majumderb / rezero Goto Github PK
View Code? Open in Web Editor NEWOfficial PyTorch Repo for "ReZero is All You Need: Fast Convergence at Large Depth"
Home Page: https://arxiv.org/pdf/2003.04887.pdf
License: MIT License
Official PyTorch Repo for "ReZero is All You Need: Fast Convergence at Large Depth"
Home Page: https://arxiv.org/pdf/2003.04887.pdf
License: MIT License
should I use a single lr like 1e-2 to this param in the optimizer?Will it be better to set g large lr?
Hi, nice work.
I notice that in encoder
layer, you multiply by resweight and then do dropout. But in decoder
layer, you do dropout and then multiply by resweight. Does the order of dropout and *resweight matter?
Thanks!
Hello, I read the paper, and it is interesting to me.
I have a question.
Many implements including Huggingface exclude LayerNorm and biases when decaying weights for convergence.
(huggingface/transformers#492)
Is it helpful to exclude the resweight parameters when decaying weights??
hello, i see your good demo in transformer and fully connected networks.
i wander, can it be applied to convolutional neural networks, is there any demo project ?
thanks .
Have you tired experiments with image models? or just NLP.
hi, guys,
did you have experiments in machine translation tasks? e.g. WMT ende or enfr
I experimented with rezero in my machine translation task, while training with fp16, using rezero brings loss scale reaching minimum , with/without layernorm, does that make sense?
Hi, nice work. When I apply it to shallower bert or gpt, after initialization, it often get NAN gradients(even for deeper architecture).
According to the paper, ReZero initializes each layer to perform identity operation.
It seems that ReZero is designed for training networks from scratch. I wonder is it applicable to fine-tuning and improve convergence?
great work! In your paper, rezero shows two main benefits both in deeper learning and faster convergence. Various forms of norm and residual connections are listd In Table 1. I am curious about the form of rezero with norm, e.g., x(i+1) = x(i) + aF(Norm(x(i))). Will it be worse or better?
Thanks
The thing is I see 1 genius man, Jürgen Schmidhuber, who invented Highway Networks in May 2015, and here are his works and subsequent works which you failed to cite:
Highway Networks (2015 May & Nov v2)
https://arxiv.org/abs/1505.00387
Training Very Deep Networks (2015 Jul & Nov v2)
https://arxiv.org/pdf/1507.06228.pdf
And ResNet is only a special case of HighwayNet, when the 2 gates are constant 1.
Highway and Residual Networks learn Unrolled Iterative Estimation (2016 Dec & 2017 Mar v2&v3)
https://arxiv.org/abs/1612.07771
And here instead of using a gate tensor as in HighwayNet they just use a scalar multiplier like you,
but in 2016, not in 2020... your scientific lagg is (significantly) more than zero.
Learning Identity Mappings with Residual Gates (2016 Nov & Dec v2)
https://arxiv.org/pdf/1611.01260v2.pdf
You did not cite either the Gated ResNet (which actually cite the HighwayNet) from 2016, neither the HighwayNet from 2015, but you cite the Kaiming He's ResNet (which also cite the HighwayNet).
Thanks for your greate job.Do you have any experiment of ReZero applied in different layers of transformers, like 1 layer Transformer layer and it performance , 2 layer Transformer layers and it performance, and so on.Does it make convergence faster in not so deep net?Thank you.
Hi ,thanks for your work, I use rezero method to train 32 layers transformer, I found that starting from the 20th layer, the resweight is almost 0, (layer 0 is data input, layer 31 is data output, dec_attn is attention layer, pos_ff is feed forward layer, resweight is the coefficient of xi + resweight*sublayer(xi) ), if resweight is almost 0, then this layer didn't work.
Why does this happen?
Thanks for your help
layers.0.dec_attn.resweight tensor([0.0962])
layers.0.pos_ff.resweight tensor( [0.0908])
layers.1.dec_attn.resweight tensor([0.1198])
layers.1.pos_ff.resweight tensor( [0.1206])
layers.2.dec_attn.resweight tensor([0.1403] )
layers.2.pos_ff.resweight tensor( [0.1274] )
layers.3.dec_attn.resweight tensor([0.1621] )
layers.3.pos_ff.resweight tensor( [0.1263] )
layers.4.dec_attn.resweight tensor([0.2211] )
layers.4.pos_ff.resweight tensor( [0.1438] )
layers.5.dec_attn.resweight tensor([0.2545] )
layers.5.pos_ff.resweight tensor( [0.1415] )
layers.6.dec_attn.resweight tensor([0.3898] )
layers.6.pos_ff.resweight tensor( [0.1338] )
layers.7.dec_attn.resweight tensor([0.2653] )
layers.7.pos_ff.resweight tensor( [0.1012] )
layers.8.dec_attn.resweight tensor([-0.0499] )
layers.8.pos_ff.resweight tensor( [-0.0796] )
layers.9.dec_attn.resweight tensor([0.0203] )
layers.9.pos_ff.resweight tensor( [-0.0963] )
layers.10.dec_attn.resweight tensor([-0.0249] )
layers.10.pos_ff.resweight tensor( [-0.0963] )
layers.11.dec_attn.resweight tensor([0.0133] )
layers.11.pos_ff.resweight tensor( [0.0927] )
layers.12.dec_attn.resweight tensor([-0.0243] )
layers.12.pos_ff.resweight tensor( [0.0958] )
layers.13.dec_attn.resweight tensor([0.0287] )
layers.13.pos_ff.resweight tensor( [-0.0868] )
layers.14.dec_attn.resweight tensor([-0.0148] )
layers.14.pos_ff.resweight tensor( [0.0814] )
layers.15.dec_attn.resweight tensor([-0.0198] )
layers.15.pos_ff.resweight tensor( [-0.0581] )
layers.16.dec_attn.resweight tensor([0.0174] )
layers.16.pos_ff.resweight tensor( [-0.0743] )
layers.17.dec_attn.resweight tensor([-0.0107] )
layers.17.pos_ff.resweight tensor( [-0.0619] )
layers.18.dec_attn.resweight tensor([ -0.0001] )
layers.18.pos_ff.resweight tensor( [-0.0001] )
layers.19.dec_attn.resweight tensor([0.0061] )
layers.19.pos_ff.resweight tensor( [-0.0000] )
layers.20.dec_attn.resweight tensor([0.0054] )
layers.20.pos_ff.resweight tensor( [-0.0001] )
layers.21.dec_attn.resweight tensor([-0.0001] )
layers.21.pos_ff.resweight tensor( [-0.0000] )
layers.22.dec_attn.resweight tensor([-0.0036] )
layers.22.pos_ff.resweight tensor( [0.0001] )
layers.23.dec_attn.resweight tensor([0.0042] )
layers.23.pos_ff.resweight tensor( [-0.0001] )
layers.24.dec_attn.resweight tensor([0.0017] )
layers.24.pos_ff.resweight tensor( [-0.0000] )
layers.25.dec_attn.resweight tensor([-0.0037] )
layers.25.pos_ff.resweight tensor( [-0.0003] )
layers.26.dec_attn.resweight tensor([0.0003] )
layers.26.pos_ff.resweight tensor( [0.0001] )
layers.27.dec_attn.resweight tensor([0.0004] )
layers.27.pos_ff.resweight tensor( [-0.0001] )
layers.28.dec_attn.resweight tensor([-0.0007] )
layers.28.pos_ff.resweight tensor( [0.0001] )
layers.29.dec_attn.resweight tensor([0.0002] )
layers.29.pos_ff.resweight tensor( [-0.0000] )
layers.30.dec_attn.resweight tensor([0.0008] )
layers.30.pos_ff.resweight tensor( [0.0000] )
layers.31.dec_attn.resweight tensor([0.0008] )
layers.31.pos_ff.resweight tensor( [-0.0000] )
RZTXDecoderLayer is made up of self-attn and feedforward network with
residual weights for faster convergece.
This encoder layer is based on the paper
Hello! Thanks for your interesting work and useful codes.
I have one small question. In table 1 of the paper, the formulation of Residual Network + Pre-Norm is . From my understanding, the corresponding formulation of Residual Network + Post-Norm should be which is also the real practice in ResNet. But the paper referred to a different formulation. Is this a typo or do I understand something wrong?
In this formulation, a trick called zero gamma trick
(setting gamma=0 for every batch normalization going back to the main branch) is commonly used [1,2]. Similar invariant Fixup Initialization [3] also benefits from this idea and shows the ability to train very deep neural network. The trick is used by both PyTorch code link and TensorFlow code link ResNet implementations. What is the relationship between ReZero and Zero gamma trick? Thanks!
[1] Goyal et al. Accurate, Large Minibatch SGD: Training ImageNet in 1 Hour.
[2] He et al. Bag of Tricks for Image Classification with Convolutional Neural Networks.
[3] Zhang et al. Fixup Initialization: Residual Learning Without Normalization.
A declarative, efficient, and flexible JavaScript library for building user interfaces.
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
An Open Source Machine Learning Framework for Everyone
The Web framework for perfectionists with deadlines.
A PHP framework for web artisans
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
Some thing interesting about web. New door for the world.
A server is a program made to process requests and deliver data to clients.
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
Some thing interesting about visualization, use data art
Some thing interesting about game, make everyone happy.
We are working to build community through open source technology. NB: members must have two-factor auth.
Open source projects and samples from Microsoft.
Google ❤️ Open Source for everyone.
Alibaba Open Source for everyone
Data-Driven Documents codes.
China tencent open source team.