diff --git a/modeling/pythia-1B-dijiang/modeling_gpt_neox.py b/modeling/pythia-1B-dijiang/modeling_gpt_neox.py
--- a/modeling/pythia-1B-dijiang/modeling_gpt_neox.py
+++ b/modeling/pythia-1B-dijiang/modeling_gpt_neox.py
@@ -40,6 +40,9 @@
from ...utils import is_flash_attn_2_available, is_flash_attn_greater_or_equal_2_10, logging
from .configuration_gpt_neox import GPTNeoXConfig
+import numpy as np
+from scipy.fft import dct
+
if is_flash_attn_2_available():
from flash_attn import flash_attn_func, flash_attn_varlen_func
@@ -98,7 +101,7 @@
class GPTNeoXAttention(nn.Module):
- def __init__(self, config):
+ def __init__(self, config, gamma):
super().__init__()
self.config = config
self.num_attention_heads = config.num_attention_heads
@@ -118,6 +121,17 @@
self.query_key_value = nn.Linear(config.hidden_size, 3 * config.hidden_size, bias=config.attention_bias)
self.dense = nn.Linear(config.hidden_size, config.hidden_size, bias=config.attention_bias)
self.attention_dropout = nn.Dropout(config.attention_dropout)
+
+ self.proj_matrix = self._build_projection()
+ self.v_dim = config.hidden_size
+ self.W_G = nn.Parameter(torch.randn(self.hidden_size, self.v_dim) / self.hidden_size)
+ self.swish = nn.SiLU()
+ self.group_norm = nn.GroupNorm(self.head_size, self.v_dim)
+ nn.init.xavier_uniform_(self.W_G.data, gain=2 ** -2.5)
+ self.D1 = self._get_D1(self.config.max_position_embeddings)
+ self.D2 = self._get_D2(self.config.max_position_embeddings)
+ self.mask = self._get_mask(self.config.max_position_embeddings).unsqueeze(0)
+
self.is_causal = True
def _init_bias(self, max_positions, device=None):
@@ -155,7 +169,32 @@
)
else:
raise ValueError(f"Unknown RoPE scaling type {scaling_type}")
+
+ def _get_D1(self, sequence_length):
+ D = ((1 - torch.exp(torch.linspace(torch.log(1/32), torch.log(1/512), self.num_attention_heads))).view(self.num_attention_heads, 1, 1) ** (torch.arange(sequence_length).unsqueeze(1))).float().unsqueeze(0)
+ return nn.Parameter(D, requires_grad=False)
+
+ def _get_D2(self, sequence_length):
+ D = 1/((1 - torch.exp(torch.linspace(torch.log(1/32), torch.log(1/512), self.num_attention_heads))).view(self.num_attention_heads, 1, 1) ** (torch.arange(sequence_length).unsqueeze(1))).float().unsqueeze(0)
+
+ return nn.Parameter(D, requires_grad=False)
+
+ def _get_mask(self, sequence_length):
+ n = torch.arange(sequence_length).unsqueeze(1)
+ m = torch.arange(sequence_length).unsqueeze(0)
+
+ M = torch.ones(self.num_attention_heads).view(self.num_attention_heads, 1, 1)*(n >= m).float()
+
+ return M
+
+ def _build_projection(self):
+ icdf_w = torch.distributions.Normal(0, 1).icdf(torch.diag_embed(torch.diag(torch.rand(self.head_size, self.head_size))))
+ icdf_w = torch.where(torch.isinf(icdf_w), torch.full_like(icdf_w, 0), icdf_w)
+ C = dct(np.eye(self.head_size, self.head_size), axis=0,norm='ortho')
+ C = torch.from_numpy(C).type(torch.FloatTensor)
+ return nn.Parameter((C @ icdf_w).contiguous(), requires_grad=False)
+
def forward(
self,
hidden_states: torch.FloatTensor,
@@ -212,7 +251,8 @@
# Reshape outputs
attn_output = self._merge_heads(attn_output, self.num_attention_heads, self.head_size)
- attn_output = self.dense(attn_output)
+ attn_output = self.group_norm(attn_output.reshape(-1, self.v_dim)).reshape(attn_output.shape)
+ attn_output = self.dense(self.swish(hidden_states @ self.W_G) * attn_output)
outputs = (attn_output, present)
if output_attentions:
@@ -256,36 +296,16 @@
self._init_bias(key_length, device=key.device)
causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length]
- query = query.view(batch_size * num_attention_heads, query_length, attn_head_size)
- key = key.view(batch_size * num_attention_heads, key_length, attn_head_size)
- attn_scores = torch.zeros(
- batch_size * num_attention_heads,
- query_length,
- key_length,
- dtype=query.dtype,
- device=key.device,
- )
- attn_scores = torch.baddbmm(
- attn_scores,
- query,
- key.transpose(1, 2),
- beta=1.0,
- alpha=self.norm_factor,
- )
- attn_scores = attn_scores.view(batch_size, num_attention_heads, query_length, key_length)
+ query = nn.functional.softmax([email protected]_matrix, dim=-1)
+ key = nn.functional.softmax([email protected]_matrix, dim=-1)
+ query = query*self.D1[:,:,:query_length,:]
+ key = key*self.D2[:,:,:key_length,:]
- mask_value = torch.finfo(attn_scores.dtype).min
- # Need to be a tensor, otherwise we get error: `RuntimeError: expected scalar type float but found double`.
- # Need to be on the same device, otherwise `RuntimeError: ..., x and y to be on the same device`
- mask_value = torch.tensor(mask_value, dtype=attn_scores.dtype).to(attn_scores.device)
- attn_scores = torch.where(causal_mask, attn_scores, mask_value)
+ attn_scores = torch.matmul(query, key.transpose(2, 3))
- if attention_mask is not None:
- # Apply the attention mask
- attn_scores = attn_scores + attention_mask
+ attn_scores = attn_scores * self.mask[:,:,:query_length,:key_length].to(attn_scores.device,dtype=attn_scores.dtype)
- attn_weights = nn.functional.softmax(attn_scores, dim=-1)
- attn_weights = attn_weights.to(value.dtype)
+ attn_weights = attn_scores.to(value.dtype)
# Mask heads if we want to
if head_mask is not None:
@@ -667,14 +687,14 @@
class GPTNeoXLayer(nn.Module):
- def __init__(self, config):
+ def __init__(self, config, gamma):
super().__init__()
self.use_parallel_residual = config.use_parallel_residual
self.input_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.post_attention_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.post_attention_dropout = nn.Dropout(config.hidden_dropout)
self.post_mlp_dropout = nn.Dropout(config.hidden_dropout)
- self.attention = GPT_NEOX_ATTENTION_CLASSES[config._attn_implementation](config)
+ self.attention = GPT_NEOX_ATTENTION_CLASSES[config._attn_implementation](config, gamma)
self.mlp = GPTNeoXMLP(config)
def forward(
@@ -787,7 +807,8 @@
self.embed_in = nn.Embedding(config.vocab_size, config.hidden_size)
self.emb_dropout = nn.Dropout(config.hidden_dropout)
- self.layers = nn.ModuleList([GPTNeoXLayer(config) for _ in range(config.num_hidden_layers)])
+ self.gammas = (1 - torch.exp(torch.linspace(torch.log(1/32), torch.log(1/512), config.num_hidden_layers))).detach().cpu().tolist()
+ self.layers = nn.ModuleList([GPTNeoXLayer(config, gamma) for gamma in self.gammas])
self.final_layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"