Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[fx_importer] Add support for 0D tensors #401

Merged
merged 5 commits into from
Feb 6, 2024
Merged

[fx_importer] Add support for 0D tensors #401

merged 5 commits into from
Feb 6, 2024

Conversation

dan-garvey
Copy link
Member

@dan-garvey dan-garvey commented Feb 6, 2024

Adds an escape hatch from creating a DenseResourceElementsAttr for single value tensors into DenseElementsAttr.

Addresses #398

Adds an escape hatch from creating a DenseResourceElementsAttr for
single value tensors into DenseElementsAttr.
Copy link
Contributor

@stellaraccident stellaraccident left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks. A few comments to address before landing.

blob_name,
tensor_type,
)
# DenseResourceElementsAttr creation doesnt support rank 0 tensors, so we use DenseElementsAttr instead.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The comment should start with: one element constants are more optimizable as splat DenseElementsAttr. DenseResourceElementsAttr does not support splats, so don't use it for that case. In addition, at the time of writing, it has bugs with handling 0d tensors.

# DenseResourceElementsAttr creation doesnt support rank 0 tensors, so we use DenseElementsAttr instead.
if np_tensor.size == 1:
elements_attr = DenseElementsAttr.get(
type=TORCH_DTYPE_TO_MLIR_TYPE[tensor.dtype](), array=np_tensor
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you also pass the shape= param? Keeps it from making dodgy local decisions and is marginally more robust.

# DenseResourceElementsAttr creation doesnt support rank 0 tensors, so we use DenseElementsAttr instead.
if np_tensor.size == 1:
elements_attr = DenseElementsAttr.get(
type=TORCH_DTYPE_TO_MLIR_TYPE[tensor.dtype](), array=np_tensor
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wrap in try...except KeyError and raise a descriptive error message.

@stellaraccident
Copy link
Contributor

I'll need to sync this to torch-mlir once this lands. We're halfway through using that one as sot so need to port patches manually until it resolves. Can you leave the issue open and assign to me to port so I don't forget?

elements_attr = DenseElementsAttr.get(
type=TORCH_DTYPE_TO_MLIR_TYPE[tensor.dtype](), array=np_tensor
type=element_type, array=np_tensor, shape=[]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shape param is wrong for 1d, 1 element. Get it from np_tensor.shape

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oof, sounds like a missing test case

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, add tests for shape [1].

@@ -928,10 +928,15 @@ def _make_vtensor_literal_op(
# buffer is via the indirection: Tensor -> list -> numpy array. This allows us to create a vtensor literal as
# desired, but also limits which data types we can support in this function (see TORCH_DTYPE_TO_NPY_TYPE above)
np_tensor = np.array(tensor.tolist()).astype(npy_dtype)
# DenseResourceElementsAttr creation doesnt support rank 0 tensors, so we use DenseElementsAttr instead.
# one element constants are more optimizable as splat DenseElementsAttr. DenseResourceElementsAttr does not support splats, so don't use it for that case. In addition, at the time of writing, it has bugs with handling 0d tensors.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Capitalize first letter.

Copy link
Contributor

@stellaraccident stellaraccident left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Couple of comments

@dan-garvey dan-garvey merged commit 66f79ab into main Feb 6, 2024
4 checks passed
@dan-garvey dan-garvey deleted the rank0_fix branch February 6, 2024 05:44
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants