mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-06 21:53:53 -05:00
amd tc 1616128 (#13439)
* amd tc 1616128 * fix test * remove hardcoded check in test
This commit is contained in:
@@ -11,6 +11,7 @@ from tinygrad.helpers import AMX, AMD_LLVM, CPU_LLVM, Context
|
||||
from test.helpers import slow
|
||||
from tinygrad.engine.realize import CompiledRunner, get_program
|
||||
from tinygrad.codegen.opt import Opt, OptOps, KernelOptError
|
||||
from tinygrad.codegen.opt.tc import amd_cdna_1616128
|
||||
|
||||
# TODO: write a clean version of this
|
||||
from test.test_linearizer import helper_realized_ast, helper_linearizer_opt
|
||||
@@ -120,7 +121,7 @@ class TestTensorCores(unittest.TestCase):
|
||||
# check excessive padding doesn't trigger padded TC in TC_OPT=2
|
||||
helper_tc_ensure_uops_and_opts_count(tc.dims[0]//4, tc.dims[1], tc.dims[2], tc.dtype_in, tc.dtype_out, tc_opt=2, ensure_triggered=False)
|
||||
helper_tc_ensure_uops_and_opts_count(tc.dims[0], tc.dims[1]//4, tc.dims[2], tc.dtype_in, tc.dtype_out, tc_opt=2, ensure_triggered=False)
|
||||
if not AMX: # AMX tc.dims[2] == 1
|
||||
if not AMX and tc not in amd_cdna_1616128: # AMX tc.dims[2] == 1
|
||||
helper_tc_ensure_uops_and_opts_count(tc.dims[0], tc.dims[1], tc.dims[2]//8, tc.dtype_in, tc.dtype_out, tc_opt=2, ensure_triggered=False)
|
||||
|
||||
@Context(ALLOW_TF32=1)
|
||||
|
||||
@@ -121,9 +121,15 @@ amd_cdna_161632 = [TensorCore(dims=(16,16,32), threads=64, elements_per_thread=(
|
||||
(('l0', 'l1', 'l2', 'l3', 'r3', 'r4'), ('r0', 'r1'), ('l4', 'l5', 'u0', 'u1', 'r2'))))
|
||||
for di,do in [(dtypes.fp8e5m2,dtypes.float),(dtypes.fp8e4m3,dtypes.float),(dtypes.half,dtypes.float),(dtypes.bfloat16,dtypes.float)]]
|
||||
|
||||
amd_cdna_1616128 = [TensorCore(dims=(16,16,128), threads=64, elements_per_thread=(32,32,4), dtype_in=di, dtype_out=do,
|
||||
opts=("l0","l0","l0","l0","u1","u1","l1","l1"),
|
||||
swizzle=((('u0', 'u1', 'l4', 'l5', 'r5', 'r6'), ('r0', 'r1'), ('l0', 'l1', 'l2', 'l3', 'r2', 'r3', 'r4')),
|
||||
(('l0', 'l1', 'l2', 'l3', 'r5', 'r6'), ('r0', 'r1'), ('l4', 'l5', 'u0', 'u1', 'r2', 'r3', 'r4'))))
|
||||
for di,do in [(dtypes.fp8e5m2,dtypes.float),(dtypes.fp8e4m3,dtypes.float)]]
|
||||
|
||||
amd_cdna3 = amd_cdna_161632[:2] + amd_cdna_161616
|
||||
|
||||
amd_cdna4 = amd_cdna_161632 + amd_cdna_161616
|
||||
amd_cdna4 = amd_cdna_1616128 + amd_cdna_161632 + amd_cdna_161616
|
||||
|
||||
# ***** Apple Metal *****
|
||||
|
||||
|
||||
@@ -373,7 +373,7 @@ class MetalRenderer(CStyleLanguage):
|
||||
simdgroup_multiply_accumulate(mat_c, mat_a, mat_b, mat_c);\n return {dstr_out}(mat_c.thread_elements()[0], mat_c.thread_elements()[1]);\n}}""")
|
||||
return super().render_kernel(function_name, kernel, bufs, uops, prefix)
|
||||
|
||||
_nms = "xyzwabcdefghijkl"
|
||||
_nms = list("xyzwabcdefghijkl") + [f'v{i}' for i in range(16, 32)]
|
||||
|
||||
class CUDARenderer(CStyleLanguage):
|
||||
device = "CUDA"
|
||||
@@ -446,6 +446,8 @@ class CUDARenderer(CStyleLanguage):
|
||||
|
||||
return super().render_kernel(function_name, kernel, bufs, uops, prefix=prefix)
|
||||
|
||||
def fp8_index(dtype: DType): return (dtypes.fp8e4m3, dtypes.fp8e5m2).index(dtype.scalar())
|
||||
|
||||
class AMDHIPRenderer(CStyleLanguage):
|
||||
device = "AMD"
|
||||
shared_max = 65536
|
||||
@@ -463,11 +465,13 @@ class AMDHIPRenderer(CStyleLanguage):
|
||||
self.tensor_cores = self.get_tensor_cores(arch)
|
||||
if self.is_cdna(self.arch):
|
||||
self.string_rewrite = PatternMatcher([
|
||||
(UPat(Ops.WMMA, name="x"), lambda ctx,x: f"__{x.arg[0]}({ctx[x.src[0]]}, {ctx[x.src[1]]}, {ctx[x.src[2]]},"
|
||||
f" {fp8_index(x.src[0].dtype)}, {fp8_index(x.src[0].dtype)}, 0, 0, 0, 0)" if x.arg[1][2] == 128 else None),
|
||||
(UPat(Ops.WMMA, name="x"), lambda ctx,x: f"__{x.arg[0]}({ctx[x.src[0]]}, {ctx[x.src[1]]}, {ctx[x.src[2]]}, 0, 0, 0)"),
|
||||
(UPat(Ops.CAST, dtypes.fp8s, (UPat.var("y", dtypes.float),), name="x",),
|
||||
lambda ctx,x, y: f"f32_to_fp8({ctx[x.src[0]]}, {'1' if x.dtype == dtypes.fp8e5m2 else '0'})"),
|
||||
lambda ctx,x,y: f"f32_to_fp8({ctx[x.src[0]]}, {fp8_index(x.dtype)})"),
|
||||
(UPat(Ops.CAST, dtypes.float, (UPat.var("y", dtypes.fp8s),), name="x",),
|
||||
lambda ctx,x, y: f"__builtin_amdgcn_cvt_f32_{'bf8' if y.dtype == dtypes.fp8e5m2 else 'fp8'}((unsigned int){ctx[x.src[0]]}, 0)"),
|
||||
lambda ctx,x,y: f"__builtin_amdgcn_cvt_f32_{('fp8', 'bf8')[fp8_index(y.dtype)]}((unsigned int){ctx[x.src[0]]}, 0)"),
|
||||
]) + base_rewrite
|
||||
def __reduce__(self): return self.__class__, (self.arch,)
|
||||
|
||||
@@ -527,7 +531,8 @@ class AMDHIPRenderer(CStyleLanguage):
|
||||
if self.is_cdna(self.arch):
|
||||
if (N, M, K) == (16, 16, 16): type_map[dtypes.bfloat16] = 'bf16_1k'
|
||||
elif (N, M, K) == (16, 16, 32): type_map = {**type_map, dtypes.bfloat16: "_bf16", dtypes.half: "_f16"}
|
||||
prefix.append(f"#define __{name} __builtin_amdgcn_mfma_f32_{N}x{M}x{K}{type_map[dtype_in]}")
|
||||
elif (N, M, K) == (16, 16, 128): type_map = {**type_map, dtypes.fp8e4m3: "_f8f6f4", dtypes.fp8e5m2: "_f8f6f4"}
|
||||
prefix.append(f"#define __{name} __builtin_amdgcn_mfma_{'scale_' if K == 128 else ''}f32_{N}x{M}x{K}{type_map[dtype_in]}")
|
||||
# #define __WMMA_16_16_16_half_half __builtin_amdgcn_wmma_f16_16x16x16_f16_w32_gfx12
|
||||
elif self.tensor_cores == tc.amd_rdna4:
|
||||
prefix.append(f"#define __{name} __builtin_amdgcn_wmma_{type_map[dtype_out]}_16x16x16_{type_map[dtype_in]}_w32_gfx12")
|
||||
|
||||
Reference in New Issue
Block a user