mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
add fused tensor core opts tests (#4775)
* add fused tc opts tests * n=64
This commit is contained in:
@@ -2,11 +2,13 @@ from typing import List, Tuple
|
||||
import numpy as np
|
||||
import unittest
|
||||
from dataclasses import replace
|
||||
from test.external.fuzz_linearizer import compare_linearizer
|
||||
|
||||
from tinygrad.codegen.kernel import Opt, OptOps, KernelOptError
|
||||
from tinygrad.codegen.linearizer import Linearizer, UOp, UOps, expand_node, expand_idxs
|
||||
from tinygrad.device import Device, Buffer
|
||||
from tinygrad.ops import BinaryOps, BufferOps, MemBuffer, ConstBuffer, LazyOp, LoadOps, TernaryOps, ReduceOps, UnaryOps
|
||||
from tinygrad.renderer import TensorCore
|
||||
from tinygrad.shape.shapetracker import ShapeTracker
|
||||
from tinygrad.shape.view import View
|
||||
from tinygrad.shape.symbolic import MulNode, Variable, NumNode, Node
|
||||
@@ -995,6 +997,29 @@ def _temp_create_multireduce_ast(r0:Tensor, r1:Tensor, merge=lambda r0,r1: LazyO
|
||||
if DEBUG >= 3: print_tree(op)
|
||||
return op,
|
||||
|
||||
def check_fused_tc_opt(tc:TensorCore, r0:Tensor, r1:Tensor, inputs:List[Tensor]):
|
||||
ast = _temp_create_multireduce_ast(r0, r1)
|
||||
(atol, rtol) = ((0.25, 0.01) if tc.dtype_out == dtypes.half else (3e-2, 1e-3)) if tc.dtype_in == dtypes.half else (1e-4, 1e-4)
|
||||
helper_linearizer_ast(ast, inputs, [
|
||||
[],
|
||||
[Opt(OptOps.UPCAST, 0, 4)],
|
||||
[Opt(OptOps.UPCAST, 1, 4)],
|
||||
[Opt(OptOps.UPCAST, 0, 4), Opt(OptOps.UPCAST, 1, 4)], # check upcasts
|
||||
[Opt(OptOps.UNROLL, 0, 2)], # check unroll
|
||||
[Opt(OptOps.UNROLL, 0, 0)], # check full unroll of reduce with locals
|
||||
[Opt(OptOps.LOCAL, 0, 4)], # check local
|
||||
[Opt(OptOps.UPCAST, 0, 4), Opt(OptOps.UNROLL, 0, 2)], # check combo of unroll and local
|
||||
[Opt(OptOps.UPCAST, 0, 4), Opt(OptOps.UPCAST, 1, 4), Opt(OptOps.UNROLL, 0, 2)],
|
||||
[Opt(OptOps.UPCAST, 0, 4), Opt(OptOps.UPCAST, 1, 4), Opt(OptOps.UNROLL, 0, 4)],
|
||||
[Opt(OptOps.UPCAST, 0, 4), Opt(OptOps.UPCAST, 1, 4), Opt(OptOps.UNROLL, 0, 4), Opt(OptOps.LOCAL, 0, 2)],
|
||||
[Opt(OptOps.UPCAST, 1, 4), Opt(OptOps.UPCAST, 0, 4)], # check permutations
|
||||
[Opt(OptOps.UNROLL, 0, 2), Opt(OptOps.UPCAST, 0, 4)],
|
||||
[Opt(OptOps.UPCAST, 0, 4), Opt(OptOps.UNROLL, 0, 2), Opt(OptOps.UPCAST, 1, 4)],
|
||||
[Opt(OptOps.UNROLL, 0, 2), Opt(OptOps.UPCAST, 1, 4), Opt(OptOps.UPCAST, 0, 4), Opt(OptOps.UNROLL, 0, 4)],
|
||||
[Opt(OptOps.LOCAL, 0, 2), Opt(OptOps.UPCAST, 1, 4), Opt(OptOps.UNROLL, 0, 2), Opt(OptOps.UPCAST, 0, 4)],
|
||||
# [Opt(OptOps.GROUP, 0, 2)] # doesn't work because group_for_reduce dims become early locals (conflicting with TC)
|
||||
], apply_tc=True, atol=atol, rtol=rtol)
|
||||
|
||||
class TestKernelOpts(unittest.TestCase):
|
||||
@unittest.skipUnless(Device[Device.DEFAULT].renderer.has_local, "test requires locals")
|
||||
@unittest.skipUnless(Device[Device.DEFAULT].renderer.has_shared, "test requires shared")
|
||||
@@ -1214,6 +1239,25 @@ class TestKernelOpts(unittest.TestCase):
|
||||
with self.assertRaises(KernelOptError):
|
||||
k.apply_opt(Opt(OptOps.TC, 0, 1))
|
||||
|
||||
@unittest.skip("multireduce isn't supported yet")
|
||||
@unittest.skipUnless(Device[Device.DEFAULT].renderer.tensor_cores, "test requires tensor cores")
|
||||
def test_invalid_fused_tensor_core(self):
|
||||
Tensor.manual_seed(1552)
|
||||
for tc in Device[Device.DEFAULT].renderer.tensor_cores:
|
||||
if tc.dtype_in == dtypes.bfloat16: continue
|
||||
M, N, K = 12, 8, 30
|
||||
a, b = Tensor.rand(M, K, dtype=tc.dtype_in).realize(), Tensor.rand(K, N, dtype=tc.dtype_in).realize()
|
||||
r0 = a.matmul(b, acc_dtype=tc.dtype_out)
|
||||
M, N, K = 16, 8, 33
|
||||
c, d = Tensor.rand(M, K, dtype=tc.dtype_in).realize(), Tensor.rand(K, N, dtype=tc.dtype_in).realize()
|
||||
r1 = c.matmul(d, acc_dtype=tc.dtype_out)
|
||||
ast = _temp_create_multireduce_ast(r0, r1)
|
||||
lin = Linearizer(*ast)
|
||||
lin.apply_opt(Opt(op=OptOps.TC, axis=0, amt=2))
|
||||
lin.linearize()
|
||||
result = compare_linearizer(lin)
|
||||
assert result[0] == "COMPARE_ERROR"
|
||||
|
||||
@unittest.skipUnless(Device[Device.DEFAULT].renderer.tensor_cores, "test requires tensor cores")
|
||||
def test_tensor_core_opts(self):
|
||||
N = 128
|
||||
@@ -1244,39 +1288,35 @@ class TestKernelOpts(unittest.TestCase):
|
||||
# [Opt(OptOps.GROUP, 0, 2)] # doesn't work because group_for_reduce dims become early locals (conflicting with TC)
|
||||
], apply_tc=True, atol=atol, rtol=rtol)
|
||||
|
||||
# NOTE: indexing issue happens here
|
||||
@unittest.skip("multireduce isn't supported yet")
|
||||
@unittest.skipUnless(Device[Device.DEFAULT].renderer.tensor_cores, "test requires tensor cores")
|
||||
def test_tensor_core_opts_multireduce(self):
|
||||
N = 128
|
||||
def test_fused_tensor_core_simple(self):
|
||||
N = 64
|
||||
Tensor.manual_seed(1552)
|
||||
for tc in Device[Device.DEFAULT].renderer.tensor_cores:
|
||||
# bf16 buffer returns float32 numpy outputs so test would fail. testing opt with half suffices.
|
||||
if tc.dtype_in == dtypes.bfloat16: continue
|
||||
a, b = Tensor.rand(N, N, dtype=tc.dtype_in).realize(), Tensor.rand(N, N, dtype=tc.dtype_in).realize()
|
||||
[a, b, c, d] = [Tensor.randn(N, N, dtype=tc.dtype_in).realize() for _ in range(4)]
|
||||
r0 = a.matmul(b, acc_dtype=tc.dtype_out)
|
||||
c, d = Tensor.rand(N, N, dtype=tc.dtype_in).realize(), Tensor.rand(N, N, dtype=tc.dtype_in).realize()
|
||||
r1 = c.matmul(d, acc_dtype=tc.dtype_out)
|
||||
ast = _temp_create_multireduce_ast(r0, r1)
|
||||
(atol, rtol) = ((0.25, 0.01) if tc.dtype_out == dtypes.half else (3e-2, 1e-3)) if tc.dtype_in == dtypes.half else (1e-4, 1e-4)
|
||||
helper_linearizer_ast(ast, [a, b, c, d], [
|
||||
[],
|
||||
[Opt(OptOps.UPCAST, 0, 4)],
|
||||
[Opt(OptOps.UPCAST, 1, 4)],
|
||||
[Opt(OptOps.UPCAST, 0, 4), Opt(OptOps.UPCAST, 1, 4)], # check upcasts
|
||||
[Opt(OptOps.UNROLL, 0, 2)], # check unroll
|
||||
[Opt(OptOps.UNROLL, 0, 0)], # check full unroll of reduce with locals
|
||||
[Opt(OptOps.LOCAL, 0, 4)], # check local
|
||||
[Opt(OptOps.UPCAST, 0, 4), Opt(OptOps.UNROLL, 0, 2)], # check combo of unroll and local
|
||||
[Opt(OptOps.UPCAST, 0, 4), Opt(OptOps.UPCAST, 1, 4), Opt(OptOps.UNROLL, 0, 2)],
|
||||
[Opt(OptOps.UPCAST, 0, 4), Opt(OptOps.UPCAST, 1, 4), Opt(OptOps.UNROLL, 0, 4)],
|
||||
[Opt(OptOps.UPCAST, 0, 4), Opt(OptOps.UPCAST, 1, 4), Opt(OptOps.UNROLL, 0, 4), Opt(OptOps.LOCAL, 0, 2)],
|
||||
[Opt(OptOps.UPCAST, 1, 4), Opt(OptOps.UPCAST, 0, 4)], # check permutations
|
||||
[Opt(OptOps.UNROLL, 0, 2), Opt(OptOps.UPCAST, 0, 4)],
|
||||
[Opt(OptOps.UPCAST, 0, 4), Opt(OptOps.UNROLL, 0, 2), Opt(OptOps.UPCAST, 1, 4)],
|
||||
[Opt(OptOps.UNROLL, 0, 2), Opt(OptOps.UPCAST, 1, 4), Opt(OptOps.UPCAST, 0, 4), Opt(OptOps.UNROLL, 0, 4)],
|
||||
[Opt(OptOps.LOCAL, 0, 2), Opt(OptOps.UPCAST, 1, 4), Opt(OptOps.UNROLL, 0, 2), Opt(OptOps.UPCAST, 0, 4)],
|
||||
# [Opt(OptOps.GROUP, 0, 2)] # doesn't work because group_for_reduce dims become early locals (conflicting with TC)
|
||||
], apply_tc=True, atol=atol, rtol=rtol, wanna_output=[np.matmul(a.numpy(), b.numpy()).flatten() + np.matmul(c.numpy(), d.numpy()).flatten()])
|
||||
check_fused_tc_opt(tc, r0, r1, [a, b, c, d])
|
||||
|
||||
@unittest.skip("multireduce isn't supported yet")
|
||||
@unittest.skipUnless(Device[Device.DEFAULT].renderer.tensor_cores, "test requires tensor cores")
|
||||
def test_fused_tensor_core_permuted(self):
|
||||
N = 64
|
||||
Tensor.manual_seed(1552)
|
||||
for tc in Device[Device.DEFAULT].renderer.tensor_cores:
|
||||
if tc.dtype_in == dtypes.bfloat16: continue
|
||||
# one permuted
|
||||
[a, b, c, d] = [Tensor.randn(N, N, dtype=tc.dtype_in).realize() for _ in range(4)]
|
||||
r0 = a.matmul(b, acc_dtype=tc.dtype_out)
|
||||
r1 = c.T.matmul(d, acc_dtype=tc.dtype_out)
|
||||
check_fused_tc_opt(tc, r0, r1, [a, b, c, d])
|
||||
# both permuted
|
||||
r0 = a.T.matmul(b, acc_dtype=tc.dtype_out)
|
||||
r1 = c.T.matmul(d, acc_dtype=tc.dtype_out)
|
||||
check_fused_tc_opt(tc, r0, r1, [a, b, c, d])
|
||||
|
||||
def test_padto_matmul(self):
|
||||
if CI and Device.DEFAULT in ["CUDA", "AMD", "NV"]: self.skipTest("super slow on CUDA and AMD because of the big grid dims")
|
||||
|
||||
Reference in New Issue
Block a user