Skip to content

Commit

Permalink
[fx_importer] Add support for 0D tensors
Browse files Browse the repository at this point in the history
Adds an escape hatch from creating a DenseResourceElementsAttr for
single value tensors into DenseElementsAttr.
  • Loading branch information
dan-garvey committed Feb 6, 2024
1 parent 5fe6bb2 commit 699a963
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 12 deletions.
33 changes: 21 additions & 12 deletions core/shark_turbine/importers/fx_importer.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
Attribute,
Block,
Context,
DenseElementsAttr,
DenseResourceElementsAttr,
FloatAttr,
BF16Type,
Expand Down Expand Up @@ -573,9 +574,11 @@ def _import_symbolic_torch_op(
# operations on symbolic arguments as regular python expressions rather than as torch ops
if is_builtin_function_or_method(target):
arg_types = [
arg.meta["val"].node.pytype
if isinstance(arg, torch.fx.Node)
else type(arg)
(
arg.meta["val"].node.pytype
if isinstance(arg, torch.fx.Node)
else type(arg)
)
for arg in node.args
]
is_int = [item == int for item in arg_types]
Expand Down Expand Up @@ -925,15 +928,21 @@ def _make_vtensor_literal_op(
# buffer is via the indirection: Tensor -> list -> numpy array. This allows us to create a vtensor literal as
# desired, but also limits which data types we can support in this function (see TORCH_DTYPE_TO_NPY_TYPE above)
np_tensor = np.array(tensor.tolist()).astype(npy_dtype)
bytes_view = memoryview(np_tensor)
tensor_type = create_mlir_tensor_type(tensor)
shape_desc = "_".join([str(d) for d in tensor.shape])
blob_name = f"torch_tensor_{shape_desc}_{str(tensor.dtype)}"
elements_attr = DenseResourceElementsAttr.get_from_buffer(
bytes_view,
blob_name,
tensor_type,
)
# DenseResourceElementsAttr creation doesnt support rank 0 tensors, so we use DenseElementsAttr instead.
if np_tensor.size == 1:
elements_attr = DenseElementsAttr.get(
type=TORCH_DTYPE_TO_MLIR_TYPE[tensor.dtype](), array=np_tensor
)
else:
bytes_view = memoryview(np_tensor)
tensor_type = create_mlir_tensor_type(tensor)
shape_desc = "_".join([str(d) for d in tensor.shape])
blob_name = f"torch_tensor_{shape_desc}_{str(tensor.dtype)}"
elements_attr = DenseResourceElementsAttr.get_from_buffer(
bytes_view,
blob_name,
tensor_type,
)
mapping.value = elements_attr
else:
elements_attr = mapping.value
Expand Down
22 changes: 22 additions & 0 deletions core/tests/dynamo/importer_basic_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,28 @@ def foo(x):
opt_foo = torch.compile(foo, backend=create_backend())
opt_foo(torch.randn(4, 4, 4, 4))

def testScalarLiteralConversion(self):
"""
Test whether scalar tensors are appropriately converted to literals
"""

def foo():
a = torch.tensor(0, dtype=torch.int32)
b = torch.tensor(0, dtype=torch.int64)
c = torch.tensor(0, dtype=torch.float32)
d = torch.tensor(0, dtype=torch.float64)
e = torch.tensor(0, dtype=torch.complex64)
f = torch.tensor(0, dtype=torch.complex128)
g = torch.tensor(0, dtype=torch.bool)
h = torch.tensor(0, dtype=torch.uint8)
i = torch.tensor(0, dtype=torch.int8)
j = torch.tensor(0, dtype=torch.int16)
return a, b, c, d, e, f, g, h, i, j

opt_foo = torch.compile(foo, backend=create_backend())
opt_foo()
print(opt_foo())

def testPromoteScalarTensor(self):
"""
Test whether scalar arguments are properly promoted to 0-rank Tensors for torch ops with no Scalar equivalent
Expand Down

0 comments on commit 699a963

Please sign in to comment.