Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Support for Python 3.11 and PyTorch 2.0 #295

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions cldm/logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import torchvision
from PIL import Image
from pytorch_lightning.callbacks import Callback
from pytorch_lightning.utilities.distributed import rank_zero_only
from pytorch_lightning.utilities.rank_zero import rank_zero_only


class ImageLogger(Callback):
Expand Down Expand Up @@ -71,6 +71,6 @@ def log_img(self, pl_module, batch, batch_idx, split="train"):
def check_frequency(self, check_idx):
return check_idx % self.batch_freq == 0

def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx):
def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):
if not self.disabled:
self.log_img(pl_module, batch, batch_idx, split="train")
174 changes: 144 additions & 30 deletions environment.yaml
Original file line number Diff line number Diff line change
@@ -1,35 +1,149 @@
name: control
channels:
- pytorch
- defaults
dependencies:
- python=3.8.5
- pip=20.3
- cudatoolkit=11.3
- pytorch=1.12.1
- torchvision=0.13.1
- numpy=1.23.1
- _libgcc_mutex=0.1=main
- _openmp_mutex=5.1=1_gnu
- bzip2=1.0.8=h7b6447c_0
- ca-certificates=2023.01.10=h06a4308_0
- certifi=2022.12.7=py311h06a4308_0
- ld_impl_linux-64=2.38=h1181459_1
- libffi=3.4.2=h6a678d5_6
- libgcc-ng=11.2.0=h1234567_1
- libgomp=11.2.0=h1234567_1
- libstdcxx-ng=11.2.0=h1234567_1
- libuuid=1.41.5=h5eee18b_0
- ncurses=6.4=h6a678d5_0
- openssl=1.1.1t=h7f8727e_0
- pip=22.3.1=py311h06a4308_0
- python=3.11.0=h7a1cb2a_3
- readline=8.2=h5eee18b_0
- setuptools=65.5.0=py311h06a4308_0
- sqlite=3.41.1=h5eee18b_0
- tk=8.6.12=h1ccaba5_0
- wheel=0.37.1=pyhd3eb1b0_0
- xz=5.2.10=h5eee18b_1
- zlib=1.2.13=h5eee18b_0
- pip:
- gradio==3.16.2
- albumentations==1.3.0
- opencv-contrib-python==4.3.0.36
- imageio==2.9.0
- imageio-ffmpeg==0.4.2
- pytorch-lightning==1.5.0
- omegaconf==2.1.1
- test-tube>=0.7.5
- streamlit==1.12.1
- einops==0.3.0
- transformers==4.19.2
- webdataset==0.2.5
- kornia==0.6
- open_clip_torch==2.0.2
- invisible-watermark>=0.1.5
- streamlit-drawable-canvas==0.8.0
- torchmetrics==0.6.0
- timm==0.6.12
- addict==2.4.0
- yapf==0.32.0
- prettytable==3.6.0
- safetensors==0.2.7
- basicsr==1.4.2
- aiofiles==23.1.0
- aiohttp==3.8.4
- aiosignal==1.3.1
- albumentations==1.3.0
- altair==4.2.2
- antlr4-python3-runtime==4.9.3
- anyio==3.6.2
- async-timeout==4.0.2
- attrs==22.2.0
- blinker==1.5
- cachetools==5.3.0
- charset-normalizer==2.1.1
- click==8.1.3
- cmake==3.25.0
- contourpy==1.0.7
- cycler==0.11.0
- decorator==5.1.1
- einops==0.6.0
- entrypoints==0.4
- fastapi==0.95.0
- ffmpy==0.3.0
- filelock==3.9.0
- fonttools==4.39.2
- frozenlist==1.3.3
- fsspec==2023.3.0
- ftfy==6.1.1
- gitdb==4.0.10
- gitpython==3.1.31
- gradio==3.22.1
- h11==0.14.0
- httpcore==0.16.3
- httpx==0.23.3
- huggingface-hub==0.13.2
- idna==3.4
- imageio==2.26.0
- imageio-ffmpeg==0.4.8
- importlib-metadata==6.1.0
- jinja2==3.1.2
- joblib==1.2.0
- jsonschema==4.17.3
- kiwisolver==1.4.4
- lazy-loader==0.1
- lightning-utilities==0.8.0
- linkify-it-py==2.0.0
- lit==15.0.7
- markdown-it-py==2.2.0
- markupsafe==2.1.2
- matplotlib==3.7.1
- mdit-py-plugins==0.3.3
- mdurl==0.1.2
- mpmath==1.2.1
- multidict==6.0.4
- networkx==3.0
- numpy==1.24.1
- omegaconf==2.3.0
- open-clip-torch==2.16.0
- opencv-contrib-python==4.7.0.72
- opencv-python-headless==4.7.0.72
- orjson==3.8.7
- packaging==23.0
- pandas==1.5.3
- pillow==9.3.0
- prettytable==3.6.0
- protobuf==3.20.3
- pyarrow==11.0.0
- pydantic==1.10.6
- pydeck==0.8.0
- pydub==0.25.1
- pygments==2.14.0
- pympler==1.0.1
- pyparsing==3.0.9
- pyrsistent==0.19.3
- python-dateutil==2.8.2
- python-multipart==0.0.6
- pytorch-lightning==2.0.0
- pytz==2022.7.1
- pytz-deprecation-shim==0.1.0.post0
- pywavelets==1.4.1
- pyyaml==6.0
- qudida==0.0.4
- regex==2022.10.31
- requests==2.28.1
- rfc3986==1.5.0
- rich==13.3.2
- safetensors==0.3.0
- scikit-image==0.20.0
- scikit-learn==1.2.2
- scipy==1.10.1
- semver==2.13.0
- sentencepiece==0.1.97
- six==1.16.0
- smmap==5.0.0
- sniffio==1.3.0
- starlette==0.26.1
- streamlit==1.20.0
- sympy==1.11.1
- threadpoolctl==3.1.0
- tifffile==2023.3.15
- timm==0.8.15.dev0
- tokenizers==0.13.2
- toml==0.10.2
- toolz==0.12.0
- torch==2.0.0+cu118
- torchaudio==2.0.1+cu118
- torchmetrics==0.11.4
- torchvision==0.15.1+cu118
- tornado==6.2
- tqdm==4.65.0
- transformers==4.27.1
- triton==2.0.0
- typing-extensions==4.4.0
- tzdata==2022.7
- tzlocal==4.3
- uc-micro-py==1.0.1
- urllib3==1.26.13
- uvicorn==0.21.1
- validators==0.20.0
- watchdog==2.3.1
- wcwidth==0.2.6
- websockets==10.4
- yarl==1.8.2
- zipp==3.15.0
4 changes: 2 additions & 2 deletions ldm/models/diffusion/ddpm.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import itertools
from tqdm import tqdm
from torchvision.utils import make_grid
from pytorch_lightning.utilities.distributed import rank_zero_only
from pytorch_lightning.utilities.rank_zero import rank_zero_only
from omegaconf import ListConfig

