Comments (4)
Does it make sense to have pmap
as its own pass/function/whatever, or does it make more sense to provide simple constructs for people to accomplish the same thing?
from nx.
To be honest, I don’t know yet, probably the latter. I just put it here so we don’t forget to track it but this is most likely a device/exla concern. I will move it. :)
from nx.
@seanmor5 so vmap may require changes to the underlying code to be compiled as it aims to add a new dimension to computations by making them batchable. So it is definitely a defn
pass. It is hard to assess right now how big those changes are, since all of our operations are element wise so far. But with Jax, this code:
from jax import make_jaxpr
def f(x, y):
a = jnp.dot(x, y)
b = jnp.tanh(a)
return b
xs = jnp.ones((8, 2, 3))
ys = jnp.ones((8, 3, 4))
print("f jaxpr")
print(make_jaxpr(f)(xs[0], ys[0]))
print("vmap(f) jaxpr")
print(make_jaxpr(vmap(f))(xs, ys))
prints:
f jaxpr
{ lambda ; a b.
let c = dot_general[ dimension_numbers=(((1,), (0,)), ((), ()))
precision=None ] a b
d = tanh c
in (d,) }
vmap(f) jaxpr
{ lambda ; a b.
let c = dot_general[ dimension_numbers=(((2,), (1,)), ((0,), (0,)))
precision=None ] a b
d = tanh c
in (d,) }
pmap, on the other hand, is about the devices. We want to move the data to separate devices when sending in and read them back from all multiple devices into a single binary. Or keep them on multiple references. So it is definitely a device based operation. I will create a separate issue for tracking the device roadmap.
from nx.
I have broken all remaining tasks to separate issues. The pmap discussion is tied to #127.
from nx.
Related Issues (20)
- Add executable-level caching to Nx.Defn in EXLA HOT 1
- Nx.LinAlg.solve/2 with Nx.BinaryBackend returns tensor with wrong state HOT 6
- Nx.all_close not working with EXLA HOT 13
- How to Set XLA log level?? HOT 2
- could not compile dependency :exla HOT 3
- Implement Nx.stack as a default callback HOT 1
- Expand docs for the :axes option in Nx.gather/3 HOT 2
- Remove xla compiler_mode
- Use regions when compiling `if` in MLIR HOT 1
- Quantization via MLIR
- Special node acceleration via metadata HOT 6
- Import and export of MLIR modules
- function Torchx.__jit__/5 is undefined or private HOT 2
- Geometric / Clifford algebra in arbitrary dimensions HOT 4
- Automatically track which variables are inside if/cond/while
- Data Loaders in Nx? HOT 6
- Cannot transform dummy columns to Nx Tensors via Nx.stack HOT 3
- Scholar.Neigbors.KDTree.predict fails when using EXLA as backend HOT 1
- Scholar.Neighbors.NNDescent.fit creashes livebook when using EXLA backend HOT 1
- Nx.slice/3 lengths parameter fails inside defn while HOT 2
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from nx.