From 37df12f58b0fcfce02a56b2fce6e065a3c5552e0 Mon Sep 17 00:00:00 2001 From: tensorneko <22864465+tensorneko@users.noreply.github.com> Date: Mon, 20 Mar 2023 02:43:08 +0000 Subject: [PATCH 1/3] add support for py 3.11 and pytorch 2.0 --- cldm/logger.py | 4 ++-- ldm/models/diffusion/ddpm.py | 4 ++-- tutorial_train_sd21.py | 8 +++++--- 3 files changed, 9 insertions(+), 7 deletions(-) diff --git a/cldm/logger.py b/cldm/logger.py index 6a8803846f..ebadddd19b 100644 --- a/cldm/logger.py +++ b/cldm/logger.py @@ -5,7 +5,7 @@ import torchvision from PIL import Image from pytorch_lightning.callbacks import Callback -from pytorch_lightning.utilities.distributed import rank_zero_only +from pytorch_lightning.utilities.rank_zero import rank_zero_only class ImageLogger(Callback): @@ -71,6 +71,6 @@ def log_img(self, pl_module, batch, batch_idx, split="train"): def check_frequency(self, check_idx): return check_idx % self.batch_freq == 0 - def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx): + def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx): if not self.disabled: self.log_img(pl_module, batch, batch_idx, split="train") diff --git a/ldm/models/diffusion/ddpm.py b/ldm/models/diffusion/ddpm.py index f71a44af48..07501c7803 100644 --- a/ldm/models/diffusion/ddpm.py +++ b/ldm/models/diffusion/ddpm.py @@ -17,7 +17,7 @@ import itertools from tqdm import tqdm from torchvision.utils import make_grid -from pytorch_lightning.utilities.distributed import rank_zero_only +from pytorch_lightning.utilities.rank_zero import rank_zero_only from omegaconf import ListConfig from ldm.util import log_txt_as_img, exists, default, ismap, isimage, mean_flat, count_params, instantiate_from_config @@ -588,7 +588,7 @@ def make_cond_schedule(self, ): @rank_zero_only @torch.no_grad() - def on_train_batch_start(self, batch, batch_idx, dataloader_idx): + def on_train_batch_start(self, batch, batch_idx): # only for very first batch if self.scale_by_std and self.current_epoch == 0 and self.global_step == 0 and batch_idx == 0 and not self.restarted_from_ckpt: assert self.scale_factor == 1., 'rather not use custom rescaling and std-rescaling simultaneously' diff --git a/tutorial_train_sd21.py b/tutorial_train_sd21.py index 8bbc148f9b..9f9225b3fa 100644 --- a/tutorial_train_sd21.py +++ b/tutorial_train_sd21.py @@ -1,3 +1,4 @@ +import os from share import * import pytorch_lightning as pl @@ -14,6 +15,7 @@ learning_rate = 1e-5 sd_locked = True only_mid_control = False +workers = os.cpu_count() // 2 # First use cpu to load models. Pytorch Lightning will automatically move it to GPUs. @@ -26,10 +28,10 @@ # Misc dataset = MyDataset() -dataloader = DataLoader(dataset, num_workers=0, batch_size=batch_size, shuffle=True) +dataloader = DataLoader(dataset, num_workers=workers, batch_size=batch_size, shuffle=True) logger = ImageLogger(batch_frequency=logger_freq) -trainer = pl.Trainer(gpus=1, precision=32, callbacks=[logger]) +trainer = pl.Trainer(accelerator='gpu', devices='auto', precision=32, callbacks=[logger]) # Train! -trainer.fit(model, dataloader) +trainer.fit(model, dataloader) \ No newline at end of file From b9b9998caa7bf2d34c7202551f8101e0ba6e3385 Mon Sep 17 00:00:00 2001 From: tensorneko <22864465+tensorneko@users.noreply.github.com> Date: Mon, 20 Mar 2023 02:58:07 +0000 Subject: [PATCH 2/3] conda env, needs pruning --- environment.yaml | 174 +++++++++++++++++++++++++++++++++++++++-------- 1 file changed, 144 insertions(+), 30 deletions(-) diff --git a/environment.yaml b/environment.yaml index 91463f0fb1..7b15e3ddef 100644 --- a/environment.yaml +++ b/environment.yaml @@ -1,35 +1,149 @@ name: control channels: - - pytorch - defaults dependencies: - - python=3.8.5 - - pip=20.3 - - cudatoolkit=11.3 - - pytorch=1.12.1 - - torchvision=0.13.1 - - numpy=1.23.1 + - _libgcc_mutex=0.1=main + - _openmp_mutex=5.1=1_gnu + - bzip2=1.0.8=h7b6447c_0 + - ca-certificates=2023.01.10=h06a4308_0 + - certifi=2022.12.7=py311h06a4308_0 + - ld_impl_linux-64=2.38=h1181459_1 + - libffi=3.4.2=h6a678d5_6 + - libgcc-ng=11.2.0=h1234567_1 + - libgomp=11.2.0=h1234567_1 + - libstdcxx-ng=11.2.0=h1234567_1 + - libuuid=1.41.5=h5eee18b_0 + - ncurses=6.4=h6a678d5_0 + - openssl=1.1.1t=h7f8727e_0 + - pip=22.3.1=py311h06a4308_0 + - python=3.11.0=h7a1cb2a_3 + - readline=8.2=h5eee18b_0 + - setuptools=65.5.0=py311h06a4308_0 + - sqlite=3.41.1=h5eee18b_0 + - tk=8.6.12=h1ccaba5_0 + - wheel=0.37.1=pyhd3eb1b0_0 + - xz=5.2.10=h5eee18b_1 + - zlib=1.2.13=h5eee18b_0 - pip: - - gradio==3.16.2 - - albumentations==1.3.0 - - opencv-contrib-python==4.3.0.36 - - imageio==2.9.0 - - imageio-ffmpeg==0.4.2 - - pytorch-lightning==1.5.0 - - omegaconf==2.1.1 - - test-tube>=0.7.5 - - streamlit==1.12.1 - - einops==0.3.0 - - transformers==4.19.2 - - webdataset==0.2.5 - - kornia==0.6 - - open_clip_torch==2.0.2 - - invisible-watermark>=0.1.5 - - streamlit-drawable-canvas==0.8.0 - - torchmetrics==0.6.0 - - timm==0.6.12 - - addict==2.4.0 - - yapf==0.32.0 - - prettytable==3.6.0 - - safetensors==0.2.7 - - basicsr==1.4.2 + - aiofiles==23.1.0 + - aiohttp==3.8.4 + - aiosignal==1.3.1 + - albumentations==1.3.0 + - altair==4.2.2 + - antlr4-python3-runtime==4.9.3 + - anyio==3.6.2 + - async-timeout==4.0.2 + - attrs==22.2.0 + - blinker==1.5 + - cachetools==5.3.0 + - charset-normalizer==2.1.1 + - click==8.1.3 + - cmake==3.25.0 + - contourpy==1.0.7 + - cycler==0.11.0 + - decorator==5.1.1 + - einops==0.6.0 + - entrypoints==0.4 + - fastapi==0.95.0 + - ffmpy==0.3.0 + - filelock==3.9.0 + - fonttools==4.39.2 + - frozenlist==1.3.3 + - fsspec==2023.3.0 + - ftfy==6.1.1 + - gitdb==4.0.10 + - gitpython==3.1.31 + - gradio==3.22.1 + - h11==0.14.0 + - httpcore==0.16.3 + - httpx==0.23.3 + - huggingface-hub==0.13.2 + - idna==3.4 + - imageio==2.26.0 + - imageio-ffmpeg==0.4.8 + - importlib-metadata==6.1.0 + - jinja2==3.1.2 + - joblib==1.2.0 + - jsonschema==4.17.3 + - kiwisolver==1.4.4 + - lazy-loader==0.1 + - lightning-utilities==0.8.0 + - linkify-it-py==2.0.0 + - lit==15.0.7 + - markdown-it-py==2.2.0 + - markupsafe==2.1.2 + - matplotlib==3.7.1 + - mdit-py-plugins==0.3.3 + - mdurl==0.1.2 + - mpmath==1.2.1 + - multidict==6.0.4 + - networkx==3.0 + - numpy==1.24.1 + - omegaconf==2.3.0 + - open-clip-torch==2.16.0 + - opencv-contrib-python==4.7.0.72 + - opencv-python-headless==4.7.0.72 + - orjson==3.8.7 + - packaging==23.0 + - pandas==1.5.3 + - pillow==9.3.0 + - prettytable==3.6.0 + - protobuf==3.20.3 + - pyarrow==11.0.0 + - pydantic==1.10.6 + - pydeck==0.8.0 + - pydub==0.25.1 + - pygments==2.14.0 + - pympler==1.0.1 + - pyparsing==3.0.9 + - pyrsistent==0.19.3 + - python-dateutil==2.8.2 + - python-multipart==0.0.6 + - pytorch-lightning==2.0.0 + - pytz==2022.7.1 + - pytz-deprecation-shim==0.1.0.post0 + - pywavelets==1.4.1 + - pyyaml==6.0 + - qudida==0.0.4 + - regex==2022.10.31 + - requests==2.28.1 + - rfc3986==1.5.0 + - rich==13.3.2 + - safetensors==0.3.0 + - scikit-image==0.20.0 + - scikit-learn==1.2.2 + - scipy==1.10.1 + - semver==2.13.0 + - sentencepiece==0.1.97 + - six==1.16.0 + - smmap==5.0.0 + - sniffio==1.3.0 + - starlette==0.26.1 + - streamlit==1.20.0 + - sympy==1.11.1 + - threadpoolctl==3.1.0 + - tifffile==2023.3.15 + - timm==0.8.15.dev0 + - tokenizers==0.13.2 + - toml==0.10.2 + - toolz==0.12.0 + - torch==2.0.0+cu118 + - torchaudio==2.0.1+cu118 + - torchmetrics==0.11.4 + - torchvision==0.15.1+cu118 + - tornado==6.2 + - tqdm==4.65.0 + - transformers==4.27.1 + - triton==2.0.0 + - typing-extensions==4.4.0 + - tzdata==2022.7 + - tzlocal==4.3 + - uc-micro-py==1.0.1 + - urllib3==1.26.13 + - uvicorn==0.21.1 + - validators==0.20.0 + - watchdog==2.3.1 + - wcwidth==0.2.6 + - websockets==10.4 + - yarl==1.8.2 + - zipp==3.15.0 \ No newline at end of file From 43f1855959f2288a7df40fa914b128685a066601 Mon Sep 17 00:00:00 2001 From: tensorneko <22864465+tensorneko@users.noreply.github.com> Date: Wed, 22 Mar 2023 20:23:07 +0000 Subject: [PATCH 3/3] 1.8x speedup, ampere tensor precision --- tutorial_train_sd21.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tutorial_train_sd21.py b/tutorial_train_sd21.py index 9f9225b3fa..d439897818 100644 --- a/tutorial_train_sd21.py +++ b/tutorial_train_sd21.py @@ -1,12 +1,14 @@ import os from share import * +import torch import pytorch_lightning as pl from torch.utils.data import DataLoader from tutorial_dataset import MyDataset from cldm.logger import ImageLogger from cldm.model import create_model, load_state_dict +torch.set_float32_matmul_precision('medium') # Configs resume_path = './models/control_sd21_ini.ckpt'