diff --git a/.flake8 b/.flake8 new file mode 100644 index 0000000..d92aeef --- /dev/null +++ b/.flake8 @@ -0,0 +1,5 @@ +[flake8] +ignore = E203, E266, E501, W503, F403, F401 +max-line-length = 79 +max-complexity = 50 +select = B,C,E,F,W,T4,B9 diff --git a/.gitignore b/.gitignore index 26376df..9874406 100644 --- a/.gitignore +++ b/.gitignore @@ -4,3 +4,4 @@ __pycache__ /tests/tests_temp/ /.pytest_cache/ /carvekit.egg-info/ +venv \ No newline at end of file diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 0000000..37dfec9 --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,10 @@ +repos: +- repo: https://github.com/psf/black + rev: 22.10.0 + hooks: + - id: black + language_version: python3.10 +- repo: https://gitlab.com/pycqa/flake8 + rev: 3.9.2 + hooks: + - id: flake8 \ No newline at end of file diff --git a/Dockerfile.cpu b/Dockerfile.cpu index ede8c1c..975093f 100644 --- a/Dockerfile.cpu +++ b/Dockerfile.cpu @@ -5,7 +5,7 @@ WORKDIR /app RUN pip3 install --no-cache-dir tqdm==4.64.0 requests==2.27.1 RUN mkdir -p ./carvekit/utils/ RUN mkdir -p ./carvekit/ml/files -RUN touch ./carvekit/__init__.py +COPY ./carvekit/__init__.py ./carvekit/__init__.py RUN touch ./carvekit/ml/__init__.py RUN touch ./carvekit/utils/__init__.py COPY ./carvekit/utils/download_models.py ./carvekit/utils/download_models.py @@ -18,7 +18,7 @@ FROM python:3.10.4 WORKDIR /app RUN apt-get update && apt-get -y install libgl1 # Install cv2 dep. -COPY --from=builder /root/.carvekit /root/.carvekit +COPY --from=builder /root/.cache/carvekit /root/.cache/carvekit # Install requirements COPY requirements.txt ./ @@ -34,15 +34,19 @@ RUN pip3 install -e ./ ENV CARVEKIT_PORT '5000' ENV CARVEKIT_HOST '0.0.0.0' -ENV CARVEKIT_SEGMENTATION_NETWORK 'u2net' +ENV CARVEKIT_SEGMENTATION_NETWORK 'tracer_b7' ENV CARVEKIT_PREPROCESSING_METHOD 'none' ENV CARVEKIT_POSTPROCESSING_METHOD 'fba' ENV CARVEKIT_DEVICE 'cpu' ENV CARVEKIT_BATCH_SIZE_SEG '5' ENV CARVEKIT_BATCH_SIZE_MATTING '1' -ENV CARVEKIT_SEG_MASK_SIZE '320' +ENV CARVEKIT_SEG_MASK_SIZE '640' ENV CARVEKIT_MATTING_MASK_SIZE '2048' ENV CARVEKIT_AUTH_ENABLE '1' +ENV CARVEKIT_FP16 '0' +ENV CARVEKIT_TRIMAP_PROB_THRESHOLD=231 +ENV CARVEKIT_TRIMAP_DILATION=30 +ENV CARVEKIT_TRIMAP_EROSION=5 # Tokens will be generated automatically every time the container is restarted if ENV is not set. diff --git a/Dockerfile.cuda b/Dockerfile.cuda index 630744c..b5d31df 100644 --- a/Dockerfile.cuda +++ b/Dockerfile.cuda @@ -5,7 +5,7 @@ WORKDIR /app RUN pip3 install --no-cache-dir tqdm==4.64.0 requests==2.27.1 RUN mkdir -p ./carvekit/utils/ RUN mkdir -p ./carvekit/ml/files -RUN touch ./carvekit/__init__.py +COPY ./carvekit/__init__.py ./carvekit/__init__.py RUN touch ./carvekit/ml/__init__.py RUN touch ./carvekit/utils/__init__.py COPY ./carvekit/utils/download_models.py ./carvekit/utils/download_models.py @@ -18,7 +18,7 @@ FROM pytorch/pytorch:1.11.0-cuda11.3-cudnn8-runtime WORKDIR /app RUN apt-get update && apt-get -y install libgl1 libglib2.0-0 # Install cv2 dep. -COPY --from=builder /root/.carvekit /root/.carvekit +COPY --from=builder /root/.cache/carvekit /root/.cache/carvekit # Install requirements COPY requirements.txt ./ @@ -34,15 +34,19 @@ RUN pip3 install -e ./ ENV CARVEKIT_PORT '5000' ENV CARVEKIT_HOST '0.0.0.0' -ENV CARVEKIT_SEGMENTATION_NETWORK 'u2net' +ENV CARVEKIT_SEGMENTATION_NETWORK 'tracer_b7' ENV CARVEKIT_PREPROCESSING_METHOD 'none' ENV CARVEKIT_POSTPROCESSING_METHOD 'fba' ENV CARVEKIT_DEVICE 'cuda' ENV CARVEKIT_BATCH_SIZE_SEG '5' ENV CARVEKIT_BATCH_SIZE_MATTING '1' -ENV CARVEKIT_SEG_MASK_SIZE '320' +ENV CARVEKIT_SEG_MASK_SIZE '640' ENV CARVEKIT_MATTING_MASK_SIZE '2048' ENV CARVEKIT_AUTH_ENABLE '1' +ENV CARVEKIT_FP16 '0' +ENV CARVEKIT_TRIMAP_PROB_THRESHOLD=231 +ENV CARVEKIT_TRIMAP_DILATION=30 +ENV CARVEKIT_TRIMAP_EROSION=5 # Tokens will be generated automatically every time the container is restarted if ENV is not set. diff --git a/README.md b/README.md index 255c132..8f44cc7 100644 --- a/README.md +++ b/README.md @@ -5,7 +5,7 @@

- +

@@ -24,11 +24,11 @@ ## 📄 Description: Automated high-quality background removal framework for an image using neural networks. - ## 🎆 Features: - High Quality - Batch Processing - NVIDIA CUDA and CPU processing +- FP16 inference: Fast inference with low memory usage - Easy inference - 100% remove.bg compatible FastAPI HTTP API - Removes background from hairs @@ -42,11 +42,25 @@ It can be briefly described as 3. Using machine learning technology, the background of the image is removed 4. Image post-processing to improve the quality of the processed image ## 🎓 Implemented Neural Networks: -* [U^2-net](https://github.com/NathanUA/U-2-Net) -* [BASNet](https://github.com/NathanUA/BASNet) -* [DeepLabV3](https://github.com/tensorflow/models/tree/master/research/deeplab) - - +| Networks | Target | Accuracy | +|:-----------------------:|:-------------------------------------------:|:--------------------------------:| +| **Tracer-B7** (default) | **General** (objects, animals, etc) | **90%** (mean F1-Score, DUTS-TE) | +| U^2-net | **Hairs** (hairs, people, animals, objects) | 80.4% (mean F1-Score, DUTS-TE) | +| BASNet | **General** (people, objects) | 80.3% (mean F1-Score, DUTS-TE) | +| DeepLabV3 | People, Animals, Cars, etc | 67.4% (mean IoU, COCO val2017) | + +### Recommended parameters for different models +| Networks | Segmentation mask size | Trimap parameters (dilation, erosion) | +|:-----------:|:-----------------------:|:-------------------------------------:| +| `tracer_b7` | 640 | (30, 5) | +| `u2net` | 320 | (30, 5) | +| `basnet` | 320 | (30, 5) | +| `deeplabv3` | 1024 | (40, 20) | + +> ### Notes: +> 1. The final quality may depend on the resolution of your image, the type of scene or object. +> 2. Use **U2-Net for hairs** and **Tracer-B7 for general images** and correct parameters. \ +> It is very important for final quality! Example images was taken by using U2-Net and FBA post-processing. ## 🖼️ Image pre-processing and post-processing methods: ### 🔍 Preprocessing methods: * `none` - No preprocessing methods used. @@ -69,12 +83,21 @@ It can be briefly described as import torch from carvekit.api.high import HiInterface -interface = HiInterface(batch_size_seg=5, batch_size_matting=1, - device='cuda' if torch.cuda.is_available() else 'cpu', - seg_mask_size=320, matting_mask_size=2048) -images_without_background = interface(['./tests/data/cat.jpg']) +# Check doc strings for more information +interface = HiInterface(object_type="hairs-like", # Can be "object" or "hairs-like". + batch_size_seg=5, + batch_size_matting=1, + device='cuda' if torch.cuda.is_available() else 'cpu', + seg_mask_size=640, # Use 640 for Tracer B7 and 320 for U2Net + matting_mask_size=2048, + trimap_prob_threshold=231, + trimap_dilation=30, + trimap_erosion_iters=5, + fp16=False) +images_without_background = interface(['./tests/data/cat.jpg']) cat_wo_bg = images_without_background[0] cat_wo_bg.save('2.png') + ``` @@ -84,12 +107,13 @@ import PIL.Image from carvekit.api.interface import Interface from carvekit.ml.wrap.fba_matting import FBAMatting -from carvekit.ml.wrap.u2net import U2NET +from carvekit.ml.wrap.tracer_b7 import TracerUniversalB7 from carvekit.pipelines.postprocessing import MattingMethod from carvekit.pipelines.preprocessing import PreprocessingStub from carvekit.trimap.generator import TrimapGenerator -u2net = U2NET(device='cpu', +# Check doc strings for more information +seg_net = TracerUniversalB7(device='cpu', batch_size=1) fba = FBAMatting(device='cpu', @@ -106,7 +130,7 @@ postprocessing = MattingMethod(matting_module=fba, interface = Interface(pre_pipe=preprocessing, post_pipe=postprocessing, - seg_pipe=u2net) + seg_pipe=seg_net) image = PIL.Image.open('tests/data/cat.jpg') cat_wo_bg = interface([image])[0] @@ -129,7 +153,7 @@ Options: -o ./2.png Path to output file or dir --pre none Preprocessing method --post fba Postprocessing method. - --net u2net Segmentation Network + --net tracer_b7 Segmentation Network. Check README for more info. --recursive Enables recursive search for images in a folder --batch_size 10 Batch Size for list of images to be loaded to RAM @@ -140,20 +164,34 @@ Options: --batch_size_mat 1 Batch size for list of images to be processed by matting network - --seg_mask_size 320 The size of the input image for the - segmentation neural network. + --seg_mask_size 640 The size of the input image for the + segmentation neural network. Use 640 for Tracer B7 and 320 for U2Net --matting_mask_size 2048 The size of the input image for the matting neural network. + --trimap_dilation 30 The size of the offset radius from the + object mask in pixels when forming an + unknown area + --trimap_erosion 5 The number of iterations of erosion that the + object's mask will be subjected to before + forming an unknown area + --trimap_prob_threshold 231 + Probability threshold at which the + prob_filter and prob_as_unknown_area + operations will be applied --device cpu Processing Device. + --fp16 Enables mixed precision processing. Use only with CUDA. CPU support is experimental! --help Show this message and exit. + ```` ## 📦 Running the Framework / FastAPI HTTP API server via Docker: +Using the API via docker is a **fast** and non-complex way to have a working API. +> **Our docker images are available on [Docker Hub](https://hub.docker.com/r/anodev/carvekit).** \ +> Version tags are the same as the releases of the project with suffixes `-cpu` and `-cuda` for CPU and CUDA versions respectively. + -Using the API via docker is a **fast** and non-complex way to have a working API.\ -**This HTTP API is 100% compatible with remove.bg API clients.**

