implement common subexpr elimination

This commit is contained in:
Carson Radtke
2023-07-09 01:45:33 -05:00
parent 67e34b356a
commit 40c5487d20

View File

@@ -240,6 +240,7 @@ class Linearizer:
def linearize(self):
# uops
self.uops: List[UOp] = []
self.alu_exprs = dict()
# add a local buffer for multistage reduce
if len(self.group_for_reduce):
@@ -354,6 +355,9 @@ class Linearizer:
_OT = TypeVar("_OT")
def uop(self, uop:UOps, out:_OT, vin:List[Token], arg:Any=None) -> _OT:
if uop == UOps.ALU:
if (arg, *vin) in self.alu_exprs: return self.alu_exprs[(arg, *vin)]
self.alu_exprs[(arg, *vin)] = out
self.uops.append(UOp(uop, cast(Optional[Token], out), vin, arg))
if DEBUG >= 4: print(self.uops[-1])
return out