Skip to content

Commit

Permalink
Add controlnet and controlled unet
Browse files Browse the repository at this point in the history
  • Loading branch information
gpetters94 committed Feb 1, 2024
1 parent c1dc94c commit b078e6e
Show file tree
Hide file tree
Showing 2 changed files with 329 additions and 5 deletions.
178 changes: 178 additions & 0 deletions python/turbine_models/custom_models/sd_inference/controlnet.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,178 @@
# Copyright 2023 Nod Labs, Inc
#
# Licensed under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

import os
import sys

from iree import runtime as ireert
from iree.compiler.ir import Context
import numpy as np
from shark_turbine.aot import *
from turbine_models.custom_models.sd_inference import utils
import torch
import torch._dynamo as dynamo
from diffusers import ControlNetModel as CNetModel

import safetensors
import argparse
import re

parser = argparse.ArgumentParser()
parser.add_argument(
"--hf_auth_token", type=str, help="The Hugging Face auth token, required"
)
parser.add_argument(
"--hf_model_name",
type=str,
help="HF model name",
default="lllyasviel/control_v11p_sd15_canny",
)
parser.add_argument(
"--batch_size", type=int, default=1, help="Batch size for inference"
)
parser.add_argument(
"--height", type=int, default=512, help="Height of Stable Diffusion"
)
parser.add_argument("--width", type=int, default=512, help="Width of Stable Diffusion")
parser.add_argument("--compile_to", type=str, help="torch, linalg, vmfb")
parser.add_argument("--external_weight_path", type=str, default="")
parser.add_argument(
"--external_weights",
type=str,
default=None,
help="saves ir without global weights for size and readability, options [safetensors]",
)
parser.add_argument("--device", type=str, default="cpu", help="cpu, cuda, vulkan, rocm")
# TODO: Bring in detection for target triple
parser.add_argument(
"--iree_target_triple",
type=str,
default="",
help="Specify vulkan target triple or rocm/cuda target device.",
)
parser.add_argument("--vulkan_max_allocation", type=str, default="4294967296")


class ControlNetModel(torch.nn.Module):
def __init__(
self, model_id="lllyasviel/control_v11p_sd15_canny", low_cpu_mem_usage=False
):
super().__init__()
self.cnet = CNetModel.from_pretrained(
model_id,
low_cpu_mem_usage=low_cpu_mem_usage,
)
self.in_channels = self.cnet.config.in_channels
self.train(False)

def forward(
self,
latent,
timestep,
text_embedding,
stencil_image_input,
):
# expand the latents if we are doing classifier-free guidance to avoid doing two forward passes.
# TODO: guidance NOT NEEDED change in `get_input_info` later
latents = torch.cat([latent] * 2) # needs to be same as controlledUNET latents
stencil_image = torch.cat(
[stencil_image_input] * 2
) # needs to be same as controlledUNET latents
(
down_block_res_samples,
mid_block_res_sample,
) = self.cnet.forward(
latents,
timestep,
encoder_hidden_states=text_embedding,
controlnet_cond=stencil_image,
return_dict=False,
)
return tuple(list(down_block_res_samples) + [mid_block_res_sample])


def export_controlnet_model(
controlnet_model,
hf_model_name,
batch_size,
height,
width,
hf_auth_token=None,
compile_to="torch",
external_weights=None,
external_weight_path=None,
device=None,
target_triple=None,
max_alloc=None,
):
mapper = {}
utils.save_external_weights(
mapper, controlnet_model, external_weights, external_weight_path
)

class CompiledControlnet(CompiledModule):
if external_weights:
params = export_parameters(
controlnet_model,
external=True,
external_scope="",
name_mapper=mapper.get,
)
else:
params = export_parameters(controlnet_model)

def main(
self,
latent=AbstractTensor(1, 4, 512, 512, dtype=torch.float32),
timestep=AbstractTensor(1, dtype=torch.float32),
text_embedding=AbstractTensor(2, 72, 768, dtype=torch.float32),
stencil_image_input=AbstractTensor(1, 3, 4096, 4096, dtype=torch.float32),
):
return jittable(controlnet_model.forward)(
latent,
timestep,
text_embedding,
stencil_image_input,
)

import_to = "INPUT" if compile_to == "linalg" else "IMPORT"
inst = CompiledControlnet(context=Context(), import_to=import_to)

module_str = str(CompiledModule.get_mlir_module(inst))
safe_name = hf_model_name.split("/")[-1].strip()
safe_name = re.sub("-", "_", safe_name)
if compile_to != "vmfb":
return module_str
else:
utils.compile_to_vmfb(module_str, device, target_triple, max_alloc, safe_name)


if __name__ == "__main__":
args = parser.parse_args()
controlnet_model = ControlNetModel(
args.hf_model_name,
)
mod_str = export_controlnet_model(
controlnet_model,
args.hf_model_name,
args.batch_size,
args.height,
args.width,
args.hf_auth_token,
args.compile_to,
args.external_weights,
args.external_weight_path,
args.device,
args.iree_target_triple,
args.vulkan_max_allocation,
)

if mod_str is None:
safe_name = args.hf_model_name.split("/")[-1].strip()
safe_name = re.sub("-", "_", safe_name)
with open(f"{safe_name}.mlir", "w+") as f:
f.write(mod_str)
print("Saved to", safe_name + ".mlir")
156 changes: 151 additions & 5 deletions python/turbine_models/custom_models/sd_inference/unet.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,19 +53,26 @@
help="Specify vulkan target triple or rocm/cuda target device.",
)
parser.add_argument("--vulkan_max_allocation", type=str, default="4294967296")
parser.add_argument('--controlled', dest='controlled', action='store_true', help="Whether or not to use controlled unet (for use with controlnet)")
parser.add_argument('--no-controlled', dest='controlled', action='store_false', help="Whether or not to use controlled unet (for use with controlnet)")
parser.set_defaults(controlled=False)


class UnetModel(torch.nn.Module):
def __init__(self, hf_model_name, hf_auth_token):
def __init__(self, hf_model_name, hf_auth_token, is_controlled):
super().__init__()
self.unet = UNet2DConditionModel.from_pretrained(
hf_model_name,
subfolder="unet",
token=hf_auth_token,
)
self.guidance_scale = 7.5
if is_controlled:
self.forward = self.forward_controlled
else:
self.forward = self.forward_default

def forward(self, sample, timestep, encoder_hidden_states):
def forward_default(self, sample, timestep, encoder_hidden_states):
samples = torch.cat([sample] * 2)
unet_out = self.unet.forward(
samples, timestep, encoder_hidden_states, return_dict=False
Expand All @@ -76,6 +83,65 @@ def forward(self, sample, timestep, encoder_hidden_states):
)
return noise_pred

def forward_controlled(
self,
sample,
timestep,
encoder_hidden_states,
control1,
control2,
control3,
control4,
control5,
control6,
control7,
control8,
control9,
control10,
control11,
control12,
control13,
scale1,
scale2,
scale3,
scale4,
scale5,
scale6,
scale7,
scale8,
scale9,
scale10,
scale11,
scale12,
scale13,
):
db_res_samples = tuple(
[
control1 * scale1,
control2 * scale2,
control3 * scale3,
control4 * scale4,
control5 * scale5,
control6 * scale6,
control7 * scale7,
control8 * scale8,
control9 * scale9,
control10 * scale10,
control11 * scale11,
control12 * scale12,
]
)
mb_res_samples = control13 * scale13
samples = torch.cat([sample] * 2)
unet_out = self.unet.forward(
samples, timestep, encoder_hidden_states, down_block_additional_residuals=db_res_samples, mid_block_additional_residual=mb_res_samples, return_dict=False
)[0]
noise_pred_uncond, noise_pred_text = unet_out.chunk(2)
noise_pred = noise_pred_uncond + self.guidance_scale * (
noise_pred_text - noise_pred_uncond
)
return noise_pred


