From d9a672ae74a6b6fa449a4cb2e671e511dad72032 Mon Sep 17 00:00:00 2001 From: yibozhong Date: Tue, 14 Jan 2025 01:50:01 +0800 Subject: [PATCH 01/17] Add a DeltaNet Image Classification Model Implementation based on huggingface-transformers/vit, currently with no cross-scan or bi-scan; also add a training script for cifar10 and cifar100 using plain pytorch. --- fla/vision_models/__init__.py | 10 + fla/vision_models/delta_net/__init__.py | 13 + .../delta_net/configuration_delta_net.py | 86 ++++ .../delta_net/modeling_delta_net.py | 207 +++++++++ fla/vision_models/utils.py | 152 +++++++ training/classification.py | 419 ++++++++++++++++++ 6 files changed, 887 insertions(+) create mode 100644 fla/vision_models/__init__.py create mode 100644 fla/vision_models/delta_net/__init__.py create mode 100644 fla/vision_models/delta_net/configuration_delta_net.py create mode 100644 fla/vision_models/delta_net/modeling_delta_net.py create mode 100644 fla/vision_models/utils.py create mode 100644 training/classification.py diff --git a/fla/vision_models/__init__.py b/fla/vision_models/__init__.py new file mode 100644 index 000000000..f93f6573d --- /dev/null +++ b/fla/vision_models/__init__.py @@ -0,0 +1,10 @@ +from fla.vision_models.delta_net import DeltaNetVisionConfig, DeltaNetForImageClassification +from fla.vision_models.utils import ImageEmbeddings, PatchEmbeddings, Pooler + +__all__ = [ + 'DeltaNetVisionConfig', + 'DeltaNetForImageClassification', + 'ImageEmbeddings', + 'PatchEmbeddings', + 'Pooler' +] diff --git a/fla/vision_models/delta_net/__init__.py b/fla/vision_models/delta_net/__init__.py new file mode 100644 index 000000000..6bc489850 --- /dev/null +++ b/fla/vision_models/delta_net/__init__.py @@ -0,0 +1,13 @@ +from transformers import AutoConfig, AutoModelForImageClassification + +from fla.vision_models.delta_net.configuration_delta_net import DeltaNetVisionConfig +from fla.vision_models.delta_net.modeling_delta_net import DeltaNetForImageClassification + +# Register the model with transformers +AutoConfig.register("delta_net_vision", DeltaNetVisionConfig) +AutoModelForImageClassification.register(DeltaNetVisionConfig, DeltaNetForImageClassification) + +__all__ = [ + 'DeltaNetVisionConfig', + 'DeltaNetForImageClassification' +] diff --git a/fla/vision_models/delta_net/configuration_delta_net.py b/fla/vision_models/delta_net/configuration_delta_net.py new file mode 100644 index 000000000..40fec3ab9 --- /dev/null +++ b/fla/vision_models/delta_net/configuration_delta_net.py @@ -0,0 +1,86 @@ +from typing import Dict, Optional +from transformers.configuration_utils import PretrainedConfig + +class DeltaNetVisionConfig(PretrainedConfig): + model_type = 'delta_net_vision' + + def __init__( + self, + # DeltaNet core parameters + attn_mode: str = "chunk", + hidden_size: int = 2048, + expand_k: int = 1, + expand_v: int = 1, + use_gate: bool = False, + use_short_conv: bool = True, + conv_size: int = 4, + use_beta: bool = True, + use_output_norm: bool = True, + num_heads: int = 16, + qk_norm: str = 'l2', + qk_activation: str = 'silu', + hidden_ratio: Optional[int] = 4, + intermediate_size: Optional[int] = None, + hidden_act: str = "swish", + num_hidden_layers: int = 12, + norm_first: bool = False, + norm_eps: float = 1e-6, + attn: Optional[Dict] = None, + use_cache: bool = True, + initializer_range: float = 0.02, + fuse_cross_entropy: bool = True, + max_position_embeddings: int = 2048, + # Vision specific parameters + image_size: int = 224, + patch_size: int = 16, + num_channels: int = 3, + num_classes: int = 1000, + qkv_bias: bool = True, + hidden_dropout_prob: float = 0.0, + use_mask_token: bool = False, + layer_norm_eps: float = 1e-6, + mlp_dim: int = None, + pool_type: str = "mean", # use "mean" by default + **kwargs + ): + # Initialize DeltaNet core parameters + self.attn_mode = attn_mode + self.hidden_size = hidden_size + self.expand_k = expand_k + self.expand_v = expand_v + self.use_gate = use_gate + self.use_short_conv = use_short_conv + self.conv_size = conv_size + self.use_beta = use_beta + self.use_output_norm = use_output_norm + self.num_heads = num_heads + self.qk_norm = qk_norm + self.qk_activation = qk_activation + self.hidden_ratio = hidden_ratio + self.intermediate_size = intermediate_size + self.hidden_act = hidden_act + self.num_hidden_layers = num_hidden_layers + self.norm_first = norm_first + self.norm_eps = norm_eps + self.use_cache = use_cache + self.initializer_range = initializer_range + self.fuse_cross_entropy = fuse_cross_entropy + self.attn = attn + self.max_position_embeddings = max_position_embeddings + + # Initialize vision specific parameters + self.image_size = image_size + self.patch_size = patch_size + self.num_channels = num_channels + self.num_classes = num_classes + self.qkv_bias = qkv_bias + self.hidden_dropout_prob = hidden_dropout_prob + self.use_mask_token = use_mask_token + self.layer_norm_eps = layer_norm_eps + self.pool_type = pool_type + if mlp_dim is None: + self.mlp_dim = 4 * hidden_size # default value set to 4 * hidden_size + else: + self.mlp_dim = mlp_dim + + super().__init__(**kwargs) diff --git a/fla/vision_models/delta_net/modeling_delta_net.py b/fla/vision_models/delta_net/modeling_delta_net.py new file mode 100644 index 000000000..c7d638dca --- /dev/null +++ b/fla/vision_models/delta_net/modeling_delta_net.py @@ -0,0 +1,207 @@ +import collections.abc +import math +import torch +import torch.utils.checkpoint +from torch import nn +from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss +from typing import Optional, Set, Tuple, Union, List, Dict, Unpack +from transformers.utils import logging +from fla.layers.attn import Attention +from transformers.modeling_outputs import ImageClassifierOutput +from transformers.modeling_utils import PreTrainedModel +from .configuration_delta_net import DeltaNetVisionConfig +from fla.layers.delta_net import DeltaNet +from fla.models.utils import Cache +from ..utils import ImageEmbeddings, Pooler + +logger = logging.get_logger(__name__) + +class DeltaNetMLP(nn.Module): + def __init__(self, config): + super().__init__() + self.net = nn.Sequential( + nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps), + nn.Linear(config.hidden_size, config.mlp_dim), + nn.GELU(), + nn.Linear(config.mlp_dim, config.hidden_size), + nn.Dropout(config.hidden_dropout_prob) + ) + + def forward(self, x): + return self.net(x) + +class DeltaNetBlock(nn.Module): + def __init__(self, config, layer_idx: int): + super().__init__() + + if not config.norm_first: + self.ln_1 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + + if config.attn is not None and layer_idx in config.attn['layers']: + self.attn = Attention( + hidden_size=config.hidden_size, + num_heads=config.attn['num_heads'], + num_kv_heads=config.attn['num_kv_heads'], + window_size=config.attn['window_size'], + max_position_embeddings=config.max_position_embeddings, + layer_idx=layer_idx + ) + else: + self.attn = DeltaNet( + mode=config.attn_mode, + hidden_size=config.hidden_size, + expand_k=config.expand_k, + expand_v=config.expand_v, + num_heads=config.num_heads, + use_gate=config.use_gate, + use_beta=config.use_beta, + use_short_conv=config.use_short_conv, + use_output_norm=config.use_output_norm, + conv_size=config.conv_size, + qk_norm=config.qk_norm, + qk_activation=config.qk_activation, + norm_first=config.norm_first, + norm_eps=config.norm_eps, + layer_idx=layer_idx + ) + + if not config.norm_first: + self.ln_2 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + + self.mlp = DeltaNetMLP(config) + + def forward( + self, + hidden_states: torch.Tensor, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + use_cache: Optional[bool] = False, + output_attentions: Optional[bool] = False, + **kwargs: Unpack[Dict] + ) -> Union[Tuple[torch.Tensor, Optional[torch.Tensor]], Tuple[torch.Tensor]]: + residual = hidden_states + + # Pre-normalization if enabled + if hasattr(self, 'ln_1'): + hidden_states = self.ln_1(hidden_states) + + # Apply attention + hidden_states, attentions, past_key_values = self.attn( + hidden_states=hidden_states, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + **kwargs + ) + + # First residual connection + hidden_states = residual + hidden_states + residual = hidden_states + + # Pre-normalization for MLP if enabled + if hasattr(self, 'ln_2'): + hidden_states = self.ln_2(hidden_states) + + # MLP + hidden_states = self.mlp(hidden_states) + + # Second residual connection + hidden_states = residual + hidden_states + + outputs = (hidden_states, attentions, past_key_values) + + return outputs + +class DeltaNetVisionPreTrainedModel(PreTrainedModel): + # this part of the code is adapted from huggingface/transformers vit implementation + config_class = DeltaNetVisionConfig + base_model_prefix = "deltanet" + + def _init_weights(self, module): + if isinstance(module, (nn.Linear, nn.Conv2d)): + module.weight.data = nn.init.trunc_normal_( + module.weight.data.to(torch.float32), mean=0.0, std=self.config.initializer_range + ).to(module.weight.dtype) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.LayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + elif isinstance(module, ImageEmbeddings): + module.position_embeddings.data = nn.init.trunc_normal_( + module.position_embeddings.data.to(torch.float32), + mean=0.0, + std=self.config.initializer_range, + ).to(module.position_embeddings.dtype) + + module.cls_token.data = nn.init.trunc_normal_( + module.cls_token.data.to(torch.float32), + mean=0.0, + std=self.config.initializer_range, + ).to(module.cls_token.dtype) + +class DeltaNetForImageClassification(DeltaNetVisionPreTrainedModel): + config_class = DeltaNetVisionConfig + + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_classes + + self.embeddings = ImageEmbeddings(config) + self.blocks = nn.ModuleList([ + DeltaNetBlock(config, layer_idx) + for layer_idx in range(config.num_hidden_layers) + ]) + self.norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.pooler = Pooler(config) + self.classifier = nn.Linear(config.hidden_size, config.num_classes) + + self.init_weights() + + def forward( + self, + pixel_values: Optional[torch.Tensor] = None, + labels: Optional[torch.Tensor] = None, + output_hidden_states: Optional[bool] = None, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + return_dict: Optional[bool] = None, + interpolate_pos_encoding: Optional[bool] = None, + **kwargs: Unpack[Dict] + ) -> Union[Tuple, ImageClassifierOutput]: + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + fuse_linear_and_cross_entropy = self.config.fuse_cross_entropy and self.training + hidden_states = self.embeddings(pixel_values, interpolate_pos_encoding) + + for block in self.blocks: + hidden_states, attentions, past_key_values = block( + hidden_states, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + **kwargs + ) + + hidden_states = self.norm(hidden_states) + pooled_output = self.pooler(hidden_states) + + logits = self.classifier(pooled_output) + + loss = None + if labels is not None: + if self.num_labels == 1: + loss_fct = MSELoss() + loss = loss_fct(logits.squeeze(), labels.squeeze()) + else: + loss_fct = CrossEntropyLoss() + loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) + + if not return_dict: + output = (logits,) + (hidden_states,) + return ((loss,) + output) if loss is not None else output + + return ImageClassifierOutput( + loss=loss, + logits=logits, + hidden_states=hidden_states, + ) diff --git a/fla/vision_models/utils.py b/fla/vision_models/utils.py new file mode 100644 index 000000000..321f6938c --- /dev/null +++ b/fla/vision_models/utils.py @@ -0,0 +1,152 @@ +""" +Vision model utilities adapted from huggingface/transformers ViT implementation. +""" + +import collections.abc +import torch +from torch import nn +from typing import Optional +from transformers.utils import torch_int + +class PatchEmbeddings(nn.Module): + """ + Convert image into patch embeddings. + Adapted from huggingface/transformers ViT implementation. + """ + def __init__(self, config): + super().__init__() + image_size, patch_size = config.image_size, config.patch_size + num_channels, hidden_size = config.num_channels, config.hidden_size + + image_size = image_size if isinstance(image_size, collections.abc.Iterable) else (image_size, image_size) + patch_size = patch_size if isinstance(patch_size, collections.abc.Iterable) else (patch_size, patch_size) + num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0]) + self.image_size = image_size + self.patch_size = patch_size + self.num_channels = num_channels + self.num_patches = num_patches + + self.projection = nn.Conv2d(num_channels, hidden_size, kernel_size=patch_size, stride=patch_size) + + def forward(self, pixel_values: torch.Tensor, interpolate_pos_encoding: bool = False) -> torch.Tensor: + batch_size, num_channels, height, width = pixel_values.shape + if num_channels != self.num_channels: + raise ValueError( + "Make sure that the channel dimension of the pixel values match with the one set in the configuration." + f" Expected {self.num_channels} but got {num_channels}." + ) + if not interpolate_pos_encoding: + if height != self.image_size[0] or width != self.image_size[1]: + raise ValueError( + f"Input image size ({height}*{width}) doesn't match model" + f" ({self.image_size[0]}*{self.image_size[1]})." + ) + embeddings = self.projection(pixel_values).flatten(2).transpose(1, 2) + return embeddings + +class ImageEmbeddings(nn.Module): + """ + Construct the CLS token, position and patch embeddings. + Adapted from huggingface/transformers ViT implementation. + """ + def __init__(self, config, use_mask_token: bool = False) -> None: + super().__init__() + + self.cls_token = nn.Parameter(torch.randn(1, 1, config.hidden_size)) + self.mask_token = nn.Parameter(torch.zeros(1, 1, config.hidden_size)) if use_mask_token else None + self.patch_embeddings = PatchEmbeddings(config) + num_patches = self.patch_embeddings.num_patches + self.position_embeddings = nn.Parameter(torch.randn(1, num_patches + 1, config.hidden_size)) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + self.patch_size = config.patch_size + self.config = config + + def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor: + """ + This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher resolution + images. This method is also adapted to support torch.jit tracing. + + Adapted from: + - https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174-L194, and + - https://github.com/facebookresearch/dinov2/blob/e1277af2ba9496fbadf7aec6eba56e8d882d1e35/dinov2/models/vision_transformer.py#L179-L211 + """ + + num_patches = embeddings.shape[1] - 1 + num_positions = self.position_embeddings.shape[1] - 1 + + # always interpolate when tracing to ensure the exported model works for dynamic input shapes + if not torch.jit.is_tracing() and num_patches == num_positions and height == width: + return self.position_embeddings + + class_pos_embed = self.position_embeddings[:, :1] + patch_pos_embed = self.position_embeddings[:, 1:] + + dim = embeddings.shape[-1] + + new_height = height // self.patch_size + new_width = width // self.patch_size + + sqrt_num_positions = torch_int(num_positions**0.5) + patch_pos_embed = patch_pos_embed.reshape(1, sqrt_num_positions, sqrt_num_positions, dim) + patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2) + + patch_pos_embed = nn.functional.interpolate( + patch_pos_embed, + size=(new_height, new_width), + mode="bicubic", + align_corners=False, + ) + + patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim) + + return torch.cat((class_pos_embed, patch_pos_embed), dim=1) + + def forward( + self, + pixel_values: torch.Tensor, + bool_masked_pos: Optional[torch.BoolTensor] = None, + interpolate_pos_encoding: bool = False, + ) -> torch.Tensor: + batch_size, num_channels, height, width = pixel_values.shape + embeddings = self.patch_embeddings(pixel_values, interpolate_pos_encoding=interpolate_pos_encoding) + + if bool_masked_pos is not None: + seq_length = embeddings.shape[1] + mask_tokens = self.mask_token.expand(batch_size, seq_length, -1) + # replace the masked visual tokens by mask_tokens + mask = bool_masked_pos.unsqueeze(-1).type_as(mask_tokens) + embeddings = embeddings * (1.0 - mask) + mask_tokens * mask + + # add the [CLS] token to the embedded patch tokens + cls_tokens = self.cls_token.expand(batch_size, -1, -1) + embeddings = torch.cat((cls_tokens, embeddings), dim=1) + + # add positional encoding to each token + if interpolate_pos_encoding: + embeddings = embeddings + self.interpolate_pos_encoding(embeddings, height, width) + else: + embeddings = embeddings + self.position_embeddings + + embeddings = self.dropout(embeddings) + + return embeddings + +class Pooler(nn.Module): + """ + Pool the output of a vision model by taking either the CLS token or mean of all tokens. + Adapted from huggingface/transformers ViT implementation. + """ + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + self.activation = nn.Tanh() + self.pool_type = config.pool_type + + def forward(self, hidden_states): + if self.pool_type == 'cls': + pooled_output = hidden_states[:, 0] + else: # 'mean' + pooled_output = hidden_states.mean(dim=1) + pooled_output = self.dense(pooled_output) + pooled_output = self.activation(pooled_output) + return pooled_output diff --git a/training/classification.py b/training/classification.py new file mode 100644 index 000000000..88c160149 --- /dev/null +++ b/training/classification.py @@ -0,0 +1,419 @@ +import os +import torch +from tqdm import tqdm +import wandb +import logging +import random +import torch.optim as optim +from torch.utils.data import DataLoader +from torchvision import datasets, transforms +from transformers import get_scheduler +from torch.amp import GradScaler, autocast +from fla.vision_models.delta_net import DeltaNetVisionConfig, DeltaNetForImageClassification +import time + +device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') +dtype = torch.bfloat16 # deafult dtype for FLA + +def setup_logging(args): + log_filename = f'training_{args.model}_vision_{args.dataset}.log' + logging.basicConfig( + level=logging.INFO, + format='%(asctime)s - %(levelname)s - %(message)s', + handlers=[ + logging.FileHandler(log_filename), + logging.StreamHandler() + ] + ) + logging.info(f"Logging to {log_filename}") + +def get_args(): + import argparse + parser = argparse.ArgumentParser(description='Vision Model Training') + parser.add_argument('--dataset', type=str, default='cifar100', help='Dataset name') + parser.add_argument('--num_hidden_layers', type=int, default=12, help='Number of hidden layers') + parser.add_argument('--hidden_size', type=int, default=768, help='Hidden size') + parser.add_argument('--patch_size', type=int, default=16, help='Patch size') + parser.add_argument('--image_size', type=int, default=224, help='Image size') + parser.add_argument('--epochs', type=int, default=50, help='Number of epochs') + parser.add_argument('--amp_enabled', action='store_true', help='Enable AMP if device supports it') + parser.add_argument('--b_lr', type=float, default=2e-4, help='Backbone learning rate') + parser.add_argument('--h_lr', type=float, default=2e-4, help='Head learning rate') + parser.add_argument('--wd', type=float, default=0., help='Weight decay') + parser.add_argument('--train_bs', type=int, default=128, help='Training batch size') + parser.add_argument('--eval_bs', type=int, default=256, help='Eval batch size') + parser.add_argument('--num_workers', type=int, default=4, help='Number of workers') + parser.add_argument('--num_heads', type=int, default=16, help='Number of attention heads') + parser.add_argument('--eval_epoch', type=int, default=1, help='Eval frequency') + parser.add_argument('--log_step', type=int, default=10, help='Log frequency') + parser.add_argument('--seed', type=int, default=42, help='Random seed') + parser.add_argument('--wandb', action='store_true', help='Enable wandb logging') + parser.add_argument('--expand_k', type=float, default=1.0, help='Key expansion ratio') + parser.add_argument('--expand_v', type=float, default=1.0, help='Value expansion ratio') + parser.add_argument('--attn_mode', type=str, default='chunk', choices=['chunk', 'fused_recurrent', 'fused_chunk']) + parser.add_argument('--pool_type', type=str, default='mean', choices=['mean', 'cls']) + parser.add_argument('--model', type=str, required=True, help='Model type (currently only supports "deltanet")') + parser.add_argument('--fuse_cross_entropy', action='store_true', help='Fuse cross entropy with logits') + + # Learning rate schedule related arguments + parser.add_argument('--lr_scheduler_type', type=str, default='constant_with_warmup', + choices=['linear', 'cosine', 'cosine_with_restarts', 'polynomial', + 'constant', 'constant_with_warmup']) + parser.add_argument('--warmup_ratio', type=float, default=0.1, + help='Ratio of total training steps for warmup') + # Add hybrid attention related arguments + parser.add_argument('--use_attn', action='store_true', help='Use hybrid attention in some layers') + parser.add_argument('--attn_layers', type=str, default='0,1', + help='Comma separated list of layer indices to use attention, e.g. "0,1,2"') + parser.add_argument('--attn_num_heads', type=int, default=16, + help='Number of attention heads for hybrid attention layers') + parser.add_argument('--attn_num_kv_heads', type=int, default=None, + help='Number of key/value heads for hybrid attention layers') + parser.add_argument('--attn_window_size', type=int, default=None, + help='Window size for hybrid attention layers') + parser.add_argument('--log_memory_epoch', type=int, default=100, help='Log memory usage frequency') + return parser.parse_args() + +def get_data(args): + """ + Prepare data transforms and loaders. + Ensures consistent data types with model. + """ + transform = transforms.Compose([ + transforms.Resize((args.image_size, args.image_size)), + transforms.ToTensor(), + transforms.ConvertImageDtype(dtype), # Match model dtype + ]) + + if args.dataset == 'cifar10': + train_dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform) + test_dataset = datasets.CIFAR10(root='./data', train=False, download=True, transform=transform) + num_classes = 10 + elif args.dataset == 'cifar100': + train_dataset = datasets.CIFAR100(root='./data', train=True, download=True, transform=transform) + test_dataset = datasets.CIFAR100(root='./data', train=False, download=True, transform=transform) + num_classes = 100 + else: + raise ValueError(f"Unsupported dataset: {args.dataset}") + + train_loader = DataLoader(train_dataset, batch_size=args.train_bs, shuffle=True, num_workers=args.num_workers) + test_loader = DataLoader(test_dataset, batch_size=args.eval_bs, shuffle=False, num_workers=args.num_workers) + + return train_loader, test_loader, num_classes + +def setup_deterministic_mode(args): + """Setup deterministic mode for reproducibility""" + import numpy as np + np.random.seed(args.seed) + random.seed(args.seed) + torch.manual_seed(args.seed) + torch.cuda.manual_seed(args.seed) + torch.backends.cudnn.deterministic = True + torch.backends.cudnn.benchmark = False + +def get_gpu_memory_info(): + """ + Get current GPU memory usage information + Returns a dictionary with: + - memory_allocated: actively allocated memory + - memory_reserved: reserved memory in GPU + - max_memory_allocated: max allocated memory since the beginning + """ + return { + 'memory_allocated': torch.cuda.memory_allocated() / 1024**2, # MB + 'memory_reserved': torch.cuda.memory_reserved() / 1024**2, # MB + 'max_memory_allocated': torch.cuda.max_memory_allocated() / 1024**2 # MB + } + +def log_gpu_memory(args, epoch): + """Log GPU memory usage if CUDA is available""" + if torch.cuda.is_available() and epoch % args.log_memory_epoch == 0: + memory_info = get_gpu_memory_info() + logging.info( + f"GPU Memory Usage (Epoch {epoch}) - " + f"Allocated: {memory_info['memory_allocated']:.2f}MB, " + f"Reserved: {memory_info['memory_reserved']:.2f}MB, " + f"Peak: {memory_info['max_memory_allocated']:.2f}MB" + ) + if args.wandb: + wandb.log({ + "gpu_memory/allocated": memory_info['memory_allocated'], + "gpu_memory/reserved": memory_info['memory_reserved'], + "gpu_memory/peak": memory_info['max_memory_allocated'], + "epoch": epoch + }) + +def evaluate(model, test_loader, device, args): + """ + Evaluation loop with proper CUDA timing. + Uses CUDA events for accurate GPU timing and ensures proper synchronization. + """ + model.eval() + correct = 0 + total = 0 + + # Create CUDA events for timing + if torch.cuda.is_available(): + start_event = torch.cuda.Event(enable_timing=True) + end_event = torch.cuda.Event(enable_timing=True) + torch.cuda.synchronize() + start_event.record() + else: + start_time = time.perf_counter() + + with torch.no_grad(): + for images, targets in tqdm(test_loader): + images = images.to(device=device, dtype=dtype) + targets = targets.to(device) + + if args.amp_enabled: + with autocast(): + outputs = model(images).logits + _, predicted = outputs.max(1) + else: + outputs = model(images).logits + _, predicted = outputs.max(1) + + total += targets.size(0) + correct += predicted.eq(targets).sum().item() + + # Measure time with proper CUDA synchronization + if torch.cuda.is_available(): + end_event.record() + torch.cuda.synchronize() + eval_time = start_event.elapsed_time(end_event) / 1000.0 # Convert to seconds + else: + eval_time = time.perf_counter() - start_time + + accuracy = 100. * correct / total + return accuracy, eval_time + +def get_model(args, num_classes): + """ + Initialize model based on configuration. + Supports both pure DeltaNet and hybrid models. + """ + if args.model == 'deltanet': + # Prepare attention config for hybrid model if enabled + attn_config = None + if args.use_attn: + attn_config = { + 'layers': [int(i) for i in args.attn_layers.split(',')], + 'num_heads': args.attn_num_heads, + 'num_kv_heads': args.attn_num_kv_heads, + 'window_size': args.attn_window_size + } + # Log hybrid attention configuration + logging.info("Hybrid Attention Configuration:") + logging.info(f"- Attention Layers: {attn_config['layers']}") + logging.info(f"- Number of Heads: {attn_config['num_heads']}") + logging.info(f"- Number of KV Heads: {attn_config['num_kv_heads']}") + logging.info(f"- Window Size: {attn_config['window_size']}") + + config = DeltaNetVisionConfig( + num_hidden_layers=args.num_hidden_layers, + hidden_size=args.hidden_size, + num_heads=args.num_heads, + patch_size=args.patch_size, + image_size=args.image_size, + num_classes=num_classes, + expand_k=args.expand_k, + expand_v=args.expand_v, + attn_mode=args.attn_mode, + pool_type=args.pool_type, + fuse_cross_entropy=args.fuse_cross_entropy, + attn=attn_config # Add attention config for hybrid model + ) + return DeltaNetForImageClassification(config).to(device=device, dtype=dtype) + else: + raise NotImplementedError(f"Model {args.model} is not implemented yet.") + +def train_one_epoch(model, train_loader, criterion, optimizer, scheduler, device, args, epoch): + """ + Training loop for one epoch with proper CUDA timing. + Uses CUDA events for accurate GPU timing and ensures proper synchronization. + """ + model.train() + total_loss = 0 + scaler = GradScaler() if args.amp_enabled else None + + # Create CUDA events for timing + if torch.cuda.is_available(): + start_event = torch.cuda.Event(enable_timing=True) + end_event = torch.cuda.Event(enable_timing=True) + torch.cuda.synchronize() + start_event.record() + else: + start_time = time.perf_counter() + + for i, (images, targets) in enumerate(tqdm(train_loader)): + images = images.to(device=device, dtype=dtype) + targets = targets.to(device) + + if args.amp_enabled: + with autocast(): + outputs = model(images).logits + loss = criterion(outputs, targets) + scaler.scale(loss).backward() + scaler.step(optimizer) + scaler.update() + else: + outputs = model(images).logits + loss = criterion(outputs, targets) + loss.backward() + optimizer.step() + + optimizer.zero_grad() + scheduler.step() # Update learning rate scheduler + total_loss += loss.item() + + if i % args.log_step == 0: + lrs = [group['lr'] for group in optimizer.param_groups] + logging.info(f'Epoch {epoch} Step {i}/{len(train_loader)}: ' + f'Loss={loss.item():.4f} ' + f'LR_backbone={lrs[0]:.2e} ' + f'LR_head={lrs[-1]:.2e}') + + if args.wandb: + wandb.log({ + "batch_loss": loss.item(), + "learning_rate/backbone": lrs[0], + "learning_rate/head": lrs[-1], + "global_step": epoch * len(train_loader) + i + }) + + # Measure time with proper CUDA synchronization + if torch.cuda.is_available(): + end_event.record() + torch.cuda.synchronize() + train_time = start_event.elapsed_time(end_event) / 1000.0 + else: + train_time = time.perf_counter() - start_time + + avg_loss = total_loss / len(train_loader) + return avg_loss, train_time + +def main(): + args = get_args() + + # Setup logging first, before any logging calls + setup_logging(args) + + # Then setup deterministic mode + setup_deterministic_mode(args) + + # Log all configuration parameters + logging.info("=" * 50) + logging.info("Training Configuration:") + logging.info("-" * 50) + for arg, value in sorted(vars(args).items()): + logging.info(f"{arg}: {value}") + logging.info("=" * 50) + + # Setup wandb after logging is initialized + if args.wandb: + project_name = f"{args.model}_vision_classification" + run_name = f"e{args.epochs}_b_lr{args.b_lr}_h_lr_{args.h_lr}_mode{args.attn_mode}_bs{args.train_bs}_p{args.patch_size}_i{args.image_size}_h{args.num_heads}_{args.dataset}" + wandb.init( + project=project_name, + name=run_name, + config=args.__dict__ + ) + logging.info(f"Wandb initialized with project: {project_name}, run: {run_name}") + + train_loader, test_loader, num_classes = get_data(args) + + # Calculate total training steps + num_training_steps = len(train_loader) * args.epochs + num_warmup_steps = int(args.warmup_ratio * num_training_steps) + + model = get_model(args, num_classes) + trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad) + + logging.info("=" * 50) + logging.info("Model Information:") + logging.info("-" * 50) + logging.info(f"Model Type: {args.model}") + logging.info(f"Number of trainable parameters: {trainable_params:,}") + logging.info(f"Number of layers: {args.num_hidden_layers}") + logging.info(f"Hidden size: {args.hidden_size}") + logging.info(f"Number of heads: {args.num_heads}") + logging.info(f"Learning rate scheduler: {args.lr_scheduler_type}") + logging.info(f"Total training steps: {num_training_steps}") + logging.info(f"Warmup steps: {num_warmup_steps}") + logging.info("=" * 50) + + if args.wandb: + wandb.log({"trainable_parameters": trainable_params}) + + criterion = torch.nn.CrossEntropyLoss() + optimizer = optim.AdamW([ + {'params': model.embeddings.parameters(), 'lr': args.b_lr}, + {'params': model.blocks.parameters(), 'lr': args.b_lr}, + {'params': model.classifier.parameters(), 'lr': args.h_lr} + ], weight_decay=args.wd) + + scheduler = get_scheduler( + name=args.lr_scheduler_type, + optimizer=optimizer, + num_warmup_steps=num_warmup_steps, + num_training_steps=num_training_steps + ) + + best_acc = 0 + total_train_time = 0 + total_eval_time = 0 + eval_num = 0 + + for epoch in range(args.epochs): + avg_loss, epoch_train_time = train_one_epoch(model, train_loader, criterion, optimizer, scheduler, device, args, epoch) + total_train_time += epoch_train_time + + # Log GPU memory usage + log_gpu_memory(args, epoch) + + if epoch % args.eval_epoch == 0: + accuracy, epoch_eval_time = evaluate(model, test_loader, device, args) + total_eval_time += epoch_eval_time + eval_num += 1 + + logging.info( + f'Epoch {epoch}: Loss={avg_loss:.4f}, Accuracy={accuracy:.2f}%, ' + f'Train time={epoch_train_time:.2f}s, Eval time={epoch_eval_time:.2f}s' + ) + + if args.wandb: + wandb.log({ + "epoch": epoch, + "train_loss": avg_loss, + "accuracy": accuracy, + "train_time": epoch_train_time, + "eval_time": epoch_eval_time, + "avg_epoch_train_time": total_train_time / (epoch + 1), + "avg_epoch_eval_time": total_eval_time / eval_num + }) + + if accuracy > best_acc: + best_acc = accuracy + torch.save(model.state_dict(), f'{args.model}_vision_best.pth') + + # Log final statistics + avg_train_time = total_train_time / args.epochs + avg_eval_time = total_eval_time / eval_num + logging.info( + f'Training completed. Best accuracy: {best_acc:.2f}%\n' + f'Average training time per epoch: {avg_train_time:.2f}s\n' + f'Average evaluation time: {avg_eval_time:.2f}s' + ) + + if args.wandb: + wandb.log({ + "final/best_accuracy": best_acc, + "final/avg_train_time": avg_train_time, + "final/avg_eval_time": avg_eval_time + }) + if args.wandb: + wandb.finish() + +if __name__ == '__main__': + main() From faf49d7eefc3f3fcdd9d8ea1d187b952fa2cc509 Mon Sep 17 00:00:00 2001 From: yibozhong Date: Tue, 14 Jan 2025 01:58:33 +0800 Subject: [PATCH 02/17] change position of vision trainig code --- classification.py | 419 ++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 419 insertions(+) create mode 100644 classification.py diff --git a/classification.py b/classification.py new file mode 100644 index 000000000..88c160149 --- /dev/null +++ b/classification.py @@ -0,0 +1,419 @@ +import os +import torch +from tqdm import tqdm +import wandb +import logging +import random +import torch.optim as optim +from torch.utils.data import DataLoader +from torchvision import datasets, transforms +from transformers import get_scheduler +from torch.amp import GradScaler, autocast +from fla.vision_models.delta_net import DeltaNetVisionConfig, DeltaNetForImageClassification +import time + +device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') +dtype = torch.bfloat16 # deafult dtype for FLA + +def setup_logging(args): + log_filename = f'training_{args.model}_vision_{args.dataset}.log' + logging.basicConfig( + level=logging.INFO, + format='%(asctime)s - %(levelname)s - %(message)s', + handlers=[ + logging.FileHandler(log_filename), + logging.StreamHandler() + ] + ) + logging.info(f"Logging to {log_filename}") + +def get_args(): + import argparse + parser = argparse.ArgumentParser(description='Vision Model Training') + parser.add_argument('--dataset', type=str, default='cifar100', help='Dataset name') + parser.add_argument('--num_hidden_layers', type=int, default=12, help='Number of hidden layers') + parser.add_argument('--hidden_size', type=int, default=768, help='Hidden size') + parser.add_argument('--patch_size', type=int, default=16, help='Patch size') + parser.add_argument('--image_size', type=int, default=224, help='Image size') + parser.add_argument('--epochs', type=int, default=50, help='Number of epochs') + parser.add_argument('--amp_enabled', action='store_true', help='Enable AMP if device supports it') + parser.add_argument('--b_lr', type=float, default=2e-4, help='Backbone learning rate') + parser.add_argument('--h_lr', type=float, default=2e-4, help='Head learning rate') + parser.add_argument('--wd', type=float, default=0., help='Weight decay') + parser.add_argument('--train_bs', type=int, default=128, help='Training batch size') + parser.add_argument('--eval_bs', type=int, default=256, help='Eval batch size') + parser.add_argument('--num_workers', type=int, default=4, help='Number of workers') + parser.add_argument('--num_heads', type=int, default=16, help='Number of attention heads') + parser.add_argument('--eval_epoch', type=int, default=1, help='Eval frequency') + parser.add_argument('--log_step', type=int, default=10, help='Log frequency') + parser.add_argument('--seed', type=int, default=42, help='Random seed') + parser.add_argument('--wandb', action='store_true', help='Enable wandb logging') + parser.add_argument('--expand_k', type=float, default=1.0, help='Key expansion ratio') + parser.add_argument('--expand_v', type=float, default=1.0, help='Value expansion ratio') + parser.add_argument('--attn_mode', type=str, default='chunk', choices=['chunk', 'fused_recurrent', 'fused_chunk']) + parser.add_argument('--pool_type', type=str, default='mean', choices=['mean', 'cls']) + parser.add_argument('--model', type=str, required=True, help='Model type (currently only supports "deltanet")') + parser.add_argument('--fuse_cross_entropy', action='store_true', help='Fuse cross entropy with logits') + + # Learning rate schedule related arguments + parser.add_argument('--lr_scheduler_type', type=str, default='constant_with_warmup', + choices=['linear', 'cosine', 'cosine_with_restarts', 'polynomial', + 'constant', 'constant_with_warmup']) + parser.add_argument('--warmup_ratio', type=float, default=0.1, + help='Ratio of total training steps for warmup') + # Add hybrid attention related arguments + parser.add_argument('--use_attn', action='store_true', help='Use hybrid attention in some layers') + parser.add_argument('--attn_layers', type=str, default='0,1', + help='Comma separated list of layer indices to use attention, e.g. "0,1,2"') + parser.add_argument('--attn_num_heads', type=int, default=16, + help='Number of attention heads for hybrid attention layers') + parser.add_argument('--attn_num_kv_heads', type=int, default=None, + help='Number of key/value heads for hybrid attention layers') + parser.add_argument('--attn_window_size', type=int, default=None, + help='Window size for hybrid attention layers') + parser.add_argument('--log_memory_epoch', type=int, default=100, help='Log memory usage frequency') + return parser.parse_args() + +def get_data(args): + """ + Prepare data transforms and loaders. + Ensures consistent data types with model. + """ + transform = transforms.Compose([ + transforms.Resize((args.image_size, args.image_size)), + transforms.ToTensor(), + transforms.ConvertImageDtype(dtype), # Match model dtype + ]) + + if args.dataset == 'cifar10': + train_dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform) + test_dataset = datasets.CIFAR10(root='./data', train=False, download=True, transform=transform) + num_classes = 10 + elif args.dataset == 'cifar100': + train_dataset = datasets.CIFAR100(root='./data', train=True, download=True, transform=transform) + test_dataset = datasets.CIFAR100(root='./data', train=False, download=True, transform=transform) + num_classes = 100 + else: + raise ValueError(f"Unsupported dataset: {args.dataset}") + + train_loader = DataLoader(train_dataset, batch_size=args.train_bs, shuffle=True, num_workers=args.num_workers) + test_loader = DataLoader(test_dataset, batch_size=args.eval_bs, shuffle=False, num_workers=args.num_workers) + + return train_loader, test_loader, num_classes + +def setup_deterministic_mode(args): + """Setup deterministic mode for reproducibility""" + import numpy as np + np.random.seed(args.seed) + random.seed(args.seed) + torch.manual_seed(args.seed) + torch.cuda.manual_seed(args.seed) + torch.backends.cudnn.deterministic = True + torch.backends.cudnn.benchmark = False + +def get_gpu_memory_info(): + """ + Get current GPU memory usage information + Returns a dictionary with: + - memory_allocated: actively allocated memory + - memory_reserved: reserved memory in GPU + - max_memory_allocated: max allocated memory since the beginning + """ + return { + 'memory_allocated': torch.cuda.memory_allocated() / 1024**2, # MB + 'memory_reserved': torch.cuda.memory_reserved() / 1024**2, # MB + 'max_memory_allocated': torch.cuda.max_memory_allocated() / 1024**2 # MB + } + +def log_gpu_memory(args, epoch): + """Log GPU memory usage if CUDA is available""" + if torch.cuda.is_available() and epoch % args.log_memory_epoch == 0: + memory_info = get_gpu_memory_info() + logging.info( + f"GPU Memory Usage (Epoch {epoch}) - " + f"Allocated: {memory_info['memory_allocated']:.2f}MB, " + f"Reserved: {memory_info['memory_reserved']:.2f}MB, " + f"Peak: {memory_info['max_memory_allocated']:.2f}MB" + ) + if args.wandb: + wandb.log({ + "gpu_memory/allocated": memory_info['memory_allocated'], + "gpu_memory/reserved": memory_info['memory_reserved'], + "gpu_memory/peak": memory_info['max_memory_allocated'], + "epoch": epoch + }) + +def evaluate(model, test_loader, device, args): + """ + Evaluation loop with proper CUDA timing. + Uses CUDA events for accurate GPU timing and ensures proper synchronization. + """ + model.eval() + correct = 0 + total = 0 + + # Create CUDA events for timing + if torch.cuda.is_available(): + start_event = torch.cuda.Event(enable_timing=True) + end_event = torch.cuda.Event(enable_timing=True) + torch.cuda.synchronize() + start_event.record() + else: + start_time = time.perf_counter() + + with torch.no_grad(): + for images, targets in tqdm(test_loader): + images = images.to(device=device, dtype=dtype) + targets = targets.to(device) + + if args.amp_enabled: + with autocast(): + outputs = model(images).logits + _, predicted = outputs.max(1) + else: + outputs = model(images).logits + _, predicted = outputs.max(1) + + total += targets.size(0) + correct += predicted.eq(targets).sum().item() + + # Measure time with proper CUDA synchronization + if torch.cuda.is_available(): + end_event.record() + torch.cuda.synchronize() + eval_time = start_event.elapsed_time(end_event) / 1000.0 # Convert to seconds + else: + eval_time = time.perf_counter() - start_time + + accuracy = 100. * correct / total + return accuracy, eval_time + +def get_model(args, num_classes): + """ + Initialize model based on configuration. + Supports both pure DeltaNet and hybrid models. + """ + if args.model == 'deltanet': + # Prepare attention config for hybrid model if enabled + attn_config = None + if args.use_attn: + attn_config = { + 'layers': [int(i) for i in args.attn_layers.split(',')], + 'num_heads': args.attn_num_heads, + 'num_kv_heads': args.attn_num_kv_heads, + 'window_size': args.attn_window_size + } + # Log hybrid attention configuration + logging.info("Hybrid Attention Configuration:") + logging.info(f"- Attention Layers: {attn_config['layers']}") + logging.info(f"- Number of Heads: {attn_config['num_heads']}") + logging.info(f"- Number of KV Heads: {attn_config['num_kv_heads']}") + logging.info(f"- Window Size: {attn_config['window_size']}") + + config = DeltaNetVisionConfig( + num_hidden_layers=args.num_hidden_layers, + hidden_size=args.hidden_size, + num_heads=args.num_heads, + patch_size=args.patch_size, + image_size=args.image_size, + num_classes=num_classes, + expand_k=args.expand_k, + expand_v=args.expand_v, + attn_mode=args.attn_mode, + pool_type=args.pool_type, + fuse_cross_entropy=args.fuse_cross_entropy, + attn=attn_config # Add attention config for hybrid model + ) + return DeltaNetForImageClassification(config).to(device=device, dtype=dtype) + else: + raise NotImplementedError(f"Model {args.model} is not implemented yet.") + +def train_one_epoch(model, train_loader, criterion, optimizer, scheduler, device, args, epoch): + """ + Training loop for one epoch with proper CUDA timing. + Uses CUDA events for accurate GPU timing and ensures proper synchronization. + """ + model.train() + total_loss = 0 + scaler = GradScaler() if args.amp_enabled else None + + # Create CUDA events for timing + if torch.cuda.is_available(): + start_event = torch.cuda.Event(enable_timing=True) + end_event = torch.cuda.Event(enable_timing=True) + torch.cuda.synchronize() + start_event.record() + else: + start_time = time.perf_counter() + + for i, (images, targets) in enumerate(tqdm(train_loader)): + images = images.to(device=device, dtype=dtype) + targets = targets.to(device) + + if args.amp_enabled: + with autocast(): + outputs = model(images).logits + loss = criterion(outputs, targets) + scaler.scale(loss).backward() + scaler.step(optimizer) + scaler.update() + else: + outputs = model(images).logits + loss = criterion(outputs, targets) + loss.backward() + optimizer.step() + + optimizer.zero_grad() + scheduler.step() # Update learning rate scheduler + total_loss += loss.item() + + if i % args.log_step == 0: + lrs = [group['lr'] for group in optimizer.param_groups] + logging.info(f'Epoch {epoch} Step {i}/{len(train_loader)}: ' + f'Loss={loss.item():.4f} ' + f'LR_backbone={lrs[0]:.2e} ' + f'LR_head={lrs[-1]:.2e}') + + if args.wandb: + wandb.log({ + "batch_loss": loss.item(), + "learning_rate/backbone": lrs[0], + "learning_rate/head": lrs[-1], + "global_step": epoch * len(train_loader) + i + }) + + # Measure time with proper CUDA synchronization + if torch.cuda.is_available(): + end_event.record() + torch.cuda.synchronize() + train_time = start_event.elapsed_time(end_event) / 1000.0 + else: + train_time = time.perf_counter() - start_time + + avg_loss = total_loss / len(train_loader) + return avg_loss, train_time + +def main(): + args = get_args() + + # Setup logging first, before any logging calls + setup_logging(args) + + # Then setup deterministic mode + setup_deterministic_mode(args) + + # Log all configuration parameters + logging.info("=" * 50) + logging.info("Training Configuration:") + logging.info("-" * 50) + for arg, value in sorted(vars(args).items()): + logging.info(f"{arg}: {value}") + logging.info("=" * 50) + + # Setup wandb after logging is initialized + if args.wandb: + project_name = f"{args.model}_vision_classification" + run_name = f"e{args.epochs}_b_lr{args.b_lr}_h_lr_{args.h_lr}_mode{args.attn_mode}_bs{args.train_bs}_p{args.patch_size}_i{args.image_size}_h{args.num_heads}_{args.dataset}" + wandb.init( + project=project_name, + name=run_name, + config=args.__dict__ + ) + logging.info(f"Wandb initialized with project: {project_name}, run: {run_name}") + + train_loader, test_loader, num_classes = get_data(args) + + # Calculate total training steps + num_training_steps = len(train_loader) * args.epochs + num_warmup_steps = int(args.warmup_ratio * num_training_steps) + + model = get_model(args, num_classes) + trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad) + + logging.info("=" * 50) + logging.info("Model Information:") + logging.info("-" * 50) + logging.info(f"Model Type: {args.model}") + logging.info(f"Number of trainable parameters: {trainable_params:,}") + logging.info(f"Number of layers: {args.num_hidden_layers}") + logging.info(f"Hidden size: {args.hidden_size}") + logging.info(f"Number of heads: {args.num_heads}") + logging.info(f"Learning rate scheduler: {args.lr_scheduler_type}") + logging.info(f"Total training steps: {num_training_steps}") + logging.info(f"Warmup steps: {num_warmup_steps}") + logging.info("=" * 50) + + if args.wandb: + wandb.log({"trainable_parameters": trainable_params}) + + criterion = torch.nn.CrossEntropyLoss() + optimizer = optim.AdamW([ + {'params': model.embeddings.parameters(), 'lr': args.b_lr}, + {'params': model.blocks.parameters(), 'lr': args.b_lr}, + {'params': model.classifier.parameters(), 'lr': args.h_lr} + ], weight_decay=args.wd) + + scheduler = get_scheduler( + name=args.lr_scheduler_type, + optimizer=optimizer, + num_warmup_steps=num_warmup_steps, + num_training_steps=num_training_steps + ) + + best_acc = 0 + total_train_time = 0 + total_eval_time = 0 + eval_num = 0 + + for epoch in range(args.epochs): + avg_loss, epoch_train_time = train_one_epoch(model, train_loader, criterion, optimizer, scheduler, device, args, epoch) + total_train_time += epoch_train_time + + # Log GPU memory usage + log_gpu_memory(args, epoch) + + if epoch % args.eval_epoch == 0: + accuracy, epoch_eval_time = evaluate(model, test_loader, device, args) + total_eval_time += epoch_eval_time + eval_num += 1 + + logging.info( + f'Epoch {epoch}: Loss={avg_loss:.4f}, Accuracy={accuracy:.2f}%, ' + f'Train time={epoch_train_time:.2f}s, Eval time={epoch_eval_time:.2f}s' + ) + + if args.wandb: + wandb.log({ + "epoch": epoch, + "train_loss": avg_loss, + "accuracy": accuracy, + "train_time": epoch_train_time, + "eval_time": epoch_eval_time, + "avg_epoch_train_time": total_train_time / (epoch + 1), + "avg_epoch_eval_time": total_eval_time / eval_num + }) + + if accuracy > best_acc: + best_acc = accuracy + torch.save(model.state_dict(), f'{args.model}_vision_best.pth') + + # Log final statistics + avg_train_time = total_train_time / args.epochs + avg_eval_time = total_eval_time / eval_num + logging.info( + f'Training completed. Best accuracy: {best_acc:.2f}%\n' + f'Average training time per epoch: {avg_train_time:.2f}s\n' + f'Average evaluation time: {avg_eval_time:.2f}s' + ) + + if args.wandb: + wandb.log({ + "final/best_accuracy": best_acc, + "final/avg_train_time": avg_train_time, + "final/avg_eval_time": avg_eval_time + }) + if args.wandb: + wandb.finish() + +if __name__ == '__main__': + main() From f676c81d0a9edcb5420a47399af93521e8862abd Mon Sep 17 00:00:00 2001 From: yibozhong Date: Tue, 14 Jan 2025 03:35:45 +0800 Subject: [PATCH 03/17] Update new vision delta net --- fla/vision_models/delta_net/modeling_delta_net.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/fla/vision_models/delta_net/modeling_delta_net.py b/fla/vision_models/delta_net/modeling_delta_net.py index c7d638dca..635c88a41 100644 --- a/fla/vision_models/delta_net/modeling_delta_net.py +++ b/fla/vision_models/delta_net/modeling_delta_net.py @@ -84,6 +84,8 @@ def forward( if hasattr(self, 'ln_1'): hidden_states = self.ln_1(hidden_states) + print(hidden_states.shape) + # Apply attention hidden_states, attentions, past_key_values = self.attn( hidden_states=hidden_states, From 5bc94cec3ccede7a6641ac8a071c7283092426e6 Mon Sep 17 00:00:00 2001 From: yibozhong Date: Tue, 14 Jan 2025 03:37:16 +0800 Subject: [PATCH 04/17] Update new vision delta net --- fla/vision_models/delta_net/modeling_delta_net.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/fla/vision_models/delta_net/modeling_delta_net.py b/fla/vision_models/delta_net/modeling_delta_net.py index 635c88a41..c7d638dca 100644 --- a/fla/vision_models/delta_net/modeling_delta_net.py +++ b/fla/vision_models/delta_net/modeling_delta_net.py @@ -84,8 +84,6 @@ def forward( if hasattr(self, 'ln_1'): hidden_states = self.ln_1(hidden_states) - print(hidden_states.shape) - # Apply attention hidden_states, attentions, past_key_values = self.attn( hidden_states=hidden_states, From d470814d42b5faa1d4cf149410886f32bb68ebb9 Mon Sep 17 00:00:00 2001 From: yibozhong Date: Wed, 15 Jan 2025 00:10:57 +0800 Subject: [PATCH 05/17] Add support for multiple scanning method --- .../delta_net/configuration_delta_net.py | 9 +- .../delta_net/modeling_delta_net.py | 20 +- fla/vision_models/utils.py | 345 ++++++++++++++++-- 3 files changed, 334 insertions(+), 40 deletions(-) diff --git a/fla/vision_models/delta_net/configuration_delta_net.py b/fla/vision_models/delta_net/configuration_delta_net.py index 40fec3ab9..9fa6b74e5 100644 --- a/fla/vision_models/delta_net/configuration_delta_net.py +++ b/fla/vision_models/delta_net/configuration_delta_net.py @@ -39,8 +39,10 @@ def __init__( hidden_dropout_prob: float = 0.0, use_mask_token: bool = False, layer_norm_eps: float = 1e-6, + interpolate_pos_encoding: bool = False, mlp_dim: int = None, - pool_type: str = "mean", # use "mean" by default + # FLA-for-vision-related parameters + scan_type: str = "uni-scan", # scaning type, "uni-scan" or "bi-scan" or "cross-scan", default to "uni-scan" **kwargs ): # Initialize DeltaNet core parameters @@ -77,7 +79,10 @@ def __init__( self.hidden_dropout_prob = hidden_dropout_prob self.use_mask_token = use_mask_token self.layer_norm_eps = layer_norm_eps - self.pool_type = pool_type + self.interpolate_pos_encoding = interpolate_pos_encoding + self.scan_type = scan_type + + if mlp_dim is None: self.mlp_dim = 4 * hidden_size # default value set to 4 * hidden_size else: diff --git a/fla/vision_models/delta_net/modeling_delta_net.py b/fla/vision_models/delta_net/modeling_delta_net.py index c7d638dca..7dd7026f8 100644 --- a/fla/vision_models/delta_net/modeling_delta_net.py +++ b/fla/vision_models/delta_net/modeling_delta_net.py @@ -12,7 +12,7 @@ from .configuration_delta_net import DeltaNetVisionConfig from fla.layers.delta_net import DeltaNet from fla.models.utils import Cache -from ..utils import ImageEmbeddings, Pooler +from ..utils import ImageEmbeddings, Pooler, prepare_hidden_states_for_cross_scan, prepare_hidden_states_for_cross_merge logger = logging.get_logger(__name__) @@ -70,6 +70,8 @@ def __init__(self, config, layer_idx: int): self.mlp = DeltaNetMLP(config) + self.scan_type = config.scan_type + def forward( self, hidden_states: torch.Tensor, @@ -85,6 +87,9 @@ def forward( hidden_states = self.ln_1(hidden_states) # Apply attention + + hidden_states = prepare_hidden_states_for_cross_scan(hidden_states, self.scan_type) + hidden_states, attentions, past_key_values = self.attn( hidden_states=hidden_states, past_key_values=past_key_values, @@ -93,6 +98,8 @@ def forward( **kwargs ) + hidden_states = prepare_hidden_states_for_cross_merge(hidden_states, self.scan_type) + # First residual connection hidden_states = residual + hidden_states residual = hidden_states @@ -133,12 +140,6 @@ def _init_weights(self, module): std=self.config.initializer_range, ).to(module.position_embeddings.dtype) - module.cls_token.data = nn.init.trunc_normal_( - module.cls_token.data.to(torch.float32), - mean=0.0, - std=self.config.initializer_range, - ).to(module.cls_token.dtype) - class DeltaNetForImageClassification(DeltaNetVisionPreTrainedModel): config_class = DeltaNetVisionConfig @@ -154,7 +155,7 @@ def __init__(self, config): self.norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.pooler = Pooler(config) self.classifier = nn.Linear(config.hidden_size, config.num_classes) - + self.interpolate_pos_encoding = config.interpolate_pos_encoding self.init_weights() def forward( @@ -166,12 +167,11 @@ def forward( use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, return_dict: Optional[bool] = None, - interpolate_pos_encoding: Optional[bool] = None, **kwargs: Unpack[Dict] ) -> Union[Tuple, ImageClassifierOutput]: return_dict = return_dict if return_dict is not None else self.config.use_return_dict fuse_linear_and_cross_entropy = self.config.fuse_cross_entropy and self.training - hidden_states = self.embeddings(pixel_values, interpolate_pos_encoding) + hidden_states = self.embeddings(pixel_values, interpolate_pos_encoding=self.interpolate_pos_encoding) for block in self.blocks: hidden_states, attentions, past_key_values = block( diff --git a/fla/vision_models/utils.py b/fla/vision_models/utils.py index 321f6938c..6ba296022 100644 --- a/fla/vision_models/utils.py +++ b/fla/vision_models/utils.py @@ -7,6 +7,14 @@ from torch import nn from typing import Optional from transformers.utils import torch_int +import triton +import triton.language as tl +import einops +import math + +""" +Basic component of a vision model, like the patch embeddings, image embeddings, and pooler. +""" class PatchEmbeddings(nn.Module): """ @@ -46,17 +54,16 @@ def forward(self, pixel_values: torch.Tensor, interpolate_pos_encoding: bool = F class ImageEmbeddings(nn.Module): """ - Construct the CLS token, position and patch embeddings. - Adapted from huggingface/transformers ViT implementation. + Construct the position and patch embeddings. + Adapted from huggingface/transformers ViT implementation. No cls token is used in this implementation. """ def __init__(self, config, use_mask_token: bool = False) -> None: super().__init__() - self.cls_token = nn.Parameter(torch.randn(1, 1, config.hidden_size)) self.mask_token = nn.Parameter(torch.zeros(1, 1, config.hidden_size)) if use_mask_token else None self.patch_embeddings = PatchEmbeddings(config) num_patches = self.patch_embeddings.num_patches - self.position_embeddings = nn.Parameter(torch.randn(1, num_patches + 1, config.hidden_size)) + self.position_embeddings = nn.Parameter(torch.randn(1, num_patches, config.hidden_size)) self.dropout = nn.Dropout(config.hidden_dropout_prob) self.patch_size = config.patch_size self.config = config @@ -71,35 +78,32 @@ def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: - https://github.com/facebookresearch/dinov2/blob/e1277af2ba9496fbadf7aec6eba56e8d882d1e35/dinov2/models/vision_transformer.py#L179-L211 """ - num_patches = embeddings.shape[1] - 1 - num_positions = self.position_embeddings.shape[1] - 1 + num_patches = embeddings.shape[1] + num_positions = self.position_embeddings.shape[1] - # always interpolate when tracing to ensure the exported model works for dynamic input shapes if not torch.jit.is_tracing() and num_patches == num_positions and height == width: return self.position_embeddings - - class_pos_embed = self.position_embeddings[:, :1] - patch_pos_embed = self.position_embeddings[:, 1:] - + dim = embeddings.shape[-1] - new_height = height // self.patch_size + new_height = height // self.patch_size new_width = width // self.patch_size sqrt_num_positions = torch_int(num_positions**0.5) - patch_pos_embed = patch_pos_embed.reshape(1, sqrt_num_positions, sqrt_num_positions, dim) - patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2) + pos_embed = self.position_embeddings.reshape(1, sqrt_num_positions, sqrt_num_positions, dim) + + pos_embed = pos_embed.permute(0, 3, 1, 2) - patch_pos_embed = nn.functional.interpolate( - patch_pos_embed, + pos_embed = nn.functional.interpolate( + pos_embed, size=(new_height, new_width), mode="bicubic", align_corners=False, ) - patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim) + pos_embed = pos_embed.permute(0, 2, 3, 1).view(1, -1, dim) - return torch.cat((class_pos_embed, patch_pos_embed), dim=1) + return pos_embed def forward( self, @@ -117,10 +121,6 @@ def forward( mask = bool_masked_pos.unsqueeze(-1).type_as(mask_tokens) embeddings = embeddings * (1.0 - mask) + mask_tokens * mask - # add the [CLS] token to the embedded patch tokens - cls_tokens = self.cls_token.expand(batch_size, -1, -1) - embeddings = torch.cat((cls_tokens, embeddings), dim=1) - # add positional encoding to each token if interpolate_pos_encoding: embeddings = embeddings + self.interpolate_pos_encoding(embeddings, height, width) @@ -133,20 +133,309 @@ def forward( class Pooler(nn.Module): """ - Pool the output of a vision model by taking either the CLS token or mean of all tokens. + Pool the output of a vision model by taking the mean of all tokens. Adapted from huggingface/transformers ViT implementation. """ def __init__(self, config): super().__init__() self.dense = nn.Linear(config.hidden_size, config.hidden_size) self.activation = nn.Tanh() - self.pool_type = config.pool_type def forward(self, hidden_states): - if self.pool_type == 'cls': - pooled_output = hidden_states[:, 0] - else: # 'mean' - pooled_output = hidden_states.mean(dim=1) + pooled_output = hidden_states.mean(dim=1) # always use mean pooling pooled_output = self.dense(pooled_output) pooled_output = self.activation(pooled_output) return pooled_output + +""" +Cross Scan and Cross Merge implemented in Triton (only). taken from https://github.com/MzeroMiko/VMamba/blob/main/classification/models/csm_triton.py +""" + +@triton.jit +def triton_cross_scan_flex( + x: tl.tensor, # (B, C, H, W) | (B, H, W, C) | (B, 4, C, H, W) | (B, H, W, 4, C) + y: tl.tensor, # (B, 4, C, H, W) | (B, H, W, 4, C) + x_layout: tl.constexpr, + y_layout: tl.constexpr, + operation: tl.constexpr, + onebyone: tl.constexpr, + scans: tl.constexpr, + BC: tl.constexpr, + BH: tl.constexpr, + BW: tl.constexpr, + DC: tl.constexpr, + DH: tl.constexpr, + DW: tl.constexpr, + NH: tl.constexpr, + NW: tl.constexpr, +): + # x_layout = 0 + # y_layout = 1 # 0 BCHW, 1 BHWC + # operation = 0 # 0 scan, 1 merge + # onebyone = 0 # 0 false, 1 true + # scans = 0 # 0 cross scan, 1 unidirectional, 2 bidirectional + + i_hw, i_c, i_b = tl.program_id(0), tl.program_id(1), tl.program_id(2) + i_h, i_w = (i_hw // NW), (i_hw % NW) + _mask_h = (i_h * BH + tl.arange(0, BH)) < DH + _mask_w = (i_w * BW + tl.arange(0, BW)) < DW + _mask_hw = _mask_h[:, None] & _mask_w[None, :] + _for_C = min(DC - i_c * BC, BC) + + pos_h = (i_h * BH + tl.arange(0, BH)[:, None]) + pos_w = (i_w * BW + tl.arange(0, BW)[None, :]) + neg_h = (DH - i_h * BH - 1 - tl.arange(0, BH)[:, None]) + neg_w = (DW - i_w * BW - 1 - tl.arange(0, BW)[None, :]) + if scans == 0: + # none; trans; flip; trans + flip; + HWRoute0 = pos_h * DW + pos_w + HWRoute1 = pos_w * DH + pos_h # trans + HWRoute2 = neg_h * DW + neg_w # flip + HWRoute3 = neg_w * DH + neg_h # trans + flip + elif scans == 1: + # none; none; none; none; + HWRoute0 = pos_h * DW + pos_w + HWRoute1 = HWRoute0 + HWRoute2 = HWRoute0 + HWRoute3 = HWRoute0 + elif scans == 2: + # none; none; flip; flip; + HWRoute0 = pos_h * DW + pos_w + HWRoute1 = HWRoute0 + HWRoute2 = neg_h * DW + neg_w # flip + HWRoute3 = HWRoute2 + + _tmp1 = DC * DH * DW + + y_ptr_base = y + i_b * 4 * _tmp1 + (i_c * BC * DH * DW if y_layout == 0 else i_c * BC) + if y_layout == 0: + p_y1 = y_ptr_base + HWRoute0 + p_y2 = y_ptr_base + _tmp1 + HWRoute1 + p_y3 = y_ptr_base + 2 * _tmp1 + HWRoute2 + p_y4 = y_ptr_base + 3 * _tmp1 + HWRoute3 + else: + p_y1 = y_ptr_base + HWRoute0 * 4 * DC + p_y2 = y_ptr_base + DC + HWRoute1 * 4 * DC + p_y3 = y_ptr_base + 2 * DC + HWRoute2 * 4 * DC + p_y4 = y_ptr_base + 3 * DC + HWRoute3 * 4 * DC + + if onebyone == 0: + x_ptr_base = x + i_b * _tmp1 + (i_c * BC * DH * DW if x_layout == 0 else i_c * BC) + if x_layout == 0: + p_x = x_ptr_base + HWRoute0 + else: + p_x = x_ptr_base + HWRoute0 * DC + + if operation == 0: + for idxc in range(_for_C): + _idx_x = idxc * DH * DW if x_layout == 0 else idxc + _idx_y = idxc * DH * DW if y_layout == 0 else idxc + _x = tl.load(p_x + _idx_x, mask=_mask_hw) + tl.store(p_y1 + _idx_y, _x, mask=_mask_hw) + tl.store(p_y2 + _idx_y, _x, mask=_mask_hw) + tl.store(p_y3 + _idx_y, _x, mask=_mask_hw) + tl.store(p_y4 + _idx_y, _x, mask=_mask_hw) + elif operation == 1: + for idxc in range(_for_C): + _idx_x = idxc * DH * DW if x_layout == 0 else idxc + _idx_y = idxc * DH * DW if y_layout == 0 else idxc + _y1 = tl.load(p_y1 + _idx_y, mask=_mask_hw) + _y2 = tl.load(p_y2 + _idx_y, mask=_mask_hw) + _y3 = tl.load(p_y3 + _idx_y, mask=_mask_hw) + _y4 = tl.load(p_y4 + _idx_y, mask=_mask_hw) + tl.store(p_x + _idx_x, _y1 + _y2 + _y3 + _y4, mask=_mask_hw) + + else: + x_ptr_base = x + i_b * 4 * _tmp1 + (i_c * BC * DH * DW if x_layout == 0 else i_c * BC) + if x_layout == 0: + p_x1 = x_ptr_base + HWRoute0 + p_x2 = p_x1 + _tmp1 + p_x3 = p_x2 + _tmp1 + p_x4 = p_x3 + _tmp1 + else: + p_x1 = x_ptr_base + HWRoute0 * 4 * DC + p_x2 = p_x1 + DC + p_x3 = p_x2 + DC + p_x4 = p_x3 + DC + + if operation == 0: + for idxc in range(_for_C): + _idx_x = idxc * DH * DW if x_layout == 0 else idxc + _idx_y = idxc * DH * DW if y_layout == 0 else idxc + tl.store(p_y1 + _idx_y, tl.load(p_x1 + _idx_x, mask=_mask_hw), mask=_mask_hw) + tl.store(p_y2 + _idx_y, tl.load(p_x2 + _idx_x, mask=_mask_hw), mask=_mask_hw) + tl.store(p_y3 + _idx_y, tl.load(p_x3 + _idx_x, mask=_mask_hw), mask=_mask_hw) + tl.store(p_y4 + _idx_y, tl.load(p_x4 + _idx_x, mask=_mask_hw), mask=_mask_hw) + else: + for idxc in range(_for_C): + _idx_x = idxc * DH * DW if x_layout == 0 else idxc + _idx_y = idxc * DH * DW if y_layout == 0 else idxc + tl.store(p_x1 + _idx_x, tl.load(p_y1 + _idx_y), mask=_mask_hw) + tl.store(p_x2 + _idx_x, tl.load(p_y2 + _idx_y), mask=_mask_hw) + tl.store(p_x3 + _idx_x, tl.load(p_y3 + _idx_y), mask=_mask_hw) + tl.store(p_x4 + _idx_x, tl.load(p_y4 + _idx_y), mask=_mask_hw) + + +class CrossScanTritonF(torch.autograd.Function): + @staticmethod + def forward(ctx, x: torch.Tensor, in_channel_first=True, out_channel_first=True, one_by_one=False, scans=0): + if one_by_one: + if in_channel_first: + B, _, C, H, W = x.shape + else: + B, H, W, _, C = x.shape + else: + if in_channel_first: + B, C, H, W = x.shape + else: + B, H, W, C = x.shape + B, C, H, W = int(B), int(C), int(H), int(W) + BC, BH, BW = 1, 32, 32 + NH, NW, NC = triton.cdiv(H, BH), triton.cdiv(W, BW), triton.cdiv(C, BC) + + ctx.in_channel_first = in_channel_first + ctx.out_channel_first = out_channel_first + ctx.one_by_one = one_by_one + ctx.scans = scans + ctx.shape = (B, C, H, W) + ctx.triton_shape = (BC, BH, BW, NC, NH, NW) + + y = x.new_empty((B, 4, C, H * W)) if out_channel_first else x.new_empty((B, H * W, 4, C)) + triton_cross_scan_flex[(NH * NW, NC, B)]( + x.contiguous(), y, + (0 if in_channel_first else 1), (0 if out_channel_first else 1), 0, (0 if not one_by_one else 1), scans, + BC, BH, BW, C, H, W, NH, NW + ) + return y + + @staticmethod + def backward(ctx, y: torch.Tensor): + in_channel_first = ctx.in_channel_first + out_channel_first = ctx.out_channel_first + one_by_one = ctx.one_by_one + scans = ctx.scans + B, C, H, W = ctx.shape + BC, BH, BW, NC, NH, NW = ctx.triton_shape + if one_by_one: + x = y.new_empty((B, 4, C, H, W)) if in_channel_first else y.new_empty((B, H, W, 4, C)) + else: + x = y.new_empty((B, C, H, W)) if in_channel_first else y.new_empty((B, H, W, C)) + + triton_cross_scan_flex[(NH * NW, NC, B)]( + x, y.contiguous(), + (0 if in_channel_first else 1), (0 if out_channel_first else 1), 1, (0 if not one_by_one else 1), scans, + BC, BH, BW, C, H, W, NH, NW + ) + return x, None, None, None, None + + +class CrossMergeTritonF(torch.autograd.Function): + @staticmethod + def forward(ctx, y: torch.Tensor, in_channel_first=True, out_channel_first=True, one_by_one=False, scans=0): + if out_channel_first: + B, _, C, H, W = y.shape + else: + B, H, W, _, C = y.shape + B, C, H, W = int(B), int(C), int(H), int(W) + BC, BH, BW = 1, 32, 32 + NH, NW, NC = triton.cdiv(H, BH), triton.cdiv(W, BW), triton.cdiv(C, BC) + ctx.in_channel_first = in_channel_first + ctx.out_channel_first = out_channel_first + ctx.one_by_one = one_by_one + ctx.scans = scans + ctx.shape = (B, C, H, W) + ctx.triton_shape = (BC, BH, BW, NC, NH, NW) + if one_by_one: + x = y.new_empty((B, 4, C, H * W)) if in_channel_first else y.new_empty((B, H * W, 4, C)) + else: + x = y.new_empty((B, C, H * W)) if in_channel_first else y.new_empty((B, H * W, C)) + triton_cross_scan_flex[(NH * NW, NC, B)]( + x, y.contiguous(), + (0 if in_channel_first else 1), (0 if out_channel_first else 1), 1, (0 if not one_by_one else 1), scans, + BC, BH, BW, C, H, W, NH, NW + ) + return x + + @staticmethod + def backward(ctx, x: torch.Tensor): + in_channel_first = ctx.in_channel_first + out_channel_first = ctx.out_channel_first + one_by_one = ctx.one_by_one + scans = ctx.scans + B, C, H, W = ctx.shape + BC, BH, BW, NC, NH, NW = ctx.triton_shape + y = x.new_empty((B, 4, C, H, W)) if out_channel_first else x.new_empty((B, H, W, 4, C)) + triton_cross_scan_flex[(NH * NW, NC, B)]( + x.contiguous(), y, + (0 if in_channel_first else 1), (0 if out_channel_first else 1), 0, (0 if not one_by_one else 1), scans, + BC, BH, BW, C, H, W, NH, NW + ) + return y, None, None, None, None, None + + +# @torch.compile(options={"triton.cudagraphs": True}, fullgraph=True) +def cross_scan_fn(x: torch.Tensor, in_channel_first=True, out_channel_first=True, one_by_one=False, scans=0, force_torch=False): + # x: (B, C, H, W) | (B, H, W, C) | (B, 4, C, H, W) | (B, H, W, 4, C) + # y: (B, 4, C, L) | (B, L, 4, C) + # scans: 0: cross scan; 1 unidirectional; 2: bidirectional; + assert x.is_cuda + CSF = CrossScanTritonF + with torch.cuda.device(x.device): + return CSF.apply(x, in_channel_first, out_channel_first, one_by_one, scans) + + +# @torch.compile(options={"triton.cudagraphs": True}, fullgraph=True) +def cross_merge_fn(y: torch.Tensor, in_channel_first=True, out_channel_first=True, one_by_one=False, scans=0, force_torch=False): + # y: (B, 4, C, L) | (B, L, 4, C) + # x: (B, C, H * W) | (B, H * W, C) | (B, 4, C, H * W) | (B, H * W, 4, C) + # scans: 0: cross scan; 1 unidirectional; 2: bidirectional; + assert y.is_cuda + CMF = CrossMergeTritonF + with torch.cuda.device(y.device): + return CMF.apply(y, in_channel_first, out_channel_first, one_by_one, scans) + +def prepare_hidden_states_for_cross_scan(hidden_states: torch.Tensor, scan_type: str = "uni-scan"): + # hidden_states shape should be: (B, L, D) + if scan_type == "uni-scan": + # in this case, nothing need to be done + return hidden_states + elif scan_type == "bi-scan": + flipped_hidden_states = hidden_states.flip(-2) + hidden_states = torch.cat([hidden_states, flipped_hidden_states], dim=0) # (B, L, D) -> (2B, L, D) + return hidden_states + + # apply cross scan to the sequence + B, L, D = hidden_states.shape + hw = int(math.sqrt(L)) + assert (hw * hw == L) # make sure L is a square + hidden_states = einops.rearrange(hidden_states, "b (h w) d -> b h w d", h=hw, w=hw) # change the shape to feed to cross_scan + hidden_states = cross_scan_fn(hidden_states, in_channel_first=False, out_channel_first=False, one_by_one=False, scans=0) + hidden_states = einops.rearrange(hidden_states, "b l k d -> (b k) l d") + return hidden_states + +def prepare_hidden_states_for_cross_merge(hidden_states: torch.Tensor, scan_type: str = "uni-scan"): + # hidden_states shape should be: (BK, L, D), K=2 for bi-scan, K=1 for uni-scan, K=4 for cross-scan + if scan_type == "uni-scan": + # in this case, nothing need to be done + return hidden_states + elif scan_type == "bi-scan": + # merge the two sequences + B = hidden_states.shape[0] // 2 + hidden_states = hidden_states[:B] + hidden_states[B:] + return hidden_states + + B, L, D = hidden_states.shape + hw = int(math.sqrt(L)) + hidden_states = einops.rearrange(hidden_states, "(b k) (h w) d -> b h w k d", k=4, h=hw, w=hw) + # apply cross merge to the sequence + hidden_states = cross_merge_fn(hidden_states, in_channel_first=False, out_channel_first=False, one_by_one=False, scans=0) + return hidden_states + +# check the implementation +if __name__ == "__main__": + B, L, D = 2, 16, 2048 + hidden_states = torch.randn(B, L, D).cuda() + hidden_states = prepare_hidden_states_for_cross_scan(hidden_states, scan_type="cross-scan") + hidden_states = prepare_hidden_states_for_cross_merge(hidden_states, scan_type="cross-scan") + print(hidden_states.shape) + print("Cross scan applied successfully!") \ No newline at end of file From 5ee4bfc4304397fa657393d7a474ff1ea8a16a5f Mon Sep 17 00:00:00 2001 From: yibozhong Date: Wed, 15 Jan 2025 00:11:36 +0800 Subject: [PATCH 06/17] Update training script --- classification.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/classification.py b/classification.py index 88c160149..26343fd2f 100644 --- a/classification.py +++ b/classification.py @@ -51,9 +51,9 @@ def get_args(): parser.add_argument('--expand_k', type=float, default=1.0, help='Key expansion ratio') parser.add_argument('--expand_v', type=float, default=1.0, help='Value expansion ratio') parser.add_argument('--attn_mode', type=str, default='chunk', choices=['chunk', 'fused_recurrent', 'fused_chunk']) - parser.add_argument('--pool_type', type=str, default='mean', choices=['mean', 'cls']) parser.add_argument('--model', type=str, required=True, help='Model type (currently only supports "deltanet")') parser.add_argument('--fuse_cross_entropy', action='store_true', help='Fuse cross entropy with logits') + parser.add_argument('--scan_type', type=str, default='uni-scan', choices=['uni-scan', 'bi-scan', 'cross-scan'],) # Learning rate schedule related arguments parser.add_argument('--lr_scheduler_type', type=str, default='constant_with_warmup', @@ -220,9 +220,9 @@ def get_model(args, num_classes): expand_k=args.expand_k, expand_v=args.expand_v, attn_mode=args.attn_mode, - pool_type=args.pool_type, fuse_cross_entropy=args.fuse_cross_entropy, - attn=attn_config # Add attention config for hybrid model + attn=attn_config, # Add attention config for hybrid model + scan_type=args.scan_type # Add scan type to choose different scaning strategy ) return DeltaNetForImageClassification(config).to(device=device, dtype=dtype) else: From c308ac05252d7fdd85ec1eac6b72bc113ccc2b05 Mon Sep 17 00:00:00 2001 From: yibozhong Date: Fri, 17 Jan 2025 02:13:12 +0800 Subject: [PATCH 07/17] Add all fla-based vision models except mamba, mamba2 and samba --- fla/vision_models/abc/__init__.py | 12 + fla/vision_models/abc/configuration_abc.py | 97 ++++++++ fla/vision_models/abc/modeling_abc.py | 205 +++++++++++++++++ fla/vision_models/bitnet/__init__.py | 12 + .../bitnet/configuration_bitnet.py | 92 ++++++++ fla/vision_models/bitnet/modeling_bitnet.py | 201 +++++++++++++++++ fla/vision_models/delta_net/__init__.py | 3 +- .../delta_net/configuration_delta_net.py | 16 +- .../delta_net/modeling_delta_net.py | 1 - fla/vision_models/gated_deltanet/__init__.py | 13 ++ .../configuration_gated_deltanet.py | 87 ++++++++ .../gated_deltanet/modeling_gated_deltanet.py | 202 +++++++++++++++++ fla/vision_models/gla/__init__.py | 12 + fla/vision_models/gla/configuration_gla.py | 95 ++++++++ fla/vision_models/gla/modeling_gla.py | 207 +++++++++++++++++ fla/vision_models/gsa/__init__.py | 12 + fla/vision_models/gsa/configuration_gsa.py | 106 +++++++++ fla/vision_models/gsa/modeling_gsa.py | 209 ++++++++++++++++++ fla/vision_models/hgrn/__init__.py | 12 + fla/vision_models/hgrn/configuration_hgrn.py | 85 +++++++ fla/vision_models/hgrn/modeling_hgrn.py | 199 +++++++++++++++++ fla/vision_models/hgrn2/__init__.py | 12 + .../hgrn2/configuration_hgrn2.py | 88 ++++++++ fla/vision_models/hgrn2/modeling_hgrn2.py | 200 +++++++++++++++++ fla/vision_models/linear_attn/__init__.py | 12 + .../linear_attn/configuration_linear_attn.py | 96 ++++++++ .../linear_attn/modeling_linear_attn.py | 205 +++++++++++++++++ fla/vision_models/retnet/__init__.py | 12 + .../retnet/configuration_retnet.py | 100 +++++++++ fla/vision_models/retnet/modeling_retnet.py | 204 +++++++++++++++++ fla/vision_models/rwkv6/__init__.py | 12 + .../rwkv6/configuration_rwkv6.py | 93 ++++++++ fla/vision_models/rwkv6/modeling_rwkv6.py | 201 +++++++++++++++++ fla/vision_models/transformer/__init__.py | 12 + .../transformer/configuration_transformer.py | 81 +++++++ .../transformer/modeling_transformer.py | 190 ++++++++++++++++ fla/vision_models/utils.py | 51 ++++- 37 files changed, 3433 insertions(+), 14 deletions(-) create mode 100644 fla/vision_models/abc/__init__.py create mode 100644 fla/vision_models/abc/configuration_abc.py create mode 100644 fla/vision_models/abc/modeling_abc.py create mode 100644 fla/vision_models/bitnet/__init__.py create mode 100644 fla/vision_models/bitnet/configuration_bitnet.py create mode 100644 fla/vision_models/bitnet/modeling_bitnet.py create mode 100644 fla/vision_models/gated_deltanet/__init__.py create mode 100644 fla/vision_models/gated_deltanet/configuration_gated_deltanet.py create mode 100644 fla/vision_models/gated_deltanet/modeling_gated_deltanet.py create mode 100644 fla/vision_models/gla/__init__.py create mode 100644 fla/vision_models/gla/configuration_gla.py create mode 100644 fla/vision_models/gla/modeling_gla.py create mode 100644 fla/vision_models/gsa/__init__.py create mode 100644 fla/vision_models/gsa/configuration_gsa.py create mode 100644 fla/vision_models/gsa/modeling_gsa.py create mode 100644 fla/vision_models/hgrn/__init__.py create mode 100644 fla/vision_models/hgrn/configuration_hgrn.py create mode 100644 fla/vision_models/hgrn/modeling_hgrn.py create mode 100644 fla/vision_models/hgrn2/__init__.py create mode 100644 fla/vision_models/hgrn2/configuration_hgrn2.py create mode 100644 fla/vision_models/hgrn2/modeling_hgrn2.py create mode 100644 fla/vision_models/linear_attn/__init__.py create mode 100644 fla/vision_models/linear_attn/configuration_linear_attn.py create mode 100644 fla/vision_models/linear_attn/modeling_linear_attn.py create mode 100644 fla/vision_models/retnet/__init__.py create mode 100644 fla/vision_models/retnet/configuration_retnet.py create mode 100644 fla/vision_models/retnet/modeling_retnet.py create mode 100644 fla/vision_models/rwkv6/__init__.py create mode 100644 fla/vision_models/rwkv6/configuration_rwkv6.py create mode 100644 fla/vision_models/rwkv6/modeling_rwkv6.py create mode 100644 fla/vision_models/transformer/__init__.py create mode 100644 fla/vision_models/transformer/configuration_transformer.py create mode 100644 fla/vision_models/transformer/modeling_transformer.py diff --git a/fla/vision_models/abc/__init__.py b/fla/vision_models/abc/__init__.py new file mode 100644 index 000000000..67d013691 --- /dev/null +++ b/fla/vision_models/abc/__init__.py @@ -0,0 +1,12 @@ +from transformers import AutoConfig, AutoModelForImageClassification + +from fla.vision_models.abc.configuration_abc import ABCVisionConfig +from fla.vision_models.abc.modeling_abc import ABCForImageClassification + +AutoConfig.register(ABCVisionConfig.model_type, ABCVisionConfig) +AutoModelForImageClassification.register(ABCVisionConfig, ABCForImageClassification) + +__all__ = [ + 'ABCVisionConfig', + 'ABCForImageClassification' +] diff --git a/fla/vision_models/abc/configuration_abc.py b/fla/vision_models/abc/configuration_abc.py new file mode 100644 index 000000000..6a7c2fa95 --- /dev/null +++ b/fla/vision_models/abc/configuration_abc.py @@ -0,0 +1,97 @@ +from typing import Dict, Optional + +from transformers.configuration_utils import PretrainedConfig + + +class ABCVisionConfig(PretrainedConfig): + + model_type = 'abc_vision' + + def __init__( + self, + # ABC core parameters + hidden_size: int = 2048, + gate_low_rank_dim: int = 16, + clamp_min: float = -32, + clamp_max: float = 32, + num_hidden_layers: int = 24, + num_heads: int = 4, + num_slots: Optional[int] = 64, + use_short_conv: bool = False, + conv_size: int = 4, + exapnd_k: float = 0.5, + exapnd_v: float = 1, + hidden_act: str = "swish", + max_position_embeddings: int = 2048, + elementwise_affine: Optional[bool] = True, + norm_eps: float = 1e-6, + attn: Optional[Dict] = None, + use_cache: bool = True, + initializer_range: float = 0.02, + fuse_norm: bool = True, + fuse_cross_entropy: bool = True, + # Vision specific parameters + image_size: int = 224, + patch_size: int = 16, + num_channels: int = 3, + num_classes: int = 1000, + qkv_bias: bool = True, + hidden_dropout_prob: float = 0.0, + use_mask_token: bool = False, + layer_norm_eps: float = 1e-6, + interpolate_pos_encoding: bool = False, + mlp_dim: int = None, + # FLA-for-vision-related parameters + scan_type: str = "uni-scan", # scaning type, "uni-scan" or "bi-scan" or "cross-scan", default to "uni-scan" + **kwargs + ): + # Initialize ABC core parameters + self.hidden_size = hidden_size + self.gate_low_rank_dim = gate_low_rank_dim + self.clamp_min = clamp_min + self.clamp_max = clamp_max + self.num_hidden_layers = num_hidden_layers + self.num_heads = num_heads + self.num_slots = num_slots + self.use_short_conv = use_short_conv + self.conv_size = conv_size + self.expand_k = exapnd_k + self.expand_v = exapnd_v + self.hidden_act = hidden_act + self.max_position_embeddings = max_position_embeddings + self.elementwise_affine = elementwise_affine + self.norm_eps = norm_eps + self.attn = attn + self.use_cache = use_cache + self.initializer_range = initializer_range + self.fuse_norm = fuse_norm + self.fuse_cross_entropy = fuse_cross_entropy + + # Initialize vision specific parameters + self.image_size = image_size + self.patch_size = patch_size + self.num_channels = num_channels + self.num_classes = num_classes + self.qkv_bias = qkv_bias + self.hidden_dropout_prob = hidden_dropout_prob + self.use_mask_token = use_mask_token + self.layer_norm_eps = layer_norm_eps + self.interpolate_pos_encoding = interpolate_pos_encoding + self.scan_type = scan_type + + if attn is not None: + if not isinstance(attn, Dict): + raise ValueError("attn must be a dictionary") + if 'layers' not in attn: + raise ValueError("Layer indices must be provided to initialize hybrid attention layers") + if 'num_heads' not in attn: + raise ValueError("Number of heads must be provided to initialize hybrid attention layers") + attn['num_kv_heads'] = attn.get('num_kv_heads', attn['num_heads']) + attn['window_size'] = attn.get('window_size', None) + + if mlp_dim is None: + self.mlp_dim = 4 * hidden_size # default value set to 4 * hidden_size + else: + self.mlp_dim = mlp_dim + + super().__init__(**kwargs) diff --git a/fla/vision_models/abc/modeling_abc.py b/fla/vision_models/abc/modeling_abc.py new file mode 100644 index 000000000..9fa33230a --- /dev/null +++ b/fla/vision_models/abc/modeling_abc.py @@ -0,0 +1,205 @@ +import collections.abc +import math +import torch +import torch.utils.checkpoint +from torch import nn +from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss +from typing import Optional, Set, Tuple, Union, List, Dict, Unpack +from transformers.utils import logging +from fla.layers.attn import Attention +from transformers.modeling_outputs import ImageClassifierOutput +from transformers.modeling_utils import PreTrainedModel +from .configuration_abc import ABCVisionConfig +from fla.layers.abc import ABCAttention +from fla.models.utils import Cache +from ..utils import ImageEmbeddings, Pooler, prepare_hidden_states_for_cross_scan, prepare_hidden_states_for_cross_merge + +logger = logging.get_logger(__name__) + +class ABCMLP(nn.Module): + def __init__(self, config): + super().__init__() + self.net = nn.Sequential( + nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps), + nn.Linear(config.hidden_size, config.mlp_dim), + nn.GELU(), + nn.Linear(config.mlp_dim, config.hidden_size), + nn.Dropout(config.hidden_dropout_prob) + ) + + def forward(self, x): + return self.net(x) + +class ABCBlock(nn.Module): + def __init__(self, config, layer_idx: int): + super().__init__() + + if not config.norm_first: + self.ln_1 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + + if config.attn is not None and layer_idx in config.attn['layers']: + self.attn = Attention( + hidden_size=config.hidden_size, + num_heads=config.attn['num_heads'], + num_kv_heads=config.attn['num_kv_heads'], + window_size=config.attn['window_size'], + max_position_embeddings=config.max_position_embeddings, + layer_idx=layer_idx + ) + else: + self.attn = ABCAttention( + hidden_size=config.hidden_size, + expand_k=config.expand_k, + expand_v=config.expand_v, + num_heads=config.num_heads, + num_slots=config.num_slots, + use_short_conv=config.use_short_conv, + conv_size=config.conv_size, + gate_fn=config.hidden_act, + elementwise_affine=config.elementwise_affine, + norm_eps=config.norm_eps, + clamp_min=config.clamp_min, + clamp_max=config.clamp_max, + fuse_norm=config.fuse_norm, + layer_idx=layer_idx + ) + + if not config.norm_first: + self.ln_2 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + + self.mlp = ABCMLP(config) + + self.scan_type = config.scan_type + + def forward( + self, + hidden_states: torch.Tensor, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + use_cache: Optional[bool] = False, + output_attentions: Optional[bool] = False, + **kwargs: Unpack[Dict] + ) -> Union[Tuple[torch.Tensor, Optional[torch.Tensor]], Tuple[torch.Tensor]]: + residual = hidden_states + + # Pre-normalization if enabled + if hasattr(self, 'ln_1'): + hidden_states = self.ln_1(hidden_states) + + # Apply attention + + hidden_states = prepare_hidden_states_for_cross_scan(hidden_states, self.scan_type) + + hidden_states, attentions, past_key_values = self.attn( + hidden_states=hidden_states, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + **kwargs + ) + + hidden_states = prepare_hidden_states_for_cross_merge(hidden_states, self.scan_type) + + # First residual connection + hidden_states = residual + hidden_states + residual = hidden_states + + # Pre-normalization for MLP if enabled + if hasattr(self, 'ln_2'): + hidden_states = self.ln_2(hidden_states) + + # MLP + hidden_states = self.mlp(hidden_states) + + # Second residual connection + hidden_states = residual + hidden_states + + outputs = (hidden_states, attentions, past_key_values) + + return outputs + +class ABCVisionPreTrainedModel(PreTrainedModel): + # this part of the code is adapted from huggingface/transformers vit implementation + config_class = ABCVisionConfig + + def _init_weights(self, module): + if isinstance(module, (nn.Linear, nn.Conv2d)): + module.weight.data = nn.init.trunc_normal_( + module.weight.data.to(torch.float32), mean=0.0, std=self.config.initializer_range + ).to(module.weight.dtype) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.LayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + elif isinstance(module, ImageEmbeddings): + module.position_embeddings.data = nn.init.trunc_normal_( + module.position_embeddings.data.to(torch.float32), + mean=0.0, + std=self.config.initializer_range, + ).to(module.position_embeddings.dtype) + +class ABCForImageClassification(ABCVisionPreTrainedModel): + config_class = ABCVisionConfig + + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_classes + + self.embeddings = ImageEmbeddings(config) + self.blocks = nn.ModuleList([ + ABCBlock(config, layer_idx) + for layer_idx in range(config.num_hidden_layers) + ]) + self.norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.pooler = Pooler(config) + self.classifier = nn.Linear(config.hidden_size, config.num_classes) + self.interpolate_pos_encoding = config.interpolate_pos_encoding + self.init_weights() + + def forward( + self, + pixel_values: Optional[torch.Tensor] = None, + labels: Optional[torch.Tensor] = None, + output_hidden_states: Optional[bool] = None, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + return_dict: Optional[bool] = None, + **kwargs: Unpack[Dict] + ) -> Union[Tuple, ImageClassifierOutput]: + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + fuse_linear_and_cross_entropy = self.config.fuse_cross_entropy and self.training + hidden_states = self.embeddings(pixel_values, interpolate_pos_encoding=self.interpolate_pos_encoding) + + for block in self.blocks: + hidden_states, attentions, past_key_values = block( + hidden_states, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + **kwargs + ) + + hidden_states = self.norm(hidden_states) + pooled_output = self.pooler(hidden_states) + + logits = self.classifier(pooled_output) + + loss = None + if labels is not None: + if self.num_labels == 1: + loss_fct = MSELoss() + loss = loss_fct(logits.squeeze(), labels.squeeze()) + else: + loss_fct = CrossEntropyLoss() + loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) + + if not return_dict: + output = (logits,) + (hidden_states,) + return ((loss,) + output) if loss is not None else output + + return ImageClassifierOutput( + loss=loss, + logits=logits, + hidden_states=hidden_states, + ) diff --git a/fla/vision_models/bitnet/__init__.py b/fla/vision_models/bitnet/__init__.py new file mode 100644 index 000000000..8f372bc7c --- /dev/null +++ b/fla/vision_models/bitnet/__init__.py @@ -0,0 +1,12 @@ +from transformers import AutoConfig, AutoModelForImageClassification + +from fla.vision_models.bitnet.configuration_bitnet import BitNetVisionConfig +from fla.vision_models.bitnet.modeling_bitnet import BitNetForImageClassification + +AutoConfig.register(BitNetVisionConfig, BitNetVisionConfig) +AutoModelForImageClassification.register(BitNetVisionConfig, BitNetForImageClassification) + +__all__ = [ + 'BitNetVisionConfig', + 'BitNetForImageClassification' +] diff --git a/fla/vision_models/bitnet/configuration_bitnet.py b/fla/vision_models/bitnet/configuration_bitnet.py new file mode 100644 index 000000000..37a51b925 --- /dev/null +++ b/fla/vision_models/bitnet/configuration_bitnet.py @@ -0,0 +1,92 @@ +from typing import Dict, Optional + +from transformers.configuration_utils import PretrainedConfig + + +class BitNetVisionConfig(PretrainedConfig): + + model_type = 'bitnet_vision' + + def __init__( + self, + # BitNet core parameters + hidden_size: int = 2048, + num_hidden_layers: int = 24, + num_heads: int = 32, + num_kv_heads: int = None, + window_size: Optional[int] = None, + rope_theta: Optional[float] = 10000., + max_position_embeddings: int = 2048, + hidden_act: str = "swish", + initializer_range: float = 0.02, + elementwise_affine: Optional[bool] = True, + norm_first: bool = False, + norm_eps: float = 1e-6, + use_cache: bool = True, + attention_bias: bool = False, + fuse_norm: bool = True, + fuse_cross_entropy: bool = True, + attn: Optional[Dict] = None, + # Vision specific parameters + image_size: int = 224, + patch_size: int = 16, + num_channels: int = 3, + num_classes: int = 1000, + qkv_bias: bool = True, + hidden_dropout_prob: float = 0.0, + use_mask_token: bool = False, + layer_norm_eps: float = 1e-6, + interpolate_pos_encoding: bool = False, + mlp_dim: int = None, + # FLA-for-vision-related parameters + scan_type: str = "uni-scan", # scaning type, "uni-scan" or "bi-scan" or "cross-scan", default to "uni-scan" + **kwargs + ): + # Initialize BitNet core parameters + self.hidden_size = hidden_size + self.num_hidden_layers = num_hidden_layers + self.num_heads = num_heads + self.num_kv_heads = num_kv_heads + self.window_size = window_size + self.rope_theta = rope_theta + self.max_position_embeddings = max_position_embeddings + + self.hidden_act = hidden_act + + self.initializer_range = initializer_range + self.elementwise_affine = elementwise_affine + self.norm_first = norm_first + self.norm_eps = norm_eps + self.use_cache = use_cache + self.attention_bias = attention_bias + self.fuse_cross_entropy = fuse_cross_entropy + self.fuse_norm = fuse_norm + + # Initialize vision specific parameters + self.image_size = image_size + self.patch_size = patch_size + self.num_channels = num_channels + self.num_classes = num_classes + self.qkv_bias = qkv_bias + self.hidden_dropout_prob = hidden_dropout_prob + self.use_mask_token = use_mask_token + self.layer_norm_eps = layer_norm_eps + self.interpolate_pos_encoding = interpolate_pos_encoding + self.scan_type = scan_type + + if attn is not None: + if not isinstance(attn, Dict): + raise ValueError("attn must be a dictionary") + if 'layers' not in attn: + raise ValueError("Layer indices must be provided to initialize hybrid attention layers") + if 'num_heads' not in attn: + raise ValueError("Number of heads must be provided to initialize hybrid attention layers") + attn['num_kv_heads'] = attn.get('num_kv_heads', attn['num_heads']) + attn['window_size'] = attn.get('window_size', None) + + if mlp_dim is None: + self.mlp_dim = 4 * hidden_size # default value set to 4 * hidden_size + else: + self.mlp_dim = mlp_dim + + super().__init__(**kwargs) \ No newline at end of file diff --git a/fla/vision_models/bitnet/modeling_bitnet.py b/fla/vision_models/bitnet/modeling_bitnet.py new file mode 100644 index 000000000..fa9675095 --- /dev/null +++ b/fla/vision_models/bitnet/modeling_bitnet.py @@ -0,0 +1,201 @@ +import collections.abc +import math +import torch +import torch.utils.checkpoint +from torch import nn +from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss +from typing import Optional, Set, Tuple, Union, List, Dict, Unpack +from transformers.utils import logging +from fla.layers.attn import Attention +from transformers.modeling_outputs import ImageClassifierOutput +from transformers.modeling_utils import PreTrainedModel +from .configuration_bitnet import BitNetVisionConfig +from fla.layers.bitattn import BitAttention +from fla.models.utils import Cache +from ..utils import ImageEmbeddings, Pooler, prepare_hidden_states_for_cross_scan, prepare_hidden_states_for_cross_merge + +logger = logging.get_logger(__name__) + +class BitNetMLP(nn.Module): + def __init__(self, config): + super().__init__() + self.net = nn.Sequential( + nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps), + nn.Linear(config.hidden_size, config.mlp_dim), + nn.GELU(), + nn.Linear(config.mlp_dim, config.hidden_size), + nn.Dropout(config.hidden_dropout_prob) + ) + + def forward(self, x): + return self.net(x) + +class BitNetBlock(nn.Module): + def __init__(self, config, layer_idx: int): + super().__init__() + + if not config.norm_first: + self.ln_1 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + + if config.attn is not None and layer_idx in config.attn['layers']: + self.attn = Attention( + hidden_size=config.hidden_size, + num_heads=config.attn['num_heads'], + num_kv_heads=config.attn['num_kv_heads'], + window_size=config.attn['window_size'], + max_position_embeddings=config.max_position_embeddings, + layer_idx=layer_idx + ) + else: + self.attn = BitAttention( + hidden_size=config.hidden_size, + num_heads=config.num_heads, + num_kv_heads=config.num_kv_heads, + window_size=config.window_size, + rope_theta=config.rope_theta, + max_position_embeddings=config.max_position_embeddings, + norm_first=config.norm_first, + norm_eps=config.norm_eps, + layer_idx=layer_idx + ) + + if not config.norm_first: + self.ln_2 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + + self.mlp = BitNetMLP(config) + + self.scan_type = config.scan_type + + def forward( + self, + hidden_states: torch.Tensor, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + use_cache: Optional[bool] = False, + output_attentions: Optional[bool] = False, + **kwargs: Unpack[Dict] + ) -> Union[Tuple[torch.Tensor, Optional[torch.Tensor]], Tuple[torch.Tensor]]: + residual = hidden_states + + # Pre-normalization if enabled + if hasattr(self, 'ln_1'): + hidden_states = self.ln_1(hidden_states) + + # Apply attention + + hidden_states = prepare_hidden_states_for_cross_scan(hidden_states, self.scan_type) + + hidden_states, attentions, past_key_values = self.attn( + hidden_states=hidden_states, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + **kwargs + ) + + hidden_states = prepare_hidden_states_for_cross_merge(hidden_states, self.scan_type) + + # First residual connection + hidden_states = residual + hidden_states + residual = hidden_states + + # Pre-normalization for MLP if enabled + if hasattr(self, 'ln_2'): + hidden_states = self.ln_2(hidden_states) + + # MLP + hidden_states = self.mlp(hidden_states) + + # Second residual connection + hidden_states = residual + hidden_states + + outputs = (hidden_states, attentions, past_key_values) + + return outputs + +class BitNetVisionPreTrainedModel(PreTrainedModel): + # this part of the code is adapted from huggingface/transformers vit implementation + config_class = BitNetVisionConfig + base_model_prefix = "bitnet" + + def _init_weights(self, module): + if isinstance(module, (nn.Linear, nn.Conv2d)): + module.weight.data = nn.init.trunc_normal_( + module.weight.data.to(torch.float32), mean=0.0, std=self.config.initializer_range + ).to(module.weight.dtype) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.LayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + elif isinstance(module, ImageEmbeddings): + module.position_embeddings.data = nn.init.trunc_normal_( + module.position_embeddings.data.to(torch.float32), + mean=0.0, + std=self.config.initializer_range, + ).to(module.position_embeddings.dtype) + +class BitNetForImageClassification(BitNetVisionPreTrainedModel): + config_class = BitNetVisionConfig + + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_classes + + self.embeddings = ImageEmbeddings(config) + self.blocks = nn.ModuleList([ + BitNetBlock(config, layer_idx) + for layer_idx in range(config.num_hidden_layers) + ]) + self.norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.pooler = Pooler(config) + self.classifier = nn.Linear(config.hidden_size, config.num_classes) + self.interpolate_pos_encoding = config.interpolate_pos_encoding + self.init_weights() + + def forward( + self, + pixel_values: Optional[torch.Tensor] = None, + labels: Optional[torch.Tensor] = None, + output_hidden_states: Optional[bool] = None, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + return_dict: Optional[bool] = None, + **kwargs: Unpack[Dict] + ) -> Union[Tuple, ImageClassifierOutput]: + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + fuse_linear_and_cross_entropy = self.config.fuse_cross_entropy and self.training + hidden_states = self.embeddings(pixel_values, interpolate_pos_encoding=self.interpolate_pos_encoding) + + for block in self.blocks: + hidden_states, attentions, past_key_values = block( + hidden_states, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + **kwargs + ) + + hidden_states = self.norm(hidden_states) + pooled_output = self.pooler(hidden_states) + + logits = self.classifier(pooled_output) + + loss = None + if labels is not None: + if self.num_labels == 1: + loss_fct = MSELoss() + loss = loss_fct(logits.squeeze(), labels.squeeze()) + else: + loss_fct = CrossEntropyLoss() + loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) + + if not return_dict: + output = (logits,) + (hidden_states,) + return ((loss,) + output) if loss is not None else output + + return ImageClassifierOutput( + loss=loss, + logits=logits, + hidden_states=hidden_states, + ) diff --git a/fla/vision_models/delta_net/__init__.py b/fla/vision_models/delta_net/__init__.py index 6bc489850..3b951dd04 100644 --- a/fla/vision_models/delta_net/__init__.py +++ b/fla/vision_models/delta_net/__init__.py @@ -3,8 +3,7 @@ from fla.vision_models.delta_net.configuration_delta_net import DeltaNetVisionConfig from fla.vision_models.delta_net.modeling_delta_net import DeltaNetForImageClassification -# Register the model with transformers -AutoConfig.register("delta_net_vision", DeltaNetVisionConfig) +AutoConfig.register(DeltaNetVisionConfig.model_type, DeltaNetVisionConfig) AutoModelForImageClassification.register(DeltaNetVisionConfig, DeltaNetForImageClassification) __all__ = [ diff --git a/fla/vision_models/delta_net/configuration_delta_net.py b/fla/vision_models/delta_net/configuration_delta_net.py index 9fa6b74e5..d490a37f7 100644 --- a/fla/vision_models/delta_net/configuration_delta_net.py +++ b/fla/vision_models/delta_net/configuration_delta_net.py @@ -19,7 +19,6 @@ def __init__( num_heads: int = 16, qk_norm: str = 'l2', qk_activation: str = 'silu', - hidden_ratio: Optional[int] = 4, intermediate_size: Optional[int] = None, hidden_act: str = "swish", num_hidden_layers: int = 12, @@ -30,12 +29,12 @@ def __init__( initializer_range: float = 0.02, fuse_cross_entropy: bool = True, max_position_embeddings: int = 2048, + # Vision specific parameters image_size: int = 224, patch_size: int = 16, num_channels: int = 3, num_classes: int = 1000, - qkv_bias: bool = True, hidden_dropout_prob: float = 0.0, use_mask_token: bool = False, layer_norm_eps: float = 1e-6, @@ -58,7 +57,6 @@ def __init__( self.num_heads = num_heads self.qk_norm = qk_norm self.qk_activation = qk_activation - self.hidden_ratio = hidden_ratio self.intermediate_size = intermediate_size self.hidden_act = hidden_act self.num_hidden_layers = num_hidden_layers @@ -75,14 +73,22 @@ def __init__( self.patch_size = patch_size self.num_channels = num_channels self.num_classes = num_classes - self.qkv_bias = qkv_bias self.hidden_dropout_prob = hidden_dropout_prob self.use_mask_token = use_mask_token self.layer_norm_eps = layer_norm_eps self.interpolate_pos_encoding = interpolate_pos_encoding self.scan_type = scan_type - + if attn is not None: + if not isinstance(attn, Dict): + raise ValueError("attn must be a dictionary") + if 'layers' not in attn: + raise ValueError("Layer indices must be provided to initialize hybrid attention layers") + if 'num_heads' not in attn: + raise ValueError("Number of heads must be provided to initialize hybrid attention layers") + attn['num_kv_heads'] = attn.get('num_kv_heads', attn['num_heads']) + attn['window_size'] = attn.get('window_size', None) + if mlp_dim is None: self.mlp_dim = 4 * hidden_size # default value set to 4 * hidden_size else: diff --git a/fla/vision_models/delta_net/modeling_delta_net.py b/fla/vision_models/delta_net/modeling_delta_net.py index 7dd7026f8..879a16327 100644 --- a/fla/vision_models/delta_net/modeling_delta_net.py +++ b/fla/vision_models/delta_net/modeling_delta_net.py @@ -121,7 +121,6 @@ def forward( class DeltaNetVisionPreTrainedModel(PreTrainedModel): # this part of the code is adapted from huggingface/transformers vit implementation config_class = DeltaNetVisionConfig - base_model_prefix = "deltanet" def _init_weights(self, module): if isinstance(module, (nn.Linear, nn.Conv2d)): diff --git a/fla/vision_models/gated_deltanet/__init__.py b/fla/vision_models/gated_deltanet/__init__.py new file mode 100644 index 000000000..45bb5ffbf --- /dev/null +++ b/fla/vision_models/gated_deltanet/__init__.py @@ -0,0 +1,13 @@ +from transformers import AutoConfig, AutoModelForImageClassification + +from fla.vision_models.gated_deltanet.configuration_gated_deltanet import GatedDeltaNetVisionConfig +from fla.vision_models.gated_deltanet.modeling_gated_deltanet import GatedDeltaNetForImageClassification + +AutoConfig.register(GatedDeltaNetVisionConfig.model_type, GatedDeltaNetVisionConfig) +AutoModelForImageClassification.register(GatedDeltaNetVisionConfig, GatedDeltaNetForImageClassification) + +__all__ = [ + 'GatedDeltaNetVisionConfig', + 'GatedDeltaNetForImageClassification' +] + diff --git a/fla/vision_models/gated_deltanet/configuration_gated_deltanet.py b/fla/vision_models/gated_deltanet/configuration_gated_deltanet.py new file mode 100644 index 000000000..fe472f257 --- /dev/null +++ b/fla/vision_models/gated_deltanet/configuration_gated_deltanet.py @@ -0,0 +1,87 @@ +from typing import Dict, Optional +from transformers.configuration_utils import PretrainedConfig + +class GatedDeltaNetVisionConfig(PretrainedConfig): + model_type = 'gated_deltanet_vision' + + def __init__( + self, + # GatedDeltaNet core parameters + attn_mode: str = "chunk", + hidden_size: int = 2048, + expand_v: int = 2, + use_gate: bool = True, + use_short_conv: bool = True, + conv_size: int = 4, + head_dim: int = 256, + num_heads: int = 6, + max_position_embeddings: int = 2048, + hidden_act: str = "swish", + num_hidden_layers: int = 21, + norm_first: bool = False, + norm_eps: float = 1e-6, + attn: Optional[Dict] = None, + use_cache: bool = True, + initializer_range: float = 0.02, + fuse_cross_entropy: bool = True, + + # Vision specific parameters + image_size: int = 224, + patch_size: int = 16, + num_channels: int = 3, + num_classes: int = 1000, + hidden_dropout_prob: float = 0.0, + use_mask_token: bool = False, + layer_norm_eps: float = 1e-6, + interpolate_pos_encoding: bool = False, + mlp_dim: int = None, + # FLA-for-vision-related parameters + scan_type: str = "uni-scan", + **kwargs + ): + # Initialize GatedDeltaNet core parameters + self.attn_mode = attn_mode + self.hidden_size = hidden_size + self.expand_v = expand_v + self.head_dim = head_dim + self.use_gate = use_gate + self.use_short_conv = use_short_conv + self.conv_size = conv_size + self.num_heads = num_heads + self.hidden_act = hidden_act + self.num_hidden_layers = num_hidden_layers + self.norm_first = norm_first + self.norm_eps = norm_eps + self.use_cache = use_cache + self.initializer_range = initializer_range + self.fuse_cross_entropy = fuse_cross_entropy + self.attn = attn + self.max_position_embeddings = max_position_embeddings + + # Initialize vision specific parameters + self.image_size = image_size + self.patch_size = patch_size + self.num_channels = num_channels + self.num_classes = num_classes + self.hidden_dropout_prob = hidden_dropout_prob + self.use_mask_token = use_mask_token + self.layer_norm_eps = layer_norm_eps + self.interpolate_pos_encoding = interpolate_pos_encoding + self.scan_type = scan_type + + if attn is not None: + if not isinstance(attn, Dict): + raise ValueError("attn must be a dictionary") + if 'layers' not in attn: + raise ValueError("Layer indices must be provided to initialize hybrid attention layers") + if 'num_heads' not in attn: + raise ValueError("Number of heads must be provided to initialize hybrid attention layers") + attn['num_kv_heads'] = attn.get('num_kv_heads', attn['num_heads']) + attn['window_size'] = attn.get('window_size', None) + + if mlp_dim is None: + self.mlp_dim = 4 * hidden_size + else: + self.mlp_dim = mlp_dim + + super().__init__(**kwargs) diff --git a/fla/vision_models/gated_deltanet/modeling_gated_deltanet.py b/fla/vision_models/gated_deltanet/modeling_gated_deltanet.py new file mode 100644 index 000000000..94694cca8 --- /dev/null +++ b/fla/vision_models/gated_deltanet/modeling_gated_deltanet.py @@ -0,0 +1,202 @@ +import collections.abc +import math +import torch +import torch.utils.checkpoint +from torch import nn +from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss +from typing import Optional, Set, Tuple, Union, List, Dict, Unpack +from transformers.utils import logging +from fla.layers.attn import Attention +from transformers.modeling_outputs import ImageClassifierOutput +from transformers.modeling_utils import PreTrainedModel +from .configuration_gated_deltanet import GatedDeltaNetVisionConfig +from fla.layers.gated_deltanet import GatedDeltaNet +from fla.models.utils import Cache +from ..utils import ImageEmbeddings, Pooler, prepare_hidden_states_for_cross_scan, prepare_hidden_states_for_cross_merge + +logger = logging.get_logger(__name__) + +class GatedDeltaNetMLP(nn.Module): + def __init__(self, config): + super().__init__() + self.net = nn.Sequential( + nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps), + nn.Linear(config.hidden_size, config.mlp_dim), + nn.GELU(), + nn.Linear(config.mlp_dim, config.hidden_size), + nn.Dropout(config.hidden_dropout_prob) + ) + + def forward(self, x): + return self.net(x) + +class GatedDeltaNetBlock(nn.Module): + def __init__(self, config, layer_idx: int): + super().__init__() + + if not config.norm_first: + self.ln_1 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + + if config.attn is not None and layer_idx in config.attn['layers']: + self.attn = Attention( + hidden_size=config.hidden_size, + num_heads=config.attn['num_heads'], + num_kv_heads=config.attn['num_kv_heads'], + window_size=config.attn['window_size'], + max_position_embeddings=config.max_position_embeddings, + layer_idx=layer_idx + ) + else: + self.attn = GatedDeltaNet( + mode=config.attn_mode, + hidden_size=config.hidden_size, + expand_v=config.expand_v, + head_dim=config.head_dim, + num_heads=config.num_heads, + use_gate=config.use_gate, + use_short_conv=config.use_short_conv, + conv_size=config.conv_size, + norm_first=config.norm_first, + norm_eps=config.norm_eps, + layer_idx=layer_idx + ) + + if not config.norm_first: + self.ln_2 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + + self.mlp = GatedDeltaNetMLP(config) + + self.scan_type = config.scan_type + + def forward( + self, + hidden_states: torch.Tensor, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + use_cache: Optional[bool] = False, + output_attentions: Optional[bool] = False, + **kwargs: Unpack[Dict] + ) -> Union[Tuple[torch.Tensor, Optional[torch.Tensor]], Tuple[torch.Tensor]]: + residual = hidden_states + + # Pre-normalization if enabled + if hasattr(self, 'ln_1'): + hidden_states = self.ln_1(hidden_states) + + # Apply attention + + hidden_states = prepare_hidden_states_for_cross_scan(hidden_states, self.scan_type) + + hidden_states, attentions, past_key_values = self.attn( + hidden_states=hidden_states, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + **kwargs + ) + + hidden_states = prepare_hidden_states_for_cross_merge(hidden_states, self.scan_type) + + # First residual connection + hidden_states = residual + hidden_states + residual = hidden_states + + # Pre-normalization for MLP if enabled + if hasattr(self, 'ln_2'): + hidden_states = self.ln_2(hidden_states) + + # MLP + hidden_states = self.mlp(hidden_states) + + # Second residual connection + hidden_states = residual + hidden_states + + outputs = (hidden_states, attentions, past_key_values) + + return outputs + +class GatedDeltaNetVisionPreTrainedModel(PreTrainedModel): + # this part of the code is adapted from huggingface/transformers vit implementation + config_class = GatedDeltaNetVisionConfig + + def _init_weights(self, module): + if isinstance(module, (nn.Linear, nn.Conv2d)): + module.weight.data = nn.init.trunc_normal_( + module.weight.data.to(torch.float32), mean=0.0, std=self.config.initializer_range + ).to(module.weight.dtype) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.LayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + elif isinstance(module, ImageEmbeddings): + module.position_embeddings.data = nn.init.trunc_normal_( + module.position_embeddings.data.to(torch.float32), + mean=0.0, + std=self.config.initializer_range, + ).to(module.position_embeddings.dtype) + +class GatedDeltaNetForImageClassification(GatedDeltaNetVisionPreTrainedModel): + config_class = GatedDeltaNetVisionConfig + + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_classes + + self.embeddings = ImageEmbeddings(config) + self.blocks = nn.ModuleList([ + GatedDeltaNetBlock(config, layer_idx) + for layer_idx in range(config.num_hidden_layers) + ]) + self.norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.pooler = Pooler(config) + self.classifier = nn.Linear(config.hidden_size, config.num_classes) + self.interpolate_pos_encoding = config.interpolate_pos_encoding + self.init_weights() + + def forward( + self, + pixel_values: Optional[torch.Tensor] = None, + labels: Optional[torch.Tensor] = None, + output_hidden_states: Optional[bool] = None, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + return_dict: Optional[bool] = None, + **kwargs: Unpack[Dict] + ) -> Union[Tuple, ImageClassifierOutput]: + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + fuse_linear_and_cross_entropy = self.config.fuse_cross_entropy and self.training + hidden_states = self.embeddings(pixel_values, interpolate_pos_encoding=self.interpolate_pos_encoding) + + for block in self.blocks: + hidden_states, attentions, past_key_values = block( + hidden_states, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + **kwargs + ) + + hidden_states = self.norm(hidden_states) + pooled_output = self.pooler(hidden_states) + + logits = self.classifier(pooled_output) + + loss = None + if labels is not None: + if self.num_labels == 1: + loss_fct = MSELoss() + loss = loss_fct(logits.squeeze(), labels.squeeze()) + else: + loss_fct = CrossEntropyLoss() + loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) + + if not return_dict: + output = (logits,) + (hidden_states,) + return ((loss,) + output) if loss is not None else output + + return ImageClassifierOutput( + loss=loss, + logits=logits, + hidden_states=hidden_states, + ) diff --git a/fla/vision_models/gla/__init__.py b/fla/vision_models/gla/__init__.py new file mode 100644 index 000000000..dc7d6e93c --- /dev/null +++ b/fla/vision_models/gla/__init__.py @@ -0,0 +1,12 @@ +from transformers import AutoConfig, AutoModelForImageClassification + +from fla.vision_models.gla.configuration_gla import GLAVisionConfig +from fla.vision_models.gla.modeling_gla import GLAForImageClassification + +AutoConfig.register(GLAVisionConfig.model_type, GLAVisionConfig) +AutoModelForImageClassification.register(GLAVisionConfig, GLAForImageClassification) + +__all__ = [ + 'GLAVisionConfig', + 'GLAForImageClassification' +] diff --git a/fla/vision_models/gla/configuration_gla.py b/fla/vision_models/gla/configuration_gla.py new file mode 100644 index 000000000..77d750f90 --- /dev/null +++ b/fla/vision_models/gla/configuration_gla.py @@ -0,0 +1,95 @@ +from typing import Dict, Optional + +from transformers.configuration_utils import PretrainedConfig + +class GLAVisionConfig(PretrainedConfig): + + model_type = 'gla_vision' + + def __init__( + self, + # GLA core parameters + hidden_size: int = 2048, + expand_k: int = 0.5, + expand_v: int = 1, + num_hidden_layers: int = 24, + num_heads: int = 4, + num_kv_heads: Optional[int] = None, + feature_map: Optional[str] = None, + attn_mode: str = "chunk", + use_short_conv: bool = False, + conv_size: int = 4, + use_output_gate: bool = True, + clamp_min: Optional[float] = None, + hidden_act: str = "swish", + max_position_embeddings: int = 2048, + elementwise_affine: Optional[bool] = True, + norm_eps: float = 1e-6, + use_gk: bool = True, + use_gv: bool = False, + attn: Optional[Dict] = None, + use_cache: bool = True, + initializer_range: float = 0.02, + fuse_norm: bool = True, + fuse_cross_entropy: bool = True, + # Vision specific parameters + image_size: int = 224, + patch_size: int = 16, + num_channels: int = 3, + num_classes: int = 1000, + hidden_dropout_prob: float = 0.0, + use_mask_token: bool = False, + layer_norm_eps: float = 1e-6, + interpolate_pos_encoding: bool = False, + mlp_dim: int = None, + # FLA-for-vision-related parameters + scan_type: str = "uni-scan", # scaning type, "uni-scan" or "bi-scan" or "cross-scan", default to "uni-scan" + **kwargs + ): + # Initialize DeltaNet core parameters + self.hidden_size = hidden_size + self.expand_k = expand_k + self.expand_v = expand_v + self.num_hidden_layers = num_hidden_layers + self.num_heads = num_heads + self.num_kv_heads = num_kv_heads + self.feature_map = feature_map + self.attn_mode = attn_mode + self.use_short_conv = use_short_conv + self.conv_size = conv_size + self.use_output_gate = use_output_gate + self.clamp_min = clamp_min + self.hidden_act = hidden_act + self.max_position_embeddings = max_position_embeddings + self.elementwise_affine = elementwise_affine + self.norm_eps = norm_eps + self.use_gk = use_gk + self.use_gv = use_gv + self.attn = attn + self.use_cache = use_cache + self.initializer_range = initializer_range + self.fuse_norm = fuse_norm + self.fuse_cross_entropy = fuse_cross_entropy + + # Initialize vision specific parameters + self.image_size = image_size + self.patch_size = patch_size + self.num_channels = num_channels + self.num_classes = num_classes + self.hidden_dropout_prob = hidden_dropout_prob + self.use_mask_token = use_mask_token + self.layer_norm_eps = layer_norm_eps + self.interpolate_pos_encoding = interpolate_pos_encoding + self.scan_type = scan_type + + if attn is not None: + if not isinstance(attn, Dict): + raise ValueError("attn must be a dictionary") + if 'layers' not in attn: + raise ValueError("Layer indices must be provided to initialize hybrid attention layers") + if 'num_heads' not in attn: + raise ValueError("Number of heads must be provided to initialize hybrid attention layers") + attn['num_kv_heads'] = attn.get('num_kv_heads', attn['num_heads']) + attn['window_size'] = attn.get('window_size', None) + + super().__init__(**kwargs) diff --git a/fla/vision_models/gla/modeling_gla.py b/fla/vision_models/gla/modeling_gla.py new file mode 100644 index 000000000..433bfb09d --- /dev/null +++ b/fla/vision_models/gla/modeling_gla.py @@ -0,0 +1,207 @@ +import collections.abc +import math +import torch +import torch.utils.checkpoint +from torch import nn +from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss +from typing import Optional, Set, Tuple, Union, List, Dict, Unpack +from transformers.utils import logging +from fla.layers.attn import Attention +from transformers.modeling_outputs import ImageClassifierOutput +from transformers.modeling_utils import PreTrainedModel +from .configuration_gla import GLAVisionConfig +from fla.layers.gla import GatedLinearAttention +from fla.models.utils import Cache +from ..utils import ImageEmbeddings, Pooler, prepare_hidden_states_for_cross_scan, prepare_hidden_states_for_cross_merge + +logger = logging.get_logger(__name__) + +class GLAMLP(nn.Module): + def __init__(self, config): + super().__init__() + self.net = nn.Sequential( + nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps), + nn.Linear(config.hidden_size, config.mlp_dim), + nn.GELU(), + nn.Linear(config.mlp_dim, config.hidden_size), + nn.Dropout(config.hidden_dropout_prob) + ) + + def forward(self, x): + return self.net(x) + +class GLABlock(nn.Module): + def __init__(self, config, layer_idx: int): + super().__init__() + + if not config.norm_first: + self.ln_1 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + + if config.attn is not None and layer_idx in config.attn['layers']: + self.attn = Attention( + hidden_size=config.hidden_size, + num_heads=config.attn['num_heads'], + num_kv_heads=config.attn['num_kv_heads'], + window_size=config.attn['window_size'], + max_position_embeddings=config.max_position_embeddings, + layer_idx=layer_idx + ) + else: + self.attn = GatedLinearAttention( + mode=config.attn_mode, + hidden_size=config.hidden_size, + expand_k=config.expand_k, + expand_v=config.expand_v, + num_heads=config.num_heads, + num_kv_heads=config.num_kv_heads, + feature_map=config.feature_map, + use_short_conv=config.use_short_conv, + conv_size=config.conv_size, + use_output_gate=config.use_output_gate, + gate_fn=config.hidden_act, + elementwise_affine=config.elementwise_affine, + norm_eps=config.norm_eps, + clamp_min=config.clamp_min, + fuse_norm=config.fuse_norm, + layer_idx=layer_idx + ) + + if not config.norm_first: + self.ln_2 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + + self.mlp = GLAMLP(config) + + self.scan_type = config.scan_type + + def forward( + self, + hidden_states: torch.Tensor, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + use_cache: Optional[bool] = False, + output_attentions: Optional[bool] = False, + **kwargs: Unpack[Dict] + ) -> Union[Tuple[torch.Tensor, Optional[torch.Tensor]], Tuple[torch.Tensor]]: + residual = hidden_states + + # Pre-normalization if enabled + if hasattr(self, 'ln_1'): + hidden_states = self.ln_1(hidden_states) + + # Apply attention + + hidden_states = prepare_hidden_states_for_cross_scan(hidden_states, self.scan_type) + + hidden_states, attentions, past_key_values = self.attn( + hidden_states=hidden_states, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + **kwargs + ) + + hidden_states = prepare_hidden_states_for_cross_merge(hidden_states, self.scan_type) + + # First residual connection + hidden_states = residual + hidden_states + residual = hidden_states + + # Pre-normalization for MLP if enabled + if hasattr(self, 'ln_2'): + hidden_states = self.ln_2(hidden_states) + + # MLP + hidden_states = self.mlp(hidden_states) + + # Second residual connection + hidden_states = residual + hidden_states + + outputs = (hidden_states, attentions, past_key_values) + + return outputs + +class GLAVisionPreTrainedModel(PreTrainedModel): + # this part of the code is adapted from huggingface/transformers vit implementation + config_class = GLAVisionConfig + + def _init_weights(self, module): + if isinstance(module, (nn.Linear, nn.Conv2d)): + module.weight.data = nn.init.trunc_normal_( + module.weight.data.to(torch.float32), mean=0.0, std=self.config.initializer_range + ).to(module.weight.dtype) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.LayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + elif isinstance(module, ImageEmbeddings): + module.position_embeddings.data = nn.init.trunc_normal_( + module.position_embeddings.data.to(torch.float32), + mean=0.0, + std=self.config.initializer_range, + ).to(module.position_embeddings.dtype) + +class GLAForImageClassification(GLAVisionPreTrainedModel): + config_class = GLAVisionConfig + + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_classes + + self.embeddings = ImageEmbeddings(config) + self.blocks = nn.ModuleList([ + GLABlock(config, layer_idx) + for layer_idx in range(config.num_hidden_layers) + ]) + self.norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.pooler = Pooler(config) + self.classifier = nn.Linear(config.hidden_size, config.num_classes) + self.interpolate_pos_encoding = config.interpolate_pos_encoding + self.init_weights() + + def forward( + self, + pixel_values: Optional[torch.Tensor] = None, + labels: Optional[torch.Tensor] = None, + output_hidden_states: Optional[bool] = None, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + return_dict: Optional[bool] = None, + **kwargs: Unpack[Dict] + ) -> Union[Tuple, ImageClassifierOutput]: + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + fuse_linear_and_cross_entropy = self.config.fuse_cross_entropy and self.training + hidden_states = self.embeddings(pixel_values, interpolate_pos_encoding=self.interpolate_pos_encoding) + + for block in self.blocks: + hidden_states, attentions, past_key_values = block( + hidden_states, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + **kwargs + ) + + hidden_states = self.norm(hidden_states) + pooled_output = self.pooler(hidden_states) + + logits = self.classifier(pooled_output) + + loss = None + if labels is not None: + if self.num_labels == 1: + loss_fct = MSELoss() + loss = loss_fct(logits.squeeze(), labels.squeeze()) + else: + loss_fct = CrossEntropyLoss() + loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) + + if not return_dict: + output = (logits,) + (hidden_states,) + return ((loss,) + output) if loss is not None else output + + return ImageClassifierOutput( + loss=loss, + logits=logits, + hidden_states=hidden_states, + ) diff --git a/fla/vision_models/gsa/__init__.py b/fla/vision_models/gsa/__init__.py new file mode 100644 index 000000000..3da164504 --- /dev/null +++ b/fla/vision_models/gsa/__init__.py @@ -0,0 +1,12 @@ +from transformers import AutoConfig, AutoModelForImageClassification + +from fla.vision_models.gsa.configuration_gsa import GSAVisionConfig +from fla.vision_models.gsa.modeling_gsa import GSAForImageClassification + +AutoConfig.register(GSAVisionConfig.model_type, GSAVisionConfig) +AutoModelForImageClassification.register(GSAVisionConfig, GSAForImageClassification) + +__all__ = [ + 'GSAVisionConfig', + 'GSAForImageClassification' +] diff --git a/fla/vision_models/gsa/configuration_gsa.py b/fla/vision_models/gsa/configuration_gsa.py new file mode 100644 index 000000000..deca79dce --- /dev/null +++ b/fla/vision_models/gsa/configuration_gsa.py @@ -0,0 +1,106 @@ +from typing import Dict, Optional + +from transformers.configuration_utils import PretrainedConfig + + +class GSAVisionConfig(PretrainedConfig): + + model_type = 'gsa_vision' + + def __init__( + self, + # GSA core parameters + hidden_size: int = 2048, + gate_logit_normalizer: Optional[int] = 8, + clamp_min: Optional[float] = None, + clamp_max: Optional[float] = None, + num_hidden_layers: int = 24, + num_heads: int = 4, + num_kv_heads: Optional[int] = None, + num_slots: Optional[int] = 64, + use_short_conv: bool = False, + conv_size: int = 4, + exapnd_k: float = 1, + exapnd_v: float = 1, + feature_map: str = 'swish', + use_output_gate: bool = False, + use_norm: bool = True, + max_position_embeddings: int = 2048, + hidden_act: str = "swish", + elementwise_affine: Optional[bool] = True, + norm_first: bool = True, + norm_eps: float = 1e-6, + attn: Optional[Dict] = None, + use_cache: bool = True, + initializer_range: float = 0.02, + fuse_norm: bool = True, + fuse_cross_entropy: bool = True, + # Vision specific parameters + image_size: int = 224, + patch_size: int = 16, + num_channels: int = 3, + num_classes: int = 1000, + qkv_bias: bool = True, + hidden_dropout_prob: float = 0.0, + use_mask_token: bool = False, + layer_norm_eps: float = 1e-6, + interpolate_pos_encoding: bool = False, + mlp_dim: int = None, + # FLA-for-vision-related parameters + scan_type: str = "uni-scan", # scaning type, "uni-scan" or "bi-scan" or "cross-scan", default to "uni-scan" + **kwargs + ): + self.hidden_size = hidden_size + self.gate_logit_normalizer = gate_logit_normalizer + self.clamp_min = clamp_min + self.clamp_max = clamp_max + self.num_hidden_layers = num_hidden_layers + self.num_heads = num_heads + self.num_kv_heads = num_kv_heads + self.num_slots = num_slots + self.use_short_conv = use_short_conv + self.conv_size = conv_size + self.expand_k = exapnd_k + self.expand_v = exapnd_v + self.feature_map = feature_map + self.use_output_gate = use_output_gate + self.use_norm = use_norm + self.max_position_embeddings = max_position_embeddings + self.hidden_act = hidden_act + self.elementwise_affine = elementwise_affine + self.norm_first = norm_first + self.norm_eps = norm_eps + self.attn = attn + self.use_cache = use_cache + self.initializer_range = initializer_range + self.fuse_cross_entropy = fuse_cross_entropy + self.fuse_norm = fuse_norm + + # Initialize vision specific parameters + self.image_size = image_size + self.patch_size = patch_size + self.num_channels = num_channels + self.num_classes = num_classes + self.qkv_bias = qkv_bias + self.hidden_dropout_prob = hidden_dropout_prob + self.use_mask_token = use_mask_token + self.layer_norm_eps = layer_norm_eps + self.interpolate_pos_encoding = interpolate_pos_encoding + self.scan_type = scan_type + + if attn is not None: + if not isinstance(attn, Dict): + raise ValueError("attn must be a dictionary") + if 'layers' not in attn: + raise ValueError("Layer indices must be provided to initialize hybrid attention layers") + if 'num_heads' not in attn: + raise ValueError("Number of heads must be provided to initialize hybrid attention layers") + attn['num_kv_heads'] = attn.get('num_kv_heads', attn['num_heads']) + attn['window_size'] = attn.get('window_size', None) + + if mlp_dim is None: + self.mlp_dim = 4 * hidden_size # default value set to 4 * hidden_size + else: + self.mlp_dim = mlp_dim + + super().__init__(**kwargs) diff --git a/fla/vision_models/gsa/modeling_gsa.py b/fla/vision_models/gsa/modeling_gsa.py new file mode 100644 index 000000000..856eea93d --- /dev/null +++ b/fla/vision_models/gsa/modeling_gsa.py @@ -0,0 +1,209 @@ +import collections.abc +import math +import torch +import torch.utils.checkpoint +from torch import nn +from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss +from typing import Optional, Set, Tuple, Union, List, Dict, Unpack +from transformers.utils import logging +from fla.layers.attn import Attention +from transformers.modeling_outputs import ImageClassifierOutput +from transformers.modeling_utils import PreTrainedModel +from .configuration_gsa import GSAVisionConfig +from fla.layers.gsa import GatedSlotAttention +from fla.models.utils import Cache +from ..utils import ImageEmbeddings, Pooler, prepare_hidden_states_for_cross_scan, prepare_hidden_states_for_cross_merge + +logger = logging.get_logger(__name__) + +class GSAMLP(nn.Module): + def __init__(self, config): + super().__init__() + self.net = nn.Sequential( + nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps), + nn.Linear(config.hidden_size, config.mlp_dim), + nn.GELU(), + nn.Linear(config.mlp_dim, config.hidden_size), + nn.Dropout(config.hidden_dropout_prob) + ) + + def forward(self, x): + return self.net(x) + +class GSABlock(nn.Module): + def __init__(self, config, layer_idx: int): + super().__init__() + + if not config.norm_first: + self.ln_1 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + + if config.attn is not None and layer_idx in config.attn['layers']: + self.attn = Attention( + hidden_size=config.hidden_size, + num_heads=config.attn['num_heads'], + num_kv_heads=config.attn['num_kv_heads'], + window_size=config.attn['window_size'], + max_position_embeddings=config.max_position_embeddings, + layer_idx=layer_idx + ) + else: + self.attn = GatedSlotAttention( + hidden_size=config.hidden_size, + expand_k=config.expand_k, + expand_v=config.expand_v, + num_heads=config.num_heads, + num_kv_heads=config.num_kv_heads, + num_slots=config.num_slots, + use_short_conv=config.use_short_conv, + conv_size=config.conv_size, + feature_map=config.feature_map, + use_output_gate=config.use_output_gate, + use_norm=config.use_norm, + gate_fn=config.hidden_act, + gate_logit_normalizer=config.gate_logit_normalizer, + elementwise_affine=config.elementwise_affine, + norm_first=config.norm_first, + norm_eps=config.norm_eps, + fuse_norm=config.fuse_norm, + layer_idx=layer_idx + ) + + if not config.norm_first: + self.ln_2 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + + self.mlp = GSAMLP(config) + + self.scan_type = config.scan_type + + def forward( + self, + hidden_states: torch.Tensor, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + use_cache: Optional[bool] = False, + output_attentions: Optional[bool] = False, + **kwargs: Unpack[Dict] + ) -> Union[Tuple[torch.Tensor, Optional[torch.Tensor]], Tuple[torch.Tensor]]: + residual = hidden_states + + # Pre-normalization if enabled + if hasattr(self, 'ln_1'): + hidden_states = self.ln_1(hidden_states) + + # Apply attention + + hidden_states = prepare_hidden_states_for_cross_scan(hidden_states, self.scan_type) + + hidden_states, attentions, past_key_values = self.attn( + hidden_states=hidden_states, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + **kwargs + ) + + hidden_states = prepare_hidden_states_for_cross_merge(hidden_states, self.scan_type) + + # First residual connection + hidden_states = residual + hidden_states + residual = hidden_states + + # Pre-normalization for MLP if enabled + if hasattr(self, 'ln_2'): + hidden_states = self.ln_2(hidden_states) + + # MLP + hidden_states = self.mlp(hidden_states) + + # Second residual connection + hidden_states = residual + hidden_states + + outputs = (hidden_states, attentions, past_key_values) + + return outputs + +class GSAVisionPreTrainedModel(PreTrainedModel): + # this part of the code is adapted from huggingface/transformers vit implementation + config_class = GSAVisionConfig + + def _init_weights(self, module): + if isinstance(module, (nn.Linear, nn.Conv2d)): + module.weight.data = nn.init.trunc_normal_( + module.weight.data.to(torch.float32), mean=0.0, std=self.config.initializer_range + ).to(module.weight.dtype) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.LayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + elif isinstance(module, ImageEmbeddings): + module.position_embeddings.data = nn.init.trunc_normal_( + module.position_embeddings.data.to(torch.float32), + mean=0.0, + std=self.config.initializer_range, + ).to(module.position_embeddings.dtype) + +class GSAForImageClassification(GSAVisionPreTrainedModel): + config_class = GSAVisionConfig + + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_classes + + self.embeddings = ImageEmbeddings(config) + self.blocks = nn.ModuleList([ + GSABlock(config, layer_idx) + for layer_idx in range(config.num_hidden_layers) + ]) + self.norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.pooler = Pooler(config) + self.classifier = nn.Linear(config.hidden_size, config.num_classes) + self.interpolate_pos_encoding = config.interpolate_pos_encoding + self.init_weights() + + def forward( + self, + pixel_values: Optional[torch.Tensor] = None, + labels: Optional[torch.Tensor] = None, + output_hidden_states: Optional[bool] = None, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + return_dict: Optional[bool] = None, + **kwargs: Unpack[Dict] + ) -> Union[Tuple, ImageClassifierOutput]: + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + fuse_linear_and_cross_entropy = self.config.fuse_cross_entropy and self.training + hidden_states = self.embeddings(pixel_values, interpolate_pos_encoding=self.interpolate_pos_encoding) + + for block in self.blocks: + hidden_states, attentions, past_key_values = block( + hidden_states, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + **kwargs + ) + + hidden_states = self.norm(hidden_states) + pooled_output = self.pooler(hidden_states) + + logits = self.classifier(pooled_output) + + loss = None + if labels is not None: + if self.num_labels == 1: + loss_fct = MSELoss() + loss = loss_fct(logits.squeeze(), labels.squeeze()) + else: + loss_fct = CrossEntropyLoss() + loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) + + if not return_dict: + output = (logits,) + (hidden_states,) + return ((loss,) + output) if loss is not None else output + + return ImageClassifierOutput( + loss=loss, + logits=logits, + hidden_states=hidden_states, + ) diff --git a/fla/vision_models/hgrn/__init__.py b/fla/vision_models/hgrn/__init__.py new file mode 100644 index 000000000..e9ab00ae0 --- /dev/null +++ b/fla/vision_models/hgrn/__init__.py @@ -0,0 +1,12 @@ +from transformers import AutoConfig, AutoModelForImageClassification + +from fla.vision_models.hgrn.configuration_hgrn import HGRNVisionConfig +from fla.vision_models.hgrn.modeling_hgrn import HGRNForImageClassification + +AutoConfig.register(HGRNVisionConfig.model_type, HGRNVisionConfig) +AutoModelForImageClassification.register(HGRNVisionConfig, HGRNForImageClassification) + +__all__ = [ + 'HGRNVisionConfig', + 'HGRNForImageClassification' +] diff --git a/fla/vision_models/hgrn/configuration_hgrn.py b/fla/vision_models/hgrn/configuration_hgrn.py new file mode 100644 index 000000000..de5aae00b --- /dev/null +++ b/fla/vision_models/hgrn/configuration_hgrn.py @@ -0,0 +1,85 @@ +from typing import Dict, Optional + +from transformers.configuration_utils import PretrainedConfig + + +class HGRNVisionConfig(PretrainedConfig): + + model_type = 'hgrn_vision' + + def __init__( + self, + # HGRN core parameters + attn_mode: str = "chunk", + hidden_size: int = 2048, + num_hidden_layers: int = 24, + expand_ratio: Optional[int] = 1, + use_short_conv: bool = False, + conv_size: int = 4, + use_lower_bound: bool = True, + max_position_embeddings: int = 2048, + hidden_act: str = "swish", + elementwise_affine: Optional[bool] = True, + norm_eps: float = 1e-6, + attn: Optional[Dict] = None, + use_cache: bool = True, + initializer_range: float = 0.02, + fuse_cross_entropy: bool = True, + # Vision specific parameters + image_size: int = 224, + patch_size: int = 16, + num_channels: int = 3, + num_classes: int = 1000, + qkv_bias: bool = True, + hidden_dropout_prob: float = 0.0, + use_mask_token: bool = False, + layer_norm_eps: float = 1e-6, + interpolate_pos_encoding: bool = False, + mlp_dim: int = None, + # FLA-for-vision-related parameters + scan_type: str = "uni-scan", # scaning type, "uni-scan" or "bi-scan" or "cross-scan", default to "uni-scan" + **kwargs + ): + # Initialize HGRN core parameters + self.attn_mode = attn_mode + self.hidden_size = hidden_size + self.num_hidden_layers = num_hidden_layers + self.expand_ratio = expand_ratio + self.use_short_conv = use_short_conv + self.conv_size = conv_size + self.use_lower_bound = use_lower_bound + self.max_position_embeddings = max_position_embeddings + self.elementwise_affine = elementwise_affine + self.attn = attn + self.norm_eps = norm_eps + self.hidden_act = hidden_act + self.use_cache = use_cache + self.initializer_range = initializer_range + self.fuse_cross_entropy = fuse_cross_entropy + # Initialize vision specific parameters + self.image_size = image_size + self.patch_size = patch_size + self.num_channels = num_channels + self.num_classes = num_classes + self.qkv_bias = qkv_bias + self.hidden_dropout_prob = hidden_dropout_prob + self.use_mask_token = use_mask_token + self.layer_norm_eps = layer_norm_eps + self.interpolate_pos_encoding = interpolate_pos_encoding + self.scan_type = scan_type + if attn is not None: + if not isinstance(attn, Dict): + raise ValueError("attn must be a dictionary") + if 'layers' not in attn: + raise ValueError("Layer indices must be provided to initialize hybrid attention layers") + if 'num_heads' not in attn: + raise ValueError("Number of heads must be provided to initialize hybrid attention layers") + attn['num_kv_heads'] = attn.get('num_kv_heads', attn['num_heads']) + attn['window_size'] = attn.get('window_size', None) + + if mlp_dim is None: + self.mlp_dim = 4 * hidden_size # default value set to 4 * hidden_size + else: + self.mlp_dim = mlp_dim + + super().__init__(**kwargs) diff --git a/fla/vision_models/hgrn/modeling_hgrn.py b/fla/vision_models/hgrn/modeling_hgrn.py new file mode 100644 index 000000000..8d591cbc6 --- /dev/null +++ b/fla/vision_models/hgrn/modeling_hgrn.py @@ -0,0 +1,199 @@ +import collections.abc +import math +import torch +import torch.utils.checkpoint +from torch import nn +from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss +from typing import Optional, Set, Tuple, Union, List, Dict, Unpack +from transformers.utils import logging +from fla.layers.attn import Attention +from transformers.modeling_outputs import ImageClassifierOutput +from transformers.modeling_utils import PreTrainedModel +from .configuration_hgrn import HGRNVisionConfig +from fla.layers.hgrn import HGRNAttention +from fla.models.utils import Cache +from ..utils import ImageEmbeddings, Pooler, prepare_hidden_states_for_cross_scan, prepare_hidden_states_for_cross_merge + +logger = logging.get_logger(__name__) + +class HGRNMLP(nn.Module): + def __init__(self, config): + super().__init__() + self.net = nn.Sequential( + nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps), + nn.Linear(config.hidden_size, config.mlp_dim), + nn.GELU(), + nn.Linear(config.mlp_dim, config.hidden_size), + nn.Dropout(config.hidden_dropout_prob) + ) + + def forward(self, x): + return self.net(x) + +class HGRNBlock(nn.Module): + def __init__(self, config, layer_idx: int): + super().__init__() + + if not config.norm_first: + self.ln_1 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + + if config.attn is not None and layer_idx in config.attn['layers']: + self.attn = Attention( + hidden_size=config.hidden_size, + num_heads=config.attn['num_heads'], + num_kv_heads=config.attn['num_kv_heads'], + window_size=config.attn['window_size'], + max_position_embeddings=config.max_position_embeddings, + layer_idx=layer_idx + ) + else: + self.attn = HGRNAttention( + mode=config.attn_mode, + hidden_size=config.hidden_size, + expand_ratio=config.expand_ratio, + use_short_conv=config.use_short_conv, + conv_size=config.conv_size, + elementwise_affine=config.elementwise_affine, + norm_eps=config.norm_eps, + layer_idx=layer_idx + ) + + if not config.norm_first: + self.ln_2 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + + self.mlp = HGRNMLP(config) + + self.scan_type = config.scan_type + + def forward( + self, + hidden_states: torch.Tensor, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + use_cache: Optional[bool] = False, + output_attentions: Optional[bool] = False, + **kwargs: Unpack[Dict] + ) -> Union[Tuple[torch.Tensor, Optional[torch.Tensor]], Tuple[torch.Tensor]]: + residual = hidden_states + + # Pre-normalization if enabled + if hasattr(self, 'ln_1'): + hidden_states = self.ln_1(hidden_states) + + # Apply attention + + hidden_states = prepare_hidden_states_for_cross_scan(hidden_states, self.scan_type) + + hidden_states, attentions, past_key_values = self.attn( + hidden_states=hidden_states, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + **kwargs + ) + + hidden_states = prepare_hidden_states_for_cross_merge(hidden_states, self.scan_type) + + # First residual connection + hidden_states = residual + hidden_states + residual = hidden_states + + # Pre-normalization for MLP if enabled + if hasattr(self, 'ln_2'): + hidden_states = self.ln_2(hidden_states) + + # MLP + hidden_states = self.mlp(hidden_states) + + # Second residual connection + hidden_states = residual + hidden_states + + outputs = (hidden_states, attentions, past_key_values) + + return outputs + +class HGRNVisionPreTrainedModel(PreTrainedModel): + # this part of the code is adapted from huggingface/transformers vit implementation + config_class = HGRNVisionConfig + + def _init_weights(self, module): + if isinstance(module, (nn.Linear, nn.Conv2d)): + module.weight.data = nn.init.trunc_normal_( + module.weight.data.to(torch.float32), mean=0.0, std=self.config.initializer_range + ).to(module.weight.dtype) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.LayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + elif isinstance(module, ImageEmbeddings): + module.position_embeddings.data = nn.init.trunc_normal_( + module.position_embeddings.data.to(torch.float32), + mean=0.0, + std=self.config.initializer_range, + ).to(module.position_embeddings.dtype) + +class HGRNForImageClassification(HGRNVisionPreTrainedModel): + config_class = HGRNVisionConfig + + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_classes + + self.embeddings = ImageEmbeddings(config) + self.blocks = nn.ModuleList([ + HGRNBlock(config, layer_idx) + for layer_idx in range(config.num_hidden_layers) + ]) + self.norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.pooler = Pooler(config) + self.classifier = nn.Linear(config.hidden_size, config.num_classes) + self.interpolate_pos_encoding = config.interpolate_pos_encoding + self.init_weights() + + def forward( + self, + pixel_values: Optional[torch.Tensor] = None, + labels: Optional[torch.Tensor] = None, + output_hidden_states: Optional[bool] = None, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + return_dict: Optional[bool] = None, + **kwargs: Unpack[Dict] + ) -> Union[Tuple, ImageClassifierOutput]: + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + fuse_linear_and_cross_entropy = self.config.fuse_cross_entropy and self.training + hidden_states = self.embeddings(pixel_values, interpolate_pos_encoding=self.interpolate_pos_encoding) + + for block in self.blocks: + hidden_states, attentions, past_key_values = block( + hidden_states, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + **kwargs + ) + + hidden_states = self.norm(hidden_states) + pooled_output = self.pooler(hidden_states) + + logits = self.classifier(pooled_output) + + loss = None + if labels is not None: + if self.num_labels == 1: + loss_fct = MSELoss() + loss = loss_fct(logits.squeeze(), labels.squeeze()) + else: + loss_fct = CrossEntropyLoss() + loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) + + if not return_dict: + output = (logits,) + (hidden_states,) + return ((loss,) + output) if loss is not None else output + + return ImageClassifierOutput( + loss=loss, + logits=logits, + hidden_states=hidden_states, + ) diff --git a/fla/vision_models/hgrn2/__init__.py b/fla/vision_models/hgrn2/__init__.py new file mode 100644 index 000000000..69a2c9c55 --- /dev/null +++ b/fla/vision_models/hgrn2/__init__.py @@ -0,0 +1,12 @@ +from transformers import AutoConfig, AutoModelForImageClassification + +from fla.vision_models.hgrn2.configuration_hgrn2 import HGRN2VisionConfig +from fla.vision_models.hgrn2.modeling_hgrn2 import HGRN2ForImageClassification + +AutoConfig.register(HGRN2VisionConfig.model_type, HGRN2VisionConfig) +AutoModelForImageClassification.register(HGRN2VisionConfig, HGRN2ForImageClassification) + +__all__ = [ + 'HGRN2VisionConfig', + 'HGRN2ForImageClassification' +] diff --git a/fla/vision_models/hgrn2/configuration_hgrn2.py b/fla/vision_models/hgrn2/configuration_hgrn2.py new file mode 100644 index 000000000..e8e5df182 --- /dev/null +++ b/fla/vision_models/hgrn2/configuration_hgrn2.py @@ -0,0 +1,88 @@ +from typing import Dict, Optional + +from transformers.configuration_utils import PretrainedConfig + + +class HGRN2VisionConfig(PretrainedConfig): + + model_type = 'hgrn2_vision' + + def __init__( + self, + # HGRN2 core parameters + hidden_size: int = 2048, + num_hidden_layers: int = 24, + attn_mode: str = "chunk", + num_heads: Optional[int] = None, + expand_ratio: Optional[int] = 128, + use_short_conv: bool = False, + conv_size: int = 4, + use_lower_bound: bool = True, + hidden_act: str = "swish", + max_position_embeddings: int = 2048, + elementwise_affine: Optional[bool] = True, + norm_eps: float = 1e-6, + attn: Optional[Dict] = None, + use_cache: bool = True, + initializer_range: float = 0.02, + fuse_cross_entropy: bool = True, + # Vision specific parameters + image_size: int = 224, + patch_size: int = 16, + num_channels: int = 3, + num_classes: int = 1000, + qkv_bias: bool = True, + hidden_dropout_prob: float = 0.0, + use_mask_token: bool = False, + layer_norm_eps: float = 1e-6, + interpolate_pos_encoding: bool = False, + mlp_dim: int = None, + # FLA-for-vision-related parameters + scan_type: str = "uni-scan", # scaning type, "uni-scan" or "bi-scan" or "cross-scan", default to "uni-scan" + **kwargs + ): + # Initialize HGRN2 core parameters + self.hidden_size = hidden_size + self.num_hidden_layers = num_hidden_layers + self.attn_mode = attn_mode + self.num_heads = num_heads + self.expand_ratio = expand_ratio + self.use_short_conv = use_short_conv + self.conv_size = conv_size + self.use_lower_bound = use_lower_bound + self.max_position_embeddings = max_position_embeddings + self.hidden_act = hidden_act + self.elementwise_affine = elementwise_affine + self.norm_eps = norm_eps + self.attn = attn + self.use_cache = use_cache + self.initializer_range = initializer_range + self.fuse_cross_entropy = fuse_cross_entropy + # Initialize vision specific parameters + self.image_size = image_size + self.patch_size = patch_size + self.num_channels = num_channels + self.num_classes = num_classes + self.qkv_bias = qkv_bias + self.hidden_dropout_prob = hidden_dropout_prob + self.use_mask_token = use_mask_token + self.layer_norm_eps = layer_norm_eps + self.interpolate_pos_encoding = interpolate_pos_encoding + self.scan_type = scan_type + + if attn is not None: + if not isinstance(attn, Dict): + raise ValueError("attn must be a dictionary") + if 'layers' not in attn: + raise ValueError("Layer indices must be provided to initialize hybrid attention layers") + if 'num_heads' not in attn: + raise ValueError("Number of heads must be provided to initialize hybrid attention layers") + attn['num_kv_heads'] = attn.get('num_kv_heads', attn['num_heads']) + attn['window_size'] = attn.get('window_size', None) + + if mlp_dim is None: + self.mlp_dim = 4 * hidden_size # default value set to 4 * hidden_size + else: + self.mlp_dim = mlp_dim + + super().__init__(**kwargs) \ No newline at end of file diff --git a/fla/vision_models/hgrn2/modeling_hgrn2.py b/fla/vision_models/hgrn2/modeling_hgrn2.py new file mode 100644 index 000000000..3284d1b76 --- /dev/null +++ b/fla/vision_models/hgrn2/modeling_hgrn2.py @@ -0,0 +1,200 @@ +import collections.abc +import math +import torch +import torch.utils.checkpoint +from torch import nn +from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss +from typing import Optional, Set, Tuple, Union, List, Dict, Unpack +from transformers.utils import logging +from fla.layers.attn import Attention +from transformers.modeling_outputs import ImageClassifierOutput +from transformers.modeling_utils import PreTrainedModel +from .configuration_hgrn2 import HGRN2VisionConfig +from fla.layers.hgrn2 import HGRN2Attention +from fla.models.utils import Cache +from ..utils import ImageEmbeddings, Pooler, prepare_hidden_states_for_cross_scan, prepare_hidden_states_for_cross_merge + +logger = logging.get_logger(__name__) + +class HGRN2MLP(nn.Module): + def __init__(self, config): + super().__init__() + self.net = nn.Sequential( + nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps), + nn.Linear(config.hidden_size, config.mlp_dim), + nn.GELU(), + nn.Linear(config.mlp_dim, config.hidden_size), + nn.Dropout(config.hidden_dropout_prob) + ) + + def forward(self, x): + return self.net(x) + +class HGRN2Block(nn.Module): + def __init__(self, config, layer_idx: int): + super().__init__() + + if not config.norm_first: + self.ln_1 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + + if config.attn is not None and layer_idx in config.attn['layers']: + self.attn = Attention( + hidden_size=config.hidden_size, + num_heads=config.attn['num_heads'], + num_kv_heads=config.attn['num_kv_heads'], + window_size=config.attn['window_size'], + max_position_embeddings=config.max_position_embeddings, + layer_idx=layer_idx + ) + else: + self.attn = HGRN2Attention( + mode=config.attn_mode, + hidden_size=config.hidden_size, + num_heads=config.num_heads, + expand_ratio=config.expand_ratio, + use_short_conv=config.use_short_conv, + conv_size=config.conv_size, + elementwise_affine=config.elementwise_affine, + norm_eps=config.norm_eps, + layer_idx=layer_idx + ) + + if not config.norm_first: + self.ln_2 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + + self.mlp = HGRN2MLP(config) + + self.scan_type = config.scan_type + + def forward( + self, + hidden_states: torch.Tensor, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + use_cache: Optional[bool] = False, + output_attentions: Optional[bool] = False, + **kwargs: Unpack[Dict] + ) -> Union[Tuple[torch.Tensor, Optional[torch.Tensor]], Tuple[torch.Tensor]]: + residual = hidden_states + + # Pre-normalization if enabled + if hasattr(self, 'ln_1'): + hidden_states = self.ln_1(hidden_states) + + # Apply attention + + hidden_states = prepare_hidden_states_for_cross_scan(hidden_states, self.scan_type) + + hidden_states, attentions, past_key_values = self.attn( + hidden_states=hidden_states, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + **kwargs + ) + + hidden_states = prepare_hidden_states_for_cross_merge(hidden_states, self.scan_type) + + # First residual connection + hidden_states = residual + hidden_states + residual = hidden_states + + # Pre-normalization for MLP if enabled + if hasattr(self, 'ln_2'): + hidden_states = self.ln_2(hidden_states) + + # MLP + hidden_states = self.mlp(hidden_states) + + # Second residual connection + hidden_states = residual + hidden_states + + outputs = (hidden_states, attentions, past_key_values) + + return outputs + +class HGRN2VisionPreTrainedModel(PreTrainedModel): + # this part of the code is adapted from huggingface/transformers vit implementation + config_class = HGRN2VisionConfig + + def _init_weights(self, module): + if isinstance(module, (nn.Linear, nn.Conv2d)): + module.weight.data = nn.init.trunc_normal_( + module.weight.data.to(torch.float32), mean=0.0, std=self.config.initializer_range + ).to(module.weight.dtype) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.LayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + elif isinstance(module, ImageEmbeddings): + module.position_embeddings.data = nn.init.trunc_normal_( + module.position_embeddings.data.to(torch.float32), + mean=0.0, + std=self.config.initializer_range, + ).to(module.position_embeddings.dtype) + +class HGRN2ForImageClassification(HGRN2VisionPreTrainedModel): + config_class = HGRN2VisionConfig + + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_classes + + self.embeddings = ImageEmbeddings(config) + self.blocks = nn.ModuleList([ + HGRN2Block(config, layer_idx) + for layer_idx in range(config.num_hidden_layers) + ]) + self.norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.pooler = Pooler(config) + self.classifier = nn.Linear(config.hidden_size, config.num_classes) + self.interpolate_pos_encoding = config.interpolate_pos_encoding + self.init_weights() + + def forward( + self, + pixel_values: Optional[torch.Tensor] = None, + labels: Optional[torch.Tensor] = None, + output_hidden_states: Optional[bool] = None, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + return_dict: Optional[bool] = None, + **kwargs: Unpack[Dict] + ) -> Union[Tuple, ImageClassifierOutput]: + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + fuse_linear_and_cross_entropy = self.config.fuse_cross_entropy and self.training + hidden_states = self.embeddings(pixel_values, interpolate_pos_encoding=self.interpolate_pos_encoding) + + for block in self.blocks: + hidden_states, attentions, past_key_values = block( + hidden_states, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + **kwargs + ) + + hidden_states = self.norm(hidden_states) + pooled_output = self.pooler(hidden_states) + + logits = self.classifier(pooled_output) + + loss = None + if labels is not None: + if self.num_labels == 1: + loss_fct = MSELoss() + loss = loss_fct(logits.squeeze(), labels.squeeze()) + else: + loss_fct = CrossEntropyLoss() + loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) + + if not return_dict: + output = (logits,) + (hidden_states,) + return ((loss,) + output) if loss is not None else output + + return ImageClassifierOutput( + loss=loss, + logits=logits, + hidden_states=hidden_states, + ) diff --git a/fla/vision_models/linear_attn/__init__.py b/fla/vision_models/linear_attn/__init__.py new file mode 100644 index 000000000..d56bc5e04 --- /dev/null +++ b/fla/vision_models/linear_attn/__init__.py @@ -0,0 +1,12 @@ +from transformers import AutoConfig, AutoModelForImageClassification + +from fla.vision_models.linear_attn.configuration_linear_attn import LinearAttentionVisionConfig +from fla.vision_models.linear_attn.modeling_linear_attn import LinearAttentionForImageClassification + +AutoConfig.register(LinearAttentionVisionConfig.model_type, LinearAttentionVisionConfig) +AutoModelForImageClassification.register(LinearAttentionVisionConfig, LinearAttentionForImageClassification) + +__all__ = [ + 'LinearAttentionVisionConfig', + 'LinearAttentionForImageClassification' +] diff --git a/fla/vision_models/linear_attn/configuration_linear_attn.py b/fla/vision_models/linear_attn/configuration_linear_attn.py new file mode 100644 index 000000000..d05e3c0ba --- /dev/null +++ b/fla/vision_models/linear_attn/configuration_linear_attn.py @@ -0,0 +1,96 @@ +from typing import Dict, Optional + +from transformers.configuration_utils import PretrainedConfig + + +class LinearAttentionVisionConfig(PretrainedConfig): + + model_type = 'linear_attn_vision' + + def __init__( + self, + # LinearAttention core parameters + attn_mode: str = "fused_chunk", + hidden_size: int = 2048, + expand_k: int = 1, + expand_v: int = 1, + num_hidden_layers: int = 24, + num_heads: int = 4, + num_kv_heads: Optional[int] = None, + feature_map: str = "elementwise_product", + tie_feature_map_qk: bool = False, + norm_q: bool = False, + norm_k: bool = False, + norm_feature_map: bool = False, + hidden_act: str = "swish", + max_position_embeddings: int = 2048, + elementwise_affine: Optional[bool] = True, + norm_eps: float = 1e-6, + attn: Optional[Dict] = None, + use_cache: bool = True, + initializer_range: float = 0.02, + fuse_cross_entropy: bool = True, + # Vision specific parameters + image_size: int = 224, + patch_size: int = 16, + num_channels: int = 3, + num_classes: int = 1000, + qkv_bias: bool = True, + hidden_dropout_prob: float = 0.0, + use_mask_token: bool = False, + layer_norm_eps: float = 1e-6, + interpolate_pos_encoding: bool = False, + mlp_dim: int = None, + # FLA-for-vision-related parameters + scan_type: str = "uni-scan", # scaning type, "uni-scan" or "bi-scan" or "cross-scan", default to "uni-scan" + **kwargs + ): + # Initialize LinearAttention core parameters + self.attn_mode = attn_mode + self.hidden_size = hidden_size + self.expand_k = expand_k + self.expand_v = expand_v + self.num_hidden_layers = num_hidden_layers + self.num_heads = num_heads + self.num_kv_heads = num_kv_heads + self.feature_map = feature_map + self.tie_feature_map_qk = tie_feature_map_qk + self.norm_q = norm_q + self.norm_k = norm_k + self.norm_feature_map = norm_feature_map + self.max_position_embeddings = max_position_embeddings + self.elementwise_affine = elementwise_affine + self.norm_eps = norm_eps + self.attn = attn + self.use_cache = use_cache + self.initializer_range = initializer_range + self.fuse_cross_entropy = fuse_cross_entropy + + # Initialize vision specific parameters + self.image_size = image_size + self.patch_size = patch_size + self.num_channels = num_channels + self.num_classes = num_classes + self.qkv_bias = qkv_bias + self.hidden_dropout_prob = hidden_dropout_prob + self.use_mask_token = use_mask_token + self.layer_norm_eps = layer_norm_eps + self.interpolate_pos_encoding = interpolate_pos_encoding + self.scan_type = scan_type + + if attn is not None: + if not isinstance(attn, Dict): + raise ValueError("attn must be a dictionary") + if 'layers' not in attn: + raise ValueError("Layer indices must be provided to initialize hybrid attention layers") + if 'num_heads' not in attn: + raise ValueError("Number of heads must be provided to initialize hybrid attention layers") + attn['num_kv_heads'] = attn.get('num_kv_heads', attn['num_heads']) + attn['window_size'] = attn.get('window_size', None) + + if mlp_dim is None: + self.mlp_dim = 4 * hidden_size # default value set to 4 * hidden_size + else: + self.mlp_dim = mlp_dim + + super().__init__(**kwargs) \ No newline at end of file diff --git a/fla/vision_models/linear_attn/modeling_linear_attn.py b/fla/vision_models/linear_attn/modeling_linear_attn.py new file mode 100644 index 000000000..2cd01fb2b --- /dev/null +++ b/fla/vision_models/linear_attn/modeling_linear_attn.py @@ -0,0 +1,205 @@ +import collections.abc +import math +import torch +import torch.utils.checkpoint +from torch import nn +from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss +from typing import Optional, Set, Tuple, Union, List, Dict, Unpack +from transformers.utils import logging +from fla.layers.attn import Attention +from transformers.modeling_outputs import ImageClassifierOutput +from transformers.modeling_utils import PreTrainedModel +from .configuration_linear_attn import LinearAttentionVisionConfig +from fla.layers.linear_attn import LinearAttention +from fla.models.utils import Cache +from ..utils import ImageEmbeddings, Pooler, prepare_hidden_states_for_cross_scan, prepare_hidden_states_for_cross_merge + +logger = logging.get_logger(__name__) + +class LinearAttentionMLP(nn.Module): + def __init__(self, config): + super().__init__() + self.net = nn.Sequential( + nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps), + nn.Linear(config.hidden_size, config.mlp_dim), + nn.GELU(), + nn.Linear(config.mlp_dim, config.hidden_size), + nn.Dropout(config.hidden_dropout_prob) + ) + + def forward(self, x): + return self.net(x) + +class LinearAttentionBlock(nn.Module): + def __init__(self, config, layer_idx: int): + super().__init__() + + if not config.norm_first: + self.ln_1 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + + if config.attn is not None and layer_idx in config.attn['layers']: + self.attn = Attention( + hidden_size=config.hidden_size, + num_heads=config.attn['num_heads'], + num_kv_heads=config.attn['num_kv_heads'], + window_size=config.attn['window_size'], + max_position_embeddings=config.max_position_embeddings, + layer_idx=layer_idx + ) + else: + self.attn = LinearAttention( + mode=config.attn_mode, + hidden_size=config.hidden_size, + expand_k=config.expand_k, + expand_v=config.expand_v, + num_heads=config.num_heads, + num_kv_heads=config.num_kv_heads, + feature_map=config.feature_map, + tie_feature_map_qk=config.tie_feature_map_qk, + norm_q=config.norm_q, + norm_k=config.norm_k, + do_feature_map_norm=config.norm_feature_map, + elementwise_affine=config.elementwise_affine, + norm_eps=config.norm_eps, + layer_idx=layer_idx + ) + + if not config.norm_first: + self.ln_2 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + + self.mlp = LinearAttentionMLP(config) + + self.scan_type = config.scan_type + + def forward( + self, + hidden_states: torch.Tensor, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + use_cache: Optional[bool] = False, + output_attentions: Optional[bool] = False, + **kwargs: Unpack[Dict] + ) -> Union[Tuple[torch.Tensor, Optional[torch.Tensor]], Tuple[torch.Tensor]]: + residual = hidden_states + + # Pre-normalization if enabled + if hasattr(self, 'ln_1'): + hidden_states = self.ln_1(hidden_states) + + # Apply attention + + hidden_states = prepare_hidden_states_for_cross_scan(hidden_states, self.scan_type) + + hidden_states, attentions, past_key_values = self.attn( + hidden_states=hidden_states, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + **kwargs + ) + + hidden_states = prepare_hidden_states_for_cross_merge(hidden_states, self.scan_type) + + # First residual connection + hidden_states = residual + hidden_states + residual = hidden_states + + # Pre-normalization for MLP if enabled + if hasattr(self, 'ln_2'): + hidden_states = self.ln_2(hidden_states) + + # MLP + hidden_states = self.mlp(hidden_states) + + # Second residual connection + hidden_states = residual + hidden_states + + outputs = (hidden_states, attentions, past_key_values) + + return outputs + +class LinearAttentionVisionPreTrainedModel(PreTrainedModel): + # this part of the code is adapted from huggingface/transformers vit implementation + config_class = LinearAttentionVisionConfig + + def _init_weights(self, module): + if isinstance(module, (nn.Linear, nn.Conv2d)): + module.weight.data = nn.init.trunc_normal_( + module.weight.data.to(torch.float32), mean=0.0, std=self.config.initializer_range + ).to(module.weight.dtype) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.LayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + elif isinstance(module, ImageEmbeddings): + module.position_embeddings.data = nn.init.trunc_normal_( + module.position_embeddings.data.to(torch.float32), + mean=0.0, + std=self.config.initializer_range, + ).to(module.position_embeddings.dtype) + +class LinearAttentionForImageClassification(LinearAttentionVisionPreTrainedModel): + config_class = LinearAttentionVisionConfig + + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_classes + + self.embeddings = ImageEmbeddings(config) + self.blocks = nn.ModuleList([ + LinearAttentionBlock(config, layer_idx) + for layer_idx in range(config.num_hidden_layers) + ]) + self.norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.pooler = Pooler(config) + self.classifier = nn.Linear(config.hidden_size, config.num_classes) + self.interpolate_pos_encoding = config.interpolate_pos_encoding + self.init_weights() + + def forward( + self, + pixel_values: Optional[torch.Tensor] = None, + labels: Optional[torch.Tensor] = None, + output_hidden_states: Optional[bool] = None, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + return_dict: Optional[bool] = None, + **kwargs: Unpack[Dict] + ) -> Union[Tuple, ImageClassifierOutput]: + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + fuse_linear_and_cross_entropy = self.config.fuse_cross_entropy and self.training + hidden_states = self.embeddings(pixel_values, interpolate_pos_encoding=self.interpolate_pos_encoding) + + for block in self.blocks: + hidden_states, attentions, past_key_values = block( + hidden_states, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + **kwargs + ) + + hidden_states = self.norm(hidden_states) + pooled_output = self.pooler(hidden_states) + + logits = self.classifier(pooled_output) + + loss = None + if labels is not None: + if self.num_labels == 1: + loss_fct = MSELoss() + loss = loss_fct(logits.squeeze(), labels.squeeze()) + else: + loss_fct = CrossEntropyLoss() + loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) + + if not return_dict: + output = (logits,) + (hidden_states,) + return ((loss,) + output) if loss is not None else output + + return ImageClassifierOutput( + loss=loss, + logits=logits, + hidden_states=hidden_states, + ) diff --git a/fla/vision_models/retnet/__init__.py b/fla/vision_models/retnet/__init__.py new file mode 100644 index 000000000..4a32b420f --- /dev/null +++ b/fla/vision_models/retnet/__init__.py @@ -0,0 +1,12 @@ +from transformers import AutoConfig, AutoModelForImageClassification + +from fla.vision_models.retnet.configuration_retnet import RetNetVisionConfig +from fla.vision_models.retnet.modeling_retnet import RetNetForImageClassification + +AutoConfig.register(RetNetVisionConfig.model_type, RetNetVisionConfig) +AutoModelForImageClassification.register(RetNetVisionConfig, RetNetForImageClassification) + +__all__ = [ + 'RetNetVisionConfig', + 'RetNetForImageClassification' +] diff --git a/fla/vision_models/retnet/configuration_retnet.py b/fla/vision_models/retnet/configuration_retnet.py new file mode 100644 index 000000000..4f27d5531 --- /dev/null +++ b/fla/vision_models/retnet/configuration_retnet.py @@ -0,0 +1,100 @@ +from __future__ import annotations + +from typing import Dict, Optional + +from transformers.configuration_utils import PretrainedConfig + + +class RetNetVisionConfig(PretrainedConfig): + + model_type = 'retnet_vision' + + def __init__( + self, + # RetNet core parameters + attn_mode: str = "chunk", + hidden_size: int = 2048, + expand_k: int = 1, + expand_v: int = 2, + num_hidden_layers: int = 24, + num_heads: int = 8, + num_kv_heads: Optional[int] = None, + feature_map: Optional[str] = None, + hidden_act: str = "swish", + use_short_conv: bool = False, + conv_size: int = 4, + use_output_gate: bool = True, + max_position_embeddings: int = 2048, + elementwise_affine: Optional[bool] = True, + norm_eps: float = 1e-6, + attn: Optional[Dict] = None, + use_cache: bool = True, + initializer_range: float = 0.02, + fuse_norm: bool = True, + fuse_cross_entropy: bool = True, + # Vision specific parameters + image_size: int = 224, + patch_size: int = 16, + num_channels: int = 3, + num_classes: int = 1000, + qkv_bias: bool = True, + hidden_dropout_prob: float = 0.0, + use_mask_token: bool = False, + layer_norm_eps: float = 1e-6, + interpolate_pos_encoding: bool = False, + mlp_dim: int = None, + # FLA-for-vision-related parameters + scan_type: str = "uni-scan", # scaning type, "uni-scan" or "bi-scan" or "cross-scan", default to "uni-scan" + **kwargs + ) -> RetNetVisionConfig: + # Initialize RetNet core parameters + self.attn_mode = attn_mode + self.hidden_size = hidden_size + self.expand_k = expand_k + self.expand_v = expand_v + self.num_hidden_layers = num_hidden_layers + self.num_heads = num_heads + self.num_kv_heads = num_kv_heads + self.feature_map = feature_map + self.hidden_act = hidden_act + self.use_short_conv = use_short_conv + self.conv_size = conv_size + self.use_output_gate = use_output_gate + self.hidden_act = hidden_act + self.max_position_embeddings = max_position_embeddings + self.elementwise_affine = elementwise_affine + self.norm_eps = norm_eps + self.attn = attn + self.use_cache = use_cache + self.initializer_range = initializer_range + self.fuse_norm = fuse_norm + self.fuse_cross_entropy = fuse_cross_entropy + + # Initialize vision specific parameters + self.image_size = image_size + self.patch_size = patch_size + self.num_channels = num_channels + self.num_classes = num_classes + self.qkv_bias = qkv_bias + self.hidden_dropout_prob = hidden_dropout_prob + self.use_mask_token = use_mask_token + self.layer_norm_eps = layer_norm_eps + self.interpolate_pos_encoding = interpolate_pos_encoding + self.scan_type = scan_type + + if attn is not None: + if not isinstance(attn, Dict): + raise ValueError("attn must be a dictionary") + if 'layers' not in attn: + raise ValueError("Layer indices must be provided to initialize hybrid attention layers") + if 'num_heads' not in attn: + raise ValueError("Number of heads must be provided to initialize hybrid attention layers") + attn['num_kv_heads'] = attn.get('num_kv_heads', attn['num_heads']) + attn['window_size'] = attn.get('window_size', None) + + if mlp_dim is None: + self.mlp_dim = 4 * hidden_size # default value set to 4 * hidden_size + else: + self.mlp_dim = mlp_dim + + super().__init__(**kwargs) \ No newline at end of file diff --git a/fla/vision_models/retnet/modeling_retnet.py b/fla/vision_models/retnet/modeling_retnet.py new file mode 100644 index 000000000..d7918696c --- /dev/null +++ b/fla/vision_models/retnet/modeling_retnet.py @@ -0,0 +1,204 @@ +import collections.abc +import math +import torch +import torch.utils.checkpoint +from torch import nn +from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss +from typing import Optional, Set, Tuple, Union, List, Dict, Unpack +from transformers.utils import logging +from fla.layers.attn import Attention +from transformers.modeling_outputs import ImageClassifierOutput +from transformers.modeling_utils import PreTrainedModel +from .configuration_retnet import RetNetVisionConfig +from fla.layers.multiscale_retention import MultiScaleRetention +from fla.models.utils import Cache +from ..utils import ImageEmbeddings, Pooler, prepare_hidden_states_for_cross_scan, prepare_hidden_states_for_cross_merge + +logger = logging.get_logger(__name__) + +class RetNetMLP(nn.Module): + def __init__(self, config): + super().__init__() + self.net = nn.Sequential( + nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps), + nn.Linear(config.hidden_size, config.mlp_dim), + nn.GELU(), + nn.Linear(config.mlp_dim, config.hidden_size), + nn.Dropout(config.hidden_dropout_prob) + ) + + def forward(self, x): + return self.net(x) + +class RetNetBlock(nn.Module): + def __init__(self, config, layer_idx: int): + super().__init__() + + if not config.norm_first: + self.ln_1 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + + if config.attn is not None and layer_idx in config.attn['layers']: + self.attn = Attention( + hidden_size=config.hidden_size, + num_heads=config.attn['num_heads'], + num_kv_heads=config.attn['num_kv_heads'], + window_size=config.attn['window_size'], + max_position_embeddings=config.max_position_embeddings, + layer_idx=layer_idx + ) + else: + self.attn = MultiScaleRetention( + mode=config.attn_mode, + hidden_size=config.hidden_size, + expand_k=config.expand_k, + expand_v=config.expand_v, + num_heads=config.num_heads, + num_kv_heads=config.num_kv_heads, + feature_map=config.feature_map, + use_output_gate=config.use_output_gate, + gate_fn=config.hidden_act, + elementwise_affine=config.elementwise_affine, + norm_eps=config.norm_eps, + fuse_norm=config.fuse_norm, + layer_idx=layer_idx + ) + + if not config.norm_first: + self.ln_2 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + + self.mlp = RetNetMLP(config) + + self.scan_type = config.scan_type + + def forward( + self, + hidden_states: torch.Tensor, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + use_cache: Optional[bool] = False, + output_attentions: Optional[bool] = False, + **kwargs: Unpack[Dict] + ) -> Union[Tuple[torch.Tensor, Optional[torch.Tensor]], Tuple[torch.Tensor]]: + residual = hidden_states + + # Pre-normalization if enabled + if hasattr(self, 'ln_1'): + hidden_states = self.ln_1(hidden_states) + + # Apply attention + + hidden_states = prepare_hidden_states_for_cross_scan(hidden_states, self.scan_type) + + hidden_states, attentions, past_key_values = self.attn( + hidden_states=hidden_states, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + **kwargs + ) + + hidden_states = prepare_hidden_states_for_cross_merge(hidden_states, self.scan_type) + + # First residual connection + hidden_states = residual + hidden_states + residual = hidden_states + + # Pre-normalization for MLP if enabled + if hasattr(self, 'ln_2'): + hidden_states = self.ln_2(hidden_states) + + # MLP + hidden_states = self.mlp(hidden_states) + + # Second residual connection + hidden_states = residual + hidden_states + + outputs = (hidden_states, attentions, past_key_values) + + return outputs + +class RetNetVisionPreTrainedModel(PreTrainedModel): + # this part of the code is adapted from huggingface/transformers vit implementation + config_class = RetNetVisionConfig + + def _init_weights(self, module): + if isinstance(module, (nn.Linear, nn.Conv2d)): + module.weight.data = nn.init.trunc_normal_( + module.weight.data.to(torch.float32), mean=0.0, std=self.config.initializer_range + ).to(module.weight.dtype) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.LayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + elif isinstance(module, ImageEmbeddings): + module.position_embeddings.data = nn.init.trunc_normal_( + module.position_embeddings.data.to(torch.float32), + mean=0.0, + std=self.config.initializer_range, + ).to(module.position_embeddings.dtype) + +class RetNetForImageClassification(RetNetVisionPreTrainedModel): + config_class = RetNetVisionConfig + + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_classes + + self.embeddings = ImageEmbeddings(config) + self.blocks = nn.ModuleList([ + RetNetBlock(config, layer_idx) + for layer_idx in range(config.num_hidden_layers) + ]) + self.norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.pooler = Pooler(config) + self.classifier = nn.Linear(config.hidden_size, config.num_classes) + self.interpolate_pos_encoding = config.interpolate_pos_encoding + self.init_weights() + + def forward( + self, + pixel_values: Optional[torch.Tensor] = None, + labels: Optional[torch.Tensor] = None, + output_hidden_states: Optional[bool] = None, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + return_dict: Optional[bool] = None, + **kwargs: Unpack[Dict] + ) -> Union[Tuple, ImageClassifierOutput]: + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + fuse_linear_and_cross_entropy = self.config.fuse_cross_entropy and self.training + hidden_states = self.embeddings(pixel_values, interpolate_pos_encoding=self.interpolate_pos_encoding) + + for block in self.blocks: + hidden_states, attentions, past_key_values = block( + hidden_states, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + **kwargs + ) + + hidden_states = self.norm(hidden_states) + pooled_output = self.pooler(hidden_states) + + logits = self.classifier(pooled_output) + + loss = None + if labels is not None: + if self.num_labels == 1: + loss_fct = MSELoss() + loss = loss_fct(logits.squeeze(), labels.squeeze()) + else: + loss_fct = CrossEntropyLoss() + loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) + + if not return_dict: + output = (logits,) + (hidden_states,) + return ((loss,) + output) if loss is not None else output + + return ImageClassifierOutput( + loss=loss, + logits=logits, + hidden_states=hidden_states, + ) diff --git a/fla/vision_models/rwkv6/__init__.py b/fla/vision_models/rwkv6/__init__.py new file mode 100644 index 000000000..2df666ac4 --- /dev/null +++ b/fla/vision_models/rwkv6/__init__.py @@ -0,0 +1,12 @@ +from transformers import AutoConfig, AutoModelForImageClassification + +from fla.vision_models.rwkv6.configuration_rwkv6 import RWKV6VisionConfig +from fla.vision_models.rwkv6.modeling_rwkv6 import RWKV6ForImageClassification + +AutoConfig.register(RWKV6VisionConfig.model_type, RWKV6VisionConfig) +AutoModelForImageClassification.register(RWKV6VisionConfig, RWKV6ForImageClassification) + +__all__ = [ + 'RWKV6VisionConfig', + 'RWKV6ForImageClassification' +] diff --git a/fla/vision_models/rwkv6/configuration_rwkv6.py b/fla/vision_models/rwkv6/configuration_rwkv6.py new file mode 100644 index 000000000..3478c6d08 --- /dev/null +++ b/fla/vision_models/rwkv6/configuration_rwkv6.py @@ -0,0 +1,93 @@ +from typing import Dict, Optional + +from transformers.configuration_utils import PretrainedConfig + + +class RWKV6VisionConfig(PretrainedConfig): + + model_type = 'rwkv6_vision' + + def __init__( + self, + # RWKV6 core parameters + attn_mode: str = "chunk", + hidden_size: int = 2048, + expand_k: int = 0.5, + expand_v: int = 1, + num_hidden_layers: int = 24, + num_heads: int = 4, + proj_low_rank_dim: int = 32, + gate_low_rank_dim: int = 64, + hidden_act: str = "sqrelu", + max_position_embeddings: int = 2048, + norm_first: bool = True, + norm_bias: bool = True, + norm_eps: float = 1e-5, + attn: Optional[Dict] = None, + use_cache: bool = True, + initializer_range: float = 0.02, + fuse_norm: bool = True, + fuse_cross_entropy: bool = True, + # Vision specific parameters + image_size: int = 224, + patch_size: int = 16, + num_channels: int = 3, + num_classes: int = 1000, + qkv_bias: bool = True, + hidden_dropout_prob: float = 0.0, + use_mask_token: bool = False, + layer_norm_eps: float = 1e-6, + interpolate_pos_encoding: bool = False, + mlp_dim: int = None, + # FLA-for-vision-related parameters + scan_type: str = "uni-scan", # scaning type, "uni-scan" or "bi-scan" or "cross-scan", default to "uni-scan" + **kwargs + ): + # Initialize RWKV6 core parameters + self.attn_mode = attn_mode + self.hidden_size = hidden_size + self.expand_k = expand_k + self.expand_v = expand_v + self.norm_first = norm_first + self.num_hidden_layers = num_hidden_layers + self.num_heads = num_heads + self.proj_low_rank_dim = proj_low_rank_dim + self.gate_low_rank_dim = gate_low_rank_dim + self.hidden_act = hidden_act + self.max_position_embeddings = max_position_embeddings + self.norm_bias = norm_bias + self.norm_eps = norm_eps + self.attn = attn + self.use_cache = use_cache + self.initializer_range = initializer_range + self.fuse_norm = fuse_norm + self.fuse_cross_entropy = fuse_cross_entropy + + # Initialize vision specific parameters + self.image_size = image_size + self.patch_size = patch_size + self.num_channels = num_channels + self.num_classes = num_classes + self.qkv_bias = qkv_bias + self.hidden_dropout_prob = hidden_dropout_prob + self.use_mask_token = use_mask_token + self.layer_norm_eps = layer_norm_eps + self.interpolate_pos_encoding = interpolate_pos_encoding + self.scan_type = scan_type + + if attn is not None: + if not isinstance(attn, Dict): + raise ValueError("attn must be a dictionary") + if 'layers' not in attn: + raise ValueError("Layer indices must be provided to initialize hybrid attention layers") + if 'num_heads' not in attn: + raise ValueError("Number of heads must be provided to initialize hybrid attention layers") + attn['num_kv_heads'] = attn.get('num_kv_heads', attn['num_heads']) + attn['window_size'] = attn.get('window_size', None) + + if mlp_dim is None: + self.mlp_dim = 4 * hidden_size # default value set to 4 * hidden_size + else: + self.mlp_dim = mlp_dim + + super().__init__(**kwargs) \ No newline at end of file diff --git a/fla/vision_models/rwkv6/modeling_rwkv6.py b/fla/vision_models/rwkv6/modeling_rwkv6.py new file mode 100644 index 000000000..bd86d0d95 --- /dev/null +++ b/fla/vision_models/rwkv6/modeling_rwkv6.py @@ -0,0 +1,201 @@ +import collections.abc +import math +import torch +import torch.utils.checkpoint +from torch import nn +from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss +from typing import Optional, Set, Tuple, Union, List, Dict, Unpack +from transformers.utils import logging +from fla.layers.attn import Attention +from transformers.modeling_outputs import ImageClassifierOutput +from transformers.modeling_utils import PreTrainedModel +from fla.layers.rwkv6 import RWKV6Attention +from .configuration_rwkv6 import RWKV6VisionConfig +from fla.models.utils import Cache +from ..utils import ImageEmbeddings, Pooler, prepare_hidden_states_for_cross_scan, prepare_hidden_states_for_cross_merge + +logger = logging.get_logger(__name__) + +class RWKV6MLP(nn.Module): + def __init__(self, config): + super().__init__() + self.net = nn.Sequential( + nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps), + nn.Linear(config.hidden_size, config.mlp_dim), + nn.GELU(), + nn.Linear(config.mlp_dim, config.hidden_size), + nn.Dropout(config.hidden_dropout_prob) + ) + + def forward(self, x): + return self.net(x) + +class RWKV6Block(nn.Module): + def __init__(self, config, layer_idx: int): + super().__init__() + + if not config.norm_first: + self.ln_1 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + + if config.attn is not None and layer_idx in config.attn['layers']: + self.attn = Attention( + hidden_size=config.hidden_size, + num_heads=config.attn['num_heads'], + num_kv_heads=config.attn['num_kv_heads'], + window_size=config.attn['window_size'], + max_position_embeddings=config.max_position_embeddings, + layer_idx=layer_idx + ) + else: + self.attn = RWKV6Attention( + mode=config.attn_mode, + hidden_size=config.hidden_size, + expand_k=config.expand_k, + expand_v=config.expand_v, + num_heads=config.num_heads, + proj_low_rank_dim=config.proj_low_rank_dim, + gate_low_rank_dim=config.gate_low_rank_dim, + norm_eps=config.norm_eps, + fuse_norm=config.fuse_norm, + layer_idx=layer_idx + ) + + if not config.norm_first: + self.ln_2 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + + self.mlp = RWKV6MLP(config) + + self.scan_type = config.scan_type + + def forward( + self, + hidden_states: torch.Tensor, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + use_cache: Optional[bool] = False, + output_attentions: Optional[bool] = False, + **kwargs: Unpack[Dict] + ) -> Union[Tuple[torch.Tensor, Optional[torch.Tensor]], Tuple[torch.Tensor]]: + residual = hidden_states + + # Pre-normalization if enabled + if hasattr(self, 'ln_1'): + hidden_states = self.ln_1(hidden_states) + + # Apply attention + + hidden_states = prepare_hidden_states_for_cross_scan(hidden_states, self.scan_type) + + hidden_states, attentions, past_key_values = self.attn( + hidden_states=hidden_states, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + **kwargs + ) + + hidden_states = prepare_hidden_states_for_cross_merge(hidden_states, self.scan_type) + + # First residual connection + hidden_states = residual + hidden_states + residual = hidden_states + + # Pre-normalization for MLP if enabled + if hasattr(self, 'ln_2'): + hidden_states = self.ln_2(hidden_states) + + # MLP + hidden_states = self.mlp(hidden_states) + + # Second residual connection + hidden_states = residual + hidden_states + + outputs = (hidden_states, attentions, past_key_values) + + return outputs + +class RWKV6VisionPreTrainedModel(PreTrainedModel): + # this part of the code is adapted from huggingface/transformers vit implementation + config_class = RWKV6VisionConfig + + def _init_weights(self, module): + if isinstance(module, (nn.Linear, nn.Conv2d)): + module.weight.data = nn.init.trunc_normal_( + module.weight.data.to(torch.float32), mean=0.0, std=self.config.initializer_range + ).to(module.weight.dtype) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.LayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + elif isinstance(module, ImageEmbeddings): + module.position_embeddings.data = nn.init.trunc_normal_( + module.position_embeddings.data.to(torch.float32), + mean=0.0, + std=self.config.initializer_range, + ).to(module.position_embeddings.dtype) + +class RWKV6ForImageClassification(RWKV6VisionPreTrainedModel): + config_class = RWKV6VisionConfig + + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_classes + + self.embeddings = ImageEmbeddings(config) + self.blocks = nn.ModuleList([ + RWKV6Block(config, layer_idx) + for layer_idx in range(config.num_hidden_layers) + ]) + self.norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.pooler = Pooler(config) + self.classifier = nn.Linear(config.hidden_size, config.num_classes) + self.interpolate_pos_encoding = config.interpolate_pos_encoding + self.init_weights() + + def forward( + self, + pixel_values: Optional[torch.Tensor] = None, + labels: Optional[torch.Tensor] = None, + output_hidden_states: Optional[bool] = None, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + return_dict: Optional[bool] = None, + **kwargs: Unpack[Dict] + ) -> Union[Tuple, ImageClassifierOutput]: + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + fuse_linear_and_cross_entropy = self.config.fuse_cross_entropy and self.training + hidden_states = self.embeddings(pixel_values, interpolate_pos_encoding=self.interpolate_pos_encoding) + + for block in self.blocks: + hidden_states, attentions, past_key_values = block( + hidden_states, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + **kwargs + ) + + hidden_states = self.norm(hidden_states) + pooled_output = self.pooler(hidden_states) + + logits = self.classifier(pooled_output) + + loss = None + if labels is not None: + if self.num_labels == 1: + loss_fct = MSELoss() + loss = loss_fct(logits.squeeze(), labels.squeeze()) + else: + loss_fct = CrossEntropyLoss() + loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) + + if not return_dict: + output = (logits,) + (hidden_states,) + return ((loss,) + output) if loss is not None else output + + return ImageClassifierOutput( + loss=loss, + logits=logits, + hidden_states=hidden_states, + ) diff --git a/fla/vision_models/transformer/__init__.py b/fla/vision_models/transformer/__init__.py new file mode 100644 index 000000000..25d4e9d2b --- /dev/null +++ b/fla/vision_models/transformer/__init__.py @@ -0,0 +1,12 @@ +from transformers import AutoConfig, AutoModelForImageClassification + +from fla.vision_models.transformer.configuration_transformer import TransformerVisionConfig +from fla.vision_models.transformer.modeling_transformer import TransformerForImageClassification + +AutoConfig.register(TransformerVisionConfig.model_type, TransformerVisionConfig) +AutoModelForImageClassification.register(TransformerVisionConfig, TransformerForImageClassification) + +__all__ = [ + 'TransformerVisionConfig', + 'TransformerForImageClassification' +] diff --git a/fla/vision_models/transformer/configuration_transformer.py b/fla/vision_models/transformer/configuration_transformer.py new file mode 100644 index 000000000..cc8246270 --- /dev/null +++ b/fla/vision_models/transformer/configuration_transformer.py @@ -0,0 +1,81 @@ +from typing import Optional + +from transformers.configuration_utils import PretrainedConfig + + +class TransformerVisionConfig(PretrainedConfig): + + model_type = 'transformer_vision' + + def __init__( + self, + # Transformer core parameters + hidden_size: int = 2048, + num_hidden_layers: int = 24, + num_heads: int = 32, + num_kv_heads: int = None, + window_size: Optional[int] = None, + rope_theta: Optional[float] = 10000., + max_position_embeddings: int = 2048, + hidden_act: str = "swish", + initializer_range: float = 0.02, + elementwise_affine: Optional[bool] = True, + norm_first: bool = False, + norm_eps: float = 1e-6, + use_cache: bool = True, + attention_bias: bool = False, + fuse_norm: bool = True, + fuse_cross_entropy: bool = True, + # Vision specific parameters + image_size: int = 224, + patch_size: int = 16, + num_channels: int = 3, + num_classes: int = 1000, + qkv_bias: bool = True, + hidden_dropout_prob: float = 0.0, + use_mask_token: bool = False, + layer_norm_eps: float = 1e-6, + interpolate_pos_encoding: bool = False, + mlp_dim: int = None, + # FLA-for-vision-related parameters + scan_type: str = "uni-scan", # scaning type, "uni-scan" or "bi-scan" or "cross-scan", default to "uni-scan" + **kwargs + ): + # Initialize transformer core parameters + self.hidden_size = hidden_size + self.num_hidden_layers = num_hidden_layers + self.num_heads = num_heads + self.num_kv_heads = num_kv_heads + self.window_size = window_size + self.rope_theta = rope_theta + self.max_position_embeddings = max_position_embeddings + + self.hidden_act = hidden_act + + self.initializer_range = initializer_range + self.elementwise_affine = elementwise_affine + self.norm_first = norm_first + self.norm_eps = norm_eps + self.use_cache = use_cache + self.attention_bias = attention_bias + self.fuse_cross_entropy = fuse_cross_entropy + self.fuse_norm = fuse_norm + + # Initialize vision specific parameters + self.image_size = image_size + self.patch_size = patch_size + self.num_channels = num_channels + self.num_classes = num_classes + self.qkv_bias = qkv_bias + self.hidden_dropout_prob = hidden_dropout_prob + self.use_mask_token = use_mask_token + self.layer_norm_eps = layer_norm_eps + self.interpolate_pos_encoding = interpolate_pos_encoding + self.scan_type = scan_type + + if mlp_dim is None: + self.mlp_dim = 4 * hidden_size # default value set to 4 * hidden_size + else: + self.mlp_dim = mlp_dim + + super().__init__(**kwargs) \ No newline at end of file diff --git a/fla/vision_models/transformer/modeling_transformer.py b/fla/vision_models/transformer/modeling_transformer.py new file mode 100644 index 000000000..6441293a5 --- /dev/null +++ b/fla/vision_models/transformer/modeling_transformer.py @@ -0,0 +1,190 @@ +import collections.abc +import math +import torch +import torch.utils.checkpoint +from torch import nn +from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss +from typing import Optional, Set, Tuple, Union, List, Dict, Unpack +from transformers.utils import logging +from fla.layers.attn import Attention +from transformers.modeling_outputs import ImageClassifierOutput +from transformers.modeling_utils import PreTrainedModel +from .configuration_transformer import TransformerVisionConfig +from fla.models.utils import Cache +from ..utils import ImageEmbeddings, Pooler, prepare_hidden_states_for_cross_scan, prepare_hidden_states_for_cross_merge + +logger = logging.get_logger(__name__) + +class TransformerMLP(nn.Module): + def __init__(self, config): + super().__init__() + self.net = nn.Sequential( + nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps), + nn.Linear(config.hidden_size, config.mlp_dim), + nn.GELU(), + nn.Linear(config.mlp_dim, config.hidden_size), + nn.Dropout(config.hidden_dropout_prob) + ) + + def forward(self, x): + return self.net(x) + +class TransformerBlock(nn.Module): + def __init__(self, config, layer_idx: int): + super().__init__() + + if not config.norm_first: + self.ln_1 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + + self.attn = Attention( + hidden_size=config.hidden_size, + num_heads=config.num_heads, + num_kv_heads=config.num_kv_heads, + window_size=config.window_size, + rope_theta=config.rope_theta, + max_position_embeddings=config.max_position_embeddings, + norm_first=config.norm_first, + norm_eps=config.norm_eps, + layer_idx=layer_idx + ) + + + if not config.norm_first: + self.ln_2 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + + self.mlp = TransformerMLP(config) + + self.scan_type = config.scan_type + + def forward( + self, + hidden_states: torch.Tensor, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + use_cache: Optional[bool] = False, + output_attentions: Optional[bool] = False, + **kwargs: Unpack[Dict] + ) -> Union[Tuple[torch.Tensor, Optional[torch.Tensor]], Tuple[torch.Tensor]]: + residual = hidden_states + + # Pre-normalization if enabled + if hasattr(self, 'ln_1'): + hidden_states = self.ln_1(hidden_states) + + # Apply attention + + hidden_states = prepare_hidden_states_for_cross_scan(hidden_states, self.scan_type) + + hidden_states, attentions, past_key_values = self.attn( + hidden_states=hidden_states, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + **kwargs + ) + + hidden_states = prepare_hidden_states_for_cross_merge(hidden_states, self.scan_type) + + # First residual connection + hidden_states = residual + hidden_states + residual = hidden_states + + # Pre-normalization for MLP if enabled + if hasattr(self, 'ln_2'): + hidden_states = self.ln_2(hidden_states) + + # MLP + hidden_states = self.mlp(hidden_states) + + # Second residual connection + hidden_states = residual + hidden_states + + outputs = (hidden_states, attentions, past_key_values) + + return outputs + +class TransformerVisionPreTrainedModel(PreTrainedModel): + # this part of the code is adapted from huggingface/transformers vit implementation + config_class = TransformerVisionConfig + + def _init_weights(self, module): + if isinstance(module, (nn.Linear, nn.Conv2d)): + module.weight.data = nn.init.trunc_normal_( + module.weight.data.to(torch.float32), mean=0.0, std=self.config.initializer_range + ).to(module.weight.dtype) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.LayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + elif isinstance(module, ImageEmbeddings): + module.position_embeddings.data = nn.init.trunc_normal_( + module.position_embeddings.data.to(torch.float32), + mean=0.0, + std=self.config.initializer_range, + ).to(module.position_embeddings.dtype) + +class TransformerForImageClassification(TransformerVisionPreTrainedModel): + config_class = TransformerVisionConfig + + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_classes + + self.embeddings = ImageEmbeddings(config) + self.blocks = nn.ModuleList([ + TransformerBlock(config, layer_idx) + for layer_idx in range(config.num_hidden_layers) + ]) + self.norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.pooler = Pooler(config) + self.classifier = nn.Linear(config.hidden_size, config.num_classes) + self.interpolate_pos_encoding = config.interpolate_pos_encoding + self.init_weights() + + def forward( + self, + pixel_values: Optional[torch.Tensor] = None, + labels: Optional[torch.Tensor] = None, + output_hidden_states: Optional[bool] = None, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + return_dict: Optional[bool] = None, + **kwargs: Unpack[Dict] + ) -> Union[Tuple, ImageClassifierOutput]: + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + fuse_linear_and_cross_entropy = self.config.fuse_cross_entropy and self.training + hidden_states = self.embeddings(pixel_values, interpolate_pos_encoding=self.interpolate_pos_encoding) + + for block in self.blocks: + hidden_states, attentions, past_key_values = block( + hidden_states, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + **kwargs + ) + + hidden_states = self.norm(hidden_states) + pooled_output = self.pooler(hidden_states) + + logits = self.classifier(pooled_output) + + loss = None + if labels is not None: + if self.num_labels == 1: + loss_fct = MSELoss() + loss = loss_fct(logits.squeeze(), labels.squeeze()) + else: + loss_fct = CrossEntropyLoss() + loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) + + if not return_dict: + output = (logits,) + (hidden_states,) + return ((loss,) + output) if loss is not None else output + + return ImageClassifierOutput( + loss=loss, + logits=logits, + hidden_states=hidden_states, + ) diff --git a/fla/vision_models/utils.py b/fla/vision_models/utils.py index 6ba296022..246dcf931 100644 --- a/fla/vision_models/utils.py +++ b/fla/vision_models/utils.py @@ -433,9 +433,48 @@ def prepare_hidden_states_for_cross_merge(hidden_states: torch.Tensor, scan_type # check the implementation if __name__ == "__main__": - B, L, D = 2, 16, 2048 - hidden_states = torch.randn(B, L, D).cuda() - hidden_states = prepare_hidden_states_for_cross_scan(hidden_states, scan_type="cross-scan") - hidden_states = prepare_hidden_states_for_cross_merge(hidden_states, scan_type="cross-scan") - print(hidden_states.shape) - print("Cross scan applied successfully!") \ No newline at end of file + B, L, D = 1, 4, 3 + transformation = nn.Linear(D, D).cuda() + # firstly test bi-scan + print("Checking bi-scan") + h1 = torch.randn(B, L, D).cuda() + h2 = h1.clone().cuda() + h1 = prepare_hidden_states_for_cross_scan(h1, scan_type="bi-scan") + h1 = transformation(h1) + h1 = prepare_hidden_states_for_cross_merge(h1, scan_type="bi-scan") + h2_ = h2.clone().cuda() + h2_ = h2_.flip(-2) + h2 = transformation(h2) + h2_ = transformation(h2_) + h2 = h2 + h2_ + # check whether the two sequences are the same + print(f"h1: \n{h1}") + print(f"h2: \n{h2}") + print(f"""The two sequences are the same: {torch.allclose(h1, h2)}""") + # Then check cross-scan + print("checking cross-scan") + h1 = torch.randn(B, L, D).cuda() + h2 = h1.clone().cuda() + h1 = prepare_hidden_states_for_cross_scan(h1, scan_type="cross-scan") + h1 = transformation(h1) + h1 = prepare_hidden_states_for_cross_merge(h1, scan_type="cross-scan") + B, L, D = h2.shape + hw = int(math.sqrt(L)) + assert (hw * hw == L) # make sure L is a square + h2 = einops.rearrange(h2, "b (h w) d -> b h w d", h=hw, w=hw) # change the shape to feed to cross_scan + h2 = cross_scan_fn(h2, in_channel_first=False, out_channel_first=False, one_by_one=False, scans=0) + h2 = h2.permute(2, 0, 1, 3) + h2_0 = h2[0] + h2_1 = h2[1] + h2_2 = h2[2] + h2_3 = h2[3] + h2_0 = transformation(h2_0) + h2_1 = transformation(h2_1) + h2_2 = transformation(h2_2) + h2_3 = transformation(h2_3) + h2 = torch.cat([h2_0, h2_1, h2_2, h2_3], dim=0) + h2 = prepare_hidden_states_for_cross_merge(h2, scan_type="cross-scan") + # check whether the two sequences are the same + print(f"h1: \n{h1}") + print(f"h2: \n{h2}") + print(f"""The two sequences are the same: {torch.allclose(h1, h2)}""") \ No newline at end of file From 4188ebbbab3baaf8e692a3a01254f9048eb319a8 Mon Sep 17 00:00:00 2001 From: yibozhong Date: Fri, 17 Jan 2025 02:21:05 +0800 Subject: [PATCH 08/17] change script location --- training/classification.py | 419 ------------------------------------- 1 file changed, 419 deletions(-) delete mode 100644 training/classification.py diff --git a/training/classification.py b/training/classification.py deleted file mode 100644 index 88c160149..000000000 --- a/training/classification.py +++ /dev/null @@ -1,419 +0,0 @@ -import os -import torch -from tqdm import tqdm -import wandb -import logging -import random -import torch.optim as optim -from torch.utils.data import DataLoader -from torchvision import datasets, transforms -from transformers import get_scheduler -from torch.amp import GradScaler, autocast -from fla.vision_models.delta_net import DeltaNetVisionConfig, DeltaNetForImageClassification -import time - -device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') -dtype = torch.bfloat16 # deafult dtype for FLA - -def setup_logging(args): - log_filename = f'training_{args.model}_vision_{args.dataset}.log' - logging.basicConfig( - level=logging.INFO, - format='%(asctime)s - %(levelname)s - %(message)s', - handlers=[ - logging.FileHandler(log_filename), - logging.StreamHandler() - ] - ) - logging.info(f"Logging to {log_filename}") - -def get_args(): - import argparse - parser = argparse.ArgumentParser(description='Vision Model Training') - parser.add_argument('--dataset', type=str, default='cifar100', help='Dataset name') - parser.add_argument('--num_hidden_layers', type=int, default=12, help='Number of hidden layers') - parser.add_argument('--hidden_size', type=int, default=768, help='Hidden size') - parser.add_argument('--patch_size', type=int, default=16, help='Patch size') - parser.add_argument('--image_size', type=int, default=224, help='Image size') - parser.add_argument('--epochs', type=int, default=50, help='Number of epochs') - parser.add_argument('--amp_enabled', action='store_true', help='Enable AMP if device supports it') - parser.add_argument('--b_lr', type=float, default=2e-4, help='Backbone learning rate') - parser.add_argument('--h_lr', type=float, default=2e-4, help='Head learning rate') - parser.add_argument('--wd', type=float, default=0., help='Weight decay') - parser.add_argument('--train_bs', type=int, default=128, help='Training batch size') - parser.add_argument('--eval_bs', type=int, default=256, help='Eval batch size') - parser.add_argument('--num_workers', type=int, default=4, help='Number of workers') - parser.add_argument('--num_heads', type=int, default=16, help='Number of attention heads') - parser.add_argument('--eval_epoch', type=int, default=1, help='Eval frequency') - parser.add_argument('--log_step', type=int, default=10, help='Log frequency') - parser.add_argument('--seed', type=int, default=42, help='Random seed') - parser.add_argument('--wandb', action='store_true', help='Enable wandb logging') - parser.add_argument('--expand_k', type=float, default=1.0, help='Key expansion ratio') - parser.add_argument('--expand_v', type=float, default=1.0, help='Value expansion ratio') - parser.add_argument('--attn_mode', type=str, default='chunk', choices=['chunk', 'fused_recurrent', 'fused_chunk']) - parser.add_argument('--pool_type', type=str, default='mean', choices=['mean', 'cls']) - parser.add_argument('--model', type=str, required=True, help='Model type (currently only supports "deltanet")') - parser.add_argument('--fuse_cross_entropy', action='store_true', help='Fuse cross entropy with logits') - - # Learning rate schedule related arguments - parser.add_argument('--lr_scheduler_type', type=str, default='constant_with_warmup', - choices=['linear', 'cosine', 'cosine_with_restarts', 'polynomial', - 'constant', 'constant_with_warmup']) - parser.add_argument('--warmup_ratio', type=float, default=0.1, - help='Ratio of total training steps for warmup') - # Add hybrid attention related arguments - parser.add_argument('--use_attn', action='store_true', help='Use hybrid attention in some layers') - parser.add_argument('--attn_layers', type=str, default='0,1', - help='Comma separated list of layer indices to use attention, e.g. "0,1,2"') - parser.add_argument('--attn_num_heads', type=int, default=16, - help='Number of attention heads for hybrid attention layers') - parser.add_argument('--attn_num_kv_heads', type=int, default=None, - help='Number of key/value heads for hybrid attention layers') - parser.add_argument('--attn_window_size', type=int, default=None, - help='Window size for hybrid attention layers') - parser.add_argument('--log_memory_epoch', type=int, default=100, help='Log memory usage frequency') - return parser.parse_args() - -def get_data(args): - """ - Prepare data transforms and loaders. - Ensures consistent data types with model. - """ - transform = transforms.Compose([ - transforms.Resize((args.image_size, args.image_size)), - transforms.ToTensor(), - transforms.ConvertImageDtype(dtype), # Match model dtype - ]) - - if args.dataset == 'cifar10': - train_dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform) - test_dataset = datasets.CIFAR10(root='./data', train=False, download=True, transform=transform) - num_classes = 10 - elif args.dataset == 'cifar100': - train_dataset = datasets.CIFAR100(root='./data', train=True, download=True, transform=transform) - test_dataset = datasets.CIFAR100(root='./data', train=False, download=True, transform=transform) - num_classes = 100 - else: - raise ValueError(f"Unsupported dataset: {args.dataset}") - - train_loader = DataLoader(train_dataset, batch_size=args.train_bs, shuffle=True, num_workers=args.num_workers) - test_loader = DataLoader(test_dataset, batch_size=args.eval_bs, shuffle=False, num_workers=args.num_workers) - - return train_loader, test_loader, num_classes - -def setup_deterministic_mode(args): - """Setup deterministic mode for reproducibility""" - import numpy as np - np.random.seed(args.seed) - random.seed(args.seed) - torch.manual_seed(args.seed) - torch.cuda.manual_seed(args.seed) - torch.backends.cudnn.deterministic = True - torch.backends.cudnn.benchmark = False - -def get_gpu_memory_info(): - """ - Get current GPU memory usage information - Returns a dictionary with: - - memory_allocated: actively allocated memory - - memory_reserved: reserved memory in GPU - - max_memory_allocated: max allocated memory since the beginning - """ - return { - 'memory_allocated': torch.cuda.memory_allocated() / 1024**2, # MB - 'memory_reserved': torch.cuda.memory_reserved() / 1024**2, # MB - 'max_memory_allocated': torch.cuda.max_memory_allocated() / 1024**2 # MB - } - -def log_gpu_memory(args, epoch): - """Log GPU memory usage if CUDA is available""" - if torch.cuda.is_available() and epoch % args.log_memory_epoch == 0: - memory_info = get_gpu_memory_info() - logging.info( - f"GPU Memory Usage (Epoch {epoch}) - " - f"Allocated: {memory_info['memory_allocated']:.2f}MB, " - f"Reserved: {memory_info['memory_reserved']:.2f}MB, " - f"Peak: {memory_info['max_memory_allocated']:.2f}MB" - ) - if args.wandb: - wandb.log({ - "gpu_memory/allocated": memory_info['memory_allocated'], - "gpu_memory/reserved": memory_info['memory_reserved'], - "gpu_memory/peak": memory_info['max_memory_allocated'], - "epoch": epoch - }) - -def evaluate(model, test_loader, device, args): - """ - Evaluation loop with proper CUDA timing. - Uses CUDA events for accurate GPU timing and ensures proper synchronization. - """ - model.eval() - correct = 0 - total = 0 - - # Create CUDA events for timing - if torch.cuda.is_available(): - start_event = torch.cuda.Event(enable_timing=True) - end_event = torch.cuda.Event(enable_timing=True) - torch.cuda.synchronize() - start_event.record() - else: - start_time = time.perf_counter() - - with torch.no_grad(): - for images, targets in tqdm(test_loader): - images = images.to(device=device, dtype=dtype) - targets = targets.to(device) - - if args.amp_enabled: - with autocast(): - outputs = model(images).logits - _, predicted = outputs.max(1) - else: - outputs = model(images).logits - _, predicted = outputs.max(1) - - total += targets.size(0) - correct += predicted.eq(targets).sum().item() - - # Measure time with proper CUDA synchronization - if torch.cuda.is_available(): - end_event.record() - torch.cuda.synchronize() - eval_time = start_event.elapsed_time(end_event) / 1000.0 # Convert to seconds - else: - eval_time = time.perf_counter() - start_time - - accuracy = 100. * correct / total - return accuracy, eval_time - -def get_model(args, num_classes): - """ - Initialize model based on configuration. - Supports both pure DeltaNet and hybrid models. - """ - if args.model == 'deltanet': - # Prepare attention config for hybrid model if enabled - attn_config = None - if args.use_attn: - attn_config = { - 'layers': [int(i) for i in args.attn_layers.split(',')], - 'num_heads': args.attn_num_heads, - 'num_kv_heads': args.attn_num_kv_heads, - 'window_size': args.attn_window_size - } - # Log hybrid attention configuration - logging.info("Hybrid Attention Configuration:") - logging.info(f"- Attention Layers: {attn_config['layers']}") - logging.info(f"- Number of Heads: {attn_config['num_heads']}") - logging.info(f"- Number of KV Heads: {attn_config['num_kv_heads']}") - logging.info(f"- Window Size: {attn_config['window_size']}") - - config = DeltaNetVisionConfig( - num_hidden_layers=args.num_hidden_layers, - hidden_size=args.hidden_size, - num_heads=args.num_heads, - patch_size=args.patch_size, - image_size=args.image_size, - num_classes=num_classes, - expand_k=args.expand_k, - expand_v=args.expand_v, - attn_mode=args.attn_mode, - pool_type=args.pool_type, - fuse_cross_entropy=args.fuse_cross_entropy, - attn=attn_config # Add attention config for hybrid model - ) - return DeltaNetForImageClassification(config).to(device=device, dtype=dtype) - else: - raise NotImplementedError(f"Model {args.model} is not implemented yet.") - -def train_one_epoch(model, train_loader, criterion, optimizer, scheduler, device, args, epoch): - """ - Training loop for one epoch with proper CUDA timing. - Uses CUDA events for accurate GPU timing and ensures proper synchronization. - """ - model.train() - total_loss = 0 - scaler = GradScaler() if args.amp_enabled else None - - # Create CUDA events for timing - if torch.cuda.is_available(): - start_event = torch.cuda.Event(enable_timing=True) - end_event = torch.cuda.Event(enable_timing=True) - torch.cuda.synchronize() - start_event.record() - else: - start_time = time.perf_counter() - - for i, (images, targets) in enumerate(tqdm(train_loader)): - images = images.to(device=device, dtype=dtype) - targets = targets.to(device) - - if args.amp_enabled: - with autocast(): - outputs = model(images).logits - loss = criterion(outputs, targets) - scaler.scale(loss).backward() - scaler.step(optimizer) - scaler.update() - else: - outputs = model(images).logits - loss = criterion(outputs, targets) - loss.backward() - optimizer.step() - - optimizer.zero_grad() - scheduler.step() # Update learning rate scheduler - total_loss += loss.item() - - if i % args.log_step == 0: - lrs = [group['lr'] for group in optimizer.param_groups] - logging.info(f'Epoch {epoch} Step {i}/{len(train_loader)}: ' - f'Loss={loss.item():.4f} ' - f'LR_backbone={lrs[0]:.2e} ' - f'LR_head={lrs[-1]:.2e}') - - if args.wandb: - wandb.log({ - "batch_loss": loss.item(), - "learning_rate/backbone": lrs[0], - "learning_rate/head": lrs[-1], - "global_step": epoch * len(train_loader) + i - }) - - # Measure time with proper CUDA synchronization - if torch.cuda.is_available(): - end_event.record() - torch.cuda.synchronize() - train_time = start_event.elapsed_time(end_event) / 1000.0 - else: - train_time = time.perf_counter() - start_time - - avg_loss = total_loss / len(train_loader) - return avg_loss, train_time - -def main(): - args = get_args() - - # Setup logging first, before any logging calls - setup_logging(args) - - # Then setup deterministic mode - setup_deterministic_mode(args) - - # Log all configuration parameters - logging.info("=" * 50) - logging.info("Training Configuration:") - logging.info("-" * 50) - for arg, value in sorted(vars(args).items()): - logging.info(f"{arg}: {value}") - logging.info("=" * 50) - - # Setup wandb after logging is initialized - if args.wandb: - project_name = f"{args.model}_vision_classification" - run_name = f"e{args.epochs}_b_lr{args.b_lr}_h_lr_{args.h_lr}_mode{args.attn_mode}_bs{args.train_bs}_p{args.patch_size}_i{args.image_size}_h{args.num_heads}_{args.dataset}" - wandb.init( - project=project_name, - name=run_name, - config=args.__dict__ - ) - logging.info(f"Wandb initialized with project: {project_name}, run: {run_name}") - - train_loader, test_loader, num_classes = get_data(args) - - # Calculate total training steps - num_training_steps = len(train_loader) * args.epochs - num_warmup_steps = int(args.warmup_ratio * num_training_steps) - - model = get_model(args, num_classes) - trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad) - - logging.info("=" * 50) - logging.info("Model Information:") - logging.info("-" * 50) - logging.info(f"Model Type: {args.model}") - logging.info(f"Number of trainable parameters: {trainable_params:,}") - logging.info(f"Number of layers: {args.num_hidden_layers}") - logging.info(f"Hidden size: {args.hidden_size}") - logging.info(f"Number of heads: {args.num_heads}") - logging.info(f"Learning rate scheduler: {args.lr_scheduler_type}") - logging.info(f"Total training steps: {num_training_steps}") - logging.info(f"Warmup steps: {num_warmup_steps}") - logging.info("=" * 50) - - if args.wandb: - wandb.log({"trainable_parameters": trainable_params}) - - criterion = torch.nn.CrossEntropyLoss() - optimizer = optim.AdamW([ - {'params': model.embeddings.parameters(), 'lr': args.b_lr}, - {'params': model.blocks.parameters(), 'lr': args.b_lr}, - {'params': model.classifier.parameters(), 'lr': args.h_lr} - ], weight_decay=args.wd) - - scheduler = get_scheduler( - name=args.lr_scheduler_type, - optimizer=optimizer, - num_warmup_steps=num_warmup_steps, - num_training_steps=num_training_steps - ) - - best_acc = 0 - total_train_time = 0 - total_eval_time = 0 - eval_num = 0 - - for epoch in range(args.epochs): - avg_loss, epoch_train_time = train_one_epoch(model, train_loader, criterion, optimizer, scheduler, device, args, epoch) - total_train_time += epoch_train_time - - # Log GPU memory usage - log_gpu_memory(args, epoch) - - if epoch % args.eval_epoch == 0: - accuracy, epoch_eval_time = evaluate(model, test_loader, device, args) - total_eval_time += epoch_eval_time - eval_num += 1 - - logging.info( - f'Epoch {epoch}: Loss={avg_loss:.4f}, Accuracy={accuracy:.2f}%, ' - f'Train time={epoch_train_time:.2f}s, Eval time={epoch_eval_time:.2f}s' - ) - - if args.wandb: - wandb.log({ - "epoch": epoch, - "train_loss": avg_loss, - "accuracy": accuracy, - "train_time": epoch_train_time, - "eval_time": epoch_eval_time, - "avg_epoch_train_time": total_train_time / (epoch + 1), - "avg_epoch_eval_time": total_eval_time / eval_num - }) - - if accuracy > best_acc: - best_acc = accuracy - torch.save(model.state_dict(), f'{args.model}_vision_best.pth') - - # Log final statistics - avg_train_time = total_train_time / args.epochs - avg_eval_time = total_eval_time / eval_num - logging.info( - f'Training completed. Best accuracy: {best_acc:.2f}%\n' - f'Average training time per epoch: {avg_train_time:.2f}s\n' - f'Average evaluation time: {avg_eval_time:.2f}s' - ) - - if args.wandb: - wandb.log({ - "final/best_accuracy": best_acc, - "final/avg_train_time": avg_train_time, - "final/avg_eval_time": avg_eval_time - }) - if args.wandb: - wandb.finish() - -if __name__ == '__main__': - main() From 16568d9c10f2e805cafaa201a79b23d114afc1f6 Mon Sep 17 00:00:00 2001 From: yibozhong Date: Fri, 17 Jan 2025 16:24:25 +0800 Subject: [PATCH 09/17] Test the implementations --- classification.py | 227 +++++++++++++++--- fla/layers/abc.py | 2 + fla/vision_models/abc/configuration_abc.py | 2 +- fla/vision_models/abc/modeling_abc.py | 6 +- fla/vision_models/bitnet/__init__.py | 2 +- .../bitnet/configuration_bitnet.py | 3 + .../delta_net/configuration_delta_net.py | 5 +- .../configuration_gated_deltanet.py | 2 + fla/vision_models/gla/configuration_gla.py | 8 +- fla/vision_models/gla/modeling_gla.py | 6 +- fla/vision_models/gsa/configuration_gsa.py | 3 +- fla/vision_models/hgrn/configuration_hgrn.py | 3 +- fla/vision_models/hgrn/modeling_hgrn.py | 6 +- .../hgrn2/configuration_hgrn2.py | 3 +- fla/vision_models/hgrn2/modeling_hgrn2.py | 6 +- .../linear_attn/configuration_linear_attn.py | 2 +- .../linear_attn/modeling_linear_attn.py | 18 +- .../retnet/configuration_retnet.py | 3 +- fla/vision_models/retnet/modeling_retnet.py | 6 +- .../rwkv6/configuration_rwkv6.py | 3 +- fla/vision_models/rwkv6/modeling_rwkv6.py | 6 +- 21 files changed, 246 insertions(+), 76 deletions(-) diff --git a/classification.py b/classification.py index 26343fd2f..07ad9be44 100644 --- a/classification.py +++ b/classification.py @@ -9,14 +9,28 @@ from torchvision import datasets, transforms from transformers import get_scheduler from torch.amp import GradScaler, autocast +from fla.vision_models.abc import ABCVisionConfig, ABCForImageClassification +from fla.vision_models.bitnet import BitNetVisionConfig, BitNetForImageClassification from fla.vision_models.delta_net import DeltaNetVisionConfig, DeltaNetForImageClassification +from fla.vision_models.gated_deltanet import GatedDeltaNetVisionConfig, GatedDeltaNetForImageClassification +from fla.vision_models.gla import GLAVisionConfig, GLAForImageClassification +from fla.vision_models.gsa import GSAVisionConfig, GSAForImageClassification +from fla.vision_models.hgrn import HGRNVisionConfig, HGRNForImageClassification +from fla.vision_models.hgrn2 import HGRN2VisionConfig, HGRN2ForImageClassification +from fla.vision_models.linear_attn import LinearAttentionVisionConfig, LinearAttentionForImageClassification +from fla.vision_models.retnet import RetNetVisionConfig, RetNetForImageClassification +from fla.vision_models.rwkv6 import RWKV6VisionConfig, RWKV6ForImageClassification +from fla.vision_models.transformer import TransformerVisionConfig, TransformerForImageClassification import time device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') dtype = torch.bfloat16 # deafult dtype for FLA def setup_logging(args): - log_filename = f'training_{args.model}_vision_{args.dataset}.log' + # check whether logs directory exists + if not os.path.exists('logs'): + os.makedirs('logs') + log_filename = f'logs/training_{args.model}_vision_{args.dataset}{"_hybrid" if args.use_attn else ""}_{args.scan_type}.log' logging.basicConfig( level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s', @@ -55,7 +69,7 @@ def get_args(): parser.add_argument('--fuse_cross_entropy', action='store_true', help='Fuse cross entropy with logits') parser.add_argument('--scan_type', type=str, default='uni-scan', choices=['uni-scan', 'bi-scan', 'cross-scan'],) - # Learning rate schedule related arguments + # Learning rate scheduler related arguments parser.add_argument('--lr_scheduler_type', type=str, default='constant_with_warmup', choices=['linear', 'cosine', 'cosine_with_restarts', 'polynomial', 'constant', 'constant_with_warmup']) @@ -65,6 +79,7 @@ def get_args(): parser.add_argument('--use_attn', action='store_true', help='Use hybrid attention in some layers') parser.add_argument('--attn_layers', type=str, default='0,1', help='Comma separated list of layer indices to use attention, e.g. "0,1,2"') + # Hybrid architecture related arguments parser.add_argument('--attn_num_heads', type=int, default=16, help='Number of attention heads for hybrid attention layers') parser.add_argument('--attn_num_kv_heads', type=int, default=None, @@ -78,11 +93,12 @@ def get_data(args): """ Prepare data transforms and loaders. Ensures consistent data types with model. + Current suppport only training with CIFAR-10 and CIFAR-100. """ transform = transforms.Compose([ transforms.Resize((args.image_size, args.image_size)), transforms.ToTensor(), - transforms.ConvertImageDtype(dtype), # Match model dtype + transforms.ConvertImageDtype(dtype), ]) if args.dataset == 'cifar10': @@ -102,7 +118,7 @@ def get_data(args): return train_loader, test_loader, num_classes def setup_deterministic_mode(args): - """Setup deterministic mode for reproducibility""" + """Setup deterministic mode for reproducibility on the same device""" import numpy as np np.random.seed(args.seed) random.seed(args.seed) @@ -146,7 +162,6 @@ def log_gpu_memory(args, epoch): def evaluate(model, test_loader, device, args): """ Evaluation loop with proper CUDA timing. - Uses CUDA events for accurate GPU timing and ensures proper synchronization. """ model.eval() correct = 0 @@ -193,23 +208,23 @@ def get_model(args, num_classes): Initialize model based on configuration. Supports both pure DeltaNet and hybrid models. """ - if args.model == 'deltanet': - # Prepare attention config for hybrid model if enabled - attn_config = None - if args.use_attn: - attn_config = { - 'layers': [int(i) for i in args.attn_layers.split(',')], - 'num_heads': args.attn_num_heads, - 'num_kv_heads': args.attn_num_kv_heads, - 'window_size': args.attn_window_size - } - # Log hybrid attention configuration - logging.info("Hybrid Attention Configuration:") - logging.info(f"- Attention Layers: {attn_config['layers']}") - logging.info(f"- Number of Heads: {attn_config['num_heads']}") - logging.info(f"- Number of KV Heads: {attn_config['num_kv_heads']}") - logging.info(f"- Window Size: {attn_config['window_size']}") + # Prepare attention config for hybrid model if enabled + attn_config = None + if args.use_attn: + attn_config = { + 'layers': [int(i) for i in args.attn_layers.split(',')], + 'num_heads': args.attn_num_heads, + 'num_kv_heads': args.attn_num_kv_heads, + 'window_size': args.attn_window_size + } + # Log hybrid attention configuration + logging.info("Hybrid Attention Configuration:") + logging.info(f"- Attention Layers: {attn_config['layers']}") + logging.info(f"- Number of Heads: {attn_config['num_heads']}") + logging.info(f"- Number of KV Heads: {attn_config['num_kv_heads']}") + logging.info(f"- Window Size: {attn_config['window_size']}") + if args.model == 'deltanet': config = DeltaNetVisionConfig( num_hidden_layers=args.num_hidden_layers, hidden_size=args.hidden_size, @@ -217,21 +232,177 @@ def get_model(args, num_classes): patch_size=args.patch_size, image_size=args.image_size, num_classes=num_classes, - expand_k=args.expand_k, - expand_v=args.expand_v, attn_mode=args.attn_mode, fuse_cross_entropy=args.fuse_cross_entropy, attn=attn_config, # Add attention config for hybrid model scan_type=args.scan_type # Add scan type to choose different scaning strategy ) return DeltaNetForImageClassification(config).to(device=device, dtype=dtype) - else: - raise NotImplementedError(f"Model {args.model} is not implemented yet.") + + elif args.model == 'abc': + config = ABCVisionConfig( + num_hidden_layers=args.num_hidden_layers, + hidden_size=args.hidden_size, + num_heads=args.num_heads, + patch_size=args.patch_size, + image_size=args.image_size, + num_classes=num_classes, + attn_mode=args.attn_mode, + fuse_cross_entropy=args.fuse_cross_entropy, + attn=attn_config, # Add attention config for hybrid model + scan_type=args.scan_type # Add scan type to choose different scaning strategy + ) + return ABCForImageClassification(config).to(device=device, dtype=dtype) + + elif args.model == 'gated_deltanet': + config = GatedDeltaNetVisionConfig( + num_hidden_layers=args.num_hidden_layers, + hidden_size=args.hidden_size, + num_heads=args.num_heads, + patch_size=args.patch_size, + image_size=args.image_size, + num_classes=num_classes, + attn_mode=args.attn_mode, + fuse_cross_entropy=args.fuse_cross_entropy, + attn=attn_config, # Add attention config for hybrid model + scan_type=args.scan_type # Add scan type to choose different scaning strategy + ) + return GatedDeltaNetForImageClassification(config).to(device=device, dtype=dtype) + + elif args.model == 'bitnet': + config = BitNetVisionConfig( + num_hidden_layers=args.num_hidden_layers, + hidden_size=args.hidden_size, + num_heads=args.num_heads, + patch_size=args.patch_size, + image_size=args.image_size, + num_classes=num_classes, + attn_mode=args.attn_mode, + fuse_cross_entropy=args.fuse_cross_entropy, + attn=attn_config, # Add attention config for hybrid model + scan_type=args.scan_type # Add scan type to choose different scaning strategy + ) + return BitNetForImageClassification(config).to(device=device, dtype=dtype) + + elif args.model == 'gla': + config = GLAVisionConfig( + num_hidden_layers=args.num_hidden_layers, + hidden_size=args.hidden_size, + num_heads=args.num_heads, + patch_size=args.patch_size, + image_size=args.image_size, + num_classes=num_classes, + attn_mode=args.attn_mode, + fuse_cross_entropy=args.fuse_cross_entropy, + attn=attn_config, # Add attention config for hybrid model + scan_type=args.scan_type # Add scan type to choose different scaning strategy + ) + return GLAForImageClassification(config).to(device=device, dtype=dtype) + + elif args.model == 'gsa': + config = GSAVisionConfig( + num_hidden_layers=args.num_hidden_layers, + hidden_size=args.hidden_size, + num_heads=args.num_heads, + patch_size=args.patch_size, + image_size=args.image_size, + num_classes=num_classes, + attn_mode=args.attn_mode, + fuse_cross_entropy=args.fuse_cross_entropy, + attn=attn_config, # Add attention config for hybrid model + scan_type=args.scan_type # Add scan type to choose different scaning strategy + ) + return GSAForImageClassification(config).to(device=device, dtype=dtype) + + elif args.model == 'hgrn': + config = HGRNVisionConfig( + num_hidden_layers=args.num_hidden_layers, + hidden_size=args.hidden_size, + num_heads=args.num_heads, + patch_size=args.patch_size, + image_size=args.image_size, + num_classes=num_classes, + attn_mode=args.attn_mode, + fuse_cross_entropy=args.fuse_cross_entropy, + attn=attn_config, # Add attention config for hybrid model + scan_type=args.scan_type # Add scan type to choose different scaning strategy + ) + return HGRNForImageClassification(config).to(device=device, dtype=dtype) + + elif args.model == 'hgrn2': + config = HGRN2VisionConfig( + num_hidden_layers=args.num_hidden_layers, + hidden_size=args.hidden_size, + num_heads=args.num_heads, + patch_size=args.patch_size, + image_size=args.image_size, + num_classes=num_classes, + attn_mode=args.attn_mode, + fuse_cross_entropy=args.fuse_cross_entropy, + attn=attn_config, # Add attention config for hybrid model + scan_type=args.scan_type # Add scan type to choose different scaning strategy + ) + return HGRN2ForImageClassification(config).to(device=device, dtype=dtype) + + elif args.model == 'linear_attn': + config = LinearAttentionVisionConfig( + num_hidden_layers=args.num_hidden_layers, + hidden_size=args.hidden_size, + num_heads=args.num_heads, + patch_size=args.patch_size, + image_size=args.image_size, + num_classes=num_classes, + attn_mode=args.attn_mode, + fuse_cross_entropy=args.fuse_cross_entropy, + attn=attn_config, # Add attention config for hybrid model + scan_type=args.scan_type # Add scan type to choose different scaning strategy + ) + return LinearAttentionForImageClassification(config).to(device=device, dtype=dtype) + + elif args.model == 'retnet': + config = RetNetVisionConfig( + num_hidden_layers=args.num_hidden_layers, + hidden_size=args.hidden_size, + num_heads=args.num_heads, + patch_size=args.patch_size, + image_size=args.image_size, + num_classes=num_classes, + attn_mode=args.attn_mode, + fuse_cross_entropy=args.fuse_cross_entropy, + attn=attn_config, # Add attention config for hybrid model + scan_type=args.scan_type # Add scan type to choose different scaning strategy + ) + return RetNetForImageClassification(config).to(device=device, dtype=dtype) + + elif args.model == 'rwkv6': + config = RWKV6VisionConfig( + num_hidden_layers=args.num_hidden_layers, + hidden_size=args.hidden_size, + num_heads=args.num_heads, + patch_size=args.patch_size, + image_size=args.image_size, + num_classes=num_classes, + attn_mode=args.attn_mode, + fuse_cross_entropy=args.fuse_cross_entropy, + attn=attn_config, # Add attention config for hybrid model + scan_type=args.scan_type # Add scan type to choose different scaning strategy + ) + return RWKV6ForImageClassification(config).to(device=device, dtype=dtype) + + elif args.model == 'transformer': + config = TransformerVisionConfig( + num_hidden_layers=args.num_hidden_layers, + hidden_size=args.hidden_size, + num_heads=args.num_heads, + patch_size=args.patch_size, + image_size=args.image_size, + num_classes=num_classes + ) + return TransformerForImageClassification(config).to(device=device, dtype=dtype) def train_one_epoch(model, train_loader, criterion, optimizer, scheduler, device, args, epoch): """ Training loop for one epoch with proper CUDA timing. - Uses CUDA events for accurate GPU timing and ensures proper synchronization. """ model.train() total_loss = 0 @@ -312,8 +483,8 @@ def main(): # Setup wandb after logging is initialized if args.wandb: - project_name = f"{args.model}_vision_classification" - run_name = f"e{args.epochs}_b_lr{args.b_lr}_h_lr_{args.h_lr}_mode{args.attn_mode}_bs{args.train_bs}_p{args.patch_size}_i{args.image_size}_h{args.num_heads}_{args.dataset}" + project_name = "fla-vision" + run_name = f'training_{args.model}_{args.dataset}{"_hybrid" if args.use_attn else ""}_{args.scan_type}_e{args.epochs}_blr_{args.b_lr}_hlr_{args.h_lr}_bs{args.train_bs}_mode_{args.attn_mode}' wandb.init( project=project_name, name=run_name, diff --git a/fla/layers/abc.py b/fla/layers/abc.py index 1db5d94fe..9676d3731 100644 --- a/fla/layers/abc.py +++ b/fla/layers/abc.py @@ -38,6 +38,7 @@ def __init__( use_input_gate: bool = False, use_output_gate: bool = True, use_norm: bool = True, + use_rope: bool = False, # FIXME clamp_min: Optional[float] = -32, clamp_max: Optional[float] = 32, layer_idx: Optional[int] = None, @@ -64,6 +65,7 @@ def __init__( self.use_input_gate = use_input_gate self.use_output_gate = use_output_gate self.use_norm = use_norm + self.use_rope = use_rope # FIXME if num_slots is None: num_slots = self.head_k_dim diff --git a/fla/vision_models/abc/configuration_abc.py b/fla/vision_models/abc/configuration_abc.py index 6a7c2fa95..13de13c09 100644 --- a/fla/vision_models/abc/configuration_abc.py +++ b/fla/vision_models/abc/configuration_abc.py @@ -61,7 +61,6 @@ def __init__( self.max_position_embeddings = max_position_embeddings self.elementwise_affine = elementwise_affine self.norm_eps = norm_eps - self.attn = attn self.use_cache = use_cache self.initializer_range = initializer_range self.fuse_norm = fuse_norm @@ -89,6 +88,7 @@ def __init__( attn['num_kv_heads'] = attn.get('num_kv_heads', attn['num_heads']) attn['window_size'] = attn.get('window_size', None) + self.attn = attn if mlp_dim is None: self.mlp_dim = 4 * hidden_size # default value set to 4 * hidden_size else: diff --git a/fla/vision_models/abc/modeling_abc.py b/fla/vision_models/abc/modeling_abc.py index 9fa33230a..ecba55a35 100644 --- a/fla/vision_models/abc/modeling_abc.py +++ b/fla/vision_models/abc/modeling_abc.py @@ -34,8 +34,7 @@ class ABCBlock(nn.Module): def __init__(self, config, layer_idx: int): super().__init__() - if not config.norm_first: - self.ln_1 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.ln_1 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) if config.attn is not None and layer_idx in config.attn['layers']: self.attn = Attention( @@ -64,8 +63,7 @@ def __init__(self, config, layer_idx: int): layer_idx=layer_idx ) - if not config.norm_first: - self.ln_2 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.ln_2 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.mlp = ABCMLP(config) diff --git a/fla/vision_models/bitnet/__init__.py b/fla/vision_models/bitnet/__init__.py index 8f372bc7c..148b4557f 100644 --- a/fla/vision_models/bitnet/__init__.py +++ b/fla/vision_models/bitnet/__init__.py @@ -3,7 +3,7 @@ from fla.vision_models.bitnet.configuration_bitnet import BitNetVisionConfig from fla.vision_models.bitnet.modeling_bitnet import BitNetForImageClassification -AutoConfig.register(BitNetVisionConfig, BitNetVisionConfig) +AutoConfig.register(BitNetVisionConfig.model_type, BitNetVisionConfig) AutoModelForImageClassification.register(BitNetVisionConfig, BitNetForImageClassification) __all__ = [ diff --git a/fla/vision_models/bitnet/configuration_bitnet.py b/fla/vision_models/bitnet/configuration_bitnet.py index 37a51b925..902f3f9a3 100644 --- a/fla/vision_models/bitnet/configuration_bitnet.py +++ b/fla/vision_models/bitnet/configuration_bitnet.py @@ -74,6 +74,7 @@ def __init__( self.interpolate_pos_encoding = interpolate_pos_encoding self.scan_type = scan_type + if attn is not None: if not isinstance(attn, Dict): raise ValueError("attn must be a dictionary") @@ -84,6 +85,8 @@ def __init__( attn['num_kv_heads'] = attn.get('num_kv_heads', attn['num_heads']) attn['window_size'] = attn.get('window_size', None) + self.attn = attn + if mlp_dim is None: self.mlp_dim = 4 * hidden_size # default value set to 4 * hidden_size else: diff --git a/fla/vision_models/delta_net/configuration_delta_net.py b/fla/vision_models/delta_net/configuration_delta_net.py index d490a37f7..c6921a0fe 100644 --- a/fla/vision_models/delta_net/configuration_delta_net.py +++ b/fla/vision_models/delta_net/configuration_delta_net.py @@ -65,7 +65,6 @@ def __init__( self.use_cache = use_cache self.initializer_range = initializer_range self.fuse_cross_entropy = fuse_cross_entropy - self.attn = attn self.max_position_embeddings = max_position_embeddings # Initialize vision specific parameters @@ -88,7 +87,9 @@ def __init__( raise ValueError("Number of heads must be provided to initialize hybrid attention layers") attn['num_kv_heads'] = attn.get('num_kv_heads', attn['num_heads']) attn['window_size'] = attn.get('window_size', None) - + + self.attn = attn + if mlp_dim is None: self.mlp_dim = 4 * hidden_size # default value set to 4 * hidden_size else: diff --git a/fla/vision_models/gated_deltanet/configuration_gated_deltanet.py b/fla/vision_models/gated_deltanet/configuration_gated_deltanet.py index fe472f257..6cbbd9e72 100644 --- a/fla/vision_models/gated_deltanet/configuration_gated_deltanet.py +++ b/fla/vision_models/gated_deltanet/configuration_gated_deltanet.py @@ -79,6 +79,8 @@ def __init__( attn['num_kv_heads'] = attn.get('num_kv_heads', attn['num_heads']) attn['window_size'] = attn.get('window_size', None) + self.attn = attn + if mlp_dim is None: self.mlp_dim = 4 * hidden_size else: diff --git a/fla/vision_models/gla/configuration_gla.py b/fla/vision_models/gla/configuration_gla.py index 77d750f90..af52bbe6f 100644 --- a/fla/vision_models/gla/configuration_gla.py +++ b/fla/vision_models/gla/configuration_gla.py @@ -65,7 +65,6 @@ def __init__( self.norm_eps = norm_eps self.use_gk = use_gk self.use_gv = use_gv - self.attn = attn self.use_cache = use_cache self.initializer_range = initializer_range self.fuse_norm = fuse_norm @@ -92,4 +91,11 @@ def __init__( attn['num_kv_heads'] = attn.get('num_kv_heads', attn['num_heads']) attn['window_size'] = attn.get('window_size', None) + self.attn = attn + + if mlp_dim is None: + self.mlp_dim = 4 * hidden_size + else: + self.mlp_dim = mlp_dim + super().__init__(**kwargs) diff --git a/fla/vision_models/gla/modeling_gla.py b/fla/vision_models/gla/modeling_gla.py index 433bfb09d..311e000a0 100644 --- a/fla/vision_models/gla/modeling_gla.py +++ b/fla/vision_models/gla/modeling_gla.py @@ -34,8 +34,7 @@ class GLABlock(nn.Module): def __init__(self, config, layer_idx: int): super().__init__() - if not config.norm_first: - self.ln_1 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.ln_1 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) if config.attn is not None and layer_idx in config.attn['layers']: self.attn = Attention( @@ -66,8 +65,7 @@ def __init__(self, config, layer_idx: int): layer_idx=layer_idx ) - if not config.norm_first: - self.ln_2 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.ln_2 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.mlp = GLAMLP(config) diff --git a/fla/vision_models/gsa/configuration_gsa.py b/fla/vision_models/gsa/configuration_gsa.py index deca79dce..de4bbcb8d 100644 --- a/fla/vision_models/gsa/configuration_gsa.py +++ b/fla/vision_models/gsa/configuration_gsa.py @@ -70,7 +70,6 @@ def __init__( self.elementwise_affine = elementwise_affine self.norm_first = norm_first self.norm_eps = norm_eps - self.attn = attn self.use_cache = use_cache self.initializer_range = initializer_range self.fuse_cross_entropy = fuse_cross_entropy @@ -98,6 +97,8 @@ def __init__( attn['num_kv_heads'] = attn.get('num_kv_heads', attn['num_heads']) attn['window_size'] = attn.get('window_size', None) + self.attn = attn + if mlp_dim is None: self.mlp_dim = 4 * hidden_size # default value set to 4 * hidden_size else: diff --git a/fla/vision_models/hgrn/configuration_hgrn.py b/fla/vision_models/hgrn/configuration_hgrn.py index de5aae00b..e9724239b 100644 --- a/fla/vision_models/hgrn/configuration_hgrn.py +++ b/fla/vision_models/hgrn/configuration_hgrn.py @@ -50,7 +50,6 @@ def __init__( self.use_lower_bound = use_lower_bound self.max_position_embeddings = max_position_embeddings self.elementwise_affine = elementwise_affine - self.attn = attn self.norm_eps = norm_eps self.hidden_act = hidden_act self.use_cache = use_cache @@ -77,6 +76,8 @@ def __init__( attn['num_kv_heads'] = attn.get('num_kv_heads', attn['num_heads']) attn['window_size'] = attn.get('window_size', None) + self.attn = attn + if mlp_dim is None: self.mlp_dim = 4 * hidden_size # default value set to 4 * hidden_size else: diff --git a/fla/vision_models/hgrn/modeling_hgrn.py b/fla/vision_models/hgrn/modeling_hgrn.py index 8d591cbc6..35d6e21bf 100644 --- a/fla/vision_models/hgrn/modeling_hgrn.py +++ b/fla/vision_models/hgrn/modeling_hgrn.py @@ -34,8 +34,7 @@ class HGRNBlock(nn.Module): def __init__(self, config, layer_idx: int): super().__init__() - if not config.norm_first: - self.ln_1 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.ln_1 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) if config.attn is not None and layer_idx in config.attn['layers']: self.attn = Attention( @@ -58,8 +57,7 @@ def __init__(self, config, layer_idx: int): layer_idx=layer_idx ) - if not config.norm_first: - self.ln_2 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.ln_2 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.mlp = HGRNMLP(config) diff --git a/fla/vision_models/hgrn2/configuration_hgrn2.py b/fla/vision_models/hgrn2/configuration_hgrn2.py index e8e5df182..ef6ffc83b 100644 --- a/fla/vision_models/hgrn2/configuration_hgrn2.py +++ b/fla/vision_models/hgrn2/configuration_hgrn2.py @@ -54,7 +54,6 @@ def __init__( self.hidden_act = hidden_act self.elementwise_affine = elementwise_affine self.norm_eps = norm_eps - self.attn = attn self.use_cache = use_cache self.initializer_range = initializer_range self.fuse_cross_entropy = fuse_cross_entropy @@ -80,6 +79,8 @@ def __init__( attn['num_kv_heads'] = attn.get('num_kv_heads', attn['num_heads']) attn['window_size'] = attn.get('window_size', None) + self.attn = attn + if mlp_dim is None: self.mlp_dim = 4 * hidden_size # default value set to 4 * hidden_size else: diff --git a/fla/vision_models/hgrn2/modeling_hgrn2.py b/fla/vision_models/hgrn2/modeling_hgrn2.py index 3284d1b76..cbae1a64a 100644 --- a/fla/vision_models/hgrn2/modeling_hgrn2.py +++ b/fla/vision_models/hgrn2/modeling_hgrn2.py @@ -34,8 +34,7 @@ class HGRN2Block(nn.Module): def __init__(self, config, layer_idx: int): super().__init__() - if not config.norm_first: - self.ln_1 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.ln_1 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) if config.attn is not None and layer_idx in config.attn['layers']: self.attn = Attention( @@ -59,8 +58,7 @@ def __init__(self, config, layer_idx: int): layer_idx=layer_idx ) - if not config.norm_first: - self.ln_2 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.ln_2 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.mlp = HGRN2MLP(config) diff --git a/fla/vision_models/linear_attn/configuration_linear_attn.py b/fla/vision_models/linear_attn/configuration_linear_attn.py index d05e3c0ba..8aa0d2ef0 100644 --- a/fla/vision_models/linear_attn/configuration_linear_attn.py +++ b/fla/vision_models/linear_attn/configuration_linear_attn.py @@ -61,7 +61,6 @@ def __init__( self.max_position_embeddings = max_position_embeddings self.elementwise_affine = elementwise_affine self.norm_eps = norm_eps - self.attn = attn self.use_cache = use_cache self.initializer_range = initializer_range self.fuse_cross_entropy = fuse_cross_entropy @@ -88,6 +87,7 @@ def __init__( attn['num_kv_heads'] = attn.get('num_kv_heads', attn['num_heads']) attn['window_size'] = attn.get('window_size', None) + self.attn = attn if mlp_dim is None: self.mlp_dim = 4 * hidden_size # default value set to 4 * hidden_size else: diff --git a/fla/vision_models/linear_attn/modeling_linear_attn.py b/fla/vision_models/linear_attn/modeling_linear_attn.py index 2cd01fb2b..f0889a493 100644 --- a/fla/vision_models/linear_attn/modeling_linear_attn.py +++ b/fla/vision_models/linear_attn/modeling_linear_attn.py @@ -34,8 +34,7 @@ class LinearAttentionBlock(nn.Module): def __init__(self, config, layer_idx: int): super().__init__() - if not config.norm_first: - self.ln_1 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.ln_1 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) if config.attn is not None and layer_idx in config.attn['layers']: self.attn = Attention( @@ -64,8 +63,7 @@ def __init__(self, config, layer_idx: int): layer_idx=layer_idx ) - if not config.norm_first: - self.ln_2 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.ln_2 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.mlp = LinearAttentionMLP(config) @@ -89,13 +87,7 @@ def forward( hidden_states = prepare_hidden_states_for_cross_scan(hidden_states, self.scan_type) - hidden_states, attentions, past_key_values = self.attn( - hidden_states=hidden_states, - past_key_values=past_key_values, - use_cache=use_cache, - output_attentions=output_attentions, - **kwargs - ) + hidden_states = self.attn(hidden_states) hidden_states = prepare_hidden_states_for_cross_merge(hidden_states, self.scan_type) @@ -113,7 +105,7 @@ def forward( # Second residual connection hidden_states = residual + hidden_states - outputs = (hidden_states, attentions, past_key_values) + outputs = (hidden_states,) return outputs @@ -172,7 +164,7 @@ def forward( hidden_states = self.embeddings(pixel_values, interpolate_pos_encoding=self.interpolate_pos_encoding) for block in self.blocks: - hidden_states, attentions, past_key_values = block( + hidden_states = block( hidden_states, past_key_values=past_key_values, use_cache=use_cache, diff --git a/fla/vision_models/retnet/configuration_retnet.py b/fla/vision_models/retnet/configuration_retnet.py index 4f27d5531..53df13698 100644 --- a/fla/vision_models/retnet/configuration_retnet.py +++ b/fla/vision_models/retnet/configuration_retnet.py @@ -64,7 +64,6 @@ def __init__( self.max_position_embeddings = max_position_embeddings self.elementwise_affine = elementwise_affine self.norm_eps = norm_eps - self.attn = attn self.use_cache = use_cache self.initializer_range = initializer_range self.fuse_norm = fuse_norm @@ -92,6 +91,8 @@ def __init__( attn['num_kv_heads'] = attn.get('num_kv_heads', attn['num_heads']) attn['window_size'] = attn.get('window_size', None) + self.attn = attn + if mlp_dim is None: self.mlp_dim = 4 * hidden_size # default value set to 4 * hidden_size else: diff --git a/fla/vision_models/retnet/modeling_retnet.py b/fla/vision_models/retnet/modeling_retnet.py index d7918696c..961ea7c71 100644 --- a/fla/vision_models/retnet/modeling_retnet.py +++ b/fla/vision_models/retnet/modeling_retnet.py @@ -34,8 +34,7 @@ class RetNetBlock(nn.Module): def __init__(self, config, layer_idx: int): super().__init__() - if not config.norm_first: - self.ln_1 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.ln_1 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) if config.attn is not None and layer_idx in config.attn['layers']: self.attn = Attention( @@ -63,8 +62,7 @@ def __init__(self, config, layer_idx: int): layer_idx=layer_idx ) - if not config.norm_first: - self.ln_2 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.ln_2 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.mlp = RetNetMLP(config) diff --git a/fla/vision_models/rwkv6/configuration_rwkv6.py b/fla/vision_models/rwkv6/configuration_rwkv6.py index 3478c6d08..9e4d54cd0 100644 --- a/fla/vision_models/rwkv6/configuration_rwkv6.py +++ b/fla/vision_models/rwkv6/configuration_rwkv6.py @@ -57,7 +57,6 @@ def __init__( self.max_position_embeddings = max_position_embeddings self.norm_bias = norm_bias self.norm_eps = norm_eps - self.attn = attn self.use_cache = use_cache self.initializer_range = initializer_range self.fuse_norm = fuse_norm @@ -85,6 +84,8 @@ def __init__( attn['num_kv_heads'] = attn.get('num_kv_heads', attn['num_heads']) attn['window_size'] = attn.get('window_size', None) + self.attn = attn + if mlp_dim is None: self.mlp_dim = 4 * hidden_size # default value set to 4 * hidden_size else: diff --git a/fla/vision_models/rwkv6/modeling_rwkv6.py b/fla/vision_models/rwkv6/modeling_rwkv6.py index bd86d0d95..45c4df011 100644 --- a/fla/vision_models/rwkv6/modeling_rwkv6.py +++ b/fla/vision_models/rwkv6/modeling_rwkv6.py @@ -34,8 +34,7 @@ class RWKV6Block(nn.Module): def __init__(self, config, layer_idx: int): super().__init__() - if not config.norm_first: - self.ln_1 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.ln_1 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) if config.attn is not None and layer_idx in config.attn['layers']: self.attn = Attention( @@ -60,8 +59,7 @@ def __init__(self, config, layer_idx: int): layer_idx=layer_idx ) - if not config.norm_first: - self.ln_2 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.ln_2 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.mlp = RWKV6MLP(config) From b2db8d041c210fe709b64960b164b03ae3d3d59e Mon Sep 17 00:00:00 2001 From: yibozhong Date: Fri, 17 Jan 2025 19:31:28 +0800 Subject: [PATCH 10/17] change script position --- classification.py | 590 ---------------------------------------------- 1 file changed, 590 deletions(-) delete mode 100644 classification.py diff --git a/classification.py b/classification.py deleted file mode 100644 index 07ad9be44..000000000 --- a/classification.py +++ /dev/null @@ -1,590 +0,0 @@ -import os -import torch -from tqdm import tqdm -import wandb -import logging -import random -import torch.optim as optim -from torch.utils.data import DataLoader -from torchvision import datasets, transforms -from transformers import get_scheduler -from torch.amp import GradScaler, autocast -from fla.vision_models.abc import ABCVisionConfig, ABCForImageClassification -from fla.vision_models.bitnet import BitNetVisionConfig, BitNetForImageClassification -from fla.vision_models.delta_net import DeltaNetVisionConfig, DeltaNetForImageClassification -from fla.vision_models.gated_deltanet import GatedDeltaNetVisionConfig, GatedDeltaNetForImageClassification -from fla.vision_models.gla import GLAVisionConfig, GLAForImageClassification -from fla.vision_models.gsa import GSAVisionConfig, GSAForImageClassification -from fla.vision_models.hgrn import HGRNVisionConfig, HGRNForImageClassification -from fla.vision_models.hgrn2 import HGRN2VisionConfig, HGRN2ForImageClassification -from fla.vision_models.linear_attn import LinearAttentionVisionConfig, LinearAttentionForImageClassification -from fla.vision_models.retnet import RetNetVisionConfig, RetNetForImageClassification -from fla.vision_models.rwkv6 import RWKV6VisionConfig, RWKV6ForImageClassification -from fla.vision_models.transformer import TransformerVisionConfig, TransformerForImageClassification -import time - -device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') -dtype = torch.bfloat16 # deafult dtype for FLA - -def setup_logging(args): - # check whether logs directory exists - if not os.path.exists('logs'): - os.makedirs('logs') - log_filename = f'logs/training_{args.model}_vision_{args.dataset}{"_hybrid" if args.use_attn else ""}_{args.scan_type}.log' - logging.basicConfig( - level=logging.INFO, - format='%(asctime)s - %(levelname)s - %(message)s', - handlers=[ - logging.FileHandler(log_filename), - logging.StreamHandler() - ] - ) - logging.info(f"Logging to {log_filename}") - -def get_args(): - import argparse - parser = argparse.ArgumentParser(description='Vision Model Training') - parser.add_argument('--dataset', type=str, default='cifar100', help='Dataset name') - parser.add_argument('--num_hidden_layers', type=int, default=12, help='Number of hidden layers') - parser.add_argument('--hidden_size', type=int, default=768, help='Hidden size') - parser.add_argument('--patch_size', type=int, default=16, help='Patch size') - parser.add_argument('--image_size', type=int, default=224, help='Image size') - parser.add_argument('--epochs', type=int, default=50, help='Number of epochs') - parser.add_argument('--amp_enabled', action='store_true', help='Enable AMP if device supports it') - parser.add_argument('--b_lr', type=float, default=2e-4, help='Backbone learning rate') - parser.add_argument('--h_lr', type=float, default=2e-4, help='Head learning rate') - parser.add_argument('--wd', type=float, default=0., help='Weight decay') - parser.add_argument('--train_bs', type=int, default=128, help='Training batch size') - parser.add_argument('--eval_bs', type=int, default=256, help='Eval batch size') - parser.add_argument('--num_workers', type=int, default=4, help='Number of workers') - parser.add_argument('--num_heads', type=int, default=16, help='Number of attention heads') - parser.add_argument('--eval_epoch', type=int, default=1, help='Eval frequency') - parser.add_argument('--log_step', type=int, default=10, help='Log frequency') - parser.add_argument('--seed', type=int, default=42, help='Random seed') - parser.add_argument('--wandb', action='store_true', help='Enable wandb logging') - parser.add_argument('--expand_k', type=float, default=1.0, help='Key expansion ratio') - parser.add_argument('--expand_v', type=float, default=1.0, help='Value expansion ratio') - parser.add_argument('--attn_mode', type=str, default='chunk', choices=['chunk', 'fused_recurrent', 'fused_chunk']) - parser.add_argument('--model', type=str, required=True, help='Model type (currently only supports "deltanet")') - parser.add_argument('--fuse_cross_entropy', action='store_true', help='Fuse cross entropy with logits') - parser.add_argument('--scan_type', type=str, default='uni-scan', choices=['uni-scan', 'bi-scan', 'cross-scan'],) - - # Learning rate scheduler related arguments - parser.add_argument('--lr_scheduler_type', type=str, default='constant_with_warmup', - choices=['linear', 'cosine', 'cosine_with_restarts', 'polynomial', - 'constant', 'constant_with_warmup']) - parser.add_argument('--warmup_ratio', type=float, default=0.1, - help='Ratio of total training steps for warmup') - # Add hybrid attention related arguments - parser.add_argument('--use_attn', action='store_true', help='Use hybrid attention in some layers') - parser.add_argument('--attn_layers', type=str, default='0,1', - help='Comma separated list of layer indices to use attention, e.g. "0,1,2"') - # Hybrid architecture related arguments - parser.add_argument('--attn_num_heads', type=int, default=16, - help='Number of attention heads for hybrid attention layers') - parser.add_argument('--attn_num_kv_heads', type=int, default=None, - help='Number of key/value heads for hybrid attention layers') - parser.add_argument('--attn_window_size', type=int, default=None, - help='Window size for hybrid attention layers') - parser.add_argument('--log_memory_epoch', type=int, default=100, help='Log memory usage frequency') - return parser.parse_args() - -def get_data(args): - """ - Prepare data transforms and loaders. - Ensures consistent data types with model. - Current suppport only training with CIFAR-10 and CIFAR-100. - """ - transform = transforms.Compose([ - transforms.Resize((args.image_size, args.image_size)), - transforms.ToTensor(), - transforms.ConvertImageDtype(dtype), - ]) - - if args.dataset == 'cifar10': - train_dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform) - test_dataset = datasets.CIFAR10(root='./data', train=False, download=True, transform=transform) - num_classes = 10 - elif args.dataset == 'cifar100': - train_dataset = datasets.CIFAR100(root='./data', train=True, download=True, transform=transform) - test_dataset = datasets.CIFAR100(root='./data', train=False, download=True, transform=transform) - num_classes = 100 - else: - raise ValueError(f"Unsupported dataset: {args.dataset}") - - train_loader = DataLoader(train_dataset, batch_size=args.train_bs, shuffle=True, num_workers=args.num_workers) - test_loader = DataLoader(test_dataset, batch_size=args.eval_bs, shuffle=False, num_workers=args.num_workers) - - return train_loader, test_loader, num_classes - -def setup_deterministic_mode(args): - """Setup deterministic mode for reproducibility on the same device""" - import numpy as np - np.random.seed(args.seed) - random.seed(args.seed) - torch.manual_seed(args.seed) - torch.cuda.manual_seed(args.seed) - torch.backends.cudnn.deterministic = True - torch.backends.cudnn.benchmark = False - -def get_gpu_memory_info(): - """ - Get current GPU memory usage information - Returns a dictionary with: - - memory_allocated: actively allocated memory - - memory_reserved: reserved memory in GPU - - max_memory_allocated: max allocated memory since the beginning - """ - return { - 'memory_allocated': torch.cuda.memory_allocated() / 1024**2, # MB - 'memory_reserved': torch.cuda.memory_reserved() / 1024**2, # MB - 'max_memory_allocated': torch.cuda.max_memory_allocated() / 1024**2 # MB - } - -def log_gpu_memory(args, epoch): - """Log GPU memory usage if CUDA is available""" - if torch.cuda.is_available() and epoch % args.log_memory_epoch == 0: - memory_info = get_gpu_memory_info() - logging.info( - f"GPU Memory Usage (Epoch {epoch}) - " - f"Allocated: {memory_info['memory_allocated']:.2f}MB, " - f"Reserved: {memory_info['memory_reserved']:.2f}MB, " - f"Peak: {memory_info['max_memory_allocated']:.2f}MB" - ) - if args.wandb: - wandb.log({ - "gpu_memory/allocated": memory_info['memory_allocated'], - "gpu_memory/reserved": memory_info['memory_reserved'], - "gpu_memory/peak": memory_info['max_memory_allocated'], - "epoch": epoch - }) - -def evaluate(model, test_loader, device, args): - """ - Evaluation loop with proper CUDA timing. - """ - model.eval() - correct = 0 - total = 0 - - # Create CUDA events for timing - if torch.cuda.is_available(): - start_event = torch.cuda.Event(enable_timing=True) - end_event = torch.cuda.Event(enable_timing=True) - torch.cuda.synchronize() - start_event.record() - else: - start_time = time.perf_counter() - - with torch.no_grad(): - for images, targets in tqdm(test_loader): - images = images.to(device=device, dtype=dtype) - targets = targets.to(device) - - if args.amp_enabled: - with autocast(): - outputs = model(images).logits - _, predicted = outputs.max(1) - else: - outputs = model(images).logits - _, predicted = outputs.max(1) - - total += targets.size(0) - correct += predicted.eq(targets).sum().item() - - # Measure time with proper CUDA synchronization - if torch.cuda.is_available(): - end_event.record() - torch.cuda.synchronize() - eval_time = start_event.elapsed_time(end_event) / 1000.0 # Convert to seconds - else: - eval_time = time.perf_counter() - start_time - - accuracy = 100. * correct / total - return accuracy, eval_time - -def get_model(args, num_classes): - """ - Initialize model based on configuration. - Supports both pure DeltaNet and hybrid models. - """ - # Prepare attention config for hybrid model if enabled - attn_config = None - if args.use_attn: - attn_config = { - 'layers': [int(i) for i in args.attn_layers.split(',')], - 'num_heads': args.attn_num_heads, - 'num_kv_heads': args.attn_num_kv_heads, - 'window_size': args.attn_window_size - } - # Log hybrid attention configuration - logging.info("Hybrid Attention Configuration:") - logging.info(f"- Attention Layers: {attn_config['layers']}") - logging.info(f"- Number of Heads: {attn_config['num_heads']}") - logging.info(f"- Number of KV Heads: {attn_config['num_kv_heads']}") - logging.info(f"- Window Size: {attn_config['window_size']}") - - if args.model == 'deltanet': - config = DeltaNetVisionConfig( - num_hidden_layers=args.num_hidden_layers, - hidden_size=args.hidden_size, - num_heads=args.num_heads, - patch_size=args.patch_size, - image_size=args.image_size, - num_classes=num_classes, - attn_mode=args.attn_mode, - fuse_cross_entropy=args.fuse_cross_entropy, - attn=attn_config, # Add attention config for hybrid model - scan_type=args.scan_type # Add scan type to choose different scaning strategy - ) - return DeltaNetForImageClassification(config).to(device=device, dtype=dtype) - - elif args.model == 'abc': - config = ABCVisionConfig( - num_hidden_layers=args.num_hidden_layers, - hidden_size=args.hidden_size, - num_heads=args.num_heads, - patch_size=args.patch_size, - image_size=args.image_size, - num_classes=num_classes, - attn_mode=args.attn_mode, - fuse_cross_entropy=args.fuse_cross_entropy, - attn=attn_config, # Add attention config for hybrid model - scan_type=args.scan_type # Add scan type to choose different scaning strategy - ) - return ABCForImageClassification(config).to(device=device, dtype=dtype) - - elif args.model == 'gated_deltanet': - config = GatedDeltaNetVisionConfig( - num_hidden_layers=args.num_hidden_layers, - hidden_size=args.hidden_size, - num_heads=args.num_heads, - patch_size=args.patch_size, - image_size=args.image_size, - num_classes=num_classes, - attn_mode=args.attn_mode, - fuse_cross_entropy=args.fuse_cross_entropy, - attn=attn_config, # Add attention config for hybrid model - scan_type=args.scan_type # Add scan type to choose different scaning strategy - ) - return GatedDeltaNetForImageClassification(config).to(device=device, dtype=dtype) - - elif args.model == 'bitnet': - config = BitNetVisionConfig( - num_hidden_layers=args.num_hidden_layers, - hidden_size=args.hidden_size, - num_heads=args.num_heads, - patch_size=args.patch_size, - image_size=args.image_size, - num_classes=num_classes, - attn_mode=args.attn_mode, - fuse_cross_entropy=args.fuse_cross_entropy, - attn=attn_config, # Add attention config for hybrid model - scan_type=args.scan_type # Add scan type to choose different scaning strategy - ) - return BitNetForImageClassification(config).to(device=device, dtype=dtype) - - elif args.model == 'gla': - config = GLAVisionConfig( - num_hidden_layers=args.num_hidden_layers, - hidden_size=args.hidden_size, - num_heads=args.num_heads, - patch_size=args.patch_size, - image_size=args.image_size, - num_classes=num_classes, - attn_mode=args.attn_mode, - fuse_cross_entropy=args.fuse_cross_entropy, - attn=attn_config, # Add attention config for hybrid model - scan_type=args.scan_type # Add scan type to choose different scaning strategy - ) - return GLAForImageClassification(config).to(device=device, dtype=dtype) - - elif args.model == 'gsa': - config = GSAVisionConfig( - num_hidden_layers=args.num_hidden_layers, - hidden_size=args.hidden_size, - num_heads=args.num_heads, - patch_size=args.patch_size, - image_size=args.image_size, - num_classes=num_classes, - attn_mode=args.attn_mode, - fuse_cross_entropy=args.fuse_cross_entropy, - attn=attn_config, # Add attention config for hybrid model - scan_type=args.scan_type # Add scan type to choose different scaning strategy - ) - return GSAForImageClassification(config).to(device=device, dtype=dtype) - - elif args.model == 'hgrn': - config = HGRNVisionConfig( - num_hidden_layers=args.num_hidden_layers, - hidden_size=args.hidden_size, - num_heads=args.num_heads, - patch_size=args.patch_size, - image_size=args.image_size, - num_classes=num_classes, - attn_mode=args.attn_mode, - fuse_cross_entropy=args.fuse_cross_entropy, - attn=attn_config, # Add attention config for hybrid model - scan_type=args.scan_type # Add scan type to choose different scaning strategy - ) - return HGRNForImageClassification(config).to(device=device, dtype=dtype) - - elif args.model == 'hgrn2': - config = HGRN2VisionConfig( - num_hidden_layers=args.num_hidden_layers, - hidden_size=args.hidden_size, - num_heads=args.num_heads, - patch_size=args.patch_size, - image_size=args.image_size, - num_classes=num_classes, - attn_mode=args.attn_mode, - fuse_cross_entropy=args.fuse_cross_entropy, - attn=attn_config, # Add attention config for hybrid model - scan_type=args.scan_type # Add scan type to choose different scaning strategy - ) - return HGRN2ForImageClassification(config).to(device=device, dtype=dtype) - - elif args.model == 'linear_attn': - config = LinearAttentionVisionConfig( - num_hidden_layers=args.num_hidden_layers, - hidden_size=args.hidden_size, - num_heads=args.num_heads, - patch_size=args.patch_size, - image_size=args.image_size, - num_classes=num_classes, - attn_mode=args.attn_mode, - fuse_cross_entropy=args.fuse_cross_entropy, - attn=attn_config, # Add attention config for hybrid model - scan_type=args.scan_type # Add scan type to choose different scaning strategy - ) - return LinearAttentionForImageClassification(config).to(device=device, dtype=dtype) - - elif args.model == 'retnet': - config = RetNetVisionConfig( - num_hidden_layers=args.num_hidden_layers, - hidden_size=args.hidden_size, - num_heads=args.num_heads, - patch_size=args.patch_size, - image_size=args.image_size, - num_classes=num_classes, - attn_mode=args.attn_mode, - fuse_cross_entropy=args.fuse_cross_entropy, - attn=attn_config, # Add attention config for hybrid model - scan_type=args.scan_type # Add scan type to choose different scaning strategy - ) - return RetNetForImageClassification(config).to(device=device, dtype=dtype) - - elif args.model == 'rwkv6': - config = RWKV6VisionConfig( - num_hidden_layers=args.num_hidden_layers, - hidden_size=args.hidden_size, - num_heads=args.num_heads, - patch_size=args.patch_size, - image_size=args.image_size, - num_classes=num_classes, - attn_mode=args.attn_mode, - fuse_cross_entropy=args.fuse_cross_entropy, - attn=attn_config, # Add attention config for hybrid model - scan_type=args.scan_type # Add scan type to choose different scaning strategy - ) - return RWKV6ForImageClassification(config).to(device=device, dtype=dtype) - - elif args.model == 'transformer': - config = TransformerVisionConfig( - num_hidden_layers=args.num_hidden_layers, - hidden_size=args.hidden_size, - num_heads=args.num_heads, - patch_size=args.patch_size, - image_size=args.image_size, - num_classes=num_classes - ) - return TransformerForImageClassification(config).to(device=device, dtype=dtype) - -def train_one_epoch(model, train_loader, criterion, optimizer, scheduler, device, args, epoch): - """ - Training loop for one epoch with proper CUDA timing. - """ - model.train() - total_loss = 0 - scaler = GradScaler() if args.amp_enabled else None - - # Create CUDA events for timing - if torch.cuda.is_available(): - start_event = torch.cuda.Event(enable_timing=True) - end_event = torch.cuda.Event(enable_timing=True) - torch.cuda.synchronize() - start_event.record() - else: - start_time = time.perf_counter() - - for i, (images, targets) in enumerate(tqdm(train_loader)): - images = images.to(device=device, dtype=dtype) - targets = targets.to(device) - - if args.amp_enabled: - with autocast(): - outputs = model(images).logits - loss = criterion(outputs, targets) - scaler.scale(loss).backward() - scaler.step(optimizer) - scaler.update() - else: - outputs = model(images).logits - loss = criterion(outputs, targets) - loss.backward() - optimizer.step() - - optimizer.zero_grad() - scheduler.step() # Update learning rate scheduler - total_loss += loss.item() - - if i % args.log_step == 0: - lrs = [group['lr'] for group in optimizer.param_groups] - logging.info(f'Epoch {epoch} Step {i}/{len(train_loader)}: ' - f'Loss={loss.item():.4f} ' - f'LR_backbone={lrs[0]:.2e} ' - f'LR_head={lrs[-1]:.2e}') - - if args.wandb: - wandb.log({ - "batch_loss": loss.item(), - "learning_rate/backbone": lrs[0], - "learning_rate/head": lrs[-1], - "global_step": epoch * len(train_loader) + i - }) - - # Measure time with proper CUDA synchronization - if torch.cuda.is_available(): - end_event.record() - torch.cuda.synchronize() - train_time = start_event.elapsed_time(end_event) / 1000.0 - else: - train_time = time.perf_counter() - start_time - - avg_loss = total_loss / len(train_loader) - return avg_loss, train_time - -def main(): - args = get_args() - - # Setup logging first, before any logging calls - setup_logging(args) - - # Then setup deterministic mode - setup_deterministic_mode(args) - - # Log all configuration parameters - logging.info("=" * 50) - logging.info("Training Configuration:") - logging.info("-" * 50) - for arg, value in sorted(vars(args).items()): - logging.info(f"{arg}: {value}") - logging.info("=" * 50) - - # Setup wandb after logging is initialized - if args.wandb: - project_name = "fla-vision" - run_name = f'training_{args.model}_{args.dataset}{"_hybrid" if args.use_attn else ""}_{args.scan_type}_e{args.epochs}_blr_{args.b_lr}_hlr_{args.h_lr}_bs{args.train_bs}_mode_{args.attn_mode}' - wandb.init( - project=project_name, - name=run_name, - config=args.__dict__ - ) - logging.info(f"Wandb initialized with project: {project_name}, run: {run_name}") - - train_loader, test_loader, num_classes = get_data(args) - - # Calculate total training steps - num_training_steps = len(train_loader) * args.epochs - num_warmup_steps = int(args.warmup_ratio * num_training_steps) - - model = get_model(args, num_classes) - trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad) - - logging.info("=" * 50) - logging.info("Model Information:") - logging.info("-" * 50) - logging.info(f"Model Type: {args.model}") - logging.info(f"Number of trainable parameters: {trainable_params:,}") - logging.info(f"Number of layers: {args.num_hidden_layers}") - logging.info(f"Hidden size: {args.hidden_size}") - logging.info(f"Number of heads: {args.num_heads}") - logging.info(f"Learning rate scheduler: {args.lr_scheduler_type}") - logging.info(f"Total training steps: {num_training_steps}") - logging.info(f"Warmup steps: {num_warmup_steps}") - logging.info("=" * 50) - - if args.wandb: - wandb.log({"trainable_parameters": trainable_params}) - - criterion = torch.nn.CrossEntropyLoss() - optimizer = optim.AdamW([ - {'params': model.embeddings.parameters(), 'lr': args.b_lr}, - {'params': model.blocks.parameters(), 'lr': args.b_lr}, - {'params': model.classifier.parameters(), 'lr': args.h_lr} - ], weight_decay=args.wd) - - scheduler = get_scheduler( - name=args.lr_scheduler_type, - optimizer=optimizer, - num_warmup_steps=num_warmup_steps, - num_training_steps=num_training_steps - ) - - best_acc = 0 - total_train_time = 0 - total_eval_time = 0 - eval_num = 0 - - for epoch in range(args.epochs): - avg_loss, epoch_train_time = train_one_epoch(model, train_loader, criterion, optimizer, scheduler, device, args, epoch) - total_train_time += epoch_train_time - - # Log GPU memory usage - log_gpu_memory(args, epoch) - - if epoch % args.eval_epoch == 0: - accuracy, epoch_eval_time = evaluate(model, test_loader, device, args) - total_eval_time += epoch_eval_time - eval_num += 1 - - logging.info( - f'Epoch {epoch}: Loss={avg_loss:.4f}, Accuracy={accuracy:.2f}%, ' - f'Train time={epoch_train_time:.2f}s, Eval time={epoch_eval_time:.2f}s' - ) - - if args.wandb: - wandb.log({ - "epoch": epoch, - "train_loss": avg_loss, - "accuracy": accuracy, - "train_time": epoch_train_time, - "eval_time": epoch_eval_time, - "avg_epoch_train_time": total_train_time / (epoch + 1), - "avg_epoch_eval_time": total_eval_time / eval_num - }) - - if accuracy > best_acc: - best_acc = accuracy - torch.save(model.state_dict(), f'{args.model}_vision_best.pth') - - # Log final statistics - avg_train_time = total_train_time / args.epochs - avg_eval_time = total_eval_time / eval_num - logging.info( - f'Training completed. Best accuracy: {best_acc:.2f}%\n' - f'Average training time per epoch: {avg_train_time:.2f}s\n' - f'Average evaluation time: {avg_eval_time:.2f}s' - ) - - if args.wandb: - wandb.log({ - "final/best_accuracy": best_acc, - "final/avg_train_time": avg_train_time, - "final/avg_eval_time": avg_eval_time - }) - if args.wandb: - wandb.finish() - -if __name__ == '__main__': - main() From be439b007b6d2f8c1b897f8bbdc01c40e197e0c7 Mon Sep 17 00:00:00 2001 From: yibozhong Date: Fri, 17 Jan 2025 19:32:04 +0800 Subject: [PATCH 11/17] change script position --- training/classification.py | 590 +++++++++++++++++++++++++++++++++++++ 1 file changed, 590 insertions(+) create mode 100644 training/classification.py diff --git a/training/classification.py b/training/classification.py new file mode 100644 index 000000000..07ad9be44 --- /dev/null +++ b/training/classification.py @@ -0,0 +1,590 @@ +import os +import torch +from tqdm import tqdm +import wandb +import logging +import random +import torch.optim as optim +from torch.utils.data import DataLoader +from torchvision import datasets, transforms +from transformers import get_scheduler +from torch.amp import GradScaler, autocast +from fla.vision_models.abc import ABCVisionConfig, ABCForImageClassification +from fla.vision_models.bitnet import BitNetVisionConfig, BitNetForImageClassification +from fla.vision_models.delta_net import DeltaNetVisionConfig, DeltaNetForImageClassification +from fla.vision_models.gated_deltanet import GatedDeltaNetVisionConfig, GatedDeltaNetForImageClassification +from fla.vision_models.gla import GLAVisionConfig, GLAForImageClassification +from fla.vision_models.gsa import GSAVisionConfig, GSAForImageClassification +from fla.vision_models.hgrn import HGRNVisionConfig, HGRNForImageClassification +from fla.vision_models.hgrn2 import HGRN2VisionConfig, HGRN2ForImageClassification +from fla.vision_models.linear_attn import LinearAttentionVisionConfig, LinearAttentionForImageClassification +from fla.vision_models.retnet import RetNetVisionConfig, RetNetForImageClassification +from fla.vision_models.rwkv6 import RWKV6VisionConfig, RWKV6ForImageClassification +from fla.vision_models.transformer import TransformerVisionConfig, TransformerForImageClassification +import time + +device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') +dtype = torch.bfloat16 # deafult dtype for FLA + +def setup_logging(args): + # check whether logs directory exists + if not os.path.exists('logs'): + os.makedirs('logs') + log_filename = f'logs/training_{args.model}_vision_{args.dataset}{"_hybrid" if args.use_attn else ""}_{args.scan_type}.log' + logging.basicConfig( + level=logging.INFO, + format='%(asctime)s - %(levelname)s - %(message)s', + handlers=[ + logging.FileHandler(log_filename), + logging.StreamHandler() + ] + ) + logging.info(f"Logging to {log_filename}") + +def get_args(): + import argparse + parser = argparse.ArgumentParser(description='Vision Model Training') + parser.add_argument('--dataset', type=str, default='cifar100', help='Dataset name') + parser.add_argument('--num_hidden_layers', type=int, default=12, help='Number of hidden layers') + parser.add_argument('--hidden_size', type=int, default=768, help='Hidden size') + parser.add_argument('--patch_size', type=int, default=16, help='Patch size') + parser.add_argument('--image_size', type=int, default=224, help='Image size') + parser.add_argument('--epochs', type=int, default=50, help='Number of epochs') + parser.add_argument('--amp_enabled', action='store_true', help='Enable AMP if device supports it') + parser.add_argument('--b_lr', type=float, default=2e-4, help='Backbone learning rate') + parser.add_argument('--h_lr', type=float, default=2e-4, help='Head learning rate') + parser.add_argument('--wd', type=float, default=0., help='Weight decay') + parser.add_argument('--train_bs', type=int, default=128, help='Training batch size') + parser.add_argument('--eval_bs', type=int, default=256, help='Eval batch size') + parser.add_argument('--num_workers', type=int, default=4, help='Number of workers') + parser.add_argument('--num_heads', type=int, default=16, help='Number of attention heads') + parser.add_argument('--eval_epoch', type=int, default=1, help='Eval frequency') + parser.add_argument('--log_step', type=int, default=10, help='Log frequency') + parser.add_argument('--seed', type=int, default=42, help='Random seed') + parser.add_argument('--wandb', action='store_true', help='Enable wandb logging') + parser.add_argument('--expand_k', type=float, default=1.0, help='Key expansion ratio') + parser.add_argument('--expand_v', type=float, default=1.0, help='Value expansion ratio') + parser.add_argument('--attn_mode', type=str, default='chunk', choices=['chunk', 'fused_recurrent', 'fused_chunk']) + parser.add_argument('--model', type=str, required=True, help='Model type (currently only supports "deltanet")') + parser.add_argument('--fuse_cross_entropy', action='store_true', help='Fuse cross entropy with logits') + parser.add_argument('--scan_type', type=str, default='uni-scan', choices=['uni-scan', 'bi-scan', 'cross-scan'],) + + # Learning rate scheduler related arguments + parser.add_argument('--lr_scheduler_type', type=str, default='constant_with_warmup', + choices=['linear', 'cosine', 'cosine_with_restarts', 'polynomial', + 'constant', 'constant_with_warmup']) + parser.add_argument('--warmup_ratio', type=float, default=0.1, + help='Ratio of total training steps for warmup') + # Add hybrid attention related arguments + parser.add_argument('--use_attn', action='store_true', help='Use hybrid attention in some layers') + parser.add_argument('--attn_layers', type=str, default='0,1', + help='Comma separated list of layer indices to use attention, e.g. "0,1,2"') + # Hybrid architecture related arguments + parser.add_argument('--attn_num_heads', type=int, default=16, + help='Number of attention heads for hybrid attention layers') + parser.add_argument('--attn_num_kv_heads', type=int, default=None, + help='Number of key/value heads for hybrid attention layers') + parser.add_argument('--attn_window_size', type=int, default=None, + help='Window size for hybrid attention layers') + parser.add_argument('--log_memory_epoch', type=int, default=100, help='Log memory usage frequency') + return parser.parse_args() + +def get_data(args): + """ + Prepare data transforms and loaders. + Ensures consistent data types with model. + Current suppport only training with CIFAR-10 and CIFAR-100. + """ + transform = transforms.Compose([ + transforms.Resize((args.image_size, args.image_size)), + transforms.ToTensor(), + transforms.ConvertImageDtype(dtype), + ]) + + if args.dataset == 'cifar10': + train_dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform) + test_dataset = datasets.CIFAR10(root='./data', train=False, download=True, transform=transform) + num_classes = 10 + elif args.dataset == 'cifar100': + train_dataset = datasets.CIFAR100(root='./data', train=True, download=True, transform=transform) + test_dataset = datasets.CIFAR100(root='./data', train=False, download=True, transform=transform) + num_classes = 100 + else: + raise ValueError(f"Unsupported dataset: {args.dataset}") + + train_loader = DataLoader(train_dataset, batch_size=args.train_bs, shuffle=True, num_workers=args.num_workers) + test_loader = DataLoader(test_dataset, batch_size=args.eval_bs, shuffle=False, num_workers=args.num_workers) + + return train_loader, test_loader, num_classes + +def setup_deterministic_mode(args): + """Setup deterministic mode for reproducibility on the same device""" + import numpy as np + np.random.seed(args.seed) + random.seed(args.seed) + torch.manual_seed(args.seed) + torch.cuda.manual_seed(args.seed) + torch.backends.cudnn.deterministic = True + torch.backends.cudnn.benchmark = False + +def get_gpu_memory_info(): + """ + Get current GPU memory usage information + Returns a dictionary with: + - memory_allocated: actively allocated memory + - memory_reserved: reserved memory in GPU + - max_memory_allocated: max allocated memory since the beginning + """ + return { + 'memory_allocated': torch.cuda.memory_allocated() / 1024**2, # MB + 'memory_reserved': torch.cuda.memory_reserved() / 1024**2, # MB + 'max_memory_allocated': torch.cuda.max_memory_allocated() / 1024**2 # MB + } + +def log_gpu_memory(args, epoch): + """Log GPU memory usage if CUDA is available""" + if torch.cuda.is_available() and epoch % args.log_memory_epoch == 0: + memory_info = get_gpu_memory_info() + logging.info( + f"GPU Memory Usage (Epoch {epoch}) - " + f"Allocated: {memory_info['memory_allocated']:.2f}MB, " + f"Reserved: {memory_info['memory_reserved']:.2f}MB, " + f"Peak: {memory_info['max_memory_allocated']:.2f}MB" + ) + if args.wandb: + wandb.log({ + "gpu_memory/allocated": memory_info['memory_allocated'], + "gpu_memory/reserved": memory_info['memory_reserved'], + "gpu_memory/peak": memory_info['max_memory_allocated'], + "epoch": epoch + }) + +def evaluate(model, test_loader, device, args): + """ + Evaluation loop with proper CUDA timing. + """ + model.eval() + correct = 0 + total = 0 + + # Create CUDA events for timing + if torch.cuda.is_available(): + start_event = torch.cuda.Event(enable_timing=True) + end_event = torch.cuda.Event(enable_timing=True) + torch.cuda.synchronize() + start_event.record() + else: + start_time = time.perf_counter() + + with torch.no_grad(): + for images, targets in tqdm(test_loader): + images = images.to(device=device, dtype=dtype) + targets = targets.to(device) + + if args.amp_enabled: + with autocast(): + outputs = model(images).logits + _, predicted = outputs.max(1) + else: + outputs = model(images).logits + _, predicted = outputs.max(1) + + total += targets.size(0) + correct += predicted.eq(targets).sum().item() + + # Measure time with proper CUDA synchronization + if torch.cuda.is_available(): + end_event.record() + torch.cuda.synchronize() + eval_time = start_event.elapsed_time(end_event) / 1000.0 # Convert to seconds + else: + eval_time = time.perf_counter() - start_time + + accuracy = 100. * correct / total + return accuracy, eval_time + +def get_model(args, num_classes): + """ + Initialize model based on configuration. + Supports both pure DeltaNet and hybrid models. + """ + # Prepare attention config for hybrid model if enabled + attn_config = None + if args.use_attn: + attn_config = { + 'layers': [int(i) for i in args.attn_layers.split(',')], + 'num_heads': args.attn_num_heads, + 'num_kv_heads': args.attn_num_kv_heads, + 'window_size': args.attn_window_size + } + # Log hybrid attention configuration + logging.info("Hybrid Attention Configuration:") + logging.info(f"- Attention Layers: {attn_config['layers']}") + logging.info(f"- Number of Heads: {attn_config['num_heads']}") + logging.info(f"- Number of KV Heads: {attn_config['num_kv_heads']}") + logging.info(f"- Window Size: {attn_config['window_size']}") + + if args.model == 'deltanet': + config = DeltaNetVisionConfig( + num_hidden_layers=args.num_hidden_layers, + hidden_size=args.hidden_size, + num_heads=args.num_heads, + patch_size=args.patch_size, + image_size=args.image_size, + num_classes=num_classes, + attn_mode=args.attn_mode, + fuse_cross_entropy=args.fuse_cross_entropy, + attn=attn_config, # Add attention config for hybrid model + scan_type=args.scan_type # Add scan type to choose different scaning strategy + ) + return DeltaNetForImageClassification(config).to(device=device, dtype=dtype) + + elif args.model == 'abc': + config = ABCVisionConfig( + num_hidden_layers=args.num_hidden_layers, + hidden_size=args.hidden_size, + num_heads=args.num_heads, + patch_size=args.patch_size, + image_size=args.image_size, + num_classes=num_classes, + attn_mode=args.attn_mode, + fuse_cross_entropy=args.fuse_cross_entropy, + attn=attn_config, # Add attention config for hybrid model + scan_type=args.scan_type # Add scan type to choose different scaning strategy + ) + return ABCForImageClassification(config).to(device=device, dtype=dtype) + + elif args.model == 'gated_deltanet': + config = GatedDeltaNetVisionConfig( + num_hidden_layers=args.num_hidden_layers, + hidden_size=args.hidden_size, + num_heads=args.num_heads, + patch_size=args.patch_size, + image_size=args.image_size, + num_classes=num_classes, + attn_mode=args.attn_mode, + fuse_cross_entropy=args.fuse_cross_entropy, + attn=attn_config, # Add attention config for hybrid model + scan_type=args.scan_type # Add scan type to choose different scaning strategy + ) + return GatedDeltaNetForImageClassification(config).to(device=device, dtype=dtype) + + elif args.model == 'bitnet': + config = BitNetVisionConfig( + num_hidden_layers=args.num_hidden_layers, + hidden_size=args.hidden_size, + num_heads=args.num_heads, + patch_size=args.patch_size, + image_size=args.image_size, + num_classes=num_classes, + attn_mode=args.attn_mode, + fuse_cross_entropy=args.fuse_cross_entropy, + attn=attn_config, # Add attention config for hybrid model + scan_type=args.scan_type # Add scan type to choose different scaning strategy + ) + return BitNetForImageClassification(config).to(device=device, dtype=dtype) + + elif args.model == 'gla': + config = GLAVisionConfig( + num_hidden_layers=args.num_hidden_layers, + hidden_size=args.hidden_size, + num_heads=args.num_heads, + patch_size=args.patch_size, + image_size=args.image_size, + num_classes=num_classes, + attn_mode=args.attn_mode, + fuse_cross_entropy=args.fuse_cross_entropy, + attn=attn_config, # Add attention config for hybrid model + scan_type=args.scan_type # Add scan type to choose different scaning strategy + ) + return GLAForImageClassification(config).to(device=device, dtype=dtype) + + elif args.model == 'gsa': + config = GSAVisionConfig( + num_hidden_layers=args.num_hidden_layers, + hidden_size=args.hidden_size, + num_heads=args.num_heads, + patch_size=args.patch_size, + image_size=args.image_size, + num_classes=num_classes, + attn_mode=args.attn_mode, + fuse_cross_entropy=args.fuse_cross_entropy, + attn=attn_config, # Add attention config for hybrid model + scan_type=args.scan_type # Add scan type to choose different scaning strategy + ) + return GSAForImageClassification(config).to(device=device, dtype=dtype) + + elif args.model == 'hgrn': + config = HGRNVisionConfig( + num_hidden_layers=args.num_hidden_layers, + hidden_size=args.hidden_size, + num_heads=args.num_heads, + patch_size=args.patch_size, + image_size=args.image_size, + num_classes=num_classes, + attn_mode=args.attn_mode, + fuse_cross_entropy=args.fuse_cross_entropy, + attn=attn_config, # Add attention config for hybrid model + scan_type=args.scan_type # Add scan type to choose different scaning strategy + ) + return HGRNForImageClassification(config).to(device=device, dtype=dtype) + + elif args.model == 'hgrn2': + config = HGRN2VisionConfig( + num_hidden_layers=args.num_hidden_layers, + hidden_size=args.hidden_size, + num_heads=args.num_heads, + patch_size=args.patch_size, + image_size=args.image_size, + num_classes=num_classes, + attn_mode=args.attn_mode, + fuse_cross_entropy=args.fuse_cross_entropy, + attn=attn_config, # Add attention config for hybrid model + scan_type=args.scan_type # Add scan type to choose different scaning strategy + ) + return HGRN2ForImageClassification(config).to(device=device, dtype=dtype) + + elif args.model == 'linear_attn': + config = LinearAttentionVisionConfig( + num_hidden_layers=args.num_hidden_layers, + hidden_size=args.hidden_size, + num_heads=args.num_heads, + patch_size=args.patch_size, + image_size=args.image_size, + num_classes=num_classes, + attn_mode=args.attn_mode, + fuse_cross_entropy=args.fuse_cross_entropy, + attn=attn_config, # Add attention config for hybrid model + scan_type=args.scan_type # Add scan type to choose different scaning strategy + ) + return LinearAttentionForImageClassification(config).to(device=device, dtype=dtype) + + elif args.model == 'retnet': + config = RetNetVisionConfig( + num_hidden_layers=args.num_hidden_layers, + hidden_size=args.hidden_size, + num_heads=args.num_heads, + patch_size=args.patch_size, + image_size=args.image_size, + num_classes=num_classes, + attn_mode=args.attn_mode, + fuse_cross_entropy=args.fuse_cross_entropy, + attn=attn_config, # Add attention config for hybrid model + scan_type=args.scan_type # Add scan type to choose different scaning strategy + ) + return RetNetForImageClassification(config).to(device=device, dtype=dtype) + + elif args.model == 'rwkv6': + config = RWKV6VisionConfig( + num_hidden_layers=args.num_hidden_layers, + hidden_size=args.hidden_size, + num_heads=args.num_heads, + patch_size=args.patch_size, + image_size=args.image_size, + num_classes=num_classes, + attn_mode=args.attn_mode, + fuse_cross_entropy=args.fuse_cross_entropy, + attn=attn_config, # Add attention config for hybrid model + scan_type=args.scan_type # Add scan type to choose different scaning strategy + ) + return RWKV6ForImageClassification(config).to(device=device, dtype=dtype) + + elif args.model == 'transformer': + config = TransformerVisionConfig( + num_hidden_layers=args.num_hidden_layers, + hidden_size=args.hidden_size, + num_heads=args.num_heads, + patch_size=args.patch_size, + image_size=args.image_size, + num_classes=num_classes + ) + return TransformerForImageClassification(config).to(device=device, dtype=dtype) + +def train_one_epoch(model, train_loader, criterion, optimizer, scheduler, device, args, epoch): + """ + Training loop for one epoch with proper CUDA timing. + """ + model.train() + total_loss = 0 + scaler = GradScaler() if args.amp_enabled else None + + # Create CUDA events for timing + if torch.cuda.is_available(): + start_event = torch.cuda.Event(enable_timing=True) + end_event = torch.cuda.Event(enable_timing=True) + torch.cuda.synchronize() + start_event.record() + else: + start_time = time.perf_counter() + + for i, (images, targets) in enumerate(tqdm(train_loader)): + images = images.to(device=device, dtype=dtype) + targets = targets.to(device) + + if args.amp_enabled: + with autocast(): + outputs = model(images).logits + loss = criterion(outputs, targets) + scaler.scale(loss).backward() + scaler.step(optimizer) + scaler.update() + else: + outputs = model(images).logits + loss = criterion(outputs, targets) + loss.backward() + optimizer.step() + + optimizer.zero_grad() + scheduler.step() # Update learning rate scheduler + total_loss += loss.item() + + if i % args.log_step == 0: + lrs = [group['lr'] for group in optimizer.param_groups] + logging.info(f'Epoch {epoch} Step {i}/{len(train_loader)}: ' + f'Loss={loss.item():.4f} ' + f'LR_backbone={lrs[0]:.2e} ' + f'LR_head={lrs[-1]:.2e}') + + if args.wandb: + wandb.log({ + "batch_loss": loss.item(), + "learning_rate/backbone": lrs[0], + "learning_rate/head": lrs[-1], + "global_step": epoch * len(train_loader) + i + }) + + # Measure time with proper CUDA synchronization + if torch.cuda.is_available(): + end_event.record() + torch.cuda.synchronize() + train_time = start_event.elapsed_time(end_event) / 1000.0 + else: + train_time = time.perf_counter() - start_time + + avg_loss = total_loss / len(train_loader) + return avg_loss, train_time + +def main(): + args = get_args() + + # Setup logging first, before any logging calls + setup_logging(args) + + # Then setup deterministic mode + setup_deterministic_mode(args) + + # Log all configuration parameters + logging.info("=" * 50) + logging.info("Training Configuration:") + logging.info("-" * 50) + for arg, value in sorted(vars(args).items()): + logging.info(f"{arg}: {value}") + logging.info("=" * 50) + + # Setup wandb after logging is initialized + if args.wandb: + project_name = "fla-vision" + run_name = f'training_{args.model}_{args.dataset}{"_hybrid" if args.use_attn else ""}_{args.scan_type}_e{args.epochs}_blr_{args.b_lr}_hlr_{args.h_lr}_bs{args.train_bs}_mode_{args.attn_mode}' + wandb.init( + project=project_name, + name=run_name, + config=args.__dict__ + ) + logging.info(f"Wandb initialized with project: {project_name}, run: {run_name}") + + train_loader, test_loader, num_classes = get_data(args) + + # Calculate total training steps + num_training_steps = len(train_loader) * args.epochs + num_warmup_steps = int(args.warmup_ratio * num_training_steps) + + model = get_model(args, num_classes) + trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad) + + logging.info("=" * 50) + logging.info("Model Information:") + logging.info("-" * 50) + logging.info(f"Model Type: {args.model}") + logging.info(f"Number of trainable parameters: {trainable_params:,}") + logging.info(f"Number of layers: {args.num_hidden_layers}") + logging.info(f"Hidden size: {args.hidden_size}") + logging.info(f"Number of heads: {args.num_heads}") + logging.info(f"Learning rate scheduler: {args.lr_scheduler_type}") + logging.info(f"Total training steps: {num_training_steps}") + logging.info(f"Warmup steps: {num_warmup_steps}") + logging.info("=" * 50) + + if args.wandb: + wandb.log({"trainable_parameters": trainable_params}) + + criterion = torch.nn.CrossEntropyLoss() + optimizer = optim.AdamW([ + {'params': model.embeddings.parameters(), 'lr': args.b_lr}, + {'params': model.blocks.parameters(), 'lr': args.b_lr}, + {'params': model.classifier.parameters(), 'lr': args.h_lr} + ], weight_decay=args.wd) + + scheduler = get_scheduler( + name=args.lr_scheduler_type, + optimizer=optimizer, + num_warmup_steps=num_warmup_steps, + num_training_steps=num_training_steps + ) + + best_acc = 0 + total_train_time = 0 + total_eval_time = 0 + eval_num = 0 + + for epoch in range(args.epochs): + avg_loss, epoch_train_time = train_one_epoch(model, train_loader, criterion, optimizer, scheduler, device, args, epoch) + total_train_time += epoch_train_time + + # Log GPU memory usage + log_gpu_memory(args, epoch) + + if epoch % args.eval_epoch == 0: + accuracy, epoch_eval_time = evaluate(model, test_loader, device, args) + total_eval_time += epoch_eval_time + eval_num += 1 + + logging.info( + f'Epoch {epoch}: Loss={avg_loss:.4f}, Accuracy={accuracy:.2f}%, ' + f'Train time={epoch_train_time:.2f}s, Eval time={epoch_eval_time:.2f}s' + ) + + if args.wandb: + wandb.log({ + "epoch": epoch, + "train_loss": avg_loss, + "accuracy": accuracy, + "train_time": epoch_train_time, + "eval_time": epoch_eval_time, + "avg_epoch_train_time": total_train_time / (epoch + 1), + "avg_epoch_eval_time": total_eval_time / eval_num + }) + + if accuracy > best_acc: + best_acc = accuracy + torch.save(model.state_dict(), f'{args.model}_vision_best.pth') + + # Log final statistics + avg_train_time = total_train_time / args.epochs + avg_eval_time = total_eval_time / eval_num + logging.info( + f'Training completed. Best accuracy: {best_acc:.2f}%\n' + f'Average training time per epoch: {avg_train_time:.2f}s\n' + f'Average evaluation time: {avg_eval_time:.2f}s' + ) + + if args.wandb: + wandb.log({ + "final/best_accuracy": best_acc, + "final/avg_train_time": avg_train_time, + "final/avg_eval_time": avg_eval_time + }) + if args.wandb: + wandb.finish() + +if __name__ == '__main__': + main() From 57fd584a18bdf168dd045e13add87cb2e361b810 Mon Sep 17 00:00:00 2001 From: yibozhong Date: Fri, 17 Jan 2025 19:43:00 +0800 Subject: [PATCH 12/17] update __init__.py for vision models --- fla/vision_models/__init__.py | 35 ++++++++++++++++++++++++++++++++++- 1 file changed, 34 insertions(+), 1 deletion(-) diff --git a/fla/vision_models/__init__.py b/fla/vision_models/__init__.py index f93f6573d..9538e6c2c 100644 --- a/fla/vision_models/__init__.py +++ b/fla/vision_models/__init__.py @@ -1,10 +1,43 @@ +from fla.vision_models.abc import ABCVisionConfig, ABCForImageClassification +from fla.vision_models.bitnet import BitNetVisionConfig, BitNetForImageClassification from fla.vision_models.delta_net import DeltaNetVisionConfig, DeltaNetForImageClassification +from fla.vision_models.gated_deltanet import GatedDeltaNetVisionConfig, GatedDeltaNetForImageClassification +from fla.vision_models.gla import GLAVisionConfig, GLAForImageClassification +from fla.vision_models.gsa import GSAVisionConfig, GSAForImageClassification +from fla.vision_models.hgrn import HGRNVisionConfig, HGRNForImageClassification +from fla.vision_models.hgrn2 import HGRN2VisionConfig, HGRN2ForImageClassification +from fla.vision_models.linear_attn import LinearAttentionVisionConfig, LinearAttentionForImageClassification +from fla.vision_models.retnet import RetNetVisionConfig, RetNetForImageClassification +from fla.vision_models.rwkv6 import RWKV6VisionConfig, RWKV6ForImageClassification +from fla.vision_models.transformer import TransformerVisionConfig, TransformerForImageClassification from fla.vision_models.utils import ImageEmbeddings, PatchEmbeddings, Pooler __all__ = [ + 'ABCVisionConfig', + 'ABCForImageClassification', + 'BitNetVisionConfig', + 'BitNetForImageClassification', 'DeltaNetVisionConfig', 'DeltaNetForImageClassification', + 'GatedDeltaNetVisionConfig', + 'GatedDeltaNetForImageClassification', + 'GLAVisionConfig', + 'GLAForImageClassification', + 'GSAVisionConfig', + 'GSAForImageClassification', + 'HGRNVisionConfig', + 'HGRNForImageClassification', + 'HGRN2VisionConfig', + 'HGRN2ForImageClassification', + 'LinearAttentionVisionConfig', + 'LinearAttentionForImageClassification', + 'RetNetVisionConfig', + 'RetNetForImageClassification', + 'RWKV6VisionConfig', + 'RWKV6ForImageClassification', + 'TransformerVisionConfig', + 'TransformerForImageClassification', 'ImageEmbeddings', 'PatchEmbeddings', - 'Pooler' + 'Pooler', ] From 1674a22c855e1243e21b4c4c8eae2bf61047bd2b Mon Sep 17 00:00:00 2001 From: yibozhong Date: Sun, 19 Jan 2025 02:08:47 +0800 Subject: [PATCH 13/17] Standarized the code and add implementations for basemodel and masked image model --- fla/vision_models/delta_net/__init__.py | 12 +- .../delta_net/configuration_delta_net.py | 3 + .../delta_net/modeling_delta_net.py | 253 +++++++++++++++--- 3 files changed, 227 insertions(+), 41 deletions(-) diff --git a/fla/vision_models/delta_net/__init__.py b/fla/vision_models/delta_net/__init__.py index 3b951dd04..eef31ccbc 100644 --- a/fla/vision_models/delta_net/__init__.py +++ b/fla/vision_models/delta_net/__init__.py @@ -1,12 +1,16 @@ -from transformers import AutoConfig, AutoModelForImageClassification +from transformers import AutoConfig, AutoModel, AutoModelForImageClassification, AutoModelForMaskedImageModeling from fla.vision_models.delta_net.configuration_delta_net import DeltaNetVisionConfig -from fla.vision_models.delta_net.modeling_delta_net import DeltaNetForImageClassification +from fla.vision_models.delta_net.modeling_delta_net import DeltaNetForImageClassification, DeltaNetVisionModel, DeltaNetForMaskedImageModeling AutoConfig.register(DeltaNetVisionConfig.model_type, DeltaNetVisionConfig) AutoModelForImageClassification.register(DeltaNetVisionConfig, DeltaNetForImageClassification) +AutoModelForMaskedImageModeling.register(DeltaNetVisionConfig, DeltaNetForMaskedImageModeling) +AutoModel.register(DeltaNetVisionConfig, DeltaNetVisionModel) __all__ = [ - 'DeltaNetVisionConfig', - 'DeltaNetForImageClassification' + "DeltaNetVisionConfig", + "DeltaNetForImageClassification", + "DeltaNetVisionModel", + "DeltaNetForMaskedImageModeling" ] diff --git a/fla/vision_models/delta_net/configuration_delta_net.py b/fla/vision_models/delta_net/configuration_delta_net.py index c6921a0fe..b24b48908 100644 --- a/fla/vision_models/delta_net/configuration_delta_net.py +++ b/fla/vision_models/delta_net/configuration_delta_net.py @@ -39,6 +39,7 @@ def __init__( use_mask_token: bool = False, layer_norm_eps: float = 1e-6, interpolate_pos_encoding: bool = False, + encoder_stride=16, mlp_dim: int = None, # FLA-for-vision-related parameters scan_type: str = "uni-scan", # scaning type, "uni-scan" or "bi-scan" or "cross-scan", default to "uni-scan" @@ -77,6 +78,8 @@ def __init__( self.layer_norm_eps = layer_norm_eps self.interpolate_pos_encoding = interpolate_pos_encoding self.scan_type = scan_type + self.encoder_stride = encoder_stride + if attn is not None: if not isinstance(attn, Dict): diff --git a/fla/vision_models/delta_net/modeling_delta_net.py b/fla/vision_models/delta_net/modeling_delta_net.py index 879a16327..e9ece6adc 100644 --- a/fla/vision_models/delta_net/modeling_delta_net.py +++ b/fla/vision_models/delta_net/modeling_delta_net.py @@ -7,7 +7,7 @@ from typing import Optional, Set, Tuple, Union, List, Dict, Unpack from transformers.utils import logging from fla.layers.attn import Attention -from transformers.modeling_outputs import ImageClassifierOutput +from transformers.modeling_outputs import ImageClassifierOutput, BaseModelOutput, BaseModelOutputWithPooling, MaskedImageModelingOutput from transformers.modeling_utils import PreTrainedModel from .configuration_delta_net import DeltaNetVisionConfig from fla.layers.delta_net import DeltaNet @@ -108,7 +108,6 @@ def forward( if hasattr(self, 'ln_2'): hidden_states = self.ln_2(hidden_states) - # MLP hidden_states = self.mlp(hidden_states) # Second residual connection @@ -119,7 +118,6 @@ def forward( return outputs class DeltaNetVisionPreTrainedModel(PreTrainedModel): - # this part of the code is adapted from huggingface/transformers vit implementation config_class = DeltaNetVisionConfig def _init_weights(self, module): @@ -139,53 +137,158 @@ def _init_weights(self, module): std=self.config.initializer_range, ).to(module.position_embeddings.dtype) -class DeltaNetForImageClassification(DeltaNetVisionPreTrainedModel): - config_class = DeltaNetVisionConfig - - def __init__(self, config): - super().__init__(config) - self.num_labels = config.num_classes - - self.embeddings = ImageEmbeddings(config) + +class DeltaNetVisionEncoder(nn.Module): + def __init__(self, config) -> None: + super().__init__() + self.config = config self.blocks = nn.ModuleList([ DeltaNetBlock(config, layer_idx) for layer_idx in range(config.num_hidden_layers) ]) - self.norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) - self.pooler = Pooler(config) - self.classifier = nn.Linear(config.hidden_size, config.num_classes) - self.interpolate_pos_encoding = config.interpolate_pos_encoding + self.gradient_checkpointing = False + + def forward( + self, + hidden_states: torch.Tensor, + output_attentions: bool = False, + output_hidden_states: bool = False, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + use_cache: Optional[bool] = None, + return_dict: bool = True, + **kwargs + ) -> Union[tuple, BaseModelOutput]: + all_hidden_states = () if output_hidden_states else None + all_self_attentions = () if output_attentions else None + + for i, block in enumerate(self.blocks): + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if self.gradient_checkpointing and self.training: + hidden_states, attentions, past_key_values = self._gradient_checkpointing_func( + block.__call__, + hidden_states, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + **kwargs + ) + else: + hidden_states, attentions, past_key_values = block( + hidden_states, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + **kwargs + ) + + if output_attentions: + all_self_attentions = all_self_attentions + (attentions,) + + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if not return_dict: + return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None) + + return BaseModelOutput( + last_hidden_state=hidden_states, + hidden_states=all_hidden_states, + attentions=all_self_attentions, + ) + +class DeltaNetVisionModel(DeltaNetVisionPreTrainedModel): + def __init__(self, config, add_pooling_layer=True, use_mask_token=False): + super().__init__(config) + self.config = config + self.embeddings = ImageEmbeddings(config, use_mask_token=use_mask_token) + self.encoder = DeltaNetVisionEncoder(config) + self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.pooler = Pooler(config) if add_pooling_layer else None self.init_weights() - + + def get_input_embeddings(self): + return self.embeddings.patch_embeddings + def forward( self, pixel_values: Optional[torch.Tensor] = None, - labels: Optional[torch.Tensor] = None, + bool_masked_pos: Optional[torch.BoolTensor] = None, + output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + interpolate_pos_encoding: Optional[bool] = None, use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, return_dict: Optional[bool] = None, - **kwargs: Unpack[Dict] - ) -> Union[Tuple, ImageClassifierOutput]: + **kwargs + ) -> Union[Tuple, BaseModelOutputWithPooling]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) return_dict = return_dict if return_dict is not None else self.config.use_return_dict - fuse_linear_and_cross_entropy = self.config.fuse_cross_entropy and self.training - hidden_states = self.embeddings(pixel_values, interpolate_pos_encoding=self.interpolate_pos_encoding) - - for block in self.blocks: - hidden_states, attentions, past_key_values = block( - hidden_states, - past_key_values=past_key_values, - use_cache=use_cache, - output_attentions=output_attentions, - **kwargs - ) - - hidden_states = self.norm(hidden_states) - pooled_output = self.pooler(hidden_states) + + if pixel_values is None: + raise ValueError("You have to specify pixel_values") - logits = self.classifier(pooled_output) + hidden_states = self.embeddings(pixel_values, bool_masked_pos=bool_masked_pos, interpolate_pos_encoding=interpolate_pos_encoding) + encoder_outputs = self.encoder( + hidden_states, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + past_key_values=past_key_values, + use_cache=use_cache, + return_dict=return_dict, + **kwargs + ) + + sequence_output = encoder_outputs[0] + sequence_output = self.layernorm(sequence_output) + pooled_output = self.pooler(sequence_output) if self.pooler is not None else None + + if not return_dict: + head_outputs = (sequence_output, pooled_output) if pooled_output is not None else (sequence_output,) + return head_outputs + encoder_outputs[1:] + + return BaseModelOutputWithPooling( + last_hidden_state=sequence_output, + pooler_output=pooled_output, + hidden_states=encoder_outputs.hidden_states, + attentions=encoder_outputs.attentions, + ) + +class DeltaNetForImageClassification(DeltaNetVisionPreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_classes + self.backbone = DeltaNetVisionModel(config, add_pooling_layer=True) # Here we should use mean pooling + self.classifier = nn.Linear(config.hidden_size, config.num_classes) + self.init_weights() + + def forward( + self, + pixel_values: Optional[torch.Tensor] = None, + labels: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + interpolate_pos_encoding: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[tuple, ImageClassifierOutput]: + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.backbone( + pixel_values, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + interpolate_pos_encoding=interpolate_pos_encoding, + return_dict=return_dict, + ) + + pooled_output = outputs.pooler_output + logits = self.classifier(pooled_output) # only use mean pooling + loss = None if labels is not None: if self.num_labels == 1: @@ -196,11 +299,87 @@ def forward( loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) if not return_dict: - output = (logits,) + (hidden_states,) + output = (logits,) + outputs[2:] return ((loss,) + output) if loss is not None else output return ImageClassifierOutput( loss=loss, logits=logits, - hidden_states=hidden_states, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + +class DeltaNetForMaskedImageModeling(DeltaNetVisionPreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.backbone = DeltaNetVisionModel(config, add_pooling_layer=False, use_mask_token=True) + self.decoder = nn.Sequential( + nn.Conv2d( + in_channels=config.hidden_size, + out_channels=config.encoder_stride**2 * config.num_channels, + kernel_size=1, + ), + nn.PixelShuffle(config.encoder_stride), + ) + self.init_weights() + + def forward( + self, + pixel_values: Optional[torch.Tensor] = None, + bool_masked_pos: Optional[torch.BoolTensor] = None, + labels: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + interpolate_pos_encoding: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[tuple, MaskedImageModelingOutput]: + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if bool_masked_pos is not None and (self.config.patch_size != self.config.encoder_stride): + raise ValueError( + "When `bool_masked_pos` is provided, `patch_size` must be equal to `encoder_stride` to ensure that " + "the reconstructed image has the same dimensions as the input. " + f"Got `patch_size` = {self.config.patch_size} and `encoder_stride` = {self.config.encoder_stride}." + ) + + outputs = self.backbone( + pixel_values, + bool_masked_pos=bool_masked_pos, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + interpolate_pos_encoding=interpolate_pos_encoding, + return_dict=return_dict, + ) + + + sequence_output = outputs[0] + batch_size, sequence_length, num_channels = sequence_output.shape + height = width = math.floor(sequence_length**0.5) + sequence_output = sequence_output.permute(0, 2, 1).reshape(batch_size, num_channels, height, width) + + # Reconstruct pixel values + reconstructed_pixel_values = self.decoder(sequence_output) + + masked_im_loss = None + if bool_masked_pos is not None: + size = self.config.image_size // self.config.patch_size + bool_masked_pos = bool_masked_pos.reshape(-1, size, size) + mask = ( + bool_masked_pos.repeat_interleave(self.config.patch_size, 1) + .repeat_interleave(self.config.patch_size, 2) + .unsqueeze(1) + .contiguous() + ) + reconstruction_loss = nn.functional.l1_loss(pixel_values, reconstructed_pixel_values, reduction="none") + masked_im_loss = (reconstruction_loss * mask).sum() / (mask.sum() + 1e-5) / self.config.num_channels + + if not return_dict: + output = (reconstructed_pixel_values,) + outputs[1:] + return ((masked_im_loss,) + output) if masked_im_loss is not None else output + + return MaskedImageModelingOutput( + loss=masked_im_loss, + reconstruction=reconstructed_pixel_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, ) From a3b03b608474ffe3ba7cb9bb772d2d51c1496bde Mon Sep 17 00:00:00 2001 From: yibozhong Date: Sun, 19 Jan 2025 14:04:52 +0800 Subject: [PATCH 14/17] Remove training script --- training/classification.py | 590 ------------------------------------- 1 file changed, 590 deletions(-) delete mode 100644 training/classification.py diff --git a/training/classification.py b/training/classification.py deleted file mode 100644 index 07ad9be44..000000000 --- a/training/classification.py +++ /dev/null @@ -1,590 +0,0 @@ -import os -import torch -from tqdm import tqdm -import wandb -import logging -import random -import torch.optim as optim -from torch.utils.data import DataLoader -from torchvision import datasets, transforms -from transformers import get_scheduler -from torch.amp import GradScaler, autocast -from fla.vision_models.abc import ABCVisionConfig, ABCForImageClassification -from fla.vision_models.bitnet import BitNetVisionConfig, BitNetForImageClassification -from fla.vision_models.delta_net import DeltaNetVisionConfig, DeltaNetForImageClassification -from fla.vision_models.gated_deltanet import GatedDeltaNetVisionConfig, GatedDeltaNetForImageClassification -from fla.vision_models.gla import GLAVisionConfig, GLAForImageClassification -from fla.vision_models.gsa import GSAVisionConfig, GSAForImageClassification -from fla.vision_models.hgrn import HGRNVisionConfig, HGRNForImageClassification -from fla.vision_models.hgrn2 import HGRN2VisionConfig, HGRN2ForImageClassification -from fla.vision_models.linear_attn import LinearAttentionVisionConfig, LinearAttentionForImageClassification -from fla.vision_models.retnet import RetNetVisionConfig, RetNetForImageClassification -from fla.vision_models.rwkv6 import RWKV6VisionConfig, RWKV6ForImageClassification -from fla.vision_models.transformer import TransformerVisionConfig, TransformerForImageClassification -import time - -device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') -dtype = torch.bfloat16 # deafult dtype for FLA - -def setup_logging(args): - # check whether logs directory exists - if not os.path.exists('logs'): - os.makedirs('logs') - log_filename = f'logs/training_{args.model}_vision_{args.dataset}{"_hybrid" if args.use_attn else ""}_{args.scan_type}.log' - logging.basicConfig( - level=logging.INFO, - format='%(asctime)s - %(levelname)s - %(message)s', - handlers=[ - logging.FileHandler(log_filename), - logging.StreamHandler() - ] - ) - logging.info(f"Logging to {log_filename}") - -def get_args(): - import argparse - parser = argparse.ArgumentParser(description='Vision Model Training') - parser.add_argument('--dataset', type=str, default='cifar100', help='Dataset name') - parser.add_argument('--num_hidden_layers', type=int, default=12, help='Number of hidden layers') - parser.add_argument('--hidden_size', type=int, default=768, help='Hidden size') - parser.add_argument('--patch_size', type=int, default=16, help='Patch size') - parser.add_argument('--image_size', type=int, default=224, help='Image size') - parser.add_argument('--epochs', type=int, default=50, help='Number of epochs') - parser.add_argument('--amp_enabled', action='store_true', help='Enable AMP if device supports it') - parser.add_argument('--b_lr', type=float, default=2e-4, help='Backbone learning rate') - parser.add_argument('--h_lr', type=float, default=2e-4, help='Head learning rate') - parser.add_argument('--wd', type=float, default=0., help='Weight decay') - parser.add_argument('--train_bs', type=int, default=128, help='Training batch size') - parser.add_argument('--eval_bs', type=int, default=256, help='Eval batch size') - parser.add_argument('--num_workers', type=int, default=4, help='Number of workers') - parser.add_argument('--num_heads', type=int, default=16, help='Number of attention heads') - parser.add_argument('--eval_epoch', type=int, default=1, help='Eval frequency') - parser.add_argument('--log_step', type=int, default=10, help='Log frequency') - parser.add_argument('--seed', type=int, default=42, help='Random seed') - parser.add_argument('--wandb', action='store_true', help='Enable wandb logging') - parser.add_argument('--expand_k', type=float, default=1.0, help='Key expansion ratio') - parser.add_argument('--expand_v', type=float, default=1.0, help='Value expansion ratio') - parser.add_argument('--attn_mode', type=str, default='chunk', choices=['chunk', 'fused_recurrent', 'fused_chunk']) - parser.add_argument('--model', type=str, required=True, help='Model type (currently only supports "deltanet")') - parser.add_argument('--fuse_cross_entropy', action='store_true', help='Fuse cross entropy with logits') - parser.add_argument('--scan_type', type=str, default='uni-scan', choices=['uni-scan', 'bi-scan', 'cross-scan'],) - - # Learning rate scheduler related arguments - parser.add_argument('--lr_scheduler_type', type=str, default='constant_with_warmup', - choices=['linear', 'cosine', 'cosine_with_restarts', 'polynomial', - 'constant', 'constant_with_warmup']) - parser.add_argument('--warmup_ratio', type=float, default=0.1, - help='Ratio of total training steps for warmup') - # Add hybrid attention related arguments - parser.add_argument('--use_attn', action='store_true', help='Use hybrid attention in some layers') - parser.add_argument('--attn_layers', type=str, default='0,1', - help='Comma separated list of layer indices to use attention, e.g. "0,1,2"') - # Hybrid architecture related arguments - parser.add_argument('--attn_num_heads', type=int, default=16, - help='Number of attention heads for hybrid attention layers') - parser.add_argument('--attn_num_kv_heads', type=int, default=None, - help='Number of key/value heads for hybrid attention layers') - parser.add_argument('--attn_window_size', type=int, default=None, - help='Window size for hybrid attention layers') - parser.add_argument('--log_memory_epoch', type=int, default=100, help='Log memory usage frequency') - return parser.parse_args() - -def get_data(args): - """ - Prepare data transforms and loaders. - Ensures consistent data types with model. - Current suppport only training with CIFAR-10 and CIFAR-100. - """ - transform = transforms.Compose([ - transforms.Resize((args.image_size, args.image_size)), - transforms.ToTensor(), - transforms.ConvertImageDtype(dtype), - ]) - - if args.dataset == 'cifar10': - train_dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform) - test_dataset = datasets.CIFAR10(root='./data', train=False, download=True, transform=transform) - num_classes = 10 - elif args.dataset == 'cifar100': - train_dataset = datasets.CIFAR100(root='./data', train=True, download=True, transform=transform) - test_dataset = datasets.CIFAR100(root='./data', train=False, download=True, transform=transform) - num_classes = 100 - else: - raise ValueError(f"Unsupported dataset: {args.dataset}") - - train_loader = DataLoader(train_dataset, batch_size=args.train_bs, shuffle=True, num_workers=args.num_workers) - test_loader = DataLoader(test_dataset, batch_size=args.eval_bs, shuffle=False, num_workers=args.num_workers) - - return train_loader, test_loader, num_classes - -def setup_deterministic_mode(args): - """Setup deterministic mode for reproducibility on the same device""" - import numpy as np - np.random.seed(args.seed) - random.seed(args.seed) - torch.manual_seed(args.seed) - torch.cuda.manual_seed(args.seed) - torch.backends.cudnn.deterministic = True - torch.backends.cudnn.benchmark = False - -def get_gpu_memory_info(): - """ - Get current GPU memory usage information - Returns a dictionary with: - - memory_allocated: actively allocated memory - - memory_reserved: reserved memory in GPU - - max_memory_allocated: max allocated memory since the beginning - """ - return { - 'memory_allocated': torch.cuda.memory_allocated() / 1024**2, # MB - 'memory_reserved': torch.cuda.memory_reserved() / 1024**2, # MB - 'max_memory_allocated': torch.cuda.max_memory_allocated() / 1024**2 # MB - } - -def log_gpu_memory(args, epoch): - """Log GPU memory usage if CUDA is available""" - if torch.cuda.is_available() and epoch % args.log_memory_epoch == 0: - memory_info = get_gpu_memory_info() - logging.info( - f"GPU Memory Usage (Epoch {epoch}) - " - f"Allocated: {memory_info['memory_allocated']:.2f}MB, " - f"Reserved: {memory_info['memory_reserved']:.2f}MB, " - f"Peak: {memory_info['max_memory_allocated']:.2f}MB" - ) - if args.wandb: - wandb.log({ - "gpu_memory/allocated": memory_info['memory_allocated'], - "gpu_memory/reserved": memory_info['memory_reserved'], - "gpu_memory/peak": memory_info['max_memory_allocated'], - "epoch": epoch - }) - -def evaluate(model, test_loader, device, args): - """ - Evaluation loop with proper CUDA timing. - """ - model.eval() - correct = 0 - total = 0 - - # Create CUDA events for timing - if torch.cuda.is_available(): - start_event = torch.cuda.Event(enable_timing=True) - end_event = torch.cuda.Event(enable_timing=True) - torch.cuda.synchronize() - start_event.record() - else: - start_time = time.perf_counter() - - with torch.no_grad(): - for images, targets in tqdm(test_loader): - images = images.to(device=device, dtype=dtype) - targets = targets.to(device) - - if args.amp_enabled: - with autocast(): - outputs = model(images).logits - _, predicted = outputs.max(1) - else: - outputs = model(images).logits - _, predicted = outputs.max(1) - - total += targets.size(0) - correct += predicted.eq(targets).sum().item() - - # Measure time with proper CUDA synchronization - if torch.cuda.is_available(): - end_event.record() - torch.cuda.synchronize() - eval_time = start_event.elapsed_time(end_event) / 1000.0 # Convert to seconds - else: - eval_time = time.perf_counter() - start_time - - accuracy = 100. * correct / total - return accuracy, eval_time - -def get_model(args, num_classes): - """ - Initialize model based on configuration. - Supports both pure DeltaNet and hybrid models. - """ - # Prepare attention config for hybrid model if enabled - attn_config = None - if args.use_attn: - attn_config = { - 'layers': [int(i) for i in args.attn_layers.split(',')], - 'num_heads': args.attn_num_heads, - 'num_kv_heads': args.attn_num_kv_heads, - 'window_size': args.attn_window_size - } - # Log hybrid attention configuration - logging.info("Hybrid Attention Configuration:") - logging.info(f"- Attention Layers: {attn_config['layers']}") - logging.info(f"- Number of Heads: {attn_config['num_heads']}") - logging.info(f"- Number of KV Heads: {attn_config['num_kv_heads']}") - logging.info(f"- Window Size: {attn_config['window_size']}") - - if args.model == 'deltanet': - config = DeltaNetVisionConfig( - num_hidden_layers=args.num_hidden_layers, - hidden_size=args.hidden_size, - num_heads=args.num_heads, - patch_size=args.patch_size, - image_size=args.image_size, - num_classes=num_classes, - attn_mode=args.attn_mode, - fuse_cross_entropy=args.fuse_cross_entropy, - attn=attn_config, # Add attention config for hybrid model - scan_type=args.scan_type # Add scan type to choose different scaning strategy - ) - return DeltaNetForImageClassification(config).to(device=device, dtype=dtype) - - elif args.model == 'abc': - config = ABCVisionConfig( - num_hidden_layers=args.num_hidden_layers, - hidden_size=args.hidden_size, - num_heads=args.num_heads, - patch_size=args.patch_size, - image_size=args.image_size, - num_classes=num_classes, - attn_mode=args.attn_mode, - fuse_cross_entropy=args.fuse_cross_entropy, - attn=attn_config, # Add attention config for hybrid model - scan_type=args.scan_type # Add scan type to choose different scaning strategy - ) - return ABCForImageClassification(config).to(device=device, dtype=dtype) - - elif args.model == 'gated_deltanet': - config = GatedDeltaNetVisionConfig( - num_hidden_layers=args.num_hidden_layers, - hidden_size=args.hidden_size, - num_heads=args.num_heads, - patch_size=args.patch_size, - image_size=args.image_size, - num_classes=num_classes, - attn_mode=args.attn_mode, - fuse_cross_entropy=args.fuse_cross_entropy, - attn=attn_config, # Add attention config for hybrid model - scan_type=args.scan_type # Add scan type to choose different scaning strategy - ) - return GatedDeltaNetForImageClassification(config).to(device=device, dtype=dtype) - - elif args.model == 'bitnet': - config = BitNetVisionConfig( - num_hidden_layers=args.num_hidden_layers, - hidden_size=args.hidden_size, - num_heads=args.num_heads, - patch_size=args.patch_size, - image_size=args.image_size, - num_classes=num_classes, - attn_mode=args.attn_mode, - fuse_cross_entropy=args.fuse_cross_entropy, - attn=attn_config, # Add attention config for hybrid model - scan_type=args.scan_type # Add scan type to choose different scaning strategy - ) - return BitNetForImageClassification(config).to(device=device, dtype=dtype) - - elif args.model == 'gla': - config = GLAVisionConfig( - num_hidden_layers=args.num_hidden_layers, - hidden_size=args.hidden_size, - num_heads=args.num_heads, - patch_size=args.patch_size, - image_size=args.image_size, - num_classes=num_classes, - attn_mode=args.attn_mode, - fuse_cross_entropy=args.fuse_cross_entropy, - attn=attn_config, # Add attention config for hybrid model - scan_type=args.scan_type # Add scan type to choose different scaning strategy - ) - return GLAForImageClassification(config).to(device=device, dtype=dtype) - - elif args.model == 'gsa': - config = GSAVisionConfig( - num_hidden_layers=args.num_hidden_layers, - hidden_size=args.hidden_size, - num_heads=args.num_heads, - patch_size=args.patch_size, - image_size=args.image_size, - num_classes=num_classes, - attn_mode=args.attn_mode, - fuse_cross_entropy=args.fuse_cross_entropy, - attn=attn_config, # Add attention config for hybrid model - scan_type=args.scan_type # Add scan type to choose different scaning strategy - ) - return GSAForImageClassification(config).to(device=device, dtype=dtype) - - elif args.model == 'hgrn': - config = HGRNVisionConfig( - num_hidden_layers=args.num_hidden_layers, - hidden_size=args.hidden_size, - num_heads=args.num_heads, - patch_size=args.patch_size, - image_size=args.image_size, - num_classes=num_classes, - attn_mode=args.attn_mode, - fuse_cross_entropy=args.fuse_cross_entropy, - attn=attn_config, # Add attention config for hybrid model - scan_type=args.scan_type # Add scan type to choose different scaning strategy - ) - return HGRNForImageClassification(config).to(device=device, dtype=dtype) - - elif args.model == 'hgrn2': - config = HGRN2VisionConfig( - num_hidden_layers=args.num_hidden_layers, - hidden_size=args.hidden_size, - num_heads=args.num_heads, - patch_size=args.patch_size, - image_size=args.image_size, - num_classes=num_classes, - attn_mode=args.attn_mode, - fuse_cross_entropy=args.fuse_cross_entropy, - attn=attn_config, # Add attention config for hybrid model - scan_type=args.scan_type # Add scan type to choose different scaning strategy - ) - return HGRN2ForImageClassification(config).to(device=device, dtype=dtype) - - elif args.model == 'linear_attn': - config = LinearAttentionVisionConfig( - num_hidden_layers=args.num_hidden_layers, - hidden_size=args.hidden_size, - num_heads=args.num_heads, - patch_size=args.patch_size, - image_size=args.image_size, - num_classes=num_classes, - attn_mode=args.attn_mode, - fuse_cross_entropy=args.fuse_cross_entropy, - attn=attn_config, # Add attention config for hybrid model - scan_type=args.scan_type # Add scan type to choose different scaning strategy - ) - return LinearAttentionForImageClassification(config).to(device=device, dtype=dtype) - - elif args.model == 'retnet': - config = RetNetVisionConfig( - num_hidden_layers=args.num_hidden_layers, - hidden_size=args.hidden_size, - num_heads=args.num_heads, - patch_size=args.patch_size, - image_size=args.image_size, - num_classes=num_classes, - attn_mode=args.attn_mode, - fuse_cross_entropy=args.fuse_cross_entropy, - attn=attn_config, # Add attention config for hybrid model - scan_type=args.scan_type # Add scan type to choose different scaning strategy - ) - return RetNetForImageClassification(config).to(device=device, dtype=dtype) - - elif args.model == 'rwkv6': - config = RWKV6VisionConfig( - num_hidden_layers=args.num_hidden_layers, - hidden_size=args.hidden_size, - num_heads=args.num_heads, - patch_size=args.patch_size, - image_size=args.image_size, - num_classes=num_classes, - attn_mode=args.attn_mode, - fuse_cross_entropy=args.fuse_cross_entropy, - attn=attn_config, # Add attention config for hybrid model - scan_type=args.scan_type # Add scan type to choose different scaning strategy - ) - return RWKV6ForImageClassification(config).to(device=device, dtype=dtype) - - elif args.model == 'transformer': - config = TransformerVisionConfig( - num_hidden_layers=args.num_hidden_layers, - hidden_size=args.hidden_size, - num_heads=args.num_heads, - patch_size=args.patch_size, - image_size=args.image_size, - num_classes=num_classes - ) - return TransformerForImageClassification(config).to(device=device, dtype=dtype) - -def train_one_epoch(model, train_loader, criterion, optimizer, scheduler, device, args, epoch): - """ - Training loop for one epoch with proper CUDA timing. - """ - model.train() - total_loss = 0 - scaler = GradScaler() if args.amp_enabled else None - - # Create CUDA events for timing - if torch.cuda.is_available(): - start_event = torch.cuda.Event(enable_timing=True) - end_event = torch.cuda.Event(enable_timing=True) - torch.cuda.synchronize() - start_event.record() - else: - start_time = time.perf_counter() - - for i, (images, targets) in enumerate(tqdm(train_loader)): - images = images.to(device=device, dtype=dtype) - targets = targets.to(device) - - if args.amp_enabled: - with autocast(): - outputs = model(images).logits - loss = criterion(outputs, targets) - scaler.scale(loss).backward() - scaler.step(optimizer) - scaler.update() - else: - outputs = model(images).logits - loss = criterion(outputs, targets) - loss.backward() - optimizer.step() - - optimizer.zero_grad() - scheduler.step() # Update learning rate scheduler - total_loss += loss.item() - - if i % args.log_step == 0: - lrs = [group['lr'] for group in optimizer.param_groups] - logging.info(f'Epoch {epoch} Step {i}/{len(train_loader)}: ' - f'Loss={loss.item():.4f} ' - f'LR_backbone={lrs[0]:.2e} ' - f'LR_head={lrs[-1]:.2e}') - - if args.wandb: - wandb.log({ - "batch_loss": loss.item(), - "learning_rate/backbone": lrs[0], - "learning_rate/head": lrs[-1], - "global_step": epoch * len(train_loader) + i - }) - - # Measure time with proper CUDA synchronization - if torch.cuda.is_available(): - end_event.record() - torch.cuda.synchronize() - train_time = start_event.elapsed_time(end_event) / 1000.0 - else: - train_time = time.perf_counter() - start_time - - avg_loss = total_loss / len(train_loader) - return avg_loss, train_time - -def main(): - args = get_args() - - # Setup logging first, before any logging calls - setup_logging(args) - - # Then setup deterministic mode - setup_deterministic_mode(args) - - # Log all configuration parameters - logging.info("=" * 50) - logging.info("Training Configuration:") - logging.info("-" * 50) - for arg, value in sorted(vars(args).items()): - logging.info(f"{arg}: {value}") - logging.info("=" * 50) - - # Setup wandb after logging is initialized - if args.wandb: - project_name = "fla-vision" - run_name = f'training_{args.model}_{args.dataset}{"_hybrid" if args.use_attn else ""}_{args.scan_type}_e{args.epochs}_blr_{args.b_lr}_hlr_{args.h_lr}_bs{args.train_bs}_mode_{args.attn_mode}' - wandb.init( - project=project_name, - name=run_name, - config=args.__dict__ - ) - logging.info(f"Wandb initialized with project: {project_name}, run: {run_name}") - - train_loader, test_loader, num_classes = get_data(args) - - # Calculate total training steps - num_training_steps = len(train_loader) * args.epochs - num_warmup_steps = int(args.warmup_ratio * num_training_steps) - - model = get_model(args, num_classes) - trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad) - - logging.info("=" * 50) - logging.info("Model Information:") - logging.info("-" * 50) - logging.info(f"Model Type: {args.model}") - logging.info(f"Number of trainable parameters: {trainable_params:,}") - logging.info(f"Number of layers: {args.num_hidden_layers}") - logging.info(f"Hidden size: {args.hidden_size}") - logging.info(f"Number of heads: {args.num_heads}") - logging.info(f"Learning rate scheduler: {args.lr_scheduler_type}") - logging.info(f"Total training steps: {num_training_steps}") - logging.info(f"Warmup steps: {num_warmup_steps}") - logging.info("=" * 50) - - if args.wandb: - wandb.log({"trainable_parameters": trainable_params}) - - criterion = torch.nn.CrossEntropyLoss() - optimizer = optim.AdamW([ - {'params': model.embeddings.parameters(), 'lr': args.b_lr}, - {'params': model.blocks.parameters(), 'lr': args.b_lr}, - {'params': model.classifier.parameters(), 'lr': args.h_lr} - ], weight_decay=args.wd) - - scheduler = get_scheduler( - name=args.lr_scheduler_type, - optimizer=optimizer, - num_warmup_steps=num_warmup_steps, - num_training_steps=num_training_steps - ) - - best_acc = 0 - total_train_time = 0 - total_eval_time = 0 - eval_num = 0 - - for epoch in range(args.epochs): - avg_loss, epoch_train_time = train_one_epoch(model, train_loader, criterion, optimizer, scheduler, device, args, epoch) - total_train_time += epoch_train_time - - # Log GPU memory usage - log_gpu_memory(args, epoch) - - if epoch % args.eval_epoch == 0: - accuracy, epoch_eval_time = evaluate(model, test_loader, device, args) - total_eval_time += epoch_eval_time - eval_num += 1 - - logging.info( - f'Epoch {epoch}: Loss={avg_loss:.4f}, Accuracy={accuracy:.2f}%, ' - f'Train time={epoch_train_time:.2f}s, Eval time={epoch_eval_time:.2f}s' - ) - - if args.wandb: - wandb.log({ - "epoch": epoch, - "train_loss": avg_loss, - "accuracy": accuracy, - "train_time": epoch_train_time, - "eval_time": epoch_eval_time, - "avg_epoch_train_time": total_train_time / (epoch + 1), - "avg_epoch_eval_time": total_eval_time / eval_num - }) - - if accuracy > best_acc: - best_acc = accuracy - torch.save(model.state_dict(), f'{args.model}_vision_best.pth') - - # Log final statistics - avg_train_time = total_train_time / args.epochs - avg_eval_time = total_eval_time / eval_num - logging.info( - f'Training completed. Best accuracy: {best_acc:.2f}%\n' - f'Average training time per epoch: {avg_train_time:.2f}s\n' - f'Average evaluation time: {avg_eval_time:.2f}s' - ) - - if args.wandb: - wandb.log({ - "final/best_accuracy": best_acc, - "final/avg_train_time": avg_train_time, - "final/avg_eval_time": avg_eval_time - }) - if args.wandb: - wandb.finish() - -if __name__ == '__main__': - main() From 683913ff2fe9bc2520b742cb35aef31d48f95130 Mon Sep 17 00:00:00 2001 From: yibozhong Date: Sun, 19 Jan 2025 17:15:41 +0800 Subject: [PATCH 15/17] remove separate folder --- fla/vision_models/__init__.py | 43 -- fla/vision_models/abc/__init__.py | 12 - fla/vision_models/abc/configuration_abc.py | 97 ---- fla/vision_models/abc/modeling_abc.py | 203 -------- fla/vision_models/bitnet/__init__.py | 12 - .../bitnet/configuration_bitnet.py | 95 ---- fla/vision_models/bitnet/modeling_bitnet.py | 201 -------- fla/vision_models/delta_net/__init__.py | 16 - .../delta_net/configuration_delta_net.py | 101 ---- .../delta_net/modeling_delta_net.py | 385 -------------- fla/vision_models/gated_deltanet/__init__.py | 13 - .../configuration_gated_deltanet.py | 89 ---- .../gated_deltanet/modeling_gated_deltanet.py | 202 -------- fla/vision_models/gla/__init__.py | 12 - fla/vision_models/gla/configuration_gla.py | 101 ---- fla/vision_models/gla/modeling_gla.py | 205 -------- fla/vision_models/gsa/__init__.py | 12 - fla/vision_models/gsa/configuration_gsa.py | 107 ---- fla/vision_models/gsa/modeling_gsa.py | 209 -------- fla/vision_models/hgrn/__init__.py | 12 - fla/vision_models/hgrn/configuration_hgrn.py | 86 ---- fla/vision_models/hgrn/modeling_hgrn.py | 197 ------- fla/vision_models/hgrn2/__init__.py | 12 - .../hgrn2/configuration_hgrn2.py | 89 ---- fla/vision_models/hgrn2/modeling_hgrn2.py | 198 -------- fla/vision_models/linear_attn/__init__.py | 12 - .../linear_attn/configuration_linear_attn.py | 96 ---- .../linear_attn/modeling_linear_attn.py | 197 ------- fla/vision_models/retnet/__init__.py | 12 - .../retnet/configuration_retnet.py | 101 ---- fla/vision_models/retnet/modeling_retnet.py | 202 -------- fla/vision_models/rwkv6/__init__.py | 12 - .../rwkv6/configuration_rwkv6.py | 94 ---- fla/vision_models/rwkv6/modeling_rwkv6.py | 199 -------- fla/vision_models/transformer/__init__.py | 12 - .../transformer/configuration_transformer.py | 81 --- .../transformer/modeling_transformer.py | 190 ------- fla/vision_models/utils.py | 480 ------------------ 38 files changed, 4397 deletions(-) delete mode 100644 fla/vision_models/__init__.py delete mode 100644 fla/vision_models/abc/__init__.py delete mode 100644 fla/vision_models/abc/configuration_abc.py delete mode 100644 fla/vision_models/abc/modeling_abc.py delete mode 100644 fla/vision_models/bitnet/__init__.py delete mode 100644 fla/vision_models/bitnet/configuration_bitnet.py delete mode 100644 fla/vision_models/bitnet/modeling_bitnet.py delete mode 100644 fla/vision_models/delta_net/__init__.py delete mode 100644 fla/vision_models/delta_net/configuration_delta_net.py delete mode 100644 fla/vision_models/delta_net/modeling_delta_net.py delete mode 100644 fla/vision_models/gated_deltanet/__init__.py delete mode 100644 fla/vision_models/gated_deltanet/configuration_gated_deltanet.py delete mode 100644 fla/vision_models/gated_deltanet/modeling_gated_deltanet.py delete mode 100644 fla/vision_models/gla/__init__.py delete mode 100644 fla/vision_models/gla/configuration_gla.py delete mode 100644 fla/vision_models/gla/modeling_gla.py delete mode 100644 fla/vision_models/gsa/__init__.py delete mode 100644 fla/vision_models/gsa/configuration_gsa.py delete mode 100644 fla/vision_models/gsa/modeling_gsa.py delete mode 100644 fla/vision_models/hgrn/__init__.py delete mode 100644 fla/vision_models/hgrn/configuration_hgrn.py delete mode 100644 fla/vision_models/hgrn/modeling_hgrn.py delete mode 100644 fla/vision_models/hgrn2/__init__.py delete mode 100644 fla/vision_models/hgrn2/configuration_hgrn2.py delete mode 100644 fla/vision_models/hgrn2/modeling_hgrn2.py delete mode 100644 fla/vision_models/linear_attn/__init__.py delete mode 100644 fla/vision_models/linear_attn/configuration_linear_attn.py delete mode 100644 fla/vision_models/linear_attn/modeling_linear_attn.py delete mode 100644 fla/vision_models/retnet/__init__.py delete mode 100644 fla/vision_models/retnet/configuration_retnet.py delete mode 100644 fla/vision_models/retnet/modeling_retnet.py delete mode 100644 fla/vision_models/rwkv6/__init__.py delete mode 100644 fla/vision_models/rwkv6/configuration_rwkv6.py delete mode 100644 fla/vision_models/rwkv6/modeling_rwkv6.py delete mode 100644 fla/vision_models/transformer/__init__.py delete mode 100644 fla/vision_models/transformer/configuration_transformer.py delete mode 100644 fla/vision_models/transformer/modeling_transformer.py delete mode 100644 fla/vision_models/utils.py diff --git a/fla/vision_models/__init__.py b/fla/vision_models/__init__.py deleted file mode 100644 index 9538e6c2c..000000000 --- a/fla/vision_models/__init__.py +++ /dev/null @@ -1,43 +0,0 @@ -from fla.vision_models.abc import ABCVisionConfig, ABCForImageClassification -from fla.vision_models.bitnet import BitNetVisionConfig, BitNetForImageClassification -from fla.vision_models.delta_net import DeltaNetVisionConfig, DeltaNetForImageClassification -from fla.vision_models.gated_deltanet import GatedDeltaNetVisionConfig, GatedDeltaNetForImageClassification -from fla.vision_models.gla import GLAVisionConfig, GLAForImageClassification -from fla.vision_models.gsa import GSAVisionConfig, GSAForImageClassification -from fla.vision_models.hgrn import HGRNVisionConfig, HGRNForImageClassification -from fla.vision_models.hgrn2 import HGRN2VisionConfig, HGRN2ForImageClassification -from fla.vision_models.linear_attn import LinearAttentionVisionConfig, LinearAttentionForImageClassification -from fla.vision_models.retnet import RetNetVisionConfig, RetNetForImageClassification -from fla.vision_models.rwkv6 import RWKV6VisionConfig, RWKV6ForImageClassification -from fla.vision_models.transformer import TransformerVisionConfig, TransformerForImageClassification -from fla.vision_models.utils import ImageEmbeddings, PatchEmbeddings, Pooler - -__all__ = [ - 'ABCVisionConfig', - 'ABCForImageClassification', - 'BitNetVisionConfig', - 'BitNetForImageClassification', - 'DeltaNetVisionConfig', - 'DeltaNetForImageClassification', - 'GatedDeltaNetVisionConfig', - 'GatedDeltaNetForImageClassification', - 'GLAVisionConfig', - 'GLAForImageClassification', - 'GSAVisionConfig', - 'GSAForImageClassification', - 'HGRNVisionConfig', - 'HGRNForImageClassification', - 'HGRN2VisionConfig', - 'HGRN2ForImageClassification', - 'LinearAttentionVisionConfig', - 'LinearAttentionForImageClassification', - 'RetNetVisionConfig', - 'RetNetForImageClassification', - 'RWKV6VisionConfig', - 'RWKV6ForImageClassification', - 'TransformerVisionConfig', - 'TransformerForImageClassification', - 'ImageEmbeddings', - 'PatchEmbeddings', - 'Pooler', -] diff --git a/fla/vision_models/abc/__init__.py b/fla/vision_models/abc/__init__.py deleted file mode 100644 index 67d013691..000000000 --- a/fla/vision_models/abc/__init__.py +++ /dev/null @@ -1,12 +0,0 @@ -from transformers import AutoConfig, AutoModelForImageClassification - -from fla.vision_models.abc.configuration_abc import ABCVisionConfig -from fla.vision_models.abc.modeling_abc import ABCForImageClassification - -AutoConfig.register(ABCVisionConfig.model_type, ABCVisionConfig) -AutoModelForImageClassification.register(ABCVisionConfig, ABCForImageClassification) - -__all__ = [ - 'ABCVisionConfig', - 'ABCForImageClassification' -] diff --git a/fla/vision_models/abc/configuration_abc.py b/fla/vision_models/abc/configuration_abc.py deleted file mode 100644 index 13de13c09..000000000 --- a/fla/vision_models/abc/configuration_abc.py +++ /dev/null @@ -1,97 +0,0 @@ -from typing import Dict, Optional - -from transformers.configuration_utils import PretrainedConfig - - -class ABCVisionConfig(PretrainedConfig): - - model_type = 'abc_vision' - - def __init__( - self, - # ABC core parameters - hidden_size: int = 2048, - gate_low_rank_dim: int = 16, - clamp_min: float = -32, - clamp_max: float = 32, - num_hidden_layers: int = 24, - num_heads: int = 4, - num_slots: Optional[int] = 64, - use_short_conv: bool = False, - conv_size: int = 4, - exapnd_k: float = 0.5, - exapnd_v: float = 1, - hidden_act: str = "swish", - max_position_embeddings: int = 2048, - elementwise_affine: Optional[bool] = True, - norm_eps: float = 1e-6, - attn: Optional[Dict] = None, - use_cache: bool = True, - initializer_range: float = 0.02, - fuse_norm: bool = True, - fuse_cross_entropy: bool = True, - # Vision specific parameters - image_size: int = 224, - patch_size: int = 16, - num_channels: int = 3, - num_classes: int = 1000, - qkv_bias: bool = True, - hidden_dropout_prob: float = 0.0, - use_mask_token: bool = False, - layer_norm_eps: float = 1e-6, - interpolate_pos_encoding: bool = False, - mlp_dim: int = None, - # FLA-for-vision-related parameters - scan_type: str = "uni-scan", # scaning type, "uni-scan" or "bi-scan" or "cross-scan", default to "uni-scan" - **kwargs - ): - # Initialize ABC core parameters - self.hidden_size = hidden_size - self.gate_low_rank_dim = gate_low_rank_dim - self.clamp_min = clamp_min - self.clamp_max = clamp_max - self.num_hidden_layers = num_hidden_layers - self.num_heads = num_heads - self.num_slots = num_slots - self.use_short_conv = use_short_conv - self.conv_size = conv_size - self.expand_k = exapnd_k - self.expand_v = exapnd_v - self.hidden_act = hidden_act - self.max_position_embeddings = max_position_embeddings - self.elementwise_affine = elementwise_affine - self.norm_eps = norm_eps - self.use_cache = use_cache - self.initializer_range = initializer_range - self.fuse_norm = fuse_norm - self.fuse_cross_entropy = fuse_cross_entropy - - # Initialize vision specific parameters - self.image_size = image_size - self.patch_size = patch_size - self.num_channels = num_channels - self.num_classes = num_classes - self.qkv_bias = qkv_bias - self.hidden_dropout_prob = hidden_dropout_prob - self.use_mask_token = use_mask_token - self.layer_norm_eps = layer_norm_eps - self.interpolate_pos_encoding = interpolate_pos_encoding - self.scan_type = scan_type - - if attn is not None: - if not isinstance(attn, Dict): - raise ValueError("attn must be a dictionary") - if 'layers' not in attn: - raise ValueError("Layer indices must be provided to initialize hybrid attention layers") - if 'num_heads' not in attn: - raise ValueError("Number of heads must be provided to initialize hybrid attention layers") - attn['num_kv_heads'] = attn.get('num_kv_heads', attn['num_heads']) - attn['window_size'] = attn.get('window_size', None) - - self.attn = attn - if mlp_dim is None: - self.mlp_dim = 4 * hidden_size # default value set to 4 * hidden_size - else: - self.mlp_dim = mlp_dim - - super().__init__(**kwargs) diff --git a/fla/vision_models/abc/modeling_abc.py b/fla/vision_models/abc/modeling_abc.py deleted file mode 100644 index ecba55a35..000000000 --- a/fla/vision_models/abc/modeling_abc.py +++ /dev/null @@ -1,203 +0,0 @@ -import collections.abc -import math -import torch -import torch.utils.checkpoint -from torch import nn -from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss -from typing import Optional, Set, Tuple, Union, List, Dict, Unpack -from transformers.utils import logging -from fla.layers.attn import Attention -from transformers.modeling_outputs import ImageClassifierOutput -from transformers.modeling_utils import PreTrainedModel -from .configuration_abc import ABCVisionConfig -from fla.layers.abc import ABCAttention -from fla.models.utils import Cache -from ..utils import ImageEmbeddings, Pooler, prepare_hidden_states_for_cross_scan, prepare_hidden_states_for_cross_merge - -logger = logging.get_logger(__name__) - -class ABCMLP(nn.Module): - def __init__(self, config): - super().__init__() - self.net = nn.Sequential( - nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps), - nn.Linear(config.hidden_size, config.mlp_dim), - nn.GELU(), - nn.Linear(config.mlp_dim, config.hidden_size), - nn.Dropout(config.hidden_dropout_prob) - ) - - def forward(self, x): - return self.net(x) - -class ABCBlock(nn.Module): - def __init__(self, config, layer_idx: int): - super().__init__() - - self.ln_1 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) - - if config.attn is not None and layer_idx in config.attn['layers']: - self.attn = Attention( - hidden_size=config.hidden_size, - num_heads=config.attn['num_heads'], - num_kv_heads=config.attn['num_kv_heads'], - window_size=config.attn['window_size'], - max_position_embeddings=config.max_position_embeddings, - layer_idx=layer_idx - ) - else: - self.attn = ABCAttention( - hidden_size=config.hidden_size, - expand_k=config.expand_k, - expand_v=config.expand_v, - num_heads=config.num_heads, - num_slots=config.num_slots, - use_short_conv=config.use_short_conv, - conv_size=config.conv_size, - gate_fn=config.hidden_act, - elementwise_affine=config.elementwise_affine, - norm_eps=config.norm_eps, - clamp_min=config.clamp_min, - clamp_max=config.clamp_max, - fuse_norm=config.fuse_norm, - layer_idx=layer_idx - ) - - self.ln_2 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) - - self.mlp = ABCMLP(config) - - self.scan_type = config.scan_type - - def forward( - self, - hidden_states: torch.Tensor, - past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, - use_cache: Optional[bool] = False, - output_attentions: Optional[bool] = False, - **kwargs: Unpack[Dict] - ) -> Union[Tuple[torch.Tensor, Optional[torch.Tensor]], Tuple[torch.Tensor]]: - residual = hidden_states - - # Pre-normalization if enabled - if hasattr(self, 'ln_1'): - hidden_states = self.ln_1(hidden_states) - - # Apply attention - - hidden_states = prepare_hidden_states_for_cross_scan(hidden_states, self.scan_type) - - hidden_states, attentions, past_key_values = self.attn( - hidden_states=hidden_states, - past_key_values=past_key_values, - use_cache=use_cache, - output_attentions=output_attentions, - **kwargs - ) - - hidden_states = prepare_hidden_states_for_cross_merge(hidden_states, self.scan_type) - - # First residual connection - hidden_states = residual + hidden_states - residual = hidden_states - - # Pre-normalization for MLP if enabled - if hasattr(self, 'ln_2'): - hidden_states = self.ln_2(hidden_states) - - # MLP - hidden_states = self.mlp(hidden_states) - - # Second residual connection - hidden_states = residual + hidden_states - - outputs = (hidden_states, attentions, past_key_values) - - return outputs - -class ABCVisionPreTrainedModel(PreTrainedModel): - # this part of the code is adapted from huggingface/transformers vit implementation - config_class = ABCVisionConfig - - def _init_weights(self, module): - if isinstance(module, (nn.Linear, nn.Conv2d)): - module.weight.data = nn.init.trunc_normal_( - module.weight.data.to(torch.float32), mean=0.0, std=self.config.initializer_range - ).to(module.weight.dtype) - if module.bias is not None: - module.bias.data.zero_() - elif isinstance(module, nn.LayerNorm): - module.bias.data.zero_() - module.weight.data.fill_(1.0) - elif isinstance(module, ImageEmbeddings): - module.position_embeddings.data = nn.init.trunc_normal_( - module.position_embeddings.data.to(torch.float32), - mean=0.0, - std=self.config.initializer_range, - ).to(module.position_embeddings.dtype) - -class ABCForImageClassification(ABCVisionPreTrainedModel): - config_class = ABCVisionConfig - - def __init__(self, config): - super().__init__(config) - self.num_labels = config.num_classes - - self.embeddings = ImageEmbeddings(config) - self.blocks = nn.ModuleList([ - ABCBlock(config, layer_idx) - for layer_idx in range(config.num_hidden_layers) - ]) - self.norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) - self.pooler = Pooler(config) - self.classifier = nn.Linear(config.hidden_size, config.num_classes) - self.interpolate_pos_encoding = config.interpolate_pos_encoding - self.init_weights() - - def forward( - self, - pixel_values: Optional[torch.Tensor] = None, - labels: Optional[torch.Tensor] = None, - output_hidden_states: Optional[bool] = None, - past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - return_dict: Optional[bool] = None, - **kwargs: Unpack[Dict] - ) -> Union[Tuple, ImageClassifierOutput]: - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - fuse_linear_and_cross_entropy = self.config.fuse_cross_entropy and self.training - hidden_states = self.embeddings(pixel_values, interpolate_pos_encoding=self.interpolate_pos_encoding) - - for block in self.blocks: - hidden_states, attentions, past_key_values = block( - hidden_states, - past_key_values=past_key_values, - use_cache=use_cache, - output_attentions=output_attentions, - **kwargs - ) - - hidden_states = self.norm(hidden_states) - pooled_output = self.pooler(hidden_states) - - logits = self.classifier(pooled_output) - - loss = None - if labels is not None: - if self.num_labels == 1: - loss_fct = MSELoss() - loss = loss_fct(logits.squeeze(), labels.squeeze()) - else: - loss_fct = CrossEntropyLoss() - loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) - - if not return_dict: - output = (logits,) + (hidden_states,) - return ((loss,) + output) if loss is not None else output - - return ImageClassifierOutput( - loss=loss, - logits=logits, - hidden_states=hidden_states, - ) diff --git a/fla/vision_models/bitnet/__init__.py b/fla/vision_models/bitnet/__init__.py deleted file mode 100644 index 148b4557f..000000000 --- a/fla/vision_models/bitnet/__init__.py +++ /dev/null @@ -1,12 +0,0 @@ -from transformers import AutoConfig, AutoModelForImageClassification - -from fla.vision_models.bitnet.configuration_bitnet import BitNetVisionConfig -from fla.vision_models.bitnet.modeling_bitnet import BitNetForImageClassification - -AutoConfig.register(BitNetVisionConfig.model_type, BitNetVisionConfig) -AutoModelForImageClassification.register(BitNetVisionConfig, BitNetForImageClassification) - -__all__ = [ - 'BitNetVisionConfig', - 'BitNetForImageClassification' -] diff --git a/fla/vision_models/bitnet/configuration_bitnet.py b/fla/vision_models/bitnet/configuration_bitnet.py deleted file mode 100644 index 902f3f9a3..000000000 --- a/fla/vision_models/bitnet/configuration_bitnet.py +++ /dev/null @@ -1,95 +0,0 @@ -from typing import Dict, Optional - -from transformers.configuration_utils import PretrainedConfig - - -class BitNetVisionConfig(PretrainedConfig): - - model_type = 'bitnet_vision' - - def __init__( - self, - # BitNet core parameters - hidden_size: int = 2048, - num_hidden_layers: int = 24, - num_heads: int = 32, - num_kv_heads: int = None, - window_size: Optional[int] = None, - rope_theta: Optional[float] = 10000., - max_position_embeddings: int = 2048, - hidden_act: str = "swish", - initializer_range: float = 0.02, - elementwise_affine: Optional[bool] = True, - norm_first: bool = False, - norm_eps: float = 1e-6, - use_cache: bool = True, - attention_bias: bool = False, - fuse_norm: bool = True, - fuse_cross_entropy: bool = True, - attn: Optional[Dict] = None, - # Vision specific parameters - image_size: int = 224, - patch_size: int = 16, - num_channels: int = 3, - num_classes: int = 1000, - qkv_bias: bool = True, - hidden_dropout_prob: float = 0.0, - use_mask_token: bool = False, - layer_norm_eps: float = 1e-6, - interpolate_pos_encoding: bool = False, - mlp_dim: int = None, - # FLA-for-vision-related parameters - scan_type: str = "uni-scan", # scaning type, "uni-scan" or "bi-scan" or "cross-scan", default to "uni-scan" - **kwargs - ): - # Initialize BitNet core parameters - self.hidden_size = hidden_size - self.num_hidden_layers = num_hidden_layers - self.num_heads = num_heads - self.num_kv_heads = num_kv_heads - self.window_size = window_size - self.rope_theta = rope_theta - self.max_position_embeddings = max_position_embeddings - - self.hidden_act = hidden_act - - self.initializer_range = initializer_range - self.elementwise_affine = elementwise_affine - self.norm_first = norm_first - self.norm_eps = norm_eps - self.use_cache = use_cache - self.attention_bias = attention_bias - self.fuse_cross_entropy = fuse_cross_entropy - self.fuse_norm = fuse_norm - - # Initialize vision specific parameters - self.image_size = image_size - self.patch_size = patch_size - self.num_channels = num_channels - self.num_classes = num_classes - self.qkv_bias = qkv_bias - self.hidden_dropout_prob = hidden_dropout_prob - self.use_mask_token = use_mask_token - self.layer_norm_eps = layer_norm_eps - self.interpolate_pos_encoding = interpolate_pos_encoding - self.scan_type = scan_type - - - if attn is not None: - if not isinstance(attn, Dict): - raise ValueError("attn must be a dictionary") - if 'layers' not in attn: - raise ValueError("Layer indices must be provided to initialize hybrid attention layers") - if 'num_heads' not in attn: - raise ValueError("Number of heads must be provided to initialize hybrid attention layers") - attn['num_kv_heads'] = attn.get('num_kv_heads', attn['num_heads']) - attn['window_size'] = attn.get('window_size', None) - - self.attn = attn - - if mlp_dim is None: - self.mlp_dim = 4 * hidden_size # default value set to 4 * hidden_size - else: - self.mlp_dim = mlp_dim - - super().__init__(**kwargs) \ No newline at end of file diff --git a/fla/vision_models/bitnet/modeling_bitnet.py b/fla/vision_models/bitnet/modeling_bitnet.py deleted file mode 100644 index fa9675095..000000000 --- a/fla/vision_models/bitnet/modeling_bitnet.py +++ /dev/null @@ -1,201 +0,0 @@ -import collections.abc -import math -import torch -import torch.utils.checkpoint -from torch import nn -from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss -from typing import Optional, Set, Tuple, Union, List, Dict, Unpack -from transformers.utils import logging -from fla.layers.attn import Attention -from transformers.modeling_outputs import ImageClassifierOutput -from transformers.modeling_utils import PreTrainedModel -from .configuration_bitnet import BitNetVisionConfig -from fla.layers.bitattn import BitAttention -from fla.models.utils import Cache -from ..utils import ImageEmbeddings, Pooler, prepare_hidden_states_for_cross_scan, prepare_hidden_states_for_cross_merge - -logger = logging.get_logger(__name__) - -class BitNetMLP(nn.Module): - def __init__(self, config): - super().__init__() - self.net = nn.Sequential( - nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps), - nn.Linear(config.hidden_size, config.mlp_dim), - nn.GELU(), - nn.Linear(config.mlp_dim, config.hidden_size), - nn.Dropout(config.hidden_dropout_prob) - ) - - def forward(self, x): - return self.net(x) - -class BitNetBlock(nn.Module): - def __init__(self, config, layer_idx: int): - super().__init__() - - if not config.norm_first: - self.ln_1 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) - - if config.attn is not None and layer_idx in config.attn['layers']: - self.attn = Attention( - hidden_size=config.hidden_size, - num_heads=config.attn['num_heads'], - num_kv_heads=config.attn['num_kv_heads'], - window_size=config.attn['window_size'], - max_position_embeddings=config.max_position_embeddings, - layer_idx=layer_idx - ) - else: - self.attn = BitAttention( - hidden_size=config.hidden_size, - num_heads=config.num_heads, - num_kv_heads=config.num_kv_heads, - window_size=config.window_size, - rope_theta=config.rope_theta, - max_position_embeddings=config.max_position_embeddings, - norm_first=config.norm_first, - norm_eps=config.norm_eps, - layer_idx=layer_idx - ) - - if not config.norm_first: - self.ln_2 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) - - self.mlp = BitNetMLP(config) - - self.scan_type = config.scan_type - - def forward( - self, - hidden_states: torch.Tensor, - past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, - use_cache: Optional[bool] = False, - output_attentions: Optional[bool] = False, - **kwargs: Unpack[Dict] - ) -> Union[Tuple[torch.Tensor, Optional[torch.Tensor]], Tuple[torch.Tensor]]: - residual = hidden_states - - # Pre-normalization if enabled - if hasattr(self, 'ln_1'): - hidden_states = self.ln_1(hidden_states) - - # Apply attention - - hidden_states = prepare_hidden_states_for_cross_scan(hidden_states, self.scan_type) - - hidden_states, attentions, past_key_values = self.attn( - hidden_states=hidden_states, - past_key_values=past_key_values, - use_cache=use_cache, - output_attentions=output_attentions, - **kwargs - ) - - hidden_states = prepare_hidden_states_for_cross_merge(hidden_states, self.scan_type) - - # First residual connection - hidden_states = residual + hidden_states - residual = hidden_states - - # Pre-normalization for MLP if enabled - if hasattr(self, 'ln_2'): - hidden_states = self.ln_2(hidden_states) - - # MLP - hidden_states = self.mlp(hidden_states) - - # Second residual connection - hidden_states = residual + hidden_states - - outputs = (hidden_states, attentions, past_key_values) - - return outputs - -class BitNetVisionPreTrainedModel(PreTrainedModel): - # this part of the code is adapted from huggingface/transformers vit implementation - config_class = BitNetVisionConfig - base_model_prefix = "bitnet" - - def _init_weights(self, module): - if isinstance(module, (nn.Linear, nn.Conv2d)): - module.weight.data = nn.init.trunc_normal_( - module.weight.data.to(torch.float32), mean=0.0, std=self.config.initializer_range - ).to(module.weight.dtype) - if module.bias is not None: - module.bias.data.zero_() - elif isinstance(module, nn.LayerNorm): - module.bias.data.zero_() - module.weight.data.fill_(1.0) - elif isinstance(module, ImageEmbeddings): - module.position_embeddings.data = nn.init.trunc_normal_( - module.position_embeddings.data.to(torch.float32), - mean=0.0, - std=self.config.initializer_range, - ).to(module.position_embeddings.dtype) - -class BitNetForImageClassification(BitNetVisionPreTrainedModel): - config_class = BitNetVisionConfig - - def __init__(self, config): - super().__init__(config) - self.num_labels = config.num_classes - - self.embeddings = ImageEmbeddings(config) - self.blocks = nn.ModuleList([ - BitNetBlock(config, layer_idx) - for layer_idx in range(config.num_hidden_layers) - ]) - self.norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) - self.pooler = Pooler(config) - self.classifier = nn.Linear(config.hidden_size, config.num_classes) - self.interpolate_pos_encoding = config.interpolate_pos_encoding - self.init_weights() - - def forward( - self, - pixel_values: Optional[torch.Tensor] = None, - labels: Optional[torch.Tensor] = None, - output_hidden_states: Optional[bool] = None, - past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - return_dict: Optional[bool] = None, - **kwargs: Unpack[Dict] - ) -> Union[Tuple, ImageClassifierOutput]: - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - fuse_linear_and_cross_entropy = self.config.fuse_cross_entropy and self.training - hidden_states = self.embeddings(pixel_values, interpolate_pos_encoding=self.interpolate_pos_encoding) - - for block in self.blocks: - hidden_states, attentions, past_key_values = block( - hidden_states, - past_key_values=past_key_values, - use_cache=use_cache, - output_attentions=output_attentions, - **kwargs - ) - - hidden_states = self.norm(hidden_states) - pooled_output = self.pooler(hidden_states) - - logits = self.classifier(pooled_output) - - loss = None - if labels is not None: - if self.num_labels == 1: - loss_fct = MSELoss() - loss = loss_fct(logits.squeeze(), labels.squeeze()) - else: - loss_fct = CrossEntropyLoss() - loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) - - if not return_dict: - output = (logits,) + (hidden_states,) - return ((loss,) + output) if loss is not None else output - - return ImageClassifierOutput( - loss=loss, - logits=logits, - hidden_states=hidden_states, - ) diff --git a/fla/vision_models/delta_net/__init__.py b/fla/vision_models/delta_net/__init__.py deleted file mode 100644 index eef31ccbc..000000000 --- a/fla/vision_models/delta_net/__init__.py +++ /dev/null @@ -1,16 +0,0 @@ -from transformers import AutoConfig, AutoModel, AutoModelForImageClassification, AutoModelForMaskedImageModeling - -from fla.vision_models.delta_net.configuration_delta_net import DeltaNetVisionConfig -from fla.vision_models.delta_net.modeling_delta_net import DeltaNetForImageClassification, DeltaNetVisionModel, DeltaNetForMaskedImageModeling - -AutoConfig.register(DeltaNetVisionConfig.model_type, DeltaNetVisionConfig) -AutoModelForImageClassification.register(DeltaNetVisionConfig, DeltaNetForImageClassification) -AutoModelForMaskedImageModeling.register(DeltaNetVisionConfig, DeltaNetForMaskedImageModeling) -AutoModel.register(DeltaNetVisionConfig, DeltaNetVisionModel) - -__all__ = [ - "DeltaNetVisionConfig", - "DeltaNetForImageClassification", - "DeltaNetVisionModel", - "DeltaNetForMaskedImageModeling" -] diff --git a/fla/vision_models/delta_net/configuration_delta_net.py b/fla/vision_models/delta_net/configuration_delta_net.py deleted file mode 100644 index b24b48908..000000000 --- a/fla/vision_models/delta_net/configuration_delta_net.py +++ /dev/null @@ -1,101 +0,0 @@ -from typing import Dict, Optional -from transformers.configuration_utils import PretrainedConfig - -class DeltaNetVisionConfig(PretrainedConfig): - model_type = 'delta_net_vision' - - def __init__( - self, - # DeltaNet core parameters - attn_mode: str = "chunk", - hidden_size: int = 2048, - expand_k: int = 1, - expand_v: int = 1, - use_gate: bool = False, - use_short_conv: bool = True, - conv_size: int = 4, - use_beta: bool = True, - use_output_norm: bool = True, - num_heads: int = 16, - qk_norm: str = 'l2', - qk_activation: str = 'silu', - intermediate_size: Optional[int] = None, - hidden_act: str = "swish", - num_hidden_layers: int = 12, - norm_first: bool = False, - norm_eps: float = 1e-6, - attn: Optional[Dict] = None, - use_cache: bool = True, - initializer_range: float = 0.02, - fuse_cross_entropy: bool = True, - max_position_embeddings: int = 2048, - - # Vision specific parameters - image_size: int = 224, - patch_size: int = 16, - num_channels: int = 3, - num_classes: int = 1000, - hidden_dropout_prob: float = 0.0, - use_mask_token: bool = False, - layer_norm_eps: float = 1e-6, - interpolate_pos_encoding: bool = False, - encoder_stride=16, - mlp_dim: int = None, - # FLA-for-vision-related parameters - scan_type: str = "uni-scan", # scaning type, "uni-scan" or "bi-scan" or "cross-scan", default to "uni-scan" - **kwargs - ): - # Initialize DeltaNet core parameters - self.attn_mode = attn_mode - self.hidden_size = hidden_size - self.expand_k = expand_k - self.expand_v = expand_v - self.use_gate = use_gate - self.use_short_conv = use_short_conv - self.conv_size = conv_size - self.use_beta = use_beta - self.use_output_norm = use_output_norm - self.num_heads = num_heads - self.qk_norm = qk_norm - self.qk_activation = qk_activation - self.intermediate_size = intermediate_size - self.hidden_act = hidden_act - self.num_hidden_layers = num_hidden_layers - self.norm_first = norm_first - self.norm_eps = norm_eps - self.use_cache = use_cache - self.initializer_range = initializer_range - self.fuse_cross_entropy = fuse_cross_entropy - self.max_position_embeddings = max_position_embeddings - - # Initialize vision specific parameters - self.image_size = image_size - self.patch_size = patch_size - self.num_channels = num_channels - self.num_classes = num_classes - self.hidden_dropout_prob = hidden_dropout_prob - self.use_mask_token = use_mask_token - self.layer_norm_eps = layer_norm_eps - self.interpolate_pos_encoding = interpolate_pos_encoding - self.scan_type = scan_type - self.encoder_stride = encoder_stride - - - if attn is not None: - if not isinstance(attn, Dict): - raise ValueError("attn must be a dictionary") - if 'layers' not in attn: - raise ValueError("Layer indices must be provided to initialize hybrid attention layers") - if 'num_heads' not in attn: - raise ValueError("Number of heads must be provided to initialize hybrid attention layers") - attn['num_kv_heads'] = attn.get('num_kv_heads', attn['num_heads']) - attn['window_size'] = attn.get('window_size', None) - - self.attn = attn - - if mlp_dim is None: - self.mlp_dim = 4 * hidden_size # default value set to 4 * hidden_size - else: - self.mlp_dim = mlp_dim - - super().__init__(**kwargs) diff --git a/fla/vision_models/delta_net/modeling_delta_net.py b/fla/vision_models/delta_net/modeling_delta_net.py deleted file mode 100644 index e9ece6adc..000000000 --- a/fla/vision_models/delta_net/modeling_delta_net.py +++ /dev/null @@ -1,385 +0,0 @@ -import collections.abc -import math -import torch -import torch.utils.checkpoint -from torch import nn -from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss -from typing import Optional, Set, Tuple, Union, List, Dict, Unpack -from transformers.utils import logging -from fla.layers.attn import Attention -from transformers.modeling_outputs import ImageClassifierOutput, BaseModelOutput, BaseModelOutputWithPooling, MaskedImageModelingOutput -from transformers.modeling_utils import PreTrainedModel -from .configuration_delta_net import DeltaNetVisionConfig -from fla.layers.delta_net import DeltaNet -from fla.models.utils import Cache -from ..utils import ImageEmbeddings, Pooler, prepare_hidden_states_for_cross_scan, prepare_hidden_states_for_cross_merge - -logger = logging.get_logger(__name__) - -class DeltaNetMLP(nn.Module): - def __init__(self, config): - super().__init__() - self.net = nn.Sequential( - nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps), - nn.Linear(config.hidden_size, config.mlp_dim), - nn.GELU(), - nn.Linear(config.mlp_dim, config.hidden_size), - nn.Dropout(config.hidden_dropout_prob) - ) - - def forward(self, x): - return self.net(x) - -class DeltaNetBlock(nn.Module): - def __init__(self, config, layer_idx: int): - super().__init__() - - if not config.norm_first: - self.ln_1 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) - - if config.attn is not None and layer_idx in config.attn['layers']: - self.attn = Attention( - hidden_size=config.hidden_size, - num_heads=config.attn['num_heads'], - num_kv_heads=config.attn['num_kv_heads'], - window_size=config.attn['window_size'], - max_position_embeddings=config.max_position_embeddings, - layer_idx=layer_idx - ) - else: - self.attn = DeltaNet( - mode=config.attn_mode, - hidden_size=config.hidden_size, - expand_k=config.expand_k, - expand_v=config.expand_v, - num_heads=config.num_heads, - use_gate=config.use_gate, - use_beta=config.use_beta, - use_short_conv=config.use_short_conv, - use_output_norm=config.use_output_norm, - conv_size=config.conv_size, - qk_norm=config.qk_norm, - qk_activation=config.qk_activation, - norm_first=config.norm_first, - norm_eps=config.norm_eps, - layer_idx=layer_idx - ) - - if not config.norm_first: - self.ln_2 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) - - self.mlp = DeltaNetMLP(config) - - self.scan_type = config.scan_type - - def forward( - self, - hidden_states: torch.Tensor, - past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, - use_cache: Optional[bool] = False, - output_attentions: Optional[bool] = False, - **kwargs: Unpack[Dict] - ) -> Union[Tuple[torch.Tensor, Optional[torch.Tensor]], Tuple[torch.Tensor]]: - residual = hidden_states - - # Pre-normalization if enabled - if hasattr(self, 'ln_1'): - hidden_states = self.ln_1(hidden_states) - - # Apply attention - - hidden_states = prepare_hidden_states_for_cross_scan(hidden_states, self.scan_type) - - hidden_states, attentions, past_key_values = self.attn( - hidden_states=hidden_states, - past_key_values=past_key_values, - use_cache=use_cache, - output_attentions=output_attentions, - **kwargs - ) - - hidden_states = prepare_hidden_states_for_cross_merge(hidden_states, self.scan_type) - - # First residual connection - hidden_states = residual + hidden_states - residual = hidden_states - - # Pre-normalization for MLP if enabled - if hasattr(self, 'ln_2'): - hidden_states = self.ln_2(hidden_states) - - hidden_states = self.mlp(hidden_states) - - # Second residual connection - hidden_states = residual + hidden_states - - outputs = (hidden_states, attentions, past_key_values) - - return outputs - -class DeltaNetVisionPreTrainedModel(PreTrainedModel): - config_class = DeltaNetVisionConfig - - def _init_weights(self, module): - if isinstance(module, (nn.Linear, nn.Conv2d)): - module.weight.data = nn.init.trunc_normal_( - module.weight.data.to(torch.float32), mean=0.0, std=self.config.initializer_range - ).to(module.weight.dtype) - if module.bias is not None: - module.bias.data.zero_() - elif isinstance(module, nn.LayerNorm): - module.bias.data.zero_() - module.weight.data.fill_(1.0) - elif isinstance(module, ImageEmbeddings): - module.position_embeddings.data = nn.init.trunc_normal_( - module.position_embeddings.data.to(torch.float32), - mean=0.0, - std=self.config.initializer_range, - ).to(module.position_embeddings.dtype) - - -class DeltaNetVisionEncoder(nn.Module): - def __init__(self, config) -> None: - super().__init__() - self.config = config - self.blocks = nn.ModuleList([ - DeltaNetBlock(config, layer_idx) - for layer_idx in range(config.num_hidden_layers) - ]) - self.gradient_checkpointing = False - - def forward( - self, - hidden_states: torch.Tensor, - output_attentions: bool = False, - output_hidden_states: bool = False, - past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, - use_cache: Optional[bool] = None, - return_dict: bool = True, - **kwargs - ) -> Union[tuple, BaseModelOutput]: - all_hidden_states = () if output_hidden_states else None - all_self_attentions = () if output_attentions else None - - for i, block in enumerate(self.blocks): - if output_hidden_states: - all_hidden_states = all_hidden_states + (hidden_states,) - - if self.gradient_checkpointing and self.training: - hidden_states, attentions, past_key_values = self._gradient_checkpointing_func( - block.__call__, - hidden_states, - past_key_values=past_key_values, - use_cache=use_cache, - output_attentions=output_attentions, - **kwargs - ) - else: - hidden_states, attentions, past_key_values = block( - hidden_states, - past_key_values=past_key_values, - use_cache=use_cache, - output_attentions=output_attentions, - **kwargs - ) - - if output_attentions: - all_self_attentions = all_self_attentions + (attentions,) - - if output_hidden_states: - all_hidden_states = all_hidden_states + (hidden_states,) - - if not return_dict: - return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None) - - return BaseModelOutput( - last_hidden_state=hidden_states, - hidden_states=all_hidden_states, - attentions=all_self_attentions, - ) - -class DeltaNetVisionModel(DeltaNetVisionPreTrainedModel): - def __init__(self, config, add_pooling_layer=True, use_mask_token=False): - super().__init__(config) - self.config = config - self.embeddings = ImageEmbeddings(config, use_mask_token=use_mask_token) - self.encoder = DeltaNetVisionEncoder(config) - self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) - self.pooler = Pooler(config) if add_pooling_layer else None - self.init_weights() - - def get_input_embeddings(self): - return self.embeddings.patch_embeddings - - def forward( - self, - pixel_values: Optional[torch.Tensor] = None, - bool_masked_pos: Optional[torch.BoolTensor] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, - interpolate_pos_encoding: Optional[bool] = None, - use_cache: Optional[bool] = None, - return_dict: Optional[bool] = None, - **kwargs - ) -> Union[Tuple, BaseModelOutputWithPooling]: - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - if pixel_values is None: - raise ValueError("You have to specify pixel_values") - - hidden_states = self.embeddings(pixel_values, bool_masked_pos=bool_masked_pos, interpolate_pos_encoding=interpolate_pos_encoding) - - encoder_outputs = self.encoder( - hidden_states, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - past_key_values=past_key_values, - use_cache=use_cache, - return_dict=return_dict, - **kwargs - ) - - sequence_output = encoder_outputs[0] - sequence_output = self.layernorm(sequence_output) - pooled_output = self.pooler(sequence_output) if self.pooler is not None else None - - if not return_dict: - head_outputs = (sequence_output, pooled_output) if pooled_output is not None else (sequence_output,) - return head_outputs + encoder_outputs[1:] - - return BaseModelOutputWithPooling( - last_hidden_state=sequence_output, - pooler_output=pooled_output, - hidden_states=encoder_outputs.hidden_states, - attentions=encoder_outputs.attentions, - ) - -class DeltaNetForImageClassification(DeltaNetVisionPreTrainedModel): - def __init__(self, config): - super().__init__(config) - self.num_labels = config.num_classes - self.backbone = DeltaNetVisionModel(config, add_pooling_layer=True) # Here we should use mean pooling - self.classifier = nn.Linear(config.hidden_size, config.num_classes) - self.init_weights() - - def forward( - self, - pixel_values: Optional[torch.Tensor] = None, - labels: Optional[torch.Tensor] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - interpolate_pos_encoding: Optional[bool] = None, - return_dict: Optional[bool] = None, - ) -> Union[tuple, ImageClassifierOutput]: - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - outputs = self.backbone( - pixel_values, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - interpolate_pos_encoding=interpolate_pos_encoding, - return_dict=return_dict, - ) - - pooled_output = outputs.pooler_output - logits = self.classifier(pooled_output) # only use mean pooling - - loss = None - if labels is not None: - if self.num_labels == 1: - loss_fct = MSELoss() - loss = loss_fct(logits.squeeze(), labels.squeeze()) - else: - loss_fct = CrossEntropyLoss() - loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) - - if not return_dict: - output = (logits,) + outputs[2:] - return ((loss,) + output) if loss is not None else output - - return ImageClassifierOutput( - loss=loss, - logits=logits, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - ) - -class DeltaNetForMaskedImageModeling(DeltaNetVisionPreTrainedModel): - def __init__(self, config): - super().__init__(config) - self.backbone = DeltaNetVisionModel(config, add_pooling_layer=False, use_mask_token=True) - self.decoder = nn.Sequential( - nn.Conv2d( - in_channels=config.hidden_size, - out_channels=config.encoder_stride**2 * config.num_channels, - kernel_size=1, - ), - nn.PixelShuffle(config.encoder_stride), - ) - self.init_weights() - - def forward( - self, - pixel_values: Optional[torch.Tensor] = None, - bool_masked_pos: Optional[torch.BoolTensor] = None, - labels: Optional[torch.Tensor] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - interpolate_pos_encoding: Optional[bool] = None, - return_dict: Optional[bool] = None, - ) -> Union[tuple, MaskedImageModelingOutput]: - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - if bool_masked_pos is not None and (self.config.patch_size != self.config.encoder_stride): - raise ValueError( - "When `bool_masked_pos` is provided, `patch_size` must be equal to `encoder_stride` to ensure that " - "the reconstructed image has the same dimensions as the input. " - f"Got `patch_size` = {self.config.patch_size} and `encoder_stride` = {self.config.encoder_stride}." - ) - - outputs = self.backbone( - pixel_values, - bool_masked_pos=bool_masked_pos, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - interpolate_pos_encoding=interpolate_pos_encoding, - return_dict=return_dict, - ) - - - sequence_output = outputs[0] - batch_size, sequence_length, num_channels = sequence_output.shape - height = width = math.floor(sequence_length**0.5) - sequence_output = sequence_output.permute(0, 2, 1).reshape(batch_size, num_channels, height, width) - - # Reconstruct pixel values - reconstructed_pixel_values = self.decoder(sequence_output) - - masked_im_loss = None - if bool_masked_pos is not None: - size = self.config.image_size // self.config.patch_size - bool_masked_pos = bool_masked_pos.reshape(-1, size, size) - mask = ( - bool_masked_pos.repeat_interleave(self.config.patch_size, 1) - .repeat_interleave(self.config.patch_size, 2) - .unsqueeze(1) - .contiguous() - ) - reconstruction_loss = nn.functional.l1_loss(pixel_values, reconstructed_pixel_values, reduction="none") - masked_im_loss = (reconstruction_loss * mask).sum() / (mask.sum() + 1e-5) / self.config.num_channels - - if not return_dict: - output = (reconstructed_pixel_values,) + outputs[1:] - return ((masked_im_loss,) + output) if masked_im_loss is not None else output - - return MaskedImageModelingOutput( - loss=masked_im_loss, - reconstruction=reconstructed_pixel_values, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - ) diff --git a/fla/vision_models/gated_deltanet/__init__.py b/fla/vision_models/gated_deltanet/__init__.py deleted file mode 100644 index 45bb5ffbf..000000000 --- a/fla/vision_models/gated_deltanet/__init__.py +++ /dev/null @@ -1,13 +0,0 @@ -from transformers import AutoConfig, AutoModelForImageClassification - -from fla.vision_models.gated_deltanet.configuration_gated_deltanet import GatedDeltaNetVisionConfig -from fla.vision_models.gated_deltanet.modeling_gated_deltanet import GatedDeltaNetForImageClassification - -AutoConfig.register(GatedDeltaNetVisionConfig.model_type, GatedDeltaNetVisionConfig) -AutoModelForImageClassification.register(GatedDeltaNetVisionConfig, GatedDeltaNetForImageClassification) - -__all__ = [ - 'GatedDeltaNetVisionConfig', - 'GatedDeltaNetForImageClassification' -] - diff --git a/fla/vision_models/gated_deltanet/configuration_gated_deltanet.py b/fla/vision_models/gated_deltanet/configuration_gated_deltanet.py deleted file mode 100644 index 6cbbd9e72..000000000 --- a/fla/vision_models/gated_deltanet/configuration_gated_deltanet.py +++ /dev/null @@ -1,89 +0,0 @@ -from typing import Dict, Optional -from transformers.configuration_utils import PretrainedConfig - -class GatedDeltaNetVisionConfig(PretrainedConfig): - model_type = 'gated_deltanet_vision' - - def __init__( - self, - # GatedDeltaNet core parameters - attn_mode: str = "chunk", - hidden_size: int = 2048, - expand_v: int = 2, - use_gate: bool = True, - use_short_conv: bool = True, - conv_size: int = 4, - head_dim: int = 256, - num_heads: int = 6, - max_position_embeddings: int = 2048, - hidden_act: str = "swish", - num_hidden_layers: int = 21, - norm_first: bool = False, - norm_eps: float = 1e-6, - attn: Optional[Dict] = None, - use_cache: bool = True, - initializer_range: float = 0.02, - fuse_cross_entropy: bool = True, - - # Vision specific parameters - image_size: int = 224, - patch_size: int = 16, - num_channels: int = 3, - num_classes: int = 1000, - hidden_dropout_prob: float = 0.0, - use_mask_token: bool = False, - layer_norm_eps: float = 1e-6, - interpolate_pos_encoding: bool = False, - mlp_dim: int = None, - # FLA-for-vision-related parameters - scan_type: str = "uni-scan", - **kwargs - ): - # Initialize GatedDeltaNet core parameters - self.attn_mode = attn_mode - self.hidden_size = hidden_size - self.expand_v = expand_v - self.head_dim = head_dim - self.use_gate = use_gate - self.use_short_conv = use_short_conv - self.conv_size = conv_size - self.num_heads = num_heads - self.hidden_act = hidden_act - self.num_hidden_layers = num_hidden_layers - self.norm_first = norm_first - self.norm_eps = norm_eps - self.use_cache = use_cache - self.initializer_range = initializer_range - self.fuse_cross_entropy = fuse_cross_entropy - self.attn = attn - self.max_position_embeddings = max_position_embeddings - - # Initialize vision specific parameters - self.image_size = image_size - self.patch_size = patch_size - self.num_channels = num_channels - self.num_classes = num_classes - self.hidden_dropout_prob = hidden_dropout_prob - self.use_mask_token = use_mask_token - self.layer_norm_eps = layer_norm_eps - self.interpolate_pos_encoding = interpolate_pos_encoding - self.scan_type = scan_type - - if attn is not None: - if not isinstance(attn, Dict): - raise ValueError("attn must be a dictionary") - if 'layers' not in attn: - raise ValueError("Layer indices must be provided to initialize hybrid attention layers") - if 'num_heads' not in attn: - raise ValueError("Number of heads must be provided to initialize hybrid attention layers") - attn['num_kv_heads'] = attn.get('num_kv_heads', attn['num_heads']) - attn['window_size'] = attn.get('window_size', None) - - self.attn = attn - - if mlp_dim is None: - self.mlp_dim = 4 * hidden_size - else: - self.mlp_dim = mlp_dim - - super().__init__(**kwargs) diff --git a/fla/vision_models/gated_deltanet/modeling_gated_deltanet.py b/fla/vision_models/gated_deltanet/modeling_gated_deltanet.py deleted file mode 100644 index 94694cca8..000000000 --- a/fla/vision_models/gated_deltanet/modeling_gated_deltanet.py +++ /dev/null @@ -1,202 +0,0 @@ -import collections.abc -import math -import torch -import torch.utils.checkpoint -from torch import nn -from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss -from typing import Optional, Set, Tuple, Union, List, Dict, Unpack -from transformers.utils import logging -from fla.layers.attn import Attention -from transformers.modeling_outputs import ImageClassifierOutput -from transformers.modeling_utils import PreTrainedModel -from .configuration_gated_deltanet import GatedDeltaNetVisionConfig -from fla.layers.gated_deltanet import GatedDeltaNet -from fla.models.utils import Cache -from ..utils import ImageEmbeddings, Pooler, prepare_hidden_states_for_cross_scan, prepare_hidden_states_for_cross_merge - -logger = logging.get_logger(__name__) - -class GatedDeltaNetMLP(nn.Module): - def __init__(self, config): - super().__init__() - self.net = nn.Sequential( - nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps), - nn.Linear(config.hidden_size, config.mlp_dim), - nn.GELU(), - nn.Linear(config.mlp_dim, config.hidden_size), - nn.Dropout(config.hidden_dropout_prob) - ) - - def forward(self, x): - return self.net(x) - -class GatedDeltaNetBlock(nn.Module): - def __init__(self, config, layer_idx: int): - super().__init__() - - if not config.norm_first: - self.ln_1 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) - - if config.attn is not None and layer_idx in config.attn['layers']: - self.attn = Attention( - hidden_size=config.hidden_size, - num_heads=config.attn['num_heads'], - num_kv_heads=config.attn['num_kv_heads'], - window_size=config.attn['window_size'], - max_position_embeddings=config.max_position_embeddings, - layer_idx=layer_idx - ) - else: - self.attn = GatedDeltaNet( - mode=config.attn_mode, - hidden_size=config.hidden_size, - expand_v=config.expand_v, - head_dim=config.head_dim, - num_heads=config.num_heads, - use_gate=config.use_gate, - use_short_conv=config.use_short_conv, - conv_size=config.conv_size, - norm_first=config.norm_first, - norm_eps=config.norm_eps, - layer_idx=layer_idx - ) - - if not config.norm_first: - self.ln_2 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) - - self.mlp = GatedDeltaNetMLP(config) - - self.scan_type = config.scan_type - - def forward( - self, - hidden_states: torch.Tensor, - past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, - use_cache: Optional[bool] = False, - output_attentions: Optional[bool] = False, - **kwargs: Unpack[Dict] - ) -> Union[Tuple[torch.Tensor, Optional[torch.Tensor]], Tuple[torch.Tensor]]: - residual = hidden_states - - # Pre-normalization if enabled - if hasattr(self, 'ln_1'): - hidden_states = self.ln_1(hidden_states) - - # Apply attention - - hidden_states = prepare_hidden_states_for_cross_scan(hidden_states, self.scan_type) - - hidden_states, attentions, past_key_values = self.attn( - hidden_states=hidden_states, - past_key_values=past_key_values, - use_cache=use_cache, - output_attentions=output_attentions, - **kwargs - ) - - hidden_states = prepare_hidden_states_for_cross_merge(hidden_states, self.scan_type) - - # First residual connection - hidden_states = residual + hidden_states - residual = hidden_states - - # Pre-normalization for MLP if enabled - if hasattr(self, 'ln_2'): - hidden_states = self.ln_2(hidden_states) - - # MLP - hidden_states = self.mlp(hidden_states) - - # Second residual connection - hidden_states = residual + hidden_states - - outputs = (hidden_states, attentions, past_key_values) - - return outputs - -class GatedDeltaNetVisionPreTrainedModel(PreTrainedModel): - # this part of the code is adapted from huggingface/transformers vit implementation - config_class = GatedDeltaNetVisionConfig - - def _init_weights(self, module): - if isinstance(module, (nn.Linear, nn.Conv2d)): - module.weight.data = nn.init.trunc_normal_( - module.weight.data.to(torch.float32), mean=0.0, std=self.config.initializer_range - ).to(module.weight.dtype) - if module.bias is not None: - module.bias.data.zero_() - elif isinstance(module, nn.LayerNorm): - module.bias.data.zero_() - module.weight.data.fill_(1.0) - elif isinstance(module, ImageEmbeddings): - module.position_embeddings.data = nn.init.trunc_normal_( - module.position_embeddings.data.to(torch.float32), - mean=0.0, - std=self.config.initializer_range, - ).to(module.position_embeddings.dtype) - -class GatedDeltaNetForImageClassification(GatedDeltaNetVisionPreTrainedModel): - config_class = GatedDeltaNetVisionConfig - - def __init__(self, config): - super().__init__(config) - self.num_labels = config.num_classes - - self.embeddings = ImageEmbeddings(config) - self.blocks = nn.ModuleList([ - GatedDeltaNetBlock(config, layer_idx) - for layer_idx in range(config.num_hidden_layers) - ]) - self.norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) - self.pooler = Pooler(config) - self.classifier = nn.Linear(config.hidden_size, config.num_classes) - self.interpolate_pos_encoding = config.interpolate_pos_encoding - self.init_weights() - - def forward( - self, - pixel_values: Optional[torch.Tensor] = None, - labels: Optional[torch.Tensor] = None, - output_hidden_states: Optional[bool] = None, - past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - return_dict: Optional[bool] = None, - **kwargs: Unpack[Dict] - ) -> Union[Tuple, ImageClassifierOutput]: - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - fuse_linear_and_cross_entropy = self.config.fuse_cross_entropy and self.training - hidden_states = self.embeddings(pixel_values, interpolate_pos_encoding=self.interpolate_pos_encoding) - - for block in self.blocks: - hidden_states, attentions, past_key_values = block( - hidden_states, - past_key_values=past_key_values, - use_cache=use_cache, - output_attentions=output_attentions, - **kwargs - ) - - hidden_states = self.norm(hidden_states) - pooled_output = self.pooler(hidden_states) - - logits = self.classifier(pooled_output) - - loss = None - if labels is not None: - if self.num_labels == 1: - loss_fct = MSELoss() - loss = loss_fct(logits.squeeze(), labels.squeeze()) - else: - loss_fct = CrossEntropyLoss() - loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) - - if not return_dict: - output = (logits,) + (hidden_states,) - return ((loss,) + output) if loss is not None else output - - return ImageClassifierOutput( - loss=loss, - logits=logits, - hidden_states=hidden_states, - ) diff --git a/fla/vision_models/gla/__init__.py b/fla/vision_models/gla/__init__.py deleted file mode 100644 index dc7d6e93c..000000000 --- a/fla/vision_models/gla/__init__.py +++ /dev/null @@ -1,12 +0,0 @@ -from transformers import AutoConfig, AutoModelForImageClassification - -from fla.vision_models.gla.configuration_gla import GLAVisionConfig -from fla.vision_models.gla.modeling_gla import GLAForImageClassification - -AutoConfig.register(GLAVisionConfig.model_type, GLAVisionConfig) -AutoModelForImageClassification.register(GLAVisionConfig, GLAForImageClassification) - -__all__ = [ - 'GLAVisionConfig', - 'GLAForImageClassification' -] diff --git a/fla/vision_models/gla/configuration_gla.py b/fla/vision_models/gla/configuration_gla.py deleted file mode 100644 index af52bbe6f..000000000 --- a/fla/vision_models/gla/configuration_gla.py +++ /dev/null @@ -1,101 +0,0 @@ -from typing import Dict, Optional - -from transformers.configuration_utils import PretrainedConfig - -class GLAVisionConfig(PretrainedConfig): - - model_type = 'gla_vision' - - def __init__( - self, - # GLA core parameters - hidden_size: int = 2048, - expand_k: int = 0.5, - expand_v: int = 1, - num_hidden_layers: int = 24, - num_heads: int = 4, - num_kv_heads: Optional[int] = None, - feature_map: Optional[str] = None, - attn_mode: str = "chunk", - use_short_conv: bool = False, - conv_size: int = 4, - use_output_gate: bool = True, - clamp_min: Optional[float] = None, - hidden_act: str = "swish", - max_position_embeddings: int = 2048, - elementwise_affine: Optional[bool] = True, - norm_eps: float = 1e-6, - use_gk: bool = True, - use_gv: bool = False, - attn: Optional[Dict] = None, - use_cache: bool = True, - initializer_range: float = 0.02, - fuse_norm: bool = True, - fuse_cross_entropy: bool = True, - # Vision specific parameters - image_size: int = 224, - patch_size: int = 16, - num_channels: int = 3, - num_classes: int = 1000, - hidden_dropout_prob: float = 0.0, - use_mask_token: bool = False, - layer_norm_eps: float = 1e-6, - interpolate_pos_encoding: bool = False, - mlp_dim: int = None, - # FLA-for-vision-related parameters - scan_type: str = "uni-scan", # scaning type, "uni-scan" or "bi-scan" or "cross-scan", default to "uni-scan" - **kwargs - ): - # Initialize DeltaNet core parameters - self.hidden_size = hidden_size - self.expand_k = expand_k - self.expand_v = expand_v - self.num_hidden_layers = num_hidden_layers - self.num_heads = num_heads - self.num_kv_heads = num_kv_heads - self.feature_map = feature_map - self.attn_mode = attn_mode - self.use_short_conv = use_short_conv - self.conv_size = conv_size - self.use_output_gate = use_output_gate - self.clamp_min = clamp_min - self.hidden_act = hidden_act - self.max_position_embeddings = max_position_embeddings - self.elementwise_affine = elementwise_affine - self.norm_eps = norm_eps - self.use_gk = use_gk - self.use_gv = use_gv - self.use_cache = use_cache - self.initializer_range = initializer_range - self.fuse_norm = fuse_norm - self.fuse_cross_entropy = fuse_cross_entropy - - # Initialize vision specific parameters - self.image_size = image_size - self.patch_size = patch_size - self.num_channels = num_channels - self.num_classes = num_classes - self.hidden_dropout_prob = hidden_dropout_prob - self.use_mask_token = use_mask_token - self.layer_norm_eps = layer_norm_eps - self.interpolate_pos_encoding = interpolate_pos_encoding - self.scan_type = scan_type - - if attn is not None: - if not isinstance(attn, Dict): - raise ValueError("attn must be a dictionary") - if 'layers' not in attn: - raise ValueError("Layer indices must be provided to initialize hybrid attention layers") - if 'num_heads' not in attn: - raise ValueError("Number of heads must be provided to initialize hybrid attention layers") - attn['num_kv_heads'] = attn.get('num_kv_heads', attn['num_heads']) - attn['window_size'] = attn.get('window_size', None) - - self.attn = attn - - if mlp_dim is None: - self.mlp_dim = 4 * hidden_size - else: - self.mlp_dim = mlp_dim - - super().__init__(**kwargs) diff --git a/fla/vision_models/gla/modeling_gla.py b/fla/vision_models/gla/modeling_gla.py deleted file mode 100644 index 311e000a0..000000000 --- a/fla/vision_models/gla/modeling_gla.py +++ /dev/null @@ -1,205 +0,0 @@ -import collections.abc -import math -import torch -import torch.utils.checkpoint -from torch import nn -from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss -from typing import Optional, Set, Tuple, Union, List, Dict, Unpack -from transformers.utils import logging -from fla.layers.attn import Attention -from transformers.modeling_outputs import ImageClassifierOutput -from transformers.modeling_utils import PreTrainedModel -from .configuration_gla import GLAVisionConfig -from fla.layers.gla import GatedLinearAttention -from fla.models.utils import Cache -from ..utils import ImageEmbeddings, Pooler, prepare_hidden_states_for_cross_scan, prepare_hidden_states_for_cross_merge - -logger = logging.get_logger(__name__) - -class GLAMLP(nn.Module): - def __init__(self, config): - super().__init__() - self.net = nn.Sequential( - nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps), - nn.Linear(config.hidden_size, config.mlp_dim), - nn.GELU(), - nn.Linear(config.mlp_dim, config.hidden_size), - nn.Dropout(config.hidden_dropout_prob) - ) - - def forward(self, x): - return self.net(x) - -class GLABlock(nn.Module): - def __init__(self, config, layer_idx: int): - super().__init__() - - self.ln_1 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) - - if config.attn is not None and layer_idx in config.attn['layers']: - self.attn = Attention( - hidden_size=config.hidden_size, - num_heads=config.attn['num_heads'], - num_kv_heads=config.attn['num_kv_heads'], - window_size=config.attn['window_size'], - max_position_embeddings=config.max_position_embeddings, - layer_idx=layer_idx - ) - else: - self.attn = GatedLinearAttention( - mode=config.attn_mode, - hidden_size=config.hidden_size, - expand_k=config.expand_k, - expand_v=config.expand_v, - num_heads=config.num_heads, - num_kv_heads=config.num_kv_heads, - feature_map=config.feature_map, - use_short_conv=config.use_short_conv, - conv_size=config.conv_size, - use_output_gate=config.use_output_gate, - gate_fn=config.hidden_act, - elementwise_affine=config.elementwise_affine, - norm_eps=config.norm_eps, - clamp_min=config.clamp_min, - fuse_norm=config.fuse_norm, - layer_idx=layer_idx - ) - - self.ln_2 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) - - self.mlp = GLAMLP(config) - - self.scan_type = config.scan_type - - def forward( - self, - hidden_states: torch.Tensor, - past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, - use_cache: Optional[bool] = False, - output_attentions: Optional[bool] = False, - **kwargs: Unpack[Dict] - ) -> Union[Tuple[torch.Tensor, Optional[torch.Tensor]], Tuple[torch.Tensor]]: - residual = hidden_states - - # Pre-normalization if enabled - if hasattr(self, 'ln_1'): - hidden_states = self.ln_1(hidden_states) - - # Apply attention - - hidden_states = prepare_hidden_states_for_cross_scan(hidden_states, self.scan_type) - - hidden_states, attentions, past_key_values = self.attn( - hidden_states=hidden_states, - past_key_values=past_key_values, - use_cache=use_cache, - output_attentions=output_attentions, - **kwargs - ) - - hidden_states = prepare_hidden_states_for_cross_merge(hidden_states, self.scan_type) - - # First residual connection - hidden_states = residual + hidden_states - residual = hidden_states - - # Pre-normalization for MLP if enabled - if hasattr(self, 'ln_2'): - hidden_states = self.ln_2(hidden_states) - - # MLP - hidden_states = self.mlp(hidden_states) - - # Second residual connection - hidden_states = residual + hidden_states - - outputs = (hidden_states, attentions, past_key_values) - - return outputs - -class GLAVisionPreTrainedModel(PreTrainedModel): - # this part of the code is adapted from huggingface/transformers vit implementation - config_class = GLAVisionConfig - - def _init_weights(self, module): - if isinstance(module, (nn.Linear, nn.Conv2d)): - module.weight.data = nn.init.trunc_normal_( - module.weight.data.to(torch.float32), mean=0.0, std=self.config.initializer_range - ).to(module.weight.dtype) - if module.bias is not None: - module.bias.data.zero_() - elif isinstance(module, nn.LayerNorm): - module.bias.data.zero_() - module.weight.data.fill_(1.0) - elif isinstance(module, ImageEmbeddings): - module.position_embeddings.data = nn.init.trunc_normal_( - module.position_embeddings.data.to(torch.float32), - mean=0.0, - std=self.config.initializer_range, - ).to(module.position_embeddings.dtype) - -class GLAForImageClassification(GLAVisionPreTrainedModel): - config_class = GLAVisionConfig - - def __init__(self, config): - super().__init__(config) - self.num_labels = config.num_classes - - self.embeddings = ImageEmbeddings(config) - self.blocks = nn.ModuleList([ - GLABlock(config, layer_idx) - for layer_idx in range(config.num_hidden_layers) - ]) - self.norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) - self.pooler = Pooler(config) - self.classifier = nn.Linear(config.hidden_size, config.num_classes) - self.interpolate_pos_encoding = config.interpolate_pos_encoding - self.init_weights() - - def forward( - self, - pixel_values: Optional[torch.Tensor] = None, - labels: Optional[torch.Tensor] = None, - output_hidden_states: Optional[bool] = None, - past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - return_dict: Optional[bool] = None, - **kwargs: Unpack[Dict] - ) -> Union[Tuple, ImageClassifierOutput]: - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - fuse_linear_and_cross_entropy = self.config.fuse_cross_entropy and self.training - hidden_states = self.embeddings(pixel_values, interpolate_pos_encoding=self.interpolate_pos_encoding) - - for block in self.blocks: - hidden_states, attentions, past_key_values = block( - hidden_states, - past_key_values=past_key_values, - use_cache=use_cache, - output_attentions=output_attentions, - **kwargs - ) - - hidden_states = self.norm(hidden_states) - pooled_output = self.pooler(hidden_states) - - logits = self.classifier(pooled_output) - - loss = None - if labels is not None: - if self.num_labels == 1: - loss_fct = MSELoss() - loss = loss_fct(logits.squeeze(), labels.squeeze()) - else: - loss_fct = CrossEntropyLoss() - loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) - - if not return_dict: - output = (logits,) + (hidden_states,) - return ((loss,) + output) if loss is not None else output - - return ImageClassifierOutput( - loss=loss, - logits=logits, - hidden_states=hidden_states, - ) diff --git a/fla/vision_models/gsa/__init__.py b/fla/vision_models/gsa/__init__.py deleted file mode 100644 index 3da164504..000000000 --- a/fla/vision_models/gsa/__init__.py +++ /dev/null @@ -1,12 +0,0 @@ -from transformers import AutoConfig, AutoModelForImageClassification - -from fla.vision_models.gsa.configuration_gsa import GSAVisionConfig -from fla.vision_models.gsa.modeling_gsa import GSAForImageClassification - -AutoConfig.register(GSAVisionConfig.model_type, GSAVisionConfig) -AutoModelForImageClassification.register(GSAVisionConfig, GSAForImageClassification) - -__all__ = [ - 'GSAVisionConfig', - 'GSAForImageClassification' -] diff --git a/fla/vision_models/gsa/configuration_gsa.py b/fla/vision_models/gsa/configuration_gsa.py deleted file mode 100644 index de4bbcb8d..000000000 --- a/fla/vision_models/gsa/configuration_gsa.py +++ /dev/null @@ -1,107 +0,0 @@ -from typing import Dict, Optional - -from transformers.configuration_utils import PretrainedConfig - - -class GSAVisionConfig(PretrainedConfig): - - model_type = 'gsa_vision' - - def __init__( - self, - # GSA core parameters - hidden_size: int = 2048, - gate_logit_normalizer: Optional[int] = 8, - clamp_min: Optional[float] = None, - clamp_max: Optional[float] = None, - num_hidden_layers: int = 24, - num_heads: int = 4, - num_kv_heads: Optional[int] = None, - num_slots: Optional[int] = 64, - use_short_conv: bool = False, - conv_size: int = 4, - exapnd_k: float = 1, - exapnd_v: float = 1, - feature_map: str = 'swish', - use_output_gate: bool = False, - use_norm: bool = True, - max_position_embeddings: int = 2048, - hidden_act: str = "swish", - elementwise_affine: Optional[bool] = True, - norm_first: bool = True, - norm_eps: float = 1e-6, - attn: Optional[Dict] = None, - use_cache: bool = True, - initializer_range: float = 0.02, - fuse_norm: bool = True, - fuse_cross_entropy: bool = True, - # Vision specific parameters - image_size: int = 224, - patch_size: int = 16, - num_channels: int = 3, - num_classes: int = 1000, - qkv_bias: bool = True, - hidden_dropout_prob: float = 0.0, - use_mask_token: bool = False, - layer_norm_eps: float = 1e-6, - interpolate_pos_encoding: bool = False, - mlp_dim: int = None, - # FLA-for-vision-related parameters - scan_type: str = "uni-scan", # scaning type, "uni-scan" or "bi-scan" or "cross-scan", default to "uni-scan" - **kwargs - ): - self.hidden_size = hidden_size - self.gate_logit_normalizer = gate_logit_normalizer - self.clamp_min = clamp_min - self.clamp_max = clamp_max - self.num_hidden_layers = num_hidden_layers - self.num_heads = num_heads - self.num_kv_heads = num_kv_heads - self.num_slots = num_slots - self.use_short_conv = use_short_conv - self.conv_size = conv_size - self.expand_k = exapnd_k - self.expand_v = exapnd_v - self.feature_map = feature_map - self.use_output_gate = use_output_gate - self.use_norm = use_norm - self.max_position_embeddings = max_position_embeddings - self.hidden_act = hidden_act - self.elementwise_affine = elementwise_affine - self.norm_first = norm_first - self.norm_eps = norm_eps - self.use_cache = use_cache - self.initializer_range = initializer_range - self.fuse_cross_entropy = fuse_cross_entropy - self.fuse_norm = fuse_norm - - # Initialize vision specific parameters - self.image_size = image_size - self.patch_size = patch_size - self.num_channels = num_channels - self.num_classes = num_classes - self.qkv_bias = qkv_bias - self.hidden_dropout_prob = hidden_dropout_prob - self.use_mask_token = use_mask_token - self.layer_norm_eps = layer_norm_eps - self.interpolate_pos_encoding = interpolate_pos_encoding - self.scan_type = scan_type - - if attn is not None: - if not isinstance(attn, Dict): - raise ValueError("attn must be a dictionary") - if 'layers' not in attn: - raise ValueError("Layer indices must be provided to initialize hybrid attention layers") - if 'num_heads' not in attn: - raise ValueError("Number of heads must be provided to initialize hybrid attention layers") - attn['num_kv_heads'] = attn.get('num_kv_heads', attn['num_heads']) - attn['window_size'] = attn.get('window_size', None) - - self.attn = attn - - if mlp_dim is None: - self.mlp_dim = 4 * hidden_size # default value set to 4 * hidden_size - else: - self.mlp_dim = mlp_dim - - super().__init__(**kwargs) diff --git a/fla/vision_models/gsa/modeling_gsa.py b/fla/vision_models/gsa/modeling_gsa.py deleted file mode 100644 index 856eea93d..000000000 --- a/fla/vision_models/gsa/modeling_gsa.py +++ /dev/null @@ -1,209 +0,0 @@ -import collections.abc -import math -import torch -import torch.utils.checkpoint -from torch import nn -from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss -from typing import Optional, Set, Tuple, Union, List, Dict, Unpack -from transformers.utils import logging -from fla.layers.attn import Attention -from transformers.modeling_outputs import ImageClassifierOutput -from transformers.modeling_utils import PreTrainedModel -from .configuration_gsa import GSAVisionConfig -from fla.layers.gsa import GatedSlotAttention -from fla.models.utils import Cache -from ..utils import ImageEmbeddings, Pooler, prepare_hidden_states_for_cross_scan, prepare_hidden_states_for_cross_merge - -logger = logging.get_logger(__name__) - -class GSAMLP(nn.Module): - def __init__(self, config): - super().__init__() - self.net = nn.Sequential( - nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps), - nn.Linear(config.hidden_size, config.mlp_dim), - nn.GELU(), - nn.Linear(config.mlp_dim, config.hidden_size), - nn.Dropout(config.hidden_dropout_prob) - ) - - def forward(self, x): - return self.net(x) - -class GSABlock(nn.Module): - def __init__(self, config, layer_idx: int): - super().__init__() - - if not config.norm_first: - self.ln_1 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) - - if config.attn is not None and layer_idx in config.attn['layers']: - self.attn = Attention( - hidden_size=config.hidden_size, - num_heads=config.attn['num_heads'], - num_kv_heads=config.attn['num_kv_heads'], - window_size=config.attn['window_size'], - max_position_embeddings=config.max_position_embeddings, - layer_idx=layer_idx - ) - else: - self.attn = GatedSlotAttention( - hidden_size=config.hidden_size, - expand_k=config.expand_k, - expand_v=config.expand_v, - num_heads=config.num_heads, - num_kv_heads=config.num_kv_heads, - num_slots=config.num_slots, - use_short_conv=config.use_short_conv, - conv_size=config.conv_size, - feature_map=config.feature_map, - use_output_gate=config.use_output_gate, - use_norm=config.use_norm, - gate_fn=config.hidden_act, - gate_logit_normalizer=config.gate_logit_normalizer, - elementwise_affine=config.elementwise_affine, - norm_first=config.norm_first, - norm_eps=config.norm_eps, - fuse_norm=config.fuse_norm, - layer_idx=layer_idx - ) - - if not config.norm_first: - self.ln_2 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) - - self.mlp = GSAMLP(config) - - self.scan_type = config.scan_type - - def forward( - self, - hidden_states: torch.Tensor, - past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, - use_cache: Optional[bool] = False, - output_attentions: Optional[bool] = False, - **kwargs: Unpack[Dict] - ) -> Union[Tuple[torch.Tensor, Optional[torch.Tensor]], Tuple[torch.Tensor]]: - residual = hidden_states - - # Pre-normalization if enabled - if hasattr(self, 'ln_1'): - hidden_states = self.ln_1(hidden_states) - - # Apply attention - - hidden_states = prepare_hidden_states_for_cross_scan(hidden_states, self.scan_type) - - hidden_states, attentions, past_key_values = self.attn( - hidden_states=hidden_states, - past_key_values=past_key_values, - use_cache=use_cache, - output_attentions=output_attentions, - **kwargs - ) - - hidden_states = prepare_hidden_states_for_cross_merge(hidden_states, self.scan_type) - - # First residual connection - hidden_states = residual + hidden_states - residual = hidden_states - - # Pre-normalization for MLP if enabled - if hasattr(self, 'ln_2'): - hidden_states = self.ln_2(hidden_states) - - # MLP - hidden_states = self.mlp(hidden_states) - - # Second residual connection - hidden_states = residual + hidden_states - - outputs = (hidden_states, attentions, past_key_values) - - return outputs - -class GSAVisionPreTrainedModel(PreTrainedModel): - # this part of the code is adapted from huggingface/transformers vit implementation - config_class = GSAVisionConfig - - def _init_weights(self, module): - if isinstance(module, (nn.Linear, nn.Conv2d)): - module.weight.data = nn.init.trunc_normal_( - module.weight.data.to(torch.float32), mean=0.0, std=self.config.initializer_range - ).to(module.weight.dtype) - if module.bias is not None: - module.bias.data.zero_() - elif isinstance(module, nn.LayerNorm): - module.bias.data.zero_() - module.weight.data.fill_(1.0) - elif isinstance(module, ImageEmbeddings): - module.position_embeddings.data = nn.init.trunc_normal_( - module.position_embeddings.data.to(torch.float32), - mean=0.0, - std=self.config.initializer_range, - ).to(module.position_embeddings.dtype) - -class GSAForImageClassification(GSAVisionPreTrainedModel): - config_class = GSAVisionConfig - - def __init__(self, config): - super().__init__(config) - self.num_labels = config.num_classes - - self.embeddings = ImageEmbeddings(config) - self.blocks = nn.ModuleList([ - GSABlock(config, layer_idx) - for layer_idx in range(config.num_hidden_layers) - ]) - self.norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) - self.pooler = Pooler(config) - self.classifier = nn.Linear(config.hidden_size, config.num_classes) - self.interpolate_pos_encoding = config.interpolate_pos_encoding - self.init_weights() - - def forward( - self, - pixel_values: Optional[torch.Tensor] = None, - labels: Optional[torch.Tensor] = None, - output_hidden_states: Optional[bool] = None, - past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - return_dict: Optional[bool] = None, - **kwargs: Unpack[Dict] - ) -> Union[Tuple, ImageClassifierOutput]: - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - fuse_linear_and_cross_entropy = self.config.fuse_cross_entropy and self.training - hidden_states = self.embeddings(pixel_values, interpolate_pos_encoding=self.interpolate_pos_encoding) - - for block in self.blocks: - hidden_states, attentions, past_key_values = block( - hidden_states, - past_key_values=past_key_values, - use_cache=use_cache, - output_attentions=output_attentions, - **kwargs - ) - - hidden_states = self.norm(hidden_states) - pooled_output = self.pooler(hidden_states) - - logits = self.classifier(pooled_output) - - loss = None - if labels is not None: - if self.num_labels == 1: - loss_fct = MSELoss() - loss = loss_fct(logits.squeeze(), labels.squeeze()) - else: - loss_fct = CrossEntropyLoss() - loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) - - if not return_dict: - output = (logits,) + (hidden_states,) - return ((loss,) + output) if loss is not None else output - - return ImageClassifierOutput( - loss=loss, - logits=logits, - hidden_states=hidden_states, - ) diff --git a/fla/vision_models/hgrn/__init__.py b/fla/vision_models/hgrn/__init__.py deleted file mode 100644 index e9ab00ae0..000000000 --- a/fla/vision_models/hgrn/__init__.py +++ /dev/null @@ -1,12 +0,0 @@ -from transformers import AutoConfig, AutoModelForImageClassification - -from fla.vision_models.hgrn.configuration_hgrn import HGRNVisionConfig -from fla.vision_models.hgrn.modeling_hgrn import HGRNForImageClassification - -AutoConfig.register(HGRNVisionConfig.model_type, HGRNVisionConfig) -AutoModelForImageClassification.register(HGRNVisionConfig, HGRNForImageClassification) - -__all__ = [ - 'HGRNVisionConfig', - 'HGRNForImageClassification' -] diff --git a/fla/vision_models/hgrn/configuration_hgrn.py b/fla/vision_models/hgrn/configuration_hgrn.py deleted file mode 100644 index e9724239b..000000000 --- a/fla/vision_models/hgrn/configuration_hgrn.py +++ /dev/null @@ -1,86 +0,0 @@ -from typing import Dict, Optional - -from transformers.configuration_utils import PretrainedConfig - - -class HGRNVisionConfig(PretrainedConfig): - - model_type = 'hgrn_vision' - - def __init__( - self, - # HGRN core parameters - attn_mode: str = "chunk", - hidden_size: int = 2048, - num_hidden_layers: int = 24, - expand_ratio: Optional[int] = 1, - use_short_conv: bool = False, - conv_size: int = 4, - use_lower_bound: bool = True, - max_position_embeddings: int = 2048, - hidden_act: str = "swish", - elementwise_affine: Optional[bool] = True, - norm_eps: float = 1e-6, - attn: Optional[Dict] = None, - use_cache: bool = True, - initializer_range: float = 0.02, - fuse_cross_entropy: bool = True, - # Vision specific parameters - image_size: int = 224, - patch_size: int = 16, - num_channels: int = 3, - num_classes: int = 1000, - qkv_bias: bool = True, - hidden_dropout_prob: float = 0.0, - use_mask_token: bool = False, - layer_norm_eps: float = 1e-6, - interpolate_pos_encoding: bool = False, - mlp_dim: int = None, - # FLA-for-vision-related parameters - scan_type: str = "uni-scan", # scaning type, "uni-scan" or "bi-scan" or "cross-scan", default to "uni-scan" - **kwargs - ): - # Initialize HGRN core parameters - self.attn_mode = attn_mode - self.hidden_size = hidden_size - self.num_hidden_layers = num_hidden_layers - self.expand_ratio = expand_ratio - self.use_short_conv = use_short_conv - self.conv_size = conv_size - self.use_lower_bound = use_lower_bound - self.max_position_embeddings = max_position_embeddings - self.elementwise_affine = elementwise_affine - self.norm_eps = norm_eps - self.hidden_act = hidden_act - self.use_cache = use_cache - self.initializer_range = initializer_range - self.fuse_cross_entropy = fuse_cross_entropy - # Initialize vision specific parameters - self.image_size = image_size - self.patch_size = patch_size - self.num_channels = num_channels - self.num_classes = num_classes - self.qkv_bias = qkv_bias - self.hidden_dropout_prob = hidden_dropout_prob - self.use_mask_token = use_mask_token - self.layer_norm_eps = layer_norm_eps - self.interpolate_pos_encoding = interpolate_pos_encoding - self.scan_type = scan_type - if attn is not None: - if not isinstance(attn, Dict): - raise ValueError("attn must be a dictionary") - if 'layers' not in attn: - raise ValueError("Layer indices must be provided to initialize hybrid attention layers") - if 'num_heads' not in attn: - raise ValueError("Number of heads must be provided to initialize hybrid attention layers") - attn['num_kv_heads'] = attn.get('num_kv_heads', attn['num_heads']) - attn['window_size'] = attn.get('window_size', None) - - self.attn = attn - - if mlp_dim is None: - self.mlp_dim = 4 * hidden_size # default value set to 4 * hidden_size - else: - self.mlp_dim = mlp_dim - - super().__init__(**kwargs) diff --git a/fla/vision_models/hgrn/modeling_hgrn.py b/fla/vision_models/hgrn/modeling_hgrn.py deleted file mode 100644 index 35d6e21bf..000000000 --- a/fla/vision_models/hgrn/modeling_hgrn.py +++ /dev/null @@ -1,197 +0,0 @@ -import collections.abc -import math -import torch -import torch.utils.checkpoint -from torch import nn -from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss -from typing import Optional, Set, Tuple, Union, List, Dict, Unpack -from transformers.utils import logging -from fla.layers.attn import Attention -from transformers.modeling_outputs import ImageClassifierOutput -from transformers.modeling_utils import PreTrainedModel -from .configuration_hgrn import HGRNVisionConfig -from fla.layers.hgrn import HGRNAttention -from fla.models.utils import Cache -from ..utils import ImageEmbeddings, Pooler, prepare_hidden_states_for_cross_scan, prepare_hidden_states_for_cross_merge - -logger = logging.get_logger(__name__) - -class HGRNMLP(nn.Module): - def __init__(self, config): - super().__init__() - self.net = nn.Sequential( - nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps), - nn.Linear(config.hidden_size, config.mlp_dim), - nn.GELU(), - nn.Linear(config.mlp_dim, config.hidden_size), - nn.Dropout(config.hidden_dropout_prob) - ) - - def forward(self, x): - return self.net(x) - -class HGRNBlock(nn.Module): - def __init__(self, config, layer_idx: int): - super().__init__() - - self.ln_1 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) - - if config.attn is not None and layer_idx in config.attn['layers']: - self.attn = Attention( - hidden_size=config.hidden_size, - num_heads=config.attn['num_heads'], - num_kv_heads=config.attn['num_kv_heads'], - window_size=config.attn['window_size'], - max_position_embeddings=config.max_position_embeddings, - layer_idx=layer_idx - ) - else: - self.attn = HGRNAttention( - mode=config.attn_mode, - hidden_size=config.hidden_size, - expand_ratio=config.expand_ratio, - use_short_conv=config.use_short_conv, - conv_size=config.conv_size, - elementwise_affine=config.elementwise_affine, - norm_eps=config.norm_eps, - layer_idx=layer_idx - ) - - self.ln_2 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) - - self.mlp = HGRNMLP(config) - - self.scan_type = config.scan_type - - def forward( - self, - hidden_states: torch.Tensor, - past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, - use_cache: Optional[bool] = False, - output_attentions: Optional[bool] = False, - **kwargs: Unpack[Dict] - ) -> Union[Tuple[torch.Tensor, Optional[torch.Tensor]], Tuple[torch.Tensor]]: - residual = hidden_states - - # Pre-normalization if enabled - if hasattr(self, 'ln_1'): - hidden_states = self.ln_1(hidden_states) - - # Apply attention - - hidden_states = prepare_hidden_states_for_cross_scan(hidden_states, self.scan_type) - - hidden_states, attentions, past_key_values = self.attn( - hidden_states=hidden_states, - past_key_values=past_key_values, - use_cache=use_cache, - output_attentions=output_attentions, - **kwargs - ) - - hidden_states = prepare_hidden_states_for_cross_merge(hidden_states, self.scan_type) - - # First residual connection - hidden_states = residual + hidden_states - residual = hidden_states - - # Pre-normalization for MLP if enabled - if hasattr(self, 'ln_2'): - hidden_states = self.ln_2(hidden_states) - - # MLP - hidden_states = self.mlp(hidden_states) - - # Second residual connection - hidden_states = residual + hidden_states - - outputs = (hidden_states, attentions, past_key_values) - - return outputs - -class HGRNVisionPreTrainedModel(PreTrainedModel): - # this part of the code is adapted from huggingface/transformers vit implementation - config_class = HGRNVisionConfig - - def _init_weights(self, module): - if isinstance(module, (nn.Linear, nn.Conv2d)): - module.weight.data = nn.init.trunc_normal_( - module.weight.data.to(torch.float32), mean=0.0, std=self.config.initializer_range - ).to(module.weight.dtype) - if module.bias is not None: - module.bias.data.zero_() - elif isinstance(module, nn.LayerNorm): - module.bias.data.zero_() - module.weight.data.fill_(1.0) - elif isinstance(module, ImageEmbeddings): - module.position_embeddings.data = nn.init.trunc_normal_( - module.position_embeddings.data.to(torch.float32), - mean=0.0, - std=self.config.initializer_range, - ).to(module.position_embeddings.dtype) - -class HGRNForImageClassification(HGRNVisionPreTrainedModel): - config_class = HGRNVisionConfig - - def __init__(self, config): - super().__init__(config) - self.num_labels = config.num_classes - - self.embeddings = ImageEmbeddings(config) - self.blocks = nn.ModuleList([ - HGRNBlock(config, layer_idx) - for layer_idx in range(config.num_hidden_layers) - ]) - self.norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) - self.pooler = Pooler(config) - self.classifier = nn.Linear(config.hidden_size, config.num_classes) - self.interpolate_pos_encoding = config.interpolate_pos_encoding - self.init_weights() - - def forward( - self, - pixel_values: Optional[torch.Tensor] = None, - labels: Optional[torch.Tensor] = None, - output_hidden_states: Optional[bool] = None, - past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - return_dict: Optional[bool] = None, - **kwargs: Unpack[Dict] - ) -> Union[Tuple, ImageClassifierOutput]: - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - fuse_linear_and_cross_entropy = self.config.fuse_cross_entropy and self.training - hidden_states = self.embeddings(pixel_values, interpolate_pos_encoding=self.interpolate_pos_encoding) - - for block in self.blocks: - hidden_states, attentions, past_key_values = block( - hidden_states, - past_key_values=past_key_values, - use_cache=use_cache, - output_attentions=output_attentions, - **kwargs - ) - - hidden_states = self.norm(hidden_states) - pooled_output = self.pooler(hidden_states) - - logits = self.classifier(pooled_output) - - loss = None - if labels is not None: - if self.num_labels == 1: - loss_fct = MSELoss() - loss = loss_fct(logits.squeeze(), labels.squeeze()) - else: - loss_fct = CrossEntropyLoss() - loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) - - if not return_dict: - output = (logits,) + (hidden_states,) - return ((loss,) + output) if loss is not None else output - - return ImageClassifierOutput( - loss=loss, - logits=logits, - hidden_states=hidden_states, - ) diff --git a/fla/vision_models/hgrn2/__init__.py b/fla/vision_models/hgrn2/__init__.py deleted file mode 100644 index 69a2c9c55..000000000 --- a/fla/vision_models/hgrn2/__init__.py +++ /dev/null @@ -1,12 +0,0 @@ -from transformers import AutoConfig, AutoModelForImageClassification - -from fla.vision_models.hgrn2.configuration_hgrn2 import HGRN2VisionConfig -from fla.vision_models.hgrn2.modeling_hgrn2 import HGRN2ForImageClassification - -AutoConfig.register(HGRN2VisionConfig.model_type, HGRN2VisionConfig) -AutoModelForImageClassification.register(HGRN2VisionConfig, HGRN2ForImageClassification) - -__all__ = [ - 'HGRN2VisionConfig', - 'HGRN2ForImageClassification' -] diff --git a/fla/vision_models/hgrn2/configuration_hgrn2.py b/fla/vision_models/hgrn2/configuration_hgrn2.py deleted file mode 100644 index ef6ffc83b..000000000 --- a/fla/vision_models/hgrn2/configuration_hgrn2.py +++ /dev/null @@ -1,89 +0,0 @@ -from typing import Dict, Optional - -from transformers.configuration_utils import PretrainedConfig - - -class HGRN2VisionConfig(PretrainedConfig): - - model_type = 'hgrn2_vision' - - def __init__( - self, - # HGRN2 core parameters - hidden_size: int = 2048, - num_hidden_layers: int = 24, - attn_mode: str = "chunk", - num_heads: Optional[int] = None, - expand_ratio: Optional[int] = 128, - use_short_conv: bool = False, - conv_size: int = 4, - use_lower_bound: bool = True, - hidden_act: str = "swish", - max_position_embeddings: int = 2048, - elementwise_affine: Optional[bool] = True, - norm_eps: float = 1e-6, - attn: Optional[Dict] = None, - use_cache: bool = True, - initializer_range: float = 0.02, - fuse_cross_entropy: bool = True, - # Vision specific parameters - image_size: int = 224, - patch_size: int = 16, - num_channels: int = 3, - num_classes: int = 1000, - qkv_bias: bool = True, - hidden_dropout_prob: float = 0.0, - use_mask_token: bool = False, - layer_norm_eps: float = 1e-6, - interpolate_pos_encoding: bool = False, - mlp_dim: int = None, - # FLA-for-vision-related parameters - scan_type: str = "uni-scan", # scaning type, "uni-scan" or "bi-scan" or "cross-scan", default to "uni-scan" - **kwargs - ): - # Initialize HGRN2 core parameters - self.hidden_size = hidden_size - self.num_hidden_layers = num_hidden_layers - self.attn_mode = attn_mode - self.num_heads = num_heads - self.expand_ratio = expand_ratio - self.use_short_conv = use_short_conv - self.conv_size = conv_size - self.use_lower_bound = use_lower_bound - self.max_position_embeddings = max_position_embeddings - self.hidden_act = hidden_act - self.elementwise_affine = elementwise_affine - self.norm_eps = norm_eps - self.use_cache = use_cache - self.initializer_range = initializer_range - self.fuse_cross_entropy = fuse_cross_entropy - # Initialize vision specific parameters - self.image_size = image_size - self.patch_size = patch_size - self.num_channels = num_channels - self.num_classes = num_classes - self.qkv_bias = qkv_bias - self.hidden_dropout_prob = hidden_dropout_prob - self.use_mask_token = use_mask_token - self.layer_norm_eps = layer_norm_eps - self.interpolate_pos_encoding = interpolate_pos_encoding - self.scan_type = scan_type - - if attn is not None: - if not isinstance(attn, Dict): - raise ValueError("attn must be a dictionary") - if 'layers' not in attn: - raise ValueError("Layer indices must be provided to initialize hybrid attention layers") - if 'num_heads' not in attn: - raise ValueError("Number of heads must be provided to initialize hybrid attention layers") - attn['num_kv_heads'] = attn.get('num_kv_heads', attn['num_heads']) - attn['window_size'] = attn.get('window_size', None) - - self.attn = attn - - if mlp_dim is None: - self.mlp_dim = 4 * hidden_size # default value set to 4 * hidden_size - else: - self.mlp_dim = mlp_dim - - super().__init__(**kwargs) \ No newline at end of file diff --git a/fla/vision_models/hgrn2/modeling_hgrn2.py b/fla/vision_models/hgrn2/modeling_hgrn2.py deleted file mode 100644 index cbae1a64a..000000000 --- a/fla/vision_models/hgrn2/modeling_hgrn2.py +++ /dev/null @@ -1,198 +0,0 @@ -import collections.abc -import math -import torch -import torch.utils.checkpoint -from torch import nn -from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss -from typing import Optional, Set, Tuple, Union, List, Dict, Unpack -from transformers.utils import logging -from fla.layers.attn import Attention -from transformers.modeling_outputs import ImageClassifierOutput -from transformers.modeling_utils import PreTrainedModel -from .configuration_hgrn2 import HGRN2VisionConfig -from fla.layers.hgrn2 import HGRN2Attention -from fla.models.utils import Cache -from ..utils import ImageEmbeddings, Pooler, prepare_hidden_states_for_cross_scan, prepare_hidden_states_for_cross_merge - -logger = logging.get_logger(__name__) - -class HGRN2MLP(nn.Module): - def __init__(self, config): - super().__init__() - self.net = nn.Sequential( - nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps), - nn.Linear(config.hidden_size, config.mlp_dim), - nn.GELU(), - nn.Linear(config.mlp_dim, config.hidden_size), - nn.Dropout(config.hidden_dropout_prob) - ) - - def forward(self, x): - return self.net(x) - -class HGRN2Block(nn.Module): - def __init__(self, config, layer_idx: int): - super().__init__() - - self.ln_1 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) - - if config.attn is not None and layer_idx in config.attn['layers']: - self.attn = Attention( - hidden_size=config.hidden_size, - num_heads=config.attn['num_heads'], - num_kv_heads=config.attn['num_kv_heads'], - window_size=config.attn['window_size'], - max_position_embeddings=config.max_position_embeddings, - layer_idx=layer_idx - ) - else: - self.attn = HGRN2Attention( - mode=config.attn_mode, - hidden_size=config.hidden_size, - num_heads=config.num_heads, - expand_ratio=config.expand_ratio, - use_short_conv=config.use_short_conv, - conv_size=config.conv_size, - elementwise_affine=config.elementwise_affine, - norm_eps=config.norm_eps, - layer_idx=layer_idx - ) - - self.ln_2 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) - - self.mlp = HGRN2MLP(config) - - self.scan_type = config.scan_type - - def forward( - self, - hidden_states: torch.Tensor, - past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, - use_cache: Optional[bool] = False, - output_attentions: Optional[bool] = False, - **kwargs: Unpack[Dict] - ) -> Union[Tuple[torch.Tensor, Optional[torch.Tensor]], Tuple[torch.Tensor]]: - residual = hidden_states - - # Pre-normalization if enabled - if hasattr(self, 'ln_1'): - hidden_states = self.ln_1(hidden_states) - - # Apply attention - - hidden_states = prepare_hidden_states_for_cross_scan(hidden_states, self.scan_type) - - hidden_states, attentions, past_key_values = self.attn( - hidden_states=hidden_states, - past_key_values=past_key_values, - use_cache=use_cache, - output_attentions=output_attentions, - **kwargs - ) - - hidden_states = prepare_hidden_states_for_cross_merge(hidden_states, self.scan_type) - - # First residual connection - hidden_states = residual + hidden_states - residual = hidden_states - - # Pre-normalization for MLP if enabled - if hasattr(self, 'ln_2'): - hidden_states = self.ln_2(hidden_states) - - # MLP - hidden_states = self.mlp(hidden_states) - - # Second residual connection - hidden_states = residual + hidden_states - - outputs = (hidden_states, attentions, past_key_values) - - return outputs - -class HGRN2VisionPreTrainedModel(PreTrainedModel): - # this part of the code is adapted from huggingface/transformers vit implementation - config_class = HGRN2VisionConfig - - def _init_weights(self, module): - if isinstance(module, (nn.Linear, nn.Conv2d)): - module.weight.data = nn.init.trunc_normal_( - module.weight.data.to(torch.float32), mean=0.0, std=self.config.initializer_range - ).to(module.weight.dtype) - if module.bias is not None: - module.bias.data.zero_() - elif isinstance(module, nn.LayerNorm): - module.bias.data.zero_() - module.weight.data.fill_(1.0) - elif isinstance(module, ImageEmbeddings): - module.position_embeddings.data = nn.init.trunc_normal_( - module.position_embeddings.data.to(torch.float32), - mean=0.0, - std=self.config.initializer_range, - ).to(module.position_embeddings.dtype) - -class HGRN2ForImageClassification(HGRN2VisionPreTrainedModel): - config_class = HGRN2VisionConfig - - def __init__(self, config): - super().__init__(config) - self.num_labels = config.num_classes - - self.embeddings = ImageEmbeddings(config) - self.blocks = nn.ModuleList([ - HGRN2Block(config, layer_idx) - for layer_idx in range(config.num_hidden_layers) - ]) - self.norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) - self.pooler = Pooler(config) - self.classifier = nn.Linear(config.hidden_size, config.num_classes) - self.interpolate_pos_encoding = config.interpolate_pos_encoding - self.init_weights() - - def forward( - self, - pixel_values: Optional[torch.Tensor] = None, - labels: Optional[torch.Tensor] = None, - output_hidden_states: Optional[bool] = None, - past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - return_dict: Optional[bool] = None, - **kwargs: Unpack[Dict] - ) -> Union[Tuple, ImageClassifierOutput]: - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - fuse_linear_and_cross_entropy = self.config.fuse_cross_entropy and self.training - hidden_states = self.embeddings(pixel_values, interpolate_pos_encoding=self.interpolate_pos_encoding) - - for block in self.blocks: - hidden_states, attentions, past_key_values = block( - hidden_states, - past_key_values=past_key_values, - use_cache=use_cache, - output_attentions=output_attentions, - **kwargs - ) - - hidden_states = self.norm(hidden_states) - pooled_output = self.pooler(hidden_states) - - logits = self.classifier(pooled_output) - - loss = None - if labels is not None: - if self.num_labels == 1: - loss_fct = MSELoss() - loss = loss_fct(logits.squeeze(), labels.squeeze()) - else: - loss_fct = CrossEntropyLoss() - loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) - - if not return_dict: - output = (logits,) + (hidden_states,) - return ((loss,) + output) if loss is not None else output - - return ImageClassifierOutput( - loss=loss, - logits=logits, - hidden_states=hidden_states, - ) diff --git a/fla/vision_models/linear_attn/__init__.py b/fla/vision_models/linear_attn/__init__.py deleted file mode 100644 index d56bc5e04..000000000 --- a/fla/vision_models/linear_attn/__init__.py +++ /dev/null @@ -1,12 +0,0 @@ -from transformers import AutoConfig, AutoModelForImageClassification - -from fla.vision_models.linear_attn.configuration_linear_attn import LinearAttentionVisionConfig -from fla.vision_models.linear_attn.modeling_linear_attn import LinearAttentionForImageClassification - -AutoConfig.register(LinearAttentionVisionConfig.model_type, LinearAttentionVisionConfig) -AutoModelForImageClassification.register(LinearAttentionVisionConfig, LinearAttentionForImageClassification) - -__all__ = [ - 'LinearAttentionVisionConfig', - 'LinearAttentionForImageClassification' -] diff --git a/fla/vision_models/linear_attn/configuration_linear_attn.py b/fla/vision_models/linear_attn/configuration_linear_attn.py deleted file mode 100644 index 8aa0d2ef0..000000000 --- a/fla/vision_models/linear_attn/configuration_linear_attn.py +++ /dev/null @@ -1,96 +0,0 @@ -from typing import Dict, Optional - -from transformers.configuration_utils import PretrainedConfig - - -class LinearAttentionVisionConfig(PretrainedConfig): - - model_type = 'linear_attn_vision' - - def __init__( - self, - # LinearAttention core parameters - attn_mode: str = "fused_chunk", - hidden_size: int = 2048, - expand_k: int = 1, - expand_v: int = 1, - num_hidden_layers: int = 24, - num_heads: int = 4, - num_kv_heads: Optional[int] = None, - feature_map: str = "elementwise_product", - tie_feature_map_qk: bool = False, - norm_q: bool = False, - norm_k: bool = False, - norm_feature_map: bool = False, - hidden_act: str = "swish", - max_position_embeddings: int = 2048, - elementwise_affine: Optional[bool] = True, - norm_eps: float = 1e-6, - attn: Optional[Dict] = None, - use_cache: bool = True, - initializer_range: float = 0.02, - fuse_cross_entropy: bool = True, - # Vision specific parameters - image_size: int = 224, - patch_size: int = 16, - num_channels: int = 3, - num_classes: int = 1000, - qkv_bias: bool = True, - hidden_dropout_prob: float = 0.0, - use_mask_token: bool = False, - layer_norm_eps: float = 1e-6, - interpolate_pos_encoding: bool = False, - mlp_dim: int = None, - # FLA-for-vision-related parameters - scan_type: str = "uni-scan", # scaning type, "uni-scan" or "bi-scan" or "cross-scan", default to "uni-scan" - **kwargs - ): - # Initialize LinearAttention core parameters - self.attn_mode = attn_mode - self.hidden_size = hidden_size - self.expand_k = expand_k - self.expand_v = expand_v - self.num_hidden_layers = num_hidden_layers - self.num_heads = num_heads - self.num_kv_heads = num_kv_heads - self.feature_map = feature_map - self.tie_feature_map_qk = tie_feature_map_qk - self.norm_q = norm_q - self.norm_k = norm_k - self.norm_feature_map = norm_feature_map - self.max_position_embeddings = max_position_embeddings - self.elementwise_affine = elementwise_affine - self.norm_eps = norm_eps - self.use_cache = use_cache - self.initializer_range = initializer_range - self.fuse_cross_entropy = fuse_cross_entropy - - # Initialize vision specific parameters - self.image_size = image_size - self.patch_size = patch_size - self.num_channels = num_channels - self.num_classes = num_classes - self.qkv_bias = qkv_bias - self.hidden_dropout_prob = hidden_dropout_prob - self.use_mask_token = use_mask_token - self.layer_norm_eps = layer_norm_eps - self.interpolate_pos_encoding = interpolate_pos_encoding - self.scan_type = scan_type - - if attn is not None: - if not isinstance(attn, Dict): - raise ValueError("attn must be a dictionary") - if 'layers' not in attn: - raise ValueError("Layer indices must be provided to initialize hybrid attention layers") - if 'num_heads' not in attn: - raise ValueError("Number of heads must be provided to initialize hybrid attention layers") - attn['num_kv_heads'] = attn.get('num_kv_heads', attn['num_heads']) - attn['window_size'] = attn.get('window_size', None) - - self.attn = attn - if mlp_dim is None: - self.mlp_dim = 4 * hidden_size # default value set to 4 * hidden_size - else: - self.mlp_dim = mlp_dim - - super().__init__(**kwargs) \ No newline at end of file diff --git a/fla/vision_models/linear_attn/modeling_linear_attn.py b/fla/vision_models/linear_attn/modeling_linear_attn.py deleted file mode 100644 index f0889a493..000000000 --- a/fla/vision_models/linear_attn/modeling_linear_attn.py +++ /dev/null @@ -1,197 +0,0 @@ -import collections.abc -import math -import torch -import torch.utils.checkpoint -from torch import nn -from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss -from typing import Optional, Set, Tuple, Union, List, Dict, Unpack -from transformers.utils import logging -from fla.layers.attn import Attention -from transformers.modeling_outputs import ImageClassifierOutput -from transformers.modeling_utils import PreTrainedModel -from .configuration_linear_attn import LinearAttentionVisionConfig -from fla.layers.linear_attn import LinearAttention -from fla.models.utils import Cache -from ..utils import ImageEmbeddings, Pooler, prepare_hidden_states_for_cross_scan, prepare_hidden_states_for_cross_merge - -logger = logging.get_logger(__name__) - -class LinearAttentionMLP(nn.Module): - def __init__(self, config): - super().__init__() - self.net = nn.Sequential( - nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps), - nn.Linear(config.hidden_size, config.mlp_dim), - nn.GELU(), - nn.Linear(config.mlp_dim, config.hidden_size), - nn.Dropout(config.hidden_dropout_prob) - ) - - def forward(self, x): - return self.net(x) - -class LinearAttentionBlock(nn.Module): - def __init__(self, config, layer_idx: int): - super().__init__() - - self.ln_1 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) - - if config.attn is not None and layer_idx in config.attn['layers']: - self.attn = Attention( - hidden_size=config.hidden_size, - num_heads=config.attn['num_heads'], - num_kv_heads=config.attn['num_kv_heads'], - window_size=config.attn['window_size'], - max_position_embeddings=config.max_position_embeddings, - layer_idx=layer_idx - ) - else: - self.attn = LinearAttention( - mode=config.attn_mode, - hidden_size=config.hidden_size, - expand_k=config.expand_k, - expand_v=config.expand_v, - num_heads=config.num_heads, - num_kv_heads=config.num_kv_heads, - feature_map=config.feature_map, - tie_feature_map_qk=config.tie_feature_map_qk, - norm_q=config.norm_q, - norm_k=config.norm_k, - do_feature_map_norm=config.norm_feature_map, - elementwise_affine=config.elementwise_affine, - norm_eps=config.norm_eps, - layer_idx=layer_idx - ) - - self.ln_2 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) - - self.mlp = LinearAttentionMLP(config) - - self.scan_type = config.scan_type - - def forward( - self, - hidden_states: torch.Tensor, - past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, - use_cache: Optional[bool] = False, - output_attentions: Optional[bool] = False, - **kwargs: Unpack[Dict] - ) -> Union[Tuple[torch.Tensor, Optional[torch.Tensor]], Tuple[torch.Tensor]]: - residual = hidden_states - - # Pre-normalization if enabled - if hasattr(self, 'ln_1'): - hidden_states = self.ln_1(hidden_states) - - # Apply attention - - hidden_states = prepare_hidden_states_for_cross_scan(hidden_states, self.scan_type) - - hidden_states = self.attn(hidden_states) - - hidden_states = prepare_hidden_states_for_cross_merge(hidden_states, self.scan_type) - - # First residual connection - hidden_states = residual + hidden_states - residual = hidden_states - - # Pre-normalization for MLP if enabled - if hasattr(self, 'ln_2'): - hidden_states = self.ln_2(hidden_states) - - # MLP - hidden_states = self.mlp(hidden_states) - - # Second residual connection - hidden_states = residual + hidden_states - - outputs = (hidden_states,) - - return outputs - -class LinearAttentionVisionPreTrainedModel(PreTrainedModel): - # this part of the code is adapted from huggingface/transformers vit implementation - config_class = LinearAttentionVisionConfig - - def _init_weights(self, module): - if isinstance(module, (nn.Linear, nn.Conv2d)): - module.weight.data = nn.init.trunc_normal_( - module.weight.data.to(torch.float32), mean=0.0, std=self.config.initializer_range - ).to(module.weight.dtype) - if module.bias is not None: - module.bias.data.zero_() - elif isinstance(module, nn.LayerNorm): - module.bias.data.zero_() - module.weight.data.fill_(1.0) - elif isinstance(module, ImageEmbeddings): - module.position_embeddings.data = nn.init.trunc_normal_( - module.position_embeddings.data.to(torch.float32), - mean=0.0, - std=self.config.initializer_range, - ).to(module.position_embeddings.dtype) - -class LinearAttentionForImageClassification(LinearAttentionVisionPreTrainedModel): - config_class = LinearAttentionVisionConfig - - def __init__(self, config): - super().__init__(config) - self.num_labels = config.num_classes - - self.embeddings = ImageEmbeddings(config) - self.blocks = nn.ModuleList([ - LinearAttentionBlock(config, layer_idx) - for layer_idx in range(config.num_hidden_layers) - ]) - self.norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) - self.pooler = Pooler(config) - self.classifier = nn.Linear(config.hidden_size, config.num_classes) - self.interpolate_pos_encoding = config.interpolate_pos_encoding - self.init_weights() - - def forward( - self, - pixel_values: Optional[torch.Tensor] = None, - labels: Optional[torch.Tensor] = None, - output_hidden_states: Optional[bool] = None, - past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - return_dict: Optional[bool] = None, - **kwargs: Unpack[Dict] - ) -> Union[Tuple, ImageClassifierOutput]: - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - fuse_linear_and_cross_entropy = self.config.fuse_cross_entropy and self.training - hidden_states = self.embeddings(pixel_values, interpolate_pos_encoding=self.interpolate_pos_encoding) - - for block in self.blocks: - hidden_states = block( - hidden_states, - past_key_values=past_key_values, - use_cache=use_cache, - output_attentions=output_attentions, - **kwargs - ) - - hidden_states = self.norm(hidden_states) - pooled_output = self.pooler(hidden_states) - - logits = self.classifier(pooled_output) - - loss = None - if labels is not None: - if self.num_labels == 1: - loss_fct = MSELoss() - loss = loss_fct(logits.squeeze(), labels.squeeze()) - else: - loss_fct = CrossEntropyLoss() - loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) - - if not return_dict: - output = (logits,) + (hidden_states,) - return ((loss,) + output) if loss is not None else output - - return ImageClassifierOutput( - loss=loss, - logits=logits, - hidden_states=hidden_states, - ) diff --git a/fla/vision_models/retnet/__init__.py b/fla/vision_models/retnet/__init__.py deleted file mode 100644 index 4a32b420f..000000000 --- a/fla/vision_models/retnet/__init__.py +++ /dev/null @@ -1,12 +0,0 @@ -from transformers import AutoConfig, AutoModelForImageClassification - -from fla.vision_models.retnet.configuration_retnet import RetNetVisionConfig -from fla.vision_models.retnet.modeling_retnet import RetNetForImageClassification - -AutoConfig.register(RetNetVisionConfig.model_type, RetNetVisionConfig) -AutoModelForImageClassification.register(RetNetVisionConfig, RetNetForImageClassification) - -__all__ = [ - 'RetNetVisionConfig', - 'RetNetForImageClassification' -] diff --git a/fla/vision_models/retnet/configuration_retnet.py b/fla/vision_models/retnet/configuration_retnet.py deleted file mode 100644 index 53df13698..000000000 --- a/fla/vision_models/retnet/configuration_retnet.py +++ /dev/null @@ -1,101 +0,0 @@ -from __future__ import annotations - -from typing import Dict, Optional - -from transformers.configuration_utils import PretrainedConfig - - -class RetNetVisionConfig(PretrainedConfig): - - model_type = 'retnet_vision' - - def __init__( - self, - # RetNet core parameters - attn_mode: str = "chunk", - hidden_size: int = 2048, - expand_k: int = 1, - expand_v: int = 2, - num_hidden_layers: int = 24, - num_heads: int = 8, - num_kv_heads: Optional[int] = None, - feature_map: Optional[str] = None, - hidden_act: str = "swish", - use_short_conv: bool = False, - conv_size: int = 4, - use_output_gate: bool = True, - max_position_embeddings: int = 2048, - elementwise_affine: Optional[bool] = True, - norm_eps: float = 1e-6, - attn: Optional[Dict] = None, - use_cache: bool = True, - initializer_range: float = 0.02, - fuse_norm: bool = True, - fuse_cross_entropy: bool = True, - # Vision specific parameters - image_size: int = 224, - patch_size: int = 16, - num_channels: int = 3, - num_classes: int = 1000, - qkv_bias: bool = True, - hidden_dropout_prob: float = 0.0, - use_mask_token: bool = False, - layer_norm_eps: float = 1e-6, - interpolate_pos_encoding: bool = False, - mlp_dim: int = None, - # FLA-for-vision-related parameters - scan_type: str = "uni-scan", # scaning type, "uni-scan" or "bi-scan" or "cross-scan", default to "uni-scan" - **kwargs - ) -> RetNetVisionConfig: - # Initialize RetNet core parameters - self.attn_mode = attn_mode - self.hidden_size = hidden_size - self.expand_k = expand_k - self.expand_v = expand_v - self.num_hidden_layers = num_hidden_layers - self.num_heads = num_heads - self.num_kv_heads = num_kv_heads - self.feature_map = feature_map - self.hidden_act = hidden_act - self.use_short_conv = use_short_conv - self.conv_size = conv_size - self.use_output_gate = use_output_gate - self.hidden_act = hidden_act - self.max_position_embeddings = max_position_embeddings - self.elementwise_affine = elementwise_affine - self.norm_eps = norm_eps - self.use_cache = use_cache - self.initializer_range = initializer_range - self.fuse_norm = fuse_norm - self.fuse_cross_entropy = fuse_cross_entropy - - # Initialize vision specific parameters - self.image_size = image_size - self.patch_size = patch_size - self.num_channels = num_channels - self.num_classes = num_classes - self.qkv_bias = qkv_bias - self.hidden_dropout_prob = hidden_dropout_prob - self.use_mask_token = use_mask_token - self.layer_norm_eps = layer_norm_eps - self.interpolate_pos_encoding = interpolate_pos_encoding - self.scan_type = scan_type - - if attn is not None: - if not isinstance(attn, Dict): - raise ValueError("attn must be a dictionary") - if 'layers' not in attn: - raise ValueError("Layer indices must be provided to initialize hybrid attention layers") - if 'num_heads' not in attn: - raise ValueError("Number of heads must be provided to initialize hybrid attention layers") - attn['num_kv_heads'] = attn.get('num_kv_heads', attn['num_heads']) - attn['window_size'] = attn.get('window_size', None) - - self.attn = attn - - if mlp_dim is None: - self.mlp_dim = 4 * hidden_size # default value set to 4 * hidden_size - else: - self.mlp_dim = mlp_dim - - super().__init__(**kwargs) \ No newline at end of file diff --git a/fla/vision_models/retnet/modeling_retnet.py b/fla/vision_models/retnet/modeling_retnet.py deleted file mode 100644 index 961ea7c71..000000000 --- a/fla/vision_models/retnet/modeling_retnet.py +++ /dev/null @@ -1,202 +0,0 @@ -import collections.abc -import math -import torch -import torch.utils.checkpoint -from torch import nn -from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss -from typing import Optional, Set, Tuple, Union, List, Dict, Unpack -from transformers.utils import logging -from fla.layers.attn import Attention -from transformers.modeling_outputs import ImageClassifierOutput -from transformers.modeling_utils import PreTrainedModel -from .configuration_retnet import RetNetVisionConfig -from fla.layers.multiscale_retention import MultiScaleRetention -from fla.models.utils import Cache -from ..utils import ImageEmbeddings, Pooler, prepare_hidden_states_for_cross_scan, prepare_hidden_states_for_cross_merge - -logger = logging.get_logger(__name__) - -class RetNetMLP(nn.Module): - def __init__(self, config): - super().__init__() - self.net = nn.Sequential( - nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps), - nn.Linear(config.hidden_size, config.mlp_dim), - nn.GELU(), - nn.Linear(config.mlp_dim, config.hidden_size), - nn.Dropout(config.hidden_dropout_prob) - ) - - def forward(self, x): - return self.net(x) - -class RetNetBlock(nn.Module): - def __init__(self, config, layer_idx: int): - super().__init__() - - self.ln_1 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) - - if config.attn is not None and layer_idx in config.attn['layers']: - self.attn = Attention( - hidden_size=config.hidden_size, - num_heads=config.attn['num_heads'], - num_kv_heads=config.attn['num_kv_heads'], - window_size=config.attn['window_size'], - max_position_embeddings=config.max_position_embeddings, - layer_idx=layer_idx - ) - else: - self.attn = MultiScaleRetention( - mode=config.attn_mode, - hidden_size=config.hidden_size, - expand_k=config.expand_k, - expand_v=config.expand_v, - num_heads=config.num_heads, - num_kv_heads=config.num_kv_heads, - feature_map=config.feature_map, - use_output_gate=config.use_output_gate, - gate_fn=config.hidden_act, - elementwise_affine=config.elementwise_affine, - norm_eps=config.norm_eps, - fuse_norm=config.fuse_norm, - layer_idx=layer_idx - ) - - self.ln_2 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) - - self.mlp = RetNetMLP(config) - - self.scan_type = config.scan_type - - def forward( - self, - hidden_states: torch.Tensor, - past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, - use_cache: Optional[bool] = False, - output_attentions: Optional[bool] = False, - **kwargs: Unpack[Dict] - ) -> Union[Tuple[torch.Tensor, Optional[torch.Tensor]], Tuple[torch.Tensor]]: - residual = hidden_states - - # Pre-normalization if enabled - if hasattr(self, 'ln_1'): - hidden_states = self.ln_1(hidden_states) - - # Apply attention - - hidden_states = prepare_hidden_states_for_cross_scan(hidden_states, self.scan_type) - - hidden_states, attentions, past_key_values = self.attn( - hidden_states=hidden_states, - past_key_values=past_key_values, - use_cache=use_cache, - output_attentions=output_attentions, - **kwargs - ) - - hidden_states = prepare_hidden_states_for_cross_merge(hidden_states, self.scan_type) - - # First residual connection - hidden_states = residual + hidden_states - residual = hidden_states - - # Pre-normalization for MLP if enabled - if hasattr(self, 'ln_2'): - hidden_states = self.ln_2(hidden_states) - - # MLP - hidden_states = self.mlp(hidden_states) - - # Second residual connection - hidden_states = residual + hidden_states - - outputs = (hidden_states, attentions, past_key_values) - - return outputs - -class RetNetVisionPreTrainedModel(PreTrainedModel): - # this part of the code is adapted from huggingface/transformers vit implementation - config_class = RetNetVisionConfig - - def _init_weights(self, module): - if isinstance(module, (nn.Linear, nn.Conv2d)): - module.weight.data = nn.init.trunc_normal_( - module.weight.data.to(torch.float32), mean=0.0, std=self.config.initializer_range - ).to(module.weight.dtype) - if module.bias is not None: - module.bias.data.zero_() - elif isinstance(module, nn.LayerNorm): - module.bias.data.zero_() - module.weight.data.fill_(1.0) - elif isinstance(module, ImageEmbeddings): - module.position_embeddings.data = nn.init.trunc_normal_( - module.position_embeddings.data.to(torch.float32), - mean=0.0, - std=self.config.initializer_range, - ).to(module.position_embeddings.dtype) - -class RetNetForImageClassification(RetNetVisionPreTrainedModel): - config_class = RetNetVisionConfig - - def __init__(self, config): - super().__init__(config) - self.num_labels = config.num_classes - - self.embeddings = ImageEmbeddings(config) - self.blocks = nn.ModuleList([ - RetNetBlock(config, layer_idx) - for layer_idx in range(config.num_hidden_layers) - ]) - self.norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) - self.pooler = Pooler(config) - self.classifier = nn.Linear(config.hidden_size, config.num_classes) - self.interpolate_pos_encoding = config.interpolate_pos_encoding - self.init_weights() - - def forward( - self, - pixel_values: Optional[torch.Tensor] = None, - labels: Optional[torch.Tensor] = None, - output_hidden_states: Optional[bool] = None, - past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - return_dict: Optional[bool] = None, - **kwargs: Unpack[Dict] - ) -> Union[Tuple, ImageClassifierOutput]: - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - fuse_linear_and_cross_entropy = self.config.fuse_cross_entropy and self.training - hidden_states = self.embeddings(pixel_values, interpolate_pos_encoding=self.interpolate_pos_encoding) - - for block in self.blocks: - hidden_states, attentions, past_key_values = block( - hidden_states, - past_key_values=past_key_values, - use_cache=use_cache, - output_attentions=output_attentions, - **kwargs - ) - - hidden_states = self.norm(hidden_states) - pooled_output = self.pooler(hidden_states) - - logits = self.classifier(pooled_output) - - loss = None - if labels is not None: - if self.num_labels == 1: - loss_fct = MSELoss() - loss = loss_fct(logits.squeeze(), labels.squeeze()) - else: - loss_fct = CrossEntropyLoss() - loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) - - if not return_dict: - output = (logits,) + (hidden_states,) - return ((loss,) + output) if loss is not None else output - - return ImageClassifierOutput( - loss=loss, - logits=logits, - hidden_states=hidden_states, - ) diff --git a/fla/vision_models/rwkv6/__init__.py b/fla/vision_models/rwkv6/__init__.py deleted file mode 100644 index 2df666ac4..000000000 --- a/fla/vision_models/rwkv6/__init__.py +++ /dev/null @@ -1,12 +0,0 @@ -from transformers import AutoConfig, AutoModelForImageClassification - -from fla.vision_models.rwkv6.configuration_rwkv6 import RWKV6VisionConfig -from fla.vision_models.rwkv6.modeling_rwkv6 import RWKV6ForImageClassification - -AutoConfig.register(RWKV6VisionConfig.model_type, RWKV6VisionConfig) -AutoModelForImageClassification.register(RWKV6VisionConfig, RWKV6ForImageClassification) - -__all__ = [ - 'RWKV6VisionConfig', - 'RWKV6ForImageClassification' -] diff --git a/fla/vision_models/rwkv6/configuration_rwkv6.py b/fla/vision_models/rwkv6/configuration_rwkv6.py deleted file mode 100644 index 9e4d54cd0..000000000 --- a/fla/vision_models/rwkv6/configuration_rwkv6.py +++ /dev/null @@ -1,94 +0,0 @@ -from typing import Dict, Optional - -from transformers.configuration_utils import PretrainedConfig - - -class RWKV6VisionConfig(PretrainedConfig): - - model_type = 'rwkv6_vision' - - def __init__( - self, - # RWKV6 core parameters - attn_mode: str = "chunk", - hidden_size: int = 2048, - expand_k: int = 0.5, - expand_v: int = 1, - num_hidden_layers: int = 24, - num_heads: int = 4, - proj_low_rank_dim: int = 32, - gate_low_rank_dim: int = 64, - hidden_act: str = "sqrelu", - max_position_embeddings: int = 2048, - norm_first: bool = True, - norm_bias: bool = True, - norm_eps: float = 1e-5, - attn: Optional[Dict] = None, - use_cache: bool = True, - initializer_range: float = 0.02, - fuse_norm: bool = True, - fuse_cross_entropy: bool = True, - # Vision specific parameters - image_size: int = 224, - patch_size: int = 16, - num_channels: int = 3, - num_classes: int = 1000, - qkv_bias: bool = True, - hidden_dropout_prob: float = 0.0, - use_mask_token: bool = False, - layer_norm_eps: float = 1e-6, - interpolate_pos_encoding: bool = False, - mlp_dim: int = None, - # FLA-for-vision-related parameters - scan_type: str = "uni-scan", # scaning type, "uni-scan" or "bi-scan" or "cross-scan", default to "uni-scan" - **kwargs - ): - # Initialize RWKV6 core parameters - self.attn_mode = attn_mode - self.hidden_size = hidden_size - self.expand_k = expand_k - self.expand_v = expand_v - self.norm_first = norm_first - self.num_hidden_layers = num_hidden_layers - self.num_heads = num_heads - self.proj_low_rank_dim = proj_low_rank_dim - self.gate_low_rank_dim = gate_low_rank_dim - self.hidden_act = hidden_act - self.max_position_embeddings = max_position_embeddings - self.norm_bias = norm_bias - self.norm_eps = norm_eps - self.use_cache = use_cache - self.initializer_range = initializer_range - self.fuse_norm = fuse_norm - self.fuse_cross_entropy = fuse_cross_entropy - - # Initialize vision specific parameters - self.image_size = image_size - self.patch_size = patch_size - self.num_channels = num_channels - self.num_classes = num_classes - self.qkv_bias = qkv_bias - self.hidden_dropout_prob = hidden_dropout_prob - self.use_mask_token = use_mask_token - self.layer_norm_eps = layer_norm_eps - self.interpolate_pos_encoding = interpolate_pos_encoding - self.scan_type = scan_type - - if attn is not None: - if not isinstance(attn, Dict): - raise ValueError("attn must be a dictionary") - if 'layers' not in attn: - raise ValueError("Layer indices must be provided to initialize hybrid attention layers") - if 'num_heads' not in attn: - raise ValueError("Number of heads must be provided to initialize hybrid attention layers") - attn['num_kv_heads'] = attn.get('num_kv_heads', attn['num_heads']) - attn['window_size'] = attn.get('window_size', None) - - self.attn = attn - - if mlp_dim is None: - self.mlp_dim = 4 * hidden_size # default value set to 4 * hidden_size - else: - self.mlp_dim = mlp_dim - - super().__init__(**kwargs) \ No newline at end of file diff --git a/fla/vision_models/rwkv6/modeling_rwkv6.py b/fla/vision_models/rwkv6/modeling_rwkv6.py deleted file mode 100644 index 45c4df011..000000000 --- a/fla/vision_models/rwkv6/modeling_rwkv6.py +++ /dev/null @@ -1,199 +0,0 @@ -import collections.abc -import math -import torch -import torch.utils.checkpoint -from torch import nn -from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss -from typing import Optional, Set, Tuple, Union, List, Dict, Unpack -from transformers.utils import logging -from fla.layers.attn import Attention -from transformers.modeling_outputs import ImageClassifierOutput -from transformers.modeling_utils import PreTrainedModel -from fla.layers.rwkv6 import RWKV6Attention -from .configuration_rwkv6 import RWKV6VisionConfig -from fla.models.utils import Cache -from ..utils import ImageEmbeddings, Pooler, prepare_hidden_states_for_cross_scan, prepare_hidden_states_for_cross_merge - -logger = logging.get_logger(__name__) - -class RWKV6MLP(nn.Module): - def __init__(self, config): - super().__init__() - self.net = nn.Sequential( - nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps), - nn.Linear(config.hidden_size, config.mlp_dim), - nn.GELU(), - nn.Linear(config.mlp_dim, config.hidden_size), - nn.Dropout(config.hidden_dropout_prob) - ) - - def forward(self, x): - return self.net(x) - -class RWKV6Block(nn.Module): - def __init__(self, config, layer_idx: int): - super().__init__() - - self.ln_1 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) - - if config.attn is not None and layer_idx in config.attn['layers']: - self.attn = Attention( - hidden_size=config.hidden_size, - num_heads=config.attn['num_heads'], - num_kv_heads=config.attn['num_kv_heads'], - window_size=config.attn['window_size'], - max_position_embeddings=config.max_position_embeddings, - layer_idx=layer_idx - ) - else: - self.attn = RWKV6Attention( - mode=config.attn_mode, - hidden_size=config.hidden_size, - expand_k=config.expand_k, - expand_v=config.expand_v, - num_heads=config.num_heads, - proj_low_rank_dim=config.proj_low_rank_dim, - gate_low_rank_dim=config.gate_low_rank_dim, - norm_eps=config.norm_eps, - fuse_norm=config.fuse_norm, - layer_idx=layer_idx - ) - - self.ln_2 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) - - self.mlp = RWKV6MLP(config) - - self.scan_type = config.scan_type - - def forward( - self, - hidden_states: torch.Tensor, - past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, - use_cache: Optional[bool] = False, - output_attentions: Optional[bool] = False, - **kwargs: Unpack[Dict] - ) -> Union[Tuple[torch.Tensor, Optional[torch.Tensor]], Tuple[torch.Tensor]]: - residual = hidden_states - - # Pre-normalization if enabled - if hasattr(self, 'ln_1'): - hidden_states = self.ln_1(hidden_states) - - # Apply attention - - hidden_states = prepare_hidden_states_for_cross_scan(hidden_states, self.scan_type) - - hidden_states, attentions, past_key_values = self.attn( - hidden_states=hidden_states, - past_key_values=past_key_values, - use_cache=use_cache, - output_attentions=output_attentions, - **kwargs - ) - - hidden_states = prepare_hidden_states_for_cross_merge(hidden_states, self.scan_type) - - # First residual connection - hidden_states = residual + hidden_states - residual = hidden_states - - # Pre-normalization for MLP if enabled - if hasattr(self, 'ln_2'): - hidden_states = self.ln_2(hidden_states) - - # MLP - hidden_states = self.mlp(hidden_states) - - # Second residual connection - hidden_states = residual + hidden_states - - outputs = (hidden_states, attentions, past_key_values) - - return outputs - -class RWKV6VisionPreTrainedModel(PreTrainedModel): - # this part of the code is adapted from huggingface/transformers vit implementation - config_class = RWKV6VisionConfig - - def _init_weights(self, module): - if isinstance(module, (nn.Linear, nn.Conv2d)): - module.weight.data = nn.init.trunc_normal_( - module.weight.data.to(torch.float32), mean=0.0, std=self.config.initializer_range - ).to(module.weight.dtype) - if module.bias is not None: - module.bias.data.zero_() - elif isinstance(module, nn.LayerNorm): - module.bias.data.zero_() - module.weight.data.fill_(1.0) - elif isinstance(module, ImageEmbeddings): - module.position_embeddings.data = nn.init.trunc_normal_( - module.position_embeddings.data.to(torch.float32), - mean=0.0, - std=self.config.initializer_range, - ).to(module.position_embeddings.dtype) - -class RWKV6ForImageClassification(RWKV6VisionPreTrainedModel): - config_class = RWKV6VisionConfig - - def __init__(self, config): - super().__init__(config) - self.num_labels = config.num_classes - - self.embeddings = ImageEmbeddings(config) - self.blocks = nn.ModuleList([ - RWKV6Block(config, layer_idx) - for layer_idx in range(config.num_hidden_layers) - ]) - self.norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) - self.pooler = Pooler(config) - self.classifier = nn.Linear(config.hidden_size, config.num_classes) - self.interpolate_pos_encoding = config.interpolate_pos_encoding - self.init_weights() - - def forward( - self, - pixel_values: Optional[torch.Tensor] = None, - labels: Optional[torch.Tensor] = None, - output_hidden_states: Optional[bool] = None, - past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - return_dict: Optional[bool] = None, - **kwargs: Unpack[Dict] - ) -> Union[Tuple, ImageClassifierOutput]: - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - fuse_linear_and_cross_entropy = self.config.fuse_cross_entropy and self.training - hidden_states = self.embeddings(pixel_values, interpolate_pos_encoding=self.interpolate_pos_encoding) - - for block in self.blocks: - hidden_states, attentions, past_key_values = block( - hidden_states, - past_key_values=past_key_values, - use_cache=use_cache, - output_attentions=output_attentions, - **kwargs - ) - - hidden_states = self.norm(hidden_states) - pooled_output = self.pooler(hidden_states) - - logits = self.classifier(pooled_output) - - loss = None - if labels is not None: - if self.num_labels == 1: - loss_fct = MSELoss() - loss = loss_fct(logits.squeeze(), labels.squeeze()) - else: - loss_fct = CrossEntropyLoss() - loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) - - if not return_dict: - output = (logits,) + (hidden_states,) - return ((loss,) + output) if loss is not None else output - - return ImageClassifierOutput( - loss=loss, - logits=logits, - hidden_states=hidden_states, - ) diff --git a/fla/vision_models/transformer/__init__.py b/fla/vision_models/transformer/__init__.py deleted file mode 100644 index 25d4e9d2b..000000000 --- a/fla/vision_models/transformer/__init__.py +++ /dev/null @@ -1,12 +0,0 @@ -from transformers import AutoConfig, AutoModelForImageClassification - -from fla.vision_models.transformer.configuration_transformer import TransformerVisionConfig -from fla.vision_models.transformer.modeling_transformer import TransformerForImageClassification - -AutoConfig.register(TransformerVisionConfig.model_type, TransformerVisionConfig) -AutoModelForImageClassification.register(TransformerVisionConfig, TransformerForImageClassification) - -__all__ = [ - 'TransformerVisionConfig', - 'TransformerForImageClassification' -] diff --git a/fla/vision_models/transformer/configuration_transformer.py b/fla/vision_models/transformer/configuration_transformer.py deleted file mode 100644 index cc8246270..000000000 --- a/fla/vision_models/transformer/configuration_transformer.py +++ /dev/null @@ -1,81 +0,0 @@ -from typing import Optional - -from transformers.configuration_utils import PretrainedConfig - - -class TransformerVisionConfig(PretrainedConfig): - - model_type = 'transformer_vision' - - def __init__( - self, - # Transformer core parameters - hidden_size: int = 2048, - num_hidden_layers: int = 24, - num_heads: int = 32, - num_kv_heads: int = None, - window_size: Optional[int] = None, - rope_theta: Optional[float] = 10000., - max_position_embeddings: int = 2048, - hidden_act: str = "swish", - initializer_range: float = 0.02, - elementwise_affine: Optional[bool] = True, - norm_first: bool = False, - norm_eps: float = 1e-6, - use_cache: bool = True, - attention_bias: bool = False, - fuse_norm: bool = True, - fuse_cross_entropy: bool = True, - # Vision specific parameters - image_size: int = 224, - patch_size: int = 16, - num_channels: int = 3, - num_classes: int = 1000, - qkv_bias: bool = True, - hidden_dropout_prob: float = 0.0, - use_mask_token: bool = False, - layer_norm_eps: float = 1e-6, - interpolate_pos_encoding: bool = False, - mlp_dim: int = None, - # FLA-for-vision-related parameters - scan_type: str = "uni-scan", # scaning type, "uni-scan" or "bi-scan" or "cross-scan", default to "uni-scan" - **kwargs - ): - # Initialize transformer core parameters - self.hidden_size = hidden_size - self.num_hidden_layers = num_hidden_layers - self.num_heads = num_heads - self.num_kv_heads = num_kv_heads - self.window_size = window_size - self.rope_theta = rope_theta - self.max_position_embeddings = max_position_embeddings - - self.hidden_act = hidden_act - - self.initializer_range = initializer_range - self.elementwise_affine = elementwise_affine - self.norm_first = norm_first - self.norm_eps = norm_eps - self.use_cache = use_cache - self.attention_bias = attention_bias - self.fuse_cross_entropy = fuse_cross_entropy - self.fuse_norm = fuse_norm - - # Initialize vision specific parameters - self.image_size = image_size - self.patch_size = patch_size - self.num_channels = num_channels - self.num_classes = num_classes - self.qkv_bias = qkv_bias - self.hidden_dropout_prob = hidden_dropout_prob - self.use_mask_token = use_mask_token - self.layer_norm_eps = layer_norm_eps - self.interpolate_pos_encoding = interpolate_pos_encoding - self.scan_type = scan_type - - if mlp_dim is None: - self.mlp_dim = 4 * hidden_size # default value set to 4 * hidden_size - else: - self.mlp_dim = mlp_dim - - super().__init__(**kwargs) \ No newline at end of file diff --git a/fla/vision_models/transformer/modeling_transformer.py b/fla/vision_models/transformer/modeling_transformer.py deleted file mode 100644 index 6441293a5..000000000 --- a/fla/vision_models/transformer/modeling_transformer.py +++ /dev/null @@ -1,190 +0,0 @@ -import collections.abc -import math -import torch -import torch.utils.checkpoint -from torch import nn -from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss -from typing import Optional, Set, Tuple, Union, List, Dict, Unpack -from transformers.utils import logging -from fla.layers.attn import Attention -from transformers.modeling_outputs import ImageClassifierOutput -from transformers.modeling_utils import PreTrainedModel -from .configuration_transformer import TransformerVisionConfig -from fla.models.utils import Cache -from ..utils import ImageEmbeddings, Pooler, prepare_hidden_states_for_cross_scan, prepare_hidden_states_for_cross_merge - -logger = logging.get_logger(__name__) - -class TransformerMLP(nn.Module): - def __init__(self, config): - super().__init__() - self.net = nn.Sequential( - nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps), - nn.Linear(config.hidden_size, config.mlp_dim), - nn.GELU(), - nn.Linear(config.mlp_dim, config.hidden_size), - nn.Dropout(config.hidden_dropout_prob) - ) - - def forward(self, x): - return self.net(x) - -class TransformerBlock(nn.Module): - def __init__(self, config, layer_idx: int): - super().__init__() - - if not config.norm_first: - self.ln_1 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) - - self.attn = Attention( - hidden_size=config.hidden_size, - num_heads=config.num_heads, - num_kv_heads=config.num_kv_heads, - window_size=config.window_size, - rope_theta=config.rope_theta, - max_position_embeddings=config.max_position_embeddings, - norm_first=config.norm_first, - norm_eps=config.norm_eps, - layer_idx=layer_idx - ) - - - if not config.norm_first: - self.ln_2 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) - - self.mlp = TransformerMLP(config) - - self.scan_type = config.scan_type - - def forward( - self, - hidden_states: torch.Tensor, - past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, - use_cache: Optional[bool] = False, - output_attentions: Optional[bool] = False, - **kwargs: Unpack[Dict] - ) -> Union[Tuple[torch.Tensor, Optional[torch.Tensor]], Tuple[torch.Tensor]]: - residual = hidden_states - - # Pre-normalization if enabled - if hasattr(self, 'ln_1'): - hidden_states = self.ln_1(hidden_states) - - # Apply attention - - hidden_states = prepare_hidden_states_for_cross_scan(hidden_states, self.scan_type) - - hidden_states, attentions, past_key_values = self.attn( - hidden_states=hidden_states, - past_key_values=past_key_values, - use_cache=use_cache, - output_attentions=output_attentions, - **kwargs - ) - - hidden_states = prepare_hidden_states_for_cross_merge(hidden_states, self.scan_type) - - # First residual connection - hidden_states = residual + hidden_states - residual = hidden_states - - # Pre-normalization for MLP if enabled - if hasattr(self, 'ln_2'): - hidden_states = self.ln_2(hidden_states) - - # MLP - hidden_states = self.mlp(hidden_states) - - # Second residual connection - hidden_states = residual + hidden_states - - outputs = (hidden_states, attentions, past_key_values) - - return outputs - -class TransformerVisionPreTrainedModel(PreTrainedModel): - # this part of the code is adapted from huggingface/transformers vit implementation - config_class = TransformerVisionConfig - - def _init_weights(self, module): - if isinstance(module, (nn.Linear, nn.Conv2d)): - module.weight.data = nn.init.trunc_normal_( - module.weight.data.to(torch.float32), mean=0.0, std=self.config.initializer_range - ).to(module.weight.dtype) - if module.bias is not None: - module.bias.data.zero_() - elif isinstance(module, nn.LayerNorm): - module.bias.data.zero_() - module.weight.data.fill_(1.0) - elif isinstance(module, ImageEmbeddings): - module.position_embeddings.data = nn.init.trunc_normal_( - module.position_embeddings.data.to(torch.float32), - mean=0.0, - std=self.config.initializer_range, - ).to(module.position_embeddings.dtype) - -class TransformerForImageClassification(TransformerVisionPreTrainedModel): - config_class = TransformerVisionConfig - - def __init__(self, config): - super().__init__(config) - self.num_labels = config.num_classes - - self.embeddings = ImageEmbeddings(config) - self.blocks = nn.ModuleList([ - TransformerBlock(config, layer_idx) - for layer_idx in range(config.num_hidden_layers) - ]) - self.norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) - self.pooler = Pooler(config) - self.classifier = nn.Linear(config.hidden_size, config.num_classes) - self.interpolate_pos_encoding = config.interpolate_pos_encoding - self.init_weights() - - def forward( - self, - pixel_values: Optional[torch.Tensor] = None, - labels: Optional[torch.Tensor] = None, - output_hidden_states: Optional[bool] = None, - past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - return_dict: Optional[bool] = None, - **kwargs: Unpack[Dict] - ) -> Union[Tuple, ImageClassifierOutput]: - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - fuse_linear_and_cross_entropy = self.config.fuse_cross_entropy and self.training - hidden_states = self.embeddings(pixel_values, interpolate_pos_encoding=self.interpolate_pos_encoding) - - for block in self.blocks: - hidden_states, attentions, past_key_values = block( - hidden_states, - past_key_values=past_key_values, - use_cache=use_cache, - output_attentions=output_attentions, - **kwargs - ) - - hidden_states = self.norm(hidden_states) - pooled_output = self.pooler(hidden_states) - - logits = self.classifier(pooled_output) - - loss = None - if labels is not None: - if self.num_labels == 1: - loss_fct = MSELoss() - loss = loss_fct(logits.squeeze(), labels.squeeze()) - else: - loss_fct = CrossEntropyLoss() - loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) - - if not return_dict: - output = (logits,) + (hidden_states,) - return ((loss,) + output) if loss is not None else output - - return ImageClassifierOutput( - loss=loss, - logits=logits, - hidden_states=hidden_states, - ) diff --git a/fla/vision_models/utils.py b/fla/vision_models/utils.py deleted file mode 100644 index 246dcf931..000000000 --- a/fla/vision_models/utils.py +++ /dev/null @@ -1,480 +0,0 @@ -""" -Vision model utilities adapted from huggingface/transformers ViT implementation. -""" - -import collections.abc -import torch -from torch import nn -from typing import Optional -from transformers.utils import torch_int -import triton -import triton.language as tl -import einops -import math - -""" -Basic component of a vision model, like the patch embeddings, image embeddings, and pooler. -""" - -class PatchEmbeddings(nn.Module): - """ - Convert image into patch embeddings. - Adapted from huggingface/transformers ViT implementation. - """ - def __init__(self, config): - super().__init__() - image_size, patch_size = config.image_size, config.patch_size - num_channels, hidden_size = config.num_channels, config.hidden_size - - image_size = image_size if isinstance(image_size, collections.abc.Iterable) else (image_size, image_size) - patch_size = patch_size if isinstance(patch_size, collections.abc.Iterable) else (patch_size, patch_size) - num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0]) - self.image_size = image_size - self.patch_size = patch_size - self.num_channels = num_channels - self.num_patches = num_patches - - self.projection = nn.Conv2d(num_channels, hidden_size, kernel_size=patch_size, stride=patch_size) - - def forward(self, pixel_values: torch.Tensor, interpolate_pos_encoding: bool = False) -> torch.Tensor: - batch_size, num_channels, height, width = pixel_values.shape - if num_channels != self.num_channels: - raise ValueError( - "Make sure that the channel dimension of the pixel values match with the one set in the configuration." - f" Expected {self.num_channels} but got {num_channels}." - ) - if not interpolate_pos_encoding: - if height != self.image_size[0] or width != self.image_size[1]: - raise ValueError( - f"Input image size ({height}*{width}) doesn't match model" - f" ({self.image_size[0]}*{self.image_size[1]})." - ) - embeddings = self.projection(pixel_values).flatten(2).transpose(1, 2) - return embeddings - -class ImageEmbeddings(nn.Module): - """ - Construct the position and patch embeddings. - Adapted from huggingface/transformers ViT implementation. No cls token is used in this implementation. - """ - def __init__(self, config, use_mask_token: bool = False) -> None: - super().__init__() - - self.mask_token = nn.Parameter(torch.zeros(1, 1, config.hidden_size)) if use_mask_token else None - self.patch_embeddings = PatchEmbeddings(config) - num_patches = self.patch_embeddings.num_patches - self.position_embeddings = nn.Parameter(torch.randn(1, num_patches, config.hidden_size)) - self.dropout = nn.Dropout(config.hidden_dropout_prob) - self.patch_size = config.patch_size - self.config = config - - def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor: - """ - This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher resolution - images. This method is also adapted to support torch.jit tracing. - - Adapted from: - - https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174-L194, and - - https://github.com/facebookresearch/dinov2/blob/e1277af2ba9496fbadf7aec6eba56e8d882d1e35/dinov2/models/vision_transformer.py#L179-L211 - """ - - num_patches = embeddings.shape[1] - num_positions = self.position_embeddings.shape[1] - - if not torch.jit.is_tracing() and num_patches == num_positions and height == width: - return self.position_embeddings - - dim = embeddings.shape[-1] - - new_height = height // self.patch_size - new_width = width // self.patch_size - - sqrt_num_positions = torch_int(num_positions**0.5) - pos_embed = self.position_embeddings.reshape(1, sqrt_num_positions, sqrt_num_positions, dim) - - pos_embed = pos_embed.permute(0, 3, 1, 2) - - pos_embed = nn.functional.interpolate( - pos_embed, - size=(new_height, new_width), - mode="bicubic", - align_corners=False, - ) - - pos_embed = pos_embed.permute(0, 2, 3, 1).view(1, -1, dim) - - return pos_embed - - def forward( - self, - pixel_values: torch.Tensor, - bool_masked_pos: Optional[torch.BoolTensor] = None, - interpolate_pos_encoding: bool = False, - ) -> torch.Tensor: - batch_size, num_channels, height, width = pixel_values.shape - embeddings = self.patch_embeddings(pixel_values, interpolate_pos_encoding=interpolate_pos_encoding) - - if bool_masked_pos is not None: - seq_length = embeddings.shape[1] - mask_tokens = self.mask_token.expand(batch_size, seq_length, -1) - # replace the masked visual tokens by mask_tokens - mask = bool_masked_pos.unsqueeze(-1).type_as(mask_tokens) - embeddings = embeddings * (1.0 - mask) + mask_tokens * mask - - # add positional encoding to each token - if interpolate_pos_encoding: - embeddings = embeddings + self.interpolate_pos_encoding(embeddings, height, width) - else: - embeddings = embeddings + self.position_embeddings - - embeddings = self.dropout(embeddings) - - return embeddings - -class Pooler(nn.Module): - """ - Pool the output of a vision model by taking the mean of all tokens. - Adapted from huggingface/transformers ViT implementation. - """ - def __init__(self, config): - super().__init__() - self.dense = nn.Linear(config.hidden_size, config.hidden_size) - self.activation = nn.Tanh() - - def forward(self, hidden_states): - pooled_output = hidden_states.mean(dim=1) # always use mean pooling - pooled_output = self.dense(pooled_output) - pooled_output = self.activation(pooled_output) - return pooled_output - -""" -Cross Scan and Cross Merge implemented in Triton (only). taken from https://github.com/MzeroMiko/VMamba/blob/main/classification/models/csm_triton.py -""" - -@triton.jit -def triton_cross_scan_flex( - x: tl.tensor, # (B, C, H, W) | (B, H, W, C) | (B, 4, C, H, W) | (B, H, W, 4, C) - y: tl.tensor, # (B, 4, C, H, W) | (B, H, W, 4, C) - x_layout: tl.constexpr, - y_layout: tl.constexpr, - operation: tl.constexpr, - onebyone: tl.constexpr, - scans: tl.constexpr, - BC: tl.constexpr, - BH: tl.constexpr, - BW: tl.constexpr, - DC: tl.constexpr, - DH: tl.constexpr, - DW: tl.constexpr, - NH: tl.constexpr, - NW: tl.constexpr, -): - # x_layout = 0 - # y_layout = 1 # 0 BCHW, 1 BHWC - # operation = 0 # 0 scan, 1 merge - # onebyone = 0 # 0 false, 1 true - # scans = 0 # 0 cross scan, 1 unidirectional, 2 bidirectional - - i_hw, i_c, i_b = tl.program_id(0), tl.program_id(1), tl.program_id(2) - i_h, i_w = (i_hw // NW), (i_hw % NW) - _mask_h = (i_h * BH + tl.arange(0, BH)) < DH - _mask_w = (i_w * BW + tl.arange(0, BW)) < DW - _mask_hw = _mask_h[:, None] & _mask_w[None, :] - _for_C = min(DC - i_c * BC, BC) - - pos_h = (i_h * BH + tl.arange(0, BH)[:, None]) - pos_w = (i_w * BW + tl.arange(0, BW)[None, :]) - neg_h = (DH - i_h * BH - 1 - tl.arange(0, BH)[:, None]) - neg_w = (DW - i_w * BW - 1 - tl.arange(0, BW)[None, :]) - if scans == 0: - # none; trans; flip; trans + flip; - HWRoute0 = pos_h * DW + pos_w - HWRoute1 = pos_w * DH + pos_h # trans - HWRoute2 = neg_h * DW + neg_w # flip - HWRoute3 = neg_w * DH + neg_h # trans + flip - elif scans == 1: - # none; none; none; none; - HWRoute0 = pos_h * DW + pos_w - HWRoute1 = HWRoute0 - HWRoute2 = HWRoute0 - HWRoute3 = HWRoute0 - elif scans == 2: - # none; none; flip; flip; - HWRoute0 = pos_h * DW + pos_w - HWRoute1 = HWRoute0 - HWRoute2 = neg_h * DW + neg_w # flip - HWRoute3 = HWRoute2 - - _tmp1 = DC * DH * DW - - y_ptr_base = y + i_b * 4 * _tmp1 + (i_c * BC * DH * DW if y_layout == 0 else i_c * BC) - if y_layout == 0: - p_y1 = y_ptr_base + HWRoute0 - p_y2 = y_ptr_base + _tmp1 + HWRoute1 - p_y3 = y_ptr_base + 2 * _tmp1 + HWRoute2 - p_y4 = y_ptr_base + 3 * _tmp1 + HWRoute3 - else: - p_y1 = y_ptr_base + HWRoute0 * 4 * DC - p_y2 = y_ptr_base + DC + HWRoute1 * 4 * DC - p_y3 = y_ptr_base + 2 * DC + HWRoute2 * 4 * DC - p_y4 = y_ptr_base + 3 * DC + HWRoute3 * 4 * DC - - if onebyone == 0: - x_ptr_base = x + i_b * _tmp1 + (i_c * BC * DH * DW if x_layout == 0 else i_c * BC) - if x_layout == 0: - p_x = x_ptr_base + HWRoute0 - else: - p_x = x_ptr_base + HWRoute0 * DC - - if operation == 0: - for idxc in range(_for_C): - _idx_x = idxc * DH * DW if x_layout == 0 else idxc - _idx_y = idxc * DH * DW if y_layout == 0 else idxc - _x = tl.load(p_x + _idx_x, mask=_mask_hw) - tl.store(p_y1 + _idx_y, _x, mask=_mask_hw) - tl.store(p_y2 + _idx_y, _x, mask=_mask_hw) - tl.store(p_y3 + _idx_y, _x, mask=_mask_hw) - tl.store(p_y4 + _idx_y, _x, mask=_mask_hw) - elif operation == 1: - for idxc in range(_for_C): - _idx_x = idxc * DH * DW if x_layout == 0 else idxc - _idx_y = idxc * DH * DW if y_layout == 0 else idxc - _y1 = tl.load(p_y1 + _idx_y, mask=_mask_hw) - _y2 = tl.load(p_y2 + _idx_y, mask=_mask_hw) - _y3 = tl.load(p_y3 + _idx_y, mask=_mask_hw) - _y4 = tl.load(p_y4 + _idx_y, mask=_mask_hw) - tl.store(p_x + _idx_x, _y1 + _y2 + _y3 + _y4, mask=_mask_hw) - - else: - x_ptr_base = x + i_b * 4 * _tmp1 + (i_c * BC * DH * DW if x_layout == 0 else i_c * BC) - if x_layout == 0: - p_x1 = x_ptr_base + HWRoute0 - p_x2 = p_x1 + _tmp1 - p_x3 = p_x2 + _tmp1 - p_x4 = p_x3 + _tmp1 - else: - p_x1 = x_ptr_base + HWRoute0 * 4 * DC - p_x2 = p_x1 + DC - p_x3 = p_x2 + DC - p_x4 = p_x3 + DC - - if operation == 0: - for idxc in range(_for_C): - _idx_x = idxc * DH * DW if x_layout == 0 else idxc - _idx_y = idxc * DH * DW if y_layout == 0 else idxc - tl.store(p_y1 + _idx_y, tl.load(p_x1 + _idx_x, mask=_mask_hw), mask=_mask_hw) - tl.store(p_y2 + _idx_y, tl.load(p_x2 + _idx_x, mask=_mask_hw), mask=_mask_hw) - tl.store(p_y3 + _idx_y, tl.load(p_x3 + _idx_x, mask=_mask_hw), mask=_mask_hw) - tl.store(p_y4 + _idx_y, tl.load(p_x4 + _idx_x, mask=_mask_hw), mask=_mask_hw) - else: - for idxc in range(_for_C): - _idx_x = idxc * DH * DW if x_layout == 0 else idxc - _idx_y = idxc * DH * DW if y_layout == 0 else idxc - tl.store(p_x1 + _idx_x, tl.load(p_y1 + _idx_y), mask=_mask_hw) - tl.store(p_x2 + _idx_x, tl.load(p_y2 + _idx_y), mask=_mask_hw) - tl.store(p_x3 + _idx_x, tl.load(p_y3 + _idx_y), mask=_mask_hw) - tl.store(p_x4 + _idx_x, tl.load(p_y4 + _idx_y), mask=_mask_hw) - - -class CrossScanTritonF(torch.autograd.Function): - @staticmethod - def forward(ctx, x: torch.Tensor, in_channel_first=True, out_channel_first=True, one_by_one=False, scans=0): - if one_by_one: - if in_channel_first: - B, _, C, H, W = x.shape - else: - B, H, W, _, C = x.shape - else: - if in_channel_first: - B, C, H, W = x.shape - else: - B, H, W, C = x.shape - B, C, H, W = int(B), int(C), int(H), int(W) - BC, BH, BW = 1, 32, 32 - NH, NW, NC = triton.cdiv(H, BH), triton.cdiv(W, BW), triton.cdiv(C, BC) - - ctx.in_channel_first = in_channel_first - ctx.out_channel_first = out_channel_first - ctx.one_by_one = one_by_one - ctx.scans = scans - ctx.shape = (B, C, H, W) - ctx.triton_shape = (BC, BH, BW, NC, NH, NW) - - y = x.new_empty((B, 4, C, H * W)) if out_channel_first else x.new_empty((B, H * W, 4, C)) - triton_cross_scan_flex[(NH * NW, NC, B)]( - x.contiguous(), y, - (0 if in_channel_first else 1), (0 if out_channel_first else 1), 0, (0 if not one_by_one else 1), scans, - BC, BH, BW, C, H, W, NH, NW - ) - return y - - @staticmethod - def backward(ctx, y: torch.Tensor): - in_channel_first = ctx.in_channel_first - out_channel_first = ctx.out_channel_first - one_by_one = ctx.one_by_one - scans = ctx.scans - B, C, H, W = ctx.shape - BC, BH, BW, NC, NH, NW = ctx.triton_shape - if one_by_one: - x = y.new_empty((B, 4, C, H, W)) if in_channel_first else y.new_empty((B, H, W, 4, C)) - else: - x = y.new_empty((B, C, H, W)) if in_channel_first else y.new_empty((B, H, W, C)) - - triton_cross_scan_flex[(NH * NW, NC, B)]( - x, y.contiguous(), - (0 if in_channel_first else 1), (0 if out_channel_first else 1), 1, (0 if not one_by_one else 1), scans, - BC, BH, BW, C, H, W, NH, NW - ) - return x, None, None, None, None - - -class CrossMergeTritonF(torch.autograd.Function): - @staticmethod - def forward(ctx, y: torch.Tensor, in_channel_first=True, out_channel_first=True, one_by_one=False, scans=0): - if out_channel_first: - B, _, C, H, W = y.shape - else: - B, H, W, _, C = y.shape - B, C, H, W = int(B), int(C), int(H), int(W) - BC, BH, BW = 1, 32, 32 - NH, NW, NC = triton.cdiv(H, BH), triton.cdiv(W, BW), triton.cdiv(C, BC) - ctx.in_channel_first = in_channel_first - ctx.out_channel_first = out_channel_first - ctx.one_by_one = one_by_one - ctx.scans = scans - ctx.shape = (B, C, H, W) - ctx.triton_shape = (BC, BH, BW, NC, NH, NW) - if one_by_one: - x = y.new_empty((B, 4, C, H * W)) if in_channel_first else y.new_empty((B, H * W, 4, C)) - else: - x = y.new_empty((B, C, H * W)) if in_channel_first else y.new_empty((B, H * W, C)) - triton_cross_scan_flex[(NH * NW, NC, B)]( - x, y.contiguous(), - (0 if in_channel_first else 1), (0 if out_channel_first else 1), 1, (0 if not one_by_one else 1), scans, - BC, BH, BW, C, H, W, NH, NW - ) - return x - - @staticmethod - def backward(ctx, x: torch.Tensor): - in_channel_first = ctx.in_channel_first - out_channel_first = ctx.out_channel_first - one_by_one = ctx.one_by_one - scans = ctx.scans - B, C, H, W = ctx.shape - BC, BH, BW, NC, NH, NW = ctx.triton_shape - y = x.new_empty((B, 4, C, H, W)) if out_channel_first else x.new_empty((B, H, W, 4, C)) - triton_cross_scan_flex[(NH * NW, NC, B)]( - x.contiguous(), y, - (0 if in_channel_first else 1), (0 if out_channel_first else 1), 0, (0 if not one_by_one else 1), scans, - BC, BH, BW, C, H, W, NH, NW - ) - return y, None, None, None, None, None - - -# @torch.compile(options={"triton.cudagraphs": True}, fullgraph=True) -def cross_scan_fn(x: torch.Tensor, in_channel_first=True, out_channel_first=True, one_by_one=False, scans=0, force_torch=False): - # x: (B, C, H, W) | (B, H, W, C) | (B, 4, C, H, W) | (B, H, W, 4, C) - # y: (B, 4, C, L) | (B, L, 4, C) - # scans: 0: cross scan; 1 unidirectional; 2: bidirectional; - assert x.is_cuda - CSF = CrossScanTritonF - with torch.cuda.device(x.device): - return CSF.apply(x, in_channel_first, out_channel_first, one_by_one, scans) - - -# @torch.compile(options={"triton.cudagraphs": True}, fullgraph=True) -def cross_merge_fn(y: torch.Tensor, in_channel_first=True, out_channel_first=True, one_by_one=False, scans=0, force_torch=False): - # y: (B, 4, C, L) | (B, L, 4, C) - # x: (B, C, H * W) | (B, H * W, C) | (B, 4, C, H * W) | (B, H * W, 4, C) - # scans: 0: cross scan; 1 unidirectional; 2: bidirectional; - assert y.is_cuda - CMF = CrossMergeTritonF - with torch.cuda.device(y.device): - return CMF.apply(y, in_channel_first, out_channel_first, one_by_one, scans) - -def prepare_hidden_states_for_cross_scan(hidden_states: torch.Tensor, scan_type: str = "uni-scan"): - # hidden_states shape should be: (B, L, D) - if scan_type == "uni-scan": - # in this case, nothing need to be done - return hidden_states - elif scan_type == "bi-scan": - flipped_hidden_states = hidden_states.flip(-2) - hidden_states = torch.cat([hidden_states, flipped_hidden_states], dim=0) # (B, L, D) -> (2B, L, D) - return hidden_states - - # apply cross scan to the sequence - B, L, D = hidden_states.shape - hw = int(math.sqrt(L)) - assert (hw * hw == L) # make sure L is a square - hidden_states = einops.rearrange(hidden_states, "b (h w) d -> b h w d", h=hw, w=hw) # change the shape to feed to cross_scan - hidden_states = cross_scan_fn(hidden_states, in_channel_first=False, out_channel_first=False, one_by_one=False, scans=0) - hidden_states = einops.rearrange(hidden_states, "b l k d -> (b k) l d") - return hidden_states - -def prepare_hidden_states_for_cross_merge(hidden_states: torch.Tensor, scan_type: str = "uni-scan"): - # hidden_states shape should be: (BK, L, D), K=2 for bi-scan, K=1 for uni-scan, K=4 for cross-scan - if scan_type == "uni-scan": - # in this case, nothing need to be done - return hidden_states - elif scan_type == "bi-scan": - # merge the two sequences - B = hidden_states.shape[0] // 2 - hidden_states = hidden_states[:B] + hidden_states[B:] - return hidden_states - - B, L, D = hidden_states.shape - hw = int(math.sqrt(L)) - hidden_states = einops.rearrange(hidden_states, "(b k) (h w) d -> b h w k d", k=4, h=hw, w=hw) - # apply cross merge to the sequence - hidden_states = cross_merge_fn(hidden_states, in_channel_first=False, out_channel_first=False, one_by_one=False, scans=0) - return hidden_states - -# check the implementation -if __name__ == "__main__": - B, L, D = 1, 4, 3 - transformation = nn.Linear(D, D).cuda() - # firstly test bi-scan - print("Checking bi-scan") - h1 = torch.randn(B, L, D).cuda() - h2 = h1.clone().cuda() - h1 = prepare_hidden_states_for_cross_scan(h1, scan_type="bi-scan") - h1 = transformation(h1) - h1 = prepare_hidden_states_for_cross_merge(h1, scan_type="bi-scan") - h2_ = h2.clone().cuda() - h2_ = h2_.flip(-2) - h2 = transformation(h2) - h2_ = transformation(h2_) - h2 = h2 + h2_ - # check whether the two sequences are the same - print(f"h1: \n{h1}") - print(f"h2: \n{h2}") - print(f"""The two sequences are the same: {torch.allclose(h1, h2)}""") - # Then check cross-scan - print("checking cross-scan") - h1 = torch.randn(B, L, D).cuda() - h2 = h1.clone().cuda() - h1 = prepare_hidden_states_for_cross_scan(h1, scan_type="cross-scan") - h1 = transformation(h1) - h1 = prepare_hidden_states_for_cross_merge(h1, scan_type="cross-scan") - B, L, D = h2.shape - hw = int(math.sqrt(L)) - assert (hw * hw == L) # make sure L is a square - h2 = einops.rearrange(h2, "b (h w) d -> b h w d", h=hw, w=hw) # change the shape to feed to cross_scan - h2 = cross_scan_fn(h2, in_channel_first=False, out_channel_first=False, one_by_one=False, scans=0) - h2 = h2.permute(2, 0, 1, 3) - h2_0 = h2[0] - h2_1 = h2[1] - h2_2 = h2[2] - h2_3 = h2[3] - h2_0 = transformation(h2_0) - h2_1 = transformation(h2_1) - h2_2 = transformation(h2_2) - h2_3 = transformation(h2_3) - h2 = torch.cat([h2_0, h2_1, h2_2, h2_3], dim=0) - h2 = prepare_hidden_states_for_cross_merge(h2, scan_type="cross-scan") - # check whether the two sequences are the same - print(f"h1: \n{h1}") - print(f"h2: \n{h2}") - print(f"""The two sequences are the same: {torch.allclose(h1, h2)}""") \ No newline at end of file From 6deb624f949f61634d4a08042b2ef9be57a18887 Mon Sep 17 00:00:00 2001 From: yibozhong Date: Sun, 19 Jan 2025 17:23:30 +0800 Subject: [PATCH 16/17] migrate vision models to fla/models --- fla/models/__init__.py | 28 +- fla/models/abc/__init__.py | 12 +- fla/models/abc/configuration_abc.py | 95 ++++ fla/models/abc/modeling_abc.py | 377 +++++++++++++++- fla/models/bitnet/__init__.py | 12 +- fla/models/bitnet/configuration_bitnet.py | 95 +++- fla/models/bitnet/modeling_bitnet.py | 375 +++++++++++++++- fla/models/delta_net/__init__.py | 15 +- .../delta_net/configuration_delta_net.py | 99 ++++ fla/models/delta_net/modeling_delta_net.py | 373 ++++++++++++++- fla/models/gated_deltanet/__init__.py | 12 +- .../configuration_gated_deltanet.py | 91 +++- .../gated_deltanet/modeling_gated_deltanet.py | 375 +++++++++++++++- fla/models/gla/__init__.py | 12 +- fla/models/gla/configuration_gla.py | 99 ++++ fla/models/gla/modeling_gla.py | 378 +++++++++++++++- fla/models/gsa/__init__.py | 13 +- fla/models/gsa/configuration_gsa.py | 104 +++++ fla/models/gsa/modeling_gsa.py | 381 +++++++++++++++- fla/models/hgrn/__init__.py | 12 +- fla/models/hgrn/configuration_hgrn.py | 84 ++++ fla/models/hgrn/modeling_hgrn.py | 369 ++++++++++++++- fla/models/hgrn2/__init__.py | 12 +- fla/models/hgrn2/configuration_hgrn2.py | 86 ++++ fla/models/hgrn2/modeling_hgrn2.py | 370 ++++++++++++++- fla/models/linear_attn/__init__.py | 12 +- .../linear_attn/configuration_linear_attn.py | 93 ++++ .../linear_attn/modeling_linear_attn.py | 376 +++++++++++++++- fla/models/retnet/__init__.py | 12 +- fla/models/retnet/configuration_retnet.py | 96 ++++ fla/models/retnet/modeling_retnet.py | 374 +++++++++++++++- fla/models/rwkv6/__init__.py | 12 +- fla/models/rwkv6/configuration_rwkv6.py | 91 ++++ fla/models/rwkv6/modeling_rwkv6.py | 373 ++++++++++++++- fla/models/transformer/__init__.py | 12 +- .../transformer/configuration_transformer.py | 78 ++++ .../transformer/modeling_transformer.py | 364 ++++++++++++++- fla/models/utils.py | 423 +++++++++++++++++- 38 files changed, 6104 insertions(+), 91 deletions(-) diff --git a/fla/models/__init__.py b/fla/models/__init__.py index 497453723..2271a6693 100644 --- a/fla/models/__init__.py +++ b/fla/models/__init__.py @@ -19,6 +19,20 @@ from fla.models.transformer import (TransformerConfig, TransformerForCausalLM, TransformerModel) from fla.models.gated_deltanet import GatedDeltaNetConfig, GatedDeltaNetForCausalLM, GatedDeltaNetModel + +from fla.models.abc import ABCVisionConfig, ABCForImageClassification, ABCForMaskedImageModeling, ABCVisionModel +from fla.models.bitnet import BitNetVisionConfig, BitNetForImageClassification, BitNetForMaskedImageModeling, BitNetVisionModel +from fla.models.delta_net import DeltaNetVisionConfig, DeltaNetForImageClassification, DeltaNetForMaskedImageModeling, DeltaNetVisionModel +from fla.models.gated_deltanet import GatedDeltaNetVisionConfig, GatedDeltaNetForImageClassification, GatedDeltaNetVisionModel, GatedDeltaNetForMaskedImageModeling +from fla.models.gla import GLAVisionConfig, GLAForImageClassification, GLAForMaskedImageModeling, GLAVisionModel +from fla.models.gsa import GSAVisionConfig, GSAForImageClassification, GSAForMaskedImageModeling, GSAVisionModel +from fla.models.hgrn import HGRNVisionConfig, HGRNForImageClassification, HGRNForMaskedImageModeling, HGRNVisionModel +from fla.models.hgrn2 import HGRN2VisionConfig, HGRN2ForImageClassification, HGRN2ForMaskedImageModeling, HGRN2VisionModel +from fla.models.linear_attn import LinearAttentionVisionConfig, LinearAttentionForImageClassification, LinearAttentionForMaskedImageModeling, LinearAttentionVisionModel +from fla.models.retnet import RetNetVisionConfig, RetNetForImageClassification, RetNetForMaskedImageModeling, RetNetVisionModel +from fla.models.rwkv6 import RWKV6VisionConfig, RWKV6ForImageClassification, RWKV6ForMaskedImageModeling, RWKV6VisionModel +from fla.models.transformer import TransformerVisionConfig, TransformerForImageClassification, TransformerForMaskedImageModeling, TransformerVisionModel + __all__ = [ 'ABCConfig', 'ABCForCausalLM', 'ABCModel', 'BitNetConfig', 'BitNetForCausalLM', 'BitNetModel', @@ -34,5 +48,17 @@ 'RWKV6Config', 'RWKV6ForCausalLM', 'RWKV6Model', 'SambaConfig', 'SambaForCausalLM', 'SambaModel', 'TransformerConfig', 'TransformerForCausalLM', 'TransformerModel', - 'GatedDeltaNetConfig', 'GatedDeltaNetForCausalLM', 'GatedDeltaNetModel' + 'GatedDeltaNetConfig', 'GatedDeltaNetForCausalLM', 'GatedDeltaNetModel', + 'ABCVisionConfig', 'ABCForImageClassification', 'ABCForMaskedImageModeling', 'ABCVisionModel', + 'BitNetVisionConfig', 'BitNetForImageClassification', 'BitNetForMaskedImageModeling', 'BitNetVisionModel', + 'DeltaNetVisionConfig', 'DeltaNetForImageClassification', 'DeltaNetForMaskedImageModeling', 'DeltaNetVisionModel', + 'GatedDeltaNetVisionConfig', 'GatedDeltaNetForImageClassification', 'GatedDeltaNetVisionModel', 'GatedDeltaNetForMaskedImageModeling', + 'GLAVisionConfig', 'GLAForImageClassification', 'GLAForMaskedImageModeling', 'GLAVisionModel', + 'GSAVisionConfig', 'GSAForImageClassification', 'GSAForMaskedImageModeling', 'GSAVisionModel', + 'HGRNVisionConfig', 'HGRNForImageClassification', 'HGRNForMaskedImageModeling', 'HGRNVisionModel', + 'HGRN2VisionConfig', 'HGRN2ForImageClassification', 'HGRN2ForMaskedImageModeling', 'HGRN2VisionModel', + 'LinearAttentionVisionConfig', 'LinearAttentionForImageClassification', 'LinearAttentionForMaskedImageModeling', 'LinearAttentionVisionModel', + 'RetNetVisionConfig', 'RetNetForImageClassification', 'RetNetForMaskedImageModeling', 'RetNetVisionModel', + 'RWKV6VisionConfig', 'RWKV6ForImageClassification', 'RWKV6ForMaskedImageModeling', 'RWKV6VisionModel', + 'TransformerVisionConfig', 'TransformerForImageClassification', 'TransformerForMaskedImageModeling', 'TransformerVisionModel', ] diff --git a/fla/models/abc/__init__.py b/fla/models/abc/__init__.py index f7021f22f..e7ca4a6b1 100644 --- a/fla/models/abc/__init__.py +++ b/fla/models/abc/__init__.py @@ -1,13 +1,17 @@ # -*- coding: utf-8 -*- -from transformers import AutoConfig, AutoModel, AutoModelForCausalLM +from transformers import AutoConfig, AutoModel, AutoModelForCausalLM, AutoModelForImageClassification, AutoModelForMaskedImageModeling -from fla.models.abc.configuration_abc import ABCConfig -from fla.models.abc.modeling_abc import ABCForCausalLM, ABCModel +from fla.models.abc.configuration_abc import ABCConfig, ABCVisionConfig +from fla.models.abc.modeling_abc import ABCForCausalLM, ABCModel, ABCVisionModel, ABCForImageClassification, ABCForMaskedImageModeling AutoConfig.register(ABCConfig.model_type, ABCConfig) +AutoConfig.register(ABCVisionConfig.model_type, ABCVisionConfig) AutoModel.register(ABCConfig, ABCModel) AutoModelForCausalLM.register(ABCConfig, ABCForCausalLM) +AutoModelForImageClassification.register(ABCVisionConfig, ABCForImageClassification) +AutoModelForMaskedImageModeling.register(ABCVisionConfig, ABCForMaskedImageModeling) +AutoModel.register(ABCVisionConfig, ABCVisionModel) -__all__ = ['ABCConfig', 'ABCForCausalLM', 'ABCModel'] +__all__ = ['ABCConfig', 'ABCForCausalLM', 'ABCModel', 'ABCVisionModel', 'ABCForImageClassification', 'ABCForMaskedImageModeling', 'ABCVisionConfig'] diff --git a/fla/models/abc/configuration_abc.py b/fla/models/abc/configuration_abc.py index bdb7f4e22..d08e811f3 100644 --- a/fla/models/abc/configuration_abc.py +++ b/fla/models/abc/configuration_abc.py @@ -82,3 +82,98 @@ def __init__( tie_word_embeddings=tie_word_embeddings, **kwargs, ) + + +class ABCVisionConfig(PretrainedConfig): + + model_type = 'abc_vision' + + def __init__( + self, + # ABC core parameters + hidden_size: int = 2048, + gate_low_rank_dim: int = 16, + clamp_min: float = -32, + clamp_max: float = 32, + num_hidden_layers: int = 24, + num_heads: int = 4, + num_slots: Optional[int] = 64, + use_short_conv: bool = False, + conv_size: int = 4, + exapnd_k: float = 0.5, + exapnd_v: float = 1, + hidden_act: str = "swish", + max_position_embeddings: int = 2048, + elementwise_affine: Optional[bool] = True, + norm_eps: float = 1e-6, + attn: Optional[Dict] = None, + use_cache: bool = True, + initializer_range: float = 0.02, + fuse_norm: bool = True, + fuse_cross_entropy: bool = True, + # Vision specific parameters + image_size: int = 224, + patch_size: int = 16, + num_channels: int = 3, + num_classes: int = 1000, + qkv_bias: bool = True, + hidden_dropout_prob: float = 0.0, + use_mask_token: bool = False, + layer_norm_eps: float = 1e-6, + interpolate_pos_encoding: bool = False, + mlp_dim: int = None, + encoder_stride=16, + scan_type: str = "uni-scan", # scaning type, "uni-scan" or "bi-scan" or "cross-scan", default to "uni-scan" + **kwargs + ): + # Initialize ABC core parameters + self.hidden_size = hidden_size + self.gate_low_rank_dim = gate_low_rank_dim + self.clamp_min = clamp_min + self.clamp_max = clamp_max + self.num_hidden_layers = num_hidden_layers + self.num_heads = num_heads + self.num_slots = num_slots + self.use_short_conv = use_short_conv + self.conv_size = conv_size + self.expand_k = exapnd_k + self.expand_v = exapnd_v + self.hidden_act = hidden_act + self.max_position_embeddings = max_position_embeddings + self.elementwise_affine = elementwise_affine + self.norm_eps = norm_eps + self.use_cache = use_cache + self.initializer_range = initializer_range + self.fuse_norm = fuse_norm + self.fuse_cross_entropy = fuse_cross_entropy + + # Initialize vision specific parameters + self.image_size = image_size + self.patch_size = patch_size + self.num_channels = num_channels + self.num_classes = num_classes + self.qkv_bias = qkv_bias + self.hidden_dropout_prob = hidden_dropout_prob + self.use_mask_token = use_mask_token + self.layer_norm_eps = layer_norm_eps + self.interpolate_pos_encoding = interpolate_pos_encoding + self.scan_type = scan_type + self.encoder_stride = encoder_stride + + if attn is not None: + if not isinstance(attn, Dict): + raise ValueError("attn must be a dictionary") + if 'layers' not in attn: + raise ValueError("Layer indices must be provided to initialize hybrid attention layers") + if 'num_heads' not in attn: + raise ValueError("Number of heads must be provided to initialize hybrid attention layers") + attn['num_kv_heads'] = attn.get('num_kv_heads', attn['num_heads']) + attn['window_size'] = attn.get('window_size', None) + + self.attn = attn + if mlp_dim is None: + self.mlp_dim = 4 * hidden_size # default value set to 4 * hidden_size + else: + self.mlp_dim = mlp_dim + + super().__init__(**kwargs) \ No newline at end of file diff --git a/fla/models/abc/modeling_abc.py b/fla/models/abc/modeling_abc.py index c9a0c9645..2902aa2c7 100644 --- a/fla/models/abc/modeling_abc.py +++ b/fla/models/abc/modeling_abc.py @@ -4,7 +4,7 @@ import math import warnings -from typing import List, Optional, Tuple, Union +from typing import List, Optional, Tuple, Union, Unpack, Dict import torch import torch.nn as nn @@ -12,17 +12,22 @@ from transformers.activations import ACT2FN from transformers.generation import GenerationMixin from transformers.modeling_outputs import (BaseModelOutputWithPast, - CausalLMOutputWithPast) + CausalLMOutputWithPast, + ImageClassifierOutput, + MaskedImageModelingOutput, + BaseModelOutput, + BaseModelOutputWithPooling) from transformers.modeling_utils import PreTrainedModel from transformers.utils import logging from fla.layers.abc import ABCAttention from fla.layers.attn import Attention -from fla.models.abc.configuration_abc import ABCConfig +from fla.models.abc.configuration_abc import ABCConfig, ABCVisionConfig from fla.models.utils import Cache from fla.modules import (FusedCrossEntropyLoss, FusedLinearCrossEntropyLoss, RMSNorm) from fla.modules.activations import swiglu_linear +from ..utils import ImageEmbeddings, Pooler, prepare_hidden_states_for_cross_scan, prepare_hidden_states_for_cross_merge logger = logging.get_logger(__name__) @@ -403,3 +408,369 @@ def forward( hidden_states=outputs.hidden_states, attentions=outputs.attentions, ) + + +class ABCVisionMLP(nn.Module): + def __init__(self, config): + super().__init__() + self.net = nn.Sequential( + nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps), + nn.Linear(config.hidden_size, config.mlp_dim), + nn.GELU(), + nn.Linear(config.mlp_dim, config.hidden_size), + nn.Dropout(config.hidden_dropout_prob) + ) + + def forward(self, x): + return self.net(x) + +class ABCVisionBlock(nn.Module): + def __init__(self, config, layer_idx: int): + super().__init__() + + self.ln_1 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + + if config.attn is not None and layer_idx in config.attn['layers']: + self.attn = Attention( + hidden_size=config.hidden_size, + num_heads=config.attn['num_heads'], + num_kv_heads=config.attn['num_kv_heads'], + window_size=config.attn['window_size'], + max_position_embeddings=config.max_position_embeddings, + layer_idx=layer_idx + ) + else: + self.attn = ABCAttention( + hidden_size=config.hidden_size, + expand_k=config.expand_k, + expand_v=config.expand_v, + num_heads=config.num_heads, + num_slots=config.num_slots, + use_short_conv=config.use_short_conv, + conv_size=config.conv_size, + gate_fn=config.hidden_act, + elementwise_affine=config.elementwise_affine, + norm_eps=config.norm_eps, + clamp_min=config.clamp_min, + clamp_max=config.clamp_max, + fuse_norm=config.fuse_norm, + layer_idx=layer_idx + ) + + self.ln_2 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + + self.mlp = ABCVisionMLP(config) + + self.scan_type = config.scan_type + + def forward( + self, + hidden_states: torch.Tensor, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + use_cache: Optional[bool] = False, + output_attentions: Optional[bool] = False, + **kwargs: Unpack[Dict] + ) -> Union[Tuple[torch.Tensor, Optional[torch.Tensor]], Tuple[torch.Tensor]]: + residual = hidden_states + + # Pre-normalization if enabled + if hasattr(self, 'ln_1'): + hidden_states = self.ln_1(hidden_states) + + # Apply attention + + hidden_states = prepare_hidden_states_for_cross_scan(hidden_states, self.scan_type) + + hidden_states, attentions, past_key_values = self.attn( + hidden_states=hidden_states, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + **kwargs + ) + + hidden_states = prepare_hidden_states_for_cross_merge(hidden_states, self.scan_type) + + # First residual connection + hidden_states = residual + hidden_states + residual = hidden_states + + # Pre-normalization for MLP if enabled + if hasattr(self, 'ln_2'): + hidden_states = self.ln_2(hidden_states) + + hidden_states = self.mlp(hidden_states) + + # Second residual connection + hidden_states = residual + hidden_states + + outputs = (hidden_states, attentions, past_key_values) + + return outputs + +class ABCVisionPreTrainedModel(PreTrainedModel): + config_class = ABCVisionConfig + + def _init_weights(self, module): + if isinstance(module, (nn.Linear, nn.Conv2d)): + module.weight.data = nn.init.trunc_normal_( + module.weight.data.to(torch.float32), mean=0.0, std=self.config.initializer_range + ).to(module.weight.dtype) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.LayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + elif isinstance(module, ImageEmbeddings): + module.position_embeddings.data = nn.init.trunc_normal_( + module.position_embeddings.data.to(torch.float32), + mean=0.0, + std=self.config.initializer_range, + ).to(module.position_embeddings.dtype) + + +class ABCVisionEncoder(nn.Module): + def __init__(self, config) -> None: + super().__init__() + self.config = config + self.blocks = nn.ModuleList([ + ABCVisionBlock(config, layer_idx) + for layer_idx in range(config.num_hidden_layers) + ]) + self.gradient_checkpointing = False + + def forward( + self, + hidden_states: torch.Tensor, + output_attentions: bool = False, + output_hidden_states: bool = False, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + use_cache: Optional[bool] = None, + return_dict: bool = True, + **kwargs + ) -> Union[tuple, BaseModelOutput]: + all_hidden_states = () if output_hidden_states else None + all_self_attentions = () if output_attentions else None + + for i, block in enumerate(self.blocks): + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if self.gradient_checkpointing and self.training: + hidden_states, attentions, past_key_values = self._gradient_checkpointing_func( + block.__call__, + hidden_states, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + **kwargs + ) + else: + hidden_states, attentions, past_key_values = block( + hidden_states, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + **kwargs + ) + + if output_attentions: + all_self_attentions = all_self_attentions + (attentions,) + + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if not return_dict: + return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None) + + return BaseModelOutput( + last_hidden_state=hidden_states, + hidden_states=all_hidden_states, + attentions=all_self_attentions, + ) + +class ABCVisionModel(ABCVisionPreTrainedModel): + def __init__(self, config, add_pooling_layer=True, use_mask_token=False): + super().__init__(config) + self.config = config + self.embeddings = ImageEmbeddings(config, use_mask_token=use_mask_token) + self.encoder = ABCVisionEncoder(config) + self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.pooler = Pooler(config) if add_pooling_layer else None + self.init_weights() + + def get_input_embeddings(self): + return self.embeddings.patch_embeddings + + def forward( + self, + pixel_values: Optional[torch.Tensor] = None, + bool_masked_pos: Optional[torch.BoolTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + interpolate_pos_encoding: Optional[bool] = None, + use_cache: Optional[bool] = None, + return_dict: Optional[bool] = None, + **kwargs + ) -> Union[Tuple, BaseModelOutputWithPooling]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if pixel_values is None: + raise ValueError("You have to specify pixel_values") + + hidden_states = self.embeddings(pixel_values, bool_masked_pos=bool_masked_pos, interpolate_pos_encoding=interpolate_pos_encoding) + + encoder_outputs = self.encoder( + hidden_states, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + past_key_values=past_key_values, + use_cache=use_cache, + return_dict=return_dict, + **kwargs + ) + + sequence_output = encoder_outputs[0] + sequence_output = self.layernorm(sequence_output) + pooled_output = self.pooler(sequence_output) if self.pooler is not None else None + + if not return_dict: + head_outputs = (sequence_output, pooled_output) if pooled_output is not None else (sequence_output,) + return head_outputs + encoder_outputs[1:] + + return BaseModelOutputWithPooling( + last_hidden_state=sequence_output, + pooler_output=pooled_output, + hidden_states=encoder_outputs.hidden_states, + attentions=encoder_outputs.attentions, + ) + +class ABCForImageClassification(ABCVisionPreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_classes + self.backbone = ABCVisionModel(config, add_pooling_layer=True) # Here we should use mean pooling + self.classifier = nn.Linear(config.hidden_size, config.num_classes) + self.init_weights() + + def forward( + self, + pixel_values: Optional[torch.Tensor] = None, + labels: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + interpolate_pos_encoding: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[tuple, ImageClassifierOutput]: + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.backbone( + pixel_values, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + interpolate_pos_encoding=interpolate_pos_encoding, + return_dict=return_dict, + ) + + pooled_output = outputs.pooler_output + logits = self.classifier(pooled_output) # only use mean pooling + + loss = None + if labels is not None: + if self.num_labels == 1: + loss_fct = torch.nn.MSELoss() + loss = loss_fct(logits.squeeze(), labels.squeeze()) + else: + loss_fct = torch.nn.CrossEntropyLoss() + loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) + + if not return_dict: + output = (logits,) + outputs[2:] + return ((loss,) + output) if loss is not None else output + + return ImageClassifierOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + +class ABCForMaskedImageModeling(ABCVisionPreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.backbone = ABCVisionModel(config, add_pooling_layer=False, use_mask_token=True) + self.decoder = nn.Sequential( + nn.Conv2d( + in_channels=config.hidden_size, + out_channels=config.encoder_stride**2 * config.num_channels, + kernel_size=1, + ), + nn.PixelShuffle(config.encoder_stride), + ) + self.init_weights() + + def forward( + self, + pixel_values: Optional[torch.Tensor] = None, + bool_masked_pos: Optional[torch.BoolTensor] = None, + labels: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + interpolate_pos_encoding: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[tuple, MaskedImageModelingOutput]: + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if bool_masked_pos is not None and (self.config.patch_size != self.config.encoder_stride): + raise ValueError( + "When `bool_masked_pos` is provided, `patch_size` must be equal to `encoder_stride` to ensure that " + "the reconstructed image has the same dimensions as the input. " + f"Got `patch_size` = {self.config.patch_size} and `encoder_stride` = {self.config.encoder_stride}." + ) + + outputs = self.backbone( + pixel_values, + bool_masked_pos=bool_masked_pos, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + interpolate_pos_encoding=interpolate_pos_encoding, + return_dict=return_dict, + ) + + + sequence_output = outputs[0] + batch_size, sequence_length, num_channels = sequence_output.shape + height = width = math.floor(sequence_length**0.5) + sequence_output = sequence_output.permute(0, 2, 1).reshape(batch_size, num_channels, height, width) + + # Reconstruct pixel values + reconstructed_pixel_values = self.decoder(sequence_output) + + masked_im_loss = None + if bool_masked_pos is not None: + size = self.config.image_size // self.config.patch_size + bool_masked_pos = bool_masked_pos.reshape(-1, size, size) + mask = ( + bool_masked_pos.repeat_interleave(self.config.patch_size, 1) + .repeat_interleave(self.config.patch_size, 2) + .unsqueeze(1) + .contiguous() + ) + reconstruction_loss = nn.functional.l1_loss(pixel_values, reconstructed_pixel_values, reduction="none") + masked_im_loss = (reconstruction_loss * mask).sum() / (mask.sum() + 1e-5) / self.config.num_channels + + if not return_dict: + output = (reconstructed_pixel_values,) + outputs[1:] + return ((masked_im_loss,) + output) if masked_im_loss is not None else output + + return MaskedImageModelingOutput( + loss=masked_im_loss, + reconstruction=reconstructed_pixel_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) \ No newline at end of file diff --git a/fla/models/bitnet/__init__.py b/fla/models/bitnet/__init__.py index bede22c64..b6a5f33b2 100644 --- a/fla/models/bitnet/__init__.py +++ b/fla/models/bitnet/__init__.py @@ -1,13 +1,17 @@ # -*- coding: utf-8 -*- -from transformers import AutoConfig, AutoModel, AutoModelForCausalLM +from transformers import AutoConfig, AutoModel, AutoModelForCausalLM, AutoModelForImageClassification, AutoModelForMaskedImageModeling -from fla.models.bitnet.configuration_bitnet import BitNetConfig -from fla.models.bitnet.modeling_bitnet import BitNetForCausalLM, BitNetModel +from fla.models.bitnet.configuration_bitnet import BitNetConfig, BitNetVisionConfig +from fla.models.bitnet.modeling_bitnet import BitNetForCausalLM, BitNetModel, BitNetVisionModel, BitNetForImageClassification, BitNetForMaskedImageModeling AutoConfig.register(BitNetConfig.model_type, BitNetConfig) +AutoConfig.register(BitNetVisionConfig.model_type, BitNetVisionConfig) AutoModel.register(BitNetConfig, BitNetModel) AutoModelForCausalLM.register(BitNetConfig, BitNetForCausalLM) +AutoModelForImageClassification.register(BitNetVisionConfig, BitNetForImageClassification) +AutoModelForMaskedImageModeling.register(BitNetVisionConfig, BitNetForMaskedImageModeling) +AutoModel.register(BitNetVisionConfig, BitNetVisionModel) -__all__ = ['BitNetConfig', 'BitNetForCausalLM', 'BitNetModel'] +__all__ = ['BitNetConfig', 'BitNetForCausalLM', 'BitNetModel', 'BitNetVisionConfig', 'BitNetForImageClassification', 'BitNetForMaskedImageModeling', 'BitNetVisionModel'] diff --git a/fla/models/bitnet/configuration_bitnet.py b/fla/models/bitnet/configuration_bitnet.py index b6c50f8aa..f99530a47 100644 --- a/fla/models/bitnet/configuration_bitnet.py +++ b/fla/models/bitnet/configuration_bitnet.py @@ -1,6 +1,6 @@ # -*- coding: utf-8 -*- -from typing import Optional +from typing import Optional, Dict from transformers.configuration_utils import PretrainedConfig @@ -66,3 +66,96 @@ def __init__( tie_word_embeddings=tie_word_embeddings, **kwargs, ) + + +class BitNetVisionConfig(PretrainedConfig): + + model_type = 'bitnet_vision' + + def __init__( + self, + # BitNet core parameters + hidden_size: int = 2048, + num_hidden_layers: int = 24, + num_heads: int = 32, + num_kv_heads: int = None, + window_size: Optional[int] = None, + rope_theta: Optional[float] = 10000., + max_position_embeddings: int = 2048, + hidden_act: str = "swish", + initializer_range: float = 0.02, + elementwise_affine: Optional[bool] = True, + norm_first: bool = False, + norm_eps: float = 1e-6, + use_cache: bool = True, + attention_bias: bool = False, + fuse_norm: bool = True, + fuse_cross_entropy: bool = True, + attn: Optional[Dict] = None, + # Vision specific parameters + image_size: int = 224, + patch_size: int = 16, + num_channels: int = 3, + num_classes: int = 1000, + qkv_bias: bool = True, + hidden_dropout_prob: float = 0.0, + use_mask_token: bool = False, + layer_norm_eps: float = 1e-6, + interpolate_pos_encoding: bool = False, + mlp_dim: int = None, + encoder_stride=16, + scan_type: str = "uni-scan", # scaning type, "uni-scan" or "bi-scan" or "cross-scan", default to "uni-scan" + **kwargs + ): + # Initialize BitNet core parameters + self.hidden_size = hidden_size + self.num_hidden_layers = num_hidden_layers + self.num_heads = num_heads + self.num_kv_heads = num_kv_heads + self.window_size = window_size + self.rope_theta = rope_theta + self.max_position_embeddings = max_position_embeddings + + self.hidden_act = hidden_act + + self.initializer_range = initializer_range + self.elementwise_affine = elementwise_affine + self.norm_first = norm_first + self.norm_eps = norm_eps + self.use_cache = use_cache + self.attention_bias = attention_bias + self.fuse_cross_entropy = fuse_cross_entropy + self.fuse_norm = fuse_norm + + # Initialize vision specific parameters + self.image_size = image_size + self.patch_size = patch_size + self.num_channels = num_channels + self.num_classes = num_classes + self.qkv_bias = qkv_bias + self.hidden_dropout_prob = hidden_dropout_prob + self.use_mask_token = use_mask_token + self.layer_norm_eps = layer_norm_eps + self.interpolate_pos_encoding = interpolate_pos_encoding + self.scan_type = scan_type + self.encoder_stride = encoder_stride + + + if attn is not None: + if not isinstance(attn, Dict): + raise ValueError("attn must be a dictionary") + if 'layers' not in attn: + raise ValueError("Layer indices must be provided to initialize hybrid attention layers") + if 'num_heads' not in attn: + raise ValueError("Number of heads must be provided to initialize hybrid attention layers") + attn['num_kv_heads'] = attn.get('num_kv_heads', attn['num_heads']) + attn['window_size'] = attn.get('window_size', None) + + self.attn = attn + + if mlp_dim is None: + self.mlp_dim = 4 * hidden_size # default value set to 4 * hidden_size + else: + self.mlp_dim = mlp_dim + + super().__init__(**kwargs) \ No newline at end of file diff --git a/fla/models/bitnet/modeling_bitnet.py b/fla/models/bitnet/modeling_bitnet.py index 27fbff889..10bee4501 100644 --- a/fla/models/bitnet/modeling_bitnet.py +++ b/fla/models/bitnet/modeling_bitnet.py @@ -4,7 +4,7 @@ import math import warnings -from typing import List, Optional, Tuple, Union +from typing import List, Optional, Tuple, Union, Unpack, Dict import torch import torch.nn as nn @@ -12,17 +12,22 @@ from transformers.activations import ACT2FN from transformers.generation import GenerationMixin from transformers.modeling_outputs import (BaseModelOutputWithPast, - CausalLMOutputWithPast) + CausalLMOutputWithPast, + ImageClassifierOutput, + MaskedImageModelingOutput, + BaseModelOutput, + BaseModelOutputWithPooling) from transformers.modeling_utils import PreTrainedModel from transformers.utils import logging - +from fla.layers.attn import Attention from fla.layers.bitattn import BitAttention -from fla.models.bitnet.configuration_bitnet import BitNetConfig +from fla.models.bitnet.configuration_bitnet import BitNetConfig, BitNetVisionConfig from fla.models.utils import Cache from fla.modules import (FusedCrossEntropyLoss, FusedLinearCrossEntropyLoss, RMSNorm) from fla.modules.activations import swiglu_bitlinear from fla.modules.fused_bitlinear import BitLinear, rms_norm_linear_quant +from ..utils import ImageEmbeddings, Pooler, prepare_hidden_states_for_cross_scan, prepare_hidden_states_for_cross_merge logger = logging.get_logger(__name__) @@ -428,3 +433,365 @@ def forward( hidden_states=outputs.hidden_states, attentions=outputs.attentions, ) + +class BitNetVisionMLP(nn.Module): + def __init__(self, config): + super().__init__() + self.net = nn.Sequential( + nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps), + nn.Linear(config.hidden_size, config.mlp_dim), + nn.GELU(), + nn.Linear(config.mlp_dim, config.hidden_size), + nn.Dropout(config.hidden_dropout_prob) + ) + + def forward(self, x): + return self.net(x) + +class BitNetVisionBlock(nn.Module): + def __init__(self, config, layer_idx: int): + super().__init__() + + if not config.norm_first: + self.ln_1 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + + if config.attn is not None and layer_idx in config.attn['layers']: + self.attn = Attention( + hidden_size=config.hidden_size, + num_heads=config.attn['num_heads'], + num_kv_heads=config.attn['num_kv_heads'], + window_size=config.attn['window_size'], + max_position_embeddings=config.max_position_embeddings, + layer_idx=layer_idx + ) + else: + self.attn = BitAttention( + hidden_size=config.hidden_size, + num_heads=config.num_heads, + num_kv_heads=config.num_kv_heads, + window_size=config.window_size, + rope_theta=config.rope_theta, + max_position_embeddings=config.max_position_embeddings, + norm_first=config.norm_first, + norm_eps=config.norm_eps, + layer_idx=layer_idx + ) + + if not config.norm_first: + self.ln_2 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + + self.mlp = BitNetVisionMLP(config) + + self.scan_type = config.scan_type + + def forward( + self, + hidden_states: torch.Tensor, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + use_cache: Optional[bool] = False, + output_attentions: Optional[bool] = False, + **kwargs: Unpack[Dict] + ) -> Union[Tuple[torch.Tensor, Optional[torch.Tensor]], Tuple[torch.Tensor]]: + residual = hidden_states + + # Pre-normalization if enabled + if hasattr(self, 'ln_1'): + hidden_states = self.ln_1(hidden_states) + + # Apply attention + + hidden_states = prepare_hidden_states_for_cross_scan(hidden_states, self.scan_type) + + hidden_states, attentions, past_key_values = self.attn( + hidden_states=hidden_states, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + **kwargs + ) + + hidden_states = prepare_hidden_states_for_cross_merge(hidden_states, self.scan_type) + + # First residual connection + hidden_states = residual + hidden_states + residual = hidden_states + + # Pre-normalization for MLP if enabled + if hasattr(self, 'ln_2'): + hidden_states = self.ln_2(hidden_states) + + hidden_states = self.mlp(hidden_states) + + # Second residual connection + hidden_states = residual + hidden_states + + outputs = (hidden_states, attentions, past_key_values) + + return outputs + +class BitNetVisionPreTrainedModel(PreTrainedModel): + config_class = BitNetVisionConfig + + def _init_weights(self, module): + if isinstance(module, (nn.Linear, nn.Conv2d)): + module.weight.data = nn.init.trunc_normal_( + module.weight.data.to(torch.float32), mean=0.0, std=self.config.initializer_range + ).to(module.weight.dtype) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.LayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + elif isinstance(module, ImageEmbeddings): + module.position_embeddings.data = nn.init.trunc_normal_( + module.position_embeddings.data.to(torch.float32), + mean=0.0, + std=self.config.initializer_range, + ).to(module.position_embeddings.dtype) + + +class BitNetVisionEncoder(nn.Module): + def __init__(self, config) -> None: + super().__init__() + self.config = config + self.blocks = nn.ModuleList([ + BitNetVisionBlock(config, layer_idx) + for layer_idx in range(config.num_hidden_layers) + ]) + self.gradient_checkpointing = False + + def forward( + self, + hidden_states: torch.Tensor, + output_attentions: bool = False, + output_hidden_states: bool = False, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + use_cache: Optional[bool] = None, + return_dict: bool = True, + **kwargs + ) -> Union[tuple, BaseModelOutput]: + all_hidden_states = () if output_hidden_states else None + all_self_attentions = () if output_attentions else None + + for i, block in enumerate(self.blocks): + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if self.gradient_checkpointing and self.training: + hidden_states, attentions, past_key_values = self._gradient_checkpointing_func( + block.__call__, + hidden_states, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + **kwargs + ) + else: + hidden_states, attentions, past_key_values = block( + hidden_states, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + **kwargs + ) + + if output_attentions: + all_self_attentions = all_self_attentions + (attentions,) + + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if not return_dict: + return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None) + + return BaseModelOutput( + last_hidden_state=hidden_states, + hidden_states=all_hidden_states, + attentions=all_self_attentions, + ) + +class BitNetVisionModel(BitNetVisionPreTrainedModel): + def __init__(self, config, add_pooling_layer=True, use_mask_token=False): + super().__init__(config) + self.config = config + self.embeddings = ImageEmbeddings(config, use_mask_token=use_mask_token) + self.encoder = BitNetVisionEncoder(config) + self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.pooler = Pooler(config) if add_pooling_layer else None + self.init_weights() + + def get_input_embeddings(self): + return self.embeddings.patch_embeddings + + def forward( + self, + pixel_values: Optional[torch.Tensor] = None, + bool_masked_pos: Optional[torch.BoolTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + interpolate_pos_encoding: Optional[bool] = None, + use_cache: Optional[bool] = None, + return_dict: Optional[bool] = None, + **kwargs + ) -> Union[Tuple, BaseModelOutputWithPooling]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if pixel_values is None: + raise ValueError("You have to specify pixel_values") + + hidden_states = self.embeddings(pixel_values, bool_masked_pos=bool_masked_pos, interpolate_pos_encoding=interpolate_pos_encoding) + + encoder_outputs = self.encoder( + hidden_states, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + past_key_values=past_key_values, + use_cache=use_cache, + return_dict=return_dict, + **kwargs + ) + + sequence_output = encoder_outputs[0] + sequence_output = self.layernorm(sequence_output) + pooled_output = self.pooler(sequence_output) if self.pooler is not None else None + + if not return_dict: + head_outputs = (sequence_output, pooled_output) if pooled_output is not None else (sequence_output,) + return head_outputs + encoder_outputs[1:] + + return BaseModelOutputWithPooling( + last_hidden_state=sequence_output, + pooler_output=pooled_output, + hidden_states=encoder_outputs.hidden_states, + attentions=encoder_outputs.attentions, + ) + +class BitNetForImageClassification(BitNetVisionPreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_classes + self.backbone = BitNetVisionModel(config, add_pooling_layer=True) # Here we should use mean pooling + self.classifier = nn.Linear(config.hidden_size, config.num_classes) + self.init_weights() + + def forward( + self, + pixel_values: Optional[torch.Tensor] = None, + labels: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + interpolate_pos_encoding: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[tuple, ImageClassifierOutput]: + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.backbone( + pixel_values, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + interpolate_pos_encoding=interpolate_pos_encoding, + return_dict=return_dict, + ) + + pooled_output = outputs.pooler_output + logits = self.classifier(pooled_output) # only use mean pooling + + loss = None + if labels is not None: + if self.num_labels == 1: + loss_fct = torch.nn.MSELoss() + loss = loss_fct(logits.squeeze(), labels.squeeze()) + else: + loss_fct = torch.nn.CrossEntropyLoss() + loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) + + if not return_dict: + output = (logits,) + outputs[2:] + return ((loss,) + output) if loss is not None else output + + return ImageClassifierOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + +class BitNetForMaskedImageModeling(BitNetVisionPreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.backbone = BitNetVisionModel(config, add_pooling_layer=False, use_mask_token=True) + self.decoder = nn.Sequential( + nn.Conv2d( + in_channels=config.hidden_size, + out_channels=config.encoder_stride**2 * config.num_channels, + kernel_size=1, + ), + nn.PixelShuffle(config.encoder_stride), + ) + self.init_weights() + + def forward( + self, + pixel_values: Optional[torch.Tensor] = None, + bool_masked_pos: Optional[torch.BoolTensor] = None, + labels: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + interpolate_pos_encoding: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[tuple, MaskedImageModelingOutput]: + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if bool_masked_pos is not None and (self.config.patch_size != self.config.encoder_stride): + raise ValueError( + "When `bool_masked_pos` is provided, `patch_size` must be equal to `encoder_stride` to ensure that " + "the reconstructed image has the same dimensions as the input. " + f"Got `patch_size` = {self.config.patch_size} and `encoder_stride` = {self.config.encoder_stride}." + ) + + outputs = self.backbone( + pixel_values, + bool_masked_pos=bool_masked_pos, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + interpolate_pos_encoding=interpolate_pos_encoding, + return_dict=return_dict, + ) + + + sequence_output = outputs[0] + batch_size, sequence_length, num_channels = sequence_output.shape + height = width = math.floor(sequence_length**0.5) + sequence_output = sequence_output.permute(0, 2, 1).reshape(batch_size, num_channels, height, width) + + # Reconstruct pixel values + reconstructed_pixel_values = self.decoder(sequence_output) + + masked_im_loss = None + if bool_masked_pos is not None: + size = self.config.image_size // self.config.patch_size + bool_masked_pos = bool_masked_pos.reshape(-1, size, size) + mask = ( + bool_masked_pos.repeat_interleave(self.config.patch_size, 1) + .repeat_interleave(self.config.patch_size, 2) + .unsqueeze(1) + .contiguous() + ) + reconstruction_loss = nn.functional.l1_loss(pixel_values, reconstructed_pixel_values, reduction="none") + masked_im_loss = (reconstruction_loss * mask).sum() / (mask.sum() + 1e-5) / self.config.num_channels + + if not return_dict: + output = (reconstructed_pixel_values,) + outputs[1:] + return ((masked_im_loss,) + output) if masked_im_loss is not None else output + + return MaskedImageModelingOutput( + loss=masked_im_loss, + reconstruction=reconstructed_pixel_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) \ No newline at end of file diff --git a/fla/models/delta_net/__init__.py b/fla/models/delta_net/__init__.py index 258908922..13011c066 100644 --- a/fla/models/delta_net/__init__.py +++ b/fla/models/delta_net/__init__.py @@ -1,13 +1,20 @@ # -*- coding: utf-8 -*- -from transformers import AutoConfig, AutoModel, AutoModelForCausalLM +from transformers import AutoConfig, AutoModel, AutoModelForCausalLM, AutoModelForImageClassification, AutoModelForMaskedImageModeling -from fla.models.delta_net.configuration_delta_net import DeltaNetConfig +from fla.models.delta_net.configuration_delta_net import DeltaNetConfig, DeltaNetVisionConfig from fla.models.delta_net.modeling_delta_net import (DeltaNetForCausalLM, - DeltaNetModel) + DeltaNetModel, + DeltaNetVisionModel, + DeltaNetForImageClassification, + DeltaNetForMaskedImageModeling) AutoConfig.register(DeltaNetConfig.model_type, DeltaNetConfig) +AutoConfig.register(DeltaNetVisionConfig.model_type, DeltaNetVisionConfig) AutoModel.register(DeltaNetConfig, DeltaNetModel) AutoModelForCausalLM.register(DeltaNetConfig, DeltaNetForCausalLM) +AutoModel.register(DeltaNetVisionConfig, DeltaNetVisionModel) +AutoModelForImageClassification.register(DeltaNetVisionConfig, DeltaNetForImageClassification) +AutoModelForMaskedImageModeling.register(DeltaNetVisionConfig, DeltaNetForMaskedImageModeling) -__all__ = ['DeltaNetConfig', 'DeltaNetForCausalLM', 'DeltaNetModel'] +__all__ = ['DeltaNetConfig', 'DeltaNetForCausalLM', 'DeltaNetModel', 'DeltaNetVisionModel', 'DeltaNetForImageClassification', 'DeltaNetForMaskedImageModeling', 'DeltaNetVisionConfig'] diff --git a/fla/models/delta_net/configuration_delta_net.py b/fla/models/delta_net/configuration_delta_net.py index 45ba7b498..80773dbe0 100644 --- a/fla/models/delta_net/configuration_delta_net.py +++ b/fla/models/delta_net/configuration_delta_net.py @@ -85,3 +85,102 @@ def __init__( tie_word_embeddings=tie_word_embeddings, **kwargs, ) + + +class DeltaNetVisionConfig(PretrainedConfig): + model_type = 'delta_net_vision' + + def __init__( + self, + # DeltaNet core parameters + attn_mode: str = "chunk", + hidden_size: int = 2048, + expand_k: int = 1, + expand_v: int = 1, + use_gate: bool = False, + use_short_conv: bool = True, + conv_size: int = 4, + use_beta: bool = True, + use_output_norm: bool = True, + num_heads: int = 16, + qk_norm: str = 'l2', + qk_activation: str = 'silu', + intermediate_size: Optional[int] = None, + hidden_act: str = "swish", + num_hidden_layers: int = 12, + norm_first: bool = False, + norm_eps: float = 1e-6, + attn: Optional[Dict] = None, + use_cache: bool = True, + initializer_range: float = 0.02, + fuse_cross_entropy: bool = True, + max_position_embeddings: int = 2048, + + # Vision specific parameters + image_size: int = 224, + patch_size: int = 16, + num_channels: int = 3, + num_classes: int = 1000, + hidden_dropout_prob: float = 0.0, + use_mask_token: bool = False, + layer_norm_eps: float = 1e-6, + interpolate_pos_encoding: bool = False, + encoder_stride=16, + mlp_dim: int = None, + scan_type: str = "uni-scan", # scaning type, "uni-scan" or "bi-scan" or "cross-scan", default to "uni-scan" + **kwargs + ): + # Initialize DeltaNet core parameters + self.attn_mode = attn_mode + self.hidden_size = hidden_size + self.expand_k = expand_k + self.expand_v = expand_v + self.use_gate = use_gate + self.use_short_conv = use_short_conv + self.conv_size = conv_size + self.use_beta = use_beta + self.use_output_norm = use_output_norm + self.num_heads = num_heads + self.qk_norm = qk_norm + self.qk_activation = qk_activation + self.intermediate_size = intermediate_size + self.hidden_act = hidden_act + self.num_hidden_layers = num_hidden_layers + self.norm_first = norm_first + self.norm_eps = norm_eps + self.use_cache = use_cache + self.initializer_range = initializer_range + self.fuse_cross_entropy = fuse_cross_entropy + self.max_position_embeddings = max_position_embeddings + + # Initialize vision specific parameters + self.image_size = image_size + self.patch_size = patch_size + self.num_channels = num_channels + self.num_classes = num_classes + self.hidden_dropout_prob = hidden_dropout_prob + self.use_mask_token = use_mask_token + self.layer_norm_eps = layer_norm_eps + self.interpolate_pos_encoding = interpolate_pos_encoding + self.scan_type = scan_type + self.encoder_stride = encoder_stride + + + if attn is not None: + if not isinstance(attn, Dict): + raise ValueError("attn must be a dictionary") + if 'layers' not in attn: + raise ValueError("Layer indices must be provided to initialize hybrid attention layers") + if 'num_heads' not in attn: + raise ValueError("Number of heads must be provided to initialize hybrid attention layers") + attn['num_kv_heads'] = attn.get('num_kv_heads', attn['num_heads']) + attn['window_size'] = attn.get('window_size', None) + + self.attn = attn + + if mlp_dim is None: + self.mlp_dim = 4 * hidden_size # default value set to 4 * hidden_size + else: + self.mlp_dim = mlp_dim + + super().__init__(**kwargs) \ No newline at end of file diff --git a/fla/models/delta_net/modeling_delta_net.py b/fla/models/delta_net/modeling_delta_net.py index cd2fe5811..3653702dc 100644 --- a/fla/models/delta_net/modeling_delta_net.py +++ b/fla/models/delta_net/modeling_delta_net.py @@ -12,18 +12,23 @@ from transformers.activations import ACT2FN from transformers.generation import GenerationMixin from transformers.modeling_outputs import (BaseModelOutputWithPast, - CausalLMOutputWithPast) + CausalLMOutputWithPast, + ImageClassifierOutput, + MaskedImageModelingOutput, + BaseModelOutput, + BaseModelOutputWithPooling) from transformers.modeling_utils import PreTrainedModel from transformers.utils import logging from fla.layers.attn import Attention from fla.layers.delta_net import DeltaNet -from fla.models.delta_net.configuration_delta_net import DeltaNetConfig +from fla.models.delta_net.configuration_delta_net import DeltaNetConfig, DeltaNetVisionConfig from fla.models.utils import Cache from fla.modules import (FusedCrossEntropyLoss, FusedLinearCrossEntropyLoss, RMSNorm) from fla.modules.activations import swiglu_linear from fla.modules.layernorm import rms_norm_linear +from ..utils import ImageEmbeddings, Pooler, prepare_hidden_states_for_cross_scan, prepare_hidden_states_for_cross_merge logger = logging.get_logger(__name__) @@ -449,3 +454,367 @@ def forward( hidden_states=outputs.hidden_states, attentions=outputs.attentions, ) + + +class DeltaNetVisionMLP(nn.Module): + def __init__(self, config): + super().__init__() + self.net = nn.Sequential( + nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps), + nn.Linear(config.hidden_size, config.mlp_dim), + nn.GELU(), + nn.Linear(config.mlp_dim, config.hidden_size), + nn.Dropout(config.hidden_dropout_prob) + ) + + def forward(self, x): + return self.net(x) + +class DeltaNetVisionBlock(nn.Module): + def __init__(self, config, layer_idx: int): + super().__init__() + + if not config.norm_first: + self.ln_1 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + + if config.attn is not None and layer_idx in config.attn['layers']: + self.attn = Attention( + hidden_size=config.hidden_size, + num_heads=config.attn['num_heads'], + num_kv_heads=config.attn['num_kv_heads'], + window_size=config.attn['window_size'], + max_position_embeddings=config.max_position_embeddings, + layer_idx=layer_idx + ) + else: + self.attn = DeltaNet( + mode=config.attn_mode, + hidden_size=config.hidden_size, + expand_k=config.expand_k, + expand_v=config.expand_v, + num_heads=config.num_heads, + use_gate=config.use_gate, + use_beta=config.use_beta, + use_short_conv=config.use_short_conv, + use_output_norm=config.use_output_norm, + conv_size=config.conv_size, + qk_norm=config.qk_norm, + qk_activation=config.qk_activation, + norm_first=config.norm_first, + norm_eps=config.norm_eps, + layer_idx=layer_idx + ) + + if not config.norm_first: + self.ln_2 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + + self.mlp = DeltaNetVisionMLP(config) + + self.scan_type = config.scan_type + + def forward( + self, + hidden_states: torch.Tensor, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + use_cache: Optional[bool] = False, + output_attentions: Optional[bool] = False, + **kwargs: Unpack[Dict] + ) -> Union[Tuple[torch.Tensor, Optional[torch.Tensor]], Tuple[torch.Tensor]]: + residual = hidden_states + + if hasattr(self, 'ln_1'): + hidden_states = self.ln_1(hidden_states) + + + hidden_states = prepare_hidden_states_for_cross_scan(hidden_states, self.scan_type) + + hidden_states, attentions, past_key_values = self.attn( + hidden_states=hidden_states, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + **kwargs + ) + + hidden_states = prepare_hidden_states_for_cross_merge(hidden_states, self.scan_type) + + hidden_states = residual + hidden_states + residual = hidden_states + + if hasattr(self, 'ln_2'): + hidden_states = self.ln_2(hidden_states) + + hidden_states = self.mlp(hidden_states) + + hidden_states = residual + hidden_states + + outputs = (hidden_states, attentions, past_key_values) + + return outputs + +class DeltaNetVisionPreTrainedModel(PreTrainedModel): + config_class = DeltaNetVisionConfig + + def _init_weights(self, module): + if isinstance(module, (nn.Linear, nn.Conv2d)): + module.weight.data = nn.init.trunc_normal_( + module.weight.data.to(torch.float32), mean=0.0, std=self.config.initializer_range + ).to(module.weight.dtype) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.LayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + elif isinstance(module, ImageEmbeddings): + module.position_embeddings.data = nn.init.trunc_normal_( + module.position_embeddings.data.to(torch.float32), + mean=0.0, + std=self.config.initializer_range, + ).to(module.position_embeddings.dtype) + + +class DeltaNetVisionEncoder(nn.Module): + def __init__(self, config) -> None: + super().__init__() + self.config = config + self.blocks = nn.ModuleList([ + DeltaNetVisionBlock(config, layer_idx) + for layer_idx in range(config.num_hidden_layers) + ]) + self.gradient_checkpointing = False + + def forward( + self, + hidden_states: torch.Tensor, + output_attentions: bool = False, + output_hidden_states: bool = False, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + use_cache: Optional[bool] = None, + return_dict: bool = True, + **kwargs + ) -> Union[tuple, BaseModelOutput]: + all_hidden_states = () if output_hidden_states else None + all_self_attentions = () if output_attentions else None + + for i, block in enumerate(self.blocks): + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if self.gradient_checkpointing and self.training: + hidden_states, attentions, past_key_values = self._gradient_checkpointing_func( + block.__call__, + hidden_states, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + **kwargs + ) + else: + hidden_states, attentions, past_key_values = block( + hidden_states, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + **kwargs + ) + + if output_attentions: + all_self_attentions = all_self_attentions + (attentions,) + + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if not return_dict: + return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None) + + return BaseModelOutput( + last_hidden_state=hidden_states, + hidden_states=all_hidden_states, + attentions=all_self_attentions, + ) + +class DeltaNetVisionModel(DeltaNetVisionPreTrainedModel): + def __init__(self, config, add_pooling_layer=True, use_mask_token=False): + super().__init__(config) + self.config = config + self.embeddings = ImageEmbeddings(config, use_mask_token=use_mask_token) + self.encoder = DeltaNetVisionEncoder(config) + self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.pooler = Pooler(config) if add_pooling_layer else None + self.init_weights() + + def get_input_embeddings(self): + return self.embeddings.patch_embeddings + + def forward( + self, + pixel_values: Optional[torch.Tensor] = None, + bool_masked_pos: Optional[torch.BoolTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + interpolate_pos_encoding: Optional[bool] = None, + use_cache: Optional[bool] = None, + return_dict: Optional[bool] = None, + **kwargs + ) -> Union[Tuple, BaseModelOutputWithPooling]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if pixel_values is None: + raise ValueError("You have to specify pixel_values") + + hidden_states = self.embeddings(pixel_values, bool_masked_pos=bool_masked_pos, interpolate_pos_encoding=interpolate_pos_encoding) + + encoder_outputs = self.encoder( + hidden_states, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + past_key_values=past_key_values, + use_cache=use_cache, + return_dict=return_dict, + **kwargs + ) + + sequence_output = encoder_outputs[0] + sequence_output = self.layernorm(sequence_output) + pooled_output = self.pooler(sequence_output) if self.pooler is not None else None + + if not return_dict: + head_outputs = (sequence_output, pooled_output) if pooled_output is not None else (sequence_output,) + return head_outputs + encoder_outputs[1:] + + return BaseModelOutputWithPooling( + last_hidden_state=sequence_output, + pooler_output=pooled_output, + hidden_states=encoder_outputs.hidden_states, + attentions=encoder_outputs.attentions, + ) + +class DeltaNetForImageClassification(DeltaNetVisionPreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_classes + self.backbone = DeltaNetVisionModel(config, add_pooling_layer=True) # Here we should use mean pooling + self.classifier = nn.Linear(config.hidden_size, config.num_classes) + self.init_weights() + + def forward( + self, + pixel_values: Optional[torch.Tensor] = None, + labels: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + interpolate_pos_encoding: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[tuple, ImageClassifierOutput]: + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.backbone( + pixel_values, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + interpolate_pos_encoding=interpolate_pos_encoding, + return_dict=return_dict, + ) + + pooled_output = outputs.pooler_output + logits = self.classifier(pooled_output) + + loss = None + if labels is not None: + if self.num_labels == 1: + loss_fct = torch.nn.MSELoss() + loss = loss_fct(logits.squeeze(), labels.squeeze()) + else: + loss_fct = torch.nn.CrossEntropyLoss() + loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) + + if not return_dict: + output = (logits,) + outputs[2:] + return ((loss,) + output) if loss is not None else output + + return ImageClassifierOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + +class DeltaNetForMaskedImageModeling(DeltaNetVisionPreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.backbone = DeltaNetVisionModel(config, add_pooling_layer=False, use_mask_token=True) + self.decoder = nn.Sequential( + nn.Conv2d( + in_channels=config.hidden_size, + out_channels=config.encoder_stride**2 * config.num_channels, + kernel_size=1, + ), + nn.PixelShuffle(config.encoder_stride), + ) + self.init_weights() + + def forward( + self, + pixel_values: Optional[torch.Tensor] = None, + bool_masked_pos: Optional[torch.BoolTensor] = None, + labels: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + interpolate_pos_encoding: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[tuple, MaskedImageModelingOutput]: + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if bool_masked_pos is not None and (self.config.patch_size != self.config.encoder_stride): + raise ValueError( + "When `bool_masked_pos` is provided, `patch_size` must be equal to `encoder_stride` to ensure that " + "the reconstructed image has the same dimensions as the input. " + f"Got `patch_size` = {self.config.patch_size} and `encoder_stride` = {self.config.encoder_stride}." + ) + + outputs = self.backbone( + pixel_values, + bool_masked_pos=bool_masked_pos, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + interpolate_pos_encoding=interpolate_pos_encoding, + return_dict=return_dict, + ) + + + sequence_output = outputs[0] + batch_size, sequence_length, num_channels = sequence_output.shape + height = width = math.floor(sequence_length**0.5) + sequence_output = sequence_output.permute(0, 2, 1).reshape(batch_size, num_channels, height, width) + + # Reconstruct pixel values + reconstructed_pixel_values = self.decoder(sequence_output) + + masked_im_loss = None + if bool_masked_pos is not None: + size = self.config.image_size // self.config.patch_size + bool_masked_pos = bool_masked_pos.reshape(-1, size, size) + mask = ( + bool_masked_pos.repeat_interleave(self.config.patch_size, 1) + .repeat_interleave(self.config.patch_size, 2) + .unsqueeze(1) + .contiguous() + ) + reconstruction_loss = nn.functional.l1_loss(pixel_values, reconstructed_pixel_values, reduction="none") + masked_im_loss = (reconstruction_loss * mask).sum() / (mask.sum() + 1e-5) / self.config.num_channels + + if not return_dict: + output = (reconstructed_pixel_values,) + outputs[1:] + return ((masked_im_loss,) + output) if masked_im_loss is not None else output + + return MaskedImageModelingOutput( + loss=masked_im_loss, + reconstruction=reconstructed_pixel_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) \ No newline at end of file diff --git a/fla/models/gated_deltanet/__init__.py b/fla/models/gated_deltanet/__init__.py index 29fb5e2ea..0da7444bc 100644 --- a/fla/models/gated_deltanet/__init__.py +++ b/fla/models/gated_deltanet/__init__.py @@ -1,14 +1,18 @@ # -*- coding: utf-8 -*- -from transformers import AutoConfig, AutoModel, AutoModelForCausalLM +from transformers import AutoConfig, AutoModel, AutoModelForCausalLM, AutoModelForImageClassification, AutoModelForMaskedImageModeling from fla.models.gated_deltanet.configuration_gated_deltanet import \ - GatedDeltaNetConfig + GatedDeltaNetConfig, GatedDeltaNetVisionConfig from fla.models.gated_deltanet.modeling_gated_deltanet import ( - GatedDeltaNetForCausalLM, GatedDeltaNetModel) + GatedDeltaNetForCausalLM, GatedDeltaNetModel, GatedDeltaNetVisionModel, GatedDeltaNetForImageClassification, GatedDeltaNetForMaskedImageModeling) AutoConfig.register(GatedDeltaNetConfig.model_type, GatedDeltaNetConfig) +AutoConfig.register(GatedDeltaNetVisionConfig.model_type, GatedDeltaNetVisionConfig) AutoModel.register(GatedDeltaNetConfig, GatedDeltaNetModel) AutoModelForCausalLM.register(GatedDeltaNetConfig, GatedDeltaNetForCausalLM) +AutoModelForImageClassification.register(GatedDeltaNetVisionConfig, GatedDeltaNetForImageClassification) +AutoModelForMaskedImageModeling.register(GatedDeltaNetVisionConfig, GatedDeltaNetForMaskedImageModeling) +AutoModel.register(GatedDeltaNetVisionConfig, GatedDeltaNetVisionModel) -__all__ = ['GatedDeltaNetConfig', 'GatedDeltaNetForCausalLM', 'GatedDeltaNetModel'] +__all__ = ['GatedDeltaNetConfig', 'GatedDeltaNetForCausalLM', 'GatedDeltaNetModel', 'GatedDeltaNetVisionModel', 'GatedDeltaNetForImageClassification', 'GatedDeltaNetForMaskedImageModeling', 'GatedDeltaNetVisionConfig'] diff --git a/fla/models/gated_deltanet/configuration_gated_deltanet.py b/fla/models/gated_deltanet/configuration_gated_deltanet.py index 65a1418e5..e16e20c23 100644 --- a/fla/models/gated_deltanet/configuration_gated_deltanet.py +++ b/fla/models/gated_deltanet/configuration_gated_deltanet.py @@ -74,4 +74,93 @@ def __init__( eos_token_id=eos_token_id, tie_word_embeddings=tie_word_embeddings, **kwargs, - ) \ No newline at end of file + ) + + +class GatedDeltaNetVisionConfig(PretrainedConfig): + model_type = 'gated_deltanet_vision' + + def __init__( + self, + # GatedDeltaNet core parameters + attn_mode: str = "chunk", + hidden_size: int = 2048, + expand_v: int = 2, + use_gate: bool = True, + use_short_conv: bool = True, + conv_size: int = 4, + head_dim: int = 256, + num_heads: int = 6, + max_position_embeddings: int = 2048, + hidden_act: str = "swish", + num_hidden_layers: int = 21, + norm_first: bool = False, + norm_eps: float = 1e-6, + attn: Optional[Dict] = None, + use_cache: bool = True, + initializer_range: float = 0.02, + fuse_cross_entropy: bool = True, + + # Vision specific parameters + image_size: int = 224, + patch_size: int = 16, + num_channels: int = 3, + num_classes: int = 1000, + hidden_dropout_prob: float = 0.0, + use_mask_token: bool = False, + layer_norm_eps: float = 1e-6, + interpolate_pos_encoding: bool = False, + mlp_dim: int = None, + encoder_stride=16, + scan_type: str = "uni-scan", + **kwargs + ): + # Initialize GatedDeltaNet core parameters + self.attn_mode = attn_mode + self.hidden_size = hidden_size + self.expand_v = expand_v + self.head_dim = head_dim + self.use_gate = use_gate + self.use_short_conv = use_short_conv + self.conv_size = conv_size + self.num_heads = num_heads + self.hidden_act = hidden_act + self.num_hidden_layers = num_hidden_layers + self.norm_first = norm_first + self.norm_eps = norm_eps + self.use_cache = use_cache + self.initializer_range = initializer_range + self.fuse_cross_entropy = fuse_cross_entropy + self.attn = attn + self.max_position_embeddings = max_position_embeddings + + # Initialize vision specific parameters + self.image_size = image_size + self.patch_size = patch_size + self.num_channels = num_channels + self.num_classes = num_classes + self.hidden_dropout_prob = hidden_dropout_prob + self.use_mask_token = use_mask_token + self.layer_norm_eps = layer_norm_eps + self.interpolate_pos_encoding = interpolate_pos_encoding + self.scan_type = scan_type + self.encoder_stride = encoder_stride + + if attn is not None: + if not isinstance(attn, Dict): + raise ValueError("attn must be a dictionary") + if 'layers' not in attn: + raise ValueError("Layer indices must be provided to initialize hybrid attention layers") + if 'num_heads' not in attn: + raise ValueError("Number of heads must be provided to initialize hybrid attention layers") + attn['num_kv_heads'] = attn.get('num_kv_heads', attn['num_heads']) + attn['window_size'] = attn.get('window_size', None) + + self.attn = attn + + if mlp_dim is None: + self.mlp_dim = 4 * hidden_size + else: + self.mlp_dim = mlp_dim + + super().__init__(**kwargs) diff --git a/fla/models/gated_deltanet/modeling_gated_deltanet.py b/fla/models/gated_deltanet/modeling_gated_deltanet.py index 1f3f507dd..da64716d6 100644 --- a/fla/models/gated_deltanet/modeling_gated_deltanet.py +++ b/fla/models/gated_deltanet/modeling_gated_deltanet.py @@ -12,20 +12,24 @@ from transformers.activations import ACT2FN from transformers.generation import GenerationMixin from transformers.modeling_outputs import (BaseModelOutputWithPast, - CausalLMOutputWithPast) + CausalLMOutputWithPast, + ImageClassifierOutput, + MaskedImageModelingOutput, + BaseModelOutput, + BaseModelOutputWithPooling) from transformers.modeling_utils import PreTrainedModel from transformers.utils import logging from fla.layers.attn import Attention from fla.layers.gated_deltanet import GatedDeltaNet from fla.models.gated_deltanet.configuration_gated_deltanet import \ - GatedDeltaNetConfig + GatedDeltaNetConfig, GatedDeltaNetVisionConfig from fla.models.utils import Cache from fla.modules import (FusedCrossEntropyLoss, FusedLinearCrossEntropyLoss, RMSNorm) from fla.modules.activations import swiglu_linear from fla.modules.layernorm import rms_norm_linear - +from ..utils import ImageEmbeddings, Pooler, prepare_hidden_states_for_cross_scan, prepare_hidden_states_for_cross_merge if TYPE_CHECKING: from transformers.processing_utils import Unpack @@ -447,3 +451,368 @@ def forward( hidden_states=outputs.hidden_states, attentions=outputs.attentions, ) + + +class GatedDeltaNetVisionMLP(nn.Module): + def __init__(self, config): + super().__init__() + self.net = nn.Sequential( + nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps), + nn.Linear(config.hidden_size, config.mlp_dim), + nn.GELU(), + nn.Linear(config.mlp_dim, config.hidden_size), + nn.Dropout(config.hidden_dropout_prob) + ) + + def forward(self, x): + return self.net(x) + +class GatedDeltaNetVisionBlock(nn.Module): + def __init__(self, config, layer_idx: int): + super().__init__() + + if not config.norm_first: + self.ln_1 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + + if config.attn is not None and layer_idx in config.attn['layers']: + self.attn = Attention( + hidden_size=config.hidden_size, + num_heads=config.attn['num_heads'], + num_kv_heads=config.attn['num_kv_heads'], + window_size=config.attn['window_size'], + max_position_embeddings=config.max_position_embeddings, + layer_idx=layer_idx + ) + else: + self.attn = GatedDeltaNet( + mode=config.attn_mode, + hidden_size=config.hidden_size, + expand_v=config.expand_v, + head_dim=config.head_dim, + num_heads=config.num_heads, + use_gate=config.use_gate, + use_short_conv=config.use_short_conv, + conv_size=config.conv_size, + norm_first=config.norm_first, + norm_eps=config.norm_eps, + layer_idx=layer_idx + ) + + if not config.norm_first: + self.ln_2 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + + self.mlp = GatedDeltaNetVisionMLP(config) + + self.scan_type = config.scan_type + + def forward( + self, + hidden_states: torch.Tensor, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + use_cache: Optional[bool] = False, + output_attentions: Optional[bool] = False, + **kwargs: Unpack[Dict] + ) -> Union[Tuple[torch.Tensor, Optional[torch.Tensor]], Tuple[torch.Tensor]]: + residual = hidden_states + + # Pre-normalization if enabled + if hasattr(self, 'ln_1'): + hidden_states = self.ln_1(hidden_states) + + # Apply attention + + hidden_states = prepare_hidden_states_for_cross_scan(hidden_states, self.scan_type) + + hidden_states, attentions, past_key_values = self.attn( + hidden_states=hidden_states, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + **kwargs + ) + + hidden_states = prepare_hidden_states_for_cross_merge(hidden_states, self.scan_type) + + # First residual connection + hidden_states = residual + hidden_states + residual = hidden_states + + # Pre-normalization for MLP if enabled + if hasattr(self, 'ln_2'): + hidden_states = self.ln_2(hidden_states) + + hidden_states = self.mlp(hidden_states) + + # Second residual connection + hidden_states = residual + hidden_states + + outputs = (hidden_states, attentions, past_key_values) + + return outputs + +class GatedDeltaNetVisionPreTrainedModel(PreTrainedModel): + config_class = GatedDeltaNetVisionConfig + + def _init_weights(self, module): + if isinstance(module, (nn.Linear, nn.Conv2d)): + module.weight.data = nn.init.trunc_normal_( + module.weight.data.to(torch.float32), mean=0.0, std=self.config.initializer_range + ).to(module.weight.dtype) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.LayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + elif isinstance(module, ImageEmbeddings): + module.position_embeddings.data = nn.init.trunc_normal_( + module.position_embeddings.data.to(torch.float32), + mean=0.0, + std=self.config.initializer_range, + ).to(module.position_embeddings.dtype) + + +class GatedDeltaNetVisionEncoder(nn.Module): + def __init__(self, config) -> None: + super().__init__() + self.config = config + self.blocks = nn.ModuleList([ + GatedDeltaNetVisionBlock(config, layer_idx) + for layer_idx in range(config.num_hidden_layers) + ]) + self.gradient_checkpointing = False + + def forward( + self, + hidden_states: torch.Tensor, + output_attentions: bool = False, + output_hidden_states: bool = False, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + use_cache: Optional[bool] = None, + return_dict: bool = True, + **kwargs + ) -> Union[tuple, BaseModelOutput]: + all_hidden_states = () if output_hidden_states else None + all_self_attentions = () if output_attentions else None + + for i, block in enumerate(self.blocks): + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if self.gradient_checkpointing and self.training: + hidden_states, attentions, past_key_values = self._gradient_checkpointing_func( + block.__call__, + hidden_states, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + **kwargs + ) + else: + hidden_states, attentions, past_key_values = block( + hidden_states, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + **kwargs + ) + + if output_attentions: + all_self_attentions = all_self_attentions + (attentions,) + + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if not return_dict: + return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None) + + return BaseModelOutput( + last_hidden_state=hidden_states, + hidden_states=all_hidden_states, + attentions=all_self_attentions, + ) + +class GatedDeltaNetVisionModel(GatedDeltaNetVisionPreTrainedModel): + def __init__(self, config, add_pooling_layer=True, use_mask_token=False): + super().__init__(config) + self.config = config + self.embeddings = ImageEmbeddings(config, use_mask_token=use_mask_token) + self.encoder = GatedDeltaNetVisionEncoder(config) + self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.pooler = Pooler(config) if add_pooling_layer else None + self.init_weights() + + def get_input_embeddings(self): + return self.embeddings.patch_embeddings + + def forward( + self, + pixel_values: Optional[torch.Tensor] = None, + bool_masked_pos: Optional[torch.BoolTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + interpolate_pos_encoding: Optional[bool] = None, + use_cache: Optional[bool] = None, + return_dict: Optional[bool] = None, + **kwargs + ) -> Union[Tuple, BaseModelOutputWithPooling]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if pixel_values is None: + raise ValueError("You have to specify pixel_values") + + hidden_states = self.embeddings(pixel_values, bool_masked_pos=bool_masked_pos, interpolate_pos_encoding=interpolate_pos_encoding) + + encoder_outputs = self.encoder( + hidden_states, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + past_key_values=past_key_values, + use_cache=use_cache, + return_dict=return_dict, + **kwargs + ) + + sequence_output = encoder_outputs[0] + sequence_output = self.layernorm(sequence_output) + pooled_output = self.pooler(sequence_output) if self.pooler is not None else None + + if not return_dict: + head_outputs = (sequence_output, pooled_output) if pooled_output is not None else (sequence_output,) + return head_outputs + encoder_outputs[1:] + + return BaseModelOutputWithPooling( + last_hidden_state=sequence_output, + pooler_output=pooled_output, + hidden_states=encoder_outputs.hidden_states, + attentions=encoder_outputs.attentions, + ) + +class GatedDeltaNetForImageClassification(GatedDeltaNetVisionPreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_classes + self.backbone = GatedDeltaNetVisionModel(config, add_pooling_layer=True) # Here we should use mean pooling + self.classifier = nn.Linear(config.hidden_size, config.num_classes) + self.init_weights() + + def forward( + self, + pixel_values: Optional[torch.Tensor] = None, + labels: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + interpolate_pos_encoding: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[tuple, ImageClassifierOutput]: + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.backbone( + pixel_values, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + interpolate_pos_encoding=interpolate_pos_encoding, + return_dict=return_dict, + ) + + pooled_output = outputs.pooler_output + logits = self.classifier(pooled_output) # only use mean pooling + + loss = None + if labels is not None: + if self.num_labels == 1: + loss_fct = torch.nn.MSELoss() + loss = loss_fct(logits.squeeze(), labels.squeeze()) + else: + loss_fct = torch.nn.CrossEntropyLoss() + loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) + + if not return_dict: + output = (logits,) + outputs[2:] + return ((loss,) + output) if loss is not None else output + + return ImageClassifierOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + +class GatedDeltaNetForMaskedImageModeling(GatedDeltaNetVisionPreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.backbone = GatedDeltaNetVisionModel(config, add_pooling_layer=False, use_mask_token=True) + self.decoder = nn.Sequential( + nn.Conv2d( + in_channels=config.hidden_size, + out_channels=config.encoder_stride**2 * config.num_channels, + kernel_size=1, + ), + nn.PixelShuffle(config.encoder_stride), + ) + self.init_weights() + + def forward( + self, + pixel_values: Optional[torch.Tensor] = None, + bool_masked_pos: Optional[torch.BoolTensor] = None, + labels: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + interpolate_pos_encoding: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[tuple, MaskedImageModelingOutput]: + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if bool_masked_pos is not None and (self.config.patch_size != self.config.encoder_stride): + raise ValueError( + "When `bool_masked_pos` is provided, `patch_size` must be equal to `encoder_stride` to ensure that " + "the reconstructed image has the same dimensions as the input. " + f"Got `patch_size` = {self.config.patch_size} and `encoder_stride` = {self.config.encoder_stride}." + ) + + outputs = self.backbone( + pixel_values, + bool_masked_pos=bool_masked_pos, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + interpolate_pos_encoding=interpolate_pos_encoding, + return_dict=return_dict, + ) + + + sequence_output = outputs[0] + batch_size, sequence_length, num_channels = sequence_output.shape + height = width = math.floor(sequence_length**0.5) + sequence_output = sequence_output.permute(0, 2, 1).reshape(batch_size, num_channels, height, width) + + # Reconstruct pixel values + reconstructed_pixel_values = self.decoder(sequence_output) + + masked_im_loss = None + if bool_masked_pos is not None: + size = self.config.image_size // self.config.patch_size + bool_masked_pos = bool_masked_pos.reshape(-1, size, size) + mask = ( + bool_masked_pos.repeat_interleave(self.config.patch_size, 1) + .repeat_interleave(self.config.patch_size, 2) + .unsqueeze(1) + .contiguous() + ) + reconstruction_loss = nn.functional.l1_loss(pixel_values, reconstructed_pixel_values, reduction="none") + masked_im_loss = (reconstruction_loss * mask).sum() / (mask.sum() + 1e-5) / self.config.num_channels + + if not return_dict: + output = (reconstructed_pixel_values,) + outputs[1:] + return ((masked_im_loss,) + output) if masked_im_loss is not None else output + + return MaskedImageModelingOutput( + loss=masked_im_loss, + reconstruction=reconstructed_pixel_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) \ No newline at end of file diff --git a/fla/models/gla/__init__.py b/fla/models/gla/__init__.py index edccb515a..419941024 100644 --- a/fla/models/gla/__init__.py +++ b/fla/models/gla/__init__.py @@ -1,13 +1,17 @@ # -*- coding: utf-8 -*- -from transformers import AutoConfig, AutoModel, AutoModelForCausalLM +from transformers import AutoConfig, AutoModel, AutoModelForCausalLM, AutoModelForImageClassification, AutoModelForMaskedImageModeling -from fla.models.gla.configuration_gla import GLAConfig -from fla.models.gla.modeling_gla import GLAForCausalLM, GLAModel +from fla.models.gla.configuration_gla import GLAConfig, GLAVisionConfig +from fla.models.gla.modeling_gla import GLAForCausalLM, GLAModel, GLAVisionModel, GLAForImageClassification, GLAForMaskedImageModeling AutoConfig.register(GLAConfig.model_type, GLAConfig) +AutoConfig.register(GLAVisionConfig.model_type, GLAVisionConfig) AutoModel.register(GLAConfig, GLAModel) AutoModelForCausalLM.register(GLAConfig, GLAForCausalLM) +AutoModelForImageClassification.register(GLAVisionConfig, GLAForImageClassification) +AutoModelForMaskedImageModeling.register(GLAVisionConfig, GLAForMaskedImageModeling) +AutoModel.register(GLAVisionConfig, GLAVisionModel) -__all__ = ['GLAConfig', 'GLAForCausalLM', 'GLAModel'] +__all__ = ['GLAConfig', 'GLAForCausalLM', 'GLAModel', 'GLAVisionModel', 'GLAForImageClassification', 'GLAForMaskedImageModeling', 'GLAVisionConfig'] diff --git a/fla/models/gla/configuration_gla.py b/fla/models/gla/configuration_gla.py index 7991112b2..b73084896 100644 --- a/fla/models/gla/configuration_gla.py +++ b/fla/models/gla/configuration_gla.py @@ -88,3 +88,102 @@ def __init__( tie_word_embeddings=tie_word_embeddings, **kwargs, ) + +class GLAVisionConfig(PretrainedConfig): + + model_type = 'gla_vision' + + def __init__( + self, + # GLA core parameters + hidden_size: int = 2048, + expand_k: int = 0.5, + expand_v: int = 1, + num_hidden_layers: int = 24, + num_heads: int = 4, + num_kv_heads: Optional[int] = None, + feature_map: Optional[str] = None, + attn_mode: str = "chunk", + use_short_conv: bool = False, + conv_size: int = 4, + use_output_gate: bool = True, + clamp_min: Optional[float] = None, + hidden_act: str = "swish", + max_position_embeddings: int = 2048, + elementwise_affine: Optional[bool] = True, + norm_eps: float = 1e-6, + use_gk: bool = True, + use_gv: bool = False, + attn: Optional[Dict] = None, + use_cache: bool = True, + initializer_range: float = 0.02, + fuse_norm: bool = True, + fuse_cross_entropy: bool = True, + # Vision specific parameters + image_size: int = 224, + patch_size: int = 16, + num_channels: int = 3, + num_classes: int = 1000, + hidden_dropout_prob: float = 0.0, + use_mask_token: bool = False, + layer_norm_eps: float = 1e-6, + interpolate_pos_encoding: bool = False, + mlp_dim: int = None, + encoder_stride=16, + scan_type: str = "uni-scan", # scaning type, "uni-scan" or "bi-scan" or "cross-scan", default to "uni-scan" + **kwargs + ): + # Initialize DeltaNet core parameters + self.hidden_size = hidden_size + self.expand_k = expand_k + self.expand_v = expand_v + self.num_hidden_layers = num_hidden_layers + self.num_heads = num_heads + self.num_kv_heads = num_kv_heads + self.feature_map = feature_map + self.attn_mode = attn_mode + self.use_short_conv = use_short_conv + self.conv_size = conv_size + self.use_output_gate = use_output_gate + self.clamp_min = clamp_min + self.hidden_act = hidden_act + self.max_position_embeddings = max_position_embeddings + self.elementwise_affine = elementwise_affine + self.norm_eps = norm_eps + self.use_gk = use_gk + self.use_gv = use_gv + self.use_cache = use_cache + self.initializer_range = initializer_range + self.fuse_norm = fuse_norm + self.fuse_cross_entropy = fuse_cross_entropy + + # Initialize vision specific parameters + self.image_size = image_size + self.patch_size = patch_size + self.num_channels = num_channels + self.num_classes = num_classes + self.hidden_dropout_prob = hidden_dropout_prob + self.use_mask_token = use_mask_token + self.layer_norm_eps = layer_norm_eps + self.interpolate_pos_encoding = interpolate_pos_encoding + self.scan_type = scan_type + self.encoder_stride = encoder_stride + + if attn is not None: + if not isinstance(attn, Dict): + raise ValueError("attn must be a dictionary") + if 'layers' not in attn: + raise ValueError("Layer indices must be provided to initialize hybrid attention layers") + if 'num_heads' not in attn: + raise ValueError("Number of heads must be provided to initialize hybrid attention layers") + attn['num_kv_heads'] = attn.get('num_kv_heads', attn['num_heads']) + attn['window_size'] = attn.get('window_size', None) + + self.attn = attn + + if mlp_dim is None: + self.mlp_dim = 4 * hidden_size + else: + self.mlp_dim = mlp_dim + + super().__init__(**kwargs) \ No newline at end of file diff --git a/fla/models/gla/modeling_gla.py b/fla/models/gla/modeling_gla.py index d4d357b87..4f6f7ca24 100644 --- a/fla/models/gla/modeling_gla.py +++ b/fla/models/gla/modeling_gla.py @@ -12,18 +12,22 @@ from transformers.activations import ACT2FN from transformers.generation import GenerationMixin from transformers.modeling_outputs import (BaseModelOutputWithPast, - CausalLMOutputWithPast) + CausalLMOutputWithPast, + ImageClassifierOutput, + MaskedImageModelingOutput, + BaseModelOutput, + BaseModelOutputWithPooling) from transformers.modeling_utils import PreTrainedModel from transformers.utils import logging from fla.layers.attn import Attention from fla.layers.gla import GatedLinearAttention -from fla.models.gla.configuration_gla import GLAConfig +from fla.models.gla.configuration_gla import GLAConfig, GLAVisionConfig from fla.models.utils import Cache from fla.modules import (FusedCrossEntropyLoss, FusedLinearCrossEntropyLoss, RMSNorm) from fla.modules.activations import swiglu_linear - +from ..utils import ImageEmbeddings, Pooler, prepare_hidden_states_for_cross_scan, prepare_hidden_states_for_cross_merge if TYPE_CHECKING: from transformers.processing_utils import Unpack @@ -431,3 +435,371 @@ def forward( hidden_states=outputs.hidden_states, attentions=outputs.attentions, ) + + +class GLAVisionMLP(nn.Module): + def __init__(self, config): + super().__init__() + self.net = nn.Sequential( + nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps), + nn.Linear(config.hidden_size, config.mlp_dim), + nn.GELU(), + nn.Linear(config.mlp_dim, config.hidden_size), + nn.Dropout(config.hidden_dropout_prob) + ) + + def forward(self, x): + return self.net(x) + +class GLAVisionBlock(nn.Module): + def __init__(self, config, layer_idx: int): + super().__init__() + + self.ln_1 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + + if config.attn is not None and layer_idx in config.attn['layers']: + self.attn = Attention( + hidden_size=config.hidden_size, + num_heads=config.attn['num_heads'], + num_kv_heads=config.attn['num_kv_heads'], + window_size=config.attn['window_size'], + max_position_embeddings=config.max_position_embeddings, + layer_idx=layer_idx + ) + else: + self.attn = GatedLinearAttention( + mode=config.attn_mode, + hidden_size=config.hidden_size, + expand_k=config.expand_k, + expand_v=config.expand_v, + num_heads=config.num_heads, + num_kv_heads=config.num_kv_heads, + feature_map=config.feature_map, + use_short_conv=config.use_short_conv, + conv_size=config.conv_size, + use_output_gate=config.use_output_gate, + gate_fn=config.hidden_act, + elementwise_affine=config.elementwise_affine, + norm_eps=config.norm_eps, + clamp_min=config.clamp_min, + fuse_norm=config.fuse_norm, + layer_idx=layer_idx + ) + + self.ln_2 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + + self.mlp = GLAVisionMLP(config) + + self.scan_type = config.scan_type + + def forward( + self, + hidden_states: torch.Tensor, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + use_cache: Optional[bool] = False, + output_attentions: Optional[bool] = False, + **kwargs: Unpack[Dict] + ) -> Union[Tuple[torch.Tensor, Optional[torch.Tensor]], Tuple[torch.Tensor]]: + residual = hidden_states + + # Pre-normalization if enabled + if hasattr(self, 'ln_1'): + hidden_states = self.ln_1(hidden_states) + + # Apply attention + + hidden_states = prepare_hidden_states_for_cross_scan(hidden_states, self.scan_type) + + hidden_states, attentions, past_key_values = self.attn( + hidden_states=hidden_states, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + **kwargs + ) + + hidden_states = prepare_hidden_states_for_cross_merge(hidden_states, self.scan_type) + + # First residual connection + hidden_states = residual + hidden_states + residual = hidden_states + + # Pre-normalization for MLP if enabled + if hasattr(self, 'ln_2'): + hidden_states = self.ln_2(hidden_states) + + hidden_states = self.mlp(hidden_states) + + # Second residual connection + hidden_states = residual + hidden_states + + outputs = (hidden_states, attentions, past_key_values) + + return outputs + +class GLAVisionPreTrainedModel(PreTrainedModel): + config_class = GLAVisionConfig + + def _init_weights(self, module): + if isinstance(module, (nn.Linear, nn.Conv2d)): + module.weight.data = nn.init.trunc_normal_( + module.weight.data.to(torch.float32), mean=0.0, std=self.config.initializer_range + ).to(module.weight.dtype) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.LayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + elif isinstance(module, ImageEmbeddings): + module.position_embeddings.data = nn.init.trunc_normal_( + module.position_embeddings.data.to(torch.float32), + mean=0.0, + std=self.config.initializer_range, + ).to(module.position_embeddings.dtype) + + +class GLAVisionEncoder(nn.Module): + def __init__(self, config) -> None: + super().__init__() + self.config = config + self.blocks = nn.ModuleList([ + GLAVisionBlock(config, layer_idx) + for layer_idx in range(config.num_hidden_layers) + ]) + self.gradient_checkpointing = False + + def forward( + self, + hidden_states: torch.Tensor, + output_attentions: bool = False, + output_hidden_states: bool = False, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + use_cache: Optional[bool] = None, + return_dict: bool = True, + **kwargs + ) -> Union[tuple, BaseModelOutput]: + all_hidden_states = () if output_hidden_states else None + all_self_attentions = () if output_attentions else None + + for i, block in enumerate(self.blocks): + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if self.gradient_checkpointing and self.training: + hidden_states, attentions, past_key_values = self._gradient_checkpointing_func( + block.__call__, + hidden_states, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + **kwargs + ) + else: + hidden_states, attentions, past_key_values = block( + hidden_states, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + **kwargs + ) + + if output_attentions: + all_self_attentions = all_self_attentions + (attentions,) + + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if not return_dict: + return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None) + + return BaseModelOutput( + last_hidden_state=hidden_states, + hidden_states=all_hidden_states, + attentions=all_self_attentions, + ) + +class GLAVisionModel(GLAVisionPreTrainedModel): + def __init__(self, config, add_pooling_layer=True, use_mask_token=False): + super().__init__(config) + self.config = config + self.embeddings = ImageEmbeddings(config, use_mask_token=use_mask_token) + self.encoder = GLAVisionEncoder(config) + self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.pooler = Pooler(config) if add_pooling_layer else None + self.init_weights() + + def get_input_embeddings(self): + return self.embeddings.patch_embeddings + + def forward( + self, + pixel_values: Optional[torch.Tensor] = None, + bool_masked_pos: Optional[torch.BoolTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + interpolate_pos_encoding: Optional[bool] = None, + use_cache: Optional[bool] = None, + return_dict: Optional[bool] = None, + **kwargs + ) -> Union[Tuple, BaseModelOutputWithPooling]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if pixel_values is None: + raise ValueError("You have to specify pixel_values") + + hidden_states = self.embeddings(pixel_values, bool_masked_pos=bool_masked_pos, interpolate_pos_encoding=interpolate_pos_encoding) + + encoder_outputs = self.encoder( + hidden_states, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + past_key_values=past_key_values, + use_cache=use_cache, + return_dict=return_dict, + **kwargs + ) + + sequence_output = encoder_outputs[0] + sequence_output = self.layernorm(sequence_output) + pooled_output = self.pooler(sequence_output) if self.pooler is not None else None + + if not return_dict: + head_outputs = (sequence_output, pooled_output) if pooled_output is not None else (sequence_output,) + return head_outputs + encoder_outputs[1:] + + return BaseModelOutputWithPooling( + last_hidden_state=sequence_output, + pooler_output=pooled_output, + hidden_states=encoder_outputs.hidden_states, + attentions=encoder_outputs.attentions, + ) + +class GLAForImageClassification(GLAVisionPreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_classes + self.backbone = GLAVisionModel(config, add_pooling_layer=True) # Here we should use mean pooling + self.classifier = nn.Linear(config.hidden_size, config.num_classes) + self.init_weights() + + def forward( + self, + pixel_values: Optional[torch.Tensor] = None, + labels: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + interpolate_pos_encoding: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[tuple, ImageClassifierOutput]: + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.backbone( + pixel_values, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + interpolate_pos_encoding=interpolate_pos_encoding, + return_dict=return_dict, + ) + + pooled_output = outputs.pooler_output + logits = self.classifier(pooled_output) # only use mean pooling + + loss = None + if labels is not None: + if self.num_labels == 1: + loss_fct = torch.nn.MSELoss() + loss = loss_fct(logits.squeeze(), labels.squeeze()) + else: + loss_fct = torch.nn.CrossEntropyLoss() + loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) + + if not return_dict: + output = (logits,) + outputs[2:] + return ((loss,) + output) if loss is not None else output + + return ImageClassifierOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + +class GLAForMaskedImageModeling(GLAVisionPreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.backbone = GLAVisionModel(config, add_pooling_layer=False, use_mask_token=True) + self.decoder = nn.Sequential( + nn.Conv2d( + in_channels=config.hidden_size, + out_channels=config.encoder_stride**2 * config.num_channels, + kernel_size=1, + ), + nn.PixelShuffle(config.encoder_stride), + ) + self.init_weights() + + def forward( + self, + pixel_values: Optional[torch.Tensor] = None, + bool_masked_pos: Optional[torch.BoolTensor] = None, + labels: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + interpolate_pos_encoding: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[tuple, MaskedImageModelingOutput]: + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if bool_masked_pos is not None and (self.config.patch_size != self.config.encoder_stride): + raise ValueError( + "When `bool_masked_pos` is provided, `patch_size` must be equal to `encoder_stride` to ensure that " + "the reconstructed image has the same dimensions as the input. " + f"Got `patch_size` = {self.config.patch_size} and `encoder_stride` = {self.config.encoder_stride}." + ) + + outputs = self.backbone( + pixel_values, + bool_masked_pos=bool_masked_pos, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + interpolate_pos_encoding=interpolate_pos_encoding, + return_dict=return_dict, + ) + + + sequence_output = outputs[0] + batch_size, sequence_length, num_channels = sequence_output.shape + height = width = math.floor(sequence_length**0.5) + sequence_output = sequence_output.permute(0, 2, 1).reshape(batch_size, num_channels, height, width) + + # Reconstruct pixel values + reconstructed_pixel_values = self.decoder(sequence_output) + + masked_im_loss = None + if bool_masked_pos is not None: + size = self.config.image_size // self.config.patch_size + bool_masked_pos = bool_masked_pos.reshape(-1, size, size) + mask = ( + bool_masked_pos.repeat_interleave(self.config.patch_size, 1) + .repeat_interleave(self.config.patch_size, 2) + .unsqueeze(1) + .contiguous() + ) + reconstruction_loss = nn.functional.l1_loss(pixel_values, reconstructed_pixel_values, reduction="none") + masked_im_loss = (reconstruction_loss * mask).sum() / (mask.sum() + 1e-5) / self.config.num_channels + + if not return_dict: + output = (reconstructed_pixel_values,) + outputs[1:] + return ((masked_im_loss,) + output) if masked_im_loss is not None else output + + return MaskedImageModelingOutput( + loss=masked_im_loss, + reconstruction=reconstructed_pixel_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) \ No newline at end of file diff --git a/fla/models/gsa/__init__.py b/fla/models/gsa/__init__.py index a134f758e..94843feb0 100644 --- a/fla/models/gsa/__init__.py +++ b/fla/models/gsa/__init__.py @@ -1,13 +1,16 @@ # -*- coding: utf-8 -*- -from transformers import AutoConfig, AutoModel, AutoModelForCausalLM +from transformers import AutoConfig, AutoModel, AutoModelForCausalLM, AutoModelForImageClassification, AutoModelForMaskedImageModeling -from fla.models.gsa.configuration_gsa import GSAConfig -from fla.models.gsa.modeling_gsa import GSAForCausalLM, GSAModel +from fla.models.gsa.configuration_gsa import GSAConfig, GSAVisionConfig +from fla.models.gsa.modeling_gsa import GSAForCausalLM, GSAModel, GSAVisionModel, GSAForImageClassification, GSAForMaskedImageModeling AutoConfig.register(GSAConfig.model_type, GSAConfig) +AutoConfig.register(GSAVisionConfig.model_type, GSAVisionConfig) AutoModel.register(GSAConfig, GSAModel) AutoModelForCausalLM.register(GSAConfig, GSAForCausalLM) +AutoModelForImageClassification.register(GSAVisionConfig, GSAForImageClassification) +AutoModelForMaskedImageModeling.register(GSAVisionConfig, GSAForMaskedImageModeling) +AutoModel.register(GSAVisionConfig, GSAVisionModel) - -__all__ = ['GSAConfig', 'GSAForCausalLM', 'GSAModel'] +__all__ = ['GSAConfig', 'GSAForCausalLM', 'GSAModel', 'GSAVisionModel', 'GSAForImageClassification', 'GSAForMaskedImageModeling', 'GSAVisionConfig'] diff --git a/fla/models/gsa/configuration_gsa.py b/fla/models/gsa/configuration_gsa.py index b2b37c843..38b91d17c 100644 --- a/fla/models/gsa/configuration_gsa.py +++ b/fla/models/gsa/configuration_gsa.py @@ -92,3 +92,107 @@ def __init__( tie_word_embeddings=tie_word_embeddings, **kwargs, ) + +class GSAVisionConfig(PretrainedConfig): + + model_type = 'gsa_vision' + + def __init__( + self, + # GSA core parameters + hidden_size: int = 2048, + gate_logit_normalizer: Optional[int] = 8, + clamp_min: Optional[float] = None, + clamp_max: Optional[float] = None, + num_hidden_layers: int = 24, + num_heads: int = 4, + num_kv_heads: Optional[int] = None, + num_slots: Optional[int] = 64, + use_short_conv: bool = False, + conv_size: int = 4, + exapnd_k: float = 1, + exapnd_v: float = 1, + feature_map: str = 'swish', + use_output_gate: bool = False, + use_norm: bool = True, + max_position_embeddings: int = 2048, + hidden_act: str = "swish", + elementwise_affine: Optional[bool] = True, + norm_first: bool = True, + norm_eps: float = 1e-6, + attn: Optional[Dict] = None, + use_cache: bool = True, + initializer_range: float = 0.02, + fuse_norm: bool = True, + fuse_cross_entropy: bool = True, + # Vision specific parameters + image_size: int = 224, + patch_size: int = 16, + num_channels: int = 3, + num_classes: int = 1000, + qkv_bias: bool = True, + hidden_dropout_prob: float = 0.0, + use_mask_token: bool = False, + layer_norm_eps: float = 1e-6, + interpolate_pos_encoding: bool = False, + mlp_dim: int = None, + encoder_stride=16, + scan_type: str = "uni-scan", # scaning type, "uni-scan" or "bi-scan" or "cross-scan", default to "uni-scan" + **kwargs + ): + self.hidden_size = hidden_size + self.gate_logit_normalizer = gate_logit_normalizer + self.clamp_min = clamp_min + self.clamp_max = clamp_max + self.num_hidden_layers = num_hidden_layers + self.num_heads = num_heads + self.num_kv_heads = num_kv_heads + self.num_slots = num_slots + self.use_short_conv = use_short_conv + self.conv_size = conv_size + self.expand_k = exapnd_k + self.expand_v = exapnd_v + self.feature_map = feature_map + self.use_output_gate = use_output_gate + self.use_norm = use_norm + self.max_position_embeddings = max_position_embeddings + self.hidden_act = hidden_act + self.elementwise_affine = elementwise_affine + self.norm_first = norm_first + self.norm_eps = norm_eps + self.use_cache = use_cache + self.initializer_range = initializer_range + self.fuse_cross_entropy = fuse_cross_entropy + self.fuse_norm = fuse_norm + + # Initialize vision specific parameters + self.image_size = image_size + self.patch_size = patch_size + self.num_channels = num_channels + self.num_classes = num_classes + self.qkv_bias = qkv_bias + self.hidden_dropout_prob = hidden_dropout_prob + self.use_mask_token = use_mask_token + self.layer_norm_eps = layer_norm_eps + self.interpolate_pos_encoding = interpolate_pos_encoding + self.scan_type = scan_type + self.encoder_stride = encoder_stride + + if attn is not None: + if not isinstance(attn, Dict): + raise ValueError("attn must be a dictionary") + if 'layers' not in attn: + raise ValueError("Layer indices must be provided to initialize hybrid attention layers") + if 'num_heads' not in attn: + raise ValueError("Number of heads must be provided to initialize hybrid attention layers") + attn['num_kv_heads'] = attn.get('num_kv_heads', attn['num_heads']) + attn['window_size'] = attn.get('window_size', None) + + self.attn = attn + + if mlp_dim is None: + self.mlp_dim = 4 * hidden_size # default value set to 4 * hidden_size + else: + self.mlp_dim = mlp_dim + + super().__init__(**kwargs) \ No newline at end of file diff --git a/fla/models/gsa/modeling_gsa.py b/fla/models/gsa/modeling_gsa.py index d11a2574e..a59fb5a4f 100644 --- a/fla/models/gsa/modeling_gsa.py +++ b/fla/models/gsa/modeling_gsa.py @@ -12,19 +12,23 @@ from transformers.activations import ACT2FN from transformers.generation import GenerationMixin from transformers.modeling_outputs import (BaseModelOutputWithPast, - CausalLMOutputWithPast) + CausalLMOutputWithPast, + ImageClassifierOutput, + MaskedImageModelingOutput, + BaseModelOutput, + BaseModelOutputWithPooling) from transformers.modeling_utils import PreTrainedModel from transformers.utils import logging from fla.layers.attn import Attention from fla.layers.gsa import GatedSlotAttention -from fla.models.gsa.configuration_gsa import GSAConfig +from fla.models.gsa.configuration_gsa import GSAConfig, GSAVisionConfig from fla.models.utils import Cache from fla.modules import (FusedCrossEntropyLoss, FusedLinearCrossEntropyLoss, RMSNorm) from fla.modules.activations import swiglu_linear from fla.modules.layernorm import rms_norm_linear - +from ..utils import ImageEmbeddings, Pooler, prepare_hidden_states_for_cross_scan, prepare_hidden_states_for_cross_merge if TYPE_CHECKING: from transformers.processing_utils import Unpack @@ -453,3 +457,374 @@ def forward( hidden_states=outputs.hidden_states, attentions=outputs.attentions, ) + +class GSAVisionMLP(nn.Module): + def __init__(self, config): + super().__init__() + self.net = nn.Sequential( + nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps), + nn.Linear(config.hidden_size, config.mlp_dim), + nn.GELU(), + nn.Linear(config.mlp_dim, config.hidden_size), + nn.Dropout(config.hidden_dropout_prob) + ) + + def forward(self, x): + return self.net(x) + +class GSAVisionBlock(nn.Module): + def __init__(self, config, layer_idx: int): + super().__init__() + + if not config.norm_first: + self.ln_1 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + + if config.attn is not None and layer_idx in config.attn['layers']: + self.attn = Attention( + hidden_size=config.hidden_size, + num_heads=config.attn['num_heads'], + num_kv_heads=config.attn['num_kv_heads'], + window_size=config.attn['window_size'], + max_position_embeddings=config.max_position_embeddings, + layer_idx=layer_idx + ) + else: + self.attn = GatedSlotAttention( + hidden_size=config.hidden_size, + expand_k=config.expand_k, + expand_v=config.expand_v, + num_heads=config.num_heads, + num_kv_heads=config.num_kv_heads, + num_slots=config.num_slots, + use_short_conv=config.use_short_conv, + conv_size=config.conv_size, + feature_map=config.feature_map, + use_output_gate=config.use_output_gate, + use_norm=config.use_norm, + gate_fn=config.hidden_act, + gate_logit_normalizer=config.gate_logit_normalizer, + elementwise_affine=config.elementwise_affine, + norm_first=config.norm_first, + norm_eps=config.norm_eps, + fuse_norm=config.fuse_norm, + layer_idx=layer_idx + ) + + if not config.norm_first: + self.ln_2 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + + self.mlp = GSAVisionMLP(config) + + self.scan_type = config.scan_type + + def forward( + self, + hidden_states: torch.Tensor, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + use_cache: Optional[bool] = False, + output_attentions: Optional[bool] = False, + **kwargs: Unpack[Dict] + ) -> Union[Tuple[torch.Tensor, Optional[torch.Tensor]], Tuple[torch.Tensor]]: + residual = hidden_states + + # Pre-normalization if enabled + if hasattr(self, 'ln_1'): + hidden_states = self.ln_1(hidden_states) + + # Apply attention + + hidden_states = prepare_hidden_states_for_cross_scan(hidden_states, self.scan_type) + + hidden_states, attentions, past_key_values = self.attn( + hidden_states=hidden_states, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + **kwargs + ) + + hidden_states = prepare_hidden_states_for_cross_merge(hidden_states, self.scan_type) + + # First residual connection + hidden_states = residual + hidden_states + residual = hidden_states + + # Pre-normalization for MLP if enabled + if hasattr(self, 'ln_2'): + hidden_states = self.ln_2(hidden_states) + + hidden_states = self.mlp(hidden_states) + + # Second residual connection + hidden_states = residual + hidden_states + + outputs = (hidden_states, attentions, past_key_values) + + return outputs + +class GSAVisionPreTrainedModel(PreTrainedModel): + config_class = GSAVisionConfig + + def _init_weights(self, module): + if isinstance(module, (nn.Linear, nn.Conv2d)): + module.weight.data = nn.init.trunc_normal_( + module.weight.data.to(torch.float32), mean=0.0, std=self.config.initializer_range + ).to(module.weight.dtype) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.LayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + elif isinstance(module, ImageEmbeddings): + module.position_embeddings.data = nn.init.trunc_normal_( + module.position_embeddings.data.to(torch.float32), + mean=0.0, + std=self.config.initializer_range, + ).to(module.position_embeddings.dtype) + + +class GSAVisionEncoder(nn.Module): + def __init__(self, config) -> None: + super().__init__() + self.config = config + self.blocks = nn.ModuleList([ + GSAVisionBlock(config, layer_idx) + for layer_idx in range(config.num_hidden_layers) + ]) + self.gradient_checkpointing = False + + def forward( + self, + hidden_states: torch.Tensor, + output_attentions: bool = False, + output_hidden_states: bool = False, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + use_cache: Optional[bool] = None, + return_dict: bool = True, + **kwargs + ) -> Union[tuple, BaseModelOutput]: + all_hidden_states = () if output_hidden_states else None + all_self_attentions = () if output_attentions else None + + for i, block in enumerate(self.blocks): + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if self.gradient_checkpointing and self.training: + hidden_states, attentions, past_key_values = self._gradient_checkpointing_func( + block.__call__, + hidden_states, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + **kwargs + ) + else: + hidden_states, attentions, past_key_values = block( + hidden_states, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + **kwargs + ) + + if output_attentions: + all_self_attentions = all_self_attentions + (attentions,) + + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if not return_dict: + return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None) + + return BaseModelOutput( + last_hidden_state=hidden_states, + hidden_states=all_hidden_states, + attentions=all_self_attentions, + ) + +class GSAVisionModel(GSAVisionPreTrainedModel): + def __init__(self, config, add_pooling_layer=True, use_mask_token=False): + super().__init__(config) + self.config = config + self.embeddings = ImageEmbeddings(config, use_mask_token=use_mask_token) + self.encoder = GSAVisionEncoder(config) + self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.pooler = Pooler(config) if add_pooling_layer else None + self.init_weights() + + def get_input_embeddings(self): + return self.embeddings.patch_embeddings + + def forward( + self, + pixel_values: Optional[torch.Tensor] = None, + bool_masked_pos: Optional[torch.BoolTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + interpolate_pos_encoding: Optional[bool] = None, + use_cache: Optional[bool] = None, + return_dict: Optional[bool] = None, + **kwargs + ) -> Union[Tuple, BaseModelOutputWithPooling]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if pixel_values is None: + raise ValueError("You have to specify pixel_values") + + hidden_states = self.embeddings(pixel_values, bool_masked_pos=bool_masked_pos, interpolate_pos_encoding=interpolate_pos_encoding) + + encoder_outputs = self.encoder( + hidden_states, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + past_key_values=past_key_values, + use_cache=use_cache, + return_dict=return_dict, + **kwargs + ) + + sequence_output = encoder_outputs[0] + sequence_output = self.layernorm(sequence_output) + pooled_output = self.pooler(sequence_output) if self.pooler is not None else None + + if not return_dict: + head_outputs = (sequence_output, pooled_output) if pooled_output is not None else (sequence_output,) + return head_outputs + encoder_outputs[1:] + + return BaseModelOutputWithPooling( + last_hidden_state=sequence_output, + pooler_output=pooled_output, + hidden_states=encoder_outputs.hidden_states, + attentions=encoder_outputs.attentions, + ) + +class GSAForImageClassification(GSAVisionPreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_classes + self.backbone = GSAVisionModel(config, add_pooling_layer=True) # Here we should use mean pooling + self.classifier = nn.Linear(config.hidden_size, config.num_classes) + self.init_weights() + + def forward( + self, + pixel_values: Optional[torch.Tensor] = None, + labels: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + interpolate_pos_encoding: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[tuple, ImageClassifierOutput]: + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.backbone( + pixel_values, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + interpolate_pos_encoding=interpolate_pos_encoding, + return_dict=return_dict, + ) + + pooled_output = outputs.pooler_output + logits = self.classifier(pooled_output) # only use mean pooling + + loss = None + if labels is not None: + if self.num_labels == 1: + loss_fct = torch.nn.MSELoss() + loss = loss_fct(logits.squeeze(), labels.squeeze()) + else: + loss_fct = torch.nn.CrossEntropyLoss() + loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) + + if not return_dict: + output = (logits,) + outputs[2:] + return ((loss,) + output) if loss is not None else output + + return ImageClassifierOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + +class GSAForMaskedImageModeling(GSAVisionPreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.backbone = GSAVisionModel(config, add_pooling_layer=False, use_mask_token=True) + self.decoder = nn.Sequential( + nn.Conv2d( + in_channels=config.hidden_size, + out_channels=config.encoder_stride**2 * config.num_channels, + kernel_size=1, + ), + nn.PixelShuffle(config.encoder_stride), + ) + self.init_weights() + + def forward( + self, + pixel_values: Optional[torch.Tensor] = None, + bool_masked_pos: Optional[torch.BoolTensor] = None, + labels: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + interpolate_pos_encoding: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[tuple, MaskedImageModelingOutput]: + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if bool_masked_pos is not None and (self.config.patch_size != self.config.encoder_stride): + raise ValueError( + "When `bool_masked_pos` is provided, `patch_size` must be equal to `encoder_stride` to ensure that " + "the reconstructed image has the same dimensions as the input. " + f"Got `patch_size` = {self.config.patch_size} and `encoder_stride` = {self.config.encoder_stride}." + ) + + outputs = self.backbone( + pixel_values, + bool_masked_pos=bool_masked_pos, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + interpolate_pos_encoding=interpolate_pos_encoding, + return_dict=return_dict, + ) + + + sequence_output = outputs[0] + batch_size, sequence_length, num_channels = sequence_output.shape + height = width = math.floor(sequence_length**0.5) + sequence_output = sequence_output.permute(0, 2, 1).reshape(batch_size, num_channels, height, width) + + # Reconstruct pixel values + reconstructed_pixel_values = self.decoder(sequence_output) + + masked_im_loss = None + if bool_masked_pos is not None: + size = self.config.image_size // self.config.patch_size + bool_masked_pos = bool_masked_pos.reshape(-1, size, size) + mask = ( + bool_masked_pos.repeat_interleave(self.config.patch_size, 1) + .repeat_interleave(self.config.patch_size, 2) + .unsqueeze(1) + .contiguous() + ) + reconstruction_loss = nn.functional.l1_loss(pixel_values, reconstructed_pixel_values, reduction="none") + masked_im_loss = (reconstruction_loss * mask).sum() / (mask.sum() + 1e-5) / self.config.num_channels + + if not return_dict: + output = (reconstructed_pixel_values,) + outputs[1:] + return ((masked_im_loss,) + output) if masked_im_loss is not None else output + + return MaskedImageModelingOutput( + loss=masked_im_loss, + reconstruction=reconstructed_pixel_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) \ No newline at end of file diff --git a/fla/models/hgrn/__init__.py b/fla/models/hgrn/__init__.py index 3b29a3dd8..4e524b707 100644 --- a/fla/models/hgrn/__init__.py +++ b/fla/models/hgrn/__init__.py @@ -1,13 +1,17 @@ # -*- coding: utf-8 -*- -from transformers import AutoConfig, AutoModel, AutoModelForCausalLM +from transformers import AutoConfig, AutoModel, AutoModelForCausalLM, AutoModelForImageClassification, AutoModelForMaskedImageModeling -from fla.models.hgrn.configuration_hgrn import HGRNConfig -from fla.models.hgrn.modeling_hgrn import HGRNForCausalLM, HGRNModel +from fla.models.hgrn.configuration_hgrn import HGRNConfig, HGRNVisionConfig +from fla.models.hgrn.modeling_hgrn import HGRNForCausalLM, HGRNModel, HGRNVisionModel, HGRNForImageClassification, HGRNForMaskedImageModeling AutoConfig.register(HGRNConfig.model_type, HGRNConfig) +AutoConfig.register(HGRNVisionConfig.model_type, HGRNVisionConfig) AutoModel.register(HGRNConfig, HGRNModel) AutoModelForCausalLM.register(HGRNConfig, HGRNForCausalLM) +AutoModelForImageClassification.register(HGRNVisionConfig, HGRNForImageClassification) +AutoModelForMaskedImageModeling.register(HGRNVisionConfig, HGRNForMaskedImageModeling) +AutoModel.register(HGRNVisionConfig, HGRNVisionModel) -__all__ = ['HGRNConfig', 'HGRNForCausalLM', 'HGRNModel'] +__all__ = ['HGRNConfig', 'HGRNForCausalLM', 'HGRNModel', 'HGRNVisionModel', 'HGRNForImageClassification', 'HGRNForMaskedImageModeling', 'HGRNVisionConfig'] diff --git a/fla/models/hgrn/configuration_hgrn.py b/fla/models/hgrn/configuration_hgrn.py index 39dd38db6..74aa04d2b 100644 --- a/fla/models/hgrn/configuration_hgrn.py +++ b/fla/models/hgrn/configuration_hgrn.py @@ -72,3 +72,87 @@ def __init__( tie_word_embeddings=tie_word_embeddings, **kwargs, ) + +class HGRNVisionConfig(PretrainedConfig): + + model_type = 'hgrn_vision' + + def __init__( + self, + # HGRN core parameters + attn_mode: str = "chunk", + hidden_size: int = 2048, + num_hidden_layers: int = 24, + expand_ratio: Optional[int] = 1, + use_short_conv: bool = False, + conv_size: int = 4, + use_lower_bound: bool = True, + max_position_embeddings: int = 2048, + hidden_act: str = "swish", + elementwise_affine: Optional[bool] = True, + norm_eps: float = 1e-6, + attn: Optional[Dict] = None, + use_cache: bool = True, + initializer_range: float = 0.02, + fuse_cross_entropy: bool = True, + # Vision specific parameters + image_size: int = 224, + patch_size: int = 16, + num_channels: int = 3, + num_classes: int = 1000, + qkv_bias: bool = True, + hidden_dropout_prob: float = 0.0, + use_mask_token: bool = False, + layer_norm_eps: float = 1e-6, + interpolate_pos_encoding: bool = False, + mlp_dim: int = None, + encoder_stride=16, + scan_type: str = "uni-scan", # scaning type, "uni-scan" or "bi-scan" or "cross-scan", default to "uni-scan" + **kwargs + ): + # Initialize HGRN core parameters + self.attn_mode = attn_mode + self.hidden_size = hidden_size + self.num_hidden_layers = num_hidden_layers + self.expand_ratio = expand_ratio + self.use_short_conv = use_short_conv + self.conv_size = conv_size + self.use_lower_bound = use_lower_bound + self.max_position_embeddings = max_position_embeddings + self.elementwise_affine = elementwise_affine + self.norm_eps = norm_eps + self.hidden_act = hidden_act + self.use_cache = use_cache + self.initializer_range = initializer_range + self.fuse_cross_entropy = fuse_cross_entropy + # Initialize vision specific parameters + self.image_size = image_size + self.patch_size = patch_size + self.num_channels = num_channels + self.num_classes = num_classes + self.qkv_bias = qkv_bias + self.hidden_dropout_prob = hidden_dropout_prob + self.use_mask_token = use_mask_token + self.layer_norm_eps = layer_norm_eps + self.interpolate_pos_encoding = interpolate_pos_encoding + self.scan_type = scan_type + self.encoder_stride = encoder_stride + + if attn is not None: + if not isinstance(attn, Dict): + raise ValueError("attn must be a dictionary") + if 'layers' not in attn: + raise ValueError("Layer indices must be provided to initialize hybrid attention layers") + if 'num_heads' not in attn: + raise ValueError("Number of heads must be provided to initialize hybrid attention layers") + attn['num_kv_heads'] = attn.get('num_kv_heads', attn['num_heads']) + attn['window_size'] = attn.get('window_size', None) + + self.attn = attn + + if mlp_dim is None: + self.mlp_dim = 4 * hidden_size # default value set to 4 * hidden_size + else: + self.mlp_dim = mlp_dim + + super().__init__(**kwargs) \ No newline at end of file diff --git a/fla/models/hgrn/modeling_hgrn.py b/fla/models/hgrn/modeling_hgrn.py index 087533caf..9006b4be3 100644 --- a/fla/models/hgrn/modeling_hgrn.py +++ b/fla/models/hgrn/modeling_hgrn.py @@ -12,18 +12,22 @@ from transformers.activations import ACT2FN from transformers.generation import GenerationMixin from transformers.modeling_outputs import (BaseModelOutputWithPast, - CausalLMOutputWithPast) + CausalLMOutputWithPast, + ImageClassifierOutput, + MaskedImageModelingOutput, + BaseModelOutput, + BaseModelOutputWithPooling) from transformers.modeling_utils import PreTrainedModel from transformers.utils import logging from fla.layers.attn import Attention from fla.layers.hgrn import HGRNAttention -from fla.models.hgrn.configuration_hgrn import HGRNConfig +from fla.models.hgrn.configuration_hgrn import HGRNConfig, HGRNVisionConfig from fla.models.utils import Cache from fla.modules import (FusedCrossEntropyLoss, FusedLinearCrossEntropyLoss, RMSNorm) from fla.modules.activations import swiglu_linear - +from ..utils import ImageEmbeddings, Pooler, prepare_hidden_states_for_cross_scan, prepare_hidden_states_for_cross_merge if TYPE_CHECKING: from transformers.processing_utils import Unpack @@ -434,3 +438,362 @@ def forward( hidden_states=outputs.hidden_states, attentions=outputs.attentions, ) + +class HGRNVisionMLP(nn.Module): + def __init__(self, config): + super().__init__() + self.net = nn.Sequential( + nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps), + nn.Linear(config.hidden_size, config.mlp_dim), + nn.GELU(), + nn.Linear(config.mlp_dim, config.hidden_size), + nn.Dropout(config.hidden_dropout_prob) + ) + + def forward(self, x): + return self.net(x) + +class HGRNVisionBlock(nn.Module): + def __init__(self, config, layer_idx: int): + super().__init__() + + self.ln_1 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + + if config.attn is not None and layer_idx in config.attn['layers']: + self.attn = Attention( + hidden_size=config.hidden_size, + num_heads=config.attn['num_heads'], + num_kv_heads=config.attn['num_kv_heads'], + window_size=config.attn['window_size'], + max_position_embeddings=config.max_position_embeddings, + layer_idx=layer_idx + ) + else: + self.attn = HGRNAttention( + mode=config.attn_mode, + hidden_size=config.hidden_size, + expand_ratio=config.expand_ratio, + use_short_conv=config.use_short_conv, + conv_size=config.conv_size, + elementwise_affine=config.elementwise_affine, + norm_eps=config.norm_eps, + layer_idx=layer_idx + ) + + self.ln_2 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + + self.mlp = HGRNVisionMLP(config) + + self.scan_type = config.scan_type + + def forward( + self, + hidden_states: torch.Tensor, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + use_cache: Optional[bool] = False, + output_attentions: Optional[bool] = False, + **kwargs: Unpack[Dict] + ) -> Union[Tuple[torch.Tensor, Optional[torch.Tensor]], Tuple[torch.Tensor]]: + residual = hidden_states + + # Pre-normalization if enabled + if hasattr(self, 'ln_1'): + hidden_states = self.ln_1(hidden_states) + + # Apply attention + + hidden_states = prepare_hidden_states_for_cross_scan(hidden_states, self.scan_type) + + hidden_states, attentions, past_key_values = self.attn( + hidden_states=hidden_states, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + **kwargs + ) + + hidden_states = prepare_hidden_states_for_cross_merge(hidden_states, self.scan_type) + + # First residual connection + hidden_states = residual + hidden_states + residual = hidden_states + + # Pre-normalization for MLP if enabled + if hasattr(self, 'ln_2'): + hidden_states = self.ln_2(hidden_states) + + hidden_states = self.mlp(hidden_states) + + # Second residual connection + hidden_states = residual + hidden_states + + outputs = (hidden_states, attentions, past_key_values) + + return outputs + +class HGRNVisionPreTrainedModel(PreTrainedModel): + config_class = HGRNVisionConfig + + def _init_weights(self, module): + if isinstance(module, (nn.Linear, nn.Conv2d)): + module.weight.data = nn.init.trunc_normal_( + module.weight.data.to(torch.float32), mean=0.0, std=self.config.initializer_range + ).to(module.weight.dtype) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.LayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + elif isinstance(module, ImageEmbeddings): + module.position_embeddings.data = nn.init.trunc_normal_( + module.position_embeddings.data.to(torch.float32), + mean=0.0, + std=self.config.initializer_range, + ).to(module.position_embeddings.dtype) + + +class HGRNVisionEncoder(nn.Module): + def __init__(self, config) -> None: + super().__init__() + self.config = config + self.blocks = nn.ModuleList([ + HGRNVisionBlock(config, layer_idx) + for layer_idx in range(config.num_hidden_layers) + ]) + self.gradient_checkpointing = False + + def forward( + self, + hidden_states: torch.Tensor, + output_attentions: bool = False, + output_hidden_states: bool = False, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + use_cache: Optional[bool] = None, + return_dict: bool = True, + **kwargs + ) -> Union[tuple, BaseModelOutput]: + all_hidden_states = () if output_hidden_states else None + all_self_attentions = () if output_attentions else None + + for i, block in enumerate(self.blocks): + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if self.gradient_checkpointing and self.training: + hidden_states, attentions, past_key_values = self._gradient_checkpointing_func( + block.__call__, + hidden_states, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + **kwargs + ) + else: + hidden_states, attentions, past_key_values = block( + hidden_states, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + **kwargs + ) + + if output_attentions: + all_self_attentions = all_self_attentions + (attentions,) + + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if not return_dict: + return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None) + + return BaseModelOutput( + last_hidden_state=hidden_states, + hidden_states=all_hidden_states, + attentions=all_self_attentions, + ) + +class HGRNVisionModel(HGRNVisionPreTrainedModel): + def __init__(self, config, add_pooling_layer=True, use_mask_token=False): + super().__init__(config) + self.config = config + self.embeddings = ImageEmbeddings(config, use_mask_token=use_mask_token) + self.encoder = HGRNVisionEncoder(config) + self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.pooler = Pooler(config) if add_pooling_layer else None + self.init_weights() + + def get_input_embeddings(self): + return self.embeddings.patch_embeddings + + def forward( + self, + pixel_values: Optional[torch.Tensor] = None, + bool_masked_pos: Optional[torch.BoolTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + interpolate_pos_encoding: Optional[bool] = None, + use_cache: Optional[bool] = None, + return_dict: Optional[bool] = None, + **kwargs + ) -> Union[Tuple, BaseModelOutputWithPooling]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if pixel_values is None: + raise ValueError("You have to specify pixel_values") + + hidden_states = self.embeddings(pixel_values, bool_masked_pos=bool_masked_pos, interpolate_pos_encoding=interpolate_pos_encoding) + + encoder_outputs = self.encoder( + hidden_states, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + past_key_values=past_key_values, + use_cache=use_cache, + return_dict=return_dict, + **kwargs + ) + + sequence_output = encoder_outputs[0] + sequence_output = self.layernorm(sequence_output) + pooled_output = self.pooler(sequence_output) if self.pooler is not None else None + + if not return_dict: + head_outputs = (sequence_output, pooled_output) if pooled_output is not None else (sequence_output,) + return head_outputs + encoder_outputs[1:] + + return BaseModelOutputWithPooling( + last_hidden_state=sequence_output, + pooler_output=pooled_output, + hidden_states=encoder_outputs.hidden_states, + attentions=encoder_outputs.attentions, + ) + +class HGRNForImageClassification(HGRNVisionPreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_classes + self.backbone = HGRNVisionModel(config, add_pooling_layer=True) # Here we should use mean pooling + self.classifier = nn.Linear(config.hidden_size, config.num_classes) + self.init_weights() + + def forward( + self, + pixel_values: Optional[torch.Tensor] = None, + labels: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + interpolate_pos_encoding: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[tuple, ImageClassifierOutput]: + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.backbone( + pixel_values, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + interpolate_pos_encoding=interpolate_pos_encoding, + return_dict=return_dict, + ) + + pooled_output = outputs.pooler_output + logits = self.classifier(pooled_output) # only use mean pooling + + loss = None + if labels is not None: + if self.num_labels == 1: + loss_fct = torch.nn.MSELoss() + loss = loss_fct(logits.squeeze(), labels.squeeze()) + else: + loss_fct = torch.nn.CrossEntropyLoss() + loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) + + if not return_dict: + output = (logits,) + outputs[2:] + return ((loss,) + output) if loss is not None else output + + return ImageClassifierOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + +class HGRNForMaskedImageModeling(HGRNVisionPreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.backbone = HGRNVisionModel(config, add_pooling_layer=False, use_mask_token=True) + self.decoder = nn.Sequential( + nn.Conv2d( + in_channels=config.hidden_size, + out_channels=config.encoder_stride**2 * config.num_channels, + kernel_size=1, + ), + nn.PixelShuffle(config.encoder_stride), + ) + self.init_weights() + + def forward( + self, + pixel_values: Optional[torch.Tensor] = None, + bool_masked_pos: Optional[torch.BoolTensor] = None, + labels: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + interpolate_pos_encoding: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[tuple, MaskedImageModelingOutput]: + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if bool_masked_pos is not None and (self.config.patch_size != self.config.encoder_stride): + raise ValueError( + "When `bool_masked_pos` is provided, `patch_size` must be equal to `encoder_stride` to ensure that " + "the reconstructed image has the same dimensions as the input. " + f"Got `patch_size` = {self.config.patch_size} and `encoder_stride` = {self.config.encoder_stride}." + ) + + outputs = self.backbone( + pixel_values, + bool_masked_pos=bool_masked_pos, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + interpolate_pos_encoding=interpolate_pos_encoding, + return_dict=return_dict, + ) + + + sequence_output = outputs[0] + batch_size, sequence_length, num_channels = sequence_output.shape + height = width = math.floor(sequence_length**0.5) + sequence_output = sequence_output.permute(0, 2, 1).reshape(batch_size, num_channels, height, width) + + # Reconstruct pixel values + reconstructed_pixel_values = self.decoder(sequence_output) + + masked_im_loss = None + if bool_masked_pos is not None: + size = self.config.image_size // self.config.patch_size + bool_masked_pos = bool_masked_pos.reshape(-1, size, size) + mask = ( + bool_masked_pos.repeat_interleave(self.config.patch_size, 1) + .repeat_interleave(self.config.patch_size, 2) + .unsqueeze(1) + .contiguous() + ) + reconstruction_loss = nn.functional.l1_loss(pixel_values, reconstructed_pixel_values, reduction="none") + masked_im_loss = (reconstruction_loss * mask).sum() / (mask.sum() + 1e-5) / self.config.num_channels + + if not return_dict: + output = (reconstructed_pixel_values,) + outputs[1:] + return ((masked_im_loss,) + output) if masked_im_loss is not None else output + + return MaskedImageModelingOutput( + loss=masked_im_loss, + reconstruction=reconstructed_pixel_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) \ No newline at end of file diff --git a/fla/models/hgrn2/__init__.py b/fla/models/hgrn2/__init__.py index 306b80822..d3e75efd7 100644 --- a/fla/models/hgrn2/__init__.py +++ b/fla/models/hgrn2/__init__.py @@ -1,13 +1,17 @@ # -*- coding: utf-8 -*- -from transformers import AutoConfig, AutoModel, AutoModelForCausalLM +from transformers import AutoConfig, AutoModel, AutoModelForCausalLM, AutoModelForImageClassification, AutoModelForMaskedImageModeling -from fla.models.hgrn2.configuration_hgrn2 import HGRN2Config -from fla.models.hgrn2.modeling_hgrn2 import HGRN2ForCausalLM, HGRN2Model +from fla.models.hgrn2.configuration_hgrn2 import HGRN2Config, HGRN2VisionConfig +from fla.models.hgrn2.modeling_hgrn2 import HGRN2ForCausalLM, HGRN2Model, HGRN2VisionModel, HGRN2ForImageClassification, HGRN2ForMaskedImageModeling AutoConfig.register(HGRN2Config.model_type, HGRN2Config) +AutoConfig.register(HGRN2VisionConfig.model_type, HGRN2VisionConfig) AutoModel.register(HGRN2Config, HGRN2Model) AutoModelForCausalLM.register(HGRN2Config, HGRN2ForCausalLM) +AutoModelForImageClassification.register(HGRN2VisionConfig, HGRN2ForImageClassification) +AutoModelForMaskedImageModeling.register(HGRN2VisionConfig, HGRN2ForMaskedImageModeling) +AutoModel.register(HGRN2VisionConfig, HGRN2VisionModel) -__all__ = ['HGRN2Config', 'HGRN2ForCausalLM', 'HGRN2Model'] +__all__ = ['HGRN2Config', 'HGRN2ForCausalLM', 'HGRN2Model', 'HGRN2VisionModel', 'HGRN2ForImageClassification', 'HGRN2ForMaskedImageModeling', 'HGRN2VisionConfig'] diff --git a/fla/models/hgrn2/configuration_hgrn2.py b/fla/models/hgrn2/configuration_hgrn2.py index d7a5945b8..34761e277 100644 --- a/fla/models/hgrn2/configuration_hgrn2.py +++ b/fla/models/hgrn2/configuration_hgrn2.py @@ -74,3 +74,89 @@ def __init__( tie_word_embeddings=tie_word_embeddings, **kwargs, ) + +class HGRN2VisionConfig(PretrainedConfig): + + model_type = 'hgrn2_vision' + + def __init__( + self, + # HGRN2 core parameters + hidden_size: int = 2048, + num_hidden_layers: int = 24, + attn_mode: str = "chunk", + num_heads: Optional[int] = None, + expand_ratio: Optional[int] = 128, + use_short_conv: bool = False, + conv_size: int = 4, + use_lower_bound: bool = True, + hidden_act: str = "swish", + max_position_embeddings: int = 2048, + elementwise_affine: Optional[bool] = True, + norm_eps: float = 1e-6, + attn: Optional[Dict] = None, + use_cache: bool = True, + initializer_range: float = 0.02, + fuse_cross_entropy: bool = True, + # Vision specific parameters + image_size: int = 224, + patch_size: int = 16, + num_channels: int = 3, + num_classes: int = 1000, + qkv_bias: bool = True, + hidden_dropout_prob: float = 0.0, + use_mask_token: bool = False, + layer_norm_eps: float = 1e-6, + interpolate_pos_encoding: bool = False, + mlp_dim: int = None, + encoder_stride=16, + scan_type: str = "uni-scan", # scaning type, "uni-scan" or "bi-scan" or "cross-scan", default to "uni-scan" + **kwargs + ): + # Initialize HGRN2 core parameters + self.hidden_size = hidden_size + self.num_hidden_layers = num_hidden_layers + self.attn_mode = attn_mode + self.num_heads = num_heads + self.expand_ratio = expand_ratio + self.use_short_conv = use_short_conv + self.conv_size = conv_size + self.use_lower_bound = use_lower_bound + self.max_position_embeddings = max_position_embeddings + self.hidden_act = hidden_act + self.elementwise_affine = elementwise_affine + self.norm_eps = norm_eps + self.use_cache = use_cache + self.initializer_range = initializer_range + self.fuse_cross_entropy = fuse_cross_entropy + # Initialize vision specific parameters + self.image_size = image_size + self.patch_size = patch_size + self.num_channels = num_channels + self.num_classes = num_classes + self.qkv_bias = qkv_bias + self.hidden_dropout_prob = hidden_dropout_prob + self.use_mask_token = use_mask_token + self.layer_norm_eps = layer_norm_eps + self.interpolate_pos_encoding = interpolate_pos_encoding + self.scan_type = scan_type + self.encoder_stride = encoder_stride + + if attn is not None: + if not isinstance(attn, Dict): + raise ValueError("attn must be a dictionary") + if 'layers' not in attn: + raise ValueError("Layer indices must be provided to initialize hybrid attention layers") + if 'num_heads' not in attn: + raise ValueError("Number of heads must be provided to initialize hybrid attention layers") + attn['num_kv_heads'] = attn.get('num_kv_heads', attn['num_heads']) + attn['window_size'] = attn.get('window_size', None) + + self.attn = attn + + if mlp_dim is None: + self.mlp_dim = 4 * hidden_size # default value set to 4 * hidden_size + else: + self.mlp_dim = mlp_dim + + super().__init__(**kwargs) \ No newline at end of file diff --git a/fla/models/hgrn2/modeling_hgrn2.py b/fla/models/hgrn2/modeling_hgrn2.py index e9f2cfff4..ea7423de7 100644 --- a/fla/models/hgrn2/modeling_hgrn2.py +++ b/fla/models/hgrn2/modeling_hgrn2.py @@ -12,18 +12,22 @@ from transformers.activations import ACT2FN from transformers.generation import GenerationMixin from transformers.modeling_outputs import (BaseModelOutputWithPast, - CausalLMOutputWithPast) + CausalLMOutputWithPast, + ImageClassifierOutput, + MaskedImageModelingOutput, + BaseModelOutput, + BaseModelOutputWithPooling) from transformers.modeling_utils import PreTrainedModel from transformers.utils import logging from fla.layers.attn import Attention from fla.layers.hgrn2 import HGRN2Attention -from fla.models.hgrn2.configuration_hgrn2 import HGRN2Config +from fla.models.hgrn2.configuration_hgrn2 import HGRN2Config, HGRN2VisionConfig from fla.models.utils import Cache from fla.modules import (FusedCrossEntropyLoss, FusedLinearCrossEntropyLoss, RMSNorm) from fla.modules.activations import swiglu_linear - +from ..utils import ImageEmbeddings, Pooler, prepare_hidden_states_for_cross_scan, prepare_hidden_states_for_cross_merge if TYPE_CHECKING: from transformers.processing_utils import Unpack @@ -435,3 +439,363 @@ def forward( hidden_states=outputs.hidden_states, attentions=outputs.attentions, ) + +class HGRN2VisionMLP(nn.Module): + def __init__(self, config): + super().__init__() + self.net = nn.Sequential( + nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps), + nn.Linear(config.hidden_size, config.mlp_dim), + nn.GELU(), + nn.Linear(config.mlp_dim, config.hidden_size), + nn.Dropout(config.hidden_dropout_prob) + ) + + def forward(self, x): + return self.net(x) + +class HGRN2VisionBlock(nn.Module): + def __init__(self, config, layer_idx: int): + super().__init__() + + self.ln_1 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + + if config.attn is not None and layer_idx in config.attn['layers']: + self.attn = Attention( + hidden_size=config.hidden_size, + num_heads=config.attn['num_heads'], + num_kv_heads=config.attn['num_kv_heads'], + window_size=config.attn['window_size'], + max_position_embeddings=config.max_position_embeddings, + layer_idx=layer_idx + ) + else: + self.attn = HGRN2Attention( + mode=config.attn_mode, + hidden_size=config.hidden_size, + num_heads=config.num_heads, + expand_ratio=config.expand_ratio, + use_short_conv=config.use_short_conv, + conv_size=config.conv_size, + elementwise_affine=config.elementwise_affine, + norm_eps=config.norm_eps, + layer_idx=layer_idx + ) + + self.ln_2 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + + self.mlp = HGRN2VisionMLP(config) + + self.scan_type = config.scan_type + + def forward( + self, + hidden_states: torch.Tensor, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + use_cache: Optional[bool] = False, + output_attentions: Optional[bool] = False, + **kwargs: Unpack[Dict] + ) -> Union[Tuple[torch.Tensor, Optional[torch.Tensor]], Tuple[torch.Tensor]]: + residual = hidden_states + + # Pre-normalization if enabled + if hasattr(self, 'ln_1'): + hidden_states = self.ln_1(hidden_states) + + # Apply attention + + hidden_states = prepare_hidden_states_for_cross_scan(hidden_states, self.scan_type) + + hidden_states, attentions, past_key_values = self.attn( + hidden_states=hidden_states, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + **kwargs + ) + + hidden_states = prepare_hidden_states_for_cross_merge(hidden_states, self.scan_type) + + # First residual connection + hidden_states = residual + hidden_states + residual = hidden_states + + # Pre-normalization for MLP if enabled + if hasattr(self, 'ln_2'): + hidden_states = self.ln_2(hidden_states) + + hidden_states = self.mlp(hidden_states) + + # Second residual connection + hidden_states = residual + hidden_states + + outputs = (hidden_states, attentions, past_key_values) + + return outputs + +class HGRN2VisionPreTrainedModel(PreTrainedModel): + config_class = HGRN2VisionConfig + + def _init_weights(self, module): + if isinstance(module, (nn.Linear, nn.Conv2d)): + module.weight.data = nn.init.trunc_normal_( + module.weight.data.to(torch.float32), mean=0.0, std=self.config.initializer_range + ).to(module.weight.dtype) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.LayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + elif isinstance(module, ImageEmbeddings): + module.position_embeddings.data = nn.init.trunc_normal_( + module.position_embeddings.data.to(torch.float32), + mean=0.0, + std=self.config.initializer_range, + ).to(module.position_embeddings.dtype) + + +class HGRN2VisionEncoder(nn.Module): + def __init__(self, config) -> None: + super().__init__() + self.config = config + self.blocks = nn.ModuleList([ + HGRN2VisionBlock(config, layer_idx) + for layer_idx in range(config.num_hidden_layers) + ]) + self.gradient_checkpointing = False + + def forward( + self, + hidden_states: torch.Tensor, + output_attentions: bool = False, + output_hidden_states: bool = False, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + use_cache: Optional[bool] = None, + return_dict: bool = True, + **kwargs + ) -> Union[tuple, BaseModelOutput]: + all_hidden_states = () if output_hidden_states else None + all_self_attentions = () if output_attentions else None + + for i, block in enumerate(self.blocks): + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if self.gradient_checkpointing and self.training: + hidden_states, attentions, past_key_values = self._gradient_checkpointing_func( + block.__call__, + hidden_states, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + **kwargs + ) + else: + hidden_states, attentions, past_key_values = block( + hidden_states, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + **kwargs + ) + + if output_attentions: + all_self_attentions = all_self_attentions + (attentions,) + + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if not return_dict: + return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None) + + return BaseModelOutput( + last_hidden_state=hidden_states, + hidden_states=all_hidden_states, + attentions=all_self_attentions, + ) + +class HGRN2VisionModel(HGRN2VisionPreTrainedModel): + def __init__(self, config, add_pooling_layer=True, use_mask_token=False): + super().__init__(config) + self.config = config + self.embeddings = ImageEmbeddings(config, use_mask_token=use_mask_token) + self.encoder = HGRN2VisionEncoder(config) + self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.pooler = Pooler(config) if add_pooling_layer else None + self.init_weights() + + def get_input_embeddings(self): + return self.embeddings.patch_embeddings + + def forward( + self, + pixel_values: Optional[torch.Tensor] = None, + bool_masked_pos: Optional[torch.BoolTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + interpolate_pos_encoding: Optional[bool] = None, + use_cache: Optional[bool] = None, + return_dict: Optional[bool] = None, + **kwargs + ) -> Union[Tuple, BaseModelOutputWithPooling]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if pixel_values is None: + raise ValueError("You have to specify pixel_values") + + hidden_states = self.embeddings(pixel_values, bool_masked_pos=bool_masked_pos, interpolate_pos_encoding=interpolate_pos_encoding) + + encoder_outputs = self.encoder( + hidden_states, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + past_key_values=past_key_values, + use_cache=use_cache, + return_dict=return_dict, + **kwargs + ) + + sequence_output = encoder_outputs[0] + sequence_output = self.layernorm(sequence_output) + pooled_output = self.pooler(sequence_output) if self.pooler is not None else None + + if not return_dict: + head_outputs = (sequence_output, pooled_output) if pooled_output is not None else (sequence_output,) + return head_outputs + encoder_outputs[1:] + + return BaseModelOutputWithPooling( + last_hidden_state=sequence_output, + pooler_output=pooled_output, + hidden_states=encoder_outputs.hidden_states, + attentions=encoder_outputs.attentions, + ) + +class HGRN2ForImageClassification(HGRN2VisionPreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_classes + self.backbone = HGRN2VisionModel(config, add_pooling_layer=True) # Here we should use mean pooling + self.classifier = nn.Linear(config.hidden_size, config.num_classes) + self.init_weights() + + def forward( + self, + pixel_values: Optional[torch.Tensor] = None, + labels: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + interpolate_pos_encoding: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[tuple, ImageClassifierOutput]: + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.backbone( + pixel_values, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + interpolate_pos_encoding=interpolate_pos_encoding, + return_dict=return_dict, + ) + + pooled_output = outputs.pooler_output + logits = self.classifier(pooled_output) # only use mean pooling + + loss = None + if labels is not None: + if self.num_labels == 1: + loss_fct = torch.nn.MSELoss() + loss = loss_fct(logits.squeeze(), labels.squeeze()) + else: + loss_fct = torch.nn.CrossEntropyLoss() + loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) + + if not return_dict: + output = (logits,) + outputs[2:] + return ((loss,) + output) if loss is not None else output + + return ImageClassifierOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + +class HGRN2ForMaskedImageModeling(HGRN2VisionPreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.backbone = HGRN2VisionModel(config, add_pooling_layer=False, use_mask_token=True) + self.decoder = nn.Sequential( + nn.Conv2d( + in_channels=config.hidden_size, + out_channels=config.encoder_stride**2 * config.num_channels, + kernel_size=1, + ), + nn.PixelShuffle(config.encoder_stride), + ) + self.init_weights() + + def forward( + self, + pixel_values: Optional[torch.Tensor] = None, + bool_masked_pos: Optional[torch.BoolTensor] = None, + labels: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + interpolate_pos_encoding: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[tuple, MaskedImageModelingOutput]: + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if bool_masked_pos is not None and (self.config.patch_size != self.config.encoder_stride): + raise ValueError( + "When `bool_masked_pos` is provided, `patch_size` must be equal to `encoder_stride` to ensure that " + "the reconstructed image has the same dimensions as the input. " + f"Got `patch_size` = {self.config.patch_size} and `encoder_stride` = {self.config.encoder_stride}." + ) + + outputs = self.backbone( + pixel_values, + bool_masked_pos=bool_masked_pos, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + interpolate_pos_encoding=interpolate_pos_encoding, + return_dict=return_dict, + ) + + + sequence_output = outputs[0] + batch_size, sequence_length, num_channels = sequence_output.shape + height = width = math.floor(sequence_length**0.5) + sequence_output = sequence_output.permute(0, 2, 1).reshape(batch_size, num_channels, height, width) + + # Reconstruct pixel values + reconstructed_pixel_values = self.decoder(sequence_output) + + masked_im_loss = None + if bool_masked_pos is not None: + size = self.config.image_size // self.config.patch_size + bool_masked_pos = bool_masked_pos.reshape(-1, size, size) + mask = ( + bool_masked_pos.repeat_interleave(self.config.patch_size, 1) + .repeat_interleave(self.config.patch_size, 2) + .unsqueeze(1) + .contiguous() + ) + reconstruction_loss = nn.functional.l1_loss(pixel_values, reconstructed_pixel_values, reduction="none") + masked_im_loss = (reconstruction_loss * mask).sum() / (mask.sum() + 1e-5) / self.config.num_channels + + if not return_dict: + output = (reconstructed_pixel_values,) + outputs[1:] + return ((masked_im_loss,) + output) if masked_im_loss is not None else output + + return MaskedImageModelingOutput( + loss=masked_im_loss, + reconstruction=reconstructed_pixel_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) \ No newline at end of file diff --git a/fla/models/linear_attn/__init__.py b/fla/models/linear_attn/__init__.py index 72d5d022d..340148b65 100644 --- a/fla/models/linear_attn/__init__.py +++ b/fla/models/linear_attn/__init__.py @@ -1,14 +1,18 @@ # -*- coding: utf-8 -*- -from transformers import AutoConfig, AutoModel, AutoModelForCausalLM +from transformers import AutoConfig, AutoModel, AutoModelForCausalLM, AutoModelForImageClassification, AutoModelForMaskedImageModeling from fla.models.linear_attn.configuration_linear_attn import \ - LinearAttentionConfig + LinearAttentionConfig, LinearAttentionVisionConfig from fla.models.linear_attn.modeling_linear_attn import ( - LinearAttentionForCausalLM, LinearAttentionModel) + LinearAttentionForCausalLM, LinearAttentionModel, LinearAttentionVisionModel, LinearAttentionForImageClassification, LinearAttentionForMaskedImageModeling) AutoConfig.register(LinearAttentionConfig.model_type, LinearAttentionConfig) +AutoConfig.register(LinearAttentionVisionConfig.model_type, LinearAttentionVisionConfig) AutoModel.register(LinearAttentionConfig, LinearAttentionModel) AutoModelForCausalLM.register(LinearAttentionConfig, LinearAttentionForCausalLM) +AutoModelForImageClassification.register(LinearAttentionVisionConfig, LinearAttentionForImageClassification) +AutoModelForMaskedImageModeling.register(LinearAttentionVisionConfig, LinearAttentionForMaskedImageModeling) +AutoModel.register(LinearAttentionVisionConfig, LinearAttentionVisionModel) -__all__ = ['LinearAttentionConfig', 'LinearAttentionForCausalLM', 'LinearAttentionModel'] +__all__ = ['LinearAttentionConfig', 'LinearAttentionForCausalLM', 'LinearAttentionModel', 'LinearAttentionVisionModel', 'LinearAttentionForImageClassification', 'LinearAttentionForMaskedImageModeling', 'LinearAttentionVisionConfig'] diff --git a/fla/models/linear_attn/configuration_linear_attn.py b/fla/models/linear_attn/configuration_linear_attn.py index d1bff79e2..f69725420 100644 --- a/fla/models/linear_attn/configuration_linear_attn.py +++ b/fla/models/linear_attn/configuration_linear_attn.py @@ -81,3 +81,96 @@ def __init__( tie_word_embeddings=tie_word_embeddings, **kwargs, ) + +class LinearAttentionVisionConfig(PretrainedConfig): + + model_type = 'linear_attn_vision' + + def __init__( + self, + # LinearAttention core parameters + attn_mode: str = "fused_chunk", + hidden_size: int = 2048, + expand_k: int = 1, + expand_v: int = 1, + num_hidden_layers: int = 24, + num_heads: int = 4, + num_kv_heads: Optional[int] = None, + feature_map: str = "elementwise_product", + tie_feature_map_qk: bool = False, + norm_q: bool = False, + norm_k: bool = False, + norm_feature_map: bool = False, + hidden_act: str = "swish", + max_position_embeddings: int = 2048, + elementwise_affine: Optional[bool] = True, + norm_eps: float = 1e-6, + attn: Optional[Dict] = None, + use_cache: bool = True, + initializer_range: float = 0.02, + fuse_cross_entropy: bool = True, + # Vision specific parameters + image_size: int = 224, + patch_size: int = 16, + num_channels: int = 3, + num_classes: int = 1000, + qkv_bias: bool = True, + hidden_dropout_prob: float = 0.0, + use_mask_token: bool = False, + layer_norm_eps: float = 1e-6, + interpolate_pos_encoding: bool = False, + mlp_dim: int = None, + encoder_stride=16, + scan_type: str = "uni-scan", # scaning type, "uni-scan" or "bi-scan" or "cross-scan", default to "uni-scan" + **kwargs + ): + # Initialize LinearAttention core parameters + self.attn_mode = attn_mode + self.hidden_size = hidden_size + self.expand_k = expand_k + self.expand_v = expand_v + self.num_hidden_layers = num_hidden_layers + self.num_heads = num_heads + self.num_kv_heads = num_kv_heads + self.feature_map = feature_map + self.tie_feature_map_qk = tie_feature_map_qk + self.norm_q = norm_q + self.norm_k = norm_k + self.norm_feature_map = norm_feature_map + self.max_position_embeddings = max_position_embeddings + self.elementwise_affine = elementwise_affine + self.norm_eps = norm_eps + self.use_cache = use_cache + self.initializer_range = initializer_range + self.fuse_cross_entropy = fuse_cross_entropy + + # Initialize vision specific parameters + self.image_size = image_size + self.patch_size = patch_size + self.num_channels = num_channels + self.num_classes = num_classes + self.qkv_bias = qkv_bias + self.hidden_dropout_prob = hidden_dropout_prob + self.use_mask_token = use_mask_token + self.layer_norm_eps = layer_norm_eps + self.interpolate_pos_encoding = interpolate_pos_encoding + self.scan_type = scan_type + self.encoder_stride = encoder_stride + + if attn is not None: + if not isinstance(attn, Dict): + raise ValueError("attn must be a dictionary") + if 'layers' not in attn: + raise ValueError("Layer indices must be provided to initialize hybrid attention layers") + if 'num_heads' not in attn: + raise ValueError("Number of heads must be provided to initialize hybrid attention layers") + attn['num_kv_heads'] = attn.get('num_kv_heads', attn['num_heads']) + attn['window_size'] = attn.get('window_size', None) + + self.attn = attn + if mlp_dim is None: + self.mlp_dim = 4 * hidden_size # default value set to 4 * hidden_size + else: + self.mlp_dim = mlp_dim + + super().__init__(**kwargs) \ No newline at end of file diff --git a/fla/models/linear_attn/modeling_linear_attn.py b/fla/models/linear_attn/modeling_linear_attn.py index 2fcea2c87..c7b13c87e 100644 --- a/fla/models/linear_attn/modeling_linear_attn.py +++ b/fla/models/linear_attn/modeling_linear_attn.py @@ -4,7 +4,7 @@ import math import warnings -from typing import List, Optional, Tuple, Union +from typing import List, Optional, Tuple, Union, Unpack, Dict import torch import torch.nn as nn @@ -12,18 +12,23 @@ from transformers.activations import ACT2FN from transformers.generation import GenerationMixin from transformers.modeling_outputs import (BaseModelOutputWithPast, - CausalLMOutputWithPast) + CausalLMOutputWithPast, + ImageClassifierOutput, + MaskedImageModelingOutput, + BaseModelOutput, + BaseModelOutputWithPooling) from transformers.modeling_utils import PreTrainedModel from transformers.utils import logging from fla.layers.attn import Attention from fla.layers.linear_attn import LinearAttention from fla.models.linear_attn.configuration_linear_attn import \ - LinearAttentionConfig + LinearAttentionConfig, LinearAttentionVisionConfig from fla.models.utils import Cache from fla.modules import (FusedCrossEntropyLoss, FusedLinearCrossEntropyLoss, RMSNorm) from fla.modules.activations import swiglu_linear +from ..utils import ImageEmbeddings, Pooler, prepare_hidden_states_for_cross_scan, prepare_hidden_states_for_cross_merge logger = logging.get_logger(__name__) @@ -427,3 +432,368 @@ def forward( hidden_states=outputs.hidden_states, attentions=outputs.attentions, ) + +class LinearAttentionVisionMLP(nn.Module): + def __init__(self, config): + super().__init__() + self.net = nn.Sequential( + nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps), + nn.Linear(config.hidden_size, config.mlp_dim), + nn.GELU(), + nn.Linear(config.mlp_dim, config.hidden_size), + nn.Dropout(config.hidden_dropout_prob) + ) + + def forward(self, x): + return self.net(x) + +class LinearAttentionVisionBlock(nn.Module): + def __init__(self, config, layer_idx: int): + super().__init__() + + self.ln_1 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + + if config.attn is not None and layer_idx in config.attn['layers']: + self.attn = Attention( + hidden_size=config.hidden_size, + num_heads=config.attn['num_heads'], + num_kv_heads=config.attn['num_kv_heads'], + window_size=config.attn['window_size'], + max_position_embeddings=config.max_position_embeddings, + layer_idx=layer_idx + ) + else: + self.attn = LinearAttention( + mode=config.attn_mode, + hidden_size=config.hidden_size, + expand_k=config.expand_k, + expand_v=config.expand_v, + num_heads=config.num_heads, + num_kv_heads=config.num_kv_heads, + feature_map=config.feature_map, + tie_feature_map_qk=config.tie_feature_map_qk, + norm_q=config.norm_q, + norm_k=config.norm_k, + do_feature_map_norm=config.norm_feature_map, + elementwise_affine=config.elementwise_affine, + norm_eps=config.norm_eps, + layer_idx=layer_idx + ) + + self.ln_2 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + + self.mlp = LinearAttentionVisionMLP(config) + + self.scan_type = config.scan_type + + def forward( + self, + hidden_states: torch.Tensor, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + use_cache: Optional[bool] = False, + output_attentions: Optional[bool] = False, + **kwargs: Unpack[Dict] + ) -> Union[Tuple[torch.Tensor, Optional[torch.Tensor]], Tuple[torch.Tensor]]: + residual = hidden_states + + # Pre-normalization if enabled + if hasattr(self, 'ln_1'): + hidden_states = self.ln_1(hidden_states) + + # Apply attention + + hidden_states = prepare_hidden_states_for_cross_scan(hidden_states, self.scan_type) + + hidden_states, attentions, past_key_values = self.attn( + hidden_states=hidden_states, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + **kwargs + ) + + hidden_states = prepare_hidden_states_for_cross_merge(hidden_states, self.scan_type) + + # First residual connection + hidden_states = residual + hidden_states + residual = hidden_states + + # Pre-normalization for MLP if enabled + if hasattr(self, 'ln_2'): + hidden_states = self.ln_2(hidden_states) + + hidden_states = self.mlp(hidden_states) + + # Second residual connection + hidden_states = residual + hidden_states + + outputs = (hidden_states, attentions, past_key_values) + + return outputs + +class LinearAttentionVisionPreTrainedModel(PreTrainedModel): + config_class = LinearAttentionVisionConfig + + def _init_weights(self, module): + if isinstance(module, (nn.Linear, nn.Conv2d)): + module.weight.data = nn.init.trunc_normal_( + module.weight.data.to(torch.float32), mean=0.0, std=self.config.initializer_range + ).to(module.weight.dtype) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.LayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + elif isinstance(module, ImageEmbeddings): + module.position_embeddings.data = nn.init.trunc_normal_( + module.position_embeddings.data.to(torch.float32), + mean=0.0, + std=self.config.initializer_range, + ).to(module.position_embeddings.dtype) + + +class LinearAttentionVisionEncoder(nn.Module): + def __init__(self, config) -> None: + super().__init__() + self.config = config + self.blocks = nn.ModuleList([ + LinearAttentionVisionBlock(config, layer_idx) + for layer_idx in range(config.num_hidden_layers) + ]) + self.gradient_checkpointing = False + + def forward( + self, + hidden_states: torch.Tensor, + output_attentions: bool = False, + output_hidden_states: bool = False, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + use_cache: Optional[bool] = None, + return_dict: bool = True, + **kwargs + ) -> Union[tuple, BaseModelOutput]: + all_hidden_states = () if output_hidden_states else None + all_self_attentions = () if output_attentions else None + + for i, block in enumerate(self.blocks): + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if self.gradient_checkpointing and self.training: + hidden_states, attentions, past_key_values = self._gradient_checkpointing_func( + block.__call__, + hidden_states, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + **kwargs + ) + else: + hidden_states, attentions, past_key_values = block( + hidden_states, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + **kwargs + ) + + if output_attentions: + all_self_attentions = all_self_attentions + (attentions,) + + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if not return_dict: + return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None) + + return BaseModelOutput( + last_hidden_state=hidden_states, + hidden_states=all_hidden_states, + attentions=all_self_attentions, + ) + +class LinearAttentionVisionModel(LinearAttentionVisionPreTrainedModel): + def __init__(self, config, add_pooling_layer=True, use_mask_token=False): + super().__init__(config) + self.config = config + self.embeddings = ImageEmbeddings(config, use_mask_token=use_mask_token) + self.encoder = LinearAttentionVisionEncoder(config) + self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.pooler = Pooler(config) if add_pooling_layer else None + self.init_weights() + + def get_input_embeddings(self): + return self.embeddings.patch_embeddings + + def forward( + self, + pixel_values: Optional[torch.Tensor] = None, + bool_masked_pos: Optional[torch.BoolTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + interpolate_pos_encoding: Optional[bool] = None, + use_cache: Optional[bool] = None, + return_dict: Optional[bool] = None, + **kwargs + ) -> Union[Tuple, BaseModelOutputWithPooling]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if pixel_values is None: + raise ValueError("You have to specify pixel_values") + + hidden_states = self.embeddings(pixel_values, bool_masked_pos=bool_masked_pos, interpolate_pos_encoding=interpolate_pos_encoding) + + encoder_outputs = self.encoder( + hidden_states, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + past_key_values=past_key_values, + use_cache=use_cache, + return_dict=return_dict, + **kwargs + ) + + sequence_output = encoder_outputs[0] + sequence_output = self.layernorm(sequence_output) + pooled_output = self.pooler(sequence_output) if self.pooler is not None else None + + if not return_dict: + head_outputs = (sequence_output, pooled_output) if pooled_output is not None else (sequence_output,) + return head_outputs + encoder_outputs[1:] + + return BaseModelOutputWithPooling( + last_hidden_state=sequence_output, + pooler_output=pooled_output, + hidden_states=encoder_outputs.hidden_states, + attentions=encoder_outputs.attentions, + ) + +class LinearAttentionForImageClassification(LinearAttentionVisionPreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_classes + self.backbone = LinearAttentionVisionModel(config, add_pooling_layer=True) # Here we should use mean pooling + self.classifier = nn.Linear(config.hidden_size, config.num_classes) + self.init_weights() + + def forward( + self, + pixel_values: Optional[torch.Tensor] = None, + labels: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + interpolate_pos_encoding: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[tuple, ImageClassifierOutput]: + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.backbone( + pixel_values, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + interpolate_pos_encoding=interpolate_pos_encoding, + return_dict=return_dict, + ) + + pooled_output = outputs.pooler_output + logits = self.classifier(pooled_output) # only use mean pooling + + loss = None + if labels is not None: + if self.num_labels == 1: + loss_fct = torch.nn.MSELoss() + loss = loss_fct(logits.squeeze(), labels.squeeze()) + else: + loss_fct = torch.nn.CrossEntropyLoss() + loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) + + if not return_dict: + output = (logits,) + outputs[2:] + return ((loss,) + output) if loss is not None else output + + return ImageClassifierOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + +class LinearAttentionForMaskedImageModeling(LinearAttentionVisionPreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.backbone = LinearAttentionVisionModel(config, add_pooling_layer=False, use_mask_token=True) + self.decoder = nn.Sequential( + nn.Conv2d( + in_channels=config.hidden_size, + out_channels=config.encoder_stride**2 * config.num_channels, + kernel_size=1, + ), + nn.PixelShuffle(config.encoder_stride), + ) + self.init_weights() + + def forward( + self, + pixel_values: Optional[torch.Tensor] = None, + bool_masked_pos: Optional[torch.BoolTensor] = None, + labels: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + interpolate_pos_encoding: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[tuple, MaskedImageModelingOutput]: + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if bool_masked_pos is not None and (self.config.patch_size != self.config.encoder_stride): + raise ValueError( + "When `bool_masked_pos` is provided, `patch_size` must be equal to `encoder_stride` to ensure that " + "the reconstructed image has the same dimensions as the input. " + f"Got `patch_size` = {self.config.patch_size} and `encoder_stride` = {self.config.encoder_stride}." + ) + + outputs = self.backbone( + pixel_values, + bool_masked_pos=bool_masked_pos, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + interpolate_pos_encoding=interpolate_pos_encoding, + return_dict=return_dict, + ) + + + sequence_output = outputs[0] + batch_size, sequence_length, num_channels = sequence_output.shape + height = width = math.floor(sequence_length**0.5) + sequence_output = sequence_output.permute(0, 2, 1).reshape(batch_size, num_channels, height, width) + + # Reconstruct pixel values + reconstructed_pixel_values = self.decoder(sequence_output) + + masked_im_loss = None + if bool_masked_pos is not None: + size = self.config.image_size // self.config.patch_size + bool_masked_pos = bool_masked_pos.reshape(-1, size, size) + mask = ( + bool_masked_pos.repeat_interleave(self.config.patch_size, 1) + .repeat_interleave(self.config.patch_size, 2) + .unsqueeze(1) + .contiguous() + ) + reconstruction_loss = nn.functional.l1_loss(pixel_values, reconstructed_pixel_values, reduction="none") + masked_im_loss = (reconstruction_loss * mask).sum() / (mask.sum() + 1e-5) / self.config.num_channels + + if not return_dict: + output = (reconstructed_pixel_values,) + outputs[1:] + return ((masked_im_loss,) + output) if masked_im_loss is not None else output + + return MaskedImageModelingOutput( + loss=masked_im_loss, + reconstruction=reconstructed_pixel_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) \ No newline at end of file diff --git a/fla/models/retnet/__init__.py b/fla/models/retnet/__init__.py index ad7d9e9da..ae36ebb50 100644 --- a/fla/models/retnet/__init__.py +++ b/fla/models/retnet/__init__.py @@ -1,13 +1,17 @@ # -*- coding: utf-8 -*- -from transformers import AutoConfig, AutoModel, AutoModelForCausalLM +from transformers import AutoConfig, AutoModel, AutoModelForCausalLM, AutoModelForImageClassification, AutoModelForMaskedImageModeling -from fla.models.retnet.configuration_retnet import RetNetConfig -from fla.models.retnet.modeling_retnet import RetNetForCausalLM, RetNetModel +from fla.models.retnet.configuration_retnet import RetNetConfig, RetNetVisionConfig +from fla.models.retnet.modeling_retnet import RetNetForCausalLM, RetNetModel, RetNetVisionModel, RetNetForImageClassification, RetNetForMaskedImageModeling AutoConfig.register(RetNetConfig.model_type, RetNetConfig) +AutoConfig.register(RetNetVisionConfig.model_type, RetNetVisionConfig) AutoModel.register(RetNetConfig, RetNetModel) AutoModelForCausalLM.register(RetNetConfig, RetNetForCausalLM) +AutoModel.register(RetNetVisionConfig, RetNetVisionModel) +AutoModelForImageClassification.register(RetNetVisionConfig, RetNetForImageClassification) +AutoModelForMaskedImageModeling.register(RetNetVisionConfig, RetNetForMaskedImageModeling) -__all__ = ['RetNetConfig', 'RetNetForCausalLM', 'RetNetModel'] +__all__ = ['RetNetConfig', 'RetNetForCausalLM', 'RetNetModel', 'RetNetVisionModel', 'RetNetForImageClassification', 'RetNetForMaskedImageModeling', 'RetNetVisionConfig'] diff --git a/fla/models/retnet/configuration_retnet.py b/fla/models/retnet/configuration_retnet.py index 535841629..21f6be6a7 100644 --- a/fla/models/retnet/configuration_retnet.py +++ b/fla/models/retnet/configuration_retnet.py @@ -85,3 +85,99 @@ def __init__( tie_word_embeddings=tie_word_embeddings, **kwargs, ) + +class RetNetVisionConfig(PretrainedConfig): + + model_type = 'retnet_vision' + + def __init__( + self, + # RetNet core parameters + attn_mode: str = "chunk", + hidden_size: int = 2048, + expand_k: int = 1, + expand_v: int = 2, + num_hidden_layers: int = 24, + num_heads: int = 8, + num_kv_heads: Optional[int] = None, + feature_map: Optional[str] = None, + hidden_act: str = "swish", + use_short_conv: bool = False, + conv_size: int = 4, + use_output_gate: bool = True, + max_position_embeddings: int = 2048, + elementwise_affine: Optional[bool] = True, + norm_eps: float = 1e-6, + attn: Optional[Dict] = None, + use_cache: bool = True, + initializer_range: float = 0.02, + fuse_norm: bool = True, + fuse_cross_entropy: bool = True, + # Vision specific parameters + image_size: int = 224, + patch_size: int = 16, + num_channels: int = 3, + num_classes: int = 1000, + qkv_bias: bool = True, + hidden_dropout_prob: float = 0.0, + use_mask_token: bool = False, + layer_norm_eps: float = 1e-6, + interpolate_pos_encoding: bool = False, + mlp_dim: int = None, + encoder_stride=16, + scan_type: str = "uni-scan", # scaning type, "uni-scan" or "bi-scan" or "cross-scan", default to "uni-scan" + **kwargs + ) -> RetNetVisionConfig: + # Initialize RetNet core parameters + self.attn_mode = attn_mode + self.hidden_size = hidden_size + self.expand_k = expand_k + self.expand_v = expand_v + self.num_hidden_layers = num_hidden_layers + self.num_heads = num_heads + self.num_kv_heads = num_kv_heads + self.feature_map = feature_map + self.hidden_act = hidden_act + self.use_short_conv = use_short_conv + self.conv_size = conv_size + self.use_output_gate = use_output_gate + self.hidden_act = hidden_act + self.max_position_embeddings = max_position_embeddings + self.elementwise_affine = elementwise_affine + self.norm_eps = norm_eps + self.use_cache = use_cache + self.initializer_range = initializer_range + self.fuse_norm = fuse_norm + self.fuse_cross_entropy = fuse_cross_entropy + + # Initialize vision specific parameters + self.image_size = image_size + self.patch_size = patch_size + self.num_channels = num_channels + self.num_classes = num_classes + self.qkv_bias = qkv_bias + self.hidden_dropout_prob = hidden_dropout_prob + self.use_mask_token = use_mask_token + self.layer_norm_eps = layer_norm_eps + self.interpolate_pos_encoding = interpolate_pos_encoding + self.scan_type = scan_type + self.encoder_stride = encoder_stride + + if attn is not None: + if not isinstance(attn, Dict): + raise ValueError("attn must be a dictionary") + if 'layers' not in attn: + raise ValueError("Layer indices must be provided to initialize hybrid attention layers") + if 'num_heads' not in attn: + raise ValueError("Number of heads must be provided to initialize hybrid attention layers") + attn['num_kv_heads'] = attn.get('num_kv_heads', attn['num_heads']) + attn['window_size'] = attn.get('window_size', None) + + self.attn = attn + + if mlp_dim is None: + self.mlp_dim = 4 * hidden_size # default value set to 4 * hidden_size + else: + self.mlp_dim = mlp_dim + + super().__init__(**kwargs) \ No newline at end of file diff --git a/fla/models/retnet/modeling_retnet.py b/fla/models/retnet/modeling_retnet.py index cd0af2bff..f3e79c5e8 100644 --- a/fla/models/retnet/modeling_retnet.py +++ b/fla/models/retnet/modeling_retnet.py @@ -12,18 +12,22 @@ from transformers.activations import ACT2FN from transformers.generation import GenerationMixin from transformers.modeling_outputs import (BaseModelOutputWithPast, - CausalLMOutputWithPast) + CausalLMOutputWithPast, + ImageClassifierOutput, + MaskedImageModelingOutput, + BaseModelOutput, + BaseModelOutputWithPooling) from transformers.modeling_utils import PreTrainedModel from transformers.utils import logging from fla.layers.attn import Attention from fla.layers.multiscale_retention import MultiScaleRetention -from fla.models.retnet.configuration_retnet import RetNetConfig +from fla.models.retnet.configuration_retnet import RetNetConfig, RetNetVisionConfig from fla.models.utils import Cache from fla.modules import (FusedCrossEntropyLoss, FusedLinearCrossEntropyLoss, RMSNorm) from fla.modules.activations import swiglu_linear - +from ..utils import ImageEmbeddings, Pooler, prepare_hidden_states_for_cross_scan, prepare_hidden_states_for_cross_merge if TYPE_CHECKING: from transformers.processing_utils import Unpack @@ -439,3 +443,367 @@ def forward( hidden_states=outputs.hidden_states, attentions=outputs.attentions, ) + +class RetNetVisionMLP(nn.Module): + def __init__(self, config): + super().__init__() + self.net = nn.Sequential( + nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps), + nn.Linear(config.hidden_size, config.mlp_dim), + nn.GELU(), + nn.Linear(config.mlp_dim, config.hidden_size), + nn.Dropout(config.hidden_dropout_prob) + ) + + def forward(self, x): + return self.net(x) + +class RetNetVisionBlock(nn.Module): + def __init__(self, config, layer_idx: int): + super().__init__() + + self.ln_1 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + + if config.attn is not None and layer_idx in config.attn['layers']: + self.attn = Attention( + hidden_size=config.hidden_size, + num_heads=config.attn['num_heads'], + num_kv_heads=config.attn['num_kv_heads'], + window_size=config.attn['window_size'], + max_position_embeddings=config.max_position_embeddings, + layer_idx=layer_idx + ) + else: + self.attn = MultiScaleRetention( + mode=config.attn_mode, + hidden_size=config.hidden_size, + expand_k=config.expand_k, + expand_v=config.expand_v, + num_heads=config.num_heads, + num_kv_heads=config.num_kv_heads, + feature_map=config.feature_map, + use_output_gate=config.use_output_gate, + gate_fn=config.hidden_act, + elementwise_affine=config.elementwise_affine, + norm_eps=config.norm_eps, + fuse_norm=config.fuse_norm, + layer_idx=layer_idx + ) + + self.ln_2 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + + self.mlp = RetNetVisionMLP(config) + + self.scan_type = config.scan_type + + def forward( + self, + hidden_states: torch.Tensor, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + use_cache: Optional[bool] = False, + output_attentions: Optional[bool] = False, + **kwargs: Unpack[Dict] + ) -> Union[Tuple[torch.Tensor, Optional[torch.Tensor]], Tuple[torch.Tensor]]: + residual = hidden_states + + # Pre-normalization if enabled + if hasattr(self, 'ln_1'): + hidden_states = self.ln_1(hidden_states) + + # Apply attention + + hidden_states = prepare_hidden_states_for_cross_scan(hidden_states, self.scan_type) + + hidden_states, attentions, past_key_values = self.attn( + hidden_states=hidden_states, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + **kwargs + ) + + hidden_states = prepare_hidden_states_for_cross_merge(hidden_states, self.scan_type) + + # First residual connection + hidden_states = residual + hidden_states + residual = hidden_states + + # Pre-normalization for MLP if enabled + if hasattr(self, 'ln_2'): + hidden_states = self.ln_2(hidden_states) + + hidden_states = self.mlp(hidden_states) + + # Second residual connection + hidden_states = residual + hidden_states + + outputs = (hidden_states, attentions, past_key_values) + + return outputs + +class RetNetVisionPreTrainedModel(PreTrainedModel): + config_class = RetNetVisionConfig + + def _init_weights(self, module): + if isinstance(module, (nn.Linear, nn.Conv2d)): + module.weight.data = nn.init.trunc_normal_( + module.weight.data.to(torch.float32), mean=0.0, std=self.config.initializer_range + ).to(module.weight.dtype) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.LayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + elif isinstance(module, ImageEmbeddings): + module.position_embeddings.data = nn.init.trunc_normal_( + module.position_embeddings.data.to(torch.float32), + mean=0.0, + std=self.config.initializer_range, + ).to(module.position_embeddings.dtype) + + +class RetNetVisionEncoder(nn.Module): + def __init__(self, config) -> None: + super().__init__() + self.config = config + self.blocks = nn.ModuleList([ + RetNetVisionBlock(config, layer_idx) + for layer_idx in range(config.num_hidden_layers) + ]) + self.gradient_checkpointing = False + + def forward( + self, + hidden_states: torch.Tensor, + output_attentions: bool = False, + output_hidden_states: bool = False, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + use_cache: Optional[bool] = None, + return_dict: bool = True, + **kwargs + ) -> Union[tuple, BaseModelOutput]: + all_hidden_states = () if output_hidden_states else None + all_self_attentions = () if output_attentions else None + + for i, block in enumerate(self.blocks): + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if self.gradient_checkpointing and self.training: + hidden_states, attentions, past_key_values = self._gradient_checkpointing_func( + block.__call__, + hidden_states, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + **kwargs + ) + else: + hidden_states, attentions, past_key_values = block( + hidden_states, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + **kwargs + ) + + if output_attentions: + all_self_attentions = all_self_attentions + (attentions,) + + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if not return_dict: + return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None) + + return BaseModelOutput( + last_hidden_state=hidden_states, + hidden_states=all_hidden_states, + attentions=all_self_attentions, + ) + +class RetNetVisionModel(RetNetVisionPreTrainedModel): + def __init__(self, config, add_pooling_layer=True, use_mask_token=False): + super().__init__(config) + self.config = config + self.embeddings = ImageEmbeddings(config, use_mask_token=use_mask_token) + self.encoder = RetNetVisionEncoder(config) + self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.pooler = Pooler(config) if add_pooling_layer else None + self.init_weights() + + def get_input_embeddings(self): + return self.embeddings.patch_embeddings + + def forward( + self, + pixel_values: Optional[torch.Tensor] = None, + bool_masked_pos: Optional[torch.BoolTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + interpolate_pos_encoding: Optional[bool] = None, + use_cache: Optional[bool] = None, + return_dict: Optional[bool] = None, + **kwargs + ) -> Union[Tuple, BaseModelOutputWithPooling]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if pixel_values is None: + raise ValueError("You have to specify pixel_values") + + hidden_states = self.embeddings(pixel_values, bool_masked_pos=bool_masked_pos, interpolate_pos_encoding=interpolate_pos_encoding) + + encoder_outputs = self.encoder( + hidden_states, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + past_key_values=past_key_values, + use_cache=use_cache, + return_dict=return_dict, + **kwargs + ) + + sequence_output = encoder_outputs[0] + sequence_output = self.layernorm(sequence_output) + pooled_output = self.pooler(sequence_output) if self.pooler is not None else None + + if not return_dict: + head_outputs = (sequence_output, pooled_output) if pooled_output is not None else (sequence_output,) + return head_outputs + encoder_outputs[1:] + + return BaseModelOutputWithPooling( + last_hidden_state=sequence_output, + pooler_output=pooled_output, + hidden_states=encoder_outputs.hidden_states, + attentions=encoder_outputs.attentions, + ) + +class RetNetForImageClassification(RetNetVisionPreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_classes + self.backbone = RetNetVisionModel(config, add_pooling_layer=True) # Here we should use mean pooling + self.classifier = nn.Linear(config.hidden_size, config.num_classes) + self.init_weights() + + def forward( + self, + pixel_values: Optional[torch.Tensor] = None, + labels: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + interpolate_pos_encoding: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[tuple, ImageClassifierOutput]: + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.backbone( + pixel_values, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + interpolate_pos_encoding=interpolate_pos_encoding, + return_dict=return_dict, + ) + + pooled_output = outputs.pooler_output + logits = self.classifier(pooled_output) # only use mean pooling + + loss = None + if labels is not None: + if self.num_labels == 1: + loss_fct = torch.nn.MSELoss() + loss = loss_fct(logits.squeeze(), labels.squeeze()) + else: + loss_fct = torch.nn.CrossEntropyLoss() + loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) + + if not return_dict: + output = (logits,) + outputs[2:] + return ((loss,) + output) if loss is not None else output + + return ImageClassifierOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + +class RetNetForMaskedImageModeling(RetNetVisionPreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.backbone = RetNetVisionModel(config, add_pooling_layer=False, use_mask_token=True) + self.decoder = nn.Sequential( + nn.Conv2d( + in_channels=config.hidden_size, + out_channels=config.encoder_stride**2 * config.num_channels, + kernel_size=1, + ), + nn.PixelShuffle(config.encoder_stride), + ) + self.init_weights() + + def forward( + self, + pixel_values: Optional[torch.Tensor] = None, + bool_masked_pos: Optional[torch.BoolTensor] = None, + labels: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + interpolate_pos_encoding: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[tuple, MaskedImageModelingOutput]: + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if bool_masked_pos is not None and (self.config.patch_size != self.config.encoder_stride): + raise ValueError( + "When `bool_masked_pos` is provided, `patch_size` must be equal to `encoder_stride` to ensure that " + "the reconstructed image has the same dimensions as the input. " + f"Got `patch_size` = {self.config.patch_size} and `encoder_stride` = {self.config.encoder_stride}." + ) + + outputs = self.backbone( + pixel_values, + bool_masked_pos=bool_masked_pos, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + interpolate_pos_encoding=interpolate_pos_encoding, + return_dict=return_dict, + ) + + + sequence_output = outputs[0] + batch_size, sequence_length, num_channels = sequence_output.shape + height = width = math.floor(sequence_length**0.5) + sequence_output = sequence_output.permute(0, 2, 1).reshape(batch_size, num_channels, height, width) + + # Reconstruct pixel values + reconstructed_pixel_values = self.decoder(sequence_output) + + masked_im_loss = None + if bool_masked_pos is not None: + size = self.config.image_size // self.config.patch_size + bool_masked_pos = bool_masked_pos.reshape(-1, size, size) + mask = ( + bool_masked_pos.repeat_interleave(self.config.patch_size, 1) + .repeat_interleave(self.config.patch_size, 2) + .unsqueeze(1) + .contiguous() + ) + reconstruction_loss = nn.functional.l1_loss(pixel_values, reconstructed_pixel_values, reduction="none") + masked_im_loss = (reconstruction_loss * mask).sum() / (mask.sum() + 1e-5) / self.config.num_channels + + if not return_dict: + output = (reconstructed_pixel_values,) + outputs[1:] + return ((masked_im_loss,) + output) if masked_im_loss is not None else output + + return MaskedImageModelingOutput( + loss=masked_im_loss, + reconstruction=reconstructed_pixel_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) \ No newline at end of file diff --git a/fla/models/rwkv6/__init__.py b/fla/models/rwkv6/__init__.py index 942c6dc20..d5ae03400 100644 --- a/fla/models/rwkv6/__init__.py +++ b/fla/models/rwkv6/__init__.py @@ -1,13 +1,17 @@ # -*- coding: utf-8 -*- -from transformers import AutoConfig, AutoModel, AutoModelForCausalLM +from transformers import AutoConfig, AutoModel, AutoModelForCausalLM, AutoModelForImageClassification, AutoModelForMaskedImageModeling -from fla.models.rwkv6.configuration_rwkv6 import RWKV6Config -from fla.models.rwkv6.modeling_rwkv6 import RWKV6ForCausalLM, RWKV6Model +from fla.models.rwkv6.configuration_rwkv6 import RWKV6Config, RWKV6VisionConfig +from fla.models.rwkv6.modeling_rwkv6 import RWKV6ForCausalLM, RWKV6Model, RWKV6VisionModel, RWKV6ForImageClassification, RWKV6ForMaskedImageModeling AutoConfig.register(RWKV6Config.model_type, RWKV6Config) +AutoConfig.register(RWKV6VisionConfig.model_type, RWKV6VisionConfig) AutoModel.register(RWKV6Config, RWKV6Model) AutoModelForCausalLM.register(RWKV6Config, RWKV6ForCausalLM) +AutoModel.register(RWKV6VisionConfig, RWKV6VisionModel) +AutoModelForImageClassification.register(RWKV6VisionConfig, RWKV6ForImageClassification) +AutoModelForMaskedImageModeling.register(RWKV6VisionConfig, RWKV6ForMaskedImageModeling) -__all__ = ['RWKV6Config', 'RWKV6ForCausalLM', 'RWKV6Model'] +__all__ = ['RWKV6Config', 'RWKV6ForCausalLM', 'RWKV6Model', 'RWKV6VisionModel', 'RWKV6ForImageClassification', 'RWKV6ForMaskedImageModeling', 'RWKV6VisionConfig'] diff --git a/fla/models/rwkv6/configuration_rwkv6.py b/fla/models/rwkv6/configuration_rwkv6.py index 6e56614bf..4bd92dc4a 100644 --- a/fla/models/rwkv6/configuration_rwkv6.py +++ b/fla/models/rwkv6/configuration_rwkv6.py @@ -78,3 +78,94 @@ def __init__( tie_word_embeddings=tie_word_embeddings, **kwargs, ) + +class RWKV6VisionConfig(PretrainedConfig): + + model_type = 'rwkv6_vision' + + def __init__( + self, + # RWKV6 core parameters + attn_mode: str = "chunk", + hidden_size: int = 2048, + expand_k: int = 0.5, + expand_v: int = 1, + num_hidden_layers: int = 24, + num_heads: int = 4, + proj_low_rank_dim: int = 32, + gate_low_rank_dim: int = 64, + hidden_act: str = "sqrelu", + max_position_embeddings: int = 2048, + norm_first: bool = True, + norm_bias: bool = True, + norm_eps: float = 1e-5, + attn: Optional[Dict] = None, + use_cache: bool = True, + initializer_range: float = 0.02, + fuse_norm: bool = True, + fuse_cross_entropy: bool = True, + # Vision specific parameters + image_size: int = 224, + patch_size: int = 16, + num_channels: int = 3, + num_classes: int = 1000, + qkv_bias: bool = True, + hidden_dropout_prob: float = 0.0, + use_mask_token: bool = False, + layer_norm_eps: float = 1e-6, + interpolate_pos_encoding: bool = False, + mlp_dim: int = None, + encoder_stride=16, + scan_type: str = "uni-scan", # scaning type, "uni-scan" or "bi-scan" or "cross-scan", default to "uni-scan" + **kwargs + ): + # Initialize RWKV6 core parameters + self.attn_mode = attn_mode + self.hidden_size = hidden_size + self.expand_k = expand_k + self.expand_v = expand_v + self.norm_first = norm_first + self.num_hidden_layers = num_hidden_layers + self.num_heads = num_heads + self.proj_low_rank_dim = proj_low_rank_dim + self.gate_low_rank_dim = gate_low_rank_dim + self.hidden_act = hidden_act + self.max_position_embeddings = max_position_embeddings + self.norm_bias = norm_bias + self.norm_eps = norm_eps + self.use_cache = use_cache + self.initializer_range = initializer_range + self.fuse_norm = fuse_norm + self.fuse_cross_entropy = fuse_cross_entropy + + # Initialize vision specific parameters + self.image_size = image_size + self.patch_size = patch_size + self.num_channels = num_channels + self.num_classes = num_classes + self.qkv_bias = qkv_bias + self.hidden_dropout_prob = hidden_dropout_prob + self.use_mask_token = use_mask_token + self.layer_norm_eps = layer_norm_eps + self.interpolate_pos_encoding = interpolate_pos_encoding + self.scan_type = scan_type + self.encoder_stride = encoder_stride + + if attn is not None: + if not isinstance(attn, Dict): + raise ValueError("attn must be a dictionary") + if 'layers' not in attn: + raise ValueError("Layer indices must be provided to initialize hybrid attention layers") + if 'num_heads' not in attn: + raise ValueError("Number of heads must be provided to initialize hybrid attention layers") + attn['num_kv_heads'] = attn.get('num_kv_heads', attn['num_heads']) + attn['window_size'] = attn.get('window_size', None) + + self.attn = attn + + if mlp_dim is None: + self.mlp_dim = 4 * hidden_size # default value set to 4 * hidden_size + else: + self.mlp_dim = mlp_dim + + super().__init__(**kwargs) \ No newline at end of file diff --git a/fla/models/rwkv6/modeling_rwkv6.py b/fla/models/rwkv6/modeling_rwkv6.py index 0fa95670d..f4ab27ff7 100644 --- a/fla/models/rwkv6/modeling_rwkv6.py +++ b/fla/models/rwkv6/modeling_rwkv6.py @@ -4,25 +4,29 @@ import math import warnings -from typing import Optional, Tuple, Union +from typing import Optional, Tuple, Union, List, Unpack, Dict import torch import torch.nn as nn import torch.utils.checkpoint from transformers.generation import GenerationMixin from transformers.modeling_outputs import (BaseModelOutputWithPast, - CausalLMOutputWithPast) + CausalLMOutputWithPast, + ImageClassifierOutput, + MaskedImageModelingOutput, + BaseModelOutput, + BaseModelOutputWithPooling) from transformers.modeling_utils import PreTrainedModel from transformers.utils import logging from fla.layers.attn import Attention from fla.layers.rwkv6 import LerpLinear, RWKV6Attention -from fla.models.rwkv6.configuration_rwkv6 import RWKV6Config +from fla.models.rwkv6.configuration_rwkv6 import RWKV6Config, RWKV6VisionConfig from fla.models.utils import Cache from fla.modules import (FusedCrossEntropyLoss, FusedLinearCrossEntropyLoss, LayerNorm) from fla.modules.activations import ACT2FN - +from ..utils import ImageEmbeddings, Pooler, prepare_hidden_states_for_cross_scan, prepare_hidden_states_for_cross_merge logger = logging.get_logger(__name__) @@ -442,3 +446,364 @@ def forward( hidden_states=outputs.hidden_states, attentions=outputs.attentions, ) + +class RWKV6VisionMLP(nn.Module): + def __init__(self, config): + super().__init__() + self.net = nn.Sequential( + nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps), + nn.Linear(config.hidden_size, config.mlp_dim), + nn.GELU(), + nn.Linear(config.mlp_dim, config.hidden_size), + nn.Dropout(config.hidden_dropout_prob) + ) + + def forward(self, x): + return self.net(x) + +class RWKV6VisionBlock(nn.Module): + def __init__(self, config, layer_idx: int): + super().__init__() + + self.ln_1 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + + if config.attn is not None and layer_idx in config.attn['layers']: + self.attn = Attention( + hidden_size=config.hidden_size, + num_heads=config.attn['num_heads'], + num_kv_heads=config.attn['num_kv_heads'], + window_size=config.attn['window_size'], + max_position_embeddings=config.max_position_embeddings, + layer_idx=layer_idx + ) + else: + self.attn = RWKV6Attention( + mode=config.attn_mode, + hidden_size=config.hidden_size, + expand_k=config.expand_k, + expand_v=config.expand_v, + num_heads=config.num_heads, + proj_low_rank_dim=config.proj_low_rank_dim, + gate_low_rank_dim=config.gate_low_rank_dim, + norm_eps=config.norm_eps, + fuse_norm=config.fuse_norm, + layer_idx=layer_idx + ) + + self.ln_2 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + + self.mlp = RWKV6VisionMLP(config) + + self.scan_type = config.scan_type + + def forward( + self, + hidden_states: torch.Tensor, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + use_cache: Optional[bool] = False, + output_attentions: Optional[bool] = False, + **kwargs: Unpack[Dict] + ) -> Union[Tuple[torch.Tensor, Optional[torch.Tensor]], Tuple[torch.Tensor]]: + residual = hidden_states + + # Pre-normalization if enabled + if hasattr(self, 'ln_1'): + hidden_states = self.ln_1(hidden_states) + + # Apply attention + + hidden_states = prepare_hidden_states_for_cross_scan(hidden_states, self.scan_type) + + hidden_states, attentions, past_key_values = self.attn( + hidden_states=hidden_states, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + **kwargs + ) + + hidden_states = prepare_hidden_states_for_cross_merge(hidden_states, self.scan_type) + + # First residual connection + hidden_states = residual + hidden_states + residual = hidden_states + + # Pre-normalization for MLP if enabled + if hasattr(self, 'ln_2'): + hidden_states = self.ln_2(hidden_states) + + hidden_states = self.mlp(hidden_states) + + # Second residual connection + hidden_states = residual + hidden_states + + outputs = (hidden_states, attentions, past_key_values) + + return outputs + +class RWKV6VisionPreTrainedModel(PreTrainedModel): + config_class = RWKV6VisionConfig + + def _init_weights(self, module): + if isinstance(module, (nn.Linear, nn.Conv2d)): + module.weight.data = nn.init.trunc_normal_( + module.weight.data.to(torch.float32), mean=0.0, std=self.config.initializer_range + ).to(module.weight.dtype) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.LayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + elif isinstance(module, ImageEmbeddings): + module.position_embeddings.data = nn.init.trunc_normal_( + module.position_embeddings.data.to(torch.float32), + mean=0.0, + std=self.config.initializer_range, + ).to(module.position_embeddings.dtype) + + +class RWKV6VisionEncoder(nn.Module): + def __init__(self, config) -> None: + super().__init__() + self.config = config + self.blocks = nn.ModuleList([ + RWKV6VisionBlock(config, layer_idx) + for layer_idx in range(config.num_hidden_layers) + ]) + self.gradient_checkpointing = False + + def forward( + self, + hidden_states: torch.Tensor, + output_attentions: bool = False, + output_hidden_states: bool = False, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + use_cache: Optional[bool] = None, + return_dict: bool = True, + **kwargs + ) -> Union[tuple, BaseModelOutput]: + all_hidden_states = () if output_hidden_states else None + all_self_attentions = () if output_attentions else None + + for i, block in enumerate(self.blocks): + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if self.gradient_checkpointing and self.training: + hidden_states, attentions, past_key_values = self._gradient_checkpointing_func( + block.__call__, + hidden_states, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + **kwargs + ) + else: + hidden_states, attentions, past_key_values = block( + hidden_states, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + **kwargs + ) + + if output_attentions: + all_self_attentions = all_self_attentions + (attentions,) + + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if not return_dict: + return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None) + + return BaseModelOutput( + last_hidden_state=hidden_states, + hidden_states=all_hidden_states, + attentions=all_self_attentions, + ) + +class RWKV6VisionModel(RWKV6VisionPreTrainedModel): + def __init__(self, config, add_pooling_layer=True, use_mask_token=False): + super().__init__(config) + self.config = config + self.embeddings = ImageEmbeddings(config, use_mask_token=use_mask_token) + self.encoder = RWKV6VisionEncoder(config) + self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.pooler = Pooler(config) if add_pooling_layer else None + self.init_weights() + + def get_input_embeddings(self): + return self.embeddings.patch_embeddings + + def forward( + self, + pixel_values: Optional[torch.Tensor] = None, + bool_masked_pos: Optional[torch.BoolTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + interpolate_pos_encoding: Optional[bool] = None, + use_cache: Optional[bool] = None, + return_dict: Optional[bool] = None, + **kwargs + ) -> Union[Tuple, BaseModelOutputWithPooling]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if pixel_values is None: + raise ValueError("You have to specify pixel_values") + + hidden_states = self.embeddings(pixel_values, bool_masked_pos=bool_masked_pos, interpolate_pos_encoding=interpolate_pos_encoding) + + encoder_outputs = self.encoder( + hidden_states, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + past_key_values=past_key_values, + use_cache=use_cache, + return_dict=return_dict, + **kwargs + ) + + sequence_output = encoder_outputs[0] + sequence_output = self.layernorm(sequence_output) + pooled_output = self.pooler(sequence_output) if self.pooler is not None else None + + if not return_dict: + head_outputs = (sequence_output, pooled_output) if pooled_output is not None else (sequence_output,) + return head_outputs + encoder_outputs[1:] + + return BaseModelOutputWithPooling( + last_hidden_state=sequence_output, + pooler_output=pooled_output, + hidden_states=encoder_outputs.hidden_states, + attentions=encoder_outputs.attentions, + ) + +class RWKV6ForImageClassification(RWKV6VisionPreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_classes + self.backbone = RWKV6VisionModel(config, add_pooling_layer=True) # Here we should use mean pooling + self.classifier = nn.Linear(config.hidden_size, config.num_classes) + self.init_weights() + + def forward( + self, + pixel_values: Optional[torch.Tensor] = None, + labels: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + interpolate_pos_encoding: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[tuple, ImageClassifierOutput]: + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.backbone( + pixel_values, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + interpolate_pos_encoding=interpolate_pos_encoding, + return_dict=return_dict, + ) + + pooled_output = outputs.pooler_output + logits = self.classifier(pooled_output) # only use mean pooling + + loss = None + if labels is not None: + if self.num_labels == 1: + loss_fct = torch.nn.MSELoss() + loss = loss_fct(logits.squeeze(), labels.squeeze()) + else: + loss_fct = torch.nn.CrossEntropyLoss() + loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) + + if not return_dict: + output = (logits,) + outputs[2:] + return ((loss,) + output) if loss is not None else output + + return ImageClassifierOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + +class RWKV6ForMaskedImageModeling(RWKV6VisionPreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.backbone = RWKV6VisionModel(config, add_pooling_layer=False, use_mask_token=True) + self.decoder = nn.Sequential( + nn.Conv2d( + in_channels=config.hidden_size, + out_channels=config.encoder_stride**2 * config.num_channels, + kernel_size=1, + ), + nn.PixelShuffle(config.encoder_stride), + ) + self.init_weights() + + def forward( + self, + pixel_values: Optional[torch.Tensor] = None, + bool_masked_pos: Optional[torch.BoolTensor] = None, + labels: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + interpolate_pos_encoding: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[tuple, MaskedImageModelingOutput]: + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if bool_masked_pos is not None and (self.config.patch_size != self.config.encoder_stride): + raise ValueError( + "When `bool_masked_pos` is provided, `patch_size` must be equal to `encoder_stride` to ensure that " + "the reconstructed image has the same dimensions as the input. " + f"Got `patch_size` = {self.config.patch_size} and `encoder_stride` = {self.config.encoder_stride}." + ) + + outputs = self.backbone( + pixel_values, + bool_masked_pos=bool_masked_pos, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + interpolate_pos_encoding=interpolate_pos_encoding, + return_dict=return_dict, + ) + + + sequence_output = outputs[0] + batch_size, sequence_length, num_channels = sequence_output.shape + height = width = math.floor(sequence_length**0.5) + sequence_output = sequence_output.permute(0, 2, 1).reshape(batch_size, num_channels, height, width) + + # Reconstruct pixel values + reconstructed_pixel_values = self.decoder(sequence_output) + + masked_im_loss = None + if bool_masked_pos is not None: + size = self.config.image_size // self.config.patch_size + bool_masked_pos = bool_masked_pos.reshape(-1, size, size) + mask = ( + bool_masked_pos.repeat_interleave(self.config.patch_size, 1) + .repeat_interleave(self.config.patch_size, 2) + .unsqueeze(1) + .contiguous() + ) + reconstruction_loss = nn.functional.l1_loss(pixel_values, reconstructed_pixel_values, reduction="none") + masked_im_loss = (reconstruction_loss * mask).sum() / (mask.sum() + 1e-5) / self.config.num_channels + + if not return_dict: + output = (reconstructed_pixel_values,) + outputs[1:] + return ((masked_im_loss,) + output) if masked_im_loss is not None else output + + return MaskedImageModelingOutput( + loss=masked_im_loss, + reconstruction=reconstructed_pixel_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) diff --git a/fla/models/transformer/__init__.py b/fla/models/transformer/__init__.py index 47df999fe..0faf3cac0 100644 --- a/fla/models/transformer/__init__.py +++ b/fla/models/transformer/__init__.py @@ -1,14 +1,18 @@ # -*- coding: utf-8 -*- -from transformers import AutoConfig, AutoModel, AutoModelForCausalLM +from transformers import AutoConfig, AutoModel, AutoModelForCausalLM, AutoModelForImageClassification, AutoModelForMaskedImageModeling -from fla.models.transformer.configuration_transformer import TransformerConfig +from fla.models.transformer.configuration_transformer import TransformerConfig, TransformerVisionConfig from fla.models.transformer.modeling_transformer import ( - TransformerForCausalLM, TransformerModel) + TransformerForCausalLM, TransformerModel, TransformerVisionModel, TransformerForImageClassification, TransformerForMaskedImageModeling) AutoConfig.register(TransformerConfig.model_type, TransformerConfig) +AutoConfig.register(TransformerVisionConfig.model_type, TransformerVisionConfig) AutoModel.register(TransformerConfig, TransformerModel) AutoModelForCausalLM.register(TransformerConfig, TransformerForCausalLM) +AutoModelForImageClassification.register(TransformerVisionConfig, TransformerForImageClassification) +AutoModelForMaskedImageModeling.register(TransformerVisionConfig, TransformerForMaskedImageModeling) +AutoModel.register(TransformerVisionConfig, TransformerVisionModel) -__all__ = ['TransformerConfig', 'TransformerForCausalLM', 'TransformerModel'] +__all__ = ['TransformerConfig', 'TransformerForCausalLM', 'TransformerModel', 'TransformerVisionModel', 'TransformerForImageClassification', 'TransformerForMaskedImageModeling', 'TransformerVisionConfig'] diff --git a/fla/models/transformer/configuration_transformer.py b/fla/models/transformer/configuration_transformer.py index 35e27113c..8e3fb057a 100644 --- a/fla/models/transformer/configuration_transformer.py +++ b/fla/models/transformer/configuration_transformer.py @@ -66,3 +66,81 @@ def __init__( tie_word_embeddings=tie_word_embeddings, **kwargs, ) + +class TransformerVisionConfig(PretrainedConfig): + + model_type = 'transformer_vision' + + def __init__( + self, + # Transformer core parameters + hidden_size: int = 2048, + num_hidden_layers: int = 24, + num_heads: int = 32, + num_kv_heads: int = None, + window_size: Optional[int] = None, + rope_theta: Optional[float] = 10000., + max_position_embeddings: int = 2048, + hidden_act: str = "swish", + initializer_range: float = 0.02, + elementwise_affine: Optional[bool] = True, + norm_first: bool = False, + norm_eps: float = 1e-6, + use_cache: bool = True, + attention_bias: bool = False, + fuse_norm: bool = True, + fuse_cross_entropy: bool = True, + # Vision specific parameters + image_size: int = 224, + patch_size: int = 16, + num_channels: int = 3, + num_classes: int = 1000, + qkv_bias: bool = True, + hidden_dropout_prob: float = 0.0, + use_mask_token: bool = False, + layer_norm_eps: float = 1e-6, + interpolate_pos_encoding: bool = False, + mlp_dim: int = None, + encoder_stride=16, + scan_type: str = "uni-scan", # scaning type, "uni-scan" or "bi-scan" or "cross-scan", default to "uni-scan" + **kwargs + ): + # Initialize transformer core parameters + self.hidden_size = hidden_size + self.num_hidden_layers = num_hidden_layers + self.num_heads = num_heads + self.num_kv_heads = num_kv_heads + self.window_size = window_size + self.rope_theta = rope_theta + self.max_position_embeddings = max_position_embeddings + + self.hidden_act = hidden_act + + self.initializer_range = initializer_range + self.elementwise_affine = elementwise_affine + self.norm_first = norm_first + self.norm_eps = norm_eps + self.use_cache = use_cache + self.attention_bias = attention_bias + self.fuse_cross_entropy = fuse_cross_entropy + self.fuse_norm = fuse_norm + + # Initialize vision specific parameters + self.image_size = image_size + self.patch_size = patch_size + self.num_channels = num_channels + self.num_classes = num_classes + self.qkv_bias = qkv_bias + self.hidden_dropout_prob = hidden_dropout_prob + self.use_mask_token = use_mask_token + self.layer_norm_eps = layer_norm_eps + self.interpolate_pos_encoding = interpolate_pos_encoding + self.scan_type = scan_type + self.encoder_stride = encoder_stride + + if mlp_dim is None: + self.mlp_dim = 4 * hidden_size # default value set to 4 * hidden_size + else: + self.mlp_dim = mlp_dim + + super().__init__(**kwargs) \ No newline at end of file diff --git a/fla/models/transformer/modeling_transformer.py b/fla/models/transformer/modeling_transformer.py index f843ce682..78096590f 100644 --- a/fla/models/transformer/modeling_transformer.py +++ b/fla/models/transformer/modeling_transformer.py @@ -4,7 +4,7 @@ import math import warnings -from typing import TYPE_CHECKING, Any, List, Optional, Tuple, Union +from typing import TYPE_CHECKING, Any, List, Optional, Tuple, Union, Dict import torch import torch.nn as nn @@ -12,18 +12,22 @@ from transformers.activations import ACT2FN from transformers.generation import GenerationMixin from transformers.modeling_outputs import (BaseModelOutputWithPast, - CausalLMOutputWithPast) + CausalLMOutputWithPast, + ImageClassifierOutput, + MaskedImageModelingOutput, + BaseModelOutput, + BaseModelOutputWithPooling) from transformers.modeling_utils import PreTrainedModel from transformers.utils import logging from fla.layers.attn import Attention -from fla.models.transformer.configuration_transformer import TransformerConfig +from fla.models.transformer.configuration_transformer import TransformerConfig, TransformerVisionConfig from fla.models.utils import Cache from fla.modules import (FusedCrossEntropyLoss, FusedLinearCrossEntropyLoss, RMSNorm) from fla.modules.activations import swiglu_linear from fla.modules.layernorm import rms_norm_linear - +from ..utils import ImageEmbeddings, Pooler, prepare_hidden_states_for_cross_scan, prepare_hidden_states_for_cross_merge if TYPE_CHECKING: from transformers.processing_utils import Unpack @@ -445,3 +449,355 @@ def forward( hidden_states=outputs.hidden_states, attentions=outputs.attentions, ) + +class TransformerVisionMLP(nn.Module): + def __init__(self, config): + super().__init__() + self.net = nn.Sequential( + nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps), + nn.Linear(config.hidden_size, config.mlp_dim), + nn.GELU(), + nn.Linear(config.mlp_dim, config.hidden_size), + nn.Dropout(config.hidden_dropout_prob) + ) + + def forward(self, x): + return self.net(x) + +class TransformerVisionBlock(nn.Module): + def __init__(self, config, layer_idx: int): + super().__init__() + + if not config.norm_first: + self.ln_1 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + + self.attn = Attention( + hidden_size=config.hidden_size, + num_heads=config.num_heads, + num_kv_heads=config.num_kv_heads, + window_size=config.window_size, + rope_theta=config.rope_theta, + max_position_embeddings=config.max_position_embeddings, + norm_first=config.norm_first, + norm_eps=config.norm_eps, + layer_idx=layer_idx + ) + + if not config.norm_first: + self.ln_2 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + + self.mlp = TransformerVisionMLP(config) + + self.scan_type = config.scan_type + + def forward( + self, + hidden_states: torch.Tensor, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + use_cache: Optional[bool] = False, + output_attentions: Optional[bool] = False, + **kwargs: Unpack[Dict] + ) -> Union[Tuple[torch.Tensor, Optional[torch.Tensor]], Tuple[torch.Tensor]]: + residual = hidden_states + + # Pre-normalization if enabled + if hasattr(self, 'ln_1'): + hidden_states = self.ln_1(hidden_states) + + # Apply attention + + hidden_states = prepare_hidden_states_for_cross_scan(hidden_states, self.scan_type) + + hidden_states, attentions, past_key_values = self.attn( + hidden_states=hidden_states, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + **kwargs + ) + + hidden_states = prepare_hidden_states_for_cross_merge(hidden_states, self.scan_type) + + # First residual connection + hidden_states = residual + hidden_states + residual = hidden_states + + # Pre-normalization for MLP if enabled + if hasattr(self, 'ln_2'): + hidden_states = self.ln_2(hidden_states) + + hidden_states = self.mlp(hidden_states) + + # Second residual connection + hidden_states = residual + hidden_states + + outputs = (hidden_states, attentions, past_key_values) + + return outputs + +class TransformerVisionPreTrainedModel(PreTrainedModel): + config_class = TransformerVisionConfig + + def _init_weights(self, module): + if isinstance(module, (nn.Linear, nn.Conv2d)): + module.weight.data = nn.init.trunc_normal_( + module.weight.data.to(torch.float32), mean=0.0, std=self.config.initializer_range + ).to(module.weight.dtype) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.LayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + elif isinstance(module, ImageEmbeddings): + module.position_embeddings.data = nn.init.trunc_normal_( + module.position_embeddings.data.to(torch.float32), + mean=0.0, + std=self.config.initializer_range, + ).to(module.position_embeddings.dtype) + + +class TransformerVisionEncoder(nn.Module): + def __init__(self, config) -> None: + super().__init__() + self.config = config + self.blocks = nn.ModuleList([ + TransformerVisionBlock(config, layer_idx) + for layer_idx in range(config.num_hidden_layers) + ]) + self.gradient_checkpointing = False + + def forward( + self, + hidden_states: torch.Tensor, + output_attentions: bool = False, + output_hidden_states: bool = False, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + use_cache: Optional[bool] = None, + return_dict: bool = True, + **kwargs + ) -> Union[tuple, BaseModelOutput]: + all_hidden_states = () if output_hidden_states else None + all_self_attentions = () if output_attentions else None + + for i, block in enumerate(self.blocks): + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if self.gradient_checkpointing and self.training: + hidden_states, attentions, past_key_values = self._gradient_checkpointing_func( + block.__call__, + hidden_states, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + **kwargs + ) + else: + hidden_states, attentions, past_key_values = block( + hidden_states, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + **kwargs + ) + + if output_attentions: + all_self_attentions = all_self_attentions + (attentions,) + + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if not return_dict: + return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None) + + return BaseModelOutput( + last_hidden_state=hidden_states, + hidden_states=all_hidden_states, + attentions=all_self_attentions, + ) + +class TransformerVisionModel(TransformerVisionPreTrainedModel): + def __init__(self, config, add_pooling_layer=True, use_mask_token=False): + super().__init__(config) + self.config = config + self.embeddings = ImageEmbeddings(config, use_mask_token=use_mask_token) + self.encoder = TransformerVisionEncoder(config) + self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.pooler = Pooler(config) if add_pooling_layer else None + self.init_weights() + + def get_input_embeddings(self): + return self.embeddings.patch_embeddings + + def forward( + self, + pixel_values: Optional[torch.Tensor] = None, + bool_masked_pos: Optional[torch.BoolTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + interpolate_pos_encoding: Optional[bool] = None, + use_cache: Optional[bool] = None, + return_dict: Optional[bool] = None, + **kwargs + ) -> Union[Tuple, BaseModelOutputWithPooling]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if pixel_values is None: + raise ValueError("You have to specify pixel_values") + + hidden_states = self.embeddings(pixel_values, bool_masked_pos=bool_masked_pos, interpolate_pos_encoding=interpolate_pos_encoding) + + encoder_outputs = self.encoder( + hidden_states, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + past_key_values=past_key_values, + use_cache=use_cache, + return_dict=return_dict, + **kwargs + ) + + sequence_output = encoder_outputs[0] + sequence_output = self.layernorm(sequence_output) + pooled_output = self.pooler(sequence_output) if self.pooler is not None else None + + if not return_dict: + head_outputs = (sequence_output, pooled_output) if pooled_output is not None else (sequence_output,) + return head_outputs + encoder_outputs[1:] + + return BaseModelOutputWithPooling( + last_hidden_state=sequence_output, + pooler_output=pooled_output, + hidden_states=encoder_outputs.hidden_states, + attentions=encoder_outputs.attentions, + ) + +class TransformerForImageClassification(TransformerVisionPreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_classes + self.backbone = TransformerVisionModel(config, add_pooling_layer=True) # Here we should use mean pooling + self.classifier = nn.Linear(config.hidden_size, config.num_classes) + self.init_weights() + + def forward( + self, + pixel_values: Optional[torch.Tensor] = None, + labels: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + interpolate_pos_encoding: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[tuple, ImageClassifierOutput]: + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.backbone( + pixel_values, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + interpolate_pos_encoding=interpolate_pos_encoding, + return_dict=return_dict, + ) + + pooled_output = outputs.pooler_output + logits = self.classifier(pooled_output) # only use mean pooling + + loss = None + if labels is not None: + if self.num_labels == 1: + loss_fct = torch.nn.MSELoss() + loss = loss_fct(logits.squeeze(), labels.squeeze()) + else: + loss_fct = torch.nn.CrossEntropyLoss() + loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) + + if not return_dict: + output = (logits,) + outputs[2:] + return ((loss,) + output) if loss is not None else output + + return ImageClassifierOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + +class TransformerForMaskedImageModeling(TransformerVisionPreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.backbone = TransformerVisionModel(config, add_pooling_layer=False, use_mask_token=True) + self.decoder = nn.Sequential( + nn.Conv2d( + in_channels=config.hidden_size, + out_channels=config.encoder_stride**2 * config.num_channels, + kernel_size=1, + ), + nn.PixelShuffle(config.encoder_stride), + ) + self.init_weights() + + def forward( + self, + pixel_values: Optional[torch.Tensor] = None, + bool_masked_pos: Optional[torch.BoolTensor] = None, + labels: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + interpolate_pos_encoding: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[tuple, MaskedImageModelingOutput]: + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if bool_masked_pos is not None and (self.config.patch_size != self.config.encoder_stride): + raise ValueError( + "When `bool_masked_pos` is provided, `patch_size` must be equal to `encoder_stride` to ensure that " + "the reconstructed image has the same dimensions as the input. " + f"Got `patch_size` = {self.config.patch_size} and `encoder_stride` = {self.config.encoder_stride}." + ) + + outputs = self.backbone( + pixel_values, + bool_masked_pos=bool_masked_pos, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + interpolate_pos_encoding=interpolate_pos_encoding, + return_dict=return_dict, + ) + + + sequence_output = outputs[0] + batch_size, sequence_length, num_channels = sequence_output.shape + height = width = math.floor(sequence_length**0.5) + sequence_output = sequence_output.permute(0, 2, 1).reshape(batch_size, num_channels, height, width) + + # Reconstruct pixel values + reconstructed_pixel_values = self.decoder(sequence_output) + + masked_im_loss = None + if bool_masked_pos is not None: + size = self.config.image_size // self.config.patch_size + bool_masked_pos = bool_masked_pos.reshape(-1, size, size) + mask = ( + bool_masked_pos.repeat_interleave(self.config.patch_size, 1) + .repeat_interleave(self.config.patch_size, 2) + .unsqueeze(1) + .contiguous() + ) + reconstruction_loss = nn.functional.l1_loss(pixel_values, reconstructed_pixel_values, reduction="none") + masked_im_loss = (reconstruction_loss * mask).sum() / (mask.sum() + 1e-5) / self.config.num_channels + + if not return_dict: + output = (reconstructed_pixel_values,) + outputs[1:] + return ((masked_im_loss,) + output) if masked_im_loss is not None else output + + return MaskedImageModelingOutput( + loss=masked_im_loss, + reconstruction=reconstructed_pixel_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) \ No newline at end of file diff --git a/fla/models/utils.py b/fla/models/utils.py index ed6fa9002..f16f0039e 100644 --- a/fla/models/utils.py +++ b/fla/models/utils.py @@ -6,7 +6,13 @@ import torch import transformers - +from torch import nn +import collections.abc +from transformers.utils import torch_int +import triton +import triton.language as tl +import einops +import math class Cache(transformers.cache_utils.Cache): """ @@ -141,3 +147,418 @@ def from_legacy_cache( for layer_idx in range(len(past_key_values)): cache.states.append(past_key_values[layer_idx]) return cache + + +""" +Basic components of a vision model, including the patch embeddings, image embeddings, and pooler. Taken from https://github.com/huggingface/transformers/blob/main/src/transformers/models/vit/modeling_vit.py +""" + +class PatchEmbeddings(nn.Module): + """ + Convert image into patch embeddings. + Adapted from huggingface/transformers ViT implementation. + """ + def __init__(self, config): + super().__init__() + image_size, patch_size = config.image_size, config.patch_size + num_channels, hidden_size = config.num_channels, config.hidden_size + + image_size = image_size if isinstance(image_size, collections.abc.Iterable) else (image_size, image_size) + patch_size = patch_size if isinstance(patch_size, collections.abc.Iterable) else (patch_size, patch_size) + num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0]) + self.image_size = image_size + self.patch_size = patch_size + self.num_channels = num_channels + self.num_patches = num_patches + + self.projection = nn.Conv2d(num_channels, hidden_size, kernel_size=patch_size, stride=patch_size) + + def forward(self, pixel_values: torch.Tensor, interpolate_pos_encoding: bool = False) -> torch.Tensor: + batch_size, num_channels, height, width = pixel_values.shape + if num_channels != self.num_channels: + raise ValueError( + "Make sure that the channel dimension of the pixel values match with the one set in the configuration." + f" Expected {self.num_channels} but got {num_channels}." + ) + if not interpolate_pos_encoding: + if height != self.image_size[0] or width != self.image_size[1]: + raise ValueError( + f"Input image size ({height}*{width}) doesn't match model" + f" ({self.image_size[0]}*{self.image_size[1]})." + ) + embeddings = self.projection(pixel_values).flatten(2).transpose(1, 2) + return embeddings + +class ImageEmbeddings(nn.Module): + """ + Construct the position and patch embeddings. + Adapted from huggingface/transformers ViT implementation. No cls token is used in this implementation. + """ + def __init__(self, config, use_mask_token: bool = False) -> None: + super().__init__() + + self.mask_token = nn.Parameter(torch.zeros(1, 1, config.hidden_size)) if use_mask_token else None + self.patch_embeddings = PatchEmbeddings(config) + num_patches = self.patch_embeddings.num_patches + self.position_embeddings = nn.Parameter(torch.randn(1, num_patches, config.hidden_size)) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + self.patch_size = config.patch_size + self.config = config + + def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor: + """ + This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher resolution + images. This method is also adapted to support torch.jit tracing. + + Adapted from: + - https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174-L194, and + - https://github.com/facebookresearch/dinov2/blob/e1277af2ba9496fbadf7aec6eba56e8d882d1e35/dinov2/models/vision_transformer.py#L179-L211 + """ + + num_patches = embeddings.shape[1] + num_positions = self.position_embeddings.shape[1] + + if not torch.jit.is_tracing() and num_patches == num_positions and height == width: + return self.position_embeddings + + dim = embeddings.shape[-1] + + new_height = height // self.patch_size + new_width = width // self.patch_size + + sqrt_num_positions = torch_int(num_positions**0.5) + pos_embed = self.position_embeddings.reshape(1, sqrt_num_positions, sqrt_num_positions, dim) + + pos_embed = pos_embed.permute(0, 3, 1, 2) + + pos_embed = nn.functional.interpolate( + pos_embed, + size=(new_height, new_width), + mode="bicubic", + align_corners=False, + ) + + pos_embed = pos_embed.permute(0, 2, 3, 1).view(1, -1, dim) + + return pos_embed + + def forward( + self, + pixel_values: torch.Tensor, + bool_masked_pos: Optional[torch.BoolTensor] = None, + interpolate_pos_encoding: bool = False, + ) -> torch.Tensor: + batch_size, num_channels, height, width = pixel_values.shape + embeddings = self.patch_embeddings(pixel_values, interpolate_pos_encoding=interpolate_pos_encoding) + + if bool_masked_pos is not None: + seq_length = embeddings.shape[1] + mask_tokens = self.mask_token.expand(batch_size, seq_length, -1) + # replace the masked visual tokens by mask_tokens + mask = bool_masked_pos.unsqueeze(-1).type_as(mask_tokens) + embeddings = embeddings * (1.0 - mask) + mask_tokens * mask + + # add positional encoding to each token + if interpolate_pos_encoding: + embeddings = embeddings + self.interpolate_pos_encoding(embeddings, height, width) + else: + embeddings = embeddings + self.position_embeddings + + embeddings = self.dropout(embeddings) + + return embeddings + +class Pooler(nn.Module): + """ + Pool the output of a vision model by taking the mean of all tokens. + Adapted from huggingface/transformers ViT implementation. + """ + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + self.activation = nn.Tanh() + + def forward(self, hidden_states): + pooled_output = hidden_states.mean(dim=1) # always use mean pooling + pooled_output = self.dense(pooled_output) + pooled_output = self.activation(pooled_output) + return pooled_output + +""" +Cross Scan and Cross Merge implemented in Triton (only). Taken from https://github.com/MzeroMiko/VMamba/blob/main/classification/models/csm_triton.py +""" + +@triton.jit +def triton_cross_scan_flex( + x: tl.tensor, # (B, C, H, W) | (B, H, W, C) | (B, 4, C, H, W) | (B, H, W, 4, C) + y: tl.tensor, # (B, 4, C, H, W) | (B, H, W, 4, C) + x_layout: tl.constexpr, + y_layout: tl.constexpr, + operation: tl.constexpr, + onebyone: tl.constexpr, + scans: tl.constexpr, + BC: tl.constexpr, + BH: tl.constexpr, + BW: tl.constexpr, + DC: tl.constexpr, + DH: tl.constexpr, + DW: tl.constexpr, + NH: tl.constexpr, + NW: tl.constexpr, +): + # x_layout = 0 + # y_layout = 1 # 0 BCHW, 1 BHWC + # operation = 0 # 0 scan, 1 merge + # onebyone = 0 # 0 false, 1 true + # scans = 0 # 0 cross scan, 1 unidirectional, 2 bidirectional + + i_hw, i_c, i_b = tl.program_id(0), tl.program_id(1), tl.program_id(2) + i_h, i_w = (i_hw // NW), (i_hw % NW) + _mask_h = (i_h * BH + tl.arange(0, BH)) < DH + _mask_w = (i_w * BW + tl.arange(0, BW)) < DW + _mask_hw = _mask_h[:, None] & _mask_w[None, :] + _for_C = min(DC - i_c * BC, BC) + + pos_h = (i_h * BH + tl.arange(0, BH)[:, None]) + pos_w = (i_w * BW + tl.arange(0, BW)[None, :]) + neg_h = (DH - i_h * BH - 1 - tl.arange(0, BH)[:, None]) + neg_w = (DW - i_w * BW - 1 - tl.arange(0, BW)[None, :]) + if scans == 0: + # none; trans; flip; trans + flip; + HWRoute0 = pos_h * DW + pos_w + HWRoute1 = pos_w * DH + pos_h # trans + HWRoute2 = neg_h * DW + neg_w # flip + HWRoute3 = neg_w * DH + neg_h # trans + flip + elif scans == 1: + # none; none; none; none; + HWRoute0 = pos_h * DW + pos_w + HWRoute1 = HWRoute0 + HWRoute2 = HWRoute0 + HWRoute3 = HWRoute0 + elif scans == 2: + # none; none; flip; flip; + HWRoute0 = pos_h * DW + pos_w + HWRoute1 = HWRoute0 + HWRoute2 = neg_h * DW + neg_w # flip + HWRoute3 = HWRoute2 + + _tmp1 = DC * DH * DW + + y_ptr_base = y + i_b * 4 * _tmp1 + (i_c * BC * DH * DW if y_layout == 0 else i_c * BC) + if y_layout == 0: + p_y1 = y_ptr_base + HWRoute0 + p_y2 = y_ptr_base + _tmp1 + HWRoute1 + p_y3 = y_ptr_base + 2 * _tmp1 + HWRoute2 + p_y4 = y_ptr_base + 3 * _tmp1 + HWRoute3 + else: + p_y1 = y_ptr_base + HWRoute0 * 4 * DC + p_y2 = y_ptr_base + DC + HWRoute1 * 4 * DC + p_y3 = y_ptr_base + 2 * DC + HWRoute2 * 4 * DC + p_y4 = y_ptr_base + 3 * DC + HWRoute3 * 4 * DC + + if onebyone == 0: + x_ptr_base = x + i_b * _tmp1 + (i_c * BC * DH * DW if x_layout == 0 else i_c * BC) + if x_layout == 0: + p_x = x_ptr_base + HWRoute0 + else: + p_x = x_ptr_base + HWRoute0 * DC + + if operation == 0: + for idxc in range(_for_C): + _idx_x = idxc * DH * DW if x_layout == 0 else idxc + _idx_y = idxc * DH * DW if y_layout == 0 else idxc + _x = tl.load(p_x + _idx_x, mask=_mask_hw) + tl.store(p_y1 + _idx_y, _x, mask=_mask_hw) + tl.store(p_y2 + _idx_y, _x, mask=_mask_hw) + tl.store(p_y3 + _idx_y, _x, mask=_mask_hw) + tl.store(p_y4 + _idx_y, _x, mask=_mask_hw) + elif operation == 1: + for idxc in range(_for_C): + _idx_x = idxc * DH * DW if x_layout == 0 else idxc + _idx_y = idxc * DH * DW if y_layout == 0 else idxc + _y1 = tl.load(p_y1 + _idx_y, mask=_mask_hw) + _y2 = tl.load(p_y2 + _idx_y, mask=_mask_hw) + _y3 = tl.load(p_y3 + _idx_y, mask=_mask_hw) + _y4 = tl.load(p_y4 + _idx_y, mask=_mask_hw) + tl.store(p_x + _idx_x, _y1 + _y2 + _y3 + _y4, mask=_mask_hw) + + else: + x_ptr_base = x + i_b * 4 * _tmp1 + (i_c * BC * DH * DW if x_layout == 0 else i_c * BC) + if x_layout == 0: + p_x1 = x_ptr_base + HWRoute0 + p_x2 = p_x1 + _tmp1 + p_x3 = p_x2 + _tmp1 + p_x4 = p_x3 + _tmp1 + else: + p_x1 = x_ptr_base + HWRoute0 * 4 * DC + p_x2 = p_x1 + DC + p_x3 = p_x2 + DC + p_x4 = p_x3 + DC + + if operation == 0: + for idxc in range(_for_C): + _idx_x = idxc * DH * DW if x_layout == 0 else idxc + _idx_y = idxc * DH * DW if y_layout == 0 else idxc + tl.store(p_y1 + _idx_y, tl.load(p_x1 + _idx_x, mask=_mask_hw), mask=_mask_hw) + tl.store(p_y2 + _idx_y, tl.load(p_x2 + _idx_x, mask=_mask_hw), mask=_mask_hw) + tl.store(p_y3 + _idx_y, tl.load(p_x3 + _idx_x, mask=_mask_hw), mask=_mask_hw) + tl.store(p_y4 + _idx_y, tl.load(p_x4 + _idx_x, mask=_mask_hw), mask=_mask_hw) + else: + for idxc in range(_for_C): + _idx_x = idxc * DH * DW if x_layout == 0 else idxc + _idx_y = idxc * DH * DW if y_layout == 0 else idxc + tl.store(p_x1 + _idx_x, tl.load(p_y1 + _idx_y), mask=_mask_hw) + tl.store(p_x2 + _idx_x, tl.load(p_y2 + _idx_y), mask=_mask_hw) + tl.store(p_x3 + _idx_x, tl.load(p_y3 + _idx_y), mask=_mask_hw) + tl.store(p_x4 + _idx_x, tl.load(p_y4 + _idx_y), mask=_mask_hw) + + +class CrossScanTritonF(torch.autograd.Function): + @staticmethod + def forward(ctx, x: torch.Tensor, in_channel_first=True, out_channel_first=True, one_by_one=False, scans=0): + if one_by_one: + if in_channel_first: + B, _, C, H, W = x.shape + else: + B, H, W, _, C = x.shape + else: + if in_channel_first: + B, C, H, W = x.shape + else: + B, H, W, C = x.shape + B, C, H, W = int(B), int(C), int(H), int(W) + BC, BH, BW = 1, 32, 32 + NH, NW, NC = triton.cdiv(H, BH), triton.cdiv(W, BW), triton.cdiv(C, BC) + + ctx.in_channel_first = in_channel_first + ctx.out_channel_first = out_channel_first + ctx.one_by_one = one_by_one + ctx.scans = scans + ctx.shape = (B, C, H, W) + ctx.triton_shape = (BC, BH, BW, NC, NH, NW) + + y = x.new_empty((B, 4, C, H * W)) if out_channel_first else x.new_empty((B, H * W, 4, C)) + triton_cross_scan_flex[(NH * NW, NC, B)]( + x.contiguous(), y, + (0 if in_channel_first else 1), (0 if out_channel_first else 1), 0, (0 if not one_by_one else 1), scans, + BC, BH, BW, C, H, W, NH, NW + ) + return y + + @staticmethod + def backward(ctx, y: torch.Tensor): + in_channel_first = ctx.in_channel_first + out_channel_first = ctx.out_channel_first + one_by_one = ctx.one_by_one + scans = ctx.scans + B, C, H, W = ctx.shape + BC, BH, BW, NC, NH, NW = ctx.triton_shape + if one_by_one: + x = y.new_empty((B, 4, C, H, W)) if in_channel_first else y.new_empty((B, H, W, 4, C)) + else: + x = y.new_empty((B, C, H, W)) if in_channel_first else y.new_empty((B, H, W, C)) + + triton_cross_scan_flex[(NH * NW, NC, B)]( + x, y.contiguous(), + (0 if in_channel_first else 1), (0 if out_channel_first else 1), 1, (0 if not one_by_one else 1), scans, + BC, BH, BW, C, H, W, NH, NW + ) + return x, None, None, None, None + + +class CrossMergeTritonF(torch.autograd.Function): + @staticmethod + def forward(ctx, y: torch.Tensor, in_channel_first=True, out_channel_first=True, one_by_one=False, scans=0): + if out_channel_first: + B, _, C, H, W = y.shape + else: + B, H, W, _, C = y.shape + B, C, H, W = int(B), int(C), int(H), int(W) + BC, BH, BW = 1, 32, 32 + NH, NW, NC = triton.cdiv(H, BH), triton.cdiv(W, BW), triton.cdiv(C, BC) + ctx.in_channel_first = in_channel_first + ctx.out_channel_first = out_channel_first + ctx.one_by_one = one_by_one + ctx.scans = scans + ctx.shape = (B, C, H, W) + ctx.triton_shape = (BC, BH, BW, NC, NH, NW) + if one_by_one: + x = y.new_empty((B, 4, C, H * W)) if in_channel_first else y.new_empty((B, H * W, 4, C)) + else: + x = y.new_empty((B, C, H * W)) if in_channel_first else y.new_empty((B, H * W, C)) + triton_cross_scan_flex[(NH * NW, NC, B)]( + x, y.contiguous(), + (0 if in_channel_first else 1), (0 if out_channel_first else 1), 1, (0 if not one_by_one else 1), scans, + BC, BH, BW, C, H, W, NH, NW + ) + return x + + @staticmethod + def backward(ctx, x: torch.Tensor): + in_channel_first = ctx.in_channel_first + out_channel_first = ctx.out_channel_first + one_by_one = ctx.one_by_one + scans = ctx.scans + B, C, H, W = ctx.shape + BC, BH, BW, NC, NH, NW = ctx.triton_shape + y = x.new_empty((B, 4, C, H, W)) if out_channel_first else x.new_empty((B, H, W, 4, C)) + triton_cross_scan_flex[(NH * NW, NC, B)]( + x.contiguous(), y, + (0 if in_channel_first else 1), (0 if out_channel_first else 1), 0, (0 if not one_by_one else 1), scans, + BC, BH, BW, C, H, W, NH, NW + ) + return y, None, None, None, None, None + + +# @torch.compile(options={"triton.cudagraphs": True}, fullgraph=True) +def cross_scan_fn(x: torch.Tensor, in_channel_first=True, out_channel_first=True, one_by_one=False, scans=0, force_torch=False): + # x: (B, C, H, W) | (B, H, W, C) | (B, 4, C, H, W) | (B, H, W, 4, C) + # y: (B, 4, C, L) | (B, L, 4, C) + # scans: 0: cross scan; 1 unidirectional; 2: bidirectional; + assert x.is_cuda + CSF = CrossScanTritonF + with torch.cuda.device(x.device): + return CSF.apply(x, in_channel_first, out_channel_first, one_by_one, scans) + + +# @torch.compile(options={"triton.cudagraphs": True}, fullgraph=True) +def cross_merge_fn(y: torch.Tensor, in_channel_first=True, out_channel_first=True, one_by_one=False, scans=0, force_torch=False): + # y: (B, 4, C, L) | (B, L, 4, C) + # x: (B, C, H * W) | (B, H * W, C) | (B, 4, C, H * W) | (B, H * W, 4, C) + # scans: 0: cross scan; 1 unidirectional; 2: bidirectional; + assert y.is_cuda + CMF = CrossMergeTritonF + with torch.cuda.device(y.device): + return CMF.apply(y, in_channel_first, out_channel_first, one_by_one, scans) + +def prepare_hidden_states_for_cross_scan(hidden_states: torch.Tensor, scan_type: str = "uni-scan"): + # hidden_states shape should be: (B, L, D) + if scan_type == "uni-scan": + return hidden_states + elif scan_type == "bi-scan": + flipped_hidden_states = hidden_states.flip(-2) + hidden_states = torch.cat([hidden_states, flipped_hidden_states], dim=0) + return hidden_states + + B, L, D = hidden_states.shape + hw = int(math.sqrt(L)) + assert (hw * hw == L) + hidden_states = einops.rearrange(hidden_states, "b (h w) d -> b h w d", h=hw, w=hw) # change the shape to feed to cross_scan + hidden_states = cross_scan_fn(hidden_states, in_channel_first=False, out_channel_first=False, one_by_one=False, scans=0) + hidden_states = einops.rearrange(hidden_states, "b l k d -> (b k) l d") + return hidden_states + +def prepare_hidden_states_for_cross_merge(hidden_states: torch.Tensor, scan_type: str = "uni-scan"): + # hidden_states shape should be: (BK, L, D), K=2 for bi-scan, K=1 for uni-scan, K=4 for cross-scan + if scan_type == "uni-scan": + return hidden_states + elif scan_type == "bi-scan": + B = hidden_states.shape[0] // 2 + hidden_states = hidden_states[:B] + hidden_states[B:] + return hidden_states + + B, L, D = hidden_states.shape + hw = int(math.sqrt(L)) + hidden_states = einops.rearrange(hidden_states, "(b k) (h w) d -> b h w k d", k=4, h=hw, w=hw) + hidden_states = cross_merge_fn(hidden_states, in_channel_first=False, out_channel_first=False, one_by_one=False, scans=0) + return hidden_states \ No newline at end of file From e1b05da1bfbe82e6162235ca208527c4d0a73d7b Mon Sep 17 00:00:00 2001 From: yibozhong Date: Sun, 19 Jan 2025 17:31:36 +0800 Subject: [PATCH 17/17] reverse previous changes for separate PR --- fla/layers/abc.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/fla/layers/abc.py b/fla/layers/abc.py index d77e8d5f3..6d1cf15c8 100644 --- a/fla/layers/abc.py +++ b/fla/layers/abc.py @@ -38,7 +38,6 @@ def __init__( use_input_gate: bool = False, use_output_gate: bool = True, use_norm: bool = True, - use_rope: bool = False, # FIXME clamp_min: Optional[float] = -32, clamp_max: Optional[float] = 32, layer_idx: Optional[int] = None, @@ -65,7 +64,6 @@ def __init__( self.use_input_gate = use_input_gate self.use_output_gate = use_output_gate self.use_norm = use_norm - self.use_rope = use_rope # FIXME if num_slots is None: num_slots = self.head_k_dim