Skip to content

Commit

Permalink
Adds default params for controlnet models
Browse files Browse the repository at this point in the history
  • Loading branch information
Vikram Voleti committed Nov 26, 2024
1 parent 123558f commit 106db06
Show file tree
Hide file tree
Showing 2 changed files with 65 additions and 36 deletions.
30 changes: 19 additions & 11 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,19 +23,18 @@ Download the following models from HuggingFace into `models` directory:
This code also works for [Stability AI SD3 Medium](https://huggingface.co/stabilityai/stable-diffusion-3-medium/blob/main/sd3_medium.safetensors).

#### ControlNets
Optionally, download [SD3.5 ControlNets](https://huggingface.co/stabilityai/stable-diffusion-3.5-controlnets):
(a) [Blur ControlNet](https://huggingface.co/stabilityai/stable-diffusion-3.5-controlnets/resolve/main/sd3.5_large_controlnet_blur.safetensors)
(b) [Canny ControlNet](https://huggingface.co/stabilityai/stable-diffusion-3.5-controlnets/resolve/main/sd3.5_large_controlnet_canny.safetensors)
(c) [Depth ControlNet](https://huggingface.co/stabilityai/stable-diffusion-3.5-controlnets/resolve/main/sd3.5_large_controlnet_depth.safetensors)
For example:

Optionally, download [SD3.5 Large ControlNets](https://huggingface.co/stabilityai/stable-diffusion-3.5-controlnets):
(a) [Blur ControlNet](https://huggingface.co/stabilityai/stable-diffusion-3.5-controlnets/resolve/main/blur_8b.safetensors)
(b) [Canny ControlNet](https://huggingface.co/stabilityai/stable-diffusion-3.5-controlnets/resolve/main/canny_8b.safetensors)
(c) [Depth ControlNet](https://huggingface.co/stabilityai/stable-diffusion-3.5-controlnets/resolve/main/depth_8b.safetensors)

```py
from huggingface_hub import hf_hub_download
hf_hub_download("stabilityai/stable-diffusion-3.5-controlnets", "sd3.5_large_controlnet_blur.safetensors", local_dir="models")
hf_hub_download("stabilityai/stable-diffusion-3.5-controlnets", "sd3.5_large_controlnet_canny.safetensors", local_dir="models")
hf_hub_download("stabilityai/stable-diffusion-3.5-controlnets", "sd3.5_large_controlnet_depth.safetensors", local_dir="models")
```
<!-- ```sh
wget -O models/sd3.5_large_controlnet_canny.safetensors https://huggingface.co/stabilityai/stable-diffusion-3.5-controlnets/resolve/main/sd3.5_large_controlnet_canny.safetensors
```
or -->

### Install

Expand Down Expand Up @@ -77,9 +76,18 @@ python3 sd3_infer.py --prompt path/to/my_prompts.txt --model models/sd3.5_medium

#### ControlNets

To use ControlNets, also download your chosen ControlNet model from the [model repository](https://huggingface.co/stabilityai/stable-diffusion-3.5-controlnets), then run inference, like so:
To use SD3.5 Large ControlNets, additionally download your chosen ControlNet model from the [model repository](https://huggingface.co/stabilityai/stable-diffusion-3.5-controlnets), then run inference, like so:
(a) Blur:
```sh
python sd3_infer.py --model models/sd3.5_large.safetensors --controlnet_ckpt models/sd3.5_large_controlnet_blur.safetensors --controlnet_cond_image inputs/blur.png --prompt "generated ai art, a tiny, lost rubber ducky in an action shot close-up, surfing the humongous waves, inside the tube, in the style of Kelly Slater"
```
(b) Canny:
```sh
python sd3_infer.py --model models/sd3.5_large.safetensors --controlnet_ckpt models/sd3.5_large_controlnet_canny.safetensors --controlnet_cond_image inputs/canny.png --prompt "A Night time photo taken by Leica M11, portrait of a Japanese woman in a kimono, looking at the camera, Cherry blossoms"
```
(c) Depth:
```sh
python sd3_infer.py --controlnet_ckpt models/sd3.5_large_controlnet_canny.safetensors --controlnet_cond_image inputs/canny.png --prompt "A Night time photo taken by Leica M11, portrait of a Japanese woman in a loose and half worn kimono showing her bare shoulders, looking at the camera, Cherry blossoms"
python sd3_infer.py --model models/sd3.5_large.safetensors --controlnet_ckpt models/sd3.5_large_controlnet_depth.safetensors --controlnet_cond_image inputs/depth.png --prompt "photo of woman, presumably in her mid-thirties, striking a balanced yoga pose on a rocky outcrop during dusk or dawn. She wears a light gray t-shirt and dark leggings. Her pose is dynamic, with one leg extended backward and the other bent at the knee, holding the moon close to her hand."
```

For details on preprocessing for each of the ControlNets, and examples, please review the [model card](https://huggingface.co/stabilityai/stable-diffusion-3.5-controlnets).
Expand Down
71 changes: 46 additions & 25 deletions sd3_infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,20 +14,19 @@

import fire
import numpy as np
import sd3_impls
import torch
from other_impls import SD3Tokenizer, SDClipModel, SDXLClipG, T5XXLModel
from PIL import Image
from safetensors import safe_open
from tqdm import tqdm

import sd3_impls
from other_impls import SD3Tokenizer, SDClipModel, SDXLClipG, T5XXLModel
from sd3_impls import (
SDVAE,
BaseModel,
CFGDenoiser,
SD3LatentFormat,
SkipLayerCFGDenoiser,
)
from tqdm import tqdm

#################################################################################################
### Wrappers for model parts
Expand Down Expand Up @@ -236,7 +235,7 @@ def __init__(self, model, dtype: torch.dtype = torch.float16):
# ControlNet
CONTROLNET_COND_IMAGE = None
# If init_image is given, this is the percentage of denoising steps to run (1.0 = full denoise, 0.0 = no denoise at all)
DENOISE = 0.6
DENOISE = 0.8
# Output file path
OUTDIR = "outputs"
# SAMPLER
Expand Down Expand Up @@ -396,7 +395,7 @@ def vae_encode(
image_torch = torch.from_numpy(batch_images).cuda()
if using_2b_controlnet:
image_torch = image_torch * 2.0 - 1.0
elif controlnet_type == 1: # canny
elif controlnet_type == 1: # canny
image_torch = image_torch * 255 * 0.5 + 0.5
else:
image_torch = 2.0 * image_torch - 1.0
Expand Down Expand Up @@ -503,14 +502,14 @@ def gen_image(
CONFIGS = {
"sd3_medium": {
"shift": 1.0,
"cfg": 5.0,
"steps": 50,
"cfg": 5.0,
"sampler": "dpmpp_2m",
},
"sd3.5_medium": {
"shift": 3.0,
"cfg": 5.0,
"steps": 50,
"cfg": 5.0,
"sampler": "dpmpp_2m",
"skip_layer_config": {
"scale": 2.5,
Expand All @@ -522,11 +521,29 @@ def gen_image(
},
"sd3.5_large": {
"shift": 3.0,
"cfg": 4.5,
"steps": 40,
"cfg": 4.5,
"sampler": "dpmpp_2m",
},
"sd3.5_large_turbo": {"shift": 3.0, "cfg": 1.0, "steps": 4, "sampler": "euler"},
"sd3.5_large_controlnet_blur": {
"shift": 3.0,
"steps": 60,
"cfg": 3.5,
"sampler": "euler",
},
"sd3.5_large_controlnet_canny": {
"shift": 3.0,
"steps": 60,
"cfg": 3.5,
"sampler": "euler",
},
"sd3.5_large_controlnet_depth": {
"shift": 3.0,
"steps": 60,
"cfg": 3.5,
"sampler": "euler",
},
}


Expand Down Expand Up @@ -556,18 +573,13 @@ def main(
**kwargs,
):
assert not kwargs, f"Unknown arguments: {kwargs}"
steps = steps or CONFIGS.get(os.path.splitext(os.path.basename(model))[0], {}).get(
"steps", 50
)
cfg = cfg or CONFIGS.get(os.path.splitext(os.path.basename(model))[0], {}).get(
"cfg", 5
)
shift = shift or CONFIGS.get(os.path.splitext(os.path.basename(model))[0], {}).get(
"shift", 3
)
sampler = sampler or CONFIGS.get(
os.path.splitext(os.path.basename(model))[0], {}
).get("sampler", "dpmpp_2m")

config = CONFIGS.get(os.path.splitext(os.path.basename(model))[0], {})
_shift = shift or config.get("shift", 3)
_steps = steps or config.get("steps", 50)
_cfg = cfg or config.get("cfg", 5)
_sampler = sampler or config.get("sampler", "dpmpp_2m")

if skip_layer_cfg:
skip_layer_config = CONFIGS.get(
os.path.splitext(os.path.basename(model))[0], {}
Expand All @@ -576,12 +588,21 @@ def main(
else:
skip_layer_config = {}

if controlnet_ckpt is not None:
controlnet_config = CONFIGS.get(
os.path.splitext(os.path.basename(controlnet_ckpt))[0], {}
)
_shift = shift or controlnet_config.get("shift", shift)
_steps = steps or controlnet_config.get("steps", steps)
_cfg = cfg or controlnet_config.get("cfg", cfg)
_sampler = sampler or controlnet_config.get("sampler", sampler)

inferencer = SD3Inferencer()

inferencer.load(
model,
vae,
shift,
_shift,
controlnet_ckpt,
model_folder,
text_encoder_device,
Expand Down Expand Up @@ -616,9 +637,9 @@ def main(
prompts,
width,
height,
steps,
cfg,
sampler,
_steps,
_cfg,
_sampler,
seed,
seed_type,
out_dir,
Expand Down

0 comments on commit 106db06

Please sign in to comment.