Code Monkey home page Code Monkey logo

adaptive-span's Introduction

Sequential Transformer

This is a code for training Transformers on sequential tasks such as language modeling. Unlike the original Transformer architecture, it uses caching of previous representations and relative position embeddings to better adapt to sequential tasks. In addition, the code also implements the following projects as described below and in this blog post:

Requirements

You need PyTorch 0.4.1 or above and a cuda-enabled GPU to run the code. If there are multiple GPUs available, the code uses nn.DataParallel to utilize them. For better efficiency, enable distributed training by --distributed argument, which can run on multiple nodes.

Adaptive Attention Span

This code can be used for running experiments in Adaptive Attention Span for Transformers paper. The adaptive span allows a model to learn an optimal context size for each self-attention head from training data. As shown in the below figure, only few heads require long attention span, thus making it possible to increase the context size to 8k tokens without increasing computation time and memory footprint significantly.

An argument --adapt-span enables adaptive span. Otherwise a model will have a fixed attention span. The adaptive-span is implemented as a nn.Module to make it easier to plug it into other models.

Running experiments in the paper

Scripts for running experiments in the paper are located in ./experiments/ directory. For example, a smaller 8-layer version of our model can be trained on a single GPU by running:

bash experiments/enwik8_small.sh

It should reach about 1.3bpc on dev after 150k steps.

For training larger models, multiple GPUs are recommended. In the script files, you can configure the number of available GPUs. Increase the --batch-split argument if you run out of GPU memory (it splits batches into smaller pieces without changing the final result).

We obtained the following results in our experiments:

Experiment #params dev test
enwik8 38M 1.04 bpb 1.02 bpb
enwik8_large 209M 1.00 bpb 0.98 bpb
text8 39M 1.05 bpc 1.11 bpc
text8_large 209M 1.01 bpc 1.07 bpc

