Skip to content

Commit

Permalink
Small fixes : cond_aug
Browse files Browse the repository at this point in the history
  • Loading branch information
Vikram Voleti committed Mar 15, 2024
1 parent b897c1c commit 6b0a47a
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 24 deletions.
16 changes: 8 additions & 8 deletions scripts/sampling/configs/sv3d_p_image_decoder.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ model:
params:
scale_factor: 0.18215
disable_first_stage_autocast: True
ckpt_path: checkpoints/sv3d_p.safetensors
ckpt_path: checkpoints/sv3d_p.safetensorsnew2.safetensors

denoiser_config:
target: sgm.modules.diffusionmodules.denoiser.Denoiser
Expand Down Expand Up @@ -48,7 +48,7 @@ model:
params:
freeze: True

- input_key: cond_frames_without_noise
- input_key: cond_frames
is_trainable: False
target: sgm.modules.encoders.modules.VideoPredictionEmbedderWithEncoder
params:
Expand All @@ -75,12 +75,12 @@ model:
dropout: 0.0
lossconfig:
target: torch.nn.Identity
sigma_sampler_config:
target: sgm.modules.diffusionmodules.sigma_sampling.ZeroSampler
sigma_cond_config:
target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
params:
outdim: 256

- input_key: cond_aug
is_trainable: False
target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
params:
outdim: 256

- input_key: polars_rad
is_trainable: False
Expand Down
18 changes: 9 additions & 9 deletions scripts/sampling/configs/sv3d_u_image_decoder.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ model:
params:
adm_in_channels: 256
num_classes: sequential
use_checkpoint: True
use_checkpoint: False
in_channels: 8
out_channels: 4
model_channels: 320
Expand Down Expand Up @@ -48,7 +48,7 @@ model:
params:
freeze: True

- input_key: cond_frames_without_noise
- input_key: cond_frames
is_trainable: False
target: sgm.modules.encoders.modules.VideoPredictionEmbedderWithEncoder
params:
Expand All @@ -75,12 +75,12 @@ model:
dropout: 0.0
lossconfig:
target: torch.nn.Identity
sigma_sampler_config:
target: sgm.modules.diffusionmodules.sigma_sampling.ZeroSampler
sigma_cond_config:
target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
params:
outdim: 256

- input_key: cond_aug
is_trainable: False
target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
params:
outdim: 256

first_stage_config:
target: sgm.models.autoencoder.AutoencodingEngine
Expand Down Expand Up @@ -117,4 +117,4 @@ model:
guider_config:
target: sgm.modules.diffusionmodules.guiders.TrianglePredictionGuider
params:
max_scale: 3.0
max_scale: 2.5
14 changes: 7 additions & 7 deletions scripts/sampling/simple_video_sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,15 +94,15 @@ def sample(
output_folder, "outputs/simple_video_sample/sv3d_u_image_decoder/"
)
model_config = "scripts/sampling/configs/sv3d_u_image_decoder.yaml"
cond_aug = 0.0
cond_aug = 1e-5
elif version == "sv3d_p":
num_frames = 21
num_steps = default(num_steps, 50)
output_folder = default(
output_folder, "outputs/simple_video_sample/sv3d_p_image_decoder/"
)
model_config = "scripts/sampling/configs/sv3d_p_image_decoder.yaml"
cond_aug = 0.0
cond_aug = 1e-5
if isinstance(elevations_deg, float) or isinstance(elevations_deg, int):
elevations_deg = [elevations_deg] * num_frames
polars_rad = [np.deg2rad(90 - e) for e in elevations_deg]
Expand Down Expand Up @@ -220,14 +220,13 @@ def sample(

value_dict = {}
value_dict["cond_frames_without_noise"] = image
value_dict["motion_bucket_id"] = motion_bucket_id
value_dict["fps_id"] = fps_id
value_dict["cond_aug"] = cond_aug
value_dict["cond_frames"] = image + cond_aug * torch.randn_like(image)
if version == "sv3d_p":
value_dict["polars_rad"] = polars_rad
value_dict["azimuths_rad"] = azimuths_rad
elif "sv3d" not in version:
value_dict["motion_bucket_id"] = motion_bucket_id
value_dict["fps_id"] = fps_id
value_dict["cond_aug"] = cond_aug
value_dict["cond_frames"] = image + cond_aug * torch.randn_like(image)

with torch.no_grad():
with torch.autocast(device):
Expand Down Expand Up @@ -358,6 +357,7 @@ def load_model(
0
].params.open_clip_embedding_config.params.init_device = device

config.model.params.sampler_config.params.verbose = True
config.model.params.sampler_config.params.num_steps = num_steps
config.model.params.sampler_config.params.guider_config.params.num_frames = (
num_frames
Expand Down

0 comments on commit 6b0a47a

Please sign in to comment.