Code Monkey home page Code Monkey logo

Comments (8)

bhack avatar bhack commented on May 10, 2024 1

If we ever want to let non-programmers use the model, we can set it up using gradio via HF spaces. It is cpu-only, but the CPU it provides is better than in colab.

It would be nice (in a distant day) to simply join through the browser directly in a Federated setup:
https://github.com/epfml/disco

from petals.

justheuristic avatar justheuristic commented on May 10, 2024 1

I'm a big fan of that framework, from the time when it was still called DeAI. However, it will take some time for us to sort through more immediate issues before we can run directly in browser.

from petals.

justheuristic avatar justheuristic commented on May 10, 2024

On client-side user interface

Design goals: based on non-representative random interviews with folks at BigScience and Yandex.Research

  • a lot of people will just use pre-trained models without understanding the internals
    • these users want things to just work without having to configure any technical bullshit
  • some people will develop new model knock-offs (e.g. adapters for new language), new inference modes (e.g. diverse beam search or restricted output space)
    • these "devs" are few, but they make the model useful to others
    • folks that develop knock-offs want nothing to do with networking internals, e.g. tcp/udp, fault tolerance
    • train the model using prompt-tuning or adapters, publish to huggingface
    • implement and train custom adapters
    • write a new generation algorithm, e.g. re-score by some other model, or diverse beam search

API for users:

import os
import torch
from src import DistributedBloom

os.environ['BLOOM_TOKEN'] = input("Paste your account token here: ").strip()
os.environ['INITIAL_PEERS'] = "/ip4/123.123.123.123/tcp/1337 /ip6/something/udp/42/quic"
model = DistributedBloom.from_pretrained("bigscience/distributed-bloom-176b")
# or MyFinetunedDistributedBloom.from_pretrained("myname/finetuned-bloom")
# this loads only embeddings, logits, and config (hyperparameters)


# inference: same as most models from HF hub
generated_tokens = model.generate(tokenized_prefix, max_length=10, max_points_per_token=0.1)

# training (e.g. prompt-tuning)
opt = torch.optim.Adam(model.parameters(), lr=1e-5)
budget = dict(max_points_per_token=0.1, max_backward_points_per_token=0.25)  # all params are optional
for data_batch in my_dataloader:
    opt.zero_grad()
    loss = compute_loss(model.forward(**data_batch, **budget))
    loss.backward()
    opt.step()  # note: this step will only update local parameters: embeddings, lm_head biases, optional adapters

API for researchers / developers:

class DistribtuedBloom:
    emb: nn.Embedding                         # local modules initialized in from_pretrained
    lm_head: nn.Linear
    remote_model: RemoteSequential            # networking, compression and fault tolerance magic happens here
        
    def forward(self, tokens):
        embs = self.emb(tokens)                                   # trains same as regular model
        last_hid = self.remote_model.forward(embs)                # model can do backprop, but does not update frozen transformer blocks
        # ^-- note 1: this module hides any network errors from the user, automatically finds alternative servers in case of errors
        # ^-- note 2: some remote_model knock-offs may contain local trainable adapters that *are* trainable
        return self.lm_head(last_hid)                             # trains same as regular model
    
    def generate(self, prefix: torch.LongTensor, max_length: int **generation_params_including_budget):
        outputs = []
        with self.remote_model.begin_inference_session(**generation_params_including_budget) as sess:
            # ^-- find out generation params, budget, etc -- do not talk to servers YET
            past_inputs = prefix
            for i in range(max_length - prefix.shape[1]):
                hidden_states = self.emb(past_inputs)

                hidden_states = sess.step(hidden_states, **optional_beam_search_hypo_indices)
                # then send 1 token through pipeline, update attention caches, return last layer activities
                # ^-- hidden from the user: find servers, establish connections, broker deals, recover from failures

                logits = self.lm_head(hidden_state)
                next_token = argmax_or_some_algorithm(logits)
                outputs.append(next_token)
                past_inputs = next_token
            return outputs

class RemoteSequential:
    dht: DHT
    metadata_about_servers
    optional_trainable_parts: nn.Module

    def forward(self, *inputs: torch.Tensor, **budget_kwargs) -> torch.Tensor:
        """Split inputs along batch dimension, send to remote servers, retry in case of failure/timeout"""

    def begin_inference_session(self, **generation_params) -> RemoteInferenceSession:
        """Create a fault-tolerant inference session that supports generic incremental inference"""

class RemoteInferenceSession:
    """
    represents a single inference session on client side
    finds best remote servers from across DHT, dynamically finds replacements for failed servers
    keeps local caches for intermediate hidden states between servers
    """
    def step(self, hidden_states: torch.Tensor, **additional_step_params) -> torch.Tensor:
        """
        Run a chunk of hidden states through all transformer blocks, return last layer hidden states
        :param hidden_states: tensor [batch_size_usually_1, chunk_length_usually_1, hidden_size]
        :param additional_step_params: beam search hypothesis indices, backtrack steps, etc
        :returns: hidden_states, tensor of same shape as hidden_states
        """

from petals.

justheuristic avatar justheuristic commented on May 10, 2024

@borzunov please think about how to maintain/update initial_peers

from petals.

justheuristic avatar justheuristic commented on May 10, 2024

From conversation with @artek0chumak

In InferenceSession, it might be useful to support a special case where user specifies less than full batch of hidden states, such as:

  • in a chat bot, pre-allocate 10 sequences, then encode user's inputs once, and split it into 10 hypotheses and sample them independently, then re-score with some algorithm
  • in beam search, when some of the hypotheses have finished and we no longer have K alternatives

from petals.

justheuristic avatar justheuristic commented on May 10, 2024

Summarized conversation from @TimDettmers

  • user shouldn't have to set initial peers (okay, Tim, Alex, you win)
  • RemoteSequential feels like a huge black box that may de difficult to modify. Even in the early version, we should at least think about how to make it un-hardcode-able
  • the beam search interface (as proposed by @artek0chumak ) sounds generally legit, but not necessary for the first version

from petals.

justheuristic avatar justheuristic commented on May 10, 2024

From a casual conversation with @younesbelkada

If we ever want to let non-programmers use the model, we can set it up using gradio via HF spaces. It is cpu-only, but the CPU it provides is better than in colab.

The image below is from an existing (but private) bloom inference space. It is written in pretty simple python with customizable environment. So, imagine this, but you're running a 176B model. With an option to enter your token to spend pre-amassed bloom points.
image

Best of all, it can be customized, so if someone wants to change the prompts, they can fork the existing space and they'll have full access to the code.

CC @borzunov

from petals.

justheuristic avatar justheuristic commented on May 10, 2024

The current interface has been fully implemented, will start a new issue later

from petals.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.