From 603c03bef261fadf493c8ad8451c121b8a546360 Mon Sep 17 00:00:00 2001 From: George Hotz <72895+geohot@users.noreply.github.com> Date: Mon, 5 May 2025 19:19:49 -0700 Subject: [PATCH] fix tests for rewrite [pr] (#10167) * fix tests for rewrite [pr] * cleaner * delete linearize_uop * clean up the rest --- test/helpers.py | 7 ++- test/test_const_folding.py | 6 +-- test/test_renderer_failures.py | 11 ++--- test/test_tensor.py | 6 +-- test/test_uop_graph.py | 69 ++++------------------------ test/test_uops.py | 9 ++-- test/test_uops_stats.py | 6 +-- test/unit/test_graph_rewrite.py | 10 ++-- test/unit/test_simplify_valid_idx.py | 8 ++-- test/unit/test_uop_symbolic.py | 8 ++-- tinygrad/codegen/__init__.py | 4 ++ tinygrad/codegen/devectorizer.py | 29 ++---------- tinygrad/codegen/expander.py | 16 +------ tinygrad/codegen/kernel.py | 5 +- tinygrad/codegen/linearize.py | 22 +-------- tinygrad/codegen/lowerer.py | 11 +---- 16 files changed, 56 insertions(+), 171 deletions(-) diff --git a/test/helpers.py b/test/helpers.py index a55373e3f7..a0a6318315 100644 --- a/test/helpers.py +++ b/test/helpers.py @@ -9,8 +9,7 @@ from tinygrad.engine.realize import Runner from tinygrad.dtype import ConstType, DType from tinygrad.nn.state import get_parameters from tinygrad.helpers import T, unwrap, CI -from tinygrad.codegen.linearize import linearize_uop -from tinygrad.codegen.devectorizer import full_graph_rewrite +from tinygrad.codegen import full_rewrite from tinygrad.runtime.ops_python import PythonProgram, PythonRenderer, PythonCompiler, PythonAllocator def derandomize_model(model): @@ -59,8 +58,8 @@ def eval_uop(uop:UOp, inputs:list[tuple[DType, list[Any]]]|None=None): bufs.append(buf:=allocator.alloc(len(data) * buf_dt.itemsize)) allocator._copyin(buf, memoryview(struct.pack(str(len(data)) + buf_dt.fmt, *data))) g = UOp(Ops.DEFINE_GLOBAL, uop.dtype.ptr(), arg=0, src=()) - rw = full_graph_rewrite(UOp.store(g.index(UOp.const(dtypes.int, 0)), uop).sink(), PythonRenderer) - prog = PythonProgram("run", PythonCompiler().compile(PythonRenderer().render(linearize_uop(rw)))) + lst = full_rewrite(UOp.store(g.index(UOp.const(dtypes.int, 0)), uop).sink(), PythonRenderer) + prog = PythonProgram("run", PythonCompiler().compile(PythonRenderer().render(lst))) prog(out_buf:=allocator.alloc(uop.dtype.itemsize), *bufs) return out_buf.cast(uop.dtype.fmt).tolist()[0] diff --git a/test/test_const_folding.py b/test/test_const_folding.py index 649cbd1a9c..8ed11a301e 100644 --- a/test/test_const_folding.py +++ b/test/test_const_folding.py @@ -3,7 +3,7 @@ from typing import Any from tinygrad import Tensor, Device, dtypes from tinygrad.dtype import DType from tinygrad.ops import Ops, UOp -from tinygrad.codegen.devectorizer import full_graph_rewrite +from tinygrad.codegen import full_rewrite_to_sink import numpy as np from tinygrad.device import is_dtype_supported from test.helpers import not_support_multi_device @@ -105,7 +105,7 @@ class TestBitcastConstFolding(unittest.TestCase): def t(cases: dict[DType, Any]): for (from_dt, from_v), (to_dt, to_v) in itertools.product(cases.items(), cases.items()): if not math.isnan(from_v): - r = full_graph_rewrite(UOp.const(from_dt, from_v).bitcast(to_dt).sink()).src[0] + r = full_rewrite_to_sink(UOp.const(from_dt, from_v).bitcast(to_dt).sink()).src[0] self.assertEqual(r.op, Ops.CONST, msg:=f"{from_dt} -> {to_dt} ({from_v} -> {to_v})") self.assertEqual(r.dtype, to_dt, msg) np.testing.assert_equal(r.arg, to_v, msg) @@ -128,7 +128,7 @@ class TestBitcastConstFolding(unittest.TestCase): t({dtypes.int64: 4598983288165178391, dtypes.uint64: 4598983288165178391, dtypes.float64: 0.29485681936461233}) def test_vec_bitcast(self): - r = full_graph_rewrite(UOp.const(dtypes.int32.vec(3), (-1, -2**31, 75)).bitcast(dtypes.uint32.vec(3)).sink()).src[0] + r = full_rewrite_to_sink(UOp.const(dtypes.int32.vec(3), (-1, -2**31, 75)).bitcast(dtypes.uint32.vec(3)).sink()).src[0] self.assertEqual(r.op, Ops.VECTORIZE) self.assertEqual(r.dtype, dtypes.uint32.vec(3)) self.assertEqual(tuple(x.arg for x in r.src), (2**32-1, 2**31, 75)) diff --git a/test/test_renderer_failures.py b/test/test_renderer_failures.py index 2db5b41389..8c676c873e 100644 --- a/test/test_renderer_failures.py +++ b/test/test_renderer_failures.py @@ -1,8 +1,6 @@ import unittest from typing import List, cast import numpy as np -from tinygrad.codegen.devectorizer import full_graph_rewrite -from tinygrad.codegen.linearize import linearize_uop from tinygrad.device import Buffer, Device, is_dtype_supported from tinygrad.dtype import dtypes from tinygrad.engine.realize import CompiledRunner @@ -13,6 +11,7 @@ from tinygrad.runtime.ops_python import PythonRenderer from tinygrad.ops import UOp, Ops from tinygrad.renderer import ProgramSpec from tinygrad.tensor import Tensor, _to_np_dtype +from tinygrad.codegen import full_rewrite def _test_uop_result(inputs:List[Tensor], stores:List[UOp], local_size=None): for x in inputs: x.realize() @@ -35,7 +34,7 @@ class TestRendererFailures(unittest.TestCase): gate_alu = (lidx0:=UOp(Ops.SPECIAL, dtypes.int, (), ('lidx0', 4))).ne(0) gated_alu_store = UOp(Ops.STORE, dtypes.void, (a.index(lidx0, gate_alu), UOp.const(dtypes.int, 1))) sink = UOp(Ops.SINK, dtypes.void, (gated_alu_store,)) - uops = linearize_uop(full_graph_rewrite(sink, Device[Device.DEFAULT].renderer)) + uops = full_rewrite(sink, Device[Device.DEFAULT].renderer) ret = _test_uop_result([], uops, local_size=[4, 1, 1])[0] np.testing.assert_equal(ret, [0, 1, 1, 1]) @@ -46,7 +45,7 @@ class TestRendererFailures(unittest.TestCase): gate_alu_1 = (lidx1:=UOp(Ops.SPECIAL, dtypes.int, (), ('lidx1', 2))).ne(0) gated_alu_store = UOp(Ops.STORE, dtypes.void, (a.index(lidx0+lidx1*4, gate_alu_0&gate_alu_1), UOp.const(dtypes.int, 1))) sink = UOp(Ops.SINK, dtypes.void, (gated_alu_store,)) - uops = linearize_uop(full_graph_rewrite(sink, Device[Device.DEFAULT].renderer)) + uops = full_rewrite(sink, Device[Device.DEFAULT].renderer) ret = _test_uop_result([], uops, local_size=[4, 2, 1])[0] np.testing.assert_equal(ret, [0, 0, 0, 0, 0, 1, 1, 1]) @@ -60,7 +59,7 @@ class TestCStyleFailures(unittest.TestCase): alu = ld.alu(Ops.MAX, UOp.const(dtypes.int, dtypes.min(dtypes.int)+1)) store = UOp.store(a.index(idx), alu) sink = UOp(Ops.SINK, dtypes.void, (store,)) - uops = linearize_uop(full_graph_rewrite(sink, Device[Device.DEFAULT].renderer)) + uops = full_rewrite(sink, Device[Device.DEFAULT].renderer) # CPU doesn't use the max function ret = _test_uop_result([Tensor([1])], uops)[0] self.assertEqual(ret[0], 1) @@ -75,7 +74,7 @@ class TestPTXFailures(unittest.TestCase): if_uop = UOp(Ops.IF, dtypes.void, (gate_alu,)) gated_alu_store = UOp(Ops.STORE, dtypes.void, (a.index(lidx0, if_uop), val)) sink = UOp(Ops.SINK, dtypes.void, (gated_alu_store,)) - uops = linearize_uop(full_graph_rewrite(sink, Device[Device.DEFAULT].renderer)) + uops = full_rewrite(sink, Device[Device.DEFAULT].renderer) ret = _test_uop_result([], uops, local_size=[4, 1, 1])[0] np.testing.assert_equal(ret, [0, 1, 1, 1]) diff --git a/test/test_tensor.py b/test/test_tensor.py index eacf413d89..f10c365e18 100644 --- a/test/test_tensor.py +++ b/test/test_tensor.py @@ -10,9 +10,7 @@ from hypothesis import given, settings, strategies as strat from tinygrad.device import is_dtype_supported from tinygrad.ops import Ops, UOp from tinygrad.runtime.support.compiler_cuda import PTX -from tinygrad.codegen.linearize import linearize_uop -from tinygrad.codegen.devectorizer import full_graph_rewrite -from tinygrad.codegen.lowerer import rewrite_shapetracker_with_index +from tinygrad.codegen import full_rewrite from tinygrad.dtype import DType settings.register_profile("my_profile", max_examples=200, deadline=None, derandomize=getenv("DERANDOMIZE_CI", False)) @@ -847,7 +845,7 @@ class TestIdxUpcast(unittest.TestCase): for s in schedule: if s.ast.op is Ops.SINK: renderer = Device[s.bufs[0].device].renderer - uops = linearize_uop(full_graph_rewrite(rewrite_shapetracker_with_index(s.ast, renderer), renderer)) + uops = full_rewrite(s.ast, renderer) renderer.render(uops) return uops diff --git a/test/test_uop_graph.py b/test/test_uop_graph.py index 3f8d3b47c3..2c8c8852f1 100644 --- a/test/test_uop_graph.py +++ b/test/test_uop_graph.py @@ -1,14 +1,11 @@ from typing import List -import unittest, time, pytest -from tinygrad import dtypes, Device, Variable +import unittest, pytest +from tinygrad import dtypes, Variable from tinygrad.helpers import DEBUG, Context -from tinygrad.ops import Ops, UOp, KernelInfo, UPat, PatternMatcher, track_rewrites -from tinygrad.renderer import Renderer -from tinygrad.codegen.lowerer import rewrite_shapetracker_with_index -from tinygrad.codegen.devectorizer import full_graph_rewrite, graph_rewrite, sym -from tinygrad.codegen.expander import expander, expand_rewrite -from tinygrad.codegen.linearize import linearize_uop -from tinygrad.shape.shapetracker import ShapeTracker, View +from tinygrad.ops import Ops, UOp, UPat, PatternMatcher, track_rewrites, graph_rewrite +from tinygrad.codegen.symbolic import sym +from tinygrad.codegen import full_rewrite, full_rewrite_to_sink +from tinygrad.codegen.expander import expander simple_pm = PatternMatcher([ (UPat.cvar('x', dtypes.int), lambda x: UOp.const(dtypes.float, 1.0) + UOp.const(dtypes.float, 2.0)), @@ -19,54 +16,10 @@ simple_pm = PatternMatcher([ def to_uops_list(u:List[UOp]) -> List[UOp]: # we strip the SINK here for legacy reasons - ret = linearize_uop(full_graph_rewrite(UOp.sink(*u))) + ret = full_rewrite(UOp.sink(*u)) assert ret[-1].op is Ops.SINK return ret[:-1] -class TestGraphRewriteEfficiency(unittest.TestCase): - def test_create_many_uops(self): - c1 = UOp.const(dtypes.int, 1) - c2 = UOp.const(dtypes.int, 2) - st = time.perf_counter() - uops = [UOp(Ops.ADD, dtypes.int, (c1, c2)) for _ in range(10000)] - et = time.perf_counter() - st - print(f"created {len(uops)} uops in {et*1000:.2f} ms") - - def test_expand_rewrite(self): - sink = UOp(Ops.SINK, dtypes.void, arg=KernelInfo(local_dims=2, upcasted=4, dont_use_locals=False), src=( - UOp(Ops.STORE, dtypes.void, arg=None, src=( - UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()), - UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(2, 4, 64, 8, 16, 1, 1, 3, 3, 4, 1), - strides=(1179648, 9216, 1, 147456, 576, 0, 0, 64, 192, 36864, 0), - offset=0, mask=None, contiguous=False),)), src=()), - UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(Ops.ADD, (5, 6, 10)), src=( - UOp(Ops.CAST, dtypes.float, arg=None, src=( - UOp(Ops.MUL, dtypes.half, arg=None, src=( - UOp(Ops.LOAD, dtypes.half, arg=None, src=( - UOp(Ops.DEFINE_GLOBAL, dtypes.half.ptr(), arg=1, src=()), - UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=( - View(shape=(1, 1024, 1, 64, 4, 17, 4, 17), strides=(0, 14400, 0, 225, 0, 15, 0, 1), offset=-16, - mask=((0, 1), (0, 1024), (0, 1), (0, 64), (0, 4), (1, 16), (0, 4), (1, 16)), contiguous=False), - View(shape=(2, 4, 64, 8, 16, 16, 15, 3, 3, 4, 15), strides=(0, 73984, 4734976, 0, 4624, 295936, 68, 18, 1224, 0, 1), offset=0, - mask=None, contiguous=False))), src=()),)), - UOp(Ops.LOAD, dtypes.half, arg=None, src=( - UOp(Ops.DEFINE_GLOBAL, dtypes.half.ptr(), arg=2, src=()), - UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=( - View(shape=(2, 4, 64, 8, 16, 16, 15, 3, 3, 4, 15), strides=(7200, 0, 230400, 900, 0, 14400, 15, 0, 0, 225, 1), offset=0, - mask=None, contiguous=False),)), src=()),)),)),)),)),)),)) - lower_sink = rewrite_shapetracker_with_index(sink, Device[Device.DEFAULT].renderer) - cnt = [0] - old_init = UOp.__init__ - def uop_hook(self, *args, **kwargs): - cnt[0] += 1 - old_init(self, *args, **kwargs) - UOp.__init__ = uop_hook - st = time.perf_counter() - new_sink = full_graph_rewrite(lower_sink) - et = time.perf_counter() - st - UOp.__init__ = old_init - print(f"rewrote in {et*1000:.2f} ms, from {len(lower_sink.toposort())} -> {len(new_sink.toposort())}, creating {cnt[0]} uops") - class TestGraphRewriteConst(unittest.TestCase): def test_gep_const(self): v1 = UOp.const(dtypes.int.vec(3), (0,1,2)) @@ -572,8 +525,6 @@ class TestUOpGraph(unittest.TestCase): @track_rewrites() def expander_rewrite(sink): return graph_rewrite(sink, sym + expander) -@track_rewrites() -def float4_rewrite(sink): return full_graph_rewrite(sink, Renderer()) class TestExpander(unittest.TestCase): def test_expand_add_broadcast(self): @@ -735,7 +686,7 @@ class TestIFUOps(unittest.TestCase): lbuf = UOp(Ops.LOAD, dtypes.float, (sbuf.index(UOp.const(dtypes.int, 0)), barrier)) store = UOp(Ops.STORE, dtypes.void, (gbuf.index(UOp.const(dtypes.int, 0), gate), lbuf)) sink = UOp(Ops.SINK, dtypes.void, (store,)) - sink = full_graph_rewrite(expand_rewrite(sink)) + sink = full_rewrite_to_sink(sink) if_uops = [u for u in sink.toposort() if u.op is Ops.IF] self.assertEqual(len(if_uops), 1) self.assertEqual(if_uops[0].src[0], gate) @@ -753,7 +704,7 @@ class TestIFUOps(unittest.TestCase): lbufs = [UOp(Ops.LOAD, dtypes.float, (sbuf.index(UOp.const(dtypes.int, i)), barrier)) for i in range(4)] stores = [UOp(Ops.STORE, dtypes.void, (gbuf.index(UOp.const(dtypes.int, i), gate), lbufs[i])) for i in range(4)] sink = UOp(Ops.SINK, dtypes.void, tuple(stores)) - sink = full_graph_rewrite(expand_rewrite(sink)) + sink = full_rewrite_to_sink(sink) if_uops = [u for u in sink.toposort() if u.op is Ops.IF] self.assertEqual(len(if_uops), 1) self.assertEqual(if_uops[0].src[0], gate) @@ -769,7 +720,7 @@ class TestIFUOps(unittest.TestCase): gate = valid&(lidx.ne(2)) stores = [UOp(Ops.STORE, dtypes.void, (buf, UOp.const(dtypes.int, i), UOp.const(dtypes.float, i), gate)) for i in range(4)] sink = UOp(Ops.SINK, dtypes.void, tuple(stores)) - sink = full_graph_rewrite(sink) + sink = full_rewrite_to_sink(sink) if_uops = [u for u in sink.toposort() if u.op is Ops.IF] self.assertEqual(len(if_uops), 1) self.assertEqual(if_uops[0].src[0], gate) diff --git a/test/test_uops.py b/test/test_uops.py index 2563500c7f..dd111be637 100644 --- a/test/test_uops.py +++ b/test/test_uops.py @@ -12,16 +12,15 @@ from tinygrad.spec import spec from tinygrad.renderer import ProgramSpec from tinygrad.engine.grouper import fix_kernel_ops from tinygrad.engine.realize import CompiledRunner, get_kernel -from tinygrad.codegen.linearize import linearize_uop -from tinygrad.codegen.devectorizer import full_graph_rewrite +from tinygrad.codegen import full_rewrite from tinygrad.codegen.symbolic import sym from tinygrad.device import is_dtype_supported from tinygrad.codegen.kernel import Kernel, Opt, OptOps -def to_uops_list(u:list[UOp], opts=None, skip_check=False) -> list[UOp]: return linearize_uop(full_graph_rewrite(UOp.sink(*u), opts), skip_check) +def to_uops_list(u:list[UOp], opts=None, skip_check=False) -> list[UOp]: return full_rewrite(UOp.sink(*u), opts) def _uops_to_prg(uops_list): - uops = linearize_uop(full_graph_rewrite(ast:=UOp.sink(*uops_list), opts=Device[Device.DEFAULT].renderer)) + uops = full_rewrite(ast:=UOp.sink(*uops_list), opts=Device[Device.DEFAULT].renderer) src = Device[Device.DEFAULT].renderer.render(uops) has_local = Device[Device.DEFAULT].renderer.has_local return CompiledRunner(ProgramSpec("test", src, Device.DEFAULT, ast, uops=uops, @@ -503,7 +502,7 @@ class TestIndexingOrdering(unittest.TestCase): gidx0 = UOp(Ops.SPECIAL, dtypes.int, (), ('gidx0', 4)) st0 = UOp(Ops.STORE, dtypes.float.vec(4), (buf, gidx0+UOp.const(dtypes.int, 0), UOp.const(dtypes.float.vec(4), 42))) st1 = UOp(Ops.STORE, dtypes.float, (buf, UOp.const(dtypes.int, 4), UOp.const(dtypes.float, 10))) - uops = linearize_uop(UOp.sink(st1, st0), skip_check=True) + uops = full_rewrite(UOp.sink(st1, st0)) stores = [st for st in uops if st.op is Ops.STORE] assert stores[0].src[1] < stores[1].src[1], f"stored at idx {stores[1].src[1].arg} AFTER {stores[0].src[1].arg}" diff --git a/test/test_uops_stats.py b/test/test_uops_stats.py index 47f0b69a13..a5369bce51 100644 --- a/test/test_uops_stats.py +++ b/test/test_uops_stats.py @@ -3,7 +3,7 @@ from tinygrad import Tensor from tinygrad.helpers import getenv, GlobalCounters from tinygrad.engine.realize import lower_schedule_item, ProgramSpec from tinygrad.renderer import Estimates -from tinygrad.codegen.linearize import linearize_uop +from tinygrad.codegen import full_rewrite from tinygrad.ops import Ops, UOp from tinygrad.dtype import dtypes from tinygrad.codegen.kernel import Kernel, Opt, OptOps, KernelOptError @@ -144,7 +144,7 @@ class TestUOpsStats(unittest.TestCase): u3 = UOp(Ops.CONST, dtypes.int, tuple(), 3) u4 = UOp(Ops.MUL, dtypes.int, (u1,u2)) u5 = UOp(Ops.ADD, dtypes.int, (u4,u3)) - uops = linearize_uop(u5.sink()) + uops = full_rewrite(u5.sink()) globl = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), tuple()) o1 = UOp(Ops.CONST, dtypes.int, tuple(), 1) @@ -153,7 +153,7 @@ class TestUOpsStats(unittest.TestCase): u2 = UOp(Ops.LOAD, dtypes.int, (globl.index(o2),)) u3 = UOp(Ops.CONST, dtypes.int, tuple(), 3) u4 = UOp(Ops.MULACC, dtypes.int, (u1,u2,u3)) - uops_fma = linearize_uop(u4.sink()) + uops_fma = full_rewrite(u4.sink()) self.assertEqual(flops_mem(uops), flops_mem(uops_fma)) diff --git a/test/unit/test_graph_rewrite.py b/test/unit/test_graph_rewrite.py index 43d0dd7385..ed33c9e764 100644 --- a/test/unit/test_graph_rewrite.py +++ b/test/unit/test_graph_rewrite.py @@ -2,11 +2,11 @@ import unittest, math from tinygrad import dtypes from tinygrad.helpers import all_same from tinygrad.ops import GroupOp, UOp, Ops, exec_alu -from tinygrad.codegen.devectorizer import full_graph_rewrite +from tinygrad.codegen import full_rewrite_to_sink # Helper function to apply the graph rewrite def apply_rewrite(expr): - return full_graph_rewrite(expr.sink()).src[0] + return full_rewrite_to_sink(expr.sink()).src[0] def evaluate_uop(uop, variables): if uop.op == Ops.CONST: @@ -145,7 +145,7 @@ class TestModuloAndDivisionFolding(unittest.TestCase): class TestEdgeCasesAndSpecialOperations(unittest.TestCase): def test_full_graph_rewrite_transcendental_edge_cases(self): - optimized_sink = full_graph_rewrite(UOp.const(dtypes.float32, -1.0).log2().sink(UOp.const(dtypes.float32, 0.0).reciprocal())) + optimized_sink = full_rewrite_to_sink(UOp.const(dtypes.float32, -1.0).log2().sink(UOp.const(dtypes.float32, 0.0).reciprocal())) optimized_log2_neg, optimized_recip_zero = optimized_sink.src self.assertTrue(math.isnan(optimized_log2_neg.arg), f"Expected NaN for log2(-1.0), got {optimized_log2_neg.arg}") self.assertTrue(math.isinf(optimized_recip_zero.arg) and optimized_recip_zero.arg > 0, @@ -154,14 +154,14 @@ class TestEdgeCasesAndSpecialOperations(unittest.TestCase): @unittest.skip("broken") def test_full_graph_rewrite_modulo_negative_dividend(self): x_var_uop = UOp.variable('x', -5, -1) - optimized_sink = full_graph_rewrite((x_var_uop % 3).sink()) + optimized_sink = full_rewrite_to_sink((x_var_uop % 3).sink()) for x_value in range(-5, 0): self.assertEqual(x_value % 3, evaluate_uop(optimized_sink.src[0], {'x': x_value})) @unittest.skip("broken") def test_full_graph_rewrite_division_negative_divisor(self): x_var_uop = UOp.variable('x', 1, 5) - optimized_sink = full_graph_rewrite((x_var_uop // -2).sink()) + optimized_sink = full_rewrite_to_sink((x_var_uop // -2).sink()) for x_value in range(1, 6): self.assertEqual(x_value // -2, evaluate_uop(optimized_sink.src[0], {'x': x_value})) diff --git a/test/unit/test_simplify_valid_idx.py b/test/unit/test_simplify_valid_idx.py index b6bc56254b..b501c5a081 100644 --- a/test/unit/test_simplify_valid_idx.py +++ b/test/unit/test_simplify_valid_idx.py @@ -1,6 +1,6 @@ import unittest, itertools -from tinygrad.codegen.devectorizer import full_graph_rewrite +from tinygrad.codegen import full_rewrite_to_sink from tinygrad.dtype import dtypes from tinygrad.ops import UOp, Ops from tinygrad.codegen.symbolic import simplify_valid @@ -45,7 +45,7 @@ class TestHelpers(unittest.TestCase): class TestValidIdxSimplification(unittest.TestCase): def check(self, load, sidx, svalid): - load = full_graph_rewrite(load.sink()).src[0] + load = full_rewrite_to_sink(load.sink()).src[0] idx, valid = load.src[0].src[1], load.src[0].src[2] self.assertEqual(idx.render(simplify=False), sidx) self.assertEqual(valid.render(simplify=False), svalid) @@ -167,7 +167,7 @@ class TestValidIdxSimplification(unittest.TestCase): class TestImageSimplification(unittest.TestCase): def check(self, load, svalid, sidx0, sidx1): - load = full_graph_rewrite(load.sink()).src[0] + load = full_rewrite_to_sink(load.sink()).src[0] idx = load.src[0].src[1] self.assertEqual(idx.op, Ops.VECTORIZE) self.assertEqual(len(idx.src), 2) @@ -233,7 +233,7 @@ class TestImageSimplification(unittest.TestCase): # empty -> invalid load = get_load_image_uop(shape, (gidx0<8) & (gidx0<8).ne(True), idx) - load = full_graph_rewrite(load.sink()).src[0] + load = full_rewrite_to_sink(load.sink()).src[0] self.assertEqual(load.op, Ops.VECTORIZE) self.assertEqual(load.dtype.count, 4) diff --git a/test/unit/test_uop_symbolic.py b/test/unit/test_uop_symbolic.py index 4df4260f84..7bcfe28dea 100644 --- a/test/unit/test_uop_symbolic.py +++ b/test/unit/test_uop_symbolic.py @@ -2,8 +2,8 @@ import unittest, pickle from tinygrad.dtype import dtypes, ConstType -from tinygrad.codegen.linearize import linearize_uop -from tinygrad.codegen.devectorizer import full_graph_rewrite, sym +from tinygrad.codegen import full_rewrite +from tinygrad.codegen.devectorizer import sym from tinygrad.ops import UOp, Ops, graph_rewrite, sym_infer from tinygrad import Variable import functools @@ -11,7 +11,7 @@ import functools def render(self) -> tuple[str, ConstType, ConstType]: # NOTE: we need STORE so the ALU op has children glbl = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), arg=0) - uops = linearize_uop(full_graph_rewrite(UOp(Ops.STORE, dtypes.void, (glbl.index(UOp.const(dtypes.int, 0)), self)).sink())) + uops = full_rewrite(UOp(Ops.STORE, dtypes.void, (glbl.index(UOp.const(dtypes.int, 0)), self)).sink()) rewritten_uop = [uop for uop in uops if uop.op is Ops.STORE][0].src[-1] return rewritten_uop.render(simplify=False), rewritten_uop.vmin, rewritten_uop.vmax @@ -569,7 +569,7 @@ class TestSymbolic(unittest.TestCase): # TODO: copied from render, render does not support cast glbl = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), arg=0) - uops = linearize_uop(full_graph_rewrite(UOp(Ops.STORE, dtypes.void, (glbl.index(UOp.const(dtypes.int, 0)), expr)).sink())) + uops = full_rewrite(UOp(Ops.STORE, dtypes.void, (glbl.index(UOp.const(dtypes.int, 0)), expr)).sink()) rewritten_uop = [uop for uop in uops if uop.op is Ops.STORE][0].src[-1] self.assertEqual(rewritten_uop, cond.where(a.cast(dtypes.half), b.cast(dtypes.half))) diff --git a/tinygrad/codegen/__init__.py b/tinygrad/codegen/__init__.py index 9d1e55e221..528577adc5 100644 --- a/tinygrad/codegen/__init__.py +++ b/tinygrad/codegen/__init__.py @@ -72,3 +72,7 @@ def _get_rewrites_for_renderer(opts:Renderer, linearizer:bool, _QUANTIZE, _DEVEC ret.append(RewriteStep(block_merge, name="Linearizer: Merge Blocks")) ret.append(RewriteStep(pm_finalize, name="Linearizer: Finalize")) return ret + +def full_rewrite_to_sink(sink:UOp, opts:Renderer|None=None, linearizer:bool=False) -> UOp: + return apply_rewrites(sink, get_rewrites_for_renderer(opts if opts is not None else Renderer(), linearizer)) +def full_rewrite(sink:UOp, opts:Renderer|None=None) -> list[UOp]: return list(full_rewrite_to_sink(sink, opts, linearizer=True).arg.lst) diff --git a/tinygrad/codegen/devectorizer.py b/tinygrad/codegen/devectorizer.py index e22fbf76bf..50ee88069f 100644 --- a/tinygrad/codegen/devectorizer.py +++ b/tinygrad/codegen/devectorizer.py @@ -1,12 +1,12 @@ -from typing import Optional, Any, Callable, cast +from typing import Any, Callable, cast import functools, operator, itertools from collections import defaultdict from dataclasses import dataclass from tinygrad.device import is_dtype_supported from tinygrad.dtype import dtypes, ImageDType, PtrDType, promo_lattice, DType from tinygrad.ops import UOp, Ops, UPat, PatternMatcher, resolve, graph_rewrite, GroupOp, identity_element -from tinygrad.codegen.symbolic import symbolic_simple, split_uop, uop_given_valid, parse_valid, simplify_valid, sym, symbolic_flat, gep_pushing -from tinygrad.helpers import getenv, flatten, TRANSCENDENTAL, AMX, prod, DEVECTORIZE, partition +from tinygrad.codegen.symbolic import split_uop, uop_given_valid, parse_valid, simplify_valid, sym, symbolic_flat +from tinygrad.helpers import getenv, flatten, AMX, prod, partition from tinygrad.codegen.transcendental import xexp2, xlog2, xsin, xpow, TRANSCENDENTAL_SUPPORTED_DTYPES from tinygrad.renderer import Renderer @@ -429,26 +429,3 @@ pm_reduce = PatternMatcher([ (UPat(Ops.WMMA, name="wmma") + UPat.var("add"), lambda add, wmma: UOp(wmma.op, wmma.dtype, (wmma.src[0], wmma.src[1], wmma.src[2]+add), wmma.arg)), ])+sym - -# *** uop graph *** - -def full_graph_rewrite(sink:UOp, opts:Optional[Renderer]=None) -> UOp: - assert sink.op is Ops.SINK, f"sink isn't sink, it's {sink.op}" - supported_ops = tuple(opts.code_for_op.keys()) if opts is not None else () - extra_matcher = opts.extra_matcher if opts is not None and opts.extra_matcher is not None else PatternMatcher([]) - - # remove reduce - sink = graph_rewrite(sink, pm_reduce+gep_pushing, ctx=ReduceContext(), name="remove_reduce") - - # devectorize is optional - if DEVECTORIZE >= 2: sink = graph_rewrite(sink, sym+load_store_folding+load_store_indexing, ctx=opts) - elif DEVECTORIZE: sink = graph_rewrite(sink, sym+devectorize+load_store_folding+correct_load_store+load_store_indexing, ctx=opts) - else: sink = graph_rewrite(sink, sym+load_store_folding+correct_load_store+load_store_indexing, ctx=opts) - - # optional pre matcher - if opts is not None and opts.pre_matcher is not None: sink = graph_rewrite(sink, opts.pre_matcher) - - # final rules for the renderer (without sym) - sink = graph_rewrite(sink, symbolic_simple+get_late_rewrite_patterns(supported_ops, TRANSCENDENTAL>=2)+pm_render+extra_matcher, - ctx=opts, name="final rewrite") - return sink diff --git a/tinygrad/codegen/expander.py b/tinygrad/codegen/expander.py index 739bb03c99..11f33afe4a 100644 --- a/tinygrad/codegen/expander.py +++ b/tinygrad/codegen/expander.py @@ -2,8 +2,7 @@ import functools, itertools, operator from tinygrad.helpers import AMX, dedup, flatten, all_same, prod -from tinygrad.ops import UOp, Ops, UPat, PatternMatcher, GroupOp, graph_rewrite -from tinygrad.codegen.symbolic import sym +from tinygrad.ops import UOp, Ops, UPat, PatternMatcher, GroupOp def _expand_arg_to_idx(args:tuple[tuple[int, int], ...], rpk:dict[int, int]) -> int: idx, mul = 0, 1 @@ -137,16 +136,3 @@ pm_delete_ignore = PatternMatcher([ # IGNORE on SELF is nothing (UPat(Ops.IGNORE, src=(UPat(name="x"), UPat())), lambda x: x), ]) - -def expand_rewrite(sink:UOp) -> UOp: - # initial symbolic + migrate indexing (remove this) - sink = graph_rewrite(sink, sym+migrate_indexing) - - # store IGNORE - sink = graph_rewrite(sink, pm_store_ignore, name="store_ignore") - - # move IGNORE - sink = graph_rewrite(sink, pm_move_ignore, name="move_ignore") - - # expand + remove surviving ignores - return graph_rewrite(sink, pm_delete_ignore+sym+expander) diff --git a/tinygrad/codegen/kernel.py b/tinygrad/codegen/kernel.py index 15d653f9f7..dba3174e09 100644 --- a/tinygrad/codegen/kernel.py +++ b/tinygrad/codegen/kernel.py @@ -16,7 +16,7 @@ from tinygrad.shape.shapetracker import ShapeTracker from tinygrad.shape.view import strides_for_shape from tinygrad.codegen.lowerer import get_contraction from tinygrad.engine.grouper import view_left -from tinygrad.codegen import get_rewrites_for_renderer, apply_rewrites +from tinygrad.codegen import full_rewrite class KernelOptError(Exception): pass @@ -553,8 +553,7 @@ class Kernel: #if __debug__: type_verify(list(modified_ast.toposort()), shape_spec) try: - rewrite_list = get_rewrites_for_renderer(self.opts) - self.uops:list[UOp] = list(apply_rewrites(modified_ast, rewrite_list).arg.lst) + self.uops:list[UOp] = full_rewrite(modified_ast, self.opts) except RuntimeError: print("***** LINEARIZE FAILURE *****") print(f"ast = {self.ast}") diff --git a/tinygrad/codegen/linearize.py b/tinygrad/codegen/linearize.py index 7aa69be320..6cea5bb8ea 100644 --- a/tinygrad/codegen/linearize.py +++ b/tinygrad/codegen/linearize.py @@ -2,7 +2,7 @@ from __future__ import annotations import heapq from collections import defaultdict from dataclasses import dataclass, replace -from tinygrad.ops import UOp, Ops, graph_rewrite, PatternMatcher, UPat, GroupOp +from tinygrad.ops import UOp, Ops, PatternMatcher, UPat, GroupOp from tinygrad.helpers import dedup, partition, all_same, flatten from tinygrad.spec import type_verify @@ -243,23 +243,3 @@ def finalize(sink:UOp) -> UOp: return UOp(Ops.BLOCKFINAL, arg=BasicBlock2(tuple(lst))) pm_finalize = PatternMatcher([(UPat(Ops.BLOCK, name="sink"), finalize)]) - -def linearize_uop(sink:UOp, skip_check:bool=not __debug__) -> list[UOp]: - assert sink.op is Ops.SINK, f"sink isn't sink, it's {sink.op}" - - # get block context - ctx = BlockContext.from_sink(sink) - - # wrap all uops in blocks, already reordered - sink = graph_rewrite(sink, block_create, ctx=ctx, name="Linearizer: Create Blocks", bottom_up=True) - - # merge blockends - sink = graph_rewrite(sink, pm_blockend_merge, name="Linearizer: Merge Blockends") - - # merge blocks - sink = graph_rewrite(sink, block_merge, name="Linearizer: Merge Blocks") - - # finalize - sink = graph_rewrite(sink, pm_finalize, name="Linearizer: Finalize") - - return list(sink.arg.lst) \ No newline at end of file diff --git a/tinygrad/codegen/lowerer.py b/tinygrad/codegen/lowerer.py index 725184e2b1..38ac4203a9 100644 --- a/tinygrad/codegen/lowerer.py +++ b/tinygrad/codegen/lowerer.py @@ -3,10 +3,9 @@ import itertools, operator, math from dataclasses import dataclass from typing import cast from tinygrad.dtype import dtypes, PtrDType, least_upper_dtype -from tinygrad.ops import KernelInfo, UOp, Ops, graph_rewrite, PatternMatcher, UPat, sint, sint_to_uop +from tinygrad.ops import KernelInfo, UOp, Ops, PatternMatcher, UPat, sint, sint_to_uop from tinygrad.renderer import Renderer -from tinygrad.helpers import all_int, prod, partition, flatten, unwrap, QUANTIZE -from tinygrad.codegen.expander import expand_rewrite +from tinygrad.helpers import all_int, prod, partition, flatten, unwrap from tinygrad.codegen.symbolic import symbolic # returns the axes to create new_shape if new_shape can be created by combining axis from old_shape @@ -233,9 +232,3 @@ pm_quant = symbolic+PatternMatcher([ (UPat(Ops.REDUCE_AXIS, src=((UPat(Ops.CAST, name="v1")+UPat.var("c1")) * (UPat(Ops.CAST, name="v2",)+UPat.var("c2")),), name="r"), lambda v1,v2,c1,c2,r: r.replace(src=(v1*v2,)) + r.replace(src=(c2*v1,)) + r.replace(src=(c1*v2,)) + r.replace(src=(c1*c2,))), ]) - -def rewrite_shapetracker_with_index(ast:UOp, opts:Renderer) -> UOp: - if QUANTIZE and opts.device in {"CPU", "DSP"}: ast = graph_rewrite(ast, pm_quant, name="quantize") - sink = graph_rewrite(ast, pm_lowerer, ctx=get_index(ast, opts)) - # expand_rewrite turns this into a vectorized program - return expand_rewrite(sink)