From 683913ff2fe9bc2520b742cb35aef31d48f95130 Mon Sep 17 00:00:00 2001 From: yibozhong Date: Sun, 19 Jan 2025 17:15:41 +0800 Subject: [PATCH] remove separate folder --- fla/vision_models/__init__.py | 43 -- fla/vision_models/abc/__init__.py | 12 - fla/vision_models/abc/configuration_abc.py | 97 ---- fla/vision_models/abc/modeling_abc.py | 203 -------- fla/vision_models/bitnet/__init__.py | 12 - .../bitnet/configuration_bitnet.py | 95 ---- fla/vision_models/bitnet/modeling_bitnet.py | 201 -------- fla/vision_models/delta_net/__init__.py | 16 - .../delta_net/configuration_delta_net.py | 101 ---- .../delta_net/modeling_delta_net.py | 385 -------------- fla/vision_models/gated_deltanet/__init__.py | 13 - .../configuration_gated_deltanet.py | 89 ---- .../gated_deltanet/modeling_gated_deltanet.py | 202 -------- fla/vision_models/gla/__init__.py | 12 - fla/vision_models/gla/configuration_gla.py | 101 ---- fla/vision_models/gla/modeling_gla.py | 205 -------- fla/vision_models/gsa/__init__.py | 12 - fla/vision_models/gsa/configuration_gsa.py | 107 ---- fla/vision_models/gsa/modeling_gsa.py | 209 -------- fla/vision_models/hgrn/__init__.py | 12 - fla/vision_models/hgrn/configuration_hgrn.py | 86 ---- fla/vision_models/hgrn/modeling_hgrn.py | 197 ------- fla/vision_models/hgrn2/__init__.py | 12 - .../hgrn2/configuration_hgrn2.py | 89 ---- fla/vision_models/hgrn2/modeling_hgrn2.py | 198 -------- fla/vision_models/linear_attn/__init__.py | 12 - .../linear_attn/configuration_linear_attn.py | 96 ---- .../linear_attn/modeling_linear_attn.py | 197 ------- fla/vision_models/retnet/__init__.py | 12 - .../retnet/configuration_retnet.py | 101 ---- fla/vision_models/retnet/modeling_retnet.py | 202 -------- fla/vision_models/rwkv6/__init__.py | 12 - .../rwkv6/configuration_rwkv6.py | 94 ---- fla/vision_models/rwkv6/modeling_rwkv6.py | 199 -------- fla/vision_models/transformer/__init__.py | 12 - .../transformer/configuration_transformer.py | 81 --- .../transformer/modeling_transformer.py | 190 ------- fla/vision_models/utils.py | 480 ------------------ 38 files changed, 4397 deletions(-) delete mode 100644 fla/vision_models/__init__.py delete mode 100644 fla/vision_models/abc/__init__.py delete mode 100644 fla/vision_models/abc/configuration_abc.py delete mode 100644 fla/vision_models/abc/modeling_abc.py delete mode 100644 fla/vision_models/bitnet/__init__.py delete mode 100644 fla/vision_models/bitnet/configuration_bitnet.py delete mode 100644 fla/vision_models/bitnet/modeling_bitnet.py delete mode 100644 fla/vision_models/delta_net/__init__.py delete mode 100644 fla/vision_models/delta_net/configuration_delta_net.py delete mode 100644 fla/vision_models/delta_net/modeling_delta_net.py delete mode 100644 fla/vision_models/gated_deltanet/__init__.py delete mode 100644 fla/vision_models/gated_deltanet/configuration_gated_deltanet.py delete mode 100644 fla/vision_models/gated_deltanet/modeling_gated_deltanet.py delete mode 100644 fla/vision_models/gla/__init__.py delete mode 100644 fla/vision_models/gla/configuration_gla.py delete mode 100644 fla/vision_models/gla/modeling_gla.py delete mode 100644 fla/vision_models/gsa/__init__.py delete mode 100644 fla/vision_models/gsa/configuration_gsa.py delete mode 100644 fla/vision_models/gsa/modeling_gsa.py delete mode 100644 fla/vision_models/hgrn/__init__.py delete mode 100644 fla/vision_models/hgrn/configuration_hgrn.py delete mode 100644 fla/vision_models/hgrn/modeling_hgrn.py delete mode 100644 fla/vision_models/hgrn2/__init__.py delete mode 100644 fla/vision_models/hgrn2/configuration_hgrn2.py delete mode 100644 fla/vision_models/hgrn2/modeling_hgrn2.py delete mode 100644 fla/vision_models/linear_attn/__init__.py delete mode 100644 fla/vision_models/linear_attn/configuration_linear_attn.py delete mode 100644 fla/vision_models/linear_attn/modeling_linear_attn.py delete mode 100644 fla/vision_models/retnet/__init__.py delete mode 100644 fla/vision_models/retnet/configuration_retnet.py delete mode 100644 fla/vision_models/retnet/modeling_retnet.py delete mode 100644 fla/vision_models/rwkv6/__init__.py delete mode 100644 fla/vision_models/rwkv6/configuration_rwkv6.py delete mode 100644 fla/vision_models/rwkv6/modeling_rwkv6.py delete mode 100644 fla/vision_models/transformer/__init__.py delete mode 100644 fla/vision_models/transformer/configuration_transformer.py delete mode 100644 fla/vision_models/transformer/modeling_transformer.py delete mode 100644 fla/vision_models/utils.py diff --git a/fla/vision_models/__init__.py b/fla/vision_models/__init__.py deleted file mode 100644 index 9538e6c2c..000000000 --- a/fla/vision_models/__init__.py +++ /dev/null @@ -1,43 +0,0 @@ -from fla.vision_models.abc import ABCVisionConfig, ABCForImageClassification -from fla.vision_models.bitnet import BitNetVisionConfig, BitNetForImageClassification -from fla.vision_models.delta_net import DeltaNetVisionConfig, DeltaNetForImageClassification -from fla.vision_models.gated_deltanet import GatedDeltaNetVisionConfig, GatedDeltaNetForImageClassification -from fla.vision_models.gla import GLAVisionConfig, GLAForImageClassification -from fla.vision_models.gsa import GSAVisionConfig, GSAForImageClassification -from fla.vision_models.hgrn import HGRNVisionConfig, HGRNForImageClassification -from fla.vision_models.hgrn2 import HGRN2VisionConfig, HGRN2ForImageClassification -from fla.vision_models.linear_attn import LinearAttentionVisionConfig, LinearAttentionForImageClassification -from fla.vision_models.retnet import RetNetVisionConfig, RetNetForImageClassification -from fla.vision_models.rwkv6 import RWKV6VisionConfig, RWKV6ForImageClassification -from fla.vision_models.transformer import TransformerVisionConfig, TransformerForImageClassification -from fla.vision_models.utils import ImageEmbeddings, PatchEmbeddings, Pooler - -__all__ = [ - 'ABCVisionConfig', - 'ABCForImageClassification', - 'BitNetVisionConfig', - 'BitNetForImageClassification', - 'DeltaNetVisionConfig', - 'DeltaNetForImageClassification', - 'GatedDeltaNetVisionConfig', - 'GatedDeltaNetForImageClassification', - 'GLAVisionConfig', - 'GLAForImageClassification', - 'GSAVisionConfig', - 'GSAForImageClassification', - 'HGRNVisionConfig', - 'HGRNForImageClassification', - 'HGRN2VisionConfig', - 'HGRN2ForImageClassification', - 'LinearAttentionVisionConfig', - 'LinearAttentionForImageClassification', - 'RetNetVisionConfig', - 'RetNetForImageClassification', - 'RWKV6VisionConfig', - 'RWKV6ForImageClassification', - 'TransformerVisionConfig', - 'TransformerForImageClassification', - 'ImageEmbeddings', - 'PatchEmbeddings', - 'Pooler', -] diff --git a/fla/vision_models/abc/__init__.py b/fla/vision_models/abc/__init__.py deleted file mode 100644 index 67d013691..000000000 --- a/fla/vision_models/abc/__init__.py +++ /dev/null @@ -1,12 +0,0 @@ -from transformers import AutoConfig, AutoModelForImageClassification - -from fla.vision_models.abc.configuration_abc import ABCVisionConfig -from fla.vision_models.abc.modeling_abc import ABCForImageClassification - -AutoConfig.register(ABCVisionConfig.model_type, ABCVisionConfig) -AutoModelForImageClassification.register(ABCVisionConfig, ABCForImageClassification) - -__all__ = [ - 'ABCVisionConfig', - 'ABCForImageClassification' -] diff --git a/fla/vision_models/abc/configuration_abc.py b/fla/vision_models/abc/configuration_abc.py deleted file mode 100644 index 13de13c09..000000000 --- a/fla/vision_models/abc/configuration_abc.py +++ /dev/null @@ -1,97 +0,0 @@ -from typing import Dict, Optional - -from transformers.configuration_utils import PretrainedConfig - - -class ABCVisionConfig(PretrainedConfig): - - model_type = 'abc_vision' - - def __init__( - self, - # ABC core parameters - hidden_size: int = 2048, - gate_low_rank_dim: int = 16, - clamp_min: float = -32, - clamp_max: float = 32, - num_hidden_layers: int = 24, - num_heads: int = 4, - num_slots: Optional[int] = 64, - use_short_conv: bool = False, - conv_size: int = 4, - exapnd_k: float = 0.5, - exapnd_v: float = 1, - hidden_act: str = "swish", - max_position_embeddings: int = 2048, - elementwise_affine: Optional[bool] = True, - norm_eps: float = 1e-6, - attn: Optional[Dict] = None, - use_cache: bool = True, - initializer_range: float = 0.02, - fuse_norm: bool = True, - fuse_cross_entropy: bool = True, - # Vision specific parameters - image_size: int = 224, - patch_size: int = 16, - num_channels: int = 3, - num_classes: int = 1000, - qkv_bias: bool = True, - hidden_dropout_prob: float = 0.0, - use_mask_token: bool = False, - layer_norm_eps: float = 1e-6, - interpolate_pos_encoding: bool = False, - mlp_dim: int = None, - # FLA-for-vision-related parameters - scan_type: str = "uni-scan", # scaning type, "uni-scan" or "bi-scan" or "cross-scan", default to "uni-scan" - **kwargs - ): - # Initialize ABC core parameters - self.hidden_size = hidden_size - self.gate_low_rank_dim = gate_low_rank_dim - self.clamp_min = clamp_min - self.clamp_max = clamp_max - self.num_hidden_layers = num_hidden_layers - self.num_heads = num_heads - self.num_slots = num_slots - self.use_short_conv = use_short_conv - self.conv_size = conv_size - self.expand_k = exapnd_k - self.expand_v = exapnd_v - self.hidden_act = hidden_act - self.max_position_embeddings = max_position_embeddings - self.elementwise_affine = elementwise_affine - self.norm_eps = norm_eps - self.use_cache = use_cache - self.initializer_range = initializer_range - self.fuse_norm = fuse_norm - self.fuse_cross_entropy = fuse_cross_entropy - - # Initialize vision specific parameters - self.image_size = image_size - self.patch_size = patch_size - self.num_channels = num_channels - self.num_classes = num_classes - self.qkv_bias = qkv_bias - self.hidden_dropout_prob = hidden_dropout_prob - self.use_mask_token = use_mask_token - self.layer_norm_eps = layer_norm_eps - self.interpolate_pos_encoding = interpolate_pos_encoding - self.scan_type = scan_type - - if attn is not None: - if not isinstance(attn, Dict): - raise ValueError("attn must be a dictionary") - if 'layers' not in attn: - raise ValueError("Layer indices must be provided to initialize hybrid attention layers") - if 'num_heads' not in attn: - raise ValueError("Number of heads must be provided to initialize hybrid attention layers") - attn['num_kv_heads'] = attn.get('num_kv_heads', attn['num_heads']) - attn['window_size'] = attn.get('window_size', None) - - self.attn = attn - if mlp_dim is None: - self.mlp_dim = 4 * hidden_size # default value set to 4 * hidden_size - else: - self.mlp_dim = mlp_dim - - super().__init__(**kwargs) diff --git a/fla/vision_models/abc/modeling_abc.py b/fla/vision_models/abc/modeling_abc.py deleted file mode 100644 index ecba55a35..000000000 --- a/fla/vision_models/abc/modeling_abc.py +++ /dev/null @@ -1,203 +0,0 @@ -import collections.abc -import math -import torch -import torch.utils.checkpoint -from torch import nn -from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss -from typing import Optional, Set, Tuple, Union, List, Dict, Unpack -from transformers.utils import logging -from fla.layers.attn import Attention -from transformers.modeling_outputs import ImageClassifierOutput -from transformers.modeling_utils import PreTrainedModel -from .configuration_abc import ABCVisionConfig -from fla.layers.abc import ABCAttention -from fla.models.utils import Cache -from ..utils import ImageEmbeddings, Pooler, prepare_hidden_states_for_cross_scan, prepare_hidden_states_for_cross_merge - -logger = logging.get_logger(__name__) - -class ABCMLP(nn.Module): - def __init__(self, config): - super().__init__() - self.net = nn.Sequential( - nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps), - nn.Linear(config.hidden_size, config.mlp_dim), - nn.GELU(), - nn.Linear(config.mlp_dim, config.hidden_size), - nn.Dropout(config.hidden_dropout_prob) - ) - - def forward(self, x): - return self.net(x) - -class ABCBlock(nn.Module): - def __init__(self, config, layer_idx: int): - super().__init__() - - self.ln_1 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) - - if config.attn is not None and layer_idx in config.attn['layers']: - self.attn = Attention( - hidden_size=config.hidden_size, - num_heads=config.attn['num_heads'], - num_kv_heads=config.attn['num_kv_heads'], - window_size=config.attn['window_size'], - max_position_embeddings=config.max_position_embeddings, - layer_idx=layer_idx - ) - else: - self.attn = ABCAttention( - hidden_size=config.hidden_size, - expand_k=config.expand_k, - expand_v=config.expand_v, - num_heads=config.num_heads, - num_slots=config.num_slots, - use_short_conv=config.use_short_conv, - conv_size=config.conv_size, - gate_fn=config.hidden_act, - elementwise_affine=config.elementwise_affine, - norm_eps=config.norm_eps, - clamp_min=config.clamp_min, - clamp_max=config.clamp_max, - fuse_norm=config.fuse_norm, - layer_idx=layer_idx - ) - - self.ln_2 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) - - self.mlp = ABCMLP(config) - - self.scan_type = config.scan_type - - def forward( - self, - hidden_states: torch.Tensor, - past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, - use_cache: Optional[bool] = False, - output_attentions: Optional[bool] = False, - **kwargs: Unpack[Dict] - ) -> Union[Tuple[torch.Tensor, Optional[torch.Tensor]], Tuple[torch.Tensor]]: - residual = hidden_states - - # Pre-normalization if enabled - if hasattr(self, 'ln_1'): - hidden_states = self.ln_1(hidden_states) - - # Apply attention - - hidden_states = prepare_hidden_states_for_cross_scan(hidden_states, self.scan_type) - - hidden_states, attentions, past_key_values = self.attn( - hidden_states=hidden_states, - past_key_values=past_key_values, - use_cache=use_cache, - output_attentions=output_attentions, - **kwargs - ) - - hidden_states = prepare_hidden_states_for_cross_merge(hidden_states, self.scan_type) - - # First residual connection - hidden_states = residual + hidden_states - residual = hidden_states - - # Pre-normalization for MLP if enabled - if hasattr(self, 'ln_2'): - hidden_states = self.ln_2(hidden_states) - - # MLP - hidden_states = self.mlp(hidden_states) - - # Second residual connection - hidden_states = residual + hidden_states - - outputs = (hidden_states, attentions, past_key_values) - - return outputs - -class ABCVisionPreTrainedModel(PreTrainedModel): - # this part of the code is adapted from huggingface/transformers vit implementation - config_class = ABCVisionConfig - - def _init_weights(self, module): - if isinstance(module, (nn.Linear, nn.Conv2d)): - module.weight.data = nn.init.trunc_normal_( - module.weight.data.to(torch.float32), mean=0.0, std=self.config.initializer_range - ).to(module.weight.dtype) - if module.bias is not None: - module.bias.data.zero_() - elif isinstance(module, nn.LayerNorm): - module.bias.data.zero_() - module.weight.data.fill_(1.0) - elif isinstance(module, ImageEmbeddings): - module.position_embeddings.data = nn.init.trunc_normal_( - module.position_embeddings.data.to(torch.float32), - mean=0.0, - std=self.config.initializer_range, - ).to(module.position_embeddings.dtype) - -class ABCForImageClassification(ABCVisionPreTrainedModel): - config_class = ABCVisionConfig - - def __init__(self, config): - super().__init__(config) - self.num_labels = config.num_classes - - self.embeddings = ImageEmbeddings(config) - self.blocks = nn.ModuleList([ - ABCBlock(config, layer_idx) - for layer_idx in range(config.num_hidden_layers) - ]) - self.norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) - self.pooler = Pooler(config) - self.classifier = nn.Linear(config.hidden_size, config.num_classes) - self.interpolate_pos_encoding = config.interpolate_pos_encoding - self.init_weights() - - def forward( - self, - pixel_values: Optional[torch.Tensor] = None, - labels: Optional[torch.Tensor] = None, - output_hidden_states: Optional[bool] = None, - past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - return_dict: Optional[bool] = None, - **kwargs: Unpack[Dict] - ) -> Union[Tuple, ImageClassifierOutput]: - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - fuse_linear_and_cross_entropy = self.config.fuse_cross_entropy and self.training - hidden_states = self.embeddings(pixel_values, interpolate_pos_encoding=self.interpolate_pos_encoding) - - for block in self.blocks: - hidden_states, attentions, past_key_values = block( - hidden_states, - past_key_values=past_key_values, - use_cache=use_cache, - output_attentions=output_attentions, - **kwargs - ) - - hidden_states = self.norm(hidden_states) - pooled_output = self.pooler(hidden_states) - - logits = self.classifier(pooled_output) - - loss = None - if labels is not None: - if self.num_labels == 1: - loss_fct = MSELoss() - loss = loss_fct(logits.squeeze(), labels.squeeze()) - else: - loss_fct = CrossEntropyLoss() - loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) - - if not return_dict: - output = (logits,) + (hidden_states,) - return ((loss,) + output) if loss is not None else output - - return ImageClassifierOutput( - loss=loss, - logits=logits, - hidden_states=hidden_states, - ) diff --git a/fla/vision_models/bitnet/__init__.py b/fla/vision_models/bitnet/__init__.py deleted file mode 100644 index 148b4557f..000000000 --- a/fla/vision_models/bitnet/__init__.py +++ /dev/null @@ -1,12 +0,0 @@ -from transformers import AutoConfig, AutoModelForImageClassification - -from fla.vision_models.bitnet.configuration_bitnet import BitNetVisionConfig -from fla.vision_models.bitnet.modeling_bitnet import BitNetForImageClassification - -AutoConfig.register(BitNetVisionConfig.model_type, BitNetVisionConfig) -AutoModelForImageClassification.register(BitNetVisionConfig, BitNetForImageClassification) - -__all__ = [ - 'BitNetVisionConfig', - 'BitNetForImageClassification' -] diff --git a/fla/vision_models/bitnet/configuration_bitnet.py b/fla/vision_models/bitnet/configuration_bitnet.py deleted file mode 100644 index 902f3f9a3..000000000 --- a/fla/vision_models/bitnet/configuration_bitnet.py +++ /dev/null @@ -1,95 +0,0 @@ -from typing import Dict, Optional - -from transformers.configuration_utils import PretrainedConfig - - -class BitNetVisionConfig(PretrainedConfig): - - model_type = 'bitnet_vision' - - def __init__( - self, - # BitNet core parameters - hidden_size: int = 2048, - num_hidden_layers: int = 24, - num_heads: int = 32, - num_kv_heads: int = None, - window_size: Optional[int] = None, - rope_theta: Optional[float] = 10000., - max_position_embeddings: int = 2048, - hidden_act: str = "swish", - initializer_range: float = 0.02, - elementwise_affine: Optional[bool] = True, - norm_first: bool = False, - norm_eps: float = 1e-6, - use_cache: bool = True, - attention_bias: bool = False, - fuse_norm: bool = True, - fuse_cross_entropy: bool = True, - attn: Optional[Dict] = None, - # Vision specific parameters - image_size: int = 224, - patch_size: int = 16, - num_channels: int = 3, - num_classes: int = 1000, - qkv_bias: bool = True, - hidden_dropout_prob: float = 0.0, - use_mask_token: bool = False, - layer_norm_eps: float = 1e-6, - interpolate_pos_encoding: bool = False, - mlp_dim: int = None, - # FLA-for-vision-related parameters - scan_type: str = "uni-scan", # scaning type, "uni-scan" or "bi-scan" or "cross-scan", default to "uni-scan" - **kwargs - ): - # Initialize BitNet core parameters - self.hidden_size = hidden_size - self.num_hidden_layers = num_hidden_layers - self.num_heads = num_heads - self.num_kv_heads = num_kv_heads - self.window_size = window_size - self.rope_theta = rope_theta - self.max_position_embeddings = max_position_embeddings - - self.hidden_act = hidden_act - - self.initializer_range = initializer_range - self.elementwise_affine = elementwise_affine - self.norm_first = norm_first - self.norm_eps = norm_eps - self.use_cache = use_cache - self.attention_bias = attention_bias - self.fuse_cross_entropy = fuse_cross_entropy - self.fuse_norm = fuse_norm - - # Initialize vision specific parameters - self.image_size = image_size - self.patch_size = patch_size - self.num_channels = num_channels - self.num_classes = num_classes - self.qkv_bias = qkv_bias - self.hidden_dropout_prob = hidden_dropout_prob - self.use_mask_token = use_mask_token - self.layer_norm_eps = layer_norm_eps - self.interpolate_pos_encoding = interpolate_pos_encoding - self.scan_type = scan_type - - - if attn is not None: - if not isinstance(attn, Dict): - raise ValueError("attn must be a dictionary") - if 'layers' not in attn: - raise ValueError("Layer indices must be provided to initialize hybrid attention layers") - if 'num_heads' not in attn: - raise ValueError("Number of heads must be provided to initialize hybrid attention layers") - attn['num_kv_heads'] = attn.get('num_kv_heads', attn['num_heads']) - attn['window_size'] = attn.get('window_size', None) - - self.attn = attn - - if mlp_dim is None: - self.mlp_dim = 4 * hidden_size # default value set to 4 * hidden_size - else: - self.mlp_dim = mlp_dim - - super().__init__(**kwargs) \ No newline at end of file diff --git a/fla/vision_models/bitnet/modeling_bitnet.py b/fla/vision_models/bitnet/modeling_bitnet.py deleted file mode 100644 index fa9675095..000000000 --- a/fla/vision_models/bitnet/modeling_bitnet.py +++ /dev/null @@ -1,201 +0,0 @@ -import collections.abc -import math -import torch -import torch.utils.checkpoint -from torch import nn -from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss -from typing import Optional, Set, Tuple, Union, List, Dict, Unpack -from transformers.utils import logging -from fla.layers.attn import Attention -from transformers.modeling_outputs import ImageClassifierOutput -from transformers.modeling_utils import PreTrainedModel -from .configuration_bitnet import BitNetVisionConfig -from fla.layers.bitattn import BitAttention -from fla.models.utils import Cache -from ..utils import ImageEmbeddings, Pooler, prepare_hidden_states_for_cross_scan, prepare_hidden_states_for_cross_merge - -logger = logging.get_logger(__name__) - -class BitNetMLP(nn.Module): - def __init__(self, config): - super().__init__() - self.net = nn.Sequential( - nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps), - nn.Linear(config.hidden_size, config.mlp_dim), - nn.GELU(), - nn.Linear(config.mlp_dim, config.hidden_size), - nn.Dropout(config.hidden_dropout_prob) - ) - - def forward(self, x): - return self.net(x) - -class BitNetBlock(nn.Module): - def __init__(self, config, layer_idx: int): - super().__init__() - - if not config.norm_first: - self.ln_1 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) - - if config.attn is not None and layer_idx in config.attn['layers']: - self.attn = Attention( - hidden_size=config.hidden_size, - num_heads=config.attn['num_heads'], - num_kv_heads=config.attn['num_kv_heads'], - window_size=config.attn['window_size'], - max_position_embeddings=config.max_position_embeddings, - layer_idx=layer_idx - ) - else: - self.attn = BitAttention( - hidden_size=config.hidden_size, - num_heads=config.num_heads, - num_kv_heads=config.num_kv_heads, - window_size=config.window_size, - rope_theta=config.rope_theta, - max_position_embeddings=config.max_position_embeddings, - norm_first=config.norm_first, - norm_eps=config.norm_eps, - layer_idx=layer_idx - ) - - if not config.norm_first: - self.ln_2 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) - - self.mlp = BitNetMLP(config) - - self.scan_type = config.scan_type - - def forward( - self, - hidden_states: torch.Tensor, - past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, - use_cache: Optional[bool] = False, - output_attentions: Optional[bool] = False, - **kwargs: Unpack[Dict] - ) -> Union[Tuple[torch.Tensor, Optional[torch.Tensor]], Tuple[torch.Tensor]]: - residual = hidden_states - - # Pre-normalization if enabled - if hasattr(self, 'ln_1'): - hidden_states = self.ln_1(hidden_states) - - # Apply attention - - hidden_states = prepare_hidden_states_for_cross_scan(hidden_states, self.scan_type) - - hidden_states, attentions, past_key_values = self.attn( - hidden_states=hidden_states, - past_key_values=past_key_values, - use_cache=use_cache, - output_attentions=output_attentions, - **kwargs - ) - - hidden_states = prepare_hidden_states_for_cross_merge(hidden_states, self.scan_type) - - # First residual connection - hidden_states = residual + hidden_states - residual = hidden_states - - # Pre-normalization for MLP if enabled - if hasattr(self, 'ln_2'): - hidden_states = self.ln_2(hidden_states) - - # MLP - hidden_states = self.mlp(hidden_states) - - # Second residual connection - hidden_states = residual + hidden_states - - outputs = (hidden_states, attentions, past_key_values) - - return outputs - -class BitNetVisionPreTrainedModel(PreTrainedModel): - # this part of the code is adapted from huggingface/transformers vit implementation - config_class = BitNetVisionConfig - base_model_prefix = "bitnet" - - def _init_weights(self, module): - if isinstance(module, (nn.Linear, nn.Conv2d)): - module.weight.data = nn.init.trunc_normal_( - module.weight.data.to(torch.float32), mean=0.0, std=self.config.initializer_range - ).to(module.weight.dtype) - if module.bias is not None: - module.bias.data.zero_() - elif isinstance(module, nn.LayerNorm): - module.bias.data.zero_() - module.weight.data.fill_(1.0) - elif isinstance(module, ImageEmbeddings): - module.position_embeddings.data = nn.init.trunc_normal_( - module.position_embeddings.data.to(torch.float32), - mean=0.0, - std=self.config.initializer_range, - ).to(module.position_embeddings.dtype) - -class BitNetForImageClassification(BitNetVisionPreTrainedModel): - config_class = BitNetVisionConfig - - def __init__(self, config): - super().__init__(config) - self.num_labels = config.num_classes - - self.embeddings = ImageEmbeddings(config) - self.blocks = nn.ModuleList([ - BitNetBlock(config, layer_idx) - for layer_idx in range(config.num_hidden_layers) - ]) - self.norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) - self.pooler = Pooler(config) - self.classifier = nn.Linear(config.hidden_size, config.num_classes) - self.interpolate_pos_encoding = config.interpolate_pos_encoding - self.init_weights() - - def forward( - self, - pixel_values: Optional[torch.Tensor] = None, - labels: Optional[torch.Tensor] = None, - output_hidden_states: Optional[bool] = None, - past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - return_dict: Optional[bool] = None, - **kwargs: Unpack[Dict] - ) -> Union[Tuple, ImageClassifierOutput]: - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - fuse_linear_and_cross_entropy = self.config.fuse_cross_entropy and self.training - hidden_states = self.embeddings(pixel_values, interpolate_pos_encoding=self.interpolate_pos_encoding) - - for block in self.blocks: - hidden_states, attentions, past_key_values = block( - hidden_states, - past_key_values=past_key_values, - use_cache=use_cache, - output_attentions=output_attentions, - **kwargs - ) - - hidden_states = self.norm(hidden_states) - pooled_output = self.pooler(hidden_states) - - logits = self.classifier(pooled_output) - - loss = None - if labels is not None: - if self.num_labels == 1: - loss_fct = MSELoss() - loss = loss_fct(logits.squeeze(), labels.squeeze()) - else: - loss_fct = CrossEntropyLoss() - loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) - - if not return_dict: - output = (logits,) + (hidden_states,) - return ((loss,) + output) if loss is not None else output - - return ImageClassifierOutput( - loss=loss, - logits=logits, - hidden_states=hidden_states, - ) diff --git a/fla/vision_models/delta_net/__init__.py b/fla/vision_models/delta_net/__init__.py deleted file mode 100644 index eef31ccbc..000000000 --- a/fla/vision_models/delta_net/__init__.py +++ /dev/null @@ -1,16 +0,0 @@ -from transformers import AutoConfig, AutoModel, AutoModelForImageClassification, AutoModelForMaskedImageModeling - -from fla.vision_models.delta_net.configuration_delta_net import DeltaNetVisionConfig -from fla.vision_models.delta_net.modeling_delta_net import DeltaNetForImageClassification, DeltaNetVisionModel, DeltaNetForMaskedImageModeling - -AutoConfig.register(DeltaNetVisionConfig.model_type, DeltaNetVisionConfig) -AutoModelForImageClassification.register(DeltaNetVisionConfig, DeltaNetForImageClassification) -AutoModelForMaskedImageModeling.register(DeltaNetVisionConfig, DeltaNetForMaskedImageModeling) -AutoModel.register(DeltaNetVisionConfig, DeltaNetVisionModel) - -__all__ = [ - "DeltaNetVisionConfig", - "DeltaNetForImageClassification", - "DeltaNetVisionModel", - "DeltaNetForMaskedImageModeling" -] diff --git a/fla/vision_models/delta_net/configuration_delta_net.py b/fla/vision_models/delta_net/configuration_delta_net.py deleted file mode 100644 index b24b48908..000000000 --- a/fla/vision_models/delta_net/configuration_delta_net.py +++ /dev/null @@ -1,101 +0,0 @@ -from typing import Dict, Optional -from transformers.configuration_utils import PretrainedConfig - -class DeltaNetVisionConfig(PretrainedConfig): - model_type = 'delta_net_vision' - - def __init__( - self, - # DeltaNet core parameters - attn_mode: str = "chunk", - hidden_size: int = 2048, - expand_k: int = 1, - expand_v: int = 1, - use_gate: bool = False, - use_short_conv: bool = True, - conv_size: int = 4, - use_beta: bool = True, - use_output_norm: bool = True, - num_heads: int = 16, - qk_norm: str = 'l2', - qk_activation: str = 'silu', - intermediate_size: Optional[int] = None, - hidden_act: str = "swish", - num_hidden_layers: int = 12, - norm_first: bool = False, - norm_eps: float = 1e-6, - attn: Optional[Dict] = None, - use_cache: bool = True, - initializer_range: float = 0.02, - fuse_cross_entropy: bool = True, - max_position_embeddings: int = 2048, - - # Vision specific parameters - image_size: int = 224, - patch_size: int = 16, - num_channels: int = 3, - num_classes: int = 1000, - hidden_dropout_prob: float = 0.0, - use_mask_token: bool = False, - layer_norm_eps: float = 1e-6, - interpolate_pos_encoding: bool = False, - encoder_stride=16, - mlp_dim: int = None, - # FLA-for-vision-related parameters - scan_type: str = "uni-scan", # scaning type, "uni-scan" or "bi-scan" or "cross-scan", default to "uni-scan" - **kwargs - ): - # Initialize DeltaNet core parameters - self.attn_mode = attn_mode - self.hidden_size = hidden_size - self.expand_k = expand_k - self.expand_v = expand_v - self.use_gate = use_gate - self.use_short_conv = use_short_conv - self.conv_size = conv_size - self.use_beta = use_beta - self.use_output_norm = use_output_norm - self.num_heads = num_heads - self.qk_norm = qk_norm - self.qk_activation = qk_activation - self.intermediate_size = intermediate_size - self.hidden_act = hidden_act - self.num_hidden_layers = num_hidden_layers - self.norm_first = norm_first - self.norm_eps = norm_eps - self.use_cache = use_cache - self.initializer_range = initializer_range - self.fuse_cross_entropy = fuse_cross_entropy - self.max_position_embeddings = max_position_embeddings - - # Initialize vision specific parameters - self.image_size = image_size - self.patch_size = patch_size - self.num_channels = num_channels - self.num_classes = num_classes - self.hidden_dropout_prob = hidden_dropout_prob - self.use_mask_token = use_mask_token - self.layer_norm_eps = layer_norm_eps - self.interpolate_pos_encoding = interpolate_pos_encoding - self.scan_type = scan_type - self.encoder_stride = encoder_stride - - - if attn is not None: - if not isinstance(attn, Dict): - raise ValueError("attn must be a dictionary") - if 'layers' not in attn: - raise ValueError("Layer indices must be provided to initialize hybrid attention layers") - if 'num_heads' not in attn: - raise ValueError("Number of heads must be provided to initialize hybrid attention layers") - attn['num_kv_heads'] = attn.get('num_kv_heads', attn['num_heads']) - attn['window_size'] = attn.get('window_size', None) - - self.attn = attn - - if mlp_dim is None: - self.mlp_dim = 4 * hidden_size # default value set to 4 * hidden_size - else: - self.mlp_dim = mlp_dim - - super().__init__(**kwargs) diff --git a/fla/vision_models/delta_net/modeling_delta_net.py b/fla/vision_models/delta_net/modeling_delta_net.py deleted file mode 100644 index e9ece6adc..000000000 --- a/fla/vision_models/delta_net/modeling_delta_net.py +++ /dev/null @@ -1,385 +0,0 @@ -import collections.abc -import math -import torch -import torch.utils.checkpoint -from torch import nn -from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss -from typing import Optional, Set, Tuple, Union, List, Dict, Unpack -from transformers.utils import logging -from fla.layers.attn import Attention -from transformers.modeling_outputs import ImageClassifierOutput, BaseModelOutput, BaseModelOutputWithPooling, MaskedImageModelingOutput -from transformers.modeling_utils import PreTrainedModel -from .configuration_delta_net import DeltaNetVisionConfig -from fla.layers.delta_net import DeltaNet -from fla.models.utils import Cache -from ..utils import ImageEmbeddings, Pooler, prepare_hidden_states_for_cross_scan, prepare_hidden_states_for_cross_merge - -logger = logging.get_logger(__name__) - -class DeltaNetMLP(nn.Module): - def __init__(self, config): - super().__init__() - self.net = nn.Sequential( - nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps), - nn.Linear(config.hidden_size, config.mlp_dim), - nn.GELU(), - nn.Linear(config.mlp_dim, config.hidden_size), - nn.Dropout(config.hidden_dropout_prob) - ) - - def forward(self, x): - return self.net(x) - -class DeltaNetBlock(nn.Module): - def __init__(self, config, layer_idx: int): - super().__init__() - - if not config.norm_first: - self.ln_1 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) - - if config.attn is not None and layer_idx in config.attn['layers']: - self.attn = Attention( - hidden_size=config.hidden_size, - num_heads=config.attn['num_heads'], - num_kv_heads=config.attn['num_kv_heads'], - window_size=config.attn['window_size'], - max_position_embeddings=config.max_position_embeddings, - layer_idx=layer_idx - ) - else: - self.attn = DeltaNet( - mode=config.attn_mode, - hidden_size=config.hidden_size, - expand_k=config.expand_k, - expand_v=config.expand_v, - num_heads=config.num_heads, - use_gate=config.use_gate, - use_beta=config.use_beta, - use_short_conv=config.use_short_conv, - use_output_norm=config.use_output_norm, - conv_size=config.conv_size, - qk_norm=config.qk_norm, - qk_activation=config.qk_activation, - norm_first=config.norm_first, - norm_eps=config.norm_eps, - layer_idx=layer_idx - ) - - if not config.norm_first: - self.ln_2 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) - - self.mlp = DeltaNetMLP(config) - - self.scan_type = config.scan_type - - def forward( - self, - hidden_states: torch.Tensor, - past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, - use_cache: Optional[bool] = False, - output_attentions: Optional[bool] = False, - **kwargs: Unpack[Dict] - ) -> Union[Tuple[torch.Tensor, Optional[torch.Tensor]], Tuple[torch.Tensor]]: - residual = hidden_states - - # Pre-normalization if enabled - if hasattr(self, 'ln_1'): - hidden_states = self.ln_1(hidden_states) - - # Apply attention - - hidden_states = prepare_hidden_states_for_cross_scan(hidden_states, self.scan_type) - - hidden_states, attentions, past_key_values = self.attn( - hidden_states=hidden_states, - past_key_values=past_key_values, - use_cache=use_cache, - output_attentions=output_attentions, - **kwargs - ) - - hidden_states = prepare_hidden_states_for_cross_merge(hidden_states, self.scan_type) - - # First residual connection - hidden_states = residual + hidden_states - residual = hidden_states - - # Pre-normalization for MLP if enabled - if hasattr(self, 'ln_2'): - hidden_states = self.ln_2(hidden_states) - - hidden_states = self.mlp(hidden_states) - - # Second residual connection - hidden_states = residual + hidden_states - - outputs = (hidden_states, attentions, past_key_values) - - return outputs - -class DeltaNetVisionPreTrainedModel(PreTrainedModel): - config_class = DeltaNetVisionConfig - - def _init_weights(self, module): - if isinstance(module, (nn.Linear, nn.Conv2d)): - module.weight.data = nn.init.trunc_normal_( - module.weight.data.to(torch.float32), mean=0.0, std=self.config.initializer_range - ).to(module.weight.dtype) - if module.bias is not None: - module.bias.data.zero_() - elif isinstance(module, nn.LayerNorm): - module.bias.data.zero_() - module.weight.data.fill_(1.0) - elif isinstance(module, ImageEmbeddings): - module.position_embeddings.data = nn.init.trunc_normal_( - module.position_embeddings.data.to(torch.float32), - mean=0.0, - std=self.config.initializer_range, - ).to(module.position_embeddings.dtype) - - -class DeltaNetVisionEncoder(nn.Module): - def __init__(self, config) -> None: - super().__init__() - self.config = config - self.blocks = nn.ModuleList([ - DeltaNetBlock(config, layer_idx) - for layer_idx in range(config.num_hidden_layers) - ]) - self.gradient_checkpointing = False - - def forward( - self, - hidden_states: torch.Tensor, - output_attentions: bool = False, - output_hidden_states: bool = False, - past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, - use_cache: Optional[bool] = None, - return_dict: bool = True, - **kwargs - ) -> Union[tuple, BaseModelOutput]: - all_hidden_states = () if output_hidden_states else None - all_self_attentions = () if output_attentions else None - - for i, block in enumerate(self.blocks): - if output_hidden_states: - all_hidden_states = all_hidden_states + (hidden_states,) - - if self.gradient_checkpointing and self.training: - hidden_states, attentions, past_key_values = self._gradient_checkpointing_func( - block.__call__, - hidden_states, - past_key_values=past_key_values, - use_cache=use_cache, - output_attentions=output_attentions, - **kwargs - ) - else: - hidden_states, attentions, past_key_values = block( - hidden_states, - past_key_values=past_key_values, - use_cache=use_cache, - output_attentions=output_attentions, - **kwargs - ) - - if output_attentions: - all_self_attentions = all_self_attentions + (attentions,) - - if output_hidden_states: - all_hidden_states = all_hidden_states + (hidden_states,) - - if not return_dict: - return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None) - - return BaseModelOutput( - last_hidden_state=hidden_states, - hidden_states=all_hidden_states, - attentions=all_self_attentions, - ) - -class DeltaNetVisionModel(DeltaNetVisionPreTrainedModel): - def __init__(self, config, add_pooling_layer=True, use_mask_token=False): - super().__init__(config) - self.config = config - self.embeddings = ImageEmbeddings(config, use_mask_token=use_mask_token) - self.encoder = DeltaNetVisionEncoder(config) - self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) - self.pooler = Pooler(config) if add_pooling_layer else None - self.init_weights() - - def get_input_embeddings(self): - return self.embeddings.patch_embeddings - - def forward( - self, - pixel_values: Optional[torch.Tensor] = None, - bool_masked_pos: Optional[torch.BoolTensor] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, - interpolate_pos_encoding: Optional[bool] = None, - use_cache: Optional[bool] = None, - return_dict: Optional[bool] = None, - **kwargs - ) -> Union[Tuple, BaseModelOutputWithPooling]: - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - if pixel_values is None: - raise ValueError("You have to specify pixel_values") - - hidden_states = self.embeddings(pixel_values, bool_masked_pos=bool_masked_pos, interpolate_pos_encoding=interpolate_pos_encoding) - - encoder_outputs = self.encoder( - hidden_states, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - past_key_values=past_key_values, - use_cache=use_cache, - return_dict=return_dict, - **kwargs - ) - - sequence_output = encoder_outputs[0] - sequence_output = self.layernorm(sequence_output) - pooled_output = self.pooler(sequence_output) if self.pooler is not None else None - - if not return_dict: - head_outputs = (sequence_output, pooled_output) if pooled_output is not None else (sequence_output,) - return head_outputs + encoder_outputs[1:] - - return BaseModelOutputWithPooling( - last_hidden_state=sequence_output, - pooler_output=pooled_output, - hidden_states=encoder_outputs.hidden_states, - attentions=encoder_outputs.attentions, - ) - -class DeltaNetForImageClassification(DeltaNetVisionPreTrainedModel): - def __init__(self, config): - super().__init__(config) - self.num_labels = config.num_classes - self.backbone = DeltaNetVisionModel(config, add_pooling_layer=True) # Here we should use mean pooling - self.classifier = nn.Linear(config.hidden_size, config.num_classes) - self.init_weights() - - def forward( - self, - pixel_values: Optional[torch.Tensor] = None, - labels: Optional[torch.Tensor] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - interpolate_pos_encoding: Optional[bool] = None, - return_dict: Optional[bool] = None, - ) -> Union[tuple, ImageClassifierOutput]: - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - outputs = self.backbone( - pixel_values, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - interpolate_pos_encoding=interpolate_pos_encoding, - return_dict=return_dict, - ) - - pooled_output = outputs.pooler_output - logits = self.classifier(pooled_output) # only use mean pooling - - loss = None - if labels is not None: - if self.num_labels == 1: - loss_fct = MSELoss() - loss = loss_fct(logits.squeeze(), labels.squeeze()) - else: - loss_fct = CrossEntropyLoss() - loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) - - if not return_dict: - output = (logits,) + outputs[2:] - return ((loss,) + output) if loss is not None else output - - return ImageClassifierOutput( - loss=loss, - logits=logits, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - ) - -class DeltaNetForMaskedImageModeling(DeltaNetVisionPreTrainedModel): - def __init__(self, config): - super().__init__(config) - self.backbone = DeltaNetVisionModel(config, add_pooling_layer=False, use_mask_token=True) - self.decoder = nn.Sequential( - nn.Conv2d( - in_channels=config.hidden_size, - out_channels=config.encoder_stride**2 * config.num_channels, - kernel_size=1, - ), - nn.PixelShuffle(config.encoder_stride), - ) - self.init_weights() - - def forward( - self, - pixel_values: Optional[torch.Tensor] = None, - bool_masked_pos: Optional[torch.BoolTensor] = None, - labels: Optional[torch.Tensor] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - interpolate_pos_encoding: Optional[bool] = None, - return_dict: Optional[bool] = None, - ) -> Union[tuple, MaskedImageModelingOutput]: - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - if bool_masked_pos is not None and (self.config.patch_size != self.config.encoder_stride): - raise ValueError( - "When `bool_masked_pos` is provided, `patch_size` must be equal to `encoder_stride` to ensure that " - "the reconstructed image has the same dimensions as the input. " - f"Got `patch_size` = {self.config.patch_size} and `encoder_stride` = {self.config.encoder_stride}." - ) - - outputs = self.backbone( - pixel_values, - bool_masked_pos=bool_masked_pos, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - interpolate_pos_encoding=interpolate_pos_encoding, - return_dict=return_dict, - ) - - - sequence_output = outputs[0] - batch_size, sequence_length, num_channels = sequence_output.shape - height = width = math.floor(sequence_length**0.5) - sequence_output = sequence_output.permute(0, 2, 1).reshape(batch_size, num_channels, height, width) - - # Reconstruct pixel values - reconstructed_pixel_values = self.decoder(sequence_output) - - masked_im_loss = None - if bool_masked_pos is not None: - size = self.config.image_size // self.config.patch_size - bool_masked_pos = bool_masked_pos.reshape(-1, size, size) - mask = ( - bool_masked_pos.repeat_interleave(self.config.patch_size, 1) - .repeat_interleave(self.config.patch_size, 2) - .unsqueeze(1) - .contiguous() - ) - reconstruction_loss = nn.functional.l1_loss(pixel_values, reconstructed_pixel_values, reduction="none") - masked_im_loss = (reconstruction_loss * mask).sum() / (mask.sum() + 1e-5) / self.config.num_channels - - if not return_dict: - output = (reconstructed_pixel_values,) + outputs[1:] - return ((masked_im_loss,) + output) if masked_im_loss is not None else output - - return MaskedImageModelingOutput( - loss=masked_im_loss, - reconstruction=reconstructed_pixel_values, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - ) diff --git a/fla/vision_models/gated_deltanet/__init__.py b/fla/vision_models/gated_deltanet/__init__.py deleted file mode 100644 index 45bb5ffbf..000000000 --- a/fla/vision_models/gated_deltanet/__init__.py +++ /dev/null @@ -1,13 +0,0 @@ -from transformers import AutoConfig, AutoModelForImageClassification - -from fla.vision_models.gated_deltanet.configuration_gated_deltanet import GatedDeltaNetVisionConfig -from fla.vision_models.gated_deltanet.modeling_gated_deltanet import GatedDeltaNetForImageClassification - -AutoConfig.register(GatedDeltaNetVisionConfig.model_type, GatedDeltaNetVisionConfig) -AutoModelForImageClassification.register(GatedDeltaNetVisionConfig, GatedDeltaNetForImageClassification) - -__all__ = [ - 'GatedDeltaNetVisionConfig', - 'GatedDeltaNetForImageClassification' -] - diff --git a/fla/vision_models/gated_deltanet/configuration_gated_deltanet.py b/fla/vision_models/gated_deltanet/configuration_gated_deltanet.py deleted file mode 100644 index 6cbbd9e72..000000000 --- a/fla/vision_models/gated_deltanet/configuration_gated_deltanet.py +++ /dev/null @@ -1,89 +0,0 @@ -from typing import Dict, Optional -from transformers.configuration_utils import PretrainedConfig - -class GatedDeltaNetVisionConfig(PretrainedConfig): - model_type = 'gated_deltanet_vision' - - def __init__( - self, - # GatedDeltaNet core parameters - attn_mode: str = "chunk", - hidden_size: int = 2048, - expand_v: int = 2, - use_gate: bool = True, - use_short_conv: bool = True, - conv_size: int = 4, - head_dim: int = 256, - num_heads: int = 6, - max_position_embeddings: int = 2048, - hidden_act: str = "swish", - num_hidden_layers: int = 21, - norm_first: bool = False, - norm_eps: float = 1e-6, - attn: Optional[Dict] = None, - use_cache: bool = True, - initializer_range: float = 0.02, - fuse_cross_entropy: bool = True, - - # Vision specific parameters - image_size: int = 224, - patch_size: int = 16, - num_channels: int = 3, - num_classes: int = 1000, - hidden_dropout_prob: float = 0.0, - use_mask_token: bool = False, - layer_norm_eps: float = 1e-6, - interpolate_pos_encoding: bool = False, - mlp_dim: int = None, - # FLA-for-vision-related parameters - scan_type: str = "uni-scan", - **kwargs - ): - # Initialize GatedDeltaNet core parameters - self.attn_mode = attn_mode - self.hidden_size = hidden_size - self.expand_v = expand_v - self.head_dim = head_dim - self.use_gate = use_gate - self.use_short_conv = use_short_conv - self.conv_size = conv_size - self.num_heads = num_heads - self.hidden_act = hidden_act - self.num_hidden_layers = num_hidden_layers - self.norm_first = norm_first - self.norm_eps = norm_eps - self.use_cache = use_cache - self.initializer_range = initializer_range - self.fuse_cross_entropy = fuse_cross_entropy - self.attn = attn - self.max_position_embeddings = max_position_embeddings - - # Initialize vision specific parameters - self.image_size = image_size - self.patch_size = patch_size - self.num_channels = num_channels - self.num_classes = num_classes - self.hidden_dropout_prob = hidden_dropout_prob - self.use_mask_token = use_mask_token - self.layer_norm_eps = layer_norm_eps - self.interpolate_pos_encoding = interpolate_pos_encoding - self.scan_type = scan_type - - if attn is not None: - if not isinstance(attn, Dict): - raise ValueError("attn must be a dictionary") - if 'layers' not in attn: - raise ValueError("Layer indices must be provided to initialize hybrid attention layers") - if 'num_heads' not in attn: - raise ValueError("Number of heads must be provided to initialize hybrid attention layers") - attn['num_kv_heads'] = attn.get('num_kv_heads', attn['num_heads']) - attn['window_size'] = attn.get('window_size', None) - - self.attn = attn - - if mlp_dim is None: - self.mlp_dim = 4 * hidden_size - else: - self.mlp_dim = mlp_dim - - super().__init__(**kwargs) diff --git a/fla/vision_models/gated_deltanet/modeling_gated_deltanet.py b/fla/vision_models/gated_deltanet/modeling_gated_deltanet.py deleted file mode 100644 index 94694cca8..000000000 --- a/fla/vision_models/gated_deltanet/modeling_gated_deltanet.py +++ /dev/null @@ -1,202 +0,0 @@ -import collections.abc -import math -import torch -import torch.utils.checkpoint -from torch import nn -from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss -from typing import Optional, Set, Tuple, Union, List, Dict, Unpack -from transformers.utils import logging -from fla.layers.attn import Attention -from transformers.modeling_outputs import ImageClassifierOutput -from transformers.modeling_utils import PreTrainedModel -from .configuration_gated_deltanet import GatedDeltaNetVisionConfig -from fla.layers.gated_deltanet import GatedDeltaNet -from fla.models.utils import Cache -from ..utils import ImageEmbeddings, Pooler, prepare_hidden_states_for_cross_scan, prepare_hidden_states_for_cross_merge - -logger = logging.get_logger(__name__) - -class GatedDeltaNetMLP(nn.Module): - def __init__(self, config): - super().__init__() - self.net = nn.Sequential( - nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps), - nn.Linear(config.hidden_size, config.mlp_dim), - nn.GELU(), - nn.Linear(config.mlp_dim, config.hidden_size), - nn.Dropout(config.hidden_dropout_prob) - ) - - def forward(self, x): - return self.net(x) - -class GatedDeltaNetBlock(nn.Module): - def __init__(self, config, layer_idx: int): - super().__init__() - - if not config.norm_first: - self.ln_1 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) - - if config.attn is not None and layer_idx in config.attn['layers']: - self.attn = Attention( - hidden_size=config.hidden_size, - num_heads=config.attn['num_heads'], - num_kv_heads=config.attn['num_kv_heads'], - window_size=config.attn['window_size'], - max_position_embeddings=config.max_position_embeddings, - layer_idx=layer_idx - ) - else: - self.attn = GatedDeltaNet( - mode=config.attn_mode, - hidden_size=config.hidden_size, - expand_v=config.expand_v, - head_dim=config.head_dim, - num_heads=config.num_heads, - use_gate=config.use_gate, - use_short_conv=config.use_short_conv, - conv_size=config.conv_size, - norm_first=config.norm_first, - norm_eps=config.norm_eps, - layer_idx=layer_idx - ) - - if not config.norm_first: - self.ln_2 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) - - self.mlp = GatedDeltaNetMLP(config) - - self.scan_type = config.scan_type - - def forward( - self, - hidden_states: torch.Tensor, - past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, - use_cache: Optional[bool] = False, - output_attentions: Optional[bool] = False, - **kwargs: Unpack[Dict] - ) -> Union[Tuple[torch.Tensor, Optional[torch.Tensor]], Tuple[torch.Tensor]]: - residual = hidden_states - - # Pre-normalization if enabled - if hasattr(self, 'ln_1'): - hidden_states = self.ln_1(hidden_states) - - # Apply attention - - hidden_states = prepare_hidden_states_for_cross_scan(hidden_states, self.scan_type) - - hidden_states, attentions, past_key_values = self.attn( - hidden_states=hidden_states, - past_key_values=past_key_values, - use_cache=use_cache, - output_attentions=output_attentions, - **kwargs - ) - - hidden_states = prepare_hidden_states_for_cross_merge(hidden_states, self.scan_type) - - # First residual connection - hidden_states = residual + hidden_states - residual = hidden_states - - # Pre-normalization for MLP if enabled - if hasattr(self, 'ln_2'): - hidden_states = self.ln_2(hidden_states) - - # MLP - hidden_states = self.mlp(hidden_states) - - # Second residual connection - hidden_states = residual + hidden_states - - outputs = (hidden_states, attentions, past_key_values) - - return outputs - -class GatedDeltaNetVisionPreTrainedModel(PreTrainedModel): - # this part of the code is adapted from huggingface/transformers vit implementation - config_class = GatedDeltaNetVisionConfig - - def _init_weights(self, module): - if isinstance(module, (nn.Linear, nn.Conv2d)): - module.weight.data = nn.init.trunc_normal_( - module.weight.data.to(torch.float32), mean=0.0, std=self.config.initializer_range - ).to(module.weight.dtype) - if module.bias is not None: - module.bias.data.zero_() - elif isinstance(module, nn.LayerNorm): - module.bias.data.zero_() - module.weight.data.fill_(1.0) - elif isinstance(module, ImageEmbeddings): - module.position_embeddings.data = nn.init.trunc_normal_( - module.position_embeddings.data.to(torch.float32), - mean=0.0, - std=self.config.initializer_range, - ).to(module.position_embeddings.dtype) - -class GatedDeltaNetForImageClassification(GatedDeltaNetVisionPreTrainedModel): - config_class = GatedDeltaNetVisionConfig - - def __init__(self, config): - super().__init__(config) - self.num_labels = config.num_classes - - self.embeddings = ImageEmbeddings(config) - self.blocks = nn.ModuleList([ - GatedDeltaNetBlock(config, layer_idx) - for layer_idx in range(config.num_hidden_layers) - ]) - self.norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) - self.pooler = Pooler(config) - self.classifier = nn.Linear(config.hidden_size, config.num_classes) - self.interpolate_pos_encoding = config.interpolate_pos_encoding - self.init_weights() - - def forward( - self, - pixel_values: Optional[torch.Tensor] = None, - labels: Optional[torch.Tensor] = None, - output_hidden_states: Optional[bool] = None, - past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - return_dict: Optional[bool] = None, - **kwargs: Unpack[Dict] - ) -> Union[Tuple, ImageClassifierOutput]: - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - fuse_linear_and_cross_entropy = self.config.fuse_cross_entropy and self.training - hidden_states = self.embeddings(pixel_values, interpolate_pos_encoding=self.interpolate_pos_encoding) - - for block in self.blocks: - hidden_states, attentions, past_key_values = block( - hidden_states, - past_key_values=past_key_values, - use_cache=use_cache, - output_attentions=output_attentions, - **kwargs - ) - - hidden_states = self.norm(hidden_states) - pooled_output = self.pooler(hidden_states) - - logits = self.classifier(pooled_output) - - loss = None - if labels is not None: - if self.num_labels == 1: - loss_fct = MSELoss() - loss = loss_fct(logits.squeeze(), labels.squeeze()) - else: - loss_fct = CrossEntropyLoss() - loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) - - if not return_dict: - output = (logits,) + (hidden_states,) - return ((loss,) + output) if loss is not None else output - - return ImageClassifierOutput( - loss=loss, - logits=logits, - hidden_states=hidden_states, - ) diff --git a/fla/vision_models/gla/__init__.py b/fla/vision_models/gla/__init__.py deleted file mode 100644 index dc7d6e93c..000000000 --- a/fla/vision_models/gla/__init__.py +++ /dev/null @@ -1,12 +0,0 @@ -from transformers import AutoConfig, AutoModelForImageClassification - -from fla.vision_models.gla.configuration_gla import GLAVisionConfig -from fla.vision_models.gla.modeling_gla import GLAForImageClassification - -AutoConfig.register(GLAVisionConfig.model_type, GLAVisionConfig) -AutoModelForImageClassification.register(GLAVisionConfig, GLAForImageClassification) - -__all__ = [ - 'GLAVisionConfig', - 'GLAForImageClassification' -] diff --git a/fla/vision_models/gla/configuration_gla.py b/fla/vision_models/gla/configuration_gla.py deleted file mode 100644 index af52bbe6f..000000000 --- a/fla/vision_models/gla/configuration_gla.py +++ /dev/null @@ -1,101 +0,0 @@ -from typing import Dict, Optional - -from transformers.configuration_utils import PretrainedConfig - -class GLAVisionConfig(PretrainedConfig): - - model_type = 'gla_vision' - - def __init__( - self, - # GLA core parameters - hidden_size: int = 2048, - expand_k: int = 0.5, - expand_v: int = 1, - num_hidden_layers: int = 24, - num_heads: int = 4, - num_kv_heads: Optional[int] = None, - feature_map: Optional[str] = None, - attn_mode: str = "chunk", - use_short_conv: bool = False, - conv_size: int = 4, - use_output_gate: bool = True, - clamp_min: Optional[float] = None, - hidden_act: str = "swish", - max_position_embeddings: int = 2048, - elementwise_affine: Optional[bool] = True, - norm_eps: float = 1e-6, - use_gk: bool = True, - use_gv: bool = False, - attn: Optional[Dict] = None, - use_cache: bool = True, - initializer_range: float = 0.02, - fuse_norm: bool = True, - fuse_cross_entropy: bool = True, - # Vision specific parameters - image_size: int = 224, - patch_size: int = 16, - num_channels: int = 3, - num_classes: int = 1000, - hidden_dropout_prob: float = 0.0, - use_mask_token: bool = False, - layer_norm_eps: float = 1e-6, - interpolate_pos_encoding: bool = False, - mlp_dim: int = None, - # FLA-for-vision-related parameters - scan_type: str = "uni-scan", # scaning type, "uni-scan" or "bi-scan" or "cross-scan", default to "uni-scan" - **kwargs - ): - # Initialize DeltaNet core parameters - self.hidden_size = hidden_size - self.expand_k = expand_k - self.expand_v = expand_v - self.num_hidden_layers = num_hidden_layers - self.num_heads = num_heads - self.num_kv_heads = num_kv_heads - self.feature_map = feature_map - self.attn_mode = attn_mode - self.use_short_conv = use_short_conv - self.conv_size = conv_size - self.use_output_gate = use_output_gate - self.clamp_min = clamp_min - self.hidden_act = hidden_act - self.max_position_embeddings = max_position_embeddings - self.elementwise_affine = elementwise_affine - self.norm_eps = norm_eps - self.use_gk = use_gk - self.use_gv = use_gv - self.use_cache = use_cache - self.initializer_range = initializer_range - self.fuse_norm = fuse_norm - self.fuse_cross_entropy = fuse_cross_entropy - - # Initialize vision specific parameters - self.image_size = image_size - self.patch_size = patch_size - self.num_channels = num_channels - self.num_classes = num_classes - self.hidden_dropout_prob = hidden_dropout_prob - self.use_mask_token = use_mask_token - self.layer_norm_eps = layer_norm_eps - self.interpolate_pos_encoding = interpolate_pos_encoding - self.scan_type = scan_type - - if attn is not None: - if not isinstance(attn, Dict): - raise ValueError("attn must be a dictionary") - if 'layers' not in attn: - raise ValueError("Layer indices must be provided to initialize hybrid attention layers") - if 'num_heads' not in attn: - raise ValueError("Number of heads must be provided to initialize hybrid attention layers") - attn['num_kv_heads'] = attn.get('num_kv_heads', attn['num_heads']) - attn['window_size'] = attn.get('window_size', None) - - self.attn = attn - - if mlp_dim is None: - self.mlp_dim = 4 * hidden_size - else: - self.mlp_dim = mlp_dim - - super().__init__(**kwargs) diff --git a/fla/vision_models/gla/modeling_gla.py b/fla/vision_models/gla/modeling_gla.py deleted file mode 100644 index 311e000a0..000000000 --- a/fla/vision_models/gla/modeling_gla.py +++ /dev/null @@ -1,205 +0,0 @@ -import collections.abc -import math -import torch -import torch.utils.checkpoint -from torch import nn -from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss -from typing import Optional, Set, Tuple, Union, List, Dict, Unpack -from transformers.utils import logging -from fla.layers.attn import Attention -from transformers.modeling_outputs import ImageClassifierOutput -from transformers.modeling_utils import PreTrainedModel -from .configuration_gla import GLAVisionConfig -from fla.layers.gla import GatedLinearAttention -from fla.models.utils import Cache -from ..utils import ImageEmbeddings, Pooler, prepare_hidden_states_for_cross_scan, prepare_hidden_states_for_cross_merge - -logger = logging.get_logger(__name__) - -class GLAMLP(nn.Module): - def __init__(self, config): - super().__init__() - self.net = nn.Sequential( - nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps), - nn.Linear(config.hidden_size, config.mlp_dim), - nn.GELU(), - nn.Linear(config.mlp_dim, config.hidden_size), - nn.Dropout(config.hidden_dropout_prob) - ) - - def forward(self, x): - return self.net(x) - -class GLABlock(nn.Module): - def __init__(self, config, layer_idx: int): - super().__init__() - - self.ln_1 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) - - if config.attn is not None and layer_idx in config.attn['layers']: - self.attn = Attention( - hidden_size=config.hidden_size, - num_heads=config.attn['num_heads'], - num_kv_heads=config.attn['num_kv_heads'], - window_size=config.attn['window_size'], - max_position_embeddings=config.max_position_embeddings, - layer_idx=layer_idx - ) - else: - self.attn = GatedLinearAttention( - mode=config.attn_mode, - hidden_size=config.hidden_size, - expand_k=config.expand_k, - expand_v=config.expand_v, - num_heads=config.num_heads, - num_kv_heads=config.num_kv_heads, - feature_map=config.feature_map, - use_short_conv=config.use_short_conv, - conv_size=config.conv_size, - use_output_gate=config.use_output_gate, - gate_fn=config.hidden_act, - elementwise_affine=config.elementwise_affine, - norm_eps=config.norm_eps, - clamp_min=config.clamp_min, - fuse_norm=config.fuse_norm, - layer_idx=layer_idx - ) - - self.ln_2 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) - - self.mlp = GLAMLP(config) - - self.scan_type = config.scan_type - - def forward( - self, - hidden_states: torch.Tensor, - past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, - use_cache: Optional[bool] = False, - output_attentions: Optional[bool] = False, - **kwargs: Unpack[Dict] - ) -> Union[Tuple[torch.Tensor, Optional[torch.Tensor]], Tuple[torch.Tensor]]: - residual = hidden_states - - # Pre-normalization if enabled - if hasattr(self, 'ln_1'): - hidden_states = self.ln_1(hidden_states) - - # Apply attention - - hidden_states = prepare_hidden_states_for_cross_scan(hidden_states, self.scan_type) - - hidden_states, attentions, past_key_values = self.attn( - hidden_states=hidden_states, - past_key_values=past_key_values, - use_cache=use_cache, - output_attentions=output_attentions, - **kwargs - ) - - hidden_states = prepare_hidden_states_for_cross_merge(hidden_states, self.scan_type) - - # First residual connection - hidden_states = residual + hidden_states - residual = hidden_states - - # Pre-normalization for MLP if enabled - if hasattr(self, 'ln_2'): - hidden_states = self.ln_2(hidden_states) - - # MLP - hidden_states = self.mlp(hidden_states) - - # Second residual connection - hidden_states = residual + hidden_states - - outputs = (hidden_states, attentions, past_key_values) - - return outputs - -class GLAVisionPreTrainedModel(PreTrainedModel): - # this part of the code is adapted from huggingface/transformers vit implementation - config_class = GLAVisionConfig - - def _init_weights(self, module): - if isinstance(module, (nn.Linear, nn.Conv2d)): - module.weight.data = nn.init.trunc_normal_( - module.weight.data.to(torch.float32), mean=0.0, std=self.config.initializer_range - ).to(module.weight.dtype) - if module.bias is not None: - module.bias.data.zero_() - elif isinstance(module, nn.LayerNorm): - module.bias.data.zero_() - module.weight.data.fill_(1.0) - elif isinstance(module, ImageEmbeddings): - module.position_embeddings.data = nn.init.trunc_normal_( - module.position_embeddings.data.to(torch.float32), - mean=0.0, - std=self.config.initializer_range, - ).to(module.position_embeddings.dtype) - -class GLAForImageClassification(GLAVisionPreTrainedModel): - config_class = GLAVisionConfig - - def __init__(self, config): - super().__init__(config) - self.num_labels = config.num_classes - - self.embeddings = ImageEmbeddings(config) - self.blocks = nn.ModuleList([ - GLABlock(config, layer_idx) - for layer_idx in range(config.num_hidden_layers) - ]) - self.norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) - self.pooler = Pooler(config) - self.classifier = nn.Linear(config.hidden_size, config.num_classes) - self.interpolate_pos_encoding = config.interpolate_pos_encoding - self.init_weights() - - def forward( - self, - pixel_values: Optional[torch.Tensor] = None, - labels: Optional[torch.Tensor] = None, - output_hidden_states: Optional[bool] = None, - past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - return_dict: Optional[bool] = None, - **kwargs: Unpack[Dict] - ) -> Union[Tuple, ImageClassifierOutput]: - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - fuse_linear_and_cross_entropy = self.config.fuse_cross_entropy and self.training - hidden_states = self.embeddings(pixel_values, interpolate_pos_encoding=self.interpolate_pos_encoding) - - for block in self.blocks: - hidden_states, attentions, past_key_values = block( - hidden_states, - past_key_values=past_key_values, - use_cache=use_cache, - output_attentions=output_attentions, - **kwargs - ) - - hidden_states = self.norm(hidden_states) - pooled_output = self.pooler(hidden_states) - - logits = self.classifier(pooled_output) - - loss = None - if labels is not None: - if self.num_labels == 1: - loss_fct = MSELoss() - loss = loss_fct(logits.squeeze(), labels.squeeze()) - else: - loss_fct = CrossEntropyLoss() - loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) - - if not return_dict: - output = (logits,) + (hidden_states,) - return ((loss,) + output) if loss is not None else output - - return ImageClassifierOutput( - loss=loss, - logits=logits, - hidden_states=hidden_states, - ) diff --git a/fla/vision_models/gsa/__init__.py b/fla/vision_models/gsa/__init__.py deleted file mode 100644 index 3da164504..000000000 --- a/fla/vision_models/gsa/__init__.py +++ /dev/null @@ -1,12 +0,0 @@ -from transformers import AutoConfig, AutoModelForImageClassification - -from fla.vision_models.gsa.configuration_gsa import GSAVisionConfig -from fla.vision_models.gsa.modeling_gsa import GSAForImageClassification - -AutoConfig.register(GSAVisionConfig.model_type, GSAVisionConfig) -AutoModelForImageClassification.register(GSAVisionConfig, GSAForImageClassification) - -__all__ = [ - 'GSAVisionConfig', - 'GSAForImageClassification' -] diff --git a/fla/vision_models/gsa/configuration_gsa.py b/fla/vision_models/gsa/configuration_gsa.py deleted file mode 100644 index de4bbcb8d..000000000 --- a/fla/vision_models/gsa/configuration_gsa.py +++ /dev/null @@ -1,107 +0,0 @@ -from typing import Dict, Optional - -from transformers.configuration_utils import PretrainedConfig - - -class GSAVisionConfig(PretrainedConfig): - - model_type = 'gsa_vision' - - def __init__( - self, - # GSA core parameters - hidden_size: int = 2048, - gate_logit_normalizer: Optional[int] = 8, - clamp_min: Optional[float] = None, - clamp_max: Optional[float] = None, - num_hidden_layers: int = 24, - num_heads: int = 4, - num_kv_heads: Optional[int] = None, - num_slots: Optional[int] = 64, - use_short_conv: bool = False, - conv_size: int = 4, - exapnd_k: float = 1, - exapnd_v: float = 1, - feature_map: str = 'swish', - use_output_gate: bool = False, - use_norm: bool = True, - max_position_embeddings: int = 2048, - hidden_act: str = "swish", - elementwise_affine: Optional[bool] = True, - norm_first: bool = True, - norm_eps: float = 1e-6, - attn: Optional[Dict] = None, - use_cache: bool = True, - initializer_range: float = 0.02, - fuse_norm: bool = True, - fuse_cross_entropy: bool = True, - # Vision specific parameters - image_size: int = 224, - patch_size: int = 16, - num_channels: int = 3, - num_classes: int = 1000, - qkv_bias: bool = True, - hidden_dropout_prob: float = 0.0, - use_mask_token: bool = False, - layer_norm_eps: float = 1e-6, - interpolate_pos_encoding: bool = False, - mlp_dim: int = None, - # FLA-for-vision-related parameters - scan_type: str = "uni-scan", # scaning type, "uni-scan" or "bi-scan" or "cross-scan", default to "uni-scan" - **kwargs - ): - self.hidden_size = hidden_size - self.gate_logit_normalizer = gate_logit_normalizer - self.clamp_min = clamp_min - self.clamp_max = clamp_max - self.num_hidden_layers = num_hidden_layers - self.num_heads = num_heads - self.num_kv_heads = num_kv_heads - self.num_slots = num_slots - self.use_short_conv = use_short_conv - self.conv_size = conv_size - self.expand_k = exapnd_k - self.expand_v = exapnd_v - self.feature_map = feature_map - self.use_output_gate = use_output_gate - self.use_norm = use_norm - self.max_position_embeddings = max_position_embeddings - self.hidden_act = hidden_act - self.elementwise_affine = elementwise_affine - self.norm_first = norm_first - self.norm_eps = norm_eps - self.use_cache = use_cache - self.initializer_range = initializer_range - self.fuse_cross_entropy = fuse_cross_entropy - self.fuse_norm = fuse_norm - - # Initialize vision specific parameters - self.image_size = image_size - self.patch_size = patch_size - self.num_channels = num_channels - self.num_classes = num_classes - self.qkv_bias = qkv_bias - self.hidden_dropout_prob = hidden_dropout_prob - self.use_mask_token = use_mask_token - self.layer_norm_eps = layer_norm_eps - self.interpolate_pos_encoding = interpolate_pos_encoding - self.scan_type = scan_type - - if attn is not None: - if not isinstance(attn, Dict): - raise ValueError("attn must be a dictionary") - if 'layers' not in attn: - raise ValueError("Layer indices must be provided to initialize hybrid attention layers") - if 'num_heads' not in attn: - raise ValueError("Number of heads must be provided to initialize hybrid attention layers") - attn['num_kv_heads'] = attn.get('num_kv_heads', attn['num_heads']) - attn['window_size'] = attn.get('window_size', None) - - self.attn = attn - - if mlp_dim is None: - self.mlp_dim = 4 * hidden_size # default value set to 4 * hidden_size - else: - self.mlp_dim = mlp_dim - - super().__init__(**kwargs) diff --git a/fla/vision_models/gsa/modeling_gsa.py b/fla/vision_models/gsa/modeling_gsa.py deleted file mode 100644 index 856eea93d..000000000 --- a/fla/vision_models/gsa/modeling_gsa.py +++ /dev/null @@ -1,209 +0,0 @@ -import collections.abc -import math -import torch -import torch.utils.checkpoint -from torch import nn -from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss -from typing import Optional, Set, Tuple, Union, List, Dict, Unpack -from transformers.utils import logging -from fla.layers.attn import Attention -from transformers.modeling_outputs import ImageClassifierOutput -from transformers.modeling_utils import PreTrainedModel -from .configuration_gsa import GSAVisionConfig -from fla.layers.gsa import GatedSlotAttention -from fla.models.utils import Cache -from ..utils import ImageEmbeddings, Pooler, prepare_hidden_states_for_cross_scan, prepare_hidden_states_for_cross_merge - -logger = logging.get_logger(__name__) - -class GSAMLP(nn.Module): - def __init__(self, config): - super().__init__() - self.net = nn.Sequential( - nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps), - nn.Linear(config.hidden_size, config.mlp_dim), - nn.GELU(), - nn.Linear(config.mlp_dim, config.hidden_size), - nn.Dropout(config.hidden_dropout_prob) - ) - - def forward(self, x): - return self.net(x) - -class GSABlock(nn.Module): - def __init__(self, config, layer_idx: int): - super().__init__() - - if not config.norm_first: - self.ln_1 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) - - if config.attn is not None and layer_idx in config.attn['layers']: - self.attn = Attention( - hidden_size=config.hidden_size, - num_heads=config.attn['num_heads'], - num_kv_heads=config.attn['num_kv_heads'], - window_size=config.attn['window_size'], - max_position_embeddings=config.max_position_embeddings, - layer_idx=layer_idx - ) - else: - self.attn = GatedSlotAttention( - hidden_size=config.hidden_size, - expand_k=config.expand_k, - expand_v=config.expand_v, - num_heads=config.num_heads, - num_kv_heads=config.num_kv_heads, - num_slots=config.num_slots, - use_short_conv=config.use_short_conv, - conv_size=config.conv_size, - feature_map=config.feature_map, - use_output_gate=config.use_output_gate, - use_norm=config.use_norm, - gate_fn=config.hidden_act, - gate_logit_normalizer=config.gate_logit_normalizer, - elementwise_affine=config.elementwise_affine, - norm_first=config.norm_first, - norm_eps=config.norm_eps, - fuse_norm=config.fuse_norm, - layer_idx=layer_idx - ) - - if not config.norm_first: - self.ln_2 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) - - self.mlp = GSAMLP(config) - - self.scan_type = config.scan_type - - def forward( - self, - hidden_states: torch.Tensor, - past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, - use_cache: Optional[bool] = False, - output_attentions: Optional[bool] = False, - **kwargs: Unpack[Dict] - ) -> Union[Tuple[torch.Tensor, Optional[torch.Tensor]], Tuple[torch.Tensor]]: - residual = hidden_states - - # Pre-normalization if enabled - if hasattr(self, 'ln_1'): - hidden_states = self.ln_1(hidden_states) - - # Apply attention - - hidden_states = prepare_hidden_states_for_cross_scan(hidden_states, self.scan_type) - - hidden_states, attentions, past_key_values = self.attn( - hidden_states=hidden_states, - past_key_values=past_key_values, - use_cache=use_cache, - output_attentions=output_attentions, - **kwargs - ) - - hidden_states = prepare_hidden_states_for_cross_merge(hidden_states, self.scan_type) - - # First residual connection - hidden_states = residual + hidden_states - residual = hidden_states - - # Pre-normalization for MLP if enabled - if hasattr(self, 'ln_2'): - hidden_states = self.ln_2(hidden_states) - - # MLP - hidden_states = self.mlp(hidden_states) - - # Second residual connection - hidden_states = residual + hidden_states - - outputs = (hidden_states, attentions, past_key_values) - - return outputs - -class GSAVisionPreTrainedModel(PreTrainedModel): - # this part of the code is adapted from huggingface/transformers vit implementation - config_class = GSAVisionConfig - - def _init_weights(self, module): - if isinstance(module, (nn.Linear, nn.Conv2d)): - module.weight.data = nn.init.trunc_normal_( - module.weight.data.to(torch.float32), mean=0.0, std=self.config.initializer_range - ).to(module.weight.dtype) - if module.bias is not None: - module.bias.data.zero_() - elif isinstance(module, nn.LayerNorm): - module.bias.data.zero_() - module.weight.data.fill_(1.0) - elif isinstance(module, ImageEmbeddings): - module.position_embeddings.data = nn.init.trunc_normal_( - module.position_embeddings.data.to(torch.float32), - mean=0.0, - std=self.config.initializer_range, - ).to(module.position_embeddings.dtype) - -class GSAForImageClassification(GSAVisionPreTrainedModel): - config_class = GSAVisionConfig - - def __init__(self, config): - super().__init__(config) - self.num_labels = config.num_classes - - self.embeddings = ImageEmbeddings(config) - self.blocks = nn.ModuleList([ - GSABlock(config, layer_idx) - for layer_idx in range(config.num_hidden_layers) - ]) - self.norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) - self.pooler = Pooler(config) - self.classifier = nn.Linear(config.hidden_size, config.num_classes) - self.interpolate_pos_encoding = config.interpolate_pos_encoding - self.init_weights() - - def forward( - self, - pixel_values: Optional[torch.Tensor] = None, - labels: Optional[torch.Tensor] = None, - output_hidden_states: Optional[bool] = None, - past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - return_dict: Optional[bool] = None, - **kwargs: Unpack[Dict] - ) -> Union[Tuple, ImageClassifierOutput]: - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - fuse_linear_and_cross_entropy = self.config.fuse_cross_entropy and self.training - hidden_states = self.embeddings(pixel_values, interpolate_pos_encoding=self.interpolate_pos_encoding) - - for block in self.blocks: - hidden_states, attentions, past_key_values = block( - hidden_states, - past_key_values=past_key_values, - use_cache=use_cache, - output_attentions=output_attentions, - **kwargs - ) - - hidden_states = self.norm(hidden_states) - pooled_output = self.pooler(hidden_states) - - logits = self.classifier(pooled_output) - - loss = None - if labels is not None: - if self.num_labels == 1: - loss_fct = MSELoss() - loss = loss_fct(logits.squeeze(), labels.squeeze()) - else: - loss_fct = CrossEntropyLoss() - loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) - - if not return_dict: - output = (logits,) + (hidden_states,) - return ((loss,) + output) if loss is not None else output - - return ImageClassifierOutput( - loss=loss, - logits=logits, - hidden_states=hidden_states, - ) diff --git a/fla/vision_models/hgrn/__init__.py b/fla/vision_models/hgrn/__init__.py deleted file mode 100644 index e9ab00ae0..000000000 --- a/fla/vision_models/hgrn/__init__.py +++ /dev/null @@ -1,12 +0,0 @@ -from transformers import AutoConfig, AutoModelForImageClassification - -from fla.vision_models.hgrn.configuration_hgrn import HGRNVisionConfig -from fla.vision_models.hgrn.modeling_hgrn import HGRNForImageClassification - -AutoConfig.register(HGRNVisionConfig.model_type, HGRNVisionConfig) -AutoModelForImageClassification.register(HGRNVisionConfig, HGRNForImageClassification) - -__all__ = [ - 'HGRNVisionConfig', - 'HGRNForImageClassification' -] diff --git a/fla/vision_models/hgrn/configuration_hgrn.py b/fla/vision_models/hgrn/configuration_hgrn.py deleted file mode 100644 index e9724239b..000000000 --- a/fla/vision_models/hgrn/configuration_hgrn.py +++ /dev/null @@ -1,86 +0,0 @@ -from typing import Dict, Optional - -from transformers.configuration_utils import PretrainedConfig - - -class HGRNVisionConfig(PretrainedConfig): - - model_type = 'hgrn_vision' - - def __init__( - self, - # HGRN core parameters - attn_mode: str = "chunk", - hidden_size: int = 2048, - num_hidden_layers: int = 24, - expand_ratio: Optional[int] = 1, - use_short_conv: bool = False, - conv_size: int = 4, - use_lower_bound: bool = True, - max_position_embeddings: int = 2048, - hidden_act: str = "swish", - elementwise_affine: Optional[bool] = True, - norm_eps: float = 1e-6, - attn: Optional[Dict] = None, - use_cache: bool = True, - initializer_range: float = 0.02, - fuse_cross_entropy: bool = True, - # Vision specific parameters - image_size: int = 224, - patch_size: int = 16, - num_channels: int = 3, - num_classes: int = 1000, - qkv_bias: bool = True, - hidden_dropout_prob: float = 0.0, - use_mask_token: bool = False, - layer_norm_eps: float = 1e-6, - interpolate_pos_encoding: bool = False, - mlp_dim: int = None, - # FLA-for-vision-related parameters - scan_type: str = "uni-scan", # scaning type, "uni-scan" or "bi-scan" or "cross-scan", default to "uni-scan" - **kwargs - ): - # Initialize HGRN core parameters - self.attn_mode = attn_mode - self.hidden_size = hidden_size - self.num_hidden_layers = num_hidden_layers - self.expand_ratio = expand_ratio - self.use_short_conv = use_short_conv - self.conv_size = conv_size - self.use_lower_bound = use_lower_bound - self.max_position_embeddings = max_position_embeddings - self.elementwise_affine = elementwise_affine - self.norm_eps = norm_eps - self.hidden_act = hidden_act - self.use_cache = use_cache - self.initializer_range = initializer_range - self.fuse_cross_entropy = fuse_cross_entropy - # Initialize vision specific parameters - self.image_size = image_size - self.patch_size = patch_size - self.num_channels = num_channels - self.num_classes = num_classes - self.qkv_bias = qkv_bias - self.hidden_dropout_prob = hidden_dropout_prob - self.use_mask_token = use_mask_token - self.layer_norm_eps = layer_norm_eps - self.interpolate_pos_encoding = interpolate_pos_encoding - self.scan_type = scan_type - if attn is not None: - if not isinstance(attn, Dict): - raise ValueError("attn must be a dictionary") - if 'layers' not in attn: - raise ValueError("Layer indices must be provided to initialize hybrid attention layers") - if 'num_heads' not in attn: - raise ValueError("Number of heads must be provided to initialize hybrid attention layers") - attn['num_kv_heads'] = attn.get('num_kv_heads', attn['num_heads']) - attn['window_size'] = attn.get('window_size', None) - - self.attn = attn - - if mlp_dim is None: - self.mlp_dim = 4 * hidden_size # default value set to 4 * hidden_size - else: - self.mlp_dim = mlp_dim - - super().__init__(**kwargs) diff --git a/fla/vision_models/hgrn/modeling_hgrn.py b/fla/vision_models/hgrn/modeling_hgrn.py deleted file mode 100644 index 35d6e21bf..000000000 --- a/fla/vision_models/hgrn/modeling_hgrn.py +++ /dev/null @@ -1,197 +0,0 @@ -import collections.abc -import math -import torch -import torch.utils.checkpoint -from torch import nn -from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss -from typing import Optional, Set, Tuple, Union, List, Dict, Unpack -from transformers.utils import logging -from fla.layers.attn import Attention -from transformers.modeling_outputs import ImageClassifierOutput -from transformers.modeling_utils import PreTrainedModel -from .configuration_hgrn import HGRNVisionConfig -from fla.layers.hgrn import HGRNAttention -from fla.models.utils import Cache -from ..utils import ImageEmbeddings, Pooler, prepare_hidden_states_for_cross_scan, prepare_hidden_states_for_cross_merge - -logger = logging.get_logger(__name__) - -class HGRNMLP(nn.Module): - def __init__(self, config): - super().__init__() - self.net = nn.Sequential( - nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps), - nn.Linear(config.hidden_size, config.mlp_dim), - nn.GELU(), - nn.Linear(config.mlp_dim, config.hidden_size), - nn.Dropout(config.hidden_dropout_prob) - ) - - def forward(self, x): - return self.net(x) - -class HGRNBlock(nn.Module): - def __init__(self, config, layer_idx: int): - super().__init__() - - self.ln_1 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) - - if config.attn is not None and layer_idx in config.attn['layers']: - self.attn = Attention( - hidden_size=config.hidden_size, - num_heads=config.attn['num_heads'], - num_kv_heads=config.attn['num_kv_heads'], - window_size=config.attn['window_size'], - max_position_embeddings=config.max_position_embeddings, - layer_idx=layer_idx - ) - else: - self.attn = HGRNAttention( - mode=config.attn_mode, - hidden_size=config.hidden_size, - expand_ratio=config.expand_ratio, - use_short_conv=config.use_short_conv, - conv_size=config.conv_size, - elementwise_affine=config.elementwise_affine, - norm_eps=config.norm_eps, - layer_idx=layer_idx - ) - - self.ln_2 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) - - self.mlp = HGRNMLP(config) - - self.scan_type = config.scan_type - - def forward( - self, - hidden_states: torch.Tensor, - past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, - use_cache: Optional[bool] = False, - output_attentions: Optional[bool] = False, - **kwargs: Unpack[Dict] - ) -> Union[Tuple[torch.Tensor, Optional[torch.Tensor]], Tuple[torch.Tensor]]: - residual = hidden_states - - # Pre-normalization if enabled - if hasattr(self, 'ln_1'): - hidden_states = self.ln_1(hidden_states) - - # Apply attention - - hidden_states = prepare_hidden_states_for_cross_scan(hidden_states, self.scan_type) - - hidden_states, attentions, past_key_values = self.attn( - hidden_states=hidden_states, - past_key_values=past_key_values, - use_cache=use_cache, - output_attentions=output_attentions, - **kwargs - ) - - hidden_states = prepare_hidden_states_for_cross_merge(hidden_states, self.scan_type) - - # First residual connection - hidden_states = residual + hidden_states - residual = hidden_states - - # Pre-normalization for MLP if enabled - if hasattr(self, 'ln_2'): - hidden_states = self.ln_2(hidden_states) - - # MLP - hidden_states = self.mlp(hidden_states) - - # Second residual connection - hidden_states = residual + hidden_states - - outputs = (hidden_states, attentions, past_key_values) - - return outputs - -class HGRNVisionPreTrainedModel(PreTrainedModel): - # this part of the code is adapted from huggingface/transformers vit implementation - config_class = HGRNVisionConfig - - def _init_weights(self, module): - if isinstance(module, (nn.Linear, nn.Conv2d)): - module.weight.data = nn.init.trunc_normal_( - module.weight.data.to(torch.float32), mean=0.0, std=self.config.initializer_range - ).to(module.weight.dtype) - if module.bias is not None: - module.bias.data.zero_() - elif isinstance(module, nn.LayerNorm): - module.bias.data.zero_() - module.weight.data.fill_(1.0) - elif isinstance(module, ImageEmbeddings): - module.position_embeddings.data = nn.init.trunc_normal_( - module.position_embeddings.data.to(torch.float32), - mean=0.0, - std=self.config.initializer_range, - ).to(module.position_embeddings.dtype) - -class HGRNForImageClassification(HGRNVisionPreTrainedModel): - config_class = HGRNVisionConfig - - def __init__(self, config): - super().__init__(config) - self.num_labels = config.num_classes - - self.embeddings = ImageEmbeddings(config) - self.blocks = nn.ModuleList([ - HGRNBlock(config, layer_idx) - for layer_idx in range(config.num_hidden_layers) - ]) - self.norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) - self.pooler = Pooler(config) - self.classifier = nn.Linear(config.hidden_size, config.num_classes) - self.interpolate_pos_encoding = config.interpolate_pos_encoding - self.init_weights() - - def forward( - self, - pixel_values: Optional[torch.Tensor] = None, - labels: Optional[torch.Tensor] = None, - output_hidden_states: Optional[bool] = None, - past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - return_dict: Optional[bool] = None, - **kwargs: Unpack[Dict] - ) -> Union[Tuple, ImageClassifierOutput]: - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - fuse_linear_and_cross_entropy = self.config.fuse_cross_entropy and self.training - hidden_states = self.embeddings(pixel_values, interpolate_pos_encoding=self.interpolate_pos_encoding) - - for block in self.blocks: - hidden_states, attentions, past_key_values = block( - hidden_states, - past_key_values=past_key_values, - use_cache=use_cache, - output_attentions=output_attentions, - **kwargs - ) - - hidden_states = self.norm(hidden_states) - pooled_output = self.pooler(hidden_states) - - logits = self.classifier(pooled_output) - - loss = None - if labels is not None: - if self.num_labels == 1: - loss_fct = MSELoss() - loss = loss_fct(logits.squeeze(), labels.squeeze()) - else: - loss_fct = CrossEntropyLoss() - loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) - - if not return_dict: - output = (logits,) + (hidden_states,) - return ((loss,) + output) if loss is not None else output - - return ImageClassifierOutput( - loss=loss, - logits=logits, - hidden_states=hidden_states, - ) diff --git a/fla/vision_models/hgrn2/__init__.py b/fla/vision_models/hgrn2/__init__.py deleted file mode 100644 index 69a2c9c55..000000000 --- a/fla/vision_models/hgrn2/__init__.py +++ /dev/null @@ -1,12 +0,0 @@ -from transformers import AutoConfig, AutoModelForImageClassification - -from fla.vision_models.hgrn2.configuration_hgrn2 import HGRN2VisionConfig -from fla.vision_models.hgrn2.modeling_hgrn2 import HGRN2ForImageClassification - -AutoConfig.register(HGRN2VisionConfig.model_type, HGRN2VisionConfig) -AutoModelForImageClassification.register(HGRN2VisionConfig, HGRN2ForImageClassification) - -__all__ = [ - 'HGRN2VisionConfig', - 'HGRN2ForImageClassification' -] diff --git a/fla/vision_models/hgrn2/configuration_hgrn2.py b/fla/vision_models/hgrn2/configuration_hgrn2.py deleted file mode 100644 index ef6ffc83b..000000000 --- a/fla/vision_models/hgrn2/configuration_hgrn2.py +++ /dev/null @@ -1,89 +0,0 @@ -from typing import Dict, Optional - -from transformers.configuration_utils import PretrainedConfig - - -class HGRN2VisionConfig(PretrainedConfig): - - model_type = 'hgrn2_vision' - - def __init__( - self, - # HGRN2 core parameters - hidden_size: int = 2048, - num_hidden_layers: int = 24, - attn_mode: str = "chunk", - num_heads: Optional[int] = None, - expand_ratio: Optional[int] = 128, - use_short_conv: bool = False, - conv_size: int = 4, - use_lower_bound: bool = True, - hidden_act: str = "swish", - max_position_embeddings: int = 2048, - elementwise_affine: Optional[bool] = True, - norm_eps: float = 1e-6, - attn: Optional[Dict] = None, - use_cache: bool = True, - initializer_range: float = 0.02, - fuse_cross_entropy: bool = True, - # Vision specific parameters - image_size: int = 224, - patch_size: int = 16, - num_channels: int = 3, - num_classes: int = 1000, - qkv_bias: bool = True, - hidden_dropout_prob: float = 0.0, - use_mask_token: bool = False, - layer_norm_eps: float = 1e-6, - interpolate_pos_encoding: bool = False, - mlp_dim: int = None, - # FLA-for-vision-related parameters - scan_type: str = "uni-scan", # scaning type, "uni-scan" or "bi-scan" or "cross-scan", default to "uni-scan" - **kwargs - ): - # Initialize HGRN2 core parameters - self.hidden_size = hidden_size - self.num_hidden_layers = num_hidden_layers - self.attn_mode = attn_mode - self.num_heads = num_heads - self.expand_ratio = expand_ratio - self.use_short_conv = use_short_conv - self.conv_size = conv_size - self.use_lower_bound = use_lower_bound - self.max_position_embeddings = max_position_embeddings - self.hidden_act = hidden_act - self.elementwise_affine = elementwise_affine - self.norm_eps = norm_eps - self.use_cache = use_cache - self.initializer_range = initializer_range - self.fuse_cross_entropy = fuse_cross_entropy - # Initialize vision specific parameters - self.image_size = image_size - self.patch_size = patch_size - self.num_channels = num_channels - self.num_classes = num_classes - self.qkv_bias = qkv_bias - self.hidden_dropout_prob = hidden_dropout_prob - self.use_mask_token = use_mask_token - self.layer_norm_eps = layer_norm_eps - self.interpolate_pos_encoding = interpolate_pos_encoding - self.scan_type = scan_type - - if attn is not None: - if not isinstance(attn, Dict): - raise ValueError("attn must be a dictionary") - if 'layers' not in attn: - raise ValueError("Layer indices must be provided to initialize hybrid attention layers") - if 'num_heads' not in attn: - raise ValueError("Number of heads must be provided to initialize hybrid attention layers") - attn['num_kv_heads'] = attn.get('num_kv_heads', attn['num_heads']) - attn['window_size'] = attn.get('window_size', None) - - self.attn = attn - - if mlp_dim is None: - self.mlp_dim = 4 * hidden_size # default value set to 4 * hidden_size - else: - self.mlp_dim = mlp_dim - - super().__init__(**kwargs) \ No newline at end of file diff --git a/fla/vision_models/hgrn2/modeling_hgrn2.py b/fla/vision_models/hgrn2/modeling_hgrn2.py deleted file mode 100644 index cbae1a64a..000000000 --- a/fla/vision_models/hgrn2/modeling_hgrn2.py +++ /dev/null @@ -1,198 +0,0 @@ -import collections.abc -import math -import torch -import torch.utils.checkpoint -from torch import nn -from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss -from typing import Optional, Set, Tuple, Union, List, Dict, Unpack -from transformers.utils import logging -from fla.layers.attn import Attention -from transformers.modeling_outputs import ImageClassifierOutput -from transformers.modeling_utils import PreTrainedModel -from .configuration_hgrn2 import HGRN2VisionConfig -from fla.layers.hgrn2 import HGRN2Attention -from fla.models.utils import Cache -from ..utils import ImageEmbeddings, Pooler, prepare_hidden_states_for_cross_scan, prepare_hidden_states_for_cross_merge - -logger = logging.get_logger(__name__) - -class HGRN2MLP(nn.Module): - def __init__(self, config): - super().__init__() - self.net = nn.Sequential( - nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps), - nn.Linear(config.hidden_size, config.mlp_dim), - nn.GELU(), - nn.Linear(config.mlp_dim, config.hidden_size), - nn.Dropout(config.hidden_dropout_prob) - ) - - def forward(self, x): - return self.net(x) - -class HGRN2Block(nn.Module): - def __init__(self, config, layer_idx: int): - super().__init__() - - self.ln_1 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) - - if config.attn is not None and layer_idx in config.attn['layers']: - self.attn = Attention( - hidden_size=config.hidden_size, - num_heads=config.attn['num_heads'], - num_kv_heads=config.attn['num_kv_heads'], - window_size=config.attn['window_size'], - max_position_embeddings=config.max_position_embeddings, - layer_idx=layer_idx - ) - else: - self.attn = HGRN2Attention( - mode=config.attn_mode, - hidden_size=config.hidden_size, - num_heads=config.num_heads, - expand_ratio=config.expand_ratio, - use_short_conv=config.use_short_conv, - conv_size=config.conv_size, - elementwise_affine=config.elementwise_affine, - norm_eps=config.norm_eps, - layer_idx=layer_idx - ) - - self.ln_2 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) - - self.mlp = HGRN2MLP(config) - - self.scan_type = config.scan_type - - def forward( - self, - hidden_states: torch.Tensor, - past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, - use_cache: Optional[bool] = False, - output_attentions: Optional[bool] = False, - **kwargs: Unpack[Dict] - ) -> Union[Tuple[torch.Tensor, Optional[torch.Tensor]], Tuple[torch.Tensor]]: - residual = hidden_states - - # Pre-normalization if enabled - if hasattr(self, 'ln_1'): - hidden_states = self.ln_1(hidden_states) - - # Apply attention - - hidden_states = prepare_hidden_states_for_cross_scan(hidden_states, self.scan_type) - - hidden_states, attentions, past_key_values = self.attn( - hidden_states=hidden_states, - past_key_values=past_key_values, - use_cache=use_cache, - output_attentions=output_attentions, - **kwargs - ) - - hidden_states = prepare_hidden_states_for_cross_merge(hidden_states, self.scan_type) - - # First residual connection - hidden_states = residual + hidden_states - residual = hidden_states - - # Pre-normalization for MLP if enabled - if hasattr(self, 'ln_2'): - hidden_states = self.ln_2(hidden_states) - - # MLP - hidden_states = self.mlp(hidden_states) - - # Second residual connection - hidden_states = residual + hidden_states - - outputs = (hidden_states, attentions, past_key_values) - - return outputs - -class HGRN2VisionPreTrainedModel(PreTrainedModel): - # this part of the code is adapted from huggingface/transformers vit implementation - config_class = HGRN2VisionConfig - - def _init_weights(self, module): - if isinstance(module, (nn.Linear, nn.Conv2d)): - module.weight.data = nn.init.trunc_normal_( - module.weight.data.to(torch.float32), mean=0.0, std=self.config.initializer_range - ).to(module.weight.dtype) - if module.bias is not None: - module.bias.data.zero_() - elif isinstance(module, nn.LayerNorm): - module.bias.data.zero_() - module.weight.data.fill_(1.0) - elif isinstance(module, ImageEmbeddings): - module.position_embeddings.data = nn.init.trunc_normal_( - module.position_embeddings.data.to(torch.float32), - mean=0.0, - std=self.config.initializer_range, - ).to(module.position_embeddings.dtype) - -class HGRN2ForImageClassification(HGRN2VisionPreTrainedModel): - config_class = HGRN2VisionConfig - - def __init__(self, config): - super().__init__(config) - self.num_labels = config.num_classes - - self.embeddings = ImageEmbeddings(config) - self.blocks = nn.ModuleList([ - HGRN2Block(config, layer_idx) - for layer_idx in range(config.num_hidden_layers) - ]) - self.norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) - self.pooler = Pooler(config) - self.classifier = nn.Linear(config.hidden_size, config.num_classes) - self.interpolate_pos_encoding = config.interpolate_pos_encoding - self.init_weights() - - def forward( - self, - pixel_values: Optional[torch.Tensor] = None, - labels: Optional[torch.Tensor] = None, - output_hidden_states: Optional[bool] = None, - past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - return_dict: Optional[bool] = None, - **kwargs: Unpack[Dict] - ) -> Union[Tuple, ImageClassifierOutput]: - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - fuse_linear_and_cross_entropy = self.config.fuse_cross_entropy and self.training - hidden_states = self.embeddings(pixel_values, interpolate_pos_encoding=self.interpolate_pos_encoding) - - for block in self.blocks: - hidden_states, attentions, past_key_values = block( - hidden_states, - past_key_values=past_key_values, - use_cache=use_cache, - output_attentions=output_attentions, - **kwargs - ) - - hidden_states = self.norm(hidden_states) - pooled_output = self.pooler(hidden_states) - - logits = self.classifier(pooled_output) - - loss = None - if labels is not None: - if self.num_labels == 1: - loss_fct = MSELoss() - loss = loss_fct(logits.squeeze(), labels.squeeze()) - else: - loss_fct = CrossEntropyLoss() - loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) - - if not return_dict: - output = (logits,) + (hidden_states,) - return ((loss,) + output) if loss is not None else output - - return ImageClassifierOutput( - loss=loss, - logits=logits, - hidden_states=hidden_states, - ) diff --git a/fla/vision_models/linear_attn/__init__.py b/fla/vision_models/linear_attn/__init__.py deleted file mode 100644 index d56bc5e04..000000000 --- a/fla/vision_models/linear_attn/__init__.py +++ /dev/null @@ -1,12 +0,0 @@ -from transformers import AutoConfig, AutoModelForImageClassification - -from fla.vision_models.linear_attn.configuration_linear_attn import LinearAttentionVisionConfig -from fla.vision_models.linear_attn.modeling_linear_attn import LinearAttentionForImageClassification - -AutoConfig.register(LinearAttentionVisionConfig.model_type, LinearAttentionVisionConfig) -AutoModelForImageClassification.register(LinearAttentionVisionConfig, LinearAttentionForImageClassification) - -__all__ = [ - 'LinearAttentionVisionConfig', - 'LinearAttentionForImageClassification' -] diff --git a/fla/vision_models/linear_attn/configuration_linear_attn.py b/fla/vision_models/linear_attn/configuration_linear_attn.py deleted file mode 100644 index 8aa0d2ef0..000000000 --- a/fla/vision_models/linear_attn/configuration_linear_attn.py +++ /dev/null @@ -1,96 +0,0 @@ -from typing import Dict, Optional - -from transformers.configuration_utils import PretrainedConfig - - -class LinearAttentionVisionConfig(PretrainedConfig): - - model_type = 'linear_attn_vision' - - def __init__( - self, - # LinearAttention core parameters - attn_mode: str = "fused_chunk", - hidden_size: int = 2048, - expand_k: int = 1, - expand_v: int = 1, - num_hidden_layers: int = 24, - num_heads: int = 4, - num_kv_heads: Optional[int] = None, - feature_map: str = "elementwise_product", - tie_feature_map_qk: bool = False, - norm_q: bool = False, - norm_k: bool = False, - norm_feature_map: bool = False, - hidden_act: str = "swish", - max_position_embeddings: int = 2048, - elementwise_affine: Optional[bool] = True, - norm_eps: float = 1e-6, - attn: Optional[Dict] = None, - use_cache: bool = True, - initializer_range: float = 0.02, - fuse_cross_entropy: bool = True, - # Vision specific parameters - image_size: int = 224, - patch_size: int = 16, - num_channels: int = 3, - num_classes: int = 1000, - qkv_bias: bool = True, - hidden_dropout_prob: float = 0.0, - use_mask_token: bool = False, - layer_norm_eps: float = 1e-6, - interpolate_pos_encoding: bool = False, - mlp_dim: int = None, - # FLA-for-vision-related parameters - scan_type: str = "uni-scan", # scaning type, "uni-scan" or "bi-scan" or "cross-scan", default to "uni-scan" - **kwargs - ): - # Initialize LinearAttention core parameters - self.attn_mode = attn_mode - self.hidden_size = hidden_size - self.expand_k = expand_k - self.expand_v = expand_v - self.num_hidden_layers = num_hidden_layers - self.num_heads = num_heads - self.num_kv_heads = num_kv_heads - self.feature_map = feature_map - self.tie_feature_map_qk = tie_feature_map_qk - self.norm_q = norm_q - self.norm_k = norm_k - self.norm_feature_map = norm_feature_map - self.max_position_embeddings = max_position_embeddings - self.elementwise_affine = elementwise_affine - self.norm_eps = norm_eps - self.use_cache = use_cache - self.initializer_range = initializer_range - self.fuse_cross_entropy = fuse_cross_entropy - - # Initialize vision specific parameters - self.image_size = image_size - self.patch_size = patch_size - self.num_channels = num_channels - self.num_classes = num_classes - self.qkv_bias = qkv_bias - self.hidden_dropout_prob = hidden_dropout_prob - self.use_mask_token = use_mask_token - self.layer_norm_eps = layer_norm_eps - self.interpolate_pos_encoding = interpolate_pos_encoding - self.scan_type = scan_type - - if attn is not None: - if not isinstance(attn, Dict): - raise ValueError("attn must be a dictionary") - if 'layers' not in attn: - raise ValueError("Layer indices must be provided to initialize hybrid attention layers") - if 'num_heads' not in attn: - raise ValueError("Number of heads must be provided to initialize hybrid attention layers") - attn['num_kv_heads'] = attn.get('num_kv_heads', attn['num_heads']) - attn['window_size'] = attn.get('window_size', None) - - self.attn = attn - if mlp_dim is None: - self.mlp_dim = 4 * hidden_size # default value set to 4 * hidden_size - else: - self.mlp_dim = mlp_dim - - super().__init__(**kwargs) \ No newline at end of file diff --git a/fla/vision_models/linear_attn/modeling_linear_attn.py b/fla/vision_models/linear_attn/modeling_linear_attn.py deleted file mode 100644 index f0889a493..000000000 --- a/fla/vision_models/linear_attn/modeling_linear_attn.py +++ /dev/null @@ -1,197 +0,0 @@ -import collections.abc -import math -import torch -import torch.utils.checkpoint -from torch import nn -from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss -from typing import Optional, Set, Tuple, Union, List, Dict, Unpack -from transformers.utils import logging -from fla.layers.attn import Attention -from transformers.modeling_outputs import ImageClassifierOutput -from transformers.modeling_utils import PreTrainedModel -from .configuration_linear_attn import LinearAttentionVisionConfig -from fla.layers.linear_attn import LinearAttention -from fla.models.utils import Cache -from ..utils import ImageEmbeddings, Pooler, prepare_hidden_states_for_cross_scan, prepare_hidden_states_for_cross_merge - -logger = logging.get_logger(__name__) - -class LinearAttentionMLP(nn.Module): - def __init__(self, config): - super().__init__() - self.net = nn.Sequential( - nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps), - nn.Linear(config.hidden_size, config.mlp_dim), - nn.GELU(), - nn.Linear(config.mlp_dim, config.hidden_size), - nn.Dropout(config.hidden_dropout_prob) - ) - - def forward(self, x): - return self.net(x) - -class LinearAttentionBlock(nn.Module): - def __init__(self, config, layer_idx: int): - super().__init__() - - self.ln_1 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) - - if config.attn is not None and layer_idx in config.attn['layers']: - self.attn = Attention( - hidden_size=config.hidden_size, - num_heads=config.attn['num_heads'], - num_kv_heads=config.attn['num_kv_heads'], - window_size=config.attn['window_size'], - max_position_embeddings=config.max_position_embeddings, - layer_idx=layer_idx - ) - else: - self.attn = LinearAttention( - mode=config.attn_mode, - hidden_size=config.hidden_size, - expand_k=config.expand_k, - expand_v=config.expand_v, - num_heads=config.num_heads, - num_kv_heads=config.num_kv_heads, - feature_map=config.feature_map, - tie_feature_map_qk=config.tie_feature_map_qk, - norm_q=config.norm_q, - norm_k=config.norm_k, - do_feature_map_norm=config.norm_feature_map, - elementwise_affine=config.elementwise_affine, - norm_eps=config.norm_eps, - layer_idx=layer_idx - ) - - self.ln_2 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) - - self.mlp = LinearAttentionMLP(config) - - self.scan_type = config.scan_type - - def forward( - self, - hidden_states: torch.Tensor, - past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, - use_cache: Optional[bool] = False, - output_attentions: Optional[bool] = False, - **kwargs: Unpack[Dict] - ) -> Union[Tuple[torch.Tensor, Optional[torch.Tensor]], Tuple[torch.Tensor]]: - residual = hidden_states - - # Pre-normalization if enabled - if hasattr(self, 'ln_1'): - hidden_states = self.ln_1(hidden_states) - - # Apply attention - - hidden_states = prepare_hidden_states_for_cross_scan(hidden_states, self.scan_type) - - hidden_states = self.attn(hidden_states) - - hidden_states = prepare_hidden_states_for_cross_merge(hidden_states, self.scan_type) - - # First residual connection - hidden_states = residual + hidden_states - residual = hidden_states - - # Pre-normalization for MLP if enabled - if hasattr(self, 'ln_2'): - hidden_states = self.ln_2(hidden_states) - - # MLP - hidden_states = self.mlp(hidden_states) - - # Second residual connection - hidden_states = residual + hidden_states - - outputs = (hidden_states,) - - return outputs - -class LinearAttentionVisionPreTrainedModel(PreTrainedModel): - # this part of the code is adapted from huggingface/transformers vit implementation - config_class = LinearAttentionVisionConfig - - def _init_weights(self, module): - if isinstance(module, (nn.Linear, nn.Conv2d)): - module.weight.data = nn.init.trunc_normal_( - module.weight.data.to(torch.float32), mean=0.0, std=self.config.initializer_range - ).to(module.weight.dtype) - if module.bias is not None: - module.bias.data.zero_() - elif isinstance(module, nn.LayerNorm): - module.bias.data.zero_() - module.weight.data.fill_(1.0) - elif isinstance(module, ImageEmbeddings): - module.position_embeddings.data = nn.init.trunc_normal_( - module.position_embeddings.data.to(torch.float32), - mean=0.0, - std=self.config.initializer_range, - ).to(module.position_embeddings.dtype) - -class LinearAttentionForImageClassification(LinearAttentionVisionPreTrainedModel): - config_class = LinearAttentionVisionConfig - - def __init__(self, config): - super().__init__(config) - self.num_labels = config.num_classes - - self.embeddings = ImageEmbeddings(config) - self.blocks = nn.ModuleList([ - LinearAttentionBlock(config, layer_idx) - for layer_idx in range(config.num_hidden_layers) - ]) - self.norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) - self.pooler = Pooler(config) - self.classifier = nn.Linear(config.hidden_size, config.num_classes) - self.interpolate_pos_encoding = config.interpolate_pos_encoding - self.init_weights() - - def forward( - self, - pixel_values: Optional[torch.Tensor] = None, - labels: Optional[torch.Tensor] = None, - output_hidden_states: Optional[bool] = None, - past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - return_dict: Optional[bool] = None, - **kwargs: Unpack[Dict] - ) -> Union[Tuple, ImageClassifierOutput]: - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - fuse_linear_and_cross_entropy = self.config.fuse_cross_entropy and self.training - hidden_states = self.embeddings(pixel_values, interpolate_pos_encoding=self.interpolate_pos_encoding) - - for block in self.blocks: - hidden_states = block( - hidden_states, - past_key_values=past_key_values, - use_cache=use_cache, - output_attentions=output_attentions, - **kwargs - ) - - hidden_states = self.norm(hidden_states) - pooled_output = self.pooler(hidden_states) - - logits = self.classifier(pooled_output) - - loss = None - if labels is not None: - if self.num_labels == 1: - loss_fct = MSELoss() - loss = loss_fct(logits.squeeze(), labels.squeeze()) - else: - loss_fct = CrossEntropyLoss() - loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) - - if not return_dict: - output = (logits,) + (hidden_states,) - return ((loss,) + output) if loss is not None else output - - return ImageClassifierOutput( - loss=loss, - logits=logits, - hidden_states=hidden_states, - ) diff --git a/fla/vision_models/retnet/__init__.py b/fla/vision_models/retnet/__init__.py deleted file mode 100644 index 4a32b420f..000000000 --- a/fla/vision_models/retnet/__init__.py +++ /dev/null @@ -1,12 +0,0 @@ -from transformers import AutoConfig, AutoModelForImageClassification - -from fla.vision_models.retnet.configuration_retnet import RetNetVisionConfig -from fla.vision_models.retnet.modeling_retnet import RetNetForImageClassification - -AutoConfig.register(RetNetVisionConfig.model_type, RetNetVisionConfig) -AutoModelForImageClassification.register(RetNetVisionConfig, RetNetForImageClassification) - -__all__ = [ - 'RetNetVisionConfig', - 'RetNetForImageClassification' -] diff --git a/fla/vision_models/retnet/configuration_retnet.py b/fla/vision_models/retnet/configuration_retnet.py deleted file mode 100644 index 53df13698..000000000 --- a/fla/vision_models/retnet/configuration_retnet.py +++ /dev/null @@ -1,101 +0,0 @@ -from __future__ import annotations - -from typing import Dict, Optional - -from transformers.configuration_utils import PretrainedConfig - - -class RetNetVisionConfig(PretrainedConfig): - - model_type = 'retnet_vision' - - def __init__( - self, - # RetNet core parameters - attn_mode: str = "chunk", - hidden_size: int = 2048, - expand_k: int = 1, - expand_v: int = 2, - num_hidden_layers: int = 24, - num_heads: int = 8, - num_kv_heads: Optional[int] = None, - feature_map: Optional[str] = None, - hidden_act: str = "swish", - use_short_conv: bool = False, - conv_size: int = 4, - use_output_gate: bool = True, - max_position_embeddings: int = 2048, - elementwise_affine: Optional[bool] = True, - norm_eps: float = 1e-6, - attn: Optional[Dict] = None, - use_cache: bool = True, - initializer_range: float = 0.02, - fuse_norm: bool = True, - fuse_cross_entropy: bool = True, - # Vision specific parameters - image_size: int = 224, - patch_size: int = 16, - num_channels: int = 3, - num_classes: int = 1000, - qkv_bias: bool = True, - hidden_dropout_prob: float = 0.0, - use_mask_token: bool = False, - layer_norm_eps: float = 1e-6, - interpolate_pos_encoding: bool = False, - mlp_dim: int = None, - # FLA-for-vision-related parameters - scan_type: str = "uni-scan", # scaning type, "uni-scan" or "bi-scan" or "cross-scan", default to "uni-scan" - **kwargs - ) -> RetNetVisionConfig: - # Initialize RetNet core parameters - self.attn_mode = attn_mode - self.hidden_size = hidden_size - self.expand_k = expand_k - self.expand_v = expand_v - self.num_hidden_layers = num_hidden_layers - self.num_heads = num_heads - self.num_kv_heads = num_kv_heads - self.feature_map = feature_map - self.hidden_act = hidden_act - self.use_short_conv = use_short_conv - self.conv_size = conv_size - self.use_output_gate = use_output_gate - self.hidden_act = hidden_act - self.max_position_embeddings = max_position_embeddings - self.elementwise_affine = elementwise_affine - self.norm_eps = norm_eps - self.use_cache = use_cache - self.initializer_range = initializer_range - self.fuse_norm = fuse_norm - self.fuse_cross_entropy = fuse_cross_entropy - - # Initialize vision specific parameters - self.image_size = image_size - self.patch_size = patch_size - self.num_channels = num_channels - self.num_classes = num_classes - self.qkv_bias = qkv_bias - self.hidden_dropout_prob = hidden_dropout_prob - self.use_mask_token = use_mask_token - self.layer_norm_eps = layer_norm_eps - self.interpolate_pos_encoding = interpolate_pos_encoding - self.scan_type = scan_type - - if attn is not None: - if not isinstance(attn, Dict): - raise ValueError("attn must be a dictionary") - if 'layers' not in attn: - raise ValueError("Layer indices must be provided to initialize hybrid attention layers") - if 'num_heads' not in attn: - raise ValueError("Number of heads must be provided to initialize hybrid attention layers") - attn['num_kv_heads'] = attn.get('num_kv_heads', attn['num_heads']) - attn['window_size'] = attn.get('window_size', None) - - self.attn = attn - - if mlp_dim is None: - self.mlp_dim = 4 * hidden_size # default value set to 4 * hidden_size - else: - self.mlp_dim = mlp_dim - - super().__init__(**kwargs) \ No newline at end of file diff --git a/fla/vision_models/retnet/modeling_retnet.py b/fla/vision_models/retnet/modeling_retnet.py deleted file mode 100644 index 961ea7c71..000000000 --- a/fla/vision_models/retnet/modeling_retnet.py +++ /dev/null @@ -1,202 +0,0 @@ -import collections.abc -import math -import torch -import torch.utils.checkpoint -from torch import nn -from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss -from typing import Optional, Set, Tuple, Union, List, Dict, Unpack -from transformers.utils import logging -from fla.layers.attn import Attention -from transformers.modeling_outputs import ImageClassifierOutput -from transformers.modeling_utils import PreTrainedModel -from .configuration_retnet import RetNetVisionConfig -from fla.layers.multiscale_retention import MultiScaleRetention -from fla.models.utils import Cache -from ..utils import ImageEmbeddings, Pooler, prepare_hidden_states_for_cross_scan, prepare_hidden_states_for_cross_merge - -logger = logging.get_logger(__name__) - -class RetNetMLP(nn.Module): - def __init__(self, config): - super().__init__() - self.net = nn.Sequential( - nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps), - nn.Linear(config.hidden_size, config.mlp_dim), - nn.GELU(), - nn.Linear(config.mlp_dim, config.hidden_size), - nn.Dropout(config.hidden_dropout_prob) - ) - - def forward(self, x): - return self.net(x) - -class RetNetBlock(nn.Module): - def __init__(self, config, layer_idx: int): - super().__init__() - - self.ln_1 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) - - if config.attn is not None and layer_idx in config.attn['layers']: - self.attn = Attention( - hidden_size=config.hidden_size, - num_heads=config.attn['num_heads'], - num_kv_heads=config.attn['num_kv_heads'], - window_size=config.attn['window_size'], - max_position_embeddings=config.max_position_embeddings, - layer_idx=layer_idx - ) - else: - self.attn = MultiScaleRetention( - mode=config.attn_mode, - hidden_size=config.hidden_size, - expand_k=config.expand_k, - expand_v=config.expand_v, - num_heads=config.num_heads, - num_kv_heads=config.num_kv_heads, - feature_map=config.feature_map, - use_output_gate=config.use_output_gate, - gate_fn=config.hidden_act, - elementwise_affine=config.elementwise_affine, - norm_eps=config.norm_eps, - fuse_norm=config.fuse_norm, - layer_idx=layer_idx - ) - - self.ln_2 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) - - self.mlp = RetNetMLP(config) - - self.scan_type = config.scan_type - - def forward( - self, - hidden_states: torch.Tensor, - past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, - use_cache: Optional[bool] = False, - output_attentions: Optional[bool] = False, - **kwargs: Unpack[Dict] - ) -> Union[Tuple[torch.Tensor, Optional[torch.Tensor]], Tuple[torch.Tensor]]: - residual = hidden_states - - # Pre-normalization if enabled - if hasattr(self, 'ln_1'): - hidden_states = self.ln_1(hidden_states) - - # Apply attention - - hidden_states = prepare_hidden_states_for_cross_scan(hidden_states, self.scan_type) - - hidden_states, attentions, past_key_values = self.attn( - hidden_states=hidden_states, - past_key_values=past_key_values, - use_cache=use_cache, - output_attentions=output_attentions, - **kwargs - ) - - hidden_states = prepare_hidden_states_for_cross_merge(hidden_states, self.scan_type) - - # First residual connection - hidden_states = residual + hidden_states - residual = hidden_states - - # Pre-normalization for MLP if enabled - if hasattr(self, 'ln_2'): - hidden_states = self.ln_2(hidden_states) - - # MLP - hidden_states = self.mlp(hidden_states) - - # Second residual connection - hidden_states = residual + hidden_states - - outputs = (hidden_states, attentions, past_key_values) - - return outputs - -class RetNetVisionPreTrainedModel(PreTrainedModel): - # this part of the code is adapted from huggingface/transformers vit implementation - config_class = RetNetVisionConfig - - def _init_weights(self, module): - if isinstance(module, (nn.Linear, nn.Conv2d)): - module.weight.data = nn.init.trunc_normal_( - module.weight.data.to(torch.float32), mean=0.0, std=self.config.initializer_range - ).to(module.weight.dtype) - if module.bias is not None: - module.bias.data.zero_() - elif isinstance(module, nn.LayerNorm): - module.bias.data.zero_() - module.weight.data.fill_(1.0) - elif isinstance(module, ImageEmbeddings): - module.position_embeddings.data = nn.init.trunc_normal_( - module.position_embeddings.data.to(torch.float32), - mean=0.0, - std=self.config.initializer_range, - ).to(module.position_embeddings.dtype) - -class RetNetForImageClassification(RetNetVisionPreTrainedModel): - config_class = RetNetVisionConfig - - def __init__(self, config): - super().__init__(config) - self.num_labels = config.num_classes - - self.embeddings = ImageEmbeddings(config) - self.blocks = nn.ModuleList([ - RetNetBlock(config, layer_idx) - for layer_idx in range(config.num_hidden_layers) - ]) - self.norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) - self.pooler = Pooler(config) - self.classifier = nn.Linear(config.hidden_size, config.num_classes) - self.interpolate_pos_encoding = config.interpolate_pos_encoding - self.init_weights() - - def forward( - self, - pixel_values: Optional[torch.Tensor] = None, - labels: Optional[torch.Tensor] = None, - output_hidden_states: Optional[bool] = None, - past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - return_dict: Optional[bool] = None, - **kwargs: Unpack[Dict] - ) -> Union[Tuple, ImageClassifierOutput]: - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - fuse_linear_and_cross_entropy = self.config.fuse_cross_entropy and self.training - hidden_states = self.embeddings(pixel_values, interpolate_pos_encoding=self.interpolate_pos_encoding) - - for block in self.blocks: - hidden_states, attentions, past_key_values = block( - hidden_states, - past_key_values=past_key_values, - use_cache=use_cache, - output_attentions=output_attentions, - **kwargs - ) - - hidden_states = self.norm(hidden_states) - pooled_output = self.pooler(hidden_states) - - logits = self.classifier(pooled_output) - - loss = None - if labels is not None: - if self.num_labels == 1: - loss_fct = MSELoss() - loss = loss_fct(logits.squeeze(), labels.squeeze()) - else: - loss_fct = CrossEntropyLoss() - loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) - - if not return_dict: - output = (logits,) + (hidden_states,) - return ((loss,) + output) if loss is not None else output - - return ImageClassifierOutput( - loss=loss, - logits=logits, - hidden_states=hidden_states, - ) diff --git a/fla/vision_models/rwkv6/__init__.py b/fla/vision_models/rwkv6/__init__.py deleted file mode 100644 index 2df666ac4..000000000 --- a/fla/vision_models/rwkv6/__init__.py +++ /dev/null @@ -1,12 +0,0 @@ -from transformers import AutoConfig, AutoModelForImageClassification - -from fla.vision_models.rwkv6.configuration_rwkv6 import RWKV6VisionConfig -from fla.vision_models.rwkv6.modeling_rwkv6 import RWKV6ForImageClassification - -AutoConfig.register(RWKV6VisionConfig.model_type, RWKV6VisionConfig) -AutoModelForImageClassification.register(RWKV6VisionConfig, RWKV6ForImageClassification) - -__all__ = [ - 'RWKV6VisionConfig', - 'RWKV6ForImageClassification' -] diff --git a/fla/vision_models/rwkv6/configuration_rwkv6.py b/fla/vision_models/rwkv6/configuration_rwkv6.py deleted file mode 100644 index 9e4d54cd0..000000000 --- a/fla/vision_models/rwkv6/configuration_rwkv6.py +++ /dev/null @@ -1,94 +0,0 @@ -from typing import Dict, Optional - -from transformers.configuration_utils import PretrainedConfig - - -class RWKV6VisionConfig(PretrainedConfig): - - model_type = 'rwkv6_vision' - - def __init__( - self, - # RWKV6 core parameters - attn_mode: str = "chunk", - hidden_size: int = 2048, - expand_k: int = 0.5, - expand_v: int = 1, - num_hidden_layers: int = 24, - num_heads: int = 4, - proj_low_rank_dim: int = 32, - gate_low_rank_dim: int = 64, - hidden_act: str = "sqrelu", - max_position_embeddings: int = 2048, - norm_first: bool = True, - norm_bias: bool = True, - norm_eps: float = 1e-5, - attn: Optional[Dict] = None, - use_cache: bool = True, - initializer_range: float = 0.02, - fuse_norm: bool = True, - fuse_cross_entropy: bool = True, - # Vision specific parameters - image_size: int = 224, - patch_size: int = 16, - num_channels: int = 3, - num_classes: int = 1000, - qkv_bias: bool = True, - hidden_dropout_prob: float = 0.0, - use_mask_token: bool = False, - layer_norm_eps: float = 1e-6, - interpolate_pos_encoding: bool = False, - mlp_dim: int = None, - # FLA-for-vision-related parameters - scan_type: str = "uni-scan", # scaning type, "uni-scan" or "bi-scan" or "cross-scan", default to "uni-scan" - **kwargs - ): - # Initialize RWKV6 core parameters - self.attn_mode = attn_mode - self.hidden_size = hidden_size - self.expand_k = expand_k - self.expand_v = expand_v - self.norm_first = norm_first - self.num_hidden_layers = num_hidden_layers - self.num_heads = num_heads - self.proj_low_rank_dim = proj_low_rank_dim - self.gate_low_rank_dim = gate_low_rank_dim - self.hidden_act = hidden_act - self.max_position_embeddings = max_position_embeddings - self.norm_bias = norm_bias - self.norm_eps = norm_eps - self.use_cache = use_cache - self.initializer_range = initializer_range - self.fuse_norm = fuse_norm - self.fuse_cross_entropy = fuse_cross_entropy - - # Initialize vision specific parameters - self.image_size = image_size - self.patch_size = patch_size - self.num_channels = num_channels - self.num_classes = num_classes - self.qkv_bias = qkv_bias - self.hidden_dropout_prob = hidden_dropout_prob - self.use_mask_token = use_mask_token - self.layer_norm_eps = layer_norm_eps - self.interpolate_pos_encoding = interpolate_pos_encoding - self.scan_type = scan_type - - if attn is not None: - if not isinstance(attn, Dict): - raise ValueError("attn must be a dictionary") - if 'layers' not in attn: - raise ValueError("Layer indices must be provided to initialize hybrid attention layers") - if 'num_heads' not in attn: - raise ValueError("Number of heads must be provided to initialize hybrid attention layers") - attn['num_kv_heads'] = attn.get('num_kv_heads', attn['num_heads']) - attn['window_size'] = attn.get('window_size', None) - - self.attn = attn - - if mlp_dim is None: - self.mlp_dim = 4 * hidden_size # default value set to 4 * hidden_size - else: - self.mlp_dim = mlp_dim - - super().__init__(**kwargs) \ No newline at end of file diff --git a/fla/vision_models/rwkv6/modeling_rwkv6.py b/fla/vision_models/rwkv6/modeling_rwkv6.py deleted file mode 100644 index 45c4df011..000000000 --- a/fla/vision_models/rwkv6/modeling_rwkv6.py +++ /dev/null @@ -1,199 +0,0 @@ -import collections.abc -import math -import torch -import torch.utils.checkpoint -from torch import nn -from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss -from typing import Optional, Set, Tuple, Union, List, Dict, Unpack -from transformers.utils import logging -from fla.layers.attn import Attention -from transformers.modeling_outputs import ImageClassifierOutput -from transformers.modeling_utils import PreTrainedModel -from fla.layers.rwkv6 import RWKV6Attention -from .configuration_rwkv6 import RWKV6VisionConfig -from fla.models.utils import Cache -from ..utils import ImageEmbeddings, Pooler, prepare_hidden_states_for_cross_scan, prepare_hidden_states_for_cross_merge - -logger = logging.get_logger(__name__) - -class RWKV6MLP(nn.Module): - def __init__(self, config): - super().__init__() - self.net = nn.Sequential( - nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps), - nn.Linear(config.hidden_size, config.mlp_dim), - nn.GELU(), - nn.Linear(config.mlp_dim, config.hidden_size), - nn.Dropout(config.hidden_dropout_prob) - ) - - def forward(self, x): - return self.net(x) - -class RWKV6Block(nn.Module): - def __init__(self, config, layer_idx: int): - super().__init__() - - self.ln_1 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) - - if config.attn is not None and layer_idx in config.attn['layers']: - self.attn = Attention( - hidden_size=config.hidden_size, - num_heads=config.attn['num_heads'], - num_kv_heads=config.attn['num_kv_heads'], - window_size=config.attn['window_size'], - max_position_embeddings=config.max_position_embeddings, - layer_idx=layer_idx - ) - else: - self.attn = RWKV6Attention( - mode=config.attn_mode, - hidden_size=config.hidden_size, - expand_k=config.expand_k, - expand_v=config.expand_v, - num_heads=config.num_heads, - proj_low_rank_dim=config.proj_low_rank_dim, - gate_low_rank_dim=config.gate_low_rank_dim, - norm_eps=config.norm_eps, - fuse_norm=config.fuse_norm, - layer_idx=layer_idx - ) - - self.ln_2 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) - - self.mlp = RWKV6MLP(config) - - self.scan_type = config.scan_type - - def forward( - self, - hidden_states: torch.Tensor, - past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, - use_cache: Optional[bool] = False, - output_attentions: Optional[bool] = False, - **kwargs: Unpack[Dict] - ) -> Union[Tuple[torch.Tensor, Optional[torch.Tensor]], Tuple[torch.Tensor]]: - residual = hidden_states - - # Pre-normalization if enabled - if hasattr(self, 'ln_1'): - hidden_states = self.ln_1(hidden_states) - - # Apply attention - - hidden_states = prepare_hidden_states_for_cross_scan(hidden_states, self.scan_type) - - hidden_states, attentions, past_key_values = self.attn( - hidden_states=hidden_states, - past_key_values=past_key_values, - use_cache=use_cache, - output_attentions=output_attentions, - **kwargs - ) - - hidden_states = prepare_hidden_states_for_cross_merge(hidden_states, self.scan_type) - - # First residual connection - hidden_states = residual + hidden_states - residual = hidden_states - - # Pre-normalization for MLP if enabled - if hasattr(self, 'ln_2'): - hidden_states = self.ln_2(hidden_states) - - # MLP - hidden_states = self.mlp(hidden_states) - - # Second residual connection - hidden_states = residual + hidden_states - - outputs = (hidden_states, attentions, past_key_values) - - return outputs - -class RWKV6VisionPreTrainedModel(PreTrainedModel): - # this part of the code is adapted from huggingface/transformers vit implementation - config_class = RWKV6VisionConfig - - def _init_weights(self, module): - if isinstance(module, (nn.Linear, nn.Conv2d)): - module.weight.data = nn.init.trunc_normal_( - module.weight.data.to(torch.float32), mean=0.0, std=self.config.initializer_range - ).to(module.weight.dtype) - if module.bias is not None: - module.bias.data.zero_() - elif isinstance(module, nn.LayerNorm): - module.bias.data.zero_() - module.weight.data.fill_(1.0) - elif isinstance(module, ImageEmbeddings): - module.position_embeddings.data = nn.init.trunc_normal_( - module.position_embeddings.data.to(torch.float32), - mean=0.0, - std=self.config.initializer_range, - ).to(module.position_embeddings.dtype) - -class RWKV6ForImageClassification(RWKV6VisionPreTrainedModel): - config_class = RWKV6VisionConfig - - def __init__(self, config): - super().__init__(config) - self.num_labels = config.num_classes - - self.embeddings = ImageEmbeddings(config) - self.blocks = nn.ModuleList([ - RWKV6Block(config, layer_idx) - for layer_idx in range(config.num_hidden_layers) - ]) - self.norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) - self.pooler = Pooler(config) - self.classifier = nn.Linear(config.hidden_size, config.num_classes) - self.interpolate_pos_encoding = config.interpolate_pos_encoding - self.init_weights() - - def forward( - self, - pixel_values: Optional[torch.Tensor] = None, - labels: Optional[torch.Tensor] = None, - output_hidden_states: Optional[bool] = None, - past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - return_dict: Optional[bool] = None, - **kwargs: Unpack[Dict] - ) -> Union[Tuple, ImageClassifierOutput]: - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - fuse_linear_and_cross_entropy = self.config.fuse_cross_entropy and self.training - hidden_states = self.embeddings(pixel_values, interpolate_pos_encoding=self.interpolate_pos_encoding) - - for block in self.blocks: - hidden_states, attentions, past_key_values = block( - hidden_states, - past_key_values=past_key_values, - use_cache=use_cache, - output_attentions=output_attentions, - **kwargs - ) - - hidden_states = self.norm(hidden_states) - pooled_output = self.pooler(hidden_states) - - logits = self.classifier(pooled_output) - - loss = None - if labels is not None: - if self.num_labels == 1: - loss_fct = MSELoss() - loss = loss_fct(logits.squeeze(), labels.squeeze()) - else: - loss_fct = CrossEntropyLoss() - loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) - - if not return_dict: - output = (logits,) + (hidden_states,) - return ((loss,) + output) if loss is not None else output - - return ImageClassifierOutput( - loss=loss, - logits=logits, - hidden_states=hidden_states, - ) diff --git a/fla/vision_models/transformer/__init__.py b/fla/vision_models/transformer/__init__.py deleted file mode 100644 index 25d4e9d2b..000000000 --- a/fla/vision_models/transformer/__init__.py +++ /dev/null @@ -1,12 +0,0 @@ -from transformers import AutoConfig, AutoModelForImageClassification - -from fla.vision_models.transformer.configuration_transformer import TransformerVisionConfig -from fla.vision_models.transformer.modeling_transformer import TransformerForImageClassification - -AutoConfig.register(TransformerVisionConfig.model_type, TransformerVisionConfig) -AutoModelForImageClassification.register(TransformerVisionConfig, TransformerForImageClassification) - -__all__ = [ - 'TransformerVisionConfig', - 'TransformerForImageClassification' -] diff --git a/fla/vision_models/transformer/configuration_transformer.py b/fla/vision_models/transformer/configuration_transformer.py deleted file mode 100644 index cc8246270..000000000 --- a/fla/vision_models/transformer/configuration_transformer.py +++ /dev/null @@ -1,81 +0,0 @@ -from typing import Optional - -from transformers.configuration_utils import PretrainedConfig - - -class TransformerVisionConfig(PretrainedConfig): - - model_type = 'transformer_vision' - - def __init__( - self, - # Transformer core parameters - hidden_size: int = 2048, - num_hidden_layers: int = 24, - num_heads: int = 32, - num_kv_heads: int = None, - window_size: Optional[int] = None, - rope_theta: Optional[float] = 10000., - max_position_embeddings: int = 2048, - hidden_act: str = "swish", - initializer_range: float = 0.02, - elementwise_affine: Optional[bool] = True, - norm_first: bool = False, - norm_eps: float = 1e-6, - use_cache: bool = True, - attention_bias: bool = False, - fuse_norm: bool = True, - fuse_cross_entropy: bool = True, - # Vision specific parameters - image_size: int = 224, - patch_size: int = 16, - num_channels: int = 3, - num_classes: int = 1000, - qkv_bias: bool = True, - hidden_dropout_prob: float = 0.0, - use_mask_token: bool = False, - layer_norm_eps: float = 1e-6, - interpolate_pos_encoding: bool = False, - mlp_dim: int = None, - # FLA-for-vision-related parameters - scan_type: str = "uni-scan", # scaning type, "uni-scan" or "bi-scan" or "cross-scan", default to "uni-scan" - **kwargs - ): - # Initialize transformer core parameters - self.hidden_size = hidden_size - self.num_hidden_layers = num_hidden_layers - self.num_heads = num_heads - self.num_kv_heads = num_kv_heads - self.window_size = window_size - self.rope_theta = rope_theta - self.max_position_embeddings = max_position_embeddings - - self.hidden_act = hidden_act - - self.initializer_range = initializer_range - self.elementwise_affine = elementwise_affine - self.norm_first = norm_first - self.norm_eps = norm_eps - self.use_cache = use_cache - self.attention_bias = attention_bias - self.fuse_cross_entropy = fuse_cross_entropy - self.fuse_norm = fuse_norm - - # Initialize vision specific parameters - self.image_size = image_size - self.patch_size = patch_size - self.num_channels = num_channels - self.num_classes = num_classes - self.qkv_bias = qkv_bias - self.hidden_dropout_prob = hidden_dropout_prob - self.use_mask_token = use_mask_token - self.layer_norm_eps = layer_norm_eps - self.interpolate_pos_encoding = interpolate_pos_encoding - self.scan_type = scan_type - - if mlp_dim is None: - self.mlp_dim = 4 * hidden_size # default value set to 4 * hidden_size - else: - self.mlp_dim = mlp_dim - - super().__init__(**kwargs) \ No newline at end of file diff --git a/fla/vision_models/transformer/modeling_transformer.py b/fla/vision_models/transformer/modeling_transformer.py deleted file mode 100644 index 6441293a5..000000000 --- a/fla/vision_models/transformer/modeling_transformer.py +++ /dev/null @@ -1,190 +0,0 @@ -import collections.abc -import math -import torch -import torch.utils.checkpoint -from torch import nn -from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss -from typing import Optional, Set, Tuple, Union, List, Dict, Unpack -from transformers.utils import logging -from fla.layers.attn import Attention -from transformers.modeling_outputs import ImageClassifierOutput -from transformers.modeling_utils import PreTrainedModel -from .configuration_transformer import TransformerVisionConfig -from fla.models.utils import Cache -from ..utils import ImageEmbeddings, Pooler, prepare_hidden_states_for_cross_scan, prepare_hidden_states_for_cross_merge - -logger = logging.get_logger(__name__) - -class TransformerMLP(nn.Module): - def __init__(self, config): - super().__init__() - self.net = nn.Sequential( - nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps), - nn.Linear(config.hidden_size, config.mlp_dim), - nn.GELU(), - nn.Linear(config.mlp_dim, config.hidden_size), - nn.Dropout(config.hidden_dropout_prob) - ) - - def forward(self, x): - return self.net(x) - -class TransformerBlock(nn.Module): - def __init__(self, config, layer_idx: int): - super().__init__() - - if not config.norm_first: - self.ln_1 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) - - self.attn = Attention( - hidden_size=config.hidden_size, - num_heads=config.num_heads, - num_kv_heads=config.num_kv_heads, - window_size=config.window_size, - rope_theta=config.rope_theta, - max_position_embeddings=config.max_position_embeddings, - norm_first=config.norm_first, - norm_eps=config.norm_eps, - layer_idx=layer_idx - ) - - - if not config.norm_first: - self.ln_2 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) - - self.mlp = TransformerMLP(config) - - self.scan_type = config.scan_type - - def forward( - self, - hidden_states: torch.Tensor, - past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, - use_cache: Optional[bool] = False, - output_attentions: Optional[bool] = False, - **kwargs: Unpack[Dict] - ) -> Union[Tuple[torch.Tensor, Optional[torch.Tensor]], Tuple[torch.Tensor]]: - residual = hidden_states - - # Pre-normalization if enabled - if hasattr(self, 'ln_1'): - hidden_states = self.ln_1(hidden_states) - - # Apply attention - - hidden_states = prepare_hidden_states_for_cross_scan(hidden_states, self.scan_type) - - hidden_states, attentions, past_key_values = self.attn( - hidden_states=hidden_states, - past_key_values=past_key_values, - use_cache=use_cache, - output_attentions=output_attentions, - **kwargs - ) - - hidden_states = prepare_hidden_states_for_cross_merge(hidden_states, self.scan_type) - - # First residual connection - hidden_states = residual + hidden_states - residual = hidden_states - - # Pre-normalization for MLP if enabled - if hasattr(self, 'ln_2'): - hidden_states = self.ln_2(hidden_states) - - # MLP - hidden_states = self.mlp(hidden_states) - - # Second residual connection - hidden_states = residual + hidden_states - - outputs = (hidden_states, attentions, past_key_values) - - return outputs - -class TransformerVisionPreTrainedModel(PreTrainedModel): - # this part of the code is adapted from huggingface/transformers vit implementation - config_class = TransformerVisionConfig - - def _init_weights(self, module): - if isinstance(module, (nn.Linear, nn.Conv2d)): - module.weight.data = nn.init.trunc_normal_( - module.weight.data.to(torch.float32), mean=0.0, std=self.config.initializer_range - ).to(module.weight.dtype) - if module.bias is not None: - module.bias.data.zero_() - elif isinstance(module, nn.LayerNorm): - module.bias.data.zero_() - module.weight.data.fill_(1.0) - elif isinstance(module, ImageEmbeddings): - module.position_embeddings.data = nn.init.trunc_normal_( - module.position_embeddings.data.to(torch.float32), - mean=0.0, - std=self.config.initializer_range, - ).to(module.position_embeddings.dtype) - -class TransformerForImageClassification(TransformerVisionPreTrainedModel): - config_class = TransformerVisionConfig - - def __init__(self, config): - super().__init__(config) - self.num_labels = config.num_classes - - self.embeddings = ImageEmbeddings(config) - self.blocks = nn.ModuleList([ - TransformerBlock(config, layer_idx) - for layer_idx in range(config.num_hidden_layers) - ]) - self.norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) - self.pooler = Pooler(config) - self.classifier = nn.Linear(config.hidden_size, config.num_classes) - self.interpolate_pos_encoding = config.interpolate_pos_encoding - self.init_weights() - - def forward( - self, - pixel_values: Optional[torch.Tensor] = None, - labels: Optional[torch.Tensor] = None, - output_hidden_states: Optional[bool] = None, - past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - return_dict: Optional[bool] = None, - **kwargs: Unpack[Dict] - ) -> Union[Tuple, ImageClassifierOutput]: - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - fuse_linear_and_cross_entropy = self.config.fuse_cross_entropy and self.training - hidden_states = self.embeddings(pixel_values, interpolate_pos_encoding=self.interpolate_pos_encoding) - - for block in self.blocks: - hidden_states, attentions, past_key_values = block( - hidden_states, - past_key_values=past_key_values, - use_cache=use_cache, - output_attentions=output_attentions, - **kwargs - ) - - hidden_states = self.norm(hidden_states) - pooled_output = self.pooler(hidden_states) - - logits = self.classifier(pooled_output) - - loss = None - if labels is not None: - if self.num_labels == 1: - loss_fct = MSELoss() - loss = loss_fct(logits.squeeze(), labels.squeeze()) - else: - loss_fct = CrossEntropyLoss() - loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) - - if not return_dict: - output = (logits,) + (hidden_states,) - return ((loss,) + output) if loss is not None else output - - return ImageClassifierOutput( - loss=loss, - logits=logits, - hidden_states=hidden_states, - ) diff --git a/fla/vision_models/utils.py b/fla/vision_models/utils.py deleted file mode 100644 index 246dcf931..000000000 --- a/fla/vision_models/utils.py +++ /dev/null @@ -1,480 +0,0 @@ -""" -Vision model utilities adapted from huggingface/transformers ViT implementation. -""" - -import collections.abc -import torch -from torch import nn -from typing import Optional -from transformers.utils import torch_int -import triton -import triton.language as tl -import einops -import math - -""" -Basic component of a vision model, like the patch embeddings, image embeddings, and pooler. -""" - -class PatchEmbeddings(nn.Module): - """ - Convert image into patch embeddings. - Adapted from huggingface/transformers ViT implementation. - """ - def __init__(self, config): - super().__init__() - image_size, patch_size = config.image_size, config.patch_size - num_channels, hidden_size = config.num_channels, config.hidden_size - - image_size = image_size if isinstance(image_size, collections.abc.Iterable) else (image_size, image_size) - patch_size = patch_size if isinstance(patch_size, collections.abc.Iterable) else (patch_size, patch_size) - num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0]) - self.image_size = image_size - self.patch_size = patch_size - self.num_channels = num_channels - self.num_patches = num_patches - - self.projection = nn.Conv2d(num_channels, hidden_size, kernel_size=patch_size, stride=patch_size) - - def forward(self, pixel_values: torch.Tensor, interpolate_pos_encoding: bool = False) -> torch.Tensor: - batch_size, num_channels, height, width = pixel_values.shape - if num_channels != self.num_channels: - raise ValueError( - "Make sure that the channel dimension of the pixel values match with the one set in the configuration." - f" Expected {self.num_channels} but got {num_channels}." - ) - if not interpolate_pos_encoding: - if height != self.image_size[0] or width != self.image_size[1]: - raise ValueError( - f"Input image size ({height}*{width}) doesn't match model" - f" ({self.image_size[0]}*{self.image_size[1]})." - ) - embeddings = self.projection(pixel_values).flatten(2).transpose(1, 2) - return embeddings - -class ImageEmbeddings(nn.Module): - """ - Construct the position and patch embeddings. - Adapted from huggingface/transformers ViT implementation. No cls token is used in this implementation. - """ - def __init__(self, config, use_mask_token: bool = False) -> None: - super().__init__() - - self.mask_token = nn.Parameter(torch.zeros(1, 1, config.hidden_size)) if use_mask_token else None - self.patch_embeddings = PatchEmbeddings(config) - num_patches = self.patch_embeddings.num_patches - self.position_embeddings = nn.Parameter(torch.randn(1, num_patches, config.hidden_size)) - self.dropout = nn.Dropout(config.hidden_dropout_prob) - self.patch_size = config.patch_size - self.config = config - - def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor: - """ - This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher resolution - images. This method is also adapted to support torch.jit tracing. - - Adapted from: - - https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174-L194, and - - https://github.com/facebookresearch/dinov2/blob/e1277af2ba9496fbadf7aec6eba56e8d882d1e35/dinov2/models/vision_transformer.py#L179-L211 - """ - - num_patches = embeddings.shape[1] - num_positions = self.position_embeddings.shape[1] - - if not torch.jit.is_tracing() and num_patches == num_positions and height == width: - return self.position_embeddings - - dim = embeddings.shape[-1] - - new_height = height // self.patch_size - new_width = width // self.patch_size - - sqrt_num_positions = torch_int(num_positions**0.5) - pos_embed = self.position_embeddings.reshape(1, sqrt_num_positions, sqrt_num_positions, dim) - - pos_embed = pos_embed.permute(0, 3, 1, 2) - - pos_embed = nn.functional.interpolate( - pos_embed, - size=(new_height, new_width), - mode="bicubic", - align_corners=False, - ) - - pos_embed = pos_embed.permute(0, 2, 3, 1).view(1, -1, dim) - - return pos_embed - - def forward( - self, - pixel_values: torch.Tensor, - bool_masked_pos: Optional[torch.BoolTensor] = None, - interpolate_pos_encoding: bool = False, - ) -> torch.Tensor: - batch_size, num_channels, height, width = pixel_values.shape - embeddings = self.patch_embeddings(pixel_values, interpolate_pos_encoding=interpolate_pos_encoding) - - if bool_masked_pos is not None: - seq_length = embeddings.shape[1] - mask_tokens = self.mask_token.expand(batch_size, seq_length, -1) - # replace the masked visual tokens by mask_tokens - mask = bool_masked_pos.unsqueeze(-1).type_as(mask_tokens) - embeddings = embeddings * (1.0 - mask) + mask_tokens * mask - - # add positional encoding to each token - if interpolate_pos_encoding: - embeddings = embeddings + self.interpolate_pos_encoding(embeddings, height, width) - else: - embeddings = embeddings + self.position_embeddings - - embeddings = self.dropout(embeddings) - - return embeddings - -class Pooler(nn.Module): - """ - Pool the output of a vision model by taking the mean of all tokens. - Adapted from huggingface/transformers ViT implementation. - """ - def __init__(self, config): - super().__init__() - self.dense = nn.Linear(config.hidden_size, config.hidden_size) - self.activation = nn.Tanh() - - def forward(self, hidden_states): - pooled_output = hidden_states.mean(dim=1) # always use mean pooling - pooled_output = self.dense(pooled_output) - pooled_output = self.activation(pooled_output) - return pooled_output - -""" -Cross Scan and Cross Merge implemented in Triton (only). taken from https://github.com/MzeroMiko/VMamba/blob/main/classification/models/csm_triton.py -""" - -@triton.jit -def triton_cross_scan_flex( - x: tl.tensor, # (B, C, H, W) | (B, H, W, C) | (B, 4, C, H, W) | (B, H, W, 4, C) - y: tl.tensor, # (B, 4, C, H, W) | (B, H, W, 4, C) - x_layout: tl.constexpr, - y_layout: tl.constexpr, - operation: tl.constexpr, - onebyone: tl.constexpr, - scans: tl.constexpr, - BC: tl.constexpr, - BH: tl.constexpr, - BW: tl.constexpr, - DC: tl.constexpr, - DH: tl.constexpr, - DW: tl.constexpr, - NH: tl.constexpr, - NW: tl.constexpr, -): - # x_layout = 0 - # y_layout = 1 # 0 BCHW, 1 BHWC - # operation = 0 # 0 scan, 1 merge - # onebyone = 0 # 0 false, 1 true - # scans = 0 # 0 cross scan, 1 unidirectional, 2 bidirectional - - i_hw, i_c, i_b = tl.program_id(0), tl.program_id(1), tl.program_id(2) - i_h, i_w = (i_hw // NW), (i_hw % NW) - _mask_h = (i_h * BH + tl.arange(0, BH)) < DH - _mask_w = (i_w * BW + tl.arange(0, BW)) < DW - _mask_hw = _mask_h[:, None] & _mask_w[None, :] - _for_C = min(DC - i_c * BC, BC) - - pos_h = (i_h * BH + tl.arange(0, BH)[:, None]) - pos_w = (i_w * BW + tl.arange(0, BW)[None, :]) - neg_h = (DH - i_h * BH - 1 - tl.arange(0, BH)[:, None]) - neg_w = (DW - i_w * BW - 1 - tl.arange(0, BW)[None, :]) - if scans == 0: - # none; trans; flip; trans + flip; - HWRoute0 = pos_h * DW + pos_w - HWRoute1 = pos_w * DH + pos_h # trans - HWRoute2 = neg_h * DW + neg_w # flip - HWRoute3 = neg_w * DH + neg_h # trans + flip - elif scans == 1: - # none; none; none; none; - HWRoute0 = pos_h * DW + pos_w - HWRoute1 = HWRoute0 - HWRoute2 = HWRoute0 - HWRoute3 = HWRoute0 - elif scans == 2: - # none; none; flip; flip; - HWRoute0 = pos_h * DW + pos_w - HWRoute1 = HWRoute0 - HWRoute2 = neg_h * DW + neg_w # flip - HWRoute3 = HWRoute2 - - _tmp1 = DC * DH * DW - - y_ptr_base = y + i_b * 4 * _tmp1 + (i_c * BC * DH * DW if y_layout == 0 else i_c * BC) - if y_layout == 0: - p_y1 = y_ptr_base + HWRoute0 - p_y2 = y_ptr_base + _tmp1 + HWRoute1 - p_y3 = y_ptr_base + 2 * _tmp1 + HWRoute2 - p_y4 = y_ptr_base + 3 * _tmp1 + HWRoute3 - else: - p_y1 = y_ptr_base + HWRoute0 * 4 * DC - p_y2 = y_ptr_base + DC + HWRoute1 * 4 * DC - p_y3 = y_ptr_base + 2 * DC + HWRoute2 * 4 * DC - p_y4 = y_ptr_base + 3 * DC + HWRoute3 * 4 * DC - - if onebyone == 0: - x_ptr_base = x + i_b * _tmp1 + (i_c * BC * DH * DW if x_layout == 0 else i_c * BC) - if x_layout == 0: - p_x = x_ptr_base + HWRoute0 - else: - p_x = x_ptr_base + HWRoute0 * DC - - if operation == 0: - for idxc in range(_for_C): - _idx_x = idxc * DH * DW if x_layout == 0 else idxc - _idx_y = idxc * DH * DW if y_layout == 0 else idxc - _x = tl.load(p_x + _idx_x, mask=_mask_hw) - tl.store(p_y1 + _idx_y, _x, mask=_mask_hw) - tl.store(p_y2 + _idx_y, _x, mask=_mask_hw) - tl.store(p_y3 + _idx_y, _x, mask=_mask_hw) - tl.store(p_y4 + _idx_y, _x, mask=_mask_hw) - elif operation == 1: - for idxc in range(_for_C): - _idx_x = idxc * DH * DW if x_layout == 0 else idxc - _idx_y = idxc * DH * DW if y_layout == 0 else idxc - _y1 = tl.load(p_y1 + _idx_y, mask=_mask_hw) - _y2 = tl.load(p_y2 + _idx_y, mask=_mask_hw) - _y3 = tl.load(p_y3 + _idx_y, mask=_mask_hw) - _y4 = tl.load(p_y4 + _idx_y, mask=_mask_hw) - tl.store(p_x + _idx_x, _y1 + _y2 + _y3 + _y4, mask=_mask_hw) - - else: - x_ptr_base = x + i_b * 4 * _tmp1 + (i_c * BC * DH * DW if x_layout == 0 else i_c * BC) - if x_layout == 0: - p_x1 = x_ptr_base + HWRoute0 - p_x2 = p_x1 + _tmp1 - p_x3 = p_x2 + _tmp1 - p_x4 = p_x3 + _tmp1 - else: - p_x1 = x_ptr_base + HWRoute0 * 4 * DC - p_x2 = p_x1 + DC - p_x3 = p_x2 + DC - p_x4 = p_x3 + DC - - if operation == 0: - for idxc in range(_for_C): - _idx_x = idxc * DH * DW if x_layout == 0 else idxc - _idx_y = idxc * DH * DW if y_layout == 0 else idxc - tl.store(p_y1 + _idx_y, tl.load(p_x1 + _idx_x, mask=_mask_hw), mask=_mask_hw) - tl.store(p_y2 + _idx_y, tl.load(p_x2 + _idx_x, mask=_mask_hw), mask=_mask_hw) - tl.store(p_y3 + _idx_y, tl.load(p_x3 + _idx_x, mask=_mask_hw), mask=_mask_hw) - tl.store(p_y4 + _idx_y, tl.load(p_x4 + _idx_x, mask=_mask_hw), mask=_mask_hw) - else: - for idxc in range(_for_C): - _idx_x = idxc * DH * DW if x_layout == 0 else idxc - _idx_y = idxc * DH * DW if y_layout == 0 else idxc - tl.store(p_x1 + _idx_x, tl.load(p_y1 + _idx_y), mask=_mask_hw) - tl.store(p_x2 + _idx_x, tl.load(p_y2 + _idx_y), mask=_mask_hw) - tl.store(p_x3 + _idx_x, tl.load(p_y3 + _idx_y), mask=_mask_hw) - tl.store(p_x4 + _idx_x, tl.load(p_y4 + _idx_y), mask=_mask_hw) - - -class CrossScanTritonF(torch.autograd.Function): - @staticmethod - def forward(ctx, x: torch.Tensor, in_channel_first=True, out_channel_first=True, one_by_one=False, scans=0): - if one_by_one: - if in_channel_first: - B, _, C, H, W = x.shape - else: - B, H, W, _, C = x.shape - else: - if in_channel_first: - B, C, H, W = x.shape - else: - B, H, W, C = x.shape - B, C, H, W = int(B), int(C), int(H), int(W) - BC, BH, BW = 1, 32, 32 - NH, NW, NC = triton.cdiv(H, BH), triton.cdiv(W, BW), triton.cdiv(C, BC) - - ctx.in_channel_first = in_channel_first - ctx.out_channel_first = out_channel_first - ctx.one_by_one = one_by_one - ctx.scans = scans - ctx.shape = (B, C, H, W) - ctx.triton_shape = (BC, BH, BW, NC, NH, NW) - - y = x.new_empty((B, 4, C, H * W)) if out_channel_first else x.new_empty((B, H * W, 4, C)) - triton_cross_scan_flex[(NH * NW, NC, B)]( - x.contiguous(), y, - (0 if in_channel_first else 1), (0 if out_channel_first else 1), 0, (0 if not one_by_one else 1), scans, - BC, BH, BW, C, H, W, NH, NW - ) - return y - - @staticmethod - def backward(ctx, y: torch.Tensor): - in_channel_first = ctx.in_channel_first - out_channel_first = ctx.out_channel_first - one_by_one = ctx.one_by_one - scans = ctx.scans - B, C, H, W = ctx.shape - BC, BH, BW, NC, NH, NW = ctx.triton_shape - if one_by_one: - x = y.new_empty((B, 4, C, H, W)) if in_channel_first else y.new_empty((B, H, W, 4, C)) - else: - x = y.new_empty((B, C, H, W)) if in_channel_first else y.new_empty((B, H, W, C)) - - triton_cross_scan_flex[(NH * NW, NC, B)]( - x, y.contiguous(), - (0 if in_channel_first else 1), (0 if out_channel_first else 1), 1, (0 if not one_by_one else 1), scans, - BC, BH, BW, C, H, W, NH, NW - ) - return x, None, None, None, None - - -class CrossMergeTritonF(torch.autograd.Function): - @staticmethod - def forward(ctx, y: torch.Tensor, in_channel_first=True, out_channel_first=True, one_by_one=False, scans=0): - if out_channel_first: - B, _, C, H, W = y.shape - else: - B, H, W, _, C = y.shape - B, C, H, W = int(B), int(C), int(H), int(W) - BC, BH, BW = 1, 32, 32 - NH, NW, NC = triton.cdiv(H, BH), triton.cdiv(W, BW), triton.cdiv(C, BC) - ctx.in_channel_first = in_channel_first - ctx.out_channel_first = out_channel_first - ctx.one_by_one = one_by_one - ctx.scans = scans - ctx.shape = (B, C, H, W) - ctx.triton_shape = (BC, BH, BW, NC, NH, NW) - if one_by_one: - x = y.new_empty((B, 4, C, H * W)) if in_channel_first else y.new_empty((B, H * W, 4, C)) - else: - x = y.new_empty((B, C, H * W)) if in_channel_first else y.new_empty((B, H * W, C)) - triton_cross_scan_flex[(NH * NW, NC, B)]( - x, y.contiguous(), - (0 if in_channel_first else 1), (0 if out_channel_first else 1), 1, (0 if not one_by_one else 1), scans, - BC, BH, BW, C, H, W, NH, NW - ) - return x - - @staticmethod - def backward(ctx, x: torch.Tensor): - in_channel_first = ctx.in_channel_first - out_channel_first = ctx.out_channel_first - one_by_one = ctx.one_by_one - scans = ctx.scans - B, C, H, W = ctx.shape - BC, BH, BW, NC, NH, NW = ctx.triton_shape - y = x.new_empty((B, 4, C, H, W)) if out_channel_first else x.new_empty((B, H, W, 4, C)) - triton_cross_scan_flex[(NH * NW, NC, B)]( - x.contiguous(), y, - (0 if in_channel_first else 1), (0 if out_channel_first else 1), 0, (0 if not one_by_one else 1), scans, - BC, BH, BW, C, H, W, NH, NW - ) - return y, None, None, None, None, None - - -# @torch.compile(options={"triton.cudagraphs": True}, fullgraph=True) -def cross_scan_fn(x: torch.Tensor, in_channel_first=True, out_channel_first=True, one_by_one=False, scans=0, force_torch=False): - # x: (B, C, H, W) | (B, H, W, C) | (B, 4, C, H, W) | (B, H, W, 4, C) - # y: (B, 4, C, L) | (B, L, 4, C) - # scans: 0: cross scan; 1 unidirectional; 2: bidirectional; - assert x.is_cuda - CSF = CrossScanTritonF - with torch.cuda.device(x.device): - return CSF.apply(x, in_channel_first, out_channel_first, one_by_one, scans) - - -# @torch.compile(options={"triton.cudagraphs": True}, fullgraph=True) -def cross_merge_fn(y: torch.Tensor, in_channel_first=True, out_channel_first=True, one_by_one=False, scans=0, force_torch=False): - # y: (B, 4, C, L) | (B, L, 4, C) - # x: (B, C, H * W) | (B, H * W, C) | (B, 4, C, H * W) | (B, H * W, 4, C) - # scans: 0: cross scan; 1 unidirectional; 2: bidirectional; - assert y.is_cuda - CMF = CrossMergeTritonF - with torch.cuda.device(y.device): - return CMF.apply(y, in_channel_first, out_channel_first, one_by_one, scans) - -def prepare_hidden_states_for_cross_scan(hidden_states: torch.Tensor, scan_type: str = "uni-scan"): - # hidden_states shape should be: (B, L, D) - if scan_type == "uni-scan": - # in this case, nothing need to be done - return hidden_states - elif scan_type == "bi-scan": - flipped_hidden_states = hidden_states.flip(-2) - hidden_states = torch.cat([hidden_states, flipped_hidden_states], dim=0) # (B, L, D) -> (2B, L, D) - return hidden_states - - # apply cross scan to the sequence - B, L, D = hidden_states.shape - hw = int(math.sqrt(L)) - assert (hw * hw == L) # make sure L is a square - hidden_states = einops.rearrange(hidden_states, "b (h w) d -> b h w d", h=hw, w=hw) # change the shape to feed to cross_scan - hidden_states = cross_scan_fn(hidden_states, in_channel_first=False, out_channel_first=False, one_by_one=False, scans=0) - hidden_states = einops.rearrange(hidden_states, "b l k d -> (b k) l d") - return hidden_states - -def prepare_hidden_states_for_cross_merge(hidden_states: torch.Tensor, scan_type: str = "uni-scan"): - # hidden_states shape should be: (BK, L, D), K=2 for bi-scan, K=1 for uni-scan, K=4 for cross-scan - if scan_type == "uni-scan": - # in this case, nothing need to be done - return hidden_states - elif scan_type == "bi-scan": - # merge the two sequences - B = hidden_states.shape[0] // 2 - hidden_states = hidden_states[:B] + hidden_states[B:] - return hidden_states - - B, L, D = hidden_states.shape - hw = int(math.sqrt(L)) - hidden_states = einops.rearrange(hidden_states, "(b k) (h w) d -> b h w k d", k=4, h=hw, w=hw) - # apply cross merge to the sequence - hidden_states = cross_merge_fn(hidden_states, in_channel_first=False, out_channel_first=False, one_by_one=False, scans=0) - return hidden_states - -# check the implementation -if __name__ == "__main__": - B, L, D = 1, 4, 3 - transformation = nn.Linear(D, D).cuda() - # firstly test bi-scan - print("Checking bi-scan") - h1 = torch.randn(B, L, D).cuda() - h2 = h1.clone().cuda() - h1 = prepare_hidden_states_for_cross_scan(h1, scan_type="bi-scan") - h1 = transformation(h1) - h1 = prepare_hidden_states_for_cross_merge(h1, scan_type="bi-scan") - h2_ = h2.clone().cuda() - h2_ = h2_.flip(-2) - h2 = transformation(h2) - h2_ = transformation(h2_) - h2 = h2 + h2_ - # check whether the two sequences are the same - print(f"h1: \n{h1}") - print(f"h2: \n{h2}") - print(f"""The two sequences are the same: {torch.allclose(h1, h2)}""") - # Then check cross-scan - print("checking cross-scan") - h1 = torch.randn(B, L, D).cuda() - h2 = h1.clone().cuda() - h1 = prepare_hidden_states_for_cross_scan(h1, scan_type="cross-scan") - h1 = transformation(h1) - h1 = prepare_hidden_states_for_cross_merge(h1, scan_type="cross-scan") - B, L, D = h2.shape - hw = int(math.sqrt(L)) - assert (hw * hw == L) # make sure L is a square - h2 = einops.rearrange(h2, "b (h w) d -> b h w d", h=hw, w=hw) # change the shape to feed to cross_scan - h2 = cross_scan_fn(h2, in_channel_first=False, out_channel_first=False, one_by_one=False, scans=0) - h2 = h2.permute(2, 0, 1, 3) - h2_0 = h2[0] - h2_1 = h2[1] - h2_2 = h2[2] - h2_3 = h2[3] - h2_0 = transformation(h2_0) - h2_1 = transformation(h2_1) - h2_2 = transformation(h2_2) - h2_3 = transformation(h2_3) - h2 = torch.cat([h2_0, h2_1, h2_2, h2_3], dim=0) - h2 = prepare_hidden_states_for_cross_merge(h2, scan_type="cross-scan") - # check whether the two sequences are the same - print(f"h1: \n{h1}") - print(f"h2: \n{h2}") - print(f"""The two sequences are the same: {torch.allclose(h1, h2)}""") \ No newline at end of file