SDXL ControlNet: DWPose
Here are the ControlNet weights trained on SDXL with DWPose conditioning.
Using in 🧨 diffusers
First, install all the libraries:
pip install -q easy-dwpose transformers accelerate
pip install -q git+https://github.com/huggingface/diffusers
Example 1
To generate a realistic DJ with the following image driving the pose:
Run the following code:
from diffusers import ControlNetModel, StableDiffusionXLControlNetPipeline
import torch
from diffusers.utils import load_image
from easy_dwpose import DWposeDetector
pose_image = load_image("./images/pose_image_1.png")
# Load detector
device = "cuda:0" if torch.cuda.is_available() else "cpu"
dwpose = DWposeDetector(device=device)
# Compute DWpose conditioning image.
skeleton = dwpose(
pose_image,
detect_resolution=pose_image.width,
output_type="pil",
include_hands=True,
include_face=True,
)
# Initialize ControlNet pipeline.
controlnet = ControlNetModel.from_pretrained(
"dimitribarbot/controlnet-dwpose-sdxl-1.0",
torch_dtype=torch.float16,
)
pipe = StableDiffusionXLControlNetPipeline.from_pretrained(
"stabilityai/stable-diffusion-xl-base-1.0",
controlnet=controlnet,
torch_dtype=torch.float16,
variant="fp16",
).to(device)
# Infer.
prompt = "DJ in a party, shallow depth of field, highly detailed, high budget, gorgeous"
negative_prompt = "bad quality, blur, anime, cartoon, graphic, text, painting, crayon, graphite, abstract, glitch, deformed, mutated, ugly, disfigured"
image = pipe(
prompt,
negative_prompt=negative_prompt,
num_inference_steps=50,
guidance_scale=5,
image=skeleton,
generator=torch.manual_seed(97),
).images[0]
skeleton.save("./images/dwpose_1.png")
image.save("./images/dwpose_image_1.png")
Generated pose is:
Image generated by SDXL is:
Example 2
To generate a anime version of a woman sitting on a bench with the following image driving the pose:
Run the following code:
from diffusers import ControlNetModel, StableDiffusionXLControlNetPipeline
import torch
from diffusers.utils import load_image
from easy_dwpose import DWposeDetector
pose_image = load_image("./images/pose_image_2.png")
# Load detector
device = "cuda:0" if torch.cuda.is_available() else "cpu"
dwpose = DWposeDetector(device=device)
# Compute DWpose conditioning image.
skeleton = dwpose(
pose_image,
detect_resolution=pose_image.width,
output_type="pil",
include_hands=True,
include_face=True,
)
# Initialize ControlNet pipeline.
controlnet = ControlNetModel.from_pretrained(
"dimitribarbot/controlnet-dwpose-sdxl-1.0",
torch_dtype=torch.float16,
)
pipe = StableDiffusionXLControlNetPipeline.from_pretrained(
"stabilityai/stable-diffusion-xl-base-1.0",
controlnet=controlnet,
torch_dtype=torch.float16,
variant="fp16",
)
if torch.cuda.is_available():
pipe.to(torch.device("cuda"))
# Infer.
prompt = "Anime girl sitting on a bench, highly detailed, noon, ambiant light"
negative_prompt = "bad quality, blur, anime, cartoon, graphic, text, painting, crayon, graphite, abstract, glitch, deformed, mutated, ugly, disfigured"
image = pipe(
prompt,
negative_prompt=negative_prompt,
num_inference_steps=25,
guidance_scale=18,
image=skeleton,
generator=torch.manual_seed(79),
).images[0]
skeleton.save("./images/dwpose_2.png")
image.save("./images/dwpose_image_2.png")
Generated pose is:
Image generated by SDXL is:
Training
The training script by HF🤗 was used.
Training data
This checkpoint was trained for 15,000 steps on the dimitribarbot/dw_pose_controlnet dataset with a resolution of 1024.
Compute
One 1xA40 machine (during 48 hours)
Batch size
Data parallel with a single GPU batch size of 2 with gradient accumulation 8.
Hyper Parameters
Constant learning rate of 8e-5
Mixed precision
fp16
Thanks
StabilityAI SDXL: for the SDXL model.
IDEA Research DWPose: for the DWPose model.
Hugging Face: for the ControlNet training script 🤗 and libraries.
raulc0399: for highly inspiring me with the creation of the DWpose dataset based on the Openpose dataset.
thibaud: for highly inspiring me with the hyper parameters of the HF training script, based on the Openpose ControlNet.
RedHash: for the easy_dwpose module, which highly simplifies the DWPose inference and which I used in the examples above.
- Downloads last month
- 90