diff --git a/README.md b/README.md index c21114f5f..b4b06af36 100644 --- a/README.md +++ b/README.md @@ -7,6 +7,7 @@ This extension is for AUTOMATIC1111's [Stable Diffusion web UI](https://github.c # News +- [2024-07-01] 🔥[v1.1.452] Depth Anything V2 - UDAV2 depth Preprocessor [Pull thread: https://github.com/Mikubill/sd-webui-controlnet/pull/2969] - [2024-05-19] 🔥[v1.1.449] Anyline Preprocessor & MistoLine SDXL model [Discussion thread: https://github.com/Mikubill/sd-webui-controlnet/discussions/2907] - [2024-05-04] 🔥[v1.1.447] PuLID [Discussion thread: https://github.com/Mikubill/sd-webui-controlnet/discussions/2841] - [2024-04-30] 🔥[v1.1.446] Effective region mask supported for ControlNet/IPAdapter [Discussion thread: https://github.com/Mikubill/sd-webui-controlnet/discussions/2831] diff --git a/annotator/depth_anything_v2.py b/annotator/depth_anything_v2.py new file mode 100644 index 000000000..b3c8b8920 --- /dev/null +++ b/annotator/depth_anything_v2.py @@ -0,0 +1,78 @@ +import os +import torch +import cv2 +import numpy as np +import torch.nn.functional as F +from torchvision.transforms import Compose +from safetensors.torch import load_file + +from depth_anything_v2.dpt import DepthAnythingV2 +from depth_anything_v2.util.transform import Resize, NormalizeImage, PrepareForNet +from .util import load_model +from .annotator_path import models_path + +transform = Compose( + [ + Resize( + width=518, + height=518, + resize_target=False, + keep_aspect_ratio=True, + ensure_multiple_of=14, + resize_method="lower_bound", + image_interpolation_method=cv2.INTER_CUBIC, + ), + NormalizeImage(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), + PrepareForNet(), + ] +) + +class DepthAnythingV2Detector: + """https://github.com/MackinationsAi/Upgraded-Depth-Anything-V2""" + + model_dir = os.path.join(models_path, "depth_anything_v2") + + def __init__(self, device: torch.device): + self.device = device + self.model = ( + DepthAnythingV2( + encoder="vitl", + features=256, + out_channels=[256, 512, 1024, 1024], + ) + .to(device) + .eval() + ) + remote_url = os.environ.get( + "CONTROLNET_DEPTH_ANYTHING_V2_MODEL_URL", + "https://huggingface.co/MackinationsAi/Depth-Anything-V2_Safetensors/resolve/main/depth_anything_v2_vitl.safetensors", + ) + model_path = load_model( + "depth_anything_v2_vitl.safetensors", remote_url=remote_url, model_dir=self.model_dir + ) + self.model.load_state_dict(load_file(model_path)) + + def __call__(self, image: np.ndarray, colored: bool = True) -> np.ndarray: + self.model.to(self.device) + h, w = image.shape[:2] + + image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) / 255.0 + image = transform({"image": image})["image"] + image = torch.from_numpy(image).unsqueeze(0).to(self.device) + @torch.no_grad() + def predict_depth(model, image): + return model(image) + depth = predict_depth(self.model, image) + depth = F.interpolate( + depth[None], (h, w), mode="bilinear", align_corners=False + )[0, 0] + depth = (depth - depth.min()) / (depth.max() - depth.min()) * 255.0 + depth = depth.cpu().numpy().astype(np.uint8) + if colored: + depth_color = cv2.applyColorMap(depth, cv2.COLORMAP_INFERNO)[:, :, ::-1] + return depth_color + else: + return depth + + def unload_model(self): + self.model.to("cpu") \ No newline at end of file diff --git a/install.py b/install.py index 9c675f2ac..a02ae032c 100644 --- a/install.py +++ b/install.py @@ -167,6 +167,7 @@ def try_remove_legacy_submodule(): ), version="2024.2.12.0", ) + try_install_from_wheel( "depth_anything", wheel_url=os.environ.get( @@ -175,6 +176,14 @@ def try_remove_legacy_submodule(): ), ) +try_install_from_wheel( + "depth_anything_v2", + wheel_url=os.environ.get( + "DEPTH_ANYTHING_V2_WHEEL", + "https://github.com/MackinationsAi/UDAV2-ControlNet/releases/download/v1.0.0/depth_anything_v2-2024.7.1.0-py2.py3-none-any.whl", + ), +) + try_install_from_wheel( "dsine", wheel_url=os.environ.get( diff --git a/scripts/preprocessor/legacy/preprocessor_compiled.py b/scripts/preprocessor/legacy/preprocessor_compiled.py index b5a69164a..33a9e16e6 100644 --- a/scripts/preprocessor/legacy/preprocessor_compiled.py +++ b/scripts/preprocessor/legacy/preprocessor_compiled.py @@ -168,6 +168,22 @@ "Depth" ] }, + "depth_anything_v2": { + "label": "depth_anything_v2", + "call_function": functools.partial(depth_anything_v2, colored=False), + "unload_function": unload_depth_anything_v2, + "managed_model": "model_depth_anything_v2", + "model_free": False, + "no_control_mode": False, + "resolution": None, + "slider_1": None, + "slider_2": None, + "slider_3": None, + "priority": 0, + "tags": [ + "Depth" + ] + }, "depth_hand_refiner": { "label": "depth_hand_refiner", "call_function": g_hand_refiner_model.run_model, diff --git a/scripts/preprocessor/legacy/processor.py b/scripts/preprocessor/legacy/processor.py index e9861fb3b..1a2ff0abb 100644 --- a/scripts/preprocessor/legacy/processor.py +++ b/scripts/preprocessor/legacy/processor.py @@ -189,6 +189,25 @@ def unload_depth_anything(): model_depth_anything.unload_model() +model_depth_anything_v2 = None + + +def depth_anything_v2(img, res:int = 512, colored:bool = True, **kwargs): + img, remove_pad = resize_image_with_pad(img, res) + global model_depth_anything_v2 + if model_depth_anything_v2 is None: + with Extra(torch_handler): + from annotator.depth_anything_v2 import DepthAnythingV2Detector + device = devices.get_device_for("controlnet") + model_depth_anything_v2 = DepthAnythingV2Detector(device) + return remove_pad(model_depth_anything_v2(img, colored=colored)), True + + +def unload_depth_anything_v2(): + if model_depth_anything_v2 is not None: + model_depth_anything_v2.unload_model() + + model_midas = None diff --git a/tests/web_api/full_coverage/depth_test.py b/tests/web_api/full_coverage/depth_test.py index d2802d443..98fc16376 100644 --- a/tests/web_api/full_coverage/depth_test.py +++ b/tests/web_api/full_coverage/depth_test.py @@ -18,6 +18,7 @@ "depth_leres", "depth_leres++", "depth_anything", + "depth_anything_v2", ] hand_refiner_module = "depth_hand_refiner" diff --git a/tests/web_api/modules_test.py b/tests/web_api/modules_test.py index de2dab132..0f50e995d 100644 --- a/tests/web_api/modules_test.py +++ b/tests/web_api/modules_test.py @@ -15,6 +15,7 @@ "densepose_parula", "depth", "depth_anything", + "depth_anything_v2", "depth_hand_refiner", "depth_leres", "depth_leres++", @@ -82,6 +83,7 @@ "densepose (pruple bg & purple torso)", "densepose_parula (black bg & blue torso)", "depth_anything", + "depth_anything_v2", "depth_hand_refiner", "depth_leres", "depth_leres++",