mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-11 07:58:08 -05:00
use unravel in views_to_indexed_uops [pr] (#8560)
* use unravel in shape * make process replay work * earlier View.minify() * fix * fix tests * mypy * get rid of early minify * fix * linter * clean and add test --------- Co-authored-by: chenyu <chenyu@fastmail.com>
This commit is contained in:
@@ -398,6 +398,17 @@ class TestSymbolic(unittest.TestCase):
|
||||
self.helper_test_variable((idx0*v+idx1)//v, 0, 2, "(idx0)")
|
||||
self.helper_test_variable((idx0*v+idx1)%v, 0, start_pos, "idx1")
|
||||
|
||||
def test_divmod_variable_denom_fold_to_const(self):
|
||||
x = Variable("x", 20, 23)
|
||||
y = Variable("y", 8, 10)
|
||||
self.helper_test_variable(x//y, 2, 2, "2")
|
||||
self.helper_test_variable(x%y, 0, 7, "(x+(y*-2))")
|
||||
# ensure all 4 corners are checked
|
||||
x = Variable("x", -10, 10)
|
||||
y = Variable("y", -8, 9)
|
||||
self.helper_test_variable(x//y, -2147483648, 2147483647, "(x//y)")
|
||||
self.helper_test_variable(x%y, -2147483648, 2147483647, "(x%y)")
|
||||
|
||||
# TODO: simplify the expression
|
||||
def test_div_neg_all_range(self):
|
||||
gidx = Variable("gidx", 0, 124)
|
||||
|
||||
@@ -993,14 +993,13 @@ def split_uop(x:UOp, sep:Ops):
|
||||
for s in x.src: yield from split_uop(s, sep)
|
||||
else: yield x
|
||||
|
||||
def div_and_mod_folding(x: UOp, c: int, which: Literal[Ops.MOD, Ops.IDIV], split_rem: bool=False) -> UOp|None:
|
||||
# simplify x // c or x % c, None means no change, c must be > 0
|
||||
assert c > 0
|
||||
if x.dtype.count > 1: return None
|
||||
def div_and_mod_folding(x: UOp, y: UOp, which: Literal[Ops.MOD, Ops.IDIV], split_rem: bool=False) -> UOp|None:
|
||||
# simplify x // y or x % y, None means no change
|
||||
# simple cancel div/mod case
|
||||
if (q:=x.vmin//c) == (x.vmax//c):
|
||||
if which is Ops.MOD: return x - q*c
|
||||
return x.const_like(q)
|
||||
if y.vmin != 0 != y.vmax and (q:=x.vmin//y.vmin) == x.vmin//y.vmax == x.vmax//y.vmin == x.vmax//y.vmax:
|
||||
return x - q*y if which is Ops.MOD else x.const_like(q)
|
||||
|
||||
if (y.op is not Ops.CONST) or ((c := y.arg) <= 0) or (x.dtype.count > 1): return None
|
||||
|
||||
svars, factors, quotients, remainders, gcd, div, const, offset, something_changed = [], [], [], [], c, 1, 0, 0, False
|
||||
for u in split_uop(x, Ops.ADD):
|
||||
@@ -1039,7 +1038,7 @@ def div_and_mod_folding(x: UOp, c: int, which: Literal[Ops.MOD, Ops.IDIV], split
|
||||
|
||||
if gcd != 1: something_changed = True
|
||||
if not something_changed:
|
||||
if which is Ops.IDIV and (1 < div < c) and (newx:=div_and_mod_folding(x, div, Ops.IDIV)) is not None: return newx//(c//div)
|
||||
if which is Ops.IDIV and (1 < div < c) and (newx:=div_and_mod_folding(x, UOp.const(dtypes.int, div), Ops.IDIV)) is not None: return newx//(c//div)
|
||||
return None
|
||||
quo, rem = x.const_like(const//c), x.const_like((const%c)//gcd)
|
||||
for q,r,f,v in zip(quotients, remainders, factors, svars):
|
||||
@@ -1259,10 +1258,10 @@ symbolic = symbolic_simple+PatternMatcher([
|
||||
# ** div **
|
||||
# div folding
|
||||
((UPat.var("x")//UPat.cvar("c") + UPat.cvar("a"))//UPat.cvar("d"), lambda x,c,a,d: (x+a*c)//(c*d)), # (x//c+a)//d -> (x+a*c)//(c*d)
|
||||
(UPat.var("x", dtypes.sints) // UPat.cvar("c", vec=False), lambda x,c: div_and_mod_folding(x,c.arg,Ops.IDIV) if 0 < c.arg else None),
|
||||
(UPat.var("x", dtypes.sints) // UPat.var("y"), lambda x,y: div_and_mod_folding(x,y,Ops.IDIV)),
|
||||
# ** mod **
|
||||
# mod folding
|
||||
(UPat.var("x") % UPat.cvar("c", vec=False), lambda x,c: div_and_mod_folding(x,c.arg,Ops.MOD) if 0 < c.arg else None),
|
||||
(UPat.var("x") % UPat.var("y"), lambda x,y: div_and_mod_folding(x,y,Ops.MOD)),
|
||||
])
|
||||
|
||||
|
||||
|
||||
@@ -4,20 +4,16 @@ from dataclasses import dataclass
|
||||
import functools
|
||||
from typing import Optional, Callable
|
||||
from tinygrad.helpers import merge_dicts, getenv
|
||||
from tinygrad.shape.view import View, strides_for_shape
|
||||
from tinygrad.shape.view import View, strides_for_shape, unravel
|
||||
from tinygrad.dtype import dtypes
|
||||
from tinygrad.ops import UOp, Ops, graph_rewrite, split_uop, symbolic_flat, Variable, sint, uop_given_valid, simplify_valid
|
||||
from tinygrad.ops import UOp, Ops, graph_rewrite, split_uop, symbolic_flat, Variable, sint, uop_given_valid, simplify_valid, sint_to_uop
|
||||
|
||||
@functools.lru_cache(None)
|
||||
def views_to_indexed_uops(views: tuple[View, ...], _idxs:Optional[tuple[UOp, ...]]=None) -> tuple[UOp, UOp]:
|
||||
idx, valid = views[-1].to_indexed_uops(_idxs)
|
||||
for view in reversed(views[0:-1]):
|
||||
view = view.minify()
|
||||
acc, idxs = 1, []
|
||||
for d in reversed(view.shape):
|
||||
idxs.append((idx//acc)%d)
|
||||
acc *= d
|
||||
idx, valid = view.to_indexed_uops(idxs[::-1], valid)
|
||||
idx, valid = view.to_indexed_uops([sint_to_uop(i) for i in unravel(view.shape, idx)], valid)
|
||||
return idx, valid
|
||||
|
||||
@functools.lru_cache(None)
|
||||
|
||||
@@ -73,11 +73,11 @@ def _reshape_mask(_mask:Optional[tuple[tuple[sint, sint], ...]], old_shape:tuple
|
||||
def unravel(shape:tuple[sint, ...], offset:sint) -> list[sint]:
|
||||
# find the position of offset on each dimension based on shape
|
||||
# similar to unravel_index in numpy/torch
|
||||
ret = []
|
||||
for stride in strides_for_shape(shape):
|
||||
ret.append(offset // stride if stride != 0 else 0)
|
||||
offset -= ret[-1] * stride
|
||||
return ret
|
||||
acc, idxs = 1, []
|
||||
for d in reversed(shape):
|
||||
idxs.append((offset//acc)%d)
|
||||
acc *= d
|
||||
return idxs[::-1]
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class View:
|
||||
|
||||
Reference in New Issue
Block a user