mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-07 03:00:26 -04:00
fix vec dtype in fast idiv (#12080)
* fix * add vec dtypes to fuzzer * add vec=False --------- Co-authored-by: Sieds Lykles <93992551+S-Lykles@users.noreply.github.com>
This commit is contained in:
2
test/external/fuzz_fast_idiv.py
vendored
2
test/external/fuzz_fast_idiv.py
vendored
@@ -11,7 +11,7 @@ if __name__ == "__main__":
|
||||
for i in range(10_000):
|
||||
if i % 1000 == 0:
|
||||
print(f"Progress: {i}")
|
||||
dt = random.choice(dtypes.ints)
|
||||
dt = random.choice(dtypes.ints + tuple(dt.vec(4) for dt in dtypes.ints))
|
||||
u = UOp.variable('x', random.randint(dt.min, 0), random.randint(1, dt.max), dtype=dt)
|
||||
d = random.randint(1, max(1, u.arg[2]))
|
||||
if d in powers_of_two: continue
|
||||
|
||||
@@ -293,7 +293,7 @@ def fast_idiv(device: str, x: UOp, d: int, dont_cast=False) -> UOp|None:
|
||||
if (ret:=fast_idiv(device, x//largest_factor_of_two_in_d, d//largest_factor_of_two_in_d, dont_cast=True)) is not None: return ret
|
||||
if dont_cast: return None
|
||||
# promo_lattice needs to return an unsigned type if the type is unsigned
|
||||
if dtypes.is_int(next_dtype := promo_lattice[x.dtype][-1]) and is_dtype_supported(next_dtype, None if device=='' else device):
|
||||
if dtypes.is_int(next_dtype := promo_lattice[x.dtype.scalar()][-1]) and is_dtype_supported(next_dtype, None if device=='' else device):
|
||||
if m*vmin >= dtypes.min(next_dtype) and m*vmax <= dtypes.max(next_dtype):
|
||||
return ((x.cast(next_dtype)*m) >> s).cast(x.dtype) if is_unsigned else ((x.cast(next_dtype)*m) >> s).cast(x.dtype) + (x<0).where(x.ufix(1), 0)
|
||||
return None
|
||||
@@ -343,7 +343,7 @@ def get_late_rewrite_patterns(ops:tuple[Ops, ...], force_transcendental=False):
|
||||
pat += [(UPat.var("x", dtypes.ints)//UPat.cvar("c"), lambda x,c: (x+(l.const_like(l.vmin) if (l:=(x<0)).vmin==l.vmax else l).where(
|
||||
c-1, 0)) >> v if (v:=powers_of_two.get(c.arg, 0)) else None)] # (x+(x<0).where(c-1, 0)) >> v
|
||||
if not DISABLE_FAST_IDIV:
|
||||
pat += [(UPat.var("x", dtypes.ints)//UPat.cvar("d"), lambda ctx, x, d: fast_idiv(ctx, x, d.arg))]
|
||||
pat += [(UPat.var("x", dtypes.ints)//UPat.cvar("d", vec=False), lambda ctx, x, d: fast_idiv(ctx, x, d.arg))]
|
||||
pat += [(UPat.var("x", dtypes.ints)%UPat.var("d"), lambda x, d: x-d*(x//d))]
|
||||
if Ops.NEG in ops:
|
||||
pat += [(UPat.var('x')*-1, lambda x: x.alu(Ops.NEG))]
|
||||
|
||||
Reference in New Issue
Block a user