-
Notifications
You must be signed in to change notification settings - Fork 25
/
Copy pathmemory_debugging_tensor.py
176 lines (145 loc) · 5.43 KB
/
memory_debugging_tensor.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
# -*- coding: utf-8 -*-
"""Subclass memory debugging
Automatically generated by Colaboratory.
Original file is located at
https://colab.research.google.com/drive/1vxJkMT1kTUpd9RoqEzf0i825fvKNQAZl
"""
import torch
# Some mode APIs are changing on master while this runs on colab
# print(torch.__version__)
# if not torch.__version__.split("+")[0] == "1.12.1":
# raise RuntimeError("This notebook is for pytorch 1.12")
import torch.nn as nn
import torch.nn.functional as F
import torchvision
from collections import defaultdict
from typing import Dict
from contextlib import contextmanager
from torch.utils._python_dispatch import TorchDispatchMode
# We do want to use non-full hooks here as they behavior similar to
# backward pre-hooks
import warnings
warnings.filterwarnings("ignore", "Using a non-full backward")
MB = 1024 * 1024.0
# Globals used to save state
operator_names: Dict[str, int] = defaultdict(int)
mem_usage: Dict[str, float] = defaultdict(float)
max_mem_usage: Dict[str, float] = defaultdict(float)
markers: Dict[str, int] = defaultdict(int)
cur_module: str = ""
op_id: int = 0
def clear_state():
operator_names.clear()
mem_usage.clear()
max_mem_usage.clear()
markers.clear()
# To add markers in the final print
def add_marker(marker_name):
marker_val = len(mem_usage.values())
markers[marker_name] = marker_val
def record_fn(fn_name):
global op_id
mem: float = torch.cuda.memory_allocated() / MB
mem_usage[op_id] = (fn_name, mem)
max_mem: float = torch.cuda.max_memory_allocated() / MB
max_mem_usage[op_id] = (fn_name, max_mem)
torch.cuda.reset_peak_memory_stats()
op_id += 1
# Mode that records all allocations
class MemoryProfileDispatchMode(TorchDispatchMode):
def __torch_dispatch__(self, func, types, args=..., kwargs=None):
rs = func(*args, **kwargs)
global cur_module
if func == torch.ops.aten.detach.default:
return rs
func_name: str = cur_module + '.' + func.__name__ + "_" + str(operator_names[func.__name__])
operator_names[func.__name__] = operator_names[func.__name__] + 1
record_fn(func_name)
return rs
# Functions to print and draw the graph
def show_graph():
import matplotlib.pyplot as plt
y = [gb for (name, gb) in mem_usage.values()]
min_val = min(y)
max_val = max(y)
x = [i for i in range(len(y))]
fig = plt.figure(figsize=(16,8))
plt.plot(x, list(y), label="memory")
plt.xlabel("# Operator Calls")
plt.ylabel("Allocated Memory (MB)")
# plt.title(filename)
for marker_name, marker in markers.items():
if marker_name == "fw_bw_boundary":
plt.plot([marker, marker], [min_val, max_val], "r", lw=2, label=marker_name)
else:
plt.plot([marker, marker], [min_val, max_val], "k-", lw=2, label=marker_name)
plt.legend()
def print_top_mem_op(top: int = 50):
global op_id
op_diff: Dict[str, float] = defaultdict(float)
op, pre_mem = mem_usage[0]
for i in range(1, op_id):
op, mem = mem_usage[i]
op_diff[op] = mem - pre_mem
pre_mem = mem
print("------------------------------------------------")
print(f"Top {top} ops that generates memory are:")
for k, v in sorted(op_diff.items(), key=lambda item: item[1], reverse=True)[:top]:
print(f"{k}: {v}MB")
print("------------------------------------------------")
# Module level printing and logging to make the Mode's output better
def mem_profile_model(mod: torch.nn.Module, *args):
with torch.utils._python_dispatch.push_torch_dispatch_mode(MemoryProfileDispatchMode):
torch.cuda.reset_peak_memory_stats()
mod.zero_grad(True)
clear_state()
record_fn("Start")
loss = mod(*args)
add_marker("fw_bw_boundary")
if isinstance(loss, dict):
loss = loss['out']
loss.sum().backward()
add_marker("bw_zero_boundary")
mod.zero_grad(set_to_none=True)
record_fn("Finished")
def fwd_wrapped(name):
def fwd_debug_hook(module, input) -> None:
global cur_module
cur_module = f"{name}.forward"
return fwd_debug_hook
def bwd_wrapped(name):
def bwd_debug_hook(module, input, out) -> None:
global cur_module
cur_module = f"{name}.backward"
return bwd_debug_hook
# this context manager attached/detecheds hooks for debugging
@contextmanager
def debug_model(model):
global op_id, cur_module
hooks = []
cur_module = 'forward'
op_id = 0
for name, module in model.named_modules():
hooks.append(module.register_forward_pre_hook(fwd_wrapped(name)))
hooks.append(module.register_backward_hook(bwd_wrapped(name)))
try:
yield model
finally:
for hook in hooks:
hook.remove()
def run_one_model(net, input):
net.cuda()
input = input.cuda()
with debug_model(net) as m:
# mem_profile_model(m, input1, input2)
mem_profile_model(m, input)
print_top_mem_op(20)
show_graph()
import torchvision
run_one_model(torchvision.models.resnet34(), torch.rand(32, 3, 224, 224, device="cuda"))
import torchvision
run_one_model(torchvision.models.mobilenet_v3_large(), torch.rand(32, 3, 224, 224, device="cuda"))
import torchvision
run_one_model(torchvision.models.segmentation.fcn_resnet50(), torch.rand(32, 3, 224, 224, device="cuda"))
import torchvision
run_one_model(torchvision.models.vision_transformer.vit_b_32(), torch.rand(32, 3, 224, 224, device="cuda"))