From 7758f7211bc3bfa60f34375937e7cb2d6c229f68 Mon Sep 17 00:00:00 2001 From: chenyu Date: Sun, 3 Nov 2024 16:33:02 -0500 Subject: [PATCH] Revert "s/UPat/Pat (#7506)" [pr] (#7517) * Revert "s/UPat/Pat (#7506)" This reverts commit 400011a8c17c296459ee8b5391485c789a992dbb. * fix --- test/test_pickle.py | 4 +- test/test_uop_graph.py | 10 +- test/test_uops.py | 8 +- test/test_viz.py | 14 +- test/unit/test_pattern_matcher.py | 44 ++--- tinygrad/codegen/lowerer.py | 12 +- tinygrad/codegen/uopgraph.py | 148 +++++++-------- tinygrad/engine/schedule.py | 38 ++-- tinygrad/ops.py | 292 +++++++++++++++--------------- tinygrad/renderer/cstyle.py | 102 +++++------ tinygrad/renderer/ptx.py | 20 +- tinygrad/viz/serve.py | 2 +- 12 files changed, 347 insertions(+), 347 deletions(-) diff --git a/test/test_pickle.py b/test/test_pickle.py index f5d048f24f..1f0053a3ca 100644 --- a/test/test_pickle.py +++ b/test/test_pickle.py @@ -2,7 +2,7 @@ import unittest, pickle, types import numpy as np from tinygrad import Tensor, TinyJit, Variable, dtypes from tinygrad.engine.schedule import create_schedule -from tinygrad.ops import PatternMatcher, Pat, UOp +from tinygrad.ops import PatternMatcher, UPat, UOp class TestPickle(unittest.TestCase): def test_pickle_code_object(self): @@ -12,7 +12,7 @@ class TestPickle(unittest.TestCase): self.assertEqual(fxn(2), 4) def test_pickle_pattern_matcher(self): - pm = PatternMatcher([(Pat.cvar('x'), lambda x: x*2)]) + pm = PatternMatcher([(UPat.cvar('x'), lambda x: x*2)]) sink = UOp.const(dtypes.int, 2) tt = pm.rewrite(sink) pm_str = pickle.dumps(pm) diff --git a/test/test_uop_graph.py b/test/test_uop_graph.py index 7711a2c13b..8d56eb2d98 100644 --- a/test/test_uop_graph.py +++ b/test/test_uop_graph.py @@ -3,7 +3,7 @@ import unittest, time from tinygrad import dtypes, Device from tinygrad.helpers import DEBUG from tinygrad.ops import BinaryOps, TernaryOps, UnaryOps, Ops, UOp, KernelInfo -from tinygrad.ops import Pat, PatternMatcher +from tinygrad.ops import UPat, PatternMatcher from tinygrad.renderer import Renderer from tinygrad.codegen.lowerer import rewrite_shapetracker_with_index from tinygrad.codegen.uopgraph import full_graph_rewrite, graph_rewrite, expander, sym @@ -11,10 +11,10 @@ from tinygrad.codegen.linearize import linearize_uop from tinygrad.shape.shapetracker import ShapeTracker, View simple_pm = PatternMatcher([ - (Pat.cvar('x', dtypes.int), lambda x: UOp.const(dtypes.float, 1.0) + UOp.const(dtypes.float, 2.0)), - (Pat.cvar('x') + Pat.cvar('y'), lambda x,y: UOp.const(dtypes.float, x.arg+y.arg)), - (Pat.cvar('x') * Pat.cvar('y') * Pat.cvar('z'), lambda x,y,z: UOp.const(dtypes.float, x.arg*y.arg*z.arg)), - ((Pat.var('x') + Pat.cvar('c1')) + Pat.cvar('c2'), lambda x,c1,c2: x + (c1.arg+c2.arg)), + (UPat.cvar('x', dtypes.int), lambda x: UOp.const(dtypes.float, 1.0) + UOp.const(dtypes.float, 2.0)), + (UPat.cvar('x') + UPat.cvar('y'), lambda x,y: UOp.const(dtypes.float, x.arg+y.arg)), + (UPat.cvar('x') * UPat.cvar('y') * UPat.cvar('z'), lambda x,y,z: UOp.const(dtypes.float, x.arg*y.arg*z.arg)), + ((UPat.var('x') + UPat.cvar('c1')) + UPat.cvar('c2'), lambda x,c1,c2: x + (c1.arg+c2.arg)), ]) def to_uops_list(u:List[UOp]) -> List[UOp]: return linearize_uop(full_graph_rewrite(UOp.sink(*u))) diff --git a/test/test_uops.py b/test/test_uops.py index 0f233cf596..47b80780d0 100644 --- a/test/test_uops.py +++ b/test/test_uops.py @@ -6,7 +6,7 @@ from tinygrad.tensor import Tensor, _to_np_dtype from tinygrad.helpers import CI, DEBUG, getenv, Context from tinygrad.dtype import dtypes, DType from tinygrad.device import Buffer, Device -from tinygrad.ops import Ops, UOp, Pat, UnaryOps, BinaryOps, TernaryOps, ReduceOps, KernelInfo, exec_alu, spec # noqa F401 +from tinygrad.ops import Ops, UOp, UPat, UnaryOps, BinaryOps, TernaryOps, ReduceOps, KernelInfo, exec_alu, spec # noqa F401 from tinygrad.renderer import Program from tinygrad.engine.schedule import create_schedule, to_si from tinygrad.engine.realize import CompiledRunner, lower_schedule_item, get_kernel @@ -441,13 +441,13 @@ class TestIndexingOrdering(unittest.TestCase): stores = [st for st in uops if st.op is Ops.STORE] assert stores[0].src[1] < stores[1].src[1], f"stored at idx {stores[1].src[1].arg} AFTER {stores[0].src[1].arg}" -class TestPatHelpers(unittest.TestCase): +class TestUPatHelpers(unittest.TestCase): def test_location(self): self.assertEqual(sym.patterns[-1][0].location[0].replace("\\", "/").split("/")[-1], "uopgraph.py") self.assertEqual(to_si.patterns[0][0].location[0].replace("\\", "/").split("/")[-1], "schedule.py") self.assertEqual(spec.patterns[0][0].location[0].replace("\\", "/").split("/")[-1], "ops.py") - with self.assertRaises(AssertionError): # TODO: location Pat files created in test/*? - test_upat = Pat(Ops.CONST, dtypes.bool) + with self.assertRaises(AssertionError): # TODO: location UPat files created in test/*? + test_upat = UPat(Ops.CONST, dtypes.bool) self.assertEqual(test_upat.location[0].split("/")[-1], __file__.replace("\\", "/").split("/")[-1]) if __name__ == '__main__': diff --git a/test/test_viz.py b/test/test_viz.py index 2db2c57b43..02f8820cc5 100644 --- a/test/test_viz.py +++ b/test/test_viz.py @@ -1,7 +1,7 @@ from typing import Dict, List, Optional import unittest from tinygrad.dtype import dtypes -from tinygrad.ops import TRACK_MATCH_STATS, BinaryOps, TrackedPatternMatcher as PatternMatcher, UOp, Ops, Pat, \ +from tinygrad.ops import TRACK_MATCH_STATS, BinaryOps, TrackedPatternMatcher as PatternMatcher, UOp, Ops, UPat, \ graph_rewrite, contexts, track_rewrites from tinygrad.viz.serve import get_details, get_metadata, uop_to_json @@ -25,7 +25,7 @@ class TestViz(unittest.TestCase): def test_viz_simple(self): pm = PatternMatcher([ - (Pat.var("x")*1, lambda x:x), + (UPat.var("x")*1, lambda x:x), ]) a = UOp(Ops.LOAD, dtypes.int, (UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), (), 0), UOp.const(dtypes.int, 0))) uops = helper_test_viz(a*1, pm) @@ -34,8 +34,8 @@ class TestViz(unittest.TestCase): def test_rewrite_twice(self): pm = PatternMatcher([ - (Pat.var("x")+Pat.var("x"), lambda x:x*2), - (Pat.var("x", dtypes.int)*2, lambda x:x.alu(BinaryOps.SHL, UOp.const(dtypes.int, 1))), + (UPat.var("x")+UPat.var("x"), lambda x:x*2), + (UPat.var("x", dtypes.int)*2, lambda x:x.alu(BinaryOps.SHL, UOp.const(dtypes.int, 1))), ]) a = UOp(Ops.LOAD, dtypes.int, (UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), (), 0), UOp.const(dtypes.int, 0))) uops = helper_test_viz(a+a, pm) @@ -51,14 +51,14 @@ class TestViz(unittest.TestCase): ctx[x] = None return UOp.store(*x.src, x) pm = PatternMatcher([ - (Pat(Ops.LOAD, name="x"), store_load), + (UPat(Ops.LOAD, name="x"), store_load), ]) uops = helper_test_viz(a+b, pm, {}) self.assertEqual(len(uops), 2) self.assertEqual(uops[-1], graph_rewrite(a+b, pm, {})) def test_track_rewrites(self): - simple = PatternMatcher([(Pat.var("x")*1, lambda x:x)]) + simple = PatternMatcher([(UPat.var("x")*1, lambda x:x)]) @track_rewrites(named=True) def do_rewrite(x:UOp): return graph_rewrite(x, simple) ld = UOp(Ops.LOAD, dtypes.int, (UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), arg=1), UOp.const(dtypes.int, 0))) @@ -74,7 +74,7 @@ class TestViz(unittest.TestCase): self.assertEqual(len(m.upats), 0) def test_track_rewrites_with_exception(self): - simple = PatternMatcher([(Pat.var("x")*1, lambda x:x)]) + simple = PatternMatcher([(UPat.var("x")*1, lambda x:x)]) @track_rewrites() def do_rewrite(x:UOp): x = graph_rewrite(x, simple) # NOTE: viz tracks this diff --git a/test/unit/test_pattern_matcher.py b/test/unit/test_pattern_matcher.py index 6a920b75df..b4a1715bcd 100644 --- a/test/unit/test_pattern_matcher.py +++ b/test/unit/test_pattern_matcher.py @@ -1,11 +1,11 @@ import unittest, itertools from tinygrad.dtype import dtypes from tinygrad.ops import Ops, UOp, BinaryOps, TernaryOps, ReduceOps, UnaryOps # noqa: F401 -from tinygrad.ops import PatternMatcher, Pat +from tinygrad.ops import PatternMatcher, UPat class TestPatternMatcher(unittest.TestCase): def test_simple_match(self): - matcher = PatternMatcher([(Pat(Ops.CONST, name="x", dtype=dtypes.float), lambda x: x)]) + matcher = PatternMatcher([(UPat(Ops.CONST, name="x", dtype=dtypes.float), lambda x: x)]) c1 = UOp(Ops.CONST, dtypes.float, arg=1.0) c2 = UOp(Ops.CONST, dtypes.int, arg=1) self.assertEqual(matcher.rewrite(c1), c1) @@ -16,7 +16,7 @@ class TestPatternMatcher(unittest.TestCase): #print(x,y,z) if y is not None: return a+y matcher = PatternMatcher([ - (Pat.var("a")+Pat.any(Pat.var("x"), Pat.var("y"), Pat.var("z")), test), + (UPat.var("a")+UPat.any(UPat.var("x"), UPat.var("y"), UPat.var("z")), test), ]) v1 = UOp.variable("a", 0, 10) v2 = UOp.variable("b", 0, 10) @@ -31,7 +31,7 @@ class TestPatternMatcher(unittest.TestCase): match_cnt += 1 assert len(x.src) == 0 return UOp(Ops.CONST, src=(UOp(Ops.CONST),)) - matcher = PatternMatcher([(Pat(Ops.CONST, src=(), name="x"), fxn)]) + matcher = PatternMatcher([(UPat(Ops.CONST, src=(), name="x"), fxn)]) c1 = UOp(Ops.CONST, dtypes.float, arg=1.0) # second rewrite shouldn't match anything c1 = matcher.rewrite(c1) @@ -43,7 +43,7 @@ class TestPatternMatcher(unittest.TestCase): ctx.append(True) assert len(x.src) == 0 return UOp(Ops.CONST, src=(UOp(Ops.CONST),)) - matcher = PatternMatcher([(Pat(Ops.CONST, src=(), name="x"), fxn)]) + matcher = PatternMatcher([(UPat(Ops.CONST, src=(), name="x"), fxn)]) c1 = UOp(Ops.CONST, dtypes.float, arg=1.0) # second rewrite shouldn't match anything ctx = [] @@ -52,14 +52,14 @@ class TestPatternMatcher(unittest.TestCase): self.assertEqual(len(ctx), 1) def test_uop(self): - matcher = PatternMatcher([(Pat(Ops.CONST, name="x"), lambda x: x)]) + matcher = PatternMatcher([(UPat(Ops.CONST, name="x"), lambda x: x)]) c1 = UOp(Ops.CONST, dtypes.float, arg=1.0) c2 = UOp(Ops.ALU, dtypes.float, (c1, c1), BinaryOps.ADD) self.assertEqual(matcher.rewrite(c1), c1) self.assertEqual(matcher.rewrite(c2), None) def test_uop_set(self): - matcher = PatternMatcher([(Pat({Ops.CONST, Ops.CAST}, name="x"), lambda x: x)]) + matcher = PatternMatcher([(UPat({Ops.CONST, Ops.CAST}, name="x"), lambda x: x)]) c1 = UOp(Ops.CONST, dtypes.bool, arg=False) c2 = UOp(Ops.CAST, dtypes.int, (c1,)) c3 = UOp(Ops.CONST, dtypes.float, arg=1.0) @@ -70,9 +70,9 @@ class TestPatternMatcher(unittest.TestCase): def test_arg(self): matcher = PatternMatcher([ - (Pat(Ops.CONST, arg=0, name="x"), lambda x: x), - (Pat(Ops.CONST, arg=False, name="x"), lambda x: x), - (Pat(Ops.ALU, arg=BinaryOps.MAX, name="x"), lambda x: x), + (UPat(Ops.CONST, arg=0, name="x"), lambda x: x), + (UPat(Ops.CONST, arg=False, name="x"), lambda x: x), + (UPat(Ops.ALU, arg=BinaryOps.MAX, name="x"), lambda x: x), ]) c1 = UOp(Ops.CONST, dtypes.float, arg=0.0) c2 = UOp(Ops.CONST, dtypes.bool, arg=False) @@ -87,7 +87,7 @@ class TestPatternMatcher(unittest.TestCase): def test_filter_arg(self): matcher = PatternMatcher([ - (Pat(Ops.ALU, arg=BinaryOps.MUL, src=[Pat(Ops.CONST, name="c"), Pat(Ops.CONST, arg=2)], name="x"), + (UPat(Ops.ALU, arg=BinaryOps.MUL, src=[UPat(Ops.CONST, name="c"), UPat(Ops.CONST, arg=2)], name="x"), lambda x,c: x if c.arg in {1, -1} else None) ]) y1 = UOp(Ops.CONST, dtypes.int, arg=1) @@ -105,7 +105,7 @@ class TestPatternMatcher(unittest.TestCase): self.assertEqual(matcher.rewrite(c5), c5) def test_dup_name(self): - matcher = PatternMatcher([(Pat(Ops.ALU, name="x", src=(Pat(Ops.CONST, name="y"), Pat(Ops.CONST, name="y"))), lambda x, y: x)]) + matcher = PatternMatcher([(UPat(Ops.ALU, name="x", src=(UPat(Ops.CONST, name="y"), UPat(Ops.CONST, name="y"))), lambda x, y: x)]) y1 = UOp(Ops.CONST, dtypes.float, arg=1.0) y2 = UOp(Ops.CONST, dtypes.float, arg=1.0) c1 = UOp(Ops.ALU, dtypes.float, (y1, y1), BinaryOps.ADD) @@ -114,14 +114,14 @@ class TestPatternMatcher(unittest.TestCase): self.assertEqual(matcher.rewrite(c2), c1) def test_dtype(self): - matcher = PatternMatcher([(Pat(Ops.CONST, name="x", dtype=dtypes.float32), lambda x: x)]) + matcher = PatternMatcher([(UPat(Ops.CONST, name="x", dtype=dtypes.float32), lambda x: x)]) c1 = UOp(Ops.CONST, dtypes.float, arg=1.0) c2 = UOp(Ops.CONST, dtypes.float64, arg=1.0) self.assertEqual(matcher.rewrite(c1), c1) self.assertEqual(matcher.rewrite(c2), None) def test_dtype_set(self): - matcher = PatternMatcher([(Pat(Ops.CONST, name="x", dtype={dtypes.float32, dtypes.float64}), lambda x: x)]) + matcher = PatternMatcher([(UPat(Ops.CONST, name="x", dtype={dtypes.float32, dtypes.float64}), lambda x: x)]) c1 = UOp(Ops.CONST, dtypes.float, arg=1.0) c2 = UOp(Ops.CONST, dtypes.float64, arg=1.0) c3 = UOp(Ops.CONST, dtypes.float16, arg=1.0) @@ -132,7 +132,7 @@ class TestPatternMatcher(unittest.TestCase): self.assertEqual(matcher.rewrite(c4), None) def test_src_one(self): - matcher = PatternMatcher([(Pat(Ops.ALU, name="x", src=(Pat(Ops.CONST), Pat(Ops.CONST))), lambda x: x)]) + matcher = PatternMatcher([(UPat(Ops.ALU, name="x", src=(UPat(Ops.CONST), UPat(Ops.CONST))), lambda x: x)]) c1 = UOp(Ops.CONST, dtypes.float, arg=1.0) c2 = UOp(Ops.CONST, dtypes.float, arg=2.0) c3 = UOp(Ops.ALU, dtypes.float, (c1,c2), BinaryOps.ADD) @@ -140,7 +140,7 @@ class TestPatternMatcher(unittest.TestCase): self.assertEqual(matcher.rewrite(c2), None) # that CONST/ALU -> ALU/CONST rewrite is now instant """ - matcher = PatternMatcher([(Pat(UOps.ALU, name="x", src=(Pat(UOps.CONST), Pat(UOps.ALU))), lambda x: x)]) + matcher = PatternMatcher([(UPat(UOps.ALU, name="x", src=(UPat(UOps.CONST), UPat(UOps.ALU))), lambda x: x)]) c4 = UOp(UOps.ALU, dtypes.float, (c1,c3), BinaryOps.ADD) c5 = UOp(UOps.ALU, dtypes.float, (c3,c1), BinaryOps.ADD) self.assertEqual(matcher.rewrite(c3), None) @@ -149,7 +149,7 @@ class TestPatternMatcher(unittest.TestCase): """ def test_src_permutations(self): - matcher = PatternMatcher([(Pat(Ops.ALU, name="x", src=[Pat(Ops.CONST), Pat(Ops.ALU)]), lambda x: x)]) + matcher = PatternMatcher([(UPat(Ops.ALU, name="x", src=[UPat(Ops.CONST), UPat(Ops.ALU)]), lambda x: x)]) c1 = UOp(Ops.CONST, dtypes.float, arg=1.0) c2 = UOp(Ops.CONST, dtypes.float, arg=2.0) c3 = UOp(Ops.ALU, dtypes.float, (c1,c2), BinaryOps.ADD) @@ -162,7 +162,7 @@ class TestPatternMatcher(unittest.TestCase): self.assertEqual(matcher.rewrite(c6), None) def test_src_repeat(self): - matcher = PatternMatcher([(Pat(Ops.ALU, name="x", src=Pat(Ops.CONST)), lambda x: x)]) + matcher = PatternMatcher([(UPat(Ops.ALU, name="x", src=UPat(Ops.CONST)), lambda x: x)]) c1 = UOp(Ops.CONST, dtypes.float, arg=1.0) c2 = UOp(Ops.CONST, dtypes.float, arg=2.0) c3 = UOp(Ops.ALU, dtypes.float, (c1,c2), BinaryOps.ADD) @@ -171,7 +171,7 @@ class TestPatternMatcher(unittest.TestCase): self.assertEqual(matcher.rewrite(c4), None) def test_allow_len(self): - matcher = PatternMatcher([(Pat(Ops.ALU, name="x", src=(Pat(Ops.CONST),), allow_any_len=True, arg=TernaryOps.MULACC), lambda x: x)]) + matcher = PatternMatcher([(UPat(Ops.ALU, name="x", src=(UPat(Ops.CONST),), allow_any_len=True, arg=TernaryOps.MULACC), lambda x: x)]) c1 = UOp(Ops.CONST, dtypes.float, arg=1.0) c2 = UOp(Ops.CONST, dtypes.float, arg=2.0) c3 = UOp(Ops.CONST, dtypes.float, arg=3.0) @@ -188,16 +188,16 @@ class TestPatternMatcher(unittest.TestCase): u1 = (c1 + c2) + c1 u2 = (c2 + c1) + c1 matcher = PatternMatcher([ - (Pat(Ops.ALU, src=[Pat(Ops.ALU, src=[Pat(name='a'), Pat(name='b')]), Pat(name='b')]), lambda a,b: b) + (UPat(Ops.ALU, src=[UPat(Ops.ALU, src=[UPat(name='a'), UPat(name='b')]), UPat(name='b')]), lambda a,b: b) ]) self.assertIsNotNone(matcher.rewrite(u1)) self.assertIsNotNone(matcher.rewrite(u2)) - def _assert_eq_upat(self, a:Pat, b:Pat): + def _assert_eq_upat(self, a:UPat, b:UPat): assert (sorted(map(str,a.op)) if a.op else [] == (sorted(map(str,b.op)) if b.op else [])) assert (sorted(a.dtype) if a.dtype else [] == (sorted(b.dtype) if b.dtype else [])) assert (a.name, type(a.src)) == (b.name, type(b.src)) - def simple_src(u:Pat): + def simple_src(u:UPat): if u.src is None: return [] if isinstance(u.src, itertools.repeat): return next(u.src[0]) return u.src[0] diff --git a/tinygrad/codegen/lowerer.py b/tinygrad/codegen/lowerer.py index d354151b4e..1c2562f189 100644 --- a/tinygrad/codegen/lowerer.py +++ b/tinygrad/codegen/lowerer.py @@ -6,7 +6,7 @@ from typing import List, Tuple, cast, Optional from tinygrad.shape.shapetracker import ShapeTracker from tinygrad.shape.view import variable_to_uop from tinygrad.dtype import dtypes -from tinygrad.ops import KernelInfo, BinaryOps, UOp, Ops, graph_rewrite, PatternMatcher, Pat, sint, identity_element +from tinygrad.ops import KernelInfo, BinaryOps, UOp, Ops, graph_rewrite, PatternMatcher, UPat, sint, identity_element from tinygrad.renderer import Renderer from tinygrad.helpers import all_int, prod, partition, flatten @@ -109,7 +109,7 @@ def lower_reduce_axis(ctx: IndexContext, x: UOp): def lower_load_store(ctx: IndexContext, x: UOp): idx, valid = x.st_arg.to_indexed_uops(ctx.ridxs if x.op is Ops.LOAD and x.src[0].op is Ops.DEFINE_LOCAL else ctx.idxs) - # TODO: check has_valid in Pat, not here + # TODO: check has_valid in UPat, not here has_valid = valid.op is not Ops.CONST or valid.arg is not True buf = x.src[0] if x.op is Ops.LOAD: @@ -127,10 +127,10 @@ def lower_load_store(ctx: IndexContext, x: UOp): return UOp(Ops.STORE, dtypes.void, (buf, idx, x.src[2]) + ((valid,) if has_valid else ())) pm_lowerer = PatternMatcher([ - (Pat(Ops.REDUCE_AXIS, name="x"), lower_reduce_axis), - (Pat(Ops.VALID, src=(Pat(Ops.VIEW),), name="x"), lambda ctx,x: x.st_arg.to_indexed_uops(ctx.idxs)[1]), + (UPat(Ops.REDUCE_AXIS, name="x"), lower_reduce_axis), + (UPat(Ops.VALID, src=(UPat(Ops.VIEW),), name="x"), lambda ctx,x: x.st_arg.to_indexed_uops(ctx.idxs)[1]), # rewrite LOAD/STORE VIEW to LOAD/STORE with indexed - (Pat((Ops.LOAD, Ops.STORE), src=(Pat(), Pat(Ops.VIEW)), allow_any_len=True, name="x"), lower_load_store), + (UPat((Ops.LOAD, Ops.STORE), src=(UPat(), UPat(Ops.VIEW)), allow_any_len=True, name="x"), lower_load_store), ]) def do_reduce(ctx:List[int], root:UOp): @@ -141,7 +141,7 @@ def do_reduce(ctx:List[int], root:UOp): just_reduce = PatternMatcher([ # do reduce - (Pat(Ops.REDUCE, name="root"), do_reduce), + (UPat(Ops.REDUCE, name="root"), do_reduce), ]) def rewrite_shapetracker_with_index(ast:UOp, opts:Renderer) -> UOp: diff --git a/tinygrad/codegen/uopgraph.py b/tinygrad/codegen/uopgraph.py index ab45a02b4a..6990dc2454 100644 --- a/tinygrad/codegen/uopgraph.py +++ b/tinygrad/codegen/uopgraph.py @@ -3,7 +3,7 @@ from typing import Optional, Tuple, Dict, List, TYPE_CHECKING, Any, DefaultDict, import functools, itertools, operator from collections import defaultdict from tinygrad.dtype import dtypes, ImageDType, PtrDType -from tinygrad.ops import UnaryOps, BinaryOps, TernaryOps, UOp, Ops, Pat, PatternMatcher, symbolic_flat, symbolic_simple +from tinygrad.ops import UnaryOps, BinaryOps, TernaryOps, UOp, Ops, UPat, PatternMatcher, symbolic_flat, symbolic_simple from tinygrad.ops import graph_rewrite, is_irreducible, split_uop, uop_given_valid, parse_valid, is_increasing, simplify_valid from tinygrad.helpers import DEBUG, getenv, flatten, dedup, TRANSCENDENTAL, AMX, prod, partition, all_same from tinygrad.codegen.transcendental import xexp2, xlog2, xsin, TRANSCENDENTAL_SUPPORTED_DTYPES @@ -77,10 +77,10 @@ def fix_unfoldable_image_load(load:UOp, buf:UOp): vec_load = UOp(Ops.LOAD, load.dtype.vec(4), tuple(new_src)) return functools.reduce(lambda ret, i: id4.ne(i).where(ret, vec_load.gep(i)), range(4), load.const_like(float('nan'))) -buf_idx_pat = Pat(Ops.INDEX, src=(Pat.var("buf"),), allow_any_len=True) +buf_idx_pat = UPat(Ops.INDEX, src=(UPat.var("buf"),), allow_any_len=True) float4_folding = PatternMatcher([ - (Pat(Ops.VECTORIZE, src=Pat(Ops.LOAD, src=(buf_idx_pat,), allow_any_len=True), name="ex"), fold_expanded), - (Pat((Ops.BARRIER, Ops.SINK), src=Pat(Ops.STORE, src=(buf_idx_pat,), allow_any_len=True), name="ex"), fold_expanded), + (UPat(Ops.VECTORIZE, src=UPat(Ops.LOAD, src=(buf_idx_pat,), allow_any_len=True), name="ex"), fold_expanded), + (UPat((Ops.BARRIER, Ops.SINK), src=UPat(Ops.STORE, src=(buf_idx_pat,), allow_any_len=True), name="ex"), fold_expanded), ]) # ***** image load valid simplification ***** @@ -124,24 +124,24 @@ def simplify_valid_load(buf:UOp, start_idx:UOp, valid:UOp) -> Optional[UOp]: powers_of_two = {2**i:i for i in range(64)} @functools.lru_cache(None) def get_late_rewrite_patterns(ops, force_transcendental=False): - pat: List[Tuple[Pat, Callable]] = [(Pat(Ops.ALU, dtype=TRANSCENDENTAL_SUPPORTED_DTYPES, src=(Pat.var("d"),), arg=op), f) for op,f in \ + pat: List[Tuple[UPat, Callable]] = [(UPat(Ops.ALU, dtype=TRANSCENDENTAL_SUPPORTED_DTYPES, src=(UPat.var("d"),), arg=op), f) for op,f in \ ((UnaryOps.EXP2, xexp2), (UnaryOps.LOG2, xlog2), (UnaryOps.SIN, xsin)) if op not in ops or force_transcendental] # rewrite MOD to AND (which should always be supported, but not for generic in tests) if BinaryOps.AND in ops: - pat += [(Pat(Ops.ALU, arg=BinaryOps.MOD, src=(Pat.var('base'), Pat.cvar("const"))), + pat += [(UPat(Ops.ALU, arg=BinaryOps.MOD, src=(UPat.var('base'), UPat.cvar("const"))), lambda base,const: base & (const.arg-1) if const.arg in powers_of_two else None)] # rewrite MUL/IDIV to SHL+SHR if BinaryOps.SHL in ops and BinaryOps.SHR in ops: pat += [ - (Pat(Ops.ALU, arg=BinaryOps.MUL, dtype=dtypes.ints, src=[Pat.cvar("const"), Pat.var("mul")]), lambda mul, const: + (UPat(Ops.ALU, arg=BinaryOps.MUL, dtype=dtypes.ints, src=[UPat.cvar("const"), UPat.var("mul")]), lambda mul, const: mul << powers_of_two[const.arg] if const.arg in powers_of_two else None), # (x * (2**y)) -> shl(x,y) - (Pat(Ops.ALU, arg=BinaryOps.IDIV, src=(Pat.var("div"), Pat.cvar("const"))), lambda div, const: + (UPat(Ops.ALU, arg=BinaryOps.IDIV, src=(UPat.var("div"), UPat.cvar("const"))), lambda div, const: div >> powers_of_two[const.arg] if const.arg in powers_of_two else None)] # (x // (2**y)) -> shr(x,y) if UnaryOps.NEG in ops: - pat += [(Pat.var('x')*-1, lambda x: x.alu(UnaryOps.NEG))] - if BinaryOps.SUB in ops: pat += [(Pat.var('x')+Pat.var('y').alu(UnaryOps.NEG), lambda x,y: x.alu(BinaryOps.SUB, y))] + pat += [(UPat.var('x')*-1, lambda x: x.alu(UnaryOps.NEG))] + if BinaryOps.SUB in ops: pat += [(UPat.var('x')+UPat.var('y').alu(UnaryOps.NEG), lambda x,y: x.alu(BinaryOps.SUB, y))] if TernaryOps.MULACC in ops: - pat += [(Pat.var('a')*Pat.var('b')+Pat.var('c'), lambda a,b,c: a.alu(TernaryOps.MULACC, b, c))] + pat += [(UPat.var('a')*UPat.var('b')+UPat.var('c'), lambda a,b,c: a.alu(TernaryOps.MULACC, b, c))] return PatternMatcher(pat) # ***** threefry ***** @@ -231,79 +231,79 @@ def reduce_collapse(acc:UOp, ret:UOp, alu:UOp): for r in reduce_unparented: ret = ret * (r.src[1]-r.src[0]).cast(ret.dtype.scalar()).broadcast(ret.dtype.count) return ret -acc_pat, rng_pat = Pat(Ops.DEFINE_ACC, name="acc"), Pat(Ops.RANGE, name="rng") -rng_aug = Pat.any(rng_pat, Pat.var("add")+rng_pat, Pat.var("mul")*rng_pat, Pat.var("add")+Pat.var("mul")*rng_pat) +acc_pat, rng_pat = UPat(Ops.DEFINE_ACC, name="acc"), UPat(Ops.RANGE, name="rng") +rng_aug = UPat.any(rng_pat, UPat.var("add")+rng_pat, UPat.var("mul")*rng_pat, UPat.var("add")+UPat.var("mul")*rng_pat) -index_load = Pat.var("buf").index(rng_aug).load(name="ld") +index_load = UPat.var("buf").index(rng_aug).load(name="ld") -arange_augrng = Pat.any(rng_aug, rng_aug+Pat.var("idx2"), rng_aug+Pat.var("idx2")+Pat.var("idx3"), Pat(Ops.VECTORIZE, name="vec", src=rng_aug)) -arange_m = arange_augrng.lt(Pat.cvar("compval")).ne(Pat(Ops.CONST, name="ne", arg=True)).where(Pat.cvar("multconst"), Pat.const(None, 0)) +arange_augrng = UPat.any(rng_aug, rng_aug+UPat.var("idx2"), rng_aug+UPat.var("idx2")+UPat.var("idx3"), UPat(Ops.VECTORIZE, name="vec", src=rng_aug)) +arange_m = arange_augrng.lt(UPat.cvar("compval")).ne(UPat(Ops.CONST, name="ne", arg=True)).where(UPat.cvar("multconst"), UPat.const(None, 0)) # this is symbolic 2.0 sym = symbolic_flat+PatternMatcher([ # self ASSIGN is just self - (Pat(Ops.ASSIGN, src=(Pat.var('x'), Pat.var('x'))), lambda x: x), + (UPat(Ops.ASSIGN, src=(UPat.var('x'), UPat.var('x'))), lambda x: x), # ASSIGN to global is just self - (Pat(Ops.ASSIGN, src=(Pat(Ops.DEFINE_GLOBAL), Pat.var("x"))), lambda x: x), + (UPat(Ops.ASSIGN, src=(UPat(Ops.DEFINE_GLOBAL), UPat.var("x"))), lambda x: x), # VECTORIZE/CONST, VECTORIZE/GEP - (Pat(Ops.VECTORIZE, src=Pat(Ops.CONST), name="vec"), lambda vec: UOp.const(vec.dtype, tuple(x.arg for x in vec.src))), - (Pat(Ops.VECTORIZE, src=Pat(Ops.GEP, src=(Pat(name="x"),)), name="vec"), lambda vec,x: x.gep(tuple(y.arg[0] for y in vec.src))), + (UPat(Ops.VECTORIZE, src=UPat(Ops.CONST), name="vec"), lambda vec: UOp.const(vec.dtype, tuple(x.arg for x in vec.src))), + (UPat(Ops.VECTORIZE, src=UPat(Ops.GEP, src=(UPat(name="x"),)), name="vec"), lambda vec,x: x.gep(tuple(y.arg[0] for y in vec.src))), # reorder ALU/VECTORIZE - (Pat(Ops.ALU, src=(Pat(Ops.VECTORIZE, src=Pat(name='x')), Pat(Ops.VECTORIZE, src=Pat(name='y'))), name='alu'), + (UPat(Ops.ALU, src=(UPat(Ops.VECTORIZE, src=UPat(name='x')), UPat(Ops.VECTORIZE, src=UPat(name='y'))), name='alu'), lambda x,y,alu: UOp(Ops.VECTORIZE, alu.dtype, (UOp(Ops.ALU, alu.dtype.scalar(), (x,y), alu.arg),)*alu.dtype.count)), # VECTORIZE of a single element is just that element - (Pat(Ops.VECTORIZE, src=(Pat(name='x'),)), lambda x: x), + (UPat(Ops.VECTORIZE, src=(UPat(name='x'),)), lambda x: x), # VECTORIZE void is SINK - (Pat(Ops.VECTORIZE, dtype=dtypes.void, src=Pat(Ops.BARRIER, name='b')), lambda b: b), - (Pat(Ops.VECTORIZE, dtype=dtypes.void, name='x'), lambda x: UOp(Ops.SINK, dtypes.void, x.src)), + (UPat(Ops.VECTORIZE, dtype=dtypes.void, src=UPat(Ops.BARRIER, name='b')), lambda b: b), + (UPat(Ops.VECTORIZE, dtype=dtypes.void, name='x'), lambda x: UOp(Ops.SINK, dtypes.void, x.src)), # GEP/VECTORIZE, GEP/GEP, GEP/CONST, GEP/VCONST - (Pat(Ops.GEP, src=(Pat(Ops.GEP, name='g2'),), name='g1'), + (UPat(Ops.GEP, src=(UPat(Ops.GEP, name='g2'),), name='g1'), lambda g1, g2: g2.src[0].gep(tuple(g2.arg[g1.arg[i]] for i in range(g1.dtype.count)))), - (Pat(Ops.GEP, src=(Pat(Ops.VECTORIZE, name="vec"),), name="gep"), + (UPat(Ops.GEP, src=(UPat(Ops.VECTORIZE, name="vec"),), name="gep"), lambda gep, vec: UOp(Ops.VECTORIZE, gep.dtype, tuple(vec.src[i] for i in gep.arg)) if len(gep.arg) > 1 else vec.src[gep.arg[0]]), - (Pat(Ops.GEP, src=(Pat.cvar("c", vec=False),), name="gep"), lambda gep, c: gep.const_like(c.arg)), - (Pat(Ops.GEP, src=(Pat(Ops.VCONST, name="c"),), name="gep"), lambda gep, c: gep.const_like(tuple(c.arg[x] for x in gep.arg))), + (UPat(Ops.GEP, src=(UPat.cvar("c", vec=False),), name="gep"), lambda gep, c: gep.const_like(c.arg)), + (UPat(Ops.GEP, src=(UPat(Ops.VCONST, name="c"),), name="gep"), lambda gep, c: gep.const_like(tuple(c.arg[x] for x in gep.arg))), # push all GEPs through ALUs (fix arange stuff) - (Pat(Ops.GEP, src=(Pat((Ops.ALU, Ops.CAST, Ops.BITCAST), name='alu'),), name='gep'), + (UPat(Ops.GEP, src=(UPat((Ops.ALU, Ops.CAST, Ops.BITCAST), name='alu'),), name='gep'), lambda gep,alu: UOp(alu.op, alu.dtype.scalar().vec(gep.dtype.count), tuple(x.gep(gep.arg) for x in alu.src), alu.arg)), # push some GEPs through WMMAs - (Pat(Ops.GEP, src=(Pat(Ops.WMMA, name="wmma"),), name="gep"), gep_through_wmma), + (UPat(Ops.GEP, src=(UPat(Ops.WMMA, name="wmma"),), name="gep"), gep_through_wmma), # tensor core with a 0 input is acc - (Pat(Ops.WMMA, src=(Pat.const(None, 0.0), Pat.var(), Pat.var("acc"))), lambda acc: acc), - (Pat(Ops.WMMA, src=(Pat.var(), Pat.const(None, 0.0), Pat.var("acc"))), lambda acc: acc), + (UPat(Ops.WMMA, src=(UPat.const(None, 0.0), UPat.var(), UPat.var("acc"))), lambda acc: acc), + (UPat(Ops.WMMA, src=(UPat.var(), UPat.const(None, 0.0), UPat.var("acc"))), lambda acc: acc), # tensor core cleanups - (Pat.var("add") + Pat(Ops.WMMA, name="wmma"), + (UPat.var("add") + UPat(Ops.WMMA, name="wmma"), lambda add, wmma: UOp(wmma.op, wmma.dtype, (wmma.src[0], wmma.src[1], wmma.src[2]+add), wmma.arg)), # threefry - (Pat(Ops.ALU, dtype=dtypes.uint64, src=(Pat.var("x"), Pat.var("key")), arg=BinaryOps.THREEFRY), threefry2x32), + (UPat(Ops.ALU, dtype=dtypes.uint64, src=(UPat.var("x"), UPat.var("key")), arg=BinaryOps.THREEFRY), threefry2x32), # arange loop folding - (acc_pat.assign(Pat.any(arange_m, arange_m+Pat.var("extra"))+acc_pat), loop_collapse), + (acc_pat.assign(UPat.any(arange_m, arange_m+UPat.var("extra"))+acc_pat), loop_collapse), # indexing, with cast or where - (acc_pat.assign(Pat.var("idx").eq(Pat(Ops.RANGE, name="rng")).cast()*index_load+acc_pat), index_collapse), - (acc_pat.assign(Pat.var("idx").eq(Pat(Ops.RANGE, name="rng")).where(index_load, Pat.const(None, 0.0))+acc_pat), index_collapse), + (acc_pat.assign(UPat.var("idx").eq(UPat(Ops.RANGE, name="rng")).cast()*index_load+acc_pat), index_collapse), + (acc_pat.assign(UPat.var("idx").eq(UPat(Ops.RANGE, name="rng")).where(index_load, UPat.const(None, 0.0))+acc_pat), index_collapse), # parentless reduce - (acc_pat.assign(Pat(Ops.ALU, src=[acc_pat, Pat.var("ret")], arg=BinaryOps.ADD, name="alu")), reduce_collapse), - (acc_pat.assign(Pat(Ops.ALU, src=[acc_pat, Pat.var("ret")], arg=BinaryOps.MAX, name="alu")), reduce_collapse), + (acc_pat.assign(UPat(Ops.ALU, src=[acc_pat, UPat.var("ret")], arg=BinaryOps.ADD, name="alu")), reduce_collapse), + (acc_pat.assign(UPat(Ops.ALU, src=[acc_pat, UPat.var("ret")], arg=BinaryOps.MAX, name="alu")), reduce_collapse), # ** self folding ** - (Pat(Ops.DEFINE_ACC, src=(Pat.var("x"),)), lambda x: x), # a DEFINE_ACC without ranges is a CONST - (Pat(Ops.ASSIGN, src=(Pat.cvar(),Pat.var("x"))), lambda x: x), # an ASSIGN to a const is a NOOP + (UPat(Ops.DEFINE_ACC, src=(UPat.var("x"),)), lambda x: x), # a DEFINE_ACC without ranges is a CONST + (UPat(Ops.ASSIGN, src=(UPat.cvar(),UPat.var("x"))), lambda x: x), # an ASSIGN to a const is a NOOP # x!=0 -> (bool)x - (Pat.var("x").ne(0), lambda x: x.cast(dtypes.bool.vec(x.dtype.count))), + (UPat.var("x").ne(0), lambda x: x.cast(dtypes.bool.vec(x.dtype.count))), # ** load/store folding ** - (Pat.store(Pat(Ops.INDEX, name="index"), Pat.load(Pat(Ops.INDEX, name="index"))), lambda index: UOp(Ops.NOOP)), - (Pat.store(Pat(Ops.INDEX, name="index"), Pat.var("gate").where(Pat.var("alt"), Pat.load(Pat(Ops.INDEX, name="index")))), + (UPat.store(UPat(Ops.INDEX, name="index"), UPat.load(UPat(Ops.INDEX, name="index"))), lambda index: UOp(Ops.NOOP)), + (UPat.store(UPat(Ops.INDEX, name="index"), UPat.var("gate").where(UPat.var("alt"), UPat.load(UPat(Ops.INDEX, name="index")))), lambda index, gate, alt: UOp.store(index.src[0].index(index.src[1], gate), alt)), # fold gated LOAD/STORE - (Pat().index(Pat(), Pat.const(dtypes.bool, True)).named("idx"), lambda idx: idx.replace(src=idx.src[0:2])), # remove True - (Pat().index(Pat(), Pat.const(dtypes.bool, False)).named("idx"), lambda idx: idx.const_like(0)), # False -> NULL pointer - (Pat(Ops.LOAD, src=(Pat.const(None, 0),), allow_any_len=True, name="x"), lambda x: x.const_like(0)), # NULL pointer load loads 0 - (Pat(Ops.STORE, src=(Pat.const(None, 0),), allow_any_len=True), lambda: UOp(Ops.NOOP)), # NULL pointer store does nothing + (UPat().index(UPat(), UPat.const(dtypes.bool, True)).named("idx"), lambda idx: idx.replace(src=idx.src[0:2])), # remove True + (UPat().index(UPat(), UPat.const(dtypes.bool, False)).named("idx"), lambda idx: idx.const_like(0)), # False -> NULL pointer + (UPat(Ops.LOAD, src=(UPat.const(None, 0),), allow_any_len=True, name="x"), lambda x: x.const_like(0)), # NULL pointer load loads 0 + (UPat(Ops.STORE, src=(UPat.const(None, 0),), allow_any_len=True), lambda: UOp(Ops.NOOP)), # NULL pointer store does nothing # remove NOOPs from SINK - (Pat(Ops.SINK, name="root"), + (UPat(Ops.SINK, name="root"), lambda root: UOp(Ops.SINK, root.dtype, a, root.arg) if len(a:=tuple(x for x in root.src if x.op is not Ops.NOOP)) != len(root.src) else None), # remove EXPANDs from SINK/BARRIER - (Pat(Ops.BARRIER, src=(Pat((Ops.VECTORIZE, Ops.SINK), name='sink'),)), lambda sink: UOp(Ops.BARRIER, dtypes.void, sink.src)), - (Pat(Ops.SINK, name="root"), + (UPat(Ops.BARRIER, src=(UPat((Ops.VECTORIZE, Ops.SINK), name='sink'),)), lambda sink: UOp(Ops.BARRIER, dtypes.void, sink.src)), + (UPat(Ops.SINK, name="root"), lambda root: UOp(Ops.SINK, root.dtype, tuple(flatten(x.src if x.op in {Ops.SINK, Ops.EXPAND} else (x,) for x in root.src)), root.arg) if any(x.op in {Ops.SINK, Ops.EXPAND} for x in root.src) else None), ]) @@ -400,21 +400,21 @@ def create_gate(root:UOp) -> Optional[UOp]: expander = PatternMatcher([ # double expand - (Pat(Ops.EXPAND, name="outer", src=(Pat(Ops.EXPAND, name="inner"),)), + (UPat(Ops.EXPAND, name="outer", src=(UPat(Ops.EXPAND, name="inner"),)), lambda outer, inner: UOp(Ops.EXPAND, outer.dtype, (inner.src[0],), inner.arg+outer.arg)), # do expansion - (Pat((Ops.ALU, Ops.CAST, Ops.BITCAST, Ops.GEP, Ops.WMMA, Ops.LOAD, Ops.STORE, Ops.INDEX, Ops.ASSIGN, + (UPat((Ops.ALU, Ops.CAST, Ops.BITCAST, Ops.GEP, Ops.WMMA, Ops.LOAD, Ops.STORE, Ops.INDEX, Ops.ASSIGN, Ops.VECTORIZE, Ops.REDUCE, Ops.IF), name="root", custom_early_reject=set([(Ops.EXPAND, None)])), do_expand), - (Pat(Ops.CONTRACT, name="con"), do_contract), + (UPat(Ops.CONTRACT, name="con"), do_contract), # vectorize DEFINE_ACC - (Pat(Ops.VECTORIZE, src=Pat(Ops.DEFINE_ACC, name="acc"), name="v"), lambda acc,v: acc.replace(dtype=v.dtype)), + (UPat(Ops.VECTORIZE, src=UPat(Ops.DEFINE_ACC, name="acc"), name="v"), lambda acc,v: acc.replace(dtype=v.dtype)), # BARRIERs aren't actually expanded - (Pat(Ops.BARRIER, src=(Pat(Ops.EXPAND, name="ex"),)), + (UPat(Ops.BARRIER, src=(UPat(Ops.EXPAND, name="ex"),)), lambda ex: UOp(Ops.EXPAND, dtypes.void, (UOp(Ops.BARRIER, dtypes.void, ex.src),)*len(ex.src), ex.arg)), # empty EXPAND is NOOP - (Pat(Ops.EXPAND, src=(Pat.var('x'),), arg=()), lambda x: x), + (UPat(Ops.EXPAND, src=(UPat.var('x'),), arg=()), lambda x: x), # EXPAND GEP (needed for WMMA, generalize this) -> vectorized ALU - (Pat(Ops.EXPAND, name="ex", src=tuple(Pat.var('x').gep(i)+Pat.var('y').gep(i) for i in range(256 if AMX else 8))), + (UPat(Ops.EXPAND, name="ex", src=tuple(UPat.var('x').gep(i)+UPat.var('y').gep(i) for i in range(256 if AMX else 8))), lambda ex,x,y: UOp(Ops.EXPAND, ex.dtype, tuple((x+y).gep(i) for i in range(256 if AMX else 8)), ex.arg)), ]) @@ -433,10 +433,10 @@ def no_vectorized_acc(acc:UOp): devectorize = PatternMatcher([ # no ALU on vectorized dtypes - (Pat((Ops.ALU, Ops.CAST, Ops.BITCAST, Ops.ASSIGN, Ops.INDEX), name="alu"), no_vectorized_alu), - (Pat(Ops.WMMA, name="wmma"), no_vectorized_wmma), - (Pat(Ops.DEFINE_ACC, name="acc"), no_vectorized_acc), - (Pat((Ops.LOAD, Ops.STORE), name="ls"), no_vectorized_load_store), + (UPat((Ops.ALU, Ops.CAST, Ops.BITCAST, Ops.ASSIGN, Ops.INDEX), name="alu"), no_vectorized_alu), + (UPat(Ops.WMMA, name="wmma"), no_vectorized_wmma), + (UPat(Ops.DEFINE_ACC, name="acc"), no_vectorized_acc), + (UPat((Ops.LOAD, Ops.STORE), name="ls"), no_vectorized_load_store), ]) def delete_redundant_gates(buf:UOp, idx:UOp, val:UOp, store_gate:UOp, cast:Optional[UOp]=None) -> Optional[UOp]: @@ -446,14 +446,14 @@ def delete_redundant_gates(buf:UOp, idx:UOp, val:UOp, store_gate:UOp, cast:Optio load_store_indexing = PatternMatcher([ # late fixup of unfoldable image loads - (Pat(Ops.LOAD, src=(Pat.var("buf"), Pat()), allow_any_len=True, name="load"), fix_unfoldable_image_load), + (UPat(Ops.LOAD, src=(UPat.var("buf"), UPat()), allow_any_len=True, name="load"), fix_unfoldable_image_load), # simplify valid - (Pat(Ops.ALU, name="valid", arg=BinaryOps.AND), simplify_valid), + (UPat(Ops.ALU, name="valid", arg=BinaryOps.AND), simplify_valid), # image load valid idx simplification - (Pat(Ops.INDEX, src=(Pat.var("buf"), Pat.var("start_idx"), Pat.var("valid"))), simplify_valid_load), + (UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var("start_idx"), UPat.var("valid"))), simplify_valid_load), # delete_redundant_gates (after expand) - (Pat(Ops.STORE, src=(Pat.any(stidx:=Pat.var("buf").index(Pat.var("idx"), Pat.var("store_gate")), stidx.cast().named("cast")), - Pat.var("val"))), delete_redundant_gates), + (UPat(Ops.STORE, src=(UPat.any(stidx:=UPat.var("buf").index(UPat.var("idx"), UPat.var("store_gate")), stidx.cast().named("cast")), + UPat.var("val"))), delete_redundant_gates), ]) def idx_load_store(x:UOp): @@ -466,9 +466,9 @@ def idx_load_store(x:UOp): migrate_indexing = PatternMatcher([ # use indexing for LOAD/STORE - (Pat((Ops.LOAD, Ops.STORE), src=(Pat((Ops.DEFINE_GLOBAL, Ops.DEFINE_LOCAL)),), allow_any_len=True, name="x"), idx_load_store), + (UPat((Ops.LOAD, Ops.STORE), src=(UPat((Ops.DEFINE_GLOBAL, Ops.DEFINE_LOCAL)),), allow_any_len=True, name="x"), idx_load_store), # create gate MUST BE BEFORE expander - (Pat(Ops.STORE, name="root"), create_gate), + (UPat(Ops.STORE, name="root"), create_gate), ]) def move_mask(x:UOp, buf:UOp, idx:UOp, mask:UOp, cast:Optional[UOp]=None) -> UOp: @@ -478,13 +478,13 @@ def move_mask(x:UOp, buf:UOp, idx:UOp, mask:UOp, cast:Optional[UOp]=None) -> UOp pm_render = PatternMatcher([ # for rendering, we use explicit VECTORIZE - (Pat(Ops.CONST, name='c'), + (UPat(Ops.CONST, name='c'), lambda c: UOp(Ops.VECTORIZE, c.dtype, (UOp.const(c.dtype.scalar(), c.arg),)*c.dtype.vcount) if c.dtype.vcount > 1 else None), - (Pat(Ops.VCONST, name='c'), lambda c: UOp(Ops.VECTORIZE, c.dtype, tuple(UOp.const(c.dtype.scalar(), x) for x in c.arg))), - (Pat(Ops.GEP, name='gep'), lambda gep: UOp(Ops.VECTORIZE, gep.dtype, tuple(gep.src[0].gep(x) for x in gep.arg)) if len(gep.arg) > 1 else None), - (Pat(Ops.VECTORIZE, src=(Pat(name='x'),)), lambda x: x), + (UPat(Ops.VCONST, name='c'), lambda c: UOp(Ops.VECTORIZE, c.dtype, tuple(UOp.const(c.dtype.scalar(), x) for x in c.arg))), + (UPat(Ops.GEP, name='gep'), lambda gep: UOp(Ops.VECTORIZE, gep.dtype, tuple(gep.src[0].gep(x) for x in gep.arg)) if len(gep.arg) > 1 else None), + (UPat(Ops.VECTORIZE, src=(UPat(name='x'),)), lambda x: x), # move masks of loads/stores - (Pat((Ops.LOAD, Ops.STORE), src=(Pat.any(masked_index:=Pat(Ops.INDEX, src=(Pat(name="buf"), Pat(name="idx"), Pat(name="mask"))), + (UPat((Ops.LOAD, Ops.STORE), src=(UPat.any(masked_index:=UPat(Ops.INDEX, src=(UPat(name="buf"), UPat(name="idx"), UPat(name="mask"))), masked_index.cast(None).named("cast")),), allow_any_len=True, name="x"), move_mask), ]) diff --git a/tinygrad/engine/schedule.py b/tinygrad/engine/schedule.py index 4496b410c7..4e8412df7c 100644 --- a/tinygrad/engine/schedule.py +++ b/tinygrad/engine/schedule.py @@ -2,7 +2,7 @@ import sys, atexit, functools, itertools from collections import defaultdict, deque from dataclasses import dataclass, field from typing import Callable, Set, Tuple, List, Dict, Optional, DefaultDict, cast -from tinygrad.ops import BUFOPS, MetaOps, ReduceOps, UnaryOps, UOp, Ops, PatternMatcher, Pat, Variable, graph_rewrite, track_rewrites, sint +from tinygrad.ops import BUFOPS, MetaOps, ReduceOps, UnaryOps, UOp, Ops, PatternMatcher, UPat, Variable, graph_rewrite, track_rewrites, sint from tinygrad.helpers import DEBUG, Metadata, all_same, colored, diskcache_put, prod, dedup, getenv, unwrap from tinygrad.dtype import ImageDType, dtypes from tinygrad.shape.shapetracker import ShapeTracker @@ -138,27 +138,27 @@ def merge_double_reduce(root:UOp, first_reduce:UOp) -> UOp: assert not any(x.op is Ops.REDUCE_AXIS for x in first_reduce.parents), "can't merge more than two reduceops at a time" return first_reduce.src[0].r(first_reduce.arg[0], root.axis_arg+first_reduce.axis_arg) -merge_views = PatternMatcher([(Pat(Ops.VIEW, src=(Pat(Ops.VIEW, name="s0"),), name="s1"), lambda s0,s1: s0.replace(arg=s0.st+s1.st))]) +merge_views = PatternMatcher([(UPat(Ops.VIEW, src=(UPat(Ops.VIEW, name="s0"),), name="s1"), lambda s0,s1: s0.replace(arg=s0.st+s1.st))]) # push VIEW to loads view_left = merge_views+PatternMatcher([ # view before ALU - (Pat(Ops.VIEW, src=(Pat((Ops.ALU, Ops.CAST, Ops.BITCAST, Ops.ASSIGN, Ops.CONTIGUOUS, *BUFOPS), name="e"),), name="v"), + (UPat(Ops.VIEW, src=(UPat((Ops.ALU, Ops.CAST, Ops.BITCAST, Ops.ASSIGN, Ops.CONTIGUOUS, *BUFOPS), name="e"),), name="v"), lambda e,v: e.replace(src=tuple(s.view(v.st) if s.has_st else s for s in e.src))), ]) # push VIEW to stores view_right = merge_views+PatternMatcher([ # ASSIGN can override st - (Pat(Ops.STORE, src=(Pat.var("b"), Pat.var("st"), Pat(Ops.ASSIGN, name="a"))), + (UPat(Ops.STORE, src=(UPat.var("b"), UPat.var("st"), UPat(Ops.ASSIGN, name="a"))), lambda a,b,st: UOp.store(b, (a.arg[0]+st.arg).to_uop(), a.replace(arg=())) if a.arg else None), # VIEW on a reduce creates a new VIEW - (Pat(Ops.VIEW, src=(Pat(Ops.REDUCE_AXIS, src=Pat.var("rsrc"), name="r"),), name="view"), view_r), + (UPat(Ops.VIEW, src=(UPat(Ops.REDUCE_AXIS, src=UPat.var("rsrc"), name="r"),), name="view"), view_r), # push a VIEW down to STORE, through a reduce (ONLY reshapes) - (Pat(Ops.REDUCE_AXIS, src=(Pat(Ops.VIEW, name="swizzle"),), name="root"), push_swizzle_down_through_reduce), + (UPat(Ops.REDUCE_AXIS, src=(UPat(Ops.VIEW, name="swizzle"),), name="root"), push_swizzle_down_through_reduce), # push VIEW(s) down to STORE, through an elementwise op (ONLY reshapes) - (Pat((Ops.ALU, Ops.CAST, Ops.BITCAST, Ops.ASSIGN, Ops.CONTIGUOUS, Ops.STORE), name="root"), push_swizzle_down_through_elementwise), - (Pat(Ops.REDUCE_AXIS, src=(Pat(Ops.REDUCE_AXIS, name="first_reduce"),), name="root"), merge_double_reduce), + (UPat((Ops.ALU, Ops.CAST, Ops.BITCAST, Ops.ASSIGN, Ops.CONTIGUOUS, Ops.STORE), name="root"), push_swizzle_down_through_elementwise), + (UPat(Ops.REDUCE_AXIS, src=(UPat(Ops.REDUCE_AXIS, name="first_reduce"),), name="root"), merge_double_reduce), ]) # ** ScheduleItem context builder @@ -181,26 +181,26 @@ def _append_st_vars(ctx:ScheduleItemContext, x:UOp) -> Optional[UOp]: def _append_buf(ctx:ScheduleItemContext, x:UOp) -> UOp: ctx.bufs.append(x) return UOp(Ops.DEFINE_GLOBAL, x.dtype, (), len(ctx.bufs)-1) -append_bufs = PatternMatcher([(Pat(Ops.BUFFER, name="x"), _append_buf)]) +append_bufs = PatternMatcher([(UPat(Ops.BUFFER, name="x"), _append_buf)]) def _append_preload(ctx:ScheduleItemContext, x:UOp, b:UOp) -> UOp: if b in ctx.assigned: ctx.assign_preloads.append(b) return x.replace(op=Ops.LOAD) to_si = PatternMatcher([ - (Pat(Ops.VIEW, name="x"), _append_st_vars), - (Pat(Ops.PRELOAD, src=(Pat.var("b"), Pat()), name="x"), _append_preload), - (Pat(Ops.CONTIGUOUS, src=(Pat.var("x"),)), lambda ctx,x: x), - (Pat(Ops.SINK, src=(Pat.store(Pat(), Pat(), Pat(tuple(METAOPS.values()), name="x")),)), lambda ctx,x: x), + (UPat(Ops.VIEW, name="x"), _append_st_vars), + (UPat(Ops.PRELOAD, src=(UPat.var("b"), UPat()), name="x"), _append_preload), + (UPat(Ops.CONTIGUOUS, src=(UPat.var("x"),)), lambda ctx,x: x), + (UPat(Ops.SINK, src=(UPat.store(UPat(), UPat(), UPat(tuple(METAOPS.values()), name="x")),)), lambda ctx,x: x), ]) # ** fusion lazy = PatternMatcher([ - (Pat.load(b:=Pat.var("b"), Pat(), Pat.store(b, Pat(), Pat.var("v"))), lambda ctx,b,v: v), + (UPat.load(b:=UPat.var("b"), UPat(), UPat.store(b, UPat(), UPat.var("v"))), lambda ctx,b,v: v), ]) -multioutput = PatternMatcher([(Pat.load(Pat.var("b"), Pat()), lambda ctx,b: ctx.get(b)),]) +multioutput = PatternMatcher([(UPat.load(UPat.var("b"), UPat()), lambda ctx,b: ctx.get(b)),]) def full_ast_rewrite(pre:UOp, var_vals:Dict[Variable, int], assigned:Set[UOp]) -> Tuple[UOp, ScheduleItemContext]: # fuse and fold store -> loads @@ -234,14 +234,14 @@ def realize(ctx:Dict[UOp, UOp], b:UOp, load:UOp, store:UOp) -> UOp: ctx[b] = store return UOp(Ops.LOAD, load.dtype, (b, load.st_arg.to_uop())) -def PatLoadStore(to_store=Pat()): return Pat.load(b:=Pat.var("b"), Pat(), Pat.store(b, Pat(), to_store, name="store"), name="load") +def UPatLoadStore(to_store=UPat()): return UPat.load(b:=UPat.var("b"), UPat(), UPat.store(b, UPat(), to_store, name="store"), name="load") do_realize = PatternMatcher([ # always realize meta ops - (PatLoadStore(Pat((Ops.ASSIGN, Ops.CONTIGUOUS, *METAOPS.values()))), realize), - (Pat((Ops.COPY, Ops.BUFFER_VIEW), src=(Pat.var("u"), Pat.any(PatLoadStore(), PatLoadStore().view(name="v"))), name="root"), + (UPatLoadStore(UPat((Ops.ASSIGN, Ops.CONTIGUOUS, *METAOPS.values()))), realize), + (UPat((Ops.COPY, Ops.BUFFER_VIEW), src=(UPat.var("u"), UPat.any(UPatLoadStore(), UPatLoadStore().view(name="v"))), name="root"), lambda ctx,root,u,v=None,**kwargs: root.replace(src=(u, realize(ctx,**kwargs) if v is None else realize(ctx,**kwargs).view(v.st))),) ]) -break_sched = PatternMatcher([(PatLoadStore(), lambda ctx,b,store,load: realize(ctx, b, load, store) if b in ctx else None),]) +break_sched = PatternMatcher([(UPatLoadStore(), lambda ctx,b,store,load: realize(ctx, b, load, store) if b in ctx else None),]) @track_rewrites(named=True) def create_schedule_with_vars(outs:List[LazyBuffer]) -> Tuple[List[ScheduleItem], Dict[Variable, int]]: diff --git a/tinygrad/ops.py b/tinygrad/ops.py index b204e2933b..97bb3bffd9 100644 --- a/tinygrad/ops.py +++ b/tinygrad/ops.py @@ -205,7 +205,7 @@ def smin(*lst): return _suop(lst[0] if isinstance(lst[0], (tuple, list)) else ls def ssimplify(uop): return uop.ssimplify() if isinstance(uop, UOp) else uop def sym_infer(uop: Union[UOp, int], var_vals: Dict[UOp, int]) -> int: return uop.sym_infer(var_vals) if isinstance(uop, UOp) else uop -# used for UOp and Pat +# used for UOp and UPat def pretty_print(x:Any, rep:Callable, srcfn=lambda x: x.src, cache=None, d=0)->str: def dfs(x:Any, cache:dict): for s in srcfn(x) or []: @@ -514,7 +514,7 @@ def flops_mem(uops:List[UOp], ignore_indexing=False) -> Tuple[sint, sint]: def get_location() -> Tuple[str, int]: frm = sys._getframe(1) - # find the real frame in the file that has the Pat, TODO: is there a better way to do this? + # find the real frame in the file that has the UPat, TODO: is there a better way to do this? while frm.f_back is not None and pathlib.Path(frm.f_back.f_code.co_filename).name in {"ops.py", "uopgraph.py", "schedule.py", "lowerer.py", "cstyle.py"}: frm = frm.f_back @@ -523,75 +523,75 @@ def get_location() -> Tuple[str, int]: def lines(fn) -> List[str]: with open(fn) as f: return f.readlines() -class Pat(MathTrait): +class UPat(MathTrait): __slots__ = ["op", "dtype", "arg", "name", "src"] def __init__(self, op:Optional[Union[Ops, Tuple[Ops, ...]]]=None, dtype:Optional[Union[DType, Tuple[DType, ...]]]=None, - src:Optional[Union[Tuple[Pat, ...], List[Pat], Pat]]=None, arg:Any=None, + src:Optional[Union[Tuple[UPat, ...], List[UPat], UPat]]=None, arg:Any=None, name:Optional[str]=None, allow_any_len:bool=False, location=None, custom_early_reject:Optional[Set[Tuple[Ops, Any]]]=None): self.op: Optional[Tuple[Ops, ...]] = (op,) if isinstance(op, Ops) else op self.dtype: Optional[Tuple[DType, ...]] = (dtype,) if isinstance(dtype, DType) else dtype self.arg, self.name, self._in_src, self.custom_early_reject = arg, name, src, custom_early_reject self.src: Any = None - assert self.name != "ctx", "Pat can't be named ctx" + assert self.name != "ctx", "UPat can't be named ctx" # try all permutations if it's a list if isinstance(src, list): self.src = list(itertools.permutations(src)) if not all_same(src) else [src] # only one if it's a tuple elif isinstance(src, tuple): self.src = [src] - # repeat if it's a Pat - elif isinstance(src, Pat): self.src = [itertools.repeat(src)] + # repeat if it's a UPat + elif isinstance(src, UPat): self.src = [itertools.repeat(src)] - self.allowed_len: int = -1 if allow_any_len or isinstance(src, Pat) or src is None else len(src) + self.allowed_len: int = -1 if allow_any_len or isinstance(src, UPat) or src is None else len(src) self.location = location or get_location() if custom_early_reject is not None: self.early_reject = custom_early_reject else: - upat_match = [src] if isinstance(src, Pat) else ([] if src is None else self.src[0]) + upat_match = [src] if isinstance(src, UPat) else ([] if src is None else self.src[0]) self.early_reject = set((pp.op[0], pp.arg) for pp in upat_match if pp.op is not None and len(pp.op) == 1) - def named(self, name:str): return Pat(self.op, self.dtype, self._in_src, self.arg, name, self.allowed_len == -1, self.custom_early_reject) + def named(self, name:str): return UPat(self.op, self.dtype, self._in_src, self.arg, name, self.allowed_len == -1, self.custom_early_reject) @staticmethod - def any(*src): return PatAny(src=src) + def any(*src): return UPatAny(src=src) @staticmethod @functools.lru_cache(None) - def var(name:Optional[str]=None, dtype:Optional[Union[DType, Tuple[DType, ...]]]=None): return Pat(dtype=dtype, name=name) + def var(name:Optional[str]=None, dtype:Optional[Union[DType, Tuple[DType, ...]]]=None): return UPat(dtype=dtype, name=name) @staticmethod @functools.lru_cache(None) def cvar(name:Optional[str]=None, dtype:Optional[DType]=None, vec=True): - return Pat((Ops.CONST, Ops.VCONST) if vec else Ops.CONST, dtype=dtype, name=name) + return UPat((Ops.CONST, Ops.VCONST) if vec else Ops.CONST, dtype=dtype, name=name) @staticmethod - def const(dtype:Optional[Union[DType, Tuple[DType, ...]]], b:ConstType): return Pat(Ops.CONST, dtype=dtype, arg=b) + def const(dtype:Optional[Union[DType, Tuple[DType, ...]]], b:ConstType): return UPat(Ops.CONST, dtype=dtype, arg=b) # copied from UOp - def index(self, idx:Pat, valid:Optional[Pat]=None): return Pat(Ops.INDEX, self.dtype, (self,idx,valid) if valid is not None else (self,idx)) - def view(self, st=None, **kwargs): return Pat(Ops.VIEW, self.dtype, (self,), st, **kwargs) - def cast(self, dtype=None): return Pat(Ops.CAST, dtype, (self,)) - def bitcast(self, dtype=None): return Pat(Ops.BITCAST, dtype, (self,)) - def gep(self, i:int): return Pat(Ops.GEP, None, (self,), (i,)) - def load(self, *src:Pat, **kwargs): return Pat(Ops.LOAD, src=(self,)+src, **kwargs) - def store(self, *src:Pat, **kwargs): return Pat(Ops.STORE, dtypes.void, (self,)+src, **kwargs) - def assign(self, x:Pat): return Pat(Ops.ASSIGN, self.dtype, (self,x)) + def index(self, idx:UPat, valid:Optional[UPat]=None): return UPat(Ops.INDEX, self.dtype, (self,idx,valid) if valid is not None else (self,idx)) + def view(self, st=None, **kwargs): return UPat(Ops.VIEW, self.dtype, (self,), st, **kwargs) + def cast(self, dtype=None): return UPat(Ops.CAST, dtype, (self,)) + def bitcast(self, dtype=None): return UPat(Ops.BITCAST, dtype, (self,)) + def gep(self, i:int): return UPat(Ops.GEP, None, (self,), (i,)) + def load(self, *src:UPat, **kwargs): return UPat(Ops.LOAD, src=(self,)+src, **kwargs) + def store(self, *src:UPat, **kwargs): return UPat(Ops.STORE, dtypes.void, (self,)+src, **kwargs) + def assign(self, x:UPat): return UPat(Ops.ASSIGN, self.dtype, (self,x)) - def const_like(self, b:ConstLike): return Pat.const(self.dtype, cast(ConstType, b)) - def alu(self, arg, *src:Pat): + def const_like(self, b:ConstLike): return UPat.const(self.dtype, cast(ConstType, b)) + def alu(self, arg, *src:UPat): asrc = (self,)+src - return Pat(Ops.ALU, None if arg in {BinaryOps.CMPLT, BinaryOps.CMPNE} else asrc[-1].dtype, list(asrc) if arg in COMMUTATIVE else asrc, arg) + return UPat(Ops.ALU, None if arg in {BinaryOps.CMPLT, BinaryOps.CMPNE} else asrc[-1].dtype, list(asrc) if arg in COMMUTATIVE else asrc, arg) - def printable(self:Pat) -> str: + def printable(self:UPat) -> str: try: return lines(self.location[0])[self.location[1]-1].strip() except FileNotFoundError: return "" def __repr__(self): def rep(x): - form = "Pat(%s, %s, name=%s, dtype=%s, allow_any_len=%s, src=%s)" + form = "UPat(%s, %s, name=%s, dtype=%s, allow_any_len=%s, src=%s)" return form % (None if x.op is None else ('(%s)'%', '.join(map(str, x.op))), x.arg, repr(x.name), set(x.dtype) if x.dtype else None, x.allowed_len == 0, "[%s]" if x.src and len(x.src)>1 else "(%s)") return pretty_print(self, rep, srcfn=lambda x:None if x.src is None else [next(x.src[0])] if isinstance(x.src[0], itertools.repeat) else x.src[0]) - def match(self:Pat, uop:UOp, store:Dict[str, UOp]) -> List[Dict[str, UOp]]: + def match(self:UPat, uop:UOp, store:Dict[str, UOp]) -> List[Dict[str, UOp]]: if (self.name is not None and store.setdefault(self.name, uop) is not uop) or \ (self.dtype is not None and uop.dtype not in self.dtype and uop.dtype.scalar() not in self.dtype) or \ (self.arg is not None and self.arg != uop.arg) or \ @@ -607,8 +607,8 @@ class Pat(MathTrait): res.extend(stores) return res -class PatAny(Pat): - def match(self:Pat, uop:UOp, store:Dict[str, UOp]) -> List[Dict[str, UOp]]: +class UPatAny(UPat): + def match(self:UPat, uop:UOp, store:Dict[str, UOp]) -> List[Dict[str, UOp]]: ret = [] for x in self.src[0]: if (match:=x.match(uop, store.copy())): ret.extend(match) @@ -624,10 +624,10 @@ def deconstruct_function(fxn:Callable) -> Tuple: return pickle.loads(pickle.dumps(ret)) if getenv("TEST_PICKLE") else ret class PatternMatcher: - def __init__(self, patterns:List[Tuple[Pat, Callable]]): + def __init__(self, patterns:List[Tuple[UPat, Callable]]): self.patterns = patterns # NOTE: use of DefaultDict here is very dangerous! all keys will live for the lifetime of the PatternMatcher! - self.pdict: Dict[Tuple[Ops, Any], List[Tuple[Pat, Callable, Set, bool]]] = {} + self.pdict: Dict[Tuple[Ops, Any], List[Tuple[UPat, Callable, Set, bool]]] = {} # uop is required, arg is optional for p,fxn in self.patterns: assert p.op is not None @@ -652,12 +652,12 @@ class PatternMatcher: # *** tracking pattern matcher *** TRACK_MATCH_STATS = ContextVar("TRACK_MATCH_STATS", 2 if getenv("VIZ") else 0) -match_stats:Dict[Pat, List[Union[int, float]]] = dict() +match_stats:Dict[UPat, List[Union[int, float]]] = dict() @dataclass(frozen=True) class TrackedRewriteContext: loc: Tuple[str, int] # location that called graph_rewrite sink: UOp # the sink passed into the rewrite - matches: List[Tuple[UOp, Optional[UOp], Optional[Pat], float]] = field(default_factory=list) # all matches of sparents + matches: List[Tuple[UOp, Optional[UOp], Optional[UPat], float]] = field(default_factory=list) # all matches of sparents rewrite_stack: List[Tuple[Any, List[TrackedRewriteContext]]] = [] contexts: List[Tuple[Any, List[TrackedRewriteContext]]] = [] @@ -676,7 +676,7 @@ def track_rewrites(named=False): return _decorator class TrackedPatternMatcher(PatternMatcher): - def __init__(self, patterns:List[Tuple[Pat, Callable]]): + def __init__(self, patterns:List[Tuple[UPat, Callable]]): super().__init__(patterns) for p,_ in self.patterns: if p not in match_stats: match_stats[p] = [0,0,0.0,0.0] @@ -745,84 +745,84 @@ def graph_rewrite(sink:UOp, pm:PatternMatcher, ctx=None) -> UOp: # this is the matcher for the final rendered UOps # matcher functions returns True or False (or None to not match) spec = PatternMatcher([ - (Pat(Ops.DEFINE_GLOBAL, name="x"), lambda x: isinstance(x.dtype, (PtrDType, ImageDType)) and not x.dtype.local), - (Pat(Ops.DEFINE_LOCAL, name="x"), lambda x: isinstance(x.dtype, PtrDType) and x.dtype.local), - (Pat(Ops.DEFINE_ACC, src=(Pat.var("c"),), name="x", allow_any_len=True), + (UPat(Ops.DEFINE_GLOBAL, name="x"), lambda x: isinstance(x.dtype, (PtrDType, ImageDType)) and not x.dtype.local), + (UPat(Ops.DEFINE_LOCAL, name="x"), lambda x: isinstance(x.dtype, PtrDType) and x.dtype.local), + (UPat(Ops.DEFINE_ACC, src=(UPat.var("c"),), name="x", allow_any_len=True), lambda x,c: all(y.op is Ops.RANGE for y in x.src[1:]) and c.dtype == x.dtype), - (Pat(Ops.DEFINE_VAR, src=(), name="x"), lambda x: isinstance(x.arg[1], int) and isinstance(x.arg[2], int)), + (UPat(Ops.DEFINE_VAR, src=(), name="x"), lambda x: isinstance(x.arg[1], int) and isinstance(x.arg[2], int)), - (Pat(Ops.RANGE, src=(Pat(name="x"), Pat(name="y")), name="rng"), lambda rng,x,y: rng.dtype == x.dtype == y.dtype), - (Pat(Ops.SPECIAL, src=()), lambda: True), + (UPat(Ops.RANGE, src=(UPat(name="x"), UPat(name="y")), name="rng"), lambda rng,x,y: rng.dtype == x.dtype == y.dtype), + (UPat(Ops.SPECIAL, src=()), lambda: True), # TODO: confirm the args of both of these are shapetrackers - (Pat(Ops.VIEW, src=()), lambda: True), - (Pat(Ops.VIEW, src=(Pat(),)), lambda: True), + (UPat(Ops.VIEW, src=()), lambda: True), + (UPat(Ops.VIEW, src=(UPat(),)), lambda: True), - (Pat(Ops.VALID, dtypes.bool, (Pat(Ops.VIEW),)), lambda: True), - (Pat(Ops.CONST, name="x"), lambda x: x.dtype == x.dtype.scalar() and (type(x.arg) is type(dtypes.as_const(x.arg, x.dtype)))), + (UPat(Ops.VALID, dtypes.bool, (UPat(Ops.VIEW),)), lambda: True), + (UPat(Ops.CONST, name="x"), lambda x: x.dtype == x.dtype.scalar() and (type(x.arg) is type(dtypes.as_const(x.arg, x.dtype)))), # early LOAD has a - (Pat(Ops.LOAD, src=(Pat((Ops.DEFINE_GLOBAL, Ops.DEFINE_LOCAL)), Pat(Ops.VIEW))), lambda: True), - (Pat(Ops.LOAD, src=(Pat((Ops.DEFINE_GLOBAL, Ops.DEFINE_LOCAL)), Pat(Ops.VIEW), Pat(Ops.STORE))), lambda: True), + (UPat(Ops.LOAD, src=(UPat((Ops.DEFINE_GLOBAL, Ops.DEFINE_LOCAL)), UPat(Ops.VIEW))), lambda: True), + (UPat(Ops.LOAD, src=(UPat((Ops.DEFINE_GLOBAL, Ops.DEFINE_LOCAL)), UPat(Ops.VIEW), UPat(Ops.STORE))), lambda: True), # early STORE has a - (Pat(Ops.STORE, src=(Pat((Ops.DEFINE_GLOBAL, Ops.DEFINE_LOCAL)), Pat(Ops.VIEW), Pat())), lambda: True), + (UPat(Ops.STORE, src=(UPat((Ops.DEFINE_GLOBAL, Ops.DEFINE_LOCAL)), UPat(Ops.VIEW), UPat())), lambda: True), # **** new style load/store **** # INDEX is used in new style load/store - (Pat(Ops.INDEX, src=(Pat((Ops.DEFINE_GLOBAL, Ops.DEFINE_LOCAL)), Pat())), lambda: True), + (UPat(Ops.INDEX, src=(UPat((Ops.DEFINE_GLOBAL, Ops.DEFINE_LOCAL)), UPat())), lambda: True), # LOAD takes a - (Pat(Ops.LOAD, src=(Pat((Ops.INDEX, Ops.CAST)),)), lambda: True), - (Pat(Ops.LOAD, src=(Pat((Ops.INDEX, Ops.CAST)), Pat((Ops.IF, Ops.BARRIER)))), lambda: True), - (Pat(Ops.LOAD, src=(Pat((Ops.INDEX, Ops.CAST)), Pat(name="alt"), Pat(dtype=dtypes.bool)), name="ld"), lambda ld,alt: ld.dtype == alt.dtype), + (UPat(Ops.LOAD, src=(UPat((Ops.INDEX, Ops.CAST)),)), lambda: True), + (UPat(Ops.LOAD, src=(UPat((Ops.INDEX, Ops.CAST)), UPat((Ops.IF, Ops.BARRIER)))), lambda: True), + (UPat(Ops.LOAD, src=(UPat((Ops.INDEX, Ops.CAST)), UPat(name="alt"), UPat(dtype=dtypes.bool)), name="ld"), lambda ld,alt: ld.dtype == alt.dtype), # STORE takes a - (Pat(Ops.STORE, dtype=dtypes.void, src=(Pat((Ops.INDEX, Ops.CAST)), Pat())), lambda: True), - (Pat(Ops.STORE, dtype=dtypes.void, src=(Pat((Ops.INDEX, Ops.CAST)), Pat(), Pat(dtype=dtypes.bool))), lambda: True), - (Pat(Ops.STORE, dtype=dtypes.void, src=(Pat((Ops.INDEX, Ops.CAST)), Pat(), Pat(Ops.IF))), lambda: True), + (UPat(Ops.STORE, dtype=dtypes.void, src=(UPat((Ops.INDEX, Ops.CAST)), UPat())), lambda: True), + (UPat(Ops.STORE, dtype=dtypes.void, src=(UPat((Ops.INDEX, Ops.CAST)), UPat(), UPat(dtype=dtypes.bool))), lambda: True), + (UPat(Ops.STORE, dtype=dtypes.void, src=(UPat((Ops.INDEX, Ops.CAST)), UPat(), UPat(Ops.IF))), lambda: True), # most ALUs have all matching dtypes, except CMPLT, CMPNE, and WHERE - (Pat(Ops.ALU, name="w", src=(Pat(dtype=dtypes.bool), Pat(name="x"), Pat(name="y")), arg=TernaryOps.WHERE), + (UPat(Ops.ALU, name="w", src=(UPat(dtype=dtypes.bool), UPat(name="x"), UPat(name="y")), arg=TernaryOps.WHERE), lambda w,x,y: w.dtype == x.dtype == y.dtype), - (Pat(Ops.ALU, dtype=dtypes.bool, src=(Pat(name="x"), Pat(name="y")), arg=BinaryOps.CMPLT), lambda x,y: x.dtype == y.dtype), - (Pat(Ops.ALU, dtype=dtypes.bool, src=(Pat(name="x"), Pat(name="y")), arg=BinaryOps.CMPNE), lambda x,y: x.dtype == y.dtype), + (UPat(Ops.ALU, dtype=dtypes.bool, src=(UPat(name="x"), UPat(name="y")), arg=BinaryOps.CMPLT), lambda x,y: x.dtype == y.dtype), + (UPat(Ops.ALU, dtype=dtypes.bool, src=(UPat(name="x"), UPat(name="y")), arg=BinaryOps.CMPNE), lambda x,y: x.dtype == y.dtype), # and SHL/SHR, the shift distance is an int - (Pat(Ops.ALU, src=(Pat(name="x"), Pat(name="y")), name="alu", arg=BinaryOps.SHL), + (UPat(Ops.ALU, src=(UPat(name="x"), UPat(name="y")), name="alu", arg=BinaryOps.SHL), lambda alu,x,y: alu.dtype == x.dtype and (x.dtype == y.dtype or y.dtype == dtypes.uint)), - (Pat(Ops.ALU, src=(Pat(name="x"), Pat(name="y")), name="alu", arg=BinaryOps.SHR), + (UPat(Ops.ALU, src=(UPat(name="x"), UPat(name="y")), name="alu", arg=BinaryOps.SHR), lambda alu,x,y: alu.dtype == x.dtype and (x.dtype == y.dtype or y.dtype == dtypes.uint)), - (Pat(Ops.ALU, arg=BinaryOps.IDIV, name="x"), lambda x: None if dtypes.is_int(x.dtype) else False), - (Pat(Ops.ALU, name="x"), lambda x: all(x.dtype == y.dtype for y in x.src)), + (UPat(Ops.ALU, arg=BinaryOps.IDIV, name="x"), lambda x: None if dtypes.is_int(x.dtype) else False), + (UPat(Ops.ALU, name="x"), lambda x: all(x.dtype == y.dtype for y in x.src)), - (Pat(Ops.ASSIGN, src=(Pat((Ops.DEFINE_ACC, Ops.DEFINE_GLOBAL)), Pat())), lambda: True), - (Pat(Ops.ENDRANGE, dtype=dtypes.void, src=(Pat(Ops.RANGE),)), lambda: True), + (UPat(Ops.ASSIGN, src=(UPat((Ops.DEFINE_ACC, Ops.DEFINE_GLOBAL)), UPat())), lambda: True), + (UPat(Ops.ENDRANGE, dtype=dtypes.void, src=(UPat(Ops.RANGE),)), lambda: True), # all WMMA has 3 args, - (Pat(Ops.WMMA, src=(Pat(), Pat(), Pat())), lambda: True), - (Pat(Ops.CONTRACT, name="x"), lambda x: x.dtype.count == prod(y[1] for y in x.arg)), - (Pat(Ops.EXPAND, name="x"), lambda x: x.src[0].dtype.count == prod(y[1] for y in x.arg)), + (UPat(Ops.WMMA, src=(UPat(), UPat(), UPat())), lambda: True), + (UPat(Ops.CONTRACT, name="x"), lambda x: x.dtype.count == prod(y[1] for y in x.arg)), + (UPat(Ops.EXPAND, name="x"), lambda x: x.src[0].dtype.count == prod(y[1] for y in x.arg)), # if has a - (Pat(Ops.IF, dtype=dtypes.void, src=(Pat(),)), lambda: True), - (Pat(Ops.IF, dtype=dtypes.void, src=(Pat(), Pat(Ops.BARRIER))), lambda: True), - (Pat(Ops.ENDIF, dtype=dtypes.void, src=(Pat(Ops.IF),)), lambda: True), + (UPat(Ops.IF, dtype=dtypes.void, src=(UPat(),)), lambda: True), + (UPat(Ops.IF, dtype=dtypes.void, src=(UPat(), UPat(Ops.BARRIER))), lambda: True), + (UPat(Ops.ENDIF, dtype=dtypes.void, src=(UPat(Ops.IF),)), lambda: True), - (Pat(Ops.REDUCE_AXIS, name="x"), lambda x: isinstance(x.arg, tuple) and len(x.arg) == 2 and x.arg[0] in REDUCE_ALU.values()), - (Pat(Ops.GEP, src=(Pat(name="src"),), name="gep"), lambda gep,src: gep.dtype == src.dtype.scalar()), - (Pat(Ops.VECTORIZE, name="x"), lambda x: len(x.src)>1 and len(x.src) == x.dtype.count and all(x.dtype == y.dtype.vec(len(x.src)) for y in x.src)), - (Pat((Ops.BITCAST, Ops.CAST), src=(Pat(),), name="x"), lambda x: x.arg is None), - (Pat(Ops.BARRIER, dtypes.void, src=Pat(Ops.STORE, allow_any_len=True)), lambda: True), # NOTE: all pointers must be local + (UPat(Ops.REDUCE_AXIS, name="x"), lambda x: isinstance(x.arg, tuple) and len(x.arg) == 2 and x.arg[0] in REDUCE_ALU.values()), + (UPat(Ops.GEP, src=(UPat(name="src"),), name="gep"), lambda gep,src: gep.dtype == src.dtype.scalar()), + (UPat(Ops.VECTORIZE, name="x"), lambda x: len(x.src)>1 and len(x.src) == x.dtype.count and all(x.dtype == y.dtype.vec(len(x.src)) for y in x.src)), + (UPat((Ops.BITCAST, Ops.CAST), src=(UPat(),), name="x"), lambda x: x.arg is None), + (UPat(Ops.BARRIER, dtypes.void, src=UPat(Ops.STORE, allow_any_len=True)), lambda: True), # NOTE: all pointers must be local # NOTE: for testing, we let sinks be anything - #(Pat(UOps.SINK, src=Pat(UOps.STORE)), lambda: True), - (Pat(Ops.SINK, dtypes.void), lambda: True), - (Pat(Ops.NOOP), lambda: True), + #(UPat(UOps.SINK, src=UPat(UOps.STORE)), lambda: True), + (UPat(Ops.SINK, dtypes.void), lambda: True), + (UPat(Ops.NOOP), lambda: True), # PTX LOAD/STORE - (Pat((Ops.LOAD, Ops.STORE), src=(Pat(dtype=dtypes.int64),), allow_any_len=True), lambda: True), - (Pat(Ops.BARRIER, dtypes.void, src=Pat(Ops.STORE, src=(Pat(dtype=dtypes.int64),), allow_any_len=True)), lambda: True), + (UPat((Ops.LOAD, Ops.STORE), src=(UPat(dtype=dtypes.int64),), allow_any_len=True), lambda: True), + (UPat(Ops.BARRIER, dtypes.void, src=UPat(Ops.STORE, src=(UPat(dtype=dtypes.int64),), allow_any_len=True)), lambda: True), ]) def type_verify(uops:List[UOp]): @@ -1013,125 +1013,125 @@ def max_var_const(x:UOp, c1:UOp, c2:UOp): symbolic_simple = PatternMatcher([ # ** self folding ** - (Pat.var("x") + 0, lambda x: x), # x+0 -> x - (Pat.var("x") * 1, lambda x: x), # x*1 -> x - (Pat.var("x") // Pat.var("x"), lambda x: x.const_like(1)), # x//x -> 1 - (Pat.var("x") // 1, lambda x: x), # x//1 -> x - (Pat.var("x") // -1, lambda x: -x), # x//-1 -> -x - (Pat.var("x") / Pat.var("x"), lambda x: x.const_like(1)), # x/x -> 1 - ((Pat.var("x") * Pat.var("x2")) / Pat.var("x2"), lambda x,x2: x), # (x*x2)/x2 -> x - ((Pat.var() % Pat.var("y")).named("base") % Pat.var("y"), lambda base,y: base), # (x%y)%y = -> x%y (rewritten with base for speed) - (Pat.var("x")%Pat.cvar("c")+(Pat.var("x")//Pat.cvar("c"))*Pat.cvar("c"), lambda x,c: x), # (x%c)+(x//c)*c = x - (Pat.var("x", dtype=dtypes.bool) & Pat.cvar("c", vec=False), lambda x,c: x if c.arg else c), - (Pat.var("x", dtype=dtypes.bool) | Pat.cvar("c", vec=False), lambda x,c: c if c.arg else x), - (Pat.var("x").maximum(Pat.var("x")), lambda x: x), - ((Pat.var("x") & Pat.var("x")), lambda x: x), - ((Pat.var("x") | Pat.var("x")), lambda x: x), - (Pat.var("x", dtype=dtypes.bool).logical_not().logical_not(), lambda x: x), + (UPat.var("x") + 0, lambda x: x), # x+0 -> x + (UPat.var("x") * 1, lambda x: x), # x*1 -> x + (UPat.var("x") // UPat.var("x"), lambda x: x.const_like(1)), # x//x -> 1 + (UPat.var("x") // 1, lambda x: x), # x//1 -> x + (UPat.var("x") // -1, lambda x: -x), # x//-1 -> -x + (UPat.var("x") / UPat.var("x"), lambda x: x.const_like(1)), # x/x -> 1 + ((UPat.var("x") * UPat.var("x2")) / UPat.var("x2"), lambda x,x2: x), # (x*x2)/x2 -> x + ((UPat.var() % UPat.var("y")).named("base") % UPat.var("y"), lambda base,y: base), # (x%y)%y = -> x%y (rewritten with base for speed) + (UPat.var("x")%UPat.cvar("c")+(UPat.var("x")//UPat.cvar("c"))*UPat.cvar("c"), lambda x,c: x), # (x%c)+(x//c)*c = x + (UPat.var("x", dtype=dtypes.bool) & UPat.cvar("c", vec=False), lambda x,c: x if c.arg else c), + (UPat.var("x", dtype=dtypes.bool) | UPat.cvar("c", vec=False), lambda x,c: c if c.arg else x), + (UPat.var("x").maximum(UPat.var("x")), lambda x: x), + ((UPat.var("x") & UPat.var("x")), lambda x: x), + ((UPat.var("x") | UPat.var("x")), lambda x: x), + (UPat.var("x", dtype=dtypes.bool).logical_not().logical_not(), lambda x: x), # ** zero folding ** - (Pat.var("x") < Pat.var("x"), lambda x: UOp.const(dtypes.bool.vec(x.dtype.count), False)), # x < x -> False - (Pat.var("x", dtype=dtypes.ints) != Pat.var("x", dtype=dtypes.ints), + (UPat.var("x") < UPat.var("x"), lambda x: UOp.const(dtypes.bool.vec(x.dtype.count), False)), # x < x -> False + (UPat.var("x", dtype=dtypes.ints) != UPat.var("x", dtype=dtypes.ints), lambda x: UOp.const(dtypes.bool.vec(x.dtype.count), False)), # x != x -> False (only ints) # x*0 -> 0 or 0*x -> 0 # if x is nan or inf it should render the nan value. # NOTE: this can be wrong for loaded NaN - (Pat.var("x") * 0, lambda x: x.const_like(float("nan") if isinstance(x.arg, float) and (math.isnan(x.arg) or math.isinf(x.arg)) else 0)), + (UPat.var("x") * 0, lambda x: x.const_like(float("nan") if isinstance(x.arg, float) and (math.isnan(x.arg) or math.isinf(x.arg)) else 0)), # ** constant folding ** - (Pat(Ops.ALU, name="root", src=Pat((Ops.VCONST, Ops.CONST))), + (UPat(Ops.ALU, name="root", src=UPat((Ops.VCONST, Ops.CONST))), lambda root: root.const_like(exec_alu(root.arg, root.dtype, [x.arg for x in root.src], truncate_output=False))), # bool MUL is AND, ADD/MAX is OR. prevents other rules to rewrite bool ADD/MUL incorrectly - (Pat.var('x', dtype=dtypes.bool) * Pat.var('y', dtype=dtypes.bool), lambda x,y: x&y), - (Pat.var('x', dtype=dtypes.bool) + Pat.var('y', dtype=dtypes.bool), lambda x,y: x|y), - (Pat.var('x', dtype=dtypes.bool).maximum(Pat.var('y', dtype=dtypes.bool)), lambda x,y: x|y), + (UPat.var('x', dtype=dtypes.bool) * UPat.var('y', dtype=dtypes.bool), lambda x,y: x&y), + (UPat.var('x', dtype=dtypes.bool) + UPat.var('y', dtype=dtypes.bool), lambda x,y: x|y), + (UPat.var('x', dtype=dtypes.bool).maximum(UPat.var('y', dtype=dtypes.bool)), lambda x,y: x|y), # *** cast *** - (Pat(Ops.CAST, name="root", src=Pat.cvar("c")), lambda root, c: root.const_like(c.arg)), - (Pat(Ops.CAST, name="root"), lambda root: root.src[0] if root.dtype == root.src[0].dtype else None), + (UPat(Ops.CAST, name="root", src=UPat.cvar("c")), lambda root, c: root.const_like(c.arg)), + (UPat(Ops.CAST, name="root"), lambda root: root.src[0] if root.dtype == root.src[0].dtype else None), ]) symbolic = symbolic_simple+PatternMatcher([ # ** COMMUTATIVE flipping ** - *[(Pat(Ops.ALU, arg=op, name='x'), lambda x: x.replace(src=x.src[::-1]) if x.src[1].tuplize < x.src[0].tuplize else None) for op in COMMUTATIVE], + *[(UPat(Ops.ALU, arg=op, name='x'), lambda x: x.replace(src=x.src[::-1]) if x.src[1].tuplize < x.src[0].tuplize else None) for op in COMMUTATIVE], # group like - ((Pat.var("x") + Pat.var("y")) + Pat.var("x") * Pat.cvar("c"), lambda x,y,c: (x+x*c)+y), + ((UPat.var("x") + UPat.var("y")) + UPat.var("x") * UPat.cvar("c"), lambda x,y,c: (x+x*c)+y), # ** combine terms ** - (Pat.var("x") * Pat.cvar("c0") + Pat.var("x") * Pat.cvar("c1"), lambda x,c0,c1: x*(c0+c1)), # (x*c0)+(x*c1) -> x*(c0+c1) - (Pat.var("x") + Pat.var("x") * Pat.cvar("c"), lambda x,c: x*(c+1)), # (x+x*c)-> x*(c+1) - (Pat.var("x") + Pat.var("x"), lambda x: x*2), # (x+x)-> x*2 - ((Pat.var("x") / Pat.var("x2")) / Pat.var("x3"), lambda x,x2,x3: x/(x2*x3)), # (x/x2)/x3 -> x/(x2*x3) - (-1 * (Pat.var("x") + Pat.cvar("c")), lambda x,c: (-x)+(-c)), # -(x+c) -> -x + -c + (UPat.var("x") * UPat.cvar("c0") + UPat.var("x") * UPat.cvar("c1"), lambda x,c0,c1: x*(c0+c1)), # (x*c0)+(x*c1) -> x*(c0+c1) + (UPat.var("x") + UPat.var("x") * UPat.cvar("c"), lambda x,c: x*(c+1)), # (x+x*c)-> x*(c+1) + (UPat.var("x") + UPat.var("x"), lambda x: x*2), # (x+x)-> x*2 + ((UPat.var("x") / UPat.var("x2")) / UPat.var("x3"), lambda x,x2,x3: x/(x2*x3)), # (x/x2)/x3 -> x/(x2*x3) + (-1 * (UPat.var("x") + UPat.cvar("c")), lambda x,c: (-x)+(-c)), # -(x+c) -> -x + -c # a conditional with the same results either way is a noop, also fold const conditionals - (Pat.var().where(Pat.var("val"), Pat.var("val")), lambda val: val), - (Pat.cvar("gate", vec=False).where(Pat.var("c0"), Pat.var("c1")), lambda gate, c0, c1: c0 if gate.arg else c1), + (UPat.var().where(UPat.var("val"), UPat.var("val")), lambda val: val), + (UPat.cvar("gate", vec=False).where(UPat.var("c0"), UPat.var("c1")), lambda gate, c0, c1: c0 if gate.arg else c1), # ALU min==max -> CONST (slow!) - (Pat(Ops.ALU, name="x"), lambda x: x.const_like(x.vmin) if x.vmin == x.vmax else None), + (UPat(Ops.ALU, name="x"), lambda x: x.const_like(x.vmin) if x.vmin == x.vmax else None), # max folding - (Pat.maximum(Pat.var("x"), Pat.var("y")), lambda x,y: x if x.vmin >= y.vmax else y if x.vmax <= y.vmin else None), + (UPat.maximum(UPat.var("x"), UPat.var("y")), lambda x,y: x if x.vmin >= y.vmax else y if x.vmax <= y.vmin else None), # TODO: why does this rule break beautiful_mnist? - #((Pat.var("x")+Pat.var("z")).maximum(Pat.var("y")+Pat.var("z")), lambda x,y,z: x.maximum(y) + z), - ((Pat.var("x")*Pat.cvar("c1")).maximum(Pat.var("x")*Pat.cvar("c2")), max_var_const), + #((UPat.var("x")+UPat.var("z")).maximum(UPat.var("y")+UPat.var("z")), lambda x,y,z: x.maximum(y) + z), + ((UPat.var("x")*UPat.cvar("c1")).maximum(UPat.var("x")*UPat.cvar("c2")), max_var_const), # ** two stage ALU folding ** - ((Pat.var("x") + Pat.cvar("c1")) + Pat.cvar("c2"), lambda x,c1,c2: x+(c1+c2)), - ((Pat.var("x") * Pat.cvar("c1")) * Pat.cvar("c2"), lambda x,c1,c2: x*(c1*c2)), - ((Pat.var("x") & Pat.cvar("c1")) & Pat.cvar("c2"), lambda x,c1,c2: x&(c1&c2)), - ((Pat.var("x") | Pat.cvar("c1")) | Pat.cvar("c2"), lambda x,c1,c2: x|(c1|c2)), - ((Pat.cvar("c0") + Pat.var("x")) < Pat.cvar("c1"), lambda x,c0,c1: x<(c1-c0)), # c0 + x < c1 -> x < c1 - c0 - ((Pat.var("x") // Pat.cvar("c1")) // Pat.cvar("c2"), lambda x,c1,c2: x//(c1*c2)), # (x//c1)//c2 -> x//(c1*c2) + ((UPat.var("x") + UPat.cvar("c1")) + UPat.cvar("c2"), lambda x,c1,c2: x+(c1+c2)), + ((UPat.var("x") * UPat.cvar("c1")) * UPat.cvar("c2"), lambda x,c1,c2: x*(c1*c2)), + ((UPat.var("x") & UPat.cvar("c1")) & UPat.cvar("c2"), lambda x,c1,c2: x&(c1&c2)), + ((UPat.var("x") | UPat.cvar("c1")) | UPat.cvar("c2"), lambda x,c1,c2: x|(c1|c2)), + ((UPat.cvar("c0") + UPat.var("x")) < UPat.cvar("c1"), lambda x,c0,c1: x<(c1-c0)), # c0 + x < c1 -> x < c1 - c0 + ((UPat.var("x") // UPat.cvar("c1")) // UPat.cvar("c2"), lambda x,c1,c2: x//(c1*c2)), # (x//c1)//c2 -> x//(c1*c2) # ** lt ** # c0*x 0 and c1.arg > 0 else None), # c0*x 0 else None), # mul add lt - (((Pat.cvar("c0", vec=False)*Pat.var("x"))+Pat.var("x2")).lt(Pat.cvar("c1", vec=False)), + (((UPat.cvar("c0", vec=False)*UPat.var("x"))+UPat.var("x2")).lt(UPat.cvar("c1", vec=False)), lambda x,x2,c0,c1: x.lt(c1//c0) if c1.arg % c0.arg == 0 and c0.arg > x2.vmax and x2.vmin >= 0 else None), # ** move add/mul consts to end (NOTE: this is still happening before constant folding) ** - (Pat(Ops.ALU, arg=BinaryOps.ADD, src=(Pat.var("x"), Pat.cvar("c1"))) + Pat.var("y"), lambda x,c1,y: (x+y)+c1), - (Pat(Ops.ALU, arg=BinaryOps.MUL, src=(Pat.var("x"), Pat.cvar("c1"))) * Pat.var("y"), lambda x,c1,y: (x*y)*c1), + (UPat(Ops.ALU, arg=BinaryOps.ADD, src=(UPat.var("x"), UPat.cvar("c1"))) + UPat.var("y"), lambda x,c1,y: (x+y)+c1), + (UPat(Ops.ALU, arg=BinaryOps.MUL, src=(UPat.var("x"), UPat.cvar("c1"))) * UPat.var("y"), lambda x,c1,y: (x*y)*c1), # *** rules from symbolic *** # unrolled arange div folding - (Pat(Ops.ALU, name="divs", src=[Pat(), Pat(Ops.ALU, arg=BinaryOps.IDIV)], arg=BinaryOps.ADD), fold_unrolled_divs), + (UPat(Ops.ALU, name="divs", src=[UPat(), UPat(Ops.ALU, arg=BinaryOps.IDIV)], arg=BinaryOps.ADD), fold_unrolled_divs), # generic lt folding - (Pat.var("x", dtypes.sints).lt(Pat.cvar("c", vec=False)), lambda x,c: lt_folding(x, c.arg) if 0 < c.arg else None), + (UPat.var("x", dtypes.sints).lt(UPat.cvar("c", vec=False)), lambda x,c: lt_folding(x, c.arg) if 0 < c.arg else None), # canonicalize a simplex with positive coefficients > 0 # not x < 1 -> X > 0 - (Pat.var("x", dtypes.ints).lt(1).ne(True), lambda x: newx.lt(1).ne(True) if (newx:=canonicalize_simplex(x)) is not None else None), + (UPat.var("x", dtypes.ints).lt(1).ne(True), lambda x: newx.lt(1).ne(True) if (newx:=canonicalize_simplex(x)) is not None else None), # ** div ** # # div folding - (Pat.var("x", dtypes.sints) // Pat.cvar("c", vec=False), lambda x,c: newx if 0 < c.arg and (newx:=div_folding(x,c.arg)) is not None else None), + (UPat.var("x", dtypes.sints) // UPat.cvar("c", vec=False), lambda x,c: newx if 0 < c.arg and (newx:=div_folding(x,c.arg)) is not None else None), # ** mod ** # mod folding - (Pat.var("x") % Pat.cvar("c", vec=False), lambda x,c: newx if 0 < c.arg and (newx:=mod_folding(x,c.arg)) is not None else None), + (UPat.var("x") % UPat.cvar("c", vec=False), lambda x,c: newx if 0 < c.arg and (newx:=mod_folding(x,c.arg)) is not None else None), ]) symbolic_flat = symbolic+PatternMatcher([ # ** combine terms (opinionated) ** - (-1 * (Pat.var("x") + Pat.var("y")), lambda x,y: (-x)+(-y)), # -(x+y) -> -x + -y + (-1 * (UPat.var("x") + UPat.var("y")), lambda x,y: (-x)+(-y)), # -(x+y) -> -x + -y # (x+y)*c -> x*c+y*c. only for int, float has inf*0=nan issue - ((Pat.var("x", dtypes.ints) + Pat.var("y")) * Pat.cvar("c"), lambda x,y,c: x*c+y*c), + ((UPat.var("x", dtypes.ints) + UPat.var("y")) * UPat.cvar("c"), lambda x,y,c: x*c+y*c), ]) -_substitute = PatternMatcher([(Pat(tuple(Ops), name="x"), lambda ctx,x: ctx.get(x,None))]) +_substitute = PatternMatcher([(UPat(tuple(Ops), name="x"), lambda ctx,x: ctx.get(x,None))]) # for debug syms = { BinaryOps.ADD: "+", BinaryOps.SUB: "-", BinaryOps.IDIV: "//", BinaryOps.MOD: "%", BinaryOps.SHL: "<<", BinaryOps.SHR: ">>", BinaryOps.MUL: "*", BinaryOps.CMPLT: "<", BinaryOps.CMPNE: "!=", BinaryOps.AND: "&", BinaryOps.OR: "|", BinaryOps.XOR: "^"} renderer = PatternMatcher([ - (Pat((Ops.DEFINE_VAR, Ops.SPECIAL), name="x"), lambda x: UOp(Ops.NOOP, arg=x.arg[0])), - (Pat(Ops.RANGE, name="x"), lambda x: UOp(Ops.NOOP, arg=f"ridx{x.arg[0]}")), - (Pat(Ops.CONST, name="x"), lambda x: UOp(Ops.NOOP, arg=str(x.arg))), - (Pat(Ops.BIND, src=Pat(Ops.NOOP), name="x"), lambda x: x.src[0]), - (Pat(Ops.ALU, src=Pat(Ops.NOOP), name="x", arg=UnaryOps.NEG), lambda x: UOp(Ops.NOOP, arg=f"(-{x.src[0].arg})")), - (Pat(Ops.ALU, src=Pat(Ops.NOOP), name="x", arg=BinaryOps.MAX), lambda x: UOp(Ops.NOOP, arg=f"max({x.src[0].arg}, {x.src[1].arg})")), - (Pat(Ops.ALU, src=Pat(Ops.NOOP), name="x", arg=TernaryOps.MULACC), + (UPat((Ops.DEFINE_VAR, Ops.SPECIAL), name="x"), lambda x: UOp(Ops.NOOP, arg=x.arg[0])), + (UPat(Ops.RANGE, name="x"), lambda x: UOp(Ops.NOOP, arg=f"ridx{x.arg[0]}")), + (UPat(Ops.CONST, name="x"), lambda x: UOp(Ops.NOOP, arg=str(x.arg))), + (UPat(Ops.BIND, src=UPat(Ops.NOOP), name="x"), lambda x: x.src[0]), + (UPat(Ops.ALU, src=UPat(Ops.NOOP), name="x", arg=UnaryOps.NEG), lambda x: UOp(Ops.NOOP, arg=f"(-{x.src[0].arg})")), + (UPat(Ops.ALU, src=UPat(Ops.NOOP), name="x", arg=BinaryOps.MAX), lambda x: UOp(Ops.NOOP, arg=f"max({x.src[0].arg}, {x.src[1].arg})")), + (UPat(Ops.ALU, src=UPat(Ops.NOOP), name="x", arg=TernaryOps.MULACC), lambda x: UOp(Ops.NOOP, arg=f"({x.src[0].arg}*{x.src[1].arg}+{x.src[2].arg})")), - (Pat(Ops.ALU, src=Pat(Ops.NOOP), name="x", arg=TernaryOps.WHERE), + (UPat(Ops.ALU, src=UPat(Ops.NOOP), name="x", arg=TernaryOps.WHERE), lambda x: UOp(Ops.NOOP, arg=f"({x.src[1].arg} if {x.src[0].arg} else {x.src[2].arg})")), - (Pat(Ops.ALU, src=Pat(Ops.NOOP), name="x"), lambda x: UOp(Ops.NOOP, arg=f"({x.src[0].arg}{syms[x.arg]}{x.src[1].arg})")), + (UPat(Ops.ALU, src=UPat(Ops.NOOP), name="x"), lambda x: UOp(Ops.NOOP, arg=f"({x.src[0].arg}{syms[x.arg]}{x.src[1].arg})")), ]) # *** what was symbolic.py *** diff --git a/tinygrad/renderer/cstyle.py b/tinygrad/renderer/cstyle.py index 4dc28c3b87..95ea4dddc9 100644 --- a/tinygrad/renderer/cstyle.py +++ b/tinygrad/renderer/cstyle.py @@ -2,66 +2,66 @@ from __future__ import annotations from typing import Dict, List, Optional, Tuple, Union, DefaultDict, Literal, Callable, cast import os, math from collections import defaultdict, Counter -from tinygrad.ops import UnaryOps, BinaryOps, TernaryOps, Ops, UOp, PatternMatcher, Pat, cast_float_to_bf16 +from tinygrad.ops import UnaryOps, BinaryOps, TernaryOps, Ops, UOp, PatternMatcher, UPat, cast_float_to_bf16 from tinygrad.helpers import strip_parens, getenv, prod, dedup, AMX from tinygrad.dtype import ImageDType, dtypes, DType, PtrDType from tinygrad.renderer import Renderer, TensorCore base_rewrite = PatternMatcher([ - (Pat(Ops.DEFINE_ACC, name="x"), lambda ctx,x: ctx[x.src[0]]), - (Pat(Ops.ASSIGN, name="x"), lambda ctx,x: f"{ctx[x.src[0]]} = {ctx[x.src[1]]};"), - (Pat(Ops.IF, name="x"), lambda ctx,x: f"if ({ctx[x.src[0]]}) {{"), - (Pat((Ops.ENDIF, Ops.ENDRANGE)), lambda ctx: "}"), - (Pat(Ops.WMMA, name="x"), lambda ctx,x: f"__{x.arg[0]}({ctx[x.src[0]]}, {ctx[x.src[1]]}, {ctx[x.src[2]]})"), + (UPat(Ops.DEFINE_ACC, name="x"), lambda ctx,x: ctx[x.src[0]]), + (UPat(Ops.ASSIGN, name="x"), lambda ctx,x: f"{ctx[x.src[0]]} = {ctx[x.src[1]]};"), + (UPat(Ops.IF, name="x"), lambda ctx,x: f"if ({ctx[x.src[0]]}) {{"), + (UPat((Ops.ENDIF, Ops.ENDRANGE)), lambda ctx: "}"), + (UPat(Ops.WMMA, name="x"), lambda ctx,x: f"__{x.arg[0]}({ctx[x.src[0]]}, {ctx[x.src[1]]}, {ctx[x.src[2]]})"), # r method accesses - (Pat(Ops.RANGE, name="x"), + (UPat(Ops.RANGE, name="x"), lambda ctx,x: f"for ({ctx.render_dtype(x.dtype)} {ctx[x]} = {ctx[x.src[0]]}; {ctx[x]} < {ctx[x.src[1]]}; {ctx[x]}++) {{"), - (Pat(Ops.VECTORIZE, name="x"), + (UPat(Ops.VECTORIZE, name="x"), lambda ctx,x: f"{ctx.float4.replace('float4', ctx.render_dtype(x.dtype))}" + \ (f"{{{','.join([ctx[y] for y in x.src])}}}" if ctx.device == "CLANG" else f"({','.join([ctx[y] for y in x.src])})")), - (Pat(Ops.CAST, name="x"), lambda ctx,x: f"({ctx.render_dtype(x.dtype)})({ctx[x.src[0]]})"), - (Pat(Ops.BITCAST, name="x"), lambda ctx,x: f"(*(({ctx.buffer_prefix}{ctx.render_dtype(x.dtype)}*)&{ctx[x.src[0]]}))"), - (Pat(Ops.DEFINE_LOCAL, name="x"), lambda ctx,x: f"{ctx.smem_align}{ctx.smem_prefix}{ctx.render_dtype(x.dtype.base)} {ctx[x]}[{x.arg[1]}];"), - (Pat(Ops.BARRIER), lambda ctx: ctx.barrier), - (Pat(Ops.NOOP, name="x"), lambda ctx,x: ctx[x.src[0]]), - (Pat(Ops.SPECIAL, name="x"), lambda ctx,x: f"{ctx.code_for_workitem[x.arg[0][0]](x.arg[0][-1])}; /* {x.arg[1]} */"), + (UPat(Ops.CAST, name="x"), lambda ctx,x: f"({ctx.render_dtype(x.dtype)})({ctx[x.src[0]]})"), + (UPat(Ops.BITCAST, name="x"), lambda ctx,x: f"(*(({ctx.buffer_prefix}{ctx.render_dtype(x.dtype)}*)&{ctx[x.src[0]]}))"), + (UPat(Ops.DEFINE_LOCAL, name="x"), lambda ctx,x: f"{ctx.smem_align}{ctx.smem_prefix}{ctx.render_dtype(x.dtype.base)} {ctx[x]}[{x.arg[1]}];"), + (UPat(Ops.BARRIER), lambda ctx: ctx.barrier), + (UPat(Ops.NOOP, name="x"), lambda ctx,x: ctx[x.src[0]]), + (UPat(Ops.SPECIAL, name="x"), lambda ctx,x: f"{ctx.code_for_workitem[x.arg[0][0]](x.arg[0][-1])}; /* {x.arg[1]} */"), # const - (Pat(Ops.CONST, arg=math.inf, name="x"), lambda ctx, x: f"(({ctx.render_dtype(x.dtype)}){ctx.infinity})"), - (Pat(Ops.CONST, arg=-math.inf, name="x"), lambda ctx, x: f"(({ctx.render_dtype(x.dtype)})-{ctx.infinity})"), - (Pat(Ops.CONST, dtype=dtypes.floats, name="x"), lambda ctx,x: f"(({ctx.render_dtype(x.dtype)}){ctx.nan})" if math.isnan(x.arg) else None), - (Pat(Ops.CONST, dtype=dtypes.float, name="x"), lambda ctx,x: f"{x.arg}f"), - (Pat(Ops.CONST, dtype=dtypes.int64, name="x"), lambda ctx,x: f"{x.arg}ll"), - (Pat(Ops.CONST, dtype=dtypes.uint64, name="x"), lambda ctx,x: f"{x.arg}ull"), - (Pat(Ops.CONST, dtype=dtypes.uint32, name="x"), lambda ctx,x: f"{x.arg}u"), - (Pat(Ops.CONST, dtype=dtypes.bool, name="x"), lambda ctx,x: "1" if x.arg else "0"), + (UPat(Ops.CONST, arg=math.inf, name="x"), lambda ctx, x: f"(({ctx.render_dtype(x.dtype)}){ctx.infinity})"), + (UPat(Ops.CONST, arg=-math.inf, name="x"), lambda ctx, x: f"(({ctx.render_dtype(x.dtype)})-{ctx.infinity})"), + (UPat(Ops.CONST, dtype=dtypes.floats, name="x"), lambda ctx,x: f"(({ctx.render_dtype(x.dtype)}){ctx.nan})" if math.isnan(x.arg) else None), + (UPat(Ops.CONST, dtype=dtypes.float, name="x"), lambda ctx,x: f"{x.arg}f"), + (UPat(Ops.CONST, dtype=dtypes.int64, name="x"), lambda ctx,x: f"{x.arg}ll"), + (UPat(Ops.CONST, dtype=dtypes.uint64, name="x"), lambda ctx,x: f"{x.arg}ull"), + (UPat(Ops.CONST, dtype=dtypes.uint32, name="x"), lambda ctx,x: f"{x.arg}u"), + (UPat(Ops.CONST, dtype=dtypes.bool, name="x"), lambda ctx,x: "1" if x.arg else "0"), # consts are rendered to larger type and casted - (Pat(Ops.CONST, (dtypes.bfloat16, dtypes.half), name="x"), lambda ctx,x: f"(({ctx.render_dtype(x.dtype)}){x.arg}f)"), - (Pat(Ops.CONST, (dtypes.uint8, dtypes.uint16), name="x"), lambda ctx,x: f"(({ctx.render_dtype(x.dtype)}){x.arg}u)"), - (Pat(Ops.CONST, (dtypes.int8, dtypes.int16), name="x"), lambda ctx,x: f"(({ctx.render_dtype(x.dtype)}){x.arg})"), + (UPat(Ops.CONST, (dtypes.bfloat16, dtypes.half), name="x"), lambda ctx,x: f"(({ctx.render_dtype(x.dtype)}){x.arg}f)"), + (UPat(Ops.CONST, (dtypes.uint8, dtypes.uint16), name="x"), lambda ctx,x: f"(({ctx.render_dtype(x.dtype)}){x.arg}u)"), + (UPat(Ops.CONST, (dtypes.int8, dtypes.int16), name="x"), lambda ctx,x: f"(({ctx.render_dtype(x.dtype)}){x.arg})"), # default const render - (Pat(Ops.CONST, name="x"), lambda ctx,x: str(x.arg)), + (UPat(Ops.CONST, name="x"), lambda ctx,x: str(x.arg)), # new load/store - (Pat(Ops.INDEX, src=(Pat.var("buf"), Pat.var('idx'))), + (UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var('idx'))), lambda ctx,buf,idx: f"({ctx[buf]}+{strip_parens(ctx[idx]) if idx.arg == BinaryOps.ADD else ctx[idx]})"), - (Pat(Ops.LOAD, src=(Pat.var('bidx'), Pat.var("var"), Pat.var("gate"))), lambda ctx,bidx,var,gate: f"({ctx[gate]}?*{ctx[bidx]}:{ctx[var]})"), - (Pat(Ops.LOAD, src=(Pat.var('bidx'),), allow_any_len=True), lambda ctx,bidx: f"*{ctx[bidx]}"), - (Pat(Ops.STORE, src=(Pat.var('bidx'), Pat.var("var")), allow_any_len=True), lambda ctx,bidx,var: f"*{ctx[bidx]} = {ctx[var]};"), + (UPat(Ops.LOAD, src=(UPat.var('bidx'), UPat.var("var"), UPat.var("gate"))), lambda ctx,bidx,var,gate: f"({ctx[gate]}?*{ctx[bidx]}:{ctx[var]})"), + (UPat(Ops.LOAD, src=(UPat.var('bidx'),), allow_any_len=True), lambda ctx,bidx: f"*{ctx[bidx]}"), + (UPat(Ops.STORE, src=(UPat.var('bidx'), UPat.var("var")), allow_any_len=True), lambda ctx,bidx,var: f"*{ctx[bidx]} = {ctx[var]};"), # alu/gep - (Pat(Ops.ALU, name="x"), lambda ctx,x: ctx.code_for_op[x.arg]( + (UPat(Ops.ALU, name="x"), lambda ctx,x: ctx.code_for_op[x.arg]( *([strip_parens(ctx[v]) if v.arg == x.arg and x.arg in {BinaryOps.ADD, BinaryOps.MUL, BinaryOps.XOR} else ctx[v] for v in x.src]), x.dtype)), - (Pat(Ops.GEP, name="x"), lambda ctx,x: ctx[x.src[0]] + \ + (UPat(Ops.GEP, name="x"), lambda ctx,x: ctx[x.src[0]] + \ (f"[{x.arg[0]}]" if x.src[0].dtype.count > (8 if ctx.device in {"CUDA", "NV"} else 4) or ctx.device == 'CLANG' else f".{'xyzwabcd'[x.arg[0]]}")), ]) extra_pm = PatternMatcher([ # insert a NOOP before BITCAST to force it to be rendered. not needed on all backends? - (Pat(Ops.BITCAST, name="x"), + (UPat(Ops.BITCAST, name="x"), lambda x: UOp(Ops.BITCAST, x.dtype, (UOp(Ops.NOOP, x.src[0].dtype, x.src),)) if x.src[0].op is not Ops.NOOP else None), # gate any stores that aren't gated with ifs - (Pat(Ops.STORE, dtype=dtypes.void, src=(Pat(), Pat(), Pat(dtype=dtypes.bool)), name="store"), + (UPat(Ops.STORE, dtype=dtypes.void, src=(UPat(), UPat(), UPat(dtype=dtypes.bool)), name="store"), lambda store: UOp(Ops.STORE, src=store.src[:2]+(UOp(Ops.IF, src=(store.src[2],)),))), # rewrite MAX to CMPLT + WHERE (max function is annoying on many cstyle backends) - (Pat(Ops.ALU, name="m", arg=BinaryOps.MAX), lambda m: (m.src[0] < m.src[1]).where(m.src[1], m.src[0])), + (UPat(Ops.ALU, name="m", arg=BinaryOps.MAX), lambda m: (m.src[0] < m.src[1]).where(m.src[1], m.src[0])), ]) def uops_to_dtypes(uops:List[UOp]) -> List[DType]: return dedup(u.dtype for u in uops if not isinstance(u.dtype, (ImageDType, PtrDType))) @@ -214,13 +214,13 @@ class OpenCLRenderer(CStyleLanguage): type_map = { dtypes.uint8: "uchar", dtypes.uint32: "uint", dtypes.uint16: "ushort", dtypes.uint64: "ulong", dtypes.bfloat16: "ushort" } string_rewrite = PatternMatcher([ - (Pat(Ops.BITCAST, name="x"), lambda ctx,x: f"as_{ctx.render_dtype(x.dtype)}({ctx[x.src[0]]})"), + (UPat(Ops.BITCAST, name="x"), lambda ctx,x: f"as_{ctx.render_dtype(x.dtype)}({ctx[x.src[0]]})"), # load/store image (OpenCL) - (Pat(Ops.LOAD, dtype=dtypes.float.vec(4), src=(Pat.var('buf').index(Pat.var('idx', dtypes.int.vec(2))), Pat.var("var"), Pat.var("gate"))), + (UPat(Ops.LOAD, dtype=dtypes.float.vec(4), src=(UPat.var('buf').index(UPat.var('idx', dtypes.int.vec(2))), UPat.var("var"), UPat.var("gate"))), lambda ctx,buf,idx,var,gate: f"({ctx[gate]}?read_imagef({ctx[buf]}, smp, {ctx[idx]}):{ctx[var]})"), - (Pat(Ops.LOAD, dtype=dtypes.float.vec(4), src=(Pat.var('buf').index(Pat.var('idx', dtypes.int.vec(2))),)), + (UPat(Ops.LOAD, dtype=dtypes.float.vec(4), src=(UPat.var('buf').index(UPat.var('idx', dtypes.int.vec(2))),)), lambda ctx,buf,idx: f"read_imagef({ctx[buf]}, smp, {ctx[idx]})"), - (Pat(Ops.STORE, src=(Pat.var('buf').index(Pat.var('idx', dtypes.int.vec(2))), Pat.var("var", dtypes.float.vec(4))), allow_any_len=True), + (UPat(Ops.STORE, src=(UPat.var('buf').index(UPat.var('idx', dtypes.int.vec(2))), UPat.var("var", dtypes.float.vec(4))), allow_any_len=True), lambda ctx,buf,idx,var: f"write_imagef({ctx[buf]}, {ctx[idx]}, {ctx[var]});"), ]) + base_rewrite @@ -234,8 +234,8 @@ class IntelRenderer(OpenCLRenderer): st1_pattern=(((1,0),),((1,2),(1,1),(0,0))),expanded_shape=(8,2,8)) for di,do in [(dtypes.half,dtypes.float),(dtypes.bfloat16,dtypes.float)]] string_rewrite = PatternMatcher([ - (Pat(Ops.CAST, dtype=dtypes.bfloat16, src=(Pat.var('x', dtype=dtypes.float))), lambda ctx,x: f"intel_convert_bfloat16_as_ushort({ctx[x[0]]})"), - (Pat(Ops.CAST, dtype=dtypes.float, src=(Pat.var('x', dtype=dtypes.bfloat16))), lambda ctx,x: f"intel_convert_as_bfloat16_float({ctx[x[0]]})"), + (UPat(Ops.CAST, dtype=dtypes.bfloat16, src=(UPat.var('x', dtype=dtypes.float))), lambda ctx,x: f"intel_convert_bfloat16_as_ushort({ctx[x[0]]})"), + (UPat(Ops.CAST, dtype=dtypes.float, src=(UPat.var('x', dtype=dtypes.bfloat16))), lambda ctx,x: f"intel_convert_as_bfloat16_float({ctx[x[0]]})"), ]) + OpenCLRenderer.string_rewrite def render_kernel(self, function_name, kernel, bufs, uops, prefix=None) -> str: @@ -272,13 +272,13 @@ class MetalRenderer(CStyleLanguage): # upcast to float32 all the ops that don't support bfloat16 extra_matcher = PatternMatcher([ # NOTE: this is copied from PTX - *[(Pat(Ops.ALU, arg=op, dtype=dtypes.bfloat16, name="x"), + *[(UPat(Ops.ALU, arg=op, dtype=dtypes.bfloat16, name="x"), lambda x: (UOp(x.op, dtypes.float, tuple(vv.cast(dtypes.float) for vv in x.src), x.arg).cast(dtypes.bfloat16))) for op in [UnaryOps.SQRT, UnaryOps.EXP2, UnaryOps.LOG2, UnaryOps.SIN]] ]) + extra_pm string_rewrite = PatternMatcher([ - (Pat(Ops.BITCAST, name="x"), lambda ctx,x: f"as_type<{ctx.render_dtype(x.dtype)}>({ctx[x.src[0]]})"), + (UPat(Ops.BITCAST, name="x"), lambda ctx,x: f"as_type<{ctx.render_dtype(x.dtype)}>({ctx[x.src[0]]})"), ]) + base_rewrite def render_kernel(self, function_name, kernel, bufs, uops, prefix=None): @@ -387,20 +387,20 @@ class AMDRenderer(CStyleLanguage): type_map = {dtypes.bfloat16: "hip_bfloat16"} extra_matcher = PatternMatcher([ # cast bfloat16 alus to float - (Pat(Ops.ALU, arg=TernaryOps.WHERE, src=(Pat.var("b"), Pat.var("x", dtype=dtypes.bfloat16), Pat.var("y", dtype=dtypes.bfloat16))), + (UPat(Ops.ALU, arg=TernaryOps.WHERE, src=(UPat.var("b"), UPat.var("x", dtype=dtypes.bfloat16), UPat.var("y", dtype=dtypes.bfloat16))), lambda b,x,y: UOp(Ops.ALU, arg=TernaryOps.WHERE, dtype=dtypes.float, src=(b,x.cast(dtypes.float),y.cast(dtypes.float))).cast(dtypes.bfloat16)), - (Pat(Ops.ALU, dtype=dtypes.bfloat16, name="x"), + (UPat(Ops.ALU, dtype=dtypes.bfloat16, name="x"), lambda x: UOp(x.op, dtypes.float, tuple(vv.cast(dtypes.float) for vv in x.src), x.arg).cast(dtypes.bfloat16)), - (Pat(Ops.ALU, dtypes.bool, name="alu", src=(Pat.var("x", dtype=dtypes.bfloat16), Pat.var("y", dtype=dtypes.bfloat16))), + (UPat(Ops.ALU, dtypes.bool, name="alu", src=(UPat.var("x", dtype=dtypes.bfloat16), UPat.var("y", dtype=dtypes.bfloat16))), lambda alu,x,y: UOp(alu.op, dtypes.bool, (x.cast(dtypes.float), y.cast(dtypes.float)), alu.arg)), # add float intermediate casting for bfloat16 - (Pat(Ops.CAST, name="x", src=Pat.var("y", dtypes.bfloat16)),lambda x,y: y.cast(dtypes.float).cast(x.dtype) if x.dtype!=dtypes.float else None), - (Pat(Ops.CAST, dtypes.bfloat16, Pat.var("x")),lambda x: x.cast(dtypes.float).cast(dtypes.bfloat16) if x.dtype!=dtypes.float else None), + (UPat(Ops.CAST, name="x", src=UPat.var("y", dtypes.bfloat16)),lambda x,y: y.cast(dtypes.float).cast(x.dtype) if x.dtype!=dtypes.float else None), + (UPat(Ops.CAST, dtypes.bfloat16, UPat.var("x")),lambda x: x.cast(dtypes.float).cast(dtypes.bfloat16) if x.dtype!=dtypes.float else None), # bfloat16 casting - (Pat.cvar('x', dtypes.bfloat16), lambda x: cast_float_to_bf16(UOp.const(dtypes.float, x.arg))), - (Pat(Ops.CAST, dtype=dtypes.float, src=Pat.var("x", dtype=dtypes.bfloat16)), + (UPat.cvar('x', dtypes.bfloat16), lambda x: cast_float_to_bf16(UOp.const(dtypes.float, x.arg))), + (UPat(Ops.CAST, dtype=dtypes.float, src=UPat.var("x", dtype=dtypes.bfloat16)), lambda x: (x.bitcast(dtypes.ushort).cast(dtypes.uint)<<16).bitcast(dtypes.float)), - (Pat(Ops.CAST, dtype=dtypes.bfloat16, src=Pat.var("x", dtype=dtypes.float)), cast_float_to_bf16)]) + extra_pm + (UPat(Ops.CAST, dtype=dtypes.bfloat16, src=UPat.var("x", dtype=dtypes.float)), cast_float_to_bf16)]) + extra_pm def render_vector_prefix(self, dtype:DType) -> str: vec, scal = self.render_dtype(dtype), self.render_dtype(dtype.scalar()) diff --git a/tinygrad/renderer/ptx.py b/tinygrad/renderer/ptx.py index 21e8bb2d83..78470e93dd 100644 --- a/tinygrad/renderer/ptx.py +++ b/tinygrad/renderer/ptx.py @@ -1,7 +1,7 @@ from typing import DefaultDict, Dict, List, Union, Optional, cast, Callable import struct from collections import defaultdict -from tinygrad.ops import BinaryOps, UnaryOps, TernaryOps, Op, Ops, UOp, PatternMatcher, Pat +from tinygrad.ops import BinaryOps, UnaryOps, TernaryOps, Op, Ops, UOp, PatternMatcher, UPat from tinygrad.dtype import dtypes, DType, PtrDType, ConstType from tinygrad.renderer import Renderer from tinygrad.renderer.cstyle import CUDARenderer @@ -35,23 +35,23 @@ asm_for_op: Dict[Op, Callable] = { supports_half: List[Op] = [UnaryOps.EXP2, BinaryOps.ADD, BinaryOps.MUL, BinaryOps.MAX, BinaryOps.CMPLT, TernaryOps.WHERE] ptx_matcher = PatternMatcher([ # bool CMPNE is XOR, bool CMPLT is XOR+AND (universal makes this slow, this is for renderer only) - (Pat.var('x', dtype=dtypes.bool).ne(Pat.var('y')), lambda x,y: x^y), - (Pat.var('x', dtype=dtypes.bool).lt(Pat.var('y')), lambda x,y: (x^True)&y), + (UPat.var('x', dtype=dtypes.bool).ne(UPat.var('y')), lambda x,y: x^y), + (UPat.var('x', dtype=dtypes.bool).lt(UPat.var('y')), lambda x,y: (x^True)&y), # upcast to float32 all the ops that don't support half - *[(Pat(Ops.ALU, arg=op, dtype=dtypes.half, name="x"), + *[(UPat(Ops.ALU, arg=op, dtype=dtypes.half, name="x"), lambda x: (UOp(x.op, dtypes.float32, tuple(vv.cast(dtypes.float32) for vv in x.src), x.arg).cast(dtypes.half))) for op in asm_for_op.keys() if op not in supports_half], # load/store bool -> uint8 - (Pat(Ops.LOAD, dtypes.bool, src=(Pat(dtype=dtypes.int64),), name="x", allow_any_len=True), + (UPat(Ops.LOAD, dtypes.bool, src=(UPat(dtype=dtypes.int64),), name="x", allow_any_len=True), lambda x: UOp(x.op, dtypes.uint8, x.src[0:1] + ((x.src[1].cast(dtypes.uint8),) if len(x.src) >= 2 else ()) + x.src[2:]).cast(dtypes.bool)), - (Pat(Ops.STORE, src=(Pat(dtype=dtypes.int64), Pat(dtype=dtypes.bool)), name="x", allow_any_len=True), + (UPat(Ops.STORE, src=(UPat(dtype=dtypes.int64), UPat(dtype=dtypes.bool)), name="x", allow_any_len=True), lambda x: UOp(x.op, dtypes.void, x.src[0:1] + (x.src[1].cast(dtypes.uint8),) + x.src[2:])), # load/store use pointer arithmetic, and the cast does nothing - (Pat(Ops.INDEX, src=(Pat.var("buf"), Pat.var("idx"))), lambda buf,idx: buf.cast(dtypes.int64) + idx.cast(dtypes.int64)*buf.dtype.itemsize), - (Pat(Ops.CAST, name="x"), lambda x: x.src[0] if isinstance(x.dtype, PtrDType) else None), + (UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var("idx"))), lambda buf,idx: buf.cast(dtypes.int64) + idx.cast(dtypes.int64)*buf.dtype.itemsize), + (UPat(Ops.CAST, name="x"), lambda x: x.src[0] if isinstance(x.dtype, PtrDType) else None), # ptx shr and shl instructions require y to be uint - (Pat.var("x") << Pat.var("y"), lambda x,y: UOp(Ops.ALU, x.dtype, (x,y.cast(dtypes.uint)), BinaryOps.SHL) if y.dtype != dtypes.uint else None), - (Pat.var("x") >> Pat.var("y"), lambda x,y: UOp(Ops.ALU, x.dtype, (x,y.cast(dtypes.uint)), BinaryOps.SHR) if y.dtype != dtypes.uint else None), + (UPat.var("x") << UPat.var("y"), lambda x,y: UOp(Ops.ALU, x.dtype, (x,y.cast(dtypes.uint)), BinaryOps.SHL) if y.dtype != dtypes.uint else None), + (UPat.var("x") >> UPat.var("y"), lambda x,y: UOp(Ops.ALU, x.dtype, (x,y.cast(dtypes.uint)), BinaryOps.SHR) if y.dtype != dtypes.uint else None), ]) class PTXRenderer(Renderer): diff --git a/tinygrad/viz/serve.py b/tinygrad/viz/serve.py index 69eb725ca5..867ce3f902 100755 --- a/tinygrad/viz/serve.py +++ b/tinygrad/viz/serve.py @@ -25,7 +25,7 @@ class GraphRewriteMetadata: kernel_name: Optional[str] """The kernel calling graph_rewrite""" upats: List[Tuple[Tuple[str, int], str, float]] - """List of all the applied Pats""" + """List of all the applied UPats""" @dataclass class GraphRewriteDetails(GraphRewriteMetadata):