From eaa1e0eeebf85469b331f337f5f0425f38087b64 Mon Sep 17 00:00:00 2001 From: George Hotz <72895+geohot@users.noreply.github.com> Date: Fri, 27 Sep 2024 14:54:54 +0800 Subject: [PATCH] rename constant_folder to sym [run_process_replay] (#6780) --- test/test_uop_graph.py | 20 ++++++++++---------- test/test_uops.py | 4 ++-- test/unit/test_shapetracker.py | 4 ++-- tinygrad/codegen/uopgraph.py | 20 ++++++++++---------- tinygrad/renderer/assembly.py | 4 ++-- tinygrad/shape/shapetracker.py | 10 +++++----- viz/test_viz.py | 4 ++-- 7 files changed, 33 insertions(+), 33 deletions(-) diff --git a/test/test_uop_graph.py b/test/test_uop_graph.py index 51570b21a8..0281f724fd 100644 --- a/test/test_uop_graph.py +++ b/test/test_uop_graph.py @@ -7,7 +7,7 @@ from tinygrad.helpers import DEBUG from tinygrad.ops import BinaryOps, TernaryOps, UnaryOps, UOps, UOp, KernelInfo from tinygrad.ops import UPat, PatternMatcher from tinygrad.codegen.lowerer import ast_to_uop -from tinygrad.codegen.uopgraph import linearize_uop, full_graph_rewrite, graph_rewrite, expander, reducer, constant_folder, float4_folding +from tinygrad.codegen.uopgraph import linearize_uop, full_graph_rewrite, graph_rewrite, expander, reducer, sym, float4_folding from tinygrad.shape.shapetracker import ShapeTracker, View simple_pm = PatternMatcher([ @@ -71,21 +71,21 @@ class TestGraphRewriteConst(unittest.TestCase): def test_gep_const(self): v1 = UOp.const(dtypes.int.vec(3), (0,1,2)) v2 = v1.gep(1) - ret = graph_rewrite(v2, constant_folder) + ret = graph_rewrite(v2, sym) self.assertEqual(ret.dtype, dtypes.int) self.assertEqual(ret.arg, 1) def test_gep_const_single(self): v1 = UOp.const(dtypes.int.vec(3), 4) v2 = v1.gep(1) - ret = graph_rewrite(v2, constant_folder) + ret = graph_rewrite(v2, sym) self.assertEqual(ret.dtype, dtypes.int) self.assertEqual(ret.arg, 4) def test_add_const(self): v1 = UOp.const(dtypes.int.vec(3), (0,1,2)) v2 = UOp.const(dtypes.int.vec(3), (5,6,7)) - ret = graph_rewrite(v1+v2, constant_folder) + ret = graph_rewrite(v1+v2, sym) self.assertEqual(ret.op, UOps.VCONST) self.assertEqual(ret.dtype, dtypes.int.vec(3)) self.assertEqual(ret.arg, (5,7,9)) @@ -93,7 +93,7 @@ class TestGraphRewriteConst(unittest.TestCase): def test_add_const_lose_v(self): v1 = UOp.const(dtypes.int.vec(3), (0,1,2)) v2 = UOp.const(dtypes.int.vec(3), (2,1,0)) - ret = graph_rewrite(v1+v2, constant_folder) + ret = graph_rewrite(v1+v2, sym) self.assertEqual(ret.op, UOps.CONST) self.assertEqual(ret.dtype, dtypes.int.vec(3)) self.assertEqual(ret.arg, 2) @@ -168,7 +168,7 @@ class TestGraphRewrite(unittest.TestCase): d = UOp.define_var('d', dtypes.int, 0, 1) outs = [2+a, 2+a+d+3+b+c+4, UOp(UOps.ALU, a.dtype, src=(a.const_like(2), a), arg=BinaryOps.ADD), (4+d)+c+(2+a)+b] for out in outs: - sink = graph_rewrite(out, constant_folder) + sink = graph_rewrite(out, sym) print(sink) self.assertEqual(sink.op, UOps.ALU) self.assertEqual(sink.src[1].op, UOps.CONST) @@ -433,9 +433,9 @@ class TestUOpGraph(unittest.TestCase): self.assertEqual(endranges[-1].src[0], ranges[0]) def expander_rewrite(sink): - sink = graph_rewrite(sink, constant_folder + expander) - return graph_rewrite(sink, constant_folder + reducer) -def float4_rewrite(sink): return graph_rewrite(sink, constant_folder + expander + float4_folding) + sink = graph_rewrite(sink, sym + expander) + return graph_rewrite(sink, sym + reducer) +def float4_rewrite(sink): return graph_rewrite(sink, sym + expander + float4_folding) class TestExpander(unittest.TestCase): def test_expand_add_broadcast(self): @@ -645,7 +645,7 @@ class TestLoadStoreFolder(unittest.TestCase): print(sink) assert len([x for x in sink.sparents if x.op is UOps.STORE]) == 3 -def gate_rewrite(sink): return graph_rewrite(sink, constant_folder + expander + reducer) +def gate_rewrite(sink): return graph_rewrite(sink, sym + expander + reducer) class TestIFUOps(unittest.TestCase): def test_create_ifs(self): diff --git a/test/test_uops.py b/test/test_uops.py index a08ca29f41..4dc9dcc5d7 100644 --- a/test/test_uops.py +++ b/test/test_uops.py @@ -10,7 +10,7 @@ from tinygrad.ops import UOps, UOp, UPat, UnaryOps, BinaryOps, TernaryOps, Reduc from tinygrad.renderer import Program from tinygrad.engine.schedule import create_schedule, reduceop_fusor from tinygrad.engine.realize import CompiledRunner, lower_schedule_item, get_kernel -from tinygrad.codegen.uopgraph import linearize_uop, full_graph_rewrite, constant_folder +from tinygrad.codegen.uopgraph import linearize_uop, full_graph_rewrite, sym from tinygrad.shape.symbolic import Variable from test.helpers import is_dtype_supported, assert_equiv_uops @@ -433,7 +433,7 @@ class TestIndexingOrdering(unittest.TestCase): class TestUPatHelpers(unittest.TestCase): def test_location(self): - self.assertEqual(constant_folder.patterns[0][0].location[0].split("/")[-1], "uopgraph.py") + self.assertEqual(sym.patterns[0][0].location[0].split("/")[-1], "uopgraph.py") self.assertEqual(reduceop_fusor.patterns[0][0].location[0].split("/")[-1], "schedule.py") self.assertEqual(spec.patterns[0][0].location[0].split("/")[-1], "ops.py") with self.assertRaises(AssertionError): # TODO: location UPat files created in test/*? diff --git a/test/unit/test_shapetracker.py b/test/unit/test_shapetracker.py index 5d3d6143a0..251d527a26 100644 --- a/test/unit/test_shapetracker.py +++ b/test/unit/test_shapetracker.py @@ -6,12 +6,12 @@ from tinygrad.helpers import prod from tinygrad.shape.shapetracker import ShapeTracker, View from tinygrad.shape.symbolic import Variable, NumNode from tinygrad.ops import UOp, UOps, graph_rewrite -from tinygrad.codegen.uopgraph import constant_folder +from tinygrad.codegen.uopgraph import sym from itertools import product def shapetracker_getitem(st:ShapeTracker, val:int): idx, valid = st.reshape((st.size,)).to_indexed_uops([UOp.const(dtypes.pyint, val)]) - idx, valid = graph_rewrite(idx, constant_folder), graph_rewrite(valid, constant_folder) + idx, valid = graph_rewrite(idx, sym), graph_rewrite(valid, sym) assert idx.op is UOps.CONST and valid.op is UOps.CONST return idx.arg, valid.arg diff --git a/tinygrad/codegen/uopgraph.py b/tinygrad/codegen/uopgraph.py index 6c87561f63..55262f35fb 100644 --- a/tinygrad/codegen/uopgraph.py +++ b/tinygrad/codegen/uopgraph.py @@ -224,7 +224,7 @@ def idx_given_valid(valid:UOp, idx:UOp) -> Optional[UOp]: for candidate in candidates: newidxs:List[List[UOp]] = [[], []] for X,newX in candidate: - newidx = replace_uop(graph_rewrite(replace_uop(idx, X, newX), constant_folder), newX, X) + newidx = replace_uop(graph_rewrite(replace_uop(idx, X, newX), sym), newX, X) newidxs[0].append(newidx.src[0]) newidxs[1].append(newidx.src[1]) @@ -247,7 +247,7 @@ def simplify_valid_image_load(load:UOp, buf:UOp): if not is_upper_bound and c == 1 and X.op is UOps.ALU and X.arg is BinaryOps.ADD and \ all(is_irreducible(u) and u.vmin == 0 for u in _get_chain(X, BinaryOps.ADD)): testidx = functools.reduce(lambda nowidx,u: replace_uop(nowidx, u, u.const_like(0)), _get_chain(X, BinaryOps.ADD), idx) - testidx = graph_rewrite(testidx, constant_folder) + testidx = graph_rewrite(testidx, sym) if testidx.src[0].vmax < 0 or testidx.src[1].vmax < 0: drop_stmt.append(stmt) continue @@ -257,7 +257,7 @@ def simplify_valid_image_load(load:UOp, buf:UOp): test_value = c + 1 if is_upper_bound else c - 1 for i,b in zip(idx.src, (buf_dtype.shape[1], buf_dtype.shape[0])): if is_increasing(i): - rw = graph_rewrite(replace_uop(i, X, X.const_like(test_value)), constant_folder) + rw = graph_rewrite(replace_uop(i, X, X.const_like(test_value)), sym) if rw.vmin >= b or rw.vmax < 0: drop_stmt.append(stmt) break @@ -365,7 +365,7 @@ def no_vectorized_wmma(wmma:UOp): return UOp(UOps.VECTORIZE, wmma.dtype, tuple(wmma_ex)) # this is symbolic 2.0 -constant_folder = PatternMatcher([ +sym = PatternMatcher([ # bool MUL is AND, ADD/MAX is OR. prevents other rules to rewrite bool ADD/MUL incorrectly (UPat.var('x', dtype=dtypes.bool) * UPat.var('y'), lambda x,y: x&y), (UPat.var('x', dtype=dtypes.bool) + UPat.var('y'), lambda x,y: x|y), @@ -737,7 +737,7 @@ def full_graph_rewrite(sink:UOp, opts:Optional[Renderer]=None) -> UOp: # do graph rewrite acc_number = 0 - sink = graph_rewrite(sink, constant_folder) + sink = graph_rewrite(sink, sym) # rewrite pyint to int32 sink = graph_rewrite(sink, no_pyint) @@ -745,12 +745,12 @@ def full_graph_rewrite(sink:UOp, opts:Optional[Renderer]=None) -> UOp: # expand linearize_cnt += 1 if linearize_cnt != (de:=getenv("DEBUG_EXPAND", 0)) and de != -1: - sink = graph_rewrite(sink, constant_folder+expander) + sink = graph_rewrite(sink, sym+expander) if getenv("DO_REDUCE", 1): - sink = graph_rewrite(sink, constant_folder+just_reduce) - sink = graph_rewrite(sink, constant_folder+(devectorize+float4_folding if opts is not None and opts.supports_float4 else devectorize)) - sink = graph_rewrite(sink, constant_folder+reducer) - sink = graph_rewrite(sink, constant_folder+get_extra_patterns(tuple(opts.code_for_op.keys()) if opts is not None else (), TRANSCENDENTAL>=2)) + sink = graph_rewrite(sink, sym+just_reduce) + sink = graph_rewrite(sink, sym+(devectorize+float4_folding if opts is not None and opts.supports_float4 else devectorize)) + sink = graph_rewrite(sink, sym+reducer) + sink = graph_rewrite(sink, sym+get_extra_patterns(tuple(opts.code_for_op.keys()) if opts is not None else (), TRANSCENDENTAL>=2)) if opts is not None and opts.extra_matcher is not None: sink = graph_rewrite(sink, opts.extra_matcher) return sink diff --git a/tinygrad/renderer/assembly.py b/tinygrad/renderer/assembly.py index bcf76684d1..cb3f174d1c 100644 --- a/tinygrad/renderer/assembly.py +++ b/tinygrad/renderer/assembly.py @@ -2,7 +2,7 @@ from typing import DefaultDict, Dict, List, Union, Optional, cast, Callable import struct from collections import defaultdict from tinygrad.ops import BinaryOps, UnaryOps, TernaryOps, Op, UOps, UOp, PatternMatcher, UPat -from tinygrad.codegen.uopgraph import constant_folder +from tinygrad.codegen.uopgraph import sym from tinygrad.dtype import dtypes, DType, PtrDType, ConstType from tinygrad.renderer import Renderer from tinygrad.renderer.cstyle import CUDARenderer @@ -39,7 +39,7 @@ def load_store_ptr_arithmetic(x:UOp, buf:UOp, alu:Optional[UOp]=None, const:Opti return x.replace(src=tuple(src)) supports_half: List[Op] = [UnaryOps.EXP2, BinaryOps.ADD, BinaryOps.MUL, BinaryOps.MAX, BinaryOps.CMPLT, TernaryOps.WHERE] -ptx_matcher = constant_folder+PatternMatcher([ +ptx_matcher = sym+PatternMatcher([ # bool CMPNE is XOR, bool CMPLT is XOR+AND (universal makes this slow, this is for renderer only) (UPat.var('x', dtype=dtypes.bool).ne(UPat.var('y')), lambda x,y: x^y), (UPat.var('x', dtype=dtypes.bool).lt(UPat.var('y')), lambda x,y: (x^True)&y), diff --git a/tinygrad/shape/shapetracker.py b/tinygrad/shape/shapetracker.py index 4fff3da1ee..577dcbafc4 100644 --- a/tinygrad/shape/shapetracker.py +++ b/tinygrad/shape/shapetracker.py @@ -8,7 +8,7 @@ from tinygrad.shape.symbolic import Variable, MulNode, SumNode, NumNode, DivNode from tinygrad.shape.view import View, strides_for_shape from tinygrad.dtype import dtypes from tinygrad.ops import UOp, UOps, BinaryOps, graph_rewrite -from tinygrad.codegen.uopgraph import constant_folder, _get_chain +from tinygrad.codegen.uopgraph import sym, _get_chain # TODO: this needs to be replaced, there shouldn't be variables in the shapetracker, only ints and UOps def variable_to_uop(x, ctx=None) -> UOp: return UOp.const(dtypes.pyint, x) if isinstance(x, int) else x.render(render_ops, ctx) @@ -101,15 +101,15 @@ class ShapeTracker: if len(self.views) == 1 and self.views[-1].mask is None: return self.views[-1].strides ret: List[Optional[sint]] = [None] * len(self.shape) idx, valid = self.to_indexed_uops() - idx = graph_rewrite(idx, pm=constant_folder) + idx = graph_rewrite(idx, pm=sym) for c in _get_chain(idx, BinaryOps.ADD): if c.op is UOps.RANGE: ret[c.arg] = 1 if c.op is UOps.ALU and c.arg is BinaryOps.MUL and c.src[0].op is UOps.RANGE and c.src[1].op is UOps.CONST: ret[c.src[0].arg] = c.src[1].arg if c.op is UOps.ALU and c.arg is BinaryOps.MUL and c.src[1].op is UOps.RANGE and c.src[0].op is UOps.CONST: ret[c.src[1].arg] = c.src[0].arg - used_ranges = [x.arg for x in graph_rewrite(idx, pm=constant_folder).sparents if x.op is UOps.RANGE] + used_ranges = [x.arg for x in graph_rewrite(idx, pm=sym).sparents if x.op is UOps.RANGE] ret = [x if i in used_ranges else 0 for i,x in enumerate(ret)] if not ignore_valid: - masked_axis = [x.arg for x in graph_rewrite(valid, pm=constant_folder).sparents if x.op is UOps.RANGE] + masked_axis = [x.arg for x in graph_rewrite(valid, pm=sym).sparents if x.op is UOps.RANGE] ret = [None if i in masked_axis else x for i,x in enumerate(ret)] return tuple(ret) @@ -117,7 +117,7 @@ class ShapeTracker: def axis_is_masked(self, axis:int) -> bool: _, valid = self.to_indexed_uops() - return axis in [x.arg for x in graph_rewrite(valid, constant_folder).sparents if x.op is UOps.RANGE] + return axis in [x.arg for x in graph_rewrite(valid, sym).sparents if x.op is UOps.RANGE] def simplify(self) -> ShapeTracker: if len(self.views) >= 2 and (new_view := self.views[-2] + self.views[-1]) is not None: diff --git a/viz/test_viz.py b/viz/test_viz.py index 8b06274939..b030b30bd5 100644 --- a/viz/test_viz.py +++ b/viz/test_viz.py @@ -7,7 +7,7 @@ from tinygrad.engine.realize import lower_schedule from tinygrad.ops import UOp, UOps, graph_rewrite, PatternMatcher, UPat, contexts, KernelInfo, BinaryOps from tinygrad.dtype import dtypes, PtrDType from tinygrad.helpers import CI, Context, all_same, DEBUG, colored, getenv -from tinygrad.codegen.uopgraph import constant_folder, devectorize, float4_folding +from tinygrad.codegen.uopgraph import sym, devectorize, float4_folding from test.external.process_replay.helpers import print_diff from viz.serve import KernelRet, UOpRet, load_kernels, uop_to_json @@ -96,7 +96,7 @@ class TestViz(unittest.TestCase): UOp(UOps.LOAD, dtypes.float.vec(4), arg=None, src=( x11, x7,)),)),)),)) - pm = constant_folder+(devectorize+float4_folding) + pm = sym+(devectorize+float4_folding) new_sink = graph_rewrite(sink, pm) if DEBUG >= 4: print_diff(sink, new_sink, unified=0) self.assert_valid_ctx(contexts)