forked from NVIDIA/cutlass
-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy path51_hopper_gett.cu
371 lines (318 loc) · 16.9 KB
/
51_hopper_gett.cu
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
/***************************************************************************************************
* Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief Example of a GETT targeting Hopper tensor cores using the CUTLASS 3.x API.
CUTLASS has long provided implementations of Generalized Matrix times Matrix (GEMM) kernels.
However, a plethora of workloads compute on higher ranked tensors. Products of such tensors,
called tensor contractions, can be executed as multiple batched GEMMs, however, they can be
further accelerated with kernels that natively operate on these higher ranked tensors to
perform Generalized Tensor times Tensor contractions (GETT). CuTe's hierarchical layouts
and CUTLASS 3.0's unified micro-kernels make implementation of GETTs trivial. In this example,
we show how CUTLASS 3.0, CuTe, and Hopper's TMA feature together can accelerate GETTs while
making the process of authoring custom GETT kernels easier than ever before.
The modes of a tensor that participate in a GETT can be fundamentally grouped into four
semantic categories. The contraction modes (or K-modes) only appear in the A and B (left and right)
inputs but not in the C output tensor. Row modes (or M-modes) only appear in the left
input tensor (A) and the output tensor (C). Column modes (or N-modes) only appear in the
right (B) input tensor and the output tensor (C). Batch modes (or L-modes) appear in all
input and output tensors. If we fold the many modes of a tensor contraction into these four
categories, it would allow us to represent the input and output tensors as rank-3 "matrices"
that can be computed upon as if we were computing a batched GEMM!
This is exactly what CuTe's hierarchical layout representation allows us to do! Instead of having
simple integers as strides for these four modes, we can have nested strides for each of these
semantic categories that themselves have multiple modes within them -- multi-mode strides!
In CUTLASS 3.0, all one has to do to take advantage of this capability is to substitute the
required multi-mode strides instead of the default ones provided by gemm::detail::TagToStrideX.
In the following example, we illustrate how every Hopper GEMM in CUTLASS 3.0 is a GETT in disguise.
We begin by defining the four modes detailed above as Row, Col (column), Red (reduction), and
Bat (batch) strides, which we then nest for each of the in/out tensors to create our rank-3 stride
tuples. Note that although we do not define the problem shape type explicitely, it too remains a
rank-4 shape tuple just like any other batched GEMM, but instead with multi-mode shapes for each
of the four corresponding multi-modes within it. After this, the same CollectiveMma and
CollectiveBuilder we describe in examples 50 and 49 are used to create our kernel type. Nothing
else changes from a user's point of view. Note that multi-mode strides do not affect our
specializations in any way -- the lexical spelling of our kernels remains the same. The
only difference between a CUTLASS 3 batched GEMM and GETT are the instaced CuTe Layouts.
CollectiveBuilders rely on detecting the static-1 in the stride tuples to determine the major mode,
which is what the example demonstrates. However, it is possible to have all modes be dynamic as well
if the user assembles a CollectiveMma manually and ensures that the runtime strides are compatible
with the static micro-kernel of the collective (TiledMma, TiledCopy, and smem layouts). On the other
hand, a user can have more than one static stride too (which need not correspond to the major mode).
In particular, this example demonstrates a GETT where the 0th M-mode (M0) in A and the 0th K-mode (K0)
in B are major. All other combinations of major modes are supported, with the exception of mixed
K-major scenarios where both A and B are K-major (e.g. K0 is major in A but K1 is major in B).
NVIDIA Hopper architecture's TMA feature makes the predictaion required to implement these complicated
kernels trivial, as it is all handled by TMA itself without requiring any programmer effort.
Example executions, where the stride order defines the major-order (major on the left):
51_hopper_gett --modeC=m,n,l --modeA=m,k,l --modeB=k,n,l --extents=m:4096,n:4096,k:4096
51_hopper_gett --modeC=l,m,n --modeA=m,l,k --modeB=k,n,l --extents=m:128,n:128,k:128,l:64
51_hopper_gett --modeC=m,a,b,p,q,n,l --modeA=m,l,b,k,a --modeB=k,n,p,q,l --extents=m:32,a:32,b:3,n:128,k:128,l:4,p:3,q:3
*/
#include "gett_kernel.cuh"
#include "thrust/host_vector.h"
#include "thrust/device_vector.h"
#include "cutlass/cutlass.h"
#include "cutlass/gemm/device/gemm_universal_adapter.h"
#include "cutlass/gemm/kernel/gemm_universal.hpp"
#include "cutlass/gemm/collective/collective_builder.hpp"
#include "cutlass/epilogue/collective/default_epilogue.hpp"
#include "cutlass/util/gett_commandline.hpp"
#include "cutlass/util/reference/device/gett.hpp"
#include "cutlass/util/reference/device/tensor_compare.h"
#include "cutlass/util/print_error.hpp"
namespace example {
// Returns true if the left-most value in the tuple is statically known to be 1
template<class Stride>
constexpr bool
is_left_major() {
// Account for stride types with and without batch mode and batch modes with static zero stride
return cute::is_constant<1, decltype(cute::size<0,0>(Stride{}))>::value;
}
// Same as cute::make_int_tuple but inserts a major stride (Int<1>) for the leftmost mode if required
template <int Rank, bool IsMajor, class Indexable>
static constexpr
auto
make_stride_tuple(Indexable const& t, int n, int64_t init_default = 0) {
static_assert(Rank > 1);
if constexpr (IsMajor) {
return cute::transform(cute::make_seq<Rank>{}, [&](auto i) {
if constexpr (i == 0) {
return cute::Int<1>{};
}
else {
return i < n ? t[i] : init_default;
}
});
}
else {
return cute::make_int_tuple<Rank>(t, n, init_default);
}
}
} // namespace example
//////////////////////////////////////////////////////////////////////////////
int
main(int argc, char const* argv[]) {
#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED)
using namespace cute;
if (argc != 5) {
std::cout << "Number of command line args must be 4.\n";
cutlass::GettCommandLine::print_usage();
return 0;
}
//
// Define the stride types for A, B, C, and D
//
// Stride for A (left input). If reduction mode is major, same must be major in B
// For this example, M0 is major in A.
using RowModeStridesA = cute::Stride<cute::Int<1>, int64_t, int64_t, int64_t>;
using RedModeStridesA = cute::Stride<int64_t, int64_t, int64_t>;
using BatModeStridesA = cute::Stride<int64_t, int64_t, int64_t, int64_t>;
// Stride for B (right input). If reduction mode is major, same must be major in A
// For this example, K0 is major in B.
using ColModeStridesB = cute::Stride<int64_t, int64_t, int64_t, int64_t>;
using RedModeStridesB = cute::Stride<cute::Int<1>, int64_t, int64_t>;
using BatModeStridesB = cute::Stride<int64_t, int64_t, int64_t, int64_t>;
// Strides for output, which can all be dynamic.
using RowModeStridesC = cute::Stride<int64_t, int64_t, int64_t, int64_t>;
using ColModeStridesC = cute::Stride<int64_t, int64_t, int64_t, int64_t>;
using BatModeStridesC = cute::Stride<int64_t, int64_t, int64_t, int64_t>;
// Assmble our rank-3 multi-mode strides for the in/out tensors
using StrideA = cute::Stride<RowModeStridesA, RedModeStridesA, BatModeStridesA>;
using StrideB = cute::Stride<ColModeStridesB, RedModeStridesB, BatModeStridesB>;
using StrideC = cute::Stride<RowModeStridesC, ColModeStridesC, BatModeStridesC>;
// Note: C and D share strides here for simplicity.
// In general, they need not have the same layout.
using StrideD = StrideC;
//
// Define element types for tensors and intermediate values
//
using ElementA = cutlass::half_t;
using ElementB = cutlass::half_t;
using ElementC = cutlass::half_t;
using ElementD = float;
using ElementAccumulator = float;
using ElementEpilogue = float;
// The following constexpr values set the max number of modes in each MNKL mode
constexpr int MaxRank_M = cute::rank(RowModeStridesA{}); // Max row modes
constexpr int MaxRank_N = cute::rank(ColModeStridesB{}); // Max column modes
constexpr int MaxRank_K = cute::rank(RedModeStridesA{}); // Max contraction modes
constexpr int MaxRank_L = cute::rank(BatModeStridesA{}); // Max batch modes
static_assert(cute::rank(RowModeStridesA{}) == cute::rank(RowModeStridesC{}));
static_assert(cute::rank(ColModeStridesB{}) == cute::rank(RowModeStridesC{}));
static_assert(cute::rank(RedModeStridesA{}) == cute::rank(RedModeStridesB{}));
static_assert(cute::rank(BatModeStridesA{}) == cute::rank(BatModeStridesC{}));
static_assert(cute::rank(BatModeStridesB{}) == cute::rank(BatModeStridesC{}));
// Parse command line to get modes, extents, and strides
cutlass::GettCommandLine cmd;
auto parsed_args = cmd.parse(argc, argv, true);
auto& m = parsed_args.M;
auto& ldAm = parsed_args.ldAm;
auto& ldCm = parsed_args.ldCm;
int rank_m = int(m.size());
auto& n = parsed_args.N;
auto& ldBn = parsed_args.ldBn;
auto& ldCn = parsed_args.ldCn;
int rank_n = int(n.size());
auto& k = parsed_args.K;
auto& ldAk = parsed_args.ldAk;
auto& ldBk = parsed_args.ldBk;
int rank_k = int(k.size());
auto& l = parsed_args.L;
auto& ldAl = parsed_args.ldAl;
auto& ldBl = parsed_args.ldBl;
auto& ldCl = parsed_args.ldCl;
int rank_l = int(l.size());
if ((rank_m > MaxRank_M) || (rank_n > MaxRank_N) || (rank_k > MaxRank_K) || (rank_l > MaxRank_L)) {
std::cerr << "ERROR: Input has more modes than statically configured.";
return 1;
}
// Check that the user input major stride match the static major strides.
if (example::is_left_major<RowModeStridesA>() && (ldAm[0] != 1)) {
std::cerr << "ERROR: A_M0 is expected to be major, but was not in the provided input!\n";
return 1;
}
if (example::is_left_major<RedModeStridesA>() && (ldAk[0] != 1)) {
std::cerr << "ERROR: A_K0 is expected to be major, but was not in the provided input!\n";
return 1;
}
if (example::is_left_major<ColModeStridesB>() && (ldBn[0] != 1)) {
std::cerr << "ERROR: B_N0 is expected to be major, but was not in the provided input!\n";
return 1;
}
if (example::is_left_major<RedModeStridesB>() && (ldBk[0] != 1)) {
std::cerr << "ERROR: B_K0 is expected to be major, but was not in the provided input!\n";
return 1;
}
// Convert to `cute::Tuple`s and set up arguments
auto M = make_int_tuple<MaxRank_M>(m.data(), rank_m, 1);
auto dAm = example::make_stride_tuple<MaxRank_M, example::is_left_major<RowModeStridesA>()>(ldAm.data(), rank_m);
auto dCm = example::make_stride_tuple<MaxRank_M, example::is_left_major<RowModeStridesC>()>(ldCm.data(), rank_m);
auto N = make_int_tuple<MaxRank_N>(n.data(), rank_n, 1);
auto dBn = example::make_stride_tuple<MaxRank_N, example::is_left_major<ColModeStridesB>()>(ldBn.data(), rank_n);
auto dCn = example::make_stride_tuple<MaxRank_N, example::is_left_major<ColModeStridesC>()>(ldCn.data(), rank_n);
auto K = make_int_tuple<MaxRank_K>(k.data(), rank_k, 1);
auto dAk = example::make_stride_tuple<MaxRank_K, example::is_left_major<RedModeStridesA>()>(ldAk.data(), rank_k);
auto dBk = example::make_stride_tuple<MaxRank_K, example::is_left_major<RedModeStridesB>()>(ldBk.data(), rank_k);
auto L = make_int_tuple<MaxRank_L>(l.data(), rank_l, 1);
auto dAl = make_int_tuple<MaxRank_L>(ldAl.data(), rank_l, 0);
auto dBl = make_int_tuple<MaxRank_L>(ldBl.data(), rank_l, 0);
auto dCl = make_int_tuple<MaxRank_L>(ldCl.data(), rank_l, 0);
// Concat tuples to turn it into rank-4 problem shape and rank-3 strides, just like GEMM
auto problem_shape = make_shape(M, N, K, L);
StrideA stride_A = make_stride(dAm, dAk, dAl);
StrideB stride_B = make_stride(dBn, dBk, dBl);
StrideC stride_C = make_stride(dCm, dCn, dCl);
StrideD stride_D = stride_C;
auto alpha = ElementEpilogue(1.0f);
auto beta = ElementEpilogue(1.0f);
//
// Allocate and init tensors
//
auto M_size = std::accumulate(std::begin(m), std::end(m), 1, std::multiplies<>{});
auto N_size = std::accumulate(std::begin(n), std::end(n), 1, std::multiplies<>{});
auto K_size = std::accumulate(std::begin(k), std::end(k), 1, std::multiplies<>{});
auto L_size = std::accumulate(std::begin(l), std::end(l), 1, std::multiplies<>{});
thrust::host_vector<ElementA> h_A(M_size * K_size * L_size);
thrust::host_vector<ElementB> h_B(N_size * K_size * L_size);
thrust::host_vector<ElementC> h_C(M_size * N_size * L_size);
thrust::host_vector<ElementD> h_D(M_size * N_size * L_size);
// Note: the cast to int here is to avoid false-negative ref-checks which can
// occur due to floating point arithmetic not being purely associative.
for (auto& a : h_A) a = ElementA(int(4*(rand() / double(RAND_MAX)) - 1));
for (auto& b : h_B) b = ElementB(int(4*(rand() / double(RAND_MAX)) - 1));
for (auto& c : h_C) c = ElementC(int(4*(rand() / double(RAND_MAX)) - 1));
for (auto& d : h_D) d = ElementD(-1);
thrust::device_vector<ElementA> d_A = h_A;
thrust::device_vector<ElementB> d_B = h_B;
thrust::device_vector<ElementC> d_C = h_C;
thrust::device_vector<ElementD> cutlass_result = h_D;
thrust::device_vector<ElementD> reference_result = h_D;
//
// Compute GETT
//
auto status = example::gett_kernel(
problem_shape,
d_A.data().get(), stride_A,
d_B.data().get(), stride_B,
ElementAccumulator{},
d_C.data().get(), stride_C,
cutlass_result.data().get(), stride_D,
alpha, beta);
if (cutlass::Status::kSuccess != status) {
std::cerr << "ERROR: GETT operator launch failed.\n";
return 1;
}
auto cuda_err = cudaDeviceSynchronize();
if (cudaSuccess != cuda_err) {
std::cerr << "ERROR: GETT operator execution failed. with error :";
std::cerr << cudaGetErrorString(cuda_err) << "\n";
return 1;
}
//
// Verify
//
cutlass::reference::device::gett(
problem_shape,
d_A.data().get(), stride_A,
d_B.data().get(), stride_B,
ElementAccumulator{},
d_C.data().get(), stride_C,
reference_result.data().get(), stride_D,
alpha, beta);
cuda_err = cudaDeviceSynchronize();
if (cudaSuccess != cuda_err) {
std::cerr << "ERROR: GETT reference execution failed. with error :";
std::cerr << cudaGetErrorString(cuda_err) << "\n";
return 1;
}
// Check if output from CUTLASS kernel and reference kernel are equal or not
bool passed = cutlass::reference::device::BlockCompareEqual(
reference_result.data().get(), cutlass_result.data().get(), cutlass_result.size());
if (passed) {
std::cout << "GETT verification passed.\n";
return 0;
}
else {
std::cerr << "ERROR: GETT verification failed! Printing detailed stats.\n";
h_D = reference_result;
thrust::host_vector<ElementD> h_cutlass_result = cutlass_result;
print_relative_error(h_cutlass_result.size(), h_cutlass_result.data(), h_D.data());
std::cout << "StrideA: "; print(stride_A); std::cout << '\n';
std::cout << "StrideB: "; print(stride_B); std::cout << '\n';
std::cout << "StrideC: "; print(stride_C); std::cout << '\n';
std::cout << "StrideD: "; print(stride_D); std::cout << '\n';
return 1;
}
#else
std::cerr << "Unsupported example. Please ensure CUTLASS_ARCH_MMA_SM90_SUPPORTED is defined.\n";
return 0;
#endif // defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED)
}