Skip to content

Commit

Permalink
migrate vision models to fla/models
Browse files Browse the repository at this point in the history
  • Loading branch information
yibozhong committed Jan 19, 2025
1 parent 683913f commit 6deb624
Show file tree
Hide file tree
Showing 38 changed files with 6,104 additions and 91 deletions.
28 changes: 27 additions & 1 deletion fla/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,20 @@
from fla.models.transformer import (TransformerConfig, TransformerForCausalLM,
TransformerModel)
from fla.models.gated_deltanet import GatedDeltaNetConfig, GatedDeltaNetForCausalLM, GatedDeltaNetModel

from fla.models.abc import ABCVisionConfig, ABCForImageClassification, ABCForMaskedImageModeling, ABCVisionModel
from fla.models.bitnet import BitNetVisionConfig, BitNetForImageClassification, BitNetForMaskedImageModeling, BitNetVisionModel
from fla.models.delta_net import DeltaNetVisionConfig, DeltaNetForImageClassification, DeltaNetForMaskedImageModeling, DeltaNetVisionModel
from fla.models.gated_deltanet import GatedDeltaNetVisionConfig, GatedDeltaNetForImageClassification, GatedDeltaNetVisionModel, GatedDeltaNetForMaskedImageModeling
from fla.models.gla import GLAVisionConfig, GLAForImageClassification, GLAForMaskedImageModeling, GLAVisionModel
from fla.models.gsa import GSAVisionConfig, GSAForImageClassification, GSAForMaskedImageModeling, GSAVisionModel
from fla.models.hgrn import HGRNVisionConfig, HGRNForImageClassification, HGRNForMaskedImageModeling, HGRNVisionModel
from fla.models.hgrn2 import HGRN2VisionConfig, HGRN2ForImageClassification, HGRN2ForMaskedImageModeling, HGRN2VisionModel
from fla.models.linear_attn import LinearAttentionVisionConfig, LinearAttentionForImageClassification, LinearAttentionForMaskedImageModeling, LinearAttentionVisionModel
from fla.models.retnet import RetNetVisionConfig, RetNetForImageClassification, RetNetForMaskedImageModeling, RetNetVisionModel
from fla.models.rwkv6 import RWKV6VisionConfig, RWKV6ForImageClassification, RWKV6ForMaskedImageModeling, RWKV6VisionModel
from fla.models.transformer import TransformerVisionConfig, TransformerForImageClassification, TransformerForMaskedImageModeling, TransformerVisionModel

__all__ = [
'ABCConfig', 'ABCForCausalLM', 'ABCModel',
'BitNetConfig', 'BitNetForCausalLM', 'BitNetModel',
Expand All @@ -34,5 +48,17 @@
'RWKV6Config', 'RWKV6ForCausalLM', 'RWKV6Model',
'SambaConfig', 'SambaForCausalLM', 'SambaModel',
'TransformerConfig', 'TransformerForCausalLM', 'TransformerModel',
'GatedDeltaNetConfig', 'GatedDeltaNetForCausalLM', 'GatedDeltaNetModel'
'GatedDeltaNetConfig', 'GatedDeltaNetForCausalLM', 'GatedDeltaNetModel',
'ABCVisionConfig', 'ABCForImageClassification', 'ABCForMaskedImageModeling', 'ABCVisionModel',
'BitNetVisionConfig', 'BitNetForImageClassification', 'BitNetForMaskedImageModeling', 'BitNetVisionModel',
'DeltaNetVisionConfig', 'DeltaNetForImageClassification', 'DeltaNetForMaskedImageModeling', 'DeltaNetVisionModel',
'GatedDeltaNetVisionConfig', 'GatedDeltaNetForImageClassification', 'GatedDeltaNetVisionModel', 'GatedDeltaNetForMaskedImageModeling',
'GLAVisionConfig', 'GLAForImageClassification', 'GLAForMaskedImageModeling', 'GLAVisionModel',
'GSAVisionConfig', 'GSAForImageClassification', 'GSAForMaskedImageModeling', 'GSAVisionModel',
'HGRNVisionConfig', 'HGRNForImageClassification', 'HGRNForMaskedImageModeling', 'HGRNVisionModel',
'HGRN2VisionConfig', 'HGRN2ForImageClassification', 'HGRN2ForMaskedImageModeling', 'HGRN2VisionModel',
'LinearAttentionVisionConfig', 'LinearAttentionForImageClassification', 'LinearAttentionForMaskedImageModeling', 'LinearAttentionVisionModel',
'RetNetVisionConfig', 'RetNetForImageClassification', 'RetNetForMaskedImageModeling', 'RetNetVisionModel',
'RWKV6VisionConfig', 'RWKV6ForImageClassification', 'RWKV6ForMaskedImageModeling', 'RWKV6VisionModel',
'TransformerVisionConfig', 'TransformerForImageClassification', 'TransformerForMaskedImageModeling', 'TransformerVisionModel',
]
12 changes: 8 additions & 4 deletions fla/models/abc/__init__.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,17 @@
# -*- coding: utf-8 -*-

