-
Notifications
You must be signed in to change notification settings - Fork 4
/
Copy pathdemo.py
72 lines (62 loc) · 2.1 KB
/
demo.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
from typing import Dict, Set
import tvm
import tvm.dlight as dl
from dlight_bench import DlightBench
def factorized(factor: int, minimum: int):
"""Factorized dynamic shape variable sample function factory."""
def sample_dym_var_sequential(
dym_vars: Set[str], sample_idx: int, _: int
) -> Dict[str, int]:
"""
Sequential dynamic shape variable sample function.
Sample a sequential value for each dynamic shape variable.
Parameters
----------
dym_vars : Set[str]
Dynamic shape variable set, e.g., {"n", "m"}
sample_idx : int
Sample index denotes the index the function is called for the same
dynamic shape variable dictionary & function.
sample_num : int
Sample number denotes the total number of samples.
Returns
-------
result : Dict[str, int]
Dynamic shape variable sample, e.g., {"n": 64, "m": 128}
"""
results = {}
cnt = 1
for var in dym_vars:
results[var] = 2 ** (sample_idx // cnt % factor + minimum)
cnt *= factor
return results
return sample_dym_var_sequential
with tvm.target.Target("nvidia/geforce-rtx-3070"):
DlightBench.benchmark(
"vicuna_v1_7b_fp16",
func_names=["matmul"],
passes=[tvm.tir.transform.DefaultGPUSchedule()],
sample_func=factorized(5, 5),
sample_num_per_func=10,
)
DlightBench.benchmark(
"vicuna_v1_7b_fp16",
func_names=["matmul"],
passes=[dl.ApplyDefaultSchedule(dl.gpu.Fallback())],
sample_func=factorized(5, 5),
sample_num_per_func=10,
)
DlightBench.benchmark(
"vicuna_v1_7b_fp16",
category="Fallback",
passes=[dl.ApplyDefaultSchedule(dl.gpu.Fallback())],
sample_func=factorized(5, 5),
sample_num_per_func=10,
)
DlightBench.benchmark(
"llama_2_7b_chat_hf_q4f16_1",
category="GEMV",
passes=[dl.ApplyDefaultSchedule(dl.gpu.GEMV())],
sample_func=factorized(5, 5),
sample_num_per_func=10,
)