sunana commited on
Commit
a667f59
·
1 Parent(s): ca12b2c

Upload flow_tools.py

Browse files
Files changed (1) hide show
  1. flow_tools.py +773 -0
flow_tools.py ADDED
@@ -0,0 +1,773 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import matplotlib.pyplot as plt
2
+ import torch
3
+ import cv2
4
+ import numpy as np
5
+ from matplotlib.colors import hsv_to_rgb
6
+ import torch.nn.functional as tf
7
+ from PIL import Image
8
+ from os.path import *
9
+ from io import BytesIO
10
+
11
+ cv2.setNumThreads(0)
12
+ cv2.ocl.setUseOpenCL(False)
13
+ TAG_CHAR = np.array([202021.25], np.float32)
14
+
15
+
16
+ def load_flow(path):
17
+ # if path.endswith('.png'):
18
+ # # for KITTI which uses 16bit PNG images
19
+ # # see 'https://github.com/ClementPinard/FlowNetPytorch/blob/master/datasets/KITTI.py'
20
+ # # The -1 is here to specify not to change the image depth (16bit), and is compatible
21
+ # # with both OpenCV2 and OpenCV3
22
+ # flo_file = cv2.imread(path, -1)
23
+ # flo_img = flo_file[:, :, 2:0:-1].astype(np.float32)
24
+ # invalid = (flo_file[:, :, 0] == 0) # mask
25
+ # flo_img = flo_img - 32768
26
+ # flo_img = flo_img / 64
27
+ # flo_img[np.abs(flo_img) < 1e-10] = 1e-10
28
+ # flo_img[invalid, :] = 0
29
+ # return flo_img
30
+ if path.endswith('.png'):
31
+ # this method is only for the flow data generated by self-rendering
32
+ # read json file and get "forward" and "backward" flow
33
+ import json
34
+ path_range = path.replace(path.name, 'data_ranges.json')
35
+ with open(path_range, 'r') as f:
36
+ flow_dict = json.load(f)
37
+ flow_forward = flow_dict['forward_flow']
38
+ # get the max and min value of the flow
39
+ max_value = float(flow_forward["max"])
40
+ min_value = float(flow_forward["min"])
41
+ # read the flow data
42
+ flow_file = cv2.imread(path, -1).astype(np.float32)
43
+ # scale the flow data
44
+ flow_file = flow_file * (max_value - min_value) / 65535 + min_value
45
+ # only keep the last two channels
46
+ flow_file = flow_file[:, :, 1:]
47
+ return flow_file
48
+
49
+ # scaling = {"min": min_value.item(), "max": max_value.item()}
50
+ # data = (data - min_value) * 65535 / (max_value - min_value)
51
+ # data = data.astype(np.uint16)
52
+
53
+ elif path.endswith('.flo'):
54
+ with open(path, 'rb') as f:
55
+ magic = np.fromfile(f, np.float32, count=1)
56
+ assert (202021.25 == magic), 'Magic number incorrect. Invalid .flo file'
57
+ h = np.fromfile(f, np.int32, count=1)[0]
58
+ w = np.fromfile(f, np.int32, count=1)[0]
59
+ data = np.fromfile(f, np.float32, count=2 * w * h)
60
+ # Reshape data into 3D array (columns, rows, bands)
61
+ data2D = np.resize(data, (w, h, 2))
62
+ return data2D
63
+ elif path.endswith('.pfm'):
64
+ file = open(path, 'rb')
65
+
66
+ color = None
67
+ width = None
68
+ height = None
69
+ scale = None
70
+ endian = None
71
+ header = file.readline().rstrip()
72
+ if header == b'PF':
73
+ color = True
74
+ elif header == b'Pf':
75
+ color = False
76
+ else:
77
+ raise Exception('Not a PFM file.')
78
+
79
+ dim_match = re.match(rb'^(\d+)\s(\d+)\s$', file.readline())
80
+ if dim_match:
81
+ width, height = map(int, dim_match.groups())
82
+ else:
83
+ raise Exception('Malformed PFM header.')
84
+
85
+ scale = float(file.readline().rstrip())
86
+ if scale < 0: # little-endian
87
+ endian = '<'
88
+ scale = -scale
89
+ else:
90
+ endian = '>' # big-endian
91
+ data = np.fromfile(file, endian + 'f')
92
+ shape = (height, width, 3) if color else (height, width)
93
+ data = np.reshape(data, shape)
94
+ data = np.flipud(data).astype(np.float32)
95
+ if len(data.shape) == 2:
96
+ return data
97
+ else:
98
+ return data[:, :, :-1]
99
+ elif path.endswith('.bin') or path.endswith('.raw'):
100
+ return np.load(path)
101
+ else:
102
+ raise NotImplementedError("flow type")
103
+
104
+
105
+ def make_colorwheel():
106
+ """
107
+ Generates a color wheel for optical flow visualization as presented in:
108
+ Baker et al. "A Database and Evaluation Methodology for Optical Flow" (ICCV, 2007)
109
+ URL: http://vision.middlebury.edu/flow/flowEval-iccv07.pdf
110
+
111
+ Code follows the original C++ source code of Daniel Scharstein.
112
+ Code follows the the Matlab source code of Deqing Sun.
113
+
114
+ Returns:
115
+ np.ndarray: Color wheel
116
+ """
117
+
118
+ RY = 15
119
+ YG = 6
120
+ GC = 4
121
+ CB = 11
122
+ BM = 13
123
+ MR = 6
124
+
125
+ ncols = RY + YG + GC + CB + BM + MR
126
+ colorwheel = np.zeros((ncols, 3))
127
+ col = 0
128
+
129
+ # RY
130
+ colorwheel[0:RY, 0] = 255
131
+ colorwheel[0:RY, 1] = np.floor(255 * np.arange(0, RY) / RY)
132
+ col = col + RY
133
+ # YG
134
+ colorwheel[col:col + YG, 0] = 255 - np.floor(255 * np.arange(0, YG) / YG)
135
+ colorwheel[col:col + YG, 1] = 255
136
+ col = col + YG
137
+ # GC
138
+ colorwheel[col:col + GC, 1] = 255
139
+ colorwheel[col:col + GC, 2] = np.floor(255 * np.arange(0, GC) / GC)
140
+ col = col + GC
141
+ # CB
142
+ colorwheel[col:col + CB, 1] = 255 - np.floor(255 * np.arange(CB) / CB)
143
+ colorwheel[col:col + CB, 2] = 255
144
+ col = col + CB
145
+ # BM
146
+ colorwheel[col:col + BM, 2] = 255
147
+ colorwheel[col:col + BM, 0] = np.floor(255 * np.arange(0, BM) / BM)
148
+ col = col + BM
149
+ # MR
150
+ colorwheel[col:col + MR, 2] = 255 - np.floor(255 * np.arange(MR) / MR)
151
+ colorwheel[col:col + MR, 0] = 255
152
+ return colorwheel
153
+
154
+
155
+ def flow_uv_to_colors(u, v, convert_to_bgr=False):
156
+ """
157
+ Applies the flow color wheel to (possibly clipped) flow components u and v.
158
+
159
+ According to the C++ source code of Daniel Scharstein
160
+ According to the Matlab source code of Deqing Sun
161
+
162
+ Args:
163
+ u (np.ndarray): Input horizontal flow of shape [H,W]
164
+ v (np.ndarray): Input vertical flow of shape [H,W]
165
+ convert_to_bgr (bool, optional): Convert output image to BGR. Defaults to False.
166
+
167
+ Returns:
168
+ np.ndarray: Flow visualization image of shape [H,W,3]
169
+ """
170
+ flow_image = np.zeros((u.shape[0], u.shape[1], 3), np.uint8)
171
+ colorwheel = make_colorwheel() # shape [55x3]
172
+ ncols = colorwheel.shape[0]
173
+ rad = np.sqrt(np.square(u) + np.square(v))
174
+ a = np.arctan2(-v, -u) / np.pi
175
+ fk = (a + 1) / 2 * (ncols - 1)
176
+ k0 = np.floor(fk).astype(np.int32)
177
+ k1 = k0 + 1
178
+ k1[k1 == ncols] = 0
179
+ f = fk - k0
180
+ for i in range(colorwheel.shape[1]):
181
+ tmp = colorwheel[:, i]
182
+ col0 = tmp[k0] / 255.0
183
+ col1 = tmp[k1] / 255.0
184
+ col = (1 - f) * col0 + f * col1
185
+ idx = (rad <= 1)
186
+ col[idx] = 1 - rad[idx] * (1 - col[idx])
187
+ col[~idx] = col[~idx] * 0.75 # out of range
188
+ # Note the 2-i => BGR instead of RGB
189
+ ch_idx = 2 - i if convert_to_bgr else i
190
+ flow_image[:, :, ch_idx] = np.floor(255 * col)
191
+ return flow_image
192
+
193
+
194
+ # absolut color flow
195
+ def flow_to_image(flow, max_flow=256):
196
+ if max_flow is not None:
197
+ max_flow = max(max_flow, 1.)
198
+ else:
199
+ max_flow = np.max(flow)
200
+
201
+ n = 8
202
+ u, v = flow[:, :, 0], flow[:, :, 1]
203
+ mag = np.sqrt(np.square(u) + np.square(v))
204
+ angle = np.arctan2(v, u)
205
+ im_h = np.mod(angle / (2 * np.pi) + 1, 1)
206
+ im_s = np.clip(mag * n / max_flow, a_min=0, a_max=1)
207
+ im_v = np.clip(n - im_s, a_min=0, a_max=1)
208
+ im = hsv_to_rgb(np.stack([im_h, im_s, im_v], 2))
209
+ return (im * 255).astype(np.uint8)
210
+
211
+
212
+ # relative color
213
+ def flow_to_image_relative(flow_uv, clip_flow=None, convert_to_bgr=False):
214
+ """
215
+ Expects a two dimensional flow image of shape.
216
+
217
+ Args:
218
+ flow_uv (np.ndarray): Flow UV image of shape [H,W,2]
219
+ clip_flow (float, optional): Clip maximum of flow values. Defaults to None.
220
+ convert_to_bgr (bool, optional): Convert output image to BGR. Defaults to False.
221
+
222
+ Returns:
223
+ np.ndarray: Flow visualization image of shape [H,W,3]
224
+ """
225
+ assert flow_uv.ndim == 3, 'input flow must have three dimensions'
226
+ assert flow_uv.shape[2] == 2, 'input flow must have shape [H,W,2]'
227
+ if clip_flow is not None:
228
+ flow_uv = np.clip(flow_uv, 0, clip_flow)
229
+ u = flow_uv[:, :, 0]
230
+ v = flow_uv[:, :, 1]
231
+ rad = np.sqrt(np.square(u) + np.square(v))
232
+ rad_max = np.max(rad)
233
+ epsilon = 1e-5
234
+ u = u / (rad_max + epsilon)
235
+ v = v / (rad_max + epsilon)
236
+ return flow_uv_to_colors(u, v, convert_to_bgr)
237
+
238
+
239
+ def resize_flow(flow, new_shape):
240
+ _, _, h, w = flow.shape
241
+ new_h, new_w = new_shape
242
+ flow = torch.nn.functional.interpolate(flow, (new_h, new_w),
243
+ mode='bilinear', align_corners=True)
244
+ scale_h, scale_w = h / float(new_h), w / float(new_w)
245
+ flow[:, 0] /= scale_w
246
+ flow[:, 1] /= scale_h
247
+ return flow
248
+
249
+
250
+ def evaluate_flow_api(gt_flows, pred_flows):
251
+ if len(gt_flows.shape) == 3:
252
+ gt_flows = gt_flows.unsqueeze(0)
253
+ if len(pred_flows.shape) == 3:
254
+ pred_flows = pred_flows.unsqueeze(0)
255
+ pred_flows = pred_flows.detach().cpu().numpy().transpose([0, 2, 3, 1])
256
+ gt_flows = gt_flows.detach().cpu().numpy().transpose([0, 2, 3, 1])
257
+ return evaluate_flow(gt_flows, pred_flows)
258
+
259
+
260
+ def evaluate_flow(gt_flows, pred_flows, moving_masks=None):
261
+ # credit "undepthflow/eval/evaluate_flow.py"
262
+ def calculate_error_rate(epe_map, gt_flow, mask):
263
+ bad_pixels = np.logical_and(
264
+ epe_map * mask > 3,
265
+ epe_map * mask / np.maximum(
266
+ np.sqrt(np.sum(np.square(gt_flow), axis=2)), 1e-10) > 0.05)
267
+ return bad_pixels.sum() / mask.sum() * 100.
268
+
269
+ error, error_noc, error_occ, error_move, error_static, error_rate = \
270
+ 0.0, 0.0, 0.0, 0.0, 0.0, 0.0
271
+ error_move_rate, error_static_rate = 0.0, 0.0
272
+ B = len(gt_flows)
273
+ for gt_flow, pred_flow, i in zip(gt_flows, pred_flows, range(B)):
274
+ H, W = gt_flow.shape[:2]
275
+
276
+ h, w = pred_flow.shape[:2]
277
+ pred_flow = np.copy(pred_flow)
278
+ pred_flow[:, :, 0] = pred_flow[:, :, 0] / w * W
279
+ pred_flow[:, :, 1] = pred_flow[:, :, 1] / h * H
280
+
281
+ flo_pred = cv2.resize(pred_flow, (W, H), interpolation=cv2.INTER_LINEAR)
282
+
283
+ epe_map = np.sqrt(
284
+ np.sum(np.square(flo_pred[:, :, :2] - gt_flow[:, :, :2]),
285
+ axis=2))
286
+ if gt_flow.shape[-1] == 2:
287
+ error += np.mean(epe_map)
288
+
289
+ elif gt_flow.shape[-1] == 4:
290
+ error += np.sum(epe_map * gt_flow[:, :, 2]) / np.sum(gt_flow[:, :, 2])
291
+ noc_mask = gt_flow[:, :, -1]
292
+ error_noc += np.sum(epe_map * noc_mask) / np.sum(noc_mask)
293
+
294
+ error_occ += np.sum(epe_map * (gt_flow[:, :, 2] - noc_mask)) / max(
295
+ np.sum(gt_flow[:, :, 2] - noc_mask), 1.0)
296
+
297
+ error_rate += calculate_error_rate(epe_map, gt_flow[:, :, 0:2],
298
+ gt_flow[:, :, 2])
299
+
300
+ if moving_masks is not None:
301
+ move_mask = moving_masks[i]
302
+
303
+ error_move_rate += calculate_error_rate(
304
+ epe_map, gt_flow[:, :, 0:2], gt_flow[:, :, 2] * move_mask)
305
+ error_static_rate += calculate_error_rate(
306
+ epe_map, gt_flow[:, :, 0:2],
307
+ gt_flow[:, :, 2] * (1.0 - move_mask))
308
+
309
+ error_move += np.sum(epe_map * gt_flow[:, :, 2] *
310
+ move_mask) / np.sum(gt_flow[:, :, 2] *
311
+ move_mask)
312
+ error_static += np.sum(epe_map * gt_flow[:, :, 2] * (
313
+ 1.0 - move_mask)) / np.sum(gt_flow[:, :, 2] *
314
+ (1.0 - move_mask))
315
+
316
+ if gt_flows[0].shape[-1] == 4:
317
+ res = [error / B, error_noc / B, error_occ / B, error_rate / B]
318
+ if moving_masks is not None:
319
+ res += [error_move / B, error_static / B]
320
+ return res
321
+ else:
322
+ return [error / B]
323
+
324
+
325
+ class InputPadder:
326
+ """ Pads images such that dimensions are divisible by 32 """
327
+
328
+ def __init__(self, dims, mode='sintel'):
329
+ self.ht, self.wd = dims[-2:]
330
+ pad_ht = (((self.ht // 16) + 1) * 16 - self.ht) % 16
331
+ pad_wd = (((self.wd // 16) + 1) * 16 - self.wd) % 16
332
+ if mode == 'sintel':
333
+ self._pad = [pad_wd // 2, pad_wd - pad_wd // 2, pad_ht // 2, pad_ht - pad_ht // 2]
334
+ else:
335
+ self._pad = [pad_wd // 2, pad_wd - pad_wd // 2, 0, pad_ht]
336
+
337
+ def pad(self, inputs):
338
+ return [tf.pad(x, self._pad, mode='replicate') for x in inputs]
339
+
340
+ def unpad(self, x):
341
+ ht, wd = x.shape[-2:]
342
+ c = [self._pad[2], ht - self._pad[3], self._pad[0], wd - self._pad[1]]
343
+
344
+ return x[..., c[0]:c[1], c[2]:c[3]]
345
+
346
+
347
+ class ImageInputZoomer:
348
+ """ Pads images such that dimensions are divisible by 32 """
349
+
350
+ def __init__(self, dims, factor=32):
351
+ self.ht, self.wd = dims[-2:]
352
+ hf = self.ht % factor
353
+ wf = self.wd % factor
354
+ pad_ht = (self.ht // factor + 1) * factor if hf > (factor / 2) else (self.ht // factor) * factor
355
+ pad_wd = (self.wd // factor + 1) * factor if wf > (factor / 2) else (self.wd // factor) * factor
356
+ self.size = [pad_wd, pad_ht]
357
+
358
+ def zoom(self, inputs):
359
+ return [
360
+ torch.from_numpy(cv2.resize(x.cpu().numpy().transpose(1, 2, 0), dsize=self.size,
361
+ interpolation=cv2.INTER_CUBIC).transpose(2, 0, 1)) for x in inputs]
362
+
363
+ def unzoom(self, inputs):
364
+ return [cv2.resize(x.cpu().squeeze().numpy().transpose(1, 2, 0), dsize=(self.wd, self.ht),
365
+ interpolation=cv2.INTER_CUBIC) for x in inputs]
366
+
367
+
368
+ def readFlow(fn):
369
+ """ Read .flo file in Middlebury format"""
370
+ # Code adapted from:
371
+ # http://stackoverflow.com/questions/28013200/reading-middlebury-flow-files-with-python-bytes-array-numpy
372
+
373
+ # WARNING: this will work on little-endian architectures (eg Intel x86) only!
374
+ # print 'fn = %s'%(fn)
375
+ with open(fn, 'rb') as f:
376
+ magic = np.fromfile(f, np.float32, count=1)
377
+ if 202021.25 != magic:
378
+ print('Magic number incorrect. Invalid .flo file')
379
+ return None
380
+ else:
381
+ w = np.fromfile(f, np.int32, count=1)
382
+ h = np.fromfile(f, np.int32, count=1)
383
+ # print 'Reading %d x %d flo file\n' % (w, h)
384
+ data = np.fromfile(f, np.float32, count=2 * int(w) * int(h))
385
+ # Reshape data into 3D array (columns, rows, bands)
386
+ # The reshape here is for visualization, the original code is (w,h,2)
387
+ return np.resize(data, (int(h), int(w), 2))
388
+
389
+
390
+ import re
391
+
392
+
393
+ def readPFM(file):
394
+ file = open(file, 'rb')
395
+
396
+ color = None
397
+ width = None
398
+ height = None
399
+ scale = None
400
+ endian = None
401
+
402
+ header = file.readline().rstrip()
403
+ if header == b'PF':
404
+ color = True
405
+ elif header == b'Pf':
406
+ color = False
407
+ else:
408
+ raise Exception('Not a PFM file.')
409
+
410
+ dim_match = re.match(rb'^(\d+)\s(\d+)\s$', file.readline())
411
+ if dim_match:
412
+ width, height = map(int, dim_match.groups())
413
+ else:
414
+ raise Exception('Malformed PFM header.')
415
+
416
+ scale = float(file.readline().rstrip())
417
+ if scale < 0: # little-endian
418
+ endian = '<'
419
+ scale = -scale
420
+ else:
421
+ endian = '>' # big-endian
422
+
423
+ data = np.fromfile(file, endian + 'f')
424
+ shape = (height, width, 3) if color else (height, width)
425
+
426
+ data = np.reshape(data, shape)
427
+ data = np.flipud(data)
428
+ return data
429
+
430
+
431
+ def writeFlow(filename, uv, v=None):
432
+ """ Write optical flow to file.
433
+
434
+ If v is None, uv is assumed to contain both u and v channels,
435
+ stacked in depth.
436
+ Original code by Deqing Sun, adapted from Daniel Scharstein.
437
+ """
438
+ nBands = 2
439
+
440
+ if v is None:
441
+ assert (uv.ndim == 3)
442
+ assert (uv.shape[2] == 2)
443
+ u = uv[:, :, 0]
444
+ v = uv[:, :, 1]
445
+ else:
446
+ u = uv
447
+
448
+ assert (u.shape == v.shape)
449
+ height, width = u.shape
450
+ f = open(filename, 'wb')
451
+ # write the header
452
+ f.write(TAG_CHAR)
453
+ np.array(width).astype(np.int32).tofile(f)
454
+ np.array(height).astype(np.int32).tofile(f)
455
+ # arrange into matrix form
456
+ tmp = np.zeros((height, width * nBands))
457
+ tmp[:, np.arange(width) * 2] = u
458
+ tmp[:, np.arange(width) * 2 + 1] = v
459
+ tmp.astype(np.float32).tofile(f)
460
+ f.close()
461
+
462
+
463
+ def readFlowKITTI(filename):
464
+ flow = cv2.imread(filename, cv2.IMREAD_ANYDEPTH | cv2.IMREAD_COLOR)
465
+ flow = flow[:, :, ::-1].astype(np.float32)
466
+ flow, valid = flow[:, :, :2], flow[:, :, 2]
467
+ flow = (flow - 2 ** 15) / 64.0
468
+ return flow, valid
469
+
470
+
471
+ def readDispKITTI(filename):
472
+ disp = cv2.imread(filename, cv2.IMREAD_ANYDEPTH) / 256.0
473
+ valid = disp > 0.0
474
+ flow = np.stack([-disp, np.zeros_like(disp)], -1)
475
+ return flow, valid
476
+
477
+
478
+ def writeFlowKITTI(filename, uv):
479
+ uv = 64.0 * uv + 2 ** 15
480
+ valid = np.ones([uv.shape[0], uv.shape[1], 1])
481
+ uv = np.concatenate([uv, valid], axis=-1).astype(np.uint16)
482
+ cv2.imwrite(filename, uv[..., ::-1])
483
+
484
+
485
+ def read_gen(file_name, pil=False):
486
+ ext = splitext(file_name)[-1]
487
+ if ext == '.png' or ext == '.jpeg' or ext == '.ppm' or ext == '.jpg':
488
+ return Image.open(file_name)
489
+ elif ext == '.bin' or ext == '.raw':
490
+ return np.load(file_name)
491
+ elif ext == '.flo':
492
+ return readFlow(file_name).astype(np.float32)
493
+ elif ext == '.pfm':
494
+ flow = readPFM(file_name).astype(np.float32)
495
+ if len(flow.shape) == 2:
496
+ return flow
497
+ else:
498
+ return flow[:, :, :-1]
499
+ return []
500
+
501
+
502
+ def flow_error_image_np(flow_pred, flow_gt, mask_occ, mask_noc=None, log_colors=True):
503
+ """Visualize the error between two flows as 3-channel color image.
504
+ Adapted from the KITTI C++ devkit.
505
+ Args:
506
+ flow_pred: prediction flow of shape [ height, width, 2].
507
+ flow_gt: ground truth
508
+ mask_occ: flow validity mask of shape [num_batch, height, width, 1].
509
+ Equals 1 at (occluded and non-occluded) valid pixels.
510
+ mask_noc: Is 1 only at valid pixels which are not occluded.
511
+ """
512
+ # mask_noc = tf.ones(tf.shape(mask_occ)) if mask_noc is None else mask_noc
513
+ mask_noc = np.ones(mask_occ.shape) if mask_noc is None else mask_noc
514
+ diff_sq = (flow_pred - flow_gt) ** 2
515
+ # diff = tf.sqrt(tf.reduce_sum(diff_sq, [3], keep_dims=True))
516
+ diff = np.sqrt(np.sum(diff_sq, axis=2, keepdims=True))
517
+ if log_colors:
518
+ height, width, _ = flow_pred.shape
519
+ # num_batch, height, width, _ = tf.unstack(tf.shape(flow_1))
520
+ colormap = [
521
+ [0, 0.0625, 49, 54, 149],
522
+ [0.0625, 0.125, 69, 117, 180],
523
+ [0.125, 0.25, 116, 173, 209],
524
+ [0.25, 0.5, 171, 217, 233],
525
+ [0.5, 1, 224, 243, 248],
526
+ [1, 2, 254, 224, 144],
527
+ [2, 4, 253, 174, 97],
528
+ [4, 8, 244, 109, 67],
529
+ [8, 16, 215, 48, 39],
530
+ [16, 1000000000.0, 165, 0, 38]]
531
+ colormap = np.asarray(colormap, dtype=np.float32)
532
+ colormap[:, 2:5] = colormap[:, 2:5] / 255
533
+ # mag = tf.sqrt(tf.reduce_sum(tf.square(flow_2), 3, keep_dims=True))
534
+ tempp = np.square(flow_gt)
535
+ # temp = np.sum(tempp, axis=2, keep_dims=True)
536
+ # mag = np.sqrt(temp)
537
+ mag = np.sqrt(np.sum(tempp, axis=2, keepdims=True))
538
+ # error = tf.minimum(diff / 3, 20 * diff / mag)
539
+ error = np.minimum(diff / 3, 20 * diff / (mag + 1e-7))
540
+ im = np.zeros([height, width, 3])
541
+ for i in range(colormap.shape[0]):
542
+ colors = colormap[i, :]
543
+ cond = np.logical_and(np.greater_equal(error, colors[0]), np.less(error, colors[1]))
544
+ # temp=np.tile(cond, [1, 1, 3])
545
+ im = np.where(np.tile(cond, [1, 1, 3]), np.ones([height, width, 1]) * colors[2:5], im)
546
+ # temp=np.cast(mask_noc, np.bool)
547
+ # im = np.where(np.tile(np.cast(mask_noc, np.bool), [1, 1, 3]), im, im * 0.5)
548
+ im = np.where(np.tile(mask_noc == 1, [1, 1, 3]), im, im * 0.5)
549
+ im = im * mask_occ
550
+ else:
551
+ error = (np.minimum(diff, 5) / 5) * mask_occ
552
+ im_r = error # errors in occluded areas will be red
553
+ im_g = error * mask_noc
554
+ im_b = error * mask_noc
555
+ im = np.concatenate([im_r, im_g, im_b], axis=2)
556
+ # im = np.concatenate(axis=2, values=[im_r, im_g, im_b])
557
+ return im[:, :, ::-1]
558
+
559
+
560
+ def viz_img_seq(img_list=[], flow_list=[], batch_index=0, if_debug=True):
561
+ '''visulize image sequence from cuda'''
562
+ if if_debug:
563
+
564
+ assert len(img_list) != 0
565
+ if len(img_list[0].shape) == 3:
566
+ img_list = [np.expand_dims(img, axis=0) for img in img_list]
567
+ elif img_list[0].shape[1] == 1:
568
+ img_list = [img1[batch_index].detach().cpu().permute(1, 2, 0).numpy() for img1 in img_list]
569
+ img_list = [cv2.cvtColor(flo * 255, cv2.COLOR_GRAY2BGR) for flo in img_list]
570
+ elif img_list[0].shape[1] == 2:
571
+ img_list = [img1[batch_index].detach().cpu().permute(1, 2, 0).numpy() for img1 in img_list]
572
+ img_list = [flow_to_image_relative(flo) / 255.0 for flo in img_list]
573
+ else:
574
+ img_list = [img1[batch_index].detach().cpu().permute(1, 2, 0).numpy() for img1 in img_list]
575
+
576
+ if len(flow_list) == 0:
577
+ flow_list = [np.zeros_like(img) for img in img_list]
578
+ elif len(flow_list[0].shape) == 3:
579
+ flow_list = [np.expand_dims(img, axis=0) for img in flow_list]
580
+ elif flow_list[0].shape[1] == 1:
581
+ flow_list = [img1[batch_index].detach().cpu().permute(1, 2, 0).numpy() for img1 in flow_list]
582
+ flow_list = [cv2.cvtColor(flo * 255, cv2.COLOR_GRAY2BGR) for flo in flow_list]
583
+ elif flow_list[0].shape[1] == 2:
584
+ flow_list = [img1[batch_index].detach().cpu().permute(1, 2, 0).numpy() for img1 in flow_list]
585
+ flow_list = [flow_to_image_relative(flo) / 255.0 for flo in flow_list]
586
+ else:
587
+ flow_list = [img1[batch_index].detach().cpu().permute(1, 2, 0).numpy() for img1 in flow_list]
588
+
589
+ if img_list[0].max() > 10:
590
+ img_list = [img / 255.0 for img in img_list]
591
+ if flow_list[0].max() > 10:
592
+ flow_list = [img / 255.0 for img in flow_list]
593
+
594
+ while len(img_list) > len(flow_list):
595
+ flow_list.append(np.zeros_like(flow_list[-1]))
596
+ while len(flow_list) > len(img_list):
597
+ img_list.append(np.zeros_like(img_list[-1]))
598
+ img_flo = np.concatenate([flow_list[0], img_list[0]], axis=0)
599
+ # map flow to rgb image
600
+ for i in range(1, len(img_list)):
601
+ temp = np.concatenate([flow_list[i], img_list[i]], axis=0)
602
+ img_flo = np.concatenate([img_flo, temp], axis=1)
603
+ cv2.imshow('image', img_flo[:, :, [2, 1, 0]])
604
+ cv2.waitKey()
605
+ else:
606
+ return
607
+
608
+
609
+ def plt_show_img_flow(img_list=[], flow_list=[], batch_index=0):
610
+ assert len(img_list) != 0
611
+ if len(img_list[0].shape) == 3:
612
+ img_list = [np.expand_dims(img, axis=0) for img in img_list]
613
+ elif img_list[0].shape[1] == 1:
614
+ img_list = [img1[batch_index].detach().cpu().permute(1, 2, 0).numpy() for img1 in img_list]
615
+ img_list = [cv2.cvtColor(flo * 255, cv2.COLOR_GRAY2BGR) for flo in img_list]
616
+ elif img_list[0].shape[1] == 2:
617
+ img_list = [img1[batch_index].detach().cpu().permute(1, 2, 0).numpy() for img1 in img_list]
618
+ img_list = [flow_to_image_relative(flo) / 255.0 for flo in img_list]
619
+ else:
620
+ img_list = [img1[batch_index].detach().cpu().permute(1, 2, 0).numpy() for img1 in img_list]
621
+
622
+ assert flow_list[0].shape[1] == 2
623
+ flow_vec = [img1[batch_index].detach().cpu().permute(1, 2, 0).numpy() for img1 in flow_list]
624
+ flow_list = [flow_to_image_relative(flo) / 255.0 for flo in flow_vec]
625
+
626
+ col = len(flow_list) // 2
627
+ fig = plt.figure(figsize=(10, 8))
628
+ for i in range(len(flow_list)):
629
+ ax1 = fig.add_subplot(2, col, i + 1)
630
+ plot_quiver(ax1, flow=flow_vec[i], mask=flow_list[i], spacing=(30 * flow_list[i].shape[0]) // 512)
631
+ if i == len(flow_list) - 1:
632
+ plt.title("Final Flow Result")
633
+ else:
634
+ plt.title("Flow from decoder (Layer %d)" % i)
635
+ plt.xticks([])
636
+ plt.yticks([])
637
+ plt.tight_layout()
638
+
639
+ # save image to buffer
640
+ buf = BytesIO()
641
+ plt.savefig(buf, format='png')
642
+ buf.seek(0)
643
+ # convert buffer to image
644
+ img = Image.open(buf)
645
+ # convert image to numpy array
646
+ img = np.asarray(img)
647
+ return img
648
+
649
+
650
+ def plt_attention(attention, h, w):
651
+ col = len(attention) // 2
652
+ fig = plt.figure(figsize=(10, 5))
653
+
654
+ for i in range(len(attention)):
655
+ viz = attention[i][0, :, :, h, w].detach().cpu().numpy()
656
+ # viz = viz[7:-7, 7:-7]
657
+ if i == 0:
658
+ viz_all = viz
659
+ else:
660
+ viz_all = viz_all + viz
661
+
662
+ ax1 = fig.add_subplot(2, col + 1, i + 1)
663
+ img = ax1.imshow(viz, cmap="rainbow", interpolation="bilinear")
664
+ plt.colorbar(img, ax=ax1)
665
+ ax1.scatter(h, w, color='red')
666
+ plt.title("Attention of Iteration %d" % (i + 1))
667
+
668
+ ax1 = fig.add_subplot(2, col + 1, 2 * (col + 1))
669
+ img = ax1.imshow(viz_all, cmap="rainbow", interpolation="bilinear")
670
+ plt.colorbar(img, ax=ax1)
671
+ ax1.scatter(h, w, color='red')
672
+ plt.title("Mean Attention")
673
+ plt.show()
674
+
675
+
676
+ def plot_quiver(ax, flow, spacing, mask=None, show_win=None, margin=0, **kwargs):
677
+ """Plots less dense quiver field.
678
+
679
+ Args:
680
+ ax: Matplotlib axis
681
+ flow: motion vectors
682
+ spacing: space (px) between each arrow in grid
683
+ margin: width (px) of enclosing region without arrows
684
+ kwargs: quiver kwargs (default: angles="xy", scale_units="xy")
685
+ """
686
+ h, w, *_ = flow.shape
687
+ spacing = 50
688
+ if show_win is None:
689
+ nx = int((w - 2 * margin) / spacing)
690
+ ny = int((h - 2 * margin) / spacing)
691
+ x = np.linspace(margin, w - margin - 1, nx, dtype=np.int64)
692
+ y = np.linspace(margin, h - margin - 1, ny, dtype=np.int64)
693
+ else:
694
+ h0, h1, w0, w1 = *show_win,
695
+ h0 = int(h0 * h)
696
+ h1 = int(h1 * h)
697
+ w0 = int(w0 * w)
698
+ w1 = int(w1 * w)
699
+ num_h = (h1 - h0) // spacing
700
+ num_w = (w1 - w0) // spacing
701
+ y = np.linspace(h0, h1, num_h, dtype=np.int64)
702
+ x = np.linspace(w0, w1, num_w, dtype=np.int64)
703
+
704
+ flow = flow[np.ix_(y, x)]
705
+ u = flow[:, :, 0]
706
+ v = flow[:, :, 1] * -1 # ----------
707
+
708
+ kwargs = {**dict(angles="xy", scale_units="xy"), **kwargs}
709
+ if mask is not None:
710
+ ax.imshow(mask)
711
+ # ax.quiver(x, y, u, v, color="black", scale=10, width=0.010, headwidth=5, minlength=0.5) # bigger is short
712
+ ax.quiver(x, y, u, v, color="black") # bigger is short
713
+ x_gird, y_gird = np.meshgrid(x, y)
714
+ ax.scatter(x_gird, y_gird, c="black", s=(h + w) // 50)
715
+ ax.scatter(x_gird, y_gird, c="black", s=(h + w) // 100)
716
+ ax.set_ylim(sorted(ax.get_ylim(), reverse=True))
717
+ ax.set_aspect("equal")
718
+
719
+
720
+ def save_img_seq(img_list, batch_index=0, name='img', if_debug=False):
721
+ if if_debug:
722
+ temp = img_list[0]
723
+ size = temp.shape
724
+ fourcc = cv2.VideoWriter_fourcc(*'mp4v')
725
+ out = cv2.VideoWriter(name + '_flow.mp4', fourcc, 22, (size[-1], size[-2]))
726
+ if img_list[0].shape[1] == 2:
727
+ image_list = []
728
+ flow_vec = [img1[batch_index].detach().cpu().permute(1, 2, 0).numpy() for img1 in img_list]
729
+ flow_viz = [flow_to_image_relative(flo) for flo in flow_vec]
730
+ # for index, img in enumerate(flow_viz):
731
+ # image_list.append(viz(flow_viz[index], flow_vec[index], flow_viz[index]))
732
+ img_list = flow_viz
733
+ if img_list[0].shape[1] == 3:
734
+ img_list = [img1[batch_index].detach().cpu().permute(1, 2, 0).numpy() * 255.0 for img1 in img_list]
735
+ if img_list[0].shape[1] == 1:
736
+ img_list = [img1[batch_index].detach().cpu().permute(1, 2, 0).numpy() for img1 in img_list]
737
+ img_list = [cv2.cvtColor(flo * 255, cv2.COLOR_GRAY2BGR) for flo in img_list]
738
+
739
+ for index, img in enumerate(img_list):
740
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
741
+ cv2.imwrite(name + '_%d.png' % index, img)
742
+ out.write(img.astype(np.uint8))
743
+ out.release()
744
+ else:
745
+ return
746
+
747
+
748
+ from io import BytesIO
749
+
750
+
751
+ def viz(flo, flow_vec,
752
+ image):
753
+ fig, axes = plt.subplots(1, 2, figsize=(10, 5), dpi=500)
754
+ ax1 = axes[0]
755
+ plot_quiver(ax1, flow=flow_vec, mask=flo, spacing=40)
756
+ ax1.set_title('flow all')
757
+
758
+ ax1 = axes[1]
759
+ ax1.imshow(image)
760
+ ax1.set_title('image')
761
+
762
+ plt.tight_layout()
763
+ # eliminate the x and y-axis
764
+ plt.axis('off')
765
+ # save figure into a buffer
766
+ buf = BytesIO()
767
+ plt.savefig(buf, format='png', dpi=200)
768
+ buf.seek(0)
769
+ # convert to numpy array
770
+ im = np.array(Image.open(buf))
771
+ buf.close()
772
+ plt.close()
773
+ return im