Skip to content

Commit

Permalink
Fix unit test
Browse files Browse the repository at this point in the history
Signed-off-by: Beat Buesser <[email protected]>
  • Loading branch information
beat-buesser committed Dec 26, 2023
1 parent 3040e4c commit f83e15e
Showing 1 changed file with 65 additions and 77 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -97,21 +97,9 @@ def test_loss_gradient(art_warning, get_pytorch_detr):
try:
object_detector, x_test, y_test = get_pytorch_detr

print("x_test[0]")
print(x_test[0])
print("x_test[1]")
print(x_test[1])

grads = object_detector.loss_gradient(x=x_test, y=y_test)

assert grads.shape == (5, 3, 800, 800)

print("expected_gradients1")
print(grads[0, 0, 10, :32])
print("expected_gradients2")
print(grads[1, 0, 10, :32])

grads = object_detector.loss_gradient(x=x_test, y=y_test)
assert grads.shape == (2, 3, 800, 800)

print("expected_gradients1")
print(grads[0, 0, 10, :32])
Expand All @@ -120,77 +108,77 @@ def test_loss_gradient(art_warning, get_pytorch_detr):

expected_gradients1 = np.asarray(
[
-0.00757495,
-0.00101332,
0.00368362,
0.00283334,
-0.00096027,
0.00873749,
0.00546095,
-0.00823532,
-0.00710872,
0.00389713,
-0.00966289,
0.00448294,
0.00754991,
-0.00934104,
-0.00350194,
-0.00541577,
-0.00395624,
0.00147651,
0.0105616,
0.01231265,
-0.00148831,
-0.0043609,
0.00093031,
0.00884939,
-0.00356749,
0.00093475,
-0.00353712,
-0.0060132,
-0.00067899,
-0.00886974,
0.00108483,
-0.00052412,
-0.02030289,
-0.00355719,
0.0065711,
-0.01009711,
0.00190201,
0.01885923,
-0.00449042,
-0.02009461,
-0.00996577,
0.0073015,
-0.02389232,
0.00877987,
0.01518259,
-0.02014997,
-0.00818033,
-0.01121265,
-0.01399302,
-0.00167601,
0.02684669,
0.03023219,
-0.00318609,
-0.0069191,
0.00056615,
0.01815295,
-0.00779946,
0.00157681,
-0.00611856,
-0.01348296,
-0.0016219,
-0.0178297,
0.00483095,
-0.00505776,
]
)

np.testing.assert_array_almost_equal(grads[0, 0, 10, :32], expected_gradients1, decimal=4)

expected_gradients2 = np.asarray(
[
-0.00757495,
-0.00101332,
0.00368362,
0.00283334,
-0.00096027,
0.00873749,
0.00546095,
-0.00823532,
-0.00710872,
0.00389713,
-0.00966289,
0.00448294,
0.00754991,
-0.00934104,
-0.00350194,
-0.00541577,
-0.00395624,
0.00147651,
0.0105616,
0.01231265,
-0.00148831,
-0.0043609,
0.00093031,
0.00884939,
-0.00356749,
0.00093475,
-0.00353712,
-0.0060132,
-0.00067899,
-0.00886974,
0.00108483,
-0.00052412,
-0.02030289,
-0.00355719,
0.0065711,
-0.01009711,
0.00190201,
0.01885923,
-0.00449042,
-0.02009461,
-0.00996577,
0.0073015,
-0.02389232,
0.00877987,
0.01518259,
-0.02014997,
-0.00818033,
-0.01121265,
-0.01399302,
-0.00167601,
0.02684669,
0.03023219,
-0.00318609,
-0.0069191,
0.00056615,
0.01815295,
-0.00779946,
0.00157681,
-0.00611856,
-0.01348296,
-0.0016219,
-0.0178297,
0.00483095,
-0.00505776,
]
)

Expand Down

0 comments on commit f83e15e

Please sign in to comment.