"""Train a GAN using the techniques described in the paper |
"Training Generative Adversarial Networks with Limited Data".""" |
import os |
import click |
import re |
import json |
import tempfile |
import torch |
import dnnlib |
from training import training_loop |
from metrics import metric_main |
from torch_utils import training_stats |
from torch_utils import custom_ops |
class UserError(Exception): |
pass |
def setup_training_loop_kwargs( |
gpus = None, |
snap = None, |
metrics = None, |
seed = None, |
data = None, |
cond = None, |
subset = None, |
mirror = None, |
cfg = None, |
gamma = None, |
kimg = None, |
batch = None, |
aug = None, |
p = None, |
target = None, |
resume = None, |
freezed = None, |
fp32 = None, |
nhwc = None, |
allow_tf32 = None, |
nobench = None, |
workers = None, |
no_insgen = False, |
rqs = None, |
fqs = None, |
no_cl_on_g = False, |
ada_linear = False, |
exp = None, |
daug = 'ADA', |
beta_schedule = None, |
beta_start = None, |
beta_end = None, |
t_min = None, |
t_max = None, |
noise_sd = None, |
ts_dist = None, |
ada_maxp = None, |
): |
args = dnnlib.EasyDict() |
if gpus is None: |
gpus = 1 |
assert isinstance(gpus, int) |
if not (gpus >= 1 and gpus & (gpus - 1) == 0): |
raise UserError('--gpus must be a power of two') |
args.num_gpus = gpus |
if snap is None: |
snap = 50 |
assert isinstance(snap, int) |
if snap < 1: |
raise UserError('--snap must be at least 1') |
args.image_snapshot_ticks = snap |
args.network_snapshot_ticks = snap |
if metrics is None: |
metrics = ['fid50k_full'] |
assert isinstance(metrics, list) |
if not all(metric_main.is_valid_metric(metric) for metric in metrics): |
raise UserError('\n'.join(['--metrics can only contain the following values:'] + metric_main.list_valid_metrics())) |
args.metrics = metrics |
if seed is None: |
seed = 0 |
assert isinstance(seed, int) |
args.random_seed = seed |
assert data is not None |
assert isinstance(data, str) |
args.training_set_kwargs = dnnlib.EasyDict(class_name='training.dataset.ImageFolderDataset', path=data, use_labels=True, max_size=None, xflip=False) |
args.data_loader_kwargs = dnnlib.EasyDict(pin_memory=True, num_workers=3, prefetch_factor=2) |
try: |
training_set = dnnlib.util.construct_class_by_name(**args.training_set_kwargs) |
args.training_set_kwargs.resolution = training_set.resolution |
args.training_set_kwargs.use_labels = training_set.has_labels |
args.training_set_kwargs.max_size = len(training_set) |
desc = training_set.name |
del training_set |
except IOError as err: |
raise UserError(f'--data: {err}') |
if exp is not None: |
desc += f'-{exp}' |
if cond is None: |
cond = False |
assert isinstance(cond, bool) |
if cond: |
if not args.training_set_kwargs.use_labels: |
raise UserError('--cond=True requires labels specified in dataset.json') |
desc += '-cond' |
else: |
args.training_set_kwargs.use_labels = False |
if subset is not None: |
assert isinstance(subset, int) |
if not 1 <= subset <= args.training_set_kwargs.max_size: |
raise UserError(f'--subset must be between 1 and {args.training_set_kwargs.max_size}') |
desc += f'-subset{subset}' |
if subset < args.training_set_kwargs.max_size: |
args.training_set_kwargs.max_size = subset |
args.training_set_kwargs.random_seed = args.random_seed |
if mirror is None: |
mirror = False |
assert isinstance(mirror, bool) |
if mirror: |
desc += '-mirror' |
args.training_set_kwargs.xflip = True |
if cfg is None: |
cfg = 'auto' |
assert isinstance(cfg, str) |
desc += f'-{cfg}' |
cfg_specs = { |
'auto': dict(ref_gpus=-1, kimg=25000, mb=-1, mbstd=-1, fmaps=-1, lrate=-1, gamma=-1, ema=-1, ramp=0.05, map=2), |
'stylegan2': dict(ref_gpus=8, kimg=25000, mb=32, mbstd=4, fmaps=1, lrate=0.002, gamma=10, ema=10, ramp=None, map=8), |
'paper256': dict(ref_gpus=8, kimg=25000, mb=64, mbstd=8, fmaps=0.5, lrate=0.0025, gamma=1, ema=20, ramp=None, map=8), |
'paper512': dict(ref_gpus=8, kimg=25000, mb=64, mbstd=8, fmaps=1, lrate=0.0025, gamma=0.5, ema=20, ramp=None, map=8), |
'paper1024': dict(ref_gpus=8, kimg=25000, mb=32, mbstd=4, fmaps=1, lrate=0.002, gamma=2, ema=10, ramp=None, map=8), |
'cifar': dict(ref_gpus=4, kimg=100000, mb=64, mbstd=32, fmaps=1, lrate=0.0025, gamma=0.01, ema=500, ramp=0.05, map=2), |
} |
assert cfg in cfg_specs |
spec = dnnlib.EasyDict(cfg_specs[cfg]) |
if cfg == 'auto': |
desc += f'{gpus:d}' |
spec.ref_gpus = gpus |
res = args.training_set_kwargs.resolution |
spec.mb = max(min(gpus * min(4096 // res, 32), 64), gpus) |
spec.mbstd = min(spec.mb // gpus, 4) |
spec.fmaps = 1 if res >= 512 else 0.5 |
spec.lrate = 0.002 if res >= 1024 else 0.0025 |
spec.gamma = 0.0002 * (res ** 2) / spec.mb |
spec.ema = spec.mb * 10 / 32 |
args.G_kwargs = dnnlib.EasyDict(class_name='training.networks.Generator', z_dim=512, w_dim=512, mapping_kwargs=dnnlib.EasyDict(), synthesis_kwargs=dnnlib.EasyDict()) |
args.D_kwargs = dnnlib.EasyDict(class_name='training.networks.Discriminator', block_kwargs=dnnlib.EasyDict(), mapping_kwargs=dnnlib.EasyDict(), epilogue_kwargs=dnnlib.EasyDict()) |
args.G_kwargs.synthesis_kwargs.channel_base = args.D_kwargs.channel_base = int(spec.fmaps * 32768) |
args.G_kwargs.synthesis_kwargs.channel_max = args.D_kwargs.channel_max = 512 |
args.G_kwargs.mapping_kwargs.num_layers = spec.map |
args.G_kwargs.synthesis_kwargs.num_fp16_res = args.D_kwargs.num_fp16_res = 4 |
args.G_kwargs.synthesis_kwargs.conv_clamp = args.D_kwargs.conv_clamp = 256 |
args.D_kwargs.epilogue_kwargs.mbstd_group_size = spec.mbstd |
args.D_kwargs.mapping_kwargs.num_layers = 0 |
args.G_opt_kwargs = dnnlib.EasyDict(class_name='torch.optim.Adam', lr=spec.lrate, betas=[0,0.99], eps=1e-8) |
args.D_opt_kwargs = dnnlib.EasyDict(class_name='torch.optim.Adam', lr=spec.lrate, betas=[0,0.99], eps=1e-8) |
args.loss_kwargs = dnnlib.EasyDict(class_name='training.loss.StyleGAN2Loss', r1_gamma=spec.gamma) |
args.total_kimg = spec.kimg |
args.batch_size = spec.mb |
args.batch_gpu = spec.mb // spec.ref_gpus |
args.ema_kimg = spec.ema |
args.ema_rampup = spec.ramp |
if cfg == 'cifar': |
args.loss_kwargs.pl_weight = 0 |
args.loss_kwargs.style_mixing_prob = 0 |
args.D_kwargs.architecture = 'orig' |
if gamma is not None: |
assert isinstance(gamma, float) |
if not gamma >= 0: |
raise UserError('--gamma must be non-negative') |
desc += f'-gamma{gamma:g}' |
args.loss_kwargs.r1_gamma = gamma |
if kimg is not None: |
assert isinstance(kimg, int) |
if not kimg >= 1: |
raise UserError('--kimg must be at least 1') |
desc += f'-kimg{kimg:d}' |
args.total_kimg = kimg |
if batch is not None: |
assert isinstance(batch, int) |
if not (batch >= 1 and batch % gpus == 0): |
raise UserError('--batch must be at least 1 and divisible by --gpus') |
desc += f'-batch{batch}' |
args.batch_size = batch |
args.batch_gpu = batch // gpus |
if aug is None: |
aug = 'ada' |
else: |
assert isinstance(aug, str) |
desc += f'-{aug}' |
if aug == 'ada': |
args.ada_target = 0.6 |
elif aug == 'noaug': |
pass |
elif aug == 'fixed': |
if p is None: |
raise UserError(f'--aug={aug} requires specifying --p') |
else: |
raise UserError(f'--aug={aug} not supported') |
if p is not None: |
assert isinstance(p, float) |
if aug != 'fixed': |
raise UserError('--p can only be specified with --aug=fixed') |
if not 0 <= p <= 1: |
raise UserError('--p must be between 0 and 1') |
desc += f'-p{p:g}' |
args.augment_p = p |
if target is not None: |
assert isinstance(target, float) |
if aug != 'ada': |
raise UserError('--target can only be specified with --aug=ada') |
if not 0 <= target <= 1: |
raise UserError('--target must be between 0 and 1') |
desc += f'-target{target:g}' |
args.ada_target = target |
diffusion_specs = dict(beta_schedule=beta_schedule, beta_start=beta_start, beta_end=beta_end, |
t_min=t_min, t_max=t_max, noise_std=noise_sd, |
aug=daug, ada_maxp=ada_maxp, ts_dist=ts_dist) |
desc += f"-ts_dist-{ts_dist}" |
if aug != 'noaug': |
args.augment_kwargs = dnnlib.EasyDict(class_name='training.augment.AugmentPipe', **diffusion_specs) |
resume_specs = { |
'ffhq256': 'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/transfer-learning-source-nets/ffhq-res256-mirror-paper256-noaug.pkl', |
'ffhq512': 'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/transfer-learning-source-nets/ffhq-res512-mirror-stylegan2-noaug.pkl', |
'ffhq1024': 'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/transfer-learning-source-nets/ffhq-res1024-mirror-stylegan2-noaug.pkl', |
'celebahq256': 'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/transfer-learning-source-nets/celebahq-res256-mirror-paper256-kimg100000-ada-target0.5.pkl', |
'lsundog256': 'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/transfer-learning-source-nets/lsundog-res256-paper256-kimg100000-noaug.pkl', |
} |
assert resume is None or isinstance(resume, str) |
if resume is None: |
resume = 'noresume' |
elif resume == 'noresume': |
desc += '-noresume' |
elif resume in resume_specs: |
desc += f'-resume{resume}' |
args.resume_pkl = resume_specs[resume] |
else: |
desc += '-resumecustom' |
args.resume_pkl = resume |
if resume != 'noresume': |
args.ada_kimg = 100 |
args.ema_rampup = None |
args.ada_kimg = 100 |
if freezed is not None: |
assert isinstance(freezed, int) |
if not freezed >= 0: |
raise UserError('--freezed must be non-negative') |
desc += f'-freezed{freezed:d}' |
args.D_kwargs.block_kwargs.freeze_layers = freezed |
if fp32 is None: |
fp32 = False |
assert isinstance(fp32, bool) |
if fp32: |
args.G_kwargs.synthesis_kwargs.num_fp16_res = args.D_kwargs.num_fp16_res = 0 |
args.G_kwargs.synthesis_kwargs.conv_clamp = args.D_kwargs.conv_clamp = None |
if nhwc is None: |
nhwc = False |
assert isinstance(nhwc, bool) |
if nhwc: |
args.G_kwargs.synthesis_kwargs.fp16_channels_last = args.D_kwargs.block_kwargs.fp16_channels_last = True |
if nobench is None: |
nobench = False |
assert isinstance(nobench, bool) |
if nobench: |
args.cudnn_benchmark = False |
if allow_tf32 is None: |
allow_tf32 = False |
assert isinstance(allow_tf32, bool) |
if allow_tf32: |
args.allow_tf32 = True |
if workers is not None: |
assert isinstance(workers, int) |
if not workers >= 1: |
raise UserError('--workers must be at least 1') |
args.data_loader_kwargs.num_workers = workers |
use_insgen = True |
if no_insgen is not None: |
assert isinstance(no_insgen, bool) |
use_insgen = not no_insgen |
if use_insgen: |
args.loss_kwargs.class_name = 'training.contrastive_loss.StyleGAN2LossCL' |
args.DHead_kwargs = dnnlib.EasyDict(class_name='training.contrastive_head.CLHead', inplanes=512, temperature=0.2, momentum=0.999, queue_size=-1) |
args.GHead_kwargs = dnnlib.EasyDict(class_name='training.contrastive_head.CLHead', inplanes=512, temperature=0.2, momentum=0.999, queue_size=-1) |
default_queue_size = int(0.05 * args.training_set_kwargs.max_size) |
if args.training_set_kwargs.xflip: |
default_queue_size *= 2 |
args.DHead_kwargs.queue_size = default_queue_size if rqs is None else rqs |
args.GHead_kwargs.queue_size = default_queue_size if fqs is None else fqs |
if no_cl_on_g is not None: |
assert isinstance(no_cl_on_g, bool) |
args.no_cl_on_g = no_cl_on_g |
if ada_linear is not None: |
assert isinstance(ada_linear, bool) |
args.ada_linear = ada_linear |
args.cl_loss_weight = dnnlib.EasyDict(lw_real_cl=1.0, lw_fake_cl=1.0, lw_fake_cl_on_g=0.1) |
else: |
args.DHead_kwargs = None |
args.GHead_kwargs = None |
return desc, args |
def subprocess_fn(rank, args, temp_dir): |
dnnlib.util.Logger(file_name=os.path.join(args.run_dir, 'log.txt'), file_mode='a', should_flush=True) |
if args.num_gpus > 1: |
init_file = os.path.abspath(os.path.join(temp_dir, '.torch_distributed_init')) |
if os.name == 'nt': |
init_method = 'file:///' + init_file.replace('\\', '/') |
torch.distributed.init_process_group(backend='gloo', init_method=init_method, rank=rank, world_size=args.num_gpus) |
else: |
init_method = f'file://{init_file}' |
torch.distributed.init_process_group(backend='nccl', init_method=init_method, rank=rank, world_size=args.num_gpus) |
sync_device = torch.device('cuda', rank) if args.num_gpus > 1 else None |
training_stats.init_multiprocessing(rank=rank, sync_device=sync_device) |
if rank != 0: |
custom_ops.verbosity = 'none' |
training_loop.training_loop(rank=rank, **args) |
class CommaSeparatedList(click.ParamType): |
name = 'list' |
def convert(self, value, param, ctx): |
_ = param, ctx |
if value is None or value.lower() == 'none' or value == '': |
return [] |
return value.split(',') |
@click.command() |
@click.pass_context |
@click.option('--outdir', help='Where to save the results', required=True, metavar='DIR') |
@click.option('--gpus', help='Number of GPUs to use [default: 1]', type=int, metavar='INT') |
@click.option('--snap', help='Snapshot interval [default: 50 ticks]', type=int, metavar='INT') |
@click.option('--metrics', help='Comma-separated list or "none" [default: fid50k_full]', type=CommaSeparatedList()) |
@click.option('--seed', help='Random seed [default: 0]', type=int, metavar='INT') |
@click.option('-n', '--dry-run', help='Print training options and exit', is_flag=True) |
@click.option('--exp', help='exp id', type=str) |
@click.option('--data', help='Training data (directory or zip)', metavar='PATH', required=True) |
@click.option('--cond', help='Train conditional model based on dataset labels [default: false]', type=bool, metavar='BOOL') |
@click.option('--subset', help='Train with only N images [default: all]', type=int, metavar='INT') |
@click.option('--mirror', help='Enable dataset x-flips [default: false]', type=bool, metavar='BOOL', default=1) |
@click.option('--cfg', help='Base config [default: auto]', type=click.Choice(['auto', 'stylegan2', 'paper256', 'paper512', 'paper1024', 'cifar'])) |
@click.option('--gamma', help='Override R1 gamma', type=float) |
@click.option('--kimg', help='Override training duration', type=int, metavar='INT') |
@click.option('--batch', help='Override batch size', type=int, metavar='INT') |
@click.option('--aug', help='Augmentation mode [default: ada]', type=click.Choice(['noaug', 'ada', 'fixed'])) |
@click.option('--daug', help='Augmentation mode [default: ada]', type=click.Choice(['NO', 'ADA', 'DIFF']), default='ADA') |
@click.option('--p', help='Augmentation probability for --aug=fixed', type=float) |
@click.option('--beta_schedule', help='Forward diffusion beta schedule (we use linear always)', type=str, default='linear') |
@click.option('--beta_start', help='Forward diffusion process beta_start', type=float, default=1e-4) |
@click.option('--beta_end', help='Forward diffusion process beta_end', type=float, default=2e-2) |
@click.option('--t_min', help='Minimum # of timesteps for adaptively modification', type=int, default=10) |
@click.option('--t_max', help='Maximum # of timesteps for adaptively modification', type=int, default=500) |
@click.option('--noise_sd', help='Diffusion noise standard deviation', type=float, default=0.05) |
@click.option('--ts_dist', help='Diffusion t sampling way', type=click.Choice(['priority', 'uniform']), default='uniform') |
@click.option('--target', help='Discriminator target value', type=float, default=0.6) |
@click.option('--resume', help='Resume training [default: noresume]', metavar='PKL') |
@click.option('--freezed', help='Freeze-D [default: 0 layers]', type=int, metavar='INT') |
@click.option('--fp32', help='Disable mixed-precision training', type=bool, metavar='BOOL') |
@click.option('--nhwc', help='Use NHWC memory format with FP16', type=bool, metavar='BOOL') |
@click.option('--nobench', help='Disable cuDNN benchmarking', type=bool, metavar='BOOL') |
@click.option('--allow-tf32', help='Allow PyTorch to use TF32 internally', type=bool, metavar='BOOL') |
@click.option('--workers', help='Override number of DataLoader workers', type=int, metavar='INT') |
@click.option('--no_insgen', help='Disable InsGen back to ADA [default: False]', type=bool, metavar='BOOL') |
@click.option('--rqs', help='Size of real image queue [default: 5% * len(dataset)]', type=int, metavar='INT') |
@click.option('--fqs', help='Size of fake image queue [default: 5% * len(dataset)]', type=int, metavar='INT') |
@click.option('--no_cl_on_g', help='Disable fake instance discrimination for generator [default: False]', type=bool, metavar='BOOL') |
@click.option('--ada_linear', help='Whether to linearly increase the strength of ADA [default: False]', type=bool, metavar='BOOL') |
def main(ctx, outdir, dry_run, **config_kwargs): |
"""Train a GAN using the techniques described in the paper |
"Training Generative Adversarial Networks with Limited Data". |
Examples: |
\b |
# Train with custom dataset using 1 GPU. |
python train.py --outdir=~/training-runs --data=~/mydataset.zip --gpus=1 |
\b |
# Train class-conditional CIFAR-10 using 2 GPUs. |
python train.py --outdir=~/training-runs --data=~/datasets/cifar10.zip \\ |
--gpus=2 --cfg=cifar --cond=1 |
\b |
# Transfer learn MetFaces from FFHQ using 4 GPUs. |
python train.py --outdir=~/training-runs --data=~/datasets/metfaces.zip \\ |
--gpus=4 --cfg=paper1024 --mirror=1 --resume=ffhq1024 --snap=10 |
\b |
# Reproduce original StyleGAN2 config F. |
python train.py --outdir=~/training-runs --data=~/datasets/ffhq.zip \\ |
--gpus=8 --cfg=stylegan2 --mirror=1 --aug=noaug |
\b |
Base configs (--cfg): |
auto Automatically select reasonable defaults based on resolution |
and GPU count. Good starting point for new datasets. |
stylegan2 Reproduce results for StyleGAN2 config F at 1024x1024. |
paper256 Reproduce results for FFHQ and LSUN Cat at 256x256. |
paper512 Reproduce results for BreCaHAD and AFHQ at 512x512. |
paper1024 Reproduce results for MetFaces at 1024x1024. |
cifar Reproduce results for CIFAR-10 at 32x32. |
\b |
Transfer learning source networks (--resume): |
ffhq256 FFHQ trained at 256x256 resolution. |
ffhq512 FFHQ trained at 512x512 resolution. |
ffhq1024 FFHQ trained at 1024x1024 resolution. |
celebahq256 CelebA-HQ trained at 256x256 resolution. |
lsundog256 LSUN Dog trained at 256x256 resolution. |
<PATH or URL> Custom network pickle. |
""" |
dnnlib.util.Logger(should_flush=True) |
try: |
run_desc, args = setup_training_loop_kwargs(**config_kwargs) |
except UserError as err: |
ctx.fail(err) |
prev_run_dirs = [] |
if os.path.isdir(outdir): |
prev_run_dirs = [x for x in os.listdir(outdir) if os.path.isdir(os.path.join(outdir, x))] |
matching_dirs = [re.fullmatch(r'\d{5}' + f'-{run_desc}', x) for x in prev_run_dirs if |
re.fullmatch(r'\d{5}' + f'-{run_desc}', x) is not None] |
if len(matching_dirs) > 0: |
assert len(matching_dirs) == 1, f'Multiple directories found for resuming: {matching_dirs}' |
run_dir = os.path.join(outdir, matching_dirs[0].group()) |
else: |
prev_run_ids = [re.match(r'^\d+', x) for x in prev_run_dirs] |
prev_run_ids = [int(x.group()) for x in prev_run_ids if x is not None] |
cur_run_id = max(prev_run_ids, default=-1) + 1 |
run_dir = os.path.join(outdir, f'{cur_run_id:05d}-{run_desc}') |
assert not os.path.exists(run_dir) |
args.run_dir = run_dir |
print() |
print('Training options:') |
print(json.dumps(args, indent=2)) |
print() |
print(f'Output directory: {args.run_dir}') |
print(f'Training data: {args.training_set_kwargs.path}') |
print(f'Training duration: {args.total_kimg} kimg') |
print(f'Number of GPUs: {args.num_gpus}') |
print(f'Number of images: {args.training_set_kwargs.max_size}') |
print(f'Image resolution: {args.training_set_kwargs.resolution}') |
print(f'Conditional model: {args.training_set_kwargs.use_labels}') |
print(f'Dataset x-flips: {args.training_set_kwargs.xflip}') |
print() |
if dry_run: |
print('Dry run; exiting.') |
return |
print('Creating output directory...') |
os.makedirs(args.run_dir, exist_ok=True) |
with open(os.path.join(args.run_dir, 'training_options.json'), 'wt') as f: |
json.dump(args, f, indent=2) |
print('Launching processes...') |
torch.multiprocessing.set_start_method('spawn') |
with tempfile.TemporaryDirectory() as temp_dir: |
if args.num_gpus == 1: |
subprocess_fn(rank=0, args=args, temp_dir=temp_dir) |
else: |
torch.multiprocessing.spawn(fn=subprocess_fn, args=(args, temp_dir), nprocs=args.num_gpus) |
if __name__ == "__main__": |
main() |