diff --git a/test/test_renderer_failures.py b/test/test_renderer_failures.py index 8c676c873e..64080e2b2c 100644 --- a/test/test_renderer_failures.py +++ b/test/test_renderer_failures.py @@ -2,7 +2,7 @@ import unittest from typing import List, cast import numpy as np from tinygrad.device import Buffer, Device, is_dtype_supported -from tinygrad.dtype import dtypes +from tinygrad.dtype import dtypes, ConstType from tinygrad.engine.realize import CompiledRunner from tinygrad.helpers import dedup, flatten, prod from tinygrad.renderer.cstyle import CStyleLanguage @@ -27,6 +27,18 @@ def _test_uop_result(inputs:List[Tensor], stores:List[UOp], local_size=None): ei.exec(outbufs+inbufs) return [np.frombuffer(x.as_buffer(), _to_np_dtype(x.dtype)) for x in outbufs] +def _setup_and_test_alu(alu_op:Ops, input_val:ConstType, *alu_src_uops:UOp): + dtype = alu_src_uops[0].dtype + a = UOp(Ops.DEFINE_GLOBAL, dtype.ptr(), (), 0) + b = UOp(Ops.DEFINE_GLOBAL, dtype.ptr(), (), 1) + idx = UOp.const(dtypes.int, 0) + ld = UOp(Ops.LOAD, dtype, (b.index(idx),)) + alu = ld.alu(alu_op, *alu_src_uops) + store = UOp.store(a.index(idx), alu) + sink = UOp(Ops.SINK, dtypes.void, (store,)) + uops = full_rewrite(sink, Device[Device.DEFAULT].renderer) + return _test_uop_result([Tensor([input_val])], uops)[0] + class TestRendererFailures(unittest.TestCase): @unittest.skipIf(not isinstance(Device[Device.DEFAULT].renderer, (PTXRenderer, PythonRenderer)), "test is for ptx or python renderer") def test_gated_store_with_alu(self): @@ -52,16 +64,8 @@ class TestRendererFailures(unittest.TestCase): @unittest.skipIf(not isinstance(Device[Device.DEFAULT].renderer, CStyleLanguage), "uops are for cstyle") class TestCStyleFailures(unittest.TestCase): def test_inline_const_alu(self): - a = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), (), 0) - b = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), (), 1) - idx = UOp.const(dtypes.int, 0) - ld = UOp(Ops.LOAD, dtypes.int, (b.index(idx),)) - alu = ld.alu(Ops.MAX, UOp.const(dtypes.int, dtypes.min(dtypes.int)+1)) - store = UOp.store(a.index(idx), alu) - sink = UOp(Ops.SINK, dtypes.void, (store,)) - uops = full_rewrite(sink, Device[Device.DEFAULT].renderer) # CPU doesn't use the max function - ret = _test_uop_result([Tensor([1])], uops)[0] + ret = _setup_and_test_alu(Ops.MAX, 1, UOp.const(dtypes.int, dtypes.min(dtypes.int)+1)) self.assertEqual(ret[0], 1) @unittest.skipIf(not isinstance(Device[Device.DEFAULT].renderer, PTXRenderer), "tests for ptx renderer")