From 15aed515448668c6b2c12b79dd73d44a647c2a7e Mon Sep 17 00:00:00 2001 From: chenyu Date: Wed, 28 Jan 2026 20:10:11 -0500 Subject: [PATCH] return types for all math.py function (#14413) calling int() on sint -> int, i think it's better support since some UOp can be safely cast to int --- tinygrad/apps/llm.py | 2 +- tinygrad/codegen/late/devectorizer.py | 2 +- tinygrad/mixin/math.py | 20 ++++++++++---------- tinygrad/nn/onnx.py | 19 ++++++++++--------- tinygrad/renderer/cstyle.py | 3 ++- tinygrad/runtime/ops_amd.py | 2 +- tinygrad/runtime/ops_qcom.py | 4 ++-- tinygrad/schedule/indexing.py | 2 +- tinygrad/tensor.py | 19 ++++++++++--------- 9 files changed, 38 insertions(+), 35 deletions(-) diff --git a/tinygrad/apps/llm.py b/tinygrad/apps/llm.py index c6eff5c5b7..604336a893 100644 --- a/tinygrad/apps/llm.py +++ b/tinygrad/apps/llm.py @@ -139,7 +139,7 @@ class TransformerBlock: v = self.cache_kv[1, :, :, 0:start_pos+T, :] # NOTE: this mask is causal_lower_right, not the causal_upper_left generated by is_casual = True - mask = Tensor.full((1, 1, T, start_pos+T), float("-inf"), dtype=x.dtype, device=x.device).triu(start_pos+1) if T > 1 else None + mask = Tensor.full((1, 1, T, start_pos+T), float("-inf"), dtype=x.dtype, device=x.device).triu(int(start_pos)+1) if T > 1 else None attn = q.scaled_dot_product_attention(k, v, attn_mask=mask, enable_gqa=True) # (B,H,T,Hd) attn = attn.transpose(1, 2).reshape(B, T, -1) # back to (B,T,D) attn = self.attn_output(attn) diff --git a/tinygrad/codegen/late/devectorizer.py b/tinygrad/codegen/late/devectorizer.py index 6faa693ee3..d10c37e176 100644 --- a/tinygrad/codegen/late/devectorizer.py +++ b/tinygrad/codegen/late/devectorizer.py @@ -201,7 +201,7 @@ def image_fixup(ls:UOp): x_mod_4 = x % 4 def sel(ret, i): return x_mod_4.ne(i).where(ret, vec_load.gep(i)) # if x is non-negative, x % 4 is in [0, 3] and we can skip NAN fallback - if x_mod_4.vmin >= 0: return functools.reduce(sel, range(x_mod_4.vmin+1, x_mod_4.vmax+1), vec_load.gep(x_mod_4.vmin)) + if x_mod_4.vmin >= 0: return functools.reduce(sel, range(int(x_mod_4.vmin)+1, int(x_mod_4.vmax)+1), vec_load.gep(int(x_mod_4.vmin))) return functools.reduce(sel, range(4), ls.const_like(float('nan'))) return None diff --git a/tinygrad/mixin/math.py b/tinygrad/mixin/math.py index 12c0456714..f38d6cc9d0 100644 --- a/tinygrad/mixin/math.py +++ b/tinygrad/mixin/math.py @@ -27,7 +27,7 @@ class MathMixin: raise TypeError(f"MathTraits __neg__ requires a dtype, {self=}") return self.logical_not() if dtype.scalar() == dtypes.bool else self * (-1) - def _check_dtype(self): + def _check_dtype(self) -> None: if (dtype := getattr(self, "dtype")) is not None: if isinstance(dtype, tuple): dtype = dtype[0] @@ -144,25 +144,25 @@ class MathMixin: def __neg__(self) -> Self: return self.neg() - def __add__(self, x: Self | ConstType): + def __add__(self, x: Self | ConstType) -> Self: return self.add(x) - def __sub__(self, x: Self | ConstType): + def __sub__(self, x: Self | ConstType) -> Self: return self.sub(x) - def __mul__(self, x: Self | ConstType): + def __mul__(self, x: Self | ConstType) -> Self: return self.mul(x) def __truediv__(self, x: Self | ConstType) -> Self: return self.div(x) - def __floordiv__(self, x: Self | ConstType): + def __floordiv__(self, x: Self | ConstType) -> Self: return self.idiv(x) # TODO: idiv is trunc div, not floordiv - def __mod__(self, x: Self | ConstType): + def __mod__(self, x: Self | ConstType) -> Self: return self.mod(x) - def __and__(self, x: Self | ConstType): + def __and__(self, x: Self | ConstType) -> Self: return self.bitwise_and(x) def __or__(self, x: Self | ConstType) -> Self: @@ -174,16 +174,16 @@ class MathMixin: def __radd__(self, x: Self | ConstType) -> Self: return self.add(x, True) - def __rsub__(self, x: Self | ConstType): + def __rsub__(self, x: Self | ConstType) -> Self: return self.sub(x, True) - def __rmul__(self, x: Self | ConstType): + def __rmul__(self, x: Self | ConstType) -> Self: return self.mul(x, True) def __rtruediv__(self, x: Self | ConstType) -> Self: return self.div(x, True) - def __rfloordiv__(self, x: Self | ConstType): + def __rfloordiv__(self, x: Self | ConstType) -> Self: return self.idiv(x, True) def __rand__(self, x: Self | ConstType) -> Self: diff --git a/tinygrad/nn/onnx.py b/tinygrad/nn/onnx.py index 89db73d3f7..164a04c85a 100644 --- a/tinygrad/nn/onnx.py +++ b/tinygrad/nn/onnx.py @@ -6,6 +6,7 @@ from tinygrad.tensor import Tensor, _broadcast_shape, ReductionStr from tinygrad.helpers import getenv, all_same, prod, flatten, make_tuple, argsort, is_numpy_ndarray, get_single_element, polyN from tinygrad.dtype import DType, ConstType, dtypes, _from_np_dtype, truncate, least_upper_dtype, DTYPES_DICT from tinygrad.device import is_dtype_supported, Device +from tinygrad.uop.ops import sint # ***** protobuf definitions ****** class WireType(enum.IntEnum): @@ -676,7 +677,7 @@ def get_onnx_ops() -> dict[str, types.FunctionType|dict[OpSetId, types.FunctionT def ReduceLogSumExp(data:Tensor, axes:list[int]|None=None, keepdims:int=1, noop_with_empty_axes:int=0): return ReduceSum(data.exp(), axes, keepdims, noop_with_empty_axes).log() def ArgMax(x:Tensor, axis:int=0, keepdims:int=1, select_last_index:int=0): - if select_last_index: return ((x.shape[axis]-1) - x.flip(axis).argmax(axis, keepdim=keepdims)).cast(dtypes.int64) + if select_last_index: return ((int(x.shape[axis])-1) - x.flip(axis).argmax(axis, keepdim=keepdims)).cast(dtypes.int64) return x.argmax(axis, keepdim=keepdims).cast(dtypes.int64) def ArgMin(x, axis:int=0, keepdims:int=1, select_last_index:int=0): return ArgMax(-x, axis=axis, keepdims=keepdims, select_last_index=select_last_index) @@ -703,7 +704,7 @@ def get_onnx_ops() -> dict[str, types.FunctionType|dict[OpSetId, types.FunctionT return data[tuple(slices)] def Split(data:Tensor, split:list[int]|None=None, num_outputs:int=0, axis:int=0): - sz = data.shape[axis] + sz = int(data.shape[axis]) if split is None: split = [sz // num_outputs + (1 if i < sz % num_outputs else 0) for i in range(num_outputs)] return data.split(split, axis) @@ -716,8 +717,8 @@ def get_onnx_ops() -> dict[str, types.FunctionType|dict[OpSetId, types.FunctionT return x.pad(padding=_onnx_pads_to_tiny_pads(real_pads), mode={"edge":"replicate", "wrap":"circular"}.get(mode, mode), value=value) def CenterCropPad(t:Tensor, shape:list[int], axes:list[int]|None=None): - shrink_arg:list[None|tuple[int,int]] = [None] * t.ndim - pad_arg:list[None|tuple[int,int]] = [None] * t.ndim + shrink_arg:list[None|tuple[sint,sint]] = [None] * t.ndim + pad_arg:list[None|tuple[sint,sint]] = [None] * t.ndim for s, x in zip(shape, axes or range(t.ndim)): tx = t.shape[x] if s < tx: shrink_arg[x] = (tx//2 - (s+1)//2, tx//2 + s//2) @@ -751,8 +752,8 @@ def get_onnx_ops() -> dict[str, types.FunctionType|dict[OpSetId, types.FunctionT pads = _auto_pad([s_*(i-1) + op_ + ((k_-1)*d_+1) - os for s_,i,op_,k_,d_,os in zip(strides_, input_shape_, output_padding_, kernel_shape_, dilations_, output_shape)], auto_pad) if pads is None: # we generate pads - output_shape = output_shape or [X.shape[i+2] * strides_[i] for i in range(len(strides_))] - pads = [strides_[i]*(input_shape_[i]-1)+output_padding_[i]+((kernel_shape_[i]-1)*dilations_[i]+1)-output_shape[i] + output_shape = output_shape or [int(X.shape[i+2]) * strides_[i] for i in range(len(strides_))] + pads = [int(strides_[i]*(input_shape_[i]-1)+output_padding_[i]+((kernel_shape_[i]-1)*dilations_[i]+1)-output_shape[i]) for i in range(len(input_shape_))] pads = _auto_pad(pads, auto_pad) if auto_pad != "NOTSET" else [0] * len(input_shape_) * 2 pads = _onnx_pads_to_tiny_pads(pads) @@ -1014,7 +1015,7 @@ def get_onnx_ops() -> dict[str, types.FunctionType|dict[OpSetId, types.FunctionT num_heads:int|None=None, past_present_share_buffer:int|None=None, qkv_hidden_sizes:list[int]|None=None, rotary_embedding_dim:int|None=None, scale:float|None=None, unidirectional:int=0): assert not do_rotary and not attention_bias, "TODO" - if qkv_hidden_sizes is None: qkv_hidden_sizes = [weights.shape[1] // 3] * 3 + if qkv_hidden_sizes is None: qkv_hidden_sizes = [int(weights.shape[1] // 3)] * 3 qkv = x.linear(weights, bias) q, k, v = qkv.split(qkv_hidden_sizes, dim=2) @@ -1112,7 +1113,7 @@ def get_onnx_ops() -> dict[str, types.FunctionType|dict[OpSetId, types.FunctionT if X.ndim == 4: X = X.permute(0, 2, 1, 3) elif X.ndim == 3: assert num_heads is not None, "num_heads must be provided for 3D input" - X = X.unflatten(-1, (num_heads, X.shape[-1] // num_heads)) + X = X.unflatten(-1, (num_heads, int(X.shape[-1]) // num_heads)) head_size = cast(int, X.shape[-1]) rot_dim = rotary_embedding_dim or head_size @@ -1182,7 +1183,7 @@ def get_onnx_ops() -> dict[str, types.FunctionType|dict[OpSetId, types.FunctionT def TensorScatter(data: Tensor, updates: Tensor, indices: Tensor, mode: str = 'default'): # scatter updates along axis -2 at positions given by indices, for each batch - B, U, D = indices.shape[0], updates.shape[-2], data.shape[-2] + B, U, D = indices.shape[0], updates.shape[-2], int(data.shape[-2]) orig_shape, data_flat, updates_flat = data.shape, data.reshape(-1, D, data.shape[-1]), updates.reshape(-1, U, updates.shape[-1]) B_total = data_flat.shape[0] batch_idx = Tensor.arange(B_total, device=data.device).reshape(B_total, 1).expand(B_total, U) diff --git a/tinygrad/renderer/cstyle.py b/tinygrad/renderer/cstyle.py index 4c3c861ada..6d76411c61 100644 --- a/tinygrad/renderer/cstyle.py +++ b/tinygrad/renderer/cstyle.py @@ -87,7 +87,8 @@ def create_non_native_float_pats(dts:tuple[DType, ...], casting:bool=True): def cast_float_to_bf16(x: UOp) -> UOp: assert x.dtype == dtypes.float, "cast float -> bf16 must start with float" x = x.bitcast(dtypes.uint) - x = ((-x & 0x7f800000) != 0).where(x + ((x >> 16) & 1) + 0x7fff, ((x & 0xffff) != 0).where((x | 0x10000), x)) + # NOTE: != returns UOp, not bool, issue with mypy + x = ((-x & 0x7f800000) != 0).where(x + ((x >> 16) & 1) + 0x7fff, ((x & 0xffff) != 0).where((x | 0x10000), x)) # type: ignore[comparison-overlap] return (x >> 16).cast(dtypes.ushort).bitcast(dtypes.bfloat16) # manual bfloat16 casting patterns (shared between LLVM, Clang, and AMD renderers to avoid compiler intrinsics) diff --git a/tinygrad/runtime/ops_amd.py b/tinygrad/runtime/ops_amd.py index 5df0fb22a9..f7f106af47 100644 --- a/tinygrad/runtime/ops_amd.py +++ b/tinygrad/runtime/ops_amd.py @@ -753,7 +753,7 @@ class KFDIface: def as_dmaref(self, mem:HCQBuffer) -> DMAFdRef: base = mem._base if mem._base is not None else mem - dmaref = DMAFdRef(kfd.AMDKFD_IOC_EXPORT_DMABUF(KFDIface.kfd, handle=base.meta.handle, flags=0).dmabuf_fd, mem.va_addr-base.va_addr, mem.size) + dmaref = DMAFdRef(kfd.AMDKFD_IOC_EXPORT_DMABUF(KFDIface.kfd, handle=base.meta.handle, flags=0).dmabuf_fd, int(mem.va_addr-base.va_addr), mem.size) weakref.finalize(dmaref, os.close, dmaref.fd) return dmaref diff --git a/tinygrad/runtime/ops_qcom.py b/tinygrad/runtime/ops_qcom.py index 916a8f2546..02ab87bc8b 100644 --- a/tinygrad/runtime/ops_qcom.py +++ b/tinygrad/runtime/ops_qcom.py @@ -143,7 +143,7 @@ class QCOMComputeQueue(HWQueue): qreg.a6xx_sp_cs_pvt_mem_param(memsizeperitem=prg.pvtmem_size_per_item), *data64_le(prg.dev._stack.va_addr), qreg.a6xx_sp_cs_pvt_mem_size(totalpvtmemsize=prg.pvtmem_size_total)) - if prg.NIR and prg.wgsz != 0xfc: to_mv(args_state.buf.va_addr + prg.wgsz * 4, 12)[:] = struct.pack("III", *local_size) + if prg.NIR and prg.wgsz != 0xfc: to_mv(int(args_state.buf.va_addr) + prg.wgsz * 4, 12)[:] = struct.pack("III", *local_size) self.cmd(mesa.CP_LOAD_STATE6_FRAG, qreg.cp_load_state6_0(state_type=mesa.ST_CONSTANTS, state_src=mesa.SS6_INDIRECT, state_block=mesa.SB6_CS_SHADER, num_unit=1024 // 4), *data64_le(args_state.buf.va_addr)) @@ -199,7 +199,7 @@ class QCOMArgsState(HCQArgsState): ibos, texs = uavs[:prg.ibo_cnt], uavs[prg.ibo_cnt:] for cnst_val,cnst_off,cnst_sz in prg.consts_info: to_mv(self.buf.va_addr + cnst_off, cnst_sz)[:] = cnst_val.to_bytes(cnst_sz, byteorder='little') - if prg.samp_cnt > 0: to_mv(self.buf.va_addr + prg.samp_off, len(prg.samplers) * 4).cast('I')[:] = array.array('I', prg.samplers) + if prg.samp_cnt > 0: to_mv(int(self.buf.va_addr) + prg.samp_off, len(prg.samplers) * 4).cast('I')[:] = array.array('I', prg.samplers) if prg.NIR: self.bind_sints_to_buf(*[b.va_addr for b in ubos], buf=self.buf, fmt='Q', offset=prg.buf_off) self.bind_sints_to_buf(*vals, buf=self.buf, fmt='I', offset=prg.buf_off + len(ubos) * 8) diff --git a/tinygrad/schedule/indexing.py b/tinygrad/schedule/indexing.py index e7bbf33db4..f362e6151b 100644 --- a/tinygrad/schedule/indexing.py +++ b/tinygrad/schedule/indexing.py @@ -122,7 +122,7 @@ pm_apply_rangeify = PatternMatcher([ @functools.cache def _apply_reshape(in_shape:tuple[sint,...], out_shape:tuple[sint, ...], urngs:UOp) -> UOp: - acc = 1 + acc:sint = 1 axes_in:list[UOp] = [] for s,src in list(zip(out_shape, urngs.src))[::-1]: axes_in.append(acc*src) diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index 07fb68a05e..2e664f5484 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -2247,7 +2247,7 @@ class Tensor(OpMixin): if ceil_mode: pads = self._apply_ceil_mode(pads, k_, stride if stride is not None else k_, dilation) pooled = self.pad(pads, value=dtypes.min(self.dtype))._pool(k_, stride if stride is not None else k_, dilation) if not return_indices: return pooled.max(axis) - spatial_sz = math.prod(spatial_shape := self.shape[-len(k_):]) + spatial_sz = int(math.prod(spatial_shape := self.shape[-len(k_):])) idx = Tensor.arange(spatial_sz,0,-1, requires_grad=False, device=self.device).reshape(spatial_shape) m = pooled == pooled.max(axis, keepdim=True) idx = m * idx.pad(pads, value=dtypes.min(idx.dtype))._pool(k_, stride if stride is not None else k_, dilation) @@ -2504,7 +2504,7 @@ class Tensor(OpMixin): ``` """ if self.ndim == 0: return self._split_cumalu(axis, Ops.MAX), Tensor.zeros(self.shape, dtype=dtypes.int32, device=self.device) - values, n = self._split_cumalu(axis, Ops.MAX), self.shape[axis] + values, n = self._split_cumalu(axis, Ops.MAX), int(self.shape[axis]) x, values_t = self.transpose(axis, -1), values.transpose(axis, -1) match = (x.unsqueeze(-1) == values_t.unsqueeze(-2)) * Tensor.ones(n, n, requires_grad=False, device=self.device).triu() idx = (-(match * Tensor.arange(n, 0, -1, requires_grad=False, device=self.device).reshape(n, 1)).max(-2) + n).cast(dtypes.int32) @@ -2595,7 +2595,7 @@ class Tensor(OpMixin): assert not (align_corners and mode != "linear"), "align_corners option can only be set with the interpolating mode linear" x, expand = self, list(self.shape) for i in range(-1,-len(size)-1,-1): - scale = (self.shape[i] - int(align_corners)) / (size[i] - int(align_corners)) + scale = (int(self.shape[i]) - int(align_corners)) / (size[i] - int(align_corners)) arr, reshape = Tensor.arange(size[i], dtype=dtypes.float32, device=self.device), [1] * self.ndim reshape[i] = expand[i] = size[i] if mode == "linear": @@ -2716,7 +2716,7 @@ class Tensor(OpMixin): ``` """ x, dim = self, self._resolve_dim(dim) - if (orig_len:= x.shape[dim]) <= 1: return x, x.zeros_like(dtype=dtypes.default_int) + if (orig_len := int(x.shape[dim])) <= 1: return x, x.zeros_like(dtype=dtypes.default_int) # pad to power of 2 n_stages = (orig_len-1).bit_length() pads = tuple((0, 2**n_stages - orig_len) if i == dim else None for i in range(x.ndim)) @@ -3543,7 +3543,7 @@ class Tensor(OpMixin): ``` """ if not dtypes.is_int(self.dtype): raise RuntimeError(f"expect integer dtype, getting {self.dtype=}") - if num_classes == -1: num_classes = (self.max()+1).item() + if num_classes == -1: num_classes = int((self.max()+1).item()) return self[..., None]._one_hot_along_dim(num_classes).where(1, 0) def scaled_dot_product_attention(self, key:Tensor, value:Tensor, attn_mask:Tensor|None=None, dropout_p:float=0.0, @@ -3570,8 +3570,8 @@ class Tensor(OpMixin): # GQA: https://docs.pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html if enable_gqa: - key = key.repeat_interleave(self.shape[-3] // key.shape[-3], dim=-3) - value = value.repeat_interleave(self.shape[-3] // value.shape[-3], dim=-3) + key = key.repeat_interleave(int(self.shape[-3] // key.shape[-3]), dim=-3) + value = value.repeat_interleave(int(self.shape[-3] // value.shape[-3]), dim=-3) q = self qk = q.matmul(key.transpose(-2,-1), dtype=least_upper_dtype(q.dtype, key.dtype, dtypes.float32)) / math.sqrt(q.shape[-1]) @@ -3706,7 +3706,8 @@ class Tensor(OpMixin): assert self.ndim > 1, "NS only works for two or more dims" if self.shape[-2] > self.shape[-1]: return self.transpose(-2, -1).newton_schulz(steps, params, eps).transpose(-2, -1) G = self / (self.square().sum(axis=(-2, -1), keepdim=True).sqrt() + eps) - for _ in range(steps): G = sum(p * functools.reduce(lambda x, y: (y @ y.transpose(-2, -1)) @ x, [G]*i, G) for i,p in enumerate(params)) + for _ in range(steps): + G = cast(Tensor, sum(p * functools.reduce(lambda x, y: (y @ y.transpose(-2, -1)) @ x, [G]*i, G) for i,p in enumerate(params))) return G def qr(self) -> tuple[Tensor, Tensor]: @@ -3801,7 +3802,7 @@ class Tensor(OpMixin): print(t.nbytes()) ``` """ - return self.numel() * self.element_size() + return int(self.numel()) * self.element_size() def is_floating_point(self) -> bool: """