Skip to content

Commit

Permalink
Merge branch 'fla-org:main' into main
Browse files Browse the repository at this point in the history
  • Loading branch information
yibozhong authored Jan 18, 2025
2 parents ed12c30 + ca9e943 commit 973e3eb
Show file tree
Hide file tree
Showing 9 changed files with 62 additions and 34 deletions.
10 changes: 7 additions & 3 deletions fla/layers/abc.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,18 +144,22 @@ def forward(
if last_state is not None:
conv_state_q, conv_state_k, conv_state_v = last_state['conv_state']
conv_mask = attention_mask[:, -hidden_states.shape[1]:] if attention_mask is not None else None
position_ids = kwargs.get('position_ids', None)
q, conv_state_q = self.q_conv1d(x=self.q_proj(hidden_states),
mask=conv_mask,
cache=conv_state_q,
output_final_state=use_cache)
output_final_state=use_cache,
seq_idx=position_ids)
k, conv_state_k = self.k_conv1d(x=self.k_proj(hidden_states),
mask=conv_mask,
cache=conv_state_k,
output_final_state=use_cache)
output_final_state=use_cache,
seq_idx=position_ids)
v, conv_state_v = self.v_conv1d(x=self.v_proj(hidden_states),
mask=conv_mask,
cache=conv_state_v,
output_final_state=use_cache)
output_final_state=use_cache,
seq_idx=position_ids)
else:
q = self.q_proj(hidden_states)
k = self.k_proj(hidden_states)
Expand Down
15 changes: 8 additions & 7 deletions fla/layers/delta_net.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,7 @@
from torch.nn import functional as F

from fla.modules import FusedRMSNormSwishGate, RMSNorm, ShortConvolution
from fla.ops.delta_rule import (chunk_delta_rule,
fused_recurrent_delta_rule)
from fla.ops.delta_rule import chunk_delta_rule, fused_recurrent_delta_rule

if TYPE_CHECKING:
from transformers.processing_utils import Unpack
Expand Down Expand Up @@ -199,19 +198,22 @@ def forward(
if last_state is not None:
conv_state_q, conv_state_k, conv_state_v = last_state['conv_state']
conv_mask = attention_mask[:, -hidden_states.shape[1]:] if attention_mask is not None else None
seq_idx=kwargs.get('seq_idx', None)
position_ids = kwargs.get('position_ids', None)
q, conv_state_q = self.q_conv1d(x=self.q_proj(hidden_states),
mask=conv_mask,
cache=conv_state_q,
output_final_state=use_cache,seq_idx=seq_idx)
output_final_state=use_cache,
seq_idx=position_ids)
k, conv_state_k = self.k_conv1d(x=self.k_proj(hidden_states),
mask=conv_mask,
cache=conv_state_k,
output_final_state=use_cache,seq_idx=seq_idx)
output_final_state=use_cache,
seq_idx=position_ids)
v, conv_state_v = self.v_conv1d(x=self.v_proj(hidden_states),
mask=conv_mask,
cache=conv_state_v,
output_final_state=use_cache,seq_idx=seq_idx)
output_final_state=use_cache,
seq_idx=position_ids)
else:
q = self.q_proj(hidden_states)
k = self.k_proj(hidden_states)
Expand All @@ -230,7 +232,6 @@ def forward(
else:
raise NotImplementedError


if self.qk_norm == 'sum':
q = sum_norm(q).to(q)
k = sum_norm(k).to(k)
Expand Down
11 changes: 7 additions & 4 deletions fla/layers/gated_deltanet.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,19 +212,22 @@ def forward(
if last_state is not None:
conv_state_q, conv_state_k, conv_state_v = last_state['conv_state']
conv_mask = attention_mask[:, -hidden_states.shape[1]:] if attention_mask is not None else None
seq_idx=kwargs.get('seq_idx', None)
position_ids = kwargs.get('position_ids', None)
q, conv_state_q = self.q_conv1d(x=self.q_proj(hidden_states),
mask=conv_mask,
cache=conv_state_q,
output_final_state=use_cache,seq_idx=seq_idx)
output_final_state=use_cache,
seq_idx=position_ids)
k, conv_state_k = self.k_conv1d(x=self.k_proj(hidden_states),
mask=conv_mask,
cache=conv_state_k,
output_final_state=use_cache,seq_idx=seq_idx)
output_final_state=use_cache,
seq_idx=position_ids)
v, conv_state_v = self.v_conv1d(x=self.v_proj(hidden_states),
mask=conv_mask,
cache=conv_state_v,
output_final_state=use_cache,seq_idx=seq_idx)
output_final_state=use_cache,
seq_idx=position_ids)
else:
q = self.silu(self.q_proj(hidden_states))
k = self.silu(self.k_proj(hidden_states))
Expand Down
11 changes: 7 additions & 4 deletions fla/layers/gla.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,19 +184,22 @@ def forward(
if last_state is not None:
conv_state_q, conv_state_k, conv_state_v = last_state['conv_state']
conv_mask = attention_mask[:, -hidden_states.shape[1]:] if attention_mask is not None else None
seq_idx=kwargs.get('seq_idx', None)
position_ids = kwargs.get('position_ids', None)
q, conv_state_q = self.q_conv1d(x=self.q_proj(hidden_states),
mask=conv_mask,
cache=conv_state_q,
output_final_state=use_cache,seq_idx=seq_idx)
output_final_state=use_cache,
seq_idx=position_ids)
k, conv_state_k = self.k_conv1d(x=self.k_proj(hidden_states),
mask=conv_mask,
cache=conv_state_k,
output_final_state=use_cache,seq_idx=seq_idx)
output_final_state=use_cache,
seq_idx=position_ids)
v, conv_state_v = self.v_conv1d(x=self.v_proj(hidden_states),
mask=conv_mask,
cache=conv_state_v,
output_final_state=use_cache,seq_idx=seq_idx)
output_final_state=use_cache,
seq_idx=position_ids)
else:
q = self.q_proj(hidden_states)
k = self.k_proj(hidden_states)
Expand Down
11 changes: 7 additions & 4 deletions fla/layers/gsa.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,19 +157,22 @@ def forward(
if last_state is not None:
conv_state_q, conv_state_k, conv_state_v = last_state['conv_state']
conv_mask = attention_mask[:, -hidden_states.shape[1]:] if attention_mask is not None else None
seq_idx=kwargs.get('seq_idx', None)
position_ids = kwargs.get('position_ids', None)
q, conv_state_q = self.q_conv1d(x=self.q_proj(hidden_states),
mask=conv_mask,
cache=conv_state_q,
output_final_state=use_cache,seq_idx=seq_idx)
output_final_state=use_cache,
seq_idx=position_ids)
k, conv_state_k = self.k_conv1d(x=self.k_proj(hidden_states),
mask=conv_mask,
cache=conv_state_k,
output_final_state=use_cache,seq_idx=seq_idx)
output_final_state=use_cache,
seq_idx=position_ids)
v, conv_state_v = self.v_conv1d(x=self.v_proj(hidden_states),
mask=conv_mask,
cache=conv_state_v,
output_final_state=use_cache,seq_idx=seq_idx)
output_final_state=use_cache,
seq_idx=position_ids)
else:
q = self.q_proj(hidden_states)
k = self.k_proj(hidden_states)
Expand Down
7 changes: 5 additions & 2 deletions fla/layers/hgrn.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,14 +103,17 @@ def forward(
if last_state is not None:
conv_state_i, conv_state_f = last_state['conv_state']
conv_mask = attention_mask[:, -hidden_states.shape[1]:] if attention_mask is not None else None
position_ids = kwargs.get('position_ids', None)
i, conv_state_i = self.i_conv1d(x=self.i_proj(hidden_states),
mask=conv_mask,
cache=conv_state_i,
output_final_state=use_cache)
output_final_state=use_cache,
seq_idx=position_ids)
f, conv_state_f = self.f_conv1d(x=self.f_proj(hidden_states),
mask=conv_mask,
cache=conv_state_f,
output_final_state=use_cache)
output_final_state=use_cache,
seq_idx=position_ids)
else:
i = self.i_proj(hidden_states)
f = self.f_proj(hidden_states)
Expand Down
11 changes: 7 additions & 4 deletions fla/layers/hgrn2.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,19 +120,22 @@ def forward(
if last_state is not None:
conv_state_q, conv_state_f, conv_state_i = last_state['conv_state']
conv_mask = attention_mask[:, -hidden_states.shape[1]:] if attention_mask is not None else None
seq_idx=kwargs.get('seq_idx', None)
position_ids = kwargs.get('position_ids', None)
q, conv_state_q = self.q_conv1d(x=self.q_proj(hidden_states),
mask=conv_mask,
cache=conv_state_q,
output_final_state=use_cache,seq_idx=seq_idx)
output_final_state=use_cache,
seq_idx=position_ids)
f, conv_state_f = self.f_conv1d(x=self.f_proj(hidden_states),
mask=conv_mask,
cache=conv_state_f,
output_final_state=use_cache,seq_idx=seq_idx)
output_final_state=use_cache,
seq_idx=position_ids)
i, conv_state_i = self.i_conv1d(x=self.i_proj(hidden_states),
mask=conv_mask,
cache=conv_state_i,
output_final_state=use_cache,seq_idx=seq_idx)
output_final_state=use_cache,
seq_idx=position_ids)
else:
q = self.q_proj(hidden_states)
f = self.f_proj(hidden_states)
Expand Down
10 changes: 7 additions & 3 deletions fla/layers/multiscale_retention.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,18 +176,22 @@ def forward(
if last_state is not None:
conv_state_q, conv_state_k, conv_state_v = last_state['conv_state']
conv_mask = attention_mask[:, -hidden_states.shape[1]:] if attention_mask is not None else None
position_ids = kwargs.get('position_ids', None)
q, conv_state_q = self.q_conv1d(x=self.q_proj(hidden_states),
mask=conv_mask,
cache=conv_state_q,
output_final_state=use_cache)
output_final_state=use_cache,
seq_idx=position_ids)
k, conv_state_k = self.k_conv1d(x=self.k_proj(hidden_states),
mask=conv_mask,
cache=conv_state_k,
output_final_state=use_cache)
output_final_state=use_cache,
seq_idx=position_ids)
v, conv_state_v = self.v_conv1d(x=self.v_proj(hidden_states),
mask=conv_mask,
cache=conv_state_v,
output_final_state=use_cache)
output_final_state=use_cache,
seq_idx=position_ids)
else:
q = self.q_proj(hidden_states)
k = self.k_proj(hidden_states)
Expand Down
10 changes: 7 additions & 3 deletions fla/layers/simple_gla.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,18 +170,22 @@ def forward(
if last_state is not None:
conv_state_q, conv_state_k, conv_state_v = last_state['conv_state']
conv_mask = attention_mask[:, -hidden_states.shape[1]:] if attention_mask is not None else None
position_ids = kwargs.get('position_ids', None)
q, conv_state_q = self.q_conv1d(x=self.q_proj(hidden_states),
mask=conv_mask,
cache=conv_state_q,
output_final_state=use_cache)
output_final_state=use_cache,
seq_idx=position_ids)
k, conv_state_k = self.k_conv1d(x=self.k_proj(hidden_states),
mask=conv_mask,
cache=conv_state_k,
output_final_state=use_cache)
output_final_state=use_cache,
seq_idx=position_ids)
v, conv_state_v = self.v_conv1d(x=self.v_proj(hidden_states),
mask=conv_mask,
cache=conv_state_v,
output_final_state=use_cache)
output_final_state=use_cache,
seq_idx=position_ids)
else:
q = self.q_proj(hidden_states)
k = self.k_proj(hidden_states)
Expand Down

0 comments on commit 973e3eb

Please sign in to comment.