mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
take gcd out of trunc div (#10238)
This commit is contained in:
@@ -345,9 +345,21 @@ class TestSymbolic(unittest.TestCase):
|
||||
def test_mul_div_factor_mul(self):
|
||||
self.helper_test_variable((Variable("a", 0, 10)*8)//4, 0, 20, "(a*2)")
|
||||
|
||||
def test_mul_div_factor_mul_neg(self):
|
||||
self.helper_test_variable((Variable("a", 0, 10)*-8+16)//4, -16, 4, "((a*-2)+4)")
|
||||
|
||||
def test_mul_div_factor_div(self):
|
||||
self.helper_test_variable((Variable("a", 0, 10)*4)//8, 0, 5, "(a//2)")
|
||||
|
||||
def test_mul_div_factor_div_neg(self):
|
||||
self.helper_test_variable((Variable("a", 0, 10)*-4+4)//8, -4, 0, "(((a*-1)+1)//2)")
|
||||
|
||||
def test_mod_gcd_factor_neg(self):
|
||||
self.helper_test_variable((Variable("a", 0, 10)*-4+4)%8, -4, 4, "((((a*-1)+1)%2)*4)")
|
||||
|
||||
def test_mod_gcd_fold_neg(self):
|
||||
self.helper_test_variable((Variable("a", 0, 10)*-8+20)%4, 0, 0, "0")
|
||||
|
||||
def test_sum_div_partial_remove(self):
|
||||
self.helper_test_variable(usum([Variable("idx0", 0, 127)*4, Variable("idx2", 0, 3)])//4, 0, 127, "idx0")
|
||||
|
||||
|
||||
@@ -156,6 +156,11 @@ def div_and_mod_folding(x: UOp, y: UOp, which: Literal[Ops.MOD, Ops.IDIV], split
|
||||
if which is Ops.MOD: return rem - rem.vmin//c*c
|
||||
return sum((f-r)//c * v for f,r,v in zip(factors,rems,svars)) + (const-const%c+rem.vmin//c*c)//c
|
||||
|
||||
if math.gcd(gcd, const)!=1:
|
||||
gcd = math.gcd(gcd, const)
|
||||
ret = UOp(which, x.dtype, src=(sum(f//gcd * v for f,v in zip(factors, svars)) + const//gcd, x.const_like(c//gcd)))
|
||||
return ret*gcd if which is Ops.MOD else ret
|
||||
|
||||
if gcd != 1: something_changed = True
|
||||
if not something_changed:
|
||||
if which is Ops.IDIV and (1 < div < c) and (newx:=div_and_mod_folding(x, x.const_like(div), Ops.IDIV)) is not None: return newx//(c//div)
|
||||
|
||||
Reference in New Issue
Block a user