Skip to content

Commit

Permalink
[Dinov2 with Registers] Some fixes (#35411)
Browse files Browse the repository at this point in the history
* First draft

* Thanks claude

* Remove print statement

* Use torch_int

* Address comments

* Address comment
  • Loading branch information
NielsRogge authored Jan 6, 2025
1 parent ca00950 commit 12ba96a
Show file tree
Hide file tree
Showing 3 changed files with 73 additions and 50 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from ...configuration_utils import PretrainedConfig
from ...utils.backbone_utils import BackboneConfigMixin, get_aligned_output_features_output_indices

Expand Down Expand Up @@ -69,10 +70,6 @@ class Dinov2WithRegistersConfig(BackboneConfigMixin, PretrainedConfig):
Whether to use the SwiGLU feedforward neural network.
num_register_tokens (`int`, *optional*, defaults to 4):
Number of register tokens to use.
interpolate_antialias (`bool`, *optional*, defaults to `True`):
Whether to use antialiasing when interpolating the image patches.
interpolate_offset (`float`, *optional*, defaults to 0.0):
Offset to use when interpolating the image patches.
out_features (`List[str]`, *optional*):
If used as backbone, list of features to output. Can be any of `"stem"`, `"stage1"`, `"stage2"`, etc.
(depending on how many stages the model has). If unset and `out_indices` is set, will default to the
Expand Down Expand Up @@ -105,7 +102,7 @@ class Dinov2WithRegistersConfig(BackboneConfigMixin, PretrainedConfig):
>>> configuration = model.config
```"""

model_type = "dinov2-with-registers-base"
model_type = "dinov2_with_registers"

def __init__(
self,
Expand All @@ -126,8 +123,6 @@ def __init__(
drop_path_rate=0.0,
use_swiglu_ffn=False,
num_register_tokens=4,
interpolate_antialias=True,
interpolate_offset=0.0,
out_features=None,
out_indices=None,
apply_layernorm=True,
Expand All @@ -153,8 +148,6 @@ def __init__(
self.drop_path_rate = drop_path_rate
self.use_swiglu_ffn = use_swiglu_ffn
self.num_register_tokens = num_register_tokens
self.interpolate_antialias = interpolate_antialias
self.interpolate_offset = interpolate_offset
self.stage_names = ["stem"] + [f"stage{idx}" for idx in range(1, num_hidden_layers + 1)]
self._out_features, self._out_indices = get_aligned_output_features_output_indices(
out_features=out_features, out_indices=out_indices, stage_names=self.stage_names
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import collections.abc
import math
from typing import Dict, List, Optional, Set, Tuple, Union
Expand All @@ -37,6 +38,7 @@
add_start_docstrings_to_model_forward,
logging,
replace_return_docstrings,
torch_int,
)
from ...utils.backbone_utils import BackboneMixin
from .configuration_dinov2_with_registers import Dinov2WithRegistersConfig
Expand Down Expand Up @@ -99,43 +101,61 @@ def __init__(self, config: Dinov2WithRegistersConfig) -> None:
num_patches = self.patch_embeddings.num_patches
self.position_embeddings = nn.Parameter(torch.randn(1, num_patches + 1, config.hidden_size))
self.dropout = nn.Dropout(config.hidden_dropout_prob)
self.patch_size = config.patch_size
self.config = config

def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor:
"""
This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher
resolution images.
resolution images. This implementation supports torch.jit tracing while maintaining backwards compatibility
with the original implementation.
Source:
https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174
Adapted from:
- https://github.com/facebookresearch/dino/blob/main/vision_transformer.py
- https://github.com/facebookresearch/dinov2/blob/main/dinov2/models/vision_transformer.py
"""

num_patches = embeddings.shape[1] - 1
num_positions = self.position_embeddings.shape[1] - 1
if num_patches == num_positions and height == width:

# Skip interpolation for matching dimensions (unless tracing)
if not torch.jit.is_tracing() and num_patches == num_positions and height == width:
return self.position_embeddings

# Handle class token and patch embeddings separately
class_pos_embed = self.position_embeddings[:, 0]
patch_pos_embed = self.position_embeddings[:, 1:]
dim = embeddings.shape[-1]

# Calculate new dimensions
height = height // self.config.patch_size
width = width // self.config.patch_size
# we add a small number to avoid floating point error in the interpolation
# see discussion at https://github.com/facebookresearch/dino/issues/8
height, width = height + self.config.interpolate_offset, width + self.config.interpolate_offset
patch_pos_embed = patch_pos_embed.reshape(1, int(math.sqrt(num_positions)), int(math.sqrt(num_positions)), dim)

# Reshape for interpolation
sqrt_num_positions = torch_int(num_positions**0.5)
patch_pos_embed = patch_pos_embed.reshape(1, sqrt_num_positions, sqrt_num_positions, dim)
patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2)

# Store original dtype for restoration after interpolation
target_dtype = patch_pos_embed.dtype

# Interpolate at float32 precision
patch_pos_embed = nn.functional.interpolate(
patch_pos_embed.to(dtype=torch.float32),
scale_factor=(float(height / math.sqrt(num_positions)), float(width / math.sqrt(num_positions))),
size=(torch_int(height), torch_int(width)), # Explicit size instead of scale_factor
mode="bicubic",
align_corners=False,
antialias=self.config.interpolate_antialias,
)
patch_pos_embed = patch_pos_embed.to(dtype=target_dtype)
if int(height) != patch_pos_embed.shape[-2] or int(width) != patch_pos_embed.shape[-1]:
raise ValueError("Width or height does not match with the interpolated position embeddings")
antialias=True,
).to(dtype=target_dtype)

# Validate output dimensions if not tracing
if not torch.jit.is_tracing():
if int(height) != patch_pos_embed.shape[-2] or int(width) != patch_pos_embed.shape[-1]:
raise ValueError("Width or height does not match with the interpolated position embeddings")

# Reshape back to original format
patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)

# Combine class and patch embeddings
return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1)

def forward(self, pixel_values: torch.Tensor, bool_masked_pos: Optional[torch.Tensor] = None) -> torch.Tensor:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math

from typing import Optional

import torch
Expand All @@ -30,7 +30,7 @@
)
from ...configuration_utils import PretrainedConfig
from ...modeling_outputs import BackboneOutput
from ...utils import logging
from ...utils import logging, torch_int
from ...utils.backbone_utils import BackboneConfigMixin, get_aligned_output_features_output_indices


Expand Down Expand Up @@ -83,10 +83,6 @@ class Dinov2WithRegistersConfig(BackboneConfigMixin, PretrainedConfig):
Whether to use the SwiGLU feedforward neural network.
num_register_tokens (`int`, *optional*, defaults to 4):
Number of register tokens to use.
interpolate_antialias (`bool`, *optional*, defaults to `True`):
Whether to use antialiasing when interpolating the image patches.
interpolate_offset (`float`, *optional*, defaults to 0.0):
Offset to use when interpolating the image patches.
out_features (`List[str]`, *optional*):
If used as backbone, list of features to output. Can be any of `"stem"`, `"stage1"`, `"stage2"`, etc.
(depending on how many stages the model has). If unset and `out_indices` is set, will default to the
Expand Down Expand Up @@ -119,7 +115,7 @@ class Dinov2WithRegistersConfig(BackboneConfigMixin, PretrainedConfig):
>>> configuration = model.config
```"""

model_type = "dinov2-with-registers-base"
model_type = "dinov2_with_registers"

def __init__(
self,
Expand All @@ -140,8 +136,6 @@ def __init__(
drop_path_rate=0.0,
use_swiglu_ffn=False,
num_register_tokens=4,
interpolate_antialias=True,
interpolate_offset=0.0,
out_features=None,
out_indices=None,
apply_layernorm=True,
Expand All @@ -167,8 +161,6 @@ def __init__(
self.drop_path_rate = drop_path_rate
self.use_swiglu_ffn = use_swiglu_ffn
self.num_register_tokens = num_register_tokens
self.interpolate_antialias = interpolate_antialias
self.interpolate_offset = interpolate_offset
self.stage_names = ["stem"] + [f"stage{idx}" for idx in range(1, num_hidden_layers + 1)]
self._out_features, self._out_indices = get_aligned_output_features_output_indices(
out_features=out_features, out_indices=out_indices, stage_names=self.stage_names
Expand Down Expand Up @@ -196,43 +188,61 @@ def __init__(self, config: Dinov2WithRegistersConfig) -> None:
num_patches = self.patch_embeddings.num_patches
self.position_embeddings = nn.Parameter(torch.randn(1, num_patches + 1, config.hidden_size))
self.dropout = nn.Dropout(config.hidden_dropout_prob)
self.patch_size = config.patch_size
self.config = config

def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor:
"""
This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher
resolution images.
resolution images. This implementation supports torch.jit tracing while maintaining backwards compatibility
with the original implementation.
Source:
https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174
Adapted from:
- https://github.com/facebookresearch/dino/blob/main/vision_transformer.py
- https://github.com/facebookresearch/dinov2/blob/main/dinov2/models/vision_transformer.py
"""

num_patches = embeddings.shape[1] - 1
num_positions = self.position_embeddings.shape[1] - 1
if num_patches == num_positions and height == width:

# Skip interpolation for matching dimensions (unless tracing)
if not torch.jit.is_tracing() and num_patches == num_positions and height == width:
return self.position_embeddings

# Handle class token and patch embeddings separately
class_pos_embed = self.position_embeddings[:, 0]
patch_pos_embed = self.position_embeddings[:, 1:]
dim = embeddings.shape[-1]

# Calculate new dimensions
height = height // self.config.patch_size
width = width // self.config.patch_size
# we add a small number to avoid floating point error in the interpolation
# see discussion at https://github.com/facebookresearch/dino/issues/8
height, width = height + self.config.interpolate_offset, width + self.config.interpolate_offset
patch_pos_embed = patch_pos_embed.reshape(1, int(math.sqrt(num_positions)), int(math.sqrt(num_positions)), dim)

# Reshape for interpolation
sqrt_num_positions = torch_int(num_positions**0.5)
patch_pos_embed = patch_pos_embed.reshape(1, sqrt_num_positions, sqrt_num_positions, dim)
patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2)

# Store original dtype for restoration after interpolation
target_dtype = patch_pos_embed.dtype

# Interpolate at float32 precision
patch_pos_embed = nn.functional.interpolate(
patch_pos_embed.to(dtype=torch.float32),
scale_factor=(float(height / math.sqrt(num_positions)), float(width / math.sqrt(num_positions))),
size=(torch_int(height), torch_int(width)), # Explicit size instead of scale_factor
mode="bicubic",
align_corners=False,
antialias=self.config.interpolate_antialias,
)
patch_pos_embed = patch_pos_embed.to(dtype=target_dtype)
if int(height) != patch_pos_embed.shape[-2] or int(width) != patch_pos_embed.shape[-1]:
raise ValueError("Width or height does not match with the interpolated position embeddings")
antialias=True,
).to(dtype=target_dtype)

# Validate output dimensions if not tracing
if not torch.jit.is_tracing():
if int(height) != patch_pos_embed.shape[-2] or int(width) != patch_pos_embed.shape[-1]:
raise ValueError("Width or height does not match with the interpolated position embeddings")

# Reshape back to original format
patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)

# Combine class and patch embeddings
return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1)

def forward(self, pixel_values: torch.Tensor, bool_masked_pos: Optional[torch.Tensor] = None) -> torch.Tensor:
Expand Down

0 comments on commit 12ba96a

Please sign in to comment.