Comments (10)
@adarob could you provide a minimal repro for this?
from flax.
Doesn't work:
@jax.jit
def rnd():
return (jax.random.randint(nn.make_rng(), (5,), 0, 10),
jax.random.randint(nn.make_rng(), (5,), 0, 10))
with nn.stochastic(jax.random.PRNGKey(0)):
for _ in range(5):
print(rnd())
Output:
(DeviceArray([0, 2, 0, 3, 2], dtype=int32), DeviceArray([4, 1, 7, 8, 9], dtype=int32))
(DeviceArray([0, 2, 0, 3, 2], dtype=int32), DeviceArray([4, 1, 7, 8, 9], dtype=int32))
(DeviceArray([0, 2, 0, 3, 2], dtype=int32), DeviceArray([4, 1, 7, 8, 9], dtype=int32))
(DeviceArray([0, 2, 0, 3, 2], dtype=int32), DeviceArray([4, 1, 7, 8, 9], dtype=int32))
(DeviceArray([0, 2, 0, 3, 2], dtype=int32), DeviceArray([4, 1, 7, 8, 9], dtype=int32))
Works
@jax.jit
def rnd(rng):
with nn.stochastic(rng):
return (jax.random.randint(nn.make_rng(), (5,), 0, 10),
jax.random.randint(nn.make_rng(), (5,), 0, 10))
with nn.stochastic(jax.random.PRNGKey(0)):
for _ in range(5):
print(rnd(nn.make_rng()))
Output:
(DeviceArray([8, 5, 6, 6, 7], dtype=int32), DeviceArray([4, 9, 7, 1, 5], dtype=int32))
(DeviceArray([9, 3, 1, 6, 0], dtype=int32), DeviceArray([6, 0, 5, 3, 9], dtype=int32))
(DeviceArray([2, 7, 8, 8, 1], dtype=int32), DeviceArray([9, 2, 5, 0, 6], dtype=int32))
(DeviceArray([0, 1, 2, 8, 1], dtype=int32), DeviceArray([5, 4, 6, 1, 1], dtype=int32))
(DeviceArray([1, 8, 4, 8, 3], dtype=int32), DeviceArray([1, 3, 6, 6, 4], dtype=int32))
from flax.
In that second example you meant to write rng
in place of nn.make_rng()
, no?
from flax.
I don't know which line you're referring to but it looks like what I intended.
from flax.
Ah, my apologies I misread it on the first read.
from flax.
This is part of a larger issue concerning mixing states and jax transformations. nn.stochastic
should throw an exception in this case because mixing jax transformations and internal state are ambigious. I will make a PR for this but it might lead to some false positives that need to be fixed.
from flax.
I think #125 is the PR that should address this.
Effectively, it should make your code @adarob throw an explicit error, and then you can decide how to deal with the PRNGs. E.g. if you're using vmap you will have to explicitly choose whether you split them or reuse the PRNG.
from flax.
Btw we are also looking into automatically supporting things like stateful and stochastic in combination with jax transforms together with the Haiku folks and the jax core team. But for know we just try to avoid silent errors
from flax.
@jheek assigning to you because I believe you're looking into this
from flax.
nn.stochastic
correctly throws an error but it does now extend into init_by_shape
(as of PR #159).
from flax.
Related Issues (20)
- flax.linen.module.init still fails under dynamic type checking for nested modules
- *Module Parameters* section of docs is outdated. HOT 4
- More memory consume compared with Pytorch HOT 1
- Difference in output between jitted and non-jitted call
- Error when calling `Module.tabulate` on normalization wrappers like `WeightNorm` and `SpectralNorm`
- Orbax checkpoint for LogicallyPartitioned params HOT 2
- For some reason these imports are elided on read the docs HOT 1
- Using variable declared at a broader scope in a function is bad form HOT 1
- Add `BatchRenorm` layer to `linen.normalization`
- GroupedConv distributed training failure
- In `MultiHeadAttention`, let `num_heads=1` by default
- Documentation/notebook errors HOT 2
- Remove `tree_map` deprecation filter after Flax upgrades minimum Python version to 3.10
- Unpickled modules with constructor arguments cannot be initialized
- Improve SEO for docs pages HOT 2
- Add ability to easily change documentation version
- Problem while using checkpoints.restore_checkpoint with gradio HOT 1
- nnx static fields not part of static tree structure HOT 1
- nn.remat_scan doesn't work with nn.with_partitioning HOT 1
- No way to call nnx.State.from_flat_path HOT 5
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from flax.