diff --git a/test/test_uop_graph.py b/test/test_uop_graph.py index 7513bd70bd..832e3d4b38 100644 --- a/test/test_uop_graph.py +++ b/test/test_uop_graph.py @@ -1,8 +1,8 @@ from typing import List import unittest, time, pytest from tinygrad import dtypes, Device -from tinygrad.helpers import DEBUG, AMX -from tinygrad.ops import Ops, UOp, KernelInfo, UPat, PatternMatcher +from tinygrad.helpers import DEBUG +from tinygrad.ops import Ops, UOp, KernelInfo, UPat, PatternMatcher, track_rewrites from tinygrad.renderer import Renderer from tinygrad.codegen.lowerer import rewrite_shapetracker_with_index from tinygrad.codegen.devectorizer import full_graph_rewrite, graph_rewrite, sym @@ -502,7 +502,9 @@ class TestUOpGraph(unittest.TestCase): # ranges are closed in the right order self.assertEqual(endranges[-1].src[0], ranges[0]) +@track_rewrites() def expander_rewrite(sink): return graph_rewrite(sink, sym + expander) +@track_rewrites() def float4_rewrite(sink): return full_graph_rewrite(sink, Renderer()) class TestExpander(unittest.TestCase): @@ -652,72 +654,6 @@ class TestExpander(unittest.TestCase): sink = expander_rewrite(sink) print(sink) -class TestLoadStoreFolder(unittest.TestCase): - def test_simple_load_fold(self): - buf = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr()) - load = [UOp(Ops.LOAD, dtypes.float, (buf.index(UOp.const(dtypes.int, i)),)) for i in range(4)] - sink = UOp(Ops.VECTORIZE, dtypes.float.vec(len(load)), tuple(load)) - - sink = float4_rewrite(sink.sink()) - assert len([x for x in sink.toposort if x.op is Ops.LOAD]) == 1 - - @unittest.skipIf(Device.DEFAULT in {"CPU"} and AMX, "CPU with AMX upcasts float up to size 16") - def test_two_load_fold(self): - buf = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr()) - load = [UOp(Ops.LOAD, dtypes.float, (buf.index(UOp.const(dtypes.int, i)),)) for i in range(8)] - sink = UOp(Ops.VECTORIZE, dtypes.float.vec(len(load)), tuple(load)) - sink = float4_rewrite(sink.sink()) - assert len([x for x in sink.toposort if x.op is Ops.LOAD]) == 2 - - def test_simple_load_fold_gated(self): - buf = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr()) - gate = UOp(Ops.DEFINE_VAR, dtypes.bool) - load = [UOp(Ops.LOAD, dtypes.float, (buf.index(UOp.const(dtypes.int, i), gate),)) for i in range(4)] - sink = UOp(Ops.VECTORIZE, dtypes.float.vec(len(load)), tuple(load)) - sink = float4_rewrite(sink.sink()) - assert len([x for x in sink.toposort if x.op is Ops.LOAD]) == 1 - single_load = [x for x in sink.toposort if x.op is Ops.LOAD][0] - self.assertEqual(single_load.src[1].op, Ops.VECTORIZE) - - def test_simple_load_dont_fold_different_gated(self): - buf = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr()) - gate = UOp.variable("g1", False, True, dtypes.bool) - gate2 = UOp.variable("g2", False, True, dtypes.bool) - load = [UOp(Ops.LOAD, dtypes.float, (buf.index(UOp.const(dtypes.int, i), gate if i == 0 else gate2), - UOp.const(dtypes.float, 0))) for i in range(4)] - sink = UOp(Ops.VECTORIZE, dtypes.float.vec(len(load)), tuple(load)) - sink = float4_rewrite(sink.sink()) - assert len([x for x in sink.toposort if x.op is Ops.LOAD]) == 3 - - def test_simple_store_fold(self): - buf = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr()) - load = [UOp(Ops.STORE, dtypes.float, (buf.index(UOp.const(dtypes.int, i)), UOp.const(dtypes.float, 0))) for i in range(4)] - sink = UOp(Ops.SINK, dtypes.void, tuple(load)) - sink = float4_rewrite(sink) - assert len([x for x in sink.toposort if x.op is Ops.STORE]) == 1 - - def test_simple_store_fold_gate(self): - buf = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr()) - gate = UOp.variable("g1", False, True, dtypes.bool) - load = [UOp(Ops.STORE, dtypes.float, (buf.index(UOp.const(dtypes.int, i)), UOp.const(dtypes.float, 0), gate)) for i in range(4)] - sink = UOp(Ops.SINK, dtypes.void, tuple(load)) - sink = float4_rewrite(sink) - assert len([x for x in sink.toposort if x.op is Ops.STORE]) == 1 - one_store = [x for x in sink.toposort if x.op is Ops.STORE][0] - assert len(one_store.src) == 3 - _if_node = one_store.src[2] - assert _if_node.op == Ops.IF and _if_node.src[0] == gate - - def test_simple_store_dont_fold(self): - buf = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr()) - gate = UOp.variable("g1", False, True, dtypes.bool) - gate2 = UOp.variable("g2", False, True, dtypes.bool) - load = [UOp(Ops.STORE, dtypes.float, (buf.index(UOp.const(dtypes.int, i), gate if i == 0 else gate2), - UOp.const(dtypes.float, i))) for i in range(4)] - sink = UOp(Ops.SINK, dtypes.void, tuple(load)) - sink = float4_rewrite(sink) - assert len([x for x in sink.toposort if x.op is Ops.STORE]) == 3 - class TestIFUOps(unittest.TestCase): def test_create_ifs(self): gbuf = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), (), 0) diff --git a/tinygrad/dtype.py b/tinygrad/dtype.py index 773be41997..1ee9bfa8f4 100644 --- a/tinygrad/dtype.py +++ b/tinygrad/dtype.py @@ -54,6 +54,8 @@ class PtrDType(DType): def vec(self, sz:int) -> DType: assert self.v == 1, f"can't vectorize ptr {self} with size {sz}" if sz == 1: return self # sz=1 is a scalar + if isinstance(self, ImageDType): + return ImageDType(self.priority, self.itemsize, self.name, self.fmt, self.count, self, self._base, self.local, sz, self.size, self.shape) return type(self)(self.priority, self.itemsize, self.name, self.fmt, self.count, self, self._base, self.local, sz, self.size) def ptr(self, size=-1, local=False): raise RuntimeError("can't make a pointer from a pointer") @property diff --git a/tinygrad/ops.py b/tinygrad/ops.py index 2df87e2281..b54c09488a 100644 --- a/tinygrad/ops.py +++ b/tinygrad/ops.py @@ -688,7 +688,7 @@ def get_location() -> tuple[str, int]: # find the real frame in the file that has the UPat, TODO: is there a better way to do this? while frm.f_back is not None and pathlib.Path(frm.f_back.f_code.co_filename).name in {"ops.py", "rewriter.py", "schedule.py", "multi.py", "symbolic.py", "expander.py", "lowerer.py", "cstyle.py", - "linearize.py"}: + "linearize.py", "devectorizer.py"}: frm = frm.f_back return frm.f_code.co_filename, frm.f_lineno @functools.lru_cache(None)