Skip to content

Commit

Permalink
Use torch.allclose instead of all()
Browse files Browse the repository at this point in the history
 * Seems that batch slicing leads to slightly different floats some of the time
 (See: #1181)
  • Loading branch information
dantp-ai committed Aug 2, 2024
1 parent f81e225 commit 757fc22
Showing 1 changed file with 5 additions and 4 deletions.
9 changes: 5 additions & 4 deletions test/base/test_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -859,10 +859,11 @@ def test_slice_distribution() -> None:
selected_idx = [1, 3]
sliced_batch = batch[selected_idx]
sliced_probs = cat_probs[selected_idx]
assert (sliced_batch.dist.probs == Categorical(probs=sliced_probs).probs).all()
assert (
Categorical(probs=sliced_probs).probs == get_sliced_dist(dist, selected_idx).probs
).all()
assert torch.allclose(sliced_batch.dist.probs, Categorical(probs=sliced_probs).probs)
assert torch.allclose(
Categorical(probs=sliced_probs).probs,
get_sliced_dist(dist, selected_idx).probs,
)
# retrieving a single index
assert torch.allclose(batch[0].dist.probs, dist.probs[0])

Expand Down

0 comments on commit 757fc22

Please sign in to comment.