Code Monkey home page Code Monkey logo

Comments (19)

ghostplant avatar ghostplant commented on September 20, 2024

You can follow this example: https://github.com/microsoft/tutel/blob/main/tutel/examples/helloworld_demo.py, which can be executed with: python3 -m tutel.examples.helloworld_demo --batch_size=16

from tutel.

zws98 avatar zws98 commented on September 20, 2024

thanks a lot!

from tutel.

zws98 avatar zws98 commented on September 20, 2024

What if I want to feed another parameter in "class CustomExpertDemo(torch.nn.Module):", how can I revise the code in tutel?

from tutel.

zws98 avatar zws98 commented on September 20, 2024

e.g., def forward(self, x, ctx, anew_param):

from tutel.

ghostplant avatar ghostplant commented on September 20, 2024

Is that a static parameter that can be set just in __init__ function of CustomExpertDemo?

from tutel.

zws98 avatar zws98 commented on September 20, 2024

nope, it is a learnable parameter initialized out of the class "CustomExpertDem".

from tutel.

ghostplant avatar ghostplant commented on September 20, 2024

Still need a few API upgrades to meet your requirement.

from tutel.

zws98 avatar zws98 commented on September 20, 2024

Thanks, is there an available way to modify it after installing tutel? (e.g., reivising xx.py after installing tutel)

from tutel.

ghostplant avatar ghostplant commented on September 20, 2024

You need to feed extra argument data you need here: https://github.com/microsoft/tutel/blob/main/tutel/impls/moe_layer.py#L238,
where self.experts is the layer object created from your custom CustomExpertDemo.

You also need to extend corresponding argument list in the forward function to match data you feed: https://github.com/microsoft/tutel/blob/main/tutel/examples/helloworld_demo.py#L101

If you cannot clone and install tutel from source after changes above applied in the source, you have to get the location of installed file maybe at /usr/..../tutel/impls/moe_layer.py and apply the changes there.

from tutel.

zws98 avatar zws98 commented on September 20, 2024

Thanks a lot.

from tutel.

zws98 avatar zws98 commented on September 20, 2024

When I use the Customexpert, it stopped here:
if ctx.sharded_count > 1:
raise Exception("sharded_count > 1 is not implemented within this expert, Model parallel is disabled.")

class CustomExpert_lora(torch.nn.Module):
    def __init__(self, model_dim, local_experts, sharded_count, my_config, act_layer=nn.GELU):
        super().__init__()
        self.r = 8
        self.scale = 1 / math.sqrt(self.r) 
        self.lora_A1 = torch.nn.Parameter(torch.empty(local_experts, self.r, model_dim))
        self.lora_B1 = torch.nn.Parameter(torch.empty(local_experts, model_dim, self.r))
        self.act = act_layer()
        self.lora_A2 = torch.nn.Parameter(torch.empty(local_experts, self.r, model_dim))
        self.lora_B2 = torch.nn.Parameter(torch.empty(local_experts, model_dim, self.r))
        self.reset_parameters()

    def reset_parameters(self):
        init.kaiming_normal_(self.lora_A1)
        self.lora_A1.data *= self.scale 
        init.constant_(self.lora_B1, 0)
        init.kaiming_normal_(self.lora_A2)
        self.lora_A2.data *= self.scale 
        init.constant_(self.lora_B2, 0)

    def forward(self, x, ctx):

        if ctx.sharded_count > 1:
            raise Exception("`sharded_count > 1` is not implemented within this expert, Model parallel is disabled.")

        t1 = torch.matmul(self.lora_A1, self.lora_B1) 
        t2 = torch.matmul(self.lora_A2, self.lora_B2)  
        y = torch.matmul(x, t1)  
        y = self.act(y)
        y = torch.matmul(y, t2)  
        return y

from tutel.

ghostplant avatar ghostplant commented on September 20, 2024

When I use the Customexpert, it stopped here: if ctx.sharded_count > 1: raise Exception("sharded_count > 1 is not implemented within this expert, Model parallel is disabled.")

class CustomExpert_lora(torch.nn.Module):
    def __init__(self, model_dim, local_experts, sharded_count, my_config, act_layer=nn.GELU):
        super().__init__()
        self.r = 8
        self.scale = 1 / math.sqrt(self.r) 
        self.lora_A1 = torch.nn.Parameter(torch.empty(local_experts, self.r, model_dim))
        self.lora_B1 = torch.nn.Parameter(torch.empty(local_experts, model_dim, self.r))
        self.act = act_layer()
        self.lora_A2 = torch.nn.Parameter(torch.empty(local_experts, self.r, model_dim))
        self.lora_B2 = torch.nn.Parameter(torch.empty(local_experts, model_dim, self.r))
        self.reset_parameters()

    def reset_parameters(self):
        init.kaiming_normal_(self.lora_A1)
        self.lora_A1.data *= self.scale 
        init.constant_(self.lora_B1, 0)
        init.kaiming_normal_(self.lora_A2)
        self.lora_A2.data *= self.scale 
        init.constant_(self.lora_B2, 0)

    def forward(self, x, ctx):

        if ctx.sharded_count > 1:
            raise Exception("`sharded_count > 1` is not implemented within this expert, Model parallel is disabled.")

        t1 = torch.matmul(self.lora_A1, self.lora_B1) 
        t2 = torch.matmul(self.lora_A2, self.lora_B2)  
        y = torch.matmul(x, t1)  
        y = self.act(y)
        y = torch.matmul(y, t2)  
        return y

What is the value of adaptive_r in your moe forward setting?

from tutel.

zws98 avatar zws98 commented on September 20, 2024

Where can I find the "adaptive_r" ?

from tutel.

zws98 avatar zws98 commented on September 20, 2024

I have not changed the value of adaptive_r. I directly replaced the above-mentioned custom MLP with the default FFN and the program is working fine.

from tutel.

ghostplant avatar ghostplant commented on September 20, 2024

So looks like num_global_experts is smaller than the number of GPUs, right?

from tutel.

zws98 avatar zws98 commented on September 20, 2024

num_global_experts=2, self.world_size=8

from tutel.

ghostplant avatar ghostplant commented on September 20, 2024

Yes. When the execution setting num_global_experts < self.world_size, you will have to handle if shared_count > 1 which tells the way to partition expert parameters that are distributed across more than 1 GPU. Typically, you can implement a expert-data-parallelism to enable this execution setting, which requires creating sharded parameters in initialization and then all_gather sharded parameters in forward function. Actually, the built-in FFN layer has included those implementations, but I'll share you a simpler example.

from tutel.

zws98 avatar zws98 commented on September 20, 2024

thanks a lot!

from tutel.

ghostplant avatar ghostplant commented on September 20, 2024

Please follow this example in handling sharded_count: https://github.com/microsoft/tutel/blob/main/tutel/experts/llama_ffn.py
And another end-to-end example: https://github.com/microsoft/tutel/blob/main/tutel/examples/helloworld_custom_expert_sharded.py

from tutel.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.