Code Monkey home page Code Monkey logo

mengzi-retrieval-lm's Introduction

Mengzi-Retrieval-LM

At Langboat Technology, we focus on enhancing pre-trained models to make them lighter to satisfy real industry needs. A retrieval-based approach(like RETRO, REALM, and RAG) is crucial to achieving this goal.

This repository is an experimental implementation of the retrieval-enhanced language model. Currently, it only supports retrieval fitting on GPT-Neo.

We forked Huggingface Transformers and lm-evaluation-harness to add retrieval support. The indexing part is implemented as an HTTP server to better decouple retrieval and training.

Most of the model implementation is copied from RETRO-pytorch and GPT-Neo. We use transformers-cli to add a new model named Re_gptForCausalLM based on GPT-Neo, and then add retrieval part to it.

We uploaded the model fitted on EleutherAI/gpt-neo-125M using the 200G retrieval library.

You can initialize a model like this:

from transformers import Re_gptForCausalLM
model = Re_gptForCausalLM.from_pretrained('Langboat/ReGPT-125M-200G')

And evaluate the model like this:

python main.py \
    --model retrieval \
    --model_args pretrained=model_path \
    --device 0 \
    --tasks wikitext,lambada,winogrande,mathqa,pubmedqa  \
    --batch_size 1

We compute similarity using sentence_transformers's embedding as text representation. You can initialize a Sentence-BERT model like this:

from sentence_transformers import SentenceTransformer
model = SentenceTransformer('all-MiniLM-L12-v2')

Architecture

Cloud Architecture - Page 1 (1)

Usage

Environment

conda create -n mengzi-retrieval-fit python=3.7
conda activate mengzi-retrieval-fit
conda install pytorch torchvision torchaudio cudatoolkit=11.1 -c pytorch-lts -c nvidia
git clone https://github.com/Langboat/mengzi-retrieval-lm.git
cd mengzi-retrieval-lm
git submodule update --init --recursive
pip install -r requirement.txt
cd transformers/
pip install -e .
cd ..
python -c "from sentence_transformers import SentenceTransformer; model = SentenceTransformer('all-MiniLM-L12-v2')"

Download

Index and DB

Using IVF1024PQ48 as the faiss index factory, we uploaded the index and database to the huggingface model hub, which can be downloaded using the following command.

In download_index_db.py, you can specify the number of indexes and databases you want to download.

python -u download_index_db.py  --num 200

Model

You can manually download the fitted model from here: https://huggingface.co/Langboat/ReGPT-125M-200G

Setup index server

Start

The index server is based on FastAPI and Ray. With Ray's Actor, computationally intensive tasks are encapsulated asynchronously, allowing us to efficiently utilize CPU and GPU resources with just one FastAPI server instance. You can initialize an index server like this:

cd index-server/
ray start --head
python -u api.py \
--config config_IVF1024PQ48.json \
--db_path ../db/models—Langboat—Pile-DB/snapshots/fd35bcce75db5c1b7385a28018029f7465b4e966
  • Keep in mind that the config IVF1024PQ48.json shard count must match the number of downloaded indexes. You can view the currently downloaded index number under the db_path
  • This config has been tested on the A100-40G, so if you have a different GPU, we recommend adjusting it to your hardware.
  • After deploying the index server, you need to modify the request_server in lm-evaluation-harness/config.json and train/config.json .
  • You can reduce the encoder_actor_count in config_IVF1024PQ48.json to reduce the required memory resources.

· db_path:the database's download location from huggingface. "../db/models—Langboat—Pile-DB/snapshots/fd35bcce75db5c1b7385a28018029f7465b4e966" is an example.

This command will download the database and index data from huggingface.

Change the index folder in the configuration file (config IVF1024PQ48) to point to the index folder's path, and send the database folder's snapshots as the db path to the api.py script.

Stop

Stop the index server with the following command

ray stop
  • Keep in mind that you need to keep the index server enabled during training, eval and inference

Training

Use train/train.py to implement training; train/config.json can be modified to change the training parameters.

You can initialize training like this:

cd train
python -u train.py
  • Since the index server needs to use memory resources, you better deploy the index server and model training on different GPUs

Inference

Utilize train/inference.py as an inference to determine the loss of a text and it's perplexity.

cd train
python -u inference.py \
    --model_path Langboat/ReGPT-125M-200G \
    --file_name data/test_data.json
  • The test_data.json and train_data.json in the data folder are currently supported file formats, you can modify your data to this format.

Evaluations

Use lm-evaluation-harness as evaluation method

We set the seq_len of the lm-evaluation-harness to 1025 as the initial setting for model comparison because the seq_len of our model training is 1025.

