Skip to content

Commit

Permalink
Nit updates (#950)
Browse files Browse the repository at this point in the history
  • Loading branch information
cbalioglu authored Jan 4, 2025
1 parent 9b03479 commit 612a2a4
Show file tree
Hide file tree
Showing 6 changed files with 14 additions and 8 deletions.
3 changes: 0 additions & 3 deletions src/fairseq2/assets/download_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -607,6 +607,3 @@ def _get_final_asset_path(self) -> Path:
)

return asset_path


default_asset_download_manager = InProcAssetDownloadManager()
3 changes: 3 additions & 0 deletions src/fairseq2/nn/transformer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,9 @@
from fairseq2.nn.transformer.encoder_layer import (
TransformerEncoderLayer as TransformerEncoderLayer,
)
from fairseq2.nn.transformer.ffn import (
DauphinFeedForwardNetwork as DauphinFeedForwardNetwork,
)
from fairseq2.nn.transformer.ffn import FeedForwardNetwork as FeedForwardNetwork
from fairseq2.nn.transformer.ffn import GLUFeedForwardNetwork as GLUFeedForwardNetwork
from fairseq2.nn.transformer.ffn import (
Expand Down
2 changes: 1 addition & 1 deletion src/fairseq2/nn/utils/gradient.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ class _GradientScaleFunction(Function):
def forward(ctx: Any, x: Tensor, scale: float) -> Tensor: # type: ignore[override]
if not x.dtype.is_floating_point:
raise TypeError(
f"`x` is expected to be a float tensor, but is a `{x.dtype}` tensor instead."
f"`x` must be a float tensor, but is a `{x.dtype}` tensor instead."
)

ctx.scale = scale
Expand Down
2 changes: 1 addition & 1 deletion src/fairseq2/utils/dataclass.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def merge_dataclass(target: T, source: T) -> T:
"""Merge ``target`` with the data contained in ``source``."""
if type(target) is not type(source):
raise TypeError(
f"`target` and `source` are expected to be of the same type, but they are of types `{type(target)}` and `{type(source)}` instead."
f"`target` and `source` must be of the same type, but they are of types `{type(target)}` and `{type(source)}` instead."
)

return cast(T, _copy_dataclass(target, source))
Expand Down
10 changes: 8 additions & 2 deletions src/fairseq2/utils/structured.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
Any,
Literal,
Protocol,
TypeVar,
Union,
cast,
get_args,
Expand Down Expand Up @@ -604,8 +605,13 @@ def _unstructure_set(self, obj: object) -> list[object]:
default_value_converter = ValueConverter()


def structure(obj: object, type_: object, *, set_empty: bool = False) -> Any:
return default_value_converter.structure(obj, type_, set_empty=set_empty)
T = TypeVar("T")


def structure(obj: object, kls: type[T], *, set_empty: bool = False) -> T:
obj = default_value_converter.structure(obj, kls, set_empty=set_empty)

return cast(T, obj)


def unstructure(obj: object) -> object:
Expand Down
2 changes: 1 addition & 1 deletion tests/unit/nn/utils/test_gradient.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,6 @@ def test_scale_gradient_raises_error_if_tensor_is_non_float() -> None:

with pytest.raises(
TypeError,
match=r"^`x` is expected to be a float tensor, but is a `torch\.int32` tensor instead\.$",
match=r"^`x` must be a float tensor, but is a `torch\.int32` tensor instead\.$",
):
scale_gradient(a, 1.0)

0 comments on commit 612a2a4

Please sign in to comment.