Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add implementations for vision models #123

Open
wants to merge 31 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
d9a672a
Add a DeltaNet Image Classification Model Implementation based on hug…
yibozhong Jan 13, 2025
faf49d7
change position of vision trainig code
yibozhong Jan 13, 2025
f676c81
Update new vision delta net
yibozhong Jan 13, 2025
5bc94ce
Update new vision delta net
yibozhong Jan 13, 2025
b013b7e
Merge branch 'fla-org:main' into main
yibozhong Jan 14, 2025
d470814
Add support for multiple scanning method
yibozhong Jan 14, 2025
5ee4bfc
Update training script
yibozhong Jan 14, 2025
f83e10a
Merge branch 'fla-org:main' into main
yibozhong Jan 14, 2025
4942afb
Merge branch 'fla-org:main' into main
yibozhong Jan 15, 2025
6353af5
Merge branch 'fla-org:main' into main
yibozhong Jan 15, 2025
dddc7ab
Merge branch 'fla-org:main' into main
yibozhong Jan 16, 2025
c308ac0
Add all fla-based vision models except mamba, mamba2 and samba
yibozhong Jan 16, 2025
4188ebb
change script location
yibozhong Jan 16, 2025
6b49935
Merge branch 'fla-org:main' into main
yibozhong Jan 17, 2025
16568d9
Test the implementations
yibozhong Jan 17, 2025
b2db8d0
change script position
yibozhong Jan 17, 2025
be439b0
change script position
yibozhong Jan 17, 2025
57fd584
update __init__.py for vision models
yibozhong Jan 17, 2025
562e12b
Merge branch 'fla-org:main' into main
yibozhong Jan 17, 2025
6f2276a
Merge branch 'fla-org:main' into main
yibozhong Jan 17, 2025
ed12c30
Merge branch 'fla-org:main' into main
yibozhong Jan 18, 2025
1674a22
Standarized the code and add implementations for basemodel and masked…
yibozhong Jan 18, 2025
973e3eb
Merge branch 'fla-org:main' into main
yibozhong Jan 18, 2025
5128b34
Merge branch 'main' of https://github.com/yibozhong/flash-linear-atte…
yibozhong Jan 18, 2025
749674e
Merge branch 'fla-org:main' into main
yibozhong Jan 19, 2025
b6daab9
Merge branch 'main' of https://github.com/yibozhong/flash-linear-atte…
yibozhong Jan 19, 2025
a3b03b6
Remove training script
yibozhong Jan 19, 2025
683913f
remove separate folder
yibozhong Jan 19, 2025
6deb624
migrate vision models to fla/models
yibozhong Jan 19, 2025
e1b05da
reverse previous changes for separate PR
yibozhong Jan 19, 2025
764da64
Merge branch 'fla-org:main' into main
yibozhong Jan 19, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 27 additions & 1 deletion fla/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,20 @@
from fla.models.transformer import (TransformerConfig, TransformerForCausalLM,
TransformerModel)
from fla.models.gated_deltanet import GatedDeltaNetConfig, GatedDeltaNetForCausalLM, GatedDeltaNetModel

from fla.models.abc import ABCVisionConfig, ABCForImageClassification, ABCForMaskedImageModeling, ABCVisionModel
from fla.models.bitnet import BitNetVisionConfig, BitNetForImageClassification, BitNetForMaskedImageModeling, BitNetVisionModel
from fla.models.delta_net import DeltaNetVisionConfig, DeltaNetForImageClassification, DeltaNetForMaskedImageModeling, DeltaNetVisionModel
from fla.models.gated_deltanet import GatedDeltaNetVisionConfig, GatedDeltaNetForImageClassification, GatedDeltaNetVisionModel, GatedDeltaNetForMaskedImageModeling
from fla.models.gla import GLAVisionConfig, GLAForImageClassification, GLAForMaskedImageModeling, GLAVisionModel
from fla.models.gsa import GSAVisionConfig, GSAForImageClassification, GSAForMaskedImageModeling, GSAVisionModel
from fla.models.hgrn import HGRNVisionConfig, HGRNForImageClassification, HGRNForMaskedImageModeling, HGRNVisionModel
from fla.models.hgrn2 import HGRN2VisionConfig, HGRN2ForImageClassification, HGRN2ForMaskedImageModeling, HGRN2VisionModel
from fla.models.linear_attn import LinearAttentionVisionConfig, LinearAttentionForImageClassification, LinearAttentionForMaskedImageModeling, LinearAttentionVisionModel
from fla.models.retnet import RetNetVisionConfig, RetNetForImageClassification, RetNetForMaskedImageModeling, RetNetVisionModel
from fla.models.rwkv6 import RWKV6VisionConfig, RWKV6ForImageClassification, RWKV6ForMaskedImageModeling, RWKV6VisionModel
from fla.models.transformer import TransformerVisionConfig, TransformerForImageClassification, TransformerForMaskedImageModeling, TransformerVisionModel

