mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 15:08:02 -05:00
CSE at uop level (#1483)
* uop-level cse * add test * don't cache reduce alu ops * types * rename variable * fix * delete lines
This commit is contained in:
@@ -39,5 +39,22 @@ class TestLinearizer(unittest.TestCase):
|
||||
assert num_loads <= 4, "more load uops than needed"
|
||||
assert num_loads >= 4, "unexpected number of uops, maybe this test needs updating?"
|
||||
|
||||
def test_upcast_cse(self):
|
||||
# when upcasting, within a subtree, there may be common expressions.
|
||||
|
||||
if not isinstance(Device[Device.DEFAULT], Compiled):
|
||||
self.skipTest("Only Compiled uses linearizer")
|
||||
|
||||
a, b = Tensor.randn(1).realize(), Tensor.randn(1).realize()
|
||||
r = a.expand([2]) + b.expand([2])
|
||||
ast = r.lazydata.op
|
||||
r = r.realize() # realize an output buffer
|
||||
k = Linearizer(ast, r.lazydata, Device[Device.DEFAULT].linearizer_opts)
|
||||
k.process()
|
||||
k.upcast()
|
||||
k.linearize()
|
||||
num_ops = len([uop for uop in k.uops if uop.uop == UOps.ALU])
|
||||
assert num_ops <= 1, "more alu uops than needed"
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
||||
@@ -4,7 +4,7 @@ from collections import defaultdict
|
||||
from enum import Enum, auto
|
||||
|
||||
from tinygrad.helpers import dedup, colored, ImageDType, DEBUG, prod, dtypes, mnum, DType, all_same, partition
|
||||
from tinygrad.ops import LazyOp, FlopCounter, get_lazyop_info, UnaryOps
|
||||
from tinygrad.ops import LazyOp, FlopCounter, get_lazyop_info, UnaryOps, Op
|
||||
from tinygrad.lazy import LazyBuffer
|
||||
from tinygrad.ops import MovementOps, ReduceOps, BinaryOps, TernaryOps
|
||||
from tinygrad.runtime.lib import RawConst, buf_is_kernel_arg
|
||||
@@ -298,7 +298,7 @@ class Linearizer:
|
||||
# uops
|
||||
self.uops: List[UOp] = []
|
||||
self.load_cache: Dict[str, Token] = {}
|
||||
self.saved_exprs: Dict[LazyOp, List[Token]] = dict()
|
||||
self.saved_exprs: Dict[Tuple[Op, Tuple[Token, ...]], Token] = dict()
|
||||
|
||||
# add global buffers
|
||||
for buf,name in self.arg_bufs.items():
|
||||
@@ -492,6 +492,11 @@ class Linearizer:
|
||||
if DEBUG >= 4: print(self.uops[-1])
|
||||
return out
|
||||
|
||||
def uop_alu(self, out: Token, vin: List[Token], op: Op) -> Token:
|
||||
key = (op, tuple(vin))
|
||||
if key not in self.saved_exprs: self.saved_exprs[key] = self.uop(UOps.ALU, out, vin, op)
|
||||
return self.saved_exprs[key]
|
||||
|
||||
def ast_parse(self, x, acc, loaded_buffers, ssa, do_reduce=False) -> List[Token]:
|
||||
if x.__class__ is not LazyOp: return loaded_buffers[x]
|
||||
if x.op in [UnaryOps.NOOP, UnaryOps.CAST]: return self.ast_parse(x.src[0], acc, loaded_buffers, ssa) # cast isn't an ALU op
|
||||
@@ -505,21 +510,19 @@ class Linearizer:
|
||||
# Reorder sources to put constants first so get_grouped_maybe_float4 can fold the op
|
||||
srcs = sorted(x.src, key=lambda x: (x.realized.__class__ != RawConst) if x.__class__ == LazyBuffer else 0)
|
||||
x.src = tuple(srcs)
|
||||
if x not in self.saved_exprs:
|
||||
values = [self.ast_parse(v, acc, loaded_buffers, ssa) for v in x.src]
|
||||
ops = {ReduceOps.SUM:BinaryOps.ADD, ReduceOps.MAX:BinaryOps.MAX, TernaryOps.MULACC:TernaryOps.MULACC}
|
||||
if x.op in ops:
|
||||
ret = [(idx, self.uop(UOps.ALU, val[-1], list(val), ops[x.op])) for idx, val in get_grouped_maybe_float4(*values, acc, grouping_allowed=self.opts.supports_float4_alu)]
|
||||
else:
|
||||
ret = [(idx, self.uop(UOps.ALU, ssa('alu', dtypes._float4) if any(x.dtype == dtypes._float4 and x.offset is None for x in val) else ssa('alu'), list(val), x.op)) for idx, val in get_grouped_maybe_float4(*values, grouping_allowed=self.opts.supports_float4_alu and x.op not in {BinaryOps.CMPLT, TernaryOps.WHERE})]
|
||||
ordered_ret: List[Optional[Token]] = [None]*len(values[0])
|
||||
# scatter
|
||||
for i,j in ret:
|
||||
for o,k in enumerate(i):
|
||||
ordered_ret[k] = Token(j.name, j.dtype, o) if j.dtype == dtypes._float4 else j
|
||||
assert all(isinstance(x, Token) for x in ordered_ret), "some tokens didn't get scattered?"
|
||||
self.saved_exprs[x] = cast(List[Token], ordered_ret)
|
||||
return self.saved_exprs[x]
|
||||
values = [self.ast_parse(v, acc, loaded_buffers, ssa) for v in x.src]
|
||||
ops = {ReduceOps.SUM:BinaryOps.ADD, ReduceOps.MAX:BinaryOps.MAX, TernaryOps.MULACC:TernaryOps.MULACC}
|
||||
if x.op in ops:
|
||||
ret = [(idx, self.uop(UOps.ALU, val[-1], list(val), ops[x.op])) for idx, val in get_grouped_maybe_float4(*values, acc, grouping_allowed=self.opts.supports_float4_alu)]
|
||||
else:
|
||||
ret = [(idx, self.uop_alu(ssa('alu', dtypes._float4) if any(x.dtype == dtypes._float4 and x.offset is None for x in val) else ssa('alu'), list(val), x.op)) for idx, val in get_grouped_maybe_float4(*values, grouping_allowed=self.opts.supports_float4_alu and x.op not in {BinaryOps.CMPLT, TernaryOps.WHERE})]
|
||||
ordered_ret: List[Optional[Token]] = [None]*len(values[0])
|
||||
# scatter
|
||||
for i,j in ret:
|
||||
for o,k in enumerate(i):
|
||||
ordered_ret[k] = Token(j.name, j.dtype, o) if j.dtype == dtypes._float4 else j
|
||||
assert all(isinstance(x, Token) for x in ordered_ret), "some tokens didn't get scattered?"
|
||||
return cast(List[Token], ordered_ret)
|
||||
|
||||
@property
|
||||
def first_reduce(self) -> int: return [x!=y for x,y in zip(self.sts[0].shape[:self.shape_len-self.upcasted]+(0,), self.full_shape[:self.shape_len-self.upcasted]+(1,))].index(True)
|
||||
|
||||
Reference in New Issue
Block a user