mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-10 15:38:29 -05:00
@@ -1,5 +1,4 @@
|
||||
# mypy: disable-error-code="misc, list-item, assignment, operator, index, arg-type"
|
||||
from typing import Any, Sequence, cast, Literal, NamedTuple, Generator, get_args
|
||||
from typing import Any, Sequence, cast, Literal, NamedTuple, Generator
|
||||
import dataclasses, functools, io, math, types, warnings, pathlib, sys, os, struct, enum
|
||||
from io import BufferedReader
|
||||
from tinygrad.nn.state import TensorIO
|
||||
@@ -481,7 +480,7 @@ class OnnxRunner:
|
||||
####################
|
||||
def get_onnx_ops() -> dict[str, types.FunctionType|dict[OpSetId, types.FunctionType]]:
|
||||
# ***** helper functions *****
|
||||
def _resolve_const(x: Sequence[ConstType]|ConstType): return x if isinstance(x, get_args(ConstType)) else get_single_element(x)
|
||||
def _resolve_const(x: Sequence[ConstType]|ConstType): return get_single_element(x) if isinstance(x, Sequence) else x
|
||||
|
||||
def _axes(axes, noop_with_empty_axes): return axes or ([] if noop_with_empty_axes else None)
|
||||
|
||||
@@ -698,13 +697,14 @@ def get_onnx_ops() -> dict[str, types.FunctionType|dict[OpSetId, types.FunctionT
|
||||
# ***** Processing Ops *****
|
||||
def AveragePool(X: Tensor, kernel_shape:list[int], auto_pad:AUTO_PAD_OPTIONS="NOTSET", ceil_mode:int=0, count_include_pad:int=0,
|
||||
dilations:list[int]|int=1, pads:list[int]|int=0, strides:list[int]|int=1):
|
||||
return X.avg_pool2d(kernel_shape, strides, dilations, _resolve_pool_pads(X, pads, kernel_shape, dilations, strides, auto_pad),
|
||||
ceil_mode=ceil_mode, count_include_pad=count_include_pad)
|
||||
pool_pads = _resolve_pool_pads(X, pads, kernel_shape, dilations, strides, auto_pad)
|
||||
return X.avg_pool2d(tuple(kernel_shape), strides, dilations, pool_pads, ceil_mode=ceil_mode, count_include_pad=count_include_pad)
|
||||
|
||||
def MaxPool(X: Tensor, kernel_shape:list[int], auto_pad:AUTO_PAD_OPTIONS="NOTSET", ceil_mode:int=0, dilations:list[int]|int=1, pads:list[int]|int=0,
|
||||
storage_order:int=0, strides:list[int]|int=1):
|
||||
pads = _resolve_pool_pads(X, pads, kernel_shape, dilations, strides, auto_pad)
|
||||
ret, idx = X.max_pool2d(kernel_shape, strides, dilations, pads, ceil_mode=ceil_mode, return_indices=True)
|
||||
pool_pads = _resolve_pool_pads(X, pads, kernel_shape, dilations, strides, auto_pad)
|
||||
out = X.max_pool2d(tuple(kernel_shape), strides, dilations, pool_pads, ceil_mode=ceil_mode, return_indices=True)
|
||||
ret, idx = cast(tuple[Tensor, Tensor], out)
|
||||
return ret, idx.transpose(-2, -1).cast(dtypes.int64) if storage_order else idx.cast(dtypes.int64)
|
||||
|
||||
def Conv(X: Tensor, W: Tensor, B:Tensor|None=None, auto_pad:AUTO_PAD_OPTIONS="NOTSET", dilations:list[int]|int=1, group:int=1,
|
||||
@@ -715,20 +715,22 @@ def get_onnx_ops() -> dict[str, types.FunctionType|dict[OpSetId, types.FunctionT
|
||||
def ConvTranspose(X: Tensor, W: Tensor, B:Tensor|None=None, auto_pad:AUTO_PAD_OPTIONS="NOTSET", dilations:list[int]|int=1, group:int=1,
|
||||
kernel_shape:list[int]|None=None, pads:list[int]|None=None, output_shape:list[int]|None=None, output_padding:list[int]|int=0,
|
||||
strides:list[int]|int=1):
|
||||
input_shape, kernel_shape = X.shape[2:], (kernel_shape or W.shape[2:])
|
||||
strides, dilations, output_padding = (make_tuple(x, len(input_shape)) for x in (strides, dilations, output_padding))
|
||||
input_shape_, kernel_shape_ = X.shape[2:], (kernel_shape or W.shape[2:])
|
||||
strides_, dilations_, output_padding_ = (make_tuple(x, len(input_shape_)) for x in (strides, dilations, output_padding))
|
||||
if output_shape is not None: # we pad according to output_shape
|
||||
pads = _auto_pad([s*(i-1) + op + ((k-1)*d+1) - os for s,i,op,k,d,os in
|
||||
zip(strides, input_shape, output_padding, kernel_shape, dilations, output_shape)], auto_pad)
|
||||
pads = _auto_pad([s_*(i-1) + op_ + ((k_-1)*d_+1) - os for s_,i,op_,k_,d_,os in
|
||||
zip(strides_, input_shape_, output_padding_, kernel_shape_, dilations_, output_shape)], auto_pad)
|
||||
if pads is None: # we generate pads
|
||||
output_shape = output_shape or [X.shape[i+2] * strides[i] for i in range(len(strides))]
|
||||
pads = [strides[i]*(input_shape[i]-1)+output_padding[i]+((kernel_shape[i]-1)*dilations[i]+1)-output_shape[i] for i in range(len(input_shape))]
|
||||
pads = _auto_pad(pads, auto_pad) if auto_pad != "NOTSET" else [0] * len(input_shape) * 2
|
||||
output_shape = output_shape or [X.shape[i+2] * strides_[i] for i in range(len(strides_))]
|
||||
pads = [strides_[i]*(input_shape_[i]-1)+output_padding_[i]+((kernel_shape_[i]-1)*dilations_[i]+1)-output_shape[i]
|
||||
for i in range(len(input_shape_))]
|
||||
pads = _auto_pad(pads, auto_pad) if auto_pad != "NOTSET" else [0] * len(input_shape_) * 2
|
||||
pads = _onnx_pads_to_tiny_pads(pads)
|
||||
return X.conv_transpose2d(W, B, stride=strides, groups=group, dilation=dilations, padding=pads, output_padding=output_padding)
|
||||
return X.conv_transpose2d(W, B, group, strides_, dilations_, pads, output_padding_)
|
||||
|
||||
def MaxUnpool(xT: Tensor, xI: Tensor, outshape: list[int]|None=None, kernel_shape:list[int]=None, pads:list[int]|int=0, strides:list[int]|int=1):
|
||||
return Tensor.max_unpool2d(xT, xI, kernel_shape, strides, 1, pads, outshape if outshape is None else tuple(outshape))
|
||||
def MaxUnpool(xT: Tensor, xI: Tensor, outshape: list[int]|None=None, kernel_shape:list[int]=[], pads:list[int]|int=0, strides:list[int]|int=1):
|
||||
pads_: int | tuple[int, ...] = tuple(pads) if isinstance(pads, list) else pads
|
||||
return Tensor.max_unpool2d(xT, xI, tuple(kernel_shape), strides, 1, pads_, outshape if outshape is None else tuple(outshape))
|
||||
|
||||
def GlobalAveragePool(X:Tensor): return X.mean(axis=tuple(range(2, X.ndim)), keepdim=True)
|
||||
def GlobalMaxPool(X:Tensor): return X.max(axis=tuple(range(2, X.ndim)), keepdim=True)
|
||||
@@ -775,7 +777,6 @@ def get_onnx_ops() -> dict[str, types.FunctionType|dict[OpSetId, types.FunctionT
|
||||
input_shape = cast(tuple[int, ...], X.shape[2:])
|
||||
if scales is not None: assert all(sc==1 for sc in scales[:-len(input_shape)]), "resizing batch_size dim or channel dim not supported"
|
||||
if sizes is not None: assert tuple(sizes[:-2]) == tuple(X.shape[X.ndim-len(sizes):-2]), "resizing batch_size dim or channel dim not supported"
|
||||
assert (scales is not None) ^ (sizes is not None), "only provide one of `scales` or `sizes`"
|
||||
|
||||
scales, sizes = (None if scales is None else scales[-len(input_shape):]), (None if sizes is None else sizes[-len(input_shape):])
|
||||
if sizes is not None:
|
||||
@@ -784,7 +785,9 @@ def get_onnx_ops() -> dict[str, types.FunctionType|dict[OpSetId, types.FunctionT
|
||||
scale = scale_fxn(sz / sh for sz,sh in zip(sizes, input_shape))
|
||||
sizes, scales = [int(scale * sh + 0.5) for sh in input_shape], [scale]*len(input_shape)
|
||||
else: scales = [sz / sh for sz, sh in zip(sizes, input_shape)]
|
||||
else: sizes = [int(sc * sh) for sc, sh in zip(scales, input_shape)]
|
||||
else:
|
||||
assert scales is not None, "either sizes or scales must be provided"
|
||||
sizes = [int(sc * sh) for sc, sh in zip(scales, input_shape)]
|
||||
|
||||
if all(sz == sh for sz, sh in zip(sizes, input_shape)): return X.permute(*argsort(perm)) if perm else X
|
||||
|
||||
@@ -823,7 +826,7 @@ def get_onnx_ops() -> dict[str, types.FunctionType|dict[OpSetId, types.FunctionT
|
||||
|
||||
expand = list(X.shape)
|
||||
for i in range(-len(sizes), 0):
|
||||
input_sz = X.shape[i]
|
||||
input_sz = cast(int, X.shape[i])
|
||||
reshape, index = [1] * X.ndim, indexes[i]
|
||||
reshape[i] = expand[i] = sizes[i]
|
||||
|
||||
@@ -855,7 +858,7 @@ def get_onnx_ops() -> dict[str, types.FunctionType|dict[OpSetId, types.FunctionT
|
||||
def Upsample(X, scales, mode): return Resize(X=X, scales=scales, mode=mode) # deprecated
|
||||
|
||||
def TopK(X:Tensor, K:int|list[int], axis:int=-1, largest:int=1, sorted:int=1): # noqa: A002
|
||||
val, idx = X.topk(_resolve_const(K), axis, largest, sorted)
|
||||
val, idx = X.topk(_resolve_const(K), axis, bool(largest), bool(sorted))
|
||||
return val, idx.cast(dtypes.int64)
|
||||
|
||||
# ***** Neural Network Ops *****
|
||||
@@ -877,7 +880,7 @@ def get_onnx_ops() -> dict[str, types.FunctionType|dict[OpSetId, types.FunctionT
|
||||
x = x.reshape(x.shape[0], num_groups, -1).layernorm(eps=epsilon).reshape(x.shape)
|
||||
return x * scale.reshape(1, -1, *[1] * (x.ndim-2)) + bias.reshape(1, -1, *[1] * (x.ndim-2))
|
||||
def InstanceNormalization(x:Tensor, scale:Tensor, bias:Tensor, epsilon:float=1e-05):
|
||||
return GroupNormalization(x, scale, bias, num_groups=x.shape[1], epsilon=epsilon)
|
||||
return GroupNormalization(x, scale, bias, num_groups=cast(int, x.shape[1]), epsilon=epsilon)
|
||||
def LayerNormalization(x:Tensor, scale:Tensor, bias:Tensor, axis:int=-1, epsilon:float=1e-05, stash_type:int=1):
|
||||
assert stash_type == 1, "only float32 is supported"
|
||||
axes = tuple(i for i in range(axis if axis >= 0 else x.ndim + axis, x.ndim))
|
||||
@@ -973,6 +976,7 @@ def get_onnx_ops() -> dict[str, types.FunctionType|dict[OpSetId, types.FunctionT
|
||||
q, k, v = qkv.split(qkv_hidden_sizes, dim=2)
|
||||
|
||||
batch_size, seq_len, _ = x.shape
|
||||
assert num_heads is not None, "num_heads must be provided"
|
||||
q_head_size, k_head_size, v_head_size = (sz // num_heads for sz in qkv_hidden_sizes)
|
||||
q, k, v = (x.reshape(batch_size, seq_len, num_heads, hsz).transpose(1, 2) for x, hsz in zip((q, k, v), (q_head_size, k_head_size, v_head_size)))
|
||||
|
||||
@@ -986,6 +990,7 @@ def get_onnx_ops() -> dict[str, types.FunctionType|dict[OpSetId, types.FunctionT
|
||||
|
||||
if mask_index is not None:
|
||||
assert 4 >= mask_index.ndim >= 1, f"{mask_index.ndim=}"
|
||||
assert isinstance(batch_size, int), f"{batch_size=}"
|
||||
if mask_index.ndim != 1: mask = mask_index.bool()
|
||||
else:
|
||||
if mask_index.shape[0] == batch_size:
|
||||
@@ -1026,7 +1031,7 @@ def get_onnx_ops() -> dict[str, types.FunctionType|dict[OpSetId, types.FunctionT
|
||||
K = K.repeat((1, _q_heads // _kv_heads, 1, 1))
|
||||
V = V.repeat((1, _q_heads // _kv_heads, 1, 1))
|
||||
|
||||
effective_scale = scale if scale is not None else 1.0 / (Q.shape[-1] ** 0.5)
|
||||
effective_scale = scale if scale is not None else 1.0 / (cast(int, Q.shape[-1]) ** 0.5)
|
||||
scores = (Q @ K.transpose(-1, -2)) * effective_scale
|
||||
qk_matmul_return_val = scores
|
||||
|
||||
@@ -1064,12 +1069,12 @@ def get_onnx_ops() -> dict[str, types.FunctionType|dict[OpSetId, types.FunctionT
|
||||
assert num_heads is not None, "num_heads must be provided for 3D input"
|
||||
X = X.reshape(*X.shape[:-1], num_heads, X.shape[-1] // num_heads)
|
||||
|
||||
head_size = X.shape[-1]
|
||||
head_size = cast(int, X.shape[-1])
|
||||
rot_dim = rotary_embedding_dim or head_size
|
||||
x_rotate, x_pass = X[..., :rot_dim], X[..., rot_dim:]
|
||||
|
||||
cos = cos_cache[position_ids] if position_ids is not None else cos_cache[:X.shape[1]]
|
||||
sin = sin_cache[position_ids] if position_ids is not None else sin_cache[:X.shape[1]]
|
||||
cos = cos_cache[position_ids] if position_ids is not None else cos_cache[:head_size]
|
||||
sin = sin_cache[position_ids] if position_ids is not None else sin_cache[:head_size]
|
||||
cos = cos[..., :rot_dim//2].unsqueeze(2)
|
||||
sin = sin[..., :rot_dim//2].unsqueeze(2)
|
||||
|
||||
@@ -1127,7 +1132,8 @@ def get_onnx_ops() -> dict[str, types.FunctionType|dict[OpSetId, types.FunctionT
|
||||
def ScatterElements(x: Tensor, indices: Tensor, updates: Tensor, axis=0, reduction:Literal["none", "add", "mul", "min", "max"]="none"):
|
||||
indices = (indices < 0).where(x.shape[axis], 0) + indices
|
||||
if reduction == "none": return x.scatter(axis, indices, updates)
|
||||
return x.scatter_reduce(axis, indices, updates, {"add": "sum", "mul": "prod", "min": "amin", "max": "amax"}.get(reduction))
|
||||
reduction_ = cast(Literal["sum", "prod", "amin", "amax"], {"add": "sum", "mul": "prod", "min": "amin", "max": "amax"}[reduction])
|
||||
return x.scatter_reduce(axis, indices, updates, reduction_)
|
||||
def GatherElements(x:Tensor, indices:Tensor, axis:int):
|
||||
indices = (indices < 0).where(x.shape[axis], 0) + indices
|
||||
return x.gather(axis, indices)
|
||||
@@ -1165,12 +1171,12 @@ def get_onnx_ops() -> dict[str, types.FunctionType|dict[OpSetId, types.FunctionT
|
||||
x_scale, x_zero_point = _prepare_quantize(x, x_scale, x_zero_point, axis, block_size)
|
||||
return ((x.int() - x_zero_point) * x_scale).cast(x_scale.dtype)
|
||||
|
||||
def QLinearConv(x:Tensor, x_scale:Tensor, x_zero_point:Tensor|int, w:Tensor, w_scale:Tensor, w_zero_point:Tensor|int, y_scale:Tensor,
|
||||
y_zero_point: Tensor|int, B:Tensor|None=None, **opts):
|
||||
def QLinearConv(x:Tensor, x_scale:Tensor, x_zero_point:Tensor, w:Tensor, w_scale:Tensor, w_zero_point:Tensor, y_scale:Tensor,
|
||||
y_zero_point:Tensor, B:Tensor|None=None, **opts):
|
||||
return _qlinearop_quantized(Conv, [x,w], [x_zero_point,w_zero_point], [x_scale,w_scale], y_scale, y_zero_point, **{"B":B, **opts})
|
||||
|
||||
def QLinearMatMul(a:Tensor, a_scale:Tensor, a_zero_point:Tensor|int, b:Tensor, b_scale:Tensor, b_zero_point:Tensor|int, y_scale:Tensor,
|
||||
y_zero_point:Tensor|int) -> Tensor:
|
||||
def QLinearMatMul(a:Tensor, a_scale:Tensor, a_zero_point:Tensor, b:Tensor, b_scale:Tensor, b_zero_point:Tensor, y_scale:Tensor,
|
||||
y_zero_point:Tensor) -> Tensor:
|
||||
return _qlinearop_quantized(Tensor.matmul, [a,b], [a_zero_point,b_zero_point], [a_scale,b_scale], y_scale, y_zero_point)
|
||||
|
||||
def QLinearAdd(a:Tensor, a_scale:Tensor, a_zero_point:Tensor, b:Tensor, b_scale:Tensor, b_zero_point:Tensor, c_scale:Tensor, c_zero_point:Tensor):
|
||||
@@ -1183,10 +1189,10 @@ def get_onnx_ops() -> dict[str, types.FunctionType|dict[OpSetId, types.FunctionT
|
||||
assert channels_last == 0, "TODO NHWC"
|
||||
return _qlinearop_float(GlobalAveragePool, [X], [x_zero_point], [x_scale], y_scale, y_zero_point)
|
||||
|
||||
def ConvInteger(x: Tensor, w: Tensor, x_zero_point: Tensor | int = 0, w_zero_point: Tensor | int = 0, B: Tensor | None = None, **opts) -> Tensor:
|
||||
def ConvInteger(x: Tensor, w: Tensor, x_zero_point:Tensor = Tensor(0), w_zero_point:Tensor = Tensor(0), B: Tensor | None = None, **opts) -> Tensor:
|
||||
return _op_integer(Conv, [x,w], [x_zero_point,w_zero_point], **{"B":B, **opts})
|
||||
|
||||
def MatMulInteger(A: Tensor, B: Tensor, a_zero_point: Tensor | int = 0, b_zero_point: Tensor | int = 0) -> Tensor:
|
||||
def MatMulInteger(A: Tensor, B: Tensor, a_zero_point: Tensor = Tensor(0), b_zero_point: Tensor = Tensor(0)) -> Tensor:
|
||||
return _op_integer(Tensor.matmul, [A,B], [a_zero_point,b_zero_point])
|
||||
|
||||
# ***** Training Ops *****
|
||||
|
||||
Reference in New Issue
Block a user