Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature request: Schedule a small matmul op as a reduction (or pointwise) op #3646

Open
naoyam opened this issue Dec 25, 2024 · 5 comments
Open
Labels

Comments

@naoyam
Copy link
Collaborator

naoyam commented Dec 25, 2024

In RoPE, there's a small matmul, which is currently sent to aten. For example, this is a first part of the Mistral forward RoPE module:

Inputs:
  T0_g___bfloat[bS0{1}, iS1{4096}, iS2{4096}]
  T1_g___bfloat[bS3{1}, iS4{4096}, iS5{1024}]
  T2_g___bfloat[bS6{1}, iS7{4096}, iS8{1024}]
  T3_g___bfloat[iS9{64}]
  T4_g_int64_t[bS10{1}, iS11{4096}]
Outputs:
  T29_g___bfloat[bS99{1}, bS100{1}, iS101{4096}, iS102{128}]
  T31_g___bfloat[bS107{1}, bS108{1}, iS109{4096}, iS110{128}]
  T51_g___bfloat[bS191{1}, iS192{32}, iS193{4096}, iS194{128}]
  T76_g___bfloat[bS299{1}, iS306{32}rf, iS302{4096}, iS303{128}]
  T81_g___bfloat[bS327{1}, iS334{32}rf, iS330{4096}, iS331{128}]

%kernel_math {
T5_l___bfloat[bS12{1}, iS13{4096}, iS16{32}rf, iS17{128}rf] = view( T0_g___bfloat[bS0{1}, iS1{4096}, iS2{4096}] )
T6_l___bfloat[bS18{1}, iS20{32}, iS19{4096}, iS21{128}]
   = Set.Permute( T5_l___bfloat[bS12{1}, iS13{4096}, iS16{32}rf, iS17{128}rf], cache_op=Streaming )
T34_l_float[bS119{1}, iS120{32}, iS121{4096}, iS122{128}]
   = __bfloat2float(T6_l___bfloat[bS18{1}, iS20{32}, iS19{4096}, iS21{128}]);
T11_l___bfloat[bS42{1}, iS43{64}, bS44{1}]
   = broadcast( T3_g___bfloat[iS9{64}] )
T12_l___bfloat[bS45{1}, iS46{64}, bS47{1}]
   = Set( T11_l___bfloat[bS42{1}, iS43{64}, bS44{1}], cache_op=Streaming )
T13_l_float[bS48{1}, iS49{64}, bS50{1}]
   = __bfloat2float(T12_l___bfloat[bS45{1}, iS46{64}, bS47{1}]);
T14_l_float[bS51{1}, iS52{64}, bS53{1}]
   = Set( T13_l_float[bS48{1}, iS49{64}, bS50{1}], cache_op=Streaming )
T15_l_float[bS54{1}, iS55{64}, bS56{1}]
   = Set( T14_l_float[bS51{1}, iS52{64}, bS53{1}], cache_op=Streaming )
T16_l_int64_t[bS57{1}, bS58{1}, iS59{4096}]
   = broadcast( T4_g_int64_t[bS10{1}, iS11{4096}] )
T17_l_int64_t[bS60{1}, bS61{1}, iS62{4096}]
   = Set( T16_l_int64_t[bS57{1}, bS58{1}, iS59{4096}], cache_op=Streaming )
T18_l_float[bS63{1}, bS64{1}, iS65{4096}]
   = (float)(T17_l_int64_t[bS60{1}, bS61{1}, iS62{4096}]);
T19_l_float[bS66{1}, iS67{64}, iS68{4096}]
   = matmul(T15_l_float[bS54{1}, iS55{64}, bS56{1}],
            T18_l_float[bS63{1}, bS64{1}, iS65{4096}])
T20_l_float[bS69{1}, iS71{4096}, iS70{64}]
   = Set.Permute( T19_l_float[bS66{1}, iS67{64}, iS68{4096}], cache_op=Streaming )
T21_l_float[bS72{1}, iS73{4096}, iS75{128}rf]
   = pad( T20_l_float[bS69{1}, iS71{4096}, iS70{64}], {0, 0, 0, 0, 0, 64} )
i85 = 0 + 64;
T22_l_float[bS76{1}, iS77{4096}, iS79{( ( 0 + 64 ) + 64 )}rf]
   = pad( T20_l_float[bS69{1}, iS71{4096}, iS70{64}], {0, 0, 0, 0, i85, 0} )
T23_l_float[bS80{1}, iS81{4096}, iS82{128}]
   = cat( T21_l_float[bS72{1}, iS73{4096}, iS75{128}rf], T22_l_float[bS76{1}, iS77{4096}, iS79{( ( 0 + 64 ) + 64 )}rf], 2 )
T24_l_float[bS83{1}, iS84{4096}, iS85{128}]
   = cosf(T23_l_float[bS80{1}, iS81{4096}, iS82{128}]);
T26_l___bfloat[bS89{1}, iS90{4096}, iS91{128}]
   = __float2bfloat(T24_l_float[bS83{1}, iS84{4096}, iS85{128}]);

The matmul op producing T19 becomes a segmentation boundary as the op and only itself is handled by aten, and the pre- and post sections are handled by the other schedulers. While this would make sense if the matmul op were compute-heavy, in this particular case it is unlikely as the dimensions are quite small.

T19_l_float[bS66{1}, iS67{64}, iS68{4096}]
   = matmul(T15_l_float[bS54{1}, iS55{64}, bS56{1}],
            T18_l_float[bS63{1}, bS64{1}, iS65{4096}])

This could be translated to just a sequence of pointwise ops:

T15_b = broadcast(T15, {false, false, false, true});
T18_b = broadcast(T18, {false, true, false, false});
T19 = squeeze(mul(T15_b, T18_b), -2);

Combined with #3645, the above section of the forward module would be likely fused into a single kernel with no segmentation.

@naoyam naoyam added the rope label Dec 25, 2024
@jacobhinkle
Copy link
Collaborator

@Priya2698, I think it's time we start handling these K=1 cases in the matmul op (and similar for linear) as we discussed a while back. What do you think?

@Priya2698
Copy link
Collaborator

@Priya2698, I think it's time we start handling these K=1 cases in the matmul op (and similar for linear) as we discussed a while back. What do you think?

Yes, we can start decomposing them since we have a use case.

@Priya2698
Copy link
Collaborator

This PR: #2397 removed the pointwise decomposition of matmul in such cases to avoid any performance regressions compared to ATen (specially gemmv).
We can introduce these back. Should we decompose only for K=1 case right now, or also where A/B might be 1D?
CC: @kevinstephano

@naoyam
Copy link
Collaborator Author

naoyam commented Jan 16, 2025

Is there any concern doing this when K=1?

@Priya2698
Copy link
Collaborator

Priya2698 commented Jan 17, 2025

Is there any concern doing this when K=1?

The dot product is somewhat tricky, nvFuser was a little faster (around 2% for the test cases I tried) for very large K (order of 10^8) and slower for smaller K (order of 10^5). For smaller K, the kernel time is under 10us, and the difference between eager and nvFuser can be upto 25%.

I did not notice this for K=1 afair. I would need to re-run cases again, but the overall gains from fusing operators in such cases may be higher masking small regressions, if any. It might be better to focus on overall performance of cases like above.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests

3 participants