Code Monkey home page Code Monkey logo

Comments (8)

jakevdp avatar jakevdp commented on June 9, 2024 1

I'm going to close this because the original question is answered; if you want to chat about ideas for how to implement megablocks MoE, perhaps a dedicated discussion would be a better place. Thanks!

from jax.

erfanzar avatar erfanzar commented on June 9, 2024

Description

im implementing mixtral model in jax and when i want to use jnp.nozero it causes error and i cant init the params and get shape

here's how im using that

class FlaxMixtralBlocKSparesTop2MLPCollection(nn.Module):
    config: MixtralConfig
    dtype: jnp.dtype = jnp.bfloat16
    param_dtype: jnp.dtype = jnp.bfloat16
    precision: Optional[jax.lax.Precision] = jax.lax.Precision("fastest")

    def setup(self) -> None:
        self.layers = [
            FlaxMixtralBLockSparseTop2MLP(
                config=self.config,
                dtype=self.dtype,
                param_dtype=self.param_dtype,
                precision=self.precision,
                name=str(i)
            )
            for i in range(self.config.num_local_experts)
        ]

    def __call__(
            self,
            expert_mask: chex.Array,
            hidden_states: chex.Array,
            routing_weights: chex.Array,
            batch_size: int,
            sequence_length: int,
            hidden_dim: int
    ) -> chex.Array:
        assert hidden_states.ndim == 2
        final_hidden_states = jnp.zeros(
            (batch_size * sequence_length, hidden_dim), dtype=hidden_states.dtype
        )

        for expert_idx, expert_layer in enumerate(self.layers):
            selected_mask = expert_mask[expert_idx]

            idx, top_x = jnp.nonzero(selected_mask)
            if top_x.shape[0] == 0:
                continue

            current_state = hidden_states[None, top_x].reshape(-1, hidden_dim)

            current_hidden_states = expert_layer(
                current_state
            ) * routing_weights[top_x, idx, None]
            final_hidden_states = final_hidden_states.at[top_x].set(
                current_hidden_states + final_hidden_states[top_x]
            )

        return final_hidden_states.reshape(batch_size, sequence_length, hidden_dim)

is there any recommendation or help that i can get ?

What jax/jaxlib version are you using?

0.4.20 JAX and JAXlib

Which accelerator(s) are you using?

CPU/TPU

Additional system info?

Linux

NVIDIA GPU info

No response

https://github.com/erfanzar/EasyDeL/blob/main/lib/python/EasyDel/modules/mixtral/modelling_mixtral_flax.py#L373

from jax.

jakevdp avatar jakevdp commented on June 9, 2024

Hi - the issue is that jnp.nonzero creates a dynamically-shaped array (i.e. an array shape that depends on the values in the array passed to it), and thus is incompatible with eval_shape.

You can address this by passing a static size argument to nonzero, in order to statiscally specify the size of the output array; for example:

import jax
import jax.numpy as jnp

def f1(x):
  return jnp.nonzero(x)

def f2(x):
  return jnp.nonzero(x, size=5)

x = jax.ShapeDtypeStruct(shape=(10,), dtype='float32')

jax.eval_shape(f1, x)  # error
jax.eval_shape(f2, x)  # ok

If the array passed to nonzero has fewer nonzero entries than the specified size, the results will be padded with zeros. If it has more nonzero entries, they will be truncated.

This static shape requirement is fundamental to the design of JAX transformations; for more discussion, see https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#dynamic-shapes.

from jax.

erfanzar avatar erfanzar commented on June 9, 2024

Yes i have tried that and noticed that with using size argument it will work correctly but it's no longer dynamic for the purpose that in using that

from jax.

davisyoshida avatar davisyoshida commented on June 9, 2024

@erfanzar If you're trying to do megablocks-style MoE layers, I don't know if you actually need/want to be computing the nonzeros for each expert. Instead, you know that each token gets assigned exactly k experts, so the number of token/expert pairs actually is statically knowable.

The thing I haven't thought carefully about is whether you can actually implement the necessary blocksparse matmul efficiently in pure JAX. I suspect you can't do it memory-efficiently without a custom kernel but I'm not sure.

from jax.

erfanzar avatar erfanzar commented on June 9, 2024

@davisyoshida guess an scan function does the job but anyway have you seen somebody implement this in jax or i should figure it out myself

from jax.

davisyoshida avatar davisyoshida commented on June 9, 2024

@erfanzar I've actually been experimenting with implementing the megablocks stuff. So far I've made the forward pass in pure JAX (so backwards will work as well), and I also have pallas kernels for the DSD and SDD matmuls done. There's still quite a lot to do, but I can prioritize getting something shareable uploaded if you're interested.

from jax.

erfanzar avatar erfanzar commented on June 9, 2024

I would be more than happy to connect with you in that case but have you tried flax.linen.scan and conditioner to make it?

from jax.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.