Merge branch 'master' into retinanet_mlperf

This commit is contained in:
Francis Lata
2025-02-03 19:54:55 +00:00
17 changed files with 1010 additions and 977 deletions

View File

@@ -1,6 +1,5 @@
import sys, onnx, time
from tinygrad import Tensor, TinyJit, Device, GlobalCounters, fetch
from tinygrad.tensor import _from_np_dtype
from extra.onnx import OnnxRunner
def load_onnx_model(fn):
@@ -18,19 +17,25 @@ def load_onnx_model(fn):
run_onnx_jit = TinyJit(lambda **kwargs: next(iter(run_onnx({k:v.to(Device.DEFAULT) for k,v in kwargs.items()}).values())), prune=True)
return run_onnx_jit, input_shapes, input_types
def get_new_inputs(input_shapes):
#from tinygrad.tensor import _from_np_dtype
#return {k:Tensor.randn(*shp, dtype=_from_np_dtype(input_types[k])).mul(8).realize() for k,shp in sorted(input_shapes.items())}
import numpy as np
return {k:Tensor(np.random.uniform(size=shp).astype(input_types[k]) * 8).realize() for k,shp in sorted(input_shapes.items())}
if __name__ == "__main__":
run_onnx_jit, input_shapes, input_types = load_onnx_model(sys.argv[1])
print("loaded model")
for i in range(3):
new_inputs = {k:Tensor.randn(*shp, dtype=_from_np_dtype(input_types[k])).mul(8).realize() for k,shp in sorted(input_shapes.items())}
new_inputs = get_new_inputs(input_shapes)
GlobalCounters.reset()
print(f"run {i}")
run_onnx_jit(**new_inputs)
# run 20 times
for _ in range(20):
new_inputs = {k:Tensor.randn(*shp, dtype=_from_np_dtype(input_types[k])).mul(8).realize() for k,shp in sorted(input_shapes.items())}
new_inputs = get_new_inputs(input_shapes)
GlobalCounters.reset()
st = time.perf_counter()
out = run_onnx_jit(**new_inputs)

View File

@@ -17,6 +17,10 @@ from tinygrad.helpers import fetch, getenv
# https://huggingface.co/qualcomm/MobileNet-v2-Quantized/resolve/main/MobileNet-v2-Quantized.onnx
# ~35% - https://github.com/axinc-ai/onnx-quantization/raw/refs/heads/main/models/mobilenev2_quantized.onnx
# QUANT=1 python3 examples/test_onnx_imagenet.py
# https://github.com/xamcat/mobcat-samples/raw/refs/heads/master/onnx_runtime/InferencingSample/InferencingSample/mobilenetv2-7.onnx
# VIZ=1 DONT_REALIZE_EXPAND=1 python3 examples/benchmark_onnx.py /tmp/model.quant.onnx
def imagenet_dataloader(cnt=0):
input_mean = Tensor([0.485, 0.456, 0.406]).reshape(1, -1, 1, 1)
input_std = Tensor([0.229, 0.224, 0.225]).reshape(1, -1, 1, 1)
@@ -61,7 +65,7 @@ if __name__ == "__main__":
assert shape[1:] == (3,224,224), f"shape is {shape}"
hit = 0
for i,(img,y) in enumerate(imagenet_dataloader()):
for i,(img,y) in enumerate(imagenet_dataloader(cnt=100)):
p = run_onnx_jit(**{t_name:img})
assert p.shape == (1,1000)
t = p.argmax().item()

View File

