From 6dfa63cb2154b614c8aa9e910030c7f22f09ff7e Mon Sep 17 00:00:00 2001 From: George Hotz <72895+geohot@users.noreply.github.com> Date: Thu, 12 Sep 2024 14:58:14 +0800 Subject: [PATCH] more vconst stuff + gep tuple [run_process_replay] (#6494) * more vconst stuff [run_process_replay] * revert that * fix inf loop --- test/test_uop_graph.py | 15 +++++++++++++-- tinygrad/codegen/uopgraph.py | 3 ++- tinygrad/ops.py | 1 + 3 files changed, 16 insertions(+), 3 deletions(-) diff --git a/test/test_uop_graph.py b/test/test_uop_graph.py index 710740ef58..86dcba1798 100644 --- a/test/test_uop_graph.py +++ b/test/test_uop_graph.py @@ -86,9 +86,18 @@ class TestGraphRewriteConst(unittest.TestCase): v1 = UOp.const(dtypes.int.vec(3), (0,1,2)) v2 = UOp.const(dtypes.int.vec(3), (5,6,7)) ret = graph_rewrite(v1+v2, constant_folder) + self.assertEqual(ret.op, UOps.VCONST) self.assertEqual(ret.dtype, dtypes.int.vec(3)) self.assertEqual(ret.arg, (5,7,9)) + def test_add_const_lose_v(self): + v1 = UOp.const(dtypes.int.vec(3), (0,1,2)) + v2 = UOp.const(dtypes.int.vec(3), (2,1,0)) + ret = graph_rewrite(v1+v2, constant_folder) + self.assertEqual(ret.op, UOps.CONST) + self.assertEqual(ret.dtype, dtypes.int.vec(3)) + self.assertEqual(ret.arg, 2) + class TestGraphRewrite(unittest.TestCase): def test_dedup(self): v1 = UOp(UOps.DEFINE_VAR, dtypes.float) @@ -421,7 +430,9 @@ class TestUOpGraph(unittest.TestCase): # ranges are closed in the right order self.assertEqual(endranges[-1].src[0], ranges[0]) -def expander_rewrite(sink): return graph_rewrite(sink, constant_folder + expander + reducer) +def expander_rewrite(sink): + sink = graph_rewrite(sink, constant_folder + expander) + return graph_rewrite(sink, constant_folder + reducer) def float4_rewrite(sink): return graph_rewrite(sink, constant_folder + expander + float4_folding) class TestExpander(unittest.TestCase): @@ -593,7 +604,7 @@ class TestLoadStoreFolder(unittest.TestCase): sink = float4_rewrite(sink) assert len([x for x in sink.sparents if x.op is UOps.LOAD]) == 1 single_load = [x for x in sink.sparents if x.op is UOps.LOAD][0] - self.assertListEqual([src.arg for src in single_load.src[2].src], [0.0, 1.0, 2.0, 3.0]) + self.assertListEqual(list(single_load.src[2].arg), [0.0, 1.0, 2.0, 3.0]) def test_simple_load_dont_fold_different_gated(self): buf = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float)) diff --git a/tinygrad/codegen/uopgraph.py b/tinygrad/codegen/uopgraph.py index 8bed8b6729..ab9360319e 100644 --- a/tinygrad/codegen/uopgraph.py +++ b/tinygrad/codegen/uopgraph.py @@ -449,6 +449,7 @@ def create_gate(root:UOp) -> Optional[UOp]: return None if len(root.src) == 3 or (ret:=_gate_srcs(root, root.src[3])) is root else ret expander = PatternMatcher([ + (UPat(UOps.VECTORIZE, src=UPat(UOps.CONST), name="vec"), lambda vec: UOp.const(vec.dtype, tuple(x.arg for x in vec.src))), # create gate MUST BE BEFORE expander (UPat(UOps.STORE, name="root"), create_gate), # do expansion @@ -490,7 +491,7 @@ reducer = PatternMatcher([ (UPat(UOps.LOAD, src=(UPat.var("buf"), UPat()), allow_any_len=True, name="load"), fix_unfoldable_image_load), ]) -no_pyint = PatternMatcher([(UPat((UOps.CONST, UOps.ALU, UOps.SPECIAL, UOps.RANGE, UOps.EXPAND, UOps.VECTORIZE), name="x"), +no_pyint = PatternMatcher([(UPat((UOps.CONST, UOps.VCONST, UOps.ALU, UOps.SPECIAL, UOps.RANGE, UOps.EXPAND, UOps.VECTORIZE), name="x"), lambda x: UOp(x.op, dtypes.int32.vec(x.dtype.count) if x.dtype.count > 1 else dtypes.int32, x.src, x.arg) \ if x.dtype.scalar() == dtypes.pyint else None)]) diff --git a/tinygrad/ops.py b/tinygrad/ops.py index d0919c77be..5b4d699d90 100644 --- a/tinygrad/ops.py +++ b/tinygrad/ops.py @@ -398,6 +398,7 @@ class UOp(MathTrait): def _const(cls, dtype:DType, b:Tuple[ConstType, ...]|ConstType|Variable): # TODO: fix dtype of b.max after Variable is just an UOp if isinstance(b, Variable): return cls(UOps.DEFINE_VAR, dtype, arg=(b.expr, cls.const(dtypes.int, b.min), cls.const(dtypes.int, cast(int,b.max)))) # type: ignore + if isinstance(b, tuple) and all_same(b): b = b[0] # doesn't have to be a VCONST if they are all the same return cls(UOps.VCONST if isinstance(b, tuple) else UOps.CONST, dtype, arg=dtypes.as_const(b, dtype) if dtype is not None else b) # type: ignore @functools.cached_property def parents(self) -> Dict[UOp, None]: return {**{x:None for x in self.src}, **{k:None for x in self.src for k in x.parents.keys()}}