mirror of
https://github.com/ROCm/ROCm.git
synced 2026-04-05 03:01:17 -04:00
[DOCS] fixed typo assert_almost_equal -> assert_allclose in tutorials (#1456)
This commit is contained in:
@@ -311,10 +311,10 @@ def test_layer_norm(M, N, dtype, eps=1e-5, device='cuda'):
|
||||
y_ref.backward(dy, retain_graph=True)
|
||||
dx_ref, dw_ref, db_ref = [_.grad.clone() for _ in [x, weight, bias]]
|
||||
# compare
|
||||
triton.testing.assert_almost_equal(y_tri, y_ref)
|
||||
triton.testing.assert_almost_equal(dx_tri, dx_ref)
|
||||
triton.testing.assert_almost_equal(db_tri, db_ref, decimal=1)
|
||||
triton.testing.assert_almost_equal(dw_tri, dw_ref, decimal=1)
|
||||
assert torch.allclose(y_tri, y_ref, atol=1e-2, rtol=0)
|
||||
assert torch.allclose(dx_tri, dx_ref, atol=1e-2, rtol=0)
|
||||
assert torch.allclose(db_tri, db_ref, atol=1e-2, rtol=0)
|
||||
assert torch.allclose(dw_tri, dw_ref, atol=1e-2, rtol=0)
|
||||
|
||||
|
||||
@triton.testing.perf_report(
|
||||
|
||||
@@ -299,10 +299,10 @@ def test_op(Z, H, N_CTX, D_HEAD, dtype=torch.float16):
|
||||
tri_dk, k.grad = k.grad.clone(), None
|
||||
tri_dq, q.grad = q.grad.clone(), None
|
||||
# compare
|
||||
triton.testing.assert_almost_equal(ref_out, tri_out)
|
||||
triton.testing.assert_almost_equal(ref_dv, tri_dv)
|
||||
triton.testing.assert_almost_equal(ref_dk, tri_dk)
|
||||
triton.testing.assert_almost_equal(ref_dq, tri_dq)
|
||||
assert torch.allclose(ref_out, tri_out, atol=1e-2, rtol=0)
|
||||
assert torch.allclose(ref_dv, tri_dv, atol=1e-2, rtol=0)
|
||||
assert torch.allclose(ref_dk, tri_dk, atol=1e-2, rtol=0)
|
||||
assert torch.allclose(ref_dq, tri_dq, atol=1e-2, rtol=0)
|
||||
|
||||
|
||||
try:
|
||||
|
||||
Reference in New Issue
Block a user