Code Monkey home page Code Monkey logo

pytensor-federated's Issues

Come up with a reconnect/failover mechanism

This is the relevant traceback when the server disconnects while streaming:

  File "...\aesara_federated\common.py", line 131, in evaluate
    logp, *gradients = self._client.evaluate(*inputs, use_stream=use_stream)
  File "...\aesara_federated\service.py", line 203, in evaluate
    output = loop.run_until_complete(eval_task)
  File "...\aefenv\lib\asyncio\base_events.py", line 646, in run_until_complete
    return future.result()
  File "...\aesara_federated\service.py", line 219, in _streamed_evaluate
    response = await self._lazy_stream.recv_message()
  File ...\aefenv\lib\site-packages\grpclib\client.py", line 427, in recv_message
    with self._wrapper:
  File "...\aefenv\lib\site-packages\grpclib\utils.py", line 70, in __exit__
    raise self._error
grpclib.exceptions.StreamTerminatedError: Connection lost

And this is the error when each request is sent as an independent message:

  File "...\aesara_federated\common.py", line 131, in evaluate
    logp, *gradients = self._client.evaluate(*inputs, use_stream=use_stream)
  File "...\aesara_federated\service.py", line 203, in evaluate
    output = loop.run_until_complete(eval_task)
  File "...\aefenv\lib\asyncio\base_events.py", line 646, in run_until_complete
    return future.result()
  File "...\aesara_federated\rpc.py", line 54, in evaluate
    return await self._unary_unary(
  File "...\aefenv\lib\site-packages\betterproto\grpc\grpclib_client.py", line 85, in _unary_unary
    response = await stream.recv_message()
  File "...\aefenv\lib\site-packages\grpclib\client.py", line 425, in recv_message
    await self.recv_initial_metadata()
  File "...\aefenv\lib\site-packages\grpclib\client.py", line 367, in recv_initial_metadata
    with self._wrapper:
  File "...\aefenv\lib\site-packages\grpclib\utils.py", line 70, in __exit__
    raise self._error
grpclib.exceptions.StreamTerminatedError: Connection lost

⚠ Note that with the demo example, use_stream=True takes 40 seconds for the parallelized MCMC sampling while use_stream=False takes 51 seconds.

Send consecutive evaluation requests to different servers

If the service client was given a server pool, it should not send consecutive calls to the same server.

In the following example the evaluation should parallelize across two servers.

t1 = client.evaluate_async(...)
t2 = client.evaluate_async(...)
await t1
await t2

A workaround is to not do that in a model, by creating new service clients and Op-wrappers for each symbolic call to the remote model.

The solution here might look like some kind of queing, where new streams are opened when needed, in the order determined by the load balancing.

Add CI test pipeline

  • Write a environment.yml
  • Add CI pipeline
  • Parametrize PyMC version like done in PyMC CI

Implement graph rewrites to vectorize asynchronous Ops

Old design idea
## Design option: RPC-Aware Ops
This was the first idea.
+ [ ] Find `Op`s in the graph that use the `ArraysToArraysServiceClient` and can be parallelized (they must not depend on each other). This can be implemented by adding a mixin-interface by which an `isinstance(op, ArraysToArraysOp)` can be identified.
+ [ ] Write a `ParallelArraysToArraysOp` that keeps a list of streams and runs evaluations in parallel.
+ [ ] Do a subgraph replacement where the independent `ArraysToArraysOp`s nodes are substituted by a subgraph that routes the inputs to a new `ParallelArraysToArraysOp` node and redistributes the outputs.

Design option: Async Ops (preferred)

This would be RPC-unaware and more generic overall.

  • #26
  • Implement async homologues to the function-wrapping convenience-Ops: AsyncArraysToArraysOp, AsyncLogpOp, AsyncLogpGradOp.
  • Write a class ParallelAsyncOp(Op) similar to aesara.graph.basic.Composite that parallelizes the .perform_async() calls of a bunch of AsyncOps.
  • Write a graph optimization that finds AsyncOps and merges them into an ParallelAsyncOp

Avoid all uses of `asyncio.get_event_loop` and `loop.run_until_complete`

The problem with these is that they don't work if the loop is already running.

This is a stark, and inconvenient contrast to C# where one can easily nest asynchronous/synchronous code:

void MySyncFunction()
{
    Task task = Task.Delay(1000);
    task.Wait();
}

async void OuterAsync()
{
    await Task.Delay(500);
    MySyncFunction();
    await Task.Delay(500);
}

void main()
{
    OuterAsync().Wait();
}

Naively, the Python equivalent would be 👇, but this does not work.

def my_sync_function():
    coro = asyncio.sleep(1000)
    asyncio.get_event_loop().run_until_complete(coro)


# This is fine:
my_sync_function()


async def outer_async():
    await asyncio.sleep(500)
    my_sync_function()
    await asyncio.sleep(500)

asyncio.run(outer_async())  # RuntimeError: Loop is already running. (In my_sync_function.)

Think: Once async, always async.

Instead of implementing a synchronous function that uses asyncio.get_event_loop+loop.run_until_complete, or asyncio.run the following should be implemented;

  • Each function that calls coroutines should be async and use await inside.
  • It should be wrapped by a syncronous version that uses asyncio.get_event_loop+loop.run_until_complete.

Only the ArraysToArraysServiceClient.__call__ method should do the trick of launching a separate thread to call the def evaluate_async without launching a new event loop on the main thread. At least if the loop is already running.

Apply `at.as_tensor` to `make_node` inputs automatically

Let ìsinstance(logp_op, LogpOp) with an underlying function taking just one scalar input, then currently logp_op(2) raises a TypeError: The 'inputs' argument to Apply must contain Variable instances, not 2.

However, at.log(2) works just fine, so the error might be unexpected and inconvenient.

We should just apply at.as_tensor to the inputs automatically.

Let `ArraysToArraysServiceClient` choose from a list of possible servers

Instead of passing just one host, client combination, the ArraysToArraysServiceClient could take a list of servers to choose from.

Each server must, of course, behave identically.

Doing this allows for a failover mechanism (#9) but also enables client-side load-balancing in situations where the ArraysToArraysServiceClient is forked/spawned into tens or hundres of copies.

For the load balancing, we probably need some get_numer_of_connected_clients endpoint on the server so the clients can np.argmin(np.random.permutation(options)).

Make `ArraysToArraysServiceClient` parallelizable

Running a PyMC pm.sample(cores=2) currently breaks in the pickling:

Traceback (most recent call last):
  File "...\aesara-federated\demo_model.py", line 38, in <module>
    run_model()
  File "...\aesara-federated\demo_model.py", line 31, in run_model
    idata = pm.sample(tune=500, draws=200)
  File "...\aefenv\lib\site-packages\pymc\sampling.py", line 607, in sample
    mtrace = _mp_sample(**sample_args, **parallel_args)
  File "...\aefenv\lib\site-packages\pymc\sampling.py", line 1520, in _mp_sample
    sampler = ps.ParallelSampler(
  File "...\aefenv\lib\site-packages\pymc\parallel_sampling.py", line 415, in __init__
    step_method_pickled = cloudpickle.dumps(step_method, protocol=-1)
  File "...\aefenv\lib\site-packages\cloudpickle\cloudpickle_fast.py", line 73, in dumps
    cp.dump(obj)
  File "...\aefenv\lib\site-packages\cloudpickle\cloudpickle_fast.py", line 633, in dump
    return Pickler.dump(self, obj)
TypeError: cannot pickle '_OverlappedFuture' object

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.