-
Notifications
You must be signed in to change notification settings - Fork 286
/
Copy pathdclw.py
113 lines (100 loc) · 4.45 KB
/
dclw.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
import math
from typing import List, Tuple
import torch
from pytorch_lightning import LightningModule
from torch import Tensor
from torch.nn import Identity
from torchvision.models import resnet50
from lightly.loss.dcl_loss import DCLWLoss
from lightly.models.modules import SimCLRProjectionHead
from lightly.models.utils import get_weight_decay_parameters
from lightly.transforms import SimCLRTransform
from lightly.utils.benchmarking import OnlineLinearClassifier
from lightly.utils.lars import LARS
from lightly.utils.scheduler import CosineWarmupScheduler
class DCLW(LightningModule):
def __init__(self, batch_size_per_device: int, num_classes: int) -> None:
super().__init__()
self.save_hyperparameters()
self.batch_size_per_device = batch_size_per_device
resnet = resnet50()
resnet.fc = Identity() # Ignore classification head
self.backbone = resnet
self.projection_head = SimCLRProjectionHead() # DCLW uses SimCLR head
self.criterion = DCLWLoss(temperature=0.1, sigma=0.5, gather_distributed=True)
self.online_classifier = OnlineLinearClassifier(num_classes=num_classes)
def forward(self, x: Tensor) -> Tensor:
return self.backbone(x)
def training_step(
self, batch: Tuple[List[Tensor], Tensor, List[str]], batch_idx: int
) -> Tensor:
views, targets = batch[0], batch[1]
features = self.forward(torch.cat(views)).flatten(start_dim=1)
z = self.projection_head(features)
z0, z1 = z.chunk(len(views))
loss = self.criterion(z0, z1)
self.log(
"train_loss", loss, prog_bar=True, sync_dist=True, batch_size=len(targets)
)
cls_loss, cls_log = self.online_classifier.training_step(
(features.detach(), targets.repeat(len(views))), batch_idx
)
self.log_dict(cls_log, sync_dist=True, batch_size=len(targets))
return loss + cls_loss
def validation_step(
self, batch: Tuple[Tensor, Tensor, List[str]], batch_idx: int
) -> Tensor:
images, targets = batch[0], batch[1]
features = self.forward(images).flatten(start_dim=1)
cls_loss, cls_log = self.online_classifier.validation_step(
(features.detach(), targets), batch_idx
)
self.log_dict(cls_log, prog_bar=True, sync_dist=True, batch_size=len(targets))
return cls_loss
def configure_optimizers(self):
# Don't use weight decay for batch norm, bias parameters, and classification
# head to improve performance.
params, params_no_weight_decay = get_weight_decay_parameters(
[self.backbone, self.projection_head]
)
optimizer = LARS(
[
{"name": "dclw", "params": params},
{
"name": "dclw_no_weight_decay",
"params": params_no_weight_decay,
"weight_decay": 0.0,
},
{
"name": "online_classifier",
"params": self.online_classifier.parameters(),
"weight_decay": 0.0,
},
],
# DCLW uses SimCLR's learning rate scaling scheme.
# Square root learning rate scaling improves performance for small
# batch sizes (<=2048) and few training epochs (<=200). Alternatively,
# linear scaling can be used for larger batches and longer training:
# lr=0.3 * self.batch_size_per_device * self.trainer.world_size / 256
# See Appendix B.1. in the SimCLR paper https://arxiv.org/abs/2002.05709
lr=0.075 * math.sqrt(self.batch_size_per_device * self.trainer.world_size),
momentum=0.9,
# Note: Paper uses weight decay of 1e-6 but reference code 1e-4. See:
# https://github.com/google-research/simclr/blob/2fc637bdd6a723130db91b377ac15151e01e4fc2/README.md?plain=1#L103
weight_decay=1e-6,
)
scheduler = {
"scheduler": CosineWarmupScheduler(
optimizer=optimizer,
warmup_epochs=int(
self.trainer.estimated_stepping_batches
/ self.trainer.max_epochs
* 10
),
max_epochs=int(self.trainer.estimated_stepping_batches),
),
"interval": "step",
}
return [optimizer], [scheduler]
# DCLW uses SimCLR augmentations
transform = SimCLRTransform()