from transformers import AutoConfig, AutoModel, AutoModelForCausalLM
from transformers import AutoConfig, AutoModel, AutoModelForCausalLM, AutoModelForImageClassification, AutoModelForMaskedImageModeling

from fla.models.abc.configuration_abc import ABCConfig
from fla.models.abc.modeling_abc import ABCForCausalLM, ABCModel
from fla.models.abc.configuration_abc import ABCConfig, ABCVisionConfig
from fla.models.abc.modeling_abc import ABCForCausalLM, ABCModel, ABCVisionModel, ABCForImageClassification, ABCForMaskedImageModeling

AutoConfig.register(ABCConfig.model_type, ABCConfig)
AutoConfig.register(ABCVisionConfig.model_type, ABCVisionConfig)
AutoModel.register(ABCConfig, ABCModel)
AutoModelForCausalLM.register(ABCConfig, ABCForCausalLM)
AutoModelForImageClassification.register(ABCVisionConfig, ABCForImageClassification)
AutoModelForMaskedImageModeling.register(ABCVisionConfig, ABCForMaskedImageModeling)
AutoModel.register(ABCVisionConfig, ABCVisionModel)


__all__ = ['ABCConfig', 'ABCForCausalLM', 'ABCModel']
__all__ = ['ABCConfig', 'ABCForCausalLM', 'ABCModel', 'ABCVisionModel', 'ABCForImageClassification', 'ABCForMaskedImageModeling', 'ABCVisionConfig']
95 changes: 95 additions & 0 deletions fla/models/abc/configuration_abc.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,3 +82,98 @@ def __init__(
tie_word_embeddings=tie_word_embeddings,
**kwargs,
)


class ABCVisionConfig(PretrainedConfig):

model_type = 'abc_vision'

def __init__(
self,
# ABC core parameters
hidden_size: int = 2048,
gate_low_rank_dim: int = 16,
clamp_min: float = -32,
clamp_max: float = 32,
num_hidden_layers: int = 24,
num_heads: int = 4,
num_slots: Optional[int] = 64,
use_short_conv: bool = False,
conv_size: int = 4,
exapnd_k: float = 0.5,
exapnd_v: float = 1,
hidden_act: str = "swish",
max_position_embeddings: int = 2048,
elementwise_affine: Optional[bool] = True,
norm_eps: float = 1e-6,
attn: Optional[Dict] = None,
use_cache: bool = True,
initializer_range: float = 0.02,
fuse_norm: bool = True,
fuse_cross_entropy: bool = True,
# Vision specific parameters
image_size: int = 224,
patch_size: int = 16,
num_channels: int = 3,
num_classes: int = 1000,
qkv_bias: bool = True,
hidden_dropout_prob: float = 0.0,
use_mask_token: bool = False,
layer_norm_eps: float = 1e-6,
interpolate_pos_encoding: bool = False,
mlp_dim: int = None,
encoder_stride=16,
scan_type: str = "uni-scan", # scaning type, "uni-scan" or "bi-scan" or "cross-scan", default to "uni-scan"
**kwargs
):
# Initialize ABC core parameters
self.hidden_size = hidden_size
self.gate_low_rank_dim = gate_low_rank_dim
self.clamp_min = clamp_min
self.clamp_max = clamp_max
self.num_hidden_layers = num_hidden_layers
self.num_heads = num_heads
self.num_slots = num_slots
self.use_short_conv = use_short_conv
self.conv_size = conv_size
self.expand_k = exapnd_k
self.expand_v = exapnd_v
self.hidden_act = hidden_act
self.max_position_embeddings = max_position_embeddings
self.elementwise_affine = elementwise_affine
self.norm_eps = norm_eps
self.use_cache = use_cache
self.initializer_range = initializer_range
self.fuse_norm = fuse_norm
self.fuse_cross_entropy = fuse_cross_entropy

# Initialize vision specific parameters
self.image_size = image_size
self.patch_size = patch_size
self.num_channels = num_channels
self.num_classes = num_classes
self.qkv_bias = qkv_bias
self.hidden_dropout_prob = hidden_dropout_prob
self.use_mask_token = use_mask_token
self.layer_norm_eps = layer_norm_eps
self.interpolate_pos_encoding = interpolate_pos_encoding
self.scan_type = scan_type
self.encoder_stride = encoder_stride

if attn is not None:
if not isinstance(attn, Dict):
raise ValueError("attn must be a dictionary")
if 'layers' not in attn:
raise ValueError("Layer indices must be provided to initialize hybrid attention layers")
if 'num_heads' not in attn:
raise ValueError("Number of heads must be provided to initialize hybrid attention layers")
attn['num_kv_heads'] = attn.get('num_kv_heads', attn['num_heads'])
attn['window_size'] = attn.get('window_size', None)

self.attn = attn
if mlp_dim is None:
self.mlp_dim = 4 * hidden_size # default value set to 4 * hidden_size
else:
self.mlp_dim = mlp_dim

super().__init__(**kwargs)
Loading

0 comments on commit 6deb624

Please sign in to comment.