ViewDiffusion / app.py
BertChristiaens's picture
Update app.py
3c39785
import streamlit as st
from streamlit_drawable_canvas import st_canvas
from PIL import Image
import time
import io
import cv2
import numpy as np
from camera_input_live import camera_input_live
from inference import inpainting
st.set_page_config(layout="wide")
def make_canvas(image):
canvas_dict = dict(
fill_color='#F00000',
stroke_color='#000000',
background_color="#FFFFFF",
background_image=image,
stroke_width=40,
update_streamlit=True,
height=512,
width=512,
drawing_mode='freedraw',
key="canvas"
)
return st_canvas(**canvas_dict)
def get_mask(image_mask: np.ndarray) -> np.ndarray:
"""Get the mask from the segmentation mask.
Args:
image_mask (np.ndarray): segmentation mask
Returns:
np.ndarray: mask
"""
# average the colors of the segmentation masks
average_color = np.mean(image_mask, axis=(2))
mask = average_color[:, :] > 0
if mask.sum() > 0:
mask = mask * 1
# 3 channels
mask = np.stack([mask, mask, mask], axis=2)*255
mask = mask.astype(np.uint8)
mask = Image.fromarray(mask).convert("RGB")
return mask
def make_prompt_fields():
st.write("### Prompting")
# prompt
prompt = st.text_input("Prompt", value="A person in a room with colored hair", key="prompt")
# negative prompt
negative_prompt = st.text_input("Negative Prompt", value="Facial hair", key="negative_prompt")
return prompt, negative_prompt
def make_input_fields():
st.write("### Parameters")
guidance_scale = st.slider("Guidance Scale", min_value=0.0, max_value=50.0, value=7.5, step=0.25, key="guidance_scale")
inference_steps = st.slider("Inference Steps", min_value=1, max_value=50, value=20, step=1, key="inference_steps")
generator_seed = st.slider("Generator Seed", min_value=0, max_value=10_000, value=0, step=1, key="generator_seed")
st.write("### Latent walk")
static_latents = st.checkbox("Static Latents", value=False, key="static_latents")
latent_walk = st.slider("Latent Walk", min_value=0.0, max_value=1.0, value=0.0, step=0.01, key="latent_walk")
return guidance_scale, inference_steps, generator_seed, static_latents, latent_walk
def decode_image(image):
cv2_img = cv2.imdecode(np.frombuffer(image, np.uint8), cv2.IMREAD_COLOR)
cv2_img = cv2.cvtColor(cv2_img, cv2.COLOR_BGR2RGB)
image = Image.fromarray(cv2_img).convert("RGB").resize((512, 512))
return image
if __name__ == "__main__":
st.sidebar.title("Sidebar")
with st.sidebar:
webcam = camera_input_live(debounce=1000, key="webcam", width=512, height=512)
prompt, negative_prompt = make_prompt_fields()
guidance_scale, inference_steps, generator_seed, static_latents, latent_walk = make_input_fields()
colA, colB = st.columns(2)
if webcam:
with colA:
st.write("## Webcam image")
st.write("You can draw the mask on the image below.")
image = decode_image(webcam.getvalue())
canvas = make_canvas(image)
if st.button("Inpaint"):
st.write("Start inpainting process")
mask_image = get_mask(np.array(canvas.image_data))
result = inpainting(image, mask_image, prompt, negative_prompt)
st.session_state["result"] = result
else:
result = None
with colB:
st.write("## Generated image")
st.write("The generated image will appear here.")
st.image(webcam)
if 'result' in st.session_state:
print("Showing result")
st.image(st.session_state["result"])