From 7884d1cdecb1ea2fe509a49dd65cfe1b37018d1e Mon Sep 17 00:00:00 2001 From: Tuan Tran <{ID}+{username}@users.noreply.github.com> Date: Wed, 1 Jan 2025 08:56:57 +0100 Subject: [PATCH 01/11] revert loader --- src/fairseq2/models/jepa/loader.py | 26 ++++++++++++++++++++++++-- 1 file changed, 24 insertions(+), 2 deletions(-) diff --git a/src/fairseq2/models/jepa/loader.py b/src/fairseq2/models/jepa/loader.py index 27cb9f189..3fc5f95c2 100644 --- a/src/fairseq2/models/jepa/loader.py +++ b/src/fairseq2/models/jepa/loader.py @@ -6,7 +6,7 @@ from __future__ import annotations -from typing import Any +from typing import Any, cast import torch @@ -26,7 +26,10 @@ def convert_jepa_checkpoint( checkpoint: dict[str, Any], config: JepaConfig ) -> dict[str, Any]: - checkpoint = checkpoint["encoder"] + try: + checkpoint = checkpoint["target_encoder"] + except Exception as _: + checkpoint = checkpoint["encoder"] del checkpoint["module.backbone.pos_embed"] @@ -50,6 +53,25 @@ def convert_jepa_checkpoint( new_checkpoint[name[:-8] + "v_proj.bias"] = v_bias continue + + # if name == "module.backbone.pos_embed": + # # TODO: This only works for checkpoint that uses sinusoidal interpolated encoders + # _config = config.encoder_config + # input_3d_dims = cast(tuple[int, int, int], _config.input_dims) + # patch_3d_dims = cast(tuple[int, int, int], _config.patch_dims) + + # d_input_dim, h_input_dim, w_input_dim = input_3d_dims + # d_patch_dim, h_patch_dim, w_patch_dim = patch_3d_dims + + # target_dims = ( + # (d_input_dim // d_patch_dim), + # (h_input_dim // h_patch_dim), + # (w_input_dim // w_patch_dim), + # _config.model_dim, + # ) + # new_checkpoint["encoder_frontend.pos_encoder.freqs"] = torch.reshape(param, target_dims) + + # continue new_checkpoint[name] = param From 27d290ccdb6392ed1686d86eb642e7989b3b37aa Mon Sep 17 00:00:00 2001 From: Tuan Tran <{ID}+{username}@users.noreply.github.com> Date: Wed, 1 Jan 2025 08:58:29 +0100 Subject: [PATCH 02/11] cleanup comments --- src/fairseq2/models/jepa/loader.py | 19 ------------------- 1 file changed, 19 deletions(-) diff --git a/src/fairseq2/models/jepa/loader.py b/src/fairseq2/models/jepa/loader.py index 3fc5f95c2..d2bb9e855 100644 --- a/src/fairseq2/models/jepa/loader.py +++ b/src/fairseq2/models/jepa/loader.py @@ -53,25 +53,6 @@ def convert_jepa_checkpoint( new_checkpoint[name[:-8] + "v_proj.bias"] = v_bias continue - - # if name == "module.backbone.pos_embed": - # # TODO: This only works for checkpoint that uses sinusoidal interpolated encoders - # _config = config.encoder_config - # input_3d_dims = cast(tuple[int, int, int], _config.input_dims) - # patch_3d_dims = cast(tuple[int, int, int], _config.patch_dims) - - # d_input_dim, h_input_dim, w_input_dim = input_3d_dims - # d_patch_dim, h_patch_dim, w_patch_dim = patch_3d_dims - - # target_dims = ( - # (d_input_dim // d_patch_dim), - # (h_input_dim // h_patch_dim), - # (w_input_dim // w_patch_dim), - # _config.model_dim, - # ) - # new_checkpoint["encoder_frontend.pos_encoder.freqs"] = torch.reshape(param, target_dims) - - # continue new_checkpoint[name] = param From 8beeddd9108c2920a3c1981c401b790fa30b97a2 Mon Sep 17 00:00:00 2001 From: Tuan Tran <{ID}+{username}@users.noreply.github.com> Date: Thu, 2 Jan 2025 11:22:48 +0100 Subject: [PATCH 03/11] update jepa loader --- src/fairseq2/models/jepa/loader.py | 33 ++++++++++++++++++++++++++---- 1 file changed, 29 insertions(+), 4 deletions(-) diff --git a/src/fairseq2/models/jepa/loader.py b/src/fairseq2/models/jepa/loader.py index d2bb9e855..063535efb 100644 --- a/src/fairseq2/models/jepa/loader.py +++ b/src/fairseq2/models/jepa/loader.py @@ -6,6 +6,7 @@ from __future__ import annotations +from pathlib import Path from typing import Any, cast import torch @@ -19,6 +20,7 @@ ) from fairseq2.models.loader import StandardModelLoader from fairseq2.models.utils.checkpoint import convert_model_state_dict +from fairseq2.utils.file import MapLocation, load_tensors load_jepa_config = StandardModelConfigLoader(JEPA_FAMILY, JepaConfig, jepa_archs) @@ -26,10 +28,6 @@ def convert_jepa_checkpoint( checkpoint: dict[str, Any], config: JepaConfig ) -> dict[str, Any]: - try: - checkpoint = checkpoint["target_encoder"] - except Exception as _: - checkpoint = checkpoint["encoder"] del checkpoint["module.backbone.pos_embed"] @@ -76,8 +74,35 @@ def convert_jepa_checkpoint( return {"model": checkpoint} +def load_encoder_tensor( + path: Path, *, map_location: MapLocation = None, restrict: bool = False +) -> dict[str, object]: + """Load encoder tensor""" + + state_dict = load_tensors(path, map_location=map_location, restrict=restrict) + + if "encoder" not in state_dict: + raise ValueError(f"`encoder` not found in state dict (available key: {state_dict.keys()})") + + return state_dict["encoder"] + + +def load_target_encoder_tensor( + path: Path, *, map_location: MapLocation = None, restrict: bool = False +) -> dict[str, object]: + """Load encoder tensor""" + + state_dict = load_tensors(path, map_location=map_location, restrict=restrict) + + if "encoder" not in state_dict: + raise ValueError(f"`encoder` not found in state dict (available key: {state_dict.keys()})") + + return state_dict["encoder"] + + load_jepa_model = StandardModelLoader( config_loader=load_jepa_config, + tensor_loader=load_encoder_tensor, factory=create_jepa_model, checkpoint_converter=convert_jepa_checkpoint, ) From 4d920c9c1ea36b2356cfe76c405e14fbeb20115c Mon Sep 17 00:00:00 2001 From: Tuan Tran <{ID}+{username}@users.noreply.github.com> Date: Thu, 2 Jan 2025 11:37:30 +0100 Subject: [PATCH 04/11] rebase --- src/fairseq2/models/jepa/loader.py | 13 ------------- 1 file changed, 13 deletions(-) diff --git a/src/fairseq2/models/jepa/loader.py b/src/fairseq2/models/jepa/loader.py index 063535efb..79f81bac5 100644 --- a/src/fairseq2/models/jepa/loader.py +++ b/src/fairseq2/models/jepa/loader.py @@ -87,19 +87,6 @@ def load_encoder_tensor( return state_dict["encoder"] -def load_target_encoder_tensor( - path: Path, *, map_location: MapLocation = None, restrict: bool = False -) -> dict[str, object]: - """Load encoder tensor""" - - state_dict = load_tensors(path, map_location=map_location, restrict=restrict) - - if "encoder" not in state_dict: - raise ValueError(f"`encoder` not found in state dict (available key: {state_dict.keys()})") - - return state_dict["encoder"] - - load_jepa_model = StandardModelLoader( config_loader=load_jepa_config, tensor_loader=load_encoder_tensor, From 0ffab250744f0cda5486cdb14a5d4b241538b3d6 Mon Sep 17 00:00:00 2001 From: Tuan Tran <{ID}+{username}@users.noreply.github.com> Date: Thu, 2 Jan 2025 17:25:29 +0100 Subject: [PATCH 05/11] lint --- src/fairseq2/models/jepa/loader.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/src/fairseq2/models/jepa/loader.py b/src/fairseq2/models/jepa/loader.py index 79f81bac5..4e15460a7 100644 --- a/src/fairseq2/models/jepa/loader.py +++ b/src/fairseq2/models/jepa/loader.py @@ -7,7 +7,7 @@ from __future__ import annotations from pathlib import Path -from typing import Any, cast +from typing import Any import torch @@ -28,7 +28,6 @@ def convert_jepa_checkpoint( checkpoint: dict[str, Any], config: JepaConfig ) -> dict[str, Any]: - del checkpoint["module.backbone.pos_embed"] new_checkpoint = {} @@ -82,7 +81,9 @@ def load_encoder_tensor( state_dict = load_tensors(path, map_location=map_location, restrict=restrict) if "encoder" not in state_dict: - raise ValueError(f"`encoder` not found in state dict (available key: {state_dict.keys()})") + raise ValueError( + f"`encoder` not found in state dict (available key: {state_dict.keys()})" + ) return state_dict["encoder"] From 24045f6e3801156d2ee5ef95ef7abdedf6de73ae Mon Sep 17 00:00:00 2001 From: Tuan Tran <{ID}+{username}@users.noreply.github.com> Date: Thu, 2 Jan 2025 17:32:40 +0100 Subject: [PATCH 06/11] mypy --- src/fairseq2/models/jepa/loader.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/fairseq2/models/jepa/loader.py b/src/fairseq2/models/jepa/loader.py index 4e15460a7..67e496ae8 100644 --- a/src/fairseq2/models/jepa/loader.py +++ b/src/fairseq2/models/jepa/loader.py @@ -7,7 +7,7 @@ from __future__ import annotations from pathlib import Path -from typing import Any +from typing import Any, cast import torch @@ -85,7 +85,7 @@ def load_encoder_tensor( f"`encoder` not found in state dict (available key: {state_dict.keys()})" ) - return state_dict["encoder"] + return cast(dict[str, object], state_dict["encoder"])ß load_jepa_model = StandardModelLoader( From bfd7951ab479bfafa7e6a717ac2870a7df5780c5 Mon Sep 17 00:00:00 2001 From: Tuan Tran <{ID}+{username}@users.noreply.github.com> Date: Thu, 2 Jan 2025 17:48:41 +0100 Subject: [PATCH 07/11] typo --- src/fairseq2/models/jepa/loader.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/fairseq2/models/jepa/loader.py b/src/fairseq2/models/jepa/loader.py index 67e496ae8..e140659dd 100644 --- a/src/fairseq2/models/jepa/loader.py +++ b/src/fairseq2/models/jepa/loader.py @@ -85,7 +85,7 @@ def load_encoder_tensor( f"`encoder` not found in state dict (available key: {state_dict.keys()})" ) - return cast(dict[str, object], state_dict["encoder"])ß + return cast(dict[str, object], state_dict["encoder"]) load_jepa_model = StandardModelLoader( From 09b843a419bc41b4c1601fccd9cb7784b743d2c8 Mon Sep 17 00:00:00 2001 From: Tuan Tran <{ID}+{username}@users.noreply.github.com> Date: Thu, 2 Jan 2025 19:26:52 +0100 Subject: [PATCH 08/11] Can's comments --- src/fairseq2/models/jepa/loader.py | 30 ++++++++++++++---------------- 1 file changed, 14 insertions(+), 16 deletions(-) diff --git a/src/fairseq2/models/jepa/loader.py b/src/fairseq2/models/jepa/loader.py index e140659dd..64ce178a7 100644 --- a/src/fairseq2/models/jepa/loader.py +++ b/src/fairseq2/models/jepa/loader.py @@ -28,6 +28,20 @@ def convert_jepa_checkpoint( checkpoint: dict[str, Any], config: JepaConfig ) -> dict[str, Any]: + # We have a shared checkpoint, used for other use cases (frozen evaluation,..) + if "target_encoder" in checkpoint: + return convert_jepa_encoder_checkpoint(checkpoint["target_encoder"], config=config) + + if "encoder" in checkpoint: + return convert_jepa_encoder_checkpoint(checkpoint["encoder"], config=config) + + raise ValueError(f"encoder not found (available keys: {checkpoint.keys()})") + + +def convert_jepa_encoder_checkpoint( + checkpoint: dict[str, Any], config: JepaConfig +) -> dict[str, Any]: + del checkpoint["module.backbone.pos_embed"] new_checkpoint = {} @@ -73,24 +87,8 @@ def convert_jepa_checkpoint( return {"model": checkpoint} -def load_encoder_tensor( - path: Path, *, map_location: MapLocation = None, restrict: bool = False -) -> dict[str, object]: - """Load encoder tensor""" - - state_dict = load_tensors(path, map_location=map_location, restrict=restrict) - - if "encoder" not in state_dict: - raise ValueError( - f"`encoder` not found in state dict (available key: {state_dict.keys()})" - ) - - return cast(dict[str, object], state_dict["encoder"]) - - load_jepa_model = StandardModelLoader( config_loader=load_jepa_config, - tensor_loader=load_encoder_tensor, factory=create_jepa_model, checkpoint_converter=convert_jepa_checkpoint, ) From 615120db735a651020485bc4b4373bc6229c0740 Mon Sep 17 00:00:00 2001 From: Tuan Tran <{ID}+{username}@users.noreply.github.com> Date: Fri, 3 Jan 2025 10:00:09 +0100 Subject: [PATCH 09/11] lint --- src/fairseq2/models/jepa/loader.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/fairseq2/models/jepa/loader.py b/src/fairseq2/models/jepa/loader.py index 64ce178a7..046806d71 100644 --- a/src/fairseq2/models/jepa/loader.py +++ b/src/fairseq2/models/jepa/loader.py @@ -6,8 +6,7 @@ from __future__ import annotations -from pathlib import Path -from typing import Any, cast +from typing import Any import torch @@ -20,7 +19,6 @@ ) from fairseq2.models.loader import StandardModelLoader from fairseq2.models.utils.checkpoint import convert_model_state_dict -from fairseq2.utils.file import MapLocation, load_tensors load_jepa_config = StandardModelConfigLoader(JEPA_FAMILY, JepaConfig, jepa_archs) From a94f8cbf5b64dc04bdf6370e132c4ec65ee0eb37 Mon Sep 17 00:00:00 2001 From: Anh Tuan Tran <1254753+antoine-tran@users.noreply.github.com> Date: Fri, 3 Jan 2025 18:58:49 +0000 Subject: [PATCH 10/11] black --- src/fairseq2/models/jepa/loader.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/fairseq2/models/jepa/loader.py b/src/fairseq2/models/jepa/loader.py index 046806d71..a4e546509 100644 --- a/src/fairseq2/models/jepa/loader.py +++ b/src/fairseq2/models/jepa/loader.py @@ -28,7 +28,9 @@ def convert_jepa_checkpoint( ) -> dict[str, Any]: # We have a shared checkpoint, used for other use cases (frozen evaluation,..) if "target_encoder" in checkpoint: - return convert_jepa_encoder_checkpoint(checkpoint["target_encoder"], config=config) + return convert_jepa_encoder_checkpoint( + checkpoint["target_encoder"], config=config + ) if "encoder" in checkpoint: return convert_jepa_encoder_checkpoint(checkpoint["encoder"], config=config) From 6f79d1b885d17ed3633e65b85e0b62947094f2e3 Mon Sep 17 00:00:00 2001 From: Tuan Tran <{ID}+{username}@users.noreply.github.com> Date: Fri, 3 Jan 2025 20:22:37 +0100 Subject: [PATCH 11/11] black --- src/fairseq2/models/jepa/loader.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/fairseq2/models/jepa/loader.py b/src/fairseq2/models/jepa/loader.py index a4e546509..a38c613e7 100644 --- a/src/fairseq2/models/jepa/loader.py +++ b/src/fairseq2/models/jepa/loader.py @@ -41,7 +41,6 @@ def convert_jepa_checkpoint( def convert_jepa_encoder_checkpoint( checkpoint: dict[str, Any], config: JepaConfig ) -> dict[str, Any]: - del checkpoint["module.backbone.pos_embed"] new_checkpoint = {}