@@ -188,21 +226,20 @@ See `docker-compose..yml` for more information. \ 1. Run `docker-compose -f docker-compose.cpu.yml run carvekit_api pytest` # For testing on CPU 2. Run `docker-compose -f docker-compose.cuda.yml run carvekit_api pytest` # For testing on GPU - ## 👪 Credits: [More info](docs/CREDITS.md) ## 💵 Support You can thank me for developing this project and buy me a small cup of coffee ☕ -| Blockchain | Cryptocurrency | Network | Wallet | +| Blockchain | Cryptocurrency | Network | Wallet | |:----------:|:-----------------------------------:|:-------------------------:|:-----------------------------------------------------------------------------------------------:| -| Ethereum | ETH / USDT / USDC / BNB / Dogecoin | Mainnet | 0x7Ab1B8015020242D2a9bC48F09b2F34b994bc2F8 | -| Ethereum | ETH / USDT / USDC / BNB / Dogecoin | BSC (Binance Smart Chain) | 0x7Ab1B8015020242D2a9bC48F09b2F34b994bc2F8 | -| Bitcoin | BTC | - | bc1qmf4qedujhhvcsg8kxpg5zzc2s3jvqssmu7mmhq | -| ZCash | ZEC | - | t1d7b9WxdboGFrcVVHG2ZuwWBgWEKhNUbtm | -| Tron | TRX | - | TH12CADSqSTcNZPvG77GVmYKAe4nrrJB5X | +| Ethereum | ETH / USDT / USDC / BNB / Dogecoin | Mainnet | 0x7Ab1B8015020242D2a9bC48F09b2F34b994bc2F8 | +| Ethereum | ETH / USDT / USDC / BNB / Dogecoin | BSC (Binance Smart Chain) | 0x7Ab1B8015020242D2a9bC48F09b2F34b994bc2F8 | +| Bitcoin | BTC | - | bc1qmf4qedujhhvcsg8kxpg5zzc2s3jvqssmu7mmhq | +| ZCash | ZEC | - | t1d7b9WxdboGFrcVVHG2ZuwWBgWEKhNUbtm | +| Tron | TRX | - | TH12CADSqSTcNZPvG77GVmYKAe4nrrJB5X | | Monero | XMR | Mainnet | 48w2pDYgPtPenwqgnNneEUC9Qt1EE6eD5MucLvU3FGpY3SABudDa4ce5bT1t32oBwchysRCUimCkZVsD1HQRBbxVLF9GTh3 | -| TON | TON | - | EQCznqTdfOKI3L06QX-3Q802tBL0ecSWIKfkSjU-qsoy0CWE | +| TON | TON | - | EQCznqTdfOKI3L06QX-3Q802tBL0ecSWIKfkSjU-qsoy0CWE | ## 📧 __Feedback__ I will be glad to receive feedback on the project and suggestions for integration. diff --git a/carvekit/__init__.py b/carvekit/__init__.py index c221c91..b58821b 100644 --- a/carvekit/__init__.py +++ b/carvekit/__init__.py @@ -1 +1 @@ -version = "4.0.8" +version = "4.1.0" diff --git a/carvekit/__main__.py b/carvekit/__main__.py index 183d0a1..acf901d 100644 --- a/carvekit/__main__.py +++ b/carvekit/__main__.py @@ -10,29 +10,95 @@ from carvekit.utils.fs_utils import save_file -@click.command('removebg', help="Performs background removal on specified photos using console interface.") -@click.option('-i', required=True, type=str, help='Path to input file or dir') -@click.option('-o', default="none", type=str, help="Path to output file or dir") -@click.option('--pre', default='none', type=str, help='Preprocessing method') -@click.option('--post', default='fba', type=str, help='Postprocessing method.') -@click.option('--net', default='u2net', type=str, help='Segmentation Network') -@click.option('--recursive', default=False, type=bool, help='Enables recursive search for images in a folder') -@click.option('--batch_size', default=10, type=int, help='Batch Size for list of images to be loaded to RAM') -@click.option('--batch_size_seg', default=5, type=int, - help='Batch size for list of images to be processed by segmentation ' - 'network') -@click.option('--batch_size_mat', default=1, type=int, help='Batch size for list of images to be processed by matting ' - 'network') -@click.option('--seg_mask_size', default=320, type=int, - help='The size of the input image for the segmentation neural network.') -@click.option('--matting_mask_size', default=2048, type=int, - help='The size of the input image for the matting neural network.') -@click.option('--device', default="cpu", type=str, - help='Processing Device.') -def removebg(i: str, o: str, pre: str, post: str, net: str, recursive: bool, - batch_size: int, batch_size_seg: int, batch_size_mat: int, seg_mask_size: int, - matting_mask_size: int, - device: str): +@click.command( + "removebg", + help="Performs background removal on specified photos using console interface.", +) +@click.option("-i", required=True, type=str, help="Path to input file or dir") +@click.option("-o", default="none", type=str, help="Path to output file or dir") +@click.option("--pre", default="none", type=str, help="Preprocessing method") +@click.option("--post", default="fba", type=str, help="Postprocessing method.") +@click.option("--net", default="tracer_b7", type=str, help="Segmentation Network") +@click.option( + "--recursive", + default=False, + type=bool, + help="Enables recursive search for images in a folder", +) +@click.option( + "--batch_size", + default=10, + type=int, + help="Batch Size for list of images to be loaded to RAM", +) +@click.option( + "--batch_size_seg", + default=5, + type=int, + help="Batch size for list of images to be processed by segmentation " "network", +) +@click.option( + "--batch_size_mat", + default=1, + type=int, + help="Batch size for list of images to be processed by matting " "network", +) +@click.option( + "--seg_mask_size", + default=640, + type=int, + help="The size of the input image for the segmentation neural network.", +) +@click.option( + "--matting_mask_size", + default=2048, + type=int, + help="The size of the input image for the matting neural network.", +) +@click.option( + "--trimap_dilation", + default=30, + type=int, + help="The size of the offset radius from the object mask in " + "pixels when forming an unknown area", +) +@click.option( + "--trimap_erosion", + default=5, + type=int, + help="The number of iterations of erosion that the object's " + "mask will be subjected to before forming an unknown area", +) +@click.option( + "--trimap_prob_threshold", + default=231, + type=int, + help="Probability threshold at which the prob_filter " + "and prob_as_unknown_area operations will be " + "applied", +) +@click.option("--device", default="cpu", type=str, help="Processing Device.") +@click.option( + "--fp16", default=False, type=bool, help="Enables mixed precision processing." +) +def removebg( + i: str, + o: str, + pre: str, + post: str, + net: str, + recursive: bool, + batch_size: int, + batch_size_seg: int, + batch_size_mat: int, + seg_mask_size: int, + matting_mask_size: int, + device: str, + fp16: bool, + trimap_dilation: int, + trimap_erosion: int, + trimap_prob_threshold: int, +): out_path = Path(o) input_path = Path(i) if input_path.is_dir(): @@ -40,8 +106,11 @@ def removebg(i: str, o: str, pre: str, post: str, net: str, recursive: bool, all_images = input_path.rglob("*.*") else: all_images = input_path.glob("*.*") - all_images = [i for i in all_images if i.suffix.lower() in ALLOWED_SUFFIXES - and '_bg_removed' not in i.name] + all_images = [ + i + for i in all_images + if i.suffix.lower() in ALLOWED_SUFFIXES and "_bg_removed" not in i.name + ] else: all_images = [input_path] @@ -53,18 +122,27 @@ def removebg(i: str, o: str, pre: str, post: str, net: str, recursive: bool, batch_size_seg=batch_size_seg, batch_size_matting=batch_size_mat, seg_mask_size=seg_mask_size, - matting_mask_size=matting_mask_size + matting_mask_size=matting_mask_size, + fp16=fp16, + trimap_dilation=trimap_dilation, + trimap_erosion=trimap_erosion, + trimap_prob_threshold=trimap_prob_threshold, ) interface = init_interface(interface_config) - for image_batch in tqdm.tqdm(batch_generator(all_images, n=batch_size), - total=int(len(all_images) / batch_size), - desc="Removing background", unit=" image batch", - colour="blue"): + for image_batch in tqdm.tqdm( + batch_generator(all_images, n=batch_size), + total=int(len(all_images) / batch_size), + desc="Removing background", + unit=" image batch", + colour="blue", + ): images_without_background = interface(image_batch) # Remove background - thread_pool_processing(lambda x: save_file(out_path, image_batch[x], images_without_background[x]), - range((len(image_batch)))) # Drop images to fs + thread_pool_processing( + lambda x: save_file(out_path, image_batch[x], images_without_background[x]), + range((len(image_batch))), + ) # Drop images to fs if __name__ == "__main__": diff --git a/carvekit/api/high.py b/carvekit/api/high.py index 8b0edfd..46fb9d3 100644 --- a/carvekit/api/high.py +++ b/carvekit/api/high.py @@ -3,39 +3,98 @@ Author: Nikita Selin (OPHoperHPO)[https://github.com/OPHoperHPO]. License: Apache License 2.0 """ +import warnings + from carvekit.api.interface import Interface from carvekit.ml.wrap.fba_matting import FBAMatting +from carvekit.ml.wrap.tracer_b7 import TracerUniversalB7 from carvekit.ml.wrap.u2net import U2NET from carvekit.pipelines.postprocessing import MattingMethod from carvekit.trimap.generator import TrimapGenerator class HiInterface(Interface): - def __init__(self, batch_size_seg=5, batch_size_matting=1, - device='cpu', seg_mask_size=320, matting_mask_size=2048): + def __init__( + self, + object_type: str = "object", + batch_size_seg=2, + batch_size_matting=1, + device="cpu", + seg_mask_size=640, + matting_mask_size=2048, + trimap_prob_threshold=231, + trimap_dilation=30, + trimap_erosion_iters=5, + fp16=False, + ): """ Initializes High Level interface. Args: + object_type: Interest object type. Can be "object" or "hairs-like". matting_mask_size: The size of the input image for the matting neural network. seg_mask_size: The size of the input image for the segmentation neural network. batch_size_seg: Number of images processed per one segmentation neural network call. batch_size_matting: Number of images processed per one matting neural network call. device: Processing device + fp16: Use half precision. Reduce memory usage and increase speed. Experimental support + trimap_prob_threshold: Probability threshold at which the prob_filter and prob_as_unknown_area operations will be applied + trimap_dilation: The size of the offset radius from the object mask in pixels when forming an unknown area + trimap_erosion_iters: The number of iterations of erosion that the object's mask will be subjected to before forming an unknown area Notes: - Changing seg_mask_size may cause an out-of-memory error if the value is too large, and it may also + 1. Changing seg_mask_size may cause an out-of-memory error if the value is too large, and it may also result in reduced precision. I do not recommend changing this value. You can change matting_mask_size in range from (1024 to 4096) to improve object edge refining quality, but it will cause extra large RAM and video memory consume. Also, you can change batch size to accelerate background removal, but it also causes extra large video memory consume, if value is too big. + + 2. Changing trimap_prob_threshold, trimap_kernel_size, trimap_erosion_iters may improve object edge + refining quality, """ - self.u2net = U2NET(device=device, batch_size=batch_size_seg, input_image_size=seg_mask_size) - self.fba = FBAMatting(batch_size=batch_size_matting, device=device, input_tensor_size=matting_mask_size) - self.trimap_generator = TrimapGenerator() - super(HiInterface, self).__init__(pre_pipe=None, - seg_pipe=self.u2net, - post_pipe=MattingMethod(matting_module=self.fba, - trimap_generator=self.trimap_generator, - device=device), - device=device) \ No newline at end of file + if object_type == "object": + self.u2net = TracerUniversalB7( + device=device, + batch_size=batch_size_seg, + input_image_size=seg_mask_size, + fp16=fp16, + ) + elif object_type == "hairs-like": + self.u2net = U2NET( + device=device, + batch_size=batch_size_seg, + input_image_size=seg_mask_size, + fp16=fp16, + ) + else: + warnings.warn( + f"Unknown object type: {object_type}. Using default object type: object" + ) + self.u2net = TracerUniversalB7( + device=device, + batch_size=batch_size_seg, + input_image_size=seg_mask_size, + fp16=fp16, + ) + + self.fba = FBAMatting( + batch_size=batch_size_matting, + device=device, + input_tensor_size=matting_mask_size, + fp16=fp16, + ) + self.trimap_generator = TrimapGenerator( + prob_threshold=trimap_prob_threshold, + kernel_size=trimap_dilation, + erosion_iters=trimap_erosion_iters, + ) + super(HiInterface, self).__init__( + pre_pipe=None, + seg_pipe=self.u2net, + post_pipe=MattingMethod( + matting_module=self.fba, + trimap_generator=self.trimap_generator, + device=device, + ), + device=device, + ) diff --git a/carvekit/api/interface.py b/carvekit/api/interface.py index 0ddf1f5..364d247 100644 --- a/carvekit/api/interface.py +++ b/carvekit/api/interface.py @@ -11,6 +11,7 @@ from carvekit.ml.wrap.basnet import BASNET from carvekit.ml.wrap.deeplab_v3 import DeepLabV3 from carvekit.ml.wrap.u2net import U2NET +from carvekit.ml.wrap.tracer_b7 import TracerUniversalB7 from carvekit.pipelines.preprocessing import PreprocessingStub from carvekit.pipelines.postprocessing import MattingMethod from carvekit.utils.image_utils import load_image @@ -19,11 +20,13 @@ class Interface: - def __init__(self, - seg_pipe: Union[U2NET, BASNET, DeepLabV3], - pre_pipe: Optional[Union[PreprocessingStub]] = None, - post_pipe: Optional[Union[MattingMethod]] = None, - device="cpu"): + def __init__( + self, + seg_pipe: Union[U2NET, BASNET, DeepLabV3, TracerUniversalB7], + pre_pipe: Optional[Union[PreprocessingStub]] = None, + post_pipe: Optional[Union[MattingMethod]] = None, + device="cpu", + ): """ Initializes an object for interacting with pipelines and other components of the CarveKit framework. @@ -38,7 +41,9 @@ def __init__(self, self.segmentation_pipeline = seg_pipe self.postprocessing_pipeline = post_pipe - def __call__(self, images: List[Union[str, Path, Image.Image]]) -> List[Image.Image]: + def __call__( + self, images: List[Union[str, Path, Image.Image]] + ) -> List[Image.Image]: """ Removes the background from the specified images. @@ -50,13 +55,23 @@ def __call__(self, images: List[Union[str, Path, Image.Image]]) -> List[Image.Im """ images = thread_pool_processing(load_image, images) if self.preprocessing_pipeline is not None: - masks: List[Image.Image] = self.preprocessing_pipeline(interface=self, images=images) + masks: List[Image.Image] = self.preprocessing_pipeline( + interface=self, images=images + ) else: masks: List[Image.Image] = self.segmentation_pipeline(images=images) if self.postprocessing_pipeline is not None: - images: List[Image.Image] = self.postprocessing_pipeline(images=images, masks=masks) + images: List[Image.Image] = self.postprocessing_pipeline( + images=images, masks=masks + ) else: - images = list(map(lambda x: apply_mask(image=images[x], mask=masks[x], device=self.device), - range(len(images)))) + images = list( + map( + lambda x: apply_mask( + image=images[x], mask=masks[x], device=self.device + ), + range(len(images)), + ) + ) return images diff --git a/carvekit/ml/arch/basnet/basnet.py b/carvekit/ml/arch/basnet/basnet.py index caaf648..e2ead6a 100644 --- a/carvekit/ml/arch/basnet/basnet.py +++ b/carvekit/ml/arch/basnet/basnet.py @@ -10,8 +10,9 @@ def conv3x3(in_planes, out_planes, stride=1): """3x3 convolution with padding""" - return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, - padding=1, bias=False) + return nn.Conv2d( + in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False + ) class BasicBlock(nn.Module): @@ -92,8 +93,9 @@ def __init__(self, inplanes, planes, stride=1, downsample=None): super(Bottleneck, self).__init__() self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) self.bn1 = nn.BatchNorm2d(planes) - self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, - padding=1, bias=False) + self.conv2 = nn.Conv2d( + planes, planes, kernel_size=3, stride=stride, padding=1, bias=False + ) self.bn2 = nn.BatchNorm2d(planes) self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False) self.bn3 = nn.BatchNorm2d(planes * 4) @@ -176,7 +178,9 @@ def __init__(self, in_ch, inc_ch): self.conv_d0 = nn.Conv2d(64, 1, 3, padding=1) - self.upscore2 = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False) + self.upscore2 = nn.Upsample( + scale_factor=2, mode="bilinear", align_corners=False + ) def forward(self, x): hx = x @@ -269,7 +273,7 @@ def __init__(self, n_channels, n_classes): self.bn6d_1 = nn.BatchNorm2d(512) self.relu6d_1 = nn.ReLU(inplace=True) - self.conv6d_m = nn.Conv2d(512, 512, 3, dilation=2, padding=2) ### + self.conv6d_m = nn.Conv2d(512, 512, 3, dilation=2, padding=2) self.bn6d_m = nn.BatchNorm2d(512) self.relu6d_m = nn.ReLU(inplace=True) @@ -282,7 +286,7 @@ def __init__(self, n_channels, n_classes): self.bn5d_1 = nn.BatchNorm2d(512) self.relu5d_1 = nn.ReLU(inplace=True) - self.conv5d_m = nn.Conv2d(512, 512, 3, padding=1) ### + self.conv5d_m = nn.Conv2d(512, 512, 3, padding=1) self.bn5d_m = nn.BatchNorm2d(512) self.relu5d_m = nn.ReLU(inplace=True) @@ -295,7 +299,7 @@ def __init__(self, n_channels, n_classes): self.bn4d_1 = nn.BatchNorm2d(512) self.relu4d_1 = nn.ReLU(inplace=True) - self.conv4d_m = nn.Conv2d(512, 512, 3, padding=1) ### + self.conv4d_m = nn.Conv2d(512, 512, 3, padding=1) self.bn4d_m = nn.BatchNorm2d(512) self.relu4d_m = nn.ReLU(inplace=True) @@ -308,7 +312,7 @@ def __init__(self, n_channels, n_classes): self.bn3d_1 = nn.BatchNorm2d(256) self.relu3d_1 = nn.ReLU(inplace=True) - self.conv3d_m = nn.Conv2d(256, 256, 3, padding=1) ### + self.conv3d_m = nn.Conv2d(256, 256, 3, padding=1) self.bn3d_m = nn.BatchNorm2d(256) self.relu3d_m = nn.ReLU(inplace=True) @@ -322,7 +326,7 @@ def __init__(self, n_channels, n_classes): self.bn2d_1 = nn.BatchNorm2d(128) self.relu2d_1 = nn.ReLU(inplace=True) - self.conv2d_m = nn.Conv2d(128, 128, 3, padding=1) ### + self.conv2d_m = nn.Conv2d(128, 128, 3, padding=1) self.bn2d_m = nn.BatchNorm2d(128) self.relu2d_m = nn.ReLU(inplace=True) @@ -335,7 +339,7 @@ def __init__(self, n_channels, n_classes): self.bn1d_1 = nn.BatchNorm2d(64) self.relu1d_1 = nn.ReLU(inplace=True) - self.conv1d_m = nn.Conv2d(64, 64, 3, padding=1) ### + self.conv1d_m = nn.Conv2d(64, 64, 3, padding=1) self.bn1d_m = nn.BatchNorm2d(64) self.relu1d_m = nn.ReLU(inplace=True) @@ -344,11 +348,21 @@ def __init__(self, n_channels, n_classes): self.relu1d_2 = nn.ReLU(inplace=True) # -------------Bilinear Upsampling-------------- - self.upscore6 = nn.Upsample(scale_factor=32, mode='bilinear', align_corners=False) ### - self.upscore5 = nn.Upsample(scale_factor=16, mode='bilinear', align_corners=False) - self.upscore4 = nn.Upsample(scale_factor=8, mode='bilinear', align_corners=False) - self.upscore3 = nn.Upsample(scale_factor=4, mode='bilinear', align_corners=False) - self.upscore2 = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False) + self.upscore6 = nn.Upsample( + scale_factor=32, mode="bilinear", align_corners=False + ) + self.upscore5 = nn.Upsample( + scale_factor=16, mode="bilinear", align_corners=False + ) + self.upscore4 = nn.Upsample( + scale_factor=8, mode="bilinear", align_corners=False + ) + self.upscore3 = nn.Upsample( + scale_factor=4, mode="bilinear", align_corners=False + ) + self.upscore2 = nn.Upsample( + scale_factor=2, mode="bilinear", align_corners=False + ) # -------------Side Output-------------- self.outconvb = nn.Conv2d(512, 1, 3, padding=1) @@ -452,6 +466,13 @@ def forward(self, x): # -------------Refine Module------------- dout = self.refunet(d1) # 256 - return torch.sigmoid(dout), torch.sigmoid(d1), torch.sigmoid(d2), torch.sigmoid(d3), torch.sigmoid( - d4), torch.sigmoid(d5), torch.sigmoid( - d6), torch.sigmoid(db) + return ( + torch.sigmoid(dout), + torch.sigmoid(d1), + torch.sigmoid(d2), + torch.sigmoid(d3), + torch.sigmoid(d4), + torch.sigmoid(d5), + torch.sigmoid(d6), + torch.sigmoid(db), + ) diff --git a/carvekit/ml/arch/fba_matting/layers_WS.py b/carvekit/ml/arch/fba_matting/layers_WS.py index d972da4..5108598 100644 --- a/carvekit/ml/arch/fba_matting/layers_WS.py +++ b/carvekit/ml/arch/fba_matting/layers_WS.py @@ -9,23 +9,48 @@ class Conv2d(nn.Conv2d): - - def __init__(self, in_channels, out_channels, kernel_size, stride=1, - padding=0, dilation=1, groups=1, bias=True): - super(Conv2d, self).__init__(in_channels, out_channels, kernel_size, stride, - padding, dilation, groups, bias) + def __init__( + self, + in_channels, + out_channels, + kernel_size, + stride=1, + padding=0, + dilation=1, + groups=1, + bias=True, + ): + super(Conv2d, self).__init__( + in_channels, + out_channels, + kernel_size, + stride, + padding, + dilation, + groups, + bias, + ) def forward(self, x): # return super(Conv2d, self).forward(x) weight = self.weight - weight_mean = weight.mean(dim=1, keepdim=True).mean(dim=2, - keepdim=True).mean(dim=3, keepdim=True) + weight_mean = ( + weight.mean(dim=1, keepdim=True) + .mean(dim=2, keepdim=True) + .mean(dim=3, keepdim=True) + ) weight = weight - weight_mean # std = (weight).view(weight.size(0), -1).std(dim=1).view(-1, 1, 1, 1) + 1e-5 - std = torch.sqrt(torch.var(weight.view(weight.size(0), -1), dim=1) + 1e-12).view(-1, 1, 1, 1) + 1e-5 + std = ( + torch.sqrt(torch.var(weight.view(weight.size(0), -1), dim=1) + 1e-12).view( + -1, 1, 1, 1 + ) + + 1e-5 + ) weight = weight / std.expand_as(weight) - return F.conv2d(x, weight, self.bias, self.stride, - self.padding, self.dilation, self.groups) + return F.conv2d( + x, weight, self.bias, self.stride, self.padding, self.dilation, self.groups + ) def BatchNorm2d(num_features): diff --git a/carvekit/ml/arch/fba_matting/models.py b/carvekit/ml/arch/fba_matting/models.py index 35fd565..dc2b0a8 100644 --- a/carvekit/ml/arch/fba_matting/models.py +++ b/carvekit/ml/arch/fba_matting/models.py @@ -15,7 +15,7 @@ class FBA(nn.Module): def __init__(self, encoder: str): super(FBA, self).__init__() self.encoder = build_encoder(arch=encoder) - self.decoder = fba_decoder(batch_norm=True if 'BN' in encoder else False) + self.decoder = fba_decoder(batch_norm=True if "BN" in encoder else False) def forward(self, image, two_chan_trimap, image_n, trimap_transformed): resnet_input = torch.cat((image_n, trimap_transformed, two_chan_trimap), 1) @@ -28,13 +28,10 @@ def __init__(self, orig_resnet, dilate_scale=8): super(ResnetDilatedBN, self).__init__() if dilate_scale == 8: - orig_resnet.layer3.apply( - partial(self._nostride_dilate, dilate=2)) - orig_resnet.layer4.apply( - partial(self._nostride_dilate, dilate=4)) + orig_resnet.layer3.apply(partial(self._nostride_dilate, dilate=2)) + orig_resnet.layer4.apply(partial(self._nostride_dilate, dilate=4)) elif dilate_scale == 16: - orig_resnet.layer4.apply( - partial(self._nostride_dilate, dilate=2)) + orig_resnet.layer4.apply(partial(self._nostride_dilate, dilate=2)) # take pretrained resnet, except AvgPool and FC self.conv1 = orig_resnet.conv1 @@ -54,7 +51,7 @@ def __init__(self, orig_resnet, dilate_scale=8): def _nostride_dilate(self, m, dilate): classname = m.__class__.__name__ - if classname.find('Conv') != -1: + if classname.find("Conv") != -1: # the convolution with stride if m.stride == (2, 2): m.stride = (1, 1) @@ -136,13 +133,10 @@ def __init__(self, orig_resnet, dilate_scale=8): super(ResnetDilated, self).__init__() if dilate_scale == 8: - orig_resnet.layer3.apply( - partial(self._nostride_dilate, dilate=2)) - orig_resnet.layer4.apply( - partial(self._nostride_dilate, dilate=4)) + orig_resnet.layer3.apply(partial(self._nostride_dilate, dilate=2)) + orig_resnet.layer4.apply(partial(self._nostride_dilate, dilate=4)) elif dilate_scale == 16: - orig_resnet.layer4.apply( - partial(self._nostride_dilate, dilate=2)) + orig_resnet.layer4.apply(partial(self._nostride_dilate, dilate=2)) # take pretrained resnet, except AvgPool and FC self.conv1 = orig_resnet.conv1 @@ -156,7 +150,7 @@ def __init__(self, orig_resnet, dilate_scale=8): def _nostride_dilate(self, m, dilate): classname = m.__class__.__name__ - if classname.find('Conv') != -1: + if classname.find("Conv") != -1: # the convolution with stride if m.stride == (2, 2): m.stride = (1, 1) @@ -196,14 +190,15 @@ def norm(dim, bn=False): def fba_fusion(alpha, img, F, B): - F = (alpha * img + (1 - alpha ** 2) * F - alpha * (1 - alpha) * B) - B = ((1 - alpha) * img + (2 * alpha - alpha ** 2) * B - alpha * (1 - alpha) * F) + F = alpha * img + (1 - alpha**2) * F - alpha * (1 - alpha) * B + B = (1 - alpha) * img + (2 * alpha - alpha**2) * B - alpha * (1 - alpha) * F F = torch.clamp(F, 0, 1) B = torch.clamp(B, 0, 1) la = 0.1 alpha = (alpha * la + torch.sum((img - B) * (F - B), 1, keepdim=True)) / ( - torch.sum((F - B) * (F - B), 1, keepdim=True) + la) + torch.sum((F - B) * (F - B), 1, keepdim=True) + la + ) alpha = torch.clamp(alpha, 0, 1) return alpha, F, B @@ -217,53 +212,50 @@ def __init__(self, batch_norm=False): self.ppm = [] for scale in pool_scales: - self.ppm.append(nn.Sequential( - nn.AdaptiveAvgPool2d(scale), - L.Conv2d(2048, 256, kernel_size=1, bias=True), - norm(256, self.batch_norm), - nn.LeakyReLU() - )) + self.ppm.append( + nn.Sequential( + nn.AdaptiveAvgPool2d(scale), + L.Conv2d(2048, 256, kernel_size=1, bias=True), + norm(256, self.batch_norm), + nn.LeakyReLU(), + ) + ) self.ppm = nn.ModuleList(self.ppm) self.conv_up1 = nn.Sequential( - L.Conv2d(2048 + len(pool_scales) * 256, 256, - kernel_size=3, padding=1, bias=True), - + L.Conv2d( + 2048 + len(pool_scales) * 256, 256, kernel_size=3, padding=1, bias=True + ), norm(256, self.batch_norm), nn.LeakyReLU(), L.Conv2d(256, 256, kernel_size=3, padding=1), norm(256, self.batch_norm), - nn.LeakyReLU() + nn.LeakyReLU(), ) self.conv_up2 = nn.Sequential( - L.Conv2d(256 + 256, 256, - kernel_size=3, padding=1, bias=True), + L.Conv2d(256 + 256, 256, kernel_size=3, padding=1, bias=True), norm(256, self.batch_norm), - nn.LeakyReLU() + nn.LeakyReLU(), ) if self.batch_norm: d_up3 = 128 else: d_up3 = 64 self.conv_up3 = nn.Sequential( - L.Conv2d(256 + d_up3, 64, - kernel_size=3, padding=1, bias=True), + L.Conv2d(256 + d_up3, 64, kernel_size=3, padding=1, bias=True), norm(64, self.batch_norm), - nn.LeakyReLU() + nn.LeakyReLU(), ) self.unpool = nn.MaxUnpool2d(2, stride=2) self.conv_up4 = nn.Sequential( - nn.Conv2d(64 + 3 + 3 + 2, 32, - kernel_size=3, padding=1, bias=True), + nn.Conv2d(64 + 3 + 3 + 2, 32, kernel_size=3, padding=1, bias=True), nn.LeakyReLU(), - nn.Conv2d(32, 16, - kernel_size=3, padding=1, bias=True), - + nn.Conv2d(32, 16, kernel_size=3, padding=1, bias=True), nn.LeakyReLU(), - nn.Conv2d(16, 7, kernel_size=1, padding=0, bias=True) + nn.Conv2d(16, 7, kernel_size=1, padding=0, bias=True), ) def forward(self, conv_out, img, indices, two_chan_trimap): @@ -272,24 +264,34 @@ def forward(self, conv_out, img, indices, two_chan_trimap): input_size = conv5.size() ppm_out = [conv5] for pool_scale in self.ppm: - ppm_out.append(nn.functional.interpolate( - pool_scale(conv5), - (input_size[2], input_size[3]), - mode='bilinear', align_corners=False)) + ppm_out.append( + nn.functional.interpolate( + pool_scale(conv5), + (input_size[2], input_size[3]), + mode="bilinear", + align_corners=False, + ) + ) ppm_out = torch.cat(ppm_out, 1) x = self.conv_up1(ppm_out) - x = torch.nn.functional.interpolate(x, scale_factor=2, mode='bilinear', align_corners=False) + x = torch.nn.functional.interpolate( + x, scale_factor=2, mode="bilinear", align_corners=False + ) x = torch.cat((x, conv_out[-4]), 1) x = self.conv_up2(x) - x = torch.nn.functional.interpolate(x, scale_factor=2, mode='bilinear', align_corners=False) + x = torch.nn.functional.interpolate( + x, scale_factor=2, mode="bilinear", align_corners=False + ) x = torch.cat((x, conv_out[-5]), 1) x = self.conv_up3(x) - x = torch.nn.functional.interpolate(x, scale_factor=2, mode='bilinear', align_corners=False) + x = torch.nn.functional.interpolate( + x, scale_factor=2, mode="bilinear", align_corners=False + ) x = torch.cat((x, conv_out[-6][:, :3], img, two_chan_trimap), 1) output = self.conv_up4(x) @@ -306,22 +308,22 @@ def forward(self, conv_out, img, indices, two_chan_trimap): return output -def build_encoder(arch='resnet50_GN'): - if arch == 'resnet50_GN_WS': - orig_resnet = resnet_GN_WS.__dict__['l_resnet50']() +def build_encoder(arch="resnet50_GN"): + if arch == "resnet50_GN_WS": + orig_resnet = resnet_GN_WS.__dict__["l_resnet50"]() net_encoder = ResnetDilated(orig_resnet, dilate_scale=8) - elif arch == 'resnet50_BN': - orig_resnet = resnet_bn.__dict__['l_resnet50']() + elif arch == "resnet50_BN": + orig_resnet = resnet_bn.__dict__["l_resnet50"]() net_encoder = ResnetDilatedBN(orig_resnet, dilate_scale=8) else: - raise ValueError('Architecture undefined!') + raise ValueError("Architecture undefined!") num_channels = 3 + 6 + 2 if num_channels > 3: net_encoder_sd = net_encoder.state_dict() - conv1_weights = net_encoder_sd['conv1.weight'] + conv1_weights = net_encoder_sd["conv1.weight"] c_out, c_in, h, w = conv1_weights.size() conv1_mod = torch.zeros(c_out, num_channels, h, w) @@ -333,7 +335,7 @@ def build_encoder(arch='resnet50_GN'): net_encoder.conv1 = conv1 - net_encoder_sd['conv1.weight'] = conv1_mod + net_encoder_sd["conv1.weight"] = conv1_mod net_encoder.load_state_dict(net_encoder_sd) return net_encoder diff --git a/carvekit/ml/arch/fba_matting/resnet_GN_WS.py b/carvekit/ml/arch/fba_matting/resnet_GN_WS.py index 2945f09..a730ed6 100644 --- a/carvekit/ml/arch/fba_matting/resnet_GN_WS.py +++ b/carvekit/ml/arch/fba_matting/resnet_GN_WS.py @@ -6,13 +6,14 @@ import torch.nn as nn import carvekit.ml.arch.fba_matting.layers_WS as L -__all__ = ['ResNet', 'l_resnet50'] +__all__ = ["ResNet", "l_resnet50"] def conv3x3(in_planes, out_planes, stride=1): """3x3 convolution with padding""" - return L.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, - padding=1, bias=False) + return L.Conv2d( + in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False + ) def conv1x1(in_planes, out_planes, stride=1): @@ -91,15 +92,15 @@ def forward(self, x): class ResNet(nn.Module): - def __init__(self, block, layers, num_classes=1000): super(ResNet, self).__init__() self.inplanes = 64 - self.conv1 = L.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, - bias=False) + self.conv1 = L.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False) self.bn1 = L.BatchNorm2d(64) self.relu = nn.ReLU(inplace=True) - self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1, return_indices=True) + self.maxpool = nn.MaxPool2d( + kernel_size=3, stride=2, padding=1, return_indices=True + ) self.layer1 = self._make_layer(block, 64, layers[0]) self.layer2 = self._make_layer(block, 128, layers[1], stride=2) self.layer3 = self._make_layer(block, 256, layers[2], stride=2) diff --git a/carvekit/ml/arch/fba_matting/resnet_bn.py b/carvekit/ml/arch/fba_matting/resnet_bn.py index 89d6cd5..9662ca8 100644 --- a/carvekit/ml/arch/fba_matting/resnet_bn.py +++ b/carvekit/ml/arch/fba_matting/resnet_bn.py @@ -7,13 +7,14 @@ import math from torch.nn import BatchNorm2d -__all__ = ['ResNet'] +__all__ = ["ResNet"] def conv3x3(in_planes, out_planes, stride=1): "3x3 convolution with padding" - return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, - padding=1, bias=False) + return nn.Conv2d( + in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False + ) class BasicBlock(nn.Module): @@ -55,8 +56,9 @@ def __init__(self, inplanes, planes, stride=1, downsample=None): super(Bottleneck, self).__init__() self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) self.bn1 = BatchNorm2d(planes) - self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, - padding=1, bias=False) + self.conv2 = nn.Conv2d( + planes, planes, kernel_size=3, stride=stride, padding=1, bias=False + ) self.bn2 = BatchNorm2d(planes, momentum=0.01) self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False) self.bn3 = BatchNorm2d(planes * 4) @@ -88,7 +90,6 @@ def forward(self, x): class ResNet(nn.Module): - def __init__(self, block, layers, num_classes=1000): self.inplanes = 128 super(ResNet, self).__init__() @@ -101,7 +102,9 @@ def __init__(self, block, layers, num_classes=1000): self.conv3 = conv3x3(64, 128) self.bn3 = BatchNorm2d(128) self.relu3 = nn.ReLU(inplace=True) - self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1, return_indices=True) + self.maxpool = nn.MaxPool2d( + kernel_size=3, stride=2, padding=1, return_indices=True + ) self.layer1 = self._make_layer(block, 64, layers[0]) self.layer2 = self._make_layer(block, 128, layers[1], stride=2) @@ -113,7 +116,7 @@ def __init__(self, block, layers, num_classes=1000): for m in self.modules(): if isinstance(m, nn.Conv2d): n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels - m.weight.data.normal_(0, math.sqrt(2. / n)) + m.weight.data.normal_(0, math.sqrt(2.0 / n)) elif isinstance(m, BatchNorm2d): m.weight.data.fill_(1) m.bias.data.zero_() @@ -122,8 +125,13 @@ def _make_layer(self, block, planes, blocks, stride=1): downsample = None if stride != 1 or self.inplanes != planes * block.expansion: downsample = nn.Sequential( - nn.Conv2d(self.inplanes, planes * block.expansion, - kernel_size=1, stride=stride, bias=False), + nn.Conv2d( + self.inplanes, + planes * block.expansion, + kernel_size=1, + stride=stride, + bias=False, + ), BatchNorm2d(planes * block.expansion), ) diff --git a/carvekit/ml/arch/fba_matting/transforms.py b/carvekit/ml/arch/fba_matting/transforms.py index b5b792d..20251ef 100644 --- a/carvekit/ml/arch/fba_matting/transforms.py +++ b/carvekit/ml/arch/fba_matting/transforms.py @@ -19,7 +19,7 @@ def trimap_transform(trimap): clicks = np.zeros((h, w, 6)) for k in range(2): - if (np.count_nonzero(trimap[:, :, k]) > 0): + if np.count_nonzero(trimap[:, :, k]) > 0: dt_mask = -dt(1 - trimap[:, :, k]) ** 2 L = 320 clicks[:, :, 3 * k] = np.exp(dt_mask / (2 * ((0.02 * L) ** 2))) @@ -29,15 +29,17 @@ def trimap_transform(trimap): return clicks -def groupnorm_normalise_image(img, format='nhwc'): - ''' - Accept rgb in range 0,1 - ''' - if (format == 'nhwc'): +def groupnorm_normalise_image(img, format="nhwc"): + """ + Accept rgb in range 0,1 + """ + if format == "nhwc": for i in range(3): img[..., i] = (img[..., i] - group_norm_mean[i]) / group_norm_std[i] else: for i in range(3): - img[..., i, :, :] = (img[..., i, :, :] - group_norm_mean[i]) / group_norm_std[i] + img[..., i, :, :] = ( + img[..., i, :, :] - group_norm_mean[i] + ) / group_norm_std[i] return img diff --git a/carvekit/ml/arch/tracerb7/__init__.py b/carvekit/ml/arch/tracerb7/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/carvekit/ml/arch/tracerb7/att_modules.py b/carvekit/ml/arch/tracerb7/att_modules.py new file mode 100644 index 0000000..07e4740 --- /dev/null +++ b/carvekit/ml/arch/tracerb7/att_modules.py @@ -0,0 +1,290 @@ +""" +Source url: https://github.com/Karel911/TRACER +Author: Min Seok Lee and Wooseok Shin +License: Apache License 2.0 +""" +import torch +import torch.nn as nn +import torch.nn.functional as F + +from carvekit.ml.arch.tracerb7.conv_modules import BasicConv2d, DWConv, DWSConv + + +class RFB_Block(nn.Module): + def __init__(self, in_channel, out_channel): + super(RFB_Block, self).__init__() + self.relu = nn.ReLU(True) + self.branch0 = nn.Sequential( + BasicConv2d(in_channel, out_channel, 1), + ) + self.branch1 = nn.Sequential( + BasicConv2d(in_channel, out_channel, 1), + BasicConv2d(out_channel, out_channel, kernel_size=(1, 3), padding=(0, 1)), + BasicConv2d(out_channel, out_channel, kernel_size=(3, 1), padding=(1, 0)), + BasicConv2d(out_channel, out_channel, 3, padding=3, dilation=3), + ) + self.branch2 = nn.Sequential( + BasicConv2d(in_channel, out_channel, 1), + BasicConv2d(out_channel, out_channel, kernel_size=(1, 5), padding=(0, 2)), + BasicConv2d(out_channel, out_channel, kernel_size=(5, 1), padding=(2, 0)), + BasicConv2d(out_channel, out_channel, 3, padding=5, dilation=5), + ) + self.branch3 = nn.Sequential( + BasicConv2d(in_channel, out_channel, 1), + BasicConv2d(out_channel, out_channel, kernel_size=(1, 7), padding=(0, 3)), + BasicConv2d(out_channel, out_channel, kernel_size=(7, 1), padding=(3, 0)), + BasicConv2d(out_channel, out_channel, 3, padding=7, dilation=7), + ) + self.conv_cat = BasicConv2d(4 * out_channel, out_channel, 3, padding=1) + self.conv_res = BasicConv2d(in_channel, out_channel, 1) + + def forward(self, x): + x0 = self.branch0(x) + x1 = self.branch1(x) + x2 = self.branch2(x) + x3 = self.branch3(x) + x_cat = torch.cat((x0, x1, x2, x3), 1) + x_cat = self.conv_cat(x_cat) + + x = self.relu(x_cat + self.conv_res(x)) + return x + + +class GlobalAvgPool(nn.Module): + def __init__(self, flatten=False): + super(GlobalAvgPool, self).__init__() + self.flatten = flatten + + def forward(self, x): + if self.flatten: + in_size = x.size() + return x.view((in_size[0], in_size[1], -1)).mean(dim=2) + else: + return ( + x.view(x.size(0), x.size(1), -1) + .mean(-1) + .view(x.size(0), x.size(1), 1, 1) + ) + + +class UnionAttentionModule(nn.Module): + def __init__(self, n_channels, only_channel_tracing=False): + super(UnionAttentionModule, self).__init__() + self.GAP = GlobalAvgPool() + self.confidence_ratio = 0.1 + self.bn = nn.BatchNorm2d(n_channels) + self.norm = nn.Sequential( + nn.BatchNorm2d(n_channels), nn.Dropout3d(self.confidence_ratio) + ) + self.channel_q = nn.Conv2d( + in_channels=n_channels, + out_channels=n_channels, + kernel_size=1, + stride=1, + padding=0, + bias=False, + ) + self.channel_k = nn.Conv2d( + in_channels=n_channels, + out_channels=n_channels, + kernel_size=1, + stride=1, + padding=0, + bias=False, + ) + self.channel_v = nn.Conv2d( + in_channels=n_channels, + out_channels=n_channels, + kernel_size=1, + stride=1, + padding=0, + bias=False, + ) + + self.fc = nn.Conv2d( + in_channels=n_channels, + out_channels=n_channels, + kernel_size=1, + stride=1, + padding=0, + bias=False, + ) + + if only_channel_tracing is False: + self.spatial_q = nn.Conv2d( + in_channels=n_channels, + out_channels=1, + kernel_size=1, + stride=1, + padding=0, + bias=False, + ) + self.spatial_k = nn.Conv2d( + in_channels=n_channels, + out_channels=1, + kernel_size=1, + stride=1, + padding=0, + bias=False, + ) + self.spatial_v = nn.Conv2d( + in_channels=n_channels, + out_channels=1, + kernel_size=1, + stride=1, + padding=0, + bias=False, + ) + self.sigmoid = nn.Sigmoid() + + def masking(self, x, mask): + mask = mask.squeeze(3).squeeze(2) + threshold = torch.quantile( + mask.float(), self.confidence_ratio, dim=-1, keepdim=True + ) + mask[mask <= threshold] = 0.0 + mask = mask.unsqueeze(2).unsqueeze(3) + mask = mask.expand(-1, x.shape[1], x.shape[2], x.shape[3]).contiguous() + masked_x = x * mask + + return masked_x + + def Channel_Tracer(self, x): + avg_pool = self.GAP(x) + x_norm = self.norm(avg_pool) + + q = self.channel_q(x_norm).squeeze(-1) + k = self.channel_k(x_norm).squeeze(-1) + v = self.channel_v(x_norm).squeeze(-1) + + # softmax(Q*K^T) + QK_T = torch.matmul(q, k.transpose(1, 2)) + alpha = F.softmax(QK_T, dim=-1) + + # a*v + att = torch.matmul(alpha, v).unsqueeze(-1) + att = self.fc(att) + att = self.sigmoid(att) + + output = (x * att) + x + alpha_mask = att.clone() + + return output, alpha_mask + + def forward(self, x): + X_c, alpha_mask = self.Channel_Tracer(x) + X_c = self.bn(X_c) + x_drop = self.masking(X_c, alpha_mask) + + q = self.spatial_q(x_drop).squeeze(1) + k = self.spatial_k(x_drop).squeeze(1) + v = self.spatial_v(x_drop).squeeze(1) + + # softmax(Q*K^T) + QK_T = torch.matmul(q, k.transpose(1, 2)) + alpha = F.softmax(QK_T, dim=-1) + + output = torch.matmul(alpha, v).unsqueeze(1) + v.unsqueeze(1) + + return output + + +class aggregation(nn.Module): + def __init__(self, channel): + super(aggregation, self).__init__() + self.relu = nn.ReLU(True) + + self.upsample = nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True) + self.conv_upsample1 = BasicConv2d(channel[2], channel[1], 3, padding=1) + self.conv_upsample2 = BasicConv2d(channel[2], channel[0], 3, padding=1) + self.conv_upsample3 = BasicConv2d(channel[1], channel[0], 3, padding=1) + self.conv_upsample4 = BasicConv2d(channel[2], channel[2], 3, padding=1) + self.conv_upsample5 = BasicConv2d( + channel[2] + channel[1], channel[2] + channel[1], 3, padding=1 + ) + + self.conv_concat2 = BasicConv2d( + (channel[2] + channel[1]), (channel[2] + channel[1]), 3, padding=1 + ) + self.conv_concat3 = BasicConv2d( + (channel[0] + channel[1] + channel[2]), + (channel[0] + channel[1] + channel[2]), + 3, + padding=1, + ) + + self.UAM = UnionAttentionModule(channel[0] + channel[1] + channel[2]) + + def forward(self, e4, e3, e2): + e4_1 = e4 + e3_1 = self.conv_upsample1(self.upsample(e4)) * e3 + e2_1 = ( + self.conv_upsample2(self.upsample(self.upsample(e4))) + * self.conv_upsample3(self.upsample(e3)) + * e2 + ) + + e3_2 = torch.cat((e3_1, self.conv_upsample4(self.upsample(e4_1))), 1) + e3_2 = self.conv_concat2(e3_2) + + e2_2 = torch.cat((e2_1, self.conv_upsample5(self.upsample(e3_2))), 1) + x = self.conv_concat3(e2_2) + + output = self.UAM(x) + + return output + + +class ObjectAttention(nn.Module): + def __init__(self, channel, kernel_size): + super(ObjectAttention, self).__init__() + self.channel = channel + self.DWSConv = DWSConv( + channel, channel // 2, kernel=kernel_size, padding=1, kernels_per_layer=1 + ) + self.DWConv1 = nn.Sequential( + DWConv(channel // 2, channel // 2, kernel=1, padding=0, dilation=1), + BasicConv2d(channel // 2, channel // 8, 1), + ) + self.DWConv2 = nn.Sequential( + DWConv(channel // 2, channel // 2, kernel=3, padding=1, dilation=1), + BasicConv2d(channel // 2, channel // 8, 1), + ) + self.DWConv3 = nn.Sequential( + DWConv(channel // 2, channel // 2, kernel=3, padding=3, dilation=3), + BasicConv2d(channel // 2, channel // 8, 1), + ) + self.DWConv4 = nn.Sequential( + DWConv(channel // 2, channel // 2, kernel=3, padding=5, dilation=5), + BasicConv2d(channel // 2, channel // 8, 1), + ) + self.conv1 = BasicConv2d(channel // 2, 1, 1) + + def forward(self, decoder_map, encoder_map): + """ + Args: + decoder_map: decoder representation (B, 1, H, W). + encoder_map: encoder block output (B, C, H, W). + Returns: + decoder representation: (B, 1, H, W) + """ + mask_bg = -1 * torch.sigmoid(decoder_map) + 1 # Sigmoid & Reverse + mask_ob = torch.sigmoid(decoder_map) # object attention + x = mask_ob.expand(-1, self.channel, -1, -1).mul(encoder_map) + + edge = mask_bg.clone() + edge[edge > 0.93] = 0 + x = x + (edge * encoder_map) + + x = self.DWSConv(x) + skip = x.clone() + x = ( + torch.cat( + [self.DWConv1(x), self.DWConv2(x), self.DWConv3(x), self.DWConv4(x)], + dim=1, + ) + + skip + ) + x = torch.relu(self.conv1(x)) + + return x + decoder_map diff --git a/carvekit/ml/arch/tracerb7/conv_modules.py b/carvekit/ml/arch/tracerb7/conv_modules.py new file mode 100644 index 0000000..90395a6 --- /dev/null +++ b/carvekit/ml/arch/tracerb7/conv_modules.py @@ -0,0 +1,88 @@ +""" +Source url: https://github.com/Karel911/TRACER +Author: Min Seok Lee and Wooseok Shin +License: Apache License 2.0 +""" +import torch.nn as nn + + +class BasicConv2d(nn.Module): + def __init__( + self, + in_channel, + out_channel, + kernel_size, + stride=(1, 1), + padding=(0, 0), + dilation=(1, 1), + ): + super(BasicConv2d, self).__init__() + self.conv = nn.Conv2d( + in_channel, + out_channel, + kernel_size=kernel_size, + stride=stride, + padding=padding, + dilation=dilation, + bias=False, + ) + self.bn = nn.BatchNorm2d(out_channel) + self.selu = nn.SELU() + + def forward(self, x): + x = self.conv(x) + x = self.bn(x) + x = self.selu(x) + + return x + + +class DWConv(nn.Module): + def __init__(self, in_channel, out_channel, kernel, dilation, padding): + super(DWConv, self).__init__() + self.out_channel = out_channel + self.DWConv = nn.Conv2d( + in_channel, + out_channel, + kernel_size=kernel, + padding=padding, + groups=in_channel, + dilation=dilation, + bias=False, + ) + self.bn = nn.BatchNorm2d(out_channel) + self.selu = nn.SELU() + + def forward(self, x): + x = self.DWConv(x) + out = self.selu(self.bn(x)) + + return out + + +class DWSConv(nn.Module): + def __init__(self, in_channel, out_channel, kernel, padding, kernels_per_layer): + super(DWSConv, self).__init__() + self.out_channel = out_channel + self.DWConv = nn.Conv2d( + in_channel, + in_channel * kernels_per_layer, + kernel_size=kernel, + padding=padding, + groups=in_channel, + bias=False, + ) + self.bn = nn.BatchNorm2d(in_channel * kernels_per_layer) + self.selu = nn.SELU() + self.PWConv = nn.Conv2d( + in_channel * kernels_per_layer, out_channel, kernel_size=1, bias=False + ) + self.bn2 = nn.BatchNorm2d(out_channel) + + def forward(self, x): + x = self.DWConv(x) + x = self.selu(self.bn(x)) + out = self.PWConv(x) + out = self.selu(self.bn2(out)) + + return out diff --git a/carvekit/ml/arch/tracerb7/effi_utils.py b/carvekit/ml/arch/tracerb7/effi_utils.py new file mode 100644 index 0000000..b578ca2 --- /dev/null +++ b/carvekit/ml/arch/tracerb7/effi_utils.py @@ -0,0 +1,579 @@ +""" +Original author: lukemelas (github username) +Github repo: https://github.com/lukemelas/EfficientNet-PyTorch +With adjustments and added comments by workingcoder (github username). +License: Apache License 2.0 +Reimplemented: Min Seok Lee and Wooseok Shin +""" + +import collections +import re +from functools import partial + +import math +import torch +from torch import nn +from torch.nn import functional as F + +# Parameters for the entire model (stem, all blocks, and head) +GlobalParams = collections.namedtuple( + "GlobalParams", + [ + "width_coefficient", + "depth_coefficient", + "image_size", + "dropout_rate", + "num_classes", + "batch_norm_momentum", + "batch_norm_epsilon", + "drop_connect_rate", + "depth_divisor", + "min_depth", + "include_top", + ], +) + +# Parameters for an individual model block +BlockArgs = collections.namedtuple( + "BlockArgs", + [ + "num_repeat", + "kernel_size", + "stride", + "expand_ratio", + "input_filters", + "output_filters", + "se_ratio", + "id_skip", + ], +) + +# Set GlobalParams and BlockArgs's defaults +GlobalParams.__new__.__defaults__ = (None,) * len(GlobalParams._fields) +BlockArgs.__new__.__defaults__ = (None,) * len(BlockArgs._fields) + + +# An ordinary implementation of Swish function +class Swish(nn.Module): + def forward(self, x): + return x * torch.sigmoid(x) + + +# A memory-efficient implementation of Swish function +class SwishImplementation(torch.autograd.Function): + @staticmethod + def forward(ctx, i): + result = i * torch.sigmoid(i) + ctx.save_for_backward(i) + return result + + @staticmethod + def backward(ctx, grad_output): + i = ctx.saved_tensors[0] + sigmoid_i = torch.sigmoid(i) + return grad_output * (sigmoid_i * (1 + i * (1 - sigmoid_i))) + + +class MemoryEfficientSwish(nn.Module): + def forward(self, x): + return SwishImplementation.apply(x) + + +def round_filters(filters, global_params): + """Calculate and round number of filters based on width multiplier. + Use width_coefficient, depth_divisor and min_depth of global_params. + + Args: + filters (int): Filters number to be calculated. + global_params (namedtuple): Global params of the model. + + Returns: + new_filters: New filters number after calculating. + """ + multiplier = global_params.width_coefficient + if not multiplier: + return filters + divisor = global_params.depth_divisor + min_depth = global_params.min_depth + filters *= multiplier + min_depth = min_depth or divisor # pay attention to this line when using min_depth + # follow the formula transferred from official TensorFlow implementation + new_filters = max(min_depth, int(filters + divisor / 2) // divisor * divisor) + if new_filters < 0.9 * filters: # prevent rounding by more than 10% + new_filters += divisor + return int(new_filters) + + +def round_repeats(repeats, global_params): + """Calculate module's repeat number of a block based on depth multiplier. + Use depth_coefficient of global_params. + + Args: + repeats (int): num_repeat to be calculated. + global_params (namedtuple): Global params of the model. + + Returns: + new repeat: New repeat number after calculating. + """ + multiplier = global_params.depth_coefficient + if not multiplier: + return repeats + # follow the formula transferred from official TensorFlow implementation + return int(math.ceil(multiplier * repeats)) + + +def drop_connect(inputs, p, training): + """Drop connect. + + Args: + input (tensor: BCWH): Input of this structure. + p (float: 0.0~1.0): Probability of drop connection. + training (bool): The running mode. + + Returns: + output: Output after drop connection. + """ + assert 0 <= p <= 1, "p must be in range of [0,1]" + + if not training: + return inputs + + batch_size = inputs.shape[0] + keep_prob = 1 - p + + # generate binary_tensor mask according to probability (p for 0, 1-p for 1) + random_tensor = keep_prob + random_tensor += torch.rand( + [batch_size, 1, 1, 1], dtype=inputs.dtype, device=inputs.device + ) + binary_tensor = torch.floor(random_tensor) + + output = inputs / keep_prob * binary_tensor + return output + + +def get_width_and_height_from_size(x): + """Obtain height and width from x. + + Args: + x (int, tuple or list): Data size. + + Returns: + size: A tuple or list (H,W). + """ + if isinstance(x, int): + return x, x + if isinstance(x, list) or isinstance(x, tuple): + return x + else: + raise TypeError() + + +def calculate_output_image_size(input_image_size, stride): + """Calculates the output image size when using Conv2dSamePadding with a stride. + Necessary for static padding. Thanks to mannatsingh for pointing this out. + + Args: + input_image_size (int, tuple or list): Size of input image. + stride (int, tuple or list): Conv2d operation's stride. + + Returns: + output_image_size: A list [H,W]. + """ + if input_image_size is None: + return None + image_height, image_width = get_width_and_height_from_size(input_image_size) + stride = stride if isinstance(stride, int) else stride[0] + image_height = int(math.ceil(image_height / stride)) + image_width = int(math.ceil(image_width / stride)) + return [image_height, image_width] + + +# Note: +# The following 'SamePadding' functions make output size equal ceil(input size/stride). +# Only when stride equals 1, can the output size be the same as input size. +# Don't be confused by their function names ! ! ! + + +def get_same_padding_conv2d(image_size=None): + """Chooses static padding if you have specified an image size, and dynamic padding otherwise. + Static padding is necessary for ONNX exporting of models. + + Args: + image_size (int or tuple): Size of the image. + + Returns: + Conv2dDynamicSamePadding or Conv2dStaticSamePadding. + """ + if image_size is None: + return Conv2dDynamicSamePadding + else: + return partial(Conv2dStaticSamePadding, image_size=image_size) + + +class Conv2dDynamicSamePadding(nn.Conv2d): + """2D Convolutions like TensorFlow, for a dynamic image size. + The padding is operated in forward function by calculating dynamically. + """ + + # Tips for 'SAME' mode padding. + # Given the following: + # i: width or height + # s: stride + # k: kernel size + # d: dilation + # p: padding + # Output after Conv2d: + # o = floor((i+p-((k-1)*d+1))/s+1) + # If o equals i, i = floor((i+p-((k-1)*d+1))/s+1), + # => p = (i-1)*s+((k-1)*d+1)-i + + def __init__( + self, + in_channels, + out_channels, + kernel_size, + stride=1, + dilation=1, + groups=1, + bias=True, + ): + super().__init__( + in_channels, out_channels, kernel_size, stride, 0, dilation, groups, bias + ) + self.stride = self.stride if len(self.stride) == 2 else [self.stride[0]] * 2 + + def forward(self, x): + ih, iw = x.size()[-2:] + kh, kw = self.weight.size()[-2:] + sh, sw = self.stride + oh, ow = math.ceil(ih / sh), math.ceil( + iw / sw + ) # change the output size according to stride ! ! ! + pad_h = max((oh - 1) * self.stride[0] + (kh - 1) * self.dilation[0] + 1 - ih, 0) + pad_w = max((ow - 1) * self.stride[1] + (kw - 1) * self.dilation[1] + 1 - iw, 0) + if pad_h > 0 or pad_w > 0: + x = F.pad( + x, [pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2] + ) + return F.conv2d( + x, + self.weight, + self.bias, + self.stride, + self.padding, + self.dilation, + self.groups, + ) + + +class Conv2dStaticSamePadding(nn.Conv2d): + """2D Convolutions like TensorFlow's 'SAME' mode, with the given input image size. + The padding mudule is calculated in construction function, then used in forward. + """ + + # With the same calculation as Conv2dDynamicSamePadding + + def __init__( + self, + in_channels, + out_channels, + kernel_size, + stride=1, + image_size=None, + **kwargs + ): + super().__init__(in_channels, out_channels, kernel_size, stride, **kwargs) + self.stride = self.stride if len(self.stride) == 2 else [self.stride[0]] * 2 + + # Calculate padding based on image size and save it + assert image_size is not None + ih, iw = (image_size, image_size) if isinstance(image_size, int) else image_size + kh, kw = self.weight.size()[-2:] + sh, sw = self.stride + oh, ow = math.ceil(ih / sh), math.ceil(iw / sw) + pad_h = max((oh - 1) * self.stride[0] + (kh - 1) * self.dilation[0] + 1 - ih, 0) + pad_w = max((ow - 1) * self.stride[1] + (kw - 1) * self.dilation[1] + 1 - iw, 0) + if pad_h > 0 or pad_w > 0: + self.static_padding = nn.ZeroPad2d( + (pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2) + ) + else: + self.static_padding = nn.Identity() + + def forward(self, x): + x = self.static_padding(x) + x = F.conv2d( + x, + self.weight, + self.bias, + self.stride, + self.padding, + self.dilation, + self.groups, + ) + return x + + +def get_same_padding_maxPool2d(image_size=None): + """Chooses static padding if you have specified an image size, and dynamic padding otherwise. + Static padding is necessary for ONNX exporting of models. + + Args: + image_size (int or tuple): Size of the image. + + Returns: + MaxPool2dDynamicSamePadding or MaxPool2dStaticSamePadding. + """ + if image_size is None: + return MaxPool2dDynamicSamePadding + else: + return partial(MaxPool2dStaticSamePadding, image_size=image_size) + + +class MaxPool2dDynamicSamePadding(nn.MaxPool2d): + """2D MaxPooling like TensorFlow's 'SAME' mode, with a dynamic image size. + The padding is operated in forward function by calculating dynamically. + """ + + def __init__( + self, + kernel_size, + stride, + padding=0, + dilation=1, + return_indices=False, + ceil_mode=False, + ): + super().__init__( + kernel_size, stride, padding, dilation, return_indices, ceil_mode + ) + self.stride = [self.stride] * 2 if isinstance(self.stride, int) else self.stride + self.kernel_size = ( + [self.kernel_size] * 2 + if isinstance(self.kernel_size, int) + else self.kernel_size + ) + self.dilation = ( + [self.dilation] * 2 if isinstance(self.dilation, int) else self.dilation + ) + + def forward(self, x): + ih, iw = x.size()[-2:] + kh, kw = self.kernel_size + sh, sw = self.stride + oh, ow = math.ceil(ih / sh), math.ceil(iw / sw) + pad_h = max((oh - 1) * self.stride[0] + (kh - 1) * self.dilation[0] + 1 - ih, 0) + pad_w = max((ow - 1) * self.stride[1] + (kw - 1) * self.dilation[1] + 1 - iw, 0) + if pad_h > 0 or pad_w > 0: + x = F.pad( + x, [pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2] + ) + return F.max_pool2d( + x, + self.kernel_size, + self.stride, + self.padding, + self.dilation, + self.ceil_mode, + self.return_indices, + ) + + +class MaxPool2dStaticSamePadding(nn.MaxPool2d): + """2D MaxPooling like TensorFlow's 'SAME' mode, with the given input image size. + The padding mudule is calculated in construction function, then used in forward. + """ + + def __init__(self, kernel_size, stride, image_size=None, **kwargs): + super().__init__(kernel_size, stride, **kwargs) + self.stride = [self.stride] * 2 if isinstance(self.stride, int) else self.stride + self.kernel_size = ( + [self.kernel_size] * 2 + if isinstance(self.kernel_size, int) + else self.kernel_size + ) + self.dilation = ( + [self.dilation] * 2 if isinstance(self.dilation, int) else self.dilation + ) + + # Calculate padding based on image size and save it + assert image_size is not None + ih, iw = (image_size, image_size) if isinstance(image_size, int) else image_size + kh, kw = self.kernel_size + sh, sw = self.stride + oh, ow = math.ceil(ih / sh), math.ceil(iw / sw) + pad_h = max((oh - 1) * self.stride[0] + (kh - 1) * self.dilation[0] + 1 - ih, 0) + pad_w = max((ow - 1) * self.stride[1] + (kw - 1) * self.dilation[1] + 1 - iw, 0) + if pad_h > 0 or pad_w > 0: + self.static_padding = nn.ZeroPad2d( + (pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2) + ) + else: + self.static_padding = nn.Identity() + + def forward(self, x): + x = self.static_padding(x) + x = F.max_pool2d( + x, + self.kernel_size, + self.stride, + self.padding, + self.dilation, + self.ceil_mode, + self.return_indices, + ) + return x + + +class BlockDecoder(object): + """Block Decoder for readability, + straight from the official TensorFlow repository. + """ + + @staticmethod + def _decode_block_string(block_string): + """Get a block through a string notation of arguments. + + Args: + block_string (str): A string notation of arguments. + Examples: 'r1_k3_s11_e1_i32_o16_se0.25_noskip'. + + Returns: + BlockArgs: The namedtuple defined at the top of this file. + """ + assert isinstance(block_string, str) + + ops = block_string.split("_") + options = {} + for op in ops: + splits = re.split(r"(\d.*)", op) + if len(splits) >= 2: + key, value = splits[:2] + options[key] = value + + # Check stride + assert ("s" in options and len(options["s"]) == 1) or ( + len(options["s"]) == 2 and options["s"][0] == options["s"][1] + ) + + return BlockArgs( + num_repeat=int(options["r"]), + kernel_size=int(options["k"]), + stride=[int(options["s"][0])], + expand_ratio=int(options["e"]), + input_filters=int(options["i"]), + output_filters=int(options["o"]), + se_ratio=float(options["se"]) if "se" in options else None, + id_skip=("noskip" not in block_string), + ) + + @staticmethod + def _encode_block_string(block): + """Encode a block to a string. + + Args: + block (namedtuple): A BlockArgs type argument. + + Returns: + block_string: A String form of BlockArgs. + """ + args = [ + "r%d" % block.num_repeat, + "k%d" % block.kernel_size, + "s%d%d" % (block.strides[0], block.strides[1]), + "e%s" % block.expand_ratio, + "i%d" % block.input_filters, + "o%d" % block.output_filters, + ] + if 0 < block.se_ratio <= 1: + args.append("se%s" % block.se_ratio) + if block.id_skip is False: + args.append("noskip") + return "_".join(args) + + @staticmethod + def decode(string_list): + """Decode a list of string notations to specify blocks inside the network. + + Args: + string_list (list[str]): A list of strings, each string is a notation of block. + + Returns: + blocks_args: A list of BlockArgs namedtuples of block args. + """ + assert isinstance(string_list, list) + blocks_args = [] + for block_string in string_list: + blocks_args.append(BlockDecoder._decode_block_string(block_string)) + return blocks_args + + @staticmethod + def encode(blocks_args): + """Encode a list of BlockArgs to a list of strings. + + Args: + blocks_args (list[namedtuples]): A list of BlockArgs namedtuples of block args. + + Returns: + block_strings: A list of strings, each string is a notation of block. + """ + block_strings = [] + for block in blocks_args: + block_strings.append(BlockDecoder._encode_block_string(block)) + return block_strings + + +def create_block_args( + width_coefficient=None, + depth_coefficient=None, + image_size=None, + dropout_rate=0.2, + drop_connect_rate=0.2, + num_classes=1000, + include_top=True, +): + """Create BlockArgs and GlobalParams for efficientnet model. + + Args: + width_coefficient (float) + depth_coefficient (float) + image_size (int) + dropout_rate (float) + drop_connect_rate (float) + num_classes (int) + + Meaning as the name suggests. + + Returns: + blocks_args, global_params. + """ + + # Blocks args for the whole model(efficientnet-b0 by default) + # It will be modified in the construction of EfficientNet Class according to model + blocks_args = [ + "r1_k3_s11_e1_i32_o16_se0.25", + "r2_k3_s22_e6_i16_o24_se0.25", + "r2_k5_s22_e6_i24_o40_se0.25", + "r3_k3_s22_e6_i40_o80_se0.25", + "r3_k5_s11_e6_i80_o112_se0.25", + "r4_k5_s22_e6_i112_o192_se0.25", + "r1_k3_s11_e6_i192_o320_se0.25", + ] + blocks_args = BlockDecoder.decode(blocks_args) + + global_params = GlobalParams( + width_coefficient=width_coefficient, + depth_coefficient=depth_coefficient, + image_size=image_size, + dropout_rate=dropout_rate, + num_classes=num_classes, + batch_norm_momentum=0.99, + batch_norm_epsilon=1e-3, + drop_connect_rate=drop_connect_rate, + depth_divisor=8, + min_depth=None, + include_top=include_top, + ) + + return blocks_args, global_params diff --git a/carvekit/ml/arch/tracerb7/efficientnet.py b/carvekit/ml/arch/tracerb7/efficientnet.py new file mode 100644 index 0000000..ea327d0 --- /dev/null +++ b/carvekit/ml/arch/tracerb7/efficientnet.py @@ -0,0 +1,325 @@ +""" +Source url: https://github.com/lukemelas/EfficientNet-PyTorch +Modified by Min Seok Lee, Wooseok Shin, Nikita Selin +License: Apache License 2.0 +Changes: + - Added support for extracting edge features + - Added support for extracting object features at different levels + - Refactored the code +""" +from typing import Any, List + +import torch +from torch import nn +from torch.nn import functional as F + +from carvekit.ml.arch.tracerb7.effi_utils import ( + get_same_padding_conv2d, + calculate_output_image_size, + MemoryEfficientSwish, + drop_connect, + round_filters, + round_repeats, + Swish, + create_block_args, +) + + +class MBConvBlock(nn.Module): + """Mobile Inverted Residual Bottleneck Block. + + Args: + block_args (namedtuple): BlockArgs, defined in utils.py. + global_params (namedtuple): GlobalParam, defined in utils.py. + image_size (tuple or list): [image_height, image_width]. + + References: + [1] https://arxiv.org/abs/1704.04861 (MobileNet v1) + [2] https://arxiv.org/abs/1801.04381 (MobileNet v2) + [3] https://arxiv.org/abs/1905.02244 (MobileNet v3) + """ + + def __init__(self, block_args, global_params, image_size=None): + super().__init__() + self._block_args = block_args + self._bn_mom = ( + 1 - global_params.batch_norm_momentum + ) # pytorch's difference from tensorflow + self._bn_eps = global_params.batch_norm_epsilon + self.has_se = (self._block_args.se_ratio is not None) and ( + 0 < self._block_args.se_ratio <= 1 + ) + self.id_skip = ( + block_args.id_skip + ) # whether to use skip connection and drop connect + + # Expansion phase (Inverted Bottleneck) + inp = self._block_args.input_filters # number of input channels + oup = ( + self._block_args.input_filters * self._block_args.expand_ratio + ) # number of output channels + if self._block_args.expand_ratio != 1: + Conv2d = get_same_padding_conv2d(image_size=image_size) + self._expand_conv = Conv2d( + in_channels=inp, out_channels=oup, kernel_size=1, bias=False + ) + self._bn0 = nn.BatchNorm2d( + num_features=oup, momentum=self._bn_mom, eps=self._bn_eps + ) + # image_size = calculate_output_image_size(image_size, 1) <-- this wouldn't modify image_size + + # Depthwise convolution phase + k = self._block_args.kernel_size + s = self._block_args.stride + Conv2d = get_same_padding_conv2d(image_size=image_size) + self._depthwise_conv = Conv2d( + in_channels=oup, + out_channels=oup, + groups=oup, # groups makes it depthwise + kernel_size=k, + stride=s, + bias=False, + ) + self._bn1 = nn.BatchNorm2d( + num_features=oup, momentum=self._bn_mom, eps=self._bn_eps + ) + image_size = calculate_output_image_size(image_size, s) + + # Squeeze and Excitation layer, if desired + if self.has_se: + Conv2d = get_same_padding_conv2d(image_size=(1, 1)) + num_squeezed_channels = max( + 1, int(self._block_args.input_filters * self._block_args.se_ratio) + ) + self._se_reduce = Conv2d( + in_channels=oup, out_channels=num_squeezed_channels, kernel_size=1 + ) + self._se_expand = Conv2d( + in_channels=num_squeezed_channels, out_channels=oup, kernel_size=1 + ) + + # Pointwise convolution phase + final_oup = self._block_args.output_filters + Conv2d = get_same_padding_conv2d(image_size=image_size) + self._project_conv = Conv2d( + in_channels=oup, out_channels=final_oup, kernel_size=1, bias=False + ) + self._bn2 = nn.BatchNorm2d( + num_features=final_oup, momentum=self._bn_mom, eps=self._bn_eps + ) + self._swish = MemoryEfficientSwish() + + def forward(self, inputs, drop_connect_rate=None): + """MBConvBlock's forward function. + + Args: + inputs (tensor): Input tensor. + drop_connect_rate (bool): Drop connect rate (float, between 0 and 1). + + Returns: + Output of this block after processing. + """ + + # Expansion and Depthwise Convolution + x = inputs + if self._block_args.expand_ratio != 1: + x = self._expand_conv(inputs) + x = self._bn0(x) + x = self._swish(x) + + x = self._depthwise_conv(x) + x = self._bn1(x) + x = self._swish(x) + + # Squeeze and Excitation + if self.has_se: + x_squeezed = F.adaptive_avg_pool2d(x, 1) + x_squeezed = self._se_reduce(x_squeezed) + x_squeezed = self._swish(x_squeezed) + x_squeezed = self._se_expand(x_squeezed) + x = torch.sigmoid(x_squeezed) * x + + # Pointwise Convolution + x = self._project_conv(x) + x = self._bn2(x) + + # Skip connection and drop connect + input_filters, output_filters = ( + self._block_args.input_filters, + self._block_args.output_filters, + ) + if ( + self.id_skip + and self._block_args.stride == 1 + and input_filters == output_filters + ): + # The combination of skip connection and drop connect brings about stochastic depth. + if drop_connect_rate: + x = drop_connect(x, p=drop_connect_rate, training=self.training) + x = x + inputs # skip connection + return x + + def set_swish(self, memory_efficient=True): + """Sets swish function as memory efficient (for training) or standard (for export). + + Args: + memory_efficient (bool): Whether to use memory-efficient version of swish. + """ + self._swish = MemoryEfficientSwish() if memory_efficient else Swish() + + +class EfficientNet(nn.Module): + def __init__(self, blocks_args=None, global_params=None): + super().__init__() + assert isinstance(blocks_args, list), "blocks_args should be a list" + assert len(blocks_args) > 0, "block args must be greater than 0" + self._global_params = global_params + self._blocks_args = blocks_args + + # Batch norm parameters + bn_mom = 1 - self._global_params.batch_norm_momentum + bn_eps = self._global_params.batch_norm_epsilon + + # Get stem static or dynamic convolution depending on image size + image_size = global_params.image_size + Conv2d = get_same_padding_conv2d(image_size=image_size) + + # Stem + in_channels = 3 # rgb + out_channels = round_filters( + 32, self._global_params + ) # number of output channels + self._conv_stem = Conv2d( + in_channels, out_channels, kernel_size=3, stride=2, bias=False + ) + self._bn0 = nn.BatchNorm2d( + num_features=out_channels, momentum=bn_mom, eps=bn_eps + ) + image_size = calculate_output_image_size(image_size, 2) + + # Build blocks + self._blocks = nn.ModuleList([]) + for block_args in self._blocks_args: + + # Update block input and output filters based on depth multiplier. + block_args = block_args._replace( + input_filters=round_filters( + block_args.input_filters, self._global_params + ), + output_filters=round_filters( + block_args.output_filters, self._global_params + ), + num_repeat=round_repeats(block_args.num_repeat, self._global_params), + ) + + # The first block needs to take care of stride and filter size increase. + self._blocks.append( + MBConvBlock(block_args, self._global_params, image_size=image_size) + ) + image_size = calculate_output_image_size(image_size, block_args.stride) + if block_args.num_repeat > 1: # modify block_args to keep same output size + block_args = block_args._replace( + input_filters=block_args.output_filters, stride=1 + ) + for _ in range(block_args.num_repeat - 1): + self._blocks.append( + MBConvBlock(block_args, self._global_params, image_size=image_size) + ) + # image_size = calculate_output_image_size(image_size, block_args.stride) # stride = 1 + + self._swish = MemoryEfficientSwish() + + def set_swish(self, memory_efficient=True): + """Sets swish function as memory efficient (for training) or standard (for export). + + Args: + memory_efficient (bool): Whether to use memory-efficient version of swish. + + """ + self._swish = MemoryEfficientSwish() if memory_efficient else Swish() + for block in self._blocks: + block.set_swish(memory_efficient) + + def extract_endpoints(self, inputs): + endpoints = dict() + + # Stem + x = self._swish(self._bn0(self._conv_stem(inputs))) + prev_x = x + + # Blocks + for idx, block in enumerate(self._blocks): + drop_connect_rate = self._global_params.drop_connect_rate + if drop_connect_rate: + drop_connect_rate *= float(idx) / len( + self._blocks + ) # scale drop connect_rate + x = block(x, drop_connect_rate=drop_connect_rate) + if prev_x.size(2) > x.size(2): + endpoints["reduction_{}".format(len(endpoints) + 1)] = prev_x + prev_x = x + + # Head + x = self._swish(self._bn1(self._conv_head(x))) + endpoints["reduction_{}".format(len(endpoints) + 1)] = x + + return endpoints + + def _change_in_channels(self, in_channels): + """Adjust model's first convolution layer to in_channels, if in_channels not equals 3. + + Args: + in_channels (int): Input data's channel number. + """ + if in_channels != 3: + Conv2d = get_same_padding_conv2d(image_size=self._global_params.image_size) + out_channels = round_filters(32, self._global_params) + self._conv_stem = Conv2d( + in_channels, out_channels, kernel_size=3, stride=2, bias=False + ) + + +class EfficientEncoderB7(EfficientNet): + def __init__(self): + super().__init__( + *create_block_args( + width_coefficient=2.0, + depth_coefficient=3.1, + dropout_rate=0.5, + image_size=600, + ) + ) + self._change_in_channels(3) + self.block_idx = [10, 17, 37, 54] + self.channels = [48, 80, 224, 640] + + def initial_conv(self, inputs): + x = self._swish(self._bn0(self._conv_stem(inputs))) + return x + + def get_blocks(self, x, H, W, block_idx): + features = [] + for idx, block in enumerate(self._blocks): + drop_connect_rate = self._global_params.drop_connect_rate + if drop_connect_rate: + drop_connect_rate *= float(idx) / len( + self._blocks + ) # scale drop connect_rate + x = block(x, drop_connect_rate=drop_connect_rate) + if idx == block_idx[0]: + features.append(x.clone()) + if idx == block_idx[1]: + features.append(x.clone()) + if idx == block_idx[2]: + features.append(x.clone()) + if idx == block_idx[3]: + features.append(x.clone()) + + return features + + def forward(self, inputs: torch.Tensor) -> List[Any]: + B, C, H, W = inputs.size() + x = self.initial_conv(inputs) # Prepare input for the backbone + return self.get_blocks( + x, H, W, block_idx=self.block_idx + ) # Get backbone features and edge maps diff --git a/carvekit/ml/arch/tracerb7/tracer.py b/carvekit/ml/arch/tracerb7/tracer.py new file mode 100644 index 0000000..70cc3f2 --- /dev/null +++ b/carvekit/ml/arch/tracerb7/tracer.py @@ -0,0 +1,97 @@ +""" +Source url: https://github.com/Karel911/TRACER +Author: Min Seok Lee and Wooseok Shin +Modified by Nikita Selin (OPHoperHPO)[https://github.com/OPHoperHPO]. +License: Apache License 2.0 +Changes: + - Refactored code + - Removed unused code + - Added comments +""" + +import torch +import torch.nn as nn +import torch.nn.functional as F +from typing import List, Optional, Tuple + +from torch import Tensor + +from carvekit.ml.arch.tracerb7.efficientnet import EfficientEncoderB7 +from carvekit.ml.arch.tracerb7.att_modules import ( + RFB_Block, + aggregation, + ObjectAttention, +) + + +class TracerDecoder(nn.Module): + """Tracer Decoder""" + + def __init__( + self, + encoder: EfficientEncoderB7, + features_channels: Optional[List[int]] = None, + rfb_channel: Optional[List[int]] = None, + ): + """ + Initialize the tracer decoder. + + Args: + encoder: The encoder to use. + features_channels: The channels of the backbone features at different stages. default: [48, 80, 224, 640] + rfb_channel: The channels of the RFB features. default: [32, 64, 128] + """ + super().__init__() + if rfb_channel is None: + rfb_channel = [32, 64, 128] + if features_channels is None: + features_channels = [48, 80, 224, 640] + self.encoder = encoder + self.features_channels = features_channels + + # Receptive Field Blocks + features_channels = rfb_channel + self.rfb2 = RFB_Block(self.features_channels[1], features_channels[0]) + self.rfb3 = RFB_Block(self.features_channels[2], features_channels[1]) + self.rfb4 = RFB_Block(self.features_channels[3], features_channels[2]) + + # Multi-level aggregation + self.agg = aggregation(features_channels) + + # Object Attention + self.ObjectAttention2 = ObjectAttention( + channel=self.features_channels[1], kernel_size=3 + ) + self.ObjectAttention1 = ObjectAttention( + channel=self.features_channels[0], kernel_size=3 + ) + + def forward(self, inputs: torch.Tensor) -> Tensor: + """ + Forward pass of the tracer decoder. + + Args: + inputs: Preprocessed images. + + Returns: + Tensors of segmentation masks and mask of object edges. + """ + features = self.encoder(inputs) + x3_rfb = self.rfb2(features[1]) + x4_rfb = self.rfb3(features[2]) + x5_rfb = self.rfb4(features[3]) + + D_0 = self.agg(x5_rfb, x4_rfb, x3_rfb) + + ds_map0 = F.interpolate(D_0, scale_factor=8, mode="bilinear") + + D_1 = self.ObjectAttention2(D_0, features[1]) + ds_map1 = F.interpolate(D_1, scale_factor=8, mode="bilinear") + + ds_map = F.interpolate(D_1, scale_factor=2, mode="bilinear") + D_2 = self.ObjectAttention1(ds_map, features[0]) + ds_map2 = F.interpolate(D_2, scale_factor=4, mode="bilinear") + + final_map = (ds_map2 + ds_map1 + ds_map0) / 3 + + return torch.sigmoid(final_map) diff --git a/carvekit/ml/arch/u2net/u2net.py b/carvekit/ml/arch/u2net/u2net.py index e951523..2225acf 100644 --- a/carvekit/ml/arch/u2net/u2net.py +++ b/carvekit/ml/arch/u2net/u2net.py @@ -10,11 +10,11 @@ import math -__all__ = ['U2NETArchitecture'] +__all__ = ["U2NETArchitecture"] def _upsample_like(x, size): - return nn.Upsample(size=size, mode='bilinear', align_corners=False)(x) + return nn.Upsample(size=size, mode="bilinear", align_corners=False)(x) def _size_map(x, height): @@ -31,7 +31,9 @@ class REBNCONV(nn.Module): def __init__(self, in_ch=3, out_ch=3, dilate=1): super(REBNCONV, self).__init__() - self.conv_s1 = nn.Conv2d(in_ch, out_ch, 3, padding=1 * dilate, dilation=1 * dilate) + self.conv_s1 = nn.Conv2d( + in_ch, out_ch, 3, padding=1 * dilate, dilation=1 * dilate + ) self.bn_s1 = nn.BatchNorm2d(out_ch) self.relu_s1 = nn.ReLU(inplace=True) @@ -54,33 +56,39 @@ def forward(self, x): # U-Net like symmetric encoder-decoder structure def unet(x, height=1): if height < self.height: - x1 = getattr(self, f'rebnconv{height}')(x) + x1 = getattr(self, f"rebnconv{height}")(x) if not self.dilated and height < self.height - 1: - x2 = unet(getattr(self, 'downsample')(x1), height + 1) + x2 = unet(getattr(self, "downsample")(x1), height + 1) else: x2 = unet(x1, height + 1) - x = getattr(self, f'rebnconv{height}d')(torch.cat((x2, x1), 1)) - return _upsample_like(x, sizes[height - 1]) if not self.dilated and height > 1 else x + x = getattr(self, f"rebnconv{height}d")(torch.cat((x2, x1), 1)) + return ( + _upsample_like(x, sizes[height - 1]) + if not self.dilated and height > 1 + else x + ) else: - return getattr(self, f'rebnconv{height}')(x) + return getattr(self, f"rebnconv{height}")(x) return x + unet(x) def _make_layers(self, height, in_ch, mid_ch, out_ch, dilated=False): - self.add_module('rebnconvin', REBNCONV(in_ch, out_ch)) - self.add_module('downsample', nn.MaxPool2d(2, stride=2, ceil_mode=True)) + self.add_module("rebnconvin", REBNCONV(in_ch, out_ch)) + self.add_module("downsample", nn.MaxPool2d(2, stride=2, ceil_mode=True)) - self.add_module(f'rebnconv1', REBNCONV(out_ch, mid_ch)) - self.add_module(f'rebnconv1d', REBNCONV(mid_ch * 2, out_ch)) + self.add_module("rebnconv1", REBNCONV(out_ch, mid_ch)) + self.add_module("rebnconv1d", REBNCONV(mid_ch * 2, out_ch)) for i in range(2, height): dilate = 1 if not dilated else 2 ** (i - 1) - self.add_module(f'rebnconv{i}', REBNCONV(mid_ch, mid_ch, dilate=dilate)) - self.add_module(f'rebnconv{i}d', REBNCONV(mid_ch * 2, mid_ch, dilate=dilate)) + self.add_module(f"rebnconv{i}", REBNCONV(mid_ch, mid_ch, dilate=dilate)) + self.add_module( + f"rebnconv{i}d", REBNCONV(mid_ch * 2, mid_ch, dilate=dilate) + ) dilate = 2 if not dilated else 2 ** (height - 1) - self.add_module(f'rebnconv{height}', REBNCONV(mid_ch, mid_ch, dilate=dilate)) + self.add_module(f"rebnconv{height}", REBNCONV(mid_ch, mid_ch, dilate=dilate)) class U2NETArchitecture(nn.Module): @@ -91,17 +99,17 @@ def __init__(self, cfg_type: Union[dict, str] = "full", out_ch: int = 1): layers_cfgs = { # cfgs for building RSUs and sides # {stage : [name, (height(L), in_ch, mid_ch, out_ch, dilated), side]} - 'stage1': ['En_1', (7, 3, 32, 64), -1], - 'stage2': ['En_2', (6, 64, 32, 128), -1], - 'stage3': ['En_3', (5, 128, 64, 256), -1], - 'stage4': ['En_4', (4, 256, 128, 512), -1], - 'stage5': ['En_5', (4, 512, 256, 512, True), -1], - 'stage6': ['En_6', (4, 512, 256, 512, True), 512], - 'stage5d': ['De_5', (4, 1024, 256, 512, True), 512], - 'stage4d': ['De_4', (4, 1024, 128, 256), 256], - 'stage3d': ['De_3', (5, 512, 64, 128), 128], - 'stage2d': ['De_2', (6, 256, 32, 64), 64], - 'stage1d': ['De_1', (7, 128, 16, 64), 64], + "stage1": ["En_1", (7, 3, 32, 64), -1], + "stage2": ["En_2", (6, 64, 32, 128), -1], + "stage3": ["En_3", (5, 128, 64, 256), -1], + "stage4": ["En_4", (4, 256, 128, 512), -1], + "stage5": ["En_5", (4, 512, 256, 512, True), -1], + "stage6": ["En_6", (4, 512, 256, 512, True), 512], + "stage5d": ["De_5", (4, 1024, 256, 512, True), 512], + "stage4d": ["De_4", (4, 1024, 128, 256), 256], + "stage3d": ["De_3", (5, 512, 64, 128), 128], + "stage2d": ["De_2", (6, 256, 32, 64), 64], + "stage1d": ["De_1", (7, 128, 16, 64), 64], } else: raise ValueError("Unknown U^2-Net architecture conf. name") @@ -119,19 +127,19 @@ def forward(self, x): # side saliency map def unet(x, height=1): if height < 6: - x1 = getattr(self, f'stage{height}')(x) - x2 = unet(getattr(self, 'downsample')(x1), height + 1) - x = getattr(self, f'stage{height}d')(torch.cat((x2, x1), 1)) + x1 = getattr(self, f"stage{height}")(x) + x2 = unet(getattr(self, "downsample")(x1), height + 1) + x = getattr(self, f"stage{height}d")(torch.cat((x2, x1), 1)) side(x, height) return _upsample_like(x, sizes[height - 1]) if height > 1 else x else: - x = getattr(self, f'stage{height}')(x) + x = getattr(self, f"stage{height}")(x) side(x, height) return _upsample_like(x, sizes[height - 1]) def side(x, h): # side output saliency map (before sigmoid) - x = getattr(self, f'side{h}')(x) + x = getattr(self, f"side{h}")(x) x = _upsample_like(x, sizes[1]) maps.append(x) @@ -139,7 +147,7 @@ def fuse(): # fuse saliency probability maps maps.reverse() x = torch.cat(maps, 1) - x = getattr(self, 'outconv')(x) + x = getattr(self, "outconv")(x) maps.insert(0, x) return [torch.sigmoid(x) for x in maps] @@ -149,12 +157,16 @@ def fuse(): def _make_layers(self, cfgs): self.height = int((len(cfgs) + 1) / 2) - self.add_module('downsample', nn.MaxPool2d(2, stride=2, ceil_mode=True)) + self.add_module("downsample", nn.MaxPool2d(2, stride=2, ceil_mode=True)) for k, v in cfgs.items(): # build rsu block self.add_module(k, RSU(v[0], *v[1])) if v[2] > 0: # build side layer - self.add_module(f'side{v[0][-1]}', nn.Conv2d(v[2], self.out_ch, 3, padding=1)) + self.add_module( + f"side{v[0][-1]}", nn.Conv2d(v[2], self.out_ch, 3, padding=1) + ) # build fuse layer - self.add_module('outconv', nn.Conv2d(int(self.height * self.out_ch), self.out_ch, 1)) + self.add_module( + "outconv", nn.Conv2d(int(self.height * self.out_ch), self.out_ch, 1) + ) diff --git a/carvekit/ml/files/__init__.py b/carvekit/ml/files/__init__.py index 563c5d6..dc535a7 100644 --- a/carvekit/ml/files/__init__.py +++ b/carvekit/ml/files/__init__.py @@ -1,4 +1,7 @@ from pathlib import Path -carvekit_dir = Path.home().joinpath('.carvekit') -checkpoints_dir = carvekit_dir.joinpath('checkpoints') +carvekit_dir = Path.home().joinpath(".cache/carvekit") + +carvekit_dir.mkdir(parents=True, exist_ok=True) + +checkpoints_dir = carvekit_dir.joinpath("checkpoints") diff --git a/carvekit/ml/files/models_loc.py b/carvekit/ml/files/models_loc.py index 4c9a464..45f9a56 100644 --- a/carvekit/ml/files/models_loc.py +++ b/carvekit/ml/files/models_loc.py @@ -5,43 +5,61 @@ """ import pathlib from carvekit.ml.files import checkpoints_dir -from carvekit.utils.download_models import check_for_exists +from carvekit.utils.download_models import downloader def u2net_full_pretrained() -> pathlib.Path: - """ Returns u2net pretrained model location + """Returns u2net pretrained model location Returns: pathlib.Path to model location """ - return check_for_exists(checkpoints_dir.joinpath('u2net/u2net.pth')) + return downloader("u2net.pth") def basnet_pretrained() -> pathlib.Path: - """ Returns basnet pretrained model location + """Returns basnet pretrained model location Returns: pathlib.Path to model location """ - return check_for_exists(checkpoints_dir.joinpath('basnet/basnet.pth')) + return downloader("basnet.pth") def deeplab_pretrained() -> pathlib.Path: - """ Returns basnet pretrained model location + """Returns basnet pretrained model location Returns: pathlib.Path to model location """ - return check_for_exists(checkpoints_dir.joinpath('deeplab/deeplab.pth')) + return downloader("deeplab.pth") def fba_pretrained() -> pathlib.Path: - """ Returns basnet pretrained model location + """Returns basnet pretrained model location Returns: pathlib.Path to model location """ - return check_for_exists(checkpoints_dir.joinpath('fba_matting/fba_matting.pth')) + return downloader("fba_matting.pth") + + +def tracer_b7_pretrained() -> pathlib.Path: + """Returns TRACER with EfficientNet v1 b7 encoder pretrained model location + + Returns: + pathlib.Path to model location + """ + return downloader("tracer_b7.pth") + + +def tracer_hair_pretrained() -> pathlib.Path: + """Returns TRACER with EfficientNet v1 b7 encoder model for hair segmentation location + + Returns: + pathlib.Path to model location + """ + return downloader("tracer_hair.pth") def download_all(): @@ -49,3 +67,4 @@ def download_all(): fba_pretrained() deeplab_pretrained() basnet_pretrained() + tracer_b7_pretrained() diff --git a/carvekit/ml/wrap/basnet.py b/carvekit/ml/wrap/basnet.py index 0ab81a0..9912e81 100644 --- a/carvekit/ml/wrap/basnet.py +++ b/carvekit/ml/wrap/basnet.py @@ -22,41 +22,48 @@ class BASNET(BASNet): """BASNet model interface""" - def __init__(self, device='cpu', - input_tensor_size: Union[List[int], int] = 320, - batch_size: int = 10, - load_pretrained: bool = True): + def __init__( + self, + device="cpu", + input_image_size: Union[List[int], int] = 320, + batch_size: int = 10, + load_pretrained: bool = True, + fp16: bool = False, + ): """ - Initialize the BASNET model + Initialize the BASNET model - Args: - device: processing device - input_tensor_size: input image size - batch_size: the number of images that the neural network processes in one run - load_pretrained: loading pretrained model + Args: + device: processing device + input_image_size: input image size + batch_size: the number of images that the neural network processes in one run + load_pretrained: loading pretrained model + fp16: use fp16 precision // not supported at this moment """ super(BASNET, self).__init__(n_channels=3, n_classes=1) self.device = device self.batch_size = batch_size - if isinstance(input_tensor_size, list): - self.input_image_size = input_tensor_size[:2] + if isinstance(input_image_size, list): + self.input_image_size = input_image_size[:2] else: - self.input_image_size = (input_tensor_size, input_tensor_size) + self.input_image_size = (input_image_size, input_image_size) self.to(device) if load_pretrained: - self.load_state_dict(torch.load(basnet_pretrained(), map_location=self.device)) + self.load_state_dict( + torch.load(basnet_pretrained(), map_location=self.device) + ) self.eval() def data_preprocessing(self, data: PIL.Image.Image) -> torch.Tensor: """ - Transform input image to suitable data format for neural network + Transform input image to suitable data format for neural network - Args: - data: input image + Args: + data: input image - Returns: - input for neural network + Returns: + input for neural network """ resized = data.resize(self.input_image_size) @@ -73,18 +80,19 @@ def data_preprocessing(self, data: PIL.Image.Image) -> torch.Tensor: return torch.from_numpy(temp_image).type(torch.FloatTensor) @staticmethod - def data_postprocessing(data: torch.tensor, - original_image: PIL.Image.Image) -> PIL.Image.Image: + def data_postprocessing( + data: torch.tensor, original_image: PIL.Image.Image + ) -> PIL.Image.Image: """ - Transforms output data from neural network to suitable data - format for using with other components of this framework. + Transforms output data from neural network to suitable data + format for using with other components of this framework. - Args: - data: output data from neural network - original_image: input image which was used for predicted data + Args: + data: output data from neural network + original_image: input image which was used for predicted data - Returns: - Segmentation mask as PIL Image instance + Returns: + Segmentation mask as PIL Image instance """ data = data.unsqueeze(0) @@ -97,27 +105,37 @@ def data_postprocessing(data: torch.tensor, mask = mask.resize(original_image.size, resample=3) return mask - def __call__(self, images: List[Union[str, pathlib.Path, PIL.Image.Image]]) -> List[PIL.Image.Image]: + def __call__( + self, images: List[Union[str, pathlib.Path, PIL.Image.Image]] + ) -> List[PIL.Image.Image]: """ - Passes input images through neural network and returns segmentation masks as PIL.Image.Image instances + Passes input images through neural network and returns segmentation masks as PIL.Image.Image instances - Args: - images: input images + Args: + images: input images - Returns: - segmentation masks as for input images, as PIL.Image.Image instances + Returns: + segmentation masks as for input images, as PIL.Image.Image instances """ collect_masks = [] for image_batch in batch_generator(images, self.batch_size): - images = thread_pool_processing(lambda x: convert_image(load_image(x)), image_batch) - batches = torch.vstack(thread_pool_processing(self.data_preprocessing, images)) + images = thread_pool_processing( + lambda x: convert_image(load_image(x)), image_batch + ) + batches = torch.vstack( + thread_pool_processing(self.data_preprocessing, images) + ) with torch.no_grad(): batches = batches.to(self.device) - masks, d2, d3, d4, d5, d6, d7, d8 = super(BASNET, self).__call__(batches) + masks, d2, d3, d4, d5, d6, d7, d8 = super(BASNET, self).__call__( + batches + ) masks_cpu = masks.cpu() del d2, d3, d4, d5, d6, d7, d8, batches, masks - masks = thread_pool_processing(lambda x: self.data_postprocessing(masks_cpu[x], images[x]), - range(len(images))) + masks = thread_pool_processing( + lambda x: self.data_postprocessing(masks_cpu[x], images[x]), + range(len(images)), + ) collect_masks += masks return collect_masks diff --git a/carvekit/ml/wrap/deeplab_v3.py b/carvekit/ml/wrap/deeplab_v3.py index 11dd517..4b19542 100644 --- a/carvekit/ml/wrap/deeplab_v3.py +++ b/carvekit/ml/wrap/deeplab_v3.py @@ -13,42 +13,56 @@ from torchvision.models.segmentation import deeplabv3_resnet101 from carvekit.ml.files.models_loc import deeplab_pretrained from carvekit.utils.image_utils import convert_image, load_image +from carvekit.utils.models_utils import get_precision_autocast, cast_network from carvekit.utils.pool_utils import batch_generator, thread_pool_processing __all__ = ["DeepLabV3"] class DeepLabV3: - def __init__(self, device='cpu', - batch_size: int = 10, - input_image_size: Union[List[int], int] = 512, - load_pretrained: bool = True): + def __init__( + self, + device="cpu", + batch_size: int = 10, + input_image_size: Union[List[int], int] = 1024, + load_pretrained: bool = True, + fp16: bool = False, + ): """ - Initialize the DeepLabV3 model + Initialize the DeepLabV3 model - Args: - device: processing device - input_tensor_size: input image size - batch_size: the number of images that the neural network processes in one run - load_pretrained: loading pretrained model + Args: + device: processing device + input_image_size: input image size + batch_size: the number of images that the neural network processes in one run + load_pretrained: loading pretrained model + fp16: use half precision """ self.device = device self.batch_size = batch_size - self.network = deeplabv3_resnet101(pretrained=False, pretrained_backbone=False, aux_loss=True) + self.network = deeplabv3_resnet101( + pretrained=False, pretrained_backbone=False, aux_loss=True + ) self.network.to(self.device) if load_pretrained: - self.network.load_state_dict(torch.load(deeplab_pretrained(), map_location=self.device)) + self.network.load_state_dict( + torch.load(deeplab_pretrained(), map_location=self.device) + ) if isinstance(input_image_size, list): self.input_image_size = input_image_size[:2] else: self.input_image_size = (input_image_size, input_image_size) self.network.eval() - self.data_preprocessing = transforms.Compose([ - transforms.ToTensor(), - transforms.Resize(self.input_image_size), - transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), - ]) + self.fp16 = fp16 + self.transform = transforms.Compose( + [ + transforms.ToTensor(), + transforms.Normalize( + mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] + ), + ] + ) def to(self, device: str): """ @@ -62,42 +76,75 @@ def to(self, device: str): """ self.network.to(device) + def data_preprocessing(self, data: PIL.Image.Image) -> torch.Tensor: + """ + Transform input image to suitable data format for neural network + + Args: + data: input image + + Returns: + input for neural network + + """ + copy = data.copy() + copy.thumbnail(self.input_image_size, resample=3) + return self.transform(copy) + @staticmethod - def data_postprocessing(data: torch.tensor, - original_image: PIL.Image.Image) -> PIL.Image.Image: + def data_postprocessing( + data: torch.tensor, original_image: PIL.Image.Image + ) -> PIL.Image.Image: """ - Transforms output data from neural network to suitable data - format for using with other components of this framework. + Transforms output data from neural network to suitable data + format for using with other components of this framework. - Args: - data: output data from neural network - original_image: input image which was used for predicted data + Args: + data: output data from neural network + original_image: input image which was used for predicted data - Returns: - Segmentation mask as PIL Image instance + Returns: + Segmentation mask as PIL Image instance """ - return Image.fromarray(data.numpy() * 255).convert("L").resize(original_image.size) + return ( + Image.fromarray(data.numpy() * 255).convert("L").resize(original_image.size) + ) - def __call__(self, images: List[Union[str, pathlib.Path, PIL.Image.Image]]) -> List[PIL.Image.Image]: + def __call__( + self, images: List[Union[str, pathlib.Path, PIL.Image.Image]] + ) -> List[PIL.Image.Image]: """ - Passes input images though neural network and returns segmentation masks as PIL.Image.Image instances + Passes input images though neural network and returns segmentation masks as PIL.Image.Image instances - Args: - images: input images + Args: + images: input images - Returns: - segmentation masks as for input images, as PIL.Image.Image instances + Returns: + segmentation masks as for input images, as PIL.Image.Image instances """ collect_masks = [] - for image_batch in batch_generator(images, self.batch_size): - images = thread_pool_processing(lambda x: convert_image(load_image(x)), image_batch) - batches = thread_pool_processing(self.data_preprocessing, images) - with torch.no_grad(): - masks = [self.network(i.to(self.device).unsqueeze(0))['out'][0].argmax(0).byte().cpu() for i in batches] - del batches - masks = thread_pool_processing(lambda x: self.data_postprocessing(masks[x], images[x]), - range(len(images))) - collect_masks += masks + autocast, dtype = get_precision_autocast(device=self.device, fp16=self.fp16) + with autocast: + cast_network(self.network, dtype) + for image_batch in batch_generator(images, self.batch_size): + images = thread_pool_processing( + lambda x: convert_image(load_image(x)), image_batch + ) + batches = thread_pool_processing(self.data_preprocessing, images) + with torch.no_grad(): + masks = [ + self.network(i.to(self.device).unsqueeze(0))["out"][0] + .argmax(0) + .byte() + .cpu() + for i in batches + ] + del batches + masks = thread_pool_processing( + lambda x: self.data_postprocessing(masks[x], images[x]), + range(len(images)), + ) + collect_masks += masks return collect_masks diff --git a/carvekit/ml/wrap/fba_matting.py b/carvekit/ml/wrap/fba_matting.py index efac07c..c285df0 100644 --- a/carvekit/ml/wrap/fba_matting.py +++ b/carvekit/ml/wrap/fba_matting.py @@ -13,9 +13,13 @@ from PIL import Image from carvekit.ml.arch.fba_matting.models import FBA -from carvekit.ml.arch.fba_matting.transforms import trimap_transform, groupnorm_normalise_image +from carvekit.ml.arch.fba_matting.transforms import ( + trimap_transform, + groupnorm_normalise_image, +) from carvekit.ml.files.models_loc import fba_pretrained from carvekit.utils.image_utils import convert_image, load_image +from carvekit.utils.models_utils import get_precision_autocast, cast_network from carvekit.utils.pool_utils import batch_generator, thread_pool_processing __all__ = ["FBAMatting"] @@ -26,23 +30,29 @@ class FBAMatting(FBA): FBA Matting Neural Network to improve edges on image. """ - def __init__(self, device='cpu', - input_tensor_size: Union[List[int], int] = 2048, - batch_size: int = 2, - encoder="resnet50_GN_WS", - load_pretrained: bool = True): + def __init__( + self, + device="cpu", + input_tensor_size: Union[List[int], int] = 2048, + batch_size: int = 2, + encoder="resnet50_GN_WS", + load_pretrained: bool = True, + fp16: bool = False, + ): """ - Initialize the FBAMatting model + Initialize the FBAMatting model - Args: - device: processing device - input_tensor_size: input image size - batch_size: the number of images that the neural network processes in one run - encoder: neural network encoder head - load_pretrained: loading pretrained model + Args: + device: processing device + input_tensor_size: input image size + batch_size: the number of images that the neural network processes in one run + encoder: neural network encoder head + load_pretrained: loading pretrained model + fp16: use half precision """ super(FBAMatting, self).__init__(encoder=encoder) + self.fp16 = fp16 self.device = device self.batch_size = batch_size if isinstance(input_tensor_size, list): @@ -54,16 +64,17 @@ def __init__(self, device='cpu', self.load_state_dict(torch.load(fba_pretrained(), map_location=self.device)) self.eval() - def data_preprocessing(self, data: Union[PIL.Image.Image, np.ndarray]) -> Tuple[torch.FloatTensor, - torch.FloatTensor]: + def data_preprocessing( + self, data: Union[PIL.Image.Image, np.ndarray] + ) -> Tuple[torch.FloatTensor, torch.FloatTensor]: """ - Transform input image to suitable data format for neural network + Transform input image to suitable data format for neural network - Args: - data: input image + Args: + data: input image - Returns: - input for neural network + Returns: + input for neural network """ resized = data.copy() @@ -74,7 +85,7 @@ def data_preprocessing(self, data: Union[PIL.Image.Image, np.ndarray]) -> Tuple[ # noinspection PyTypeChecker image = np.array(resized, dtype=np.float64) image = image / 255.0 # Normalize image to [0, 1] values range - if resized.mode == 'RGB': + if resized.mode == "RGB": image = image[:, :, ::-1] elif resized.mode == "L": image2 = np.copy(image) @@ -83,38 +94,44 @@ def data_preprocessing(self, data: Union[PIL.Image.Image, np.ndarray]) -> Tuple[ image[image2 == 1, 1] = 1 image[image2 == 0, 0] = 1 else: - raise ValueError('Incorrect color mode for image') + raise ValueError("Incorrect color mode for image") h, w = image.shape[:2] # Scale input mlt to 8 h1 = int(np.ceil(1.0 * h / 8) * 8) w1 = int(np.ceil(1.0 * w / 8) * 8) x_scale = cv2.resize(image, (w1, h1), interpolation=cv2.INTER_LANCZOS4) image_tensor = torch.from_numpy(x_scale).permute(2, 0, 1)[None, :, :, :].float() if resized.mode == "RGB": - return image_tensor, groupnorm_normalise_image(image_tensor.clone(), format='nchw') + return image_tensor, groupnorm_normalise_image( + image_tensor.clone(), format="nchw" + ) else: - return image_tensor, torch.from_numpy(trimap_transform(x_scale)).permute(2, 0, 1)[None, :, :, :].float() + return ( + image_tensor, + torch.from_numpy(trimap_transform(x_scale)) + .permute(2, 0, 1)[None, :, :, :] + .float(), + ) @staticmethod - def data_postprocessing(data: torch.tensor, - trimap: PIL.Image.Image) -> PIL.Image.Image: + def data_postprocessing( + data: torch.tensor, trimap: PIL.Image.Image + ) -> PIL.Image.Image: """ - Transforms output data from neural network to suitable data - format for using with other components of this framework. + Transforms output data from neural network to suitable data + format for using with other components of this framework. - Args: - data: output data from neural network - trimap: Map with the area we need to refine + Args: + data: output data from neural network + trimap: Map with the area we need to refine - Returns: - Segmentation mask as PIL Image instance + Returns: + Segmentation mask as PIL Image instance """ if trimap.mode != "L": raise ValueError("Incorrect color mode for trimap") pred = data.numpy().transpose((1, 2, 0)) - pred = cv2.resize(pred, - trimap.size, - cv2.INTER_LANCZOS4)[:, :, 0] + pred = cv2.resize(pred, trimap.size, cv2.INTER_LANCZOS4)[:, :, 0] # noinspection PyTypeChecker # Clean mask by removing all false predictions outside trimap and already known area trimap_arr = np.array(trimap.copy()) @@ -123,53 +140,85 @@ def data_postprocessing(data: torch.tensor, pred[pred < 0.3] = 0 return Image.fromarray(pred * 255).convert("L") - def __call__(self, - images: List[Union[str, pathlib.Path, PIL.Image.Image]], - trimaps: List[Union[str, pathlib.Path, PIL.Image.Image]]) -> List[PIL.Image.Image]: + def __call__( + self, + images: List[Union[str, pathlib.Path, PIL.Image.Image]], + trimaps: List[Union[str, pathlib.Path, PIL.Image.Image]], + ) -> List[PIL.Image.Image]: """ - Passes input images though neural network and returns segmentation masks as PIL.Image.Image instances + Passes input images though neural network and returns segmentation masks as PIL.Image.Image instances - Args: - images: input images - trimaps: Maps with the areas we need to refine + Args: + images: input images + trimaps: Maps with the areas we need to refine - Returns: - segmentation masks as for input images, as PIL.Image.Image instances + Returns: + segmentation masks as for input images, as PIL.Image.Image instances """ if len(images) != len(trimaps): - raise ValueError("Len of specified arrays of images and trimaps should be equal!") + raise ValueError( + "Len of specified arrays of images and trimaps should be equal!" + ) collect_masks = [] - for idx_batch in batch_generator(range(len(images)), self.batch_size): - inpt_images = thread_pool_processing(lambda x: convert_image(load_image(images[x])), - idx_batch) - - inpt_trimaps = thread_pool_processing(lambda x: convert_image(load_image(trimaps[x]), mode="L"), - idx_batch) - - inpt_img_batches = thread_pool_processing(self.data_preprocessing, inpt_images) - inpt_trimaps_batches = thread_pool_processing(self.data_preprocessing, inpt_trimaps) - - inpt_img_batches_transformed = torch.vstack([i[1] for i in inpt_img_batches]) - inpt_img_batches = torch.vstack([i[0] for i in inpt_img_batches]) - - inpt_trimaps_transformed = torch.vstack([i[1] for i in inpt_trimaps_batches]) - inpt_trimaps_batches = torch.vstack([i[0] for i in inpt_trimaps_batches]) - - with torch.no_grad(): - inpt_img_batches = inpt_img_batches.to(self.device) - inpt_trimaps_batches = inpt_trimaps_batches.to(self.device) - inpt_img_batches_transformed = inpt_img_batches_transformed.to(self.device) - inpt_trimaps_transformed = inpt_trimaps_transformed.to(self.device) - - output = super(FBAMatting, self).__call__(inpt_img_batches, inpt_trimaps_batches, - inpt_img_batches_transformed, inpt_trimaps_transformed) - output_cpu = output.cpu() - del inpt_img_batches, inpt_trimaps_batches, \ - inpt_img_batches_transformed, inpt_trimaps_transformed, output - masks = thread_pool_processing(lambda x: self.data_postprocessing(output_cpu[x], inpt_trimaps[x]), - range(len(inpt_images))) - collect_masks += masks - return collect_masks + autocast, dtype = get_precision_autocast(device=self.device, fp16=self.fp16) + with autocast: + cast_network(self, dtype) + for idx_batch in batch_generator(range(len(images)), self.batch_size): + inpt_images = thread_pool_processing( + lambda x: convert_image(load_image(images[x])), idx_batch + ) + + inpt_trimaps = thread_pool_processing( + lambda x: convert_image(load_image(trimaps[x]), mode="L"), idx_batch + ) + + inpt_img_batches = thread_pool_processing( + self.data_preprocessing, inpt_images + ) + inpt_trimaps_batches = thread_pool_processing( + self.data_preprocessing, inpt_trimaps + ) + + inpt_img_batches_transformed = torch.vstack( + [i[1] for i in inpt_img_batches] + ) + inpt_img_batches = torch.vstack([i[0] for i in inpt_img_batches]) + + inpt_trimaps_transformed = torch.vstack( + [i[1] for i in inpt_trimaps_batches] + ) + inpt_trimaps_batches = torch.vstack( + [i[0] for i in inpt_trimaps_batches] + ) + + with torch.no_grad(): + inpt_img_batches = inpt_img_batches.to(self.device) + inpt_trimaps_batches = inpt_trimaps_batches.to(self.device) + inpt_img_batches_transformed = inpt_img_batches_transformed.to( + self.device + ) + inpt_trimaps_transformed = inpt_trimaps_transformed.to(self.device) + + output = super(FBAMatting, self).__call__( + inpt_img_batches, + inpt_trimaps_batches, + inpt_img_batches_transformed, + inpt_trimaps_transformed, + ) + output_cpu = output.cpu() + del ( + inpt_img_batches, + inpt_trimaps_batches, + inpt_img_batches_transformed, + inpt_trimaps_transformed, + output, + ) + masks = thread_pool_processing( + lambda x: self.data_postprocessing(output_cpu[x], inpt_trimaps[x]), + range(len(inpt_images)), + ) + collect_masks += masks + return collect_masks diff --git a/carvekit/ml/wrap/tracer_b7.py b/carvekit/ml/wrap/tracer_b7.py new file mode 100644 index 0000000..20a8e45 --- /dev/null +++ b/carvekit/ml/wrap/tracer_b7.py @@ -0,0 +1,178 @@ +""" +Source url: https://github.com/OPHoperHPO/image-background-remove-tool +Author: Nikita Selin (OPHoperHPO)[https://github.com/OPHoperHPO]. +License: Apache License 2.0 +""" +import pathlib +import warnings +from typing import List, Union +import PIL.Image +import numpy as np +import torch +import torchvision.transforms as transforms +from PIL import Image + +from carvekit.ml.arch.tracerb7.tracer import TracerDecoder +from carvekit.ml.arch.tracerb7.efficientnet import EfficientEncoderB7 +from carvekit.ml.files.models_loc import tracer_b7_pretrained, tracer_hair_pretrained +from carvekit.utils.models_utils import get_precision_autocast, cast_network +from carvekit.utils.image_utils import load_image, convert_image +from carvekit.utils.pool_utils import thread_pool_processing, batch_generator + +__all__ = ["TracerUniversalB7"] + + +class TracerUniversalB7(TracerDecoder): + """TRACER B7 model interface""" + + def __init__( + self, + device="cpu", + input_image_size: Union[List[int], int] = 640, + batch_size: int = 4, + load_pretrained: bool = True, + fp16: bool = False, + model_path: Union[str, pathlib.Path] = None, + ): + """ + Initialize the U2NET model + + Args: + layers_cfg: neural network layers configuration + device: processing device + input_image_size: input image size + batch_size: the number of images that the neural network processes in one run + load_pretrained: loading pretrained model + fp16: use fp16 precision + + """ + if model_path is None: + model_path = tracer_b7_pretrained() + super(TracerUniversalB7, self).__init__( + encoder=EfficientEncoderB7(), + rfb_channel=[32, 64, 128], + features_channels=[48, 80, 224, 640], + ) + + self.fp16 = fp16 + self.device = device + self.batch_size = batch_size + if isinstance(input_image_size, list): + self.input_image_size = input_image_size[:2] + else: + self.input_image_size = (input_image_size, input_image_size) + + self.transform = transforms.Compose( + [ + transforms.ToTensor(), + transforms.Resize(self.input_image_size), + transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), + ] + ) + self.to(device) + if load_pretrained: + # TODO remove edge detector from weights. It doesn't work well with this model! + self.load_state_dict( + torch.load(model_path, map_location=self.device), strict=False + ) + self.eval() + + def data_preprocessing(self, data: PIL.Image.Image) -> torch.FloatTensor: + """ + Transform input image to suitable data format for neural network + + Args: + data: input image + + Returns: + input for neural network + + """ + + return torch.unsqueeze(self.transform(data), 0).type(torch.FloatTensor) + + @staticmethod + def data_postprocessing( + data: torch.tensor, original_image: PIL.Image.Image + ) -> PIL.Image.Image: + """ + Transforms output data from neural network to suitable data + format for using with other components of this framework. + + Args: + data: output data from neural network + original_image: input image which was used for predicted data + + Returns: + Segmentation mask as PIL Image instance + + """ + output = (data.type(torch.FloatTensor).detach().cpu().numpy() * 255.0).astype( + np.uint8 + ) + output = output.squeeze(0) + mask = Image.fromarray(output).convert("L") + mask = mask.resize(original_image.size, resample=Image.BILINEAR) + return mask + + def __call__( + self, images: List[Union[str, pathlib.Path, PIL.Image.Image]] + ) -> List[PIL.Image.Image]: + """ + Passes input images though neural network and returns segmentation masks as PIL.Image.Image instances + + Args: + images: input images + + Returns: + segmentation masks as for input images, as PIL.Image.Image instances + + """ + collect_masks = [] + autocast, dtype = get_precision_autocast(device=self.device, fp16=self.fp16) + with autocast: + cast_network(self, dtype) + for image_batch in batch_generator(images, self.batch_size): + images = thread_pool_processing( + lambda x: convert_image(load_image(x)), image_batch + ) + batches = torch.vstack( + thread_pool_processing(self.data_preprocessing, images) + ) + with torch.no_grad(): + batches = batches.to(self.device) + masks = super(TracerDecoder, self).__call__(batches) + masks_cpu = masks.cpu() + del batches, masks + masks = thread_pool_processing( + lambda x: self.data_postprocessing(masks_cpu[x], images[x]), + range(len(images)), + ) + collect_masks += masks + + return collect_masks + + +class TracerHair(TracerUniversalB7): + """TRACER HAIR model interface""" + + def __init__( + self, + device="cpu", + input_image_size: Union[List[int], int] = 640, + batch_size: int = 4, + load_pretrained: bool = True, + fp16: bool = False, + model_path: Union[str, pathlib.Path] = None, + ): + if model_path is None: + model_path = tracer_hair_pretrained() + warnings.warn("TracerHair has not public model yet. Don't use it!", UserWarning) + super(TracerHair, self).__init__( + device=device, + input_image_size=input_image_size, + batch_size=batch_size, + load_pretrained=load_pretrained, + fp16=fp16, + model_path=model_path, + ) diff --git a/carvekit/ml/wrap/u2net.py b/carvekit/ml/wrap/u2net.py index 55b8f7f..7d126df 100644 --- a/carvekit/ml/wrap/u2net.py +++ b/carvekit/ml/wrap/u2net.py @@ -21,21 +21,25 @@ class U2NET(U2NETArchitecture): """U^2-Net model interface""" - def __init__(self, - layers_cfg="full", - device='cpu', - input_image_size: Union[List[int], int] = 320, - batch_size: int = 10, - load_pretrained: bool = True): + def __init__( + self, + layers_cfg="full", + device="cpu", + input_image_size: Union[List[int], int] = 320, + batch_size: int = 10, + load_pretrained: bool = True, + fp16: bool = False, + ): """ - Initialize the U2NET model + Initialize the U2NET model - Args: - layers_cfg: neural network layers configuration - device: processing device - input_image_size: input image size - batch_size: the number of images that the neural network processes in one run - load_pretrained: loading pretrained model + Args: + layers_cfg: neural network layers configuration + device: processing device + input_image_size: input image size + batch_size: the number of images that the neural network processes in one run + load_pretrained: loading pretrained model + fp16: use fp16 precision // not supported at this moment. """ super(U2NET, self).__init__(cfg_type=layers_cfg, out_ch=1) @@ -47,18 +51,20 @@ def __init__(self, self.input_image_size = (input_image_size, input_image_size) self.to(device) if load_pretrained: - self.load_state_dict(torch.load(u2net_full_pretrained(), map_location=self.device)) + self.load_state_dict( + torch.load(u2net_full_pretrained(), map_location=self.device) + ) self.eval() def data_preprocessing(self, data: PIL.Image.Image) -> torch.FloatTensor: """ - Transform input image to suitable data format for neural network + Transform input image to suitable data format for neural network - Args: - data: input image + Args: + data: input image - Returns: - input for neural network + Returns: + input for neural network """ resized = data.resize(self.input_image_size, resample=3) @@ -75,18 +81,19 @@ def data_preprocessing(self, data: PIL.Image.Image) -> torch.FloatTensor: return torch.from_numpy(temp_image).type(torch.FloatTensor) @staticmethod - def data_postprocessing(data: torch.tensor, - original_image: PIL.Image.Image) -> PIL.Image.Image: + def data_postprocessing( + data: torch.tensor, original_image: PIL.Image.Image + ) -> PIL.Image.Image: """ - Transforms output data from neural network to suitable data - format for using with other components of this framework. + Transforms output data from neural network to suitable data + format for using with other components of this framework. - Args: - data: output data from neural network - original_image: input image which was used for predicted data + Args: + data: output data from neural network + original_image: input image which was used for predicted data - Returns: - Segmentation mask as PIL Image instance + Returns: + Segmentation mask as PIL Image instance """ data = data.unsqueeze(0) @@ -99,27 +106,35 @@ def data_postprocessing(data: torch.tensor, mask = mask.resize(original_image.size, resample=3) return mask - def __call__(self, images: List[Union[str, pathlib.Path, PIL.Image.Image]]) -> List[PIL.Image.Image]: + def __call__( + self, images: List[Union[str, pathlib.Path, PIL.Image.Image]] + ) -> List[PIL.Image.Image]: """ - Passes input images though neural network and returns segmentation masks as PIL.Image.Image instances + Passes input images though neural network and returns segmentation masks as PIL.Image.Image instances - Args: - images: input images + Args: + images: input images - Returns: - segmentation masks as for input images, as PIL.Image.Image instances + Returns: + segmentation masks as for input images, as PIL.Image.Image instances """ collect_masks = [] for image_batch in batch_generator(images, self.batch_size): - images = thread_pool_processing(lambda x: convert_image(load_image(x)), image_batch) - batches = torch.vstack(thread_pool_processing(self.data_preprocessing, images)) + images = thread_pool_processing( + lambda x: convert_image(load_image(x)), image_batch + ) + batches = torch.vstack( + thread_pool_processing(self.data_preprocessing, images) + ) with torch.no_grad(): batches = batches.to(self.device) masks, d2, d3, d4, d5, d6, d7 = super(U2NET, self).__call__(batches) masks_cpu = masks.cpu() del d2, d3, d4, d5, d6, d7, batches, masks - masks = thread_pool_processing(lambda x: self.data_postprocessing(masks_cpu[x], images[x]), - range(len(images))) + masks = thread_pool_processing( + lambda x: self.data_postprocessing(masks_cpu[x], images[x]), + range(len(images)), + ) collect_masks += masks return collect_masks diff --git a/carvekit/pipelines/postprocessing.py b/carvekit/pipelines/postprocessing.py index 1623732..fc22451 100644 --- a/carvekit/pipelines/postprocessing.py +++ b/carvekit/pipelines/postprocessing.py @@ -22,9 +22,12 @@ class MattingMethod: Neural network for matting performs accurate object edge detection by using a special map called trimap, with unknown area that we scan for boundary, already known general object area and the background.""" - def __init__(self, matting_module: Union[FBAMatting], - trimap_generator: Union[TrimapGenerator, CV2TrimapGenerator], - device="cpu"): + def __init__( + self, + matting_module: Union[FBAMatting], + trimap_generator: Union[TrimapGenerator, CV2TrimapGenerator], + device="cpu", + ): """ Initializes Matting Method class. @@ -37,8 +40,11 @@ def __init__(self, matting_module: Union[FBAMatting], self.matting_module = matting_module self.trimap_generator = trimap_generator - def __call__(self, images: List[Union[str, Path, Image.Image]], - masks: List[Union[str, Path, Image.Image]]): + def __call__( + self, + images: List[Union[str, Path, Image.Image]], + masks: List[Union[str, Path, Image.Image]], + ): """ Passes data through apply_mask function @@ -52,9 +58,19 @@ def __call__(self, images: List[Union[str, Path, Image.Image]], if len(images) != len(masks): raise ValueError("Images and Masks lists should have same length!") images = thread_pool_processing(lambda x: convert_image(load_image(x)), images) - masks = thread_pool_processing(lambda x: convert_image(load_image(x), mode="L"), masks) - trimaps = thread_pool_processing(lambda x: self.trimap_generator(original_image=images[x], - mask=masks[x]), range(len(images))) + masks = thread_pool_processing( + lambda x: convert_image(load_image(x), mode="L"), masks + ) + trimaps = thread_pool_processing( + lambda x: self.trimap_generator(original_image=images[x], mask=masks[x]), + range(len(images)), + ) alpha = self.matting_module(images=images, trimaps=trimaps) - return list(map(lambda x: apply_mask(image=images[x], mask=alpha[x], device=self.device), - range(len(images)))) + return list( + map( + lambda x: apply_mask( + image=images[x], mask=alpha[x], device=self.device + ), + range(len(images)), + ) + ) diff --git a/carvekit/pipelines/preprocessing.py b/carvekit/pipelines/preprocessing.py index 678f13f..3d1e848 100644 --- a/carvekit/pipelines/preprocessing.py +++ b/carvekit/pipelines/preprocessing.py @@ -11,8 +11,6 @@ __all__ = ["PreprocessingStub"] - - class PreprocessingStub: """Stub for future preprocessing methods""" diff --git a/carvekit/trimap/add_ops.py b/carvekit/trimap/add_ops.py index 01f91d4..dfb37ca 100644 --- a/carvekit/trimap/add_ops.py +++ b/carvekit/trimap/add_ops.py @@ -3,6 +3,7 @@ Author: Nikita Selin (OPHoperHPO)[https://github.com/OPHoperHPO]. License: Apache License 2.0 """ +import cv2 import numpy as np from PIL import Image @@ -30,7 +31,9 @@ def prob_filter(mask: Image.Image, prob_threshold=231) -> Image.Image: return Image.fromarray(mask_array).convert("L") -def prob_as_unknown_area(trimap: Image.Image, mask: Image.Image, prob_threshold=255) -> Image.Image: +def prob_as_unknown_area( + trimap: Image.Image, mask: Image.Image, prob_threshold=255 +) -> Image.Image: """ Marks any uncertainty in the seg mask as an unknown region. @@ -53,3 +56,36 @@ def prob_as_unknown_area(trimap: Image.Image, mask: Image.Image, prob_threshold= trimap_array = np.array(trimap) trimap_array[np.logical_and(mask_array <= prob_threshold, mask_array > 0)] = 127 return Image.fromarray(trimap_array).convert("L") + + +def post_erosion(trimap: Image.Image, erosion_iters=1) -> Image.Image: + """ + Performs erosion on the mask and marks the resulting area as an unknown region. + + Args: + erosion_iters: The number of iterations of erosion that + the object's mask will be subjected to before forming an unknown area + trimap: Generated trimap. + mask: Predicted object mask + + Returns: + Generated trimap for image. + """ + if trimap.mode != "L": + raise ValueError("Input mask has wrong color mode.") + # noinspection PyTypeChecker + trimap_array = np.array(trimap) + if erosion_iters > 0: + without_unknown_area = trimap_array.copy() + without_unknown_area[without_unknown_area == 127] = 0 + + erosion_kernel = np.ones((3, 3), np.uint8) + erode = cv2.erode( + without_unknown_area, erosion_kernel, iterations=erosion_iters + ) + erode = np.where(erode == 0, 0, without_unknown_area) + trimap_array[np.logical_and(erode == 0, without_unknown_area > 0)] = 127 + erode = trimap_array.copy() + else: + erode = trimap_array.copy() + return Image.fromarray(erode).convert("L") diff --git a/carvekit/trimap/cv_gen.py b/carvekit/trimap/cv_gen.py index 8b049ba..fc2c229 100644 --- a/carvekit/trimap/cv_gen.py +++ b/carvekit/trimap/cv_gen.py @@ -22,7 +22,9 @@ def __init__(self, kernel_size: int = 30, erosion_iters: int = 1): self.kernel_size = kernel_size self.erosion_iters = erosion_iters - def __call__(self, original_image: PIL.Image.Image, mask: PIL.Image.Image) -> PIL.Image.Image: + def __call__( + self, original_image: PIL.Image.Image, mask: PIL.Image.Image + ) -> PIL.Image.Image: """ Generates trimap based on predicted object mask to refine object mask borders. Based on cv2 erosion algorithm. @@ -46,7 +48,7 @@ def __call__(self, original_image: PIL.Image.Image, mask: PIL.Image.Image) -> PI if self.erosion_iters > 0: erosion_kernel = np.ones((3, 3), np.uint8) erode = cv2.erode(mask_array, erosion_kernel, iterations=self.erosion_iters) - erode = np.where(erode > 0, 255, mask_array) + erode = np.where(erode == 0, 0, mask_array) else: erode = mask_array.copy() diff --git a/carvekit/trimap/generator.py b/carvekit/trimap/generator.py index 21f827d..0656f45 100644 --- a/carvekit/trimap/generator.py +++ b/carvekit/trimap/generator.py @@ -5,11 +5,13 @@ """ from PIL import Image from carvekit.trimap.cv_gen import CV2TrimapGenerator -from carvekit.trimap.add_ops import prob_filter, prob_as_unknown_area +from carvekit.trimap.add_ops import prob_filter, prob_as_unknown_area, post_erosion class TrimapGenerator(CV2TrimapGenerator): - def __init__(self, prob_threshold: float = 231, kernel_size: int = 30, erosion_iters: int = 5): + def __init__( + self, prob_threshold: int = 231, kernel_size: int = 30, erosion_iters: int = 5 + ): """ Initialize a TrimapGenerator instance @@ -21,8 +23,9 @@ def __init__(self, prob_threshold: float = 231, kernel_size: int = 30, erosion_i erosion_iters: The number of iterations of erosion that the object's mask will be subjected to before forming an unknown area """ - super().__init__(kernel_size, erosion_iters) + super().__init__(kernel_size, erosion_iters=0) self.prob_threshold = prob_threshold + self.__erosion_iters = erosion_iters def __call__(self, original_image: Image.Image, mask: Image.Image) -> Image.Image: """ @@ -37,6 +40,8 @@ def __call__(self, original_image: Image.Image, mask: Image.Image) -> Image.Imag """ filter_mask = prob_filter(mask=mask, prob_threshold=self.prob_threshold) trimap = super(TrimapGenerator, self).__call__(original_image, filter_mask) - new_trimap = prob_as_unknown_area(trimap=trimap, mask=mask, prob_threshold=self.prob_threshold) - + new_trimap = prob_as_unknown_area( + trimap=trimap, mask=mask, prob_threshold=self.prob_threshold + ) + new_trimap = post_erosion(new_trimap, self.__erosion_iters) return new_trimap diff --git a/carvekit/utils/download_models.py b/carvekit/utils/download_models.py index 1c25b32..c1d7705 100644 --- a/carvekit/utils/download_models.py +++ b/carvekit/utils/download_models.py @@ -6,72 +6,68 @@ import hashlib import os import warnings +from abc import ABCMeta, abstractmethod, ABC from pathlib import Path +from typing import Optional + +import carvekit +from carvekit.ml.files import checkpoints_dir import requests import tqdm +requests = requests.Session() +requests.headers.update({"User-Agent": f"Carvekit/{carvekit.version}"}) + MODELS_URLS = { - "basnet.pth": - "https://huggingface.co/anodev/basnet-universal/resolve/870becbdb364fda6d8fdb2c10b072542f8d08701/basnet.pth", - "deeplab.pth": - "https://huggingface.co/anodev/deeplabv3-resnet101/resolve/d504005392fc877565afdf58aad0cd524682d2b0/deeplab.pth", - "fba_matting.pth": - "https://huggingface.co/anodev/fba/resolve/a5d3457df0fb9c88ea19ed700d409756ca2069d1/fba_matting.pth", - "u2net.pth": - "https://huggingface.co/anodev/u2net-universal/resolve/10305d785481cf4b2eee1d447c39cd6e5f43d74b/full_weights" - ".pth", + "basnet.pth": { + "repository": "Carve/basnet-universal", + "revision": "870becbdb364fda6d8fdb2c10b072542f8d08701", + "filename": "basnet.pth", + }, + "deeplab.pth": { + "repository": "Carve/deeplabv3-resnet101", + "revision": "d504005392fc877565afdf58aad0cd524682d2b0", + "filename": "deeplab.pth", + }, + "fba_matting.pth": { + "repository": "Carve/fba", + "revision": "a5d3457df0fb9c88ea19ed700d409756ca2069d1", + "filename": "fba_matting.pth", + }, + "u2net.pth": { + "repository": "Carve/u2net-universal", + "revision": "10305d785481cf4b2eee1d447c39cd6e5f43d74b", + "filename": "full_weights.pth", + }, + "tracer_b7.pth": { + "repository": "Carve/tracer_b7", + "revision": "d8a8fd9e7b3fa0d2f1506fe7242966b34381e9c5", + "filename": "tracer_b7.pth", + }, + "tracer_hair.pth": { + "repository": "Carve/tracer_b7", + "revision": "d8a8fd9e7b3fa0d2f1506fe7242966b34381e9c5", + "filename": "tracer_b7.pth", # TODO don't forget change this link!! + }, } MODELS_CHECKSUMS = { "basnet.pth": "e409cb709f4abca87cb11bd44a9ad3f909044a917977ab65244b4c94dd33" - "8b1a37755c4253d7cb54526b7763622a094d7b676d34b5e6886689256754e5a5e6ad", - "deeplab.pth": - "9c5a1795bc8baa267200a44b49ac544a1ba2687d210f63777e4bd715387324469a59b072f8a28" - "9cc471c637b367932177e5b312e8ea6351c1763d9ff44b4857c", - "fba_matting.pth": - "890906ec94c1bfd2ad08707a63e4ccb0955d7f5d25e32853950c24c78" - "4cbad2e59be277999defc3754905d0f15aa75702cdead3cfe669ff72f08811c52971613", - "u2net.pth": - "16f8125e2fedd8c85db0e001ee15338b4aa2fda77bab8ba70c25e" - "bea1533fda5ee70a909b934a9bd495b432cef89d629f00a07858a517742476fa8b346de24f7", - + "8b1a37755c4253d7cb54526b7763622a094d7b676d34b5e6886689256754e5a5e6ad", + "deeplab.pth": "9c5a1795bc8baa267200a44b49ac544a1ba2687d210f63777e4bd715387324469a59b072f8a28" + "9cc471c637b367932177e5b312e8ea6351c1763d9ff44b4857c", + "fba_matting.pth": "890906ec94c1bfd2ad08707a63e4ccb0955d7f5d25e32853950c24c78" + "4cbad2e59be277999defc3754905d0f15aa75702cdead3cfe669ff72f08811c52971613", + "u2net.pth": "16f8125e2fedd8c85db0e001ee15338b4aa2fda77bab8ba70c25e" + "bea1533fda5ee70a909b934a9bd495b432cef89d629f00a07858a517742476fa8b346de24f7", + "tracer_b7.pth": "c439c5c12d4d43d5f9be9ec61e68b2e54658a541bccac2577ef5a54fb252b6e8415d41f7e" + "c2487033d0c02b4dd08367958e4e62091318111c519f93e2632be7b", + "tracer_hair.pth": "5c2fb9973fc42fa6208920ffa9ac233cc2ea9f770b24b4a96969d3449aed7ac89e6d37e" + "e486a13e63be5499f2df6ccef1109e9e8797d1326207ac89b2f39a7cf", } -def download_model(path: Path) -> Path: - """ Downloads model from repo. - - Args: - path (pathlib.Path): Path to file - - Returns: - Path if exists - - Raises: - FileNotFoundError: if model checkpoint is not exists in known checkpoints models - ConnectionError: if the model cannot be loaded from the URL. - """ - if path.name in MODELS_URLS: - model_url = MODELS_URLS[path.name] - path.parent.mkdir(parents=True, exist_ok=True) - try: - r = requests.get(model_url, stream=True) - if r.status_code == 200: - with path.absolute().open('wb') as f: - r.raw.decode_content = True - for chunk in tqdm.tqdm(r, desc="Downloading " + path.name + ' model', colour='blue'): - f.write(chunk) - except BaseException as e: - if path.exists(): - os.remove(path) - raise ConnectionError(f"Exception caused when downloading model! " - f"Model name: {path.name}. Exception: {str(e)}") - return path - else: - raise FileNotFoundError("Unknown model!") - - def sha512_checksum_calc(file: Path) -> str: """ Calculates the SHA512 hash digest of a file on fs @@ -89,45 +85,117 @@ def sha512_checksum_calc(file: Path) -> str: return dd.hexdigest() -def check_model(path: Path) -> bool: - """ Verifies model checksums and existence in the file system +class CachedDownloader: + __metaclass__ = ABCMeta - Args: - path: Path to the model + @property + @abstractmethod + def fallback_downloader(self) -> Optional["CachedDownloader"]: + pass - Returns: - True if all is okay and False if not - - Raises: - FileNotFoundError: if model checkpoint is not exists in known checkpoints models - """ - if path.exists(): - if path.name in MODELS_URLS: - if MODELS_CHECKSUMS[path.name] != sha512_checksum_calc(path): - warnings.warn(f"Invalid checksum for model {path.name}. Downloading correct model!") - os.remove(path) - return False - return True - else: + def download_model(self, file_name: str) -> Path: + try: + return self.download_model_base(file_name) + except BaseException as e: + if self.fallback_downloader is not None: + warnings.warn( + f"Failed to download model from {self.__class__.__name__} downloader." + f" Trying to download from {self.fallback_downloader.__class__.__name__} downloader." + ) + return self.fallback_downloader.download_model(file_name) + else: + warnings.warn( + f"Failed to download model from {self.__class__.__name__} downloader." + f" No fallback downloader available." + ) + raise e + + @abstractmethod + def download_model_base(self, file_name: str) -> Path: + """Download model from any source if not cached. Returns path if cached""" + + def __call__(self, file_name: str): + return self.download_model(file_name) + + +class HuggingFaceCompatibleDownloader(CachedDownloader, ABC): + def __init__( + self, + base_url: str = "https://huggingface.co", + fb_downloader: Optional["CachedDownloader"] = None, + ): + self.cache_dir = checkpoints_dir + self.base_url = base_url + self._fallback_downloader = fb_downloader + + @property + def fallback_downloader(self) -> Optional["CachedDownloader"]: + return self._fallback_downloader + + def check_for_existence(self, file_name: str) -> Optional[Path]: + if file_name not in MODELS_URLS.keys(): raise FileNotFoundError("Unknown model!") - else: - return False - - -def check_for_exists(path: Path) -> Path: - """ Checks for checkpoint path exists - - Args: - path (pathlib.Path): Path to file - - Returns: - Path if exists - - Raises: - FileNotFoundError: if model checkpoint is not exists in known checkpoints models - ConnectionError: if the model cannot be loaded from the URL. - """ - if not check_model(path): - download_model(path) + path = ( + self.cache_dir + / MODELS_URLS[file_name]["repository"].split("/")[1] + / file_name + ) + + if not path.exists(): + return None + + if MODELS_CHECKSUMS[path.name] != sha512_checksum_calc(path): + warnings.warn( + f"Invalid checksum for model {path.name}. Downloading correct model!" + ) + os.remove(path) + return None + return path - return path + def download_model_base(self, file_name: str) -> Path: + cached_path = self.check_for_existence(file_name) + if cached_path is not None: + return cached_path + else: + cached_path = ( + self.cache_dir + / MODELS_URLS[file_name]["repository"].split("/")[1] + / file_name + ) + cached_path.parent.mkdir(parents=True, exist_ok=True) + url = MODELS_URLS[file_name] + hugging_face_url = f"{self.base_url}/{url['repository']}/resolve/{url['revision']}/{url['filename']}" + + try: + r = requests.get(hugging_face_url, stream=True) + if r.status_code < 400: + with open(cached_path, "wb") as f: + r.raw.decode_content = True + for chunk in tqdm.tqdm( + r, + desc="Downloading " + cached_path.name + " model", + colour="blue", + ): + f.write(chunk) + else: + if r.status_code == 404: + raise FileNotFoundError(f"Model {file_name} not found!") + else: + raise ConnectionError( + f"Error {r.status_code} while downloading model {file_name}!" + ) + except BaseException as e: + if cached_path.exists(): + os.remove(cached_path) + raise ConnectionError( + f"Exception caught when downloading model! " + f"Model name: {cached_path.name}. Exception: {str(e)}." + ) + return cached_path + + +downloader: CachedDownloader = HuggingFaceCompatibleDownloader( + base_url="https://cdn.carve.photos" +) +fallback_downloader: CachedDownloader = HuggingFaceCompatibleDownloader() +downloader._fallback_downloader = fallback_downloader diff --git a/carvekit/utils/fs_utils.py b/carvekit/utils/fs_utils.py index 0a22c5d..bd6291e 100644 --- a/carvekit/utils/fs_utils.py +++ b/carvekit/utils/fs_utils.py @@ -20,13 +20,19 @@ def save_file(output: Optional[Path], input_path: Path, image: Image.Image): """ if isinstance(output, Path) and str(output) != "none": if output.is_dir() and output.exists(): - image.save(output.joinpath(input_path.with_suffix('.png').name)) - elif output.suffix != '': + image.save(output.joinpath(input_path.with_suffix(".png").name)) + elif output.suffix != "": if output.suffix != ".png": - warnings.warn(f"Only export with .png extension is supported! Your {output.suffix}" - f" extension will be ignored and replaced with .png!") - image.save(output.with_suffix('.png')) + warnings.warn( + f"Only export with .png extension is supported! Your {output.suffix}" + f" extension will be ignored and replaced with .png!" + ) + image.save(output.with_suffix(".png")) else: raise ValueError("Wrong output path!") elif output is None or str(output) == "none": - image.save(input_path.with_name(input_path.stem.split('.')[0] + '_bg_removed').with_suffix('.png')) + image.save( + input_path.with_name( + input_path.stem.split(".")[0] + "_bg_removed" + ).with_suffix(".png") + ) diff --git a/carvekit/utils/image_utils.py b/carvekit/utils/image_utils.py index 88bbb5e..8b939f5 100644 --- a/carvekit/utils/image_utils.py +++ b/carvekit/utils/image_utils.py @@ -28,7 +28,7 @@ def to_tensor(x: Any) -> torch.Tensor: def load_image(file: Union[str, pathlib.Path, PIL.Image.Image]) -> PIL.Image.Image: - """ Returns a PIL.Image.Image class by string path or pathlib path or PIL.Image.Image instance + """Returns a PIL.Image.Image class by string path or pathlib path or PIL.Image.Image instance Args: file: File path or PIL.Image.Image instance @@ -51,7 +51,7 @@ def load_image(file: Union[str, pathlib.Path, PIL.Image.Image]) -> PIL.Image.Ima def convert_image(image: PIL.Image.Image, mode="RGB") -> PIL.Image.Image: - """ Performs image conversion to correct color mode + """Performs image conversion to correct color mode Args: image: PIL.Image.Image instance @@ -86,18 +86,22 @@ def is_image_valid(image: Union[pathlib.Path, PIL.Image.Image]) -> bool: elif image.is_dir(): raise ValueError("File is a directory") elif image.suffix.lower() not in ALLOWED_SUFFIXES: - raise ValueError(f"Unsupported image format. Supported file formats: {', '.join(ALLOWED_SUFFIXES)}") + raise ValueError( + f"Unsupported image format. Supported file formats: {', '.join(ALLOWED_SUFFIXES)}" + ) elif isinstance(image, PIL.Image.Image): if not (image.size[0] > 32 and image.size[1] > 32): raise ValueError("Image should be bigger then (32x32) pixels.") elif image.mode not in ["RGB", "RGBA", "L"]: - raise ValueError('Wrong image color mode.') + raise ValueError("Wrong image color mode.") else: raise ValueError("Unknown input file type") return True -def transparency_paste(bg_img: PIL.Image.Image, fg_img: PIL.Image.Image, box=(0, 0)) -> PIL.Image.Image: +def transparency_paste( + bg_img: PIL.Image.Image, fg_img: PIL.Image.Image, box=(0, 0) +) -> PIL.Image.Image: """ Inserts an image into another image while maintaining transparency. @@ -115,9 +119,14 @@ def transparency_paste(bg_img: PIL.Image.Image, fg_img: PIL.Image.Image, box=(0, return new_img -def add_margin(pil_img: PIL.Image.Image, - top: int, right: int, bottom: int, left: int, - color: Tuple[int, int, int, int])->PIL.Image.Image: +def add_margin( + pil_img: PIL.Image.Image, + top: int, + right: int, + bottom: int, + left: int, + color: Tuple[int, int, int, int], +) -> PIL.Image.Image: """ Adds margin to the image. diff --git a/carvekit/utils/mask_utils.py b/carvekit/utils/mask_utils.py index 9802f6d..4402036 100644 --- a/carvekit/utils/mask_utils.py +++ b/carvekit/utils/mask_utils.py @@ -8,10 +8,12 @@ from carvekit.utils.image_utils import to_tensor -def composite(foreground: PIL.Image.Image, - background: PIL.Image.Image, - alpha: PIL.Image.Image, - device="cpu"): +def composite( + foreground: PIL.Image.Image, + background: PIL.Image.Image, + alpha: PIL.Image.Image, + device="cpu", +): """ Composites foreground with background by following https://pymatting.github.io/intro.html#alpha-matting math formula. @@ -49,7 +51,9 @@ def composite(foreground: PIL.Image.Image, return PIL.Image.fromarray(bg.cpu().numpy()).convert("RGBA") -def apply_mask(image: PIL.Image.Image, mask: PIL.Image.Image, device="cpu") -> PIL.Image.Image: +def apply_mask( + image: PIL.Image.Image, mask: PIL.Image.Image, device="cpu" +) -> PIL.Image.Image: """ Applies mask to foreground. diff --git a/carvekit/utils/models_utils.py b/carvekit/utils/models_utils.py index 16be87d..da0141d 100644 --- a/carvekit/utils/models_utils.py +++ b/carvekit/utils/models_utils.py @@ -6,7 +6,93 @@ import random import warnings +from typing import Union, Tuple, Any + import torch +from torch import autocast + + +class EmptyAutocast(object): + """ + Empty class for disable any autocasting. + """ + + def __enter__(self): + return None + + def __exit__(self, exc_type, exc_val, exc_tb): + return + + def __call__(self, func): + return + + +def get_precision_autocast( + device="cpu", fp16=True, override_dtype=None +) -> Union[ + Tuple[EmptyAutocast, Union[torch.dtype, Any]], + Tuple[autocast, Union[torch.dtype, Any]], +]: + """ + Returns precision and autocast settings for given device and fp16 settings. + Args: + device: Device to get precision and autocast settings for. + fp16: Whether to use fp16 precision. + override_dtype: Override dtype for autocast. + + Returns: + Autocast object, dtype + """ + dtype = torch.float32 + cache_enabled = None + + if device == "cpu" and fp16: + warnings.warn('FP16 is not supported on CPU. Using FP32 instead.') + dtype = torch.float32 + + # TODO: Implement BFP16 on CPU. There are unexpected slowdowns on cpu on a clean environment. + # warnings.warn( + # "Accuracy BFP16 has experimental support on the CPU. " + # "This may result in an unexpected reduction in quality." + # ) + # dtype = ( + # torch.bfloat16 + # ) # Using bfloat16 for CPU, since autocast is not supported for float16 + + + if "cuda" in device and fp16: + dtype = torch.float16 + cache_enabled = True + + if override_dtype is not None: + dtype = override_dtype + + if dtype == torch.float32 and device == "cpu": + return EmptyAutocast(), dtype + + return ( + torch.autocast( + device_type=device, dtype=dtype, enabled=True, cache_enabled=cache_enabled + ), + dtype, + ) + + +def cast_network(network: torch.nn.Module, dtype: torch.dtype): + """Cast network to given dtype + + Args: + network: Network to be casted + dtype: Dtype to cast network to + """ + if dtype == torch.float16: + network.half() + elif dtype == torch.bfloat16: + network.bfloat16() + elif dtype == torch.float32: + network.float() + else: + raise ValueError(f"Unknown dtype {dtype}") def fix_seed(seed=42): @@ -30,9 +116,11 @@ def fix_seed(seed=42): def suppress_warnings(): # Suppress PyTorch 1.11.0 warning associated with changing order of args in nn.MaxPool2d layer, # since source code is not affected by this issue and there aren't any other correct way to hide this message. - warnings.filterwarnings("ignore", - category=UserWarning, - message="Note that order of the arguments: ceil_mode and " - "return_indices will changeto match the args list " - "in nn.MaxPool2d in a future release.", - module="torch") + warnings.filterwarnings( + "ignore", + category=UserWarning, + message="Note that order of the arguments: ceil_mode and " + "return_indices will changeto match the args list " + "in nn.MaxPool2d in a future release.", + module="torch", + ) diff --git a/carvekit/utils/pool_utils.py b/carvekit/utils/pool_utils.py index f9a7d49..ae3b741 100644 --- a/carvekit/utils/pool_utils.py +++ b/carvekit/utils/pool_utils.py @@ -9,15 +9,15 @@ def thread_pool_processing(func: Any, data: Iterable, workers=18): """ - Passes all iterator data through the given function + Passes all iterator data through the given function - Args: - workers: Count of workers. - func: function to pass data through - data: input iterator + Args: + workers: Count of workers. + func: function to pass data through + data: input iterator - Returns: - function return list + Returns: + function return list """ with ThreadPoolExecutor(workers) as p: @@ -26,15 +26,15 @@ def thread_pool_processing(func: Any, data: Iterable, workers=18): def batch_generator(iterable, n=1): """ - Splits any iterable into n-size packets + Splits any iterable into n-size packets - Args: - iterable: iterator - n: size of packets + Args: + iterable: iterator + n: size of packets - Returns: - new n-size packet + Returns: + new n-size packet """ it = len(iterable) for ndx in range(0, it, n): - yield iterable[ndx:min(ndx + n, it)] + yield iterable[ndx : min(ndx + n, it)] diff --git a/carvekit/web/app.py b/carvekit/web/app.py index 24ba1df..cea3526 100644 --- a/carvekit/web/app.py +++ b/carvekit/web/app.py @@ -9,18 +9,22 @@ from carvekit.web.deps import config from carvekit.web.routers.api_router import api_router -app = FastAPI(title='CarveKit Web API', version=version) +app = FastAPI(title="CarveKit Web API", version=version) app.add_middleware( CORSMiddleware, - allow_origins=['*'], + allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) app.include_router(api_router, prefix="/api") -app.mount('/', StaticFiles(directory=Path(__file__).parent.joinpath('static'), html=True), name="static") +app.mount( + "/", + StaticFiles(directory=Path(__file__).parent.joinpath("static"), html=True), + name="static", +) if __name__ == "__main__": uvicorn.run(app, host=config.host, port=config.port) diff --git a/carvekit/web/deps.py b/carvekit/web/deps.py index 5802dd6..37a41ae 100644 --- a/carvekit/web/deps.py +++ b/carvekit/web/deps.py @@ -4,4 +4,3 @@ config: WebAPIConfig = init_config() ml_processor = MLProcessor(api_config=config) - diff --git a/carvekit/web/handlers/response.py b/carvekit/web/handlers/response.py index 3f35679..f359b3c 100644 --- a/carvekit/web/handlers/response.py +++ b/carvekit/web/handlers/response.py @@ -26,21 +26,31 @@ def handle_response(response, original_image) -> Response: response_object = None if isinstance(response, dict): if response["type"] == "jpg": - response_object = Response(content=response["data"][0].read(), media_type='image/jpeg') + response_object = Response( + content=response["data"][0].read(), media_type="image/jpeg" + ) elif response["type"] == "png": - response_object = Response(content=response["data"][0].read(), media_type='image/png') + response_object = Response( + content=response["data"][0].read(), media_type="image/png" + ) elif response["type"] == "zip": - response_object = Response(content=response["data"][0], media_type='application/zip') - response_object.headers['Content-Disposition'] = 'attachment; filename=\'no-bg.zip\'' + response_object = Response( + content=response["data"][0], media_type="application/zip" + ) + response_object.headers[ + "Content-Disposition" + ] = "attachment; filename='no-bg.zip'" # Add headers to output result - response_object.headers["X-Credits-Charged"] = '0' + response_object.headers["X-Credits-Charged"] = "0" response_object.headers["X-Type"] = "other" # TODO Make support for this response_object.headers["X-Max-Width"] = str(original_image.size[0]) response_object.headers["X-Max-Height"] = str(original_image.size[1]) - response_object.headers["X-Ratelimit-Limit"] = '500' # TODO Make ratelimit support - response_object.headers["X-Ratelimit-Remaining"] = '500' - response_object.headers["X-Ratelimit-Reset"] = '1' + response_object.headers[ + "X-Ratelimit-Limit" + ] = "500" # TODO Make ratelimit support + response_object.headers["X-Ratelimit-Remaining"] = "500" + response_object.headers["X-Ratelimit-Reset"] = "1" response_object.headers["X-Width"] = str(response["data"][1][0]) response_object.headers["X-Height"] = str(response["data"][1][1]) diff --git a/carvekit/web/other/removebg.py b/carvekit/web/other/removebg.py index 0270594..30dca1e 100644 --- a/carvekit/web/other/removebg.py +++ b/carvekit/web/other/removebg.py @@ -11,7 +11,9 @@ from carvekit.api.interface import Interface -def process_remove_bg(interface: Interface, params, image, bg, is_json_or_www_encoded=False): +def process_remove_bg( + interface: Interface, params, image, bg, is_json_or_www_encoded=False +): """ Handles a request to the removebg api method @@ -51,25 +53,45 @@ def process_remove_bg(interface: Interface, params, image, bg, is_json_or_www_en try: coord = int(coord) except BaseException: - return error_dict("Error converting roi coordinate string to number!"), 400 + return ( + error_dict( + "Error converting roi coordinate string to number!" + ), + 400, + ) if coord < 0: - error_dict( - "Bad roi coordinate."), 400 + error_dict("Bad roi coordinate."), 400 if (i == 0 or i == 2) and coord > image.size[0]: - return error_dict( - "The roi coordinate cannot be larger than the image size."), 400 + return ( + error_dict( + "The roi coordinate cannot be larger than the image size." + ), + 400, + ) elif (i == 1 or i == 3) and coord > image.size[1]: - return error_dict( - "The roi coordinate cannot be larger than the image size."), 400 + return ( + error_dict( + "The roi coordinate cannot be larger than the image size." + ), + 400, + ) roi_box[i] = int(coord) elif "%" in coord: coord = coord.replace("%", "") try: coord = int(coord) except BaseException: - return error_dict("Error converting roi coordinate string to number!"), 400 + return ( + error_dict( + "Error converting roi coordinate string to number!" + ), + 400, + ) if coord > 100: - return error_dict("The coordinate cannot be more than 100%"), 400 + return ( + error_dict("The coordinate cannot be more than 100%"), + 400, + ) elif coord < 0: return error_dict("Coordinate cannot be less than 0%"), 400 if i == 0 or i == 2: @@ -88,10 +110,12 @@ def process_remove_bg(interface: Interface, params, image, bg, is_json_or_www_en new_image = interface([new_image])[0] scaled = False - if "scale" in params.keys() and params['scale'] != 100: + if "scale" in params.keys() and params["scale"] != 100: value = params["scale"] - new_image.thumbnail((int(image.size[0] * value / 100), - int(image.size[1] * value / 100)), resample=3) + new_image.thumbnail( + (int(image.size[0] * value / 100), int(image.size[1] * value / 100)), + resample=3, + ) scaled = True if "crop" in params.keys(): value = params["crop"] @@ -103,28 +127,51 @@ def process_remove_bg(interface: Interface, params, image, bg, is_json_or_www_en crop_margin = crop_margin.replace("px", "") crop_margin = abs(int(crop_margin)) if crop_margin > 500: - return error_dict( - "The crop_margin cannot be larger than the original image size."), 400 - new_image = add_margin(new_image, crop_margin, - crop_margin, crop_margin, crop_margin, (0, 0, 0, 0)) + return ( + error_dict( + "The crop_margin cannot be larger than the original image size." + ), + 400, + ) + new_image = add_margin( + new_image, + crop_margin, + crop_margin, + crop_margin, + crop_margin, + (0, 0, 0, 0), + ) elif "%" in crop_margin: crop_margin = crop_margin.replace("%", "") crop_margin = int(crop_margin) - new_image = add_margin(new_image, int(new_image.size[1] * crop_margin / 100), - int(new_image.size[0] * crop_margin / 100), - int(new_image.size[1] * crop_margin / 100), - int(new_image.size[0] * crop_margin / 100), (0, 0, 0, 0)) + new_image = add_margin( + new_image, + int(new_image.size[1] * crop_margin / 100), + int(new_image.size[0] * crop_margin / 100), + int(new_image.size[1] * crop_margin / 100), + int(new_image.size[0] * crop_margin / 100), + (0, 0, 0, 0), + ) else: if "position" in params.keys() and scaled is False: value = params["position"] if len(value) == 2: - new_image = transparency_paste(Image.new("RGBA", image.size), new_image, - (int(image.size[0] * value[0] / 100), - int(image.size[1] * value[1] / 100))) + new_image = transparency_paste( + Image.new("RGBA", image.size), + new_image, + ( + int(image.size[0] * value[0] / 100), + int(image.size[1] * value[1] / 100), + ), + ) else: - new_image = transparency_paste(Image.new("RGBA", image.size), new_image, roi_box) + new_image = transparency_paste( + Image.new("RGBA", image.size), new_image, roi_box + ) elif scaled is False: - new_image = transparency_paste(Image.new("RGBA", image.size), new_image, roi_box) + new_image = transparency_paste( + Image.new("RGBA", image.size), new_image, roi_box + ) if "channels" in params.keys(): value = params["channels"] @@ -163,19 +210,19 @@ def process_remove_bg(interface: Interface, params, image, bg, is_json_or_www_en if value == "jpg": new_image = new_image.convert("RGB") img_io = io.BytesIO() - new_image.save(img_io, 'JPEG', quality=100) + new_image.save(img_io, "JPEG", quality=100) img_io.seek(0) return {"type": "jpg", "data": [img_io, new_image.size]} elif value == "zip": mask = extract_alpha_channel(new_image) mask_buff = io.BytesIO() - mask.save(mask_buff, 'PNG') + mask.save(mask_buff, "PNG") mask_buff.seek(0) image_buff = io.BytesIO() - image.save(image_buff, 'JPEG') + image.save(image_buff, "JPEG") image_buff.seek(0) fileobj = io.BytesIO() - with zipfile.ZipFile(fileobj, 'w') as zip_file: + with zipfile.ZipFile(fileobj, "w") as zip_file: zip_info = zipfile.ZipInfo(filename="color.jpg") zip_info.date_time = time.localtime(time.time())[:6] zip_info.compress_type = zipfile.ZIP_DEFLATED @@ -188,8 +235,13 @@ def process_remove_bg(interface: Interface, params, image, bg, is_json_or_www_en return {"type": "zip", "data": [fileobj.read(), new_image.size]} else: buff = io.BytesIO() - new_image.save(buff, 'PNG') + new_image.save(buff, "PNG") buff.seek(0) return {"type": "png", "data": [buff, new_image.size]} - return error_dict("Something wrong with request or http api. Please, open new issue on Github! This is error in " - "code."), 400 + return ( + error_dict( + "Something wrong with request or http api. Please, open new issue on Github! This is error in " + "code." + ), + 400, + ) diff --git a/carvekit/web/routers/api_router.py b/carvekit/web/routers/api_router.py index 359309b..c452cac 100644 --- a/carvekit/web/routers/api_router.py +++ b/carvekit/web/routers/api_router.py @@ -18,52 +18,60 @@ from carvekit.web.schemas.request import Parameters from carvekit.web.utils.net_utils import is_loopback -api_router = APIRouter(prefix='', tags=['api']) +api_router = APIRouter(prefix="", tags=["api"]) # noinspection PyBroadException -@api_router.post('/removebg') +@api_router.post("/removebg") async def removebg( - request: Request, - image_file: Optional[bytes] = File(None), - auth: bool = Depends(Authenticate), - content_type: str = Header(""), - image_file_b64: Optional[str] = Form(None), - image_url: Optional[str] = Form(None), - bg_image_file: Optional[bytes] = File(None), - size: Optional[str] = Form("full"), - type: Optional[str] = Form("auto"), - format: Optional[str] = Form("auto"), - roi: str = Form("0% 0% 100% 100%"), - crop: bool = Form(False), - crop_margin: Optional[str] = Form("0px"), - scale: Optional[str] = Form("original"), - position: Optional[str] = Form("original"), - channels: Optional[str] = Form("rgba"), - add_shadow: bool = Form(False), # Not supported at the moment - semitransparency: bool = Form(False), # Not supported at the moment - bg_color: Optional[str] = Form("") + request: Request, + image_file: Optional[bytes] = File(None), + auth: bool = Depends(Authenticate), + content_type: str = Header(""), + image_file_b64: Optional[str] = Form(None), + image_url: Optional[str] = Form(None), + bg_image_file: Optional[bytes] = File(None), + size: Optional[str] = Form("full"), + type: Optional[str] = Form("auto"), + format: Optional[str] = Form("auto"), + roi: str = Form("0% 0% 100% 100%"), + crop: bool = Form(False), + crop_margin: Optional[str] = Form("0px"), + scale: Optional[str] = Form("original"), + position: Optional[str] = Form("original"), + channels: Optional[str] = Form("rgba"), + add_shadow: bool = Form(False), # Not supported at the moment + semitransparency: bool = Form(False), # Not supported at the moment + bg_color: Optional[str] = Form(""), ): if auth is False: return JSONResponse(content=error_dict("Missing API Key"), status_code=403) - if content_type not in ["application/x-www-form-urlencoded", - "application/json"] and "multipart/form-data" not in content_type: - return JSONResponse(content=error_dict("Invalid request content type"), status_code=400) + if ( + content_type not in ["application/x-www-form-urlencoded", "application/json"] + and "multipart/form-data" not in content_type + ): + return JSONResponse( + content=error_dict("Invalid request content type"), status_code=400 + ) if image_url: - if ( - not image_url.startswith("http://") or - not image_url.startswith("https://") or - is_loopback(image_url) - ): - print(f"Possible ssrf attempt to /api/removebg endpoint with image url: {image_url}") - return JSONResponse(content=error_dict("Invalid image url."), - status_code=400) # possible ssrf attempt + if not ( + image_url.startswith("http://") or image_url.startswith("https://") + ) or is_loopback(image_url): + print( + f"Possible ssrf attempt to /api/removebg endpoint with image url: {image_url}" + ) + return JSONResponse( + content=error_dict("Invalid image url."), status_code=400 + ) # possible ssrf attempt image = None bg = None parameters = None - if content_type == "application/x-www-form-urlencoded" or "multipart/form-data" in content_type: + if ( + content_type == "application/x-www-form-urlencoded" + or "multipart/form-data" in content_type + ): if image_file_b64 is None and image_url is None and image_file is None: return JSONResponse(content=error_dict("File not found"), status_code=400) @@ -73,12 +81,16 @@ async def removebg( try: image = Image.open(io.BytesIO(base64.b64decode(image_file_b64))) except BaseException: - return JSONResponse(content=error_dict("Error decode image!"), status_code=400) + return JSONResponse( + content=error_dict("Error decode image!"), status_code=400 + ) elif image_url: try: image = Image.open(io.BytesIO(requests.get(image_url).content)) except BaseException: - return JSONResponse(content=error_dict("Error download image!"), status_code=400) + return JSONResponse( + content=error_dict("Error download image!"), status_code=400 + ) elif image_file: if len(image_file) == 0: return JSONResponse(content=error_dict("Empty image"), status_code=400) @@ -106,7 +118,9 @@ async def removebg( bg_color=bg_color, ) except ValidationError as e: - return JSONResponse(content=e.json(), status_code=400, media_type='application/json') + return JSONResponse( + content=e.json(), status_code=400, media_type="application/json" + ) else: payload = None @@ -117,7 +131,9 @@ async def removebg( try: parameters = Parameters(**payload) except ValidationError as e: - return Response(content=e.json(), status_code=400, media_type='application/json') + return Response( + content=e.json(), status_code=400, media_type="application/json" + ) if parameters.image_file_b64 is None and parameters.image_url is None: return JSONResponse(content=error_dict("File not found"), status_code=400) @@ -125,30 +141,44 @@ async def removebg( if len(parameters.image_file_b64) == 0: return JSONResponse(content=error_dict("Empty image"), status_code=400) try: - image = Image.open(io.BytesIO(base64.b64decode(parameters.image_file_b64))) + image = Image.open( + io.BytesIO(base64.b64decode(parameters.image_file_b64)) + ) except BaseException: - return JSONResponse(content=error_dict("Error decode image!"), status_code=400) + return JSONResponse( + content=error_dict("Error decode image!"), status_code=400 + ) elif parameters.image_url: - if ( - not parameters.image_url.startswith("http://") or - not parameters.image_url.startswith("https://") or - is_loopback(parameters.image_url) - ): - print(f"Possible ssrf attempt to /api/removebg endpoint with image url: {parameters.image_url}") - return JSONResponse(content=error_dict("Invalid image url."), - status_code=400) # possible ssrf attempt + if not ( + parameters.image_url.startswith("http://") + or parameters.image_url.startswith("https://") + ) or is_loopback(parameters.image_url): + print( + f"Possible ssrf attempt to /api/removebg endpoint with image url: {parameters.image_url}" + ) + return JSONResponse( + content=error_dict("Invalid image url."), status_code=400 + ) # possible ssrf attempt try: - image = Image.open(io.BytesIO(requests.get(parameters.image_url).content)) + image = Image.open( + io.BytesIO(requests.get(parameters.image_url).content) + ) except BaseException: - return JSONResponse(content=error_dict("Error download image!"), status_code=400) + return JSONResponse( + content=error_dict("Error download image!"), status_code=400 + ) if image is None: - return JSONResponse(content=error_dict("Error download image!"), status_code=400) + return JSONResponse( + content=error_dict("Error download image!"), status_code=400 + ) job_id = ml_processor.job_create([parameters.dict(), image, bg, False]) while ml_processor.job_status(job_id) != "finished": if ml_processor.job_status(job_id) == "not_found": - return JSONResponse(content=error_dict("Job ID not found!"), status_code=500) + return JSONResponse( + content=error_dict("Job ID not found!"), status_code=500 + ) time.sleep(5) result = ml_processor.job_result(job_id) @@ -160,9 +190,22 @@ def account(): """ Stub for compatibility with remove.bg api libraries """ - return JSONResponse(content={"data": {"attributes": { - "credits": {"total": 99999, "subscription": 99999, "payg": 99999, "enterprise": 99999}, - "api": {"free_calls": 99999, "sizes": "all"}}}}, status_code=200) + return JSONResponse( + content={ + "data": { + "attributes": { + "credits": { + "total": 99999, + "subscription": 99999, + "payg": 99999, + "enterprise": 99999, + }, + "api": {"free_calls": 99999, "sizes": "all"}, + } + } + }, + status_code=200, + ) @api_router.get("/admin/config") @@ -171,7 +214,9 @@ def status(auth: str = Depends(Authenticate)): Returns the current server config. """ if not auth or auth != "admin": - return JSONResponse(content=error_dict("Authentication failed"), status_code=403) + return JSONResponse( + content=error_dict("Authentication failed"), status_code=403 + ) resp = JSONResponse(content=config.json(), status_code=200) resp.headers["X-Credits-Charged"] = "0" return resp diff --git a/carvekit/web/schemas/config.py b/carvekit/web/schemas/config.py index de5d337..5d47ffc 100644 --- a/carvekit/web/schemas/config.py +++ b/carvekit/web/schemas/config.py @@ -7,7 +7,8 @@ class AuthConfig(BaseModel): - """Config for web api token authentication """ + """Config for web api token authentication""" + auth: bool = True """Enables Token Authentication for API""" admin_token: str = secrets.token_hex(32) @@ -18,7 +19,10 @@ class AuthConfig(BaseModel): class MLConfig(BaseModel): """Config for ml part of framework""" - segmentation_network: Literal["u2net", "deeplabv3", "basnet"] = "u2net" + + segmentation_network: Literal[ + "u2net", "deeplabv3", "basnet", "tracer_b7" + ] = "tracer_b7" """Segmentation Network""" preprocessing_method: Literal["none", "stub"] = "none" """Pre-processing Method""" @@ -30,50 +34,61 @@ class MLConfig(BaseModel): """Batch size for segmentation network""" batch_size_matting: int = 1 """Batch size for matting network""" - seg_mask_size: int = 320 + seg_mask_size: int = 640 """The size of the input image for the segmentation neural network.""" matting_mask_size: int = 2048 """The size of the input image for the matting neural network.""" + fp16: bool = False + """Use half precision for inference""" + trimap_dilation: int = 30 + """Dilation size for trimap""" + trimap_erosion: int = 5 + """Erosion levels for trimap""" + trimap_prob_threshold: int = 231 + """Probability threshold for trimap generation""" - @validator('seg_mask_size') + @validator("seg_mask_size") def seg_mask_size_validator(cls, value: int, values): if value > 0: return value else: raise ValueError("Incorrect seg_mask_size!") - @validator('matting_mask_size') + @validator("matting_mask_size") def matting_mask_size_validator(cls, value: int, values): if value > 0: return value else: raise ValueError("Incorrect matting_mask_size!") - @validator('batch_size_seg') + @validator("batch_size_seg") def batch_size_seg_validator(cls, value: int, values): if value > 0: return value else: raise ValueError("Incorrect batch size!") - @validator('batch_size_matting') - def batch_size_matting_validator(cls, value: int,values): + @validator("batch_size_matting") + def batch_size_matting_validator(cls, value: int, values): if value > 0: return value else: raise ValueError("Incorrect batch size!") - @validator('device') + @validator("device") def device_validator(cls, value): if torch.cuda.is_available() is False and "cuda" in value: - raise ValueError("GPU is not available, but specified as processing device!") - if 'cuda' not in value and "cpu" != value: + raise ValueError( + "GPU is not available, but specified as processing device!" + ) + if "cuda" not in value and "cpu" != value: raise ValueError("Unknown processing device! It should be cpu or cuda!") return value class WebAPIConfig(BaseModel): """FastAPI app config""" + port: int = 5000 """Web API port""" host: str = "0.0.0.0" diff --git a/carvekit/web/schemas/request.py b/carvekit/web/schemas/request.py index 16c7fc5..d7ebefc 100644 --- a/carvekit/web/schemas/request.py +++ b/carvekit/web/schemas/request.py @@ -8,57 +8,65 @@ class Parameters(BaseModel): image_file_b64: Optional[str] = "" image_url: Optional[str] = "" - size: Optional[Literal['preview', 'full', 'auto']] = "preview" - type: Optional[Literal['auto', 'product', 'person', 'car']] = "auto" # Not supported at the moment - format: Optional[Literal['auto', 'jpg', 'png', 'zip']] = "auto" + size: Optional[Literal["preview", "full", "auto"]] = "preview" + type: Optional[ + Literal["auto", "product", "person", "car"] + ] = "auto" # Not supported at the moment + format: Optional[Literal["auto", "jpg", "png", "zip"]] = "auto" roi: str = "0% 0% 100% 100%" crop: bool = False crop_margin: Optional[str] = "0px" scale: Optional[str] = "original" position: Optional[str] = "original" - channels: Optional[Literal['rgba', 'alpha']] = "rgba" + channels: Optional[Literal["rgba", "alpha"]] = "rgba" add_shadow: str = "false" # Not supported at the moment semitransparency: str = "false" # Not supported at the moment bg_color: Optional[str] = "" bg_image_url: Optional[str] = "" - @validator('crop_margin') + @validator("crop_margin") def crop_margin_validator(cls, value): - if not re.match(r'[0-9]+(px|%)$', value): - raise ValueError('crop_margin paramter is not valid') # TODO: Add support of several values - if '%' in value and (int(value[:-1]) < 0 or int(value[:-1]) > 100): - raise ValueError('crop_margin mast be in range between 0% and 100%') + if not re.match(r"[0-9]+(px|%)$", value): + raise ValueError( + "crop_margin paramter is not valid" + ) # TODO: Add support of several values + if "%" in value and (int(value[:-1]) < 0 or int(value[:-1]) > 100): + raise ValueError("crop_margin mast be in range between 0% and 100%") return value - @validator('scale') + @validator("scale") def scale_validator(cls, value): - if value != 'original' and ( - not re.match(r'[0-9]+%$', value) or not int(value[:-1]) <= 100 or not int(value[:-1]) >= 10): - raise ValueError('scale must be original or in between of 10% and 100%') + if value != "original" and ( + not re.match(r"[0-9]+%$", value) + or not int(value[:-1]) <= 100 + or not int(value[:-1]) >= 10 + ): + raise ValueError("scale must be original or in between of 10% and 100%") - if value == 'original': + if value == "original": return 100 return int(value[:-1]) - @validator('position') + @validator("position") def position_validator(cls, value, values): - if len(value.split(' ')) > 2: + if len(value.split(" ")) > 2: raise ValueError( - 'Position must be a value from 0 to 100 ' - 'for both vertical and horizontal axises or for both axises respectively') + "Position must be a value from 0 to 100 " + "for both vertical and horizontal axises or for both axises respectively" + ) - if value == 'original': - return 'original' - elif len(value.split(' ')) == 1: + if value == "original": + return "original" + elif len(value.split(" ")) == 1: return [int(value[:-1]), int(value[:-1])] else: - return [int(value.split(' ')[0][:-1]), int(value.split(' ')[1][:-1])] + return [int(value.split(" ")[0][:-1]), int(value.split(" ")[1][:-1])] - @validator('bg_color') + @validator("bg_color") def bg_color_validator(cls, value): - if not re.match(r'(#{0,1}[0-9a-f]{3}){0,2}$', value): - raise ValueError('bg_color is not in hex') - if len(value) and value[0] != '#': - value = '#' + value + if not re.match(r"(#{0,1}[0-9a-f]{3}){0,2}$", value): + raise ValueError("bg_color is not in hex") + if len(value) and value[0] != "#": + value = "#" + value return value diff --git a/carvekit/web/utils/init_utils.py b/carvekit/web/utils/init_utils.py index cb78041..f687182 100644 --- a/carvekit/web/utils/init_utils.py +++ b/carvekit/web/utils/init_utils.py @@ -9,6 +9,7 @@ from carvekit.ml.wrap.u2net import U2NET from carvekit.ml.wrap.deeplab_v3 import DeepLabV3 from carvekit.ml.wrap.basnet import BASNET +from carvekit.ml.wrap.tracer_b7 import TracerUniversalB7 from carvekit.pipelines.postprocessing import MattingMethod from carvekit.pipelines.preprocessing import PreprocessingStub @@ -18,29 +19,72 @@ def init_config() -> WebAPIConfig: default_config = WebAPIConfig() config = WebAPIConfig( - **dict(port=int(getenv('CARVEKIT_PORT', default_config.port)), - host=getenv('CARVEKIT_HOST', default_config.host), - ml=MLConfig( - segmentation_network=getenv('CARVEKIT_SEGMENTATION_NETWORK', default_config.ml.segmentation_network), - preprocessing_method=getenv('CARVEKIT_PREPROCESSING_METHOD', default_config.ml.preprocessing_method), - postprocessing_method=getenv('CARVEKIT_POSTPROCESSING_METHOD', - default_config.ml.postprocessing_method), - device=getenv('CARVEKIT_DEVICE', default_config.ml.device), - batch_size_seg=int(getenv('CARVEKIT_BATCH_SIZE_SEG', default_config.ml.batch_size_seg)), - batch_size_matting=int(getenv('CARVEKIT_BATCH_SIZE_MATTING', default_config.ml.batch_size_matting)), - seg_mask_size=int(getenv('CARVEKIT_SEG_MASK_SIZE', default_config.ml.seg_mask_size)), - matting_mask_size=int(getenv('CARVEKIT_MATTING_MASK_SIZE', default_config.ml.matting_mask_size)) - ), auth=AuthConfig( - auth=bool(int(getenv('CARVEKIT_AUTH_ENABLE', default_config.auth.auth))), - admin_token=getenv('CARVEKIT_ADMIN_TOKEN', default_config.auth.admin_token), + **dict( + port=int(getenv("CARVEKIT_PORT", default_config.port)), + host=getenv("CARVEKIT_HOST", default_config.host), + ml=MLConfig( + segmentation_network=getenv( + "CARVEKIT_SEGMENTATION_NETWORK", + default_config.ml.segmentation_network, + ), + preprocessing_method=getenv( + "CARVEKIT_PREPROCESSING_METHOD", + default_config.ml.preprocessing_method, + ), + postprocessing_method=getenv( + "CARVEKIT_POSTPROCESSING_METHOD", + default_config.ml.postprocessing_method, + ), + device=getenv("CARVEKIT_DEVICE", default_config.ml.device), + batch_size_seg=int( + getenv("CARVEKIT_BATCH_SIZE_SEG", default_config.ml.batch_size_seg) + ), + batch_size_matting=int( + getenv( + "CARVEKIT_BATCH_SIZE_MATTING", + default_config.ml.batch_size_matting, + ) + ), + seg_mask_size=int( + getenv("CARVEKIT_SEG_MASK_SIZE", default_config.ml.seg_mask_size) + ), + matting_mask_size=int( + getenv( + "CARVEKIT_MATTING_MASK_SIZE", + default_config.ml.matting_mask_size, + ) + ), + fp16=bool(int(getenv("CARVEKIT_FP16", default_config.ml.fp16))), + trimap_prob_threshold=int( + getenv( + "CARVEKIT_TRIMAP_PROB_THRESHOLD", + default_config.ml.trimap_prob_threshold, + ) + ), + trimap_dilation=int( + getenv( + "CARVEKIT_TRIMAP_DILATION", default_config.ml.trimap_dilation + ) + ), + trimap_erosion=int( + getenv("CARVEKIT_TRIMAP_EROSION", default_config.ml.trimap_erosion) + ), + ), + auth=AuthConfig( + auth=bool( + int(getenv("CARVEKIT_AUTH_ENABLE", default_config.auth.auth)) + ), + admin_token=getenv( + "CARVEKIT_ADMIN_TOKEN", default_config.auth.admin_token + ), allowed_tokens=default_config.auth.allowed_tokens - if getenv('CARVEKIT_ALLOWED_TOKENS') is None else getenv('CARVEKIT_ALLOWED_TOKENS').split(',') - - )) - + if getenv("CARVEKIT_ALLOWED_TOKENS") is None + else getenv("CARVEKIT_ALLOWED_TOKENS").split(","), + ), + ) ) - logger.info(f'Admin token for Web API is {config.auth.admin_token}') + logger.info(f"Admin token for Web API is {config.auth.admin_token}") logger.debug(f"Running Web API with this config: {config.json()}") return config @@ -49,21 +93,40 @@ def init_interface(config: Union[WebAPIConfig, MLConfig]) -> Interface: if isinstance(config, WebAPIConfig): config = config.ml if config.segmentation_network == "u2net": - seg_net = U2NET(device=config.device, - batch_size=config.batch_size_seg, - input_image_size=config.seg_mask_size) + seg_net = U2NET( + device=config.device, + batch_size=config.batch_size_seg, + input_image_size=config.seg_mask_size, + fp16=config.fp16, + ) elif config.segmentation_network == "deeplabv3": - seg_net = DeepLabV3(device=config.device, - batch_size=config.batch_size_seg, - input_image_size=config.seg_mask_size) + seg_net = DeepLabV3( + device=config.device, + batch_size=config.batch_size_seg, + input_image_size=config.seg_mask_size, + fp16=config.fp16, + ) elif config.segmentation_network == "basnet": - seg_net = BASNET(device=config.device, - batch_size=config.batch_size_seg, - input_tensor_size=config.seg_mask_size) + seg_net = BASNET( + device=config.device, + batch_size=config.batch_size_seg, + input_image_size=config.seg_mask_size, + fp16=config.fp16, + ) + elif config.segmentation_network == "tracer_b7": + seg_net = TracerUniversalB7( + device=config.device, + batch_size=config.batch_size_seg, + input_image_size=config.seg_mask_size, + fp16=config.fp16, + ) else: - seg_net = U2NET(device=config.device, - batch_size=config.batch_size_seg, - input_image_size=config.seg_mask_size) + seg_net = TracerUniversalB7( + device=config.device, + batch_size=config.batch_size_seg, + input_image_size=config.seg_mask_size, + fp16=config.fp16, + ) if config.preprocessing_method == "stub": preprocessing = PreprocessingStub() @@ -73,21 +136,30 @@ def init_interface(config: Union[WebAPIConfig, MLConfig]) -> Interface: preprocessing = None if config.postprocessing_method == "fba": - fba = FBAMatting(device=config.device, - batch_size=config.batch_size_matting, - input_tensor_size=config.matting_mask_size) - trimap_generator = TrimapGenerator() - postprocessing = MattingMethod(device=config.device, - matting_module=fba, - trimap_generator=trimap_generator) + fba = FBAMatting( + device=config.device, + batch_size=config.batch_size_matting, + input_tensor_size=config.matting_mask_size, + fp16=config.fp16, + ) + trimap_generator = TrimapGenerator( + prob_threshold=config.trimap_prob_threshold, + kernel_size=config.trimap_dilation, + erosion_iters=config.trimap_erosion, + ) + postprocessing = MattingMethod( + device=config.device, matting_module=fba, trimap_generator=trimap_generator + ) elif config.postprocessing_method == "none": postprocessing = None else: postprocessing = None - interface = Interface(pre_pipe=preprocessing, - post_pipe=postprocessing, - seg_pipe=seg_net, - device=config.device) + interface = Interface( + pre_pipe=preprocessing, + post_pipe=postprocessing, + seg_pipe=seg_net, + device=config.device, + ) return interface diff --git a/carvekit/web/utils/net_utils.py b/carvekit/web/utils/net_utils.py index 6f7a225..12a9620 100644 --- a/carvekit/web/utils/net_utils.py +++ b/carvekit/web/utils/net_utils.py @@ -10,12 +10,14 @@ def is_loopback(address): try: parsed_url = urlparse(address) host = parsed_url.hostname - except: + except ValueError: return False # url is not even a url loopback_checker = { - socket.AF_INET: lambda x: struct.unpack('!I', socket.inet_aton(x))[0] >> (32 - 8) == 127, - socket.AF_INET6: lambda x: x == '::1' + socket.AF_INET: lambda x: struct.unpack("!I", socket.inet_aton(x))[0] + >> (32 - 8) + == 127, + socket.AF_INET6: lambda x: x == "::1", } for family in (socket.AF_INET, socket.AF_INET6): try: diff --git a/carvekit/web/utils/task_queue.py b/carvekit/web/utils/task_queue.py index 2b85af8..f821434 100644 --- a/carvekit/web/utils/task_queue.py +++ b/carvekit/web/utils/task_queue.py @@ -14,6 +14,7 @@ class MLProcessor(threading.Thread): """Simple ml task queue processor""" + def __init__(self, api_config: WebAPIConfig): super().__init__() self.api_config = api_config @@ -36,7 +37,9 @@ def run(self): id = list(self.jobs.keys())[0] data = self.jobs[id] # TODO add pydantic scheme here - response = process_remove_bg(self.interface, data[0], data[1], data[2], data[3]) + response = process_remove_bg( + self.interface, data[0], data[1], data[2], data[3] + ) self.completed_jobs[id] = [response, time.time()] try: del self.jobs[id] diff --git a/conftest.py b/conftest.py index c218370..f328d35 100644 --- a/conftest.py +++ b/conftest.py @@ -22,15 +22,30 @@ from carvekit.ml.wrap.basnet import BASNET from carvekit.ml.wrap.fba_matting import FBAMatting from carvekit.ml.wrap.deeplab_v3 import DeepLabV3 +from carvekit.ml.wrap.tracer_b7 import TracerUniversalB7 @pytest.fixture() -def u2net_model() -> Callable[[], U2NET]: - return lambda: U2NET(layers_cfg="full", - device='cuda' if torch.cuda.is_available() else 'cpu', - input_image_size=320, - batch_size=10, - load_pretrained=True) +def u2net_model() -> Callable[[bool], U2NET]: + return lambda fb16: U2NET( + layers_cfg="full", + device="cuda" if torch.cuda.is_available() else "cpu", + input_image_size=320, + batch_size=10, + load_pretrained=True, + fp16=fb16, + ) + + +@pytest.fixture() +def tracer_model() -> Callable[[bool], TracerUniversalB7]: + return lambda fb16: TracerUniversalB7( + device="cuda" if torch.cuda.is_available() else "cpu", + input_image_size=320, + batch_size=10, + load_pretrained=True, + fp16=fb16, + ) @pytest.fixture() @@ -50,46 +65,66 @@ def preprocessing_stub_instance() -> Callable[[], PreprocessingStub]: @pytest.fixture() def matting_method_instance(fba_model, trimap_instance): - return lambda: MattingMethod(matting_module=fba_model(), trimap_generator=trimap_instance(), device="cpu") + return lambda: MattingMethod( + matting_module=fba_model(False), + trimap_generator=trimap_instance(), + device="cpu", + ) @pytest.fixture() def high_interface_instance() -> Callable[[], HiInterface]: - return lambda: HiInterface(batch_size_seg=5, batch_size_matting=1, - device='cuda' if torch.cuda.is_available() else 'cpu', - seg_mask_size=320, matting_mask_size=2048) + return lambda: HiInterface( + batch_size_seg=5, + batch_size_matting=1, + device="cuda" if torch.cuda.is_available() else "cpu", + seg_mask_size=320, + matting_mask_size=2048, + ) @pytest.fixture() -def interface_instance(u2net_model, preprocessing_stub_instance, - matting_method_instance) -> Callable[[], Interface]: - return lambda: Interface(u2net_model(), - pre_pipe=preprocessing_stub_instance(), - post_pipe=matting_method_instance(), - device='cuda' if torch.cuda.is_available() else 'cpu') +def interface_instance( + u2net_model, preprocessing_stub_instance, matting_method_instance +) -> Callable[[], Interface]: + return lambda: Interface( + u2net_model(False), + pre_pipe=preprocessing_stub_instance(), + post_pipe=matting_method_instance(), + device="cuda" if torch.cuda.is_available() else "cpu", + ) @pytest.fixture() -def fba_model() -> Callable[[], FBAMatting]: - return lambda: FBAMatting(device='cuda' if torch.cuda.is_available() else 'cpu', - input_tensor_size=1024, - batch_size=2, - load_pretrained=True) +def fba_model() -> Callable[[bool], FBAMatting]: + return lambda fp16: FBAMatting( + device="cuda" if torch.cuda.is_available() else "cpu", + input_tensor_size=1024, + batch_size=2, + load_pretrained=True, + fp16=fp16, + ) @pytest.fixture() -def deeplabv3_model() -> Callable[[], DeepLabV3]: - return lambda: DeepLabV3(device='cuda' if torch.cuda.is_available() else 'cpu', - batch_size=10, - load_pretrained=True) +def deeplabv3_model() -> Callable[[bool], DeepLabV3]: + return lambda fp16: DeepLabV3( + device="cuda" if torch.cuda.is_available() else "cpu", + batch_size=10, + load_pretrained=True, + fp16=fp16, + ) @pytest.fixture() -def basnet_model() -> Callable[[], BASNET]: - return lambda: BASNET(device='cuda' if torch.cuda.is_available() else 'cpu', - input_tensor_size=320, - batch_size=10, - load_pretrained=True) +def basnet_model() -> Callable[[bool], BASNET]: + return lambda fp16: BASNET( + device="cuda" if torch.cuda.is_available() else "cpu", + input_image_size=320, + batch_size=10, + load_pretrained=True, + fp16=fp16, + ) @pytest.fixture() @@ -99,17 +134,19 @@ def image_str(image_path) -> str: @pytest.fixture() def image_path() -> Path: - return Path(__file__).parent.joinpath('tests').joinpath('data', 'cat.jpg') + return Path(__file__).parent.joinpath("tests").joinpath("data", "cat.jpg") @pytest.fixture() def image_mask(image_path) -> Image.Image: - return Image.open(image_path.with_name('cat_mask').with_suffix(".png")) + return Image.open(image_path.with_name("cat_mask").with_suffix(".png")) @pytest.fixture() def image_trimap(image_path) -> Image.Image: - return Image.open(image_path.with_name('cat_trimap').with_suffix(".png")).convert("L") + return Image.open(image_path.with_name("cat_trimap").with_suffix(".png")).convert( + "L" + ) @pytest.fixture() @@ -128,10 +165,17 @@ def converted_pil_image(image_pil) -> Image.Image: @pytest.fixture() -def available_models(u2net_model, deeplabv3_model, basnet_model, - preprocessing_stub_instance, matting_method_instance) -> Tuple[ - List[Union[Callable[[], U2NET], Callable[[], DeepLabV3], Callable[[], BASNET]]], List[ - Optional[Callable[[], PreprocessingStub]]], List[Union[Optional[Callable[[], MattingMethod]], Any]]]: +def available_models( + u2net_model, + deeplabv3_model, + basnet_model, + preprocessing_stub_instance, + matting_method_instance, +) -> Tuple[ + List[Union[Callable[[], U2NET], Callable[[], DeepLabV3], Callable[[], BASNET]]], + List[Optional[Callable[[], PreprocessingStub]]], + List[Union[Optional[Callable[[], MattingMethod]], Any]], +]: models = [u2net_model, deeplabv3_model, basnet_model] pre_pipes = [None, preprocessing_stub_instance] post_pipes = [None, matting_method_instance] diff --git a/docker-compose.cpu.yml b/docker-compose.cpu.yml index 9ee8de7..1fe3f5a 100644 --- a/docker-compose.cpu.yml +++ b/docker-compose.cpu.yml @@ -1,22 +1,23 @@ services: carvekit_api: - build: - dockerfile: Dockerfile.cpu - context: . + image: anodev/carvekit:latest-cpu ports: - "5000:5000" # 5000 environment: - CARVEKIT_PORT=5000 - CARVEKIT_HOST=0.0.0.0 - - CARVEKIT_SEGMENTATION_NETWORK=u2net # can be u2net, basnet, deeplabv3 + - CARVEKIT_SEGMENTATION_NETWORK=tracer_b7 # can be u2net, tracer_b7, basnet, deeplabv3 - CARVEKIT_PREPROCESSING_METHOD=none # can be none, stub - CARVEKIT_POSTPROCESSING_METHOD=fba # can be none, fba - CARVEKIT_DEVICE=cpu # can be cuda (req. cuda docker image), cpu - CARVEKIT_BATCH_SIZE_SEG=5 # Number of images processed per one segmentation nn call. NOT USED IF WEB API IS USED - CARVEKIT_BATCH_SIZE_MATTING=1 # Number of images processed per one matting nn call. NOT USED IF WEB API IS USED - - CARVEKIT_SEG_MASK_SIZE=320 # The size of the input image for the segmentation neural network. + - CARVEKIT_SEG_MASK_SIZE=640 # The size of the input image for the segmentation neural network. - CARVEKIT_MATTING_MASK_SIZE=2048 # The size of the input image for the matting neural network. - + - CARVEKIT_FP16=0 # Enables FP16 mode (Only CUDA at the moment) + - CARVEKIT_TRIMAP_PROB_THRESHOLD=231 # Probability threshold at which the prob_filter and prob_as_unknown_area operations will be applied + - CARVEKIT_TRIMAP_DILATION=30 # The size of the offset radius from the object mask in pixels when forming an unknown area + - CARVEKIT_TRIMAP_EROSION=5 # The number of iterations of erosion that the object's mask will be subjected to before forming an unknown area - CARVEKIT_AUTH_ENABLE=1 # Enables authentication by tokens # Tokens will be generated automatically every time the container is restarted if these ENV is not set. #- CARVEKIT_ADMIN_TOKEN=admin diff --git a/docker-compose.cuda.yml b/docker-compose.cuda.yml index 71c2dc5..8308594 100644 --- a/docker-compose.cuda.yml +++ b/docker-compose.cuda.yml @@ -1,22 +1,23 @@ services: carvekit_api: - build: - dockerfile: Dockerfile.cuda - context: . + image: anodev/carvekit:latest-cuda ports: - "5000:5000" # 5000 environment: - CARVEKIT_PORT=5000 - CARVEKIT_HOST=0.0.0.0 - - CARVEKIT_SEGMENTATION_NETWORK=u2net # can be u2net, basnet, deeplabv3 + - CARVEKIT_SEGMENTATION_NETWORK=tracer_b7 # can be u2net, tracer_b7, basnet, deeplabv3 - CARVEKIT_PREPROCESSING_METHOD=none # can be none, stub - CARVEKIT_POSTPROCESSING_METHOD=fba # can be none, fba - CARVEKIT_DEVICE=cuda # can be cuda (req. cuda docker image), cpu - CARVEKIT_BATCH_SIZE_SEG=5 # Number of images processed per one segmentation nn call. NOT USED IF WEB API IS USED - CARVEKIT_BATCH_SIZE_MATTING=1 # Number of images processed per one matting nn call. NOT USED IF WEB API IS USED - - CARVEKIT_SEG_MASK_SIZE=320 # The size of the input image for the segmentation neural network. + - CARVEKIT_SEG_MASK_SIZE=640 # The size of the input image for the segmentation neural network. - CARVEKIT_MATTING_MASK_SIZE=2048 # The size of the input image for the matting neural network. - + - CARVEKIT_FP16=0 # Enables FP16 mode (Only CUDA at the moment) + - CARVEKIT_TRIMAP_PROB_THRESHOLD=231 # Probability threshold at which the prob_filter and prob_as_unknown_area operations will be applied + - CARVEKIT_TRIMAP_DILATION=30 # The size of the offset radius from the object mask in pixels when forming an unknown area + - CARVEKIT_TRIMAP_EROSION=5 # The number of iterations of erosion that the object's mask will be subjected to before forming an unknown area - CARVEKIT_AUTH_ENABLE=1 # Enables authentication by tokens # Tokens will be generated automatically every time the container is restarted if these ENV is not set. #- CARVEKIT_ADMIN_TOKEN=admin diff --git a/docs/CREDITS.md b/docs/CREDITS.md index 33f37a6..c544c65 100644 --- a/docs/CREDITS.md +++ b/docs/CREDITS.md @@ -16,10 +16,11 @@ All images are copyrighted by their authors. 2. https://github.com/NathanUA/U-2-Net 3. https://github.com/NathanUA/BASNet 4. https://github.com/MarcoForte/FBA_Matting -5. https://gluon-cv.mxnet.io/model_zoo/detection.html -6. https://arxiv.org/abs/1706.05587 -7. https://arxiv.org/pdf/2005.09007.pdf -8. http://openaccess.thecvf.com/content_CVPR_2019/html/Qin_BASNet_Boundary-Aware_Salient_Object_Detection_CVPR_2019_paper.html -9. https://arxiv.org/abs/2003.07711 -10. https://arxiv.org/abs/1506.01497 -11. https://arxiv.org/abs/1703.06870 +5. https://arxiv.org/abs/1706.05587 +6. https://arxiv.org/pdf/2005.09007.pdf +7. http://openaccess.thecvf.com/content_CVPR_2019/html/Qin_BASNet_Boundary-Aware_Salient_Object_Detection_CVPR_2019_paper.html +8. https://arxiv.org/abs/2003.07711 +9. https://arxiv.org/abs/1506.01497 +10. https://arxiv.org/abs/1703.06870 +11. https://github.com/Karel911/TRACER +12. https://arxiv.org/abs/2112.07380 diff --git a/docs/code_examples/python/http_api_lib.py b/docs/code_examples/python/http_api_lib.py index 7956efb..cef55ff 100644 --- a/docs/code_examples/python/http_api_lib.py +++ b/docs/code_examples/python/http_api_lib.py @@ -11,22 +11,25 @@ remove_bg_api.API_URL = "http://localhost:5000/api" # Change the endpoint url removebg = remove_bg_api.RemoveBg("test") -settings = \ - { # API settings. See https://www.remove.bg/api for more details. - "size": "preview", # ["preview", "full", "auto", "medium", "hd", "4k", "small", "regular"] - "type": "auto", # ["auto", "person", "product", "car"] - "format": "auto", # ["auto", "png", "jpg", "zip"] - "roi": "", # {}% {}% {}% {}% or {}px {}px {}px {}px - "crop": False, # True or False - "crop_margin": "0px", # {}% or {}px - "scale": "original", # "{}%" or "original" - "position": "original", # "original" "center", or {}% - "channels": "rgba", # "rgba" or "alpha" - "add_shadow": "false", # Not supported at the moment - "semitransparency": "false", # Not supported at the moment - "bg_color": "", # "81d4fa" or "red" or any other color - "bg_image_url": "" # URL - } +settings = { # API settings. See https://www.remove.bg/api for more details. + "size": "preview", # ["preview", "full", "auto", "medium", "hd", "4k", "small", "regular"] + "type": "auto", # ["auto", "person", "product", "car"] + "format": "auto", # ["auto", "png", "jpg", "zip"] + "roi": "", # {}% {}% {}% {}% or {}px {}px {}px {}px + "crop": False, # True or False + "crop_margin": "0px", # {}% or {}px + "scale": "original", # "{}%" or "original" + "position": "original", # "original" "center", or {}% + "channels": "rgba", # "rgba" or "alpha" + "add_shadow": "false", # Not supported at the moment + "semitransparency": "false", # Not supported at the moment + "bg_color": "", # "81d4fa" or "red" or any other color + "bg_image_url": "", # URL +} -removebg.remove_bg_file(str(Path("images/4.jpg").absolute()), raw=False, - out_path=str(Path("./4.png").absolute()), data=settings) \ No newline at end of file +removebg.remove_bg_file( + str(Path("images/4.jpg").absolute()), + raw=False, + out_path=str(Path("./4.png").absolute()), + data=settings, +) diff --git a/docs/code_examples/python/http_api_requests.py b/docs/code_examples/python/http_api_requests.py index a2f2e69..f0370a7 100644 --- a/docs/code_examples/python/http_api_requests.py +++ b/docs/code_examples/python/http_api_requests.py @@ -8,12 +8,12 @@ from pathlib import Path response = requests.post( - 'http://localhost:5000/api/removebg', - files={'image_file': Path("images/4.jpg").read_bytes()}, - data={'size': 'auto'}, - headers={'X-Api-Key': 'test'}, + "http://localhost:5000/api/removebg", + files={"image_file": Path("images/4.jpg").read_bytes()}, + data={"size": "auto"}, + headers={"X-Api-Key": "test"}, ) if response.status_code == 200: Path("image_without_bg.png").write_bytes(response.content) else: - print("Error:", response.status_code, response.text) \ No newline at end of file + print("Error:", response.status_code, response.text) diff --git a/docs/imgs/input/1_bg_removed.png b/docs/imgs/input/1_bg_removed.png index 4a05649..a1e44f6 100644 Binary files a/docs/imgs/input/1_bg_removed.png and b/docs/imgs/input/1_bg_removed.png differ diff --git a/docs/imgs/input/2_bg_removed.png b/docs/imgs/input/2_bg_removed.png index d416cf3..a30c041 100644 Binary files a/docs/imgs/input/2_bg_removed.png and b/docs/imgs/input/2_bg_removed.png differ diff --git a/docs/imgs/input/3_bg_removed.png b/docs/imgs/input/3_bg_removed.png index abfc1cc..298e17f 100644 Binary files a/docs/imgs/input/3_bg_removed.png and b/docs/imgs/input/3_bg_removed.png differ diff --git a/docs/imgs/input/4_bg_removed.png b/docs/imgs/input/4_bg_removed.png index 48b56e0..32b6a1c 100644 Binary files a/docs/imgs/input/4_bg_removed.png and b/docs/imgs/input/4_bg_removed.png differ diff --git a/docs/other/carvekit_try.ipynb b/docs/other/carvekit_try.ipynb index dcb38c6..484ee9c 100644 --- a/docs/other/carvekit_try.ipynb +++ b/docs/other/carvekit_try.ipynb @@ -1,259 +1,204 @@ { - "nbformat": 4, - "nbformat_minor": 0, - "metadata": { - "colab": { - "name": "carvekit-try.ipynb", - "provenance": [], - "collapsed_sections": [] - }, - "kernelspec": { - "name": "python3", - "display_name": "Python 3" - }, - "accelerator": "GPU", - "gpuClass": "standard" - }, - "cells": [ - { - "cell_type": "markdown", - "source": [ - "![logo.png]()" - ], - "metadata": { - "id": "-BV5wSJzQ-ev", - "pycharm": { - "name": "#%% md\n" - } - } - }, - { - "cell_type": "markdown", - "source": [ - "### Automated high-quality background removal framework for an image using neural networks\n", - "\n", - "\n", - "\n", - "- 🏢 [Project at GitHub](https://github.com/OPHoperHPO/image-background-remove-tool) 🏢\n", - "- 🔗 [Author at GitHub](https://github.com/OPHoperHPO) 🔗\n", - "\n", - "> Please rate our repository with ⭐ if you like our work! Thanks! 😀" - ], - "metadata": { - "id": "Yq1sa5BbRV4c", - "pycharm": { - "name": "#%% md\n" - } - } - }, - { - "cell_type": "markdown", - "source": [ - "This notebook supports **Google Colab GPU runtime**. \n", - "\n", - "> **Enabling and testing the GPU** \\\n", - "> Navigate to `Edit → Notebook Settings`. \\\n", - "> Select `GPU` from the `Hardware Accelerator` drop-down." - ], - "metadata": { - "id": "lrGOILABYqXx", - "pycharm": { - "name": "#%% md\n" - } - } - }, - { - "cell_type": "markdown", - "metadata": { - "id": "sqwsUfoI3SnG", - "pycharm": { - "name": "#%% md\n" - } - }, - "source": [ - "# Install CarveKit" - ] - }, - { - "cell_type": "code", - "metadata": { - "id": "7C4rC_HQi1gq", - "outputId": "4b17792b-8f83-4195-be15-0a46d9f80534", + "nbformat": 4, + "nbformat_minor": 0, + "metadata": { "colab": { - "base_uri": "https://localhost:8080/" + "name": "carvekit-try.ipynb", + "provenance": [], + "collapsed_sections": [] }, - "pycharm": { - "name": "#%%\n" - } - }, - "source": [ - "#@title Install colab-ready python package (Click the arrow on the left)\n", - "%cd /content\n", - "!pip install carvekit_colab\n" - ], - "execution_count": 1, - "outputs": [ - { - "output_type": "stream", - "name": "stdout", - "text": [ - "/content\n", - "Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/\n", - "Collecting carvekit_colab\n", - " Downloading carvekit_colab-4.0.1-py3-none-any.whl (56 kB)\n", - "\u001B[K |████████████████████████████████| 56 kB 4.9 MB/s \n", - "\u001B[?25hCollecting loguru\n", - " Downloading loguru-0.6.0-py3-none-any.whl (58 kB)\n", - "\u001B[K |████████████████████████████████| 58 kB 6.9 MB/s \n", - "\u001B[?25hRequirement already satisfied: tqdm in /usr/local/lib/python3.7/dist-packages (from carvekit_colab) (4.64.1)\n", - "Requirement already satisfied: numpy in /usr/local/lib/python3.7/dist-packages (from carvekit_colab) (1.21.6)\n", - "Collecting uvicorn\n", - " Downloading uvicorn-0.18.3-py3-none-any.whl (57 kB)\n", - "\u001B[K |████████████████████████████████| 57 kB 6.0 MB/s \n", - "\u001B[?25hRequirement already satisfied: requests in /usr/local/lib/python3.7/dist-packages (from carvekit_colab) (2.23.0)\n", - "Requirement already satisfied: torch in /usr/local/lib/python3.7/dist-packages (from carvekit_colab) (1.12.1+cu113)\n", - "Requirement already satisfied: pydantic in /usr/local/lib/python3.7/dist-packages (from carvekit_colab) (1.9.2)\n", - "Requirement already satisfied: click in /usr/local/lib/python3.7/dist-packages (from carvekit_colab) (7.1.2)\n", - "Requirement already satisfied: opencv-python in /usr/local/lib/python3.7/dist-packages (from carvekit_colab) (4.6.0.66)\n", - "Collecting aiofiles\n", - " Downloading aiofiles-22.1.0-py3-none-any.whl (14 kB)\n", - "Requirement already satisfied: Pillow in /usr/local/lib/python3.7/dist-packages (from carvekit_colab) (7.1.2)\n", - "Collecting python-multipart\n", - " Downloading python-multipart-0.0.5.tar.gz (32 kB)\n", - "Collecting fastapi\n", - " Downloading fastapi-0.85.0-py3-none-any.whl (55 kB)\n", - "\u001B[K |████████████████████████████████| 55 kB 4.4 MB/s \n", - "\u001B[?25hCollecting starlette\n", - " Downloading starlette-0.20.4-py3-none-any.whl (63 kB)\n", - "\u001B[K |████████████████████████████████| 63 kB 2.8 MB/s \n", - "\u001B[?25hRequirement already satisfied: torchvision in /usr/local/lib/python3.7/dist-packages (from carvekit_colab) (0.13.1+cu113)\n", - "Requirement already satisfied: setuptools in /usr/local/lib/python3.7/dist-packages (from carvekit_colab) (57.4.0)\n", - "Requirement already satisfied: typing-extensions>=3.10.0 in /usr/local/lib/python3.7/dist-packages (from starlette->carvekit_colab) (4.1.1)\n", - "Collecting anyio<5,>=3.4.0\n", - " Downloading anyio-3.6.1-py3-none-any.whl (80 kB)\n", - "\u001B[K |████████████████████████████████| 80 kB 9.8 MB/s \n", - "\u001B[?25hCollecting sniffio>=1.1\n", - " Downloading sniffio-1.3.0-py3-none-any.whl (10 kB)\n", - "Requirement already satisfied: idna>=2.8 in /usr/local/lib/python3.7/dist-packages (from anyio<5,>=3.4.0->starlette->carvekit_colab) (2.10)\n", - "Requirement already satisfied: six>=1.4.0 in /usr/local/lib/python3.7/dist-packages (from python-multipart->carvekit_colab) (1.15.0)\n", - "Requirement already satisfied: urllib3!=1.25.0,!=1.25.1,<1.26,>=1.21.1 in /usr/local/lib/python3.7/dist-packages (from requests->carvekit_colab) (1.24.3)\n", - "Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.7/dist-packages (from requests->carvekit_colab) (2022.6.15)\n", - "Requirement already satisfied: chardet<4,>=3.0.2 in /usr/local/lib/python3.7/dist-packages (from requests->carvekit_colab) (3.0.4)\n", - "Collecting h11>=0.8\n", - " Downloading h11-0.13.0-py3-none-any.whl (58 kB)\n", - "\u001B[K |████████████████████████████████| 58 kB 6.5 MB/s \n", - "\u001B[?25hBuilding wheels for collected packages: python-multipart\n", - " Building wheel for python-multipart (setup.py) ... \u001B[?25l\u001B[?25hdone\n", - " Created wheel for python-multipart: filename=python_multipart-0.0.5-py3-none-any.whl size=31678 sha256=07256848d4610c7b38d9d5b2d789f20eea390c0eb7a7fe8497b2858367b7ce96\n", - " Stored in directory: /root/.cache/pip/wheels/2c/41/7c/bfd1c180534ffdcc0972f78c5758f89881602175d48a8bcd2c\n", - "Successfully built python-multipart\n", - "Installing collected packages: sniffio, anyio, starlette, h11, uvicorn, python-multipart, loguru, fastapi, aiofiles, carvekit-colab\n", - "Successfully installed aiofiles-22.1.0 anyio-3.6.1 carvekit-colab-4.0.1 fastapi-0.85.0 h11-0.13.0 loguru-0.6.0 python-multipart-0.0.5 sniffio-1.3.0 starlette-0.20.4 uvicorn-0.18.3\n" - ] - } - ] + "kernelspec": { + "name": "python3", + "display_name": "Python 3" + }, + "accelerator": "GPU", + "gpuClass": "standard" }, - { - "cell_type": "code", - "source": [ - "#@title Download all models\n", - "from carvekit.ml.files.models_loc import download_all\n", - "\n", - "download_all();" - ], - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" + "cells": [ + { + "cell_type": "markdown", + "source": [ + "![logo.png]()" + ], + "metadata": { + "id": "-BV5wSJzQ-ev", + "pycharm": { + "name": "#%% md\n" + } + } }, - "cellView": "form", - "id": "EPjtRXRpQ2k7", - "outputId": "3353de30-4153-4fba-fc5d-5fe5cf60f4f8", - "pycharm": { - "name": "#%%\n" - } - }, - "execution_count": 2, - "outputs": [ { - "output_type": "stream", - "name": "stderr", - "text": [ - "Downloading u2net.pth model: 1377273it [00:18, 73486.19it/s]\n", - "Downloading fba_matting.pth model: 1084688it [00:12, 86147.52it/s]\n", - "Downloading deeplab.pth model: 1910513it [00:20, 91823.93it/s]\n", - "Downloading basnet.pth model: 2722581it [00:31, 87620.62it/s]\n" - ] - } - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "pF-4SVcB3gjK", - "pycharm": { - "name": "#%% md\n" - } - }, - "source": [ - "# Remove background using CarveKit" - ] - }, - { - "cell_type": "code", - "metadata": { - "id": "rgm6pR6U22a9", - "colab": { - "base_uri": "https://localhost:8080/", - "height": 529 + "cell_type": "markdown", + "source": [ + "### Automated high-quality background removal framework for an image using neural networks\n", + "\n", + "\n", + "\n", + "- 🏢 [Project at GitHub](https://github.com/OPHoperHPO/image-background-remove-tool) 🏢\n", + "- 🔗 [Author at GitHub](https://github.com/OPHoperHPO) 🔗\n", + "\n", + "> Please rate our repository with ⭐ if you like our work! Thanks! 😀" + ], + "metadata": { + "id": "Yq1sa5BbRV4c", + "pycharm": { + "name": "#%% md\n" + } + } + }, + { + "cell_type": "markdown", + "source": [ + "This notebook supports **Google Colab GPU runtime**. \n", + "\n", + "> **Enabling and testing the GPU** \\\n", + "> Navigate to `Edit → Notebook Settings`. \\\n", + "> Select `GPU` from the `Hardware Accelerator` drop-down." + ], + "metadata": { + "id": "lrGOILABYqXx", + "pycharm": { + "name": "#%% md\n" + } + } + }, + { + "cell_type": "markdown", + "metadata": { + "id": "sqwsUfoI3SnG", + "pycharm": { + "name": "#%% md\n" + } + }, + "source": [ + "# Install CarveKit" + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "7C4rC_HQi1gq", + "pycharm": { + "name": "#%%\n" + } + }, + "source": [ + "#@title Install colab-ready python package (Click the arrow on the left)\n", + "%cd /content\n", + "!pip install carvekit_colab\n" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "#@title Download all models\n", + "from carvekit.ml.files.models_loc import download_all\n", + "\n", + "download_all();" + ], + "metadata": { + "cellView": "form", + "id": "EPjtRXRpQ2k7", + "pycharm": { + "name": "#%%\n" + } + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "pF-4SVcB3gjK", + "pycharm": { + "name": "#%% md\n" + } + }, + "source": [ + "# Remove background using CarveKit" + ] }, - "cellView": "form", - "outputId": "c147713d-1e5d-4eaf-a901-f345ae846ca8", - "pycharm": { - "name": "#%%\n" - } - }, - "source": [ - "#@title Upload images from your computer\n", - "import torch\n", - "from IPython import display\n", - "from google.colab import files\n", - "from carvekit.api.high import HiInterface\n", - "\n", - "SHOW_FULLSIZE = False #@param {type:\"boolean\"}\n", - "\n", - "#@markdown Description of parameters\n", - "#@markdown - `SHOW_FULLSIZE` - Shows image in full size (may take a long time to load)\n", - "\n", - "interface = HiInterface(batch_size_seg=5,\n", - " batch_size_matting=1, \n", - " device='cuda' if torch.cuda.is_available() else 'cpu',\n", - " seg_mask_size=320, matting_mask_size=2048)\n", - "\n", - "\n", - "\n", - "\n", - "uploaded = files.upload().keys()\n", - "display.clear_output()\n", - "images = interface(uploaded)\n", - "for im in enumerate(images):\n", - " if not SHOW_FULLSIZE:\n", - " im[1].thumbnail((768, 768), resample=3)\n", - " display.display(im[1])\n", - "\n" - ], - "execution_count": 4, - "outputs": [ { - "output_type": "display_data", - "data": { - "text/plain": [ - "" + "cell_type": "code", + "metadata": { + "id": "rgm6pR6U22a9", + "colab": { + "base_uri": "https://localhost:8080/", + "height": 529 + }, + "cellView": "form", + "outputId": "a908d208-0520-42ec-dbe0-c06e6c4ee260", + "pycharm": { + "name": "#%%\n" + } + }, + "source": [ + "#@title Upload images from your computer\n", + "#@markdown Description of parameters\n", + "#@markdown - `SHOW_FULLSIZE` - Shows image in full size (may take a long time to load)\n", + "#@markdown - `PREPROCESSING_METHOD` - Preprocessing method\n", + "#@markdown - `SEGMENTATION_NETWORK` - Segmentation network. Use `u2net` for hairs-like objects and `tracer_b7` for objects\n", + "#@markdown - `POSTPROCESSING_METHOD` - Postprocessing method\n", + "#@markdown - `SEGMENTATION_MASK_SIZE` - Segmentation mask size. Use 640 for Tracer B7 and 320 for U2Net\n", + "#@markdown - `TRIMAP_DILATION` - The size of the offset radius from the object mask in pixels when forming an unknown area\n", + "#@markdown - `TRIMAP_EROSION` - The number of iterations of erosion that the object's mask will be subjected to before forming an unknown area\n", + "#@markdown > Look README.md and code for more details on networks and methods\n", + "\n", + "\n", + "import torch\n", + "from IPython import display\n", + "from google.colab import files\n", + "from carvekit.web.schemas.config import MLConfig\n", + "from carvekit.web.utils.init_utils import init_interface\n", + "\n", + "SHOW_FULLSIZE = False #@param {type:\"boolean\"}\n", + "PREPROCESSING_METHOD = \"none\" #@param [\"stub\", \"none\"]\n", + "SEGMENTATION_NETWORK = \"tracer_b7\" #@param [\"u2net\", \"deeplabv3\", \"basnet\", \"tracer_b7\"]\n", + "POSTPROCESSING_METHOD = \"fba\" #@param [\"fba\", \"none\"] \n", + "SEGMENTATION_MASK_SIZE = 640 #@param [\"640\", \"320\"] {type:\"raw\", allow-input: true}\n", + "TRIMAP_DILATION = 30 #@param {type:\"integer\"}\n", + "TRIMAP_EROSION = 5 #@param {type:\"integer\"}\n", + "DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'\n", + "\n", + "\n", + "config = MLConfig(segmentation_network=SEGMENTATION_NETWORK,\n", + " preprocessing_method=PREPROCESSING_METHOD,\n", + " postprocessing_method=POSTPROCESSING_METHOD,\n", + " seg_mask_size=SEGMENTATION_MASK_SIZE,\n", + " trimap_dilation=TRIMAP_DILATION,\n", + " trimap_erosion=TRIMAP_EROSION,\n", + " device=DEVICE)\n", + "\n", + "\n", + "interface = init_interface(config)\n", + "\n", + "\n", + "\n", + "\n", + "uploaded = files.upload().keys()\n", + "display.clear_output()\n", + "images = interface(uploaded)\n", + "for im in enumerate(images):\n", + " if not SHOW_FULLSIZE:\n", + " im[1].thumbnail((768, 768), resample=3)\n", + " display.display(im[1])\n", + "\n" ], - "image/png": "\n" - }, - "metadata": {} + "execution_count": 5, + "outputs": [ + { + "output_type": "display_data", + "data": { + "text/plain": [ + "" + ], + "image/png": "\n" + }, + "metadata": {} + } + ] } - ] - } - ] + ] } \ No newline at end of file diff --git a/docs/readme/ru.md b/docs/readme/ru.md index f938fd0..c001adc 100644 --- a/docs/readme/ru.md +++ b/docs/readme/ru.md @@ -5,7 +5,9 @@

