Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add implementations for vision models #123

Open
wants to merge 30 commits into
base: main
Choose a base branch
from

Conversation

yibozhong
Copy link

Vision models based on sub-quadratic models like SSMs are quite popular, but their codebases are separated and often complex, making them hard to use and compare.

This PR implements vision models based on FLA to make it easier to use FLA models for vision tasks and compare them in the same codebase. The architecture is mostly based on Hugging Face's implementation of Vision Transformer (ViT), with several modifications. Several implementation details include:

  • Minimal code structure, consistent with current language models.
  • A companying training script is added: training/classification.py.
  • Implemented different scanning types (uni-scan, bi-scan, cross-scan). Whether these scanning methods actually work is uncertain to me, but it's good for comparison as many prior models have used them.
  • CLS-token is not used as it doesn't work in my early testing, a mean pooling is used instead.
  • Utilized common components (Embedding, Pooler) and initialization code for pretrained model from Hugging Face's ViT implementation.
  • Currently, Mamba, Mamba2 and Samba are not implemented since their code is slightly different than other models. Will implement them in the future.

Additionally, there's a bug in fla/layers/abc.py, where argument use_rope is not included in the initialization code. This PR fixes it.


I did some quick tests to check whether these models work (being able to run with a decreasing loss) The training code is classification.py. And below is an example testing script:

# test all pure fla models
for SCAN in uni-scan bi-scan cross-scan; do
    for MODEL in transformer abc bitnet deltanet gated_deltanet gla gsa hgrn hgrn2 linear_attn retnet rwkv6; do
        python classification.py \
            --model $MODEL \
            --dataset cifar100 \
            --num_hidden_layers 6 \
            --lr_scheduler_type constant \
            --hidden_size 256 \
            --attn_mode chunk \
            --train_bs 16 \
            --eval_bs 16 \
            --b_lr 1e-4 \
            --h_lr 1e-4 \
            --epochs 2 \
            --eval_epoch 1 \
            --scan_type $SCAN \
            # --wandb
    done
done
# test all hybrid models
for SCAN in uni-scan bi-scan cross-scan; do
    for MODEL in abc bitnet deltanet gated_deltanet gla gsa hgrn hgrn2 linear_attn retnet rwkv6; do
        python classification.py \
            --model $MODEL \
            --dataset cifar100 \
            --num_hidden_layers 6 \
            --lr_scheduler_type constant \
            --attn_mode chunk \
            --hidden_size 256 \
            --train_bs 16 \
            --eval_bs 16 \
            --b_lr 1e-4 \
            --h_lr 1e-4 \
            --epochs 2 \
            --eval_epoch 1 \
            --use_attn \
            --attn_layers 1,3,5 \
            --scan_type $SCAN \
            # --wandb
    done
done

Test Environment is:

envs details
CPU 12th Gen Intel i9-12900HX (24) @ 2.495GHz
GPU NVIDIA RTX 4060
OS Ubuntu 22.04.5 LTS on Windows 10 x86_64
Kernel 5.15.167.4-microsoft-standard-WSL2

and virtual environment is:

