-
Notifications
You must be signed in to change notification settings - Fork 54
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Feature request: Schedule a small matmul op as a reduction (or pointwise) op #3646
Comments
@Priya2698, I think it's time we start handling these K=1 cases in the |
Yes, we can start decomposing them since we have a use case. |
This PR: #2397 removed the pointwise decomposition of matmul in such cases to avoid any performance regressions compared to ATen (specially gemmv). |
Is there any concern doing this when K=1? |
The dot product is somewhat tricky, nvFuser was a little faster (around 2% for the test cases I tried) for very large K (order of 10^8) and slower for smaller K (order of 10^5). For smaller K, the kernel time is under 10us, and the difference between eager and nvFuser can be upto 25%. I did not notice this for K=1 afair. I would need to re-run cases again, but the overall gains from fusing operators in such cases may be higher masking small regressions, if any. It might be better to focus on overall performance of cases like above. |
In RoPE, there's a small matmul, which is currently sent to aten. For example, this is a first part of the Mistral forward RoPE module:
The matmul op producing
T19
becomes a segmentation boundary as the op and only itself is handled by aten, and the pre- and post sections are handled by the other schedulers. While this would make sense if the matmul op were compute-heavy, in this particular case it is unlikely as the dimensions are quite small.This could be translated to just a sequence of pointwise ops:
Combined with #3645, the above section of the forward module would be likely fused into a single kernel with no segmentation.
The text was updated successfully, but these errors were encountered: