From 757fc22dd3c65937d8cd7baa41ed86c01eef28f0 Mon Sep 17 00:00:00 2001 From: daniel <1534513+dantp-ai@users.noreply.github.com> Date: Fri, 2 Aug 2024 11:59:44 +0200 Subject: [PATCH] Use torch.allclose instead of all() * Seems that batch slicing leads to slightly different floats some of the time (See: https://github.com/thu-ml/tianshou/pull/1181) --- test/base/test_batch.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/test/base/test_batch.py b/test/base/test_batch.py index 781347a2d..8839fe482 100644 --- a/test/base/test_batch.py +++ b/test/base/test_batch.py @@ -859,10 +859,11 @@ def test_slice_distribution() -> None: selected_idx = [1, 3] sliced_batch = batch[selected_idx] sliced_probs = cat_probs[selected_idx] - assert (sliced_batch.dist.probs == Categorical(probs=sliced_probs).probs).all() - assert ( - Categorical(probs=sliced_probs).probs == get_sliced_dist(dist, selected_idx).probs - ).all() + assert torch.allclose(sliced_batch.dist.probs, Categorical(probs=sliced_probs).probs) + assert torch.allclose( + Categorical(probs=sliced_probs).probs, + get_sliced_dist(dist, selected_idx).probs, + ) # retrieving a single index assert torch.allclose(batch[0].dist.probs, dist.probs[0])