Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Performance gap between manual nvfuser definition and thunder.jit #3629

Open
Priya2698 opened this issue Dec 20, 2024 · 4 comments
Open

Performance gap between manual nvfuser definition and thunder.jit #3629

Priya2698 opened this issue Dec 20, 2024 · 4 comments

Comments

@Priya2698
Copy link
Collaborator

Priya2698 commented Dec 20, 2024

I am seeing lower performance for thunder.jit (with nvfuserex executor) than the manual nvfuser definition existent in the python benchmark suite: http://nv/etb. This came up in testing PR #3394.

For size = (2048, 8192), dtype=torch.bfloat16 (on my local system with Ada card):

--------------------------------------------------------------------------------------------------------------------------- benchmark: 4 tests ---------------------------------------------------------------------------------------------------------------------------
Name (time in us)                                                                                      Min                 Max                Mean            StdDev              Median               IQR            Outliers  OPS (Kops/s)            Rounds  Iterations
--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
test_rmsnorm_bwd_nvf_benchmark[dtype=torch.bfloat16-size=[2048_8192]]                             136.8970 (1.0)      145.0250 (1.0)      140.3881 (1.0)      2.3346 (1.68)     140.0140 (1.0)      3.4200 (1.84)          2;0        7.1231 (1.0)          10           1
test_rmsnorm_bwd_baseline_benchmark[dtype=torch.bfloat16-size=[2048_8192]-executor='thunder']     223.9020 (1.64)     228.9010 (1.58)     226.1649 (1.61)     1.3899 (1.0)      226.0655 (1.61)     1.8540 (1.0)           2;0        4.4216 (0.62)         10           1
test_rmsnorm_bwd_nvf_benchmark[dtype=torch.float32-size=[2048_8192]]                              256.4510 (1.87)     265.5080 (1.83)     260.5773 (1.86)     3.0545 (2.20)     259.8870 (1.86)     4.8270 (2.60)          4;0        3.8376 (0.54)         10           1
test_rmsnorm_bwd_baseline_benchmark[dtype=torch.float32-size=[2048_8192]-executor='thunder']      271.0090 (1.98)     274.9130 (1.90)     273.3845 (1.95)     1.4553 (1.05)     273.9035 (1.96)     2.7580 (1.49)          5;0        3.6579 (0.51)         10           1
--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------

The above numbers are using rmsnorm composed of primitives:

def rmsnorm_prims(inputs: list):
    inp, weights = inputs
    squared_mean = (inp**2).mean(1, keepdim=True)
    rms_eps = torch.sqrt(squared_mean + 1e-5)
    output = weights * (inp / rms_eps)
    return output

I recover some of the performance using torch.nn.functional.rms_norm (Note that the manual nvfuser definition was generated through Thunder using the above rmsnorm_prims):

def rmsnorm_func(inputs: list):
    inp, weights = inputs
    output = F.rms_norm(inp, inp.shape[1:], weights, eps=1e-5)
    return output
--------------------------------------------------------------------------------------------------------------------------- benchmark: 4 tests ---------------------------------------------------------------------------------------------------------------------------
Name (time in us)                                                                                      Min                 Max                Mean            StdDev              Median               IQR            Outliers  OPS (Kops/s)            Rounds  Iterations
--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
test_rmsnorm_bwd_nvf_benchmark[dtype=torch.bfloat16-size=[2048_8192]]                             137.6300 (1.0)      143.1660 (1.0)      140.5177 (1.0)      1.9396 (1.91)     139.8885 (1.0)      3.2640 (2.55)          4;0        7.1165 (1.0)          10           1
test_rmsnorm_bwd_baseline_benchmark[dtype=torch.bfloat16-size=[2048_8192]-executor='thunder']     175.1710 (1.27)     178.3350 (1.25)     176.9573 (1.26)     1.0168 (1.0)      176.9435 (1.26)     1.2810 (1.0)           4;0        5.6511 (0.79)         10           1
test_rmsnorm_bwd_baseline_benchmark[dtype=torch.float32-size=[2048_8192]-executor='thunder']      255.0390 (1.85)     264.3810 (1.85)     258.7758 (1.84)     2.6816 (2.64)     258.5290 (1.85)     2.9120 (2.27)          3;0        3.8643 (0.54)         10           1
test_rmsnorm_bwd_nvf_benchmark[dtype=torch.float32-size=[2048_8192]]                              258.3390 (1.88)     267.1710 (1.87)     261.6898 (1.86)     2.7284 (2.68)     261.1510 (1.87)     2.7240 (2.13)          4;1        3.8213 (0.54)         10           1
--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
@Priya2698
Copy link
Collaborator Author

Priya2698 commented Dec 20, 2024

I have a (mostly) standalone script for nsys profiling here.

I'll run a sweep using F.rms_norm. The existent fusion definition in the python benchmarks was obtained using Thunder but modified to allow for dynamic shapes and dtypes. Some casts and broadcast ops may have been simplified, which may be responsible for the performance gap.
I'll looking at the difference in the operators present.

@Priya2698
Copy link
Collaborator Author

CC: @kevinstephano @mruberry

@mruberry
Copy link

I filed Lightning-AI/lightning-thunder#1582 to also track this in the thunder repository. Looking forward to hearing the results of your analysis, @Priya2698!

@Priya2698
Copy link
Collaborator Author

I compared the existing nvfuser definition (nvf_rmsnorm) with that generated from Thunder when using F.rms_norm (thunder_rmsnorm). I am not using the primitives-based implementation for comparison now since Thunder is now using F.rms_norm. This is also faster.

I have been looking at input size: [2048, 8192], dtype=bfloat16 on my local machine with Ada card.

  1. The launch parameters for both cases are same: BlockDim.x = 16, BlockDim.y = 16, BlockDim.z = 1, GridDim.x = -1, GridDim.y = 142, GridDim.z = -1, Smem Size = 50176

  2. One of the snippets I noted in the thunder_rmsnorm implementation is:

   T29 = fd.ops.cast(T3, dtype=DataType.BFloat16)
   T37 = fd.ops.broadcast_in_dim(T3, shape=[2048, 8192], broadcast_dims=[0, 1])
   T42 = fd.ops.cast(T37, dtype=DataType.Float)

where T3 = fd.define_tensor(shape=[2048, 1], contiguity=[True, None], dtype=DataType.Float, is_cpu=False, stride_order=[1, 0])

Simplifying this to avoid the roundtrip cast: T42 = fd.ops.broadcast_in_dim(T3, shape=[2048, 8192], broadcast_dims=[0, 1]) reduces time from 175ms->162ms (input size: [2048, 8192], dtype=bfloat16).

  1. The computation itself is different between the two definitions as well. nvf_rmsnorm takes as input RMS vs thunder_rmsnorm takes 1/RMS as input -- this difference is due to RMSNorm implementations (primitives-base implementation of RMSNorm stores RMS for backward pass).
  2. I see similar instructions for memory loads, shared memory access, warp reduce, waits etc in the CUDA kernels for both the definitions.
  3. Here is a link to the benchmarking run: http://nv/etT. The performance seems better than I previously saw which might have been due to infra difference. I will run it again for verification.

My next steps will be to isolate common computations in the two definitions and identify which instructions have the maximum difference in performance for the two fusion definitions.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants