Sophia Yang commited on
Commit
a0019ba
·
1 Parent(s): e92a987
Files changed (1) hide show
  1. app.py +60 -53
app.py CHANGED
@@ -1,6 +1,5 @@
1
  import io
2
 
3
- import hvplot.pandas
4
  import numpy as np
5
  import panel as pn
6
  import param
@@ -10,29 +9,29 @@ import torch
10
 
11
  from diffusers import StableDiffusionInstructPix2PixPipeline
12
 
13
- pn.extension(template="bootstrap")
14
- pn.state.template.main_max_width = "690px"
15
- pn.state.template.accent_base_color = "#F08080"
16
- pn.state.template.header_background = "#F08080"
17
 
18
- # Set up device
19
- device = "cuda" if torch.cuda.is_available() else "cpu"
 
 
20
 
21
- # Model
22
  model_id = "timbrooks/instruct-pix2pix"
 
 
 
23
 
24
- if device == "cuda":
25
- pipe = StableDiffusionInstructPix2PixPipeline.from_pretrained(
26
- model_id, torch_dtype=torch.float16
27
- )
28
- else:
29
- pipe = StableDiffusionInstructPix2PixPipeline.from_pretrained(
30
- model_id
31
- )
32
- pipe = pipe.to(device)
33
-
34
 
35
- def new_image(prompt, image, img_guidance, guidance, steps):
36
  edit = pipe(
37
  prompt,
38
  image=image,
@@ -42,11 +41,10 @@ def new_image(prompt, image, img_guidance, guidance, steps):
42
  ).images[0]
43
  return edit
44
 
45
-
46
- # Panel widgets
47
  file_input = pn.widgets.FileInput(width=600)
48
- prompt = pn.widgets.TextInput(
49
- value="", placeholder="Enter image editing instruction here...", width=600
 
50
  )
