diff --git a/tinygrad/codegen/linearizer.py b/tinygrad/codegen/linearizer.py index 5a3491d21f..a38a4e2da0 100644 --- a/tinygrad/codegen/linearizer.py +++ b/tinygrad/codegen/linearizer.py @@ -262,6 +262,7 @@ class Linearizer: def linearize(self): # uops self.uops: List[UOp] = [] + self.saved_exprs: Dict[LazyOp, List[Token]] = dict() # add a local buffer for multistage reduce if len(self.group_for_reduce): @@ -445,18 +446,20 @@ class Linearizer: # Reorder sources to put constants first so get_grouped_maybe_float4 can fold the op srcs = sorted(x.src, key=lambda x: (x.realized.__class__ != RawConst) if x.__class__ == LazyBuffer else 0) x.src = tuple(srcs) - values = [self.ast_parse(v, acc, loaded_buffers, ssa) for v in x.src] - if x.op.__class__ in {ReduceOps, FusedOps}: - ret = [(idx, self.uop(UOps.ALU, val[-1], list(val), {ReduceOps.SUM:BinaryOps.ADD, ReduceOps.MAX:BinaryOps.MAX, FusedOps.MULACC:FusedOps.MULACC}[x.op])) for idx, val in get_grouped_maybe_float4(*values, acc, grouping_allowed=self.supports_float4_alu)] - else: - ret = [(idx, self.uop(UOps.ALU, ssa('alu', dtypes._float4) if any(x.dtype == dtypes._float4 and x.offset is None for x in val) else ssa('alu'), list(val), x.op)) for idx, val in get_grouped_maybe_float4(*values, grouping_allowed=self.supports_float4_alu and x.op!=BinaryOps.CMPEQ)] - ordered_ret: List[Optional[Token]] = [None]*len(values[0]) - # scatter - for i,j in ret: - for o,k in enumerate(i): - ordered_ret[k] = Token(j.name, j.dtype, o) if j.dtype == dtypes._float4 else j - assert all(isinstance(x, Token) for x in ordered_ret), "some tokens didn't get scattered?" - return cast(List[Token], ordered_ret) + if x not in self.saved_exprs: + values = [self.ast_parse(v, acc, loaded_buffers, ssa) for v in x.src] + if x.op.__class__ in {ReduceOps, FusedOps}: + ret = [(idx, self.uop(UOps.ALU, val[-1], list(val), {ReduceOps.SUM:BinaryOps.ADD, ReduceOps.MAX:BinaryOps.MAX, FusedOps.MULACC:FusedOps.MULACC}[x.op])) for idx, val in get_grouped_maybe_float4(*values, acc, grouping_allowed=self.supports_float4_alu)] + else: + ret = [(idx, self.uop(UOps.ALU, ssa('alu', dtypes._float4) if any(x.dtype == dtypes._float4 and x.offset is None for x in val) else ssa('alu'), list(val), x.op)) for idx, val in get_grouped_maybe_float4(*values, grouping_allowed=self.supports_float4_alu and x.op!=BinaryOps.CMPEQ)] + ordered_ret: List[Optional[Token]] = [None]*len(values[0]) + # scatter + for i,j in ret: + for o,k in enumerate(i): + ordered_ret[k] = Token(j.name, j.dtype, o) if j.dtype == dtypes._float4 else j + assert all(isinstance(x, Token) for x in ordered_ret), "some tokens didn't get scattered?" + self.saved_exprs[x] = cast(List[Token], ordered_ret) + return self.saved_exprs[x] @property def first_reduce(self) -> int: return [x!=y for x,y in zip(self.sts[0].shape[:self.shape_len-self.upcasted]+(0,), self.full_shape[:self.shape_len-self.upcasted]+(1,))].index(True)