Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make Jepa loader more flexible #945

Merged
merged 12 commits into from
Jan 3, 2025
Merged
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 17 additions & 2 deletions src/fairseq2/models/jepa/loader.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.

Check failure on line 1 in src/fairseq2/models/jepa/loader.py

View workflow job for this annotation

GitHub Actions / Lint Python / Lint

would reformat
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
Expand All @@ -6,7 +6,8 @@

from __future__ import annotations

from typing import Any
from pathlib import Path
from typing import Any, cast

Check failure on line 10 in src/fairseq2/models/jepa/loader.py

View workflow job for this annotation

GitHub Actions / Lint Python / Lint

'typing.cast' imported but unused

import torch

Expand All @@ -19,14 +20,14 @@
)
from fairseq2.models.loader import StandardModelLoader
from fairseq2.models.utils.checkpoint import convert_model_state_dict
from fairseq2.utils.file import MapLocation, load_tensors

load_jepa_config = StandardModelConfigLoader(JEPA_FAMILY, JepaConfig, jepa_archs)


def convert_jepa_checkpoint(
checkpoint: dict[str, Any], config: JepaConfig
) -> dict[str, Any]:
checkpoint = checkpoint["encoder"]

del checkpoint["module.backbone.pos_embed"]

Expand Down Expand Up @@ -73,8 +74,22 @@
return {"model": checkpoint}


def load_encoder_tensor(
path: Path, *, map_location: MapLocation = None, restrict: bool = False
) -> dict[str, object]:
"""Load encoder tensor"""

state_dict = load_tensors(path, map_location=map_location, restrict=restrict)

if "encoder" not in state_dict:
raise ValueError(f"`encoder` not found in state dict (available key: {state_dict.keys()})")

return state_dict["encoder"]

Check failure on line 87 in src/fairseq2/models/jepa/loader.py

View workflow job for this annotation

GitHub Actions / Lint Python / Lint

Incompatible return value type (got "object", expected "dict[str, object]")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Although I understood the PR description, I am not sure if I understand the change here. Any reason for not handling this check in convert_jepa_checkpoint? I mean instead of

checkpoint = checkpoint["encoder"]

having:

checkpoint = checkpoint.get("encoder")
if checkpoint is None:
  raise ValueError(...)

What is the benefit of having this check in a tensor_loader?

Copy link
Contributor Author

@antoine-tran antoine-tran Jan 2, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have 2 thoughts in making this change, both are opinionated though:

  • We should narrow the scope of convert_jepa_checkpoint function to only converting the parameters related to the jepa model. How we get into these parameters is handled separately (in TensorLoader).
  • With this, we do not list all possible checkpoint keys ("encoder" , "target_encoder") and define their priority in convert_jepa_checkpoint. This allows us to inject the pretrained encoders from other "exotic" checkpoints (for example, the jepa-llava where the encoder is stored in vision_tower).

The drawback of this approach though is we have to write custom TensorLoader for each checkpoint, so it is the matter of opinions here...

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about doing something like:

# Handles different variants of JEPA checkpoints and delegates the actual conversion
# to the standard converter.
def convert_jepa_checkpoint(
    checkpoint: dict[str, Any], config: JepaConfig
) -> dict[str, Any]:
  if "vision_tower" in checkpoint:
      return convert_jepa_encoder_checkpoint(checkpoint["vision_tower"])

  if "target_encoder" in checkpoint:
      return convert_jepa_encoder_checkpoint(checkpoint["target_encoder"])

  if "encoder" in checkpoint:
      return convert_jepa_encoder_checkpoint(checkpoint["encoder"])

  raise ValueError("encoder not found.")

def convert_jepa_encoder_checkpoint(
    checkpoint: dict[str, Any], config: JepaConfig
) -> dict[str, Any]:
    # Contains the current implementation.
    ...

My worry with the TensorLoader approach is that we leak state dict handling logic to tensor loading. Essentially we want to "pre-process" the checkpoint before passing it to the converter. So a wrapper function might do the job as well. Let me know what you think.



load_jepa_model = StandardModelLoader(
config_loader=load_jepa_config,
tensor_loader=load_encoder_tensor,
factory=create_jepa_model,
checkpoint_converter=convert_jepa_checkpoint,
)
Loading