cd lm-evaluation-harness
python setup.py install

with retrieval

python main.py \
    --model retrieval \
    --model_args pretrained=Langboat/ReGPT-125M-200G \
    --device 0 \
    --tasks wikitext  \
    --batch_size 1

· model_path:the fitting model path

without retrieval

python main.py \
	--model gpt2 \
	--model_args pretrained=EleutherAI/gpt-neo-125M \
	--device 0 \
	--tasks wikitext \
	--batch_size 1

The results of the evaluation are as follows

model wikitext word_perplexity
EleutherAI/gpt-neo-125M 35.8774
Langboat/ReGPT-125M-200G 22.115
EleutherAI/gpt-neo-1.3B 17.6979
Langboat/ReGPT-125M-400G 14.1327

Citing Mengzi Retrieval LM

@software{mengzi-retrieval-lm-library,
  title = {{Mengzi-Retrieval-LM}},
  author = {Wang, Yulong and Bo, Lin},
  url = {https://github.com/Langboat/mengzi-retrieval-lm},
  month = {9},
  year = {2022},
  version = {0.0.1},
}

mengzi-retrieval-lm's People

Contributors

ag2s1 avatar bling0830 avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar

mengzi-retrieval-lm's Issues

Re_gptForCausalLM were not initialized from the model checkpoint at EleutherAI/gpt-neo-125M

When i try to load the gpt-neo-125M using train/trainer.py, following log shows up . I wonder is this ok ? I have change the Re_gptForCausalLM to GPTNeoForCausalLM, it disappears.
Some weights of Re_gptForCausalLM were not initialized from the model checkpoint at EleutherAI/gpt-neo-125M and are newly initialized: ['transformer.h.5.cross_attn.fn.cross_attn.to_q.weight', 'transformer.encoder.layers.1.0.fn.to_q.weight', 'transformer.encoder.layers.1.0.fn.to_k.weight', 'transformer.encoder.layers.0.1.fn.to_out.weight', 'transformer.encoder.layers.0.0.fn.to_v.weight', 'transformer.encoder.layers.0.1.fn.to_v.weight', 'transformer.encoder.layers.0.0.fn.to_k.weight', 'transformer.encoder.layers.1.1.fn.to_out.bias', 'transformer.encoder.layers.0.2.fn.ff.0.weight', 'transformer.encoder.layers.0.0.fn.to_out.weight', 'transformer.rotary_pos_emb.inv_freq', 'transformer.h.5.cross_attn.fn.cross_attn.null_v', 'transformer.encoder.layers.1.1.fn.to_q.weight', 'transformer.encoder.layers.1.1.fn.to_k.weight', 'transformer.encoder.layers.1.2.norm.weight', 'transformer.encoder.layers.1.1.fn.to_v.weight', 'transformer.encoder.layers.1.0.fn.to_v.weight', 'transformer.encoder.layers.1.2.fn.ff.3.bias', 'transformer.encoder.layers.0.1.fn.to_k.weight', 'transformer.encoder.layers.1.2.fn.ff.0.weight', 'transformer.encoder.norm_out.weight', 'transformer.encoder.project_out.bias', 'transformer.encoder.layers.0.1.fn.to_q.weight', 'transformer.encoder.layers.0.2.norm.weight', 'transformer.encoder.layers.0.1.norm.weight', 'transformer.encoder.rotary_pos_emb.inv_freq', 'transformer.encoder.layers.1.1.fn.to_out.weight', 'transformer.h.5.cross_attn.fn.cross_attn.to_v.weight', 'transformer.encoder.layers.1.0.fn.to_out.weight', 'transformer.h.5.cross_attn.fn.cross_attn.null_k', 'transformer.h.5.cross_attn.fn.cross_attn.to_out.bias', 'transformer.encoder.layers.1.2.fn.ff.3.weight', 'transformer.encoder.layers.1.1.norm.weight', 'transformer.encoder.layers.0.2.fn.ff.3.bias', 'transformer.h.5.cross_attn.norm.weight', 'transformer.encoder.layers.1.2.fn.ff.0.bias', 'transformer.encoder.layers.0.2.fn.ff.0.bias', 'transformer.encoder.layers.0.1.fn.to_out.bias', 'transformer.encoder.layers.0.2.fn.ff.3.weight', 'transformer.encoder.layers.0.0.fn.to_q.weight', 'transformer.encoder.layers.1.0.fn.to_out.bias', 'transformer.encoder.project_out.weight', 'transformer.h.5.cross_attn.fn.cross_attn.to_k.weight', 'transformer.h.5.cross_attn.fn.cross_attn.to_out.weight', 'transformer.encoder.layers.1.0.norm.weight', 'transformer.encoder.layers.0.0.norm.weight', 'transformer.encoder.layers.0.0.fn.to_out.bias'] You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.

Whole stack didn't work with python 3.7 but does with python 3.8

The installation instructions include:

conda create -n mengzi-retrieval-fit python=3.7

I found that this created loads of errors relating to importlib.metadata and importlib_metadata (not for the index but for most everything else). After a little bit of digging I found that Python 3.8 seemed to fix this issue. Upgrading my conda environment to 3.8. (i was lazy and left the index on 3.7). Anyway... for whomever comes after me. If you have these kinds of troubles. Try upgrading to python 3.8 and re-installing.

Customize knowledge db

Hello, Thanks for the valuable repo, I already tried to run this code and it worked very well! Looks like the db we can download through huggingface. I want to ask can we build our customize knowledge database without download from huggingface? Thanks!

Langboat/ReGPT-125M-200G score isn't reproducable

When I run:

python main.py \
    --model retrieval \
    --model_args pretrained=Langboat/ReGPT-125M-200G \
    --device 0 \
    --tasks wikitext  \
    --batch_size 1

I get the following:

  "config": {
    "model": "retrieval",
    "model_args": "pretrained=Langboat/ReGPT-125M-200G",
    "num_fewshot": 0,
    "batch_size": 1,
    "device": "0",
    "no_cache": false,
    "limit": null,
    "bootstrap_iters": 100000,
    "description_dict": {}
  }
}
retrieval (pretrained=Langboat/ReGPT-125M-200G), limit: None, provide_description: False, num_fewshot: 0, batch_size: 1
|  Task  |Version|    Metric     | Value |   |Stderr|
|--------|------:|---------------|------:|---|------|
|wikitext|      1|word_perplexity|36.1793|   |      |
|        |       |byte_perplexity| 1.9563|   |      |
|        |       |bits_per_byte  | 0.9681|   |      |

