Tech-Meld commited on
Commit
0259e4a
·
verified ·
1 Parent(s): d1d1942

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +87 -40
app.py CHANGED
@@ -1,63 +1,89 @@
1
  import torch
2
  from diffusers import DiffusionPipeline, DPMSolverMultistepScheduler
3
  from tqdm.auto import tqdm
4
- from huggingface_hub import cached_download, hf_hub_url
5
  import os
 
 
 
6
 
7
  def display_image(image):
8
- """
9
- Replace this with your actual image display logic.
10
- """
11
- image.show()
12
 
13
- def load_and_merge_lora(base_model_id, lora_id, lora_weight_name, lora_adapter_name):
14
- try:
15
  pipe = DiffusionPipeline.from_pretrained(
16
- base_model_id,
17
- torch_dtype=torch.float16,
18
- scheduler=DPMSolverMultistepScheduler.from_config(
19
- pipe.scheduler.config),
20
  variant="fp16",
21
  use_safetensors=True,
22
  ).to("cuda")
23
 
24
- lora_url = hf_hub_url(lora_id, revision="main", filename=lora_weight_name)
25
- lora_path = cached_download(lora_url)
 
 
 
 
26
 
27
- with tqdm(desc="Loading LoRA weights", unit="step") as pbar:
28
- pipe.load_lora_weights(
29
- lora_path,
30
- weight_name=lora_weight_name,
31
- adapter_name=lora_adapter_name,
32
- progress_callback=lambda step, max_steps: pbar.update(1)
33
- )
 
 
 
34
 
35
  print("LoRA merged successfully!")
36
  return pipe
 
37
  except Exception as e:
38
- print(f"Error merging LoRA: {e}")
 
 
 
 
 
39
  return None
40
 
41
- def save_merged_model(pipe, save_path):
42
- """Saves the merged model to the specified path."""
43
  try:
44
  pipe.save_pretrained(save_path)
45
  print(f"Merged model saved successfully to: {save_path}")
46
- except Exception as e:
47
- print(f"Error saving the merged model: {e}")
48
 
49
- if __name__ == "__main__":
50
- base_model_id = input("Enter the base model ID: ")
51
- lora_id = input("Enter the LoRA Hugging Face Hub ID: ")
52
- lora_weight_name = input("Enter the LoRA weight file name: ")
53
- lora_adapter_name = input("Enter the LoRA adapter name: ")
54
 
55
- pipe = load_and_merge_lora(base_model_id, lora_id, lora_weight_name, lora_adapter_name)
 
56
 
57
- if pipe:
58
- prompt = input("Enter your prompt: ")
59
- lora_scale = float(input("Enter the LoRA scale (e.g., 0.9): "))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
60
 
 
 
61
  image = pipe(
62
  prompt,
63
  num_inference_steps=30,
@@ -65,10 +91,31 @@ if __name__ == "__main__":
65
  generator=torch.manual_seed(0)
66
  ).images[0]
67
 
68
- display_image(image)
 
69
 
70
- # Ask the user for a directory to save the model
71
- save_path = input(
72
- "Enter the directory where you want to save the merged model: "
73
- )
74
- save_merged_model(pipe, save_path)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import torch
2
  from diffusers import DiffusionPipeline, DPMSolverMultistepScheduler
3
  from tqdm.auto import tqdm
4
+ from huggingface_hub import hf_hub_url, login, HfApi, create_repo
5
  import os
6
+ import traceback
7
+ from peft import PeftModel
8
+ import gradio as gr
9
 
10
  def display_image(image):
11
+ """Display the generated image."""
12
+ return image
 
 
13
 
14
+ def load_and_merge_lora(base_model_id, lora_id, lora_adapter_name):
15
+ try:
16
  pipe = DiffusionPipeline.from_pretrained(
17
+ base_model_id,
18
+ torch_dtype=torch.float16,
 
 
19
  variant="fp16",
20
  use_safetensors=True,
21
  ).to("cuda")
22
 
