mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-10 07:28:15 -05:00
fix tests for rewrite [pr] (#10167)
* fix tests for rewrite [pr] * cleaner * delete linearize_uop * clean up the rest
This commit is contained in:
@@ -1,8 +1,6 @@
|
||||
import unittest
|
||||
from typing import List, cast
|
||||
import numpy as np
|
||||
from tinygrad.codegen.devectorizer import full_graph_rewrite
|
||||
from tinygrad.codegen.linearize import linearize_uop
|
||||
from tinygrad.device import Buffer, Device, is_dtype_supported
|
||||
from tinygrad.dtype import dtypes
|
||||
from tinygrad.engine.realize import CompiledRunner
|
||||
@@ -13,6 +11,7 @@ from tinygrad.runtime.ops_python import PythonRenderer
|
||||
from tinygrad.ops import UOp, Ops
|
||||
from tinygrad.renderer import ProgramSpec
|
||||
from tinygrad.tensor import Tensor, _to_np_dtype
|
||||
from tinygrad.codegen import full_rewrite
|
||||
|
||||
def _test_uop_result(inputs:List[Tensor], stores:List[UOp], local_size=None):
|
||||
for x in inputs: x.realize()
|
||||
@@ -35,7 +34,7 @@ class TestRendererFailures(unittest.TestCase):
|
||||
gate_alu = (lidx0:=UOp(Ops.SPECIAL, dtypes.int, (), ('lidx0', 4))).ne(0)
|
||||
gated_alu_store = UOp(Ops.STORE, dtypes.void, (a.index(lidx0, gate_alu), UOp.const(dtypes.int, 1)))
|
||||
sink = UOp(Ops.SINK, dtypes.void, (gated_alu_store,))
|
||||
uops = linearize_uop(full_graph_rewrite(sink, Device[Device.DEFAULT].renderer))
|
||||
uops = full_rewrite(sink, Device[Device.DEFAULT].renderer)
|
||||
ret = _test_uop_result([], uops, local_size=[4, 1, 1])[0]
|
||||
np.testing.assert_equal(ret, [0, 1, 1, 1])
|
||||
|
||||
@@ -46,7 +45,7 @@ class TestRendererFailures(unittest.TestCase):
|
||||
gate_alu_1 = (lidx1:=UOp(Ops.SPECIAL, dtypes.int, (), ('lidx1', 2))).ne(0)
|
||||
gated_alu_store = UOp(Ops.STORE, dtypes.void, (a.index(lidx0+lidx1*4, gate_alu_0&gate_alu_1), UOp.const(dtypes.int, 1)))
|
||||
sink = UOp(Ops.SINK, dtypes.void, (gated_alu_store,))
|
||||
uops = linearize_uop(full_graph_rewrite(sink, Device[Device.DEFAULT].renderer))
|
||||
uops = full_rewrite(sink, Device[Device.DEFAULT].renderer)
|
||||
ret = _test_uop_result([], uops, local_size=[4, 2, 1])[0]
|
||||
np.testing.assert_equal(ret, [0, 0, 0, 0, 0, 1, 1, 1])
|
||||
|
||||
@@ -60,7 +59,7 @@ class TestCStyleFailures(unittest.TestCase):
|
||||
alu = ld.alu(Ops.MAX, UOp.const(dtypes.int, dtypes.min(dtypes.int)+1))
|
||||
store = UOp.store(a.index(idx), alu)
|
||||
sink = UOp(Ops.SINK, dtypes.void, (store,))
|
||||
uops = linearize_uop(full_graph_rewrite(sink, Device[Device.DEFAULT].renderer))
|
||||
uops = full_rewrite(sink, Device[Device.DEFAULT].renderer)
|
||||
# CPU doesn't use the max function
|
||||
ret = _test_uop_result([Tensor([1])], uops)[0]
|
||||
self.assertEqual(ret[0], 1)
|
||||
@@ -75,7 +74,7 @@ class TestPTXFailures(unittest.TestCase):
|
||||
if_uop = UOp(Ops.IF, dtypes.void, (gate_alu,))
|
||||
gated_alu_store = UOp(Ops.STORE, dtypes.void, (a.index(lidx0, if_uop), val))
|
||||
sink = UOp(Ops.SINK, dtypes.void, (gated_alu_store,))
|
||||
uops = linearize_uop(full_graph_rewrite(sink, Device[Device.DEFAULT].renderer))
|
||||
uops = full_rewrite(sink, Device[Device.DEFAULT].renderer)
|
||||
ret = _test_uop_result([], uops, local_size=[4, 1, 1])[0]
|
||||
np.testing.assert_equal(ret, [0, 1, 1, 1])
|
||||
|
||||
|
||||
Reference in New Issue
Block a user