ibaiGorordo commited on
Commit
7e908f7
·
1 Parent(s): aa3e7f3

Upload model.py

Browse files
Files changed (1) hide show
  1. model.py +696 -0
model.py ADDED
@@ -0,0 +1,696 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import torchvision.models as models
5
+ from torch.jit import script
6
+
7
+
8
+ class WSConv2d(nn.Conv2d):
9
+ def __init___(self, in_channels, out_channels, kernel_size, stride=1,
10
+ padding=0, dilation=1, groups=1, bias=True):
11
+ super(WSConv2d, self).__init__(in_channels, out_channels, kernel_size, stride,
12
+ padding, dilation, groups, bias)
13
+
14
+ def forward(self, x):
15
+ weight = self.weight
16
+ weight_mean = weight.mean(dim=1, keepdim=True).mean(dim=2, keepdim=True).mean(dim=3, keepdim=True)
17
+ weight = weight - weight_mean
18
+ std = weight.view(weight.size(0), -1).std(dim=1).view(-1, 1, 1, 1) + 1e-5
19
+ # std = torch.sqrt(torch.var(weight.view(weight.size(0),-1),dim=1)+1e-12).view(-1,1,1,1)+1e-5
20
+ weight = weight / std.expand_as(weight)
21
+ return F.conv2d(x, weight, self.bias, self.stride, self.padding, self.dilation, self.groups)
22
+
23
+
24
+ def conv_ws(in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True):
25
+ return WSConv2d(in_channels, out_channels, kernel_size, stride=stride, padding=padding, dilation=dilation,
26
+ groups=groups, bias=bias)
27
+
28
+
29
+ '''
30
+ class Mish(nn.Module):
31
+ def __init__(self):
32
+ super(Mish, self).__init__()
33
+ def forward(self, x):
34
+ return x*torch.tanh(F.softplus(x))
35
+ '''
36
+
37
+
38
+ @script
39
+ def _mish_jit_fwd(x): return x.mul(torch.tanh(F.softplus(x)))
40
+
41
+
42
+ @script
43
+ def _mish_jit_bwd(x, grad_output):
44
+ x_sigmoid = torch.sigmoid(x)
45
+ x_tanh_sp = F.softplus(x).tanh()
46
+ return grad_output.mul(x_tanh_sp + x * x_sigmoid * (1 - x_tanh_sp * x_tanh_sp))
47
+
48
+
49
+ class MishJitAutoFn(torch.autograd.Function):
50
+ @staticmethod
51
+ def forward(ctx, x):
52
+ ctx.save_for_backward(x)
53
+ return _mish_jit_fwd(x)
54
+
55
+ @staticmethod
56
+ def backward(ctx, grad_output):
57
+ x = ctx.saved_variables[0]
58
+ return _mish_jit_bwd(x, grad_output)
59
+
60
+
61
+ # Cell
62
+ def mish(x): return MishJitAutoFn.apply(x)
63
+
64
+
65
+ class Mish(nn.Module):
66
+ def __init__(self, inplace: bool = False):
67
+ super(Mish, self).__init__()
68
+
69
+ def forward(self, x):
70
+ return MishJitAutoFn.apply(x)
71
+
72
+
73
+ ######################################################################################################################
74
+ ######################################################################################################################
75
+
76
+ # pre-activation based upsampling conv block
77
+ class upConvLayer(nn.Module):
78
+ def __init__(self, in_channels, out_channels, scale_factor, norm, act, num_groups):
79
+ super(upConvLayer, self).__init__()
80
+ conv = conv_ws
81
+ if act == 'ELU':
82
+ act = nn.ELU()
83
+ elif act == 'Mish':
84
+ act = Mish()
85
+ else:
86
+ act = nn.ReLU(True)
87
+ self.conv = conv(in_channels=in_channels, out_channels=out_channels, kernel_size=3, stride=1, padding=1,
88
+ bias=False)
89
+ if norm == 'GN':
90
+ self.norm = nn.GroupNorm(num_groups=num_groups, num_channels=in_channels)
91
+ else:
92
+ self.norm = nn.BatchNorm2d(in_channels, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
93
+ self.act = act
94
+ self.scale_factor = scale_factor
95
+
96
+ def forward(self, x):
97
+ x = self.norm(x)
98
+ x = self.act(x) # pre-activation
99
+ x = F.interpolate(x, scale_factor=self.scale_factor, mode='bilinear')
100
+ x = self.conv(x)
101
+ return x
102
+
103
+
104
+ # pre-activation based conv block
105
+ class myConv(nn.Module):
106
+ def __init__(self, in_ch, out_ch, kSize, stride=1,
107
+ padding=0, dilation=1, bias=True, norm='GN', act='ELU', num_groups=32):
108
+ super(myConv, self).__init__()
109
+ conv = conv_ws
110
+ if act == 'ELU':
111
+ act = nn.ELU()
112
+ elif act == 'Mish':
113
+ act = Mish()
114
+ else:
115
+ act = nn.ReLU(True)
116
+ module = []
117
+ if norm == 'GN':
118
+ module.append(nn.GroupNorm(num_groups=num_groups, num_channels=in_ch))
119
+ else:
120
+ module.append(nn.BatchNorm2d(in_ch, eps=0.001, momentum=0.1, affine=True, track_running_stats=True))
121
+ module.append(act)
122
+ module.append(conv(in_ch, out_ch, kernel_size=kSize, stride=stride,
123
+ padding=padding, dilation=dilation, groups=1, bias=bias))
124
+ self.module = nn.Sequential(*module)
125
+
126
+ def forward(self, x):
127
+ out = self.module(x)
128
+ return out
129
+
130
+
131
+ # Deep Feature Fxtractor
132
+ class deepFeatureExtractor_ResNext101(nn.Module):
133
+ def __init__(self, args, lv6=False):
134
+ super(deepFeatureExtractor_ResNext101, self).__init__()
135
+ self.args = args
136
+ # after passing ReLU : H/2 x W/2
137
+ # after passing Layer1 : H/4 x W/4
138
+ # after passing Layer2 : H/8 x W/8
139
+ # after passing Layer3 : H/16 x W/16
140
+ self.encoder = models.resnext101_32x8d(weights=models.ResNeXt101_32X8D_Weights.DEFAULT)
141
+ self.fixList = ['layer1.0', 'layer1.1', '.bn']
142
+ self.lv6 = lv6
143
+
144
+ if lv6 is True:
145
+ self.layerList = ['relu', 'layer1', 'layer2', 'layer3', 'layer4']
146
+ self.dimList = [64, 256, 512, 1024, 2048]
147
+ else:
148
+ del self.encoder.layer4
149
+ del self.encoder.fc
150
+ self.layerList = ['relu', 'layer1', 'layer2', 'layer3']
151
+ self.dimList = [64, 256, 512, 1024]
152
+
153
+ for name, parameters in self.encoder.named_parameters():
154
+ if name == 'conv1.weight':
155
+ parameters.requires_grad = False
156
+ if any(x in name for x in self.fixList):
157
+ parameters.requires_grad = False
158
+
159
+ def forward(self, x):
160
+ out_featList = []
161
+ feature = x
162
+ for k, v in self.encoder._modules.items():
163
+ if k == 'avgpool':
164
+ break
165
+ feature = v(feature)
166
+ # feature = v(features[-1])
167
+ # features.append(feature)
168
+ if any(x in k for x in self.layerList):
169
+ out_featList.append(feature)
170
+ return out_featList
171
+
172
+ def freeze_bn(self, enable=False):
173
+ """ Adapted from https://discuss.pytorch.org/t/how-to-train-with-frozen-batchnorm/12106/8 """
174
+ for module in self.modules():
175
+ if isinstance(module, nn.BatchNorm2d):
176
+ module.train() if enable else module.eval()
177
+
178
+ module.weight.requires_grad = enable
179
+ module.bias.requires_grad = enable
180
+
181
+
182
+ # ASPP Module
183
+ class Dilated_bottleNeck(nn.Module):
184
+ def __init__(self, norm, act, in_feat):
185
+ super(Dilated_bottleNeck, self).__init__()
186
+ conv = conv_ws
187
+ # in feat = 1024 in ResNext101 and ResNet101
188
+ self.reduction1 = conv(in_feat, in_feat // 2, kernel_size=1, stride=1, bias=False, padding=0)
189
+ self.aspp_d3 = nn.Sequential(
190
+ myConv(in_feat // 2, in_feat // 4, kSize=1, stride=1, padding=0, dilation=1, bias=False, norm=norm, act=act,
191
+ num_groups=(in_feat // 2) // 16),
192
+ myConv(in_feat // 4, in_feat // 4, kSize=3, stride=1, padding=3, dilation=3, bias=False, norm=norm, act=act,
193
+ num_groups=(in_feat // 4) // 16))
194
+ self.aspp_d6 = nn.Sequential(
195
+ myConv(in_feat // 2 + in_feat // 4, in_feat // 4, kSize=1, stride=1, padding=0, dilation=1, bias=False,
196
+ norm=norm, act=act, num_groups=(in_feat // 2 + in_feat // 4) // 16),
197
+ myConv(in_feat // 4, in_feat // 4, kSize=3, stride=1, padding=6, dilation=6, bias=False, norm=norm, act=act,
198
+ num_groups=(in_feat // 4) // 16))
199
+ self.aspp_d12 = nn.Sequential(
200
+ myConv(in_feat, in_feat // 4, kSize=1, stride=1, padding=0, dilation=1, bias=False, norm=norm, act=act,
201
+ num_groups=(in_feat) // 16),
202
+ myConv(in_feat // 4, in_feat // 4, kSize=3, stride=1, padding=12, dilation=12, bias=False, norm=norm,
203
+ act=act, num_groups=(in_feat // 4) // 16))
204
+ self.aspp_d18 = nn.Sequential(
205
+ myConv(in_feat + in_feat // 4, in_feat // 4, kSize=1, stride=1, padding=0, dilation=1, bias=False,
206
+ norm=norm, act=act, num_groups=(in_feat + in_feat // 4) // 16),
207
+ myConv(in_feat // 4, in_feat // 4, kSize=3, stride=1, padding=18, dilation=18, bias=False, norm=norm,
208
+ act=act, num_groups=(in_feat // 4) // 16))
209
+ self.reduction2 = myConv(((in_feat // 4) * 4) + (in_feat // 2), in_feat // 2, kSize=3, stride=1, padding=1,
210
+ bias=False, norm=norm, act=act, num_groups=((in_feat // 4) * 4 + (in_feat // 2)) // 16)
211
+
212
+ def forward(self, x):
213
+ x = self.reduction1(x)
214
+ d3 = self.aspp_d3(x)
215
+ cat1 = torch.cat([x, d3], dim=1)
216
+ d6 = self.aspp_d6(cat1)
217
+ cat2 = torch.cat([cat1, d6], dim=1)
218
+ d12 = self.aspp_d12(cat2)
219
+ cat3 = torch.cat([cat2, d12], dim=1)
220
+ d18 = self.aspp_d18(cat3)
221
+ out = self.reduction2(torch.cat([x, d3, d6, d12, d18], dim=1))
222
+ return out # 512 x H/16 x W/16
223
+
224
+
225
+ class Dilated_bottleNeck2(nn.Module):
226
+ def __init__(self, norm, act, in_feat):
227
+ super(Dilated_bottleNeck2, self).__init__()
228
+ conv = conv_ws
229
+ # in feat = 1024 in ResNext101 and ResNet101
230
+ # self.reduction1 = conv(in_feat, in_feat//2, kernel_size=1, stride = 1, bias=False, padding=0)
231
+ self.reduction1 = conv(in_feat, in_feat // 2, kernel_size=3, stride=1, padding=1, bias=False)
232
+ self.aspp_d3 = nn.Sequential(
233
+ myConv(in_feat // 2, in_feat // 4, kSize=1, stride=1, padding=0, dilation=1, bias=False, norm=norm, act=act,
234
+ num_groups=(in_feat // 2) // 16),
235
+ myConv(in_feat // 4, in_feat // 4, kSize=3, stride=1, padding=3, dilation=3, bias=False, norm=norm, act=act,
236
+ num_groups=(in_feat // 4) // 16))
237
+ self.aspp_d6 = nn.Sequential(
238
+ myConv(in_feat // 2 + in_feat // 4, in_feat // 4, kSize=1, stride=1, padding=0, dilation=1, bias=False,
239
+ norm=norm, act=act, num_groups=(in_feat // 2 + in_feat // 4) // 16),
240
+ myConv(in_feat // 4, in_feat // 4, kSize=3, stride=1, padding=6, dilation=6, bias=False, norm=norm, act=act,
241
+ num_groups=(in_feat // 4) // 16))
242
+ self.aspp_d12 = nn.Sequential(
243
+ myConv(in_feat, in_feat // 4, kSize=1, stride=1, padding=0, dilation=1, bias=False, norm=norm, act=act,
244
+ num_groups=(in_feat) // 16),
245
+ myConv(in_feat // 4, in_feat // 4, kSize=3, stride=1, padding=12, dilation=12, bias=False, norm=norm,
246
+ act=act, num_groups=(in_feat // 4) // 16))
247
+ self.aspp_d18 = nn.Sequential(
248
+ myConv(in_feat + in_feat // 4, in_feat // 4, kSize=1, stride=1, padding=0, dilation=1, bias=False,
249
+ norm=norm, act=act, num_groups=(in_feat + in_feat // 4) // 16),
250
+ myConv(in_feat // 4, in_feat // 4, kSize=3, stride=1, padding=18, dilation=18, bias=False, norm=norm,
251
+ act=act, num_groups=(in_feat // 4) // 16))
252
+ self.aspp_d24 = nn.Sequential(
253
+ myConv(in_feat + in_feat // 2, in_feat // 4, kSize=1, stride=1, padding=0, dilation=1, bias=False,
254
+ norm=norm, act=act, num_groups=(in_feat + in_feat // 2) // 16),
255
+ myConv(in_feat // 4, in_feat // 4, kSize=3, stride=1, padding=24, dilation=24, bias=False, norm=norm,
256
+ act=act, num_groups=(in_feat // 4) // 16))
257
+ self.reduction2 = myConv(((in_feat // 4) * 5) + (in_feat // 2), in_feat // 2, kSize=3, stride=1, padding=1,
258
+ bias=False, norm=norm, act=act, num_groups=((in_feat // 4) * 5 + (in_feat // 2)) // 16)
259
+
260
+ def forward(self, x):
261
+ x = self.reduction1(x)
262
+ d3 = self.aspp_d3(x)
263
+ cat1 = torch.cat([x, d3], dim=1)
264
+ d6 = self.aspp_d6(cat1)
265
+ cat2 = torch.cat([cat1, d6], dim=1)
266
+ d12 = self.aspp_d12(cat2)
267
+ cat3 = torch.cat([cat2, d12], dim=1)
268
+ d18 = self.aspp_d18(cat3)
269
+ cat4 = torch.cat([cat3, d18], dim=1)
270
+ d24 = self.aspp_d24(cat4)
271
+ out = self.reduction2(torch.cat([x, d3, d6, d12, d18, d24], dim=1))
272
+ return out # 512 x H/16 x W/16
273
+
274
+
275
+ class Dilated_bottleNeck_lv6(nn.Module):
276
+ def __init__(self, norm, act, in_feat):
277
+ super(Dilated_bottleNeck_lv6, self).__init__()
278
+ conv = conv_ws
279
+ in_feat = in_feat // 2
280
+ self.reduction1 = myConv(in_feat * 2, in_feat // 2, kSize=3, stride=1, padding=1, bias=False, norm=norm,
281
+ act=act, num_groups=(in_feat) // 16)
282
+ self.aspp_d3 = nn.Sequential(
283
+ myConv(in_feat // 2, in_feat // 4, kSize=1, stride=1, padding=0, dilation=1, bias=False, norm=norm, act=act,
284
+ num_groups=(in_feat // 2) // 16),
285
+ myConv(in_feat // 4, in_feat // 4, kSize=3, stride=1, padding=3, dilation=3, bias=False, norm=norm, act=act,
286
+ num_groups=(in_feat // 4) // 16))
287
+ self.aspp_d6 = nn.Sequential(
288
+ myConv(in_feat // 2 + in_feat // 4, in_feat // 4, kSize=1, stride=1, padding=0, dilation=1, bias=False,
289
+ norm=norm, act=act, num_groups=(in_feat // 2 + in_feat // 4) // 16),
290
+ myConv(in_feat // 4, in_feat // 4, kSize=3, stride=1, padding=6, dilation=6, bias=False, norm=norm, act=act,
291
+ num_groups=(in_feat // 4) // 16))
292
+ self.aspp_d12 = nn.Sequential(
293
+ myConv(in_feat, in_feat // 4, kSize=1, stride=1, padding=0, dilation=1, bias=False, norm=norm, act=act,
294
+ num_groups=(in_feat) // 16),
295
+ myConv(in_feat // 4, in_feat // 4, kSize=3, stride=1, padding=12, dilation=12, bias=False, norm=norm,
296
+ act=act, num_groups=(in_feat // 4) // 16))
297
+ self.aspp_d18 = nn.Sequential(
298
+ myConv(in_feat + in_feat // 4, in_feat // 4, kSize=1, stride=1, padding=0, dilation=1, bias=False,
299
+ norm=norm, act=act, num_groups=(in_feat + in_feat // 4) // 16),
300
+ myConv(in_feat // 4, in_feat // 4, kSize=3, stride=1, padding=18, dilation=18, bias=False, norm=norm,
301
+ act=act, num_groups=(in_feat // 4) // 16))
302
+ self.reduction2 = myConv(((in_feat // 4) * 4) + (in_feat // 2), in_feat, kSize=3, stride=1, padding=1,
303
+ bias=False, norm=norm, act=act, num_groups=((in_feat // 4) * 4 + (in_feat // 2)) // 16)
304
+
305
+ def forward(self, x):
306
+ x = self.reduction1(x)
307
+ d3 = self.aspp_d3(x)
308
+ cat1 = torch.cat([x, d3], dim=1)
309
+ d6 = self.aspp_d6(cat1)
310
+ cat2 = torch.cat([cat1, d6], dim=1)
311
+ d12 = self.aspp_d12(cat2)
312
+ cat3 = torch.cat([cat2, d12], dim=1)
313
+ d18 = self.aspp_d18(cat3)
314
+ out = self.reduction2(torch.cat([x, d3, d6, d12, d18], dim=1))
315
+ return out # 512 x H/16 x W/16
316
+
317
+
318
+ # Laplacian Decoder Network
319
+ class Lap_decoder_lv5(nn.Module):
320
+ def __init__(self, args, dimList):
321
+ super(Lap_decoder_lv5, self).__init__()
322
+ norm = args.norm
323
+ conv = conv_ws
324
+ if norm == 'GN':
325
+ if args.rank == 0:
326
+ print("==> Norm: GN")
327
+ else:
328
+ if args.rank == 0:
329
+ print("==> Norm: BN")
330
+
331
+ if args.act == 'ELU':
332
+ act = 'ELU'
333
+ elif args.act == 'Mish':
334
+ act = 'Mish'
335
+ else:
336
+ act = 'ReLU'
337
+ kSize = 3
338
+ self.max_depth = args.max_depth
339
+ self.ASPP = Dilated_bottleNeck(norm, act, dimList[3])
340
+ self.dimList = dimList
341
+ ############################################ Pyramid Level 5 ###################################################
342
+ # decoder1 out : 1 x H/16 x W/16 (Level 5)
343
+ self.decoder1 = nn.Sequential(
344
+ myConv(dimList[3] // 2, dimList[3] // 4, kSize, stride=1, padding=kSize // 2, bias=False,
345
+ norm=norm, act=act, num_groups=(dimList[3] // 2) // 16),
346
+ myConv(dimList[3] // 4, dimList[3] // 8, kSize, stride=1, padding=kSize // 2, bias=False,
347
+ norm=norm, act=act, num_groups=(dimList[3] // 4) // 16),
348
+ myConv(dimList[3] // 8, dimList[3] // 16, kSize, stride=1, padding=kSize // 2, bias=False,
349
+ norm=norm, act=act, num_groups=(dimList[3] // 8) // 16),
350
+ myConv(dimList[3] // 16, dimList[3] // 32, kSize, stride=1, padding=kSize // 2, bias=False,
351
+ norm=norm, act=act, num_groups=(dimList[3] // 16) // 16),
352
+ myConv(dimList[3] // 32, 1, kSize, stride=1, padding=kSize // 2, bias=False,
353
+ norm=norm, act=act, num_groups=(dimList[3] // 32) // 16)
354
+ )
355
+ ########################################################################################################################
356
+
357
+ ############################################ Pyramid Level 4 ###################################################
358
+ # decoder2 out : 1 x H/8 x W/8 (Level 4)
359
+ # decoder2_up : (H/16,W/16)->(H/8,W/8)
360
+ self.decoder2_up1 = upConvLayer(dimList[3] // 2, dimList[3] // 4, 2, norm, act, (dimList[3] // 2) // 16)
361
+ self.decoder2_reduc1 = myConv(dimList[3] // 4 + dimList[2], dimList[3] // 4 - 4, kSize=1, stride=1, padding=0,
362
+ bias=False,
363
+ norm=norm, act=act, num_groups=(dimList[3] // 4 + dimList[2]) // 16)
364
+ self.decoder2_1 = myConv(dimList[3] // 4, dimList[3] // 4, kSize, stride=1, padding=kSize // 2, bias=False,
365
+ norm=norm, act=act, num_groups=(dimList[3] // 4) // 16)
366
+
367
+ self.decoder2_2 = myConv(dimList[3] // 4, dimList[3] // 8, kSize, stride=1, padding=kSize // 2, bias=False,
368
+ norm=norm, act=act, num_groups=(dimList[3] // 4) // 16)
369
+ self.decoder2_3 = myConv(dimList[3] // 8, dimList[3] // 16, kSize, stride=1, padding=kSize // 2, bias=False,
370
+ norm=norm, act=act, num_groups=(dimList[3] // 8) // 16)
371
+
372
+ self.decoder2_4 = myConv(dimList[3] // 16, 1, kSize, stride=1, padding=kSize // 2, bias=False,
373
+ norm=norm, act=act, num_groups=(dimList[3] // 16) // 16)
374
+ ########################################################################################################################
375
+
376
+ ############################################ Pyramid Level 3 ###################################################
377
+ # decoder2 out2 : 1 x H/4 x W/4 (Level 3)
378
+ # decoder2_1_up2 : (H/8,W/8)->(H/4,W/4)
379
+ self.decoder2_1_up2 = upConvLayer(dimList[3] // 4, dimList[3] // 8, 2, norm, act, (dimList[3] // 4) // 16)
380
+ self.decoder2_1_reduc2 = myConv(dimList[3] // 8 + dimList[1], dimList[3] // 8 - 4, kSize=1, stride=1, padding=0,
381
+ bias=False,
382
+ norm=norm, act=act, num_groups=(dimList[3] // 8 + dimList[1]) // 16)
383
+ self.decoder2_1_1 = myConv(dimList[3] // 8, dimList[3] // 8, kSize, stride=1, padding=kSize // 2, bias=False,
384
+ norm=norm, act=act, num_groups=(dimList[3] // 8) // 16)
385
+
386
+ self.decoder2_1_2 = myConv(dimList[3] // 8, dimList[3] // 16, kSize, stride=1, padding=kSize // 2, bias=False,
387
+ norm=norm, act=act, num_groups=(dimList[3] // 8) // 16)
388
+
389
+ self.decoder2_1_3 = myConv(dimList[3] // 16, 1, kSize, stride=1, padding=kSize // 2, bias=False,
390
+ norm=norm, act=act, num_groups=(dimList[3] // 16) // 16)
391
+ ########################################################################################################################
392
+
393
+ ############################################ Pyramid Level 2 ###################################################
394
+ # decoder2 out3 : 1 x H/2 x W/2 (Level 2)
395
+ # decoder2_1_1_up3 : (H/4,W/4)->(H/2,W/2)
396
+ self.decoder2_1_1_up3 = upConvLayer(dimList[3] // 8, dimList[3] // 16, 2, norm, act, (dimList[3] // 8) // 16)
397
+ self.decoder2_1_1_reduc3 = myConv(dimList[3] // 16 + dimList[0], dimList[3] // 16 - 4, kSize=1, stride=1,
398
+ padding=0, bias=False,
399
+ norm=norm, act=act, num_groups=(dimList[3] // 16 + dimList[0]) // 16)
400
+ self.decoder2_1_1_1 = myConv(dimList[3] // 16, dimList[3] // 16, kSize, stride=1, padding=kSize // 2,
401
+ bias=False,
402
+ norm=norm, act=act, num_groups=(dimList[3] // 16) // 16)
403
+
404
+ self.decoder2_1_1_2 = myConv(dimList[3] // 16, 1, kSize, stride=1, padding=kSize // 2, bias=False,
405
+ norm=norm, act=act, num_groups=(dimList[3] // 16) // 16)
406
+ ########################################################################################################################
407
+
408
+ ############################################ Pyramid Level 1 ###################################################
409
+ # decoder5 out : 1 x H x W (Level 1)
410
+ # decoder2_1_1_1_up4 : (H/2,W/2)->(H,W)
411
+ self.decoder2_1_1_1_up4 = upConvLayer(dimList[3] // 16, dimList[3] // 16 - 4, 2, norm, act,
412
+ (dimList[3] // 16) // 16)
413
+ self.decoder2_1_1_1_1 = myConv(dimList[3] // 16, dimList[3] // 16, kSize, stride=1, padding=kSize // 2,
414
+ bias=False,
415
+ norm=norm, act=act, num_groups=(dimList[3] // 16) // 16)
416
+
417
+ self.decoder2_1_1_1_2 = myConv(dimList[3] // 16, dimList[3] // 32, kSize, stride=1, padding=kSize // 2,
418
+ bias=False,
419
+ norm=norm, act=act, num_groups=(dimList[3] // 16) // 16)
420
+ self.decoder2_1_1_1_3 = myConv(dimList[3] // 32, 1, kSize, stride=1, padding=kSize // 2, bias=False,
421
+ norm=norm, act=act, num_groups=(dimList[3] // 32) // 16)
422
+ ########################################################################################################################
423
+ self.upscale = F.interpolate
424
+
425
+ def forward(self, x, rgb):
426
+ cat1, cat2, cat3, dense_feat = x[0], x[1], x[2], x[3]
427
+ rgb_lv6, rgb_lv5, rgb_lv4, rgb_lv3, rgb_lv2, rgb_lv1 = rgb[0], rgb[1], rgb[2], rgb[3], rgb[4], rgb[5]
428
+ dense_feat = self.ASPP(dense_feat) # Dense feature for lev 5
429
+ # decoder 1 - Pyramid level 5
430
+ lap_lv5 = torch.sigmoid(self.decoder1(dense_feat))
431
+ lap_lv5_up = self.upscale(lap_lv5, scale_factor=2, mode='bilinear')
432
+
433
+ # decoder 2 - Pyramid level 4
434
+ dec2 = self.decoder2_up1(dense_feat)
435
+ dec2 = self.decoder2_reduc1(torch.cat([dec2, cat3], dim=1))
436
+ dec2_up = self.decoder2_1(torch.cat([dec2, lap_lv5_up, rgb_lv4], dim=1))
437
+ dec2 = self.decoder2_2(dec2_up)
438
+ dec2 = self.decoder2_3(dec2)
439
+ lap_lv4 = torch.tanh(self.decoder2_4(dec2) + (0.1 * rgb_lv4.mean(dim=1, keepdim=True)))
440
+ # if depth range is (0,1), laplacian of image range is (-1,1)
441
+ lap_lv4_up = self.upscale(lap_lv4, scale_factor=2, mode='bilinear')
442
+ # decoder 2 - Pyramid level 3
443
+ dec3 = self.decoder2_1_up2(dec2_up)
444
+ dec3 = self.decoder2_1_reduc2(torch.cat([dec3, cat2], dim=1))
445
+ dec3_up = self.decoder2_1_1(torch.cat([dec3, lap_lv4_up, rgb_lv3], dim=1))
446
+ dec3 = self.decoder2_1_2(dec3_up)
447
+ lap_lv3 = torch.tanh(self.decoder2_1_3(dec3) + (0.1 * rgb_lv3.mean(dim=1, keepdim=True)))
448
+ # if depth range is (0,1), laplacian of image range is (-1,1)
449
+ lap_lv3_up = self.upscale(lap_lv3, scale_factor=2, mode='bilinear')
450
+ # decoder 2 - Pyramid level 2
451
+ dec4 = self.decoder2_1_1_up3(dec3_up)
452
+ dec4 = self.decoder2_1_1_reduc3(torch.cat([dec4, cat1], dim=1))
453
+ dec4_up = self.decoder2_1_1_1(torch.cat([dec4, lap_lv3_up, rgb_lv2], dim=1))
454
+
455
+ lap_lv2 = torch.tanh(self.decoder2_1_1_2(dec4_up) + (0.1 * rgb_lv2.mean(dim=1, keepdim=True)))
456
+ # if depth range is (0,1), laplacian of image range is (-1,1)
457
+ lap_lv2_up = self.upscale(lap_lv2, scale_factor=2, mode='bilinear')
458
+ # decoder 2 - Pyramid level 1
459
+ dec5 = self.decoder2_1_1_1_up4(dec4_up)
460
+ dec5 = self.decoder2_1_1_1_1(torch.cat([dec5, lap_lv2_up, rgb_lv1], dim=1))
461
+ dec5 = self.decoder2_1_1_1_2(dec5)
462
+ lap_lv1 = torch.tanh(self.decoder2_1_1_1_3(dec5) + (0.1 * rgb_lv1.mean(dim=1, keepdim=True)))
463
+ # if depth range is (0,1), laplacian of image range is (-1,1)
464
+
465
+ # Laplacian restoration
466
+ lap_lv4_img = lap_lv4 + lap_lv5_up
467
+ lap_lv3_img = lap_lv3 + self.upscale(lap_lv4_img, scale_factor=2, mode='bilinear')
468
+ lap_lv2_img = lap_lv2 + self.upscale(lap_lv3_img, scale_factor=2, mode='bilinear')
469
+ final_depth = lap_lv1 + self.upscale(lap_lv2_img, scale_factor=2, mode='bilinear')
470
+ final_depth = torch.sigmoid(final_depth)
471
+ return [(lap_lv5) * self.max_depth, (lap_lv4) * self.max_depth, (lap_lv3) * self.max_depth,
472
+ (lap_lv2) * self.max_depth, (lap_lv1) * self.max_depth], final_depth * self.max_depth
473
+ # fit laplacian image range (-80,80), depth image range(0,80)
474
+
475
+
476
+ class Lap_decoder_lv6(nn.Module):
477
+ def __init__(self, args, dimList):
478
+ super(Lap_decoder_lv6, self).__init__()
479
+ norm = args.norm
480
+ conv = conv_ws
481
+ if norm == 'GN':
482
+ if args.rank == 0:
483
+ print("==> Norm: GN")
484
+ else:
485
+ if args.rank == 0:
486
+ print("==> Norm: BN")
487
+
488
+ if args.act == 'ELU':
489
+ act = 'ELU'
490
+ elif args.act == 'Mish':
491
+ act = 'Mish'
492
+ else:
493
+ act = 'ReLU'
494
+ kSize = 3
495
+ self.max_depth = args.max_depth
496
+ self.ASPP = Dilated_bottleNeck_lv6(norm, act, dimList[4])
497
+ dimList[4] = dimList[4] // 2
498
+ self.dimList = dimList
499
+ ############################################ Pyramid Level 6 ###################################################
500
+ # decoder1 out : 1 x H/32 x W/32 (Level 6)
501
+ self.decoder1 = nn.Sequential(
502
+ myConv(dimList[4] // 2, dimList[4] // 4, kSize, stride=1, padding=kSize // 2, bias=False,
503
+ norm=norm, act=act, num_groups=(dimList[4] // 2) // 16),
504
+ myConv(dimList[4] // 4, dimList[4] // 8, kSize, stride=1, padding=kSize // 2, bias=False,
505
+ norm=norm, act=act, num_groups=(dimList[4] // 4) // 16),
506
+ myConv(dimList[4] // 8, dimList[4] // 16, kSize, stride=1, padding=kSize // 2, bias=False,
507
+ norm=norm, act=act, num_groups=(dimList[4] // 8) // 16),
508
+ myConv(dimList[4] // 16, dimList[4] // 32, kSize, stride=1, padding=kSize // 2, bias=False,
509
+ norm=norm, act=act, num_groups=(dimList[4] // 16) // 16),
510
+ myConv(dimList[4] // 32, 1, kSize, stride=1, padding=kSize // 2, bias=False,
511
+ norm=norm, act=act, num_groups=(dimList[4] // 32) // 8)
512
+ )
513
+ ########################################################################################################################
514
+
515
+ ############################################ Pyramid Level 5 ###################################################
516
+ # decoder2 out : 1 x H/16 x W/16 (Level 5)
517
+ # decoder2_up : (H/32,W/32)->(H/16,W/16)
518
+ self.decoder2_up1 = upConvLayer(dimList[4] // 2, dimList[4] // 4, 2, norm, act, (dimList[4] // 2) // 16)
519
+ self.decoder2_reduc1 = myConv(dimList[4] // 4 + dimList[3], dimList[4] // 4 - 4, kSize=1, stride=1, padding=0,
520
+ bias=False,
521
+ norm=norm, act=act, num_groups=(dimList[4] // 4 + dimList[3]) // 16)
522
+ self.decoder2_1 = myConv(dimList[4] // 4, dimList[4] // 4, kSize, stride=1, padding=kSize // 2, bias=False,
523
+ norm=norm, act=act, num_groups=(dimList[4] // 4) // 16)
524
+
525
+ self.decoder2_2 = myConv(dimList[4] // 4, dimList[4] // 8, kSize, stride=1, padding=kSize // 2, bias=False,
526
+ norm=norm, act=act, num_groups=(dimList[4] // 4) // 16)
527
+ self.decoder2_3 = myConv(dimList[4] // 8, dimList[4] // 16, kSize, stride=1, padding=kSize // 2, bias=False,
528
+ norm=norm, act=act, num_groups=(dimList[4] // 8) // 16)
529
+
530
+ self.decoder2_4 = myConv(dimList[4] // 16, 1, kSize, stride=1, padding=kSize // 2, bias=False,
531
+ norm=norm, act=act, num_groups=(dimList[4] // 16) // 16)
532
+ ########################################################################################################################
533
+
534
+ ############################################ Pyramid Level 4 ###################################################
535
+ # decoder2 out2 : 1 x H/8 x W/8 (Level 4)
536
+ # decoder2_1_up2 : (H/16,W/16)->(H/8,W/8)
537
+ self.decoder2_1_up2 = upConvLayer(dimList[4] // 4, dimList[4] // 8, 2, norm, act, (dimList[4] // 4) // 16)
538
+ self.decoder2_1_reduc2 = myConv(dimList[4] // 8 + dimList[2], dimList[4] // 8 - 4, kSize=1, stride=1, padding=0,
539
+ bias=False,
540
+ norm=norm, act=act, num_groups=(dimList[4] // 8 + dimList[2]) // 16)
541
+ self.decoder2_1_1 = myConv(dimList[4] // 8, dimList[4] // 8, kSize, stride=1, padding=kSize // 2, bias=False,
542
+ norm=norm, act=act, num_groups=(dimList[4] // 8) // 16)
543
+
544
+ self.decoder2_1_2 = myConv(dimList[4] // 8, dimList[4] // 16, kSize, stride=1, padding=kSize // 2, bias=False,
545
+ norm=norm, act=act, num_groups=(dimList[4] // 8) // 16)
546
+
547
+ self.decoder2_1_3 = myConv(dimList[4] // 16, 1, kSize, stride=1, padding=kSize // 2, bias=False,
548
+ norm=norm, act=act, num_groups=(dimList[4] // 16) // 16)
549
+ ########################################################################################################################
550
+
551
+ ############################################ Pyramid Level 3 ###################################################
552
+ # decoder2 out3 : 1 x H/4 x W/4 (Level 3)
553
+ # decoder2_1_1_up3 : (H/8,W/8)->(H/4,W/4)
554
+ self.decoder2_1_1_up3 = upConvLayer(dimList[4] // 8, dimList[4] // 16, 2, norm, act, (dimList[4] // 8) // 16)
555
+ self.decoder2_1_1_reduc3 = myConv(dimList[4] // 16 + dimList[1], dimList[4] // 16 - 4, kSize=1, stride=1,
556
+ padding=0, bias=False,
557
+ norm=norm, act=act, num_groups=(dimList[4] // 16 + dimList[1]) // 8)
558
+ self.decoder2_1_1_1 = myConv(dimList[4] // 16, dimList[4] // 16, kSize, stride=1, padding=kSize // 2,
559
+ bias=False,
560
+ norm=norm, act=act, num_groups=(dimList[4] // 16) // 16)
561
+
562
+ self.decoder2_1_1_2 = myConv(dimList[4] // 16, 1, kSize, stride=1, padding=kSize // 2, bias=False,
563
+ norm=norm, act=act, num_groups=(dimList[4] // 16) // 16)
564
+ ########################################################################################################################
565
+
566
+ ############################################ Pyramid Level 2 ###################################################
567
+ # decoder2 out4 : 1 x H/2 x W/2 (Level 2)
568
+ # decoder2_1_1_1_up4 : (H/4,W/4)->(H/2,W/2)
569
+ self.decoder2_1_1_1_up4 = upConvLayer(dimList[4] // 16, dimList[4] // 32, 2, norm, act,
570
+ (dimList[4] // 16) // 16)
571
+ self.decoder2_1_1_1_reduc4 = myConv(dimList[4] // 32 + dimList[0], dimList[4] // 32 - 4, kSize=1, stride=1,
572
+ padding=0, bias=False,
573
+ norm=norm, act=act, num_groups=(dimList[4] // 32 + dimList[0]) // 8)
574
+ self.decoder2_1_1_1_1 = myConv(dimList[4] // 32, dimList[4] // 32, kSize, stride=1, padding=kSize // 2,
575
+ bias=False,
576
+ norm=norm, act=act, num_groups=(dimList[4] // 32) // 8)
577
+
578
+ self.decoder2_1_1_1_2 = myConv(dimList[4] // 32, 1, kSize, stride=1, padding=kSize // 2, bias=False,
579
+ norm=norm, act=act, num_groups=(dimList[4] // 32) // 8)
580
+ ########################################################################################################################
581
+
582
+ ############################################ Pyramid Level 1 ###################################################
583
+ # decoder5 out : 1 x H x W (Level 1)
584
+ # decoder2_1_1_1_1_up5 : (H/2,W/2)->(H,W)
585
+ self.decoder2_1_1_1_1_up5 = upConvLayer(dimList[4] // 32, dimList[4] // 32 - 4, 2, norm, act,
586
+ (dimList[4] // 32) // 8) # H x W (64 -> 60)
587
+ self.decoder2_1_1_1_1_1 = myConv(dimList[4] // 32, dimList[4] // 32, kSize, stride=1, padding=kSize // 2,
588
+ bias=False,
589
+ norm=norm, act=act, num_groups=(dimList[4] // 32) // 8)
590
+
591
+ self.decoder2_1_1_1_1_2 = myConv(dimList[4] // 32, dimList[4] // 64, kSize, stride=1, padding=kSize // 2,
592
+ bias=False,
593
+ norm=norm, act=act, num_groups=(dimList[4] // 32) // 8)
594
+ self.decoder2_1_1_1_1_3 = myConv(dimList[4] // 64, 1, kSize, stride=1, padding=kSize // 2, bias=False,
595
+ norm=norm, act=act, num_groups=(dimList[4] // 64) // 4)
596
+ ########################################################################################################################
597
+ self.upscale = F.interpolate
598
+
599
+ def forward(self, x, rgb):
600
+ cat1, cat2, cat3, cat4, dense_feat = x[0], x[1], x[2], x[3], x[4]
601
+ rgb_lv6, rgb_lv5, rgb_lv4, rgb_lv3, rgb_lv2, rgb_lv1 = rgb[0], rgb[1], rgb[2], rgb[3], rgb[4], rgb[5]
602
+ dense_feat = self.ASPP(dense_feat) # Dense feature for lev 6
603
+ # decoder 1 - Pyramid level 6
604
+ lap_lv6 = torch.sigmoid(self.decoder1(dense_feat))
605
+ lap_lv6_up = self.upscale(lap_lv6, scale_factor=2, mode='bilinear')
606
+
607
+ # decoder 2 - Pyramid level 5
608
+ dec2 = self.decoder2_up1(dense_feat)
609
+ dec2 = self.decoder2_reduc1(torch.cat([dec2, cat4], dim=1))
610
+ dec2_up = self.decoder2_1(torch.cat([dec2, lap_lv6_up, rgb_lv5], dim=1))
611
+ dec2 = self.decoder2_2(dec2_up)
612
+ dec2 = self.decoder2_3(dec2)
613
+ lap_lv5 = torch.tanh(self.decoder2_4(dec2) + (0.1 * rgb_lv5.mean(dim=1, keepdim=True)))
614
+ # if depth range is (0,1), laplacian image range is (-1,1)
615
+ lap_lv5_up = self.upscale(lap_lv5, scale_factor=2, mode='bilinear')
616
+ # decoder 2 - Pyramid level 4
617
+ dec3 = self.decoder2_1_up2(dec2_up)
618
+ dec3 = self.decoder2_1_reduc2(torch.cat([dec3, cat3], dim=1))
619
+ dec3_up = self.decoder2_1_1(torch.cat([dec3, lap_lv5_up, rgb_lv4], dim=1))
620
+ dec3 = self.decoder2_1_2(dec3_up)
621
+ lap_lv4 = torch.tanh(self.decoder2_1_3(dec3) + (0.1 * rgb_lv4.mean(dim=1, keepdim=True)))
622
+ # if depth range is (0,1), laplacian image range is (-1,1)
623
+ lap_lv4_up = self.upscale(lap_lv4, scale_factor=2, mode='bilinear')
624
+ # decoder 2 - Pyramid level 3
625
+ dec4 = self.decoder2_1_1_up3(dec3_up)
626
+ dec4 = self.decoder2_1_1_reduc3(torch.cat([dec4, cat2], dim=1))
627
+ dec4_up = self.decoder2_1_1_1(torch.cat([dec4, lap_lv4_up, rgb_lv3], dim=1))
628
+
629
+ lap_lv3 = torch.tanh(self.decoder2_1_1_2(dec4_up) + (0.1 * rgb_lv3.mean(dim=1, keepdim=True)))
630
+ # if depth range is (0,1), laplacian image range is (-1,1)
631
+ lap_lv3_up = self.upscale(lap_lv3, scale_factor=2, mode='bilinear')
632
+ # decoder 2 - Pyramid level 2
633
+ dec5 = self.decoder2_1_1_1_up4(dec4_up)
634
+ dec5 = self.decoder2_1_1_1_reduc4(torch.cat([dec5, cat1], dim=1))
635
+ dec5_up = self.decoder2_1_1_1_1(torch.cat([dec5, lap_lv3_up, rgb_lv2], dim=1))
636
+
637
+ lap_lv2 = torch.tanh(self.decoder2_1_1_1_2(dec5_up) + (0.1 * rgb_lv2.mean(dim=1, keepdim=True)))
638
+ # if depth range is (0,1), laplacian image range is (-1,1)
639
+ lap_lv2_up = self.upscale(lap_lv2, scale_factor=2, mode='bilinear')
640
+ # decoder 2 - Pyramid level 1
641
+ dec6 = self.decoder2_1_1_1_1_up5(dec5_up)
642
+ dec6 = self.decoder2_1_1_1_1_1(torch.cat([dec6, lap_lv2_up, rgb_lv1], dim=1))
643
+ dec6 = self.decoder2_1_1_1_1_2(dec6)
644
+ lap_lv1 = torch.tanh(self.decoder2_1_1_1_1_3(dec6) + (0.1 * rgb_lv1.mean(dim=1, keepdim=True)))
645
+ # if depth range is (0,1), laplacian image range is (-1,1)
646
+
647
+ # Laplacian restoration
648
+ lap_lv5_img = lap_lv5 + lap_lv6_up
649
+ lap_lv4_img = lap_lv4 + self.upscale(lap_lv5_img, scale_factor=2, mode='bilinear')
650
+ lap_lv3_img = lap_lv3 + self.upscale(lap_lv4_img, scale_factor=2, mode='bilinear')
651
+ lap_lv2_img = lap_lv2 + self.upscale(lap_lv3_img, scale_factor=2, mode='bilinear')
652
+ final_depth = lap_lv1 + self.upscale(lap_lv2_img, scale_factor=2, mode='bilinear')
653
+ final_depth = torch.sigmoid(final_depth)
654
+ return [(lap_lv6) * self.max_depth, (lap_lv5) * self.max_depth, (lap_lv4) * self.max_depth,
655
+ (lap_lv3) * self.max_depth, (lap_lv2) * self.max_depth,
656
+ (lap_lv1) * self.max_depth], final_depth * self.max_depth
657
+ # fit laplacian image range (-80,80), depth image range(0,80)
658
+
659
+
660
+ # Laplacian Depth Residual Network
661
+ class LDRN(nn.Module):
662
+ def __init__(self, args):
663
+ super(LDRN, self).__init__()
664
+ lv6 = args.lv6
665
+ self.encoder = deepFeatureExtractor_ResNext101(args, lv6)
666
+
667
+ if lv6 is True:
668
+ self.decoder = Lap_decoder_lv6(args, self.encoder.dimList)
669
+ else:
670
+ self.decoder = Lap_decoder_lv5(args, self.encoder.dimList)
671
+
672
+ def forward(self, x):
673
+ out_featList = self.encoder(x)
674
+ rgb_down2 = F.interpolate(x, scale_factor=0.5, mode='bilinear')
675
+ rgb_down4 = F.interpolate(rgb_down2, scale_factor=0.5, mode='bilinear')
676
+ rgb_down8 = F.interpolate(rgb_down4, scale_factor=0.5, mode='bilinear')
677
+ rgb_down16 = F.interpolate(rgb_down8, scale_factor=0.5, mode='bilinear')
678
+ rgb_down32 = F.interpolate(rgb_down16, scale_factor=0.5, mode='bilinear')
679
+ rgb_up16 = F.interpolate(rgb_down32, rgb_down16.shape[2:], mode='bilinear')
680
+ rgb_up8 = F.interpolate(rgb_down16, rgb_down8.shape[2:], mode='bilinear')
681
+ rgb_up4 = F.interpolate(rgb_down8, rgb_down4.shape[2:], mode='bilinear')
682
+ rgb_up2 = F.interpolate(rgb_down4, rgb_down2.shape[2:], mode='bilinear')
683
+ rgb_up = F.interpolate(rgb_down2, x.shape[2:], mode='bilinear')
684
+ lap1 = x - rgb_up
685
+ lap2 = rgb_down2 - rgb_up2
686
+ lap3 = rgb_down4 - rgb_up4
687
+ lap4 = rgb_down8 - rgb_up8
688
+ lap5 = rgb_down16 - rgb_up16
689
+ rgb_list = [rgb_down32, lap5, lap4, lap3, lap2, lap1]
690
+
691
+ d_res_list, depth = self.decoder(out_featList, rgb_list)
692
+ return d_res_list, depth
693
+
694
+ def train(self, mode=True):
695
+ super().train(mode)
696
+ self.encoder.freeze_bn()