mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 23:18:04 -05:00
pm_render [pr] (#7430)
* pm_render [pr] * test fixes * use gep, not src * ptx only symbolic, not sym * move cast rules
This commit is contained in:
@@ -456,58 +456,58 @@ class TestExpander(unittest.TestCase):
|
||||
def test_expand_add_broadcast(self):
|
||||
e1 = UOp(UOps.EXPAND, dtypes.int, (UOp.const(dtypes.int.vec(4), tuple(x for x in range(4))),), ((1,4),))
|
||||
sink = expander_rewrite(e1+3)
|
||||
assert sink.op is UOps.EXPAND and len(sink.src[0].src) == 4
|
||||
self.assertListEqual([x.arg for x in sink.src[0].src], [3,4,5,6])
|
||||
assert sink.op is UOps.EXPAND and len(sink.src[0].arg) == 4
|
||||
self.assertTupleEqual(sink.src[0].arg, (3,4,5,6))
|
||||
|
||||
def test_contract_simple(self):
|
||||
e1 = UOp(UOps.EXPAND, dtypes.int, (UOp.const(dtypes.int.vec(4), tuple(x for x in range(4))),), ((1,4),))
|
||||
con = UOp(UOps.CONTRACT, dtypes.int.vec(4), (e1,), ((1,4),))
|
||||
sink = expander_rewrite(con)
|
||||
assert sink.op is UOps.VECTORIZE and len(sink.src) == 4
|
||||
self.assertListEqual([x.arg for x in sink.src], [0,1,2,3])
|
||||
self.assertEqual(sink.op, UOps.VCONST)
|
||||
self.assertTupleEqual(sink.arg, (0,1,2,3))
|
||||
|
||||
def test_contract_axis_1(self):
|
||||
e1 = UOp(UOps.EXPAND, dtypes.int, (UOp.const(dtypes.int.vec(16), tuple(x for x in range(16))),), ((1,4),(2,4)))
|
||||
con = UOp(UOps.CONTRACT, dtypes.int.vec(4), (e1,), ((1,4),))
|
||||
sink = expander_rewrite(con)
|
||||
assert sink.op is UOps.EXPAND and len(sink.src[0].src) == 16 and sink.arg == ((2,4),)
|
||||
assert sink.src[0].op is UOps.VECTORIZE and len(sink.src[0].src) == 16
|
||||
self.assertListEqual([x.arg for x in sink.src[0].src][0:4], [0,4,8,12])
|
||||
self.assertListEqual([x.arg for x in sink.src[0].src][12:], [3,7,11,15])
|
||||
assert sink.op is UOps.EXPAND and len(sink.src[0].arg) == 16 and sink.arg == ((2,4),)
|
||||
assert sink.src[0].op is UOps.VCONST
|
||||
self.assertTupleEqual(sink.src[0].arg[0:4], (0,4,8,12))
|
||||
self.assertTupleEqual(sink.src[0].arg[12:], (3,7,11,15))
|
||||
|
||||
def test_contract_axis_2(self):
|
||||
e1 = UOp(UOps.EXPAND, dtypes.int, (UOp.const(dtypes.int.vec(16), tuple(x for x in range(16))),), ((1,4),(2,4)))
|
||||
con = UOp(UOps.CONTRACT, dtypes.int.vec(4), (e1,), ((2,4),))
|
||||
sink = expander_rewrite(con)
|
||||
assert sink.op is UOps.EXPAND and len(sink.src[0].src) == 16 and sink.arg == ((1,4),)
|
||||
assert sink.src[0].op is UOps.VECTORIZE and len(sink.src[0].src) == 16
|
||||
self.assertListEqual([x.arg for x in sink.src[0].src][0:4], [0,1,2,3])
|
||||
self.assertListEqual([x.arg for x in sink.src[0].src][12:], [12,13,14,15])
|
||||
assert sink.op is UOps.EXPAND and len(sink.src[0].arg) == 16 and sink.arg == ((1,4),)
|
||||
assert sink.src[0].op is UOps.VCONST
|
||||
self.assertTupleEqual(sink.src[0].arg[0:4], (0,1,2,3))
|
||||
self.assertTupleEqual(sink.src[0].arg[12:], (12,13,14,15))
|
||||
|
||||
def test_contract_axis_2_big(self):
|
||||
e1 = UOp(UOps.EXPAND, dtypes.int, (UOp.const(dtypes.int.vec(16), tuple(x for x in range(16))),), ((1,2),(2,2),(3,2),(4,2)))
|
||||
con = UOp(UOps.CONTRACT, dtypes.int.vec(2), (e1,), ((2,2),))
|
||||
sink = expander_rewrite(con)
|
||||
assert sink.op is UOps.EXPAND and sink.arg == ((1, 2), (3, 2), (4, 2))
|
||||
self.assertListEqual([x.arg for x in sink.src[0].src][0:2], [0,4])
|
||||
self.assertListEqual([x.arg for x in sink.src[0].src][12:14], [10,14])
|
||||
self.assertTupleEqual(sink.src[0].arg[0:2], (0,4))
|
||||
self.assertTupleEqual(sink.src[0].arg[12:14], (10,14))
|
||||
|
||||
def test_contract_multi_axis(self):
|
||||
e1 = UOp(UOps.EXPAND, dtypes.int, (UOp.const(dtypes.int.vec(16), tuple(x for x in range(16))),), ((1,2),(2,2),(3,2),(4,2)))
|
||||
sink = expander_rewrite(UOp(UOps.CONTRACT, dtypes.int.vec(4), (e1,), ((3, 2), (2, 2))))
|
||||
assert sink.op is UOps.EXPAND and sink.arg == ((1, 2), (4, 2))
|
||||
self.assertListEqual([x.arg for x in sink.src[0].src][0:4], [0, 4, 2, 6])
|
||||
self.assertTupleEqual(sink.src[0].arg[0:4], (0, 4, 2, 6))
|
||||
sink = expander_rewrite(UOp(UOps.CONTRACT, dtypes.int.vec(4), (e1,), ((2, 2), (3, 2))))
|
||||
assert sink.op is UOps.EXPAND and sink.arg == ((1, 2), (4, 2))
|
||||
self.assertListEqual([x.arg for x in sink.src[0].src][0:4], [0, 2, 4, 6])
|
||||
self.assertTupleEqual(sink.src[0].arg[0:4], (0, 2, 4, 6))
|
||||
|
||||
def test_contract_mid(self):
|
||||
e1 = UOp(UOps.EXPAND, dtypes.int, (UOp.const(dtypes.int.vec(8), tuple(x for x in range(8))),), ((1,2),(2,2),(3,2)))
|
||||
con = UOp(UOps.CONTRACT, dtypes.int.vec(2), (e1,), ((2,2),))
|
||||
sink = expander_rewrite(con)
|
||||
assert sink.op is UOps.EXPAND and sink.arg == ((1,2),(3,2))
|
||||
assert sink.src[0].op is UOps.VECTORIZE and len(sink.src[0].src) == 8
|
||||
self.assertListEqual([x.arg for x in sink.src[0].src], [0,2,1,3,4,6,5,7])
|
||||
assert sink.src[0].op is UOps.VCONST and len(sink.src[0].arg) == 8
|
||||
self.assertTupleEqual(sink.src[0].arg, (0,2,1,3,4,6,5,7))
|
||||
|
||||
def test_contract_no_expand(self):
|
||||
e1 = UOp(UOps.DEFINE_VAR, dtypes.int)
|
||||
@@ -520,25 +520,26 @@ class TestExpander(unittest.TestCase):
|
||||
e1 = UOp(UOps.EXPAND, dtypes.int, (UOp.const(dtypes.int.vec(4), tuple(x for x in range(4))),), ((1,4),))
|
||||
con = UOp(UOps.CONTRACT, dtypes.int.vec(8), (e1,), ((1,4), (2,2)))
|
||||
sink = expander_rewrite(con)
|
||||
assert sink.op is UOps.VECTORIZE and len(sink.src) == 8
|
||||
assert sink.src[0] == sink.src[1]
|
||||
assert sink.src[0] != sink.src[2]
|
||||
assert sink.src[6] == sink.src[7]
|
||||
assert sink.op is UOps.VCONST and len(sink.arg) == 8
|
||||
assert sink.arg[0] == sink.arg[1]
|
||||
assert sink.arg[0] != sink.arg[2]
|
||||
assert sink.arg[6] == sink.arg[7]
|
||||
|
||||
def test_expand_same_axis(self):
|
||||
e1 = UOp(UOps.EXPAND, dtypes.int, (UOp.const(dtypes.int.vec(4), tuple(x for x in range(4))),), ((1,4),))
|
||||
e2 = UOp(UOps.EXPAND, dtypes.int, (UOp.const(dtypes.int.vec(4), tuple(4*x for x in range(4))),), ((1,4),))
|
||||
sink = expander_rewrite(e1+e2)
|
||||
assert sink.op is UOps.EXPAND and len(sink.src[0].src) == 4
|
||||
self.assertListEqual([x.arg for x in sink.src[0].src], [0,5,10,15])
|
||||
self.assertEqual(sink.op, UOps.EXPAND)
|
||||
self.assertEqual(sink.src[0].op, UOps.VCONST)
|
||||
self.assertTupleEqual(sink.src[0].arg, (0,5,10,15))
|
||||
|
||||
def test_expand_different_axis(self, flip=False):
|
||||
e1 = UOp(UOps.EXPAND, dtypes.int, (UOp.const(dtypes.int.vec(4), tuple(4*x for x in range(4))),), ((1,4),))
|
||||
e2 = UOp(UOps.EXPAND, dtypes.int, (UOp.const(dtypes.int.vec(4), tuple(x for x in range(4))),), ((2,4),))
|
||||
sink = expander_rewrite((e2+e1) if flip else (e1+e2))
|
||||
assert sink.op is UOps.EXPAND and len(sink.src[0].src) == 16
|
||||
assert sink.op is UOps.EXPAND and len(sink.src[0].arg) == 16
|
||||
assert sink.arg == ((1, 4), (2, 4))
|
||||
self.assertListEqual([x.arg for x in sink.src[0].src], [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15])
|
||||
self.assertTupleEqual(sink.src[0].arg, (0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15))
|
||||
|
||||
def test_expand_different_axis_flip(self): self.test_expand_different_axis(True)
|
||||
|
||||
@@ -621,7 +622,7 @@ class TestLoadStoreFolder(unittest.TestCase):
|
||||
sink = float4_rewrite(sink)
|
||||
assert len([x for x in sink.sparents if x.op is UOps.LOAD]) == 1
|
||||
single_load = [x for x in sink.sparents if x.op is UOps.LOAD][0]
|
||||
self.assertEqual(single_load.src[1].op, UOps.VECTORIZE)
|
||||
self.assertEqual(single_load.src[1].op, UOps.CONST)
|
||||
|
||||
def test_simple_load_dont_fold_different_gated(self):
|
||||
buf = UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr())
|
||||
|
||||
@@ -100,8 +100,8 @@ def simplify_valid_load(buf:UOp, start_idx:UOp, valid:UOp) -> Optional[UOp]:
|
||||
# for X0 + X1 + ... >= 1, check if it's out of bound when Xi = 0 for all i
|
||||
if not is_upper_bound and c == 1 and all(is_irreducible(u) and u.vmin == 0 for u in split_uop(X, BinaryOps.ADD)):
|
||||
testidx = functools.reduce(lambda nowidx,u: nowidx.substitute({u:u.const_like(0)}), split_uop(X, BinaryOps.ADD), idx)
|
||||
testidx = graph_rewrite(testidx, sym)
|
||||
if testidx.src[0].vmax < 0 or testidx.src[1].vmax < 0:
|
||||
testidx = testidx.simplify()
|
||||
if testidx.gep(0).vmax < 0 or testidx.gep(1).vmax < 0:
|
||||
drop_stmt.append(stmt)
|
||||
continue
|
||||
|
||||
@@ -110,7 +110,7 @@ def simplify_valid_load(buf:UOp, start_idx:UOp, valid:UOp) -> Optional[UOp]:
|
||||
test_value = c + 1 if is_upper_bound else c - 1
|
||||
for i,b in zip(idx.src, (buf.dtype.shape[1], buf.dtype.shape[0])):
|
||||
if is_increasing(i):
|
||||
rw = graph_rewrite(i.substitute({X:X.const_like(test_value)}), sym)
|
||||
rw = i.substitute({X:X.const_like(test_value)}).simplify()
|
||||
if rw.vmin >= b or rw.vmax < 0:
|
||||
drop_stmt.append(stmt)
|
||||
break
|
||||
@@ -241,9 +241,9 @@ sym = symbolic_flat+PatternMatcher([
|
||||
(UPat(UOps.ASSIGN, src=(UPat.var('x'), UPat.var('x'))), lambda x: x),
|
||||
# ASSIGN to global is just self
|
||||
(UPat(UOps.ASSIGN, src=(UPat(UOps.DEFINE_GLOBAL), UPat.var("x"))), lambda x: x),
|
||||
# VECTORIZE/GEP: the expander rule allows tuple GEP creation, this is just for removal
|
||||
(UPat(UOps.VECTORIZE, src=UPat(UOps.GEP, src=(UPat(name="x"),)), name="vec"),
|
||||
lambda vec,x: x if x.dtype == vec.dtype and tuple(y.arg[0] for y in vec.src) == tuple(range(len(vec.src))) else None),
|
||||
# VECTORIZE/CONST, VECTORIZE/GEP
|
||||
(UPat(UOps.VECTORIZE, src=UPat(UOps.CONST), name="vec"), lambda vec: UOp.const(vec.dtype, tuple(x.arg for x in vec.src))),
|
||||
(UPat(UOps.VECTORIZE, src=UPat(UOps.GEP, src=(UPat(name="x"),)), name="vec"), lambda vec,x: x.gep(tuple(y.arg[0] for y in vec.src))),
|
||||
# reorder ALU/VECTORIZE
|
||||
(UPat(UOps.ALU, src=(UPat(UOps.VECTORIZE, src=UPat(name='x')), UPat(UOps.VECTORIZE, src=UPat(name='y'))), name='alu'),
|
||||
lambda x,y,alu: UOp(UOps.VECTORIZE, alu.dtype, (UOp(UOps.ALU, alu.dtype.scalar(), (x,y), alu.arg),)*alu.dtype.count)),
|
||||
@@ -277,11 +277,7 @@ sym = symbolic_flat+PatternMatcher([
|
||||
# indexing, with cast or where
|
||||
(acc_pat.assign(UPat.var("idx").eq(UPat(UOps.RANGE, name="rng")).cast()*index_load+acc_pat), index_collapse),
|
||||
(acc_pat.assign(UPat.var("idx").eq(UPat(UOps.RANGE, name="rng")).where(index_load, UPat.const(None, 0.0))+acc_pat), index_collapse),
|
||||
# GEP/CAST const rules
|
||||
(UPat(UOps.CAST, name="root", src=UPat.cvar("c")), lambda root, c: root.const_like(c.arg)),
|
||||
# ** self folding **
|
||||
# cast NOOP (NOTE: it's str to deal with PtrDType)
|
||||
(UPat(UOps.CAST, name="root"), lambda root: root.src[0] if str(root.dtype) == str(root.src[0].dtype) else None),
|
||||
(UPat(UOps.DEFINE_ACC, src=(UPat.var("x"),)), lambda x: x), # a DEFINE_ACC without ranges is a CONST
|
||||
(UPat(UOps.ASSIGN, src=(UPat.cvar(),UPat.var("x"))), lambda x: x), # an ASSIGN to a const is a NOOP
|
||||
# x!=0 -> (bool)x
|
||||
@@ -409,8 +405,6 @@ def create_gate(root:UOp) -> Optional[UOp]:
|
||||
return None if idx.op is not UOps.INDEX or len(idx.src) == 2 or (ret:=_gate_srcs(root, idx.src[2])) is root else ret
|
||||
|
||||
expander = PatternMatcher([
|
||||
(UPat(UOps.VECTORIZE, src=UPat(UOps.CONST), name="vec"), lambda vec: UOp.const(vec.dtype, tuple(x.arg for x in vec.src))),
|
||||
(UPat(UOps.VECTORIZE, src=UPat(UOps.GEP, src=(UPat(name="x"),)), name="vec"), lambda vec,x: x.gep(tuple(y.arg[0] for y in vec.src))),
|
||||
# double expand
|
||||
(UPat(UOps.EXPAND, name="outer", src=(UPat(UOps.EXPAND, name="inner"),)),
|
||||
lambda outer, inner: UOp(UOps.EXPAND, outer.dtype, (inner.src[0],), inner.arg+outer.arg)),
|
||||
@@ -496,10 +490,6 @@ def delete_redundant_gates(root:UOp) -> Optional[UOp]:
|
||||
return UOp(UOps.STORE, root.dtype, root.src[:2], root.arg)
|
||||
|
||||
finalize = PatternMatcher([
|
||||
(UPat(UOps.CONST, name='c'),
|
||||
lambda c: UOp(UOps.VECTORIZE, c.dtype, (UOp.const(c.dtype.scalar(), c.arg),)*c.dtype.vcount) if c.dtype.vcount > 1 else None),
|
||||
(UPat(UOps.VCONST, name='c'), lambda c: UOp(UOps.VECTORIZE, c.dtype, tuple(UOp.const(c.dtype.scalar(), x) for x in c.arg))),
|
||||
(UPat(UOps.GEP, name='gep'), lambda gep: UOp(UOps.VECTORIZE, gep.dtype, tuple(gep.src[0].gep(x) for x in gep.arg)) if len(gep.arg) > 1 else None),
|
||||
# move masks of loads/stores
|
||||
# TODO: this should be an IF instead of a masked STORE
|
||||
(UPat((UOps.LOAD, UOps.STORE), src=(UPat.any(masked_index:=UPat(UOps.INDEX, src=(UPat(name="buf"), UPat(name="idx"), UPat(name="mask"))),
|
||||
@@ -508,6 +498,15 @@ finalize = PatternMatcher([
|
||||
(UPat(UOps.STORE, name="root"), delete_redundant_gates),
|
||||
])
|
||||
|
||||
# for rendering, we don't use vector
|
||||
pm_render = PatternMatcher([
|
||||
(UPat(UOps.CONST, name='c'),
|
||||
lambda c: UOp(UOps.VECTORIZE, c.dtype, (UOp.const(c.dtype.scalar(), c.arg),)*c.dtype.vcount) if c.dtype.vcount > 1 else None),
|
||||
(UPat(UOps.VCONST, name='c'), lambda c: UOp(UOps.VECTORIZE, c.dtype, tuple(UOp.const(c.dtype.scalar(), x) for x in c.arg))),
|
||||
(UPat(UOps.GEP, name='gep'), lambda gep: UOp(UOps.VECTORIZE, gep.dtype, tuple(gep.src[0].gep(x) for x in gep.arg)) if len(gep.arg) > 1 else None),
|
||||
(UPat(UOps.VECTORIZE, src=(UPat(name='x'),)), lambda x: x),
|
||||
])
|
||||
|
||||
# *** uop graph ***
|
||||
|
||||
def full_graph_rewrite(sink:UOp, opts:Optional[Renderer]=None) -> UOp:
|
||||
@@ -531,5 +530,6 @@ def full_graph_rewrite(sink:UOp, opts:Optional[Renderer]=None) -> UOp:
|
||||
# finalize
|
||||
sink = graph_rewrite(sink, sym+finalize+get_extra_patterns(tuple(opts.code_for_op.keys()) if opts is not None else (), TRANSCENDENTAL>=2))
|
||||
|
||||
if opts is not None and opts.extra_matcher is not None: sink = graph_rewrite(sink, opts.extra_matcher)
|
||||
# for rendering without sym (including the rules from the renderer)
|
||||
sink = graph_rewrite(sink, (pm_render+opts.extra_matcher if opts is not None and opts.extra_matcher is not None else pm_render))
|
||||
return sink
|
||||
|
||||
@@ -1091,6 +1091,9 @@ symbolic = PatternMatcher([
|
||||
# ** mod **
|
||||
# mod folding
|
||||
(UPat.var("x") % UPat.cvar("c", vec=False), lambda x,c: newx if 0 < c.arg and (newx:=mod_folding(x,c.arg)) is not None else None),
|
||||
# *** cast ***
|
||||
(UPat(UOps.CAST, name="root", src=UPat.cvar("c")), lambda root, c: root.const_like(c.arg)),
|
||||
(UPat(UOps.CAST, name="root"), lambda root: root.src[0] if root.dtype == root.src[0].dtype else None),
|
||||
])
|
||||
|
||||
symbolic_flat = symbolic+PatternMatcher([
|
||||
|
||||
@@ -1,8 +1,7 @@
|
||||
from typing import DefaultDict, Dict, List, Union, Optional, cast, Callable
|
||||
import struct
|
||||
from collections import defaultdict
|
||||
from tinygrad.ops import BinaryOps, UnaryOps, TernaryOps, Op, UOps, UOp, PatternMatcher, UPat
|
||||
from tinygrad.codegen.uopgraph import sym
|
||||
from tinygrad.ops import BinaryOps, UnaryOps, TernaryOps, Op, UOps, UOp, PatternMatcher, UPat, symbolic
|
||||
from tinygrad.dtype import dtypes, DType, PtrDType, ConstType
|
||||
from tinygrad.renderer import Renderer
|
||||
from tinygrad.renderer.cstyle import CUDARenderer
|
||||
@@ -34,7 +33,7 @@ asm_for_op: Dict[Op, Callable] = {
|
||||
}
|
||||
|
||||
supports_half: List[Op] = [UnaryOps.EXP2, BinaryOps.ADD, BinaryOps.MUL, BinaryOps.MAX, BinaryOps.CMPLT, TernaryOps.WHERE]
|
||||
ptx_matcher = sym+PatternMatcher([
|
||||
ptx_matcher = symbolic+PatternMatcher([
|
||||
# bool CMPNE is XOR, bool CMPLT is XOR+AND (universal makes this slow, this is for renderer only)
|
||||
(UPat.var('x', dtype=dtypes.bool).ne(UPat.var('y')), lambda x,y: x^y),
|
||||
(UPat.var('x', dtype=dtypes.bool).lt(UPat.var('y')), lambda x,y: (x^True)&y),
|
||||
|
||||
Reference in New Issue
Block a user