diff --git a/src/fairseq2/models/loader.py b/src/fairseq2/models/loader.py index 0f4dabfab..589532e3f 100644 --- a/src/fairseq2/models/loader.py +++ b/src/fairseq2/models/loader.py @@ -82,7 +82,6 @@ def __call__( dtype: DataType | None = None, force: bool = False, progress: bool = True, - strict_state_dict: bool = True, ) -> ModelT_co: """ :param model_name_or_card: @@ -99,9 +98,6 @@ def __call__( cache. :param progress: If ``True``, displays a progress bar to stderr. - :param strict_state_dict: - If ``True``, checkpoint' parameters and layers must be identical to - the model state dict) :returns: A model loaded from the checkpoint of ``model_name_or_card``. @@ -205,7 +201,6 @@ def __call__( dtype: DataType | None = None, force: bool = False, progress: bool = True, - strict_state_dict: bool = True, ) -> ModelT: if isinstance(model_name_or_card, AssetCard): @@ -361,7 +356,7 @@ def __call__( consume_prefix_in_state_dict_if_present(state_dict, prefix="module.") try: - load_state_dict(model, state_dict, strict=strict_state_dict) + load_state_dict(model, state_dict) except (KeyError, ValueError) as ex: raise AssetError( f"{card.name} cannot be loaded. See nested exception for details." @@ -402,7 +397,6 @@ def __call__( dtype: DataType | None = None, force: bool = False, progress: bool = True, - strict_state_dict: bool = True, ) -> ModelT: if isinstance(model_name_or_card, AssetCard): card = model_name_or_card @@ -426,7 +420,6 @@ def __call__( dtype=dtype, force=force, progress=progress, - strict_state_dict=strict_state_dict, ) def register(self, family: str, loader: ModelLoader[ModelT]) -> None: diff --git a/src/fairseq2/nn/utils/module.py b/src/fairseq2/nn/utils/module.py index 02ab7d1d2..e76200599 100644 --- a/src/fairseq2/nn/utils/module.py +++ b/src/fairseq2/nn/utils/module.py @@ -431,15 +431,16 @@ def broadcast_module( _broadcast_coalesced(pg, tensors, bucket_size, source_rank) -def load_state_dict(module: Module, state_dict: Mapping[str, object], strict: bool = True) -> None: +def load_state_dict(module: Module, state_dict: Mapping[str, object]) -> None: """Copy parameters and buffers from ``state_dict`` into ``module`` and its descendant modules. - This implementation internally calls :meth:`Module.load_state_dict()`, and also enforces that - ``state_dict`` does not contain any keys corresponding to descendants that are set to ``None`` - via :meth:`Module.register_module()`. + This implementation internally calls :meth:`Module.load_state_dict()` with + ``strict`` set to ``True``, and also enforces that ``state_dict`` does not + contain any keys corresponding to descendants that are set to ``None`` via + :meth:`Module.register_module()`. """ - module.load_state_dict(state_dict, strict=strict) + module.load_state_dict(state_dict, strict=True) unexpected_keys = []