from ldm.util import log_txt_as_img, exists, default, ismap, isimage, mean_flat, count_params, instantiate_from_config
Expand Down Expand Up @@ -588,7 +588,7 @@ def make_cond_schedule(self, ):

@rank_zero_only
@torch.no_grad()
def on_train_batch_start(self, batch, batch_idx, dataloader_idx):
def on_train_batch_start(self, batch, batch_idx):
# only for very first batch
if self.scale_by_std and self.current_epoch == 0 and self.global_step == 0 and batch_idx == 0 and not self.restarted_from_ckpt:
assert self.scale_factor == 1., 'rather not use custom rescaling and std-rescaling simultaneously'
Expand Down
10 changes: 7 additions & 3 deletions tutorial_train_sd21.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
import os
from share import *

import torch
import pytorch_lightning as pl
from torch.utils.data import DataLoader
from tutorial_dataset import MyDataset
from cldm.logger import ImageLogger
from cldm.model import create_model, load_state_dict

torch.set_float32_matmul_precision('medium')

# Configs
resume_path = './models/control_sd21_ini.ckpt'
Expand All @@ -14,6 +17,7 @@
learning_rate = 1e-5
sd_locked = True
only_mid_control = False
workers = os.cpu_count() // 2


# First use cpu to load models. Pytorch Lightning will automatically move it to GPUs.
Expand All @@ -26,10 +30,10 @@

# Misc
dataset = MyDataset()
dataloader = DataLoader(dataset, num_workers=0, batch_size=batch_size, shuffle=True)
dataloader = DataLoader(dataset, num_workers=workers, batch_size=batch_size, shuffle=True)
logger = ImageLogger(batch_frequency=logger_freq)
trainer = pl.Trainer(gpus=1, precision=32, callbacks=[logger])
trainer = pl.Trainer(accelerator='gpu', devices='auto', precision=32, callbacks=[logger])


# Train!
trainer.fit(model, dataloader)
trainer.fit(model, dataloader)