From 30e4d32e97fb3e97206d22cb72e84420b1f78cbc Mon Sep 17 00:00:00 2001 From: Vikram Voleti Date: Mon, 18 Mar 2024 16:52:31 +0000 Subject: [PATCH] Removes SV3D video_decoder, keeps SV3D image_decoder --- configs/inference/sv3d_p.yaml | 11 +- configs/inference/sv3d_p_image_decoder.yaml | 118 ------ configs/inference/sv3d_u.yaml | 11 +- configs/inference/sv3d_u_image_decoder.yaml | 106 ----- scripts/demo/video_sampling.py | 40 -- scripts/sampling/configs/sv3d_p.yaml | 11 +- .../configs/sv3d_p_image_decoder.yaml | 132 ------- scripts/sampling/configs/sv3d_u.yaml | 11 +- .../configs/sv3d_u_image_decoder.yaml | 120 ------ scripts/sampling/simple_video_sample.py | 34 +- scripts/sampling/simple_video_sample_GSO.py | 374 ++++++++++++++++++ 11 files changed, 400 insertions(+), 568 deletions(-) delete mode 100644 configs/inference/sv3d_p_image_decoder.yaml delete mode 100644 configs/inference/sv3d_u_image_decoder.yaml delete mode 100644 scripts/sampling/configs/sv3d_p_image_decoder.yaml delete mode 100644 scripts/sampling/configs/sv3d_u_image_decoder.yaml create mode 100644 scripts/sampling/simple_video_sample_GSO.py diff --git a/configs/inference/sv3d_p.yaml b/configs/inference/sv3d_p.yaml index 5aac9333..d3781fe5 100644 --- a/configs/inference/sv3d_p.yaml +++ b/configs/inference/sv3d_p.yaml @@ -103,17 +103,16 @@ model: encoder_config: target: torch.nn.Identity decoder_config: - target: sgm.modules.autoencoding.temporal_ae.VideoDecoder + target: sgm.modules.diffusionmodules.model.Decoder params: - attn_type: vanilla + attn_type: vanilla-xformers double_z: True z_channels: 4 resolution: 256 in_channels: 3 out_ch: 3 ch: 128 - ch_mult: [1, 2, 4, 4] + ch_mult: [ 1, 2, 4, 4 ] num_res_blocks: 2 - attn_resolutions: [] - dropout: 0.0 - video_kernel_size: [3, 1, 1] \ No newline at end of file + attn_resolutions: [ ] + dropout: 0.0 \ No newline at end of file diff --git a/configs/inference/sv3d_p_image_decoder.yaml b/configs/inference/sv3d_p_image_decoder.yaml deleted file mode 100644 index d3781fe5..00000000 --- a/configs/inference/sv3d_p_image_decoder.yaml +++ /dev/null @@ -1,118 +0,0 @@ -model: - target: sgm.models.diffusion.DiffusionEngine - params: - scale_factor: 0.18215 - disable_first_stage_autocast: True - - denoiser_config: - target: sgm.modules.diffusionmodules.denoiser.Denoiser - params: - scaling_config: - target: sgm.modules.diffusionmodules.denoiser_scaling.VScalingWithEDMcNoise - - network_config: - target: sgm.modules.diffusionmodules.video_model.VideoUNet - params: - adm_in_channels: 1280 - num_classes: sequential - use_checkpoint: True - in_channels: 8 - out_channels: 4 - model_channels: 320 - attention_resolutions: [4, 2, 1] - num_res_blocks: 2 - channel_mult: [1, 2, 4, 4] - num_head_channels: 64 - use_linear_in_transformer: True - transformer_depth: 1 - context_dim: 1024 - spatial_transformer_attn_type: softmax-xformers - extra_ff_mix_layer: True - use_spatial_context: True - merge_strategy: learned_with_images - video_kernel_size: [3, 1, 1] - - conditioner_config: - target: sgm.modules.GeneralConditioner - params: - emb_models: - - input_key: cond_frames_without_noise - is_trainable: False - target: sgm.modules.encoders.modules.FrozenOpenCLIPImagePredictionEmbedder - params: - n_cond_frames: 1 - n_copies: 1 - open_clip_embedding_config: - target: sgm.modules.encoders.modules.FrozenOpenCLIPImageEmbedder - params: - freeze: True - - - input_key: cond_frames - is_trainable: False - target: sgm.modules.encoders.modules.VideoPredictionEmbedderWithEncoder - params: - disable_encoder_autocast: True - n_cond_frames: 1 - n_copies: 1 - is_ae: True - encoder_config: - target: sgm.models.autoencoder.AutoencoderKLModeOnly - params: - embed_dim: 4 - monitor: val/rec_loss - ddconfig: - attn_type: vanilla-xformers - double_z: True - z_channels: 4 - resolution: 256 - in_channels: 3 - out_ch: 3 - ch: 128 - ch_mult: [1, 2, 4, 4] - num_res_blocks: 2 - attn_resolutions: [] - dropout: 0.0 - lossconfig: - target: torch.nn.Identity - - - input_key: cond_aug - is_trainable: False - target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND - params: - outdim: 256 - - - input_key: polars_rad - is_trainable: False - target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND - params: - outdim: 512 - - - input_key: azimuths_rad - is_trainable: False - target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND - params: - outdim: 512 - - first_stage_config: - target: sgm.models.autoencoder.AutoencodingEngine - params: - loss_config: - target: torch.nn.Identity - regularizer_config: - target: sgm.modules.autoencoding.regularizers.DiagonalGaussianRegularizer - encoder_config: - target: torch.nn.Identity - decoder_config: - target: sgm.modules.diffusionmodules.model.Decoder - params: - attn_type: vanilla-xformers - double_z: True - z_channels: 4 - resolution: 256 - in_channels: 3 - out_ch: 3 - ch: 128 - ch_mult: [ 1, 2, 4, 4 ] - num_res_blocks: 2 - attn_resolutions: [ ] - dropout: 0.0 \ No newline at end of file diff --git a/configs/inference/sv3d_u.yaml b/configs/inference/sv3d_u.yaml index 32de359c..5c48a5ff 100644 --- a/configs/inference/sv3d_u.yaml +++ b/configs/inference/sv3d_u.yaml @@ -91,17 +91,16 @@ model: encoder_config: target: torch.nn.Identity decoder_config: - target: sgm.modules.autoencoding.temporal_ae.VideoDecoder + target: sgm.modules.diffusionmodules.model.Decoder params: - attn_type: vanilla + attn_type: vanilla-xformers double_z: True z_channels: 4 resolution: 256 in_channels: 3 out_ch: 3 ch: 128 - ch_mult: [1, 2, 4, 4] + ch_mult: [ 1, 2, 4, 4 ] num_res_blocks: 2 - attn_resolutions: [] - dropout: 0.0 - video_kernel_size: [3, 1, 1] \ No newline at end of file + attn_resolutions: [ ] + dropout: 0.0 \ No newline at end of file diff --git a/configs/inference/sv3d_u_image_decoder.yaml b/configs/inference/sv3d_u_image_decoder.yaml deleted file mode 100644 index 5c48a5ff..00000000 --- a/configs/inference/sv3d_u_image_decoder.yaml +++ /dev/null @@ -1,106 +0,0 @@ -model: - target: sgm.models.diffusion.DiffusionEngine - params: - scale_factor: 0.18215 - disable_first_stage_autocast: True - - denoiser_config: - target: sgm.modules.diffusionmodules.denoiser.Denoiser - params: - scaling_config: - target: sgm.modules.diffusionmodules.denoiser_scaling.VScalingWithEDMcNoise - - network_config: - target: sgm.modules.diffusionmodules.video_model.VideoUNet - params: - adm_in_channels: 256 - num_classes: sequential - use_checkpoint: True - in_channels: 8 - out_channels: 4 - model_channels: 320 - attention_resolutions: [4, 2, 1] - num_res_blocks: 2 - channel_mult: [1, 2, 4, 4] - num_head_channels: 64 - use_linear_in_transformer: True - transformer_depth: 1 - context_dim: 1024 - spatial_transformer_attn_type: softmax-xformers - extra_ff_mix_layer: True - use_spatial_context: True - merge_strategy: learned_with_images - video_kernel_size: [3, 1, 1] - - conditioner_config: - target: sgm.modules.GeneralConditioner - params: - emb_models: - - input_key: cond_frames_without_noise - is_trainable: False - target: sgm.modules.encoders.modules.FrozenOpenCLIPImagePredictionEmbedder - params: - n_cond_frames: 1 - n_copies: 1 - open_clip_embedding_config: - target: sgm.modules.encoders.modules.FrozenOpenCLIPImageEmbedder - params: - freeze: True - - - input_key: cond_frames - is_trainable: False - target: sgm.modules.encoders.modules.VideoPredictionEmbedderWithEncoder - params: - disable_encoder_autocast: True - n_cond_frames: 1 - n_copies: 1 - is_ae: True - encoder_config: - target: sgm.models.autoencoder.AutoencoderKLModeOnly - params: - embed_dim: 4 - monitor: val/rec_loss - ddconfig: - attn_type: vanilla-xformers - double_z: True - z_channels: 4 - resolution: 256 - in_channels: 3 - out_ch: 3 - ch: 128 - ch_mult: [1, 2, 4, 4] - num_res_blocks: 2 - attn_resolutions: [] - dropout: 0.0 - lossconfig: - target: torch.nn.Identity - - - input_key: cond_aug - is_trainable: False - target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND - params: - outdim: 256 - - first_stage_config: - target: sgm.models.autoencoder.AutoencodingEngine - params: - loss_config: - target: torch.nn.Identity - regularizer_config: - target: sgm.modules.autoencoding.regularizers.DiagonalGaussianRegularizer - encoder_config: - target: torch.nn.Identity - decoder_config: - target: sgm.modules.diffusionmodules.model.Decoder - params: - attn_type: vanilla-xformers - double_z: True - z_channels: 4 - resolution: 256 - in_channels: 3 - out_ch: 3 - ch: 128 - ch_mult: [ 1, 2, 4, 4 ] - num_res_blocks: 2 - attn_resolutions: [ ] - dropout: 0.0 \ No newline at end of file diff --git a/scripts/demo/video_sampling.py b/scripts/demo/video_sampling.py index 8425b2a1..1f4fcfc4 100644 --- a/scripts/demo/video_sampling.py +++ b/scripts/demo/video_sampling.py @@ -109,26 +109,6 @@ "decoding_t": 14, }, }, - "sv3d_u_image_decoder": { - "T": 21, - "H": 576, - "W": 576, - "C": 4, - "f": 8, - "config": "configs/inference/sv3d_u_image_decoder.yaml", - "ckpt": "checkpoints/sv3d_u_image_decoder.safetensors", - "options": { - "discretization": 1, - "cfg": 2.5, - "sigma_min": 0.002, - "sigma_max": 700.0, - "rho": 7.0, - "guider": 3, - "force_uc_zero_embeddings": ["cond_frames", "cond_frames_without_noise"], - "num_steps": 50, - "decoding_t": 14, - }, - }, "sv3d_p": { "T": 21, "H": 576, @@ -149,26 +129,6 @@ "decoding_t": 14, }, }, - "sv3d_p_image_decoder": { - "T": 21, - "H": 576, - "W": 576, - "C": 4, - "f": 8, - "config": "configs/inference/sv3d_p_image_decoder.yaml", - "ckpt": "checkpoints/sv3d_p_image_decoder.safetensors", - "options": { - "discretization": 1, - "cfg": 2.5, - "sigma_min": 0.002, - "sigma_max": 700.0, - "rho": 7.0, - "guider": 3, - "force_uc_zero_embeddings": ["cond_frames", "cond_frames_without_noise"], - "num_steps": 50, - "decoding_t": 14, - }, - }, } diff --git a/scripts/sampling/configs/sv3d_p.yaml b/scripts/sampling/configs/sv3d_p.yaml index 5906ce75..bb3747c7 100644 --- a/scripts/sampling/configs/sv3d_p.yaml +++ b/scripts/sampling/configs/sv3d_p.yaml @@ -3,7 +3,7 @@ model: params: scale_factor: 0.18215 disable_first_stage_autocast: True - ckpt_path: checkpoints/sv3d_p.safetensors + ckpt_path: checkpoints/sv3d_p_image_decoder.safetensors denoiser_config: target: sgm.modules.diffusionmodules.denoiser.Denoiser @@ -104,20 +104,19 @@ model: encoder_config: target: torch.nn.Identity decoder_config: - target: sgm.modules.autoencoding.temporal_ae.VideoDecoder + target: sgm.modules.diffusionmodules.model.Decoder params: - attn_type: vanilla + attn_type: vanilla-xformers double_z: True z_channels: 4 resolution: 256 in_channels: 3 out_ch: 3 ch: 128 - ch_mult: [1, 2, 4, 4] + ch_mult: [ 1, 2, 4, 4 ] num_res_blocks: 2 - attn_resolutions: [] + attn_resolutions: [ ] dropout: 0.0 - video_kernel_size: [3, 1, 1] sampler_config: target: sgm.modules.diffusionmodules.sampling.EulerEDMSampler diff --git a/scripts/sampling/configs/sv3d_p_image_decoder.yaml b/scripts/sampling/configs/sv3d_p_image_decoder.yaml deleted file mode 100644 index bb3747c7..00000000 --- a/scripts/sampling/configs/sv3d_p_image_decoder.yaml +++ /dev/null @@ -1,132 +0,0 @@ -model: - target: sgm.models.diffusion.DiffusionEngine - params: - scale_factor: 0.18215 - disable_first_stage_autocast: True - ckpt_path: checkpoints/sv3d_p_image_decoder.safetensors - - denoiser_config: - target: sgm.modules.diffusionmodules.denoiser.Denoiser - params: - scaling_config: - target: sgm.modules.diffusionmodules.denoiser_scaling.VScalingWithEDMcNoise - - network_config: - target: sgm.modules.diffusionmodules.video_model.VideoUNet - params: - adm_in_channels: 1280 - num_classes: sequential - use_checkpoint: True - in_channels: 8 - out_channels: 4 - model_channels: 320 - attention_resolutions: [4, 2, 1] - num_res_blocks: 2 - channel_mult: [1, 2, 4, 4] - num_head_channels: 64 - use_linear_in_transformer: True - transformer_depth: 1 - context_dim: 1024 - spatial_transformer_attn_type: softmax-xformers - extra_ff_mix_layer: True - use_spatial_context: True - merge_strategy: learned_with_images - video_kernel_size: [3, 1, 1] - - conditioner_config: - target: sgm.modules.GeneralConditioner - params: - emb_models: - - input_key: cond_frames_without_noise - is_trainable: False - target: sgm.modules.encoders.modules.FrozenOpenCLIPImagePredictionEmbedder - params: - n_cond_frames: 1 - n_copies: 1 - open_clip_embedding_config: - target: sgm.modules.encoders.modules.FrozenOpenCLIPImageEmbedder - params: - freeze: True - - - input_key: cond_frames - is_trainable: False - target: sgm.modules.encoders.modules.VideoPredictionEmbedderWithEncoder - params: - disable_encoder_autocast: True - n_cond_frames: 1 - n_copies: 1 - is_ae: True - encoder_config: - target: sgm.models.autoencoder.AutoencoderKLModeOnly - params: - embed_dim: 4 - monitor: val/rec_loss - ddconfig: - attn_type: vanilla-xformers - double_z: True - z_channels: 4 - resolution: 256 - in_channels: 3 - out_ch: 3 - ch: 128 - ch_mult: [1, 2, 4, 4] - num_res_blocks: 2 - attn_resolutions: [] - dropout: 0.0 - lossconfig: - target: torch.nn.Identity - - - input_key: cond_aug - is_trainable: False - target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND - params: - outdim: 256 - - - input_key: polars_rad - is_trainable: False - target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND - params: - outdim: 512 - - - input_key: azimuths_rad - is_trainable: False - target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND - params: - outdim: 512 - - first_stage_config: - target: sgm.models.autoencoder.AutoencodingEngine - params: - loss_config: - target: torch.nn.Identity - regularizer_config: - target: sgm.modules.autoencoding.regularizers.DiagonalGaussianRegularizer - encoder_config: - target: torch.nn.Identity - decoder_config: - target: sgm.modules.diffusionmodules.model.Decoder - params: - attn_type: vanilla-xformers - double_z: True - z_channels: 4 - resolution: 256 - in_channels: 3 - out_ch: 3 - ch: 128 - ch_mult: [ 1, 2, 4, 4 ] - num_res_blocks: 2 - attn_resolutions: [ ] - dropout: 0.0 - - sampler_config: - target: sgm.modules.diffusionmodules.sampling.EulerEDMSampler - params: - discretization_config: - target: sgm.modules.diffusionmodules.discretizer.EDMDiscretization - params: - sigma_max: 700.0 - - guider_config: - target: sgm.modules.diffusionmodules.guiders.TrianglePredictionGuider - params: - max_scale: 2.5 diff --git a/scripts/sampling/configs/sv3d_u.yaml b/scripts/sampling/configs/sv3d_u.yaml index 47b11b1a..8a7ce212 100644 --- a/scripts/sampling/configs/sv3d_u.yaml +++ b/scripts/sampling/configs/sv3d_u.yaml @@ -3,7 +3,7 @@ model: params: scale_factor: 0.18215 disable_first_stage_autocast: True - ckpt_path: checkpoints/sv3d_u.safetensors + ckpt_path: checkpoints/sv3d_u_image_decoder.safetensors denoiser_config: target: sgm.modules.diffusionmodules.denoiser.Denoiser @@ -92,20 +92,19 @@ model: encoder_config: target: torch.nn.Identity decoder_config: - target: sgm.modules.autoencoding.temporal_ae.VideoDecoder + target: sgm.modules.diffusionmodules.model.Decoder params: - attn_type: vanilla + attn_type: vanilla-xformers double_z: True z_channels: 4 resolution: 256 in_channels: 3 out_ch: 3 ch: 128 - ch_mult: [1, 2, 4, 4] + ch_mult: [ 1, 2, 4, 4 ] num_res_blocks: 2 - attn_resolutions: [] + attn_resolutions: [ ] dropout: 0.0 - video_kernel_size: [3, 1, 1] sampler_config: target: sgm.modules.diffusionmodules.sampling.EulerEDMSampler diff --git a/scripts/sampling/configs/sv3d_u_image_decoder.yaml b/scripts/sampling/configs/sv3d_u_image_decoder.yaml deleted file mode 100644 index 8a7ce212..00000000 --- a/scripts/sampling/configs/sv3d_u_image_decoder.yaml +++ /dev/null @@ -1,120 +0,0 @@ -model: - target: sgm.models.diffusion.DiffusionEngine - params: - scale_factor: 0.18215 - disable_first_stage_autocast: True - ckpt_path: checkpoints/sv3d_u_image_decoder.safetensors - - denoiser_config: - target: sgm.modules.diffusionmodules.denoiser.Denoiser - params: - scaling_config: - target: sgm.modules.diffusionmodules.denoiser_scaling.VScalingWithEDMcNoise - - network_config: - target: sgm.modules.diffusionmodules.video_model.VideoUNet - params: - adm_in_channels: 256 - num_classes: sequential - use_checkpoint: True - in_channels: 8 - out_channels: 4 - model_channels: 320 - attention_resolutions: [4, 2, 1] - num_res_blocks: 2 - channel_mult: [1, 2, 4, 4] - num_head_channels: 64 - use_linear_in_transformer: True - transformer_depth: 1 - context_dim: 1024 - spatial_transformer_attn_type: softmax-xformers - extra_ff_mix_layer: True - use_spatial_context: True - merge_strategy: learned_with_images - video_kernel_size: [3, 1, 1] - - conditioner_config: - target: sgm.modules.GeneralConditioner - params: - emb_models: - - is_trainable: False - input_key: cond_frames_without_noise - target: sgm.modules.encoders.modules.FrozenOpenCLIPImagePredictionEmbedder - params: - n_cond_frames: 1 - n_copies: 1 - open_clip_embedding_config: - target: sgm.modules.encoders.modules.FrozenOpenCLIPImageEmbedder - params: - freeze: True - - - input_key: cond_frames - is_trainable: False - target: sgm.modules.encoders.modules.VideoPredictionEmbedderWithEncoder - params: - disable_encoder_autocast: True - n_cond_frames: 1 - n_copies: 1 - is_ae: True - encoder_config: - target: sgm.models.autoencoder.AutoencoderKLModeOnly - params: - embed_dim: 4 - monitor: val/rec_loss - ddconfig: - attn_type: vanilla-xformers - double_z: True - z_channels: 4 - resolution: 256 - in_channels: 3 - out_ch: 3 - ch: 128 - ch_mult: [1, 2, 4, 4] - num_res_blocks: 2 - attn_resolutions: [] - dropout: 0.0 - lossconfig: - target: torch.nn.Identity - - - input_key: cond_aug - is_trainable: False - target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND - params: - outdim: 256 - - first_stage_config: - target: sgm.models.autoencoder.AutoencodingEngine - params: - loss_config: - target: torch.nn.Identity - regularizer_config: - target: sgm.modules.autoencoding.regularizers.DiagonalGaussianRegularizer - encoder_config: - target: torch.nn.Identity - decoder_config: - target: sgm.modules.diffusionmodules.model.Decoder - params: - attn_type: vanilla-xformers - double_z: True - z_channels: 4 - resolution: 256 - in_channels: 3 - out_ch: 3 - ch: 128 - ch_mult: [ 1, 2, 4, 4 ] - num_res_blocks: 2 - attn_resolutions: [ ] - dropout: 0.0 - - sampler_config: - target: sgm.modules.diffusionmodules.sampling.EulerEDMSampler - params: - discretization_config: - target: sgm.modules.diffusionmodules.discretizer.EDMDiscretization - params: - sigma_max: 700.0 - - guider_config: - target: sgm.modules.diffusionmodules.guiders.TrianglePredictionGuider - params: - max_scale: 2.5 diff --git a/scripts/sampling/simple_video_sample.py b/scripts/sampling/simple_video_sample.py index 9de27937..bc64591e 100644 --- a/scripts/sampling/simple_video_sample.py +++ b/scripts/sampling/simple_video_sample.py @@ -51,24 +51,6 @@ def sample( num_steps = default(num_steps, 30) output_folder = default(output_folder, "outputs/simple_video_sample/svd_xt/") model_config = "scripts/sampling/configs/svd_xt.yaml" - elif version == "sv3d_u": - num_frames = 21 - num_steps = default(num_steps, 50) - output_folder = default(output_folder, "outputs/simple_video_sample/sv3d_u/") - model_config = "scripts/sampling/configs/sv3d_u.yaml" - cond_aug = 1e-5 - elif version == "sv3d_p": - num_frames = 21 - num_steps = default(num_steps, 50) - output_folder = default(output_folder, "outputs/simple_video_sample/sv3d_p/") - model_config = "scripts/sampling/configs/sv3d_p.yaml" - cond_aug = 1e-5 - if isinstance(elevations_deg, float) or isinstance(elevations_deg, int): - elevations_deg = [elevations_deg] * num_frames - polars_rad = [np.deg2rad(90 - e) for e in elevations_deg] - if azimuths_deg is None: - azimuths_deg = np.linspace(0, 360, num_frames + 1)[1:] % 360 - azimuths_rad = [np.deg2rad(a) for a in azimuths_deg] elif version == "svd_image_decoder": num_frames = default(num_frames, 14) num_steps = default(num_steps, 25) @@ -83,21 +65,17 @@ def sample( output_folder, "outputs/simple_video_sample/svd_xt_image_decoder/" ) model_config = "scripts/sampling/configs/svd_xt_image_decoder.yaml" - elif version == "sv3d_u_image_decoder": + elif version == "sv3d_u": num_frames = 21 num_steps = default(num_steps, 50) - output_folder = default( - output_folder, "outputs/simple_video_sample/sv3d_u_image_decoder/" - ) - model_config = "scripts/sampling/configs/sv3d_u_image_decoder.yaml" + output_folder = default(output_folder, "outputs/simple_video_sample/sv3d_u/") + model_config = "scripts/sampling/configs/sv3d_u.yaml" cond_aug = 1e-5 - elif version == "sv3d_p_image_decoder": + elif version == "sv3d_p": num_frames = 21 num_steps = default(num_steps, 50) - output_folder = default( - output_folder, "outputs/simple_video_sample/sv3d_p_image_decoder/" - ) - model_config = "scripts/sampling/configs/sv3d_p_image_decoder.yaml" + output_folder = default(output_folder, "outputs/simple_video_sample/sv3d_p/") + model_config = "scripts/sampling/configs/sv3d_p.yaml" cond_aug = 1e-5 if isinstance(elevations_deg, float) or isinstance(elevations_deg, int): elevations_deg = [elevations_deg] * num_frames diff --git a/scripts/sampling/simple_video_sample_GSO.py b/scripts/sampling/simple_video_sample_GSO.py new file mode 100644 index 00000000..455a8c43 --- /dev/null +++ b/scripts/sampling/simple_video_sample_GSO.py @@ -0,0 +1,374 @@ +import json +import math +import os +import sys +from glob import glob +from pathlib import Path +from typing import List, Optional + +sys.path.append(os.path.realpath(os.path.join(os.path.dirname(__file__), "../../"))) + +import cv2 +import imageio +import numpy as np +import torch +from einops import rearrange, repeat +from fire import Fire +from omegaconf import OmegaConf +from PIL import Image +from rembg import remove +from scripts.util.detection.nsfw_and_watermark_dectection import DeepFloydDataFiltering +from sgm.inference.helpers import embed_watermark +from sgm.util import default, instantiate_from_config +from torchvision.transforms import ToTensor + + +def sample( + input_path: str = "assets/test_image.png", # Can either be image file or folder with image files + num_frames: Optional[int] = None, # 21 for SV3D + num_steps: Optional[int] = None, + version: str = "svd", + fps_id: int = 6, + motion_bucket_id: int = 127, + cond_aug: float = 0.02, + seed: int = 23, + decoding_t: int = 7, # Number of frames decoded at a time! This eats most VRAM. Reduce if necessary. + device: str = "cuda", + output_folder: Optional[str] = None, + elevations_deg: Optional[float | List[float]] = 10.0, # For SV3D + azimuths_deg: Optional[float | List[float]] = None, # For SV3D + image_frame_ratio: Optional[float] = None, +): + """ + Simple script to generate a single sample conditioned on an image `input_path` or multiple images, one for each + image file in folder `input_path`. If you run out of VRAM, try decreasing `decoding_t`. + """ + + if version == "svd": + num_frames = default(num_frames, 14) + num_steps = default(num_steps, 25) + output_folder = default(output_folder, "outputs/simple_video_sample/svd/") + model_config = "scripts/sampling/configs/svd.yaml" + elif version == "svd_xt": + num_frames = default(num_frames, 25) + num_steps = default(num_steps, 30) + output_folder = default(output_folder, "outputs/simple_video_sample/svd_xt/") + model_config = "scripts/sampling/configs/svd_xt.yaml" + elif version == "sv3d_u": + num_frames = 21 + num_steps = default(num_steps, 50) + output_folder = default(output_folder, "outputs/simple_video_sample/sv3d_u/") + model_config = "scripts/sampling/configs/sv3d_u.yaml" + cond_aug = 1e-5 + elif version == "sv3d_p": + num_frames = 21 + num_steps = default(num_steps, 50) + output_folder = default(output_folder, "outputs/simple_video_sample/sv3d_p/") + model_config = "scripts/sampling/configs/sv3d_p.yaml" + cond_aug = 1e-5 + if isinstance(elevations_deg, float) or isinstance(elevations_deg, int): + elevations_deg = [elevations_deg] * num_frames + polars_rad = [np.deg2rad(90 - e) for e in elevations_deg] + if azimuths_deg is None: + azimuths_deg = np.linspace(0, 360, num_frames + 1)[1:] % 360 + azimuths_rad = [np.deg2rad(a) for a in azimuths_deg] + elif version == "svd_image_decoder": + num_frames = default(num_frames, 14) + num_steps = default(num_steps, 25) + output_folder = default( + output_folder, "outputs/simple_video_sample/svd_image_decoder/" + ) + model_config = "scripts/sampling/configs/svd_image_decoder.yaml" + elif version == "svd_xt_image_decoder": + num_frames = default(num_frames, 25) + num_steps = default(num_steps, 30) + output_folder = default( + output_folder, "outputs/simple_video_sample/svd_xt_image_decoder/" + ) + model_config = "scripts/sampling/configs/svd_xt_image_decoder.yaml" + elif version == "sv3d_u_image_decoder": + num_frames = 21 + num_steps = default(num_steps, 50) + output_folder = default( + output_folder, "outputs/simple_video_sample/sv3d_u_image_decoder/" + ) + model_config = "scripts/sampling/configs/sv3d_u_image_decoder.yaml" + cond_aug = 1e-5 + elif version == "sv3d_p_image_decoder": + num_frames = 21 + num_steps = default(num_steps, 50) + output_folder = default( + output_folder, "outputs/simple_video_sample/sv3d_p_image_decoder/" + ) + model_config = "scripts/sampling/configs/sv3d_p_image_decoder.yaml" + cond_aug = 1e-5 + if isinstance(elevations_deg, float) or isinstance(elevations_deg, int): + elevations_deg = [elevations_deg] * num_frames + polars_rad = [np.deg2rad(90 - e) for e in elevations_deg] + if azimuths_deg is None: + azimuths_deg = np.linspace(0, 360, num_frames + 1)[1:] % 360 + azimuths_rad = [np.deg2rad(a) for a in azimuths_deg] + else: + raise ValueError(f"Version {version} does not exist.") + + model, filter = load_model( + model_config, + device, + num_frames, + num_steps, + ) + torch.manual_seed(seed) + + path = Path(input_path) + all_img_paths = [] + if path.is_file(): + if any([input_path.endswith(x) for x in ["jpg", "jpeg", "png"]]): + all_img_paths = [input_path] + else: + raise ValueError("Path is not valid image file.") + elif path.is_dir(): + all_img_paths = sorted( + [ + f + for f in path.iterdir() + if f.is_file() and f.suffix.lower() in [".jpg", ".jpeg", ".png"] + ] + ) + if len(all_img_paths) == 0: + raise ValueError("Folder does not contain any images.") + else: + raise ValueError + + for input_img_path in all_img_paths: + if "sv3d" in version: + image = Image.open(input_img_path) + if image.mode == "RGBA": + pass + else: + # remove bg + image.thumbnail([768, 768], Image.Resampling.LANCZOS) + image = remove(image.convert("RGBA"), alpha_matting=True) + + # resize object in frame + image_arr = np.array(image) + in_w, in_h = image_arr.shape[:2] + ret, mask = cv2.threshold( + np.array(image.split()[-1]), 0, 255, cv2.THRESH_BINARY + ) + x, y, w, h = cv2.boundingRect(mask) + max_size = max(w, h) + side_len = ( + int(max_size / image_frame_ratio) + if image_frame_ratio is not None + else in_w + ) + padded_image = np.zeros((side_len, side_len, 4), dtype=np.uint8) + center = side_len // 2 + padded_image[ + center - h // 2 : center - h // 2 + h, + center - w // 2 : center - w // 2 + w, + ] = image_arr[y : y + h, x : x + w] + # resize frame to 576x576 + rgba = Image.fromarray(padded_image).resize((576, 576), Image.LANCZOS) + # white bg + rgba_arr = np.array(rgba) / 255.0 + rgb = rgba_arr[..., :3] * rgba_arr[..., -1:] + (1 - rgba_arr[..., -1:]) + input_image = Image.fromarray((rgb * 255).astype(np.uint8)) + + else: + with Image.open(input_img_path) as image: + if image.mode == "RGBA": + input_image = image.convert("RGB") + w, h = image.size + + if h % 64 != 0 or w % 64 != 0: + width, height = map(lambda x: x - x % 64, (w, h)) + input_image = input_image.resize((width, height)) + print( + f"WARNING: Your image is of size {h}x{w} which is not divisible by 64. We are resizing to {height}x{width}!" + ) + + image = ToTensor()(input_image) + image = image * 2.0 - 1.0 + + image = image.unsqueeze(0).to(device) + H, W = image.shape[2:] + assert image.shape[1] == 3 + F = 8 + C = 4 + shape = (num_frames, C, H // F, W // F) + if (H, W) != (576, 1024) and "sv3d" not in version: + print( + "WARNING: The conditioning frame you provided is not 576x1024. This leads to suboptimal performance as model was only trained on 576x1024. Consider increasing `cond_aug`." + ) + if (H, W) != (576, 576) and "sv3d" in version: + print( + "WARNING: The conditioning frame you provided is not 576x576. This leads to suboptimal performance as model was only trained on 576x576." + ) + if motion_bucket_id > 255: + print( + "WARNING: High motion bucket! This may lead to suboptimal performance." + ) + + if fps_id < 5: + print("WARNING: Small fps value! This may lead to suboptimal performance.") + + if fps_id > 30: + print("WARNING: Large fps value! This may lead to suboptimal performance.") + + json_files = sorted( + glob(os.path.join(os.path.dirname(input_img_path), "../", "*.json")) + )[:21] + polars_rad, azimuths_rad = [], [] + for json_file in json_files: + with open(json_file, "r") as f: + f_dict = json.load(f) + polars_rad.append(f_dict["polar"]) + azimuths_rad.append(f_dict["azimuth"]) + + azimuths_rad = (np.array(azimuths_rad) - azimuths_rad[-1]) % np.deg2rad(360) + + value_dict = {} + value_dict["cond_frames_without_noise"] = image + value_dict["motion_bucket_id"] = motion_bucket_id + value_dict["fps_id"] = fps_id + value_dict["cond_aug"] = cond_aug + value_dict["cond_frames"] = image + cond_aug * torch.randn_like(image) + if "sv3d_p" in version: + value_dict["polars_rad"] = polars_rad + value_dict["azimuths_rad"] = azimuths_rad + + with torch.no_grad(): + with torch.autocast(device): + batch, batch_uc = get_batch( + get_unique_embedder_keys_from_conditioner(model.conditioner), + value_dict, + [1, num_frames], + T=num_frames, + device=device, + ) + c, uc = model.conditioner.get_unconditional_conditioning( + batch, + batch_uc=batch_uc, + force_uc_zero_embeddings=[ + "cond_frames", + "cond_frames_without_noise", + ], + ) + + for k in ["crossattn", "concat"]: + uc[k] = repeat(uc[k], "b ... -> b t ...", t=num_frames) + uc[k] = rearrange(uc[k], "b t ... -> (b t) ...", t=num_frames) + c[k] = repeat(c[k], "b ... -> b t ...", t=num_frames) + c[k] = rearrange(c[k], "b t ... -> (b t) ...", t=num_frames) + + randn = torch.randn(shape, device=device) + + additional_model_inputs = {} + additional_model_inputs["image_only_indicator"] = torch.zeros( + 2, num_frames + ).to(device) + additional_model_inputs["num_video_frames"] = batch["num_video_frames"] + + def denoiser(input, sigma, c): + return model.denoiser( + model.model, input, sigma, c, **additional_model_inputs + ) + + samples_z = model.sampler(denoiser, randn, cond=c, uc=uc) + model.en_and_decode_n_samples_a_time = decoding_t + samples_x = model.decode_first_stage(samples_z) + if "sv3d" in version: + samples_x[-1:] = value_dict["cond_frames_without_noise"] + samples = torch.clamp((samples_x + 1.0) / 2.0, min=0.0, max=1.0) + + os.makedirs(output_folder, exist_ok=True) + base_count = len(glob(os.path.join(output_folder, "*.mp4"))) + + imageio.imwrite( + os.path.join(output_folder, f"{base_count:06d}.jpg"), input_image + ) + + samples = embed_watermark(samples) + samples = filter(samples) + vid = ( + (rearrange(samples, "t c h w -> t h w c") * 255) + .cpu() + .numpy() + .astype(np.uint8) + ) + video_path = os.path.join(output_folder, f"{base_count:06d}.mp4") + imageio.mimwrite(video_path, vid) + + +def get_unique_embedder_keys_from_conditioner(conditioner): + return list(set([x.input_key for x in conditioner.embedders])) + + +def get_batch(keys, value_dict, N, T, device): + batch = {} + batch_uc = {} + + for key in keys: + if key == "fps_id": + batch[key] = ( + torch.tensor([value_dict["fps_id"]]) + .to(device) + .repeat(int(math.prod(N))) + ) + elif key == "motion_bucket_id": + batch[key] = ( + torch.tensor([value_dict["motion_bucket_id"]]) + .to(device) + .repeat(int(math.prod(N))) + ) + elif key == "cond_aug": + batch[key] = repeat( + torch.tensor([value_dict["cond_aug"]]).to(device), + "1 -> b", + b=math.prod(N), + ) + elif key == "cond_frames" or key == "cond_frames_without_noise": + batch[key] = repeat(value_dict[key], "1 ... -> b ...", b=N[0]) + elif key == "polars_rad" or key == "azimuths_rad": + batch[key] = torch.tensor(value_dict[key]).to(device).repeat(N[0]) + else: + batch[key] = value_dict[key] + + if T is not None: + batch["num_video_frames"] = T + + for key in batch.keys(): + if key not in batch_uc and isinstance(batch[key], torch.Tensor): + batch_uc[key] = torch.clone(batch[key]) + return batch, batch_uc + + +def load_model( + config: str, + device: str, + num_frames: int, + num_steps: int, +): + config = OmegaConf.load(config) + if device == "cuda": + config.model.params.conditioner_config.params.emb_models[ + 0 + ].params.open_clip_embedding_config.params.init_device = device + + config.model.params.sampler_config.params.num_steps = num_steps + config.model.params.sampler_config.params.guider_config.params.num_frames = ( + num_frames + ) + if device == "cuda": + with torch.device(device): + model = instantiate_from_config(config.model).to(device).eval() + else: + model = instantiate_from_config(config.model).to(device).eval() + + filter = DeepFloydDataFiltering(verbose=False, device=device) + return model, filter + + +if __name__ == "__main__": + Fire(sample)