implement common subexpression elimination (#1204)

* implement common subexpr elimination

* Revert "implement common subexpr elimination"

This reverts commit 40c5487d20.

* move cse to ast_parse + add type annotations

* oneline if

* improve saved_exprs lookup
This commit is contained in:
Carson Radtke
2023-07-09 16:22:53 -05:00
committed by GitHub
parent beb4d3ab01
commit 1eb0e0cb3f

View File

@@ -262,6 +262,7 @@ class Linearizer:
def linearize(self):
# uops
self.uops: List[UOp] = []
self.saved_exprs: Dict[LazyOp, List[Token]] = dict()
# add a local buffer for multistage reduce
if len(self.group_for_reduce):
@@ -445,18 +446,20 @@ class Linearizer:
# Reorder sources to put constants first so get_grouped_maybe_float4 can fold the op
srcs = sorted(x.src, key=lambda x: (x.realized.__class__ != RawConst) if x.__class__ == LazyBuffer else 0)
x.src = tuple(srcs)
values = [self.ast_parse(v, acc, loaded_buffers, ssa) for v in x.src]
if x.op.__class__ in {ReduceOps, FusedOps}:
ret = [(idx, self.uop(UOps.ALU, val[-1], list(val), {ReduceOps.SUM:BinaryOps.ADD, ReduceOps.MAX:BinaryOps.MAX, FusedOps.MULACC:FusedOps.MULACC}[x.op])) for idx, val in get_grouped_maybe_float4(*values, acc, grouping_allowed=self.supports_float4_alu)]
else:
ret = [(idx, self.uop(UOps.ALU, ssa('alu', dtypes._float4) if any(x.dtype == dtypes._float4 and x.offset is None for x in val) else ssa('alu'), list(val), x.op)) for idx, val in get_grouped_maybe_float4(*values, grouping_allowed=self.supports_float4_alu and x.op!=BinaryOps.CMPEQ)]
ordered_ret: List[Optional[Token]] = [None]*len(values[0])
# scatter
for i,j in ret:
for o,k in enumerate(i):
ordered_ret[k] = Token(j.name, j.dtype, o) if j.dtype == dtypes._float4 else j
assert all(isinstance(x, Token) for x in ordered_ret), "some tokens didn't get scattered?"
return cast(List[Token], ordered_ret)
if x not in self.saved_exprs:
values = [self.ast_parse(v, acc, loaded_buffers, ssa) for v in x.src]
if x.op.__class__ in {ReduceOps, FusedOps}:
ret = [(idx, self.uop(UOps.ALU, val[-1], list(val), {ReduceOps.SUM:BinaryOps.ADD, ReduceOps.MAX:BinaryOps.MAX, FusedOps.MULACC:FusedOps.MULACC}[x.op])) for idx, val in get_grouped_maybe_float4(*values, acc, grouping_allowed=self.supports_float4_alu)]
else:
ret = [(idx, self.uop(UOps.ALU, ssa('alu', dtypes._float4) if any(x.dtype == dtypes._float4 and x.offset is None for x in val) else ssa('alu'), list(val), x.op)) for idx, val in get_grouped_maybe_float4(*values, grouping_allowed=self.supports_float4_alu and x.op!=BinaryOps.CMPEQ)]
ordered_ret: List[Optional[Token]] = [None]*len(values[0])
# scatter
for i,j in ret:
for o,k in enumerate(i):
ordered_ret[k] = Token(j.name, j.dtype, o) if j.dtype == dtypes._float4 else j
assert all(isinstance(x, Token) for x in ordered_ret), "some tokens didn't get scattered?"
self.saved_exprs[x] = cast(List[Token], ordered_ret)
return self.saved_exprs[x]
@property
def first_reduce(self) -> int: return [x!=y for x,y in zip(self.sts[0].shape[:self.shape_len-self.upcasted]+(0,), self.full_shape[:self.shape_len-self.upcasted]+(1,))].index(True)