51
  img_guidance = pn.widgets.DiscreteSlider(
52
  name="Image guidance scale", options=list(np.arange(1, 10.5, 0.5)), value=1.5
@@ -54,27 +52,24 @@ img_guidance = pn.widgets.DiscreteSlider(
54
  guidance = pn.widgets.DiscreteSlider(
55
  name="Guidance scale", options=list(np.arange(1, 10.5, 0.5)), value=7
56
  )
57
- steps = pn.widgets.IntSlider(name="Inference Steps", start=1, end=100, step=1, value=20)
58
- run_button = pn.widgets.Button(name="Run!", width=600)
59
-
 
 
 
 
 
 
 
 
 
60
 
61
  # define global variables to keep track of things
62
  convos = [] # store all panel objects in a list
63
  image = None
64
  filename = None
65
 
66
-
67
- def normalize_image(value, width):
68
- """
69
- normalize image to RBG channels and to the same size
70
- """
71
- b = io.BytesIO(value)
72
- image = PIL.Image.open(b).convert("RGB")
73
- aspect = image.size[1] / image.size[0]
74
- height = int(aspect * width)
75
- return image.resize((width, height), PIL.Image.ANTIALIAS)
76
-
77
-
78
  def get_conversations(_, img, img_guidance, guidance, steps, width=600):
79
  """
80
  Get all the conversations in a Panel object
@@ -89,26 +84,38 @@ def get_conversations(_, img, img_guidance, guidance, steps, width=600):
89
  image = normalize_image(file_input.value, width)
90
  convos.clear()
91
 
 
92
  if prompt_text:
93
- # generate new image
94
  image = new_image(prompt_text, image, img_guidance, guidance, steps)
95
- convos.append(pn.Row("\U0001F60A", pn.pane.Markdown(prompt_text, width=600)))
96
- convos.append(pn.Row("\U0001F916", image))
97
- return pn.Column(*convos)
98
-
 
 
 
 
 
 
 
 
 
99
 
100
  # bind widgets to functions
101
- interactive_conversation = pn.bind(
102
- get_conversations, run_button, file_input, img_guidance, guidance, steps
 
 
 
 
103
  )
104
- interactive_upload = pn.bind(pn.panel, file_input, width=600)
105
 
106
  # layout
107
  pn.Column(
108
- pn.pane.Markdown("## \U0001F60A Upload an image file and start editing!"),
109
- pn.Column(file_input, pn.panel(interactive_upload)),
110
- pn.panel(interactive_conversation, loading_indicator=True),
111
- prompt,
112
- pn.Row(run_button),
113
- pn.Card(img_guidance, guidance, steps, width=600, header="Advance settings"),
114
- ).servable(title="Stablel Diffusion InstructPix2pix Image Editing Chatbot")
 
1
  import io
2
 
 
3
  import numpy as np
4
  import panel as pn
5
  import param
 
9
 
10
  from diffusers import StableDiffusionInstructPix2PixPipeline
11
 
12
+ pn.extension('texteditor', template="bootstrap", sizing_mode='stretch_width')
 
 
 
13
 
14
+ pn.state.template.param.update(
15
+ main_max_width="690px",
16
+ header_background="#F08080",
17
+ )
18
 
 
19
  model_id = "timbrooks/instruct-pix2pix"
20
+ pipe = StableDiffusionInstructPix2PixPipeline.from_pretrained(
21
+ model_id, torch_dtype=torch.float16
22
+ ).to("cuda")
23
 
24
+ def normalize_image(value, width):
25
+ """
26
+ normalize image to RBG channels and to the same size
27
+ """
28
+ b = io.BytesIO(value)
29
+ image = PIL.Image.open(b).convert("RGB")
30
+ aspect = image.size[1] / image.size[0]
31
+ height = int(aspect * width)
32
+ return image.resize((width, height), PIL.Image.LANCZOS)
 
33
 
34
+ def new_image(prompt, image, img_guidance, guidance, steps, width=600):
35
  edit = pipe(
36
  prompt,
37
  image=image,
 
41
  ).images[0]
42
  return edit
43
 
 
 
44
  file_input = pn.widgets.FileInput(width=600)
45
+
46
+ prompt = pn.widgets.TextEditor(
47
+ value="", placeholder="Enter image editing instruction here...", height=160, toolbar=False
48
  )
49
  img_guidance = pn.widgets.DiscreteSlider(
50
  name="Image guidance scale", options=list(np.arange(1, 10.5, 0.5)), value=1.5
 
52
  guidance = pn.widgets.DiscreteSlider(
53
  name="Guidance scale", options=list(np.arange(1, 10.5, 0.5)), value=7
54
  )
55
+ steps = pn.widgets.IntSlider(
56
+ name="Inference Steps", start=1, end=100, step=1, value=20
57
+ )
58
+ run_button = pn.widgets.Button(name="Run!")
59
+
60
+ widgets = pn.Row(
61
+ pn.Column(prompt, run_button, margin=5),
62
+ pn.Card(
63
+ pn.Column(img_guidance, guidance, steps),
64
+ title="Advanced settings", margin=10
65
+ ), width=600
66
+ )
67
 
68
  # define global variables to keep track of things
69
  convos = [] # store all panel objects in a list
70
  image = None
71
  filename = None
72
 
 
 
 
 
 
 
 
 
 
 
 
 
73
  def get_conversations(_, img, img_guidance, guidance, steps, width=600):
74
  """
75
  Get all the conversations in a Panel object
 
84
  image = normalize_image(file_input.value, width)
85
  convos.clear()
86
 
87
+ # if there is a prompt run output
88
  if prompt_text:
 
89
  image = new_image(prompt_text, image, img_guidance, guidance, steps)
90
+ convos.extend([
91
+ pn.Row(
92
+ pn.panel("\U0001F60A", width=10),
93
+ prompt_text,
94
+ width=600
95
+ ),
96
+ pn.Row(
97
+ pn.panel(image, align='end', width=500),
98
+ pn.panel("\U0001F916", width=10),
99
+ align='end'
100
+ )
101
+ ])
102
+ return pn.Column(*convos, margin=15, width=575)
103
 
104
  # bind widgets to functions
105
+ interactive_upload = pn.panel(pn.bind(pn.panel, file_input, width=575, min_height=400, margin=15))
106
+
107
+ interactive_conversation = pn.panel(
108
+ pn.bind(
109
+ get_conversations, run_button, file_input, img_guidance, guidance, steps
110
+ ), loading_indicator=True
111
  )
112
+
113
 
114
  # layout
115
  pn.Column(
116
+ "## \U0001F60A Upload an image file and start editing!",
117
+ file_input,
118
+ interactive_upload,
119
+ interactive_conversation,
120
+ widgets
121
+ ).servable(title="Stable Diffusion InstructPix2pix Image Editing Chatbot")