nikunjkdtechnoland commited on
Commit
89c278d
·
1 Parent(s): 063372b

init commit some more files

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. LICENSE +21 -0
  2. data/dataset.py +75 -0
  3. full-stack-server.py +231 -0
  4. iopaint/cli.py +223 -0
  5. iopaint/const.py +121 -0
  6. iopaint/download.py +294 -0
  7. iopaint/file_manager/file_manager.py +215 -0
  8. iopaint/helper.py +425 -0
  9. iopaint/installer.py +12 -0
  10. iopaint/model/anytext/cldm/cldm.py +630 -0
  11. iopaint/model/anytext/cldm/ddim_hacked.py +486 -0
  12. iopaint/model/anytext/cldm/embedding_manager.py +165 -0
  13. iopaint/model/anytext/cldm/hack.py +111 -0
  14. iopaint/model/anytext/cldm/model.py +40 -0
  15. iopaint/model/anytext/ldm/models/diffusion/ddim.py +354 -0
  16. iopaint/model/anytext/ldm/models/diffusion/ddpm.py +2380 -0
  17. iopaint/model/anytext/ldm/models/diffusion/dpm_solver/dpm_solver.py +1154 -0
  18. iopaint/model/anytext/ldm/modules/diffusionmodules/model.py +973 -0
  19. iopaint/model/anytext/ldm/modules/diffusionmodules/openaimodel.py +786 -0
  20. iopaint/model/anytext/ldm/modules/distributions/distributions.py +92 -0
  21. iopaint/model/anytext/ldm/modules/ema.py +80 -0
  22. iopaint/model/anytext/ldm/modules/encoders/modules.py +411 -0
  23. iopaint/model/anytext/main.py +45 -0
  24. iopaint/model/anytext/ocr_recog/common.py +74 -0
  25. iopaint/model/anytext/ocr_recog/en_dict.txt +95 -0
  26. iopaint/model/controlnet.py +190 -0
  27. iopaint/model/ddim_sampler.py +193 -0
  28. iopaint/model/fcf.py +1737 -0
  29. iopaint/model/helper/controlnet_preprocess.py +68 -0
  30. iopaint/model/helper/cpu_text_encoder.py +41 -0
  31. iopaint/model/helper/g_diffuser_bot.py +167 -0
  32. iopaint/model/instruct_pix2pix.py +64 -0
  33. iopaint/model/kandinsky.py +65 -0
  34. iopaint/model/lama.py +57 -0
  35. iopaint/model/ldm.py +336 -0
  36. iopaint/model/manga.py +97 -0
  37. iopaint/model/mat.py +1945 -0
  38. iopaint/model/mi_gan.py +110 -0
  39. iopaint/model/opencv2.py +29 -0
  40. iopaint/model_manager.py +191 -0
  41. iopaint/plugins/briarmbg.py +512 -0
  42. iopaint/plugins/gfpgan_plugin.py +74 -0
  43. iopaint/plugins/gfpganer.py +84 -0
  44. iopaint/plugins/interactive_seg.py +89 -0
  45. iopaint/plugins/segment_anything/build_sam.py +168 -0
  46. iopaint/plugins/segment_anything/modeling/common.py +43 -0
  47. iopaint/plugins/segment_anything/modeling/image_encoder.py +395 -0
  48. iopaint/plugins/segment_anything/modeling/mask_decoder.py +176 -0
  49. model/networks.py +563 -0
  50. only_gradio_server.py +188 -0
LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2021 Du Ang
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
data/dataset.py ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ import torch.utils.data as data
3
+ from os import listdir
4
+ from utils.tools import default_loader, is_image_file, normalize
5
+ import os
6
+
7
+ import torchvision.transforms as transforms
8
+
9
+
10
+ class Dataset(data.Dataset):
11
+ def __init__(self, data_path, image_shape, with_subfolder=False, random_crop=True, return_name=False):
12
+ super(Dataset, self).__init__()
13
+ if with_subfolder:
14
+ self.samples = self._find_samples_in_subfolders(data_path)
15
+ else:
16
+ self.samples = [x for x in listdir(data_path) if is_image_file(x)]
17
+ self.data_path = data_path
18
+ self.image_shape = image_shape[:-1]
19
+ self.random_crop = random_crop
20
+ self.return_name = return_name
21
+
22
+ def __getitem__(self, index):
23
+ path = os.path.join(self.data_path, self.samples[index])
24
+ img = default_loader(path)
25
+
26
+ if self.random_crop:
27
+ imgw, imgh = img.size
28
+ if imgh < self.image_shape[0] or imgw < self.image_shape[1]:
29
+ img = transforms.Resize(min(self.image_shape))(img)
30
+ img = transforms.RandomCrop(self.image_shape)(img)
31
+ else:
32
+ img = transforms.Resize(self.image_shape)(img)
33
+ img = transforms.RandomCrop(self.image_shape)(img)
34
+
35
+ img = transforms.ToTensor()(img) # turn the image to a tensor
36
+ img = normalize(img)
37
+
38
+ if self.return_name:
39
+ return self.samples[index], img
40
+ else:
41
+ return img
42
+
43
+ def _find_samples_in_subfolders(self, dir):
44
+ """
45
+ Finds the class folders in a dataset.
46
+ Args:
47
+ dir (string): Root directory path.
48
+ Returns:
49
+ tuple: (classes, class_to_idx) where classes are relative to (dir), and class_to_idx is a dictionary.
50
+ Ensures:
51
+ No class is a subdirectory of another.
52
+ """
53
+ if sys.version_info >= (3, 5):
54
+ # Faster and available in Python 3.5 and above
55
+ classes = [d.name for d in os.scandir(dir) if d.is_dir()]
56
+ else:
57
+ classes = [d for d in os.listdir(dir) if os.path.isdir(os.path.join(dir, d))]
58
+ classes.sort()
59
+ class_to_idx = {classes[i]: i for i in range(len(classes))}
60
+ samples = []
61
+ for target in sorted(class_to_idx.keys()):
62
+ d = os.path.join(dir, target)
63
+ if not os.path.isdir(d):
64
+ continue
65
+ for root, _, fnames in sorted(os.walk(d)):
66
+ for fname in sorted(fnames):
67
+ if is_image_file(fname):
68
+ path = os.path.join(root, fname)
69
+ # item = (path, class_to_idx[target])
70
+ # samples.append(item)
71
+ samples.append(path)
72
+ return samples
73
+
74
+ def __len__(self):
75
+ return len(self.samples)
full-stack-server.py ADDED
@@ -0,0 +1,231 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import base64
3
+ import io
4
+ import uuid
5
+ from ultralytics import YOLO
6
+ import cv2
7
+ import torch
8
+ import numpy as np
9
+ from PIL import Image
10
+ from torchvision import transforms
11
+ import imageio.v2 as imageio
12
+ from trainer import Trainer
13
+ from utils.tools import get_config
14
+ import torch.nn.functional as F
15
+ from iopaint.single_processing import batch_inpaint
16
+ from pathlib import Path
17
+ from flask import Flask, request, jsonify,render_template
18
+ from flask_cors import CORS
19
+
20
+ app = Flask(__name__)
21
+ CORS(app)
22
+
23
+ # set current working directory cache instead of default
24
+ os.environ["TORCH_HOME"] = "./pretrained-model"
25
+ os.environ["HUGGINGFACE_HUB_CACHE"] = "./pretrained-model"
26
+
27
+
28
+ def resize_image(input_image_base64, width=640, height=640):
29
+ """Resizes an image from base64 data and returns the resized image as bytes."""
30
+ try:
31
+ # Decode base64 string to bytes
32
+ input_image_data = base64.b64decode(input_image_base64)
33
+ # Convert bytes to NumPy array
34
+ img = np.frombuffer(input_image_data, dtype=np.uint8)
35
+ # Decode NumPy array as an image
36
+ img = cv2.imdecode(img, cv2.IMREAD_COLOR)
37
+
38
+ # Resize while maintaining the aspect ratio
39
+ shape = img.shape[:2] # current shape [height, width]
40
+ new_shape = (width, height) # the shape to resize to
41
+
42
+ # Scale ratio (new / old)
43
+ r = min(new_shape[0] / shape[0], new_shape[1] / shape[1])
44
+ ratio = r, r # width, height ratios
45
+ new_unpad = int(round(shape[1] * r)), int(round(shape[0] * r))
46
+
47
+ # Resize the image
48
+ im = cv2.resize(img, new_unpad, interpolation=cv2.INTER_LINEAR)
49
+
50
+ # Pad the image
51
+ color = (114, 114, 114) # color used for padding
52
+ dw, dh = new_shape[1] - new_unpad[0], new_shape[0] - new_unpad[1] # wh padding
53
+ # divide padding into 2 sides
54
+ dw /= 2
55
+ dh /= 2
56
+ # compute padding on all corners
57
+ top, bottom = int(round(dh - 0.1)), int(round(dh + 0.1))
58
+ left, right = int(round(dw - 0.1)), int(round(dw + 0.1))
59
+ im = cv2.copyMakeBorder(im, top, bottom, left, right, cv2.BORDER_CONSTANT, value=color) # add border
60
+
61
+ # Convert the resized and padded image to bytes
62
+ resized_image_bytes = cv2.imencode('.png', im)[1].tobytes()
63
+ return resized_image_bytes
64
+
65
+ except Exception as e:
66
+ print(f"Error resizing image: {e}")
67
+ return None # Or handle differently as needed
68
+
69
+
70
+ def load_weights(path, device):
71
+ model_weights = torch.load(path)
72
+ return {
73
+ k: v.to(device)
74
+ for k, v in model_weights.items()
75
+ }
76
+
77
+
78
+ # Function to convert image to base64
79
+ def convert_image_to_base64(image):
80
+ # Convert image to bytes
81
+ _, buffer = cv2.imencode('.png', image)
82
+ # Convert bytes to base64
83
+ image_base64 = base64.b64encode(buffer).decode('utf-8')
84
+ return image_base64
85
+
86
+
87
+ def convert_to_base64(image):
88
+ # Read the image file as binary data
89
+ image_data = image.read()
90
+ # Encode the binary data as base64
91
+ base64_encoded = base64.b64encode(image_data).decode('utf-8')
92
+ return base64_encoded
93
+
94
+
95
+ @app.route('/')
96
+ def index():
97
+ return render_template('index.html')
98
+
99
+
100
+ @app.route('/process_images', methods=['POST'])
101
+ def process_images():
102
+ # Static paths
103
+ config_path = Path('configs/config.yaml')
104
+ model_path = Path('pretrained-model/torch_model.p')
105
+
106
+ # Check if the request contains files
107
+ if 'input_image' not in request.files or 'append_image' not in request.files:
108
+ return jsonify({'error': 'No files found'}), 419
109
+
110
+ # Get the objectName from the request or use default "chair" if not provided
111
+ default_class = request.form.get('objectName', 'chair')
112
+
113
+ # Convert the images to base64
114
+ try:
115
+ input_base64 = convert_to_base64(request.files['input_image'])
116
+ append_base64 = convert_to_base64(request.files['append_image'])
117
+ except Exception as e:
118
+ return jsonify({'error': 'Failed to read files'}), 419
119
+
120
+ # Resize input image and get base64 data of resized image
121
+ input_resized_image_bytes = resize_image(input_base64)
122
+
123
+ # Convert resized image bytes to base64
124
+ input_resized_base64 = base64.b64encode(input_resized_image_bytes).decode('utf-8')
125
+
126
+ # Decode the resized image from base64 data directly
127
+ img = cv2.imdecode(np.frombuffer(input_resized_image_bytes, np.uint8), cv2.IMREAD_COLOR)
128
+
129
+ if img is None:
130
+ return jsonify({'error': 'Failed to decode resized image'}), 419
131
+
132
+ H, W, _ = img.shape
133
+ x_point = 0
134
+ y_point = 0
135
+ width = 1
136
+ height = 1
137
+
138
+ # Load a model
139
+ model = YOLO('pretrained-model/yolov8m-seg.pt') # pretrained YOLOv8m-seg model
140
+
141
+ # Run batched inference on a list of images
142
+ results = model(img, imgsz=(W,H), conf=0.5) # chair class 56 with confidence >= 0.5
143
+ names = model.names
144
+ # print(names)
145
+
146
+ class_found = False
147
+ for result in results:
148
+ for i, label in enumerate(result.boxes.cls):
149
+ # Check if the label matches the chair label
150
+ if names[int(label)] == default_class:
151
+ class_found = True
152
+ # Convert the tensor to a numpy array
153
+ chair_mask_np = result.masks.data[i].numpy()
154
+
155
+ kernel = np.ones((5, 5), np.uint8) # Create a 5x5 kernel for dilation
156
+ chair_mask_np = cv2.dilate(chair_mask_np, kernel, iterations=2) # Apply dilation
157
+
158
+ # Find contours to get bounding box
159
+ contours, _ = cv2.findContours((chair_mask_np == 1).astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
160
+
161
+ # Iterate over contours to find the bounding box of each object
162
+ for contour in contours:
163
+ x, y, w, h = cv2.boundingRect(contour)
164
+ x_point = x
165
+ y_point = y
166
+ width = w
167
+ height = h
168
+
169
+ # Get the corresponding mask
170
+ mask = result.masks.data[i].numpy() * 255
171
+ dilated_mask = cv2.dilate(mask, kernel, iterations=2) # Apply dilation
172
+ # Resize the mask to match the dimensions of the original image
173
+ resized_mask = cv2.resize(dilated_mask, (img.shape[1], img.shape[0]))
174
+ # Convert mask to base64
175
+ mask_base64 = convert_image_to_base64(resized_mask)
176
+
177
+ # call repainting and merge function
178
+ output_base64 = repaitingAndMerge(append_base64,str(model_path), str(config_path),width, height, x_point, y_point, input_resized_base64, mask_base64)
179
+ # Return the output base64 image in the API response
180
+ return jsonify({'output_base64': output_base64}), 200
181
+
182
+ # return class not found in prediction
183
+ if not class_found:
184
+ return jsonify({'message': f'{default_class} object not found in the image'}), 200
185
+
186
+ def repaitingAndMerge(append_image_base64_image, model_path, config_path, width, height, xposition, yposition, input_base64, mask_base64):
187
+ config = get_config(config_path)
188
+ device = torch.device("cpu")
189
+ trainer = Trainer(config)
190
+ trainer.load_state_dict(load_weights(model_path, device), strict=False)
191
+ trainer.eval()
192
+
193
+ # lama inpainting start
194
+ print("lama inpainting start")
195
+ inpaint_result_base64 = batch_inpaint('lama', 'cpu', input_base64, mask_base64)
196
+ print("lama inpainting end")
197
+
198
+ # Decode base64 to bytes
199
+ inpaint_result_bytes = base64.b64decode(inpaint_result_base64)
200
+
201
+ # Convert bytes to NumPy array
202
+ inpaint_result_np = np.array(Image.open(io.BytesIO(inpaint_result_bytes)))
203
+
204
+ # Create PIL Image from NumPy array
205
+ final_image = Image.fromarray(inpaint_result_np)
206
+
207
+ print("merge start")
208
+ # Decode base64 to binary data
209
+ decoded_image_data = base64.b64decode(append_image_base64_image)
210
+ # Convert binary data to a NumPy array
211
+ append_image = cv2.imdecode(np.frombuffer(decoded_image_data, np.uint8), cv2.IMREAD_UNCHANGED)
212
+ # Resize the append image while preserving transparency
213
+ resized_image = cv2.resize(append_image, (width, height), interpolation=cv2.INTER_AREA)
214
+ # Convert the resized image to RGBA format (assuming it's in BGRA format)
215
+ resized_image = cv2.cvtColor(resized_image, cv2.COLOR_BGRA2RGBA)
216
+ # Create a PIL Image from the resized image with transparent background
217
+ append_image_pil = Image.fromarray(resized_image)
218
+ # Paste the append image onto the final image
219
+ final_image.paste(append_image_pil, (xposition, yposition), append_image_pil)
220
+ # Save the resulting image
221
+ print("merge end")
222
+ # Convert the final image to base64
223
+ with io.BytesIO() as output_buffer:
224
+ final_image.save(output_buffer, format='PNG')
225
+ output_base64 = base64.b64encode(output_buffer.getvalue()).decode('utf-8')
226
+
227
+ return output_base64
228
+
229
+
230
+ if __name__ == '__main__':
231
+ app.run(host='0.0.0.0',debug=True)
iopaint/cli.py ADDED
@@ -0,0 +1,223 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import webbrowser
2
+ from contextlib import asynccontextmanager
3
+ from pathlib import Path
4
+ from typing import Dict, Optional
5
+
6
+ import typer
7
+ from fastapi import FastAPI
8
+ from loguru import logger
9
+ from typer import Option
10
+ from typer_config import use_json_config
11
+
12
+ from iopaint.const import *
13
+ from iopaint.runtime import setup_model_dir, dump_environment_info, check_device
14
+ from iopaint.schema import InteractiveSegModel, Device, RealESRGANModel, RemoveBGModel
15
+
16
+ typer_app = typer.Typer(pretty_exceptions_show_locals=False, add_completion=False)
17
+
18
+
19
+ @typer_app.command(help="Install all plugins dependencies")
20
+ def install_plugins_packages():
21
+ from iopaint.installer import install_plugins_package
22
+
23
+ install_plugins_package()
24
+
25
+
26
+ @typer_app.command(help="Download SD/SDXL normal/inpainting model from HuggingFace")
27
+ def download(
28
+ model: str = Option(
29
+ ..., help="Model id on HuggingFace e.g: runwayml/stable-diffusion-inpainting"
30
+ ),
31
+ model_dir: Path = Option(
32
+ DEFAULT_MODEL_DIR,
33
+ help=MODEL_DIR_HELP,
34
+ file_okay=False,
35
+ callback=setup_model_dir,
36
+ ),
37
+ ):
38
+ from iopaint.download import cli_download_model
39
+
40
+ cli_download_model(model)
41
+
42
+
43
+ @typer_app.command(name="list", help="List downloaded models")
44
+ def list_model(
45
+ model_dir: Path = Option(
46
+ DEFAULT_MODEL_DIR,
47
+ help=MODEL_DIR_HELP,
48
+ file_okay=False,
49
+ callback=setup_model_dir,
50
+ ),
51
+ ):
52
+ from iopaint.download import scan_models
53
+
54
+ scanned_models = scan_models()
55
+ for it in scanned_models:
56
+ print(it.name)
57
+
58
+
59
+ @typer_app.command(help="Batch processing images")
60
+ def run(
61
+ model: str = Option("lama"),
62
+ device: Device = Option(Device.cpu),
63
+ image: Path = Option(..., help="Image folders or file path"),
64
+ mask: Path = Option(
65
+ ...,
66
+ help="Mask folders or file path. "
67
+ "If it is a directory, the mask images in the directory should have the same name as the original image."
68
+ "If it is a file, all images will use this mask."
69
+ "Mask will automatically resize to the same size as the original image.",
70
+ ),
71
+ output: Path = Option(..., help="Output directory or file path"),
72
+ config: Path = Option(
73
+ None, help="Config file path. You can use dump command to create a base config."
74
+ ),
75
+ concat: bool = Option(
76
+ False, help="Concat original image, mask and output images into one image"
77
+ ),
78
+ model_dir: Path = Option(
79
+ DEFAULT_MODEL_DIR,
80
+ help=MODEL_DIR_HELP,
81
+ file_okay=False,
82
+ callback=setup_model_dir,
83
+ ),
84
+ ):
85
+ from iopaint.download import cli_download_model, scan_models
86
+
87
+ scanned_models = scan_models()
88
+ if model not in [it.name for it in scanned_models]:
89
+ logger.info(f"{model} not found in {model_dir}, try to downloading")
90
+ cli_download_model(model)
91
+
92
+ from iopaint.batch_processing import batch_inpaint
93
+
94
+ batch_inpaint(model, device, image, mask, output, config, concat)
95
+
96
+
97
+ @typer_app.command(help="Start IOPaint server")
98
+ @use_json_config()
99
+ def start(
100
+ host: str = Option("127.0.0.1"),
101
+ port: int = Option(8080),
102
+ inbrowser: bool = Option(False, help=INBROWSER_HELP),
103
+ model: str = Option(
104
+ DEFAULT_MODEL,
105
+ help=f"Erase models: [{', '.join(AVAILABLE_MODELS)}].\n"
106
+ f"Diffusion models: [{', '.join(DIFFUSION_MODELS)}] or any SD/SDXL normal/inpainting models on HuggingFace.",
107
+ ),
108
+ model_dir: Path = Option(
109
+ DEFAULT_MODEL_DIR,
110
+ help=MODEL_DIR_HELP,
111
+ dir_okay=True,
112
+ file_okay=False,
113
+ callback=setup_model_dir,
114
+ ),
115
+ low_mem: bool = Option(False, help=LOW_MEM_HELP),
116
+ no_half: bool = Option(False, help=NO_HALF_HELP),
117
+ cpu_offload: bool = Option(False, help=CPU_OFFLOAD_HELP),
118
+ disable_nsfw_checker: bool = Option(False, help=DISABLE_NSFW_HELP),
119
+ cpu_textencoder: bool = Option(False, help=CPU_TEXTENCODER_HELP),
120
+ local_files_only: bool = Option(False, help=LOCAL_FILES_ONLY_HELP),
121
+ device: Device = Option(Device.cpu),
122
+ input: Optional[Path] = Option(None, help=INPUT_HELP),
123
+ output_dir: Optional[Path] = Option(
124
+ None, help=OUTPUT_DIR_HELP, dir_okay=True, file_okay=False
125
+ ),
126
+ quality: int = Option(95, help=QUALITY_HELP),
127
+ enable_interactive_seg: bool = Option(False, help=INTERACTIVE_SEG_HELP),
128
+ interactive_seg_model: InteractiveSegModel = Option(
129
+ InteractiveSegModel.vit_b, help=INTERACTIVE_SEG_MODEL_HELP
130
+ ),
131
+ interactive_seg_device: Device = Option(Device.cpu),
132
+ enable_remove_bg: bool = Option(False, help=REMOVE_BG_HELP),
133
+ remove_bg_model: RemoveBGModel = Option(RemoveBGModel.briaai_rmbg_1_4),
134
+ enable_anime_seg: bool = Option(False, help=ANIMESEG_HELP),
135
+ enable_realesrgan: bool = Option(False),
136
+ realesrgan_device: Device = Option(Device.cpu),
137
+ realesrgan_model: RealESRGANModel = Option(RealESRGANModel.realesr_general_x4v3),
138
+ enable_gfpgan: bool = Option(False),
139
+ gfpgan_device: Device = Option(Device.cpu),
140
+ enable_restoreformer: bool = Option(False),
141
+ restoreformer_device: Device = Option(Device.cpu),
142
+ ):
143
+ dump_environment_info()
144
+ device = check_device(device)
145
+ if input and not input.exists():
146
+ logger.error(f"invalid --input: {input} not exists")
147
+ exit(-1)
148
+ if input and input.is_dir() and not output_dir:
149
+ logger.error(f"invalid --output-dir: must be set when --input is a directory")
150
+ exit(-1)
151
+ if output_dir:
152
+ output_dir = output_dir.expanduser().absolute()
153
+ logger.info(f"Image will be saved to {output_dir}")
154
+ if not output_dir.exists():
155
+ logger.info(f"Create output directory {output_dir}")
156
+ output_dir.mkdir(parents=True)
157
+
158
+ model_dir = model_dir.expanduser().absolute()
159
+
160
+ if local_files_only:
161
+ os.environ["TRANSFORMERS_OFFLINE"] = "1"
162
+ os.environ["HF_HUB_OFFLINE"] = "1"
163
+
164
+ from iopaint.download import cli_download_model, scan_models
165
+
166
+ scanned_models = scan_models()
167
+ if model not in [it.name for it in scanned_models]:
168
+ logger.info(f"{model} not found in {model_dir}, try to downloading")
169
+ cli_download_model(model)
170
+
171
+ from iopaint.api import Api
172
+ from iopaint.schema import ApiConfig
173
+
174
+ @asynccontextmanager
175
+ async def lifespan(app: FastAPI):
176
+ if inbrowser:
177
+ webbrowser.open(f"http://localhost:{port}", new=0, autoraise=True)
178
+ yield
179
+
180
+ app = FastAPI(lifespan=lifespan)
181
+
182
+ api_config = ApiConfig(
183
+ host=host,
184
+ port=port,
185
+ inbrowser=inbrowser,
186
+ model=model,
187
+ no_half=no_half,
188
+ low_mem=low_mem,
189
+ cpu_offload=cpu_offload,
190
+ disable_nsfw_checker=disable_nsfw_checker,
191
+ local_files_only=local_files_only,
192
+ cpu_textencoder=cpu_textencoder if device == Device.cuda else False,
193
+ device=device,
194
+ input=input,
195
+ output_dir=output_dir,
196
+ quality=quality,
197
+ enable_interactive_seg=enable_interactive_seg,
198
+ interactive_seg_model=interactive_seg_model,
199
+ interactive_seg_device=interactive_seg_device,
200
+ enable_remove_bg=enable_remove_bg,
201
+ remove_bg_model=remove_bg_model,
202
+ enable_anime_seg=enable_anime_seg,
203
+ enable_realesrgan=enable_realesrgan,
204
+ realesrgan_device=realesrgan_device,
205
+ realesrgan_model=realesrgan_model,
206
+ enable_gfpgan=enable_gfpgan,
207
+ gfpgan_device=gfpgan_device,
208
+ enable_restoreformer=enable_restoreformer,
209
+ restoreformer_device=restoreformer_device,
210
+ )
211
+ print(api_config.model_dump_json(indent=4))
212
+ api = Api(app, api_config)
213
+ api.launch()
214
+
215
+
216
+ @typer_app.command(help="Start IOPaint web config page")
217
+ def start_web_config(
218
+ config_file: Path = Option("config.json"),
219
+ ):
220
+ dump_environment_info()
221
+ from iopaint.web_config import main
222
+
223
+ main(config_file)
iopaint/const.py ADDED
@@ -0,0 +1,121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from typing import List
3
+
4
+ INSTRUCT_PIX2PIX_NAME = "timbrooks/instruct-pix2pix"
5
+ KANDINSKY22_NAME = "kandinsky-community/kandinsky-2-2-decoder-inpaint"
6
+ POWERPAINT_NAME = "Sanster/PowerPaint-V1-stable-diffusion-inpainting"
7
+ ANYTEXT_NAME = "Sanster/AnyText"
8
+
9
+
10
+ DIFFUSERS_SD_CLASS_NAME = "StableDiffusionPipeline"
11
+ DIFFUSERS_SD_INPAINT_CLASS_NAME = "StableDiffusionInpaintPipeline"
12
+ DIFFUSERS_SDXL_CLASS_NAME = "StableDiffusionXLPipeline"
13
+ DIFFUSERS_SDXL_INPAINT_CLASS_NAME = "StableDiffusionXLInpaintPipeline"
14
+
15
+ MPS_UNSUPPORT_MODELS = [
16
+ "lama",
17
+ "ldm",
18
+ "zits",
19
+ "mat",
20
+ "fcf",
21
+ "cv2",
22
+ "manga",
23
+ ]
24
+
25
+ DEFAULT_MODEL = "lama"
26
+ AVAILABLE_MODELS = ["lama", "ldm", "zits", "mat", "fcf", "manga", "cv2", "migan"]
27
+ DIFFUSION_MODELS = [
28
+ "runwayml/stable-diffusion-inpainting",
29
+ "Uminosachi/realisticVisionV51_v51VAE-inpainting",
30
+ "redstonehero/dreamshaper-inpainting",
31
+ "Sanster/anything-4.0-inpainting",
32
+ "diffusers/stable-diffusion-xl-1.0-inpainting-0.1",
33
+ "Fantasy-Studio/Paint-by-Example",
34
+ POWERPAINT_NAME,
35
+ ANYTEXT_NAME,
36
+ ]
37
+
38
+ NO_HALF_HELP = """
39
+ Using full precision(fp32) model.
40
+ If your diffusion model generate result is always black or green, use this argument.
41
+ """
42
+
43
+ CPU_OFFLOAD_HELP = """
44
+ Offloads diffusion model's weight to CPU RAM, significantly reducing vRAM usage.
45
+ """
46
+
47
+ LOW_MEM_HELP = "Enable attention slicing and vae tiling to save memory."
48
+
49
+ DISABLE_NSFW_HELP = """
50
+ Disable NSFW checker for diffusion model.
51
+ """
52
+
53
+ CPU_TEXTENCODER_HELP = """
54
+ Run diffusion models text encoder on CPU to reduce vRAM usage.
55
+ """
56
+
57
+ SD_CONTROLNET_CHOICES: List[str] = [
58
+ "lllyasviel/control_v11p_sd15_canny",
59
+ # "lllyasviel/control_v11p_sd15_seg",
60
+ "lllyasviel/control_v11p_sd15_openpose",
61
+ "lllyasviel/control_v11p_sd15_inpaint",
62
+ "lllyasviel/control_v11f1p_sd15_depth",
63
+ ]
64
+
65
+ SD2_CONTROLNET_CHOICES = [
66
+ "thibaud/controlnet-sd21-canny-diffusers",
67
+ "thibaud/controlnet-sd21-depth-diffusers",
68
+ "thibaud/controlnet-sd21-openpose-diffusers",
69
+ ]
70
+
71
+ SDXL_CONTROLNET_CHOICES = [
72
+ "thibaud/controlnet-openpose-sdxl-1.0",
73
+ "destitech/controlnet-inpaint-dreamer-sdxl",
74
+ "diffusers/controlnet-canny-sdxl-1.0",
75
+ "diffusers/controlnet-canny-sdxl-1.0-mid",
76
+ "diffusers/controlnet-canny-sdxl-1.0-small",
77
+ "diffusers/controlnet-depth-sdxl-1.0",
78
+ "diffusers/controlnet-depth-sdxl-1.0-mid",
79
+ "diffusers/controlnet-depth-sdxl-1.0-small",
80
+ ]
81
+
82
+ LOCAL_FILES_ONLY_HELP = """
83
+ When loading diffusion models, using local files only, not connect to HuggingFace server.
84
+ """
85
+
86
+ DEFAULT_MODEL_DIR = os.path.abspath(
87
+ os.getenv("XDG_CACHE_HOME", os.path.join(os.path.expanduser("~"), ".cache"))
88
+ )
89
+ #DEFAULT_MODEL_DIR = os.path.abspath("pretrained-models")
90
+
91
+ MODEL_DIR_HELP = f"""
92
+ Model download directory (by setting XDG_CACHE_HOME environment variable), by default model download to {DEFAULT_MODEL_DIR}
93
+ """
94
+
95
+ OUTPUT_DIR_HELP = """
96
+ Result images will be saved to output directory automatically.
97
+ """
98
+
99
+ INPUT_HELP = """
100
+ If input is image, it will be loaded by default.
101
+ If input is directory, you can browse and select image in file manager.
102
+ """
103
+
104
+ GUI_HELP = """
105
+ Launch Lama Cleaner as desktop app
106
+ """
107
+
108
+ QUALITY_HELP = """
109
+ Quality of image encoding, 0-100. Default is 95, higher quality will generate larger file size.
110
+ """
111
+
112
+ INTERACTIVE_SEG_HELP = "Enable interactive segmentation using Segment Anything."
113
+ INTERACTIVE_SEG_MODEL_HELP = "Model size: mobile_sam < vit_b < vit_l < vit_h. Bigger model size means better segmentation but slower speed."
114
+ REMOVE_BG_HELP = "Enable remove background plugin. Always run on CPU"
115
+ ANIMESEG_HELP = "Enable anime segmentation plugin. Always run on CPU"
116
+ REALESRGAN_HELP = "Enable realesrgan super resolution"
117
+ GFPGAN_HELP = "Enable GFPGAN face restore. To also enhance background, use with --enable-realesrgan"
118
+ RESTOREFORMER_HELP = "Enable RestoreFormer face restore. To also enhance background, use with --enable-realesrgan"
119
+ GIF_HELP = "Enable GIF plugin. Make GIF to compare original and cleaned image"
120
+
121
+ INBROWSER_HELP = "Automatically launch IOPaint in a new tab on the default browser"
iopaint/download.py ADDED
@@ -0,0 +1,294 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import os
3
+ from functools import lru_cache
4
+ from typing import List
5
+
6
+ from iopaint.schema import ModelType, ModelInfo
7
+ from loguru import logger
8
+ from pathlib import Path
9
+
10
+ from iopaint.const import (
11
+ DEFAULT_MODEL_DIR,
12
+ DIFFUSERS_SD_CLASS_NAME,
13
+ DIFFUSERS_SD_INPAINT_CLASS_NAME,
14
+ DIFFUSERS_SDXL_CLASS_NAME,
15
+ DIFFUSERS_SDXL_INPAINT_CLASS_NAME,
16
+ ANYTEXT_NAME,
17
+ )
18
+ from iopaint.model.original_sd_configs import get_config_files
19
+
20
+
21
+ def cli_download_model(model: str):
22
+ from iopaint.model import models
23
+ from iopaint.model.utils import handle_from_pretrained_exceptions
24
+
25
+ if model in models and models[model].is_erase_model:
26
+ logger.info(f"Downloading {model}...")
27
+ models[model].download()
28
+ logger.info(f"Done.")
29
+ elif model == ANYTEXT_NAME:
30
+ logger.info(f"Downloading {model}...")
31
+ models[model].download()
32
+ logger.info(f"Done.")
33
+ else:
34
+ logger.info(f"Downloading model from Huggingface: {model}")
35
+ from diffusers import DiffusionPipeline
36
+
37
+ downloaded_path = handle_from_pretrained_exceptions(
38
+ DiffusionPipeline.download,
39
+ pretrained_model_name=model,
40
+ variant="fp16",
41
+ resume_download=True,
42
+ )
43
+ logger.info(f"Done. Downloaded to {downloaded_path}")
44
+
45
+
46
+ def folder_name_to_show_name(name: str) -> str:
47
+ return name.replace("models--", "").replace("--", "/")
48
+
49
+
50
+ @lru_cache(maxsize=512)
51
+ def get_sd_model_type(model_abs_path: str) -> ModelType:
52
+ if "inpaint" in Path(model_abs_path).name.lower():
53
+ model_type = ModelType.DIFFUSERS_SD_INPAINT
54
+ else:
55
+ # load once to check num_in_channels
56
+ from diffusers import StableDiffusionInpaintPipeline
57
+
58
+ try:
59
+ StableDiffusionInpaintPipeline.from_single_file(
60
+ model_abs_path,
61
+ load_safety_checker=False,
62
+ num_in_channels=9,
63
+ config_files=get_config_files(),
64
+ )
65
+ model_type = ModelType.DIFFUSERS_SD_INPAINT
66
+ except ValueError as e:
67
+ if "Trying to set a tensor of shape torch.Size([320, 4, 3, 3])" in str(e):
68
+ model_type = ModelType.DIFFUSERS_SD
69
+ else:
70
+ raise e
71
+ return model_type
72
+
73
+
74
+ @lru_cache()
75
+ def get_sdxl_model_type(model_abs_path: str) -> ModelType:
76
+ if "inpaint" in model_abs_path:
77
+ model_type = ModelType.DIFFUSERS_SDXL_INPAINT
78
+ else:
79
+ # load once to check num_in_channels
80
+ from diffusers import StableDiffusionXLInpaintPipeline
81
+
82
+ try:
83
+ model = StableDiffusionXLInpaintPipeline.from_single_file(
84
+ model_abs_path,
85
+ load_safety_checker=False,
86
+ num_in_channels=9,
87
+ config_files=get_config_files(),
88
+ )
89
+ if model.unet.config.in_channels == 9:
90
+ # https://github.com/huggingface/diffusers/issues/6610
91
+ model_type = ModelType.DIFFUSERS_SDXL_INPAINT
92
+ else:
93
+ model_type = ModelType.DIFFUSERS_SDXL
94
+ except ValueError as e:
95
+ if "Trying to set a tensor of shape torch.Size([320, 4, 3, 3])" in str(e):
96
+ model_type = ModelType.DIFFUSERS_SDXL
97
+ else:
98
+ raise e
99
+ return model_type
100
+
101
+
102
+ def scan_single_file_diffusion_models(cache_dir) -> List[ModelInfo]:
103
+ cache_dir = Path(cache_dir)
104
+ stable_diffusion_dir = cache_dir / "stable_diffusion"
105
+ cache_file = stable_diffusion_dir / "iopaint_cache.json"
106
+ model_type_cache = {}
107
+ if cache_file.exists():
108
+ try:
109
+ with open(cache_file, "r", encoding="utf-8") as f:
110
+ model_type_cache = json.load(f)
111
+ assert isinstance(model_type_cache, dict)
112
+ except:
113
+ pass
114
+
115
+ res = []
116
+ for it in stable_diffusion_dir.glob(f"*.*"):
117
+ if it.suffix not in [".safetensors", ".ckpt"]:
118
+ continue
119
+ model_abs_path = str(it.absolute())
120
+ model_type = model_type_cache.get(it.name)
121
+ if model_type is None:
122
+ model_type = get_sd_model_type(model_abs_path)
123
+ model_type_cache[it.name] = model_type
124
+ res.append(
125
+ ModelInfo(
126
+ name=it.name,
127
+ path=model_abs_path,
128
+ model_type=model_type,
129
+ is_single_file_diffusers=True,
130
+ )
131
+ )
132
+ if stable_diffusion_dir.exists():
133
+ with open(cache_file, "w", encoding="utf-8") as fw:
134
+ json.dump(model_type_cache, fw, indent=2, ensure_ascii=False)
135
+
136
+ stable_diffusion_xl_dir = cache_dir / "stable_diffusion_xl"
137
+ sdxl_cache_file = stable_diffusion_xl_dir / "iopaint_cache.json"
138
+ sdxl_model_type_cache = {}
139
+ if sdxl_cache_file.exists():
140
+ try:
141
+ with open(sdxl_cache_file, "r", encoding="utf-8") as f:
142
+ sdxl_model_type_cache = json.load(f)
143
+ assert isinstance(sdxl_model_type_cache, dict)
144
+ except:
145
+ pass
146
+
147
+ for it in stable_diffusion_xl_dir.glob(f"*.*"):
148
+ if it.suffix not in [".safetensors", ".ckpt"]:
149
+ continue
150
+ model_abs_path = str(it.absolute())
151
+ model_type = sdxl_model_type_cache.get(it.name)
152
+ if model_type is None:
153
+ model_type = get_sdxl_model_type(model_abs_path)
154
+ sdxl_model_type_cache[it.name] = model_type
155
+ if stable_diffusion_xl_dir.exists():
156
+ with open(sdxl_cache_file, "w", encoding="utf-8") as fw:
157
+ json.dump(sdxl_model_type_cache, fw, indent=2, ensure_ascii=False)
158
+
159
+ res.append(
160
+ ModelInfo(
161
+ name=it.name,
162
+ path=model_abs_path,
163
+ model_type=model_type,
164
+ is_single_file_diffusers=True,
165
+ )
166
+ )
167
+ return res
168
+
169
+
170
+ def scan_inpaint_models(model_dir: Path) -> List[ModelInfo]:
171
+ res = []
172
+ from iopaint.model import models
173
+
174
+ # logger.info(f"Scanning inpaint models in {model_dir}")
175
+
176
+ for name, m in models.items():
177
+ if m.is_erase_model and m.is_downloaded():
178
+ res.append(
179
+ ModelInfo(
180
+ name=name,
181
+ path=name,
182
+ model_type=ModelType.INPAINT,
183
+ )
184
+ )
185
+ return res
186
+
187
+
188
+ def scan_diffusers_models() -> List[ModelInfo]:
189
+ from huggingface_hub.constants import HF_HUB_CACHE
190
+
191
+ available_models = []
192
+ cache_dir = Path(HF_HUB_CACHE)
193
+ # logger.info(f"Scanning diffusers models in {cache_dir}")
194
+ diffusers_model_names = []
195
+ for it in cache_dir.glob("**/*/model_index.json"):
196
+ with open(it, "r", encoding="utf-8") as f:
197
+ try:
198
+ data = json.load(f)
199
+ except:
200
+ continue
201
+
202
+ _class_name = data["_class_name"]
203
+ name = folder_name_to_show_name(it.parent.parent.parent.name)
204
+ if name in diffusers_model_names:
205
+ continue
206
+ if "PowerPaint" in name:
207
+ model_type = ModelType.DIFFUSERS_OTHER
208
+ elif _class_name == DIFFUSERS_SD_CLASS_NAME:
209
+ model_type = ModelType.DIFFUSERS_SD
210
+ elif _class_name == DIFFUSERS_SD_INPAINT_CLASS_NAME:
211
+ model_type = ModelType.DIFFUSERS_SD_INPAINT
212
+ elif _class_name == DIFFUSERS_SDXL_CLASS_NAME:
213
+ model_type = ModelType.DIFFUSERS_SDXL
214
+ elif _class_name == DIFFUSERS_SDXL_INPAINT_CLASS_NAME:
215
+ model_type = ModelType.DIFFUSERS_SDXL_INPAINT
216
+ elif _class_name in [
217
+ "StableDiffusionInstructPix2PixPipeline",
218
+ "PaintByExamplePipeline",
219
+ "KandinskyV22InpaintPipeline",
220
+ "AnyText",
221
+ ]:
222
+ model_type = ModelType.DIFFUSERS_OTHER
223
+ else:
224
+ continue
225
+
226
+ diffusers_model_names.append(name)
227
+ available_models.append(
228
+ ModelInfo(
229
+ name=name,
230
+ path=name,
231
+ model_type=model_type,
232
+ )
233
+ )
234
+ return available_models
235
+
236
+
237
+ def _scan_converted_diffusers_models(cache_dir) -> List[ModelInfo]:
238
+ cache_dir = Path(cache_dir)
239
+ available_models = []
240
+ diffusers_model_names = []
241
+ for it in cache_dir.glob("**/*/model_index.json"):
242
+ with open(it, "r", encoding="utf-8") as f:
243
+ try:
244
+ data = json.load(f)
245
+ except:
246
+ logger.error(
247
+ f"Failed to load {it}, please try revert from original model or fix model_index.json by hand."
248
+ )
249
+ continue
250
+
251
+ _class_name = data["_class_name"]
252
+ name = folder_name_to_show_name(it.parent.name)
253
+ if name in diffusers_model_names:
254
+ continue
255
+ elif _class_name == DIFFUSERS_SD_CLASS_NAME:
256
+ model_type = ModelType.DIFFUSERS_SD
257
+ elif _class_name == DIFFUSERS_SD_INPAINT_CLASS_NAME:
258
+ model_type = ModelType.DIFFUSERS_SD_INPAINT
259
+ elif _class_name == DIFFUSERS_SDXL_CLASS_NAME:
260
+ model_type = ModelType.DIFFUSERS_SDXL
261
+ elif _class_name == DIFFUSERS_SDXL_INPAINT_CLASS_NAME:
262
+ model_type = ModelType.DIFFUSERS_SDXL_INPAINT
263
+ else:
264
+ continue
265
+
266
+ diffusers_model_names.append(name)
267
+ available_models.append(
268
+ ModelInfo(
269
+ name=name,
270
+ path=str(it.parent.absolute()),
271
+ model_type=model_type,
272
+ )
273
+ )
274
+ return available_models
275
+
276
+
277
+ def scan_converted_diffusers_models(cache_dir) -> List[ModelInfo]:
278
+ cache_dir = Path(cache_dir)
279
+ available_models = []
280
+ stable_diffusion_dir = cache_dir / "stable_diffusion"
281
+ stable_diffusion_xl_dir = cache_dir / "stable_diffusion_xl"
282
+ available_models.extend(_scan_converted_diffusers_models(stable_diffusion_dir))
283
+ available_models.extend(_scan_converted_diffusers_models(stable_diffusion_xl_dir))
284
+ return available_models
285
+
286
+
287
+ def scan_models() -> List[ModelInfo]:
288
+ model_dir = os.getenv("XDG_CACHE_HOME", DEFAULT_MODEL_DIR)
289
+ available_models = []
290
+ available_models.extend(scan_inpaint_models(model_dir))
291
+ available_models.extend(scan_single_file_diffusion_models(model_dir))
292
+ available_models.extend(scan_diffusers_models())
293
+ available_models.extend(scan_converted_diffusers_models(model_dir))
294
+ return available_models
iopaint/file_manager/file_manager.py ADDED
@@ -0,0 +1,215 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from io import BytesIO
3
+ from pathlib import Path
4
+ from typing import List
5
+
6
+ from PIL import Image, ImageOps, PngImagePlugin
7
+ from fastapi import FastAPI, UploadFile, HTTPException
8
+ from starlette.responses import FileResponse
9
+
10
+ from ..schema import MediasResponse, MediaTab
11
+
12
+ LARGE_ENOUGH_NUMBER = 100
13
+ PngImagePlugin.MAX_TEXT_CHUNK = LARGE_ENOUGH_NUMBER * (1024**2)
14
+ from .storage_backends import FilesystemStorageBackend
15
+ from .utils import aspect_to_string, generate_filename, glob_img
16
+
17
+
18
+ class FileManager:
19
+ def __init__(self, app: FastAPI, input_dir: Path, output_dir: Path):
20
+ self.app = app
21
+ self.input_dir: Path = input_dir
22
+ self.output_dir: Path = output_dir
23
+
24
+ self.image_dir_filenames = []
25
+ self.output_dir_filenames = []
26
+ if not self.thumbnail_directory.exists():
27
+ self.thumbnail_directory.mkdir(parents=True)
28
+
29
+ # fmt: off
30
+ self.app.add_api_route("/api/v1/medias", self.api_medias, methods=["GET"], response_model=List[MediasResponse])
31
+ self.app.add_api_route("/api/v1/media_file", self.api_media_file, methods=["GET"])
32
+ self.app.add_api_route("/api/v1/media_thumbnail_file", self.api_media_thumbnail_file, methods=["GET"])
33
+ # fmt: on
34
+
35
+ def api_medias(self, tab: MediaTab) -> List[MediasResponse]:
36
+ img_dir = self._get_dir(tab)
37
+ return self._media_names(img_dir)
38
+
39
+ def api_media_file(self, tab: MediaTab, filename: str) -> FileResponse:
40
+ file_path = self._get_file(tab, filename)
41
+ return FileResponse(file_path, media_type="image/png")
42
+
43
+ # tab=${tab}?filename=${filename.name}?width=${width}&height=${height}
44
+ def api_media_thumbnail_file(
45
+ self, tab: MediaTab, filename: str, width: int, height: int
46
+ ) -> FileResponse:
47
+ img_dir = self._get_dir(tab)
48
+ thumb_filename, (width, height) = self.get_thumbnail(
49
+ img_dir, filename, width=width, height=height
50
+ )
51
+ thumbnail_filepath = self.thumbnail_directory / thumb_filename
52
+ return FileResponse(
53
+ thumbnail_filepath,
54
+ headers={
55
+ "X-Width": str(width),
56
+ "X-Height": str(height),
57
+ },
58
+ media_type="image/jpeg",
59
+ )
60
+
61
+ def _get_dir(self, tab: MediaTab) -> Path:
62
+ if tab == "input":
63
+ return self.input_dir
64
+ elif tab == "output":
65
+ return self.output_dir
66
+ else:
67
+ raise HTTPException(status_code=422, detail=f"tab not found: {tab}")
68
+
69
+ def _get_file(self, tab: MediaTab, filename: str) -> Path:
70
+ file_path = self._get_dir(tab) / filename
71
+ if not file_path.exists():
72
+ raise HTTPException(status_code=422, detail=f"file not found: {file_path}")
73
+ return file_path
74
+
75
+ @property
76
+ def thumbnail_directory(self) -> Path:
77
+ return self.output_dir / "thumbnails"
78
+
79
+ @staticmethod
80
+ def _media_names(directory: Path) -> List[MediasResponse]:
81
+ names = sorted([it.name for it in glob_img(directory)])
82
+ res = []
83
+ for name in names:
84
+ path = os.path.join(directory, name)
85
+ img = Image.open(path)
86
+ res.append(
87
+ MediasResponse(
88
+ name=name,
89
+ height=img.height,
90
+ width=img.width,
91
+ ctime=os.path.getctime(path),
92
+ mtime=os.path.getmtime(path),
93
+ )
94
+ )
95
+ return res
96
+
97
+ def get_thumbnail(
98
+ self, directory: Path, original_filename: str, width, height, **options
99
+ ):
100
+ directory = Path(directory)
101
+ storage = FilesystemStorageBackend(self.app)
102
+ crop = options.get("crop", "fit")
103
+ background = options.get("background")
104
+ quality = options.get("quality", 90)
105
+
106
+ original_path, original_filename = os.path.split(original_filename)
107
+ original_filepath = os.path.join(directory, original_path, original_filename)
108
+ image = Image.open(BytesIO(storage.read(original_filepath)))
109
+
110
+ # keep ratio resize
111
+ if not width and not height:
112
+ width = 256
113
+
114
+ if width != 0:
115
+ height = int(image.height * width / image.width)
116
+ else:
117
+ width = int(image.width * height / image.height)
118
+
119
+ thumbnail_size = (width, height)
120
+
121
+ thumbnail_filename = generate_filename(
122
+ directory,
123
+ original_filename,
124
+ aspect_to_string(thumbnail_size),
125
+ crop,
126
+ background,
127
+ quality,
128
+ )
129
+
130
+ thumbnail_filepath = os.path.join(
131
+ self.thumbnail_directory, original_path, thumbnail_filename
132
+ )
133
+
134
+ if storage.exists(thumbnail_filepath):
135
+ return thumbnail_filepath, (width, height)
136
+
137
+ try:
138
+ image.load()
139
+ except (IOError, OSError):
140
+ self.app.logger.warning("Thumbnail not load image: %s", original_filepath)
141
+ return thumbnail_filepath, (width, height)
142
+
143
+ # get original image format
144
+ options["format"] = options.get("format", image.format)
145
+
146
+ image = self._create_thumbnail(
147
+ image, thumbnail_size, crop, background=background
148
+ )
149
+
150
+ raw_data = self.get_raw_data(image, **options)
151
+ storage.save(thumbnail_filepath, raw_data)
152
+
153
+ return thumbnail_filepath, (width, height)
154
+
155
+ def get_raw_data(self, image, **options):
156
+ data = {
157
+ "format": self._get_format(image, **options),
158
+ "quality": options.get("quality", 90),
159
+ }
160
+
161
+ _file = BytesIO()
162
+ image.save(_file, **data)
163
+ return _file.getvalue()
164
+
165
+ @staticmethod
166
+ def colormode(image, colormode="RGB"):
167
+ if colormode == "RGB" or colormode == "RGBA":
168
+ if image.mode == "RGBA":
169
+ return image
170
+ if image.mode == "LA":
171
+ return image.convert("RGBA")
172
+ return image.convert(colormode)
173
+
174
+ if colormode == "GRAY":
175
+ return image.convert("L")
176
+
177
+ return image.convert(colormode)
178
+
179
+ @staticmethod
180
+ def background(original_image, color=0xFF):
181
+ size = (max(original_image.size),) * 2
182
+ image = Image.new("L", size, color)
183
+ image.paste(
184
+ original_image,
185
+ tuple(map(lambda x: (x[0] - x[1]) / 2, zip(size, original_image.size))),
186
+ )
187
+
188
+ return image
189
+
190
+ def _get_format(self, image, **options):
191
+ if options.get("format"):
192
+ return options.get("format")
193
+ if image.format:
194
+ return image.format
195
+
196
+ return "JPEG"
197
+
198
+ def _create_thumbnail(self, image, size, crop="fit", background=None):
199
+ try:
200
+ resample = Image.Resampling.LANCZOS
201
+ except AttributeError: # pylint: disable=raise-missing-from
202
+ resample = Image.ANTIALIAS
203
+
204
+ if crop == "fit":
205
+ image = ImageOps.fit(image, size, resample)
206
+ else:
207
+ image = image.copy()
208
+ image.thumbnail(size, resample=resample)
209
+
210
+ if background is not None:
211
+ image = self.background(image)
212
+
213
+ image = self.colormode(image)
214
+
215
+ return image
iopaint/helper.py ADDED
@@ -0,0 +1,425 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import base64
2
+ import imghdr
3
+ import io
4
+ import os
5
+ import sys
6
+ from typing import List, Optional, Dict, Tuple
7
+
8
+ from urllib.parse import urlparse
9
+ import cv2
10
+ from PIL import Image, ImageOps, PngImagePlugin
11
+ import numpy as np
12
+ import torch
13
+ from iopaint.const import MPS_UNSUPPORT_MODELS
14
+ from loguru import logger
15
+ from torch.hub import download_url_to_file, get_dir
16
+ import hashlib
17
+
18
+
19
+ def md5sum(filename):
20
+ md5 = hashlib.md5()
21
+ with open(filename, "rb") as f:
22
+ for chunk in iter(lambda: f.read(128 * md5.block_size), b""):
23
+ md5.update(chunk)
24
+ return md5.hexdigest()
25
+
26
+
27
+ def switch_mps_device(model_name, device):
28
+ if model_name in MPS_UNSUPPORT_MODELS and str(device) == "mps":
29
+ logger.info(f"{model_name} not support mps, switch to cpu")
30
+ return torch.device("cpu")
31
+ return device
32
+
33
+
34
+ def get_cache_path_by_url(url):
35
+ parts = urlparse(url)
36
+ hub_dir = get_dir()
37
+ model_dir = os.path.join(hub_dir, "checkpoints")
38
+ if not os.path.isdir(model_dir):
39
+ os.makedirs(model_dir)
40
+ filename = os.path.basename(parts.path)
41
+ cached_file = os.path.join(model_dir, filename)
42
+ return cached_file
43
+
44
+ def get_cache_path_by_local(url):
45
+ root_path = os.getcwd()
46
+ model_path = os.path.join(root_path, 'pretrained-model', 'big-lama.pt')
47
+ return model_path
48
+
49
+ def download_model(url, model_md5: str = None):
50
+ cached_file = get_cache_path_by_url(url)
51
+ # cached_file = get_cache_path_by_local(url)
52
+ if not os.path.exists(cached_file):
53
+ sys.stderr.write('Downloading: "{}" to {}\n'.format(url, cached_file))
54
+ hash_prefix = None
55
+ download_url_to_file(url, cached_file, hash_prefix, progress=True)
56
+ if model_md5:
57
+ _md5 = md5sum(cached_file)
58
+ if model_md5 == _md5:
59
+ logger.info(f"Download model success, md5: {_md5}")
60
+ else:
61
+ try:
62
+ os.remove(cached_file)
63
+ logger.error(
64
+ f"Model md5: {_md5}, expected md5: {model_md5}, wrong model deleted. Please restart iopaint."
65
+ f"If you still have errors, please try download model manually first https://lama-cleaner-docs.vercel.app/install/download_model_manually.\n"
66
+ )
67
+ except:
68
+ logger.error(
69
+ f"Model md5: {_md5}, expected md5: {model_md5}, please delete {cached_file} and restart iopaint."
70
+ )
71
+ exit(-1)
72
+
73
+ return cached_file
74
+
75
+
76
+ def ceil_modulo(x, mod):
77
+ if x % mod == 0:
78
+ return x
79
+ return (x // mod + 1) * mod
80
+
81
+
82
+ def handle_error(model_path, model_md5, e):
83
+ _md5 = md5sum(model_path)
84
+ if _md5 != model_md5:
85
+ try:
86
+ os.remove(model_path)
87
+ logger.error(
88
+ f"Model md5: {_md5}, expected md5: {model_md5}, wrong model deleted. Please restart iopaint."
89
+ f"If you still have errors, please try download model manually first https://lama-cleaner-docs.vercel.app/install/download_model_manually.\n"
90
+ )
91
+ except:
92
+ logger.error(
93
+ f"Model md5: {_md5}, expected md5: {model_md5}, please delete {model_path} and restart iopaint."
94
+ )
95
+ else:
96
+ logger.error(
97
+ f"Failed to load model {model_path},"
98
+ f"please submit an issue at https://github.com/Sanster/lama-cleaner/issues and include a screenshot of the error:\n{e}"
99
+ )
100
+ exit(-1)
101
+
102
+
103
+ def load_jit_model(url_or_path, device, model_md5: str):
104
+ if os.path.exists(url_or_path):
105
+ model_path = url_or_path
106
+ else:
107
+ model_path = download_model(url_or_path, model_md5)
108
+
109
+ logger.info(f"Loading model from: {model_path}")
110
+ try:
111
+ model = torch.jit.load(model_path, map_location="cpu").to(device)
112
+ except Exception as e:
113
+ handle_error(model_path, model_md5, e)
114
+ model.eval()
115
+ return model
116
+
117
+
118
+ def load_model(model: torch.nn.Module, url_or_path, device, model_md5):
119
+ if os.path.exists(url_or_path):
120
+ model_path = url_or_path
121
+ else:
122
+ model_path = download_model(url_or_path, model_md5)
123
+
124
+ try:
125
+ logger.info(f"Loading model from: {model_path}")
126
+ state_dict = torch.load(model_path, map_location="cpu")
127
+ model.load_state_dict(state_dict, strict=True)
128
+ model.to(device)
129
+ except Exception as e:
130
+ handle_error(model_path, model_md5, e)
131
+ model.eval()
132
+ return model
133
+
134
+
135
+ def numpy_to_bytes(image_numpy: np.ndarray, ext: str) -> bytes:
136
+ data = cv2.imencode(
137
+ f".{ext}",
138
+ image_numpy,
139
+ [int(cv2.IMWRITE_JPEG_QUALITY), 100, int(cv2.IMWRITE_PNG_COMPRESSION), 0],
140
+ )[1]
141
+ image_bytes = data.tobytes()
142
+ return image_bytes
143
+
144
+
145
+ def pil_to_bytes(pil_img, ext: str, quality: int = 95, infos={}) -> bytes:
146
+ with io.BytesIO() as output:
147
+ kwargs = {k: v for k, v in infos.items() if v is not None}
148
+ if ext == "jpg":
149
+ ext = "jpeg"
150
+ if "png" == ext.lower() and "parameters" in kwargs:
151
+ pnginfo_data = PngImagePlugin.PngInfo()
152
+ pnginfo_data.add_text("parameters", kwargs["parameters"])
153
+ kwargs["pnginfo"] = pnginfo_data
154
+
155
+ pil_img.save(output, format=ext, quality=quality, **kwargs)
156
+ image_bytes = output.getvalue()
157
+ return image_bytes
158
+
159
+ def pil_to_bytes_single(pil_img, ext: str, quality: int = 95, infos=None) -> bytes:
160
+ infos = infos or {} # Use an empty dictionary if infos is None
161
+ with io.BytesIO() as output:
162
+ kwargs = {k: v for k, v in infos.items() if v is not None}
163
+ if ext == "jpg":
164
+ ext = "jpeg"
165
+ if "png" == ext.lower() and "parameters" in kwargs:
166
+ pnginfo_data = PngImagePlugin.PngInfo()
167
+ pnginfo_data.add_text("parameters", kwargs["parameters"])
168
+ kwargs["pnginfo"] = pnginfo_data
169
+
170
+ pil_img.save(output, format=ext, quality=quality, **kwargs)
171
+ image_bytes = output.getvalue()
172
+ return image_bytes
173
+
174
+
175
+ def load_img(img_bytes, gray: bool = False, return_info: bool = False):
176
+ alpha_channel = None
177
+ image = Image.open(io.BytesIO(img_bytes))
178
+
179
+ if return_info:
180
+ infos = image.info
181
+
182
+ try:
183
+ image = ImageOps.exif_transpose(image)
184
+ except:
185
+ pass
186
+
187
+ if gray:
188
+ image = image.convert("L")
189
+ np_img = np.array(image)
190
+ else:
191
+ if image.mode == "RGBA":
192
+ np_img = np.array(image)
193
+ alpha_channel = np_img[:, :, -1]
194
+ np_img = cv2.cvtColor(np_img, cv2.COLOR_RGBA2RGB)
195
+ else:
196
+ image = image.convert("RGB")
197
+ np_img = np.array(image)
198
+
199
+ if return_info:
200
+ return np_img, alpha_channel, infos
201
+ return np_img, alpha_channel
202
+
203
+
204
+ def norm_img(np_img):
205
+ if len(np_img.shape) == 2:
206
+ np_img = np_img[:, :, np.newaxis]
207
+ np_img = np.transpose(np_img, (2, 0, 1))
208
+ np_img = np_img.astype("float32") / 255
209
+ return np_img
210
+
211
+
212
+ def resize_max_size(
213
+ np_img, size_limit: int, interpolation=cv2.INTER_CUBIC
214
+ ) -> np.ndarray:
215
+ # Resize image's longer size to size_limit if longer size larger than size_limit
216
+ h, w = np_img.shape[:2]
217
+ if max(h, w) > size_limit:
218
+ ratio = size_limit / max(h, w)
219
+ new_w = int(w * ratio + 0.5)
220
+ new_h = int(h * ratio + 0.5)
221
+ return cv2.resize(np_img, dsize=(new_w, new_h), interpolation=interpolation)
222
+ else:
223
+ return np_img
224
+
225
+
226
+ def pad_img_to_modulo(
227
+ img: np.ndarray, mod: int, square: bool = False, min_size: Optional[int] = None
228
+ ):
229
+ """
230
+
231
+ Args:
232
+ img: [H, W, C]
233
+ mod:
234
+ square: 是否为正方形
235
+ min_size:
236
+
237
+ Returns:
238
+
239
+ """
240
+ if len(img.shape) == 2:
241
+ img = img[:, :, np.newaxis]
242
+ height, width = img.shape[:2]
243
+ out_height = ceil_modulo(height, mod)
244
+ out_width = ceil_modulo(width, mod)
245
+
246
+ if min_size is not None:
247
+ assert min_size % mod == 0
248
+ out_width = max(min_size, out_width)
249
+ out_height = max(min_size, out_height)
250
+
251
+ if square:
252
+ max_size = max(out_height, out_width)
253
+ out_height = max_size
254
+ out_width = max_size
255
+
256
+ return np.pad(
257
+ img,
258
+ ((0, out_height - height), (0, out_width - width), (0, 0)),
259
+ mode="symmetric",
260
+ )
261
+
262
+
263
+ def boxes_from_mask(mask: np.ndarray) -> List[np.ndarray]:
264
+ """
265
+ Args:
266
+ mask: (h, w, 1) 0~255
267
+
268
+ Returns:
269
+
270
+ """
271
+ height, width = mask.shape[:2]
272
+ _, thresh = cv2.threshold(mask, 127, 255, 0)
273
+ contours, _ = cv2.findContours(thresh, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
274
+
275
+ boxes = []
276
+ for cnt in contours:
277
+ x, y, w, h = cv2.boundingRect(cnt)
278
+ box = np.array([x, y, x + w, y + h]).astype(int)
279
+
280
+ box[::2] = np.clip(box[::2], 0, width)
281
+ box[1::2] = np.clip(box[1::2], 0, height)
282
+ boxes.append(box)
283
+
284
+ return boxes
285
+
286
+
287
+ def only_keep_largest_contour(mask: np.ndarray) -> List[np.ndarray]:
288
+ """
289
+ Args:
290
+ mask: (h, w) 0~255
291
+
292
+ Returns:
293
+
294
+ """
295
+ _, thresh = cv2.threshold(mask, 127, 255, 0)
296
+ contours, _ = cv2.findContours(thresh, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
297
+
298
+ max_area = 0
299
+ max_index = -1
300
+ for i, cnt in enumerate(contours):
301
+ area = cv2.contourArea(cnt)
302
+ if area > max_area:
303
+ max_area = area
304
+ max_index = i
305
+
306
+ if max_index != -1:
307
+ new_mask = np.zeros_like(mask)
308
+ return cv2.drawContours(new_mask, contours, max_index, 255, -1)
309
+ else:
310
+ return mask
311
+
312
+
313
+ def is_mac():
314
+ return sys.platform == "darwin"
315
+
316
+
317
+ def get_image_ext(img_bytes):
318
+ w = imghdr.what("", img_bytes)
319
+ if w is None:
320
+ w = "jpeg"
321
+ return w
322
+
323
+
324
+ def decode_base64_to_image(
325
+ encoding: str, gray=False
326
+ ) -> Tuple[np.array, Optional[np.array], Dict]:
327
+ if encoding.startswith("data:image/") or encoding.startswith(
328
+ "data:application/octet-stream;base64,"
329
+ ):
330
+ encoding = encoding.split(";")[1].split(",")[1]
331
+ image = Image.open(io.BytesIO(base64.b64decode(encoding)))
332
+
333
+ alpha_channel = None
334
+ try:
335
+ image = ImageOps.exif_transpose(image)
336
+ except:
337
+ pass
338
+ # exif_transpose will remove exif rotate info,we must call image.info after exif_transpose
339
+ infos = image.info
340
+
341
+ if gray:
342
+ image = image.convert("L")
343
+ np_img = np.array(image)
344
+ else:
345
+ if image.mode == "RGBA":
346
+ np_img = np.array(image)
347
+ alpha_channel = np_img[:, :, -1]
348
+ np_img = cv2.cvtColor(np_img, cv2.COLOR_RGBA2RGB)
349
+ else:
350
+ image = image.convert("RGB")
351
+ np_img = np.array(image)
352
+
353
+ return np_img, alpha_channel, infos
354
+
355
+
356
+ def encode_pil_to_base64(image: Image, quality: int, infos: Dict) -> bytes:
357
+ img_bytes = pil_to_bytes(
358
+ image,
359
+ "png",
360
+ quality=quality,
361
+ infos=infos,
362
+ )
363
+ return base64.b64encode(img_bytes)
364
+
365
+
366
+ def concat_alpha_channel(rgb_np_img, alpha_channel) -> np.ndarray:
367
+ if alpha_channel is not None:
368
+ if alpha_channel.shape[:2] != rgb_np_img.shape[:2]:
369
+ alpha_channel = cv2.resize(
370
+ alpha_channel, dsize=(rgb_np_img.shape[1], rgb_np_img.shape[0])
371
+ )
372
+ rgb_np_img = np.concatenate(
373
+ (rgb_np_img, alpha_channel[:, :, np.newaxis]), axis=-1
374
+ )
375
+ return rgb_np_img
376
+
377
+
378
+ def adjust_mask(mask: np.ndarray, kernel_size: int, operate):
379
+ # fronted brush color "ffcc00bb"
380
+ # kernel_size = kernel_size*2+1
381
+ mask[mask >= 127] = 255
382
+ mask[mask < 127] = 0
383
+
384
+ if operate == "reverse":
385
+ mask = 255 - mask
386
+ else:
387
+ kernel = cv2.getStructuringElement(
388
+ cv2.MORPH_ELLIPSE, (2 * kernel_size + 1, 2 * kernel_size + 1)
389
+ )
390
+ if operate == "expand":
391
+ mask = cv2.dilate(
392
+ mask,
393
+ kernel,
394
+ iterations=1,
395
+ )
396
+ else:
397
+ mask = cv2.erode(
398
+ mask,
399
+ kernel,
400
+ iterations=1,
401
+ )
402
+ res_mask = np.zeros((mask.shape[0], mask.shape[1], 4), dtype=np.uint8)
403
+ res_mask[mask > 128] = [255, 203, 0, int(255 * 0.73)]
404
+ res_mask = cv2.cvtColor(res_mask, cv2.COLOR_BGRA2RGBA)
405
+ return res_mask
406
+
407
+
408
+ def gen_frontend_mask(bgr_or_gray_mask):
409
+ if len(bgr_or_gray_mask.shape) == 3 and bgr_or_gray_mask.shape[2] != 1:
410
+ bgr_or_gray_mask = cv2.cvtColor(bgr_or_gray_mask, cv2.COLOR_BGR2GRAY)
411
+
412
+ # fronted brush color "ffcc00bb"
413
+ # TODO: how to set kernel size?
414
+ kernel_size = 9
415
+ bgr_or_gray_mask = cv2.dilate(
416
+ bgr_or_gray_mask,
417
+ np.ones((kernel_size, kernel_size), np.uint8),
418
+ iterations=1,
419
+ )
420
+ res_mask = np.zeros(
421
+ (bgr_or_gray_mask.shape[0], bgr_or_gray_mask.shape[1], 4), dtype=np.uint8
422
+ )
423
+ res_mask[bgr_or_gray_mask > 128] = [255, 203, 0, int(255 * 0.73)]
424
+ res_mask = cv2.cvtColor(res_mask, cv2.COLOR_BGRA2RGBA)
425
+ return res_mask
iopaint/installer.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import subprocess
2
+ import sys
3
+
4
+
5
+ def install(package):
6
+ subprocess.check_call([sys.executable, "-m", "pip", "install", package])
7
+
8
+
9
+ def install_plugins_package():
10
+ install("rembg")
11
+ install("realesrgan")
12
+ install("gfpgan")
iopaint/model/anytext/cldm/cldm.py ADDED
@@ -0,0 +1,630 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from pathlib import Path
3
+
4
+ import einops
5
+ import torch
6
+ import torch as th
7
+ import torch.nn as nn
8
+ import copy
9
+ from easydict import EasyDict as edict
10
+
11
+ from iopaint.model.anytext.ldm.modules.diffusionmodules.util import (
12
+ conv_nd,
13
+ linear,
14
+ zero_module,
15
+ timestep_embedding,
16
+ )
17
+
18
+ from einops import rearrange, repeat
19
+ from iopaint.model.anytext.ldm.modules.attention import SpatialTransformer
20
+ from iopaint.model.anytext.ldm.modules.diffusionmodules.openaimodel import UNetModel, TimestepEmbedSequential, ResBlock, Downsample, AttentionBlock
21
+ from iopaint.model.anytext.ldm.models.diffusion.ddpm import LatentDiffusion
22
+ from iopaint.model.anytext.ldm.util import log_txt_as_img, exists, instantiate_from_config
23
+ from iopaint.model.anytext.ldm.models.diffusion.ddim import DDIMSampler
24
+ from iopaint.model.anytext.ldm.modules.distributions.distributions import DiagonalGaussianDistribution
25
+ from .recognizer import TextRecognizer, create_predictor
26
+
27
+ CURRENT_DIR = Path(os.path.dirname(os.path.abspath(__file__)))
28
+
29
+
30
+ def count_parameters(model):
31
+ return sum(p.numel() for p in model.parameters() if p.requires_grad)
32
+
33
+
34
+ class ControlledUnetModel(UNetModel):
35
+ def forward(self, x, timesteps=None, context=None, control=None, only_mid_control=False, **kwargs):
36
+ hs = []
37
+ with torch.no_grad():
38
+ t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False)
39
+ if self.use_fp16:
40
+ t_emb = t_emb.half()
41
+ emb = self.time_embed(t_emb)
42
+ h = x.type(self.dtype)
43
+ for module in self.input_blocks:
44
+ h = module(h, emb, context)
45
+ hs.append(h)
46
+ h = self.middle_block(h, emb, context)
47
+
48
+ if control is not None:
49
+ h += control.pop()
50
+
51
+ for i, module in enumerate(self.output_blocks):
52
+ if only_mid_control or control is None:
53
+ h = torch.cat([h, hs.pop()], dim=1)
54
+ else:
55
+ h = torch.cat([h, hs.pop() + control.pop()], dim=1)
56
+ h = module(h, emb, context)
57
+
58
+ h = h.type(x.dtype)
59
+ return self.out(h)
60
+
61
+
62
+ class ControlNet(nn.Module):
63
+ def __init__(
64
+ self,
65
+ image_size,
66
+ in_channels,
67
+ model_channels,
68
+ glyph_channels,
69
+ position_channels,
70
+ num_res_blocks,
71
+ attention_resolutions,
72
+ dropout=0,
73
+ channel_mult=(1, 2, 4, 8),
74
+ conv_resample=True,
75
+ dims=2,
76
+ use_checkpoint=False,
77
+ use_fp16=False,
78
+ num_heads=-1,
79
+ num_head_channels=-1,
80
+ num_heads_upsample=-1,
81
+ use_scale_shift_norm=False,
82
+ resblock_updown=False,
83
+ use_new_attention_order=False,
84
+ use_spatial_transformer=False, # custom transformer support
85
+ transformer_depth=1, # custom transformer support
86
+ context_dim=None, # custom transformer support
87
+ n_embed=None, # custom support for prediction of discrete ids into codebook of first stage vq model
88
+ legacy=True,
89
+ disable_self_attentions=None,
90
+ num_attention_blocks=None,
91
+ disable_middle_self_attn=False,
92
+ use_linear_in_transformer=False,
93
+ ):
94
+ super().__init__()
95
+ if use_spatial_transformer:
96
+ assert context_dim is not None, 'Fool!! You forgot to include the dimension of your cross-attention conditioning...'
97
+
98
+ if context_dim is not None:
99
+ assert use_spatial_transformer, 'Fool!! You forgot to use the spatial transformer for your cross-attention conditioning...'
100
+ from omegaconf.listconfig import ListConfig
101
+ if type(context_dim) == ListConfig:
102
+ context_dim = list(context_dim)
103
+
104
+ if num_heads_upsample == -1:
105
+ num_heads_upsample = num_heads
106
+
107
+ if num_heads == -1:
108
+ assert num_head_channels != -1, 'Either num_heads or num_head_channels has to be set'
109
+
110
+ if num_head_channels == -1:
111
+ assert num_heads != -1, 'Either num_heads or num_head_channels has to be set'
112
+ self.dims = dims
113
+ self.image_size = image_size
114
+ self.in_channels = in_channels
115
+ self.model_channels = model_channels
116
+ if isinstance(num_res_blocks, int):
117
+ self.num_res_blocks = len(channel_mult) * [num_res_blocks]
118
+ else:
119
+ if len(num_res_blocks) != len(channel_mult):
120
+ raise ValueError("provide num_res_blocks either as an int (globally constant) or "
121
+ "as a list/tuple (per-level) with the same length as channel_mult")
122
+ self.num_res_blocks = num_res_blocks
123
+ if disable_self_attentions is not None:
124
+ # should be a list of booleans, indicating whether to disable self-attention in TransformerBlocks or not
125
+ assert len(disable_self_attentions) == len(channel_mult)
126
+ if num_attention_blocks is not None:
127
+ assert len(num_attention_blocks) == len(self.num_res_blocks)
128
+ assert all(map(lambda i: self.num_res_blocks[i] >= num_attention_blocks[i], range(len(num_attention_blocks))))
129
+ print(f"Constructor of UNetModel received num_attention_blocks={num_attention_blocks}. "
130
+ f"This option has LESS priority than attention_resolutions {attention_resolutions}, "
131
+ f"i.e., in cases where num_attention_blocks[i] > 0 but 2**i not in attention_resolutions, "
132
+ f"attention will still not be set.")
133
+ self.attention_resolutions = attention_resolutions
134
+ self.dropout = dropout
135
+ self.channel_mult = channel_mult
136
+ self.conv_resample = conv_resample
137
+ self.use_checkpoint = use_checkpoint
138
+ self.use_fp16 = use_fp16
139
+ self.dtype = th.float16 if use_fp16 else th.float32
140
+ self.num_heads = num_heads
141
+ self.num_head_channels = num_head_channels
142
+ self.num_heads_upsample = num_heads_upsample
143
+ self.predict_codebook_ids = n_embed is not None
144
+
145
+ time_embed_dim = model_channels * 4
146
+ self.time_embed = nn.Sequential(
147
+ linear(model_channels, time_embed_dim),
148
+ nn.SiLU(),
149
+ linear(time_embed_dim, time_embed_dim),
150
+ )
151
+
152
+ self.input_blocks = nn.ModuleList(
153
+ [
154
+ TimestepEmbedSequential(
155
+ conv_nd(dims, in_channels, model_channels, 3, padding=1)
156
+ )
157
+ ]
158
+ )
159
+ self.zero_convs = nn.ModuleList([self.make_zero_conv(model_channels)])
160
+
161
+ self.glyph_block = TimestepEmbedSequential(
162
+ conv_nd(dims, glyph_channels, 8, 3, padding=1),
163
+ nn.SiLU(),
164
+ conv_nd(dims, 8, 8, 3, padding=1),
165
+ nn.SiLU(),
166
+ conv_nd(dims, 8, 16, 3, padding=1, stride=2),
167
+ nn.SiLU(),
168
+ conv_nd(dims, 16, 16, 3, padding=1),
169
+ nn.SiLU(),
170
+ conv_nd(dims, 16, 32, 3, padding=1, stride=2),
171
+ nn.SiLU(),
172
+ conv_nd(dims, 32, 32, 3, padding=1),
173
+ nn.SiLU(),
174
+ conv_nd(dims, 32, 96, 3, padding=1, stride=2),
175
+ nn.SiLU(),
176
+ conv_nd(dims, 96, 96, 3, padding=1),
177
+ nn.SiLU(),
178
+ conv_nd(dims, 96, 256, 3, padding=1, stride=2),
179
+ nn.SiLU(),
180
+ )
181
+
182
+ self.position_block = TimestepEmbedSequential(
183
+ conv_nd(dims, position_channels, 8, 3, padding=1),
184
+ nn.SiLU(),
185
+ conv_nd(dims, 8, 8, 3, padding=1),
186
+ nn.SiLU(),
187
+ conv_nd(dims, 8, 16, 3, padding=1, stride=2),
188
+ nn.SiLU(),
189
+ conv_nd(dims, 16, 16, 3, padding=1),
190
+ nn.SiLU(),
191
+ conv_nd(dims, 16, 32, 3, padding=1, stride=2),
192
+ nn.SiLU(),
193
+ conv_nd(dims, 32, 32, 3, padding=1),
194
+ nn.SiLU(),
195
+ conv_nd(dims, 32, 64, 3, padding=1, stride=2),
196
+ nn.SiLU(),
197
+ )
198
+
199
+ self.fuse_block = zero_module(conv_nd(dims, 256+64+4, model_channels, 3, padding=1))
200
+
201
+ self._feature_size = model_channels
202
+ input_block_chans = [model_channels]
203
+ ch = model_channels
204
+ ds = 1
205
+ for level, mult in enumerate(channel_mult):
206
+ for nr in range(self.num_res_blocks[level]):
207
+ layers = [
208
+ ResBlock(
209
+ ch,
210
+ time_embed_dim,
211
+ dropout,
212
+ out_channels=mult * model_channels,
213
+ dims=dims,
214
+ use_checkpoint=use_checkpoint,
215
+ use_scale_shift_norm=use_scale_shift_norm,
216
+ )
217
+ ]
218
+ ch = mult * model_channels
219
+ if ds in attention_resolutions:
220
+ if num_head_channels == -1:
221
+ dim_head = ch // num_heads
222
+ else:
223
+ num_heads = ch // num_head_channels
224
+ dim_head = num_head_channels
225
+ if legacy:
226
+ # num_heads = 1
227
+ dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
228
+ if exists(disable_self_attentions):
229
+ disabled_sa = disable_self_attentions[level]
230
+ else:
231
+ disabled_sa = False
232
+
233
+ if not exists(num_attention_blocks) or nr < num_attention_blocks[level]:
234
+ layers.append(
235
+ AttentionBlock(
236
+ ch,
237
+ use_checkpoint=use_checkpoint,
238
+ num_heads=num_heads,
239
+ num_head_channels=dim_head,
240
+ use_new_attention_order=use_new_attention_order,
241
+ ) if not use_spatial_transformer else SpatialTransformer(
242
+ ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim,
243
+ disable_self_attn=disabled_sa, use_linear=use_linear_in_transformer,
244
+ use_checkpoint=use_checkpoint
245
+ )
246
+ )
247
+ self.input_blocks.append(TimestepEmbedSequential(*layers))
248
+ self.zero_convs.append(self.make_zero_conv(ch))
249
+ self._feature_size += ch
250
+ input_block_chans.append(ch)
251
+ if level != len(channel_mult) - 1:
252
+ out_ch = ch
253
+ self.input_blocks.append(
254
+ TimestepEmbedSequential(
255
+ ResBlock(
256
+ ch,
257
+ time_embed_dim,
258
+ dropout,
259
+ out_channels=out_ch,
260
+ dims=dims,
261
+ use_checkpoint=use_checkpoint,
262
+ use_scale_shift_norm=use_scale_shift_norm,
263
+ down=True,
264
+ )
265
+ if resblock_updown
266
+ else Downsample(
267
+ ch, conv_resample, dims=dims, out_channels=out_ch
268
+ )
269
+ )
270
+ )
271
+ ch = out_ch
272
+ input_block_chans.append(ch)
273
+ self.zero_convs.append(self.make_zero_conv(ch))
274
+ ds *= 2
275
+ self._feature_size += ch
276
+
277
+ if num_head_channels == -1:
278
+ dim_head = ch // num_heads
279
+ else:
280
+ num_heads = ch // num_head_channels
281
+ dim_head = num_head_channels
282
+ if legacy:
283
+ # num_heads = 1
284
+ dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
285
+ self.middle_block = TimestepEmbedSequential(
286
+ ResBlock(
287
+ ch,
288
+ time_embed_dim,
289
+ dropout,
290
+ dims=dims,
291
+ use_checkpoint=use_checkpoint,
292
+ use_scale_shift_norm=use_scale_shift_norm,
293
+ ),
294
+ AttentionBlock(
295
+ ch,
296
+ use_checkpoint=use_checkpoint,
297
+ num_heads=num_heads,
298
+ num_head_channels=dim_head,
299
+ use_new_attention_order=use_new_attention_order,
300
+ ) if not use_spatial_transformer else SpatialTransformer( # always uses a self-attn
301
+ ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim,
302
+ disable_self_attn=disable_middle_self_attn, use_linear=use_linear_in_transformer,
303
+ use_checkpoint=use_checkpoint
304
+ ),
305
+ ResBlock(
306
+ ch,
307
+ time_embed_dim,
308
+ dropout,
309
+ dims=dims,
310
+ use_checkpoint=use_checkpoint,
311
+ use_scale_shift_norm=use_scale_shift_norm,
312
+ ),
313
+ )
314
+ self.middle_block_out = self.make_zero_conv(ch)
315
+ self._feature_size += ch
316
+
317
+ def make_zero_conv(self, channels):
318
+ return TimestepEmbedSequential(zero_module(conv_nd(self.dims, channels, channels, 1, padding=0)))
319
+
320
+ def forward(self, x, hint, text_info, timesteps, context, **kwargs):
321
+ t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False)
322
+ if self.use_fp16:
323
+ t_emb = t_emb.half()
324
+ emb = self.time_embed(t_emb)
325
+
326
+ # guided_hint from text_info
327
+ B, C, H, W = x.shape
328
+ glyphs = torch.cat(text_info['glyphs'], dim=1).sum(dim=1, keepdim=True)
329
+ positions = torch.cat(text_info['positions'], dim=1).sum(dim=1, keepdim=True)
330
+ enc_glyph = self.glyph_block(glyphs, emb, context)
331
+ enc_pos = self.position_block(positions, emb, context)
332
+ guided_hint = self.fuse_block(torch.cat([enc_glyph, enc_pos, text_info['masked_x']], dim=1))
333
+
334
+ outs = []
335
+
336
+ h = x.type(self.dtype)
337
+ for module, zero_conv in zip(self.input_blocks, self.zero_convs):
338
+ if guided_hint is not None:
339
+ h = module(h, emb, context)
340
+ h += guided_hint
341
+ guided_hint = None
342
+ else:
343
+ h = module(h, emb, context)
344
+ outs.append(zero_conv(h, emb, context))
345
+
346
+ h = self.middle_block(h, emb, context)
347
+ outs.append(self.middle_block_out(h, emb, context))
348
+
349
+ return outs
350
+
351
+
352
+ class ControlLDM(LatentDiffusion):
353
+
354
+ def __init__(self, control_stage_config, control_key, glyph_key, position_key, only_mid_control, loss_alpha=0, loss_beta=0, with_step_weight=False, use_vae_upsample=False, latin_weight=1.0, embedding_manager_config=None, *args, **kwargs):
355
+ self.use_fp16 = kwargs.pop('use_fp16', False)
356
+ super().__init__(*args, **kwargs)
357
+ self.control_model = instantiate_from_config(control_stage_config)
358
+ self.control_key = control_key
359
+ self.glyph_key = glyph_key
360
+ self.position_key = position_key
361
+ self.only_mid_control = only_mid_control
362
+ self.control_scales = [1.0] * 13
363
+ self.loss_alpha = loss_alpha
364
+ self.loss_beta = loss_beta
365
+ self.with_step_weight = with_step_weight
366
+ self.use_vae_upsample = use_vae_upsample
367
+ self.latin_weight = latin_weight
368
+
369
+ if embedding_manager_config is not None and embedding_manager_config.params.valid:
370
+ self.embedding_manager = self.instantiate_embedding_manager(embedding_manager_config, self.cond_stage_model)
371
+ for param in self.embedding_manager.embedding_parameters():
372
+ param.requires_grad = True
373
+ else:
374
+ self.embedding_manager = None
375
+ if self.loss_alpha > 0 or self.loss_beta > 0 or self.embedding_manager:
376
+ if embedding_manager_config.params.emb_type == 'ocr':
377
+ self.text_predictor = create_predictor().eval()
378
+ args = edict()
379
+ args.rec_image_shape = "3, 48, 320"
380
+ args.rec_batch_num = 6
381
+ args.rec_char_dict_path = str(CURRENT_DIR.parent / "ocr_recog" / "ppocr_keys_v1.txt")
382
+ args.use_fp16 = self.use_fp16
383
+ self.cn_recognizer = TextRecognizer(args, self.text_predictor)
384
+ for param in self.text_predictor.parameters():
385
+ param.requires_grad = False
386
+ if self.embedding_manager:
387
+ self.embedding_manager.recog = self.cn_recognizer
388
+
389
+ @torch.no_grad()
390
+ def get_input(self, batch, k, bs=None, *args, **kwargs):
391
+ if self.embedding_manager is None: # fill in full caption
392
+ self.fill_caption(batch)
393
+ x, c, mx = super().get_input(batch, self.first_stage_key, mask_k='masked_img', *args, **kwargs)
394
+ control = batch[self.control_key] # for log_images and loss_alpha, not real control
395
+ if bs is not None:
396
+ control = control[:bs]
397
+ control = control.to(self.device)
398
+ control = einops.rearrange(control, 'b h w c -> b c h w')
399
+ control = control.to(memory_format=torch.contiguous_format).float()
400
+
401
+ inv_mask = batch['inv_mask']
402
+ if bs is not None:
403
+ inv_mask = inv_mask[:bs]
404
+ inv_mask = inv_mask.to(self.device)
405
+ inv_mask = einops.rearrange(inv_mask, 'b h w c -> b c h w')
406
+ inv_mask = inv_mask.to(memory_format=torch.contiguous_format).float()
407
+
408
+ glyphs = batch[self.glyph_key]
409
+ gly_line = batch['gly_line']
410
+ positions = batch[self.position_key]
411
+ n_lines = batch['n_lines']
412
+ language = batch['language']
413
+ texts = batch['texts']
414
+ assert len(glyphs) == len(positions)
415
+ for i in range(len(glyphs)):
416
+ if bs is not None:
417
+ glyphs[i] = glyphs[i][:bs]
418
+ gly_line[i] = gly_line[i][:bs]
419
+ positions[i] = positions[i][:bs]
420
+ n_lines = n_lines[:bs]
421
+ glyphs[i] = glyphs[i].to(self.device)
422
+ gly_line[i] = gly_line[i].to(self.device)
423
+ positions[i] = positions[i].to(self.device)
424
+ glyphs[i] = einops.rearrange(glyphs[i], 'b h w c -> b c h w')
425
+ gly_line[i] = einops.rearrange(gly_line[i], 'b h w c -> b c h w')
426
+ positions[i] = einops.rearrange(positions[i], 'b h w c -> b c h w')
427
+ glyphs[i] = glyphs[i].to(memory_format=torch.contiguous_format).float()
428
+ gly_line[i] = gly_line[i].to(memory_format=torch.contiguous_format).float()
429
+ positions[i] = positions[i].to(memory_format=torch.contiguous_format).float()
430
+ info = {}
431
+ info['glyphs'] = glyphs
432
+ info['positions'] = positions
433
+ info['n_lines'] = n_lines
434
+ info['language'] = language
435
+ info['texts'] = texts
436
+ info['img'] = batch['img'] # nhwc, (-1,1)
437
+ info['masked_x'] = mx
438
+ info['gly_line'] = gly_line
439
+ info['inv_mask'] = inv_mask
440
+ return x, dict(c_crossattn=[c], c_concat=[control], text_info=info)
441
+
442
+ def apply_model(self, x_noisy, t, cond, *args, **kwargs):
443
+ assert isinstance(cond, dict)
444
+ diffusion_model = self.model.diffusion_model
445
+ _cond = torch.cat(cond['c_crossattn'], 1)
446
+ _hint = torch.cat(cond['c_concat'], 1)
447
+ if self.use_fp16:
448
+ x_noisy = x_noisy.half()
449
+ control = self.control_model(x=x_noisy, timesteps=t, context=_cond, hint=_hint, text_info=cond['text_info'])
450
+ control = [c * scale for c, scale in zip(control, self.control_scales)]
451
+ eps = diffusion_model(x=x_noisy, timesteps=t, context=_cond, control=control, only_mid_control=self.only_mid_control)
452
+
453
+ return eps
454
+
455
+ def instantiate_embedding_manager(self, config, embedder):
456
+ model = instantiate_from_config(config, embedder=embedder)
457
+ return model
458
+
459
+ @torch.no_grad()
460
+ def get_unconditional_conditioning(self, N):
461
+ return self.get_learned_conditioning(dict(c_crossattn=[[""] * N], text_info=None))
462
+
463
+ def get_learned_conditioning(self, c):
464
+ if self.cond_stage_forward is None:
465
+ if hasattr(self.cond_stage_model, 'encode') and callable(self.cond_stage_model.encode):
466
+ if self.embedding_manager is not None and c['text_info'] is not None:
467
+ self.embedding_manager.encode_text(c['text_info'])
468
+ if isinstance(c, dict):
469
+ cond_txt = c['c_crossattn'][0]
470
+ else:
471
+ cond_txt = c
472
+ if self.embedding_manager is not None:
473
+ cond_txt = self.cond_stage_model.encode(cond_txt, embedding_manager=self.embedding_manager)
474
+ else:
475
+ cond_txt = self.cond_stage_model.encode(cond_txt)
476
+ if isinstance(c, dict):
477
+ c['c_crossattn'][0] = cond_txt
478
+ else:
479
+ c = cond_txt
480
+ if isinstance(c, DiagonalGaussianDistribution):
481
+ c = c.mode()
482
+ else:
483
+ c = self.cond_stage_model(c)
484
+ else:
485
+ assert hasattr(self.cond_stage_model, self.cond_stage_forward)
486
+ c = getattr(self.cond_stage_model, self.cond_stage_forward)(c)
487
+ return c
488
+
489
+ def fill_caption(self, batch, place_holder='*'):
490
+ bs = len(batch['n_lines'])
491
+ cond_list = copy.deepcopy(batch[self.cond_stage_key])
492
+ for i in range(bs):
493
+ n_lines = batch['n_lines'][i]
494
+ if n_lines == 0:
495
+ continue
496
+ cur_cap = cond_list[i]
497
+ for j in range(n_lines):
498
+ r_txt = batch['texts'][j][i]
499
+ cur_cap = cur_cap.replace(place_holder, f'"{r_txt}"', 1)
500
+ cond_list[i] = cur_cap
501
+ batch[self.cond_stage_key] = cond_list
502
+
503
+ @torch.no_grad()
504
+ def log_images(self, batch, N=4, n_row=2, sample=False, ddim_steps=50, ddim_eta=0.0, return_keys=None,
505
+ quantize_denoised=True, inpaint=True, plot_denoise_rows=False, plot_progressive_rows=True,
506
+ plot_diffusion_rows=False, unconditional_guidance_scale=9.0, unconditional_guidance_label=None,
507
+ use_ema_scope=True,
508
+ **kwargs):
509
+ use_ddim = ddim_steps is not None
510
+
511
+ log = dict()
512
+ z, c = self.get_input(batch, self.first_stage_key, bs=N)
513
+ if self.cond_stage_trainable:
514
+ with torch.no_grad():
515
+ c = self.get_learned_conditioning(c)
516
+ c_crossattn = c["c_crossattn"][0][:N]
517
+ c_cat = c["c_concat"][0][:N]
518
+ text_info = c["text_info"]
519
+ text_info['glyphs'] = [i[:N] for i in text_info['glyphs']]
520
+ text_info['gly_line'] = [i[:N] for i in text_info['gly_line']]
521
+ text_info['positions'] = [i[:N] for i in text_info['positions']]
522
+ text_info['n_lines'] = text_info['n_lines'][:N]
523
+ text_info['masked_x'] = text_info['masked_x'][:N]
524
+ text_info['img'] = text_info['img'][:N]
525
+
526
+ N = min(z.shape[0], N)
527
+ n_row = min(z.shape[0], n_row)
528
+ log["reconstruction"] = self.decode_first_stage(z)
529
+ log["masked_image"] = self.decode_first_stage(text_info['masked_x'])
530
+ log["control"] = c_cat * 2.0 - 1.0
531
+ log["img"] = text_info['img'].permute(0, 3, 1, 2) # log source image if needed
532
+ # get glyph
533
+ glyph_bs = torch.stack(text_info['glyphs'])
534
+ glyph_bs = torch.sum(glyph_bs, dim=0) * 2.0 - 1.0
535
+ log["glyph"] = torch.nn.functional.interpolate(glyph_bs, size=(512, 512), mode='bilinear', align_corners=True,)
536
+ # fill caption
537
+ if not self.embedding_manager:
538
+ self.fill_caption(batch)
539
+ captions = batch[self.cond_stage_key]
540
+ log["conditioning"] = log_txt_as_img((512, 512), captions, size=16)
541
+
542
+ if plot_diffusion_rows:
543
+ # get diffusion row
544
+ diffusion_row = list()
545
+ z_start = z[:n_row]
546
+ for t in range(self.num_timesteps):
547
+ if t % self.log_every_t == 0 or t == self.num_timesteps - 1:
548
+ t = repeat(torch.tensor([t]), '1 -> b', b=n_row)
549
+ t = t.to(self.device).long()
550
+ noise = torch.randn_like(z_start)
551
+ z_noisy = self.q_sample(x_start=z_start, t=t, noise=noise)
552
+ diffusion_row.append(self.decode_first_stage(z_noisy))
553
+
554
+ diffusion_row = torch.stack(diffusion_row) # n_log_step, n_row, C, H, W
555
+ diffusion_grid = rearrange(diffusion_row, 'n b c h w -> b n c h w')
556
+ diffusion_grid = rearrange(diffusion_grid, 'b n c h w -> (b n) c h w')
557
+ diffusion_grid = make_grid(diffusion_grid, nrow=diffusion_row.shape[0])
558
+ log["diffusion_row"] = diffusion_grid
559
+
560
+ if sample:
561
+ # get denoise row
562
+ samples, z_denoise_row = self.sample_log(cond={"c_concat": [c_cat], "c_crossattn": [c], "text_info": text_info},
563
+ batch_size=N, ddim=use_ddim,
564
+ ddim_steps=ddim_steps, eta=ddim_eta)
565
+ x_samples = self.decode_first_stage(samples)
566
+ log["samples"] = x_samples
567
+ if plot_denoise_rows:
568
+ denoise_grid = self._get_denoise_row_from_list(z_denoise_row)
569
+ log["denoise_row"] = denoise_grid
570
+
571
+ if unconditional_guidance_scale > 1.0:
572
+ uc_cross = self.get_unconditional_conditioning(N)
573
+ uc_cat = c_cat # torch.zeros_like(c_cat)
574
+ uc_full = {"c_concat": [uc_cat], "c_crossattn": [uc_cross['c_crossattn'][0]], "text_info": text_info}
575
+ samples_cfg, tmps = self.sample_log(cond={"c_concat": [c_cat], "c_crossattn": [c_crossattn], "text_info": text_info},
576
+ batch_size=N, ddim=use_ddim,
577
+ ddim_steps=ddim_steps, eta=ddim_eta,
578
+ unconditional_guidance_scale=unconditional_guidance_scale,
579
+ unconditional_conditioning=uc_full,
580
+ )
581
+ x_samples_cfg = self.decode_first_stage(samples_cfg)
582
+ log[f"samples_cfg_scale_{unconditional_guidance_scale:.2f}"] = x_samples_cfg
583
+ pred_x0 = False # wether log pred_x0
584
+ if pred_x0:
585
+ for idx in range(len(tmps['pred_x0'])):
586
+ pred_x0 = self.decode_first_stage(tmps['pred_x0'][idx])
587
+ log[f"pred_x0_{tmps['index'][idx]}"] = pred_x0
588
+
589
+ return log
590
+
591
+ @torch.no_grad()
592
+ def sample_log(self, cond, batch_size, ddim, ddim_steps, **kwargs):
593
+ ddim_sampler = DDIMSampler(self)
594
+ b, c, h, w = cond["c_concat"][0].shape
595
+ shape = (self.channels, h // 8, w // 8)
596
+ samples, intermediates = ddim_sampler.sample(ddim_steps, batch_size, shape, cond, verbose=False, log_every_t=5, **kwargs)
597
+ return samples, intermediates
598
+
599
+ def configure_optimizers(self):
600
+ lr = self.learning_rate
601
+ params = list(self.control_model.parameters())
602
+ if self.embedding_manager:
603
+ params += list(self.embedding_manager.embedding_parameters())
604
+ if not self.sd_locked:
605
+ # params += list(self.model.diffusion_model.input_blocks.parameters())
606
+ # params += list(self.model.diffusion_model.middle_block.parameters())
607
+ params += list(self.model.diffusion_model.output_blocks.parameters())
608
+ params += list(self.model.diffusion_model.out.parameters())
609
+ if self.unlockKV:
610
+ nCount = 0
611
+ for name, param in self.model.diffusion_model.named_parameters():
612
+ if 'attn2.to_k' in name or 'attn2.to_v' in name:
613
+ params += [param]
614
+ nCount += 1
615
+ print(f'Cross attention is unlocked, and {nCount} Wk or Wv are added to potimizers!!!')
616
+
617
+ opt = torch.optim.AdamW(params, lr=lr)
618
+ return opt
619
+
620
+ def low_vram_shift(self, is_diffusing):
621
+ if is_diffusing:
622
+ self.model = self.model.cuda()
623
+ self.control_model = self.control_model.cuda()
624
+ self.first_stage_model = self.first_stage_model.cpu()
625
+ self.cond_stage_model = self.cond_stage_model.cpu()
626
+ else:
627
+ self.model = self.model.cpu()
628
+ self.control_model = self.control_model.cpu()
629
+ self.first_stage_model = self.first_stage_model.cuda()
630
+ self.cond_stage_model = self.cond_stage_model.cuda()
iopaint/model/anytext/cldm/ddim_hacked.py ADDED
@@ -0,0 +1,486 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """SAMPLING ONLY."""
2
+
3
+ import torch
4
+ import numpy as np
5
+ from tqdm import tqdm
6
+
7
+ from iopaint.model.anytext.ldm.modules.diffusionmodules.util import (
8
+ make_ddim_sampling_parameters,
9
+ make_ddim_timesteps,
10
+ noise_like,
11
+ extract_into_tensor,
12
+ )
13
+
14
+
15
+ class DDIMSampler(object):
16
+ def __init__(self, model, device, schedule="linear", **kwargs):
17
+ super().__init__()
18
+ self.device = device
19
+ self.model = model
20
+ self.ddpm_num_timesteps = model.num_timesteps
21
+ self.schedule = schedule
22
+
23
+ def register_buffer(self, name, attr):
24
+ if type(attr) == torch.Tensor:
25
+ if attr.device != torch.device(self.device):
26
+ attr = attr.to(torch.device(self.device))
27
+ setattr(self, name, attr)
28
+
29
+ def make_schedule(
30
+ self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0.0, verbose=True
31
+ ):
32
+ self.ddim_timesteps = make_ddim_timesteps(
33
+ ddim_discr_method=ddim_discretize,
34
+ num_ddim_timesteps=ddim_num_steps,
35
+ num_ddpm_timesteps=self.ddpm_num_timesteps,
36
+ verbose=verbose,
37
+ )
38
+ alphas_cumprod = self.model.alphas_cumprod
39
+ assert (
40
+ alphas_cumprod.shape[0] == self.ddpm_num_timesteps
41
+ ), "alphas have to be defined for each timestep"
42
+ to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.device)
43
+
44
+ self.register_buffer("betas", to_torch(self.model.betas))
45
+ self.register_buffer("alphas_cumprod", to_torch(alphas_cumprod))
46
+ self.register_buffer(
47
+ "alphas_cumprod_prev", to_torch(self.model.alphas_cumprod_prev)
48
+ )
49
+
50
+ # calculations for diffusion q(x_t | x_{t-1}) and others
51
+ self.register_buffer(
52
+ "sqrt_alphas_cumprod", to_torch(np.sqrt(alphas_cumprod.cpu()))
53
+ )
54
+ self.register_buffer(
55
+ "sqrt_one_minus_alphas_cumprod",
56
+ to_torch(np.sqrt(1.0 - alphas_cumprod.cpu())),
57
+ )
58
+ self.register_buffer(
59
+ "log_one_minus_alphas_cumprod", to_torch(np.log(1.0 - alphas_cumprod.cpu()))
60
+ )
61
+ self.register_buffer(
62
+ "sqrt_recip_alphas_cumprod", to_torch(np.sqrt(1.0 / alphas_cumprod.cpu()))
63
+ )
64
+ self.register_buffer(
65
+ "sqrt_recipm1_alphas_cumprod",
66
+ to_torch(np.sqrt(1.0 / alphas_cumprod.cpu() - 1)),
67
+ )
68
+
69
+ # ddim sampling parameters
70
+ ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(
71
+ alphacums=alphas_cumprod.cpu(),
72
+ ddim_timesteps=self.ddim_timesteps,
73
+ eta=ddim_eta,
74
+ verbose=verbose,
75
+ )
76
+ self.register_buffer("ddim_sigmas", ddim_sigmas)
77
+ self.register_buffer("ddim_alphas", ddim_alphas)
78
+ self.register_buffer("ddim_alphas_prev", ddim_alphas_prev)
79
+ self.register_buffer("ddim_sqrt_one_minus_alphas", np.sqrt(1.0 - ddim_alphas))
80
+ sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt(
81
+ (1 - self.alphas_cumprod_prev)
82
+ / (1 - self.alphas_cumprod)
83
+ * (1 - self.alphas_cumprod / self.alphas_cumprod_prev)
84
+ )
85
+ self.register_buffer(
86
+ "ddim_sigmas_for_original_num_steps", sigmas_for_original_sampling_steps
87
+ )
88
+
89
+ @torch.no_grad()
90
+ def sample(
91
+ self,
92
+ S,
93
+ batch_size,
94
+ shape,
95
+ conditioning=None,
96
+ callback=None,
97
+ normals_sequence=None,
98
+ img_callback=None,
99
+ quantize_x0=False,
100
+ eta=0.0,
101
+ mask=None,
102
+ x0=None,
103
+ temperature=1.0,
104
+ noise_dropout=0.0,
105
+ score_corrector=None,
106
+ corrector_kwargs=None,
107
+ verbose=True,
108
+ x_T=None,
109
+ log_every_t=100,
110
+ unconditional_guidance_scale=1.0,
111
+ unconditional_conditioning=None, # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
112
+ dynamic_threshold=None,
113
+ ucg_schedule=None,
114
+ **kwargs,
115
+ ):
116
+ if conditioning is not None:
117
+ if isinstance(conditioning, dict):
118
+ ctmp = conditioning[list(conditioning.keys())[0]]
119
+ while isinstance(ctmp, list):
120
+ ctmp = ctmp[0]
121
+ cbs = ctmp.shape[0]
122
+ if cbs != batch_size:
123
+ print(
124
+ f"Warning: Got {cbs} conditionings but batch-size is {batch_size}"
125
+ )
126
+
127
+ elif isinstance(conditioning, list):
128
+ for ctmp in conditioning:
129
+ if ctmp.shape[0] != batch_size:
130
+ print(
131
+ f"Warning: Got {cbs} conditionings but batch-size is {batch_size}"
132
+ )
133
+
134
+ else:
135
+ if conditioning.shape[0] != batch_size:
136
+ print(
137
+ f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}"
138
+ )
139
+
140
+ self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose)
141
+ # sampling
142
+ C, H, W = shape
143
+ size = (batch_size, C, H, W)
144
+ print(f"Data shape for DDIM sampling is {size}, eta {eta}")
145
+
146
+ samples, intermediates = self.ddim_sampling(
147
+ conditioning,
148
+ size,
149
+ callback=callback,
150
+ img_callback=img_callback,
151
+ quantize_denoised=quantize_x0,
152
+ mask=mask,
153
+ x0=x0,
154
+ ddim_use_original_steps=False,
155
+ noise_dropout=noise_dropout,
156
+ temperature=temperature,
157
+ score_corrector=score_corrector,
158
+ corrector_kwargs=corrector_kwargs,
159
+ x_T=x_T,
160
+ log_every_t=log_every_t,
161
+ unconditional_guidance_scale=unconditional_guidance_scale,
162
+ unconditional_conditioning=unconditional_conditioning,
163
+ dynamic_threshold=dynamic_threshold,
164
+ ucg_schedule=ucg_schedule,
165
+ )
166
+ return samples, intermediates
167
+
168
+ @torch.no_grad()
169
+ def ddim_sampling(
170
+ self,
171
+ cond,
172
+ shape,
173
+ x_T=None,
174
+ ddim_use_original_steps=False,
175
+ callback=None,
176
+ timesteps=None,
177
+ quantize_denoised=False,
178
+ mask=None,
179
+ x0=None,
180
+ img_callback=None,
181
+ log_every_t=100,
182
+ temperature=1.0,
183
+ noise_dropout=0.0,
184
+ score_corrector=None,
185
+ corrector_kwargs=None,
186
+ unconditional_guidance_scale=1.0,
187
+ unconditional_conditioning=None,
188
+ dynamic_threshold=None,
189
+ ucg_schedule=None,
190
+ ):
191
+ device = self.model.betas.device
192
+ b = shape[0]
193
+ if x_T is None:
194
+ img = torch.randn(shape, device=device)
195
+ else:
196
+ img = x_T
197
+
198
+ if timesteps is None:
199
+ timesteps = (
200
+ self.ddpm_num_timesteps
201
+ if ddim_use_original_steps
202
+ else self.ddim_timesteps
203
+ )
204
+ elif timesteps is not None and not ddim_use_original_steps:
205
+ subset_end = (
206
+ int(
207
+ min(timesteps / self.ddim_timesteps.shape[0], 1)
208
+ * self.ddim_timesteps.shape[0]
209
+ )
210
+ - 1
211
+ )
212
+ timesteps = self.ddim_timesteps[:subset_end]
213
+
214
+ intermediates = {"x_inter": [img], "pred_x0": [img]}
215
+ time_range = (
216
+ reversed(range(0, timesteps))
217
+ if ddim_use_original_steps
218
+ else np.flip(timesteps)
219
+ )
220
+ total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0]
221
+ print(f"Running DDIM Sampling with {total_steps} timesteps")
222
+
223
+ iterator = tqdm(time_range, desc="DDIM Sampler", total=total_steps)
224
+
225
+ for i, step in enumerate(iterator):
226
+ index = total_steps - i - 1
227
+ ts = torch.full((b,), step, device=device, dtype=torch.long)
228
+
229
+ if mask is not None:
230
+ assert x0 is not None
231
+ img_orig = self.model.q_sample(
232
+ x0, ts
233
+ ) # TODO: deterministic forward pass?
234
+ img = img_orig * mask + (1.0 - mask) * img
235
+
236
+ if ucg_schedule is not None:
237
+ assert len(ucg_schedule) == len(time_range)
238
+ unconditional_guidance_scale = ucg_schedule[i]
239
+
240
+ outs = self.p_sample_ddim(
241
+ img,
242
+ cond,
243
+ ts,
244
+ index=index,
245
+ use_original_steps=ddim_use_original_steps,
246
+ quantize_denoised=quantize_denoised,
247
+ temperature=temperature,
248
+ noise_dropout=noise_dropout,
249
+ score_corrector=score_corrector,
250
+ corrector_kwargs=corrector_kwargs,
251
+ unconditional_guidance_scale=unconditional_guidance_scale,
252
+ unconditional_conditioning=unconditional_conditioning,
253
+ dynamic_threshold=dynamic_threshold,
254
+ )
255
+ img, pred_x0 = outs
256
+ if callback:
257
+ callback(None, i, None, None)
258
+ if img_callback:
259
+ img_callback(pred_x0, i)
260
+
261
+ if index % log_every_t == 0 or index == total_steps - 1:
262
+ intermediates["x_inter"].append(img)
263
+ intermediates["pred_x0"].append(pred_x0)
264
+
265
+ return img, intermediates
266
+
267
+ @torch.no_grad()
268
+ def p_sample_ddim(
269
+ self,
270
+ x,
271
+ c,
272
+ t,
273
+ index,
274
+ repeat_noise=False,
275
+ use_original_steps=False,
276
+ quantize_denoised=False,
277
+ temperature=1.0,
278
+ noise_dropout=0.0,
279
+ score_corrector=None,
280
+ corrector_kwargs=None,
281
+ unconditional_guidance_scale=1.0,
282
+ unconditional_conditioning=None,
283
+ dynamic_threshold=None,
284
+ ):
285
+ b, *_, device = *x.shape, x.device
286
+
287
+ if unconditional_conditioning is None or unconditional_guidance_scale == 1.0:
288
+ model_output = self.model.apply_model(x, t, c)
289
+ else:
290
+ model_t = self.model.apply_model(x, t, c)
291
+ model_uncond = self.model.apply_model(x, t, unconditional_conditioning)
292
+ model_output = model_uncond + unconditional_guidance_scale * (
293
+ model_t - model_uncond
294
+ )
295
+
296
+ if self.model.parameterization == "v":
297
+ e_t = self.model.predict_eps_from_z_and_v(x, t, model_output)
298
+ else:
299
+ e_t = model_output
300
+
301
+ if score_corrector is not None:
302
+ assert self.model.parameterization == "eps", "not implemented"
303
+ e_t = score_corrector.modify_score(
304
+ self.model, e_t, x, t, c, **corrector_kwargs
305
+ )
306
+
307
+ alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas
308
+ alphas_prev = (
309
+ self.model.alphas_cumprod_prev
310
+ if use_original_steps
311
+ else self.ddim_alphas_prev
312
+ )
313
+ sqrt_one_minus_alphas = (
314
+ self.model.sqrt_one_minus_alphas_cumprod
315
+ if use_original_steps
316
+ else self.ddim_sqrt_one_minus_alphas
317
+ )
318
+ sigmas = (
319
+ self.model.ddim_sigmas_for_original_num_steps
320
+ if use_original_steps
321
+ else self.ddim_sigmas
322
+ )
323
+ # select parameters corresponding to the currently considered timestep
324
+ a_t = torch.full((b, 1, 1, 1), alphas[index], device=device)
325
+ a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device)
326
+ sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device)
327
+ sqrt_one_minus_at = torch.full(
328
+ (b, 1, 1, 1), sqrt_one_minus_alphas[index], device=device
329
+ )
330
+
331
+ # current prediction for x_0
332
+ if self.model.parameterization != "v":
333
+ pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
334
+ else:
335
+ pred_x0 = self.model.predict_start_from_z_and_v(x, t, model_output)
336
+
337
+ if quantize_denoised:
338
+ pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0)
339
+
340
+ if dynamic_threshold is not None:
341
+ raise NotImplementedError()
342
+
343
+ # direction pointing to x_t
344
+ dir_xt = (1.0 - a_prev - sigma_t**2).sqrt() * e_t
345
+ noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature
346
+ if noise_dropout > 0.0:
347
+ noise = torch.nn.functional.dropout(noise, p=noise_dropout)
348
+ x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise
349
+ return x_prev, pred_x0
350
+
351
+ @torch.no_grad()
352
+ def encode(
353
+ self,
354
+ x0,
355
+ c,
356
+ t_enc,
357
+ use_original_steps=False,
358
+ return_intermediates=None,
359
+ unconditional_guidance_scale=1.0,
360
+ unconditional_conditioning=None,
361
+ callback=None,
362
+ ):
363
+ timesteps = (
364
+ np.arange(self.ddpm_num_timesteps)
365
+ if use_original_steps
366
+ else self.ddim_timesteps
367
+ )
368
+ num_reference_steps = timesteps.shape[0]
369
+
370
+ assert t_enc <= num_reference_steps
371
+ num_steps = t_enc
372
+
373
+ if use_original_steps:
374
+ alphas_next = self.alphas_cumprod[:num_steps]
375
+ alphas = self.alphas_cumprod_prev[:num_steps]
376
+ else:
377
+ alphas_next = self.ddim_alphas[:num_steps]
378
+ alphas = torch.tensor(self.ddim_alphas_prev[:num_steps])
379
+
380
+ x_next = x0
381
+ intermediates = []
382
+ inter_steps = []
383
+ for i in tqdm(range(num_steps), desc="Encoding Image"):
384
+ t = torch.full(
385
+ (x0.shape[0],), timesteps[i], device=self.model.device, dtype=torch.long
386
+ )
387
+ if unconditional_guidance_scale == 1.0:
388
+ noise_pred = self.model.apply_model(x_next, t, c)
389
+ else:
390
+ assert unconditional_conditioning is not None
391
+ e_t_uncond, noise_pred = torch.chunk(
392
+ self.model.apply_model(
393
+ torch.cat((x_next, x_next)),
394
+ torch.cat((t, t)),
395
+ torch.cat((unconditional_conditioning, c)),
396
+ ),
397
+ 2,
398
+ )
399
+ noise_pred = e_t_uncond + unconditional_guidance_scale * (
400
+ noise_pred - e_t_uncond
401
+ )
402
+
403
+ xt_weighted = (alphas_next[i] / alphas[i]).sqrt() * x_next
404
+ weighted_noise_pred = (
405
+ alphas_next[i].sqrt()
406
+ * ((1 / alphas_next[i] - 1).sqrt() - (1 / alphas[i] - 1).sqrt())
407
+ * noise_pred
408
+ )
409
+ x_next = xt_weighted + weighted_noise_pred
410
+ if (
411
+ return_intermediates
412
+ and i % (num_steps // return_intermediates) == 0
413
+ and i < num_steps - 1
414
+ ):
415
+ intermediates.append(x_next)
416
+ inter_steps.append(i)
417
+ elif return_intermediates and i >= num_steps - 2:
418
+ intermediates.append(x_next)
419
+ inter_steps.append(i)
420
+ if callback:
421
+ callback(i)
422
+
423
+ out = {"x_encoded": x_next, "intermediate_steps": inter_steps}
424
+ if return_intermediates:
425
+ out.update({"intermediates": intermediates})
426
+ return x_next, out
427
+
428
+ @torch.no_grad()
429
+ def stochastic_encode(self, x0, t, use_original_steps=False, noise=None):
430
+ # fast, but does not allow for exact reconstruction
431
+ # t serves as an index to gather the correct alphas
432
+ if use_original_steps:
433
+ sqrt_alphas_cumprod = self.sqrt_alphas_cumprod
434
+ sqrt_one_minus_alphas_cumprod = self.sqrt_one_minus_alphas_cumprod
435
+ else:
436
+ sqrt_alphas_cumprod = torch.sqrt(self.ddim_alphas)
437
+ sqrt_one_minus_alphas_cumprod = self.ddim_sqrt_one_minus_alphas
438
+
439
+ if noise is None:
440
+ noise = torch.randn_like(x0)
441
+ return (
442
+ extract_into_tensor(sqrt_alphas_cumprod, t, x0.shape) * x0
443
+ + extract_into_tensor(sqrt_one_minus_alphas_cumprod, t, x0.shape) * noise
444
+ )
445
+
446
+ @torch.no_grad()
447
+ def decode(
448
+ self,
449
+ x_latent,
450
+ cond,
451
+ t_start,
452
+ unconditional_guidance_scale=1.0,
453
+ unconditional_conditioning=None,
454
+ use_original_steps=False,
455
+ callback=None,
456
+ ):
457
+ timesteps = (
458
+ np.arange(self.ddpm_num_timesteps)
459
+ if use_original_steps
460
+ else self.ddim_timesteps
461
+ )
462
+ timesteps = timesteps[:t_start]
463
+
464
+ time_range = np.flip(timesteps)
465
+ total_steps = timesteps.shape[0]
466
+ print(f"Running DDIM Sampling with {total_steps} timesteps")
467
+
468
+ iterator = tqdm(time_range, desc="Decoding image", total=total_steps)
469
+ x_dec = x_latent
470
+ for i, step in enumerate(iterator):
471
+ index = total_steps - i - 1
472
+ ts = torch.full(
473
+ (x_latent.shape[0],), step, device=x_latent.device, dtype=torch.long
474
+ )
475
+ x_dec, _ = self.p_sample_ddim(
476
+ x_dec,
477
+ cond,
478
+ ts,
479
+ index=index,
480
+ use_original_steps=use_original_steps,
481
+ unconditional_guidance_scale=unconditional_guidance_scale,
482
+ unconditional_conditioning=unconditional_conditioning,
483
+ )
484
+ if callback:
485
+ callback(i)
486
+ return x_dec
iopaint/model/anytext/cldm/embedding_manager.py ADDED
@@ -0,0 +1,165 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ '''
2
+ Copyright (c) Alibaba, Inc. and its affiliates.
3
+ '''
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+ from functools import partial
8
+ from iopaint.model.anytext.ldm.modules.diffusionmodules.util import conv_nd, linear
9
+
10
+
11
+ def get_clip_token_for_string(tokenizer, string):
12
+ batch_encoding = tokenizer(string, truncation=True, max_length=77, return_length=True,
13
+ return_overflowing_tokens=False, padding="max_length", return_tensors="pt")
14
+ tokens = batch_encoding["input_ids"]
15
+ assert torch.count_nonzero(tokens - 49407) == 2, f"String '{string}' maps to more than a single token. Please use another string"
16
+ return tokens[0, 1]
17
+
18
+
19
+ def get_bert_token_for_string(tokenizer, string):
20
+ token = tokenizer(string)
21
+ assert torch.count_nonzero(token) == 3, f"String '{string}' maps to more than a single token. Please use another string"
22
+ token = token[0, 1]
23
+ return token
24
+
25
+
26
+ def get_clip_vision_emb(encoder, processor, img):
27
+ _img = img.repeat(1, 3, 1, 1)*255
28
+ inputs = processor(images=_img, return_tensors="pt")
29
+ inputs['pixel_values'] = inputs['pixel_values'].to(img.device)
30
+ outputs = encoder(**inputs)
31
+ emb = outputs.image_embeds
32
+ return emb
33
+
34
+
35
+ def get_recog_emb(encoder, img_list):
36
+ _img_list = [(img.repeat(1, 3, 1, 1)*255)[0] for img in img_list]
37
+ encoder.predictor.eval()
38
+ _, preds_neck = encoder.pred_imglist(_img_list, show_debug=False)
39
+ return preds_neck
40
+
41
+
42
+ def pad_H(x):
43
+ _, _, H, W = x.shape
44
+ p_top = (W - H) // 2
45
+ p_bot = W - H - p_top
46
+ return F.pad(x, (0, 0, p_top, p_bot))
47
+
48
+
49
+ class EncodeNet(nn.Module):
50
+ def __init__(self, in_channels, out_channels):
51
+ super(EncodeNet, self).__init__()
52
+ chan = 16
53
+ n_layer = 4 # downsample
54
+
55
+ self.conv1 = conv_nd(2, in_channels, chan, 3, padding=1)
56
+ self.conv_list = nn.ModuleList([])
57
+ _c = chan
58
+ for i in range(n_layer):
59
+ self.conv_list.append(conv_nd(2, _c, _c*2, 3, padding=1, stride=2))
60
+ _c *= 2
61
+ self.conv2 = conv_nd(2, _c, out_channels, 3, padding=1)
62
+ self.avgpool = nn.AdaptiveAvgPool2d(1)
63
+ self.act = nn.SiLU()
64
+
65
+ def forward(self, x):
66
+ x = self.act(self.conv1(x))
67
+ for layer in self.conv_list:
68
+ x = self.act(layer(x))
69
+ x = self.act(self.conv2(x))
70
+ x = self.avgpool(x)
71
+ x = x.view(x.size(0), -1)
72
+ return x
73
+
74
+
75
+ class EmbeddingManager(nn.Module):
76
+ def __init__(
77
+ self,
78
+ embedder,
79
+ valid=True,
80
+ glyph_channels=20,
81
+ position_channels=1,
82
+ placeholder_string='*',
83
+ add_pos=False,
84
+ emb_type='ocr',
85
+ **kwargs
86
+ ):
87
+ super().__init__()
88
+ if hasattr(embedder, 'tokenizer'): # using Stable Diffusion's CLIP encoder
89
+ get_token_for_string = partial(get_clip_token_for_string, embedder.tokenizer)
90
+ token_dim = 768
91
+ if hasattr(embedder, 'vit'):
92
+ assert emb_type == 'vit'
93
+ self.get_vision_emb = partial(get_clip_vision_emb, embedder.vit, embedder.processor)
94
+ self.get_recog_emb = None
95
+ else: # using LDM's BERT encoder
96
+ get_token_for_string = partial(get_bert_token_for_string, embedder.tknz_fn)
97
+ token_dim = 1280
98
+ self.token_dim = token_dim
99
+ self.emb_type = emb_type
100
+
101
+ self.add_pos = add_pos
102
+ if add_pos:
103
+ self.position_encoder = EncodeNet(position_channels, token_dim)
104
+ if emb_type == 'ocr':
105
+ self.proj = linear(40*64, token_dim)
106
+ if emb_type == 'conv':
107
+ self.glyph_encoder = EncodeNet(glyph_channels, token_dim)
108
+
109
+ self.placeholder_token = get_token_for_string(placeholder_string)
110
+
111
+ def encode_text(self, text_info):
112
+ if self.get_recog_emb is None and self.emb_type == 'ocr':
113
+ self.get_recog_emb = partial(get_recog_emb, self.recog)
114
+
115
+ gline_list = []
116
+ pos_list = []
117
+ for i in range(len(text_info['n_lines'])): # sample index in a batch
118
+ n_lines = text_info['n_lines'][i]
119
+ for j in range(n_lines): # line
120
+ gline_list += [text_info['gly_line'][j][i:i+1]]
121
+ if self.add_pos:
122
+ pos_list += [text_info['positions'][j][i:i+1]]
123
+
124
+ if len(gline_list) > 0:
125
+ if self.emb_type == 'ocr':
126
+ recog_emb = self.get_recog_emb(gline_list)
127
+ enc_glyph = self.proj(recog_emb.reshape(recog_emb.shape[0], -1))
128
+ elif self.emb_type == 'vit':
129
+ enc_glyph = self.get_vision_emb(pad_H(torch.cat(gline_list, dim=0)))
130
+ elif self.emb_type == 'conv':
131
+ enc_glyph = self.glyph_encoder(pad_H(torch.cat(gline_list, dim=0)))
132
+ if self.add_pos:
133
+ enc_pos = self.position_encoder(torch.cat(gline_list, dim=0))
134
+ enc_glyph = enc_glyph+enc_pos
135
+
136
+ self.text_embs_all = []
137
+ n_idx = 0
138
+ for i in range(len(text_info['n_lines'])): # sample index in a batch
139
+ n_lines = text_info['n_lines'][i]
140
+ text_embs = []
141
+ for j in range(n_lines): # line
142
+ text_embs += [enc_glyph[n_idx:n_idx+1]]
143
+ n_idx += 1
144
+ self.text_embs_all += [text_embs]
145
+
146
+ def forward(
147
+ self,
148
+ tokenized_text,
149
+ embedded_text,
150
+ ):
151
+ b, device = tokenized_text.shape[0], tokenized_text.device
152
+ for i in range(b):
153
+ idx = tokenized_text[i] == self.placeholder_token.to(device)
154
+ if sum(idx) > 0:
155
+ if i >= len(self.text_embs_all):
156
+ print('truncation for log images...')
157
+ break
158
+ text_emb = torch.cat(self.text_embs_all[i], dim=0)
159
+ if sum(idx) != len(text_emb):
160
+ print('truncation for long caption...')
161
+ embedded_text[i][idx] = text_emb[:sum(idx)]
162
+ return embedded_text
163
+
164
+ def embedding_parameters(self):
165
+ return self.parameters()
iopaint/model/anytext/cldm/hack.py ADDED
@@ -0,0 +1,111 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import einops
3
+
4
+ import iopaint.model.anytext.ldm.modules.encoders.modules
5
+ import iopaint.model.anytext.ldm.modules.attention
6
+
7
+ from transformers import logging
8
+ from iopaint.model.anytext.ldm.modules.attention import default
9
+
10
+
11
+ def disable_verbosity():
12
+ logging.set_verbosity_error()
13
+ print('logging improved.')
14
+ return
15
+
16
+
17
+ def enable_sliced_attention():
18
+ iopaint.model.anytext.ldm.modules.attention.CrossAttention.forward = _hacked_sliced_attentin_forward
19
+ print('Enabled sliced_attention.')
20
+ return
21
+
22
+
23
+ def hack_everything(clip_skip=0):
24
+ disable_verbosity()
25
+ iopaint.model.anytext.ldm.modules.encoders.modules.FrozenCLIPEmbedder.forward = _hacked_clip_forward
26
+ iopaint.model.anytext.ldm.modules.encoders.modules.FrozenCLIPEmbedder.clip_skip = clip_skip
27
+ print('Enabled clip hacks.')
28
+ return
29
+
30
+
31
+ # Written by Lvmin
32
+ def _hacked_clip_forward(self, text):
33
+ PAD = self.tokenizer.pad_token_id
34
+ EOS = self.tokenizer.eos_token_id
35
+ BOS = self.tokenizer.bos_token_id
36
+
37
+ def tokenize(t):
38
+ return self.tokenizer(t, truncation=False, add_special_tokens=False)["input_ids"]
39
+
40
+ def transformer_encode(t):
41
+ if self.clip_skip > 1:
42
+ rt = self.transformer(input_ids=t, output_hidden_states=True)
43
+ return self.transformer.text_model.final_layer_norm(rt.hidden_states[-self.clip_skip])
44
+ else:
45
+ return self.transformer(input_ids=t, output_hidden_states=False).last_hidden_state
46
+
47
+ def split(x):
48
+ return x[75 * 0: 75 * 1], x[75 * 1: 75 * 2], x[75 * 2: 75 * 3]
49
+
50
+ def pad(x, p, i):
51
+ return x[:i] if len(x) >= i else x + [p] * (i - len(x))
52
+
53
+ raw_tokens_list = tokenize(text)
54
+ tokens_list = []
55
+
56
+ for raw_tokens in raw_tokens_list:
57
+ raw_tokens_123 = split(raw_tokens)
58
+ raw_tokens_123 = [[BOS] + raw_tokens_i + [EOS] for raw_tokens_i in raw_tokens_123]
59
+ raw_tokens_123 = [pad(raw_tokens_i, PAD, 77) for raw_tokens_i in raw_tokens_123]
60
+ tokens_list.append(raw_tokens_123)
61
+
62
+ tokens_list = torch.IntTensor(tokens_list).to(self.device)
63
+
64
+ feed = einops.rearrange(tokens_list, 'b f i -> (b f) i')
65
+ y = transformer_encode(feed)
66
+ z = einops.rearrange(y, '(b f) i c -> b (f i) c', f=3)
67
+
68
+ return z
69
+
70
+
71
+ # Stolen from https://github.com/basujindal/stable-diffusion/blob/main/optimizedSD/splitAttention.py
72
+ def _hacked_sliced_attentin_forward(self, x, context=None, mask=None):
73
+ h = self.heads
74
+
75
+ q = self.to_q(x)
76
+ context = default(context, x)
77
+ k = self.to_k(context)
78
+ v = self.to_v(context)
79
+ del context, x
80
+
81
+ q, k, v = map(lambda t: einops.rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
82
+
83
+ limit = k.shape[0]
84
+ att_step = 1
85
+ q_chunks = list(torch.tensor_split(q, limit // att_step, dim=0))
86
+ k_chunks = list(torch.tensor_split(k, limit // att_step, dim=0))
87
+ v_chunks = list(torch.tensor_split(v, limit // att_step, dim=0))
88
+
89
+ q_chunks.reverse()
90
+ k_chunks.reverse()
91
+ v_chunks.reverse()
92
+ sim = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device)
93
+ del k, q, v
94
+ for i in range(0, limit, att_step):
95
+ q_buffer = q_chunks.pop()
96
+ k_buffer = k_chunks.pop()
97
+ v_buffer = v_chunks.pop()
98
+ sim_buffer = torch.einsum('b i d, b j d -> b i j', q_buffer, k_buffer) * self.scale
99
+
100
+ del k_buffer, q_buffer
101
+ # attention, what we cannot get enough of, by chunks
102
+
103
+ sim_buffer = sim_buffer.softmax(dim=-1)
104
+
105
+ sim_buffer = torch.einsum('b i j, b j d -> b i d', sim_buffer, v_buffer)
106
+ del v_buffer
107
+ sim[i:i + att_step, :, :] = sim_buffer
108
+
109
+ del sim_buffer
110
+ sim = einops.rearrange(sim, '(b h) n d -> b n (h d)', h=h)
111
+ return self.to_out(sim)
iopaint/model/anytext/cldm/model.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+
4
+ from omegaconf import OmegaConf
5
+ from iopaint.model.anytext.ldm.util import instantiate_from_config
6
+
7
+
8
+ def get_state_dict(d):
9
+ return d.get("state_dict", d)
10
+
11
+
12
+ def load_state_dict(ckpt_path, location="cpu"):
13
+ _, extension = os.path.splitext(ckpt_path)
14
+ if extension.lower() == ".safetensors":
15
+ import safetensors.torch
16
+
17
+ state_dict = safetensors.torch.load_file(ckpt_path, device=location)
18
+ else:
19
+ state_dict = get_state_dict(
20
+ torch.load(ckpt_path, map_location=torch.device(location))
21
+ )
22
+ state_dict = get_state_dict(state_dict)
23
+ print(f"Loaded state_dict from [{ckpt_path}]")
24
+ return state_dict
25
+
26
+
27
+ def create_model(config_path, device, cond_stage_path=None, use_fp16=False):
28
+ config = OmegaConf.load(config_path)
29
+ # if cond_stage_path:
30
+ # config.model.params.cond_stage_config.params.version = (
31
+ # cond_stage_path # use pre-downloaded ckpts, in case blocked
32
+ # )
33
+ config.model.params.cond_stage_config.params.device = str(device)
34
+ if use_fp16:
35
+ config.model.params.use_fp16 = True
36
+ config.model.params.control_stage_config.params.use_fp16 = True
37
+ config.model.params.unet_config.params.use_fp16 = True
38
+ model = instantiate_from_config(config.model).cpu()
39
+ print(f"Loaded model config from [{config_path}]")
40
+ return model
iopaint/model/anytext/ldm/models/diffusion/ddim.py ADDED
@@ -0,0 +1,354 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """SAMPLING ONLY."""
2
+
3
+ import torch
4
+ import numpy as np
5
+ from tqdm import tqdm
6
+
7
+ from iopaint.model.anytext.ldm.modules.diffusionmodules.util import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like, extract_into_tensor
8
+
9
+
10
+ class DDIMSampler(object):
11
+ def __init__(self, model, schedule="linear", **kwargs):
12
+ super().__init__()
13
+ self.model = model
14
+ self.ddpm_num_timesteps = model.num_timesteps
15
+ self.schedule = schedule
16
+
17
+ def register_buffer(self, name, attr):
18
+ if type(attr) == torch.Tensor:
19
+ if attr.device != torch.device("cuda"):
20
+ attr = attr.to(torch.device("cuda"))
21
+ setattr(self, name, attr)
22
+
23
+ def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True):
24
+ self.ddim_timesteps = make_ddim_timesteps(ddim_discr_method=ddim_discretize, num_ddim_timesteps=ddim_num_steps,
25
+ num_ddpm_timesteps=self.ddpm_num_timesteps,verbose=verbose)
26
+ alphas_cumprod = self.model.alphas_cumprod
27
+ assert alphas_cumprod.shape[0] == self.ddpm_num_timesteps, 'alphas have to be defined for each timestep'
28
+ to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device)
29
+
30
+ self.register_buffer('betas', to_torch(self.model.betas))
31
+ self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))
32
+ self.register_buffer('alphas_cumprod_prev', to_torch(self.model.alphas_cumprod_prev))
33
+
34
+ # calculations for diffusion q(x_t | x_{t-1}) and others
35
+ self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod.cpu())))
36
+ self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod.cpu())))
37
+ self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod.cpu())))
38
+ self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu())))
39
+ self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu() - 1)))
40
+
41
+ # ddim sampling parameters
42
+ ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(alphacums=alphas_cumprod.cpu(),
43
+ ddim_timesteps=self.ddim_timesteps,
44
+ eta=ddim_eta,verbose=verbose)
45
+ self.register_buffer('ddim_sigmas', ddim_sigmas)
46
+ self.register_buffer('ddim_alphas', ddim_alphas)
47
+ self.register_buffer('ddim_alphas_prev', ddim_alphas_prev)
48
+ self.register_buffer('ddim_sqrt_one_minus_alphas', np.sqrt(1. - ddim_alphas))
49
+ sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt(
50
+ (1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod) * (
51
+ 1 - self.alphas_cumprod / self.alphas_cumprod_prev))
52
+ self.register_buffer('ddim_sigmas_for_original_num_steps', sigmas_for_original_sampling_steps)
53
+
54
+ @torch.no_grad()
55
+ def sample(self,
56
+ S,
57
+ batch_size,
58
+ shape,
59
+ conditioning=None,
60
+ callback=None,
61
+ normals_sequence=None,
62
+ img_callback=None,
63
+ quantize_x0=False,
64
+ eta=0.,
65
+ mask=None,
66
+ x0=None,
67
+ temperature=1.,
68
+ noise_dropout=0.,
69
+ score_corrector=None,
70
+ corrector_kwargs=None,
71
+ verbose=True,
72
+ x_T=None,
73
+ log_every_t=100,
74
+ unconditional_guidance_scale=1.,
75
+ unconditional_conditioning=None, # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
76
+ dynamic_threshold=None,
77
+ ucg_schedule=None,
78
+ **kwargs
79
+ ):
80
+ if conditioning is not None:
81
+ if isinstance(conditioning, dict):
82
+ ctmp = conditioning[list(conditioning.keys())[0]]
83
+ while isinstance(ctmp, list): ctmp = ctmp[0]
84
+ cbs = ctmp.shape[0]
85
+ # cbs = len(ctmp[0])
86
+ if cbs != batch_size:
87
+ print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
88
+
89
+ elif isinstance(conditioning, list):
90
+ for ctmp in conditioning:
91
+ if ctmp.shape[0] != batch_size:
92
+ print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
93
+
94
+ else:
95
+ if conditioning.shape[0] != batch_size:
96
+ print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}")
97
+
98
+ self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose)
99
+ # sampling
100
+ C, H, W = shape
101
+ size = (batch_size, C, H, W)
102
+ print(f'Data shape for DDIM sampling is {size}, eta {eta}')
103
+
104
+ samples, intermediates = self.ddim_sampling(conditioning, size,
105
+ callback=callback,
106
+ img_callback=img_callback,
107
+ quantize_denoised=quantize_x0,
108
+ mask=mask, x0=x0,
109
+ ddim_use_original_steps=False,
110
+ noise_dropout=noise_dropout,
111
+ temperature=temperature,
112
+ score_corrector=score_corrector,
113
+ corrector_kwargs=corrector_kwargs,
114
+ x_T=x_T,
115
+ log_every_t=log_every_t,
116
+ unconditional_guidance_scale=unconditional_guidance_scale,
117
+ unconditional_conditioning=unconditional_conditioning,
118
+ dynamic_threshold=dynamic_threshold,
119
+ ucg_schedule=ucg_schedule
120
+ )
121
+ return samples, intermediates
122
+
123
+ @torch.no_grad()
124
+ def ddim_sampling(self, cond, shape,
125
+ x_T=None, ddim_use_original_steps=False,
126
+ callback=None, timesteps=None, quantize_denoised=False,
127
+ mask=None, x0=None, img_callback=None, log_every_t=100,
128
+ temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
129
+ unconditional_guidance_scale=1., unconditional_conditioning=None, dynamic_threshold=None,
130
+ ucg_schedule=None):
131
+ device = self.model.betas.device
132
+ b = shape[0]
133
+ if x_T is None:
134
+ img = torch.randn(shape, device=device)
135
+ else:
136
+ img = x_T
137
+
138
+ if timesteps is None:
139
+ timesteps = self.ddpm_num_timesteps if ddim_use_original_steps else self.ddim_timesteps
140
+ elif timesteps is not None and not ddim_use_original_steps:
141
+ subset_end = int(min(timesteps / self.ddim_timesteps.shape[0], 1) * self.ddim_timesteps.shape[0]) - 1
142
+ timesteps = self.ddim_timesteps[:subset_end]
143
+
144
+ intermediates = {'x_inter': [img], 'pred_x0': [img], "index": [10000]}
145
+ time_range = reversed(range(0, timesteps)) if ddim_use_original_steps else np.flip(timesteps)
146
+ total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0]
147
+ print(f"Running DDIM Sampling with {total_steps} timesteps")
148
+
149
+ iterator = tqdm(time_range, desc='DDIM Sampler', total=total_steps)
150
+
151
+ for i, step in enumerate(iterator):
152
+ index = total_steps - i - 1
153
+ ts = torch.full((b,), step, device=device, dtype=torch.long)
154
+
155
+ if mask is not None:
156
+ assert x0 is not None
157
+ img_orig = self.model.q_sample(x0, ts) # TODO: deterministic forward pass?
158
+ img = img_orig * mask + (1. - mask) * img
159
+
160
+ if ucg_schedule is not None:
161
+ assert len(ucg_schedule) == len(time_range)
162
+ unconditional_guidance_scale = ucg_schedule[i]
163
+
164
+ outs = self.p_sample_ddim(img, cond, ts, index=index, use_original_steps=ddim_use_original_steps,
165
+ quantize_denoised=quantize_denoised, temperature=temperature,
166
+ noise_dropout=noise_dropout, score_corrector=score_corrector,
167
+ corrector_kwargs=corrector_kwargs,
168
+ unconditional_guidance_scale=unconditional_guidance_scale,
169
+ unconditional_conditioning=unconditional_conditioning,
170
+ dynamic_threshold=dynamic_threshold)
171
+ img, pred_x0 = outs
172
+ if callback:
173
+ callback(i)
174
+ if img_callback:
175
+ img_callback(pred_x0, i)
176
+
177
+ if index % log_every_t == 0 or index == total_steps - 1:
178
+ intermediates['x_inter'].append(img)
179
+ intermediates['pred_x0'].append(pred_x0)
180
+ intermediates['index'].append(index)
181
+
182
+ return img, intermediates
183
+
184
+ @torch.no_grad()
185
+ def p_sample_ddim(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False,
186
+ temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
187
+ unconditional_guidance_scale=1., unconditional_conditioning=None,
188
+ dynamic_threshold=None):
189
+ b, *_, device = *x.shape, x.device
190
+
191
+ if unconditional_conditioning is None or unconditional_guidance_scale == 1.:
192
+ model_output = self.model.apply_model(x, t, c)
193
+ else:
194
+ x_in = torch.cat([x] * 2)
195
+ t_in = torch.cat([t] * 2)
196
+ if isinstance(c, dict):
197
+ assert isinstance(unconditional_conditioning, dict)
198
+ c_in = dict()
199
+ for k in c:
200
+ if isinstance(c[k], list):
201
+ c_in[k] = [torch.cat([
202
+ unconditional_conditioning[k][i],
203
+ c[k][i]]) for i in range(len(c[k]))]
204
+ elif isinstance(c[k], dict):
205
+ c_in[k] = dict()
206
+ for key in c[k]:
207
+ if isinstance(c[k][key], list):
208
+ if not isinstance(c[k][key][0], torch.Tensor):
209
+ continue
210
+ c_in[k][key] = [torch.cat([
211
+ unconditional_conditioning[k][key][i],
212
+ c[k][key][i]]) for i in range(len(c[k][key]))]
213
+ else:
214
+ c_in[k][key] = torch.cat([
215
+ unconditional_conditioning[k][key],
216
+ c[k][key]])
217
+
218
+ else:
219
+ c_in[k] = torch.cat([
220
+ unconditional_conditioning[k],
221
+ c[k]])
222
+ elif isinstance(c, list):
223
+ c_in = list()
224
+ assert isinstance(unconditional_conditioning, list)
225
+ for i in range(len(c)):
226
+ c_in.append(torch.cat([unconditional_conditioning[i], c[i]]))
227
+ else:
228
+ c_in = torch.cat([unconditional_conditioning, c])
229
+ model_uncond, model_t = self.model.apply_model(x_in, t_in, c_in).chunk(2)
230
+ model_output = model_uncond + unconditional_guidance_scale * (model_t - model_uncond)
231
+
232
+ if self.model.parameterization == "v":
233
+ e_t = self.model.predict_eps_from_z_and_v(x, t, model_output)
234
+ else:
235
+ e_t = model_output
236
+
237
+ if score_corrector is not None:
238
+ assert self.model.parameterization == "eps", 'not implemented'
239
+ e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs)
240
+
241
+ alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas
242
+ alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev
243
+ sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas
244
+ sigmas = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas
245
+ # select parameters corresponding to the currently considered timestep
246
+ a_t = torch.full((b, 1, 1, 1), alphas[index], device=device)
247
+ a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device)
248
+ sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device)
249
+ sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index],device=device)
250
+
251
+ # current prediction for x_0
252
+ if self.model.parameterization != "v":
253
+ pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
254
+ else:
255
+ pred_x0 = self.model.predict_start_from_z_and_v(x, t, model_output)
256
+
257
+ if quantize_denoised:
258
+ pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0)
259
+
260
+ if dynamic_threshold is not None:
261
+ raise NotImplementedError()
262
+
263
+ # direction pointing to x_t
264
+ dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t
265
+ noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature
266
+ if noise_dropout > 0.:
267
+ noise = torch.nn.functional.dropout(noise, p=noise_dropout)
268
+ x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise
269
+ return x_prev, pred_x0
270
+
271
+ @torch.no_grad()
272
+ def encode(self, x0, c, t_enc, use_original_steps=False, return_intermediates=None,
273
+ unconditional_guidance_scale=1.0, unconditional_conditioning=None, callback=None):
274
+ num_reference_steps = self.ddpm_num_timesteps if use_original_steps else self.ddim_timesteps.shape[0]
275
+
276
+ assert t_enc <= num_reference_steps
277
+ num_steps = t_enc
278
+
279
+ if use_original_steps:
280
+ alphas_next = self.alphas_cumprod[:num_steps]
281
+ alphas = self.alphas_cumprod_prev[:num_steps]
282
+ else:
283
+ alphas_next = self.ddim_alphas[:num_steps]
284
+ alphas = torch.tensor(self.ddim_alphas_prev[:num_steps])
285
+
286
+ x_next = x0
287
+ intermediates = []
288
+ inter_steps = []
289
+ for i in tqdm(range(num_steps), desc='Encoding Image'):
290
+ t = torch.full((x0.shape[0],), i, device=self.model.device, dtype=torch.long)
291
+ if unconditional_guidance_scale == 1.:
292
+ noise_pred = self.model.apply_model(x_next, t, c)
293
+ else:
294
+ assert unconditional_conditioning is not None
295
+ e_t_uncond, noise_pred = torch.chunk(
296
+ self.model.apply_model(torch.cat((x_next, x_next)), torch.cat((t, t)),
297
+ torch.cat((unconditional_conditioning, c))), 2)
298
+ noise_pred = e_t_uncond + unconditional_guidance_scale * (noise_pred - e_t_uncond)
299
+
300
+ xt_weighted = (alphas_next[i] / alphas[i]).sqrt() * x_next
301
+ weighted_noise_pred = alphas_next[i].sqrt() * (
302
+ (1 / alphas_next[i] - 1).sqrt() - (1 / alphas[i] - 1).sqrt()) * noise_pred
303
+ x_next = xt_weighted + weighted_noise_pred
304
+ if return_intermediates and i % (
305
+ num_steps // return_intermediates) == 0 and i < num_steps - 1:
306
+ intermediates.append(x_next)
307
+ inter_steps.append(i)
308
+ elif return_intermediates and i >= num_steps - 2:
309
+ intermediates.append(x_next)
310
+ inter_steps.append(i)
311
+ if callback: callback(i)
312
+
313
+ out = {'x_encoded': x_next, 'intermediate_steps': inter_steps}
314
+ if return_intermediates:
315
+ out.update({'intermediates': intermediates})
316
+ return x_next, out
317
+
318
+ @torch.no_grad()
319
+ def stochastic_encode(self, x0, t, use_original_steps=False, noise=None):
320
+ # fast, but does not allow for exact reconstruction
321
+ # t serves as an index to gather the correct alphas
322
+ if use_original_steps:
323
+ sqrt_alphas_cumprod = self.sqrt_alphas_cumprod
324
+ sqrt_one_minus_alphas_cumprod = self.sqrt_one_minus_alphas_cumprod
325
+ else:
326
+ sqrt_alphas_cumprod = torch.sqrt(self.ddim_alphas)
327
+ sqrt_one_minus_alphas_cumprod = self.ddim_sqrt_one_minus_alphas
328
+
329
+ if noise is None:
330
+ noise = torch.randn_like(x0)
331
+ return (extract_into_tensor(sqrt_alphas_cumprod, t, x0.shape) * x0 +
332
+ extract_into_tensor(sqrt_one_minus_alphas_cumprod, t, x0.shape) * noise)
333
+
334
+ @torch.no_grad()
335
+ def decode(self, x_latent, cond, t_start, unconditional_guidance_scale=1.0, unconditional_conditioning=None,
336
+ use_original_steps=False, callback=None):
337
+
338
+ timesteps = np.arange(self.ddpm_num_timesteps) if use_original_steps else self.ddim_timesteps
339
+ timesteps = timesteps[:t_start]
340
+
341
+ time_range = np.flip(timesteps)
342
+ total_steps = timesteps.shape[0]
343
+ print(f"Running DDIM Sampling with {total_steps} timesteps")
344
+
345
+ iterator = tqdm(time_range, desc='Decoding image', total=total_steps)
346
+ x_dec = x_latent
347
+ for i, step in enumerate(iterator):
348
+ index = total_steps - i - 1
349
+ ts = torch.full((x_latent.shape[0],), step, device=x_latent.device, dtype=torch.long)
350
+ x_dec, _ = self.p_sample_ddim(x_dec, cond, ts, index=index, use_original_steps=use_original_steps,
351
+ unconditional_guidance_scale=unconditional_guidance_scale,
352
+ unconditional_conditioning=unconditional_conditioning)
353
+ if callback: callback(i)
354
+ return x_dec
iopaint/model/anytext/ldm/models/diffusion/ddpm.py ADDED
@@ -0,0 +1,2380 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Part of the implementation is borrowed and modified from ControlNet, publicly available at https://github.com/lllyasviel/ControlNet/blob/main/ldm/models/diffusion/ddpm.py
3
+ """
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+ import numpy as np
8
+ from torch.optim.lr_scheduler import LambdaLR
9
+ from einops import rearrange, repeat
10
+ from contextlib import contextmanager, nullcontext
11
+ from functools import partial
12
+ import itertools
13
+ from tqdm import tqdm
14
+ from torchvision.utils import make_grid
15
+ from omegaconf import ListConfig
16
+
17
+ from iopaint.model.anytext.ldm.util import (
18
+ log_txt_as_img,
19
+ exists,
20
+ default,
21
+ ismap,
22
+ isimage,
23
+ mean_flat,
24
+ count_params,
25
+ instantiate_from_config,
26
+ )
27
+ from iopaint.model.anytext.ldm.modules.ema import LitEma
28
+ from iopaint.model.anytext.ldm.modules.distributions.distributions import (
29
+ normal_kl,
30
+ DiagonalGaussianDistribution,
31
+ )
32
+ from iopaint.model.anytext.ldm.models.autoencoder import IdentityFirstStage, AutoencoderKL
33
+ from iopaint.model.anytext.ldm.modules.diffusionmodules.util import (
34
+ make_beta_schedule,
35
+ extract_into_tensor,
36
+ noise_like,
37
+ )
38
+ from iopaint.model.anytext.ldm.models.diffusion.ddim import DDIMSampler
39
+ import cv2
40
+
41
+
42
+ __conditioning_keys__ = {"concat": "c_concat", "crossattn": "c_crossattn", "adm": "y"}
43
+
44
+ PRINT_DEBUG = False
45
+
46
+
47
+ def print_grad(grad):
48
+ # print('Gradient:', grad)
49
+ # print(grad.shape)
50
+ a = grad.max()
51
+ b = grad.min()
52
+ # print(f'mean={grad.mean():.4f}, max={a:.4f}, min={b:.4f}')
53
+ s = 255.0 / (a - b)
54
+ c = 255 * (-b / (a - b))
55
+ grad = grad * s + c
56
+ # print(f'mean={grad.mean():.4f}, max={grad.max():.4f}, min={grad.min():.4f}')
57
+ img = grad[0].permute(1, 2, 0).detach().cpu().numpy()
58
+ if img.shape[0] == 512:
59
+ cv2.imwrite("grad-img.jpg", img)
60
+ elif img.shape[0] == 64:
61
+ cv2.imwrite("grad-latent.jpg", img)
62
+
63
+
64
+ def disabled_train(self, mode=True):
65
+ """Overwrite model.train with this function to make sure train/eval mode
66
+ does not change anymore."""
67
+ return self
68
+
69
+
70
+ def uniform_on_device(r1, r2, shape, device):
71
+ return (r1 - r2) * torch.rand(*shape, device=device) + r2
72
+
73
+
74
+ class DDPM(torch.nn.Module):
75
+ # classic DDPM with Gaussian diffusion, in image space
76
+ def __init__(
77
+ self,
78
+ unet_config,
79
+ timesteps=1000,
80
+ beta_schedule="linear",
81
+ loss_type="l2",
82
+ ckpt_path=None,
83
+ ignore_keys=[],
84
+ load_only_unet=False,
85
+ monitor="val/loss",
86
+ use_ema=True,
87
+ first_stage_key="image",
88
+ image_size=256,
89
+ channels=3,
90
+ log_every_t=100,
91
+ clip_denoised=True,
92
+ linear_start=1e-4,
93
+ linear_end=2e-2,
94
+ cosine_s=8e-3,
95
+ given_betas=None,
96
+ original_elbo_weight=0.0,
97
+ v_posterior=0.0, # weight for choosing posterior variance as sigma = (1-v) * beta_tilde + v * beta
98
+ l_simple_weight=1.0,
99
+ conditioning_key=None,
100
+ parameterization="eps", # all assuming fixed variance schedules
101
+ scheduler_config=None,
102
+ use_positional_encodings=False,
103
+ learn_logvar=False,
104
+ logvar_init=0.0,
105
+ make_it_fit=False,
106
+ ucg_training=None,
107
+ reset_ema=False,
108
+ reset_num_ema_updates=False,
109
+ ):
110
+ super().__init__()
111
+ assert parameterization in [
112
+ "eps",
113
+ "x0",
114
+ "v",
115
+ ], 'currently only supporting "eps" and "x0" and "v"'
116
+ self.parameterization = parameterization
117
+ print(
118
+ f"{self.__class__.__name__}: Running in {self.parameterization}-prediction mode"
119
+ )
120
+ self.cond_stage_model = None
121
+ self.clip_denoised = clip_denoised
122
+ self.log_every_t = log_every_t
123
+ self.first_stage_key = first_stage_key
124
+ self.image_size = image_size # try conv?
125
+ self.channels = channels
126
+ self.use_positional_encodings = use_positional_encodings
127
+ self.model = DiffusionWrapper(unet_config, conditioning_key)
128
+ count_params(self.model, verbose=True)
129
+ self.use_ema = use_ema
130
+ if self.use_ema:
131
+ self.model_ema = LitEma(self.model)
132
+ print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
133
+
134
+ self.use_scheduler = scheduler_config is not None
135
+ if self.use_scheduler:
136
+ self.scheduler_config = scheduler_config
137
+
138
+ self.v_posterior = v_posterior
139
+ self.original_elbo_weight = original_elbo_weight
140
+ self.l_simple_weight = l_simple_weight
141
+
142
+ if monitor is not None:
143
+ self.monitor = monitor
144
+ self.make_it_fit = make_it_fit
145
+ if reset_ema:
146
+ assert exists(ckpt_path)
147
+ if ckpt_path is not None:
148
+ self.init_from_ckpt(
149
+ ckpt_path, ignore_keys=ignore_keys, only_model=load_only_unet
150
+ )
151
+ if reset_ema:
152
+ assert self.use_ema
153
+ print(
154
+ f"Resetting ema to pure model weights. This is useful when restoring from an ema-only checkpoint."
155
+ )
156
+ self.model_ema = LitEma(self.model)
157
+ if reset_num_ema_updates:
158
+ print(
159
+ " +++++++++++ WARNING: RESETTING NUM_EMA UPDATES TO ZERO +++++++++++ "
160
+ )
161
+ assert self.use_ema
162
+ self.model_ema.reset_num_updates()
163
+
164
+ self.register_schedule(
165
+ given_betas=given_betas,
166
+ beta_schedule=beta_schedule,
167
+ timesteps=timesteps,
168
+ linear_start=linear_start,
169
+ linear_end=linear_end,
170
+ cosine_s=cosine_s,
171
+ )
172
+
173
+ self.loss_type = loss_type
174
+
175
+ self.learn_logvar = learn_logvar
176
+ logvar = torch.full(fill_value=logvar_init, size=(self.num_timesteps,))
177
+ if self.learn_logvar:
178
+ self.logvar = nn.Parameter(self.logvar, requires_grad=True)
179
+ else:
180
+ self.register_buffer("logvar", logvar)
181
+
182
+ self.ucg_training = ucg_training or dict()
183
+ if self.ucg_training:
184
+ self.ucg_prng = np.random.RandomState()
185
+
186
+ def register_schedule(
187
+ self,
188
+ given_betas=None,
189
+ beta_schedule="linear",
190
+ timesteps=1000,
191
+ linear_start=1e-4,
192
+ linear_end=2e-2,
193
+ cosine_s=8e-3,
194
+ ):
195
+ if exists(given_betas):
196
+ betas = given_betas
197
+ else:
198
+ betas = make_beta_schedule(
199
+ beta_schedule,
200
+ timesteps,
201
+ linear_start=linear_start,
202
+ linear_end=linear_end,
203
+ cosine_s=cosine_s,
204
+ )
205
+ alphas = 1.0 - betas
206
+ alphas_cumprod = np.cumprod(alphas, axis=0)
207
+ # np.save('1.npy', alphas_cumprod)
208
+ alphas_cumprod_prev = np.append(1.0, alphas_cumprod[:-1])
209
+
210
+ (timesteps,) = betas.shape
211
+ self.num_timesteps = int(timesteps)
212
+ self.linear_start = linear_start
213
+ self.linear_end = linear_end
214
+ assert (
215
+ alphas_cumprod.shape[0] == self.num_timesteps
216
+ ), "alphas have to be defined for each timestep"
217
+
218
+ to_torch = partial(torch.tensor, dtype=torch.float32)
219
+
220
+ self.register_buffer("betas", to_torch(betas))
221
+ self.register_buffer("alphas_cumprod", to_torch(alphas_cumprod))
222
+ self.register_buffer("alphas_cumprod_prev", to_torch(alphas_cumprod_prev))
223
+
224
+ # calculations for diffusion q(x_t | x_{t-1}) and others
225
+ self.register_buffer("sqrt_alphas_cumprod", to_torch(np.sqrt(alphas_cumprod)))
226
+ self.register_buffer(
227
+ "sqrt_one_minus_alphas_cumprod", to_torch(np.sqrt(1.0 - alphas_cumprod))
228
+ )
229
+ self.register_buffer(
230
+ "log_one_minus_alphas_cumprod", to_torch(np.log(1.0 - alphas_cumprod))
231
+ )
232
+ self.register_buffer(
233
+ "sqrt_recip_alphas_cumprod", to_torch(np.sqrt(1.0 / alphas_cumprod))
234
+ )
235
+ self.register_buffer(
236
+ "sqrt_recipm1_alphas_cumprod", to_torch(np.sqrt(1.0 / alphas_cumprod - 1))
237
+ )
238
+
239
+ # calculations for posterior q(x_{t-1} | x_t, x_0)
240
+ posterior_variance = (1 - self.v_posterior) * betas * (
241
+ 1.0 - alphas_cumprod_prev
242
+ ) / (1.0 - alphas_cumprod) + self.v_posterior * betas
243
+ # above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t)
244
+ self.register_buffer("posterior_variance", to_torch(posterior_variance))
245
+ # below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain
246
+ self.register_buffer(
247
+ "posterior_log_variance_clipped",
248
+ to_torch(np.log(np.maximum(posterior_variance, 1e-20))),
249
+ )
250
+ self.register_buffer(
251
+ "posterior_mean_coef1",
252
+ to_torch(betas * np.sqrt(alphas_cumprod_prev) / (1.0 - alphas_cumprod)),
253
+ )
254
+ self.register_buffer(
255
+ "posterior_mean_coef2",
256
+ to_torch(
257
+ (1.0 - alphas_cumprod_prev) * np.sqrt(alphas) / (1.0 - alphas_cumprod)
258
+ ),
259
+ )
260
+
261
+ if self.parameterization == "eps":
262
+ lvlb_weights = self.betas**2 / (
263
+ 2
264
+ * self.posterior_variance
265
+ * to_torch(alphas)
266
+ * (1 - self.alphas_cumprod)
267
+ )
268
+ elif self.parameterization == "x0":
269
+ lvlb_weights = (
270
+ 0.5
271
+ * np.sqrt(torch.Tensor(alphas_cumprod))
272
+ / (2.0 * 1 - torch.Tensor(alphas_cumprod))
273
+ )
274
+ elif self.parameterization == "v":
275
+ lvlb_weights = torch.ones_like(
276
+ self.betas**2
277
+ / (
278
+ 2
279
+ * self.posterior_variance
280
+ * to_torch(alphas)
281
+ * (1 - self.alphas_cumprod)
282
+ )
283
+ )
284
+ else:
285
+ raise NotImplementedError("mu not supported")
286
+ lvlb_weights[0] = lvlb_weights[1]
287
+ self.register_buffer("lvlb_weights", lvlb_weights, persistent=False)
288
+ assert not torch.isnan(self.lvlb_weights).all()
289
+
290
+ @contextmanager
291
+ def ema_scope(self, context=None):
292
+ if self.use_ema:
293
+ self.model_ema.store(self.model.parameters())
294
+ self.model_ema.copy_to(self.model)
295
+ if context is not None:
296
+ print(f"{context}: Switched to EMA weights")
297
+ try:
298
+ yield None
299
+ finally:
300
+ if self.use_ema:
301
+ self.model_ema.restore(self.model.parameters())
302
+ if context is not None:
303
+ print(f"{context}: Restored training weights")
304
+
305
+ @torch.no_grad()
306
+ def init_from_ckpt(self, path, ignore_keys=list(), only_model=False):
307
+ sd = torch.load(path, map_location="cpu")
308
+ if "state_dict" in list(sd.keys()):
309
+ sd = sd["state_dict"]
310
+ keys = list(sd.keys())
311
+ for k in keys:
312
+ for ik in ignore_keys:
313
+ if k.startswith(ik):
314
+ print("Deleting key {} from state_dict.".format(k))
315
+ del sd[k]
316
+ if self.make_it_fit:
317
+ n_params = len(
318
+ [
319
+ name
320
+ for name, _ in itertools.chain(
321
+ self.named_parameters(), self.named_buffers()
322
+ )
323
+ ]
324
+ )
325
+ for name, param in tqdm(
326
+ itertools.chain(self.named_parameters(), self.named_buffers()),
327
+ desc="Fitting old weights to new weights",
328
+ total=n_params,
329
+ ):
330
+ if not name in sd:
331
+ continue
332
+ old_shape = sd[name].shape
333
+ new_shape = param.shape
334
+ assert len(old_shape) == len(new_shape)
335
+ if len(new_shape) > 2:
336
+ # we only modify first two axes
337
+ assert new_shape[2:] == old_shape[2:]
338
+ # assumes first axis corresponds to output dim
339
+ if not new_shape == old_shape:
340
+ new_param = param.clone()
341
+ old_param = sd[name]
342
+ if len(new_shape) == 1:
343
+ for i in range(new_param.shape[0]):
344
+ new_param[i] = old_param[i % old_shape[0]]
345
+ elif len(new_shape) >= 2:
346
+ for i in range(new_param.shape[0]):
347
+ for j in range(new_param.shape[1]):
348
+ new_param[i, j] = old_param[
349
+ i % old_shape[0], j % old_shape[1]
350
+ ]
351
+
352
+ n_used_old = torch.ones(old_shape[1])
353
+ for j in range(new_param.shape[1]):
354
+ n_used_old[j % old_shape[1]] += 1
355
+ n_used_new = torch.zeros(new_shape[1])
356
+ for j in range(new_param.shape[1]):
357
+ n_used_new[j] = n_used_old[j % old_shape[1]]
358
+
359
+ n_used_new = n_used_new[None, :]
360
+ while len(n_used_new.shape) < len(new_shape):
361
+ n_used_new = n_used_new.unsqueeze(-1)
362
+ new_param /= n_used_new
363
+
364
+ sd[name] = new_param
365
+
366
+ missing, unexpected = (
367
+ self.load_state_dict(sd, strict=False)
368
+ if not only_model
369
+ else self.model.load_state_dict(sd, strict=False)
370
+ )
371
+ print(
372
+ f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys"
373
+ )
374
+ if len(missing) > 0:
375
+ print(f"Missing Keys:\n {missing}")
376
+ if len(unexpected) > 0:
377
+ print(f"\nUnexpected Keys:\n {unexpected}")
378
+
379
+ def q_mean_variance(self, x_start, t):
380
+ """
381
+ Get the distribution q(x_t | x_0).
382
+ :param x_start: the [N x C x ...] tensor of noiseless inputs.
383
+ :param t: the number of diffusion steps (minus 1). Here, 0 means one step.
384
+ :return: A tuple (mean, variance, log_variance), all of x_start's shape.
385
+ """
386
+ mean = extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start
387
+ variance = extract_into_tensor(1.0 - self.alphas_cumprod, t, x_start.shape)
388
+ log_variance = extract_into_tensor(
389
+ self.log_one_minus_alphas_cumprod, t, x_start.shape
390
+ )
391
+ return mean, variance, log_variance
392
+
393
+ def predict_start_from_noise(self, x_t, t, noise):
394
+ return (
395
+ extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t
396
+ - extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape)
397
+ * noise
398
+ )
399
+
400
+ def predict_start_from_z_and_v(self, x_t, t, v):
401
+ # self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod)))
402
+ # self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod)))
403
+ return (
404
+ extract_into_tensor(self.sqrt_alphas_cumprod, t, x_t.shape) * x_t
405
+ - extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_t.shape) * v
406
+ )
407
+
408
+ def predict_eps_from_z_and_v(self, x_t, t, v):
409
+ return (
410
+ extract_into_tensor(self.sqrt_alphas_cumprod, t, x_t.shape) * v
411
+ + extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_t.shape)
412
+ * x_t
413
+ )
414
+
415
+ def q_posterior(self, x_start, x_t, t):
416
+ posterior_mean = (
417
+ extract_into_tensor(self.posterior_mean_coef1, t, x_t.shape) * x_start
418
+ + extract_into_tensor(self.posterior_mean_coef2, t, x_t.shape) * x_t
419
+ )
420
+ posterior_variance = extract_into_tensor(self.posterior_variance, t, x_t.shape)
421
+ posterior_log_variance_clipped = extract_into_tensor(
422
+ self.posterior_log_variance_clipped, t, x_t.shape
423
+ )
424
+ return posterior_mean, posterior_variance, posterior_log_variance_clipped
425
+
426
+ def p_mean_variance(self, x, t, clip_denoised: bool):
427
+ model_out = self.model(x, t)
428
+ if self.parameterization == "eps":
429
+ x_recon = self.predict_start_from_noise(x, t=t, noise=model_out)
430
+ elif self.parameterization == "x0":
431
+ x_recon = model_out
432
+ if clip_denoised:
433
+ x_recon.clamp_(-1.0, 1.0)
434
+
435
+ model_mean, posterior_variance, posterior_log_variance = self.q_posterior(
436
+ x_start=x_recon, x_t=x, t=t
437
+ )
438
+ return model_mean, posterior_variance, posterior_log_variance
439
+
440
+ @torch.no_grad()
441
+ def p_sample(self, x, t, clip_denoised=True, repeat_noise=False):
442
+ b, *_, device = *x.shape, x.device
443
+ model_mean, _, model_log_variance = self.p_mean_variance(
444
+ x=x, t=t, clip_denoised=clip_denoised
445
+ )
446
+ noise = noise_like(x.shape, device, repeat_noise)
447
+ # no noise when t == 0
448
+ nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1)))
449
+ return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise
450
+
451
+ @torch.no_grad()
452
+ def p_sample_loop(self, shape, return_intermediates=False):
453
+ device = self.betas.device
454
+ b = shape[0]
455
+ img = torch.randn(shape, device=device)
456
+ intermediates = [img]
457
+ for i in tqdm(
458
+ reversed(range(0, self.num_timesteps)),
459
+ desc="Sampling t",
460
+ total=self.num_timesteps,
461
+ ):
462
+ img = self.p_sample(
463
+ img,
464
+ torch.full((b,), i, device=device, dtype=torch.long),
465
+ clip_denoised=self.clip_denoised,
466
+ )
467
+ if i % self.log_every_t == 0 or i == self.num_timesteps - 1:
468
+ intermediates.append(img)
469
+ if return_intermediates:
470
+ return img, intermediates
471
+ return img
472
+
473
+ @torch.no_grad()
474
+ def sample(self, batch_size=16, return_intermediates=False):
475
+ image_size = self.image_size
476
+ channels = self.channels
477
+ return self.p_sample_loop(
478
+ (batch_size, channels, image_size, image_size),
479
+ return_intermediates=return_intermediates,
480
+ )
481
+
482
+ def q_sample(self, x_start, t, noise=None):
483
+ noise = default(noise, lambda: torch.randn_like(x_start))
484
+ return (
485
+ extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start
486
+ + extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape)
487
+ * noise
488
+ )
489
+
490
+ def get_v(self, x, noise, t):
491
+ return (
492
+ extract_into_tensor(self.sqrt_alphas_cumprod, t, x.shape) * noise
493
+ - extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x.shape) * x
494
+ )
495
+
496
+ def get_loss(self, pred, target, mean=True):
497
+ if self.loss_type == "l1":
498
+ loss = (target - pred).abs()
499
+ if mean:
500
+ loss = loss.mean()
501
+ elif self.loss_type == "l2":
502
+ if mean:
503
+ loss = torch.nn.functional.mse_loss(target, pred)
504
+ else:
505
+ loss = torch.nn.functional.mse_loss(target, pred, reduction="none")
506
+ else:
507
+ raise NotImplementedError("unknown loss type '{loss_type}'")
508
+
509
+ return loss
510
+
511
+ def p_losses(self, x_start, t, noise=None):
512
+ noise = default(noise, lambda: torch.randn_like(x_start))
513
+ x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise)
514
+ model_out = self.model(x_noisy, t)
515
+
516
+ loss_dict = {}
517
+ if self.parameterization == "eps":
518
+ target = noise
519
+ elif self.parameterization == "x0":
520
+ target = x_start
521
+ elif self.parameterization == "v":
522
+ target = self.get_v(x_start, noise, t)
523
+ else:
524
+ raise NotImplementedError(
525
+ f"Parameterization {self.parameterization} not yet supported"
526
+ )
527
+
528
+ loss = self.get_loss(model_out, target, mean=False).mean(dim=[1, 2, 3])
529
+
530
+ log_prefix = "train" if self.training else "val"
531
+
532
+ loss_dict.update({f"{log_prefix}/loss_simple": loss.mean()})
533
+ loss_simple = loss.mean() * self.l_simple_weight
534
+
535
+ loss_vlb = (self.lvlb_weights[t] * loss).mean()
536
+ loss_dict.update({f"{log_prefix}/loss_vlb": loss_vlb})
537
+
538
+ loss = loss_simple + self.original_elbo_weight * loss_vlb
539
+
540
+ loss_dict.update({f"{log_prefix}/loss": loss})
541
+
542
+ return loss, loss_dict
543
+
544
+ def forward(self, x, *args, **kwargs):
545
+ # b, c, h, w, device, img_size, = *x.shape, x.device, self.image_size
546
+ # assert h == img_size and w == img_size, f'height and width of image must be {img_size}'
547
+ t = torch.randint(
548
+ 0, self.num_timesteps, (x.shape[0],), device=self.device
549
+ ).long()
550
+ return self.p_losses(x, t, *args, **kwargs)
551
+
552
+ def get_input(self, batch, k):
553
+ x = batch[k]
554
+ if len(x.shape) == 3:
555
+ x = x[..., None]
556
+ x = rearrange(x, "b h w c -> b c h w")
557
+ x = x.to(memory_format=torch.contiguous_format).float()
558
+ return x
559
+
560
+ def shared_step(self, batch):
561
+ x = self.get_input(batch, self.first_stage_key)
562
+ loss, loss_dict = self(x)
563
+ return loss, loss_dict
564
+
565
+ def training_step(self, batch, batch_idx):
566
+ for k in self.ucg_training:
567
+ p = self.ucg_training[k]["p"]
568
+ val = self.ucg_training[k]["val"]
569
+ if val is None:
570
+ val = ""
571
+ for i in range(len(batch[k])):
572
+ if self.ucg_prng.choice(2, p=[1 - p, p]):
573
+ batch[k][i] = val
574
+
575
+ loss, loss_dict = self.shared_step(batch)
576
+
577
+ self.log_dict(
578
+ loss_dict, prog_bar=True, logger=True, on_step=True, on_epoch=True
579
+ )
580
+
581
+ self.log(
582
+ "global_step",
583
+ self.global_step,
584
+ prog_bar=True,
585
+ logger=True,
586
+ on_step=True,
587
+ on_epoch=False,
588
+ )
589
+
590
+ if self.use_scheduler:
591
+ lr = self.optimizers().param_groups[0]["lr"]
592
+ self.log(
593
+ "lr_abs", lr, prog_bar=True, logger=True, on_step=True, on_epoch=False
594
+ )
595
+
596
+ return loss
597
+
598
+ @torch.no_grad()
599
+ def validation_step(self, batch, batch_idx):
600
+ _, loss_dict_no_ema = self.shared_step(batch)
601
+ with self.ema_scope():
602
+ _, loss_dict_ema = self.shared_step(batch)
603
+ loss_dict_ema = {key + "_ema": loss_dict_ema[key] for key in loss_dict_ema}
604
+ self.log_dict(
605
+ loss_dict_no_ema, prog_bar=False, logger=True, on_step=False, on_epoch=True
606
+ )
607
+ self.log_dict(
608
+ loss_dict_ema, prog_bar=False, logger=True, on_step=False, on_epoch=True
609
+ )
610
+
611
+ def on_train_batch_end(self, *args, **kwargs):
612
+ if self.use_ema:
613
+ self.model_ema(self.model)
614
+
615
+ def _get_rows_from_list(self, samples):
616
+ n_imgs_per_row = len(samples)
617
+ denoise_grid = rearrange(samples, "n b c h w -> b n c h w")
618
+ denoise_grid = rearrange(denoise_grid, "b n c h w -> (b n) c h w")
619
+ denoise_grid = make_grid(denoise_grid, nrow=n_imgs_per_row)
620
+ return denoise_grid
621
+
622
+ @torch.no_grad()
623
+ def log_images(self, batch, N=8, n_row=2, sample=True, return_keys=None, **kwargs):
624
+ log = dict()
625
+ x = self.get_input(batch, self.first_stage_key)
626
+ N = min(x.shape[0], N)
627
+ n_row = min(x.shape[0], n_row)
628
+ x = x.to(self.device)[:N]
629
+ log["inputs"] = x
630
+
631
+ # get diffusion row
632
+ diffusion_row = list()
633
+ x_start = x[:n_row]
634
+
635
+ for t in range(self.num_timesteps):
636
+ if t % self.log_every_t == 0 or t == self.num_timesteps - 1:
637
+ t = repeat(torch.tensor([t]), "1 -> b", b=n_row)
638
+ t = t.to(self.device).long()
639
+ noise = torch.randn_like(x_start)
640
+ x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise)
641
+ diffusion_row.append(x_noisy)
642
+
643
+ log["diffusion_row"] = self._get_rows_from_list(diffusion_row)
644
+
645
+ if sample:
646
+ # get denoise row
647
+ with self.ema_scope("Plotting"):
648
+ samples, denoise_row = self.sample(
649
+ batch_size=N, return_intermediates=True
650
+ )
651
+
652
+ log["samples"] = samples
653
+ log["denoise_row"] = self._get_rows_from_list(denoise_row)
654
+
655
+ if return_keys:
656
+ if np.intersect1d(list(log.keys()), return_keys).shape[0] == 0:
657
+ return log
658
+ else:
659
+ return {key: log[key] for key in return_keys}
660
+ return log
661
+
662
+ def configure_optimizers(self):
663
+ lr = self.learning_rate
664
+ params = list(self.model.parameters())
665
+ if self.learn_logvar:
666
+ params = params + [self.logvar]
667
+ opt = torch.optim.AdamW(params, lr=lr)
668
+ return opt
669
+
670
+
671
+ class LatentDiffusion(DDPM):
672
+ """main class"""
673
+
674
+ def __init__(
675
+ self,
676
+ first_stage_config,
677
+ cond_stage_config,
678
+ num_timesteps_cond=None,
679
+ cond_stage_key="image",
680
+ cond_stage_trainable=False,
681
+ concat_mode=True,
682
+ cond_stage_forward=None,
683
+ conditioning_key=None,
684
+ scale_factor=1.0,
685
+ scale_by_std=False,
686
+ force_null_conditioning=False,
687
+ *args,
688
+ **kwargs,
689
+ ):
690
+ self.force_null_conditioning = force_null_conditioning
691
+ self.num_timesteps_cond = default(num_timesteps_cond, 1)
692
+ self.scale_by_std = scale_by_std
693
+ assert self.num_timesteps_cond <= kwargs["timesteps"]
694
+ # for backwards compatibility after implementation of DiffusionWrapper
695
+ if conditioning_key is None:
696
+ conditioning_key = "concat" if concat_mode else "crossattn"
697
+ if (
698
+ cond_stage_config == "__is_unconditional__"
699
+ and not self.force_null_conditioning
700
+ ):
701
+ conditioning_key = None
702
+ ckpt_path = kwargs.pop("ckpt_path", None)
703
+ reset_ema = kwargs.pop("reset_ema", False)
704
+ reset_num_ema_updates = kwargs.pop("reset_num_ema_updates", False)
705
+ ignore_keys = kwargs.pop("ignore_keys", [])
706
+ super().__init__(conditioning_key=conditioning_key, *args, **kwargs)
707
+ self.concat_mode = concat_mode
708
+ self.cond_stage_trainable = cond_stage_trainable
709
+ self.cond_stage_key = cond_stage_key
710
+ try:
711
+ self.num_downs = len(first_stage_config.params.ddconfig.ch_mult) - 1
712
+ except:
713
+ self.num_downs = 0
714
+ if not scale_by_std:
715
+ self.scale_factor = scale_factor
716
+ else:
717
+ self.register_buffer("scale_factor", torch.tensor(scale_factor))
718
+ self.instantiate_first_stage(first_stage_config)
719
+ self.instantiate_cond_stage(cond_stage_config)
720
+ self.cond_stage_forward = cond_stage_forward
721
+ self.clip_denoised = False
722
+ self.bbox_tokenizer = None
723
+
724
+ self.restarted_from_ckpt = False
725
+ if ckpt_path is not None:
726
+ self.init_from_ckpt(ckpt_path, ignore_keys)
727
+ self.restarted_from_ckpt = True
728
+ if reset_ema:
729
+ assert self.use_ema
730
+ print(
731
+ f"Resetting ema to pure model weights. This is useful when restoring from an ema-only checkpoint."
732
+ )
733
+ self.model_ema = LitEma(self.model)
734
+ if reset_num_ema_updates:
735
+ print(
736
+ " +++++++++++ WARNING: RESETTING NUM_EMA UPDATES TO ZERO +++++++++++ "
737
+ )
738
+ assert self.use_ema
739
+ self.model_ema.reset_num_updates()
740
+
741
+ def make_cond_schedule(
742
+ self,
743
+ ):
744
+ self.cond_ids = torch.full(
745
+ size=(self.num_timesteps,),
746
+ fill_value=self.num_timesteps - 1,
747
+ dtype=torch.long,
748
+ )
749
+ ids = torch.round(
750
+ torch.linspace(0, self.num_timesteps - 1, self.num_timesteps_cond)
751
+ ).long()
752
+ self.cond_ids[: self.num_timesteps_cond] = ids
753
+
754
+ @torch.no_grad()
755
+ def on_train_batch_start(self, batch, batch_idx, dataloader_idx):
756
+ # only for very first batch
757
+ if (
758
+ self.scale_by_std
759
+ and self.current_epoch == 0
760
+ and self.global_step == 0
761
+ and batch_idx == 0
762
+ and not self.restarted_from_ckpt
763
+ ):
764
+ assert (
765
+ self.scale_factor == 1.0
766
+ ), "rather not use custom rescaling and std-rescaling simultaneously"
767
+ # set rescale weight to 1./std of encodings
768
+ print("### USING STD-RESCALING ###")
769
+ x = super().get_input(batch, self.first_stage_key)
770
+ x = x.to(self.device)
771
+ encoder_posterior = self.encode_first_stage(x)
772
+ z = self.get_first_stage_encoding(encoder_posterior).detach()
773
+ del self.scale_factor
774
+ self.register_buffer("scale_factor", 1.0 / z.flatten().std())
775
+ print(f"setting self.scale_factor to {self.scale_factor}")
776
+ print("### USING STD-RESCALING ###")
777
+
778
+ def register_schedule(
779
+ self,
780
+ given_betas=None,
781
+ beta_schedule="linear",
782
+ timesteps=1000,
783
+ linear_start=1e-4,
784
+ linear_end=2e-2,
785
+ cosine_s=8e-3,
786
+ ):
787
+ super().register_schedule(
788
+ given_betas, beta_schedule, timesteps, linear_start, linear_end, cosine_s
789
+ )
790
+
791
+ self.shorten_cond_schedule = self.num_timesteps_cond > 1
792
+ if self.shorten_cond_schedule:
793
+ self.make_cond_schedule()
794
+
795
+ def instantiate_first_stage(self, config):
796
+ model = instantiate_from_config(config)
797
+ self.first_stage_model = model.eval()
798
+ self.first_stage_model.train = disabled_train
799
+ for param in self.first_stage_model.parameters():
800
+ param.requires_grad = False
801
+
802
+ def instantiate_cond_stage(self, config):
803
+ if not self.cond_stage_trainable:
804
+ if config == "__is_first_stage__":
805
+ print("Using first stage also as cond stage.")
806
+ self.cond_stage_model = self.first_stage_model
807
+ elif config == "__is_unconditional__":
808
+ print(f"Training {self.__class__.__name__} as an unconditional model.")
809
+ self.cond_stage_model = None
810
+ # self.be_unconditional = True
811
+ else:
812
+ model = instantiate_from_config(config)
813
+ self.cond_stage_model = model.eval()
814
+ self.cond_stage_model.train = disabled_train
815
+ for param in self.cond_stage_model.parameters():
816
+ param.requires_grad = False
817
+ else:
818
+ assert config != "__is_first_stage__"
819
+ assert config != "__is_unconditional__"
820
+ model = instantiate_from_config(config)
821
+ self.cond_stage_model = model
822
+
823
+ def _get_denoise_row_from_list(
824
+ self, samples, desc="", force_no_decoder_quantization=False
825
+ ):
826
+ denoise_row = []
827
+ for zd in tqdm(samples, desc=desc):
828
+ denoise_row.append(
829
+ self.decode_first_stage(
830
+ zd.to(self.device), force_not_quantize=force_no_decoder_quantization
831
+ )
832
+ )
833
+ n_imgs_per_row = len(denoise_row)
834
+ denoise_row = torch.stack(denoise_row) # n_log_step, n_row, C, H, W
835
+ denoise_grid = rearrange(denoise_row, "n b c h w -> b n c h w")
836
+ denoise_grid = rearrange(denoise_grid, "b n c h w -> (b n) c h w")
837
+ denoise_grid = make_grid(denoise_grid, nrow=n_imgs_per_row)
838
+ return denoise_grid
839
+
840
+ def get_first_stage_encoding(self, encoder_posterior):
841
+ if isinstance(encoder_posterior, DiagonalGaussianDistribution):
842
+ z = encoder_posterior.sample()
843
+ elif isinstance(encoder_posterior, torch.Tensor):
844
+ z = encoder_posterior
845
+ else:
846
+ raise NotImplementedError(
847
+ f"encoder_posterior of type '{type(encoder_posterior)}' not yet implemented"
848
+ )
849
+ return self.scale_factor * z
850
+
851
+ def get_learned_conditioning(self, c):
852
+ if self.cond_stage_forward is None:
853
+ if hasattr(self.cond_stage_model, "encode") and callable(
854
+ self.cond_stage_model.encode
855
+ ):
856
+ c = self.cond_stage_model.encode(c)
857
+ if isinstance(c, DiagonalGaussianDistribution):
858
+ c = c.mode()
859
+ else:
860
+ c = self.cond_stage_model(c)
861
+ else:
862
+ assert hasattr(self.cond_stage_model, self.cond_stage_forward)
863
+ c = getattr(self.cond_stage_model, self.cond_stage_forward)(c)
864
+ return c
865
+
866
+ def meshgrid(self, h, w):
867
+ y = torch.arange(0, h).view(h, 1, 1).repeat(1, w, 1)
868
+ x = torch.arange(0, w).view(1, w, 1).repeat(h, 1, 1)
869
+
870
+ arr = torch.cat([y, x], dim=-1)
871
+ return arr
872
+
873
+ def delta_border(self, h, w):
874
+ """
875
+ :param h: height
876
+ :param w: width
877
+ :return: normalized distance to image border,
878
+ wtith min distance = 0 at border and max dist = 0.5 at image center
879
+ """
880
+ lower_right_corner = torch.tensor([h - 1, w - 1]).view(1, 1, 2)
881
+ arr = self.meshgrid(h, w) / lower_right_corner
882
+ dist_left_up = torch.min(arr, dim=-1, keepdims=True)[0]
883
+ dist_right_down = torch.min(1 - arr, dim=-1, keepdims=True)[0]
884
+ edge_dist = torch.min(
885
+ torch.cat([dist_left_up, dist_right_down], dim=-1), dim=-1
886
+ )[0]
887
+ return edge_dist
888
+
889
+ def get_weighting(self, h, w, Ly, Lx, device):
890
+ weighting = self.delta_border(h, w)
891
+ weighting = torch.clip(
892
+ weighting,
893
+ self.split_input_params["clip_min_weight"],
894
+ self.split_input_params["clip_max_weight"],
895
+ )
896
+ weighting = weighting.view(1, h * w, 1).repeat(1, 1, Ly * Lx).to(device)
897
+
898
+ if self.split_input_params["tie_braker"]:
899
+ L_weighting = self.delta_border(Ly, Lx)
900
+ L_weighting = torch.clip(
901
+ L_weighting,
902
+ self.split_input_params["clip_min_tie_weight"],
903
+ self.split_input_params["clip_max_tie_weight"],
904
+ )
905
+
906
+ L_weighting = L_weighting.view(1, 1, Ly * Lx).to(device)
907
+ weighting = weighting * L_weighting
908
+ return weighting
909
+
910
+ def get_fold_unfold(
911
+ self, x, kernel_size, stride, uf=1, df=1
912
+ ): # todo load once not every time, shorten code
913
+ """
914
+ :param x: img of size (bs, c, h, w)
915
+ :return: n img crops of size (n, bs, c, kernel_size[0], kernel_size[1])
916
+ """
917
+ bs, nc, h, w = x.shape
918
+
919
+ # number of crops in image
920
+ Ly = (h - kernel_size[0]) // stride[0] + 1
921
+ Lx = (w - kernel_size[1]) // stride[1] + 1
922
+
923
+ if uf == 1 and df == 1:
924
+ fold_params = dict(
925
+ kernel_size=kernel_size, dilation=1, padding=0, stride=stride
926
+ )
927
+ unfold = torch.nn.Unfold(**fold_params)
928
+
929
+ fold = torch.nn.Fold(output_size=x.shape[2:], **fold_params)
930
+
931
+ weighting = self.get_weighting(
932
+ kernel_size[0], kernel_size[1], Ly, Lx, x.device
933
+ ).to(x.dtype)
934
+ normalization = fold(weighting).view(1, 1, h, w) # normalizes the overlap
935
+ weighting = weighting.view((1, 1, kernel_size[0], kernel_size[1], Ly * Lx))
936
+
937
+ elif uf > 1 and df == 1:
938
+ fold_params = dict(
939
+ kernel_size=kernel_size, dilation=1, padding=0, stride=stride
940
+ )
941
+ unfold = torch.nn.Unfold(**fold_params)
942
+
943
+ fold_params2 = dict(
944
+ kernel_size=(kernel_size[0] * uf, kernel_size[0] * uf),
945
+ dilation=1,
946
+ padding=0,
947
+ stride=(stride[0] * uf, stride[1] * uf),
948
+ )
949
+ fold = torch.nn.Fold(
950
+ output_size=(x.shape[2] * uf, x.shape[3] * uf), **fold_params2
951
+ )
952
+
953
+ weighting = self.get_weighting(
954
+ kernel_size[0] * uf, kernel_size[1] * uf, Ly, Lx, x.device
955
+ ).to(x.dtype)
956
+ normalization = fold(weighting).view(
957
+ 1, 1, h * uf, w * uf
958
+ ) # normalizes the overlap
959
+ weighting = weighting.view(
960
+ (1, 1, kernel_size[0] * uf, kernel_size[1] * uf, Ly * Lx)
961
+ )
962
+
963
+ elif df > 1 and uf == 1:
964
+ fold_params = dict(
965
+ kernel_size=kernel_size, dilation=1, padding=0, stride=stride
966
+ )
967
+ unfold = torch.nn.Unfold(**fold_params)
968
+
969
+ fold_params2 = dict(
970
+ kernel_size=(kernel_size[0] // df, kernel_size[0] // df),
971
+ dilation=1,
972
+ padding=0,
973
+ stride=(stride[0] // df, stride[1] // df),
974
+ )
975
+ fold = torch.nn.Fold(
976
+ output_size=(x.shape[2] // df, x.shape[3] // df), **fold_params2
977
+ )
978
+
979
+ weighting = self.get_weighting(
980
+ kernel_size[0] // df, kernel_size[1] // df, Ly, Lx, x.device
981
+ ).to(x.dtype)
982
+ normalization = fold(weighting).view(
983
+ 1, 1, h // df, w // df
984
+ ) # normalizes the overlap
985
+ weighting = weighting.view(
986
+ (1, 1, kernel_size[0] // df, kernel_size[1] // df, Ly * Lx)
987
+ )
988
+
989
+ else:
990
+ raise NotImplementedError
991
+
992
+ return fold, unfold, normalization, weighting
993
+
994
+ @torch.no_grad()
995
+ def get_input(
996
+ self,
997
+ batch,
998
+ k,
999
+ return_first_stage_outputs=False,
1000
+ force_c_encode=False,
1001
+ cond_key=None,
1002
+ return_original_cond=False,
1003
+ bs=None,
1004
+ return_x=False,
1005
+ mask_k=None,
1006
+ ):
1007
+ x = super().get_input(batch, k)
1008
+ if bs is not None:
1009
+ x = x[:bs]
1010
+ x = x.to(self.device)
1011
+ encoder_posterior = self.encode_first_stage(x)
1012
+ z = self.get_first_stage_encoding(encoder_posterior).detach()
1013
+
1014
+ if mask_k is not None:
1015
+ mx = super().get_input(batch, mask_k)
1016
+ if bs is not None:
1017
+ mx = mx[:bs]
1018
+ mx = mx.to(self.device)
1019
+ encoder_posterior = self.encode_first_stage(mx)
1020
+ mx = self.get_first_stage_encoding(encoder_posterior).detach()
1021
+
1022
+ if self.model.conditioning_key is not None and not self.force_null_conditioning:
1023
+ if cond_key is None:
1024
+ cond_key = self.cond_stage_key
1025
+ if cond_key != self.first_stage_key:
1026
+ if cond_key in ["caption", "coordinates_bbox", "txt"]:
1027
+ xc = batch[cond_key]
1028
+ elif cond_key in ["class_label", "cls"]:
1029
+ xc = batch
1030
+ else:
1031
+ xc = super().get_input(batch, cond_key).to(self.device)
1032
+ else:
1033
+ xc = x
1034
+ if not self.cond_stage_trainable or force_c_encode:
1035
+ if isinstance(xc, dict) or isinstance(xc, list):
1036
+ c = self.get_learned_conditioning(xc)
1037
+ else:
1038
+ c = self.get_learned_conditioning(xc.to(self.device))
1039
+ else:
1040
+ c = xc
1041
+ if bs is not None:
1042
+ c = c[:bs]
1043
+
1044
+ if self.use_positional_encodings:
1045
+ pos_x, pos_y = self.compute_latent_shifts(batch)
1046
+ ckey = __conditioning_keys__[self.model.conditioning_key]
1047
+ c = {ckey: c, "pos_x": pos_x, "pos_y": pos_y}
1048
+
1049
+ else:
1050
+ c = None
1051
+ xc = None
1052
+ if self.use_positional_encodings:
1053
+ pos_x, pos_y = self.compute_latent_shifts(batch)
1054
+ c = {"pos_x": pos_x, "pos_y": pos_y}
1055
+ out = [z, c]
1056
+ if return_first_stage_outputs:
1057
+ xrec = self.decode_first_stage(z)
1058
+ out.extend([x, xrec])
1059
+ if return_x:
1060
+ out.extend([x])
1061
+ if return_original_cond:
1062
+ out.append(xc)
1063
+ if mask_k:
1064
+ out.append(mx)
1065
+ return out
1066
+
1067
+ @torch.no_grad()
1068
+ def decode_first_stage(self, z, predict_cids=False, force_not_quantize=False):
1069
+ if predict_cids:
1070
+ if z.dim() == 4:
1071
+ z = torch.argmax(z.exp(), dim=1).long()
1072
+ z = self.first_stage_model.quantize.get_codebook_entry(z, shape=None)
1073
+ z = rearrange(z, "b h w c -> b c h w").contiguous()
1074
+
1075
+ z = 1.0 / self.scale_factor * z
1076
+ return self.first_stage_model.decode(z)
1077
+
1078
+ def decode_first_stage_grad(self, z, predict_cids=False, force_not_quantize=False):
1079
+ if predict_cids:
1080
+ if z.dim() == 4:
1081
+ z = torch.argmax(z.exp(), dim=1).long()
1082
+ z = self.first_stage_model.quantize.get_codebook_entry(z, shape=None)
1083
+ z = rearrange(z, "b h w c -> b c h w").contiguous()
1084
+
1085
+ z = 1.0 / self.scale_factor * z
1086
+ return self.first_stage_model.decode(z)
1087
+
1088
+ @torch.no_grad()
1089
+ def encode_first_stage(self, x):
1090
+ return self.first_stage_model.encode(x)
1091
+
1092
+ def shared_step(self, batch, **kwargs):
1093
+ x, c = self.get_input(batch, self.first_stage_key)
1094
+ loss = self(x, c)
1095
+ return loss
1096
+
1097
+ def forward(self, x, c, *args, **kwargs):
1098
+ t = torch.randint(
1099
+ 0, self.num_timesteps, (x.shape[0],), device=self.device
1100
+ ).long()
1101
+ # t = torch.randint(500, 501, (x.shape[0],), device=self.device).long()
1102
+ if self.model.conditioning_key is not None:
1103
+ assert c is not None
1104
+ if self.cond_stage_trainable:
1105
+ c = self.get_learned_conditioning(c)
1106
+ if self.shorten_cond_schedule: # TODO: drop this option
1107
+ tc = self.cond_ids[t].to(self.device)
1108
+ c = self.q_sample(x_start=c, t=tc, noise=torch.randn_like(c.float()))
1109
+ return self.p_losses(x, c, t, *args, **kwargs)
1110
+
1111
+ def apply_model(self, x_noisy, t, cond, return_ids=False):
1112
+ if isinstance(cond, dict):
1113
+ # hybrid case, cond is expected to be a dict
1114
+ pass
1115
+ else:
1116
+ if not isinstance(cond, list):
1117
+ cond = [cond]
1118
+ key = (
1119
+ "c_concat" if self.model.conditioning_key == "concat" else "c_crossattn"
1120
+ )
1121
+ cond = {key: cond}
1122
+
1123
+ x_recon = self.model(x_noisy, t, **cond)
1124
+
1125
+ if isinstance(x_recon, tuple) and not return_ids:
1126
+ return x_recon[0]
1127
+ else:
1128
+ return x_recon
1129
+
1130
+ def _predict_eps_from_xstart(self, x_t, t, pred_xstart):
1131
+ return (
1132
+ extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t
1133
+ - pred_xstart
1134
+ ) / extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape)
1135
+
1136
+ def _prior_bpd(self, x_start):
1137
+ """
1138
+ Get the prior KL term for the variational lower-bound, measured in
1139
+ bits-per-dim.
1140
+ This term can't be optimized, as it only depends on the encoder.
1141
+ :param x_start: the [N x C x ...] tensor of inputs.
1142
+ :return: a batch of [N] KL values (in bits), one per batch element.
1143
+ """
1144
+ batch_size = x_start.shape[0]
1145
+ t = torch.tensor([self.num_timesteps - 1] * batch_size, device=x_start.device)
1146
+ qt_mean, _, qt_log_variance = self.q_mean_variance(x_start, t)
1147
+ kl_prior = normal_kl(
1148
+ mean1=qt_mean, logvar1=qt_log_variance, mean2=0.0, logvar2=0.0
1149
+ )
1150
+ return mean_flat(kl_prior) / np.log(2.0)
1151
+
1152
+ def p_mean_variance(
1153
+ self,
1154
+ x,
1155
+ c,
1156
+ t,
1157
+ clip_denoised: bool,
1158
+ return_codebook_ids=False,
1159
+ quantize_denoised=False,
1160
+ return_x0=False,
1161
+ score_corrector=None,
1162
+ corrector_kwargs=None,
1163
+ ):
1164
+ t_in = t
1165
+ model_out = self.apply_model(x, t_in, c, return_ids=return_codebook_ids)
1166
+
1167
+ if score_corrector is not None:
1168
+ assert self.parameterization == "eps"
1169
+ model_out = score_corrector.modify_score(
1170
+ self, model_out, x, t, c, **corrector_kwargs
1171
+ )
1172
+
1173
+ if return_codebook_ids:
1174
+ model_out, logits = model_out
1175
+
1176
+ if self.parameterization == "eps":
1177
+ x_recon = self.predict_start_from_noise(x, t=t, noise=model_out)
1178
+ elif self.parameterization == "x0":
1179
+ x_recon = model_out
1180
+ else:
1181
+ raise NotImplementedError()
1182
+
1183
+ if clip_denoised:
1184
+ x_recon.clamp_(-1.0, 1.0)
1185
+ if quantize_denoised:
1186
+ x_recon, _, [_, _, indices] = self.first_stage_model.quantize(x_recon)
1187
+ model_mean, posterior_variance, posterior_log_variance = self.q_posterior(
1188
+ x_start=x_recon, x_t=x, t=t
1189
+ )
1190
+ if return_codebook_ids:
1191
+ return model_mean, posterior_variance, posterior_log_variance, logits
1192
+ elif return_x0:
1193
+ return model_mean, posterior_variance, posterior_log_variance, x_recon
1194
+ else:
1195
+ return model_mean, posterior_variance, posterior_log_variance
1196
+
1197
+ @torch.no_grad()
1198
+ def p_sample(
1199
+ self,
1200
+ x,
1201
+ c,
1202
+ t,
1203
+ clip_denoised=False,
1204
+ repeat_noise=False,
1205
+ return_codebook_ids=False,
1206
+ quantize_denoised=False,
1207
+ return_x0=False,
1208
+ temperature=1.0,
1209
+ noise_dropout=0.0,
1210
+ score_corrector=None,
1211
+ corrector_kwargs=None,
1212
+ ):
1213
+ b, *_, device = *x.shape, x.device
1214
+ outputs = self.p_mean_variance(
1215
+ x=x,
1216
+ c=c,
1217
+ t=t,
1218
+ clip_denoised=clip_denoised,
1219
+ return_codebook_ids=return_codebook_ids,
1220
+ quantize_denoised=quantize_denoised,
1221
+ return_x0=return_x0,
1222
+ score_corrector=score_corrector,
1223
+ corrector_kwargs=corrector_kwargs,
1224
+ )
1225
+ if return_codebook_ids:
1226
+ raise DeprecationWarning("Support dropped.")
1227
+ model_mean, _, model_log_variance, logits = outputs
1228
+ elif return_x0:
1229
+ model_mean, _, model_log_variance, x0 = outputs
1230
+ else:
1231
+ model_mean, _, model_log_variance = outputs
1232
+
1233
+ noise = noise_like(x.shape, device, repeat_noise) * temperature
1234
+ if noise_dropout > 0.0:
1235
+ noise = torch.nn.functional.dropout(noise, p=noise_dropout)
1236
+ # no noise when t == 0
1237
+ nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1)))
1238
+
1239
+ if return_codebook_ids:
1240
+ return model_mean + nonzero_mask * (
1241
+ 0.5 * model_log_variance
1242
+ ).exp() * noise, logits.argmax(dim=1)
1243
+ if return_x0:
1244
+ return (
1245
+ model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise,
1246
+ x0,
1247
+ )
1248
+ else:
1249
+ return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise
1250
+
1251
+ @torch.no_grad()
1252
+ def progressive_denoising(
1253
+ self,
1254
+ cond,
1255
+ shape,
1256
+ verbose=True,
1257
+ callback=None,
1258
+ quantize_denoised=False,
1259
+ img_callback=None,
1260
+ mask=None,
1261
+ x0=None,
1262
+ temperature=1.0,
1263
+ noise_dropout=0.0,
1264
+ score_corrector=None,
1265
+ corrector_kwargs=None,
1266
+ batch_size=None,
1267
+ x_T=None,
1268
+ start_T=None,
1269
+ log_every_t=None,
1270
+ ):
1271
+ if not log_every_t:
1272
+ log_every_t = self.log_every_t
1273
+ timesteps = self.num_timesteps
1274
+ if batch_size is not None:
1275
+ b = batch_size if batch_size is not None else shape[0]
1276
+ shape = [batch_size] + list(shape)
1277
+ else:
1278
+ b = batch_size = shape[0]
1279
+ if x_T is None:
1280
+ img = torch.randn(shape, device=self.device)
1281
+ else:
1282
+ img = x_T
1283
+ intermediates = []
1284
+ if cond is not None:
1285
+ if isinstance(cond, dict):
1286
+ cond = {
1287
+ key: cond[key][:batch_size]
1288
+ if not isinstance(cond[key], list)
1289
+ else list(map(lambda x: x[:batch_size], cond[key]))
1290
+ for key in cond
1291
+ }
1292
+ else:
1293
+ cond = (
1294
+ [c[:batch_size] for c in cond]
1295
+ if isinstance(cond, list)
1296
+ else cond[:batch_size]
1297
+ )
1298
+
1299
+ if start_T is not None:
1300
+ timesteps = min(timesteps, start_T)
1301
+ iterator = (
1302
+ tqdm(
1303
+ reversed(range(0, timesteps)),
1304
+ desc="Progressive Generation",
1305
+ total=timesteps,
1306
+ )
1307
+ if verbose
1308
+ else reversed(range(0, timesteps))
1309
+ )
1310
+ if type(temperature) == float:
1311
+ temperature = [temperature] * timesteps
1312
+
1313
+ for i in iterator:
1314
+ ts = torch.full((b,), i, device=self.device, dtype=torch.long)
1315
+ if self.shorten_cond_schedule:
1316
+ assert self.model.conditioning_key != "hybrid"
1317
+ tc = self.cond_ids[ts].to(cond.device)
1318
+ cond = self.q_sample(x_start=cond, t=tc, noise=torch.randn_like(cond))
1319
+
1320
+ img, x0_partial = self.p_sample(
1321
+ img,
1322
+ cond,
1323
+ ts,
1324
+ clip_denoised=self.clip_denoised,
1325
+ quantize_denoised=quantize_denoised,
1326
+ return_x0=True,
1327
+ temperature=temperature[i],
1328
+ noise_dropout=noise_dropout,
1329
+ score_corrector=score_corrector,
1330
+ corrector_kwargs=corrector_kwargs,
1331
+ )
1332
+ if mask is not None:
1333
+ assert x0 is not None
1334
+ img_orig = self.q_sample(x0, ts)
1335
+ img = img_orig * mask + (1.0 - mask) * img
1336
+
1337
+ if i % log_every_t == 0 or i == timesteps - 1:
1338
+ intermediates.append(x0_partial)
1339
+ if callback:
1340
+ callback(i)
1341
+ if img_callback:
1342
+ img_callback(img, i)
1343
+ return img, intermediates
1344
+
1345
+ @torch.no_grad()
1346
+ def p_sample_loop(
1347
+ self,
1348
+ cond,
1349
+ shape,
1350
+ return_intermediates=False,
1351
+ x_T=None,
1352
+ verbose=True,
1353
+ callback=None,
1354
+ timesteps=None,
1355
+ quantize_denoised=False,
1356
+ mask=None,
1357
+ x0=None,
1358
+ img_callback=None,
1359
+ start_T=None,
1360
+ log_every_t=None,
1361
+ ):
1362
+ if not log_every_t:
1363
+ log_every_t = self.log_every_t
1364
+ device = self.betas.device
1365
+ b = shape[0]
1366
+ if x_T is None:
1367
+ img = torch.randn(shape, device=device)
1368
+ else:
1369
+ img = x_T
1370
+
1371
+ intermediates = [img]
1372
+ if timesteps is None:
1373
+ timesteps = self.num_timesteps
1374
+
1375
+ if start_T is not None:
1376
+ timesteps = min(timesteps, start_T)
1377
+ iterator = (
1378
+ tqdm(reversed(range(0, timesteps)), desc="Sampling t", total=timesteps)
1379
+ if verbose
1380
+ else reversed(range(0, timesteps))
1381
+ )
1382
+
1383
+ if mask is not None:
1384
+ assert x0 is not None
1385
+ assert x0.shape[2:3] == mask.shape[2:3] # spatial size has to match
1386
+
1387
+ for i in iterator:
1388
+ ts = torch.full((b,), i, device=device, dtype=torch.long)
1389
+ if self.shorten_cond_schedule:
1390
+ assert self.model.conditioning_key != "hybrid"
1391
+ tc = self.cond_ids[ts].to(cond.device)
1392
+ cond = self.q_sample(x_start=cond, t=tc, noise=torch.randn_like(cond))
1393
+
1394
+ img = self.p_sample(
1395
+ img,
1396
+ cond,
1397
+ ts,
1398
+ clip_denoised=self.clip_denoised,
1399
+ quantize_denoised=quantize_denoised,
1400
+ )
1401
+ if mask is not None:
1402
+ img_orig = self.q_sample(x0, ts)
1403
+ img = img_orig * mask + (1.0 - mask) * img
1404
+
1405
+ if i % log_every_t == 0 or i == timesteps - 1:
1406
+ intermediates.append(img)
1407
+ if callback:
1408
+ callback(i)
1409
+ if img_callback:
1410
+ img_callback(img, i)
1411
+
1412
+ if return_intermediates:
1413
+ return img, intermediates
1414
+ return img
1415
+
1416
+ @torch.no_grad()
1417
+ def sample(
1418
+ self,
1419
+ cond,
1420
+ batch_size=16,
1421
+ return_intermediates=False,
1422
+ x_T=None,
1423
+ verbose=True,
1424
+ timesteps=None,
1425
+ quantize_denoised=False,
1426
+ mask=None,
1427
+ x0=None,
1428
+ shape=None,
1429
+ **kwargs,
1430
+ ):
1431
+ if shape is None:
1432
+ shape = (batch_size, self.channels, self.image_size, self.image_size)
1433
+ if cond is not None:
1434
+ if isinstance(cond, dict):
1435
+ cond = {
1436
+ key: cond[key][:batch_size]
1437
+ if not isinstance(cond[key], list)
1438
+ else list(map(lambda x: x[:batch_size], cond[key]))
1439
+ for key in cond
1440
+ }
1441
+ else:
1442
+ cond = (
1443
+ [c[:batch_size] for c in cond]
1444
+ if isinstance(cond, list)
1445
+ else cond[:batch_size]
1446
+ )
1447
+ return self.p_sample_loop(
1448
+ cond,
1449
+ shape,
1450
+ return_intermediates=return_intermediates,
1451
+ x_T=x_T,
1452
+ verbose=verbose,
1453
+ timesteps=timesteps,
1454
+ quantize_denoised=quantize_denoised,
1455
+ mask=mask,
1456
+ x0=x0,
1457
+ )
1458
+
1459
+ @torch.no_grad()
1460
+ def sample_log(self, cond, batch_size, ddim, ddim_steps, **kwargs):
1461
+ if ddim:
1462
+ ddim_sampler = DDIMSampler(self)
1463
+ shape = (self.channels, self.image_size, self.image_size)
1464
+ samples, intermediates = ddim_sampler.sample(
1465
+ ddim_steps, batch_size, shape, cond, verbose=False, **kwargs
1466
+ )
1467
+
1468
+ else:
1469
+ samples, intermediates = self.sample(
1470
+ cond=cond, batch_size=batch_size, return_intermediates=True, **kwargs
1471
+ )
1472
+
1473
+ return samples, intermediates
1474
+
1475
+ @torch.no_grad()
1476
+ def get_unconditional_conditioning(self, batch_size, null_label=None):
1477
+ if null_label is not None:
1478
+ xc = null_label
1479
+ if isinstance(xc, ListConfig):
1480
+ xc = list(xc)
1481
+ if isinstance(xc, dict) or isinstance(xc, list):
1482
+ c = self.get_learned_conditioning(xc)
1483
+ else:
1484
+ if hasattr(xc, "to"):
1485
+ xc = xc.to(self.device)
1486
+ c = self.get_learned_conditioning(xc)
1487
+ else:
1488
+ if self.cond_stage_key in ["class_label", "cls"]:
1489
+ xc = self.cond_stage_model.get_unconditional_conditioning(
1490
+ batch_size, device=self.device
1491
+ )
1492
+ return self.get_learned_conditioning(xc)
1493
+ else:
1494
+ raise NotImplementedError("todo")
1495
+ if isinstance(c, list): # in case the encoder gives us a list
1496
+ for i in range(len(c)):
1497
+ c[i] = repeat(c[i], "1 ... -> b ...", b=batch_size).to(self.device)
1498
+ else:
1499
+ c = repeat(c, "1 ... -> b ...", b=batch_size).to(self.device)
1500
+ return c
1501
+
1502
+ @torch.no_grad()
1503
+ def log_images(
1504
+ self,
1505
+ batch,
1506
+ N=8,
1507
+ n_row=4,
1508
+ sample=True,
1509
+ ddim_steps=50,
1510
+ ddim_eta=0.0,
1511
+ return_keys=None,
1512
+ quantize_denoised=True,
1513
+ inpaint=True,
1514
+ plot_denoise_rows=False,
1515
+ plot_progressive_rows=True,
1516
+ plot_diffusion_rows=True,
1517
+ unconditional_guidance_scale=1.0,
1518
+ unconditional_guidance_label=None,
1519
+ use_ema_scope=True,
1520
+ **kwargs,
1521
+ ):
1522
+ ema_scope = self.ema_scope if use_ema_scope else nullcontext
1523
+ use_ddim = ddim_steps is not None
1524
+
1525
+ log = dict()
1526
+ z, c, x, xrec, xc = self.get_input(
1527
+ batch,
1528
+ self.first_stage_key,
1529
+ return_first_stage_outputs=True,
1530
+ force_c_encode=True,
1531
+ return_original_cond=True,
1532
+ bs=N,
1533
+ )
1534
+ N = min(x.shape[0], N)
1535
+ n_row = min(x.shape[0], n_row)
1536
+ log["inputs"] = x
1537
+ log["reconstruction"] = xrec
1538
+ if self.model.conditioning_key is not None:
1539
+ if hasattr(self.cond_stage_model, "decode"):
1540
+ xc = self.cond_stage_model.decode(c)
1541
+ log["conditioning"] = xc
1542
+ elif self.cond_stage_key in ["caption", "txt"]:
1543
+ xc = log_txt_as_img(
1544
+ (x.shape[2], x.shape[3]),
1545
+ batch[self.cond_stage_key],
1546
+ size=x.shape[2] // 25,
1547
+ )
1548
+ log["conditioning"] = xc
1549
+ elif self.cond_stage_key in ["class_label", "cls"]:
1550
+ try:
1551
+ xc = log_txt_as_img(
1552
+ (x.shape[2], x.shape[3]),
1553
+ batch["human_label"],
1554
+ size=x.shape[2] // 25,
1555
+ )
1556
+ log["conditioning"] = xc
1557
+ except KeyError:
1558
+ # probably no "human_label" in batch
1559
+ pass
1560
+ elif isimage(xc):
1561
+ log["conditioning"] = xc
1562
+ if ismap(xc):
1563
+ log["original_conditioning"] = self.to_rgb(xc)
1564
+
1565
+ if plot_diffusion_rows:
1566
+ # get diffusion row
1567
+ diffusion_row = list()
1568
+ z_start = z[:n_row]
1569
+ for t in range(self.num_timesteps):
1570
+ if t % self.log_every_t == 0 or t == self.num_timesteps - 1:
1571
+ t = repeat(torch.tensor([t]), "1 -> b", b=n_row)
1572
+ t = t.to(self.device).long()
1573
+ noise = torch.randn_like(z_start)
1574
+ z_noisy = self.q_sample(x_start=z_start, t=t, noise=noise)
1575
+ diffusion_row.append(self.decode_first_stage(z_noisy))
1576
+
1577
+ diffusion_row = torch.stack(diffusion_row) # n_log_step, n_row, C, H, W
1578
+ diffusion_grid = rearrange(diffusion_row, "n b c h w -> b n c h w")
1579
+ diffusion_grid = rearrange(diffusion_grid, "b n c h w -> (b n) c h w")
1580
+ diffusion_grid = make_grid(diffusion_grid, nrow=diffusion_row.shape[0])
1581
+ log["diffusion_row"] = diffusion_grid
1582
+
1583
+ if sample:
1584
+ # get denoise row
1585
+ with ema_scope("Sampling"):
1586
+ samples, z_denoise_row = self.sample_log(
1587
+ cond=c,
1588
+ batch_size=N,
1589
+ ddim=use_ddim,
1590
+ ddim_steps=ddim_steps,
1591
+ eta=ddim_eta,
1592
+ )
1593
+ # samples, z_denoise_row = self.sample(cond=c, batch_size=N, return_intermediates=True)
1594
+ x_samples = self.decode_first_stage(samples)
1595
+ log["samples"] = x_samples
1596
+ if plot_denoise_rows:
1597
+ denoise_grid = self._get_denoise_row_from_list(z_denoise_row)
1598
+ log["denoise_row"] = denoise_grid
1599
+
1600
+ if (
1601
+ quantize_denoised
1602
+ and not isinstance(self.first_stage_model, AutoencoderKL)
1603
+ and not isinstance(self.first_stage_model, IdentityFirstStage)
1604
+ ):
1605
+ # also display when quantizing x0 while sampling
1606
+ with ema_scope("Plotting Quantized Denoised"):
1607
+ samples, z_denoise_row = self.sample_log(
1608
+ cond=c,
1609
+ batch_size=N,
1610
+ ddim=use_ddim,
1611
+ ddim_steps=ddim_steps,
1612
+ eta=ddim_eta,
1613
+ quantize_denoised=True,
1614
+ )
1615
+ # samples, z_denoise_row = self.sample(cond=c, batch_size=N, return_intermediates=True,
1616
+ # quantize_denoised=True)
1617
+ x_samples = self.decode_first_stage(samples.to(self.device))
1618
+ log["samples_x0_quantized"] = x_samples
1619
+
1620
+ if unconditional_guidance_scale > 1.0:
1621
+ uc = self.get_unconditional_conditioning(N, unconditional_guidance_label)
1622
+ if self.model.conditioning_key == "crossattn-adm":
1623
+ uc = {"c_crossattn": [uc], "c_adm": c["c_adm"]}
1624
+ with ema_scope("Sampling with classifier-free guidance"):
1625
+ samples_cfg, _ = self.sample_log(
1626
+ cond=c,
1627
+ batch_size=N,
1628
+ ddim=use_ddim,
1629
+ ddim_steps=ddim_steps,
1630
+ eta=ddim_eta,
1631
+ unconditional_guidance_scale=unconditional_guidance_scale,
1632
+ unconditional_conditioning=uc,
1633
+ )
1634
+ x_samples_cfg = self.decode_first_stage(samples_cfg)
1635
+ log[
1636
+ f"samples_cfg_scale_{unconditional_guidance_scale:.2f}"
1637
+ ] = x_samples_cfg
1638
+
1639
+ if inpaint:
1640
+ # make a simple center square
1641
+ b, h, w = z.shape[0], z.shape[2], z.shape[3]
1642
+ mask = torch.ones(N, h, w).to(self.device)
1643
+ # zeros will be filled in
1644
+ mask[:, h // 4 : 3 * h // 4, w // 4 : 3 * w // 4] = 0.0
1645
+ mask = mask[:, None, ...]
1646
+ with ema_scope("Plotting Inpaint"):
1647
+ samples, _ = self.sample_log(
1648
+ cond=c,
1649
+ batch_size=N,
1650
+ ddim=use_ddim,
1651
+ eta=ddim_eta,
1652
+ ddim_steps=ddim_steps,
1653
+ x0=z[:N],
1654
+ mask=mask,
1655
+ )
1656
+ x_samples = self.decode_first_stage(samples.to(self.device))
1657
+ log["samples_inpainting"] = x_samples
1658
+ log["mask"] = mask
1659
+
1660
+ # outpaint
1661
+ mask = 1.0 - mask
1662
+ with ema_scope("Plotting Outpaint"):
1663
+ samples, _ = self.sample_log(
1664
+ cond=c,
1665
+ batch_size=N,
1666
+ ddim=use_ddim,
1667
+ eta=ddim_eta,
1668
+ ddim_steps=ddim_steps,
1669
+ x0=z[:N],
1670
+ mask=mask,
1671
+ )
1672
+ x_samples = self.decode_first_stage(samples.to(self.device))
1673
+ log["samples_outpainting"] = x_samples
1674
+
1675
+ if plot_progressive_rows:
1676
+ with ema_scope("Plotting Progressives"):
1677
+ img, progressives = self.progressive_denoising(
1678
+ c,
1679
+ shape=(self.channels, self.image_size, self.image_size),
1680
+ batch_size=N,
1681
+ )
1682
+ prog_row = self._get_denoise_row_from_list(
1683
+ progressives, desc="Progressive Generation"
1684
+ )
1685
+ log["progressive_row"] = prog_row
1686
+
1687
+ if return_keys:
1688
+ if np.intersect1d(list(log.keys()), return_keys).shape[0] == 0:
1689
+ return log
1690
+ else:
1691
+ return {key: log[key] for key in return_keys}
1692
+ return log
1693
+
1694
+ def configure_optimizers(self):
1695
+ lr = self.learning_rate
1696
+ params = list(self.model.parameters())
1697
+ if self.cond_stage_trainable:
1698
+ print(f"{self.__class__.__name__}: Also optimizing conditioner params!")
1699
+ params = params + list(self.cond_stage_model.parameters())
1700
+ if self.learn_logvar:
1701
+ print("Diffusion model optimizing logvar")
1702
+ params.append(self.logvar)
1703
+ opt = torch.optim.AdamW(params, lr=lr)
1704
+ if self.use_scheduler:
1705
+ assert "target" in self.scheduler_config
1706
+ scheduler = instantiate_from_config(self.scheduler_config)
1707
+
1708
+ print("Setting up LambdaLR scheduler...")
1709
+ scheduler = [
1710
+ {
1711
+ "scheduler": LambdaLR(opt, lr_lambda=scheduler.schedule),
1712
+ "interval": "step",
1713
+ "frequency": 1,
1714
+ }
1715
+ ]
1716
+ return [opt], scheduler
1717
+ return opt
1718
+
1719
+ @torch.no_grad()
1720
+ def to_rgb(self, x):
1721
+ x = x.float()
1722
+ if not hasattr(self, "colorize"):
1723
+ self.colorize = torch.randn(3, x.shape[1], 1, 1).to(x)
1724
+ x = nn.functional.conv2d(x, weight=self.colorize)
1725
+ x = 2.0 * (x - x.min()) / (x.max() - x.min()) - 1.0
1726
+ return x
1727
+
1728
+
1729
+ class DiffusionWrapper(torch.nn.Module):
1730
+ def __init__(self, diff_model_config, conditioning_key):
1731
+ super().__init__()
1732
+ self.sequential_cross_attn = diff_model_config.pop(
1733
+ "sequential_crossattn", False
1734
+ )
1735
+ self.diffusion_model = instantiate_from_config(diff_model_config)
1736
+ self.conditioning_key = conditioning_key
1737
+ assert self.conditioning_key in [
1738
+ None,
1739
+ "concat",
1740
+ "crossattn",
1741
+ "hybrid",
1742
+ "adm",
1743
+ "hybrid-adm",
1744
+ "crossattn-adm",
1745
+ ]
1746
+
1747
+ def forward(
1748
+ self, x, t, c_concat: list = None, c_crossattn: list = None, c_adm=None
1749
+ ):
1750
+ if self.conditioning_key is None:
1751
+ out = self.diffusion_model(x, t)
1752
+ elif self.conditioning_key == "concat":
1753
+ xc = torch.cat([x] + c_concat, dim=1)
1754
+ out = self.diffusion_model(xc, t)
1755
+ elif self.conditioning_key == "crossattn":
1756
+ if not self.sequential_cross_attn:
1757
+ cc = torch.cat(c_crossattn, 1)
1758
+ else:
1759
+ cc = c_crossattn
1760
+ out = self.diffusion_model(x, t, context=cc)
1761
+ elif self.conditioning_key == "hybrid":
1762
+ xc = torch.cat([x] + c_concat, dim=1)
1763
+ cc = torch.cat(c_crossattn, 1)
1764
+ out = self.diffusion_model(xc, t, context=cc)
1765
+ elif self.conditioning_key == "hybrid-adm":
1766
+ assert c_adm is not None
1767
+ xc = torch.cat([x] + c_concat, dim=1)
1768
+ cc = torch.cat(c_crossattn, 1)
1769
+ out = self.diffusion_model(xc, t, context=cc, y=c_adm)
1770
+ elif self.conditioning_key == "crossattn-adm":
1771
+ assert c_adm is not None
1772
+ cc = torch.cat(c_crossattn, 1)
1773
+ out = self.diffusion_model(x, t, context=cc, y=c_adm)
1774
+ elif self.conditioning_key == "adm":
1775
+ cc = c_crossattn[0]
1776
+ out = self.diffusion_model(x, t, y=cc)
1777
+ else:
1778
+ raise NotImplementedError()
1779
+
1780
+ return out
1781
+
1782
+
1783
+ class LatentUpscaleDiffusion(LatentDiffusion):
1784
+ def __init__(
1785
+ self,
1786
+ *args,
1787
+ low_scale_config,
1788
+ low_scale_key="LR",
1789
+ noise_level_key=None,
1790
+ **kwargs,
1791
+ ):
1792
+ super().__init__(*args, **kwargs)
1793
+ # assumes that neither the cond_stage nor the low_scale_model contain trainable params
1794
+ assert not self.cond_stage_trainable
1795
+ self.instantiate_low_stage(low_scale_config)
1796
+ self.low_scale_key = low_scale_key
1797
+ self.noise_level_key = noise_level_key
1798
+
1799
+ def instantiate_low_stage(self, config):
1800
+ model = instantiate_from_config(config)
1801
+ self.low_scale_model = model.eval()
1802
+ self.low_scale_model.train = disabled_train
1803
+ for param in self.low_scale_model.parameters():
1804
+ param.requires_grad = False
1805
+
1806
+ @torch.no_grad()
1807
+ def get_input(self, batch, k, cond_key=None, bs=None, log_mode=False):
1808
+ if not log_mode:
1809
+ z, c = super().get_input(batch, k, force_c_encode=True, bs=bs)
1810
+ else:
1811
+ z, c, x, xrec, xc = super().get_input(
1812
+ batch,
1813
+ self.first_stage_key,
1814
+ return_first_stage_outputs=True,
1815
+ force_c_encode=True,
1816
+ return_original_cond=True,
1817
+ bs=bs,
1818
+ )
1819
+ x_low = batch[self.low_scale_key][:bs]
1820
+ x_low = rearrange(x_low, "b h w c -> b c h w")
1821
+ x_low = x_low.to(memory_format=torch.contiguous_format).float()
1822
+ zx, noise_level = self.low_scale_model(x_low)
1823
+ if self.noise_level_key is not None:
1824
+ # get noise level from batch instead, e.g. when extracting a custom noise level for bsr
1825
+ raise NotImplementedError("TODO")
1826
+
1827
+ all_conds = {"c_concat": [zx], "c_crossattn": [c], "c_adm": noise_level}
1828
+ if log_mode:
1829
+ # TODO: maybe disable if too expensive
1830
+ x_low_rec = self.low_scale_model.decode(zx)
1831
+ return z, all_conds, x, xrec, xc, x_low, x_low_rec, noise_level
1832
+ return z, all_conds
1833
+
1834
+ @torch.no_grad()
1835
+ def log_images(
1836
+ self,
1837
+ batch,
1838
+ N=8,
1839
+ n_row=4,
1840
+ sample=True,
1841
+ ddim_steps=200,
1842
+ ddim_eta=1.0,
1843
+ return_keys=None,
1844
+ plot_denoise_rows=False,
1845
+ plot_progressive_rows=True,
1846
+ plot_diffusion_rows=True,
1847
+ unconditional_guidance_scale=1.0,
1848
+ unconditional_guidance_label=None,
1849
+ use_ema_scope=True,
1850
+ **kwargs,
1851
+ ):
1852
+ ema_scope = self.ema_scope if use_ema_scope else nullcontext
1853
+ use_ddim = ddim_steps is not None
1854
+
1855
+ log = dict()
1856
+ z, c, x, xrec, xc, x_low, x_low_rec, noise_level = self.get_input(
1857
+ batch, self.first_stage_key, bs=N, log_mode=True
1858
+ )
1859
+ N = min(x.shape[0], N)
1860
+ n_row = min(x.shape[0], n_row)
1861
+ log["inputs"] = x
1862
+ log["reconstruction"] = xrec
1863
+ log["x_lr"] = x_low
1864
+ log[
1865
+ f"x_lr_rec_@noise_levels{'-'.join(map(lambda x: str(x), list(noise_level.cpu().numpy())))}"
1866
+ ] = x_low_rec
1867
+ if self.model.conditioning_key is not None:
1868
+ if hasattr(self.cond_stage_model, "decode"):
1869
+ xc = self.cond_stage_model.decode(c)
1870
+ log["conditioning"] = xc
1871
+ elif self.cond_stage_key in ["caption", "txt"]:
1872
+ xc = log_txt_as_img(
1873
+ (x.shape[2], x.shape[3]),
1874
+ batch[self.cond_stage_key],
1875
+ size=x.shape[2] // 25,
1876
+ )
1877
+ log["conditioning"] = xc
1878
+ elif self.cond_stage_key in ["class_label", "cls"]:
1879
+ xc = log_txt_as_img(
1880
+ (x.shape[2], x.shape[3]),
1881
+ batch["human_label"],
1882
+ size=x.shape[2] // 25,
1883
+ )
1884
+ log["conditioning"] = xc
1885
+ elif isimage(xc):
1886
+ log["conditioning"] = xc
1887
+ if ismap(xc):
1888
+ log["original_conditioning"] = self.to_rgb(xc)
1889
+
1890
+ if plot_diffusion_rows:
1891
+ # get diffusion row
1892
+ diffusion_row = list()
1893
+ z_start = z[:n_row]
1894
+ for t in range(self.num_timesteps):
1895
+ if t % self.log_every_t == 0 or t == self.num_timesteps - 1:
1896
+ t = repeat(torch.tensor([t]), "1 -> b", b=n_row)
1897
+ t = t.to(self.device).long()
1898
+ noise = torch.randn_like(z_start)
1899
+ z_noisy = self.q_sample(x_start=z_start, t=t, noise=noise)
1900
+ diffusion_row.append(self.decode_first_stage(z_noisy))
1901
+
1902
+ diffusion_row = torch.stack(diffusion_row) # n_log_step, n_row, C, H, W
1903
+ diffusion_grid = rearrange(diffusion_row, "n b c h w -> b n c h w")
1904
+ diffusion_grid = rearrange(diffusion_grid, "b n c h w -> (b n) c h w")
1905
+ diffusion_grid = make_grid(diffusion_grid, nrow=diffusion_row.shape[0])
1906
+ log["diffusion_row"] = diffusion_grid
1907
+
1908
+ if sample:
1909
+ # get denoise row
1910
+ with ema_scope("Sampling"):
1911
+ samples, z_denoise_row = self.sample_log(
1912
+ cond=c,
1913
+ batch_size=N,
1914
+ ddim=use_ddim,
1915
+ ddim_steps=ddim_steps,
1916
+ eta=ddim_eta,
1917
+ )
1918
+ # samples, z_denoise_row = self.sample(cond=c, batch_size=N, return_intermediates=True)
1919
+ x_samples = self.decode_first_stage(samples)
1920
+ log["samples"] = x_samples
1921
+ if plot_denoise_rows:
1922
+ denoise_grid = self._get_denoise_row_from_list(z_denoise_row)
1923
+ log["denoise_row"] = denoise_grid
1924
+
1925
+ if unconditional_guidance_scale > 1.0:
1926
+ uc_tmp = self.get_unconditional_conditioning(
1927
+ N, unconditional_guidance_label
1928
+ )
1929
+ # TODO explore better "unconditional" choices for the other keys
1930
+ # maybe guide away from empty text label and highest noise level and maximally degraded zx?
1931
+ uc = dict()
1932
+ for k in c:
1933
+ if k == "c_crossattn":
1934
+ assert isinstance(c[k], list) and len(c[k]) == 1
1935
+ uc[k] = [uc_tmp]
1936
+ elif k == "c_adm": # todo: only run with text-based guidance?
1937
+ assert isinstance(c[k], torch.Tensor)
1938
+ # uc[k] = torch.ones_like(c[k]) * self.low_scale_model.max_noise_level
1939
+ uc[k] = c[k]
1940
+ elif isinstance(c[k], list):
1941
+ uc[k] = [c[k][i] for i in range(len(c[k]))]
1942
+ else:
1943
+ uc[k] = c[k]
1944
+
1945
+ with ema_scope("Sampling with classifier-free guidance"):
1946
+ samples_cfg, _ = self.sample_log(
1947
+ cond=c,
1948
+ batch_size=N,
1949
+ ddim=use_ddim,
1950
+ ddim_steps=ddim_steps,
1951
+ eta=ddim_eta,
1952
+ unconditional_guidance_scale=unconditional_guidance_scale,
1953
+ unconditional_conditioning=uc,
1954
+ )
1955
+ x_samples_cfg = self.decode_first_stage(samples_cfg)
1956
+ log[
1957
+ f"samples_cfg_scale_{unconditional_guidance_scale:.2f}"
1958
+ ] = x_samples_cfg
1959
+
1960
+ if plot_progressive_rows:
1961
+ with ema_scope("Plotting Progressives"):
1962
+ img, progressives = self.progressive_denoising(
1963
+ c,
1964
+ shape=(self.channels, self.image_size, self.image_size),
1965
+ batch_size=N,
1966
+ )
1967
+ prog_row = self._get_denoise_row_from_list(
1968
+ progressives, desc="Progressive Generation"
1969
+ )
1970
+ log["progressive_row"] = prog_row
1971
+
1972
+ return log
1973
+
1974
+
1975
+ class LatentFinetuneDiffusion(LatentDiffusion):
1976
+ """
1977
+ Basis for different finetunas, such as inpainting or depth2image
1978
+ To disable finetuning mode, set finetune_keys to None
1979
+ """
1980
+
1981
+ def __init__(
1982
+ self,
1983
+ concat_keys: tuple,
1984
+ finetune_keys=(
1985
+ "model.diffusion_model.input_blocks.0.0.weight",
1986
+ "model_ema.diffusion_modelinput_blocks00weight",
1987
+ ),
1988
+ keep_finetune_dims=4,
1989
+ # if model was trained without concat mode before and we would like to keep these channels
1990
+ c_concat_log_start=None, # to log reconstruction of c_concat codes
1991
+ c_concat_log_end=None,
1992
+ *args,
1993
+ **kwargs,
1994
+ ):
1995
+ ckpt_path = kwargs.pop("ckpt_path", None)
1996
+ ignore_keys = kwargs.pop("ignore_keys", list())
1997
+ super().__init__(*args, **kwargs)
1998
+ self.finetune_keys = finetune_keys
1999
+ self.concat_keys = concat_keys
2000
+ self.keep_dims = keep_finetune_dims
2001
+ self.c_concat_log_start = c_concat_log_start
2002
+ self.c_concat_log_end = c_concat_log_end
2003
+ if exists(self.finetune_keys):
2004
+ assert exists(ckpt_path), "can only finetune from a given checkpoint"
2005
+ if exists(ckpt_path):
2006
+ self.init_from_ckpt(ckpt_path, ignore_keys)
2007
+
2008
+ def init_from_ckpt(self, path, ignore_keys=list(), only_model=False):
2009
+ sd = torch.load(path, map_location="cpu")
2010
+ if "state_dict" in list(sd.keys()):
2011
+ sd = sd["state_dict"]
2012
+ keys = list(sd.keys())
2013
+ for k in keys:
2014
+ for ik in ignore_keys:
2015
+ if k.startswith(ik):
2016
+ print("Deleting key {} from state_dict.".format(k))
2017
+ del sd[k]
2018
+
2019
+ # make it explicit, finetune by including extra input channels
2020
+ if exists(self.finetune_keys) and k in self.finetune_keys:
2021
+ new_entry = None
2022
+ for name, param in self.named_parameters():
2023
+ if name in self.finetune_keys:
2024
+ print(
2025
+ f"modifying key '{name}' and keeping its original {self.keep_dims} (channels) dimensions only"
2026
+ )
2027
+ new_entry = torch.zeros_like(param) # zero init
2028
+ assert exists(new_entry), "did not find matching parameter to modify"
2029
+ new_entry[:, : self.keep_dims, ...] = sd[k]
2030
+ sd[k] = new_entry
2031
+
2032
+ missing, unexpected = (
2033
+ self.load_state_dict(sd, strict=False)
2034
+ if not only_model
2035
+ else self.model.load_state_dict(sd, strict=False)
2036
+ )
2037
+ print(
2038
+ f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys"
2039
+ )
2040
+ if len(missing) > 0:
2041
+ print(f"Missing Keys: {missing}")
2042
+ if len(unexpected) > 0:
2043
+ print(f"Unexpected Keys: {unexpected}")
2044
+
2045
+ @torch.no_grad()
2046
+ def log_images(
2047
+ self,
2048
+ batch,
2049
+ N=8,
2050
+ n_row=4,
2051
+ sample=True,
2052
+ ddim_steps=200,
2053
+ ddim_eta=1.0,
2054
+ return_keys=None,
2055
+ quantize_denoised=True,
2056
+ inpaint=True,
2057
+ plot_denoise_rows=False,
2058
+ plot_progressive_rows=True,
2059
+ plot_diffusion_rows=True,
2060
+ unconditional_guidance_scale=1.0,
2061
+ unconditional_guidance_label=None,
2062
+ use_ema_scope=True,
2063
+ **kwargs,
2064
+ ):
2065
+ ema_scope = self.ema_scope if use_ema_scope else nullcontext
2066
+ use_ddim = ddim_steps is not None
2067
+
2068
+ log = dict()
2069
+ z, c, x, xrec, xc = self.get_input(
2070
+ batch, self.first_stage_key, bs=N, return_first_stage_outputs=True
2071
+ )
2072
+ c_cat, c = c["c_concat"][0], c["c_crossattn"][0]
2073
+ N = min(x.shape[0], N)
2074
+ n_row = min(x.shape[0], n_row)
2075
+ log["inputs"] = x
2076
+ log["reconstruction"] = xrec
2077
+ if self.model.conditioning_key is not None:
2078
+ if hasattr(self.cond_stage_model, "decode"):
2079
+ xc = self.cond_stage_model.decode(c)
2080
+ log["conditioning"] = xc
2081
+ elif self.cond_stage_key in ["caption", "txt"]:
2082
+ xc = log_txt_as_img(
2083
+ (x.shape[2], x.shape[3]),
2084
+ batch[self.cond_stage_key],
2085
+ size=x.shape[2] // 25,
2086
+ )
2087
+ log["conditioning"] = xc
2088
+ elif self.cond_stage_key in ["class_label", "cls"]:
2089
+ xc = log_txt_as_img(
2090
+ (x.shape[2], x.shape[3]),
2091
+ batch["human_label"],
2092
+ size=x.shape[2] // 25,
2093
+ )
2094
+ log["conditioning"] = xc
2095
+ elif isimage(xc):
2096
+ log["conditioning"] = xc
2097
+ if ismap(xc):
2098
+ log["original_conditioning"] = self.to_rgb(xc)
2099
+
2100
+ if not (self.c_concat_log_start is None and self.c_concat_log_end is None):
2101
+ log["c_concat_decoded"] = self.decode_first_stage(
2102
+ c_cat[:, self.c_concat_log_start : self.c_concat_log_end]
2103
+ )
2104
+
2105
+ if plot_diffusion_rows:
2106
+ # get diffusion row
2107
+ diffusion_row = list()
2108
+ z_start = z[:n_row]
2109
+ for t in range(self.num_timesteps):
2110
+ if t % self.log_every_t == 0 or t == self.num_timesteps - 1:
2111
+ t = repeat(torch.tensor([t]), "1 -> b", b=n_row)
2112
+ t = t.to(self.device).long()
2113
+ noise = torch.randn_like(z_start)
2114
+ z_noisy = self.q_sample(x_start=z_start, t=t, noise=noise)
2115
+ diffusion_row.append(self.decode_first_stage(z_noisy))
2116
+
2117
+ diffusion_row = torch.stack(diffusion_row) # n_log_step, n_row, C, H, W
2118
+ diffusion_grid = rearrange(diffusion_row, "n b c h w -> b n c h w")
2119
+ diffusion_grid = rearrange(diffusion_grid, "b n c h w -> (b n) c h w")
2120
+ diffusion_grid = make_grid(diffusion_grid, nrow=diffusion_row.shape[0])
2121
+ log["diffusion_row"] = diffusion_grid
2122
+
2123
+ if sample:
2124
+ # get denoise row
2125
+ with ema_scope("Sampling"):
2126
+ samples, z_denoise_row = self.sample_log(
2127
+ cond={"c_concat": [c_cat], "c_crossattn": [c]},
2128
+ batch_size=N,
2129
+ ddim=use_ddim,
2130
+ ddim_steps=ddim_steps,
2131
+ eta=ddim_eta,
2132
+ )
2133
+ # samples, z_denoise_row = self.sample(cond=c, batch_size=N, return_intermediates=True)
2134
+ x_samples = self.decode_first_stage(samples)
2135
+ log["samples"] = x_samples
2136
+ if plot_denoise_rows:
2137
+ denoise_grid = self._get_denoise_row_from_list(z_denoise_row)
2138
+ log["denoise_row"] = denoise_grid
2139
+
2140
+ if unconditional_guidance_scale > 1.0:
2141
+ uc_cross = self.get_unconditional_conditioning(
2142
+ N, unconditional_guidance_label
2143
+ )
2144
+ uc_cat = c_cat
2145
+ uc_full = {"c_concat": [uc_cat], "c_crossattn": [uc_cross]}
2146
+ with ema_scope("Sampling with classifier-free guidance"):
2147
+ samples_cfg, _ = self.sample_log(
2148
+ cond={"c_concat": [c_cat], "c_crossattn": [c]},
2149
+ batch_size=N,
2150
+ ddim=use_ddim,
2151
+ ddim_steps=ddim_steps,
2152
+ eta=ddim_eta,
2153
+ unconditional_guidance_scale=unconditional_guidance_scale,
2154
+ unconditional_conditioning=uc_full,
2155
+ )
2156
+ x_samples_cfg = self.decode_first_stage(samples_cfg)
2157
+ log[
2158
+ f"samples_cfg_scale_{unconditional_guidance_scale:.2f}"
2159
+ ] = x_samples_cfg
2160
+
2161
+ return log
2162
+
2163
+
2164
+ class LatentInpaintDiffusion(LatentFinetuneDiffusion):
2165
+ """
2166
+ can either run as pure inpainting model (only concat mode) or with mixed conditionings,
2167
+ e.g. mask as concat and text via cross-attn.
2168
+ To disable finetuning mode, set finetune_keys to None
2169
+ """
2170
+
2171
+ def __init__(
2172
+ self,
2173
+ concat_keys=("mask", "masked_image"),
2174
+ masked_image_key="masked_image",
2175
+ *args,
2176
+ **kwargs,
2177
+ ):
2178
+ super().__init__(concat_keys, *args, **kwargs)
2179
+ self.masked_image_key = masked_image_key
2180
+ assert self.masked_image_key in concat_keys
2181
+
2182
+ @torch.no_grad()
2183
+ def get_input(
2184
+ self, batch, k, cond_key=None, bs=None, return_first_stage_outputs=False
2185
+ ):
2186
+ # note: restricted to non-trainable encoders currently
2187
+ assert (
2188
+ not self.cond_stage_trainable
2189
+ ), "trainable cond stages not yet supported for inpainting"
2190
+ z, c, x, xrec, xc = super().get_input(
2191
+ batch,
2192
+ self.first_stage_key,
2193
+ return_first_stage_outputs=True,
2194
+ force_c_encode=True,
2195
+ return_original_cond=True,
2196
+ bs=bs,
2197
+ )
2198
+
2199
+ assert exists(self.concat_keys)
2200
+ c_cat = list()
2201
+ for ck in self.concat_keys:
2202
+ cc = (
2203
+ rearrange(batch[ck], "b h w c -> b c h w")
2204
+ .to(memory_format=torch.contiguous_format)
2205
+ .float()
2206
+ )
2207
+ if bs is not None:
2208
+ cc = cc[:bs]
2209
+ cc = cc.to(self.device)
2210
+ bchw = z.shape
2211
+ if ck != self.masked_image_key:
2212
+ cc = torch.nn.functional.interpolate(cc, size=bchw[-2:])
2213
+ else:
2214
+ cc = self.get_first_stage_encoding(self.encode_first_stage(cc))
2215
+ c_cat.append(cc)
2216
+ c_cat = torch.cat(c_cat, dim=1)
2217
+ all_conds = {"c_concat": [c_cat], "c_crossattn": [c]}
2218
+ if return_first_stage_outputs:
2219
+ return z, all_conds, x, xrec, xc
2220
+ return z, all_conds
2221
+
2222
+ @torch.no_grad()
2223
+ def log_images(self, *args, **kwargs):
2224
+ log = super(LatentInpaintDiffusion, self).log_images(*args, **kwargs)
2225
+ log["masked_image"] = (
2226
+ rearrange(args[0]["masked_image"], "b h w c -> b c h w")
2227
+ .to(memory_format=torch.contiguous_format)
2228
+ .float()
2229
+ )
2230
+ return log
2231
+
2232
+
2233
+ class LatentDepth2ImageDiffusion(LatentFinetuneDiffusion):
2234
+ """
2235
+ condition on monocular depth estimation
2236
+ """
2237
+
2238
+ def __init__(self, depth_stage_config, concat_keys=("midas_in",), *args, **kwargs):
2239
+ super().__init__(concat_keys=concat_keys, *args, **kwargs)
2240
+ self.depth_model = instantiate_from_config(depth_stage_config)
2241
+ self.depth_stage_key = concat_keys[0]
2242
+
2243
+ @torch.no_grad()
2244
+ def get_input(
2245
+ self, batch, k, cond_key=None, bs=None, return_first_stage_outputs=False
2246
+ ):
2247
+ # note: restricted to non-trainable encoders currently
2248
+ assert (
2249
+ not self.cond_stage_trainable
2250
+ ), "trainable cond stages not yet supported for depth2img"
2251
+ z, c, x, xrec, xc = super().get_input(
2252
+ batch,
2253
+ self.first_stage_key,
2254
+ return_first_stage_outputs=True,
2255
+ force_c_encode=True,
2256
+ return_original_cond=True,
2257
+ bs=bs,
2258
+ )
2259
+
2260
+ assert exists(self.concat_keys)
2261
+ assert len(self.concat_keys) == 1
2262
+ c_cat = list()
2263
+ for ck in self.concat_keys:
2264
+ cc = batch[ck]
2265
+ if bs is not None:
2266
+ cc = cc[:bs]
2267
+ cc = cc.to(self.device)
2268
+ cc = self.depth_model(cc)
2269
+ cc = torch.nn.functional.interpolate(
2270
+ cc,
2271
+ size=z.shape[2:],
2272
+ mode="bicubic",
2273
+ align_corners=False,
2274
+ )
2275
+
2276
+ depth_min, depth_max = torch.amin(
2277
+ cc, dim=[1, 2, 3], keepdim=True
2278
+ ), torch.amax(cc, dim=[1, 2, 3], keepdim=True)
2279
+ cc = 2.0 * (cc - depth_min) / (depth_max - depth_min + 0.001) - 1.0
2280
+ c_cat.append(cc)
2281
+ c_cat = torch.cat(c_cat, dim=1)
2282
+ all_conds = {"c_concat": [c_cat], "c_crossattn": [c]}
2283
+ if return_first_stage_outputs:
2284
+ return z, all_conds, x, xrec, xc
2285
+ return z, all_conds
2286
+
2287
+ @torch.no_grad()
2288
+ def log_images(self, *args, **kwargs):
2289
+ log = super().log_images(*args, **kwargs)
2290
+ depth = self.depth_model(args[0][self.depth_stage_key])
2291
+ depth_min, depth_max = torch.amin(
2292
+ depth, dim=[1, 2, 3], keepdim=True
2293
+ ), torch.amax(depth, dim=[1, 2, 3], keepdim=True)
2294
+ log["depth"] = 2.0 * (depth - depth_min) / (depth_max - depth_min) - 1.0
2295
+ return log
2296
+
2297
+
2298
+ class LatentUpscaleFinetuneDiffusion(LatentFinetuneDiffusion):
2299
+ """
2300
+ condition on low-res image (and optionally on some spatial noise augmentation)
2301
+ """
2302
+
2303
+ def __init__(
2304
+ self,
2305
+ concat_keys=("lr",),
2306
+ reshuffle_patch_size=None,
2307
+ low_scale_config=None,
2308
+ low_scale_key=None,
2309
+ *args,
2310
+ **kwargs,
2311
+ ):
2312
+ super().__init__(concat_keys=concat_keys, *args, **kwargs)
2313
+ self.reshuffle_patch_size = reshuffle_patch_size
2314
+ self.low_scale_model = None
2315
+ if low_scale_config is not None:
2316
+ print("Initializing a low-scale model")
2317
+ assert exists(low_scale_key)
2318
+ self.instantiate_low_stage(low_scale_config)
2319
+ self.low_scale_key = low_scale_key
2320
+
2321
+ def instantiate_low_stage(self, config):
2322
+ model = instantiate_from_config(config)
2323
+ self.low_scale_model = model.eval()
2324
+ self.low_scale_model.train = disabled_train
2325
+ for param in self.low_scale_model.parameters():
2326
+ param.requires_grad = False
2327
+
2328
+ @torch.no_grad()
2329
+ def get_input(
2330
+ self, batch, k, cond_key=None, bs=None, return_first_stage_outputs=False
2331
+ ):
2332
+ # note: restricted to non-trainable encoders currently
2333
+ assert (
2334
+ not self.cond_stage_trainable
2335
+ ), "trainable cond stages not yet supported for upscaling-ft"
2336
+ z, c, x, xrec, xc = super().get_input(
2337
+ batch,
2338
+ self.first_stage_key,
2339
+ return_first_stage_outputs=True,
2340
+ force_c_encode=True,
2341
+ return_original_cond=True,
2342
+ bs=bs,
2343
+ )
2344
+
2345
+ assert exists(self.concat_keys)
2346
+ assert len(self.concat_keys) == 1
2347
+ # optionally make spatial noise_level here
2348
+ c_cat = list()
2349
+ noise_level = None
2350
+ for ck in self.concat_keys:
2351
+ cc = batch[ck]
2352
+ cc = rearrange(cc, "b h w c -> b c h w")
2353
+ if exists(self.reshuffle_patch_size):
2354
+ assert isinstance(self.reshuffle_patch_size, int)
2355
+ cc = rearrange(
2356
+ cc,
2357
+ "b c (p1 h) (p2 w) -> b (p1 p2 c) h w",
2358
+ p1=self.reshuffle_patch_size,
2359
+ p2=self.reshuffle_patch_size,
2360
+ )
2361
+ if bs is not None:
2362
+ cc = cc[:bs]
2363
+ cc = cc.to(self.device)
2364
+ if exists(self.low_scale_model) and ck == self.low_scale_key:
2365
+ cc, noise_level = self.low_scale_model(cc)
2366
+ c_cat.append(cc)
2367
+ c_cat = torch.cat(c_cat, dim=1)
2368
+ if exists(noise_level):
2369
+ all_conds = {"c_concat": [c_cat], "c_crossattn": [c], "c_adm": noise_level}
2370
+ else:
2371
+ all_conds = {"c_concat": [c_cat], "c_crossattn": [c]}
2372
+ if return_first_stage_outputs:
2373
+ return z, all_conds, x, xrec, xc
2374
+ return z, all_conds
2375
+
2376
+ @torch.no_grad()
2377
+ def log_images(self, *args, **kwargs):
2378
+ log = super().log_images(*args, **kwargs)
2379
+ log["lr"] = rearrange(args[0]["lr"], "b h w c -> b c h w")
2380
+ return log
iopaint/model/anytext/ldm/models/diffusion/dpm_solver/dpm_solver.py ADDED
@@ -0,0 +1,1154 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+ import math
4
+ from tqdm import tqdm
5
+
6
+
7
+ class NoiseScheduleVP:
8
+ def __init__(
9
+ self,
10
+ schedule='discrete',
11
+ betas=None,
12
+ alphas_cumprod=None,
13
+ continuous_beta_0=0.1,
14
+ continuous_beta_1=20.,
15
+ ):
16
+ """Create a wrapper class for the forward SDE (VP type).
17
+ ***
18
+ Update: We support discrete-time diffusion models by implementing a picewise linear interpolation for log_alpha_t.
19
+ We recommend to use schedule='discrete' for the discrete-time diffusion models, especially for high-resolution images.
20
+ ***
21
+ The forward SDE ensures that the condition distribution q_{t|0}(x_t | x_0) = N ( alpha_t * x_0, sigma_t^2 * I ).
22
+ We further define lambda_t = log(alpha_t) - log(sigma_t), which is the half-logSNR (described in the DPM-Solver paper).
23
+ Therefore, we implement the functions for computing alpha_t, sigma_t and lambda_t. For t in [0, T], we have:
24
+ log_alpha_t = self.marginal_log_mean_coeff(t)
25
+ sigma_t = self.marginal_std(t)
26
+ lambda_t = self.marginal_lambda(t)
27
+ Moreover, as lambda(t) is an invertible function, we also support its inverse function:
28
+ t = self.inverse_lambda(lambda_t)
29
+ ===============================================================
30
+ We support both discrete-time DPMs (trained on n = 0, 1, ..., N-1) and continuous-time DPMs (trained on t in [t_0, T]).
31
+ 1. For discrete-time DPMs:
32
+ For discrete-time DPMs trained on n = 0, 1, ..., N-1, we convert the discrete steps to continuous time steps by:
33
+ t_i = (i + 1) / N
34
+ e.g. for N = 1000, we have t_0 = 1e-3 and T = t_{N-1} = 1.
35
+ We solve the corresponding diffusion ODE from time T = 1 to time t_0 = 1e-3.
36
+ Args:
37
+ betas: A `torch.Tensor`. The beta array for the discrete-time DPM. (See the original DDPM paper for details)
38
+ alphas_cumprod: A `torch.Tensor`. The cumprod alphas for the discrete-time DPM. (See the original DDPM paper for details)
39
+ Note that we always have alphas_cumprod = cumprod(betas). Therefore, we only need to set one of `betas` and `alphas_cumprod`.
40
+ **Important**: Please pay special attention for the args for `alphas_cumprod`:
41
+ The `alphas_cumprod` is the \hat{alpha_n} arrays in the notations of DDPM. Specifically, DDPMs assume that
42
+ q_{t_n | 0}(x_{t_n} | x_0) = N ( \sqrt{\hat{alpha_n}} * x_0, (1 - \hat{alpha_n}) * I ).
43
+ Therefore, the notation \hat{alpha_n} is different from the notation alpha_t in DPM-Solver. In fact, we have
44
+ alpha_{t_n} = \sqrt{\hat{alpha_n}},
45
+ and
46
+ log(alpha_{t_n}) = 0.5 * log(\hat{alpha_n}).
47
+ 2. For continuous-time DPMs:
48
+ We support two types of VPSDEs: linear (DDPM) and cosine (improved-DDPM). The hyperparameters for the noise
49
+ schedule are the default settings in DDPM and improved-DDPM:
50
+ Args:
51
+ beta_min: A `float` number. The smallest beta for the linear schedule.
52
+ beta_max: A `float` number. The largest beta for the linear schedule.
53
+ cosine_s: A `float` number. The hyperparameter in the cosine schedule.
54
+ cosine_beta_max: A `float` number. The hyperparameter in the cosine schedule.
55
+ T: A `float` number. The ending time of the forward process.
56
+ ===============================================================
57
+ Args:
58
+ schedule: A `str`. The noise schedule of the forward SDE. 'discrete' for discrete-time DPMs,
59
+ 'linear' or 'cosine' for continuous-time DPMs.
60
+ Returns:
61
+ A wrapper object of the forward SDE (VP type).
62
+
63
+ ===============================================================
64
+ Example:
65
+ # For discrete-time DPMs, given betas (the beta array for n = 0, 1, ..., N - 1):
66
+ >>> ns = NoiseScheduleVP('discrete', betas=betas)
67
+ # For discrete-time DPMs, given alphas_cumprod (the \hat{alpha_n} array for n = 0, 1, ..., N - 1):
68
+ >>> ns = NoiseScheduleVP('discrete', alphas_cumprod=alphas_cumprod)
69
+ # For continuous-time DPMs (VPSDE), linear schedule:
70
+ >>> ns = NoiseScheduleVP('linear', continuous_beta_0=0.1, continuous_beta_1=20.)
71
+ """
72
+
73
+ if schedule not in ['discrete', 'linear', 'cosine']:
74
+ raise ValueError(
75
+ "Unsupported noise schedule {}. The schedule needs to be 'discrete' or 'linear' or 'cosine'".format(
76
+ schedule))
77
+
78
+ self.schedule = schedule
79
+ if schedule == 'discrete':
80
+ if betas is not None:
81
+ log_alphas = 0.5 * torch.log(1 - betas).cumsum(dim=0)
82
+ else:
83
+ assert alphas_cumprod is not None
84
+ log_alphas = 0.5 * torch.log(alphas_cumprod)
85
+ self.total_N = len(log_alphas)
86
+ self.T = 1.
87
+ self.t_array = torch.linspace(0., 1., self.total_N + 1)[1:].reshape((1, -1))
88
+ self.log_alpha_array = log_alphas.reshape((1, -1,))
89
+ else:
90
+ self.total_N = 1000
91
+ self.beta_0 = continuous_beta_0
92
+ self.beta_1 = continuous_beta_1
93
+ self.cosine_s = 0.008
94
+ self.cosine_beta_max = 999.
95
+ self.cosine_t_max = math.atan(self.cosine_beta_max * (1. + self.cosine_s) / math.pi) * 2. * (
96
+ 1. + self.cosine_s) / math.pi - self.cosine_s
97
+ self.cosine_log_alpha_0 = math.log(math.cos(self.cosine_s / (1. + self.cosine_s) * math.pi / 2.))
98
+ self.schedule = schedule
99
+ if schedule == 'cosine':
100
+ # For the cosine schedule, T = 1 will have numerical issues. So we manually set the ending time T.
101
+ # Note that T = 0.9946 may be not the optimal setting. However, we find it works well.
102
+ self.T = 0.9946
103
+ else:
104
+ self.T = 1.
105
+
106
+ def marginal_log_mean_coeff(self, t):
107
+ """
108
+ Compute log(alpha_t) of a given continuous-time label t in [0, T].
109
+ """
110
+ if self.schedule == 'discrete':
111
+ return interpolate_fn(t.reshape((-1, 1)), self.t_array.to(t.device),
112
+ self.log_alpha_array.to(t.device)).reshape((-1))
113
+ elif self.schedule == 'linear':
114
+ return -0.25 * t ** 2 * (self.beta_1 - self.beta_0) - 0.5 * t * self.beta_0
115
+ elif self.schedule == 'cosine':
116
+ log_alpha_fn = lambda s: torch.log(torch.cos((s + self.cosine_s) / (1. + self.cosine_s) * math.pi / 2.))
117
+ log_alpha_t = log_alpha_fn(t) - self.cosine_log_alpha_0
118
+ return log_alpha_t
119
+
120
+ def marginal_alpha(self, t):
121
+ """
122
+ Compute alpha_t of a given continuous-time label t in [0, T].
123
+ """
124
+ return torch.exp(self.marginal_log_mean_coeff(t))
125
+
126
+ def marginal_std(self, t):
127
+ """
128
+ Compute sigma_t of a given continuous-time label t in [0, T].
129
+ """
130
+ return torch.sqrt(1. - torch.exp(2. * self.marginal_log_mean_coeff(t)))
131
+
132
+ def marginal_lambda(self, t):
133
+ """
134
+ Compute lambda_t = log(alpha_t) - log(sigma_t) of a given continuous-time label t in [0, T].
135
+ """
136
+ log_mean_coeff = self.marginal_log_mean_coeff(t)
137
+ log_std = 0.5 * torch.log(1. - torch.exp(2. * log_mean_coeff))
138
+ return log_mean_coeff - log_std
139
+
140
+ def inverse_lambda(self, lamb):
141
+ """
142
+ Compute the continuous-time label t in [0, T] of a given half-logSNR lambda_t.
143
+ """
144
+ if self.schedule == 'linear':
145
+ tmp = 2. * (self.beta_1 - self.beta_0) * torch.logaddexp(-2. * lamb, torch.zeros((1,)).to(lamb))
146
+ Delta = self.beta_0 ** 2 + tmp
147
+ return tmp / (torch.sqrt(Delta) + self.beta_0) / (self.beta_1 - self.beta_0)
148
+ elif self.schedule == 'discrete':
149
+ log_alpha = -0.5 * torch.logaddexp(torch.zeros((1,)).to(lamb.device), -2. * lamb)
150
+ t = interpolate_fn(log_alpha.reshape((-1, 1)), torch.flip(self.log_alpha_array.to(lamb.device), [1]),
151
+ torch.flip(self.t_array.to(lamb.device), [1]))
152
+ return t.reshape((-1,))
153
+ else:
154
+ log_alpha = -0.5 * torch.logaddexp(-2. * lamb, torch.zeros((1,)).to(lamb))
155
+ t_fn = lambda log_alpha_t: torch.arccos(torch.exp(log_alpha_t + self.cosine_log_alpha_0)) * 2. * (
156
+ 1. + self.cosine_s) / math.pi - self.cosine_s
157
+ t = t_fn(log_alpha)
158
+ return t
159
+
160
+
161
+ def model_wrapper(
162
+ model,
163
+ noise_schedule,
164
+ model_type="noise",
165
+ model_kwargs={},
166
+ guidance_type="uncond",
167
+ condition=None,
168
+ unconditional_condition=None,
169
+ guidance_scale=1.,
170
+ classifier_fn=None,
171
+ classifier_kwargs={},
172
+ ):
173
+ """Create a wrapper function for the noise prediction model.
174
+ DPM-Solver needs to solve the continuous-time diffusion ODEs. For DPMs trained on discrete-time labels, we need to
175
+ firstly wrap the model function to a noise prediction model that accepts the continuous time as the input.
176
+ We support four types of the diffusion model by setting `model_type`:
177
+ 1. "noise": noise prediction model. (Trained by predicting noise).
178
+ 2. "x_start": data prediction model. (Trained by predicting the data x_0 at time 0).
179
+ 3. "v": velocity prediction model. (Trained by predicting the velocity).
180
+ The "v" prediction is derivation detailed in Appendix D of [1], and is used in Imagen-Video [2].
181
+ [1] Salimans, Tim, and Jonathan Ho. "Progressive distillation for fast sampling of diffusion models."
182
+ arXiv preprint arXiv:2202.00512 (2022).
183
+ [2] Ho, Jonathan, et al. "Imagen Video: High Definition Video Generation with Diffusion Models."
184
+ arXiv preprint arXiv:2210.02303 (2022).
185
+
186
+ 4. "score": marginal score function. (Trained by denoising score matching).
187
+ Note that the score function and the noise prediction model follows a simple relationship:
188
+ ```
189
+ noise(x_t, t) = -sigma_t * score(x_t, t)
190
+ ```
191
+ We support three types of guided sampling by DPMs by setting `guidance_type`:
192
+ 1. "uncond": unconditional sampling by DPMs.
193
+ The input `model` has the following format:
194
+ ``
195
+ model(x, t_input, **model_kwargs) -> noise | x_start | v | score
196
+ ``
197
+ 2. "classifier": classifier guidance sampling [3] by DPMs and another classifier.
198
+ The input `model` has the following format:
199
+ ``
200
+ model(x, t_input, **model_kwargs) -> noise | x_start | v | score
201
+ ``
202
+ The input `classifier_fn` has the following format:
203
+ ``
204
+ classifier_fn(x, t_input, cond, **classifier_kwargs) -> logits(x, t_input, cond)
205
+ ``
206
+ [3] P. Dhariwal and A. Q. Nichol, "Diffusion models beat GANs on image synthesis,"
207
+ in Advances in Neural Information Processing Systems, vol. 34, 2021, pp. 8780-8794.
208
+ 3. "classifier-free": classifier-free guidance sampling by conditional DPMs.
209
+ The input `model` has the following format:
210
+ ``
211
+ model(x, t_input, cond, **model_kwargs) -> noise | x_start | v | score
212
+ ``
213
+ And if cond == `unconditional_condition`, the model output is the unconditional DPM output.
214
+ [4] Ho, Jonathan, and Tim Salimans. "Classifier-free diffusion guidance."
215
+ arXiv preprint arXiv:2207.12598 (2022).
216
+
217
+ The `t_input` is the time label of the model, which may be discrete-time labels (i.e. 0 to 999)
218
+ or continuous-time labels (i.e. epsilon to T).
219
+ We wrap the model function to accept only `x` and `t_continuous` as inputs, and outputs the predicted noise:
220
+ ``
221
+ def model_fn(x, t_continuous) -> noise:
222
+ t_input = get_model_input_time(t_continuous)
223
+ return noise_pred(model, x, t_input, **model_kwargs)
224
+ ``
225
+ where `t_continuous` is the continuous time labels (i.e. epsilon to T). And we use `model_fn` for DPM-Solver.
226
+ ===============================================================
227
+ Args:
228
+ model: A diffusion model with the corresponding format described above.
229
+ noise_schedule: A noise schedule object, such as NoiseScheduleVP.
230
+ model_type: A `str`. The parameterization type of the diffusion model.
231
+ "noise" or "x_start" or "v" or "score".
232
+ model_kwargs: A `dict`. A dict for the other inputs of the model function.
233
+ guidance_type: A `str`. The type of the guidance for sampling.
234
+ "uncond" or "classifier" or "classifier-free".
235
+ condition: A pytorch tensor. The condition for the guided sampling.
236
+ Only used for "classifier" or "classifier-free" guidance type.
237
+ unconditional_condition: A pytorch tensor. The condition for the unconditional sampling.
238
+ Only used for "classifier-free" guidance type.
239
+ guidance_scale: A `float`. The scale for the guided sampling.
240
+ classifier_fn: A classifier function. Only used for the classifier guidance.
241
+ classifier_kwargs: A `dict`. A dict for the other inputs of the classifier function.
242
+ Returns:
243
+ A noise prediction model that accepts the noised data and the continuous time as the inputs.
244
+ """
245
+
246
+ def get_model_input_time(t_continuous):
247
+ """
248
+ Convert the continuous-time `t_continuous` (in [epsilon, T]) to the model input time.
249
+ For discrete-time DPMs, we convert `t_continuous` in [1 / N, 1] to `t_input` in [0, 1000 * (N - 1) / N].
250
+ For continuous-time DPMs, we just use `t_continuous`.
251
+ """
252
+ if noise_schedule.schedule == 'discrete':
253
+ return (t_continuous - 1. / noise_schedule.total_N) * 1000.
254
+ else:
255
+ return t_continuous
256
+
257
+ def noise_pred_fn(x, t_continuous, cond=None):
258
+ if t_continuous.reshape((-1,)).shape[0] == 1:
259
+ t_continuous = t_continuous.expand((x.shape[0]))
260
+ t_input = get_model_input_time(t_continuous)
261
+ if cond is None:
262
+ output = model(x, t_input, **model_kwargs)
263
+ else:
264
+ output = model(x, t_input, cond, **model_kwargs)
265
+ if model_type == "noise":
266
+ return output
267
+ elif model_type == "x_start":
268
+ alpha_t, sigma_t = noise_schedule.marginal_alpha(t_continuous), noise_schedule.marginal_std(t_continuous)
269
+ dims = x.dim()
270
+ return (x - expand_dims(alpha_t, dims) * output) / expand_dims(sigma_t, dims)
271
+ elif model_type == "v":
272
+ alpha_t, sigma_t = noise_schedule.marginal_alpha(t_continuous), noise_schedule.marginal_std(t_continuous)
273
+ dims = x.dim()
274
+ return expand_dims(alpha_t, dims) * output + expand_dims(sigma_t, dims) * x
275
+ elif model_type == "score":
276
+ sigma_t = noise_schedule.marginal_std(t_continuous)
277
+ dims = x.dim()
278
+ return -expand_dims(sigma_t, dims) * output
279
+
280
+ def cond_grad_fn(x, t_input):
281
+ """
282
+ Compute the gradient of the classifier, i.e. nabla_{x} log p_t(cond | x_t).
283
+ """
284
+ with torch.enable_grad():
285
+ x_in = x.detach().requires_grad_(True)
286
+ log_prob = classifier_fn(x_in, t_input, condition, **classifier_kwargs)
287
+ return torch.autograd.grad(log_prob.sum(), x_in)[0]
288
+
289
+ def model_fn(x, t_continuous):
290
+ """
291
+ The noise predicition model function that is used for DPM-Solver.
292
+ """
293
+ if t_continuous.reshape((-1,)).shape[0] == 1:
294
+ t_continuous = t_continuous.expand((x.shape[0]))
295
+ if guidance_type == "uncond":
296
+ return noise_pred_fn(x, t_continuous)
297
+ elif guidance_type == "classifier":
298
+ assert classifier_fn is not None
299
+ t_input = get_model_input_time(t_continuous)
300
+ cond_grad = cond_grad_fn(x, t_input)
301
+ sigma_t = noise_schedule.marginal_std(t_continuous)
302
+ noise = noise_pred_fn(x, t_continuous)
303
+ return noise - guidance_scale * expand_dims(sigma_t, dims=cond_grad.dim()) * cond_grad
304
+ elif guidance_type == "classifier-free":
305
+ if guidance_scale == 1. or unconditional_condition is None:
306
+ return noise_pred_fn(x, t_continuous, cond=condition)
307
+ else:
308
+ x_in = torch.cat([x] * 2)
309
+ t_in = torch.cat([t_continuous] * 2)
310
+ c_in = torch.cat([unconditional_condition, condition])
311
+ noise_uncond, noise = noise_pred_fn(x_in, t_in, cond=c_in).chunk(2)
312
+ return noise_uncond + guidance_scale * (noise - noise_uncond)
313
+
314
+ assert model_type in ["noise", "x_start", "v"]
315
+ assert guidance_type in ["uncond", "classifier", "classifier-free"]
316
+ return model_fn
317
+
318
+
319
+ class DPM_Solver:
320
+ def __init__(self, model_fn, noise_schedule, predict_x0=False, thresholding=False, max_val=1.):
321
+ """Construct a DPM-Solver.
322
+ We support both the noise prediction model ("predicting epsilon") and the data prediction model ("predicting x0").
323
+ If `predict_x0` is False, we use the solver for the noise prediction model (DPM-Solver).
324
+ If `predict_x0` is True, we use the solver for the data prediction model (DPM-Solver++).
325
+ In such case, we further support the "dynamic thresholding" in [1] when `thresholding` is True.
326
+ The "dynamic thresholding" can greatly improve the sample quality for pixel-space DPMs with large guidance scales.
327
+ Args:
328
+ model_fn: A noise prediction model function which accepts the continuous-time input (t in [epsilon, T]):
329
+ ``
330
+ def model_fn(x, t_continuous):
331
+ return noise
332
+ ``
333
+ noise_schedule: A noise schedule object, such as NoiseScheduleVP.
334
+ predict_x0: A `bool`. If true, use the data prediction model; else, use the noise prediction model.
335
+ thresholding: A `bool`. Valid when `predict_x0` is True. Whether to use the "dynamic thresholding" in [1].
336
+ max_val: A `float`. Valid when both `predict_x0` and `thresholding` are True. The max value for thresholding.
337
+
338
+ [1] Chitwan Saharia, William Chan, Saurabh Saxena, Lala Li, Jay Whang, Emily Denton, Seyed Kamyar Seyed Ghasemipour, Burcu Karagol Ayan, S Sara Mahdavi, Rapha Gontijo Lopes, et al. Photorealistic text-to-image diffusion models with deep language understanding. arXiv preprint arXiv:2205.11487, 2022b.
339
+ """
340
+ self.model = model_fn
341
+ self.noise_schedule = noise_schedule
342
+ self.predict_x0 = predict_x0
343
+ self.thresholding = thresholding
344
+ self.max_val = max_val
345
+
346
+ def noise_prediction_fn(self, x, t):
347
+ """
348
+ Return the noise prediction model.
349
+ """
350
+ return self.model(x, t)
351
+
352
+ def data_prediction_fn(self, x, t):
353
+ """
354
+ Return the data prediction model (with thresholding).
355
+ """
356
+ noise = self.noise_prediction_fn(x, t)
357
+ dims = x.dim()
358
+ alpha_t, sigma_t = self.noise_schedule.marginal_alpha(t), self.noise_schedule.marginal_std(t)
359
+ x0 = (x - expand_dims(sigma_t, dims) * noise) / expand_dims(alpha_t, dims)
360
+ if self.thresholding:
361
+ p = 0.995 # A hyperparameter in the paper of "Imagen" [1].
362
+ s = torch.quantile(torch.abs(x0).reshape((x0.shape[0], -1)), p, dim=1)
363
+ s = expand_dims(torch.maximum(s, self.max_val * torch.ones_like(s).to(s.device)), dims)
364
+ x0 = torch.clamp(x0, -s, s) / s
365
+ return x0
366
+
367
+ def model_fn(self, x, t):
368
+ """
369
+ Convert the model to the noise prediction model or the data prediction model.
370
+ """
371
+ if self.predict_x0:
372
+ return self.data_prediction_fn(x, t)
373
+ else:
374
+ return self.noise_prediction_fn(x, t)
375
+
376
+ def get_time_steps(self, skip_type, t_T, t_0, N, device):
377
+ """Compute the intermediate time steps for sampling.
378
+ Args:
379
+ skip_type: A `str`. The type for the spacing of the time steps. We support three types:
380
+ - 'logSNR': uniform logSNR for the time steps.
381
+ - 'time_uniform': uniform time for the time steps. (**Recommended for high-resolutional data**.)
382
+ - 'time_quadratic': quadratic time for the time steps. (Used in DDIM for low-resolutional data.)
383
+ t_T: A `float`. The starting time of the sampling (default is T).
384
+ t_0: A `float`. The ending time of the sampling (default is epsilon).
385
+ N: A `int`. The total number of the spacing of the time steps.
386
+ device: A torch device.
387
+ Returns:
388
+ A pytorch tensor of the time steps, with the shape (N + 1,).
389
+ """
390
+ if skip_type == 'logSNR':
391
+ lambda_T = self.noise_schedule.marginal_lambda(torch.tensor(t_T).to(device))
392
+ lambda_0 = self.noise_schedule.marginal_lambda(torch.tensor(t_0).to(device))
393
+ logSNR_steps = torch.linspace(lambda_T.cpu().item(), lambda_0.cpu().item(), N + 1).to(device)
394
+ return self.noise_schedule.inverse_lambda(logSNR_steps)
395
+ elif skip_type == 'time_uniform':
396
+ return torch.linspace(t_T, t_0, N + 1).to(device)
397
+ elif skip_type == 'time_quadratic':
398
+ t_order = 2
399
+ t = torch.linspace(t_T ** (1. / t_order), t_0 ** (1. / t_order), N + 1).pow(t_order).to(device)
400
+ return t
401
+ else:
402
+ raise ValueError(
403
+ "Unsupported skip_type {}, need to be 'logSNR' or 'time_uniform' or 'time_quadratic'".format(skip_type))
404
+
405
+ def get_orders_and_timesteps_for_singlestep_solver(self, steps, order, skip_type, t_T, t_0, device):
406
+ """
407
+ Get the order of each step for sampling by the singlestep DPM-Solver.
408
+ We combine both DPM-Solver-1,2,3 to use all the function evaluations, which is named as "DPM-Solver-fast".
409
+ Given a fixed number of function evaluations by `steps`, the sampling procedure by DPM-Solver-fast is:
410
+ - If order == 1:
411
+ We take `steps` of DPM-Solver-1 (i.e. DDIM).
412
+ - If order == 2:
413
+ - Denote K = (steps // 2). We take K or (K + 1) intermediate time steps for sampling.
414
+ - If steps % 2 == 0, we use K steps of DPM-Solver-2.
415
+ - If steps % 2 == 1, we use K steps of DPM-Solver-2 and 1 step of DPM-Solver-1.
416
+ - If order == 3:
417
+ - Denote K = (steps // 3 + 1). We take K intermediate time steps for sampling.
418
+ - If steps % 3 == 0, we use (K - 2) steps of DPM-Solver-3, and 1 step of DPM-Solver-2 and 1 step of DPM-Solver-1.
419
+ - If steps % 3 == 1, we use (K - 1) steps of DPM-Solver-3 and 1 step of DPM-Solver-1.
420
+ - If steps % 3 == 2, we use (K - 1) steps of DPM-Solver-3 and 1 step of DPM-Solver-2.
421
+ ============================================
422
+ Args:
423
+ order: A `int`. The max order for the solver (2 or 3).
424
+ steps: A `int`. The total number of function evaluations (NFE).
425
+ skip_type: A `str`. The type for the spacing of the time steps. We support three types:
426
+ - 'logSNR': uniform logSNR for the time steps.
427
+ - 'time_uniform': uniform time for the time steps. (**Recommended for high-resolutional data**.)
428
+ - 'time_quadratic': quadratic time for the time steps. (Used in DDIM for low-resolutional data.)
429
+ t_T: A `float`. The starting time of the sampling (default is T).
430
+ t_0: A `float`. The ending time of the sampling (default is epsilon).
431
+ device: A torch device.
432
+ Returns:
433
+ orders: A list of the solver order of each step.
434
+ """
435
+ if order == 3:
436
+ K = steps // 3 + 1
437
+ if steps % 3 == 0:
438
+ orders = [3, ] * (K - 2) + [2, 1]
439
+ elif steps % 3 == 1:
440
+ orders = [3, ] * (K - 1) + [1]
441
+ else:
442
+ orders = [3, ] * (K - 1) + [2]
443
+ elif order == 2:
444
+ if steps % 2 == 0:
445
+ K = steps // 2
446
+ orders = [2, ] * K
447
+ else:
448
+ K = steps // 2 + 1
449
+ orders = [2, ] * (K - 1) + [1]
450
+ elif order == 1:
451
+ K = 1
452
+ orders = [1, ] * steps
453
+ else:
454
+ raise ValueError("'order' must be '1' or '2' or '3'.")
455
+ if skip_type == 'logSNR':
456
+ # To reproduce the results in DPM-Solver paper
457
+ timesteps_outer = self.get_time_steps(skip_type, t_T, t_0, K, device)
458
+ else:
459
+ timesteps_outer = self.get_time_steps(skip_type, t_T, t_0, steps, device)[
460
+ torch.cumsum(torch.tensor([0, ] + orders)).to(device)]
461
+ return timesteps_outer, orders
462
+
463
+ def denoise_to_zero_fn(self, x, s):
464
+ """
465
+ Denoise at the final step, which is equivalent to solve the ODE from lambda_s to infty by first-order discretization.
466
+ """
467
+ return self.data_prediction_fn(x, s)
468
+
469
+ def dpm_solver_first_update(self, x, s, t, model_s=None, return_intermediate=False):
470
+ """
471
+ DPM-Solver-1 (equivalent to DDIM) from time `s` to time `t`.
472
+ Args:
473
+ x: A pytorch tensor. The initial value at time `s`.
474
+ s: A pytorch tensor. The starting time, with the shape (x.shape[0],).
475
+ t: A pytorch tensor. The ending time, with the shape (x.shape[0],).
476
+ model_s: A pytorch tensor. The model function evaluated at time `s`.
477
+ If `model_s` is None, we evaluate the model by `x` and `s`; otherwise we directly use it.
478
+ return_intermediate: A `bool`. If true, also return the model value at time `s`.
479
+ Returns:
480
+ x_t: A pytorch tensor. The approximated solution at time `t`.
481
+ """
482
+ ns = self.noise_schedule
483
+ dims = x.dim()
484
+ lambda_s, lambda_t = ns.marginal_lambda(s), ns.marginal_lambda(t)
485
+ h = lambda_t - lambda_s
486
+ log_alpha_s, log_alpha_t = ns.marginal_log_mean_coeff(s), ns.marginal_log_mean_coeff(t)
487
+ sigma_s, sigma_t = ns.marginal_std(s), ns.marginal_std(t)
488
+ alpha_t = torch.exp(log_alpha_t)
489
+
490
+ if self.predict_x0:
491
+ phi_1 = torch.expm1(-h)
492
+ if model_s is None:
493
+ model_s = self.model_fn(x, s)
494
+ x_t = (
495
+ expand_dims(sigma_t / sigma_s, dims) * x
496
+ - expand_dims(alpha_t * phi_1, dims) * model_s
497
+ )
498
+ if return_intermediate:
499
+ return x_t, {'model_s': model_s}
500
+ else:
501
+ return x_t
502
+ else:
503
+ phi_1 = torch.expm1(h)
504
+ if model_s is None:
505
+ model_s = self.model_fn(x, s)
506
+ x_t = (
507
+ expand_dims(torch.exp(log_alpha_t - log_alpha_s), dims) * x
508
+ - expand_dims(sigma_t * phi_1, dims) * model_s
509
+ )
510
+ if return_intermediate:
511
+ return x_t, {'model_s': model_s}
512
+ else:
513
+ return x_t
514
+
515
+ def singlestep_dpm_solver_second_update(self, x, s, t, r1=0.5, model_s=None, return_intermediate=False,
516
+ solver_type='dpm_solver'):
517
+ """
518
+ Singlestep solver DPM-Solver-2 from time `s` to time `t`.
519
+ Args:
520
+ x: A pytorch tensor. The initial value at time `s`.
521
+ s: A pytorch tensor. The starting time, with the shape (x.shape[0],).
522
+ t: A pytorch tensor. The ending time, with the shape (x.shape[0],).
523
+ r1: A `float`. The hyperparameter of the second-order solver.
524
+ model_s: A pytorch tensor. The model function evaluated at time `s`.
525
+ If `model_s` is None, we evaluate the model by `x` and `s`; otherwise we directly use it.
526
+ return_intermediate: A `bool`. If true, also return the model value at time `s` and `s1` (the intermediate time).
527
+ solver_type: either 'dpm_solver' or 'taylor'. The type for the high-order solvers.
528
+ The type slightly impacts the performance. We recommend to use 'dpm_solver' type.
529
+ Returns:
530
+ x_t: A pytorch tensor. The approximated solution at time `t`.
531
+ """
532
+ if solver_type not in ['dpm_solver', 'taylor']:
533
+ raise ValueError("'solver_type' must be either 'dpm_solver' or 'taylor', got {}".format(solver_type))
534
+ if r1 is None:
535
+ r1 = 0.5
536
+ ns = self.noise_schedule
537
+ dims = x.dim()
538
+ lambda_s, lambda_t = ns.marginal_lambda(s), ns.marginal_lambda(t)
539
+ h = lambda_t - lambda_s
540
+ lambda_s1 = lambda_s + r1 * h
541
+ s1 = ns.inverse_lambda(lambda_s1)
542
+ log_alpha_s, log_alpha_s1, log_alpha_t = ns.marginal_log_mean_coeff(s), ns.marginal_log_mean_coeff(
543
+ s1), ns.marginal_log_mean_coeff(t)
544
+ sigma_s, sigma_s1, sigma_t = ns.marginal_std(s), ns.marginal_std(s1), ns.marginal_std(t)
545
+ alpha_s1, alpha_t = torch.exp(log_alpha_s1), torch.exp(log_alpha_t)
546
+
547
+ if self.predict_x0:
548
+ phi_11 = torch.expm1(-r1 * h)
549
+ phi_1 = torch.expm1(-h)
550
+
551
+ if model_s is None:
552
+ model_s = self.model_fn(x, s)
553
+ x_s1 = (
554
+ expand_dims(sigma_s1 / sigma_s, dims) * x
555
+ - expand_dims(alpha_s1 * phi_11, dims) * model_s
556
+ )
557
+ model_s1 = self.model_fn(x_s1, s1)
558
+ if solver_type == 'dpm_solver':
559
+ x_t = (
560
+ expand_dims(sigma_t / sigma_s, dims) * x
561
+ - expand_dims(alpha_t * phi_1, dims) * model_s
562
+ - (0.5 / r1) * expand_dims(alpha_t * phi_1, dims) * (model_s1 - model_s)
563
+ )
564
+ elif solver_type == 'taylor':
565
+ x_t = (
566
+ expand_dims(sigma_t / sigma_s, dims) * x
567
+ - expand_dims(alpha_t * phi_1, dims) * model_s
568
+ + (1. / r1) * expand_dims(alpha_t * ((torch.exp(-h) - 1.) / h + 1.), dims) * (
569
+ model_s1 - model_s)
570
+ )
571
+ else:
572
+ phi_11 = torch.expm1(r1 * h)
573
+ phi_1 = torch.expm1(h)
574
+
575
+ if model_s is None:
576
+ model_s = self.model_fn(x, s)
577
+ x_s1 = (
578
+ expand_dims(torch.exp(log_alpha_s1 - log_alpha_s), dims) * x
579
+ - expand_dims(sigma_s1 * phi_11, dims) * model_s
580
+ )
581
+ model_s1 = self.model_fn(x_s1, s1)
582
+ if solver_type == 'dpm_solver':
583
+ x_t = (
584
+ expand_dims(torch.exp(log_alpha_t - log_alpha_s), dims) * x
585
+ - expand_dims(sigma_t * phi_1, dims) * model_s
586
+ - (0.5 / r1) * expand_dims(sigma_t * phi_1, dims) * (model_s1 - model_s)
587
+ )
588
+ elif solver_type == 'taylor':
589
+ x_t = (
590
+ expand_dims(torch.exp(log_alpha_t - log_alpha_s), dims) * x
591
+ - expand_dims(sigma_t * phi_1, dims) * model_s
592
+ - (1. / r1) * expand_dims(sigma_t * ((torch.exp(h) - 1.) / h - 1.), dims) * (model_s1 - model_s)
593
+ )
594
+ if return_intermediate:
595
+ return x_t, {'model_s': model_s, 'model_s1': model_s1}
596
+ else:
597
+ return x_t
598
+
599
+ def singlestep_dpm_solver_third_update(self, x, s, t, r1=1. / 3., r2=2. / 3., model_s=None, model_s1=None,
600
+ return_intermediate=False, solver_type='dpm_solver'):
601
+ """
602
+ Singlestep solver DPM-Solver-3 from time `s` to time `t`.
603
+ Args:
604
+ x: A pytorch tensor. The initial value at time `s`.
605
+ s: A pytorch tensor. The starting time, with the shape (x.shape[0],).
606
+ t: A pytorch tensor. The ending time, with the shape (x.shape[0],).
607
+ r1: A `float`. The hyperparameter of the third-order solver.
608
+ r2: A `float`. The hyperparameter of the third-order solver.
609
+ model_s: A pytorch tensor. The model function evaluated at time `s`.
610
+ If `model_s` is None, we evaluate the model by `x` and `s`; otherwise we directly use it.
611
+ model_s1: A pytorch tensor. The model function evaluated at time `s1` (the intermediate time given by `r1`).
612
+ If `model_s1` is None, we evaluate the model at `s1`; otherwise we directly use it.
613
+ return_intermediate: A `bool`. If true, also return the model value at time `s`, `s1` and `s2` (the intermediate times).
614
+ solver_type: either 'dpm_solver' or 'taylor'. The type for the high-order solvers.
615
+ The type slightly impacts the performance. We recommend to use 'dpm_solver' type.
616
+ Returns:
617
+ x_t: A pytorch tensor. The approximated solution at time `t`.
618
+ """
619
+ if solver_type not in ['dpm_solver', 'taylor']:
620
+ raise ValueError("'solver_type' must be either 'dpm_solver' or 'taylor', got {}".format(solver_type))
621
+ if r1 is None:
622
+ r1 = 1. / 3.
623
+ if r2 is None:
624
+ r2 = 2. / 3.
625
+ ns = self.noise_schedule
626
+ dims = x.dim()
627
+ lambda_s, lambda_t = ns.marginal_lambda(s), ns.marginal_lambda(t)
628
+ h = lambda_t - lambda_s
629
+ lambda_s1 = lambda_s + r1 * h
630
+ lambda_s2 = lambda_s + r2 * h
631
+ s1 = ns.inverse_lambda(lambda_s1)
632
+ s2 = ns.inverse_lambda(lambda_s2)
633
+ log_alpha_s, log_alpha_s1, log_alpha_s2, log_alpha_t = ns.marginal_log_mean_coeff(
634
+ s), ns.marginal_log_mean_coeff(s1), ns.marginal_log_mean_coeff(s2), ns.marginal_log_mean_coeff(t)
635
+ sigma_s, sigma_s1, sigma_s2, sigma_t = ns.marginal_std(s), ns.marginal_std(s1), ns.marginal_std(
636
+ s2), ns.marginal_std(t)
637
+ alpha_s1, alpha_s2, alpha_t = torch.exp(log_alpha_s1), torch.exp(log_alpha_s2), torch.exp(log_alpha_t)
638
+
639
+ if self.predict_x0:
640
+ phi_11 = torch.expm1(-r1 * h)
641
+ phi_12 = torch.expm1(-r2 * h)
642
+ phi_1 = torch.expm1(-h)
643
+ phi_22 = torch.expm1(-r2 * h) / (r2 * h) + 1.
644
+ phi_2 = phi_1 / h + 1.
645
+ phi_3 = phi_2 / h - 0.5
646
+
647
+ if model_s is None:
648
+ model_s = self.model_fn(x, s)
649
+ if model_s1 is None:
650
+ x_s1 = (
651
+ expand_dims(sigma_s1 / sigma_s, dims) * x
652
+ - expand_dims(alpha_s1 * phi_11, dims) * model_s
653
+ )
654
+ model_s1 = self.model_fn(x_s1, s1)
655
+ x_s2 = (
656
+ expand_dims(sigma_s2 / sigma_s, dims) * x
657
+ - expand_dims(alpha_s2 * phi_12, dims) * model_s
658
+ + r2 / r1 * expand_dims(alpha_s2 * phi_22, dims) * (model_s1 - model_s)
659
+ )
660
+ model_s2 = self.model_fn(x_s2, s2)
661
+ if solver_type == 'dpm_solver':
662
+ x_t = (
663
+ expand_dims(sigma_t / sigma_s, dims) * x
664
+ - expand_dims(alpha_t * phi_1, dims) * model_s
665
+ + (1. / r2) * expand_dims(alpha_t * phi_2, dims) * (model_s2 - model_s)
666
+ )
667
+ elif solver_type == 'taylor':
668
+ D1_0 = (1. / r1) * (model_s1 - model_s)
669
+ D1_1 = (1. / r2) * (model_s2 - model_s)
670
+ D1 = (r2 * D1_0 - r1 * D1_1) / (r2 - r1)
671
+ D2 = 2. * (D1_1 - D1_0) / (r2 - r1)
672
+ x_t = (
673
+ expand_dims(sigma_t / sigma_s, dims) * x
674
+ - expand_dims(alpha_t * phi_1, dims) * model_s
675
+ + expand_dims(alpha_t * phi_2, dims) * D1
676
+ - expand_dims(alpha_t * phi_3, dims) * D2
677
+ )
678
+ else:
679
+ phi_11 = torch.expm1(r1 * h)
680
+ phi_12 = torch.expm1(r2 * h)
681
+ phi_1 = torch.expm1(h)
682
+ phi_22 = torch.expm1(r2 * h) / (r2 * h) - 1.
683
+ phi_2 = phi_1 / h - 1.
684
+ phi_3 = phi_2 / h - 0.5
685
+
686
+ if model_s is None:
687
+ model_s = self.model_fn(x, s)
688
+ if model_s1 is None:
689
+ x_s1 = (
690
+ expand_dims(torch.exp(log_alpha_s1 - log_alpha_s), dims) * x
691
+ - expand_dims(sigma_s1 * phi_11, dims) * model_s
692
+ )
693
+ model_s1 = self.model_fn(x_s1, s1)
694
+ x_s2 = (
695
+ expand_dims(torch.exp(log_alpha_s2 - log_alpha_s), dims) * x
696
+ - expand_dims(sigma_s2 * phi_12, dims) * model_s
697
+ - r2 / r1 * expand_dims(sigma_s2 * phi_22, dims) * (model_s1 - model_s)
698
+ )
699
+ model_s2 = self.model_fn(x_s2, s2)
700
+ if solver_type == 'dpm_solver':
701
+ x_t = (
702
+ expand_dims(torch.exp(log_alpha_t - log_alpha_s), dims) * x
703
+ - expand_dims(sigma_t * phi_1, dims) * model_s
704
+ - (1. / r2) * expand_dims(sigma_t * phi_2, dims) * (model_s2 - model_s)
705
+ )
706
+ elif solver_type == 'taylor':
707
+ D1_0 = (1. / r1) * (model_s1 - model_s)
708
+ D1_1 = (1. / r2) * (model_s2 - model_s)
709
+ D1 = (r2 * D1_0 - r1 * D1_1) / (r2 - r1)
710
+ D2 = 2. * (D1_1 - D1_0) / (r2 - r1)
711
+ x_t = (
712
+ expand_dims(torch.exp(log_alpha_t - log_alpha_s), dims) * x
713
+ - expand_dims(sigma_t * phi_1, dims) * model_s
714
+ - expand_dims(sigma_t * phi_2, dims) * D1
715
+ - expand_dims(sigma_t * phi_3, dims) * D2
716
+ )
717
+
718
+ if return_intermediate:
719
+ return x_t, {'model_s': model_s, 'model_s1': model_s1, 'model_s2': model_s2}
720
+ else:
721
+ return x_t
722
+
723
+ def multistep_dpm_solver_second_update(self, x, model_prev_list, t_prev_list, t, solver_type="dpm_solver"):
724
+ """
725
+ Multistep solver DPM-Solver-2 from time `t_prev_list[-1]` to time `t`.
726
+ Args:
727
+ x: A pytorch tensor. The initial value at time `s`.
728
+ model_prev_list: A list of pytorch tensor. The previous computed model values.
729
+ t_prev_list: A list of pytorch tensor. The previous times, each time has the shape (x.shape[0],)
730
+ t: A pytorch tensor. The ending time, with the shape (x.shape[0],).
731
+ solver_type: either 'dpm_solver' or 'taylor'. The type for the high-order solvers.
732
+ The type slightly impacts the performance. We recommend to use 'dpm_solver' type.
733
+ Returns:
734
+ x_t: A pytorch tensor. The approximated solution at time `t`.
735
+ """
736
+ if solver_type not in ['dpm_solver', 'taylor']:
737
+ raise ValueError("'solver_type' must be either 'dpm_solver' or 'taylor', got {}".format(solver_type))
738
+ ns = self.noise_schedule
739
+ dims = x.dim()
740
+ model_prev_1, model_prev_0 = model_prev_list
741
+ t_prev_1, t_prev_0 = t_prev_list
742
+ lambda_prev_1, lambda_prev_0, lambda_t = ns.marginal_lambda(t_prev_1), ns.marginal_lambda(
743
+ t_prev_0), ns.marginal_lambda(t)
744
+ log_alpha_prev_0, log_alpha_t = ns.marginal_log_mean_coeff(t_prev_0), ns.marginal_log_mean_coeff(t)
745
+ sigma_prev_0, sigma_t = ns.marginal_std(t_prev_0), ns.marginal_std(t)
746
+ alpha_t = torch.exp(log_alpha_t)
747
+
748
+ h_0 = lambda_prev_0 - lambda_prev_1
749
+ h = lambda_t - lambda_prev_0
750
+ r0 = h_0 / h
751
+ D1_0 = expand_dims(1. / r0, dims) * (model_prev_0 - model_prev_1)
752
+ if self.predict_x0:
753
+ if solver_type == 'dpm_solver':
754
+ x_t = (
755
+ expand_dims(sigma_t / sigma_prev_0, dims) * x
756
+ - expand_dims(alpha_t * (torch.exp(-h) - 1.), dims) * model_prev_0
757
+ - 0.5 * expand_dims(alpha_t * (torch.exp(-h) - 1.), dims) * D1_0
758
+ )
759
+ elif solver_type == 'taylor':
760
+ x_t = (
761
+ expand_dims(sigma_t / sigma_prev_0, dims) * x
762
+ - expand_dims(alpha_t * (torch.exp(-h) - 1.), dims) * model_prev_0
763
+ + expand_dims(alpha_t * ((torch.exp(-h) - 1.) / h + 1.), dims) * D1_0
764
+ )
765
+ else:
766
+ if solver_type == 'dpm_solver':
767
+ x_t = (
768
+ expand_dims(torch.exp(log_alpha_t - log_alpha_prev_0), dims) * x
769
+ - expand_dims(sigma_t * (torch.exp(h) - 1.), dims) * model_prev_0
770
+ - 0.5 * expand_dims(sigma_t * (torch.exp(h) - 1.), dims) * D1_0
771
+ )
772
+ elif solver_type == 'taylor':
773
+ x_t = (
774
+ expand_dims(torch.exp(log_alpha_t - log_alpha_prev_0), dims) * x
775
+ - expand_dims(sigma_t * (torch.exp(h) - 1.), dims) * model_prev_0
776
+ - expand_dims(sigma_t * ((torch.exp(h) - 1.) / h - 1.), dims) * D1_0
777
+ )
778
+ return x_t
779
+
780
+ def multistep_dpm_solver_third_update(self, x, model_prev_list, t_prev_list, t, solver_type='dpm_solver'):
781
+ """
782
+ Multistep solver DPM-Solver-3 from time `t_prev_list[-1]` to time `t`.
783
+ Args:
784
+ x: A pytorch tensor. The initial value at time `s`.
785
+ model_prev_list: A list of pytorch tensor. The previous computed model values.
786
+ t_prev_list: A list of pytorch tensor. The previous times, each time has the shape (x.shape[0],)
787
+ t: A pytorch tensor. The ending time, with the shape (x.shape[0],).
788
+ solver_type: either 'dpm_solver' or 'taylor'. The type for the high-order solvers.
789
+ The type slightly impacts the performance. We recommend to use 'dpm_solver' type.
790
+ Returns:
791
+ x_t: A pytorch tensor. The approximated solution at time `t`.
792
+ """
793
+ ns = self.noise_schedule
794
+ dims = x.dim()
795
+ model_prev_2, model_prev_1, model_prev_0 = model_prev_list
796
+ t_prev_2, t_prev_1, t_prev_0 = t_prev_list
797
+ lambda_prev_2, lambda_prev_1, lambda_prev_0, lambda_t = ns.marginal_lambda(t_prev_2), ns.marginal_lambda(
798
+ t_prev_1), ns.marginal_lambda(t_prev_0), ns.marginal_lambda(t)
799
+ log_alpha_prev_0, log_alpha_t = ns.marginal_log_mean_coeff(t_prev_0), ns.marginal_log_mean_coeff(t)
800
+ sigma_prev_0, sigma_t = ns.marginal_std(t_prev_0), ns.marginal_std(t)
801
+ alpha_t = torch.exp(log_alpha_t)
802
+
803
+ h_1 = lambda_prev_1 - lambda_prev_2
804
+ h_0 = lambda_prev_0 - lambda_prev_1
805
+ h = lambda_t - lambda_prev_0
806
+ r0, r1 = h_0 / h, h_1 / h
807
+ D1_0 = expand_dims(1. / r0, dims) * (model_prev_0 - model_prev_1)
808
+ D1_1 = expand_dims(1. / r1, dims) * (model_prev_1 - model_prev_2)
809
+ D1 = D1_0 + expand_dims(r0 / (r0 + r1), dims) * (D1_0 - D1_1)
810
+ D2 = expand_dims(1. / (r0 + r1), dims) * (D1_0 - D1_1)
811
+ if self.predict_x0:
812
+ x_t = (
813
+ expand_dims(sigma_t / sigma_prev_0, dims) * x
814
+ - expand_dims(alpha_t * (torch.exp(-h) - 1.), dims) * model_prev_0
815
+ + expand_dims(alpha_t * ((torch.exp(-h) - 1.) / h + 1.), dims) * D1
816
+ - expand_dims(alpha_t * ((torch.exp(-h) - 1. + h) / h ** 2 - 0.5), dims) * D2
817
+ )
818
+ else:
819
+ x_t = (
820
+ expand_dims(torch.exp(log_alpha_t - log_alpha_prev_0), dims) * x
821
+ - expand_dims(sigma_t * (torch.exp(h) - 1.), dims) * model_prev_0
822
+ - expand_dims(sigma_t * ((torch.exp(h) - 1.) / h - 1.), dims) * D1
823
+ - expand_dims(sigma_t * ((torch.exp(h) - 1. - h) / h ** 2 - 0.5), dims) * D2
824
+ )
825
+ return x_t
826
+
827
+ def singlestep_dpm_solver_update(self, x, s, t, order, return_intermediate=False, solver_type='dpm_solver', r1=None,
828
+ r2=None):
829
+ """
830
+ Singlestep DPM-Solver with the order `order` from time `s` to time `t`.
831
+ Args:
832
+ x: A pytorch tensor. The initial value at time `s`.
833
+ s: A pytorch tensor. The starting time, with the shape (x.shape[0],).
834
+ t: A pytorch tensor. The ending time, with the shape (x.shape[0],).
835
+ order: A `int`. The order of DPM-Solver. We only support order == 1 or 2 or 3.
836
+ return_intermediate: A `bool`. If true, also return the model value at time `s`, `s1` and `s2` (the intermediate times).
837
+ solver_type: either 'dpm_solver' or 'taylor'. The type for the high-order solvers.
838
+ The type slightly impacts the performance. We recommend to use 'dpm_solver' type.
839
+ r1: A `float`. The hyperparameter of the second-order or third-order solver.
840
+ r2: A `float`. The hyperparameter of the third-order solver.
841
+ Returns:
842
+ x_t: A pytorch tensor. The approximated solution at time `t`.
843
+ """
844
+ if order == 1:
845
+ return self.dpm_solver_first_update(x, s, t, return_intermediate=return_intermediate)
846
+ elif order == 2:
847
+ return self.singlestep_dpm_solver_second_update(x, s, t, return_intermediate=return_intermediate,
848
+ solver_type=solver_type, r1=r1)
849
+ elif order == 3:
850
+ return self.singlestep_dpm_solver_third_update(x, s, t, return_intermediate=return_intermediate,
851
+ solver_type=solver_type, r1=r1, r2=r2)
852
+ else:
853
+ raise ValueError("Solver order must be 1 or 2 or 3, got {}".format(order))
854
+
855
+ def multistep_dpm_solver_update(self, x, model_prev_list, t_prev_list, t, order, solver_type='dpm_solver'):
856
+ """
857
+ Multistep DPM-Solver with the order `order` from time `t_prev_list[-1]` to time `t`.
858
+ Args:
859
+ x: A pytorch tensor. The initial value at time `s`.
860
+ model_prev_list: A list of pytorch tensor. The previous computed model values.
861
+ t_prev_list: A list of pytorch tensor. The previous times, each time has the shape (x.shape[0],)
862
+ t: A pytorch tensor. The ending time, with the shape (x.shape[0],).
863
+ order: A `int`. The order of DPM-Solver. We only support order == 1 or 2 or 3.
864
+ solver_type: either 'dpm_solver' or 'taylor'. The type for the high-order solvers.
865
+ The type slightly impacts the performance. We recommend to use 'dpm_solver' type.
866
+ Returns:
867
+ x_t: A pytorch tensor. The approximated solution at time `t`.
868
+ """
869
+ if order == 1:
870
+ return self.dpm_solver_first_update(x, t_prev_list[-1], t, model_s=model_prev_list[-1])
871
+ elif order == 2:
872
+ return self.multistep_dpm_solver_second_update(x, model_prev_list, t_prev_list, t, solver_type=solver_type)
873
+ elif order == 3:
874
+ return self.multistep_dpm_solver_third_update(x, model_prev_list, t_prev_list, t, solver_type=solver_type)
875
+ else:
876
+ raise ValueError("Solver order must be 1 or 2 or 3, got {}".format(order))
877
+
878
+ def dpm_solver_adaptive(self, x, order, t_T, t_0, h_init=0.05, atol=0.0078, rtol=0.05, theta=0.9, t_err=1e-5,
879
+ solver_type='dpm_solver'):
880
+ """
881
+ The adaptive step size solver based on singlestep DPM-Solver.
882
+ Args:
883
+ x: A pytorch tensor. The initial value at time `t_T`.
884
+ order: A `int`. The (higher) order of the solver. We only support order == 2 or 3.
885
+ t_T: A `float`. The starting time of the sampling (default is T).
886
+ t_0: A `float`. The ending time of the sampling (default is epsilon).
887
+ h_init: A `float`. The initial step size (for logSNR).
888
+ atol: A `float`. The absolute tolerance of the solver. For image data, the default setting is 0.0078, followed [1].
889
+ rtol: A `float`. The relative tolerance of the solver. The default setting is 0.05.
890
+ theta: A `float`. The safety hyperparameter for adapting the step size. The default setting is 0.9, followed [1].
891
+ t_err: A `float`. The tolerance for the time. We solve the diffusion ODE until the absolute error between the
892
+ current time and `t_0` is less than `t_err`. The default setting is 1e-5.
893
+ solver_type: either 'dpm_solver' or 'taylor'. The type for the high-order solvers.
894
+ The type slightly impacts the performance. We recommend to use 'dpm_solver' type.
895
+ Returns:
896
+ x_0: A pytorch tensor. The approximated solution at time `t_0`.
897
+ [1] A. Jolicoeur-Martineau, K. Li, R. Piché-Taillefer, T. Kachman, and I. Mitliagkas, "Gotta go fast when generating data with score-based models," arXiv preprint arXiv:2105.14080, 2021.
898
+ """
899
+ ns = self.noise_schedule
900
+ s = t_T * torch.ones((x.shape[0],)).to(x)
901
+ lambda_s = ns.marginal_lambda(s)
902
+ lambda_0 = ns.marginal_lambda(t_0 * torch.ones_like(s).to(x))
903
+ h = h_init * torch.ones_like(s).to(x)
904
+ x_prev = x
905
+ nfe = 0
906
+ if order == 2:
907
+ r1 = 0.5
908
+ lower_update = lambda x, s, t: self.dpm_solver_first_update(x, s, t, return_intermediate=True)
909
+ higher_update = lambda x, s, t, **kwargs: self.singlestep_dpm_solver_second_update(x, s, t, r1=r1,
910
+ solver_type=solver_type,
911
+ **kwargs)
912
+ elif order == 3:
913
+ r1, r2 = 1. / 3., 2. / 3.
914
+ lower_update = lambda x, s, t: self.singlestep_dpm_solver_second_update(x, s, t, r1=r1,
915
+ return_intermediate=True,
916
+ solver_type=solver_type)
917
+ higher_update = lambda x, s, t, **kwargs: self.singlestep_dpm_solver_third_update(x, s, t, r1=r1, r2=r2,
918
+ solver_type=solver_type,
919
+ **kwargs)
920
+ else:
921
+ raise ValueError("For adaptive step size solver, order must be 2 or 3, got {}".format(order))
922
+ while torch.abs((s - t_0)).mean() > t_err:
923
+ t = ns.inverse_lambda(lambda_s + h)
924
+ x_lower, lower_noise_kwargs = lower_update(x, s, t)
925
+ x_higher = higher_update(x, s, t, **lower_noise_kwargs)
926
+ delta = torch.max(torch.ones_like(x).to(x) * atol, rtol * torch.max(torch.abs(x_lower), torch.abs(x_prev)))
927
+ norm_fn = lambda v: torch.sqrt(torch.square(v.reshape((v.shape[0], -1))).mean(dim=-1, keepdim=True))
928
+ E = norm_fn((x_higher - x_lower) / delta).max()
929
+ if torch.all(E <= 1.):
930
+ x = x_higher
931
+ s = t
932
+ x_prev = x_lower
933
+ lambda_s = ns.marginal_lambda(s)
934
+ h = torch.min(theta * h * torch.float_power(E, -1. / order).float(), lambda_0 - lambda_s)
935
+ nfe += order
936
+ print('adaptive solver nfe', nfe)
937
+ return x
938
+
939
+ def sample(self, x, steps=20, t_start=None, t_end=None, order=3, skip_type='time_uniform',
940
+ method='singlestep', lower_order_final=True, denoise_to_zero=False, solver_type='dpm_solver',
941
+ atol=0.0078, rtol=0.05,
942
+ ):
943
+ """
944
+ Compute the sample at time `t_end` by DPM-Solver, given the initial `x` at time `t_start`.
945
+ =====================================================
946
+ We support the following algorithms for both noise prediction model and data prediction model:
947
+ - 'singlestep':
948
+ Singlestep DPM-Solver (i.e. "DPM-Solver-fast" in the paper), which combines different orders of singlestep DPM-Solver.
949
+ We combine all the singlestep solvers with order <= `order` to use up all the function evaluations (steps).
950
+ The total number of function evaluations (NFE) == `steps`.
951
+ Given a fixed NFE == `steps`, the sampling procedure is:
952
+ - If `order` == 1:
953
+ - Denote K = steps. We use K steps of DPM-Solver-1 (i.e. DDIM).
954
+ - If `order` == 2:
955
+ - Denote K = (steps // 2) + (steps % 2). We take K intermediate time steps for sampling.
956
+ - If steps % 2 == 0, we use K steps of singlestep DPM-Solver-2.
957
+ - If steps % 2 == 1, we use (K - 1) steps of singlestep DPM-Solver-2 and 1 step of DPM-Solver-1.
958
+ - If `order` == 3:
959
+ - Denote K = (steps // 3 + 1). We take K intermediate time steps for sampling.
960
+ - If steps % 3 == 0, we use (K - 2) steps of singlestep DPM-Solver-3, and 1 step of singlestep DPM-Solver-2 and 1 step of DPM-Solver-1.
961
+ - If steps % 3 == 1, we use (K - 1) steps of singlestep DPM-Solver-3 and 1 step of DPM-Solver-1.
962
+ - If steps % 3 == 2, we use (K - 1) steps of singlestep DPM-Solver-3 and 1 step of singlestep DPM-Solver-2.
963
+ - 'multistep':
964
+ Multistep DPM-Solver with the order of `order`. The total number of function evaluations (NFE) == `steps`.
965
+ We initialize the first `order` values by lower order multistep solvers.
966
+ Given a fixed NFE == `steps`, the sampling procedure is:
967
+ Denote K = steps.
968
+ - If `order` == 1:
969
+ - We use K steps of DPM-Solver-1 (i.e. DDIM).
970
+ - If `order` == 2:
971
+ - We firstly use 1 step of DPM-Solver-1, then use (K - 1) step of multistep DPM-Solver-2.
972
+ - If `order` == 3:
973
+ - We firstly use 1 step of DPM-Solver-1, then 1 step of multistep DPM-Solver-2, then (K - 2) step of multistep DPM-Solver-3.
974
+ - 'singlestep_fixed':
975
+ Fixed order singlestep DPM-Solver (i.e. DPM-Solver-1 or singlestep DPM-Solver-2 or singlestep DPM-Solver-3).
976
+ We use singlestep DPM-Solver-`order` for `order`=1 or 2 or 3, with total [`steps` // `order`] * `order` NFE.
977
+ - 'adaptive':
978
+ Adaptive step size DPM-Solver (i.e. "DPM-Solver-12" and "DPM-Solver-23" in the paper).
979
+ We ignore `steps` and use adaptive step size DPM-Solver with a higher order of `order`.
980
+ You can adjust the absolute tolerance `atol` and the relative tolerance `rtol` to balance the computatation costs
981
+ (NFE) and the sample quality.
982
+ - If `order` == 2, we use DPM-Solver-12 which combines DPM-Solver-1 and singlestep DPM-Solver-2.
983
+ - If `order` == 3, we use DPM-Solver-23 which combines singlestep DPM-Solver-2 and singlestep DPM-Solver-3.
984
+ =====================================================
985
+ Some advices for choosing the algorithm:
986
+ - For **unconditional sampling** or **guided sampling with small guidance scale** by DPMs:
987
+ Use singlestep DPM-Solver ("DPM-Solver-fast" in the paper) with `order = 3`.
988
+ e.g.
989
+ >>> dpm_solver = DPM_Solver(model_fn, noise_schedule, predict_x0=False)
990
+ >>> x_sample = dpm_solver.sample(x, steps=steps, t_start=t_start, t_end=t_end, order=3,
991
+ skip_type='time_uniform', method='singlestep')
992
+ - For **guided sampling with large guidance scale** by DPMs:
993
+ Use multistep DPM-Solver with `predict_x0 = True` and `order = 2`.
994
+ e.g.
995
+ >>> dpm_solver = DPM_Solver(model_fn, noise_schedule, predict_x0=True)
996
+ >>> x_sample = dpm_solver.sample(x, steps=steps, t_start=t_start, t_end=t_end, order=2,
997
+ skip_type='time_uniform', method='multistep')
998
+ We support three types of `skip_type`:
999
+ - 'logSNR': uniform logSNR for the time steps. **Recommended for low-resolutional images**
1000
+ - 'time_uniform': uniform time for the time steps. **Recommended for high-resolutional images**.
1001
+ - 'time_quadratic': quadratic time for the time steps.
1002
+ =====================================================
1003
+ Args:
1004
+ x: A pytorch tensor. The initial value at time `t_start`
1005
+ e.g. if `t_start` == T, then `x` is a sample from the standard normal distribution.
1006
+ steps: A `int`. The total number of function evaluations (NFE).
1007
+ t_start: A `float`. The starting time of the sampling.
1008
+ If `T` is None, we use self.noise_schedule.T (default is 1.0).
1009
+ t_end: A `float`. The ending time of the sampling.
1010
+ If `t_end` is None, we use 1. / self.noise_schedule.total_N.
1011
+ e.g. if total_N == 1000, we have `t_end` == 1e-3.
1012
+ For discrete-time DPMs:
1013
+ - We recommend `t_end` == 1. / self.noise_schedule.total_N.
1014
+ For continuous-time DPMs:
1015
+ - We recommend `t_end` == 1e-3 when `steps` <= 15; and `t_end` == 1e-4 when `steps` > 15.
1016
+ order: A `int`. The order of DPM-Solver.
1017
+ skip_type: A `str`. The type for the spacing of the time steps. 'time_uniform' or 'logSNR' or 'time_quadratic'.
1018
+ method: A `str`. The method for sampling. 'singlestep' or 'multistep' or 'singlestep_fixed' or 'adaptive'.
1019
+ denoise_to_zero: A `bool`. Whether to denoise to time 0 at the final step.
1020
+ Default is `False`. If `denoise_to_zero` is `True`, the total NFE is (`steps` + 1).
1021
+ This trick is firstly proposed by DDPM (https://arxiv.org/abs/2006.11239) and
1022
+ score_sde (https://arxiv.org/abs/2011.13456). Such trick can improve the FID
1023
+ for diffusion models sampling by diffusion SDEs for low-resolutional images
1024
+ (such as CIFAR-10). However, we observed that such trick does not matter for
1025
+ high-resolutional images. As it needs an additional NFE, we do not recommend
1026
+ it for high-resolutional images.
1027
+ lower_order_final: A `bool`. Whether to use lower order solvers at the final steps.
1028
+ Only valid for `method=multistep` and `steps < 15`. We empirically find that
1029
+ this trick is a key to stabilizing the sampling by DPM-Solver with very few steps
1030
+ (especially for steps <= 10). So we recommend to set it to be `True`.
1031
+ solver_type: A `str`. The taylor expansion type for the solver. `dpm_solver` or `taylor`. We recommend `dpm_solver`.
1032
+ atol: A `float`. The absolute tolerance of the adaptive step size solver. Valid when `method` == 'adaptive'.
1033
+ rtol: A `float`. The relative tolerance of the adaptive step size solver. Valid when `method` == 'adaptive'.
1034
+ Returns:
1035
+ x_end: A pytorch tensor. The approximated solution at time `t_end`.
1036
+ """
1037
+ t_0 = 1. / self.noise_schedule.total_N if t_end is None else t_end
1038
+ t_T = self.noise_schedule.T if t_start is None else t_start
1039
+ device = x.device
1040
+ if method == 'adaptive':
1041
+ with torch.no_grad():
1042
+ x = self.dpm_solver_adaptive(x, order=order, t_T=t_T, t_0=t_0, atol=atol, rtol=rtol,
1043
+ solver_type=solver_type)
1044
+ elif method == 'multistep':
1045
+ assert steps >= order
1046
+ timesteps = self.get_time_steps(skip_type=skip_type, t_T=t_T, t_0=t_0, N=steps, device=device)
1047
+ assert timesteps.shape[0] - 1 == steps
1048
+ with torch.no_grad():
1049
+ vec_t = timesteps[0].expand((x.shape[0]))
1050
+ model_prev_list = [self.model_fn(x, vec_t)]
1051
+ t_prev_list = [vec_t]
1052
+ # Init the first `order` values by lower order multistep DPM-Solver.
1053
+ for init_order in tqdm(range(1, order), desc="DPM init order"):
1054
+ vec_t = timesteps[init_order].expand(x.shape[0])
1055
+ x = self.multistep_dpm_solver_update(x, model_prev_list, t_prev_list, vec_t, init_order,
1056
+ solver_type=solver_type)
1057
+ model_prev_list.append(self.model_fn(x, vec_t))
1058
+ t_prev_list.append(vec_t)
1059
+ # Compute the remaining values by `order`-th order multistep DPM-Solver.
1060
+ for step in tqdm(range(order, steps + 1), desc="DPM multistep"):
1061
+ vec_t = timesteps[step].expand(x.shape[0])
1062
+ if lower_order_final and steps < 15:
1063
+ step_order = min(order, steps + 1 - step)
1064
+ else:
1065
+ step_order = order
1066
+ x = self.multistep_dpm_solver_update(x, model_prev_list, t_prev_list, vec_t, step_order,
1067
+ solver_type=solver_type)
1068
+ for i in range(order - 1):
1069
+ t_prev_list[i] = t_prev_list[i + 1]
1070
+ model_prev_list[i] = model_prev_list[i + 1]
1071
+ t_prev_list[-1] = vec_t
1072
+ # We do not need to evaluate the final model value.
1073
+ if step < steps:
1074
+ model_prev_list[-1] = self.model_fn(x, vec_t)
1075
+ elif method in ['singlestep', 'singlestep_fixed']:
1076
+ if method == 'singlestep':
1077
+ timesteps_outer, orders = self.get_orders_and_timesteps_for_singlestep_solver(steps=steps, order=order,
1078
+ skip_type=skip_type,
1079
+ t_T=t_T, t_0=t_0,
1080
+ device=device)
1081
+ elif method == 'singlestep_fixed':
1082
+ K = steps // order
1083
+ orders = [order, ] * K
1084
+ timesteps_outer = self.get_time_steps(skip_type=skip_type, t_T=t_T, t_0=t_0, N=K, device=device)
1085
+ for i, order in enumerate(orders):
1086
+ t_T_inner, t_0_inner = timesteps_outer[i], timesteps_outer[i + 1]
1087
+ timesteps_inner = self.get_time_steps(skip_type=skip_type, t_T=t_T_inner.item(), t_0=t_0_inner.item(),
1088
+ N=order, device=device)
1089
+ lambda_inner = self.noise_schedule.marginal_lambda(timesteps_inner)
1090
+ vec_s, vec_t = t_T_inner.tile(x.shape[0]), t_0_inner.tile(x.shape[0])
1091
+ h = lambda_inner[-1] - lambda_inner[0]
1092
+ r1 = None if order <= 1 else (lambda_inner[1] - lambda_inner[0]) / h
1093
+ r2 = None if order <= 2 else (lambda_inner[2] - lambda_inner[0]) / h
1094
+ x = self.singlestep_dpm_solver_update(x, vec_s, vec_t, order, solver_type=solver_type, r1=r1, r2=r2)
1095
+ if denoise_to_zero:
1096
+ x = self.denoise_to_zero_fn(x, torch.ones((x.shape[0],)).to(device) * t_0)
1097
+ return x
1098
+
1099
+
1100
+ #############################################################
1101
+ # other utility functions
1102
+ #############################################################
1103
+
1104
+ def interpolate_fn(x, xp, yp):
1105
+ """
1106
+ A piecewise linear function y = f(x), using xp and yp as keypoints.
1107
+ We implement f(x) in a differentiable way (i.e. applicable for autograd).
1108
+ The function f(x) is well-defined for all x-axis. (For x beyond the bounds of xp, we use the outmost points of xp to define the linear function.)
1109
+ Args:
1110
+ x: PyTorch tensor with shape [N, C], where N is the batch size, C is the number of channels (we use C = 1 for DPM-Solver).
1111
+ xp: PyTorch tensor with shape [C, K], where K is the number of keypoints.
1112
+ yp: PyTorch tensor with shape [C, K].
1113
+ Returns:
1114
+ The function values f(x), with shape [N, C].
1115
+ """
1116
+ N, K = x.shape[0], xp.shape[1]
1117
+ all_x = torch.cat([x.unsqueeze(2), xp.unsqueeze(0).repeat((N, 1, 1))], dim=2)
1118
+ sorted_all_x, x_indices = torch.sort(all_x, dim=2)
1119
+ x_idx = torch.argmin(x_indices, dim=2)
1120
+ cand_start_idx = x_idx - 1
1121
+ start_idx = torch.where(
1122
+ torch.eq(x_idx, 0),
1123
+ torch.tensor(1, device=x.device),
1124
+ torch.where(
1125
+ torch.eq(x_idx, K), torch.tensor(K - 2, device=x.device), cand_start_idx,
1126
+ ),
1127
+ )
1128
+ end_idx = torch.where(torch.eq(start_idx, cand_start_idx), start_idx + 2, start_idx + 1)
1129
+ start_x = torch.gather(sorted_all_x, dim=2, index=start_idx.unsqueeze(2)).squeeze(2)
1130
+ end_x = torch.gather(sorted_all_x, dim=2, index=end_idx.unsqueeze(2)).squeeze(2)
1131
+ start_idx2 = torch.where(
1132
+ torch.eq(x_idx, 0),
1133
+ torch.tensor(0, device=x.device),
1134
+ torch.where(
1135
+ torch.eq(x_idx, K), torch.tensor(K - 2, device=x.device), cand_start_idx,
1136
+ ),
1137
+ )
1138
+ y_positions_expanded = yp.unsqueeze(0).expand(N, -1, -1)
1139
+ start_y = torch.gather(y_positions_expanded, dim=2, index=start_idx2.unsqueeze(2)).squeeze(2)
1140
+ end_y = torch.gather(y_positions_expanded, dim=2, index=(start_idx2 + 1).unsqueeze(2)).squeeze(2)
1141
+ cand = start_y + (x - start_x) * (end_y - start_y) / (end_x - start_x)
1142
+ return cand
1143
+
1144
+
1145
+ def expand_dims(v, dims):
1146
+ """
1147
+ Expand the tensor `v` to the dim `dims`.
1148
+ Args:
1149
+ `v`: a PyTorch tensor with shape [N].
1150
+ `dim`: a `int`.
1151
+ Returns:
1152
+ a PyTorch tensor with shape [N, 1, 1, ..., 1] and the total dimension is `dims`.
1153
+ """
1154
+ return v[(...,) + (None,) * (dims - 1)]
iopaint/model/anytext/ldm/modules/diffusionmodules/model.py ADDED
@@ -0,0 +1,973 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # pytorch_diffusion + derived encoder decoder
2
+ import math
3
+
4
+ import numpy as np
5
+ import torch
6
+ import torch.nn as nn
7
+
8
+
9
+ def get_timestep_embedding(timesteps, embedding_dim):
10
+ """
11
+ This matches the implementation in Denoising Diffusion Probabilistic Models:
12
+ From Fairseq.
13
+ Build sinusoidal embeddings.
14
+ This matches the implementation in tensor2tensor, but differs slightly
15
+ from the description in Section 3.5 of "Attention Is All You Need".
16
+ """
17
+ assert len(timesteps.shape) == 1
18
+
19
+ half_dim = embedding_dim // 2
20
+ emb = math.log(10000) / (half_dim - 1)
21
+ emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb)
22
+ emb = emb.to(device=timesteps.device)
23
+ emb = timesteps.float()[:, None] * emb[None, :]
24
+ emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
25
+ if embedding_dim % 2 == 1: # zero pad
26
+ emb = torch.nn.functional.pad(emb, (0, 1, 0, 0))
27
+ return emb
28
+
29
+
30
+ def nonlinearity(x):
31
+ # swish
32
+ return x * torch.sigmoid(x)
33
+
34
+
35
+ def Normalize(in_channels, num_groups=32):
36
+ return torch.nn.GroupNorm(
37
+ num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True
38
+ )
39
+
40
+
41
+ class Upsample(nn.Module):
42
+ def __init__(self, in_channels, with_conv):
43
+ super().__init__()
44
+ self.with_conv = with_conv
45
+ if self.with_conv:
46
+ self.conv = torch.nn.Conv2d(
47
+ in_channels, in_channels, kernel_size=3, stride=1, padding=1
48
+ )
49
+
50
+ def forward(self, x):
51
+ x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
52
+ if self.with_conv:
53
+ x = self.conv(x)
54
+ return x
55
+
56
+
57
+ class Downsample(nn.Module):
58
+ def __init__(self, in_channels, with_conv):
59
+ super().__init__()
60
+ self.with_conv = with_conv
61
+ if self.with_conv:
62
+ # no asymmetric padding in torch conv, must do it ourselves
63
+ self.conv = torch.nn.Conv2d(
64
+ in_channels, in_channels, kernel_size=3, stride=2, padding=0
65
+ )
66
+
67
+ def forward(self, x):
68
+ if self.with_conv:
69
+ pad = (0, 1, 0, 1)
70
+ x = torch.nn.functional.pad(x, pad, mode="constant", value=0)
71
+ x = self.conv(x)
72
+ else:
73
+ x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2)
74
+ return x
75
+
76
+
77
+ class ResnetBlock(nn.Module):
78
+ def __init__(
79
+ self,
80
+ *,
81
+ in_channels,
82
+ out_channels=None,
83
+ conv_shortcut=False,
84
+ dropout,
85
+ temb_channels=512,
86
+ ):
87
+ super().__init__()
88
+ self.in_channels = in_channels
89
+ out_channels = in_channels if out_channels is None else out_channels
90
+ self.out_channels = out_channels
91
+ self.use_conv_shortcut = conv_shortcut
92
+
93
+ self.norm1 = Normalize(in_channels)
94
+ self.conv1 = torch.nn.Conv2d(
95
+ in_channels, out_channels, kernel_size=3, stride=1, padding=1
96
+ )
97
+ if temb_channels > 0:
98
+ self.temb_proj = torch.nn.Linear(temb_channels, out_channels)
99
+ self.norm2 = Normalize(out_channels)
100
+ self.dropout = torch.nn.Dropout(dropout)
101
+ self.conv2 = torch.nn.Conv2d(
102
+ out_channels, out_channels, kernel_size=3, stride=1, padding=1
103
+ )
104
+ if self.in_channels != self.out_channels:
105
+ if self.use_conv_shortcut:
106
+ self.conv_shortcut = torch.nn.Conv2d(
107
+ in_channels, out_channels, kernel_size=3, stride=1, padding=1
108
+ )
109
+ else:
110
+ self.nin_shortcut = torch.nn.Conv2d(
111
+ in_channels, out_channels, kernel_size=1, stride=1, padding=0
112
+ )
113
+
114
+ def forward(self, x, temb):
115
+ h = x
116
+ h = self.norm1(h)
117
+ h = nonlinearity(h)
118
+ h = self.conv1(h)
119
+
120
+ if temb is not None:
121
+ h = h + self.temb_proj(nonlinearity(temb))[:, :, None, None]
122
+
123
+ h = self.norm2(h)
124
+ h = nonlinearity(h)
125
+ h = self.dropout(h)
126
+ h = self.conv2(h)
127
+
128
+ if self.in_channels != self.out_channels:
129
+ if self.use_conv_shortcut:
130
+ x = self.conv_shortcut(x)
131
+ else:
132
+ x = self.nin_shortcut(x)
133
+
134
+ return x + h
135
+
136
+
137
+ class AttnBlock(nn.Module):
138
+ def __init__(self, in_channels):
139
+ super().__init__()
140
+ self.in_channels = in_channels
141
+
142
+ self.norm = Normalize(in_channels)
143
+ self.q = torch.nn.Conv2d(
144
+ in_channels, in_channels, kernel_size=1, stride=1, padding=0
145
+ )
146
+ self.k = torch.nn.Conv2d(
147
+ in_channels, in_channels, kernel_size=1, stride=1, padding=0
148
+ )
149
+ self.v = torch.nn.Conv2d(
150
+ in_channels, in_channels, kernel_size=1, stride=1, padding=0
151
+ )
152
+ self.proj_out = torch.nn.Conv2d(
153
+ in_channels, in_channels, kernel_size=1, stride=1, padding=0
154
+ )
155
+
156
+ def forward(self, x):
157
+ h_ = x
158
+ h_ = self.norm(h_)
159
+ q = self.q(h_)
160
+ k = self.k(h_)
161
+ v = self.v(h_)
162
+
163
+ # compute attention
164
+ b, c, h, w = q.shape
165
+ q = q.reshape(b, c, h * w)
166
+ q = q.permute(0, 2, 1) # b,hw,c
167
+ k = k.reshape(b, c, h * w) # b,c,hw
168
+ w_ = torch.bmm(q, k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
169
+ w_ = w_ * (int(c) ** (-0.5))
170
+ w_ = torch.nn.functional.softmax(w_, dim=2)
171
+
172
+ # attend to values
173
+ v = v.reshape(b, c, h * w)
174
+ w_ = w_.permute(0, 2, 1) # b,hw,hw (first hw of k, second of q)
175
+ h_ = torch.bmm(v, w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
176
+ h_ = h_.reshape(b, c, h, w)
177
+
178
+ h_ = self.proj_out(h_)
179
+
180
+ return x + h_
181
+
182
+
183
+ class AttnBlock2_0(nn.Module):
184
+ def __init__(self, in_channels):
185
+ super().__init__()
186
+ self.in_channels = in_channels
187
+
188
+ self.norm = Normalize(in_channels)
189
+ self.q = torch.nn.Conv2d(
190
+ in_channels, in_channels, kernel_size=1, stride=1, padding=0
191
+ )
192
+ self.k = torch.nn.Conv2d(
193
+ in_channels, in_channels, kernel_size=1, stride=1, padding=0
194
+ )
195
+ self.v = torch.nn.Conv2d(
196
+ in_channels, in_channels, kernel_size=1, stride=1, padding=0
197
+ )
198
+ self.proj_out = torch.nn.Conv2d(
199
+ in_channels, in_channels, kernel_size=1, stride=1, padding=0
200
+ )
201
+
202
+ def forward(self, x):
203
+ h_ = x
204
+ h_ = self.norm(h_)
205
+ # output: [1, 512, 64, 64]
206
+ q = self.q(h_)
207
+ k = self.k(h_)
208
+ v = self.v(h_)
209
+
210
+ # compute attention
211
+ b, c, h, w = q.shape
212
+
213
+ # q = q.reshape(b, c, h * w).transpose()
214
+ # q = q.permute(0, 2, 1) # b,hw,c
215
+ # k = k.reshape(b, c, h * w) # b,c,hw
216
+ q = q.transpose(1, 2)
217
+ k = k.transpose(1, 2)
218
+ v = v.transpose(1, 2)
219
+ # (batch, num_heads, seq_len, head_dim)
220
+ hidden_states = torch.nn.functional.scaled_dot_product_attention(
221
+ q, k, v, attn_mask=None, dropout_p=0.0, is_causal=False
222
+ )
223
+ hidden_states = hidden_states.transpose(1, 2)
224
+ hidden_states = hidden_states.to(q.dtype)
225
+
226
+ h_ = self.proj_out(hidden_states)
227
+
228
+ return x + h_
229
+
230
+
231
+ def make_attn(in_channels, attn_type="vanilla", attn_kwargs=None):
232
+ assert attn_type in [
233
+ "vanilla",
234
+ "vanilla-xformers",
235
+ "memory-efficient-cross-attn",
236
+ "linear",
237
+ "none",
238
+ ], f"attn_type {attn_type} unknown"
239
+ assert attn_kwargs is None
240
+ if hasattr(torch.nn.functional, "scaled_dot_product_attention"):
241
+ # print(f"Using torch.nn.functional.scaled_dot_product_attention")
242
+ return AttnBlock2_0(in_channels)
243
+ return AttnBlock(in_channels)
244
+
245
+
246
+ class Model(nn.Module):
247
+ def __init__(
248
+ self,
249
+ *,
250
+ ch,
251
+ out_ch,
252
+ ch_mult=(1, 2, 4, 8),
253
+ num_res_blocks,
254
+ attn_resolutions,
255
+ dropout=0.0,
256
+ resamp_with_conv=True,
257
+ in_channels,
258
+ resolution,
259
+ use_timestep=True,
260
+ use_linear_attn=False,
261
+ attn_type="vanilla",
262
+ ):
263
+ super().__init__()
264
+ if use_linear_attn:
265
+ attn_type = "linear"
266
+ self.ch = ch
267
+ self.temb_ch = self.ch * 4
268
+ self.num_resolutions = len(ch_mult)
269
+ self.num_res_blocks = num_res_blocks
270
+ self.resolution = resolution
271
+ self.in_channels = in_channels
272
+
273
+ self.use_timestep = use_timestep
274
+ if self.use_timestep:
275
+ # timestep embedding
276
+ self.temb = nn.Module()
277
+ self.temb.dense = nn.ModuleList(
278
+ [
279
+ torch.nn.Linear(self.ch, self.temb_ch),
280
+ torch.nn.Linear(self.temb_ch, self.temb_ch),
281
+ ]
282
+ )
283
+
284
+ # downsampling
285
+ self.conv_in = torch.nn.Conv2d(
286
+ in_channels, self.ch, kernel_size=3, stride=1, padding=1
287
+ )
288
+
289
+ curr_res = resolution
290
+ in_ch_mult = (1,) + tuple(ch_mult)
291
+ self.down = nn.ModuleList()
292
+ for i_level in range(self.num_resolutions):
293
+ block = nn.ModuleList()
294
+ attn = nn.ModuleList()
295
+ block_in = ch * in_ch_mult[i_level]
296
+ block_out = ch * ch_mult[i_level]
297
+ for i_block in range(self.num_res_blocks):
298
+ block.append(
299
+ ResnetBlock(
300
+ in_channels=block_in,
301
+ out_channels=block_out,
302
+ temb_channels=self.temb_ch,
303
+ dropout=dropout,
304
+ )
305
+ )
306
+ block_in = block_out
307
+ if curr_res in attn_resolutions:
308
+ attn.append(make_attn(block_in, attn_type=attn_type))
309
+ down = nn.Module()
310
+ down.block = block
311
+ down.attn = attn
312
+ if i_level != self.num_resolutions - 1:
313
+ down.downsample = Downsample(block_in, resamp_with_conv)
314
+ curr_res = curr_res // 2
315
+ self.down.append(down)
316
+
317
+ # middle
318
+ self.mid = nn.Module()
319
+ self.mid.block_1 = ResnetBlock(
320
+ in_channels=block_in,
321
+ out_channels=block_in,
322
+ temb_channels=self.temb_ch,
323
+ dropout=dropout,
324
+ )
325
+ self.mid.attn_1 = make_attn(block_in, attn_type=attn_type)
326
+ self.mid.block_2 = ResnetBlock(
327
+ in_channels=block_in,
328
+ out_channels=block_in,
329
+ temb_channels=self.temb_ch,
330
+ dropout=dropout,
331
+ )
332
+
333
+ # upsampling
334
+ self.up = nn.ModuleList()
335
+ for i_level in reversed(range(self.num_resolutions)):
336
+ block = nn.ModuleList()
337
+ attn = nn.ModuleList()
338
+ block_out = ch * ch_mult[i_level]
339
+ skip_in = ch * ch_mult[i_level]
340
+ for i_block in range(self.num_res_blocks + 1):
341
+ if i_block == self.num_res_blocks:
342
+ skip_in = ch * in_ch_mult[i_level]
343
+ block.append(
344
+ ResnetBlock(
345
+ in_channels=block_in + skip_in,
346
+ out_channels=block_out,
347
+ temb_channels=self.temb_ch,
348
+ dropout=dropout,
349
+ )
350
+ )
351
+ block_in = block_out
352
+ if curr_res in attn_resolutions:
353
+ attn.append(make_attn(block_in, attn_type=attn_type))
354
+ up = nn.Module()
355
+ up.block = block
356
+ up.attn = attn
357
+ if i_level != 0:
358
+ up.upsample = Upsample(block_in, resamp_with_conv)
359
+ curr_res = curr_res * 2
360
+ self.up.insert(0, up) # prepend to get consistent order
361
+
362
+ # end
363
+ self.norm_out = Normalize(block_in)
364
+ self.conv_out = torch.nn.Conv2d(
365
+ block_in, out_ch, kernel_size=3, stride=1, padding=1
366
+ )
367
+
368
+ def forward(self, x, t=None, context=None):
369
+ # assert x.shape[2] == x.shape[3] == self.resolution
370
+ if context is not None:
371
+ # assume aligned context, cat along channel axis
372
+ x = torch.cat((x, context), dim=1)
373
+ if self.use_timestep:
374
+ # timestep embedding
375
+ assert t is not None
376
+ temb = get_timestep_embedding(t, self.ch)
377
+ temb = self.temb.dense[0](temb)
378
+ temb = nonlinearity(temb)
379
+ temb = self.temb.dense[1](temb)
380
+ else:
381
+ temb = None
382
+
383
+ # downsampling
384
+ hs = [self.conv_in(x)]
385
+ for i_level in range(self.num_resolutions):
386
+ for i_block in range(self.num_res_blocks):
387
+ h = self.down[i_level].block[i_block](hs[-1], temb)
388
+ if len(self.down[i_level].attn) > 0:
389
+ h = self.down[i_level].attn[i_block](h)
390
+ hs.append(h)
391
+ if i_level != self.num_resolutions - 1:
392
+ hs.append(self.down[i_level].downsample(hs[-1]))
393
+
394
+ # middle
395
+ h = hs[-1]
396
+ h = self.mid.block_1(h, temb)
397
+ h = self.mid.attn_1(h)
398
+ h = self.mid.block_2(h, temb)
399
+
400
+ # upsampling
401
+ for i_level in reversed(range(self.num_resolutions)):
402
+ for i_block in range(self.num_res_blocks + 1):
403
+ h = self.up[i_level].block[i_block](
404
+ torch.cat([h, hs.pop()], dim=1), temb
405
+ )
406
+ if len(self.up[i_level].attn) > 0:
407
+ h = self.up[i_level].attn[i_block](h)
408
+ if i_level != 0:
409
+ h = self.up[i_level].upsample(h)
410
+
411
+ # end
412
+ h = self.norm_out(h)
413
+ h = nonlinearity(h)
414
+ h = self.conv_out(h)
415
+ return h
416
+
417
+ def get_last_layer(self):
418
+ return self.conv_out.weight
419
+
420
+
421
+ class Encoder(nn.Module):
422
+ def __init__(
423
+ self,
424
+ *,
425
+ ch,
426
+ out_ch,
427
+ ch_mult=(1, 2, 4, 8),
428
+ num_res_blocks,
429
+ attn_resolutions,
430
+ dropout=0.0,
431
+ resamp_with_conv=True,
432
+ in_channels,
433
+ resolution,
434
+ z_channels,
435
+ double_z=True,
436
+ use_linear_attn=False,
437
+ attn_type="vanilla",
438
+ **ignore_kwargs,
439
+ ):
440
+ super().__init__()
441
+ if use_linear_attn:
442
+ attn_type = "linear"
443
+ self.ch = ch
444
+ self.temb_ch = 0
445
+ self.num_resolutions = len(ch_mult)
446
+ self.num_res_blocks = num_res_blocks
447
+ self.resolution = resolution
448
+ self.in_channels = in_channels
449
+
450
+ # downsampling
451
+ self.conv_in = torch.nn.Conv2d(
452
+ in_channels, self.ch, kernel_size=3, stride=1, padding=1
453
+ )
454
+
455
+ curr_res = resolution
456
+ in_ch_mult = (1,) + tuple(ch_mult)
457
+ self.in_ch_mult = in_ch_mult
458
+ self.down = nn.ModuleList()
459
+ for i_level in range(self.num_resolutions):
460
+ block = nn.ModuleList()
461
+ attn = nn.ModuleList()
462
+ block_in = ch * in_ch_mult[i_level]
463
+ block_out = ch * ch_mult[i_level]
464
+ for i_block in range(self.num_res_blocks):
465
+ block.append(
466
+ ResnetBlock(
467
+ in_channels=block_in,
468
+ out_channels=block_out,
469
+ temb_channels=self.temb_ch,
470
+ dropout=dropout,
471
+ )
472
+ )
473
+ block_in = block_out
474
+ if curr_res in attn_resolutions:
475
+ attn.append(make_attn(block_in, attn_type=attn_type))
476
+ down = nn.Module()
477
+ down.block = block
478
+ down.attn = attn
479
+ if i_level != self.num_resolutions - 1:
480
+ down.downsample = Downsample(block_in, resamp_with_conv)
481
+ curr_res = curr_res // 2
482
+ self.down.append(down)
483
+
484
+ # middle
485
+ self.mid = nn.Module()
486
+ self.mid.block_1 = ResnetBlock(
487
+ in_channels=block_in,
488
+ out_channels=block_in,
489
+ temb_channels=self.temb_ch,
490
+ dropout=dropout,
491
+ )
492
+ self.mid.attn_1 = make_attn(block_in, attn_type=attn_type)
493
+ self.mid.block_2 = ResnetBlock(
494
+ in_channels=block_in,
495
+ out_channels=block_in,
496
+ temb_channels=self.temb_ch,
497
+ dropout=dropout,
498
+ )
499
+
500
+ # end
501
+ self.norm_out = Normalize(block_in)
502
+ self.conv_out = torch.nn.Conv2d(
503
+ block_in,
504
+ 2 * z_channels if double_z else z_channels,
505
+ kernel_size=3,
506
+ stride=1,
507
+ padding=1,
508
+ )
509
+
510
+ def forward(self, x):
511
+ # timestep embedding
512
+ temb = None
513
+
514
+ # downsampling
515
+ hs = [self.conv_in(x)]
516
+ for i_level in range(self.num_resolutions):
517
+ for i_block in range(self.num_res_blocks):
518
+ h = self.down[i_level].block[i_block](hs[-1], temb)
519
+ if len(self.down[i_level].attn) > 0:
520
+ h = self.down[i_level].attn[i_block](h)
521
+ hs.append(h)
522
+ if i_level != self.num_resolutions - 1:
523
+ hs.append(self.down[i_level].downsample(hs[-1]))
524
+
525
+ # middle
526
+ h = hs[-1]
527
+ h = self.mid.block_1(h, temb)
528
+ h = self.mid.attn_1(h)
529
+ h = self.mid.block_2(h, temb)
530
+
531
+ # end
532
+ h = self.norm_out(h)
533
+ h = nonlinearity(h)
534
+ h = self.conv_out(h)
535
+ return h
536
+
537
+
538
+ class Decoder(nn.Module):
539
+ def __init__(
540
+ self,
541
+ *,
542
+ ch,
543
+ out_ch,
544
+ ch_mult=(1, 2, 4, 8),
545
+ num_res_blocks,
546
+ attn_resolutions,
547
+ dropout=0.0,
548
+ resamp_with_conv=True,
549
+ in_channels,
550
+ resolution,
551
+ z_channels,
552
+ give_pre_end=False,
553
+ tanh_out=False,
554
+ use_linear_attn=False,
555
+ attn_type="vanilla",
556
+ **ignorekwargs,
557
+ ):
558
+ super().__init__()
559
+ if use_linear_attn:
560
+ attn_type = "linear"
561
+ self.ch = ch
562
+ self.temb_ch = 0
563
+ self.num_resolutions = len(ch_mult)
564
+ self.num_res_blocks = num_res_blocks
565
+ self.resolution = resolution
566
+ self.in_channels = in_channels
567
+ self.give_pre_end = give_pre_end
568
+ self.tanh_out = tanh_out
569
+
570
+ # compute in_ch_mult, block_in and curr_res at lowest res
571
+ in_ch_mult = (1,) + tuple(ch_mult)
572
+ block_in = ch * ch_mult[self.num_resolutions - 1]
573
+ curr_res = resolution // 2 ** (self.num_resolutions - 1)
574
+ self.z_shape = (1, z_channels, curr_res, curr_res)
575
+ print(
576
+ "Working with z of shape {} = {} dimensions.".format(
577
+ self.z_shape, np.prod(self.z_shape)
578
+ )
579
+ )
580
+
581
+ # z to block_in
582
+ self.conv_in = torch.nn.Conv2d(
583
+ z_channels, block_in, kernel_size=3, stride=1, padding=1
584
+ )
585
+
586
+ # middle
587
+ self.mid = nn.Module()
588
+ self.mid.block_1 = ResnetBlock(
589
+ in_channels=block_in,
590
+ out_channels=block_in,
591
+ temb_channels=self.temb_ch,
592
+ dropout=dropout,
593
+ )
594
+ self.mid.attn_1 = make_attn(block_in, attn_type=attn_type)
595
+ self.mid.block_2 = ResnetBlock(
596
+ in_channels=block_in,
597
+ out_channels=block_in,
598
+ temb_channels=self.temb_ch,
599
+ dropout=dropout,
600
+ )
601
+
602
+ # upsampling
603
+ self.up = nn.ModuleList()
604
+ for i_level in reversed(range(self.num_resolutions)):
605
+ block = nn.ModuleList()
606
+ attn = nn.ModuleList()
607
+ block_out = ch * ch_mult[i_level]
608
+ for i_block in range(self.num_res_blocks + 1):
609
+ block.append(
610
+ ResnetBlock(
611
+ in_channels=block_in,
612
+ out_channels=block_out,
613
+ temb_channels=self.temb_ch,
614
+ dropout=dropout,
615
+ )
616
+ )
617
+ block_in = block_out
618
+ if curr_res in attn_resolutions:
619
+ attn.append(make_attn(block_in, attn_type=attn_type))
620
+ up = nn.Module()
621
+ up.block = block
622
+ up.attn = attn
623
+ if i_level != 0:
624
+ up.upsample = Upsample(block_in, resamp_with_conv)
625
+ curr_res = curr_res * 2
626
+ self.up.insert(0, up) # prepend to get consistent order
627
+
628
+ # end
629
+ self.norm_out = Normalize(block_in)
630
+ self.conv_out = torch.nn.Conv2d(
631
+ block_in, out_ch, kernel_size=3, stride=1, padding=1
632
+ )
633
+
634
+ def forward(self, z):
635
+ # assert z.shape[1:] == self.z_shape[1:]
636
+ self.last_z_shape = z.shape
637
+
638
+ # timestep embedding
639
+ temb = None
640
+
641
+ # z to block_in
642
+ h = self.conv_in(z)
643
+
644
+ # middle
645
+ h = self.mid.block_1(h, temb)
646
+ h = self.mid.attn_1(h)
647
+ h = self.mid.block_2(h, temb)
648
+
649
+ # upsampling
650
+ for i_level in reversed(range(self.num_resolutions)):
651
+ for i_block in range(self.num_res_blocks + 1):
652
+ h = self.up[i_level].block[i_block](h, temb)
653
+ if len(self.up[i_level].attn) > 0:
654
+ h = self.up[i_level].attn[i_block](h)
655
+ if i_level != 0:
656
+ h = self.up[i_level].upsample(h)
657
+
658
+ # end
659
+ if self.give_pre_end:
660
+ return h
661
+
662
+ h = self.norm_out(h)
663
+ h = nonlinearity(h)
664
+ h = self.conv_out(h)
665
+ if self.tanh_out:
666
+ h = torch.tanh(h)
667
+ return h
668
+
669
+
670
+ class SimpleDecoder(nn.Module):
671
+ def __init__(self, in_channels, out_channels, *args, **kwargs):
672
+ super().__init__()
673
+ self.model = nn.ModuleList(
674
+ [
675
+ nn.Conv2d(in_channels, in_channels, 1),
676
+ ResnetBlock(
677
+ in_channels=in_channels,
678
+ out_channels=2 * in_channels,
679
+ temb_channels=0,
680
+ dropout=0.0,
681
+ ),
682
+ ResnetBlock(
683
+ in_channels=2 * in_channels,
684
+ out_channels=4 * in_channels,
685
+ temb_channels=0,
686
+ dropout=0.0,
687
+ ),
688
+ ResnetBlock(
689
+ in_channels=4 * in_channels,
690
+ out_channels=2 * in_channels,
691
+ temb_channels=0,
692
+ dropout=0.0,
693
+ ),
694
+ nn.Conv2d(2 * in_channels, in_channels, 1),
695
+ Upsample(in_channels, with_conv=True),
696
+ ]
697
+ )
698
+ # end
699
+ self.norm_out = Normalize(in_channels)
700
+ self.conv_out = torch.nn.Conv2d(
701
+ in_channels, out_channels, kernel_size=3, stride=1, padding=1
702
+ )
703
+
704
+ def forward(self, x):
705
+ for i, layer in enumerate(self.model):
706
+ if i in [1, 2, 3]:
707
+ x = layer(x, None)
708
+ else:
709
+ x = layer(x)
710
+
711
+ h = self.norm_out(x)
712
+ h = nonlinearity(h)
713
+ x = self.conv_out(h)
714
+ return x
715
+
716
+
717
+ class UpsampleDecoder(nn.Module):
718
+ def __init__(
719
+ self,
720
+ in_channels,
721
+ out_channels,
722
+ ch,
723
+ num_res_blocks,
724
+ resolution,
725
+ ch_mult=(2, 2),
726
+ dropout=0.0,
727
+ ):
728
+ super().__init__()
729
+ # upsampling
730
+ self.temb_ch = 0
731
+ self.num_resolutions = len(ch_mult)
732
+ self.num_res_blocks = num_res_blocks
733
+ block_in = in_channels
734
+ curr_res = resolution // 2 ** (self.num_resolutions - 1)
735
+ self.res_blocks = nn.ModuleList()
736
+ self.upsample_blocks = nn.ModuleList()
737
+ for i_level in range(self.num_resolutions):
738
+ res_block = []
739
+ block_out = ch * ch_mult[i_level]
740
+ for i_block in range(self.num_res_blocks + 1):
741
+ res_block.append(
742
+ ResnetBlock(
743
+ in_channels=block_in,
744
+ out_channels=block_out,
745
+ temb_channels=self.temb_ch,
746
+ dropout=dropout,
747
+ )
748
+ )
749
+ block_in = block_out
750
+ self.res_blocks.append(nn.ModuleList(res_block))
751
+ if i_level != self.num_resolutions - 1:
752
+ self.upsample_blocks.append(Upsample(block_in, True))
753
+ curr_res = curr_res * 2
754
+
755
+ # end
756
+ self.norm_out = Normalize(block_in)
757
+ self.conv_out = torch.nn.Conv2d(
758
+ block_in, out_channels, kernel_size=3, stride=1, padding=1
759
+ )
760
+
761
+ def forward(self, x):
762
+ # upsampling
763
+ h = x
764
+ for k, i_level in enumerate(range(self.num_resolutions)):
765
+ for i_block in range(self.num_res_blocks + 1):
766
+ h = self.res_blocks[i_level][i_block](h, None)
767
+ if i_level != self.num_resolutions - 1:
768
+ h = self.upsample_blocks[k](h)
769
+ h = self.norm_out(h)
770
+ h = nonlinearity(h)
771
+ h = self.conv_out(h)
772
+ return h
773
+
774
+
775
+ class LatentRescaler(nn.Module):
776
+ def __init__(self, factor, in_channels, mid_channels, out_channels, depth=2):
777
+ super().__init__()
778
+ # residual block, interpolate, residual block
779
+ self.factor = factor
780
+ self.conv_in = nn.Conv2d(
781
+ in_channels, mid_channels, kernel_size=3, stride=1, padding=1
782
+ )
783
+ self.res_block1 = nn.ModuleList(
784
+ [
785
+ ResnetBlock(
786
+ in_channels=mid_channels,
787
+ out_channels=mid_channels,
788
+ temb_channels=0,
789
+ dropout=0.0,
790
+ )
791
+ for _ in range(depth)
792
+ ]
793
+ )
794
+ self.attn = AttnBlock(mid_channels)
795
+ self.res_block2 = nn.ModuleList(
796
+ [
797
+ ResnetBlock(
798
+ in_channels=mid_channels,
799
+ out_channels=mid_channels,
800
+ temb_channels=0,
801
+ dropout=0.0,
802
+ )
803
+ for _ in range(depth)
804
+ ]
805
+ )
806
+
807
+ self.conv_out = nn.Conv2d(
808
+ mid_channels,
809
+ out_channels,
810
+ kernel_size=1,
811
+ )
812
+
813
+ def forward(self, x):
814
+ x = self.conv_in(x)
815
+ for block in self.res_block1:
816
+ x = block(x, None)
817
+ x = torch.nn.functional.interpolate(
818
+ x,
819
+ size=(
820
+ int(round(x.shape[2] * self.factor)),
821
+ int(round(x.shape[3] * self.factor)),
822
+ ),
823
+ )
824
+ x = self.attn(x)
825
+ for block in self.res_block2:
826
+ x = block(x, None)
827
+ x = self.conv_out(x)
828
+ return x
829
+
830
+
831
+ class MergedRescaleEncoder(nn.Module):
832
+ def __init__(
833
+ self,
834
+ in_channels,
835
+ ch,
836
+ resolution,
837
+ out_ch,
838
+ num_res_blocks,
839
+ attn_resolutions,
840
+ dropout=0.0,
841
+ resamp_with_conv=True,
842
+ ch_mult=(1, 2, 4, 8),
843
+ rescale_factor=1.0,
844
+ rescale_module_depth=1,
845
+ ):
846
+ super().__init__()
847
+ intermediate_chn = ch * ch_mult[-1]
848
+ self.encoder = Encoder(
849
+ in_channels=in_channels,
850
+ num_res_blocks=num_res_blocks,
851
+ ch=ch,
852
+ ch_mult=ch_mult,
853
+ z_channels=intermediate_chn,
854
+ double_z=False,
855
+ resolution=resolution,
856
+ attn_resolutions=attn_resolutions,
857
+ dropout=dropout,
858
+ resamp_with_conv=resamp_with_conv,
859
+ out_ch=None,
860
+ )
861
+ self.rescaler = LatentRescaler(
862
+ factor=rescale_factor,
863
+ in_channels=intermediate_chn,
864
+ mid_channels=intermediate_chn,
865
+ out_channels=out_ch,
866
+ depth=rescale_module_depth,
867
+ )
868
+
869
+ def forward(self, x):
870
+ x = self.encoder(x)
871
+ x = self.rescaler(x)
872
+ return x
873
+
874
+
875
+ class MergedRescaleDecoder(nn.Module):
876
+ def __init__(
877
+ self,
878
+ z_channels,
879
+ out_ch,
880
+ resolution,
881
+ num_res_blocks,
882
+ attn_resolutions,
883
+ ch,
884
+ ch_mult=(1, 2, 4, 8),
885
+ dropout=0.0,
886
+ resamp_with_conv=True,
887
+ rescale_factor=1.0,
888
+ rescale_module_depth=1,
889
+ ):
890
+ super().__init__()
891
+ tmp_chn = z_channels * ch_mult[-1]
892
+ self.decoder = Decoder(
893
+ out_ch=out_ch,
894
+ z_channels=tmp_chn,
895
+ attn_resolutions=attn_resolutions,
896
+ dropout=dropout,
897
+ resamp_with_conv=resamp_with_conv,
898
+ in_channels=None,
899
+ num_res_blocks=num_res_blocks,
900
+ ch_mult=ch_mult,
901
+ resolution=resolution,
902
+ ch=ch,
903
+ )
904
+ self.rescaler = LatentRescaler(
905
+ factor=rescale_factor,
906
+ in_channels=z_channels,
907
+ mid_channels=tmp_chn,
908
+ out_channels=tmp_chn,
909
+ depth=rescale_module_depth,
910
+ )
911
+
912
+ def forward(self, x):
913
+ x = self.rescaler(x)
914
+ x = self.decoder(x)
915
+ return x
916
+
917
+
918
+ class Upsampler(nn.Module):
919
+ def __init__(self, in_size, out_size, in_channels, out_channels, ch_mult=2):
920
+ super().__init__()
921
+ assert out_size >= in_size
922
+ num_blocks = int(np.log2(out_size // in_size)) + 1
923
+ factor_up = 1.0 + (out_size % in_size)
924
+ print(
925
+ f"Building {self.__class__.__name__} with in_size: {in_size} --> out_size {out_size} and factor {factor_up}"
926
+ )
927
+ self.rescaler = LatentRescaler(
928
+ factor=factor_up,
929
+ in_channels=in_channels,
930
+ mid_channels=2 * in_channels,
931
+ out_channels=in_channels,
932
+ )
933
+ self.decoder = Decoder(
934
+ out_ch=out_channels,
935
+ resolution=out_size,
936
+ z_channels=in_channels,
937
+ num_res_blocks=2,
938
+ attn_resolutions=[],
939
+ in_channels=None,
940
+ ch=in_channels,
941
+ ch_mult=[ch_mult for _ in range(num_blocks)],
942
+ )
943
+
944
+ def forward(self, x):
945
+ x = self.rescaler(x)
946
+ x = self.decoder(x)
947
+ return x
948
+
949
+
950
+ class Resize(nn.Module):
951
+ def __init__(self, in_channels=None, learned=False, mode="bilinear"):
952
+ super().__init__()
953
+ self.with_conv = learned
954
+ self.mode = mode
955
+ if self.with_conv:
956
+ print(
957
+ f"Note: {self.__class__.__name} uses learned downsampling and will ignore the fixed {mode} mode"
958
+ )
959
+ raise NotImplementedError()
960
+ assert in_channels is not None
961
+ # no asymmetric padding in torch conv, must do it ourselves
962
+ self.conv = torch.nn.Conv2d(
963
+ in_channels, in_channels, kernel_size=4, stride=2, padding=1
964
+ )
965
+
966
+ def forward(self, x, scale_factor=1.0):
967
+ if scale_factor == 1.0:
968
+ return x
969
+ else:
970
+ x = torch.nn.functional.interpolate(
971
+ x, mode=self.mode, align_corners=False, scale_factor=scale_factor
972
+ )
973
+ return x
iopaint/model/anytext/ldm/modules/diffusionmodules/openaimodel.py ADDED
@@ -0,0 +1,786 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from abc import abstractmethod
2
+ import math
3
+
4
+ import numpy as np
5
+ import torch as th
6
+ import torch.nn as nn
7
+ import torch.nn.functional as F
8
+
9
+ from iopaint.model.anytext.ldm.modules.diffusionmodules.util import (
10
+ checkpoint,
11
+ conv_nd,
12
+ linear,
13
+ avg_pool_nd,
14
+ zero_module,
15
+ normalization,
16
+ timestep_embedding,
17
+ )
18
+ from iopaint.model.anytext.ldm.modules.attention import SpatialTransformer
19
+ from iopaint.model.anytext.ldm.util import exists
20
+
21
+
22
+ # dummy replace
23
+ def convert_module_to_f16(x):
24
+ pass
25
+
26
+ def convert_module_to_f32(x):
27
+ pass
28
+
29
+
30
+ ## go
31
+ class AttentionPool2d(nn.Module):
32
+ """
33
+ Adapted from CLIP: https://github.com/openai/CLIP/blob/main/clip/model.py
34
+ """
35
+
36
+ def __init__(
37
+ self,
38
+ spacial_dim: int,
39
+ embed_dim: int,
40
+ num_heads_channels: int,
41
+ output_dim: int = None,
42
+ ):
43
+ super().__init__()
44
+ self.positional_embedding = nn.Parameter(th.randn(embed_dim, spacial_dim ** 2 + 1) / embed_dim ** 0.5)
45
+ self.qkv_proj = conv_nd(1, embed_dim, 3 * embed_dim, 1)
46
+ self.c_proj = conv_nd(1, embed_dim, output_dim or embed_dim, 1)
47
+ self.num_heads = embed_dim // num_heads_channels
48
+ self.attention = QKVAttention(self.num_heads)
49
+
50
+ def forward(self, x):
51
+ b, c, *_spatial = x.shape
52
+ x = x.reshape(b, c, -1) # NC(HW)
53
+ x = th.cat([x.mean(dim=-1, keepdim=True), x], dim=-1) # NC(HW+1)
54
+ x = x + self.positional_embedding[None, :, :].to(x.dtype) # NC(HW+1)
55
+ x = self.qkv_proj(x)
56
+ x = self.attention(x)
57
+ x = self.c_proj(x)
58
+ return x[:, :, 0]
59
+
60
+
61
+ class TimestepBlock(nn.Module):
62
+ """
63
+ Any module where forward() takes timestep embeddings as a second argument.
64
+ """
65
+
66
+ @abstractmethod
67
+ def forward(self, x, emb):
68
+ """
69
+ Apply the module to `x` given `emb` timestep embeddings.
70
+ """
71
+
72
+
73
+ class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
74
+ """
75
+ A sequential module that passes timestep embeddings to the children that
76
+ support it as an extra input.
77
+ """
78
+
79
+ def forward(self, x, emb, context=None):
80
+ for layer in self:
81
+ if isinstance(layer, TimestepBlock):
82
+ x = layer(x, emb)
83
+ elif isinstance(layer, SpatialTransformer):
84
+ x = layer(x, context)
85
+ else:
86
+ x = layer(x)
87
+ return x
88
+
89
+
90
+ class Upsample(nn.Module):
91
+ """
92
+ An upsampling layer with an optional convolution.
93
+ :param channels: channels in the inputs and outputs.
94
+ :param use_conv: a bool determining if a convolution is applied.
95
+ :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
96
+ upsampling occurs in the inner-two dimensions.
97
+ """
98
+
99
+ def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1):
100
+ super().__init__()
101
+ self.channels = channels
102
+ self.out_channels = out_channels or channels
103
+ self.use_conv = use_conv
104
+ self.dims = dims
105
+ if use_conv:
106
+ self.conv = conv_nd(dims, self.channels, self.out_channels, 3, padding=padding)
107
+
108
+ def forward(self, x):
109
+ assert x.shape[1] == self.channels
110
+ if self.dims == 3:
111
+ x = F.interpolate(
112
+ x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2), mode="nearest"
113
+ )
114
+ else:
115
+ x = F.interpolate(x, scale_factor=2, mode="nearest")
116
+ if self.use_conv:
117
+ x = self.conv(x)
118
+ return x
119
+
120
+ class TransposedUpsample(nn.Module):
121
+ 'Learned 2x upsampling without padding'
122
+ def __init__(self, channels, out_channels=None, ks=5):
123
+ super().__init__()
124
+ self.channels = channels
125
+ self.out_channels = out_channels or channels
126
+
127
+ self.up = nn.ConvTranspose2d(self.channels,self.out_channels,kernel_size=ks,stride=2)
128
+
129
+ def forward(self,x):
130
+ return self.up(x)
131
+
132
+
133
+ class Downsample(nn.Module):
134
+ """
135
+ A downsampling layer with an optional convolution.
136
+ :param channels: channels in the inputs and outputs.
137
+ :param use_conv: a bool determining if a convolution is applied.
138
+ :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
139
+ downsampling occurs in the inner-two dimensions.
140
+ """
141
+
142
+ def __init__(self, channels, use_conv, dims=2, out_channels=None,padding=1):
143
+ super().__init__()
144
+ self.channels = channels
145
+ self.out_channels = out_channels or channels
146
+ self.use_conv = use_conv
147
+ self.dims = dims
148
+ stride = 2 if dims != 3 else (1, 2, 2)
149
+ if use_conv:
150
+ self.op = conv_nd(
151
+ dims, self.channels, self.out_channels, 3, stride=stride, padding=padding
152
+ )
153
+ else:
154
+ assert self.channels == self.out_channels
155
+ self.op = avg_pool_nd(dims, kernel_size=stride, stride=stride)
156
+
157
+ def forward(self, x):
158
+ assert x.shape[1] == self.channels
159
+ return self.op(x)
160
+
161
+
162
+ class ResBlock(TimestepBlock):
163
+ """
164
+ A residual block that can optionally change the number of channels.
165
+ :param channels: the number of input channels.
166
+ :param emb_channels: the number of timestep embedding channels.
167
+ :param dropout: the rate of dropout.
168
+ :param out_channels: if specified, the number of out channels.
169
+ :param use_conv: if True and out_channels is specified, use a spatial
170
+ convolution instead of a smaller 1x1 convolution to change the
171
+ channels in the skip connection.
172
+ :param dims: determines if the signal is 1D, 2D, or 3D.
173
+ :param use_checkpoint: if True, use gradient checkpointing on this module.
174
+ :param up: if True, use this block for upsampling.
175
+ :param down: if True, use this block for downsampling.
176
+ """
177
+
178
+ def __init__(
179
+ self,
180
+ channels,
181
+ emb_channels,
182
+ dropout,
183
+ out_channels=None,
184
+ use_conv=False,
185
+ use_scale_shift_norm=False,
186
+ dims=2,
187
+ use_checkpoint=False,
188
+ up=False,
189
+ down=False,
190
+ ):
191
+ super().__init__()
192
+ self.channels = channels
193
+ self.emb_channels = emb_channels
194
+ self.dropout = dropout
195
+ self.out_channels = out_channels or channels
196
+ self.use_conv = use_conv
197
+ self.use_checkpoint = use_checkpoint
198
+ self.use_scale_shift_norm = use_scale_shift_norm
199
+
200
+ self.in_layers = nn.Sequential(
201
+ normalization(channels),
202
+ nn.SiLU(),
203
+ conv_nd(dims, channels, self.out_channels, 3, padding=1),
204
+ )
205
+
206
+ self.updown = up or down
207
+
208
+ if up:
209
+ self.h_upd = Upsample(channels, False, dims)
210
+ self.x_upd = Upsample(channels, False, dims)
211
+ elif down:
212
+ self.h_upd = Downsample(channels, False, dims)
213
+ self.x_upd = Downsample(channels, False, dims)
214
+ else:
215
+ self.h_upd = self.x_upd = nn.Identity()
216
+
217
+ self.emb_layers = nn.Sequential(
218
+ nn.SiLU(),
219
+ linear(
220
+ emb_channels,
221
+ 2 * self.out_channels if use_scale_shift_norm else self.out_channels,
222
+ ),
223
+ )
224
+ self.out_layers = nn.Sequential(
225
+ normalization(self.out_channels),
226
+ nn.SiLU(),
227
+ nn.Dropout(p=dropout),
228
+ zero_module(
229
+ conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1)
230
+ ),
231
+ )
232
+
233
+ if self.out_channels == channels:
234
+ self.skip_connection = nn.Identity()
235
+ elif use_conv:
236
+ self.skip_connection = conv_nd(
237
+ dims, channels, self.out_channels, 3, padding=1
238
+ )
239
+ else:
240
+ self.skip_connection = conv_nd(dims, channels, self.out_channels, 1)
241
+
242
+ def forward(self, x, emb):
243
+ """
244
+ Apply the block to a Tensor, conditioned on a timestep embedding.
245
+ :param x: an [N x C x ...] Tensor of features.
246
+ :param emb: an [N x emb_channels] Tensor of timestep embeddings.
247
+ :return: an [N x C x ...] Tensor of outputs.
248
+ """
249
+ return checkpoint(
250
+ self._forward, (x, emb), self.parameters(), self.use_checkpoint
251
+ )
252
+
253
+
254
+ def _forward(self, x, emb):
255
+ if self.updown:
256
+ in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1]
257
+ h = in_rest(x)
258
+ h = self.h_upd(h)
259
+ x = self.x_upd(x)
260
+ h = in_conv(h)
261
+ else:
262
+ h = self.in_layers(x)
263
+ emb_out = self.emb_layers(emb).type(h.dtype)
264
+ while len(emb_out.shape) < len(h.shape):
265
+ emb_out = emb_out[..., None]
266
+ if self.use_scale_shift_norm:
267
+ out_norm, out_rest = self.out_layers[0], self.out_layers[1:]
268
+ scale, shift = th.chunk(emb_out, 2, dim=1)
269
+ h = out_norm(h) * (1 + scale) + shift
270
+ h = out_rest(h)
271
+ else:
272
+ h = h + emb_out
273
+ h = self.out_layers(h)
274
+ return self.skip_connection(x) + h
275
+
276
+
277
+ class AttentionBlock(nn.Module):
278
+ """
279
+ An attention block that allows spatial positions to attend to each other.
280
+ Originally ported from here, but adapted to the N-d case.
281
+ https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66.
282
+ """
283
+
284
+ def __init__(
285
+ self,
286
+ channels,
287
+ num_heads=1,
288
+ num_head_channels=-1,
289
+ use_checkpoint=False,
290
+ use_new_attention_order=False,
291
+ ):
292
+ super().__init__()
293
+ self.channels = channels
294
+ if num_head_channels == -1:
295
+ self.num_heads = num_heads
296
+ else:
297
+ assert (
298
+ channels % num_head_channels == 0
299
+ ), f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}"
300
+ self.num_heads = channels // num_head_channels
301
+ self.use_checkpoint = use_checkpoint
302
+ self.norm = normalization(channels)
303
+ self.qkv = conv_nd(1, channels, channels * 3, 1)
304
+ if use_new_attention_order:
305
+ # split qkv before split heads
306
+ self.attention = QKVAttention(self.num_heads)
307
+ else:
308
+ # split heads before split qkv
309
+ self.attention = QKVAttentionLegacy(self.num_heads)
310
+
311
+ self.proj_out = zero_module(conv_nd(1, channels, channels, 1))
312
+
313
+ def forward(self, x):
314
+ return checkpoint(self._forward, (x,), self.parameters(), True) # TODO: check checkpoint usage, is True # TODO: fix the .half call!!!
315
+ #return pt_checkpoint(self._forward, x) # pytorch
316
+
317
+ def _forward(self, x):
318
+ b, c, *spatial = x.shape
319
+ x = x.reshape(b, c, -1)
320
+ qkv = self.qkv(self.norm(x))
321
+ h = self.attention(qkv)
322
+ h = self.proj_out(h)
323
+ return (x + h).reshape(b, c, *spatial)
324
+
325
+
326
+ def count_flops_attn(model, _x, y):
327
+ """
328
+ A counter for the `thop` package to count the operations in an
329
+ attention operation.
330
+ Meant to be used like:
331
+ macs, params = thop.profile(
332
+ model,
333
+ inputs=(inputs, timestamps),
334
+ custom_ops={QKVAttention: QKVAttention.count_flops},
335
+ )
336
+ """
337
+ b, c, *spatial = y[0].shape
338
+ num_spatial = int(np.prod(spatial))
339
+ # We perform two matmuls with the same number of ops.
340
+ # The first computes the weight matrix, the second computes
341
+ # the combination of the value vectors.
342
+ matmul_ops = 2 * b * (num_spatial ** 2) * c
343
+ model.total_ops += th.DoubleTensor([matmul_ops])
344
+
345
+
346
+ class QKVAttentionLegacy(nn.Module):
347
+ """
348
+ A module which performs QKV attention. Matches legacy QKVAttention + input/ouput heads shaping
349
+ """
350
+
351
+ def __init__(self, n_heads):
352
+ super().__init__()
353
+ self.n_heads = n_heads
354
+
355
+ def forward(self, qkv):
356
+ """
357
+ Apply QKV attention.
358
+ :param qkv: an [N x (H * 3 * C) x T] tensor of Qs, Ks, and Vs.
359
+ :return: an [N x (H * C) x T] tensor after attention.
360
+ """
361
+ bs, width, length = qkv.shape
362
+ assert width % (3 * self.n_heads) == 0
363
+ ch = width // (3 * self.n_heads)
364
+ q, k, v = qkv.reshape(bs * self.n_heads, ch * 3, length).split(ch, dim=1)
365
+ scale = 1 / math.sqrt(math.sqrt(ch))
366
+ weight = th.einsum(
367
+ "bct,bcs->bts", q * scale, k * scale
368
+ ) # More stable with f16 than dividing afterwards
369
+ weight = th.softmax(weight.float(), dim=-1).type(weight.dtype)
370
+ a = th.einsum("bts,bcs->bct", weight, v)
371
+ return a.reshape(bs, -1, length)
372
+
373
+ @staticmethod
374
+ def count_flops(model, _x, y):
375
+ return count_flops_attn(model, _x, y)
376
+
377
+
378
+ class QKVAttention(nn.Module):
379
+ """
380
+ A module which performs QKV attention and splits in a different order.
381
+ """
382
+
383
+ def __init__(self, n_heads):
384
+ super().__init__()
385
+ self.n_heads = n_heads
386
+
387
+ def forward(self, qkv):
388
+ """
389
+ Apply QKV attention.
390
+ :param qkv: an [N x (3 * H * C) x T] tensor of Qs, Ks, and Vs.
391
+ :return: an [N x (H * C) x T] tensor after attention.
392
+ """
393
+ bs, width, length = qkv.shape
394
+ assert width % (3 * self.n_heads) == 0
395
+ ch = width // (3 * self.n_heads)
396
+ q, k, v = qkv.chunk(3, dim=1)
397
+ scale = 1 / math.sqrt(math.sqrt(ch))
398
+ weight = th.einsum(
399
+ "bct,bcs->bts",
400
+ (q * scale).view(bs * self.n_heads, ch, length),
401
+ (k * scale).view(bs * self.n_heads, ch, length),
402
+ ) # More stable with f16 than dividing afterwards
403
+ weight = th.softmax(weight.float(), dim=-1).type(weight.dtype)
404
+ a = th.einsum("bts,bcs->bct", weight, v.reshape(bs * self.n_heads, ch, length))
405
+ return a.reshape(bs, -1, length)
406
+
407
+ @staticmethod
408
+ def count_flops(model, _x, y):
409
+ return count_flops_attn(model, _x, y)
410
+
411
+
412
+ class UNetModel(nn.Module):
413
+ """
414
+ The full UNet model with attention and timestep embedding.
415
+ :param in_channels: channels in the input Tensor.
416
+ :param model_channels: base channel count for the model.
417
+ :param out_channels: channels in the output Tensor.
418
+ :param num_res_blocks: number of residual blocks per downsample.
419
+ :param attention_resolutions: a collection of downsample rates at which
420
+ attention will take place. May be a set, list, or tuple.
421
+ For example, if this contains 4, then at 4x downsampling, attention
422
+ will be used.
423
+ :param dropout: the dropout probability.
424
+ :param channel_mult: channel multiplier for each level of the UNet.
425
+ :param conv_resample: if True, use learned convolutions for upsampling and
426
+ downsampling.
427
+ :param dims: determines if the signal is 1D, 2D, or 3D.
428
+ :param num_classes: if specified (as an int), then this model will be
429
+ class-conditional with `num_classes` classes.
430
+ :param use_checkpoint: use gradient checkpointing to reduce memory usage.
431
+ :param num_heads: the number of attention heads in each attention layer.
432
+ :param num_heads_channels: if specified, ignore num_heads and instead use
433
+ a fixed channel width per attention head.
434
+ :param num_heads_upsample: works with num_heads to set a different number
435
+ of heads for upsampling. Deprecated.
436
+ :param use_scale_shift_norm: use a FiLM-like conditioning mechanism.
437
+ :param resblock_updown: use residual blocks for up/downsampling.
438
+ :param use_new_attention_order: use a different attention pattern for potentially
439
+ increased efficiency.
440
+ """
441
+
442
+ def __init__(
443
+ self,
444
+ image_size,
445
+ in_channels,
446
+ model_channels,
447
+ out_channels,
448
+ num_res_blocks,
449
+ attention_resolutions,
450
+ dropout=0,
451
+ channel_mult=(1, 2, 4, 8),
452
+ conv_resample=True,
453
+ dims=2,
454
+ num_classes=None,
455
+ use_checkpoint=False,
456
+ use_fp16=False,
457
+ num_heads=-1,
458
+ num_head_channels=-1,
459
+ num_heads_upsample=-1,
460
+ use_scale_shift_norm=False,
461
+ resblock_updown=False,
462
+ use_new_attention_order=False,
463
+ use_spatial_transformer=False, # custom transformer support
464
+ transformer_depth=1, # custom transformer support
465
+ context_dim=None, # custom transformer support
466
+ n_embed=None, # custom support for prediction of discrete ids into codebook of first stage vq model
467
+ legacy=True,
468
+ disable_self_attentions=None,
469
+ num_attention_blocks=None,
470
+ disable_middle_self_attn=False,
471
+ use_linear_in_transformer=False,
472
+ ):
473
+ super().__init__()
474
+ if use_spatial_transformer:
475
+ assert context_dim is not None, 'Fool!! You forgot to include the dimension of your cross-attention conditioning...'
476
+
477
+ if context_dim is not None:
478
+ assert use_spatial_transformer, 'Fool!! You forgot to use the spatial transformer for your cross-attention conditioning...'
479
+ from omegaconf.listconfig import ListConfig
480
+ if type(context_dim) == ListConfig:
481
+ context_dim = list(context_dim)
482
+
483
+ if num_heads_upsample == -1:
484
+ num_heads_upsample = num_heads
485
+
486
+ if num_heads == -1:
487
+ assert num_head_channels != -1, 'Either num_heads or num_head_channels has to be set'
488
+
489
+ if num_head_channels == -1:
490
+ assert num_heads != -1, 'Either num_heads or num_head_channels has to be set'
491
+
492
+ self.image_size = image_size
493
+ self.in_channels = in_channels
494
+ self.model_channels = model_channels
495
+ self.out_channels = out_channels
496
+ if isinstance(num_res_blocks, int):
497
+ self.num_res_blocks = len(channel_mult) * [num_res_blocks]
498
+ else:
499
+ if len(num_res_blocks) != len(channel_mult):
500
+ raise ValueError("provide num_res_blocks either as an int (globally constant) or "
501
+ "as a list/tuple (per-level) with the same length as channel_mult")
502
+ self.num_res_blocks = num_res_blocks
503
+ if disable_self_attentions is not None:
504
+ # should be a list of booleans, indicating whether to disable self-attention in TransformerBlocks or not
505
+ assert len(disable_self_attentions) == len(channel_mult)
506
+ if num_attention_blocks is not None:
507
+ assert len(num_attention_blocks) == len(self.num_res_blocks)
508
+ assert all(map(lambda i: self.num_res_blocks[i] >= num_attention_blocks[i], range(len(num_attention_blocks))))
509
+ print(f"Constructor of UNetModel received num_attention_blocks={num_attention_blocks}. "
510
+ f"This option has LESS priority than attention_resolutions {attention_resolutions}, "
511
+ f"i.e., in cases where num_attention_blocks[i] > 0 but 2**i not in attention_resolutions, "
512
+ f"attention will still not be set.")
513
+ self.use_fp16 = use_fp16
514
+ self.attention_resolutions = attention_resolutions
515
+ self.dropout = dropout
516
+ self.channel_mult = channel_mult
517
+ self.conv_resample = conv_resample
518
+ self.num_classes = num_classes
519
+ self.use_checkpoint = use_checkpoint
520
+ self.dtype = th.float16 if use_fp16 else th.float32
521
+ self.num_heads = num_heads
522
+ self.num_head_channels = num_head_channels
523
+ self.num_heads_upsample = num_heads_upsample
524
+ self.predict_codebook_ids = n_embed is not None
525
+
526
+ time_embed_dim = model_channels * 4
527
+ self.time_embed = nn.Sequential(
528
+ linear(model_channels, time_embed_dim),
529
+ nn.SiLU(),
530
+ linear(time_embed_dim, time_embed_dim),
531
+ )
532
+
533
+ if self.num_classes is not None:
534
+ if isinstance(self.num_classes, int):
535
+ self.label_emb = nn.Embedding(num_classes, time_embed_dim)
536
+ elif self.num_classes == "continuous":
537
+ print("setting up linear c_adm embedding layer")
538
+ self.label_emb = nn.Linear(1, time_embed_dim)
539
+ else:
540
+ raise ValueError()
541
+
542
+ self.input_blocks = nn.ModuleList(
543
+ [
544
+ TimestepEmbedSequential(
545
+ conv_nd(dims, in_channels, model_channels, 3, padding=1)
546
+ )
547
+ ]
548
+ )
549
+ self._feature_size = model_channels
550
+ input_block_chans = [model_channels]
551
+ ch = model_channels
552
+ ds = 1
553
+ for level, mult in enumerate(channel_mult):
554
+ for nr in range(self.num_res_blocks[level]):
555
+ layers = [
556
+ ResBlock(
557
+ ch,
558
+ time_embed_dim,
559
+ dropout,
560
+ out_channels=mult * model_channels,
561
+ dims=dims,
562
+ use_checkpoint=use_checkpoint,
563
+ use_scale_shift_norm=use_scale_shift_norm,
564
+ )
565
+ ]
566
+ ch = mult * model_channels
567
+ if ds in attention_resolutions:
568
+ if num_head_channels == -1:
569
+ dim_head = ch // num_heads
570
+ else:
571
+ num_heads = ch // num_head_channels
572
+ dim_head = num_head_channels
573
+ if legacy:
574
+ #num_heads = 1
575
+ dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
576
+ if exists(disable_self_attentions):
577
+ disabled_sa = disable_self_attentions[level]
578
+ else:
579
+ disabled_sa = False
580
+
581
+ if not exists(num_attention_blocks) or nr < num_attention_blocks[level]:
582
+ layers.append(
583
+ AttentionBlock(
584
+ ch,
585
+ use_checkpoint=use_checkpoint,
586
+ num_heads=num_heads,
587
+ num_head_channels=dim_head,
588
+ use_new_attention_order=use_new_attention_order,
589
+ ) if not use_spatial_transformer else SpatialTransformer(
590
+ ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim,
591
+ disable_self_attn=disabled_sa, use_linear=use_linear_in_transformer,
592
+ use_checkpoint=use_checkpoint
593
+ )
594
+ )
595
+ self.input_blocks.append(TimestepEmbedSequential(*layers))
596
+ self._feature_size += ch
597
+ input_block_chans.append(ch)
598
+ if level != len(channel_mult) - 1:
599
+ out_ch = ch
600
+ self.input_blocks.append(
601
+ TimestepEmbedSequential(
602
+ ResBlock(
603
+ ch,
604
+ time_embed_dim,
605
+ dropout,
606
+ out_channels=out_ch,
607
+ dims=dims,
608
+ use_checkpoint=use_checkpoint,
609
+ use_scale_shift_norm=use_scale_shift_norm,
610
+ down=True,
611
+ )
612
+ if resblock_updown
613
+ else Downsample(
614
+ ch, conv_resample, dims=dims, out_channels=out_ch
615
+ )
616
+ )
617
+ )
618
+ ch = out_ch
619
+ input_block_chans.append(ch)
620
+ ds *= 2
621
+ self._feature_size += ch
622
+
623
+ if num_head_channels == -1:
624
+ dim_head = ch // num_heads
625
+ else:
626
+ num_heads = ch // num_head_channels
627
+ dim_head = num_head_channels
628
+ if legacy:
629
+ #num_heads = 1
630
+ dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
631
+ self.middle_block = TimestepEmbedSequential(
632
+ ResBlock(
633
+ ch,
634
+ time_embed_dim,
635
+ dropout,
636
+ dims=dims,
637
+ use_checkpoint=use_checkpoint,
638
+ use_scale_shift_norm=use_scale_shift_norm,
639
+ ),
640
+ AttentionBlock(
641
+ ch,
642
+ use_checkpoint=use_checkpoint,
643
+ num_heads=num_heads,
644
+ num_head_channels=dim_head,
645
+ use_new_attention_order=use_new_attention_order,
646
+ ) if not use_spatial_transformer else SpatialTransformer( # always uses a self-attn
647
+ ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim,
648
+ disable_self_attn=disable_middle_self_attn, use_linear=use_linear_in_transformer,
649
+ use_checkpoint=use_checkpoint
650
+ ),
651
+ ResBlock(
652
+ ch,
653
+ time_embed_dim,
654
+ dropout,
655
+ dims=dims,
656
+ use_checkpoint=use_checkpoint,
657
+ use_scale_shift_norm=use_scale_shift_norm,
658
+ ),
659
+ )
660
+ self._feature_size += ch
661
+
662
+ self.output_blocks = nn.ModuleList([])
663
+ for level, mult in list(enumerate(channel_mult))[::-1]:
664
+ for i in range(self.num_res_blocks[level] + 1):
665
+ ich = input_block_chans.pop()
666
+ layers = [
667
+ ResBlock(
668
+ ch + ich,
669
+ time_embed_dim,
670
+ dropout,
671
+ out_channels=model_channels * mult,
672
+ dims=dims,
673
+ use_checkpoint=use_checkpoint,
674
+ use_scale_shift_norm=use_scale_shift_norm,
675
+ )
676
+ ]
677
+ ch = model_channels * mult
678
+ if ds in attention_resolutions:
679
+ if num_head_channels == -1:
680
+ dim_head = ch // num_heads
681
+ else:
682
+ num_heads = ch // num_head_channels
683
+ dim_head = num_head_channels
684
+ if legacy:
685
+ #num_heads = 1
686
+ dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
687
+ if exists(disable_self_attentions):
688
+ disabled_sa = disable_self_attentions[level]
689
+ else:
690
+ disabled_sa = False
691
+
692
+ if not exists(num_attention_blocks) or i < num_attention_blocks[level]:
693
+ layers.append(
694
+ AttentionBlock(
695
+ ch,
696
+ use_checkpoint=use_checkpoint,
697
+ num_heads=num_heads_upsample,
698
+ num_head_channels=dim_head,
699
+ use_new_attention_order=use_new_attention_order,
700
+ ) if not use_spatial_transformer else SpatialTransformer(
701
+ ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim,
702
+ disable_self_attn=disabled_sa, use_linear=use_linear_in_transformer,
703
+ use_checkpoint=use_checkpoint
704
+ )
705
+ )
706
+ if level and i == self.num_res_blocks[level]:
707
+ out_ch = ch
708
+ layers.append(
709
+ ResBlock(
710
+ ch,
711
+ time_embed_dim,
712
+ dropout,
713
+ out_channels=out_ch,
714
+ dims=dims,
715
+ use_checkpoint=use_checkpoint,
716
+ use_scale_shift_norm=use_scale_shift_norm,
717
+ up=True,
718
+ )
719
+ if resblock_updown
720
+ else Upsample(ch, conv_resample, dims=dims, out_channels=out_ch)
721
+ )
722
+ ds //= 2
723
+ self.output_blocks.append(TimestepEmbedSequential(*layers))
724
+ self._feature_size += ch
725
+
726
+ self.out = nn.Sequential(
727
+ normalization(ch),
728
+ nn.SiLU(),
729
+ zero_module(conv_nd(dims, model_channels, out_channels, 3, padding=1)),
730
+ )
731
+ if self.predict_codebook_ids:
732
+ self.id_predictor = nn.Sequential(
733
+ normalization(ch),
734
+ conv_nd(dims, model_channels, n_embed, 1),
735
+ #nn.LogSoftmax(dim=1) # change to cross_entropy and produce non-normalized logits
736
+ )
737
+
738
+ def convert_to_fp16(self):
739
+ """
740
+ Convert the torso of the model to float16.
741
+ """
742
+ self.input_blocks.apply(convert_module_to_f16)
743
+ self.middle_block.apply(convert_module_to_f16)
744
+ self.output_blocks.apply(convert_module_to_f16)
745
+
746
+ def convert_to_fp32(self):
747
+ """
748
+ Convert the torso of the model to float32.
749
+ """
750
+ self.input_blocks.apply(convert_module_to_f32)
751
+ self.middle_block.apply(convert_module_to_f32)
752
+ self.output_blocks.apply(convert_module_to_f32)
753
+
754
+ def forward(self, x, timesteps=None, context=None, y=None,**kwargs):
755
+ """
756
+ Apply the model to an input batch.
757
+ :param x: an [N x C x ...] Tensor of inputs.
758
+ :param timesteps: a 1-D batch of timesteps.
759
+ :param context: conditioning plugged in via crossattn
760
+ :param y: an [N] Tensor of labels, if class-conditional.
761
+ :return: an [N x C x ...] Tensor of outputs.
762
+ """
763
+ assert (y is not None) == (
764
+ self.num_classes is not None
765
+ ), "must specify y if and only if the model is class-conditional"
766
+ hs = []
767
+ t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False)
768
+ emb = self.time_embed(t_emb)
769
+
770
+ if self.num_classes is not None:
771
+ assert y.shape[0] == x.shape[0]
772
+ emb = emb + self.label_emb(y)
773
+
774
+ h = x.type(self.dtype)
775
+ for module in self.input_blocks:
776
+ h = module(h, emb, context)
777
+ hs.append(h)
778
+ h = self.middle_block(h, emb, context)
779
+ for module in self.output_blocks:
780
+ h = th.cat([h, hs.pop()], dim=1)
781
+ h = module(h, emb, context)
782
+ h = h.type(x.dtype)
783
+ if self.predict_codebook_ids:
784
+ return self.id_predictor(h)
785
+ else:
786
+ return self.out(h)
iopaint/model/anytext/ldm/modules/distributions/distributions.py ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+
4
+
5
+ class AbstractDistribution:
6
+ def sample(self):
7
+ raise NotImplementedError()
8
+
9
+ def mode(self):
10
+ raise NotImplementedError()
11
+
12
+
13
+ class DiracDistribution(AbstractDistribution):
14
+ def __init__(self, value):
15
+ self.value = value
16
+
17
+ def sample(self):
18
+ return self.value
19
+
20
+ def mode(self):
21
+ return self.value
22
+
23
+
24
+ class DiagonalGaussianDistribution(object):
25
+ def __init__(self, parameters, deterministic=False):
26
+ self.parameters = parameters
27
+ self.mean, self.logvar = torch.chunk(parameters, 2, dim=1)
28
+ self.logvar = torch.clamp(self.logvar, -30.0, 20.0)
29
+ self.deterministic = deterministic
30
+ self.std = torch.exp(0.5 * self.logvar)
31
+ self.var = torch.exp(self.logvar)
32
+ if self.deterministic:
33
+ self.var = self.std = torch.zeros_like(self.mean).to(device=self.parameters.device)
34
+
35
+ def sample(self):
36
+ x = self.mean + self.std * torch.randn(self.mean.shape).to(device=self.parameters.device)
37
+ return x
38
+
39
+ def kl(self, other=None):
40
+ if self.deterministic:
41
+ return torch.Tensor([0.])
42
+ else:
43
+ if other is None:
44
+ return 0.5 * torch.sum(torch.pow(self.mean, 2)
45
+ + self.var - 1.0 - self.logvar,
46
+ dim=[1, 2, 3])
47
+ else:
48
+ return 0.5 * torch.sum(
49
+ torch.pow(self.mean - other.mean, 2) / other.var
50
+ + self.var / other.var - 1.0 - self.logvar + other.logvar,
51
+ dim=[1, 2, 3])
52
+
53
+ def nll(self, sample, dims=[1,2,3]):
54
+ if self.deterministic:
55
+ return torch.Tensor([0.])
56
+ logtwopi = np.log(2.0 * np.pi)
57
+ return 0.5 * torch.sum(
58
+ logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var,
59
+ dim=dims)
60
+
61
+ def mode(self):
62
+ return self.mean
63
+
64
+
65
+ def normal_kl(mean1, logvar1, mean2, logvar2):
66
+ """
67
+ source: https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/losses.py#L12
68
+ Compute the KL divergence between two gaussians.
69
+ Shapes are automatically broadcasted, so batches can be compared to
70
+ scalars, among other use cases.
71
+ """
72
+ tensor = None
73
+ for obj in (mean1, logvar1, mean2, logvar2):
74
+ if isinstance(obj, torch.Tensor):
75
+ tensor = obj
76
+ break
77
+ assert tensor is not None, "at least one argument must be a Tensor"
78
+
79
+ # Force variances to be Tensors. Broadcasting helps convert scalars to
80
+ # Tensors, but it does not work for torch.exp().
81
+ logvar1, logvar2 = [
82
+ x if isinstance(x, torch.Tensor) else torch.tensor(x).to(tensor)
83
+ for x in (logvar1, logvar2)
84
+ ]
85
+
86
+ return 0.5 * (
87
+ -1.0
88
+ + logvar2
89
+ - logvar1
90
+ + torch.exp(logvar1 - logvar2)
91
+ + ((mean1 - mean2) ** 2) * torch.exp(-logvar2)
92
+ )
iopaint/model/anytext/ldm/modules/ema.py ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+
4
+
5
+ class LitEma(nn.Module):
6
+ def __init__(self, model, decay=0.9999, use_num_upates=True):
7
+ super().__init__()
8
+ if decay < 0.0 or decay > 1.0:
9
+ raise ValueError('Decay must be between 0 and 1')
10
+
11
+ self.m_name2s_name = {}
12
+ self.register_buffer('decay', torch.tensor(decay, dtype=torch.float32))
13
+ self.register_buffer('num_updates', torch.tensor(0, dtype=torch.int) if use_num_upates
14
+ else torch.tensor(-1, dtype=torch.int))
15
+
16
+ for name, p in model.named_parameters():
17
+ if p.requires_grad:
18
+ # remove as '.'-character is not allowed in buffers
19
+ s_name = name.replace('.', '')
20
+ self.m_name2s_name.update({name: s_name})
21
+ self.register_buffer(s_name, p.clone().detach().data)
22
+
23
+ self.collected_params = []
24
+
25
+ def reset_num_updates(self):
26
+ del self.num_updates
27
+ self.register_buffer('num_updates', torch.tensor(0, dtype=torch.int))
28
+
29
+ def forward(self, model):
30
+ decay = self.decay
31
+
32
+ if self.num_updates >= 0:
33
+ self.num_updates += 1
34
+ decay = min(self.decay, (1 + self.num_updates) / (10 + self.num_updates))
35
+
36
+ one_minus_decay = 1.0 - decay
37
+
38
+ with torch.no_grad():
39
+ m_param = dict(model.named_parameters())
40
+ shadow_params = dict(self.named_buffers())
41
+
42
+ for key in m_param:
43
+ if m_param[key].requires_grad:
44
+ sname = self.m_name2s_name[key]
45
+ shadow_params[sname] = shadow_params[sname].type_as(m_param[key])
46
+ shadow_params[sname].sub_(one_minus_decay * (shadow_params[sname] - m_param[key]))
47
+ else:
48
+ assert not key in self.m_name2s_name
49
+
50
+ def copy_to(self, model):
51
+ m_param = dict(model.named_parameters())
52
+ shadow_params = dict(self.named_buffers())
53
+ for key in m_param:
54
+ if m_param[key].requires_grad:
55
+ m_param[key].data.copy_(shadow_params[self.m_name2s_name[key]].data)
56
+ else:
57
+ assert not key in self.m_name2s_name
58
+
59
+ def store(self, parameters):
60
+ """
61
+ Save the current parameters for restoring later.
62
+ Args:
63
+ parameters: Iterable of `torch.nn.Parameter`; the parameters to be
64
+ temporarily stored.
65
+ """
66
+ self.collected_params = [param.clone() for param in parameters]
67
+
68
+ def restore(self, parameters):
69
+ """
70
+ Restore the parameters stored with the `store` method.
71
+ Useful to validate the model with EMA parameters without affecting the
72
+ original optimization process. Store the parameters before the
73
+ `copy_to` method. After validation (or model saving), use this to
74
+ restore the former parameters.
75
+ Args:
76
+ parameters: Iterable of `torch.nn.Parameter`; the parameters to be
77
+ updated with the stored parameters.
78
+ """
79
+ for c_param, param in zip(self.collected_params, parameters):
80
+ param.data.copy_(c_param.data)
iopaint/model/anytext/ldm/modules/encoders/modules.py ADDED
@@ -0,0 +1,411 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from torch.utils.checkpoint import checkpoint
4
+
5
+ from transformers import (
6
+ T5Tokenizer,
7
+ T5EncoderModel,
8
+ CLIPTokenizer,
9
+ CLIPTextModel,
10
+ AutoProcessor,
11
+ CLIPVisionModelWithProjection,
12
+ )
13
+
14
+ from iopaint.model.anytext.ldm.util import count_params
15
+
16
+
17
+ def _expand_mask(mask, dtype, tgt_len=None):
18
+ """
19
+ Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
20
+ """
21
+ bsz, src_len = mask.size()
22
+ tgt_len = tgt_len if tgt_len is not None else src_len
23
+
24
+ expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)
25
+
26
+ inverted_mask = 1.0 - expanded_mask
27
+
28
+ return inverted_mask.masked_fill(
29
+ inverted_mask.to(torch.bool), torch.finfo(dtype).min
30
+ )
31
+
32
+
33
+ def _build_causal_attention_mask(bsz, seq_len, dtype):
34
+ # lazily create causal attention mask, with full attention between the vision tokens
35
+ # pytorch uses additive attention mask; fill with -inf
36
+ mask = torch.empty(bsz, seq_len, seq_len, dtype=dtype)
37
+ mask.fill_(torch.tensor(torch.finfo(dtype).min))
38
+ mask.triu_(1) # zero out the lower diagonal
39
+ mask = mask.unsqueeze(1) # expand mask
40
+ return mask
41
+
42
+
43
+ class AbstractEncoder(nn.Module):
44
+ def __init__(self):
45
+ super().__init__()
46
+
47
+ def encode(self, *args, **kwargs):
48
+ raise NotImplementedError
49
+
50
+
51
+ class IdentityEncoder(AbstractEncoder):
52
+ def encode(self, x):
53
+ return x
54
+
55
+
56
+ class ClassEmbedder(nn.Module):
57
+ def __init__(self, embed_dim, n_classes=1000, key="class", ucg_rate=0.1):
58
+ super().__init__()
59
+ self.key = key
60
+ self.embedding = nn.Embedding(n_classes, embed_dim)
61
+ self.n_classes = n_classes
62
+ self.ucg_rate = ucg_rate
63
+
64
+ def forward(self, batch, key=None, disable_dropout=False):
65
+ if key is None:
66
+ key = self.key
67
+ # this is for use in crossattn
68
+ c = batch[key][:, None]
69
+ if self.ucg_rate > 0.0 and not disable_dropout:
70
+ mask = 1.0 - torch.bernoulli(torch.ones_like(c) * self.ucg_rate)
71
+ c = mask * c + (1 - mask) * torch.ones_like(c) * (self.n_classes - 1)
72
+ c = c.long()
73
+ c = self.embedding(c)
74
+ return c
75
+
76
+ def get_unconditional_conditioning(self, bs, device="cuda"):
77
+ uc_class = (
78
+ self.n_classes - 1
79
+ ) # 1000 classes --> 0 ... 999, one extra class for ucg (class 1000)
80
+ uc = torch.ones((bs,), device=device) * uc_class
81
+ uc = {self.key: uc}
82
+ return uc
83
+
84
+
85
+ def disabled_train(self, mode=True):
86
+ """Overwrite model.train with this function to make sure train/eval mode
87
+ does not change anymore."""
88
+ return self
89
+
90
+
91
+ class FrozenT5Embedder(AbstractEncoder):
92
+ """Uses the T5 transformer encoder for text"""
93
+
94
+ def __init__(
95
+ self, version="google/t5-v1_1-large", device="cuda", max_length=77, freeze=True
96
+ ): # others are google/t5-v1_1-xl and google/t5-v1_1-xxl
97
+ super().__init__()
98
+ self.tokenizer = T5Tokenizer.from_pretrained(version)
99
+ self.transformer = T5EncoderModel.from_pretrained(version)
100
+ self.device = device
101
+ self.max_length = max_length # TODO: typical value?
102
+ if freeze:
103
+ self.freeze()
104
+
105
+ def freeze(self):
106
+ self.transformer = self.transformer.eval()
107
+ # self.train = disabled_train
108
+ for param in self.parameters():
109
+ param.requires_grad = False
110
+
111
+ def forward(self, text):
112
+ batch_encoding = self.tokenizer(
113
+ text,
114
+ truncation=True,
115
+ max_length=self.max_length,
116
+ return_length=True,
117
+ return_overflowing_tokens=False,
118
+ padding="max_length",
119
+ return_tensors="pt",
120
+ )
121
+ tokens = batch_encoding["input_ids"].to(self.device)
122
+ outputs = self.transformer(input_ids=tokens)
123
+
124
+ z = outputs.last_hidden_state
125
+ return z
126
+
127
+ def encode(self, text):
128
+ return self(text)
129
+
130
+
131
+ class FrozenCLIPEmbedder(AbstractEncoder):
132
+ """Uses the CLIP transformer encoder for text (from huggingface)"""
133
+
134
+ LAYERS = ["last", "pooled", "hidden"]
135
+
136
+ def __init__(
137
+ self,
138
+ version="openai/clip-vit-large-patch14",
139
+ device="cuda",
140
+ max_length=77,
141
+ freeze=True,
142
+ layer="last",
143
+ layer_idx=None,
144
+ ): # clip-vit-base-patch32
145
+ super().__init__()
146
+ assert layer in self.LAYERS
147
+ self.tokenizer = CLIPTokenizer.from_pretrained(version)
148
+ self.transformer = CLIPTextModel.from_pretrained(version)
149
+ self.device = device
150
+ self.max_length = max_length
151
+ if freeze:
152
+ self.freeze()
153
+ self.layer = layer
154
+ self.layer_idx = layer_idx
155
+ if layer == "hidden":
156
+ assert layer_idx is not None
157
+ assert 0 <= abs(layer_idx) <= 12
158
+
159
+ def freeze(self):
160
+ self.transformer = self.transformer.eval()
161
+ # self.train = disabled_train
162
+ for param in self.parameters():
163
+ param.requires_grad = False
164
+
165
+ def forward(self, text):
166
+ batch_encoding = self.tokenizer(
167
+ text,
168
+ truncation=True,
169
+ max_length=self.max_length,
170
+ return_length=True,
171
+ return_overflowing_tokens=False,
172
+ padding="max_length",
173
+ return_tensors="pt",
174
+ )
175
+ tokens = batch_encoding["input_ids"].to(self.device)
176
+ outputs = self.transformer(
177
+ input_ids=tokens, output_hidden_states=self.layer == "hidden"
178
+ )
179
+ if self.layer == "last":
180
+ z = outputs.last_hidden_state
181
+ elif self.layer == "pooled":
182
+ z = outputs.pooler_output[:, None, :]
183
+ else:
184
+ z = outputs.hidden_states[self.layer_idx]
185
+ return z
186
+
187
+ def encode(self, text):
188
+ return self(text)
189
+
190
+
191
+ class FrozenCLIPT5Encoder(AbstractEncoder):
192
+ def __init__(
193
+ self,
194
+ clip_version="openai/clip-vit-large-patch14",
195
+ t5_version="google/t5-v1_1-xl",
196
+ device="cuda",
197
+ clip_max_length=77,
198
+ t5_max_length=77,
199
+ ):
200
+ super().__init__()
201
+ self.clip_encoder = FrozenCLIPEmbedder(
202
+ clip_version, device, max_length=clip_max_length
203
+ )
204
+ self.t5_encoder = FrozenT5Embedder(t5_version, device, max_length=t5_max_length)
205
+ print(
206
+ f"{self.clip_encoder.__class__.__name__} has {count_params(self.clip_encoder)*1.e-6:.2f} M parameters, "
207
+ f"{self.t5_encoder.__class__.__name__} comes with {count_params(self.t5_encoder)*1.e-6:.2f} M params."
208
+ )
209
+
210
+ def encode(self, text):
211
+ return self(text)
212
+
213
+ def forward(self, text):
214
+ clip_z = self.clip_encoder.encode(text)
215
+ t5_z = self.t5_encoder.encode(text)
216
+ return [clip_z, t5_z]
217
+
218
+
219
+ class FrozenCLIPEmbedderT3(AbstractEncoder):
220
+ """Uses the CLIP transformer encoder for text (from Hugging Face)"""
221
+
222
+ def __init__(
223
+ self,
224
+ version="openai/clip-vit-large-patch14",
225
+ device="cuda",
226
+ max_length=77,
227
+ freeze=True,
228
+ use_vision=False,
229
+ ):
230
+ super().__init__()
231
+ self.tokenizer = CLIPTokenizer.from_pretrained(version)
232
+ self.transformer = CLIPTextModel.from_pretrained(version)
233
+ if use_vision:
234
+ self.vit = CLIPVisionModelWithProjection.from_pretrained(version)
235
+ self.processor = AutoProcessor.from_pretrained(version)
236
+ self.device = device
237
+ self.max_length = max_length
238
+ if freeze:
239
+ self.freeze()
240
+
241
+ def embedding_forward(
242
+ self,
243
+ input_ids=None,
244
+ position_ids=None,
245
+ inputs_embeds=None,
246
+ embedding_manager=None,
247
+ ):
248
+ seq_length = (
249
+ input_ids.shape[-1]
250
+ if input_ids is not None
251
+ else inputs_embeds.shape[-2]
252
+ )
253
+ if position_ids is None:
254
+ position_ids = self.position_ids[:, :seq_length]
255
+ if inputs_embeds is None:
256
+ inputs_embeds = self.token_embedding(input_ids)
257
+ if embedding_manager is not None:
258
+ inputs_embeds = embedding_manager(input_ids, inputs_embeds)
259
+ position_embeddings = self.position_embedding(position_ids)
260
+ embeddings = inputs_embeds + position_embeddings
261
+ return embeddings
262
+
263
+ self.transformer.text_model.embeddings.forward = embedding_forward.__get__(
264
+ self.transformer.text_model.embeddings
265
+ )
266
+
267
+ def encoder_forward(
268
+ self,
269
+ inputs_embeds,
270
+ attention_mask=None,
271
+ causal_attention_mask=None,
272
+ output_attentions=None,
273
+ output_hidden_states=None,
274
+ return_dict=None,
275
+ ):
276
+ output_attentions = (
277
+ output_attentions
278
+ if output_attentions is not None
279
+ else self.config.output_attentions
280
+ )
281
+ output_hidden_states = (
282
+ output_hidden_states
283
+ if output_hidden_states is not None
284
+ else self.config.output_hidden_states
285
+ )
286
+ return_dict = (
287
+ return_dict if return_dict is not None else self.config.use_return_dict
288
+ )
289
+ encoder_states = () if output_hidden_states else None
290
+ all_attentions = () if output_attentions else None
291
+ hidden_states = inputs_embeds
292
+ for idx, encoder_layer in enumerate(self.layers):
293
+ if output_hidden_states:
294
+ encoder_states = encoder_states + (hidden_states,)
295
+ layer_outputs = encoder_layer(
296
+ hidden_states,
297
+ attention_mask,
298
+ causal_attention_mask,
299
+ output_attentions=output_attentions,
300
+ )
301
+ hidden_states = layer_outputs[0]
302
+ if output_attentions:
303
+ all_attentions = all_attentions + (layer_outputs[1],)
304
+ if output_hidden_states:
305
+ encoder_states = encoder_states + (hidden_states,)
306
+ return hidden_states
307
+
308
+ self.transformer.text_model.encoder.forward = encoder_forward.__get__(
309
+ self.transformer.text_model.encoder
310
+ )
311
+
312
+ def text_encoder_forward(
313
+ self,
314
+ input_ids=None,
315
+ attention_mask=None,
316
+ position_ids=None,
317
+ output_attentions=None,
318
+ output_hidden_states=None,
319
+ return_dict=None,
320
+ embedding_manager=None,
321
+ ):
322
+ output_attentions = (
323
+ output_attentions
324
+ if output_attentions is not None
325
+ else self.config.output_attentions
326
+ )
327
+ output_hidden_states = (
328
+ output_hidden_states
329
+ if output_hidden_states is not None
330
+ else self.config.output_hidden_states
331
+ )
332
+ return_dict = (
333
+ return_dict if return_dict is not None else self.config.use_return_dict
334
+ )
335
+ if input_ids is None:
336
+ raise ValueError("You have to specify either input_ids")
337
+ input_shape = input_ids.size()
338
+ input_ids = input_ids.view(-1, input_shape[-1])
339
+ hidden_states = self.embeddings(
340
+ input_ids=input_ids,
341
+ position_ids=position_ids,
342
+ embedding_manager=embedding_manager,
343
+ )
344
+ bsz, seq_len = input_shape
345
+ # CLIP's text model uses causal mask, prepare it here.
346
+ # https://github.com/openai/CLIP/blob/cfcffb90e69f37bf2ff1e988237a0fbe41f33c04/clip/model.py#L324
347
+ causal_attention_mask = _build_causal_attention_mask(
348
+ bsz, seq_len, hidden_states.dtype
349
+ ).to(hidden_states.device)
350
+ # expand attention_mask
351
+ if attention_mask is not None:
352
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
353
+ attention_mask = _expand_mask(attention_mask, hidden_states.dtype)
354
+ last_hidden_state = self.encoder(
355
+ inputs_embeds=hidden_states,
356
+ attention_mask=attention_mask,
357
+ causal_attention_mask=causal_attention_mask,
358
+ output_attentions=output_attentions,
359
+ output_hidden_states=output_hidden_states,
360
+ return_dict=return_dict,
361
+ )
362
+ last_hidden_state = self.final_layer_norm(last_hidden_state)
363
+ return last_hidden_state
364
+
365
+ self.transformer.text_model.forward = text_encoder_forward.__get__(
366
+ self.transformer.text_model
367
+ )
368
+
369
+ def transformer_forward(
370
+ self,
371
+ input_ids=None,
372
+ attention_mask=None,
373
+ position_ids=None,
374
+ output_attentions=None,
375
+ output_hidden_states=None,
376
+ return_dict=None,
377
+ embedding_manager=None,
378
+ ):
379
+ return self.text_model(
380
+ input_ids=input_ids,
381
+ attention_mask=attention_mask,
382
+ position_ids=position_ids,
383
+ output_attentions=output_attentions,
384
+ output_hidden_states=output_hidden_states,
385
+ return_dict=return_dict,
386
+ embedding_manager=embedding_manager,
387
+ )
388
+
389
+ self.transformer.forward = transformer_forward.__get__(self.transformer)
390
+
391
+ def freeze(self):
392
+ self.transformer = self.transformer.eval()
393
+ for param in self.parameters():
394
+ param.requires_grad = False
395
+
396
+ def forward(self, text, **kwargs):
397
+ batch_encoding = self.tokenizer(
398
+ text,
399
+ truncation=True,
400
+ max_length=self.max_length,
401
+ return_length=True,
402
+ return_overflowing_tokens=False,
403
+ padding="max_length",
404
+ return_tensors="pt",
405
+ )
406
+ tokens = batch_encoding["input_ids"].to(self.device)
407
+ z = self.transformer(input_ids=tokens, **kwargs)
408
+ return z
409
+
410
+ def encode(self, text, **kwargs):
411
+ return self(text, **kwargs)
iopaint/model/anytext/main.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import os
3
+
4
+ from anytext_pipeline import AnyTextPipeline
5
+ from utils import save_images
6
+
7
+ seed = 66273235
8
+ # seed_everything(seed)
9
+
10
+ pipe = AnyTextPipeline(
11
+ ckpt_path="/Users/cwq/code/github/IOPaint/iopaint/model/anytext/anytext_v1.1_fp16.ckpt",
12
+ font_path="/Users/cwq/code/github/AnyText/anytext/font/SourceHanSansSC-Medium.otf",
13
+ use_fp16=False,
14
+ device="mps",
15
+ )
16
+
17
+ img_save_folder = "SaveImages"
18
+ rgb_image = cv2.imread(
19
+ "/Users/cwq/code/github/AnyText/anytext/example_images/ref7.jpg"
20
+ )[..., ::-1]
21
+
22
+ masked_image = cv2.imread(
23
+ "/Users/cwq/code/github/AnyText/anytext/example_images/edit7.png"
24
+ )[..., ::-1]
25
+
26
+ rgb_image = cv2.resize(rgb_image, (512, 512))
27
+ masked_image = cv2.resize(masked_image, (512, 512))
28
+
29
+ # results: list of rgb ndarray
30
+ results, rtn_code, rtn_warning = pipe(
31
+ prompt='A cake with colorful characters that reads "EVERYDAY", best quality, extremely detailed,4k, HD, supper legible text, clear text edges, clear strokes, neat writing, no watermarks',
32
+ negative_prompt="low-res, bad anatomy, extra digit, fewer digits, cropped, worst quality, low quality, watermark, unreadable text, messy words, distorted text, disorganized writing, advertising picture",
33
+ image=rgb_image,
34
+ masked_image=masked_image,
35
+ num_inference_steps=20,
36
+ strength=1.0,
37
+ guidance_scale=9.0,
38
+ height=rgb_image.shape[0],
39
+ width=rgb_image.shape[1],
40
+ seed=seed,
41
+ sort_priority="y",
42
+ )
43
+ if rtn_code >= 0:
44
+ save_images(results, img_save_folder)
45
+ print(f"Done, result images are saved in: {img_save_folder}")
iopaint/model/anytext/ocr_recog/common.py ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+
7
+
8
+ class Hswish(nn.Module):
9
+ def __init__(self, inplace=True):
10
+ super(Hswish, self).__init__()
11
+ self.inplace = inplace
12
+
13
+ def forward(self, x):
14
+ return x * F.relu6(x + 3., inplace=self.inplace) / 6.
15
+
16
+ # out = max(0, min(1, slop*x+offset))
17
+ # paddle.fluid.layers.hard_sigmoid(x, slope=0.2, offset=0.5, name=None)
18
+ class Hsigmoid(nn.Module):
19
+ def __init__(self, inplace=True):
20
+ super(Hsigmoid, self).__init__()
21
+ self.inplace = inplace
22
+
23
+ def forward(self, x):
24
+ # torch: F.relu6(x + 3., inplace=self.inplace) / 6.
25
+ # paddle: F.relu6(1.2 * x + 3., inplace=self.inplace) / 6.
26
+ return F.relu6(1.2 * x + 3., inplace=self.inplace) / 6.
27
+
28
+ class GELU(nn.Module):
29
+ def __init__(self, inplace=True):
30
+ super(GELU, self).__init__()
31
+ self.inplace = inplace
32
+
33
+ def forward(self, x):
34
+ return torch.nn.functional.gelu(x)
35
+
36
+
37
+ class Swish(nn.Module):
38
+ def __init__(self, inplace=True):
39
+ super(Swish, self).__init__()
40
+ self.inplace = inplace
41
+
42
+ def forward(self, x):
43
+ if self.inplace:
44
+ x.mul_(torch.sigmoid(x))
45
+ return x
46
+ else:
47
+ return x*torch.sigmoid(x)
48
+
49
+
50
+ class Activation(nn.Module):
51
+ def __init__(self, act_type, inplace=True):
52
+ super(Activation, self).__init__()
53
+ act_type = act_type.lower()
54
+ if act_type == 'relu':
55
+ self.act = nn.ReLU(inplace=inplace)
56
+ elif act_type == 'relu6':
57
+ self.act = nn.ReLU6(inplace=inplace)
58
+ elif act_type == 'sigmoid':
59
+ raise NotImplementedError
60
+ elif act_type == 'hard_sigmoid':
61
+ self.act = Hsigmoid(inplace)
62
+ elif act_type == 'hard_swish':
63
+ self.act = Hswish(inplace=inplace)
64
+ elif act_type == 'leakyrelu':
65
+ self.act = nn.LeakyReLU(inplace=inplace)
66
+ elif act_type == 'gelu':
67
+ self.act = GELU(inplace=inplace)
68
+ elif act_type == 'swish':
69
+ self.act = Swish(inplace=inplace)
70
+ else:
71
+ raise NotImplementedError
72
+
73
+ def forward(self, inputs):
74
+ return self.act(inputs)
iopaint/model/anytext/ocr_recog/en_dict.txt ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ 0
2
+ 1
3
+ 2
4
+ 3
5
+ 4
6
+ 5
7
+ 6
8
+ 7
9
+ 8
10
+ 9
11
+ :
12
+ ;
13
+ <
14
+ =
15
+ >
16
+ ?
17
+ @
18
+ A
19
+ B
20
+ C
21
+ D
22
+ E
23
+ F
24
+ G
25
+ H
26
+ I
27
+ J
28
+ K
29
+ L
30
+ M
31
+ N
32
+ O
33
+ P
34
+ Q
35
+ R
36
+ S
37
+ T
38
+ U
39
+ V
40
+ W
41
+ X
42
+ Y
43
+ Z
44
+ [
45
+ \
46
+ ]
47
+ ^
48
+ _
49
+ `
50
+ a
51
+ b
52
+ c
53
+ d
54
+ e
55
+ f
56
+ g
57
+ h
58
+ i
59
+ j
60
+ k
61
+ l
62
+ m
63
+ n
64
+ o
65
+ p
66
+ q
67
+ r
68
+ s
69
+ t
70
+ u
71
+ v
72
+ w
73
+ x
74
+ y
75
+ z
76
+ {
77
+ |
78
+ }
79
+ ~
80
+ !
81
+ "
82
+ #
83
+ $
84
+ %
85
+ &
86
+ '
87
+ (
88
+ )
89
+ *
90
+ +
91
+ ,
92
+ -
93
+ .
94
+ /
95
+
iopaint/model/controlnet.py ADDED
@@ -0,0 +1,190 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import PIL.Image
2
+ import cv2
3
+ import torch
4
+ from diffusers import ControlNetModel
5
+ from loguru import logger
6
+ from iopaint.schema import InpaintRequest, ModelType
7
+
8
+ from .base import DiffusionInpaintModel
9
+ from .helper.controlnet_preprocess import (
10
+ make_canny_control_image,
11
+ make_openpose_control_image,
12
+ make_depth_control_image,
13
+ make_inpaint_control_image,
14
+ )
15
+ from .helper.cpu_text_encoder import CPUTextEncoderWrapper
16
+ from .original_sd_configs import get_config_files
17
+ from .utils import (
18
+ get_scheduler,
19
+ handle_from_pretrained_exceptions,
20
+ get_torch_dtype,
21
+ enable_low_mem,
22
+ is_local_files_only,
23
+ )
24
+
25
+
26
+ class ControlNet(DiffusionInpaintModel):
27
+ name = "controlnet"
28
+ pad_mod = 8
29
+ min_size = 512
30
+
31
+ @property
32
+ def lcm_lora_id(self):
33
+ if self.model_info.model_type in [
34
+ ModelType.DIFFUSERS_SD,
35
+ ModelType.DIFFUSERS_SD_INPAINT,
36
+ ]:
37
+ return "latent-consistency/lcm-lora-sdv1-5"
38
+ if self.model_info.model_type in [
39
+ ModelType.DIFFUSERS_SDXL,
40
+ ModelType.DIFFUSERS_SDXL_INPAINT,
41
+ ]:
42
+ return "latent-consistency/lcm-lora-sdxl"
43
+ raise NotImplementedError(f"Unsupported controlnet lcm model {self.model_info}")
44
+
45
+ def init_model(self, device: torch.device, **kwargs):
46
+ model_info = kwargs["model_info"]
47
+ controlnet_method = kwargs["controlnet_method"]
48
+
49
+ self.model_info = model_info
50
+ self.controlnet_method = controlnet_method
51
+
52
+ model_kwargs = {
53
+ **kwargs.get("pipe_components", {}),
54
+ "local_files_only": is_local_files_only(**kwargs),
55
+ }
56
+ self.local_files_only = model_kwargs["local_files_only"]
57
+
58
+ disable_nsfw_checker = kwargs["disable_nsfw"] or kwargs.get(
59
+ "cpu_offload", False
60
+ )
61
+ if disable_nsfw_checker:
62
+ logger.info("Disable Stable Diffusion Model NSFW checker")
63
+ model_kwargs.update(
64
+ dict(
65
+ safety_checker=None,
66
+ feature_extractor=None,
67
+ requires_safety_checker=False,
68
+ )
69
+ )
70
+
71
+ use_gpu, torch_dtype = get_torch_dtype(device, kwargs.get("no_half", False))
72
+ self.torch_dtype = torch_dtype
73
+
74
+ if model_info.model_type in [
75
+ ModelType.DIFFUSERS_SD,
76
+ ModelType.DIFFUSERS_SD_INPAINT,
77
+ ]:
78
+ from diffusers import (
79
+ StableDiffusionControlNetInpaintPipeline as PipeClass,
80
+ )
81
+ elif model_info.model_type in [
82
+ ModelType.DIFFUSERS_SDXL,
83
+ ModelType.DIFFUSERS_SDXL_INPAINT,
84
+ ]:
85
+ from diffusers import (
86
+ StableDiffusionXLControlNetInpaintPipeline as PipeClass,
87
+ )
88
+
89
+ controlnet = ControlNetModel.from_pretrained(
90
+ pretrained_model_name_or_path=controlnet_method,
91
+ resume_download=True,
92
+ local_files_only=model_kwargs["local_files_only"],
93
+ torch_dtype=self.torch_dtype,
94
+ )
95
+ if model_info.is_single_file_diffusers:
96
+ if self.model_info.model_type == ModelType.DIFFUSERS_SD:
97
+ model_kwargs["num_in_channels"] = 4
98
+ else:
99
+ model_kwargs["num_in_channels"] = 9
100
+
101
+ self.model = PipeClass.from_single_file(
102
+ model_info.path,
103
+ controlnet=controlnet,
104
+ load_safety_checker=not disable_nsfw_checker,
105
+ torch_dtype=torch_dtype,
106
+ config_files=get_config_files(),
107
+ **model_kwargs,
108
+ )
109
+ else:
110
+ self.model = handle_from_pretrained_exceptions(
111
+ PipeClass.from_pretrained,
112
+ pretrained_model_name_or_path=model_info.path,
113
+ controlnet=controlnet,
114
+ variant="fp16",
115
+ torch_dtype=torch_dtype,
116
+ **model_kwargs,
117
+ )
118
+
119
+ enable_low_mem(self.model, kwargs.get("low_mem", False))
120
+
121
+ if kwargs.get("cpu_offload", False) and use_gpu:
122
+ logger.info("Enable sequential cpu offload")
123
+ self.model.enable_sequential_cpu_offload(gpu_id=0)
124
+ else:
125
+ self.model = self.model.to(device)
126
+ if kwargs["sd_cpu_textencoder"]:
127
+ logger.info("Run Stable Diffusion TextEncoder on CPU")
128
+ self.model.text_encoder = CPUTextEncoderWrapper(
129
+ self.model.text_encoder, torch_dtype
130
+ )
131
+
132
+ self.callback = kwargs.pop("callback", None)
133
+
134
+ def switch_controlnet_method(self, new_method: str):
135
+ self.controlnet_method = new_method
136
+ controlnet = ControlNetModel.from_pretrained(
137
+ new_method,
138
+ resume_download=True,
139
+ local_files_only=self.local_files_only,
140
+ torch_dtype=self.torch_dtype,
141
+ ).to(self.model.device)
142
+ self.model.controlnet = controlnet
143
+
144
+ def _get_control_image(self, image, mask):
145
+ if "canny" in self.controlnet_method:
146
+ control_image = make_canny_control_image(image)
147
+ elif "openpose" in self.controlnet_method:
148
+ control_image = make_openpose_control_image(image)
149
+ elif "depth" in self.controlnet_method:
150
+ control_image = make_depth_control_image(image)
151
+ elif "inpaint" in self.controlnet_method:
152
+ control_image = make_inpaint_control_image(image, mask)
153
+ else:
154
+ raise NotImplementedError(f"{self.controlnet_method} not implemented")
155
+ return control_image
156
+
157
+ def forward(self, image, mask, config: InpaintRequest):
158
+ """Input image and output image have same size
159
+ image: [H, W, C] RGB
160
+ mask: [H, W, 1] 255 means area to repaint
161
+ return: BGR IMAGE
162
+ """
163
+ scheduler_config = self.model.scheduler.config
164
+ scheduler = get_scheduler(config.sd_sampler, scheduler_config)
165
+ self.model.scheduler = scheduler
166
+
167
+ img_h, img_w = image.shape[:2]
168
+ control_image = self._get_control_image(image, mask)
169
+ mask_image = PIL.Image.fromarray(mask[:, :, -1], mode="L")
170
+ image = PIL.Image.fromarray(image)
171
+
172
+ output = self.model(
173
+ image=image,
174
+ mask_image=mask_image,
175
+ control_image=control_image,
176
+ prompt=config.prompt,
177
+ negative_prompt=config.negative_prompt,
178
+ num_inference_steps=config.sd_steps,
179
+ guidance_scale=config.sd_guidance_scale,
180
+ output_type="np",
181
+ callback_on_step_end=self.callback,
182
+ height=img_h,
183
+ width=img_w,
184
+ generator=torch.manual_seed(config.sd_seed),
185
+ controlnet_conditioning_scale=config.controlnet_conditioning_scale,
186
+ ).images[0]
187
+
188
+ output = (output * 255).round().astype("uint8")
189
+ output = cv2.cvtColor(output, cv2.COLOR_RGB2BGR)
190
+ return output
iopaint/model/ddim_sampler.py ADDED
@@ -0,0 +1,193 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ from tqdm import tqdm
4
+
5
+ from .utils import make_ddim_timesteps, make_ddim_sampling_parameters, noise_like
6
+
7
+ from loguru import logger
8
+
9
+
10
+ class DDIMSampler(object):
11
+ def __init__(self, model, schedule="linear"):
12
+ super().__init__()
13
+ self.model = model
14
+ self.ddpm_num_timesteps = model.num_timesteps
15
+ self.schedule = schedule
16
+
17
+ def register_buffer(self, name, attr):
18
+ setattr(self, name, attr)
19
+
20
+ def make_schedule(
21
+ self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0.0, verbose=True
22
+ ):
23
+ self.ddim_timesteps = make_ddim_timesteps(
24
+ ddim_discr_method=ddim_discretize,
25
+ num_ddim_timesteps=ddim_num_steps,
26
+ # array([1])
27
+ num_ddpm_timesteps=self.ddpm_num_timesteps,
28
+ verbose=verbose,
29
+ )
30
+ alphas_cumprod = self.model.alphas_cumprod # torch.Size([1000])
31
+ assert (
32
+ alphas_cumprod.shape[0] == self.ddpm_num_timesteps
33
+ ), "alphas have to be defined for each timestep"
34
+ to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device)
35
+
36
+ self.register_buffer("betas", to_torch(self.model.betas))
37
+ self.register_buffer("alphas_cumprod", to_torch(alphas_cumprod))
38
+ self.register_buffer(
39
+ "alphas_cumprod_prev", to_torch(self.model.alphas_cumprod_prev)
40
+ )
41
+
42
+ # calculations for diffusion q(x_t | x_{t-1}) and others
43
+ self.register_buffer(
44
+ "sqrt_alphas_cumprod", to_torch(np.sqrt(alphas_cumprod.cpu()))
45
+ )
46
+ self.register_buffer(
47
+ "sqrt_one_minus_alphas_cumprod",
48
+ to_torch(np.sqrt(1.0 - alphas_cumprod.cpu())),
49
+ )
50
+ self.register_buffer(
51
+ "log_one_minus_alphas_cumprod", to_torch(np.log(1.0 - alphas_cumprod.cpu()))
52
+ )
53
+ self.register_buffer(
54
+ "sqrt_recip_alphas_cumprod", to_torch(np.sqrt(1.0 / alphas_cumprod.cpu()))
55
+ )
56
+ self.register_buffer(
57
+ "sqrt_recipm1_alphas_cumprod",
58
+ to_torch(np.sqrt(1.0 / alphas_cumprod.cpu() - 1)),
59
+ )
60
+
61
+ # ddim sampling parameters
62
+ ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(
63
+ alphacums=alphas_cumprod.cpu(),
64
+ ddim_timesteps=self.ddim_timesteps,
65
+ eta=ddim_eta,
66
+ verbose=verbose,
67
+ )
68
+ self.register_buffer("ddim_sigmas", ddim_sigmas)
69
+ self.register_buffer("ddim_alphas", ddim_alphas)
70
+ self.register_buffer("ddim_alphas_prev", ddim_alphas_prev)
71
+ self.register_buffer("ddim_sqrt_one_minus_alphas", np.sqrt(1.0 - ddim_alphas))
72
+ sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt(
73
+ (1 - self.alphas_cumprod_prev)
74
+ / (1 - self.alphas_cumprod)
75
+ * (1 - self.alphas_cumprod / self.alphas_cumprod_prev)
76
+ )
77
+ self.register_buffer(
78
+ "ddim_sigmas_for_original_num_steps", sigmas_for_original_sampling_steps
79
+ )
80
+
81
+ @torch.no_grad()
82
+ def sample(self, steps, conditioning, batch_size, shape):
83
+ self.make_schedule(ddim_num_steps=steps, ddim_eta=0, verbose=False)
84
+ # sampling
85
+ C, H, W = shape
86
+ size = (batch_size, C, H, W)
87
+
88
+ # samples: 1,3,128,128
89
+ return self.ddim_sampling(
90
+ conditioning,
91
+ size,
92
+ quantize_denoised=False,
93
+ ddim_use_original_steps=False,
94
+ noise_dropout=0,
95
+ temperature=1.0,
96
+ )
97
+
98
+ @torch.no_grad()
99
+ def ddim_sampling(
100
+ self,
101
+ cond,
102
+ shape,
103
+ ddim_use_original_steps=False,
104
+ quantize_denoised=False,
105
+ temperature=1.0,
106
+ noise_dropout=0.0,
107
+ ):
108
+ device = self.model.betas.device
109
+ b = shape[0]
110
+ img = torch.randn(shape, device=device, dtype=cond.dtype)
111
+ timesteps = (
112
+ self.ddpm_num_timesteps if ddim_use_original_steps else self.ddim_timesteps
113
+ )
114
+
115
+ time_range = (
116
+ reversed(range(0, timesteps))
117
+ if ddim_use_original_steps
118
+ else np.flip(timesteps)
119
+ )
120
+ total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0]
121
+ logger.info(f"Running DDIM Sampling with {total_steps} timesteps")
122
+
123
+ iterator = tqdm(time_range, desc="DDIM Sampler", total=total_steps)
124
+
125
+ for i, step in enumerate(iterator):
126
+ index = total_steps - i - 1
127
+ ts = torch.full((b,), step, device=device, dtype=torch.long)
128
+
129
+ outs = self.p_sample_ddim(
130
+ img,
131
+ cond,
132
+ ts,
133
+ index=index,
134
+ use_original_steps=ddim_use_original_steps,
135
+ quantize_denoised=quantize_denoised,
136
+ temperature=temperature,
137
+ noise_dropout=noise_dropout,
138
+ )
139
+ img, _ = outs
140
+
141
+ return img
142
+
143
+ @torch.no_grad()
144
+ def p_sample_ddim(
145
+ self,
146
+ x,
147
+ c,
148
+ t,
149
+ index,
150
+ repeat_noise=False,
151
+ use_original_steps=False,
152
+ quantize_denoised=False,
153
+ temperature=1.0,
154
+ noise_dropout=0.0,
155
+ ):
156
+ b, *_, device = *x.shape, x.device
157
+ e_t = self.model.apply_model(x, t, c)
158
+
159
+ alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas
160
+ alphas_prev = (
161
+ self.model.alphas_cumprod_prev
162
+ if use_original_steps
163
+ else self.ddim_alphas_prev
164
+ )
165
+ sqrt_one_minus_alphas = (
166
+ self.model.sqrt_one_minus_alphas_cumprod
167
+ if use_original_steps
168
+ else self.ddim_sqrt_one_minus_alphas
169
+ )
170
+ sigmas = (
171
+ self.model.ddim_sigmas_for_original_num_steps
172
+ if use_original_steps
173
+ else self.ddim_sigmas
174
+ )
175
+ # select parameters corresponding to the currently considered timestep
176
+ a_t = torch.full((b, 1, 1, 1), alphas[index], device=device)
177
+ a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device)
178
+ sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device)
179
+ sqrt_one_minus_at = torch.full(
180
+ (b, 1, 1, 1), sqrt_one_minus_alphas[index], device=device
181
+ )
182
+
183
+ # current prediction for x_0
184
+ pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
185
+ if quantize_denoised: # 没用
186
+ pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0)
187
+ # direction pointing to x_t
188
+ dir_xt = (1.0 - a_prev - sigma_t ** 2).sqrt() * e_t
189
+ noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature
190
+ if noise_dropout > 0.0: # 没用
191
+ noise = torch.nn.functional.dropout(noise, p=noise_dropout)
192
+ x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise
193
+ return x_prev, pred_x0
iopaint/model/fcf.py ADDED
@@ -0,0 +1,1737 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import random
3
+
4
+ import cv2
5
+ import torch
6
+ import numpy as np
7
+ import torch.fft as fft
8
+
9
+ from iopaint.schema import InpaintRequest
10
+
11
+ from iopaint.helper import (
12
+ load_model,
13
+ get_cache_path_by_url,
14
+ norm_img,
15
+ boxes_from_mask,
16
+ resize_max_size,
17
+ download_model,
18
+ )
19
+ from .base import InpaintModel
20
+ from torch import conv2d, nn
21
+ import torch.nn.functional as F
22
+
23
+ from .utils import (
24
+ setup_filter,
25
+ _parse_scaling,
26
+ _parse_padding,
27
+ Conv2dLayer,
28
+ FullyConnectedLayer,
29
+ MinibatchStdLayer,
30
+ activation_funcs,
31
+ conv2d_resample,
32
+ bias_act,
33
+ upsample2d,
34
+ normalize_2nd_moment,
35
+ downsample2d,
36
+ )
37
+
38
+
39
+ def upfirdn2d(x, f, up=1, down=1, padding=0, flip_filter=False, gain=1, impl="cuda"):
40
+ assert isinstance(x, torch.Tensor)
41
+ return _upfirdn2d_ref(
42
+ x, f, up=up, down=down, padding=padding, flip_filter=flip_filter, gain=gain
43
+ )
44
+
45
+
46
+ def _upfirdn2d_ref(x, f, up=1, down=1, padding=0, flip_filter=False, gain=1):
47
+ """Slow reference implementation of `upfirdn2d()` using standard PyTorch ops."""
48
+ # Validate arguments.
49
+ assert isinstance(x, torch.Tensor) and x.ndim == 4
50
+ if f is None:
51
+ f = torch.ones([1, 1], dtype=torch.float32, device=x.device)
52
+ assert isinstance(f, torch.Tensor) and f.ndim in [1, 2]
53
+ assert f.dtype == torch.float32 and not f.requires_grad
54
+ batch_size, num_channels, in_height, in_width = x.shape
55
+ upx, upy = _parse_scaling(up)
56
+ downx, downy = _parse_scaling(down)
57
+ padx0, padx1, pady0, pady1 = _parse_padding(padding)
58
+
59
+ # Upsample by inserting zeros.
60
+ x = x.reshape([batch_size, num_channels, in_height, 1, in_width, 1])
61
+ x = torch.nn.functional.pad(x, [0, upx - 1, 0, 0, 0, upy - 1])
62
+ x = x.reshape([batch_size, num_channels, in_height * upy, in_width * upx])
63
+
64
+ # Pad or crop.
65
+ x = torch.nn.functional.pad(
66
+ x, [max(padx0, 0), max(padx1, 0), max(pady0, 0), max(pady1, 0)]
67
+ )
68
+ x = x[
69
+ :,
70
+ :,
71
+ max(-pady0, 0) : x.shape[2] - max(-pady1, 0),
72
+ max(-padx0, 0) : x.shape[3] - max(-padx1, 0),
73
+ ]
74
+
75
+ # Setup filter.
76
+ f = f * (gain ** (f.ndim / 2))
77
+ f = f.to(x.dtype)
78
+ if not flip_filter:
79
+ f = f.flip(list(range(f.ndim)))
80
+
81
+ # Convolve with the filter.
82
+ f = f[np.newaxis, np.newaxis].repeat([num_channels, 1] + [1] * f.ndim)
83
+ if f.ndim == 4:
84
+ x = conv2d(input=x, weight=f, groups=num_channels)
85
+ else:
86
+ x = conv2d(input=x, weight=f.unsqueeze(2), groups=num_channels)
87
+ x = conv2d(input=x, weight=f.unsqueeze(3), groups=num_channels)
88
+
89
+ # Downsample by throwing away pixels.
90
+ x = x[:, :, ::downy, ::downx]
91
+ return x
92
+
93
+
94
+ class EncoderEpilogue(torch.nn.Module):
95
+ def __init__(
96
+ self,
97
+ in_channels, # Number of input channels.
98
+ cmap_dim, # Dimensionality of mapped conditioning label, 0 = no label.
99
+ z_dim, # Output Latent (Z) dimensionality.
100
+ resolution, # Resolution of this block.
101
+ img_channels, # Number of input color channels.
102
+ architecture="resnet", # Architecture: 'orig', 'skip', 'resnet'.
103
+ mbstd_group_size=4, # Group size for the minibatch standard deviation layer, None = entire minibatch.
104
+ mbstd_num_channels=1, # Number of features for the minibatch standard deviation layer, 0 = disable.
105
+ activation="lrelu", # Activation function: 'relu', 'lrelu', etc.
106
+ conv_clamp=None, # Clamp the output of convolution layers to +-X, None = disable clamping.
107
+ ):
108
+ assert architecture in ["orig", "skip", "resnet"]
109
+ super().__init__()
110
+ self.in_channels = in_channels
111
+ self.cmap_dim = cmap_dim
112
+ self.resolution = resolution
113
+ self.img_channels = img_channels
114
+ self.architecture = architecture
115
+
116
+ if architecture == "skip":
117
+ self.fromrgb = Conv2dLayer(
118
+ self.img_channels, in_channels, kernel_size=1, activation=activation
119
+ )
120
+ self.mbstd = (
121
+ MinibatchStdLayer(
122
+ group_size=mbstd_group_size, num_channels=mbstd_num_channels
123
+ )
124
+ if mbstd_num_channels > 0
125
+ else None
126
+ )
127
+ self.conv = Conv2dLayer(
128
+ in_channels + mbstd_num_channels,
129
+ in_channels,
130
+ kernel_size=3,
131
+ activation=activation,
132
+ conv_clamp=conv_clamp,
133
+ )
134
+ self.fc = FullyConnectedLayer(
135
+ in_channels * (resolution**2), z_dim, activation=activation
136
+ )
137
+ self.dropout = torch.nn.Dropout(p=0.5)
138
+
139
+ def forward(self, x, cmap, force_fp32=False):
140
+ _ = force_fp32 # unused
141
+ dtype = torch.float32
142
+ memory_format = torch.contiguous_format
143
+
144
+ # FromRGB.
145
+ x = x.to(dtype=dtype, memory_format=memory_format)
146
+
147
+ # Main layers.
148
+ if self.mbstd is not None:
149
+ x = self.mbstd(x)
150
+ const_e = self.conv(x)
151
+ x = self.fc(const_e.flatten(1))
152
+ x = self.dropout(x)
153
+
154
+ # Conditioning.
155
+ if self.cmap_dim > 0:
156
+ x = (x * cmap).sum(dim=1, keepdim=True) * (1 / np.sqrt(self.cmap_dim))
157
+
158
+ assert x.dtype == dtype
159
+ return x, const_e
160
+
161
+
162
+ class EncoderBlock(torch.nn.Module):
163
+ def __init__(
164
+ self,
165
+ in_channels, # Number of input channels, 0 = first block.
166
+ tmp_channels, # Number of intermediate channels.
167
+ out_channels, # Number of output channels.
168
+ resolution, # Resolution of this block.
169
+ img_channels, # Number of input color channels.
170
+ first_layer_idx, # Index of the first layer.
171
+ architecture="skip", # Architecture: 'orig', 'skip', 'resnet'.
172
+ activation="lrelu", # Activation function: 'relu', 'lrelu', etc.
173
+ resample_filter=[
174
+ 1,
175
+ 3,
176
+ 3,
177
+ 1,
178
+ ], # Low-pass filter to apply when resampling activations.
179
+ conv_clamp=None, # Clamp the output of convolution layers to +-X, None = disable clamping.
180
+ use_fp16=False, # Use FP16 for this block?
181
+ fp16_channels_last=False, # Use channels-last memory format with FP16?
182
+ freeze_layers=0, # Freeze-D: Number of layers to freeze.
183
+ ):
184
+ assert in_channels in [0, tmp_channels]
185
+ assert architecture in ["orig", "skip", "resnet"]
186
+ super().__init__()
187
+ self.in_channels = in_channels
188
+ self.resolution = resolution
189
+ self.img_channels = img_channels + 1
190
+ self.first_layer_idx = first_layer_idx
191
+ self.architecture = architecture
192
+ self.use_fp16 = use_fp16
193
+ self.channels_last = use_fp16 and fp16_channels_last
194
+ self.register_buffer("resample_filter", setup_filter(resample_filter))
195
+
196
+ self.num_layers = 0
197
+
198
+ def trainable_gen():
199
+ while True:
200
+ layer_idx = self.first_layer_idx + self.num_layers
201
+ trainable = layer_idx >= freeze_layers
202
+ self.num_layers += 1
203
+ yield trainable
204
+
205
+ trainable_iter = trainable_gen()
206
+
207
+ if in_channels == 0:
208
+ self.fromrgb = Conv2dLayer(
209
+ self.img_channels,
210
+ tmp_channels,
211
+ kernel_size=1,
212
+ activation=activation,
213
+ trainable=next(trainable_iter),
214
+ conv_clamp=conv_clamp,
215
+ channels_last=self.channels_last,
216
+ )
217
+
218
+ self.conv0 = Conv2dLayer(
219
+ tmp_channels,
220
+ tmp_channels,
221
+ kernel_size=3,
222
+ activation=activation,
223
+ trainable=next(trainable_iter),
224
+ conv_clamp=conv_clamp,
225
+ channels_last=self.channels_last,
226
+ )
227
+
228
+ self.conv1 = Conv2dLayer(
229
+ tmp_channels,
230
+ out_channels,
231
+ kernel_size=3,
232
+ activation=activation,
233
+ down=2,
234
+ trainable=next(trainable_iter),
235
+ resample_filter=resample_filter,
236
+ conv_clamp=conv_clamp,
237
+ channels_last=self.channels_last,
238
+ )
239
+
240
+ if architecture == "resnet":
241
+ self.skip = Conv2dLayer(
242
+ tmp_channels,
243
+ out_channels,
244
+ kernel_size=1,
245
+ bias=False,
246
+ down=2,
247
+ trainable=next(trainable_iter),
248
+ resample_filter=resample_filter,
249
+ channels_last=self.channels_last,
250
+ )
251
+
252
+ def forward(self, x, img, force_fp32=False):
253
+ # dtype = torch.float16 if self.use_fp16 and not force_fp32 else torch.float32
254
+ dtype = torch.float32
255
+ memory_format = (
256
+ torch.channels_last
257
+ if self.channels_last and not force_fp32
258
+ else torch.contiguous_format
259
+ )
260
+
261
+ # Input.
262
+ if x is not None:
263
+ x = x.to(dtype=dtype, memory_format=memory_format)
264
+
265
+ # FromRGB.
266
+ if self.in_channels == 0:
267
+ img = img.to(dtype=dtype, memory_format=memory_format)
268
+ y = self.fromrgb(img)
269
+ x = x + y if x is not None else y
270
+ img = (
271
+ downsample2d(img, self.resample_filter)
272
+ if self.architecture == "skip"
273
+ else None
274
+ )
275
+
276
+ # Main layers.
277
+ if self.architecture == "resnet":
278
+ y = self.skip(x, gain=np.sqrt(0.5))
279
+ x = self.conv0(x)
280
+ feat = x.clone()
281
+ x = self.conv1(x, gain=np.sqrt(0.5))
282
+ x = y.add_(x)
283
+ else:
284
+ x = self.conv0(x)
285
+ feat = x.clone()
286
+ x = self.conv1(x)
287
+
288
+ assert x.dtype == dtype
289
+ return x, img, feat
290
+
291
+
292
+ class EncoderNetwork(torch.nn.Module):
293
+ def __init__(
294
+ self,
295
+ c_dim, # Conditioning label (C) dimensionality.
296
+ z_dim, # Input latent (Z) dimensionality.
297
+ img_resolution, # Input resolution.
298
+ img_channels, # Number of input color channels.
299
+ architecture="orig", # Architecture: 'orig', 'skip', 'resnet'.
300
+ channel_base=16384, # Overall multiplier for the number of channels.
301
+ channel_max=512, # Maximum number of channels in any layer.
302
+ num_fp16_res=0, # Use FP16 for the N highest resolutions.
303
+ conv_clamp=None, # Clamp the output of convolution layers to +-X, None = disable clamping.
304
+ cmap_dim=None, # Dimensionality of mapped conditioning label, None = default.
305
+ block_kwargs={}, # Arguments for DiscriminatorBlock.
306
+ mapping_kwargs={}, # Arguments for MappingNetwork.
307
+ epilogue_kwargs={}, # Arguments for EncoderEpilogue.
308
+ ):
309
+ super().__init__()
310
+ self.c_dim = c_dim
311
+ self.z_dim = z_dim
312
+ self.img_resolution = img_resolution
313
+ self.img_resolution_log2 = int(np.log2(img_resolution))
314
+ self.img_channels = img_channels
315
+ self.block_resolutions = [
316
+ 2**i for i in range(self.img_resolution_log2, 2, -1)
317
+ ]
318
+ channels_dict = {
319
+ res: min(channel_base // res, channel_max)
320
+ for res in self.block_resolutions + [4]
321
+ }
322
+ fp16_resolution = max(2 ** (self.img_resolution_log2 + 1 - num_fp16_res), 8)
323
+
324
+ if cmap_dim is None:
325
+ cmap_dim = channels_dict[4]
326
+ if c_dim == 0:
327
+ cmap_dim = 0
328
+
329
+ common_kwargs = dict(
330
+ img_channels=img_channels, architecture=architecture, conv_clamp=conv_clamp
331
+ )
332
+ cur_layer_idx = 0
333
+ for res in self.block_resolutions:
334
+ in_channels = channels_dict[res] if res < img_resolution else 0
335
+ tmp_channels = channels_dict[res]
336
+ out_channels = channels_dict[res // 2]
337
+ use_fp16 = res >= fp16_resolution
338
+ use_fp16 = False
339
+ block = EncoderBlock(
340
+ in_channels,
341
+ tmp_channels,
342
+ out_channels,
343
+ resolution=res,
344
+ first_layer_idx=cur_layer_idx,
345
+ use_fp16=use_fp16,
346
+ **block_kwargs,
347
+ **common_kwargs,
348
+ )
349
+ setattr(self, f"b{res}", block)
350
+ cur_layer_idx += block.num_layers
351
+ if c_dim > 0:
352
+ self.mapping = MappingNetwork(
353
+ z_dim=0,
354
+ c_dim=c_dim,
355
+ w_dim=cmap_dim,
356
+ num_ws=None,
357
+ w_avg_beta=None,
358
+ **mapping_kwargs,
359
+ )
360
+ self.b4 = EncoderEpilogue(
361
+ channels_dict[4],
362
+ cmap_dim=cmap_dim,
363
+ z_dim=z_dim * 2,
364
+ resolution=4,
365
+ **epilogue_kwargs,
366
+ **common_kwargs,
367
+ )
368
+
369
+ def forward(self, img, c, **block_kwargs):
370
+ x = None
371
+ feats = {}
372
+ for res in self.block_resolutions:
373
+ block = getattr(self, f"b{res}")
374
+ x, img, feat = block(x, img, **block_kwargs)
375
+ feats[res] = feat
376
+
377
+ cmap = None
378
+ if self.c_dim > 0:
379
+ cmap = self.mapping(None, c)
380
+ x, const_e = self.b4(x, cmap)
381
+ feats[4] = const_e
382
+
383
+ B, _ = x.shape
384
+ z = torch.zeros(
385
+ (B, self.z_dim), requires_grad=False, dtype=x.dtype, device=x.device
386
+ ) ## Noise for Co-Modulation
387
+ return x, z, feats
388
+
389
+
390
+ def fma(a, b, c): # => a * b + c
391
+ return _FusedMultiplyAdd.apply(a, b, c)
392
+
393
+
394
+ class _FusedMultiplyAdd(torch.autograd.Function): # a * b + c
395
+ @staticmethod
396
+ def forward(ctx, a, b, c): # pylint: disable=arguments-differ
397
+ out = torch.addcmul(c, a, b)
398
+ ctx.save_for_backward(a, b)
399
+ ctx.c_shape = c.shape
400
+ return out
401
+
402
+ @staticmethod
403
+ def backward(ctx, dout): # pylint: disable=arguments-differ
404
+ a, b = ctx.saved_tensors
405
+ c_shape = ctx.c_shape
406
+ da = None
407
+ db = None
408
+ dc = None
409
+
410
+ if ctx.needs_input_grad[0]:
411
+ da = _unbroadcast(dout * b, a.shape)
412
+
413
+ if ctx.needs_input_grad[1]:
414
+ db = _unbroadcast(dout * a, b.shape)
415
+
416
+ if ctx.needs_input_grad[2]:
417
+ dc = _unbroadcast(dout, c_shape)
418
+
419
+ return da, db, dc
420
+
421
+
422
+ def _unbroadcast(x, shape):
423
+ extra_dims = x.ndim - len(shape)
424
+ assert extra_dims >= 0
425
+ dim = [
426
+ i
427
+ for i in range(x.ndim)
428
+ if x.shape[i] > 1 and (i < extra_dims or shape[i - extra_dims] == 1)
429
+ ]
430
+ if len(dim):
431
+ x = x.sum(dim=dim, keepdim=True)
432
+ if extra_dims:
433
+ x = x.reshape(-1, *x.shape[extra_dims + 1 :])
434
+ assert x.shape == shape
435
+ return x
436
+
437
+
438
+ def modulated_conv2d(
439
+ x, # Input tensor of shape [batch_size, in_channels, in_height, in_width].
440
+ weight, # Weight tensor of shape [out_channels, in_channels, kernel_height, kernel_width].
441
+ styles, # Modulation coefficients of shape [batch_size, in_channels].
442
+ noise=None, # Optional noise tensor to add to the output activations.
443
+ up=1, # Integer upsampling factor.
444
+ down=1, # Integer downsampling factor.
445
+ padding=0, # Padding with respect to the upsampled image.
446
+ resample_filter=None,
447
+ # Low-pass filter to apply when resampling activations. Must be prepared beforehand by calling upfirdn2d.setup_filter().
448
+ demodulate=True, # Apply weight demodulation?
449
+ flip_weight=True, # False = convolution, True = correlation (matches torch.nn.functional.conv2d).
450
+ fused_modconv=True, # Perform modulation, convolution, and demodulation as a single fused operation?
451
+ ):
452
+ batch_size = x.shape[0]
453
+ out_channels, in_channels, kh, kw = weight.shape
454
+
455
+ # Pre-normalize inputs to avoid FP16 overflow.
456
+ if x.dtype == torch.float16 and demodulate:
457
+ weight = weight * (
458
+ 1
459
+ / np.sqrt(in_channels * kh * kw)
460
+ / weight.norm(float("inf"), dim=[1, 2, 3], keepdim=True)
461
+ ) # max_Ikk
462
+ styles = styles / styles.norm(float("inf"), dim=1, keepdim=True) # max_I
463
+
464
+ # Calculate per-sample weights and demodulation coefficients.
465
+ w = None
466
+ dcoefs = None
467
+ if demodulate or fused_modconv:
468
+ w = weight.unsqueeze(0) # [NOIkk]
469
+ w = w * styles.reshape(batch_size, 1, -1, 1, 1) # [NOIkk]
470
+ if demodulate:
471
+ dcoefs = (w.square().sum(dim=[2, 3, 4]) + 1e-8).rsqrt() # [NO]
472
+ if demodulate and fused_modconv:
473
+ w = w * dcoefs.reshape(batch_size, -1, 1, 1, 1) # [NOIkk]
474
+ # Execute by scaling the activations before and after the convolution.
475
+ if not fused_modconv:
476
+ x = x * styles.to(x.dtype).reshape(batch_size, -1, 1, 1)
477
+ x = conv2d_resample.conv2d_resample(
478
+ x=x,
479
+ w=weight.to(x.dtype),
480
+ f=resample_filter,
481
+ up=up,
482
+ down=down,
483
+ padding=padding,
484
+ flip_weight=flip_weight,
485
+ )
486
+ if demodulate and noise is not None:
487
+ x = fma(
488
+ x, dcoefs.to(x.dtype).reshape(batch_size, -1, 1, 1), noise.to(x.dtype)
489
+ )
490
+ elif demodulate:
491
+ x = x * dcoefs.to(x.dtype).reshape(batch_size, -1, 1, 1)
492
+ elif noise is not None:
493
+ x = x.add_(noise.to(x.dtype))
494
+ return x
495
+
496
+ # Execute as one fused op using grouped convolution.
497
+ batch_size = int(batch_size)
498
+ x = x.reshape(1, -1, *x.shape[2:])
499
+ w = w.reshape(-1, in_channels, kh, kw)
500
+ x = conv2d_resample(
501
+ x=x,
502
+ w=w.to(x.dtype),
503
+ f=resample_filter,
504
+ up=up,
505
+ down=down,
506
+ padding=padding,
507
+ groups=batch_size,
508
+ flip_weight=flip_weight,
509
+ )
510
+ x = x.reshape(batch_size, -1, *x.shape[2:])
511
+ if noise is not None:
512
+ x = x.add_(noise)
513
+ return x
514
+
515
+
516
+ class SynthesisLayer(torch.nn.Module):
517
+ def __init__(
518
+ self,
519
+ in_channels, # Number of input channels.
520
+ out_channels, # Number of output channels.
521
+ w_dim, # Intermediate latent (W) dimensionality.
522
+ resolution, # Resolution of this layer.
523
+ kernel_size=3, # Convolution kernel size.
524
+ up=1, # Integer upsampling factor.
525
+ use_noise=True, # Enable noise input?
526
+ activation="lrelu", # Activation function: 'relu', 'lrelu', etc.
527
+ resample_filter=[
528
+ 1,
529
+ 3,
530
+ 3,
531
+ 1,
532
+ ], # Low-pass filter to apply when resampling activations.
533
+ conv_clamp=None, # Clamp the output of convolution layers to +-X, None = disable clamping.
534
+ channels_last=False, # Use channels_last format for the weights?
535
+ ):
536
+ super().__init__()
537
+ self.resolution = resolution
538
+ self.up = up
539
+ self.use_noise = use_noise
540
+ self.activation = activation
541
+ self.conv_clamp = conv_clamp
542
+ self.register_buffer("resample_filter", setup_filter(resample_filter))
543
+ self.padding = kernel_size // 2
544
+ self.act_gain = activation_funcs[activation].def_gain
545
+
546
+ self.affine = FullyConnectedLayer(w_dim, in_channels, bias_init=1)
547
+ memory_format = (
548
+ torch.channels_last if channels_last else torch.contiguous_format
549
+ )
550
+ self.weight = torch.nn.Parameter(
551
+ torch.randn([out_channels, in_channels, kernel_size, kernel_size]).to(
552
+ memory_format=memory_format
553
+ )
554
+ )
555
+ if use_noise:
556
+ self.register_buffer("noise_const", torch.randn([resolution, resolution]))
557
+ self.noise_strength = torch.nn.Parameter(torch.zeros([]))
558
+ self.bias = torch.nn.Parameter(torch.zeros([out_channels]))
559
+
560
+ def forward(self, x, w, noise_mode="none", fused_modconv=True, gain=1):
561
+ assert noise_mode in ["random", "const", "none"]
562
+ in_resolution = self.resolution // self.up
563
+ styles = self.affine(w)
564
+
565
+ noise = None
566
+ if self.use_noise and noise_mode == "random":
567
+ noise = (
568
+ torch.randn(
569
+ [x.shape[0], 1, self.resolution, self.resolution], device=x.device
570
+ )
571
+ * self.noise_strength
572
+ )
573
+ if self.use_noise and noise_mode == "const":
574
+ noise = self.noise_const * self.noise_strength
575
+
576
+ flip_weight = self.up == 1 # slightly faster
577
+ x = modulated_conv2d(
578
+ x=x,
579
+ weight=self.weight,
580
+ styles=styles,
581
+ noise=noise,
582
+ up=self.up,
583
+ padding=self.padding,
584
+ resample_filter=self.resample_filter,
585
+ flip_weight=flip_weight,
586
+ fused_modconv=fused_modconv,
587
+ )
588
+
589
+ act_gain = self.act_gain * gain
590
+ act_clamp = self.conv_clamp * gain if self.conv_clamp is not None else None
591
+ x = F.leaky_relu(x, negative_slope=0.2, inplace=False)
592
+ if act_gain != 1:
593
+ x = x * act_gain
594
+ if act_clamp is not None:
595
+ x = x.clamp(-act_clamp, act_clamp)
596
+ return x
597
+
598
+
599
+ class ToRGBLayer(torch.nn.Module):
600
+ def __init__(
601
+ self,
602
+ in_channels,
603
+ out_channels,
604
+ w_dim,
605
+ kernel_size=1,
606
+ conv_clamp=None,
607
+ channels_last=False,
608
+ ):
609
+ super().__init__()
610
+ self.conv_clamp = conv_clamp
611
+ self.affine = FullyConnectedLayer(w_dim, in_channels, bias_init=1)
612
+ memory_format = (
613
+ torch.channels_last if channels_last else torch.contiguous_format
614
+ )
615
+ self.weight = torch.nn.Parameter(
616
+ torch.randn([out_channels, in_channels, kernel_size, kernel_size]).to(
617
+ memory_format=memory_format
618
+ )
619
+ )
620
+ self.bias = torch.nn.Parameter(torch.zeros([out_channels]))
621
+ self.weight_gain = 1 / np.sqrt(in_channels * (kernel_size**2))
622
+
623
+ def forward(self, x, w, fused_modconv=True):
624
+ styles = self.affine(w) * self.weight_gain
625
+ x = modulated_conv2d(
626
+ x=x,
627
+ weight=self.weight,
628
+ styles=styles,
629
+ demodulate=False,
630
+ fused_modconv=fused_modconv,
631
+ )
632
+ x = bias_act(x, self.bias.to(x.dtype), clamp=self.conv_clamp)
633
+ return x
634
+
635
+
636
+ class SynthesisForeword(torch.nn.Module):
637
+ def __init__(
638
+ self,
639
+ z_dim, # Output Latent (Z) dimensionality.
640
+ resolution, # Resolution of this block.
641
+ in_channels,
642
+ img_channels, # Number of input color channels.
643
+ architecture="skip", # Architecture: 'orig', 'skip', 'resnet'.
644
+ activation="lrelu", # Activation function: 'relu', 'lrelu', etc.
645
+ ):
646
+ super().__init__()
647
+ self.in_channels = in_channels
648
+ self.z_dim = z_dim
649
+ self.resolution = resolution
650
+ self.img_channels = img_channels
651
+ self.architecture = architecture
652
+
653
+ self.fc = FullyConnectedLayer(
654
+ self.z_dim, (self.z_dim // 2) * 4 * 4, activation=activation
655
+ )
656
+ self.conv = SynthesisLayer(
657
+ self.in_channels, self.in_channels, w_dim=(z_dim // 2) * 3, resolution=4
658
+ )
659
+
660
+ if architecture == "skip":
661
+ self.torgb = ToRGBLayer(
662
+ self.in_channels,
663
+ self.img_channels,
664
+ kernel_size=1,
665
+ w_dim=(z_dim // 2) * 3,
666
+ )
667
+
668
+ def forward(self, x, ws, feats, img, force_fp32=False):
669
+ _ = force_fp32 # unused
670
+ dtype = torch.float32
671
+ memory_format = torch.contiguous_format
672
+
673
+ x_global = x.clone()
674
+ # ToRGB.
675
+ x = self.fc(x)
676
+ x = x.view(-1, self.z_dim // 2, 4, 4)
677
+ x = x.to(dtype=dtype, memory_format=memory_format)
678
+
679
+ # Main layers.
680
+ x_skip = feats[4].clone()
681
+ x = x + x_skip
682
+
683
+ mod_vector = []
684
+ mod_vector.append(ws[:, 0])
685
+ mod_vector.append(x_global.clone())
686
+ mod_vector = torch.cat(mod_vector, dim=1)
687
+
688
+ x = self.conv(x, mod_vector)
689
+
690
+ mod_vector = []
691
+ mod_vector.append(ws[:, 2 * 2 - 3])
692
+ mod_vector.append(x_global.clone())
693
+ mod_vector = torch.cat(mod_vector, dim=1)
694
+
695
+ if self.architecture == "skip":
696
+ img = self.torgb(x, mod_vector)
697
+ img = img.to(dtype=torch.float32, memory_format=torch.contiguous_format)
698
+
699
+ assert x.dtype == dtype
700
+ return x, img
701
+
702
+
703
+ class SELayer(nn.Module):
704
+ def __init__(self, channel, reduction=16):
705
+ super(SELayer, self).__init__()
706
+ self.avg_pool = nn.AdaptiveAvgPool2d(1)
707
+ self.fc = nn.Sequential(
708
+ nn.Linear(channel, channel // reduction, bias=False),
709
+ nn.ReLU(inplace=False),
710
+ nn.Linear(channel // reduction, channel, bias=False),
711
+ nn.Sigmoid(),
712
+ )
713
+
714
+ def forward(self, x):
715
+ b, c, _, _ = x.size()
716
+ y = self.avg_pool(x).view(b, c)
717
+ y = self.fc(y).view(b, c, 1, 1)
718
+ res = x * y.expand_as(x)
719
+ return res
720
+
721
+
722
+ class FourierUnit(nn.Module):
723
+ def __init__(
724
+ self,
725
+ in_channels,
726
+ out_channels,
727
+ groups=1,
728
+ spatial_scale_factor=None,
729
+ spatial_scale_mode="bilinear",
730
+ spectral_pos_encoding=False,
731
+ use_se=False,
732
+ se_kwargs=None,
733
+ ffc3d=False,
734
+ fft_norm="ortho",
735
+ ):
736
+ # bn_layer not used
737
+ super(FourierUnit, self).__init__()
738
+ self.groups = groups
739
+
740
+ self.conv_layer = torch.nn.Conv2d(
741
+ in_channels=in_channels * 2 + (2 if spectral_pos_encoding else 0),
742
+ out_channels=out_channels * 2,
743
+ kernel_size=1,
744
+ stride=1,
745
+ padding=0,
746
+ groups=self.groups,
747
+ bias=False,
748
+ )
749
+ self.relu = torch.nn.ReLU(inplace=False)
750
+
751
+ # squeeze and excitation block
752
+ self.use_se = use_se
753
+ if use_se:
754
+ if se_kwargs is None:
755
+ se_kwargs = {}
756
+ self.se = SELayer(self.conv_layer.in_channels, **se_kwargs)
757
+
758
+ self.spatial_scale_factor = spatial_scale_factor
759
+ self.spatial_scale_mode = spatial_scale_mode
760
+ self.spectral_pos_encoding = spectral_pos_encoding
761
+ self.ffc3d = ffc3d
762
+ self.fft_norm = fft_norm
763
+
764
+ def forward(self, x):
765
+ batch = x.shape[0]
766
+
767
+ if self.spatial_scale_factor is not None:
768
+ orig_size = x.shape[-2:]
769
+ x = F.interpolate(
770
+ x,
771
+ scale_factor=self.spatial_scale_factor,
772
+ mode=self.spatial_scale_mode,
773
+ align_corners=False,
774
+ )
775
+
776
+ r_size = x.size()
777
+ # (batch, c, h, w/2+1, 2)
778
+ fft_dim = (-3, -2, -1) if self.ffc3d else (-2, -1)
779
+ ffted = fft.rfftn(x, dim=fft_dim, norm=self.fft_norm)
780
+ ffted = torch.stack((ffted.real, ffted.imag), dim=-1)
781
+ ffted = ffted.permute(0, 1, 4, 2, 3).contiguous() # (batch, c, 2, h, w/2+1)
782
+ ffted = ffted.view(
783
+ (
784
+ batch,
785
+ -1,
786
+ )
787
+ + ffted.size()[3:]
788
+ )
789
+
790
+ if self.spectral_pos_encoding:
791
+ height, width = ffted.shape[-2:]
792
+ coords_vert = (
793
+ torch.linspace(0, 1, height)[None, None, :, None]
794
+ .expand(batch, 1, height, width)
795
+ .to(ffted)
796
+ )
797
+ coords_hor = (
798
+ torch.linspace(0, 1, width)[None, None, None, :]
799
+ .expand(batch, 1, height, width)
800
+ .to(ffted)
801
+ )
802
+ ffted = torch.cat((coords_vert, coords_hor, ffted), dim=1)
803
+
804
+ if self.use_se:
805
+ ffted = self.se(ffted)
806
+
807
+ ffted = self.conv_layer(ffted) # (batch, c*2, h, w/2+1)
808
+ ffted = self.relu(ffted)
809
+
810
+ ffted = (
811
+ ffted.view(
812
+ (
813
+ batch,
814
+ -1,
815
+ 2,
816
+ )
817
+ + ffted.size()[2:]
818
+ )
819
+ .permute(0, 1, 3, 4, 2)
820
+ .contiguous()
821
+ ) # (batch,c, t, h, w/2+1, 2)
822
+ ffted = torch.complex(ffted[..., 0], ffted[..., 1])
823
+
824
+ ifft_shape_slice = x.shape[-3:] if self.ffc3d else x.shape[-2:]
825
+ output = torch.fft.irfftn(
826
+ ffted, s=ifft_shape_slice, dim=fft_dim, norm=self.fft_norm
827
+ )
828
+
829
+ if self.spatial_scale_factor is not None:
830
+ output = F.interpolate(
831
+ output,
832
+ size=orig_size,
833
+ mode=self.spatial_scale_mode,
834
+ align_corners=False,
835
+ )
836
+
837
+ return output
838
+
839
+
840
+ class SpectralTransform(nn.Module):
841
+ def __init__(
842
+ self,
843
+ in_channels,
844
+ out_channels,
845
+ stride=1,
846
+ groups=1,
847
+ enable_lfu=True,
848
+ **fu_kwargs,
849
+ ):
850
+ # bn_layer not used
851
+ super(SpectralTransform, self).__init__()
852
+ self.enable_lfu = enable_lfu
853
+ if stride == 2:
854
+ self.downsample = nn.AvgPool2d(kernel_size=(2, 2), stride=2)
855
+ else:
856
+ self.downsample = nn.Identity()
857
+
858
+ self.stride = stride
859
+ self.conv1 = nn.Sequential(
860
+ nn.Conv2d(
861
+ in_channels, out_channels // 2, kernel_size=1, groups=groups, bias=False
862
+ ),
863
+ # nn.BatchNorm2d(out_channels // 2),
864
+ nn.ReLU(inplace=True),
865
+ )
866
+ self.fu = FourierUnit(out_channels // 2, out_channels // 2, groups, **fu_kwargs)
867
+ if self.enable_lfu:
868
+ self.lfu = FourierUnit(out_channels // 2, out_channels // 2, groups)
869
+ self.conv2 = torch.nn.Conv2d(
870
+ out_channels // 2, out_channels, kernel_size=1, groups=groups, bias=False
871
+ )
872
+
873
+ def forward(self, x):
874
+ x = self.downsample(x)
875
+ x = self.conv1(x)
876
+ output = self.fu(x)
877
+
878
+ if self.enable_lfu:
879
+ n, c, h, w = x.shape
880
+ split_no = 2
881
+ split_s = h // split_no
882
+ xs = torch.cat(
883
+ torch.split(x[:, : c // 4], split_s, dim=-2), dim=1
884
+ ).contiguous()
885
+ xs = torch.cat(torch.split(xs, split_s, dim=-1), dim=1).contiguous()
886
+ xs = self.lfu(xs)
887
+ xs = xs.repeat(1, 1, split_no, split_no).contiguous()
888
+ else:
889
+ xs = 0
890
+
891
+ output = self.conv2(x + output + xs)
892
+
893
+ return output
894
+
895
+
896
+ class FFC(nn.Module):
897
+ def __init__(
898
+ self,
899
+ in_channels,
900
+ out_channels,
901
+ kernel_size,
902
+ ratio_gin,
903
+ ratio_gout,
904
+ stride=1,
905
+ padding=0,
906
+ dilation=1,
907
+ groups=1,
908
+ bias=False,
909
+ enable_lfu=True,
910
+ padding_type="reflect",
911
+ gated=False,
912
+ **spectral_kwargs,
913
+ ):
914
+ super(FFC, self).__init__()
915
+
916
+ assert stride == 1 or stride == 2, "Stride should be 1 or 2."
917
+ self.stride = stride
918
+
919
+ in_cg = int(in_channels * ratio_gin)
920
+ in_cl = in_channels - in_cg
921
+ out_cg = int(out_channels * ratio_gout)
922
+ out_cl = out_channels - out_cg
923
+ # groups_g = 1 if groups == 1 else int(groups * ratio_gout)
924
+ # groups_l = 1 if groups == 1 else groups - groups_g
925
+
926
+ self.ratio_gin = ratio_gin
927
+ self.ratio_gout = ratio_gout
928
+ self.global_in_num = in_cg
929
+
930
+ module = nn.Identity if in_cl == 0 or out_cl == 0 else nn.Conv2d
931
+ self.convl2l = module(
932
+ in_cl,
933
+ out_cl,
934
+ kernel_size,
935
+ stride,
936
+ padding,
937
+ dilation,
938
+ groups,
939
+ bias,
940
+ padding_mode=padding_type,
941
+ )
942
+ module = nn.Identity if in_cl == 0 or out_cg == 0 else nn.Conv2d
943
+ self.convl2g = module(
944
+ in_cl,
945
+ out_cg,
946
+ kernel_size,
947
+ stride,
948
+ padding,
949
+ dilation,
950
+ groups,
951
+ bias,
952
+ padding_mode=padding_type,
953
+ )
954
+ module = nn.Identity if in_cg == 0 or out_cl == 0 else nn.Conv2d
955
+ self.convg2l = module(
956
+ in_cg,
957
+ out_cl,
958
+ kernel_size,
959
+ stride,
960
+ padding,
961
+ dilation,
962
+ groups,
963
+ bias,
964
+ padding_mode=padding_type,
965
+ )
966
+ module = nn.Identity if in_cg == 0 or out_cg == 0 else SpectralTransform
967
+ self.convg2g = module(
968
+ in_cg,
969
+ out_cg,
970
+ stride,
971
+ 1 if groups == 1 else groups // 2,
972
+ enable_lfu,
973
+ **spectral_kwargs,
974
+ )
975
+
976
+ self.gated = gated
977
+ module = (
978
+ nn.Identity if in_cg == 0 or out_cl == 0 or not self.gated else nn.Conv2d
979
+ )
980
+ self.gate = module(in_channels, 2, 1)
981
+
982
+ def forward(self, x, fname=None):
983
+ x_l, x_g = x if type(x) is tuple else (x, 0)
984
+ out_xl, out_xg = 0, 0
985
+
986
+ if self.gated:
987
+ total_input_parts = [x_l]
988
+ if torch.is_tensor(x_g):
989
+ total_input_parts.append(x_g)
990
+ total_input = torch.cat(total_input_parts, dim=1)
991
+
992
+ gates = torch.sigmoid(self.gate(total_input))
993
+ g2l_gate, l2g_gate = gates.chunk(2, dim=1)
994
+ else:
995
+ g2l_gate, l2g_gate = 1, 1
996
+
997
+ spec_x = self.convg2g(x_g)
998
+
999
+ if self.ratio_gout != 1:
1000
+ out_xl = self.convl2l(x_l) + self.convg2l(x_g) * g2l_gate
1001
+ if self.ratio_gout != 0:
1002
+ out_xg = self.convl2g(x_l) * l2g_gate + spec_x
1003
+
1004
+ return out_xl, out_xg
1005
+
1006
+
1007
+ class FFC_BN_ACT(nn.Module):
1008
+ def __init__(
1009
+ self,
1010
+ in_channels,
1011
+ out_channels,
1012
+ kernel_size,
1013
+ ratio_gin,
1014
+ ratio_gout,
1015
+ stride=1,
1016
+ padding=0,
1017
+ dilation=1,
1018
+ groups=1,
1019
+ bias=False,
1020
+ norm_layer=nn.SyncBatchNorm,
1021
+ activation_layer=nn.Identity,
1022
+ padding_type="reflect",
1023
+ enable_lfu=True,
1024
+ **kwargs,
1025
+ ):
1026
+ super(FFC_BN_ACT, self).__init__()
1027
+ self.ffc = FFC(
1028
+ in_channels,
1029
+ out_channels,
1030
+ kernel_size,
1031
+ ratio_gin,
1032
+ ratio_gout,
1033
+ stride,
1034
+ padding,
1035
+ dilation,
1036
+ groups,
1037
+ bias,
1038
+ enable_lfu,
1039
+ padding_type=padding_type,
1040
+ **kwargs,
1041
+ )
1042
+ lnorm = nn.Identity if ratio_gout == 1 else norm_layer
1043
+ gnorm = nn.Identity if ratio_gout == 0 else norm_layer
1044
+ global_channels = int(out_channels * ratio_gout)
1045
+ # self.bn_l = lnorm(out_channels - global_channels)
1046
+ # self.bn_g = gnorm(global_channels)
1047
+
1048
+ lact = nn.Identity if ratio_gout == 1 else activation_layer
1049
+ gact = nn.Identity if ratio_gout == 0 else activation_layer
1050
+ self.act_l = lact(inplace=True)
1051
+ self.act_g = gact(inplace=True)
1052
+
1053
+ def forward(self, x, fname=None):
1054
+ x_l, x_g = self.ffc(
1055
+ x,
1056
+ fname=fname,
1057
+ )
1058
+ x_l = self.act_l(x_l)
1059
+ x_g = self.act_g(x_g)
1060
+ return x_l, x_g
1061
+
1062
+
1063
+ class FFCResnetBlock(nn.Module):
1064
+ def __init__(
1065
+ self,
1066
+ dim,
1067
+ padding_type,
1068
+ norm_layer,
1069
+ activation_layer=nn.ReLU,
1070
+ dilation=1,
1071
+ spatial_transform_kwargs=None,
1072
+ inline=False,
1073
+ ratio_gin=0.75,
1074
+ ratio_gout=0.75,
1075
+ ):
1076
+ super().__init__()
1077
+ self.conv1 = FFC_BN_ACT(
1078
+ dim,
1079
+ dim,
1080
+ kernel_size=3,
1081
+ padding=dilation,
1082
+ dilation=dilation,
1083
+ norm_layer=norm_layer,
1084
+ activation_layer=activation_layer,
1085
+ padding_type=padding_type,
1086
+ ratio_gin=ratio_gin,
1087
+ ratio_gout=ratio_gout,
1088
+ )
1089
+ self.conv2 = FFC_BN_ACT(
1090
+ dim,
1091
+ dim,
1092
+ kernel_size=3,
1093
+ padding=dilation,
1094
+ dilation=dilation,
1095
+ norm_layer=norm_layer,
1096
+ activation_layer=activation_layer,
1097
+ padding_type=padding_type,
1098
+ ratio_gin=ratio_gin,
1099
+ ratio_gout=ratio_gout,
1100
+ )
1101
+ self.inline = inline
1102
+
1103
+ def forward(self, x, fname=None):
1104
+ if self.inline:
1105
+ x_l, x_g = (
1106
+ x[:, : -self.conv1.ffc.global_in_num],
1107
+ x[:, -self.conv1.ffc.global_in_num :],
1108
+ )
1109
+ else:
1110
+ x_l, x_g = x if type(x) is tuple else (x, 0)
1111
+
1112
+ id_l, id_g = x_l, x_g
1113
+
1114
+ x_l, x_g = self.conv1((x_l, x_g), fname=fname)
1115
+ x_l, x_g = self.conv2((x_l, x_g), fname=fname)
1116
+
1117
+ x_l, x_g = id_l + x_l, id_g + x_g
1118
+ out = x_l, x_g
1119
+ if self.inline:
1120
+ out = torch.cat(out, dim=1)
1121
+ return out
1122
+
1123
+
1124
+ class ConcatTupleLayer(nn.Module):
1125
+ def forward(self, x):
1126
+ assert isinstance(x, tuple)
1127
+ x_l, x_g = x
1128
+ assert torch.is_tensor(x_l) or torch.is_tensor(x_g)
1129
+ if not torch.is_tensor(x_g):
1130
+ return x_l
1131
+ return torch.cat(x, dim=1)
1132
+
1133
+
1134
+ class FFCBlock(torch.nn.Module):
1135
+ def __init__(
1136
+ self,
1137
+ dim, # Number of output/input channels.
1138
+ kernel_size, # Width and height of the convolution kernel.
1139
+ padding,
1140
+ ratio_gin=0.75,
1141
+ ratio_gout=0.75,
1142
+ activation="linear", # Activation function: 'relu', 'lrelu', etc.
1143
+ ):
1144
+ super().__init__()
1145
+ if activation == "linear":
1146
+ self.activation = nn.Identity
1147
+ else:
1148
+ self.activation = nn.ReLU
1149
+ self.padding = padding
1150
+ self.kernel_size = kernel_size
1151
+ self.ffc_block = FFCResnetBlock(
1152
+ dim=dim,
1153
+ padding_type="reflect",
1154
+ norm_layer=nn.SyncBatchNorm,
1155
+ activation_layer=self.activation,
1156
+ dilation=1,
1157
+ ratio_gin=ratio_gin,
1158
+ ratio_gout=ratio_gout,
1159
+ )
1160
+
1161
+ self.concat_layer = ConcatTupleLayer()
1162
+
1163
+ def forward(self, gen_ft, mask, fname=None):
1164
+ x = gen_ft.float()
1165
+
1166
+ x_l, x_g = (
1167
+ x[:, : -self.ffc_block.conv1.ffc.global_in_num],
1168
+ x[:, -self.ffc_block.conv1.ffc.global_in_num :],
1169
+ )
1170
+ id_l, id_g = x_l, x_g
1171
+
1172
+ x_l, x_g = self.ffc_block((x_l, x_g), fname=fname)
1173
+ x_l, x_g = id_l + x_l, id_g + x_g
1174
+ x = self.concat_layer((x_l, x_g))
1175
+
1176
+ return x + gen_ft.float()
1177
+
1178
+
1179
+ class FFCSkipLayer(torch.nn.Module):
1180
+ def __init__(
1181
+ self,
1182
+ dim, # Number of input/output channels.
1183
+ kernel_size=3, # Convolution kernel size.
1184
+ ratio_gin=0.75,
1185
+ ratio_gout=0.75,
1186
+ ):
1187
+ super().__init__()
1188
+ self.padding = kernel_size // 2
1189
+
1190
+ self.ffc_act = FFCBlock(
1191
+ dim=dim,
1192
+ kernel_size=kernel_size,
1193
+ activation=nn.ReLU,
1194
+ padding=self.padding,
1195
+ ratio_gin=ratio_gin,
1196
+ ratio_gout=ratio_gout,
1197
+ )
1198
+
1199
+ def forward(self, gen_ft, mask, fname=None):
1200
+ x = self.ffc_act(gen_ft, mask, fname=fname)
1201
+ return x
1202
+
1203
+
1204
+ class SynthesisBlock(torch.nn.Module):
1205
+ def __init__(
1206
+ self,
1207
+ in_channels, # Number of input channels, 0 = first block.
1208
+ out_channels, # Number of output channels.
1209
+ w_dim, # Intermediate latent (W) dimensionality.
1210
+ resolution, # Resolution of this block.
1211
+ img_channels, # Number of output color channels.
1212
+ is_last, # Is this the last block?
1213
+ architecture="skip", # Architecture: 'orig', 'skip', 'resnet'.
1214
+ resample_filter=[
1215
+ 1,
1216
+ 3,
1217
+ 3,
1218
+ 1,
1219
+ ], # Low-pass filter to apply when resampling activations.
1220
+ conv_clamp=None, # Clamp the output of convolution layers to +-X, None = disable clamping.
1221
+ use_fp16=False, # Use FP16 for this block?
1222
+ fp16_channels_last=False, # Use channels-last memory format with FP16?
1223
+ **layer_kwargs, # Arguments for SynthesisLayer.
1224
+ ):
1225
+ assert architecture in ["orig", "skip", "resnet"]
1226
+ super().__init__()
1227
+ self.in_channels = in_channels
1228
+ self.w_dim = w_dim
1229
+ self.resolution = resolution
1230
+ self.img_channels = img_channels
1231
+ self.is_last = is_last
1232
+ self.architecture = architecture
1233
+ self.use_fp16 = use_fp16
1234
+ self.channels_last = use_fp16 and fp16_channels_last
1235
+ self.register_buffer("resample_filter", setup_filter(resample_filter))
1236
+ self.num_conv = 0
1237
+ self.num_torgb = 0
1238
+ self.res_ffc = {4: 0, 8: 0, 16: 0, 32: 1, 64: 1, 128: 1, 256: 1, 512: 1}
1239
+
1240
+ if in_channels != 0 and resolution >= 8:
1241
+ self.ffc_skip = nn.ModuleList()
1242
+ for _ in range(self.res_ffc[resolution]):
1243
+ self.ffc_skip.append(FFCSkipLayer(dim=out_channels))
1244
+
1245
+ if in_channels == 0:
1246
+ self.const = torch.nn.Parameter(
1247
+ torch.randn([out_channels, resolution, resolution])
1248
+ )
1249
+
1250
+ if in_channels != 0:
1251
+ self.conv0 = SynthesisLayer(
1252
+ in_channels,
1253
+ out_channels,
1254
+ w_dim=w_dim * 3,
1255
+ resolution=resolution,
1256
+ up=2,
1257
+ resample_filter=resample_filter,
1258
+ conv_clamp=conv_clamp,
1259
+ channels_last=self.channels_last,
1260
+ **layer_kwargs,
1261
+ )
1262
+ self.num_conv += 1
1263
+
1264
+ self.conv1 = SynthesisLayer(
1265
+ out_channels,
1266
+ out_channels,
1267
+ w_dim=w_dim * 3,
1268
+ resolution=resolution,
1269
+ conv_clamp=conv_clamp,
1270
+ channels_last=self.channels_last,
1271
+ **layer_kwargs,
1272
+ )
1273
+ self.num_conv += 1
1274
+
1275
+ if is_last or architecture == "skip":
1276
+ self.torgb = ToRGBLayer(
1277
+ out_channels,
1278
+ img_channels,
1279
+ w_dim=w_dim * 3,
1280
+ conv_clamp=conv_clamp,
1281
+ channels_last=self.channels_last,
1282
+ )
1283
+ self.num_torgb += 1
1284
+
1285
+ if in_channels != 0 and architecture == "resnet":
1286
+ self.skip = Conv2dLayer(
1287
+ in_channels,
1288
+ out_channels,
1289
+ kernel_size=1,
1290
+ bias=False,
1291
+ up=2,
1292
+ resample_filter=resample_filter,
1293
+ channels_last=self.channels_last,
1294
+ )
1295
+
1296
+ def forward(
1297
+ self,
1298
+ x,
1299
+ mask,
1300
+ feats,
1301
+ img,
1302
+ ws,
1303
+ fname=None,
1304
+ force_fp32=False,
1305
+ fused_modconv=None,
1306
+ **layer_kwargs,
1307
+ ):
1308
+ dtype = torch.float16 if self.use_fp16 and not force_fp32 else torch.float32
1309
+ dtype = torch.float32
1310
+ memory_format = (
1311
+ torch.channels_last
1312
+ if self.channels_last and not force_fp32
1313
+ else torch.contiguous_format
1314
+ )
1315
+ if fused_modconv is None:
1316
+ fused_modconv = (not self.training) and (
1317
+ dtype == torch.float32 or int(x.shape[0]) == 1
1318
+ )
1319
+
1320
+ x = x.to(dtype=dtype, memory_format=memory_format)
1321
+ x_skip = (
1322
+ feats[self.resolution].clone().to(dtype=dtype, memory_format=memory_format)
1323
+ )
1324
+
1325
+ # Main layers.
1326
+ if self.in_channels == 0:
1327
+ x = self.conv1(x, ws[1], fused_modconv=fused_modconv, **layer_kwargs)
1328
+ elif self.architecture == "resnet":
1329
+ y = self.skip(x, gain=np.sqrt(0.5))
1330
+ x = self.conv0(
1331
+ x, ws[0].clone(), fused_modconv=fused_modconv, **layer_kwargs
1332
+ )
1333
+ if len(self.ffc_skip) > 0:
1334
+ mask = F.interpolate(
1335
+ mask,
1336
+ size=x_skip.shape[2:],
1337
+ )
1338
+ z = x + x_skip
1339
+ for fres in self.ffc_skip:
1340
+ z = fres(z, mask)
1341
+ x = x + z
1342
+ else:
1343
+ x = x + x_skip
1344
+ x = self.conv1(
1345
+ x,
1346
+ ws[1].clone(),
1347
+ fused_modconv=fused_modconv,
1348
+ gain=np.sqrt(0.5),
1349
+ **layer_kwargs,
1350
+ )
1351
+ x = y.add_(x)
1352
+ else:
1353
+ x = self.conv0(
1354
+ x, ws[0].clone(), fused_modconv=fused_modconv, **layer_kwargs
1355
+ )
1356
+ if len(self.ffc_skip) > 0:
1357
+ mask = F.interpolate(
1358
+ mask,
1359
+ size=x_skip.shape[2:],
1360
+ )
1361
+ z = x + x_skip
1362
+ for fres in self.ffc_skip:
1363
+ z = fres(z, mask)
1364
+ x = x + z
1365
+ else:
1366
+ x = x + x_skip
1367
+ x = self.conv1(
1368
+ x, ws[1].clone(), fused_modconv=fused_modconv, **layer_kwargs
1369
+ )
1370
+ # ToRGB.
1371
+ if img is not None:
1372
+ img = upsample2d(img, self.resample_filter)
1373
+ if self.is_last or self.architecture == "skip":
1374
+ y = self.torgb(x, ws[2].clone(), fused_modconv=fused_modconv)
1375
+ y = y.to(dtype=torch.float32, memory_format=torch.contiguous_format)
1376
+ img = img.add_(y) if img is not None else y
1377
+
1378
+ x = x.to(dtype=dtype)
1379
+ assert x.dtype == dtype
1380
+ assert img is None or img.dtype == torch.float32
1381
+ return x, img
1382
+
1383
+
1384
+ class SynthesisNetwork(torch.nn.Module):
1385
+ def __init__(
1386
+ self,
1387
+ w_dim, # Intermediate latent (W) dimensionality.
1388
+ z_dim, # Output Latent (Z) dimensionality.
1389
+ img_resolution, # Output image resolution.
1390
+ img_channels, # Number of color channels.
1391
+ channel_base=16384, # Overall multiplier for the number of channels.
1392
+ channel_max=512, # Maximum number of channels in any layer.
1393
+ num_fp16_res=0, # Use FP16 for the N highest resolutions.
1394
+ **block_kwargs, # Arguments for SynthesisBlock.
1395
+ ):
1396
+ assert img_resolution >= 4 and img_resolution & (img_resolution - 1) == 0
1397
+ super().__init__()
1398
+ self.w_dim = w_dim
1399
+ self.img_resolution = img_resolution
1400
+ self.img_resolution_log2 = int(np.log2(img_resolution))
1401
+ self.img_channels = img_channels
1402
+ self.block_resolutions = [
1403
+ 2**i for i in range(3, self.img_resolution_log2 + 1)
1404
+ ]
1405
+ channels_dict = {
1406
+ res: min(channel_base // res, channel_max) for res in self.block_resolutions
1407
+ }
1408
+ fp16_resolution = max(2 ** (self.img_resolution_log2 + 1 - num_fp16_res), 8)
1409
+
1410
+ self.foreword = SynthesisForeword(
1411
+ img_channels=img_channels,
1412
+ in_channels=min(channel_base // 4, channel_max),
1413
+ z_dim=z_dim * 2,
1414
+ resolution=4,
1415
+ )
1416
+
1417
+ self.num_ws = self.img_resolution_log2 * 2 - 2
1418
+ for res in self.block_resolutions:
1419
+ if res // 2 in channels_dict.keys():
1420
+ in_channels = channels_dict[res // 2] if res > 4 else 0
1421
+ else:
1422
+ in_channels = min(channel_base // (res // 2), channel_max)
1423
+ out_channels = channels_dict[res]
1424
+ use_fp16 = res >= fp16_resolution
1425
+ use_fp16 = False
1426
+ is_last = res == self.img_resolution
1427
+ block = SynthesisBlock(
1428
+ in_channels,
1429
+ out_channels,
1430
+ w_dim=w_dim,
1431
+ resolution=res,
1432
+ img_channels=img_channels,
1433
+ is_last=is_last,
1434
+ use_fp16=use_fp16,
1435
+ **block_kwargs,
1436
+ )
1437
+ setattr(self, f"b{res}", block)
1438
+
1439
+ def forward(self, x_global, mask, feats, ws, fname=None, **block_kwargs):
1440
+ img = None
1441
+
1442
+ x, img = self.foreword(x_global, ws, feats, img)
1443
+
1444
+ for res in self.block_resolutions:
1445
+ block = getattr(self, f"b{res}")
1446
+ mod_vector0 = []
1447
+ mod_vector0.append(ws[:, int(np.log2(res)) * 2 - 5])
1448
+ mod_vector0.append(x_global.clone())
1449
+ mod_vector0 = torch.cat(mod_vector0, dim=1)
1450
+
1451
+ mod_vector1 = []
1452
+ mod_vector1.append(ws[:, int(np.log2(res)) * 2 - 4])
1453
+ mod_vector1.append(x_global.clone())
1454
+ mod_vector1 = torch.cat(mod_vector1, dim=1)
1455
+
1456
+ mod_vector_rgb = []
1457
+ mod_vector_rgb.append(ws[:, int(np.log2(res)) * 2 - 3])
1458
+ mod_vector_rgb.append(x_global.clone())
1459
+ mod_vector_rgb = torch.cat(mod_vector_rgb, dim=1)
1460
+ x, img = block(
1461
+ x,
1462
+ mask,
1463
+ feats,
1464
+ img,
1465
+ (mod_vector0, mod_vector1, mod_vector_rgb),
1466
+ fname=fname,
1467
+ **block_kwargs,
1468
+ )
1469
+ return img
1470
+
1471
+
1472
+ class MappingNetwork(torch.nn.Module):
1473
+ def __init__(
1474
+ self,
1475
+ z_dim, # Input latent (Z) dimensionality, 0 = no latent.
1476
+ c_dim, # Conditioning label (C) dimensionality, 0 = no label.
1477
+ w_dim, # Intermediate latent (W) dimensionality.
1478
+ num_ws, # Number of intermediate latents to output, None = do not broadcast.
1479
+ num_layers=8, # Number of mapping layers.
1480
+ embed_features=None, # Label embedding dimensionality, None = same as w_dim.
1481
+ layer_features=None, # Number of intermediate features in the mapping layers, None = same as w_dim.
1482
+ activation="lrelu", # Activation function: 'relu', 'lrelu', etc.
1483
+ lr_multiplier=0.01, # Learning rate multiplier for the mapping layers.
1484
+ w_avg_beta=0.995, # Decay for tracking the moving average of W during training, None = do not track.
1485
+ ):
1486
+ super().__init__()
1487
+ self.z_dim = z_dim
1488
+ self.c_dim = c_dim
1489
+ self.w_dim = w_dim
1490
+ self.num_ws = num_ws
1491
+ self.num_layers = num_layers
1492
+ self.w_avg_beta = w_avg_beta
1493
+
1494
+ if embed_features is None:
1495
+ embed_features = w_dim
1496
+ if c_dim == 0:
1497
+ embed_features = 0
1498
+ if layer_features is None:
1499
+ layer_features = w_dim
1500
+ features_list = (
1501
+ [z_dim + embed_features] + [layer_features] * (num_layers - 1) + [w_dim]
1502
+ )
1503
+
1504
+ if c_dim > 0:
1505
+ self.embed = FullyConnectedLayer(c_dim, embed_features)
1506
+ for idx in range(num_layers):
1507
+ in_features = features_list[idx]
1508
+ out_features = features_list[idx + 1]
1509
+ layer = FullyConnectedLayer(
1510
+ in_features,
1511
+ out_features,
1512
+ activation=activation,
1513
+ lr_multiplier=lr_multiplier,
1514
+ )
1515
+ setattr(self, f"fc{idx}", layer)
1516
+
1517
+ if num_ws is not None and w_avg_beta is not None:
1518
+ self.register_buffer("w_avg", torch.zeros([w_dim]))
1519
+
1520
+ def forward(
1521
+ self, z, c, truncation_psi=1, truncation_cutoff=None, skip_w_avg_update=False
1522
+ ):
1523
+ # Embed, normalize, and concat inputs.
1524
+ x = None
1525
+ with torch.autograd.profiler.record_function("input"):
1526
+ if self.z_dim > 0:
1527
+ x = normalize_2nd_moment(z.to(torch.float32))
1528
+ if self.c_dim > 0:
1529
+ y = normalize_2nd_moment(self.embed(c.to(torch.float32)))
1530
+ x = torch.cat([x, y], dim=1) if x is not None else y
1531
+
1532
+ # Main layers.
1533
+ for idx in range(self.num_layers):
1534
+ layer = getattr(self, f"fc{idx}")
1535
+ x = layer(x)
1536
+
1537
+ # Update moving average of W.
1538
+ if self.w_avg_beta is not None and self.training and not skip_w_avg_update:
1539
+ with torch.autograd.profiler.record_function("update_w_avg"):
1540
+ self.w_avg.copy_(
1541
+ x.detach().mean(dim=0).lerp(self.w_avg, self.w_avg_beta)
1542
+ )
1543
+
1544
+ # Broadcast.
1545
+ if self.num_ws is not None:
1546
+ with torch.autograd.profiler.record_function("broadcast"):
1547
+ x = x.unsqueeze(1).repeat([1, self.num_ws, 1])
1548
+
1549
+ # Apply truncation.
1550
+ if truncation_psi != 1:
1551
+ with torch.autograd.profiler.record_function("truncate"):
1552
+ assert self.w_avg_beta is not None
1553
+ if self.num_ws is None or truncation_cutoff is None:
1554
+ x = self.w_avg.lerp(x, truncation_psi)
1555
+ else:
1556
+ x[:, :truncation_cutoff] = self.w_avg.lerp(
1557
+ x[:, :truncation_cutoff], truncation_psi
1558
+ )
1559
+ return x
1560
+
1561
+
1562
+ class Generator(torch.nn.Module):
1563
+ def __init__(
1564
+ self,
1565
+ z_dim, # Input latent (Z) dimensionality.
1566
+ c_dim, # Conditioning label (C) dimensionality.
1567
+ w_dim, # Intermediate latent (W) dimensionality.
1568
+ img_resolution, # Output resolution.
1569
+ img_channels, # Number of output color channels.
1570
+ encoder_kwargs={}, # Arguments for EncoderNetwork.
1571
+ mapping_kwargs={}, # Arguments for MappingNetwork.
1572
+ synthesis_kwargs={}, # Arguments for SynthesisNetwork.
1573
+ ):
1574
+ super().__init__()
1575
+ self.z_dim = z_dim
1576
+ self.c_dim = c_dim
1577
+ self.w_dim = w_dim
1578
+ self.img_resolution = img_resolution
1579
+ self.img_channels = img_channels
1580
+ self.encoder = EncoderNetwork(
1581
+ c_dim=c_dim,
1582
+ z_dim=z_dim,
1583
+ img_resolution=img_resolution,
1584
+ img_channels=img_channels,
1585
+ **encoder_kwargs,
1586
+ )
1587
+ self.synthesis = SynthesisNetwork(
1588
+ z_dim=z_dim,
1589
+ w_dim=w_dim,
1590
+ img_resolution=img_resolution,
1591
+ img_channels=img_channels,
1592
+ **synthesis_kwargs,
1593
+ )
1594
+ self.num_ws = self.synthesis.num_ws
1595
+ self.mapping = MappingNetwork(
1596
+ z_dim=z_dim, c_dim=c_dim, w_dim=w_dim, num_ws=self.num_ws, **mapping_kwargs
1597
+ )
1598
+
1599
+ def forward(
1600
+ self,
1601
+ img,
1602
+ c,
1603
+ fname=None,
1604
+ truncation_psi=1,
1605
+ truncation_cutoff=None,
1606
+ **synthesis_kwargs,
1607
+ ):
1608
+ mask = img[:, -1].unsqueeze(1)
1609
+ x_global, z, feats = self.encoder(img, c)
1610
+ ws = self.mapping(
1611
+ z, c, truncation_psi=truncation_psi, truncation_cutoff=truncation_cutoff
1612
+ )
1613
+ img = self.synthesis(x_global, mask, feats, ws, fname=fname, **synthesis_kwargs)
1614
+ return img
1615
+
1616
+
1617
+ FCF_MODEL_URL = os.environ.get(
1618
+ "FCF_MODEL_URL",
1619
+ "https://github.com/Sanster/models/releases/download/add_fcf/places_512_G.pth",
1620
+ )
1621
+ FCF_MODEL_MD5 = os.environ.get("FCF_MODEL_MD5", "3323152bc01bf1c56fd8aba74435a211")
1622
+
1623
+
1624
+ class FcF(InpaintModel):
1625
+ name = "fcf"
1626
+ min_size = 512
1627
+ pad_mod = 512
1628
+ pad_to_square = True
1629
+ is_erase_model = True
1630
+
1631
+ def init_model(self, device, **kwargs):
1632
+ seed = 0
1633
+ random.seed(seed)
1634
+ np.random.seed(seed)
1635
+ torch.manual_seed(seed)
1636
+ torch.cuda.manual_seed_all(seed)
1637
+ torch.backends.cudnn.deterministic = True
1638
+ torch.backends.cudnn.benchmark = False
1639
+
1640
+ kwargs = {
1641
+ "channel_base": 1 * 32768,
1642
+ "channel_max": 512,
1643
+ "num_fp16_res": 4,
1644
+ "conv_clamp": 256,
1645
+ }
1646
+ G = Generator(
1647
+ z_dim=512,
1648
+ c_dim=0,
1649
+ w_dim=512,
1650
+ img_resolution=512,
1651
+ img_channels=3,
1652
+ synthesis_kwargs=kwargs,
1653
+ encoder_kwargs=kwargs,
1654
+ mapping_kwargs={"num_layers": 2},
1655
+ )
1656
+ self.model = load_model(G, FCF_MODEL_URL, device, FCF_MODEL_MD5)
1657
+ self.label = torch.zeros([1, self.model.c_dim], device=device)
1658
+
1659
+ @staticmethod
1660
+ def download():
1661
+ download_model(FCF_MODEL_URL, FCF_MODEL_MD5)
1662
+
1663
+ @staticmethod
1664
+ def is_downloaded() -> bool:
1665
+ return os.path.exists(get_cache_path_by_url(FCF_MODEL_URL))
1666
+
1667
+ @torch.no_grad()
1668
+ def __call__(self, image, mask, config: InpaintRequest):
1669
+ """
1670
+ images: [H, W, C] RGB, not normalized
1671
+ masks: [H, W]
1672
+ return: BGR IMAGE
1673
+ """
1674
+ if image.shape[0] == 512 and image.shape[1] == 512:
1675
+ return self._pad_forward(image, mask, config)
1676
+
1677
+ boxes = boxes_from_mask(mask)
1678
+ crop_result = []
1679
+ config.hd_strategy_crop_margin = 128
1680
+ for box in boxes:
1681
+ crop_image, crop_mask, crop_box = self._crop_box(image, mask, box, config)
1682
+ origin_size = crop_image.shape[:2]
1683
+ resize_image = resize_max_size(crop_image, size_limit=512)
1684
+ resize_mask = resize_max_size(crop_mask, size_limit=512)
1685
+ inpaint_result = self._pad_forward(resize_image, resize_mask, config)
1686
+
1687
+ # only paste masked area result
1688
+ inpaint_result = cv2.resize(
1689
+ inpaint_result,
1690
+ (origin_size[1], origin_size[0]),
1691
+ interpolation=cv2.INTER_CUBIC,
1692
+ )
1693
+
1694
+ original_pixel_indices = crop_mask < 127
1695
+ inpaint_result[original_pixel_indices] = crop_image[:, :, ::-1][
1696
+ original_pixel_indices
1697
+ ]
1698
+
1699
+ crop_result.append((inpaint_result, crop_box))
1700
+
1701
+ inpaint_result = image[:, :, ::-1].copy()
1702
+ for crop_image, crop_box in crop_result:
1703
+ x1, y1, x2, y2 = crop_box
1704
+ inpaint_result[y1:y2, x1:x2, :] = crop_image
1705
+
1706
+ return inpaint_result
1707
+
1708
+ def forward(self, image, mask, config: InpaintRequest):
1709
+ """Input images and output images have same size
1710
+ images: [H, W, C] RGB
1711
+ masks: [H, W] mask area == 255
1712
+ return: BGR IMAGE
1713
+ """
1714
+
1715
+ image = norm_img(image) # [0, 1]
1716
+ image = image * 2 - 1 # [0, 1] -> [-1, 1]
1717
+ mask = (mask > 120) * 255
1718
+ mask = norm_img(mask)
1719
+
1720
+ image = torch.from_numpy(image).unsqueeze(0).to(self.device)
1721
+ mask = torch.from_numpy(mask).unsqueeze(0).to(self.device)
1722
+
1723
+ erased_img = image * (1 - mask)
1724
+ input_image = torch.cat([0.5 - mask, erased_img], dim=1)
1725
+
1726
+ output = self.model(
1727
+ input_image, self.label, truncation_psi=0.1, noise_mode="none"
1728
+ )
1729
+ output = (
1730
+ (output.permute(0, 2, 3, 1) * 127.5 + 127.5)
1731
+ .round()
1732
+ .clamp(0, 255)
1733
+ .to(torch.uint8)
1734
+ )
1735
+ output = output[0].cpu().numpy()
1736
+ cur_res = cv2.cvtColor(output, cv2.COLOR_RGB2BGR)
1737
+ return cur_res
iopaint/model/helper/controlnet_preprocess.py ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import PIL
3
+ import cv2
4
+ from PIL import Image
5
+ import numpy as np
6
+
7
+ from iopaint.helper import pad_img_to_modulo
8
+
9
+
10
+ def make_canny_control_image(image: np.ndarray) -> Image:
11
+ canny_image = cv2.Canny(image, 100, 200)
12
+ canny_image = canny_image[:, :, None]
13
+ canny_image = np.concatenate([canny_image, canny_image, canny_image], axis=2)
14
+ canny_image = PIL.Image.fromarray(canny_image)
15
+ control_image = canny_image
16
+ return control_image
17
+
18
+
19
+ def make_openpose_control_image(image: np.ndarray) -> Image:
20
+ from controlnet_aux import OpenposeDetector
21
+
22
+ processor = OpenposeDetector.from_pretrained("lllyasviel/ControlNet")
23
+ control_image = processor(image, hand_and_face=True)
24
+ return control_image
25
+
26
+
27
+ def resize_image(input_image, resolution):
28
+ H, W, C = input_image.shape
29
+ H = float(H)
30
+ W = float(W)
31
+ k = float(resolution) / min(H, W)
32
+ H *= k
33
+ W *= k
34
+ H = int(np.round(H / 64.0)) * 64
35
+ W = int(np.round(W / 64.0)) * 64
36
+ img = cv2.resize(
37
+ input_image,
38
+ (W, H),
39
+ interpolation=cv2.INTER_LANCZOS4 if k > 1 else cv2.INTER_AREA,
40
+ )
41
+ return img
42
+
43
+
44
+ def make_depth_control_image(image: np.ndarray) -> Image:
45
+ from controlnet_aux import MidasDetector
46
+
47
+ midas = MidasDetector.from_pretrained("lllyasviel/Annotators")
48
+
49
+ origin_height, origin_width = image.shape[:2]
50
+ pad_image = pad_img_to_modulo(image, mod=64, square=False, min_size=512)
51
+ depth_image = midas(pad_image)
52
+ depth_image = depth_image[0:origin_height, 0:origin_width]
53
+ depth_image = depth_image[:, :, None]
54
+ depth_image = np.concatenate([depth_image, depth_image, depth_image], axis=2)
55
+ control_image = PIL.Image.fromarray(depth_image)
56
+ return control_image
57
+
58
+
59
+ def make_inpaint_control_image(image: np.ndarray, mask: np.ndarray) -> torch.Tensor:
60
+ """
61
+ image: [H, W, C] RGB
62
+ mask: [H, W, 1] 255 means area to repaint
63
+ """
64
+ image = image.astype(np.float32) / 255.0
65
+ image[mask[:, :, -1] > 128] = -1.0 # set as masked pixel
66
+ image = np.expand_dims(image, 0).transpose(0, 3, 1, 2)
67
+ image = torch.from_numpy(image)
68
+ return image
iopaint/model/helper/cpu_text_encoder.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from transformers import PreTrainedModel
3
+
4
+ from ..utils import torch_gc
5
+
6
+
7
+ class CPUTextEncoderWrapper(PreTrainedModel):
8
+ def __init__(self, text_encoder, torch_dtype):
9
+ super().__init__(text_encoder.config)
10
+ self.config = text_encoder.config
11
+ self._device = text_encoder.device
12
+ # cpu not support float16
13
+ self.text_encoder = text_encoder.to(torch.device("cpu"), non_blocking=True)
14
+ self.text_encoder = self.text_encoder.to(torch.float32, non_blocking=True)
15
+ self.torch_dtype = torch_dtype
16
+ del text_encoder
17
+ torch_gc()
18
+
19
+ def __call__(self, x, **kwargs):
20
+ input_device = x.device
21
+ original_output = self.text_encoder(x.to(self.text_encoder.device), **kwargs)
22
+ for k, v in original_output.items():
23
+ if isinstance(v, tuple):
24
+ original_output[k] = [
25
+ v[i].to(input_device).to(self.torch_dtype) for i in range(len(v))
26
+ ]
27
+ else:
28
+ original_output[k] = v.to(input_device).to(self.torch_dtype)
29
+ return original_output
30
+
31
+ @property
32
+ def dtype(self):
33
+ return self.torch_dtype
34
+
35
+ @property
36
+ def device(self) -> torch.device:
37
+ """
38
+ `torch.device`: The device on which the module is (assuming that all the module parameters are on the same
39
+ device).
40
+ """
41
+ return self._device
iopaint/model/helper/g_diffuser_bot.py ADDED
@@ -0,0 +1,167 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # code copy from: https://github.com/parlance-zz/g-diffuser-bot
2
+ import cv2
3
+ import numpy as np
4
+
5
+
6
+ def np_img_grey_to_rgb(data):
7
+ if data.ndim == 3:
8
+ return data
9
+ return np.expand_dims(data, 2) * np.ones((1, 1, 3))
10
+
11
+
12
+ def convolve(data1, data2): # fast convolution with fft
13
+ if data1.ndim != data2.ndim: # promote to rgb if mismatch
14
+ if data1.ndim < 3:
15
+ data1 = np_img_grey_to_rgb(data1)
16
+ if data2.ndim < 3:
17
+ data2 = np_img_grey_to_rgb(data2)
18
+ return ifft2(fft2(data1) * fft2(data2))
19
+
20
+
21
+ def fft2(data):
22
+ if data.ndim > 2: # multiple channels
23
+ out_fft = np.zeros(
24
+ (data.shape[0], data.shape[1], data.shape[2]), dtype=np.complex128
25
+ )
26
+ for c in range(data.shape[2]):
27
+ c_data = data[:, :, c]
28
+ out_fft[:, :, c] = np.fft.fft2(np.fft.fftshift(c_data), norm="ortho")
29
+ out_fft[:, :, c] = np.fft.ifftshift(out_fft[:, :, c])
30
+ else: # single channel
31
+ out_fft = np.zeros((data.shape[0], data.shape[1]), dtype=np.complex128)
32
+ out_fft[:, :] = np.fft.fft2(np.fft.fftshift(data), norm="ortho")
33
+ out_fft[:, :] = np.fft.ifftshift(out_fft[:, :])
34
+
35
+ return out_fft
36
+
37
+
38
+ def ifft2(data):
39
+ if data.ndim > 2: # multiple channels
40
+ out_ifft = np.zeros(
41
+ (data.shape[0], data.shape[1], data.shape[2]), dtype=np.complex128
42
+ )
43
+ for c in range(data.shape[2]):
44
+ c_data = data[:, :, c]
45
+ out_ifft[:, :, c] = np.fft.ifft2(np.fft.fftshift(c_data), norm="ortho")
46
+ out_ifft[:, :, c] = np.fft.ifftshift(out_ifft[:, :, c])
47
+ else: # single channel
48
+ out_ifft = np.zeros((data.shape[0], data.shape[1]), dtype=np.complex128)
49
+ out_ifft[:, :] = np.fft.ifft2(np.fft.fftshift(data), norm="ortho")
50
+ out_ifft[:, :] = np.fft.ifftshift(out_ifft[:, :])
51
+
52
+ return out_ifft
53
+
54
+
55
+ def get_gradient_kernel(width, height, std=3.14, mode="linear"):
56
+ window_scale_x = float(
57
+ width / min(width, height)
58
+ ) # for non-square aspect ratios we still want a circular kernel
59
+ window_scale_y = float(height / min(width, height))
60
+ if mode == "gaussian":
61
+ x = (np.arange(width) / width * 2.0 - 1.0) * window_scale_x
62
+ kx = np.exp(-x * x * std)
63
+ if window_scale_x != window_scale_y:
64
+ y = (np.arange(height) / height * 2.0 - 1.0) * window_scale_y
65
+ ky = np.exp(-y * y * std)
66
+ else:
67
+ y = x
68
+ ky = kx
69
+ return np.outer(kx, ky)
70
+ elif mode == "linear":
71
+ x = (np.arange(width) / width * 2.0 - 1.0) * window_scale_x
72
+ if window_scale_x != window_scale_y:
73
+ y = (np.arange(height) / height * 2.0 - 1.0) * window_scale_y
74
+ else:
75
+ y = x
76
+ return np.clip(1.0 - np.sqrt(np.add.outer(x * x, y * y)) * std / 3.14, 0.0, 1.0)
77
+ else:
78
+ raise Exception("Error: Unknown mode in get_gradient_kernel: {0}".format(mode))
79
+
80
+
81
+ def image_blur(data, std=3.14, mode="linear"):
82
+ width = data.shape[0]
83
+ height = data.shape[1]
84
+ kernel = get_gradient_kernel(width, height, std, mode=mode)
85
+ return np.real(convolve(data, kernel / np.sqrt(np.sum(kernel * kernel))))
86
+
87
+
88
+ def soften_mask(mask_img, softness, space):
89
+ if softness == 0:
90
+ return mask_img
91
+ softness = min(softness, 1.0)
92
+ space = np.clip(space, 0.0, 1.0)
93
+ original_max_opacity = np.max(mask_img)
94
+ out_mask = mask_img <= 0.0
95
+ blurred_mask = image_blur(mask_img, 3.5 / softness, mode="linear")
96
+ blurred_mask = np.maximum(blurred_mask - np.max(blurred_mask[out_mask]), 0.0)
97
+ mask_img *= blurred_mask # preserve partial opacity in original input mask
98
+ mask_img /= np.max(mask_img) # renormalize
99
+ mask_img = np.clip(mask_img - space, 0.0, 1.0) # make space
100
+ mask_img /= np.max(mask_img) # and renormalize again
101
+ mask_img *= original_max_opacity # restore original max opacity
102
+ return mask_img
103
+
104
+
105
+ def expand_image(
106
+ cv2_img, top: int, right: int, bottom: int, left: int, softness: float, space: float
107
+ ):
108
+ assert cv2_img.shape[2] == 3
109
+ origin_h, origin_w = cv2_img.shape[:2]
110
+ new_width = cv2_img.shape[1] + left + right
111
+ new_height = cv2_img.shape[0] + top + bottom
112
+
113
+ # TODO: which is better?
114
+ # new_img = np.random.randint(0, 255, (new_height, new_width, 3), np.uint8)
115
+ new_img = cv2.copyMakeBorder(
116
+ cv2_img, top, bottom, left, right, cv2.BORDER_REPLICATE
117
+ )
118
+ mask_img = np.zeros((new_height, new_width), np.uint8)
119
+ mask_img[top : top + cv2_img.shape[0], left : left + cv2_img.shape[1]] = 255
120
+
121
+ if softness > 0.0:
122
+ mask_img = soften_mask(mask_img / 255.0, softness / 100.0, space / 100.0)
123
+ mask_img = (np.clip(mask_img, 0.0, 1.0) * 255.0).astype(np.uint8)
124
+
125
+ mask_image = 255.0 - mask_img # extract mask from alpha channel and invert
126
+ rgb_init_image = (
127
+ 0.0 + new_img[:, :, 0:3]
128
+ ) # strip mask from init_img leaving only rgb channels
129
+
130
+ hard_mask = np.zeros_like(cv2_img[:, :, 0])
131
+ if top != 0:
132
+ hard_mask[0 : origin_h // 2, :] = 255
133
+ if bottom != 0:
134
+ hard_mask[origin_h // 2 :, :] = 255
135
+ if left != 0:
136
+ hard_mask[:, 0 : origin_w // 2] = 255
137
+ if right != 0:
138
+ hard_mask[:, origin_w // 2 :] = 255
139
+
140
+ hard_mask = cv2.copyMakeBorder(
141
+ hard_mask, top, bottom, left, right, cv2.BORDER_DEFAULT, value=255
142
+ )
143
+ mask_image = np.where(hard_mask > 0, mask_image, 0)
144
+ return rgb_init_image.astype(np.uint8), mask_image.astype(np.uint8)
145
+
146
+
147
+ if __name__ == "__main__":
148
+ from pathlib import Path
149
+
150
+ current_dir = Path(__file__).parent.absolute().resolve()
151
+ image_path = current_dir.parent / "tests" / "bunny.jpeg"
152
+ init_image = cv2.imread(str(image_path))
153
+ init_image, mask_image = expand_image(
154
+ init_image,
155
+ top=100,
156
+ right=100,
157
+ bottom=100,
158
+ left=100,
159
+ softness=20,
160
+ space=20,
161
+ )
162
+ print(mask_image.dtype, mask_image.min(), mask_image.max())
163
+ print(init_image.dtype, init_image.min(), init_image.max())
164
+ mask_image = mask_image.astype(np.uint8)
165
+ init_image = init_image.astype(np.uint8)
166
+ cv2.imwrite("expanded_image.png", init_image)
167
+ cv2.imwrite("expanded_mask.png", mask_image)
iopaint/model/instruct_pix2pix.py ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import PIL.Image
2
+ import cv2
3
+ import torch
4
+ from loguru import logger
5
+
6
+ from iopaint.const import INSTRUCT_PIX2PIX_NAME
7
+ from .base import DiffusionInpaintModel
8
+ from iopaint.schema import InpaintRequest
9
+ from .utils import get_torch_dtype, enable_low_mem, is_local_files_only
10
+
11
+
12
+ class InstructPix2Pix(DiffusionInpaintModel):
13
+ name = INSTRUCT_PIX2PIX_NAME
14
+ pad_mod = 8
15
+ min_size = 512
16
+
17
+ def init_model(self, device: torch.device, **kwargs):
18
+ from diffusers import StableDiffusionInstructPix2PixPipeline
19
+
20
+ use_gpu, torch_dtype = get_torch_dtype(device, kwargs.get("no_half", False))
21
+
22
+ model_kwargs = {"local_files_only": is_local_files_only(**kwargs)}
23
+ if kwargs["disable_nsfw"] or kwargs.get("cpu_offload", False):
24
+ logger.info("Disable Stable Diffusion Model NSFW checker")
25
+ model_kwargs.update(
26
+ dict(
27
+ safety_checker=None,
28
+ feature_extractor=None,
29
+ requires_safety_checker=False,
30
+ )
31
+ )
32
+
33
+ self.model = StableDiffusionInstructPix2PixPipeline.from_pretrained(
34
+ self.name, variant="fp16", torch_dtype=torch_dtype, **model_kwargs
35
+ )
36
+ enable_low_mem(self.model, kwargs.get("low_mem", False))
37
+
38
+ if kwargs.get("cpu_offload", False) and use_gpu:
39
+ logger.info("Enable sequential cpu offload")
40
+ self.model.enable_sequential_cpu_offload(gpu_id=0)
41
+ else:
42
+ self.model = self.model.to(device)
43
+
44
+ def forward(self, image, mask, config: InpaintRequest):
45
+ """Input image and output image have same size
46
+ image: [H, W, C] RGB
47
+ mask: [H, W, 1] 255 means area to repaint
48
+ return: BGR IMAGE
49
+ edit = pipe(prompt, image=image, num_inference_steps=20, image_guidance_scale=1.5, guidance_scale=7).images[0]
50
+ """
51
+ output = self.model(
52
+ image=PIL.Image.fromarray(image),
53
+ prompt=config.prompt,
54
+ negative_prompt=config.negative_prompt,
55
+ num_inference_steps=config.sd_steps,
56
+ image_guidance_scale=config.p2p_image_guidance_scale,
57
+ guidance_scale=config.sd_guidance_scale,
58
+ output_type="np",
59
+ generator=torch.manual_seed(config.sd_seed),
60
+ ).images[0]
61
+
62
+ output = (output * 255).round().astype("uint8")
63
+ output = cv2.cvtColor(output, cv2.COLOR_RGB2BGR)
64
+ return output
iopaint/model/kandinsky.py ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import PIL.Image
2
+ import cv2
3
+ import numpy as np
4
+ import torch
5
+
6
+ from iopaint.const import KANDINSKY22_NAME
7
+ from .base import DiffusionInpaintModel
8
+ from iopaint.schema import InpaintRequest
9
+ from .utils import get_torch_dtype, enable_low_mem, is_local_files_only
10
+
11
+
12
+ class Kandinsky(DiffusionInpaintModel):
13
+ pad_mod = 64
14
+ min_size = 512
15
+
16
+ def init_model(self, device: torch.device, **kwargs):
17
+ from diffusers import AutoPipelineForInpainting
18
+
19
+ use_gpu, torch_dtype = get_torch_dtype(device, kwargs.get("no_half", False))
20
+
21
+ model_kwargs = {
22
+ "torch_dtype": torch_dtype,
23
+ "local_files_only": is_local_files_only(**kwargs),
24
+ }
25
+ self.model = AutoPipelineForInpainting.from_pretrained(
26
+ self.name, **model_kwargs
27
+ ).to(device)
28
+ enable_low_mem(self.model, kwargs.get("low_mem", False))
29
+
30
+ self.callback = kwargs.pop("callback", None)
31
+
32
+ def forward(self, image, mask, config: InpaintRequest):
33
+ """Input image and output image have same size
34
+ image: [H, W, C] RGB
35
+ mask: [H, W, 1] 255 means area to repaint
36
+ return: BGR IMAGE
37
+ """
38
+ self.set_scheduler(config)
39
+
40
+ generator = torch.manual_seed(config.sd_seed)
41
+ mask = mask.astype(np.float32) / 255
42
+ img_h, img_w = image.shape[:2]
43
+
44
+ # kandinsky 没有 strength
45
+ output = self.model(
46
+ prompt=config.prompt,
47
+ negative_prompt=config.negative_prompt,
48
+ image=PIL.Image.fromarray(image),
49
+ mask_image=mask[:, :, 0],
50
+ height=img_h,
51
+ width=img_w,
52
+ num_inference_steps=config.sd_steps,
53
+ guidance_scale=config.sd_guidance_scale,
54
+ output_type="np",
55
+ callback_on_step_end=self.callback,
56
+ generator=generator,
57
+ ).images[0]
58
+
59
+ output = (output * 255).round().astype("uint8")
60
+ output = cv2.cvtColor(output, cv2.COLOR_RGB2BGR)
61
+ return output
62
+
63
+
64
+ class Kandinsky22(Kandinsky):
65
+ name = KANDINSKY22_NAME
iopaint/model/lama.py ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ import cv2
4
+ import numpy as np
5
+ import torch
6
+
7
+ from iopaint.helper import (
8
+ norm_img,
9
+ get_cache_path_by_url,
10
+ load_jit_model,
11
+ download_model,
12
+ )
13
+ from iopaint.schema import InpaintRequest
14
+ from .base import InpaintModel
15
+
16
+ LAMA_MODEL_URL = os.environ.get(
17
+ "LAMA_MODEL_URL",
18
+ "https://github.com/Sanster/models/releases/download/add_big_lama/big-lama.pt",
19
+ )
20
+ LAMA_MODEL_MD5 = os.environ.get("LAMA_MODEL_MD5", "e3aa4aaa15225a33ec84f9f4bc47e500")
21
+
22
+
23
+ class LaMa(InpaintModel):
24
+ name = "lama"
25
+ pad_mod = 8
26
+ is_erase_model = True
27
+
28
+ @staticmethod
29
+ def download():
30
+ download_model(LAMA_MODEL_URL, LAMA_MODEL_MD5)
31
+
32
+ def init_model(self, device, **kwargs):
33
+ self.model = load_jit_model(LAMA_MODEL_URL, device, LAMA_MODEL_MD5).eval()
34
+
35
+ @staticmethod
36
+ def is_downloaded() -> bool:
37
+ return os.path.exists(get_cache_path_by_url(LAMA_MODEL_URL))
38
+
39
+ def forward(self, image, mask, config: InpaintRequest):
40
+ """Input image and output image have same size
41
+ image: [H, W, C] RGB
42
+ mask: [H, W]
43
+ return: BGR IMAGE
44
+ """
45
+ image = norm_img(image)
46
+ mask = norm_img(mask)
47
+
48
+ mask = (mask > 0) * 1
49
+ image = torch.from_numpy(image).unsqueeze(0).to(self.device)
50
+ mask = torch.from_numpy(mask).unsqueeze(0).to(self.device)
51
+
52
+ inpainted_image = self.model(image, mask)
53
+
54
+ cur_res = inpainted_image[0].permute(1, 2, 0).detach().cpu().numpy()
55
+ cur_res = np.clip(cur_res * 255, 0, 255).astype("uint8")
56
+ cur_res = cv2.cvtColor(cur_res, cv2.COLOR_RGB2BGR)
57
+ return cur_res
iopaint/model/ldm.py ADDED
@@ -0,0 +1,336 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ import numpy as np
4
+ import torch
5
+ from loguru import logger
6
+
7
+ from .base import InpaintModel
8
+ from .ddim_sampler import DDIMSampler
9
+ from .plms_sampler import PLMSSampler
10
+ from iopaint.schema import InpaintRequest, LDMSampler
11
+
12
+ torch.manual_seed(42)
13
+ import torch.nn as nn
14
+ from iopaint.helper import (
15
+ download_model,
16
+ norm_img,
17
+ get_cache_path_by_url,
18
+ load_jit_model,
19
+ )
20
+ from .utils import (
21
+ make_beta_schedule,
22
+ timestep_embedding,
23
+ )
24
+
25
+ LDM_ENCODE_MODEL_URL = os.environ.get(
26
+ "LDM_ENCODE_MODEL_URL",
27
+ "https://github.com/Sanster/models/releases/download/add_ldm/cond_stage_model_encode.pt",
28
+ )
29
+ LDM_ENCODE_MODEL_MD5 = os.environ.get(
30
+ "LDM_ENCODE_MODEL_MD5", "23239fc9081956a3e70de56472b3f296"
31
+ )
32
+
33
+ LDM_DECODE_MODEL_URL = os.environ.get(
34
+ "LDM_DECODE_MODEL_URL",
35
+ "https://github.com/Sanster/models/releases/download/add_ldm/cond_stage_model_decode.pt",
36
+ )
37
+ LDM_DECODE_MODEL_MD5 = os.environ.get(
38
+ "LDM_DECODE_MODEL_MD5", "fe419cd15a750d37a4733589d0d3585c"
39
+ )
40
+
41
+ LDM_DIFFUSION_MODEL_URL = os.environ.get(
42
+ "LDM_DIFFUSION_MODEL_URL",
43
+ "https://github.com/Sanster/models/releases/download/add_ldm/diffusion.pt",
44
+ )
45
+
46
+ LDM_DIFFUSION_MODEL_MD5 = os.environ.get(
47
+ "LDM_DIFFUSION_MODEL_MD5", "b0afda12bf790c03aba2a7431f11d22d"
48
+ )
49
+
50
+
51
+ class DDPM(nn.Module):
52
+ # classic DDPM with Gaussian diffusion, in image space
53
+ def __init__(
54
+ self,
55
+ device,
56
+ timesteps=1000,
57
+ beta_schedule="linear",
58
+ linear_start=0.0015,
59
+ linear_end=0.0205,
60
+ cosine_s=0.008,
61
+ original_elbo_weight=0.0,
62
+ v_posterior=0.0, # weight for choosing posterior variance as sigma = (1-v) * beta_tilde + v * beta
63
+ l_simple_weight=1.0,
64
+ parameterization="eps", # all assuming fixed variance schedules
65
+ use_positional_encodings=False,
66
+ ):
67
+ super().__init__()
68
+ self.device = device
69
+ self.parameterization = parameterization
70
+ self.use_positional_encodings = use_positional_encodings
71
+
72
+ self.v_posterior = v_posterior
73
+ self.original_elbo_weight = original_elbo_weight
74
+ self.l_simple_weight = l_simple_weight
75
+
76
+ self.register_schedule(
77
+ beta_schedule=beta_schedule,
78
+ timesteps=timesteps,
79
+ linear_start=linear_start,
80
+ linear_end=linear_end,
81
+ cosine_s=cosine_s,
82
+ )
83
+
84
+ def register_schedule(
85
+ self,
86
+ given_betas=None,
87
+ beta_schedule="linear",
88
+ timesteps=1000,
89
+ linear_start=1e-4,
90
+ linear_end=2e-2,
91
+ cosine_s=8e-3,
92
+ ):
93
+ betas = make_beta_schedule(
94
+ self.device,
95
+ beta_schedule,
96
+ timesteps,
97
+ linear_start=linear_start,
98
+ linear_end=linear_end,
99
+ cosine_s=cosine_s,
100
+ )
101
+ alphas = 1.0 - betas
102
+ alphas_cumprod = np.cumprod(alphas, axis=0)
103
+ alphas_cumprod_prev = np.append(1.0, alphas_cumprod[:-1])
104
+
105
+ (timesteps,) = betas.shape
106
+ self.num_timesteps = int(timesteps)
107
+ self.linear_start = linear_start
108
+ self.linear_end = linear_end
109
+ assert (
110
+ alphas_cumprod.shape[0] == self.num_timesteps
111
+ ), "alphas have to be defined for each timestep"
112
+
113
+ to_torch = lambda x: torch.tensor(x, dtype=torch.float32).to(self.device)
114
+
115
+ self.register_buffer("betas", to_torch(betas))
116
+ self.register_buffer("alphas_cumprod", to_torch(alphas_cumprod))
117
+ self.register_buffer("alphas_cumprod_prev", to_torch(alphas_cumprod_prev))
118
+
119
+ # calculations for diffusion q(x_t | x_{t-1}) and others
120
+ self.register_buffer("sqrt_alphas_cumprod", to_torch(np.sqrt(alphas_cumprod)))
121
+ self.register_buffer(
122
+ "sqrt_one_minus_alphas_cumprod", to_torch(np.sqrt(1.0 - alphas_cumprod))
123
+ )
124
+ self.register_buffer(
125
+ "log_one_minus_alphas_cumprod", to_torch(np.log(1.0 - alphas_cumprod))
126
+ )
127
+ self.register_buffer(
128
+ "sqrt_recip_alphas_cumprod", to_torch(np.sqrt(1.0 / alphas_cumprod))
129
+ )
130
+ self.register_buffer(
131
+ "sqrt_recipm1_alphas_cumprod", to_torch(np.sqrt(1.0 / alphas_cumprod - 1))
132
+ )
133
+
134
+ # calculations for posterior q(x_{t-1} | x_t, x_0)
135
+ posterior_variance = (1 - self.v_posterior) * betas * (
136
+ 1.0 - alphas_cumprod_prev
137
+ ) / (1.0 - alphas_cumprod) + self.v_posterior * betas
138
+ # above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t)
139
+ self.register_buffer("posterior_variance", to_torch(posterior_variance))
140
+ # below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain
141
+ self.register_buffer(
142
+ "posterior_log_variance_clipped",
143
+ to_torch(np.log(np.maximum(posterior_variance, 1e-20))),
144
+ )
145
+ self.register_buffer(
146
+ "posterior_mean_coef1",
147
+ to_torch(betas * np.sqrt(alphas_cumprod_prev) / (1.0 - alphas_cumprod)),
148
+ )
149
+ self.register_buffer(
150
+ "posterior_mean_coef2",
151
+ to_torch(
152
+ (1.0 - alphas_cumprod_prev) * np.sqrt(alphas) / (1.0 - alphas_cumprod)
153
+ ),
154
+ )
155
+
156
+ if self.parameterization == "eps":
157
+ lvlb_weights = self.betas**2 / (
158
+ 2
159
+ * self.posterior_variance
160
+ * to_torch(alphas)
161
+ * (1 - self.alphas_cumprod)
162
+ )
163
+ elif self.parameterization == "x0":
164
+ lvlb_weights = (
165
+ 0.5
166
+ * np.sqrt(torch.Tensor(alphas_cumprod))
167
+ / (2.0 * 1 - torch.Tensor(alphas_cumprod))
168
+ )
169
+ else:
170
+ raise NotImplementedError("mu not supported")
171
+ # TODO how to choose this term
172
+ lvlb_weights[0] = lvlb_weights[1]
173
+ self.register_buffer("lvlb_weights", lvlb_weights, persistent=False)
174
+ assert not torch.isnan(self.lvlb_weights).all()
175
+
176
+
177
+ class LatentDiffusion(DDPM):
178
+ def __init__(
179
+ self,
180
+ diffusion_model,
181
+ device,
182
+ cond_stage_key="image",
183
+ cond_stage_trainable=False,
184
+ concat_mode=True,
185
+ scale_factor=1.0,
186
+ scale_by_std=False,
187
+ *args,
188
+ **kwargs,
189
+ ):
190
+ self.num_timesteps_cond = 1
191
+ self.scale_by_std = scale_by_std
192
+ super().__init__(device, *args, **kwargs)
193
+ self.diffusion_model = diffusion_model
194
+ self.concat_mode = concat_mode
195
+ self.cond_stage_trainable = cond_stage_trainable
196
+ self.cond_stage_key = cond_stage_key
197
+ self.num_downs = 2
198
+ self.scale_factor = scale_factor
199
+
200
+ def make_cond_schedule(
201
+ self,
202
+ ):
203
+ self.cond_ids = torch.full(
204
+ size=(self.num_timesteps,),
205
+ fill_value=self.num_timesteps - 1,
206
+ dtype=torch.long,
207
+ )
208
+ ids = torch.round(
209
+ torch.linspace(0, self.num_timesteps - 1, self.num_timesteps_cond)
210
+ ).long()
211
+ self.cond_ids[: self.num_timesteps_cond] = ids
212
+
213
+ def register_schedule(
214
+ self,
215
+ given_betas=None,
216
+ beta_schedule="linear",
217
+ timesteps=1000,
218
+ linear_start=1e-4,
219
+ linear_end=2e-2,
220
+ cosine_s=8e-3,
221
+ ):
222
+ super().register_schedule(
223
+ given_betas, beta_schedule, timesteps, linear_start, linear_end, cosine_s
224
+ )
225
+
226
+ self.shorten_cond_schedule = self.num_timesteps_cond > 1
227
+ if self.shorten_cond_schedule:
228
+ self.make_cond_schedule()
229
+
230
+ def apply_model(self, x_noisy, t, cond):
231
+ # x_recon = self.model(x_noisy, t, cond['c_concat'][0]) # cond['c_concat'][0].shape 1,4,128,128
232
+ t_emb = timestep_embedding(x_noisy.device, t, 256, repeat_only=False)
233
+ x_recon = self.diffusion_model(x_noisy, t_emb, cond)
234
+ return x_recon
235
+
236
+
237
+ class LDM(InpaintModel):
238
+ name = "ldm"
239
+ pad_mod = 32
240
+ is_erase_model = True
241
+
242
+ def __init__(self, device, fp16: bool = True, **kwargs):
243
+ self.fp16 = fp16
244
+ super().__init__(device)
245
+ self.device = device
246
+
247
+ def init_model(self, device, **kwargs):
248
+ self.diffusion_model = load_jit_model(
249
+ LDM_DIFFUSION_MODEL_URL, device, LDM_DIFFUSION_MODEL_MD5
250
+ )
251
+ self.cond_stage_model_decode = load_jit_model(
252
+ LDM_DECODE_MODEL_URL, device, LDM_DECODE_MODEL_MD5
253
+ )
254
+ self.cond_stage_model_encode = load_jit_model(
255
+ LDM_ENCODE_MODEL_URL, device, LDM_ENCODE_MODEL_MD5
256
+ )
257
+ if self.fp16 and "cuda" in str(device):
258
+ self.diffusion_model = self.diffusion_model.half()
259
+ self.cond_stage_model_decode = self.cond_stage_model_decode.half()
260
+ self.cond_stage_model_encode = self.cond_stage_model_encode.half()
261
+
262
+ self.model = LatentDiffusion(self.diffusion_model, device)
263
+
264
+ @staticmethod
265
+ def download():
266
+ download_model(LDM_DIFFUSION_MODEL_URL, LDM_DIFFUSION_MODEL_MD5)
267
+ download_model(LDM_DECODE_MODEL_URL, LDM_DECODE_MODEL_MD5)
268
+ download_model(LDM_ENCODE_MODEL_URL, LDM_ENCODE_MODEL_MD5)
269
+
270
+ @staticmethod
271
+ def is_downloaded() -> bool:
272
+ model_paths = [
273
+ get_cache_path_by_url(LDM_DIFFUSION_MODEL_URL),
274
+ get_cache_path_by_url(LDM_DECODE_MODEL_URL),
275
+ get_cache_path_by_url(LDM_ENCODE_MODEL_URL),
276
+ ]
277
+ return all([os.path.exists(it) for it in model_paths])
278
+
279
+ @torch.cuda.amp.autocast()
280
+ def forward(self, image, mask, config: InpaintRequest):
281
+ """
282
+ image: [H, W, C] RGB
283
+ mask: [H, W, 1]
284
+ return: BGR IMAGE
285
+ """
286
+ # image [1,3,512,512] float32
287
+ # mask: [1,1,512,512] float32
288
+ # masked_image: [1,3,512,512] float32
289
+ if config.ldm_sampler == LDMSampler.ddim:
290
+ sampler = DDIMSampler(self.model)
291
+ elif config.ldm_sampler == LDMSampler.plms:
292
+ sampler = PLMSSampler(self.model)
293
+ else:
294
+ raise ValueError()
295
+
296
+ steps = config.ldm_steps
297
+ image = norm_img(image)
298
+ mask = norm_img(mask)
299
+
300
+ mask[mask < 0.5] = 0
301
+ mask[mask >= 0.5] = 1
302
+
303
+ image = torch.from_numpy(image).unsqueeze(0).to(self.device)
304
+ mask = torch.from_numpy(mask).unsqueeze(0).to(self.device)
305
+ masked_image = (1 - mask) * image
306
+
307
+ mask = self._norm(mask)
308
+ masked_image = self._norm(masked_image)
309
+
310
+ c = self.cond_stage_model_encode(masked_image)
311
+ torch.cuda.empty_cache()
312
+
313
+ cc = torch.nn.functional.interpolate(mask, size=c.shape[-2:]) # 1,1,128,128
314
+ c = torch.cat((c, cc), dim=1) # 1,4,128,128
315
+
316
+ shape = (c.shape[1] - 1,) + c.shape[2:]
317
+ samples_ddim = sampler.sample(
318
+ steps=steps, conditioning=c, batch_size=c.shape[0], shape=shape
319
+ )
320
+ torch.cuda.empty_cache()
321
+ x_samples_ddim = self.cond_stage_model_decode(
322
+ samples_ddim
323
+ ) # samples_ddim: 1, 3, 128, 128 float32
324
+ torch.cuda.empty_cache()
325
+
326
+ # image = torch.clamp((image + 1.0) / 2.0, min=0.0, max=1.0)
327
+ # mask = torch.clamp((mask + 1.0) / 2.0, min=0.0, max=1.0)
328
+ inpainted_image = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
329
+
330
+ # inpainted = (1 - mask) * image + mask * predicted_image
331
+ inpainted_image = inpainted_image.cpu().numpy().transpose(0, 2, 3, 1)[0] * 255
332
+ inpainted_image = inpainted_image.astype(np.uint8)[:, :, ::-1]
333
+ return inpainted_image
334
+
335
+ def _norm(self, tensor):
336
+ return tensor * 2.0 - 1.0
iopaint/model/manga.py ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import random
3
+
4
+ import cv2
5
+ import numpy as np
6
+ import torch
7
+ import time
8
+ from loguru import logger
9
+
10
+ from iopaint.helper import get_cache_path_by_url, load_jit_model, download_model
11
+ from .base import InpaintModel
12
+ from iopaint.schema import InpaintRequest
13
+
14
+
15
+ MANGA_INPAINTOR_MODEL_URL = os.environ.get(
16
+ "MANGA_INPAINTOR_MODEL_URL",
17
+ "https://github.com/Sanster/models/releases/download/manga/manga_inpaintor.jit",
18
+ )
19
+ MANGA_INPAINTOR_MODEL_MD5 = os.environ.get(
20
+ "MANGA_INPAINTOR_MODEL_MD5", "7d8b269c4613b6b3768af714610da86c"
21
+ )
22
+
23
+ MANGA_LINE_MODEL_URL = os.environ.get(
24
+ "MANGA_LINE_MODEL_URL",
25
+ "https://github.com/Sanster/models/releases/download/manga/erika.jit",
26
+ )
27
+ MANGA_LINE_MODEL_MD5 = os.environ.get(
28
+ "MANGA_LINE_MODEL_MD5", "0c926d5a4af8450b0d00bc5b9a095644"
29
+ )
30
+
31
+
32
+ class Manga(InpaintModel):
33
+ name = "manga"
34
+ pad_mod = 16
35
+ is_erase_model = True
36
+
37
+ def init_model(self, device, **kwargs):
38
+ self.inpaintor_model = load_jit_model(
39
+ MANGA_INPAINTOR_MODEL_URL, device, MANGA_INPAINTOR_MODEL_MD5
40
+ )
41
+ self.line_model = load_jit_model(
42
+ MANGA_LINE_MODEL_URL, device, MANGA_LINE_MODEL_MD5
43
+ )
44
+ self.seed = 42
45
+
46
+ @staticmethod
47
+ def download():
48
+ download_model(MANGA_INPAINTOR_MODEL_URL, MANGA_INPAINTOR_MODEL_MD5)
49
+ download_model(MANGA_LINE_MODEL_URL, MANGA_LINE_MODEL_MD5)
50
+
51
+ @staticmethod
52
+ def is_downloaded() -> bool:
53
+ model_paths = [
54
+ get_cache_path_by_url(MANGA_INPAINTOR_MODEL_URL),
55
+ get_cache_path_by_url(MANGA_LINE_MODEL_URL),
56
+ ]
57
+ return all([os.path.exists(it) for it in model_paths])
58
+
59
+ def forward(self, image, mask, config: InpaintRequest):
60
+ """
61
+ image: [H, W, C] RGB
62
+ mask: [H, W, 1]
63
+ return: BGR IMAGE
64
+ """
65
+ seed = self.seed
66
+ random.seed(seed)
67
+ np.random.seed(seed)
68
+ torch.manual_seed(seed)
69
+ torch.cuda.manual_seed_all(seed)
70
+
71
+ gray_img = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)
72
+ gray_img = torch.from_numpy(
73
+ gray_img[np.newaxis, np.newaxis, :, :].astype(np.float32)
74
+ ).to(self.device)
75
+ start = time.time()
76
+ lines = self.line_model(gray_img)
77
+ torch.cuda.empty_cache()
78
+ lines = torch.clamp(lines, 0, 255)
79
+ logger.info(f"erika_model time: {time.time() - start}")
80
+
81
+ mask = torch.from_numpy(mask[np.newaxis, :, :, :]).to(self.device)
82
+ mask = mask.permute(0, 3, 1, 2)
83
+ mask = torch.where(mask > 0.5, 1.0, 0.0)
84
+ noise = torch.randn_like(mask)
85
+ ones = torch.ones_like(mask)
86
+
87
+ gray_img = gray_img / 255 * 2 - 1.0
88
+ lines = lines / 255 * 2 - 1.0
89
+
90
+ start = time.time()
91
+ inpainted_image = self.inpaintor_model(gray_img, lines, mask, noise, ones)
92
+ logger.info(f"image_inpaintor_model time: {time.time() - start}")
93
+
94
+ cur_res = inpainted_image[0].permute(1, 2, 0).detach().cpu().numpy()
95
+ cur_res = (cur_res * 127.5 + 127.5).astype(np.uint8)
96
+ cur_res = cv2.cvtColor(cur_res, cv2.COLOR_GRAY2BGR)
97
+ return cur_res
iopaint/model/mat.py ADDED
@@ -0,0 +1,1945 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import random
3
+
4
+ import cv2
5
+ import numpy as np
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+ import torch.utils.checkpoint as checkpoint
10
+
11
+ from iopaint.helper import (
12
+ load_model,
13
+ get_cache_path_by_url,
14
+ norm_img,
15
+ download_model,
16
+ )
17
+ from iopaint.schema import InpaintRequest
18
+ from .base import InpaintModel
19
+ from .utils import (
20
+ setup_filter,
21
+ Conv2dLayer,
22
+ FullyConnectedLayer,
23
+ conv2d_resample,
24
+ bias_act,
25
+ upsample2d,
26
+ activation_funcs,
27
+ MinibatchStdLayer,
28
+ to_2tuple,
29
+ normalize_2nd_moment,
30
+ set_seed,
31
+ )
32
+
33
+
34
+ class ModulatedConv2d(nn.Module):
35
+ def __init__(
36
+ self,
37
+ in_channels, # Number of input channels.
38
+ out_channels, # Number of output channels.
39
+ kernel_size, # Width and height of the convolution kernel.
40
+ style_dim, # dimension of the style code
41
+ demodulate=True, # perfrom demodulation
42
+ up=1, # Integer upsampling factor.
43
+ down=1, # Integer downsampling factor.
44
+ resample_filter=[
45
+ 1,
46
+ 3,
47
+ 3,
48
+ 1,
49
+ ], # Low-pass filter to apply when resampling activations.
50
+ conv_clamp=None, # Clamp the output to +-X, None = disable clamping.
51
+ ):
52
+ super().__init__()
53
+ self.demodulate = demodulate
54
+
55
+ self.weight = torch.nn.Parameter(
56
+ torch.randn([1, out_channels, in_channels, kernel_size, kernel_size])
57
+ )
58
+ self.out_channels = out_channels
59
+ self.kernel_size = kernel_size
60
+ self.weight_gain = 1 / np.sqrt(in_channels * (kernel_size**2))
61
+ self.padding = self.kernel_size // 2
62
+ self.up = up
63
+ self.down = down
64
+ self.register_buffer("resample_filter", setup_filter(resample_filter))
65
+ self.conv_clamp = conv_clamp
66
+
67
+ self.affine = FullyConnectedLayer(style_dim, in_channels, bias_init=1)
68
+
69
+ def forward(self, x, style):
70
+ batch, in_channels, height, width = x.shape
71
+ style = self.affine(style).view(batch, 1, in_channels, 1, 1)
72
+ weight = self.weight * self.weight_gain * style
73
+
74
+ if self.demodulate:
75
+ decoefs = (weight.pow(2).sum(dim=[2, 3, 4]) + 1e-8).rsqrt()
76
+ weight = weight * decoefs.view(batch, self.out_channels, 1, 1, 1)
77
+
78
+ weight = weight.view(
79
+ batch * self.out_channels, in_channels, self.kernel_size, self.kernel_size
80
+ )
81
+ x = x.view(1, batch * in_channels, height, width)
82
+ x = conv2d_resample(
83
+ x=x,
84
+ w=weight,
85
+ f=self.resample_filter,
86
+ up=self.up,
87
+ down=self.down,
88
+ padding=self.padding,
89
+ groups=batch,
90
+ )
91
+ out = x.view(batch, self.out_channels, *x.shape[2:])
92
+
93
+ return out
94
+
95
+
96
+ class StyleConv(torch.nn.Module):
97
+ def __init__(
98
+ self,
99
+ in_channels, # Number of input channels.
100
+ out_channels, # Number of output channels.
101
+ style_dim, # Intermediate latent (W) dimensionality.
102
+ resolution, # Resolution of this layer.
103
+ kernel_size=3, # Convolution kernel size.
104
+ up=1, # Integer upsampling factor.
105
+ use_noise=False, # Enable noise input?
106
+ activation="lrelu", # Activation function: 'relu', 'lrelu', etc.
107
+ resample_filter=[
108
+ 1,
109
+ 3,
110
+ 3,
111
+ 1,
112
+ ], # Low-pass filter to apply when resampling activations.
113
+ conv_clamp=None, # Clamp the output of convolution layers to +-X, None = disable clamping.
114
+ demodulate=True, # perform demodulation
115
+ ):
116
+ super().__init__()
117
+
118
+ self.conv = ModulatedConv2d(
119
+ in_channels=in_channels,
120
+ out_channels=out_channels,
121
+ kernel_size=kernel_size,
122
+ style_dim=style_dim,
123
+ demodulate=demodulate,
124
+ up=up,
125
+ resample_filter=resample_filter,
126
+ conv_clamp=conv_clamp,
127
+ )
128
+
129
+ self.use_noise = use_noise
130
+ self.resolution = resolution
131
+ if use_noise:
132
+ self.register_buffer("noise_const", torch.randn([resolution, resolution]))
133
+ self.noise_strength = torch.nn.Parameter(torch.zeros([]))
134
+
135
+ self.bias = torch.nn.Parameter(torch.zeros([out_channels]))
136
+ self.activation = activation
137
+ self.act_gain = activation_funcs[activation].def_gain
138
+ self.conv_clamp = conv_clamp
139
+
140
+ def forward(self, x, style, noise_mode="random", gain=1):
141
+ x = self.conv(x, style)
142
+
143
+ assert noise_mode in ["random", "const", "none"]
144
+
145
+ if self.use_noise:
146
+ if noise_mode == "random":
147
+ xh, xw = x.size()[-2:]
148
+ noise = (
149
+ torch.randn([x.shape[0], 1, xh, xw], device=x.device)
150
+ * self.noise_strength
151
+ )
152
+ if noise_mode == "const":
153
+ noise = self.noise_const * self.noise_strength
154
+ x = x + noise
155
+
156
+ act_gain = self.act_gain * gain
157
+ act_clamp = self.conv_clamp * gain if self.conv_clamp is not None else None
158
+ out = bias_act(
159
+ x, self.bias, act=self.activation, gain=act_gain, clamp=act_clamp
160
+ )
161
+
162
+ return out
163
+
164
+
165
+ class ToRGB(torch.nn.Module):
166
+ def __init__(
167
+ self,
168
+ in_channels,
169
+ out_channels,
170
+ style_dim,
171
+ kernel_size=1,
172
+ resample_filter=[1, 3, 3, 1],
173
+ conv_clamp=None,
174
+ demodulate=False,
175
+ ):
176
+ super().__init__()
177
+
178
+ self.conv = ModulatedConv2d(
179
+ in_channels=in_channels,
180
+ out_channels=out_channels,
181
+ kernel_size=kernel_size,
182
+ style_dim=style_dim,
183
+ demodulate=demodulate,
184
+ resample_filter=resample_filter,
185
+ conv_clamp=conv_clamp,
186
+ )
187
+ self.bias = torch.nn.Parameter(torch.zeros([out_channels]))
188
+ self.register_buffer("resample_filter", setup_filter(resample_filter))
189
+ self.conv_clamp = conv_clamp
190
+
191
+ def forward(self, x, style, skip=None):
192
+ x = self.conv(x, style)
193
+ out = bias_act(x, self.bias, clamp=self.conv_clamp)
194
+
195
+ if skip is not None:
196
+ if skip.shape != out.shape:
197
+ skip = upsample2d(skip, self.resample_filter)
198
+ out = out + skip
199
+
200
+ return out
201
+
202
+
203
+ def get_style_code(a, b):
204
+ return torch.cat([a, b], dim=1)
205
+
206
+
207
+ class DecBlockFirst(nn.Module):
208
+ def __init__(
209
+ self,
210
+ in_channels,
211
+ out_channels,
212
+ activation,
213
+ style_dim,
214
+ use_noise,
215
+ demodulate,
216
+ img_channels,
217
+ ):
218
+ super().__init__()
219
+ self.fc = FullyConnectedLayer(
220
+ in_features=in_channels * 2,
221
+ out_features=in_channels * 4**2,
222
+ activation=activation,
223
+ )
224
+ self.conv = StyleConv(
225
+ in_channels=in_channels,
226
+ out_channels=out_channels,
227
+ style_dim=style_dim,
228
+ resolution=4,
229
+ kernel_size=3,
230
+ use_noise=use_noise,
231
+ activation=activation,
232
+ demodulate=demodulate,
233
+ )
234
+ self.toRGB = ToRGB(
235
+ in_channels=out_channels,
236
+ out_channels=img_channels,
237
+ style_dim=style_dim,
238
+ kernel_size=1,
239
+ demodulate=False,
240
+ )
241
+
242
+ def forward(self, x, ws, gs, E_features, noise_mode="random"):
243
+ x = self.fc(x).view(x.shape[0], -1, 4, 4)
244
+ x = x + E_features[2]
245
+ style = get_style_code(ws[:, 0], gs)
246
+ x = self.conv(x, style, noise_mode=noise_mode)
247
+ style = get_style_code(ws[:, 1], gs)
248
+ img = self.toRGB(x, style, skip=None)
249
+
250
+ return x, img
251
+
252
+
253
+ class DecBlockFirstV2(nn.Module):
254
+ def __init__(
255
+ self,
256
+ in_channels,
257
+ out_channels,
258
+ activation,
259
+ style_dim,
260
+ use_noise,
261
+ demodulate,
262
+ img_channels,
263
+ ):
264
+ super().__init__()
265
+ self.conv0 = Conv2dLayer(
266
+ in_channels=in_channels,
267
+ out_channels=in_channels,
268
+ kernel_size=3,
269
+ activation=activation,
270
+ )
271
+ self.conv1 = StyleConv(
272
+ in_channels=in_channels,
273
+ out_channels=out_channels,
274
+ style_dim=style_dim,
275
+ resolution=4,
276
+ kernel_size=3,
277
+ use_noise=use_noise,
278
+ activation=activation,
279
+ demodulate=demodulate,
280
+ )
281
+ self.toRGB = ToRGB(
282
+ in_channels=out_channels,
283
+ out_channels=img_channels,
284
+ style_dim=style_dim,
285
+ kernel_size=1,
286
+ demodulate=False,
287
+ )
288
+
289
+ def forward(self, x, ws, gs, E_features, noise_mode="random"):
290
+ # x = self.fc(x).view(x.shape[0], -1, 4, 4)
291
+ x = self.conv0(x)
292
+ x = x + E_features[2]
293
+ style = get_style_code(ws[:, 0], gs)
294
+ x = self.conv1(x, style, noise_mode=noise_mode)
295
+ style = get_style_code(ws[:, 1], gs)
296
+ img = self.toRGB(x, style, skip=None)
297
+
298
+ return x, img
299
+
300
+
301
+ class DecBlock(nn.Module):
302
+ def __init__(
303
+ self,
304
+ res,
305
+ in_channels,
306
+ out_channels,
307
+ activation,
308
+ style_dim,
309
+ use_noise,
310
+ demodulate,
311
+ img_channels,
312
+ ): # res = 2, ..., resolution_log2
313
+ super().__init__()
314
+ self.res = res
315
+
316
+ self.conv0 = StyleConv(
317
+ in_channels=in_channels,
318
+ out_channels=out_channels,
319
+ style_dim=style_dim,
320
+ resolution=2**res,
321
+ kernel_size=3,
322
+ up=2,
323
+ use_noise=use_noise,
324
+ activation=activation,
325
+ demodulate=demodulate,
326
+ )
327
+ self.conv1 = StyleConv(
328
+ in_channels=out_channels,
329
+ out_channels=out_channels,
330
+ style_dim=style_dim,
331
+ resolution=2**res,
332
+ kernel_size=3,
333
+ use_noise=use_noise,
334
+ activation=activation,
335
+ demodulate=demodulate,
336
+ )
337
+ self.toRGB = ToRGB(
338
+ in_channels=out_channels,
339
+ out_channels=img_channels,
340
+ style_dim=style_dim,
341
+ kernel_size=1,
342
+ demodulate=False,
343
+ )
344
+
345
+ def forward(self, x, img, ws, gs, E_features, noise_mode="random"):
346
+ style = get_style_code(ws[:, self.res * 2 - 5], gs)
347
+ x = self.conv0(x, style, noise_mode=noise_mode)
348
+ x = x + E_features[self.res]
349
+ style = get_style_code(ws[:, self.res * 2 - 4], gs)
350
+ x = self.conv1(x, style, noise_mode=noise_mode)
351
+ style = get_style_code(ws[:, self.res * 2 - 3], gs)
352
+ img = self.toRGB(x, style, skip=img)
353
+
354
+ return x, img
355
+
356
+
357
+ class MappingNet(torch.nn.Module):
358
+ def __init__(
359
+ self,
360
+ z_dim, # Input latent (Z) dimensionality, 0 = no latent.
361
+ c_dim, # Conditioning label (C) dimensionality, 0 = no label.
362
+ w_dim, # Intermediate latent (W) dimensionality.
363
+ num_ws, # Number of intermediate latents to output, None = do not broadcast.
364
+ num_layers=8, # Number of mapping layers.
365
+ embed_features=None, # Label embedding dimensionality, None = same as w_dim.
366
+ layer_features=None, # Number of intermediate features in the mapping layers, None = same as w_dim.
367
+ activation="lrelu", # Activation function: 'relu', 'lrelu', etc.
368
+ lr_multiplier=0.01, # Learning rate multiplier for the mapping layers.
369
+ w_avg_beta=0.995, # Decay for tracking the moving average of W during training, None = do not track.
370
+ torch_dtype=torch.float32,
371
+ ):
372
+ super().__init__()
373
+ self.z_dim = z_dim
374
+ self.c_dim = c_dim
375
+ self.w_dim = w_dim
376
+ self.num_ws = num_ws
377
+ self.num_layers = num_layers
378
+ self.w_avg_beta = w_avg_beta
379
+ self.torch_dtype = torch_dtype
380
+
381
+ if embed_features is None:
382
+ embed_features = w_dim
383
+ if c_dim == 0:
384
+ embed_features = 0
385
+ if layer_features is None:
386
+ layer_features = w_dim
387
+ features_list = (
388
+ [z_dim + embed_features] + [layer_features] * (num_layers - 1) + [w_dim]
389
+ )
390
+
391
+ if c_dim > 0:
392
+ self.embed = FullyConnectedLayer(c_dim, embed_features)
393
+ for idx in range(num_layers):
394
+ in_features = features_list[idx]
395
+ out_features = features_list[idx + 1]
396
+ layer = FullyConnectedLayer(
397
+ in_features,
398
+ out_features,
399
+ activation=activation,
400
+ lr_multiplier=lr_multiplier,
401
+ )
402
+ setattr(self, f"fc{idx}", layer)
403
+
404
+ if num_ws is not None and w_avg_beta is not None:
405
+ self.register_buffer("w_avg", torch.zeros([w_dim]))
406
+
407
+ def forward(
408
+ self, z, c, truncation_psi=1, truncation_cutoff=None, skip_w_avg_update=False
409
+ ):
410
+ # Embed, normalize, and concat inputs.
411
+ x = None
412
+ if self.z_dim > 0:
413
+ x = normalize_2nd_moment(z)
414
+ if self.c_dim > 0:
415
+ y = normalize_2nd_moment(self.embed(c))
416
+ x = torch.cat([x, y], dim=1) if x is not None else y
417
+
418
+ # Main layers.
419
+ for idx in range(self.num_layers):
420
+ layer = getattr(self, f"fc{idx}")
421
+ x = layer(x)
422
+
423
+ # Update moving average of W.
424
+ if self.w_avg_beta is not None and self.training and not skip_w_avg_update:
425
+ self.w_avg.copy_(x.detach().mean(dim=0).lerp(self.w_avg, self.w_avg_beta))
426
+
427
+ # Broadcast.
428
+ if self.num_ws is not None:
429
+ x = x.unsqueeze(1).repeat([1, self.num_ws, 1])
430
+
431
+ # Apply truncation.
432
+ if truncation_psi != 1:
433
+ assert self.w_avg_beta is not None
434
+ if self.num_ws is None or truncation_cutoff is None:
435
+ x = self.w_avg.lerp(x, truncation_psi)
436
+ else:
437
+ x[:, :truncation_cutoff] = self.w_avg.lerp(
438
+ x[:, :truncation_cutoff], truncation_psi
439
+ )
440
+
441
+ return x
442
+
443
+
444
+ class DisFromRGB(nn.Module):
445
+ def __init__(
446
+ self, in_channels, out_channels, activation
447
+ ): # res = 2, ..., resolution_log2
448
+ super().__init__()
449
+ self.conv = Conv2dLayer(
450
+ in_channels=in_channels,
451
+ out_channels=out_channels,
452
+ kernel_size=1,
453
+ activation=activation,
454
+ )
455
+
456
+ def forward(self, x):
457
+ return self.conv(x)
458
+
459
+
460
+ class DisBlock(nn.Module):
461
+ def __init__(
462
+ self, in_channels, out_channels, activation
463
+ ): # res = 2, ..., resolution_log2
464
+ super().__init__()
465
+ self.conv0 = Conv2dLayer(
466
+ in_channels=in_channels,
467
+ out_channels=in_channels,
468
+ kernel_size=3,
469
+ activation=activation,
470
+ )
471
+ self.conv1 = Conv2dLayer(
472
+ in_channels=in_channels,
473
+ out_channels=out_channels,
474
+ kernel_size=3,
475
+ down=2,
476
+ activation=activation,
477
+ )
478
+ self.skip = Conv2dLayer(
479
+ in_channels=in_channels,
480
+ out_channels=out_channels,
481
+ kernel_size=1,
482
+ down=2,
483
+ bias=False,
484
+ )
485
+
486
+ def forward(self, x):
487
+ skip = self.skip(x, gain=np.sqrt(0.5))
488
+ x = self.conv0(x)
489
+ x = self.conv1(x, gain=np.sqrt(0.5))
490
+ out = skip + x
491
+
492
+ return out
493
+
494
+
495
+ class Discriminator(torch.nn.Module):
496
+ def __init__(
497
+ self,
498
+ c_dim, # Conditioning label (C) dimensionality.
499
+ img_resolution, # Input resolution.
500
+ img_channels, # Number of input color channels.
501
+ channel_base=32768, # Overall multiplier for the number of channels.
502
+ channel_max=512, # Maximum number of channels in any layer.
503
+ channel_decay=1,
504
+ cmap_dim=None, # Dimensionality of mapped conditioning label, None = default.
505
+ activation="lrelu",
506
+ mbstd_group_size=4, # Group size for the minibatch standard deviation layer, None = entire minibatch.
507
+ mbstd_num_channels=1, # Number of features for the minibatch standard deviation layer, 0 = disable.
508
+ ):
509
+ super().__init__()
510
+ self.c_dim = c_dim
511
+ self.img_resolution = img_resolution
512
+ self.img_channels = img_channels
513
+
514
+ resolution_log2 = int(np.log2(img_resolution))
515
+ assert img_resolution == 2**resolution_log2 and img_resolution >= 4
516
+ self.resolution_log2 = resolution_log2
517
+
518
+ def nf(stage):
519
+ return np.clip(
520
+ int(channel_base / 2 ** (stage * channel_decay)), 1, channel_max
521
+ )
522
+
523
+ if cmap_dim == None:
524
+ cmap_dim = nf(2)
525
+ if c_dim == 0:
526
+ cmap_dim = 0
527
+ self.cmap_dim = cmap_dim
528
+
529
+ if c_dim > 0:
530
+ self.mapping = MappingNet(
531
+ z_dim=0, c_dim=c_dim, w_dim=cmap_dim, num_ws=None, w_avg_beta=None
532
+ )
533
+
534
+ Dis = [DisFromRGB(img_channels + 1, nf(resolution_log2), activation)]
535
+ for res in range(resolution_log2, 2, -1):
536
+ Dis.append(DisBlock(nf(res), nf(res - 1), activation))
537
+
538
+ if mbstd_num_channels > 0:
539
+ Dis.append(
540
+ MinibatchStdLayer(
541
+ group_size=mbstd_group_size, num_channels=mbstd_num_channels
542
+ )
543
+ )
544
+ Dis.append(
545
+ Conv2dLayer(
546
+ nf(2) + mbstd_num_channels, nf(2), kernel_size=3, activation=activation
547
+ )
548
+ )
549
+ self.Dis = nn.Sequential(*Dis)
550
+
551
+ self.fc0 = FullyConnectedLayer(nf(2) * 4**2, nf(2), activation=activation)
552
+ self.fc1 = FullyConnectedLayer(nf(2), 1 if cmap_dim == 0 else cmap_dim)
553
+
554
+ def forward(self, images_in, masks_in, c):
555
+ x = torch.cat([masks_in - 0.5, images_in], dim=1)
556
+ x = self.Dis(x)
557
+ x = self.fc1(self.fc0(x.flatten(start_dim=1)))
558
+
559
+ if self.c_dim > 0:
560
+ cmap = self.mapping(None, c)
561
+
562
+ if self.cmap_dim > 0:
563
+ x = (x * cmap).sum(dim=1, keepdim=True) * (1 / np.sqrt(self.cmap_dim))
564
+
565
+ return x
566
+
567
+
568
+ def nf(stage, channel_base=32768, channel_decay=1.0, channel_max=512):
569
+ NF = {512: 64, 256: 128, 128: 256, 64: 512, 32: 512, 16: 512, 8: 512, 4: 512}
570
+ return NF[2**stage]
571
+
572
+
573
+ class Mlp(nn.Module):
574
+ def __init__(
575
+ self,
576
+ in_features,
577
+ hidden_features=None,
578
+ out_features=None,
579
+ act_layer=nn.GELU,
580
+ drop=0.0,
581
+ ):
582
+ super().__init__()
583
+ out_features = out_features or in_features
584
+ hidden_features = hidden_features or in_features
585
+ self.fc1 = FullyConnectedLayer(
586
+ in_features=in_features, out_features=hidden_features, activation="lrelu"
587
+ )
588
+ self.fc2 = FullyConnectedLayer(
589
+ in_features=hidden_features, out_features=out_features
590
+ )
591
+
592
+ def forward(self, x):
593
+ x = self.fc1(x)
594
+ x = self.fc2(x)
595
+ return x
596
+
597
+
598
+ def window_partition(x, window_size):
599
+ """
600
+ Args:
601
+ x: (B, H, W, C)
602
+ window_size (int): window size
603
+ Returns:
604
+ windows: (num_windows*B, window_size, window_size, C)
605
+ """
606
+ B, H, W, C = x.shape
607
+ x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
608
+ windows = (
609
+ x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
610
+ )
611
+ return windows
612
+
613
+
614
+ def window_reverse(windows, window_size: int, H: int, W: int):
615
+ """
616
+ Args:
617
+ windows: (num_windows*B, window_size, window_size, C)
618
+ window_size (int): Window size
619
+ H (int): Height of image
620
+ W (int): Width of image
621
+ Returns:
622
+ x: (B, H, W, C)
623
+ """
624
+ B = int(windows.shape[0] / (H * W / window_size / window_size))
625
+ # B = windows.shape[0] / (H * W / window_size / window_size)
626
+ x = windows.view(
627
+ B, H // window_size, W // window_size, window_size, window_size, -1
628
+ )
629
+ x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
630
+ return x
631
+
632
+
633
+ class Conv2dLayerPartial(nn.Module):
634
+ def __init__(
635
+ self,
636
+ in_channels, # Number of input channels.
637
+ out_channels, # Number of output channels.
638
+ kernel_size, # Width and height of the convolution kernel.
639
+ bias=True, # Apply additive bias before the activation function?
640
+ activation="linear", # Activation function: 'relu', 'lrelu', etc.
641
+ up=1, # Integer upsampling factor.
642
+ down=1, # Integer downsampling factor.
643
+ resample_filter=[
644
+ 1,
645
+ 3,
646
+ 3,
647
+ 1,
648
+ ], # Low-pass filter to apply when resampling activations.
649
+ conv_clamp=None, # Clamp the output to +-X, None = disable clamping.
650
+ trainable=True, # Update the weights of this layer during training?
651
+ ):
652
+ super().__init__()
653
+ self.conv = Conv2dLayer(
654
+ in_channels,
655
+ out_channels,
656
+ kernel_size,
657
+ bias,
658
+ activation,
659
+ up,
660
+ down,
661
+ resample_filter,
662
+ conv_clamp,
663
+ trainable,
664
+ )
665
+
666
+ self.weight_maskUpdater = torch.ones(1, 1, kernel_size, kernel_size)
667
+ self.slide_winsize = kernel_size**2
668
+ self.stride = down
669
+ self.padding = kernel_size // 2 if kernel_size % 2 == 1 else 0
670
+
671
+ def forward(self, x, mask=None):
672
+ if mask is not None:
673
+ with torch.no_grad():
674
+ if self.weight_maskUpdater.type() != x.type():
675
+ self.weight_maskUpdater = self.weight_maskUpdater.to(x)
676
+ update_mask = F.conv2d(
677
+ mask,
678
+ self.weight_maskUpdater,
679
+ bias=None,
680
+ stride=self.stride,
681
+ padding=self.padding,
682
+ )
683
+ mask_ratio = self.slide_winsize / (update_mask.to(torch.float32) + 1e-8)
684
+ update_mask = torch.clamp(update_mask, 0, 1) # 0 or 1
685
+ mask_ratio = torch.mul(mask_ratio, update_mask).to(x.dtype)
686
+ x = self.conv(x)
687
+ x = torch.mul(x, mask_ratio)
688
+ return x, update_mask
689
+ else:
690
+ x = self.conv(x)
691
+ return x, None
692
+
693
+
694
+ class WindowAttention(nn.Module):
695
+ r"""Window based multi-head self attention (W-MSA) module with relative position bias.
696
+ It supports both of shifted and non-shifted window.
697
+ Args:
698
+ dim (int): Number of input channels.
699
+ window_size (tuple[int]): The height and width of the window.
700
+ num_heads (int): Number of attention heads.
701
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
702
+ qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set
703
+ attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
704
+ proj_drop (float, optional): Dropout ratio of output. Default: 0.0
705
+ """
706
+
707
+ def __init__(
708
+ self,
709
+ dim,
710
+ window_size,
711
+ num_heads,
712
+ down_ratio=1,
713
+ qkv_bias=True,
714
+ qk_scale=None,
715
+ attn_drop=0.0,
716
+ proj_drop=0.0,
717
+ ):
718
+ super().__init__()
719
+ self.dim = dim
720
+ self.window_size = window_size # Wh, Ww
721
+ self.num_heads = num_heads
722
+ head_dim = dim // num_heads
723
+ self.scale = qk_scale or head_dim**-0.5
724
+
725
+ self.q = FullyConnectedLayer(in_features=dim, out_features=dim)
726
+ self.k = FullyConnectedLayer(in_features=dim, out_features=dim)
727
+ self.v = FullyConnectedLayer(in_features=dim, out_features=dim)
728
+ self.proj = FullyConnectedLayer(in_features=dim, out_features=dim)
729
+
730
+ self.softmax = nn.Softmax(dim=-1)
731
+
732
+ def forward(self, x, mask_windows=None, mask=None):
733
+ """
734
+ Args:
735
+ x: input features with shape of (num_windows*B, N, C)
736
+ mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
737
+ """
738
+ B_, N, C = x.shape
739
+ norm_x = F.normalize(x, p=2.0, dim=-1, eps=torch.finfo(x.dtype).eps)
740
+ q = (
741
+ self.q(norm_x)
742
+ .reshape(B_, N, self.num_heads, C // self.num_heads)
743
+ .permute(0, 2, 1, 3)
744
+ )
745
+ k = (
746
+ self.k(norm_x)
747
+ .view(B_, -1, self.num_heads, C // self.num_heads)
748
+ .permute(0, 2, 3, 1)
749
+ )
750
+ v = (
751
+ self.v(x)
752
+ .view(B_, -1, self.num_heads, C // self.num_heads)
753
+ .permute(0, 2, 1, 3)
754
+ )
755
+
756
+ attn = (q @ k) * self.scale
757
+
758
+ if mask is not None:
759
+ nW = mask.shape[0]
760
+ attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(
761
+ 1
762
+ ).unsqueeze(0)
763
+ attn = attn.view(-1, self.num_heads, N, N)
764
+
765
+ if mask_windows is not None:
766
+ attn_mask_windows = mask_windows.squeeze(-1).unsqueeze(1).unsqueeze(1)
767
+ attn = attn + attn_mask_windows.masked_fill(
768
+ attn_mask_windows == 0, float(-100.0)
769
+ ).masked_fill(attn_mask_windows == 1, float(0.0))
770
+ with torch.no_grad():
771
+ mask_windows = torch.clamp(
772
+ torch.sum(mask_windows, dim=1, keepdim=True), 0, 1
773
+ ).repeat(1, N, 1)
774
+
775
+ attn = self.softmax(attn)
776
+ x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
777
+ x = self.proj(x)
778
+ return x, mask_windows
779
+
780
+
781
+ class SwinTransformerBlock(nn.Module):
782
+ r"""Swin Transformer Block.
783
+ Args:
784
+ dim (int): Number of input channels.
785
+ input_resolution (tuple[int]): Input resulotion.
786
+ num_heads (int): Number of attention heads.
787
+ window_size (int): Window size.
788
+ shift_size (int): Shift size for SW-MSA.
789
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
790
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
791
+ qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
792
+ drop (float, optional): Dropout rate. Default: 0.0
793
+ attn_drop (float, optional): Attention dropout rate. Default: 0.0
794
+ drop_path (float, optional): Stochastic depth rate. Default: 0.0
795
+ act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
796
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
797
+ """
798
+
799
+ def __init__(
800
+ self,
801
+ dim,
802
+ input_resolution,
803
+ num_heads,
804
+ down_ratio=1,
805
+ window_size=7,
806
+ shift_size=0,
807
+ mlp_ratio=4.0,
808
+ qkv_bias=True,
809
+ qk_scale=None,
810
+ drop=0.0,
811
+ attn_drop=0.0,
812
+ drop_path=0.0,
813
+ act_layer=nn.GELU,
814
+ norm_layer=nn.LayerNorm,
815
+ ):
816
+ super().__init__()
817
+ self.dim = dim
818
+ self.input_resolution = input_resolution
819
+ self.num_heads = num_heads
820
+ self.window_size = window_size
821
+ self.shift_size = shift_size
822
+ self.mlp_ratio = mlp_ratio
823
+ if min(self.input_resolution) <= self.window_size:
824
+ # if window size is larger than input resolution, we don't partition windows
825
+ self.shift_size = 0
826
+ self.window_size = min(self.input_resolution)
827
+ assert (
828
+ 0 <= self.shift_size < self.window_size
829
+ ), "shift_size must in 0-window_size"
830
+
831
+ if self.shift_size > 0:
832
+ down_ratio = 1
833
+ self.attn = WindowAttention(
834
+ dim,
835
+ window_size=to_2tuple(self.window_size),
836
+ num_heads=num_heads,
837
+ down_ratio=down_ratio,
838
+ qkv_bias=qkv_bias,
839
+ qk_scale=qk_scale,
840
+ attn_drop=attn_drop,
841
+ proj_drop=drop,
842
+ )
843
+
844
+ self.fuse = FullyConnectedLayer(
845
+ in_features=dim * 2, out_features=dim, activation="lrelu"
846
+ )
847
+
848
+ mlp_hidden_dim = int(dim * mlp_ratio)
849
+ self.mlp = Mlp(
850
+ in_features=dim,
851
+ hidden_features=mlp_hidden_dim,
852
+ act_layer=act_layer,
853
+ drop=drop,
854
+ )
855
+
856
+ if self.shift_size > 0:
857
+ attn_mask = self.calculate_mask(self.input_resolution)
858
+ else:
859
+ attn_mask = None
860
+
861
+ self.register_buffer("attn_mask", attn_mask)
862
+
863
+ def calculate_mask(self, x_size):
864
+ # calculate attention mask for SW-MSA
865
+ H, W = x_size
866
+ img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1
867
+ h_slices = (
868
+ slice(0, -self.window_size),
869
+ slice(-self.window_size, -self.shift_size),
870
+ slice(-self.shift_size, None),
871
+ )
872
+ w_slices = (
873
+ slice(0, -self.window_size),
874
+ slice(-self.window_size, -self.shift_size),
875
+ slice(-self.shift_size, None),
876
+ )
877
+ cnt = 0
878
+ for h in h_slices:
879
+ for w in w_slices:
880
+ img_mask[:, h, w, :] = cnt
881
+ cnt += 1
882
+
883
+ mask_windows = window_partition(
884
+ img_mask, self.window_size
885
+ ) # nW, window_size, window_size, 1
886
+ mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
887
+ attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
888
+ attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(
889
+ attn_mask == 0, float(0.0)
890
+ )
891
+
892
+ return attn_mask
893
+
894
+ def forward(self, x, x_size, mask=None):
895
+ # H, W = self.input_resolution
896
+ H, W = x_size
897
+ B, L, C = x.shape
898
+ # assert L == H * W, "input feature has wrong size"
899
+
900
+ shortcut = x
901
+ x = x.view(B, H, W, C)
902
+ if mask is not None:
903
+ mask = mask.view(B, H, W, 1)
904
+
905
+ # cyclic shift
906
+ if self.shift_size > 0:
907
+ shifted_x = torch.roll(
908
+ x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)
909
+ )
910
+ if mask is not None:
911
+ shifted_mask = torch.roll(
912
+ mask, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)
913
+ )
914
+ else:
915
+ shifted_x = x
916
+ if mask is not None:
917
+ shifted_mask = mask
918
+
919
+ # partition windows
920
+ x_windows = window_partition(
921
+ shifted_x, self.window_size
922
+ ) # nW*B, window_size, window_size, C
923
+ x_windows = x_windows.view(
924
+ -1, self.window_size * self.window_size, C
925
+ ) # nW*B, window_size*window_size, C
926
+ if mask is not None:
927
+ mask_windows = window_partition(shifted_mask, self.window_size)
928
+ mask_windows = mask_windows.view(-1, self.window_size * self.window_size, 1)
929
+ else:
930
+ mask_windows = None
931
+
932
+ # W-MSA/SW-MSA (to be compatible for testing on images whose shapes are the multiple of window size
933
+ if self.input_resolution == x_size:
934
+ attn_windows, mask_windows = self.attn(
935
+ x_windows, mask_windows, mask=self.attn_mask
936
+ ) # nW*B, window_size*window_size, C
937
+ else:
938
+ attn_windows, mask_windows = self.attn(
939
+ x_windows,
940
+ mask_windows,
941
+ mask=self.calculate_mask(x_size).to(x.dtype).to(x.device),
942
+ ) # nW*B, window_size*window_size, C
943
+
944
+ # merge windows
945
+ attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)
946
+ shifted_x = window_reverse(attn_windows, self.window_size, H, W) # B H' W' C
947
+ if mask is not None:
948
+ mask_windows = mask_windows.view(-1, self.window_size, self.window_size, 1)
949
+ shifted_mask = window_reverse(mask_windows, self.window_size, H, W)
950
+
951
+ # reverse cyclic shift
952
+ if self.shift_size > 0:
953
+ x = torch.roll(
954
+ shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2)
955
+ )
956
+ if mask is not None:
957
+ mask = torch.roll(
958
+ shifted_mask, shifts=(self.shift_size, self.shift_size), dims=(1, 2)
959
+ )
960
+ else:
961
+ x = shifted_x
962
+ if mask is not None:
963
+ mask = shifted_mask
964
+ x = x.view(B, H * W, C)
965
+ if mask is not None:
966
+ mask = mask.view(B, H * W, 1)
967
+
968
+ # FFN
969
+ x = self.fuse(torch.cat([shortcut, x], dim=-1))
970
+ x = self.mlp(x)
971
+
972
+ return x, mask
973
+
974
+
975
+ class PatchMerging(nn.Module):
976
+ def __init__(self, in_channels, out_channels, down=2):
977
+ super().__init__()
978
+ self.conv = Conv2dLayerPartial(
979
+ in_channels=in_channels,
980
+ out_channels=out_channels,
981
+ kernel_size=3,
982
+ activation="lrelu",
983
+ down=down,
984
+ )
985
+ self.down = down
986
+
987
+ def forward(self, x, x_size, mask=None):
988
+ x = token2feature(x, x_size)
989
+ if mask is not None:
990
+ mask = token2feature(mask, x_size)
991
+ x, mask = self.conv(x, mask)
992
+ if self.down != 1:
993
+ ratio = 1 / self.down
994
+ x_size = (int(x_size[0] * ratio), int(x_size[1] * ratio))
995
+ x = feature2token(x)
996
+ if mask is not None:
997
+ mask = feature2token(mask)
998
+ return x, x_size, mask
999
+
1000
+
1001
+ class PatchUpsampling(nn.Module):
1002
+ def __init__(self, in_channels, out_channels, up=2):
1003
+ super().__init__()
1004
+ self.conv = Conv2dLayerPartial(
1005
+ in_channels=in_channels,
1006
+ out_channels=out_channels,
1007
+ kernel_size=3,
1008
+ activation="lrelu",
1009
+ up=up,
1010
+ )
1011
+ self.up = up
1012
+
1013
+ def forward(self, x, x_size, mask=None):
1014
+ x = token2feature(x, x_size)
1015
+ if mask is not None:
1016
+ mask = token2feature(mask, x_size)
1017
+ x, mask = self.conv(x, mask)
1018
+ if self.up != 1:
1019
+ x_size = (int(x_size[0] * self.up), int(x_size[1] * self.up))
1020
+ x = feature2token(x)
1021
+ if mask is not None:
1022
+ mask = feature2token(mask)
1023
+ return x, x_size, mask
1024
+
1025
+
1026
+ class BasicLayer(nn.Module):
1027
+ """A basic Swin Transformer layer for one stage.
1028
+ Args:
1029
+ dim (int): Number of input channels.
1030
+ input_resolution (tuple[int]): Input resolution.
1031
+ depth (int): Number of blocks.
1032
+ num_heads (int): Number of attention heads.
1033
+ window_size (int): Local window size.
1034
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
1035
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
1036
+ qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
1037
+ drop (float, optional): Dropout rate. Default: 0.0
1038
+ attn_drop (float, optional): Attention dropout rate. Default: 0.0
1039
+ drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
1040
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
1041
+ downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
1042
+ use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
1043
+ """
1044
+
1045
+ def __init__(
1046
+ self,
1047
+ dim,
1048
+ input_resolution,
1049
+ depth,
1050
+ num_heads,
1051
+ window_size,
1052
+ down_ratio=1,
1053
+ mlp_ratio=2.0,
1054
+ qkv_bias=True,
1055
+ qk_scale=None,
1056
+ drop=0.0,
1057
+ attn_drop=0.0,
1058
+ drop_path=0.0,
1059
+ norm_layer=nn.LayerNorm,
1060
+ downsample=None,
1061
+ use_checkpoint=False,
1062
+ ):
1063
+ super().__init__()
1064
+ self.dim = dim
1065
+ self.input_resolution = input_resolution
1066
+ self.depth = depth
1067
+ self.use_checkpoint = use_checkpoint
1068
+
1069
+ # patch merging layer
1070
+ if downsample is not None:
1071
+ # self.downsample = downsample(input_resolution, dim=dim, norm_layer=norm_layer)
1072
+ self.downsample = downsample
1073
+ else:
1074
+ self.downsample = None
1075
+
1076
+ # build blocks
1077
+ self.blocks = nn.ModuleList(
1078
+ [
1079
+ SwinTransformerBlock(
1080
+ dim=dim,
1081
+ input_resolution=input_resolution,
1082
+ num_heads=num_heads,
1083
+ down_ratio=down_ratio,
1084
+ window_size=window_size,
1085
+ shift_size=0 if (i % 2 == 0) else window_size // 2,
1086
+ mlp_ratio=mlp_ratio,
1087
+ qkv_bias=qkv_bias,
1088
+ qk_scale=qk_scale,
1089
+ drop=drop,
1090
+ attn_drop=attn_drop,
1091
+ drop_path=drop_path[i]
1092
+ if isinstance(drop_path, list)
1093
+ else drop_path,
1094
+ norm_layer=norm_layer,
1095
+ )
1096
+ for i in range(depth)
1097
+ ]
1098
+ )
1099
+
1100
+ self.conv = Conv2dLayerPartial(
1101
+ in_channels=dim, out_channels=dim, kernel_size=3, activation="lrelu"
1102
+ )
1103
+
1104
+ def forward(self, x, x_size, mask=None):
1105
+ if self.downsample is not None:
1106
+ x, x_size, mask = self.downsample(x, x_size, mask)
1107
+ identity = x
1108
+ for blk in self.blocks:
1109
+ if self.use_checkpoint:
1110
+ x, mask = checkpoint.checkpoint(blk, x, x_size, mask)
1111
+ else:
1112
+ x, mask = blk(x, x_size, mask)
1113
+ if mask is not None:
1114
+ mask = token2feature(mask, x_size)
1115
+ x, mask = self.conv(token2feature(x, x_size), mask)
1116
+ x = feature2token(x) + identity
1117
+ if mask is not None:
1118
+ mask = feature2token(mask)
1119
+ return x, x_size, mask
1120
+
1121
+
1122
+ class ToToken(nn.Module):
1123
+ def __init__(self, in_channels=3, dim=128, kernel_size=5, stride=1):
1124
+ super().__init__()
1125
+
1126
+ self.proj = Conv2dLayerPartial(
1127
+ in_channels=in_channels,
1128
+ out_channels=dim,
1129
+ kernel_size=kernel_size,
1130
+ activation="lrelu",
1131
+ )
1132
+
1133
+ def forward(self, x, mask):
1134
+ x, mask = self.proj(x, mask)
1135
+
1136
+ return x, mask
1137
+
1138
+
1139
+ class EncFromRGB(nn.Module):
1140
+ def __init__(
1141
+ self, in_channels, out_channels, activation
1142
+ ): # res = 2, ..., resolution_log2
1143
+ super().__init__()
1144
+ self.conv0 = Conv2dLayer(
1145
+ in_channels=in_channels,
1146
+ out_channels=out_channels,
1147
+ kernel_size=1,
1148
+ activation=activation,
1149
+ )
1150
+ self.conv1 = Conv2dLayer(
1151
+ in_channels=out_channels,
1152
+ out_channels=out_channels,
1153
+ kernel_size=3,
1154
+ activation=activation,
1155
+ )
1156
+
1157
+ def forward(self, x):
1158
+ x = self.conv0(x)
1159
+ x = self.conv1(x)
1160
+
1161
+ return x
1162
+
1163
+
1164
+ class ConvBlockDown(nn.Module):
1165
+ def __init__(
1166
+ self, in_channels, out_channels, activation
1167
+ ): # res = 2, ..., resolution_log
1168
+ super().__init__()
1169
+
1170
+ self.conv0 = Conv2dLayer(
1171
+ in_channels=in_channels,
1172
+ out_channels=out_channels,
1173
+ kernel_size=3,
1174
+ activation=activation,
1175
+ down=2,
1176
+ )
1177
+ self.conv1 = Conv2dLayer(
1178
+ in_channels=out_channels,
1179
+ out_channels=out_channels,
1180
+ kernel_size=3,
1181
+ activation=activation,
1182
+ )
1183
+
1184
+ def forward(self, x):
1185
+ x = self.conv0(x)
1186
+ x = self.conv1(x)
1187
+
1188
+ return x
1189
+
1190
+
1191
+ def token2feature(x, x_size):
1192
+ B, N, C = x.shape
1193
+ h, w = x_size
1194
+ x = x.permute(0, 2, 1).reshape(B, C, h, w)
1195
+ return x
1196
+
1197
+
1198
+ def feature2token(x):
1199
+ B, C, H, W = x.shape
1200
+ x = x.view(B, C, -1).transpose(1, 2)
1201
+ return x
1202
+
1203
+
1204
+ class Encoder(nn.Module):
1205
+ def __init__(
1206
+ self,
1207
+ res_log2,
1208
+ img_channels,
1209
+ activation,
1210
+ patch_size=5,
1211
+ channels=16,
1212
+ drop_path_rate=0.1,
1213
+ ):
1214
+ super().__init__()
1215
+
1216
+ self.resolution = []
1217
+
1218
+ for idx, i in enumerate(range(res_log2, 3, -1)): # from input size to 16x16
1219
+ res = 2**i
1220
+ self.resolution.append(res)
1221
+ if i == res_log2:
1222
+ block = EncFromRGB(img_channels * 2 + 1, nf(i), activation)
1223
+ else:
1224
+ block = ConvBlockDown(nf(i + 1), nf(i), activation)
1225
+ setattr(self, "EncConv_Block_%dx%d" % (res, res), block)
1226
+
1227
+ def forward(self, x):
1228
+ out = {}
1229
+ for res in self.resolution:
1230
+ res_log2 = int(np.log2(res))
1231
+ x = getattr(self, "EncConv_Block_%dx%d" % (res, res))(x)
1232
+ out[res_log2] = x
1233
+
1234
+ return out
1235
+
1236
+
1237
+ class ToStyle(nn.Module):
1238
+ def __init__(self, in_channels, out_channels, activation, drop_rate):
1239
+ super().__init__()
1240
+ self.conv = nn.Sequential(
1241
+ Conv2dLayer(
1242
+ in_channels=in_channels,
1243
+ out_channels=in_channels,
1244
+ kernel_size=3,
1245
+ activation=activation,
1246
+ down=2,
1247
+ ),
1248
+ Conv2dLayer(
1249
+ in_channels=in_channels,
1250
+ out_channels=in_channels,
1251
+ kernel_size=3,
1252
+ activation=activation,
1253
+ down=2,
1254
+ ),
1255
+ Conv2dLayer(
1256
+ in_channels=in_channels,
1257
+ out_channels=in_channels,
1258
+ kernel_size=3,
1259
+ activation=activation,
1260
+ down=2,
1261
+ ),
1262
+ )
1263
+
1264
+ self.pool = nn.AdaptiveAvgPool2d(1)
1265
+ self.fc = FullyConnectedLayer(
1266
+ in_features=in_channels, out_features=out_channels, activation=activation
1267
+ )
1268
+ # self.dropout = nn.Dropout(drop_rate)
1269
+
1270
+ def forward(self, x):
1271
+ x = self.conv(x)
1272
+ x = self.pool(x)
1273
+ x = self.fc(x.flatten(start_dim=1))
1274
+ # x = self.dropout(x)
1275
+
1276
+ return x
1277
+
1278
+
1279
+ class DecBlockFirstV2(nn.Module):
1280
+ def __init__(
1281
+ self,
1282
+ res,
1283
+ in_channels,
1284
+ out_channels,
1285
+ activation,
1286
+ style_dim,
1287
+ use_noise,
1288
+ demodulate,
1289
+ img_channels,
1290
+ ):
1291
+ super().__init__()
1292
+ self.res = res
1293
+
1294
+ self.conv0 = Conv2dLayer(
1295
+ in_channels=in_channels,
1296
+ out_channels=in_channels,
1297
+ kernel_size=3,
1298
+ activation=activation,
1299
+ )
1300
+ self.conv1 = StyleConv(
1301
+ in_channels=in_channels,
1302
+ out_channels=out_channels,
1303
+ style_dim=style_dim,
1304
+ resolution=2**res,
1305
+ kernel_size=3,
1306
+ use_noise=use_noise,
1307
+ activation=activation,
1308
+ demodulate=demodulate,
1309
+ )
1310
+ self.toRGB = ToRGB(
1311
+ in_channels=out_channels,
1312
+ out_channels=img_channels,
1313
+ style_dim=style_dim,
1314
+ kernel_size=1,
1315
+ demodulate=False,
1316
+ )
1317
+
1318
+ def forward(self, x, ws, gs, E_features, noise_mode="random"):
1319
+ # x = self.fc(x).view(x.shape[0], -1, 4, 4)
1320
+ x = self.conv0(x)
1321
+ x = x + E_features[self.res]
1322
+ style = get_style_code(ws[:, 0], gs)
1323
+ x = self.conv1(x, style, noise_mode=noise_mode)
1324
+ style = get_style_code(ws[:, 1], gs)
1325
+ img = self.toRGB(x, style, skip=None)
1326
+
1327
+ return x, img
1328
+
1329
+
1330
+ class DecBlock(nn.Module):
1331
+ def __init__(
1332
+ self,
1333
+ res,
1334
+ in_channels,
1335
+ out_channels,
1336
+ activation,
1337
+ style_dim,
1338
+ use_noise,
1339
+ demodulate,
1340
+ img_channels,
1341
+ ): # res = 4, ..., resolution_log2
1342
+ super().__init__()
1343
+ self.res = res
1344
+
1345
+ self.conv0 = StyleConv(
1346
+ in_channels=in_channels,
1347
+ out_channels=out_channels,
1348
+ style_dim=style_dim,
1349
+ resolution=2**res,
1350
+ kernel_size=3,
1351
+ up=2,
1352
+ use_noise=use_noise,
1353
+ activation=activation,
1354
+ demodulate=demodulate,
1355
+ )
1356
+ self.conv1 = StyleConv(
1357
+ in_channels=out_channels,
1358
+ out_channels=out_channels,
1359
+ style_dim=style_dim,
1360
+ resolution=2**res,
1361
+ kernel_size=3,
1362
+ use_noise=use_noise,
1363
+ activation=activation,
1364
+ demodulate=demodulate,
1365
+ )
1366
+ self.toRGB = ToRGB(
1367
+ in_channels=out_channels,
1368
+ out_channels=img_channels,
1369
+ style_dim=style_dim,
1370
+ kernel_size=1,
1371
+ demodulate=False,
1372
+ )
1373
+
1374
+ def forward(self, x, img, ws, gs, E_features, noise_mode="random"):
1375
+ style = get_style_code(ws[:, self.res * 2 - 9], gs)
1376
+ x = self.conv0(x, style, noise_mode=noise_mode)
1377
+ x = x + E_features[self.res]
1378
+ style = get_style_code(ws[:, self.res * 2 - 8], gs)
1379
+ x = self.conv1(x, style, noise_mode=noise_mode)
1380
+ style = get_style_code(ws[:, self.res * 2 - 7], gs)
1381
+ img = self.toRGB(x, style, skip=img)
1382
+
1383
+ return x, img
1384
+
1385
+
1386
+ class Decoder(nn.Module):
1387
+ def __init__(
1388
+ self, res_log2, activation, style_dim, use_noise, demodulate, img_channels
1389
+ ):
1390
+ super().__init__()
1391
+ self.Dec_16x16 = DecBlockFirstV2(
1392
+ 4, nf(4), nf(4), activation, style_dim, use_noise, demodulate, img_channels
1393
+ )
1394
+ for res in range(5, res_log2 + 1):
1395
+ setattr(
1396
+ self,
1397
+ "Dec_%dx%d" % (2**res, 2**res),
1398
+ DecBlock(
1399
+ res,
1400
+ nf(res - 1),
1401
+ nf(res),
1402
+ activation,
1403
+ style_dim,
1404
+ use_noise,
1405
+ demodulate,
1406
+ img_channels,
1407
+ ),
1408
+ )
1409
+ self.res_log2 = res_log2
1410
+
1411
+ def forward(self, x, ws, gs, E_features, noise_mode="random"):
1412
+ x, img = self.Dec_16x16(x, ws, gs, E_features, noise_mode=noise_mode)
1413
+ for res in range(5, self.res_log2 + 1):
1414
+ block = getattr(self, "Dec_%dx%d" % (2**res, 2**res))
1415
+ x, img = block(x, img, ws, gs, E_features, noise_mode=noise_mode)
1416
+
1417
+ return img
1418
+
1419
+
1420
+ class DecStyleBlock(nn.Module):
1421
+ def __init__(
1422
+ self,
1423
+ res,
1424
+ in_channels,
1425
+ out_channels,
1426
+ activation,
1427
+ style_dim,
1428
+ use_noise,
1429
+ demodulate,
1430
+ img_channels,
1431
+ ):
1432
+ super().__init__()
1433
+ self.res = res
1434
+
1435
+ self.conv0 = StyleConv(
1436
+ in_channels=in_channels,
1437
+ out_channels=out_channels,
1438
+ style_dim=style_dim,
1439
+ resolution=2**res,
1440
+ kernel_size=3,
1441
+ up=2,
1442
+ use_noise=use_noise,
1443
+ activation=activation,
1444
+ demodulate=demodulate,
1445
+ )
1446
+ self.conv1 = StyleConv(
1447
+ in_channels=out_channels,
1448
+ out_channels=out_channels,
1449
+ style_dim=style_dim,
1450
+ resolution=2**res,
1451
+ kernel_size=3,
1452
+ use_noise=use_noise,
1453
+ activation=activation,
1454
+ demodulate=demodulate,
1455
+ )
1456
+ self.toRGB = ToRGB(
1457
+ in_channels=out_channels,
1458
+ out_channels=img_channels,
1459
+ style_dim=style_dim,
1460
+ kernel_size=1,
1461
+ demodulate=False,
1462
+ )
1463
+
1464
+ def forward(self, x, img, style, skip, noise_mode="random"):
1465
+ x = self.conv0(x, style, noise_mode=noise_mode)
1466
+ x = x + skip
1467
+ x = self.conv1(x, style, noise_mode=noise_mode)
1468
+ img = self.toRGB(x, style, skip=img)
1469
+
1470
+ return x, img
1471
+
1472
+
1473
+ class FirstStage(nn.Module):
1474
+ def __init__(
1475
+ self,
1476
+ img_channels,
1477
+ img_resolution=256,
1478
+ dim=180,
1479
+ w_dim=512,
1480
+ use_noise=False,
1481
+ demodulate=True,
1482
+ activation="lrelu",
1483
+ ):
1484
+ super().__init__()
1485
+ res = 64
1486
+
1487
+ self.conv_first = Conv2dLayerPartial(
1488
+ in_channels=img_channels + 1,
1489
+ out_channels=dim,
1490
+ kernel_size=3,
1491
+ activation=activation,
1492
+ )
1493
+ self.enc_conv = nn.ModuleList()
1494
+ down_time = int(np.log2(img_resolution // res))
1495
+ # 根据图片尺寸构建 swim transformer 的层数
1496
+ for i in range(down_time): # from input size to 64
1497
+ self.enc_conv.append(
1498
+ Conv2dLayerPartial(
1499
+ in_channels=dim,
1500
+ out_channels=dim,
1501
+ kernel_size=3,
1502
+ down=2,
1503
+ activation=activation,
1504
+ )
1505
+ )
1506
+
1507
+ # from 64 -> 16 -> 64
1508
+ depths = [2, 3, 4, 3, 2]
1509
+ ratios = [1, 1 / 2, 1 / 2, 2, 2]
1510
+ num_heads = 6
1511
+ window_sizes = [8, 16, 16, 16, 8]
1512
+ drop_path_rate = 0.1
1513
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]
1514
+
1515
+ self.tran = nn.ModuleList()
1516
+ for i, depth in enumerate(depths):
1517
+ res = int(res * ratios[i])
1518
+ if ratios[i] < 1:
1519
+ merge = PatchMerging(dim, dim, down=int(1 / ratios[i]))
1520
+ elif ratios[i] > 1:
1521
+ merge = PatchUpsampling(dim, dim, up=ratios[i])
1522
+ else:
1523
+ merge = None
1524
+ self.tran.append(
1525
+ BasicLayer(
1526
+ dim=dim,
1527
+ input_resolution=[res, res],
1528
+ depth=depth,
1529
+ num_heads=num_heads,
1530
+ window_size=window_sizes[i],
1531
+ drop_path=dpr[sum(depths[:i]) : sum(depths[: i + 1])],
1532
+ downsample=merge,
1533
+ )
1534
+ )
1535
+
1536
+ # global style
1537
+ down_conv = []
1538
+ for i in range(int(np.log2(16))):
1539
+ down_conv.append(
1540
+ Conv2dLayer(
1541
+ in_channels=dim,
1542
+ out_channels=dim,
1543
+ kernel_size=3,
1544
+ down=2,
1545
+ activation=activation,
1546
+ )
1547
+ )
1548
+ down_conv.append(nn.AdaptiveAvgPool2d((1, 1)))
1549
+ self.down_conv = nn.Sequential(*down_conv)
1550
+ self.to_style = FullyConnectedLayer(
1551
+ in_features=dim, out_features=dim * 2, activation=activation
1552
+ )
1553
+ self.ws_style = FullyConnectedLayer(
1554
+ in_features=w_dim, out_features=dim, activation=activation
1555
+ )
1556
+ self.to_square = FullyConnectedLayer(
1557
+ in_features=dim, out_features=16 * 16, activation=activation
1558
+ )
1559
+
1560
+ style_dim = dim * 3
1561
+ self.dec_conv = nn.ModuleList()
1562
+ for i in range(down_time): # from 64 to input size
1563
+ res = res * 2
1564
+ self.dec_conv.append(
1565
+ DecStyleBlock(
1566
+ res,
1567
+ dim,
1568
+ dim,
1569
+ activation,
1570
+ style_dim,
1571
+ use_noise,
1572
+ demodulate,
1573
+ img_channels,
1574
+ )
1575
+ )
1576
+
1577
+ def forward(self, images_in, masks_in, ws, noise_mode="random"):
1578
+ x = torch.cat([masks_in - 0.5, images_in * masks_in], dim=1)
1579
+
1580
+ skips = []
1581
+ x, mask = self.conv_first(x, masks_in) # input size
1582
+ skips.append(x)
1583
+ for i, block in enumerate(self.enc_conv): # input size to 64
1584
+ x, mask = block(x, mask)
1585
+ if i != len(self.enc_conv) - 1:
1586
+ skips.append(x)
1587
+
1588
+ x_size = x.size()[-2:]
1589
+ x = feature2token(x)
1590
+ mask = feature2token(mask)
1591
+ mid = len(self.tran) // 2
1592
+ for i, block in enumerate(self.tran): # 64 to 16
1593
+ if i < mid:
1594
+ x, x_size, mask = block(x, x_size, mask)
1595
+ skips.append(x)
1596
+ elif i > mid:
1597
+ x, x_size, mask = block(x, x_size, None)
1598
+ x = x + skips[mid - i]
1599
+ else:
1600
+ x, x_size, mask = block(x, x_size, None)
1601
+
1602
+ mul_map = torch.ones_like(x) * 0.5
1603
+ mul_map = F.dropout(mul_map, training=True)
1604
+ ws = self.ws_style(ws[:, -1])
1605
+ add_n = self.to_square(ws).unsqueeze(1)
1606
+ add_n = (
1607
+ F.interpolate(
1608
+ add_n, size=x.size(1), mode="linear", align_corners=False
1609
+ )
1610
+ .squeeze(1)
1611
+ .unsqueeze(-1)
1612
+ )
1613
+ x = x * mul_map + add_n * (1 - mul_map)
1614
+ gs = self.to_style(
1615
+ self.down_conv(token2feature(x, x_size)).flatten(start_dim=1)
1616
+ )
1617
+ style = torch.cat([gs, ws], dim=1)
1618
+
1619
+ x = token2feature(x, x_size).contiguous()
1620
+ img = None
1621
+ for i, block in enumerate(self.dec_conv):
1622
+ x, img = block(
1623
+ x, img, style, skips[len(self.dec_conv) - i - 1], noise_mode=noise_mode
1624
+ )
1625
+
1626
+ # ensemble
1627
+ img = img * (1 - masks_in) + images_in * masks_in
1628
+
1629
+ return img
1630
+
1631
+
1632
+ class SynthesisNet(nn.Module):
1633
+ def __init__(
1634
+ self,
1635
+ w_dim, # Intermediate latent (W) dimensionality.
1636
+ img_resolution, # Output image resolution.
1637
+ img_channels=3, # Number of color channels.
1638
+ channel_base=32768, # Overall multiplier for the number of channels.
1639
+ channel_decay=1.0,
1640
+ channel_max=512, # Maximum number of channels in any layer.
1641
+ activation="lrelu", # Activation function: 'relu', 'lrelu', etc.
1642
+ drop_rate=0.5,
1643
+ use_noise=False,
1644
+ demodulate=True,
1645
+ ):
1646
+ super().__init__()
1647
+ resolution_log2 = int(np.log2(img_resolution))
1648
+ assert img_resolution == 2**resolution_log2 and img_resolution >= 4
1649
+
1650
+ self.num_layers = resolution_log2 * 2 - 3 * 2
1651
+ self.img_resolution = img_resolution
1652
+ self.resolution_log2 = resolution_log2
1653
+
1654
+ # first stage
1655
+ self.first_stage = FirstStage(
1656
+ img_channels,
1657
+ img_resolution=img_resolution,
1658
+ w_dim=w_dim,
1659
+ use_noise=False,
1660
+ demodulate=demodulate,
1661
+ )
1662
+
1663
+ # second stage
1664
+ self.enc = Encoder(
1665
+ resolution_log2, img_channels, activation, patch_size=5, channels=16
1666
+ )
1667
+ self.to_square = FullyConnectedLayer(
1668
+ in_features=w_dim, out_features=16 * 16, activation=activation
1669
+ )
1670
+ self.to_style = ToStyle(
1671
+ in_channels=nf(4),
1672
+ out_channels=nf(2) * 2,
1673
+ activation=activation,
1674
+ drop_rate=drop_rate,
1675
+ )
1676
+ style_dim = w_dim + nf(2) * 2
1677
+ self.dec = Decoder(
1678
+ resolution_log2, activation, style_dim, use_noise, demodulate, img_channels
1679
+ )
1680
+
1681
+ def forward(self, images_in, masks_in, ws, noise_mode="random", return_stg1=False):
1682
+ out_stg1 = self.first_stage(images_in, masks_in, ws, noise_mode=noise_mode)
1683
+
1684
+ # encoder
1685
+ x = images_in * masks_in + out_stg1 * (1 - masks_in)
1686
+ x = torch.cat([masks_in - 0.5, x, images_in * masks_in], dim=1)
1687
+ E_features = self.enc(x)
1688
+
1689
+ fea_16 = E_features[4]
1690
+ mul_map = torch.ones_like(fea_16) * 0.5
1691
+ mul_map = F.dropout(mul_map, training=True)
1692
+ add_n = self.to_square(ws[:, 0]).view(-1, 16, 16).unsqueeze(1)
1693
+ add_n = F.interpolate(
1694
+ add_n, size=fea_16.size()[-2:], mode="bilinear", align_corners=False
1695
+ )
1696
+ fea_16 = fea_16 * mul_map + add_n * (1 - mul_map)
1697
+ E_features[4] = fea_16
1698
+
1699
+ # style
1700
+ gs = self.to_style(fea_16)
1701
+
1702
+ # decoder
1703
+ img = self.dec(fea_16, ws, gs, E_features, noise_mode=noise_mode)
1704
+
1705
+ # ensemble
1706
+ img = img * (1 - masks_in) + images_in * masks_in
1707
+
1708
+ if not return_stg1:
1709
+ return img
1710
+ else:
1711
+ return img, out_stg1
1712
+
1713
+
1714
+ class Generator(nn.Module):
1715
+ def __init__(
1716
+ self,
1717
+ z_dim, # Input latent (Z) dimensionality, 0 = no latent.
1718
+ c_dim, # Conditioning label (C) dimensionality, 0 = no label.
1719
+ w_dim, # Intermediate latent (W) dimensionality.
1720
+ img_resolution, # resolution of generated image
1721
+ img_channels, # Number of input color channels.
1722
+ synthesis_kwargs={}, # Arguments for SynthesisNetwork.
1723
+ mapping_kwargs={}, # Arguments for MappingNetwork.
1724
+ ):
1725
+ super().__init__()
1726
+ self.z_dim = z_dim
1727
+ self.c_dim = c_dim
1728
+ self.w_dim = w_dim
1729
+ self.img_resolution = img_resolution
1730
+ self.img_channels = img_channels
1731
+
1732
+ self.synthesis = SynthesisNet(
1733
+ w_dim=w_dim,
1734
+ img_resolution=img_resolution,
1735
+ img_channels=img_channels,
1736
+ **synthesis_kwargs,
1737
+ )
1738
+ self.mapping = MappingNet(
1739
+ z_dim=z_dim,
1740
+ c_dim=c_dim,
1741
+ w_dim=w_dim,
1742
+ num_ws=self.synthesis.num_layers,
1743
+ **mapping_kwargs,
1744
+ )
1745
+
1746
+ def forward(
1747
+ self,
1748
+ images_in,
1749
+ masks_in,
1750
+ z,
1751
+ c,
1752
+ truncation_psi=1,
1753
+ truncation_cutoff=None,
1754
+ skip_w_avg_update=False,
1755
+ noise_mode="none",
1756
+ return_stg1=False,
1757
+ ):
1758
+ ws = self.mapping(
1759
+ z,
1760
+ c,
1761
+ truncation_psi=truncation_psi,
1762
+ truncation_cutoff=truncation_cutoff,
1763
+ skip_w_avg_update=skip_w_avg_update,
1764
+ )
1765
+ img = self.synthesis(images_in, masks_in, ws, noise_mode=noise_mode)
1766
+ return img
1767
+
1768
+
1769
+ class Discriminator(torch.nn.Module):
1770
+ def __init__(
1771
+ self,
1772
+ c_dim, # Conditioning label (C) dimensionality.
1773
+ img_resolution, # Input resolution.
1774
+ img_channels, # Number of input color channels.
1775
+ channel_base=32768, # Overall multiplier for the number of channels.
1776
+ channel_max=512, # Maximum number of channels in any layer.
1777
+ channel_decay=1,
1778
+ cmap_dim=None, # Dimensionality of mapped conditioning label, None = default.
1779
+ activation="lrelu",
1780
+ mbstd_group_size=4, # Group size for the minibatch standard deviation layer, None = entire minibatch.
1781
+ mbstd_num_channels=1, # Number of features for the minibatch standard deviation layer, 0 = disable.
1782
+ ):
1783
+ super().__init__()
1784
+ self.c_dim = c_dim
1785
+ self.img_resolution = img_resolution
1786
+ self.img_channels = img_channels
1787
+
1788
+ resolution_log2 = int(np.log2(img_resolution))
1789
+ assert img_resolution == 2**resolution_log2 and img_resolution >= 4
1790
+ self.resolution_log2 = resolution_log2
1791
+
1792
+ if cmap_dim == None:
1793
+ cmap_dim = nf(2)
1794
+ if c_dim == 0:
1795
+ cmap_dim = 0
1796
+ self.cmap_dim = cmap_dim
1797
+
1798
+ if c_dim > 0:
1799
+ self.mapping = MappingNet(
1800
+ z_dim=0, c_dim=c_dim, w_dim=cmap_dim, num_ws=None, w_avg_beta=None
1801
+ )
1802
+
1803
+ Dis = [DisFromRGB(img_channels + 1, nf(resolution_log2), activation)]
1804
+ for res in range(resolution_log2, 2, -1):
1805
+ Dis.append(DisBlock(nf(res), nf(res - 1), activation))
1806
+
1807
+ if mbstd_num_channels > 0:
1808
+ Dis.append(
1809
+ MinibatchStdLayer(
1810
+ group_size=mbstd_group_size, num_channels=mbstd_num_channels
1811
+ )
1812
+ )
1813
+ Dis.append(
1814
+ Conv2dLayer(
1815
+ nf(2) + mbstd_num_channels, nf(2), kernel_size=3, activation=activation
1816
+ )
1817
+ )
1818
+ self.Dis = nn.Sequential(*Dis)
1819
+
1820
+ self.fc0 = FullyConnectedLayer(nf(2) * 4**2, nf(2), activation=activation)
1821
+ self.fc1 = FullyConnectedLayer(nf(2), 1 if cmap_dim == 0 else cmap_dim)
1822
+
1823
+ # for 64x64
1824
+ Dis_stg1 = [DisFromRGB(img_channels + 1, nf(resolution_log2) // 2, activation)]
1825
+ for res in range(resolution_log2, 2, -1):
1826
+ Dis_stg1.append(DisBlock(nf(res) // 2, nf(res - 1) // 2, activation))
1827
+
1828
+ if mbstd_num_channels > 0:
1829
+ Dis_stg1.append(
1830
+ MinibatchStdLayer(
1831
+ group_size=mbstd_group_size, num_channels=mbstd_num_channels
1832
+ )
1833
+ )
1834
+ Dis_stg1.append(
1835
+ Conv2dLayer(
1836
+ nf(2) // 2 + mbstd_num_channels,
1837
+ nf(2) // 2,
1838
+ kernel_size=3,
1839
+ activation=activation,
1840
+ )
1841
+ )
1842
+ self.Dis_stg1 = nn.Sequential(*Dis_stg1)
1843
+
1844
+ self.fc0_stg1 = FullyConnectedLayer(
1845
+ nf(2) // 2 * 4**2, nf(2) // 2, activation=activation
1846
+ )
1847
+ self.fc1_stg1 = FullyConnectedLayer(
1848
+ nf(2) // 2, 1 if cmap_dim == 0 else cmap_dim
1849
+ )
1850
+
1851
+ def forward(self, images_in, masks_in, images_stg1, c):
1852
+ x = self.Dis(torch.cat([masks_in - 0.5, images_in], dim=1))
1853
+ x = self.fc1(self.fc0(x.flatten(start_dim=1)))
1854
+
1855
+ x_stg1 = self.Dis_stg1(torch.cat([masks_in - 0.5, images_stg1], dim=1))
1856
+ x_stg1 = self.fc1_stg1(self.fc0_stg1(x_stg1.flatten(start_dim=1)))
1857
+
1858
+ if self.c_dim > 0:
1859
+ cmap = self.mapping(None, c)
1860
+
1861
+ if self.cmap_dim > 0:
1862
+ x = (x * cmap).sum(dim=1, keepdim=True) * (1 / np.sqrt(self.cmap_dim))
1863
+ x_stg1 = (x_stg1 * cmap).sum(dim=1, keepdim=True) * (
1864
+ 1 / np.sqrt(self.cmap_dim)
1865
+ )
1866
+
1867
+ return x, x_stg1
1868
+
1869
+
1870
+ MAT_MODEL_URL = os.environ.get(
1871
+ "MAT_MODEL_URL",
1872
+ "https://github.com/Sanster/models/releases/download/add_mat/Places_512_FullData_G.pth",
1873
+ )
1874
+
1875
+ MAT_MODEL_MD5 = os.environ.get("MAT_MODEL_MD5", "8ca927835fa3f5e21d65ffcb165377ed")
1876
+
1877
+
1878
+ class MAT(InpaintModel):
1879
+ name = "mat"
1880
+ min_size = 512
1881
+ pad_mod = 512
1882
+ pad_to_square = True
1883
+ is_erase_model = True
1884
+
1885
+ def init_model(self, device, **kwargs):
1886
+ seed = 240 # pick up a random number
1887
+ set_seed(seed)
1888
+
1889
+ fp16 = not kwargs.get("no_half", False)
1890
+ use_gpu = "cuda" in str(device) and torch.cuda.is_available()
1891
+ self.torch_dtype = torch.float16 if use_gpu and fp16 else torch.float32
1892
+
1893
+ G = Generator(
1894
+ z_dim=512,
1895
+ c_dim=0,
1896
+ w_dim=512,
1897
+ img_resolution=512,
1898
+ img_channels=3,
1899
+ mapping_kwargs={"torch_dtype": self.torch_dtype},
1900
+ ).to(self.torch_dtype)
1901
+ # fmt: off
1902
+ self.model = load_model(G, MAT_MODEL_URL, device, MAT_MODEL_MD5)
1903
+ self.z = torch.from_numpy(np.random.randn(1, G.z_dim)).to(self.torch_dtype).to(device)
1904
+ self.label = torch.zeros([1, self.model.c_dim], device=device).to(self.torch_dtype)
1905
+ # fmt: on
1906
+
1907
+ @staticmethod
1908
+ def download():
1909
+ download_model(MAT_MODEL_URL, MAT_MODEL_MD5)
1910
+
1911
+ @staticmethod
1912
+ def is_downloaded() -> bool:
1913
+ return os.path.exists(get_cache_path_by_url(MAT_MODEL_URL))
1914
+
1915
+ def forward(self, image, mask, config: InpaintRequest):
1916
+ """Input images and output images have same size
1917
+ images: [H, W, C] RGB
1918
+ masks: [H, W] mask area == 255
1919
+ return: BGR IMAGE
1920
+ """
1921
+
1922
+ image = norm_img(image) # [0, 1]
1923
+ image = image * 2 - 1 # [0, 1] -> [-1, 1]
1924
+
1925
+ mask = (mask > 127) * 255
1926
+ mask = 255 - mask
1927
+ mask = norm_img(mask)
1928
+
1929
+ image = (
1930
+ torch.from_numpy(image).unsqueeze(0).to(self.torch_dtype).to(self.device)
1931
+ )
1932
+ mask = torch.from_numpy(mask).unsqueeze(0).to(self.torch_dtype).to(self.device)
1933
+
1934
+ output = self.model(
1935
+ image, mask, self.z, self.label, truncation_psi=1, noise_mode="none"
1936
+ )
1937
+ output = (
1938
+ (output.permute(0, 2, 3, 1) * 127.5 + 127.5)
1939
+ .round()
1940
+ .clamp(0, 255)
1941
+ .to(torch.uint8)
1942
+ )
1943
+ output = output[0].cpu().numpy()
1944
+ cur_res = cv2.cvtColor(output, cv2.COLOR_RGB2BGR)
1945
+ return cur_res
iopaint/model/mi_gan.py ADDED
@@ -0,0 +1,110 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ import cv2
4
+ import torch
5
+
6
+ from iopaint.helper import (
7
+ load_jit_model,
8
+ download_model,
9
+ get_cache_path_by_url,
10
+ boxes_from_mask,
11
+ resize_max_size,
12
+ norm_img,
13
+ )
14
+ from .base import InpaintModel
15
+ from iopaint.schema import InpaintRequest
16
+
17
+ MIGAN_MODEL_URL = os.environ.get(
18
+ "MIGAN_MODEL_URL",
19
+ "https://github.com/Sanster/models/releases/download/migan/migan_traced.pt",
20
+ )
21
+ MIGAN_MODEL_MD5 = os.environ.get("MIGAN_MODEL_MD5", "76eb3b1a71c400ee3290524f7a11b89c")
22
+
23
+
24
+ class MIGAN(InpaintModel):
25
+ name = "migan"
26
+ min_size = 512
27
+ pad_mod = 512
28
+ pad_to_square = True
29
+ is_erase_model = True
30
+
31
+ def init_model(self, device, **kwargs):
32
+ self.model = load_jit_model(MIGAN_MODEL_URL, device, MIGAN_MODEL_MD5).eval()
33
+
34
+ @staticmethod
35
+ def download():
36
+ download_model(MIGAN_MODEL_URL, MIGAN_MODEL_MD5)
37
+
38
+ @staticmethod
39
+ def is_downloaded() -> bool:
40
+ return os.path.exists(get_cache_path_by_url(MIGAN_MODEL_URL))
41
+
42
+ @torch.no_grad()
43
+ def __call__(self, image, mask, config: InpaintRequest):
44
+ """
45
+ images: [H, W, C] RGB, not normalized
46
+ masks: [H, W]
47
+ return: BGR IMAGE
48
+ """
49
+ if image.shape[0] == 512 and image.shape[1] == 512:
50
+ return self._pad_forward(image, mask, config)
51
+
52
+ boxes = boxes_from_mask(mask)
53
+ crop_result = []
54
+ config.hd_strategy_crop_margin = 128
55
+ for box in boxes:
56
+ crop_image, crop_mask, crop_box = self._crop_box(image, mask, box, config)
57
+ origin_size = crop_image.shape[:2]
58
+ resize_image = resize_max_size(crop_image, size_limit=512)
59
+ resize_mask = resize_max_size(crop_mask, size_limit=512)
60
+ inpaint_result = self._pad_forward(resize_image, resize_mask, config)
61
+
62
+ # only paste masked area result
63
+ inpaint_result = cv2.resize(
64
+ inpaint_result,
65
+ (origin_size[1], origin_size[0]),
66
+ interpolation=cv2.INTER_CUBIC,
67
+ )
68
+
69
+ original_pixel_indices = crop_mask < 127
70
+ inpaint_result[original_pixel_indices] = crop_image[:, :, ::-1][
71
+ original_pixel_indices
72
+ ]
73
+
74
+ crop_result.append((inpaint_result, crop_box))
75
+
76
+ inpaint_result = image[:, :, ::-1].copy()
77
+ for crop_image, crop_box in crop_result:
78
+ x1, y1, x2, y2 = crop_box
79
+ inpaint_result[y1:y2, x1:x2, :] = crop_image
80
+
81
+ return inpaint_result
82
+
83
+ def forward(self, image, mask, config: InpaintRequest):
84
+ """Input images and output images have same size
85
+ images: [H, W, C] RGB
86
+ masks: [H, W] mask area == 255
87
+ return: BGR IMAGE
88
+ """
89
+
90
+ image = norm_img(image) # [0, 1]
91
+ image = image * 2 - 1 # [0, 1] -> [-1, 1]
92
+ mask = (mask > 120) * 255
93
+ mask = norm_img(mask)
94
+
95
+ image = torch.from_numpy(image).unsqueeze(0).to(self.device)
96
+ mask = torch.from_numpy(mask).unsqueeze(0).to(self.device)
97
+
98
+ erased_img = image * (1 - mask)
99
+ input_image = torch.cat([0.5 - mask, erased_img], dim=1)
100
+
101
+ output = self.model(input_image)
102
+ output = (
103
+ (output.permute(0, 2, 3, 1) * 127.5 + 127.5)
104
+ .round()
105
+ .clamp(0, 255)
106
+ .to(torch.uint8)
107
+ )
108
+ output = output[0].cpu().numpy()
109
+ cur_res = cv2.cvtColor(output, cv2.COLOR_RGB2BGR)
110
+ return cur_res
iopaint/model/opencv2.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ from .base import InpaintModel
3
+ from iopaint.schema import InpaintRequest
4
+
5
+ flag_map = {"INPAINT_NS": cv2.INPAINT_NS, "INPAINT_TELEA": cv2.INPAINT_TELEA}
6
+
7
+
8
+ class OpenCV2(InpaintModel):
9
+ name = "cv2"
10
+ pad_mod = 1
11
+ is_erase_model = True
12
+
13
+ @staticmethod
14
+ def is_downloaded() -> bool:
15
+ return True
16
+
17
+ def forward(self, image, mask, config: InpaintRequest):
18
+ """Input image and output image have same size
19
+ image: [H, W, C] RGB
20
+ mask: [H, W, 1]
21
+ return: BGR IMAGE
22
+ """
23
+ cur_res = cv2.inpaint(
24
+ image[:, :, ::-1],
25
+ mask,
26
+ inpaintRadius=config.cv2_radius,
27
+ flags=flag_map[config.cv2_flag],
28
+ )
29
+ return cur_res
iopaint/model_manager.py ADDED
@@ -0,0 +1,191 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Dict
2
+
3
+ import torch
4
+ from loguru import logger
5
+ import numpy as np
6
+
7
+ from iopaint.download import scan_models
8
+ from iopaint.helper import switch_mps_device
9
+ from iopaint.model import models, ControlNet, SD, SDXL
10
+ from iopaint.model.utils import torch_gc, is_local_files_only
11
+ from iopaint.schema import InpaintRequest, ModelInfo, ModelType
12
+
13
+
14
+ class ModelManager:
15
+ def __init__(self, name: str, device: torch.device, **kwargs):
16
+ self.name = name
17
+ self.device = device
18
+ self.kwargs = kwargs
19
+ self.available_models: Dict[str, ModelInfo] = {}
20
+ self.scan_models()
21
+
22
+ self.enable_controlnet = kwargs.get("enable_controlnet", False)
23
+ controlnet_method = kwargs.get("controlnet_method", None)
24
+ if (
25
+ controlnet_method is None
26
+ and name in self.available_models
27
+ and self.available_models[name].support_controlnet
28
+ ):
29
+ controlnet_method = self.available_models[name].controlnets[0]
30
+ self.controlnet_method = controlnet_method
31
+ self.model = self.init_model(name, device, **kwargs)
32
+
33
+ @property
34
+ def current_model(self) -> ModelInfo:
35
+ return self.available_models[self.name]
36
+
37
+ def init_model(self, name: str, device, **kwargs):
38
+ logger.info(f"Loading model: {name}")
39
+ if name not in self.available_models:
40
+ raise NotImplementedError(
41
+ f"Unsupported model: {name}. Available models: {list(self.available_models.keys())}"
42
+ )
43
+
44
+ model_info = self.available_models[name]
45
+ kwargs = {
46
+ **kwargs,
47
+ "model_info": model_info,
48
+ "enable_controlnet": self.enable_controlnet,
49
+ "controlnet_method": self.controlnet_method,
50
+ }
51
+
52
+ if model_info.support_controlnet and self.enable_controlnet:
53
+ return ControlNet(device, **kwargs)
54
+ elif model_info.name in models:
55
+ return models[name](device, **kwargs)
56
+ else:
57
+ if model_info.model_type in [
58
+ ModelType.DIFFUSERS_SD_INPAINT,
59
+ ModelType.DIFFUSERS_SD,
60
+ ]:
61
+ return SD(device, **kwargs)
62
+
63
+ if model_info.model_type in [
64
+ ModelType.DIFFUSERS_SDXL_INPAINT,
65
+ ModelType.DIFFUSERS_SDXL,
66
+ ]:
67
+ return SDXL(device, **kwargs)
68
+
69
+ raise NotImplementedError(f"Unsupported model: {name}")
70
+
71
+ @torch.inference_mode()
72
+ def __call__(self, image, mask, config: InpaintRequest):
73
+ """
74
+
75
+ Args:
76
+ image: [H, W, C] RGB
77
+ mask: [H, W, 1] 255 means area to repaint
78
+ config:
79
+
80
+ Returns:
81
+ BGR image
82
+ """
83
+ self.switch_controlnet_method(config)
84
+ self.enable_disable_freeu(config)
85
+ self.enable_disable_lcm_lora(config)
86
+ return self.model(image, mask, config).astype(np.uint8)
87
+
88
+ def scan_models(self) -> List[ModelInfo]:
89
+ available_models = scan_models()
90
+ self.available_models = {it.name: it for it in available_models}
91
+ return available_models
92
+
93
+ def switch(self, new_name: str):
94
+ if new_name == self.name:
95
+ return
96
+
97
+ old_name = self.name
98
+ old_controlnet_method = self.controlnet_method
99
+ self.name = new_name
100
+
101
+ if (
102
+ self.available_models[new_name].support_controlnet
103
+ and self.controlnet_method
104
+ not in self.available_models[new_name].controlnets
105
+ ):
106
+ self.controlnet_method = self.available_models[new_name].controlnets[0]
107
+ try:
108
+ # TODO: enable/disable controlnet without reload model
109
+ del self.model
110
+ torch_gc()
111
+
112
+ self.model = self.init_model(
113
+ new_name, switch_mps_device(new_name, self.device), **self.kwargs
114
+ )
115
+ except Exception as e:
116
+ self.name = old_name
117
+ self.controlnet_method = old_controlnet_method
118
+ logger.info(f"Switch model from {old_name} to {new_name} failed, rollback")
119
+ self.model = self.init_model(
120
+ old_name, switch_mps_device(old_name, self.device), **self.kwargs
121
+ )
122
+ raise e
123
+
124
+ def switch_controlnet_method(self, config):
125
+ if not self.available_models[self.name].support_controlnet:
126
+ return
127
+
128
+ if (
129
+ self.enable_controlnet
130
+ and config.controlnet_method
131
+ and self.controlnet_method != config.controlnet_method
132
+ ):
133
+ old_controlnet_method = self.controlnet_method
134
+ self.controlnet_method = config.controlnet_method
135
+ self.model.switch_controlnet_method(config.controlnet_method)
136
+ logger.info(
137
+ f"Switch Controlnet method from {old_controlnet_method} to {config.controlnet_method}"
138
+ )
139
+ elif self.enable_controlnet != config.enable_controlnet:
140
+ self.enable_controlnet = config.enable_controlnet
141
+ self.controlnet_method = config.controlnet_method
142
+
143
+ pipe_components = {
144
+ "vae": self.model.model.vae,
145
+ "text_encoder": self.model.model.text_encoder,
146
+ "unet": self.model.model.unet,
147
+ }
148
+ if hasattr(self.model.model, "text_encoder_2"):
149
+ pipe_components["text_encoder_2"] = self.model.model.text_encoder_2
150
+
151
+ self.model = self.init_model(
152
+ self.name,
153
+ switch_mps_device(self.name, self.device),
154
+ pipe_components=pipe_components,
155
+ **self.kwargs,
156
+ )
157
+ if not config.enable_controlnet:
158
+ logger.info(f"Disable controlnet")
159
+ else:
160
+ logger.info(f"Enable controlnet: {config.controlnet_method}")
161
+
162
+ def enable_disable_freeu(self, config: InpaintRequest):
163
+ if str(self.model.device) == "mps":
164
+ return
165
+
166
+ if self.available_models[self.name].support_freeu:
167
+ if config.sd_freeu:
168
+ freeu_config = config.sd_freeu_config
169
+ self.model.model.enable_freeu(
170
+ s1=freeu_config.s1,
171
+ s2=freeu_config.s2,
172
+ b1=freeu_config.b1,
173
+ b2=freeu_config.b2,
174
+ )
175
+ else:
176
+ self.model.model.disable_freeu()
177
+
178
+ def enable_disable_lcm_lora(self, config: InpaintRequest):
179
+ if self.available_models[self.name].support_lcm_lora:
180
+ # TODO: change this if load other lora is supported
181
+ lcm_lora_loaded = bool(self.model.model.get_list_adapters())
182
+ if config.sd_lcm_lora:
183
+ if not lcm_lora_loaded:
184
+ self.model.model.load_lora_weights(
185
+ self.model.lcm_lora_id,
186
+ weight_name="pytorch_lora_weights.safetensors",
187
+ local_files_only=is_local_files_only(),
188
+ )
189
+ else:
190
+ if lcm_lora_loaded:
191
+ self.model.model.disable_lora()
iopaint/plugins/briarmbg.py ADDED
@@ -0,0 +1,512 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # copy from: https://huggingface.co/spaces/briaai/BRIA-RMBG-1.4/blob/main/briarmbg.py
2
+ import cv2
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+ from PIL import Image
7
+ import numpy as np
8
+ from torchvision.transforms.functional import normalize
9
+
10
+
11
+ class REBNCONV(nn.Module):
12
+ def __init__(self, in_ch=3, out_ch=3, dirate=1, stride=1):
13
+ super(REBNCONV, self).__init__()
14
+
15
+ self.conv_s1 = nn.Conv2d(
16
+ in_ch, out_ch, 3, padding=1 * dirate, dilation=1 * dirate, stride=stride
17
+ )
18
+ self.bn_s1 = nn.BatchNorm2d(out_ch)
19
+ self.relu_s1 = nn.ReLU(inplace=True)
20
+
21
+ def forward(self, x):
22
+ hx = x
23
+ xout = self.relu_s1(self.bn_s1(self.conv_s1(hx)))
24
+
25
+ return xout
26
+
27
+
28
+ ## upsample tensor 'src' to have the same spatial size with tensor 'tar'
29
+ def _upsample_like(src, tar):
30
+ src = F.interpolate(src, size=tar.shape[2:], mode="bilinear")
31
+
32
+ return src
33
+
34
+
35
+ ### RSU-7 ###
36
+ class RSU7(nn.Module):
37
+ def __init__(self, in_ch=3, mid_ch=12, out_ch=3, img_size=512):
38
+ super(RSU7, self).__init__()
39
+
40
+ self.in_ch = in_ch
41
+ self.mid_ch = mid_ch
42
+ self.out_ch = out_ch
43
+
44
+ self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1) ## 1 -> 1/2
45
+
46
+ self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1)
47
+ self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
48
+
49
+ self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=1)
50
+ self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
51
+
52
+ self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=1)
53
+ self.pool3 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
54
+
55
+ self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=1)
56
+ self.pool4 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
57
+
58
+ self.rebnconv5 = REBNCONV(mid_ch, mid_ch, dirate=1)
59
+ self.pool5 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
60
+
61
+ self.rebnconv6 = REBNCONV(mid_ch, mid_ch, dirate=1)
62
+
63
+ self.rebnconv7 = REBNCONV(mid_ch, mid_ch, dirate=2)
64
+
65
+ self.rebnconv6d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
66
+ self.rebnconv5d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
67
+ self.rebnconv4d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
68
+ self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
69
+ self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
70
+ self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1)
71
+
72
+ def forward(self, x):
73
+ b, c, h, w = x.shape
74
+
75
+ hx = x
76
+ hxin = self.rebnconvin(hx)
77
+
78
+ hx1 = self.rebnconv1(hxin)
79
+ hx = self.pool1(hx1)
80
+
81
+ hx2 = self.rebnconv2(hx)
82
+ hx = self.pool2(hx2)
83
+
84
+ hx3 = self.rebnconv3(hx)
85
+ hx = self.pool3(hx3)
86
+
87
+ hx4 = self.rebnconv4(hx)
88
+ hx = self.pool4(hx4)
89
+
90
+ hx5 = self.rebnconv5(hx)
91
+ hx = self.pool5(hx5)
92
+
93
+ hx6 = self.rebnconv6(hx)
94
+
95
+ hx7 = self.rebnconv7(hx6)
96
+
97
+ hx6d = self.rebnconv6d(torch.cat((hx7, hx6), 1))
98
+ hx6dup = _upsample_like(hx6d, hx5)
99
+
100
+ hx5d = self.rebnconv5d(torch.cat((hx6dup, hx5), 1))
101
+ hx5dup = _upsample_like(hx5d, hx4)
102
+
103
+ hx4d = self.rebnconv4d(torch.cat((hx5dup, hx4), 1))
104
+ hx4dup = _upsample_like(hx4d, hx3)
105
+
106
+ hx3d = self.rebnconv3d(torch.cat((hx4dup, hx3), 1))
107
+ hx3dup = _upsample_like(hx3d, hx2)
108
+
109
+ hx2d = self.rebnconv2d(torch.cat((hx3dup, hx2), 1))
110
+ hx2dup = _upsample_like(hx2d, hx1)
111
+
112
+ hx1d = self.rebnconv1d(torch.cat((hx2dup, hx1), 1))
113
+
114
+ return hx1d + hxin
115
+
116
+
117
+ ### RSU-6 ###
118
+ class RSU6(nn.Module):
119
+ def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
120
+ super(RSU6, self).__init__()
121
+
122
+ self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1)
123
+
124
+ self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1)
125
+ self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
126
+
127
+ self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=1)
128
+ self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
129
+
130
+ self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=1)
131
+ self.pool3 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
132
+
133
+ self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=1)
134
+ self.pool4 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
135
+
136
+ self.rebnconv5 = REBNCONV(mid_ch, mid_ch, dirate=1)
137
+
138
+ self.rebnconv6 = REBNCONV(mid_ch, mid_ch, dirate=2)
139
+
140
+ self.rebnconv5d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
141
+ self.rebnconv4d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
142
+ self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
143
+ self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
144
+ self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1)
145
+
146
+ def forward(self, x):
147
+ hx = x
148
+
149
+ hxin = self.rebnconvin(hx)
150
+
151
+ hx1 = self.rebnconv1(hxin)
152
+ hx = self.pool1(hx1)
153
+
154
+ hx2 = self.rebnconv2(hx)
155
+ hx = self.pool2(hx2)
156
+
157
+ hx3 = self.rebnconv3(hx)
158
+ hx = self.pool3(hx3)
159
+
160
+ hx4 = self.rebnconv4(hx)
161
+ hx = self.pool4(hx4)
162
+
163
+ hx5 = self.rebnconv5(hx)
164
+
165
+ hx6 = self.rebnconv6(hx5)
166
+
167
+ hx5d = self.rebnconv5d(torch.cat((hx6, hx5), 1))
168
+ hx5dup = _upsample_like(hx5d, hx4)
169
+
170
+ hx4d = self.rebnconv4d(torch.cat((hx5dup, hx4), 1))
171
+ hx4dup = _upsample_like(hx4d, hx3)
172
+
173
+ hx3d = self.rebnconv3d(torch.cat((hx4dup, hx3), 1))
174
+ hx3dup = _upsample_like(hx3d, hx2)
175
+
176
+ hx2d = self.rebnconv2d(torch.cat((hx3dup, hx2), 1))
177
+ hx2dup = _upsample_like(hx2d, hx1)
178
+
179
+ hx1d = self.rebnconv1d(torch.cat((hx2dup, hx1), 1))
180
+
181
+ return hx1d + hxin
182
+
183
+
184
+ ### RSU-5 ###
185
+ class RSU5(nn.Module):
186
+ def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
187
+ super(RSU5, self).__init__()
188
+
189
+ self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1)
190
+
191
+ self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1)
192
+ self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
193
+
194
+ self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=1)
195
+ self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
196
+
197
+ self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=1)
198
+ self.pool3 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
199
+
200
+ self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=1)
201
+
202
+ self.rebnconv5 = REBNCONV(mid_ch, mid_ch, dirate=2)
203
+
204
+ self.rebnconv4d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
205
+ self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
206
+ self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
207
+ self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1)
208
+
209
+ def forward(self, x):
210
+ hx = x
211
+
212
+ hxin = self.rebnconvin(hx)
213
+
214
+ hx1 = self.rebnconv1(hxin)
215
+ hx = self.pool1(hx1)
216
+
217
+ hx2 = self.rebnconv2(hx)
218
+ hx = self.pool2(hx2)
219
+
220
+ hx3 = self.rebnconv3(hx)
221
+ hx = self.pool3(hx3)
222
+
223
+ hx4 = self.rebnconv4(hx)
224
+
225
+ hx5 = self.rebnconv5(hx4)
226
+
227
+ hx4d = self.rebnconv4d(torch.cat((hx5, hx4), 1))
228
+ hx4dup = _upsample_like(hx4d, hx3)
229
+
230
+ hx3d = self.rebnconv3d(torch.cat((hx4dup, hx3), 1))
231
+ hx3dup = _upsample_like(hx3d, hx2)
232
+
233
+ hx2d = self.rebnconv2d(torch.cat((hx3dup, hx2), 1))
234
+ hx2dup = _upsample_like(hx2d, hx1)
235
+
236
+ hx1d = self.rebnconv1d(torch.cat((hx2dup, hx1), 1))
237
+
238
+ return hx1d + hxin
239
+
240
+
241
+ ### RSU-4 ###
242
+ class RSU4(nn.Module):
243
+ def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
244
+ super(RSU4, self).__init__()
245
+
246
+ self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1)
247
+
248
+ self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1)
249
+ self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
250
+
251
+ self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=1)
252
+ self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
253
+
254
+ self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=1)
255
+
256
+ self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=2)
257
+
258
+ self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
259
+ self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
260
+ self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1)
261
+
262
+ def forward(self, x):
263
+ hx = x
264
+
265
+ hxin = self.rebnconvin(hx)
266
+
267
+ hx1 = self.rebnconv1(hxin)
268
+ hx = self.pool1(hx1)
269
+
270
+ hx2 = self.rebnconv2(hx)
271
+ hx = self.pool2(hx2)
272
+
273
+ hx3 = self.rebnconv3(hx)
274
+
275
+ hx4 = self.rebnconv4(hx3)
276
+
277
+ hx3d = self.rebnconv3d(torch.cat((hx4, hx3), 1))
278
+ hx3dup = _upsample_like(hx3d, hx2)
279
+
280
+ hx2d = self.rebnconv2d(torch.cat((hx3dup, hx2), 1))
281
+ hx2dup = _upsample_like(hx2d, hx1)
282
+
283
+ hx1d = self.rebnconv1d(torch.cat((hx2dup, hx1), 1))
284
+
285
+ return hx1d + hxin
286
+
287
+
288
+ ### RSU-4F ###
289
+ class RSU4F(nn.Module):
290
+ def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
291
+ super(RSU4F, self).__init__()
292
+
293
+ self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1)
294
+
295
+ self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1)
296
+ self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=2)
297
+ self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=4)
298
+
299
+ self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=8)
300
+
301
+ self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=4)
302
+ self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=2)
303
+ self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1)
304
+
305
+ def forward(self, x):
306
+ hx = x
307
+
308
+ hxin = self.rebnconvin(hx)
309
+
310
+ hx1 = self.rebnconv1(hxin)
311
+ hx2 = self.rebnconv2(hx1)
312
+ hx3 = self.rebnconv3(hx2)
313
+
314
+ hx4 = self.rebnconv4(hx3)
315
+
316
+ hx3d = self.rebnconv3d(torch.cat((hx4, hx3), 1))
317
+ hx2d = self.rebnconv2d(torch.cat((hx3d, hx2), 1))
318
+ hx1d = self.rebnconv1d(torch.cat((hx2d, hx1), 1))
319
+
320
+ return hx1d + hxin
321
+
322
+
323
+ class myrebnconv(nn.Module):
324
+ def __init__(
325
+ self,
326
+ in_ch=3,
327
+ out_ch=1,
328
+ kernel_size=3,
329
+ stride=1,
330
+ padding=1,
331
+ dilation=1,
332
+ groups=1,
333
+ ):
334
+ super(myrebnconv, self).__init__()
335
+
336
+ self.conv = nn.Conv2d(
337
+ in_ch,
338
+ out_ch,
339
+ kernel_size=kernel_size,
340
+ stride=stride,
341
+ padding=padding,
342
+ dilation=dilation,
343
+ groups=groups,
344
+ )
345
+ self.bn = nn.BatchNorm2d(out_ch)
346
+ self.rl = nn.ReLU(inplace=True)
347
+
348
+ def forward(self, x):
349
+ return self.rl(self.bn(self.conv(x)))
350
+
351
+
352
+ class BriaRMBG(nn.Module):
353
+ def __init__(self, in_ch=3, out_ch=1):
354
+ super(BriaRMBG, self).__init__()
355
+
356
+ self.conv_in = nn.Conv2d(in_ch, 64, 3, stride=2, padding=1)
357
+ self.pool_in = nn.MaxPool2d(2, stride=2, ceil_mode=True)
358
+
359
+ self.stage1 = RSU7(64, 32, 64)
360
+ self.pool12 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
361
+
362
+ self.stage2 = RSU6(64, 32, 128)
363
+ self.pool23 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
364
+
365
+ self.stage3 = RSU5(128, 64, 256)
366
+ self.pool34 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
367
+
368
+ self.stage4 = RSU4(256, 128, 512)
369
+ self.pool45 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
370
+
371
+ self.stage5 = RSU4F(512, 256, 512)
372
+ self.pool56 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
373
+
374
+ self.stage6 = RSU4F(512, 256, 512)
375
+
376
+ # decoder
377
+ self.stage5d = RSU4F(1024, 256, 512)
378
+ self.stage4d = RSU4(1024, 128, 256)
379
+ self.stage3d = RSU5(512, 64, 128)
380
+ self.stage2d = RSU6(256, 32, 64)
381
+ self.stage1d = RSU7(128, 16, 64)
382
+
383
+ self.side1 = nn.Conv2d(64, out_ch, 3, padding=1)
384
+ self.side2 = nn.Conv2d(64, out_ch, 3, padding=1)
385
+ self.side3 = nn.Conv2d(128, out_ch, 3, padding=1)
386
+ self.side4 = nn.Conv2d(256, out_ch, 3, padding=1)
387
+ self.side5 = nn.Conv2d(512, out_ch, 3, padding=1)
388
+ self.side6 = nn.Conv2d(512, out_ch, 3, padding=1)
389
+
390
+ # self.outconv = nn.Conv2d(6*out_ch,out_ch,1)
391
+
392
+ def forward(self, x):
393
+ hx = x
394
+
395
+ hxin = self.conv_in(hx)
396
+ # hx = self.pool_in(hxin)
397
+
398
+ # stage 1
399
+ hx1 = self.stage1(hxin)
400
+ hx = self.pool12(hx1)
401
+
402
+ # stage 2
403
+ hx2 = self.stage2(hx)
404
+ hx = self.pool23(hx2)
405
+
406
+ # stage 3
407
+ hx3 = self.stage3(hx)
408
+ hx = self.pool34(hx3)
409
+
410
+ # stage 4
411
+ hx4 = self.stage4(hx)
412
+ hx = self.pool45(hx4)
413
+
414
+ # stage 5
415
+ hx5 = self.stage5(hx)
416
+ hx = self.pool56(hx5)
417
+
418
+ # stage 6
419
+ hx6 = self.stage6(hx)
420
+ hx6up = _upsample_like(hx6, hx5)
421
+
422
+ # -------------------- decoder --------------------
423
+ hx5d = self.stage5d(torch.cat((hx6up, hx5), 1))
424
+ hx5dup = _upsample_like(hx5d, hx4)
425
+
426
+ hx4d = self.stage4d(torch.cat((hx5dup, hx4), 1))
427
+ hx4dup = _upsample_like(hx4d, hx3)
428
+
429
+ hx3d = self.stage3d(torch.cat((hx4dup, hx3), 1))
430
+ hx3dup = _upsample_like(hx3d, hx2)
431
+
432
+ hx2d = self.stage2d(torch.cat((hx3dup, hx2), 1))
433
+ hx2dup = _upsample_like(hx2d, hx1)
434
+
435
+ hx1d = self.stage1d(torch.cat((hx2dup, hx1), 1))
436
+
437
+ # side output
438
+ d1 = self.side1(hx1d)
439
+ d1 = _upsample_like(d1, x)
440
+
441
+ d2 = self.side2(hx2d)
442
+ d2 = _upsample_like(d2, x)
443
+
444
+ d3 = self.side3(hx3d)
445
+ d3 = _upsample_like(d3, x)
446
+
447
+ d4 = self.side4(hx4d)
448
+ d4 = _upsample_like(d4, x)
449
+
450
+ d5 = self.side5(hx5d)
451
+ d5 = _upsample_like(d5, x)
452
+
453
+ d6 = self.side6(hx6)
454
+ d6 = _upsample_like(d6, x)
455
+
456
+ return [
457
+ F.sigmoid(d1),
458
+ F.sigmoid(d2),
459
+ F.sigmoid(d3),
460
+ F.sigmoid(d4),
461
+ F.sigmoid(d5),
462
+ F.sigmoid(d6),
463
+ ], [hx1d, hx2d, hx3d, hx4d, hx5d, hx6]
464
+
465
+
466
+ def resize_image(image):
467
+ image = image.convert("RGB")
468
+ model_input_size = (1024, 1024)
469
+ image = image.resize(model_input_size, Image.BILINEAR)
470
+ return image
471
+
472
+
473
+ def create_briarmbg_session():
474
+ from huggingface_hub import hf_hub_download
475
+
476
+ net = BriaRMBG()
477
+ model_path = hf_hub_download("briaai/RMBG-1.4", "model.pth")
478
+ net.load_state_dict(torch.load(model_path, map_location="cpu"))
479
+ net.eval()
480
+ return net
481
+
482
+
483
+ def briarmbg_process(bgr_np_image, session, only_mask=False):
484
+ # prepare input
485
+ orig_bgr_image = Image.fromarray(bgr_np_image)
486
+ w, h = orig_im_size = orig_bgr_image.size
487
+ image = resize_image(orig_bgr_image)
488
+ im_np = np.array(image)
489
+ im_tensor = torch.tensor(im_np, dtype=torch.float32).permute(2, 0, 1)
490
+ im_tensor = torch.unsqueeze(im_tensor, 0)
491
+ im_tensor = torch.divide(im_tensor, 255.0)
492
+ im_tensor = normalize(im_tensor, [0.5, 0.5, 0.5], [1.0, 1.0, 1.0])
493
+ # inference
494
+ result = session(im_tensor)
495
+ # post process
496
+ result = torch.squeeze(F.interpolate(result[0][0], size=(h, w), mode="bilinear"), 0)
497
+ ma = torch.max(result)
498
+ mi = torch.min(result)
499
+ result = (result - mi) / (ma - mi)
500
+ # image to pil
501
+ im_array = (result * 255).cpu().data.numpy().astype(np.uint8)
502
+
503
+ mask = np.squeeze(im_array)
504
+ if only_mask:
505
+ return mask
506
+
507
+ pil_im = Image.fromarray(mask)
508
+ # paste the mask on the original image
509
+ new_im = Image.new("RGBA", pil_im.size, (0, 0, 0, 0))
510
+ new_im.paste(orig_bgr_image, mask=pil_im)
511
+ rgba_np_img = np.asarray(new_im)
512
+ return rgba_np_img
iopaint/plugins/gfpgan_plugin.py ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import numpy as np
3
+ from loguru import logger
4
+
5
+ from iopaint.helper import download_model
6
+ from iopaint.plugins.base_plugin import BasePlugin
7
+ from iopaint.schema import RunPluginRequest
8
+
9
+
10
+ class GFPGANPlugin(BasePlugin):
11
+ name = "GFPGAN"
12
+ support_gen_image = True
13
+
14
+ def __init__(self, device, upscaler=None):
15
+ super().__init__()
16
+ from .gfpganer import MyGFPGANer
17
+
18
+ url = "https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.4.pth"
19
+ model_md5 = "94d735072630ab734561130a47bc44f8"
20
+ model_path = download_model(url, model_md5)
21
+ logger.info(f"GFPGAN model path: {model_path}")
22
+
23
+ import facexlib
24
+
25
+ if hasattr(facexlib.detection.retinaface, "device"):
26
+ facexlib.detection.retinaface.device = device
27
+
28
+ # Use GFPGAN for face enhancement
29
+ self.face_enhancer = MyGFPGANer(
30
+ model_path=model_path,
31
+ upscale=1,
32
+ arch="clean",
33
+ channel_multiplier=2,
34
+ device=device,
35
+ bg_upsampler=upscaler.model if upscaler is not None else None,
36
+ )
37
+ self.face_enhancer.face_helper.face_det.mean_tensor.to(device)
38
+ self.face_enhancer.face_helper.face_det = (
39
+ self.face_enhancer.face_helper.face_det.to(device)
40
+ )
41
+
42
+ def gen_image(self, rgb_np_img, req: RunPluginRequest) -> np.ndarray:
43
+ weight = 0.5
44
+ bgr_np_img = cv2.cvtColor(rgb_np_img, cv2.COLOR_RGB2BGR)
45
+ logger.info(f"GFPGAN input shape: {bgr_np_img.shape}")
46
+ _, _, bgr_output = self.face_enhancer.enhance(
47
+ bgr_np_img,
48
+ has_aligned=False,
49
+ only_center_face=False,
50
+ paste_back=True,
51
+ weight=weight,
52
+ )
53
+ logger.info(f"GFPGAN output shape: {bgr_output.shape}")
54
+
55
+ # try:
56
+ # if scale != 2:
57
+ # interpolation = cv2.INTER_AREA if scale < 2 else cv2.INTER_LANCZOS4
58
+ # h, w = img.shape[0:2]
59
+ # output = cv2.resize(
60
+ # output,
61
+ # (int(w * scale / 2), int(h * scale / 2)),
62
+ # interpolation=interpolation,
63
+ # )
64
+ # except Exception as error:
65
+ # print("wrong scale input.", error)
66
+ return bgr_output
67
+
68
+ def check_dep(self):
69
+ try:
70
+ import gfpgan
71
+ except ImportError:
72
+ return (
73
+ "gfpgan is not installed, please install it first. pip install gfpgan"
74
+ )
iopaint/plugins/gfpganer.py ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ import torch
4
+ from facexlib.utils.face_restoration_helper import FaceRestoreHelper
5
+ from gfpgan import GFPGANv1Clean, GFPGANer
6
+ from torch.hub import get_dir
7
+
8
+
9
+ class MyGFPGANer(GFPGANer):
10
+ """Helper for restoration with GFPGAN.
11
+
12
+ It will detect and crop faces, and then resize the faces to 512x512.
13
+ GFPGAN is used to restored the resized faces.
14
+ The background is upsampled with the bg_upsampler.
15
+ Finally, the faces will be pasted back to the upsample background image.
16
+
17
+ Args:
18
+ model_path (str): The path to the GFPGAN model. It can be urls (will first download it automatically).
19
+ upscale (float): The upscale of the final output. Default: 2.
20
+ arch (str): The GFPGAN architecture. Option: clean | original. Default: clean.
21
+ channel_multiplier (int): Channel multiplier for large networks of StyleGAN2. Default: 2.
22
+ bg_upsampler (nn.Module): The upsampler for the background. Default: None.
23
+ """
24
+
25
+ def __init__(
26
+ self,
27
+ model_path,
28
+ upscale=2,
29
+ arch="clean",
30
+ channel_multiplier=2,
31
+ bg_upsampler=None,
32
+ device=None,
33
+ ):
34
+ self.upscale = upscale
35
+ self.bg_upsampler = bg_upsampler
36
+
37
+ # initialize model
38
+ self.device = (
39
+ torch.device("cuda" if torch.cuda.is_available() else "cpu")
40
+ if device is None
41
+ else device
42
+ )
43
+ # initialize the GFP-GAN
44
+ if arch == "clean":
45
+ self.gfpgan = GFPGANv1Clean(
46
+ out_size=512,
47
+ num_style_feat=512,
48
+ channel_multiplier=channel_multiplier,
49
+ decoder_load_path=None,
50
+ fix_decoder=False,
51
+ num_mlp=8,
52
+ input_is_latent=True,
53
+ different_w=True,
54
+ narrow=1,
55
+ sft_half=True,
56
+ )
57
+ elif arch == "RestoreFormer":
58
+ from gfpgan.archs.restoreformer_arch import RestoreFormer
59
+
60
+ self.gfpgan = RestoreFormer()
61
+
62
+ hub_dir = get_dir()
63
+ model_dir = os.path.join(hub_dir, "checkpoints")
64
+
65
+ # initialize face helper
66
+ self.face_helper = FaceRestoreHelper(
67
+ upscale,
68
+ face_size=512,
69
+ crop_ratio=(1, 1),
70
+ det_model="retinaface_resnet50",
71
+ save_ext="png",
72
+ use_parse=True,
73
+ device=self.device,
74
+ model_rootpath=model_dir,
75
+ )
76
+
77
+ loadnet = torch.load(model_path)
78
+ if "params_ema" in loadnet:
79
+ keyname = "params_ema"
80
+ else:
81
+ keyname = "params"
82
+ self.gfpgan.load_state_dict(loadnet[keyname], strict=True)
83
+ self.gfpgan.eval()
84
+ self.gfpgan = self.gfpgan.to(self.device)
iopaint/plugins/interactive_seg.py ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import hashlib
2
+ from typing import List
3
+
4
+ import numpy as np
5
+ import torch
6
+ from loguru import logger
7
+
8
+ from iopaint.helper import download_model
9
+ from iopaint.plugins.base_plugin import BasePlugin
10
+ from iopaint.plugins.segment_anything import SamPredictor, sam_model_registry
11
+ from iopaint.schema import RunPluginRequest
12
+
13
+ # 从小到大
14
+ SEGMENT_ANYTHING_MODELS = {
15
+ "vit_b": {
16
+ "url": "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth",
17
+ "md5": "01ec64d29a2fca3f0661936605ae66f8",
18
+ },
19
+ "vit_l": {
20
+ "url": "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_l_0b3195.pth",
21
+ "md5": "0b3195507c641ddb6910d2bb5adee89c",
22
+ },
23
+ "vit_h": {
24
+ "url": "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth",
25
+ "md5": "4b8939a88964f0f4ff5f5b2642c598a6",
26
+ },
27
+ "mobile_sam": {
28
+ "url": "https://github.com/Sanster/models/releases/download/MobileSAM/mobile_sam.pt",
29
+ "md5": "f3c0d8cda613564d499310dab6c812cd",
30
+ },
31
+ }
32
+
33
+
34
+ class InteractiveSeg(BasePlugin):
35
+ name = "InteractiveSeg"
36
+ support_gen_mask = True
37
+
38
+ def __init__(self, model_name, device):
39
+ super().__init__()
40
+ self.model_name = model_name
41
+ self.device = device
42
+ self._init_session(model_name)
43
+
44
+ def _init_session(self, model_name: str):
45
+ model_path = download_model(
46
+ SEGMENT_ANYTHING_MODELS[model_name]["url"],
47
+ SEGMENT_ANYTHING_MODELS[model_name]["md5"],
48
+ )
49
+ logger.info(f"SegmentAnything model path: {model_path}")
50
+ self.predictor = SamPredictor(
51
+ sam_model_registry[model_name](checkpoint=model_path).to(self.device)
52
+ )
53
+ self.prev_img_md5 = None
54
+
55
+ def switch_model(self, new_model_name):
56
+ if self.model_name == new_model_name:
57
+ return
58
+
59
+ logger.info(
60
+ f"Switching InteractiveSeg model from {self.model_name} to {new_model_name}"
61
+ )
62
+ self._init_session(new_model_name)
63
+ self.model_name = new_model_name
64
+
65
+ def gen_mask(self, rgb_np_img, req: RunPluginRequest) -> np.ndarray:
66
+ img_md5 = hashlib.md5(req.image.encode("utf-8")).hexdigest()
67
+ return self.forward(rgb_np_img, req.clicks, img_md5)
68
+
69
+ @torch.inference_mode()
70
+ def forward(self, rgb_np_img, clicks: List[List], img_md5: str):
71
+ input_point = []
72
+ input_label = []
73
+ for click in clicks:
74
+ x = click[0]
75
+ y = click[1]
76
+ input_point.append([x, y])
77
+ input_label.append(click[2])
78
+
79
+ if img_md5 and img_md5 != self.prev_img_md5:
80
+ self.prev_img_md5 = img_md5
81
+ self.predictor.set_image(rgb_np_img)
82
+
83
+ masks, scores, _ = self.predictor.predict(
84
+ point_coords=np.array(input_point),
85
+ point_labels=np.array(input_label),
86
+ multimask_output=False,
87
+ )
88
+ mask = masks[0].astype(np.uint8) * 255
89
+ return mask
iopaint/plugins/segment_anything/build_sam.py ADDED
@@ -0,0 +1,168 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import torch
8
+
9
+ from functools import partial
10
+
11
+ from iopaint.plugins.segment_anything.modeling.tiny_vit_sam import TinyViT
12
+
13
+ from .modeling import (
14
+ ImageEncoderViT,
15
+ MaskDecoder,
16
+ PromptEncoder,
17
+ Sam,
18
+ TwoWayTransformer,
19
+ )
20
+
21
+
22
+ def build_sam_vit_h(checkpoint=None):
23
+ return _build_sam(
24
+ encoder_embed_dim=1280,
25
+ encoder_depth=32,
26
+ encoder_num_heads=16,
27
+ encoder_global_attn_indexes=[7, 15, 23, 31],
28
+ checkpoint=checkpoint,
29
+ )
30
+
31
+
32
+ build_sam = build_sam_vit_h
33
+
34
+
35
+ def build_sam_vit_l(checkpoint=None):
36
+ return _build_sam(
37
+ encoder_embed_dim=1024,
38
+ encoder_depth=24,
39
+ encoder_num_heads=16,
40
+ encoder_global_attn_indexes=[5, 11, 17, 23],
41
+ checkpoint=checkpoint,
42
+ )
43
+
44
+
45
+ def build_sam_vit_b(checkpoint=None):
46
+ return _build_sam(
47
+ encoder_embed_dim=768,
48
+ encoder_depth=12,
49
+ encoder_num_heads=12,
50
+ encoder_global_attn_indexes=[2, 5, 8, 11],
51
+ checkpoint=checkpoint,
52
+ )
53
+
54
+
55
+ def build_sam_vit_t(checkpoint=None):
56
+ prompt_embed_dim = 256
57
+ image_size = 1024
58
+ vit_patch_size = 16
59
+ image_embedding_size = image_size // vit_patch_size
60
+ mobile_sam = Sam(
61
+ image_encoder=TinyViT(
62
+ img_size=1024,
63
+ in_chans=3,
64
+ num_classes=1000,
65
+ embed_dims=[64, 128, 160, 320],
66
+ depths=[2, 2, 6, 2],
67
+ num_heads=[2, 4, 5, 10],
68
+ window_sizes=[7, 7, 14, 7],
69
+ mlp_ratio=4.0,
70
+ drop_rate=0.0,
71
+ drop_path_rate=0.0,
72
+ use_checkpoint=False,
73
+ mbconv_expand_ratio=4.0,
74
+ local_conv_size=3,
75
+ layer_lr_decay=0.8,
76
+ ),
77
+ prompt_encoder=PromptEncoder(
78
+ embed_dim=prompt_embed_dim,
79
+ image_embedding_size=(image_embedding_size, image_embedding_size),
80
+ input_image_size=(image_size, image_size),
81
+ mask_in_chans=16,
82
+ ),
83
+ mask_decoder=MaskDecoder(
84
+ num_multimask_outputs=3,
85
+ transformer=TwoWayTransformer(
86
+ depth=2,
87
+ embedding_dim=prompt_embed_dim,
88
+ mlp_dim=2048,
89
+ num_heads=8,
90
+ ),
91
+ transformer_dim=prompt_embed_dim,
92
+ iou_head_depth=3,
93
+ iou_head_hidden_dim=256,
94
+ ),
95
+ pixel_mean=[123.675, 116.28, 103.53],
96
+ pixel_std=[58.395, 57.12, 57.375],
97
+ )
98
+
99
+ mobile_sam.eval()
100
+ if checkpoint is not None:
101
+ with open(checkpoint, "rb") as f:
102
+ state_dict = torch.load(f)
103
+ mobile_sam.load_state_dict(state_dict)
104
+ return mobile_sam
105
+
106
+
107
+ sam_model_registry = {
108
+ "default": build_sam,
109
+ "vit_h": build_sam,
110
+ "vit_l": build_sam_vit_l,
111
+ "vit_b": build_sam_vit_b,
112
+ "mobile_sam": build_sam_vit_t,
113
+ }
114
+
115
+
116
+ def _build_sam(
117
+ encoder_embed_dim,
118
+ encoder_depth,
119
+ encoder_num_heads,
120
+ encoder_global_attn_indexes,
121
+ checkpoint=None,
122
+ ):
123
+ prompt_embed_dim = 256
124
+ image_size = 1024
125
+ vit_patch_size = 16
126
+ image_embedding_size = image_size // vit_patch_size
127
+ sam = Sam(
128
+ image_encoder=ImageEncoderViT(
129
+ depth=encoder_depth,
130
+ embed_dim=encoder_embed_dim,
131
+ img_size=image_size,
132
+ mlp_ratio=4,
133
+ norm_layer=partial(torch.nn.LayerNorm, eps=1e-6),
134
+ num_heads=encoder_num_heads,
135
+ patch_size=vit_patch_size,
136
+ qkv_bias=True,
137
+ use_rel_pos=True,
138
+ global_attn_indexes=encoder_global_attn_indexes,
139
+ window_size=14,
140
+ out_chans=prompt_embed_dim,
141
+ ),
142
+ prompt_encoder=PromptEncoder(
143
+ embed_dim=prompt_embed_dim,
144
+ image_embedding_size=(image_embedding_size, image_embedding_size),
145
+ input_image_size=(image_size, image_size),
146
+ mask_in_chans=16,
147
+ ),
148
+ mask_decoder=MaskDecoder(
149
+ num_multimask_outputs=3,
150
+ transformer=TwoWayTransformer(
151
+ depth=2,
152
+ embedding_dim=prompt_embed_dim,
153
+ mlp_dim=2048,
154
+ num_heads=8,
155
+ ),
156
+ transformer_dim=prompt_embed_dim,
157
+ iou_head_depth=3,
158
+ iou_head_hidden_dim=256,
159
+ ),
160
+ pixel_mean=[123.675, 116.28, 103.53],
161
+ pixel_std=[58.395, 57.12, 57.375],
162
+ )
163
+ sam.eval()
164
+ if checkpoint is not None:
165
+ with open(checkpoint, "rb") as f:
166
+ state_dict = torch.load(f)
167
+ sam.load_state_dict(state_dict)
168
+ return sam
iopaint/plugins/segment_anything/modeling/common.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import torch
8
+ import torch.nn as nn
9
+
10
+ from typing import Type
11
+
12
+
13
+ class MLPBlock(nn.Module):
14
+ def __init__(
15
+ self,
16
+ embedding_dim: int,
17
+ mlp_dim: int,
18
+ act: Type[nn.Module] = nn.GELU,
19
+ ) -> None:
20
+ super().__init__()
21
+ self.lin1 = nn.Linear(embedding_dim, mlp_dim)
22
+ self.lin2 = nn.Linear(mlp_dim, embedding_dim)
23
+ self.act = act()
24
+
25
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
26
+ return self.lin2(self.act(self.lin1(x)))
27
+
28
+
29
+ # From https://github.com/facebookresearch/detectron2/blob/main/detectron2/layers/batch_norm.py # noqa
30
+ # Itself from https://github.com/facebookresearch/ConvNeXt/blob/d1fa8f6fef0a165b27399986cc2bdacc92777e40/models/convnext.py#L119 # noqa
31
+ class LayerNorm2d(nn.Module):
32
+ def __init__(self, num_channels: int, eps: float = 1e-6) -> None:
33
+ super().__init__()
34
+ self.weight = nn.Parameter(torch.ones(num_channels))
35
+ self.bias = nn.Parameter(torch.zeros(num_channels))
36
+ self.eps = eps
37
+
38
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
39
+ u = x.mean(1, keepdim=True)
40
+ s = (x - u).pow(2).mean(1, keepdim=True)
41
+ x = (x - u) / torch.sqrt(s + self.eps)
42
+ x = self.weight[:, None, None] * x + self.bias[:, None, None]
43
+ return x
iopaint/plugins/segment_anything/modeling/image_encoder.py ADDED
@@ -0,0 +1,395 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import torch
8
+ import torch.nn as nn
9
+ import torch.nn.functional as F
10
+
11
+ from typing import Optional, Tuple, Type
12
+
13
+ from .common import LayerNorm2d, MLPBlock
14
+
15
+
16
+ # This class and its supporting functions below lightly adapted from the ViTDet backbone available at: https://github.com/facebookresearch/detectron2/blob/main/detectron2/modeling/backbone/vit.py # noqa
17
+ class ImageEncoderViT(nn.Module):
18
+ def __init__(
19
+ self,
20
+ img_size: int = 1024,
21
+ patch_size: int = 16,
22
+ in_chans: int = 3,
23
+ embed_dim: int = 768,
24
+ depth: int = 12,
25
+ num_heads: int = 12,
26
+ mlp_ratio: float = 4.0,
27
+ out_chans: int = 256,
28
+ qkv_bias: bool = True,
29
+ norm_layer: Type[nn.Module] = nn.LayerNorm,
30
+ act_layer: Type[nn.Module] = nn.GELU,
31
+ use_abs_pos: bool = True,
32
+ use_rel_pos: bool = False,
33
+ rel_pos_zero_init: bool = True,
34
+ window_size: int = 0,
35
+ global_attn_indexes: Tuple[int, ...] = (),
36
+ ) -> None:
37
+ """
38
+ Args:
39
+ img_size (int): Input image size.
40
+ patch_size (int): Patch size.
41
+ in_chans (int): Number of input image channels.
42
+ embed_dim (int): Patch embedding dimension.
43
+ depth (int): Depth of ViT.
44
+ num_heads (int): Number of attention heads in each ViT block.
45
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
46
+ qkv_bias (bool): If True, add a learnable bias to query, key, value.
47
+ norm_layer (nn.Module): Normalization layer.
48
+ act_layer (nn.Module): Activation layer.
49
+ use_abs_pos (bool): If True, use absolute positional embeddings.
50
+ use_rel_pos (bool): If True, add relative positional embeddings to the attention map.
51
+ rel_pos_zero_init (bool): If True, zero initialize relative positional parameters.
52
+ window_size (int): Window size for window attention blocks.
53
+ global_attn_indexes (list): Indexes for blocks using global attention.
54
+ """
55
+ super().__init__()
56
+ self.img_size = img_size
57
+
58
+ self.patch_embed = PatchEmbed(
59
+ kernel_size=(patch_size, patch_size),
60
+ stride=(patch_size, patch_size),
61
+ in_chans=in_chans,
62
+ embed_dim=embed_dim,
63
+ )
64
+
65
+ self.pos_embed: Optional[nn.Parameter] = None
66
+ if use_abs_pos:
67
+ # Initialize absolute positional embedding with pretrain image size.
68
+ self.pos_embed = nn.Parameter(
69
+ torch.zeros(1, img_size // patch_size, img_size // patch_size, embed_dim)
70
+ )
71
+
72
+ self.blocks = nn.ModuleList()
73
+ for i in range(depth):
74
+ block = Block(
75
+ dim=embed_dim,
76
+ num_heads=num_heads,
77
+ mlp_ratio=mlp_ratio,
78
+ qkv_bias=qkv_bias,
79
+ norm_layer=norm_layer,
80
+ act_layer=act_layer,
81
+ use_rel_pos=use_rel_pos,
82
+ rel_pos_zero_init=rel_pos_zero_init,
83
+ window_size=window_size if i not in global_attn_indexes else 0,
84
+ input_size=(img_size // patch_size, img_size // patch_size),
85
+ )
86
+ self.blocks.append(block)
87
+
88
+ self.neck = nn.Sequential(
89
+ nn.Conv2d(
90
+ embed_dim,
91
+ out_chans,
92
+ kernel_size=1,
93
+ bias=False,
94
+ ),
95
+ LayerNorm2d(out_chans),
96
+ nn.Conv2d(
97
+ out_chans,
98
+ out_chans,
99
+ kernel_size=3,
100
+ padding=1,
101
+ bias=False,
102
+ ),
103
+ LayerNorm2d(out_chans),
104
+ )
105
+
106
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
107
+ x = self.patch_embed(x)
108
+ if self.pos_embed is not None:
109
+ x = x + self.pos_embed
110
+
111
+ for blk in self.blocks:
112
+ x = blk(x)
113
+
114
+ x = self.neck(x.permute(0, 3, 1, 2))
115
+
116
+ return x
117
+
118
+
119
+ class Block(nn.Module):
120
+ """Transformer blocks with support of window attention and residual propagation blocks"""
121
+
122
+ def __init__(
123
+ self,
124
+ dim: int,
125
+ num_heads: int,
126
+ mlp_ratio: float = 4.0,
127
+ qkv_bias: bool = True,
128
+ norm_layer: Type[nn.Module] = nn.LayerNorm,
129
+ act_layer: Type[nn.Module] = nn.GELU,
130
+ use_rel_pos: bool = False,
131
+ rel_pos_zero_init: bool = True,
132
+ window_size: int = 0,
133
+ input_size: Optional[Tuple[int, int]] = None,
134
+ ) -> None:
135
+ """
136
+ Args:
137
+ dim (int): Number of input channels.
138
+ num_heads (int): Number of attention heads in each ViT block.
139
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
140
+ qkv_bias (bool): If True, add a learnable bias to query, key, value.
141
+ norm_layer (nn.Module): Normalization layer.
142
+ act_layer (nn.Module): Activation layer.
143
+ use_rel_pos (bool): If True, add relative positional embeddings to the attention map.
144
+ rel_pos_zero_init (bool): If True, zero initialize relative positional parameters.
145
+ window_size (int): Window size for window attention blocks. If it equals 0, then
146
+ use global attention.
147
+ input_size (int or None): Input resolution for calculating the relative positional
148
+ parameter size.
149
+ """
150
+ super().__init__()
151
+ self.norm1 = norm_layer(dim)
152
+ self.attn = Attention(
153
+ dim,
154
+ num_heads=num_heads,
155
+ qkv_bias=qkv_bias,
156
+ use_rel_pos=use_rel_pos,
157
+ rel_pos_zero_init=rel_pos_zero_init,
158
+ input_size=input_size if window_size == 0 else (window_size, window_size),
159
+ )
160
+
161
+ self.norm2 = norm_layer(dim)
162
+ self.mlp = MLPBlock(embedding_dim=dim, mlp_dim=int(dim * mlp_ratio), act=act_layer)
163
+
164
+ self.window_size = window_size
165
+
166
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
167
+ shortcut = x
168
+ x = self.norm1(x)
169
+ # Window partition
170
+ if self.window_size > 0:
171
+ H, W = x.shape[1], x.shape[2]
172
+ x, pad_hw = window_partition(x, self.window_size)
173
+
174
+ x = self.attn(x)
175
+ # Reverse window partition
176
+ if self.window_size > 0:
177
+ x = window_unpartition(x, self.window_size, pad_hw, (H, W))
178
+
179
+ x = shortcut + x
180
+ x = x + self.mlp(self.norm2(x))
181
+
182
+ return x
183
+
184
+
185
+ class Attention(nn.Module):
186
+ """Multi-head Attention block with relative position embeddings."""
187
+
188
+ def __init__(
189
+ self,
190
+ dim: int,
191
+ num_heads: int = 8,
192
+ qkv_bias: bool = True,
193
+ use_rel_pos: bool = False,
194
+ rel_pos_zero_init: bool = True,
195
+ input_size: Optional[Tuple[int, int]] = None,
196
+ ) -> None:
197
+ """
198
+ Args:
199
+ dim (int): Number of input channels.
200
+ num_heads (int): Number of attention heads.
201
+ qkv_bias (bool: If True, add a learnable bias to query, key, value.
202
+ rel_pos (bool): If True, add relative positional embeddings to the attention map.
203
+ rel_pos_zero_init (bool): If True, zero initialize relative positional parameters.
204
+ input_size (int or None): Input resolution for calculating the relative positional
205
+ parameter size.
206
+ """
207
+ super().__init__()
208
+ self.num_heads = num_heads
209
+ head_dim = dim // num_heads
210
+ self.scale = head_dim**-0.5
211
+
212
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
213
+ self.proj = nn.Linear(dim, dim)
214
+
215
+ self.use_rel_pos = use_rel_pos
216
+ if self.use_rel_pos:
217
+ assert (
218
+ input_size is not None
219
+ ), "Input size must be provided if using relative positional encoding."
220
+ # initialize relative positional embeddings
221
+ self.rel_pos_h = nn.Parameter(torch.zeros(2 * input_size[0] - 1, head_dim))
222
+ self.rel_pos_w = nn.Parameter(torch.zeros(2 * input_size[1] - 1, head_dim))
223
+
224
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
225
+ B, H, W, _ = x.shape
226
+ # qkv with shape (3, B, nHead, H * W, C)
227
+ qkv = self.qkv(x).reshape(B, H * W, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
228
+ # q, k, v with shape (B * nHead, H * W, C)
229
+ q, k, v = qkv.reshape(3, B * self.num_heads, H * W, -1).unbind(0)
230
+
231
+ attn = (q * self.scale) @ k.transpose(-2, -1)
232
+
233
+ if self.use_rel_pos:
234
+ attn = add_decomposed_rel_pos(attn, q, self.rel_pos_h, self.rel_pos_w, (H, W), (H, W))
235
+
236
+ attn = attn.softmax(dim=-1)
237
+ x = (attn @ v).view(B, self.num_heads, H, W, -1).permute(0, 2, 3, 1, 4).reshape(B, H, W, -1)
238
+ x = self.proj(x)
239
+
240
+ return x
241
+
242
+
243
+ def window_partition(x: torch.Tensor, window_size: int) -> Tuple[torch.Tensor, Tuple[int, int]]:
244
+ """
245
+ Partition into non-overlapping windows with padding if needed.
246
+ Args:
247
+ x (tensor): input tokens with [B, H, W, C].
248
+ window_size (int): window size.
249
+
250
+ Returns:
251
+ windows: windows after partition with [B * num_windows, window_size, window_size, C].
252
+ (Hp, Wp): padded height and width before partition
253
+ """
254
+ B, H, W, C = x.shape
255
+
256
+ pad_h = (window_size - H % window_size) % window_size
257
+ pad_w = (window_size - W % window_size) % window_size
258
+ if pad_h > 0 or pad_w > 0:
259
+ x = F.pad(x, (0, 0, 0, pad_w, 0, pad_h))
260
+ Hp, Wp = H + pad_h, W + pad_w
261
+
262
+ x = x.view(B, Hp // window_size, window_size, Wp // window_size, window_size, C)
263
+ windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
264
+ return windows, (Hp, Wp)
265
+
266
+
267
+ def window_unpartition(
268
+ windows: torch.Tensor, window_size: int, pad_hw: Tuple[int, int], hw: Tuple[int, int]
269
+ ) -> torch.Tensor:
270
+ """
271
+ Window unpartition into original sequences and removing padding.
272
+ Args:
273
+ x (tensor): input tokens with [B * num_windows, window_size, window_size, C].
274
+ window_size (int): window size.
275
+ pad_hw (Tuple): padded height and width (Hp, Wp).
276
+ hw (Tuple): original height and width (H, W) before padding.
277
+
278
+ Returns:
279
+ x: unpartitioned sequences with [B, H, W, C].
280
+ """
281
+ Hp, Wp = pad_hw
282
+ H, W = hw
283
+ B = windows.shape[0] // (Hp * Wp // window_size // window_size)
284
+ x = windows.view(B, Hp // window_size, Wp // window_size, window_size, window_size, -1)
285
+ x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, Hp, Wp, -1)
286
+
287
+ if Hp > H or Wp > W:
288
+ x = x[:, :H, :W, :].contiguous()
289
+ return x
290
+
291
+
292
+ def get_rel_pos(q_size: int, k_size: int, rel_pos: torch.Tensor) -> torch.Tensor:
293
+ """
294
+ Get relative positional embeddings according to the relative positions of
295
+ query and key sizes.
296
+ Args:
297
+ q_size (int): size of query q.
298
+ k_size (int): size of key k.
299
+ rel_pos (Tensor): relative position embeddings (L, C).
300
+
301
+ Returns:
302
+ Extracted positional embeddings according to relative positions.
303
+ """
304
+ max_rel_dist = int(2 * max(q_size, k_size) - 1)
305
+ # Interpolate rel pos if needed.
306
+ if rel_pos.shape[0] != max_rel_dist:
307
+ # Interpolate rel pos.
308
+ rel_pos_resized = F.interpolate(
309
+ rel_pos.reshape(1, rel_pos.shape[0], -1).permute(0, 2, 1),
310
+ size=max_rel_dist,
311
+ mode="linear",
312
+ )
313
+ rel_pos_resized = rel_pos_resized.reshape(-1, max_rel_dist).permute(1, 0)
314
+ else:
315
+ rel_pos_resized = rel_pos
316
+
317
+ # Scale the coords with short length if shapes for q and k are different.
318
+ q_coords = torch.arange(q_size)[:, None] * max(k_size / q_size, 1.0)
319
+ k_coords = torch.arange(k_size)[None, :] * max(q_size / k_size, 1.0)
320
+ relative_coords = (q_coords - k_coords) + (k_size - 1) * max(q_size / k_size, 1.0)
321
+
322
+ return rel_pos_resized[relative_coords.long()]
323
+
324
+
325
+ def add_decomposed_rel_pos(
326
+ attn: torch.Tensor,
327
+ q: torch.Tensor,
328
+ rel_pos_h: torch.Tensor,
329
+ rel_pos_w: torch.Tensor,
330
+ q_size: Tuple[int, int],
331
+ k_size: Tuple[int, int],
332
+ ) -> torch.Tensor:
333
+ """
334
+ Calculate decomposed Relative Positional Embeddings from :paper:`mvitv2`.
335
+ https://github.com/facebookresearch/mvit/blob/19786631e330df9f3622e5402b4a419a263a2c80/mvit/models/attention.py # noqa B950
336
+ Args:
337
+ attn (Tensor): attention map.
338
+ q (Tensor): query q in the attention layer with shape (B, q_h * q_w, C).
339
+ rel_pos_h (Tensor): relative position embeddings (Lh, C) for height axis.
340
+ rel_pos_w (Tensor): relative position embeddings (Lw, C) for width axis.
341
+ q_size (Tuple): spatial sequence size of query q with (q_h, q_w).
342
+ k_size (Tuple): spatial sequence size of key k with (k_h, k_w).
343
+
344
+ Returns:
345
+ attn (Tensor): attention map with added relative positional embeddings.
346
+ """
347
+ q_h, q_w = q_size
348
+ k_h, k_w = k_size
349
+ Rh = get_rel_pos(q_h, k_h, rel_pos_h)
350
+ Rw = get_rel_pos(q_w, k_w, rel_pos_w)
351
+
352
+ B, _, dim = q.shape
353
+ r_q = q.reshape(B, q_h, q_w, dim)
354
+ rel_h = torch.einsum("bhwc,hkc->bhwk", r_q, Rh)
355
+ rel_w = torch.einsum("bhwc,wkc->bhwk", r_q, Rw)
356
+
357
+ attn = (
358
+ attn.view(B, q_h, q_w, k_h, k_w) + rel_h[:, :, :, :, None] + rel_w[:, :, :, None, :]
359
+ ).view(B, q_h * q_w, k_h * k_w)
360
+
361
+ return attn
362
+
363
+
364
+ class PatchEmbed(nn.Module):
365
+ """
366
+ Image to Patch Embedding.
367
+ """
368
+
369
+ def __init__(
370
+ self,
371
+ kernel_size: Tuple[int, int] = (16, 16),
372
+ stride: Tuple[int, int] = (16, 16),
373
+ padding: Tuple[int, int] = (0, 0),
374
+ in_chans: int = 3,
375
+ embed_dim: int = 768,
376
+ ) -> None:
377
+ """
378
+ Args:
379
+ kernel_size (Tuple): kernel size of the projection layer.
380
+ stride (Tuple): stride of the projection layer.
381
+ padding (Tuple): padding size of the projection layer.
382
+ in_chans (int): Number of input image channels.
383
+ embed_dim (int): embed_dim (int): Patch embedding dimension.
384
+ """
385
+ super().__init__()
386
+
387
+ self.proj = nn.Conv2d(
388
+ in_chans, embed_dim, kernel_size=kernel_size, stride=stride, padding=padding
389
+ )
390
+
391
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
392
+ x = self.proj(x)
393
+ # B C H W -> B H W C
394
+ x = x.permute(0, 2, 3, 1)
395
+ return x
iopaint/plugins/segment_anything/modeling/mask_decoder.py ADDED
@@ -0,0 +1,176 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import torch
8
+ from torch import nn
9
+ from torch.nn import functional as F
10
+
11
+ from typing import List, Tuple, Type
12
+
13
+ from .common import LayerNorm2d
14
+
15
+
16
+ class MaskDecoder(nn.Module):
17
+ def __init__(
18
+ self,
19
+ *,
20
+ transformer_dim: int,
21
+ transformer: nn.Module,
22
+ num_multimask_outputs: int = 3,
23
+ activation: Type[nn.Module] = nn.GELU,
24
+ iou_head_depth: int = 3,
25
+ iou_head_hidden_dim: int = 256,
26
+ ) -> None:
27
+ """
28
+ Predicts masks given an image and prompt embeddings, using a
29
+ tranformer architecture.
30
+
31
+ Arguments:
32
+ transformer_dim (int): the channel dimension of the transformer
33
+ transformer (nn.Module): the transformer used to predict masks
34
+ num_multimask_outputs (int): the number of masks to predict
35
+ when disambiguating masks
36
+ activation (nn.Module): the type of activation to use when
37
+ upscaling masks
38
+ iou_head_depth (int): the depth of the MLP used to predict
39
+ mask quality
40
+ iou_head_hidden_dim (int): the hidden dimension of the MLP
41
+ used to predict mask quality
42
+ """
43
+ super().__init__()
44
+ self.transformer_dim = transformer_dim
45
+ self.transformer = transformer
46
+
47
+ self.num_multimask_outputs = num_multimask_outputs
48
+
49
+ self.iou_token = nn.Embedding(1, transformer_dim)
50
+ self.num_mask_tokens = num_multimask_outputs + 1
51
+ self.mask_tokens = nn.Embedding(self.num_mask_tokens, transformer_dim)
52
+
53
+ self.output_upscaling = nn.Sequential(
54
+ nn.ConvTranspose2d(transformer_dim, transformer_dim // 4, kernel_size=2, stride=2),
55
+ LayerNorm2d(transformer_dim // 4),
56
+ activation(),
57
+ nn.ConvTranspose2d(transformer_dim // 4, transformer_dim // 8, kernel_size=2, stride=2),
58
+ activation(),
59
+ )
60
+ self.output_hypernetworks_mlps = nn.ModuleList(
61
+ [
62
+ MLP(transformer_dim, transformer_dim, transformer_dim // 8, 3)
63
+ for i in range(self.num_mask_tokens)
64
+ ]
65
+ )
66
+
67
+ self.iou_prediction_head = MLP(
68
+ transformer_dim, iou_head_hidden_dim, self.num_mask_tokens, iou_head_depth
69
+ )
70
+
71
+ def forward(
72
+ self,
73
+ image_embeddings: torch.Tensor,
74
+ image_pe: torch.Tensor,
75
+ sparse_prompt_embeddings: torch.Tensor,
76
+ dense_prompt_embeddings: torch.Tensor,
77
+ multimask_output: bool,
78
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
79
+ """
80
+ Predict masks given image and prompt embeddings.
81
+
82
+ Arguments:
83
+ image_embeddings (torch.Tensor): the embeddings from the image encoder
84
+ image_pe (torch.Tensor): positional encoding with the shape of image_embeddings
85
+ sparse_prompt_embeddings (torch.Tensor): the embeddings of the points and boxes
86
+ dense_prompt_embeddings (torch.Tensor): the embeddings of the mask inputs
87
+ multimask_output (bool): Whether to return multiple masks or a single
88
+ mask.
89
+
90
+ Returns:
91
+ torch.Tensor: batched predicted masks
92
+ torch.Tensor: batched predictions of mask quality
93
+ """
94
+ masks, iou_pred = self.predict_masks(
95
+ image_embeddings=image_embeddings,
96
+ image_pe=image_pe,
97
+ sparse_prompt_embeddings=sparse_prompt_embeddings,
98
+ dense_prompt_embeddings=dense_prompt_embeddings,
99
+ )
100
+
101
+ # Select the correct mask or masks for outptu
102
+ if multimask_output:
103
+ mask_slice = slice(1, None)
104
+ else:
105
+ mask_slice = slice(0, 1)
106
+ masks = masks[:, mask_slice, :, :]
107
+ iou_pred = iou_pred[:, mask_slice]
108
+
109
+ # Prepare output
110
+ return masks, iou_pred
111
+
112
+ def predict_masks(
113
+ self,
114
+ image_embeddings: torch.Tensor,
115
+ image_pe: torch.Tensor,
116
+ sparse_prompt_embeddings: torch.Tensor,
117
+ dense_prompt_embeddings: torch.Tensor,
118
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
119
+ """Predicts masks. See 'forward' for more details."""
120
+ # Concatenate output tokens
121
+ output_tokens = torch.cat([self.iou_token.weight, self.mask_tokens.weight], dim=0)
122
+ output_tokens = output_tokens.unsqueeze(0).expand(sparse_prompt_embeddings.size(0), -1, -1)
123
+ tokens = torch.cat((output_tokens, sparse_prompt_embeddings), dim=1)
124
+
125
+ # Expand per-image data in batch direction to be per-mask
126
+ src = torch.repeat_interleave(image_embeddings, tokens.shape[0], dim=0)
127
+ src = src + dense_prompt_embeddings
128
+ pos_src = torch.repeat_interleave(image_pe, tokens.shape[0], dim=0)
129
+ b, c, h, w = src.shape
130
+
131
+ # Run the transformer
132
+ hs, src = self.transformer(src, pos_src, tokens)
133
+ iou_token_out = hs[:, 0, :]
134
+ mask_tokens_out = hs[:, 1 : (1 + self.num_mask_tokens), :]
135
+
136
+ # Upscale mask embeddings and predict masks using the mask tokens
137
+ src = src.transpose(1, 2).view(b, c, h, w)
138
+ upscaled_embedding = self.output_upscaling(src)
139
+ hyper_in_list: List[torch.Tensor] = []
140
+ for i in range(self.num_mask_tokens):
141
+ hyper_in_list.append(self.output_hypernetworks_mlps[i](mask_tokens_out[:, i, :]))
142
+ hyper_in = torch.stack(hyper_in_list, dim=1)
143
+ b, c, h, w = upscaled_embedding.shape
144
+ masks = (hyper_in @ upscaled_embedding.view(b, c, h * w)).view(b, -1, h, w)
145
+
146
+ # Generate mask quality predictions
147
+ iou_pred = self.iou_prediction_head(iou_token_out)
148
+
149
+ return masks, iou_pred
150
+
151
+
152
+ # Lightly adapted from
153
+ # https://github.com/facebookresearch/MaskFormer/blob/main/mask_former/modeling/transformer/transformer_predictor.py # noqa
154
+ class MLP(nn.Module):
155
+ def __init__(
156
+ self,
157
+ input_dim: int,
158
+ hidden_dim: int,
159
+ output_dim: int,
160
+ num_layers: int,
161
+ sigmoid_output: bool = False,
162
+ ) -> None:
163
+ super().__init__()
164
+ self.num_layers = num_layers
165
+ h = [hidden_dim] * (num_layers - 1)
166
+ self.layers = nn.ModuleList(
167
+ nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim])
168
+ )
169
+ self.sigmoid_output = sigmoid_output
170
+
171
+ def forward(self, x):
172
+ for i, layer in enumerate(self.layers):
173
+ x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x)
174
+ if self.sigmoid_output:
175
+ x = F.sigmoid(x)
176
+ return x
model/networks.py ADDED
@@ -0,0 +1,563 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from torch.nn.utils import spectral_norm as spectral_norm_fn
5
+ from torch.nn.utils import weight_norm as weight_norm_fn
6
+ from PIL import Image
7
+ from torchvision import transforms
8
+ from torchvision import utils as vutils
9
+
10
+ from utils.tools import extract_image_patches, flow_to_image, \
11
+ reduce_mean, reduce_sum, default_loader, same_padding
12
+
13
+
14
+ class Generator(nn.Module):
15
+ def __init__(self, config, use_cuda, device_ids):
16
+ super(Generator, self).__init__()
17
+ self.input_dim = config['input_dim']
18
+ self.cnum = config['ngf']
19
+ self.use_cuda = use_cuda
20
+ self.device_ids = device_ids
21
+
22
+ self.coarse_generator = CoarseGenerator(self.input_dim, self.cnum, self.use_cuda, self.device_ids)
23
+ self.fine_generator = FineGenerator(self.input_dim, self.cnum, self.use_cuda, self.device_ids)
24
+
25
+ def forward(self, x, mask):
26
+ x_stage1 = self.coarse_generator(x, mask)
27
+ x_stage2, offset_flow = self.fine_generator(x, x_stage1, mask)
28
+ return x_stage1, x_stage2, offset_flow
29
+
30
+
31
+ class CoarseGenerator(nn.Module):
32
+ def __init__(self, input_dim, cnum, use_cuda=True, device_ids=None):
33
+ super(CoarseGenerator, self).__init__()
34
+ self.use_cuda = use_cuda
35
+ self.device_ids = device_ids
36
+
37
+ self.conv1 = gen_conv(input_dim + 2, cnum, 5, 1, 2)
38
+ self.conv2_downsample = gen_conv(cnum, cnum*2, 3, 2, 1)
39
+ self.conv3 = gen_conv(cnum*2, cnum*2, 3, 1, 1)
40
+ self.conv4_downsample = gen_conv(cnum*2, cnum*4, 3, 2, 1)
41
+ self.conv5 = gen_conv(cnum*4, cnum*4, 3, 1, 1)
42
+ self.conv6 = gen_conv(cnum*4, cnum*4, 3, 1, 1)
43
+
44
+ self.conv7_atrous = gen_conv(cnum*4, cnum*4, 3, 1, 2, rate=2)
45
+ self.conv8_atrous = gen_conv(cnum*4, cnum*4, 3, 1, 4, rate=4)
46
+ self.conv9_atrous = gen_conv(cnum*4, cnum*4, 3, 1, 8, rate=8)
47
+ self.conv10_atrous = gen_conv(cnum*4, cnum*4, 3, 1, 16, rate=16)
48
+
49
+ self.conv11 = gen_conv(cnum*4, cnum*4, 3, 1, 1)
50
+ self.conv12 = gen_conv(cnum*4, cnum*4, 3, 1, 1)
51
+
52
+ self.conv13 = gen_conv(cnum*4, cnum*2, 3, 1, 1)
53
+ self.conv14 = gen_conv(cnum*2, cnum*2, 3, 1, 1)
54
+ self.conv15 = gen_conv(cnum*2, cnum, 3, 1, 1)
55
+ self.conv16 = gen_conv(cnum, cnum//2, 3, 1, 1)
56
+ self.conv17 = gen_conv(cnum//2, input_dim, 3, 1, 1, activation='none')
57
+
58
+ def forward(self, x, mask):
59
+ # For indicating the boundaries of images
60
+ ones = torch.ones(x.size(0), 1, x.size(2), x.size(3))
61
+ if self.use_cuda:
62
+ ones = ones.cuda()
63
+ mask = mask.cuda()
64
+ # 5 x 256 x 256
65
+ x = self.conv1(torch.cat([x, ones, mask], dim=1))
66
+ x = self.conv2_downsample(x)
67
+ # cnum*2 x 128 x 128
68
+ x = self.conv3(x)
69
+ x = self.conv4_downsample(x)
70
+ # cnum*4 x 64 x 64
71
+ x = self.conv5(x)
72
+ x = self.conv6(x)
73
+ x = self.conv7_atrous(x)
74
+ x = self.conv8_atrous(x)
75
+ x = self.conv9_atrous(x)
76
+ x = self.conv10_atrous(x)
77
+ x = self.conv11(x)
78
+ x = self.conv12(x)
79
+ x = F.interpolate(x, scale_factor=2, mode='nearest')
80
+ # cnum*2 x 128 x 128
81
+ x = self.conv13(x)
82
+ x = self.conv14(x)
83
+ x = F.interpolate(x, scale_factor=2, mode='nearest')
84
+ # cnum x 256 x 256
85
+ x = self.conv15(x)
86
+ x = self.conv16(x)
87
+ x = self.conv17(x)
88
+ # 3 x 256 x 256
89
+ x_stage1 = torch.clamp(x, -1., 1.)
90
+
91
+ return x_stage1
92
+
93
+
94
+ class FineGenerator(nn.Module):
95
+ def __init__(self, input_dim, cnum, use_cuda=True, device_ids=None):
96
+ super(FineGenerator, self).__init__()
97
+ self.use_cuda = use_cuda
98
+ self.device_ids = device_ids
99
+
100
+ # 3 x 256 x 256
101
+ self.conv1 = gen_conv(input_dim + 2, cnum, 5, 1, 2)
102
+ self.conv2_downsample = gen_conv(cnum, cnum, 3, 2, 1)
103
+ # cnum*2 x 128 x 128
104
+ self.conv3 = gen_conv(cnum, cnum*2, 3, 1, 1)
105
+ self.conv4_downsample = gen_conv(cnum*2, cnum*2, 3, 2, 1)
106
+ # cnum*4 x 64 x 64
107
+ self.conv5 = gen_conv(cnum*2, cnum*4, 3, 1, 1)
108
+ self.conv6 = gen_conv(cnum*4, cnum*4, 3, 1, 1)
109
+
110
+ self.conv7_atrous = gen_conv(cnum*4, cnum*4, 3, 1, 2, rate=2)
111
+ self.conv8_atrous = gen_conv(cnum*4, cnum*4, 3, 1, 4, rate=4)
112
+ self.conv9_atrous = gen_conv(cnum*4, cnum*4, 3, 1, 8, rate=8)
113
+ self.conv10_atrous = gen_conv(cnum*4, cnum*4, 3, 1, 16, rate=16)
114
+
115
+ # attention branch
116
+ # 3 x 256 x 256
117
+ self.pmconv1 = gen_conv(input_dim + 2, cnum, 5, 1, 2)
118
+ self.pmconv2_downsample = gen_conv(cnum, cnum, 3, 2, 1)
119
+ # cnum*2 x 128 x 128
120
+ self.pmconv3 = gen_conv(cnum, cnum*2, 3, 1, 1)
121
+ self.pmconv4_downsample = gen_conv(cnum*2, cnum*4, 3, 2, 1)
122
+ # cnum*4 x 64 x 64
123
+ self.pmconv5 = gen_conv(cnum*4, cnum*4, 3, 1, 1)
124
+ self.pmconv6 = gen_conv(cnum*4, cnum*4, 3, 1, 1, activation='relu')
125
+ self.contextul_attention = ContextualAttention(ksize=3, stride=1, rate=2, fuse_k=3, softmax_scale=10,
126
+ fuse=True, use_cuda=self.use_cuda, device_ids=self.device_ids)
127
+ self.pmconv9 = gen_conv(cnum*4, cnum*4, 3, 1, 1)
128
+ self.pmconv10 = gen_conv(cnum*4, cnum*4, 3, 1, 1)
129
+ self.allconv11 = gen_conv(cnum*8, cnum*4, 3, 1, 1)
130
+ self.allconv12 = gen_conv(cnum*4, cnum*4, 3, 1, 1)
131
+ self.allconv13 = gen_conv(cnum*4, cnum*2, 3, 1, 1)
132
+ self.allconv14 = gen_conv(cnum*2, cnum*2, 3, 1, 1)
133
+ self.allconv15 = gen_conv(cnum*2, cnum, 3, 1, 1)
134
+ self.allconv16 = gen_conv(cnum, cnum//2, 3, 1, 1)
135
+ self.allconv17 = gen_conv(cnum//2, input_dim, 3, 1, 1, activation='none')
136
+
137
+ def forward(self, xin, x_stage1, mask):
138
+ x1_inpaint = x_stage1 * mask + xin * (1. - mask)
139
+ # For indicating the boundaries of images
140
+ ones = torch.ones(xin.size(0), 1, xin.size(2), xin.size(3))
141
+ if self.use_cuda:
142
+ ones = ones.cuda()
143
+ mask = mask.cuda()
144
+ # conv branch
145
+ xnow = torch.cat([x1_inpaint, ones, mask], dim=1)
146
+ x = self.conv1(xnow)
147
+ x = self.conv2_downsample(x)
148
+ x = self.conv3(x)
149
+ x = self.conv4_downsample(x)
150
+ x = self.conv5(x)
151
+ x = self.conv6(x)
152
+ x = self.conv7_atrous(x)
153
+ x = self.conv8_atrous(x)
154
+ x = self.conv9_atrous(x)
155
+ x = self.conv10_atrous(x)
156
+ x_hallu = x
157
+ # attention branch
158
+ x = self.pmconv1(xnow)
159
+ x = self.pmconv2_downsample(x)
160
+ x = self.pmconv3(x)
161
+ x = self.pmconv4_downsample(x)
162
+ x = self.pmconv5(x)
163
+ x = self.pmconv6(x)
164
+ x, offset_flow = self.contextul_attention(x, x, mask)
165
+ x = self.pmconv9(x)
166
+ x = self.pmconv10(x)
167
+ pm = x
168
+ x = torch.cat([x_hallu, pm], dim=1)
169
+ # merge two branches
170
+ x = self.allconv11(x)
171
+ x = self.allconv12(x)
172
+ x = F.interpolate(x, scale_factor=2, mode='nearest')
173
+ x = self.allconv13(x)
174
+ x = self.allconv14(x)
175
+ x = F.interpolate(x, scale_factor=2, mode='nearest')
176
+ x = self.allconv15(x)
177
+ x = self.allconv16(x)
178
+ x = self.allconv17(x)
179
+ x_stage2 = torch.clamp(x, -1., 1.)
180
+
181
+ return x_stage2, offset_flow
182
+
183
+
184
+ class ContextualAttention(nn.Module):
185
+ def __init__(self, ksize=3, stride=1, rate=1, fuse_k=3, softmax_scale=10,
186
+ fuse=False, use_cuda=False, device_ids=None):
187
+ super(ContextualAttention, self).__init__()
188
+ self.ksize = ksize
189
+ self.stride = stride
190
+ self.rate = rate
191
+ self.fuse_k = fuse_k
192
+ self.softmax_scale = softmax_scale
193
+ self.fuse = fuse
194
+ self.use_cuda = use_cuda
195
+ self.device_ids = device_ids
196
+
197
+ def forward(self, f, b, mask=None):
198
+ """ Contextual attention layer implementation.
199
+ Contextual attention is first introduced in publication:
200
+ Generative Image Inpainting with Contextual Attention, Yu et al.
201
+ Args:
202
+ f: Input feature to match (foreground).
203
+ b: Input feature for match (background).
204
+ mask: Input mask for b, indicating patches not available.
205
+ ksize: Kernel size for contextual attention.
206
+ stride: Stride for extracting patches from b.
207
+ rate: Dilation for matching.
208
+ softmax_scale: Scaled softmax for attention.
209
+ Returns:
210
+ torch.tensor: output
211
+ """
212
+ # get shapes
213
+ raw_int_fs = list(f.size()) # b*c*h*w
214
+ raw_int_bs = list(b.size()) # b*c*h*w
215
+
216
+ # extract patches from background with stride and rate
217
+ kernel = 2 * self.rate
218
+ # raw_w is extracted for reconstruction
219
+ raw_w = extract_image_patches(b, ksizes=[kernel, kernel],
220
+ strides=[self.rate*self.stride,
221
+ self.rate*self.stride],
222
+ rates=[1, 1],
223
+ padding='same') # [N, C*k*k, L]
224
+ # raw_shape: [N, C, k, k, L]
225
+ raw_w = raw_w.view(raw_int_bs[0], raw_int_bs[1], kernel, kernel, -1)
226
+ raw_w = raw_w.permute(0, 4, 1, 2, 3) # raw_shape: [N, L, C, k, k]
227
+ raw_w_groups = torch.split(raw_w, 1, dim=0)
228
+
229
+ # downscaling foreground option: downscaling both foreground and
230
+ # background for matching and use original background for reconstruction.
231
+ f = F.interpolate(f, scale_factor=1./self.rate, mode='nearest')
232
+ b = F.interpolate(b, scale_factor=1./self.rate, mode='nearest')
233
+ int_fs = list(f.size()) # b*c*h*w
234
+ int_bs = list(b.size())
235
+ f_groups = torch.split(f, 1, dim=0) # split tensors along the batch dimension
236
+ # w shape: [N, C*k*k, L]
237
+ w = extract_image_patches(b, ksizes=[self.ksize, self.ksize],
238
+ strides=[self.stride, self.stride],
239
+ rates=[1, 1],
240
+ padding='same')
241
+ # w shape: [N, C, k, k, L]
242
+ w = w.view(int_bs[0], int_bs[1], self.ksize, self.ksize, -1)
243
+ w = w.permute(0, 4, 1, 2, 3) # w shape: [N, L, C, k, k]
244
+ w_groups = torch.split(w, 1, dim=0)
245
+
246
+ # process mask
247
+ if mask is None:
248
+ mask = torch.zeros([int_bs[0], 1, int_bs[2], int_bs[3]])
249
+ if self.use_cuda:
250
+ mask = mask.cuda()
251
+ else:
252
+ mask = F.interpolate(mask, scale_factor=1./(4*self.rate), mode='nearest')
253
+ int_ms = list(mask.size())
254
+ # m shape: [N, C*k*k, L]
255
+ m = extract_image_patches(mask, ksizes=[self.ksize, self.ksize],
256
+ strides=[self.stride, self.stride],
257
+ rates=[1, 1],
258
+ padding='same')
259
+ # m shape: [N, C, k, k, L]
260
+ m = m.view(int_ms[0], int_ms[1], self.ksize, self.ksize, -1)
261
+ m = m.permute(0, 4, 1, 2, 3) # m shape: [N, L, C, k, k]
262
+ m = m[0] # m shape: [L, C, k, k]
263
+ # mm shape: [L, 1, 1, 1]
264
+ mm = (reduce_mean(m, axis=[1, 2, 3], keepdim=True)==0.).to(torch.float32)
265
+ mm = mm.permute(1, 0, 2, 3) # mm shape: [1, L, 1, 1]
266
+
267
+ y = []
268
+ offsets = []
269
+ k = self.fuse_k
270
+ scale = self.softmax_scale # to fit the PyTorch tensor image value range
271
+ fuse_weight = torch.eye(k).view(1, 1, k, k) # 1*1*k*k
272
+ if self.use_cuda:
273
+ fuse_weight = fuse_weight.cuda()
274
+
275
+ for xi, wi, raw_wi in zip(f_groups, w_groups, raw_w_groups):
276
+ '''
277
+ O => output channel as a conv filter
278
+ I => input channel as a conv filter
279
+ xi : separated tensor along batch dimension of front; (B=1, C=128, H=32, W=32)
280
+ wi : separated patch tensor along batch dimension of back; (B=1, O=32*32, I=128, KH=3, KW=3)
281
+ raw_wi : separated tensor along batch dimension of back; (B=1, I=32*32, O=128, KH=4, KW=4)
282
+ '''
283
+ # conv for compare
284
+ escape_NaN = torch.FloatTensor([1e-4])
285
+ if self.use_cuda:
286
+ escape_NaN = escape_NaN.cuda()
287
+ wi = wi[0] # [L, C, k, k]
288
+ max_wi = torch.sqrt(reduce_sum(torch.pow(wi, 2) + escape_NaN, axis=[1, 2, 3], keepdim=True))
289
+ wi_normed = wi / max_wi
290
+ # xi shape: [1, C, H, W], yi shape: [1, L, H, W]
291
+ xi = same_padding(xi, [self.ksize, self.ksize], [1, 1], [1, 1]) # xi: 1*c*H*W
292
+ yi = F.conv2d(xi, wi_normed, stride=1) # [1, L, H, W]
293
+ # conv implementation for fuse scores to encourage large patches
294
+ if self.fuse:
295
+ # make all of depth to spatial resolution
296
+ yi = yi.view(1, 1, int_bs[2]*int_bs[3], int_fs[2]*int_fs[3]) # (B=1, I=1, H=32*32, W=32*32)
297
+ yi = same_padding(yi, [k, k], [1, 1], [1, 1])
298
+ yi = F.conv2d(yi, fuse_weight, stride=1) # (B=1, C=1, H=32*32, W=32*32)
299
+ yi = yi.contiguous().view(1, int_bs[2], int_bs[3], int_fs[2], int_fs[3]) # (B=1, 32, 32, 32, 32)
300
+ yi = yi.permute(0, 2, 1, 4, 3)
301
+ yi = yi.contiguous().view(1, 1, int_bs[2]*int_bs[3], int_fs[2]*int_fs[3])
302
+ yi = same_padding(yi, [k, k], [1, 1], [1, 1])
303
+ yi = F.conv2d(yi, fuse_weight, stride=1)
304
+ yi = yi.contiguous().view(1, int_bs[3], int_bs[2], int_fs[3], int_fs[2])
305
+ yi = yi.permute(0, 2, 1, 4, 3).contiguous()
306
+ yi = yi.view(1, int_bs[2] * int_bs[3], int_fs[2], int_fs[3]) # (B=1, C=32*32, H=32, W=32)
307
+ # softmax to match
308
+ yi = yi * mm
309
+ yi = F.softmax(yi*scale, dim=1)
310
+ yi = yi * mm # [1, L, H, W]
311
+
312
+ offset = torch.argmax(yi, dim=1, keepdim=True) # 1*1*H*W
313
+
314
+ if int_bs != int_fs:
315
+ # Normalize the offset value to match foreground dimension
316
+ times = float(int_fs[2] * int_fs[3]) / float(int_bs[2] * int_bs[3])
317
+ offset = ((offset + 1).float() * times - 1).to(torch.int64)
318
+ offset = torch.cat([offset//int_fs[3], offset%int_fs[3]], dim=1) # 1*2*H*W
319
+
320
+ # deconv for patch pasting
321
+ wi_center = raw_wi[0]
322
+ # yi = F.pad(yi, [0, 1, 0, 1]) # here may need conv_transpose same padding
323
+ yi = F.conv_transpose2d(yi, wi_center, stride=self.rate, padding=1) / 4. # (B=1, C=128, H=64, W=64)
324
+ y.append(yi)
325
+ offsets.append(offset)
326
+
327
+ y = torch.cat(y, dim=0) # back to the mini-batch
328
+ y.contiguous().view(raw_int_fs)
329
+
330
+ offsets = torch.cat(offsets, dim=0)
331
+ offsets = offsets.view(int_fs[0], 2, *int_fs[2:])
332
+
333
+ # case1: visualize optical flow: minus current position
334
+ h_add = torch.arange(int_fs[2]).view([1, 1, int_fs[2], 1]).expand(int_fs[0], -1, -1, int_fs[3])
335
+ w_add = torch.arange(int_fs[3]).view([1, 1, 1, int_fs[3]]).expand(int_fs[0], -1, int_fs[2], -1)
336
+ ref_coordinate = torch.cat([h_add, w_add], dim=1)
337
+ if self.use_cuda:
338
+ ref_coordinate = ref_coordinate.cuda()
339
+
340
+ offsets = offsets - ref_coordinate
341
+ # flow = pt_flow_to_image(offsets)
342
+
343
+ flow = torch.from_numpy(flow_to_image(offsets.permute(0, 2, 3, 1).cpu().data.numpy())) / 255.
344
+ flow = flow.permute(0, 3, 1, 2)
345
+ if self.use_cuda:
346
+ flow = flow.cuda()
347
+ # case2: visualize which pixels are attended
348
+ # flow = torch.from_numpy(highlight_flow((offsets * mask.long()).cpu().data.numpy()))
349
+
350
+ if self.rate != 1:
351
+ flow = F.interpolate(flow, scale_factor=self.rate*4, mode='nearest')
352
+
353
+ return y, flow
354
+
355
+
356
+ def test_contextual_attention(args):
357
+ import cv2
358
+ import os
359
+ # run on cpu
360
+ os.environ['CUDA_VISIBLE_DEVICES'] = '2'
361
+
362
+ def float_to_uint8(img):
363
+ img = img * 255
364
+ return img.astype('uint8')
365
+
366
+ rate = 2
367
+ stride = 1
368
+ grid = rate*stride
369
+
370
+ b = default_loader(args.imageA)
371
+ w, h = b.size
372
+ b = b.resize((w//grid*grid//2, h//grid*grid//2), Image.ANTIALIAS)
373
+ # b = b.resize((w//grid*grid, h//grid*grid), Image.ANTIALIAS)
374
+ print('Size of imageA: {}'.format(b.size))
375
+
376
+ f = default_loader(args.imageB)
377
+ w, h = f.size
378
+ f = f.resize((w//grid*grid, h//grid*grid), Image.ANTIALIAS)
379
+ print('Size of imageB: {}'.format(f.size))
380
+
381
+ f, b = transforms.ToTensor()(f), transforms.ToTensor()(b)
382
+ f, b = f.unsqueeze(0), b.unsqueeze(0)
383
+ if torch.cuda.is_available():
384
+ f, b = f.cuda(), b.cuda()
385
+
386
+ contextual_attention = ContextualAttention(ksize=3, stride=stride, rate=rate, fuse=True)
387
+
388
+ if torch.cuda.is_available():
389
+ contextual_attention = contextual_attention.cuda()
390
+
391
+ yt, flow_t = contextual_attention(f, b)
392
+ vutils.save_image(yt, 'vutils' + args.imageOut, normalize=True)
393
+ vutils.save_image(flow_t, 'flow' + args.imageOut, normalize=True)
394
+ # y = tensor_img_to_npimg(yt.cpu()[0])
395
+ # flow = tensor_img_to_npimg(flow_t.cpu()[0])
396
+ # cv2.imwrite('flow' + args.imageOut, flow_t)
397
+
398
+
399
+ class LocalDis(nn.Module):
400
+ def __init__(self, config, use_cuda=True, device_ids=None):
401
+ super(LocalDis, self).__init__()
402
+ self.input_dim = config['input_dim']
403
+ self.cnum = config['ndf']
404
+ self.use_cuda = use_cuda
405
+ self.device_ids = device_ids
406
+
407
+ self.dis_conv_module = DisConvModule(self.input_dim, self.cnum)
408
+ self.linear = nn.Linear(self.cnum*4*8*8, 1)
409
+
410
+ def forward(self, x):
411
+ x = self.dis_conv_module(x)
412
+ x = x.view(x.size()[0], -1)
413
+ x = self.linear(x)
414
+
415
+ return x
416
+
417
+
418
+ class GlobalDis(nn.Module):
419
+ def __init__(self, config, use_cuda=True, device_ids=None):
420
+ super(GlobalDis, self).__init__()
421
+ self.input_dim = config['input_dim']
422
+ self.cnum = config['ndf']
423
+ self.use_cuda = use_cuda
424
+ self.device_ids = device_ids
425
+
426
+ self.dis_conv_module = DisConvModule(self.input_dim, self.cnum)
427
+ self.linear = nn.Linear(self.cnum*4*16*16, 1)
428
+
429
+ def forward(self, x):
430
+ x = self.dis_conv_module(x)
431
+ x = x.view(x.size()[0], -1)
432
+ x = self.linear(x)
433
+
434
+ return x
435
+
436
+
437
+ class DisConvModule(nn.Module):
438
+ def __init__(self, input_dim, cnum, use_cuda=True, device_ids=None):
439
+ super(DisConvModule, self).__init__()
440
+ self.use_cuda = use_cuda
441
+ self.device_ids = device_ids
442
+
443
+ self.conv1 = dis_conv(input_dim, cnum, 5, 2, 2)
444
+ self.conv2 = dis_conv(cnum, cnum*2, 5, 2, 2)
445
+ self.conv3 = dis_conv(cnum*2, cnum*4, 5, 2, 2)
446
+ self.conv4 = dis_conv(cnum*4, cnum*4, 5, 2, 2)
447
+
448
+ def forward(self, x):
449
+ x = self.conv1(x)
450
+ x = self.conv2(x)
451
+ x = self.conv3(x)
452
+ x = self.conv4(x)
453
+
454
+ return x
455
+
456
+
457
+ def gen_conv(input_dim, output_dim, kernel_size=3, stride=1, padding=0, rate=1,
458
+ activation='elu'):
459
+ return Conv2dBlock(input_dim, output_dim, kernel_size, stride,
460
+ conv_padding=padding, dilation=rate,
461
+ activation=activation)
462
+
463
+
464
+ def dis_conv(input_dim, output_dim, kernel_size=5, stride=2, padding=0, rate=1,
465
+ activation='lrelu'):
466
+ return Conv2dBlock(input_dim, output_dim, kernel_size, stride,
467
+ conv_padding=padding, dilation=rate,
468
+ activation=activation)
469
+
470
+
471
+ class Conv2dBlock(nn.Module):
472
+ def __init__(self, input_dim, output_dim, kernel_size, stride, padding=0,
473
+ conv_padding=0, dilation=1, weight_norm='none', norm='none',
474
+ activation='relu', pad_type='zero', transpose=False):
475
+ super(Conv2dBlock, self).__init__()
476
+ self.use_bias = True
477
+ # initialize padding
478
+ if pad_type == 'reflect':
479
+ self.pad = nn.ReflectionPad2d(padding)
480
+ elif pad_type == 'replicate':
481
+ self.pad = nn.ReplicationPad2d(padding)
482
+ elif pad_type == 'zero':
483
+ self.pad = nn.ZeroPad2d(padding)
484
+ elif pad_type == 'none':
485
+ self.pad = None
486
+ else:
487
+ assert 0, "Unsupported padding type: {}".format(pad_type)
488
+
489
+ # initialize normalization
490
+ norm_dim = output_dim
491
+ if norm == 'bn':
492
+ self.norm = nn.BatchNorm2d(norm_dim)
493
+ elif norm == 'in':
494
+ self.norm = nn.InstanceNorm2d(norm_dim)
495
+ elif norm == 'none':
496
+ self.norm = None
497
+ else:
498
+ assert 0, "Unsupported normalization: {}".format(norm)
499
+
500
+ if weight_norm == 'sn':
501
+ self.weight_norm = spectral_norm_fn
502
+ elif weight_norm == 'wn':
503
+ self.weight_norm = weight_norm_fn
504
+ elif weight_norm == 'none':
505
+ self.weight_norm = None
506
+ else:
507
+ assert 0, "Unsupported normalization: {}".format(weight_norm)
508
+
509
+ # initialize activation
510
+ if activation == 'relu':
511
+ self.activation = nn.ReLU(inplace=True)
512
+ elif activation == 'elu':
513
+ self.activation = nn.ELU(inplace=True)
514
+ elif activation == 'lrelu':
515
+ self.activation = nn.LeakyReLU(0.2, inplace=True)
516
+ elif activation == 'prelu':
517
+ self.activation = nn.PReLU()
518
+ elif activation == 'selu':
519
+ self.activation = nn.SELU(inplace=True)
520
+ elif activation == 'tanh':
521
+ self.activation = nn.Tanh()
522
+ elif activation == 'none':
523
+ self.activation = None
524
+ else:
525
+ assert 0, "Unsupported activation: {}".format(activation)
526
+
527
+ # initialize convolution
528
+ if transpose:
529
+ self.conv = nn.ConvTranspose2d(input_dim, output_dim,
530
+ kernel_size, stride,
531
+ padding=conv_padding,
532
+ output_padding=conv_padding,
533
+ dilation=dilation,
534
+ bias=self.use_bias)
535
+ else:
536
+ self.conv = nn.Conv2d(input_dim, output_dim, kernel_size, stride,
537
+ padding=conv_padding, dilation=dilation,
538
+ bias=self.use_bias)
539
+
540
+ if self.weight_norm:
541
+ self.conv = self.weight_norm(self.conv)
542
+
543
+ def forward(self, x):
544
+ if self.pad:
545
+ x = self.conv(self.pad(x))
546
+ else:
547
+ x = self.conv(x)
548
+ if self.norm:
549
+ x = self.norm(x)
550
+ if self.activation:
551
+ x = self.activation(x)
552
+ return x
553
+
554
+
555
+
556
+ if __name__ == "__main__":
557
+ import argparse
558
+ parser = argparse.ArgumentParser()
559
+ parser.add_argument('--imageA', default='', type=str, help='Image A as background patches to reconstruct image B.')
560
+ parser.add_argument('--imageB', default='', type=str, help='Image B is reconstructed with image A.')
561
+ parser.add_argument('--imageOut', default='result.png', type=str, help='Image B is reconstructed with image A.')
562
+ args = parser.parse_args()
563
+ test_contextual_attention(args)
only_gradio_server.py ADDED
@@ -0,0 +1,188 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import base64
3
+ import io
4
+ import uuid
5
+ from ultralytics import YOLO
6
+ import cv2
7
+ import torch
8
+ import numpy as np
9
+ from PIL import Image
10
+ from torchvision import transforms
11
+ import imageio.v2 as imageio
12
+ from trainer import Trainer
13
+ from utils.tools import get_config
14
+ import torch.nn.functional as F
15
+ from iopaint.single_processing import batch_inpaint_cv2
16
+ from pathlib import Path
17
+
18
+ # set current working directory cache instead of default
19
+ os.environ["TORCH_HOME"] = "./pretrained-model"
20
+ os.environ["HUGGINGFACE_HUB_CACHE"] = "./pretrained-model"
21
+
22
+ def resize_image(input_image_path, width=640, height=640):
23
+ """Resizes an image from image data and returns the resized image."""
24
+ try:
25
+ # Read the image using cv2.imread
26
+ img = cv2.imread(input_image_path, cv2.IMREAD_COLOR)
27
+
28
+ # Resize while maintaining the aspect ratio
29
+ shape = img.shape[:2] # current shape [height, width]
30
+ new_shape = (width, height) # the shape to resize to
31
+
32
+ # Scale ratio (new / old)
33
+ r = min(new_shape[0] / shape[0], new_shape[1] / shape[1])
34
+ ratio = r, r # width, height ratios
35
+ new_unpad = int(round(shape[1] * r)), int(round(shape[0] * r))
36
+
37
+ # Resize the image
38
+ im = cv2.resize(img, new_unpad, interpolation=cv2.INTER_LINEAR)
39
+
40
+ # Pad the image
41
+ color = (114, 114, 114) # color used for padding
42
+ dw, dh = new_shape[1] - new_unpad[0], new_shape[0] - new_unpad[1] # wh padding
43
+ # divide padding into 2 sides
44
+ dw /= 2
45
+ dh /= 2
46
+ # compute padding on all corners
47
+ top, bottom = int(round(dh - 0.1)), int(round(dh + 0.1))
48
+ left, right = int(round(dw - 0.1)), int(round(dw + 0.1))
49
+ im = cv2.copyMakeBorder(im, top, bottom, left, right, cv2.BORDER_CONSTANT, value=color) # add border
50
+ return im
51
+
52
+ except Exception as e:
53
+ print(f"Error resizing image: {e}")
54
+ return None # Or handle differently as needed
55
+
56
+
57
+ def load_weights(path, device):
58
+ model_weights = torch.load(path)
59
+ return {
60
+ k: v.to(device)
61
+ for k, v in model_weights.items()
62
+ }
63
+
64
+
65
+ # Function to convert image to base64
66
+ def convert_image_to_base64(image):
67
+ # Convert image to bytes
68
+ _, buffer = cv2.imencode('.png', image)
69
+ # Convert bytes to base64
70
+ image_base64 = base64.b64encode(buffer).decode('utf-8')
71
+ return image_base64
72
+
73
+
74
+ def convert_to_base64(image):
75
+ # Read the image file as binary data
76
+ image_data = image.read()
77
+ # Encode the binary data as base64
78
+ base64_encoded = base64.b64encode(image_data).decode('utf-8')
79
+ return base64_encoded
80
+
81
+ def convert_to_base64_file(image):
82
+ # Convert the image to binary data
83
+ image_data = cv2.imencode('.png', image)[1].tobytes()
84
+ # Encode the binary data as base64
85
+ base64_encoded = base64.b64encode(image_data).decode('utf-8')
86
+ return base64_encoded
87
+
88
+
89
+ def process_images(input_image, append_image, default_class="chair"):
90
+ # Static paths
91
+ config_path = Path('configs/config.yaml')
92
+ model_path = Path('pretrained-model/torch_model.p')
93
+
94
+ # Resize input image and get base64 data of resized image
95
+ img = resize_image(input_image)
96
+
97
+ if img is None:
98
+ return {'error': 'Failed to decode resized image'}, 419
99
+
100
+ H, W, _ = img.shape
101
+ x_point = 0
102
+ y_point = 0
103
+ width = 1
104
+ height = 1
105
+
106
+ # Load a model
107
+ model = YOLO('pretrained-model/yolov8m-seg.pt') # pretrained YOLOv8m-seg model
108
+
109
+ # Run batched inference on a list of images
110
+ results = model(img, imgsz=(W,H), conf=0.5) # chair class 56 with confidence >= 0.5
111
+ names = model.names
112
+
113
+ class_found = False
114
+ for result in results:
115
+ for i, label in enumerate(result.boxes.cls):
116
+ # Check if the label matches the chair label
117
+ if names[int(label)] == default_class:
118
+ class_found = True
119
+ # Convert the tensor to a numpy array
120
+ chair_mask_np = result.masks.data[i].numpy()
121
+
122
+ kernel = np.ones((5, 5), np.uint8) # Create a 5x5 kernel for dilation
123
+ chair_mask_np = cv2.dilate(chair_mask_np, kernel, iterations=2) # Apply dilation
124
+
125
+ # Find contours to get bounding box
126
+ contours, _ = cv2.findContours((chair_mask_np == 1).astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
127
+
128
+ # Iterate over contours to find the bounding box of each object
129
+ for contour in contours:
130
+ x, y, w, h = cv2.boundingRect(contour)
131
+ x_point = x
132
+ y_point = y
133
+ width = w
134
+ height = h
135
+
136
+ # Get the corresponding mask
137
+ mask = result.masks.data[i].numpy() * 255
138
+ dilated_mask = cv2.dilate(mask, kernel, iterations=2) # Apply dilation
139
+ # Resize the mask to match the dimensions of the original image
140
+ resized_mask = cv2.resize(dilated_mask, (img.shape[1], img.shape[0]))
141
+
142
+ # call repainting and merge function
143
+ output_base64 = repaitingAndMerge(append_image,str(model_path), str(config_path),width, height, x_point, y_point, img, resized_mask)
144
+ # Return the output base64 image in the API response
145
+ return output_base64
146
+
147
+ # return class not found in prediction
148
+ if not class_found:
149
+ return {'message': f'{default_class} object not found in the image'}, 200
150
+
151
+ def repaitingAndMerge(append_image_path, model_path, config_path, width, height, xposition, yposition, input_base, mask_base):
152
+ config = get_config(config_path)
153
+ device = torch.device("cpu")
154
+ trainer = Trainer(config)
155
+ trainer.load_state_dict(load_weights(model_path, device), strict=False)
156
+ trainer.eval()
157
+
158
+ # lama inpainting start
159
+ print("lama inpainting start")
160
+ inpaint_result_np = batch_inpaint_cv2('lama', 'cpu', input_base, mask_base)
161
+ print("lama inpainting end")
162
+
163
+ # Create PIL Image from NumPy array
164
+ final_image = Image.fromarray(inpaint_result_np)
165
+
166
+ print("merge start")
167
+
168
+ # Load the append image using cv2.imread
169
+ append_image = cv2.imread(append_image_path, cv2.IMREAD_UNCHANGED)
170
+ cv2.imwrite('appneded-image.png',append_image)
171
+ # Resize the append image while preserving transparency
172
+ resized_image = cv2.resize(append_image, (width, height), interpolation=cv2.INTER_AREA)
173
+ # Convert the resized image to RGBA format (assuming it's in BGRA format)
174
+ resized_image = cv2.cvtColor(resized_image, cv2.COLOR_BGRA2RGBA)
175
+ # Create a PIL Image from the resized image with transparent background
176
+ append_image_pil = Image.fromarray(resized_image)
177
+
178
+ # Paste the append image onto the final image
179
+ final_image.paste(append_image_pil, (xposition, yposition), append_image_pil)
180
+ # Save the resulting image
181
+ print("merge end")
182
+
183
+ # Convert the final image to base64
184
+ with io.BytesIO() as output_buffer:
185
+ final_image.save(output_buffer, format='PNG')
186
+ output_numpy = np.array(final_image)
187
+
188
+ return output_numpy