mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-10 23:48:01 -05:00
_get_add_chain(x) -> _get_chain(x, BinaryOps.ADD) (#6523)
need MUL for valid [run_process_replay]
This commit is contained in:
@@ -90,9 +90,9 @@ float4_folding = PatternMatcher([
|
||||
|
||||
# ***** mod *****
|
||||
|
||||
def _get_add_chain(x:UOp):
|
||||
if x.op is UOps.ALU and x.arg is BinaryOps.ADD:
|
||||
for s in x.src: yield from _get_add_chain(s)
|
||||
def _get_chain(x:UOp, sep:BinaryOps):
|
||||
if x.op is UOps.ALU and x.arg is sep:
|
||||
for s in x.src: yield from _get_chain(s, sep)
|
||||
else: yield x
|
||||
|
||||
def mod_folding(x:UOp, c:int) -> Optional[UOp]:
|
||||
@@ -102,7 +102,7 @@ def mod_folding(x:UOp, c:int) -> Optional[UOp]:
|
||||
if 0 < c and 0 <= x.vmin and (quotient:=x.vmin//c) == x.vmax//c: return x-quotient*c
|
||||
|
||||
remainder, something_changed = [], False
|
||||
for u in _get_add_chain(x):
|
||||
for u in _get_chain(x, BinaryOps.ADD):
|
||||
if (factor:=u.const_factor())%c != factor:
|
||||
remainder.append(u.divides(factor)*(factor%c))
|
||||
something_changed = True
|
||||
@@ -120,7 +120,7 @@ def div_folding(x:UOp, c:int) -> Optional[UOp]:
|
||||
if 0 <= x.vmin and x.vmax < c: return x.const_like(0)
|
||||
|
||||
quotient, remainder, rem_const, something_changed, gcd, divisor = [], [], 0, False, c, 1
|
||||
for u in _get_add_chain(x):
|
||||
for u in _get_chain(x, BinaryOps.ADD):
|
||||
if u.op is UOps.CONST:
|
||||
# add all const together first
|
||||
if rem_const != 0: something_changed = True
|
||||
@@ -157,7 +157,7 @@ def lt_folding(x:UOp, c:int) -> Optional[UOp]:
|
||||
def fold_unrolled_divs(divs:UOp, c:UOp):
|
||||
# div pattern in unrolled arange
|
||||
# example: (-x+2561)//-4+(-x+2562)//-4+(-x+2560)//-4+(-x+2559)//-4+2559 -> x
|
||||
add_chain, seen_const, ans = list(_get_add_chain(divs)), [], None
|
||||
add_chain, seen_const, ans = list(_get_chain(divs, BinaryOps.ADD)), [], None
|
||||
for u in add_chain:
|
||||
if not (u.op is UOps.ALU and u.arg is BinaryOps.IDIV and u.src[1].op is UOps.CONST and u.src[1].arg==-len(add_chain)): return None
|
||||
# assumed CONST is the last of an ADD
|
||||
|
||||
@@ -9,7 +9,7 @@ from tinygrad.shape.view import View, strides_for_shape
|
||||
from tinygrad.dtype import dtypes
|
||||
from tinygrad.ops import UOp, UOps, BinaryOps
|
||||
from tinygrad.ops import graph_rewrite
|
||||
from tinygrad.codegen.uopgraph import constant_folder, _get_add_chain
|
||||
from tinygrad.codegen.uopgraph import constant_folder, _get_chain
|
||||
|
||||
# TODO: this needs to be replaced, there shouldn't be variables in the shapetracker, only ints and UOps
|
||||
def variable_to_uop(x, ctx=None) -> UOp: return UOp.const(dtypes.pyint, x) if isinstance(x, int) else x.render(render_ops, ctx)
|
||||
@@ -104,7 +104,7 @@ class ShapeTracker:
|
||||
ret: List[Optional[sint]] = [None] * len(self.shape)
|
||||
idx, valid = self.to_indexed_uops()
|
||||
idx = graph_rewrite(idx, pm=constant_folder)
|
||||
for c in _get_add_chain(idx):
|
||||
for c in _get_chain(idx, BinaryOps.ADD):
|
||||
if c.op is UOps.RANGE: ret[c.arg] = 1
|
||||
if c.op is UOps.ALU and c.arg is BinaryOps.MUL and c.src[0].op is UOps.RANGE and c.src[1].op is UOps.CONST: ret[c.src[0].arg] = c.src[1].arg
|
||||
if c.op is UOps.ALU and c.arg is BinaryOps.MUL and c.src[1].op is UOps.RANGE and c.src[0].op is UOps.CONST: ret[c.src[1].arg] = c.src[0].arg
|
||||
|
||||
Reference in New Issue
Block a user