jamie-stirling / retnet Goto Github PK
View Code? Open in Web Editor NEWAn implementation of "Retentive Network: A Successor to Transformer for Large Language Models"
License: MIT License
An implementation of "Retentive Network: A Successor to Transformer for Large Language Models"
License: MIT License
2511 if has_torch_function_variadic(input, weight, bias):
2512 return handle_torch_function(
2513 layer_norm, (input, weight, bias), input, normalized_shape, weight=weight, bias=bias, eps=eps
2514 )
-> 2515 return torch.layer_norm(input, normalized_shape, weight, bias, eps, torch.backends.cudnn.enabled)
RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cpu and cuda:0! (when checking argument for argument weight in method wrapper_CUDA__native_layer_norm)
The implementation of chunkwise retention paradigm on the chunkwise-real branch gives different outputs to the other two paradigms.
It appears there may be a mistake in the paper on which the implementation was based, in equation (7). A pull request fixing this and obtaining outputs consistent with the other two paradigms would be greatly appreciated.
This can be reproduced by running `python src/tests.py', with stdout:
FFF
======================================================================
FAIL: test_retnet (__main__.TestRetNet)
verify that the three implementations of RetNet are identical
----------------------------------------------------------------------
Traceback (most recent call last):
File "/home/jamie/Repos/RetNet/src/tests.py", line 137, in test_retnet
self.assertTrue(torch.allclose(Y_parallel, Y_chunkwise, atol=1e-5)) # fails
AssertionError: False is not true
======================================================================
FAIL: test_multiscale (__main__.TestRetention)
verify that the three implementations of MultiScaleRetention are identical
----------------------------------------------------------------------
Traceback (most recent call last):
File "/home/jamie/Repos/RetNet/src/tests.py", line 86, in test_multiscale
self.assertTrue(torch.allclose(Y_parallel, Y_chunkwise, atol=1e-5)) # fails
AssertionError: False is not true
======================================================================
FAIL: test_simple (__main__.TestRetention)
verify that the three implementations of SimpleRetention are identical
----------------------------------------------------------------------
Traceback (most recent call last):
File "/home/jamie/Repos/RetNet/src/tests.py", line 45, in test_simple
assert torch.allclose(Y_parallel, Y_chunkwise, atol=1e-5) # fails
AssertionError
----------------------------------------------------------------------
Ran 3 tests in 0.098s
FAILED (failures=3)
Hi there, just found this work thanks to @yk's recent video. Nice job! There are similarities with work I've been doing for a few months, and while I'm a little bummed you beat me to publish I wasn't going to be able to do a good job of evaluating the architectures anyway (this is a side-project that is currently thrashing my laptop and I'm not sure I could justify the cloud costs to train even a moderately sized model just out of curiosity), and I'm glad the idea is being investigated and released with a permissive license.
I'm not sure if you're looking for suggestions or collaborations, but thought I'd put my ideas out there and see what happens. I'm happy to provide more details/collaborate on a future work if there's interest, or feel free to point me towards someone else who might be interested or run with it yourself.
From my understanding of the paper/code (and I apologise if I've got any of this wrong), computing retention values is still O(T^2)
in sequence length T
and prone to underflow (hence the nan replacement). Neither of these is necessary. The computation you're performing is just an exponential moving average which can be computed in O(T)
with a scan
using an associative operator, meaning associative_scan
implementations can do it very efficiently in parallel.
Unfortunately we're still waiting on pytorch's associative_scan implementation, so I'll be using jax below, for which a primitive exists. Note I've got a pytorch version working which wraps the jax implementations with jax2torch, though I can't make it work nicely with torch's compile
and I'm more comfortable with jax anyway.
The below is an implementation that takes an arbitrary decay factor at each step. To get the same performance as in your paper, I think you can just set it to factors = gamma * ones_like(values)
, but
import typing as tp
import jax
import jax.numpy as jnp
Pair = tp.Tuple[jnp.ndarray, jnp.ndarray]
def _cumulative_ema_op(a: Pair, b: Pair) -> Pair:
xa, fa = a
xb, fb = b
return xa * fb + xb, fa * fb
def cumulative_ema(
values: jnp.ndarray, factors: jnp.ndarray, reverse: bool = False, axis: int = 0
) -> jnp.ndarray:
"""
Compute cumulative exponential moving average.
If `reverse == False` and axis == 0,
output[i+1] = output[i] * factors[i+1] + output[i+1]
If `reverse == True`, then the result is the reverse of the non-reversed call on
arguments reversed on the given axis.
Args:
values: N-D float values
factors: same shape/dtype as values
axis: the axis to compute exponential moving average along.
reverse: if True, perform accumulation in reverse.
Returns:
cumulative ema values, same shape as values/factors.
"""
if axis < 0:
axis += len(values.shape)
assert values.shape == factors.shape, (values.shape, factors.shape)
f, t = jax.lax.associative_scan(
_cumulative_ema_op, (values, factors), reverse=reverse, axis=axis
)
del t
return f
Thus computing retention values from Q, K and V values would be:
def retention(Q, K, gamma, V, reverse=False):
"""
Notation:
T: time dimension
A: attention dimension
C: number of output channels
Args:
Q: [T, A] query
K: [T, A] key
gamma: [] decay constant
V: [T, C] values
Returns:
[T, C]
"""
rhs = jnp.einsum('ta,tc->tac', K, V)
rhs = cumulative_ema(rhs, jnp.full_like(rhs, gamma), axis=0, reverse=reverse)
return jnp.einsum('ta,tac->tc', Q, rhs)
I've left out the batch dimension for simplicity, but I'm sure you could make the appropriate modifications (or if you decide to use jax, just vmap it). I'll spare you the full theoretical derivation for why this computes (Q K.T * D) @ V
, but the short version is we use property 1 from here (see last slide) and note that DX = cumulative_ema(X, jnp.full_like(X, gamma), axis=0)
. This is O(TAC) ins space/time rather than O(T^2(A + C) in time and O(T(T + C)) in space.
Creating a bidirectional encoder is thus trivial by combining two - one with reverse=False
and the other with reverse=True
.
Now with that implementation you might be tempted to play around with the architecture a little - I've played with creating only two transformed matrices, factors
(sigmoid-activated to ensure decay) and values
of the same shape (rather than Q, K, V) and using them in the cumulative_ema
directly which reduces the O(TAC)
memory/time requirement to O(TC)
. Conceptually this just means that each token embedding at each layer just decides how much of the past to forget, and what to add based on the previous layer's embedding. I don't see any barriers to implementing a complex version to allow for periodic behaviour, but haven't attempted that.
My implementation is keras_core
-based (so you can use pytorch backend so long as you don't try and compile). It needs a lot of cleaning up before I'm prepared to make it public but happy to share privately. Very small-scale experiments where I've just replaced Bert's self-attention mechanism with the bidirectional O(TC) implementation discussed above and remove positional embeddings entirely have proved promising (faster training, better performance than bert). I have no way of validating if performance scales with model size - I was planning on looking for collaborators/sponsors for that, so if you're interested in that let me know :).
I have a question regarding "Chunkwise Recurrent Representation of Retention." The original expression in the paper is as follows:
In your implementation, the code looks like this:
r_i = (K.transpose(-1, -2) @ (V * D[-1].view(1, chunk_size, 1))) + (self.gamma ** chunk_size) * r_i_1
The first part of this equation calculates the KV matrix for the current chunk, and then multiplies it by a scaling factor. My understanding is that, assuming we ignore batch size, the shapes of K and V for the current chunk are both (2,3). In other words, the current chunk contains 2 tokens, so the KV matrix should have a shape of (3,3). Then, based on your code, you multiply this KV matrix by the last row of the D matrix (shape is (2,2)), for example, if the D matrix is [[1, 0], [0.9, 1]], then V * D[-1].view(1, chunk_size, 1) becomes [[0.9], [1]], and these values are multiplied with the first and second rows of the V matrix to implement decay. However, when we take the inner product of the Q matrix for the chunk and the first half of R_i, it seems like both q tokens within the Q matrix are using the same decay factor, is that correct? In other words, for the same chunk, if we want to perform attention, the second q token should intuitively be multiplied by a decay factor (0.9) when attending to the first v token, but when the first q token operates on the first v token, it doesn't need this decay factor.
Additionally, for the second half of R_i, it seems that the entire chunk is considered as a whole, and R_i_1 is directly subjected to decay as a whole, and the decay occurs as many times as the length of the chunk.
There's another question I have regarding the cross-chunk calculations.
#e[i,j] = gamma ** (i+1)
e = torch.zeros(batch, chunk_size, 1)
for _i in range(chunk_size):
e[:, _i, :] = self.gamma ** (_i + 1)
cross_chunk = (Q @ r_i_1) * e
In the code, the variable 'e' appears to play a role in decay as well. However, based on the code, the final result after calculating (Q @ r_i_1) might be something like [o1, o2, o3]^T, where each 'oi' is a row vector with D dimensions. What I'd like to point out is that, according to your code, 'o1' actually has the least decay, and 'o3' has the most decay. But intuitively, for the current Q, shouldn't the vector corresponding to 'o1' be the farthest from the q tokens within the current chunk? In other words, shouldn't the decay of 'o1' be the greatest? So, should the code be like this:
#e[i,j] = gamma ** (i+1)
e = torch.zeros(batch, chunk_size, 1)
for _i in range(chunk_size):
# e[:, _i, :] = self.gamma ** (_i + 1)
e[:, _i, :] = self.gamma ** (chunk_size - _i)
cross_chunk = (Q @ r_i_1) * e
This is very confusing to me. Is there a more detailed derivation or a clearer explanation of how equation (7) in the original article is obtained? Especially the exponential part of the decay factor, is the result of this calculation consistent with the result of completely parallel computation? Can someone help me with this?
First, many thanks for your implementation!
It seems that the _get_D function
def _get_D(self, sequence_length):
# D[n,m] = gamma ** (n - m) if n >= m else 0
D = torch.zeros((sequence_length, sequence_length), requires_grad=False)
for n in range(sequence_length):
for m in range(sequence_length):
if n >= m:
D[n, m] = self.gamma ** (n - m)
return D
gets really slow for long sequence lengths, resulting in very low GPU utility.
by changing to the style below it gets better. Not sure if it's perfectly correct but for gamma < 1 it seems all good.
def _get_D(self, sequence_length):
n = torch.arange(sequence_length).unsqueeze(1)
m = torch.arange(sequence_length).unsqueeze(0)
# Broadcast self.gamma ** (n - m) with appropriate masking to set values where n < m to 0
D = (self.gamma ** (n - m)) * (n >= m).float() #this results in some NaN when n is much larger than m
# fill the NaN with 0
D[D != D] = 0
return D
Maybe I am missing something but do we need the Theta? Since its magnitude is 1, multiplying with its conjugate should cancel out in the parallel version.
Can this model achieve cross-attention similar to how transformer handles different modal embedding matrices?
The current implementation uses complex arithmetic to implement the original paper, which has known issues with stability and precision. It's been suggested that xPos is a more stable and efficient way to achieve the same things by representing rotations using Euler's identity.
It would be nice if all constructors had an additional option to do arithmetic in real algebra using xPos (rotary positional embeddings), as described in this paper:
https://arxiv.org/abs/2212.10554
and implemented here:
https://github.com/microsoft/torchscale/blob/main/torchscale/component/xpos_relative_position.py
This may solve current issues with memory stability.
I'm a little confused of what retnet does in practice. Because in the formula Rentention(X) = (Q @ K.T * D) @ V
, if the decay is 1, the mathematical derivation of proving the equivalence between RNN and the Retnet's transformer still works. As when decay is equal to 1, D will be the normal attention mask used by almost all existing GPT models. Does that mean all existing GPT models can be modified into Retnet by simply modifying the inference function without any further training? Am I correct or do I miss something?
Sorry for bothering you and this may be a dumb question:
The Complex type in here is for what?
I'm not very good at math and if you guys can explain why we need to use complex it will be good.
as we all know, xpos has decay ability , but you add D of decay mat after Q @k^T ..
Is It redundant ?
return (self.swish(X @ self.W_G.to(self.complex_type)) + Y) @ self.W_O.to(self.complex_type)
you may want to know microsoft/torchscale@bf65397
Thanks for the well-written package! The RetNet's official implementation had several updates at https://github.com/microsoft/unilm/blob/master/retnet/README.md#changelog .
Line 104 in 2acf026
You should change the device of these tensors in order to match the model device.
When training using a GPU I have an error of mismatching devices.
Thank you very much for the work, can retnet be applied in point cloud tasks?
In MultiScaleRetention class, it is mentioned that 's_n_1s' has dimensions (batch_size, heads, head_size, head_size), while in SimpleRetention, 's_n_1' is defined as 's_n_1s[i]'. However, you mentioned that 's_n_1' has dimensions (batch_size, hidden_size, v_dim). Can you clarify this?
Hello
I am new to llm world and i cant seem to train a new retnet model. Is there a script that i can use ? or can you guide me through some resources?
Thank you in advance
Hello, I have reviewed some of the code and did not use an attention mask. It's retnet. Don't you need to cover up the pad ID? Or does the pad ID have no impact on the previous sequence?
Excellent, Can you write a demo example on training ? Comparing to the Microsoft code have you check that it provide same number of parameters for the same settings ?
Line 45 in 2acf026
Q and K are put onto any device because they are model parameters, while D is created in SimpleRetention._get_D and is not put to any device. Therefore if you train on CUDA, Q and K are on cuda and D is on CPU. Error arises
Hi! Is there a way to pass an attention mask like in transformers library or src_key_padding_mask in nn.Transformer? So that the model wouldn't "pay attention" to paddings?
A declarative, efficient, and flexible JavaScript library for building user interfaces.
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
An Open Source Machine Learning Framework for Everyone
The Web framework for perfectionists with deadlines.
A PHP framework for web artisans
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
Some thing interesting about web. New door for the world.
A server is a program made to process requests and deliver data to clients.
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
Some thing interesting about visualization, use data art
Some thing interesting about game, make everyone happy.
We are working to build community through open source technology. NB: members must have two-factor auth.
Open source projects and samples from Microsoft.
Google ❤️ Open Source for everyone.
Alibaba Open Source for everyone
Data-Driven Documents codes.
China tencent open source team.