Code Monkey home page Code Monkey logo

ar-diffusion's Introduction

AR-Diffusion

This repo provides the code and models for AR-Diffusion: Auto-Regressive Diffusion Model for Text Generation.

๐Ÿš€ Overview

we introduce Auto-Regressive Diffusion (AR-Diffusion). AR-Diffusion ensures that the generation of tokens on the right depends on the generated ones on the left, a mechanism achieved through employing a dynamic number of denoising steps that vary based on token position. This results in tokens on the left undergoing fewer denoising steps than those on the right, thereby enabling them to generate earlier and subsequently influence the generation of tokens on the right.

In a series of experiments on various text generation tasks including text summarization, machine translation, and common sense generation, AR-Diffusion clearly demonstrated the superiority over existing diffusion language models and that it can be $100\times\sim 600\times$ faster when achieving comparable results.

You can find more details in the paper.

โš™๏ธ Experiment Preparation

Dependencies:

Downstream Task Dataset:

The text generation benchmarks we use is well-known and widely used, including XSum, CNN/DailyMail, IWSLT14 and Commongen. You can find more detailed information and obtain methods of the dataset here.

Model

We have released the checkpoint of the AR-Diffusion here for each dataset (6-layer encoder, and 6-layer decoder):

๐Ÿ’ก Training

In this section, we will use XSum dataset as an example to demonstrate the process of AR-Diffusion training on downstream tasks. (The training scripts for all datasets are available at scripts/train.sh.) The running script for training is as follows:

FILE_NAME = xsum
STEP = 80000

torchrun --nproc_per_node=8 --nnodes=1 ./train_utils/trainer_main.py \
model.name='bert-base-uncased' batch_size=128 grad_accum=3 \
total_steps=$STEP exp.name=$FILE_NAME \
data.name=xsum tgt_len=50 max_pos_len=512 lr=8e-4 lr_step=40000 \
intermediate_size=2048 num_attention_heads=8 dropout=0.2 \
in_channels=128 out_channels=128 time_channels=128 \
eval_interval=3000 log_interval=1000 \
schedule_sampler='xy_uniform' time_att=True att_strategy='txl' use_AMP=True \

๐Ÿ’ฌ Inference

In this section, we will show how to batch generate text from trained AR-Diffusion. We use XSum dataset as an example (The training scripts for all datasets are available at scripts/gen.sh.). The running script for generating is as follows:

FILE_NAME = xsum
STEP = 80000

torchrun --nproc_per_node=8 --nnodes=1 ./gen_utils/generate.py \
model.name='bert-base-uncased' batch_size=800 \
exp.name=$FILE_NAME load_step=$STEP \
data.name=xsum tgt_len=50 max_pos_len=512 num_samples=50 \
intermediate_size=2048 num_attention_heads=8 dropout=0.2 \
in_channels=128 out_channels=128 time_channels=128 \
skip_sample=True gen_timesteps=20 \
schedule_sampler='xy_uniform' time_att=True att_strategy='txl' load_from_ema=True prediction=True \

โšฝ Evaluation

In this section, we will show how to select the best samples by MBR on candidate samples and evaluate the selected samples. We use XSum dataset as an example (The training scripts for all datasets are available at scripts/concat_eval.sh.). The running script for evaluation is as follows:

FILE_NAME=xsum
DATA_NAME=xsum
STEP=80000
NUM=50

echo "model step" $STEP
j=0
while [ "$j" -lt $NUM ]; do
echo "gen num $j"
./.conda/envs/torch/bin/python ./eval_utils/concat.py \
--n_gpu=8 --num=$j \
--src_path=./my_output/$DATA_NAME/$FILE_NAME/$STEP\_ema_0.9999_skip__xy_20/num$j \
--tgt_path=./data/$DATA_NAME

j=$(($j+1))
done


./.conda/envs/torch/bin/python ./eval_utils/mbr/mbr_select.py \
--data_name=$DATA_NAME --num=$NUM --process=50 # --exp_name=500

Repo Reference

This repo is partially referred to Diffusion-LM and GENIE.

๐Ÿ“œ Citation

Please cite our paper if you use AR-Diffusion in your work:

@misc{wu2023ardiffusion,
      title={AR-Diffusion: Auto-Regressive Diffusion Model for Text Generation}, 
      author={Tong Wu and Zhihao Fan and Xiao Liu and Yeyun Gong and Yelong Shen and Jian Jiao and Hai-Tao Zheng and Juntao Li and Zhongyu Wei and Jian Guo and Nan Duan and Weizhu Chen},
      year={2023},
      eprint={2305.09515},
      archivePrefix={arXiv},
      primaryClass={cs.CL}
}

ar-diffusion's People

Contributors

wutong4012 avatar

Stargazers

 avatar  avatar Jaswer avatar  avatar Mohammed OE Abdallah avatar zcl avatar Zefeng Zhu avatar bansky-cl avatar  avatar

Watchers

Kostas Georgiou avatar  avatar

Forkers

lansmurf

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.