You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Hi all, I have already trained a diffusion model with class conditioning (via embedding layers) for weakly-supervised anomaly detection (classifier-free guidance) following the tutorial here . This was trained for 1000 epochs.
I want to take the trained weights of this diffusion model and embedding layers and train a ControlNet model using class and boundary mask condition. The tutorial here trains a ControlNet model with the pretrained weights from diffusion model checkpoint, but doesn't implement a class conditional ControlNet. I want my ControlNet to have both class and boundary mask conditioning. I used the following training and inference codes and trained the ControlNet for 100 epochs but I was not able to detect any anomaly (although the initial diffusion model was able to detect anomalies). So, I would like if someone could verify my code before I train the ControlNet for more epochs (since training this takes several days) on my machines:
Training code (DDP implementation):
#%%
import os
import time
import os
import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn.functional as F
import sys
from monai.utils import set_determinism
from torch.amp import GradScaler, autocast
from tqdm import tqdm
import argparse
import pandas as pd
from generative.inferers import DiffusionInferer, ControlNetDiffusionInferer
from generative.networks.nets import DiffusionModelUNet, ControlNet
from generative.networks.schedulers import DDIMScheduler, DDPMScheduler
from torch.utils.data.distributed import DistributedSampler
from torch.nn.parallel import DistributedDataParallel as DDP
import torch.distributed as dist
from get_data import get_train_valid_datasets_for_controlnet
torch.multiprocessing.set_sharing_strategy("file_system")
from monai.data import DataLoader, CacheDataset
torch.backends.cudnn.benchmark = True
WORKING_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "..")
sys.path.append(WORKING_DIR)
from utils.utils import pad_zeros_at_front, str2bool
# os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:512"
#%%
def ddp_setup():
dist.init_process_group(backend='nccl', init_method="env://")
set_determinism(42)
def main_worker(models_dir, logs_dir, args):
# init_process_group
ddp_setup()
# get local rank on the GPU
local_rank = int(dist.get_rank())
if local_rank == 0:
print(f"The models will be saved in {models_dir}")
print(f"The training/validation logs will be saved in {logs_dir}")
dataset_train, dataset_valid = get_train_valid_datasets_for_controlnet(args.cache_rate)
dataset_train = dataset_train + dataset_valid
print(f'Total number of images: {dataset_train}')
sampler_train = DistributedSampler(dataset=dataset_train, shuffle=True)
dataloader_train = DataLoader(
dataset_train,
batch_size=args.batch_size,
pin_memory=True,
shuffle=False,
sampler=sampler_train,
num_workers=args.num_workers
)
trainlog_fpath = os.path.join(logs_dir, f'trainlog_gpu{local_rank}.csv')
device = torch.device(f"cuda:{local_rank}")
torch.cuda.set_device(device)
embedding_dimension = 64
model = DiffusionModelUNet(
spatial_dims=2,
in_channels=1,
out_channels=1,
num_channels=(64, 64, 64),
attention_levels=(False, True, True),
num_res_blocks=1,
num_head_channels=16,
with_conditioning=True,
cross_attention_dim=embedding_dimension,
).to(device)
embed = torch.nn.Embedding(num_embeddings=3, embedding_dim=embedding_dimension, padding_idx=0).to(device)
scheduler = DDIMScheduler(num_train_timesteps=1000)
diffusion_model_path = '/data/blobfuse/diffusion2d_wsad/results/ddpm_ig/AttnLayers_011/models/chkpt_ep01000.pth'
state_dict = torch.load(diffusion_model_path, map_location=device)
model.load_state_dict(state_dict['model_state_dict'])
embed.load_state_dict(state_dict['embed_state_dict'])
controlnet = ControlNet(
spatial_dims=2,
in_channels=1,
num_channels=(64, 64, 64),
attention_levels=(False, True, True),
num_res_blocks=1,
num_head_channels=16,
conditioning_embedding_num_channels=(16,),
with_conditioning=True,
cross_attention_dim=embedding_dimension,
).to(device)
controlnet.load_state_dict(model.state_dict(), strict=False)
scheduler = DDIMScheduler(num_train_timesteps=1000)
for p in model.parameters():
p.requires_grad = False
optimizer = torch.optim.Adam(params=list(controlnet.parameters()) + list(embed.parameters()), lr=1e-5)
controlnet_inferer = ControlNetDiffusionInferer(scheduler)
start_epoch = 0
# model = DDP(model, device_ids=[device])
embed = DDP(embed, device_ids=[device])
controlnet = DDP(controlnet, device_ids=[device])
condition_dropout = 0.15
n_epochs = args.epochs
val_interval = args.val_interval
train_epoch_loss_list = []
scaler = GradScaler()
experiment_start_time = time.time()
for epoch in range(n_epochs):
epoch_start_time = time.time()
model.train()
epoch_loss = 0
sampler_train.set_epoch(epoch)
progress_bar = tqdm(enumerate(dataloader_train), total=len(dataloader_train), ncols=80)
progress_bar.set_description(f"Epoch {start_epoch + epoch + 1}")
for step, batch in progress_bar:
images, classes, boundarymasks = batch['PT'].to(device), batch['Label'].to(device), batch['BoundaryMask'].to(device)
classes = classes * (torch.rand_like(classes) > condition_dropout)
class_embedding = embed(classes.long().to(device)).unsqueeze(1)
optimizer.zero_grad(set_to_none=True)
timesteps = torch.randint(0, 1000, (len(images),)).to(device)
with autocast('cuda'):
noise = torch.randn_like(images).to(device)
noise_pred = controlnet_inferer(
inputs=images,
diffusion_model=model,
controlnet=controlnet,
noise=noise,
timesteps=timesteps,
condition=class_embedding,
cn_cond=boundarymasks
)
loss = F.mse_loss(noise_pred.float(), noise.float())
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()
epoch_loss += loss.item()
progress_bar.set_postfix({f"GPU[{local_rank}]: Train Loss": round(epoch_loss / (step + 1), 6)})
train_epoch_loss_list.append((epoch_loss / (step + 1)))
trainlog_df = pd.DataFrame(train_epoch_loss_list, columns=['Loss'])
trainlog_df.to_csv(trainlog_fpath, index=False)
if (epoch + 1) % val_interval == 0:
saved_dict = {
'controlnet_state_dict': controlnet.module.state_dict(),
'embed_state_dict': embed.module.state_dict(),
'scheduler_state_dict': scheduler.state_dict(),
'optimizer_state_dict': optimizer.state_dict()
}
checkpoint_fpath = os.path.join(models_dir, f'chkpt_ep{pad_zeros_at_front(start_epoch+epoch+1, 5)}.pth')
torch.save(saved_dict, checkpoint_fpath)
epoch_end_time = (time.time() - epoch_start_time)
print(f"[GPU:{local_rank}]: Epoch {start_epoch + epoch + 1} time: {round(epoch_end_time,2)} sec")
experiment_end_time = (time.time() - experiment_start_time)/(60)
print(f"[GPU:{local_rank}]: Total time: {round(experiment_end_time,2)} min")
print('Destroying process')
dist.destroy_process_group()
print('Destroyed process')
reacted with thumbs up emoji reacted with thumbs down emoji reacted with laugh emoji reacted with hooray emoji reacted with confused emoji reacted with heart emoji reacted with rocket emoji reacted with eyes emoji
-
Hi all, I have already trained a diffusion model with class conditioning (via embedding layers) for weakly-supervised anomaly detection (classifier-free guidance) following the tutorial here . This was trained for 1000 epochs.
I want to take the trained weights of this diffusion model and embedding layers and train a ControlNet model using class and boundary mask condition. The tutorial here trains a ControlNet model with the pretrained weights from diffusion model checkpoint, but doesn't implement a class conditional ControlNet. I want my ControlNet to have both class and boundary mask conditioning. I used the following training and inference codes and trained the ControlNet for 100 epochs but I was not able to detect any anomaly (although the initial diffusion model was able to detect anomalies). So, I would like if someone could verify my code before I train the ControlNet for more epochs (since training this takes several days) on my machines:
Training code (DDP implementation):
Inference code:
Thanks in advance and Happy holidays!
Beta Was this translation helpful? Give feedback.
All reactions