* Enable fast mod

* Add test
This commit is contained in:
Sieds Lykles
2025-05-05 18:15:43 +02:00
committed by GitHub
parent 363481e2fb
commit 338f33efae
2 changed files with 9 additions and 3 deletions

View File

@@ -359,7 +359,7 @@ class TestAssembly(unittest.TestCase):
self.assertIn(Ops.SHR, ops)
self.assertNotIn(Ops.IDIV, ops)
def test_fast_idiv(self):
def test_fast_idiv_and_mod(self):
g = UOp(Ops.DEFINE_GLOBAL, dtypes.uint32.ptr(), (), 0)
c = UOp(Ops.CONST, dtypes.uint, (), 3)
l = UOp(Ops.LOAD, dtypes.uint, (g.index(c),))
@@ -370,6 +370,13 @@ class TestAssembly(unittest.TestCase):
self.assertIn(Ops.SHR, ops)
self.assertNotIn(Ops.IDIV, ops)
b = UOp(Ops.MOD, dtypes.uint, (l, c))
uops = to_uops_list([b], opts=Device[Device.DEFAULT].renderer)
Device[Device.DEFAULT].renderer.render(uops)
ops = [x.op for x in uops]
self.assertIn(Ops.SHR, ops)
self.assertNotIn(Ops.MOD, ops)
@unittest.expectedFailure
def test_fast_idiv_overflow(self):
# This will be possible with a slightly different method for fast_idiv

View File

@@ -180,8 +180,7 @@ def get_late_rewrite_patterns(ops, force_transcendental=False):
pat += [(UPat.var("x", dtypes.sints)//UPat.cvar("c"), lambda x,c: x >> v if (v:=powers_of_two.get(c.arg, 0)) and resolve(x>=0,False) else None)]
if not getenv("DISABLE_FAST_IDIV"):
pat += [(UPat.var("x", dtypes.ints)//UPat.cvar("d"), lambda ctx, x, d: fast_idiv(ctx, x, d.arg))]
# TODO: This breaks validate_index because of the way _min_max is calucalted on uops
# pat += [(UPat.var("x", dtypes.ints)%UPat.cvar("d"), lambda x, d: x - d*f if (f:=fast_idiv(x, d.arg)) is not None else None)]
pat += [(UPat.var("x", dtypes.ints)%UPat.cvar("d"), lambda ctx, x, d: x - d*f if (f:=fast_idiv(ctx, x, d.arg)) is not None else None)]
if Ops.NEG in ops:
pat += [(UPat.var('x')*-1, lambda x: x.alu(Ops.NEG))]
if Ops.SUB in ops: pat += [(UPat.var('x')+UPat.var('y').alu(Ops.NEG), lambda x,y: x.alu(Ops.SUB, y))]