CSE at uop level (#1483)

* uop-level cse

* add test

* don't cache reduce alu ops

* types

* rename variable

* fix

* delete lines
This commit is contained in:
David Hou
2023-08-19 20:40:40 -10:00
committed by GitHub
parent b9feb1b743
commit 4fbce972d7
2 changed files with 37 additions and 17 deletions

View File

@@ -39,5 +39,22 @@ class TestLinearizer(unittest.TestCase):
assert num_loads <= 4, "more load uops than needed"
assert num_loads >= 4, "unexpected number of uops, maybe this test needs updating?"
def test_upcast_cse(self):
# when upcasting, within a subtree, there may be common expressions.
if not isinstance(Device[Device.DEFAULT], Compiled):
self.skipTest("Only Compiled uses linearizer")
a, b = Tensor.randn(1).realize(), Tensor.randn(1).realize()
r = a.expand([2]) + b.expand([2])
ast = r.lazydata.op
r = r.realize() # realize an output buffer
k = Linearizer(ast, r.lazydata, Device[Device.DEFAULT].linearizer_opts)
k.process()
k.upcast()
k.linearize()
num_ops = len([uop for uop in k.uops if uop.uop == UOps.ALU])
assert num_ops <= 1, "more alu uops than needed"
if __name__ == '__main__':
unittest.main()

View File

@@ -4,7 +4,7 @@ from collections import defaultdict
from enum import Enum, auto
from tinygrad.helpers import dedup, colored, ImageDType, DEBUG, prod, dtypes, mnum, DType, all_same, partition
from tinygrad.ops import LazyOp, FlopCounter, get_lazyop_info, UnaryOps
from tinygrad.ops import LazyOp, FlopCounter, get_lazyop_info, UnaryOps, Op
from tinygrad.lazy import LazyBuffer
from tinygrad.ops import MovementOps, ReduceOps, BinaryOps, TernaryOps
from tinygrad.runtime.lib import RawConst, buf_is_kernel_arg
@@ -298,7 +298,7 @@ class Linearizer:
# uops
self.uops: List[UOp] = []
self.load_cache: Dict[str, Token] = {}
self.saved_exprs: Dict[LazyOp, List[Token]] = dict()
self.saved_exprs: Dict[Tuple[Op, Tuple[Token, ...]], Token] = dict()
# add global buffers
for buf,name in self.arg_bufs.items():
@@ -492,6 +492,11 @@ class Linearizer:
if DEBUG >= 4: print(self.uops[-1])
return out
def uop_alu(self, out: Token, vin: List[Token], op: Op) -> Token:
key = (op, tuple(vin))
if key not in self.saved_exprs: self.saved_exprs[key] = self.uop(UOps.ALU, out, vin, op)
return self.saved_exprs[key]
def ast_parse(self, x, acc, loaded_buffers, ssa, do_reduce=False) -> List[Token]:
if x.__class__ is not LazyOp: return loaded_buffers[x]
if x.op in [UnaryOps.NOOP, UnaryOps.CAST]: return self.ast_parse(x.src[0], acc, loaded_buffers, ssa) # cast isn't an ALU op
@@ -505,21 +510,19 @@ class Linearizer:
# Reorder sources to put constants first so get_grouped_maybe_float4 can fold the op
srcs = sorted(x.src, key=lambda x: (x.realized.__class__ != RawConst) if x.__class__ == LazyBuffer else 0)
x.src = tuple(srcs)
if x not in self.saved_exprs:
values = [self.ast_parse(v, acc, loaded_buffers, ssa) for v in x.src]
ops = {ReduceOps.SUM:BinaryOps.ADD, ReduceOps.MAX:BinaryOps.MAX, TernaryOps.MULACC:TernaryOps.MULACC}
if x.op in ops:
ret = [(idx, self.uop(UOps.ALU, val[-1], list(val), ops[x.op])) for idx, val in get_grouped_maybe_float4(*values, acc, grouping_allowed=self.opts.supports_float4_alu)]
else:
ret = [(idx, self.uop(UOps.ALU, ssa('alu', dtypes._float4) if any(x.dtype == dtypes._float4 and x.offset is None for x in val) else ssa('alu'), list(val), x.op)) for idx, val in get_grouped_maybe_float4(*values, grouping_allowed=self.opts.supports_float4_alu and x.op not in {BinaryOps.CMPLT, TernaryOps.WHERE})]
ordered_ret: List[Optional[Token]] = [None]*len(values[0])
# scatter
for i,j in ret:
for o,k in enumerate(i):
ordered_ret[k] = Token(j.name, j.dtype, o) if j.dtype == dtypes._float4 else j
assert all(isinstance(x, Token) for x in ordered_ret), "some tokens didn't get scattered?"
self.saved_exprs[x] = cast(List[Token], ordered_ret)
return self.saved_exprs[x]
values = [self.ast_parse(v, acc, loaded_buffers, ssa) for v in x.src]
ops = {ReduceOps.SUM:BinaryOps.ADD, ReduceOps.MAX:BinaryOps.MAX, TernaryOps.MULACC:TernaryOps.MULACC}
if x.op in ops:
ret = [(idx, self.uop(UOps.ALU, val[-1], list(val), ops[x.op])) for idx, val in get_grouped_maybe_float4(*values, acc, grouping_allowed=self.opts.supports_float4_alu)]
else:
ret = [(idx, self.uop_alu(ssa('alu', dtypes._float4) if any(x.dtype == dtypes._float4 and x.offset is None for x in val) else ssa('alu'), list(val), x.op)) for idx, val in get_grouped_maybe_float4(*values, grouping_allowed=self.opts.supports_float4_alu and x.op not in {BinaryOps.CMPLT, TernaryOps.WHERE})]
ordered_ret: List[Optional[Token]] = [None]*len(values[0])
# scatter
for i,j in ret:
for o,k in enumerate(i):
ordered_ret[k] = Token(j.name, j.dtype, o) if j.dtype == dtypes._float4 else j
assert all(isinstance(x, Token) for x in ordered_ret), "some tokens didn't get scattered?"
return cast(List[Token], ordered_ret)
@property
def first_reduce(self) -> int: return [x!=y for x,y in zip(self.sts[0].shape[:self.shape_len-self.upcasted]+(0,), self.full_shape[:self.shape_len-self.upcasted]+(1,))].index(True)