[DOCS] fixed typo assert_almost_equal -> assert_allclose in tutorials (#1456)

This commit is contained in:
Philippe Tillet
2023-03-31 11:27:18 -07:00
committed by GitHub
parent 28ea484dab
commit 123afdf423
2 changed files with 8 additions and 8 deletions

View File

@@ -311,10 +311,10 @@ def test_layer_norm(M, N, dtype, eps=1e-5, device='cuda'):
y_ref.backward(dy, retain_graph=True)
dx_ref, dw_ref, db_ref = [_.grad.clone() for _ in [x, weight, bias]]
# compare
triton.testing.assert_almost_equal(y_tri, y_ref)
triton.testing.assert_almost_equal(dx_tri, dx_ref)
triton.testing.assert_almost_equal(db_tri, db_ref, decimal=1)
triton.testing.assert_almost_equal(dw_tri, dw_ref, decimal=1)
assert torch.allclose(y_tri, y_ref, atol=1e-2, rtol=0)
assert torch.allclose(dx_tri, dx_ref, atol=1e-2, rtol=0)
assert torch.allclose(db_tri, db_ref, atol=1e-2, rtol=0)
assert torch.allclose(dw_tri, dw_ref, atol=1e-2, rtol=0)
@triton.testing.perf_report(

View File

@@ -299,10 +299,10 @@ def test_op(Z, H, N_CTX, D_HEAD, dtype=torch.float16):
tri_dk, k.grad = k.grad.clone(), None
tri_dq, q.grad = q.grad.clone(), None
# compare
triton.testing.assert_almost_equal(ref_out, tri_out)
triton.testing.assert_almost_equal(ref_dv, tri_dv)
triton.testing.assert_almost_equal(ref_dk, tri_dk)
triton.testing.assert_almost_equal(ref_dq, tri_dq)
assert torch.allclose(ref_out, tri_out, atol=1e-2, rtol=0)
assert torch.allclose(ref_dv, tri_dv, atol=1e-2, rtol=0)
assert torch.allclose(ref_dk, tri_dk, atol=1e-2, rtol=0)
assert torch.allclose(ref_dq, tri_dq, atol=1e-2, rtol=0)
try: