-
Notifications
You must be signed in to change notification settings - Fork 9
/
Copy pathsam.py
82 lines (65 loc) · 2.48 KB
/
sam.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
from typing import Iterable
import torch
from torch.optim._multi_tensor import SGD
__all__ = ["SAMSGD"]
class SAMSGD(SGD):
""" SGD wrapped with Sharp-Aware Minimization
Args:
params: tensors to be optimized
lr: learning rate
momentum: momentum factor
dampening: damping factor
weight_decay: weight decay factor
nesterov: enables Nesterov momentum
rho: neighborhood size
"""
def __init__(self,
params: Iterable[torch.Tensor],
lr: float,
momentum: float = 0,
dampening: float = 0,
weight_decay: float = 0,
nesterov: bool = False,
rho: float = 0.05,
):
if rho <= 0:
raise ValueError(f"Invalid neighborhood size: {rho}")
super().__init__(params, lr, momentum, dampening, weight_decay, nesterov)
# todo: generalize this
if len(self.param_groups) > 1:
raise ValueError("Not supported")
self.param_groups[0]["rho"] = rho
@torch.no_grad()
def step(self,
closure
) -> torch.Tensor:
"""
Args:
closure: A closure that reevaluates the model and returns the loss.
Returns: the loss value evaluated on the original point
"""
closure = torch.enable_grad()(closure)
loss = closure().detach()
for group in self.param_groups:
grads = []
params_with_grads = []
rho = group['rho']
# update internal_optim's learning rate
for p in group['params']:
if p.grad is not None:
# without clone().detach(), p.grad will be zeroed by closure()
grads.append(p.grad.clone().detach())
params_with_grads.append(p)
device = grads[0].device
# compute \hat{\epsilon}=\rho/\norm{g}\|g\|
grad_norm = torch.stack([g.detach().norm(2).to(device) for g in grads]).norm(2)
epsilon = grads # alias for readability
torch._foreach_mul_(epsilon, rho / grad_norm)
# virtual step toward \epsilon
torch._foreach_add_(params_with_grads, epsilon)
# compute g=\nabla_w L_B(w)|_{w+\hat{\epsilon}}
closure()
# virtual step back to the original point
torch._foreach_sub_(params_with_grads, epsilon)
super().step()
return loss