Comments (8)
I'm going to close this because the original question is answered; if you want to chat about ideas for how to implement megablocks MoE, perhaps a dedicated discussion would be a better place. Thanks!
from jax.
Description
im implementing mixtral model in jax and when i want to use jnp.nozero it causes error and i cant init the params and get shape
here's how im using that
class FlaxMixtralBlocKSparesTop2MLPCollection(nn.Module): config: MixtralConfig dtype: jnp.dtype = jnp.bfloat16 param_dtype: jnp.dtype = jnp.bfloat16 precision: Optional[jax.lax.Precision] = jax.lax.Precision("fastest") def setup(self) -> None: self.layers = [ FlaxMixtralBLockSparseTop2MLP( config=self.config, dtype=self.dtype, param_dtype=self.param_dtype, precision=self.precision, name=str(i) ) for i in range(self.config.num_local_experts) ] def __call__( self, expert_mask: chex.Array, hidden_states: chex.Array, routing_weights: chex.Array, batch_size: int, sequence_length: int, hidden_dim: int ) -> chex.Array: assert hidden_states.ndim == 2 final_hidden_states = jnp.zeros( (batch_size * sequence_length, hidden_dim), dtype=hidden_states.dtype ) for expert_idx, expert_layer in enumerate(self.layers): selected_mask = expert_mask[expert_idx] idx, top_x = jnp.nonzero(selected_mask) if top_x.shape[0] == 0: continue current_state = hidden_states[None, top_x].reshape(-1, hidden_dim) current_hidden_states = expert_layer( current_state ) * routing_weights[top_x, idx, None] final_hidden_states = final_hidden_states.at[top_x].set( current_hidden_states + final_hidden_states[top_x] ) return final_hidden_states.reshape(batch_size, sequence_length, hidden_dim)is there any recommendation or help that i can get ?
What jax/jaxlib version are you using?
0.4.20 JAX and JAXlib
Which accelerator(s) are you using?
CPU/TPU
Additional system info?
Linux
NVIDIA GPU info
No response
from jax.
Hi - the issue is that jnp.nonzero
creates a dynamically-shaped array (i.e. an array shape that depends on the values in the array passed to it), and thus is incompatible with eval_shape
.
You can address this by passing a static size
argument to nonzero
, in order to statiscally specify the size of the output array; for example:
import jax
import jax.numpy as jnp
def f1(x):
return jnp.nonzero(x)
def f2(x):
return jnp.nonzero(x, size=5)
x = jax.ShapeDtypeStruct(shape=(10,), dtype='float32')
jax.eval_shape(f1, x) # error
jax.eval_shape(f2, x) # ok
If the array passed to nonzero
has fewer nonzero entries than the specified size
, the results will be padded with zeros. If it has more nonzero entries, they will be truncated.
This static shape requirement is fundamental to the design of JAX transformations; for more discussion, see https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#dynamic-shapes.
from jax.
Yes i have tried that and noticed that with using size argument it will work correctly but it's no longer dynamic for the purpose that in using that
from jax.
@erfanzar If you're trying to do megablocks-style MoE layers, I don't know if you actually need/want to be computing the nonzeros for each expert. Instead, you know that each token gets assigned exactly k experts, so the number of token/expert pairs actually is statically knowable.
The thing I haven't thought carefully about is whether you can actually implement the necessary blocksparse matmul efficiently in pure JAX. I suspect you can't do it memory-efficiently without a custom kernel but I'm not sure.
from jax.
@davisyoshida guess an scan function does the job but anyway have you seen somebody implement this in jax or i should figure it out myself
from jax.
@erfanzar I've actually been experimenting with implementing the megablocks stuff. So far I've made the forward pass in pure JAX (so backwards will work as well), and I also have pallas kernels for the DSD and SDD matmuls done. There's still quite a lot to do, but I can prioritize getting something shareable uploaded if you're interested.
from jax.
I would be more than happy to connect with you in that case but have you tried flax.linen.scan and conditioner to make it?
from jax.
Related Issues (20)
- str(PyTreeDef) identical for two PyTreeDefs, but assert with allclose fails HOT 6
- Guidelines on reducing compilation memory?
- Custom partitioning error in fused_attention_stablehlo HOT 3
- jax-metal: dynamic update slice fails with unsigned indices
- jax-metal: cond fails in compile in certain cases HOT 1
- support batched matrix multiplication in pallas
- jaxlib.xla_extension.XlaRuntimeError: INTERNAL: ptxas exited with non-zero error code 11 HOT 2
- Wrong results on matmul's associative_scan when jitted within scan HOT 2
- Wrong array values in `jax.debug.print` and the actual results when using `lax.slice` in `lax.scan` HOT 5
- Stochastic pmap lowering behavior in tests HOT 3
- Using `jax.config.update` within config context manager fails to set new value HOT 7
- Inconsistent results with shard_map when switching PRNG Implementation from threefry2x32 to rbg HOT 1
- Allow access to custom pytree definitions and override them HOT 7
- pmap out_axes=None doesn't check if output is mapped/unmapped
- reduce_window broken on CPU HOT 2
- Mistake (?) in the "How to think in Jax" doc HOT 2
- ndarray.at.set(mode="drop") gives incorrect value for the last element in the array HOT 4
- partial eval silently skips effects HOT 2
- NaN when computing gradient of squared norm evaluated at 0. HOT 2
- TracerBoolConversionError when jitting jax.numpy.linalg.norm HOT 1
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from jax.