diff --git a/test/base/test_batch.py b/test/base/test_batch.py index 781347a2d..8839fe482 100644 --- a/test/base/test_batch.py +++ b/test/base/test_batch.py @@ -859,10 +859,11 @@ def test_slice_distribution() -> None: selected_idx = [1, 3] sliced_batch = batch[selected_idx] sliced_probs = cat_probs[selected_idx] - assert (sliced_batch.dist.probs == Categorical(probs=sliced_probs).probs).all() - assert ( - Categorical(probs=sliced_probs).probs == get_sliced_dist(dist, selected_idx).probs - ).all() + assert torch.allclose(sliced_batch.dist.probs, Categorical(probs=sliced_probs).probs) + assert torch.allclose( + Categorical(probs=sliced_probs).probs, + get_sliced_dist(dist, selected_idx).probs, + ) # retrieving a single index assert torch.allclose(batch[0].dist.probs, dist.probs[0])