-
Notifications
You must be signed in to change notification settings - Fork 54
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Performance gap between manual nvfuser definition and thunder.jit
#3629
Comments
I have a (mostly) standalone script for nsys profiling here. I'll run a sweep using |
I filed Lightning-AI/lightning-thunder#1582 to also track this in the thunder repository. Looking forward to hearing the results of your analysis, @Priya2698! |
I compared the existing nvfuser definition (nvf_rmsnorm) with that generated from Thunder when using F.rms_norm ( I have been looking at
where Simplifying this to avoid the roundtrip cast:
My next steps will be to isolate common computations in the two definitions and identify which instructions have the maximum difference in performance for the two fusion definitions. |
I am seeing lower performance for
thunder.jit
(with nvfuserex executor) than the manual nvfuser definition existent in the python benchmark suite: http://nv/etb. This came up in testing PR #3394.For
size = (2048, 8192), dtype=torch.bfloat16
(on my local system with Ada card):The above numbers are using rmsnorm composed of primitives:
I recover some of the performance using
torch.nn.functional.rms_norm
(Note that the manual nvfuser definition was generated through Thunder using the abovermsnorm_prims
):The text was updated successfully, but these errors were encountered: