Skip to content

Commit

Permalink
Image2image - python (#115)
Browse files Browse the repository at this point in the history
* Add Encoder model to torch2coreml for image2image

and later for in-paining

* diagonal test with randn

* Revert "diagonal test with randn"

This reverts commit 270afe1.

* readme updates for encoder

* pr comments
  • Loading branch information
littleowl authored Jan 31, 2023
1 parent 6cd5c7a commit 086cc5e
Show file tree
Hide file tree
Showing 2 changed files with 182 additions and 2 deletions.
5 changes: 5 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -220,6 +220,10 @@ Both of these products require the Core ML models and tokenization resources to
- `vocab.json` (tokenizer vocabulary file)
- `merges.text` (merges for byte pair encoding file)

Optionally, for image2image, in-painting, or similar:

- `VAEEncoder.mlmodelc` (image encoder model)

Optionally, it may also include the safety checker model that some versions of Stable Diffusion include:

- `SafetyChecker.mlmodelc`
Expand Down Expand Up @@ -321,6 +325,7 @@ Differences may be less or more pronounced for different inputs. Please see the
<b> A3: </b> In order to minimize the memory impact of the model conversion process, please execute the following command instead:

```bash
python -m python_coreml_stable_diffusion.torch2coreml --convert-vae-encoder -o <output-mlpackages-directory> && \
python -m python_coreml_stable_diffusion.torch2coreml --convert-vae-decoder -o <output-mlpackages-directory> && \
python -m python_coreml_stable_diffusion.torch2coreml --convert-unet -o <output-mlpackages-directory> && \
python -m python_coreml_stable_diffusion.torch2coreml --convert-text-encoder -o <output-mlpackages-directory> && \
Expand Down
179 changes: 177 additions & 2 deletions python_coreml_stable_diffusion/torch2coreml.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,21 @@ def _get_coreml_inputs(sample_inputs, args):
) for k, v in sample_inputs.items()
]

# Simpler version of `DiagonalGaussianDistribution` with only needed calculations
# as implemented in vae.py as part of the AutoencoderKL class
# https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/vae.py#L312
# coremltools-6.1 does not yet implement the randn operation with the option of setting a random seed
class CoreMLDiagonalGaussianDistribution(object):
def __init__(self, parameters, noise):
self.parameters = parameters
self.noise = noise
self.mean, self.logvar = torch.chunk(parameters, 2, dim=1)
self.logvar = torch.clamp(self.logvar, -30.0, 20.0)
self.std = torch.exp(0.5 * self.logvar)

def sample(self) -> torch.FloatTensor:
x = self.mean + self.std * self.noise
return x

def compute_psnr(a, b):
""" Compute Peak-Signal-to-Noise-Ratio across two numpy.ndarray objects
Expand Down Expand Up @@ -140,7 +155,7 @@ def _convert_to_coreml(submodule_name, torchscript_module, sample_inputs,

def quantize_weights_to_8bits(args):
for model_name in [
"text_encoder", "vae_decoder", "unet", "unet_chunk1",
"text_encoder", "vae_decoder", "vae_encoder", "unet", "unet_chunk1",
"unet_chunk2", "safety_checker"
]:
out_path = _get_out_path(args, model_name)
Expand Down Expand Up @@ -190,6 +205,7 @@ def bundle_resources_for_swift_cli(args):
# Compile model using coremlcompiler (Significantly reduces the load time for unet)
for source_name, target_name in [("text_encoder", "TextEncoder"),
("vae_decoder", "VAEDecoder"),
("vae_encoder", "VAEEncoder"),
("unet", "Unet"),
("unet_chunk1", "UnetChunk1"),
("unet_chunk2", "UnetChunk2"),
Expand Down Expand Up @@ -453,6 +469,159 @@ def forward(self, z):
gc.collect()


def convert_vae_encoder(pipe, args):
""" Converts the VAE Encoder component of Stable Diffusion
"""
out_path = _get_out_path(args, "vae_encoder")
if os.path.exists(out_path):
logger.info(
f"`vae_encoder` already exists at {out_path}, skipping conversion."
)
return

if not hasattr(pipe, "unet"):
raise RuntimeError(
"convert_unet() deletes pipe.unet to save RAM. "
"Please use convert_vae_encoder() before convert_unet()")

sample_shape = (
1, # B
3, # C (RGB range from -1 to 1)
(args.latent_h or pipe.unet.config.sample_size) * 8, # H
(args.latent_w or pipe.unet.config.sample_size) * 8, # w
)

noise_shape = (
1, # B
4, # C
pipe.unet.config.sample_size, # H
pipe.unet.config.sample_size, # w
)

float_value_shape = (
1,
1,
)

sqrt_alphas_cumprod_torch_shape = torch.tensor([[0.2,]])
sqrt_one_minus_alphas_cumprod_torch_shape = torch.tensor([[0.8,]])

sample_vae_encoder_inputs = {
"sample": torch.rand(*sample_shape, dtype=torch.float16),
"diagonal_noise": torch.rand(*noise_shape, dtype=torch.float16),
"noise": torch.rand(*noise_shape, dtype=torch.float16),
"sqrt_alphas_cumprod": torch.rand(*float_value_shape, dtype=torch.float16),
"sqrt_one_minus_alphas_cumprod": torch.rand(*float_value_shape, dtype=torch.float16),
}

class VAEEncoder(nn.Module):
""" Wrapper nn.Module wrapper for pipe.encode() method
"""

def __init__(self):
super().__init__()
self.quant_conv = pipe.vae.quant_conv
self.alphas_cumprod = pipe.scheduler.alphas_cumprod
self.encoder = pipe.vae.encoder

# Because CoreMLTools does not support the torch.randn op, we pass in both
# the diagonal Noise for the `DiagonalGaussianDistribution` operation and
# the noise tensor combined with precalculated `sqrt_alphas_cumprod` and `sqrt_one_minus_alphas_cumprod`
# for faster computation.
def forward(self, sample, diagonal_noise, noise, sqrt_alphas_cumprod, sqrt_one_minus_alphas_cumprod):
h = self.encoder(sample)
moments = self.quant_conv(h)
posterior = CoreMLDiagonalGaussianDistribution(moments, diagonal_noise)
posteriorSample = posterior.sample()

# Add the scaling operation and the latent noise for faster computation
init_latents = 0.18215 * posteriorSample
result = self.add_noise(init_latents, noise, sqrt_alphas_cumprod, sqrt_one_minus_alphas_cumprod)
return result

def add_noise(
self,
original_samples: torch.FloatTensor,
noise: torch.FloatTensor,
sqrt_alphas_cumprod: torch.FloatTensor,
sqrt_one_minus_alphas_cumprod: torch.FloatTensor
) -> torch.FloatTensor:
noisy_samples = sqrt_alphas_cumprod * original_samples + sqrt_one_minus_alphas_cumprod * noise
return noisy_samples


baseline_encoder = VAEEncoder().eval()

# No optimization needed for the VAE Encoder as it is a pure ConvNet
traced_vae_encoder = torch.jit.trace(
baseline_encoder, (
sample_vae_encoder_inputs["sample"].to(torch.float32),
sample_vae_encoder_inputs["diagonal_noise"].to(torch.float32),
sample_vae_encoder_inputs["noise"].to(torch.float32),
sqrt_alphas_cumprod_torch_shape.to(torch.float32),
sqrt_one_minus_alphas_cumprod_torch_shape.to(torch.float32)
))

modify_coremltools_torch_frontend_badbmm()
coreml_vae_encoder, out_path = _convert_to_coreml(
"vae_encoder", traced_vae_encoder, sample_vae_encoder_inputs,
["latent_dist"], args)

# Set model metadata
coreml_vae_encoder.author = f"Please refer to the Model Card available at huggingface.co/{args.model_version}"
coreml_vae_encoder.license = "OpenRAIL (https://huggingface.co/spaces/CompVis/stable-diffusion-license)"
coreml_vae_encoder.version = args.model_version
coreml_vae_encoder.short_description = \
"Stable Diffusion generates images conditioned on text and/or other images as input through the diffusion process. " \
"Please refer to https://arxiv.org/abs/2112.10752 for details."

# Set the input descriptions
coreml_vae_encoder.input_description["sample"] = \
"An image of the correct size to create the latent space with, image2image and in-painting."
coreml_vae_encoder.input_description["diagonal_noise"] = \
"Latent noise for `DiagonalGaussianDistribution` operation."
coreml_vae_encoder.input_description["noise"] = \
"Latent noise for use with strength parameter of image2image"
coreml_vae_encoder.input_description["sqrt_alphas_cumprod"] = \
"Precalculated `sqrt_alphas_cumprod` value based on strength and the current schedular's alphasCumprod values"
coreml_vae_encoder.input_description["sqrt_one_minus_alphas_cumprod"] = \
"Precalculated `sqrt_one_minus_alphas_cumprod` value based on strength and the current schedular's alphasCumprod values"

# Set the output descriptions
coreml_vae_encoder.output_description[
"latent_dist"] = "The latent embeddings from the unet model from the input image."

_save_mlpackage(coreml_vae_encoder, out_path)

logger.info(f"Saved vae_encoder into {out_path}")

# Parity check PyTorch vs CoreML
if args.check_output_correctness:
baseline_out = baseline_encoder(
sample=sample_vae_encoder_inputs["sample"].to(torch.float32),
diagonal_noise=sample_vae_encoder_inputs["diagonal_noise"].to(torch.float32),
noise=sample_vae_encoder_inputs["noise"].to(torch.float32),
sqrt_alphas_cumprod=sqrt_alphas_cumprod_torch_shape,
sqrt_one_minus_alphas_cumprod=sqrt_one_minus_alphas_cumprod_torch_shape,
).numpy(),

coreml_out = list(
coreml_vae_encoder.predict(
{
"sample": sample_vae_encoder_inputs["sample"].numpy(),
"diagonal_noise": sample_vae_encoder_inputs["diagonal_noise"].numpy(),
"noise": sample_vae_encoder_inputs["noise"].numpy(),
"sqrt_alphas_cumprod": sqrt_alphas_cumprod_torch_shape.numpy(),
"sqrt_one_minus_alphas_cumprod": sqrt_one_minus_alphas_cumprod_torch_shape.numpy()
}).values())

report_correctness(baseline_out[0], coreml_out[0],
"vae_encoder baseline PyTorch to baseline CoreML")

del traced_vae_encoder, pipe.vae.encoder, coreml_vae_encoder
gc.collect()


def convert_unet(pipe, args):
""" Converts the UNet component of Stable Diffusion
"""
Expand Down Expand Up @@ -801,7 +970,12 @@ def main(args):
logger.info("Converting vae_decoder")
convert_vae_decoder(pipe, args)
logger.info("Converted vae_decoder")


if args.convert_vae_encoder:
logger.info("Converting vae_encoder")
convert_vae_encoder(pipe, args)
logger.info("Converted vae_encoder")

if args.convert_unet:
logger.info("Converting unet")
convert_unet(pipe, args)
Expand Down Expand Up @@ -835,6 +1009,7 @@ def parser_spec():
# Select which models to export (All are needed for text-to-image pipeline to function)
parser.add_argument("--convert-text-encoder", action="store_true")
parser.add_argument("--convert-vae-decoder", action="store_true")
parser.add_argument("--convert-vae-encoder", action="store_true")
parser.add_argument("--convert-unet", action="store_true")
parser.add_argument("--convert-safety-checker", action="store_true")
parser.add_argument(
Expand Down

0 comments on commit 086cc5e

Please sign in to comment.