pm_render [pr] (#7430)

* pm_render [pr]

* test fixes

* use gep, not src

* ptx only symbolic, not sym

* move cast rules
This commit is contained in:
George Hotz
2024-10-31 14:04:50 +07:00
committed by GitHub
parent 8fff8fc3e7
commit 17c9a9fde4
4 changed files with 50 additions and 47 deletions

View File

@@ -456,58 +456,58 @@ class TestExpander(unittest.TestCase):
def test_expand_add_broadcast(self):
e1 = UOp(UOps.EXPAND, dtypes.int, (UOp.const(dtypes.int.vec(4), tuple(x for x in range(4))),), ((1,4),))
sink = expander_rewrite(e1+3)
assert sink.op is UOps.EXPAND and len(sink.src[0].src) == 4
self.assertListEqual([x.arg for x in sink.src[0].src], [3,4,5,6])
assert sink.op is UOps.EXPAND and len(sink.src[0].arg) == 4
self.assertTupleEqual(sink.src[0].arg, (3,4,5,6))
def test_contract_simple(self):
e1 = UOp(UOps.EXPAND, dtypes.int, (UOp.const(dtypes.int.vec(4), tuple(x for x in range(4))),), ((1,4),))
con = UOp(UOps.CONTRACT, dtypes.int.vec(4), (e1,), ((1,4),))
sink = expander_rewrite(con)
assert sink.op is UOps.VECTORIZE and len(sink.src) == 4
self.assertListEqual([x.arg for x in sink.src], [0,1,2,3])
self.assertEqual(sink.op, UOps.VCONST)
self.assertTupleEqual(sink.arg, (0,1,2,3))
def test_contract_axis_1(self):
e1 = UOp(UOps.EXPAND, dtypes.int, (UOp.const(dtypes.int.vec(16), tuple(x for x in range(16))),), ((1,4),(2,4)))
con = UOp(UOps.CONTRACT, dtypes.int.vec(4), (e1,), ((1,4),))
sink = expander_rewrite(con)
assert sink.op is UOps.EXPAND and len(sink.src[0].src) == 16 and sink.arg == ((2,4),)
assert sink.src[0].op is UOps.VECTORIZE and len(sink.src[0].src) == 16
self.assertListEqual([x.arg for x in sink.src[0].src][0:4], [0,4,8,12])
self.assertListEqual([x.arg for x in sink.src[0].src][12:], [3,7,11,15])
assert sink.op is UOps.EXPAND and len(sink.src[0].arg) == 16 and sink.arg == ((2,4),)
assert sink.src[0].op is UOps.VCONST
self.assertTupleEqual(sink.src[0].arg[0:4], (0,4,8,12))
self.assertTupleEqual(sink.src[0].arg[12:], (3,7,11,15))
def test_contract_axis_2(self):
e1 = UOp(UOps.EXPAND, dtypes.int, (UOp.const(dtypes.int.vec(16), tuple(x for x in range(16))),), ((1,4),(2,4)))
con = UOp(UOps.CONTRACT, dtypes.int.vec(4), (e1,), ((2,4),))
sink = expander_rewrite(con)
assert sink.op is UOps.EXPAND and len(sink.src[0].src) == 16 and sink.arg == ((1,4),)
assert sink.src[0].op is UOps.VECTORIZE and len(sink.src[0].src) == 16
self.assertListEqual([x.arg for x in sink.src[0].src][0:4], [0,1,2,3])
self.assertListEqual([x.arg for x in sink.src[0].src][12:], [12,13,14,15])
assert sink.op is UOps.EXPAND and len(sink.src[0].arg) == 16 and sink.arg == ((1,4),)
assert sink.src[0].op is UOps.VCONST
self.assertTupleEqual(sink.src[0].arg[0:4], (0,1,2,3))
self.assertTupleEqual(sink.src[0].arg[12:], (12,13,14,15))
def test_contract_axis_2_big(self):
e1 = UOp(UOps.EXPAND, dtypes.int, (UOp.const(dtypes.int.vec(16), tuple(x for x in range(16))),), ((1,2),(2,2),(3,2),(4,2)))
con = UOp(UOps.CONTRACT, dtypes.int.vec(2), (e1,), ((2,2),))
sink = expander_rewrite(con)
assert sink.op is UOps.EXPAND and sink.arg == ((1, 2), (3, 2), (4, 2))
self.assertListEqual([x.arg for x in sink.src[0].src][0:2], [0,4])
self.assertListEqual([x.arg for x in sink.src[0].src][12:14], [10,14])
self.assertTupleEqual(sink.src[0].arg[0:2], (0,4))
self.assertTupleEqual(sink.src[0].arg[12:14], (10,14))
def test_contract_multi_axis(self):
e1 = UOp(UOps.EXPAND, dtypes.int, (UOp.const(dtypes.int.vec(16), tuple(x for x in range(16))),), ((1,2),(2,2),(3,2),(4,2)))
sink = expander_rewrite(UOp(UOps.CONTRACT, dtypes.int.vec(4), (e1,), ((3, 2), (2, 2))))
assert sink.op is UOps.EXPAND and sink.arg == ((1, 2), (4, 2))
self.assertListEqual([x.arg for x in sink.src[0].src][0:4], [0, 4, 2, 6])
self.assertTupleEqual(sink.src[0].arg[0:4], (0, 4, 2, 6))
sink = expander_rewrite(UOp(UOps.CONTRACT, dtypes.int.vec(4), (e1,), ((2, 2), (3, 2))))
assert sink.op is UOps.EXPAND and sink.arg == ((1, 2), (4, 2))
self.assertListEqual([x.arg for x in sink.src[0].src][0:4], [0, 2, 4, 6])
self.assertTupleEqual(sink.src[0].arg[0:4], (0, 2, 4, 6))
def test_contract_mid(self):
e1 = UOp(UOps.EXPAND, dtypes.int, (UOp.const(dtypes.int.vec(8), tuple(x for x in range(8))),), ((1,2),(2,2),(3,2)))
con = UOp(UOps.CONTRACT, dtypes.int.vec(2), (e1,), ((2,2),))
sink = expander_rewrite(con)
assert sink.op is UOps.EXPAND and sink.arg == ((1,2),(3,2))
assert sink.src[0].op is UOps.VECTORIZE and len(sink.src[0].src) == 8
self.assertListEqual([x.arg for x in sink.src[0].src], [0,2,1,3,4,6,5,7])
assert sink.src[0].op is UOps.VCONST and len(sink.src[0].arg) == 8
self.assertTupleEqual(sink.src[0].arg, (0,2,1,3,4,6,5,7))
def test_contract_no_expand(self):
e1 = UOp(UOps.DEFINE_VAR, dtypes.int)
@@ -520,25 +520,26 @@ class TestExpander(unittest.TestCase):
e1 = UOp(UOps.EXPAND, dtypes.int, (UOp.const(dtypes.int.vec(4), tuple(x for x in range(4))),), ((1,4),))
con = UOp(UOps.CONTRACT, dtypes.int.vec(8), (e1,), ((1,4), (2,2)))
sink = expander_rewrite(con)
assert sink.op is UOps.VECTORIZE and len(sink.src) == 8
assert sink.src[0] == sink.src[1]
assert sink.src[0] != sink.src[2]
assert sink.src[6] == sink.src[7]
assert sink.op is UOps.VCONST and len(sink.arg) == 8
assert sink.arg[0] == sink.arg[1]
assert sink.arg[0] != sink.arg[2]
assert sink.arg[6] == sink.arg[7]
def test_expand_same_axis(self):
e1 = UOp(UOps.EXPAND, dtypes.int, (UOp.const(dtypes.int.vec(4), tuple(x for x in range(4))),), ((1,4),))
e2 = UOp(UOps.EXPAND, dtypes.int, (UOp.const(dtypes.int.vec(4), tuple(4*x for x in range(4))),), ((1,4),))
sink = expander_rewrite(e1+e2)
assert sink.op is UOps.EXPAND and len(sink.src[0].src) == 4
self.assertListEqual([x.arg for x in sink.src[0].src], [0,5,10,15])
self.assertEqual(sink.op, UOps.EXPAND)
self.assertEqual(sink.src[0].op, UOps.VCONST)
self.assertTupleEqual(sink.src[0].arg, (0,5,10,15))
def test_expand_different_axis(self, flip=False):
e1 = UOp(UOps.EXPAND, dtypes.int, (UOp.const(dtypes.int.vec(4), tuple(4*x for x in range(4))),), ((1,4),))
e2 = UOp(UOps.EXPAND, dtypes.int, (UOp.const(dtypes.int.vec(4), tuple(x for x in range(4))),), ((2,4),))
sink = expander_rewrite((e2+e1) if flip else (e1+e2))
assert sink.op is UOps.EXPAND and len(sink.src[0].src) == 16
assert sink.op is UOps.EXPAND and len(sink.src[0].arg) == 16
assert sink.arg == ((1, 4), (2, 4))
self.assertListEqual([x.arg for x in sink.src[0].src], [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15])
self.assertTupleEqual(sink.src[0].arg, (0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15))
def test_expand_different_axis_flip(self): self.test_expand_different_axis(True)
@@ -621,7 +622,7 @@ class TestLoadStoreFolder(unittest.TestCase):
sink = float4_rewrite(sink)
assert len([x for x in sink.sparents if x.op is UOps.LOAD]) == 1
single_load = [x for x in sink.sparents if x.op is UOps.LOAD][0]
self.assertEqual(single_load.src[1].op, UOps.VECTORIZE)
self.assertEqual(single_load.src[1].op, UOps.CONST)
def test_simple_load_dont_fold_different_gated(self):
buf = UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr())

