-
Notifications
You must be signed in to change notification settings - Fork 286
/
Copy pathdino.py
166 lines (146 loc) · 6.1 KB
/
dino.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
import copy
from typing import List, Tuple, Union
import torch
from pytorch_lightning import LightningModule
from torch import Tensor
from torch.nn import Identity
from torch.optim import SGD
from torch.optim.optimizer import Optimizer
from torchvision.models import resnet50
from lightly.loss import DINOLoss
from lightly.models.modules import DINOProjectionHead
from lightly.models.utils import (
activate_requires_grad,
deactivate_requires_grad,
get_weight_decay_parameters,
update_momentum,
)
from lightly.transforms import DINOTransform
from lightly.utils.benchmarking import OnlineLinearClassifier
from lightly.utils.scheduler import CosineWarmupScheduler, cosine_schedule
class DINO(LightningModule):
def __init__(self, batch_size_per_device: int, num_classes: int) -> None:
super().__init__()
self.save_hyperparameters()
self.batch_size_per_device = batch_size_per_device
resnet = resnet50()
resnet.fc = Identity() # Ignore classification head
self.backbone = resnet
self.projection_head = DINOProjectionHead()
self.student_backbone = copy.deepcopy(self.backbone)
self.student_projection_head = DINOProjectionHead(freeze_last_layer=1)
self.criterion = DINOLoss()
self.online_classifier = OnlineLinearClassifier(num_classes=num_classes)
def forward(self, x: Tensor) -> Tensor:
return self.backbone(x)
def forward_student(self, x: Tensor) -> Tensor:
features = self.student_backbone(x).flatten(start_dim=1)
projections = self.student_projection_head(features)
return projections
def on_train_start(self) -> None:
deactivate_requires_grad(self.backbone)
deactivate_requires_grad(self.projection_head)
def on_train_end(self) -> None:
activate_requires_grad(self.backbone)
activate_requires_grad(self.projection_head)
def training_step(
self, batch: Tuple[List[Tensor], Tensor, List[str]], batch_idx: int
) -> Tensor:
# Momentum update teacher.
momentum = cosine_schedule(
step=self.trainer.global_step,
max_steps=self.trainer.estimated_stepping_batches,
start_value=0.996,
end_value=1.0,
)
update_momentum(self.student_backbone, self.backbone, m=momentum)
update_momentum(self.student_projection_head, self.projection_head, m=momentum)
views, targets = batch[0], batch[1]
global_views = torch.cat(views[:2])
local_views = torch.cat(views[2:])
teacher_features = self.forward(global_views).flatten(start_dim=1)
teacher_projections = self.projection_head(teacher_features)
student_projections = torch.cat(
[self.forward_student(global_views), self.forward_student(local_views)]
)
loss = self.criterion(
teacher_out=teacher_projections.chunk(2),
student_out=student_projections.chunk(len(views)),
epoch=self.current_epoch,
)
self.log_dict(
{"train_loss": loss, "ema_momentum": momentum},
prog_bar=True,
sync_dist=True,
batch_size=len(targets),
)
# Online classification.
cls_loss, cls_log = self.online_classifier.training_step(
(teacher_features.chunk(2)[0].detach(), targets), batch_idx
)
self.log_dict(cls_log, sync_dist=True, batch_size=len(targets))
return loss + cls_loss
def validation_step(
self, batch: Tuple[Tensor, Tensor, List[str]], batch_idx: int
) -> Tensor:
images, targets = batch[0], batch[1]
features = self.forward(images).flatten(start_dim=1)
cls_loss, cls_log = self.online_classifier.validation_step(
(features.detach(), targets), batch_idx
)
self.log_dict(cls_log, prog_bar=True, sync_dist=True, batch_size=len(targets))
return cls_loss
def configure_optimizers(self):
# Don't use weight decay for batch norm, bias parameters, and classification
# head to improve performance.
params, params_no_weight_decay = get_weight_decay_parameters(
[self.student_backbone, self.student_projection_head]
)
# For ResNet50 we use SGD instead of AdamW/LARS as recommended by the authors:
# https://github.com/facebookresearch/dino#resnet-50-and-other-convnets-trainings
optimizer = SGD(
[
{"name": "dino", "params": params},
{
"name": "dino_no_weight_decay",
"params": params_no_weight_decay,
"weight_decay": 0.0,
},
{
"name": "online_classifier",
"params": self.online_classifier.parameters(),
"weight_decay": 0.0,
},
],
lr=0.03 * self.batch_size_per_device * self.trainer.world_size / 256,
momentum=0.9,
weight_decay=1e-4,
)
scheduler = {
"scheduler": CosineWarmupScheduler(
optimizer=optimizer,
warmup_epochs=int(
self.trainer.estimated_stepping_batches
/ self.trainer.max_epochs
* 10
),
max_epochs=int(self.trainer.estimated_stepping_batches),
),
"interval": "step",
}
return [optimizer], [scheduler]
def configure_gradient_clipping(
self,
optimizer: Optimizer,
gradient_clip_val: Union[int, float, None] = None,
gradient_clip_algorithm: Union[str, None] = None,
) -> None:
self.clip_gradients(
optimizer=optimizer,
gradient_clip_val=3.0,
gradient_clip_algorithm="norm",
)
self.student_projection_head.cancel_last_layer_gradients(self.current_epoch)
# For ResNet50 we adjust crop scales as recommended by the authors:
# https://github.com/facebookresearch/dino#resnet-50-and-other-convnets-trainings
transform = DINOTransform(global_crop_scale=(0.14, 1), local_crop_scale=(0.05, 0.14))