You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Currently it just merges all specified IDs. To follow the convention of the reshape transform, broadcast IDs should be squeezed instead.
#3682 exposed this issue. I don't think there's any fundamental reason that broadcast should be squeezed, but I don't think there's any reason to change what we do with the usual reshape transform and we should be consistent.
The text was updated successfully, but these errors were encountered:
Adds `repeat` as an alias op as well as the `RepeatOp` IR node. The
`repeat` op has almost the same semantics as the PyTorch repeat.
The main motivation is to fix#3682, which is due to #3645, which
introduced a preseg pass that detects and translates a repeat pattern to
broadcast, expand and reshape. The issue of #3682 is because that the
translation-based method does not work when a broadcast ID is repeated.
I originally just used `TensorDomain::flatten`
(https://github.com/NVIDIA/Fuser/blob/main/csrc/ir/nodes.cpp#L3674-L3740),
which just merges broadcast IDs. However, for reshape, it should not
merge but squeeze them. Merging broadcast IDs triggered an assertion of
the transpose scheduler as seen in #3682.
`TensorDomain::flatten` needs to be fixed (#3691), but that's a separate
issue. For fixing #3682, since repeating broadcast IDs cannot be
translated to the broadcast-expand-reshape pattern anyway, I added the
new `RepeatOp` node. I initially thought it could be just `LoadStoreOp`
but decided to have a different IR node since, unlike usual LoadStore
case, some of the broadcast IDs of a producer becomes concrete IDs in
the corresponding consumer logical domain. I did actually try using
`LoadStoreOp` but some of the preseg passes complained the mismatched
broadcast pattern.
Repeating non-broadcast IDs is still done by the
broadcast-expand-reshape patten. Only for repeating broadcast IDs gets
represented using the `RepeatOp` node.
Fixes#3682
Currently it just merges all specified IDs. To follow the convention of the reshape transform, broadcast IDs should be squeezed instead.
#3682 exposed this issue. I don't think there's any fundamental reason that broadcast should be squeezed, but I don't think there's any reason to change what we do with the usual reshape transform and we should be consistent.
The text was updated successfully, but these errors were encountered: