Ops.ALU is no more, the arg is just an op (#7525)

* op arg alu [pr]

* more

* more passing

* fix more tests

* more tests passing

* fix single failing test

* so much cleaner

* noop to not have process replay trigger

* fix ptx
This commit is contained in:
George Hotz
2024-11-05 00:22:22 +08:00
committed by GitHub
parent 36488a2a43
commit 99bd4372a5
22 changed files with 171 additions and 185 deletions

View File

@@ -252,7 +252,7 @@ def fuzz_linearizer(lin: Kernel, rtol=1e-2, atol=1e-2, opts_list=None):
def _is_simple(lin: Kernel) -> bool:
if len(lin.ast.src) > 1: return False
ast:UOp = lin.ast.src[0]
if ast.src[0].arg is UnaryOps.CAST and ast.src[0].src[0].op is Ops.LOAD: return True
if ast.src[0].op is UnaryOps.CAST and ast.src[0].src[0].op is Ops.LOAD: return True
return False
if __name__ == "__main__":

View File

@@ -8,7 +8,7 @@ from tinygrad.dtype import DType
from tinygrad.helpers import CI, getenv
from tinygrad.engine.schedule import create_schedule
from tinygrad.engine.realize import run_schedule
from tinygrad.ops import UnaryOps, Ops
from tinygrad.ops import GroupOp
from tinygrad.tensor import _to_np_dtype
from test.helpers import is_dtype_supported
import pytest
@@ -79,7 +79,7 @@ def universal_test_unary(a, dtype, op):
np.testing.assert_allclose(tensor_value, numpy_value, atol=1e-3, rtol=1e-2)
else: np.testing.assert_equal(tensor_value, numpy_value)
if op[0] != Tensor.reciprocal: # reciprocal is not supported in most backends
op = [x for x in ast.parents if x.op is Ops.ALU and x.arg in UnaryOps][0]
op = [x for x in ast.parents if x.op in GroupOp.Unary][0]
assert op.dtype == dtype
def universal_test_cast(a, in_dtype, dtype):

View File

@@ -6,7 +6,7 @@ from dataclasses import replace
from test.helpers import ast_const
from tinygrad.codegen.kernel import Opt, OptOps, KernelOptError, Kernel
from tinygrad.codegen.lowerer import get_grouped_dims
from tinygrad.ops import UOp, Ops, BinaryOps, TernaryOps, UnaryOps
from tinygrad.ops import UOp, Ops, BinaryOps, TernaryOps, UnaryOps, GroupOp
from tinygrad.device import Device, Buffer
from tinygrad.shape.shapetracker import ShapeTracker
from tinygrad.shape.view import View
@@ -865,9 +865,9 @@ class TestLinearizer(unittest.TestCase):
lin = helper_linearizer_opt(out, wanna_output=[24])[0]
ranges = [i for i,u in enumerate(lin.uops) if u.op is Ops.RANGE]
# RANGE -> ALU -> RANGE -> ALU + LOAD -> ASSIGN
assert any(x.op is Ops.ALU for x in lin.uops[ranges[0]:ranges[1]])
assert any(x.op in GroupOp.ALU for x in lin.uops[ranges[0]:ranges[1]])
assert not any(x.op is Ops.LOAD for x in lin.uops[ranges[0]:ranges[1]])
assert any(x.op in {Ops.ALU, Ops.LOAD} for x in lin.uops[ranges[1]:])
assert any(x.op in {*GroupOp.ALU, Ops.LOAD} for x in lin.uops[ranges[1]:])
def test_range_outer_op_before_phi(self):
a = Tensor.randn(4, 1).realize()
@@ -902,7 +902,7 @@ class TestLinearizer(unittest.TestCase):
lin = helper_linearizer_opt(out, wanna_output=[a.numpy().sum()*a.numpy().sum()])[0]
# RANGE -> LOAD -> ASSIGN -> ALU
end = max(i for i,u in enumerate(lin.uops) if u.op is Ops.ENDRANGE)
assert lin.uops[end+1].op is Ops.ALU
assert lin.uops[end+1].op in GroupOp.ALU
def test_range_outer_op_after_phi_nested_range(self):
a = Tensor.randn(2, ).realize()
@@ -910,7 +910,7 @@ class TestLinearizer(unittest.TestCase):
lin = helper_linearizer_opt(out, wanna_output=[(np.broadcast_to(a.numpy().reshape(2, 1), (2, 3))).sum()*2])[0]
# RANGE -> LOAD -> ASSIGN -> ALU
end = max(i for i,u in enumerate(lin.uops) if u.op is Ops.ENDRANGE)
assert lin.uops[end+1].op is Ops.ALU
assert lin.uops[end+1].op in GroupOp.ALU
def test_load_dedup(self):
# for different leaves in the AST, the same loads may occur.
@@ -1141,7 +1141,7 @@ class TestLinearizer(unittest.TestCase):
# the uops graph is RANGE -> DEFINE_ACC -> 4x ALU -> 4x ASSIGN -> ENDRANGE
for u in k.uops:
if u.op is Ops.ASSIGN:
assert u.src[1].op is Ops.ALU
assert u.src[1].op in GroupOp.ALU
# children of ASSIGN are placed after ENDRANGE
if any(x.op is Ops.ASSIGN for x in u.src):
end_range = [i for i, x in enumerate(k.uops) if x.op is Ops.ENDRANGE][0]
@@ -1219,7 +1219,7 @@ class TestLinearizer(unittest.TestCase):
assert len(sched) == 1
lin = Kernel(sched[0].ast)
assert sum(u.arg is UnaryOps.RECIP for u in lin.linearize().uops) == max_ops, msg
assert sum(u.op is UnaryOps.RECIP for u in lin.linearize().uops) == max_ops, msg
a = Tensor.empty((4,4))
b = Tensor.empty((4,4))
@@ -1260,7 +1260,7 @@ class TestLinearizer(unittest.TestCase):
lin = Kernel(sched_copy[-1].ast)
lin.hand_coded_optimizations()
lin.linearize()
assert not any(u.arg == TernaryOps.WHERE for u in lin.uops), "found where where where should be folded"
assert not any(u.op == TernaryOps.WHERE for u in lin.uops), "found where where where should be folded"
def test_phi_simplification(self):
def helper(t, max_ops=0):
@@ -1272,7 +1272,7 @@ class TestLinearizer(unittest.TestCase):
assert len(set([u.op for u in uops if u.op in {Ops.RANGE, Ops.SPECIAL}])) == 1, "has either specials or ranges, not both"
assert len([u for u in uops if u.op is Ops.ASSIGN]) == 0, "ASSIGN should have been simplified"
# TODO: once uops track min/max this will be fixed
#assert len([u for u in uops if u.arg is BinaryOps.MAX]) <= max_ops, "no unnecessary MAX ops"
#assert len([u for u in uops if u.op is BinaryOps.MAX]) <= max_ops, "no unnecessary MAX ops"
helper(Tensor.arange(5.5, (3.5*300), 3.5), max_ops=2)
helper(Tensor.arange(-1, -100, -5), max_ops=2)

View File

@@ -75,7 +75,7 @@ class TestLinearizerDumb(unittest.TestCase):
for opt in opts: k.apply_opt(opt)
prg = k.to_program()
print(prg.src)
assert prg.uops is not None and not any(uop.op is Ops.ALU and uop.arg is BinaryOps.MAX for uop in prg.uops), "leftover MAX"
assert prg.uops is not None and not any(uop.op is BinaryOps.MAX for uop in prg.uops), "leftover MAX"
@unittest.skipUnless(Device[Device.DEFAULT].renderer.has_local, "need local")
def test_expander_new_srcs(self):

View File

@@ -620,21 +620,21 @@ class TestMultiTensor(unittest.TestCase):
for si in t.schedule():
ast = si.ast.src[0]
assert ast.op is Ops.STORE
assert ast.src[2].arg is BinaryOps.ADD
assert ast.src[2].op is BinaryOps.ADD
assert ast.src[2].src[0].op is Ops.LOAD
assert ast.src[2].src[1].src[1].op is Ops.CONST and ast.src[2].src[1].src[1].arg == 1
t = 2 * t
for si in t.schedule():
ast = si.ast.src[0]
assert ast.op is Ops.STORE
assert ast.src[2].arg is BinaryOps.MUL
assert ast.src[2].op is BinaryOps.MUL
assert ast.src[2].src[0].src[1].op is Ops.CONST and ast.src[2].src[0].src[1].arg == 2
assert ast.src[2].src[1].op is Ops.LOAD
t = t + t.full_like(3)
for si in t.schedule():
ast = si.ast.src[0]
assert ast.op is Ops.STORE
assert ast.src[2].arg is BinaryOps.ADD
assert ast.src[2].op is BinaryOps.ADD
assert ast.src[2].src[0].op is Ops.LOAD
assert ast.src[2].src[1].src[1].op is Ops.CONST and ast.src[2].src[1].src[1].arg == 3

View File

@@ -1040,7 +1040,7 @@ class TestSchedule(unittest.TestCase):
b = r.sum(0) * 4
c = r.sum(1) * 2
schedule = check_schedule([b, c], 3)
self.assertIs(schedule[0].ast.src[0].src[2].arg, BinaryOps.ADD)
self.assertIs(schedule[0].ast.src[0].src[2].op, BinaryOps.ADD)
# multireduce spec
def test_multireduce_simple_chase(self):
@@ -1064,7 +1064,7 @@ class TestSchedule(unittest.TestCase):
d = r.T * 4
e = r * d
schedule = check_schedule([d, e], 3)
self.assertIs(schedule[0].ast.src[0].src[2].arg, BinaryOps.ADD)
self.assertIs(schedule[0].ast.src[0].src[2].op, BinaryOps.ADD)
# multireduce spec
def test_multireduce_push_permute_chase(self):
@@ -1075,7 +1075,7 @@ class TestSchedule(unittest.TestCase):
d = r.T * 4
e = r * (d + a).sum(2)
schedule = check_schedule([d, e], 3) # make sure it doesn't fuse
self.assertIs(schedule[0].ast.src[0].src[2].arg, BinaryOps.ADD)
self.assertIs(schedule[0].ast.src[0].src[2].op, BinaryOps.ADD)
run_schedule(schedule)
np.testing.assert_allclose(d.numpy(), (a.numpy().sum(2) + b.numpy()).T * 4, atol=1e-4, rtol=1e-4)
np.testing.assert_allclose(e.numpy(), (a.numpy().sum(2) + b.numpy()) * (d.numpy() + a.numpy()).sum(2), atol=1e-4, rtol=1e-4)
@@ -1087,7 +1087,7 @@ class TestSchedule(unittest.TestCase):
r = a.sum(1) + c
d = r[:4] * b
schedule = check_schedule(d, 2)
self.assertIs(schedule[0].ast.src[0].src[2].arg, BinaryOps.ADD)
self.assertIs(schedule[0].ast.src[0].src[2].op, BinaryOps.ADD)
# multireduce spec
def test_multireduce_push_shrink_chase(self):
@@ -1100,7 +1100,7 @@ class TestSchedule(unittest.TestCase):
out = r[:4] * b + d.sum(1)[:4]
# schedule = check_schedule(out, 2)
schedule = check_schedule(out, 3)
self.assertIs(schedule[0].ast.src[0].src[2].arg, BinaryOps.ADD)
self.assertIs(schedule[0].ast.src[0].src[2].op, BinaryOps.ADD)
run_schedule(schedule)
np.testing.assert_allclose(out.numpy(), (a.numpy().sum(1) + c.numpy())[:4] * b.numpy() + d.numpy().sum(1)[:4], atol=1e-4, rtol=1e-4)

View File

@@ -161,7 +161,7 @@ class TestGraphRewrite(unittest.TestCase):
c1 = UOp.const(dtypes.float, 1.0)
c2 = UOp.const(dtypes.float, 2.0)
nout = graph_rewrite(v+c1+c2, simple_pm)
self.assertEqual(nout.op, Ops.ALU)
self.assertEqual(nout.op, Ops.ADD)
self.assertEqual(nout.src[0].op, Ops.DEFINE_VAR)
self.assertEqual(nout.src[1].op, Ops.CONST)
self.assertEqual(nout.src[1].arg, 3.0)
@@ -182,11 +182,11 @@ class TestGraphRewrite(unittest.TestCase):
b = UOp.variable('b', 0, 1)
c = UOp.variable('c', 0, 1)
d = UOp.variable('d', 0, 1)
outs = [2+a, 2+a+d+3+b+c+4, UOp(Ops.ALU, a.dtype, src=(a.const_like(2), a), arg=BinaryOps.ADD), (4+d)+c+(2+a)+b]
outs = [2+a, 2+a+d+3+b+c+4, UOp(Ops.ADD, a.dtype, src=(a.const_like(2), a)), (4+d)+c+(2+a)+b]
for out in outs:
sink = graph_rewrite(out, sym)
print(sink.render())
self.assertEqual(sink.op, Ops.ALU)
self.assertEqual(sink.op, Ops.ADD)
self.assertEqual(sink.src[1].op, Ops.CONST)
self.assertEqual(len([x for x in sink.sparents if x.op is Ops.CONST]), 1)
@@ -380,8 +380,7 @@ class TestUOpGraph(unittest.TestCase):
uops = to_uops_list([out])
self.assertEqual(len(uops), 3)
out = uops[-1]
self.assertEqual(out.op, Ops.ALU)
self.assertEqual(out.arg, BinaryOps.ADD)
self.assertEqual(out.op, BinaryOps.ADD)
self.assertEqual(out.src[1].op, Ops.CONST)
self.assertEqual(out.src[1].arg, 6)

View File

@@ -336,8 +336,8 @@ class TestAssembly(unittest.TestCase):
a2 = UOp(Ops.ALU, dtypes.int, (l1, c2), BinaryOps.MUL)
uops = to_uops_list([a1,a2], opts=Device[Device.DEFAULT].renderer)
Device[Device.DEFAULT].renderer.render("test", uops)
self.assertEqual(uops[-1].arg, BinaryOps.SHL)
self.assertEqual(uops[-2].arg, BinaryOps.MUL)
self.assertEqual(uops[-1].op, BinaryOps.SHL)
self.assertEqual(uops[-2].op, BinaryOps.MUL)
def test_bitshift_right(self):
g1 = UOp(Ops.DEFINE_GLOBAL, dtypes.int32.ptr(), (), 0)
@@ -348,8 +348,8 @@ class TestAssembly(unittest.TestCase):
a2 = UOp(Ops.ALU, dtypes.int, (l1, c2), BinaryOps.IDIV)
uops = to_uops_list([a1,a2], opts=Device[Device.DEFAULT].renderer)
Device[Device.DEFAULT].renderer.render("test", uops)
self.assertEqual(uops[-1].arg, BinaryOps.SHR)
self.assertEqual(uops[-2].arg, BinaryOps.IDIV)
self.assertEqual(uops[-1].op, BinaryOps.SHR)
self.assertEqual(uops[-2].op, BinaryOps.IDIV)
class TestUOpMethod(unittest.TestCase):
@unittest.skip("uops lt no longer ordered")

View File

@@ -1,7 +1,7 @@
import unittest, math
from tinygrad import dtypes
from tinygrad.helpers import all_same
from tinygrad.ops import UOp, Ops, BinaryOps, exec_alu
from tinygrad.ops import GroupOp, UOp, Ops, BinaryOps, exec_alu
from tinygrad.codegen.uopgraph import full_graph_rewrite
# Helper function to apply the graph rewrite
@@ -14,9 +14,9 @@ def evaluate_uop(uop, variables):
elif uop.op == Ops.DEFINE_VAR:
var_name = uop.arg[0]
return variables[var_name]
elif uop.op == Ops.ALU:
elif uop.op in GroupOp.ALU:
src_values = [evaluate_uop(src, variables) for src in uop.src]
return exec_alu(uop.arg, uop.dtype, src_values)
return exec_alu(uop.op, uop.dtype, src_values)
else:
raise NotImplementedError(f"Unsupported UOp {uop.op}")
@@ -104,8 +104,7 @@ class TestModuloAndDivisionFolding(unittest.TestCase):
def test_full_graph_rewrite_division_folding_with_define_var(self):
n_var_uop = UOp.variable('n', 1, 1000)
optimized_div_uop = apply_rewrite((n_var_uop * 6) // 3)
self.assertEqual(optimized_div_uop.op, Ops.ALU)
self.assertEqual(optimized_div_uop.arg, BinaryOps.MUL)
self.assertEqual(optimized_div_uop.op, BinaryOps.MUL)
self.assertEqual(optimized_div_uop.src[1].arg, 2)
def test_full_graph_rewrite_complex_mod_div_folding(self):
@@ -115,7 +114,7 @@ class TestModuloAndDivisionFolding(unittest.TestCase):
self.assertEqual(optimized_div_uop.arg, 1)
def test_graph_rewrite_div_folding_bug(self):
lhs = UOp(Ops.ALU, dtypes.int.vec(4), arg=BinaryOps.ADD, src=(
lhs = UOp(Ops.ADD, dtypes.int.vec(4), src=(
UOp(Ops.VECTORIZE, dtypes.int.vec(4), arg=None, src=(UOp(Ops.SPECIAL, dtypes.int, arg=('lidx0', 32), src=()),)*4),
UOp(Ops.VCONST, dtypes.int.vec(4), arg=(0, 256, 512, 768), src=())))
rhs = UOp.const(dtypes.int.vec(4), 2)

View File

@@ -1,6 +1,6 @@
import unittest, itertools
from tinygrad.dtype import dtypes
from tinygrad.ops import Ops, UOp, BinaryOps, TernaryOps, ReduceOps, UnaryOps # noqa: F401
from tinygrad.ops import Ops, UOp, BinaryOps, TernaryOps, ReduceOps, UnaryOps, GroupOp # noqa: F401
from tinygrad.ops import PatternMatcher, UPat
class TestPatternMatcher(unittest.TestCase):
@@ -72,7 +72,7 @@ class TestPatternMatcher(unittest.TestCase):
matcher = PatternMatcher([
(UPat(Ops.CONST, arg=0, name="x"), lambda x: x),
(UPat(Ops.CONST, arg=False, name="x"), lambda x: x),
(UPat(Ops.ALU, arg=BinaryOps.MAX, name="x"), lambda x: x),
(UPat(Ops.MAX, name="x"), lambda x: x),
])
c1 = UOp(Ops.CONST, dtypes.float, arg=0.0)
c2 = UOp(Ops.CONST, dtypes.bool, arg=False)
@@ -87,7 +87,7 @@ class TestPatternMatcher(unittest.TestCase):
def test_filter_arg(self):
matcher = PatternMatcher([
(UPat(Ops.ALU, arg=BinaryOps.MUL, src=[UPat(Ops.CONST, name="c"), UPat(Ops.CONST, arg=2)], name="x"),
(UPat(Ops.MUL, src=[UPat(Ops.CONST, name="c"), UPat(Ops.CONST, arg=2)], name="x"),
lambda x,c: x if c.arg in {1, -1} else None)
])
y1 = UOp(Ops.CONST, dtypes.int, arg=1)
@@ -105,11 +105,11 @@ class TestPatternMatcher(unittest.TestCase):
self.assertEqual(matcher.rewrite(c5), c5)
def test_dup_name(self):
matcher = PatternMatcher([(UPat(Ops.ALU, name="x", src=(UPat(Ops.CONST, name="y"), UPat(Ops.CONST, name="y"))), lambda x, y: x)])
matcher = PatternMatcher([(UPat(GroupOp.ALU, name="x", src=(UPat(Ops.CONST, name="y"), UPat(Ops.CONST, name="y"))), lambda x, y: x)])
y1 = UOp(Ops.CONST, dtypes.float, arg=1.0)
y2 = UOp(Ops.CONST, dtypes.float, arg=1.0)
c1 = UOp(Ops.ALU, dtypes.float, (y1, y1), BinaryOps.ADD)
c2 = UOp(Ops.ALU, dtypes.float, (y1, y2), BinaryOps.ADD)
c1 = UOp(Ops.ADD, dtypes.float, (y1, y1))
c2 = UOp(Ops.ADD, dtypes.float, (y1, y2))
self.assertEqual(matcher.rewrite(c1), c1)
self.assertEqual(matcher.rewrite(c2), c1)
@@ -132,7 +132,7 @@ class TestPatternMatcher(unittest.TestCase):
self.assertEqual(matcher.rewrite(c4), None)
def test_src_one(self):
matcher = PatternMatcher([(UPat(Ops.ALU, name="x", src=(UPat(Ops.CONST), UPat(Ops.CONST))), lambda x: x)])
matcher = PatternMatcher([(UPat(GroupOp.ALU, name="x", src=(UPat(Ops.CONST), UPat(Ops.CONST))), lambda x: x)])
c1 = UOp(Ops.CONST, dtypes.float, arg=1.0)
c2 = UOp(Ops.CONST, dtypes.float, arg=2.0)
c3 = UOp(Ops.ALU, dtypes.float, (c1,c2), BinaryOps.ADD)
@@ -149,7 +149,7 @@ class TestPatternMatcher(unittest.TestCase):
"""
def test_src_permutations(self):
matcher = PatternMatcher([(UPat(Ops.ALU, name="x", src=[UPat(Ops.CONST), UPat(Ops.ALU)]), lambda x: x)])
matcher = PatternMatcher([(UPat(GroupOp.ALU, name="x", src=[UPat(Ops.CONST), UPat(GroupOp.ALU)]), lambda x: x)])
c1 = UOp(Ops.CONST, dtypes.float, arg=1.0)
c2 = UOp(Ops.CONST, dtypes.float, arg=2.0)
c3 = UOp(Ops.ALU, dtypes.float, (c1,c2), BinaryOps.ADD)
@@ -162,7 +162,7 @@ class TestPatternMatcher(unittest.TestCase):
self.assertEqual(matcher.rewrite(c6), None)
def test_src_repeat(self):
matcher = PatternMatcher([(UPat(Ops.ALU, name="x", src=UPat(Ops.CONST)), lambda x: x)])
matcher = PatternMatcher([(UPat(GroupOp.ALU, name="x", src=UPat(Ops.CONST)), lambda x: x)])
c1 = UOp(Ops.CONST, dtypes.float, arg=1.0)
c2 = UOp(Ops.CONST, dtypes.float, arg=2.0)
c3 = UOp(Ops.ALU, dtypes.float, (c1,c2), BinaryOps.ADD)
@@ -171,7 +171,7 @@ class TestPatternMatcher(unittest.TestCase):
self.assertEqual(matcher.rewrite(c4), None)
def test_allow_len(self):
matcher = PatternMatcher([(UPat(Ops.ALU, name="x", src=(UPat(Ops.CONST),), allow_any_len=True, arg=TernaryOps.MULACC), lambda x: x)])
matcher = PatternMatcher([(UPat(Ops.MULACC, name="x", src=(UPat(Ops.CONST),), allow_any_len=True), lambda x: x)])
c1 = UOp(Ops.CONST, dtypes.float, arg=1.0)
c2 = UOp(Ops.CONST, dtypes.float, arg=2.0)
c3 = UOp(Ops.CONST, dtypes.float, arg=3.0)
@@ -188,7 +188,7 @@ class TestPatternMatcher(unittest.TestCase):
u1 = (c1 + c2) + c1
u2 = (c2 + c1) + c1
matcher = PatternMatcher([
(UPat(Ops.ALU, src=[UPat(Ops.ALU, src=[UPat(name='a'), UPat(name='b')]), UPat(name='b')]), lambda a,b: b)
(UPat(GroupOp.ALU, src=[UPat(GroupOp.ALU, src=[UPat(name='a'), UPat(name='b')]), UPat(name='b')]), lambda a,b: b)
])
self.assertIsNotNone(matcher.rewrite(u1))
self.assertIsNotNone(matcher.rewrite(u2))

View File

@@ -276,7 +276,7 @@ class Kernel:
if has_cast and not (reduceop.src[0].op is Ops.CAST and reduceop.src[0].dtype == tc.dtype_out): return None
mul_op = reduceop.src[0].src[0] if has_cast else reduceop.src[0]
if mul_op.arg is not BinaryOps.MUL: return None
if mul_op.op is not BinaryOps.MUL: return None
def buf_index(src:UOp) -> Optional[int]:
# TODO: apply tc even if the sources are not from LOAD
@@ -442,7 +442,7 @@ class Kernel:
check(axis < self.first_upcast, "cannot pad upcasted")
# ok to pad SUM if all parent ALU ops have f(0) = 0
if (r:=self.reduceop) is not None and self.first_reduce <= axis:
check(r.arg[0] is BinaryOps.ADD and not any(u.op is Ops.ALU and u.arg in GroupOp.UnsafePad for u in r.parents), "cannot pad UnsafePad")
check(r.arg[0] is BinaryOps.ADD and not any(u.op in GroupOp.UnsafePad for u in r.parents), "cannot pad UnsafePad")
padded = False
for i,st in enumerate(self.sts):
if (s:=st.shape[axis]) == 1: continue # reduced
@@ -472,7 +472,7 @@ class Kernel:
MV_BLOCKSIZE, MV_THREADS_PER_ROW, MV_ROWS_PER_THREAD = getenv("MV_BLOCKSIZE", 4), getenv("MV_THREADS_PER_ROW", 8), getenv("MV_ROWS_PER_THREAD", 4)
if self.opts.has_local and getenv("MV",1) != 0 and (MV_BLOCKSIZE > 1 or MV_THREADS_PER_ROW > 1 or MV_ROWS_PER_THREAD > 1) and \
self.reduceop is not None and self.reduceop.arg[0] is BinaryOps.ADD and len(self.full_shape) >= 2 and self.opts.has_shared and \
(mulop:=self.reduceop.src[0]).arg is BinaryOps.MUL and mulop.src[0].op is Ops.LOAD and mulop.src[1].op is Ops.LOAD:
(mulop:=self.reduceop.src[0]).op is BinaryOps.MUL and mulop.src[0].op is Ops.LOAD and mulop.src[1].op is Ops.LOAD:
st0, st1 = self.sts[self.bufs.index(mulop.src[0])], self.sts[self.bufs.index(mulop.src[1])]
strides0, strides1 = st0.real_strides(), st1.real_strides()
def has_expanded_axis(shape, strides): return any(resolve(s > 1) and not resolve(st != 0) for s,st in zip(shape,strides))
@@ -628,7 +628,7 @@ class Kernel:
if op in self.bufs_for_tensor_core and (tc := self.tensor_core):
rsrc = op.src[0]
if rsrc.op is Ops.CAST: rsrc = rsrc.src[0]
assert rsrc.op is Ops.ALU and rsrc.arg is BinaryOps.MUL
assert rsrc.op is Ops.MUL
def fix_st(warp_dims, tcd_dims, tcd_expand, pattern_1, pattern_2, st1):
wd, tcd = self.global_dims, self.first_upcast

View File

@@ -1,6 +1,6 @@
from typing import List, Set, Dict, Tuple
import functools, heapq
from tinygrad.ops import type_verify, END_FOR_UOP, UOp, Ops
from tinygrad.ops import type_verify, END_FOR_UOP, UOp, Ops, GroupOp
from tinygrad.dtype import dtypes
from tinygrad.helpers import DEBUG
@@ -54,7 +54,7 @@ def linearize_uop(sink:UOp, skip_check:bool=not __debug__) -> List[UOp]:
# prevent priority inversion
@functools.lru_cache(None)
def fix_priority(u:UOp, lowest_priority):
if u.op in {Ops.CAST, Ops.BITCAST, Ops.ALU, Ops.VECTORIZE, Ops.GEP, Ops.SPECIAL, Ops.DEFINE_LOCAL, Ops.LOAD}:
if u.op in {Ops.CAST, Ops.BITCAST, *GroupOp.ALU, Ops.VECTORIZE, Ops.GEP, Ops.SPECIAL, Ops.DEFINE_LOCAL, Ops.LOAD}:
priorities[u] = min(priorities[u], lowest_priority)
if u.op is Ops.LOAD: priorities[u] += 100 # load penalty (here)
for x in u.src: fix_priority(x, priorities[u])

View File

@@ -4,7 +4,7 @@ import functools, itertools, operator
from collections import defaultdict
from tinygrad.dtype import dtypes, ImageDType, PtrDType
from tinygrad.ops import UnaryOps, BinaryOps, TernaryOps, UOp, Ops, UPat, PatternMatcher, symbolic_flat, symbolic_simple
from tinygrad.ops import graph_rewrite, is_irreducible, split_uop, uop_given_valid, parse_valid, is_increasing, simplify_valid
from tinygrad.ops import graph_rewrite, is_irreducible, split_uop, uop_given_valid, parse_valid, is_increasing, simplify_valid, GroupOp
from tinygrad.helpers import DEBUG, getenv, flatten, dedup, TRANSCENDENTAL, AMX, prod, partition, all_same
from tinygrad.codegen.transcendental import xexp2, xlog2, xsin, TRANSCENDENTAL_SUPPORTED_DTYPES
@@ -23,7 +23,7 @@ def fold_expanded(ex, buf):
for i,s in enumerate(new_srcs):
idx = s.src[0].src[1]
if s.dtype.count != 1 or (is_image and idx.dtype.count == 2): continue
if idx.arg is BinaryOps.ADD and idx.src[1].op is Ops.CONST: root_src, arg = idx.src[0], idx.src[1].arg
if idx.op is BinaryOps.ADD and idx.src[1].op is Ops.CONST: root_src, arg = idx.src[0], idx.src[1].arg
elif idx.op is Ops.CONST: root_src, arg = "CONST", idx.arg
else: root_src, arg = idx, 0
# add gates for gated
@@ -124,18 +124,18 @@ def simplify_valid_load(buf:UOp, start_idx:UOp, valid:UOp) -> Optional[UOp]:
powers_of_two = {2**i:i for i in range(64)}
@functools.lru_cache(None)
def get_late_rewrite_patterns(ops, force_transcendental=False):
pat: List[Tuple[UPat, Callable]] = [(UPat(Ops.ALU, dtype=TRANSCENDENTAL_SUPPORTED_DTYPES, src=(UPat.var("d"),), arg=op), f) for op,f in \
pat: List[Tuple[UPat, Callable]] = [(UPat(op, dtype=TRANSCENDENTAL_SUPPORTED_DTYPES, src=(UPat.var("d"),)), f) for op,f in \
((UnaryOps.EXP2, xexp2), (UnaryOps.LOG2, xlog2), (UnaryOps.SIN, xsin)) if op not in ops or force_transcendental]
# rewrite MOD to AND (which should always be supported, but not for generic in tests)
if BinaryOps.AND in ops:
pat += [(UPat(Ops.ALU, arg=BinaryOps.MOD, src=(UPat.var('base'), UPat.cvar("const"))),
pat += [(UPat(Ops.MOD, src=(UPat.var('base'), UPat.cvar("const"))),
lambda base,const: base & (const.arg-1) if const.arg in powers_of_two else None)]
# rewrite MUL/IDIV to SHL+SHR
if BinaryOps.SHL in ops and BinaryOps.SHR in ops:
pat += [
(UPat(Ops.ALU, arg=BinaryOps.MUL, dtype=dtypes.ints, src=[UPat.cvar("const"), UPat.var("mul")]), lambda mul, const:
(UPat(Ops.MUL, dtype=dtypes.ints, src=[UPat.cvar("const"), UPat.var("mul")]), lambda mul, const:
mul << powers_of_two[const.arg] if const.arg in powers_of_two else None), # (x * (2**y)) -> shl(x,y)
(UPat(Ops.ALU, arg=BinaryOps.IDIV, src=(UPat.var("div"), UPat.cvar("const"))), lambda div, const:
(UPat(Ops.IDIV, src=(UPat.var("div"), UPat.cvar("const"))), lambda div, const:
div >> powers_of_two[const.arg] if const.arg in powers_of_two else None)] # (x // (2**y)) -> shr(x,y)
if UnaryOps.NEG in ops:
pat += [(UPat.var('x')*-1, lambda x: x.alu(UnaryOps.NEG))]
@@ -226,8 +226,8 @@ def reduce_collapse(acc:UOp, ret:UOp, alu:UOp):
reduce_parented, reduce_unparented = partition(acc.src[1:], lambda x: x in ret.sparents)
if len(reduce_unparented) == 0: return None
new_acc = acc.replace(src=acc.src[0:1]+tuple(reduce_parented))
ret = new_acc.assign(new_acc.alu(alu.arg, ret))
if alu.arg is BinaryOps.ADD:
ret = new_acc.assign(new_acc.alu(alu.op, ret))
if alu.op is BinaryOps.ADD:
for r in reduce_unparented: ret = ret * (r.src[1]-r.src[0]).cast(ret.dtype.scalar()).broadcast(ret.dtype.count)
return ret
@@ -249,8 +249,8 @@ sym = symbolic_flat+PatternMatcher([
(UPat(Ops.VECTORIZE, src=UPat(Ops.CONST), name="vec"), lambda vec: UOp.const(vec.dtype, tuple(x.arg for x in vec.src))),
(UPat(Ops.VECTORIZE, src=UPat(Ops.GEP, src=(UPat(name="x"),)), name="vec"), lambda vec,x: x.gep(tuple(y.arg[0] for y in vec.src))),
# reorder ALU/VECTORIZE
(UPat(Ops.ALU, src=(UPat(Ops.VECTORIZE, src=UPat(name='x')), UPat(Ops.VECTORIZE, src=UPat(name='y'))), name='alu'),
lambda x,y,alu: UOp(Ops.VECTORIZE, alu.dtype, (UOp(Ops.ALU, alu.dtype.scalar(), (x,y), alu.arg),)*alu.dtype.count)),
(UPat(GroupOp.ALU, src=(UPat(Ops.VECTORIZE, src=UPat(name='x')), UPat(Ops.VECTORIZE, src=UPat(name='y'))), name='alu'),
lambda x,y,alu: UOp(Ops.VECTORIZE, alu.dtype, (UOp(alu.op, alu.dtype.scalar(), (x,y)),)*alu.dtype.count)),
# VECTORIZE of a single element is just that element
(UPat(Ops.VECTORIZE, src=(UPat(name='x'),)), lambda x: x),
# VECTORIZE void is SINK
@@ -264,7 +264,7 @@ sym = symbolic_flat+PatternMatcher([
(UPat(Ops.GEP, src=(UPat.cvar("c", vec=False),), name="gep"), lambda gep, c: gep.const_like(c.arg)),
(UPat(Ops.GEP, src=(UPat(Ops.VCONST, name="c"),), name="gep"), lambda gep, c: gep.const_like(tuple(c.arg[x] for x in gep.arg))),
# push all GEPs through ALUs (fix arange stuff)
(UPat(Ops.GEP, src=(UPat((Ops.ALU, Ops.CAST, Ops.BITCAST), name='alu'),), name='gep'),
(UPat(Ops.GEP, src=(UPat((*GroupOp.ALU, Ops.CAST, Ops.BITCAST), name='alu'),), name='gep'),
lambda gep,alu: UOp(alu.op, alu.dtype.scalar().vec(gep.dtype.count), tuple(x.gep(gep.arg) for x in alu.src), alu.arg)),
# push some GEPs through WMMAs
(UPat(Ops.GEP, src=(UPat(Ops.WMMA, name="wmma"),), name="gep"), gep_through_wmma),
@@ -275,15 +275,15 @@ sym = symbolic_flat+PatternMatcher([
(UPat.var("add") + UPat(Ops.WMMA, name="wmma"),
lambda add, wmma: UOp(wmma.op, wmma.dtype, (wmma.src[0], wmma.src[1], wmma.src[2]+add), wmma.arg)),
# threefry
(UPat(Ops.ALU, dtype=dtypes.uint64, src=(UPat.var("x"), UPat.var("key")), arg=BinaryOps.THREEFRY), threefry2x32),
(UPat(Ops.THREEFRY, dtype=dtypes.uint64, src=(UPat.var("x"), UPat.var("key"))), threefry2x32),
# arange loop folding
(acc_pat.assign(UPat.any(arange_m, arange_m+UPat.var("extra"))+acc_pat), loop_collapse),
# indexing, with cast or where
(acc_pat.assign(UPat.var("idx").eq(UPat(Ops.RANGE, name="rng")).cast()*index_load+acc_pat), index_collapse),
(acc_pat.assign(UPat.var("idx").eq(UPat(Ops.RANGE, name="rng")).where(index_load, UPat.const(None, 0.0))+acc_pat), index_collapse),
# parentless reduce
(acc_pat.assign(UPat(Ops.ALU, src=[acc_pat, UPat.var("ret")], arg=BinaryOps.ADD, name="alu")), reduce_collapse),
(acc_pat.assign(UPat(Ops.ALU, src=[acc_pat, UPat.var("ret")], arg=BinaryOps.MAX, name="alu")), reduce_collapse),
(acc_pat.assign(UPat(Ops.ADD, src=[acc_pat, UPat.var("ret")], name="alu")), reduce_collapse),
(acc_pat.assign(UPat(Ops.MAX, src=[acc_pat, UPat.var("ret")], name="alu")), reduce_collapse),
# ** self folding **
(UPat(Ops.DEFINE_ACC, src=(UPat.var("x"),)), lambda x: x), # a DEFINE_ACC without ranges is a CONST
(UPat(Ops.ASSIGN, src=(UPat.cvar(),UPat.var("x"))), lambda x: x), # an ASSIGN to a const is a NOOP
@@ -403,7 +403,7 @@ expander = PatternMatcher([
(UPat(Ops.EXPAND, name="outer", src=(UPat(Ops.EXPAND, name="inner"),)),
lambda outer, inner: UOp(Ops.EXPAND, outer.dtype, (inner.src[0],), inner.arg+outer.arg)),
# do expansion
(UPat((Ops.ALU, Ops.CAST, Ops.BITCAST, Ops.GEP, Ops.WMMA, Ops.LOAD, Ops.STORE, Ops.INDEX, Ops.ASSIGN,
(UPat((*GroupOp.ALU, Ops.CAST, Ops.BITCAST, Ops.GEP, Ops.WMMA, Ops.LOAD, Ops.STORE, Ops.INDEX, Ops.ASSIGN,
Ops.VECTORIZE, Ops.REDUCE, Ops.IF), name="root", custom_early_reject=set([(Ops.EXPAND, None)])), do_expand),
(UPat(Ops.CONTRACT, name="con"), do_contract),
# vectorize DEFINE_ACC
@@ -433,7 +433,7 @@ def no_vectorized_acc(acc:UOp):
devectorize = PatternMatcher([
# no ALU on vectorized dtypes
(UPat((Ops.ALU, Ops.CAST, Ops.BITCAST, Ops.ASSIGN, Ops.INDEX), name="alu"), no_vectorized_alu),
(UPat((*GroupOp.ALU, Ops.CAST, Ops.BITCAST, Ops.ASSIGN, Ops.INDEX), name="alu"), no_vectorized_alu),
(UPat(Ops.WMMA, name="wmma"), no_vectorized_wmma),
(UPat(Ops.DEFINE_ACC, name="acc"), no_vectorized_acc),
(UPat((Ops.LOAD, Ops.STORE), name="ls"), no_vectorized_load_store),
@@ -448,7 +448,7 @@ load_store_indexing = PatternMatcher([
# late fixup of unfoldable image loads
(UPat(Ops.LOAD, src=(UPat.var("buf"), UPat()), allow_any_len=True, name="load"), fix_unfoldable_image_load),
# simplify valid
(UPat(Ops.ALU, name="valid", arg=BinaryOps.AND), simplify_valid),
(UPat(Ops.AND, name="valid"), simplify_valid),
# image load valid idx simplification
(UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var("start_idx"), UPat.var("valid"))), simplify_valid_load),
# delete_redundant_gates (after expand)

View File

@@ -1,7 +1,7 @@
import sys, atexit, functools, itertools
from collections import defaultdict, deque
from dataclasses import dataclass, field
from typing import Callable, Set, Tuple, List, Dict, Optional, DefaultDict
from typing import Callable, Set, Tuple, List, Dict, Optional, DefaultDict, cast
from tinygrad.ops import GroupOp, UOp, Ops, PatternMatcher, UPat, Variable, graph_rewrite, track_rewrites, sint
from tinygrad.helpers import DEBUG, Metadata, all_same, colored, diskcache_put, prod, dedup, getenv, unwrap
from tinygrad.dtype import ImageDType, dtypes
@@ -72,9 +72,7 @@ def to_uop(buf:LazyBuffer, ctx:ScheduleContext, cache:Dict[LazyBuffer, UOp]) ->
elif buf.op is Ops.CONTIGUOUS: ret = UOp(Ops.CONTIGUOUS, dtype, src)
elif buf.op is Ops.ASSIGN: ret = UOp(Ops.ASSIGN, dtype, (ubuf, src[1]), buf.arg)
elif buf.op in GroupOp.Meta: ret = UOp(buf.op, buf.dtype, (ubuf, *src), buf.arg)
elif buf.op is Ops.CAST: ret = UOp(Ops.CAST, dtype, src)
elif buf.op is Ops.BITCAST: ret = UOp(Ops.BITCAST, dtype, src)
else: ret = UOp(Ops.ALU, dtype, src, buf.op)
else: ret = UOp(cast(Ops, buf.op), dtype, src)
cache[buf] = ret = UOp(Ops.LOAD, dtype, (ubuf, buf.st.to_uop(), UOp.store(ubuf, ShapeTracker.from_shape(buf.shape).to_uop(), ret)))
if buf.metadata is not None: ctx.ubuf_metadata[ubuf] = buf.metadata
if buf.forced_realize: ctx.realizes[ubuf] = ubuf
@@ -142,7 +140,7 @@ merge_views = PatternMatcher([(UPat(Ops.VIEW, src=(UPat(Ops.VIEW, name="s0"),),
# push VIEW to loads
view_left = merge_views+PatternMatcher([
# view before ALU
(UPat(Ops.VIEW, src=(UPat((Ops.ALU, Ops.CAST, Ops.BITCAST, Ops.ASSIGN, Ops.CONTIGUOUS, *GroupOp.Buffer), name="e"),), name="v"),
(UPat(Ops.VIEW, src=(UPat((*GroupOp.ALU, Ops.CAST, Ops.BITCAST, Ops.ASSIGN, Ops.CONTIGUOUS, *GroupOp.Buffer), name="e"),), name="v"),
lambda e,v: e.replace(src=tuple(s.view(v.st) if s.has_st else s for s in e.src))),
])
@@ -156,7 +154,7 @@ view_right = merge_views+PatternMatcher([
# push a VIEW down to STORE, through a reduce (ONLY reshapes)
(UPat(Ops.REDUCE_AXIS, src=(UPat(Ops.VIEW, name="swizzle"),), name="root"), push_swizzle_down_through_reduce),
# push VIEW(s) down to STORE, through an elementwise op (ONLY reshapes)
(UPat((Ops.ALU, Ops.CAST, Ops.BITCAST, Ops.ASSIGN, Ops.CONTIGUOUS, Ops.STORE), name="root"), push_swizzle_down_through_elementwise),
(UPat((*GroupOp.ALU, Ops.CAST, Ops.BITCAST, Ops.ASSIGN, Ops.CONTIGUOUS, Ops.STORE), name="root"), push_swizzle_down_through_elementwise),
(UPat(Ops.REDUCE_AXIS, src=(UPat(Ops.REDUCE_AXIS, name="first_reduce"),), name="root"), merge_double_reduce),
])

View File

@@ -169,7 +169,11 @@ class Ops(FastEnum):
CONST = auto()
class GroupOp:
Unary = {Ops.EXP2, Ops.LOG2, Ops.SIN, Ops.SQRT, Ops.RECIP, Ops.NEG}
Binary = {Ops.ADD, Ops.MUL, Ops.IDIV, Ops.MAX, Ops.MOD, Ops.CMPLT, Ops.CMPNE, Ops.XOR, Ops.SHL, Ops.SHR, Ops.OR, Ops.AND, Ops.THREEFRY, Ops.SUB}
Ternary = {Ops.WHERE, Ops.MULACC}
ALU = set.union(Unary, Binary, Ternary)
Reduce = {Ops.SUM, Ops.PROD, Ops.REDUCE_MAX}
# meta ops
@@ -224,6 +228,8 @@ def pretty_print(x:Any, rep:Callable, srcfn=lambda x: x.src, cache=None, d=0)->s
class UOpMetaClass(type):
ucache:WeakValueDictionary[Tuple, UOp] = WeakValueDictionary()
def __call__(cls, op:Ops, dtype:DType=dtypes.void, src:Tuple[UOp,...]=tuple(), arg:Any=None):
# TODO: remove this
if op is Ops.ALU: op, arg = arg, None
if (ret:=UOpMetaClass.ucache.get(key:=(op, dtype, src, arg), None)) is not None: return ret
UOpMetaClass.ucache[key] = ret = super().__call__(op, dtype, src, arg)
return ret
@@ -232,11 +238,6 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
__slots__ = ["op", "dtype", "src", "arg"]
def __init__(self, op:Ops, dtype:DType=dtypes.void, src: Tuple[UOp,...]=tuple(), arg:Any=None):
# TODO: instant check rules here make debugging easier
#assert op in UOps and isinstance(dtype, DType), f"bad UOp creation with {op} {dtype}"
#if op is UOps.ALU and arg is BinaryOps.CMPNE: assert dtype.scalar() == dtypes.bool
#if op is UOps.VECTORIZE and dtype != dtypes.void: assert len(src) == dtype.count, f"{len(src)} invalid for {dtype}"
#if op is UOps.ALU and arg not in (BinaryOps.CMPNE, BinaryOps.CMPLT, TernaryOps.WHERE): assert all_same([dtype] + [x.dtype for x in src])
#if op is UOps.CAST: assert dtype.count == src[0].dtype.count, f"cast can't change vectorization {src[0].dtype} --> {dtype}"
self.op, self.dtype, self.src, self.arg = op, dtype, src, arg
def __reduce__(self): return UOp, (self.op, self.dtype, self.src, self.arg)
def replace(self, **kwargs) -> UOp:
@@ -256,7 +257,7 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
@functools.cached_property
def tuplize(self:UOp) -> Tuple[int, Any, Optional[DType], Tuple]:
return (self.op.value, self.arg.value if self.op is Ops.ALU else self.arg, self.dtype, tuple(x.tuplize for x in self.src))
return (self.op.value, self.arg, self.dtype, tuple(x.tuplize for x in self.src))
# *** uop shape stuff ***
@@ -334,7 +335,7 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
out_dtype = (self, *src)[-1].dtype
if arg in {BinaryOps.CMPLT, BinaryOps.CMPNE} and out_dtype is not None:
out_dtype = dtypes.bool.vec(out_dtype.count) if out_dtype.count > 1 else dtypes.bool
return UOp(Ops.ALU, out_dtype, (self,)+src, arg)
return UOp(arg, out_dtype, (self,)+src)
@staticmethod
def const(dtype:DType, b:ConstLike):
if isinstance(b, UOp): return b.unbind()[0] if b.op is Ops.BIND else b
@@ -382,19 +383,17 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
"""largest known int that divides self"""
if self.op is Ops.CONST: return self.arg
if self.op is Ops.VCONST: return functools.reduce(math.gcd, self.arg)
if self.op is Ops.ALU:
if self.arg is BinaryOps.ADD: return math.gcd(self.src[0].const_factor(), self.src[1].const_factor())
if self.arg is BinaryOps.MUL: return self.src[0].arg if self.src[0].op is Ops.CONST else self.src[1].arg if self.src[1].op is Ops.CONST else 1
if self.op is BinaryOps.ADD: return math.gcd(self.src[0].const_factor(), self.src[1].const_factor())
if self.op is BinaryOps.MUL: return self.src[0].arg if self.src[0].op is Ops.CONST else self.src[1].arg if self.src[1].op is Ops.CONST else 1
return 1
def divides(self, v) -> Optional[UOp]:
if v==1: return self
if self.op is Ops.CONST: return self.const_like(self.arg//v) if self.arg%v == 0 else None
if self.op is Ops.VCONST: return self.const_like(tuple(x//v for x in self.arg)) if all(x%v == 0 for x in self.arg) else None
if self.op is Ops.ALU:
if self.arg is BinaryOps.ADD: return d0+d1 if (d0:=self.src[0].divides(v)) is not None and (d1:=self.src[1].divides(v)) is not None else None
if self.arg is BinaryOps.MUL:
if (d0:=self.src[0].divides(v)) is not None: return d0 * self.src[1]
if (d1:=self.src[1].divides(v)) is not None: return self.src[0] * d1
if self.op is BinaryOps.ADD: return d0+d1 if (d0:=self.src[0].divides(v)) is not None and (d1:=self.src[1].divides(v)) is not None else None
if self.op is BinaryOps.MUL:
if (d0:=self.src[0].divides(v)) is not None: return d0 * self.src[1]
if (d1:=self.src[1].divides(v)) is not None: return self.src[0] * d1
return None # generic None if we aren't sure
@property
def vmin(self) -> ConstType: return self._min_max[0]
@@ -411,25 +410,25 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
if self.op is Ops.SPECIAL: return 0, self.arg[1]-1 if isinstance(self.arg[1], int) else dtypes.max(self.dtype)
if self.op is Ops.CONST: return self.arg, self.arg
if self.op is Ops.VCONST: return (min(self.arg), max(self.arg))
if self.op is Ops.ALU and not dtypes.is_float(self.dtype):
if self.op in GroupOp.ALU and not dtypes.is_float(self.dtype):
s0,s1,s2 = [cast(UOp, self.src[i] if i < len(self.src) else None) for i in range(3)]
if self.arg is BinaryOps.ADD: return s0.vmin+s1.vmin, s0.vmax+s1.vmax
if self.arg is BinaryOps.MUL: return min(vals:=(s0.vmin*s1.vmin, s0.vmin*s1.vmax, s0.vmax*s1.vmin, s0.vmax*s1.vmax)), max(vals)
if self.arg is BinaryOps.MOD and s1.vmin > 0: return 0, s1.vmax-1
if self.arg is BinaryOps.IDIV and s1.op is Ops.CONST:
if self.op is BinaryOps.ADD: return s0.vmin+s1.vmin, s0.vmax+s1.vmax
if self.op is BinaryOps.MUL: return min(vals:=(s0.vmin*s1.vmin, s0.vmin*s1.vmax, s0.vmax*s1.vmin, s0.vmax*s1.vmax)), max(vals)
if self.op is BinaryOps.MOD and s1.vmin > 0: return 0, s1.vmax-1
if self.op is BinaryOps.IDIV and s1.op is Ops.CONST:
if s1.arg > 0: return s0.vmin//s1.arg, s0.vmax//s1.arg
if s1.arg < 0 and s0.vmin >= 0: return -(s0.vmax//-s1.arg), -(s0.vmin//-s1.arg)
if self.arg is BinaryOps.MAX: return max(s0.vmin, s1.vmin), max(s0.vmax, s1.vmax)
if self.arg is BinaryOps.CMPLT: return (s0.vmax<s1.vmin, s0.vmin<s1.vmax)
if self.arg is BinaryOps.CMPNE:
if self.op is BinaryOps.MAX: return max(s0.vmin, s1.vmin), max(s0.vmax, s1.vmax)
if self.op is BinaryOps.CMPLT: return (s0.vmax<s1.vmin, s0.vmin<s1.vmax)
if self.op is BinaryOps.CMPNE:
always_ne = (s0.vmax < s1.vmin) or (s1.vmax < s0.vmin)
sometimes_ne = not (s0.vmin == s0.vmax == s1.vmin == s1.vmax)
return (always_ne, sometimes_ne)
# float has NAN issue and we use explicit NAN in transcendental
if self.arg is TernaryOps.WHERE and dtypes.is_int(s1.dtype): return min(s1.vmin, s2.vmin), max(s1.vmax, s2.vmax)
if self.op is TernaryOps.WHERE and dtypes.is_int(s1.dtype): return min(s1.vmin, s2.vmin), max(s1.vmax, s2.vmax)
if self.dtype == dtypes.bool:
if self.arg is BinaryOps.OR: return s0.vmin or s1.vmin, s0.vmax or s1.vmax
if self.arg is BinaryOps.AND: return s0.vmin and s1.vmin, s0.vmax and s1.vmax
if self.op is BinaryOps.OR: return s0.vmin or s1.vmin, s0.vmax or s1.vmax
if self.op is BinaryOps.AND: return s0.vmin and s1.vmin, s0.vmax and s1.vmax
return dtypes.min(self.dtype), dtypes.max(self.dtype)
@functools.cached_property
@@ -509,8 +508,8 @@ def flops_mem(uops:List[UOp], ignore_indexing=False) -> Tuple[sint, sint]:
mem += u.dtype.itemsize * mults
elif u.op is Ops.STORE:
mem += u.src[1].dtype.itemsize * mults
elif u.op is Ops.ALU and u not in dont_count:
flops += (mults * (2 if u.arg == TernaryOps.MULACC else 1)) * u.dtype.count
elif u.op in GroupOp.ALU and u not in dont_count:
flops += (mults * (2 if u.op is TernaryOps.MULACC else 1)) * u.dtype.count
elif u.op is Ops.WMMA and u not in dont_count:
flops += 2 * prod(u.arg[1]) // u.arg[5] * mults
return flops, mem
@@ -582,9 +581,9 @@ class UPat(MathTrait):
def assign(self, x:UPat): return UPat(Ops.ASSIGN, self.dtype, (self,x))
def const_like(self, b:ConstLike): return UPat.const(self.dtype, cast(ConstType, b))
def alu(self, arg, *src:UPat):
def alu(self, op:Ops, *src:UPat):
asrc = (self,)+src
return UPat(Ops.ALU, None if arg in {Ops.CMPLT, Ops.CMPNE} else asrc[-1].dtype, list(asrc) if arg in GroupOp.Commutative else asrc, arg)
return UPat(op, None if op in {Ops.CMPLT, Ops.CMPNE} else asrc[-1].dtype, list(asrc) if op in GroupOp.Commutative else asrc)
def printable(self:UPat) -> str:
try: return lines(self.location[0])[self.location[1]-1].strip()
@@ -790,17 +789,13 @@ spec = PatternMatcher([
(UPat(Ops.STORE, dtype=dtypes.void, src=(UPat((Ops.INDEX, Ops.CAST)), UPat(), UPat(Ops.IF))), lambda: True),
# most ALUs have all matching dtypes, except CMPLT, CMPNE, and WHERE
(UPat(Ops.ALU, name="w", src=(UPat(dtype=dtypes.bool), UPat(name="x"), UPat(name="y")), arg=TernaryOps.WHERE),
lambda w,x,y: w.dtype == x.dtype == y.dtype),
(UPat(Ops.ALU, dtype=dtypes.bool, src=(UPat(name="x"), UPat(name="y")), arg=BinaryOps.CMPLT), lambda x,y: x.dtype == y.dtype),
(UPat(Ops.ALU, dtype=dtypes.bool, src=(UPat(name="x"), UPat(name="y")), arg=BinaryOps.CMPNE), lambda x,y: x.dtype == y.dtype),
# and SHL/SHR, the shift distance is an int
(UPat(Ops.ALU, src=(UPat(name="x"), UPat(name="y")), name="alu", arg=BinaryOps.SHL),
(UPat(Ops.WHERE, name="w", src=(UPat(dtype=dtypes.bool), UPat(name="x"), UPat(name="y"))), lambda w,x,y: w.dtype == x.dtype == y.dtype),
(UPat((Ops.CMPLT, Ops.CMPNE), dtype=dtypes.bool, src=(UPat(name="x"), UPat(name="y"))), lambda x,y: x.dtype == y.dtype),
# and SHL/SHR, the shift distance can be an int
(UPat((Ops.SHL, Ops.SHR), src=(UPat(name="x"), UPat(name="y")), name="alu"),
lambda alu,x,y: alu.dtype == x.dtype and (x.dtype == y.dtype or y.dtype == dtypes.uint)),
(UPat(Ops.ALU, src=(UPat(name="x"), UPat(name="y")), name="alu", arg=BinaryOps.SHR),
lambda alu,x,y: alu.dtype == x.dtype and (x.dtype == y.dtype or y.dtype == dtypes.uint)),
(UPat(Ops.ALU, arg=BinaryOps.IDIV, name="x"), lambda x: None if dtypes.is_int(x.dtype) else False),
(UPat(Ops.ALU, name="x"), lambda x: all(x.dtype == y.dtype for y in x.src)),
(UPat(Ops.IDIV, name="x"), lambda x: None if dtypes.is_int(x.dtype) else False),
(UPat(GroupOp.ALU, name="x"), lambda x: all(x.dtype == y.dtype for y in x.src)),
(UPat(Ops.ASSIGN, src=(UPat((Ops.DEFINE_ACC, Ops.DEFINE_GLOBAL)), UPat())), lambda: True),
(UPat(Ops.ENDRANGE, dtype=dtypes.void, src=(UPat(Ops.RANGE),)), lambda: True),
@@ -848,7 +843,7 @@ def cast_float_to_bf16(x: UOp) -> UOp:
# *** most of symbolic lives here now ***
def split_uop(x:UOp, sep:Ops):
if x.op is Ops.ALU and x.arg is sep:
if x.op is sep:
for s in x.src: yield from split_uop(s, sep)
else: yield x
@@ -865,7 +860,7 @@ def mod_folding(x:UOp, c:int) -> Optional[UOp]:
assert divides is not None
remainder.append(divides)
something_changed = True
elif u.op is Ops.ALU and u.arg is BinaryOps.MOD and (s1:=u.src[1]).op is Ops.CONST and s1.arg%c == 0:
elif u.op is Ops.MOD and (s1:=u.src[1]).op is Ops.CONST and s1.arg%c == 0:
remainder.append(u.src[0])
something_changed = True
else: remainder.append(u)
@@ -892,7 +887,7 @@ def div_folding(x:UOp, c:int) -> Optional[UOp]:
something_changed = True
else:
# divisor is the smallest common divisor of all MULs
if u.op is Ops.ALU and u.arg is BinaryOps.MUL and factor > 1 and c % factor == 0 and (divisor == 1 or divisor > factor): divisor = factor
if u.op is Ops.MUL and factor > 1 and c % factor == 0 and (divisor == 1 or divisor > factor): divisor = factor
remainder.append(u)
gcd = math.gcd(gcd, factor)
@@ -923,11 +918,11 @@ def fold_unrolled_divs(divs:UOp):
# example: (x//4+(x+1)//4+(x+2)//4+(x+3)//4 -> x
add_chain, denominator, seen_const, ans = list(split_uop(divs, BinaryOps.ADD)), None, [], None
for u in add_chain:
if not (u.op is Ops.ALU and u.arg is BinaryOps.IDIV and u.src[1].op is Ops.CONST): return None
if not (u.op is Ops.IDIV and u.src[1].op is Ops.CONST): return None
if denominator is None: denominator = u.src[1].arg
if denominator != u.src[1].arg: return None
# assumed CONST is the last of an ADD
if (s0:=u.src[0]).op is Ops.ALU and s0.arg is BinaryOps.ADD and s0.src[1].op is Ops.CONST and s0.src[1].op is Ops.CONST:
if (s0:=u.src[0]).op is Ops.ADD and s0.src[1].op is Ops.CONST and s0.src[1].op is Ops.CONST:
seen_const.append(s0.src[1].arg)
s0 = s0.src[0]
else: seen_const.append(0)
@@ -947,7 +942,7 @@ def canonicalize_simplex(X:UOp) -> Optional[UOp]:
changed, ret = False, []
for u in split_uop(X, BinaryOps.ADD):
# assumed the const is the last src of MUL
if u.op is Ops.ALU and u.arg is BinaryOps.MUL and u.src[1].op is Ops.CONST and u.src[1].arg > 0:
if u.op is BinaryOps.MUL and u.src[1].op is Ops.CONST and u.src[1].arg > 0:
changed = True
u = u.src[0]
if not (is_irreducible(u) and u.vmin >= 0): return None
@@ -957,8 +952,8 @@ def canonicalize_simplex(X:UOp) -> Optional[UOp]:
def is_increasing(f:UOp) -> bool:
# is f a monotonically increasing function regards its input
if is_irreducible(f): return True
if f.op is Ops.ALU and f.arg is BinaryOps.ADD: return is_increasing(f.src[0]) and is_increasing(f.src[1])
if f.op is Ops.ALU and f.arg in (BinaryOps.MUL, BinaryOps.IDIV) and f.src[1].op is Ops.CONST and f.src[1].arg >= 0: return is_increasing(f.src[0])
if f.op is BinaryOps.ADD: return is_increasing(f.src[0]) and is_increasing(f.src[1])
if f.op in (BinaryOps.MUL, BinaryOps.IDIV) and f.src[1].op is Ops.CONST and f.src[1].arg >= 0: return is_increasing(f.src[0])
return False # False if not sure
def parse_valid(valid:UOp) -> Tuple[UOp, bool, int]:
@@ -966,10 +961,10 @@ def parse_valid(valid:UOp) -> Tuple[UOp, bool, int]:
# if it's X >= c, returns X, False, c
# (X < c).ne(True) -> X >= c
if valid.op is Ops.ALU and valid.arg is BinaryOps.CMPNE and valid.src[1].op is Ops.CONST and valid.src[1].arg == 1 and \
(s0:=valid.src[0]).op is Ops.ALU and s0.arg is BinaryOps.CMPLT and s0.src[1].op is Ops.CONST: return s0.src[0], False, s0.src[1].arg
if valid.op is BinaryOps.CMPNE and valid.src[1].op is Ops.CONST and valid.src[1].arg == 1 and \
(s0:=valid.src[0]).op is BinaryOps.CMPLT and s0.src[1].op is Ops.CONST: return s0.src[0], False, s0.src[1].arg
# X < c -> X <= c-1
if valid.op is Ops.ALU and valid.arg is BinaryOps.CMPLT and valid.src[1].op is Ops.CONST: return valid.src[0], True, valid.src[1].arg-1
if valid.op is BinaryOps.CMPLT and valid.src[1].op is Ops.CONST: return valid.src[0], True, valid.src[1].arg-1
raise ValueError(f"not able to parse {valid=}")
def uop_given_valid(valid:UOp, uop:UOp) -> Optional[UOp]:
@@ -989,7 +984,7 @@ def uop_given_valid(valid:UOp, uop:UOp) -> Optional[UOp]:
# every candidate is a set of contrained UOp based on valid, and if every item in a set simplifies the uop into a same output, we rewrite uop
candidates = []
if expr.op is Ops.ALU and expr.arg is BinaryOps.ADD and all(is_irreducible(u) and v[0] == 1 for u in split_uop(expr, BinaryOps.ADD)):
if expr.op is Ops.ADD and all(is_irreducible(u) and v[0] == 1 for u in split_uop(expr, BinaryOps.ADD)):
# if the constraint is a simplex: X0 + X1 + ... > 0, we can check if all Xi > 0 simplify into the same output
candidates.append([(Xi, UOp.variable("fake", 1, Xi.vmax, Xi.dtype)) for Xi in split_uop(expr, BinaryOps.ADD)])
# try checking the whole clause
@@ -1043,8 +1038,8 @@ symbolic_simple = PatternMatcher([
# NOTE: this can be wrong for loaded NaN
(UPat.var("x") * 0, lambda x: x.const_like(float("nan") if isinstance(x.arg, float) and (math.isnan(x.arg) or math.isinf(x.arg)) else 0)),
# ** constant folding **
(UPat(Ops.ALU, name="root", src=UPat((Ops.VCONST, Ops.CONST))),
lambda root: root.const_like(exec_alu(root.arg, root.dtype, [x.arg for x in root.src], truncate_output=False))),
(UPat(GroupOp.ALU, name="root", src=UPat((Ops.VCONST, Ops.CONST))),
lambda root: root.const_like(exec_alu(root.op, root.dtype, [x.arg for x in root.src], truncate_output=False))),
# bool MUL is AND, ADD/MAX is OR. prevents other rules to rewrite bool ADD/MUL incorrectly
(UPat.var('x', dtype=dtypes.bool) * UPat.var('y', dtype=dtypes.bool), lambda x,y: x&y),
(UPat.var('x', dtype=dtypes.bool) + UPat.var('y', dtype=dtypes.bool), lambda x,y: x|y),
@@ -1056,8 +1051,7 @@ symbolic_simple = PatternMatcher([
symbolic = symbolic_simple+PatternMatcher([
# ** COMMUTATIVE flipping **
*[(UPat(Ops.ALU, arg=op, name='x'), lambda x: x.replace(src=x.src[::-1]) if x.src[1].tuplize < x.src[0].tuplize else None) \
for op in GroupOp.Commutative],
(UPat(GroupOp.Commutative, name='x'), lambda x: x.replace(src=x.src[::-1]) if x.src[1].tuplize < x.src[0].tuplize else None),
# group like
((UPat.var("x") + UPat.var("y")) + UPat.var("x") * UPat.cvar("c"), lambda x,y,c: (x+x*c)+y),
# ** combine terms **
@@ -1070,7 +1064,7 @@ symbolic = symbolic_simple+PatternMatcher([
(UPat.var().where(UPat.var("val"), UPat.var("val")), lambda val: val),
(UPat.cvar("gate", vec=False).where(UPat.var("c0"), UPat.var("c1")), lambda gate, c0, c1: c0 if gate.arg else c1),
# ALU min==max -> CONST (slow!)
(UPat(Ops.ALU, name="x"), lambda x: x.const_like(x.vmin) if x.vmin == x.vmax else None),
(UPat(GroupOp.ALU, name="x"), lambda x: x.const_like(x.vmin) if x.vmin == x.vmax else None),
# max folding
(UPat.maximum(UPat.var("x"), UPat.var("y")), lambda x,y: x if x.vmin >= y.vmax else y if x.vmax <= y.vmin else None),
# TODO: why does this rule break beautiful_mnist?
@@ -1097,11 +1091,11 @@ symbolic = symbolic_simple+PatternMatcher([
(((UPat.cvar("c0", vec=False)*UPat.var("x"))+UPat.var("x2")).lt(UPat.cvar("c1", vec=False)),
lambda x,x2,c0,c1: x.lt(c1//c0) if c1.arg % c0.arg == 0 and c0.arg > x2.vmax and x2.vmin >= 0 else None),
# ** move add/mul consts to end (NOTE: this is still happening before constant folding) **
(UPat(Ops.ALU, arg=BinaryOps.ADD, src=(UPat.var("x"), UPat.cvar("c1"))) + UPat.var("y"), lambda x,c1,y: (x+y)+c1),
(UPat(Ops.ALU, arg=BinaryOps.MUL, src=(UPat.var("x"), UPat.cvar("c1"))) * UPat.var("y"), lambda x,c1,y: (x*y)*c1),
(UPat(Ops.ADD, src=(UPat.var("x"), UPat.cvar("c1"))) + UPat.var("y"), lambda x,c1,y: (x+y)+c1),
(UPat(Ops.MUL, src=(UPat.var("x"), UPat.cvar("c1"))) * UPat.var("y"), lambda x,c1,y: (x*y)*c1),
# *** rules from symbolic ***
# unrolled arange div folding
(UPat(Ops.ALU, name="divs", src=[UPat(), UPat(Ops.ALU, arg=BinaryOps.IDIV)], arg=BinaryOps.ADD), fold_unrolled_divs),
(UPat(Ops.ADD, name="divs", src=[UPat(), UPat(Ops.IDIV)]), fold_unrolled_divs),
# generic lt folding
(UPat.var("x", dtypes.sints).lt(UPat.cvar("c", vec=False)), lambda x,c: lt_folding(x, c.arg) if 0 < c.arg else None),
# canonicalize a simplex with positive coefficients > 0
@@ -1132,13 +1126,11 @@ renderer = PatternMatcher([
(UPat(Ops.RANGE, name="x"), lambda x: UOp(Ops.NOOP, arg=f"ridx{x.arg[0]}")),
(UPat(Ops.CONST, name="x"), lambda x: UOp(Ops.NOOP, arg=str(x.arg))),
(UPat(Ops.BIND, src=UPat(Ops.NOOP), name="x"), lambda x: x.src[0]),
(UPat(Ops.ALU, src=UPat(Ops.NOOP), name="x", arg=UnaryOps.NEG), lambda x: UOp(Ops.NOOP, arg=f"(-{x.src[0].arg})")),
(UPat(Ops.ALU, src=UPat(Ops.NOOP), name="x", arg=BinaryOps.MAX), lambda x: UOp(Ops.NOOP, arg=f"max({x.src[0].arg}, {x.src[1].arg})")),
(UPat(Ops.ALU, src=UPat(Ops.NOOP), name="x", arg=TernaryOps.MULACC),
lambda x: UOp(Ops.NOOP, arg=f"({x.src[0].arg}*{x.src[1].arg}+{x.src[2].arg})")),
(UPat(Ops.ALU, src=UPat(Ops.NOOP), name="x", arg=TernaryOps.WHERE),
lambda x: UOp(Ops.NOOP, arg=f"({x.src[1].arg} if {x.src[0].arg} else {x.src[2].arg})")),
(UPat(Ops.ALU, src=UPat(Ops.NOOP), name="x"), lambda x: UOp(Ops.NOOP, arg=f"({x.src[0].arg}{syms[x.arg]}{x.src[1].arg})")),
(UPat(Ops.NEG, src=UPat(Ops.NOOP), name="x"), lambda x: UOp(Ops.NOOP, arg=f"(-{x.src[0].arg})")),
(UPat(Ops.MAX, src=UPat(Ops.NOOP), name="x"), lambda x: UOp(Ops.NOOP, arg=f"max({x.src[0].arg}, {x.src[1].arg})")),
(UPat(Ops.MULACC, src=UPat(Ops.NOOP), name="x"), lambda x: UOp(Ops.NOOP, arg=f"({x.src[0].arg}*{x.src[1].arg}+{x.src[2].arg})")),
(UPat(Ops.WHERE, src=UPat(Ops.NOOP), name="x"), lambda x: UOp(Ops.NOOP, arg=f"({x.src[1].arg} if {x.src[0].arg} else {x.src[2].arg})")),
(UPat(GroupOp.ALU, src=UPat(Ops.NOOP), name="x"), lambda x: UOp(Ops.NOOP, arg=f"({x.src[0].arg}{syms[x.op]}{x.src[1].arg})")),
])
# *** what was symbolic.py ***

View File

@@ -2,7 +2,7 @@ from __future__ import annotations
from typing import Dict, List, Optional, Tuple, Union, DefaultDict, Literal, Callable, cast
import os, math
from collections import defaultdict, Counter
from tinygrad.ops import UnaryOps, BinaryOps, TernaryOps, Ops, UOp, PatternMatcher, UPat, cast_float_to_bf16
from tinygrad.ops import GroupOp, UnaryOps, BinaryOps, TernaryOps, Ops, UOp, PatternMatcher, UPat, cast_float_to_bf16
from tinygrad.helpers import strip_parens, getenv, prod, dedup, AMX
from tinygrad.dtype import ImageDType, dtypes, DType, PtrDType
from tinygrad.renderer import Renderer, TensorCore
@@ -47,8 +47,8 @@ base_rewrite = PatternMatcher([
(UPat(Ops.LOAD, src=(UPat.var('bidx'),), allow_any_len=True), lambda ctx,bidx: f"*{ctx[bidx]}"),
(UPat(Ops.STORE, src=(UPat.var('bidx'), UPat.var("var")), allow_any_len=True), lambda ctx,bidx,var: f"*{ctx[bidx]} = {ctx[var]};"),
# alu/gep
(UPat(Ops.ALU, name="x"), lambda ctx,x: ctx.code_for_op[x.arg](
*([strip_parens(ctx[v]) if v.arg == x.arg and x.arg in {BinaryOps.ADD, BinaryOps.MUL, BinaryOps.XOR} else ctx[v] for v in x.src]), x.dtype)),
(UPat(GroupOp.ALU, name="x"), lambda ctx,x: ctx.code_for_op[x.op](
*([strip_parens(ctx[v]) if v.op == x.op and x.op in {BinaryOps.ADD, BinaryOps.MUL, BinaryOps.XOR} else ctx[v] for v in x.src]), x.dtype)),
(UPat(Ops.GEP, name="x"), lambda ctx,x: ctx[x.src[0]] + \
(f"[{x.arg[0]}]" if x.src[0].dtype.count > (8 if ctx.device in {"CUDA", "NV"} else 4) or ctx.device == 'CLANG' else f".{'xyzwabcd'[x.arg[0]]}")),
])
@@ -61,7 +61,7 @@ extra_pm = PatternMatcher([
(UPat(Ops.STORE, dtype=dtypes.void, src=(UPat(), UPat(), UPat(dtype=dtypes.bool)), name="store"),
lambda store: UOp(Ops.STORE, src=store.src[:2]+(UOp(Ops.IF, src=(store.src[2],)),))),
# rewrite MAX to CMPLT + WHERE (max function is annoying on many cstyle backends)
(UPat(Ops.ALU, name="m", arg=BinaryOps.MAX), lambda m: (m.src[0] < m.src[1]).where(m.src[1], m.src[0])),
(UPat(Ops.MAX, name="m"), lambda m: (m.src[0] < m.src[1]).where(m.src[1], m.src[0])),
])
def uops_to_dtypes(uops:List[UOp]) -> List[DType]: return dedup(u.dtype for u in uops if not isinstance(u.dtype, (ImageDType, PtrDType)))
@@ -140,16 +140,16 @@ class CStyleLanguage(Renderer):
if u.op is Ops.SPECIAL:
r[u] = u.arg[0]
else:
prefix = {Ops.RANGE: "ridx", Ops.ALU: "alu", Ops.WMMA: "wmma", Ops.DEFINE_LOCAL: "temp", Ops.CONST: "const",
prefix = {Ops.RANGE: "ridx", Ops.WMMA: "wmma", Ops.DEFINE_LOCAL: "temp", Ops.CONST: "const",
Ops.CAST: "cast", Ops.BITCAST: "cast", Ops.GEP: "gep", Ops.VECTORIZE: "cast", Ops.NOOP: "precast",
Ops.INDEX: "bidx", Ops.DEFINE_ACC: "acc", Ops.LOAD: "val"}.get(u.op, "unk")
Ops.INDEX: "bidx", Ops.DEFINE_ACC: "acc", Ops.LOAD: "val"}.get(u.op, "alu")
r[u] = f"{prefix}{c[prefix]}"
l = cast(str, self.string_rewrite.rewrite(u, ctx=self))
assert l is not None, f"failed to render {u.op} {u.dtype} {[(x.op,x.dtype) for x in u.src]} {u.arg}"
if u.op in {Ops.ENDIF, Ops.ENDRANGE}: depth -= 1
if u.op in {Ops.CONST, Ops.GEP, Ops.INDEX} or (u.op in {Ops.VECTORIZE, Ops.ALU, Ops.CAST, Ops.BITCAST}
if u.op in {Ops.CONST, Ops.GEP, Ops.INDEX} or (u.op in {Ops.VECTORIZE, *GroupOp.ALU, Ops.CAST, Ops.BITCAST}
and child_count[u] == 1 and not getenv("EXPAND_SSA")):
r[u] = l
else:
@@ -272,9 +272,8 @@ class MetalRenderer(CStyleLanguage):
# upcast to float32 all the ops that don't support bfloat16
extra_matcher = PatternMatcher([
# NOTE: this is copied from PTX
*[(UPat(Ops.ALU, arg=op, dtype=dtypes.bfloat16, name="x"),
lambda x: (UOp(x.op, dtypes.float, tuple(vv.cast(dtypes.float) for vv in x.src), x.arg).cast(dtypes.bfloat16)))
for op in [UnaryOps.SQRT, UnaryOps.EXP2, UnaryOps.LOG2, UnaryOps.SIN]]
(UPat((UnaryOps.SQRT, UnaryOps.EXP2, UnaryOps.LOG2, UnaryOps.SIN), dtype=dtypes.bfloat16, name="x"),
lambda x: (UOp(x.op, dtypes.float, tuple(vv.cast(dtypes.float) for vv in x.src), x.arg).cast(dtypes.bfloat16))),
]) + extra_pm
string_rewrite = PatternMatcher([
@@ -387,11 +386,11 @@ class AMDRenderer(CStyleLanguage):
type_map = {dtypes.bfloat16: "hip_bfloat16"}
extra_matcher = PatternMatcher([
# cast bfloat16 alus to float
(UPat(Ops.ALU, arg=TernaryOps.WHERE, src=(UPat.var("b"), UPat.var("x", dtype=dtypes.bfloat16), UPat.var("y", dtype=dtypes.bfloat16))),
lambda b,x,y: UOp(Ops.ALU, arg=TernaryOps.WHERE, dtype=dtypes.float, src=(b,x.cast(dtypes.float),y.cast(dtypes.float))).cast(dtypes.bfloat16)),
(UPat(Ops.ALU, dtype=dtypes.bfloat16, name="x"),
(UPat(Ops.WHERE, src=(UPat.var("b"), UPat.var("x", dtype=dtypes.bfloat16), UPat.var("y", dtype=dtypes.bfloat16))),
lambda b,x,y: UOp(Ops.WHERE, dtype=dtypes.float, src=(b,x.cast(dtypes.float),y.cast(dtypes.float))).cast(dtypes.bfloat16)),
(UPat(GroupOp.ALU, dtype=dtypes.bfloat16, name="x"),
lambda x: UOp(x.op, dtypes.float, tuple(vv.cast(dtypes.float) for vv in x.src), x.arg).cast(dtypes.bfloat16)),
(UPat(Ops.ALU, dtypes.bool, name="alu", src=(UPat.var("x", dtype=dtypes.bfloat16), UPat.var("y", dtype=dtypes.bfloat16))),
(UPat(GroupOp.ALU, dtypes.bool, name="alu", src=(UPat.var("x", dtype=dtypes.bfloat16), UPat.var("y", dtype=dtypes.bfloat16))),
lambda alu,x,y: UOp(alu.op, dtypes.bool, (x.cast(dtypes.float), y.cast(dtypes.float)), alu.arg)),
# add float intermediate casting for bfloat16
(UPat(Ops.CAST, name="x", src=UPat.var("y", dtypes.bfloat16)),lambda x,y: y.cast(dtypes.float).cast(x.dtype) if x.dtype!=dtypes.float else None),

View File

@@ -1,7 +1,7 @@
from typing import Dict, Callable, List, Optional
from llvmlite import ir
from tinygrad.dtype import DType, PtrDType, dtypes
from tinygrad.ops import UnaryOps, BinaryOps, TernaryOps, Ops, UOp
from tinygrad.ops import UnaryOps, BinaryOps, TernaryOps, Ops, UOp, GroupOp
from tinygrad.renderer import Renderer
MFLAGS = ('nsz', 'arcp', 'contract', 'afn') # All from fast math, but nnan and ninf and reassoc
@@ -141,8 +141,8 @@ class LLVMRenderer(Renderer):
backward = src[0]
while backward.op is Ops.ASSIGN: backward = backward.src[0]
lvars[backward] = lvars[u]
elif uop is Ops.ALU:
lvars[u] = self.code_for_op[args](bb[-1], *[lvars[x] for x in src], src[0].dtype if args in {BinaryOps.CMPLT, BinaryOps.CMPNE} else dtype)
elif uop in GroupOp.ALU:
lvars[u] = self.code_for_op[uop](bb[-1], *[lvars[x] for x in src], src[0].dtype if uop in {BinaryOps.CMPLT, BinaryOps.CMPNE} else dtype)
elif uop in {Ops.CAST, Ops.BITCAST}: lvars[u] = cast(bb, lvars[src[0]], src[0].dtype, dtype, bitcast=uop is Ops.BITCAST)
elif uop in {Ops.DEFINE_GLOBAL, Ops.DEFINE_VAR}: lvars[u] = func.args[buf_index[args]]
elif uop is Ops.CONST: lvars[u] = const(args, dtype)

View File

@@ -1,7 +1,7 @@
from typing import DefaultDict, Dict, List, Union, Optional, cast, Callable
from typing import DefaultDict, Dict, List, Union, Optional, cast, Callable, Tuple
import struct
from collections import defaultdict
from tinygrad.ops import BinaryOps, UnaryOps, TernaryOps, Ops, UOp, PatternMatcher, UPat
from tinygrad.ops import BinaryOps, UnaryOps, TernaryOps, Ops, UOp, PatternMatcher, UPat, GroupOp
from tinygrad.dtype import dtypes, DType, PtrDType, ConstType
from tinygrad.renderer import Renderer
from tinygrad.renderer.cstyle import CUDARenderer
@@ -33,14 +33,14 @@ asm_for_op: Dict[Ops, Callable] = {
}
supports_half: List[Ops] = [UnaryOps.EXP2, BinaryOps.ADD, BinaryOps.MUL, BinaryOps.MAX, BinaryOps.CMPLT, TernaryOps.WHERE]
doesnt_support_half: Tuple[Ops, ...] = tuple(op for op in asm_for_op.keys() if op not in supports_half)
ptx_matcher = PatternMatcher([
# bool CMPNE is XOR, bool CMPLT is XOR+AND (universal makes this slow, this is for renderer only)
(UPat.var('x', dtype=dtypes.bool).ne(UPat.var('y')), lambda x,y: x^y),
(UPat.var('x', dtype=dtypes.bool).lt(UPat.var('y')), lambda x,y: (x^True)&y),
# upcast to float32 all the ops that don't support half
*[(UPat(Ops.ALU, arg=op, dtype=dtypes.half, name="x"),
lambda x: (UOp(x.op, dtypes.float32, tuple(vv.cast(dtypes.float32) for vv in x.src), x.arg).cast(dtypes.half)))
for op in asm_for_op.keys() if op not in supports_half],
(UPat(doesnt_support_half, dtype=dtypes.half, name="x"),
lambda x: (UOp(x.op, dtypes.float32, tuple(vv.cast(dtypes.float32) for vv in x.src), x.arg).cast(dtypes.half))),
# load/store bool -> uint8
(UPat(Ops.LOAD, dtypes.bool, src=(UPat(dtype=dtypes.int64),), name="x", allow_any_len=True),
lambda x: UOp(x.op, dtypes.uint8, x.src[0:1] + ((x.src[1].cast(dtypes.uint8),) if len(x.src) >= 2 else ()) + x.src[2:]).cast(dtypes.bool)),
@@ -50,8 +50,8 @@ ptx_matcher = PatternMatcher([
(UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var("idx"))), lambda buf,idx: buf.cast(dtypes.int64) + idx.cast(dtypes.int64)*buf.dtype.itemsize),
(UPat(Ops.CAST, name="x"), lambda x: x.src[0] if isinstance(x.dtype, PtrDType) else None),
# ptx shr and shl instructions require y to be uint
(UPat.var("x") << UPat.var("y"), lambda x,y: UOp(Ops.ALU, x.dtype, (x,y.cast(dtypes.uint)), BinaryOps.SHL) if y.dtype != dtypes.uint else None),
(UPat.var("x") >> UPat.var("y"), lambda x,y: UOp(Ops.ALU, x.dtype, (x,y.cast(dtypes.uint)), BinaryOps.SHR) if y.dtype != dtypes.uint else None),
(UPat.var("x") << UPat.var("y"), lambda x,y: UOp(Ops.SHL, x.dtype, (x,y.cast(dtypes.uint))) if y.dtype != dtypes.uint else None),
(UPat.var("x") >> UPat.var("y"), lambda x,y: UOp(Ops.SHR, x.dtype, (x,y.cast(dtypes.uint))) if y.dtype != dtypes.uint else None),
])
class PTXRenderer(Renderer):
@@ -166,9 +166,9 @@ class PTXRenderer(Renderer):
kk(gate + f"st{mem_type}.{self.mem_types[src[1].dtype]} [{r[src[0]]}+0], {r[src[1]]};")
else:
if uop is Ops.RANGE: kk(*self.render_loop(loop:=ssa('ridx', u), r[src[0]], "LOOP_"+loop[1:]))
elif uop is Ops.ALU:
src_dtype = src[0].dtype if args in {BinaryOps.CMPLT, BinaryOps.CMPNE} else dtype
kk(self.code_for_op[args](ssa("alu", u), *[r[x] for x in src], src_dtype, self.types[src_dtype]))
elif uop in GroupOp.ALU:
src_dtype = src[0].dtype if uop in {BinaryOps.CMPLT, BinaryOps.CMPNE} else dtype
kk(self.code_for_op[uop](ssa("alu", u), *[r[x] for x in src], src_dtype, self.types[src_dtype]))
elif uop is Ops.DEFINE_ACC:
if dtype.count > 1:
r[u] = [ssa('acc', dtype=self.types[dtype.scalar()]) for _ in range(dtype.count)]
@@ -190,7 +190,7 @@ class PTXRenderer(Renderer):
elif uop is Ops.LOAD:
assert src[0].dtype == dtypes.int64, "load isn't int64"
mem_type = '.shared' if src[0].op is Ops.DEFINE_LOCAL or any(x.op is Ops.DEFINE_LOCAL for x in src[0].parents) else '.global'
has_gate = len(src) > 2 and src[2].op is Ops.ALU
has_gate = len(src) > 2 and src[2].op in GroupOp.ALU
if dtype.count > 1:
r[u] = [ssa('val', dtype=self.types[dtype.scalar()]) for _ in range(dtype.count)]
if has_gate:

View File

@@ -7,7 +7,7 @@ import pickle, base64, itertools, time, struct
from tinygrad.dtype import DType, dtypes, ImageDType, PtrDType, truncate
from tinygrad.helpers import all_same, getenv, flatten
from tinygrad.device import Compiled, Compiler, Allocator
from tinygrad.ops import BinaryOps, TernaryOps, exec_alu, Ops, UOp
from tinygrad.ops import BinaryOps, TernaryOps, exec_alu, Ops, UOp, GroupOp
from tinygrad.renderer import Renderer
from tinygrad.renderer.cstyle import CUDARenderer, MetalRenderer, AMDRenderer, IntelRenderer, ClangRenderer
@@ -173,10 +173,10 @@ class PythonProgram:
def c_map(_, elem): return (elem%16, elem//16)
ul[i] = wmma_helper(1, 1, 16, 16, 256, elem, elem, c_map)
else: raise NotImplementedError(f"unimplemented tensor core {arg}")
elif uop is Ops.ALU:
assert all_same([len(x) for x in inp]), f"{[len(x) for x in inp]} doesn't match on {arg}"
assert all_same([dtype] + dtp) or arg in {BinaryOps.CMPNE, BinaryOps.CMPLT, TernaryOps.WHERE}, f"dtype mismatch on {arg}"
ul[i] = [exec_alu(arg, dtype, p) for p in zip(*inp)]
elif uop in GroupOp.ALU:
assert all_same([len(x) for x in inp]), f"{[len(x) for x in inp]} doesn't match on {uop}"
assert all_same([dtype] + dtp) or uop in {BinaryOps.CMPNE, BinaryOps.CMPLT, TernaryOps.WHERE}, f"dtype mismatch on {uop}"
ul[i] = [exec_alu(uop, dtype, p) for p in zip(*inp)]
assert i in ul, (uop, dtype, idp, arg)
i += 1
return time.perf_counter() - st

View File

@@ -76,8 +76,8 @@ class ShapeTracker:
idx, valid = (graph_rewrite(u, symbolic_flat) for u in self.to_indexed_uops())
for c in split_uop(idx, BinaryOps.ADD):
if c.op is Ops.RANGE: ret[c.arg] = 1
if c.op is Ops.ALU and c.arg is BinaryOps.MUL and c.src[0].op is Ops.RANGE and c.src[1].op is Ops.CONST: ret[c.src[0].arg] = c.src[1].arg
if c.op is Ops.ALU and c.arg is BinaryOps.MUL and c.src[1].op is Ops.RANGE and c.src[0].op is Ops.CONST: ret[c.src[1].arg] = c.src[0].arg
if c.op is Ops.MUL and c.src[0].op is Ops.RANGE and c.src[1].op is Ops.CONST: ret[c.src[0].arg] = c.src[1].arg
if c.op is Ops.MUL and c.src[1].op is Ops.RANGE and c.src[0].op is Ops.CONST: ret[c.src[1].arg] = c.src[0].arg
used_ranges = [x.arg for x in idx.sparents if x.op is Ops.RANGE]
ret = [x if i in used_ranges else 0 for i,x in enumerate(ret)]
if not ignore_valid:

View File

@@ -384,10 +384,9 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
def from_uop(y:UOp, **kwargs) -> Tensor:
if y.op is Ops.BIND: return Tensor(y, **kwargs, requires_grad=False) # this is the only UOp allowed in Tensor
if y.op is Ops.CONST: return Tensor(y.arg, **kwargs, requires_grad=False)
if y.op is Ops.ALU:
if y.arg is BinaryOps.MUL: return Tensor.from_uop(y.src[0]) * Tensor.from_uop(y.src[1])
if y.arg is BinaryOps.ADD: return Tensor.from_uop(y.src[0]) + Tensor.from_uop(y.src[1])
if y.arg is BinaryOps.MAX: return Tensor.from_uop(y.src[0]).maximum(Tensor.from_uop(y.src[1]))
if y.op is BinaryOps.MUL: return Tensor.from_uop(y.src[0]) * Tensor.from_uop(y.src[1])
if y.op is BinaryOps.ADD: return Tensor.from_uop(y.src[0]) + Tensor.from_uop(y.src[1])
if y.op is BinaryOps.MAX: return Tensor.from_uop(y.src[0]).maximum(Tensor.from_uop(y.src[1]))
raise RuntimeError(f"unhandled UOp {y}")
# ***** creation entrypoint *****

View File

@@ -5,13 +5,13 @@ from urllib.parse import parse_qs, urlparse
from dataclasses import asdict, dataclass
from typing import Any, Dict, List, Tuple, Optional
from tinygrad.helpers import colored, getenv, to_function_name, tqdm, unwrap, word_wrap
from tinygrad.ops import TrackedRewriteContext, UOp, Ops, lines
from tinygrad.ops import TrackedRewriteContext, UOp, Ops, lines, GroupOp
from tinygrad.codegen.kernel import Kernel
uops_colors = {Ops.ALU: "#ffffc0", Ops.LOAD: "#ffc0c0", Ops.STORE: "#c0ffc0", Ops.CONST: "#e0e0e0", Ops.VCONST: "#e0e0e0",
uops_colors = {Ops.LOAD: "#ffc0c0", Ops.STORE: "#c0ffc0", Ops.CONST: "#e0e0e0", Ops.VCONST: "#e0e0e0",
Ops.DEFINE_GLOBAL: "#ffe0b0", Ops.DEFINE_LOCAL: "#ffe0d0", Ops.DEFINE_ACC: "#f0ffe0", Ops.REDUCE: "#C4A484",
Ops.RANGE: "#c8a0e0", Ops.ASSIGN: "#e0ffc0", Ops.BARRIER: "#ff8080", Ops.IF: "#c8b0c0", Ops.SPECIAL: "#c0c0ff",
Ops.INDEX: "#e8ffa0", Ops.WMMA: "#efefc0", Ops.VIEW: "#C8F9D4", Ops.REDUCE_AXIS: "#f58488"}
Ops.INDEX: "#e8ffa0", Ops.WMMA: "#efefc0", Ops.VIEW: "#C8F9D4", Ops.REDUCE_AXIS: "#f58488", **{x:"#ffffc0" for x in GroupOp.ALU}}
# ** API spec