View File

@@ -100,8 +100,8 @@ def simplify_valid_load(buf:UOp, start_idx:UOp, valid:UOp) -> Optional[UOp]:
# for X0 + X1 + ... >= 1, check if it's out of bound when Xi = 0 for all i
if not is_upper_bound and c == 1 and all(is_irreducible(u) and u.vmin == 0 for u in split_uop(X, BinaryOps.ADD)):
testidx = functools.reduce(lambda nowidx,u: nowidx.substitute({u:u.const_like(0)}), split_uop(X, BinaryOps.ADD), idx)
testidx = graph_rewrite(testidx, sym)
if testidx.src[0].vmax < 0 or testidx.src[1].vmax < 0:
testidx = testidx.simplify()
if testidx.gep(0).vmax < 0 or testidx.gep(1).vmax < 0:
drop_stmt.append(stmt)
continue
@@ -110,7 +110,7 @@ def simplify_valid_load(buf:UOp, start_idx:UOp, valid:UOp) -> Optional[UOp]:
test_value = c + 1 if is_upper_bound else c - 1
for i,b in zip(idx.src, (buf.dtype.shape[1], buf.dtype.shape[0])):
if is_increasing(i):
rw = graph_rewrite(i.substitute({X:X.const_like(test_value)}), sym)
rw = i.substitute({X:X.const_like(test_value)}).simplify()
if rw.vmin >= b or rw.vmax < 0:
drop_stmt.append(stmt)
break
@@ -241,9 +241,9 @@ sym = symbolic_flat+PatternMatcher([
(UPat(UOps.ASSIGN, src=(UPat.var('x'), UPat.var('x'))), lambda x: x),
# ASSIGN to global is just self
(UPat(UOps.ASSIGN, src=(UPat(UOps.DEFINE_GLOBAL), UPat.var("x"))), lambda x: x),
# VECTORIZE/GEP: the expander rule allows tuple GEP creation, this is just for removal
(UPat(UOps.VECTORIZE, src=UPat(UOps.GEP, src=(UPat(name="x"),)), name="vec"),
lambda vec,x: x if x.dtype == vec.dtype and tuple(y.arg[0] for y in vec.src) == tuple(range(len(vec.src))) else None),
# VECTORIZE/CONST, VECTORIZE/GEP
(UPat(UOps.VECTORIZE, src=UPat(UOps.CONST), name="vec"), lambda vec: UOp.const(vec.dtype, tuple(x.arg for x in vec.src))),
(UPat(UOps.VECTORIZE, src=UPat(UOps.GEP, src=(UPat(name="x"),)), name="vec"), lambda vec,x: x.gep(tuple(y.arg[0] for y in vec.src))),
# reorder ALU/VECTORIZE
(UPat(UOps.ALU, src=(UPat(UOps.VECTORIZE, src=UPat(name='x')), UPat(UOps.VECTORIZE, src=UPat(name='y'))), name='alu'),
lambda x,y,alu: UOp(UOps.VECTORIZE, alu.dtype, (UOp(UOps.ALU, alu.dtype.scalar(), (x,y), alu.arg),)*alu.dtype.count)),
@@ -277,11 +277,7 @@ sym = symbolic_flat+PatternMatcher([
# indexing, with cast or where
(acc_pat.assign(UPat.var("idx").eq(UPat(UOps.RANGE, name="rng")).cast()*index_load+acc_pat), index_collapse),
(acc_pat.assign(UPat.var("idx").eq(UPat(UOps.RANGE, name="rng")).where(index_load, UPat.const(None, 0.0))+acc_pat), index_collapse),
# GEP/CAST const rules
(UPat(UOps.CAST, name="root", src=UPat.cvar("c")), lambda root, c: root.const_like(c.arg)),
# ** self folding **
# cast NOOP (NOTE: it's str to deal with PtrDType)
(UPat(UOps.CAST, name="root"), lambda root: root.src[0] if str(root.dtype) == str(root.src[0].dtype) else None),
(UPat(UOps.DEFINE_ACC, src=(UPat.var("x"),)), lambda x: x), # a DEFINE_ACC without ranges is a CONST
(UPat(UOps.ASSIGN, src=(UPat.cvar(),UPat.var("x"))), lambda x: x), # an ASSIGN to a const is a NOOP
# x!=0 -> (bool)x
@@ -409,8 +405,6 @@ def create_gate(root:UOp) -> Optional[UOp]:
return None if idx.op is not UOps.INDEX or len(idx.src) == 2 or (ret:=_gate_srcs(root, idx.src[2])) is root else ret
expander = PatternMatcher([
(UPat(UOps.VECTORIZE, src=UPat(UOps.CONST), name="vec"), lambda vec: UOp.const(vec.dtype, tuple(x.arg for x in vec.src))),
(UPat(UOps.VECTORIZE, src=UPat(UOps.GEP, src=(UPat(name="x"),)), name="vec"), lambda vec,x: x.gep(tuple(y.arg[0] for y in vec.src))),
# double expand
(UPat(UOps.EXPAND, name="outer", src=(UPat(UOps.EXPAND, name="inner"),)),
lambda outer, inner: UOp(UOps.EXPAND, outer.dtype, (inner.src[0],), inner.arg+outer.arg)),
@@ -496,10 +490,6 @@ def delete_redundant_gates(root:UOp) -> Optional[UOp]:
return UOp(UOps.STORE, root.dtype, root.src[:2], root.arg)
finalize = PatternMatcher([
(UPat(UOps.CONST, name='c'),
lambda c: UOp(UOps.VECTORIZE, c.dtype, (UOp.const(c.dtype.scalar(), c.arg),)*c.dtype.vcount) if c.dtype.vcount > 1 else None),
(UPat(UOps.VCONST, name='c'), lambda c: UOp(UOps.VECTORIZE, c.dtype, tuple(UOp.const(c.dtype.scalar(), x) for x in c.arg))),
(UPat(UOps.GEP, name='gep'), lambda gep: UOp(UOps.VECTORIZE, gep.dtype, tuple(gep.src[0].gep(x) for x in gep.arg)) if len(gep.arg) > 1 else None),
# move masks of loads/stores
# TODO: this should be an IF instead of a masked STORE
(UPat((UOps.LOAD, UOps.STORE), src=(UPat.any(masked_index:=UPat(UOps.INDEX, src=(UPat(name="buf"), UPat(name="idx"), UPat(name="mask"))),
@@ -508,6 +498,15 @@ finalize = PatternMatcher([
(UPat(UOps.STORE, name="root"), delete_redundant_gates),
])
# for rendering, we don't use vector
pm_render = PatternMatcher([
(UPat(UOps.CONST, name='c'),
lambda c: UOp(UOps.VECTORIZE, c.dtype, (UOp.const(c.dtype.scalar(), c.arg),)*c.dtype.vcount) if c.dtype.vcount > 1 else None),
(UPat(UOps.VCONST, name='c'), lambda c: UOp(UOps.VECTORIZE, c.dtype, tuple(UOp.const(c.dtype.scalar(), x) for x in c.arg))),
(UPat(UOps.GEP, name='gep'), lambda gep: UOp(UOps.VECTORIZE, gep.dtype, tuple(gep.src[0].gep(x) for x in gep.arg)) if len(gep.arg) > 1 else None),
(UPat(UOps.VECTORIZE, src=(UPat(name='x'),)), lambda x: x),
])
# *** uop graph ***
def full_graph_rewrite(sink:UOp, opts:Optional[Renderer]=None) -> UOp:
@@ -531,5 +530,6 @@ def full_graph_rewrite(sink:UOp, opts:Optional[Renderer]=None) -> UOp:
# finalize
sink = graph_rewrite(sink, sym+finalize+get_extra_patterns(tuple(opts.code_for_op.keys()) if opts is not None else (), TRANSCENDENTAL>=2))
if opts is not None and opts.extra_matcher is not None: sink = graph_rewrite(sink, opts.extra_matcher)
# for rendering without sym (including the rules from the renderer)
sink = graph_rewrite(sink, (pm_render+opts.extra_matcher if opts is not None and opts.extra_matcher is not None else pm_render))
return sink

View File

@@ -1091,6 +1091,9 @@ symbolic = PatternMatcher([
# ** mod **
# mod folding
(UPat.var("x") % UPat.cvar("c", vec=False), lambda x,c: newx if 0 < c.arg and (newx:=mod_folding(x,c.arg)) is not None else None),
# *** cast ***
(UPat(UOps.CAST, name="root", src=UPat.cvar("c")), lambda root, c: root.const_like(c.arg)),
(UPat(UOps.CAST, name="root"), lambda root: root.src[0] if root.dtype == root.src[0].dtype else None),
])
symbolic_flat = symbolic+PatternMatcher([

View File

@@ -1,8 +1,7 @@
from typing import DefaultDict, Dict, List, Union, Optional, cast, Callable
import struct
from collections import defaultdict
from tinygrad.ops import BinaryOps, UnaryOps, TernaryOps, Op, UOps, UOp, PatternMatcher, UPat
from tinygrad.codegen.uopgraph import sym
from tinygrad.ops import BinaryOps, UnaryOps, TernaryOps, Op, UOps, UOp, PatternMatcher, UPat, symbolic
from tinygrad.dtype import dtypes, DType, PtrDType, ConstType
from tinygrad.renderer import Renderer
from tinygrad.renderer.cstyle import CUDARenderer
@@ -34,7 +33,7 @@ asm_for_op: Dict[Op, Callable] = {
}
supports_half: List[Op] = [UnaryOps.EXP2, BinaryOps.ADD, BinaryOps.MUL, BinaryOps.MAX, BinaryOps.CMPLT, TernaryOps.WHERE]
ptx_matcher = sym+PatternMatcher([
ptx_matcher = symbolic+PatternMatcher([
# bool CMPNE is XOR, bool CMPLT is XOR+AND (universal makes this slow, this is for renderer only)
(UPat.var('x', dtype=dtypes.bool).ne(UPat.var('y')), lambda x,y: x^y),
(UPat.var('x', dtype=dtypes.bool).lt(UPat.var('y')), lambda x,y: (x^True)&y),