Spaces:
Running
Running
import numpy as np | |
import cv2 | |
from basicsr.utils import img2tensor | |
import torch | |
import torch.nn.functional as F | |
def resize_numpy_image(image, max_resolution=768 * 768, resize_short_edge=None): | |
h, w = image.shape[:2] | |
w_org = image.shape[1] | |
if resize_short_edge is not None: | |
k = resize_short_edge / min(h, w) | |
else: | |
k = max_resolution / (h * w) | |
k = k**0.5 | |
h = int(np.round(h * k / 64)) * 64 | |
w = int(np.round(w * k / 64)) * 64 | |
image = cv2.resize(image, (w, h), interpolation=cv2.INTER_LANCZOS4) | |
scale = w/w_org | |
return image, scale | |
def split_ldm(ldm): | |
x = [] | |
y = [] | |
for p in ldm: | |
x.append(p[0]) | |
y.append(p[1]) | |
return x,y | |
def process_move(path_mask, h, w, dx, dy, scale, input_scale, resize_scale, up_scale, up_ft_index, w_edit, w_content, w_contrast, w_inpaint, precision, path_mask_ref=None): | |
dx, dy = dx*input_scale, dy*input_scale | |
if isinstance(path_mask, str): | |
mask_x0 = cv2.imread(path_mask) | |
else: | |
mask_x0 = path_mask | |
mask_x0 = cv2.resize(mask_x0, (h, w)) | |
if path_mask_ref is not None: | |
if isinstance(path_mask_ref, str): | |
mask_x0_ref = cv2.imread(path_mask_ref) | |
else: | |
mask_x0_ref = path_mask_ref | |
mask_x0_ref = cv2.resize(mask_x0_ref, (h, w)) | |
else: | |
mask_x0_ref=None | |
mask_x0 = img2tensor(mask_x0)[0] | |
mask_x0 = (mask_x0>0.5).float().to('cuda', dtype=precision) | |
if mask_x0_ref is not None: | |
mask_x0_ref = img2tensor(mask_x0_ref)[0] | |
mask_x0_ref = (mask_x0_ref>0.5).float().to('cuda', dtype=precision) | |
mask_org = F.interpolate(mask_x0[None,None], (int(mask_x0.shape[-2]//scale), int(mask_x0.shape[-1]//scale)))>0.5 | |
mask_tar = F.interpolate(mask_x0[None,None], (int(mask_x0.shape[-2]//scale*resize_scale), int(mask_x0.shape[-1]//scale*resize_scale)))>0.5 | |
mask_cur = torch.roll(mask_tar, (int(dy//scale*resize_scale), int(dx//scale*resize_scale)), (-2,-1)) | |
pad_size_x = abs(mask_tar.shape[-1]-mask_org.shape[-1])//2 | |
pad_size_y = abs(mask_tar.shape[-2]-mask_org.shape[-2])//2 | |
if resize_scale>1: | |
sum_before = torch.sum(mask_cur) | |
mask_cur = mask_cur[:,:,pad_size_y:pad_size_y+mask_org.shape[-2],pad_size_x:pad_size_x+mask_org.shape[-1]] | |
sum_after = torch.sum(mask_cur) | |
if sum_after != sum_before: | |
raise ValueError('Resize out of bounds, exiting.') | |
else: | |
temp = torch.zeros(1,1,mask_org.shape[-2], mask_org.shape[-1]).to(mask_org.device) | |
temp[:,:,pad_size_y:pad_size_y+mask_cur.shape[-2],pad_size_x:pad_size_x+mask_cur.shape[-1]]=mask_cur | |
mask_cur =temp>0.5 | |
mask_other = (1-((mask_cur+mask_org)>0.5).float())>0.5 | |
mask_overlap = ((mask_cur.float()+mask_org.float())>1.5).float() | |
mask_non_overlap = (mask_org.float()-mask_overlap)>0.5 | |
return { | |
"mask_x0":mask_x0, | |
"mask_x0_ref":mask_x0_ref, | |
"mask_tar":mask_tar, | |
"mask_cur":mask_cur, | |
"mask_other":mask_other, | |
"mask_overlap":mask_overlap, | |
"mask_non_overlap":mask_non_overlap, | |
"up_scale":up_scale, | |
"up_ft_index":up_ft_index, | |
"resize_scale":resize_scale, | |
"w_edit":w_edit, | |
"w_content":w_content, | |
"w_contrast":w_contrast, | |
"w_inpaint":w_inpaint, | |
} | |
def process_drag_face(h, w, x, y, x_cur, y_cur, scale, input_scale, up_scale, up_ft_index, w_edit, w_inpaint, precision): | |
for i in range(len(x)): | |
x[i] = int(x[i]*input_scale) | |
y[i] = int(y[i]*input_scale) | |
x_cur[i] = int(x_cur[i]*input_scale) | |
y_cur[i] = int(y_cur[i]*input_scale) | |
mask_tar = [] | |
for p_idx in range(len(x)): | |
mask_i = torch.zeros(int(h//scale), int(w//scale)).cuda() | |
y_clip = int(np.clip(y[p_idx]//scale, 1, mask_i.shape[0]-2)) | |
x_clip = int(np.clip(x[p_idx]//scale, 1, mask_i.shape[1]-2)) | |
mask_i[y_clip-1:y_clip+2,x_clip-1:x_clip+2]=1 | |
mask_i = mask_i>0.5 | |
mask_tar.append(mask_i) | |
mask_cur = [] | |
for p_idx in range(len(x_cur)): | |
mask_i = torch.zeros(int(h//scale), int(w//scale)).cuda() | |
y_clip = int(np.clip(y_cur[p_idx]//scale, 1, mask_i.shape[0]-2)) | |
x_clip = int(np.clip(x_cur[p_idx]//scale, 1, mask_i.shape[1]-2)) | |
mask_i[y_clip-1:y_clip+2,x_clip-1:x_clip+2]=1 | |
mask_i=mask_i>0.5 | |
mask_cur.append(mask_i) | |
return { | |
"mask_tar":mask_tar, | |
"mask_cur":mask_cur, | |
"up_scale":up_scale, | |
"up_ft_index":up_ft_index, | |
"w_edit": w_edit, | |
"w_inpaint": w_inpaint, | |
} | |
def process_drag(path_mask, h, w, x, y, x_cur, y_cur, scale, input_scale, up_scale, up_ft_index, w_edit, w_inpaint, w_content, precision, latent_in): | |
if isinstance(path_mask, str): | |
mask_x0 = cv2.imread(path_mask) | |
else: | |
mask_x0 = path_mask | |
mask_x0 = cv2.resize(mask_x0, (h, w)) | |
mask_x0 = img2tensor(mask_x0)[0] | |
dict_mask = {} | |
dict_mask['base'] = mask_x0 | |
mask_x0 = (mask_x0>0.5).float().to('cuda', dtype=precision) | |
mask_other = F.interpolate(mask_x0[None,None], (int(mask_x0.shape[-2]//scale), int(mask_x0.shape[-1]//scale)))<0.5 | |
mask_tar = [] | |
mask_cur = [] | |
for p_idx in range(len(x)): | |
mask_tar_i = torch.zeros(int(mask_x0.shape[-2]//scale), int(mask_x0.shape[-1]//scale)).to('cuda', dtype=precision) | |
mask_cur_i = torch.zeros(int(mask_x0.shape[-2]//scale), int(mask_x0.shape[-1]//scale)).to('cuda', dtype=precision) | |
y_tar_clip = int(np.clip(y[p_idx]//scale, 1, mask_tar_i.shape[0]-2)) | |
x_tar_clip = int(np.clip(x[p_idx]//scale, 1, mask_tar_i.shape[0]-2)) | |
y_cur_clip = int(np.clip(y_cur[p_idx]//scale, 1, mask_cur_i.shape[0]-2)) | |
x_cur_clip = int(np.clip(x_cur[p_idx]//scale, 1, mask_cur_i.shape[0]-2)) | |
mask_tar_i[y_tar_clip-1:y_tar_clip+2,x_tar_clip-1:x_tar_clip+2]=1 | |
mask_cur_i[y_cur_clip-1:y_cur_clip+2,x_cur_clip-1:x_cur_clip+2]=1 | |
mask_tar_i = mask_tar_i>0.5 | |
mask_cur_i=mask_cur_i>0.5 | |
mask_tar.append(mask_tar_i) | |
mask_cur.append(mask_cur_i) | |
latent_in[:,:,y_cur_clip//up_scale-1:y_cur_clip//up_scale+2, x_cur_clip//up_scale-1:x_cur_clip//up_scale+2] = latent_in[:,:, y_tar_clip//up_scale-1:y_tar_clip//up_scale+2, x_tar_clip//up_scale-1:x_tar_clip//up_scale+2] | |
return { | |
"dict_mask":dict_mask, | |
"mask_x0":mask_x0, | |
"mask_tar":mask_tar, | |
"mask_cur":mask_cur, | |
"mask_other":mask_other, | |
"up_scale":up_scale, | |
"up_ft_index":up_ft_index, | |
"w_edit": w_edit, | |
"w_inpaint": w_inpaint, | |
"w_content": w_content, | |
"latent_in":latent_in, | |
} | |
def process_appearance(path_mask, path_mask_replace, h, w, scale, input_scale, up_scale, up_ft_index, w_edit, w_content, precision): | |
if isinstance(path_mask, str): | |
mask_base = cv2.imread(path_mask) | |
else: | |
mask_base = path_mask | |
mask_base = cv2.resize(mask_base, (h, w)) | |
if isinstance(path_mask_replace, str): | |
mask_replace = cv2.imread(path_mask_replace) | |
else: | |
mask_replace = path_mask_replace | |
mask_replace = cv2.resize(mask_replace, (h, w)) | |
dict_mask = {} | |
mask_base = img2tensor(mask_base)[0] | |
dict_mask['base'] = mask_base | |
mask_base = (mask_base>0.5).to('cuda', dtype=precision) | |
mask_replace = img2tensor(mask_replace)[0] | |
dict_mask['replace'] = mask_replace | |
mask_replace = (mask_replace>0.5).to('cuda', dtype=precision) | |
mask_base_cur = F.interpolate(mask_base[None,None], (int(mask_base.shape[-2]//scale), int(mask_base.shape[-1]//scale)))>0.5 | |
mask_replace_cur = F.interpolate(mask_replace[None,None], (int(mask_replace.shape[-2]//scale), int(mask_replace.shape[-1]//scale)))>0.5 | |
return { | |
"dict_mask":dict_mask, | |
"mask_base_cur":mask_base_cur, | |
"mask_replace_cur":mask_replace_cur, | |
"up_scale":up_scale, | |
"up_ft_index":up_ft_index, | |
"w_edit":w_edit, | |
"w_content":w_content, | |
} | |
def process_paste(path_mask, h, w, dx, dy, scale, input_scale, up_scale, up_ft_index, w_edit, w_content, precision, resize_scale=None): | |
dx, dy = dx*input_scale, dy*input_scale | |
if isinstance(path_mask, str): | |
mask_base = cv2.imread(path_mask) | |
else: | |
mask_base = path_mask | |
mask_base = cv2.resize(mask_base, (h, w)) | |
dict_mask = {} | |
mask_base = img2tensor(mask_base)[0][None, None] | |
mask_base = (mask_base>0.5).to('cuda', dtype=precision) | |
if resize_scale is not None and resize_scale!=1: | |
hi, wi = mask_base.shape[-2], mask_base.shape[-1] | |
mask_base = F.interpolate(mask_base, (int(hi*resize_scale), int(wi*resize_scale))) | |
pad_size_x = np.abs(mask_base.shape[-1]-wi)//2 | |
pad_size_y = np.abs(mask_base.shape[-2]-hi)//2 | |
if resize_scale>1: | |
mask_base = mask_base[:,:,pad_size_y:pad_size_y+hi,pad_size_x:pad_size_x+wi] | |
else: | |
temp = torch.zeros(1,1,hi, wi).to(mask_base.device) | |
temp[:,:,pad_size_y:pad_size_y+mask_base.shape[-2],pad_size_x:pad_size_x+mask_base.shape[-1]]=mask_base | |
mask_base = temp | |
mask_replace = mask_base.clone() | |
mask_base = torch.roll(mask_base, (int(dy), int(dx)), (-2,-1)) | |
dict_mask['base'] = mask_base[0,0] | |
dict_mask['replace'] = mask_replace[0,0] | |
mask_replace = (mask_replace>0.5).to('cuda', dtype=precision) | |
mask_base_cur = F.interpolate(mask_base, (int(mask_base.shape[-2]//scale), int(mask_base.shape[-1]//scale)))>0.5 | |
mask_replace_cur = torch.roll(mask_base_cur, (-int(dy/scale), -int(dx/scale)), (-2,-1)) | |
return { | |
"dict_mask":dict_mask, | |
"mask_base_cur":mask_base_cur, | |
"mask_replace_cur":mask_replace_cur, | |
"up_scale":up_scale, | |
"up_ft_index":up_ft_index, | |
"w_edit":w_edit, | |
"w_content":w_content, | |
"w_edit":w_edit, | |
"w_content":w_content, | |
} |