Skip to content

Commit

Permalink
Fixes SV3D checkpoint, fixes rembg
Browse files Browse the repository at this point in the history
  • Loading branch information
Vikram Voleti committed Mar 18, 2024
1 parent 6b0a47a commit 4a9fa78
Show file tree
Hide file tree
Showing 5 changed files with 6 additions and 28 deletions.
1 change: 1 addition & 0 deletions requirements/pt2.txt
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ pillow>=9.5.0
pudb>=2022.1.3
pytorch-lightning==2.0.1
pyyaml>=6.0.1
rembg
scipy>=1.10.1
streamlit>=0.73.1
tensorboardx==2.6
Expand Down
2 changes: 1 addition & 1 deletion scripts/sampling/configs/sv3d_p_image_decoder.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ model:
params:
scale_factor: 0.18215
disable_first_stage_autocast: True
ckpt_path: checkpoints/sv3d_p.safetensorsnew2.safetensors
ckpt_path: checkpoints/sv3d_p.safetensors

denoiser_config:
target: sgm.modules.diffusionmodules.denoiser.Denoiser
Expand Down
2 changes: 1 addition & 1 deletion scripts/sampling/configs/sv3d_u_image_decoder.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ model:
params:
adm_in_channels: 256
num_classes: sequential
use_checkpoint: False
use_checkpoint: True
in_channels: 8
out_channels: 4
model_channels: 320
Expand Down
24 changes: 3 additions & 21 deletions scripts/sampling/simple_video_sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,38 +13,21 @@

import cv2
import numpy as np
import requests
import torch
from einops import rearrange, repeat
from fire import Fire
from omegaconf import OmegaConf
from PIL import Image
from rembg import remove
from scripts.util.detection.nsfw_and_watermark_dectection import DeepFloydDataFiltering
from sgm.inference.helpers import embed_watermark
from sgm.util import default, instantiate_from_config
from torchvision.transforms import ToTensor


def remove_bg_stable_PIL(PIL_img, MYAPIKEY=""):
img_byte_arr = io.BytesIO()
PIL_img.save(img_byte_arr, format="PNG")
img_byte_arr.seek(0)
response = requests.post(
f"https://dev.apiv2.stability.ai/v2alpha/generation/stable-image/remove-background",
headers={"authorization": f"Bearer sk-{MYAPIKEY}"},
files={"image": io.BufferedReader(img_byte_arr)},
data={"output_format": "png"},
)
if response.status_code == 200:
return Image.open(io.BytesIO(response.content))
else:
print("ERROR: Could not remove background!! " + str(response.json()))
return PIL_img


def sample(
input_path: str = "assets/test_image.png", # Can either be image file or folder with image files
num_frames: Optional[int] = None,
num_frames: Optional[int] = None, # 21 for SV3D
num_steps: Optional[int] = None,
version: str = "svd",
fps_id: int = 6,
Expand Down Expand Up @@ -149,7 +132,7 @@ def sample(
else:
# remove bg
image.thumbnail([768, 768], Image.Resampling.LANCZOS)
image = remove_bg_stable_PIL(image)
image = remove(image.convert("RGBA"), alpha_matting=True)

# resize object in frame
image_arr = np.array(image)
Expand Down Expand Up @@ -357,7 +340,6 @@ def load_model(
0
].params.open_clip_embedding_config.params.init_device = device

config.model.params.sampler_config.params.verbose = True
config.model.params.sampler_config.params.num_steps = num_steps
config.model.params.sampler_config.params.guider_config.params.num_frames = (
num_frames
Expand Down
5 changes: 0 additions & 5 deletions sgm/modules/diffusionmodules/sigma_sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,3 @@ def __call__(self, n_samples, rand=None):
torch.randint(0, self.num_idx, (n_samples,)),
)
return self.idx_to_sigma(idx)


class ZeroSampler:
def __call__(self, n_samples: int, rand=None) -> torch.Tensor:
return torch.zeros_like(default(rand, torch.randn((n_samples,)))) + 1.0e-5

0 comments on commit 4a9fa78

Please sign in to comment.