+ +

********************************************************************** @@ -25,6 +27,7 @@ - Высокое качество выходного изображения - Пакетная обработка изображений - Поддержка NVIDIA CUDA и процессорной обработки +- Поддержка FP16: быстрая обработка с низким потреблением памяти - Легкое взаимодействие и запуск - 100% совместимое с remove.bg API FastAPI HTTP API - Удаляет фон с волос @@ -37,12 +40,14 @@ 2. Происходит предобработка фотографии для обеспечения лучшего качества выходного изображения 3. С помощью технологии машинного обучения убирается фон у изображения 4. Происходит постобработка изображения для улучшения качества обработанного изображения -## 🎓 Интегрированные нейронные сети: -* [U^2-net](https://github.com/NathanUA/U-2-Net) -* [BASNet](https://github.com/NathanUA/BASNet) -* [DeepLabV3](https://github.com/tensorflow/models/tree/master/research/deeplab) - - +## 🎓 Implemented Neural Networks: +| Нейронные сети | Целевая область | Точность | +|:--------------:|:--------------------------------------------:|:--------------------------------:| +| **Tracer-B7** | **Общий** (objects, people, animals, etc) | **90%** (mean F1-Score, DUTS-TE) | +| U^2-net | **Волосы** (hairs, people, animals, objects) | 80% (mean F1-Score, DUTS-TE) | +| BASNet | **Общий** (people, objects) | 80% (mean F1-Score, DUTS-TE) | +| DeepLabV3 | People, Animals, Cars, etc | 67.4% (mean IoU, COCO val2017) | +> Используйте U2-Net для волос и Tracer-B7 для обычных изображений. ## 🖼️ Методы предварительной обработки и постобработки изображений: ### 🔍 Методы предобработки: * `none` - методы предобработки не используются. @@ -67,12 +72,21 @@ import torch from carvekit.api.high import HiInterface -interface = HiInterface(batch_size_seg=5, batch_size_matting=1, - device='cuda' if torch.cuda.is_available() else 'cpu', - seg_mask_size=320, matting_mask_size=2048) -images_without_background = interface(['./tests/data/cat.jpg']) +# Check doc strings for more information +interface = HiInterface(object_type="hairs-like", # Can be "object" or "hairs-like". + batch_size_seg=5, + batch_size_matting=1, + device='cuda' if torch.cuda.is_available() else 'cpu', + seg_mask_size=640, + matting_mask_size=2048, + trimap_prob_threshold=231, + trimap_dilation=30, + trimap_erosion_iters=5, + fp16=False) +images_without_background = interface(['./tests/data/cat.jpg']) cat_wo_bg = images_without_background[0] cat_wo_bg.save('2.png') + ``` @@ -82,12 +96,13 @@ import PIL.Image from carvekit.api.interface import Interface from carvekit.ml.wrap.fba_matting import FBAMatting -from carvekit.ml.wrap.u2net import U2NET +from carvekit.ml.wrap.tracer_b7 import TracerUniversalB7 from carvekit.pipelines.postprocessing import MattingMethod from carvekit.pipelines.preprocessing import PreprocessingStub from carvekit.trimap.generator import TrimapGenerator -u2net = U2NET(device='cpu', +# Check doc strings for more information +seg_net = TracerUniversalB7(device='cpu', batch_size=1) fba = FBAMatting(device='cpu', @@ -104,7 +119,7 @@ postprocessing = MattingMethod(matting_module=fba, interface = Interface(pre_pipe=preprocessing, post_pipe=postprocessing, - seg_pipe=u2net) + seg_pipe=seg_net) image = PIL.Image.open('tests/data/cat.jpg') cat_wo_bg = interface([image])[0] @@ -143,14 +158,30 @@ Options: --matting_mask_size 2048 Размер исходного изображения для матирующей нейронной сети - --device cpu Обрабатывающий девайс - --help Показывает это сообщение + --trimap_dilation 30 Размер радиуса смещения от маски объекта в пикселях при + формировании неизвестной области + + --trimap_erosion 5 Количество итераций эрозии, которым будет подвергаться маска + объекта перед формированием неизвестной области. + + --trimap_prob_threshold 231 Порог вероятности, при котором будут применяться + операции prob_filter и prob_as_unknown_area + + --device cpu Устройство обработки. + + --fp16 Включает обработку со смешанной точностью. + Используйте только с CUDA. Поддержка процессора является экспериментальной! + + --help Показать это сообщение и выйти. ```` ## 📦 Запустить фреймворк / FastAPI HTTP API сервер с помощью Docker: Использование API через Docker — это **быстрый** и эффективный способ получить работающий API.\ -**Этот HTTP API на 100% совместим с API клиентами сайта remove.bg** +> Наши образы Docker доступны на [Docker Hub](https:hub.docker.comranodevcarvekit). \ +> Теги версий совпадают с релизами проекта с суффиксами `-cpu` и `-cuda` для версий CPU и CUDA соответственно. + +

