-
Notifications
You must be signed in to change notification settings - Fork 25
/
Copy pathnan_detect.py
30 lines (25 loc) · 886 Bytes
/
nan_detect.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
import torch
from torch.utils._python_dispatch import TorchDispatchMode
from torch.utils._pytree import tree_flatten
class NanDetect(TorchDispatchMode):
def __torch_dispatch__(self, func, types, args, kwargs=None):
kwargs = kwargs or {}
res = func(*args, **kwargs)
flat_res, _ = tree_flatten(res)
for t in flat_res:
if not torch.is_tensor(t):
continue
try:
if (t != t).any():
raise RuntimeError(
f"Function {func}(*{args}, **{kwargs}) " "returned a NaN"
)
except NotImplementedError:
pass
return res
a = torch.tensor([0.,])
print(a.div(a))
# This will raise
# RuntimeError: Function aten.div.Tensor(*(tensor([0.]), tensor([0.])), **{}) returned a NaN
with NanDetect():
print(a.div(a))