-
Notifications
You must be signed in to change notification settings - Fork 57
/
Copy pathresidual_encoder.py
136 lines (114 loc) · 6.93 KB
/
residual_encoder.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
import torch
import torch.nn as nn
class residual_encoder(nn.Module) :
'''
Neural network that can be used to parametrize q(z_{l}|x) and q(z_{o}|x)
'''
def __init__(self, hparams, log_min_std_dev=-1):
super(residual_encoder, self).__init__()
self.conv1 = nn.Conv1d(hparams.n_mel_channels, 512, 3, 1)
self.bi_lstm = nn.LSTM(512, 256, 2, bidirectional = True, batch_first=True)
self.linear = nn.Linear(512, 32)
self.residual_encoding_dim = int(hparams.residual_encoding_dim/2)
self.register_buffer('min_std_dev', torch.exp(torch.tensor([log_min_std_dev]).float()) )
def forward(self, x):
'''
x.shape = [batch_size, seq_len, n_mel_channels]
returns single sample from the distribution q(z_{l}|X) or q(z_{o}|X) of size [batch_size, 16]
'''
x = self.conv1(x.transpose(2,1)).transpose(2,1)
output, (_,_) = self.bi_lstm(x)
seq_len = output.shape[1]
output = output.sum(dim=1)/seq_len
x = self.linear(output)
mean, log_variance = x[:,:self.residual_encoding_dim], x[:,self.residual_encoding_dim:]
std_dev = torch.sqrt(torch.exp(log_variance))
return torch.distributions.normal.Normal(mean,torch.max(std_dev, self.min_std_dev)) #Check here if scale_tril=log_variance ?
class continuous_given_discrete(nn.Module) :
'''
Class for p(z_{o}|y_{o}) and p(z_{l}|y_{l})
n_disc :- number of discrete possible values for y_{o/l}
distrib_lis[i] :- is the distribution over z , p(z|y=i). Total n_disc distribuitons
std_init :- standard deviation is initialized to e^(std_init). And clamped to be >= e^(2*std_init)
distribs :- p(z|y) for all y. Can be used to sample n_disc z's {1 from each of the n_disc distribution of prev line}, simultaneously.
'''
def __init__(self, hparams, n_disc, std_init=-1) :
super(continuous_given_discrete, self).__init__()
self.n_disc = n_disc
self.residual_encoding_dim = int(hparams.residual_encoding_dim/2)
self.std_init = torch.tensor([std_init]).float()
self.cont_given_disc_mus = nn.Parameter(torch.randn((self.n_disc, self.residual_encoding_dim)))
self.cont_given_disc_sigmas = nn.Parameter(torch.exp(self.std_init)*torch.ones((self.n_disc, self.residual_encoding_dim)))
self.distrib_lis = self.make_normal_distribs(self.cont_given_disc_mus, self.cont_given_disc_sigmas, make_lis=True)
self.distribs = self.make_normal_distribs(self.cont_given_disc_mus, self.cont_given_disc_sigmas, make_lis=False)
def make_normal_distribs(self, mus, sigmas, make_lis = False) :
if make_lis :
return [torch.distributions.normal.Normal(mus[i], sigmas[i]) for i in range(mus.shape[0])]
return torch.distributions.normal.Normal(mus, sigmas)
def after_optim_step(self) :
sigmas = self.cont_given_disc_sigmas.data
sigmas = sigmas.clamp(float(torch.exp(torch.tensor(2.)*self.std_init).data))
self.cont_given_disc_sigmas.data = sigmas
self.cont_given_disc_mus.detach_()
self.cont_given_disc_sigmas.detach_()
self.cont_given_disc_mus.requires_grad=True
self.cont_given_disc_sigmas.requires_grad=True
self.distrib_lis = self.make_normal_distribs(self.cont_given_disc_mus, self.cont_given_disc_sigmas, make_lis=True)
self.distribs = self.make_normal_distribs(self.cont_given_disc_mus, self.cont_given_disc_sigmas, make_lis=False)
class residual_encoders(nn.Module) :
def __init__(self, hparams) :
super(residual_encoders, self).__init__()
#Variational Posteriors
self.q_zl_given_X = residual_encoder(hparams, -2) #q(z_{l}|X)
self.q_zo_given_X = residual_encoder(hparams, -4) #q(z_{o}|X)
self.q_zl_given_X_at_x = None
self.q_zo_given_X_at_x = None
self.residual_encoding_dim = hparams.residual_encoding_dim
self.mcn = hparams.mcn
#Priors
self.y_l_probs = nn.Parameter(torch.ones((hparams.dim_yl)))
self.y_l_probs.requires_grad = False
self.y_l = torch.distributions.categorical.Categorical(self.y_l_probs)
self.p_zo_given_yo = continuous_given_discrete(hparams, hparams.dim_yo, -2)
self.p_zl_given_yl = continuous_given_discrete(hparams, hparams.dim_yl, -1)
self.q_yl_given_X = None
def calc_q_tilde(self, sampled_zl) :
'''
Caculates approximation to q_yl_given_X using monte carlo sampling, for each element in a batch.
Supposed to be recalculated for each batch.
'''
K = self.p_zl_given_yl.n_disc
sampled_zl = sampled_zl.repeat_interleave(K,-2)
sampled_zl = sampled_zl.reshape(sampled_zl.shape[0], -1, K, sampled_zl.shape[-1])
probs = self.p_zl_given_yl.distribs.log_prob(sampled_zl).exp() #[mcn, batch_size, K, residual_encoding_dim/2]
p_zl_givn_yl = probs.double().prod(dim=-1) #[mcn, batch_size, K]
ans = p_zl_givn_yl*self.y_l.probs
normalization_consts = ans.sum(dim=-1) #[mcn, batch_size]
ans = ans.permute(2,0,1)/(normalization_consts) #[K, mcn, batch_size]
self.q_yl_given_X = ans.sum(dim=1)/self.mcn #[K, batch_size]
def forward(self, x) :
'''
x.shape = [seq_len, batch_size, n_mel_channels]
z_l.shape, z_o.shape == [hparams.mcn, batch_size, hparams.residual_encoding_dim/2]
returns concatenation of z_{o} and z_{l} sampled from respective distributions
'''
x = x.transpose(1,0)
self.q_zl_given_X_at_x, self.q_zo_given_X_at_x = self.q_zl_given_X(x), self.q_zo_given_X(x)
z_l, z_o = self.q_zl_given_X_at_x.rsample((self.mcn, )), self.q_zo_given_X_at_x.rsample((self.mcn,)) #[mcn, batch_size, residual_encoding_dim/2]
self.calc_q_tilde(z_l)
return torch.cat([z_l,z_o], dim=-1).reshape(-1, self.residual_encoding_dim)
def redefine_y_l(self) :
'''To be called whenever model is sent to new device'''
self.y_l = torch.distributions.categorical.Categorical(self.y_l_probs)
def after_optim_step(self) :
'''
The parameters :- cont_given_disc_mus, sigmas, are altered, so their distributions need to be made again.
'''
self.p_zo_given_yo.after_optim_step()
self.p_zl_given_yl.after_optim_step()
def infer(self, y_o_idx, y_l_idx=None) :
if y_l_idx is None :
y_l_idx = self.y_l.sample()
z_l = self.p_zl_given_yl.distrib_lis[y_l_idx].sample()
z_o = self.p_zo_given_yo.distrib_lis[y_o_idx].sample()
return torch.cat([z_l,z_o], dim=-1).unsqueeze(dim=0)