diff --git a/requirements_dev.txt b/requirements_dev.txt new file mode 100644 index 0000000..eb6a8d8 --- /dev/null +++ b/requirements_dev.txt @@ -0,0 +1,2 @@ +pre-commit==2.20.0 + diff --git a/setup.py b/setup.py index 42c6302..489e60c 100644 --- a/setup.py +++ b/setup.py @@ -15,12 +15,12 @@ def read(filename: str): filepath = os.path.join(os.path.dirname(__file__), filename) - file = open(filepath, 'r', encoding='utf-8') + file = open(filepath, "r", encoding="utf-8") return file.read() def req_file(filename: str, folder: str = "."): - with open(os.path.join(folder, filename), encoding='utf-8') as f: + with open(os.path.join(folder, filename), encoding="utf-8") as f: content = f.readlines() # you may also want to remove whitespace characters # Example: `\n` at the end of each line @@ -30,35 +30,42 @@ def req_file(filename: str, folder: str = "."): setup( - name='carvekit' if IS_COLAB_PACKAGE is None else 'carvekit_colab', + name="carvekit" if IS_COLAB_PACKAGE is None else "carvekit_colab", version=version, author="Nikita Selin (Anodev)", - author_email='farvard34@gmail.com', - description='Open-Source background removal framework', - long_description=read('README.md'), + author_email="farvard34@gmail.com", + description="Open-Source background removal framework", + long_description=read("README.md"), long_description_content_type="text/markdown", - license='Apache License v2.0', - keywords=["ml", "carvekit", "background removal", "neural networks", "machine learning", "remove bg"], - url='https://github.com/OPHoperHPO/image-background-remove-tool', + license="Apache License v2.0", + keywords=[ + "ml", + "carvekit", + "background removal", + "neural networks", + "machine learning", + "remove bg", + ], + url="https://github.com/OPHoperHPO/image-background-remove-tool", packages=find_packages(), scripts=[], install_requires=req_file("requirements.txt"), include_package_data=True, zip_safe=False, entry_points={ - 'console_scripts': [ - 'carvekit=carvekit:__main__.removebg', + "console_scripts": [ + "carvekit=carvekit:__main__.removebg", ], }, python_requires=">=3.8" if IS_COLAB_PACKAGE is None else ">=3.6", classifiers=[ - 'Development Status :: 5 - Production/Stable', - 'Environment :: Web Environment', - 'Intended Audience :: Developers', - 'License :: OSI Approved :: Apache Software License', - 'Natural Language :: English', - 'Operating System :: OS Independent', - 'Programming Language :: Python :: 3.10', - 'Topic :: Scientific/Engineering :: Artificial Intelligence' + "Development Status :: 5 - Production/Stable", + "Environment :: Web Environment", + "Intended Audience :: Developers", + "License :: OSI Approved :: Apache Software License", + "Natural Language :: English", + "Operating System :: OS Independent", + "Programming Language :: Python :: 3.10", + "Topic :: Scientific/Engineering :: Artificial Intelligence", ], ) diff --git a/tests/test_basnet.py b/tests/test_basnet.py index 5ab6667..1bf28fb 100644 --- a/tests/test_basnet.py +++ b/tests/test_basnet.py @@ -11,23 +11,41 @@ def test_init(): - BASNET(input_tensor_size=[320, 320], load_pretrained=True) + BASNET(input_image_size=[320, 320], load_pretrained=True) BASNET(load_pretrained=False) def test_preprocessing(basnet_model, converted_pil_image, black_image_pil): - basnet_model = basnet_model() - assert isinstance(basnet_model.data_preprocessing(converted_pil_image), torch.FloatTensor) is True - assert isinstance(basnet_model.data_preprocessing(black_image_pil), torch.FloatTensor) is True + basnet_model = basnet_model(False) + assert ( + isinstance( + basnet_model.data_preprocessing(converted_pil_image), torch.FloatTensor + ) + is True + ) + assert ( + isinstance(basnet_model.data_preprocessing(black_image_pil), torch.FloatTensor) + is True + ) def test_postprocessing(basnet_model, converted_pil_image, black_image_pil): - basnet_model = basnet_model() - assert isinstance(basnet_model.data_postprocessing(torch.ones((1, 320, 320), dtype=torch.float64), - converted_pil_image), Image.Image) + basnet_model = basnet_model(False) + assert isinstance( + basnet_model.data_postprocessing( + torch.ones((1, 320, 320), dtype=torch.float64), converted_pil_image + ), + Image.Image, + ) def test_seg(basnet_model, image_pil, image_str, image_path, black_image_pil): - basnet_model = basnet_model() + basnet_model = basnet_model(False) + basnet_model([image_pil]) + basnet_model([image_pil, image_str, image_path, black_image_pil]) + + +def test_seg_fp12(basnet_model, image_pil, image_str, image_path, black_image_pil): + basnet_model = basnet_model(True) basnet_model([image_pil]) basnet_model([image_pil, image_str, image_path, black_image_pil]) diff --git a/tests/test_deeplabv3.py b/tests/test_deeplabv3.py index eeff0f6..6df089e 100644 --- a/tests/test_deeplabv3.py +++ b/tests/test_deeplabv3.py @@ -12,24 +12,45 @@ def test_init(): DeepLabV3(load_pretrained=True) - DeepLabV3(load_pretrained=False).to('cpu') + DeepLabV3(load_pretrained=False).to("cpu") DeepLabV3(input_image_size=[128, 256]) - def test_preprocessing(deeplabv3_model, converted_pil_image, black_image_pil): - deeplabv3_model = deeplabv3_model() - assert isinstance(deeplabv3_model.data_preprocessing(converted_pil_image), torch.FloatTensor) is True - assert isinstance(deeplabv3_model.data_preprocessing(black_image_pil), torch.FloatTensor) is True + deeplabv3_model = deeplabv3_model(False) + assert ( + isinstance( + deeplabv3_model.data_preprocessing(converted_pil_image), torch.FloatTensor + ) + is True + ) + assert ( + isinstance( + deeplabv3_model.data_preprocessing(black_image_pil), torch.FloatTensor + ) + is True + ) def test_postprocessing(deeplabv3_model, converted_pil_image, black_image_pil): - deeplabv3_model = deeplabv3_model() - assert isinstance(deeplabv3_model.data_postprocessing(torch.ones((320, 320), dtype=torch.float64), - converted_pil_image), Image.Image) + deeplabv3_model = deeplabv3_model(False) + assert isinstance( + deeplabv3_model.data_postprocessing( + torch.ones((320, 320), dtype=torch.float64), converted_pil_image + ), + Image.Image, + ) def test_seg(deeplabv3_model, image_pil, image_str, image_path, black_image_pil): - deeplabv3_model = deeplabv3_model() + deeplabv3_model = deeplabv3_model(False) + deeplabv3_model([image_pil]) + deeplabv3_model([image_pil, image_str, image_path, black_image_pil]) + + +def test_seg_with_fp12( + deeplabv3_model, image_pil, image_str, image_path, black_image_pil +): + deeplabv3_model = deeplabv3_model(True) deeplabv3_model([image_pil]) deeplabv3_model([image_pil, image_str, image_path, black_image_pil]) diff --git a/tests/test_fba.py b/tests/test_fba.py index b212ea6..a36f69a 100644 --- a/tests/test_fba.py +++ b/tests/test_fba.py @@ -18,39 +18,104 @@ def test_init(): def test_preprocessing(fba_model, converted_pil_image, black_image_pil, image_mask): - fba_model = fba_model() - assert isinstance(fba_model.data_preprocessing(converted_pil_image)[0], torch.FloatTensor) is True - assert isinstance(fba_model.data_preprocessing(black_image_pil)[0], torch.FloatTensor) is True - assert isinstance(fba_model.data_preprocessing(image_mask)[0], torch.FloatTensor) is True + fba_model = fba_model(False) + assert ( + isinstance( + fba_model.data_preprocessing(converted_pil_image)[0], torch.FloatTensor + ) + is True + ) + assert ( + isinstance(fba_model.data_preprocessing(black_image_pil)[0], torch.FloatTensor) + is True + ) + assert ( + isinstance(fba_model.data_preprocessing(image_mask)[0], torch.FloatTensor) + is True + ) with pytest.raises(ValueError): - assert isinstance(fba_model.data_preprocessing(Image.new("P", (512, 512)))[0], torch.FloatTensor) is True - fba_model = FBAMatting(device='cuda' if torch.cuda.is_available() else 'cpu', - input_tensor_size=1024, - batch_size=1, - load_pretrained=True) - assert isinstance(fba_model.data_preprocessing(converted_pil_image)[0], torch.FloatTensor) is True - assert isinstance(fba_model.data_preprocessing(black_image_pil)[0], torch.FloatTensor) is True - assert isinstance(fba_model.data_preprocessing(image_mask)[0], torch.FloatTensor) is True + assert ( + isinstance( + fba_model.data_preprocessing(Image.new("P", (512, 512)))[0], + torch.FloatTensor, + ) + is True + ) + fba_model = FBAMatting( + device="cuda" if torch.cuda.is_available() else "cpu", + input_tensor_size=1024, + batch_size=1, + load_pretrained=True, + ) + assert ( + isinstance( + fba_model.data_preprocessing(converted_pil_image)[0], torch.FloatTensor + ) + is True + ) + assert ( + isinstance(fba_model.data_preprocessing(black_image_pil)[0], torch.FloatTensor) + is True + ) + assert ( + isinstance(fba_model.data_preprocessing(image_mask)[0], torch.FloatTensor) + is True + ) with pytest.raises(ValueError): - assert isinstance(fba_model.data_preprocessing(Image.new("P", (512, 512)))[0], torch.FloatTensor) is True + assert ( + isinstance( + fba_model.data_preprocessing(Image.new("P", (512, 512)))[0], + torch.FloatTensor, + ) + is True + ) def test_postprocessing(fba_model, converted_pil_image, black_image_pil): - fba_model = fba_model() - assert isinstance(fba_model.data_postprocessing(torch.ones((7, 320, 320), dtype=torch.float64), - black_image_pil.convert("L")), Image.Image) + fba_model = fba_model(False) + assert isinstance( + fba_model.data_postprocessing( + torch.ones((7, 320, 320), dtype=torch.float64), black_image_pil.convert("L") + ), + Image.Image, + ) with pytest.raises(ValueError): - assert isinstance(fba_model.data_postprocessing(torch.ones((7, 320, 320), dtype=torch.float64), - black_image_pil.convert("RGBA")), Image.Image) + assert isinstance( + fba_model.data_postprocessing( + torch.ones((7, 320, 320), dtype=torch.float64), + black_image_pil.convert("RGBA"), + ), + Image.Image, + ) -def test_seg(fba_model, image_pil, image_str, image_path, black_image_pil, image_trimap): - fba_model = fba_model() +def test_seg( + fba_model, image_pil, image_str, image_path, black_image_pil, image_trimap +): + fba_model = fba_model(False) fba_model([image_pil], [image_trimap]) - fba_model([image_pil, image_str, image_path], - [image_trimap, image_trimap, image_trimap]) - fba_model([Image.new('RGB', (512, 512)), - Image.new('RGB', (512, 512))], [Image.new('L', (512, 512)), - Image.new('L', (512, 512))]) + fba_model( + [image_pil, image_str, image_path], [image_trimap, image_trimap, image_trimap] + ) + fba_model( + [Image.new("RGB", (512, 512)), Image.new("RGB", (512, 512))], + [Image.new("L", (512, 512)), Image.new("L", (512, 512))], + ) + with pytest.raises(ValueError): + fba_model([image_pil], [image_trimap, image_trimap]) + + +def test_seg_with_fp12( + fba_model, image_pil, image_str, image_path, black_image_pil, image_trimap +): + fba_model = fba_model(True) + fba_model([image_pil], [image_trimap]) + fba_model( + [image_pil, image_str, image_path], [image_trimap, image_trimap, image_trimap] + ) + fba_model( + [Image.new("RGB", (512, 512)), Image.new("RGB", (512, 512))], + [Image.new("L", (512, 512)), Image.new("L", (512, 512))], + ) with pytest.raises(ValueError): fba_model([image_pil], [image_trimap, image_trimap]) diff --git a/tests/test_fs_utils.py b/tests/test_fs_utils.py index 3f55772..edca14d 100644 --- a/tests/test_fs_utils.py +++ b/tests/test_fs_utils.py @@ -13,11 +13,21 @@ def test_save_file(): save_file(Path("output.png"), Path("input.png"), PIL.Image.new("RGB", (512, 512))) os.remove(Path("output.png")) - save_file(Path(__file__).parent.joinpath("data"), Path("input.png"), PIL.Image.new("RGB", (512, 512))) - os.remove(Path(__file__).parent.joinpath("data").joinpath('input.png')) + save_file( + Path(__file__).parent.joinpath("data"), + Path("input.png"), + PIL.Image.new("RGB", (512, 512)), + ) + os.remove(Path(__file__).parent.joinpath("data").joinpath("input.png")) save_file(Path("output.jpg"), Path("input.jpg"), PIL.Image.new("RGB", (512, 512))) os.remove(Path("output.png")) with pytest.raises(ValueError): - save_file(Path("NotExistedPath"), Path("input.png"), PIL.Image.new("RGB", (512, 512))) - save_file(output=None, input_path=Path("input.png"), image=PIL.Image.new("RGB", (512, 512))) - os.remove(Path("input_bg_removed.png")) \ No newline at end of file + save_file( + Path("NotExistedPath"), Path("input.png"), PIL.Image.new("RGB", (512, 512)) + ) + save_file( + output=None, + input_path=Path("input.png"), + image=PIL.Image.new("RGB", (512, 512)), + ) + os.remove(Path("input_bg_removed.png")) diff --git a/tests/test_high.py b/tests/test_high.py index c5c0d70..2f91331 100644 --- a/tests/test_high.py +++ b/tests/test_high.py @@ -8,9 +8,25 @@ def test_init(): - HiInterface(batch_size_seg=1, batch_size_matting=4, - device='cpu', - seg_mask_size=160, matting_mask_size=1024) - HiInterface(batch_size_seg=0, batch_size_matting=0, - device='cpu', - seg_mask_size=0, matting_mask_size=0) + HiInterface( + batch_size_seg=1, + batch_size_matting=4, + device="cpu", + seg_mask_size=160, + matting_mask_size=1024, + trimap_prob_threshold=1, + trimap_dilation=2, + trimap_erosion_iters=3, + fp16=False, + ) + HiInterface( + batch_size_seg=0, + batch_size_matting=0, + device="cpu", + seg_mask_size=0, + matting_mask_size=0, + trimap_prob_threshold=0, + trimap_dilation=0, + trimap_erosion_iters=0, + fp16=True, + ) diff --git a/tests/test_image_utils.py b/tests/test_image_utils.py index d661912..f498acd 100644 --- a/tests/test_image_utils.py +++ b/tests/test_image_utils.py @@ -10,8 +10,14 @@ import pytest import torch from PIL import Image -from carvekit.utils.image_utils import load_image, convert_image, is_image_valid, \ - to_tensor, transparency_paste, add_margin +from carvekit.utils.image_utils import ( + load_image, + convert_image, + is_image_valid, + to_tensor, + transparency_paste, + add_margin, +) def test_load_image(image_path, image_pil, image_str): @@ -25,34 +31,34 @@ def test_load_image(image_path, image_pil, image_str): def test_is_image_valid(image_path, image_pil, image_str): assert is_image_valid(image_path) is True - assert is_image_valid(image_path.with_suffix('.JPG')) is True + assert is_image_valid(image_path.with_suffix(".JPG")) is True with pytest.raises(ValueError): - is_image_valid(Path(uuid.uuid1().hex).with_suffix('.jpg')) + is_image_valid(Path(uuid.uuid1().hex).with_suffix(".jpg")) with pytest.raises(ValueError): is_image_valid(Path(__file__).parent) with pytest.raises(ValueError): - is_image_valid(image_path.with_suffix('.mp3')) + is_image_valid(image_path.with_suffix(".mp3")) with pytest.raises(ValueError): - is_image_valid(image_path.with_suffix('.MP3')) + is_image_valid(image_path.with_suffix(".MP3")) with pytest.raises(ValueError): is_image_valid(23) assert is_image_valid(image_pil) is True - assert is_image_valid(Image.new('RGB', (512, 512))) is True - assert is_image_valid(Image.new('L', (512, 512))) is True - assert is_image_valid(Image.new('RGBA', (512, 512))) is True + assert is_image_valid(Image.new("RGB", (512, 512))) is True + assert is_image_valid(Image.new("L", (512, 512))) is True + assert is_image_valid(Image.new("RGBA", (512, 512))) is True with pytest.raises(ValueError): - is_image_valid(Image.new('P', (512, 512))) + is_image_valid(Image.new("P", (512, 512))) with pytest.raises(ValueError): - is_image_valid(Image.new('RGB', (32, 10))) + is_image_valid(Image.new("RGB", (32, 10))) def test_convert_image(image_pil): with pytest.raises(ValueError): - convert_image(Image.new('L', (10, 10))) - assert convert_image(image_pil.convert('RGBA')).mode == "RGB" + convert_image(Image.new("L", (10, 10))) + assert convert_image(image_pil.convert("RGBA")).mode == "RGB" def test_to_tensor(image_pil): @@ -60,12 +66,27 @@ def test_to_tensor(image_pil): def test_transparency_paste(): - assert isinstance(transparency_paste(PIL.Image.new("RGBA", (1024, 1024)), - PIL.Image.new("RGBA", (1024, 1024))), PIL.Image.Image) - assert isinstance(transparency_paste(PIL.Image.new("RGBA", (512, 512)), - PIL.Image.new("RGBA", (512, 512))), PIL.Image.Image) + assert isinstance( + transparency_paste( + PIL.Image.new("RGBA", (1024, 1024)), PIL.Image.new("RGBA", (1024, 1024)) + ), + PIL.Image.Image, + ) + assert isinstance( + transparency_paste( + PIL.Image.new("RGBA", (512, 512)), PIL.Image.new("RGBA", (512, 512)) + ), + PIL.Image.Image, + ) def test_add_margin(): - assert isinstance(add_margin(PIL.Image.new("RGB", (512, 512)), - 10, 10, 10, 10, (10, 10, 10, 10)), PIL.Image.Image) is True + assert ( + isinstance( + add_margin( + PIL.Image.new("RGB", (512, 512)), 10, 10, 10, 10, (10, 10, 10, 10) + ), + PIL.Image.Image, + ) + is True + ) diff --git a/tests/test_interface.py b/tests/test_interface.py index 22f8a64..a34ca88 100644 --- a/tests/test_interface.py +++ b/tests/test_interface.py @@ -14,16 +14,20 @@ def test_init(available_models): models, pre_pipes, post_pipes = available_models devices = ["cpu", "cuda"] for model in models: - mdl = model() + mdl = model(False) for pre_pipe in pre_pipes: pre = pre_pipe() if pre_pipe is not None else pre_pipe for post_pipe in post_pipes: post = post_pipe() if post_pipe is not None else post_pipe for device in devices: if device == "cuda" and torch.cuda.is_available() is False: - warnings.warn('Cuda GPU is not available! Testing on cuda skipped!') + warnings.warn( + "Cuda GPU is not available! Testing on cuda skipped!" + ) continue - inf = Interface(seg_pipe=mdl, post_pipe=post, pre_pipe=pre, device=device) + inf = Interface( + seg_pipe=mdl, post_pipe=post, pre_pipe=pre, device=device + ) del inf del post del pre @@ -33,13 +37,17 @@ def test_init(available_models): def test_seg(image_pil, image_str, image_path, available_models): models, pre_pipes, post_pipes = available_models for model in models: - mdl = model() + mdl = model(False) for pre_pipe in pre_pipes: pre = pre_pipe() if pre_pipe is not None else pre_pipe for post_pipe in post_pipes: post = post_pipe() if post_pipe is not None else post_pipe - interface = Interface(seg_pipe=mdl, post_pipe=post, pre_pipe=pre, - device='cuda' if torch.cuda.is_available() else 'cpu') + interface = Interface( + seg_pipe=mdl, + post_pipe=post, + pre_pipe=pre, + device="cuda" if torch.cuda.is_available() else "cpu", + ) interface([image_pil, image_str, image_path]) del post, interface del pre diff --git a/tests/test_mask_utils.py b/tests/test_mask_utils.py index 52790f7..a979874 100644 --- a/tests/test_mask_utils.py +++ b/tests/test_mask_utils.py @@ -9,15 +9,38 @@ def test_composite(): - assert isinstance(composite(PIL.Image.new("RGB", (512, 512)), PIL.Image.new("RGB", (512, 512)), - PIL.Image.new("RGB", (512, 512)), device="cpu"), PIL.Image.Image) is True + assert ( + isinstance( + composite( + PIL.Image.new("RGB", (512, 512)), + PIL.Image.new("RGB", (512, 512)), + PIL.Image.new("RGB", (512, 512)), + device="cpu", + ), + PIL.Image.Image, + ) + is True + ) def test_apply_mask(): - assert isinstance(apply_mask(image=PIL.Image.new("RGB", (512, 512)), mask=PIL.Image.new("RGB", (512, 512)), - device="cpu"), PIL.Image.Image) is True + assert ( + isinstance( + apply_mask( + image=PIL.Image.new("RGB", (512, 512)), + mask=PIL.Image.new("RGB", (512, 512)), + device="cpu", + ), + PIL.Image.Image, + ) + is True + ) def test_extract_alpha_channel(): - assert isinstance(extract_alpha_channel(PIL.Image.new("RGB", (512, 512))), PIL.Image.Image) is True - + assert ( + isinstance( + extract_alpha_channel(PIL.Image.new("RGB", (512, 512))), PIL.Image.Image + ) + is True + ) diff --git a/tests/test_models_utils.py b/tests/test_models_utils.py index fb0cb65..0f81409 100644 --- a/tests/test_models_utils.py +++ b/tests/test_models_utils.py @@ -6,9 +6,17 @@ import os import pytest from pathlib import Path -from carvekit.utils.download_models import check_for_exists, check_model, sha512_checksum_calc, download_model -from carvekit.ml.files.models_loc import u2net_full_pretrained, fba_pretrained, deeplab_pretrained, basnet_pretrained, \ - download_all +from carvekit.utils.download_models import sha512_checksum_calc +from carvekit.ml.files.models_loc import ( + u2net_full_pretrained, + fba_pretrained, + deeplab_pretrained, + basnet_pretrained, + download_all, + checkpoints_dir, + downloader, + tracer_b7_pretrained, +) from carvekit.utils.models_utils import fix_seed, suppress_warnings @@ -25,39 +33,40 @@ def test_download_all(): def test_download_model(): - hh = Path(__file__).parent.joinpath('data', 'u2net.pth') - hh.write_text('1234') - assert download_model(hh) == hh + hh = checkpoints_dir / "u2net-universal" / "u2net.pth" + hh.write_text("1234") + assert downloader("u2net.pth") == hh os.remove(hh) with pytest.raises(FileNotFoundError): - download_model(Path("NotExistedPath/2.dl")) - with pytest.raises(FileNotFoundError): - download_model(Path(__file__).parent.joinpath('data', 'cat.jpg')) + downloader("NotExistedPath/2.dl") def test_sha512(): - hh = Path(__file__).parent.joinpath('data', 'basnet.pth') - hh.write_text('1234') - assert sha512_checksum_calc(hh) == "d404559f602eab6fd602ac7680dacbfaadd13630335e951f097a" \ - "f3900e9de176b6db28512f2e000" \ - "b9d04fba5133e8b1c6e8df59db3a8ab9d60be4b97cc9e81db" + hh = checkpoints_dir / "basnet-universal" / "basnet.pth" + hh.write_text("1234") + assert ( + sha512_checksum_calc(hh) + == "d404559f602eab6fd602ac7680dacbfaadd13630335e951f097a" + "f3900e9de176b6db28512f2e000" + "b9d04fba5133e8b1c6e8df59db3a8ab9d60be4b97cc9e81db" + ) def test_check_model(): - invalid_hash_file = Path(__file__).parent.joinpath('data', 'basnet.pth') - invalid_hash_file.write_text('1234') - assert check_model(invalid_hash_file) is False - assert check_model(Path(__file__).parent.joinpath('data', 'u2net.pth')) is False - assert check_model(u2net_full_pretrained()) is True - assert check_model(Path("NotExistedPath/2.dl")) is False - with pytest.raises(FileNotFoundError): - assert check_model(Path(__file__).parent.joinpath('data', 'cat.jpg')) is False + invalid_hash_file = checkpoints_dir / "basnet-universal" / "basnet.pth" + invalid_hash_file.write_text("1234") + downloader("basnet.pth") + assert ( + sha512_checksum_calc(invalid_hash_file) + != "d404559f602eab6fd602ac7680dacbfaadd13630335e951f097a" + "f3900e9de176b6db28512f2e000" + "b9d04fba5133e8b1c6e8df59db3a8ab9d60be4b97cc9e81db" + ) def test_check_for_exists(): - assert isinstance(check_for_exists(u2net_full_pretrained()), Path) is True - assert isinstance(check_for_exists(fba_pretrained()), Path) is True - assert isinstance(check_for_exists(deeplab_pretrained()), Path) is True - assert isinstance(check_for_exists(basnet_pretrained()), Path) is True - with pytest.raises(FileNotFoundError): - check_for_exists(Path(__file__).parent.joinpath('data', 'cat.jpg')) + assert u2net_full_pretrained().exists() + assert fba_pretrained().exists() + assert deeplab_pretrained().exists() + assert basnet_pretrained().exists() + assert tracer_b7_pretrained().exists() diff --git a/tests/test_pool_utils.py b/tests/test_pool_utils.py index 07e0a8a..39cbaba 100644 --- a/tests/test_pool_utils.py +++ b/tests/test_pool_utils.py @@ -13,4 +13,4 @@ def test_thread_pool_processing(): def test_batch_generator(): assert list(batch_generator([1, 2, 3], n=1)) == [[1], [2], [3]] - assert list(batch_generator([1, 2, 3, 4], n=2)) == [[1, 2], [3, 4]] \ No newline at end of file + assert list(batch_generator([1, 2, 3, 4], n=2)) == [[1, 2], [3, 4]] diff --git a/tests/test_postprocessing.py b/tests/test_postprocessing.py index 02a201e..f2f8128 100644 --- a/tests/test_postprocessing.py +++ b/tests/test_postprocessing.py @@ -8,7 +8,7 @@ def test_init(fba_model, trimap_instance): - fba_model = fba_model() + fba_model = fba_model(False) trimap_instance = trimap_instance() MattingMethod(fba_model, trimap_instance, "cpu") MattingMethod(fba_model, trimap_instance, device="cuda") @@ -16,6 +16,8 @@ def test_init(fba_model, trimap_instance): def test_seg(matting_method_instance, image_str, image_path, image_pil): matting_method_instance = matting_method_instance() - matting_method_instance(images=[image_str, image_path], masks=[image_pil, image_path]) + matting_method_instance( + images=[image_str, image_path], masks=[image_pil, image_path] + ) with pytest.raises(ValueError): matting_method_instance(images=[image_str], masks=[image_pil, image_path]) diff --git a/tests/test_preprocessing.py b/tests/test_preprocessing.py index 434b351..b898715 100644 --- a/tests/test_preprocessing.py +++ b/tests/test_preprocessing.py @@ -5,7 +5,9 @@ """ -def test_seg(preprocessing_stub_instance, image_str, image_path, image_pil, interface_instance): +def test_seg( + preprocessing_stub_instance, image_str, image_path, image_pil, interface_instance +): preprocessing_stub_instance = preprocessing_stub_instance() interface_instance = interface_instance() preprocessing_stub_instance(interface_instance, [image_str, image_path]) diff --git a/tests/test_tracer.py b/tests/test_tracer.py new file mode 100644 index 0000000..8b5a19f --- /dev/null +++ b/tests/test_tracer.py @@ -0,0 +1,48 @@ +import pytest +import torch +from PIL import Image + +from carvekit.ml.wrap.tracer_b7 import TracerUniversalB7 + + +def test_init(): + TracerUniversalB7(input_image_size=[640, 640], load_pretrained=True) + TracerUniversalB7(input_image_size=640, load_pretrained=True) + TracerUniversalB7(load_pretrained=False) + TracerUniversalB7(fp16=True) + + +def test_preprocessing(tracer_model, converted_pil_image, black_image_pil): + tracer_model = tracer_model(False) + assert ( + isinstance( + tracer_model.data_preprocessing(converted_pil_image), torch.FloatTensor + ) + is True + ) + assert ( + isinstance(tracer_model.data_preprocessing(black_image_pil), torch.FloatTensor) + is True + ) + + +def test_postprocessing(tracer_model, converted_pil_image, black_image_pil): + tracer_model = tracer_model(False) + assert isinstance( + tracer_model.data_postprocessing( + torch.ones((1, 640, 640), dtype=torch.float64), converted_pil_image + ), + Image.Image, + ) + + +def test_seg(tracer_model, image_pil, image_str, image_path, black_image_pil): + tracer_model = tracer_model(False) + tracer_model([image_pil]) + tracer_model([image_pil, image_str, image_path, black_image_pil]) + + +def test_seg_with_fp12(tracer_model, image_pil, image_str, image_path, black_image_pil): + tracer_model = tracer_model(True) + tracer_model([image_pil]) + tracer_model([image_pil, image_str, image_path, black_image_pil]) diff --git a/tests/test_trimap.py b/tests/test_trimap.py index 66b0112..47ba728 100644 --- a/tests/test_trimap.py +++ b/tests/test_trimap.py @@ -12,9 +12,17 @@ def test_trimap_generator(trimap_instance, image_mask, image_pil): te = trimap_instance() assert isinstance(te(image_pil, image_mask), PIL.Image.Image) - assert isinstance(te(PIL.Image.new("RGB", (512, 512)), PIL.Image.new("L", (512, 512))), PIL.Image.Image) - assert isinstance(te(PIL.Image.new("RGB", (512, 512), color=(255, 255, 255)), - PIL.Image.new("L", (512, 512), color=255)), PIL.Image.Image) + assert isinstance( + te(PIL.Image.new("RGB", (512, 512)), PIL.Image.new("L", (512, 512))), + PIL.Image.Image, + ) + assert isinstance( + te( + PIL.Image.new("RGB", (512, 512), color=(255, 255, 255)), + PIL.Image.new("L", (512, 512), color=255), + ), + PIL.Image.Image, + ) with pytest.raises(ValueError): te(PIL.Image.new("RGB", (512, 512)), PIL.Image.new("RGB", (512, 512))) with pytest.raises(ValueError): diff --git a/tests/test_u2net.py b/tests/test_u2net.py index 5960497..098386b 100644 --- a/tests/test_u2net.py +++ b/tests/test_u2net.py @@ -13,20 +13,22 @@ def test_init(): U2NET(layers_cfg="full", input_image_size=[320, 320], load_pretrained=True) - U2NET(layers_cfg='full', load_pretrained=False) - U2NET(layers_cfg={ - 'stage1': ['En_1', (7, 3, 32, 64), -1], - 'stage2': ['En_2', (6, 64, 32, 128), -1], - 'stage3': ['En_3', (5, 128, 64, 256), -1], - 'stage4': ['En_4', (4, 256, 128, 512), -1], - 'stage5': ['En_5', (4, 512, 256, 512, True), -1], - 'stage6': ['En_6', (4, 512, 256, 512, True), 512], - 'stage5d': ['De_5', (4, 1024, 256, 512, True), 512], - 'stage4d': ['De_4', (4, 1024, 128, 256), 256], - 'stage3d': ['De_3', (5, 512, 64, 128), 128], - 'stage2d': ['De_2', (6, 256, 32, 64), 64], - 'stage1d': ['De_1', (7, 128, 16, 64), 64], - }) + U2NET(layers_cfg="full", load_pretrained=False) + U2NET( + layers_cfg={ + "stage1": ["En_1", (7, 3, 32, 64), -1], + "stage2": ["En_2", (6, 64, 32, 128), -1], + "stage3": ["En_3", (5, 128, 64, 256), -1], + "stage4": ["En_4", (4, 256, 128, 512), -1], + "stage5": ["En_5", (4, 512, 256, 512, True), -1], + "stage6": ["En_6", (4, 512, 256, 512, True), 512], + "stage5d": ["De_5", (4, 1024, 256, 512, True), 512], + "stage4d": ["De_4", (4, 1024, 128, 256), 256], + "stage3d": ["De_3", (5, 512, 64, 128), 128], + "stage2d": ["De_2", (6, 256, 32, 64), 64], + "stage1d": ["De_1", (7, 128, 16, 64), 64], + } + ) with pytest.raises(ValueError): U2NET(layers_cfg="nan") with pytest.raises(ValueError): @@ -34,18 +36,36 @@ def test_init(): def test_preprocessing(u2net_model, converted_pil_image, black_image_pil): - u2net_model = u2net_model() - assert isinstance(u2net_model.data_preprocessing(converted_pil_image), torch.FloatTensor) is True - assert isinstance(u2net_model.data_preprocessing(black_image_pil), torch.FloatTensor) is True + u2net_model = u2net_model(False) + assert ( + isinstance( + u2net_model.data_preprocessing(converted_pil_image), torch.FloatTensor + ) + is True + ) + assert ( + isinstance(u2net_model.data_preprocessing(black_image_pil), torch.FloatTensor) + is True + ) def test_postprocessing(u2net_model, converted_pil_image, black_image_pil): - u2net_model = u2net_model() - assert isinstance(u2net_model.data_postprocessing(torch.ones((1, 320, 320), dtype=torch.float64), - converted_pil_image), Image.Image) + u2net_model = u2net_model(False) + assert isinstance( + u2net_model.data_postprocessing( + torch.ones((1, 320, 320), dtype=torch.float64), converted_pil_image + ), + Image.Image, + ) def test_seg(u2net_model, image_pil, image_str, image_path, black_image_pil): - u2net_model = u2net_model() + u2net_model = u2net_model(False) + u2net_model([image_pil]) + u2net_model([image_pil, image_str, image_path, black_image_pil]) + + +def test_seg_with_fp12(u2net_model, image_pil, image_str, image_path, black_image_pil): + u2net_model = u2net_model(True) u2net_model([image_pil]) u2net_model([image_pil, image_str, image_path, black_image_pil])