Code Monkey home page Code Monkey logo

pytensor-federated's Introduction

PyPI version pipeline coverage

pytensor-federated

This package implements federated computing with PyTensor.

Using pytensor-federated, differentiable cost functions can be computed on federated nodes. Inputs and outputs are transmitted in binary via a bidirectional gRPC stream.

A client side LogpGradOp is provided to conveniently embed federated compute operations in PyTensor graphs such as a PyMC model.

The example code implements a simple Bayesian linear regression to data that is "private" to the federated compute process.

Run each command in its own terminal:

python demo_node.py
python demo_model.py

Architecture

pytensor-federated is designed to be a very generalizable framework for federated computing with gRPC, but it comes with implementations for PyTensor, and specifically for use cases of Bayesian inference. This is reflected in the actual implementation, where the most basic gRPC service implementation -- the ArraysToArraysService -- is wrapped by a few implementation flavors, specifically for common use cases in Bayesian inference.

At the core, everything is built around an ArraysToArrays gRPC service, which takes any number of (NumPy) arrays as parameters, and returns any number of (NumPy) arrays as outputs. The arrays can have arbitrary dtype or shape, as long as the buffer interface is supported (meaning dtype=object doesn't work, but datetime dtypes are ok).

This ArraysToArraysService can be used to wrap arbitrary model functions, thereby enabling to run model simulations and MCMC/optimization on different machines. The protobuf files that specify the data types and gRPC interface can be compiled to other programming languages, such that the model implementation could be C++, while MCMC/optimization run in Python.

For the Bayesian inference or optimization use case, it helps to first understand the inputs and outputs of the undelying computation graph. For example, parameter estimation with a differential equation model requires...

  • observations to which the model should be fitted
  • timepoints at which there were observations
  • parameters (including initial states) theta, some of which are to be estimated

From timepoints and parameters theta, the model predicts trajectories. Together with observations, these predictions are fed into some kind of likelihood function, which produces a scalar log-likelihood log-likelihood as the output.

Different sub-graphs of this example could be wrapped by an ArraysToArraysService:

  • [theta,] -> [log-likelihood,]
  • [timepoints, theta] -> [trajectories,]
  • [timepoints, observations, theta] -> [log-likelihood,]

If the entire model is differentiable, one can even return gradients. For example, with a linear model: [slope, intercept] -> [LL, dLL_dslope, dLL_dintercept].

The role of PyTensor here is purely technical: PyTensor is a graph computation framework that implements auto-differentiation. Wrapping the ArraysToArraysServiceClient in PyTensor Ops simply makes it easier to build more sophisticated compute graphs. PyTensor is also the computatation backend for PyMC, which is the most popular framework for Bayesian inference in Python.

Installation & Contributing

conda env create -f environment.yml

Additional dependencies are needed to compile the protobufs:

conda install -c conda-forge protobuf
pip install --pre betterproto[compiler]
python protobufs/generate.py

Set up pre-commit for automated code style enforcement:

pip install pre-commit
pre-commit install

pytensor-federated's People

Contributors

dependabot[bot] avatar howdoicodepls avatar michaelosthege avatar qacwnfq avatar

Stargazers

 avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar

pytensor-federated's Issues

Implement graph rewrites to vectorize asynchronous Ops

Old design idea
## Design option: RPC-Aware Ops
This was the first idea.
+ [ ] Find `Op`s in the graph that use the `ArraysToArraysServiceClient` and can be parallelized (they must not depend on each other). This can be implemented by adding a mixin-interface by which an `isinstance(op, ArraysToArraysOp)` can be identified.
+ [ ] Write a `ParallelArraysToArraysOp` that keeps a list of streams and runs evaluations in parallel.
+ [ ] Do a subgraph replacement where the independent `ArraysToArraysOp`s nodes are substituted by a subgraph that routes the inputs to a new `ParallelArraysToArraysOp` node and redistributes the outputs.

Design option: Async Ops (preferred)

This would be RPC-unaware and more generic overall.

  • #26
  • Implement async homologues to the function-wrapping convenience-Ops: AsyncArraysToArraysOp, AsyncLogpOp, AsyncLogpGradOp.
  • Write a class ParallelAsyncOp(Op) similar to aesara.graph.basic.Composite that parallelizes the .perform_async() calls of a bunch of AsyncOps.
  • Write a graph optimization that finds AsyncOps and merges them into an ParallelAsyncOp

Avoid all uses of `asyncio.get_event_loop` and `loop.run_until_complete`

The problem with these is that they don't work if the loop is already running.

This is a stark, and inconvenient contrast to C# where one can easily nest asynchronous/synchronous code:

void MySyncFunction()
{
    Task task = Task.Delay(1000);
    task.Wait();
}

async void OuterAsync()
{
    await Task.Delay(500);
    MySyncFunction();
    await Task.Delay(500);
}

void main()
{
    OuterAsync().Wait();
}

Naively, the Python equivalent would be 👇, but this does not work.

def my_sync_function():
    coro = asyncio.sleep(1000)
    asyncio.get_event_loop().run_until_complete(coro)


# This is fine:
my_sync_function()


async def outer_async():
    await asyncio.sleep(500)
    my_sync_function()
    await asyncio.sleep(500)

asyncio.run(outer_async())  # RuntimeError: Loop is already running. (In my_sync_function.)

Think: Once async, always async.

Instead of implementing a synchronous function that uses asyncio.get_event_loop+loop.run_until_complete, or asyncio.run the following should be implemented;

  • Each function that calls coroutines should be async and use await inside.
  • It should be wrapped by a syncronous version that uses asyncio.get_event_loop+loop.run_until_complete.

Only the ArraysToArraysServiceClient.__call__ method should do the trick of launching a separate thread to call the def evaluate_async without launching a new event loop on the main thread. At least if the loop is already running.

Make `ArraysToArraysServiceClient` parallelizable

Running a PyMC pm.sample(cores=2) currently breaks in the pickling:

Traceback (most recent call last):
  File "...\aesara-federated\demo_model.py", line 38, in <module>
    run_model()
  File "...\aesara-federated\demo_model.py", line 31, in run_model
    idata = pm.sample(tune=500, draws=200)
  File "...\aefenv\lib\site-packages\pymc\sampling.py", line 607, in sample
    mtrace = _mp_sample(**sample_args, **parallel_args)
  File "...\aefenv\lib\site-packages\pymc\sampling.py", line 1520, in _mp_sample
    sampler = ps.ParallelSampler(
  File "...\aefenv\lib\site-packages\pymc\parallel_sampling.py", line 415, in __init__
    step_method_pickled = cloudpickle.dumps(step_method, protocol=-1)
  File "...\aefenv\lib\site-packages\cloudpickle\cloudpickle_fast.py", line 73, in dumps
    cp.dump(obj)
  File "...\aefenv\lib\site-packages\cloudpickle\cloudpickle_fast.py", line 633, in dump
    return Pickler.dump(self, obj)
TypeError: cannot pickle '_OverlappedFuture' object

Let `ArraysToArraysServiceClient` choose from a list of possible servers

Instead of passing just one host, client combination, the ArraysToArraysServiceClient could take a list of servers to choose from.

Each server must, of course, behave identically.

Doing this allows for a failover mechanism (#9) but also enables client-side load-balancing in situations where the ArraysToArraysServiceClient is forked/spawned into tens or hundres of copies.

For the load balancing, we probably need some get_numer_of_connected_clients endpoint on the server so the clients can np.argmin(np.random.permutation(options)).

Add CI test pipeline

  • Write a environment.yml
  • Add CI pipeline
  • Parametrize PyMC version like done in PyMC CI

Apply `at.as_tensor` to `make_node` inputs automatically

Let ìsinstance(logp_op, LogpOp) with an underlying function taking just one scalar input, then currently logp_op(2) raises a TypeError: The 'inputs' argument to Apply must contain Variable instances, not 2.

However, at.log(2) works just fine, so the error might be unexpected and inconvenient.

We should just apply at.as_tensor to the inputs automatically.

Send consecutive evaluation requests to different servers

If the service client was given a server pool, it should not send consecutive calls to the same server.

In the following example the evaluation should parallelize across two servers.

t1 = client.evaluate_async(...)
t2 = client.evaluate_async(...)
await t1
await t2

A workaround is to not do that in a model, by creating new service clients and Op-wrappers for each symbolic call to the remote model.

The solution here might look like some kind of queing, where new streams are opened when needed, in the order determined by the load balancing.

Come up with a reconnect/failover mechanism

This is the relevant traceback when the server disconnects while streaming:

  File "...\aesara_federated\common.py", line 131, in evaluate
    logp, *gradients = self._client.evaluate(*inputs, use_stream=use_stream)
  File "...\aesara_federated\service.py", line 203, in evaluate
    output = loop.run_until_complete(eval_task)
  File "...\aefenv\lib\asyncio\base_events.py", line 646, in run_until_complete
    return future.result()
  File "...\aesara_federated\service.py", line 219, in _streamed_evaluate
    response = await self._lazy_stream.recv_message()
  File ...\aefenv\lib\site-packages\grpclib\client.py", line 427, in recv_message
    with self._wrapper:
  File "...\aefenv\lib\site-packages\grpclib\utils.py", line 70, in __exit__
    raise self._error
grpclib.exceptions.StreamTerminatedError: Connection lost

And this is the error when each request is sent as an independent message:

  File "...\aesara_federated\common.py", line 131, in evaluate
    logp, *gradients = self._client.evaluate(*inputs, use_stream=use_stream)
  File "...\aesara_federated\service.py", line 203, in evaluate
    output = loop.run_until_complete(eval_task)
  File "...\aefenv\lib\asyncio\base_events.py", line 646, in run_until_complete
    return future.result()
  File "...\aesara_federated\rpc.py", line 54, in evaluate
    return await self._unary_unary(
  File "...\aefenv\lib\site-packages\betterproto\grpc\grpclib_client.py", line 85, in _unary_unary
    response = await stream.recv_message()
  File "...\aefenv\lib\site-packages\grpclib\client.py", line 425, in recv_message
    await self.recv_initial_metadata()
  File "...\aefenv\lib\site-packages\grpclib\client.py", line 367, in recv_initial_metadata
    with self._wrapper:
  File "...\aefenv\lib\site-packages\grpclib\utils.py", line 70, in __exit__
    raise self._error
grpclib.exceptions.StreamTerminatedError: Connection lost

⚠ Note that with the demo example, use_stream=True takes 40 seconds for the parallelized MCMC sampling while use_stream=False takes 51 seconds.

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.