__all__ = [
'ABCConfig', 'ABCForCausalLM', 'ABCModel',
'BitNetConfig', 'BitNetForCausalLM', 'BitNetModel',
Expand All @@ -34,5 +48,17 @@
'RWKV6Config', 'RWKV6ForCausalLM', 'RWKV6Model',
'SambaConfig', 'SambaForCausalLM', 'SambaModel',
'TransformerConfig', 'TransformerForCausalLM', 'TransformerModel',
'GatedDeltaNetConfig', 'GatedDeltaNetForCausalLM', 'GatedDeltaNetModel'
'GatedDeltaNetConfig', 'GatedDeltaNetForCausalLM', 'GatedDeltaNetModel',
'ABCVisionConfig', 'ABCForImageClassification', 'ABCForMaskedImageModeling', 'ABCVisionModel',
'BitNetVisionConfig', 'BitNetForImageClassification', 'BitNetForMaskedImageModeling', 'BitNetVisionModel',
'DeltaNetVisionConfig', 'DeltaNetForImageClassification', 'DeltaNetForMaskedImageModeling', 'DeltaNetVisionModel',
'GatedDeltaNetVisionConfig', 'GatedDeltaNetForImageClassification', 'GatedDeltaNetVisionModel', 'GatedDeltaNetForMaskedImageModeling',
'GLAVisionConfig', 'GLAForImageClassification', 'GLAForMaskedImageModeling', 'GLAVisionModel',
'GSAVisionConfig', 'GSAForImageClassification', 'GSAForMaskedImageModeling', 'GSAVisionModel',
'HGRNVisionConfig', 'HGRNForImageClassification', 'HGRNForMaskedImageModeling', 'HGRNVisionModel',
'HGRN2VisionConfig', 'HGRN2ForImageClassification', 'HGRN2ForMaskedImageModeling', 'HGRN2VisionModel',
'LinearAttentionVisionConfig', 'LinearAttentionForImageClassification', 'LinearAttentionForMaskedImageModeling', 'LinearAttentionVisionModel',
'RetNetVisionConfig', 'RetNetForImageClassification', 'RetNetForMaskedImageModeling', 'RetNetVisionModel',
'RWKV6VisionConfig', 'RWKV6ForImageClassification', 'RWKV6ForMaskedImageModeling', 'RWKV6VisionModel',
'TransformerVisionConfig', 'TransformerForImageClassification', 'TransformerForMaskedImageModeling', 'TransformerVisionModel',
]
12 changes: 8 additions & 4 deletions fla/models/abc/__init__.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,17 @@
# -*- coding: utf-8 -*-

from transformers import AutoConfig, AutoModel, AutoModelForCausalLM
from transformers import AutoConfig, AutoModel, AutoModelForCausalLM, AutoModelForImageClassification, AutoModelForMaskedImageModeling

from fla.models.abc.configuration_abc import ABCConfig
from fla.models.abc.modeling_abc import ABCForCausalLM, ABCModel
from fla.models.abc.configuration_abc import ABCConfig, ABCVisionConfig
from fla.models.abc.modeling_abc import ABCForCausalLM, ABCModel, ABCVisionModel, ABCForImageClassification, ABCForMaskedImageModeling

