mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-10 07:28:15 -05:00
@@ -8,7 +8,7 @@ from tinygrad.dtype import dtypes
|
||||
from tinygrad.engine.realize import CompiledRunner
|
||||
from tinygrad.helpers import dedup, flatten, prod
|
||||
from tinygrad.renderer.cstyle import CStyleLanguage
|
||||
from tinygrad.ops import BinaryOps, UOp, UOps
|
||||
from tinygrad.ops import BinaryOps, UOp, Ops
|
||||
from tinygrad.renderer import Program
|
||||
from tinygrad.tensor import Tensor, _to_np_dtype
|
||||
from tinygrad.engine.lazy import LazyBuffer
|
||||
@@ -20,7 +20,7 @@ def _test_uop_result(inputs:List[Tensor], stores:List[UOp], local_size=None):
|
||||
def _recursive_add(uop:UOp) -> List[UOp]: return flatten([_recursive_add(x) for x in uop.src])+[uop]
|
||||
uops = dedup(flatten(_recursive_add(st) for st in stores))
|
||||
outbufs = [Buffer(Device.DEFAULT, sz:=(1 if local_size is None else prod(local_size)), (dtype:=u.src[1].dtype), \
|
||||
initial_value=np.zeros(sz, dtype=_to_np_dtype(dtype)).data) for u in uops if u.op is UOps.STORE]
|
||||
initial_value=np.zeros(sz, dtype=_to_np_dtype(dtype)).data) for u in uops if u.op is Ops.STORE]
|
||||
inbufs = [cast(LazyBuffer,x.lazydata).base.buffer for x in inputs]
|
||||
src = Device[Device.DEFAULT].renderer.render("test", uops)
|
||||
ei = CompiledRunner(Program("test", src, Device.DEFAULT, uops=uops, local_size=local_size))
|
||||
@@ -30,13 +30,13 @@ def _test_uop_result(inputs:List[Tensor], stores:List[UOp], local_size=None):
|
||||
@unittest.skipIf(not isinstance(Device[Device.DEFAULT].renderer, CStyleLanguage), "uops are for cstyle")
|
||||
class TestCStyleFailures(unittest.TestCase):
|
||||
def test_inline_const_alu(self):
|
||||
a = UOp(UOps.DEFINE_GLOBAL, dtypes.int.ptr(), (), 0)
|
||||
b = UOp(UOps.DEFINE_GLOBAL, dtypes.int.ptr(), (), 1)
|
||||
a = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), (), 0)
|
||||
b = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), (), 1)
|
||||
idx = UOp.const(dtypes.int, 0)
|
||||
ld = UOp(UOps.LOAD, dtypes.int, (b, idx))
|
||||
ld = UOp(Ops.LOAD, dtypes.int, (b, idx))
|
||||
alu = ld.alu(BinaryOps.MAX, UOp.const(dtypes.int, dtypes.min(dtypes.int)+1))
|
||||
store = UOp.store(a, idx, alu)
|
||||
sink = UOp(UOps.SINK, dtypes.void, (store,))
|
||||
sink = UOp(Ops.SINK, dtypes.void, (store,))
|
||||
uops = linearize_uop(full_graph_rewrite(sink, Device[Device.DEFAULT].renderer))
|
||||
# CLANG doesn't use the max function
|
||||
ret = _test_uop_result([Tensor([1])], uops)[0]
|
||||
@@ -45,21 +45,21 @@ class TestCStyleFailures(unittest.TestCase):
|
||||
@unittest.skipUnless(Device[Device.DEFAULT].renderer.has_local and Device.DEFAULT == "PTX", "need local")
|
||||
class TestPTXFailures(unittest.TestCase):
|
||||
def test_gated_store_with_alu(self):
|
||||
a = UOp(UOps.DEFINE_GLOBAL, dtypes.int.ptr(), (), 0)
|
||||
gate_alu = (lidx0:=UOp(UOps.SPECIAL, dtypes.int, (), ('lidx0', 4))).ne(0)
|
||||
gated_alu_store = UOp(UOps.STORE, dtypes.void, (a, lidx0, UOp.const(dtypes.int, 1), gate_alu))
|
||||
sink = UOp(UOps.SINK, dtypes.void, (gated_alu_store,))
|
||||
a = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), (), 0)
|
||||
gate_alu = (lidx0:=UOp(Ops.SPECIAL, dtypes.int, (), ('lidx0', 4))).ne(0)
|
||||
gated_alu_store = UOp(Ops.STORE, dtypes.void, (a, lidx0, UOp.const(dtypes.int, 1), gate_alu))
|
||||
sink = UOp(Ops.SINK, dtypes.void, (gated_alu_store,))
|
||||
uops = linearize_uop(full_graph_rewrite(sink, Device[Device.DEFAULT].renderer))
|
||||
ret = _test_uop_result([], uops, local_size=[4, 1, 1])[0]
|
||||
np.testing.assert_equal(ret, [0, 1, 1, 1])
|
||||
|
||||
def test_gated_store_with_if(self):
|
||||
a = UOp(UOps.DEFINE_GLOBAL, dtypes.int.ptr(), (), 0)
|
||||
gate_alu = (lidx0:=UOp(UOps.SPECIAL, dtypes.int, (), ('lidx0', 4))).ne(0)
|
||||
a = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), (), 0)
|
||||
gate_alu = (lidx0:=UOp(Ops.SPECIAL, dtypes.int, (), ('lidx0', 4))).ne(0)
|
||||
val = UOp.const(dtypes.int, 1)
|
||||
if_uop = UOp(UOps.IF, dtypes.void, (gate_alu,))
|
||||
gated_alu_store = UOp(UOps.STORE, dtypes.void, (a, lidx0, val, if_uop))
|
||||
sink = UOp(UOps.SINK, dtypes.void, (gated_alu_store,))
|
||||
if_uop = UOp(Ops.IF, dtypes.void, (gate_alu,))
|
||||
gated_alu_store = UOp(Ops.STORE, dtypes.void, (a, lidx0, val, if_uop))
|
||||
sink = UOp(Ops.SINK, dtypes.void, (gated_alu_store,))
|
||||
uops = linearize_uop(full_graph_rewrite(sink, Device[Device.DEFAULT].renderer))
|
||||
ret = _test_uop_result([], uops, local_size=[4, 1, 1])[0]
|
||||
np.testing.assert_equal(ret, [0, 1, 1, 1])
|
||||
|
||||
Reference in New Issue
Block a user