A large model training takes about 1.2sec/batch near the end (initially it's faster because the attention spans are smaller) on 8 V100 GPUs. So, for example, the whole enwik8_large training of 170k steps should take less than 2.4 days.

Pre-trained models

You can download pre-trained models by running the get_pretrained.sh script. Then the same scripts in ./experiments/ can be used to evaluate those models. Since the download script puts models in ./checkpoints/, make sure there is no file with the same name. Note that these pre-trained models are obtained by rerunning the training scripts after the code cleanup, so there are small differences from the above results due to the randomness of the training.

All-attention Network

The code also can be used for training All-attention Networks introduced in Augmenting Self-attention with Persistent Memory. If --pers-mem-size argument is set to N, all FF sublayers will be removed from the model and N persistent memory vectors will be added to every self-attention sublayer. The following experiments can be found in ./experiments/ directory.

Experiment #params dev test
enwik8_pers_small.sh 39M 1.03 bpb 1.01 bpb
enwik8_pers.sh 114M 1.00 bpb 0.98 bpb
wiki103_pers.sh 133M 18.8 ppl * 19.7 ppl *

(*This number is slightly better than the paper because it includes end-of-line as a token.)

License

The code is licensed under CC-BY-NC license. See the LICENSE file for more details.

Acknowledgement

We thank Xavier Martinet for helping with cleaning the code. The data preprocessing scripts are downloaded from awd-lstm and transformer-XL repos. The adagrad_with_grad_clip.py is mostly adapted from PyTorch.

adaptive-span's People

Contributors

geohot avatar tesatory avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

adaptive-span's Issues

Compute attention span of individual attention heads

I am working in model interpretability and wish to learn more about what each head is looking at and it's attention span (similar to the graphs from the paper). Could you please share what did you use to get the span of individual head ?

Using mask can reduce FLOPs?

Hi,
Good Paper.
From the source code your idea is implemented by mask(AdaptiveMask) which means there will be no FLOPs saving right?

BPC

Scripts in experiments directory calculates bits per byte, not bits per character. Am I right?

It is important when comparing chars or words perplexities.

For example, for English enwiki8 ratio chars to bits is 1.0033040809995477:
BPB: 1.0 -> byte perplexity: 2.718 -> char perplexity: 2.727

For Polish ratio chars to bits is 1.0505100080652954:
BPB: 1.0 -> byte perplexity: 2.718 -> char perplexity: 2.859

Question: How to reduce the memory in this project

Hi, I read your paper ,it's great. I'm very interesting about how to reduce the memory in the real project.

I guess the memory things are:
in

key_pe = key_pe[:, :, trim_len:]

But I just see you cut the key_pe and It's just reduce a little memory and wouldn't help for reduce the Q K things I think.

So. can you explain How to reduce the memory in the code?

thanks

confuse

Are the results of dev and test on the test data set?

Experiment | #params | dev | test
enwik8 | 38M | 1.04 bpb | 1.02 bpb
enwik8_large | 209M | 1.00 bpb | 0.98 bpb
text8 | 39M | 1.05 bpc | 1.11 bpc
text8_large | 209M | 1.01 bpc | 1.07 bpc

Do they need to be evaluated multiple times on the test? When I reproduce the model's train and valid bpcs are much larger than those obtained on the test, is it normal?

The way you preprocess data is different from that of Transformer-XL

I noticed that you add a <eos> tokens at the end of each line:
https://github.com/facebookresearch/adaptive-span/blob/master/data.py#L34

But in Transformer-XL's code, they do not add<eos> for enwik8 and text8:
https://github.com/kimiyoung/transformer-xl/blob/master/pytorch/data_utils.py#L205

According to my experience, in enwik8 (sentence length is short), using <eos> would make the final bpc/bpb about 0.02 lower.
It's better if you use the same setting for fair comparison.

Please convert to a permissive license

Other Facebook projects like react use permissive licenses like MIT, would it be possible to relicense this for commercial use so startups could participate in development also?

Queries about adaptive span

Hi, I had few queries:

  • Do adaptive span change with time as the model sees more data ? Or is the span static ? In my experiments, they do not seem to change for some reason.
  • Secondly, as long as the values in current_val lies between [0,1], adaptive span loss won't change right since you are using _clamp(0,1). So how much weight does this loss carry ?

Understanding adaptive-span loss

Hi,

Sorry to bother you. I have gone through the paper several times. I've also looked at the code many times
I just had one query with adaptive span loss. Here's what I interpreted:
This parameter self.current_val = nn.Parameter(torch.zeros(*shape) + init_val) is responsible for calculating loss, mask and span.
In this case, this parameter will be initialized with zero values since as per your config since init_val is kept as 0 (since the mean of all the values of the parameter will be 0).

My question is how is this parameter getting updated ?

When I call adaptive_span.get_loss(), it in turn calls:
self._loss_coeff * self._max_span * self._mask.current_val.mean() which will also return 0.
When I do :
adaptive_span.clamp_param(), nothing will happen since all the values inside the parameter were initialized with 0. These are the only two function calls happening inside train method.
Can you please point out what am I missing ?

Understanding graphs from papers

Thanks for replying to my previous questions. In the fig 3 of your paper, i had few queries.

  1. In Average Span vs Span Limit (Central graph), you showed that in case of fixed span model, span increases as span limit increases. I wanted to ask, as per your code base, spans are already monitored by current_val only if adapt_span_enabled is set to True (line). So how did you measure the span of fixed model because in that case, the bool value will be false, and then AdaptiveSpan won't monitor it. How did you measure the span of fixed model ?

  2. In FLOPS vs Span Limit, you showed that FLOPS keep on increasing in the case of fixed span model while in the case adaptive span, FLOPS were constant (approximately linear). After through inspection, FLOPS are constant in adaptive span but they don't see seem to be rising in case of standard attention as well. In both the cases, FLOPS are same. Could you please share some insights.

Thanks

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.