RysonFeng commited on
Commit
cdb26a4
·
1 Parent(s): a9b35cc

Add source code

Browse files
LICENSE ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Apache License
2
+ Version 2.0, January 2004
3
+ http://www.apache.org/licenses/
4
+
5
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6
+
7
+ 1. Definitions.
8
+
9
+ "License" shall mean the terms and conditions for use, reproduction,
10
+ and distribution as defined by Sections 1 through 9 of this document.
11
+
12
+ "Licensor" shall mean the copyright owner or entity authorized by
13
+ the copyright owner that is granting the License.
14
+
15
+ "Legal Entity" shall mean the union of the acting entity and all
16
+ other entities that control, are controlled by, or are under common
17
+ control with that entity. For the purposes of this definition,
18
+ "control" means (i) the power, direct or indirect, to cause the
19
+ direction or management of such entity, whether by contract or
20
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
21
+ outstanding shares, or (iii) beneficial ownership of such entity.
22
+
23
+ "You" (or "Your") shall mean an individual or Legal Entity
24
+ exercising permissions granted by this License.
25
+
26
+ "Source" form shall mean the preferred form for making modifications,
27
+ including but not limited to software source code, documentation
28
+ source, and configuration files.
29
+
30
+ "Object" form shall mean any form resulting from mechanical
31
+ transformation or translation of a Source form, including but
32
+ not limited to compiled object code, generated documentation,
33
+ and conversions to other media types.
34
+
35
+ "Work" shall mean the work of authorship, whether in Source or
36
+ Object form, made available under the License, as indicated by a
37
+ copyright notice that is included in or attached to the work
38
+ (an example is provided in the Appendix below).
39
+
40
+ "Derivative Works" shall mean any work, whether in Source or Object
41
+ form, that is based on (or derived from) the Work and for which the
42
+ editorial revisions, annotations, elaborations, or other modifications
43
+ represent, as a whole, an original work of authorship. For the purposes
44
+ of this License, Derivative Works shall not include works that remain
45
+ separable from, or merely link (or bind by name) to the interfaces of,
46
+ the Work and Derivative Works thereof.
47
+
48
+ "Contribution" shall mean any work of authorship, including
49
+ the original version of the Work and any modifications or additions
50
+ to that Work or Derivative Works thereof, that is intentionally
51
+ submitted to Licensor for inclusion in the Work by the copyright owner
52
+ or by an individual or Legal Entity authorized to submit on behalf of
53
+ the copyright owner. For the purposes of this definition, "submitted"
54
+ means any form of electronic, verbal, or written communication sent
55
+ to the Licensor or its representatives, including but not limited to
56
+ communication on electronic mailing lists, source code control systems,
57
+ and issue tracking systems that are managed by, or on behalf of, the
58
+ Licensor for the purpose of discussing and improving the Work, but
59
+ excluding communication that is conspicuously marked or otherwise
60
+ designated in writing by the copyright owner as "Not a Contribution."
61
+
62
+ "Contributor" shall mean Licensor and any individual or Legal Entity
63
+ on behalf of whom a Contribution has been received by Licensor and
64
+ subsequently incorporated within the Work.
65
+
66
+ 2. Grant of Copyright License. Subject to the terms and conditions of
67
+ this License, each Contributor hereby grants to You a perpetual,
68
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69
+ copyright license to reproduce, prepare Derivative Works of,
70
+ publicly display, publicly perform, sublicense, and distribute the
71
+ Work and such Derivative Works in Source or Object form.
72
+
73
+ 3. Grant of Patent License. Subject to the terms and conditions of
74
+ this License, each Contributor hereby grants to You a perpetual,
75
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76
+ (except as stated in this section) patent license to make, have made,
77
+ use, offer to sell, sell, import, and otherwise transfer the Work,
78
+ where such license applies only to those patent claims licensable
79
+ by such Contributor that are necessarily infringed by their
80
+ Contribution(s) alone or by combination of their Contribution(s)
81
+ with the Work to which such Contribution(s) was submitted. If You
82
+ institute patent litigation against any entity (including a
83
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
84
+ or a Contribution incorporated within the Work constitutes direct
85
+ or contributory patent infringement, then any patent licenses
86
+ granted to You under this License for that Work shall terminate
87
+ as of the date such litigation is filed.
88
+
89
+ 4. Redistribution. You may reproduce and distribute copies of the
90
+ Work or Derivative Works thereof in any medium, with or without
91
+ modifications, and in Source or Object form, provided that You
92
+ meet the following conditions:
93
+
94
+ (a) You must give any other recipients of the Work or
95
+ Derivative Works a copy of this License; and
96
+
97
+ (b) You must cause any modified files to carry prominent notices
98
+ stating that You changed the files; and
99
+
100
+ (c) You must retain, in the Source form of any Derivative Works
101
+ that You distribute, all copyright, patent, trademark, and
102
+ attribution notices from the Source form of the Work,
103
+ excluding those notices that do not pertain to any part of
104
+ the Derivative Works; and
105
+
106
+ (d) If the Work includes a "NOTICE" text file as part of its
107
+ distribution, then any Derivative Works that You distribute must
108
+ include a readable copy of the attribution notices contained
109
+ within such NOTICE file, excluding those notices that do not
110
+ pertain to any part of the Derivative Works, in at least one
111
+ of the following places: within a NOTICE text file distributed
112
+ as part of the Derivative Works; within the Source form or
113
+ documentation, if provided along with the Derivative Works; or,
114
+ within a display generated by the Derivative Works, if and
115
+ wherever such third-party notices normally appear. The contents
116
+ of the NOTICE file are for informational purposes only and
117
+ do not modify the License. You may add Your own attribution
118
+ notices within Derivative Works that You distribute, alongside
119
+ or as an addendum to the NOTICE text from the Work, provided
120
+ that such additional attribution notices cannot be construed
121
+ as modifying the License.
122
+
123
+ You may add Your own copyright statement to Your modifications and
124
+ may provide additional or different license terms and conditions
125
+ for use, reproduction, or distribution of Your modifications, or
126
+ for any such Derivative Works as a whole, provided Your use,
127
+ reproduction, and distribution of the Work otherwise complies with
128
+ the conditions stated in this License.
129
+
130
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
131
+ any Contribution intentionally submitted for inclusion in the Work
132
+ by You to the Licensor shall be under the terms and conditions of
133
+ this License, without any additional terms or conditions.
134
+ Notwithstanding the above, nothing herein shall supersede or modify
135
+ the terms of any separate license agreement you may have executed
136
+ with Licensor regarding such Contributions.
137
+
138
+ 6. Trademarks. This License does not grant permission to use the trade
139
+ names, trademarks, service marks, or product names of the Licensor,
140
+ except as required for reasonable and customary use in describing the
141
+ origin of the Work and reproducing the content of the NOTICE file.
142
+
143
+ 7. Disclaimer of Warranty. Unless required by applicable law or
144
+ agreed to in writing, Licensor provides the Work (and each
145
+ Contributor provides its Contributions) on an "AS IS" BASIS,
146
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147
+ implied, including, without limitation, any warranties or conditions
148
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149
+ PARTICULAR PURPOSE. You are solely responsible for determining the
150
+ appropriateness of using or redistributing the Work and assume any
151
+ risks associated with Your exercise of permissions under this License.
152
+
153
+ 8. Limitation of Liability. In no event and under no legal theory,
154
+ whether in tort (including negligence), contract, or otherwise,
155
+ unless required by applicable law (such as deliberate and grossly
156
+ negligent acts) or agreed to in writing, shall any Contributor be
157
+ liable to You for damages, including any direct, indirect, special,
158
+ incidental, or consequential damages of any character arising as a
159
+ result of this License or out of the use or inability to use the
160
+ Work (including but not limited to damages for loss of goodwill,
161
+ work stoppage, computer failure or malfunction, or any and all
162
+ other commercial damages or losses), even if such Contributor
163
+ has been advised of the possibility of such damages.
164
+
165
+ 9. Accepting Warranty or Additional Liability. While redistributing
166
+ the Work or Derivative Works thereof, You may choose to offer,
167
+ and charge a fee for, acceptance of support, warranty, indemnity,
168
+ or other liability obligations and/or rights consistent with this
169
+ License. However, in accepting such obligations, You may act only
170
+ on Your own behalf and on Your sole responsibility, not on behalf
171
+ of any other Contributor, and only if You agree to indemnify,
172
+ defend, and hold each Contributor harmless for any liability
173
+ incurred by, or claims asserted against, such Contributor by reason
174
+ of your accepting any such warranty or additional liability.
175
+
176
+ END OF TERMS AND CONDITIONS
177
+
178
+ APPENDIX: How to apply the Apache License to your work.
179
+
180
+ To apply the Apache License to your work, attach the following
181
+ boilerplate notice, with the fields enclosed by brackets "[]"
182
+ replaced with your own identifying information. (Don't include
183
+ the brackets!) The text should be enclosed in the appropriate
184
+ comment syntax for the file format. We also recommend that a
185
+ file or class name and description of purpose be included on the
186
+ same "printed page" as the copyright notice for easier
187
+ identification within third-party archives.
188
+
189
+ Copyright [yyyy] [name of copyright owner]
190
+
191
+ Licensed under the Apache License, Version 2.0 (the "License");
192
+ you may not use this file except in compliance with the License.
193
+ You may obtain a copy of the License at
194
+
195
+ http://www.apache.org/licenses/LICENSE-2.0
196
+
197
+ Unless required by applicable law or agreed to in writing, software
198
+ distributed under the License is distributed on an "AS IS" BASIS,
199
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200
+ See the License for the specific language governing permissions and
201
+ limitations under the License.
__init__.py ADDED
File without changes
fill_anything.py ADDED
@@ -0,0 +1,128 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import sys
3
+ import argparse
4
+ import numpy as np
5
+ import torch
6
+ from pathlib import Path
7
+ from matplotlib import pyplot as plt
8
+ from typing import Any, Dict, List
9
+
10
+ from sam_segment import predict_masks_with_sam
11
+ from stable_diffusion_inpaint import fill_img_with_sd
12
+ from utils import load_img_to_array, save_array_to_img, dilate_mask, \
13
+ show_mask, show_points
14
+
15
+
16
+ def setup_args(parser):
17
+ parser.add_argument(
18
+ "--input_img", type=str, required=True,
19
+ help="Path to a single input img",
20
+ )
21
+ parser.add_argument(
22
+ "--point_coords", type=float, nargs='+', required=True,
23
+ help="The coordinate of the point prompt, [coord_W coord_H].",
24
+ )
25
+ parser.add_argument(
26
+ "--point_labels", type=int, nargs='+', required=True,
27
+ help="The labels of the point prompt, 1 or 0.",
28
+ )
29
+ parser.add_argument(
30
+ "--text_prompt", type=str, required=True,
31
+ help="Text prompt",
32
+ )
33
+ parser.add_argument(
34
+ "--dilate_kernel_size", type=int, default=None,
35
+ help="Dilate kernel size. Default: None",
36
+ )
37
+ parser.add_argument(
38
+ "--output_dir", type=str, required=True,
39
+ help="Output path to the directory with results.",
40
+ )
41
+ parser.add_argument(
42
+ "--sam_model_type", type=str,
43
+ default="vit_h", choices=['vit_h', 'vit_l', 'vit_b'],
44
+ help="The type of sam model to load. Default: 'vit_h"
45
+ )
46
+ parser.add_argument(
47
+ "--sam_ckpt", type=str, required=True,
48
+ help="The path to the SAM checkpoint to use for mask generation.",
49
+ )
50
+ parser.add_argument(
51
+ "--seed", type=int,
52
+ help="Specify seed for reproducibility.",
53
+ )
54
+ parser.add_argument(
55
+ "--deterministic", action="store_true",
56
+ help="Use deterministic algorithms for reproducibility.",
57
+ )
58
+
59
+
60
+
61
+ if __name__ == "__main__":
62
+ """Example usage:
63
+ python fill_anything.py \
64
+ --input_img FA_demo/FA1_dog.png \
65
+ --point_coords 750 500 \
66
+ --point_labels 1 \
67
+ --text_prompt "a teddy bear on a bench" \
68
+ --dilate_kernel_size 15 \
69
+ --output_dir ./results \
70
+ --sam_model_type "vit_h" \
71
+ --sam_ckpt sam_vit_h_4b8939.pth
72
+ """
73
+ parser = argparse.ArgumentParser()
74
+ setup_args(parser)
75
+ args = parser.parse_args(sys.argv[1:])
76
+ device = "cuda" if torch.cuda.is_available() else "cpu"
77
+
78
+ img = load_img_to_array(args.input_img)
79
+
80
+ masks, _, _ = predict_masks_with_sam(
81
+ img,
82
+ [args.point_coords],
83
+ args.point_labels,
84
+ model_type=args.sam_model_type,
85
+ ckpt_p=args.sam_ckpt,
86
+ device=device,
87
+ )
88
+ masks = masks.astype(np.uint8) * 255
89
+
90
+ # dilate mask to avoid unmasked edge effect
91
+ if args.dilate_kernel_size is not None:
92
+ masks = [dilate_mask(mask, args.dilate_kernel_size) for mask in masks]
93
+
94
+ # visualize the segmentation results
95
+ img_stem = Path(args.input_img).stem
96
+ out_dir = Path(args.output_dir) / img_stem
97
+ out_dir.mkdir(parents=True, exist_ok=True)
98
+ for idx, mask in enumerate(masks):
99
+ # path to the results
100
+ mask_p = out_dir / f"mask_{idx}.png"
101
+ img_points_p = out_dir / f"with_points.png"
102
+ img_mask_p = out_dir / f"with_{Path(mask_p).name}"
103
+
104
+ # save the mask
105
+ save_array_to_img(mask, mask_p)
106
+
107
+ # save the pointed and masked image
108
+ dpi = plt.rcParams['figure.dpi']
109
+ height, width = img.shape[:2]
110
+ plt.figure(figsize=(width/dpi/0.77, height/dpi/0.77))
111
+ plt.imshow(img)
112
+ plt.axis('off')
113
+ show_points(plt.gca(), [args.point_coords], args.point_labels,
114
+ size=(width*0.04)**2)
115
+ plt.savefig(img_points_p, bbox_inches='tight', pad_inches=0)
116
+ show_mask(plt.gca(), mask, random_color=False)
117
+ plt.savefig(img_mask_p, bbox_inches='tight', pad_inches=0)
118
+ plt.close()
119
+
120
+ # fill the masked image
121
+ for idx, mask in enumerate(masks):
122
+ if args.seed is not None:
123
+ torch.manual_seed(args.seed)
124
+ mask_p = out_dir / f"mask_{idx}.png"
125
+ img_filled_p = out_dir / f"filled_with_{Path(mask_p).name}"
126
+ img_filled = fill_img_with_sd(
127
+ img, mask, args.text_prompt, device=device)
128
+ save_array_to_img(img_filled, img_filled_p)
lama_inpaint.py ADDED
@@ -0,0 +1,134 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import numpy as np
4
+ import torch
5
+ import yaml
6
+ import glob
7
+ import argparse
8
+ from PIL import Image
9
+ from omegaconf import OmegaConf
10
+ from pathlib import Path
11
+
12
+ os.environ['OMP_NUM_THREADS'] = '1'
13
+ os.environ['OPENBLAS_NUM_THREADS'] = '1'
14
+ os.environ['MKL_NUM_THREADS'] = '1'
15
+ os.environ['VECLIB_MAXIMUM_THREADS'] = '1'
16
+ os.environ['NUMEXPR_NUM_THREADS'] = '1'
17
+
18
+ sys.path.insert(0, str(Path(__file__).resolve().parent / "third_party" / "lama"))
19
+ from saicinpainting.evaluation.utils import move_to_device
20
+ from saicinpainting.training.trainers import load_checkpoint
21
+ from saicinpainting.evaluation.data import pad_tensor_to_modulo
22
+
23
+ from utils import load_img_to_array, save_array_to_img
24
+
25
+
26
+ @torch.no_grad()
27
+ def inpaint_img_with_lama(
28
+ img: np.ndarray,
29
+ mask: np.ndarray,
30
+ config_p: str,
31
+ ckpt_p: str,
32
+ mod=8,
33
+ device="cuda"
34
+ ):
35
+ assert len(mask.shape) == 2
36
+ if np.max(mask) == 1:
37
+ mask = mask * 255
38
+ img = torch.from_numpy(img).float().div(255.)
39
+ mask = torch.from_numpy(mask).float()
40
+ predict_config = OmegaConf.load(config_p)
41
+ predict_config.model.path = ckpt_p
42
+ # device = torch.device(predict_config.device)
43
+ device = torch.device(device)
44
+
45
+ train_config_path = os.path.join(
46
+ predict_config.model.path, 'config.yaml')
47
+
48
+ with open(train_config_path, 'r') as f:
49
+ train_config = OmegaConf.create(yaml.safe_load(f))
50
+
51
+ train_config.training_model.predict_only = True
52
+ train_config.visualizer.kind = 'noop'
53
+
54
+ checkpoint_path = os.path.join(
55
+ predict_config.model.path, 'models',
56
+ predict_config.model.checkpoint
57
+ )
58
+ model = load_checkpoint(
59
+ train_config, checkpoint_path, strict=False, map_location=device)
60
+ model.freeze()
61
+ if not predict_config.get('refine', False):
62
+ model.to(device)
63
+
64
+ batch = {}
65
+ batch['image'] = img.permute(2, 0, 1).unsqueeze(0)
66
+ batch['mask'] = mask[None, None]
67
+ unpad_to_size = [batch['image'].shape[2], batch['image'].shape[3]]
68
+ batch['image'] = pad_tensor_to_modulo(batch['image'], mod)
69
+ batch['mask'] = pad_tensor_to_modulo(batch['mask'], mod)
70
+ batch = move_to_device(batch, device)
71
+ batch['mask'] = (batch['mask'] > 0) * 1
72
+
73
+ batch = model(batch)
74
+ cur_res = batch[predict_config.out_key][0].permute(1, 2, 0)
75
+ cur_res = cur_res.detach().cpu().numpy()
76
+
77
+ if unpad_to_size is not None:
78
+ orig_height, orig_width = unpad_to_size
79
+ cur_res = cur_res[:orig_height, :orig_width]
80
+
81
+ cur_res = np.clip(cur_res * 255, 0, 255).astype('uint8')
82
+ return cur_res
83
+
84
+ def setup_args(parser):
85
+ parser.add_argument(
86
+ "--input_img", type=str, required=True,
87
+ help="Path to a single input img",
88
+ )
89
+ parser.add_argument(
90
+ "--input_mask_glob", type=str, required=True,
91
+ help="Glob to input masks",
92
+ )
93
+ parser.add_argument(
94
+ "--output_dir", type=str, required=True,
95
+ help="Output path to the directory with results.",
96
+ )
97
+ parser.add_argument(
98
+ "--lama_config", type=str,
99
+ default="./third_party/lama/configs/prediction/default.yaml",
100
+ help="The path to the config file of lama model. "
101
+ "Default: the config of big-lama",
102
+ )
103
+ parser.add_argument(
104
+ "--lama_ckpt", type=str, required=True,
105
+ help="The path to the lama checkpoint.",
106
+ )
107
+
108
+
109
+ if __name__ == "__main__":
110
+ """Example usage:
111
+ python lama_inpaint.py \
112
+ --input_img FA_demo/FA1_dog.png \
113
+ --input_mask_glob "results/FA1_dog/mask*.png" \
114
+ --output_dir results \
115
+ --lama_config lama/configs/prediction/default.yaml \
116
+ --lama_ckpt big-lama
117
+ """
118
+ parser = argparse.ArgumentParser()
119
+ setup_args(parser)
120
+ args = parser.parse_args(sys.argv[1:])
121
+ device = "cuda" if torch.cuda.is_available() else "cpu"
122
+
123
+ img_stem = Path(args.input_img).stem
124
+ mask_ps = sorted(glob.glob(args.input_mask_glob))
125
+ out_dir = Path(args.output_dir) / img_stem
126
+ out_dir.mkdir(parents=True, exist_ok=True)
127
+
128
+ img = load_img_to_array(args.input_img)
129
+ for mask_p in mask_ps:
130
+ mask = load_img_to_array(mask_p)
131
+ img_inpainted_p = out_dir / f"inpainted_with_{Path(mask_p).name}"
132
+ img_inpainted = inpaint_img_with_lama(
133
+ img, mask, args.lama_config, args.lama_ckpt, device=device)
134
+ save_array_to_img(img_inpainted, img_inpainted_p)
remove_anything.py ADDED
@@ -0,0 +1,122 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import sys
3
+ import argparse
4
+ import numpy as np
5
+ from pathlib import Path
6
+ from matplotlib import pyplot as plt
7
+
8
+ from sam_segment import predict_masks_with_sam
9
+ from lama_inpaint import inpaint_img_with_lama
10
+ from utils import load_img_to_array, save_array_to_img, dilate_mask, \
11
+ show_mask, show_points
12
+
13
+
14
+ def setup_args(parser):
15
+ parser.add_argument(
16
+ "--input_img", type=str, required=True,
17
+ help="Path to a single input img",
18
+ )
19
+ parser.add_argument(
20
+ "--point_coords", type=float, nargs='+', required=True,
21
+ help="The coordinate of the point prompt, [coord_W coord_H].",
22
+ )
23
+ parser.add_argument(
24
+ "--point_labels", type=int, nargs='+', required=True,
25
+ help="The labels of the point prompt, 1 or 0.",
26
+ )
27
+ parser.add_argument(
28
+ "--dilate_kernel_size", type=int, default=None,
29
+ help="Dilate kernel size. Default: None",
30
+ )
31
+ parser.add_argument(
32
+ "--output_dir", type=str, required=True,
33
+ help="Output path to the directory with results.",
34
+ )
35
+ parser.add_argument(
36
+ "--sam_model_type", type=str,
37
+ default="vit_h", choices=['vit_h', 'vit_l', 'vit_b'],
38
+ help="The type of sam model to load. Default: 'vit_h"
39
+ )
40
+ parser.add_argument(
41
+ "--sam_ckpt", type=str, required=True,
42
+ help="The path to the SAM checkpoint to use for mask generation.",
43
+ )
44
+ parser.add_argument(
45
+ "--lama_config", type=str,
46
+ default="./lama/configs/prediction/default.yaml",
47
+ help="The path to the config file of lama model. "
48
+ "Default: the config of big-lama",
49
+ )
50
+ parser.add_argument(
51
+ "--lama_ckpt", type=str, required=True,
52
+ help="The path to the lama checkpoint.",
53
+ )
54
+
55
+
56
+ if __name__ == "__main__":
57
+ """Example usage:
58
+ python remove_anything.py \
59
+ --input_img FA_demo/FA1_dog.png \
60
+ --point_coords 750 500 \
61
+ --point_labels 1 \
62
+ --dilate_kernel_size 15 \
63
+ --output_dir ./results \
64
+ --sam_model_type "vit_h" \
65
+ --sam_ckpt sam_vit_h_4b8939.pth \
66
+ --lama_config lama/configs/prediction/default.yaml \
67
+ --lama_ckpt big-lama
68
+ """
69
+ parser = argparse.ArgumentParser()
70
+ setup_args(parser)
71
+ args = parser.parse_args(sys.argv[1:])
72
+ device = "cuda" if torch.cuda.is_available() else "cpu"
73
+
74
+ img = load_img_to_array(args.input_img)
75
+
76
+ masks, _, _ = predict_masks_with_sam(
77
+ img,
78
+ [args.point_coords],
79
+ args.point_labels,
80
+ model_type=args.sam_model_type,
81
+ ckpt_p=args.sam_ckpt,
82
+ device=device,
83
+ )
84
+ masks = masks.astype(np.uint8) * 255
85
+
86
+ # dilate mask to avoid unmasked edge effect
87
+ if args.dilate_kernel_size is not None:
88
+ masks = [dilate_mask(mask, args.dilate_kernel_size) for mask in masks]
89
+
90
+ # visualize the segmentation results
91
+ img_stem = Path(args.input_img).stem
92
+ out_dir = Path(args.output_dir) / img_stem
93
+ out_dir.mkdir(parents=True, exist_ok=True)
94
+ for idx, mask in enumerate(masks):
95
+ # path to the results
96
+ mask_p = out_dir / f"mask_{idx}.png"
97
+ img_points_p = out_dir / f"with_points.png"
98
+ img_mask_p = out_dir / f"with_{Path(mask_p).name}"
99
+
100
+ # save the mask
101
+ save_array_to_img(mask, mask_p)
102
+
103
+ # save the pointed and masked image
104
+ dpi = plt.rcParams['figure.dpi']
105
+ height, width = img.shape[:2]
106
+ plt.figure(figsize=(width/dpi/0.77, height/dpi/0.77))
107
+ plt.imshow(img)
108
+ plt.axis('off')
109
+ show_points(plt.gca(), [args.point_coords], args.point_labels,
110
+ size=(width*0.04)**2)
111
+ plt.savefig(img_points_p, bbox_inches='tight', pad_inches=0)
112
+ show_mask(plt.gca(), mask, random_color=False)
113
+ plt.savefig(img_mask_p, bbox_inches='tight', pad_inches=0)
114
+ plt.close()
115
+
116
+ # inpaint the masked image
117
+ for idx, mask in enumerate(masks):
118
+ mask_p = out_dir / f"mask_{idx}.png"
119
+ img_inpainted_p = out_dir / f"inpainted_with_{Path(mask_p).name}"
120
+ img_inpainted = inpaint_img_with_lama(
121
+ img, mask, args.lama_config, args.lama_ckpt, device=device)
122
+ save_array_to_img(img_inpainted, img_inpainted_p)
replace_anything.py ADDED
@@ -0,0 +1,126 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import sys
3
+ import argparse
4
+ import numpy as np
5
+ import torch
6
+ from pathlib import Path
7
+ from matplotlib import pyplot as plt
8
+ from typing import Any, Dict, List
9
+ from sam_segment import predict_masks_with_sam
10
+ from stable_diffusion_inpaint import replace_img_with_sd
11
+ from utils import load_img_to_array, save_array_to_img, dilate_mask, \
12
+ show_mask, show_points
13
+
14
+
15
+ def setup_args(parser):
16
+ parser.add_argument(
17
+ "--input_img", type=str, required=True,
18
+ help="Path to a single input img",
19
+ )
20
+ parser.add_argument(
21
+ "--point_coords", type=float, nargs='+', required=True,
22
+ help="The coordinate of the point prompt, [coord_W coord_H].",
23
+ )
24
+ parser.add_argument(
25
+ "--point_labels", type=int, nargs='+', required=True,
26
+ help="The labels of the point prompt, 1 or 0.",
27
+ )
28
+ parser.add_argument(
29
+ "--text_prompt", type=str, required=True,
30
+ help="Text prompt",
31
+ )
32
+ parser.add_argument(
33
+ "--dilate_kernel_size", type=int, default=None,
34
+ help="Dilate kernel size. Default: None",
35
+ )
36
+ parser.add_argument(
37
+ "--output_dir", type=str, required=True,
38
+ help="Output path to the directory with results.",
39
+ )
40
+ parser.add_argument(
41
+ "--sam_model_type", type=str,
42
+ default="vit_h", choices=['vit_h', 'vit_l', 'vit_b'],
43
+ help="The type of sam model to load. Default: 'vit_h"
44
+ )
45
+ parser.add_argument(
46
+ "--sam_ckpt", type=str, required=True,
47
+ help="The path to the SAM checkpoint to use for mask generation.",
48
+ )
49
+ parser.add_argument(
50
+ "--seed", type=int,
51
+ help="Specify seed for reproducibility.",
52
+ )
53
+ parser.add_argument(
54
+ "--deterministic", action="store_true",
55
+ help="Use deterministic algorithms for reproducibility.",
56
+ )
57
+
58
+
59
+
60
+ if __name__ == "__main__":
61
+ """Example usage:
62
+ python replace_anything.py \
63
+ --input_img FA_demo/FA1_dog.png \
64
+ --point_coords 750 500 \
65
+ --point_labels 1 \
66
+ --text_prompt "sit on the swing" \
67
+ --output_dir ./results \
68
+ --sam_model_type "vit_h" \
69
+ --sam_ckpt sam_vit_h_4b8939.pth
70
+ """
71
+ parser = argparse.ArgumentParser()
72
+ setup_args(parser)
73
+ args = parser.parse_args(sys.argv[1:])
74
+ device = "cuda" if torch.cuda.is_available() else "cpu"
75
+
76
+ img = load_img_to_array(args.input_img)
77
+
78
+ masks, _, _ = predict_masks_with_sam(
79
+ img,
80
+ [args.point_coords],
81
+ args.point_labels,
82
+ model_type=args.sam_model_type,
83
+ ckpt_p=args.sam_ckpt,
84
+ device=device,
85
+ )
86
+ masks = masks.astype(np.uint8) * 255
87
+
88
+ # dilate mask to avoid unmasked edge effect
89
+ if args.dilate_kernel_size is not None:
90
+ masks = [dilate_mask(mask, args.dilate_kernel_size) for mask in masks]
91
+
92
+ # visualize the segmentation results
93
+ img_stem = Path(args.input_img).stem
94
+ out_dir = Path(args.output_dir) / img_stem
95
+ out_dir.mkdir(parents=True, exist_ok=True)
96
+ for idx, mask in enumerate(masks):
97
+ # path to the results
98
+ mask_p = out_dir / f"mask_{idx}.png"
99
+ img_points_p = out_dir / f"with_points.png"
100
+ img_mask_p = out_dir / f"with_{Path(mask_p).name}"
101
+
102
+ # save the mask
103
+ save_array_to_img(mask, mask_p)
104
+
105
+ # save the pointed and masked image
106
+ dpi = plt.rcParams['figure.dpi']
107
+ height, width = img.shape[:2]
108
+ plt.figure(figsize=(width/dpi/0.77, height/dpi/0.77))
109
+ plt.imshow(img)
110
+ plt.axis('off')
111
+ show_points(plt.gca(), [args.point_coords], args.point_labels,
112
+ size=(width*0.04)**2)
113
+ plt.savefig(img_points_p, bbox_inches='tight', pad_inches=0)
114
+ show_mask(plt.gca(), mask, random_color=False)
115
+ plt.savefig(img_mask_p, bbox_inches='tight', pad_inches=0)
116
+ plt.close()
117
+
118
+ # fill the masked image
119
+ for idx, mask in enumerate(masks):
120
+ if args.seed is not None:
121
+ torch.manual_seed(args.seed)
122
+ mask_p = out_dir / f"mask_{idx}.png"
123
+ img_replaced_p = out_dir / f"replaced_with_{Path(mask_p).name}"
124
+ img_replaced = replace_img_with_sd(
125
+ img, mask, args.text_prompt, device=device)
126
+ save_array_to_img(img_replaced, img_replaced_p)
sam_segment.py ADDED
@@ -0,0 +1,125 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ import argparse
3
+ import numpy as np
4
+ from pathlib import Path
5
+ from matplotlib import pyplot as plt
6
+ from typing import Any, Dict, List
7
+ import torch
8
+
9
+ from segment_anything import SamPredictor, sam_model_registry
10
+ from utils import load_img_to_array, save_array_to_img, dilate_mask, \
11
+ show_mask, show_points
12
+
13
+
14
+ def predict_masks_with_sam(
15
+ img: np.ndarray,
16
+ point_coords: List[List[float]],
17
+ point_labels: List[int],
18
+ model_type: str,
19
+ ckpt_p: str,
20
+ device="cuda"
21
+ ):
22
+ point_coords = np.array(point_coords)
23
+ point_labels = np.array(point_labels)
24
+ sam = sam_model_registry[model_type](checkpoint=ckpt_p)
25
+ sam.to(device=device)
26
+ predictor = SamPredictor(sam)
27
+
28
+ predictor.set_image(img)
29
+ masks, scores, logits = predictor.predict(
30
+ point_coords=point_coords,
31
+ point_labels=point_labels,
32
+ multimask_output=True,
33
+ )
34
+ return masks, scores, logits
35
+
36
+
37
+ def setup_args(parser):
38
+ parser.add_argument(
39
+ "--input_img", type=str, required=True,
40
+ help="Path to a single input img",
41
+ )
42
+ parser.add_argument(
43
+ "--point_coords", type=float, nargs='+', required=True,
44
+ help="The coordinate of the point prompt, [coord_W coord_H].",
45
+ )
46
+ parser.add_argument(
47
+ "--point_labels", type=int, nargs='+', required=True,
48
+ help="The labels of the point prompt, 1 or 0.",
49
+ )
50
+ parser.add_argument(
51
+ "--dilate_kernel_size", type=int, default=None,
52
+ help="Dilate kernel size. Default: None",
53
+ )
54
+ parser.add_argument(
55
+ "--output_dir", type=str, required=True,
56
+ help="Output path to the directory with results.",
57
+ )
58
+ parser.add_argument(
59
+ "--sam_model_type", type=str,
60
+ default="vit_h", choices=['vit_h', 'vit_l', 'vit_b'],
61
+ help="The type of sam model to load. Default: 'vit_h"
62
+ )
63
+ parser.add_argument(
64
+ "--sam_ckpt", type=str, required=True,
65
+ help="The path to the SAM checkpoint to use for mask generation.",
66
+ )
67
+
68
+
69
+ if __name__ == "__main__":
70
+ """Example usage:
71
+ python sam_segment.py \
72
+ --input_img FA_demo/FA1_dog.png \
73
+ --point_coords 750 500 \
74
+ --point_labels 1 \
75
+ --dilate_kernel_size 15 \
76
+ --output_dir ./results \
77
+ --sam_model_type "vit_h" \
78
+ --sam_ckpt sam_vit_h_4b8939.pth
79
+ """
80
+ parser = argparse.ArgumentParser()
81
+ setup_args(parser)
82
+ args = parser.parse_args(sys.argv[1:])
83
+ device = "cuda" if torch.cuda.is_available() else "cpu"
84
+
85
+ img = load_img_to_array(args.input_img)
86
+
87
+ masks, _, _ = predict_masks_with_sam(
88
+ img,
89
+ [args.point_coords],
90
+ args.point_labels,
91
+ model_type=args.sam_model_type,
92
+ ckpt_p=args.sam_ckpt,
93
+ device=device,
94
+ )
95
+ masks = masks.astype(np.uint8) * 255
96
+
97
+ # dilate mask to avoid unmasked edge effect
98
+ if args.dilate_kernel_size is not None:
99
+ masks = [dilate_mask(mask, args.dilate_kernel_size) for mask in masks]
100
+
101
+ # visualize the segmentation results
102
+ img_stem = Path(args.input_img).stem
103
+ out_dir = Path(args.output_dir) / img_stem
104
+ out_dir.mkdir(parents=True, exist_ok=True)
105
+ for idx, mask in enumerate(masks):
106
+ # path to the results
107
+ mask_p = out_dir / f"mask_{idx}.png"
108
+ img_points_p = out_dir / f"with_points.png"
109
+ img_mask_p = out_dir / f"with_{Path(mask_p).name}"
110
+
111
+ # save the mask
112
+ save_array_to_img(mask, mask_p)
113
+
114
+ # save the pointed and masked image
115
+ dpi = plt.rcParams['figure.dpi']
116
+ height, width = img.shape[:2]
117
+ plt.figure(figsize=(width/dpi/0.77, height/dpi/0.77))
118
+ plt.imshow(img)
119
+ plt.axis('off')
120
+ show_points(plt.gca(), [args.point_coords], args.point_labels,
121
+ size=(width*0.04)**2)
122
+ plt.savefig(img_points_p, bbox_inches='tight', pad_inches=0)
123
+ show_mask(plt.gca(), mask, random_color=False)
124
+ plt.savefig(img_mask_p, bbox_inches='tight', pad_inches=0)
125
+ plt.close()
stable_diffusion_inpaint.py ADDED
@@ -0,0 +1,117 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import glob
4
+ import argparse
5
+ import torch
6
+ import numpy as np
7
+ import PIL.Image as Image
8
+ from pathlib import Path
9
+ from diffusers import StableDiffusionInpaintPipeline
10
+ from utils.mask_processing import crop_for_filling_pre, crop_for_filling_post
11
+ from utils.crop_for_replacing import recover_size, resize_and_pad
12
+ from utils import load_img_to_array, save_array_to_img
13
+
14
+
15
+ def fill_img_with_sd(
16
+ img: np.ndarray,
17
+ mask: np.ndarray,
18
+ text_prompt: str,
19
+ device="cuda"
20
+ ):
21
+ pipe = StableDiffusionInpaintPipeline.from_pretrained(
22
+ "stabilityai/stable-diffusion-2-inpainting",
23
+ torch_dtype=torch.float32,
24
+ ).to(device)
25
+ img_crop, mask_crop = crop_for_filling_pre(img, mask)
26
+ img_crop_filled = pipe(
27
+ prompt=text_prompt,
28
+ image=Image.fromarray(img_crop),
29
+ mask_image=Image.fromarray(mask_crop)
30
+ ).images[0]
31
+ img_filled = crop_for_filling_post(img, mask, np.array(img_crop_filled))
32
+ return img_filled
33
+
34
+
35
+ def replace_img_with_sd(
36
+ img: np.ndarray,
37
+ mask: np.ndarray,
38
+ text_prompt: str,
39
+ step: int = 50,
40
+ device="cuda"
41
+ ):
42
+ pipe = StableDiffusionInpaintPipeline.from_pretrained(
43
+ "stabilityai/stable-diffusion-2-inpainting",
44
+ torch_dtype=torch.float32,
45
+ ).to(device)
46
+ img_padded, mask_padded, padding_factors = resize_and_pad(img, mask)
47
+ img_padded = pipe(
48
+ prompt=text_prompt,
49
+ image=Image.fromarray(img_padded),
50
+ mask_image=Image.fromarray(255 - mask_padded),
51
+ num_inference_steps=step,
52
+ ).images[0]
53
+ height, width, _ = img.shape
54
+ img_resized, mask_resized = recover_size(
55
+ np.array(img_padded), mask_padded, (height, width), padding_factors)
56
+ mask_resized = np.expand_dims(mask_resized, -1) / 255
57
+ img_resized = img_resized * (1-mask_resized) + img * mask_resized
58
+ return img_resized
59
+
60
+
61
+ def setup_args(parser):
62
+ parser.add_argument(
63
+ "--input_img", type=str, required=True,
64
+ help="Path to a single input img",
65
+ )
66
+ parser.add_argument(
67
+ "--text_prompt", type=str, required=True,
68
+ help="Text prompt",
69
+ )
70
+ parser.add_argument(
71
+ "--input_mask_glob", type=str, required=True,
72
+ help="Glob to input masks",
73
+ )
74
+ parser.add_argument(
75
+ "--output_dir", type=str, required=True,
76
+ help="Output path to the directory with results.",
77
+ )
78
+ parser.add_argument(
79
+ "--seed", type=int,
80
+ help="Specify seed for reproducibility.",
81
+ )
82
+ parser.add_argument(
83
+ "--deterministic", action="store_true",
84
+ help="Use deterministic algorithms for reproducibility.",
85
+ )
86
+
87
+ if __name__ == "__main__":
88
+ """Example usage:
89
+ python lama_inpaint.py \
90
+ --input_img FA_demo/FA1_dog.png \
91
+ --input_mask_glob "results/FA1_dog/mask*.png" \
92
+ --text_prompt "a teddy bear on a bench" \
93
+ --output_dir results
94
+ """
95
+ parser = argparse.ArgumentParser()
96
+ setup_args(parser)
97
+ args = parser.parse_args(sys.argv[1:])
98
+ device = "cuda" if torch.cuda.is_available() else "cpu"
99
+
100
+ if args.deterministic:
101
+ os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"
102
+ torch.use_deterministic_algorithms(True)
103
+
104
+ img_stem = Path(args.input_img).stem
105
+ mask_ps = sorted(glob.glob(args.input_mask_glob))
106
+ out_dir = Path(args.output_dir) / img_stem
107
+ out_dir.mkdir(parents=True, exist_ok=True)
108
+
109
+ img = load_img_to_array(args.input_img)
110
+ for mask_p in mask_ps:
111
+ if args.seed is not None:
112
+ torch.manual_seed(args.seed)
113
+ mask = load_img_to_array(mask_p)
114
+ img_filled_p = out_dir / f"filled_with_{Path(mask_p).name}"
115
+ img_filled = fill_img_with_sd(
116
+ img, mask, args.text_prompt, device=device)
117
+ save_array_to_img(img_filled, img_filled_p)
utils/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .utils import *
utils/crop_for_replacing.py ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import numpy as np
3
+ from typing import Tuple
4
+
5
+ def resize_and_pad(image: np.ndarray, mask: np.ndarray, target_size: int = 512) -> Tuple[np.ndarray, np.ndarray]:
6
+ """
7
+ Resizes an image and its corresponding mask to have the longer side equal to `target_size` and pads them to make them
8
+ both have the same size. The resulting image and mask have dimensions (target_size, target_size).
9
+
10
+ Args:
11
+ image: A numpy array representing the image to resize and pad.
12
+ mask: A numpy array representing the mask to resize and pad.
13
+ target_size: An integer specifying the desired size of the longer side after resizing.
14
+
15
+ Returns:
16
+ A tuple containing two numpy arrays - the resized and padded image and the resized and padded mask.
17
+ """
18
+ height, width, _ = image.shape
19
+ max_dim = max(height, width)
20
+ scale = target_size / max_dim
21
+ new_height = int(height * scale)
22
+ new_width = int(width * scale)
23
+ image_resized = cv2.resize(image, (new_width, new_height), interpolation=cv2.INTER_LINEAR)
24
+ mask_resized = cv2.resize(mask, (new_width, new_height), interpolation=cv2.INTER_LINEAR)
25
+ pad_height = target_size - new_height
26
+ pad_width = target_size - new_width
27
+ top_pad = pad_height // 2
28
+ bottom_pad = pad_height - top_pad
29
+ left_pad = pad_width // 2
30
+ right_pad = pad_width - left_pad
31
+ image_padded = np.pad(image_resized, ((top_pad, bottom_pad), (left_pad, right_pad), (0, 0)), mode='constant')
32
+ mask_padded = np.pad(mask_resized, ((top_pad, bottom_pad), (left_pad, right_pad)), mode='constant')
33
+ return image_padded, mask_padded, (top_pad, bottom_pad, left_pad, right_pad)
34
+
35
+ def recover_size(image_padded: np.ndarray, mask_padded: np.ndarray, orig_size: Tuple[int, int],
36
+ padding_factors: Tuple[int, int, int, int]) -> Tuple[np.ndarray, np.ndarray]:
37
+ """
38
+ Resizes a padded and resized image and mask to the original size.
39
+
40
+ Args:
41
+ image_padded: A numpy array representing the padded and resized image.
42
+ mask_padded: A numpy array representing the padded and resized mask.
43
+ orig_size: A tuple containing two integers - the original height and width of the image before resizing and padding.
44
+
45
+ Returns:
46
+ A tuple containing two numpy arrays - the recovered image and the recovered mask with dimensions `orig_size`.
47
+ """
48
+ h,w,c = image_padded.shape
49
+ top_pad, bottom_pad, left_pad, right_pad = padding_factors
50
+ image = image_padded[top_pad:h-bottom_pad, left_pad:w-right_pad, :]
51
+ mask = mask_padded[top_pad:h-bottom_pad, left_pad:w-right_pad]
52
+ image_resized = cv2.resize(image, orig_size[::-1], interpolation=cv2.INTER_LINEAR)
53
+ mask_resized = cv2.resize(mask, orig_size[::-1], interpolation=cv2.INTER_LINEAR)
54
+ return image_resized, mask_resized
55
+
56
+
57
+
58
+
59
+ if __name__ == '__main__':
60
+
61
+ # image = cv2.imread('example/boat.jpg')
62
+ # mask = cv2.imread('example/boat_mask_2.png', cv2.IMREAD_GRAYSCALE)
63
+ # image = cv2.imread('example/groceries.jpg')
64
+ # mask = cv2.imread('example/groceries_mask_2.png', cv2.IMREAD_GRAYSCALE)
65
+ # image = cv2.imread('example/bridge.jpg')
66
+ # mask = cv2.imread('example/bridge_mask_2.png', cv2.IMREAD_GRAYSCALE)
67
+ # image = cv2.imread('example/person_umbrella.jpg')
68
+ # mask = cv2.imread('example/person_umbrella_mask_2.png', cv2.IMREAD_GRAYSCALE)
69
+ # image = cv2.imread('example/hippopotamus.jpg')
70
+ # mask = cv2.imread('example/hippopotamus_mask_1.png', cv2.IMREAD_GRAYSCALE)
71
+ image = cv2.imread('/data1/yutao/projects/IAM/Inpaint-Anything/example/fill-anything/sample5.jpeg')
72
+ mask = cv2.imread('/data1/yutao/projects/IAM/Inpaint-Anything/example/fill-anything/sample5/mask.png', cv2.IMREAD_GRAYSCALE)
73
+ print(image.shape)
74
+ print(mask.shape)
75
+ cv2.imwrite('original_image.jpg', image)
76
+ cv2.imwrite('original_mask.jpg', mask)
77
+ image_padded, mask_padded, padding_factors = resize_and_pad(image, mask)
78
+ cv2.imwrite('padded_image.png', image_padded)
79
+ cv2.imwrite('padded_mask.png', mask_padded)
80
+ print(image_padded.shape, mask_padded.shape, padding_factors)
81
+
82
+ # ^ ------------------------------------------------------------------------------------
83
+ # ^ Please conduct inpainting or filling here on the cropped image with the cropped mask
84
+ # ^ ------------------------------------------------------------------------------------
85
+
86
+ # resize and pad the image and mask
87
+
88
+ # perform some operation on the 512x512 image and mask
89
+ # ...
90
+
91
+ # recover the image and mask to the original size
92
+ height, width, _ = image.shape
93
+ image_resized, mask_resized = recover_size(image_padded, mask_padded, (height, width), padding_factors)
94
+
95
+ # save the resized and recovered image and mask
96
+ cv2.imwrite('resized_and_padded_image.png', image_padded)
97
+ cv2.imwrite('resized_and_padded_mask.png', mask_padded)
98
+ cv2.imwrite('recovered_image.png', image_resized)
99
+ cv2.imwrite('recovered_mask.png', mask_resized)
100
+
101
+
utils/get_point_coor.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+
3
+ def click_event(event, x, y, flags, param):
4
+ if event == cv2.EVENT_LBUTTONDOWN:
5
+ print("Point coordinates ({}, {})".format(x, y))
6
+ img = cv2.imread("./example/remove-anything/dog.jpg")
7
+
8
+ cv2.imshow("Image", img)
9
+ cv2.setMouseCallback("Image", click_event)
10
+ cv2.waitKey(0)
11
+
12
+ cv2.destroyAllWindows()
utils/mask_processing.py ADDED
@@ -0,0 +1,160 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ from matplotlib import pyplot as plt
3
+ import PIL.Image as Image
4
+ import numpy as np
5
+
6
+
7
+ def crop_for_filling_pre(image: np.array, mask: np.array, crop_size: int = 512):
8
+ # Calculate the aspect ratio of the image
9
+ height, width = image.shape[:2]
10
+ aspect_ratio = float(width) / float(height)
11
+
12
+ # If the shorter side is less than 512, resize the image proportionally
13
+ if min(height, width) < crop_size:
14
+ if height < width:
15
+ new_height = crop_size
16
+ new_width = int(new_height * aspect_ratio)
17
+ else:
18
+ new_width = crop_size
19
+ new_height = int(new_width / aspect_ratio)
20
+
21
+ image = cv2.resize(image, (new_width, new_height))
22
+ mask = cv2.resize(mask, (new_width, new_height))
23
+
24
+ # Find the bounding box of the mask
25
+ x, y, w, h = cv2.boundingRect(mask)
26
+
27
+ # Update the height and width of the resized image
28
+ height, width = image.shape[:2]
29
+
30
+ # # If the 512x512 square cannot cover the entire mask, resize the image accordingly
31
+ if w > crop_size or h > crop_size:
32
+ # padding to square at first
33
+ if height < width:
34
+ padding = width - height
35
+ image = np.pad(image, ((padding // 2, padding - padding // 2), (0, 0), (0, 0)), 'constant')
36
+ mask = np.pad(mask, ((padding // 2, padding - padding // 2), (0, 0)), 'constant')
37
+ else:
38
+ padding = height - width
39
+ image = np.pad(image, ((0, 0), (padding // 2, padding - padding // 2), (0, 0)), 'constant')
40
+ mask = np.pad(mask, ((0, 0), (padding // 2, padding - padding // 2)), 'constant')
41
+
42
+ resize_factor = crop_size / max(w, h)
43
+ image = cv2.resize(image, (0, 0), fx=resize_factor, fy=resize_factor)
44
+ mask = cv2.resize(mask, (0, 0), fx=resize_factor, fy=resize_factor)
45
+ x, y, w, h = cv2.boundingRect(mask)
46
+
47
+ # Calculate the crop coordinates
48
+ crop_x = min(max(x + w // 2 - crop_size // 2, 0), width - crop_size)
49
+ crop_y = min(max(y + h // 2 - crop_size // 2, 0), height - crop_size)
50
+
51
+ # Crop the image
52
+ cropped_image = image[crop_y:crop_y + crop_size, crop_x:crop_x + crop_size]
53
+ cropped_mask = mask[crop_y:crop_y + crop_size, crop_x:crop_x + crop_size]
54
+
55
+ return cropped_image, cropped_mask
56
+
57
+
58
+ def crop_for_filling_post(
59
+ image: np.array,
60
+ mask: np.array,
61
+ filled_image: np.array,
62
+ crop_size: int = 512,
63
+ ):
64
+ image_copy = image.copy()
65
+ mask_copy = mask.copy()
66
+ # Calculate the aspect ratio of the image
67
+ height, width = image.shape[:2]
68
+ height_ori, width_ori = height, width
69
+ aspect_ratio = float(width) / float(height)
70
+
71
+ # If the shorter side is less than 512, resize the image proportionally
72
+ if min(height, width) < crop_size:
73
+ if height < width:
74
+ new_height = crop_size
75
+ new_width = int(new_height * aspect_ratio)
76
+ else:
77
+ new_width = crop_size
78
+ new_height = int(new_width / aspect_ratio)
79
+
80
+ image = cv2.resize(image, (new_width, new_height))
81
+ mask = cv2.resize(mask, (new_width, new_height))
82
+
83
+ # Find the bounding box of the mask
84
+ x, y, w, h = cv2.boundingRect(mask)
85
+
86
+ # Update the height and width of the resized image
87
+ height, width = image.shape[:2]
88
+
89
+ # # If the 512x512 square cannot cover the entire mask, resize the image accordingly
90
+ if w > crop_size or h > crop_size:
91
+ flag_padding = True
92
+ # padding to square at first
93
+ if height < width:
94
+ padding = width - height
95
+ image = np.pad(image, ((padding // 2, padding - padding // 2), (0, 0), (0, 0)), 'constant')
96
+ mask = np.pad(mask, ((padding // 2, padding - padding // 2), (0, 0)), 'constant')
97
+ padding_side = 'h'
98
+ else:
99
+ padding = height - width
100
+ image = np.pad(image, ((0, 0), (padding // 2, padding - padding // 2), (0, 0)), 'constant')
101
+ mask = np.pad(mask, ((0, 0), (padding // 2, padding - padding // 2)), 'constant')
102
+ padding_side = 'w'
103
+
104
+ resize_factor = crop_size / max(w, h)
105
+ image = cv2.resize(image, (0, 0), fx=resize_factor, fy=resize_factor)
106
+ mask = cv2.resize(mask, (0, 0), fx=resize_factor, fy=resize_factor)
107
+ x, y, w, h = cv2.boundingRect(mask)
108
+ else:
109
+ flag_padding = False
110
+
111
+ # Calculate the crop coordinates
112
+ crop_x = min(max(x + w // 2 - crop_size // 2, 0), width - crop_size)
113
+ crop_y = min(max(y + h // 2 - crop_size // 2, 0), height - crop_size)
114
+
115
+ # Fill the image
116
+ image[crop_y:crop_y + crop_size, crop_x:crop_x + crop_size] = filled_image
117
+ if flag_padding:
118
+ image = cv2.resize(image, (0, 0), fx=1/resize_factor, fy=1/resize_factor)
119
+ if padding_side == 'h':
120
+ image = image[padding // 2:padding // 2 + height_ori, :]
121
+ else:
122
+ image = image[:, padding // 2:padding // 2 + width_ori]
123
+
124
+ image = cv2.resize(image, (width_ori, height_ori))
125
+
126
+ image_copy[mask_copy==255] = image[mask_copy==255]
127
+ return image_copy
128
+
129
+
130
+ if __name__ == '__main__':
131
+
132
+ # image = cv2.imread('example/boat.jpg')
133
+ # mask = cv2.imread('example/boat_mask_2.png', cv2.IMREAD_GRAYSCALE)
134
+ image = cv2.imread('./example/groceries.jpg')
135
+ mask = cv2.imread('example/groceries_mask_2.png', cv2.IMREAD_GRAYSCALE)
136
+ # image = cv2.imread('example/bridge.jpg')
137
+ # mask = cv2.imread('example/bridge_mask_2.png', cv2.IMREAD_GRAYSCALE)
138
+ # image = cv2.imread('example/person_umbrella.jpg')
139
+ # mask = cv2.imread('example/person_umbrella_mask_2.png', cv2.IMREAD_GRAYSCALE)
140
+ # image = cv2.imread('example/hippopotamus.jpg')
141
+ # mask = cv2.imread('example/hippopotamus_mask_1.png', cv2.IMREAD_GRAYSCALE)
142
+
143
+ cropped_image, cropped_mask = crop_for_filling_pre(image, mask)
144
+ # ^ ------------------------------------------------------------------------------------
145
+ # ^ Please conduct inpainting or filling here on the cropped image with the cropped mask
146
+ # ^ ------------------------------------------------------------------------------------
147
+
148
+ # e.g.
149
+ # cropped_image[cropped_mask==255] = 0
150
+ cv2.imwrite('cropped_image.jpg', cropped_image)
151
+ cv2.imwrite('cropped_mask.jpg', cropped_mask)
152
+ print(cropped_image.shape)
153
+ print(cropped_mask.shape)
154
+
155
+ image = crop_for_filling_post(image, mask, cropped_image)
156
+ cv2.imwrite('filled_image.jpg', image)
157
+ print(image.shape)
158
+
159
+
160
+
utils/paste_object.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import numpy as np
3
+
4
+ def paste_object(source, source_mask, target, target_coords, resize_scale=1):
5
+ assert target_coords[0] < target.shape[1] and target_coords[1] < target.shape[0]
6
+ # Find the bounding box of the source_mask
7
+ x, y, w, h = cv2.boundingRect(source_mask)
8
+ assert h < source.shape[0] and w < source.shape[1]
9
+ obj = source[y:y+h, x:x+w]
10
+ obj_msk = source_mask[y:y+h, x:x+w]
11
+ if resize_scale != 1:
12
+ obj = cv2.resize(obj, (0,0), fx=resize_scale, fy=resize_scale)
13
+ obj_msk = cv2.resize(obj_msk, (0,0), fx=resize_scale, fy=resize_scale)
14
+ _, _, w, h = cv2.boundingRect(obj_msk)
15
+
16
+ xt = max(0, target_coords[0]-w//2)
17
+ yt = max(0, target_coords[1]-h//2)
18
+ if target_coords[0]-w//2 < 0:
19
+ obj = obj[:, w//2-target_coords[0]:]
20
+ obj_msk = obj_msk[:, w//2-target_coords[0]:]
21
+ if target_coords[0]+w//2 > target.shape[1]:
22
+ obj = obj[:, :target.shape[1]-target_coords[0]+w//2]
23
+ obj_msk = obj_msk[:, :target.shape[1]-target_coords[0]+w//2]
24
+ if target_coords[1]-h//2 < 0:
25
+ obj = obj[h//2-target_coords[1]:, :]
26
+ obj_msk = obj_msk[h//2-target_coords[1]:, :]
27
+ if target_coords[1]+h//2 > target.shape[0]:
28
+ obj = obj[:target.shape[0]-target_coords[1]+h//2, :]
29
+ obj_msk = obj_msk[:target.shape[0]-target_coords[1]+h//2, :]
30
+ _, _, w, h = cv2.boundingRect(obj_msk)
31
+
32
+ target[yt:yt+h, xt:xt+w][obj_msk==255] = obj[obj_msk==255]
33
+ target_mask = np.zeros_like(target)
34
+ target_mask = cv2.cvtColor(target_mask, cv2.COLOR_BGR2GRAY)
35
+ target_mask[yt:yt+h, xt:xt+w][obj_msk==255] = 255
36
+
37
+ return target, target_mask
38
+
39
+ if __name__ == '__main__':
40
+ source = cv2.imread('example/boat.jpg')
41
+ source_mask = cv2.imread('example/boat_mask_1.png', 0)
42
+ target = cv2.imread('example/hippopotamus.jpg')
43
+ print(source.shape, source_mask.shape, target.shape)
44
+
45
+ target_coords = (700, 400) # (x, y)
46
+ resize_scale = 1
47
+ target, target_mask = paste_object(source, source_mask, target, target_coords, resize_scale)
48
+ cv2.imwrite('target_pasted.png', target)
49
+ cv2.imwrite('target_mask.png', target_mask)
50
+ print(target.shape, target_mask.shape)
utils/utils.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import numpy as np
3
+ from PIL import Image
4
+ from typing import Any, Dict, List
5
+
6
+
7
+ def load_img_to_array(img_p):
8
+ return np.array(Image.open(img_p))
9
+
10
+
11
+ def save_array_to_img(img_arr, img_p):
12
+ Image.fromarray(img_arr.astype(np.uint8)).save(img_p)
13
+
14
+
15
+ def dilate_mask(mask, dilate_factor=15):
16
+ mask = mask.astype(np.uint8)
17
+ mask = cv2.dilate(
18
+ mask,
19
+ np.ones((dilate_factor, dilate_factor), np.uint8),
20
+ iterations=1
21
+ )
22
+ return mask
23
+
24
+ def erode_mask(mask, dilate_factor=15):
25
+ mask = mask.astype(np.uint8)
26
+ mask = cv2.erode(
27
+ mask,
28
+ np.ones((dilate_factor, dilate_factor), np.uint8),
29
+ iterations=1
30
+ )
31
+ return mask
32
+
33
+ def show_mask(ax, mask: np.ndarray, random_color=False):
34
+ mask = mask.astype(np.uint8)
35
+ if np.max(mask) == 255:
36
+ mask = mask / 255
37
+ if random_color:
38
+ color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
39
+ else:
40
+ color = np.array([30 / 255, 144 / 255, 255 / 255, 0.6])
41
+ h, w = mask.shape[-2:]
42
+ mask_img = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
43
+ ax.imshow(mask_img)
44
+
45
+
46
+ def show_points(ax, coords: List[List[float]], labels: List[int], size=375):
47
+ coords = np.array(coords)
48
+ labels = np.array(labels)
49
+ color_table = {0: 'red', 1: 'green'}
50
+ for label_value, color in color_table.items():
51
+ points = coords[labels == label_value]
52
+ ax.scatter(points[:, 0], points[:, 1], color=color, marker='*',
53
+ s=size, edgecolor='white', linewidth=1.25)
utils/visual_mask_on_img.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import sys
3
+ import argparse
4
+ import numpy as np
5
+ from PIL import Image
6
+ from pathlib import Path
7
+ from matplotlib import pyplot as plt
8
+ from typing import Any, Dict, List
9
+ import glob
10
+
11
+ from utils import load_img_to_array, show_mask
12
+
13
+
14
+ def setup_args(parser):
15
+ parser.add_argument(
16
+ "--input_img", type=str, required=True,
17
+ help="Path to a single input img",
18
+ )
19
+ parser.add_argument(
20
+ "--input_mask_glob", type=str, required=True,
21
+ help="Glob to input masks",
22
+ )
23
+ parser.add_argument(
24
+ "--output_dir", type=str, required=True,
25
+ help="Output path to the directory with results.",
26
+ )
27
+
28
+ if __name__ == "__main__":
29
+ """Example usage:
30
+ python visual_mask_on_img.py \
31
+ --input_img FA_demo/FA1_dog.png \
32
+ --input_mask_glob "results/FA1_dog/mask*.png" \
33
+ --output_dir results
34
+ """
35
+ parser = argparse.ArgumentParser()
36
+ setup_args(parser)
37
+ args = parser.parse_args(sys.argv[1:])
38
+
39
+ img = load_img_to_array(args.input_img)
40
+ img_stem = Path(args.input_img).stem
41
+
42
+ mask_ps = sorted(glob.glob(args.input_mask_glob))
43
+
44
+ out_dir = Path(args.output_dir) / img_stem
45
+ out_dir.mkdir(parents=True, exist_ok=True)
46
+
47
+ for mask_p in mask_ps:
48
+ mask = load_img_to_array(mask_p)
49
+ mask = mask.astype(np.uint8)
50
+
51
+ # path to the results
52
+ img_mask_p = out_dir / f"with_{Path(mask_p).name}"
53
+
54
+ # save the masked image
55
+ dpi = plt.rcParams['figure.dpi']
56
+ height, width = img.shape[:2]
57
+ plt.figure(figsize=(width/dpi/0.77, height/dpi/0.77))
58
+ plt.imshow(img)
59
+ plt.axis('off')
60
+ show_mask(plt.gca(), mask, random_color=False)
61
+ plt.savefig(img_mask_p, bbox_inches='tight', pad_inches=0)
62
+ plt.close()