Skip to content

Commit

Permalink
Update docstrings
Browse files Browse the repository at this point in the history
Signed-off-by: Beat Buesser <[email protected]>
  • Loading branch information
beat-buesser committed Dec 23, 2023
1 parent 7a8cce2 commit 4abbedb
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 4 deletions.
4 changes: 2 additions & 2 deletions art/estimators/object_detection/pytorch_object_detector.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,8 +298,8 @@ def _get_losses(
"""
self._model.train()

self.set_dropout(False)
self.set_multihead_attention(False)
self.set_dropout(train=False)
self.set_multihead_attention(train=False)

# Apply preprocessing and convert to tensors
x_preprocessed, y_preprocessed = self._preprocess_and_convert_inputs(x=x, y=y, fit=False, no_grad=False)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ def test_loss_gradient(art_warning, get_pytorch_detr):
)

print("expected_gradients1")
print(expected_gradients1)
print(grads[0, 0, 10, :32])

np.testing.assert_array_almost_equal(grads[0, 0, 10, :32], expected_gradients1, decimal=1)

Expand Down Expand Up @@ -181,7 +181,7 @@ def test_loss_gradient(art_warning, get_pytorch_detr):
)

print("expected_gradients2")
print(expected_gradients2)
print(grads[1, 0, 10, :32])

np.testing.assert_array_almost_equal(grads[1, 0, 10, :32], expected_gradients2, decimal=2)

Expand Down

0 comments on commit 4abbedb

Please sign in to comment.