mirror of
https://github.com/ROCm/ROCm.git
synced 2026-04-05 03:01:17 -04:00
[TESTS] remove get_proper_err, get_variant_golden (#2039)
Co-authored-by: Philippe Tillet <phil@openai.com>
This commit is contained in:
@@ -84,31 +84,6 @@ def matmul_kernel(
|
||||
tl.store(c_block_ptr, accumulator)
|
||||
|
||||
|
||||
def get_variant_golden(a, b):
|
||||
SIZE_M = a.shape[0]
|
||||
SIZE_K = a.shape[1]
|
||||
SIZE_N = b.shape[1]
|
||||
assert a.shape[1] == b.shape[0]
|
||||
zero_M_K = torch.zeros((SIZE_M, SIZE_K)).cuda()
|
||||
zero_3M_K = torch.zeros((3 * SIZE_M, SIZE_K)).cuda()
|
||||
zero_K_N = torch.zeros((SIZE_K, SIZE_N)).cuda()
|
||||
zero_3K_N = torch.zeros((3 * SIZE_K, SIZE_N)).cuda()
|
||||
a_padded = torch.cat((a, zero_M_K, zero_M_K), 0)
|
||||
a_padded = torch.cat((a_padded, zero_3M_K, zero_3M_K), 1)
|
||||
b_padded = torch.cat((b, zero_K_N, zero_K_N), 0)
|
||||
b_padded = torch.cat((b_padded, zero_3K_N, zero_3K_N), 1)
|
||||
c_padded = torch.matmul(a_padded, b_padded)
|
||||
return c_padded[:SIZE_M, :SIZE_N]
|
||||
|
||||
|
||||
def get_proper_err(a, b, golden):
|
||||
golden_variant = get_variant_golden(a, b)
|
||||
golden_diff = golden - golden_variant
|
||||
golden_abs_err = torch.max(torch.abs(golden_diff)).item()
|
||||
golden_rel_err = torch.max(torch.abs(golden_diff / golden)).item()
|
||||
return (golden_abs_err, golden_rel_err)
|
||||
|
||||
|
||||
def matmul(a, b):
|
||||
# checks constraints
|
||||
assert a.shape[1] == b.shape[0], "incompatible dimensions"
|
||||
@@ -134,15 +109,16 @@ def matmul(a, b):
|
||||
a = torch.randn((512, 512), device='cuda', dtype=torch.float16)
|
||||
b = torch.randn((512, 512), device='cuda', dtype=torch.float16).T
|
||||
c = matmul(a, b)
|
||||
c = torch.nn.functional.normalize(c)
|
||||
|
||||
golden = torch.nn.functional.normalize(torch.matmul(a, b))
|
||||
|
||||
golden = torch.matmul(a, b)
|
||||
golden_abs_err, golden_rel_err = get_proper_err(a, b, golden)
|
||||
torch.set_printoptions(profile="full")
|
||||
assert_close(
|
||||
c,
|
||||
golden,
|
||||
rtol=max(1e-4, 1.5 * golden_rel_err),
|
||||
atol=max(1e-4, 1.5 * golden_abs_err),
|
||||
rtol=1e-2,
|
||||
atol=1e-3,
|
||||
check_dtype=False)
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user