Spaces:
Running
on
Zero
Running
on
Zero
import hashlib | |
import json | |
import os | |
import re | |
from pathlib import Path | |
import folder_paths | |
import numpy as np | |
import torch | |
from PIL import Image, ImageOps | |
from PIL.PngImagePlugin import PngInfo | |
from ..log import log | |
class MTB_LoadImageSequence: | |
"""Load an image sequence from a folder. The current frame is used to determine which image to load. | |
Usually used in conjunction with the `Primitive` node set to increment to load a sequence of images from a folder. | |
Use -1 to load all matching frames as a batch. | |
""" | |
def INPUT_TYPES(cls): | |
return { | |
"required": { | |
"path": ("STRING", {"default": "videos/####.png"}), | |
"current_frame": ( | |
"INT", | |
{"default": 0, "min": -1, "max": 9999999}, | |
), | |
}, | |
"optional": { | |
"range": ("STRING", {"default": ""}), | |
}, | |
} | |
CATEGORY = "mtb/IO" | |
FUNCTION = "load_image" | |
RETURN_TYPES = ( | |
"IMAGE", | |
"MASK", | |
"INT", | |
"INT", | |
) | |
RETURN_NAMES = ( | |
"image", | |
"mask", | |
"current_frame", | |
"total_frames", | |
) | |
def load_image(self, path=None, current_frame=0, range=""): | |
load_all = current_frame == -1 | |
total_frames = 1 | |
if range: | |
frames = self.get_frames_from_range(path, range) | |
imgs, masks = zip(*(img_from_path(frame) for frame in frames)) | |
out_img = torch.cat(imgs, dim=0) | |
out_mask = torch.cat(masks, dim=0) | |
total_frames = len(imgs) | |
return (out_img, out_mask, -1, total_frames) | |
elif load_all: | |
log.debug(f"Loading all frames from {path}") | |
frames = resolve_all_frames(path) | |
log.debug(f"Found {len(frames)} frames") | |
imgs = [] | |
masks = [] | |
imgs, masks = zip(*(img_from_path(frame) for frame in frames)) | |
out_img = torch.cat(imgs, dim=0) | |
out_mask = torch.cat(masks, dim=0) | |
total_frames = len(imgs) | |
return (out_img, out_mask, -1, total_frames) | |
log.debug(f"Loading image: {path}, {current_frame}") | |
resolved_path = resolve_path(path, current_frame) | |
image_path = folder_paths.get_annotated_filepath(resolved_path) | |
image, mask = img_from_path(image_path) | |
return (image, mask, current_frame, total_frames) | |
def get_frames_from_range(self, path, range_str): | |
try: | |
start, end = map(int, range_str.split("-")) | |
except ValueError: | |
raise ValueError( | |
f"Invalid range format: {range_str}. Expected format is 'start-end'." | |
) | |
frames = resolve_all_frames(path) | |
total_frames = len(frames) | |
if start < 0 or end >= total_frames: | |
raise ValueError( | |
f"Range {range_str} is out of bounds. Total frames available: {total_frames}" | |
) | |
if "#" in path: | |
frame_regex = re.escape(path).replace(r"\#", r"(\d+)") | |
frame_number_regex = re.compile(frame_regex) | |
matching_frames = [] | |
for frame in frames: | |
match = frame_number_regex.search(frame) | |
if match: | |
frame_number = int(match.group(1)) | |
if start <= frame_number <= end: | |
matching_frames.append(frame) | |
return matching_frames | |
else: | |
log.warning( | |
f"Wildcard pattern or directory will use indexes instead of frame numbers for : {path}" | |
) | |
selected_frames = frames[start : end + 1] | |
return selected_frames | |
def IS_CHANGED(path="", current_frame=0, range=""): | |
print(f"Checking if changed: {path}, {current_frame}") | |
if range or current_frame == -1: | |
resolved_paths = resolve_all_frames(path) | |
timestamps = [ | |
os.path.getmtime(folder_paths.get_annotated_filepath(p)) | |
for p in resolved_paths | |
] | |
combined_hash = hashlib.sha256( | |
"".join(map(str, timestamps)).encode() | |
) | |
return combined_hash.hexdigest() | |
resolved_path = resolve_path(path, current_frame) | |
image_path = folder_paths.get_annotated_filepath(resolved_path) | |
if os.path.exists(image_path): | |
m = hashlib.sha256() | |
with open(image_path, "rb") as f: | |
m.update(f.read()) | |
return m.digest().hex() | |
return "NONE" | |
# @staticmethod | |
# def VALIDATE_INPUTS(path="", current_frame=0): | |
# print(f"Validating inputs: {path}, {current_frame}") | |
# resolved_path = resolve_path(path, current_frame) | |
# if not folder_paths.exists_annotated_filepath(resolved_path): | |
# return f"Invalid image file: {resolved_path}" | |
# return True | |
import glob | |
def img_from_path(path): | |
img = Image.open(path) | |
img = ImageOps.exif_transpose(img) | |
image = img.convert("RGB") | |
image = np.array(image).astype(np.float32) / 255.0 | |
image = torch.from_numpy(image)[None,] | |
if "A" in img.getbands(): | |
mask = np.array(img.getchannel("A")).astype(np.float32) / 255.0 | |
mask = 1.0 - torch.from_numpy(mask) | |
else: | |
mask = torch.zeros((64, 64), dtype=torch.float32, device="cpu") | |
return ( | |
image, | |
mask, | |
) | |
def resolve_all_frames(path: str): | |
frames: list[str] = [] | |
if "#" not in path: | |
pth = Path(path) | |
if pth.is_dir(): | |
for f in pth.iterdir(): | |
if f.suffix in [".jpg", ".png"]: | |
frames.append(f.as_posix()) | |
elif "*" in path: | |
frames = glob.glob(path) | |
else: | |
raise ValueError( | |
"The path doesn't contain a # or a * or is not a directory" | |
) | |
frames.sort() | |
return frames | |
pattern = path | |
folder_path, file_pattern = os.path.split(pattern) | |
log.debug(f"Resolving all frames in {folder_path}") | |
hash_count = file_pattern.count("#") | |
frame_pattern = re.sub(r"#+", "*", file_pattern) | |
log.debug(f"Found pattern: {frame_pattern}") | |
matching_files = glob.glob(os.path.join(folder_path, frame_pattern)) | |
log.debug(f"Found {len(matching_files)} matching files") | |
frame_regex = re.escape(file_pattern).replace(r"\#", r"(\d+)") | |
frame_number_regex = re.compile(frame_regex) | |
for file in matching_files: | |
match = frame_number_regex.search(file) | |
if match: | |
frame_number = match.group(1) | |
log.debug(f"Found frame number: {frame_number}") | |
# resolved_file = pattern.replace("*" * frame_number.count("#"), frame_number) | |
frames.append(file) | |
frames.sort() # Sort frames alphabetically | |
return frames | |
def resolve_path(path, frame): | |
hashes = path.count("#") | |
padded_number = str(frame).zfill(hashes) | |
return re.sub("#+", padded_number, path) | |
class MTB_SaveImageSequence: | |
"""Save an image sequence to a folder. The current frame is used to determine which image to save. | |
This is merely a wrapper around the `save_images` function with formatting for the output folder and filename. | |
""" | |
def __init__(self): | |
self.output_dir = folder_paths.get_output_directory() | |
self.type = "output" | |
def INPUT_TYPES(cls): | |
return { | |
"required": { | |
"images": ("IMAGE",), | |
"filename_prefix": ("STRING", {"default": "Sequence"}), | |
"current_frame": ( | |
"INT", | |
{"default": 0, "min": 0, "max": 9999999}, | |
), | |
}, | |
"hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"}, | |
} | |
RETURN_TYPES = () | |
FUNCTION = "save_images" | |
OUTPUT_NODE = True | |
CATEGORY = "mtb/IO" | |
def save_images( | |
self, | |
images, | |
filename_prefix="Sequence", | |
current_frame=0, | |
prompt=None, | |
extra_pnginfo=None, | |
): | |
# full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(filename_prefix, self.output_dir, images[0].shape[1], images[0].shape[0]) | |
# results = list() | |
# for image in images: | |
# i = 255. * image.cpu().numpy() | |
# img = Image.fromarray(np.clip(i, 0, 255).astype(np.uint8)) | |
# metadata = PngInfo() | |
# if prompt is not None: | |
# metadata.add_text("prompt", json.dumps(prompt)) | |
# if extra_pnginfo is not None: | |
# for x in extra_pnginfo: | |
# metadata.add_text(x, json.dumps(extra_pnginfo[x])) | |
# file = f"{filename}_{counter:05}_.png" | |
# img.save(os.path.join(full_output_folder, file), pnginfo=metadata, compress_level=4) | |
# results.append({ | |
# "filename": file, | |
# "subfolder": subfolder, | |
# "type": self.type | |
# }) | |
# counter += 1 | |
if len(images) > 1: | |
raise ValueError("Can only save one image at a time") | |
resolved_path = Path(self.output_dir) / filename_prefix | |
resolved_path.mkdir(parents=True, exist_ok=True) | |
resolved_img = ( | |
resolved_path / f"{filename_prefix}_{current_frame:05}.png" | |
) | |
output_image = images[0].cpu().numpy() | |
img = Image.fromarray( | |
np.clip(output_image * 255.0, 0, 255).astype(np.uint8) | |
) | |
metadata = PngInfo() | |
if prompt is not None: | |
metadata.add_text("prompt", json.dumps(prompt)) | |
if extra_pnginfo is not None: | |
for x in extra_pnginfo: | |
metadata.add_text(x, json.dumps(extra_pnginfo[x])) | |
img.save(resolved_img, pnginfo=metadata, compress_level=4) | |
return { | |
"ui": { | |
"images": [ | |
{ | |
"filename": resolved_img.name, | |
"subfolder": resolved_path.name, | |
"type": self.type, | |
} | |
] | |
} | |
} | |
__nodes__ = [ | |
MTB_LoadImageSequence, | |
MTB_SaveImageSequence, | |
] | |