diff --git a/test/test_linearizer.py b/test/test_linearizer.py index 27550417b9..d62253e74c 100644 --- a/test/test_linearizer.py +++ b/test/test_linearizer.py @@ -39,5 +39,22 @@ class TestLinearizer(unittest.TestCase): assert num_loads <= 4, "more load uops than needed" assert num_loads >= 4, "unexpected number of uops, maybe this test needs updating?" + def test_upcast_cse(self): + # when upcasting, within a subtree, there may be common expressions. + + if not isinstance(Device[Device.DEFAULT], Compiled): + self.skipTest("Only Compiled uses linearizer") + + a, b = Tensor.randn(1).realize(), Tensor.randn(1).realize() + r = a.expand([2]) + b.expand([2]) + ast = r.lazydata.op + r = r.realize() # realize an output buffer + k = Linearizer(ast, r.lazydata, Device[Device.DEFAULT].linearizer_opts) + k.process() + k.upcast() + k.linearize() + num_ops = len([uop for uop in k.uops if uop.uop == UOps.ALU]) + assert num_ops <= 1, "more alu uops than needed" + if __name__ == '__main__': unittest.main() diff --git a/tinygrad/codegen/linearizer.py b/tinygrad/codegen/linearizer.py index 2b920dec04..0da9ee00d5 100644 --- a/tinygrad/codegen/linearizer.py +++ b/tinygrad/codegen/linearizer.py @@ -4,7 +4,7 @@ from collections import defaultdict from enum import Enum, auto from tinygrad.helpers import dedup, colored, ImageDType, DEBUG, prod, dtypes, mnum, DType, all_same, partition -from tinygrad.ops import LazyOp, FlopCounter, get_lazyop_info, UnaryOps +from tinygrad.ops import LazyOp, FlopCounter, get_lazyop_info, UnaryOps, Op from tinygrad.lazy import LazyBuffer from tinygrad.ops import MovementOps, ReduceOps, BinaryOps, TernaryOps from tinygrad.runtime.lib import RawConst, buf_is_kernel_arg @@ -298,7 +298,7 @@ class Linearizer: # uops self.uops: List[UOp] = [] self.load_cache: Dict[str, Token] = {} - self.saved_exprs: Dict[LazyOp, List[Token]] = dict() + self.saved_exprs: Dict[Tuple[Op, Tuple[Token, ...]], Token] = dict() # add global buffers for buf,name in self.arg_bufs.items(): @@ -492,6 +492,11 @@ class Linearizer: if DEBUG >= 4: print(self.uops[-1]) return out + def uop_alu(self, out: Token, vin: List[Token], op: Op) -> Token: + key = (op, tuple(vin)) + if key not in self.saved_exprs: self.saved_exprs[key] = self.uop(UOps.ALU, out, vin, op) + return self.saved_exprs[key] + def ast_parse(self, x, acc, loaded_buffers, ssa, do_reduce=False) -> List[Token]: if x.__class__ is not LazyOp: return loaded_buffers[x] if x.op in [UnaryOps.NOOP, UnaryOps.CAST]: return self.ast_parse(x.src[0], acc, loaded_buffers, ssa) # cast isn't an ALU op @@ -505,21 +510,19 @@ class Linearizer: # Reorder sources to put constants first so get_grouped_maybe_float4 can fold the op srcs = sorted(x.src, key=lambda x: (x.realized.__class__ != RawConst) if x.__class__ == LazyBuffer else 0) x.src = tuple(srcs) - if x not in self.saved_exprs: - values = [self.ast_parse(v, acc, loaded_buffers, ssa) for v in x.src] - ops = {ReduceOps.SUM:BinaryOps.ADD, ReduceOps.MAX:BinaryOps.MAX, TernaryOps.MULACC:TernaryOps.MULACC} - if x.op in ops: - ret = [(idx, self.uop(UOps.ALU, val[-1], list(val), ops[x.op])) for idx, val in get_grouped_maybe_float4(*values, acc, grouping_allowed=self.opts.supports_float4_alu)] - else: - ret = [(idx, self.uop(UOps.ALU, ssa('alu', dtypes._float4) if any(x.dtype == dtypes._float4 and x.offset is None for x in val) else ssa('alu'), list(val), x.op)) for idx, val in get_grouped_maybe_float4(*values, grouping_allowed=self.opts.supports_float4_alu and x.op not in {BinaryOps.CMPLT, TernaryOps.WHERE})] - ordered_ret: List[Optional[Token]] = [None]*len(values[0]) - # scatter - for i,j in ret: - for o,k in enumerate(i): - ordered_ret[k] = Token(j.name, j.dtype, o) if j.dtype == dtypes._float4 else j - assert all(isinstance(x, Token) for x in ordered_ret), "some tokens didn't get scattered?" - self.saved_exprs[x] = cast(List[Token], ordered_ret) - return self.saved_exprs[x] + values = [self.ast_parse(v, acc, loaded_buffers, ssa) for v in x.src] + ops = {ReduceOps.SUM:BinaryOps.ADD, ReduceOps.MAX:BinaryOps.MAX, TernaryOps.MULACC:TernaryOps.MULACC} + if x.op in ops: + ret = [(idx, self.uop(UOps.ALU, val[-1], list(val), ops[x.op])) for idx, val in get_grouped_maybe_float4(*values, acc, grouping_allowed=self.opts.supports_float4_alu)] + else: + ret = [(idx, self.uop_alu(ssa('alu', dtypes._float4) if any(x.dtype == dtypes._float4 and x.offset is None for x in val) else ssa('alu'), list(val), x.op)) for idx, val in get_grouped_maybe_float4(*values, grouping_allowed=self.opts.supports_float4_alu and x.op not in {BinaryOps.CMPLT, TernaryOps.WHERE})] + ordered_ret: List[Optional[Token]] = [None]*len(values[0]) + # scatter + for i,j in ret: + for o,k in enumerate(i): + ordered_ret[k] = Token(j.name, j.dtype, o) if j.dtype == dtypes._float4 else j + assert all(isinstance(x, Token) for x in ordered_ret), "some tokens didn't get scattered?" + return cast(List[Token], ordered_ret) @property def first_reduce(self) -> int: return [x!=y for x,y in zip(self.sts[0].shape[:self.shape_len-self.upcasted]+(0,), self.full_shape[:self.shape_len-self.upcasted]+(1,))].index(True)