when I believe it should be getting closer to 22 word perplexity (According to the readme).

trainer训练很慢

我把retrival强制置none了,但是8张v100调用trainer训练时候还是非常的慢,大概一个小时训练3w条数据,请问是否有问题呀

About the compute resources

Thanks for making your work public!
Want to know how many computing resources were used for training and retrieval when you train the GPT-125M model?

Any Web Demo to have a look at it?

After reading some issues, I realized that it would cost a lot of time to train and take a heavy resouce to build a model on my own env. So is there any web demo page so that i can give it a try on it? At least I just wanna know how it reacts and how good it is.

Approximate Training Time

Hello,
Thanks for the authors to have this repo, it's really helpful to me. Now I'm using this repo to training with my dataset and my customize database. I use 8 V-100 GPUs and the utilization rate is of each GPU is nearly 100%. However, the training time is extremely slow, only 1 epoch per day. If I train the GPT-Neo-125M without retro (just use huggingface), It can train 40 epochs per day. So I want to ask that is there has any bottleneck to make the training become much slower in the retrieval process? or may I ask that how long did you train the retro model and get the result of this repo? Thanks!

Prompt for Result

Hi ,

Can you explain or give an example of what prompt we should be giving for Q&A
The code mentions finding loss as a whole for a file, but If I want to get the answer of a single question, how do I go about it?

Unable to reproduce PPL for GPT-Neo-125M using lm-eval

Hey!

I'm trying to run the following command using the lm-eval cli, but I can't reproduce the results you shared. Did you do something different? If not, do you have any idea where I'm doing wrong?

python main.py \
	--model gpt2 \
	--model_args pretrained=EleutherAI/gpt-neo-125M \
	--device 0 \
	--tasks wikitext \
	--batch_size 1

api.py

What is the purpose of api.py please? Will there be any output? I ran api.py as the readme and it stopped at the position shown below without any output, am I doing something wrong?
image

Eval_loss NaN with train/train.py

I had builded the pre retrieval dataset with train/preload.py and tried to train the model with train/train.py. The training loss is ok, but the eval loss is all nan value. Could you help me out ?

Question For Training Dataset

Except the database and index data on huggingface, the train_data.json in the repo could thought to be an example right? Would you mind releasing the full version of train and test dataset for reproducing the result ?

Need more ressources? :)

I am one of the founders of LAION ( https://laion.ai ) and we are very interested in getting retrieval augmented transformers to work.
If you would like to train a bigger model, let me know.
My discord ID is spirit-from-germany#1488

Prepare_load

How do we use the prepare_load file for training?

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.