mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 15:08:02 -05:00
Ops.ALU is no more, the arg is just an op (#7525)
* op arg alu [pr] * more * more passing * fix more tests * more tests passing * fix single failing test * so much cleaner * noop to not have process replay trigger * fix ptx
This commit is contained in:
2
test/external/fuzz_linearizer.py
vendored
2
test/external/fuzz_linearizer.py
vendored
@@ -252,7 +252,7 @@ def fuzz_linearizer(lin: Kernel, rtol=1e-2, atol=1e-2, opts_list=None):
|
||||
def _is_simple(lin: Kernel) -> bool:
|
||||
if len(lin.ast.src) > 1: return False
|
||||
ast:UOp = lin.ast.src[0]
|
||||
if ast.src[0].arg is UnaryOps.CAST and ast.src[0].src[0].op is Ops.LOAD: return True
|
||||
if ast.src[0].op is UnaryOps.CAST and ast.src[0].src[0].op is Ops.LOAD: return True
|
||||
return False
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@@ -8,7 +8,7 @@ from tinygrad.dtype import DType
|
||||
from tinygrad.helpers import CI, getenv
|
||||
from tinygrad.engine.schedule import create_schedule
|
||||
from tinygrad.engine.realize import run_schedule
|
||||
from tinygrad.ops import UnaryOps, Ops
|
||||
from tinygrad.ops import GroupOp
|
||||
from tinygrad.tensor import _to_np_dtype
|
||||
from test.helpers import is_dtype_supported
|
||||
import pytest
|
||||
@@ -79,7 +79,7 @@ def universal_test_unary(a, dtype, op):
|
||||
np.testing.assert_allclose(tensor_value, numpy_value, atol=1e-3, rtol=1e-2)
|
||||
else: np.testing.assert_equal(tensor_value, numpy_value)
|
||||
if op[0] != Tensor.reciprocal: # reciprocal is not supported in most backends
|
||||
op = [x for x in ast.parents if x.op is Ops.ALU and x.arg in UnaryOps][0]
|
||||
op = [x for x in ast.parents if x.op in GroupOp.Unary][0]
|
||||
assert op.dtype == dtype
|
||||
|
||||
def universal_test_cast(a, in_dtype, dtype):
|
||||
|
||||
@@ -6,7 +6,7 @@ from dataclasses import replace
|
||||
from test.helpers import ast_const
|
||||
from tinygrad.codegen.kernel import Opt, OptOps, KernelOptError, Kernel
|
||||
from tinygrad.codegen.lowerer import get_grouped_dims
|
||||
from tinygrad.ops import UOp, Ops, BinaryOps, TernaryOps, UnaryOps
|
||||
from tinygrad.ops import UOp, Ops, BinaryOps, TernaryOps, UnaryOps, GroupOp
|
||||
from tinygrad.device import Device, Buffer
|
||||
from tinygrad.shape.shapetracker import ShapeTracker
|
||||
from tinygrad.shape.view import View
|
||||
@@ -865,9 +865,9 @@ class TestLinearizer(unittest.TestCase):
|
||||
lin = helper_linearizer_opt(out, wanna_output=[24])[0]
|
||||
ranges = [i for i,u in enumerate(lin.uops) if u.op is Ops.RANGE]
|
||||
# RANGE -> ALU -> RANGE -> ALU + LOAD -> ASSIGN
|
||||
assert any(x.op is Ops.ALU for x in lin.uops[ranges[0]:ranges[1]])
|
||||
assert any(x.op in GroupOp.ALU for x in lin.uops[ranges[0]:ranges[1]])
|
||||
assert not any(x.op is Ops.LOAD for x in lin.uops[ranges[0]:ranges[1]])
|
||||
assert any(x.op in {Ops.ALU, Ops.LOAD} for x in lin.uops[ranges[1]:])
|
||||
assert any(x.op in {*GroupOp.ALU, Ops.LOAD} for x in lin.uops[ranges[1]:])
|
||||
|
||||
def test_range_outer_op_before_phi(self):
|
||||
a = Tensor.randn(4, 1).realize()
|
||||
@@ -902,7 +902,7 @@ class TestLinearizer(unittest.TestCase):
|
||||
lin = helper_linearizer_opt(out, wanna_output=[a.numpy().sum()*a.numpy().sum()])[0]
|
||||
# RANGE -> LOAD -> ASSIGN -> ALU
|
||||
end = max(i for i,u in enumerate(lin.uops) if u.op is Ops.ENDRANGE)
|
||||
assert lin.uops[end+1].op is Ops.ALU
|
||||
assert lin.uops[end+1].op in GroupOp.ALU
|
||||
|
||||
def test_range_outer_op_after_phi_nested_range(self):
|
||||
a = Tensor.randn(2, ).realize()
|
||||
@@ -910,7 +910,7 @@ class TestLinearizer(unittest.TestCase):
|
||||
lin = helper_linearizer_opt(out, wanna_output=[(np.broadcast_to(a.numpy().reshape(2, 1), (2, 3))).sum()*2])[0]
|
||||
# RANGE -> LOAD -> ASSIGN -> ALU
|
||||
end = max(i for i,u in enumerate(lin.uops) if u.op is Ops.ENDRANGE)
|
||||
assert lin.uops[end+1].op is Ops.ALU
|
||||
assert lin.uops[end+1].op in GroupOp.ALU
|
||||
|
||||
def test_load_dedup(self):
|
||||
# for different leaves in the AST, the same loads may occur.
|
||||
@@ -1141,7 +1141,7 @@ class TestLinearizer(unittest.TestCase):
|
||||
# the uops graph is RANGE -> DEFINE_ACC -> 4x ALU -> 4x ASSIGN -> ENDRANGE
|
||||
for u in k.uops:
|
||||
if u.op is Ops.ASSIGN:
|
||||
assert u.src[1].op is Ops.ALU
|
||||
assert u.src[1].op in GroupOp.ALU
|
||||
# children of ASSIGN are placed after ENDRANGE
|
||||
if any(x.op is Ops.ASSIGN for x in u.src):
|
||||
end_range = [i for i, x in enumerate(k.uops) if x.op is Ops.ENDRANGE][0]
|
||||
@@ -1219,7 +1219,7 @@ class TestLinearizer(unittest.TestCase):
|
||||
assert len(sched) == 1
|
||||
|
||||
lin = Kernel(sched[0].ast)
|
||||
assert sum(u.arg is UnaryOps.RECIP for u in lin.linearize().uops) == max_ops, msg
|
||||
assert sum(u.op is UnaryOps.RECIP for u in lin.linearize().uops) == max_ops, msg
|
||||
|
||||
a = Tensor.empty((4,4))
|
||||
b = Tensor.empty((4,4))
|
||||
@@ -1260,7 +1260,7 @@ class TestLinearizer(unittest.TestCase):
|
||||
lin = Kernel(sched_copy[-1].ast)
|
||||
lin.hand_coded_optimizations()
|
||||
lin.linearize()
|
||||
assert not any(u.arg == TernaryOps.WHERE for u in lin.uops), "found where where where should be folded"
|
||||
assert not any(u.op == TernaryOps.WHERE for u in lin.uops), "found where where where should be folded"
|
||||
|
||||
def test_phi_simplification(self):
|
||||
def helper(t, max_ops=0):
|
||||
@@ -1272,7 +1272,7 @@ class TestLinearizer(unittest.TestCase):
|
||||
assert len(set([u.op for u in uops if u.op in {Ops.RANGE, Ops.SPECIAL}])) == 1, "has either specials or ranges, not both"
|
||||
assert len([u for u in uops if u.op is Ops.ASSIGN]) == 0, "ASSIGN should have been simplified"
|
||||
# TODO: once uops track min/max this will be fixed
|
||||
#assert len([u for u in uops if u.arg is BinaryOps.MAX]) <= max_ops, "no unnecessary MAX ops"
|
||||
#assert len([u for u in uops if u.op is BinaryOps.MAX]) <= max_ops, "no unnecessary MAX ops"
|
||||
|
||||
helper(Tensor.arange(5.5, (3.5*300), 3.5), max_ops=2)
|
||||
helper(Tensor.arange(-1, -100, -5), max_ops=2)
|
||||
|
||||
@@ -75,7 +75,7 @@ class TestLinearizerDumb(unittest.TestCase):
|
||||
for opt in opts: k.apply_opt(opt)
|
||||
prg = k.to_program()
|
||||
print(prg.src)
|
||||
assert prg.uops is not None and not any(uop.op is Ops.ALU and uop.arg is BinaryOps.MAX for uop in prg.uops), "leftover MAX"
|
||||
assert prg.uops is not None and not any(uop.op is BinaryOps.MAX for uop in prg.uops), "leftover MAX"
|
||||
|
||||
@unittest.skipUnless(Device[Device.DEFAULT].renderer.has_local, "need local")
|
||||
def test_expander_new_srcs(self):
|
||||
|
||||
@@ -620,21 +620,21 @@ class TestMultiTensor(unittest.TestCase):
|
||||
for si in t.schedule():
|
||||
ast = si.ast.src[0]
|
||||
assert ast.op is Ops.STORE
|
||||
assert ast.src[2].arg is BinaryOps.ADD
|
||||
assert ast.src[2].op is BinaryOps.ADD
|
||||
assert ast.src[2].src[0].op is Ops.LOAD
|
||||
assert ast.src[2].src[1].src[1].op is Ops.CONST and ast.src[2].src[1].src[1].arg == 1
|
||||
t = 2 * t
|
||||
for si in t.schedule():
|
||||
ast = si.ast.src[0]
|
||||
assert ast.op is Ops.STORE
|
||||
assert ast.src[2].arg is BinaryOps.MUL
|
||||
assert ast.src[2].op is BinaryOps.MUL
|
||||
assert ast.src[2].src[0].src[1].op is Ops.CONST and ast.src[2].src[0].src[1].arg == 2
|
||||
assert ast.src[2].src[1].op is Ops.LOAD
|
||||
t = t + t.full_like(3)
|
||||
for si in t.schedule():
|
||||
ast = si.ast.src[0]
|
||||
assert ast.op is Ops.STORE
|
||||
assert ast.src[2].arg is BinaryOps.ADD
|
||||
assert ast.src[2].op is BinaryOps.ADD
|
||||
assert ast.src[2].src[0].op is Ops.LOAD
|
||||
assert ast.src[2].src[1].src[1].op is Ops.CONST and ast.src[2].src[1].src[1].arg == 3
|
||||
|
||||
|
||||
@@ -1040,7 +1040,7 @@ class TestSchedule(unittest.TestCase):
|
||||
b = r.sum(0) * 4
|
||||
c = r.sum(1) * 2
|
||||
schedule = check_schedule([b, c], 3)
|
||||
self.assertIs(schedule[0].ast.src[0].src[2].arg, BinaryOps.ADD)
|
||||
self.assertIs(schedule[0].ast.src[0].src[2].op, BinaryOps.ADD)
|
||||
|
||||
# multireduce spec
|
||||
def test_multireduce_simple_chase(self):
|
||||
@@ -1064,7 +1064,7 @@ class TestSchedule(unittest.TestCase):
|
||||
d = r.T * 4
|
||||
e = r * d
|
||||
schedule = check_schedule([d, e], 3)
|
||||
self.assertIs(schedule[0].ast.src[0].src[2].arg, BinaryOps.ADD)
|
||||
self.assertIs(schedule[0].ast.src[0].src[2].op, BinaryOps.ADD)
|
||||
|
||||
# multireduce spec
|
||||
def test_multireduce_push_permute_chase(self):
|
||||
@@ -1075,7 +1075,7 @@ class TestSchedule(unittest.TestCase):
|
||||
d = r.T * 4
|
||||
e = r * (d + a).sum(2)
|
||||
schedule = check_schedule([d, e], 3) # make sure it doesn't fuse
|
||||
self.assertIs(schedule[0].ast.src[0].src[2].arg, BinaryOps.ADD)
|
||||
self.assertIs(schedule[0].ast.src[0].src[2].op, BinaryOps.ADD)
|
||||
run_schedule(schedule)
|
||||
np.testing.assert_allclose(d.numpy(), (a.numpy().sum(2) + b.numpy()).T * 4, atol=1e-4, rtol=1e-4)
|
||||
np.testing.assert_allclose(e.numpy(), (a.numpy().sum(2) + b.numpy()) * (d.numpy() + a.numpy()).sum(2), atol=1e-4, rtol=1e-4)
|
||||
@@ -1087,7 +1087,7 @@ class TestSchedule(unittest.TestCase):
|
||||
r = a.sum(1) + c
|
||||
d = r[:4] * b
|
||||
schedule = check_schedule(d, 2)
|
||||
self.assertIs(schedule[0].ast.src[0].src[2].arg, BinaryOps.ADD)
|
||||
self.assertIs(schedule[0].ast.src[0].src[2].op, BinaryOps.ADD)
|
||||
|
||||
# multireduce spec
|
||||
def test_multireduce_push_shrink_chase(self):
|
||||
@@ -1100,7 +1100,7 @@ class TestSchedule(unittest.TestCase):
|
||||
out = r[:4] * b + d.sum(1)[:4]
|
||||
# schedule = check_schedule(out, 2)
|
||||
schedule = check_schedule(out, 3)
|
||||
self.assertIs(schedule[0].ast.src[0].src[2].arg, BinaryOps.ADD)
|
||||
self.assertIs(schedule[0].ast.src[0].src[2].op, BinaryOps.ADD)
|
||||
run_schedule(schedule)
|
||||
np.testing.assert_allclose(out.numpy(), (a.numpy().sum(1) + c.numpy())[:4] * b.numpy() + d.numpy().sum(1)[:4], atol=1e-4, rtol=1e-4)
|
||||
|
||||
|
||||
@@ -161,7 +161,7 @@ class TestGraphRewrite(unittest.TestCase):
|
||||
c1 = UOp.const(dtypes.float, 1.0)
|
||||
c2 = UOp.const(dtypes.float, 2.0)
|
||||
nout = graph_rewrite(v+c1+c2, simple_pm)
|
||||
self.assertEqual(nout.op, Ops.ALU)
|
||||
self.assertEqual(nout.op, Ops.ADD)
|
||||
self.assertEqual(nout.src[0].op, Ops.DEFINE_VAR)
|
||||
self.assertEqual(nout.src[1].op, Ops.CONST)
|
||||
self.assertEqual(nout.src[1].arg, 3.0)
|
||||
@@ -182,11 +182,11 @@ class TestGraphRewrite(unittest.TestCase):
|
||||
b = UOp.variable('b', 0, 1)
|
||||
c = UOp.variable('c', 0, 1)
|
||||
d = UOp.variable('d', 0, 1)
|
||||
outs = [2+a, 2+a+d+3+b+c+4, UOp(Ops.ALU, a.dtype, src=(a.const_like(2), a), arg=BinaryOps.ADD), (4+d)+c+(2+a)+b]
|
||||
outs = [2+a, 2+a+d+3+b+c+4, UOp(Ops.ADD, a.dtype, src=(a.const_like(2), a)), (4+d)+c+(2+a)+b]
|
||||
for out in outs:
|
||||
sink = graph_rewrite(out, sym)
|
||||
print(sink.render())
|
||||
self.assertEqual(sink.op, Ops.ALU)
|
||||
self.assertEqual(sink.op, Ops.ADD)
|
||||
self.assertEqual(sink.src[1].op, Ops.CONST)
|
||||
self.assertEqual(len([x for x in sink.sparents if x.op is Ops.CONST]), 1)
|
||||
|
||||
@@ -380,8 +380,7 @@ class TestUOpGraph(unittest.TestCase):
|
||||
uops = to_uops_list([out])
|
||||
self.assertEqual(len(uops), 3)
|
||||
out = uops[-1]
|
||||
self.assertEqual(out.op, Ops.ALU)
|
||||
self.assertEqual(out.arg, BinaryOps.ADD)
|
||||
self.assertEqual(out.op, BinaryOps.ADD)
|
||||
self.assertEqual(out.src[1].op, Ops.CONST)
|
||||
self.assertEqual(out.src[1].arg, 6)
|
||||
|
||||
|
||||
@@ -336,8 +336,8 @@ class TestAssembly(unittest.TestCase):
|
||||
a2 = UOp(Ops.ALU, dtypes.int, (l1, c2), BinaryOps.MUL)
|
||||
uops = to_uops_list([a1,a2], opts=Device[Device.DEFAULT].renderer)
|
||||
Device[Device.DEFAULT].renderer.render("test", uops)
|
||||
self.assertEqual(uops[-1].arg, BinaryOps.SHL)
|
||||
self.assertEqual(uops[-2].arg, BinaryOps.MUL)
|
||||
self.assertEqual(uops[-1].op, BinaryOps.SHL)
|
||||
self.assertEqual(uops[-2].op, BinaryOps.MUL)
|
||||
|
||||
def test_bitshift_right(self):
|
||||
g1 = UOp(Ops.DEFINE_GLOBAL, dtypes.int32.ptr(), (), 0)
|
||||
@@ -348,8 +348,8 @@ class TestAssembly(unittest.TestCase):
|
||||
a2 = UOp(Ops.ALU, dtypes.int, (l1, c2), BinaryOps.IDIV)
|
||||
uops = to_uops_list([a1,a2], opts=Device[Device.DEFAULT].renderer)
|
||||
Device[Device.DEFAULT].renderer.render("test", uops)
|
||||
self.assertEqual(uops[-1].arg, BinaryOps.SHR)
|
||||
self.assertEqual(uops[-2].arg, BinaryOps.IDIV)
|
||||
self.assertEqual(uops[-1].op, BinaryOps.SHR)
|
||||
self.assertEqual(uops[-2].op, BinaryOps.IDIV)
|
||||
|
||||
class TestUOpMethod(unittest.TestCase):
|
||||
@unittest.skip("uops lt no longer ordered")
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import unittest, math
|
||||
from tinygrad import dtypes
|
||||
from tinygrad.helpers import all_same
|
||||
from tinygrad.ops import UOp, Ops, BinaryOps, exec_alu
|
||||
from tinygrad.ops import GroupOp, UOp, Ops, BinaryOps, exec_alu
|
||||
from tinygrad.codegen.uopgraph import full_graph_rewrite
|
||||
|
||||
# Helper function to apply the graph rewrite
|
||||
@@ -14,9 +14,9 @@ def evaluate_uop(uop, variables):
|
||||
elif uop.op == Ops.DEFINE_VAR:
|
||||
var_name = uop.arg[0]
|
||||
return variables[var_name]
|
||||
elif uop.op == Ops.ALU:
|
||||
elif uop.op in GroupOp.ALU:
|
||||
src_values = [evaluate_uop(src, variables) for src in uop.src]
|
||||
return exec_alu(uop.arg, uop.dtype, src_values)
|
||||
return exec_alu(uop.op, uop.dtype, src_values)
|
||||
else:
|
||||
raise NotImplementedError(f"Unsupported UOp {uop.op}")
|
||||
|
||||
@@ -104,8 +104,7 @@ class TestModuloAndDivisionFolding(unittest.TestCase):
|
||||
def test_full_graph_rewrite_division_folding_with_define_var(self):
|
||||
n_var_uop = UOp.variable('n', 1, 1000)
|
||||
optimized_div_uop = apply_rewrite((n_var_uop * 6) // 3)
|
||||
self.assertEqual(optimized_div_uop.op, Ops.ALU)
|
||||
self.assertEqual(optimized_div_uop.arg, BinaryOps.MUL)
|
||||
self.assertEqual(optimized_div_uop.op, BinaryOps.MUL)
|
||||
self.assertEqual(optimized_div_uop.src[1].arg, 2)
|
||||
|
||||
def test_full_graph_rewrite_complex_mod_div_folding(self):
|
||||
@@ -115,7 +114,7 @@ class TestModuloAndDivisionFolding(unittest.TestCase):
|
||||
self.assertEqual(optimized_div_uop.arg, 1)
|
||||
|
||||
def test_graph_rewrite_div_folding_bug(self):
|
||||
lhs = UOp(Ops.ALU, dtypes.int.vec(4), arg=BinaryOps.ADD, src=(
|
||||
lhs = UOp(Ops.ADD, dtypes.int.vec(4), src=(
|
||||
UOp(Ops.VECTORIZE, dtypes.int.vec(4), arg=None, src=(UOp(Ops.SPECIAL, dtypes.int, arg=('lidx0', 32), src=()),)*4),
|
||||
UOp(Ops.VCONST, dtypes.int.vec(4), arg=(0, 256, 512, 768), src=())))
|
||||
rhs = UOp.const(dtypes.int.vec(4), 2)
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import unittest, itertools
|
||||
from tinygrad.dtype import dtypes
|
||||
from tinygrad.ops import Ops, UOp, BinaryOps, TernaryOps, ReduceOps, UnaryOps # noqa: F401
|
||||
from tinygrad.ops import Ops, UOp, BinaryOps, TernaryOps, ReduceOps, UnaryOps, GroupOp # noqa: F401
|
||||
from tinygrad.ops import PatternMatcher, UPat
|
||||
|
||||
class TestPatternMatcher(unittest.TestCase):
|
||||
@@ -72,7 +72,7 @@ class TestPatternMatcher(unittest.TestCase):
|
||||
matcher = PatternMatcher([
|
||||
(UPat(Ops.CONST, arg=0, name="x"), lambda x: x),
|
||||
(UPat(Ops.CONST, arg=False, name="x"), lambda x: x),
|
||||
(UPat(Ops.ALU, arg=BinaryOps.MAX, name="x"), lambda x: x),
|
||||
(UPat(Ops.MAX, name="x"), lambda x: x),
|
||||
])
|
||||
c1 = UOp(Ops.CONST, dtypes.float, arg=0.0)
|
||||
c2 = UOp(Ops.CONST, dtypes.bool, arg=False)
|
||||
@@ -87,7 +87,7 @@ class TestPatternMatcher(unittest.TestCase):
|
||||
|
||||
def test_filter_arg(self):
|
||||
matcher = PatternMatcher([
|
||||
(UPat(Ops.ALU, arg=BinaryOps.MUL, src=[UPat(Ops.CONST, name="c"), UPat(Ops.CONST, arg=2)], name="x"),
|
||||
(UPat(Ops.MUL, src=[UPat(Ops.CONST, name="c"), UPat(Ops.CONST, arg=2)], name="x"),
|
||||
lambda x,c: x if c.arg in {1, -1} else None)
|
||||
])
|
||||
y1 = UOp(Ops.CONST, dtypes.int, arg=1)
|
||||
@@ -105,11 +105,11 @@ class TestPatternMatcher(unittest.TestCase):
|
||||
self.assertEqual(matcher.rewrite(c5), c5)
|
||||
|
||||
def test_dup_name(self):
|
||||
matcher = PatternMatcher([(UPat(Ops.ALU, name="x", src=(UPat(Ops.CONST, name="y"), UPat(Ops.CONST, name="y"))), lambda x, y: x)])
|
||||
matcher = PatternMatcher([(UPat(GroupOp.ALU, name="x", src=(UPat(Ops.CONST, name="y"), UPat(Ops.CONST, name="y"))), lambda x, y: x)])
|
||||
y1 = UOp(Ops.CONST, dtypes.float, arg=1.0)
|
||||
y2 = UOp(Ops.CONST, dtypes.float, arg=1.0)
|
||||
c1 = UOp(Ops.ALU, dtypes.float, (y1, y1), BinaryOps.ADD)
|
||||
c2 = UOp(Ops.ALU, dtypes.float, (y1, y2), BinaryOps.ADD)
|
||||
c1 = UOp(Ops.ADD, dtypes.float, (y1, y1))
|
||||
c2 = UOp(Ops.ADD, dtypes.float, (y1, y2))
|
||||
self.assertEqual(matcher.rewrite(c1), c1)
|
||||
self.assertEqual(matcher.rewrite(c2), c1)
|
||||
|
||||
@@ -132,7 +132,7 @@ class TestPatternMatcher(unittest.TestCase):
|
||||
self.assertEqual(matcher.rewrite(c4), None)
|
||||
|
||||
def test_src_one(self):
|
||||
matcher = PatternMatcher([(UPat(Ops.ALU, name="x", src=(UPat(Ops.CONST), UPat(Ops.CONST))), lambda x: x)])
|
||||
matcher = PatternMatcher([(UPat(GroupOp.ALU, name="x", src=(UPat(Ops.CONST), UPat(Ops.CONST))), lambda x: x)])
|
||||
c1 = UOp(Ops.CONST, dtypes.float, arg=1.0)
|
||||
c2 = UOp(Ops.CONST, dtypes.float, arg=2.0)
|
||||
c3 = UOp(Ops.ALU, dtypes.float, (c1,c2), BinaryOps.ADD)
|
||||
@@ -149,7 +149,7 @@ class TestPatternMatcher(unittest.TestCase):
|
||||
"""
|
||||
|
||||
def test_src_permutations(self):
|
||||
matcher = PatternMatcher([(UPat(Ops.ALU, name="x", src=[UPat(Ops.CONST), UPat(Ops.ALU)]), lambda x: x)])
|
||||
matcher = PatternMatcher([(UPat(GroupOp.ALU, name="x", src=[UPat(Ops.CONST), UPat(GroupOp.ALU)]), lambda x: x)])
|
||||
c1 = UOp(Ops.CONST, dtypes.float, arg=1.0)
|
||||
c2 = UOp(Ops.CONST, dtypes.float, arg=2.0)
|
||||
c3 = UOp(Ops.ALU, dtypes.float, (c1,c2), BinaryOps.ADD)
|
||||
@@ -162,7 +162,7 @@ class TestPatternMatcher(unittest.TestCase):
|
||||
self.assertEqual(matcher.rewrite(c6), None)
|
||||
|
||||
def test_src_repeat(self):
|
||||
matcher = PatternMatcher([(UPat(Ops.ALU, name="x", src=UPat(Ops.CONST)), lambda x: x)])
|
||||
matcher = PatternMatcher([(UPat(GroupOp.ALU, name="x", src=UPat(Ops.CONST)), lambda x: x)])
|
||||
c1 = UOp(Ops.CONST, dtypes.float, arg=1.0)
|
||||
c2 = UOp(Ops.CONST, dtypes.float, arg=2.0)
|
||||
c3 = UOp(Ops.ALU, dtypes.float, (c1,c2), BinaryOps.ADD)
|
||||
@@ -171,7 +171,7 @@ class TestPatternMatcher(unittest.TestCase):
|
||||
self.assertEqual(matcher.rewrite(c4), None)
|
||||
|
||||
def test_allow_len(self):
|
||||
matcher = PatternMatcher([(UPat(Ops.ALU, name="x", src=(UPat(Ops.CONST),), allow_any_len=True, arg=TernaryOps.MULACC), lambda x: x)])
|
||||
matcher = PatternMatcher([(UPat(Ops.MULACC, name="x", src=(UPat(Ops.CONST),), allow_any_len=True), lambda x: x)])
|
||||
c1 = UOp(Ops.CONST, dtypes.float, arg=1.0)
|
||||
c2 = UOp(Ops.CONST, dtypes.float, arg=2.0)
|
||||
c3 = UOp(Ops.CONST, dtypes.float, arg=3.0)
|
||||
@@ -188,7 +188,7 @@ class TestPatternMatcher(unittest.TestCase):
|
||||
u1 = (c1 + c2) + c1
|
||||
u2 = (c2 + c1) + c1
|
||||
matcher = PatternMatcher([
|
||||
(UPat(Ops.ALU, src=[UPat(Ops.ALU, src=[UPat(name='a'), UPat(name='b')]), UPat(name='b')]), lambda a,b: b)
|
||||
(UPat(GroupOp.ALU, src=[UPat(GroupOp.ALU, src=[UPat(name='a'), UPat(name='b')]), UPat(name='b')]), lambda a,b: b)
|
||||
])
|
||||
self.assertIsNotNone(matcher.rewrite(u1))
|
||||
self.assertIsNotNone(matcher.rewrite(u2))
|
||||
|
||||
@@ -276,7 +276,7 @@ class Kernel:
|
||||
if has_cast and not (reduceop.src[0].op is Ops.CAST and reduceop.src[0].dtype == tc.dtype_out): return None
|
||||
|
||||
mul_op = reduceop.src[0].src[0] if has_cast else reduceop.src[0]
|
||||
if mul_op.arg is not BinaryOps.MUL: return None
|
||||
if mul_op.op is not BinaryOps.MUL: return None
|
||||
|
||||
def buf_index(src:UOp) -> Optional[int]:
|
||||
# TODO: apply tc even if the sources are not from LOAD
|
||||
@@ -442,7 +442,7 @@ class Kernel:
|
||||
check(axis < self.first_upcast, "cannot pad upcasted")
|
||||
# ok to pad SUM if all parent ALU ops have f(0) = 0
|
||||
if (r:=self.reduceop) is not None and self.first_reduce <= axis:
|
||||
check(r.arg[0] is BinaryOps.ADD and not any(u.op is Ops.ALU and u.arg in GroupOp.UnsafePad for u in r.parents), "cannot pad UnsafePad")
|
||||
check(r.arg[0] is BinaryOps.ADD and not any(u.op in GroupOp.UnsafePad for u in r.parents), "cannot pad UnsafePad")
|
||||
padded = False
|
||||
for i,st in enumerate(self.sts):
|
||||
if (s:=st.shape[axis]) == 1: continue # reduced
|
||||
@@ -472,7 +472,7 @@ class Kernel:
|
||||
MV_BLOCKSIZE, MV_THREADS_PER_ROW, MV_ROWS_PER_THREAD = getenv("MV_BLOCKSIZE", 4), getenv("MV_THREADS_PER_ROW", 8), getenv("MV_ROWS_PER_THREAD", 4)
|
||||
if self.opts.has_local and getenv("MV",1) != 0 and (MV_BLOCKSIZE > 1 or MV_THREADS_PER_ROW > 1 or MV_ROWS_PER_THREAD > 1) and \
|
||||
self.reduceop is not None and self.reduceop.arg[0] is BinaryOps.ADD and len(self.full_shape) >= 2 and self.opts.has_shared and \
|
||||
(mulop:=self.reduceop.src[0]).arg is BinaryOps.MUL and mulop.src[0].op is Ops.LOAD and mulop.src[1].op is Ops.LOAD:
|
||||
(mulop:=self.reduceop.src[0]).op is BinaryOps.MUL and mulop.src[0].op is Ops.LOAD and mulop.src[1].op is Ops.LOAD:
|
||||
st0, st1 = self.sts[self.bufs.index(mulop.src[0])], self.sts[self.bufs.index(mulop.src[1])]
|
||||
strides0, strides1 = st0.real_strides(), st1.real_strides()
|
||||
def has_expanded_axis(shape, strides): return any(resolve(s > 1) and not resolve(st != 0) for s,st in zip(shape,strides))
|
||||
@@ -628,7 +628,7 @@ class Kernel:
|
||||
if op in self.bufs_for_tensor_core and (tc := self.tensor_core):
|
||||
rsrc = op.src[0]
|
||||
if rsrc.op is Ops.CAST: rsrc = rsrc.src[0]
|
||||
assert rsrc.op is Ops.ALU and rsrc.arg is BinaryOps.MUL
|
||||
assert rsrc.op is Ops.MUL
|
||||
|
||||
def fix_st(warp_dims, tcd_dims, tcd_expand, pattern_1, pattern_2, st1):
|
||||
wd, tcd = self.global_dims, self.first_upcast
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from typing import List, Set, Dict, Tuple
|
||||
import functools, heapq
|
||||
from tinygrad.ops import type_verify, END_FOR_UOP, UOp, Ops
|
||||
from tinygrad.ops import type_verify, END_FOR_UOP, UOp, Ops, GroupOp
|
||||
from tinygrad.dtype import dtypes
|
||||
from tinygrad.helpers import DEBUG
|
||||
|
||||
@@ -54,7 +54,7 @@ def linearize_uop(sink:UOp, skip_check:bool=not __debug__) -> List[UOp]:
|
||||
# prevent priority inversion
|
||||
@functools.lru_cache(None)
|
||||
def fix_priority(u:UOp, lowest_priority):
|
||||
if u.op in {Ops.CAST, Ops.BITCAST, Ops.ALU, Ops.VECTORIZE, Ops.GEP, Ops.SPECIAL, Ops.DEFINE_LOCAL, Ops.LOAD}:
|
||||
if u.op in {Ops.CAST, Ops.BITCAST, *GroupOp.ALU, Ops.VECTORIZE, Ops.GEP, Ops.SPECIAL, Ops.DEFINE_LOCAL, Ops.LOAD}:
|
||||
priorities[u] = min(priorities[u], lowest_priority)
|
||||
if u.op is Ops.LOAD: priorities[u] += 100 # load penalty (here)
|
||||
for x in u.src: fix_priority(x, priorities[u])
|
||||
|
||||
@@ -4,7 +4,7 @@ import functools, itertools, operator
|
||||
from collections import defaultdict
|
||||
from tinygrad.dtype import dtypes, ImageDType, PtrDType
|
||||
from tinygrad.ops import UnaryOps, BinaryOps, TernaryOps, UOp, Ops, UPat, PatternMatcher, symbolic_flat, symbolic_simple
|
||||
from tinygrad.ops import graph_rewrite, is_irreducible, split_uop, uop_given_valid, parse_valid, is_increasing, simplify_valid
|
||||
from tinygrad.ops import graph_rewrite, is_irreducible, split_uop, uop_given_valid, parse_valid, is_increasing, simplify_valid, GroupOp
|
||||
from tinygrad.helpers import DEBUG, getenv, flatten, dedup, TRANSCENDENTAL, AMX, prod, partition, all_same
|
||||
from tinygrad.codegen.transcendental import xexp2, xlog2, xsin, TRANSCENDENTAL_SUPPORTED_DTYPES
|
||||
|
||||
@@ -23,7 +23,7 @@ def fold_expanded(ex, buf):
|
||||
for i,s in enumerate(new_srcs):
|
||||
idx = s.src[0].src[1]
|
||||
if s.dtype.count != 1 or (is_image and idx.dtype.count == 2): continue
|
||||
if idx.arg is BinaryOps.ADD and idx.src[1].op is Ops.CONST: root_src, arg = idx.src[0], idx.src[1].arg
|
||||
if idx.op is BinaryOps.ADD and idx.src[1].op is Ops.CONST: root_src, arg = idx.src[0], idx.src[1].arg
|
||||
elif idx.op is Ops.CONST: root_src, arg = "CONST", idx.arg
|
||||
else: root_src, arg = idx, 0
|
||||
# add gates for gated
|
||||
@@ -124,18 +124,18 @@ def simplify_valid_load(buf:UOp, start_idx:UOp, valid:UOp) -> Optional[UOp]:
|
||||
powers_of_two = {2**i:i for i in range(64)}
|
||||
@functools.lru_cache(None)
|
||||
def get_late_rewrite_patterns(ops, force_transcendental=False):
|
||||
pat: List[Tuple[UPat, Callable]] = [(UPat(Ops.ALU, dtype=TRANSCENDENTAL_SUPPORTED_DTYPES, src=(UPat.var("d"),), arg=op), f) for op,f in \
|
||||
pat: List[Tuple[UPat, Callable]] = [(UPat(op, dtype=TRANSCENDENTAL_SUPPORTED_DTYPES, src=(UPat.var("d"),)), f) for op,f in \
|
||||
((UnaryOps.EXP2, xexp2), (UnaryOps.LOG2, xlog2), (UnaryOps.SIN, xsin)) if op not in ops or force_transcendental]
|
||||
# rewrite MOD to AND (which should always be supported, but not for generic in tests)
|
||||
if BinaryOps.AND in ops:
|
||||
pat += [(UPat(Ops.ALU, arg=BinaryOps.MOD, src=(UPat.var('base'), UPat.cvar("const"))),
|
||||
pat += [(UPat(Ops.MOD, src=(UPat.var('base'), UPat.cvar("const"))),
|
||||
lambda base,const: base & (const.arg-1) if const.arg in powers_of_two else None)]
|
||||
# rewrite MUL/IDIV to SHL+SHR
|
||||
if BinaryOps.SHL in ops and BinaryOps.SHR in ops:
|
||||
pat += [
|
||||
(UPat(Ops.ALU, arg=BinaryOps.MUL, dtype=dtypes.ints, src=[UPat.cvar("const"), UPat.var("mul")]), lambda mul, const:
|
||||
(UPat(Ops.MUL, dtype=dtypes.ints, src=[UPat.cvar("const"), UPat.var("mul")]), lambda mul, const:
|
||||
mul << powers_of_two[const.arg] if const.arg in powers_of_two else None), # (x * (2**y)) -> shl(x,y)
|
||||
(UPat(Ops.ALU, arg=BinaryOps.IDIV, src=(UPat.var("div"), UPat.cvar("const"))), lambda div, const:
|
||||
(UPat(Ops.IDIV, src=(UPat.var("div"), UPat.cvar("const"))), lambda div, const:
|
||||
div >> powers_of_two[const.arg] if const.arg in powers_of_two else None)] # (x // (2**y)) -> shr(x,y)
|
||||
if UnaryOps.NEG in ops:
|
||||
pat += [(UPat.var('x')*-1, lambda x: x.alu(UnaryOps.NEG))]
|
||||
@@ -226,8 +226,8 @@ def reduce_collapse(acc:UOp, ret:UOp, alu:UOp):
|
||||
reduce_parented, reduce_unparented = partition(acc.src[1:], lambda x: x in ret.sparents)
|
||||
if len(reduce_unparented) == 0: return None
|
||||
new_acc = acc.replace(src=acc.src[0:1]+tuple(reduce_parented))
|
||||
ret = new_acc.assign(new_acc.alu(alu.arg, ret))
|
||||
if alu.arg is BinaryOps.ADD:
|
||||
ret = new_acc.assign(new_acc.alu(alu.op, ret))
|
||||
if alu.op is BinaryOps.ADD:
|
||||
for r in reduce_unparented: ret = ret * (r.src[1]-r.src[0]).cast(ret.dtype.scalar()).broadcast(ret.dtype.count)
|
||||
return ret
|
||||
|
||||
@@ -249,8 +249,8 @@ sym = symbolic_flat+PatternMatcher([
|
||||
(UPat(Ops.VECTORIZE, src=UPat(Ops.CONST), name="vec"), lambda vec: UOp.const(vec.dtype, tuple(x.arg for x in vec.src))),
|
||||
(UPat(Ops.VECTORIZE, src=UPat(Ops.GEP, src=(UPat(name="x"),)), name="vec"), lambda vec,x: x.gep(tuple(y.arg[0] for y in vec.src))),
|
||||
# reorder ALU/VECTORIZE
|
||||
(UPat(Ops.ALU, src=(UPat(Ops.VECTORIZE, src=UPat(name='x')), UPat(Ops.VECTORIZE, src=UPat(name='y'))), name='alu'),
|
||||
lambda x,y,alu: UOp(Ops.VECTORIZE, alu.dtype, (UOp(Ops.ALU, alu.dtype.scalar(), (x,y), alu.arg),)*alu.dtype.count)),
|
||||
(UPat(GroupOp.ALU, src=(UPat(Ops.VECTORIZE, src=UPat(name='x')), UPat(Ops.VECTORIZE, src=UPat(name='y'))), name='alu'),
|
||||
lambda x,y,alu: UOp(Ops.VECTORIZE, alu.dtype, (UOp(alu.op, alu.dtype.scalar(), (x,y)),)*alu.dtype.count)),
|
||||
# VECTORIZE of a single element is just that element
|
||||
(UPat(Ops.VECTORIZE, src=(UPat(name='x'),)), lambda x: x),
|
||||
# VECTORIZE void is SINK
|
||||
@@ -264,7 +264,7 @@ sym = symbolic_flat+PatternMatcher([
|
||||
(UPat(Ops.GEP, src=(UPat.cvar("c", vec=False),), name="gep"), lambda gep, c: gep.const_like(c.arg)),
|
||||
(UPat(Ops.GEP, src=(UPat(Ops.VCONST, name="c"),), name="gep"), lambda gep, c: gep.const_like(tuple(c.arg[x] for x in gep.arg))),
|
||||
# push all GEPs through ALUs (fix arange stuff)
|
||||
(UPat(Ops.GEP, src=(UPat((Ops.ALU, Ops.CAST, Ops.BITCAST), name='alu'),), name='gep'),
|
||||
(UPat(Ops.GEP, src=(UPat((*GroupOp.ALU, Ops.CAST, Ops.BITCAST), name='alu'),), name='gep'),
|
||||
lambda gep,alu: UOp(alu.op, alu.dtype.scalar().vec(gep.dtype.count), tuple(x.gep(gep.arg) for x in alu.src), alu.arg)),
|
||||
# push some GEPs through WMMAs
|
||||
(UPat(Ops.GEP, src=(UPat(Ops.WMMA, name="wmma"),), name="gep"), gep_through_wmma),
|
||||
@@ -275,15 +275,15 @@ sym = symbolic_flat+PatternMatcher([
|
||||
(UPat.var("add") + UPat(Ops.WMMA, name="wmma"),
|
||||
lambda add, wmma: UOp(wmma.op, wmma.dtype, (wmma.src[0], wmma.src[1], wmma.src[2]+add), wmma.arg)),
|
||||
# threefry
|
||||
(UPat(Ops.ALU, dtype=dtypes.uint64, src=(UPat.var("x"), UPat.var("key")), arg=BinaryOps.THREEFRY), threefry2x32),
|
||||
(UPat(Ops.THREEFRY, dtype=dtypes.uint64, src=(UPat.var("x"), UPat.var("key"))), threefry2x32),
|
||||
# arange loop folding
|
||||
(acc_pat.assign(UPat.any(arange_m, arange_m+UPat.var("extra"))+acc_pat), loop_collapse),
|
||||
# indexing, with cast or where
|
||||
(acc_pat.assign(UPat.var("idx").eq(UPat(Ops.RANGE, name="rng")).cast()*index_load+acc_pat), index_collapse),
|
||||
(acc_pat.assign(UPat.var("idx").eq(UPat(Ops.RANGE, name="rng")).where(index_load, UPat.const(None, 0.0))+acc_pat), index_collapse),
|
||||
# parentless reduce
|
||||
(acc_pat.assign(UPat(Ops.ALU, src=[acc_pat, UPat.var("ret")], arg=BinaryOps.ADD, name="alu")), reduce_collapse),
|
||||
(acc_pat.assign(UPat(Ops.ALU, src=[acc_pat, UPat.var("ret")], arg=BinaryOps.MAX, name="alu")), reduce_collapse),
|
||||
(acc_pat.assign(UPat(Ops.ADD, src=[acc_pat, UPat.var("ret")], name="alu")), reduce_collapse),
|
||||
(acc_pat.assign(UPat(Ops.MAX, src=[acc_pat, UPat.var("ret")], name="alu")), reduce_collapse),
|
||||
# ** self folding **
|
||||
(UPat(Ops.DEFINE_ACC, src=(UPat.var("x"),)), lambda x: x), # a DEFINE_ACC without ranges is a CONST
|
||||
(UPat(Ops.ASSIGN, src=(UPat.cvar(),UPat.var("x"))), lambda x: x), # an ASSIGN to a const is a NOOP
|
||||
@@ -403,7 +403,7 @@ expander = PatternMatcher([
|
||||
(UPat(Ops.EXPAND, name="outer", src=(UPat(Ops.EXPAND, name="inner"),)),
|
||||
lambda outer, inner: UOp(Ops.EXPAND, outer.dtype, (inner.src[0],), inner.arg+outer.arg)),
|
||||
# do expansion
|
||||
(UPat((Ops.ALU, Ops.CAST, Ops.BITCAST, Ops.GEP, Ops.WMMA, Ops.LOAD, Ops.STORE, Ops.INDEX, Ops.ASSIGN,
|
||||
(UPat((*GroupOp.ALU, Ops.CAST, Ops.BITCAST, Ops.GEP, Ops.WMMA, Ops.LOAD, Ops.STORE, Ops.INDEX, Ops.ASSIGN,
|
||||
Ops.VECTORIZE, Ops.REDUCE, Ops.IF), name="root", custom_early_reject=set([(Ops.EXPAND, None)])), do_expand),
|
||||
(UPat(Ops.CONTRACT, name="con"), do_contract),
|
||||
# vectorize DEFINE_ACC
|
||||
@@ -433,7 +433,7 @@ def no_vectorized_acc(acc:UOp):
|
||||
|
||||
devectorize = PatternMatcher([
|
||||
# no ALU on vectorized dtypes
|
||||
(UPat((Ops.ALU, Ops.CAST, Ops.BITCAST, Ops.ASSIGN, Ops.INDEX), name="alu"), no_vectorized_alu),
|
||||
(UPat((*GroupOp.ALU, Ops.CAST, Ops.BITCAST, Ops.ASSIGN, Ops.INDEX), name="alu"), no_vectorized_alu),
|
||||
(UPat(Ops.WMMA, name="wmma"), no_vectorized_wmma),
|
||||
(UPat(Ops.DEFINE_ACC, name="acc"), no_vectorized_acc),
|
||||
(UPat((Ops.LOAD, Ops.STORE), name="ls"), no_vectorized_load_store),
|
||||
@@ -448,7 +448,7 @@ load_store_indexing = PatternMatcher([
|
||||
# late fixup of unfoldable image loads
|
||||
(UPat(Ops.LOAD, src=(UPat.var("buf"), UPat()), allow_any_len=True, name="load"), fix_unfoldable_image_load),
|
||||
# simplify valid
|
||||
(UPat(Ops.ALU, name="valid", arg=BinaryOps.AND), simplify_valid),
|
||||
(UPat(Ops.AND, name="valid"), simplify_valid),
|
||||
# image load valid idx simplification
|
||||
(UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var("start_idx"), UPat.var("valid"))), simplify_valid_load),
|
||||
# delete_redundant_gates (after expand)
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import sys, atexit, functools, itertools
|
||||
from collections import defaultdict, deque
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Callable, Set, Tuple, List, Dict, Optional, DefaultDict
|
||||
from typing import Callable, Set, Tuple, List, Dict, Optional, DefaultDict, cast
|
||||
from tinygrad.ops import GroupOp, UOp, Ops, PatternMatcher, UPat, Variable, graph_rewrite, track_rewrites, sint
|
||||
from tinygrad.helpers import DEBUG, Metadata, all_same, colored, diskcache_put, prod, dedup, getenv, unwrap
|
||||
from tinygrad.dtype import ImageDType, dtypes
|
||||
@@ -72,9 +72,7 @@ def to_uop(buf:LazyBuffer, ctx:ScheduleContext, cache:Dict[LazyBuffer, UOp]) ->
|
||||
elif buf.op is Ops.CONTIGUOUS: ret = UOp(Ops.CONTIGUOUS, dtype, src)
|
||||
elif buf.op is Ops.ASSIGN: ret = UOp(Ops.ASSIGN, dtype, (ubuf, src[1]), buf.arg)
|
||||
elif buf.op in GroupOp.Meta: ret = UOp(buf.op, buf.dtype, (ubuf, *src), buf.arg)
|
||||
elif buf.op is Ops.CAST: ret = UOp(Ops.CAST, dtype, src)
|
||||
elif buf.op is Ops.BITCAST: ret = UOp(Ops.BITCAST, dtype, src)
|
||||
else: ret = UOp(Ops.ALU, dtype, src, buf.op)
|
||||
else: ret = UOp(cast(Ops, buf.op), dtype, src)
|
||||
cache[buf] = ret = UOp(Ops.LOAD, dtype, (ubuf, buf.st.to_uop(), UOp.store(ubuf, ShapeTracker.from_shape(buf.shape).to_uop(), ret)))
|
||||
if buf.metadata is not None: ctx.ubuf_metadata[ubuf] = buf.metadata
|
||||
if buf.forced_realize: ctx.realizes[ubuf] = ubuf
|
||||
@@ -142,7 +140,7 @@ merge_views = PatternMatcher([(UPat(Ops.VIEW, src=(UPat(Ops.VIEW, name="s0"),),
|
||||
# push VIEW to loads
|
||||
view_left = merge_views+PatternMatcher([
|
||||
# view before ALU
|
||||
(UPat(Ops.VIEW, src=(UPat((Ops.ALU, Ops.CAST, Ops.BITCAST, Ops.ASSIGN, Ops.CONTIGUOUS, *GroupOp.Buffer), name="e"),), name="v"),
|
||||
(UPat(Ops.VIEW, src=(UPat((*GroupOp.ALU, Ops.CAST, Ops.BITCAST, Ops.ASSIGN, Ops.CONTIGUOUS, *GroupOp.Buffer), name="e"),), name="v"),
|
||||
lambda e,v: e.replace(src=tuple(s.view(v.st) if s.has_st else s for s in e.src))),
|
||||
])
|
||||
|
||||
@@ -156,7 +154,7 @@ view_right = merge_views+PatternMatcher([
|
||||
# push a VIEW down to STORE, through a reduce (ONLY reshapes)
|
||||
(UPat(Ops.REDUCE_AXIS, src=(UPat(Ops.VIEW, name="swizzle"),), name="root"), push_swizzle_down_through_reduce),
|
||||
# push VIEW(s) down to STORE, through an elementwise op (ONLY reshapes)
|
||||
(UPat((Ops.ALU, Ops.CAST, Ops.BITCAST, Ops.ASSIGN, Ops.CONTIGUOUS, Ops.STORE), name="root"), push_swizzle_down_through_elementwise),
|
||||
(UPat((*GroupOp.ALU, Ops.CAST, Ops.BITCAST, Ops.ASSIGN, Ops.CONTIGUOUS, Ops.STORE), name="root"), push_swizzle_down_through_elementwise),
|
||||
(UPat(Ops.REDUCE_AXIS, src=(UPat(Ops.REDUCE_AXIS, name="first_reduce"),), name="root"), merge_double_reduce),
|
||||
])
|
||||
|
||||
|
||||
126
tinygrad/ops.py
126
tinygrad/ops.py
@@ -169,7 +169,11 @@ class Ops(FastEnum):
|
||||
CONST = auto()
|
||||
|
||||
class GroupOp:
|
||||
Unary = {Ops.EXP2, Ops.LOG2, Ops.SIN, Ops.SQRT, Ops.RECIP, Ops.NEG}
|
||||
Binary = {Ops.ADD, Ops.MUL, Ops.IDIV, Ops.MAX, Ops.MOD, Ops.CMPLT, Ops.CMPNE, Ops.XOR, Ops.SHL, Ops.SHR, Ops.OR, Ops.AND, Ops.THREEFRY, Ops.SUB}
|
||||
Ternary = {Ops.WHERE, Ops.MULACC}
|
||||
ALU = set.union(Unary, Binary, Ternary)
|
||||
|
||||
Reduce = {Ops.SUM, Ops.PROD, Ops.REDUCE_MAX}
|
||||
|
||||
# meta ops
|
||||
@@ -224,6 +228,8 @@ def pretty_print(x:Any, rep:Callable, srcfn=lambda x: x.src, cache=None, d=0)->s
|
||||
class UOpMetaClass(type):
|
||||
ucache:WeakValueDictionary[Tuple, UOp] = WeakValueDictionary()
|
||||
def __call__(cls, op:Ops, dtype:DType=dtypes.void, src:Tuple[UOp,...]=tuple(), arg:Any=None):
|
||||
# TODO: remove this
|
||||
if op is Ops.ALU: op, arg = arg, None
|
||||
if (ret:=UOpMetaClass.ucache.get(key:=(op, dtype, src, arg), None)) is not None: return ret
|
||||
UOpMetaClass.ucache[key] = ret = super().__call__(op, dtype, src, arg)
|
||||
return ret
|
||||
@@ -232,11 +238,6 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
|
||||
__slots__ = ["op", "dtype", "src", "arg"]
|
||||
def __init__(self, op:Ops, dtype:DType=dtypes.void, src: Tuple[UOp,...]=tuple(), arg:Any=None):
|
||||
# TODO: instant check rules here make debugging easier
|
||||
#assert op in UOps and isinstance(dtype, DType), f"bad UOp creation with {op} {dtype}"
|
||||
#if op is UOps.ALU and arg is BinaryOps.CMPNE: assert dtype.scalar() == dtypes.bool
|
||||
#if op is UOps.VECTORIZE and dtype != dtypes.void: assert len(src) == dtype.count, f"{len(src)} invalid for {dtype}"
|
||||
#if op is UOps.ALU and arg not in (BinaryOps.CMPNE, BinaryOps.CMPLT, TernaryOps.WHERE): assert all_same([dtype] + [x.dtype for x in src])
|
||||
#if op is UOps.CAST: assert dtype.count == src[0].dtype.count, f"cast can't change vectorization {src[0].dtype} --> {dtype}"
|
||||
self.op, self.dtype, self.src, self.arg = op, dtype, src, arg
|
||||
def __reduce__(self): return UOp, (self.op, self.dtype, self.src, self.arg)
|
||||
def replace(self, **kwargs) -> UOp:
|
||||
@@ -256,7 +257,7 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
|
||||
|
||||
@functools.cached_property
|
||||
def tuplize(self:UOp) -> Tuple[int, Any, Optional[DType], Tuple]:
|
||||
return (self.op.value, self.arg.value if self.op is Ops.ALU else self.arg, self.dtype, tuple(x.tuplize for x in self.src))
|
||||
return (self.op.value, self.arg, self.dtype, tuple(x.tuplize for x in self.src))
|
||||
|
||||
# *** uop shape stuff ***
|
||||
|
||||
@@ -334,7 +335,7 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
|
||||
out_dtype = (self, *src)[-1].dtype
|
||||
if arg in {BinaryOps.CMPLT, BinaryOps.CMPNE} and out_dtype is not None:
|
||||
out_dtype = dtypes.bool.vec(out_dtype.count) if out_dtype.count > 1 else dtypes.bool
|
||||
return UOp(Ops.ALU, out_dtype, (self,)+src, arg)
|
||||
return UOp(arg, out_dtype, (self,)+src)
|
||||
@staticmethod
|
||||
def const(dtype:DType, b:ConstLike):
|
||||
if isinstance(b, UOp): return b.unbind()[0] if b.op is Ops.BIND else b
|
||||
@@ -382,19 +383,17 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
|
||||
"""largest known int that divides self"""
|
||||
if self.op is Ops.CONST: return self.arg
|
||||
if self.op is Ops.VCONST: return functools.reduce(math.gcd, self.arg)
|
||||
if self.op is Ops.ALU:
|
||||
if self.arg is BinaryOps.ADD: return math.gcd(self.src[0].const_factor(), self.src[1].const_factor())
|
||||
if self.arg is BinaryOps.MUL: return self.src[0].arg if self.src[0].op is Ops.CONST else self.src[1].arg if self.src[1].op is Ops.CONST else 1
|
||||
if self.op is BinaryOps.ADD: return math.gcd(self.src[0].const_factor(), self.src[1].const_factor())
|
||||
if self.op is BinaryOps.MUL: return self.src[0].arg if self.src[0].op is Ops.CONST else self.src[1].arg if self.src[1].op is Ops.CONST else 1
|
||||
return 1
|
||||
def divides(self, v) -> Optional[UOp]:
|
||||
if v==1: return self
|
||||
if self.op is Ops.CONST: return self.const_like(self.arg//v) if self.arg%v == 0 else None
|
||||
if self.op is Ops.VCONST: return self.const_like(tuple(x//v for x in self.arg)) if all(x%v == 0 for x in self.arg) else None
|
||||
if self.op is Ops.ALU:
|
||||
if self.arg is BinaryOps.ADD: return d0+d1 if (d0:=self.src[0].divides(v)) is not None and (d1:=self.src[1].divides(v)) is not None else None
|
||||
if self.arg is BinaryOps.MUL:
|
||||
if (d0:=self.src[0].divides(v)) is not None: return d0 * self.src[1]
|
||||
if (d1:=self.src[1].divides(v)) is not None: return self.src[0] * d1
|
||||
if self.op is BinaryOps.ADD: return d0+d1 if (d0:=self.src[0].divides(v)) is not None and (d1:=self.src[1].divides(v)) is not None else None
|
||||
if self.op is BinaryOps.MUL:
|
||||
if (d0:=self.src[0].divides(v)) is not None: return d0 * self.src[1]
|
||||
if (d1:=self.src[1].divides(v)) is not None: return self.src[0] * d1
|
||||
return None # generic None if we aren't sure
|
||||
@property
|
||||
def vmin(self) -> ConstType: return self._min_max[0]
|
||||
@@ -411,25 +410,25 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
|
||||
if self.op is Ops.SPECIAL: return 0, self.arg[1]-1 if isinstance(self.arg[1], int) else dtypes.max(self.dtype)
|
||||
if self.op is Ops.CONST: return self.arg, self.arg
|
||||
if self.op is Ops.VCONST: return (min(self.arg), max(self.arg))
|
||||
if self.op is Ops.ALU and not dtypes.is_float(self.dtype):
|
||||
if self.op in GroupOp.ALU and not dtypes.is_float(self.dtype):
|
||||
s0,s1,s2 = [cast(UOp, self.src[i] if i < len(self.src) else None) for i in range(3)]
|
||||
if self.arg is BinaryOps.ADD: return s0.vmin+s1.vmin, s0.vmax+s1.vmax
|
||||
if self.arg is BinaryOps.MUL: return min(vals:=(s0.vmin*s1.vmin, s0.vmin*s1.vmax, s0.vmax*s1.vmin, s0.vmax*s1.vmax)), max(vals)
|
||||
if self.arg is BinaryOps.MOD and s1.vmin > 0: return 0, s1.vmax-1
|
||||
if self.arg is BinaryOps.IDIV and s1.op is Ops.CONST:
|
||||
if self.op is BinaryOps.ADD: return s0.vmin+s1.vmin, s0.vmax+s1.vmax
|
||||
if self.op is BinaryOps.MUL: return min(vals:=(s0.vmin*s1.vmin, s0.vmin*s1.vmax, s0.vmax*s1.vmin, s0.vmax*s1.vmax)), max(vals)
|
||||
if self.op is BinaryOps.MOD and s1.vmin > 0: return 0, s1.vmax-1
|
||||
if self.op is BinaryOps.IDIV and s1.op is Ops.CONST:
|
||||
if s1.arg > 0: return s0.vmin//s1.arg, s0.vmax//s1.arg
|
||||
if s1.arg < 0 and s0.vmin >= 0: return -(s0.vmax//-s1.arg), -(s0.vmin//-s1.arg)
|
||||
if self.arg is BinaryOps.MAX: return max(s0.vmin, s1.vmin), max(s0.vmax, s1.vmax)
|
||||
if self.arg is BinaryOps.CMPLT: return (s0.vmax<s1.vmin, s0.vmin<s1.vmax)
|
||||
if self.arg is BinaryOps.CMPNE:
|
||||
if self.op is BinaryOps.MAX: return max(s0.vmin, s1.vmin), max(s0.vmax, s1.vmax)
|
||||
if self.op is BinaryOps.CMPLT: return (s0.vmax<s1.vmin, s0.vmin<s1.vmax)
|
||||
if self.op is BinaryOps.CMPNE:
|
||||
always_ne = (s0.vmax < s1.vmin) or (s1.vmax < s0.vmin)
|
||||
sometimes_ne = not (s0.vmin == s0.vmax == s1.vmin == s1.vmax)
|
||||
return (always_ne, sometimes_ne)
|
||||
# float has NAN issue and we use explicit NAN in transcendental
|
||||
if self.arg is TernaryOps.WHERE and dtypes.is_int(s1.dtype): return min(s1.vmin, s2.vmin), max(s1.vmax, s2.vmax)
|
||||
if self.op is TernaryOps.WHERE and dtypes.is_int(s1.dtype): return min(s1.vmin, s2.vmin), max(s1.vmax, s2.vmax)
|
||||
if self.dtype == dtypes.bool:
|
||||
if self.arg is BinaryOps.OR: return s0.vmin or s1.vmin, s0.vmax or s1.vmax
|
||||
if self.arg is BinaryOps.AND: return s0.vmin and s1.vmin, s0.vmax and s1.vmax
|
||||
if self.op is BinaryOps.OR: return s0.vmin or s1.vmin, s0.vmax or s1.vmax
|
||||
if self.op is BinaryOps.AND: return s0.vmin and s1.vmin, s0.vmax and s1.vmax
|
||||
return dtypes.min(self.dtype), dtypes.max(self.dtype)
|
||||
|
||||
@functools.cached_property
|
||||
@@ -509,8 +508,8 @@ def flops_mem(uops:List[UOp], ignore_indexing=False) -> Tuple[sint, sint]:
|
||||
mem += u.dtype.itemsize * mults
|
||||
elif u.op is Ops.STORE:
|
||||
mem += u.src[1].dtype.itemsize * mults
|
||||
elif u.op is Ops.ALU and u not in dont_count:
|
||||
flops += (mults * (2 if u.arg == TernaryOps.MULACC else 1)) * u.dtype.count
|
||||
elif u.op in GroupOp.ALU and u not in dont_count:
|
||||
flops += (mults * (2 if u.op is TernaryOps.MULACC else 1)) * u.dtype.count
|
||||
elif u.op is Ops.WMMA and u not in dont_count:
|
||||
flops += 2 * prod(u.arg[1]) // u.arg[5] * mults
|
||||
return flops, mem
|
||||
@@ -582,9 +581,9 @@ class UPat(MathTrait):
|
||||
def assign(self, x:UPat): return UPat(Ops.ASSIGN, self.dtype, (self,x))
|
||||
|
||||
def const_like(self, b:ConstLike): return UPat.const(self.dtype, cast(ConstType, b))
|
||||
def alu(self, arg, *src:UPat):
|
||||
def alu(self, op:Ops, *src:UPat):
|
||||
asrc = (self,)+src
|
||||
return UPat(Ops.ALU, None if arg in {Ops.CMPLT, Ops.CMPNE} else asrc[-1].dtype, list(asrc) if arg in GroupOp.Commutative else asrc, arg)
|
||||
return UPat(op, None if op in {Ops.CMPLT, Ops.CMPNE} else asrc[-1].dtype, list(asrc) if op in GroupOp.Commutative else asrc)
|
||||
|
||||
def printable(self:UPat) -> str:
|
||||
try: return lines(self.location[0])[self.location[1]-1].strip()
|
||||
@@ -790,17 +789,13 @@ spec = PatternMatcher([
|
||||
(UPat(Ops.STORE, dtype=dtypes.void, src=(UPat((Ops.INDEX, Ops.CAST)), UPat(), UPat(Ops.IF))), lambda: True),
|
||||
|
||||
# most ALUs have all matching dtypes, except CMPLT, CMPNE, and WHERE
|
||||
(UPat(Ops.ALU, name="w", src=(UPat(dtype=dtypes.bool), UPat(name="x"), UPat(name="y")), arg=TernaryOps.WHERE),
|
||||
lambda w,x,y: w.dtype == x.dtype == y.dtype),
|
||||
(UPat(Ops.ALU, dtype=dtypes.bool, src=(UPat(name="x"), UPat(name="y")), arg=BinaryOps.CMPLT), lambda x,y: x.dtype == y.dtype),
|
||||
(UPat(Ops.ALU, dtype=dtypes.bool, src=(UPat(name="x"), UPat(name="y")), arg=BinaryOps.CMPNE), lambda x,y: x.dtype == y.dtype),
|
||||
# and SHL/SHR, the shift distance is an int
|
||||
(UPat(Ops.ALU, src=(UPat(name="x"), UPat(name="y")), name="alu", arg=BinaryOps.SHL),
|
||||
(UPat(Ops.WHERE, name="w", src=(UPat(dtype=dtypes.bool), UPat(name="x"), UPat(name="y"))), lambda w,x,y: w.dtype == x.dtype == y.dtype),
|
||||
(UPat((Ops.CMPLT, Ops.CMPNE), dtype=dtypes.bool, src=(UPat(name="x"), UPat(name="y"))), lambda x,y: x.dtype == y.dtype),
|
||||
# and SHL/SHR, the shift distance can be an int
|
||||
(UPat((Ops.SHL, Ops.SHR), src=(UPat(name="x"), UPat(name="y")), name="alu"),
|
||||
lambda alu,x,y: alu.dtype == x.dtype and (x.dtype == y.dtype or y.dtype == dtypes.uint)),
|
||||
(UPat(Ops.ALU, src=(UPat(name="x"), UPat(name="y")), name="alu", arg=BinaryOps.SHR),
|
||||
lambda alu,x,y: alu.dtype == x.dtype and (x.dtype == y.dtype or y.dtype == dtypes.uint)),
|
||||
(UPat(Ops.ALU, arg=BinaryOps.IDIV, name="x"), lambda x: None if dtypes.is_int(x.dtype) else False),
|
||||
(UPat(Ops.ALU, name="x"), lambda x: all(x.dtype == y.dtype for y in x.src)),
|
||||
(UPat(Ops.IDIV, name="x"), lambda x: None if dtypes.is_int(x.dtype) else False),
|
||||
(UPat(GroupOp.ALU, name="x"), lambda x: all(x.dtype == y.dtype for y in x.src)),
|
||||
|
||||
(UPat(Ops.ASSIGN, src=(UPat((Ops.DEFINE_ACC, Ops.DEFINE_GLOBAL)), UPat())), lambda: True),
|
||||
(UPat(Ops.ENDRANGE, dtype=dtypes.void, src=(UPat(Ops.RANGE),)), lambda: True),
|
||||
@@ -848,7 +843,7 @@ def cast_float_to_bf16(x: UOp) -> UOp:
|
||||
# *** most of symbolic lives here now ***
|
||||
|
||||
def split_uop(x:UOp, sep:Ops):
|
||||
if x.op is Ops.ALU and x.arg is sep:
|
||||
if x.op is sep:
|
||||
for s in x.src: yield from split_uop(s, sep)
|
||||
else: yield x
|
||||
|
||||
@@ -865,7 +860,7 @@ def mod_folding(x:UOp, c:int) -> Optional[UOp]:
|
||||
assert divides is not None
|
||||
remainder.append(divides)
|
||||
something_changed = True
|
||||
elif u.op is Ops.ALU and u.arg is BinaryOps.MOD and (s1:=u.src[1]).op is Ops.CONST and s1.arg%c == 0:
|
||||
elif u.op is Ops.MOD and (s1:=u.src[1]).op is Ops.CONST and s1.arg%c == 0:
|
||||
remainder.append(u.src[0])
|
||||
something_changed = True
|
||||
else: remainder.append(u)
|
||||
@@ -892,7 +887,7 @@ def div_folding(x:UOp, c:int) -> Optional[UOp]:
|
||||
something_changed = True
|
||||
else:
|
||||
# divisor is the smallest common divisor of all MULs
|
||||
if u.op is Ops.ALU and u.arg is BinaryOps.MUL and factor > 1 and c % factor == 0 and (divisor == 1 or divisor > factor): divisor = factor
|
||||
if u.op is Ops.MUL and factor > 1 and c % factor == 0 and (divisor == 1 or divisor > factor): divisor = factor
|
||||
remainder.append(u)
|
||||
gcd = math.gcd(gcd, factor)
|
||||
|
||||
@@ -923,11 +918,11 @@ def fold_unrolled_divs(divs:UOp):
|
||||
# example: (x//4+(x+1)//4+(x+2)//4+(x+3)//4 -> x
|
||||
add_chain, denominator, seen_const, ans = list(split_uop(divs, BinaryOps.ADD)), None, [], None
|
||||
for u in add_chain:
|
||||
if not (u.op is Ops.ALU and u.arg is BinaryOps.IDIV and u.src[1].op is Ops.CONST): return None
|
||||
if not (u.op is Ops.IDIV and u.src[1].op is Ops.CONST): return None
|
||||
if denominator is None: denominator = u.src[1].arg
|
||||
if denominator != u.src[1].arg: return None
|
||||
# assumed CONST is the last of an ADD
|
||||
if (s0:=u.src[0]).op is Ops.ALU and s0.arg is BinaryOps.ADD and s0.src[1].op is Ops.CONST and s0.src[1].op is Ops.CONST:
|
||||
if (s0:=u.src[0]).op is Ops.ADD and s0.src[1].op is Ops.CONST and s0.src[1].op is Ops.CONST:
|
||||
seen_const.append(s0.src[1].arg)
|
||||
s0 = s0.src[0]
|
||||
else: seen_const.append(0)
|
||||
@@ -947,7 +942,7 @@ def canonicalize_simplex(X:UOp) -> Optional[UOp]:
|
||||
changed, ret = False, []
|
||||
for u in split_uop(X, BinaryOps.ADD):
|
||||
# assumed the const is the last src of MUL
|
||||
if u.op is Ops.ALU and u.arg is BinaryOps.MUL and u.src[1].op is Ops.CONST and u.src[1].arg > 0:
|
||||
if u.op is BinaryOps.MUL and u.src[1].op is Ops.CONST and u.src[1].arg > 0:
|
||||
changed = True
|
||||
u = u.src[0]
|
||||
if not (is_irreducible(u) and u.vmin >= 0): return None
|
||||
@@ -957,8 +952,8 @@ def canonicalize_simplex(X:UOp) -> Optional[UOp]:
|
||||
def is_increasing(f:UOp) -> bool:
|
||||
# is f a monotonically increasing function regards its input
|
||||
if is_irreducible(f): return True
|
||||
if f.op is Ops.ALU and f.arg is BinaryOps.ADD: return is_increasing(f.src[0]) and is_increasing(f.src[1])
|
||||
if f.op is Ops.ALU and f.arg in (BinaryOps.MUL, BinaryOps.IDIV) and f.src[1].op is Ops.CONST and f.src[1].arg >= 0: return is_increasing(f.src[0])
|
||||
if f.op is BinaryOps.ADD: return is_increasing(f.src[0]) and is_increasing(f.src[1])
|
||||
if f.op in (BinaryOps.MUL, BinaryOps.IDIV) and f.src[1].op is Ops.CONST and f.src[1].arg >= 0: return is_increasing(f.src[0])
|
||||
return False # False if not sure
|
||||
|
||||
def parse_valid(valid:UOp) -> Tuple[UOp, bool, int]:
|
||||
@@ -966,10 +961,10 @@ def parse_valid(valid:UOp) -> Tuple[UOp, bool, int]:
|
||||
# if it's X >= c, returns X, False, c
|
||||
|
||||
# (X < c).ne(True) -> X >= c
|
||||
if valid.op is Ops.ALU and valid.arg is BinaryOps.CMPNE and valid.src[1].op is Ops.CONST and valid.src[1].arg == 1 and \
|
||||
(s0:=valid.src[0]).op is Ops.ALU and s0.arg is BinaryOps.CMPLT and s0.src[1].op is Ops.CONST: return s0.src[0], False, s0.src[1].arg
|
||||
if valid.op is BinaryOps.CMPNE and valid.src[1].op is Ops.CONST and valid.src[1].arg == 1 and \
|
||||
(s0:=valid.src[0]).op is BinaryOps.CMPLT and s0.src[1].op is Ops.CONST: return s0.src[0], False, s0.src[1].arg
|
||||
# X < c -> X <= c-1
|
||||
if valid.op is Ops.ALU and valid.arg is BinaryOps.CMPLT and valid.src[1].op is Ops.CONST: return valid.src[0], True, valid.src[1].arg-1
|
||||
if valid.op is BinaryOps.CMPLT and valid.src[1].op is Ops.CONST: return valid.src[0], True, valid.src[1].arg-1
|
||||
raise ValueError(f"not able to parse {valid=}")
|
||||
|
||||
def uop_given_valid(valid:UOp, uop:UOp) -> Optional[UOp]:
|
||||
@@ -989,7 +984,7 @@ def uop_given_valid(valid:UOp, uop:UOp) -> Optional[UOp]:
|
||||
|
||||
# every candidate is a set of contrained UOp based on valid, and if every item in a set simplifies the uop into a same output, we rewrite uop
|
||||
candidates = []
|
||||
if expr.op is Ops.ALU and expr.arg is BinaryOps.ADD and all(is_irreducible(u) and v[0] == 1 for u in split_uop(expr, BinaryOps.ADD)):
|
||||
if expr.op is Ops.ADD and all(is_irreducible(u) and v[0] == 1 for u in split_uop(expr, BinaryOps.ADD)):
|
||||
# if the constraint is a simplex: X0 + X1 + ... > 0, we can check if all Xi > 0 simplify into the same output
|
||||
candidates.append([(Xi, UOp.variable("fake", 1, Xi.vmax, Xi.dtype)) for Xi in split_uop(expr, BinaryOps.ADD)])
|
||||
# try checking the whole clause
|
||||
@@ -1043,8 +1038,8 @@ symbolic_simple = PatternMatcher([
|
||||
# NOTE: this can be wrong for loaded NaN
|
||||
(UPat.var("x") * 0, lambda x: x.const_like(float("nan") if isinstance(x.arg, float) and (math.isnan(x.arg) or math.isinf(x.arg)) else 0)),
|
||||
# ** constant folding **
|
||||
(UPat(Ops.ALU, name="root", src=UPat((Ops.VCONST, Ops.CONST))),
|
||||
lambda root: root.const_like(exec_alu(root.arg, root.dtype, [x.arg for x in root.src], truncate_output=False))),
|
||||
(UPat(GroupOp.ALU, name="root", src=UPat((Ops.VCONST, Ops.CONST))),
|
||||
lambda root: root.const_like(exec_alu(root.op, root.dtype, [x.arg for x in root.src], truncate_output=False))),
|
||||
# bool MUL is AND, ADD/MAX is OR. prevents other rules to rewrite bool ADD/MUL incorrectly
|
||||
(UPat.var('x', dtype=dtypes.bool) * UPat.var('y', dtype=dtypes.bool), lambda x,y: x&y),
|
||||
(UPat.var('x', dtype=dtypes.bool) + UPat.var('y', dtype=dtypes.bool), lambda x,y: x|y),
|
||||
@@ -1056,8 +1051,7 @@ symbolic_simple = PatternMatcher([
|
||||
|
||||
symbolic = symbolic_simple+PatternMatcher([
|
||||
# ** COMMUTATIVE flipping **
|
||||
*[(UPat(Ops.ALU, arg=op, name='x'), lambda x: x.replace(src=x.src[::-1]) if x.src[1].tuplize < x.src[0].tuplize else None) \
|
||||
for op in GroupOp.Commutative],
|
||||
(UPat(GroupOp.Commutative, name='x'), lambda x: x.replace(src=x.src[::-1]) if x.src[1].tuplize < x.src[0].tuplize else None),
|
||||
# group like
|
||||
((UPat.var("x") + UPat.var("y")) + UPat.var("x") * UPat.cvar("c"), lambda x,y,c: (x+x*c)+y),
|
||||
# ** combine terms **
|
||||
@@ -1070,7 +1064,7 @@ symbolic = symbolic_simple+PatternMatcher([
|
||||
(UPat.var().where(UPat.var("val"), UPat.var("val")), lambda val: val),
|
||||
(UPat.cvar("gate", vec=False).where(UPat.var("c0"), UPat.var("c1")), lambda gate, c0, c1: c0 if gate.arg else c1),
|
||||
# ALU min==max -> CONST (slow!)
|
||||
(UPat(Ops.ALU, name="x"), lambda x: x.const_like(x.vmin) if x.vmin == x.vmax else None),
|
||||
(UPat(GroupOp.ALU, name="x"), lambda x: x.const_like(x.vmin) if x.vmin == x.vmax else None),
|
||||
# max folding
|
||||
(UPat.maximum(UPat.var("x"), UPat.var("y")), lambda x,y: x if x.vmin >= y.vmax else y if x.vmax <= y.vmin else None),
|
||||
# TODO: why does this rule break beautiful_mnist?
|
||||
@@ -1097,11 +1091,11 @@ symbolic = symbolic_simple+PatternMatcher([
|
||||
(((UPat.cvar("c0", vec=False)*UPat.var("x"))+UPat.var("x2")).lt(UPat.cvar("c1", vec=False)),
|
||||
lambda x,x2,c0,c1: x.lt(c1//c0) if c1.arg % c0.arg == 0 and c0.arg > x2.vmax and x2.vmin >= 0 else None),
|
||||
# ** move add/mul consts to end (NOTE: this is still happening before constant folding) **
|
||||
(UPat(Ops.ALU, arg=BinaryOps.ADD, src=(UPat.var("x"), UPat.cvar("c1"))) + UPat.var("y"), lambda x,c1,y: (x+y)+c1),
|
||||
(UPat(Ops.ALU, arg=BinaryOps.MUL, src=(UPat.var("x"), UPat.cvar("c1"))) * UPat.var("y"), lambda x,c1,y: (x*y)*c1),
|
||||
(UPat(Ops.ADD, src=(UPat.var("x"), UPat.cvar("c1"))) + UPat.var("y"), lambda x,c1,y: (x+y)+c1),
|
||||
(UPat(Ops.MUL, src=(UPat.var("x"), UPat.cvar("c1"))) * UPat.var("y"), lambda x,c1,y: (x*y)*c1),
|
||||
# *** rules from symbolic ***
|
||||
# unrolled arange div folding
|
||||
(UPat(Ops.ALU, name="divs", src=[UPat(), UPat(Ops.ALU, arg=BinaryOps.IDIV)], arg=BinaryOps.ADD), fold_unrolled_divs),
|
||||
(UPat(Ops.ADD, name="divs", src=[UPat(), UPat(Ops.IDIV)]), fold_unrolled_divs),
|
||||
# generic lt folding
|
||||
(UPat.var("x", dtypes.sints).lt(UPat.cvar("c", vec=False)), lambda x,c: lt_folding(x, c.arg) if 0 < c.arg else None),
|
||||
# canonicalize a simplex with positive coefficients > 0
|
||||
@@ -1132,13 +1126,11 @@ renderer = PatternMatcher([
|
||||
(UPat(Ops.RANGE, name="x"), lambda x: UOp(Ops.NOOP, arg=f"ridx{x.arg[0]}")),
|
||||
(UPat(Ops.CONST, name="x"), lambda x: UOp(Ops.NOOP, arg=str(x.arg))),
|
||||
(UPat(Ops.BIND, src=UPat(Ops.NOOP), name="x"), lambda x: x.src[0]),
|
||||
(UPat(Ops.ALU, src=UPat(Ops.NOOP), name="x", arg=UnaryOps.NEG), lambda x: UOp(Ops.NOOP, arg=f"(-{x.src[0].arg})")),
|
||||
(UPat(Ops.ALU, src=UPat(Ops.NOOP), name="x", arg=BinaryOps.MAX), lambda x: UOp(Ops.NOOP, arg=f"max({x.src[0].arg}, {x.src[1].arg})")),
|
||||
(UPat(Ops.ALU, src=UPat(Ops.NOOP), name="x", arg=TernaryOps.MULACC),
|
||||
lambda x: UOp(Ops.NOOP, arg=f"({x.src[0].arg}*{x.src[1].arg}+{x.src[2].arg})")),
|
||||
(UPat(Ops.ALU, src=UPat(Ops.NOOP), name="x", arg=TernaryOps.WHERE),
|
||||
lambda x: UOp(Ops.NOOP, arg=f"({x.src[1].arg} if {x.src[0].arg} else {x.src[2].arg})")),
|
||||
(UPat(Ops.ALU, src=UPat(Ops.NOOP), name="x"), lambda x: UOp(Ops.NOOP, arg=f"({x.src[0].arg}{syms[x.arg]}{x.src[1].arg})")),
|
||||
(UPat(Ops.NEG, src=UPat(Ops.NOOP), name="x"), lambda x: UOp(Ops.NOOP, arg=f"(-{x.src[0].arg})")),
|
||||
(UPat(Ops.MAX, src=UPat(Ops.NOOP), name="x"), lambda x: UOp(Ops.NOOP, arg=f"max({x.src[0].arg}, {x.src[1].arg})")),
|
||||
(UPat(Ops.MULACC, src=UPat(Ops.NOOP), name="x"), lambda x: UOp(Ops.NOOP, arg=f"({x.src[0].arg}*{x.src[1].arg}+{x.src[2].arg})")),
|
||||
(UPat(Ops.WHERE, src=UPat(Ops.NOOP), name="x"), lambda x: UOp(Ops.NOOP, arg=f"({x.src[1].arg} if {x.src[0].arg} else {x.src[2].arg})")),
|
||||
(UPat(GroupOp.ALU, src=UPat(Ops.NOOP), name="x"), lambda x: UOp(Ops.NOOP, arg=f"({x.src[0].arg}{syms[x.op]}{x.src[1].arg})")),
|
||||
])
|
||||
|
||||
# *** what was symbolic.py ***
|
||||
|
||||
@@ -2,7 +2,7 @@ from __future__ import annotations
|
||||
from typing import Dict, List, Optional, Tuple, Union, DefaultDict, Literal, Callable, cast
|
||||
import os, math
|
||||
from collections import defaultdict, Counter
|
||||
from tinygrad.ops import UnaryOps, BinaryOps, TernaryOps, Ops, UOp, PatternMatcher, UPat, cast_float_to_bf16
|
||||
from tinygrad.ops import GroupOp, UnaryOps, BinaryOps, TernaryOps, Ops, UOp, PatternMatcher, UPat, cast_float_to_bf16
|
||||
from tinygrad.helpers import strip_parens, getenv, prod, dedup, AMX
|
||||
from tinygrad.dtype import ImageDType, dtypes, DType, PtrDType
|
||||
from tinygrad.renderer import Renderer, TensorCore
|
||||
@@ -47,8 +47,8 @@ base_rewrite = PatternMatcher([
|
||||
(UPat(Ops.LOAD, src=(UPat.var('bidx'),), allow_any_len=True), lambda ctx,bidx: f"*{ctx[bidx]}"),
|
||||
(UPat(Ops.STORE, src=(UPat.var('bidx'), UPat.var("var")), allow_any_len=True), lambda ctx,bidx,var: f"*{ctx[bidx]} = {ctx[var]};"),
|
||||
# alu/gep
|
||||
(UPat(Ops.ALU, name="x"), lambda ctx,x: ctx.code_for_op[x.arg](
|
||||
*([strip_parens(ctx[v]) if v.arg == x.arg and x.arg in {BinaryOps.ADD, BinaryOps.MUL, BinaryOps.XOR} else ctx[v] for v in x.src]), x.dtype)),
|
||||
(UPat(GroupOp.ALU, name="x"), lambda ctx,x: ctx.code_for_op[x.op](
|
||||
*([strip_parens(ctx[v]) if v.op == x.op and x.op in {BinaryOps.ADD, BinaryOps.MUL, BinaryOps.XOR} else ctx[v] for v in x.src]), x.dtype)),
|
||||
(UPat(Ops.GEP, name="x"), lambda ctx,x: ctx[x.src[0]] + \
|
||||
(f"[{x.arg[0]}]" if x.src[0].dtype.count > (8 if ctx.device in {"CUDA", "NV"} else 4) or ctx.device == 'CLANG' else f".{'xyzwabcd'[x.arg[0]]}")),
|
||||
])
|
||||
@@ -61,7 +61,7 @@ extra_pm = PatternMatcher([
|
||||
(UPat(Ops.STORE, dtype=dtypes.void, src=(UPat(), UPat(), UPat(dtype=dtypes.bool)), name="store"),
|
||||
lambda store: UOp(Ops.STORE, src=store.src[:2]+(UOp(Ops.IF, src=(store.src[2],)),))),
|
||||
# rewrite MAX to CMPLT + WHERE (max function is annoying on many cstyle backends)
|
||||
(UPat(Ops.ALU, name="m", arg=BinaryOps.MAX), lambda m: (m.src[0] < m.src[1]).where(m.src[1], m.src[0])),
|
||||
(UPat(Ops.MAX, name="m"), lambda m: (m.src[0] < m.src[1]).where(m.src[1], m.src[0])),
|
||||
])
|
||||
|
||||
def uops_to_dtypes(uops:List[UOp]) -> List[DType]: return dedup(u.dtype for u in uops if not isinstance(u.dtype, (ImageDType, PtrDType)))
|
||||
@@ -140,16 +140,16 @@ class CStyleLanguage(Renderer):
|
||||
if u.op is Ops.SPECIAL:
|
||||
r[u] = u.arg[0]
|
||||
else:
|
||||
prefix = {Ops.RANGE: "ridx", Ops.ALU: "alu", Ops.WMMA: "wmma", Ops.DEFINE_LOCAL: "temp", Ops.CONST: "const",
|
||||
prefix = {Ops.RANGE: "ridx", Ops.WMMA: "wmma", Ops.DEFINE_LOCAL: "temp", Ops.CONST: "const",
|
||||
Ops.CAST: "cast", Ops.BITCAST: "cast", Ops.GEP: "gep", Ops.VECTORIZE: "cast", Ops.NOOP: "precast",
|
||||
Ops.INDEX: "bidx", Ops.DEFINE_ACC: "acc", Ops.LOAD: "val"}.get(u.op, "unk")
|
||||
Ops.INDEX: "bidx", Ops.DEFINE_ACC: "acc", Ops.LOAD: "val"}.get(u.op, "alu")
|
||||
r[u] = f"{prefix}{c[prefix]}"
|
||||
|
||||
l = cast(str, self.string_rewrite.rewrite(u, ctx=self))
|
||||
assert l is not None, f"failed to render {u.op} {u.dtype} {[(x.op,x.dtype) for x in u.src]} {u.arg}"
|
||||
|
||||
if u.op in {Ops.ENDIF, Ops.ENDRANGE}: depth -= 1
|
||||
if u.op in {Ops.CONST, Ops.GEP, Ops.INDEX} or (u.op in {Ops.VECTORIZE, Ops.ALU, Ops.CAST, Ops.BITCAST}
|
||||
if u.op in {Ops.CONST, Ops.GEP, Ops.INDEX} or (u.op in {Ops.VECTORIZE, *GroupOp.ALU, Ops.CAST, Ops.BITCAST}
|
||||
and child_count[u] == 1 and not getenv("EXPAND_SSA")):
|
||||
r[u] = l
|
||||
else:
|
||||
@@ -272,9 +272,8 @@ class MetalRenderer(CStyleLanguage):
|
||||
# upcast to float32 all the ops that don't support bfloat16
|
||||
extra_matcher = PatternMatcher([
|
||||
# NOTE: this is copied from PTX
|
||||
*[(UPat(Ops.ALU, arg=op, dtype=dtypes.bfloat16, name="x"),
|
||||
lambda x: (UOp(x.op, dtypes.float, tuple(vv.cast(dtypes.float) for vv in x.src), x.arg).cast(dtypes.bfloat16)))
|
||||
for op in [UnaryOps.SQRT, UnaryOps.EXP2, UnaryOps.LOG2, UnaryOps.SIN]]
|
||||
(UPat((UnaryOps.SQRT, UnaryOps.EXP2, UnaryOps.LOG2, UnaryOps.SIN), dtype=dtypes.bfloat16, name="x"),
|
||||
lambda x: (UOp(x.op, dtypes.float, tuple(vv.cast(dtypes.float) for vv in x.src), x.arg).cast(dtypes.bfloat16))),
|
||||
]) + extra_pm
|
||||
|
||||
string_rewrite = PatternMatcher([
|
||||
@@ -387,11 +386,11 @@ class AMDRenderer(CStyleLanguage):
|
||||
type_map = {dtypes.bfloat16: "hip_bfloat16"}
|
||||
extra_matcher = PatternMatcher([
|
||||
# cast bfloat16 alus to float
|
||||
(UPat(Ops.ALU, arg=TernaryOps.WHERE, src=(UPat.var("b"), UPat.var("x", dtype=dtypes.bfloat16), UPat.var("y", dtype=dtypes.bfloat16))),
|
||||
lambda b,x,y: UOp(Ops.ALU, arg=TernaryOps.WHERE, dtype=dtypes.float, src=(b,x.cast(dtypes.float),y.cast(dtypes.float))).cast(dtypes.bfloat16)),
|
||||
(UPat(Ops.ALU, dtype=dtypes.bfloat16, name="x"),
|
||||
(UPat(Ops.WHERE, src=(UPat.var("b"), UPat.var("x", dtype=dtypes.bfloat16), UPat.var("y", dtype=dtypes.bfloat16))),
|
||||
lambda b,x,y: UOp(Ops.WHERE, dtype=dtypes.float, src=(b,x.cast(dtypes.float),y.cast(dtypes.float))).cast(dtypes.bfloat16)),
|
||||
(UPat(GroupOp.ALU, dtype=dtypes.bfloat16, name="x"),
|
||||
lambda x: UOp(x.op, dtypes.float, tuple(vv.cast(dtypes.float) for vv in x.src), x.arg).cast(dtypes.bfloat16)),
|
||||
(UPat(Ops.ALU, dtypes.bool, name="alu", src=(UPat.var("x", dtype=dtypes.bfloat16), UPat.var("y", dtype=dtypes.bfloat16))),
|
||||
(UPat(GroupOp.ALU, dtypes.bool, name="alu", src=(UPat.var("x", dtype=dtypes.bfloat16), UPat.var("y", dtype=dtypes.bfloat16))),
|
||||
lambda alu,x,y: UOp(alu.op, dtypes.bool, (x.cast(dtypes.float), y.cast(dtypes.float)), alu.arg)),
|
||||
# add float intermediate casting for bfloat16
|
||||
(UPat(Ops.CAST, name="x", src=UPat.var("y", dtypes.bfloat16)),lambda x,y: y.cast(dtypes.float).cast(x.dtype) if x.dtype!=dtypes.float else None),
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
from typing import Dict, Callable, List, Optional
|
||||
from llvmlite import ir
|
||||
from tinygrad.dtype import DType, PtrDType, dtypes
|
||||
from tinygrad.ops import UnaryOps, BinaryOps, TernaryOps, Ops, UOp
|
||||
from tinygrad.ops import UnaryOps, BinaryOps, TernaryOps, Ops, UOp, GroupOp
|
||||
from tinygrad.renderer import Renderer
|
||||
|
||||
MFLAGS = ('nsz', 'arcp', 'contract', 'afn') # All from fast math, but nnan and ninf and reassoc
|
||||
@@ -141,8 +141,8 @@ class LLVMRenderer(Renderer):
|
||||
backward = src[0]
|
||||
while backward.op is Ops.ASSIGN: backward = backward.src[0]
|
||||
lvars[backward] = lvars[u]
|
||||
elif uop is Ops.ALU:
|
||||
lvars[u] = self.code_for_op[args](bb[-1], *[lvars[x] for x in src], src[0].dtype if args in {BinaryOps.CMPLT, BinaryOps.CMPNE} else dtype)
|
||||
elif uop in GroupOp.ALU:
|
||||
lvars[u] = self.code_for_op[uop](bb[-1], *[lvars[x] for x in src], src[0].dtype if uop in {BinaryOps.CMPLT, BinaryOps.CMPNE} else dtype)
|
||||
elif uop in {Ops.CAST, Ops.BITCAST}: lvars[u] = cast(bb, lvars[src[0]], src[0].dtype, dtype, bitcast=uop is Ops.BITCAST)
|
||||
elif uop in {Ops.DEFINE_GLOBAL, Ops.DEFINE_VAR}: lvars[u] = func.args[buf_index[args]]
|
||||
elif uop is Ops.CONST: lvars[u] = const(args, dtype)
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
from typing import DefaultDict, Dict, List, Union, Optional, cast, Callable
|
||||
from typing import DefaultDict, Dict, List, Union, Optional, cast, Callable, Tuple
|
||||
import struct
|
||||
from collections import defaultdict
|
||||
from tinygrad.ops import BinaryOps, UnaryOps, TernaryOps, Ops, UOp, PatternMatcher, UPat
|
||||
from tinygrad.ops import BinaryOps, UnaryOps, TernaryOps, Ops, UOp, PatternMatcher, UPat, GroupOp
|
||||
from tinygrad.dtype import dtypes, DType, PtrDType, ConstType
|
||||
from tinygrad.renderer import Renderer
|
||||
from tinygrad.renderer.cstyle import CUDARenderer
|
||||
@@ -33,14 +33,14 @@ asm_for_op: Dict[Ops, Callable] = {
|
||||
}
|
||||
|
||||
supports_half: List[Ops] = [UnaryOps.EXP2, BinaryOps.ADD, BinaryOps.MUL, BinaryOps.MAX, BinaryOps.CMPLT, TernaryOps.WHERE]
|
||||
doesnt_support_half: Tuple[Ops, ...] = tuple(op for op in asm_for_op.keys() if op not in supports_half)
|
||||
ptx_matcher = PatternMatcher([
|
||||
# bool CMPNE is XOR, bool CMPLT is XOR+AND (universal makes this slow, this is for renderer only)
|
||||
(UPat.var('x', dtype=dtypes.bool).ne(UPat.var('y')), lambda x,y: x^y),
|
||||
(UPat.var('x', dtype=dtypes.bool).lt(UPat.var('y')), lambda x,y: (x^True)&y),
|
||||
# upcast to float32 all the ops that don't support half
|
||||
*[(UPat(Ops.ALU, arg=op, dtype=dtypes.half, name="x"),
|
||||
lambda x: (UOp(x.op, dtypes.float32, tuple(vv.cast(dtypes.float32) for vv in x.src), x.arg).cast(dtypes.half)))
|
||||
for op in asm_for_op.keys() if op not in supports_half],
|
||||
(UPat(doesnt_support_half, dtype=dtypes.half, name="x"),
|
||||
lambda x: (UOp(x.op, dtypes.float32, tuple(vv.cast(dtypes.float32) for vv in x.src), x.arg).cast(dtypes.half))),
|
||||
# load/store bool -> uint8
|
||||
(UPat(Ops.LOAD, dtypes.bool, src=(UPat(dtype=dtypes.int64),), name="x", allow_any_len=True),
|
||||
lambda x: UOp(x.op, dtypes.uint8, x.src[0:1] + ((x.src[1].cast(dtypes.uint8),) if len(x.src) >= 2 else ()) + x.src[2:]).cast(dtypes.bool)),
|
||||
@@ -50,8 +50,8 @@ ptx_matcher = PatternMatcher([
|
||||
(UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var("idx"))), lambda buf,idx: buf.cast(dtypes.int64) + idx.cast(dtypes.int64)*buf.dtype.itemsize),
|
||||
(UPat(Ops.CAST, name="x"), lambda x: x.src[0] if isinstance(x.dtype, PtrDType) else None),
|
||||
# ptx shr and shl instructions require y to be uint
|
||||
(UPat.var("x") << UPat.var("y"), lambda x,y: UOp(Ops.ALU, x.dtype, (x,y.cast(dtypes.uint)), BinaryOps.SHL) if y.dtype != dtypes.uint else None),
|
||||
(UPat.var("x") >> UPat.var("y"), lambda x,y: UOp(Ops.ALU, x.dtype, (x,y.cast(dtypes.uint)), BinaryOps.SHR) if y.dtype != dtypes.uint else None),
|
||||
(UPat.var("x") << UPat.var("y"), lambda x,y: UOp(Ops.SHL, x.dtype, (x,y.cast(dtypes.uint))) if y.dtype != dtypes.uint else None),
|
||||
(UPat.var("x") >> UPat.var("y"), lambda x,y: UOp(Ops.SHR, x.dtype, (x,y.cast(dtypes.uint))) if y.dtype != dtypes.uint else None),
|
||||
])
|
||||
|
||||
class PTXRenderer(Renderer):
|
||||
@@ -166,9 +166,9 @@ class PTXRenderer(Renderer):
|
||||
kk(gate + f"st{mem_type}.{self.mem_types[src[1].dtype]} [{r[src[0]]}+0], {r[src[1]]};")
|
||||
else:
|
||||
if uop is Ops.RANGE: kk(*self.render_loop(loop:=ssa('ridx', u), r[src[0]], "LOOP_"+loop[1:]))
|
||||
elif uop is Ops.ALU:
|
||||
src_dtype = src[0].dtype if args in {BinaryOps.CMPLT, BinaryOps.CMPNE} else dtype
|
||||
kk(self.code_for_op[args](ssa("alu", u), *[r[x] for x in src], src_dtype, self.types[src_dtype]))
|
||||
elif uop in GroupOp.ALU:
|
||||
src_dtype = src[0].dtype if uop in {BinaryOps.CMPLT, BinaryOps.CMPNE} else dtype
|
||||
kk(self.code_for_op[uop](ssa("alu", u), *[r[x] for x in src], src_dtype, self.types[src_dtype]))
|
||||
elif uop is Ops.DEFINE_ACC:
|
||||
if dtype.count > 1:
|
||||
r[u] = [ssa('acc', dtype=self.types[dtype.scalar()]) for _ in range(dtype.count)]
|
||||
@@ -190,7 +190,7 @@ class PTXRenderer(Renderer):
|
||||
elif uop is Ops.LOAD:
|
||||
assert src[0].dtype == dtypes.int64, "load isn't int64"
|
||||
mem_type = '.shared' if src[0].op is Ops.DEFINE_LOCAL or any(x.op is Ops.DEFINE_LOCAL for x in src[0].parents) else '.global'
|
||||
has_gate = len(src) > 2 and src[2].op is Ops.ALU
|
||||
has_gate = len(src) > 2 and src[2].op in GroupOp.ALU
|
||||
if dtype.count > 1:
|
||||
r[u] = [ssa('val', dtype=self.types[dtype.scalar()]) for _ in range(dtype.count)]
|
||||
if has_gate:
|
||||
|
||||
@@ -7,7 +7,7 @@ import pickle, base64, itertools, time, struct
|
||||
from tinygrad.dtype import DType, dtypes, ImageDType, PtrDType, truncate
|
||||
from tinygrad.helpers import all_same, getenv, flatten
|
||||
from tinygrad.device import Compiled, Compiler, Allocator
|
||||
from tinygrad.ops import BinaryOps, TernaryOps, exec_alu, Ops, UOp
|
||||
from tinygrad.ops import BinaryOps, TernaryOps, exec_alu, Ops, UOp, GroupOp
|
||||
from tinygrad.renderer import Renderer
|
||||
from tinygrad.renderer.cstyle import CUDARenderer, MetalRenderer, AMDRenderer, IntelRenderer, ClangRenderer
|
||||
|
||||
@@ -173,10 +173,10 @@ class PythonProgram:
|
||||
def c_map(_, elem): return (elem%16, elem//16)
|
||||
ul[i] = wmma_helper(1, 1, 16, 16, 256, elem, elem, c_map)
|
||||
else: raise NotImplementedError(f"unimplemented tensor core {arg}")
|
||||
elif uop is Ops.ALU:
|
||||
assert all_same([len(x) for x in inp]), f"{[len(x) for x in inp]} doesn't match on {arg}"
|
||||
assert all_same([dtype] + dtp) or arg in {BinaryOps.CMPNE, BinaryOps.CMPLT, TernaryOps.WHERE}, f"dtype mismatch on {arg}"
|
||||
ul[i] = [exec_alu(arg, dtype, p) for p in zip(*inp)]
|
||||
elif uop in GroupOp.ALU:
|
||||
assert all_same([len(x) for x in inp]), f"{[len(x) for x in inp]} doesn't match on {uop}"
|
||||
assert all_same([dtype] + dtp) or uop in {BinaryOps.CMPNE, BinaryOps.CMPLT, TernaryOps.WHERE}, f"dtype mismatch on {uop}"
|
||||
ul[i] = [exec_alu(uop, dtype, p) for p in zip(*inp)]
|
||||
assert i in ul, (uop, dtype, idp, arg)
|
||||
i += 1
|
||||
return time.perf_counter() - st
|
||||
|
||||
@@ -76,8 +76,8 @@ class ShapeTracker:
|
||||
idx, valid = (graph_rewrite(u, symbolic_flat) for u in self.to_indexed_uops())
|
||||
for c in split_uop(idx, BinaryOps.ADD):
|
||||
if c.op is Ops.RANGE: ret[c.arg] = 1
|
||||
if c.op is Ops.ALU and c.arg is BinaryOps.MUL and c.src[0].op is Ops.RANGE and c.src[1].op is Ops.CONST: ret[c.src[0].arg] = c.src[1].arg
|
||||
if c.op is Ops.ALU and c.arg is BinaryOps.MUL and c.src[1].op is Ops.RANGE and c.src[0].op is Ops.CONST: ret[c.src[1].arg] = c.src[0].arg
|
||||
if c.op is Ops.MUL and c.src[0].op is Ops.RANGE and c.src[1].op is Ops.CONST: ret[c.src[0].arg] = c.src[1].arg
|
||||
if c.op is Ops.MUL and c.src[1].op is Ops.RANGE and c.src[0].op is Ops.CONST: ret[c.src[1].arg] = c.src[0].arg
|
||||
used_ranges = [x.arg for x in idx.sparents if x.op is Ops.RANGE]
|
||||
ret = [x if i in used_ranges else 0 for i,x in enumerate(ret)]
|
||||
if not ignore_valid:
|
||||
|
||||
@@ -384,10 +384,9 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
|
||||
def from_uop(y:UOp, **kwargs) -> Tensor:
|
||||
if y.op is Ops.BIND: return Tensor(y, **kwargs, requires_grad=False) # this is the only UOp allowed in Tensor
|
||||
if y.op is Ops.CONST: return Tensor(y.arg, **kwargs, requires_grad=False)
|
||||
if y.op is Ops.ALU:
|
||||
if y.arg is BinaryOps.MUL: return Tensor.from_uop(y.src[0]) * Tensor.from_uop(y.src[1])
|
||||
if y.arg is BinaryOps.ADD: return Tensor.from_uop(y.src[0]) + Tensor.from_uop(y.src[1])
|
||||
if y.arg is BinaryOps.MAX: return Tensor.from_uop(y.src[0]).maximum(Tensor.from_uop(y.src[1]))
|
||||
if y.op is BinaryOps.MUL: return Tensor.from_uop(y.src[0]) * Tensor.from_uop(y.src[1])
|
||||
if y.op is BinaryOps.ADD: return Tensor.from_uop(y.src[0]) + Tensor.from_uop(y.src[1])
|
||||
if y.op is BinaryOps.MAX: return Tensor.from_uop(y.src[0]).maximum(Tensor.from_uop(y.src[1]))
|
||||
raise RuntimeError(f"unhandled UOp {y}")
|
||||
|
||||
# ***** creation entrypoint *****
|
||||
|
||||
@@ -5,13 +5,13 @@ from urllib.parse import parse_qs, urlparse
|
||||
from dataclasses import asdict, dataclass
|
||||
from typing import Any, Dict, List, Tuple, Optional
|
||||
from tinygrad.helpers import colored, getenv, to_function_name, tqdm, unwrap, word_wrap
|
||||
from tinygrad.ops import TrackedRewriteContext, UOp, Ops, lines
|
||||
from tinygrad.ops import TrackedRewriteContext, UOp, Ops, lines, GroupOp
|
||||
from tinygrad.codegen.kernel import Kernel
|
||||
|
||||
uops_colors = {Ops.ALU: "#ffffc0", Ops.LOAD: "#ffc0c0", Ops.STORE: "#c0ffc0", Ops.CONST: "#e0e0e0", Ops.VCONST: "#e0e0e0",
|
||||
uops_colors = {Ops.LOAD: "#ffc0c0", Ops.STORE: "#c0ffc0", Ops.CONST: "#e0e0e0", Ops.VCONST: "#e0e0e0",
|
||||
Ops.DEFINE_GLOBAL: "#ffe0b0", Ops.DEFINE_LOCAL: "#ffe0d0", Ops.DEFINE_ACC: "#f0ffe0", Ops.REDUCE: "#C4A484",
|
||||
Ops.RANGE: "#c8a0e0", Ops.ASSIGN: "#e0ffc0", Ops.BARRIER: "#ff8080", Ops.IF: "#c8b0c0", Ops.SPECIAL: "#c0c0ff",
|
||||
Ops.INDEX: "#e8ffa0", Ops.WMMA: "#efefc0", Ops.VIEW: "#C8F9D4", Ops.REDUCE_AXIS: "#f58488"}
|
||||
Ops.INDEX: "#e8ffa0", Ops.WMMA: "#efefc0", Ops.VIEW: "#C8F9D4", Ops.REDUCE_AXIS: "#f58488", **{x:"#ffffc0" for x in GroupOp.ALU}}
|
||||
|
||||
# ** API spec
|
||||
|
||||
|
||||
Reference in New Issue
Block a user