Skip to content

Commit

Permalink
reduce autotuning range and add second bmm to benchmark
Browse files Browse the repository at this point in the history
  • Loading branch information
siddharth9820 authored and prajwal1210 committed May 11, 2024
1 parent 2d30576 commit cf4e6c7
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 13 deletions.
30 changes: 20 additions & 10 deletions methods/pca_topk/kernel/benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,18 +4,13 @@
from pca_topk import gather_outer_bmv_optimized, gather_inner_matrix_only_bmv_optimized
from sparq import gather_outer_bmv, gather_inner_matrix_only_bmv


B = 1
B = 4
NH = 32
S = 500
S = 800
D = 128
dtype = torch.float16

print("===== BENCHMARKING s.v with various sparsities =======")
print("Batch Size : ", B)
print("Number of Heads : ", NH)
print("Number of Key Tokens (or sequence length) : ", S)
print("Hidden dimension per head : ", D)


configs = [
triton.testing.Benchmark(
Expand All @@ -25,7 +20,7 @@
# Possible values for `line_arg`
# Don't compare to cublas for fp8 cases as torch.matmul doesn't support fp8 at the moment.
line_vals=["torch", "triton-optimized"], # Label name for the lines
line_names=["torch (full keys)", "Triton (Optimized)"], # Line styles
line_names=["torch (full keys and values)", "Triton (Optimized)"], # Line styles
styles=[("black", "-"), ("blue", "-")],
ylabel="TFLOPS", # Label name for the y-axis
plot_name="matmul-performance-" + ("fp16 (time in ms)" ), # Name for the plot, used also as a file name for saving the plot.
Expand Down Expand Up @@ -72,7 +67,22 @@ def benchmark_bmm2(sparsity, B, NH, S, D, provider):

return ms, max_ms, min_ms



print("===== BENCHMARKING [email protected]() with various sparsities =======")
print("Batch Size : ", B)
print("Number of Heads : ", NH)
print("Number of Key Tokens (or sequence length) : ", S)
print("Hidden dimension per head : ", D)
result = benchmark_bmm1.run(print_data=True)



print("===== BENCHMARKING s@v with various sparsities =======")
print("Batch Size : ", B)
print("Number of Heads : ", NH)
print("Number of Key Tokens (or sequence length) : ", S)
print("Hidden dimension per head : ", D)
result = benchmark_bmm2.run(print_data=True)

print(result)

16 changes: 13 additions & 3 deletions methods/pca_topk/kernel/pca_topk.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import triton.language as tl
from torch import Tensor

def get_autotune_config():
def get_autotune_config_outer():
return [
triton.Config({"n_chunk": 4}),
triton.Config({"n_chunk": 8}),
Expand All @@ -20,7 +20,7 @@ def get_autotune_config():
]

@triton.autotune(
configs=get_autotune_config(),
configs=get_autotune_config_outer(),
key=['b', 'n', 'k'],
)
@triton.jit
Expand Down Expand Up @@ -110,8 +110,18 @@ def gather_outer_bmv_optimized(A: Tensor, B: Tensor, I: Tensor) -> Tensor:
return Y


def get_autotune_config_inner():
return [
triton.Config({"n_chunk": 4}),
triton.Config({"n_chunk": 8}),
triton.Config({"n_chunk": 16}),
triton.Config({"n_chunk": 32}),
triton.Config({"n_chunk": 64}),
triton.Config({"n_chunk": 128}),
]

@triton.autotune(
configs=get_autotune_config(),
configs=get_autotune_config_inner(),
key=['b', 'n', 'k'],
)
@triton.jit
Expand Down

0 comments on commit cf4e6c7

Please sign in to comment.