-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathmodels.py
101 lines (85 loc) · 3.12 KB
/
models.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
import torch.nn as nn
class FC_G(nn.Module):
def __init__(self, idim=2, odim=2, hidden_dim=512):
super(FC_G, self).__init__()
main = nn.Sequential(
nn.Linear(idim, hidden_dim),
nn.ReLU(True),
nn.Linear(hidden_dim, hidden_dim),
nn.ReLU(True),
nn.Linear(hidden_dim, hidden_dim),
nn.ReLU(True),
nn.Linear(hidden_dim, odim),
)
self.main = main
def forward(self, noise):
output = self.main(noise)
return output
class FC_D(nn.Module):
def __init__(self, idim=2, hidden_dim=512):
super(FC_D, self).__init__()
main = nn.Sequential(
nn.Linear(idim, hidden_dim),
nn.ReLU(True),
nn.Linear(hidden_dim, hidden_dim),
nn.ReLU(True),
nn.Linear(hidden_dim, hidden_dim),
nn.ReLU(True),
nn.Linear(hidden_dim, 1),
)
self.main = main
def forward(self, input):
output = self.main(input)
return output.view(-1)
class Resnet_G(nn.Module): # https://github.com/caogang/wgan-gp/blob/master/gan_cifar10.py
def __init__(self, hidden_dim=128):
super(Resnet_G, self).__init__()
self.hidden_dim = hidden_dim
preprocess = nn.Sequential(
nn.Linear(128, 4 * 4 * 4 * hidden_dim),
nn.BatchNorm1d(4 * 4 * 4 * hidden_dim),
nn.ReLU(True),
)
block1 = nn.Sequential(
nn.ConvTranspose2d(4 * hidden_dim, 2 * hidden_dim, 2, stride=2),
nn.BatchNorm2d(2 * hidden_dim),
nn.ReLU(True),
)
block2 = nn.Sequential(
nn.ConvTranspose2d(2 * hidden_dim, hidden_dim, 2, stride=2),
nn.BatchNorm2d(hidden_dim),
nn.ReLU(True),
)
deconv_out = nn.ConvTranspose2d(hidden_dim, 3, 2, stride=2)
self.preprocess = preprocess
self.block1 = block1
self.block2 = block2
self.deconv_out = deconv_out
self.tanh = nn.Tanh()
def forward(self, input):
output = self.preprocess(input)
output = output.view(-1, 4 * self.hidden_dim, 4, 4)
output = self.block1(output)
output = self.block2(output)
output = self.deconv_out(output)
output = self.tanh(output)
return output.view(-1, 3, 32, 32)
class Convnet_D(nn.Module): # https://github.com/caogang/wgan-gp/blob/master/gan_cifar10.py
def __init__(self, hidden_dim=128):
super(Convnet_D, self).__init__()
self.hidden_dim = hidden_dim
main = nn.Sequential(
nn.Conv2d(3, hidden_dim, 3, 2, padding=1),
nn.LeakyReLU(),
nn.Conv2d(hidden_dim, 2 * hidden_dim, 3, 2, padding=1),
nn.LeakyReLU(),
nn.Conv2d(2 * hidden_dim, 4 * hidden_dim, 3, 2, padding=1),
nn.LeakyReLU(),
)
self.main = main
self.linear = nn.Linear(4*4*4*hidden_dim, 1)
def forward(self, input):
output = self.main(input)
output = output.view(-1, 4*4*4*self.hidden_dim)
output = self.linear(output)
return output