AutoConfig.register(ABCConfig.model_type, ABCConfig)
AutoConfig.register(ABCVisionConfig.model_type, ABCVisionConfig)
AutoModel.register(ABCConfig, ABCModel)
AutoModelForCausalLM.register(ABCConfig, ABCForCausalLM)
AutoModelForImageClassification.register(ABCVisionConfig, ABCForImageClassification)
AutoModelForMaskedImageModeling.register(ABCVisionConfig, ABCForMaskedImageModeling)
AutoModel.register(ABCVisionConfig, ABCVisionModel)


__all__ = ['ABCConfig', 'ABCForCausalLM', 'ABCModel']
__all__ = ['ABCConfig', 'ABCForCausalLM', 'ABCModel', 'ABCVisionModel', 'ABCForImageClassification', 'ABCForMaskedImageModeling', 'ABCVisionConfig']
95 changes: 95 additions & 0 deletions fla/models/abc/configuration_abc.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,3 +82,98 @@ def __init__(
tie_word_embeddings=tie_word_embeddings,
**kwargs,
)


class ABCVisionConfig(PretrainedConfig):

model_type = 'abc_vision'

def __init__(
self,
# ABC core parameters
hidden_size: int = 2048,
gate_low_rank_dim: int = 16,
clamp_min: float = -32,
clamp_max: float = 32,
num_hidden_layers: int = 24,
num_heads: int = 4,
num_slots: Optional[int] = 64,
use_short_conv: bool = False,
conv_size: int = 4,
exapnd_k: float = 0.5,
exapnd_v: float = 1,
hidden_act: str = "swish",
max_position_embeddings: int = 2048,
elementwise_affine: Optional[bool] = True,
norm_eps: float = 1e-6,
attn: Optional[Dict] = None,
use_cache: bool = True,
initializer_range: float = 0.02,
fuse_norm: bool = True,
fuse_cross_entropy: bool = True,
# Vision specific parameters
image_size: int = 224,
patch_size: int = 16,
num_channels: int = 3,
num_classes: int = 1000,
qkv_bias: bool = True,
hidden_dropout_prob: float = 0.0,
use_mask_token: bool = False,
layer_norm_eps: float = 1e-6,
interpolate_pos_encoding: bool = False,
mlp_dim: int = None,
encoder_stride=16,
scan_type: str = "uni-scan", # scaning type, "uni-scan" or "bi-scan" or "cross-scan", default to "uni-scan"
**kwargs
):
# Initialize ABC core parameters
self.hidden_size = hidden_size
self.gate_low_rank_dim = gate_low_rank_dim
self.clamp_min = clamp_min
self.clamp_max = clamp_max
self.num_hidden_layers = num_hidden_layers
self.num_heads = num_heads
self.num_slots = num_slots
self.use_short_conv = use_short_conv
self.conv_size = conv_size
self.expand_k = exapnd_k
self.expand_v = exapnd_v
self.hidden_act = hidden_act
self.max_position_embeddings = max_position_embeddings
self.elementwise_affine = elementwise_affine
self.norm_eps = norm_eps
self.use_cache = use_cache
self.initializer_range = initializer_range
self.fuse_norm = fuse_norm
self.fuse_cross_entropy = fuse_cross_entropy

# Initialize vision specific parameters
self.image_size = image_size
self.patch_size = patch_size
self.num_channels = num_channels
self.num_classes = num_classes
self.qkv_bias = qkv_bias
self.hidden_dropout_prob = hidden_dropout_prob
self.use_mask_token = use_mask_token
self.layer_norm_eps = layer_norm_eps
self.interpolate_pos_encoding = interpolate_pos_encoding
self.scan_type = scan_type
self.encoder_stride = encoder_stride

if attn is not None:
if not isinstance(attn, Dict):
raise ValueError("attn must be a dictionary")
if 'layers' not in attn:
raise ValueError("Layer indices must be provided to initialize hybrid attention layers")
if 'num_heads' not in attn:
raise ValueError("Number of heads must be provided to initialize hybrid attention layers")
attn['num_kv_heads'] = attn.get('num_kv_heads', attn['num_heads'])
attn['window_size'] = attn.get('window_size', None)

self.attn = attn
if mlp_dim is None:
self.mlp_dim = 4 * hidden_size # default value set to 4 * hidden_size
else:
self.mlp_dim = mlp_dim

super().__init__(**kwargs)
Loading