diff --git a/README.md b/README.md index 96a28364..25e015b7 100644 --- a/README.md +++ b/README.md @@ -4,6 +4,30 @@ ## News + +**July 24, 2024** +- We are releasing **[Stable Video 4D (SV4D)](https://huggingface.co/stabilityai/sv4d)**, a video-to-4D diffusion model for novel-view video synthesis. For research purposes: + - **SV4D** was trained to generate 40 frames (5 video frames x 8 camera views) at 576x576 resolution, given 5 context frames (the input video), and 8 reference views (synthesised from the first frame of the input video, using a multi-view diffusion model like SV3D) of the same size, ideally white-background images with one object. + - To generate longer novel-view videos (21 frames), we propose a novel sampling method using SV4D, by first sampling 5 anchor frames and then densely sampling the remaining frames while maintaining temporal consistency. + - Please check our [project page](), [tech report]() and [video summary]() for more details. + +**QUICKSTART** : `python scripts/sampling/simple_video_sample_4d.py --input_path assets/test_video1.mp4 --output_folder outputs/sv4d` (after downloading [SV4D](https://huggingface.co/stabilityai/sv4d) and [SV3D_u]((https://huggingface.co/stabilityai/sv3d)) from HuggingFace) + +To run **SV4D** on a single input video of 21 frames: +- Download SV3D models (`sv3d_u.safetensors` and `sv3d_p.safetensors`) from [here](https://huggingface.co/stabilityai/sv3d) and SV4D model (`sv4d.safetensors`) from [here](https://huggingface.co/stabilityai/sv4d) to `checkpoints/` +- Run `python scripts/sampling/simple_video_sample_4d.py --input_path ` + - `input_path` : The input video `` can be + - a single video file in `gif` or `mp4` format, such as `assets/test_video1.mp4`, or + - a folder containing images of video frames in `.jpg`, `.jpeg`, or `.png` format, or + - a file name pattern matching images of video frames. + - `num_steps` : default is 20, can increase to 50 for better quality but longer sampling time. + - `sv3d_version` : To specify the SV3D model to generate reference multi-views, set `--sv3d_version=sv3d_u` for SV3D_u or `--sv3d_version=sv3d_p` for SV3D_p. + - `elevations_deg` : To generate novel-view videos at a specified elevation (default elevation is 10) using SV3D_p (default is SV3D_u), run `python scripts/sampling/simple_video_sample_4d.py --input_path test_video1.mp4 --sv3d_version sv3d_p --elevations_deg 30.0` + - **Background removal** : For input videos with plain background, (optionally) use [rembg](https://github.com/danielgatis/rembg) to remove background and crop video frames by setting `--remove_bg=True`. To obtain higher quality outputs on real-world input videos (with noisy background), try segmenting the foreground object using [Cliipdrop](https://clipdrop.co/) before running SV4D. + + ![tile](assets/sv4d.gif) + + **March 18, 2024** - We are releasing **[SV3D](https://huggingface.co/stabilityai/sv3d)**, an image-to-video model for novel multi-view synthesis, for research purposes: - **SV3D** was trained to generate 21 frames at resolution 576x576, given 1 context frame of the same size, ideally a white-background image with one object. diff --git a/assets/hiphop_parrot.mp4 b/assets/hiphop_parrot.mp4 new file mode 100644 index 00000000..f86514d4 Binary files /dev/null and b/assets/hiphop_parrot.mp4 differ diff --git a/assets/sv4d.gif b/assets/sv4d.gif new file mode 100644 index 00000000..a3127c26 Binary files /dev/null and b/assets/sv4d.gif differ diff --git a/assets/test_video1.mp4 b/assets/test_video1.mp4 new file mode 100644 index 00000000..0da88736 Binary files /dev/null and b/assets/test_video1.mp4 differ diff --git a/assets/test_video2.mp4 b/assets/test_video2.mp4 new file mode 100644 index 00000000..424f14dc Binary files /dev/null and b/assets/test_video2.mp4 differ diff --git a/scripts/demo/sv4d_helpers.py b/scripts/demo/sv4d_helpers.py new file mode 100644 index 00000000..b533b5e5 --- /dev/null +++ b/scripts/demo/sv4d_helpers.py @@ -0,0 +1,1207 @@ +import math +import os +from glob import glob +from pathlib import Path +from typing import Dict, List, Optional, Tuple, Union + +import cv2 +import imageio +import numpy as np +import torch +import torchvision.transforms as TT +from einops import rearrange, repeat +from omegaconf import ListConfig, OmegaConf +from PIL import Image, ImageSequence +from rembg import remove +from torch import autocast +from torchvision.transforms import ToTensor + +from scripts.util.detection.nsfw_and_watermark_dectection import DeepFloydDataFiltering +from sgm.modules.autoencoding.temporal_ae import VideoDecoder +from sgm.modules.diffusionmodules.guiders import ( + LinearPredictionGuider, + SpatiotemporalPredictionGuider, + TrapezoidPredictionGuider, + TrianglePredictionGuider, + VanillaCFG, +) +from sgm.modules.diffusionmodules.sampling import ( + DPMPP2MSampler, + DPMPP2SAncestralSampler, + EulerAncestralSampler, + EulerEDMSampler, + HeunEDMSampler, + LinearMultistepSampler, +) +from sgm.util import default, instantiate_from_config + + +def get_resizing_factor( + desired_shape: Tuple[int, int], current_shape: Tuple[int, int] +) -> float: + r_bound = desired_shape[1] / desired_shape[0] + aspect_r = current_shape[1] / current_shape[0] + if r_bound >= 1.0: + if aspect_r >= r_bound: + factor = min(desired_shape) / min(current_shape) + else: + if aspect_r < 1.0: + factor = max(desired_shape) / min(current_shape) + else: + factor = max(desired_shape) / max(current_shape) + else: + if aspect_r <= r_bound: + factor = min(desired_shape) / min(current_shape) + else: + if aspect_r > 1: + factor = max(desired_shape) / min(current_shape) + else: + factor = max(desired_shape) / max(current_shape) + return factor + + +def load_img_for_prediction_no_st( + image_path: str, + mask_path: str, + W: int, + H: int, + crop_h: int, + crop_w: int, + device="cuda", +) -> torch.Tensor: + image = Image.open(image_path) + if image is None: + return None + image = np.array(image).astype(np.float32) / 255 + h, w = image.shape[:2] + rotated = 0 + + mask = None + if mask_path is not None: + mask = Image.open(mask_path) + mask = np.array(mask).astype(np.float32) / 255 + mask = np.any(mask.reshape(h, w, -1) > 0, axis=2, keepdims=True).astype( + np.float32 + ) + elif image.shape[-1] == 4: + mask = image[:, :, 3:] + + if mask is not None: + image = image[:, :, :3] * mask + (1 - mask) + # if "DAVIS" in image_path: + # y, x, _ = np.where(mask > 0) + # x_mean, y_mean = np.mean(x), np.mean(y) + # else: + # x_mean, y_mean = w//2, h//2 + # h_new = int(max(crop_h, crop_w) * 1.33) + # x_min = max(int(x_mean - h_new//2), 0) + # y_min = max(int(y_mean - h_new//2), 0) + # image_cropped = image[y_min : y_min + h_new, x_min : x_min + h_new] + # h_crop, w_crop = image_cropped.shape[:2] + # h_new = max(h_crop, w_crop) + # top = max((h_new - h_crop) // 2, 0) + # left = max((h_new - w_crop) // 2, 0) + # image_padded = np.ones((h_new, h_new, 3)).astype(np.float32) + # image_padded[top : top + h_crop, left : left + w_crop, :] = image_cropped + # image = image_padded + # h, w = image.shape[:2] + + image = image.transpose(2, 0, 1) + image = torch.from_numpy(image).to(dtype=torch.float32) + image = image.unsqueeze(0) + + rfs = get_resizing_factor((H, W), (h, w)) + resize_size = [int(np.ceil(rfs * s)) for s in (h, w)] + top = (resize_size[0] - H) // 2 + left = (resize_size[1] - W) // 2 + + image = torch.nn.functional.interpolate( + image, resize_size, mode="area", antialias=False + ) + image = TT.functional.crop(image, top=top, left=left, height=H, width=W) + return image.to(device) * 2.0 - 1.0, rotated + + +def read_gif(input_path, n_frames): + frames = [] + video = Image.open(input_path) + if video.n_frames < n_frames: + return frames + for img in ImageSequence.Iterator(video): + frames.append(img.convert("RGB")) + if len(frames) == n_frames: + break + return frames + + +def read_mp4(input_path, n_frames): + frames = [] + vidcap = cv2.VideoCapture(input_path) + success, image = vidcap.read() + while success: + frames.append(Image.fromarray(cv2.cvtColor(image, cv2.COLOR_BGR2RGB))) + success, image = vidcap.read() + if len(frames) == n_frames: + break + return frames + + +def save_img(file_name, img): + output_dir = os.path.dirname(file_name) + os.makedirs(output_dir, exist_ok=True) + imageio.imwrite( + file_name, + (((img[0].permute(1, 2, 0) + 1) / 2).cpu().numpy() * 255.0).astype(np.uint8), + ) + + +def save_video(file_name, imgs, fps=10): + output_dir = os.path.dirname(file_name) + os.makedirs(output_dir, exist_ok=True) + img_grid = [ + (((img[0].permute(1, 2, 0) + 1) / 2).cpu().numpy() * 255.0).astype(np.uint8) + for img in imgs + ] + if file_name.endswith(".gif"): + imageio.mimwrite(file_name, img_grid, fps=fps, loop=0) + else: + imageio.mimwrite(file_name, img_grid, fps=fps) + + +def read_video( + input_path: str, + n_frames: int, + W: int, + H: int, + remove_bg: bool = False, + image_frame_ratio: Optional[float] = None, + device: str = "cuda", +): + path = Path(input_path) + is_video_file = False + all_img_paths = [] + if path.is_file(): + if any([input_path.endswith(x) for x in [".gif", ".mp4"]]): + is_video_file = True + else: + raise ValueError("Path is not a valid video file.") + elif path.is_dir(): + all_img_paths = sorted( + [ + f + for f in path.iterdir() + if f.is_file() and f.suffix.lower() in [".jpg", ".jpeg", ".png"] + ] + )[:n_frames] + elif "*" in input_path: + all_img_paths = sorted(glob(input_path))[:n_frames] + else: + raise ValueError + + if is_video_file and input_path.endswith(".gif"): + images = read_gif(input_path, n_frames)[:n_frames] + elif is_video_file and input_path.endswith(".mp4"): + images = read_mp4(input_path, n_frames)[:n_frames] + else: + print(f"Loading {len(all_img_paths)} video frames...") + images = [Image.open(img_path) for img_path in all_img_paths] + + if len(images) != n_frames: + raise ValueError("Input video contains fewer than {n_frames} frames.") + + # Remove background and crop video frames + images_v0 = [] + for image in images: + if remove_bg: + if image.mode == "RGBA": + pass + else: + image.thumbnail([W, H], Image.Resampling.LANCZOS) + image = remove(image.convert("RGBA"), alpha_matting=True) + image_arr = np.array(image) + in_w, in_h = image_arr.shape[:2] + ret, mask = cv2.threshold( + np.array(image.split()[-1]), 0, 255, cv2.THRESH_BINARY + ) + x, y, w, h = cv2.boundingRect(mask) + max_size = max(w, h) + side_len = ( + int(max_size / image_frame_ratio) + if image_frame_ratio is not None + else in_w + ) + padded_image = np.zeros((side_len, side_len, 4), dtype=np.uint8) + center = side_len // 2 + padded_image[ + center - h // 2 : center - h // 2 + h, + center - w // 2 : center - w // 2 + w, + ] = image_arr[y : y + h, x : x + w] + rgba = Image.fromarray(padded_image).resize((W, H), Image.LANCZOS) + rgba_arr = np.array(rgba) / 255.0 + rgb = rgba_arr[..., :3] * rgba_arr[..., -1:] + (1 - rgba_arr[..., -1:]) + images = Image.fromarray((rgb * 255).astype(np.uint8)) + image = ToTensor()(image).unsqueeze(0).to(device) + images_v0.append(image * 2.0 - 1.0) + return images_v0 + + +def sample_sv3d( + image, + num_frames: Optional[int] = None, # 21 for SV3D + num_steps: Optional[int] = None, + version: str = "sv3d_u", + fps_id: int = 6, + motion_bucket_id: int = 127, + cond_aug: float = 0.02, + decoding_t: int = 14, # Number of frames decoded at a time! This eats most VRAM. Reduce if necessary. + device: str = "cuda", + polar_rad: Optional[Union[float, List[float]]] = None, + azim_rad: Optional[List[float]] = None, + verbose: Optional[bool] = False, +): + """ + Simple script to generate a single sample conditioned on an image `input_path` or multiple images, one for each + image file in folder `input_path`. If you run out of VRAM, try decreasing `decoding_t`. + """ + + if version == "sv3d_u": + model_config = "scripts/sampling/configs/sv3d_u.yaml" + elif version == "sv3d_p": + model_config = "scripts/sampling/configs/sv3d_p.yaml" + else: + raise ValueError(f"Version {version} does not exist.") + + model, filter = load_model( + model_config, + device, + num_frames, + num_steps, + verbose, + ) + + H, W = image.shape[2:] + F = 8 + C = 4 + shape = (num_frames, C, H // F, W // F) + + value_dict = {} + value_dict["cond_frames_without_noise"] = image + value_dict["motion_bucket_id"] = motion_bucket_id + value_dict["fps_id"] = fps_id + value_dict["cond_aug"] = cond_aug + value_dict["cond_frames"] = image + cond_aug * torch.randn_like(image) + if "sv3d_p" in version: + value_dict["polars_rad"] = polar_rad + value_dict["azimuths_rad"] = azim_rad + + with torch.no_grad(): + with torch.autocast(device): + batch, batch_uc = get_batch_sv3d( + get_unique_embedder_keys_from_conditioner(model.conditioner), + value_dict, + [1, num_frames], + T=num_frames, + device=device, + ) + c, uc = model.conditioner.get_unconditional_conditioning( + batch, + batch_uc=batch_uc, + force_uc_zero_embeddings=[ + "cond_frames", + "cond_frames_without_noise", + ], + ) + + for k in ["crossattn", "concat"]: + uc[k] = repeat(uc[k], "b ... -> b t ...", t=num_frames) + uc[k] = rearrange(uc[k], "b t ... -> (b t) ...", t=num_frames) + c[k] = repeat(c[k], "b ... -> b t ...", t=num_frames) + c[k] = rearrange(c[k], "b t ... -> (b t) ...", t=num_frames) + + randn = torch.randn(shape, device=device) + + additional_model_inputs = {} + additional_model_inputs["image_only_indicator"] = torch.zeros( + 2, num_frames + ).to(device) + additional_model_inputs["num_video_frames"] = batch["num_video_frames"] + + def denoiser(input, sigma, c): + return model.denoiser( + model.model, input, sigma, c, **additional_model_inputs + ) + + samples_z = model.sampler(denoiser, randn, cond=c, uc=uc) + model.en_and_decode_n_samples_a_time = decoding_t + samples_x = model.decode_first_stage(samples_z) + samples_x[-1:] = value_dict["cond_frames_without_noise"] + samples = torch.clamp(samples_x, min=-1.0, max=1.0) + + return samples + + +def decode_latents(model, samples_z, timesteps): + if isinstance(model.first_stage_model.decoder, VideoDecoder): + samples_x = model.decode_first_stage(samples_z, timesteps=timesteps) + else: + samples_x = model.decode_first_stage(samples_z) + samples = torch.clamp((samples_x + 1.0) / 2.0, min=0.0, max=1.0) + return samples + + +def init_embedder_options_no_st(keys, init_dict, prompt=None, negative_prompt=None): + # Hardcoded demo settings; might undergo some changes in the future + + value_dict = {} + for key in keys: + if key == "txt": + if prompt is None: + prompt = "A professional photograph of an astronaut riding a pig" + if negative_prompt is None: + negative_prompt = "" + + value_dict["prompt"] = prompt + value_dict["negative_prompt"] = negative_prompt + + if key == "original_size_as_tuple": + orig_width = init_dict["orig_width"] + orig_height = init_dict["orig_height"] + + value_dict["orig_width"] = orig_width + value_dict["orig_height"] = orig_height + + if key == "crop_coords_top_left": + crop_coord_top = 0 + crop_coord_left = 0 + + value_dict["crop_coords_top"] = crop_coord_top + value_dict["crop_coords_left"] = crop_coord_left + + if key == "aesthetic_score": + value_dict["aesthetic_score"] = 6.0 + value_dict["negative_aesthetic_score"] = 2.5 + + if key == "target_size_as_tuple": + value_dict["target_width"] = init_dict["target_width"] + value_dict["target_height"] = init_dict["target_height"] + + if key in ["fps_id", "fps"]: + fps = 6 + + value_dict["fps"] = fps + value_dict["fps_id"] = fps - 1 + + if key == "motion_bucket_id": + mb_id = 127 + value_dict["motion_bucket_id"] = mb_id + + if key == "noise_level": + value_dict["noise_level"] = 0 + + return value_dict + + +def get_discretization_no_st(discretization, options, key=1): + if discretization == "LegacyDDPMDiscretization": + discretization_config = { + "target": "sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization", + } + elif discretization == "EDMDiscretization": + sigma_min = options.get("sigma_min", 0.03) + sigma_max = options.get("sigma_max", 14.61) + rho = options.get("rho", 3.0) + discretization_config = { + "target": "sgm.modules.diffusionmodules.discretizer.EDMDiscretization", + "params": { + "sigma_min": sigma_min, + "sigma_max": sigma_max, + "rho": rho, + }, + } + return discretization_config + + +def get_guider_no_st(options, key): + guider = [ + "VanillaCFG", + "IdentityGuider", + "LinearPredictionGuider", + "TrianglePredictionGuider", + "TrapezoidPredictionGuider", + "SpatiotemporalPredictionGuider", + ][options.get("guider", 2)] + + additional_guider_kwargs = ( + options["additional_guider_kwargs"] + if "additional_guider_kwargs" in options + else {} + ) + + if guider == "IdentityGuider": + guider_config = { + "target": "sgm.modules.diffusionmodules.guiders.IdentityGuider" + } + elif guider == "VanillaCFG": + scale_schedule = "Identity" + + if scale_schedule == "Identity": + scale = options.get("cfg", 5.0) + + scale_schedule_config = { + "target": "sgm.modules.diffusionmodules.guiders.IdentitySchedule", + "params": {"scale": scale}, + } + + elif scale_schedule == "Oscillating": + small_scale = 4.0 + large_scale = 16.0 + sigma_cutoff = 1.0 + + scale_schedule_config = { + "target": "sgm.modules.diffusionmodules.guiders.OscillatingSchedule", + "params": { + "small_scale": small_scale, + "large_scale": large_scale, + "sigma_cutoff": sigma_cutoff, + }, + } + else: + raise NotImplementedError + + guider_config = { + "target": "sgm.modules.diffusionmodules.guiders.VanillaCFG", + "params": { + "scale_schedule_config": scale_schedule_config, + **additional_guider_kwargs, + }, + } + elif guider == "LinearPredictionGuider": + max_scale = options.get("cfg", 1.5) + + guider_config = { + "target": "sgm.modules.diffusionmodules.guiders.LinearPredictionGuider", + "params": { + "max_scale": max_scale, + "num_frames": options["num_frames"], + **additional_guider_kwargs, + }, + } + elif guider == "TrianglePredictionGuider": + max_scale = options.get("cfg", 1.5) + period = options.get("period", 1.0) + period_fusing = options.get("period_fusing", "max") + + guider_config = { + "target": "sgm.modules.diffusionmodules.guiders.TrianglePredictionGuider", + "params": { + "max_scale": max_scale, + "num_frames": options["num_frames"], + "period": period, + "period_fusing": period_fusing, + **additional_guider_kwargs, + }, + } + elif guider == "TrapezoidPredictionGuider": + max_scale = options.get("cfg", 1.5) + edge_perc = options.get("edge_perc", 0.1) + + guider_config = { + "target": "sgm.modules.diffusionmodules.guiders.TrapezoidPredictionGuider", + "params": { + "max_scale": max_scale, + "num_frames": options["num_frames"], + "edge_perc": edge_perc, + **additional_guider_kwargs, + }, + } + elif guider == "SpatiotemporalPredictionGuider": + max_scale = options.get("cfg", 1.5) + + guider_config = { + "target": "sgm.modules.diffusionmodules.guiders.SpatiotemporalPredictionGuider", + "params": { + "max_scale": max_scale, + "num_frames": options["num_frames"], + **additional_guider_kwargs, + }, + } + else: + raise NotImplementedError + return guider_config + + +def get_sampler_no_st(sampler_name, steps, discretization_config, guider_config, key=1): + if sampler_name == "EulerEDMSampler" or sampler_name == "HeunEDMSampler": + s_churn = 0.0 + s_tmin = 0.0 + s_tmax = 999.0 + s_noise = 1.0 + + if sampler_name == "EulerEDMSampler": + sampler = EulerEDMSampler( + num_steps=steps, + discretization_config=discretization_config, + guider_config=guider_config, + s_churn=s_churn, + s_tmin=s_tmin, + s_tmax=s_tmax, + s_noise=s_noise, + verbose=False, + ) + elif sampler_name == "HeunEDMSampler": + sampler = HeunEDMSampler( + num_steps=steps, + discretization_config=discretization_config, + guider_config=guider_config, + s_churn=s_churn, + s_tmin=s_tmin, + s_tmax=s_tmax, + s_noise=s_noise, + verbose=False, + ) + elif ( + sampler_name == "EulerAncestralSampler" + or sampler_name == "DPMPP2SAncestralSampler" + ): + s_noise = 1.0 + eta = 1.0 + + if sampler_name == "EulerAncestralSampler": + sampler = EulerAncestralSampler( + num_steps=steps, + discretization_config=discretization_config, + guider_config=guider_config, + eta=eta, + s_noise=s_noise, + verbose=False, + ) + elif sampler_name == "DPMPP2SAncestralSampler": + sampler = DPMPP2SAncestralSampler( + num_steps=steps, + discretization_config=discretization_config, + guider_config=guider_config, + eta=eta, + s_noise=s_noise, + verbose=False, + ) + elif sampler_name == "DPMPP2MSampler": + sampler = DPMPP2MSampler( + num_steps=steps, + discretization_config=discretization_config, + guider_config=guider_config, + verbose=False, + ) + elif sampler_name == "LinearMultistepSampler": + order = 4 + sampler = LinearMultistepSampler( + num_steps=steps, + discretization_config=discretization_config, + guider_config=guider_config, + order=order, + verbose=False, + ) + else: + raise ValueError(f"unknown sampler {sampler_name}!") + + return sampler + + +def init_sampling_no_st( + key=1, + options: Optional[Dict[str, int]] = None, +): + options = {} if options is None else options + + num_rows, num_cols = 1, 1 + steps = options.get("num_steps", 40) + sampler = [ + "EulerEDMSampler", + "HeunEDMSampler", + "EulerAncestralSampler", + "DPMPP2SAncestralSampler", + "DPMPP2MSampler", + "LinearMultistepSampler", + ][options.get("sampler", 0)] + discretization = [ + "LegacyDDPMDiscretization", + "EDMDiscretization", + ][options.get("discretization", 1)] + + discretization_config = get_discretization_no_st( + discretization, options=options, key=key + ) + + guider_config = get_guider_no_st(options=options, key=key) + + sampler = get_sampler_no_st( + sampler, steps, discretization_config, guider_config, key=key + ) + return sampler, num_rows, num_cols + + +def run_img2vid( + version_dict, + model, + image, + seed=23, + polar_rad=[10] * 21, + azim_rad=np.linspace(0, 360, 21 + 1)[1:], + cond_motion=None, + cond_view=None, +): + options = version_dict["options"] + H = version_dict["H"] + W = version_dict["W"] + T = version_dict["T"] + C = version_dict["C"] + F = version_dict["f"] + init_dict = { + "orig_width": 576, + "orig_height": 576, + "target_width": W, + "target_height": H, + } + ukeys = set(get_unique_embedder_keys_from_conditioner(model.conditioner)) + + value_dict = init_embedder_options_no_st( + ukeys, + init_dict, + negative_prompt=options.get("negative_promt", ""), + prompt="A 3D model.", + ) + if "fps" not in ukeys: + value_dict["fps"] = 6 + + value_dict["is_image"] = 0 + value_dict["is_webvid"] = 0 + value_dict["image_only_indicator"] = 0 + + cond_aug = 0.00 + if cond_motion is not None: + value_dict["cond_frames_without_noise"] = cond_motion + value_dict["cond_frames"] = ( + cond_motion[:, None].repeat(1, cond_view.shape[0], 1, 1, 1).flatten(0, 1) + ) + value_dict["cond_motion"] = cond_motion + value_dict["cond_view"] = cond_view + else: + value_dict["cond_frames_without_noise"] = image + value_dict["cond_frames"] = image + cond_aug * torch.randn_like(image) + value_dict["cond_aug"] = cond_aug + value_dict["polar_rad"] = polar_rad + value_dict["azimuth_rad"] = azim_rad + value_dict["rotated"] = False + value_dict["cond_motion"] = cond_motion + value_dict["cond_view"] = cond_view + + # seed_everything(seed) + + options["num_frames"] = T + sampler, num_rows, num_cols = init_sampling_no_st(options=options) + num_samples = num_rows * num_cols + + samples = do_sample( + model, + sampler, + value_dict, + num_samples, + H, + W, + C, + F, + T=T, + batch2model_input=["num_video_frames", "image_only_indicator"], + force_uc_zero_embeddings=options.get("force_uc_zero_embeddings", None), + force_cond_zero_embeddings=options.get("force_cond_zero_embeddings", None), + return_latents=False, + decoding_t=options.get("decoding_T", T), + ) + + return samples + + +def do_sample( + model, + sampler, + value_dict, + num_samples, + H, + W, + C, + F, + force_uc_zero_embeddings: Optional[List] = None, + force_cond_zero_embeddings: Optional[List] = None, + batch2model_input: List = None, + return_latents=False, + filter=None, + T=None, + additional_batch_uc_fields=None, + decoding_t=None, +): + force_uc_zero_embeddings = default(force_uc_zero_embeddings, []) + batch2model_input = default(batch2model_input, []) + additional_batch_uc_fields = default(additional_batch_uc_fields, []) + + precision_scope = autocast + with torch.no_grad(): + with precision_scope("cuda"): + with model.ema_scope(): + if T is not None: + num_samples = [num_samples, T] + else: + num_samples = [num_samples] + + batch, batch_uc = get_batch( + get_unique_embedder_keys_from_conditioner(model.conditioner), + value_dict, + num_samples, + T=T, + additional_batch_uc_fields=additional_batch_uc_fields, + ) + + c, uc = model.conditioner.get_unconditional_conditioning( + batch, + batch_uc=batch_uc, + force_uc_zero_embeddings=force_uc_zero_embeddings, + force_cond_zero_embeddings=force_cond_zero_embeddings, + ) + + for k in c: + if not k == "crossattn": + c[k], uc[k] = map( + lambda y: y[k][: math.prod(num_samples)].to("cuda"), (c, uc) + ) + + additional_model_inputs = {} + for k in batch2model_input: + if k == "image_only_indicator": + assert T is not None + + if isinstance( + sampler.guider, + ( + VanillaCFG, + LinearPredictionGuider, + TrianglePredictionGuider, + TrapezoidPredictionGuider, + SpatiotemporalPredictionGuider, + ), + ): + additional_model_inputs[k] = torch.zeros( + num_samples[0] * 2, num_samples[1] + ).to("cuda") + else: + additional_model_inputs[k] = torch.zeros(num_samples).to( + "cuda" + ) + else: + additional_model_inputs[k] = batch[k] + + shape = (math.prod(num_samples), C, H // F, W // F) + randn = torch.randn(shape).to("cuda") + + def denoiser(input, sigma, c): + return model.denoiser( + model.model, input, sigma, c, **additional_model_inputs + ) + + samples_z = sampler(denoiser, randn, cond=c, uc=uc) + + if isinstance(model.first_stage_model.decoder, VideoDecoder): + samples_x = model.decode_first_stage( + samples_z, timesteps=default(decoding_t, T) + ) + else: + samples_x = model.decode_first_stage(samples_z) + samples = torch.clamp((samples_x + 1.0) / 2.0, min=0.0, max=1.0) + + if filter is not None: + samples = filter(samples) + + if return_latents: + return samples, samples_z + + return samples + + +def do_sample_per_step( + model, + sampler, + value_dict, + num_samples, + force_uc_zero_embeddings: Optional[List] = None, + force_cond_zero_embeddings: Optional[List] = None, + batch2model_input: List = None, + T=None, + additional_batch_uc_fields=None, + step=None, + noisy_latents=None, +): + force_uc_zero_embeddings = default(force_uc_zero_embeddings, []) + batch2model_input = default(batch2model_input, []) + additional_batch_uc_fields = default(additional_batch_uc_fields, []) + + precision_scope = autocast + with torch.no_grad(): + with precision_scope("cuda"): + with model.ema_scope(): + if T is not None: + num_samples = [num_samples, T] + else: + num_samples = [num_samples] + + batch, batch_uc = get_batch( + get_unique_embedder_keys_from_conditioner(model.conditioner), + value_dict, + num_samples, + T=T, + additional_batch_uc_fields=additional_batch_uc_fields, + ) + + c, uc = model.conditioner.get_unconditional_conditioning( + batch, + batch_uc=batch_uc, + force_uc_zero_embeddings=force_uc_zero_embeddings, + force_cond_zero_embeddings=force_cond_zero_embeddings, + ) + + for k in c: + if not k == "crossattn": + c[k], uc[k] = map( + lambda y: y[k][: math.prod(num_samples)].to("cuda"), (c, uc) + ) + + additional_model_inputs = {} + for k in batch2model_input: + if k == "image_only_indicator": + assert T is not None + + if isinstance( + sampler.guider, + ( + VanillaCFG, + LinearPredictionGuider, + TrianglePredictionGuider, + TrapezoidPredictionGuider, + SpatiotemporalPredictionGuider, + ), + ): + additional_model_inputs[k] = torch.zeros( + num_samples[0] * 2, num_samples[1] + ).to("cuda") + else: + additional_model_inputs[k] = torch.zeros(num_samples).to( + "cuda" + ) + else: + additional_model_inputs[k] = batch[k] + + noisy_latents_scaled, s_in, sigmas, num_sigmas, _, _ = ( + sampler.prepare_sampling_loop( + noisy_latents.clone(), c, uc, sampler.num_steps + ) + ) + + if step == 0: + latents = noisy_latents_scaled + else: + latents = noisy_latents + + def denoiser(input, sigma, c): + return model.denoiser( + model.model, input, sigma, c, **additional_model_inputs + ) + + gamma = ( + min(sampler.s_churn / (num_sigmas - 1), 2**0.5 - 1) + if sampler.s_tmin <= sigmas[step] <= sampler.s_tmax + else 0.0 + ) + samples_z = sampler.sampler_step( + s_in * sigmas[step], + s_in * sigmas[step + 1], + denoiser, + latents, + c, + uc, + gamma, + ) + + return samples_z + + +def run_img2vid_per_step( + version_dict, + model, + image, + seed=23, + polar_rad=[10] * 21, + azim_rad=np.linspace(0, 360, 21 + 1)[1:], + cond_motion=None, + cond_view=None, + step=None, + noisy_latents=None, +): + options = version_dict["options"] + H = version_dict["H"] + W = version_dict["W"] + T = version_dict["T"] + C = version_dict["C"] + F = version_dict["f"] + init_dict = { + "orig_width": 576, + "orig_height": 576, + "target_width": W, + "target_height": H, + } + ukeys = set(get_unique_embedder_keys_from_conditioner(model.conditioner)) + + value_dict = init_embedder_options_no_st( + ukeys, + init_dict, + negative_prompt=options.get("negative_promt", ""), + prompt="A 3D model.", + ) + if "fps" not in ukeys: + value_dict["fps"] = 6 + + value_dict["is_image"] = 0 + value_dict["is_webvid"] = 0 + value_dict["image_only_indicator"] = 0 + + cond_aug = 0.00 + if cond_motion is not None: + value_dict["cond_frames_without_noise"] = cond_motion + value_dict["cond_frames"] = ( + cond_motion[:, None].repeat(1, cond_view.shape[0], 1, 1, 1).flatten(0, 1) + ) + value_dict["cond_motion"] = cond_motion + value_dict["cond_view"] = cond_view + else: + value_dict["cond_frames_without_noise"] = image + value_dict["cond_frames"] = image + cond_aug * torch.randn_like(image) + value_dict["cond_aug"] = cond_aug + value_dict["polar_rad"] = polar_rad + value_dict["azimuth_rad"] = azim_rad + value_dict["rotated"] = False + value_dict["cond_motion"] = cond_motion + value_dict["cond_view"] = cond_view + + # seed_everything(seed) + + options["num_frames"] = T + sampler, num_rows, num_cols = init_sampling_no_st(options=options) + num_samples = num_rows * num_cols + + samples = do_sample_per_step( + model, + sampler, + value_dict, + num_samples, + force_uc_zero_embeddings=options.get("force_uc_zero_embeddings", None), + force_cond_zero_embeddings=options.get("force_cond_zero_embeddings", None), + batch2model_input=["num_video_frames", "image_only_indicator"], + T=T, + step=step, + noisy_latents=noisy_latents, + ) + + return samples + + +def get_unique_embedder_keys_from_conditioner(conditioner): + return list(set([x.input_key for x in conditioner.embedders])) + + +def get_batch_sv3d(keys, value_dict, N, T, device): + batch = {} + batch_uc = {} + + for key in keys: + if key == "fps_id": + batch[key] = ( + torch.tensor([value_dict["fps_id"]]) + .to(device) + .repeat(int(math.prod(N))) + ) + elif key == "motion_bucket_id": + batch[key] = ( + torch.tensor([value_dict["motion_bucket_id"]]) + .to(device) + .repeat(int(math.prod(N))) + ) + elif key == "cond_aug": + batch[key] = repeat( + torch.tensor([value_dict["cond_aug"]]).to(device), + "1 -> b", + b=math.prod(N), + ) + elif key == "cond_frames" or key == "cond_frames_without_noise": + batch[key] = repeat(value_dict[key], "1 ... -> b ...", b=N[0]) + elif key == "polars_rad" or key == "azimuths_rad": + batch[key] = torch.tensor(value_dict[key]).to(device).repeat(N[0]) + else: + batch[key] = value_dict[key] + + if T is not None: + batch["num_video_frames"] = T + + for key in batch.keys(): + if key not in batch_uc and isinstance(batch[key], torch.Tensor): + batch_uc[key] = torch.clone(batch[key]) + return batch, batch_uc + + +def get_batch( + keys, + value_dict: dict, + N: Union[List, ListConfig], + device: str = "cuda", + T: int = None, + additional_batch_uc_fields: List[str] = [], +): + batch = {} + batch_uc = {} + + for key in keys: + if key == "txt": + batch["txt"] = [value_dict["prompt"]] * math.prod(N) + batch_uc["txt"] = [value_dict["negative_prompt"]] * math.prod(N) + + elif key == "original_size_as_tuple": + batch["original_size_as_tuple"] = ( + torch.tensor([value_dict["orig_height"], value_dict["orig_width"]]) + .to(device) + .repeat(math.prod(N), 1) + ) + elif key == "crop_coords_top_left": + batch["crop_coords_top_left"] = ( + torch.tensor( + [value_dict["crop_coords_top"], value_dict["crop_coords_left"]] + ) + .to(device) + .repeat(math.prod(N), 1) + ) + elif key == "aesthetic_score": + batch["aesthetic_score"] = ( + torch.tensor([value_dict["aesthetic_score"]]) + .to(device) + .repeat(math.prod(N), 1) + ) + batch_uc["aesthetic_score"] = ( + torch.tensor([value_dict["negative_aesthetic_score"]]) + .to(device) + .repeat(math.prod(N), 1) + ) + + elif key == "target_size_as_tuple": + batch["target_size_as_tuple"] = ( + torch.tensor([value_dict["target_height"], value_dict["target_width"]]) + .to(device) + .repeat(math.prod(N), 1) + ) + elif key == "fps": + batch[key] = ( + torch.tensor([value_dict["fps"]]).to(device).repeat(math.prod(N)) + ) + elif key == "fps_id": + batch[key] = ( + torch.tensor([value_dict["fps_id"]]).to(device).repeat(math.prod(N)) + ) + elif key == "motion_bucket_id": + batch[key] = ( + torch.tensor([value_dict["motion_bucket_id"]]) + .to(device) + .repeat(math.prod(N)) + ) + elif key == "pool_image": + batch[key] = repeat(value_dict[key], "1 ... -> b ...", b=math.prod(N)).to( + device, dtype=torch.half + ) + elif key == "is_image": + batch[key] = ( + torch.tensor([value_dict["is_image"]]) + .to(device) + .repeat(math.prod(N)) + .long() + ) + elif key == "is_webvid": + batch[key] = ( + torch.tensor([value_dict["is_webvid"]]) + .to(device) + .repeat(math.prod(N)) + .long() + ) + elif key == "cond_aug": + batch[key] = repeat( + torch.tensor([value_dict["cond_aug"]]).to("cuda"), + "1 -> b", + b=math.prod(N), + ) + elif ( + key == "cond_frames" + or key == "cond_frames_without_noise" + or key == "back_frames" + ): + # batch[key] = repeat(value_dict[key], "1 ... -> b ...", b=N[0]) + batch[key] = value_dict[key] + + elif key == "interpolation_context": + batch[key] = repeat( + value_dict["interpolation_context"], "b ... -> (b n) ...", n=N[1] + ) + + elif key == "start_frame": + assert T is not None + batch[key] = repeat(value_dict[key], "b ... -> (b t) ...", t=T) + + elif key == "polar_rad" or key == "azimuth_rad": + batch[key] = ( + torch.tensor(value_dict[key]).to(device).repeat(math.prod(N) // T) + ) + + elif key == "rotated": + batch[key] = ( + torch.tensor([value_dict["rotated"]]).to(device).repeat(math.prod(N)) + ) + + else: + batch[key] = value_dict[key] + + if T is not None: + batch["num_video_frames"] = T + + for key in batch.keys(): + if key not in batch_uc and isinstance(batch[key], torch.Tensor): + batch_uc[key] = torch.clone(batch[key]) + elif key in additional_batch_uc_fields and key not in batch_uc: + batch_uc[key] = copy.copy(batch[key]) + return batch, batch_uc + + +def load_model( + config: str, + device: str, + num_frames: int, + num_steps: int, + verbose: bool = False, +): + config = OmegaConf.load(config) + if device == "cuda": + config.model.params.conditioner_config.params.emb_models[ + 0 + ].params.open_clip_embedding_config.params.init_device = device + + config.model.params.sampler_config.params.verbose = verbose + config.model.params.sampler_config.params.num_steps = num_steps + config.model.params.sampler_config.params.guider_config.params.num_frames = ( + num_frames + ) + if device == "cuda": + with torch.device(device): + model = instantiate_from_config(config.model).to(device).eval() + else: + model = instantiate_from_config(config.model).to(device).eval() + + filter = DeepFloydDataFiltering(verbose=False, device=device) + return model, filter diff --git a/scripts/sampling/configs/sv4d.yaml b/scripts/sampling/configs/sv4d.yaml new file mode 100644 index 00000000..b908b758 --- /dev/null +++ b/scripts/sampling/configs/sv4d.yaml @@ -0,0 +1,208 @@ +N_TIME: 5 +N_VIEW: 8 +N_FRAMES: 40 + +model: + target: sgm.models.diffusion.DiffusionEngine + params: + scale_factor: 0.18215 + en_and_decode_n_samples_a_time: 7 + disable_first_stage_autocast: True + ckpt_path: checkpoints/sv4d.safetensors + + denoiser_config: + target: sgm.modules.diffusionmodules.denoiser.Denoiser + params: + scaling_config: + target: sgm.modules.diffusionmodules.denoiser_scaling.VScalingWithEDMcNoise + + network_config: + target: sgm.modules.diffusionmodules.video_model.SpatialUNetModelWithTime + params: + adm_in_channels: 1280 + attention_resolutions: [4, 2, 1] + channel_mult: [1, 2, 4, 4] + context_dim: 1024 + extra_ff_mix_layer: True + in_channels: 8 + legacy: False + model_channels: 320 + num_classes: sequential + num_head_channels: 64 + num_res_blocks: 2 + out_channels: 4 + replicate_time_mix_bug: True + spatial_transformer_attn_type: softmax-xformers + time_block_merge_factor: 0.0 + time_block_merge_strategy: learned_with_images + time_kernel_size: [3, 1, 1] + time_mix_legacy: False + transformer_depth: 1 + use_checkpoint: False + use_linear_in_transformer: True + use_spatial_context: True + use_spatial_transformer: True + use_motion_attention: True + + conditioner_config: + target: sgm.modules.GeneralConditioner + params: + emb_models: + + - input_key: cond_frames_without_noise + target: sgm.modules.encoders.modules.FrozenOpenCLIPImagePredictionEmbedder + is_trainable: False + params: + n_cond_frames: ${N_TIME} + n_copies: 1 + open_clip_embedding_config: + target: sgm.modules.encoders.modules.FrozenOpenCLIPImageEmbedder + params: + freeze: True + + - input_key: cond_frames + target: sgm.modules.encoders.modules.VideoPredictionEmbedderWithEncoder + is_trainable: False + params: + is_ae: True + n_cond_frames: ${N_FRAMES} + n_copies: 1 + encoder_config: + target: sgm.models.autoencoder.AutoencoderKLModeOnly + params: + ddconfig: + attn_resolutions: [] + attn_type: vanilla-xformers + ch: 128 + ch_mult: [1, 2, 4, 4] + double_z: True + dropout: 0.0 + in_channels: 3 + num_res_blocks: 2 + out_ch: 3 + resolution: 256 + z_channels: 4 + embed_dim: 4 + lossconfig: + target: torch.nn.Identity + monitor: val/rec_loss + sigma_cond_config: + target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND + params: + outdim: 256 + sigma_sampler_config: + target: sgm.modules.diffusionmodules.sigma_sampling.ZeroSampler + + # - input_key: cond_aug + # is_trainable: False + # target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND + # params: + # outdim: 256 + + - input_key: polar_rad + is_trainable: False + target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND + params: + outdim: 512 + + - input_key: azimuth_rad + is_trainable: False + target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND + params: + outdim: 512 + + - input_key: cond_view + is_trainable: False + target: sgm.modules.encoders.modules.VideoPredictionEmbedderWithEncoder + params: + encoder_config: + target: sgm.models.autoencoder.AutoencoderKLModeOnly + params: + embed_dim: 4 + monitor: val/rec_loss + ddconfig: + attn_resolutions: [] + attn_type: vanilla-xformers + ch: 128 + ch_mult: [1, 2, 4, 4] + double_z: True + dropout: 0.0 + in_channels: 3 + num_res_blocks: 2 + out_ch: 3 + resolution: 256 + z_channels: 4 + lossconfig: + target: torch.nn.Identity + is_ae: True + n_cond_frames: ${N_VIEW} + n_copies: 1 + sigma_sampler_config: + target: sgm.modules.diffusionmodules.sigma_sampling.ZeroSampler + + - input_key: cond_motion + is_trainable: False + target: sgm.modules.encoders.modules.VideoPredictionEmbedderWithEncoder + params: + is_ae: True + n_cond_frames: ${N_TIME} + n_copies: 1 + encoder_config: + target: sgm.models.autoencoder.AutoencoderKLModeOnly + params: + embed_dim: 4 + monitor: val/rec_loss + ddconfig: + attn_resolutions: [] + attn_type: vanilla-xformers + ch: 128 + ch_mult: [1, 2, 4, 4] + double_z: True + dropout: 0.0 + in_channels: 3 + num_res_blocks: 2 + out_ch: 3 + resolution: 256 + z_channels: 4 + lossconfig: + target: torch.nn.Identity + sigma_sampler_config: + target: sgm.modules.diffusionmodules.sigma_sampling.ZeroSampler + + first_stage_config: + target: sgm.models.autoencoder.AutoencodingEngine + params: + loss_config: + target: torch.nn.Identity + regularizer_config: + target: sgm.modules.autoencoding.regularizers.DiagonalGaussianRegularizer + encoder_config: + target: torch.nn.Identity + decoder_config: + target: sgm.modules.diffusionmodules.model.Decoder + params: + attn_resolutions: [] + attn_type: vanilla-xformers + ch: 128 + ch_mult: [1, 2, 4, 4] + double_z: True + dropout: 0.0 + in_channels: 3 + num_res_blocks: 2 + out_ch: 3 + resolution: 256 + z_channels: 4 + + sampler_config: + target: sgm.modules.diffusionmodules.sampling.EulerEDMSampler + params: + discretization_config: + target: sgm.modules.diffusionmodules.discretizer.EDMDiscretization + params: + sigma_max: 500.0 + guider_config: + target: sgm.modules.diffusionmodules.guiders.LinearPredictionGuider + params: + max_scale: 2.5 + num_frames: ${N_FRAMES} + additional_cond_keys: [ cond_view, cond_motion ] diff --git a/scripts/sampling/simple_video_sample_4d.py b/scripts/sampling/simple_video_sample_4d.py new file mode 100644 index 00000000..51c809ea --- /dev/null +++ b/scripts/sampling/simple_video_sample_4d.py @@ -0,0 +1,236 @@ +import os +import sys +from glob import glob +from typing import List, Optional, Union + +from tqdm import tqdm + +sys.path.append(os.path.realpath(os.path.join(os.path.dirname(__file__), "../../"))) +import numpy as np +import torch +from fire import Fire + +from scripts.demo.sv4d_helpers import ( + decode_latents, + load_model, + read_video, + run_img2vid, + run_img2vid_per_step, + sample_sv3d, + save_video, +) + + +def sample( + input_path: str = "assets/test_video.mp4", # Can either be image file or folder with image files + output_folder: Optional[str] = "outputs/sv4d", + num_steps: Optional[int] = 20, + sv3d_version: str = "sv3d_u", # sv3d_u or sv3d_p + fps_id: int = 6, + motion_bucket_id: int = 127, + cond_aug: float = 1e-5, + seed: int = 23, + decoding_t: int = 14, # Number of frames decoded at a time! This eats most VRAM. Reduce if necessary. + device: str = "cuda", + elevations_deg: Optional[Union[float, List[float]]] = 10.0, + azimuths_deg: Optional[List[float]] = None, + image_frame_ratio: Optional[float] = None, + verbose: Optional[bool] = False, + remove_bg: bool = False, +): + """ + Simple script to generate multiple novel-view videos conditioned on a video `input_path` or multiple frames, one for each + image file in folder `input_path`. If you run out of VRAM, try decreasing `decoding_t`. + """ + # Set model config + T = 5 # number of frames per sample + V = 8 # number of views per sample + F = 8 # vae factor to downsize image->latent + C = 4 + H, W = 576, 576 + n_frames = 21 # number of input and output video frames + n_views = V + 1 # number of output video views (1 input view + 8 novel views) + n_views_sv3d = 21 + subsampled_views = np.array( + [0, 2, 5, 7, 9, 12, 14, 16, 19] + ) # subsample (V+1=)9 (uniform) views from 21 SV3D views + + model_config = "scripts/sampling/configs/sv4d.yaml" + version_dict = { + "T": T * V, + "H": H, + "W": W, + "C": C, + "f": F, + "options": { + "discretization": 1, + "cfg": 2.5, + "sigma_min": 0.002, + "sigma_max": 700.0, + "rho": 7.0, + "guider": 5, + "num_steps": num_steps, + "force_uc_zero_embeddings": [ + "cond_frames", + "cond_frames_without_noise", + "cond_view", + "cond_motion", + ], + "additional_guider_kwargs": { + "additional_cond_keys": ["cond_view", "cond_motion"] + }, + }, + } + + torch.manual_seed(seed) + os.makedirs(output_folder, exist_ok=True) + + # Read input video frames i.e. images at view 0 + print(f"Reading {input_path}") + images_v0 = read_video( + input_path, + n_frames=n_frames, + W=W, + H=H, + remove_bg=remove_bg, + image_frame_ratio=image_frame_ratio, + device=device, + ) + + # Get camera viewpoints + if isinstance(elevations_deg, float) or isinstance(elevations_deg, int): + elevations_deg = [elevations_deg] * n_views_sv3d + assert ( + len(elevations_deg) == n_views_sv3d + ), f"Please provide 1 value, or a list of {n_views_sv3d} values for elevations_deg! Given {len(elevations_deg)}" + if azimuths_deg is None: + azimuths_deg = np.linspace(0, 360, n_views_sv3d + 1)[1:] % 360 + assert ( + len(azimuths_deg) == n_views_sv3d + ), f"Please provide a list of {n_views_sv3d} values for azimuths_deg! Given {len(azimuths_deg)}" + polars_rad = np.array([np.deg2rad(90 - e) for e in elevations_deg]) + azimuths_rad = np.array( + [np.deg2rad((a - azimuths_deg[-1]) % 360) for a in azimuths_deg] + ) + + # Sample multi-view images of the first frame using SV3D i.e. images at time 0 + images_t0 = sample_sv3d( + images_v0[0], + n_views_sv3d, + num_steps, + sv3d_version, + fps_id, + motion_bucket_id, + cond_aug, + decoding_t, + device, + polars_rad, + azimuths_rad, + verbose, + ) + images_t0 = torch.roll(images_t0, 1, 0) # move conditioning image to first frame + + # Initialize image matrix + img_matrix = [[None] * n_views for _ in range(n_frames)] + for i, v in enumerate(subsampled_views): + img_matrix[0][i] = images_t0[v].unsqueeze(0) + for t in range(n_frames): + img_matrix[t][0] = images_v0[t] + + base_count = len(glob(os.path.join(output_folder, "*.mp4"))) // 10 + save_video( + os.path.join(output_folder, f"{base_count:06d}_t000.mp4"), + img_matrix[0], + ) + save_video( + os.path.join(output_folder, f"{base_count:06d}_v000.mp4"), + [img_matrix[t][0] for t in range(n_frames)], + ) + + # Load SV4D model + model, filter = load_model( + model_config, + device, + version_dict["T"], + num_steps, + verbose, + ) + + # Interleaved sampling for anchor frames + t0, v0 = 0, 0 + frame_indices = np.arange(T - 1, n_frames, T - 1) # [4, 8, 12, 16, 20] + view_indices = np.arange(V) + 1 + print(f"Sampling anchor frames {frame_indices}") + image = img_matrix[t0][v0] + cond_motion = torch.cat([img_matrix[t][v0] for t in frame_indices], 0) + cond_view = torch.cat([img_matrix[t0][v] for v in view_indices], 0) + polars = polars_rad[subsampled_views[1:]][None].repeat(T, 0).flatten() + azims = azimuths_rad[subsampled_views[1:]][None].repeat(T, 0).flatten() + azims = (azims - azimuths_rad[v0]) % (torch.pi * 2) + samples = run_img2vid( + version_dict, model, image, seed, polars, azims, cond_motion, cond_view + ) + samples = samples.view(T, V, 3, H, W) + for i, t in enumerate(frame_indices): + for j, v in enumerate(view_indices): + if img_matrix[t][v] is None: + img_matrix[t][v] = samples[i, j][None] * 2 - 1 + + # Dense sampling for the rest + print(f"Sampling dense frames:") + for t0 in tqdm(np.arange(0, n_frames - 1, T - 1)): # [0, 4, 8, 12, 16] + frame_indices = t0 + np.arange(T) + print(f"Sampling dense frames {frame_indices}") + latent_matrix = torch.randn(n_frames, n_views, C, H // F, W // F).to("cuda") + for step in tqdm(range(num_steps)): + frame_indices = frame_indices[ + ::-1 + ].copy() # alternate between forward and backward conditioning + t0 = frame_indices[0] + image = img_matrix[t0][v0] + cond_motion = torch.cat([img_matrix[t][v0] for t in frame_indices], 0) + cond_view = torch.cat([img_matrix[t0][v] for v in view_indices], 0) + polars = polars_rad[subsampled_views[1:]][None].repeat(T, 0).flatten() + azims = azimuths_rad[subsampled_views[1:]][None].repeat(T, 0).flatten() + azims = (azims - azimuths_rad[v0]) % (torch.pi * 2) + noisy_latents = latent_matrix[frame_indices][:, view_indices].flatten(0, 1) + samples = run_img2vid_per_step( + version_dict, + model, + image, + seed, + polars, + azims, + cond_motion, + cond_view, + step, + noisy_latents, + ) + samples = samples.view(T, V, C, H // F, W // F) + for i, t in enumerate(frame_indices): + for j, v in enumerate(view_indices): + latent_matrix[t, v] = samples[i, j] + + for t in frame_indices: + for v in view_indices: + if t != 0 and v != 0: + img = decode_latents(model, latent_matrix[t, v][None], T) + img_matrix[t][v] = img * 2 - 1 + + # Save output videos + for v in view_indices: + vid_file = os.path.join(output_folder, f"{base_count:06d}_v{v:03d}.mp4") + print(f"Saving {vid_file}") + save_video(vid_file, [img_matrix[t][v] for t in range(n_frames)]) + + # Save diagonal video + diag_frames = [ + img_matrix[t][(t // (n_frames // n_views)) % n_views] for t in range(n_frames) + ] + vid_file = os.path.join(output_folder, f"{base_count:06d}_diag.mp4") + print(f"Saving {vid_file}") + save_video(vid_file, diag_frames) + + +if __name__ == "__main__": + Fire(sample) diff --git a/sgm/modules/diffusionmodules/guiders.py b/sgm/modules/diffusionmodules/guiders.py index bcaa01c5..db2cede0 100644 --- a/sgm/modules/diffusionmodules/guiders.py +++ b/sgm/modules/diffusionmodules/guiders.py @@ -94,7 +94,7 @@ def prepare_inputs( if k in ["vector", "crossattn", "concat"] + self.additional_cond_keys: c_out[k] = torch.cat((uc[k], c[k]), 0) else: - assert c[k] == uc[k] + # assert c[k] == uc[k] c_out[k] = c[k] return torch.cat([x] * 2), torch.cat([s] * 2), c_out @@ -105,7 +105,7 @@ def __init__( max_scale: float, num_frames: int, min_scale: float = 1.0, - period: float | List[float] = 1.0, + period: Union[float, List[float]] = 1.0, period_fusing: Literal["mean", "multiply", "max"] = "max", additional_cond_keys: Optional[Union[List[str], str]] = None, ): @@ -129,3 +129,47 @@ def __init__( def triangle_wave(self, values: torch.Tensor, period) -> torch.Tensor: return 2 * (values / period - torch.floor(values / period + 0.5)).abs() + + +class TrapezoidPredictionGuider(LinearPredictionGuider): + def __init__( + self, + max_scale: float, + num_frames: int, + min_scale: float = 1.0, + edge_perc: float = 0.1, + additional_cond_keys: Optional[Union[List[str], str]] = None, + ): + super().__init__(max_scale, num_frames, min_scale, additional_cond_keys) + + rise_steps = torch.linspace(min_scale, max_scale, int(num_frames * edge_perc)) + fall_steps = torch.flip(rise_steps, [0]) + self.scale = torch.cat( + [ + rise_steps, + torch.ones(num_frames - 2 * int(num_frames * edge_perc)), + fall_steps, + ] + ).unsqueeze(0) + + +class SpatiotemporalPredictionGuider(LinearPredictionGuider): + def __init__( + self, + max_scale: float, + num_frames: int, + num_views: int = 1, + min_scale: float = 1.0, + additional_cond_keys: Optional[Union[List[str], str]] = None, + ): + super().__init__(max_scale, num_frames, min_scale, additional_cond_keys) + V = num_views + T = num_frames // V + scale = torch.zeros(num_frames).view(T, V) + scale += torch.linspace(0, 1, T)[:,None] * 0.5 + scale += self.triangle_wave(torch.linspace(0, 1, V))[None,:] * 0.5 + scale = scale.flatten() + self.scale = (scale * (max_scale - min_scale) + min_scale).unsqueeze(0) + + def triangle_wave(self, values: torch.Tensor, period=1) -> torch.Tensor: + return 2 * (values / period - torch.floor(values / period + 0.5)).abs() \ No newline at end of file diff --git a/sgm/modules/diffusionmodules/openaimodel.py b/sgm/modules/diffusionmodules/openaimodel.py index b58e1b0e..275f4a01 100644 --- a/sgm/modules/diffusionmodules/openaimodel.py +++ b/sgm/modules/diffusionmodules/openaimodel.py @@ -75,20 +75,43 @@ def forward( emb: th.Tensor, context: Optional[th.Tensor] = None, image_only_indicator: Optional[th.Tensor] = None, + cond_view: Optional[th.Tensor] = None, + cond_motion: Optional[th.Tensor] = None, time_context: Optional[int] = None, num_video_frames: Optional[int] = None, + time_step: Optional[int] = None, + name: Optional[str] = None, ): - from ...modules.diffusionmodules.video_model import VideoResBlock + from ...modules.diffusionmodules.video_model import VideoResBlock, PostHocResBlockWithTime + from ...modules.spacetime_attention import ( + BasicTransformerTimeMixBlock, + PostHocSpatialTransformerWithTimeMixing, + PostHocSpatialTransformerWithTimeMixingAndMotion + ) for layer in self: module = layer - if isinstance(module, TimestepBlock) and not isinstance( - module, VideoResBlock + if isinstance( + module, + ( + BasicTransformerTimeMixBlock, + PostHocSpatialTransformerWithTimeMixing, + PostHocSpatialTransformerWithTimeMixingAndMotion + ), ): - x = layer(x, emb) - elif isinstance(module, VideoResBlock): - x = layer(x, emb, num_video_frames, image_only_indicator) + x = layer( + x, + context, + # cam, + time_context, + num_video_frames, + image_only_indicator, + cond_view, + cond_motion, + time_step, + name, + ) elif isinstance(module, SpatialVideoTransformer): x = layer( x, @@ -96,7 +119,16 @@ def forward( time_context, num_video_frames, image_only_indicator, + # time_step, ) + elif isinstance(module, PostHocResBlockWithTime): + x = layer(x, emb, num_video_frames, image_only_indicator) + elif isinstance(module, VideoResBlock): + x = layer(x, emb, num_video_frames, image_only_indicator) + elif isinstance(module, TimestepBlock) and not isinstance( + module, VideoResBlock + ): + x = layer(x, emb) elif isinstance(module, SpatialTransformer): x = layer(x, context) else: diff --git a/sgm/modules/diffusionmodules/sigma_sampling.py b/sgm/modules/diffusionmodules/sigma_sampling.py index d54724c6..e327e851 100644 --- a/sgm/modules/diffusionmodules/sigma_sampling.py +++ b/sgm/modules/diffusionmodules/sigma_sampling.py @@ -1,5 +1,5 @@ import torch - +from typing import Optional, Union from ...util import default, instantiate_from_config @@ -29,3 +29,10 @@ def __call__(self, n_samples, rand=None): torch.randint(0, self.num_idx, (n_samples,)), ) return self.idx_to_sigma(idx) + + +class ZeroSampler: + def __call__( + self, n_samples: int, rand: Optional[torch.Tensor] = None + ) -> torch.Tensor: + return torch.zeros_like(default(rand, torch.randn((n_samples,)))) + 1.0e-5 diff --git a/sgm/modules/diffusionmodules/util.py b/sgm/modules/diffusionmodules/util.py index 389f0e44..1d06ce8f 100644 --- a/sgm/modules/diffusionmodules/util.py +++ b/sgm/modules/diffusionmodules/util.py @@ -17,6 +17,36 @@ from einops import rearrange, repeat +def get_alpha( + merge_strategy: str, + mix_factor: Optional[torch.Tensor], + image_only_indicator: torch.Tensor, + apply_sigmoid: bool = True, + is_attn: bool = False, +) -> torch.Tensor: + if merge_strategy == "fixed" or merge_strategy == "learned": + alpha = mix_factor + elif merge_strategy == "learned_with_images": + alpha = torch.where( + image_only_indicator.bool(), + torch.ones(1, 1, device=image_only_indicator.device), + rearrange(mix_factor, "... -> ... 1"), + ) + if is_attn: + alpha = rearrange(alpha, "b t -> (b t) 1 1") + else: + alpha = rearrange(alpha, "b t -> b 1 t 1 1") + elif merge_strategy == "fixed_with_images": + alpha = image_only_indicator + if is_attn: + alpha = rearrange(alpha, "b t -> (b t) 1 1") + else: + alpha = rearrange(alpha, "b t -> b 1 t 1 1") + else: + raise NotImplementedError + return torch.sigmoid(alpha) if apply_sigmoid else alpha + + def make_beta_schedule( schedule, n_timestep, diff --git a/sgm/modules/diffusionmodules/video_model.py b/sgm/modules/diffusionmodules/video_model.py index ff2d077c..64b9844a 100644 --- a/sgm/modules/diffusionmodules/video_model.py +++ b/sgm/modules/diffusionmodules/video_model.py @@ -5,8 +5,13 @@ from ...modules.diffusionmodules.openaimodel import * from ...modules.video_attention import SpatialVideoTransformer +from ...modules.spacetime_attention import ( + BasicTransformerTimeMixBlock, + PostHocSpatialTransformerWithTimeMixing, + PostHocSpatialTransformerWithTimeMixingAndMotion +) from ...util import default -from .util import AlphaBlender +from .util import AlphaBlender # , LegacyAlphaBlenderWithBug, get_alpha class VideoResBlock(ResBlock): @@ -491,3 +496,746 @@ def forward( ) h = h.type(x.dtype) return self.out(h) + + +class PostHocAttentionBlockWithTimeMixing(AttentionBlock): + def __init__( + self, + in_channels: int, + n_heads: int, + d_head: int, + use_checkpoint: bool = False, + use_new_attention_order: bool = False, + dropout: float = 0.0, + use_spatial_context: bool = False, + merge_strategy: bool = "fixed", + merge_factor: float = 0.5, + apply_sigmoid_to_merge: bool = True, + ff_in: bool = False, + attn_mode: str = "softmax", + disable_temporal_crossattention: bool = False, + ): + super().__init__( + in_channels, + n_heads, + d_head, + use_checkpoint=use_checkpoint, + use_new_attention_order=use_new_attention_order, + ) + inner_dim = n_heads * d_head + + self.time_mix_blocks = nn.ModuleList( + [ + BasicTransformerTimeMixBlock( + inner_dim, + n_heads, + d_head, + dropout=dropout, + checkpoint=use_checkpoint, + ff_in=ff_in, + attn_mode=attn_mode, + disable_temporal_crossattention=disable_temporal_crossattention, + ) + ] + ) + self.in_channels = in_channels + + time_embed_dim = self.in_channels * 4 + self.time_mix_time_embed = nn.Sequential( + linear(self.in_channels, time_embed_dim), + nn.SiLU(), + linear(time_embed_dim, self.in_channels), + ) + + self.use_spatial_context = use_spatial_context + + if merge_strategy == "fixed": + self.register_buffer("mix_factor", th.Tensor([merge_factor])) + elif merge_strategy == "learned" or merge_strategy == "learned_with_images": + self.register_parameter( + "mix_factor", th.nn.Parameter(th.Tensor([merge_factor])) + ) + elif merge_strategy == "fixed_with_images": + self.mix_factor = None + else: + raise ValueError(f"unknown merge strategy {merge_strategy}") + + self.get_alpha_fn = functools.partial( + get_alpha, + merge_strategy, + self.mix_factor, + apply_sigmoid=apply_sigmoid_to_merge, + ) + + def forward( + self, + x: th.Tensor, + context: Optional[th.Tensor] = None, + # cam: Optional[th.Tensor] = None, + time_context: Optional[th.Tensor] = None, + timesteps: Optional[int] = None, + image_only_indicator: Optional[th.Tensor] = None, + conv_view: Optional[th.Tensor] = None, + conv_motion: Optional[th.Tensor] = None, + ): + if time_context is not None: + raise NotImplementedError + + _, _, h, w = x.shape + if exists(context): + context = rearrange(context, "b t ... -> (b t) ...") + if self.use_spatial_context: + time_context = repeat(context[:, 0], "b ... -> (b n) ...", n=h * w) + + x = super().forward( + x, + ) + + x = rearrange(x, "b c h w -> b (h w) c") + x_mix = x + + num_frames = th.arange(timesteps, device=x.device) + num_frames = repeat(num_frames, "t -> b t", b=x.shape[0] // timesteps) + num_frames = rearrange(num_frames, "b t -> (b t)") + t_emb = timestep_embedding(num_frames, self.in_channels, repeat_only=False) + emb = self.time_mix_time_embed(t_emb) + emb = emb[:, None, :] + x_mix = x_mix + emb + + x_mix = self.time_mix_blocks[0]( + x_mix, context=time_context, timesteps=timesteps + ) + + alpha = self.get_alpha_fn(image_only_indicator=image_only_indicator) + x = alpha * x + (1.0 - alpha) * x_mix + x = rearrange(x, "b (h w) c -> b c h w", h=h, w=w) + return x + + +class PostHocResBlockWithTime(ResBlock): + def __init__( + self, + channels: int, + emb_channels: int, + dropout: float, + time_kernel_size: Union[int, List[int]] = 3, + merge_strategy: bool = "fixed", + merge_factor: float = 0.5, + apply_sigmoid_to_merge: bool = True, + out_channels: Optional[int] = None, + use_conv: bool = False, + use_scale_shift_norm: bool = False, + dims: int = 2, + use_checkpoint: bool = False, + up: bool = False, + down: bool = False, + time_mix_legacy: bool = True, + replicate_bug: bool = False, + ): + super().__init__( + channels, + emb_channels, + dropout, + out_channels=out_channels, + use_conv=use_conv, + use_scale_shift_norm=use_scale_shift_norm, + dims=dims, + use_checkpoint=use_checkpoint, + up=up, + down=down, + ) + + self.time_mix_blocks = ResBlock( + default(out_channels, channels), + emb_channels, + dropout=dropout, + dims=3, + out_channels=default(out_channels, channels), + use_scale_shift_norm=False, + use_conv=False, + up=False, + down=False, + kernel_size=time_kernel_size, + use_checkpoint=use_checkpoint, + exchange_temb_dims=True, + ) + self.time_mix_legacy = time_mix_legacy + if self.time_mix_legacy: + if merge_strategy == "fixed": + self.register_buffer("mix_factor", th.Tensor([merge_factor])) + elif merge_strategy == "learned" or merge_strategy == "learned_with_images": + self.register_parameter( + "mix_factor", th.nn.Parameter(th.Tensor([merge_factor])) + ) + elif merge_strategy == "fixed_with_images": + self.mix_factor = None + else: + raise ValueError(f"unknown merge strategy {merge_strategy}") + + self.get_alpha_fn = functools.partial( + get_alpha, + merge_strategy, + self.mix_factor, + apply_sigmoid=apply_sigmoid_to_merge, + ) + else: + if False: # replicate_bug: + logpy.warning( + "*****************************************************************************************\n" + "GRAVE WARNING: YOU'RE USING THE BUGGY LEGACY ALPHABLENDER!!! ARE YOU SURE YOU WANT THIS?!\n" + "*****************************************************************************************" + ) + self.time_mixer = LegacyAlphaBlenderWithBug( + alpha=merge_factor, + merge_strategy=merge_strategy, + rearrange_pattern="b t -> b 1 t 1 1", + ) + else: + self.time_mixer = AlphaBlender( + alpha=merge_factor, + merge_strategy=merge_strategy, + rearrange_pattern="b t -> b 1 t 1 1", + ) + + def forward( + self, + x: th.Tensor, + emb: th.Tensor, + num_video_frames: int, + image_only_indicator: Optional[th.Tensor] = None, + cond_view: Optional[th.Tensor] = None, + cond_motion: Optional[th.Tensor] = None, + ) -> th.Tensor: + x = super().forward(x, emb) + + x_mix = rearrange(x, "(b t) c h w -> b c t h w", t=num_video_frames) + x = rearrange(x, "(b t) c h w -> b c t h w", t=num_video_frames) + + x = self.time_mix_blocks( + x, rearrange(emb, "(b t) ... -> b t ...", t=num_video_frames) + ) + + if self.time_mix_legacy: + alpha = self.get_alpha_fn(image_only_indicator=image_only_indicator) + x = alpha.to(x.dtype) * x + (1.0 - alpha).to(x.dtype) * x_mix + else: + x = self.time_mixer( + x_spatial=x_mix, x_temporal=x, image_only_indicator=image_only_indicator + ) + x = rearrange(x, "b c t h w -> (b t) c h w") + return x + + +class SpatialUNetModelWithTime(nn.Module): + def __init__( + self, + in_channels: int, + model_channels: int, + out_channels: int, + num_res_blocks: int, + attention_resolutions: int, + dropout: float = 0.0, + channel_mult: List[int] = (1, 2, 4, 8), + conv_resample: bool = True, + dims: int = 2, + num_classes: Optional[int] = None, + use_checkpoint: bool = False, + num_heads: int = -1, + num_head_channels: int = -1, + num_heads_upsample: int = -1, + use_scale_shift_norm: bool = False, + resblock_updown: bool = False, + use_new_attention_order: bool = False, + use_spatial_transformer: bool = False, + transformer_depth: Union[List[int], int] = 1, + transformer_depth_middle: Optional[int] = None, + context_dim: Optional[int] = None, + time_downup: bool = False, + time_context_dim: Optional[int] = None, + extra_ff_mix_layer: bool = False, + use_spatial_context: bool = False, + time_block_merge_strategy: str = "fixed", + time_block_merge_factor: float = 0.5, + spatial_transformer_attn_type: str = "softmax", + time_kernel_size: Union[int, List[int]] = 3, + use_linear_in_transformer: bool = False, + legacy: bool = True, + adm_in_channels: Optional[int] = None, + use_temporal_resblock: bool = True, + disable_temporal_crossattention: bool = False, + time_mix_legacy: bool = True, + max_ddpm_temb_period: int = 10000, + replicate_time_mix_bug: bool = False, + use_motion_attention: bool = False, + ): + super().__init__() + + if use_spatial_transformer: + assert context_dim is not None + + if context_dim is not None: + assert use_spatial_transformer + + if num_heads_upsample == -1: + num_heads_upsample = num_heads + + if num_heads == -1: + assert num_head_channels != -1 + + if num_head_channels == -1: + assert num_heads != -1 + + self.in_channels = in_channels + self.model_channels = model_channels + self.out_channels = out_channels + if isinstance(transformer_depth, int): + transformer_depth = len(channel_mult) * [transformer_depth] + transformer_depth_middle = default( + transformer_depth_middle, transformer_depth[-1] + ) + + self.num_res_blocks = num_res_blocks + self.attention_resolutions = attention_resolutions + self.dropout = dropout + self.channel_mult = channel_mult + self.conv_resample = conv_resample + self.num_classes = num_classes + self.use_checkpoint = use_checkpoint + self.num_heads = num_heads + self.num_head_channels = num_head_channels + self.num_heads_upsample = num_heads_upsample + self.use_temporal_resblocks = use_temporal_resblock + + time_embed_dim = model_channels * 4 + self.time_embed = nn.Sequential( + linear(model_channels, time_embed_dim), + nn.SiLU(), + linear(time_embed_dim, time_embed_dim), + ) + + if self.num_classes is not None: + if isinstance(self.num_classes, int): + self.label_emb = nn.Embedding(num_classes, time_embed_dim) + elif self.num_classes == "continuous": + print("setting up linear c_adm embedding layer") + self.label_emb = nn.Linear(1, time_embed_dim) + elif self.num_classes == "timestep": + self.label_emb = nn.Sequential( + Timestep(model_channels), + nn.Sequential( + linear(model_channels, time_embed_dim), + nn.SiLU(), + linear(time_embed_dim, time_embed_dim), + ), + ) + + elif self.num_classes == "sequential": + assert adm_in_channels is not None + self.label_emb = nn.Sequential( + nn.Sequential( + linear(adm_in_channels, time_embed_dim), + nn.SiLU(), + linear(time_embed_dim, time_embed_dim), + ) + ) + else: + raise ValueError() + + self.input_blocks = nn.ModuleList( + [ + TimestepEmbedSequential( + conv_nd(dims, in_channels, model_channels, 3, padding=1) + ) + ] + ) + self._feature_size = model_channels + input_block_chans = [model_channels] + ch = model_channels + ds = 1 + + def get_attention_layer( + ch, + num_heads, + dim_head, + depth=1, + context_dim=None, + use_checkpoint=False, + disabled_sa=False, + ): + if not use_spatial_transformer: + return PostHocAttentionBlockWithTimeMixing( + ch, + num_heads, + dim_head, + use_checkpoint=use_checkpoint, + use_new_attention_order=use_new_attention_order, + dropout=dropout, + ff_in=extra_ff_mix_layer, + use_spatial_context=use_spatial_context, + merge_strategy=time_block_merge_strategy, + merge_factor=time_block_merge_factor, + attn_mode=spatial_transformer_attn_type, + disable_temporal_crossattention=disable_temporal_crossattention, + ) + + elif use_motion_attention: + return PostHocSpatialTransformerWithTimeMixingAndMotion( + ch, + num_heads, + dim_head, + depth=depth, + context_dim=context_dim, + time_context_dim=time_context_dim, + dropout=dropout, + ff_in=extra_ff_mix_layer, + use_spatial_context=use_spatial_context, + merge_strategy=time_block_merge_strategy, + merge_factor=time_block_merge_factor, + checkpoint=use_checkpoint, + use_linear=use_linear_in_transformer, + attn_mode=spatial_transformer_attn_type, + disable_self_attn=disabled_sa, + disable_temporal_crossattention=disable_temporal_crossattention, + time_mix_legacy=time_mix_legacy, + max_time_embed_period=max_ddpm_temb_period, + ) + + else: + return PostHocSpatialTransformerWithTimeMixing( + ch, + num_heads, + dim_head, + depth=depth, + context_dim=context_dim, + time_context_dim=time_context_dim, + dropout=dropout, + ff_in=extra_ff_mix_layer, + use_spatial_context=use_spatial_context, + merge_strategy=time_block_merge_strategy, + merge_factor=time_block_merge_factor, + checkpoint=use_checkpoint, + use_linear=use_linear_in_transformer, + attn_mode=spatial_transformer_attn_type, + disable_self_attn=disabled_sa, + disable_temporal_crossattention=disable_temporal_crossattention, + time_mix_legacy=time_mix_legacy, + max_time_embed_period=max_ddpm_temb_period, + ) + + def get_resblock( + time_block_merge_factor, + time_block_merge_strategy, + time_kernel_size, + ch, + time_embed_dim, + dropout, + out_ch, + dims, + use_checkpoint, + use_scale_shift_norm, + down=False, + up=False, + ): + if self.use_temporal_resblocks: + return PostHocResBlockWithTime( + merge_factor=time_block_merge_factor, + merge_strategy=time_block_merge_strategy, + time_kernel_size=time_kernel_size, + channels=ch, + emb_channels=time_embed_dim, + dropout=dropout, + out_channels=out_ch, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + down=down, + up=up, + time_mix_legacy=time_mix_legacy, + replicate_bug=replicate_time_mix_bug, + ) + else: + return ResBlock( + channels=ch, + emb_channels=time_embed_dim, + dropout=dropout, + out_channels=out_ch, + use_checkpoint=use_checkpoint, + dims=dims, + use_scale_shift_norm=use_scale_shift_norm, + down=down, + up=up, + ) + + for level, mult in enumerate(channel_mult): + for _ in range(num_res_blocks): + layers = [ + get_resblock( + time_block_merge_factor=time_block_merge_factor, + time_block_merge_strategy=time_block_merge_strategy, + time_kernel_size=time_kernel_size, + ch=ch, + time_embed_dim=time_embed_dim, + dropout=dropout, + out_ch=mult * model_channels, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + ) + ] + ch = mult * model_channels + if ds in attention_resolutions: + if num_head_channels == -1: + dim_head = ch // num_heads + else: + num_heads = ch // num_head_channels + dim_head = num_head_channels + if legacy: + dim_head = ( + ch // num_heads + if use_spatial_transformer + else num_head_channels + ) + + layers.append( + get_attention_layer( + ch, + num_heads, + dim_head, + depth=transformer_depth[level], + context_dim=context_dim, + use_checkpoint=use_checkpoint, + disabled_sa=False, + ) + ) + self.input_blocks.append(TimestepEmbedSequential(*layers)) + self._feature_size += ch + input_block_chans.append(ch) + if level != len(channel_mult) - 1: + ds *= 2 + out_ch = ch + self.input_blocks.append( + TimestepEmbedSequential( + get_resblock( + time_block_merge_factor=time_block_merge_factor, + time_block_merge_strategy=time_block_merge_strategy, + time_kernel_size=time_kernel_size, + ch=ch, + time_embed_dim=time_embed_dim, + dropout=dropout, + out_ch=out_ch, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + down=True, + ) + if resblock_updown + else Downsample( + ch, + conv_resample, + dims=dims, + out_channels=out_ch, + third_down=time_downup, + ) + ) + ) + ch = out_ch + input_block_chans.append(ch) + + self._feature_size += ch + + if num_head_channels == -1: + dim_head = ch // num_heads + else: + num_heads = ch // num_head_channels + dim_head = num_head_channels + if legacy: + # num_heads = 1 + dim_head = ch // num_heads if use_spatial_transformer else num_head_channels + + self.middle_block = TimestepEmbedSequential( + get_resblock( + time_block_merge_factor=time_block_merge_factor, + time_block_merge_strategy=time_block_merge_strategy, + time_kernel_size=time_kernel_size, + ch=ch, + time_embed_dim=time_embed_dim, + out_ch=None, + dropout=dropout, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + ), + get_attention_layer( + ch, + num_heads, + dim_head, + depth=transformer_depth_middle, + context_dim=context_dim, + use_checkpoint=use_checkpoint, + ), + get_resblock( + time_block_merge_factor=time_block_merge_factor, + time_block_merge_strategy=time_block_merge_strategy, + time_kernel_size=time_kernel_size, + ch=ch, + out_ch=None, + time_embed_dim=time_embed_dim, + dropout=dropout, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + ), + ) + self._feature_size += ch + + self.output_blocks = nn.ModuleList([]) + for level, mult in list(enumerate(channel_mult))[::-1]: + for i in range(num_res_blocks + 1): + ich = input_block_chans.pop() + layers = [ + get_resblock( + time_block_merge_factor=time_block_merge_factor, + time_block_merge_strategy=time_block_merge_strategy, + time_kernel_size=time_kernel_size, + ch=ch + ich, + time_embed_dim=time_embed_dim, + dropout=dropout, + out_ch=model_channels * mult, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + ) + ] + ch = model_channels * mult + if ds in attention_resolutions: + if num_head_channels == -1: + dim_head = ch // num_heads + else: + num_heads = ch // num_head_channels + dim_head = num_head_channels + if legacy: + dim_head = ( + ch // num_heads + if use_spatial_transformer + else num_head_channels + ) + + layers.append( + get_attention_layer( + ch, + num_heads, + dim_head, + depth=transformer_depth[level], + context_dim=context_dim, + use_checkpoint=use_checkpoint, + disabled_sa=False, + ) + ) + if level and i == num_res_blocks: + out_ch = ch + ds //= 2 + layers.append( + get_resblock( + time_block_merge_factor=time_block_merge_factor, + time_block_merge_strategy=time_block_merge_strategy, + time_kernel_size=time_kernel_size, + ch=ch, + time_embed_dim=time_embed_dim, + dropout=dropout, + out_ch=out_ch, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + up=True, + ) + if resblock_updown + else Upsample( + ch, + conv_resample, + dims=dims, + out_channels=out_ch, + third_up=time_downup, + ) + ) + + self.output_blocks.append(TimestepEmbedSequential(*layers)) + self._feature_size += ch + + self.out = nn.Sequential( + normalization(ch), + nn.SiLU(), + zero_module(conv_nd(dims, model_channels, out_channels, 3, padding=1)), + ) + + def forward( + self, + x: th.Tensor, + timesteps: th.Tensor, + context: Optional[th.Tensor] = None, + y: Optional[th.Tensor] = None, + # cam: Optional[th.Tensor] = None, + time_context: Optional[th.Tensor] = None, + num_video_frames: Optional[int] = None, + image_only_indicator: Optional[th.Tensor] = None, + cond_view: Optional[th.Tensor] = None, + cond_motion: Optional[th.Tensor] = None, + time_step: Optional[int] = None, + ): + assert (y is not None) == ( + self.num_classes is not None + ), "must specify y if and only if the model is class-conditional -> no, relax this TODO" + hs = [] + t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False) # 21 x 320 + emb = self.time_embed(t_emb) # 21 x 1280 + time = str(timesteps[0].data.cpu().numpy()) + + if self.num_classes is not None: + assert y.shape[0] == x.shape[0] + emb = emb + self.label_emb(y) # 21 x 1280 + + h = x # 21 x 8 x 64 x 64 + for i, module in enumerate(self.input_blocks): + h = module( + h, + emb, + context=context, + # cam=cam, + image_only_indicator=image_only_indicator, + cond_view=cond_view, + cond_motion=cond_motion, + time_context=time_context, + num_video_frames=num_video_frames, + time_step=time_step, + name='encoder_{}_{}'.format(time, i) + ) + hs.append(h) + h = self.middle_block( + h, + emb, + context=context, + # cam=cam, + image_only_indicator=image_only_indicator, + cond_view=cond_view, + cond_motion=cond_motion, + time_context=time_context, + num_video_frames=num_video_frames, + time_step=time_step, + name='middle_{}_0'.format(time, i) + ) + for i, module in enumerate(self.output_blocks): + h = th.cat([h, hs.pop()], dim=1) + h = module( + h, + emb, + context=context, + # cam=cam, + image_only_indicator=image_only_indicator, + cond_view=cond_view, + cond_motion=cond_motion, + time_context=time_context, + num_video_frames=num_video_frames, + time_step=time_step, + name='decoder_{}_{}'.format(time, i) + ) + h = h.type(x.dtype) + return self.out(h) diff --git a/sgm/modules/diffusionmodules/wrappers.py b/sgm/modules/diffusionmodules/wrappers.py index 37449ea6..23c7d073 100644 --- a/sgm/modules/diffusionmodules/wrappers.py +++ b/sgm/modules/diffusionmodules/wrappers.py @@ -25,10 +25,21 @@ def forward( self, x: torch.Tensor, t: torch.Tensor, c: dict, **kwargs ) -> torch.Tensor: x = torch.cat((x, c.get("concat", torch.Tensor([]).type_as(x))), dim=1) - return self.diffusion_model( - x, - timesteps=t, - context=c.get("crossattn", None), - y=c.get("vector", None), - **kwargs, - ) + if "cond_view" in c: + return self.diffusion_model( + x, + timesteps=t, + context=c.get("crossattn", None), + y=c.get("vector", None), + cond_view=c.get("cond_view", None), + cond_motion=c.get("cond_motion", None), + **kwargs, + ) + else: + return self.diffusion_model( + x, + timesteps=t, + context=c.get("crossattn", None), + y=c.get("vector", None), + **kwargs, + ) diff --git a/sgm/modules/encoders/modules.py b/sgm/modules/encoders/modules.py index d77b8ed7..48bd5ea8 100644 --- a/sgm/modules/encoders/modules.py +++ b/sgm/modules/encoders/modules.py @@ -69,8 +69,8 @@ def input_key(self): class GeneralConditioner(nn.Module): - OUTPUT_DIM2KEYS = {2: "vector", 3: "crossattn", 4: "concat", 5: "concat"} - KEY2CATDIM = {"vector": 1, "crossattn": 2, "concat": 1} + OUTPUT_DIM2KEYS = {2: "vector", 3: "crossattn", 4: "concat"} # , 5: "concat"} + KEY2CATDIM = {"vector": 1, "crossattn": 2, "concat": 1, "cond_view": 1, "cond_motion": 1} def __init__(self, emb_models: Union[List, ListConfig]): super().__init__() @@ -138,7 +138,11 @@ def forward( if not isinstance(emb_out, (list, tuple)): emb_out = [emb_out] for emb in emb_out: - out_key = self.OUTPUT_DIM2KEYS[emb.dim()] + if embedder.input_key in ["cond_view", "cond_motion"]: + out_key = embedder.input_key + else: + out_key = self.OUTPUT_DIM2KEYS[emb.dim()] + if embedder.ucg_rate > 0.0 and embedder.legacy_ucg_val is None: emb = ( expand_dims_like( @@ -994,7 +998,10 @@ def forward( sigmas = self.sigma_sampler(b).to(vid.device) if self.sigma_cond is not None: sigma_cond = self.sigma_cond(sigmas) - sigma_cond = repeat(sigma_cond, "b d -> (b t) d", t=self.n_copies) + if self.n_cond_frames == 1: + sigma_cond = repeat(sigma_cond, "b d -> (b t) d", t=self.n_copies) + else: + sigma_cond = repeat(sigma_cond, "b d -> (b t) d", t=self.n_cond_frames) # For SV4D sigmas = repeat(sigmas, "b -> (b t)", t=self.n_cond_frames) noise = torch.randn_like(vid) vid = vid + noise * append_dims(sigmas, vid.ndim) @@ -1017,8 +1024,9 @@ def forward( vid = torch.cat(all_out, dim=0) vid *= self.scale_factor - vid = rearrange(vid, "(b t) c h w -> b () (t c) h w", t=self.n_cond_frames) - vid = repeat(vid, "b 1 c h w -> (b t) c h w", t=self.n_copies) + if self.n_cond_frames == 1: + vid = rearrange(vid, "(b t) c h w -> b () (t c) h w", t=self.n_cond_frames) + vid = repeat(vid, "b 1 c h w -> (b t) c h w", t=self.n_copies) return_val = (vid, sigma_cond) if self.sigma_cond is not None else vid diff --git a/sgm/modules/spacetime_attention.py b/sgm/modules/spacetime_attention.py new file mode 100644 index 00000000..c604c1b8 --- /dev/null +++ b/sgm/modules/spacetime_attention.py @@ -0,0 +1,596 @@ +from functools import partial + +import torch + +from ..modules.attention import * +from ..modules.diffusionmodules.util import ( + AlphaBlender, + get_alpha, + linear, + mixed_checkpoint, + timestep_embedding, +) + + +class TimeMixSequential(nn.Sequential): + def forward(self, x, context=None, timesteps=None): + for layer in self: + x = layer(x, context, timesteps) + + return x + + +class BasicTransformerTimeMixBlock(nn.Module): + ATTENTION_MODES = { + "softmax": CrossAttention, + "softmax-xformers": MemoryEfficientCrossAttention, + } + + def __init__( + self, + dim, + n_heads, + d_head, + dropout=0.0, + context_dim=None, + gated_ff=True, + checkpoint=True, + timesteps=None, + ff_in=False, + inner_dim=None, + attn_mode="softmax", + disable_self_attn=False, + disable_temporal_crossattention=False, + switch_temporal_ca_to_sa=False, + ): + super().__init__() + + attn_cls = self.ATTENTION_MODES[attn_mode] + + self.ff_in = ff_in or inner_dim is not None + if inner_dim is None: + inner_dim = dim + + assert int(n_heads * d_head) == inner_dim + + self.is_res = inner_dim == dim + + if self.ff_in: + self.norm_in = nn.LayerNorm(dim) + self.ff_in = FeedForward( + dim, dim_out=inner_dim, dropout=dropout, glu=gated_ff + ) + + self.timesteps = timesteps + self.disable_self_attn = disable_self_attn + if self.disable_self_attn: + self.attn1 = attn_cls( + query_dim=inner_dim, + heads=n_heads, + dim_head=d_head, + context_dim=context_dim, + dropout=dropout, + ) # is a cross-attention + else: + self.attn1 = attn_cls( + query_dim=inner_dim, heads=n_heads, dim_head=d_head, dropout=dropout + ) # is a self-attention + + self.ff = FeedForward(inner_dim, dim_out=dim, dropout=dropout, glu=gated_ff) + + if disable_temporal_crossattention: + if switch_temporal_ca_to_sa: + raise ValueError + else: + self.attn2 = None + else: + self.norm2 = nn.LayerNorm(inner_dim) + if switch_temporal_ca_to_sa: + self.attn2 = attn_cls( + query_dim=inner_dim, heads=n_heads, dim_head=d_head, dropout=dropout + ) # is a self-attention + else: + self.attn2 = attn_cls( + query_dim=inner_dim, + context_dim=context_dim, + heads=n_heads, + dim_head=d_head, + dropout=dropout, + ) # is self-attn if context is none + + self.norm1 = nn.LayerNorm(inner_dim) + self.norm3 = nn.LayerNorm(inner_dim) + self.switch_temporal_ca_to_sa = switch_temporal_ca_to_sa + + self.checkpoint = checkpoint + if self.checkpoint: + logpy.info(f"{self.__class__.__name__} is using checkpointing") + + def forward( + self, x: torch.Tensor, context: torch.Tensor = None, timesteps: int = None + ) -> torch.Tensor: + if self.checkpoint: + return checkpoint(self._forward, x, context, timesteps) + else: + return self._forward(x, context, timesteps=timesteps) + + def _forward(self, x, context=None, timesteps=None): + assert self.timesteps or timesteps + assert not (self.timesteps and timesteps) or self.timesteps == timesteps + timesteps = self.timesteps or timesteps + B, S, C = x.shape + x = rearrange(x, "(b t) s c -> (b s) t c", t=timesteps) + + if self.ff_in: + x_skip = x + x = self.ff_in(self.norm_in(x)) + if self.is_res: + x += x_skip + + if self.disable_self_attn: + x = self.attn1(self.norm1(x), context=context) + x + else: + x = self.attn1(self.norm1(x)) + x + + if self.attn2 is not None: + if self.switch_temporal_ca_to_sa: + x = self.attn2(self.norm2(x)) + x + else: + x = self.attn2(self.norm2(x), context=context) + x + x_skip = x + x = self.ff(self.norm3(x)) + if self.is_res: + x += x_skip + + x = rearrange( + x, "(b s) t c -> (b t) s c", s=S, b=B // timesteps, c=C, t=timesteps + ) + return x + + def get_last_layer(self): + return self.ff.net[-1].weight + + +class PostHocSpatialTransformerWithTimeMixing(SpatialTransformer): + def __init__( + self, + in_channels, + n_heads, + d_head, + depth=1, + dropout=0.0, + use_linear=False, + context_dim=None, + use_spatial_context=False, + timesteps=None, + merge_strategy: str = "fixed", + merge_factor: float = 0.5, + apply_sigmoid_to_merge: bool = True, + time_context_dim=None, + ff_in=False, + checkpoint=False, + time_depth=1, + attn_mode="softmax", + disable_self_attn=False, + disable_temporal_crossattention=False, + time_mix_legacy: bool = True, + max_time_embed_period: int = 10000, + ): + super().__init__( + in_channels, + n_heads, + d_head, + depth=depth, + dropout=dropout, + attn_type=attn_mode, + use_checkpoint=checkpoint, + context_dim=context_dim, + use_linear=use_linear, + disable_self_attn=disable_self_attn, + ) + self.time_depth = time_depth + self.depth = depth + self.max_time_embed_period = max_time_embed_period + + time_mix_d_head = d_head + n_time_mix_heads = n_heads + + time_mix_inner_dim = int(time_mix_d_head * n_time_mix_heads) + + inner_dim = n_heads * d_head + if use_spatial_context: + time_context_dim = context_dim + + self.time_mix_blocks = nn.ModuleList( + [ + BasicTransformerTimeMixBlock( + inner_dim, + n_time_mix_heads, + time_mix_d_head, + dropout=dropout, + context_dim=time_context_dim, + timesteps=timesteps, + checkpoint=checkpoint, + ff_in=ff_in, + inner_dim=time_mix_inner_dim, + attn_mode=attn_mode, + disable_self_attn=disable_self_attn, + disable_temporal_crossattention=disable_temporal_crossattention, + ) + for _ in range(self.depth) + ] + ) + + assert len(self.time_mix_blocks) == len(self.transformer_blocks) + + self.use_spatial_context = use_spatial_context + self.in_channels = in_channels + + time_embed_dim = self.in_channels * 4 + self.time_mix_time_embed = nn.Sequential( + linear(self.in_channels, time_embed_dim), + nn.SiLU(), + linear(time_embed_dim, self.in_channels), + ) + + self.time_mix_legacy = time_mix_legacy + if self.time_mix_legacy: + if merge_strategy == "fixed": + self.register_buffer("mix_factor", torch.Tensor([merge_factor])) + elif merge_strategy == "learned" or merge_strategy == "learned_with_images": + self.register_parameter( + "mix_factor", torch.nn.Parameter(torch.Tensor([merge_factor])) + ) + elif merge_strategy == "fixed_with_images": + self.mix_factor = None + else: + raise ValueError(f"unknown merge strategy {merge_strategy}") + + self.get_alpha_fn = partial( + get_alpha, + merge_strategy, + self.mix_factor, + apply_sigmoid=apply_sigmoid_to_merge, + is_attn=True, + ) + else: + self.time_mixer = AlphaBlender( + alpha=merge_factor, merge_strategy=merge_strategy + ) + + def forward( + self, + x: torch.Tensor, + context: Optional[torch.Tensor] = None, + # cam: Optional[torch.Tensor] = None, + time_context: Optional[torch.Tensor] = None, + timesteps: Optional[int] = None, + image_only_indicator: Optional[torch.Tensor] = None, + cond_view: Optional[torch.Tensor] = None, + cond_motion: Optional[torch.Tensor] = None, + time_step: Optional[int] = None, + name: Optional[str] = None, + ) -> torch.Tensor: + _, _, h, w = x.shape + x_in = x + spatial_context = None + if exists(context): + spatial_context = context + + if self.use_spatial_context: + assert ( + context.ndim == 3 + ), f"n dims of spatial context should be 3 but are {context.ndim}" + + time_context = context + time_context_first_timestep = time_context[::timesteps] + time_context = repeat( + time_context_first_timestep, "b ... -> (b n) ...", n=h * w + ) + elif time_context is not None and not self.use_spatial_context: + time_context = repeat(time_context, "b ... -> (b n) ...", n=h * w) + if time_context.ndim == 2: + time_context = rearrange(time_context, "b c -> b 1 c") + + x = self.norm(x) + if not self.use_linear: + x = self.proj_in(x) + x = rearrange(x, "b c h w -> b (h w) c") + if self.use_linear: + x = self.proj_in(x) + + if self.time_mix_legacy: + alpha = self.get_alpha_fn(image_only_indicator=image_only_indicator) + + num_frames = torch.arange(timesteps, device=x.device) + num_frames = repeat(num_frames, "t -> b t", b=x.shape[0] // timesteps) + num_frames = rearrange(num_frames, "b t -> (b t)") + t_emb = timestep_embedding( + num_frames, + self.in_channels, + repeat_only=False, + max_period=self.max_time_embed_period, + ) + emb = self.time_mix_time_embed(t_emb) + emb = emb[:, None, :] + + for it_, (block, mix_block) in enumerate( + zip(self.transformer_blocks, self.time_mix_blocks) + ): + # spatial attention + x = block( + x, + context=spatial_context, + time_step=time_step, + name=name + '_' + str(it_) + ) + + x_mix = x + x_mix = x_mix + emb + + # temporal attention + x_mix = mix_block(x_mix, context=time_context, timesteps=timesteps) + if self.time_mix_legacy: + x = alpha.to(x.dtype) * x + (1.0 - alpha).to(x.dtype) * x_mix + else: + x = self.time_mixer( + x_spatial=x, + x_temporal=x_mix, + image_only_indicator=image_only_indicator, + ) + + if self.use_linear: + x = self.proj_out(x) + x = rearrange(x, "b (h w) c -> b c h w", h=h, w=w) + if not self.use_linear: + x = self.proj_out(x) + out = x + x_in + return out + + +class PostHocSpatialTransformerWithTimeMixingAndMotion(SpatialTransformer): + def __init__( + self, + in_channels, + n_heads, + d_head, + depth=1, + dropout=0.0, + use_linear=False, + context_dim=None, + use_spatial_context=False, + timesteps=None, + merge_strategy: str = "fixed", + merge_factor: float = 0.5, + apply_sigmoid_to_merge: bool = True, + time_context_dim=None, + ff_in=False, + checkpoint=False, + time_depth=1, + attn_mode="softmax", + disable_self_attn=False, + disable_temporal_crossattention=False, + time_mix_legacy: bool = True, + max_time_embed_period: int = 10000, + ): + super().__init__( + in_channels, + n_heads, + d_head, + depth=depth, + dropout=dropout, + attn_type=attn_mode, + use_checkpoint=checkpoint, + context_dim=context_dim, + use_linear=use_linear, + disable_self_attn=disable_self_attn, + ) + self.time_depth = time_depth + self.depth = depth + self.max_time_embed_period = max_time_embed_period + + time_mix_d_head = d_head + n_time_mix_heads = n_heads + + time_mix_inner_dim = int(time_mix_d_head * n_time_mix_heads) + + inner_dim = n_heads * d_head + if use_spatial_context: + time_context_dim = context_dim + + camera_context_dim = time_context_dim + motion_context_dim = 4 # time_context_dim + + # Camera attention layer + self.time_mix_blocks = nn.ModuleList( + [ + BasicTransformerTimeMixBlock( + inner_dim, + n_time_mix_heads, + time_mix_d_head, + dropout=dropout, + context_dim=camera_context_dim, + timesteps=timesteps, + checkpoint=checkpoint, + ff_in=ff_in, + inner_dim=time_mix_inner_dim, + attn_mode=attn_mode, + disable_self_attn=disable_self_attn, + disable_temporal_crossattention=disable_temporal_crossattention, + ) + for _ in range(self.depth) + ] + ) + + # Motion attention layer + self.motion_blocks = nn.ModuleList( + [ + BasicTransformerTimeMixBlock( + inner_dim, + n_time_mix_heads, + time_mix_d_head, + dropout=dropout, + context_dim=motion_context_dim, + timesteps=timesteps, + checkpoint=checkpoint, + ff_in=ff_in, + inner_dim=time_mix_inner_dim, + attn_mode=attn_mode, + disable_self_attn=disable_self_attn, + disable_temporal_crossattention=disable_temporal_crossattention, + ) + for _ in range(self.depth) + ] + ) + + assert len(self.time_mix_blocks) == len(self.transformer_blocks) + + self.use_spatial_context = use_spatial_context + self.in_channels = in_channels + + time_embed_dim = self.in_channels * 4 + # Camera view embedding + self.time_mix_time_embed = nn.Sequential( + linear(self.in_channels, time_embed_dim), + nn.SiLU(), + linear(time_embed_dim, self.in_channels), + ) + # Motion time embedding + self.time_mix_motion_embed = nn.Sequential( + linear(self.in_channels, time_embed_dim), + nn.SiLU(), + linear(time_embed_dim, self.in_channels), + ) + + self.time_mix_legacy = time_mix_legacy + if self.time_mix_legacy: + if merge_strategy == "fixed": + self.register_buffer("mix_factor", torch.Tensor([merge_factor])) + elif merge_strategy == "learned" or merge_strategy == "learned_with_images": + self.register_parameter( + "mix_factor", torch.nn.Parameter(torch.Tensor([merge_factor])) + ) + elif merge_strategy == "fixed_with_images": + self.mix_factor = None + else: + raise ValueError(f"unknown merge strategy {merge_strategy}") + + self.get_alpha_fn = partial( + get_alpha, + merge_strategy, + self.mix_factor, + apply_sigmoid=apply_sigmoid_to_merge, + is_attn=True, + ) + else: + self.time_mixer = AlphaBlender( + alpha=merge_factor, merge_strategy=merge_strategy + ) + + def forward( + self, + x: torch.Tensor, + context: Optional[torch.Tensor] = None, + # cam: Optional[torch.Tensor] = None, + time_context: Optional[torch.Tensor] = None, + timesteps: Optional[int] = None, + image_only_indicator: Optional[torch.Tensor] = None, + cond_view: Optional[torch.Tensor] = None, + cond_motion: Optional[torch.Tensor] = None, + time_step: Optional[int] = None, + name: Optional[str] = None, + ) -> torch.Tensor: + _, _, h, w = x.shape + x_in = x + spatial_context = None + if exists(context): + spatial_context = context + + # cond_view: b v 4 h w + # cond_motion: b t 4 h w + b, t, d1 = context.shape # CLIP + v, d2 = cond_view.shape[0]//b, cond_view.shape[1] # VAE + cond_view = torch.nn.functional.interpolate(cond_view, size=(h,w), mode="bilinear") # b*v d h w + spatial_context = context[:,:,None].repeat(1,1,v,1).reshape(-1,1,d1) # (b*t*v) 1 d1 + camera_context = context[:,:,None].repeat(1,1,h*w,1).reshape(-1,1,d1) # (b*t*h*w) 1 d1 + motion_context = cond_view.permute(0,2,3,1).reshape(-1,1,d2) # (b*v*h*w) 1 d2 + + x = self.norm(x) + if not self.use_linear: + x = self.proj_in(x) + x = rearrange(x, "b c h w -> b (h w) c") # 21 x 4096 x 320 + if self.use_linear: + x = self.proj_in(x) + c = x.shape[-1] + + if self.time_mix_legacy: + alpha = self.get_alpha_fn(image_only_indicator=image_only_indicator) + + num_frames = torch.arange(t, device=x.device) + num_frames = repeat(num_frames, "t -> b t", b=b) + num_frames = rearrange(num_frames, "b t -> (b t)") + t_emb = timestep_embedding( + num_frames, + self.in_channels, + repeat_only=False, + max_period=self.max_time_embed_period, + ) + emb_time = self.time_mix_motion_embed(t_emb) + emb_time = emb_time[:, None, :] # b*t x 1 x 320 + + num_views = torch.arange(v, device=x.device) + num_views = repeat(num_views, "t -> b t", b=b) + num_views = rearrange(num_views, "b t -> (b t)") + v_emb = timestep_embedding( + num_views, + self.in_channels, + repeat_only=False, + max_period=self.max_time_embed_period, + ) + emb_view = self.time_mix_time_embed(v_emb) + emb_view = emb_view[:, None, :] # b*v x 1 x 320 + + for it_, (block, time_block, mot_block) in enumerate( + zip(self.transformer_blocks, self.time_mix_blocks, self.motion_blocks) + ): + # Spatial attention + x = block( + x, + context=spatial_context, + ) + + # Camera attention + x = x.view(b, t, v, h*w, c).permute(0,2,1,3,4).reshape(b*v,-1,c) # b*v t*h*w c + x_mix = x + emb_view + x_mix = time_block(x_mix, context=camera_context, timesteps=v) + if self.time_mix_legacy: + x = alpha.to(x.dtype) * x + (1.0 - alpha).to(x.dtype) * x_mix + else: + x = self.time_mixer( + x_spatial=x, + x_temporal=x_mix, + image_only_indicator=image_only_indicator[:,:v], + ) + + # Motion attention + x = x.view(b, v, t, h*w, c).permute(0,2,1,3,4).reshape(b*t,-1,c) # b*t v*h*w c + x_mix = x + emb_time + x_mix = mot_block(x_mix, context=motion_context, timesteps=t) + if self.time_mix_legacy: + x = alpha.to(x.dtype) * x + (1.0 - alpha).to(x.dtype) * x_mix + else: + x = self.time_mixer( + x_spatial=x, + x_temporal=x_mix, + image_only_indicator=image_only_indicator[:,:t], + ) + + x = x.view(b, t, v, h*w, c).reshape(-1,h*w,c) # b*t*v h*w c + + if self.use_linear: + x = self.proj_out(x) + x = rearrange(x, "b (h w) c -> b c h w", h=h, w=w) + if not self.use_linear: + x = self.proj_out(x) + out = x + x_in + return out