def export_unet_model(
unet_model,
Expand All @@ -90,6 +156,7 @@ def export_unet_model(
device=None,
target_triple=None,
max_alloc=None,
is_controlled=False,
):
mapper = {}
utils.save_external_weights(
Expand All @@ -100,7 +167,7 @@ def export_unet_model(
if hf_model_name == "stabilityai/stable-diffusion-2-1-base":
encoder_hidden_states_sizes = (2, 77, 1024)

sample = (batch_size, unet_model.unet.config.in_channels, height // 8, width // 8)
sample = (batch_size, unet_model.unet.config.in_channels, height, width)

class CompiledUnet(CompiledModule):
if external_weights:
Expand All @@ -120,8 +187,85 @@ def main(
):
return jittable(unet_model.forward)(sample, timestep, encoder_hidden_states)

class CompiledControlledUnet(CompiledModule):
if external_weights:
params = export_parameters(
unet_model, external=True, external_scope="", name_mapper=mapper.get
)
else:
params = export_parameters(unet_model)

def main(
self,
sample=AbstractTensor(*sample, dtype=torch.float32),
timestep=AbstractTensor(1, dtype=torch.float32),
encoder_hidden_states=AbstractTensor(
*encoder_hidden_states_sizes, dtype=torch.float32
),
control1=AbstractTensor(2, 320, height, width, dtype=torch.float32),
control2=AbstractTensor(2, 320, height, width, dtype=torch.float32),
control3=AbstractTensor(2, 320, height, width, dtype=torch.float32),
control4=AbstractTensor(2, 320, height//2, width//2, dtype=torch.float32),
control5=AbstractTensor(2, 640, height//2, width//2, dtype=torch.float32),
control6=AbstractTensor(2, 640, height//2, width//2, dtype=torch.float32),
control7=AbstractTensor(2, 640, height//4, width//4, dtype=torch.float32),
control8=AbstractTensor(2, 1280, height//4, width//4, dtype=torch.float32),
control9=AbstractTensor(2, 1280, height//4, width//4, dtype=torch.float32),
control10=AbstractTensor(2, 1280, height//8, width//8, dtype=torch.float32),
control11=AbstractTensor(2, 1280, height//8, width//8, dtype=torch.float32),
control12=AbstractTensor(2, 1280, height//8, width//8, dtype=torch.float32),
control13=AbstractTensor(2, 1280, height//8, width//8, dtype=torch.float32),
scale1=AbstractTensor(1, dtype=torch.float32),
scale2=AbstractTensor(1, dtype=torch.float32),
scale3=AbstractTensor(1, dtype=torch.float32),
scale4=AbstractTensor(1, dtype=torch.float32),
scale5=AbstractTensor(1, dtype=torch.float32),
scale6=AbstractTensor(1, dtype=torch.float32),
scale7=AbstractTensor(1, dtype=torch.float32),
scale8=AbstractTensor(1, dtype=torch.float32),
scale9=AbstractTensor(1, dtype=torch.float32),
scale10=AbstractTensor(1, dtype=torch.float32),
scale11=AbstractTensor(1, dtype=torch.float32),
scale12=AbstractTensor(1, dtype=torch.float32),
scale13=AbstractTensor(1, dtype=torch.float32),
):
return jittable(unet_model.forward)(
sample,
timestep,
encoder_hidden_states,
control1,
control2,
control3,
control4,
control5,
control6,
control7,
control8,
control9,
control10,
control11,
control12,
control13,
scale1,
scale2,
scale3,
scale4,
scale5,
scale6,
scale7,
scale8,
scale9,
scale10,
scale11,
scale12,
scale13,
)

import_to = "INPUT" if compile_to == "linalg" else "IMPORT"
inst = CompiledUnet(context=Context(), import_to=import_to)
if is_controlled:
inst = CompiledControlledUnet(context=Context(), import_to=import_to)
else:
inst = CompiledUnet(context=Context(), import_to=import_to)

module_str = str(CompiledModule.get_mlir_module(inst))
safe_name = utils.create_safe_name(hf_model_name, "-unet")
Expand All @@ -134,8 +278,9 @@ def main(
if __name__ == "__main__":
args = parser.parse_args()
unet_model = UnetModel(
args.hf_model_name,
args.hf_model_name if not args.controlled else "CompVis/stable-diffusion-v1-4",
args.hf_auth_token,
args.controlled,
)
mod_str = export_unet_model(
unet_model,
Expand All @@ -150,6 +295,7 @@ def main(
args.device,
args.iree_target_triple,
args.vulkan_max_allocation,
args.controlled,
)
safe_name = utils.create_safe_name(args.hf_model_name, "-unet")
with open(f"{safe_name}.mlir", "w+") as f:
Expand Down

0 comments on commit b078e6e

Please sign in to comment.