mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-14 01:18:26 -05:00
disable fast_idiv on metal (#14137)
there's a metal compiler bug which was the root cause that keccak needs a contigous hack
This commit is contained in:
4
test/external/fuzz_fast_idiv.py
vendored
4
test/external/fuzz_fast_idiv.py
vendored
@@ -1,6 +1,6 @@
|
||||
import random
|
||||
import z3
|
||||
from tinygrad import dtypes
|
||||
from tinygrad import dtypes, Device
|
||||
from tinygrad.uop.validate import uops_to_z3, z3_cdiv
|
||||
from tinygrad.uop.ops import UOp
|
||||
from tinygrad.uop.decompositions import fast_idiv
|
||||
@@ -15,7 +15,7 @@ if __name__ == "__main__":
|
||||
u = UOp.variable('x', random.randint(dt.min, 0), random.randint(1, dt.max), dtype=dt)
|
||||
d = random.randint(1, max(1, u.arg[2]))
|
||||
if d in powers_of_two: continue
|
||||
expr = fast_idiv(None, u, d)
|
||||
expr = fast_idiv(Device.DEFAULT, u, d)
|
||||
if expr is None: continue
|
||||
|
||||
solver = z3.Solver()
|
||||
|
||||
@@ -1881,7 +1881,6 @@ class Tensor(OpMixin):
|
||||
p = state.reshape(bs, 5, 5).transpose(2, 1)
|
||||
t1 = (p[:,:,0] ^ p[:,:,1] ^ p[:,:,2] ^ p[:,:,3] ^ p[:,:,4]).roll(-1, 1) # xor reduce
|
||||
state = state ^ (t1.roll(2, 1).bitwise_xor((t1 << 1) ^ (t1 >> 63)).unsqueeze(2).expand(bs, 5, 5).transpose(2, 1).flatten(1))
|
||||
state = state.contiguous() # required for correct indexing in π step # TODO: why is it needed?
|
||||
# ρ and π steps
|
||||
state = state[:, reorder_indexes]
|
||||
state = (state * rot_offsets_v0).bitwise_or(state // rot_offsets_v1).reshape(bs, 5, 5)
|
||||
|
||||
@@ -279,6 +279,8 @@ def magicgu(vmax:int, d:int) -> tuple[int,int]:
|
||||
assert False
|
||||
|
||||
def fast_idiv(device: str, x: UOp, d: int, dont_cast=False) -> UOp|None:
|
||||
# NOTE: disable for METAL due to compiler bug. keccak with -O0 works but not with optimization
|
||||
if device.startswith("METAL"): return None
|
||||
# If d is a power of two this is not valid for signed ints!
|
||||
is_unsigned = x.vmin>=0 or x.dtype in dtypes.uints
|
||||
assert d>0, "Sign should have been taken out of divisor"
|
||||
|
||||
Reference in New Issue
Block a user