SDXL ControlNet: DWPose

Here are the ControlNet weights trained on SDXL with DWPose conditioning.

Using in 🧨 diffusers

First, install all the libraries:

pip install -q easy-dwpose transformers accelerate
pip install -q git+https://github.com/huggingface/diffusers

Example 1

To generate a realistic DJ with the following image driving the pose:

Pose Image 1

Run the following code:

from diffusers import ControlNetModel, StableDiffusionXLControlNetPipeline
import torch
from diffusers.utils import load_image

from easy_dwpose import DWposeDetector


pose_image = load_image("./images/pose_image_1.png")

# Load detector
device = "cuda:0" if torch.cuda.is_available() else "cpu"
dwpose = DWposeDetector(device=device)

# Compute DWpose conditioning image.
skeleton = dwpose(
    pose_image,
    detect_resolution=pose_image.width,
    output_type="pil",
    include_hands=True,
    include_face=True,
)

# Initialize ControlNet pipeline.
controlnet = ControlNetModel.from_pretrained(
    "dimitribarbot/controlnet-dwpose-sdxl-1.0",
    torch_dtype=torch.float16,
)
pipe = StableDiffusionXLControlNetPipeline.from_pretrained(
    "stabilityai/stable-diffusion-xl-base-1.0",
    controlnet=controlnet,
    torch_dtype=torch.float16,
    variant="fp16",
).to(device)

# Infer.
prompt = "DJ in a party, shallow depth of field, highly detailed, high budget, gorgeous"
negative_prompt = "bad quality, blur, anime, cartoon, graphic, text, painting, crayon, graphite, abstract, glitch, deformed, mutated, ugly, disfigured"
image = pipe(
    prompt,
    negative_prompt=negative_prompt,
    num_inference_steps=50,
    guidance_scale=5,
    image=skeleton,
    generator=torch.manual_seed(97),
).images[0]

skeleton.save("./images/dwpose_1.png")
image.save("./images/dwpose_image_1.png")

Generated pose is:

Pose 1

Image generated by SDXL is:

Pose 1

Example 2

To generate a anime version of a woman sitting on a bench with the following image driving the pose:

Pose Image 2

Run the following code:

from diffusers import ControlNetModel, StableDiffusionXLControlNetPipeline
import torch
from diffusers.utils import load_image

from easy_dwpose import DWposeDetector


pose_image = load_image("./images/pose_image_2.png")

# Load detector
device = "cuda:0" if torch.cuda.is_available() else "cpu"
dwpose = DWposeDetector(device=device)

# Compute DWpose conditioning image.
skeleton = dwpose(
    pose_image,
    detect_resolution=pose_image.width,
    output_type="pil",
    include_hands=True,
    include_face=True,
)

# Initialize ControlNet pipeline.
controlnet = ControlNetModel.from_pretrained(
    "dimitribarbot/controlnet-dwpose-sdxl-1.0",
    torch_dtype=torch.float16,
)
pipe = StableDiffusionXLControlNetPipeline.from_pretrained(
    "stabilityai/stable-diffusion-xl-base-1.0",
    controlnet=controlnet,
    torch_dtype=torch.float16,
    variant="fp16",
)
if torch.cuda.is_available():
    pipe.to(torch.device("cuda"))

# Infer.
prompt = "Anime girl sitting on a bench, highly detailed, noon, ambiant light"
negative_prompt = "bad quality, blur, anime, cartoon, graphic, text, painting, crayon, graphite, abstract, glitch, deformed, mutated, ugly, disfigured"
image = pipe(
    prompt,
    negative_prompt=negative_prompt,
    num_inference_steps=25,
    guidance_scale=18,
    image=skeleton,
    generator=torch.manual_seed(79),
).images[0]

skeleton.save("./images/dwpose_2.png")
image.save("./images/dwpose_image_2.png")

Generated pose is:

Pose 2

Image generated by SDXL is:

Pose 2

Training

The training script by HF🤗 was used.

Training data

This checkpoint was trained for 15,000 steps on the dimitribarbot/dw_pose_controlnet dataset with a resolution of 1024.

Compute

One 1xA40 machine (during 48 hours)

Batch size

Data parallel with a single GPU batch size of 2 with gradient accumulation 8.

Hyper Parameters

Constant learning rate of 8e-5

Mixed precision

fp16

Thanks

StabilityAI SDXL: for the SDXL model.

IDEA Research DWPose: for the DWPose model.

Hugging Face: for the ControlNet training script 🤗 and libraries.

raulc0399: for highly inspiring me with the creation of the DWpose dataset based on the Openpose dataset.

thibaud: for highly inspiring me with the hyper parameters of the HF training script, based on the Openpose ControlNet.

RedHash: for the easy_dwpose module, which highly simplifies the DWPose inference and which I used in the examples above.

Downloads last month
90
Inference Examples
This model does not have enough activity to be deployed to Inference API (serverless) yet. Increase its social visibility and check back later, or deploy to Inference Endpoints (dedicated) instead.