mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
extra/gemm: add a simple_conv.py along with correctness check (#3236)
* extra/gemm: add a simple_conv.py along with correctness check The goal is to easily test tensor core triggering situations * test: add tests for acc_dtype handling and fixed typing
This commit is contained in:
@@ -12,7 +12,7 @@ from tinygrad.tensor import Tensor
|
||||
from tinygrad.jit import CacheCollector
|
||||
from tinygrad.realize import run_schedule
|
||||
from tinygrad.helpers import prod, Context
|
||||
from tinygrad.dtype import dtypes
|
||||
from tinygrad.dtype import DType, dtypes
|
||||
|
||||
@unittest.skipIf(not isinstance(Device[Device.DEFAULT], Compiled), "linearizer is only for compiled backends")
|
||||
class TestLinearizer(unittest.TestCase):
|
||||
@@ -81,6 +81,28 @@ class TestLinearizer(unittest.TestCase):
|
||||
local = [uop for uop in k.uops if uop.uop == UOps.DEFINE_ACC]
|
||||
assert local[0].dtype == acc_dtype
|
||||
|
||||
def test_arg_acc_dtype(self):
|
||||
def helper_arg_acc_dtype(c: Tensor, expected_dtype:DType):
|
||||
k = Linearizer(c.lazydata.schedule()[-1].ast)
|
||||
k.linearize()
|
||||
local = [uop for uop in k.uops if uop.uop == UOps.DEFINE_ACC]
|
||||
assert local[0].dtype == expected_dtype
|
||||
|
||||
tests = (
|
||||
(dtypes.float16, None, dtypes.float),
|
||||
(dtypes.bfloat16, None, dtypes.float),
|
||||
(dtypes.float, None, dtypes.float),
|
||||
(dtypes.float16, dtypes.float16, dtypes.float16),
|
||||
(dtypes.bfloat16, dtypes.bfloat16, dtypes.bfloat16),
|
||||
(dtypes.float, dtypes.float16, dtypes.float16),
|
||||
)
|
||||
for tensor_dtype, acc_dtype, expected_dtype in tests:
|
||||
a, b = Tensor.rand(8, 8, dtype=tensor_dtype), Tensor.rand(8, 8, dtype=tensor_dtype)
|
||||
helper_arg_acc_dtype(a.sum(acc_dtype=acc_dtype), expected_dtype)
|
||||
helper_arg_acc_dtype(a.matmul(b, acc_dtype=acc_dtype), expected_dtype)
|
||||
d, w = Tensor.rand(4, 8, 8, 8, dtype=tensor_dtype), Tensor.rand(8, 8, 2, 2, dtype=tensor_dtype)
|
||||
helper_arg_acc_dtype(d.conv2d(w, acc_dtype=acc_dtype), expected_dtype)
|
||||
|
||||
@unittest.skipUnless(Device.DEFAULT in tensor_cores, "No tensor cores for device")
|
||||
def test_tensor_cores(self):
|
||||
for tc in tensor_cores[Device.DEFAULT]:
|
||||
|
||||
Reference in New Issue
Block a user