@@ -37,7 +37,7 @@ if __name__ == "__main__":
print("mmapped", hex(res))
to_mv(res, 0x10)[1] = 0xaa
from tinygrad.runtime.ops_clang import ClangCompiler
from tinygrad.runtime.ops_dsp import ClangCompiler
cc = ClangCompiler(args=["--target=hexagon", "-mcpu=hexagonv65", "-fuse-ld=lld", "-nostdlib"])
obj = cc.compile("""

27
extra/dsp/opt.py Normal file
View File

@@ -0,0 +1,27 @@
from tinygrad.runtime.ops_dsp import DSPCompiler
# PATH=/opt/homebrew/opt/llvm/bin:$PATH python3 extra/dsp/opt.py
if __name__ == "__main__":
compiler = DSPCompiler()
lib = compiler.compile("""
typedef long HVX_Vector __attribute__((__vector_size__(128))) __attribute__ ((aligned(128)));
typedef long HVX_VectorPair __attribute__((__vector_size__(256))) __attribute__ ((aligned(256)));
void test(unsigned char *c, unsigned char *a, unsigned char *b) {
HVX_Vector t0 = *(HVX_Vector*)a;
//HVX_VectorPair t1 = *((HVX_VectorPair*)b);
HVX_Vector acc = __builtin_HEXAGON_V6_vd0_128B();
for (int i = 0; i < 128; i++) {
//__builtin_HEXAGON_V6_lvsplatb_128B(t0[i])
//acc += __builtin_HEXAGON_V6_lvsplatb_128B(t0[i]) * t1;
//acc += t0[i] * t1;
unsigned int t1 = ((unsigned int *)b)[i];
//acc = __builtin_HEXAGON_V6_vrmpyub_acc_128B(acc, t0, t1);
acc = __builtin_HEXAGON_V6_vrmpybus_acc_128B(acc, t0, t1);
}
*((HVX_Vector*)c) = acc;
}""")
compiler.disassemble(lib)

View File

@@ -1,8 +1,8 @@
from typing import Callable, Any, Sequence
import importlib, functools, dataclasses
from tinygrad.tensor import Tensor
from tinygrad.helpers import getenv, DEBUG, all_same
from tinygrad.dtype import DType, ConstType, dtypes
from typing import Any, Sequence, cast, Literal, Callable
import dataclasses, functools, io, math, types
from tinygrad.tensor import Tensor, _broadcast_shape, ReductionStr
from tinygrad.helpers import getenv, DEBUG, all_same, prod, flatten, make_tuple
from tinygrad.dtype import DType, ConstType, dtypes, ImageDType
from tinygrad.device import is_dtype_supported
# ***** protobuf parsing ******
@@ -111,11 +111,11 @@ limit = int(getenv("ONNXLIMIT", "-1"))
class OnnxRunner:
def __init__(self, model: ModelProto):
# parse model protobuf
self.is_training = any(n.HasField("domain") and n.domain == "ai.onnx.preview.training" for n in model.graph.node)
self.is_training = any(n.domain in {"ai.onnx.training", "ai.onnx.preview.training"} for n in model.graph.node)
self.old_training, self.old_no_grad = Tensor.training, Tensor.no_grad
Tensor.training = True if self.is_training else False
Tensor.no_grad = False if self.is_training else True
self.graph_values = {x.name:buffer_parse(x) for x in model.graph.initializer}
self.graph_values = {"": None, **{x.name:buffer_parse(x) for x in model.graph.initializer}}
self.graph_inputs = {x.name:type_parse(x.type) for x in model.graph.input if x.name not in self.graph_values}
self.graph_outputs = {x.name:type_parse(x.type) for x in model.graph.output}
self.graph_nodes = tuple(OnnxNode(num, n.op_type, tuple(n.input), tuple(n.output), {x.name:attribute_parse(x) for x in n.attribute})
@@ -123,14 +123,7 @@ class OnnxRunner:
self.opset_version = model.opset_import[0].version
self.variable_dims: dict[str, int] = {}
# TODO: move extra.onnx_ops here so we don't have to deal with annoying circular import
# TODO: clean up opset stuff after moving extra.onnx_ops here
self.onnx_ops_module = importlib.import_module('extra.onnx_ops')
self.onnx_ops = {
**{op: getattr(Tensor, op.lower()) for op in ("Neg", "Reciprocal", "Pow", "Sqrt", "Sign", "Abs", "Exp", "Log", "Mish", "Sin", "Cos", "Tan",
"Asin", "Acos", "Atan", "Relu", "Sigmoid", "MatMul", "Floor", "Ceil", "IsInf", "IsNaN", "Softplus", "HardSwish", "Where", "Mul", "Sinh", "Cosh",
"Tanh", "Softsign", "Asinh", "Acosh", "Atanh", "Elu", "Celu", "Selu", "Xor", "Round", "Erf", "Mod")},
}
self.onnx_ops = onnx_ops
def _parse_input(self, name: str, value: Any, spec: OnnxValue):
if spec.is_optional and value is None: return None
@@ -148,9 +141,8 @@ class OnnxRunner:
return tensor
def _dispatch_op(self, op, inps, opts):
if op in self.onnx_ops: return self.onnx_ops[op](*inps, **opts)
if hasattr(self.onnx_ops_module, op):
fxn = getattr(self.onnx_ops_module, op)
if op in self.onnx_ops:
fxn = self.onnx_ops[op]
if isinstance(fxn, dict):
for k in sorted(fxn.keys()):
if k <= self.opset_version:
@@ -165,7 +157,7 @@ class OnnxRunner:
self.graph_values[name] = self._parse_input(name, inputs[name], input_spec)
for node in self.graph_nodes:
inps = [to_python_const(self.graph_values.get(name), node.op, i) for i,name in enumerate(node.inputs)]
inps = [to_python_const(self.graph_values[name], node.op, i) for i,name in enumerate(node.inputs)]
opts = node.opts
# provide additional opts
@@ -184,4 +176,623 @@ class OnnxRunner:
Tensor.training, Tensor.no_grad = self.old_training, self.old_no_grad
return {name:self.graph_values[name] for name in node.outputs}
Tensor.training, Tensor.no_grad = self.old_training, self.old_no_grad
return {name:self.graph_values[name] for name in self.graph_outputs}
return {name:self.graph_values[name] for name in self.graph_outputs}
####################
##### ONNX OPS #####
####################
def get_onnx_ops():
# ***** helper functions *****
def _axes(axes, noop_with_empty_axes): return axes or ([] if noop_with_empty_axes else None)
# (padding_top, padding_left, ..., padding_bottom, padding_right, ...) -> (padding_left, padding_right, padding_top, padding_bottom, ...)
def _onnx_pads_to_tiny_pads(pads): return tuple(flatten(reversed(list(zip(pads, pads[len(pads)//2:])))))
AUTO_PAD_OPTIONS = Literal["NOTSET", "SAME_UPPER", "SAME_LOWER", "VALID"]
# (padding_height, padding_width) -> (padding_top, padding_left, padding_bottom, padding_right)
def _auto_pad(pads, auto_pad: AUTO_PAD_OPTIONS):
if auto_pad == "SAME_UPPER": return [pads[i]//2 for i in range(len(pads))] + [pads[i]-pads[i]//2 for i in range(len(pads))]
return [pads[i]-pads[i]//2 for i in range(len(pads))] + [pads[i]//2 for i in range(len(pads))]
def _resolve_pool_pads(x:Tensor, p_, k_, d_, s_, auto_pad:AUTO_PAD_OPTIONS):
i_, (s_,d_,p_) = x.shape[-len(k_):], (make_tuple(x, len(k_)*2) for x in (s_, d_, p_))
if auto_pad == "NOTSET": return _onnx_pads_to_tiny_pads(p_ if len(p_)==len(k_)*2 else p_*2)
o_ = [((i - (1 if auto_pad in ("SAME_UPPER", "SAME_LOWER") else k)) // s + 1) for i,k,s in zip(i_, k_, s_)]
return _onnx_pads_to_tiny_pads(_auto_pad([(o-1)*s+k-i for o,i,k,s in zip(o_, i_, k_, s_)], auto_pad))
def _clamp_cast(x:Tensor, dtype:DType): return x.clamp(dtypes.min(dtype), dtypes.max(dtype)).cast(dtype)
def _prepare_quantize(x, scale, zero_point, axis=1, block_size=0):
if axis < 0: axis += x.ndim
if not isinstance(zero_point, Tensor): zero_point = Tensor(zero_point, dtype=dtypes.uint8)._broadcast_to(scale.shape)
if block_size == 0:
shape = (*[1]*axis, *scale.shape, *[1]*(x.ndim - axis - scale.ndim))
return scale.reshape(shape), zero_point.reshape(shape)
return scale.repeat_interleave(block_size, dim=axis), zero_point.repeat_interleave(block_size, dim=axis)
def _op_integer(op, inputs:list[Tensor], zero_points:list[Tensor], **opts):
adjusted_inputs = [inp.int() - zp for inp, zp in zip(inputs, zero_points)]
return op(*adjusted_inputs, **opts)
def _qlinearop_quantized(op, inputs:list[Tensor], zero_points:list[Tensor], scales:list[Tensor], out_scale:Tensor, out_zero_point:Tensor, **opts):
# op execution is done in quantized int
out = _op_integer(op, inputs, zero_points, **opts)
assert dtypes.is_int(out.dtype), "quantized op should've done math in int"
out_quantized = (out * prod(scales) / out_scale).round() + out_zero_point
return _clamp_cast(out_quantized, out_zero_point.dtype)
def _qlinearop_float(op, inputs:list[Tensor], zero_points:list[Tensor], scales:list[Tensor], out_scale:Tensor, out_zero_point:Tensor, **opts):
# op execution is done in float32
dequantized_inputs = [(inp.int() - zp) * scale for inp, zp, scale in zip(inputs, zero_points, scales)]
out = op(*dequantized_inputs, **opts)
assert dtypes.is_float(out.dtype), "op should've done math in float"
out_quantized = (out / out_scale).round() + out_zero_point
return _clamp_cast(out_quantized, out_zero_point.dtype)
def _onnx_training(input_group_size):
def __decorator(func):
def ___wrapper(R:Tensor, T:int, *inputs:Tensor, **kwargs):
R = R.detach()
groups = len(inputs) // input_group_size
ret = [func(R, T, *inps, **kwargs) for inps in (inputs[i::groups] for i in range(groups))]
return tuple(flatten(zip(*ret)))
return ___wrapper
return __decorator
# ***** Property/Graph Ops *****
def Identity(x:Tensor): return x
def Constant(sparse_value:Tensor|None=None, value:Tensor|None=None, value_float:float|None=None, value_floats:list[float]|None=None,
value_int:int|None=None, value_ints:list[int]|None=None, value_string:str|None=None, value_strings:list[str]|None=None):
if value is not None: return value
if value_float is not None: return Tensor(value_float, dtype=dtypes.float32, requires_grad=False)
if value_floats is not None: return Tensor(list(value_floats), dtype=dtypes.float32, requires_grad=False)
if value_int is not None: return Tensor(value_int, dtype=dtypes.int64, requires_grad=False)
if value_ints is not None: return Tensor(list(value_ints), dtype=dtypes.int64, requires_grad=False)
if value_string is not None or value_strings is not None and sparse_value is not None:
raise NotImplementedError('Constant OP not implemented for value_string, value_strings and sparse_value')
def Range(start:float|int, limit:float|int, delta:float|int): return Tensor.arange(start=start, stop=limit, step=delta)
def ImageDecoder(encoded_stream:bytes, pixel_format="RGB"):
try: import PIL.Image
except ImportError as e: raise ImportError("Pillow must be installed for the ImageDecoder operator") from e
img = PIL.Image.open(io.BytesIO(encoded_stream))
if pixel_format == "BGR": return Tensor(np.array(img))[:, :, ::-1]
if pixel_format == "RGB": return Tensor(np.array(img))
if pixel_format == "Grayscale": return Tensor(np.array(img.convert("L"))).unsqueeze(-1) # (H, W) to (H, W, 1)
raise ValueError(f"pixel_format={pixel_format!r} is not supported.")
def EyeLike(x:Tensor, dtype:int|None=None, k:int=0):
ret = Tensor.eye(cast(int, min(x.shape)), dtype=dtype_parse(dtype) if dtype is not None else x.dtype)
return ret if x.size(0) == x.size(1) else ret.pad(tuple(None if d == ret.size(0) else (k, d-ret.shape[0]-k) for d in x.shape))
def OptionalHasElement(x:Tensor|None=None): return Tensor(x is not None and x.numel() > 0)
def OptionalGetElement(x:Tensor|None=None): return x if x is not None else Tensor([])
def ConstantOfShape(shape:list[int], value:Tensor|None=None):
if value is None: value = Tensor(0, dtype=dtypes.float32)
return Tensor.ones(*shape, dtype=value.dtype) * (value if shape != [0] else 1)
def Size(data:Tensor): return data.numel()
def Shape(data:Tensor, end:int|None=None, start:int=0): return Tensor(data.shape[start:end], dtype=dtypes.int64)
# ***** Unary Ops (math) *****
def Not(x:Tensor): return x.logical_not()
def Clip(x: Tensor, min:Tensor|None=None, max:Tensor|None=None):
return x.clip(float('-inf') if min is None else min, float('inf') if max is None else max).cast(x.dtype)
# ***** Unary Ops (activation) *****
def Softmax_1(x:Tensor, axis:int=1): return x.softmax(axis)
def Softmax_13(x:Tensor, axis:int=-1): return x.softmax(axis)
Softmax = {1:Softmax_1, 13:Softmax_13}
def HardSigmoid(x:Tensor, alpha:float=0.2, beta:float=0.5): return (alpha*x + beta).clip(0, 1)
def Gelu(x:Tensor, approximate:str|None=None): return x.gelu() if approximate == "tanh" else 0.5 * x * (1 + (x/math.sqrt(2)).erf())
def FastGelu(x:Tensor, bias:Tensor|None=None):
# this is tanh approximated
return (x + bias).gelu() if bias is not None else x.gelu()
# TODO: fix this
def PRelu(X:Tensor, slope:Tensor):
slope = slope[0] if slope.shape[-1] != X.shape[-1] else slope
return (X > 0).where(X, X * slope)
def LeakyRelu(X:Tensor, alpha:float=0.01): return X.leakyrelu(alpha)
def ThresholdedRelu(X:Tensor, alpha:float=1.0): return (X > alpha).where(X, 0)
def LogSoftmax(x: Tensor, axis:int=-1): return x.log_softmax(axis)
def Binarizer(x:Tensor, threshold:float=0.0): return (x > threshold).float()
# ***** Unary Ops (broadcasted) *****
def Add(x:Tensor,y:Tensor, broadcast=None, axis=None): return x + y if x.dtype == dtypes.float or isinstance(x.dtype, ImageDType) else (x + y).cast(x.dtype)
def Sub(x:Tensor|int,y:Tensor): return x - y # some test has input as int
def Div(x:Tensor,y:Tensor): return (x/y).cast(x.dtype)
def Less(x:Tensor,y:Tensor): return x < y
def LessOrEqual(x:Tensor,y:Tensor): return x <= y
def Greater(x:Tensor,y:Tensor): return x > y
def GreaterOrEqual(x:Tensor,y:Tensor): return x >= y
def Equal(x:Tensor,y:Tensor): return x == y
def And(x:Tensor,y:Tensor): return (x==y).where(x, False)
def Or(x:Tensor,y:Tensor): return (x==y).where(x, True)
def BitwiseAnd(x:Tensor,y:Tensor): return x & y
def BitwiseOr(x:Tensor,y:Tensor): return x | y
def BitwiseXor(x:Tensor,y:Tensor): return x ^ y
def BitwiseNot(x:Tensor): return ~x
# ***** Casting Ops *****
# TODO: saturate
def Cast(x:Tensor, to:int, saturate:int=1): return x.cast(dtype_parse(to))
def CastLike(x:Tensor, target_type:Tensor, saturate:int=1): return x.cast(target_type.dtype)
# ***** Reduce Ops *****
def Max(*data_0:Tensor): return functools.reduce(Tensor.maximum, data_0)
def Min(*data_0:Tensor): return functools.reduce(Tensor.minimum, data_0)
def Sum(*data_0:Tensor): return functools.reduce(Tensor.add, data_0)
def Mean(*data_0:Tensor): return Sum(*data_0) / len(data_0)
def ReduceMax(data:Tensor, axes:list[int]|None=None, keepdims:int=1, noop_with_empty_axes:int=0):
return data.max(_axes(axes, noop_with_empty_axes), keepdim=keepdims)
def ReduceMin(data:Tensor, axes:list[int]|None=None, keepdims:int=1, noop_with_empty_axes:int=0):
return data.min(_axes(axes, noop_with_empty_axes), keepdim=keepdims)
def ReduceSum(data:Tensor, axes:list[int]|None=None, keepdims:int=1, noop_with_empty_axes:int=0):
return data.sum(_axes(axes, noop_with_empty_axes), keepdim=keepdims)
def ReduceMean(data:Tensor, axes:list[int]|None=None, keepdims:int=1, noop_with_empty_axes:int=0):
return data.mean(_axes(axes, noop_with_empty_axes), keepdim=keepdims)
def ReduceSumSquare(data:Tensor, axes:list[int]|None=None, keepdims:int=1, noop_with_empty_axes:int=0):
return ReduceSum(data.square(), axes, keepdims, noop_with_empty_axes)
def ReduceProd(data:Tensor, axes:list[int]|None=None, keepdims:int=1, noop_with_empty_axes:int=0):
return data.prod(_axes(axes, noop_with_empty_axes), keepdim=keepdims)
def ReduceL1(data:Tensor, axes:list[int]|None=None, keepdims:int=1, noop_with_empty_axes:int=0):
return ReduceSum(data.abs(), axes, keepdims, noop_with_empty_axes)
def ReduceL2(data:Tensor, axes:list[int]|None=None, keepdims:int=1, noop_with_empty_axes:int=0):
return ReduceSumSquare(data, axes, keepdims, noop_with_empty_axes).sqrt()
def ReduceLogSum(data:Tensor, axes:list[int]|None=None, keepdims:int=1, noop_with_empty_axes:int=0):
return ReduceSum(data, axes, keepdims, noop_with_empty_axes).log()
def ReduceLogSumExp(data:Tensor, axes:list[int]|None=None, keepdims:int=1, noop_with_empty_axes:int=0):
return ReduceSum(data.exp(), axes, keepdims, noop_with_empty_axes).log()
def ArgMax(x:Tensor, axis:int=0, keepdims:int=1, select_last_index:int=0):
if select_last_index: return ((x.shape[axis]-1) - x.flip(axis).argmax(axis, keepdim=keepdims)).cast(dtypes.int64)
return x.argmax(axis, keepdim=keepdims).cast(dtypes.int64)
def ArgMin(x, axis:int=0, keepdims:int=1, select_last_index:int=0):
return ArgMax(-x, axis=axis, keepdims=keepdims, select_last_index=select_last_index)
# ***** Movement Ops *****
def Reshape(data:Tensor, shape:list[int], allowzero:int=0):
return data.reshape([x if x != 0 else (0 if allowzero else data.shape[i]) for i,x in enumerate(shape)])
def Flatten(x:Tensor, axis:int=1): return x.reshape(prod(x.shape[0:axis]), -1)
def Expand(x:Tensor, shape:list[int]): return x.expand(_broadcast_shape(x.shape, tuple(shape)))
def Shrink(x:Tensor, bias:float=0.0, lambd:float=0.5): return (x < -lambd)*(x+bias) + (x > lambd)*(x-bias)
def Transpose(x:Tensor, perm:list[int]|None=None): return x.permute(order=list(range(x.ndim)[::-1]) if perm is None else perm)
# TODO: add test for when axes is None
def Squeeze(data:Tensor, axes:list[int]|None=None):
return data.squeeze() if axes is None else functools.reduce(lambda d, dim: d.squeeze(dim), sorted(axes, reverse=True), data)
def Unsqueeze(data:Tensor, axes:list[int]): return functools.reduce(lambda d, dim: d.unsqueeze(dim), sorted(axes), data)
def Tile(x:Tensor, repeats:list[int]): return x.repeat(repeats)
def Concat(*xs:Tensor, axis:int): return Tensor.cat(*xs, dim=axis)
def Slice(data:Tensor, starts:list[int], ends:list[int], axes:list[int]|None=None, steps:list[int]|None=None):
axes = axes or list(range(data.ndim))
steps = steps or [1]*data.ndim
slices = [slice(0,x,1) for x in data.shape]
for i, axis in enumerate(axes): slices[axis] = slice(starts[i], ends[i], steps[i])
return data[tuple(slices)]
def Split(data:Tensor, split:list[int]|None=None, num_outputs:int=0, axis:int=0):
sz = data.shape[axis]
if split is None: split = [sz // num_outputs + (1 if i < sz % num_outputs else 0) for i in range(num_outputs)]
return data.split(split, axis)
def Pad(x:Tensor, pads:list[int], constant_value:ConstType|None=None, axes:list[int]|None=None,
mode:Literal["constant", "reflect", "edge", "wrap"]="constant", value=0):
value = constant_value or value
axes = axes or list(range(x.ndim))
real_pads = [0] * (x.ndim*2)
for i,axis in enumerate(axes): real_pads[axis%x.ndim], real_pads[axis%x.ndim+x.ndim] = pads[i], pads[i+len(axes)]
return x.pad(padding=_onnx_pads_to_tiny_pads(real_pads), mode={"edge":"replicate", "wrap":"circular"}.get(mode, mode), value=value)
def CenterCropPad(t:Tensor, shape:list[int], axes:list[int]|None=None):
shrink_arg:list[None|tuple[int,int]] = [None] * t.ndim
pad_arg:list[None|tuple[int,int]] = [None] * t.ndim
for s, x in zip(shape, axes or range(t.ndim)):
tx = t.shape[x]
if s < tx: shrink_arg[x] = (tx//2 - (s+1)//2, tx//2 + s//2)
elif s > tx: pad_arg[x] = ((s-tx)//2, (s-tx+1)//2)
return t.shrink(tuple(shrink_arg)).pad(tuple(pad_arg))
# ***** Processing Ops *****
def AveragePool(X: Tensor, kernel_shape:list[int], auto_pad:AUTO_PAD_OPTIONS="NOTSET", ceil_mode:int=0, count_include_pad:int=0,
dilations:list[int]|int=1, pads:list[int]|int=0, strides:list[int]|int=1):
return X.avg_pool2d(kernel_shape, strides, dilations, _resolve_pool_pads(X, pads, kernel_shape, dilations, strides, auto_pad),
ceil_mode=ceil_mode, count_include_pad=count_include_pad)
def MaxPool(X: Tensor, kernel_shape:list[int], auto_pad:AUTO_PAD_OPTIONS="NOTSET", ceil_mode:int=0, dilations:list[int]|int=1, pads:list[int]|int=0,
storage_order:int=0, strides:list[int]|int=1):
ret = X.max_pool2d(kernel_shape, strides, dilations, _resolve_pool_pads(X, pads, kernel_shape, dilations, strides, auto_pad), ceil_mode=ceil_mode)
# tests expect indices with int64 dtype
# TODO: if there are repeated values, this is wrong
indices = ((ret.reshape(-1, 1) == X.reshape(1, -1)) * Tensor.arange(X.numel(), dtype=dtypes.int64).unsqueeze(0)).sum(1).reshape(ret.shape)
return ret.cast(X.dtype), indices.transpose(-2, -1) if storage_order else indices
def Conv(X: Tensor, W: Tensor, B:Tensor|None=None, auto_pad:AUTO_PAD_OPTIONS="NOTSET", dilations:list[int]|int=1, group:int=1,
kernel_shape:list[int]|None=None, pads:list[int]|int=0, strides:list[int]|int=1):
return X.conv2d(W, B, stride=strides, groups=group, dilation=dilations,
padding=_resolve_pool_pads(X, pads, kernel_shape or W.shape[2:], dilations, strides, auto_pad))
def ConvTranspose(X: Tensor, W: Tensor, B:Tensor|None=None, auto_pad:AUTO_PAD_OPTIONS="NOTSET", dilations:list[int]|int=1, group:int=1,
kernel_shape:list[int]|None=None, pads:list[int]|None=None, output_shape:list[int]|None=None, output_padding:list[int]|int=0,
strides:list[int]|int=1):
input_shape, kernel_shape = X.shape[2:], (kernel_shape or W.shape[2:])
strides, dilations, output_padding = (make_tuple(x, len(input_shape)) for x in (strides, dilations, output_padding))
if output_shape is not None: # we pad according to output_shape
pads = _auto_pad([s*(i-1) + op + ((k-1)*d+1) - os for s,i,op,k,d,os in
zip(strides, input_shape, output_padding, kernel_shape, dilations, output_shape)], auto_pad)
if pads is None: # we generate pads
output_shape = output_shape or [X.shape[i+2] * strides[i] for i in range(len(strides))]
pads = [strides[i]*(input_shape[i]-1) + output_padding[i] + ((kernel_shape[i]-1)*dilations[i]+1)-output_shape[i] for i in range(len(input_shape))]
pads = _auto_pad(pads, auto_pad) if auto_pad != "NOTSET" else [0] * len(input_shape) * 2
pads = _onnx_pads_to_tiny_pads(pads)
return X.conv_transpose2d(W, B, stride=strides, groups=group, dilation=dilations, padding=pads, output_padding=output_padding)
def MaxUnpool(xT: Tensor, xI: Tensor, outshape: list[int]|None=None, kernel_shape:list[int]=None, pads:list[int]|int=0, strides:list[int]|int=1):
pads, strides = (make_tuple(x, len(xI.shape)) for x in (pads, strides))
out_sh = [(ks//2)*2 + st * inps for inps, st, ks in zip(xI.shape, strides, kernel_shape)]
ret = (xI.reshape(-1, 1)._one_hot_along_dim(prod(out_sh)) * xT.reshape(-1, 1)).sum(0).reshape(1, 1, *out_sh)
if outshape is not None and outshape != ret.shape: pads = _auto_pad([outshape[-2] - ret.shape[-2], outshape[-1] - ret.shape[-1]], "SAME_UPPER")
return ret.pad(_onnx_pads_to_tiny_pads(pads))
def GlobalAveragePool(X:Tensor): return X.mean(axis=tuple(range(2, X.ndim)), keepdim=True)
def GlobalMaxPool(X:Tensor): return X.max(axis=tuple(range(2, X.ndim)), keepdim=True)
def Gemm(A:Tensor, B:Tensor, C:Tensor|None=None, alpha:float=1.0, beta:float=1.0, transA:int=0, transB:int=0, broadcast=0):
ret = alpha * (A.transpose(transA) @ B.transpose(transB))
if C is not None: ret = ret + beta * (C if broadcast == 0 else C.reshape([-1 if i < len(C.shape) else 1 for i in range(ret.ndim)][::-1]))
return ret
def Einsum(*Inputs:list[Tensor], equation:str): return Tensor.einsum(equation, *Inputs)
def CumSum(X:Tensor, axis:int|list, exclusive:int=0, reverse:int=0):
axis = X._resolve_dim(axis[0] if isinstance(axis, list) else axis)
if reverse: X = X.flip(axis)
if exclusive: X = X.pad(tuple((1,0) if i == axis else None for i in range(X.ndim)))\
.shrink(tuple((0,X.shape[axis]) if i == axis else None for i in range(X.ndim)))
return X.cumsum(axis).flip(axis) if reverse else X.cumsum(axis)
def Trilu(x:Tensor, k:int=0, upper:int=1): return x.triu(k) if upper else x.tril(k)
def Resize(X:Tensor, roi:list[float]|None=None, scales:list[float]|None=None, sizes:list[int]|None=None, antialias:int=0,
axes:list[int]|None=None, coordinate_transformation_mode:str='half_pixel', cubic_coeff_a:float=-0.75, exclude_outside:int=0,
extrapolation_value:float=0.0, keep_aspect_ratio_policy:str='stretch', mode:str='nearest', nearest_mode:str='round_prefer_floor'):
def _apply_nearest_mode(index: Tensor, input_dim, mode: str):
if mode == "round_prefer_floor": index = (index - 0.5).ceil()
elif mode == "round_prefer_ceil": index = (index + 0.5).floor()
elif mode in ["floor", "ceil"]: index = getattr(index, mode)()
else: raise ValueError(f"invalid {nearest_mode=}")
return index.cast(dtypes.int32).clip(0, input_dim-1)
def _apply_transformation(index: Tensor, input_dim, scale_dim, roi_dim, mode):
# TODO: needs more testing, not confident in this
# NOTE: their reference implementation differ from the implementation in their reference docs
# https://github.com/onnx/onnx/blob/main/onnx/reference/ops/op_resize.py
# https://github.com/onnx/onnx/blob/main/docs/Operators.md#Resize
output_dim = scale_dim * input_dim
if mode == "half_pixel": index = (index + 0.5) / scale_dim - 0.5
elif mode == "align_corners": index = index * (input_dim - 1) / (output_dim - 1) if output_dim != 1 else Tensor([0])
elif mode == "asymmetric": index = index / scale_dim
elif mode == "pytorch_half_pixel": index = (index + 0.5) / scale_dim - 0.5 if output_dim != 1 else Tensor([-0.5])
elif mode == "half_pixel_symmetric": index = input_dim / 2 * (1 - int(output_dim) / output_dim) + (index + 0.5) / scale_dim - 0.5
elif mode == "tf_crop_and_resize": index = roi_dim[0] * (input_dim - 1) + index * ((roi_dim[1] - roi_dim[0]) * (input_dim - 1) / (output_dim - 1))
else: raise ValueError(f"invalid {coordinate_transformation_mode=}")
return index.clip(0, input_dim-1)
scales, sizes = (None if scales is None else scales[2-(X.ndim-len(scales)):]), (None if sizes is None else sizes[2-(X.ndim-len(sizes)):])
# we pre permute the axes and permute back after resize
axes, input_shape, = (axes or list(range(X.ndim))), cast(tuple[int, ...], X.shape[2:]),
perm = [a for a in range(len(X.shape)) if a not in axes] + list(axes)
X = X.permute(*perm)
if sizes is not None:
if keep_aspect_ratio_policy in ["not_larger", "not_smaller"]:
scale_fxn = min if keep_aspect_ratio_policy == "not_larger" else max
scales = [scale_fxn([sizes[i] / input_shape[i] for i in range(len(input_shape)) if i+2 in axes])] * 2
sizes = [int((scales[0] * input_shape[i]) + 0.5) if i+2 in axes else input_shape[i] for i in range(X.ndim-2)]
else:
scales = [size / input_shape for size, input_shape in zip(sizes, input_shape)]
else:
sizes = [int(sc*sh) for sc, sh in zip(scales, input_shape)]
regions = [[st, ed] for st, ed in zip(roi, roi[len(roi)//2:])] if isinstance(roi, list) and roi else [[0.0, 0.0]] * (X.ndim-2)
# NOTE: this transformation makes it so that we can't just call Tensor.interpolate
# in Tensor.interpolate, we use indexes without any transformation
indexes = []
for shape, size, scale, region in zip(input_shape, sizes, scales, regions):
indexes.append(_apply_transformation(Tensor.arange(size), shape, scale, region, coordinate_transformation_mode))
if mode == "nearest":
indexes = [_apply_nearest_mode(index, shape, nearest_mode) for (index, shape) in zip(indexes, input_shape)]
X = X[(..., *Tensor.meshgrid(*indexes))]
if mode == "linear":
expand = list(X.shape)
for i in range(-len(sizes), 0):
reshape, index = [1] * X.ndim, indexes[i]
reshape[i] = expand[i] = sizes[i]
low, high, perc = [y.reshape(reshape).expand(expand) for y in (index.floor(), index.ceil(), index - index.floor())]
X = X.gather(i, low).lerp(X.gather(i, high), perc)
if mode == "cubic": raise NotImplementedError("cubic interpolation is not implemented")
return X.permute(*[perm.index(i) for i in range(len(perm))]) if perm else X
def Upsample(X, scales, mode): return Resize(X=X, scales=scales, mode=mode) # deprecated
# ***** Neural Network Ops *****
# TODO: try to factor out common implementations for these ops
# https://medium.com/@zljdanceholic/groupnorm-then-batchnorm-instancenorm-layernorm-e2b2a1d350a0
def BatchNormalization(X:Tensor, scale:Tensor, B:Tensor, input_mean:Tensor, input_var:Tensor, epsilon:float=1e-05, momentum:float=0.9,
training_mode:int=0, spatial=1, is_test=0):
if training_mode:
x_detached = X.detach()
current_mean = x_detached.mean(axis=(0,2,3))
y = (x_detached - current_mean.reshape(shape=[1, -1, 1, 1]))
current_var = (y*y).mean(axis=(0,2,3))
current_invstd = current_var.add(epsilon).rsqrt()
running_mean = input_mean * momentum + current_mean * (1 - momentum)
running_var = input_var * momentum + current_var * (1 - momentum)
return X.batchnorm(scale, B, current_mean, current_invstd), running_mean, running_var
invstd = (input_var + epsilon).rsqrt()
return X.batchnorm(scale, B, input_mean, invstd)
def InstanceNormalization(x:Tensor, scale:Tensor, bias:Tensor, epsilon:float=1e-05):
axis = tuple(range(2, x.ndim))
mean = x.mean(axis=axis, keepdim=True)
invstd = x.sub(mean).square().mean(axis=axis, keepdim=True).add(epsilon).rsqrt()
return x.sub(mean).mul(scale.reshape(shape=[-1, 1, 1])).mul(invstd).add(bias.reshape(shape=[-1, 1, 1]))
def LayerNormalization(x:Tensor, scale:Tensor, bias:Tensor, axis:int=-1, epsilon:float=1e-05, stash_type:int=1):
assert stash_type == 1, "only float32 is supported"
axes = tuple(i for i in range(axis if axis >= 0 else x.ndim + axis, x.ndim))
mean = x.mean(axis=axes, keepdim=True)
return x.layernorm(axes, epsilon).mul(scale).add(bias), mean, (x.sub(mean)).square().mean(axis=axes, keepdim=True).add(epsilon).rsqrt()
def GroupNormalization(x:Tensor, scale:Tensor, bias:Tensor, num_groups:int, epsilon:float=1e-05):
return x.reshape(x.shape[0], num_groups, -1).layernorm(axis=-1, eps=epsilon).mul(scale.unsqueeze(-1)).add(bias.unsqueeze(-1)).reshape(x.shape)
def MeanVarianceNormalization(x:Tensor, axis:list[int]=[0,2,3]):
return (x - x.mean(axis, keepdim=True)) / (x.std(axis, keepdim=True, correction=0) + 1e-9)
def SkipLayerNormalization(x:Tensor, skip:Tensor, gamma:Tensor, beta:Tensor|None=None, bias:Tensor|None=None, epsilon:float=1e-12):
x = x + skip + bias
return x.layernorm(eps=epsilon) * gamma + beta, None, None, x
def EmbedLayerNormalization(input_ids: Tensor, segment_ids:Tensor, word_embedding:Tensor, position_embedding:Tensor,
segment_embedding:Tensor, gamma=None, beta=None, mask:Tensor|None=None,
position_ids:Tensor|None=None, epsilon=1e-12, mask_index_type=0):
# https://github.com/microsoft/onnxruntime/blob/main/docs/ContribOperators.md#com.microsoft.EmbedLayerNormalization
assert (segment_ids is None) is (segment_embedding is None)
assert mask is None and not mask_index_type, "functionality not supported yet" # TODO
input_shape = input_ids.shape
seq_length = input_shape[1]
compute_seg_emb = (segment_embedding is not None and segment_ids is not None)
vocab_size, max_position_embeddings = word_embedding.shape[0], position_embedding.shape[0]
type_vocab_size = (segment_embedding.shape[0] if compute_seg_emb else None)
def embedding(x:Tensor, vocab_size, weight:Tensor) -> Tensor:
return x.unsqueeze(-1).expand(*x.shape, vocab_size)._one_hot_along_dim(vocab_size) @ weight
# bert embedding layer
if position_ids is None: position_ids = Tensor.arange(seq_length, requires_grad=False).unsqueeze(0).expand(*input_shape)
wrd_embedding_res = embedding(input_ids, vocab_size, word_embedding)
pos_embedding_res = embedding(position_ids, max_position_embeddings, position_embedding)
seg_embedding_res = embedding(segment_ids, type_vocab_size, segment_embedding) if compute_seg_emb else None
embedding_sum = wrd_embedding_res + pos_embedding_res
if seg_embedding_res is not None: embedding_sum = embedding_sum + seg_embedding_res
out = embedding_sum.layernorm(eps=epsilon) * gamma + beta
return out, None, embedding_sum
def OneHot(indices:Tensor, depth:float|int|list, values:Tensor, axis:int=-1):
# Scalar or Rank 1 tensor containing exactly one element
depth = int(depth[0] if isinstance(depth, list) else depth)
indices = (indices < 0).where(indices+depth, indices)
return indices[:, None]._one_hot_along_dim(depth, dim=axis).where(values[1], values[0])
def DepthToSpace(X:Tensor, blocksize:int, mode:str="DCR"):
return X.rearrange("b (c h1 w1) h w -> b c (h h1) (w w1)" if mode=="CRD" else "b (h1 w1 c) h w -> b c (h h1) (w w1)", h1=blocksize, w1=blocksize)
def SpaceToDepth(X:Tensor, blocksize:int):
return X.rearrange("b c (h h1) (w w1) -> b (h1 w1 c) h w", h1=blocksize, w1=blocksize)
# Reimplemented here because you need legacy RNG for passing ONNX tests.
def Dropout_7(data:Tensor, ratio:float=0.5, training_mode:bool=False, seed:int|None=None):
if not training_mode: return data, Tensor.ones(data.shape, dtype=dtypes.bool) # if mask is requested as output it will contain all True's.
mask = Tensor(np.random.RandomState(seed).random(cast(tuple[int,...], data.shape)) >= ratio, requires_grad=False, device=data.device)
return data * mask * (1/(1.0 - ratio)), mask
# 6 with 'is_test' needed for https://github.com/MTlab/onnx2caffe/raw/refs/heads/master/model/MobileNetV2.onnx
def Dropout_6(data:Tensor, ratio:float=0.5, is_test=0): return Dropout_7(data, ratio, training_mode=not is_test)
Dropout = {6:Dropout_6, 7:Dropout_7}
def LRN(x:Tensor, size:int, alpha:float=1e-4, beta:float=0.75, bias:float=1.0):
pooled_x = (x**2).rearrange('b c h w -> b 1 c (h w)').pad((0,0,(size-1)//2, size//2)).avg_pool2d((size, 1), 1)
return x / (pooled_x.reshape(x.shape) * alpha + bias).pow(beta)
def NegativeLogLikelihoodLoss(x:Tensor, target:Tensor, weight:Tensor|None=None, ignore_index:int|None=None, reduction:ReductionStr="mean"):
return x.nll_loss(target, weight, ignore_index, reduction)
def SoftmaxCrossEntropyLoss(scores:Tensor, labels:Tensor, weights:Tensor|None=None, ignore_index:int|None=None, reduction:ReductionStr="mean"):
log_probs = scores.log_softmax(1)
return log_probs.nll_loss(labels, weights, ignore_index, reduction), log_probs
def AffineGrid(theta:Tensor, size:list[int], align_corners:int=0):
N, _, *spatial_dims = size
def generate_grid(steps):
return Tensor.linspace(-1, 1, steps, device=theta.device) if align_corners else Tensor.linspace(-1+1/steps, 1-1/steps, steps, device=theta.device)
grids = Tensor.meshgrid(*(generate_grid(d) for d in spatial_dims))
base_grid = Tensor.stack(*reversed(grids), Tensor.ones_like(grids[0], device=theta.device), dim=-1)
base_grid = base_grid.reshape(1, prod(spatial_dims), len(grids)+1).expand(N, -1, -1)
return (base_grid @ theta.transpose(1, 2)).reshape(N, *spatial_dims, -1)
def Attention(x:Tensor, weights, bias:Tensor, mask_index:Tensor|None=None, past:Tensor|None=None,
relative_position_bias:Tensor|None=None, past_sequence_length:Tensor|None=None, do_rotary:int|None=None,
mask_filter_value:float|None=None, num_heads:int|None=None, past_present_share_buffer:int|None=None,
qkv_hidden_sizes:list[int]|None=None, scale:float|None=None, unidirectional:int|None=None):
# https://github.com/microsoft/onnxruntime/blob/main/docs/ContribOperators.md#com.microsoft.Attention
assert num_heads is not None # required
assert (qkv_hidden_sizes is None and past is not None) or (qkv_hidden_sizes is not None)
assert relative_position_bias is do_rotary is past_sequence_length is mask_filter_value is past_present_share_buffer is scale is None, \
"functionality not supported yet" # TODO strange params
hidden_size, v_hidden_size = qkv_hidden_sizes[1:] if qkv_hidden_sizes is not None else 2*(weights.shape[1] // 3,)
if unidirectional: # gpt-style
assert hidden_size == v_hidden_size
xqkv = x.linear(weights, bias)
xq, xk, xv = [xqkv.shrink([None, None, (i*hidden_size, (i+1)*hidden_size)]) for i in range(3)]
else: # bert-style
wq, wk, wv = weights[:,:hidden_size], weights[:,hidden_size:hidden_size+v_hidden_size], weights[:,hidden_size+v_hidden_size:]
bq, bk, bv = (bias[:hidden_size], bias[hidden_size:hidden_size+v_hidden_size], bias[hidden_size+v_hidden_size]) if bias is not None else None
xq, xk, xv = [x.linear(w, b) for w, b in zip((wq, wk, wv), (bq, bk, bv))]
xq, xk, xv = [x.reshape(x.shape[0], x.shape[1], num_heads, -1).transpose(1, 2) for x in (xq, xk, xv)]
if past is not None:
xk, xv = Tensor.cat(past[0], xk, dim=-2), Tensor.cat(past[1], xv, dim=-2)
present = Tensor.cat(xk.unsqueeze(0), xv.unsqueeze(0))
def attn(query, key, value, attn_mask):
query_length, key_length = query.shape[-2], key.shape[-2]
cdim = max(query_length, key_length) + 1
attn_weights = query @ key.transpose(-1, -2) / math.sqrt(value.shape[-1])
# This is where Tensor.scaled_dot_product_attention differs:
causal_mask = Tensor.ones((cdim, cdim), requires_grad=False, dtype=dtypes.bool).tril(0)[key_length - query_length : key_length, :key_length]
masked = Tensor.where(causal_mask, attn_weights, -math.inf)
if attn_mask is not None: masked = masked + attn_mask
return masked.softmax(-1) @ value
bsz, _, seq_len, _ = xq.shape
out = attn(xq, xk, xv, mask_index).transpose(1, 2).reshape(bsz, seq_len, -1)
return out, present if past is not None else out
# ***** Indexing Ops *****
def ArrayFeatureExtractor(x:Tensor, indices:Tensor): return x[..., indices]
def Gather(x:Tensor, indices:Tensor, axis:int=0):
if indices.numel() < 9: # NOTE lessor kernels for smaller indices but kernel number increases depending on size of indices
x_sh = list(x.shape)
ret_shape = x_sh[:axis] + list(indices.shape) + x_sh[axis+1:]
if indices.ndim > 1: indices = indices.flatten()
indices = [_cached_to_python_const(indices)] if indices.shape == () else [x_sh[axis]+x if x<0 else x for x in _cached_to_python_const(indices)]
args = [[(0,x) if j != axis else (i,i+1) for j, x in enumerate(x_sh)] for i in indices] # type: ignore
return x.shrink(arg=tuple(args[0])).cat(*[x.shrink(arg=tuple(arg)) for arg in args[1:]], dim=axis).reshape(ret_shape)
# NOTE faster gather, fixed number of kernels, but exceeds limited kernels for openpilot
return x[tuple([slice(None) if i != axis else indices for i in range(x.ndim)])]
def Scatter(*args, **kwargs): return ScatterElements(*args, **kwargs) # deprecated
def GatherND(x:Tensor, indices:Tensor, batch_dims:int=0):
if batch_dims == 0: return x[tuple(i.squeeze(-1) for i in indices.split(1, -1))]
x_shape, i_shape = x.shape, indices.shape
b = math.prod(x.shape[dim] for dim in range(batch_dims))
# NOTE: each batched dim of both input and indices are equal
x = x.reshape(b, *x.shape[batch_dims:])
indices = indices.reshape(b, *indices.shape[batch_dims:])
b_idx = Tensor.arange(b, device=x.device).reshape(b, *(1,)*(indices.ndim - 2)).expand(*indices.shape[:-1])
ret = x[(b_idx,) + tuple(i.squeeze(-1) for i in indices.split(1, -1))]
return ret.reshape(*x_shape[:batch_dims], *i_shape[batch_dims:-1], *ret.shape[indices.ndim-1:])
def ScatterND(x:Tensor, indices:Tensor, updates:Tensor, reduction:Literal["none", "add", "mul"]='none'):
assert updates.shape == indices.shape[:-1] + x.shape[cast(int, indices.shape[-1]):]
x = x.contiguous()
for index, u in zip(indices.split(1, 0), updates.split(1, 0)):
i = tuple(idx.squeeze(-1) for idx in index.squeeze(0).split(1, -1))
u = u.squeeze(0)
if reduction == "none": x[i] = u
elif reduction == "add": x[i] += u
elif reduction == "mul": x[i] *= u
else: raise NotImplementedError("reduction doesn't support max or min")
return x
def ScatterElements(x: Tensor, indices: Tensor, updates: Tensor, axis=0, reduction:Literal["none", "add", "mul"]="none"):
indices = (indices < 0).where(x.shape[axis], 0) + indices
return x.scatter(axis, indices, updates, {"none":None, "mul": "multiply"}.get(reduction, reduction))
def GatherElements(x:Tensor, indices:Tensor, axis:int):
indices = (indices < 0).where(x.shape[axis], 0) + indices
return x.gather(axis, indices)
def Compress(inp:Tensor, condition:list[bool], axis:int|None=None):
if axis is None:
inp = inp.flatten()
axis = 0
if axis < 0: axis += inp.ndim
con = Tensor(np.arange(len(condition))[condition]) # no boolean indexing in Tensor
return inp[tuple(con if i == axis else slice(None) for i in range(inp.ndim))]
# ***** Quantization Ops *****
def QuantizeLinear(x:Tensor, y_scale:Tensor, y_zero_point:Tensor|int=0, axis:int=1, block_size:int=0, output_dtype:int=0, saturate=1):
out_dtype = y_zero_point.dtype if isinstance(y_zero_point, Tensor) else dtype_parse(output_dtype) if output_dtype else dtypes.uint8
y_scale, y_zero_point = _prepare_quantize(x, y_scale, y_zero_point, axis, block_size)
return _clamp_cast(((x / y_scale).round() + y_zero_point), out_dtype).contiguous()
def DequantizeLinear(x:Tensor, x_scale:Tensor, x_zero_point:Tensor|int=0, axis:int=1, block_size:int=0):
x_scale, x_zero_point = _prepare_quantize(x, x_scale, x_zero_point, axis, block_size)
return ((x.int() - x_zero_point) * x_scale).cast(x_scale.dtype)
def QLinearConv(x:Tensor, x_scale:Tensor, x_zero_point:Tensor|int, w:Tensor, w_scale:Tensor, w_zero_point:Tensor|int, y_scale:Tensor,
y_zero_point: Tensor|int, B:Tensor|None=None, **opts):
return _qlinearop_quantized(Conv, [x,w], [x_zero_point,w_zero_point], [x_scale,w_scale], y_scale, y_zero_point, **{"B":B, **opts})
def QLinearMatMul(a:Tensor, a_scale:Tensor, a_zero_point:Tensor|int, b:Tensor, b_scale:Tensor, b_zero_point:Tensor|int, y_scale:Tensor,
y_zero_point:Tensor|int) -> Tensor:
return _qlinearop_quantized(Tensor.matmul, [a,b], [a_zero_point,b_zero_point], [a_scale,b_scale], y_scale, y_zero_point)
def QLinearAdd(a:Tensor, a_scale:Tensor, a_zero_point:Tensor, b:Tensor, b_scale:Tensor, b_zero_point:Tensor, c_scale:Tensor, c_zero_point:Tensor):
return _qlinearop_float(Tensor.add, [a,b], [a_zero_point,b_zero_point], [a_scale,b_scale], c_scale, c_zero_point)
def QLinearGlobalAveragePool(X:Tensor, x_scale:Tensor, x_zero_point:Tensor, y_scale:Tensor, y_zero_point:Tensor, channels_last:int):
assert channels_last == 0, "unsure what this does"
return _qlinearop_float(GlobalAveragePool, [X], [x_zero_point], [x_scale], y_scale, y_zero_point)
def ConvInteger(x: Tensor, w: Tensor, x_zero_point: Tensor | int = 0, w_zero_point: Tensor | int = 0, B: Tensor | None = None, **opts) -> Tensor:
return _op_integer(Conv, [x,w], [x_zero_point,w_zero_point], **{"B":B, **opts})
def MatMulInteger(A: Tensor, B: Tensor, a_zero_point: Tensor | int = 0, b_zero_point: Tensor | int = 0) -> Tensor:
return _op_integer(Tensor.matmul, [A,B], [a_zero_point,b_zero_point])
# ***** Training Ops *****
# NOTE: onnx test coverage only covers `T==0` cases, so for all `T>0` this isn't tested
# NOTE: onnx training ops actually don't need the state for optim, all the ops work in a functional way, but we still can reuse optim.py code
@_onnx_training(3)
def Adagrad(R:Tensor, T:int, *inputs:Tensor, decay_factor:float=0.0, epsilon:float=0.0, norm_coefficient:float=0.0):
X, G, H = (i.detach() for i in inputs)
grad = norm_coefficient * X + G
H.assign(H + grad.square())
up = grad / (H.sqrt() + epsilon)
r = R / (1 + T * decay_factor)
X.assign(X.detach() - r * up)
return [X, H]
@_onnx_training(4)
def Adam(R:Tensor, T:int, *inputs:Tensor, alpha:float=0.9, beta:float=0.999, epsilon:float=0.0, norm_coefficient:float=0.0,
norm_coefficient_post:float=0.0):
from tinygrad.nn.optim import Adam as TinyAdam
X, G, V, H = inputs
G, V, H = G.detach(), V.detach(), H.detach() # TODO we shouldn't need these detaches
X.grad = norm_coefficient * X.detach() + G
opt = TinyAdam([X], b1=alpha, b2=beta, eps=epsilon)
opt.m, opt.v, opt.lr = [V], [H], R
# need no-op for m_hat and v_hat if T == 0
if T == 0: opt.b1_t, opt.b2_t = opt.b1_t.zeros_like(), opt.b2_t.zeros_like()
else:
# `T-1` since it's applied again at the start of `_step`
opt.b1_t = Tensor([alpha**(T-1)], dtype=dtypes.float32, device=X.device, requires_grad=False)
opt.b2_t = Tensor([beta**(T-1)], dtype=dtypes.float32, device=X.device, requires_grad=False)
opt.step()
X = (1 - norm_coefficient_post) * X
return [X, V, H]
@_onnx_training(3)
def Momentum(R:Tensor, T:int, *inputs:Tensor, alpha:float, beta:float, mode:str, norm_coefficient:float):
from tinygrad.nn.optim import SGD
X, G, V = inputs
G, V = G.detach(), V.detach()
X.grad = (norm_coefficient * X.detach() + G) * (beta if T > 0 else 1)
opt = SGD([X], momentum=alpha, nesterov=(mode=="nesterov"))
opt.b, opt.lr = [V], R
opt.step()
return [X, V]
def Gradient(*inputs:Tensor, y:str, intermediate_tensors:dict[str, Tensor], **_):
intermediate_tensors[y].backward()
return tuple([t.grad for t in inputs])
return {
# Tensor ops
**{op: getattr(Tensor, op.lower()) for op in ("Neg", "Reciprocal", "Pow", "Sqrt", "Sign", "Abs", "Exp", "Log", "Mish", "Sin", "Cos", "Tan",
"Asin", "Acos", "Atan", "Relu", "Sigmoid", "MatMul", "Floor", "Ceil", "IsInf", "IsNaN", "Softplus", "HardSwish", "Where", "Mul", "Sinh", "Cosh",
"Tanh", "Softsign", "Asinh", "Acosh", "Atanh", "Elu", "Celu", "Selu", "Xor", "Round", "Erf", "Mod")},
# Implemented ops
**{name:obj for name,obj in locals().items() if isinstance(obj, types.FunctionType) and not name.startswith("_") and name[0].isupper()},
# Version ops
**{name:obj for name,obj in locals().items() if isinstance(obj, dict)},
}
onnx_ops = get_onnx_ops()

View File

@@ -1,606 +0,0 @@
import functools, io, math
from typing import cast, Literal
from tinygrad.tensor import Tensor, _broadcast_shape, ConstType, ReductionStr
from tinygrad.dtype import ImageDType, dtypes, DType
from tinygrad.helpers import prod, flatten, make_tuple
from extra.onnx import dtype_parse, _cached_to_python_const
import numpy as np
# ***** Property/Graph Ops *****
def Identity(x:Tensor): return x
def Constant(sparse_value:Tensor|None=None, value:Tensor|None=None, value_float:float|None=None, value_floats:list[float]|None=None,
value_int:int|None=None, value_ints:list[int]|None=None, value_string:str|None=None, value_strings:list[str]|None=None):
if value is not None: return value
if value_float is not None: return Tensor(value_float, dtype=dtypes.float32, requires_grad=False)
if value_floats is not None: return Tensor(list(value_floats), dtype=dtypes.float32, requires_grad=False)
if value_int is not None: return Tensor(value_int, dtype=dtypes.int64, requires_grad=False)
if value_ints is not None: return Tensor(list(value_ints), dtype=dtypes.int64, requires_grad=False)
if value_string is not None or value_strings is not None and sparse_value is not None:
raise NotImplementedError('Constant OP not implemented for value_string, value_strings and sparse_value')
def Range(start:float|int, limit:float|int, delta:float|int): return Tensor.arange(start=start, stop=limit, step=delta)
def ImageDecoder(encoded_stream:bytes, pixel_format="RGB"):
try: import PIL.Image
except ImportError as e: raise ImportError("Pillow must be installed for the ImageDecoder operator") from e
img = PIL.Image.open(io.BytesIO(encoded_stream))
if pixel_format == "BGR": return Tensor(np.array(img))[:, :, ::-1]
if pixel_format == "RGB": return Tensor(np.array(img))
if pixel_format == "Grayscale": return Tensor(np.array(img.convert("L"))).unsqueeze(-1) # (H, W) to (H, W, 1)
raise ValueError(f"pixel_format={pixel_format!r} is not supported.")
def EyeLike(x:Tensor, dtype:int|None=None, k:int=0):
ret = Tensor.eye(cast(int, min(x.shape)), dtype=dtype_parse(dtype) if dtype is not None else x.dtype)
return ret if x.size(0) == x.size(1) else ret.pad(tuple(None if d == ret.size(0) else (k, d-ret.shape[0]-k) for d in x.shape))
def OptionalHasElement(x:Tensor|None=None): return Tensor(x is not None and x.numel() > 0)
def OptionalGetElement(x:Tensor|None=None): return x if x is not None else Tensor([])
def ConstantOfShape(shape:list[int], value:Tensor|None=None):
if value is None: value = Tensor(0, dtype=dtypes.float32)
return Tensor.ones(*shape, dtype=value.dtype) * (value if shape != [0] else 1)
def Size(data:Tensor): return data.numel()
def Shape(data:Tensor, end:int|None=None, start:int=0): return Tensor(data.shape[start:end], dtype=dtypes.int64)
# ***** Unary Ops (math) *****
def Not(x:Tensor): return x.logical_not()
def Clip(x: Tensor, min:Tensor|None=None, max:Tensor|None=None):
return x.clip(float('-inf') if min is None else min, float('inf') if max is None else max).cast(x.dtype)
# ***** Unary Ops (activation) *****
def Softmax_1(x:Tensor, axis:int=1): return x.softmax(axis)
def Softmax_13(x:Tensor, axis:int=-1): return x.softmax(axis)
Softmax = {1:Softmax_1, 13:Softmax_13}
def HardSigmoid(x:Tensor, alpha:float=0.2, beta:float=0.5): return (alpha*x + beta).clip(0, 1)
def Gelu(x:Tensor, approximate:str|None=None): return x.gelu() if approximate == "tanh" else 0.5 * x * (1 + (x/math.sqrt(2)).erf())
def FastGelu(x:Tensor, bias:Tensor|None=None):
# this is tanh approximated
return (x + bias).gelu() if bias is not None else x.gelu()
# TODO: fix this
def PRelu(X:Tensor, slope:Tensor):
slope = slope[0] if slope.shape[-1] != X.shape[-1] else slope
return (X > 0).where(X, X * slope)
def LeakyRelu(X:Tensor, alpha:float=0.01): return X.leakyrelu(alpha)
def ThresholdedRelu(X:Tensor, alpha:float=1.0): return (X > alpha).where(X, 0)
def LogSoftmax(x: Tensor, axis:int=-1): return x.log_softmax(axis)
def Binarizer(x:Tensor, threshold:float=0.0): return (x > threshold).float()
# ***** Unary Ops (broadcasted) *****
def Add(x:Tensor,y:Tensor, broadcast=None, axis=None): return x + y if x.dtype == dtypes.float or isinstance(x.dtype, ImageDType) else (x + y).cast(x.dtype)
def Sub(x:Tensor|int,y:Tensor): return x - y # some test has input as int
def Div(x:Tensor,y:Tensor): return (x/y).cast(x.dtype)
def Less(x:Tensor,y:Tensor): return x < y
def LessOrEqual(x:Tensor,y:Tensor): return x <= y
def Greater(x:Tensor,y:Tensor): return x > y
def GreaterOrEqual(x:Tensor,y:Tensor): return x >= y
def Equal(x:Tensor,y:Tensor): return x == y
def And(x:Tensor,y:Tensor): return (x==y).where(x, False)
def Or(x:Tensor,y:Tensor): return (x==y).where(x, True)
def BitwiseAnd(x:Tensor,y:Tensor): return x & y
def BitwiseOr(x:Tensor,y:Tensor): return x | y
def BitwiseXor(x:Tensor,y:Tensor): return x ^ y
def BitwiseNot(x:Tensor): return ~x
# ***** Casting Ops *****
# TODO: saturate
def Cast(x:Tensor, to:int, saturate:int=1): return x.cast(dtype_parse(to))
def CastLike(x:Tensor, target_type:Tensor, saturate:int=1): return x.cast(target_type.dtype)
# ***** Reduce Ops *****
def Max(*data_0:Tensor): return functools.reduce(Tensor.maximum, data_0)
def Min(*data_0:Tensor): return functools.reduce(Tensor.minimum, data_0)
def Sum(*data_0:Tensor): return functools.reduce(Tensor.add, data_0)
def Mean(*data_0:Tensor): return Sum(*data_0) / len(data_0)
def _axes(axes, noop_with_empty_axes): return axes or ([] if noop_with_empty_axes else None)
def ReduceMax(data:Tensor, axes:list[int]|None=None, keepdims:int=1, noop_with_empty_axes:int=0):
return data.max(_axes(axes, noop_with_empty_axes), keepdim=keepdims)
def ReduceMin(data:Tensor, axes:list[int]|None=None, keepdims:int=1, noop_with_empty_axes:int=0):
return data.min(_axes(axes, noop_with_empty_axes), keepdim=keepdims)
def ReduceSum(data:Tensor, axes:list[int]|None=None, keepdims:int=1, noop_with_empty_axes:int=0):
return data.sum(_axes(axes, noop_with_empty_axes), keepdim=keepdims)
def ReduceMean(data:Tensor, axes:list[int]|None=None, keepdims:int=1, noop_with_empty_axes:int=0):
return data.mean(_axes(axes, noop_with_empty_axes), keepdim=keepdims)
def ReduceSumSquare(data:Tensor, axes:list[int]|None=None, keepdims:int=1, noop_with_empty_axes:int=0):
return ReduceSum(data.square(), axes, keepdims, noop_with_empty_axes)
def ReduceProd(data:Tensor, axes:list[int]|None=None, keepdims:int=1, noop_with_empty_axes:int=0):
return data.prod(_axes(axes, noop_with_empty_axes), keepdim=keepdims)
def ReduceL1(data:Tensor, axes:list[int]|None=None, keepdims:int=1, noop_with_empty_axes:int=0):
return ReduceSum(data.abs(), axes, keepdims, noop_with_empty_axes)
def ReduceL2(data:Tensor, axes:list[int]|None=None, keepdims:int=1, noop_with_empty_axes:int=0):
return ReduceSumSquare(data, axes, keepdims, noop_with_empty_axes).sqrt()
def ReduceLogSum(data:Tensor, axes:list[int]|None=None, keepdims:int=1, noop_with_empty_axes:int=0):
return ReduceSum(data, axes, keepdims, noop_with_empty_axes).log()
def ReduceLogSumExp(data:Tensor, axes:list[int]|None=None, keepdims:int=1, noop_with_empty_axes:int=0):
return ReduceSum(data.exp(), axes, keepdims, noop_with_empty_axes).log()
def ArgMax(x:Tensor, axis:int=0, keepdims:int=1, select_last_index:int=0):
if select_last_index: return ((x.shape[axis]-1) - x.flip(axis).argmax(axis, keepdim=keepdims)).cast(dtypes.int64)
return x.argmax(axis, keepdim=keepdims).cast(dtypes.int64)
def ArgMin(x, axis:int=0, keepdims:int=1, select_last_index:int=0):
return ArgMax(-x, axis=axis, keepdims=keepdims, select_last_index=select_last_index)
# ***** Movement Ops *****
def Reshape(data:Tensor, shape:list[int], allowzero:int=0):
return data.reshape([x if x != 0 else (0 if allowzero else data.shape[i]) for i,x in enumerate(shape)])
def Flatten(x:Tensor, axis:int=1): return x.reshape(prod(x.shape[0:axis]), -1)
def Expand(x:Tensor, shape:list[int]): return x.expand(_broadcast_shape(x.shape, tuple(shape)))
def Shrink(x:Tensor, bias:float=0.0, lambd:float=0.5): return (x < -lambd)*(x+bias) + (x > lambd)*(x-bias)
def Transpose(x:Tensor, perm:list[int]|None=None): return x.permute(order=list(range(x.ndim)[::-1]) if perm is None else perm)
# TODO: add test for when axes is None
def Squeeze(data:Tensor, axes:list[int]|None=None):
return data.squeeze() if axes is None else functools.reduce(lambda d, dim: d.squeeze(dim), sorted(axes, reverse=True), data)
def Unsqueeze(data:Tensor, axes:list[int]): return functools.reduce(lambda d, dim: d.unsqueeze(dim), sorted(axes), data)
def Tile(x:Tensor, repeats:list[int]): return x.repeat(repeats)
def Concat(*xs:Tensor, axis:int): return Tensor.cat(*xs, dim=axis)
def Slice(data:Tensor, starts:list[int], ends:list[int], axes:list[int]|None=None, steps:list[int]|None=None):
axes = axes or list(range(data.ndim))
steps = steps or [1]*data.ndim
slices = [slice(0,x,1) for x in data.shape]
for i, axis in enumerate(axes): slices[axis] = slice(starts[i], ends[i], steps[i])
return data[tuple(slices)]
def Split(data:Tensor, split:list[int]|None=None, num_outputs:int=0, axis:int=0):
sz = data.shape[axis]
if split is None: split = [sz // num_outputs + (1 if i < sz % num_outputs else 0) for i in range(num_outputs)]
return data.split(split, axis)
def _onnx_pads_to_tiny_pads(pads):
# (padding_top, padding_left, ..., padding_bottom, padding_right, ...) -> (padding_left, padding_right, padding_top, padding_bottom, ...)
return tuple(flatten(reversed(list(zip(pads, pads[len(pads)//2:])))))
def Pad(x:Tensor, pads:list[int], constant_value:ConstType|None=None, axes:list[int]|None=None,
mode:Literal["constant", "reflect", "edge", "wrap"]="constant", value=0):
value = constant_value or value
axes = axes or list(range(x.ndim))
real_pads = [0] * (x.ndim*2)
for i,axis in enumerate(axes): real_pads[axis%x.ndim], real_pads[axis%x.ndim+x.ndim] = pads[i], pads[i+len(axes)]
return x.pad(padding=_onnx_pads_to_tiny_pads(real_pads), mode={"edge":"replicate", "wrap":"circular"}.get(mode, mode), value=value)
def CenterCropPad(t:Tensor, shape:list[int], axes:list[int]|None=None):
shrink_arg:list[None|tuple[int,int]] = [None] * t.ndim
pad_arg:list[None|tuple[int,int]] = [None] * t.ndim
for s, x in zip(shape, axes or range(t.ndim)):
tx = t.shape[x]
if s < tx: shrink_arg[x] = (tx//2 - (s+1)//2, tx//2 + s//2)
elif s > tx: pad_arg[x] = ((s-tx)//2, (s-tx+1)//2)
return t.shrink(tuple(shrink_arg)).pad(tuple(pad_arg))
# ***** Processing Ops *****
AUTO_PAD_OPTIONS = Literal["NOTSET", "SAME_UPPER", "SAME_LOWER", "VALID"]
def _auto_pad(pads, auto_pad: AUTO_PAD_OPTIONS):
# (padding_height, padding_width) -> (padding_top, padding_left, padding_bottom, padding_right)
if auto_pad == "SAME_UPPER": return [pads[i]//2 for i in range(len(pads))] + [pads[i]-pads[i]//2 for i in range(len(pads))]
return [pads[i]-pads[i]//2 for i in range(len(pads))] + [pads[i]//2 for i in range(len(pads))]
def _resolve_pool_pads(x:Tensor, p_, k_, d_, s_, auto_pad:AUTO_PAD_OPTIONS):
i_, (s_,d_,p_) = x.shape[-len(k_):], (make_tuple(x, len(k_)*2) for x in (s_, d_, p_))
if auto_pad == "NOTSET": return _onnx_pads_to_tiny_pads(p_ if len(p_)==len(k_)*2 else p_*2)
o_ = [((i - (1 if auto_pad in ("SAME_UPPER", "SAME_LOWER") else k)) // s + 1) for i,k,s in zip(i_, k_, s_)]
return _onnx_pads_to_tiny_pads(_auto_pad([(o-1)*s+k-i for o,i,k,s in zip(o_, i_, k_, s_)], auto_pad))
def AveragePool(X: Tensor, kernel_shape:list[int], auto_pad:AUTO_PAD_OPTIONS="NOTSET", ceil_mode:int=0, count_include_pad:int=0,
dilations:list[int]|int=1, pads:list[int]|int=0, strides:list[int]|int=1):
return X.avg_pool2d(kernel_shape, strides, dilations, _resolve_pool_pads(X, pads, kernel_shape, dilations, strides, auto_pad),
ceil_mode=ceil_mode, count_include_pad=count_include_pad)
def MaxPool(X: Tensor, kernel_shape:list[int], auto_pad:AUTO_PAD_OPTIONS="NOTSET", ceil_mode:int=0, dilations:list[int]|int=1, pads:list[int]|int=0,
storage_order:int=0, strides:list[int]|int=1):
ret = X.max_pool2d(kernel_shape, strides, dilations, _resolve_pool_pads(X, pads, kernel_shape, dilations, strides, auto_pad), ceil_mode=ceil_mode)
# tests expect indices with int64 dtype
# TODO: if there are repeated values, this is wrong
indices = ((ret.reshape(-1, 1) == X.reshape(1, -1)) * Tensor.arange(X.numel(), dtype=dtypes.int64).unsqueeze(0)).sum(1).reshape(ret.shape)
return ret.cast(X.dtype), indices.transpose(-2, -1) if storage_order else indices
def Conv(X: Tensor, W: Tensor, B:Tensor|None=None, auto_pad:AUTO_PAD_OPTIONS="NOTSET", dilations:list[int]|int=1, group:int=1,
kernel_shape:list[int]|None=None, pads:list[int]|int=0, strides:list[int]|int=1):
return X.conv2d(W, B, stride=strides, groups=group, dilation=dilations,
padding=_resolve_pool_pads(X, pads, kernel_shape or W.shape[2:], dilations, strides, auto_pad))
def ConvTranspose(X: Tensor, W: Tensor, B:Tensor|None=None, auto_pad:AUTO_PAD_OPTIONS="NOTSET", dilations:list[int]|int=1, group:int=1,
kernel_shape:list[int]|None=None, pads:list[int]|None=None, output_shape:list[int]|None=None, output_padding:list[int]|int=0,
strides:list[int]|int=1):
input_shape, kernel_shape = X.shape[2:], (kernel_shape or W.shape[2:])
strides, dilations, output_padding = (make_tuple(x, len(input_shape)) for x in (strides, dilations, output_padding))
if output_shape is not None: # we pad according to output_shape
pads = _auto_pad([s*(i-1) + op + ((k-1)*d+1) - os for s,i,op,k,d,os in
zip(strides, input_shape, output_padding, kernel_shape, dilations, output_shape)], auto_pad)
if pads is None: # we generate pads
output_shape = output_shape or [X.shape[i+2] * strides[i] for i in range(len(strides))]
pads = [strides[i]*(input_shape[i]-1) + output_padding[i] + ((kernel_shape[i]-1)*dilations[i]+1)-output_shape[i] for i in range(len(input_shape))]
pads = _auto_pad(pads, auto_pad) if auto_pad != "NOTSET" else [0] * len(input_shape) * 2
pads = _onnx_pads_to_tiny_pads(pads)
return X.conv_transpose2d(W, B, stride=strides, groups=group, dilation=dilations, padding=pads, output_padding=output_padding)
def MaxUnpool(xT: Tensor, xI: Tensor, outshape: list[int]|None=None, kernel_shape:list[int]=None, pads:list[int]|int=0, strides:list[int]|int=1):
pads, strides = (make_tuple(x, len(xI.shape)) for x in (pads, strides))
out_sh = [(ks//2)*2 + st * inps for inps, st, ks in zip(xI.shape, strides, kernel_shape)]
ret = (xI.reshape(-1, 1)._one_hot_along_dim(prod(out_sh)) * xT.reshape(-1, 1)).sum(0).reshape(1, 1, *out_sh)
if outshape is not None and outshape != ret.shape: pads = _auto_pad([outshape[-2] - ret.shape[-2], outshape[-1] - ret.shape[-1]], "SAME_UPPER")
return ret.pad(_onnx_pads_to_tiny_pads(pads))
def GlobalAveragePool(X:Tensor): return X.mean(axis=tuple(range(2, X.ndim)), keepdim=True)
def GlobalMaxPool(X:Tensor): return X.max(axis=tuple(range(2, X.ndim)), keepdim=True)
def Gemm(A:Tensor, B:Tensor, C:Tensor|None=None, alpha:float=1.0, beta:float=1.0, transA:int=0, transB:int=0, broadcast=0):
ret = alpha * (A.transpose(transA) @ B.transpose(transB))
if C is not None: ret = ret + beta * (C if broadcast == 0 else C.reshape([-1 if i < len(C.shape) else 1 for i in range(ret.ndim)][::-1]))
return ret
def Einsum(*Inputs:list[Tensor], equation:str): return Tensor.einsum(equation, *Inputs)
def CumSum(X:Tensor, axis:int|list, exclusive:int=0, reverse:int=0):
axis = X._resolve_dim(axis[0] if isinstance(axis, list) else axis)
if reverse: X = X.flip(axis)
if exclusive: X = X.pad(tuple((1,0) if i == axis else None for i in range(X.ndim)))\
.shrink(tuple((0,X.shape[axis]) if i == axis else None for i in range(X.ndim)))
return X.cumsum(axis).flip(axis) if reverse else X.cumsum(axis)
def Trilu(x:Tensor, k:int=0, upper:int=1): return x.triu(k) if upper else x.tril(k)
def Resize(X:Tensor, roi:list[float]|None=None, scales:list[float]|None=None, sizes:list[int]|None=None, antialias:int=0,
axes:list[int]|None=None, coordinate_transformation_mode:str='half_pixel', cubic_coeff_a:float=-0.75, exclude_outside:int=0,
extrapolation_value:float=0.0, keep_aspect_ratio_policy:str='stretch', mode:str='nearest', nearest_mode:str='round_prefer_floor'):
def _apply_nearest_mode(index: Tensor, input_dim, mode: str):
if mode == "round_prefer_floor": index = (index - 0.5).ceil()
elif mode == "round_prefer_ceil": index = (index + 0.5).floor()
elif mode in ["floor", "ceil"]: index = getattr(index, mode)()
else: raise ValueError(f"invalid {nearest_mode=}")
return index.cast(dtypes.int32).clip(0, input_dim-1)
def _apply_transformation(index: Tensor, input_dim, scale_dim, roi_dim, mode):
# TODO: needs more testing, not confident in this
# NOTE: their reference implementation differ from the implementation in their reference docs
# https://github.com/onnx/onnx/blob/main/onnx/reference/ops/op_resize.py
# https://github.com/onnx/onnx/blob/main/docs/Operators.md#Resize
output_dim = scale_dim * input_dim
if mode == "half_pixel": index = (index + 0.5) / scale_dim - 0.5
elif mode == "align_corners": index = index * (input_dim - 1) / (output_dim - 1) if output_dim != 1 else Tensor([0])
elif mode == "asymmetric": index = index / scale_dim
elif mode == "pytorch_half_pixel": index = (index + 0.5) / scale_dim - 0.5 if output_dim != 1 else Tensor([-0.5])
elif mode == "half_pixel_symmetric": index = input_dim / 2 * (1 - int(output_dim) / output_dim) + (index + 0.5) / scale_dim - 0.5
elif mode == "tf_crop_and_resize": index = roi_dim[0] * (input_dim - 1) + index * ((roi_dim[1] - roi_dim[0]) * (input_dim - 1) / (output_dim - 1))
else: raise ValueError(f"invalid {coordinate_transformation_mode=}")
return index.clip(0, input_dim-1)
scales, sizes = (None if scales is None else scales[2-(X.ndim-len(scales)):]), (None if sizes is None else sizes[2-(X.ndim-len(sizes)):])
# we pre permute the axes and permute back after resize
axes, input_shape, = (axes or list(range(X.ndim))), cast(tuple[int, ...], X.shape[2:]),
perm = [a for a in range(len(X.shape)) if a not in axes] + list(axes)
X = X.permute(*perm)
if sizes is not None:
if keep_aspect_ratio_policy in ["not_larger", "not_smaller"]:
scale_fxn = min if keep_aspect_ratio_policy == "not_larger" else max
scales = [scale_fxn([sizes[i] / input_shape[i] for i in range(len(input_shape)) if i+2 in axes])] * 2
sizes = [int((scales[0] * input_shape[i]) + 0.5) if i+2 in axes else input_shape[i] for i in range(X.ndim-2)]
else:
scales = [size / input_shape for size, input_shape in zip(sizes, input_shape)]
else:
sizes = [int(sc*sh) for sc, sh in zip(scales, input_shape)]
regions = [[st, ed] for st, ed in zip(roi, roi[len(roi)//2:])] if isinstance(roi, list) and roi else [[0.0, 0.0]] * (X.ndim-2)
# NOTE: this transformation makes it so that we can't just call Tensor.interpolate
# in Tensor.interpolate, we use indexes without any transformation
indexes = []
for shape, size, scale, region in zip(input_shape, sizes, scales, regions):
indexes.append(_apply_transformation(Tensor.arange(size), shape, scale, region, coordinate_transformation_mode))
if mode == "nearest":
indexes = [_apply_nearest_mode(index, shape, nearest_mode) for (index, shape) in zip(indexes, input_shape)]
X = X[(..., *Tensor.meshgrid(*indexes))]
if mode == "linear":
expand = list(X.shape)
for i in range(-len(sizes), 0):
reshape, index = [1] * X.ndim, indexes[i]
reshape[i] = expand[i] = sizes[i]
low, high, perc = [y.reshape(reshape).expand(expand) for y in (index.floor(), index.ceil(), index - index.floor())]
X = X.gather(i, low).lerp(X.gather(i, high), perc)
if mode == "cubic": raise NotImplementedError("cubic interpolation is not implemented")
return X.permute(*[perm.index(i) for i in range(len(perm))]) if perm else X
def Upsample(X, scales, mode): return Resize(X=X, scales=scales, mode=mode) # deprecated
# ***** Neural Network Ops *****
# TODO: try to factor out common implementations for these ops
# https://medium.com/@zljdanceholic/groupnorm-then-batchnorm-instancenorm-layernorm-e2b2a1d350a0
def BatchNormalization(X:Tensor, scale:Tensor, B:Tensor, input_mean:Tensor, input_var:Tensor, epsilon:float=1e-05, momentum:float=0.9,
training_mode:int=0, spatial=1, is_test=0):
if training_mode:
x_detached = X.detach()
current_mean = x_detached.mean(axis=(0,2,3))
y = (x_detached - current_mean.reshape(shape=[1, -1, 1, 1]))
current_var = (y*y).mean(axis=(0,2,3))
current_invstd = current_var.add(epsilon).rsqrt()
running_mean = input_mean * momentum + current_mean * (1 - momentum)
running_var = input_var * momentum + current_var * (1 - momentum)
return X.batchnorm(scale, B, current_mean, current_invstd), running_mean, running_var
invstd = (input_var + epsilon).rsqrt()
return X.batchnorm(scale, B, input_mean, invstd)
def InstanceNormalization(x:Tensor, scale:Tensor, bias:Tensor, epsilon:float=1e-05):
axis = tuple(range(2, x.ndim))
mean = x.mean(axis=axis, keepdim=True)
invstd = x.sub(mean).square().mean(axis=axis, keepdim=True).add(epsilon).rsqrt()
return x.sub(mean).mul(scale.reshape(shape=[-1, 1, 1])).mul(invstd).add(bias.reshape(shape=[-1, 1, 1]))
def LayerNormalization(x:Tensor, scale:Tensor, bias:Tensor, axis:int=-1, epsilon:float=1e-05, stash_type:int=1):
assert stash_type == 1, "only float32 is supported"
axes = tuple(i for i in range(axis if axis >= 0 else x.ndim + axis, x.ndim))
mean = x.mean(axis=axes, keepdim=True)
return x.layernorm(axes, epsilon).mul(scale).add(bias), mean, (x.sub(mean)).square().mean(axis=axes, keepdim=True).add(epsilon).rsqrt()
def GroupNormalization(x:Tensor, scale:Tensor, bias:Tensor, num_groups:int, epsilon:float=1e-05):
return x.reshape(x.shape[0], num_groups, -1).layernorm(axis=-1, eps=epsilon).mul(scale.unsqueeze(-1)).add(bias.unsqueeze(-1)).reshape(x.shape)
def MeanVarianceNormalization(x:Tensor, axis:list[int]=[0,2,3]):
return (x - x.mean(axis, keepdim=True)) / (x.std(axis, keepdim=True, correction=0) + 1e-9)
def SkipLayerNormalization(x:Tensor, skip:Tensor, gamma:Tensor, beta:Tensor|None=None, bias:Tensor|None=None, epsilon:float=1e-12):
x = x + skip + bias
return x.layernorm(eps=epsilon) * gamma + beta, None, None, x
def EmbedLayerNormalization(input_ids: Tensor, segment_ids:Tensor, word_embedding:Tensor, position_embedding:Tensor,
segment_embedding:Tensor, gamma=None, beta=None, mask:Tensor|None=None,
position_ids:Tensor|None=None, epsilon=1e-12, mask_index_type=0):
# https://github.com/microsoft/onnxruntime/blob/main/docs/ContribOperators.md#com.microsoft.EmbedLayerNormalization
assert (segment_ids is None) is (segment_embedding is None)
assert mask is None and not mask_index_type, "functionality not supported yet" # TODO
input_shape = input_ids.shape
seq_length = input_shape[1]
compute_seg_emb = (segment_embedding is not None and segment_ids is not None)
vocab_size, max_position_embeddings = word_embedding.shape[0], position_embedding.shape[0]
type_vocab_size = (segment_embedding.shape[0] if compute_seg_emb else None)
def embedding(x:Tensor, vocab_size, weight:Tensor) -> Tensor:
return x.unsqueeze(-1).expand(*x.shape, vocab_size)._one_hot_along_dim(vocab_size) @ weight
# bert embedding layer
if position_ids is None: position_ids = Tensor.arange(seq_length, requires_grad=False).unsqueeze(0).expand(*input_shape)
wrd_embedding_res = embedding(input_ids, vocab_size, word_embedding)
pos_embedding_res = embedding(position_ids, max_position_embeddings, position_embedding)
seg_embedding_res = embedding(segment_ids, type_vocab_size, segment_embedding) if compute_seg_emb else None
embedding_sum = wrd_embedding_res + pos_embedding_res
if seg_embedding_res is not None: embedding_sum = embedding_sum + seg_embedding_res
out = embedding_sum.layernorm(eps=epsilon) * gamma + beta
return out, None, embedding_sum
def OneHot(indices:Tensor, depth:float|int|list, values:Tensor, axis:int=-1):
# Scalar or Rank 1 tensor containing exactly one element
depth = int(depth[0] if isinstance(depth, list) else depth)
indices = (indices < 0).where(indices+depth, indices)
return indices[:, None]._one_hot_along_dim(depth, dim=axis).where(values[1], values[0])
def DepthToSpace(X:Tensor, blocksize:int, mode:str="DCR"):
return X.rearrange("b (c h1 w1) h w -> b c (h h1) (w w1)" if mode=="CRD" else "b (h1 w1 c) h w -> b c (h h1) (w w1)", h1=blocksize, w1=blocksize)
def SpaceToDepth(X:Tensor, blocksize:int):
return X.rearrange("b c (h h1) (w w1) -> b (h1 w1 c) h w", h1=blocksize, w1=blocksize)
# Reimplemented here because you need legacy RNG for passing ONNX tests.
def Dropout_7(data:Tensor, ratio:float=0.5, training_mode:bool=False, seed:int|None=None):
if not training_mode: return data, Tensor.ones(data.shape, dtype=dtypes.bool) # if mask is requested as output it will contain all True's.
mask = Tensor(np.random.RandomState(seed).random(cast(tuple[int,...], data.shape)) >= ratio, requires_grad=False, device=data.device)
return data * mask * (1/(1.0 - ratio)), mask
# 6 with 'is_test' needed for https://github.com/MTlab/onnx2caffe/raw/refs/heads/master/model/MobileNetV2.onnx
def Dropout_6(data:Tensor, ratio:float=0.5, is_test=0): return Dropout_7(data, ratio, training_mode=not is_test)
Dropout = {6:Dropout_6, 7:Dropout_7}
def LRN(x:Tensor, size:int, alpha:float=1e-4, beta:float=0.75, bias:float=1.0):
pooled_x = (x**2).rearrange('b c h w -> b 1 c (h w)').pad((0,0,(size-1)//2, size//2)).avg_pool2d((size, 1), 1)
return x / (pooled_x.reshape(x.shape) * alpha + bias).pow(beta)
def NegativeLogLikelihoodLoss(x:Tensor, target:Tensor, weight:Tensor|None=None, ignore_index:int|None=None, reduction:ReductionStr="mean"):
return x.nll_loss(target, weight, ignore_index, reduction)
def SoftmaxCrossEntropyLoss(scores:Tensor, labels:Tensor, weights:Tensor|None=None, ignore_index:int|None=None, reduction:ReductionStr="mean"):
log_probs = scores.log_softmax(1)
return log_probs.nll_loss(labels, weights, ignore_index, reduction), log_probs
def AffineGrid(theta:Tensor, size:list[int], align_corners:int=0):
N, _, *spatial_dims = size
def generate_grid(steps):
return Tensor.linspace(-1, 1, steps, device=theta.device) if align_corners else Tensor.linspace(-1+1/steps, 1-1/steps, steps, device=theta.device)
grids = Tensor.meshgrid(*(generate_grid(d) for d in spatial_dims))
base_grid = Tensor.stack(*reversed(grids), Tensor.ones_like(grids[0], device=theta.device), dim=-1)
base_grid = base_grid.reshape(1, prod(spatial_dims), len(grids)+1).expand(N, -1, -1)
return (base_grid @ theta.transpose(1, 2)).reshape(N, *spatial_dims, -1)
def Attention(x:Tensor, weights, bias:Tensor, mask_index:Tensor|None=None, past:Tensor|None=None,
relative_position_bias:Tensor|None=None, past_sequence_length:Tensor|None=None, do_rotary:int|None=None,
mask_filter_value:float|None=None, num_heads:int|None=None, past_present_share_buffer:int|None=None,
qkv_hidden_sizes:list[int]|None=None, scale:float|None=None, unidirectional:int|None=None):
# https://github.com/microsoft/onnxruntime/blob/main/docs/ContribOperators.md#com.microsoft.Attention
assert num_heads is not None # required
assert (qkv_hidden_sizes is None and past is not None) or (qkv_hidden_sizes is not None)
assert relative_position_bias is do_rotary is past_sequence_length is mask_filter_value is past_present_share_buffer is scale is None, \
"functionality not supported yet" # TODO strange params
hidden_size, v_hidden_size = qkv_hidden_sizes[1:] if qkv_hidden_sizes is not None else 2*(weights.shape[1] // 3,)
if unidirectional: # gpt-style
assert hidden_size == v_hidden_size
xqkv = x.linear(weights, bias)
xq, xk, xv = [xqkv.shrink([None, None, (i*hidden_size, (i+1)*hidden_size)]) for i in range(3)]
else: # bert-style
wq, wk, wv = weights[:,:hidden_size], weights[:,hidden_size:hidden_size+v_hidden_size], weights[:,hidden_size+v_hidden_size:]
bq, bk, bv = (bias[:hidden_size], bias[hidden_size:hidden_size+v_hidden_size], bias[hidden_size+v_hidden_size]) if bias is not None else None
xq, xk, xv = [x.linear(w, b) for w, b in zip((wq, wk, wv), (bq, bk, bv))]
xq, xk, xv = [x.reshape(x.shape[0], x.shape[1], num_heads, -1).transpose(1, 2) for x in (xq, xk, xv)]
if past is not None:
xk, xv = Tensor.cat(past[0], xk, dim=-2), Tensor.cat(past[1], xv, dim=-2)
present = Tensor.cat(xk.unsqueeze(0), xv.unsqueeze(0))
def attn(query, key, value, attn_mask):
query_length, key_length = query.shape[-2], key.shape[-2]
cdim = max(query_length, key_length) + 1
attn_weights = query @ key.transpose(-1, -2) / math.sqrt(value.shape[-1])
# This is where Tensor.scaled_dot_product_attention differs:
causal_mask = Tensor.ones((cdim, cdim), requires_grad=False, dtype=dtypes.bool).tril(0)[key_length - query_length : key_length, :key_length]
masked = Tensor.where(causal_mask, attn_weights, -math.inf)
if attn_mask is not None: masked = masked + attn_mask
return masked.softmax(-1) @ value
bsz, _, seq_len, _ = xq.shape
out = attn(xq, xk, xv, mask_index).transpose(1, 2).reshape(bsz, seq_len, -1)
return out, present if past is not None else out
# ***** Indexing Ops *****
def ArrayFeatureExtractor(x:Tensor, indices:Tensor): return x[..., indices]
def Gather(x:Tensor, indices:Tensor, axis:int=0):
if indices.numel() < 9: # NOTE lessor kernels for smaller indices but kernel number increases depending on size of indices
x_sh = list(x.shape)
ret_shape = x_sh[:axis] + list(indices.shape) + x_sh[axis+1:]
if indices.ndim > 1: indices = indices.flatten()
indices = [_cached_to_python_const(indices)] if indices.shape == () else [x_sh[axis]+x if x<0 else x for x in _cached_to_python_const(indices)]
args = [[(0,x) if j != axis else (i,i+1) for j, x in enumerate(x_sh)] for i in indices] # type: ignore
return x.shrink(arg=tuple(args[0])).cat(*[x.shrink(arg=tuple(arg)) for arg in args[1:]], dim=axis).reshape(ret_shape)
# NOTE faster gather, fixed number of kernels, but exceeds limited kernels for openpilot
return x[tuple([slice(None) if i != axis else indices for i in range(x.ndim)])]
def Scatter(*args, **kwargs): return ScatterElements(*args, **kwargs) # deprecated
def GatherND(x:Tensor, indices:Tensor, batch_dims:int=0):
if batch_dims == 0: return x[tuple(i.squeeze(-1) for i in indices.split(1, -1))]
x_shape, i_shape = x.shape, indices.shape
b = math.prod(x.shape[dim] for dim in range(batch_dims))
# NOTE: each batched dim of both input and indices are equal
x = x.reshape(b, *x.shape[batch_dims:])
indices = indices.reshape(b, *indices.shape[batch_dims:])
b_idx = Tensor.arange(b, device=x.device).reshape(b, *(1,)*(indices.ndim - 2)).expand(*indices.shape[:-1])
ret = x[(b_idx,) + tuple(i.squeeze(-1) for i in indices.split(1, -1))]
return ret.reshape(*x_shape[:batch_dims], *i_shape[batch_dims:-1], *ret.shape[indices.ndim-1:])
def ScatterND(x:Tensor, indices:Tensor, updates:Tensor, reduction:Literal["none", "add", "mul"]='none'):
assert updates.shape == indices.shape[:-1] + x.shape[cast(int, indices.shape[-1]):]
x = x.contiguous()
for index, u in zip(indices.split(1, 0), updates.split(1, 0)):
i = tuple(idx.squeeze(-1) for idx in index.squeeze(0).split(1, -1))
u = u.squeeze(0)
if reduction == "none": x[i] = u
elif reduction == "add": x[i] += u
elif reduction == "mul": x[i] *= u
else: raise NotImplementedError("reduction doesn't support max or min")
return x
def ScatterElements(x: Tensor, indices: Tensor, updates: Tensor, axis=0, reduction:Literal["none", "add", "mul"]="none"):
indices = (indices < 0).where(x.shape[axis], 0) + indices
return x.scatter(axis, indices, updates, {"none":None, "mul": "multiply"}.get(reduction, reduction))
def GatherElements(x:Tensor, indices:Tensor, axis:int):
indices = (indices < 0).where(x.shape[axis], 0) + indices
return x.gather(axis, indices)
def Compress(inp:Tensor, condition:list[bool], axis:int|None=None):
if axis is None:
inp = inp.flatten()
axis = 0
if axis < 0: axis += inp.ndim
con = Tensor(np.arange(len(condition))[condition]) # no boolean indexing in Tensor
return inp[tuple(con if i == axis else slice(None) for i in range(inp.ndim))]
# ***** Quantization Ops *****
def _clamp_cast(x:Tensor, dtype:DType): return x.clamp(dtypes.min(dtype), dtypes.max(dtype)).cast(dtype)
def _prepare_quantize(x, scale, zero_point, axis=1, block_size=0):
if axis < 0: axis += x.ndim
if not isinstance(zero_point, Tensor): zero_point = Tensor(zero_point, dtype=dtypes.uint8)._broadcast_to(scale.shape)
if block_size == 0:
shape = (*[1]*axis, *scale.shape, *[1]*(x.ndim - axis - scale.ndim))
return scale.reshape(shape), zero_point.reshape(shape)
return scale.repeat_interleave(block_size, dim=axis), zero_point.repeat_interleave(block_size, dim=axis)
def QuantizeLinear(x:Tensor, y_scale:Tensor, y_zero_point:Tensor|int=0, axis:int=1, block_size:int=0, output_dtype:int=0, saturate=1):
out_dtype = y_zero_point.dtype if isinstance(y_zero_point, Tensor) else dtype_parse(output_dtype) if output_dtype else dtypes.uint8
y_scale, y_zero_point = _prepare_quantize(x, y_scale, y_zero_point, axis, block_size)
return _clamp_cast(((x / y_scale).round() + y_zero_point), out_dtype).contiguous()
def DequantizeLinear(x:Tensor, x_scale:Tensor, x_zero_point:Tensor|int=0, axis:int=1, block_size:int=0):
x_scale, x_zero_point = _prepare_quantize(x, x_scale, x_zero_point, axis, block_size)
return ((x.int() - x_zero_point) * x_scale).cast(x_scale.dtype)
def _op_integer(op, inputs:list[Tensor], zero_points:list[Tensor], **opts):
adjusted_inputs = [inp.int() - zp for inp, zp in zip(inputs, zero_points)]
return op(*adjusted_inputs, **opts)
def _qlinearop_quantized(op, inputs:list[Tensor], zero_points:list[Tensor], scales:list[Tensor], out_scale:Tensor, out_zero_point:Tensor, **opts):
# op execution is done in quantized int
out = _op_integer(op, inputs, zero_points, **opts)
assert dtypes.is_int(out.dtype), "quantized op should've done math in int"
out_quantized = (out * prod(scales) / out_scale).round() + out_zero_point
return _clamp_cast(out_quantized, out_zero_point.dtype)
def _qlinearop_float(op, inputs:list[Tensor], zero_points:list[Tensor], scales:list[Tensor], out_scale:Tensor, out_zero_point:Tensor, **opts):
# op execution is done in float32
dequantized_inputs = [(inp.int() - zp) * scale for inp, zp, scale in zip(inputs, zero_points, scales)]
out = op(*dequantized_inputs, **opts)
assert dtypes.is_float(out.dtype), "op should've done math in float"
out_quantized = (out / out_scale).round() + out_zero_point
return _clamp_cast(out_quantized, out_zero_point.dtype)
def QLinearConv(x:Tensor, x_scale:Tensor, x_zero_point:Tensor|int, w:Tensor, w_scale:Tensor, w_zero_point:Tensor|int, y_scale:Tensor,
y_zero_point: Tensor|int, B:Tensor|None=None, **opts):
return _qlinearop_quantized(Conv, [x,w], [x_zero_point,w_zero_point], [x_scale,w_scale], y_scale, y_zero_point, **{"B":B, **opts})
def QLinearMatMul(a:Tensor, a_scale:Tensor, a_zero_point:Tensor|int, b:Tensor, b_scale:Tensor, b_zero_point:Tensor|int, y_scale:Tensor,
y_zero_point:Tensor|int) -> Tensor:
return _qlinearop_quantized(Tensor.matmul, [a,b], [a_zero_point,b_zero_point], [a_scale,b_scale], y_scale, y_zero_point)
def QLinearAdd(a:Tensor, a_scale:Tensor, a_zero_point:Tensor, b:Tensor, b_scale:Tensor, b_zero_point:Tensor, c_scale:Tensor, c_zero_point:Tensor):
return _qlinearop_float(Tensor.add, [a,b], [a_zero_point,b_zero_point], [a_scale,b_scale], c_scale, c_zero_point)
def QLinearGlobalAveragePool(X:Tensor, x_scale:Tensor, x_zero_point:Tensor, y_scale:Tensor, y_zero_point:Tensor, channels_last:int):
assert channels_last == 0, "unsure what this does"
return _qlinearop_float(GlobalAveragePool, [X], [x_zero_point], [x_scale], y_scale, y_zero_point)
def ConvInteger(x: Tensor, w: Tensor, x_zero_point: Tensor | int = 0, w_zero_point: Tensor | int = 0, B: Tensor | None = None, **opts) -> Tensor:
return _op_integer(Conv, [x,w], [x_zero_point,w_zero_point], **{"B":B, **opts})
def MatMulInteger(A: Tensor, B: Tensor, a_zero_point: Tensor | int = 0, b_zero_point: Tensor | int = 0) -> Tensor:
return _op_integer(Tensor.matmul, [A,B], [a_zero_point,b_zero_point])
# ***** Training Ops *****
# NOTE: onnx test coverage only covers `T==0` cases, so for all `T>0` this isn't tested
# NOTE: onnx training ops actually don't need the state for optim, all the ops work in a functional way, but we still can reuse optim.py code
def _onnx_training(input_group_size):
def __decorator(func):
def ___wrapper(R:Tensor, T:int, *inputs:Tensor, **kwargs):
R = R.detach()
groups = len(inputs) // input_group_size
ret = [func(R, T, *inps, **kwargs) for inps in (inputs[i::groups] for i in range(groups))]
return tuple(flatten(zip(*ret)))
return ___wrapper
return __decorator
@_onnx_training(3)
def Adagrad(R:Tensor, T:int, *inputs:Tensor, decay_factor:float=0.0, epsilon:float=0.0, norm_coefficient:float=0.0):
X, G, H = (i.detach() for i in inputs)
grad = norm_coefficient * X + G
H.assign(H + grad.square())
up = grad / (H.sqrt() + epsilon)
r = R / (1 + T * decay_factor)
X.assign(X.detach() - r * up)
return [X, H]
@_onnx_training(4)
def Adam(R:Tensor, T:int, *inputs:Tensor, alpha:float=0.9, beta:float=0.999, epsilon:float=0.0, norm_coefficient:float=0.0,
norm_coefficient_post:float=0.0):
from tinygrad.nn.optim import Adam as TinyAdam
X, G, V, H = inputs
G, V, H = G.detach(), V.detach(), H.detach() # TODO we shouldn't need these detaches
X.grad = norm_coefficient * X.detach() + G
opt = TinyAdam([X], b1=alpha, b2=beta, eps=epsilon)
opt.m, opt.v, opt.lr = [V], [H], R
# need no-op for m_hat and v_hat if T == 0
if T == 0: opt.b1_t, opt.b2_t = opt.b1_t.zeros_like(), opt.b2_t.zeros_like()
else:
# `T-1` since it's applied again at the start of `_step`
opt.b1_t = Tensor([alpha**(T-1)], dtype=dtypes.float32, device=X.device, requires_grad=False)
opt.b2_t = Tensor([beta**(T-1)], dtype=dtypes.float32, device=X.device, requires_grad=False)
opt.step()
X = (1 - norm_coefficient_post) * X
return [X, V, H]
@_onnx_training(3)
def Momentum(R:Tensor, T:int, *inputs:Tensor, alpha:float, beta:float, mode:str, norm_coefficient:float):
from tinygrad.nn.optim import SGD
X, G, V = inputs
G, V = G.detach(), V.detach()
X.grad = (norm_coefficient * X.detach() + G) * (beta if T > 0 else 1)
opt = SGD([X], momentum=alpha, nesterov=(mode=="nesterov"))
opt.b, opt.lr = [V], R
opt.step()
return [X, V]
def Gradient(*inputs:Tensor, y:str, intermediate_tensors:dict[str, Tensor], **_):
intermediate_tensors[y].backward()
return tuple([t.grad for t in inputs])

View File

@@ -94,10 +94,9 @@ class TestBinaryOpsConstFolding(unittest.TestCase):
_check_ast_count(0, Tensor([1.0, 2, 3, 4]) ** Tensor.ones(4))
def test_literal_one_pow(self):
_check_ast_count(0, 1 ** Tensor([1.0, 2, 3, 4]))
# this fails because of DETACH, it shouldn't
# update: passes after CONST(VIEW(DEVICE)) in tensor
# TODO: pow simplification
def test_tensor_one_pow(self):
_check_ast_count(0, Tensor.ones(4) ** Tensor([1.0, 2, 3, 4]))
_check_ast_count(1, Tensor.ones(4) ** Tensor([1.0, 2, 3, 4]))
# folds advance indexing into basic indexing
class TestIndexingConstFolding(unittest.TestCase):

View File

@@ -3,7 +3,7 @@ import numpy as np
from tinygrad import Tensor, Device, dtypes
from tinygrad.dtype import DType
from tinygrad.nn.state import safe_load, safe_save, get_state_dict, torch_load
from tinygrad.helpers import Timing, fetch, temp, CI
from tinygrad.helpers import Timing, fetch, temp, CI, OSX
from tinygrad.device import is_dtype_supported
def compare_weights_both(url):
@@ -298,6 +298,7 @@ class TestDiskTensor(unittest.TestCase):
ret = t.bitcast(dtypes.uint16).to("CLANG") + 1
assert ret.tolist() == [2827, 3341, 3855, 4369]
@unittest.skipIf(OSX, "new LLVM has an issue on OSX")
def test_bf16_disk_write_read(self):
t = Tensor([10000, -1, -1000, -10000, 20], dtype=dtypes.float32)
t.to(f"disk:{temp('dt_bf16_disk_write_read_f32')}").realize()

View File

@@ -620,7 +620,7 @@ class Kernel:
if self.use_tensor_cores == 3: # for TC=3, emulate the warp addressing with locals
local_shape = tuple(1 if i >= self.first_reduce and i < self.first_upcast else s for i, s in enumerate(self.full_shape))
st = store_st = ShapeTracker.from_shape(local_shape)
local_buffer = UOp(Ops.DEFINE_LOCAL, tc.dtype_in.ptr(size=st.real_size(), local=True), (), f"temp{i + 1}")
local_buffer = UOp(Ops.DEFINE_LOCAL, tc.dtype_in.ptr(size=st.real_size(), local=True), (), f"temp{i}")
if swizzle: store_st = get_tc_swizzle_st(store_st.shape, *swizzle)
local_store = UOp.store(local_buffer, store_st.to_uop(), srcs[i])
srcs[i] = UOp(Ops.LOAD, tc.dtype_in, (local_buffer, st.to_uop(), local_store))
@@ -648,7 +648,7 @@ class Kernel:
(1,) * (self.shape_len - self.upcasted - self.group_for_reduces - self.first_reduce) + tuple([x[0] for x in self.upcasted_axis(0)])
st_uop = ShapeTracker.from_shape(local_shape).to_uop()
local_size = st_uop.arg.real_size()
local_buffer = UOp(Ops.DEFINE_LOCAL, op.dtype.ptr(local_size, local=True), (), f"temp{self.reduceops.index(op)+1}")
local_buffer = UOp(Ops.DEFINE_LOCAL, op.dtype.ptr(local_size, local=True), (), f"temp{self.reduceops.index(op)}")
local_load = UOp(Ops.LOAD, op.dtype, (local_buffer, st_uop, UOp.store(local_buffer, st_uop, ret)))
grouped_reduce = UOp(Ops.REDUCE_AXIS, op.dtype, (local_load,), arg=(op.arg[0], grouped_axes))
if op is self.reduceops[-1]: return grouped_reduce

View File

@@ -43,13 +43,28 @@ def to_sharded(lbs:list[UOp], axis:int, bounds: tuple[tuple[int, int], ...]) ->
from tinygrad.ops import PatternMatcher, UPat, GroupOp, graph_rewrite_map, track_rewrites
def get_axis(root:UOp):
if root.op is Ops.MULTI: return root.arg[0]
# NOTE: they all have to share an axis, we always choose [-1]
if root.op in GroupOp.ALU: return axes[-1] if (axes := dedup([x.axis for x in root.src if x.axis is not None])) else None
src_axis = get_axis(root.src[0])
if root.op is Ops.REDUCE_AXIS: return None if src_axis is not None and src_axis in root.arg[1] else src_axis
if root.op is Ops.RESHAPE:
if src_axis is None: return None
arg_acc:list[sint] = list(itertools.accumulate(root.arg, operator.mul, initial=1))
# new_axis is the last one that preserves prod(prior to new_axis) and must not move items between shards
# TODO: what to do about shrinking to self.shape[self.axis]==1 len(self.real_lbs)==1?
return len(arg_acc) - arg_acc[::-1].index(prod(root.src[0].shape[:src_axis])) - 1
if root.op is Ops.PERMUTE: return root.arg.index(src_axis) if src_axis is not None else None
raise NotImplementedError("rest should be passthrough")
def alu_multi(root:UOp):
msrcs = root.src
assert all(x.op is Ops.MULTI for x in msrcs), f"all buffers must be MultiLazyBuffer {[x.op for x in msrcs]}"
assert all_same([x.device for x in msrcs]), f"all buffers must have the same device {[x.device for x in msrcs]}"
# NOTE: they all have to share an axis, we always choose [-1]
axis, bounds = axes[-1] if len(axes := dedup([(x.axis, x.bounds) for x in msrcs if x.axis is not None])) else (None, None)
axis = get_axis(root)
bounds = dedup([x.bounds for x in root.src if x.axis == axis])[-1] if axis is not None else None
srcs:list[list[UOp]] = []
not_all_real = not all(all(mlb.real) for mlb in msrcs)
new_real = tuple(all(transposed) for transposed in zip(*[mlb.real for mlb in msrcs])) if not_all_real else msrcs[0].real
@@ -64,28 +79,24 @@ def alu_multi(root:UOp):
return UOp.multi(*new_lbs, axis=axis, real=new_real)
def reduce_multi(root:UOp, multi:UOp):
op, axis = root.arg
(op, axis), new_axis = root.arg, get_axis(root)
if multi.axis is not None and multi.axis in axis:
# all-reduce on sharded axes
reduced_parts = [(x if r else x.const_like(0)).r(op, axis) for x,r in zip(multi.src, multi.real)]
# if all partitions are real, do all_reduce
if all(multi.real): return UOp.multi(*all_reduce(op, reduced_parts), axis=None)
if all(multi.real): return UOp.multi(*all_reduce(op, reduced_parts), axis=new_axis)
# only one partition is real, keep it
return UOp.multi(*reduced_parts, axis=None, real=multi.real)
return UOp.multi(*reduced_parts, axis=new_axis, real=multi.real)
# reduce on non sharded axes, piecewise is fine. if axis is None this is also correct
return UOp.multi(*[x.r(op, axis) for x in multi.src], axis=multi.axis, real=multi.real)
return UOp.multi(*[x.r(op, axis) for x in multi.src], axis=new_axis, real=multi.real)
def _shape_to_single_shard(axis, shape:tuple[sint, ...], lb:UOp) -> tuple[sint, ...]:
return tuple(lb.shape[axis] if a == axis else s for a,s in enumerate(shape))
def reshape_multi(root:UOp, multi:UOp):
arg = root.arg
if multi.axis is None: return UOp.multi(*[x.reshape(arg) for x in multi.src], axis=None, real=multi.real)
arg, new_axis = root.arg, get_axis(root)
if multi.axis is None: return UOp.multi(*[x.reshape(arg) for x in multi.src], axis=new_axis, real=multi.real)
assert prod(multi.shape) == prod(arg), "reshape must maintain prod(shape)"
arg_acc:list[sint] = list(itertools.accumulate(arg, operator.mul, initial=1))
# new_axis is the last one that preserves prod(prior to new_axis) and must not move items between shards
# todo: what to do about shrinking to self.shape[self.axis]==1 len(self.real_lbs)==1?
new_axis = len(arg_acc) - arg_acc[::-1].index(prod(multi.shape[:multi.axis])) - 1
assert all(prod(lb.shape[multi.axis:])%prod(arg[new_axis+1:])==0 for lb in multi.src), \
f"reshape cannot move items between shards {multi.shape} -> {root.arg=}"
lbs = [x.reshape(tuple(s if a!=new_axis else prod(x.shape[multi.axis:])//prod(arg[new_axis+1:]) for a,s in enumerate(arg))) for x in multi.src]
@@ -109,7 +120,7 @@ def pad_multi(root:UOp, multi:UOp):
def permute_multi(root:UOp, multi:UOp):
# all permutes supported!
return UOp.multi(*[x.permute(root.arg) for x in multi.src], axis=root.arg.index(multi.axis) if multi.axis is not None else None, real=multi.real)
return UOp.multi(*[x.permute(root.arg) for x in multi.src], axis=get_axis(root), real=multi.real)
def shrink_multi(root:UOp, multi:UOp):
assert multi.axis is None or root.arg[multi.axis] == (0, multi.shape[multi.axis]) or root.arg[multi.axis] in multi.bounds, \

View File

@@ -13,7 +13,257 @@ from tinygrad.device import Buffer
# creation can recurse a lot
sys.setrecursionlimit(10000)
# **** ScheduleItem return type
# **** schedule simplifier
def simplify_reduceop(reduce:UOp, x:UOp) -> UOp|None:
if not all_int(x.shape): return None
# remove reduce on unmasked const
prshape = prod(unwrap(x.st).shape[i] for i in reduce.arg[1])
ret = x.const_arg
match reduce.arg[0]:
case Ops.ADD: ret *= prshape
case Ops.MUL: ret **= prshape
case Ops.MAX: pass # NOTE: Ops.MAX is passthrough
case _: return None
return reduce.const_like(ret)
def found_contiguous(ctx:dict[UOp, UOp], contig:UOp, src:UOp):
if (sti:=unwrap(src.st).invert(src.base.shape)) is not None: ctx[src.base] = contig.view(sti)
def replace_contiguous(ctx:dict[UOp, UOp], alu:UOp):
new_src = list(alu.src)
for i,s in enumerate(alu.src):
if (replace_src:=ctx.get(s, None)) is not None: new_src[i] = replace_src
if tuple(new_src) != alu.src: return alu.replace(src=tuple(new_src))
sym = symbolic_simple+PatternMatcher([
# UOp with size 0 is zero
(UPat(set(Ops)-{Ops.SINK}, name="root"), lambda root: root.const_like(0) if root.base.st is not None and root.size == 0 \
and not (root.base.op is Ops.CONST and root.base.arg == 0) else None),
# DETACH and CONTIGUOUS_BACKWARD are NOOPs here
(UPat((Ops.DETACH, Ops.CONTIGUOUS_BACKWARD), name="x"), lambda x: x.src[0]),
# reduce of size 0 is the identity element
(UPat(Ops.REDUCE_AXIS, name="reduce", src=(UPat.var("x"),)),
lambda reduce,x: reduce.const_like(identity_element(reduce.arg[0], reduce.dtype)) if x.size == 0 and reduce.size != 0 else None),
# reduce of const is collapsed (TODO: make this a generic rule for stride0)
(UPat(Ops.REDUCE_AXIS, name="reduce", src=(UPat.cvar("x"),)), simplify_reduceop),
# COPY(CONST) creates a new CONST on the destination device
(UPat(Ops.COPY, name="root", src=(UPat(), UPat.cvar("x"),)), lambda root,x: root.const_like(x.const_arg)),
# no COPY to same device, except clone (arg is True)
(UPat(Ops.COPY, src=(UPat(), UPat.var("copyin")), name="copy"),
lambda copyin,copy: copyin if copyin.device == copy.device and copy.arg is not True else None),
# remove cast to image when it's already a contiguous image
(UPat(Ops.VIEW, name="vm1", src=(UPat(Ops.CAST, name="cast", src=(UPat(Ops.VIEW, name="vm2", src=(UPat(Ops.CONTIGUOUS, name="base"))))),)),
lambda cast,base,vm1,vm2: base.view(vm2.st+vm1.st) if isinstance(cast.dtype, ImageDType) and isinstance(base.dtype, ImageDType) else None),
# remove contiguous if we can just view the buffer
(UPat(Ops.CONTIGUOUS, name="root", src=(UPat(Ops.VIEW, name="view", src=(UPat(Ops.BUFFER, name="buf"),)),)),
lambda root,view,buf: view if view.st.contiguous and view.size == buf.size else None),
# contiguous/buffer is already contiguous
(UPat(Ops.CONTIGUOUS, name="root", src=(UPat((Ops.CONTIGUOUS, Ops.BUFFER)),)), lambda root: root.src[0]),
# support for using a contiguous permuted view instead of the parent view if one exists
(UPat(Ops.CONTIGUOUS, name="contig", src=(UPat(Ops.VIEW, name="src"),)), found_contiguous),
(UPat(GroupOp.ALU, name="alu"), replace_contiguous),
# remove CONST/BIND/BUFFER from SINK
(UPat(Ops.SINK, name="root"),
lambda root: UOp(Ops.SINK, root.dtype, new_src, root.arg)
if (new_src:=tuple(x for x in root.src if not x.is_realized and x.base.op not in {Ops.CONST, Ops.BIND})) != root.src else None),
])
remove_movement_ops = merge_views+PatternMatcher([
# NOTE: movement ops are always applied to base
(UPat(GroupOp.Movement, name="mov", src=(UPat.var("x"),)), lambda x,mov: x.view(unwrap(mov.st))),
# some masked views can collapse to 0, VIEW(x) -> CONST(VIEW)
(UPat(Ops.VIEW, name="view"),
lambda view: view.const_like(0) if (vm:=view.st.views[-1].mask) is not None and any((x[1]-x[0]) == 0 for x in vm) else None),
])
# **** UOp realization
@dataclass(frozen=True)
class ScheduleContext:
tensor_uops: dict[UOp, list[UOp]] # this maps BUFFER uops of this schedule to the tensor uop
assigns: set[UOp] = field(default_factory=set) # this holds all the BUFFER uops we ASSIGN to in this schedule
realizes: dict[UOp, UOp] = field(default_factory=dict) # this holds all the BUFFER uops we mutate in this schedule
allbufs: dict[UOp, UOp] = field(default_factory=dict) # this maps BUFFER uops the actual op
ops_metadata: dict[UOp, Metadata] = field(default_factory=dict) # this maps fused ops to Metadata
children: defaultdict[UOp, dict[UOp, None]] = field(default_factory=lambda: defaultdict(dict))
preloads: defaultdict[Buffer, dict[UOp, None]] = field(default_factory=lambda: defaultdict(dict))
# wrap tensor uops around a VIEW(BUFFER, <uop>)
# this BUFFER preserves a link back to the uop on the tensor after the scheduler rewrites it.
def add_buffers(buf:UOp, buffer_map:dict[UOp, UOp], cache:dict[UOp, UOp]) -> UOp:
if (r:=cache.get(buf)) is not None: return r
# SINK is passthrough
if buf.op is Ops.SINK: return buf.replace(src=tuple(add_buffers(x, buffer_map, cache) for x in buf.src))
# skip creating buffers for CONST/BIND/DEVICE/BUFFER
if buf.base.op in {Ops.CONST, Ops.BIND, Ops.DEVICE}: return buf
if buf.base.op is Ops.BUFFER: return buf.view(unwrap(buf.st))
# VIEW is passthrough
if buf is not buf.base:
cache[buf] = ret = add_buffers(buf.base, buffer_map, cache).view(unwrap(buf.st))
return ret
# make things that can't be images not images
dtype = buf.dtype
if isinstance(dtype, ImageDType) and (prod(buf.shape)!=prod(dtype.shape) or not any(buf.shape[x]%4==0 for x in unwrap(buf.st).unit_stride_axes())):
if DEBUG >= 2: print(f"forcing image {dtype} with shape {buf.shape} to {dtype.base}")
dtype = buf.dtype.base
# ASSIGN already has a target buffer, otherwise we create a new one
assert isinstance(buf.device, str), f"buf device is str, not {buf.device}"
buf_uop = buf.buf_uop if buf.op is Ops.ASSIGN else UOp.new_buffer(buf.device, buf.size, dtype)
op = buf.replace(dtype=dtype, src=tuple(add_buffers(x, buffer_map, cache) for x in buf.src))
# track the buffer uop for the simplified uop
buffer_map[buf] = buf_uop
# (early) bufferize
cache[buf] = ret = UOp(Ops.VIEW, dtype.base, (buf_uop, op), buf.st)
return ret
class UPatScheduled(UPat):
def __init__(self, *args, **kwargs):
super().__init__(Ops.VIEW, name="base", src=(UPat(Ops.BUFFER, name="b"), UPat(*args, **{"name":"to_store",**kwargs})))
def realize(ctx:ScheduleContext, b:UOp, to_store:UOp, **kwargs) -> None: ctx.realizes[b] = to_store
def realize_before_view(ctx:ScheduleContext, view:UOp, src:UOp, b:UOp, **kwargs) -> None:
st = unwrap(view.st)
# fold simple pads
if len(st.views) == 1 and (m:=st.views[-1].mask) is not None and all_int(src.shape) and resolve(prod(src.shape) >= prod([y-x for x,y in m])):
return None if can_pad(src, ctx.realizes, set()) else realize(ctx, b, src)
# early realize before expand
if resolve(prod(src.shape) < prod(st.shape)) and not getenv("DONT_REALIZE_EXPAND"): return realize(ctx, b, src)
# otherwise safety check pads
return None if (all(v.mask is None for v in st.views) or can_pad(src, ctx.realizes, set())) else realize(ctx, b, src)
def fold_img_cast(ctx:ScheduleContext, xb:UOp, view:UOp, b:UOp, x:UOp, **kwargs) -> UOp|None:
if not isinstance(xb.dtype, ImageDType) or b not in ctx.realizes or xb not in ctx.realizes or uval(x.base).op is Ops.COPY: return None
del ctx.realizes[b]
return x.view(unwrap(view.st))
def create_subbuffer(base:UOp, b:UOp, root:UOp, x:UOp):
if isinstance(b.device, tuple) or not b.device.startswith("DISK"): return None
buffers[b] = x.buf_uop.buffer.view(b.size, b.dtype, unwrap(x.st).views[0].offset*x.dtype.itemsize)
return base.replace(src=(b, root.replace(op=Ops.BUFFER_VIEW)))
do_realize = PatternMatcher([
# always realize SINK parents
(UPat(Ops.SINK, name="sink"), lambda ctx,sink: ctx.realizes.update((x.buf_uop, x) for x in sink.src)),
# always realize ASSIGN/CONTIGUOUS/COPY/BUFFER_VIEW
(UPatScheduled({Ops.ASSIGN, Ops.CONTIGUOUS, Ops.COPY, Ops.BUFFER_VIEW}), realize),
# realize before expand or unsafe pad ops
(UPat(Ops.VIEW, name="view", src=(UPatScheduled(name="src"),)), realize_before_view),
# don't realize image to image casts
(UPat(Ops.VIEW, name="view", src=(UPatScheduled(Ops.CAST, src=(UPat(Ops.VIEW, src=(UPat.var("xb"), UPat()), name="x"),), dtype=dtypes.float),)),
fold_img_cast),
# realize before COPY or BUFFER_VIEW
(UPat(Ops.COPY, src=(UPat(), UPat.any(UPatScheduled(), UPatScheduled().view()),)), realize),
(UPat(Ops.BUFFER_VIEW, src=(UPat.any(UPatScheduled(), UPatScheduled().view()),)), realize),
# substitute BITCAST/CONTIGUOUS with BUFFER_VIEW on DISK
(UPatScheduled((Ops.BITCAST, Ops.CONTIGUOUS), name="root", src=(UPat.var("x"),)), create_subbuffer),
])
def append_uop(ctx:ScheduleContext, view:UOp, buf_uop:UOp) -> None:
ctx.allbufs[buf_uop] = view
if (op:=uval(view)).op is Ops.ASSIGN: ctx.assigns.add(buf_uop)
for x in op.base.src:
if is_scheduled(x.base): ctx.children.setdefault(x.base.buf_uop, {})[buf_uop] = None
create_ctx = PatternMatcher([(UPat(Ops.VIEW, name="view", src=(UPat(Ops.BUFFER, name="buf_uop"), UPat())), append_uop)])
def is_scheduled(u:UOp) -> bool: return u.op is Ops.VIEW and len(u.src) == 2 and u.src[0].op is Ops.BUFFER
def uval(u:UOp) -> UOp:
assert is_scheduled(u), f"must be a scheduled op {u}"
return u.src[1]
def recursive_group(tr:UOp, st:ShapeTracker, r:UOp, children:defaultdict[UOp, dict[UOp, None]], allbufs:dict[UOp, UOp], realizes:dict[UOp, UOp],
reduce_for_op:dict[UOp, UOp], group:dict[UOp, None], cache:dict[tuple[UOp, ShapeTracker], None]) -> None:
"""recursively search the uop for groupable children, realize the UOp if a child can't group"""
if (tr, st) in cache: return
cache.setdefault((tr, st))
rsize = unwrap(allbufs[r].st).size
if tr in realizes and tr is not r:
# can only fuse contiguous
# max one reduceop per kernel
if not st.contiguous or st.size != rsize or tr in reduce_for_op: group.setdefault(r)
return group.setdefault(tr)
for tr_next in children[tr]:
# max one reduceop per kernel
if (tr_next_uop:=uval(allbufs[tr_next]).base).op is Ops.REDUCE_AXIS: return group.setdefault(r)
# can only fuse contiguous
if len(st_childs:=dedup(unwrap(x.st) for x in tr_next_uop.src if is_scheduled(x.base) and x.base.buf_uop == tr)) > 1: return group.setdefault(r)
recursive_group(tr_next, st+st_childs[0], r, children, allbufs, realizes, reduce_for_op, group, cache)
def group_realizes(sink:UOp, ctx:ScheduleContext) -> dict[UOp, UOp]:
# start by adding uops that always realize
sink = graph_rewrite(sink, do_realize+create_ctx, ctx)
# find all reduces, and pair them to a elementwise op. if they can't be cleanly paired, force realize the reduce (or a contig child)
reduce_for_op: dict[UOp, UOp] = {}
double_reduces: list[UOp] = []
for r, r_uop in ctx.allbufs.items():
if (r_uop:=uval(r_uop)).op is not Ops.REDUCE_AXIS: continue
if FUSE_CONV_BW and is_scheduled((x:=r_uop.src[0]).base) and uval(x.base).op is r_uop.op and x.base is not x: double_reduces.append(r)
if r in ctx.realizes: continue
group: dict[UOp, None] = {}
recursive_group(r, unwrap(r_uop.st), r, ctx.children, ctx.allbufs, ctx.realizes, reduce_for_op, group, cache={})
# max one reduceop per kernel
can_chase = all(tr not in reduce_for_op for tr in group)
# TODO: forced_realize exists because the scheduler is incapable of checking for self-contained DAGs
forced_realize = r in group
# can only have one output
if not forced_realize and len(group) > 1: forced_realize = True
# can only fuse assign if no other assign_target is used in the kernel
if not forced_realize and any(x in ctx.assigns for x in group):
parents = deque((r, *group))
while parents and not forced_realize:
if (p_uop:=ctx.allbufs.get(p:=parents.pop())) is None: continue
if (p_uop:=uval(p_uop)).op is Ops.ASSIGN and p not in group: forced_realize, can_chase = True, False
if p in ctx.realizes: continue
parents.extend([x.base.buf_uop for x in p_uop.src if x.base.is_realized or (x.base.op is Ops.VIEW and len(x.base.src) != 0)])
if forced_realize or not group:
tr = r
if can_chase:
# can chase this down to contiguous children
st = unwrap(r_uop.st)
while len(ctx.children[tr]) == 1:
tr_next_uop = uval(ctx.allbufs[(tr_next:=next(iter(ctx.children[tr])))])
st_childs = dedup([unwrap(x.st) for x in tr_next_uop.src if is_scheduled(x.base) and x.base.buf_uop is tr])
if len(st_childs) > 1: break
if st.size != st_childs[0].size: break
st = st + st_childs[0]
if not st.contiguous or tr_next_uop.op is Ops.REDUCE_AXIS: break
tr = tr_next
# don't cast to higher size before store (tr cannot be realized if forced_realize)
if (tr_uop:=uval(ctx.allbufs[tr])).op is Ops.CAST and tr_uop.dtype.base.itemsize > tr_uop.src[0].dtype.base.itemsize:
tr = tr_uop.src[0].base.buf_uop
group = {tr: None}
ctx.realizes[tr] = tr
reduce_for_op.update((tr, r) for tr in group)
if FUSE_ARANGE and r_uop.arg[0] is Ops.ADD and r_uop.src[0].base.op is Ops.CONST:
# maybe fuse arange with its children
if len(flatten(ctx.children[tr] for tr in group)) != 0:
for tr in group: del ctx.realizes[tr]
# fuse double reduces with no other child
for reduceop in double_reduces:
top_reduce = uval(ctx.allbufs[reduceop]).src[0].base.buf_uop
if len(ctx.children[top_reduce]) == 1: del ctx.realizes[top_reduce]
graph_rewrite(sink, break_sched, ctx)
return ctx.realizes
# break the SINK into stores
def load_realized(ctx:ScheduleContext, b:UOp, st:UOp):
# NOTE: if we're assigning to the BUFFER too, PRELOAD tells toposort to place this load before the ASSIGN
return UOp(Ops.PRELOAD if b in ctx.assigns else Ops.LOAD, b.dtype.base, (b, unwrap(st.st).to_uop()))
def store_or_fuse(ctx:ScheduleContext, b:UOp, x:UOp, st:UOp):
if (m:=ctx.tensor_uops[b][-1].metadata) is not None: ctx.ops_metadata[x] = m
if b not in ctx.realizes: return x # collapse BUFFER
ctx.realizes[b] = UOp.store(b, ShapeTracker.from_shape(st.shape).to_uop(), x)
return UOp(Ops.LOAD, x.dtype, (b, unwrap(st.st).to_uop()))
break_sched = PatternMatcher([
# VIEW of BUFFER either becomes a LOAD/STORE or we fuse it
(UPat(Ops.VIEW, name="st", src=(UPat(Ops.BUFFER, name="b"),)), load_realized),
(UPat(Ops.VIEW, name="st", src=(UPat(Ops.BUFFER, name="b"), UPat.var("x"))), store_or_fuse),
])
# **** ScheduleItem creation
@dataclass(frozen=True)
class ScheduleItem:
@@ -31,49 +281,11 @@ class ScheduleItem:
@functools.cached_property
def output_idxs(self) -> tuple[int, ...]: return tuple(x.src[0].arg for x in self.ast.src) if self.ast.op is Ops.SINK else (0,)
# **** Schedule context and big graph
@dataclass(frozen=True)
class ScheduleContext:
tensor_uops: dict[UOp, list[UOp]] = field(default_factory=dict) # this maps BUFFER uops of this schedule to the tensor uop
assigns: set[UOp] = field(default_factory=set) # this holds all the BUFFER uops we ASSIGN to in this schedule
realizes: dict[UOp, UOp] = field(default_factory=dict) # this holds all the BUFFER uops we mutate in this schedule
allbufs: dict[UOp, UOp] = field(default_factory=dict) # this maps BUFFER uops the actual op
ops_metadata: dict[UOp, Metadata] = field(default_factory=dict) # this maps fused ops to Metadata
children: defaultdict[UOp, dict[UOp, None]] = field(default_factory=lambda: defaultdict(dict))
preloads: defaultdict[Buffer, dict[UOp, None]] = field(default_factory=lambda: defaultdict(dict))
# wrap tensor uops around a VIEW(BUFFER, <uop>)
# this BUFFER preserves a link back to the uop on the tensor after the scheduler rewrites it.
def add_buffers(buf:UOp, tensor_map:dict[UOp, list[UOp]], ctx:ScheduleContext, cache:dict[UOp, UOp]) -> UOp:
if (r:=cache.get(buf)) is not None: return r
# SINK is passthrough
if buf.op is Ops.SINK: return buf.replace(src=tuple(add_buffers(x, tensor_map, ctx, cache) for x in buf.src))
# skip creating buffers for CONST/BIND/DEVICE/BUFFER
if buf.base.op in {Ops.CONST, Ops.BIND, Ops.DEVICE}: return buf
if buf.base.op is Ops.BUFFER: return buf.view(unwrap(buf.st))
# VIEW is passthrough
if buf is not buf.base:
cache[buf] = ret = add_buffers(buf.base, tensor_map, ctx, cache).view(unwrap(buf.st))
return ret
# make things that can't be images not images
dtype = buf.dtype
if isinstance(dtype, ImageDType) and (prod(buf.shape)!=prod(dtype.shape) or not any(buf.shape[x]%4==0 for x in unwrap(buf.st).unit_stride_axes())):
if DEBUG >= 2: print(f"forcing image {dtype} with shape {buf.shape} to {dtype.base}")
dtype = buf.dtype.base
# ASSIGN already has a target buffer, otherwise we create a new one
assert isinstance(buf.device, str), f"buf device is str, not {buf.device}"
buf_uop = buf.buf_uop if buf.op is Ops.ASSIGN else UOp.new_buffer(buf.device, buf.size, dtype)
op = buf.replace(dtype=dtype, src=tuple(add_buffers(x, tensor_map, ctx, cache) for x in buf.src))
# track the underlying tensor uop for this buffer
ctx.tensor_uops[buf_uop] = tensor_map[buf]
# (early) bufferize
cache[buf] = ret = UOp(Ops.VIEW, dtype.base, (buf_uop, op), buf.st)
return ret
# **** AST graph rewrite
# ** movement ops
class ScheduleItemContext:
var_vals: dict[Variable, int]
sts: set[ShapeTracker] = field(default_factory=set)
bufs: list[UOp] = field(default_factory=list)
def apply_swizzle(u:UOp) -> UOp:
with Context(TRACK_MATCH_STATS=0): return graph_rewrite(u, view_left)
@@ -127,14 +339,6 @@ view_right = merge_views+PatternMatcher([
(UPat(Ops.REDUCE_AXIS, src=(UPat(Ops.REDUCE_AXIS, name="first_reduce"),), name="root"), merge_double_reduce),
])
# ** ScheduleItem context builder
@dataclass(frozen=True)
class ScheduleItemContext:
var_vals: dict[Variable, int]
sts: set[ShapeTracker] = field(default_factory=set)
bufs: list[UOp] = field(default_factory=list)
def _append_st_vars(ctx:ScheduleItemContext, x:UOp) -> UOp|None:
if (st:=unwrap(x.st)) in ctx.sts: return None
st, var_vals = st.simplify().unbind()
@@ -203,222 +407,7 @@ if CAPTURE_PROCESS_REPLAY:
def save_process_replay() -> None:
for k,v in PROCESS_REPLAY_CAPTURE.items(): diskcache_put("schedule_process_replay", k, v, prepickled=True)
# **** Schedule grouping
def is_scheduled(u:UOp) -> bool: return u.op is Ops.VIEW and len(u.src) == 2 and u.src[0].op is Ops.BUFFER
def uval(u:UOp) -> UOp:
assert is_scheduled(u), f"must be a scheduled op {u}"
return u.src[1]
def recursive_group(tr:UOp, st:ShapeTracker, r:UOp, children:defaultdict[UOp, dict[UOp, None]], allbufs:dict[UOp, UOp], realizes:dict[UOp, UOp],
reduce_for_op:dict[UOp, UOp], group:dict[UOp, None], cache:dict[tuple[UOp, ShapeTracker], None]) -> None:
"""recursively search the uop for groupable children, realize the UOp if a child can't group"""
if (tr, st) in cache: return
cache.setdefault((tr, st))
rsize = unwrap(allbufs[r].st).size
if tr in realizes and tr is not r:
# can only fuse contiguous
# max one reduceop per kernel
if not st.contiguous or st.size != rsize or tr in reduce_for_op: group.setdefault(r)
return group.setdefault(tr)
for tr_next in children[tr]:
# max one reduceop per kernel
if (tr_next_uop:=uval(allbufs[tr_next]).base).op is Ops.REDUCE_AXIS: return group.setdefault(r)
# can only fuse contiguous
if len(st_childs:=dedup(unwrap(x.st) for x in tr_next_uop.src if is_scheduled(x.base) and x.base.buf_uop == tr)) > 1: return group.setdefault(r)
recursive_group(tr_next, st+st_childs[0], r, children, allbufs, realizes, reduce_for_op, group, cache)
def group_realizes(ctx:ScheduleContext) -> None:
"""search the big graph for all the reduceops that need to realize, sometimes group/fuse the reduceop"""
# find all reduces, and pair them to a elementwise op. if they can't be cleanly paired, force realize the reduce (or a contig child)
reduce_for_op: dict[UOp, UOp] = {}
double_reduces: list[UOp] = []
for r, r_uop in ctx.allbufs.items():
if (r_uop:=uval(r_uop)).op is not Ops.REDUCE_AXIS: continue
if FUSE_CONV_BW and is_scheduled((x:=r_uop.src[0]).base) and uval(x.base).op is r_uop.op and x.base is not x: double_reduces.append(r)
if r in ctx.realizes: continue
group: dict[UOp, None] = {}
recursive_group(r, unwrap(r_uop.st), r, ctx.children, ctx.allbufs, ctx.realizes, reduce_for_op, group, cache={})
# max one reduceop per kernel
can_chase = all(tr not in reduce_for_op for tr in group)
# TODO: forced_realize exists because the scheduler is incapable of checking for self-contained DAGs
forced_realize = r in group
# can only have one output
if not forced_realize and len(group) > 1: forced_realize = True
# can only fuse assign if no other assign_target is used in the kernel
if not forced_realize and any(x in ctx.assigns for x in group):
parents = deque((r, *group))
while parents and not forced_realize:
if (p_uop:=ctx.allbufs.get(p:=parents.pop())) is None: continue
if (p_uop:=uval(p_uop)).op is Ops.ASSIGN and p not in group: forced_realize, can_chase = True, False
if p in ctx.realizes: continue
parents.extend([x.base.buf_uop for x in p_uop.src if x.base.is_realized or (x.base.op is Ops.VIEW and len(x.base.src) != 0)])
if forced_realize or not group:
tr = r
if can_chase:
# can chase this down to contiguous children
st = unwrap(r_uop.st)
while len(ctx.children[tr]) == 1:
tr_next_uop = uval(ctx.allbufs[(tr_next:=next(iter(ctx.children[tr])))])
st_childs = dedup([unwrap(x.st) for x in tr_next_uop.src if is_scheduled(x.base) and x.base.buf_uop is tr])
if len(st_childs) > 1: break
if st.size != st_childs[0].size: break
st = st + st_childs[0]
if not st.contiguous or tr_next_uop.op is Ops.REDUCE_AXIS: break
tr = tr_next
# don't cast to higher size before store (tr cannot be realized if forced_realize)
if (tr_uop:=uval(ctx.allbufs[tr])).op is Ops.CAST and tr_uop.dtype.base.itemsize > tr_uop.src[0].dtype.base.itemsize:
tr = tr_uop.src[0].base.buf_uop
group = {tr: None}
ctx.realizes[tr] = tr
reduce_for_op.update((tr, r) for tr in group)
if FUSE_ARANGE and r_uop.arg[0] is Ops.ADD and r_uop.src[0].base.op is Ops.CONST:
# maybe fuse arange with its children
if len(flatten(ctx.children[tr] for tr in group)) != 0:
for tr in group: del ctx.realizes[tr]
# fuse double reduces with no other child
for reduceop in double_reduces:
top_reduce = uval(ctx.allbufs[reduceop]).src[0].base.buf_uop
if len(ctx.children[top_reduce]) == 1: del ctx.realizes[top_reduce]
# **** Schedule creation and BFS toposort
# ** this is schedule level const folding
def simplify_reduceop(reduce:UOp, x:UOp) -> UOp|None:
if not all_int(x.shape): return None
# remove reduce on unmasked const
prshape = prod(unwrap(x.st).shape[i] for i in reduce.arg[1])
ret = x.const_arg
match reduce.arg[0]:
case Ops.ADD: ret *= prshape
case Ops.MUL: ret **= prshape
case Ops.MAX: pass # NOTE: Ops.MAX is passthrough
case _: return None
return reduce.const_like(ret)
def found_contiguous(ctx:dict[UOp, UOp], contig:UOp, src:UOp):
if (sti:=unwrap(src.st).invert(src.base.shape)) is not None: ctx[src.base] = contig.view(sti)
def replace_contiguous(ctx:dict[UOp, UOp], alu:UOp):
new_src = list(alu.src)
for i,s in enumerate(alu.src):
if (replace_src:=ctx.get(s, None)) is not None: new_src[i] = replace_src
if tuple(new_src) != alu.src: return alu.replace(src=tuple(new_src))
sym = symbolic_simple+PatternMatcher([
# UOp with size 0 is zero
(UPat(set(Ops)-{Ops.SINK}, name="root"), lambda root: root.const_like(0) if root.base.st is not None and root.size == 0 \
and not (root.base.op is Ops.CONST and root.base.arg == 0) else None),
# DETACH and CONTIGUOUS_BACKWARD are NOOPs here
(UPat((Ops.DETACH, Ops.CONTIGUOUS_BACKWARD), name="x"), lambda x: x.src[0]),
# reduce of size 0 is the identity element
(UPat(Ops.REDUCE_AXIS, name="reduce", src=(UPat.var("x"),)),
lambda reduce,x: reduce.const_like(identity_element(reduce.arg[0], reduce.dtype)) if x.size == 0 and reduce.size != 0 else None),
# reduce of const is collapsed (TODO: make this a generic rule for stride0)
(UPat(Ops.REDUCE_AXIS, name="reduce", src=(UPat.cvar("x"),)), simplify_reduceop),
# COPY(CONST) creates a new CONST on the destination device
(UPat(Ops.COPY, name="root", src=(UPat(), UPat.cvar("x"),)), lambda root,x: root.const_like(x.const_arg)),
# no COPY to same device, except clone (arg is True)
(UPat(Ops.COPY, src=(UPat(), UPat.var("copyin")), name="copy"),
lambda copyin,copy: copyin if copyin.device == copy.device and copy.arg is not True else None),
# remove cast to image when it's already a contiguous image
(UPat(Ops.VIEW, name="vm1", src=(UPat(Ops.CAST, name="cast", src=(UPat(Ops.VIEW, name="vm2", src=(UPat(Ops.CONTIGUOUS, name="base"))))),)),
lambda cast,base,vm1,vm2: base.view(vm2.st+vm1.st) if isinstance(cast.dtype, ImageDType) and isinstance(base.dtype, ImageDType) else None),
# remove contiguous if we can just view the buffer
(UPat(Ops.CONTIGUOUS, name="root", src=(UPat(Ops.VIEW, name="view", src=(UPat(Ops.BUFFER, name="buf"),)),)),
lambda root,view,buf: view if view.st.contiguous and view.size == buf.size else None),
# contiguous/buffer is already contiguous
(UPat(Ops.CONTIGUOUS, name="root", src=(UPat((Ops.CONTIGUOUS, Ops.BUFFER)),)), lambda root: root.src[0]),
# support for using a contiguous permuted view instead of the parent view if one exists
(UPat(Ops.CONTIGUOUS, name="contig", src=(UPat(Ops.VIEW, name="src"),)), found_contiguous),
(UPat(GroupOp.ALU, name="alu"), replace_contiguous),
# remove CONST/BIND/BUFFER from SINK
(UPat(Ops.SINK, name="root"),
lambda root: UOp(Ops.SINK, root.dtype, new_src, root.arg)
if (new_src:=tuple(x for x in root.src if not x.is_realized and x.base.op not in {Ops.CONST, Ops.BIND})) != root.src else None),
])
# ** this decides which ops get realized
class UPatScheduled(UPat):
def __init__(self, *args, **kwargs):
super().__init__(Ops.VIEW, name="base", src=(UPat(Ops.BUFFER, name="b"), UPat(*args, **{"name":"to_store",**kwargs})))
def realize(ctx:ScheduleContext, b:UOp, to_store:UOp, **kwargs) -> None: ctx.realizes[b] = to_store
def realize_before_view(ctx:ScheduleContext, view:UOp, src:UOp, b:UOp, **kwargs) -> None:
st = unwrap(view.st)
# fold simple pads
if len(st.views) == 1 and (m:=st.views[-1].mask) is not None and all_int(src.shape) and resolve(prod(src.shape) >= prod([y-x for x,y in m])):
return None if can_pad(src, ctx.realizes, set()) else realize(ctx, b, src)
# early realize before expand
if resolve(prod(src.shape) < prod(st.shape)) and not getenv("DONT_REALIZE_EXPAND"): return realize(ctx, b, src)
# otherwise safety check pads
return None if (all(v.mask is None for v in st.views) or can_pad(src, ctx.realizes, set())) else realize(ctx, b, src)
def fold_img_cast(ctx:ScheduleContext, xb:UOp, view:UOp, b:UOp, x:UOp, **kwargs) -> UOp|None:
if not isinstance(xb.dtype, ImageDType) or b not in ctx.realizes or xb not in ctx.realizes or uval(x.base).op is Ops.COPY: return None
del ctx.realizes[b]
return x.view(unwrap(view.st))
def create_subbuffer(base:UOp, b:UOp, root:UOp, x:UOp):
if isinstance(b.device, tuple) or not b.device.startswith("DISK"): return None
buffers[b] = x.buf_uop.buffer.view(b.size, b.dtype, unwrap(x.st).views[0].offset*x.dtype.itemsize)
return base.replace(src=(b, root.replace(op=Ops.BUFFER_VIEW)))
do_realize = PatternMatcher([
# always realize SINK parents
(UPat(Ops.SINK, name="sink"), lambda ctx,sink: ctx.realizes.update((x.buf_uop, x) for x in sink.src)),
# always realize ASSIGN/CONTIGUOUS/COPY/BUFFER_VIEW
(UPatScheduled({Ops.ASSIGN, Ops.CONTIGUOUS, Ops.COPY, Ops.BUFFER_VIEW}), realize),
# realize before expand or unsafe pad ops
(UPat(Ops.VIEW, name="view", src=(UPatScheduled(name="src"),)), realize_before_view),
# don't realize image to image casts
(UPat(Ops.VIEW, name="view", src=(UPatScheduled(Ops.CAST, src=(UPat(Ops.VIEW, src=(UPat.var("xb"), UPat()), name="x"),), dtype=dtypes.float),)),
fold_img_cast),
# realize before COPY or BUFFER_VIEW
(UPat(Ops.COPY, src=(UPat(), UPat.any(UPatScheduled(), UPatScheduled().view()),)), realize),
(UPat(Ops.BUFFER_VIEW, src=(UPat.any(UPatScheduled(), UPatScheduled().view()),)), realize),
# substitute BITCAST/CONTIGUOUS with BUFFER_VIEW on DISK
(UPatScheduled((Ops.BITCAST, Ops.CONTIGUOUS), name="root", src=(UPat.var("x"),)), create_subbuffer),
])
# **** rewrite VIEW into LOAD/STORE or fuse the underlying UOp
def load_realized(ctx:ScheduleContext, b:UOp, st:UOp):
# NOTE: if we're assigning to the BUFFER too, PRELOAD tells toposort to place this load before the ASSIGN
return UOp(Ops.PRELOAD if b in ctx.assigns else Ops.LOAD, b.dtype.base, (b, unwrap(st.st).to_uop()))
def store_or_fuse(ctx:ScheduleContext, b:UOp, x:UOp, st:UOp):
if (m:=ctx.tensor_uops[b][-1].metadata) is not None: ctx.ops_metadata[x] = m
if b not in ctx.realizes: return x # collapse BUFFER
ctx.realizes[b] = UOp.store(b, ShapeTracker.from_shape(st.shape).to_uop(), x)
return UOp(Ops.LOAD, x.dtype, (b, unwrap(st.st).to_uop()))
break_sched = PatternMatcher([
# VIEW of BUFFER either becomes a LOAD/STORE or we fuse it
(UPat(Ops.VIEW, name="st", src=(UPat(Ops.BUFFER, name="b"),)), load_realized),
(UPat(Ops.VIEW, name="st", src=(UPat(Ops.BUFFER, name="b"), UPat.var("x"))), store_or_fuse),
])
# **** Schedule context builder
def append_uop(ctx:ScheduleContext, view:UOp, buf_uop:UOp) -> None:
ctx.allbufs[buf_uop] = view
if (op:=uval(view)).op is Ops.ASSIGN: ctx.assigns.add(buf_uop)
for x in op.base.src:
if is_scheduled(x.base): ctx.children.setdefault(x.base.buf_uop, {})[buf_uop] = None
create_ctx = PatternMatcher([(UPat(Ops.VIEW, name="view", src=(UPat(Ops.BUFFER, name="buf_uop"), UPat())), append_uop)])
# **** movement ops
remove_movement_ops = merge_views+PatternMatcher([
# NOTE: movement ops are always applied to base
(UPat(GroupOp.Movement, name="mov", src=(UPat.var("x"),)), lambda x,mov: x.view(unwrap(mov.st))),
# some masked views can collapse to 0, VIEW(x) -> CONST(VIEW)
(UPat(Ops.VIEW, name="view"),
lambda view: view.const_like(0) if (vm:=view.st.views[-1].mask) is not None and any((x[1]-x[0]) == 0 for x in vm) else None),
])
# **** schedule creation and toposort
@track_rewrites(named=True)
def create_schedule_with_vars(big_sink:UOp) -> tuple[list[ScheduleItem], dict[Variable, int], dict[UOp, UOp]]:
@@ -434,15 +423,13 @@ def create_schedule_with_vars(big_sink:UOp) -> tuple[list[ScheduleItem], dict[Va
elif v.op is Ops.CONST and all_int(v.shape): becomes_map[k] = v
# we group the rest of UOps into ScheduleItems
rev_tensor_map: dict[UOp, list[UOp]] = {}
for k,v in tensor_map.items(): rev_tensor_map.setdefault(v, []).append(k)
# add BUFFER uops
sink = add_buffers(tensor_map[big_sink], rev_tensor_map, ctx:=ScheduleContext(), cache={})
# add realizes
sink = graph_rewrite(sink, do_realize+create_ctx, ctx)
# group realizes into kernels
group_realizes(ctx)
graph_rewrite(sink, break_sched, ctx)
buffer_map: dict[UOp, UOp] = {}
sink = add_buffers(tensor_map[big_sink], buffer_map, cache={})
# get realizes
buf_tensors: dict[UOp, list[UOp]] = {}
for k,v in tensor_map.items():
if (b:=buffer_map.get(v)) is not None: buf_tensors.setdefault(b, []).append(k)
realize_map = group_realizes(sink, ctx:=ScheduleContext(buf_tensors))
# TODO: this should be the break between the "grouper" and the "linearizer"
# here, there should just be one sink UOp with BUFFER/KERNEL/COPY/ASSIGN (assign is the parent if you want the buffer post assign)
@@ -451,11 +438,11 @@ def create_schedule_with_vars(big_sink:UOp) -> tuple[list[ScheduleItem], dict[Va
# create schedule items + map buffers to realized tensors
prescheduled: list[ScheduleItem] = []
var_vals: dict[Variable, int] = {}
for buf_uop,store in ctx.realizes.items():
for buf_uop,store in realize_map.items():
assert store.op is Ops.STORE, f"expected a realized BUFFER to get a STORE {sink}"
prescheduled.append(schedule_uop(store.sink(), ctx, var_vals))
# can only schedule once
for tensor_uop in ctx.tensor_uops[buf_uop]: becomes_map[tensor_uop] = buf_uop.view(unwrap(tensor_uop.st))
for tensor_uop in buf_tensors[buf_uop]: becomes_map[tensor_uop] = buf_uop.view(unwrap(tensor_uop.st))
# increment refcount for this buffer
buf_uop.buffer.ref(1)

View File

@@ -93,7 +93,7 @@ class MathTrait(SimpleMathTrait):
# the order of these Ops controls the order of the toposort
class Ops(FastEnum):
# uops that aren't rendered
SINK = auto(); CONTIGUOUS = auto(); CONTIGUOUS_BACKWARD = auto(); DETACH = auto(); PRELOAD = auto() # noqa: E702
SINK = auto(); CONTIGUOUS = auto(); CONTIGUOUS_BACKWARD = auto(); DETACH = auto(); PRELOAD = auto(); KERNEL = auto() # noqa: E702
# TODO: empty continues to exist because of tensor
EMPTY = auto()

View File

@@ -1,25 +1,9 @@
import platform, tempfile, pathlib, subprocess, sys
from tinygrad.helpers import cpu_objdump, capstone_flatdump
import platform, subprocess, sys
from tinygrad.helpers import capstone_flatdump
from tinygrad.device import Compiled, Compiler, MallocAllocator, CPUProgram
from tinygrad.runtime.support.elf import jit_loader
from tinygrad.renderer.cstyle import ClangRenderer
# Used by ops_dsp.py
class ClangCompiler(Compiler):
def __init__(self, cachekey="compile_clang", args:list[str]|None=None, objdump_tool='objdump'):
self.args = ['-march=native'] if args is None else args
self.objdump_tool = objdump_tool
super().__init__(cachekey)
def compile(self, src:str) -> bytes:
# TODO: remove file write. sadly clang doesn't like the use of /dev/stdout here
with tempfile.NamedTemporaryFile(delete=True) as output_file:
subprocess.check_output(['clang', '-shared', *self.args, '-O2', '-Wall', '-Werror', '-x', 'c', '-fPIC', '-ffreestanding', '-nostdlib',
'-', '-o', str(output_file.name)], input=src.encode('utf-8'))
return pathlib.Path(output_file.name).read_bytes()
def disassemble(self, lib:bytes): return cpu_objdump(lib, self.objdump_tool)
class ClangJITCompiler(Compiler):
def __init__(self, cachekey="compile_clang_jit"): super().__init__(cachekey)

View File

@@ -1,12 +1,11 @@
from __future__ import annotations
from typing import Tuple, Any, List
import ctypes, os, mmap, tempfile, pathlib, array, functools, threading, contextlib, sys
import ctypes, os, mmap, tempfile, pathlib, array, functools, threading, contextlib, sys, subprocess
assert sys.platform != 'win32'
from tinygrad.device import BufferSpec, Compiled, Allocator
from tinygrad.device import BufferSpec, Compiled, Allocator, Compiler
from tinygrad.dtype import dtypes, DType, PtrDType
from tinygrad.ops import Ops, UOp
from tinygrad.helpers import from_mv, getenv, round_up, mv_address, to_mv
from tinygrad.runtime.ops_clang import ClangCompiler
from tinygrad.helpers import from_mv, getenv, round_up, mv_address, to_mv, cpu_objdump
from tinygrad.renderer.cstyle import ClangRenderer
from tinygrad.runtime.autogen import libc, qcom_dsp
if getenv("IOCTL"): import extra.dsp.run # noqa: F401 # pylint: disable=unused-import
@@ -91,10 +90,23 @@ class DSPAllocator(Allocator):
def _copyout(self, dest:memoryview, src:DSPBuffer): ctypes.memmove(from_mv(dest), src.va_addr, dest.nbytes)
def _offset(self, buf, size:int, offset:int): return DSPBuffer(buf.va_addr+offset, size, buf.share_info, buf.offset+offset)
class DSPDevice(Compiled):
def __init__(self, device:str=""):
self.ion_fd = os.open('/dev/ion', os.O_RDONLY)
class ClangCompiler(Compiler):
def __init__(self, cachekey="compile_clang", args:list[str]|None=None, objdump_tool='objdump'):
self.args = ['-march=native'] if args is None else args
self.objdump_tool = objdump_tool
super().__init__(cachekey)
def compile(self, src:str) -> bytes:
# TODO: remove file write. sadly clang doesn't like the use of /dev/stdout here
with tempfile.NamedTemporaryFile(delete=True) as output_file:
subprocess.check_output(['clang', '-shared', *self.args, '-O2', '-Wall', '-Werror', '-x', 'c', '-fPIC', '-ffreestanding', '-nostdlib',
'-', '-o', str(output_file.name)], input=src.encode('utf-8'))
return pathlib.Path(output_file.name).read_bytes()
def disassemble(self, lib:bytes): return cpu_objdump(lib, self.objdump_tool)
class DSPCompiler(ClangCompiler):
def __init__(self):
# Generate link script to pass into clang. Aligning all used sections to 4k fixes invoke problem.
sections = ['hash', 'text', 'rela.plt', 'got', 'got.plt', 'dynamic', 'dynsym', 'dynstr', 'plt', 'data', 'bss']
sections_link = '\n'.join([f'.{n} : ALIGN(4096) {{ *(.{n}) }}' for n in sections])
@@ -103,15 +115,19 @@ class DSPDevice(Compiled):
self.link_ld.flush()
compiler_args = ["--target=hexagon", "-mcpu=hexagonv65", "-fuse-ld=lld", "-nostdlib", "-mhvx=v65", "-mhvx-length=128b", f"-T{self.link_ld.name}"]
super().__init__(device, DSPAllocator(self), DSPRenderer(),
ClangCompiler("compile_dsp", args=compiler_args, objdump_tool='llvm-objdump'), functools.partial(DSPProgram, self))
return super().__init__("compile_dsp", args=compiler_args, objdump_tool='llvm-objdump')
class DSPDevice(Compiled):
def __init__(self, device:str=""):
self.ion_fd = os.open('/dev/ion', os.O_RDONLY)
super().__init__(device, DSPAllocator(self), DSPRenderer(), DSPCompiler(), functools.partial(DSPProgram, self))
fastrpc_shell = memoryview(bytearray(pathlib.Path('/dsp/cdsp/fastrpc_shell_3').read_bytes()))
self.shell_buf = self.allocator.alloc(round_up(fastrpc_shell.nbytes, 0x1000), BufferSpec(nolru=True))
ctypes.memmove(self.shell_buf.va_addr, mv_address(fastrpc_shell), fastrpc_shell.nbytes)
self.init_dsp()
RPCListner(self).start()
RPCListener(self).start()
def open_lib(self, lib):
self.binded_lib, self.binded_lib_off = lib, 0
@@ -149,7 +165,7 @@ class DSPDevice(Compiled):
qcom_dsp.FASTRPC_IOCTL_INIT(self.rpc_fd, flags=0x1, file=self.shell_buf.va_addr, filelen=self.shell_buf.size, filefd=self.shell_buf.share_info.fd)
qcom_dsp.FASTRPC_IOCTL_INVOKE(self.rpc_fd, handle=3, sc=rpc_sc(method=3, ins=0, outs=0))
class RPCListner(threading.Thread):
class RPCListener(threading.Thread):
def __init__(self, device:DSPDevice):
super().__init__()
self.device, self.daemon = device, True

View File

@@ -11,12 +11,15 @@ def expect(x, err, ret=None):
if x: raise RuntimeError(llvm.string_cast(err.contents) if not isinstance(err, str) else err)
return ret
HOST_ARCH = {'arm64': 'AArch64', 'aarch64': 'AArch64', 'x86_64': 'X86', 'AMD64': 'X86'}[platform.machine()]
HOST_TRIPLE = {'AArch64': 'aarch64', 'X86': 'x86_64'}[HOST_ARCH]
REQUIRED_COMPONENTS = ['Target', 'TargetInfo', 'TargetMC', 'AsmPrinter']
class LLVMCompiler(Compiler):
def __init__(self, target_machine, opt):
def __init__(self, host_arch:str, opt:bool):
for component in ['Target', 'TargetInfo', 'TargetMC', 'AsmPrinter']: getattr(llvm, f'LLVMInitialize{host_arch}{component}')()
triple = ({'AArch64': 'aarch64', 'X86': 'x86_64'}[host_arch]+'-none-unknown-elf').encode()
target = expect(llvm.LLVMGetTargetFromTriple(triple, ctypes.pointer(tgt:=llvm.LLVMTargetRef()), err:=cerr()), err, tgt)
target_machine = llvm.LLVMCreateTargetMachine(target, triple, b'', b'+reserve-x18' if platform.machine() == 'arm64' else b'',
llvm.LLVMCodeGenLevelDefault, llvm.LLVMRelocPIC, llvm.LLVMCodeModelDefault)
self.pbo = llvm.LLVMCreatePassBuilderOptions()
if opt:
self.passes = b'default<O2>'
@@ -48,14 +51,5 @@ class LLVMCompiler(Compiler):
class LLVMDevice(Compiled):
def __init__(self, device:str):
for component in REQUIRED_COMPONENTS:
getattr(llvm, f'LLVMInitialize{HOST_ARCH}{component}')()
triple = f'{HOST_TRIPLE}-none-unknown-elf'.encode()
target = expect(llvm.LLVMGetTargetFromTriple(triple, ctypes.pointer(tgt:=llvm.LLVMTargetRef()), err:=cerr()), err, tgt)
features = b'+reserve-x18' if platform.machine() == 'arm64' else b''
target_machine = llvm.LLVMCreateTargetMachine(target, triple, b'', features, llvm.LLVMCodeGenLevelDefault, llvm.LLVMRelocPIC,
llvm.LLVMCodeModelDefault)
super().__init__(device, MallocAllocator, LLVMRenderer('win64cc' if sys.platform == 'win32' else None),
LLVMCompiler(target_machine, getenv("LLVMOPT")), CPUProgram)
compiler = LLVMCompiler({'arm64': 'AArch64', 'aarch64': 'AArch64', 'x86_64': 'X86', 'AMD64': 'X86'}[platform.machine()], bool(getenv("LLVMOPT")))
super().__init__(device, MallocAllocator, LLVMRenderer('win64cc' if sys.platform == 'win32' else None), compiler, CPUProgram)

View File

@@ -3314,10 +3314,10 @@ class Tensor(SimpleMathTrait):
if not base.is_floating_point(): raise RuntimeError("base needs to be float")
# start with b ** e = exp(e * log(b))
ret = base.abs().log().mul(exponent).exp()
# correct sign of negative base with odd exponent (cos has a period of 2pi so we use it here to get the oddness of the exponent)
# correct sign of negative base with odd exponent
negative_base = (base < 0).detach().where(1, 0)
# 1 for non-negative base or negative even exponent, -1 for negative odd exponent, don't care about non-integer exponent
correct_sign = 1 + negative_base * ((exponent * math.pi).cos() - 1)
correct_sign = (exponent.int()%2==0).where(1, 1-2*negative_base)
# inject nan for negative base and non-integer exponent
inject_nan = (negative_base * (exponent != exponent.trunc())).detach().where(math.nan, 1)
# apply correct_sign inject_nan, and fix 0 ** 0 = 1

View File

@@ -12,7 +12,7 @@ from tinygrad.dtype import dtypes
uops_colors = {Ops.LOAD: "#ffc0c0", Ops.PRELOAD: "#ffc0c0", Ops.STORE: "#87CEEB", Ops.CONST: "#e0e0e0", Ops.VCONST: "#e0e0e0",
Ops.DEFINE_GLOBAL: "#ffe0b0", Ops.DEFINE_LOCAL: "#ffe0d0", Ops.DEFINE_ACC: "#f0ffe0", Ops.REDUCE_AXIS: "#FF6B6B",
Ops.RANGE: "#c8a0e0", Ops.ASSIGN: "#e0ffc0", Ops.BARRIER: "#ff8080", Ops.IF: "#c8b0c0", Ops.SPECIAL: "#c0c0ff",
Ops.INDEX: "#e8ffa0", Ops.WMMA: "#efefc0", Ops.VIEW: "#C8F9D4", Ops.MULTI: "#f6ccff",
Ops.INDEX: "#e8ffa0", Ops.WMMA: "#efefc0", Ops.VIEW: "#C8F9D4", Ops.MULTI: "#f6ccff", Ops.KERNEL: "#3e7f55",
**{x:"#D8F9E4" for x in GroupOp.Movement}, **{x:"#ffffc0" for x in GroupOp.ALU}, Ops.THREEFRY:"#ffff80",
Ops.BLOCK: "#C4A484", Ops.BLOCKEND: "#C4A4A4", Ops.BUFFER: "#B0BDFF", Ops.COPY: "#a040a0"}