23
+ pipe.scheduler = DPMSolverMultistepScheduler.from_config(
24
+ pipe.scheduler.config
25
+ )
26
+
27
+ # Get the UNet model from the pipeline
28
+ unet = pipe.unet
29
 
30
+ # Apply PEFT to the UNet model
31
+ unet = PeftModel.from_pretrained(
32
+ unet,
33
+ lora_id,
34
+ torch_dtype=torch.float16,
35
+ adapter_name=lora_adapter_name
36
+ )
37
+
38
+ # Replace the original UNet in the pipeline with the PEFT-loaded one
39
+ pipe.unet = unet
40
 
41
  print("LoRA merged successfully!")
42
  return pipe
43
+
44
  except Exception as e:
45
+ error_msg = traceback.format_exc()
46
+ print(f"Error merging LoRA: {e}\n\nFull traceback saved to errors.txt")
47
+
48
+ with open("errors.txt", "w") as f:
49
+ f.write(error_msg)
50
+
51
  return None
52
 
53
+ def save_merged_model(pipe, save_path, push_to_hub=False, hf_token=None):
54
+ """Saves and optionally pushes the merged model to Hugging Face Hub."""
55
  try:
56
  pipe.save_pretrained(save_path)
57
  print(f"Merged model saved successfully to: {save_path}")
 
 
58
 
59
+ if push_to_hub:
60
+ if hf_token is None:
61
+ hf_token = input("Enter your Hugging Face write token: ")
62
+ login(token=hf_token)
 
63
 
64
+ repo_name = input("Enter the Hugging Face repository name "
65
+ "(e.g., your_username/your_model_name): ")
66
 
67
+ # Create the repository if it doesn't exist
68
+ create_repo(repo_name, token=hf_token, exist_ok=True)
69
+
70
+ api = HfApi()
71
+ api.upload_folder(
72
+ folder_path=save_path,
73
+ repo_id=repo_name,
74
+ token=hf_token,
75
+ repo_type="model",
76
+ )
77
+ print(f"Model pushed successfully to Hugging Face Hub: {repo_name}")
78
+
79
+ except Exception as e:
80
+ print(f"Error saving/pushing the merged model: {e}")
81
+
82
+ def generate_and_save(base_model_id, lora_id, lora_adapter_name, prompt, lora_scale, save_path, push_to_hub, hf_token):
83
+ pipe = load_and_merge_lora(base_model_id, lora_id, lora_adapter_name)
84
 
85
+ if pipe:
86
+ lora_scale = float(lora_scale)
87
  image = pipe(
88
  prompt,
89
  num_inference_steps=30,
 
91
  generator=torch.manual_seed(0)
92
  ).images[0]
93
 
94
+ image.save("generated_image.png")
95
+ print(f"Image saved to: generated_image.png")
96
 
97
+ save_merged_model(pipe, save_path, push_to_hub, hf_token)
98
+
99
+ return image, "Image generated and model saved/pushed (if selected)."
100
+
101
+ iface = gr.Interface(
102
+ fn=generate_and_save,
103
+ inputs=[
104
+ gr.Textbox(label="Base Model ID (e.g., stabilityai/stable-diffusion-xl-base-1.0)"),
105
+ gr.Textbox(label="LoRA ID (e.g., your_username/your_lora)"),
106
+ gr.Textbox(label="LoRA Adapter Name"),
107
+ gr.Textbox(label="Prompt"),
108
+ gr.Slider(label="LoRA Scale", minimum=0.0, maximum=1.0, value=0.7, step=0.1),
109
+ gr.Textbox(label="Save Path"),
110
+ gr.Checkbox(label="Push to Hugging Face Hub"),
111
+ gr.Textbox(label="Hugging Face Write Token", type="password")
112
+ ],
113
+ outputs=[
114
+ gr.Image(label="Generated Image"),
115
+ gr.Textbox(label="Status")
116
+ ],
117
+ title="LoRA Merger and Image Generator",
118
+ description="Merge a LoRA with a base Stable Diffusion model and generate images."
119
+ )
120
+
121
+ iface.launch()