-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathvaesweep.py
136 lines (118 loc) · 5.09 KB
/
vaesweep.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
import wandb
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.datasets as datasets
import torchvision.transforms as transforms
from torch.utils.data import random_split
from pythae.models import VAE, VAEConfig
from pythae.models.nn import BaseEncoder, BaseDecoder
from pythae.models.base.base_utils import ModelOutput
from pythae.trainers.training_callbacks import WandbCallback
from pythae.data.datasets import DatasetOutput
from pythae.trainers import BaseTrainerConfig
from pythae.pipelines.training import TrainingPipeline
class UTKFaceDataset(datasets.ImageFolder):
def __init__(self, root, transform=None, target_transform=None):
super().__init__(root=root, transform=transform, target_transform=target_transform)
def __getitem__(self, index):
X, _ = super().__getitem__(index)
return DatasetOutput(data=X)
class UTKFace_Encoder(BaseEncoder):
def __init__(self, lat_dim):
super(UTKFace_Encoder, self).__init__()
self.conv1 = nn.Conv2d(3, 32, kernel_size=4, stride=2, padding=1)
self.conv2 = nn.Conv2d(32, 64, kernel_size=4, stride=2, padding=1)
self.conv3 = nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1)
self.fc_mu = nn.Linear(128 * 25 * 25, lat_dim)
self.fc_logvar = nn.Linear(128 * 25 * 25, lat_dim)
def forward(self, x: torch.Tensor) -> ModelOutput:
x = F.relu(self.conv1(x))
x = F.relu(self.conv2(x))
x = F.relu(self.conv3(x))
x = x.view(x.size(0), -1) # flatten the tensor
mu = self.fc_mu(x)
logvar = self.fc_logvar(x)
output = ModelOutput(
embedding=mu, # mean
log_covariance=logvar # log variance
)
return output
class UTKFace_Decoder(BaseDecoder):
def __init__(self, lat_dim):
super(UTKFace_Decoder, self).__init__()
self.fc = nn.Linear(lat_dim, 128 * 25 * 25)
self.deconv1 = nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1)
self.deconv2 = nn.ConvTranspose2d(64, 32, kernel_size=4, stride=2, padding=1)
self.deconv3 = nn.ConvTranspose2d(32, 3, kernel_size=4, stride=2, padding=1)
def forward(self, x: torch.Tensor) -> ModelOutput:
x = self.fc(x)
x = x.view(-1, 128, 25, 25) # reshape tensor
x = F.relu(self.deconv1(x))
x = F.relu(self.deconv2(x))
x = torch.sigmoid(self.deconv3(x))
output = ModelOutput(
reconstruction=x
)
return output
def train():
if torch.cuda.is_available():
device = torch.device("cuda")
gpu_name = torch.cuda.get_device_name(0)
print(f"CUDA is available. GPU Name: {gpu_name}")
elif torch.backends.mps.is_available():
device = torch.device("mps")
print("Training on MPS device")
else:
device = torch.device("cpu")
print("MPS not available, training on CPU")
# Initialize WandB
with wandb.init() as run:
batch_size = run.config.batch_size
latent_dim = run.config.latent_dim
learning_rate = run.config['training_config.learning_rate']
num_epochs = run.config['training_config.num_epochs']
transform = transforms.Compose([transforms.ToTensor(),])
all_dataset = UTKFaceDataset(root="./data", transform=transform)
train_size = int(0.8 * len(all_dataset))
eval_size = len(all_dataset) - train_size
train_dataset, eval_dataset = random_split(all_dataset, [train_size, eval_size])
encoder = UTKFace_Encoder(latent_dim)
decoder = UTKFace_Decoder(latent_dim)
model_config = VAEConfig(input_dim=(3, 200, 200), latent_dim=latent_dim)
model = VAE(model_config=model_config, encoder=encoder, decoder=decoder)
config = BaseTrainerConfig(
output_dir='my_model',
learning_rate=learning_rate,
per_device_train_batch_size=batch_size,
per_device_eval_batch_size=batch_size,
num_epochs=num_epochs,
)
callbacks = []
wandb_cb = WandbCallback()
#this line will need adjusting for different W&B accounts
wandb_cb.setup(training_config=config, model_config=model_config, project_name="VAE_UTKFACE", entity_name="charlesdoyne")
callbacks.append(wandb_cb)
pipeline = TrainingPipeline(training_config=config, model=model)
pipeline(train_data=train_dataset, eval_data=eval_dataset, callbacks=callbacks)
if __name__ == "__main__":
sweep_configuration = {
'method': 'bayes',
'metric': {
'goal': 'minimize',
'name': 'train/epoch_loss'
},
'parameters': {
'batch_size': {'values': [4, 8, 16]},
'latent_dim': {'min': 100, 'max': 200},
'training_config.learning_rate': {'min': 0.0005, 'max': 0.001},
'training_config.num_epochs': {'values': [10]}
},
'early_terminate': {
'type': 'hyperband',
'min_iter': 5,
},
'count': 50
}
sweep_id = wandb.sweep(sweep=sweep_configuration, project="VAE_UTKFACE")
wandb.agent(sweep_id, train)