diff --git a/test/external/fuzz_linearizer.py b/test/external/fuzz_linearizer.py index eecb750e33..570a879cb9 100644 --- a/test/external/fuzz_linearizer.py +++ b/test/external/fuzz_linearizer.py @@ -252,7 +252,7 @@ def fuzz_linearizer(lin: Kernel, rtol=1e-2, atol=1e-2, opts_list=None): def _is_simple(lin: Kernel) -> bool: if len(lin.ast.src) > 1: return False ast:UOp = lin.ast.src[0] - if ast.src[0].arg is UnaryOps.CAST and ast.src[0].src[0].op is Ops.LOAD: return True + if ast.src[0].op is UnaryOps.CAST and ast.src[0].src[0].op is Ops.LOAD: return True return False if __name__ == "__main__": diff --git a/test/test_dtype_alu.py b/test/test_dtype_alu.py index d4e407d5c6..f5be6f13be 100644 --- a/test/test_dtype_alu.py +++ b/test/test_dtype_alu.py @@ -8,7 +8,7 @@ from tinygrad.dtype import DType from tinygrad.helpers import CI, getenv from tinygrad.engine.schedule import create_schedule from tinygrad.engine.realize import run_schedule -from tinygrad.ops import UnaryOps, Ops +from tinygrad.ops import GroupOp from tinygrad.tensor import _to_np_dtype from test.helpers import is_dtype_supported import pytest @@ -79,7 +79,7 @@ def universal_test_unary(a, dtype, op): np.testing.assert_allclose(tensor_value, numpy_value, atol=1e-3, rtol=1e-2) else: np.testing.assert_equal(tensor_value, numpy_value) if op[0] != Tensor.reciprocal: # reciprocal is not supported in most backends - op = [x for x in ast.parents if x.op is Ops.ALU and x.arg in UnaryOps][0] + op = [x for x in ast.parents if x.op in GroupOp.Unary][0] assert op.dtype == dtype def universal_test_cast(a, in_dtype, dtype): diff --git a/test/test_linearizer.py b/test/test_linearizer.py index 940b65ca2c..7e8f518cbd 100644 --- a/test/test_linearizer.py +++ b/test/test_linearizer.py @@ -6,7 +6,7 @@ from dataclasses import replace from test.helpers import ast_const from tinygrad.codegen.kernel import Opt, OptOps, KernelOptError, Kernel from tinygrad.codegen.lowerer import get_grouped_dims -from tinygrad.ops import UOp, Ops, BinaryOps, TernaryOps, UnaryOps +from tinygrad.ops import UOp, Ops, BinaryOps, TernaryOps, UnaryOps, GroupOp from tinygrad.device import Device, Buffer from tinygrad.shape.shapetracker import ShapeTracker from tinygrad.shape.view import View @@ -865,9 +865,9 @@ class TestLinearizer(unittest.TestCase): lin = helper_linearizer_opt(out, wanna_output=[24])[0] ranges = [i for i,u in enumerate(lin.uops) if u.op is Ops.RANGE] # RANGE -> ALU -> RANGE -> ALU + LOAD -> ASSIGN - assert any(x.op is Ops.ALU for x in lin.uops[ranges[0]:ranges[1]]) + assert any(x.op in GroupOp.ALU for x in lin.uops[ranges[0]:ranges[1]]) assert not any(x.op is Ops.LOAD for x in lin.uops[ranges[0]:ranges[1]]) - assert any(x.op in {Ops.ALU, Ops.LOAD} for x in lin.uops[ranges[1]:]) + assert any(x.op in {*GroupOp.ALU, Ops.LOAD} for x in lin.uops[ranges[1]:]) def test_range_outer_op_before_phi(self): a = Tensor.randn(4, 1).realize() @@ -902,7 +902,7 @@ class TestLinearizer(unittest.TestCase): lin = helper_linearizer_opt(out, wanna_output=[a.numpy().sum()*a.numpy().sum()])[0] # RANGE -> LOAD -> ASSIGN -> ALU end = max(i for i,u in enumerate(lin.uops) if u.op is Ops.ENDRANGE) - assert lin.uops[end+1].op is Ops.ALU + assert lin.uops[end+1].op in GroupOp.ALU def test_range_outer_op_after_phi_nested_range(self): a = Tensor.randn(2, ).realize() @@ -910,7 +910,7 @@ class TestLinearizer(unittest.TestCase): lin = helper_linearizer_opt(out, wanna_output=[(np.broadcast_to(a.numpy().reshape(2, 1), (2, 3))).sum()*2])[0] # RANGE -> LOAD -> ASSIGN -> ALU end = max(i for i,u in enumerate(lin.uops) if u.op is Ops.ENDRANGE) - assert lin.uops[end+1].op is Ops.ALU + assert lin.uops[end+1].op in GroupOp.ALU def test_load_dedup(self): # for different leaves in the AST, the same loads may occur. @@ -1141,7 +1141,7 @@ class TestLinearizer(unittest.TestCase): # the uops graph is RANGE -> DEFINE_ACC -> 4x ALU -> 4x ASSIGN -> ENDRANGE for u in k.uops: if u.op is Ops.ASSIGN: - assert u.src[1].op is Ops.ALU + assert u.src[1].op in GroupOp.ALU # children of ASSIGN are placed after ENDRANGE if any(x.op is Ops.ASSIGN for x in u.src): end_range = [i for i, x in enumerate(k.uops) if x.op is Ops.ENDRANGE][0] @@ -1219,7 +1219,7 @@ class TestLinearizer(unittest.TestCase): assert len(sched) == 1 lin = Kernel(sched[0].ast) - assert sum(u.arg is UnaryOps.RECIP for u in lin.linearize().uops) == max_ops, msg + assert sum(u.op is UnaryOps.RECIP for u in lin.linearize().uops) == max_ops, msg a = Tensor.empty((4,4)) b = Tensor.empty((4,4)) @@ -1260,7 +1260,7 @@ class TestLinearizer(unittest.TestCase): lin = Kernel(sched_copy[-1].ast) lin.hand_coded_optimizations() lin.linearize() - assert not any(u.arg == TernaryOps.WHERE for u in lin.uops), "found where where where should be folded" + assert not any(u.op == TernaryOps.WHERE for u in lin.uops), "found where where where should be folded" def test_phi_simplification(self): def helper(t, max_ops=0): @@ -1272,7 +1272,7 @@ class TestLinearizer(unittest.TestCase): assert len(set([u.op for u in uops if u.op in {Ops.RANGE, Ops.SPECIAL}])) == 1, "has either specials or ranges, not both" assert len([u for u in uops if u.op is Ops.ASSIGN]) == 0, "ASSIGN should have been simplified" # TODO: once uops track min/max this will be fixed - #assert len([u for u in uops if u.arg is BinaryOps.MAX]) <= max_ops, "no unnecessary MAX ops" + #assert len([u for u in uops if u.op is BinaryOps.MAX]) <= max_ops, "no unnecessary MAX ops" helper(Tensor.arange(5.5, (3.5*300), 3.5), max_ops=2) helper(Tensor.arange(-1, -100, -5), max_ops=2) diff --git a/test/test_linearizer_dumb.py b/test/test_linearizer_dumb.py index 44688cc60a..eaf5df9818 100644 --- a/test/test_linearizer_dumb.py +++ b/test/test_linearizer_dumb.py @@ -75,7 +75,7 @@ class TestLinearizerDumb(unittest.TestCase): for opt in opts: k.apply_opt(opt) prg = k.to_program() print(prg.src) - assert prg.uops is not None and not any(uop.op is Ops.ALU and uop.arg is BinaryOps.MAX for uop in prg.uops), "leftover MAX" + assert prg.uops is not None and not any(uop.op is BinaryOps.MAX for uop in prg.uops), "leftover MAX" @unittest.skipUnless(Device[Device.DEFAULT].renderer.has_local, "need local") def test_expander_new_srcs(self): diff --git a/test/test_multitensor.py b/test/test_multitensor.py index ab344c027a..d99e017a3b 100644 --- a/test/test_multitensor.py +++ b/test/test_multitensor.py @@ -620,21 +620,21 @@ class TestMultiTensor(unittest.TestCase): for si in t.schedule(): ast = si.ast.src[0] assert ast.op is Ops.STORE - assert ast.src[2].arg is BinaryOps.ADD + assert ast.src[2].op is BinaryOps.ADD assert ast.src[2].src[0].op is Ops.LOAD assert ast.src[2].src[1].src[1].op is Ops.CONST and ast.src[2].src[1].src[1].arg == 1 t = 2 * t for si in t.schedule(): ast = si.ast.src[0] assert ast.op is Ops.STORE - assert ast.src[2].arg is BinaryOps.MUL + assert ast.src[2].op is BinaryOps.MUL assert ast.src[2].src[0].src[1].op is Ops.CONST and ast.src[2].src[0].src[1].arg == 2 assert ast.src[2].src[1].op is Ops.LOAD t = t + t.full_like(3) for si in t.schedule(): ast = si.ast.src[0] assert ast.op is Ops.STORE - assert ast.src[2].arg is BinaryOps.ADD + assert ast.src[2].op is BinaryOps.ADD assert ast.src[2].src[0].op is Ops.LOAD assert ast.src[2].src[1].src[1].op is Ops.CONST and ast.src[2].src[1].src[1].arg == 3 diff --git a/test/test_schedule.py b/test/test_schedule.py index d88d09ae6f..3322f8a02d 100644 --- a/test/test_schedule.py +++ b/test/test_schedule.py @@ -1040,7 +1040,7 @@ class TestSchedule(unittest.TestCase): b = r.sum(0) * 4 c = r.sum(1) * 2 schedule = check_schedule([b, c], 3) - self.assertIs(schedule[0].ast.src[0].src[2].arg, BinaryOps.ADD) + self.assertIs(schedule[0].ast.src[0].src[2].op, BinaryOps.ADD) # multireduce spec def test_multireduce_simple_chase(self): @@ -1064,7 +1064,7 @@ class TestSchedule(unittest.TestCase): d = r.T * 4 e = r * d schedule = check_schedule([d, e], 3) - self.assertIs(schedule[0].ast.src[0].src[2].arg, BinaryOps.ADD) + self.assertIs(schedule[0].ast.src[0].src[2].op, BinaryOps.ADD) # multireduce spec def test_multireduce_push_permute_chase(self): @@ -1075,7 +1075,7 @@ class TestSchedule(unittest.TestCase): d = r.T * 4 e = r * (d + a).sum(2) schedule = check_schedule([d, e], 3) # make sure it doesn't fuse - self.assertIs(schedule[0].ast.src[0].src[2].arg, BinaryOps.ADD) + self.assertIs(schedule[0].ast.src[0].src[2].op, BinaryOps.ADD) run_schedule(schedule) np.testing.assert_allclose(d.numpy(), (a.numpy().sum(2) + b.numpy()).T * 4, atol=1e-4, rtol=1e-4) np.testing.assert_allclose(e.numpy(), (a.numpy().sum(2) + b.numpy()) * (d.numpy() + a.numpy()).sum(2), atol=1e-4, rtol=1e-4) @@ -1087,7 +1087,7 @@ class TestSchedule(unittest.TestCase): r = a.sum(1) + c d = r[:4] * b schedule = check_schedule(d, 2) - self.assertIs(schedule[0].ast.src[0].src[2].arg, BinaryOps.ADD) + self.assertIs(schedule[0].ast.src[0].src[2].op, BinaryOps.ADD) # multireduce spec def test_multireduce_push_shrink_chase(self): @@ -1100,7 +1100,7 @@ class TestSchedule(unittest.TestCase): out = r[:4] * b + d.sum(1)[:4] # schedule = check_schedule(out, 2) schedule = check_schedule(out, 3) - self.assertIs(schedule[0].ast.src[0].src[2].arg, BinaryOps.ADD) + self.assertIs(schedule[0].ast.src[0].src[2].op, BinaryOps.ADD) run_schedule(schedule) np.testing.assert_allclose(out.numpy(), (a.numpy().sum(1) + c.numpy())[:4] * b.numpy() + d.numpy().sum(1)[:4], atol=1e-4, rtol=1e-4) diff --git a/test/test_uop_graph.py b/test/test_uop_graph.py index 36455e9804..1025424dda 100644 --- a/test/test_uop_graph.py +++ b/test/test_uop_graph.py @@ -161,7 +161,7 @@ class TestGraphRewrite(unittest.TestCase): c1 = UOp.const(dtypes.float, 1.0) c2 = UOp.const(dtypes.float, 2.0) nout = graph_rewrite(v+c1+c2, simple_pm) - self.assertEqual(nout.op, Ops.ALU) + self.assertEqual(nout.op, Ops.ADD) self.assertEqual(nout.src[0].op, Ops.DEFINE_VAR) self.assertEqual(nout.src[1].op, Ops.CONST) self.assertEqual(nout.src[1].arg, 3.0) @@ -182,11 +182,11 @@ class TestGraphRewrite(unittest.TestCase): b = UOp.variable('b', 0, 1) c = UOp.variable('c', 0, 1) d = UOp.variable('d', 0, 1) - outs = [2+a, 2+a+d+3+b+c+4, UOp(Ops.ALU, a.dtype, src=(a.const_like(2), a), arg=BinaryOps.ADD), (4+d)+c+(2+a)+b] + outs = [2+a, 2+a+d+3+b+c+4, UOp(Ops.ADD, a.dtype, src=(a.const_like(2), a)), (4+d)+c+(2+a)+b] for out in outs: sink = graph_rewrite(out, sym) print(sink.render()) - self.assertEqual(sink.op, Ops.ALU) + self.assertEqual(sink.op, Ops.ADD) self.assertEqual(sink.src[1].op, Ops.CONST) self.assertEqual(len([x for x in sink.sparents if x.op is Ops.CONST]), 1) @@ -380,8 +380,7 @@ class TestUOpGraph(unittest.TestCase): uops = to_uops_list([out]) self.assertEqual(len(uops), 3) out = uops[-1] - self.assertEqual(out.op, Ops.ALU) - self.assertEqual(out.arg, BinaryOps.ADD) + self.assertEqual(out.op, BinaryOps.ADD) self.assertEqual(out.src[1].op, Ops.CONST) self.assertEqual(out.src[1].arg, 6) diff --git a/test/test_uops.py b/test/test_uops.py index 7c1b57d5eb..a7676a61ce 100644 --- a/test/test_uops.py +++ b/test/test_uops.py @@ -336,8 +336,8 @@ class TestAssembly(unittest.TestCase): a2 = UOp(Ops.ALU, dtypes.int, (l1, c2), BinaryOps.MUL) uops = to_uops_list([a1,a2], opts=Device[Device.DEFAULT].renderer) Device[Device.DEFAULT].renderer.render("test", uops) - self.assertEqual(uops[-1].arg, BinaryOps.SHL) - self.assertEqual(uops[-2].arg, BinaryOps.MUL) + self.assertEqual(uops[-1].op, BinaryOps.SHL) + self.assertEqual(uops[-2].op, BinaryOps.MUL) def test_bitshift_right(self): g1 = UOp(Ops.DEFINE_GLOBAL, dtypes.int32.ptr(), (), 0) @@ -348,8 +348,8 @@ class TestAssembly(unittest.TestCase): a2 = UOp(Ops.ALU, dtypes.int, (l1, c2), BinaryOps.IDIV) uops = to_uops_list([a1,a2], opts=Device[Device.DEFAULT].renderer) Device[Device.DEFAULT].renderer.render("test", uops) - self.assertEqual(uops[-1].arg, BinaryOps.SHR) - self.assertEqual(uops[-2].arg, BinaryOps.IDIV) + self.assertEqual(uops[-1].op, BinaryOps.SHR) + self.assertEqual(uops[-2].op, BinaryOps.IDIV) class TestUOpMethod(unittest.TestCase): @unittest.skip("uops lt no longer ordered") diff --git a/test/unit/test_graph_rewrite.py b/test/unit/test_graph_rewrite.py index 4cd031efda..3388013286 100644 --- a/test/unit/test_graph_rewrite.py +++ b/test/unit/test_graph_rewrite.py @@ -1,7 +1,7 @@ import unittest, math from tinygrad import dtypes from tinygrad.helpers import all_same -from tinygrad.ops import UOp, Ops, BinaryOps, exec_alu +from tinygrad.ops import GroupOp, UOp, Ops, BinaryOps, exec_alu from tinygrad.codegen.uopgraph import full_graph_rewrite # Helper function to apply the graph rewrite @@ -14,9 +14,9 @@ def evaluate_uop(uop, variables): elif uop.op == Ops.DEFINE_VAR: var_name = uop.arg[0] return variables[var_name] - elif uop.op == Ops.ALU: + elif uop.op in GroupOp.ALU: src_values = [evaluate_uop(src, variables) for src in uop.src] - return exec_alu(uop.arg, uop.dtype, src_values) + return exec_alu(uop.op, uop.dtype, src_values) else: raise NotImplementedError(f"Unsupported UOp {uop.op}") @@ -104,8 +104,7 @@ class TestModuloAndDivisionFolding(unittest.TestCase): def test_full_graph_rewrite_division_folding_with_define_var(self): n_var_uop = UOp.variable('n', 1, 1000) optimized_div_uop = apply_rewrite((n_var_uop * 6) // 3) - self.assertEqual(optimized_div_uop.op, Ops.ALU) - self.assertEqual(optimized_div_uop.arg, BinaryOps.MUL) + self.assertEqual(optimized_div_uop.op, BinaryOps.MUL) self.assertEqual(optimized_div_uop.src[1].arg, 2) def test_full_graph_rewrite_complex_mod_div_folding(self): @@ -115,7 +114,7 @@ class TestModuloAndDivisionFolding(unittest.TestCase): self.assertEqual(optimized_div_uop.arg, 1) def test_graph_rewrite_div_folding_bug(self): - lhs = UOp(Ops.ALU, dtypes.int.vec(4), arg=BinaryOps.ADD, src=( + lhs = UOp(Ops.ADD, dtypes.int.vec(4), src=( UOp(Ops.VECTORIZE, dtypes.int.vec(4), arg=None, src=(UOp(Ops.SPECIAL, dtypes.int, arg=('lidx0', 32), src=()),)*4), UOp(Ops.VCONST, dtypes.int.vec(4), arg=(0, 256, 512, 768), src=()))) rhs = UOp.const(dtypes.int.vec(4), 2) diff --git a/test/unit/test_pattern_matcher.py b/test/unit/test_pattern_matcher.py index 715d71c525..7925fb0dff 100644 --- a/test/unit/test_pattern_matcher.py +++ b/test/unit/test_pattern_matcher.py @@ -1,6 +1,6 @@ import unittest, itertools from tinygrad.dtype import dtypes -from tinygrad.ops import Ops, UOp, BinaryOps, TernaryOps, ReduceOps, UnaryOps # noqa: F401 +from tinygrad.ops import Ops, UOp, BinaryOps, TernaryOps, ReduceOps, UnaryOps, GroupOp # noqa: F401 from tinygrad.ops import PatternMatcher, UPat class TestPatternMatcher(unittest.TestCase): @@ -72,7 +72,7 @@ class TestPatternMatcher(unittest.TestCase): matcher = PatternMatcher([ (UPat(Ops.CONST, arg=0, name="x"), lambda x: x), (UPat(Ops.CONST, arg=False, name="x"), lambda x: x), - (UPat(Ops.ALU, arg=BinaryOps.MAX, name="x"), lambda x: x), + (UPat(Ops.MAX, name="x"), lambda x: x), ]) c1 = UOp(Ops.CONST, dtypes.float, arg=0.0) c2 = UOp(Ops.CONST, dtypes.bool, arg=False) @@ -87,7 +87,7 @@ class TestPatternMatcher(unittest.TestCase): def test_filter_arg(self): matcher = PatternMatcher([ - (UPat(Ops.ALU, arg=BinaryOps.MUL, src=[UPat(Ops.CONST, name="c"), UPat(Ops.CONST, arg=2)], name="x"), + (UPat(Ops.MUL, src=[UPat(Ops.CONST, name="c"), UPat(Ops.CONST, arg=2)], name="x"), lambda x,c: x if c.arg in {1, -1} else None) ]) y1 = UOp(Ops.CONST, dtypes.int, arg=1) @@ -105,11 +105,11 @@ class TestPatternMatcher(unittest.TestCase): self.assertEqual(matcher.rewrite(c5), c5) def test_dup_name(self): - matcher = PatternMatcher([(UPat(Ops.ALU, name="x", src=(UPat(Ops.CONST, name="y"), UPat(Ops.CONST, name="y"))), lambda x, y: x)]) + matcher = PatternMatcher([(UPat(GroupOp.ALU, name="x", src=(UPat(Ops.CONST, name="y"), UPat(Ops.CONST, name="y"))), lambda x, y: x)]) y1 = UOp(Ops.CONST, dtypes.float, arg=1.0) y2 = UOp(Ops.CONST, dtypes.float, arg=1.0) - c1 = UOp(Ops.ALU, dtypes.float, (y1, y1), BinaryOps.ADD) - c2 = UOp(Ops.ALU, dtypes.float, (y1, y2), BinaryOps.ADD) + c1 = UOp(Ops.ADD, dtypes.float, (y1, y1)) + c2 = UOp(Ops.ADD, dtypes.float, (y1, y2)) self.assertEqual(matcher.rewrite(c1), c1) self.assertEqual(matcher.rewrite(c2), c1) @@ -132,7 +132,7 @@ class TestPatternMatcher(unittest.TestCase): self.assertEqual(matcher.rewrite(c4), None) def test_src_one(self): - matcher = PatternMatcher([(UPat(Ops.ALU, name="x", src=(UPat(Ops.CONST), UPat(Ops.CONST))), lambda x: x)]) + matcher = PatternMatcher([(UPat(GroupOp.ALU, name="x", src=(UPat(Ops.CONST), UPat(Ops.CONST))), lambda x: x)]) c1 = UOp(Ops.CONST, dtypes.float, arg=1.0) c2 = UOp(Ops.CONST, dtypes.float, arg=2.0) c3 = UOp(Ops.ALU, dtypes.float, (c1,c2), BinaryOps.ADD) @@ -149,7 +149,7 @@ class TestPatternMatcher(unittest.TestCase): """ def test_src_permutations(self): - matcher = PatternMatcher([(UPat(Ops.ALU, name="x", src=[UPat(Ops.CONST), UPat(Ops.ALU)]), lambda x: x)]) + matcher = PatternMatcher([(UPat(GroupOp.ALU, name="x", src=[UPat(Ops.CONST), UPat(GroupOp.ALU)]), lambda x: x)]) c1 = UOp(Ops.CONST, dtypes.float, arg=1.0) c2 = UOp(Ops.CONST, dtypes.float, arg=2.0) c3 = UOp(Ops.ALU, dtypes.float, (c1,c2), BinaryOps.ADD) @@ -162,7 +162,7 @@ class TestPatternMatcher(unittest.TestCase): self.assertEqual(matcher.rewrite(c6), None) def test_src_repeat(self): - matcher = PatternMatcher([(UPat(Ops.ALU, name="x", src=UPat(Ops.CONST)), lambda x: x)]) + matcher = PatternMatcher([(UPat(GroupOp.ALU, name="x", src=UPat(Ops.CONST)), lambda x: x)]) c1 = UOp(Ops.CONST, dtypes.float, arg=1.0) c2 = UOp(Ops.CONST, dtypes.float, arg=2.0) c3 = UOp(Ops.ALU, dtypes.float, (c1,c2), BinaryOps.ADD) @@ -171,7 +171,7 @@ class TestPatternMatcher(unittest.TestCase): self.assertEqual(matcher.rewrite(c4), None) def test_allow_len(self): - matcher = PatternMatcher([(UPat(Ops.ALU, name="x", src=(UPat(Ops.CONST),), allow_any_len=True, arg=TernaryOps.MULACC), lambda x: x)]) + matcher = PatternMatcher([(UPat(Ops.MULACC, name="x", src=(UPat(Ops.CONST),), allow_any_len=True), lambda x: x)]) c1 = UOp(Ops.CONST, dtypes.float, arg=1.0) c2 = UOp(Ops.CONST, dtypes.float, arg=2.0) c3 = UOp(Ops.CONST, dtypes.float, arg=3.0) @@ -188,7 +188,7 @@ class TestPatternMatcher(unittest.TestCase): u1 = (c1 + c2) + c1 u2 = (c2 + c1) + c1 matcher = PatternMatcher([ - (UPat(Ops.ALU, src=[UPat(Ops.ALU, src=[UPat(name='a'), UPat(name='b')]), UPat(name='b')]), lambda a,b: b) + (UPat(GroupOp.ALU, src=[UPat(GroupOp.ALU, src=[UPat(name='a'), UPat(name='b')]), UPat(name='b')]), lambda a,b: b) ]) self.assertIsNotNone(matcher.rewrite(u1)) self.assertIsNotNone(matcher.rewrite(u2)) diff --git a/tinygrad/codegen/kernel.py b/tinygrad/codegen/kernel.py index 11663aa56b..7be6822b3d 100644 --- a/tinygrad/codegen/kernel.py +++ b/tinygrad/codegen/kernel.py @@ -276,7 +276,7 @@ class Kernel: if has_cast and not (reduceop.src[0].op is Ops.CAST and reduceop.src[0].dtype == tc.dtype_out): return None mul_op = reduceop.src[0].src[0] if has_cast else reduceop.src[0] - if mul_op.arg is not BinaryOps.MUL: return None + if mul_op.op is not BinaryOps.MUL: return None def buf_index(src:UOp) -> Optional[int]: # TODO: apply tc even if the sources are not from LOAD @@ -442,7 +442,7 @@ class Kernel: check(axis < self.first_upcast, "cannot pad upcasted") # ok to pad SUM if all parent ALU ops have f(0) = 0 if (r:=self.reduceop) is not None and self.first_reduce <= axis: - check(r.arg[0] is BinaryOps.ADD and not any(u.op is Ops.ALU and u.arg in GroupOp.UnsafePad for u in r.parents), "cannot pad UnsafePad") + check(r.arg[0] is BinaryOps.ADD and not any(u.op in GroupOp.UnsafePad for u in r.parents), "cannot pad UnsafePad") padded = False for i,st in enumerate(self.sts): if (s:=st.shape[axis]) == 1: continue # reduced @@ -472,7 +472,7 @@ class Kernel: MV_BLOCKSIZE, MV_THREADS_PER_ROW, MV_ROWS_PER_THREAD = getenv("MV_BLOCKSIZE", 4), getenv("MV_THREADS_PER_ROW", 8), getenv("MV_ROWS_PER_THREAD", 4) if self.opts.has_local and getenv("MV",1) != 0 and (MV_BLOCKSIZE > 1 or MV_THREADS_PER_ROW > 1 or MV_ROWS_PER_THREAD > 1) and \ self.reduceop is not None and self.reduceop.arg[0] is BinaryOps.ADD and len(self.full_shape) >= 2 and self.opts.has_shared and \ - (mulop:=self.reduceop.src[0]).arg is BinaryOps.MUL and mulop.src[0].op is Ops.LOAD and mulop.src[1].op is Ops.LOAD: + (mulop:=self.reduceop.src[0]).op is BinaryOps.MUL and mulop.src[0].op is Ops.LOAD and mulop.src[1].op is Ops.LOAD: st0, st1 = self.sts[self.bufs.index(mulop.src[0])], self.sts[self.bufs.index(mulop.src[1])] strides0, strides1 = st0.real_strides(), st1.real_strides() def has_expanded_axis(shape, strides): return any(resolve(s > 1) and not resolve(st != 0) for s,st in zip(shape,strides)) @@ -628,7 +628,7 @@ class Kernel: if op in self.bufs_for_tensor_core and (tc := self.tensor_core): rsrc = op.src[0] if rsrc.op is Ops.CAST: rsrc = rsrc.src[0] - assert rsrc.op is Ops.ALU and rsrc.arg is BinaryOps.MUL + assert rsrc.op is Ops.MUL def fix_st(warp_dims, tcd_dims, tcd_expand, pattern_1, pattern_2, st1): wd, tcd = self.global_dims, self.first_upcast diff --git a/tinygrad/codegen/linearize.py b/tinygrad/codegen/linearize.py index d213e22ff7..356835b7bc 100644 --- a/tinygrad/codegen/linearize.py +++ b/tinygrad/codegen/linearize.py @@ -1,6 +1,6 @@ from typing import List, Set, Dict, Tuple import functools, heapq -from tinygrad.ops import type_verify, END_FOR_UOP, UOp, Ops +from tinygrad.ops import type_verify, END_FOR_UOP, UOp, Ops, GroupOp from tinygrad.dtype import dtypes from tinygrad.helpers import DEBUG @@ -54,7 +54,7 @@ def linearize_uop(sink:UOp, skip_check:bool=not __debug__) -> List[UOp]: # prevent priority inversion @functools.lru_cache(None) def fix_priority(u:UOp, lowest_priority): - if u.op in {Ops.CAST, Ops.BITCAST, Ops.ALU, Ops.VECTORIZE, Ops.GEP, Ops.SPECIAL, Ops.DEFINE_LOCAL, Ops.LOAD}: + if u.op in {Ops.CAST, Ops.BITCAST, *GroupOp.ALU, Ops.VECTORIZE, Ops.GEP, Ops.SPECIAL, Ops.DEFINE_LOCAL, Ops.LOAD}: priorities[u] = min(priorities[u], lowest_priority) if u.op is Ops.LOAD: priorities[u] += 100 # load penalty (here) for x in u.src: fix_priority(x, priorities[u]) diff --git a/tinygrad/codegen/uopgraph.py b/tinygrad/codegen/uopgraph.py index 3e7c23e679..f0186489d9 100644 --- a/tinygrad/codegen/uopgraph.py +++ b/tinygrad/codegen/uopgraph.py @@ -4,7 +4,7 @@ import functools, itertools, operator from collections import defaultdict from tinygrad.dtype import dtypes, ImageDType, PtrDType from tinygrad.ops import UnaryOps, BinaryOps, TernaryOps, UOp, Ops, UPat, PatternMatcher, symbolic_flat, symbolic_simple -from tinygrad.ops import graph_rewrite, is_irreducible, split_uop, uop_given_valid, parse_valid, is_increasing, simplify_valid +from tinygrad.ops import graph_rewrite, is_irreducible, split_uop, uop_given_valid, parse_valid, is_increasing, simplify_valid, GroupOp from tinygrad.helpers import DEBUG, getenv, flatten, dedup, TRANSCENDENTAL, AMX, prod, partition, all_same from tinygrad.codegen.transcendental import xexp2, xlog2, xsin, TRANSCENDENTAL_SUPPORTED_DTYPES @@ -23,7 +23,7 @@ def fold_expanded(ex, buf): for i,s in enumerate(new_srcs): idx = s.src[0].src[1] if s.dtype.count != 1 or (is_image and idx.dtype.count == 2): continue - if idx.arg is BinaryOps.ADD and idx.src[1].op is Ops.CONST: root_src, arg = idx.src[0], idx.src[1].arg + if idx.op is BinaryOps.ADD and idx.src[1].op is Ops.CONST: root_src, arg = idx.src[0], idx.src[1].arg elif idx.op is Ops.CONST: root_src, arg = "CONST", idx.arg else: root_src, arg = idx, 0 # add gates for gated @@ -124,18 +124,18 @@ def simplify_valid_load(buf:UOp, start_idx:UOp, valid:UOp) -> Optional[UOp]: powers_of_two = {2**i:i for i in range(64)} @functools.lru_cache(None) def get_late_rewrite_patterns(ops, force_transcendental=False): - pat: List[Tuple[UPat, Callable]] = [(UPat(Ops.ALU, dtype=TRANSCENDENTAL_SUPPORTED_DTYPES, src=(UPat.var("d"),), arg=op), f) for op,f in \ + pat: List[Tuple[UPat, Callable]] = [(UPat(op, dtype=TRANSCENDENTAL_SUPPORTED_DTYPES, src=(UPat.var("d"),)), f) for op,f in \ ((UnaryOps.EXP2, xexp2), (UnaryOps.LOG2, xlog2), (UnaryOps.SIN, xsin)) if op not in ops or force_transcendental] # rewrite MOD to AND (which should always be supported, but not for generic in tests) if BinaryOps.AND in ops: - pat += [(UPat(Ops.ALU, arg=BinaryOps.MOD, src=(UPat.var('base'), UPat.cvar("const"))), + pat += [(UPat(Ops.MOD, src=(UPat.var('base'), UPat.cvar("const"))), lambda base,const: base & (const.arg-1) if const.arg in powers_of_two else None)] # rewrite MUL/IDIV to SHL+SHR if BinaryOps.SHL in ops and BinaryOps.SHR in ops: pat += [ - (UPat(Ops.ALU, arg=BinaryOps.MUL, dtype=dtypes.ints, src=[UPat.cvar("const"), UPat.var("mul")]), lambda mul, const: + (UPat(Ops.MUL, dtype=dtypes.ints, src=[UPat.cvar("const"), UPat.var("mul")]), lambda mul, const: mul << powers_of_two[const.arg] if const.arg in powers_of_two else None), # (x * (2**y)) -> shl(x,y) - (UPat(Ops.ALU, arg=BinaryOps.IDIV, src=(UPat.var("div"), UPat.cvar("const"))), lambda div, const: + (UPat(Ops.IDIV, src=(UPat.var("div"), UPat.cvar("const"))), lambda div, const: div >> powers_of_two[const.arg] if const.arg in powers_of_two else None)] # (x // (2**y)) -> shr(x,y) if UnaryOps.NEG in ops: pat += [(UPat.var('x')*-1, lambda x: x.alu(UnaryOps.NEG))] @@ -226,8 +226,8 @@ def reduce_collapse(acc:UOp, ret:UOp, alu:UOp): reduce_parented, reduce_unparented = partition(acc.src[1:], lambda x: x in ret.sparents) if len(reduce_unparented) == 0: return None new_acc = acc.replace(src=acc.src[0:1]+tuple(reduce_parented)) - ret = new_acc.assign(new_acc.alu(alu.arg, ret)) - if alu.arg is BinaryOps.ADD: + ret = new_acc.assign(new_acc.alu(alu.op, ret)) + if alu.op is BinaryOps.ADD: for r in reduce_unparented: ret = ret * (r.src[1]-r.src[0]).cast(ret.dtype.scalar()).broadcast(ret.dtype.count) return ret @@ -249,8 +249,8 @@ sym = symbolic_flat+PatternMatcher([ (UPat(Ops.VECTORIZE, src=UPat(Ops.CONST), name="vec"), lambda vec: UOp.const(vec.dtype, tuple(x.arg for x in vec.src))), (UPat(Ops.VECTORIZE, src=UPat(Ops.GEP, src=(UPat(name="x"),)), name="vec"), lambda vec,x: x.gep(tuple(y.arg[0] for y in vec.src))), # reorder ALU/VECTORIZE - (UPat(Ops.ALU, src=(UPat(Ops.VECTORIZE, src=UPat(name='x')), UPat(Ops.VECTORIZE, src=UPat(name='y'))), name='alu'), - lambda x,y,alu: UOp(Ops.VECTORIZE, alu.dtype, (UOp(Ops.ALU, alu.dtype.scalar(), (x,y), alu.arg),)*alu.dtype.count)), + (UPat(GroupOp.ALU, src=(UPat(Ops.VECTORIZE, src=UPat(name='x')), UPat(Ops.VECTORIZE, src=UPat(name='y'))), name='alu'), + lambda x,y,alu: UOp(Ops.VECTORIZE, alu.dtype, (UOp(alu.op, alu.dtype.scalar(), (x,y)),)*alu.dtype.count)), # VECTORIZE of a single element is just that element (UPat(Ops.VECTORIZE, src=(UPat(name='x'),)), lambda x: x), # VECTORIZE void is SINK @@ -264,7 +264,7 @@ sym = symbolic_flat+PatternMatcher([ (UPat(Ops.GEP, src=(UPat.cvar("c", vec=False),), name="gep"), lambda gep, c: gep.const_like(c.arg)), (UPat(Ops.GEP, src=(UPat(Ops.VCONST, name="c"),), name="gep"), lambda gep, c: gep.const_like(tuple(c.arg[x] for x in gep.arg))), # push all GEPs through ALUs (fix arange stuff) - (UPat(Ops.GEP, src=(UPat((Ops.ALU, Ops.CAST, Ops.BITCAST), name='alu'),), name='gep'), + (UPat(Ops.GEP, src=(UPat((*GroupOp.ALU, Ops.CAST, Ops.BITCAST), name='alu'),), name='gep'), lambda gep,alu: UOp(alu.op, alu.dtype.scalar().vec(gep.dtype.count), tuple(x.gep(gep.arg) for x in alu.src), alu.arg)), # push some GEPs through WMMAs (UPat(Ops.GEP, src=(UPat(Ops.WMMA, name="wmma"),), name="gep"), gep_through_wmma), @@ -275,15 +275,15 @@ sym = symbolic_flat+PatternMatcher([ (UPat.var("add") + UPat(Ops.WMMA, name="wmma"), lambda add, wmma: UOp(wmma.op, wmma.dtype, (wmma.src[0], wmma.src[1], wmma.src[2]+add), wmma.arg)), # threefry - (UPat(Ops.ALU, dtype=dtypes.uint64, src=(UPat.var("x"), UPat.var("key")), arg=BinaryOps.THREEFRY), threefry2x32), + (UPat(Ops.THREEFRY, dtype=dtypes.uint64, src=(UPat.var("x"), UPat.var("key"))), threefry2x32), # arange loop folding (acc_pat.assign(UPat.any(arange_m, arange_m+UPat.var("extra"))+acc_pat), loop_collapse), # indexing, with cast or where (acc_pat.assign(UPat.var("idx").eq(UPat(Ops.RANGE, name="rng")).cast()*index_load+acc_pat), index_collapse), (acc_pat.assign(UPat.var("idx").eq(UPat(Ops.RANGE, name="rng")).where(index_load, UPat.const(None, 0.0))+acc_pat), index_collapse), # parentless reduce - (acc_pat.assign(UPat(Ops.ALU, src=[acc_pat, UPat.var("ret")], arg=BinaryOps.ADD, name="alu")), reduce_collapse), - (acc_pat.assign(UPat(Ops.ALU, src=[acc_pat, UPat.var("ret")], arg=BinaryOps.MAX, name="alu")), reduce_collapse), + (acc_pat.assign(UPat(Ops.ADD, src=[acc_pat, UPat.var("ret")], name="alu")), reduce_collapse), + (acc_pat.assign(UPat(Ops.MAX, src=[acc_pat, UPat.var("ret")], name="alu")), reduce_collapse), # ** self folding ** (UPat(Ops.DEFINE_ACC, src=(UPat.var("x"),)), lambda x: x), # a DEFINE_ACC without ranges is a CONST (UPat(Ops.ASSIGN, src=(UPat.cvar(),UPat.var("x"))), lambda x: x), # an ASSIGN to a const is a NOOP @@ -403,7 +403,7 @@ expander = PatternMatcher([ (UPat(Ops.EXPAND, name="outer", src=(UPat(Ops.EXPAND, name="inner"),)), lambda outer, inner: UOp(Ops.EXPAND, outer.dtype, (inner.src[0],), inner.arg+outer.arg)), # do expansion - (UPat((Ops.ALU, Ops.CAST, Ops.BITCAST, Ops.GEP, Ops.WMMA, Ops.LOAD, Ops.STORE, Ops.INDEX, Ops.ASSIGN, + (UPat((*GroupOp.ALU, Ops.CAST, Ops.BITCAST, Ops.GEP, Ops.WMMA, Ops.LOAD, Ops.STORE, Ops.INDEX, Ops.ASSIGN, Ops.VECTORIZE, Ops.REDUCE, Ops.IF), name="root", custom_early_reject=set([(Ops.EXPAND, None)])), do_expand), (UPat(Ops.CONTRACT, name="con"), do_contract), # vectorize DEFINE_ACC @@ -433,7 +433,7 @@ def no_vectorized_acc(acc:UOp): devectorize = PatternMatcher([ # no ALU on vectorized dtypes - (UPat((Ops.ALU, Ops.CAST, Ops.BITCAST, Ops.ASSIGN, Ops.INDEX), name="alu"), no_vectorized_alu), + (UPat((*GroupOp.ALU, Ops.CAST, Ops.BITCAST, Ops.ASSIGN, Ops.INDEX), name="alu"), no_vectorized_alu), (UPat(Ops.WMMA, name="wmma"), no_vectorized_wmma), (UPat(Ops.DEFINE_ACC, name="acc"), no_vectorized_acc), (UPat((Ops.LOAD, Ops.STORE), name="ls"), no_vectorized_load_store), @@ -448,7 +448,7 @@ load_store_indexing = PatternMatcher([ # late fixup of unfoldable image loads (UPat(Ops.LOAD, src=(UPat.var("buf"), UPat()), allow_any_len=True, name="load"), fix_unfoldable_image_load), # simplify valid - (UPat(Ops.ALU, name="valid", arg=BinaryOps.AND), simplify_valid), + (UPat(Ops.AND, name="valid"), simplify_valid), # image load valid idx simplification (UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var("start_idx"), UPat.var("valid"))), simplify_valid_load), # delete_redundant_gates (after expand) diff --git a/tinygrad/engine/schedule.py b/tinygrad/engine/schedule.py index 43702fa4d4..0de75c30df 100644 --- a/tinygrad/engine/schedule.py +++ b/tinygrad/engine/schedule.py @@ -1,7 +1,7 @@ import sys, atexit, functools, itertools from collections import defaultdict, deque from dataclasses import dataclass, field -from typing import Callable, Set, Tuple, List, Dict, Optional, DefaultDict +from typing import Callable, Set, Tuple, List, Dict, Optional, DefaultDict, cast from tinygrad.ops import GroupOp, UOp, Ops, PatternMatcher, UPat, Variable, graph_rewrite, track_rewrites, sint from tinygrad.helpers import DEBUG, Metadata, all_same, colored, diskcache_put, prod, dedup, getenv, unwrap from tinygrad.dtype import ImageDType, dtypes @@ -72,9 +72,7 @@ def to_uop(buf:LazyBuffer, ctx:ScheduleContext, cache:Dict[LazyBuffer, UOp]) -> elif buf.op is Ops.CONTIGUOUS: ret = UOp(Ops.CONTIGUOUS, dtype, src) elif buf.op is Ops.ASSIGN: ret = UOp(Ops.ASSIGN, dtype, (ubuf, src[1]), buf.arg) elif buf.op in GroupOp.Meta: ret = UOp(buf.op, buf.dtype, (ubuf, *src), buf.arg) - elif buf.op is Ops.CAST: ret = UOp(Ops.CAST, dtype, src) - elif buf.op is Ops.BITCAST: ret = UOp(Ops.BITCAST, dtype, src) - else: ret = UOp(Ops.ALU, dtype, src, buf.op) + else: ret = UOp(cast(Ops, buf.op), dtype, src) cache[buf] = ret = UOp(Ops.LOAD, dtype, (ubuf, buf.st.to_uop(), UOp.store(ubuf, ShapeTracker.from_shape(buf.shape).to_uop(), ret))) if buf.metadata is not None: ctx.ubuf_metadata[ubuf] = buf.metadata if buf.forced_realize: ctx.realizes[ubuf] = ubuf @@ -142,7 +140,7 @@ merge_views = PatternMatcher([(UPat(Ops.VIEW, src=(UPat(Ops.VIEW, name="s0"),), # push VIEW to loads view_left = merge_views+PatternMatcher([ # view before ALU - (UPat(Ops.VIEW, src=(UPat((Ops.ALU, Ops.CAST, Ops.BITCAST, Ops.ASSIGN, Ops.CONTIGUOUS, *GroupOp.Buffer), name="e"),), name="v"), + (UPat(Ops.VIEW, src=(UPat((*GroupOp.ALU, Ops.CAST, Ops.BITCAST, Ops.ASSIGN, Ops.CONTIGUOUS, *GroupOp.Buffer), name="e"),), name="v"), lambda e,v: e.replace(src=tuple(s.view(v.st) if s.has_st else s for s in e.src))), ]) @@ -156,7 +154,7 @@ view_right = merge_views+PatternMatcher([ # push a VIEW down to STORE, through a reduce (ONLY reshapes) (UPat(Ops.REDUCE_AXIS, src=(UPat(Ops.VIEW, name="swizzle"),), name="root"), push_swizzle_down_through_reduce), # push VIEW(s) down to STORE, through an elementwise op (ONLY reshapes) - (UPat((Ops.ALU, Ops.CAST, Ops.BITCAST, Ops.ASSIGN, Ops.CONTIGUOUS, Ops.STORE), name="root"), push_swizzle_down_through_elementwise), + (UPat((*GroupOp.ALU, Ops.CAST, Ops.BITCAST, Ops.ASSIGN, Ops.CONTIGUOUS, Ops.STORE), name="root"), push_swizzle_down_through_elementwise), (UPat(Ops.REDUCE_AXIS, src=(UPat(Ops.REDUCE_AXIS, name="first_reduce"),), name="root"), merge_double_reduce), ]) diff --git a/tinygrad/ops.py b/tinygrad/ops.py index 68f92ffcd4..49a8b94ccf 100644 --- a/tinygrad/ops.py +++ b/tinygrad/ops.py @@ -169,7 +169,11 @@ class Ops(FastEnum): CONST = auto() class GroupOp: + Unary = {Ops.EXP2, Ops.LOG2, Ops.SIN, Ops.SQRT, Ops.RECIP, Ops.NEG} Binary = {Ops.ADD, Ops.MUL, Ops.IDIV, Ops.MAX, Ops.MOD, Ops.CMPLT, Ops.CMPNE, Ops.XOR, Ops.SHL, Ops.SHR, Ops.OR, Ops.AND, Ops.THREEFRY, Ops.SUB} + Ternary = {Ops.WHERE, Ops.MULACC} + ALU = set.union(Unary, Binary, Ternary) + Reduce = {Ops.SUM, Ops.PROD, Ops.REDUCE_MAX} # meta ops @@ -224,6 +228,8 @@ def pretty_print(x:Any, rep:Callable, srcfn=lambda x: x.src, cache=None, d=0)->s class UOpMetaClass(type): ucache:WeakValueDictionary[Tuple, UOp] = WeakValueDictionary() def __call__(cls, op:Ops, dtype:DType=dtypes.void, src:Tuple[UOp,...]=tuple(), arg:Any=None): + # TODO: remove this + if op is Ops.ALU: op, arg = arg, None if (ret:=UOpMetaClass.ucache.get(key:=(op, dtype, src, arg), None)) is not None: return ret UOpMetaClass.ucache[key] = ret = super().__call__(op, dtype, src, arg) return ret @@ -232,11 +238,6 @@ class UOp(MathTrait, metaclass=UOpMetaClass): __slots__ = ["op", "dtype", "src", "arg"] def __init__(self, op:Ops, dtype:DType=dtypes.void, src: Tuple[UOp,...]=tuple(), arg:Any=None): # TODO: instant check rules here make debugging easier - #assert op in UOps and isinstance(dtype, DType), f"bad UOp creation with {op} {dtype}" - #if op is UOps.ALU and arg is BinaryOps.CMPNE: assert dtype.scalar() == dtypes.bool - #if op is UOps.VECTORIZE and dtype != dtypes.void: assert len(src) == dtype.count, f"{len(src)} invalid for {dtype}" - #if op is UOps.ALU and arg not in (BinaryOps.CMPNE, BinaryOps.CMPLT, TernaryOps.WHERE): assert all_same([dtype] + [x.dtype for x in src]) - #if op is UOps.CAST: assert dtype.count == src[0].dtype.count, f"cast can't change vectorization {src[0].dtype} --> {dtype}" self.op, self.dtype, self.src, self.arg = op, dtype, src, arg def __reduce__(self): return UOp, (self.op, self.dtype, self.src, self.arg) def replace(self, **kwargs) -> UOp: @@ -256,7 +257,7 @@ class UOp(MathTrait, metaclass=UOpMetaClass): @functools.cached_property def tuplize(self:UOp) -> Tuple[int, Any, Optional[DType], Tuple]: - return (self.op.value, self.arg.value if self.op is Ops.ALU else self.arg, self.dtype, tuple(x.tuplize for x in self.src)) + return (self.op.value, self.arg, self.dtype, tuple(x.tuplize for x in self.src)) # *** uop shape stuff *** @@ -334,7 +335,7 @@ class UOp(MathTrait, metaclass=UOpMetaClass): out_dtype = (self, *src)[-1].dtype if arg in {BinaryOps.CMPLT, BinaryOps.CMPNE} and out_dtype is not None: out_dtype = dtypes.bool.vec(out_dtype.count) if out_dtype.count > 1 else dtypes.bool - return UOp(Ops.ALU, out_dtype, (self,)+src, arg) + return UOp(arg, out_dtype, (self,)+src) @staticmethod def const(dtype:DType, b:ConstLike): if isinstance(b, UOp): return b.unbind()[0] if b.op is Ops.BIND else b @@ -382,19 +383,17 @@ class UOp(MathTrait, metaclass=UOpMetaClass): """largest known int that divides self""" if self.op is Ops.CONST: return self.arg if self.op is Ops.VCONST: return functools.reduce(math.gcd, self.arg) - if self.op is Ops.ALU: - if self.arg is BinaryOps.ADD: return math.gcd(self.src[0].const_factor(), self.src[1].const_factor()) - if self.arg is BinaryOps.MUL: return self.src[0].arg if self.src[0].op is Ops.CONST else self.src[1].arg if self.src[1].op is Ops.CONST else 1 + if self.op is BinaryOps.ADD: return math.gcd(self.src[0].const_factor(), self.src[1].const_factor()) + if self.op is BinaryOps.MUL: return self.src[0].arg if self.src[0].op is Ops.CONST else self.src[1].arg if self.src[1].op is Ops.CONST else 1 return 1 def divides(self, v) -> Optional[UOp]: if v==1: return self if self.op is Ops.CONST: return self.const_like(self.arg//v) if self.arg%v == 0 else None if self.op is Ops.VCONST: return self.const_like(tuple(x//v for x in self.arg)) if all(x%v == 0 for x in self.arg) else None - if self.op is Ops.ALU: - if self.arg is BinaryOps.ADD: return d0+d1 if (d0:=self.src[0].divides(v)) is not None and (d1:=self.src[1].divides(v)) is not None else None - if self.arg is BinaryOps.MUL: - if (d0:=self.src[0].divides(v)) is not None: return d0 * self.src[1] - if (d1:=self.src[1].divides(v)) is not None: return self.src[0] * d1 + if self.op is BinaryOps.ADD: return d0+d1 if (d0:=self.src[0].divides(v)) is not None and (d1:=self.src[1].divides(v)) is not None else None + if self.op is BinaryOps.MUL: + if (d0:=self.src[0].divides(v)) is not None: return d0 * self.src[1] + if (d1:=self.src[1].divides(v)) is not None: return self.src[0] * d1 return None # generic None if we aren't sure @property def vmin(self) -> ConstType: return self._min_max[0] @@ -411,25 +410,25 @@ class UOp(MathTrait, metaclass=UOpMetaClass): if self.op is Ops.SPECIAL: return 0, self.arg[1]-1 if isinstance(self.arg[1], int) else dtypes.max(self.dtype) if self.op is Ops.CONST: return self.arg, self.arg if self.op is Ops.VCONST: return (min(self.arg), max(self.arg)) - if self.op is Ops.ALU and not dtypes.is_float(self.dtype): + if self.op in GroupOp.ALU and not dtypes.is_float(self.dtype): s0,s1,s2 = [cast(UOp, self.src[i] if i < len(self.src) else None) for i in range(3)] - if self.arg is BinaryOps.ADD: return s0.vmin+s1.vmin, s0.vmax+s1.vmax - if self.arg is BinaryOps.MUL: return min(vals:=(s0.vmin*s1.vmin, s0.vmin*s1.vmax, s0.vmax*s1.vmin, s0.vmax*s1.vmax)), max(vals) - if self.arg is BinaryOps.MOD and s1.vmin > 0: return 0, s1.vmax-1 - if self.arg is BinaryOps.IDIV and s1.op is Ops.CONST: + if self.op is BinaryOps.ADD: return s0.vmin+s1.vmin, s0.vmax+s1.vmax + if self.op is BinaryOps.MUL: return min(vals:=(s0.vmin*s1.vmin, s0.vmin*s1.vmax, s0.vmax*s1.vmin, s0.vmax*s1.vmax)), max(vals) + if self.op is BinaryOps.MOD and s1.vmin > 0: return 0, s1.vmax-1 + if self.op is BinaryOps.IDIV and s1.op is Ops.CONST: if s1.arg > 0: return s0.vmin//s1.arg, s0.vmax//s1.arg if s1.arg < 0 and s0.vmin >= 0: return -(s0.vmax//-s1.arg), -(s0.vmin//-s1.arg) - if self.arg is BinaryOps.MAX: return max(s0.vmin, s1.vmin), max(s0.vmax, s1.vmax) - if self.arg is BinaryOps.CMPLT: return (s0.vmax Tuple[sint, sint]: mem += u.dtype.itemsize * mults elif u.op is Ops.STORE: mem += u.src[1].dtype.itemsize * mults - elif u.op is Ops.ALU and u not in dont_count: - flops += (mults * (2 if u.arg == TernaryOps.MULACC else 1)) * u.dtype.count + elif u.op in GroupOp.ALU and u not in dont_count: + flops += (mults * (2 if u.op is TernaryOps.MULACC else 1)) * u.dtype.count elif u.op is Ops.WMMA and u not in dont_count: flops += 2 * prod(u.arg[1]) // u.arg[5] * mults return flops, mem @@ -582,9 +581,9 @@ class UPat(MathTrait): def assign(self, x:UPat): return UPat(Ops.ASSIGN, self.dtype, (self,x)) def const_like(self, b:ConstLike): return UPat.const(self.dtype, cast(ConstType, b)) - def alu(self, arg, *src:UPat): + def alu(self, op:Ops, *src:UPat): asrc = (self,)+src - return UPat(Ops.ALU, None if arg in {Ops.CMPLT, Ops.CMPNE} else asrc[-1].dtype, list(asrc) if arg in GroupOp.Commutative else asrc, arg) + return UPat(op, None if op in {Ops.CMPLT, Ops.CMPNE} else asrc[-1].dtype, list(asrc) if op in GroupOp.Commutative else asrc) def printable(self:UPat) -> str: try: return lines(self.location[0])[self.location[1]-1].strip() @@ -790,17 +789,13 @@ spec = PatternMatcher([ (UPat(Ops.STORE, dtype=dtypes.void, src=(UPat((Ops.INDEX, Ops.CAST)), UPat(), UPat(Ops.IF))), lambda: True), # most ALUs have all matching dtypes, except CMPLT, CMPNE, and WHERE - (UPat(Ops.ALU, name="w", src=(UPat(dtype=dtypes.bool), UPat(name="x"), UPat(name="y")), arg=TernaryOps.WHERE), - lambda w,x,y: w.dtype == x.dtype == y.dtype), - (UPat(Ops.ALU, dtype=dtypes.bool, src=(UPat(name="x"), UPat(name="y")), arg=BinaryOps.CMPLT), lambda x,y: x.dtype == y.dtype), - (UPat(Ops.ALU, dtype=dtypes.bool, src=(UPat(name="x"), UPat(name="y")), arg=BinaryOps.CMPNE), lambda x,y: x.dtype == y.dtype), - # and SHL/SHR, the shift distance is an int - (UPat(Ops.ALU, src=(UPat(name="x"), UPat(name="y")), name="alu", arg=BinaryOps.SHL), + (UPat(Ops.WHERE, name="w", src=(UPat(dtype=dtypes.bool), UPat(name="x"), UPat(name="y"))), lambda w,x,y: w.dtype == x.dtype == y.dtype), + (UPat((Ops.CMPLT, Ops.CMPNE), dtype=dtypes.bool, src=(UPat(name="x"), UPat(name="y"))), lambda x,y: x.dtype == y.dtype), + # and SHL/SHR, the shift distance can be an int + (UPat((Ops.SHL, Ops.SHR), src=(UPat(name="x"), UPat(name="y")), name="alu"), lambda alu,x,y: alu.dtype == x.dtype and (x.dtype == y.dtype or y.dtype == dtypes.uint)), - (UPat(Ops.ALU, src=(UPat(name="x"), UPat(name="y")), name="alu", arg=BinaryOps.SHR), - lambda alu,x,y: alu.dtype == x.dtype and (x.dtype == y.dtype or y.dtype == dtypes.uint)), - (UPat(Ops.ALU, arg=BinaryOps.IDIV, name="x"), lambda x: None if dtypes.is_int(x.dtype) else False), - (UPat(Ops.ALU, name="x"), lambda x: all(x.dtype == y.dtype for y in x.src)), + (UPat(Ops.IDIV, name="x"), lambda x: None if dtypes.is_int(x.dtype) else False), + (UPat(GroupOp.ALU, name="x"), lambda x: all(x.dtype == y.dtype for y in x.src)), (UPat(Ops.ASSIGN, src=(UPat((Ops.DEFINE_ACC, Ops.DEFINE_GLOBAL)), UPat())), lambda: True), (UPat(Ops.ENDRANGE, dtype=dtypes.void, src=(UPat(Ops.RANGE),)), lambda: True), @@ -848,7 +843,7 @@ def cast_float_to_bf16(x: UOp) -> UOp: # *** most of symbolic lives here now *** def split_uop(x:UOp, sep:Ops): - if x.op is Ops.ALU and x.arg is sep: + if x.op is sep: for s in x.src: yield from split_uop(s, sep) else: yield x @@ -865,7 +860,7 @@ def mod_folding(x:UOp, c:int) -> Optional[UOp]: assert divides is not None remainder.append(divides) something_changed = True - elif u.op is Ops.ALU and u.arg is BinaryOps.MOD and (s1:=u.src[1]).op is Ops.CONST and s1.arg%c == 0: + elif u.op is Ops.MOD and (s1:=u.src[1]).op is Ops.CONST and s1.arg%c == 0: remainder.append(u.src[0]) something_changed = True else: remainder.append(u) @@ -892,7 +887,7 @@ def div_folding(x:UOp, c:int) -> Optional[UOp]: something_changed = True else: # divisor is the smallest common divisor of all MULs - if u.op is Ops.ALU and u.arg is BinaryOps.MUL and factor > 1 and c % factor == 0 and (divisor == 1 or divisor > factor): divisor = factor + if u.op is Ops.MUL and factor > 1 and c % factor == 0 and (divisor == 1 or divisor > factor): divisor = factor remainder.append(u) gcd = math.gcd(gcd, factor) @@ -923,11 +918,11 @@ def fold_unrolled_divs(divs:UOp): # example: (x//4+(x+1)//4+(x+2)//4+(x+3)//4 -> x add_chain, denominator, seen_const, ans = list(split_uop(divs, BinaryOps.ADD)), None, [], None for u in add_chain: - if not (u.op is Ops.ALU and u.arg is BinaryOps.IDIV and u.src[1].op is Ops.CONST): return None + if not (u.op is Ops.IDIV and u.src[1].op is Ops.CONST): return None if denominator is None: denominator = u.src[1].arg if denominator != u.src[1].arg: return None # assumed CONST is the last of an ADD - if (s0:=u.src[0]).op is Ops.ALU and s0.arg is BinaryOps.ADD and s0.src[1].op is Ops.CONST and s0.src[1].op is Ops.CONST: + if (s0:=u.src[0]).op is Ops.ADD and s0.src[1].op is Ops.CONST and s0.src[1].op is Ops.CONST: seen_const.append(s0.src[1].arg) s0 = s0.src[0] else: seen_const.append(0) @@ -947,7 +942,7 @@ def canonicalize_simplex(X:UOp) -> Optional[UOp]: changed, ret = False, [] for u in split_uop(X, BinaryOps.ADD): # assumed the const is the last src of MUL - if u.op is Ops.ALU and u.arg is BinaryOps.MUL and u.src[1].op is Ops.CONST and u.src[1].arg > 0: + if u.op is BinaryOps.MUL and u.src[1].op is Ops.CONST and u.src[1].arg > 0: changed = True u = u.src[0] if not (is_irreducible(u) and u.vmin >= 0): return None @@ -957,8 +952,8 @@ def canonicalize_simplex(X:UOp) -> Optional[UOp]: def is_increasing(f:UOp) -> bool: # is f a monotonically increasing function regards its input if is_irreducible(f): return True - if f.op is Ops.ALU and f.arg is BinaryOps.ADD: return is_increasing(f.src[0]) and is_increasing(f.src[1]) - if f.op is Ops.ALU and f.arg in (BinaryOps.MUL, BinaryOps.IDIV) and f.src[1].op is Ops.CONST and f.src[1].arg >= 0: return is_increasing(f.src[0]) + if f.op is BinaryOps.ADD: return is_increasing(f.src[0]) and is_increasing(f.src[1]) + if f.op in (BinaryOps.MUL, BinaryOps.IDIV) and f.src[1].op is Ops.CONST and f.src[1].arg >= 0: return is_increasing(f.src[0]) return False # False if not sure def parse_valid(valid:UOp) -> Tuple[UOp, bool, int]: @@ -966,10 +961,10 @@ def parse_valid(valid:UOp) -> Tuple[UOp, bool, int]: # if it's X >= c, returns X, False, c # (X < c).ne(True) -> X >= c - if valid.op is Ops.ALU and valid.arg is BinaryOps.CMPNE and valid.src[1].op is Ops.CONST and valid.src[1].arg == 1 and \ - (s0:=valid.src[0]).op is Ops.ALU and s0.arg is BinaryOps.CMPLT and s0.src[1].op is Ops.CONST: return s0.src[0], False, s0.src[1].arg + if valid.op is BinaryOps.CMPNE and valid.src[1].op is Ops.CONST and valid.src[1].arg == 1 and \ + (s0:=valid.src[0]).op is BinaryOps.CMPLT and s0.src[1].op is Ops.CONST: return s0.src[0], False, s0.src[1].arg # X < c -> X <= c-1 - if valid.op is Ops.ALU and valid.arg is BinaryOps.CMPLT and valid.src[1].op is Ops.CONST: return valid.src[0], True, valid.src[1].arg-1 + if valid.op is BinaryOps.CMPLT and valid.src[1].op is Ops.CONST: return valid.src[0], True, valid.src[1].arg-1 raise ValueError(f"not able to parse {valid=}") def uop_given_valid(valid:UOp, uop:UOp) -> Optional[UOp]: @@ -989,7 +984,7 @@ def uop_given_valid(valid:UOp, uop:UOp) -> Optional[UOp]: # every candidate is a set of contrained UOp based on valid, and if every item in a set simplifies the uop into a same output, we rewrite uop candidates = [] - if expr.op is Ops.ALU and expr.arg is BinaryOps.ADD and all(is_irreducible(u) and v[0] == 1 for u in split_uop(expr, BinaryOps.ADD)): + if expr.op is Ops.ADD and all(is_irreducible(u) and v[0] == 1 for u in split_uop(expr, BinaryOps.ADD)): # if the constraint is a simplex: X0 + X1 + ... > 0, we can check if all Xi > 0 simplify into the same output candidates.append([(Xi, UOp.variable("fake", 1, Xi.vmax, Xi.dtype)) for Xi in split_uop(expr, BinaryOps.ADD)]) # try checking the whole clause @@ -1043,8 +1038,8 @@ symbolic_simple = PatternMatcher([ # NOTE: this can be wrong for loaded NaN (UPat.var("x") * 0, lambda x: x.const_like(float("nan") if isinstance(x.arg, float) and (math.isnan(x.arg) or math.isinf(x.arg)) else 0)), # ** constant folding ** - (UPat(Ops.ALU, name="root", src=UPat((Ops.VCONST, Ops.CONST))), - lambda root: root.const_like(exec_alu(root.arg, root.dtype, [x.arg for x in root.src], truncate_output=False))), + (UPat(GroupOp.ALU, name="root", src=UPat((Ops.VCONST, Ops.CONST))), + lambda root: root.const_like(exec_alu(root.op, root.dtype, [x.arg for x in root.src], truncate_output=False))), # bool MUL is AND, ADD/MAX is OR. prevents other rules to rewrite bool ADD/MUL incorrectly (UPat.var('x', dtype=dtypes.bool) * UPat.var('y', dtype=dtypes.bool), lambda x,y: x&y), (UPat.var('x', dtype=dtypes.bool) + UPat.var('y', dtype=dtypes.bool), lambda x,y: x|y), @@ -1056,8 +1051,7 @@ symbolic_simple = PatternMatcher([ symbolic = symbolic_simple+PatternMatcher([ # ** COMMUTATIVE flipping ** - *[(UPat(Ops.ALU, arg=op, name='x'), lambda x: x.replace(src=x.src[::-1]) if x.src[1].tuplize < x.src[0].tuplize else None) \ - for op in GroupOp.Commutative], + (UPat(GroupOp.Commutative, name='x'), lambda x: x.replace(src=x.src[::-1]) if x.src[1].tuplize < x.src[0].tuplize else None), # group like ((UPat.var("x") + UPat.var("y")) + UPat.var("x") * UPat.cvar("c"), lambda x,y,c: (x+x*c)+y), # ** combine terms ** @@ -1070,7 +1064,7 @@ symbolic = symbolic_simple+PatternMatcher([ (UPat.var().where(UPat.var("val"), UPat.var("val")), lambda val: val), (UPat.cvar("gate", vec=False).where(UPat.var("c0"), UPat.var("c1")), lambda gate, c0, c1: c0 if gate.arg else c1), # ALU min==max -> CONST (slow!) - (UPat(Ops.ALU, name="x"), lambda x: x.const_like(x.vmin) if x.vmin == x.vmax else None), + (UPat(GroupOp.ALU, name="x"), lambda x: x.const_like(x.vmin) if x.vmin == x.vmax else None), # max folding (UPat.maximum(UPat.var("x"), UPat.var("y")), lambda x,y: x if x.vmin >= y.vmax else y if x.vmax <= y.vmin else None), # TODO: why does this rule break beautiful_mnist? @@ -1097,11 +1091,11 @@ symbolic = symbolic_simple+PatternMatcher([ (((UPat.cvar("c0", vec=False)*UPat.var("x"))+UPat.var("x2")).lt(UPat.cvar("c1", vec=False)), lambda x,x2,c0,c1: x.lt(c1//c0) if c1.arg % c0.arg == 0 and c0.arg > x2.vmax and x2.vmin >= 0 else None), # ** move add/mul consts to end (NOTE: this is still happening before constant folding) ** - (UPat(Ops.ALU, arg=BinaryOps.ADD, src=(UPat.var("x"), UPat.cvar("c1"))) + UPat.var("y"), lambda x,c1,y: (x+y)+c1), - (UPat(Ops.ALU, arg=BinaryOps.MUL, src=(UPat.var("x"), UPat.cvar("c1"))) * UPat.var("y"), lambda x,c1,y: (x*y)*c1), + (UPat(Ops.ADD, src=(UPat.var("x"), UPat.cvar("c1"))) + UPat.var("y"), lambda x,c1,y: (x+y)+c1), + (UPat(Ops.MUL, src=(UPat.var("x"), UPat.cvar("c1"))) * UPat.var("y"), lambda x,c1,y: (x*y)*c1), # *** rules from symbolic *** # unrolled arange div folding - (UPat(Ops.ALU, name="divs", src=[UPat(), UPat(Ops.ALU, arg=BinaryOps.IDIV)], arg=BinaryOps.ADD), fold_unrolled_divs), + (UPat(Ops.ADD, name="divs", src=[UPat(), UPat(Ops.IDIV)]), fold_unrolled_divs), # generic lt folding (UPat.var("x", dtypes.sints).lt(UPat.cvar("c", vec=False)), lambda x,c: lt_folding(x, c.arg) if 0 < c.arg else None), # canonicalize a simplex with positive coefficients > 0 @@ -1132,13 +1126,11 @@ renderer = PatternMatcher([ (UPat(Ops.RANGE, name="x"), lambda x: UOp(Ops.NOOP, arg=f"ridx{x.arg[0]}")), (UPat(Ops.CONST, name="x"), lambda x: UOp(Ops.NOOP, arg=str(x.arg))), (UPat(Ops.BIND, src=UPat(Ops.NOOP), name="x"), lambda x: x.src[0]), - (UPat(Ops.ALU, src=UPat(Ops.NOOP), name="x", arg=UnaryOps.NEG), lambda x: UOp(Ops.NOOP, arg=f"(-{x.src[0].arg})")), - (UPat(Ops.ALU, src=UPat(Ops.NOOP), name="x", arg=BinaryOps.MAX), lambda x: UOp(Ops.NOOP, arg=f"max({x.src[0].arg}, {x.src[1].arg})")), - (UPat(Ops.ALU, src=UPat(Ops.NOOP), name="x", arg=TernaryOps.MULACC), - lambda x: UOp(Ops.NOOP, arg=f"({x.src[0].arg}*{x.src[1].arg}+{x.src[2].arg})")), - (UPat(Ops.ALU, src=UPat(Ops.NOOP), name="x", arg=TernaryOps.WHERE), - lambda x: UOp(Ops.NOOP, arg=f"({x.src[1].arg} if {x.src[0].arg} else {x.src[2].arg})")), - (UPat(Ops.ALU, src=UPat(Ops.NOOP), name="x"), lambda x: UOp(Ops.NOOP, arg=f"({x.src[0].arg}{syms[x.arg]}{x.src[1].arg})")), + (UPat(Ops.NEG, src=UPat(Ops.NOOP), name="x"), lambda x: UOp(Ops.NOOP, arg=f"(-{x.src[0].arg})")), + (UPat(Ops.MAX, src=UPat(Ops.NOOP), name="x"), lambda x: UOp(Ops.NOOP, arg=f"max({x.src[0].arg}, {x.src[1].arg})")), + (UPat(Ops.MULACC, src=UPat(Ops.NOOP), name="x"), lambda x: UOp(Ops.NOOP, arg=f"({x.src[0].arg}*{x.src[1].arg}+{x.src[2].arg})")), + (UPat(Ops.WHERE, src=UPat(Ops.NOOP), name="x"), lambda x: UOp(Ops.NOOP, arg=f"({x.src[1].arg} if {x.src[0].arg} else {x.src[2].arg})")), + (UPat(GroupOp.ALU, src=UPat(Ops.NOOP), name="x"), lambda x: UOp(Ops.NOOP, arg=f"({x.src[0].arg}{syms[x.op]}{x.src[1].arg})")), ]) # *** what was symbolic.py *** diff --git a/tinygrad/renderer/cstyle.py b/tinygrad/renderer/cstyle.py index 95ea4dddc9..b2938de777 100644 --- a/tinygrad/renderer/cstyle.py +++ b/tinygrad/renderer/cstyle.py @@ -2,7 +2,7 @@ from __future__ import annotations from typing import Dict, List, Optional, Tuple, Union, DefaultDict, Literal, Callable, cast import os, math from collections import defaultdict, Counter -from tinygrad.ops import UnaryOps, BinaryOps, TernaryOps, Ops, UOp, PatternMatcher, UPat, cast_float_to_bf16 +from tinygrad.ops import GroupOp, UnaryOps, BinaryOps, TernaryOps, Ops, UOp, PatternMatcher, UPat, cast_float_to_bf16 from tinygrad.helpers import strip_parens, getenv, prod, dedup, AMX from tinygrad.dtype import ImageDType, dtypes, DType, PtrDType from tinygrad.renderer import Renderer, TensorCore @@ -47,8 +47,8 @@ base_rewrite = PatternMatcher([ (UPat(Ops.LOAD, src=(UPat.var('bidx'),), allow_any_len=True), lambda ctx,bidx: f"*{ctx[bidx]}"), (UPat(Ops.STORE, src=(UPat.var('bidx'), UPat.var("var")), allow_any_len=True), lambda ctx,bidx,var: f"*{ctx[bidx]} = {ctx[var]};"), # alu/gep - (UPat(Ops.ALU, name="x"), lambda ctx,x: ctx.code_for_op[x.arg]( - *([strip_parens(ctx[v]) if v.arg == x.arg and x.arg in {BinaryOps.ADD, BinaryOps.MUL, BinaryOps.XOR} else ctx[v] for v in x.src]), x.dtype)), + (UPat(GroupOp.ALU, name="x"), lambda ctx,x: ctx.code_for_op[x.op]( + *([strip_parens(ctx[v]) if v.op == x.op and x.op in {BinaryOps.ADD, BinaryOps.MUL, BinaryOps.XOR} else ctx[v] for v in x.src]), x.dtype)), (UPat(Ops.GEP, name="x"), lambda ctx,x: ctx[x.src[0]] + \ (f"[{x.arg[0]}]" if x.src[0].dtype.count > (8 if ctx.device in {"CUDA", "NV"} else 4) or ctx.device == 'CLANG' else f".{'xyzwabcd'[x.arg[0]]}")), ]) @@ -61,7 +61,7 @@ extra_pm = PatternMatcher([ (UPat(Ops.STORE, dtype=dtypes.void, src=(UPat(), UPat(), UPat(dtype=dtypes.bool)), name="store"), lambda store: UOp(Ops.STORE, src=store.src[:2]+(UOp(Ops.IF, src=(store.src[2],)),))), # rewrite MAX to CMPLT + WHERE (max function is annoying on many cstyle backends) - (UPat(Ops.ALU, name="m", arg=BinaryOps.MAX), lambda m: (m.src[0] < m.src[1]).where(m.src[1], m.src[0])), + (UPat(Ops.MAX, name="m"), lambda m: (m.src[0] < m.src[1]).where(m.src[1], m.src[0])), ]) def uops_to_dtypes(uops:List[UOp]) -> List[DType]: return dedup(u.dtype for u in uops if not isinstance(u.dtype, (ImageDType, PtrDType))) @@ -140,16 +140,16 @@ class CStyleLanguage(Renderer): if u.op is Ops.SPECIAL: r[u] = u.arg[0] else: - prefix = {Ops.RANGE: "ridx", Ops.ALU: "alu", Ops.WMMA: "wmma", Ops.DEFINE_LOCAL: "temp", Ops.CONST: "const", + prefix = {Ops.RANGE: "ridx", Ops.WMMA: "wmma", Ops.DEFINE_LOCAL: "temp", Ops.CONST: "const", Ops.CAST: "cast", Ops.BITCAST: "cast", Ops.GEP: "gep", Ops.VECTORIZE: "cast", Ops.NOOP: "precast", - Ops.INDEX: "bidx", Ops.DEFINE_ACC: "acc", Ops.LOAD: "val"}.get(u.op, "unk") + Ops.INDEX: "bidx", Ops.DEFINE_ACC: "acc", Ops.LOAD: "val"}.get(u.op, "alu") r[u] = f"{prefix}{c[prefix]}" l = cast(str, self.string_rewrite.rewrite(u, ctx=self)) assert l is not None, f"failed to render {u.op} {u.dtype} {[(x.op,x.dtype) for x in u.src]} {u.arg}" if u.op in {Ops.ENDIF, Ops.ENDRANGE}: depth -= 1 - if u.op in {Ops.CONST, Ops.GEP, Ops.INDEX} or (u.op in {Ops.VECTORIZE, Ops.ALU, Ops.CAST, Ops.BITCAST} + if u.op in {Ops.CONST, Ops.GEP, Ops.INDEX} or (u.op in {Ops.VECTORIZE, *GroupOp.ALU, Ops.CAST, Ops.BITCAST} and child_count[u] == 1 and not getenv("EXPAND_SSA")): r[u] = l else: @@ -272,9 +272,8 @@ class MetalRenderer(CStyleLanguage): # upcast to float32 all the ops that don't support bfloat16 extra_matcher = PatternMatcher([ # NOTE: this is copied from PTX - *[(UPat(Ops.ALU, arg=op, dtype=dtypes.bfloat16, name="x"), - lambda x: (UOp(x.op, dtypes.float, tuple(vv.cast(dtypes.float) for vv in x.src), x.arg).cast(dtypes.bfloat16))) - for op in [UnaryOps.SQRT, UnaryOps.EXP2, UnaryOps.LOG2, UnaryOps.SIN]] + (UPat((UnaryOps.SQRT, UnaryOps.EXP2, UnaryOps.LOG2, UnaryOps.SIN), dtype=dtypes.bfloat16, name="x"), + lambda x: (UOp(x.op, dtypes.float, tuple(vv.cast(dtypes.float) for vv in x.src), x.arg).cast(dtypes.bfloat16))), ]) + extra_pm string_rewrite = PatternMatcher([ @@ -387,11 +386,11 @@ class AMDRenderer(CStyleLanguage): type_map = {dtypes.bfloat16: "hip_bfloat16"} extra_matcher = PatternMatcher([ # cast bfloat16 alus to float - (UPat(Ops.ALU, arg=TernaryOps.WHERE, src=(UPat.var("b"), UPat.var("x", dtype=dtypes.bfloat16), UPat.var("y", dtype=dtypes.bfloat16))), - lambda b,x,y: UOp(Ops.ALU, arg=TernaryOps.WHERE, dtype=dtypes.float, src=(b,x.cast(dtypes.float),y.cast(dtypes.float))).cast(dtypes.bfloat16)), - (UPat(Ops.ALU, dtype=dtypes.bfloat16, name="x"), + (UPat(Ops.WHERE, src=(UPat.var("b"), UPat.var("x", dtype=dtypes.bfloat16), UPat.var("y", dtype=dtypes.bfloat16))), + lambda b,x,y: UOp(Ops.WHERE, dtype=dtypes.float, src=(b,x.cast(dtypes.float),y.cast(dtypes.float))).cast(dtypes.bfloat16)), + (UPat(GroupOp.ALU, dtype=dtypes.bfloat16, name="x"), lambda x: UOp(x.op, dtypes.float, tuple(vv.cast(dtypes.float) for vv in x.src), x.arg).cast(dtypes.bfloat16)), - (UPat(Ops.ALU, dtypes.bool, name="alu", src=(UPat.var("x", dtype=dtypes.bfloat16), UPat.var("y", dtype=dtypes.bfloat16))), + (UPat(GroupOp.ALU, dtypes.bool, name="alu", src=(UPat.var("x", dtype=dtypes.bfloat16), UPat.var("y", dtype=dtypes.bfloat16))), lambda alu,x,y: UOp(alu.op, dtypes.bool, (x.cast(dtypes.float), y.cast(dtypes.float)), alu.arg)), # add float intermediate casting for bfloat16 (UPat(Ops.CAST, name="x", src=UPat.var("y", dtypes.bfloat16)),lambda x,y: y.cast(dtypes.float).cast(x.dtype) if x.dtype!=dtypes.float else None), diff --git a/tinygrad/renderer/llvmir.py b/tinygrad/renderer/llvmir.py index 1f1682d763..9e5061a5be 100644 --- a/tinygrad/renderer/llvmir.py +++ b/tinygrad/renderer/llvmir.py @@ -1,7 +1,7 @@ from typing import Dict, Callable, List, Optional from llvmlite import ir from tinygrad.dtype import DType, PtrDType, dtypes -from tinygrad.ops import UnaryOps, BinaryOps, TernaryOps, Ops, UOp +from tinygrad.ops import UnaryOps, BinaryOps, TernaryOps, Ops, UOp, GroupOp from tinygrad.renderer import Renderer MFLAGS = ('nsz', 'arcp', 'contract', 'afn') # All from fast math, but nnan and ninf and reassoc @@ -141,8 +141,8 @@ class LLVMRenderer(Renderer): backward = src[0] while backward.op is Ops.ASSIGN: backward = backward.src[0] lvars[backward] = lvars[u] - elif uop is Ops.ALU: - lvars[u] = self.code_for_op[args](bb[-1], *[lvars[x] for x in src], src[0].dtype if args in {BinaryOps.CMPLT, BinaryOps.CMPNE} else dtype) + elif uop in GroupOp.ALU: + lvars[u] = self.code_for_op[uop](bb[-1], *[lvars[x] for x in src], src[0].dtype if uop in {BinaryOps.CMPLT, BinaryOps.CMPNE} else dtype) elif uop in {Ops.CAST, Ops.BITCAST}: lvars[u] = cast(bb, lvars[src[0]], src[0].dtype, dtype, bitcast=uop is Ops.BITCAST) elif uop in {Ops.DEFINE_GLOBAL, Ops.DEFINE_VAR}: lvars[u] = func.args[buf_index[args]] elif uop is Ops.CONST: lvars[u] = const(args, dtype) diff --git a/tinygrad/renderer/ptx.py b/tinygrad/renderer/ptx.py index ddf2c9e645..4912c024de 100644 --- a/tinygrad/renderer/ptx.py +++ b/tinygrad/renderer/ptx.py @@ -1,7 +1,7 @@ -from typing import DefaultDict, Dict, List, Union, Optional, cast, Callable +from typing import DefaultDict, Dict, List, Union, Optional, cast, Callable, Tuple import struct from collections import defaultdict -from tinygrad.ops import BinaryOps, UnaryOps, TernaryOps, Ops, UOp, PatternMatcher, UPat +from tinygrad.ops import BinaryOps, UnaryOps, TernaryOps, Ops, UOp, PatternMatcher, UPat, GroupOp from tinygrad.dtype import dtypes, DType, PtrDType, ConstType from tinygrad.renderer import Renderer from tinygrad.renderer.cstyle import CUDARenderer @@ -33,14 +33,14 @@ asm_for_op: Dict[Ops, Callable] = { } supports_half: List[Ops] = [UnaryOps.EXP2, BinaryOps.ADD, BinaryOps.MUL, BinaryOps.MAX, BinaryOps.CMPLT, TernaryOps.WHERE] +doesnt_support_half: Tuple[Ops, ...] = tuple(op for op in asm_for_op.keys() if op not in supports_half) ptx_matcher = PatternMatcher([ # bool CMPNE is XOR, bool CMPLT is XOR+AND (universal makes this slow, this is for renderer only) (UPat.var('x', dtype=dtypes.bool).ne(UPat.var('y')), lambda x,y: x^y), (UPat.var('x', dtype=dtypes.bool).lt(UPat.var('y')), lambda x,y: (x^True)&y), # upcast to float32 all the ops that don't support half - *[(UPat(Ops.ALU, arg=op, dtype=dtypes.half, name="x"), - lambda x: (UOp(x.op, dtypes.float32, tuple(vv.cast(dtypes.float32) for vv in x.src), x.arg).cast(dtypes.half))) - for op in asm_for_op.keys() if op not in supports_half], + (UPat(doesnt_support_half, dtype=dtypes.half, name="x"), + lambda x: (UOp(x.op, dtypes.float32, tuple(vv.cast(dtypes.float32) for vv in x.src), x.arg).cast(dtypes.half))), # load/store bool -> uint8 (UPat(Ops.LOAD, dtypes.bool, src=(UPat(dtype=dtypes.int64),), name="x", allow_any_len=True), lambda x: UOp(x.op, dtypes.uint8, x.src[0:1] + ((x.src[1].cast(dtypes.uint8),) if len(x.src) >= 2 else ()) + x.src[2:]).cast(dtypes.bool)), @@ -50,8 +50,8 @@ ptx_matcher = PatternMatcher([ (UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var("idx"))), lambda buf,idx: buf.cast(dtypes.int64) + idx.cast(dtypes.int64)*buf.dtype.itemsize), (UPat(Ops.CAST, name="x"), lambda x: x.src[0] if isinstance(x.dtype, PtrDType) else None), # ptx shr and shl instructions require y to be uint - (UPat.var("x") << UPat.var("y"), lambda x,y: UOp(Ops.ALU, x.dtype, (x,y.cast(dtypes.uint)), BinaryOps.SHL) if y.dtype != dtypes.uint else None), - (UPat.var("x") >> UPat.var("y"), lambda x,y: UOp(Ops.ALU, x.dtype, (x,y.cast(dtypes.uint)), BinaryOps.SHR) if y.dtype != dtypes.uint else None), + (UPat.var("x") << UPat.var("y"), lambda x,y: UOp(Ops.SHL, x.dtype, (x,y.cast(dtypes.uint))) if y.dtype != dtypes.uint else None), + (UPat.var("x") >> UPat.var("y"), lambda x,y: UOp(Ops.SHR, x.dtype, (x,y.cast(dtypes.uint))) if y.dtype != dtypes.uint else None), ]) class PTXRenderer(Renderer): @@ -166,9 +166,9 @@ class PTXRenderer(Renderer): kk(gate + f"st{mem_type}.{self.mem_types[src[1].dtype]} [{r[src[0]]}+0], {r[src[1]]};") else: if uop is Ops.RANGE: kk(*self.render_loop(loop:=ssa('ridx', u), r[src[0]], "LOOP_"+loop[1:])) - elif uop is Ops.ALU: - src_dtype = src[0].dtype if args in {BinaryOps.CMPLT, BinaryOps.CMPNE} else dtype - kk(self.code_for_op[args](ssa("alu", u), *[r[x] for x in src], src_dtype, self.types[src_dtype])) + elif uop in GroupOp.ALU: + src_dtype = src[0].dtype if uop in {BinaryOps.CMPLT, BinaryOps.CMPNE} else dtype + kk(self.code_for_op[uop](ssa("alu", u), *[r[x] for x in src], src_dtype, self.types[src_dtype])) elif uop is Ops.DEFINE_ACC: if dtype.count > 1: r[u] = [ssa('acc', dtype=self.types[dtype.scalar()]) for _ in range(dtype.count)] @@ -190,7 +190,7 @@ class PTXRenderer(Renderer): elif uop is Ops.LOAD: assert src[0].dtype == dtypes.int64, "load isn't int64" mem_type = '.shared' if src[0].op is Ops.DEFINE_LOCAL or any(x.op is Ops.DEFINE_LOCAL for x in src[0].parents) else '.global' - has_gate = len(src) > 2 and src[2].op is Ops.ALU + has_gate = len(src) > 2 and src[2].op in GroupOp.ALU if dtype.count > 1: r[u] = [ssa('val', dtype=self.types[dtype.scalar()]) for _ in range(dtype.count)] if has_gate: diff --git a/tinygrad/runtime/ops_python.py b/tinygrad/runtime/ops_python.py index 46a6e928da..452618f813 100644 --- a/tinygrad/runtime/ops_python.py +++ b/tinygrad/runtime/ops_python.py @@ -7,7 +7,7 @@ import pickle, base64, itertools, time, struct from tinygrad.dtype import DType, dtypes, ImageDType, PtrDType, truncate from tinygrad.helpers import all_same, getenv, flatten from tinygrad.device import Compiled, Compiler, Allocator -from tinygrad.ops import BinaryOps, TernaryOps, exec_alu, Ops, UOp +from tinygrad.ops import BinaryOps, TernaryOps, exec_alu, Ops, UOp, GroupOp from tinygrad.renderer import Renderer from tinygrad.renderer.cstyle import CUDARenderer, MetalRenderer, AMDRenderer, IntelRenderer, ClangRenderer @@ -173,10 +173,10 @@ class PythonProgram: def c_map(_, elem): return (elem%16, elem//16) ul[i] = wmma_helper(1, 1, 16, 16, 256, elem, elem, c_map) else: raise NotImplementedError(f"unimplemented tensor core {arg}") - elif uop is Ops.ALU: - assert all_same([len(x) for x in inp]), f"{[len(x) for x in inp]} doesn't match on {arg}" - assert all_same([dtype] + dtp) or arg in {BinaryOps.CMPNE, BinaryOps.CMPLT, TernaryOps.WHERE}, f"dtype mismatch on {arg}" - ul[i] = [exec_alu(arg, dtype, p) for p in zip(*inp)] + elif uop in GroupOp.ALU: + assert all_same([len(x) for x in inp]), f"{[len(x) for x in inp]} doesn't match on {uop}" + assert all_same([dtype] + dtp) or uop in {BinaryOps.CMPNE, BinaryOps.CMPLT, TernaryOps.WHERE}, f"dtype mismatch on {uop}" + ul[i] = [exec_alu(uop, dtype, p) for p in zip(*inp)] assert i in ul, (uop, dtype, idp, arg) i += 1 return time.perf_counter() - st diff --git a/tinygrad/shape/shapetracker.py b/tinygrad/shape/shapetracker.py index 0a21f01097..98b3b56c64 100644 --- a/tinygrad/shape/shapetracker.py +++ b/tinygrad/shape/shapetracker.py @@ -76,8 +76,8 @@ class ShapeTracker: idx, valid = (graph_rewrite(u, symbolic_flat) for u in self.to_indexed_uops()) for c in split_uop(idx, BinaryOps.ADD): if c.op is Ops.RANGE: ret[c.arg] = 1 - if c.op is Ops.ALU and c.arg is BinaryOps.MUL and c.src[0].op is Ops.RANGE and c.src[1].op is Ops.CONST: ret[c.src[0].arg] = c.src[1].arg - if c.op is Ops.ALU and c.arg is BinaryOps.MUL and c.src[1].op is Ops.RANGE and c.src[0].op is Ops.CONST: ret[c.src[1].arg] = c.src[0].arg + if c.op is Ops.MUL and c.src[0].op is Ops.RANGE and c.src[1].op is Ops.CONST: ret[c.src[0].arg] = c.src[1].arg + if c.op is Ops.MUL and c.src[1].op is Ops.RANGE and c.src[0].op is Ops.CONST: ret[c.src[1].arg] = c.src[0].arg used_ranges = [x.arg for x in idx.sparents if x.op is Ops.RANGE] ret = [x if i in used_ranges else 0 for i,x in enumerate(ret)] if not ignore_valid: diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index b37bc82188..b3778196d8 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -384,10 +384,9 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method def from_uop(y:UOp, **kwargs) -> Tensor: if y.op is Ops.BIND: return Tensor(y, **kwargs, requires_grad=False) # this is the only UOp allowed in Tensor if y.op is Ops.CONST: return Tensor(y.arg, **kwargs, requires_grad=False) - if y.op is Ops.ALU: - if y.arg is BinaryOps.MUL: return Tensor.from_uop(y.src[0]) * Tensor.from_uop(y.src[1]) - if y.arg is BinaryOps.ADD: return Tensor.from_uop(y.src[0]) + Tensor.from_uop(y.src[1]) - if y.arg is BinaryOps.MAX: return Tensor.from_uop(y.src[0]).maximum(Tensor.from_uop(y.src[1])) + if y.op is BinaryOps.MUL: return Tensor.from_uop(y.src[0]) * Tensor.from_uop(y.src[1]) + if y.op is BinaryOps.ADD: return Tensor.from_uop(y.src[0]) + Tensor.from_uop(y.src[1]) + if y.op is BinaryOps.MAX: return Tensor.from_uop(y.src[0]).maximum(Tensor.from_uop(y.src[1])) raise RuntimeError(f"unhandled UOp {y}") # ***** creation entrypoint ***** diff --git a/tinygrad/viz/serve.py b/tinygrad/viz/serve.py index a0616a2d44..fe867e26af 100755 --- a/tinygrad/viz/serve.py +++ b/tinygrad/viz/serve.py @@ -5,13 +5,13 @@ from urllib.parse import parse_qs, urlparse from dataclasses import asdict, dataclass from typing import Any, Dict, List, Tuple, Optional from tinygrad.helpers import colored, getenv, to_function_name, tqdm, unwrap, word_wrap -from tinygrad.ops import TrackedRewriteContext, UOp, Ops, lines +from tinygrad.ops import TrackedRewriteContext, UOp, Ops, lines, GroupOp from tinygrad.codegen.kernel import Kernel -uops_colors = {Ops.ALU: "#ffffc0", Ops.LOAD: "#ffc0c0", Ops.STORE: "#c0ffc0", Ops.CONST: "#e0e0e0", Ops.VCONST: "#e0e0e0", +uops_colors = {Ops.LOAD: "#ffc0c0", Ops.STORE: "#c0ffc0", Ops.CONST: "#e0e0e0", Ops.VCONST: "#e0e0e0", Ops.DEFINE_GLOBAL: "#ffe0b0", Ops.DEFINE_LOCAL: "#ffe0d0", Ops.DEFINE_ACC: "#f0ffe0", Ops.REDUCE: "#C4A484", Ops.RANGE: "#c8a0e0", Ops.ASSIGN: "#e0ffc0", Ops.BARRIER: "#ff8080", Ops.IF: "#c8b0c0", Ops.SPECIAL: "#c0c0ff", - Ops.INDEX: "#e8ffa0", Ops.WMMA: "#efefc0", Ops.VIEW: "#C8F9D4", Ops.REDUCE_AXIS: "#f58488"} + Ops.INDEX: "#e8ffa0", Ops.WMMA: "#efefc0", Ops.VIEW: "#C8F9D4", Ops.REDUCE_AXIS: "#f58488", **{x:"#ffffc0" for x in GroupOp.ALU}} # ** API spec