mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-25 06:48:22 -05:00
implement common subexpression elimination (#1204)
* implement common subexpr elimination
* Revert "implement common subexpr elimination"
This reverts commit 40c5487d20.
* move cse to ast_parse + add type annotations
* oneline if
* improve saved_exprs lookup
This commit is contained in:
@@ -262,6 +262,7 @@ class Linearizer:
|
||||
def linearize(self):
|
||||
# uops
|
||||
self.uops: List[UOp] = []
|
||||
self.saved_exprs: Dict[LazyOp, List[Token]] = dict()
|
||||
|
||||
# add a local buffer for multistage reduce
|
||||
if len(self.group_for_reduce):
|
||||
@@ -445,18 +446,20 @@ class Linearizer:
|
||||
# Reorder sources to put constants first so get_grouped_maybe_float4 can fold the op
|
||||
srcs = sorted(x.src, key=lambda x: (x.realized.__class__ != RawConst) if x.__class__ == LazyBuffer else 0)
|
||||
x.src = tuple(srcs)
|
||||
values = [self.ast_parse(v, acc, loaded_buffers, ssa) for v in x.src]
|
||||
if x.op.__class__ in {ReduceOps, FusedOps}:
|
||||
ret = [(idx, self.uop(UOps.ALU, val[-1], list(val), {ReduceOps.SUM:BinaryOps.ADD, ReduceOps.MAX:BinaryOps.MAX, FusedOps.MULACC:FusedOps.MULACC}[x.op])) for idx, val in get_grouped_maybe_float4(*values, acc, grouping_allowed=self.supports_float4_alu)]
|
||||
else:
|
||||
ret = [(idx, self.uop(UOps.ALU, ssa('alu', dtypes._float4) if any(x.dtype == dtypes._float4 and x.offset is None for x in val) else ssa('alu'), list(val), x.op)) for idx, val in get_grouped_maybe_float4(*values, grouping_allowed=self.supports_float4_alu and x.op!=BinaryOps.CMPEQ)]
|
||||
ordered_ret: List[Optional[Token]] = [None]*len(values[0])
|
||||
# scatter
|
||||
for i,j in ret:
|
||||
for o,k in enumerate(i):
|
||||
ordered_ret[k] = Token(j.name, j.dtype, o) if j.dtype == dtypes._float4 else j
|
||||
assert all(isinstance(x, Token) for x in ordered_ret), "some tokens didn't get scattered?"
|
||||
return cast(List[Token], ordered_ret)
|
||||
if x not in self.saved_exprs:
|
||||
values = [self.ast_parse(v, acc, loaded_buffers, ssa) for v in x.src]
|
||||
if x.op.__class__ in {ReduceOps, FusedOps}:
|
||||
ret = [(idx, self.uop(UOps.ALU, val[-1], list(val), {ReduceOps.SUM:BinaryOps.ADD, ReduceOps.MAX:BinaryOps.MAX, FusedOps.MULACC:FusedOps.MULACC}[x.op])) for idx, val in get_grouped_maybe_float4(*values, acc, grouping_allowed=self.supports_float4_alu)]
|
||||
else:
|
||||
ret = [(idx, self.uop(UOps.ALU, ssa('alu', dtypes._float4) if any(x.dtype == dtypes._float4 and x.offset is None for x in val) else ssa('alu'), list(val), x.op)) for idx, val in get_grouped_maybe_float4(*values, grouping_allowed=self.supports_float4_alu and x.op!=BinaryOps.CMPEQ)]
|
||||
ordered_ret: List[Optional[Token]] = [None]*len(values[0])
|
||||
# scatter
|
||||
for i,j in ret:
|
||||
for o,k in enumerate(i):
|
||||
ordered_ret[k] = Token(j.name, j.dtype, o) if j.dtype == dtypes._float4 else j
|
||||
assert all(isinstance(x, Token) for x in ordered_ret), "some tokens didn't get scattered?"
|
||||
self.saved_exprs[x] = cast(List[Token], ordered_ret)
|
||||
return self.saved_exprs[x]
|
||||
|
||||
@property
|
||||
def first_reduce(self) -> int: return [x!=y for x,y in zip(self.sts[0].shape[:self.shape_len-self.upcasted]+(0,), self.full_shape[:self.shape_len-self.upcasted]+(1,))].index(True)
|
||||
|
||||
Reference in New Issue
Block a user