use unravel in views_to_indexed_uops [pr] (#8560)

* use unravel in shape

* make process replay work

* earlier View.minify()

* fix

* fix tests

* mypy

* get rid of early minify

* fix

* linter

* clean and add test

---------

Co-authored-by: chenyu <chenyu@fastmail.com>
This commit is contained in:
eliotgolding
2025-01-12 15:25:55 +00:00
committed by GitHub
parent 38b5ac4d4a
commit 867004fbeb
4 changed files with 28 additions and 22 deletions

View File

@@ -398,6 +398,17 @@ class TestSymbolic(unittest.TestCase):
self.helper_test_variable((idx0*v+idx1)//v, 0, 2, "(idx0)")
self.helper_test_variable((idx0*v+idx1)%v, 0, start_pos, "idx1")
def test_divmod_variable_denom_fold_to_const(self):
x = Variable("x", 20, 23)
y = Variable("y", 8, 10)
self.helper_test_variable(x//y, 2, 2, "2")
self.helper_test_variable(x%y, 0, 7, "(x+(y*-2))")
# ensure all 4 corners are checked
x = Variable("x", -10, 10)
y = Variable("y", -8, 9)
self.helper_test_variable(x//y, -2147483648, 2147483647, "(x//y)")
self.helper_test_variable(x%y, -2147483648, 2147483647, "(x%y)")
# TODO: simplify the expression
def test_div_neg_all_range(self):
gidx = Variable("gidx", 0, 124)

View File

@@ -993,14 +993,13 @@ def split_uop(x:UOp, sep:Ops):
for s in x.src: yield from split_uop(s, sep)
else: yield x
def div_and_mod_folding(x: UOp, c: int, which: Literal[Ops.MOD, Ops.IDIV], split_rem: bool=False) -> UOp|None:
# simplify x // c or x % c, None means no change, c must be > 0
assert c > 0
if x.dtype.count > 1: return None
def div_and_mod_folding(x: UOp, y: UOp, which: Literal[Ops.MOD, Ops.IDIV], split_rem: bool=False) -> UOp|None:
# simplify x // y or x % y, None means no change
# simple cancel div/mod case
if (q:=x.vmin//c) == (x.vmax//c):
if which is Ops.MOD: return x - q*c
return x.const_like(q)
if y.vmin != 0 != y.vmax and (q:=x.vmin//y.vmin) == x.vmin//y.vmax == x.vmax//y.vmin == x.vmax//y.vmax:
return x - q*y if which is Ops.MOD else x.const_like(q)
if (y.op is not Ops.CONST) or ((c := y.arg) <= 0) or (x.dtype.count > 1): return None
svars, factors, quotients, remainders, gcd, div, const, offset, something_changed = [], [], [], [], c, 1, 0, 0, False
for u in split_uop(x, Ops.ADD):
@@ -1039,7 +1038,7 @@ def div_and_mod_folding(x: UOp, c: int, which: Literal[Ops.MOD, Ops.IDIV], split
if gcd != 1: something_changed = True
if not something_changed:
if which is Ops.IDIV and (1 < div < c) and (newx:=div_and_mod_folding(x, div, Ops.IDIV)) is not None: return newx//(c//div)
if which is Ops.IDIV and (1 < div < c) and (newx:=div_and_mod_folding(x, UOp.const(dtypes.int, div), Ops.IDIV)) is not None: return newx//(c//div)
return None
quo, rem = x.const_like(const//c), x.const_like((const%c)//gcd)
for q,r,f,v in zip(quotients, remainders, factors, svars):
@@ -1259,10 +1258,10 @@ symbolic = symbolic_simple+PatternMatcher([
# ** div **
# div folding
((UPat.var("x")//UPat.cvar("c") + UPat.cvar("a"))//UPat.cvar("d"), lambda x,c,a,d: (x+a*c)//(c*d)), # (x//c+a)//d -> (x+a*c)//(c*d)
(UPat.var("x", dtypes.sints) // UPat.cvar("c", vec=False), lambda x,c: div_and_mod_folding(x,c.arg,Ops.IDIV) if 0 < c.arg else None),
(UPat.var("x", dtypes.sints) // UPat.var("y"), lambda x,y: div_and_mod_folding(x,y,Ops.IDIV)),
# ** mod **
# mod folding
(UPat.var("x") % UPat.cvar("c", vec=False), lambda x,c: div_and_mod_folding(x,c.arg,Ops.MOD) if 0 < c.arg else None),
(UPat.var("x") % UPat.var("y"), lambda x,y: div_and_mod_folding(x,y,Ops.MOD)),
])

View File

@@ -4,20 +4,16 @@ from dataclasses import dataclass
import functools
from typing import Optional, Callable
from tinygrad.helpers import merge_dicts, getenv
from tinygrad.shape.view import View, strides_for_shape
from tinygrad.shape.view import View, strides_for_shape, unravel
from tinygrad.dtype import dtypes
from tinygrad.ops import UOp, Ops, graph_rewrite, split_uop, symbolic_flat, Variable, sint, uop_given_valid, simplify_valid
from tinygrad.ops import UOp, Ops, graph_rewrite, split_uop, symbolic_flat, Variable, sint, uop_given_valid, simplify_valid, sint_to_uop
@functools.lru_cache(None)
def views_to_indexed_uops(views: tuple[View, ...], _idxs:Optional[tuple[UOp, ...]]=None) -> tuple[UOp, UOp]:
idx, valid = views[-1].to_indexed_uops(_idxs)
for view in reversed(views[0:-1]):
view = view.minify()
acc, idxs = 1, []
for d in reversed(view.shape):
idxs.append((idx//acc)%d)
acc *= d
idx, valid = view.to_indexed_uops(idxs[::-1], valid)
idx, valid = view.to_indexed_uops([sint_to_uop(i) for i in unravel(view.shape, idx)], valid)
return idx, valid
@functools.lru_cache(None)

View File

@@ -73,11 +73,11 @@ def _reshape_mask(_mask:Optional[tuple[tuple[sint, sint], ...]], old_shape:tuple
def unravel(shape:tuple[sint, ...], offset:sint) -> list[sint]:
# find the position of offset on each dimension based on shape
# similar to unravel_index in numpy/torch
ret = []
for stride in strides_for_shape(shape):
ret.append(offset // stride if stride != 0 else 0)
offset -= ret[-1] * stride
return ret
acc, idxs = 1, []
for d in reversed(shape):
idxs.append((offset//acc)%d)
acc *= d
return idxs[::-1]
@dataclass(frozen=True)
class View: