adirik commited on
Commit
b262a3f
·
1 Parent(s): f4c3c2b

update repo

Browse files
.DS_Store CHANGED
Binary files a/.DS_Store and b/.DS_Store differ
 
.gitattributes CHANGED
@@ -34,3 +34,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
34
  *tfevents* filter=lfs diff=lfs merge=lfs -text
35
  *.pth* filter=lfs diff=lfs merge=lfs -text
36
  filter=lfs diff=lfs merge=lfs -text
 
 
 
34
  *tfevents* filter=lfs diff=lfs merge=lfs -text
35
  *.pth* filter=lfs diff=lfs merge=lfs -text
36
  filter=lfs diff=lfs merge=lfs -text
37
+ *.pkl* filter=lfs diff=lfs merge=lfs -text
38
+ filter=lfs diff=lfs merge=lfs -text
dnnlib/__pycache__/util.cpython-38.pyc CHANGED
Binary files a/dnnlib/__pycache__/util.cpython-38.pyc and b/dnnlib/__pycache__/util.cpython-38.pyc differ
 
id_loss.py CHANGED
@@ -15,7 +15,7 @@ class IDLoss(nn.Module):
15
  super(IDLoss, self).__init__()
16
  print('Loading ResNet ArcFace')
17
  self.facenet = Backbone(input_size=112, num_layers=50, drop_ratio=0.6, mode='ir_se')
18
- self.facenet.load_state_dict(torch.load("model_ir_se50.pth", map_location=device))
19
  self.pool = torch.nn.AdaptiveAvgPool2d((256, 256))
20
  self.face_pool = torch.nn.AdaptiveAvgPool2d((112, 112))
21
  self.facenet.eval()
 
15
  super(IDLoss, self).__init__()
16
  print('Loading ResNet ArcFace')
17
  self.facenet = Backbone(input_size=112, num_layers=50, drop_ratio=0.6, mode='ir_se')
18
+ self.facenet.load_state_dict(torch.load("./pretrained/model_ir_se50.pth", map_location=device))
19
  self.pool = torch.nn.AdaptiveAvgPool2d((256, 256))
20
  self.face_pool = torch.nn.AdaptiveAvgPool2d((112, 112))
21
  self.facenet.eval()
