Skip to content

Commit

Permalink
decouple #938
Browse files Browse the repository at this point in the history
  • Loading branch information
Tuan Tran authored and Tuan Tran committed Dec 26, 2024
1 parent f96344b commit 75472ab
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 13 deletions.
9 changes: 1 addition & 8 deletions src/fairseq2/models/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,6 @@ def __call__(
dtype: DataType | None = None,
force: bool = False,
progress: bool = True,
strict_state_dict: bool = True,
) -> ModelT_co:
"""
:param model_name_or_card:
Expand All @@ -99,9 +98,6 @@ def __call__(
cache.
:param progress:
If ``True``, displays a progress bar to stderr.
:param strict_state_dict:
If ``True``, checkpoint' parameters and layers must be identical to
the model state dict)
:returns:
A model loaded from the checkpoint of ``model_name_or_card``.
Expand Down Expand Up @@ -205,7 +201,6 @@ def __call__(
dtype: DataType | None = None,
force: bool = False,
progress: bool = True,
strict_state_dict: bool = True,

) -> ModelT:
if isinstance(model_name_or_card, AssetCard):
Expand Down Expand Up @@ -361,7 +356,7 @@ def __call__(
consume_prefix_in_state_dict_if_present(state_dict, prefix="module.")

try:
load_state_dict(model, state_dict, strict=strict_state_dict)
load_state_dict(model, state_dict)
except (KeyError, ValueError) as ex:
raise AssetError(
f"{card.name} cannot be loaded. See nested exception for details."
Expand Down Expand Up @@ -402,7 +397,6 @@ def __call__(
dtype: DataType | None = None,
force: bool = False,
progress: bool = True,
strict_state_dict: bool = True,
) -> ModelT:
if isinstance(model_name_or_card, AssetCard):
card = model_name_or_card
Expand All @@ -426,7 +420,6 @@ def __call__(
dtype=dtype,
force=force,
progress=progress,
strict_state_dict=strict_state_dict,
)

def register(self, family: str, loader: ModelLoader[ModelT]) -> None:
Expand Down
11 changes: 6 additions & 5 deletions src/fairseq2/nn/utils/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -431,15 +431,16 @@ def broadcast_module(
_broadcast_coalesced(pg, tensors, bucket_size, source_rank)


def load_state_dict(module: Module, state_dict: Mapping[str, object], strict: bool = True) -> None:
def load_state_dict(module: Module, state_dict: Mapping[str, object]) -> None:
"""Copy parameters and buffers from ``state_dict`` into ``module`` and its
descendant modules.
This implementation internally calls :meth:`Module.load_state_dict()`, and also enforces that
``state_dict`` does not contain any keys corresponding to descendants that are set to ``None``
via :meth:`Module.register_module()`.
This implementation internally calls :meth:`Module.load_state_dict()` with
``strict`` set to ``True``, and also enforces that ``state_dict`` does not
contain any keys corresponding to descendants that are set to ``None`` via
:meth:`Module.register_module()`.
"""
module.load_state_dict(state_dict, strict=strict)
module.load_state_dict(state_dict, strict=True)

unexpected_keys = []

Expand Down

0 comments on commit 75472ab

Please sign in to comment.