mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-14 09:28:04 -05:00
disable fast_idiv on metal (#14137)
there's a metal compiler bug which was the root cause that keccak needs a contigous hack
This commit is contained in:
4
test/external/fuzz_fast_idiv.py
vendored
4
test/external/fuzz_fast_idiv.py
vendored
@@ -1,6 +1,6 @@
|
|||||||
import random
|
import random
|
||||||
import z3
|
import z3
|
||||||
from tinygrad import dtypes
|
from tinygrad import dtypes, Device
|
||||||
from tinygrad.uop.validate import uops_to_z3, z3_cdiv
|
from tinygrad.uop.validate import uops_to_z3, z3_cdiv
|
||||||
from tinygrad.uop.ops import UOp
|
from tinygrad.uop.ops import UOp
|
||||||
from tinygrad.uop.decompositions import fast_idiv
|
from tinygrad.uop.decompositions import fast_idiv
|
||||||
@@ -15,7 +15,7 @@ if __name__ == "__main__":
|
|||||||
u = UOp.variable('x', random.randint(dt.min, 0), random.randint(1, dt.max), dtype=dt)
|
u = UOp.variable('x', random.randint(dt.min, 0), random.randint(1, dt.max), dtype=dt)
|
||||||
d = random.randint(1, max(1, u.arg[2]))
|
d = random.randint(1, max(1, u.arg[2]))
|
||||||
if d in powers_of_two: continue
|
if d in powers_of_two: continue
|
||||||
expr = fast_idiv(None, u, d)
|
expr = fast_idiv(Device.DEFAULT, u, d)
|
||||||
if expr is None: continue
|
if expr is None: continue
|
||||||
|
|
||||||
solver = z3.Solver()
|
solver = z3.Solver()
|
||||||
|
|||||||
@@ -1881,7 +1881,6 @@ class Tensor(OpMixin):
|
|||||||
p = state.reshape(bs, 5, 5).transpose(2, 1)
|
p = state.reshape(bs, 5, 5).transpose(2, 1)
|
||||||
t1 = (p[:,:,0] ^ p[:,:,1] ^ p[:,:,2] ^ p[:,:,3] ^ p[:,:,4]).roll(-1, 1) # xor reduce
|
t1 = (p[:,:,0] ^ p[:,:,1] ^ p[:,:,2] ^ p[:,:,3] ^ p[:,:,4]).roll(-1, 1) # xor reduce
|
||||||
state = state ^ (t1.roll(2, 1).bitwise_xor((t1 << 1) ^ (t1 >> 63)).unsqueeze(2).expand(bs, 5, 5).transpose(2, 1).flatten(1))
|
state = state ^ (t1.roll(2, 1).bitwise_xor((t1 << 1) ^ (t1 >> 63)).unsqueeze(2).expand(bs, 5, 5).transpose(2, 1).flatten(1))
|
||||||
state = state.contiguous() # required for correct indexing in π step # TODO: why is it needed?
|
|
||||||
# ρ and π steps
|
# ρ and π steps
|
||||||
state = state[:, reorder_indexes]
|
state = state[:, reorder_indexes]
|
||||||
state = (state * rot_offsets_v0).bitwise_or(state // rot_offsets_v1).reshape(bs, 5, 5)
|
state = (state * rot_offsets_v0).bitwise_or(state // rot_offsets_v1).reshape(bs, 5, 5)
|
||||||
|
|||||||
@@ -279,6 +279,8 @@ def magicgu(vmax:int, d:int) -> tuple[int,int]:
|
|||||||
assert False
|
assert False
|
||||||
|
|
||||||
def fast_idiv(device: str, x: UOp, d: int, dont_cast=False) -> UOp|None:
|
def fast_idiv(device: str, x: UOp, d: int, dont_cast=False) -> UOp|None:
|
||||||
|
# NOTE: disable for METAL due to compiler bug. keccak with -O0 works but not with optimization
|
||||||
|
if device.startswith("METAL"): return None
|
||||||
# If d is a power of two this is not valid for signed ints!
|
# If d is a power of two this is not valid for signed ints!
|
||||||
is_unsigned = x.vmin>=0 or x.dtype in dtypes.uints
|
is_unsigned = x.vmin>=0 or x.dtype in dtypes.uints
|
||||||
assert d>0, "Sign should have been taken out of divisor"
|
assert d>0, "Sign should have been taken out of divisor"
|
||||||
|
|||||||
Reference in New Issue
Block a user