-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
reduce autotuning range and add second bmm to benchmark
- Loading branch information
1 parent
2d30576
commit cf4e6c7
Showing
2 changed files
with
33 additions
and
13 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -4,18 +4,13 @@ | |
from pca_topk import gather_outer_bmv_optimized, gather_inner_matrix_only_bmv_optimized | ||
from sparq import gather_outer_bmv, gather_inner_matrix_only_bmv | ||
|
||
|
||
B = 1 | ||
B = 4 | ||
NH = 32 | ||
S = 500 | ||
S = 800 | ||
D = 128 | ||
dtype = torch.float16 | ||
|
||
print("===== BENCHMARKING s.v with various sparsities =======") | ||
print("Batch Size : ", B) | ||
print("Number of Heads : ", NH) | ||
print("Number of Key Tokens (or sequence length) : ", S) | ||
print("Hidden dimension per head : ", D) | ||
|
||
|
||
configs = [ | ||
triton.testing.Benchmark( | ||
|
@@ -25,7 +20,7 @@ | |
# Possible values for `line_arg` | ||
# Don't compare to cublas for fp8 cases as torch.matmul doesn't support fp8 at the moment. | ||
line_vals=["torch", "triton-optimized"], # Label name for the lines | ||
line_names=["torch (full keys)", "Triton (Optimized)"], # Line styles | ||
line_names=["torch (full keys and values)", "Triton (Optimized)"], # Line styles | ||
styles=[("black", "-"), ("blue", "-")], | ||
ylabel="TFLOPS", # Label name for the y-axis | ||
plot_name="matmul-performance-" + ("fp16 (time in ms)" ), # Name for the plot, used also as a file name for saving the plot. | ||
|
@@ -72,7 +67,22 @@ def benchmark_bmm2(sparsity, B, NH, S, D, provider): | |
|
||
return ms, max_ms, min_ms | ||
|
||
|
||
|
||
print("===== BENCHMARKING [email protected]() with various sparsities =======") | ||
print("Batch Size : ", B) | ||
print("Number of Heads : ", NH) | ||
print("Number of Key Tokens (or sequence length) : ", S) | ||
print("Hidden dimension per head : ", D) | ||
result = benchmark_bmm1.run(print_data=True) | ||
|
||
|
||
|
||
print("===== BENCHMARKING s@v with various sparsities =======") | ||
print("Batch Size : ", B) | ||
print("Number of Heads : ", NH) | ||
print("Number of Key Tokens (or sequence length) : ", S) | ||
print("Hidden dimension per head : ", D) | ||
result = benchmark_bmm2.run(print_data=True) | ||
|
||
print(result) | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters