This is a PyTorch implementation of the L ow Rank F a ctorization for Compact M ulti-Head A ttention (LAMA) mechanism and the corresponding pooler introduced in the paper: "Low Rank Factorization for Compact Multi-Head Self-Attention".
Figure 1 from Low Rank Factorization for Compact Multi-Head Self-Attention.
Note: I am not one of the authors on the paper.
The only dependency is PyTorch. Installation instructions can be found here.
import torch
from modules.lama import LAMA
num_heads = 8 # Number of attention heads
input_dim = 768 # Dimension of each tokens hidden representation
batch_size = 16 # Number of sentences/documents in the mini-batch
max_seq_len = 100 # Maximum length of the input sequence
# Create a random input sequence
inputs = torch.randn(batch_size, max_seq_len, input_dim)
# Optionally, you can provide a mask over timesteps (e.g., for padding tokens)
# Size: (batch_size, max_seq_len), 0 where timesteps should be masked and 1 otherwise
mask = torch.ones(batch_size, max_seq_len)
mask[:, -1] = 0
# Initialize the attention mechanism
lama = LAMA(num_heads, input_dim)
output = lama(inputs, mask)
assert output.size() == (batch_size, num_heads, max_seq_len)
import torch
from modules.lama_encoder import LAMAEncoder
num_heads = 8 # Number of attention heads
input_dim = 768 # Dimension of each tokens hidden representation
batch_size = 16 # Number of sentences/documents in the mini-batch
max_seq_len = 100 # Maximum length of the input sequence
# Create a random input sequence
inputs = torch.randn(batch_size, max_seq_len, input_dim)
# Optionally, you can provide a mask over timesteps (e.g., for padding tokens)
# Size: (batch_size, max_seq_len), 0 where timesteps should be masked and 1 otherwise
mask = torch.ones(batch_size, max_seq_len)
mask[:, -1] = 0
# Initialize the encoder
lama_encoder = LAMAEncoder(num_heads, input_dim)
output = lama_encoder(inputs, mask)
assert output.size() == (batch_size, num_heads, input_dim)
# If output_dim is not None (default), the "structured sentence embedding" is flattened by concatenation and projected by a linear layer into a vector of this size
lama_encoder = LAMAEncoder(num_heads, input_dim, output_dim=128)
output = lama_encoder(inputs, mask)
assert output.size() == (batch_size, 128)