From 141da4b36de87d56008bd16899188f7ac2cb6ea4 Mon Sep 17 00:00:00 2001 From: hypnopump Date: Mon, 12 Aug 2024 23:47:00 +0100 Subject: [PATCH] fix test to cover case and report --- tests/ops/test_delta.py | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/tests/ops/test_delta.py b/tests/ops/test_delta.py index 0f3df69c3..ff4739bcb 100644 --- a/tests/ops/test_delta.py +++ b/tests/ops/test_delta.py @@ -14,10 +14,11 @@ @pytest.mark.parametrize("D", [128]) @pytest.mark.parametrize("dtype", [torch.float]) def test_beta_scalar_vector_equivalence(B: int, H: int, T: int, D: int, dtype: torch.dtype): + torch.manual_seed(17) q = torch.randn(B, H, T, D, dtype=dtype) k = torch.nn.functional.normalize(torch.randn(B, H, T, D, dtype=dtype), p=2, dim=-1) v = torch.randn(B, H, T, D, dtype=dtype) - beta = torch.rand(B, H, T, dtype=dtype).sigmoid() + beta = torch.rand(B, H, T, D, dtype=dtype).sigmoid() q, k, v, beta = map(lambda x: x.cuda().requires_grad_(True), (q, k, v, beta)) do = torch.rand_like(v) @@ -31,11 +32,12 @@ def test_beta_scalar_vector_equivalence(B: int, H: int, T: int, D: int, dtype: t q_grad2, k_grad2, v_grad2, beta_grad2 = q.grad, k.grad, v.grad, beta.grad q.grad = k.grad = v.grad = beta.grad = None - assert o.allclose(o2, 0, 1e-3), f"Diff: {torch.abs(o - o2).max()}" - assert q_grad.allclose(q_grad2, 0, 1e-3), f"Diff: {torch.abs(q_grad - q_grad2).max()}" - assert k_grad.allclose(k_grad2, 0, 1e-3), f"Diff: {torch.abs(k_grad - k_grad2).max()}" - assert v_grad.allclose(v_grad2, 0, 1e-3), f"Diff: {torch.abs(v_grad - v_grad2).max()}" - assert beta_grad.allclose(beta_grad2, 0, 1e-3), f"Diff: {torch.abs(beta_grad - beta_grad2).max()}" + assert o.allclose(o2, rtol=0, atol=2e-5), f"Diff: {torch.abs(o - o2).max()}" + assert q_grad.allclose(q_grad2, rtol=0, atol=2e-5), f"Diff: {torch.abs(q_grad - q_grad2).max()}" + assert k_grad.allclose(k_grad2, rtol=0, atol=2e-5), f"Diff: {torch.abs(k_grad - k_grad2).max()}" + assert v_grad.allclose(v_grad2, rtol=0, atol=2e-5), f"Diff: {torch.abs(v_grad - v_grad2).max()}" + # FIXME: this gradient does not match when beta a vector. matches when a scalar. + assert beta_grad.allclose(beta_grad2, rtol=0, atol=1e-3), f"Diff: {torch.abs(beta_grad - beta_grad2).max()}" @pytest.mark.parametrize("B", [4])