mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
return types for all math.py function (#14413)
calling int() on sint -> int, i think it's better support since some UOp can be safely cast to int
This commit is contained in:
@@ -139,7 +139,7 @@ class TransformerBlock:
|
||||
v = self.cache_kv[1, :, :, 0:start_pos+T, :]
|
||||
|
||||
# NOTE: this mask is causal_lower_right, not the causal_upper_left generated by is_casual = True
|
||||
mask = Tensor.full((1, 1, T, start_pos+T), float("-inf"), dtype=x.dtype, device=x.device).triu(start_pos+1) if T > 1 else None
|
||||
mask = Tensor.full((1, 1, T, start_pos+T), float("-inf"), dtype=x.dtype, device=x.device).triu(int(start_pos)+1) if T > 1 else None
|
||||
attn = q.scaled_dot_product_attention(k, v, attn_mask=mask, enable_gqa=True) # (B,H,T,Hd)
|
||||
attn = attn.transpose(1, 2).reshape(B, T, -1) # back to (B,T,D)
|
||||
attn = self.attn_output(attn)
|
||||
|
||||
@@ -201,7 +201,7 @@ def image_fixup(ls:UOp):
|
||||
x_mod_4 = x % 4
|
||||
def sel(ret, i): return x_mod_4.ne(i).where(ret, vec_load.gep(i))
|
||||
# if x is non-negative, x % 4 is in [0, 3] and we can skip NAN fallback
|
||||
if x_mod_4.vmin >= 0: return functools.reduce(sel, range(x_mod_4.vmin+1, x_mod_4.vmax+1), vec_load.gep(x_mod_4.vmin))
|
||||
if x_mod_4.vmin >= 0: return functools.reduce(sel, range(int(x_mod_4.vmin)+1, int(x_mod_4.vmax)+1), vec_load.gep(int(x_mod_4.vmin)))
|
||||
return functools.reduce(sel, range(4), ls.const_like(float('nan')))
|
||||
|
||||
return None
|
||||
|
||||
@@ -27,7 +27,7 @@ class MathMixin:
|
||||
raise TypeError(f"MathTraits __neg__ requires a dtype, {self=}")
|
||||
return self.logical_not() if dtype.scalar() == dtypes.bool else self * (-1)
|
||||
|
||||
def _check_dtype(self):
|
||||
def _check_dtype(self) -> None:
|
||||
if (dtype := getattr(self, "dtype")) is not None:
|
||||
if isinstance(dtype, tuple):
|
||||
dtype = dtype[0]
|
||||
@@ -144,25 +144,25 @@ class MathMixin:
|
||||
def __neg__(self) -> Self:
|
||||
return self.neg()
|
||||
|
||||
def __add__(self, x: Self | ConstType):
|
||||
def __add__(self, x: Self | ConstType) -> Self:
|
||||
return self.add(x)
|
||||
|
||||
def __sub__(self, x: Self | ConstType):
|
||||
def __sub__(self, x: Self | ConstType) -> Self:
|
||||
return self.sub(x)
|
||||
|
||||
def __mul__(self, x: Self | ConstType):
|
||||
def __mul__(self, x: Self | ConstType) -> Self:
|
||||
return self.mul(x)
|
||||
|
||||
def __truediv__(self, x: Self | ConstType) -> Self:
|
||||
return self.div(x)
|
||||
|
||||
def __floordiv__(self, x: Self | ConstType):
|
||||
def __floordiv__(self, x: Self | ConstType) -> Self:
|
||||
return self.idiv(x) # TODO: idiv is trunc div, not floordiv
|
||||
|
||||
def __mod__(self, x: Self | ConstType):
|
||||
def __mod__(self, x: Self | ConstType) -> Self:
|
||||
return self.mod(x)
|
||||
|
||||
def __and__(self, x: Self | ConstType):
|
||||
def __and__(self, x: Self | ConstType) -> Self:
|
||||
return self.bitwise_and(x)
|
||||
|
||||
def __or__(self, x: Self | ConstType) -> Self:
|
||||
@@ -174,16 +174,16 @@ class MathMixin:
|
||||
def __radd__(self, x: Self | ConstType) -> Self:
|
||||
return self.add(x, True)
|
||||
|
||||
def __rsub__(self, x: Self | ConstType):
|
||||
def __rsub__(self, x: Self | ConstType) -> Self:
|
||||
return self.sub(x, True)
|
||||
|
||||
def __rmul__(self, x: Self | ConstType):
|
||||
def __rmul__(self, x: Self | ConstType) -> Self:
|
||||
return self.mul(x, True)
|
||||
|
||||
def __rtruediv__(self, x: Self | ConstType) -> Self:
|
||||
return self.div(x, True)
|
||||
|
||||
def __rfloordiv__(self, x: Self | ConstType):
|
||||
def __rfloordiv__(self, x: Self | ConstType) -> Self:
|
||||
return self.idiv(x, True)
|
||||
|
||||
def __rand__(self, x: Self | ConstType) -> Self:
|
||||
|
||||
@@ -6,6 +6,7 @@ from tinygrad.tensor import Tensor, _broadcast_shape, ReductionStr
|
||||
from tinygrad.helpers import getenv, all_same, prod, flatten, make_tuple, argsort, is_numpy_ndarray, get_single_element, polyN
|
||||
from tinygrad.dtype import DType, ConstType, dtypes, _from_np_dtype, truncate, least_upper_dtype, DTYPES_DICT
|
||||
from tinygrad.device import is_dtype_supported, Device
|
||||
from tinygrad.uop.ops import sint
|
||||
|
||||
# ***** protobuf definitions ******
|
||||
class WireType(enum.IntEnum):
|
||||
@@ -676,7 +677,7 @@ def get_onnx_ops() -> dict[str, types.FunctionType|dict[OpSetId, types.FunctionT
|
||||
def ReduceLogSumExp(data:Tensor, axes:list[int]|None=None, keepdims:int=1, noop_with_empty_axes:int=0):
|
||||
return ReduceSum(data.exp(), axes, keepdims, noop_with_empty_axes).log()
|
||||
def ArgMax(x:Tensor, axis:int=0, keepdims:int=1, select_last_index:int=0):
|
||||
if select_last_index: return ((x.shape[axis]-1) - x.flip(axis).argmax(axis, keepdim=keepdims)).cast(dtypes.int64)
|
||||
if select_last_index: return ((int(x.shape[axis])-1) - x.flip(axis).argmax(axis, keepdim=keepdims)).cast(dtypes.int64)
|
||||
return x.argmax(axis, keepdim=keepdims).cast(dtypes.int64)
|
||||
def ArgMin(x, axis:int=0, keepdims:int=1, select_last_index:int=0):
|
||||
return ArgMax(-x, axis=axis, keepdims=keepdims, select_last_index=select_last_index)
|
||||
@@ -703,7 +704,7 @@ def get_onnx_ops() -> dict[str, types.FunctionType|dict[OpSetId, types.FunctionT
|
||||
return data[tuple(slices)]
|
||||
|
||||
def Split(data:Tensor, split:list[int]|None=None, num_outputs:int=0, axis:int=0):
|
||||
sz = data.shape[axis]
|
||||
sz = int(data.shape[axis])
|
||||
if split is None: split = [sz // num_outputs + (1 if i < sz % num_outputs else 0) for i in range(num_outputs)]
|
||||
return data.split(split, axis)
|
||||
|
||||
@@ -716,8 +717,8 @@ def get_onnx_ops() -> dict[str, types.FunctionType|dict[OpSetId, types.FunctionT
|
||||
return x.pad(padding=_onnx_pads_to_tiny_pads(real_pads), mode={"edge":"replicate", "wrap":"circular"}.get(mode, mode), value=value)
|
||||
|
||||
def CenterCropPad(t:Tensor, shape:list[int], axes:list[int]|None=None):
|
||||
shrink_arg:list[None|tuple[int,int]] = [None] * t.ndim
|
||||
pad_arg:list[None|tuple[int,int]] = [None] * t.ndim
|
||||
shrink_arg:list[None|tuple[sint,sint]] = [None] * t.ndim
|
||||
pad_arg:list[None|tuple[sint,sint]] = [None] * t.ndim
|
||||
for s, x in zip(shape, axes or range(t.ndim)):
|
||||
tx = t.shape[x]
|
||||
if s < tx: shrink_arg[x] = (tx//2 - (s+1)//2, tx//2 + s//2)
|
||||
@@ -751,8 +752,8 @@ def get_onnx_ops() -> dict[str, types.FunctionType|dict[OpSetId, types.FunctionT
|
||||
pads = _auto_pad([s_*(i-1) + op_ + ((k_-1)*d_+1) - os for s_,i,op_,k_,d_,os in
|
||||
zip(strides_, input_shape_, output_padding_, kernel_shape_, dilations_, output_shape)], auto_pad)
|
||||
if pads is None: # we generate pads
|
||||
output_shape = output_shape or [X.shape[i+2] * strides_[i] for i in range(len(strides_))]
|
||||
pads = [strides_[i]*(input_shape_[i]-1)+output_padding_[i]+((kernel_shape_[i]-1)*dilations_[i]+1)-output_shape[i]
|
||||
output_shape = output_shape or [int(X.shape[i+2]) * strides_[i] for i in range(len(strides_))]
|
||||
pads = [int(strides_[i]*(input_shape_[i]-1)+output_padding_[i]+((kernel_shape_[i]-1)*dilations_[i]+1)-output_shape[i])
|
||||
for i in range(len(input_shape_))]
|
||||
pads = _auto_pad(pads, auto_pad) if auto_pad != "NOTSET" else [0] * len(input_shape_) * 2
|
||||
pads = _onnx_pads_to_tiny_pads(pads)
|
||||
@@ -1014,7 +1015,7 @@ def get_onnx_ops() -> dict[str, types.FunctionType|dict[OpSetId, types.FunctionT
|
||||
num_heads:int|None=None, past_present_share_buffer:int|None=None, qkv_hidden_sizes:list[int]|None=None,
|
||||
rotary_embedding_dim:int|None=None, scale:float|None=None, unidirectional:int=0):
|
||||
assert not do_rotary and not attention_bias, "TODO"
|
||||
if qkv_hidden_sizes is None: qkv_hidden_sizes = [weights.shape[1] // 3] * 3
|
||||
if qkv_hidden_sizes is None: qkv_hidden_sizes = [int(weights.shape[1] // 3)] * 3
|
||||
qkv = x.linear(weights, bias)
|
||||
q, k, v = qkv.split(qkv_hidden_sizes, dim=2)
|
||||
|
||||
@@ -1112,7 +1113,7 @@ def get_onnx_ops() -> dict[str, types.FunctionType|dict[OpSetId, types.FunctionT
|
||||
if X.ndim == 4: X = X.permute(0, 2, 1, 3)
|
||||
elif X.ndim == 3:
|
||||
assert num_heads is not None, "num_heads must be provided for 3D input"
|
||||
X = X.unflatten(-1, (num_heads, X.shape[-1] // num_heads))
|
||||
X = X.unflatten(-1, (num_heads, int(X.shape[-1]) // num_heads))
|
||||
|
||||
head_size = cast(int, X.shape[-1])
|
||||
rot_dim = rotary_embedding_dim or head_size
|
||||
@@ -1182,7 +1183,7 @@ def get_onnx_ops() -> dict[str, types.FunctionType|dict[OpSetId, types.FunctionT
|
||||
|
||||
def TensorScatter(data: Tensor, updates: Tensor, indices: Tensor, mode: str = 'default'):
|
||||
# scatter updates along axis -2 at positions given by indices, for each batch
|
||||
B, U, D = indices.shape[0], updates.shape[-2], data.shape[-2]
|
||||
B, U, D = indices.shape[0], updates.shape[-2], int(data.shape[-2])
|
||||
orig_shape, data_flat, updates_flat = data.shape, data.reshape(-1, D, data.shape[-1]), updates.reshape(-1, U, updates.shape[-1])
|
||||
B_total = data_flat.shape[0]
|
||||
batch_idx = Tensor.arange(B_total, device=data.device).reshape(B_total, 1).expand(B_total, U)
|
||||
|
||||
@@ -87,7 +87,8 @@ def create_non_native_float_pats(dts:tuple[DType, ...], casting:bool=True):
|
||||
def cast_float_to_bf16(x: UOp) -> UOp:
|
||||
assert x.dtype == dtypes.float, "cast float -> bf16 must start with float"
|
||||
x = x.bitcast(dtypes.uint)
|
||||
x = ((-x & 0x7f800000) != 0).where(x + ((x >> 16) & 1) + 0x7fff, ((x & 0xffff) != 0).where((x | 0x10000), x))
|
||||
# NOTE: != returns UOp, not bool, issue with mypy
|
||||
x = ((-x & 0x7f800000) != 0).where(x + ((x >> 16) & 1) + 0x7fff, ((x & 0xffff) != 0).where((x | 0x10000), x)) # type: ignore[comparison-overlap]
|
||||
return (x >> 16).cast(dtypes.ushort).bitcast(dtypes.bfloat16)
|
||||
|
||||
# manual bfloat16 casting patterns (shared between LLVM, Clang, and AMD renderers to avoid compiler intrinsics)
|
||||
|
||||
@@ -753,7 +753,7 @@ class KFDIface:
|
||||
|
||||
def as_dmaref(self, mem:HCQBuffer) -> DMAFdRef:
|
||||
base = mem._base if mem._base is not None else mem
|
||||
dmaref = DMAFdRef(kfd.AMDKFD_IOC_EXPORT_DMABUF(KFDIface.kfd, handle=base.meta.handle, flags=0).dmabuf_fd, mem.va_addr-base.va_addr, mem.size)
|
||||
dmaref = DMAFdRef(kfd.AMDKFD_IOC_EXPORT_DMABUF(KFDIface.kfd, handle=base.meta.handle, flags=0).dmabuf_fd, int(mem.va_addr-base.va_addr), mem.size)
|
||||
weakref.finalize(dmaref, os.close, dmaref.fd)
|
||||
return dmaref
|
||||
|
||||
|
||||
@@ -143,7 +143,7 @@ class QCOMComputeQueue(HWQueue):
|
||||
qreg.a6xx_sp_cs_pvt_mem_param(memsizeperitem=prg.pvtmem_size_per_item), *data64_le(prg.dev._stack.va_addr),
|
||||
qreg.a6xx_sp_cs_pvt_mem_size(totalpvtmemsize=prg.pvtmem_size_total))
|
||||
|
||||
if prg.NIR and prg.wgsz != 0xfc: to_mv(args_state.buf.va_addr + prg.wgsz * 4, 12)[:] = struct.pack("III", *local_size)
|
||||
if prg.NIR and prg.wgsz != 0xfc: to_mv(int(args_state.buf.va_addr) + prg.wgsz * 4, 12)[:] = struct.pack("III", *local_size)
|
||||
self.cmd(mesa.CP_LOAD_STATE6_FRAG, qreg.cp_load_state6_0(state_type=mesa.ST_CONSTANTS, state_src=mesa.SS6_INDIRECT,
|
||||
state_block=mesa.SB6_CS_SHADER, num_unit=1024 // 4),
|
||||
*data64_le(args_state.buf.va_addr))
|
||||
@@ -199,7 +199,7 @@ class QCOMArgsState(HCQArgsState):
|
||||
ibos, texs = uavs[:prg.ibo_cnt], uavs[prg.ibo_cnt:]
|
||||
for cnst_val,cnst_off,cnst_sz in prg.consts_info: to_mv(self.buf.va_addr + cnst_off, cnst_sz)[:] = cnst_val.to_bytes(cnst_sz, byteorder='little')
|
||||
|
||||
if prg.samp_cnt > 0: to_mv(self.buf.va_addr + prg.samp_off, len(prg.samplers) * 4).cast('I')[:] = array.array('I', prg.samplers)
|
||||
if prg.samp_cnt > 0: to_mv(int(self.buf.va_addr) + prg.samp_off, len(prg.samplers) * 4).cast('I')[:] = array.array('I', prg.samplers)
|
||||
if prg.NIR:
|
||||
self.bind_sints_to_buf(*[b.va_addr for b in ubos], buf=self.buf, fmt='Q', offset=prg.buf_off)
|
||||
self.bind_sints_to_buf(*vals, buf=self.buf, fmt='I', offset=prg.buf_off + len(ubos) * 8)
|
||||
|
||||
@@ -122,7 +122,7 @@ pm_apply_rangeify = PatternMatcher([
|
||||
|
||||
@functools.cache
|
||||
def _apply_reshape(in_shape:tuple[sint,...], out_shape:tuple[sint, ...], urngs:UOp) -> UOp:
|
||||
acc = 1
|
||||
acc:sint = 1
|
||||
axes_in:list[UOp] = []
|
||||
for s,src in list(zip(out_shape, urngs.src))[::-1]:
|
||||
axes_in.append(acc*src)
|
||||
|
||||
@@ -2247,7 +2247,7 @@ class Tensor(OpMixin):
|
||||
if ceil_mode: pads = self._apply_ceil_mode(pads, k_, stride if stride is not None else k_, dilation)
|
||||
pooled = self.pad(pads, value=dtypes.min(self.dtype))._pool(k_, stride if stride is not None else k_, dilation)
|
||||
if not return_indices: return pooled.max(axis)
|
||||
spatial_sz = math.prod(spatial_shape := self.shape[-len(k_):])
|
||||
spatial_sz = int(math.prod(spatial_shape := self.shape[-len(k_):]))
|
||||
idx = Tensor.arange(spatial_sz,0,-1, requires_grad=False, device=self.device).reshape(spatial_shape)
|
||||
m = pooled == pooled.max(axis, keepdim=True)
|
||||
idx = m * idx.pad(pads, value=dtypes.min(idx.dtype))._pool(k_, stride if stride is not None else k_, dilation)
|
||||
@@ -2504,7 +2504,7 @@ class Tensor(OpMixin):
|
||||
```
|
||||
"""
|
||||
if self.ndim == 0: return self._split_cumalu(axis, Ops.MAX), Tensor.zeros(self.shape, dtype=dtypes.int32, device=self.device)
|
||||
values, n = self._split_cumalu(axis, Ops.MAX), self.shape[axis]
|
||||
values, n = self._split_cumalu(axis, Ops.MAX), int(self.shape[axis])
|
||||
x, values_t = self.transpose(axis, -1), values.transpose(axis, -1)
|
||||
match = (x.unsqueeze(-1) == values_t.unsqueeze(-2)) * Tensor.ones(n, n, requires_grad=False, device=self.device).triu()
|
||||
idx = (-(match * Tensor.arange(n, 0, -1, requires_grad=False, device=self.device).reshape(n, 1)).max(-2) + n).cast(dtypes.int32)
|
||||
@@ -2595,7 +2595,7 @@ class Tensor(OpMixin):
|
||||
assert not (align_corners and mode != "linear"), "align_corners option can only be set with the interpolating mode linear"
|
||||
x, expand = self, list(self.shape)
|
||||
for i in range(-1,-len(size)-1,-1):
|
||||
scale = (self.shape[i] - int(align_corners)) / (size[i] - int(align_corners))
|
||||
scale = (int(self.shape[i]) - int(align_corners)) / (size[i] - int(align_corners))
|
||||
arr, reshape = Tensor.arange(size[i], dtype=dtypes.float32, device=self.device), [1] * self.ndim
|
||||
reshape[i] = expand[i] = size[i]
|
||||
if mode == "linear":
|
||||
@@ -2716,7 +2716,7 @@ class Tensor(OpMixin):
|
||||
```
|
||||
"""
|
||||
x, dim = self, self._resolve_dim(dim)
|
||||
if (orig_len:= x.shape[dim]) <= 1: return x, x.zeros_like(dtype=dtypes.default_int)
|
||||
if (orig_len := int(x.shape[dim])) <= 1: return x, x.zeros_like(dtype=dtypes.default_int)
|
||||
# pad to power of 2
|
||||
n_stages = (orig_len-1).bit_length()
|
||||
pads = tuple((0, 2**n_stages - orig_len) if i == dim else None for i in range(x.ndim))
|
||||
@@ -3543,7 +3543,7 @@ class Tensor(OpMixin):
|
||||
```
|
||||
"""
|
||||
if not dtypes.is_int(self.dtype): raise RuntimeError(f"expect integer dtype, getting {self.dtype=}")
|
||||
if num_classes == -1: num_classes = (self.max()+1).item()
|
||||
if num_classes == -1: num_classes = int((self.max()+1).item())
|
||||
return self[..., None]._one_hot_along_dim(num_classes).where(1, 0)
|
||||
|
||||
def scaled_dot_product_attention(self, key:Tensor, value:Tensor, attn_mask:Tensor|None=None, dropout_p:float=0.0,
|
||||
@@ -3570,8 +3570,8 @@ class Tensor(OpMixin):
|
||||
|
||||
# GQA: https://docs.pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html
|
||||
if enable_gqa:
|
||||
key = key.repeat_interleave(self.shape[-3] // key.shape[-3], dim=-3)
|
||||
value = value.repeat_interleave(self.shape[-3] // value.shape[-3], dim=-3)
|
||||
key = key.repeat_interleave(int(self.shape[-3] // key.shape[-3]), dim=-3)
|
||||
value = value.repeat_interleave(int(self.shape[-3] // value.shape[-3]), dim=-3)
|
||||
|
||||
q = self
|
||||
qk = q.matmul(key.transpose(-2,-1), dtype=least_upper_dtype(q.dtype, key.dtype, dtypes.float32)) / math.sqrt(q.shape[-1])
|
||||
@@ -3706,7 +3706,8 @@ class Tensor(OpMixin):
|
||||
assert self.ndim > 1, "NS only works for two or more dims"
|
||||
if self.shape[-2] > self.shape[-1]: return self.transpose(-2, -1).newton_schulz(steps, params, eps).transpose(-2, -1)
|
||||
G = self / (self.square().sum(axis=(-2, -1), keepdim=True).sqrt() + eps)
|
||||
for _ in range(steps): G = sum(p * functools.reduce(lambda x, y: (y @ y.transpose(-2, -1)) @ x, [G]*i, G) for i,p in enumerate(params))
|
||||
for _ in range(steps):
|
||||
G = cast(Tensor, sum(p * functools.reduce(lambda x, y: (y @ y.transpose(-2, -1)) @ x, [G]*i, G) for i,p in enumerate(params)))
|
||||
return G
|
||||
|
||||
def qr(self) -> tuple[Tensor, Tensor]:
|
||||
@@ -3801,7 +3802,7 @@ class Tensor(OpMixin):
|
||||
print(t.nbytes())
|
||||
```
|
||||
"""
|
||||
return self.numel() * self.element_size()
|
||||
return int(self.numel()) * self.element_size()
|
||||
|
||||
def is_floating_point(self) -> bool:
|
||||
"""
|
||||
|
||||
Reference in New Issue
Block a user