[TESTS] remove get_proper_err, get_variant_golden (#2039)

Co-authored-by: Philippe Tillet <phil@openai.com>
This commit is contained in:
ben-zhang-609
2023-08-08 13:52:55 +08:00
committed by GitHub
parent 4ed8381fdb
commit 31e79aa384
7 changed files with 12 additions and 168 deletions

View File

@@ -84,31 +84,6 @@ def matmul_kernel(
tl.store(c_block_ptr, accumulator)
def get_variant_golden(a, b):
SIZE_M = a.shape[0]
SIZE_K = a.shape[1]
SIZE_N = b.shape[1]
assert a.shape[1] == b.shape[0]
zero_M_K = torch.zeros((SIZE_M, SIZE_K)).cuda()
zero_3M_K = torch.zeros((3 * SIZE_M, SIZE_K)).cuda()
zero_K_N = torch.zeros((SIZE_K, SIZE_N)).cuda()
zero_3K_N = torch.zeros((3 * SIZE_K, SIZE_N)).cuda()
a_padded = torch.cat((a, zero_M_K, zero_M_K), 0)
a_padded = torch.cat((a_padded, zero_3M_K, zero_3M_K), 1)
b_padded = torch.cat((b, zero_K_N, zero_K_N), 0)
b_padded = torch.cat((b_padded, zero_3K_N, zero_3K_N), 1)
c_padded = torch.matmul(a_padded, b_padded)
return c_padded[:SIZE_M, :SIZE_N]
def get_proper_err(a, b, golden):
golden_variant = get_variant_golden(a, b)
golden_diff = golden - golden_variant
golden_abs_err = torch.max(torch.abs(golden_diff)).item()
golden_rel_err = torch.max(torch.abs(golden_diff / golden)).item()
return (golden_abs_err, golden_rel_err)
def matmul(a, b):
# checks constraints
assert a.shape[1] == b.shape[0], "incompatible dimensions"
@@ -134,15 +109,16 @@ def matmul(a, b):
a = torch.randn((512, 512), device='cuda', dtype=torch.float16)
b = torch.randn((512, 512), device='cuda', dtype=torch.float16).T
c = matmul(a, b)
c = torch.nn.functional.normalize(c)
golden = torch.nn.functional.normalize(torch.matmul(a, b))
golden = torch.matmul(a, b)
golden_abs_err, golden_rel_err = get_proper_err(a, b, golden)
torch.set_printoptions(profile="full")
assert_close(
c,
golden,
rtol=max(1e-4, 1.5 * golden_rel_err),
atol=max(1e-4, 1.5 * golden_abs_err),
rtol=1e-2,
atol=1e-3,
check_dtype=False)