diff --git a/social-attention/srnn-pytorch/srnn/train.py b/social-attention/srnn-pytorch/srnn/train.py index b0d8bc5..916dfde 100644 --- a/social-attention/srnn-pytorch/srnn/train.py +++ b/social-attention/srnn-pytorch/srnn/train.py @@ -178,7 +178,7 @@ def checkpoint_path(x): # Compute loss loss = Gaussian2DLikelihood(outputs, nodes[1:], nodesPresent[1:], args.pred_length) - loss_batch += loss.data[0] + loss_batch += loss.item() # Compute gradients loss.backward() @@ -241,7 +241,7 @@ def checkpoint_path(x): # Compute loss loss = Gaussian2DLikelihood(outputs, nodes[1:], nodesPresent[1:], args.pred_length) - loss_batch += loss.data[0] + loss_batch += loss.item() # Reset the stgraph stgraph.reset()