mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-10 07:28:15 -05:00
fix use_tensor_cores propagation (#10048)
* propagate use_tensor_cores * add use_tensor_core to arg in test and search * bugfix * get TC val from ContextVar in search * revert minor space change * add tc emulation test to ci and benchmark * revert * revert whitespace change * remove test for ptx * add comment and remove llvm test run
This commit is contained in:
10
.github/workflows/benchmark.yml
vendored
10
.github/workflows/benchmark.yml
vendored
@@ -62,10 +62,11 @@ jobs:
|
||||
- name: Test speed vs torch
|
||||
run: BIG=2 MPS=1 python3.11 test/test_speed_v_torch.py | tee torch_speed.txt
|
||||
- name: Test tensor cores
|
||||
run: METAL=1 python3.11 test/test_linearizer.py TestLinearizer.test_tensor_cores TestLinearizer.test_tensor_cores_padded TestLinearizer.test_tensor_cores_padded_uops
|
||||
run: METAL=1 python3.11 test/test_linearizer.py TestLinearizer.test_tensor_cores TestLinearizer.test_tensor_cores_emulation TestLinearizer.test_tensor_cores_padded TestLinearizer.test_tensor_cores_padded_uops
|
||||
# TODO: add TestLinearizer.test_tensor_cores_emulation for llvm (#10093)
|
||||
- name: Test AMX tensor cores
|
||||
run: |
|
||||
DEBUG=2 CPU=1 AMX=1 python3.11 test/test_linearizer.py TestLinearizer.test_tensor_cores TestLinearizer.test_tensor_cores_padded TestLinearizer.test_tensor_cores_padded_uops
|
||||
DEBUG=2 CPU=1 AMX=1 python3.11 test/test_linearizer.py TestLinearizer.test_tensor_cores TestLinearizer.test_tensor_cores_emulation TestLinearizer.test_tensor_cores_padded TestLinearizer.test_tensor_cores_padded_uops
|
||||
DEBUG=2 LLVM=1 AMX=1 python3.11 test/test_linearizer.py TestLinearizer.test_tensor_cores TestLinearizer.test_tensor_cores_padded TestLinearizer.test_tensor_cores_padded_uops
|
||||
- name: Run Tensor Core GEMM (float)
|
||||
run: DEBUG=2 SHOULD_USE_TC=1 python3.11 extra/gemm/simple_matmul.py | tee matmul.txt
|
||||
@@ -175,9 +176,10 @@ jobs:
|
||||
run: NV=1 IGNORE_BEAM_CACHE=1 BEAM_DEBUG=1 DEBUG=1 python -m pytest -rA test/external/speed_v_theoretical.py --durations=20
|
||||
- name: Test benchmark allreduce
|
||||
run: NV=1 python test/external/external_benchmark_multitensor_allreduce.py
|
||||
# TODO: add TestLinearizer.test_tensor_cores_emulation for ptx (#9967)
|
||||
- name: Test tensor cores
|
||||
run: |
|
||||
NV=1 ALLOW_TF32=1 python3 test/test_linearizer.py TestLinearizer.test_tensor_cores TestLinearizer.test_tensor_cores_padded TestLinearizer.test_tensor_cores_padded_uops
|
||||
NV=1 ALLOW_TF32=1 python3 test/test_linearizer.py TestLinearizer.test_tensor_cores TestLinearizer.test_tensor_cores_emulation TestLinearizer.test_tensor_cores_padded TestLinearizer.test_tensor_cores_padded_uops
|
||||
PTX=1 ALLOW_TF32=1 NV=1 python3 test/test_linearizer.py TestLinearizer.test_tensor_cores TestLinearizer.test_tensor_cores_padded TestLinearizer.test_tensor_cores_padded_uops
|
||||
- name: Run Tensor Core GEMM (CUDA)
|
||||
run: |
|
||||
@@ -373,7 +375,7 @@ jobs:
|
||||
run: AMD=1 IGNORE_BEAM_CACHE=1 BEAM_DEBUG=1 DEBUG=1 python -m pytest -rA test/external/speed_v_theoretical.py --durations=20
|
||||
- name: Test tensor cores
|
||||
run: |
|
||||
AMD=1 python3 test/test_linearizer.py TestLinearizer.test_tensor_cores TestLinearizer.test_tensor_cores_padded_amd TestLinearizer.test_tensor_cores_padded_uops
|
||||
AMD=1 python3 test/test_linearizer.py TestLinearizer.test_tensor_cores TestLinearizer.test_tensor_cores_emulation TestLinearizer.test_tensor_cores_padded_amd TestLinearizer.test_tensor_cores_padded_uops
|
||||
AMD=1 SHOULD_USE_TC=1 BFLOAT16=1 DEBUG=2 python3 extra/gemm/simple_matmul.py
|
||||
- name: Run Tensor Core GEMM (AMD)
|
||||
run: AMD=1 SHOULD_USE_TC=1 HALF=1 DEBUG=2 ATOL=2e-2 python3 extra/gemm/simple_matmul.py | tee matmul_amd.txt
|
||||
|
||||
12
.github/workflows/test.yml
vendored
12
.github/workflows/test.yml
vendored
@@ -276,12 +276,12 @@ jobs:
|
||||
PYTHONPATH=. DEBUG=2 AMX=1 EMULATE_AMX=1 FORWARD_ONLY=1 PYTHON=1 python3 ./test/test_linearizer.py TestLinearizer.test_tensor_cores
|
||||
- name: Test tensor cores (TC=3)
|
||||
run: |
|
||||
TC=3 DEBUG=3 EMULATE_METAL=1 FORWARD_ONLY=1 PYTHON=1 python3 test/test_ops.py TestOps.test_gemm
|
||||
TC=3 PYTHONPATH=. DEBUG=3 EMULATE_AMD=1 PYTHON=1 N=16 HALF=1 ACC_HALF=0 ATOL=3e-4 python3 ./extra/gemm/simple_matmul.py
|
||||
TC=3 PYTHONPATH=. DEBUG=3 EMULATE_AMD_MFMA=1 PYTHON=1 N=16 HALF=1 ACC_HALF=0 ATOL=3e-4 python3 ./extra/gemm/simple_matmul.py
|
||||
TC=3 DEBUG=3 EMULATE_CUDA=1 FORWARD_ONLY=1 PYTHON=1 python3 test/test_ops.py TestOps.test_gemm_fp16
|
||||
TC=3 PYTHONPATH=. DEBUG=3 EMULATE_INTEL=1 PYTHON=1 N=16 HALF=1 ATOL=3e-4 python3 ./extra/gemm/simple_matmul.py
|
||||
TC=3 PYTHONPATH=. DEBUG=3 AMX=1 EMULATE_AMX=1 FORWARD_ONLY=1 PYTHON=1 python3 test/test_ops.py TestOps.test_gemm
|
||||
PYTHONPATH=. DEBUG=2 PYTHON=1 EMULATE_METAL=1 python3 ./test/test_linearizer.py TestLinearizer.test_tensor_cores_emulation
|
||||
PYTHONPATH=. DEBUG=2 PYTHON=1 EMULATE_AMD=1 python3 ./test/test_linearizer.py TestLinearizer.test_tensor_cores_emulation
|
||||
PYTHONPATH=. DEBUG=2 PYTHON=1 EMULATE_AMD_MFMA=1 python3 ./test/test_linearizer.py TestLinearizer.test_tensor_cores_emulation
|
||||
PYTHONPATH=. DEBUG=2 PYTHON=1 EMULATE_CUDA=1 python3 ./test/test_linearizer.py TestLinearizer.test_tensor_cores_emulation
|
||||
PYTHONPATH=. DEBUG=2 PYTHON=1 EMULATE_INTEL=1 python3 ./test/test_linearizer.py TestLinearizer.test_tensor_cores_emulation
|
||||
PYTHONPATH=. DEBUG=2 PYTHON=1 EMULATE_AMX=1 AMX=1 python3 ./test/test_linearizer.py TestLinearizer.test_tensor_cores_emulation
|
||||
- name: Test device flop counts
|
||||
run: |
|
||||
PYTHONPATH=. DEBUG=2 EMULATE_METAL=1 PYTHON=1 python3 ./test/test_uops_stats.py TestUOpsStatsMatmulHalf
|
||||
|
||||
2
test/external/external_test_nv.py
vendored
2
test/external/external_test_nv.py
vendored
@@ -26,7 +26,7 @@ class TestNV(unittest.TestCase):
|
||||
|
||||
def test_oor_kernels(self):
|
||||
ast = LazyOp(op=BufferOps.STORE, src=(LazyOp(op=Ops.CAST, src=(LazyOp(op=ReduceOps.SUM, src=(LazyOp(op=Ops.CAST, src=(LazyOp(op=Ops.MUL, src=(LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=1, dtype=dtypes.half, st=ShapeTracker(views=(View(shape=(1, 256, 1, 512, 4, 16, 4, 16), strides=(0, 100352, 0, 196, 0, 14, 0, 1), offset=-15, mask=((0, 1), (0, 256), (0, 1), (0, 512), (0, 4), (1, 15), (0, 4), (1, 15)), contiguous=False), View(shape=(256, 1, 512, 7, 7, 512, 3, 3), strides=(2097152, 0, 0, 128, 2, 4096, 1088, 17), offset=0, mask=None, contiguous=False))))), LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=2, dtype=dtypes.half, st=ShapeTracker(views=(View(shape=(256, 1, 512, 7, 7, 512, 3, 3), strides=(25088, 0, 49, 7, 1, 0, 0, 0), offset=0, mask=None, contiguous=False),))))), arg=None),), arg=(dtypes.float, False)),), arg=((0, 3, 4), dtypes.float)),), arg=(dtypes.half, False)),), arg=MemBuffer(idx=0, dtype=dtypes.half, st=ShapeTracker(views=(View(shape=(1, 1, 512, 1, 1, 512, 3, 3), strides=(0, 0, 4608, 0, 0, 9, 3, 1), offset=0, mask=None, contiguous=True),)))) # noqa: E501
|
||||
opts = [Opt(op=OptOps.TC, axis=6, arg=(-1, 2)), Opt(op=OptOps.UPCAST, axis=0, arg=4), Opt(op=OptOps.UPCAST, axis=3, arg=0), Opt(op=OptOps.LOCAL, axis=1, arg=4), Opt(op=OptOps.LOCAL, axis=2, arg=3), Opt(op=OptOps.UPCAST, axis=1, arg=2)] # noqa: E501
|
||||
opts = [Opt(op=OptOps.TC, axis=6, arg=(-1, 2, 1)), Opt(op=OptOps.UPCAST, axis=0, arg=4), Opt(op=OptOps.UPCAST, axis=3, arg=0), Opt(op=OptOps.LOCAL, axis=1, arg=4), Opt(op=OptOps.LOCAL, axis=2, arg=3), Opt(op=OptOps.UPCAST, axis=1, arg=2)] # noqa: E501
|
||||
helper_test_lin(Kernel(ast), opts=opts, failed_platforms=["NV"])
|
||||
|
||||
def test_error_on_huge_dims(self):
|
||||
|
||||
@@ -26,16 +26,17 @@ def helper_realized_ast(r:Union[Tensor, list[Tensor]]) -> tuple[UOp, list[Buffer
|
||||
bufs = [Buffer((x).device, x.size, x.dtype).allocate() if i < len(s[-1].ast.src) else x for i,x in enumerate(s[-1].bufs)]
|
||||
return s[-1].ast, bufs
|
||||
|
||||
def helper_tc_allclose(N:int, M:int, K:int, dtype_in:DType, dtype_out:DType, axis:int=0, tc_select:int=-1, tc_opt:int=0):
|
||||
def helper_tc_allclose(N:int, M:int, K:int, dtype_in:DType, dtype_out:DType, axis:int=0, tc_select:int=-1, tc_opt:int=0, use_tensor_cores:int=1):
|
||||
a, b = Tensor.rand(M, K, dtype=dtype_in), Tensor.rand(K, N, dtype=dtype_in)
|
||||
np_a, np_b = a.numpy(), b.numpy()
|
||||
r = a.matmul(b, dtype=dtype_out)
|
||||
if dtype_in == dtypes.bfloat16: r = r.float()
|
||||
realized_ast, bufs = helper_realized_ast(r)
|
||||
k = Kernel(realized_ast)
|
||||
k.apply_tensor_cores(1, axis=axis, tc_select=tc_select, tc_opt=tc_opt)
|
||||
k.apply_tensor_cores(use_tensor_cores, axis=axis, tc_select=tc_select, tc_opt=tc_opt)
|
||||
prg = CompiledRunner(replace(k.to_program(), device=Device.DEFAULT))
|
||||
assert len([uop for uop in k.uops if uop.op is Ops.WMMA]) > 0, "tensor core not triggered"
|
||||
if use_tensor_cores == 1: assert len([uop for uop in k.uops if uop.op is Ops.WMMA]) > 0, "wmma not triggered"
|
||||
elif use_tensor_cores == 3: assert len([uop for uop in k.uops if uop.op is Ops.DEFINE_LOCAL]) == 2, "local buffers not triggered"
|
||||
assert len([x for x in k.applied_opts if x.op is OptOps.TC]) == 1, "tensor core opt not included"
|
||||
prg.exec(bufs)
|
||||
if dtype_in == dtypes.half: tc_atol, tc_rtol = 1e-2, 1e-3
|
||||
@@ -1060,6 +1061,14 @@ class TestLinearizer(unittest.TestCase):
|
||||
# for AMX, tc.dims[2] == 1 so reduceop is None thus tensor_cores are not triggered
|
||||
helper_tc_allclose(tc.dims[0], tc.dims[1], 2 if AMX else tc.dims[2], tc.dtype_in, tc.dtype_out, axis=0, tc_opt=0)
|
||||
|
||||
@unittest.skipUnless(Device[Device.DEFAULT].renderer.tensor_cores, "test requires tensor cores")
|
||||
def test_tensor_cores_emulation(self):
|
||||
for tc in Device[Device.DEFAULT].renderer.tensor_cores:
|
||||
if CI and getenv("AMD_LLVM") and tc.dtype_in is dtypes.bfloat16: continue # TODO: compilation error in CI
|
||||
if not is_dtype_supported(tc.dtype_in) or not is_dtype_supported(tc.dtype_out): continue
|
||||
# for AMX, tc.dims[2] == 1 so reduceop is None thus tensor_cores are not triggered
|
||||
helper_tc_allclose(tc.dims[0], tc.dims[1], 2 if AMX else tc.dims[2], tc.dtype_in, tc.dtype_out, axis=0, tc_opt=0, use_tensor_cores=3)
|
||||
|
||||
@unittest.skipUnless(Device[Device.DEFAULT].renderer.tensor_cores, "test requires tensor cores")
|
||||
def test_tensor_cores_codegen(self):
|
||||
for tc in Device[Device.DEFAULT].renderer.tensor_cores:
|
||||
@@ -2016,7 +2025,7 @@ class TestKernelOpts(unittest.TestCase):
|
||||
UOp(Ops.VIEW, arg=ShapeTracker(views=(View(shape=(1243, 256), strides=(1, 0), offset=0, mask=None, contiguous=False),))),)),)),)),)),)) # noqa: E501
|
||||
k = Kernel(ast, opts=Device[Device.DEFAULT].renderer)
|
||||
with self.assertRaises(KernelOptError):
|
||||
k.apply_opt(Opt(OptOps.TC, 0, (-1, 1)))
|
||||
k.apply_opt(Opt(OptOps.TC, 0, (-1, 1, 1)))
|
||||
|
||||
@unittest.skipUnless(Device[Device.DEFAULT].renderer.tensor_cores, "test requires tensor cores")
|
||||
def test_tensor_core_opts(self):
|
||||
|
||||
@@ -35,7 +35,7 @@ class TestLinearizerDumb(unittest.TestCase):
|
||||
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(64, 1, 512, 7, 7, 1, 1, 1), strides=(0, 0, 0, 0, 0, 0, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)),
|
||||
ast_const(dtypes.half, 0.0, st_src=(
|
||||
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(64, 1, 512, 7, 7, 1, 1, 1), strides=(0, 0, 0, 0, 0, 0, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),))
|
||||
opts = [Opt(op=OptOps.TC, axis=2, arg=(-1, 2)), Opt(op=OptOps.UPCAST, axis=2, arg=0), Opt(op=OptOps.UNROLL, axis=1, arg=0)]
|
||||
opts = [Opt(op=OptOps.TC, axis=2, arg=(-1, 2, 1)), Opt(op=OptOps.UPCAST, axis=2, arg=0), Opt(op=OptOps.UNROLL, axis=1, arg=0)]
|
||||
k = Kernel(ast, opts=Device["METAL"].renderer)
|
||||
k.apply_opts(opts)
|
||||
prg = k.to_program()
|
||||
|
||||
@@ -739,7 +739,7 @@ class TestLinearizerFailures(unittest.TestCase):
|
||||
UOp(Ops.LOAD, dtypes.half, arg=None, src=(
|
||||
UOp(Ops.DEFINE_GLOBAL, dtypes.half.ptr(), arg=2, src=()),
|
||||
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(128, 1, 64, 56, 56, 64, 3, 3), strides=(0, 0, 576, 0, 0, 9, 3, 1), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),)),)),)),))
|
||||
opts = [Opt(op=OptOps.TC, axis=0, arg=(-1, 1)), Opt(op=OptOps.PADTO, axis=2, arg=32)]
|
||||
opts = [Opt(op=OptOps.TC, axis=0, arg=(-1, 1, 1)), Opt(op=OptOps.PADTO, axis=2, arg=32)]
|
||||
helper_test_lin(Kernel(ast), opts, failed_platforms=[], atol=1.0)
|
||||
|
||||
def test_failure_30(self):
|
||||
@@ -799,7 +799,7 @@ class TestLinearizerFailures(unittest.TestCase):
|
||||
UOp(Ops.LOAD, dtypes.half, arg=None, src=(
|
||||
UOp(Ops.DEFINE_GLOBAL, dtypes.half.ptr(), arg=2, src=()),
|
||||
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(256, 1, 256, 14, 14, 256, 3, 3), strides=(0, 0, 2304, 0, 0, 9, 3, 1), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),)),)),)),))
|
||||
opts = [Opt(op=OptOps.TC, axis=2, arg=(-1, 2)), Opt(op=OptOps.UPCAST, axis=2, arg=7), Opt(op=OptOps.UNROLL, axis=1, arg=0), Opt(op=OptOps.LOCAL, axis=1, arg=16)]
|
||||
opts = [Opt(op=OptOps.TC, axis=2, arg=(-1, 2, 1)), Opt(op=OptOps.UPCAST, axis=2, arg=7), Opt(op=OptOps.UNROLL, axis=1, arg=0), Opt(op=OptOps.LOCAL, axis=1, arg=16)]
|
||||
helper_test_lin(Kernel(ast), opts=opts, failed_platforms=[], atol=0.1, rtol=0.05)
|
||||
|
||||
def test_failure_33(self):
|
||||
@@ -860,7 +860,7 @@ class TestLinearizerFailures(unittest.TestCase):
|
||||
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(4, 1, 6, 10, 3, 1, 2, 5), strides=(0, 0, 10, 0, 0, 0, 5, 1), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),
|
||||
ast_const(dtypes.float, 0.0, st_src=(
|
||||
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(4, 1, 6, 10, 3, 1, 1, 1), strides=(0, 0, 0, 0, 0, 0, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),))
|
||||
opts = [Opt(op=OptOps.TC, axis=0, arg=(-1, 2)), Opt(op=OptOps.UNROLL, axis=0, arg=0)] if unroll else [Opt(op=OptOps.TC, axis=0, arg=(-1, 2))]
|
||||
opts = [Opt(op=OptOps.TC, axis=0, arg=(-1, 2, 1)), Opt(op=OptOps.UNROLL, axis=0, arg=0)] if unroll else [Opt(op=OptOps.TC, axis=0, arg=(-1, 2, 1))]
|
||||
helper_test_lin(Kernel(ast), opts=opts, failed_platforms=[])
|
||||
|
||||
def test_failure_35(self): self.test_failure_34(True)
|
||||
@@ -909,7 +909,7 @@ class TestLinearizerFailures(unittest.TestCase):
|
||||
ast_const(dtypes.float, 0.0, st_src=(
|
||||
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(512, 1, 32, 24, 24, 1, 1, 1), strides=(0, 0, 0, 0, 0, 0, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),))
|
||||
for axis in [0,1,2,3,4,5]:
|
||||
opts = [Opt(op=OptOps.TC, axis=axis, arg=(-1, 2))]
|
||||
opts = [Opt(op=OptOps.TC, axis=axis, arg=(-1, 2, 1))]
|
||||
helper_test_lin(Kernel(ast), opts=opts, failed_platforms=[])
|
||||
|
||||
def test_failure_38(self):
|
||||
@@ -929,7 +929,7 @@ class TestLinearizerFailures(unittest.TestCase):
|
||||
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=2, src=()),
|
||||
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(2, 1, 32, 24, 24, 1, 5, 5, 256), strides=(18432, 0, 576, 24, 1, 0, 0, 0, 36864), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),)),))
|
||||
for axis in [0,1,3,4]:
|
||||
opts = [Opt(op=OptOps.TC, axis=axis, arg=(-1, 2))]
|
||||
opts = [Opt(op=OptOps.TC, axis=axis, arg=(-1, 2, 1))]
|
||||
helper_test_lin(Kernel(ast), opts=opts, failed_platforms=[])
|
||||
|
||||
@unittest.skip("very slow, similar to test_failure_37")
|
||||
@@ -957,7 +957,7 @@ class TestLinearizerFailures(unittest.TestCase):
|
||||
ast_const(dtypes.float, 0.0, st_src=(
|
||||
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(10000, 1, 32, 24, 24, 1, 1, 1), strides=(0, 0, 0, 0, 0, 0, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),))
|
||||
for axis in [0,1,2,3,4,5]:
|
||||
opts = [Opt(op=OptOps.TC, axis=axis, arg=(-1, 2))]
|
||||
opts = [Opt(op=OptOps.TC, axis=axis, arg=(-1, 2, 1))]
|
||||
helper_test_lin(Kernel(ast), opts=opts, failed_platforms=[])
|
||||
|
||||
def test_failure_40(self):
|
||||
@@ -995,7 +995,7 @@ class TestLinearizerFailures(unittest.TestCase):
|
||||
UOp(Ops.LOAD, dtypes.half, arg=None, src=(
|
||||
UOp(Ops.DEFINE_GLOBAL, dtypes.half.ptr(), arg=2, src=()),
|
||||
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(256, 1, 128, 28, 28, 128, 3, 3), strides=(0, 0, 1152, 0, 0, 9, 3, 1), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),)),)),)),))
|
||||
opts=[Opt(op=OptOps.TC, axis=5, arg=(-1, 2)), Opt(op=OptOps.UNROLL, axis=0, arg=0)]
|
||||
opts=[Opt(op=OptOps.TC, axis=5, arg=(-1, 2, 1)), Opt(op=OptOps.UNROLL, axis=0, arg=0)]
|
||||
helper_test_lin(Kernel(ast), opts=opts, failed_platforms=["AMD", "HIP"], atol=0.02)
|
||||
|
||||
# llama3 8B failure with BEAM=2 https://github.com/tinygrad/tinygrad/actions/runs/10150118124/job/28066519425#step:14:1, these don't compile
|
||||
@@ -1150,7 +1150,7 @@ class TestLinearizerFailures(unittest.TestCase):
|
||||
UOp(Ops.LOAD, dtypes.half, arg=None, src=(
|
||||
UOp(Ops.DEFINE_GLOBAL, dtypes.half.ptr(), arg=2, src=()),
|
||||
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 1, 64, 56, 56, 256, 1, 1, 256), strides=(0, 0, 3136, 56, 1, 0, 0, 0, 200704), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),)),)),))
|
||||
opts = [Opt(op=OptOps.TC, axis=0, arg=(-1, 0)), Opt(op=OptOps.UPCAST, axis=1, arg=4), Opt(op=OptOps.UPCAST, axis=0, arg=4), Opt(op=OptOps.LOCAL, axis=0, arg=2)]
|
||||
opts = [Opt(op=OptOps.TC, axis=0, arg=(-1, 0, 1)), Opt(op=OptOps.UPCAST, axis=1, arg=4), Opt(op=OptOps.UPCAST, axis=0, arg=4), Opt(op=OptOps.LOCAL, axis=0, arg=2)]
|
||||
helper_test_lin(Kernel(ast, opts=Device[Device.DEFAULT].renderer), opts=opts, failed_platforms=[])
|
||||
|
||||
def test_failure_49(self):
|
||||
@@ -1167,7 +1167,7 @@ class TestLinearizerFailures(unittest.TestCase):
|
||||
UOp(Ops.LOAD, dtypes.float, arg=None, src=(
|
||||
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=2, src=()),
|
||||
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(10, 6, 10), strides=(0, 1, 6), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),)),))
|
||||
opts = [Opt(op=OptOps.TC, axis=0, arg=(-1, 2)), Opt(op=OptOps.UPCAST, axis=0, arg=2)]
|
||||
opts = [Opt(op=OptOps.TC, axis=0, arg=(-1, 2, 1)), Opt(op=OptOps.UPCAST, axis=0, arg=2)]
|
||||
helper_test_lin(Kernel(ast, opts=Device[Device.DEFAULT].renderer), opts=opts, failed_platforms=[])
|
||||
|
||||
def test_failure_50(self):
|
||||
@@ -1228,7 +1228,7 @@ class TestLinearizerFailures(unittest.TestCase):
|
||||
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(12, 1024, 1), strides=(0, 1, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),
|
||||
UOp(Ops.CONST, dtypes.half, arg=-1.4426950408889634, src=(
|
||||
x6,)),)),)),)),)),)),))
|
||||
opts = [Opt(op=OptOps.TC, axis=0, arg=(-1, 2))]
|
||||
opts = [Opt(op=OptOps.TC, axis=0, arg=(-1, 2, 1))]
|
||||
helper_test_lin(Kernel(ast, opts=Device[Device.DEFAULT].renderer), opts=opts, failed_platforms=[])
|
||||
|
||||
@unittest.skipIf(CI and Device.DEFAULT in {"METAL"}, "hangs metal gpu CI")
|
||||
@@ -1250,7 +1250,7 @@ class TestLinearizerFailures(unittest.TestCase):
|
||||
UOp(Ops.LOAD, dtypes.half, arg=None, src=(
|
||||
UOp(Ops.DEFINE_GLOBAL, dtypes.half.ptr(), arg=2, src=()),
|
||||
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(256, 1, 64, 112, 112, 3, 7, 7), strides=(0, 0, 147, 0, 0, 49, 7, 1), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),)),)),)),))
|
||||
opts = [Opt(op=OptOps.TC, axis=0, arg=(-1, 2)), Opt(op=OptOps.UPCAST, axis=1, arg=4), Opt(op=OptOps.LOCAL, axis=0, arg=16)]
|
||||
opts = [Opt(op=OptOps.TC, axis=0, arg=(-1, 2, 1)), Opt(op=OptOps.UPCAST, axis=1, arg=4), Opt(op=OptOps.LOCAL, axis=0, arg=16)]
|
||||
helper_test_lin(Kernel(ast, opts=Device[Device.DEFAULT].renderer), opts=opts, failed_platforms=[])
|
||||
|
||||
def test_failure_53(self):
|
||||
@@ -1304,7 +1304,7 @@ class TestLinearizerFailures(unittest.TestCase):
|
||||
UOp(Ops.LOAD, dtypes.half, arg=None, src=(
|
||||
UOp(Ops.DEFINE_GLOBAL, dtypes.half.ptr(), arg=2, src=()),
|
||||
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(256, 1, 64, 56, 56, 64, 3, 3), strides=(0, 0, 576, 0, 0, 9, 3, 1), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),)),)),)),))
|
||||
opts = [Opt(op=OptOps.TC, axis=2, arg=(-1, 2)), Opt(op=OptOps.UPCAST, axis=2, arg=7), Opt(op=OptOps.UPCAST, axis=1, arg=2)]
|
||||
opts = [Opt(op=OptOps.TC, axis=2, arg=(-1, 2, 1)), Opt(op=OptOps.UPCAST, axis=2, arg=7), Opt(op=OptOps.UPCAST, axis=1, arg=2)]
|
||||
helper_test_lin(Kernel(ast, opts=Device[Device.DEFAULT].renderer), opts=opts, failed_platforms=["HIP", "AMD"])
|
||||
|
||||
@unittest.skipIf(CI and Device.DEFAULT in {"METAL"}, "hangs metal gpu CI")
|
||||
|
||||
@@ -11,7 +11,7 @@ from tinygrad.device import Device
|
||||
from tinygrad.renderer import Renderer, TensorCore, ProgramSpec, Opt, OptOps
|
||||
from tinygrad.dtype import ImageDType
|
||||
from tinygrad.helpers import all_same, colored, ansilen, dedup, getenv, prod, round_up, all_int, to_function_name, diskcache_put, unwrap, ContextVar
|
||||
from tinygrad.helpers import DEBUG, TC_SELECT, TC_OPT, USE_TC, AMX, CAPTURE_PROCESS_REPLAY
|
||||
from tinygrad.helpers import DEBUG, TC_SELECT, TC_OPT, AMX, CAPTURE_PROCESS_REPLAY
|
||||
from tinygrad.shape.shapetracker import ShapeTracker
|
||||
from tinygrad.shape.view import strides_for_shape
|
||||
from tinygrad.codegen.linearize import linearize_uop
|
||||
@@ -314,7 +314,7 @@ class Kernel:
|
||||
if tc_opt is None: tc_opt = TC_OPT.value
|
||||
if not self.opts.tensor_cores: return False
|
||||
try: # check TC first and apply hand-coded opts if successful
|
||||
self.apply_opt(Opt(OptOps.TC, axis, (tc_select, tc_opt)))
|
||||
self.apply_opt(Opt(OptOps.TC, axis, (tc_select, tc_opt, use_tensor_cores)))
|
||||
|
||||
if (tc_opts:=self.tensor_core_opts) is not None:
|
||||
if extra_opts is not None: self.apply_opts(extra_opts)
|
||||
@@ -344,10 +344,10 @@ class Kernel:
|
||||
check(len(self.applied_opts) == 0, "tensor core opts must be first") # TODO: things like PADTO might be fine
|
||||
check(len(self.opts.tensor_cores) > 0, "must have tensor cores")
|
||||
check(opt.axis is not None, "tensor core opts must have an axis")
|
||||
check(opt.arg is not None and isinstance(opt.arg, tuple) and len(opt.arg) == 2, "tensor core opts must have tc_select and tc_opt")
|
||||
check(opt.arg is not None and isinstance(opt.arg, tuple) and len(opt.arg) == 3, "tensor core opts must have valid arg")
|
||||
check(-1 <= (tc_select:=cast(tuple, opt.arg)[0]) < len(self.opts.tensor_cores), "tensor core opts must have valid tc_select")
|
||||
check(0 <= (tc_opt:=cast(tuple, opt.arg)[1]) <= 2, "tensor core opts must have valid tc_opt")
|
||||
check(0 < (use_tensor_cores:=USE_TC.value) <= 3, "use_tensor_cores value is not valid")
|
||||
check(0 < (use_tensor_cores:=cast(tuple, opt.arg)[2]) <= 3, "use_tensor_cores value is not valid")
|
||||
check(self._apply_tc_opt(use_tensor_cores, cast(int, opt.axis), tc_select, tc_opt), "no tensor core available")
|
||||
self.applied_opts.append(opt)
|
||||
return
|
||||
|
||||
@@ -19,8 +19,9 @@ actions += [Opt(op=OptOps.GROUPTOP, axis=axis, arg=amt) for amt in [13,16,28,29,
|
||||
actions += [Opt(op=OptOps.GROUP, axis=axis, arg=amt) for amt in [0,4,8,16] for axis in range(3)]
|
||||
if getenv("BEAM_PADTO", 1): actions += [Opt(op=OptOps.PADTO, axis=axis, arg=amt) for amt in [32] for axis in range(7)]
|
||||
actions += [Opt(op=OptOps.LOCAL, axis=0, arg=32), Opt(op=OptOps.LOCAL, axis=6, arg=2)]
|
||||
actions += [Opt(op=OptOps.TC, axis=0, arg=(-1, 0))]
|
||||
actions += [Opt(op=OptOps.TC, axis=axis, arg=(-1, getenv("TC_OPT", 2))) for axis in range(9)] # covers resnet kernels (3 global * 3 reduce)
|
||||
actions += [Opt(op=OptOps.TC, axis=0, arg=(-1, 0, getenv("TC", 1)))]
|
||||
# covers resnet kernels (3 global * 3 reduce)
|
||||
actions += [Opt(op=OptOps.TC, axis=axis, arg=(-1, getenv("TC_OPT", 2), getenv("TC", 1))) for axis in range(9)]
|
||||
actions += [Opt(op=OptOps.SWAP, axis=axis_0, arg=axis_1) for axis_0 in range(5) for axis_1 in range(axis_0+1, 5)]
|
||||
if getenv("NOLOCALS"): actions += [Opt(op=OptOps.NOLOCALS)]
|
||||
|
||||
@@ -111,7 +112,8 @@ def get_kernel_actions(lin:Kernel, include_0=True) -> dict[int, Kernel]:
|
||||
for i, action in enumerate(kernel_actions):
|
||||
if action.op == OptOps.TC and (tc_arg := cast(tuple, action.arg))[0] == -1:
|
||||
# replace every tc_action with default tc with one tc_action for each available tc
|
||||
kernel_actions[i:i+1] = [Opt(op=OptOps.TC, axis=action.axis, arg=(tc_select, tc_arg[1])) for tc_select,_ in enumerate(lin.opts.tensor_cores)]
|
||||
kernel_actions[i:i+1] = \
|
||||
[Opt(op=OptOps.TC, axis=action.axis, arg=(tc_select, tc_arg[1], tc_arg[2])) for tc_select,_ in enumerate(lin.opts.tensor_cores)]
|
||||
|
||||
for i,a in enumerate(kernel_actions):
|
||||
if a.axis is not None and a.op is not OptOps.TC:
|
||||
|
||||
Reference in New Issue
Block a user