Package                  Version
------------------------ ------------------
absl-py                  2.1.0
accelerate               1.1.1
aiohttp                  3.9.5
aiosignal                1.3.1
annotated-types          0.7.0
asgiref                  3.8.1
attrs                    23.2.0
bitsandbytes             0.43.3
blinker                  1.9.0
cachetools               5.5.0
causal-conv1d            1.4.0
certifi                  2024.6.2
chardet                  5.2.0
charset-normalizer       2.0.12
click                    8.1.7
colorama                 0.4.6
contourpy                1.2.1
cycler                   0.12.1
DataProperty             1.0.1
datasets                 3.2.0
diffusers                0.32.1
dill                     0.3.8
Django                   5.1.2
docker-pycreds           0.4.0
einops                   0.8.0
evaluate                 0.4.3
filelock                 3.15.3
flash-attn               2.7.3
Flask                    2.3.3
fonttools                4.53.0
frozenlist               1.4.1
fsspec                   2024.5.0
fvcore                   0.1.5.post20221221
gitdb                    4.0.12
GitPython                3.1.44
graphviz                 0.20.3
grpcio                   1.66.2
huggingface-hub          0.27.0
idna                     3.7
importlib_metadata       8.5.0
iopath                   0.1.10
itsdangerous             2.2.0
Jinja2                   3.1.4
jiwer                    3.0.5
joblib                   1.4.2
jsonlines                4.0.0
kiwisolver               1.4.5
lightning-utilities      0.11.7
lm_eval                  0.4.5
lxml                     5.3.0
mamba-ssm                2.2.2
Markdown                 3.7
markdown-it-py           3.0.0
MarkupSafe               2.1.5
matplotlib               3.9.0
mbstrdecoder             1.1.3
mdurl                    0.1.2
metrics                  0.3.3
more-itertools           10.5.0
mpmath                   1.3.0
multidict                6.0.5
multiprocess             0.70.16
networkx                 3.3
ninja                    1.11.1.1
nltk                     3.9.1
numexpr                  2.10.1
numpy                    2.0.0
nvidia-cublas-cu12       12.1.3.1
nvidia-cuda-cupti-cu12   12.1.105
nvidia-cuda-nvrtc-cu12   12.1.105
nvidia-cuda-runtime-cu12 12.1.105
nvidia-cudnn-cu12        8.9.2.26
nvidia-cufft-cu12        11.0.2.54
nvidia-curand-cu12       10.3.2.106
nvidia-cusolver-cu12     11.4.5.107
nvidia-cusparse-cu12     12.1.0.106
nvidia-ml-py             12.535.161
nvidia-nccl-cu12         2.20.5
nvidia-nvjitlink-cu12    12.5.40
nvidia-nvtx-cu12         12.1.105
nvitop                   1.3.2
opencv-python            4.10.0.84
packaging                24.1
pandas                   2.2.2
pathlib2                 2.3.7.post1
pathspec                 0.5.5
pathvalidate             3.2.1
peft                     0.11.1
pillow                   10.3.0
pip                      24.0
platformdirs             4.3.6
portalocker              2.10.1
protobuf                 5.28.2
psutil                   6.0.0
psycopg2                 2.9.10
pyarrow                  16.1.0
pyarrow-hotfix           0.6
pybind11                 2.13.6
pydantic                 2.10.5
pydantic_core            2.27.2
Pygments                 2.18.0
pyparsing                3.1.2
pytablewriter            1.2.0
python-dateutil          2.9.0.post0
python-dotenv            1.0.0
pytorch-lightning        2.4.0
pytz                     2024.1
PyYAML                   6.0.1
RapidFuzz                3.11.0
regex                    2024.5.15
requests                 2.32.3
rich                     13.9.2
rouge_score              0.1.2
sacrebleu                2.4.3
safetensors              0.4.3
scikit-learn             1.5.1
scipy                    1.14.0
seaborn                  0.13.2
sentencepiece            0.2.0
sentry-sdk               2.19.2
setproctitle             1.3.4
setuptools               69.5.1
six                      1.16.0
smmap                    5.0.2
sqlitedict               2.1.0
sqlparse                 0.5.1
sympy                    1.12.1
tabledata                1.3.3
tabulate                 0.9.0
tcolorpy                 0.1.6
tensorboard              2.18.0
tensorboard-data-server  0.7.2
tensorboardX             2.6.2.2
termcolor                2.5.0
threadpoolctl            3.5.0
timm                     1.0.11
tokenizers               0.21.0
torch                    2.3.1
torchmetrics             1.4.3
torchvision              0.18.1
tqdm                     4.66.4
tqdm-multiprocess        0.0.11
transformers             4.47.1
triton                   3.0.0
typepy                   1.3.2
typing_extensions        4.12.2
tzdata                   2024.1
urllib3                  1.26.20
wandb                    0.19.2
Werkzeug                 2.3.7
wheel                    0.43.0
word2number              1.1
xxhash                   3.4.1
yacs                     0.1.8
yarl                     1.9.4
zipp                     3.21.0
zstandard                0.23.0

An additional machine with a A100 GPU is used to test gated delta-net.

The test results are as follows. All with 6 layers in total. Hybrid setting has attention layers with idxs: 1,3,5. The attention mode is set to chunk by default except for rwkv6. All of the errors below are caused by respective attention implementation (e.g. triton errors).

models pure FLA hybrid
abc ❌ CompilationError ❌ CompilationError
bitnet ❌ AttributeError ❌ AttributeError
deltanet
gated_deltanet RTX 4060:❌
A100: ✅
RTX 4060:❌
A100: ✅
gla ❌ CompilationError ❌ CompilationError
gsa
hgrn
hgrn2
linear_attn ❌ Matmul Shape error ❌ Matmul Shape error
retnet
rwkv6 chunk:❌ CompilationError

fused_recurrent:✅
chunk:❌ CompilationError

fused_recurrent:✅
transformer

@yzhangcs
Copy link
Member

@yibozhong Hi, thank you for your great job!
However, I have some concerns:

  1. What is the primary motivation for separating the independent vision_models folder from the current models? I believe this separation may complicate maintenance, especially as we continue to introduce more cutting-edge linear attention variants while vision_models are detached from existing models.
  2. What are the main differences between vision_models and models. I apologize for my lack of expertise in this area; the only distinction I have noticed is the inclusion of qknorm.
  3. Given the significant changes associated with this new folder, I believe it would be beneficial to create some new PRs to address existing bugs first, such as the bug in the ABC layer.
  4. We are transitioning from fla/training to a new repo, fla-org/flame, which is based on Torchtitan. Therefore, it is not advisable to update fla/training, as it will be deprecated in the near future. I also believe that maintaining an independent repository for various training recipes would be a better choice, as we intend to keep flame minimal.

Thank you!

@yibozhong
Copy link
Author

@yzhangcs Hi, thank you for your response! My answers to your concerns:

  1. Actually I initially wanted to directly modify the fla/models folder. However, I'm not sure whether directly modifying the file in fla/models is acceptable to the maintainers of this repo, and that's the whole reason for the separation. If it is okay, I will be happy to migrate my changes to the usual fla/models folder.
  2. I may be misunderstanding your question (please correct me if I'm wrong). Do you mean the folder-issue? If so, I will introduce my changes to the usual fla/models folder with no problem. If you mean the distinctions between the model implementations, the models have notable distinctions like the embedding and some code logic for sure.
  3. Thank you for your advice!
  4. Thank you for letting me know. I will delete my changes to this folder.

In general, I didn't mean to separate these two folders. If it is ok, I'll introduce my changes to the fla/models folder.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants