From 4a9fa787d11f76305f1bb3b185fb1798633b0c6d Mon Sep 17 00:00:00 2001 From: Vikram Voleti Date: Mon, 18 Mar 2024 01:47:09 +0000 Subject: [PATCH] Fixes SV3D checkpoint, fixes rembg --- requirements/pt2.txt | 1 + .../configs/sv3d_p_image_decoder.yaml | 2 +- .../configs/sv3d_u_image_decoder.yaml | 2 +- scripts/sampling/simple_video_sample.py | 24 +++---------------- .../diffusionmodules/sigma_sampling.py | 5 ---- 5 files changed, 6 insertions(+), 28 deletions(-) diff --git a/requirements/pt2.txt b/requirements/pt2.txt index 26bb71a6..824473ab 100644 --- a/requirements/pt2.txt +++ b/requirements/pt2.txt @@ -19,6 +19,7 @@ pillow>=9.5.0 pudb>=2022.1.3 pytorch-lightning==2.0.1 pyyaml>=6.0.1 +rembg scipy>=1.10.1 streamlit>=0.73.1 tensorboardx==2.6 diff --git a/scripts/sampling/configs/sv3d_p_image_decoder.yaml b/scripts/sampling/configs/sv3d_p_image_decoder.yaml index bd369e64..c6d8579d 100644 --- a/scripts/sampling/configs/sv3d_p_image_decoder.yaml +++ b/scripts/sampling/configs/sv3d_p_image_decoder.yaml @@ -3,7 +3,7 @@ model: params: scale_factor: 0.18215 disable_first_stage_autocast: True - ckpt_path: checkpoints/sv3d_p.safetensorsnew2.safetensors + ckpt_path: checkpoints/sv3d_p.safetensors denoiser_config: target: sgm.modules.diffusionmodules.denoiser.Denoiser diff --git a/scripts/sampling/configs/sv3d_u_image_decoder.yaml b/scripts/sampling/configs/sv3d_u_image_decoder.yaml index 15fddef4..defdd83d 100644 --- a/scripts/sampling/configs/sv3d_u_image_decoder.yaml +++ b/scripts/sampling/configs/sv3d_u_image_decoder.yaml @@ -16,7 +16,7 @@ model: params: adm_in_channels: 256 num_classes: sequential - use_checkpoint: False + use_checkpoint: True in_channels: 8 out_channels: 4 model_channels: 320 diff --git a/scripts/sampling/simple_video_sample.py b/scripts/sampling/simple_video_sample.py index 1d038689..064ea4e9 100644 --- a/scripts/sampling/simple_video_sample.py +++ b/scripts/sampling/simple_video_sample.py @@ -13,38 +13,21 @@ import cv2 import numpy as np -import requests import torch from einops import rearrange, repeat from fire import Fire from omegaconf import OmegaConf from PIL import Image +from rembg import remove from scripts.util.detection.nsfw_and_watermark_dectection import DeepFloydDataFiltering from sgm.inference.helpers import embed_watermark from sgm.util import default, instantiate_from_config from torchvision.transforms import ToTensor -def remove_bg_stable_PIL(PIL_img, MYAPIKEY=""): - img_byte_arr = io.BytesIO() - PIL_img.save(img_byte_arr, format="PNG") - img_byte_arr.seek(0) - response = requests.post( - f"https://dev.apiv2.stability.ai/v2alpha/generation/stable-image/remove-background", - headers={"authorization": f"Bearer sk-{MYAPIKEY}"}, - files={"image": io.BufferedReader(img_byte_arr)}, - data={"output_format": "png"}, - ) - if response.status_code == 200: - return Image.open(io.BytesIO(response.content)) - else: - print("ERROR: Could not remove background!! " + str(response.json())) - return PIL_img - - def sample( input_path: str = "assets/test_image.png", # Can either be image file or folder with image files - num_frames: Optional[int] = None, + num_frames: Optional[int] = None, # 21 for SV3D num_steps: Optional[int] = None, version: str = "svd", fps_id: int = 6, @@ -149,7 +132,7 @@ def sample( else: # remove bg image.thumbnail([768, 768], Image.Resampling.LANCZOS) - image = remove_bg_stable_PIL(image) + image = remove(image.convert("RGBA"), alpha_matting=True) # resize object in frame image_arr = np.array(image) @@ -357,7 +340,6 @@ def load_model( 0 ].params.open_clip_embedding_config.params.init_device = device - config.model.params.sampler_config.params.verbose = True config.model.params.sampler_config.params.num_steps = num_steps config.model.params.sampler_config.params.guider_config.params.num_frames = ( num_frames diff --git a/sgm/modules/diffusionmodules/sigma_sampling.py b/sgm/modules/diffusionmodules/sigma_sampling.py index c2bac44b..d54724c6 100644 --- a/sgm/modules/diffusionmodules/sigma_sampling.py +++ b/sgm/modules/diffusionmodules/sigma_sampling.py @@ -29,8 +29,3 @@ def __call__(self, n_samples, rand=None): torch.randint(0, self.num_idx, (n_samples,)), ) return self.idx_to_sigma(idx) - - -class ZeroSampler: - def __call__(self, n_samples: int, rand=None) -> torch.Tensor: - return torch.zeros_like(default(rand, torch.randn((n_samples,)))) + 1.0e-5