-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathloss.py
executable file
·196 lines (182 loc) · 9.41 KB
/
loss.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
import torch
import torch.nn as nn
import torch.nn.functional as F
from tools import *
class MVSLoss(nn.Module):
def __init__(self, args):
super(MVSLoss, self).__init__()
self.loss_funcs = [UnsupLossMultiStage_l05(args), ICCLossMultiStage(args), SCCLossMultiStage(args)]
self.args = args
def forward(self, data, outputs, epoch_idx):
losses = {}
total_loss = torch.tensor(0.0, dtype=torch.float32, device=data["imgs"].device, requires_grad=False)
for loss_func in self.loss_funcs:
loss, _ = loss_func(data, outputs, epoch_idx)
losses[loss_func.name] = loss.item()
total_loss = total_loss + loss
return total_loss, losses
class UnSupLoss(nn.Module):
def __init__(self,args):
super(UnSupLoss, self).__init__()
self.ssim = SSIM()
self.args = args
def forward(self, imgs, cams, depth, stage_idx):
imgs = torch.unbind(imgs, 1)
cams = torch.unbind(cams, 1)
assert len(imgs) == len(cams), "Different number of images and projection matrices"
num_views = len(imgs)
ref_img = imgs[0]
if stage_idx == 0:
ref_img = F.interpolate(ref_img, scale_factor=0.25)
elif stage_idx == 1:
ref_img = F.interpolate(ref_img, scale_factor=0.5)
else:
pass
ref_img = ref_img.permute(0, 2, 3, 1) # [B, C, H, W] --> [B, H, W, C]
ref_cam = cams[0]
self.reconstr_loss = 0
self.ssim_loss = 0
self.smooth_loss = 0
warped_img_list = []
mask_list = []
reprojection_losses = []
for view in range(1, num_views):
view_img = imgs[view]
view_cam = cams[view]
if stage_idx == 0:
view_img = F.interpolate(view_img, scale_factor=0.25)
elif stage_idx == 1:
view_img = F.interpolate(view_img, scale_factor=0.5)
else:
pass
view_img = view_img.permute(0, 2, 3, 1) # [B, C, H, W] --> [B, H, W, C]
warped_img, mask = inverse_warping(view_img, ref_cam, view_cam, depth)
if mask.sum() == 0:
self.unsup_loss = torch.tensor(0.0, dtype=torch.float32, device=mask.device)
return self.unsup_loss
warped_img_list.append(warped_img)
mask_list.append(mask)
reconstr_loss = compute_reconstr_loss_l0_5(warped_img, ref_img, mask, simple=False)
valid_mask = 1 - mask # replace all 0 values with INF
reprojection_losses.append(reconstr_loss + 1e4 * valid_mask)
# SSIM loss##
if view < 3:
self.ssim_loss += torch.mean(self.ssim(ref_img, warped_img, mask))
##smooth loss##
self.smooth_loss += depth_smoothness(depth.unsqueeze(dim=-1), ref_img, 1.0)
reprojection_volume = torch.stack(reprojection_losses).permute(1, 2, 3, 4, 0)
top_vals, top_inds = torch.topk(torch.neg(reprojection_volume), k=1, sorted=False)
top_vals = torch.neg(top_vals)
top_mask = top_vals < (1e4 * torch.ones_like(top_vals).cuda())
top_mask = top_mask.float()
top_vals = torch.mul(top_vals, top_mask)
self.reconstr_loss = torch.mean(torch.sum(top_vals, dim=-1))
self.unsup_loss = self.args.wrecon * self.reconstr_loss + 6 * self.ssim_loss + 0.18 * self.smooth_loss
return self.unsup_loss
class UnsupLossMultiStage_l05(nn.Module):
def __init__(self, args):
super(UnsupLossMultiStage_l05, self).__init__()
self.name = "unslossl05"
self.args = args
self.unsup_loss = UnSupLoss(args)
def forward(self, data, outputs, epoch_idx, **kwargs):
inputs = outputs
imgs = data["center_imgs"]
cams = data["proj_matrices"]
depth_loss_weights = self.args.dlossw
total_loss = torch.tensor(0.0, dtype=torch.float32, device=imgs.device, requires_grad=False)
scalar_outputs = {}
for (stage_inputs, stage_key) in [(inputs[k], k) for k in inputs.keys() if "stage" in k]:
stage_idx = int(stage_key.replace("stage", "")) - 1
depth_est = stage_inputs["depth"]
depth_loss = self.unsup_loss(imgs, cams[stage_key], depth_est, stage_idx)
if depth_loss_weights is not None:
total_loss = total_loss + depth_loss_weights[stage_idx] * depth_loss
else:
total_loss = total_loss + 1.0 * depth_loss
scalar_outputs["depth_loss_stage{}".format(stage_idx + 1)] = depth_loss
scalar_outputs["reconstr_loss_stage{}".format(stage_idx + 1)] = self.unsup_loss.reconstr_loss
scalar_outputs["ssim_loss_stage{}".format(stage_idx + 1)] = self.unsup_loss.ssim_loss
scalar_outputs["smooth_loss_stage{}".format(stage_idx + 1)] = self.unsup_loss.smooth_loss
return total_loss, scalar_outputs
class ICCLossMultiStage(nn.Module):
def __init__(self, args):
super(ICCLossMultiStage, self).__init__()
self.name = "iccloss"
self.args = args
# def forward(self, inputs, pseudo_depth, mask_ms, filter_mask, **kwargs):
def forward(self, data, outputs, epoch_idx, **kwargs):
if not "output2" in outputs: return torch.tensor(0.0, dtype=torch.float32, device=data["imgs"].device, requires_grad=False), {}
inputs = outputs["output2"]
pseudo_depth = outputs["output1"]["depth"].detach()
filter_mask = inputs["filter_mask"]
depth_loss_weights = self.args.dlossw
total_loss = torch.tensor(0.0, dtype=torch.float32, device=pseudo_depth.device, requires_grad=False)
scalar_outputs = {}
for (stage_inputs, stage_key) in [(inputs[k], k) for k in inputs.keys() if "stage" in k]:
stage_idx = int(stage_key.replace("stage", "")) - 1
depth_est = stage_inputs["depth"]
pseudo_gt = pseudo_depth.unsqueeze(dim=1)
if stage_idx == 0:
pseudo_gt_t = F.interpolate(pseudo_gt, scale_factor=(0.25, 0.25))
filter_mask_t = F.interpolate(filter_mask, scale_factor=(0.25, 0.25))
elif stage_idx == 1:
pseudo_gt_t = F.interpolate(pseudo_gt, scale_factor=(0.5, 0.5))
filter_mask_t = F.interpolate(filter_mask, scale_factor=(0.5, 0.5))
else:
pseudo_gt_t = pseudo_gt
filter_mask_t = filter_mask
filter_mask_t = filter_mask_t[:, 0, :, :]
pseudo_gt_t = pseudo_gt_t.squeeze(dim=1)
mask = filter_mask_t > 0.5
depth_loss = F.smooth_l1_loss(depth_est[mask], pseudo_gt_t[mask], reduction='mean')
if depth_loss_weights is not None:
total_loss = total_loss + depth_loss_weights[stage_idx] * depth_loss
else:
total_loss = total_loss + 1.0 * depth_loss
scalar_outputs["mva_loss_stage{}".format(stage_idx + 1)] = depth_loss
w_icc = adjust_w_icc(epoch_idx, self.args.w_icc, self.args.max_w_icc)
total_loss = total_loss * w_icc
return total_loss, scalar_outputs
class SCCLossMultiStage(nn.Module):
def __init__(self, args, **kwargs):
super(SCCLossMultiStage, self).__init__()
self.name = "sccloss"
self.conf = args.mask_conf
self.args = args
def forward(self, data, outputs, epoch_idx, **kwargs):
depth_loss_weights = self.args.dlossw
total_loss = torch.tensor(0.0, dtype=torch.float32, device=data["center_imgs"].device, requires_grad=False)
scalar_outputs = {}
photometric_confidence = outputs["output1"]["photometric_confidence"].clone().detach()
pseudo_depth = outputs["output1"]["depth"].clone().detach()
output3 = outputs["output3"]
w_scc = self.args.w_scc
for stage_key in [k for k in output3.keys() if "stage" in k]:
stage_idx = int(stage_key.replace("stage", "")) - 1 # 0 1 2
pseudo_gt = pseudo_depth.unsqueeze(dim=1)
photometric_confidence_tp = photometric_confidence.unsqueeze(dim=1)
if stage_idx == 0:
pseudo_gt_t = F.interpolate(pseudo_gt, scale_factor=(0.25, 0.25))
photometric_confidence_t = F.interpolate(photometric_confidence_tp, scale_factor=(0.25, 0.25))
mask_t = photometric_confidence_t > self.conf
elif stage_idx == 1:
pseudo_gt_t = F.interpolate(pseudo_gt, scale_factor=(0.5, 0.5))
photometric_confidence_t = F.interpolate(photometric_confidence_tp, scale_factor=(0.5, 0.5))
mask_t = photometric_confidence_t > self.conf
else:
pseudo_gt_t = pseudo_gt
photometric_confidence_t = photometric_confidence_tp
mask_t = photometric_confidence_t > self.conf
pseudo_gt_t = pseudo_gt_t.squeeze(dim=1)
mask_t = mask_t.squeeze(dim=1)
if torch.sum(mask_t.type(torch.float32)) == 0:
depth_loss = torch.tensor(0.0, dtype=torch.float32, device=data["center_imgs"].device)
else:
depth_loss = F.smooth_l1_loss(pseudo_gt_t[mask_t], output3[stage_key]["depth"][mask_t], reduction='mean')
if depth_loss_weights is not None:
total_loss = total_loss + depth_loss_weights[stage_idx] * depth_loss
else:
total_loss = total_loss + 1.0 * depth_loss
total_loss = total_loss * w_scc
return total_loss, scalar_outputs