Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

TensorDomain::flatten should squeeze broadcast IDs as done by the usual reshape transform #3691

Open
naoyam opened this issue Jan 9, 2025 · 0 comments

Comments

@naoyam
Copy link
Collaborator

naoyam commented Jan 9, 2025

Currently it just merges all specified IDs. To follow the convention of the reshape transform, broadcast IDs should be squeezed instead.

#3682 exposed this issue. I don't think there's any fundamental reason that broadcast should be squeezed, but I don't think there's any reason to change what we do with the usual reshape transform and we should be consistent.

naoyam added a commit that referenced this issue Jan 10, 2025
Adds `repeat` as an alias op as well as the `RepeatOp` IR node. The
`repeat` op has almost the same semantics as the PyTorch repeat.

The main motivation is to fix #3682, which is due to #3645, which
introduced a preseg pass that detects and translates a repeat pattern to
broadcast, expand and reshape. The issue of #3682 is because that the
translation-based method does not work when a broadcast ID is repeated.
I originally just used `TensorDomain::flatten`
(https://github.com/NVIDIA/Fuser/blob/main/csrc/ir/nodes.cpp#L3674-L3740),
which just merges broadcast IDs. However, for reshape, it should not
merge but squeeze them. Merging broadcast IDs triggered an assertion of
the transpose scheduler as seen in #3682.

`TensorDomain::flatten` needs to be fixed (#3691), but that's a separate
issue. For fixing #3682, since repeating broadcast IDs cannot be
translated to the broadcast-expand-reshape pattern anyway, I added the
new `RepeatOp` node. I initially thought it could be just `LoadStoreOp`
but decided to have a different IR node since, unlike usual LoadStore
case, some of the broadcast IDs of a producer becomes concrete IDs in
the corresponding consumer logical domain. I did actually try using
`LoadStoreOp` but some of the preseg passes complained the mismatched
broadcast pattern.

Repeating non-broadcast IDs is still done by the
broadcast-expand-reshape patten. Only for repeating broadcast IDs gets
represented using the `RepeatOp` node.

Fixes #3682
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant