From ca9e9437a967a74171d680d8c841e1e4434e7bbe Mon Sep 17 00:00:00 2001 From: Yu Zhang Date: Sat, 18 Jan 2025 10:53:09 -0800 Subject: [PATCH] Add full support for varlen short convs --- fla/layers/abc.py | 10 +++++++--- fla/layers/delta_net.py | 15 ++++++++------- fla/layers/gated_deltanet.py | 11 +++++++---- fla/layers/gla.py | 11 +++++++---- fla/layers/gsa.py | 11 +++++++---- fla/layers/hgrn.py | 7 +++++-- fla/layers/hgrn2.py | 11 +++++++---- fla/layers/multiscale_retention.py | 10 +++++++--- fla/layers/simple_gla.py | 10 +++++++--- 9 files changed, 62 insertions(+), 34 deletions(-) diff --git a/fla/layers/abc.py b/fla/layers/abc.py index 1db5d94fe..6d1cf15c8 100644 --- a/fla/layers/abc.py +++ b/fla/layers/abc.py @@ -142,18 +142,22 @@ def forward( if last_state is not None: conv_state_q, conv_state_k, conv_state_v = last_state['conv_state'] conv_mask = attention_mask[:, -hidden_states.shape[1]:] if attention_mask is not None else None + position_ids = kwargs.get('position_ids', None) q, conv_state_q = self.q_conv1d(x=self.q_proj(hidden_states), mask=conv_mask, cache=conv_state_q, - output_final_state=use_cache) + output_final_state=use_cache, + seq_idx=position_ids) k, conv_state_k = self.k_conv1d(x=self.k_proj(hidden_states), mask=conv_mask, cache=conv_state_k, - output_final_state=use_cache) + output_final_state=use_cache, + seq_idx=position_ids) v, conv_state_v = self.v_conv1d(x=self.v_proj(hidden_states), mask=conv_mask, cache=conv_state_v, - output_final_state=use_cache) + output_final_state=use_cache, + seq_idx=position_ids) else: q = self.q_proj(hidden_states) k = self.k_proj(hidden_states) diff --git a/fla/layers/delta_net.py b/fla/layers/delta_net.py index 19f24fd89..f45751d58 100644 --- a/fla/layers/delta_net.py +++ b/fla/layers/delta_net.py @@ -11,8 +11,7 @@ from torch.nn import functional as F from fla.modules import FusedRMSNormSwishGate, RMSNorm, ShortConvolution -from fla.ops.delta_rule import (chunk_delta_rule, - fused_recurrent_delta_rule) +from fla.ops.delta_rule import chunk_delta_rule, fused_recurrent_delta_rule if TYPE_CHECKING: from transformers.processing_utils import Unpack @@ -199,19 +198,22 @@ def forward( if last_state is not None: conv_state_q, conv_state_k, conv_state_v = last_state['conv_state'] conv_mask = attention_mask[:, -hidden_states.shape[1]:] if attention_mask is not None else None - seq_idx=kwargs.get('seq_idx', None) + position_ids = kwargs.get('position_ids', None) q, conv_state_q = self.q_conv1d(x=self.q_proj(hidden_states), mask=conv_mask, cache=conv_state_q, - output_final_state=use_cache,seq_idx=seq_idx) + output_final_state=use_cache, + seq_idx=position_ids) k, conv_state_k = self.k_conv1d(x=self.k_proj(hidden_states), mask=conv_mask, cache=conv_state_k, - output_final_state=use_cache,seq_idx=seq_idx) + output_final_state=use_cache, + seq_idx=position_ids) v, conv_state_v = self.v_conv1d(x=self.v_proj(hidden_states), mask=conv_mask, cache=conv_state_v, - output_final_state=use_cache,seq_idx=seq_idx) + output_final_state=use_cache, + seq_idx=position_ids) else: q = self.q_proj(hidden_states) k = self.k_proj(hidden_states) @@ -230,7 +232,6 @@ def forward( else: raise NotImplementedError - if self.qk_norm == 'sum': q = sum_norm(q).to(q) k = sum_norm(k).to(k) diff --git a/fla/layers/gated_deltanet.py b/fla/layers/gated_deltanet.py index d0aca2f7f..d03977eb5 100644 --- a/fla/layers/gated_deltanet.py +++ b/fla/layers/gated_deltanet.py @@ -212,19 +212,22 @@ def forward( if last_state is not None: conv_state_q, conv_state_k, conv_state_v = last_state['conv_state'] conv_mask = attention_mask[:, -hidden_states.shape[1]:] if attention_mask is not None else None - seq_idx=kwargs.get('seq_idx', None) + position_ids = kwargs.get('position_ids', None) q, conv_state_q = self.q_conv1d(x=self.q_proj(hidden_states), mask=conv_mask, cache=conv_state_q, - output_final_state=use_cache,seq_idx=seq_idx) + output_final_state=use_cache, + seq_idx=position_ids) k, conv_state_k = self.k_conv1d(x=self.k_proj(hidden_states), mask=conv_mask, cache=conv_state_k, - output_final_state=use_cache,seq_idx=seq_idx) + output_final_state=use_cache, + seq_idx=position_ids) v, conv_state_v = self.v_conv1d(x=self.v_proj(hidden_states), mask=conv_mask, cache=conv_state_v, - output_final_state=use_cache,seq_idx=seq_idx) + output_final_state=use_cache, + seq_idx=position_ids) else: q = self.silu(self.q_proj(hidden_states)) k = self.silu(self.k_proj(hidden_states)) diff --git a/fla/layers/gla.py b/fla/layers/gla.py index 560460cfa..e86e04183 100644 --- a/fla/layers/gla.py +++ b/fla/layers/gla.py @@ -184,19 +184,22 @@ def forward( if last_state is not None: conv_state_q, conv_state_k, conv_state_v = last_state['conv_state'] conv_mask = attention_mask[:, -hidden_states.shape[1]:] if attention_mask is not None else None - seq_idx=kwargs.get('seq_idx', None) + position_ids = kwargs.get('position_ids', None) q, conv_state_q = self.q_conv1d(x=self.q_proj(hidden_states), mask=conv_mask, cache=conv_state_q, - output_final_state=use_cache,seq_idx=seq_idx) + output_final_state=use_cache, + seq_idx=position_ids) k, conv_state_k = self.k_conv1d(x=self.k_proj(hidden_states), mask=conv_mask, cache=conv_state_k, - output_final_state=use_cache,seq_idx=seq_idx) + output_final_state=use_cache, + seq_idx=position_ids) v, conv_state_v = self.v_conv1d(x=self.v_proj(hidden_states), mask=conv_mask, cache=conv_state_v, - output_final_state=use_cache,seq_idx=seq_idx) + output_final_state=use_cache, + seq_idx=position_ids) else: q = self.q_proj(hidden_states) k = self.k_proj(hidden_states) diff --git a/fla/layers/gsa.py b/fla/layers/gsa.py index fd30dcfc3..c89bf3ada 100644 --- a/fla/layers/gsa.py +++ b/fla/layers/gsa.py @@ -157,19 +157,22 @@ def forward( if last_state is not None: conv_state_q, conv_state_k, conv_state_v = last_state['conv_state'] conv_mask = attention_mask[:, -hidden_states.shape[1]:] if attention_mask is not None else None - seq_idx=kwargs.get('seq_idx', None) + position_ids = kwargs.get('position_ids', None) q, conv_state_q = self.q_conv1d(x=self.q_proj(hidden_states), mask=conv_mask, cache=conv_state_q, - output_final_state=use_cache,seq_idx=seq_idx) + output_final_state=use_cache, + seq_idx=position_ids) k, conv_state_k = self.k_conv1d(x=self.k_proj(hidden_states), mask=conv_mask, cache=conv_state_k, - output_final_state=use_cache,seq_idx=seq_idx) + output_final_state=use_cache, + seq_idx=position_ids) v, conv_state_v = self.v_conv1d(x=self.v_proj(hidden_states), mask=conv_mask, cache=conv_state_v, - output_final_state=use_cache,seq_idx=seq_idx) + output_final_state=use_cache, + seq_idx=position_ids) else: q = self.q_proj(hidden_states) k = self.k_proj(hidden_states) diff --git a/fla/layers/hgrn.py b/fla/layers/hgrn.py index 716a56e7b..7b1bbd0e0 100644 --- a/fla/layers/hgrn.py +++ b/fla/layers/hgrn.py @@ -103,14 +103,17 @@ def forward( if last_state is not None: conv_state_i, conv_state_f = last_state['conv_state'] conv_mask = attention_mask[:, -hidden_states.shape[1]:] if attention_mask is not None else None + position_ids = kwargs.get('position_ids', None) i, conv_state_i = self.i_conv1d(x=self.i_proj(hidden_states), mask=conv_mask, cache=conv_state_i, - output_final_state=use_cache) + output_final_state=use_cache, + seq_idx=position_ids) f, conv_state_f = self.f_conv1d(x=self.f_proj(hidden_states), mask=conv_mask, cache=conv_state_f, - output_final_state=use_cache) + output_final_state=use_cache, + seq_idx=position_ids) else: i = self.i_proj(hidden_states) f = self.f_proj(hidden_states) diff --git a/fla/layers/hgrn2.py b/fla/layers/hgrn2.py index 53a915a45..e19b773c8 100644 --- a/fla/layers/hgrn2.py +++ b/fla/layers/hgrn2.py @@ -120,19 +120,22 @@ def forward( if last_state is not None: conv_state_q, conv_state_f, conv_state_i = last_state['conv_state'] conv_mask = attention_mask[:, -hidden_states.shape[1]:] if attention_mask is not None else None - seq_idx=kwargs.get('seq_idx', None) + position_ids = kwargs.get('position_ids', None) q, conv_state_q = self.q_conv1d(x=self.q_proj(hidden_states), mask=conv_mask, cache=conv_state_q, - output_final_state=use_cache,seq_idx=seq_idx) + output_final_state=use_cache, + seq_idx=position_ids) f, conv_state_f = self.f_conv1d(x=self.f_proj(hidden_states), mask=conv_mask, cache=conv_state_f, - output_final_state=use_cache,seq_idx=seq_idx) + output_final_state=use_cache, + seq_idx=position_ids) i, conv_state_i = self.i_conv1d(x=self.i_proj(hidden_states), mask=conv_mask, cache=conv_state_i, - output_final_state=use_cache,seq_idx=seq_idx) + output_final_state=use_cache, + seq_idx=position_ids) else: q = self.q_proj(hidden_states) f = self.f_proj(hidden_states) diff --git a/fla/layers/multiscale_retention.py b/fla/layers/multiscale_retention.py index 9238f1bc6..d5f1989b7 100644 --- a/fla/layers/multiscale_retention.py +++ b/fla/layers/multiscale_retention.py @@ -176,18 +176,22 @@ def forward( if last_state is not None: conv_state_q, conv_state_k, conv_state_v = last_state['conv_state'] conv_mask = attention_mask[:, -hidden_states.shape[1]:] if attention_mask is not None else None + position_ids = kwargs.get('position_ids', None) q, conv_state_q = self.q_conv1d(x=self.q_proj(hidden_states), mask=conv_mask, cache=conv_state_q, - output_final_state=use_cache) + output_final_state=use_cache, + seq_idx=position_ids) k, conv_state_k = self.k_conv1d(x=self.k_proj(hidden_states), mask=conv_mask, cache=conv_state_k, - output_final_state=use_cache) + output_final_state=use_cache, + seq_idx=position_ids) v, conv_state_v = self.v_conv1d(x=self.v_proj(hidden_states), mask=conv_mask, cache=conv_state_v, - output_final_state=use_cache) + output_final_state=use_cache, + seq_idx=position_ids) else: q = self.q_proj(hidden_states) k = self.k_proj(hidden_states) diff --git a/fla/layers/simple_gla.py b/fla/layers/simple_gla.py index 805639e0b..e652d7b0a 100644 --- a/fla/layers/simple_gla.py +++ b/fla/layers/simple_gla.py @@ -170,18 +170,22 @@ def forward( if last_state is not None: conv_state_q, conv_state_k, conv_state_v = last_state['conv_state'] conv_mask = attention_mask[:, -hidden_states.shape[1]:] if attention_mask is not None else None + position_ids = kwargs.get('position_ids', None) q, conv_state_q = self.q_conv1d(x=self.q_proj(hidden_states), mask=conv_mask, cache=conv_state_q, - output_final_state=use_cache) + output_final_state=use_cache, + seq_idx=position_ids) k, conv_state_k = self.k_conv1d(x=self.k_proj(hidden_states), mask=conv_mask, cache=conv_state_k, - output_final_state=use_cache) + output_final_state=use_cache, + seq_idx=position_ids) v, conv_state_v = self.v_conv1d(x=self.v_proj(hidden_states), mask=conv_mask, cache=conv_state_v, - output_final_state=use_cache) + output_final_state=use_cache, + seq_idx=position_ids) else: q = self.q_proj(hidden_states) k = self.k_proj(hidden_states)