Comments (8)
If we ever want to let non-programmers use the model, we can set it up using gradio via HF spaces. It is cpu-only, but the CPU it provides is better than in colab.
It would be nice (in a distant day) to simply join through the browser directly in a Federated setup:
https://github.com/epfml/disco
from petals.
I'm a big fan of that framework, from the time when it was still called DeAI. However, it will take some time for us to sort through more immediate issues before we can run directly in browser.
from petals.
On client-side user interface
Design goals: based on non-representative random interviews with folks at BigScience and Yandex.Research
- a lot of people will just use pre-trained models without understanding the internals
- these users want things to just work without having to configure any technical bullshit
- some people will develop new model knock-offs (e.g. adapters for new language), new inference modes (e.g. diverse beam search or restricted output space)
- these "devs" are few, but they make the model useful to others
- folks that develop knock-offs want nothing to do with networking internals, e.g. tcp/udp, fault tolerance
- train the model using prompt-tuning or adapters, publish to huggingface
- implement and train custom adapters
- write a new generation algorithm, e.g. re-score by some other model, or diverse beam search
API for users:
import os
import torch
from src import DistributedBloom
os.environ['BLOOM_TOKEN'] = input("Paste your account token here: ").strip()
os.environ['INITIAL_PEERS'] = "/ip4/123.123.123.123/tcp/1337 /ip6/something/udp/42/quic"
model = DistributedBloom.from_pretrained("bigscience/distributed-bloom-176b")
# or MyFinetunedDistributedBloom.from_pretrained("myname/finetuned-bloom")
# this loads only embeddings, logits, and config (hyperparameters)
# inference: same as most models from HF hub
generated_tokens = model.generate(tokenized_prefix, max_length=10, max_points_per_token=0.1)
# training (e.g. prompt-tuning)
opt = torch.optim.Adam(model.parameters(), lr=1e-5)
budget = dict(max_points_per_token=0.1, max_backward_points_per_token=0.25) # all params are optional
for data_batch in my_dataloader:
opt.zero_grad()
loss = compute_loss(model.forward(**data_batch, **budget))
loss.backward()
opt.step() # note: this step will only update local parameters: embeddings, lm_head biases, optional adapters
API for researchers / developers:
class DistribtuedBloom:
emb: nn.Embedding # local modules initialized in from_pretrained
lm_head: nn.Linear
remote_model: RemoteSequential # networking, compression and fault tolerance magic happens here
def forward(self, tokens):
embs = self.emb(tokens) # trains same as regular model
last_hid = self.remote_model.forward(embs) # model can do backprop, but does not update frozen transformer blocks
# ^-- note 1: this module hides any network errors from the user, automatically finds alternative servers in case of errors
# ^-- note 2: some remote_model knock-offs may contain local trainable adapters that *are* trainable
return self.lm_head(last_hid) # trains same as regular model
def generate(self, prefix: torch.LongTensor, max_length: int **generation_params_including_budget):
outputs = []
with self.remote_model.begin_inference_session(**generation_params_including_budget) as sess:
# ^-- find out generation params, budget, etc -- do not talk to servers YET
past_inputs = prefix
for i in range(max_length - prefix.shape[1]):
hidden_states = self.emb(past_inputs)
hidden_states = sess.step(hidden_states, **optional_beam_search_hypo_indices)
# then send 1 token through pipeline, update attention caches, return last layer activities
# ^-- hidden from the user: find servers, establish connections, broker deals, recover from failures
logits = self.lm_head(hidden_state)
next_token = argmax_or_some_algorithm(logits)
outputs.append(next_token)
past_inputs = next_token
return outputs
class RemoteSequential:
dht: DHT
metadata_about_servers
optional_trainable_parts: nn.Module
def forward(self, *inputs: torch.Tensor, **budget_kwargs) -> torch.Tensor:
"""Split inputs along batch dimension, send to remote servers, retry in case of failure/timeout"""
def begin_inference_session(self, **generation_params) -> RemoteInferenceSession:
"""Create a fault-tolerant inference session that supports generic incremental inference"""
class RemoteInferenceSession:
"""
represents a single inference session on client side
finds best remote servers from across DHT, dynamically finds replacements for failed servers
keeps local caches for intermediate hidden states between servers
"""
def step(self, hidden_states: torch.Tensor, **additional_step_params) -> torch.Tensor:
"""
Run a chunk of hidden states through all transformer blocks, return last layer hidden states
:param hidden_states: tensor [batch_size_usually_1, chunk_length_usually_1, hidden_size]
:param additional_step_params: beam search hypothesis indices, backtrack steps, etc
:returns: hidden_states, tensor of same shape as hidden_states
"""
from petals.
@borzunov please think about how to maintain/update initial_peers
from petals.
From conversation with @artek0chumak
In InferenceSession, it might be useful to support a special case where user specifies less than full batch of hidden states, such as:
- in a chat bot, pre-allocate 10 sequences, then encode user's inputs once, and split it into 10 hypotheses and sample them independently, then re-score with some algorithm
- in beam search, when some of the hypotheses have finished and we no longer have K alternatives
from petals.
Summarized conversation from @TimDettmers
- user shouldn't have to set initial peers (okay, Tim, Alex, you win)
- RemoteSequential feels like a huge black box that may de difficult to modify. Even in the early version, we should at least think about how to make it un-hardcode-able
- the beam search interface (as proposed by @artek0chumak ) sounds generally legit, but not necessary for the first version
from petals.
From a casual conversation with @younesbelkada
If we ever want to let non-programmers use the model, we can set it up using gradio via HF spaces. It is cpu-only, but the CPU it provides is better than in colab.
The image below is from an existing (but private) bloom inference space. It is written in pretty simple python with customizable environment. So, imagine this, but you're running a 176B model. With an option to enter your token to spend pre-amassed bloom points.
Best of all, it can be customized, so if someone wants to change the prompts, they can fork the existing space and they'll have full access to the code.
CC @borzunov
from petals.
The current interface has been fully implemented, will start a new issue later
from petals.
Related Issues (20)
- AttributeError: module 'numpy' has no attribute 'ndarray' HOT 5
- Support stable diffusion model HOT 2
- Available Models? HOT 1
- Add pre-commit hook
- support quivr or privateGPT? HOT 2
- whether past_key_values can be obtained HOT 1
- Add mistral to chat.petals HOT 1
- Pyinstaller packaged Petals binary fails to load/download on`AutoDistributedModelForCausalLM.from_pretrained`
- text to video generation models ? HOT 2
- IndexError: tuple index out of range HOT 7
- Error while hosting as provider
- I confirm Petals not working in Google Colab HOT 2
- Upstream Changes makes the demo not work HOT 2
- M1 macOS installation error ("failed to build wheel for hivemind") with mac-native petals HOT 2
- Prepull model data on private swarm HOT 5
- ImportError: cannot import name 'AutoDistributedModelForCausalLM' from partially initialized module 'petals' (most likely due to a circular import) HOT 3
- Add "Podman" usage to the documentation
- Is there any plan to support MoE models like Mixtral8×7B? HOT 5
- Error when trying to launch private swarm using locally stored model
- Can not use direct server-to-server communication HOT 6
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from petals.