_get_add_chain(x) -> _get_chain(x, BinaryOps.ADD) (#6523)

need MUL for valid [run_process_replay]
This commit is contained in:
chenyu
2024-09-15 10:54:13 -04:00
committed by GitHub
parent b2c286f567
commit 6be0cc387c
2 changed files with 8 additions and 8 deletions

View File

@@ -90,9 +90,9 @@ float4_folding = PatternMatcher([
# ***** mod *****
def _get_add_chain(x:UOp):
if x.op is UOps.ALU and x.arg is BinaryOps.ADD:
for s in x.src: yield from _get_add_chain(s)
def _get_chain(x:UOp, sep:BinaryOps):
if x.op is UOps.ALU and x.arg is sep:
for s in x.src: yield from _get_chain(s, sep)
else: yield x
def mod_folding(x:UOp, c:int) -> Optional[UOp]:
@@ -102,7 +102,7 @@ def mod_folding(x:UOp, c:int) -> Optional[UOp]:
if 0 < c and 0 <= x.vmin and (quotient:=x.vmin//c) == x.vmax//c: return x-quotient*c
remainder, something_changed = [], False
for u in _get_add_chain(x):
for u in _get_chain(x, BinaryOps.ADD):
if (factor:=u.const_factor())%c != factor:
remainder.append(u.divides(factor)*(factor%c))
something_changed = True
@@ -120,7 +120,7 @@ def div_folding(x:UOp, c:int) -> Optional[UOp]:
if 0 <= x.vmin and x.vmax < c: return x.const_like(0)
quotient, remainder, rem_const, something_changed, gcd, divisor = [], [], 0, False, c, 1
for u in _get_add_chain(x):
for u in _get_chain(x, BinaryOps.ADD):
if u.op is UOps.CONST:
# add all const together first
if rem_const != 0: something_changed = True
@@ -157,7 +157,7 @@ def lt_folding(x:UOp, c:int) -> Optional[UOp]:
def fold_unrolled_divs(divs:UOp, c:UOp):
# div pattern in unrolled arange
# example: (-x+2561)//-4+(-x+2562)//-4+(-x+2560)//-4+(-x+2559)//-4+2559 -> x
add_chain, seen_const, ans = list(_get_add_chain(divs)), [], None
add_chain, seen_const, ans = list(_get_chain(divs, BinaryOps.ADD)), [], None
for u in add_chain:
if not (u.op is UOps.ALU and u.arg is BinaryOps.IDIV and u.src[1].op is UOps.CONST and u.src[1].arg==-len(add_chain)): return None
# assumed CONST is the last of an ADD

View File

@@ -9,7 +9,7 @@ from tinygrad.shape.view import View, strides_for_shape
from tinygrad.dtype import dtypes
from tinygrad.ops import UOp, UOps, BinaryOps
from tinygrad.ops import graph_rewrite
from tinygrad.codegen.uopgraph import constant_folder, _get_add_chain
from tinygrad.codegen.uopgraph import constant_folder, _get_chain
# TODO: this needs to be replaced, there shouldn't be variables in the shapetracker, only ints and UOps
def variable_to_uop(x, ctx=None) -> UOp: return UOp.const(dtypes.pyint, x) if isinstance(x, int) else x.render(render_ops, ctx)
@@ -104,7 +104,7 @@ class ShapeTracker:
ret: List[Optional[sint]] = [None] * len(self.shape)
idx, valid = self.to_indexed_uops()
idx = graph_rewrite(idx, pm=constant_folder)
for c in _get_add_chain(idx):
for c in _get_chain(idx, BinaryOps.ADD):
if c.op is UOps.RANGE: ret[c.arg] = 1
if c.op is UOps.ALU and c.arg is BinaryOps.MUL and c.src[0].op is UOps.RANGE and c.src[1].op is UOps.CONST: ret[c.src[0].arg] = c.src[1].arg
if c.op is UOps.ALU and c.arg is BinaryOps.MUL and c.src[1].op is UOps.RANGE and c.src[0].op is UOps.CONST: ret[c.src[1].arg] = c.src[0].arg