-
Notifications
You must be signed in to change notification settings - Fork 25
/
Copy pathenhanced_error_mode.py
40 lines (34 loc) · 1.3 KB
/
enhanced_error_mode.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
import torch
from torch.utils._python_dispatch import TorchDispatchMode
from torch.utils._pytree import tree_map
import itertools
# cribbed from https://github.com/albanD/subclass_zoo/blob/main/logging_mode.py
class Lit:
def __init__(self, s):
self.s = s
def __repr__(self):
return self.s
def fmt(t: object) -> str:
if isinstance(t, torch.Tensor):
return Lit(f"torch.tensor(..., size={tuple(t.shape)}, dtype={t.dtype}, device='{t.device}')")
else:
return t
class EnhancedErrorMode(TorchDispatchMode):
def __torch_dispatch__(self, func, types, args, kwargs):
try:
return func(*args, **kwargs)
except Exception as ex:
fmt_args = ", ".join(
itertools.chain(
(repr(tree_map(fmt, a)) for a in args),
(f"{k}={tree_map(fmt, v)}" for k, v in kwargs.items()),
)
)
msg = f"...when running {func}({fmt_args})"
# https://stackoverflow.com/questions/17677680/how-can-i-add-context-to-an-exception-in-python
msg = f'{ex.args[0]}\n{msg}' if ex.args else msg
ex.args = (msg,) + ex.args[1:]
raise
if __name__ == "__main__":
with EnhancedErrorMode():
torch.matmul(torch.randn(3), torch.randn(4, 5))