-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathdecorrelation.py
59 lines (44 loc) · 2.06 KB
/
decorrelation.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
"""
Let y1 = f(x1, W) and y2 = f(x2, W).
Learning the weights either for (x1, y1) or (x2, y2) should decorrelate y1 and y2 signals.
"""
import torch
import torch.nn as nn
from torch.utils.data import TensorDataset
from mighty.loss import ContrastiveLossSampler
from mighty.utils.common import set_seed
from mighty.utils.data import DataLoader
from mighty.utils.domain import MonitorLevel
from nn.kwta import *
from nn.trainer import TrainerIWTA
from nn.utils import sample_bernoulli, NoShuffleLoader
set_seed(0)
N_x = N_y = N_h = 200
s_x = 0.2
s_w_xh = s_w_xy = s_w_hy = s_w_yy = s_w_hh = s_w_yh = 0.05
class TrainerIWTADecorrelation(TrainerIWTA):
N_CHOOSE = 10
LEARNING_RATE = 0.001
class RandomDataset(TensorDataset):
def __init__(self, *args, **kwargs):
super().__init__(x, labels)
x = sample_bernoulli((100, N_x), p=s_x)
labels = torch.arange(x.size(0), device=x.device)
Permanence = PermanenceVaryingSparsity
w_xy = Permanence(sample_bernoulli((N_x, N_y), p=s_w_xy), excitatory=True, learn=True)
w_xh = Permanence(sample_bernoulli((N_x, N_h), p=s_w_xh), excitatory=True, learn=True)
w_hy = Permanence(sample_bernoulli((N_h, N_y), p=s_w_hy), excitatory=False, learn=True)
w_hh = Permanence(sample_bernoulli((N_h, N_h), p=s_w_hy), excitatory=False, learn=True)
# w_yy = Permanence(sample_bernoulli((N_y, N_y), p=s_w_yy), excitatory=True, learn=True)
w_yy = None
w_yh = Permanence(sample_bernoulli((N_y, N_h), p=s_w_yh), excitatory=True, learn=True)
iwta = IterativeWTA(w_xy=w_xy, w_xh=w_xh, w_hy=w_hy, w_hh=w_hh, w_yy=w_yy, w_yh=w_yh)
# iwta = KWTANet(w_xy=w_xy, w_xh=w_xh, w_hy=w_hy, kh=10, ky=10)
print(iwta)
data_loader = DataLoader(RandomDataset, transform=None,
loader_cls=NoShuffleLoader, batch_size=10)
criterion = ContrastiveLossSampler(nn.CosineEmbeddingLoss(margin=0))
trainer = TrainerIWTADecorrelation(model=iwta, criterion=criterion,
data_loader=data_loader, verbosity=1)
trainer.monitor.advanced_monitoring(level=MonitorLevel.SIGN_FLIPS | MonitorLevel.WEIGHT_HISTOGRAM)
trainer.train(n_epochs=50)