Tech-Meld commited on
Commit
d1d1942
·
verified ·
1 Parent(s): b86eb9f

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +74 -0
app.py ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from diffusers import DiffusionPipeline, DPMSolverMultistepScheduler
3
+ from tqdm.auto import tqdm
4
+ from huggingface_hub import cached_download, hf_hub_url
5
+ import os
6
+
7
+ def display_image(image):
8
+ """
9
+ Replace this with your actual image display logic.
10
+ """
11
+ image.show()
12
+
13
+ def load_and_merge_lora(base_model_id, lora_id, lora_weight_name, lora_adapter_name):
14
+ try:
15
+ pipe = DiffusionPipeline.from_pretrained(
16
+ base_model_id,
17
+ torch_dtype=torch.float16,
18
+ scheduler=DPMSolverMultistepScheduler.from_config(
19
+ pipe.scheduler.config),
20
+ variant="fp16",
21
+ use_safetensors=True,
22
+ ).to("cuda")
23
+
24
+ lora_url = hf_hub_url(lora_id, revision="main", filename=lora_weight_name)
25
+ lora_path = cached_download(lora_url)
26
+
27
+ with tqdm(desc="Loading LoRA weights", unit="step") as pbar:
28
+ pipe.load_lora_weights(
29
+ lora_path,
30
+ weight_name=lora_weight_name,
31
+ adapter_name=lora_adapter_name,
32
+ progress_callback=lambda step, max_steps: pbar.update(1)
33
+ )
34
+
35
+ print("LoRA merged successfully!")
36
+ return pipe
37
+ except Exception as e:
38
+ print(f"Error merging LoRA: {e}")
39
+ return None
40
+
41
+ def save_merged_model(pipe, save_path):
42
+ """Saves the merged model to the specified path."""
43
+ try:
44
+ pipe.save_pretrained(save_path)
45
+ print(f"Merged model saved successfully to: {save_path}")
46
+ except Exception as e:
47
+ print(f"Error saving the merged model: {e}")
48
+
49
+ if __name__ == "__main__":
50
+ base_model_id = input("Enter the base model ID: ")
51
+ lora_id = input("Enter the LoRA Hugging Face Hub ID: ")
52
+ lora_weight_name = input("Enter the LoRA weight file name: ")
53
+ lora_adapter_name = input("Enter the LoRA adapter name: ")
54
+
55
+ pipe = load_and_merge_lora(base_model_id, lora_id, lora_weight_name, lora_adapter_name)
56
+
57
+ if pipe:
58
+ prompt = input("Enter your prompt: ")
59
+ lora_scale = float(input("Enter the LoRA scale (e.g., 0.9): "))
60
+
61
+ image = pipe(
62
+ prompt,
63
+ num_inference_steps=30,
64
+ cross_attention_kwargs={"scale": lora_scale},
65
+ generator=torch.manual_seed(0)
66
+ ).images[0]
67
+
68
+ display_image(image)
69
+
70
+ # Ask the user for a directory to save the model
71
+ save_path = input(
72
+ "Enter the directory where you want to save the merged model: "
73
+ )
74
+ save_merged_model(pipe, save_path)