extra/gemm: add a simple_conv.py along with correctness check (#3236)

* extra/gemm: add a simple_conv.py along with correctness check

The goal is to easily test tensor core triggering situations

* test: add tests for acc_dtype handling and fixed typing
This commit is contained in:
Francis Lam
2024-01-26 19:06:57 -08:00
committed by GitHub
parent 0aad8d238b
commit 4273aabe31
3 changed files with 58 additions and 6 deletions

View File

@@ -12,7 +12,7 @@ from tinygrad.tensor import Tensor
from tinygrad.jit import CacheCollector
from tinygrad.realize import run_schedule
from tinygrad.helpers import prod, Context
from tinygrad.dtype import dtypes
from tinygrad.dtype import DType, dtypes
@unittest.skipIf(not isinstance(Device[Device.DEFAULT], Compiled), "linearizer is only for compiled backends")
class TestLinearizer(unittest.TestCase):
@@ -81,6 +81,28 @@ class TestLinearizer(unittest.TestCase):
local = [uop for uop in k.uops if uop.uop == UOps.DEFINE_ACC]
assert local[0].dtype == acc_dtype
def test_arg_acc_dtype(self):
def helper_arg_acc_dtype(c: Tensor, expected_dtype:DType):
k = Linearizer(c.lazydata.schedule()[-1].ast)
k.linearize()
local = [uop for uop in k.uops if uop.uop == UOps.DEFINE_ACC]
assert local[0].dtype == expected_dtype
tests = (
(dtypes.float16, None, dtypes.float),
(dtypes.bfloat16, None, dtypes.float),
(dtypes.float, None, dtypes.float),
(dtypes.float16, dtypes.float16, dtypes.float16),
(dtypes.bfloat16, dtypes.bfloat16, dtypes.bfloat16),
(dtypes.float, dtypes.float16, dtypes.float16),
)
for tensor_dtype, acc_dtype, expected_dtype in tests:
a, b = Tensor.rand(8, 8, dtype=tensor_dtype), Tensor.rand(8, 8, dtype=tensor_dtype)
helper_arg_acc_dtype(a.sum(acc_dtype=acc_dtype), expected_dtype)
helper_arg_acc_dtype(a.matmul(b, acc_dtype=acc_dtype), expected_dtype)
d, w = Tensor.rand(4, 8, 8, 8, dtype=tensor_dtype), Tensor.rand(8, 8, 2, 2, dtype=tensor_dtype)
helper_arg_acc_dtype(d.conv2d(w, acc_dtype=acc_dtype), expected_dtype)
@unittest.skipUnless(Device.DEFAULT in tensor_cores, "No tensor cores for device")
def test_tensor_cores(self):
for tc in tensor_cores[Device.DEFAULT]: