-
Notifications
You must be signed in to change notification settings - Fork 25
/
Copy pathdispatch_mem_profiler.py
151 lines (122 loc) · 4.66 KB
/
dispatch_mem_profiler.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
from collections import defaultdict
from typing import Dict
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils._python_dispatch import TorchDispatchMode
aten = torch.ops.aten
MB = 1024 * 1024.0
operator_names: Dict[str, int] = defaultdict(int)
mem_usage: Dict[str, float] = defaultdict(float)
markers: Dict[str, int] = defaultdict(int)
series: Dict[str, Dict[str, float]] = defaultdict(lambda: defaultdict(float))
def normalize_tuple(x):
if not isinstance(x, tuple):
return (x,)
return x
def reduce_to_scalar_loss(inp):
return inp.sum()
class MemoryProfileDispatchMode(TorchDispatchMode):
def __init__(self, verbose=False):
self.verbose: bool = verbose
def __torch_dispatch__(self, func, types, args=..., kwargs=None):
rs = func(*args, **kwargs)
if func == torch.ops.aten.detach.default:
return rs
mem: float = torch.cuda.memory_allocated() / MB
func_name: str = func.__name__ + "_" + str(operator_names[func.__name__])
operator_names[func.__name__] = operator_names[func.__name__] + 1
mem_usage[func_name] = mem
if self.verbose:
print("Mem Usage (" + func_name + "): ", mem)
return rs
def clear_state():
operator_names.clear()
mem_usage.clear()
def add_series(series_name):
global mem_usage
fin_usage = torch.cuda.memory_allocated() / MB
mem_usage["fin_usage"] = fin_usage
series[series_name] = mem_usage
mem_usage = defaultdict(float)
def save_graph(filename: str):
for series_name, mem_usage in series.items():
y = mem_usage.values()
min_val = min(y)
max_val = max(y)
x = [i for i in range(len(y))]
plt.plot(x, y, label=series_name)
plt.xlabel("# Operator Calls")
plt.ylabel("Allocated Memory (MB)")
plt.title(filename)
for marker_name, marker in markers.items():
plt.plot([marker, marker], [min_val, max_val], "k-", lw=2, label=marker_name)
plt.legend()
print("Saving Graph")
plt.savefig(filename)
plt.close()
markers.clear()
series.clear()
def add_marker(marker_name):
k = len(series.keys())
last_val_num = len(mem_usage.values())
markers[marker_name + str(k)] = last_val_num
def mem_profile_model(mod: torch.nn.Module, inp: torch.Tensor):
with MemoryProfileDispatchMode(True):
pred = mod(inp)
loss = reduce_to_scalar_loss(pred)
loss.backward()
mod.zero_grad(True)
torch.cuda.synchronize()
clear_state()
pred = mod(inp)
loss = reduce_to_scalar_loss(pred)
add_marker("fw_bw_boundary")
loss.backward()
if __name__ == "__main__":
try:
import torchvision.models as models
from functorch.compile import aot_module
from functorch.compile import min_cut_rematerialization_partition
from functorch.compile import nop
from functorch.compile import print_compile
mod: torch.nn.Module = models.resnet18().cuda()
inp: torch.Tensor = torch.randn(32, 3, 224, 224, device="cuda")
mem_profile_model(mod, inp)
add_series("eager_mode")
mod3 = aot_module(mod, nop, partition_fn=min_cut_rematerialization_partition)
mem_profile_model(mod3, inp)
add_series("aot_autograd_min_cut")
save_graph("Resnet_mem_usage")
clear_state()
with MemoryProfileDispatchMode(True):
mod3 = aot_module(
mod, nop, partition_fn=min_cut_rematerialization_partition
)
mod3(inp).sum().backward()
add_series("aot_autograd_mem_usage")
save_graph("autograd_mem_usage")
except ImportError:
class MyModel(torch.nn.Module):
def __init__(self):
super(MyModel, self).__init__()
self.conv1 = nn.Conv2d(3, 6, 5)
self.pool = nn.MaxPool2d(2, 2)
self.conv2 = nn.Conv2d(6, 16, 5)
self.fc1 = nn.Linear(16 * 5 * 5, 120)
self.fc2 = nn.Linear(120, 84)
self.fc3 = nn.Linear(84, 10)
def forward(self, x):
x = self.pool(F.relu(self.conv1(x)))
x = self.pool(F.relu(self.conv2(x)))
x = torch.flatten(x, 1) # flatten all dimensions except batch
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = self.fc3(x)
return x
mod: torch.nn.Module = MyModel().cuda()
inp: torch.Tensor = torch.randn(512, 3, 32, 32, device="cuda")
mem_profile_model(mod, inp)
add_series("eager_mode")
save_graph("Model_mem_usage")