Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[fx_importer] Add support for 0D tensors #401

Merged
merged 5 commits into from
Feb 6, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 27 additions & 13 deletions core/shark_turbine/importers/fx_importer.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
Attribute,
Block,
Context,
DenseElementsAttr,
DenseResourceElementsAttr,
FloatAttr,
BF16Type,
Expand Down Expand Up @@ -573,9 +574,11 @@ def _import_symbolic_torch_op(
# operations on symbolic arguments as regular python expressions rather than as torch ops
if is_builtin_function_or_method(target):
arg_types = [
arg.meta["val"].node.pytype
if isinstance(arg, torch.fx.Node)
else type(arg)
(
arg.meta["val"].node.pytype
if isinstance(arg, torch.fx.Node)
else type(arg)
)
for arg in node.args
]
is_int = [item == int for item in arg_types]
Expand Down Expand Up @@ -905,7 +908,7 @@ def create_mlir_tensor_type(tensor: torch.Tensor) -> IrType:
tensor_type = RankedTensorType.get(tuple(tensor.size()), element_type)
return tensor_type
except KeyError:
raise TypeError(f"Could not map Torch dtype {dtype} to an IREE type")
raise TypeError(f"Could not map Torch dtype {dtype} to an MLIR type")


def _make_vtensor_literal_op(
Expand All @@ -925,15 +928,26 @@ def _make_vtensor_literal_op(
# buffer is via the indirection: Tensor -> list -> numpy array. This allows us to create a vtensor literal as
# desired, but also limits which data types we can support in this function (see TORCH_DTYPE_TO_NPY_TYPE above)
np_tensor = np.array(tensor.tolist()).astype(npy_dtype)
bytes_view = memoryview(np_tensor)
tensor_type = create_mlir_tensor_type(tensor)
shape_desc = "_".join([str(d) for d in tensor.shape])
blob_name = f"torch_tensor_{shape_desc}_{str(tensor.dtype)}"
elements_attr = DenseResourceElementsAttr.get_from_buffer(
bytes_view,
blob_name,
tensor_type,
)
# one element constants are more optimizable as splat DenseElementsAttr. DenseResourceElementsAttr does not support splats, so don't use it for that case. In addition, at the time of writing, it has bugs with handling 0d tensors.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Capitalize first letter.

if np_tensor.size == 1:
try:
dtype = tensor.dtype
element_type = TORCH_DTYPE_TO_MLIR_TYPE[dtype]()
except KeyError:
raise TypeError(f"Could not map Torch dtype {dtype} to an MLIR type")
elements_attr = DenseElementsAttr.get(
type=element_type, array=np_tensor, shape=[]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shape param is wrong for 1d, 1 element. Get it from np_tensor.shape

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oof, sounds like a missing test case

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, add tests for shape [1].

)
else:
bytes_view = memoryview(np_tensor)
tensor_type = create_mlir_tensor_type(tensor)
shape_desc = "_".join([str(d) for d in tensor.shape])
blob_name = f"torch_tensor_{shape_desc}_{str(tensor.dtype)}"
elements_attr = DenseResourceElementsAttr.get_from_buffer(
bytes_view,
blob_name,
tensor_type,
)
mapping.value = elements_attr
else:
elements_attr = mapping.value
Expand Down
22 changes: 22 additions & 0 deletions core/tests/dynamo/importer_basic_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,28 @@ def foo(x):
opt_foo = torch.compile(foo, backend=create_backend())
opt_foo(torch.randn(4, 4, 4, 4))

def testScalarLiteralConversion(self):
"""
Test whether scalar tensors are appropriately converted to literals
"""

def foo():
a = torch.tensor(0, dtype=torch.int32)
b = torch.tensor(0, dtype=torch.int64)
c = torch.tensor(0, dtype=torch.float32)
d = torch.tensor(0, dtype=torch.float64)
e = torch.tensor(0, dtype=torch.complex64)
f = torch.tensor(0, dtype=torch.complex128)
g = torch.tensor(0, dtype=torch.bool)
h = torch.tensor(0, dtype=torch.uint8)
i = torch.tensor(0, dtype=torch.int8)
j = torch.tensor(0, dtype=torch.int16)
return a, b, c, d, e, f, g, h, i, j

opt_foo = torch.compile(foo, backend=create_backend())
opt_foo()
print(opt_foo())

def testPromoteScalarTensor(self):
"""
Test whether scalar arguments are properly promoted to 0-rank Tensors for torch ops with no Scalar equivalent
Expand Down
Loading