mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
add test, fix rewrite rule and raise error on division by zero (#11073)
This commit is contained in:
@@ -169,6 +169,12 @@ class TestSymbolic(unittest.TestCase):
|
||||
self.helper_test_variable(Variable("a", 1, 7) // -2, -3, 0, "((a//2)*-1)")
|
||||
self.helper_test_variable(Variable("a", 0, 6) // -2, -3, 0, "((a//2)*-1)")
|
||||
|
||||
def test_div_mod_zero(self):
|
||||
with self.assertRaises(ZeroDivisionError):
|
||||
(Variable("a", 0, 7) // 0).simplify()
|
||||
with self.assertRaises(ZeroDivisionError):
|
||||
(Variable("a", 0, 7) % 0).simplify()
|
||||
|
||||
def test_sum_div_remove(self):
|
||||
self.helper_test_variable(usum([Variable("a", 0, 7), Variable("b", 0, 3)]) // 20, 0, 0, "0")
|
||||
|
||||
|
||||
@@ -131,7 +131,8 @@ def div_and_mod_folding(x: UOp, y: UOp, which: Literal[Ops.MOD, Ops.IDIV], split
|
||||
if y_min*y_max > 0 and (q:=cdiv(x_min,y_min)) == cdiv(x_min,y_max) == cdiv(x_max,y_min) == cdiv(x_max,y_max):
|
||||
return x - q*y if which is Ops.MOD else x.const_like(q)
|
||||
|
||||
if (y.op is not Ops.CONST) or ((c := y.arg) <= 0) or (x.dtype.count > 1): return None
|
||||
if (y.op is not Ops.CONST) or ((c := y.arg) < 0) or (x.dtype.count > 1): return None
|
||||
if y.arg == 0: raise ZeroDivisionError(f"{'Division' if which is Ops.IDIV else 'Mod'} by zero trying to rewrite {x.alu(which, y)}")
|
||||
|
||||
svars, factors, quotients, remainders, gcd, div, const, something_changed = [], [], [], [], c, 1, 0, False
|
||||
for u in split_uop(x, Ops.ADD):
|
||||
@@ -285,7 +286,7 @@ symbolic = symbolic_simple+commutative+PatternMatcher([
|
||||
((UPat.var("x")//UPat.cvar("c") + UPat.cvar("a"))//UPat.cvar("d"), lambda x,c,a,d: (x+a*c)//(c*d)
|
||||
if c.vmin>0 and d.vmin>0 and ((x.vmin>=0 and a.vmin>=0) or (x.vmax<=0 and a.vmax<=0)) else None), # (x//c+a)//d -> (x+a*c)//(c*d)
|
||||
(UPat.var("x", dtypes.sints) // UPat.var("y"), lambda x,y: div_and_mod_folding(x,y,Ops.IDIV)),
|
||||
(UPat.var("x") // UPat.var("d"), lambda x,d: -(x//(-d)) if d.vmax <=0 else None),
|
||||
(UPat.var("x") // UPat.var("d"), lambda x,d: -(x//(-d)) if d.vmax < 0 else None),
|
||||
(UPat.var("x") // UPat.var("d"), lambda x,d: -((-x)//d) if x.vmax <=0 else None),
|
||||
((UPat.var("x", dtypes.sints)+UPat.cvar("c", vec=False)).named("n")//UPat.cvar("d", vec=False),
|
||||
lambda x,c,n,d: (-(-(c.arg%d.arg + x - (d.arg-1))//d) + c.arg//d.arg) if x.vmax<=0 and n.vmin>=0 and d.arg>0 else None),
|
||||
|
||||
Reference in New Issue
Block a user