diff --git a/mnist_forward_forward/main.py b/mnist_forward_forward/main.py index f137dee48a..a175126067 100644 --- a/mnist_forward_forward/main.py +++ b/mnist_forward_forward/main.py @@ -72,9 +72,8 @@ def train(self, x_pos, x_neg): for i in range(self.num_epochs): g_pos = self.forward(x_pos).pow(2).mean(1) g_neg = self.forward(x_neg).pow(2).mean(1) - loss = torch.log( - 1 - + torch.exp( + loss = torch.log1p( + torch.exp( torch.cat([-g_pos + self.threshold, g_neg - self.threshold]) ) ).mean()