metrics/__init__.py DELETED
@@ -1,9 +0,0 @@
1
- # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
2
- #
3
- # NVIDIA CORPORATION and its licensors retain all intellectual property
4
- # and proprietary rights in and to this software, related documentation
5
- # and any modifications thereto. Any use, reproduction, disclosure or
6
- # distribution of this software and related documentation without an express
7
- # license agreement from NVIDIA CORPORATION is strictly prohibited.
8
-
9
- # empty
 
 
 
 
 
 
 
 
 
 
metrics/frechet_inception_distance.py DELETED
@@ -1,41 +0,0 @@
1
- # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
2
- #
3
- # NVIDIA CORPORATION and its licensors retain all intellectual property
4
- # and proprietary rights in and to this software, related documentation
5
- # and any modifications thereto. Any use, reproduction, disclosure or
6
- # distribution of this software and related documentation without an express
7
- # license agreement from NVIDIA CORPORATION is strictly prohibited.
8
-
9
- """Frechet Inception Distance (FID) from the paper
10
- "GANs trained by a two time-scale update rule converge to a local Nash
11
- equilibrium". Matches the original implementation by Heusel et al. at
12
- https://github.com/bioinf-jku/TTUR/blob/master/fid.py"""
13
-
14
- import numpy as np
15
- import scipy.linalg
16
- from . import metric_utils
17
-
18
- #----------------------------------------------------------------------------
19
-
20
- def compute_fid(opts, max_real, num_gen):
21
- # Direct TorchScript translation of http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz
22
- detector_url = 'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/metrics/inception-2015-12-05.pt'
23
- detector_kwargs = dict(return_features=True) # Return raw features before the softmax layer.
24
-
25
- mu_real, sigma_real = metric_utils.compute_feature_stats_for_dataset(
26
- opts=opts, detector_url=detector_url, detector_kwargs=detector_kwargs,
27
- rel_lo=0, rel_hi=0, capture_mean_cov=True, max_items=max_real).get_mean_cov()
28
-
29
- mu_gen, sigma_gen = metric_utils.compute_feature_stats_for_generator(
30
- opts=opts, detector_url=detector_url, detector_kwargs=detector_kwargs,
31
- rel_lo=0, rel_hi=1, capture_mean_cov=True, max_items=num_gen).get_mean_cov()
32
-
33
- if opts.rank != 0:
34
- return float('nan')
35
-
36
- m = np.square(mu_gen - mu_real).sum()
37
- s, _ = scipy.linalg.sqrtm(np.dot(sigma_gen, sigma_real), disp=False) # pylint: disable=no-member
38
- fid = np.real(m + np.trace(sigma_gen + sigma_real - s * 2))
39
- return float(fid)
40
-
41
- #----------------------------------------------------------------------------
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
metrics/inception_score.py DELETED
@@ -1,38 +0,0 @@
1
- # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
2
- #
3
- # NVIDIA CORPORATION and its licensors retain all intellectual property
4
- # and proprietary rights in and to this software, related documentation
5
- # and any modifications thereto. Any use, reproduction, disclosure or
6
- # distribution of this software and related documentation without an express
7
- # license agreement from NVIDIA CORPORATION is strictly prohibited.
8
-
9
- """Inception Score (IS) from the paper "Improved techniques for training
10
- GANs". Matches the original implementation by Salimans et al. at
11
- https://github.com/openai/improved-gan/blob/master/inception_score/model.py"""
12
-
13
- import numpy as np
14
- from . import metric_utils
15
-
16
- #----------------------------------------------------------------------------
17
-
18
- def compute_is(opts, num_gen, num_splits):
19
- # Direct TorchScript translation of http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz
20
- detector_url = 'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/metrics/inception-2015-12-05.pt'
21
- detector_kwargs = dict(no_output_bias=True) # Match the original implementation by not applying bias in the softmax layer.
22
-
23
- gen_probs = metric_utils.compute_feature_stats_for_generator(
24
- opts=opts, detector_url=detector_url, detector_kwargs=detector_kwargs,
25
- capture_all=True, max_items=num_gen).get_all()
26
-
27
- if opts.rank != 0:
28
- return float('nan'), float('nan')
29
-
30
- scores = []
31
- for i in range(num_splits):
32
- part = gen_probs[i * num_gen // num_splits : (i + 1) * num_gen // num_splits]
33
- kl = part * (np.log(part) - np.log(np.mean(part, axis=0, keepdims=True)))
34
- kl = np.mean(np.sum(kl, axis=1))
35
- scores.append(np.exp(kl))
36
- return float(np.mean(scores)), float(np.std(scores))
37
-
38
- #----------------------------------------------------------------------------
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
metrics/kernel_inception_distance.py DELETED
@@ -1,46 +0,0 @@
1
- # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
2
- #
3
- # NVIDIA CORPORATION and its licensors retain all intellectual property
4
- # and proprietary rights in and to this software, related documentation
5
- # and any modifications thereto. Any use, reproduction, disclosure or
6
- # distribution of this software and related documentation without an express
7
- # license agreement from NVIDIA CORPORATION is strictly prohibited.
8
-
9
- """Kernel Inception Distance (KID) from the paper "Demystifying MMD
10
- GANs". Matches the original implementation by Binkowski et al. at
11
- https://github.com/mbinkowski/MMD-GAN/blob/master/gan/compute_scores.py"""
12
-
13
- import numpy as np
14
- from . import metric_utils
15
-
16
- #----------------------------------------------------------------------------
17
-
18
- def compute_kid(opts, max_real, num_gen, num_subsets, max_subset_size):
19
- # Direct TorchScript translation of http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz
20
- detector_url = 'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/metrics/inception-2015-12-05.pt'
21
- detector_kwargs = dict(return_features=True) # Return raw features before the softmax layer.
22
-
23
- real_features = metric_utils.compute_feature_stats_for_dataset(
24
- opts=opts, detector_url=detector_url, detector_kwargs=detector_kwargs,
25
- rel_lo=0, rel_hi=0, capture_all=True, max_items=max_real).get_all()
26
-
27
- gen_features = metric_utils.compute_feature_stats_for_generator(
28
- opts=opts, detector_url=detector_url, detector_kwargs=detector_kwargs,
29
- rel_lo=0, rel_hi=1, capture_all=True, max_items=num_gen).get_all()
30
-
31
- if opts.rank != 0:
32
- return float('nan')
33
-
34
- n = real_features.shape[1]
35
- m = min(min(real_features.shape[0], gen_features.shape[0]), max_subset_size)
36
- t = 0
37
- for _subset_idx in range(num_subsets):
38
- x = gen_features[np.random.choice(gen_features.shape[0], m, replace=False)]
39
- y = real_features[np.random.choice(real_features.shape[0], m, replace=False)]
40
- a = (x @ x.T / n + 1) ** 3 + (y @ y.T / n + 1) ** 3
41
- b = (x @ y.T / n + 1) ** 3
42
- t += (a.sum() - np.diag(a).sum()) / (m - 1) - b.sum() * 2 / m
43
- kid = t / num_subsets / m
44
- return float(kid)
45
-
46
- #----------------------------------------------------------------------------
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
metrics/metric_main.py DELETED
@@ -1,152 +0,0 @@
1
- # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
2
- #
3
- # NVIDIA CORPORATION and its licensors retain all intellectual property
4
- # and proprietary rights in and to this software, related documentation
5
- # and any modifications thereto. Any use, reproduction, disclosure or
6
- # distribution of this software and related documentation without an express
7
- # license agreement from NVIDIA CORPORATION is strictly prohibited.
8
-
9
- import os
10
- import time
11
- import json
12
- import torch
13
- import dnnlib
14
-
15
- from . import metric_utils
16
- from . import frechet_inception_distance
17
- from . import kernel_inception_distance
18
- from . import precision_recall
19
- from . import perceptual_path_length
20
- from . import inception_score
21
-
22
- #----------------------------------------------------------------------------
23
-
24
- _metric_dict = dict() # name => fn
25
-
26
- def register_metric(fn):
27
- assert callable(fn)
28
- _metric_dict[fn.__name__] = fn
29
- return fn
30
-
31
- def is_valid_metric(metric):
32
- return metric in _metric_dict
33
-
34
- def list_valid_metrics():
35
- return list(_metric_dict.keys())
36
-
37
- #----------------------------------------------------------------------------
38
-
39
- def calc_metric(metric, **kwargs): # See metric_utils.MetricOptions for the full list of arguments.
40
- assert is_valid_metric(metric)
41
- opts = metric_utils.MetricOptions(**kwargs)
42
-
43
- # Calculate.
44
- start_time = time.time()
45
- results = _metric_dict[metric](opts)
46
- total_time = time.time() - start_time
47
-
48
- # Broadcast results.
49
- for key, value in list(results.items()):
50
- if opts.num_gpus > 1:
51
- value = torch.as_tensor(value, dtype=torch.float64, device=opts.device)
52
- torch.distributed.broadcast(tensor=value, src=0)
53
- value = float(value.cpu())
54
- results[key] = value
55
-
56
- # Decorate with metadata.
57
- return dnnlib.EasyDict(
58
- results = dnnlib.EasyDict(results),
59
- metric = metric,
60
- total_time = total_time,
61
- total_time_str = dnnlib.util.format_time(total_time),
62
- num_gpus = opts.num_gpus,
63
- )
64
-
65
- #----------------------------------------------------------------------------
66
-
67
- def report_metric(result_dict, run_dir=None, snapshot_pkl=None):
68
- metric = result_dict['metric']
69
- assert is_valid_metric(metric)
70
- if run_dir is not None and snapshot_pkl is not None:
71
- snapshot_pkl = os.path.relpath(snapshot_pkl, run_dir)
72
-
73
- jsonl_line = json.dumps(dict(result_dict, snapshot_pkl=snapshot_pkl, timestamp=time.time()))
74
- print(jsonl_line)
75
- if run_dir is not None and os.path.isdir(run_dir):
76
- with open(os.path.join(run_dir, f'metric-{metric}.jsonl'), 'at') as f:
77
- f.write(jsonl_line + '\n')
78
-
79
- #----------------------------------------------------------------------------
80
- # Primary metrics.
81
-
82
- @register_metric
83
- def fid50k_full(opts):
84
- opts.dataset_kwargs.update(max_size=None, xflip=False)
85
- fid = frechet_inception_distance.compute_fid(opts, max_real=None, num_gen=50000)
86
- return dict(fid50k_full=fid)
87
-
88
- @register_metric
89
- def kid50k_full(opts):
90
- opts.dataset_kwargs.update(max_size=None, xflip=False)
91
- kid = kernel_inception_distance.compute_kid(opts, max_real=1000000, num_gen=50000, num_subsets=100, max_subset_size=1000)
92
- return dict(kid50k_full=kid)
93
-
94
- @register_metric
95
- def pr50k3_full(opts):
96
- opts.dataset_kwargs.update(max_size=None, xflip=False)
97
- precision, recall = precision_recall.compute_pr(opts, max_real=200000, num_gen=50000, nhood_size=3, row_batch_size=10000, col_batch_size=10000)
98
- return dict(pr50k3_full_precision=precision, pr50k3_full_recall=recall)
99
-
100
- @register_metric
101
- def ppl2_wend(opts):
102
- ppl = perceptual_path_length.compute_ppl(opts, num_samples=50000, epsilon=1e-4, space='w', sampling='end', crop=False, batch_size=2)
103
- return dict(ppl2_wend=ppl)
104
-
105
- @register_metric
106
- def is50k(opts):
107
- opts.dataset_kwargs.update(max_size=None, xflip=False)
108
- mean, std = inception_score.compute_is(opts, num_gen=50000, num_splits=10)
109
- return dict(is50k_mean=mean, is50k_std=std)
110
-
111
- #----------------------------------------------------------------------------
112
- # Legacy metrics.
113
-
114
- @register_metric
115
- def fid50k(opts):
116
- opts.dataset_kwargs.update(max_size=None)
117
- fid = frechet_inception_distance.compute_fid(opts, max_real=50000, num_gen=50000)
118
- return dict(fid50k=fid)
119
-
120
- @register_metric
121
- def kid50k(opts):
122
- opts.dataset_kwargs.update(max_size=None)
123
- kid = kernel_inception_distance.compute_kid(opts, max_real=50000, num_gen=50000, num_subsets=100, max_subset_size=1000)
124
- return dict(kid50k=kid)
125
-
126
- @register_metric
127
- def pr50k3(opts):
128
- opts.dataset_kwargs.update(max_size=None)
129
- precision, recall = precision_recall.compute_pr(opts, max_real=50000, num_gen=50000, nhood_size=3, row_batch_size=10000, col_batch_size=10000)
130
- return dict(pr50k3_precision=precision, pr50k3_recall=recall)
131
-
132
- @register_metric
133
- def ppl_zfull(opts):
134
- ppl = perceptual_path_length.compute_ppl(opts, num_samples=50000, epsilon=1e-4, space='z', sampling='full', crop=True, batch_size=2)
135
- return dict(ppl_zfull=ppl)
136
-
137
- @register_metric
138
- def ppl_wfull(opts):
139
- ppl = perceptual_path_length.compute_ppl(opts, num_samples=50000, epsilon=1e-4, space='w', sampling='full', crop=True, batch_size=2)
140
- return dict(ppl_wfull=ppl)
141
-
142
- @register_metric
143
- def ppl_zend(opts):
144
- ppl = perceptual_path_length.compute_ppl(opts, num_samples=50000, epsilon=1e-4, space='z', sampling='end', crop=True, batch_size=2)
145
- return dict(ppl_zend=ppl)
146
-
147
- @register_metric
148
- def ppl_wend(opts):
149
- ppl = perceptual_path_length.compute_ppl(opts, num_samples=50000, epsilon=1e-4, space='w', sampling='end', crop=True, batch_size=2)
150
- return dict(ppl_wend=ppl)
151
-
152
- #----------------------------------------------------------------------------
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
metrics/metric_utils.py DELETED
@@ -1,275 +0,0 @@
1
- # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
2
- #
3
- # NVIDIA CORPORATION and its licensors retain all intellectual property
4
- # and proprietary rights in and to this software, related documentation
5
- # and any modifications thereto. Any use, reproduction, disclosure or
6
- # distribution of this software and related documentation without an express
7
- # license agreement from NVIDIA CORPORATION is strictly prohibited.
8
-
9
- import os
10
- import time
11
- import hashlib
12
- import pickle
13
- import copy
14
- import uuid
15
- import numpy as np
16
- import torch
17
- import dnnlib
18
-
19
- #----------------------------------------------------------------------------
20
-
21
- class MetricOptions:
22
- def __init__(self, G=None, G_kwargs={}, dataset_kwargs={}, num_gpus=1, rank=0, device=None, progress=None, cache=True):
23
- assert 0 <= rank < num_gpus
24
- self.G = G
25
- self.G_kwargs = dnnlib.EasyDict(G_kwargs)
26
- self.dataset_kwargs = dnnlib.EasyDict(dataset_kwargs)
27
- self.num_gpus = num_gpus
28
- self.rank = rank
29
- self.device = device if device is not None else torch.device('cuda', rank)
30
- self.progress = progress.sub() if progress is not None and rank == 0 else ProgressMonitor()
31
- self.cache = cache
32
-
33
- #----------------------------------------------------------------------------
34
-
35
- _feature_detector_cache = dict()
36
-
37
- def get_feature_detector_name(url):
38
- return os.path.splitext(url.split('/')[-1])[0]
39
-
40
- def get_feature_detector(url, device=torch.device('cpu'), num_gpus=1, rank=0, verbose=False):
41
- assert 0 <= rank < num_gpus
42
- key = (url, device)
43
- if key not in _feature_detector_cache:
44
- is_leader = (rank == 0)
45
- if not is_leader and num_gpus > 1:
46
- torch.distributed.barrier() # leader goes first
47
- with dnnlib.util.open_url(url, verbose=(verbose and is_leader)) as f:
48
- _feature_detector_cache[key] = torch.jit.load(f).eval().to(device)
49
- if is_leader and num_gpus > 1:
50
- torch.distributed.barrier() # others follow
51
- return _feature_detector_cache[key]
52
-
53
- #----------------------------------------------------------------------------
54
-
55
- class FeatureStats:
56
- def __init__(self, capture_all=False, capture_mean_cov=False, max_items=None):
57
- self.capture_all = capture_all
58
- self.capture_mean_cov = capture_mean_cov
59
- self.max_items = max_items
60
- self.num_items = 0
61
- self.num_features = None
62
- self.all_features = None
63
- self.raw_mean = None
64
- self.raw_cov = None
65
-
66
- def set_num_features(self, num_features):
67
- if self.num_features is not None:
68
- assert num_features == self.num_features
69
- else:
70
- self.num_features = num_features
71
- self.all_features = []
72
- self.raw_mean = np.zeros([num_features], dtype=np.float64)
73
- self.raw_cov = np.zeros([num_features, num_features], dtype=np.float64)
74
-
75
- def is_full(self):
76
- return (self.max_items is not None) and (self.num_items >= self.max_items)
77
-
78
- def append(self, x):
79
- x = np.asarray(x, dtype=np.float32)
80
- assert x.ndim == 2
81
- if (self.max_items is not None) and (self.num_items + x.shape[0] > self.max_items):
82
- if self.num_items >= self.max_items:
83
- return
84
- x = x[:self.max_items - self.num_items]
85
-
86
- self.set_num_features(x.shape[1])
87
- self.num_items += x.shape[0]
88
- if self.capture_all:
89
- self.all_features.append(x)
90
- if self.capture_mean_cov:
91
- x64 = x.astype(np.float64)
92
- self.raw_mean += x64.sum(axis=0)
93
- self.raw_cov += x64.T @ x64
94
-
95
- def append_torch(self, x, num_gpus=1, rank=0):
96
- assert isinstance(x, torch.Tensor) and x.ndim == 2
97
- assert 0 <= rank < num_gpus
98
- if num_gpus > 1:
99
- ys = []
100
- for src in range(num_gpus):
101
- y = x.clone()
102
- torch.distributed.broadcast(y, src=src)
103
- ys.append(y)
104
- x = torch.stack(ys, dim=1).flatten(0, 1) # interleave samples
105
- self.append(x.cpu().numpy())
106
-
107
- def get_all(self):
108
- assert self.capture_all
109
- return np.concatenate(self.all_features, axis=0)
110
-
111
- def get_all_torch(self):
112
- return torch.from_numpy(self.get_all())
113
-
114
- def get_mean_cov(self):
115
- assert self.capture_mean_cov
116
- mean = self.raw_mean / self.num_items
117
- cov = self.raw_cov / self.num_items
118
- cov = cov - np.outer(mean, mean)
119
- return mean, cov
120
-
121
- def save(self, pkl_file):
122
- with open(pkl_file, 'wb') as f:
123
- pickle.dump(self.__dict__, f)
124
-
125
- @staticmethod
126
- def load(pkl_file):
127
- with open(pkl_file, 'rb') as f:
128
- s = dnnlib.EasyDict(pickle.load(f))
129
- obj = FeatureStats(capture_all=s.capture_all, max_items=s.max_items)
130
- obj.__dict__.update(s)
131
- return obj
132
-
133
- #----------------------------------------------------------------------------
134
-
135
- class ProgressMonitor:
136
- def __init__(self, tag=None, num_items=None, flush_interval=1000, verbose=False, progress_fn=None, pfn_lo=0, pfn_hi=1000, pfn_total=1000):
137
- self.tag = tag
138
- self.num_items = num_items
139
- self.verbose = verbose
140
- self.flush_interval = flush_interval
141
- self.progress_fn = progress_fn
142
- self.pfn_lo = pfn_lo
143
- self.pfn_hi = pfn_hi
144
- self.pfn_total = pfn_total
145
- self.start_time = time.time()
146
- self.batch_time = self.start_time
147
- self.batch_items = 0
148
- if self.progress_fn is not None:
149
- self.progress_fn(self.pfn_lo, self.pfn_total)
150
-
151
- def update(self, cur_items):
152
- assert (self.num_items is None) or (cur_items <= self.num_items)
153
- if (cur_items < self.batch_items + self.flush_interval) and (self.num_items is None or cur_items < self.num_items):
154
- return
155
- cur_time = time.time()
156
- total_time = cur_time - self.start_time
157
- time_per_item = (cur_time - self.batch_time) / max(cur_items - self.batch_items, 1)
158
- if (self.verbose) and (self.tag is not None):
159
- print(f'{self.tag:<19s} items {cur_items:<7d} time {dnnlib.util.format_time(total_time):<12s} ms/item {time_per_item*1e3:.2f}')
160
- self.batch_time = cur_time
161
- self.batch_items = cur_items
162
-
163
- if (self.progress_fn is not None) and (self.num_items is not None):
164
- self.progress_fn(self.pfn_lo + (self.pfn_hi - self.pfn_lo) * (cur_items / self.num_items), self.pfn_total)
165
-
166
- def sub(self, tag=None, num_items=None, flush_interval=1000, rel_lo=0, rel_hi=1):
167
- return ProgressMonitor(
168
- tag = tag,
169
- num_items = num_items,
170
- flush_interval = flush_interval,
171
- verbose = self.verbose,
172
- progress_fn = self.progress_fn,
173
- pfn_lo = self.pfn_lo + (self.pfn_hi - self.pfn_lo) * rel_lo,
174
- pfn_hi = self.pfn_lo + (self.pfn_hi - self.pfn_lo) * rel_hi,
175
- pfn_total = self.pfn_total,
176
- )
177
-
178
- #----------------------------------------------------------------------------
179
-
180
- def compute_feature_stats_for_dataset(opts, detector_url, detector_kwargs, rel_lo=0, rel_hi=1, batch_size=64, data_loader_kwargs=None, max_items=None, **stats_kwargs):
181
- dataset = dnnlib.util.construct_class_by_name(**opts.dataset_kwargs)
182
- if data_loader_kwargs is None:
183
- data_loader_kwargs = dict(pin_memory=True, num_workers=3, prefetch_factor=2)
184
-
185
- # Try to lookup from cache.
186
- cache_file = None
187
- if opts.cache:
188
- # Choose cache file name.
189
- args = dict(dataset_kwargs=opts.dataset_kwargs, detector_url=detector_url, detector_kwargs=detector_kwargs, stats_kwargs=stats_kwargs)
190
- md5 = hashlib.md5(repr(sorted(args.items())).encode('utf-8'))
191
- cache_tag = f'{dataset.name}-{get_feature_detector_name(detector_url)}-{md5.hexdigest()}'
192
- cache_file = dnnlib.make_cache_dir_path('gan-metrics', cache_tag + '.pkl')
193
-
194
- # Check if the file exists (all processes must agree).
195
- flag = os.path.isfile(cache_file) if opts.rank == 0 else False
196
- if opts.num_gpus > 1:
197
- flag = torch.as_tensor(flag, dtype=torch.float32, device=opts.device)
198
- torch.distributed.broadcast(tensor=flag, src=0)
199
- flag = (float(flag.cpu()) != 0)
200
-
201
- # Load.
202
- if flag:
203
- return FeatureStats.load(cache_file)
204
-
205
- # Initialize.
206
- num_items = len(dataset)
207
- if max_items is not None:
208
- num_items = min(num_items, max_items)
209
- stats = FeatureStats(max_items=num_items, **stats_kwargs)
210
- progress = opts.progress.sub(tag='dataset features', num_items=num_items, rel_lo=rel_lo, rel_hi=rel_hi)
211
- detector = get_feature_detector(url=detector_url, device=opts.device, num_gpus=opts.num_gpus, rank=opts.rank, verbose=progress.verbose)
212
-
213
- # Main loop.
214
- item_subset = [(i * opts.num_gpus + opts.rank) % num_items for i in range((num_items - 1) // opts.num_gpus + 1)]
215
- for images, _labels in torch.utils.data.DataLoader(dataset=dataset, sampler=item_subset, batch_size=batch_size, **data_loader_kwargs):
216
- if images.shape[1] == 1:
217
- images = images.repeat([1, 3, 1, 1])
218
- features = detector(images.to(opts.device), **detector_kwargs)
219
- stats.append_torch(features, num_gpus=opts.num_gpus, rank=opts.rank)
220
- progress.update(stats.num_items)
221
-
222
- # Save to cache.
223
- if cache_file is not None and opts.rank == 0:
224
- os.makedirs(os.path.dirname(cache_file), exist_ok=True)
225
- temp_file = cache_file + '.' + uuid.uuid4().hex
226
- stats.save(temp_file)
227
- os.replace(temp_file, cache_file) # atomic
228
- return stats
229
-
230
- #----------------------------------------------------------------------------
231
-
232
- def compute_feature_stats_for_generator(opts, detector_url, detector_kwargs, rel_lo=0, rel_hi=1, batch_size=64, batch_gen=None, jit=False, **stats_kwargs):
233
- if batch_gen is None:
234
- batch_gen = min(batch_size, 4)
235
- assert batch_size % batch_gen == 0
236
-
237
- # Setup generator and load labels.
238
- G = copy.deepcopy(opts.G).eval().requires_grad_(False).to(opts.device)
239
- dataset = dnnlib.util.construct_class_by_name(**opts.dataset_kwargs)
240
-
241
- # Image generation func.
242
- def run_generator(z, c):
243
- img = G(z=z, c=c, **opts.G_kwargs)
244
- img = (img * 127.5 + 128).clamp(0, 255).to(torch.uint8)
245
- return img
246
-
247
- # JIT.
248
- if jit:
249
- z = torch.zeros([batch_gen, G.z_dim], device=opts.device)
250
- c = torch.zeros([batch_gen, G.c_dim], device=opts.device)
251
- run_generator = torch.jit.trace(run_generator, [z, c], check_trace=False)
252
-
253
- # Initialize.
254
- stats = FeatureStats(**stats_kwargs)
255
- assert stats.max_items is not None
256
- progress = opts.progress.sub(tag='generator features', num_items=stats.max_items, rel_lo=rel_lo, rel_hi=rel_hi)
257
- detector = get_feature_detector(url=detector_url, device=opts.device, num_gpus=opts.num_gpus, rank=opts.rank, verbose=progress.verbose)
258
-
259
- # Main loop.
260
- while not stats.is_full():
261
- images = []
262
- for _i in range(batch_size // batch_gen):
263
- z = torch.randn([batch_gen, G.z_dim], device=opts.device)
264
- c = [dataset.get_label(np.random.randint(len(dataset))) for _i in range(batch_gen)]
265
- c = torch.from_numpy(np.stack(c)).pin_memory().to(opts.device)
266
- images.append(run_generator(z, c))
267
- images = torch.cat(images)
268
- if images.shape[1] == 1:
269
- images = images.repeat([1, 3, 1, 1])
270
- features = detector(images, **detector_kwargs)
271
- stats.append_torch(features, num_gpus=opts.num_gpus, rank=opts.rank)
272
- progress.update(stats.num_items)
273
- return stats
274
-
275
- #----------------------------------------------------------------------------
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
metrics/perceptual_path_length.py DELETED
@@ -1,131 +0,0 @@
1
- # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
2
- #
3
- # NVIDIA CORPORATION and its licensors retain all intellectual property
4
- # and proprietary rights in and to this software, related documentation
5
- # and any modifications thereto. Any use, reproduction, disclosure or
6
- # distribution of this software and related documentation without an express
7
- # license agreement from NVIDIA CORPORATION is strictly prohibited.
8
-
9
- """Perceptual Path Length (PPL) from the paper "A Style-Based Generator
10
- Architecture for Generative Adversarial Networks". Matches the original
11
- implementation by Karras et al. at
12
- https://github.com/NVlabs/stylegan/blob/master/metrics/perceptual_path_length.py"""
13
-
14
- import copy
15
- import numpy as np
16
- import torch
17
- import dnnlib
18
- from . import metric_utils
19
-
20
- #----------------------------------------------------------------------------
21
-
22
- # Spherical interpolation of a batch of vectors.
23
- def slerp(a, b, t):
24
- a = a / a.norm(dim=-1, keepdim=True)
25
- b = b / b.norm(dim=-1, keepdim=True)
26
- d = (a * b).sum(dim=-1, keepdim=True)
27
- p = t * torch.acos(d)
28
- c = b - d * a
29
- c = c / c.norm(dim=-1, keepdim=True)
30
- d = a * torch.cos(p) + c * torch.sin(p)
31
- d = d / d.norm(dim=-1, keepdim=True)
32
- return d
33
-
34
- #----------------------------------------------------------------------------
35
-
36
- class PPLSampler(torch.nn.Module):
37
- def __init__(self, G, G_kwargs, epsilon, space, sampling, crop, vgg16):
38
- assert space in ['z', 'w']
39
- assert sampling in ['full', 'end']
40
- super().__init__()
41
- self.G = copy.deepcopy(G)
42
- self.G_kwargs = G_kwargs
43
- self.epsilon = epsilon
44
- self.space = space
45
- self.sampling = sampling
46
- self.crop = crop
47
- self.vgg16 = copy.deepcopy(vgg16)
48
-
49
- def forward(self, c):
50
- # Generate random latents and interpolation t-values.
51
- t = torch.rand([c.shape[0]], device=c.device) * (1 if self.sampling == 'full' else 0)
52
- z0, z1 = torch.randn([c.shape[0] * 2, self.G.z_dim], device=c.device).chunk(2)
53
-
54
- # Interpolate in W or Z.
55
- if self.space == 'w':
56
- w0, w1 = self.G.mapping(z=torch.cat([z0,z1]), c=torch.cat([c,c])).chunk(2)
57
- wt0 = w0.lerp(w1, t.unsqueeze(1).unsqueeze(2))
58
- wt1 = w0.lerp(w1, t.unsqueeze(1).unsqueeze(2) + self.epsilon)
59
- else: # space == 'z'
60
- zt0 = slerp(z0, z1, t.unsqueeze(1))
61
- zt1 = slerp(z0, z1, t.unsqueeze(1) + self.epsilon)
62
- wt0, wt1 = self.G.mapping(z=torch.cat([zt0,zt1]), c=torch.cat([c,c])).chunk(2)
63
-
64
- # Randomize noise buffers.
65
- for name, buf in self.G.named_buffers():
66
- if name.endswith('.noise_const'):
67
- buf.copy_(torch.randn_like(buf))
68
-
69
- # Generate images.
70
- img = self.G.synthesis(ws=torch.cat([wt0,wt1]), noise_mode='const', force_fp32=True, **self.G_kwargs)
71
-
72
- # Center crop.
73
- if self.crop:
74
- assert img.shape[2] == img.shape[3]
75
- c = img.shape[2] // 8
76
- img = img[:, :, c*3 : c*7, c*2 : c*6]
77
-
78
- # Downsample to 256x256.
79
- factor = self.G.img_resolution // 256
80
- if factor > 1:
81
- img = img.reshape([-1, img.shape[1], img.shape[2] // factor, factor, img.shape[3] // factor, factor]).mean([3, 5])
82
-
83
- # Scale dynamic range from [-1,1] to [0,255].
84
- img = (img + 1) * (255 / 2)
85
- if self.G.img_channels == 1:
86
- img = img.repeat([1, 3, 1, 1])
87
-
88
- # Evaluate differential LPIPS.
89
- lpips_t0, lpips_t1 = self.vgg16(img, resize_images=False, return_lpips=True).chunk(2)
90
- dist = (lpips_t0 - lpips_t1).square().sum(1) / self.epsilon ** 2
91
- return dist
92
-
93
- #----------------------------------------------------------------------------
94
-
95
- def compute_ppl(opts, num_samples, epsilon, space, sampling, crop, batch_size, jit=False):
96
- dataset = dnnlib.util.construct_class_by_name(**opts.dataset_kwargs)
97
- vgg16_url = 'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/metrics/vgg16.pt'
98
- vgg16 = metric_utils.get_feature_detector(vgg16_url, num_gpus=opts.num_gpus, rank=opts.rank, verbose=opts.progress.verbose)
99
-
100
- # Setup sampler.
101
- sampler = PPLSampler(G=opts.G, G_kwargs=opts.G_kwargs, epsilon=epsilon, space=space, sampling=sampling, crop=crop, vgg16=vgg16)
102
- sampler.eval().requires_grad_(False).to(opts.device)
103
- if jit:
104
- c = torch.zeros([batch_size, opts.G.c_dim], device=opts.device)
105
- sampler = torch.jit.trace(sampler, [c], check_trace=False)
106
-
107
- # Sampling loop.
108
- dist = []
109
- progress = opts.progress.sub(tag='ppl sampling', num_items=num_samples)
110
- for batch_start in range(0, num_samples, batch_size * opts.num_gpus):
111
- progress.update(batch_start)
112
- c = [dataset.get_label(np.random.randint(len(dataset))) for _i in range(batch_size)]
113
- c = torch.from_numpy(np.stack(c)).pin_memory().to(opts.device)
114
- x = sampler(c)
115
- for src in range(opts.num_gpus):
116
- y = x.clone()
117
- if opts.num_gpus > 1:
118
- torch.distributed.broadcast(y, src=src)
119
- dist.append(y)
120
- progress.update(num_samples)
121
-
122
- # Compute PPL.
123
- if opts.rank != 0:
124
- return float('nan')
125
- dist = torch.cat(dist)[:num_samples].cpu().numpy()
126
- lo = np.percentile(dist, 1, interpolation='lower')
127
- hi = np.percentile(dist, 99, interpolation='higher')
128
- ppl = np.extract(np.logical_and(dist >= lo, dist <= hi), dist).mean()
129
- return float(ppl)
130
-
131
- #----------------------------------------------------------------------------
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
metrics/precision_recall.py DELETED
@@ -1,62 +0,0 @@
1
- # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
2
- #
3
- # NVIDIA CORPORATION and its licensors retain all intellectual property
4
- # and proprietary rights in and to this software, related documentation
5
- # and any modifications thereto. Any use, reproduction, disclosure or
6
- # distribution of this software and related documentation without an express
7
- # license agreement from NVIDIA CORPORATION is strictly prohibited.
8
-
9
- """Precision/Recall (PR) from the paper "Improved Precision and Recall
10
- Metric for Assessing Generative Models". Matches the original implementation
11
- by Kynkaanniemi et al. at
12
- https://github.com/kynkaat/improved-precision-and-recall-metric/blob/master/precision_recall.py"""
13
-
14
- import torch
15
- from . import metric_utils
16
-
17
- #----------------------------------------------------------------------------
18
-
19
- def compute_distances(row_features, col_features, num_gpus, rank, col_batch_size):
20
- assert 0 <= rank < num_gpus
21
- num_cols = col_features.shape[0]
22
- num_batches = ((num_cols - 1) // col_batch_size // num_gpus + 1) * num_gpus
23
- col_batches = torch.nn.functional.pad(col_features, [0, 0, 0, -num_cols % num_batches]).chunk(num_batches)
24
- dist_batches = []
25
- for col_batch in col_batches[rank :: num_gpus]:
26
- dist_batch = torch.cdist(row_features.unsqueeze(0), col_batch.unsqueeze(0))[0]
27
- for src in range(num_gpus):
28
- dist_broadcast = dist_batch.clone()
29
- if num_gpus > 1:
30
- torch.distributed.broadcast(dist_broadcast, src=src)
31
- dist_batches.append(dist_broadcast.cpu() if rank == 0 else None)
32
- return torch.cat(dist_batches, dim=1)[:, :num_cols] if rank == 0 else None
33
-
34
- #----------------------------------------------------------------------------
35
-
36
- def compute_pr(opts, max_real, num_gen, nhood_size, row_batch_size, col_batch_size):
37
- detector_url = 'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/metrics/vgg16.pt'
38
- detector_kwargs = dict(return_features=True)
39
-
40
- real_features = metric_utils.compute_feature_stats_for_dataset(
41
- opts=opts, detector_url=detector_url, detector_kwargs=detector_kwargs,
42
- rel_lo=0, rel_hi=0, capture_all=True, max_items=max_real).get_all_torch().to(torch.float16).to(opts.device)
43
-
44
- gen_features = metric_utils.compute_feature_stats_for_generator(
45
- opts=opts, detector_url=detector_url, detector_kwargs=detector_kwargs,
46
- rel_lo=0, rel_hi=1, capture_all=True, max_items=num_gen).get_all_torch().to(torch.float16).to(opts.device)
47
-
48
- results = dict()
49
- for name, manifold, probes in [('precision', real_features, gen_features), ('recall', gen_features, real_features)]:
50
- kth = []
51
- for manifold_batch in manifold.split(row_batch_size):
52
- dist = compute_distances(row_features=manifold_batch, col_features=manifold, num_gpus=opts.num_gpus, rank=opts.rank, col_batch_size=col_batch_size)
53
- kth.append(dist.to(torch.float32).kthvalue(nhood_size + 1).values.to(torch.float16) if opts.rank == 0 else None)
54
- kth = torch.cat(kth) if opts.rank == 0 else None
55
- pred = []
56
- for probes_batch in probes.split(row_batch_size):
57
- dist = compute_distances(row_features=probes_batch, col_features=manifold, num_gpus=opts.num_gpus, rank=opts.rank, col_batch_size=col_batch_size)
58
- pred.append((dist <= kth).any(dim=1) if opts.rank == 0 else None)
59
- results[name] = float(torch.cat(pred).to(torch.float32).mean() if opts.rank == 0 else 'nan')
60
- return results['precision'], results['recall']
61
-
62
- #----------------------------------------------------------------------------
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
pretrained/.DS_Store ADDED
Binary file (6.15 kB). View file
 
pretrained/ffhq.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a205a346e86a9ddaae702e118097d014b7b8bd719491396a162cca438f2f524c
3
+ size 381624121
pretrained/metfaces.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:880a460d011a3696c088f58f5844b44271b17903963f2671f96f72dfbce5f76f
3
+ size 381624133
model_ir_se50.pth → pretrained/model_ir_se50.pth RENAMED
File without changes