move to_python_const from onnx_ops to onnx (#8158)

* move to_python_const out

* move more over

* try deleting alternative gather implementation

* Revert "try deleting alternative gather implementation"

This reverts commit d46b30b717.

* add types to onnx ops

* better debug msg

* improve some com.microsoft too

---------

Co-authored-by: chenyu <chenyu@fastmail.com>
This commit is contained in:
geohotstan
2024-12-18 03:12:06 +08:00
committed by GitHub
parent 21b085b8ed
commit 32c995a5da
2 changed files with 157 additions and 140 deletions

View File

@@ -85,6 +85,15 @@ def get_run_onnx(onnx_model: ModelProto):
"Softsign", "Asinh", "Acosh", "Atanh", "Elu", "Celu", "Selu", "Xor", "Round", "Erf")
}
# these values are expected to be python consts
required_input_python_consts: Dict[str, tuple[int, ...]] = {
"Tile": (1,), "Range": (0,1,2), "Expand": (1,), "Reshape": (1,), "Squeeze": (1,), "Unsqueeze": (1,), "Trilu": (1,), "ConstantOfShape": (0,),
"CumSum": (1,), "Pad": (1,2,3), "MaxUnpool": (2,), "Dropout": (1,2), "CenterCropPad": (1,), "OneHot": (1,), "Compress": (1,),
"ImageDecoder": (0,), "AffineGrid": (1,), "Resize": (1,2,3), "Upsample": (1,), "Split": (1,), "Slice": (1,2,3,4),
**{"Reduce"+r: (1,) for r in ("Max", "Min", "Sum", "Mean", "SumSquare", "Prod", "L1", "L2", "LogSum", "LogSumExp")},
**{optim: (1,) for optim in ("Adam", "Adagrad", "Momentum")}
}
# src: https://onnx.ai/onnx/repo-docs/IR.html#input-output-data-types
# parses and validates inputs based on their shape and dtype specified by model
def prepare_input(user_input:Any, model_input:ValueInfoProto):
@@ -121,11 +130,15 @@ def get_run_onnx(onnx_model: ModelProto):
model_tensors[name] = prepare_input(inputs[name], value_info)
for num,n in enumerate(onnx_model.graph.node):
inp = [model_tensors.get(x) for x in n.input]
inp_tensors = [model_tensors.get(x) for x in n.input]
required_consts = required_input_python_consts.get(n.op_type, ())
inp = [to_python_const(t) if i in required_consts else t for i,t in enumerate(inp_tensors)]
opt = model_attributes[num]
if debug >= 1: print(f"{num}: op \"{n.op_type}\" input shapes {[x.shape if isinstance(x, Tensor) else x for x in inp]} opt {opt}")
if debug >= 3: print("\tinputs:\n" + "\n".join(f"\t\t{x} - {t}" for i,(x,t) in enumerate(zip(n.input, inp))))
if debug >= 1: print(f"{num}: op \"{n.op_type}\" input shapes {[x.shape if isinstance(x, Tensor) else x for x in inp_tensors]} opt {opt}")
if debug >= 3:
print("\tinputs:")
print("\n".join(f"\t\t{x} - {t}" + (" (to_python_const)" if i in required_consts else "") for i,(x,t) in enumerate(zip(n.input, inp))))
if n.op_type in tensor_methods:
ret = getattr(Tensor, tensor_methods[n.op_type])(*inp, **opt)
@@ -134,7 +147,7 @@ def get_run_onnx(onnx_model: ModelProto):
elif n.op_type == "Split":
axis, n_outputs = opt.get('axis', 0), opt.get('num_outputs') or len(n.output)
sz = inp[0].shape[axis]
sizes = to_python_const(inp[1]) if len(inp) == 2 else [sz // n_outputs + (1 if i < sz % n_outputs else 0) for i in range(n_outputs)]
sizes = inp[1] if len(inp) == 2 else [sz // n_outputs + (1 if i < sz % n_outputs else 0) for i in range(n_outputs)]
ret = inp[0].split(sizes, axis)
elif n.op_type == "Gradient":
assert len(opt["xs"]) == len(inp), f"len(opt['xs']):{len(opt['xs'])}, len(inp):{len(inp)} output and input has to match"

View File

@@ -1,6 +1,6 @@
import functools, io, math
from typing import Union, Tuple, Optional, List, Any, cast
from tinygrad.tensor import Tensor, _broadcast_shape
from tinygrad.tensor import Tensor, _broadcast_shape, ConstType
from tinygrad.dtype import ImageDType, dtypes
from tinygrad.helpers import prod, flatten
from extra.onnx import dtype_parse, to_python_const
@@ -8,10 +8,10 @@ import numpy as np
# **************** Free Ops ****************
def Identity(x: Tensor): return x
def Identity(x:Tensor): return x
# TODO: fix buffer_parse
def Add(x: Tensor, other: Tensor, broadcast=None, axis=None): return x + other if x.dtype == dtypes.float or isinstance(x.dtype, ImageDType) else (x + other).cast(x.dtype)
def Sub(x: Union[Tensor, Any], other: Tensor): return x - other # some test has input as int
def Add(x:Tensor, other:Tensor, broadcast=None, axis=None): return x + other if x.dtype == dtypes.float or isinstance(x.dtype, ImageDType) else (x + other).cast(x.dtype)
def Sub(x:Union[Tensor, Any], other:Tensor): return x - other # some test has input as int
def Less(x:Tensor,y:Tensor): return x < y
def LessOrEqual(x:Tensor,y:Tensor): return x <= y
def Greater(x:Tensor,y:Tensor): return x > y
@@ -21,120 +21,124 @@ def BitwiseNot(x:Tensor): return ~x
def BitwiseOr(x:Tensor, y:Tensor): return x | y
def BitwiseAnd(x:Tensor, y:Tensor): return x & y
def BitwiseXor(x:Tensor, y:Tensor): return x ^ y
def Max(*data_0): return functools.reduce(Tensor.maximum, data_0)
def Min(*data_0): return functools.reduce(Tensor.minimum, data_0)
def Sum(*data_0): return functools.reduce(Tensor.add, data_0)
def Mean(*data_0): return Sum(*data_0) / len(data_0)
def Max(*data_0:Tensor): return functools.reduce(Tensor.maximum, data_0)
def Min(*data_0:Tensor): return functools.reduce(Tensor.minimum, data_0)
def Sum(*data_0:Tensor): return functools.reduce(Tensor.add, data_0)
def Mean(*data_0:Tensor): return Sum(*data_0) / len(data_0)
# NOTE: does not support saturate
def Cast(x: Tensor, to: int, saturate=1): return x.cast(dtype_parse(to))
def CastLike(x: Tensor, target_type: Tensor, saturate=1): return x.cast(target_type.dtype)
def Cast(x:Tensor, to:int, saturate:int=1): return x.cast(dtype_parse(to))
def CastLike(x:Tensor, target_type:Tensor, saturate:int=1): return x.cast(target_type.dtype)
# **************** Simple Ops ****************
# https://github.com/onnx/onnx/blob/main/onnx/reference/ops/op_div.py
def Div(x: Tensor, other: Tensor): return (x/other).cast(x.dtype)
def Div(x:Tensor, other:Tensor): return (x/other).cast(x.dtype)
def Constant(value:Optional[Tensor]=None, value_float=None, value_floats=None, value_int=None, value_ints=None, value_string=None, value_strings=None):
def Constant(sparse_value:Optional[Tensor]=None, value:Optional[Tensor]=None, value_float:Optional[float]=None,
value_floats:Optional[List[float]]=None, value_int:Optional[int]=None, value_ints:Optional[List[int]]=None,
value_string:Optional[str]=None, value_strings:Optional[List[str]]=None):
if value is not None: return value
if value_float is not None: return Tensor(value_float, dtype=dtypes.float32, requires_grad=False)
if value_floats is not None: return Tensor(list(value_floats), dtype=dtypes.float32, requires_grad=False)
if value_int is not None: return Tensor(value_int, dtype=dtypes.int64, requires_grad=False)
if value_ints is not None: return Tensor(list(value_ints), dtype=dtypes.int64, requires_grad=False)
if value_string is not None or value_strings is not None: raise NotImplementedError('value_string or value_strings not implemented for Constant op')
if value_string is not None or value_strings is not None and sparse_value is not None:
raise NotImplementedError('Constant OP not implemented for value_string, value_strings and sparse_value')
def HardSigmoid(x: Tensor, alpha=0.2, beta=0.5): return (alpha*x + beta).clip(0, 1)
def Gelu(x:Tensor, approximate=None): return x.gelu() if approximate == "tanh" else 0.5 * x * (1 + (x/math.sqrt(2)).erf())
def HardSigmoid(x:Tensor, alpha:float=0.2, beta:float=0.5): return (alpha*x + beta).clip(0, 1)
def Gelu(x:Tensor, approximate:Optional[str]=None): return x.gelu() if approximate == "tanh" else 0.5 * x * (1 + (x/math.sqrt(2)).erf())
# TODO: fix this
def PRelu(X:Tensor, slope:Tensor):
slope = slope[0] if slope.shape[-1] != X.shape[-1] else slope # HACK OnnxBackendPyTorchConvertedModelTest HAS WEIRD SLOPE WHERE IT'S [0.25, 0.25, 0.25] FOR ANY X.SHAPE
return (X > 0).where(X, X * slope)
def LeakyRelu(X: Tensor, alpha=0.01): return X.leakyrelu(alpha)
def ThresholdedRelu(X: Tensor, alpha=1.0): return (X > alpha).where(X, 0)
def Softmax_1(x: Tensor, axis=1): return x.softmax(axis)
def Softmax_13(x: Tensor, axis=-1): return x.softmax(axis)
def LeakyRelu(X:Tensor, alpha:float=0.01): return X.leakyrelu(alpha)
def ThresholdedRelu(X:Tensor, alpha:float=1.0): return (X > alpha).where(X, 0)
def Softmax_1(x:Tensor, axis:int=1): return x.softmax(axis)
def Softmax_13(x:Tensor, axis:int=-1): return x.softmax(axis)
Softmax = {1: Softmax_1, 13: Softmax_13} # Softmax default axis changed
def LogSoftmax(x: Tensor, axis=-1): return x.log_softmax(axis)
def Clip(x: Tensor, min=None, max=None): return x.clip(float('-inf') if min is None else min, float('inf') if max is None else max).cast(x.dtype)
def LogSoftmax(x: Tensor, axis:int=-1): return x.log_softmax(axis)
def Clip(x: Tensor, min:Optional[Tensor]=None, max:Optional[Tensor]=None):
return x.clip(float('-inf') if min is None else min, float('inf') if max is None else max).cast(x.dtype)
def _axes(axes, noop_with_empty_axes):
if axes is not None and not (isinstance(axes, Tensor) and axes.shape == (0,)): return to_python_const(axes)
return [] if noop_with_empty_axes else None
def ReduceMax(data: Tensor, axes=None, keepdims=1, noop_with_empty_axes=0): return data.max(_axes(axes, noop_with_empty_axes), keepdim=keepdims)
def ReduceMin(data: Tensor, axes=None, keepdims=1, noop_with_empty_axes=0): return data.min(_axes(axes, noop_with_empty_axes), keepdim=keepdims)
def ReduceSum(data: Tensor, axes=None, keepdims=1, noop_with_empty_axes=0): return data.sum(_axes(axes, noop_with_empty_axes), keepdim=keepdims)
def ReduceMean(data: Tensor, axes=None, keepdims=1, noop_with_empty_axes=0): return data.mean(_axes(axes, noop_with_empty_axes), keepdim=keepdims)
def ReduceSumSquare(data: Tensor, axes=None, keepdims=1, noop_with_empty_axes=0): return ReduceSum(data.square(), axes, keepdims, noop_with_empty_axes)
def ReduceProd(data: Tensor, axes=None, keepdims=1, noop_with_empty_axes=0): return data.prod(_axes(axes, noop_with_empty_axes), keepdim=keepdims)
def ReduceL1(data: Tensor, axes=None, keepdims=1, noop_with_empty_axes=0): return ReduceSum(data.abs(), axes, keepdims, noop_with_empty_axes)
def ReduceL2(data: Tensor, axes=None, keepdims=1, noop_with_empty_axes=0): return ReduceSumSquare(data, axes, keepdims, noop_with_empty_axes).sqrt()
def ReduceLogSum(data: Tensor, axes=None, keepdims=1, noop_with_empty_axes=0): return ReduceSum(data, axes, keepdims, noop_with_empty_axes).log()
def ReduceLogSumExp(data: Tensor, axes=None, keepdims=1, noop_with_empty_axes=0): return ReduceSum(data.exp(), axes, keepdims, noop_with_empty_axes).log()
def _axes(axes, noop_with_empty_axes): return axes or ([] if noop_with_empty_axes else None)
def ReduceMax(data:Tensor, axes:Optional[List[int]]=None, keepdims:int=1, noop_with_empty_axes:int=0):
return data.max(_axes(axes, noop_with_empty_axes), keepdim=keepdims)
def ReduceMin(data:Tensor, axes:Optional[List[int]]=None, keepdims:int=1, noop_with_empty_axes:int=0):
return data.min(_axes(axes, noop_with_empty_axes), keepdim=keepdims)
def ReduceSum(data:Tensor, axes:Optional[List[int]]=None, keepdims:int=1, noop_with_empty_axes:int=0):
return data.sum(_axes(axes, noop_with_empty_axes), keepdim=keepdims)
def ReduceMean(data:Tensor, axes:Optional[List[int]]=None, keepdims:int=1, noop_with_empty_axes:int=0):
return data.mean(_axes(axes, noop_with_empty_axes), keepdim=keepdims)
def ReduceSumSquare(data:Tensor, axes:Optional[List[int]]=None, keepdims:int=1, noop_with_empty_axes:int=0):
return ReduceSum(data.square(), axes, keepdims, noop_with_empty_axes)
def ReduceProd(data:Tensor, axes:Optional[List[int]]=None, keepdims:int=1, noop_with_empty_axes:int=0):
return data.prod(_axes(axes, noop_with_empty_axes), keepdim=keepdims)
def ReduceL1(data:Tensor, axes:Optional[List[int]]=None, keepdims:int=1, noop_with_empty_axes:int=0):
return ReduceSum(data.abs(), axes, keepdims, noop_with_empty_axes)
def ReduceL2(data:Tensor, axes:Optional[List[int]]=None, keepdims:int=1, noop_with_empty_axes:int=0):
return ReduceSumSquare(data, axes, keepdims, noop_with_empty_axes).sqrt()
def ReduceLogSum(data:Tensor, axes:Optional[List[int]]=None, keepdims:int=1, noop_with_empty_axes:int=0):
return ReduceSum(data, axes, keepdims, noop_with_empty_axes).log()
def ReduceLogSumExp(data:Tensor, axes:Optional[List[int]]=None, keepdims:int=1, noop_with_empty_axes:int=0):
return ReduceSum(data.exp(), axes, keepdims, noop_with_empty_axes).log()
def GlobalAveragePool(X: Tensor): return X.mean(axis=tuple(range(2, X.ndim)), keepdim=True)
def GlobalMaxPool(X: Tensor): return X.max(axis=tuple(range(2, X.ndim)), keepdim=True)
def OptionalHasElement(x: Optional[Tensor]=None): return Tensor(x is not None and x.numel() > 0)
def OptionalGetElement(x: Optional[Tensor]=None): return x if x is not None else Tensor([])
def GlobalAveragePool(X:Tensor): return X.mean(axis=tuple(range(2, X.ndim)), keepdim=True)
def GlobalMaxPool(X:Tensor): return X.max(axis=tuple(range(2, X.ndim)), keepdim=True)
def OptionalHasElement(x:Optional[Tensor]=None): return Tensor(x is not None and x.numel() > 0)
def OptionalGetElement(x:Optional[Tensor]=None): return x if x is not None else Tensor([])
def Tile(x: Tensor, repeats): return x.repeat(to_python_const(repeats))
def Range(start: Tensor, limit, delta): return Tensor.arange(start=to_python_const(start), stop=to_python_const(limit), step=to_python_const(delta))
def Shape(data: Tensor, end=None, start=0): return Tensor(data.shape[start:end], dtype=dtypes.int64)
def Size(data: Tensor): return prod(data if isinstance(data, list) else data.shape)
def Flatten(x: Tensor, axis=1): return x.reshape(prod(x.shape[0:axis]), -1)
def Reshape(data: Tensor, shape: Tensor, allowzero=0):
return data.reshape([int(x) if x != 0 else (0 if allowzero else data.shape[i]) for i,x in enumerate(to_python_const(shape))])
def Expand(x: Tensor, shape:Tensor): return x.expand(_broadcast_shape(x.shape, tuple(to_python_const(shape))))
def Shrink(x: Tensor, bias=0.0, lambd=0.5): return (x < -lambd)*(x+bias) + (x > lambd)*(x-bias)
def Tile(x:Tensor, repeats:List[int]): return x.repeat(repeats)
def Range(start:Union[float, int], limit:Union[float, int], delta:Union[float, int]): return Tensor.arange(start=start, stop=limit, step=delta)
def Shape(data:Tensor, end:Optional[int]=None, start:int=0): return Tensor(data.shape[start:end], dtype=dtypes.int64)
def Size(data:Tensor): return prod(data if isinstance(data, list) else data.shape)
def Flatten(x:Tensor, axis:int=1): return x.reshape(prod(x.shape[0:axis]), -1)
def Reshape(data:Tensor, shape:List[int], allowzero:int=0):
return data.reshape([x if x != 0 else (0 if allowzero else data.shape[i]) for i,x in enumerate(shape)])
def Expand(x:Tensor, shape:List[int]): return x.expand(_broadcast_shape(x.shape, tuple(shape)))
def Shrink(x:Tensor, bias:float=0.0, lambd:float=0.5): return (x < -lambd)*(x+bias) + (x > lambd)*(x-bias)
def And(x:Tensor, y:Tensor): return (x==y).where(x, False)
def Or(x:Tensor, y:Tensor): return (x==y).where(x, True)
def Not(x:Tensor): return x.logical_not()
def Trilu(x: Tensor, k: Union[Tensor, int]=0, upper=1):
k = to_python_const(k) if isinstance(k, Tensor) else 0 # onnx passes k as a tensor int64 with one element, default is 0
return x.triu(k) if upper else x.tril(k)
def Trilu(x:Tensor, k:int=0, upper:int=1): return x.triu(k) if upper else x.tril(k)
def Slice(data: Tensor, starts:Tensor, ends:Tensor, axes:Optional[Tensor]=None, steps:Optional[Tensor]=None):
if axes is None: axes = list(range(data.ndim))
if steps is None: steps = [1] * data.ndim
starts, ends, axes, steps = (to_python_const(x) for x in (starts, ends, axes, steps))
def Slice(data:Tensor, starts:List[int], ends:List[int], axes:Optional[List[int]]=None, steps:Optional[List[int]]=None):
axes = axes or list(range(data.ndim))
steps = steps or [1]*data.ndim
slices = [slice(0,x,1) for x in data.shape]
for i, axis in enumerate(axes): slices[axis] = slice(starts[i], ends[i], steps[i])
return data[tuple(slices)]
def Squeeze(data: Tensor, axes):
if isinstance(axes, Tensor): axes = to_python_const(axes)
axes = [data._resolve_dim(x) for x in axes]
return data.reshape([s for i,s in enumerate(data.shape) if i not in axes])
def Unsqueeze(data: Tensor, axes):
axes = sorted([x + data.ndim if x < 0 else x for x in to_python_const(axes)])
new_shape = list(data.shape)
for axis in axes: new_shape.insert(axis, 1)
return data.reshape(new_shape)
# TODO: add test for when axes is None
def Squeeze(data:Tensor, axes:Optional[List[int]]=None):
return data.squeeze() if axes is None else functools.reduce(lambda d, dim: d.squeeze(dim), sorted(axes, reverse=True), data)
def Unsqueeze(data:Tensor, axes:List[int]): return functools.reduce(lambda d, dim: d.unsqueeze(dim), sorted(axes), data)
def Binarizer(x, threshold=0.0): return (x > threshold).float()
def Binarizer(x:Tensor, threshold:float=0.0): return (x > threshold).float()
def ArgMax(x: Tensor, axis=0, keepdims=1, select_last_index=0):
def ArgMax(x:Tensor, axis:int=0, keepdims:int=1, select_last_index:int=0):
if select_last_index: return ((x.shape[axis]-1) - x.flip(axis).argmax(axis, keepdim=keepdims)).cast(dtypes.int64)
return x.argmax(axis, keepdim=keepdims).cast(dtypes.int64)
def ArgMin(x, axis=0, keepdims=1, select_last_index=0): return ArgMax(-x, axis=axis, keepdims=keepdims, select_last_index=select_last_index)
def ArgMin(x, axis:int=0, keepdims:int=1, select_last_index:int=0): return ArgMax(-x, axis=axis, keepdims=keepdims, select_last_index=select_last_index)
def Concat(*xs: List[Tensor], axis): return Tensor.cat(*xs, dim=axis)
def Transpose(x: Tensor, perm=None): return x.permute(order=list(range(x.ndim)[::-1]) if perm is None else perm)
def Concat(*xs:Tensor, axis:int): return Tensor.cat(*xs, dim=axis)
def Transpose(x:Tensor, perm:Optional[List[int]]=None): return x.permute(order=list(range(x.ndim)[::-1]) if perm is None else perm)
def ConstantOfShape(x, value:Tensor=None):
if value is None: value = 0.0
shape = to_python_const(x)
return Tensor.ones(*shape, dtype=value.dtype) * (value if shape[0]!=0 else 1)
def ConstantOfShape(shape:List[int], value:Optional[Tensor]=None):
if value is None: value = Tensor(0, dtype=dtypes.float32)
return Tensor.ones(*shape, dtype=value.dtype) * (value if shape != [0] else 1)
# **************** Complex Ops ****************
def Gemm(A: Tensor, B: Tensor, C: Tensor=None, alpha=1.0, beta=1.0, transA=0, transB=0, broadcast=0):
def Gemm(A:Tensor, B:Tensor, C:Optional[Tensor]=None, alpha:float=1.0, beta:float=1.0, transA:int=0, transB:int=0, broadcast=0):
ret = alpha * (A.transpose(transA) @ B.transpose(transB))
if C is not None: ret = ret + beta * (C if broadcast == 0 else C.reshape([-1 if i < len(C.shape) else 1 for i in range(ret.ndim)][::-1]))
if C is not None: ret = ret + beta * (C if broadcast == 0 else C.reshape([-1 if i < len(C.shape) else 1 for i in range(ret.ndim)][::-1]))
return ret
def Einsum(*Inputs: List[Tensor], equation): return Tensor.einsum(equation, Inputs)
def Einsum(*Inputs:List[Tensor], equation:str): return Tensor.einsum(equation, Inputs)
def CumSum(X:Tensor, axis:Tensor, exclusive=0, reverse=0):
if (axis := to_python_const(axis)) < 0: axis += X.ndim
def CumSum(X:Tensor, axis:int, exclusive:int=0, reverse:int=0):
axis = X._resolve_dim(axis)
if reverse: X = X.flip(axis)
if exclusive: X = X.pad(tuple((1,0) if i == axis else None for i in range(X.ndim)))\
.shrink(tuple((0,X.shape[axis]) if i == axis else None for i in range(X.ndim)))
@@ -142,7 +146,8 @@ def CumSum(X:Tensor, axis:Tensor, exclusive=0, reverse=0):
# TODO: this is copied from tinygrad/nn/__init__.py
# spatial is from opset 7 and has since been removed
def BatchNormalization(X: Tensor, scale, B, input_mean, input_var, epsilon=1e-05, momentum=0.9, training_mode=0, spatial=1, is_test=0):
def BatchNormalization(X:Tensor, scale:Tensor, B:Tensor, input_mean:Tensor, input_var:Tensor, epsilon:float=1e-05, momentum:float=0.9 ,
training_mode:int=0, spatial=1, is_test=0):
if training_mode:
x_detached = X.detach()
current_mean = x_detached.mean(axis=(0,2,3))
@@ -157,19 +162,19 @@ def BatchNormalization(X: Tensor, scale, B, input_mean, input_var, epsilon=1e-05
invstd = (input_var + epsilon).rsqrt()
return X.batchnorm(scale, B, input_mean, invstd)
def InstanceNormalization(x: Tensor, scale: Tensor, bias: Tensor, epsilon=1e-05):
def InstanceNormalization(x:Tensor, scale:Tensor, bias:Tensor, epsilon:float=1e-05):
axis = tuple(range(2, x.ndim))
mean = x.mean(axis=axis, keepdim=True)
invstd = x.sub(mean).square().mean(axis=axis, keepdim=True).add(epsilon).rsqrt()
return x.sub(mean).mul(scale.reshape(shape=[-1, 1, 1])).mul(invstd).add(bias.reshape(shape=[-1, 1, 1]))
def LayerNormalization(x: Tensor, scale, bias, axis=-1, epsilon=1e-05, stash_type=1):
def LayerNormalization(x:Tensor, scale:Tensor, bias:Tensor, axis:int=-1, epsilon:float=1e-05, stash_type:int=1):
assert stash_type == 1, "only float32 is supported"
axis = tuple(i for i in range(axis if axis >= 0 else x.ndim + axis, x.ndim))
mean = x.mean(axis=axis, keepdim=True)
return x.layernorm(axis, epsilon).mul(scale).add(bias), mean, (x.sub(mean)).square().mean(axis=axis, keepdim=True).add(epsilon).rsqrt()
def GroupNormalization(x: Tensor, scale: Tensor, bias: Tensor, num_groups, epsilon=1e-05):
def GroupNormalization(x:Tensor, scale:Tensor, bias:Tensor, num_groups:int, epsilon:float=1e-05):
return x.reshape(x.shape[0], num_groups, -1).layernorm(axis=-1, eps=epsilon).mul(scale.unsqueeze(-1)).add(bias.unsqueeze(-1)).reshape(x.shape)
# onnx: [x1_begin, x2_begin, ..., x1_end, x2_end, ...]
@@ -213,11 +218,12 @@ def _auto_pad(X: Tensor, auto_pad, strides, kernel_shape, dilations):
# (x1_begin, x2_begin, ..., x1_end, x2_end, ...) -> (..., x2_start, x2_end, x1_start, x1_end)
def _onnx_pads_to_pad2d_pads(pads): return flatten(reversed(list((pB, pE) for pB, pE in zip(pads, pads[len(pads)//2:]))))
def Pad(x: Tensor, pads: Union[Tensor, Tuple[int, ...]], constant_value: Optional[Tensor]=None, axes: Optional[Tensor]=None, mode="constant", value=0):
pads, value, axes = to_python_const(pads), to_python_const(constant_value) or value or 0, to_python_const(axes) or list(range(x.ndim))
def Pad(x:Tensor, pads:List[int], constant_value:Optional[ConstType]=None, axes:Optional[List[int]]=None, mode:str="constant", value=0):
value, axes = constant_value or value or 0, axes or list(range(x.ndim))
real_pads = [0] * (x.ndim*2)
for i,axis in enumerate(axes): real_pads[axis%x.ndim], real_pads[axis%x.ndim+x.ndim] = pads[i], pads[i+len(axes)]
return x.pad(padding=_onnx_pads_to_pad2d_pads(to_python_const(real_pads)), mode={"edge":"replicate", "wrap":"circular"}.get(mode, mode), value=value)
return x.pad(padding=_onnx_pads_to_pad2d_pads(real_pads), mode={"edge":"replicate", "wrap":"circular"}.get(mode, mode), value=value)
def AveragePool(X: Tensor, kernel_shape, auto_pad="NOTSET", ceil_mode=0, count_include_pad=0, dilations=1, pads=None, strides=1):
pixel_axes = tuple(range(2, X.ndim))
@@ -244,7 +250,7 @@ def MaxUnpool(xT: Tensor, xI: Tensor, outshape: Optional[Tensor]=None, kernel_sh
arange = Tensor.arange(outlength, requires_grad=False).reshape(1, outlength).expand(xI.shape)
xT = xT.flatten().unsqueeze(1).expand(None, outlength)
ret = ((xI == arange) * xT).sum(0).reshape([1, 1] + out_sh)
if outshape is not None and (outshape := to_python_const(outshape)) != ret.shape:
if outshape is not None and outshape != ret.shape:
diff = [outshape[2] - ret.shape[2], outshape[3] - ret.shape[3]]
pad_args = [diff[0]//2, diff[1]//2, diff[0]-diff[0]//2, diff[1]-diff[1]//2]
ret = ret.pad((pad_args[1], pad_args[3], pad_args[0], pad_args[2]))
@@ -282,31 +288,28 @@ def SpaceToDepth(X:Tensor, blocksize:int):
return X.rearrange("b c (h h1) (w w1) -> b (h1 w1 c) h w", h1=blocksize, w1=blocksize)
# Reimplemented here because you need legacy RNG for passing ONNX tests.
def Dropout(data: Tensor, ratio=0.5, training_mode=False, seed=None):
if isinstance(ratio, Tensor) and not ratio.shape: ratio = to_python_const(ratio) # ratio and tensor is passed in as Tensor with shape: ()
if isinstance(training_mode, Tensor) and not training_mode.shape: training_mode = to_python_const(training_mode)
def Dropout(data:Tensor, ratio:float=0.5, training_mode:bool=False, seed:Optional[int]=None):
if not training_mode: return data, Tensor.ones(data.shape, dtype=dtypes.bool) # if mask is requested as output it will contain all True's.
rng = np.random.RandomState(seed)
if isinstance(ratio, Tensor): ratio = ratio.item()
mask = Tensor(rng.random(data.shape) >= ratio, requires_grad=False, device=data.device)
mask = Tensor(np.random.RandomState(seed).random(cast(Tuple[int,...], data.shape)) >= ratio, requires_grad=False, device=data.device)
return data * mask * (1/(1.0 - ratio)), mask
def LRN(x: Tensor, size, alpha=1e-4, beta=0.75, bias=1.0):
def LRN(x:Tensor, size:int, alpha:float=1e-4, beta:float=0.75, bias:float=1.0):
pooled_x = (x**2).rearrange('b c h w -> b 1 c (h w)').pad((0,0,(size-1)//2, size//2)).avg_pool2d((size, 1), 1)
return x / (pooled_x.reshape(x.shape) * alpha + bias).pow(beta)
def MeanVarianceNormalization(x: Tensor, axis=(0, 2, 3)): return (x - x.mean(axis, keepdim=True)) / (x.std(axis, keepdim=True, correction=0) + 1e-9)
def MeanVarianceNormalization(x:Tensor, axis:List[int]=[0,2,3]):
return (x - x.mean(axis, keepdim=True)) / (x.std(axis, keepdim=True, correction=0) + 1e-9)
def NegativeLogLikelihoodLoss(x: Tensor, target: Tensor, weight=None, ignore_index=None, reduction="mean"):
def NegativeLogLikelihoodLoss(x:Tensor, target:Tensor, weight:Optional[Tensor]=None, ignore_index:Optional[int]=None, reduction:str="mean"):
return x.nll_loss(target, weight, ignore_index, reduction)
def SoftmaxCrossEntropyLoss(scores: Tensor, labels: Tensor, weights=None, ignore_index=None, reduction="mean"):
def SoftmaxCrossEntropyLoss(scores:Tensor, labels:Tensor, weights:Optional[Tensor]=None, ignore_index:Optional[int]=None, reduction:str="mean"):
log_probs = scores.log_softmax(1)
return log_probs.nll_loss(labels, weights, ignore_index, reduction), log_probs
def ArrayFeatureExtractor(x: Tensor, indices: Tensor): return x[..., indices]
def ArrayFeatureExtractor(x:Tensor, indices:Tensor): return x[..., indices]
def Gather(x: Tensor, indices: Tensor, axis=0):
def Gather(x:Tensor, indices:Tensor, axis:int=0):
if indices.numel() < 9: # NOTE lessor kernels for smaller indices but kernel number increases depending on size of indices
x_sh = list(x.shape)
ret_shape = x_sh[:axis] + list(indices.shape) + x_sh[axis+1:]
@@ -318,17 +321,17 @@ def Gather(x: Tensor, indices: Tensor, axis=0):
return x[tuple([slice(None) if i != axis else indices for i in range(x.ndim)])]
def Scatter(*args, **kwargs): return ScatterElements(*args, **kwargs) # deprecated
def ScatterElements(x: Tensor, indices: Tensor, updates: Tensor, axis=0, reduction:Optional[str]=None):
def ScatterElements(x:Tensor, indices:Tensor, updates:Tensor, axis:int=0, reduction:Optional[str]=None):
if reduction in {"min", "max"}: raise NotImplementedError("min and max reduction not supported")
indices = (indices < 0).where(x.shape[axis], 0) + indices
return x.scatter(axis, indices, updates, reduction)
def GatherElements(x: Tensor, indices: Tensor, axis):
def GatherElements(x:Tensor, indices:Tensor, axis:int):
indices = (indices < 0).where(x.shape[axis], 0) + indices
return x.gather(axis, indices)
def Resize(X:Tensor, roi=None, scales=None, sizes=None, antialias=0, axes=None, coordinate_transformation_mode='half_pixel',
cubic_coeff_a=-0.75, exclude_outside=0, extrapolation_value=0.0, keep_aspect_ratio_policy='stretch',
mode='nearest', nearest_mode='round_prefer_floor'):
def Resize(X:Tensor, roi:Optional[List[float]]=None, scales:Optional[List[float]]=None, sizes:Optional[List[int]]=None, antialias:int=0,
axes:Optional[List[int]]=None, coordinate_transformation_mode:str='half_pixel', cubic_coeff_a:float=-0.75, exclude_outside:int=0,
extrapolation_value:float=0.0, keep_aspect_ratio_policy:str='stretch', mode:str='nearest', nearest_mode:str='round_prefer_floor'):
def _apply_nearest_mode(index: Tensor, input_dim, mode: str):
if mode == "round_prefer_floor": index = (index - 0.5).ceil()
elif mode == "round_prefer_ceil": index = (index + 0.5).floor()
@@ -350,7 +353,6 @@ def Resize(X:Tensor, roi=None, scales=None, sizes=None, antialias=0, axes=None,
else: raise ValueError(f"invalid {coordinate_transformation_mode=}")
return index.clip(0, input_dim-1)
roi, scales, sizes = (to_python_const(a) for a in (roi, scales, sizes))
scales, sizes = (None if scales is None else scales[-2:]), (None if sizes is None else sizes[-2:])
# we pre permute the axes and permute back after resize
axes, input_shape, = (axes or list(range(X.ndim))), X.shape[2:],
@@ -386,8 +388,7 @@ def Resize(X:Tensor, roi=None, scales=None, sizes=None, antialias=0, axes=None,
if mode == "cubic": raise NotImplementedError("cubic interpolation is not implemented")
return X.permute(*[perm.index(i) for i in range(len(perm))]) if perm else X
def CenterCropPad(t: Tensor, shape: Tensor, axes=None):
shape = to_python_const(shape)
def CenterCropPad(t:Tensor, shape:List[int], axes:Optional[List[int]]=None):
shrink_arg = [None] * t.ndim
pad_arg = [None] * t.ndim
for s, x in zip(shape, axes or range(t.ndim)):
@@ -396,28 +397,27 @@ def CenterCropPad(t: Tensor, shape: Tensor, axes=None):
elif s > tx: pad_arg[x] = ((s-tx)//2, (s-tx+1)//2)
return t.shrink(tuple(shrink_arg)).pad(tuple(pad_arg))
def OneHot(indices: Tensor, depth: Tensor, values: Tensor, axis=-1):
depth = int(to_python_const(depth))
def OneHot(indices:Tensor, depth:Union[int, float], values:Tensor, axis:int=-1):
# Scalar or Rank 1 tensor containing exactly one element
depth, indices = depth[0] if isinstance(depth, list) else depth, (indices < 0).where(indices+depth, indices),
depth = int(depth)
indices = (indices < 0).where(indices+depth, indices)
return indices[:, None]._one_hot_along_dim(depth, dim=axis).where(values[1], values[0])
def Compress(inp: Tensor, condition: Tensor, axis=None):
def Compress(inp:Tensor, condition:List[bool], axis:Optional[int]=None):
if axis is None:
inp = inp.flatten()
axis = 0
if axis < 0: axis += inp.ndim
con_np = to_python_const(condition)
con = Tensor(np.arange(condition.shape[0])[con_np]) # no boolean indexing in Tensor
con = Tensor(np.arange(len(condition))[condition]) # no boolean indexing in Tensor
return inp[tuple(con if i == axis else slice(None) for i in range(inp.ndim))]
def EyeLike(x: Tensor, dtype=None, k=0):
ret = Tensor.eye(cast(int, min(x.shape)), dtype=dtype_parse(dtype) if dtype else x.dtype)
def EyeLike(x:Tensor, dtype:Optional[int]=None, k:int=0):
ret = Tensor.eye(cast(int, min(x.shape)), dtype=dtype_parse(dtype) if dtype is not None else x.dtype)
return ret if x.size(0) == x.size(1) else ret.pad(tuple(None if d == ret.size(0) else (k, d-ret.size(0)-k) for d in x.shape))
def Upsample(X, scales, mode): return Resize(X=X, scales=scales, mode=mode)
def Upsample(X, scales, mode): return Resize(X=X, scales=scales, mode=mode) # deprecated
def DequantizeLinear(x: Tensor, x_scale: Tensor, x_zero_point: Union[Tensor, int] = 0, axis=1, block_size=0):
def DequantizeLinear(x:Tensor, x_scale:Tensor, x_zero_point:Union[Tensor, int] = 0, axis:int=1, block_size:int=0):
if axis < 0: axis += x.ndim
if not isinstance(x_zero_point, Tensor): x_zero_point = Tensor(x_zero_point)
if block_size: x_zer, x_sc = x_zero_point.repeat_interleave(block_size, axis), x_scale.repeat_interleave(block_size, axis)
@@ -427,18 +427,17 @@ def DequantizeLinear(x: Tensor, x_scale: Tensor, x_zero_point: Union[Tensor, int
return ((x.float() - x_zer) * x_sc).cast(x_scale.dtype)
# copied from https://github.com/onnx/onnx/blob/main/onnx/reference/ops/op_image_decoder.py
# without importing PIL we'll have to manually decode a bunch of image formats like PNG, JPEG, WebP, etc
def ImageDecoder(encoded_stream: Tensor, pixel_format="RGB"):
def ImageDecoder(encoded_stream:bytes, pixel_format="RGB"):
try: import PIL.Image
except ImportError as e: raise ImportError("Pillow must be installed to use the reference implementation of the ImageDecoder operator") from e
img = PIL.Image.open(io.BytesIO(to_python_const(encoded_stream)))
img = PIL.Image.open(io.BytesIO(encoded_stream))
if pixel_format == "BGR": return Tensor(np.array(img))[:, :, ::-1]
if pixel_format == "RGB": return Tensor(np.array(img))
if pixel_format == "Grayscale": return Tensor(np.array(img.convert("L"))).unsqueeze(-1) # (H, W) to (H, W, 1)
raise ValueError(f"pixel_format={pixel_format!r} is not supported.")
def AffineGrid(theta: Tensor, size: Tensor, align_corners=0):
N, _, *spatial_dims = to_python_const(size)
def AffineGrid(theta:Tensor, size:Tensor, align_corners:int=0):
N, _, *spatial_dims = size
def generate_grid(steps):
return Tensor.linspace(-1, 1, steps, device=theta.device) if align_corners else Tensor.linspace(-1+1/steps, 1-1/steps, steps, device=theta.device)
grids = Tensor.meshgrid(*(generate_grid(d) for d in spatial_dims))
@@ -448,30 +447,30 @@ def AffineGrid(theta: Tensor, size: Tensor, align_corners=0):
# **************** com.microsoft Ops ****************
def SkipLayerNormalization(x:Tensor, skip:Tensor, gamma, beta:Optional[Tensor]=None, bias:Optional[Tensor]=None, epsilon=None):
if epsilon is None: epsilon=1e-12
def SkipLayerNormalization(x:Tensor, skip:Tensor, gamma:Tensor, beta:Optional[Tensor]=None, bias:Optional[Tensor]=None, epsilon:float=1e-12):
x = x + skip + bias
return x.layernorm(eps=epsilon) * gamma + beta, None, None, x
def FastGelu(x:Tensor, bias:Optional[Tensor]=None):
# this is tanh approximated
return (x + bias).gelu()
return (x + bias).gelu() if bias is not None else x.gelu()
def EmbedLayerNormalization(input_ids: Tensor, segment_ids:Optional[Tensor]=None, word_embedding:Tensor=None, position_embedding:Tensor=None, segment_embedding:Optional[Tensor]=None, gamma=None, beta=None, mask:Optional[Tensor]=None, position_ids:Optional[Tensor]=None, epsilon=None, mask_index_type=None):
def EmbedLayerNormalization(input_ids: Tensor, segment_ids:Optional[Tensor]=None, word_embedding:Tensor=None, position_embedding:Tensor=None,
segment_embedding:Optional[Tensor]=None, gamma=None, beta=None, mask:Optional[Tensor]=None,
position_ids:Optional[Tensor]=None, epsilon=1e-12, mask_index_type=0):
# https://github.com/microsoft/onnxruntime/blob/main/docs/ContribOperators.md#com.microsoft.EmbedLayerNormalization
assert (segment_ids is None) is (segment_embedding is None)
assert (mask is None) is (mask_index_type is None)
assert mask is None, "functionality not supported yet" # TODO
assert mask is None and not mask_index_type, "functionality not supported yet" # TODO
input_shape = input_ids.shape
seq_length = input_shape[1]
compute_seg_emb = (segment_embedding is not None and segment_ids is not None)
vocab_size, max_position_embeddings, type_vocab_size = word_embedding.shape[0], position_embedding.shape[0], (segment_embedding.shape[0] if compute_seg_emb else None)
vocab_size, max_position_embeddings = word_embedding.shape[0], position_embedding.shape[0]
type_vocab_size = (segment_embedding.shape[0] if compute_seg_emb else None)
def embedding(x:Tensor, vocab_size, weight:Tensor) -> Tensor:
return x.unsqueeze(-1).expand(*x.shape, vocab_size)._one_hot_along_dim(vocab_size) @ weight
# bert embedding layer
if epsilon is None: epsilon = 1e-12
if position_ids is None: position_ids = Tensor.arange(seq_length, requires_grad=False).unsqueeze(0).expand(*input_shape)
wrd_embedding_res = embedding(input_ids, vocab_size, word_embedding)
pos_embedding_res = embedding(position_ids, max_position_embeddings, position_embedding)
@@ -482,11 +481,15 @@ def EmbedLayerNormalization(input_ids: Tensor, segment_ids:Optional[Tensor]=None
out = embedding_sum.layernorm(eps=epsilon) * gamma + beta
return out, None, embedding_sum
def Attention(x:Tensor, weights, bias:Optional[Tensor]=None, mask_index:Optional[Tensor]=None, past:Optional[Tensor]=None, relative_position_bias:Optional[Tensor]=None, past_sequence_length:Optional[Tensor]=None, do_rotary=None, mask_filter_value=None, num_heads=None, past_present_share_buffer=None, qkv_hidden_sizes=None, scale=None, unidirectional=None):
def Attention(x:Tensor, weights, bias:Optional[Tensor]=None, mask_index:Optional[Tensor]=None, past:Optional[Tensor]=None,
relative_position_bias:Optional[Tensor]=None, past_sequence_length:Optional[Tensor]=None, do_rotary:Optional[int]=None,
mask_filter_value:Optional[float]=None, num_heads:Optional[int]=None, past_present_share_buffer:Optional[int]=None,
qkv_hidden_sizes:Optional[List[int]]=None, scale:Optional[float]=None, unidirectional:Optional[int]=None):
# https://github.com/microsoft/onnxruntime/blob/main/docs/ContribOperators.md#com.microsoft.Attention
assert num_heads is not None # required
assert (qkv_hidden_sizes is None and past is not None) or (qkv_hidden_sizes is not None)
assert relative_position_bias==do_rotary==past_sequence_length==mask_filter_value==past_present_share_buffer==scale==None, "functionality not supported yet" # TODO strange params
assert relative_position_bias==do_rotary==past_sequence_length==mask_filter_value==past_present_share_buffer==scale==None, \
"functionality not supported yet" # TODO strange params
hidden_size, v_hidden_size = qkv_hidden_sizes[1:] if qkv_hidden_sizes is not None else 2*(weights.shape[1] // 3,)
if unidirectional: # gpt-style
@@ -515,7 +518,7 @@ def Attention(x:Tensor, weights, bias:Optional[Tensor]=None, mask_index:Optional
bsz, _, seq_len, _ = xq.shape
out = attn(xq, xk, xv, mask_index).transpose(1, 2).reshape(bsz, seq_len, -1)
return out, present
return out, present if past is not None else out
# **************** ai.onnx.preview.training Ops ****************
# NOTE: onnx test coverage only covers `T==0` cases, so for all `T>0` this isn't tested
@@ -526,10 +529,10 @@ from tinygrad.nn.optim import SGD
def onnx_training(input_group_size):
def _decorator(func):
def __wrapper(R, T, *inputs, **kwargs):
def __wrapper(R:Tensor, T:int, *inputs:Tensor, **kwargs):
old_training = Tensor.training
Tensor.training = True
T, R = to_python_const(T), R.detach()
R = R.detach()
groups = len(inputs) // input_group_size
ret = [func(R, T, *inps, **kwargs) for inps in (inputs[i::groups] for i in range(groups))]
Tensor.training = old_training
@@ -538,7 +541,7 @@ def onnx_training(input_group_size):
return _decorator
@onnx_training(3)
def Adagrad(R, T, *inputs, decay_factor=0.0, epsilon=0.0, norm_coefficient=0.0):
def Adagrad(R:Tensor, T:int, *inputs:Tensor, decay_factor:float=0.0, epsilon:float=0.0, norm_coefficient:float=0.0):
X, G, H = (i.detach() for i in inputs)
grad = norm_coefficient * X + G
H.assign(H + grad.square())
@@ -548,7 +551,8 @@ def Adagrad(R, T, *inputs, decay_factor=0.0, epsilon=0.0, norm_coefficient=0.0):
return [X, H]
@onnx_training(4)
def Adam(R, T, *inputs, alpha=0.9, beta=0.999, epsilon=0.0, norm_coefficient=0.0, norm_coefficient_post=0.0):
def Adam(R:Tensor, T:int, *inputs:Tensor, alpha:float=0.9, beta:float=0.999, epsilon:float=0.0, norm_coefficient:float=0.0,
norm_coefficient_post:float=0.0):
X, G, V, H = inputs
G, V, H = G.detach(), V.detach(), H.detach() # TODO we shouldn't need these detaches
X.grad = norm_coefficient * X.detach() + G
@@ -565,7 +569,7 @@ def Adam(R, T, *inputs, alpha=0.9, beta=0.999, epsilon=0.0, norm_coefficient=0.0
return [X, V, H]
@onnx_training(3)
def Momentum(R, T, *inputs, alpha, beta, mode, norm_coefficient):
def Momentum(R:Tensor, T:int, *inputs:Tensor, alpha:float, beta:float, mode:str, norm_coefficient:float):
X, G, V = inputs
G, V = G.detach(), V.detach()
X.grad = (norm_coefficient * X.detach() + G) * (beta if T > 0 else 1)