diff --git a/tinygrad/uop/decompositions.py b/tinygrad/uop/decompositions.py index aaa4e8472c..54985c02be 100644 --- a/tinygrad/uop/decompositions.py +++ b/tinygrad/uop/decompositions.py @@ -282,7 +282,7 @@ def magicgu(vmax:int, d:int) -> tuple[int,int]: def fast_idiv(device: str, x: UOp, d: int, dont_cast=False) -> UOp|None: # If d is a power of two this is not valid for signed ints! - is_unsigned = True if x.vmin>=0 or x.dtype in dtypes.uints else False + is_unsigned = x.vmin>=0 or x.dtype in dtypes.uints assert d>0, "Sign should have been taken out of divisor" vmin,vmax = max(x.vmin, x.dtype.min), min(x.vmax, x.dtype.max) m,s = magicgu(max(vmax, abs(vmin)), d) @@ -293,7 +293,7 @@ def fast_idiv(device: str, x: UOp, d: int, dont_cast=False) -> UOp|None: if (ret:=fast_idiv(device, x//largest_factor_of_two_in_d, d//largest_factor_of_two_in_d, dont_cast=True)) is not None: return ret if dont_cast: return None # promo_lattice needs to return an unsigned type if the type is unsigned - if dtypes.is_int(next_dtype := promo_lattice[x.dtype.scalar()][-1]) and is_dtype_supported(next_dtype, None if device=='' else device): + if dtypes.is_int(next_dtype := promo_lattice[x.dtype.scalar()][-1]) and is_dtype_supported(next_dtype, device): if m*vmin >= dtypes.min(next_dtype) and m*vmax <= dtypes.max(next_dtype): return ((x.cast(next_dtype)*m) >> s).cast(x.dtype) if is_unsigned else ((x.cast(next_dtype)*m) >> s).cast(x.dtype) + (x<0).where(x.ufix(1), 0) return None