diff --git a/examples/benchmark_onnx.py b/examples/benchmark_onnx.py index 498e626aa6..333ac9fbba 100644 --- a/examples/benchmark_onnx.py +++ b/examples/benchmark_onnx.py @@ -1,6 +1,5 @@ import sys, onnx, time from tinygrad import Tensor, TinyJit, Device, GlobalCounters, fetch -from tinygrad.tensor import _from_np_dtype from extra.onnx import OnnxRunner def load_onnx_model(fn): @@ -18,19 +17,25 @@ def load_onnx_model(fn): run_onnx_jit = TinyJit(lambda **kwargs: next(iter(run_onnx({k:v.to(Device.DEFAULT) for k,v in kwargs.items()}).values())), prune=True) return run_onnx_jit, input_shapes, input_types +def get_new_inputs(input_shapes): + #from tinygrad.tensor import _from_np_dtype + #return {k:Tensor.randn(*shp, dtype=_from_np_dtype(input_types[k])).mul(8).realize() for k,shp in sorted(input_shapes.items())} + import numpy as np + return {k:Tensor(np.random.uniform(size=shp).astype(input_types[k]) * 8).realize() for k,shp in sorted(input_shapes.items())} + if __name__ == "__main__": run_onnx_jit, input_shapes, input_types = load_onnx_model(sys.argv[1]) print("loaded model") for i in range(3): - new_inputs = {k:Tensor.randn(*shp, dtype=_from_np_dtype(input_types[k])).mul(8).realize() for k,shp in sorted(input_shapes.items())} + new_inputs = get_new_inputs(input_shapes) GlobalCounters.reset() print(f"run {i}") run_onnx_jit(**new_inputs) # run 20 times for _ in range(20): - new_inputs = {k:Tensor.randn(*shp, dtype=_from_np_dtype(input_types[k])).mul(8).realize() for k,shp in sorted(input_shapes.items())} + new_inputs = get_new_inputs(input_shapes) GlobalCounters.reset() st = time.perf_counter() out = run_onnx_jit(**new_inputs) diff --git a/examples/test_onnx_imagenet.py b/examples/test_onnx_imagenet.py index a8e5a8c56a..1f27e23e0b 100644 --- a/examples/test_onnx_imagenet.py +++ b/examples/test_onnx_imagenet.py @@ -17,6 +17,10 @@ from tinygrad.helpers import fetch, getenv # https://huggingface.co/qualcomm/MobileNet-v2-Quantized/resolve/main/MobileNet-v2-Quantized.onnx # ~35% - https://github.com/axinc-ai/onnx-quantization/raw/refs/heads/main/models/mobilenev2_quantized.onnx +# QUANT=1 python3 examples/test_onnx_imagenet.py +# https://github.com/xamcat/mobcat-samples/raw/refs/heads/master/onnx_runtime/InferencingSample/InferencingSample/mobilenetv2-7.onnx +# VIZ=1 DONT_REALIZE_EXPAND=1 python3 examples/benchmark_onnx.py /tmp/model.quant.onnx + def imagenet_dataloader(cnt=0): input_mean = Tensor([0.485, 0.456, 0.406]).reshape(1, -1, 1, 1) input_std = Tensor([0.229, 0.224, 0.225]).reshape(1, -1, 1, 1) @@ -61,7 +65,7 @@ if __name__ == "__main__": assert shape[1:] == (3,224,224), f"shape is {shape}" hit = 0 - for i,(img,y) in enumerate(imagenet_dataloader()): + for i,(img,y) in enumerate(imagenet_dataloader(cnt=100)): p = run_onnx_jit(**{t_name:img}) assert p.shape == (1,1000) t = p.argmax().item() diff --git a/extra/dsp/compile.py b/extra/dsp/compile.py index 93ee4cf9c0..cb3c18a880 100755 --- a/extra/dsp/compile.py +++ b/extra/dsp/compile.py @@ -37,7 +37,7 @@ if __name__ == "__main__": print("mmapped", hex(res)) to_mv(res, 0x10)[1] = 0xaa - from tinygrad.runtime.ops_clang import ClangCompiler + from tinygrad.runtime.ops_dsp import ClangCompiler cc = ClangCompiler(args=["--target=hexagon", "-mcpu=hexagonv65", "-fuse-ld=lld", "-nostdlib"]) obj = cc.compile(""" diff --git a/extra/dsp/opt.py b/extra/dsp/opt.py new file mode 100644 index 0000000000..fbe35e7ccb --- /dev/null +++ b/extra/dsp/opt.py @@ -0,0 +1,27 @@ +from tinygrad.runtime.ops_dsp import DSPCompiler + +# PATH=/opt/homebrew/opt/llvm/bin:$PATH python3 extra/dsp/opt.py + +if __name__ == "__main__": + compiler = DSPCompiler() + + lib = compiler.compile(""" +typedef long HVX_Vector __attribute__((__vector_size__(128))) __attribute__ ((aligned(128))); +typedef long HVX_VectorPair __attribute__((__vector_size__(256))) __attribute__ ((aligned(256))); + +void test(unsigned char *c, unsigned char *a, unsigned char *b) { + HVX_Vector t0 = *(HVX_Vector*)a; + //HVX_VectorPair t1 = *((HVX_VectorPair*)b); + HVX_Vector acc = __builtin_HEXAGON_V6_vd0_128B(); + for (int i = 0; i < 128; i++) { + //__builtin_HEXAGON_V6_lvsplatb_128B(t0[i]) + //acc += __builtin_HEXAGON_V6_lvsplatb_128B(t0[i]) * t1; + //acc += t0[i] * t1; + unsigned int t1 = ((unsigned int *)b)[i]; + //acc = __builtin_HEXAGON_V6_vrmpyub_acc_128B(acc, t0, t1); + acc = __builtin_HEXAGON_V6_vrmpybus_acc_128B(acc, t0, t1); + } + *((HVX_Vector*)c) = acc; +}""") + + compiler.disassemble(lib) diff --git a/extra/onnx.py b/extra/onnx.py index fbf8e69904..81bb57199a 100644 --- a/extra/onnx.py +++ b/extra/onnx.py @@ -1,8 +1,8 @@ -from typing import Callable, Any, Sequence -import importlib, functools, dataclasses -from tinygrad.tensor import Tensor -from tinygrad.helpers import getenv, DEBUG, all_same -from tinygrad.dtype import DType, ConstType, dtypes +from typing import Any, Sequence, cast, Literal, Callable +import dataclasses, functools, io, math, types +from tinygrad.tensor import Tensor, _broadcast_shape, ReductionStr +from tinygrad.helpers import getenv, DEBUG, all_same, prod, flatten, make_tuple +from tinygrad.dtype import DType, ConstType, dtypes, ImageDType from tinygrad.device import is_dtype_supported # ***** protobuf parsing ****** @@ -111,11 +111,11 @@ limit = int(getenv("ONNXLIMIT", "-1")) class OnnxRunner: def __init__(self, model: ModelProto): # parse model protobuf - self.is_training = any(n.HasField("domain") and n.domain == "ai.onnx.preview.training" for n in model.graph.node) + self.is_training = any(n.domain in {"ai.onnx.training", "ai.onnx.preview.training"} for n in model.graph.node) self.old_training, self.old_no_grad = Tensor.training, Tensor.no_grad Tensor.training = True if self.is_training else False Tensor.no_grad = False if self.is_training else True - self.graph_values = {x.name:buffer_parse(x) for x in model.graph.initializer} + self.graph_values = {"": None, **{x.name:buffer_parse(x) for x in model.graph.initializer}} self.graph_inputs = {x.name:type_parse(x.type) for x in model.graph.input if x.name not in self.graph_values} self.graph_outputs = {x.name:type_parse(x.type) for x in model.graph.output} self.graph_nodes = tuple(OnnxNode(num, n.op_type, tuple(n.input), tuple(n.output), {x.name:attribute_parse(x) for x in n.attribute}) @@ -123,14 +123,7 @@ class OnnxRunner: self.opset_version = model.opset_import[0].version self.variable_dims: dict[str, int] = {} - # TODO: move extra.onnx_ops here so we don't have to deal with annoying circular import - # TODO: clean up opset stuff after moving extra.onnx_ops here - self.onnx_ops_module = importlib.import_module('extra.onnx_ops') - self.onnx_ops = { - **{op: getattr(Tensor, op.lower()) for op in ("Neg", "Reciprocal", "Pow", "Sqrt", "Sign", "Abs", "Exp", "Log", "Mish", "Sin", "Cos", "Tan", - "Asin", "Acos", "Atan", "Relu", "Sigmoid", "MatMul", "Floor", "Ceil", "IsInf", "IsNaN", "Softplus", "HardSwish", "Where", "Mul", "Sinh", "Cosh", - "Tanh", "Softsign", "Asinh", "Acosh", "Atanh", "Elu", "Celu", "Selu", "Xor", "Round", "Erf", "Mod")}, - } + self.onnx_ops = onnx_ops def _parse_input(self, name: str, value: Any, spec: OnnxValue): if spec.is_optional and value is None: return None @@ -148,9 +141,8 @@ class OnnxRunner: return tensor def _dispatch_op(self, op, inps, opts): - if op in self.onnx_ops: return self.onnx_ops[op](*inps, **opts) - if hasattr(self.onnx_ops_module, op): - fxn = getattr(self.onnx_ops_module, op) + if op in self.onnx_ops: + fxn = self.onnx_ops[op] if isinstance(fxn, dict): for k in sorted(fxn.keys()): if k <= self.opset_version: @@ -165,7 +157,7 @@ class OnnxRunner: self.graph_values[name] = self._parse_input(name, inputs[name], input_spec) for node in self.graph_nodes: - inps = [to_python_const(self.graph_values.get(name), node.op, i) for i,name in enumerate(node.inputs)] + inps = [to_python_const(self.graph_values[name], node.op, i) for i,name in enumerate(node.inputs)] opts = node.opts # provide additional opts @@ -184,4 +176,623 @@ class OnnxRunner: Tensor.training, Tensor.no_grad = self.old_training, self.old_no_grad return {name:self.graph_values[name] for name in node.outputs} Tensor.training, Tensor.no_grad = self.old_training, self.old_no_grad - return {name:self.graph_values[name] for name in self.graph_outputs} \ No newline at end of file + return {name:self.graph_values[name] for name in self.graph_outputs} + +#################### +##### ONNX OPS ##### +#################### +def get_onnx_ops(): + # ***** helper functions ***** + def _axes(axes, noop_with_empty_axes): return axes or ([] if noop_with_empty_axes else None) + + # (padding_top, padding_left, ..., padding_bottom, padding_right, ...) -> (padding_left, padding_right, padding_top, padding_bottom, ...) + def _onnx_pads_to_tiny_pads(pads): return tuple(flatten(reversed(list(zip(pads, pads[len(pads)//2:]))))) + + AUTO_PAD_OPTIONS = Literal["NOTSET", "SAME_UPPER", "SAME_LOWER", "VALID"] + # (padding_height, padding_width) -> (padding_top, padding_left, padding_bottom, padding_right) + def _auto_pad(pads, auto_pad: AUTO_PAD_OPTIONS): + if auto_pad == "SAME_UPPER": return [pads[i]//2 for i in range(len(pads))] + [pads[i]-pads[i]//2 for i in range(len(pads))] + return [pads[i]-pads[i]//2 for i in range(len(pads))] + [pads[i]//2 for i in range(len(pads))] + + def _resolve_pool_pads(x:Tensor, p_, k_, d_, s_, auto_pad:AUTO_PAD_OPTIONS): + i_, (s_,d_,p_) = x.shape[-len(k_):], (make_tuple(x, len(k_)*2) for x in (s_, d_, p_)) + if auto_pad == "NOTSET": return _onnx_pads_to_tiny_pads(p_ if len(p_)==len(k_)*2 else p_*2) + o_ = [((i - (1 if auto_pad in ("SAME_UPPER", "SAME_LOWER") else k)) // s + 1) for i,k,s in zip(i_, k_, s_)] + return _onnx_pads_to_tiny_pads(_auto_pad([(o-1)*s+k-i for o,i,k,s in zip(o_, i_, k_, s_)], auto_pad)) + + def _clamp_cast(x:Tensor, dtype:DType): return x.clamp(dtypes.min(dtype), dtypes.max(dtype)).cast(dtype) + + def _prepare_quantize(x, scale, zero_point, axis=1, block_size=0): + if axis < 0: axis += x.ndim + if not isinstance(zero_point, Tensor): zero_point = Tensor(zero_point, dtype=dtypes.uint8)._broadcast_to(scale.shape) + if block_size == 0: + shape = (*[1]*axis, *scale.shape, *[1]*(x.ndim - axis - scale.ndim)) + return scale.reshape(shape), zero_point.reshape(shape) + return scale.repeat_interleave(block_size, dim=axis), zero_point.repeat_interleave(block_size, dim=axis) + + def _op_integer(op, inputs:list[Tensor], zero_points:list[Tensor], **opts): + adjusted_inputs = [inp.int() - zp for inp, zp in zip(inputs, zero_points)] + return op(*adjusted_inputs, **opts) + + def _qlinearop_quantized(op, inputs:list[Tensor], zero_points:list[Tensor], scales:list[Tensor], out_scale:Tensor, out_zero_point:Tensor, **opts): + # op execution is done in quantized int + out = _op_integer(op, inputs, zero_points, **opts) + assert dtypes.is_int(out.dtype), "quantized op should've done math in int" + out_quantized = (out * prod(scales) / out_scale).round() + out_zero_point + return _clamp_cast(out_quantized, out_zero_point.dtype) + + def _qlinearop_float(op, inputs:list[Tensor], zero_points:list[Tensor], scales:list[Tensor], out_scale:Tensor, out_zero_point:Tensor, **opts): + # op execution is done in float32 + dequantized_inputs = [(inp.int() - zp) * scale for inp, zp, scale in zip(inputs, zero_points, scales)] + out = op(*dequantized_inputs, **opts) + assert dtypes.is_float(out.dtype), "op should've done math in float" + out_quantized = (out / out_scale).round() + out_zero_point + return _clamp_cast(out_quantized, out_zero_point.dtype) + + def _onnx_training(input_group_size): + def __decorator(func): + def ___wrapper(R:Tensor, T:int, *inputs:Tensor, **kwargs): + R = R.detach() + groups = len(inputs) // input_group_size + ret = [func(R, T, *inps, **kwargs) for inps in (inputs[i::groups] for i in range(groups))] + return tuple(flatten(zip(*ret))) + return ___wrapper + return __decorator + + # ***** Property/Graph Ops ***** + def Identity(x:Tensor): return x + def Constant(sparse_value:Tensor|None=None, value:Tensor|None=None, value_float:float|None=None, value_floats:list[float]|None=None, + value_int:int|None=None, value_ints:list[int]|None=None, value_string:str|None=None, value_strings:list[str]|None=None): + if value is not None: return value + if value_float is not None: return Tensor(value_float, dtype=dtypes.float32, requires_grad=False) + if value_floats is not None: return Tensor(list(value_floats), dtype=dtypes.float32, requires_grad=False) + if value_int is not None: return Tensor(value_int, dtype=dtypes.int64, requires_grad=False) + if value_ints is not None: return Tensor(list(value_ints), dtype=dtypes.int64, requires_grad=False) + if value_string is not None or value_strings is not None and sparse_value is not None: + raise NotImplementedError('Constant OP not implemented for value_string, value_strings and sparse_value') + + def Range(start:float|int, limit:float|int, delta:float|int): return Tensor.arange(start=start, stop=limit, step=delta) + + def ImageDecoder(encoded_stream:bytes, pixel_format="RGB"): + try: import PIL.Image + except ImportError as e: raise ImportError("Pillow must be installed for the ImageDecoder operator") from e + img = PIL.Image.open(io.BytesIO(encoded_stream)) + if pixel_format == "BGR": return Tensor(np.array(img))[:, :, ::-1] + if pixel_format == "RGB": return Tensor(np.array(img)) + if pixel_format == "Grayscale": return Tensor(np.array(img.convert("L"))).unsqueeze(-1) # (H, W) to (H, W, 1) + raise ValueError(f"pixel_format={pixel_format!r} is not supported.") + + def EyeLike(x:Tensor, dtype:int|None=None, k:int=0): + ret = Tensor.eye(cast(int, min(x.shape)), dtype=dtype_parse(dtype) if dtype is not None else x.dtype) + return ret if x.size(0) == x.size(1) else ret.pad(tuple(None if d == ret.size(0) else (k, d-ret.shape[0]-k) for d in x.shape)) + + def OptionalHasElement(x:Tensor|None=None): return Tensor(x is not None and x.numel() > 0) + def OptionalGetElement(x:Tensor|None=None): return x if x is not None else Tensor([]) + def ConstantOfShape(shape:list[int], value:Tensor|None=None): + if value is None: value = Tensor(0, dtype=dtypes.float32) + return Tensor.ones(*shape, dtype=value.dtype) * (value if shape != [0] else 1) + + def Size(data:Tensor): return data.numel() + def Shape(data:Tensor, end:int|None=None, start:int=0): return Tensor(data.shape[start:end], dtype=dtypes.int64) + + # ***** Unary Ops (math) ***** + def Not(x:Tensor): return x.logical_not() + def Clip(x: Tensor, min:Tensor|None=None, max:Tensor|None=None): + return x.clip(float('-inf') if min is None else min, float('inf') if max is None else max).cast(x.dtype) + + # ***** Unary Ops (activation) ***** + def Softmax_1(x:Tensor, axis:int=1): return x.softmax(axis) + def Softmax_13(x:Tensor, axis:int=-1): return x.softmax(axis) + Softmax = {1:Softmax_1, 13:Softmax_13} + def HardSigmoid(x:Tensor, alpha:float=0.2, beta:float=0.5): return (alpha*x + beta).clip(0, 1) + def Gelu(x:Tensor, approximate:str|None=None): return x.gelu() if approximate == "tanh" else 0.5 * x * (1 + (x/math.sqrt(2)).erf()) + def FastGelu(x:Tensor, bias:Tensor|None=None): + # this is tanh approximated + return (x + bias).gelu() if bias is not None else x.gelu() + # TODO: fix this + def PRelu(X:Tensor, slope:Tensor): + slope = slope[0] if slope.shape[-1] != X.shape[-1] else slope + return (X > 0).where(X, X * slope) + def LeakyRelu(X:Tensor, alpha:float=0.01): return X.leakyrelu(alpha) + def ThresholdedRelu(X:Tensor, alpha:float=1.0): return (X > alpha).where(X, 0) + def LogSoftmax(x: Tensor, axis:int=-1): return x.log_softmax(axis) + def Binarizer(x:Tensor, threshold:float=0.0): return (x > threshold).float() + + # ***** Unary Ops (broadcasted) ***** + def Add(x:Tensor,y:Tensor, broadcast=None, axis=None): return x + y if x.dtype == dtypes.float or isinstance(x.dtype, ImageDType) else (x + y).cast(x.dtype) + def Sub(x:Tensor|int,y:Tensor): return x - y # some test has input as int + def Div(x:Tensor,y:Tensor): return (x/y).cast(x.dtype) + def Less(x:Tensor,y:Tensor): return x < y + def LessOrEqual(x:Tensor,y:Tensor): return x <= y + def Greater(x:Tensor,y:Tensor): return x > y + def GreaterOrEqual(x:Tensor,y:Tensor): return x >= y + def Equal(x:Tensor,y:Tensor): return x == y + def And(x:Tensor,y:Tensor): return (x==y).where(x, False) + def Or(x:Tensor,y:Tensor): return (x==y).where(x, True) + def BitwiseAnd(x:Tensor,y:Tensor): return x & y + def BitwiseOr(x:Tensor,y:Tensor): return x | y + def BitwiseXor(x:Tensor,y:Tensor): return x ^ y + def BitwiseNot(x:Tensor): return ~x + + # ***** Casting Ops ***** + # TODO: saturate + def Cast(x:Tensor, to:int, saturate:int=1): return x.cast(dtype_parse(to)) + def CastLike(x:Tensor, target_type:Tensor, saturate:int=1): return x.cast(target_type.dtype) + + # ***** Reduce Ops ***** + def Max(*data_0:Tensor): return functools.reduce(Tensor.maximum, data_0) + def Min(*data_0:Tensor): return functools.reduce(Tensor.minimum, data_0) + def Sum(*data_0:Tensor): return functools.reduce(Tensor.add, data_0) + def Mean(*data_0:Tensor): return Sum(*data_0) / len(data_0) + def ReduceMax(data:Tensor, axes:list[int]|None=None, keepdims:int=1, noop_with_empty_axes:int=0): + return data.max(_axes(axes, noop_with_empty_axes), keepdim=keepdims) + def ReduceMin(data:Tensor, axes:list[int]|None=None, keepdims:int=1, noop_with_empty_axes:int=0): + return data.min(_axes(axes, noop_with_empty_axes), keepdim=keepdims) + def ReduceSum(data:Tensor, axes:list[int]|None=None, keepdims:int=1, noop_with_empty_axes:int=0): + return data.sum(_axes(axes, noop_with_empty_axes), keepdim=keepdims) + def ReduceMean(data:Tensor, axes:list[int]|None=None, keepdims:int=1, noop_with_empty_axes:int=0): + return data.mean(_axes(axes, noop_with_empty_axes), keepdim=keepdims) + def ReduceSumSquare(data:Tensor, axes:list[int]|None=None, keepdims:int=1, noop_with_empty_axes:int=0): + return ReduceSum(data.square(), axes, keepdims, noop_with_empty_axes) + def ReduceProd(data:Tensor, axes:list[int]|None=None, keepdims:int=1, noop_with_empty_axes:int=0): + return data.prod(_axes(axes, noop_with_empty_axes), keepdim=keepdims) + def ReduceL1(data:Tensor, axes:list[int]|None=None, keepdims:int=1, noop_with_empty_axes:int=0): + return ReduceSum(data.abs(), axes, keepdims, noop_with_empty_axes) + def ReduceL2(data:Tensor, axes:list[int]|None=None, keepdims:int=1, noop_with_empty_axes:int=0): + return ReduceSumSquare(data, axes, keepdims, noop_with_empty_axes).sqrt() + def ReduceLogSum(data:Tensor, axes:list[int]|None=None, keepdims:int=1, noop_with_empty_axes:int=0): + return ReduceSum(data, axes, keepdims, noop_with_empty_axes).log() + def ReduceLogSumExp(data:Tensor, axes:list[int]|None=None, keepdims:int=1, noop_with_empty_axes:int=0): + return ReduceSum(data.exp(), axes, keepdims, noop_with_empty_axes).log() + def ArgMax(x:Tensor, axis:int=0, keepdims:int=1, select_last_index:int=0): + if select_last_index: return ((x.shape[axis]-1) - x.flip(axis).argmax(axis, keepdim=keepdims)).cast(dtypes.int64) + return x.argmax(axis, keepdim=keepdims).cast(dtypes.int64) + def ArgMin(x, axis:int=0, keepdims:int=1, select_last_index:int=0): + return ArgMax(-x, axis=axis, keepdims=keepdims, select_last_index=select_last_index) + + # ***** Movement Ops ***** + def Reshape(data:Tensor, shape:list[int], allowzero:int=0): + return data.reshape([x if x != 0 else (0 if allowzero else data.shape[i]) for i,x in enumerate(shape)]) + def Flatten(x:Tensor, axis:int=1): return x.reshape(prod(x.shape[0:axis]), -1) + def Expand(x:Tensor, shape:list[int]): return x.expand(_broadcast_shape(x.shape, tuple(shape))) + def Shrink(x:Tensor, bias:float=0.0, lambd:float=0.5): return (x < -lambd)*(x+bias) + (x > lambd)*(x-bias) + def Transpose(x:Tensor, perm:list[int]|None=None): return x.permute(order=list(range(x.ndim)[::-1]) if perm is None else perm) + + # TODO: add test for when axes is None + def Squeeze(data:Tensor, axes:list[int]|None=None): + return data.squeeze() if axes is None else functools.reduce(lambda d, dim: d.squeeze(dim), sorted(axes, reverse=True), data) + def Unsqueeze(data:Tensor, axes:list[int]): return functools.reduce(lambda d, dim: d.unsqueeze(dim), sorted(axes), data) + + def Tile(x:Tensor, repeats:list[int]): return x.repeat(repeats) + def Concat(*xs:Tensor, axis:int): return Tensor.cat(*xs, dim=axis) + def Slice(data:Tensor, starts:list[int], ends:list[int], axes:list[int]|None=None, steps:list[int]|None=None): + axes = axes or list(range(data.ndim)) + steps = steps or [1]*data.ndim + slices = [slice(0,x,1) for x in data.shape] + for i, axis in enumerate(axes): slices[axis] = slice(starts[i], ends[i], steps[i]) + return data[tuple(slices)] + + def Split(data:Tensor, split:list[int]|None=None, num_outputs:int=0, axis:int=0): + sz = data.shape[axis] + if split is None: split = [sz // num_outputs + (1 if i < sz % num_outputs else 0) for i in range(num_outputs)] + return data.split(split, axis) + + def Pad(x:Tensor, pads:list[int], constant_value:ConstType|None=None, axes:list[int]|None=None, + mode:Literal["constant", "reflect", "edge", "wrap"]="constant", value=0): + value = constant_value or value + axes = axes or list(range(x.ndim)) + real_pads = [0] * (x.ndim*2) + for i,axis in enumerate(axes): real_pads[axis%x.ndim], real_pads[axis%x.ndim+x.ndim] = pads[i], pads[i+len(axes)] + return x.pad(padding=_onnx_pads_to_tiny_pads(real_pads), mode={"edge":"replicate", "wrap":"circular"}.get(mode, mode), value=value) + + def CenterCropPad(t:Tensor, shape:list[int], axes:list[int]|None=None): + shrink_arg:list[None|tuple[int,int]] = [None] * t.ndim + pad_arg:list[None|tuple[int,int]] = [None] * t.ndim + for s, x in zip(shape, axes or range(t.ndim)): + tx = t.shape[x] + if s < tx: shrink_arg[x] = (tx//2 - (s+1)//2, tx//2 + s//2) + elif s > tx: pad_arg[x] = ((s-tx)//2, (s-tx+1)//2) + return t.shrink(tuple(shrink_arg)).pad(tuple(pad_arg)) + + # ***** Processing Ops ***** + def AveragePool(X: Tensor, kernel_shape:list[int], auto_pad:AUTO_PAD_OPTIONS="NOTSET", ceil_mode:int=0, count_include_pad:int=0, + dilations:list[int]|int=1, pads:list[int]|int=0, strides:list[int]|int=1): + return X.avg_pool2d(kernel_shape, strides, dilations, _resolve_pool_pads(X, pads, kernel_shape, dilations, strides, auto_pad), + ceil_mode=ceil_mode, count_include_pad=count_include_pad) + + def MaxPool(X: Tensor, kernel_shape:list[int], auto_pad:AUTO_PAD_OPTIONS="NOTSET", ceil_mode:int=0, dilations:list[int]|int=1, pads:list[int]|int=0, + storage_order:int=0, strides:list[int]|int=1): + ret = X.max_pool2d(kernel_shape, strides, dilations, _resolve_pool_pads(X, pads, kernel_shape, dilations, strides, auto_pad), ceil_mode=ceil_mode) + # tests expect indices with int64 dtype + # TODO: if there are repeated values, this is wrong + indices = ((ret.reshape(-1, 1) == X.reshape(1, -1)) * Tensor.arange(X.numel(), dtype=dtypes.int64).unsqueeze(0)).sum(1).reshape(ret.shape) + return ret.cast(X.dtype), indices.transpose(-2, -1) if storage_order else indices + + def Conv(X: Tensor, W: Tensor, B:Tensor|None=None, auto_pad:AUTO_PAD_OPTIONS="NOTSET", dilations:list[int]|int=1, group:int=1, + kernel_shape:list[int]|None=None, pads:list[int]|int=0, strides:list[int]|int=1): + return X.conv2d(W, B, stride=strides, groups=group, dilation=dilations, + padding=_resolve_pool_pads(X, pads, kernel_shape or W.shape[2:], dilations, strides, auto_pad)) + + def ConvTranspose(X: Tensor, W: Tensor, B:Tensor|None=None, auto_pad:AUTO_PAD_OPTIONS="NOTSET", dilations:list[int]|int=1, group:int=1, + kernel_shape:list[int]|None=None, pads:list[int]|None=None, output_shape:list[int]|None=None, output_padding:list[int]|int=0, + strides:list[int]|int=1): + input_shape, kernel_shape = X.shape[2:], (kernel_shape or W.shape[2:]) + strides, dilations, output_padding = (make_tuple(x, len(input_shape)) for x in (strides, dilations, output_padding)) + if output_shape is not None: # we pad according to output_shape + pads = _auto_pad([s*(i-1) + op + ((k-1)*d+1) - os for s,i,op,k,d,os in + zip(strides, input_shape, output_padding, kernel_shape, dilations, output_shape)], auto_pad) + if pads is None: # we generate pads + output_shape = output_shape or [X.shape[i+2] * strides[i] for i in range(len(strides))] + pads = [strides[i]*(input_shape[i]-1) + output_padding[i] + ((kernel_shape[i]-1)*dilations[i]+1)-output_shape[i] for i in range(len(input_shape))] + pads = _auto_pad(pads, auto_pad) if auto_pad != "NOTSET" else [0] * len(input_shape) * 2 + pads = _onnx_pads_to_tiny_pads(pads) + return X.conv_transpose2d(W, B, stride=strides, groups=group, dilation=dilations, padding=pads, output_padding=output_padding) + + def MaxUnpool(xT: Tensor, xI: Tensor, outshape: list[int]|None=None, kernel_shape:list[int]=None, pads:list[int]|int=0, strides:list[int]|int=1): + pads, strides = (make_tuple(x, len(xI.shape)) for x in (pads, strides)) + out_sh = [(ks//2)*2 + st * inps for inps, st, ks in zip(xI.shape, strides, kernel_shape)] + ret = (xI.reshape(-1, 1)._one_hot_along_dim(prod(out_sh)) * xT.reshape(-1, 1)).sum(0).reshape(1, 1, *out_sh) + if outshape is not None and outshape != ret.shape: pads = _auto_pad([outshape[-2] - ret.shape[-2], outshape[-1] - ret.shape[-1]], "SAME_UPPER") + return ret.pad(_onnx_pads_to_tiny_pads(pads)) + + def GlobalAveragePool(X:Tensor): return X.mean(axis=tuple(range(2, X.ndim)), keepdim=True) + def GlobalMaxPool(X:Tensor): return X.max(axis=tuple(range(2, X.ndim)), keepdim=True) + + def Gemm(A:Tensor, B:Tensor, C:Tensor|None=None, alpha:float=1.0, beta:float=1.0, transA:int=0, transB:int=0, broadcast=0): + ret = alpha * (A.transpose(transA) @ B.transpose(transB)) + if C is not None: ret = ret + beta * (C if broadcast == 0 else C.reshape([-1 if i < len(C.shape) else 1 for i in range(ret.ndim)][::-1])) + return ret + + def Einsum(*Inputs:list[Tensor], equation:str): return Tensor.einsum(equation, *Inputs) + + def CumSum(X:Tensor, axis:int|list, exclusive:int=0, reverse:int=0): + axis = X._resolve_dim(axis[0] if isinstance(axis, list) else axis) + if reverse: X = X.flip(axis) + if exclusive: X = X.pad(tuple((1,0) if i == axis else None for i in range(X.ndim)))\ + .shrink(tuple((0,X.shape[axis]) if i == axis else None for i in range(X.ndim))) + return X.cumsum(axis).flip(axis) if reverse else X.cumsum(axis) + + def Trilu(x:Tensor, k:int=0, upper:int=1): return x.triu(k) if upper else x.tril(k) + + def Resize(X:Tensor, roi:list[float]|None=None, scales:list[float]|None=None, sizes:list[int]|None=None, antialias:int=0, + axes:list[int]|None=None, coordinate_transformation_mode:str='half_pixel', cubic_coeff_a:float=-0.75, exclude_outside:int=0, + extrapolation_value:float=0.0, keep_aspect_ratio_policy:str='stretch', mode:str='nearest', nearest_mode:str='round_prefer_floor'): + def _apply_nearest_mode(index: Tensor, input_dim, mode: str): + if mode == "round_prefer_floor": index = (index - 0.5).ceil() + elif mode == "round_prefer_ceil": index = (index + 0.5).floor() + elif mode in ["floor", "ceil"]: index = getattr(index, mode)() + else: raise ValueError(f"invalid {nearest_mode=}") + return index.cast(dtypes.int32).clip(0, input_dim-1) + def _apply_transformation(index: Tensor, input_dim, scale_dim, roi_dim, mode): + # TODO: needs more testing, not confident in this + # NOTE: their reference implementation differ from the implementation in their reference docs + # https://github.com/onnx/onnx/blob/main/onnx/reference/ops/op_resize.py + # https://github.com/onnx/onnx/blob/main/docs/Operators.md#Resize + output_dim = scale_dim * input_dim + if mode == "half_pixel": index = (index + 0.5) / scale_dim - 0.5 + elif mode == "align_corners": index = index * (input_dim - 1) / (output_dim - 1) if output_dim != 1 else Tensor([0]) + elif mode == "asymmetric": index = index / scale_dim + elif mode == "pytorch_half_pixel": index = (index + 0.5) / scale_dim - 0.5 if output_dim != 1 else Tensor([-0.5]) + elif mode == "half_pixel_symmetric": index = input_dim / 2 * (1 - int(output_dim) / output_dim) + (index + 0.5) / scale_dim - 0.5 + elif mode == "tf_crop_and_resize": index = roi_dim[0] * (input_dim - 1) + index * ((roi_dim[1] - roi_dim[0]) * (input_dim - 1) / (output_dim - 1)) + else: raise ValueError(f"invalid {coordinate_transformation_mode=}") + return index.clip(0, input_dim-1) + + scales, sizes = (None if scales is None else scales[2-(X.ndim-len(scales)):]), (None if sizes is None else sizes[2-(X.ndim-len(sizes)):]) + # we pre permute the axes and permute back after resize + axes, input_shape, = (axes or list(range(X.ndim))), cast(tuple[int, ...], X.shape[2:]), + perm = [a for a in range(len(X.shape)) if a not in axes] + list(axes) + X = X.permute(*perm) + + if sizes is not None: + if keep_aspect_ratio_policy in ["not_larger", "not_smaller"]: + scale_fxn = min if keep_aspect_ratio_policy == "not_larger" else max + scales = [scale_fxn([sizes[i] / input_shape[i] for i in range(len(input_shape)) if i+2 in axes])] * 2 + sizes = [int((scales[0] * input_shape[i]) + 0.5) if i+2 in axes else input_shape[i] for i in range(X.ndim-2)] + else: + scales = [size / input_shape for size, input_shape in zip(sizes, input_shape)] + else: + sizes = [int(sc*sh) for sc, sh in zip(scales, input_shape)] + regions = [[st, ed] for st, ed in zip(roi, roi[len(roi)//2:])] if isinstance(roi, list) and roi else [[0.0, 0.0]] * (X.ndim-2) + + # NOTE: this transformation makes it so that we can't just call Tensor.interpolate + # in Tensor.interpolate, we use indexes without any transformation + indexes = [] + for shape, size, scale, region in zip(input_shape, sizes, scales, regions): + indexes.append(_apply_transformation(Tensor.arange(size), shape, scale, region, coordinate_transformation_mode)) + + if mode == "nearest": + indexes = [_apply_nearest_mode(index, shape, nearest_mode) for (index, shape) in zip(indexes, input_shape)] + X = X[(..., *Tensor.meshgrid(*indexes))] + if mode == "linear": + expand = list(X.shape) + for i in range(-len(sizes), 0): + reshape, index = [1] * X.ndim, indexes[i] + reshape[i] = expand[i] = sizes[i] + low, high, perc = [y.reshape(reshape).expand(expand) for y in (index.floor(), index.ceil(), index - index.floor())] + X = X.gather(i, low).lerp(X.gather(i, high), perc) + if mode == "cubic": raise NotImplementedError("cubic interpolation is not implemented") + return X.permute(*[perm.index(i) for i in range(len(perm))]) if perm else X + def Upsample(X, scales, mode): return Resize(X=X, scales=scales, mode=mode) # deprecated + + # ***** Neural Network Ops ***** + # TODO: try to factor out common implementations for these ops + # https://medium.com/@zljdanceholic/groupnorm-then-batchnorm-instancenorm-layernorm-e2b2a1d350a0 + def BatchNormalization(X:Tensor, scale:Tensor, B:Tensor, input_mean:Tensor, input_var:Tensor, epsilon:float=1e-05, momentum:float=0.9, + training_mode:int=0, spatial=1, is_test=0): + if training_mode: + x_detached = X.detach() + current_mean = x_detached.mean(axis=(0,2,3)) + y = (x_detached - current_mean.reshape(shape=[1, -1, 1, 1])) + current_var = (y*y).mean(axis=(0,2,3)) + current_invstd = current_var.add(epsilon).rsqrt() + + running_mean = input_mean * momentum + current_mean * (1 - momentum) + running_var = input_var * momentum + current_var * (1 - momentum) + + return X.batchnorm(scale, B, current_mean, current_invstd), running_mean, running_var + invstd = (input_var + epsilon).rsqrt() + return X.batchnorm(scale, B, input_mean, invstd) + def InstanceNormalization(x:Tensor, scale:Tensor, bias:Tensor, epsilon:float=1e-05): + axis = tuple(range(2, x.ndim)) + mean = x.mean(axis=axis, keepdim=True) + invstd = x.sub(mean).square().mean(axis=axis, keepdim=True).add(epsilon).rsqrt() + return x.sub(mean).mul(scale.reshape(shape=[-1, 1, 1])).mul(invstd).add(bias.reshape(shape=[-1, 1, 1])) + def LayerNormalization(x:Tensor, scale:Tensor, bias:Tensor, axis:int=-1, epsilon:float=1e-05, stash_type:int=1): + assert stash_type == 1, "only float32 is supported" + axes = tuple(i for i in range(axis if axis >= 0 else x.ndim + axis, x.ndim)) + mean = x.mean(axis=axes, keepdim=True) + return x.layernorm(axes, epsilon).mul(scale).add(bias), mean, (x.sub(mean)).square().mean(axis=axes, keepdim=True).add(epsilon).rsqrt() + def GroupNormalization(x:Tensor, scale:Tensor, bias:Tensor, num_groups:int, epsilon:float=1e-05): + return x.reshape(x.shape[0], num_groups, -1).layernorm(axis=-1, eps=epsilon).mul(scale.unsqueeze(-1)).add(bias.unsqueeze(-1)).reshape(x.shape) + def MeanVarianceNormalization(x:Tensor, axis:list[int]=[0,2,3]): + return (x - x.mean(axis, keepdim=True)) / (x.std(axis, keepdim=True, correction=0) + 1e-9) + def SkipLayerNormalization(x:Tensor, skip:Tensor, gamma:Tensor, beta:Tensor|None=None, bias:Tensor|None=None, epsilon:float=1e-12): + x = x + skip + bias + return x.layernorm(eps=epsilon) * gamma + beta, None, None, x + def EmbedLayerNormalization(input_ids: Tensor, segment_ids:Tensor, word_embedding:Tensor, position_embedding:Tensor, + segment_embedding:Tensor, gamma=None, beta=None, mask:Tensor|None=None, + position_ids:Tensor|None=None, epsilon=1e-12, mask_index_type=0): + # https://github.com/microsoft/onnxruntime/blob/main/docs/ContribOperators.md#com.microsoft.EmbedLayerNormalization + assert (segment_ids is None) is (segment_embedding is None) + assert mask is None and not mask_index_type, "functionality not supported yet" # TODO + input_shape = input_ids.shape + seq_length = input_shape[1] + compute_seg_emb = (segment_embedding is not None and segment_ids is not None) + vocab_size, max_position_embeddings = word_embedding.shape[0], position_embedding.shape[0] + type_vocab_size = (segment_embedding.shape[0] if compute_seg_emb else None) + + def embedding(x:Tensor, vocab_size, weight:Tensor) -> Tensor: + return x.unsqueeze(-1).expand(*x.shape, vocab_size)._one_hot_along_dim(vocab_size) @ weight + + # bert embedding layer + if position_ids is None: position_ids = Tensor.arange(seq_length, requires_grad=False).unsqueeze(0).expand(*input_shape) + wrd_embedding_res = embedding(input_ids, vocab_size, word_embedding) + pos_embedding_res = embedding(position_ids, max_position_embeddings, position_embedding) + seg_embedding_res = embedding(segment_ids, type_vocab_size, segment_embedding) if compute_seg_emb else None + + embedding_sum = wrd_embedding_res + pos_embedding_res + if seg_embedding_res is not None: embedding_sum = embedding_sum + seg_embedding_res + out = embedding_sum.layernorm(eps=epsilon) * gamma + beta + return out, None, embedding_sum + + def OneHot(indices:Tensor, depth:float|int|list, values:Tensor, axis:int=-1): + # Scalar or Rank 1 tensor containing exactly one element + depth = int(depth[0] if isinstance(depth, list) else depth) + indices = (indices < 0).where(indices+depth, indices) + return indices[:, None]._one_hot_along_dim(depth, dim=axis).where(values[1], values[0]) + + def DepthToSpace(X:Tensor, blocksize:int, mode:str="DCR"): + return X.rearrange("b (c h1 w1) h w -> b c (h h1) (w w1)" if mode=="CRD" else "b (h1 w1 c) h w -> b c (h h1) (w w1)", h1=blocksize, w1=blocksize) + def SpaceToDepth(X:Tensor, blocksize:int): + return X.rearrange("b c (h h1) (w w1) -> b (h1 w1 c) h w", h1=blocksize, w1=blocksize) + + # Reimplemented here because you need legacy RNG for passing ONNX tests. + def Dropout_7(data:Tensor, ratio:float=0.5, training_mode:bool=False, seed:int|None=None): + if not training_mode: return data, Tensor.ones(data.shape, dtype=dtypes.bool) # if mask is requested as output it will contain all True's. + mask = Tensor(np.random.RandomState(seed).random(cast(tuple[int,...], data.shape)) >= ratio, requires_grad=False, device=data.device) + return data * mask * (1/(1.0 - ratio)), mask + # 6 with 'is_test' needed for https://github.com/MTlab/onnx2caffe/raw/refs/heads/master/model/MobileNetV2.onnx + def Dropout_6(data:Tensor, ratio:float=0.5, is_test=0): return Dropout_7(data, ratio, training_mode=not is_test) + Dropout = {6:Dropout_6, 7:Dropout_7} + + def LRN(x:Tensor, size:int, alpha:float=1e-4, beta:float=0.75, bias:float=1.0): + pooled_x = (x**2).rearrange('b c h w -> b 1 c (h w)').pad((0,0,(size-1)//2, size//2)).avg_pool2d((size, 1), 1) + return x / (pooled_x.reshape(x.shape) * alpha + bias).pow(beta) + + def NegativeLogLikelihoodLoss(x:Tensor, target:Tensor, weight:Tensor|None=None, ignore_index:int|None=None, reduction:ReductionStr="mean"): + return x.nll_loss(target, weight, ignore_index, reduction) + def SoftmaxCrossEntropyLoss(scores:Tensor, labels:Tensor, weights:Tensor|None=None, ignore_index:int|None=None, reduction:ReductionStr="mean"): + log_probs = scores.log_softmax(1) + return log_probs.nll_loss(labels, weights, ignore_index, reduction), log_probs + + def AffineGrid(theta:Tensor, size:list[int], align_corners:int=0): + N, _, *spatial_dims = size + def generate_grid(steps): + return Tensor.linspace(-1, 1, steps, device=theta.device) if align_corners else Tensor.linspace(-1+1/steps, 1-1/steps, steps, device=theta.device) + grids = Tensor.meshgrid(*(generate_grid(d) for d in spatial_dims)) + base_grid = Tensor.stack(*reversed(grids), Tensor.ones_like(grids[0], device=theta.device), dim=-1) + base_grid = base_grid.reshape(1, prod(spatial_dims), len(grids)+1).expand(N, -1, -1) + return (base_grid @ theta.transpose(1, 2)).reshape(N, *spatial_dims, -1) + + def Attention(x:Tensor, weights, bias:Tensor, mask_index:Tensor|None=None, past:Tensor|None=None, + relative_position_bias:Tensor|None=None, past_sequence_length:Tensor|None=None, do_rotary:int|None=None, + mask_filter_value:float|None=None, num_heads:int|None=None, past_present_share_buffer:int|None=None, + qkv_hidden_sizes:list[int]|None=None, scale:float|None=None, unidirectional:int|None=None): + # https://github.com/microsoft/onnxruntime/blob/main/docs/ContribOperators.md#com.microsoft.Attention + assert num_heads is not None # required + assert (qkv_hidden_sizes is None and past is not None) or (qkv_hidden_sizes is not None) + assert relative_position_bias is do_rotary is past_sequence_length is mask_filter_value is past_present_share_buffer is scale is None, \ + "functionality not supported yet" # TODO strange params + hidden_size, v_hidden_size = qkv_hidden_sizes[1:] if qkv_hidden_sizes is not None else 2*(weights.shape[1] // 3,) + + if unidirectional: # gpt-style + assert hidden_size == v_hidden_size + xqkv = x.linear(weights, bias) + xq, xk, xv = [xqkv.shrink([None, None, (i*hidden_size, (i+1)*hidden_size)]) for i in range(3)] + else: # bert-style + wq, wk, wv = weights[:,:hidden_size], weights[:,hidden_size:hidden_size+v_hidden_size], weights[:,hidden_size+v_hidden_size:] + bq, bk, bv = (bias[:hidden_size], bias[hidden_size:hidden_size+v_hidden_size], bias[hidden_size+v_hidden_size]) if bias is not None else None + xq, xk, xv = [x.linear(w, b) for w, b in zip((wq, wk, wv), (bq, bk, bv))] + xq, xk, xv = [x.reshape(x.shape[0], x.shape[1], num_heads, -1).transpose(1, 2) for x in (xq, xk, xv)] + + if past is not None: + xk, xv = Tensor.cat(past[0], xk, dim=-2), Tensor.cat(past[1], xv, dim=-2) + present = Tensor.cat(xk.unsqueeze(0), xv.unsqueeze(0)) + + def attn(query, key, value, attn_mask): + query_length, key_length = query.shape[-2], key.shape[-2] + cdim = max(query_length, key_length) + 1 + attn_weights = query @ key.transpose(-1, -2) / math.sqrt(value.shape[-1]) + # This is where Tensor.scaled_dot_product_attention differs: + causal_mask = Tensor.ones((cdim, cdim), requires_grad=False, dtype=dtypes.bool).tril(0)[key_length - query_length : key_length, :key_length] + masked = Tensor.where(causal_mask, attn_weights, -math.inf) + if attn_mask is not None: masked = masked + attn_mask + return masked.softmax(-1) @ value + + bsz, _, seq_len, _ = xq.shape + out = attn(xq, xk, xv, mask_index).transpose(1, 2).reshape(bsz, seq_len, -1) + return out, present if past is not None else out + + # ***** Indexing Ops ***** + def ArrayFeatureExtractor(x:Tensor, indices:Tensor): return x[..., indices] + + def Gather(x:Tensor, indices:Tensor, axis:int=0): + if indices.numel() < 9: # NOTE lessor kernels for smaller indices but kernel number increases depending on size of indices + x_sh = list(x.shape) + ret_shape = x_sh[:axis] + list(indices.shape) + x_sh[axis+1:] + if indices.ndim > 1: indices = indices.flatten() + indices = [_cached_to_python_const(indices)] if indices.shape == () else [x_sh[axis]+x if x<0 else x for x in _cached_to_python_const(indices)] + args = [[(0,x) if j != axis else (i,i+1) for j, x in enumerate(x_sh)] for i in indices] # type: ignore + return x.shrink(arg=tuple(args[0])).cat(*[x.shrink(arg=tuple(arg)) for arg in args[1:]], dim=axis).reshape(ret_shape) + # NOTE faster gather, fixed number of kernels, but exceeds limited kernels for openpilot + return x[tuple([slice(None) if i != axis else indices for i in range(x.ndim)])] + def Scatter(*args, **kwargs): return ScatterElements(*args, **kwargs) # deprecated + + def GatherND(x:Tensor, indices:Tensor, batch_dims:int=0): + if batch_dims == 0: return x[tuple(i.squeeze(-1) for i in indices.split(1, -1))] + x_shape, i_shape = x.shape, indices.shape + b = math.prod(x.shape[dim] for dim in range(batch_dims)) + # NOTE: each batched dim of both input and indices are equal + x = x.reshape(b, *x.shape[batch_dims:]) + indices = indices.reshape(b, *indices.shape[batch_dims:]) + b_idx = Tensor.arange(b, device=x.device).reshape(b, *(1,)*(indices.ndim - 2)).expand(*indices.shape[:-1]) + ret = x[(b_idx,) + tuple(i.squeeze(-1) for i in indices.split(1, -1))] + return ret.reshape(*x_shape[:batch_dims], *i_shape[batch_dims:-1], *ret.shape[indices.ndim-1:]) + def ScatterND(x:Tensor, indices:Tensor, updates:Tensor, reduction:Literal["none", "add", "mul"]='none'): + assert updates.shape == indices.shape[:-1] + x.shape[cast(int, indices.shape[-1]):] + x = x.contiguous() + for index, u in zip(indices.split(1, 0), updates.split(1, 0)): + i = tuple(idx.squeeze(-1) for idx in index.squeeze(0).split(1, -1)) + u = u.squeeze(0) + if reduction == "none": x[i] = u + elif reduction == "add": x[i] += u + elif reduction == "mul": x[i] *= u + else: raise NotImplementedError("reduction doesn't support max or min") + return x + + def ScatterElements(x: Tensor, indices: Tensor, updates: Tensor, axis=0, reduction:Literal["none", "add", "mul"]="none"): + indices = (indices < 0).where(x.shape[axis], 0) + indices + return x.scatter(axis, indices, updates, {"none":None, "mul": "multiply"}.get(reduction, reduction)) + def GatherElements(x:Tensor, indices:Tensor, axis:int): + indices = (indices < 0).where(x.shape[axis], 0) + indices + return x.gather(axis, indices) + + def Compress(inp:Tensor, condition:list[bool], axis:int|None=None): + if axis is None: + inp = inp.flatten() + axis = 0 + if axis < 0: axis += inp.ndim + con = Tensor(np.arange(len(condition))[condition]) # no boolean indexing in Tensor + return inp[tuple(con if i == axis else slice(None) for i in range(inp.ndim))] + + # ***** Quantization Ops ***** + def QuantizeLinear(x:Tensor, y_scale:Tensor, y_zero_point:Tensor|int=0, axis:int=1, block_size:int=0, output_dtype:int=0, saturate=1): + out_dtype = y_zero_point.dtype if isinstance(y_zero_point, Tensor) else dtype_parse(output_dtype) if output_dtype else dtypes.uint8 + y_scale, y_zero_point = _prepare_quantize(x, y_scale, y_zero_point, axis, block_size) + return _clamp_cast(((x / y_scale).round() + y_zero_point), out_dtype).contiguous() + + def DequantizeLinear(x:Tensor, x_scale:Tensor, x_zero_point:Tensor|int=0, axis:int=1, block_size:int=0): + x_scale, x_zero_point = _prepare_quantize(x, x_scale, x_zero_point, axis, block_size) + return ((x.int() - x_zero_point) * x_scale).cast(x_scale.dtype) + + def QLinearConv(x:Tensor, x_scale:Tensor, x_zero_point:Tensor|int, w:Tensor, w_scale:Tensor, w_zero_point:Tensor|int, y_scale:Tensor, + y_zero_point: Tensor|int, B:Tensor|None=None, **opts): + return _qlinearop_quantized(Conv, [x,w], [x_zero_point,w_zero_point], [x_scale,w_scale], y_scale, y_zero_point, **{"B":B, **opts}) + + def QLinearMatMul(a:Tensor, a_scale:Tensor, a_zero_point:Tensor|int, b:Tensor, b_scale:Tensor, b_zero_point:Tensor|int, y_scale:Tensor, + y_zero_point:Tensor|int) -> Tensor: + return _qlinearop_quantized(Tensor.matmul, [a,b], [a_zero_point,b_zero_point], [a_scale,b_scale], y_scale, y_zero_point) + + def QLinearAdd(a:Tensor, a_scale:Tensor, a_zero_point:Tensor, b:Tensor, b_scale:Tensor, b_zero_point:Tensor, c_scale:Tensor, c_zero_point:Tensor): + return _qlinearop_float(Tensor.add, [a,b], [a_zero_point,b_zero_point], [a_scale,b_scale], c_scale, c_zero_point) + + def QLinearGlobalAveragePool(X:Tensor, x_scale:Tensor, x_zero_point:Tensor, y_scale:Tensor, y_zero_point:Tensor, channels_last:int): + assert channels_last == 0, "unsure what this does" + return _qlinearop_float(GlobalAveragePool, [X], [x_zero_point], [x_scale], y_scale, y_zero_point) + + def ConvInteger(x: Tensor, w: Tensor, x_zero_point: Tensor | int = 0, w_zero_point: Tensor | int = 0, B: Tensor | None = None, **opts) -> Tensor: + return _op_integer(Conv, [x,w], [x_zero_point,w_zero_point], **{"B":B, **opts}) + + def MatMulInteger(A: Tensor, B: Tensor, a_zero_point: Tensor | int = 0, b_zero_point: Tensor | int = 0) -> Tensor: + return _op_integer(Tensor.matmul, [A,B], [a_zero_point,b_zero_point]) + + # ***** Training Ops ***** + # NOTE: onnx test coverage only covers `T==0` cases, so for all `T>0` this isn't tested + # NOTE: onnx training ops actually don't need the state for optim, all the ops work in a functional way, but we still can reuse optim.py code + @_onnx_training(3) + def Adagrad(R:Tensor, T:int, *inputs:Tensor, decay_factor:float=0.0, epsilon:float=0.0, norm_coefficient:float=0.0): + X, G, H = (i.detach() for i in inputs) + grad = norm_coefficient * X + G + H.assign(H + grad.square()) + up = grad / (H.sqrt() + epsilon) + r = R / (1 + T * decay_factor) + X.assign(X.detach() - r * up) + return [X, H] + + @_onnx_training(4) + def Adam(R:Tensor, T:int, *inputs:Tensor, alpha:float=0.9, beta:float=0.999, epsilon:float=0.0, norm_coefficient:float=0.0, + norm_coefficient_post:float=0.0): + from tinygrad.nn.optim import Adam as TinyAdam + X, G, V, H = inputs + G, V, H = G.detach(), V.detach(), H.detach() # TODO we shouldn't need these detaches + X.grad = norm_coefficient * X.detach() + G + opt = TinyAdam([X], b1=alpha, b2=beta, eps=epsilon) + opt.m, opt.v, opt.lr = [V], [H], R + # need no-op for m_hat and v_hat if T == 0 + if T == 0: opt.b1_t, opt.b2_t = opt.b1_t.zeros_like(), opt.b2_t.zeros_like() + else: + # `T-1` since it's applied again at the start of `_step` + opt.b1_t = Tensor([alpha**(T-1)], dtype=dtypes.float32, device=X.device, requires_grad=False) + opt.b2_t = Tensor([beta**(T-1)], dtype=dtypes.float32, device=X.device, requires_grad=False) + opt.step() + X = (1 - norm_coefficient_post) * X + return [X, V, H] + + @_onnx_training(3) + def Momentum(R:Tensor, T:int, *inputs:Tensor, alpha:float, beta:float, mode:str, norm_coefficient:float): + from tinygrad.nn.optim import SGD + X, G, V = inputs + G, V = G.detach(), V.detach() + X.grad = (norm_coefficient * X.detach() + G) * (beta if T > 0 else 1) + opt = SGD([X], momentum=alpha, nesterov=(mode=="nesterov")) + opt.b, opt.lr = [V], R + opt.step() + return [X, V] + + def Gradient(*inputs:Tensor, y:str, intermediate_tensors:dict[str, Tensor], **_): + intermediate_tensors[y].backward() + return tuple([t.grad for t in inputs]) + + return { + # Tensor ops + **{op: getattr(Tensor, op.lower()) for op in ("Neg", "Reciprocal", "Pow", "Sqrt", "Sign", "Abs", "Exp", "Log", "Mish", "Sin", "Cos", "Tan", + "Asin", "Acos", "Atan", "Relu", "Sigmoid", "MatMul", "Floor", "Ceil", "IsInf", "IsNaN", "Softplus", "HardSwish", "Where", "Mul", "Sinh", "Cosh", + "Tanh", "Softsign", "Asinh", "Acosh", "Atanh", "Elu", "Celu", "Selu", "Xor", "Round", "Erf", "Mod")}, + # Implemented ops + **{name:obj for name,obj in locals().items() if isinstance(obj, types.FunctionType) and not name.startswith("_") and name[0].isupper()}, + # Version ops + **{name:obj for name,obj in locals().items() if isinstance(obj, dict)}, + } + +onnx_ops = get_onnx_ops() diff --git a/extra/onnx_ops.py b/extra/onnx_ops.py deleted file mode 100644 index 165de3154b..0000000000 --- a/extra/onnx_ops.py +++ /dev/null @@ -1,606 +0,0 @@ -import functools, io, math -from typing import cast, Literal -from tinygrad.tensor import Tensor, _broadcast_shape, ConstType, ReductionStr -from tinygrad.dtype import ImageDType, dtypes, DType -from tinygrad.helpers import prod, flatten, make_tuple -from extra.onnx import dtype_parse, _cached_to_python_const -import numpy as np - -# ***** Property/Graph Ops ***** -def Identity(x:Tensor): return x -def Constant(sparse_value:Tensor|None=None, value:Tensor|None=None, value_float:float|None=None, value_floats:list[float]|None=None, - value_int:int|None=None, value_ints:list[int]|None=None, value_string:str|None=None, value_strings:list[str]|None=None): - if value is not None: return value - if value_float is not None: return Tensor(value_float, dtype=dtypes.float32, requires_grad=False) - if value_floats is not None: return Tensor(list(value_floats), dtype=dtypes.float32, requires_grad=False) - if value_int is not None: return Tensor(value_int, dtype=dtypes.int64, requires_grad=False) - if value_ints is not None: return Tensor(list(value_ints), dtype=dtypes.int64, requires_grad=False) - if value_string is not None or value_strings is not None and sparse_value is not None: - raise NotImplementedError('Constant OP not implemented for value_string, value_strings and sparse_value') - -def Range(start:float|int, limit:float|int, delta:float|int): return Tensor.arange(start=start, stop=limit, step=delta) - -def ImageDecoder(encoded_stream:bytes, pixel_format="RGB"): - try: import PIL.Image - except ImportError as e: raise ImportError("Pillow must be installed for the ImageDecoder operator") from e - img = PIL.Image.open(io.BytesIO(encoded_stream)) - if pixel_format == "BGR": return Tensor(np.array(img))[:, :, ::-1] - if pixel_format == "RGB": return Tensor(np.array(img)) - if pixel_format == "Grayscale": return Tensor(np.array(img.convert("L"))).unsqueeze(-1) # (H, W) to (H, W, 1) - raise ValueError(f"pixel_format={pixel_format!r} is not supported.") - -def EyeLike(x:Tensor, dtype:int|None=None, k:int=0): - ret = Tensor.eye(cast(int, min(x.shape)), dtype=dtype_parse(dtype) if dtype is not None else x.dtype) - return ret if x.size(0) == x.size(1) else ret.pad(tuple(None if d == ret.size(0) else (k, d-ret.shape[0]-k) for d in x.shape)) - -def OptionalHasElement(x:Tensor|None=None): return Tensor(x is not None and x.numel() > 0) -def OptionalGetElement(x:Tensor|None=None): return x if x is not None else Tensor([]) -def ConstantOfShape(shape:list[int], value:Tensor|None=None): - if value is None: value = Tensor(0, dtype=dtypes.float32) - return Tensor.ones(*shape, dtype=value.dtype) * (value if shape != [0] else 1) - -def Size(data:Tensor): return data.numel() -def Shape(data:Tensor, end:int|None=None, start:int=0): return Tensor(data.shape[start:end], dtype=dtypes.int64) - -# ***** Unary Ops (math) ***** -def Not(x:Tensor): return x.logical_not() -def Clip(x: Tensor, min:Tensor|None=None, max:Tensor|None=None): - return x.clip(float('-inf') if min is None else min, float('inf') if max is None else max).cast(x.dtype) - -# ***** Unary Ops (activation) ***** -def Softmax_1(x:Tensor, axis:int=1): return x.softmax(axis) -def Softmax_13(x:Tensor, axis:int=-1): return x.softmax(axis) -Softmax = {1:Softmax_1, 13:Softmax_13} -def HardSigmoid(x:Tensor, alpha:float=0.2, beta:float=0.5): return (alpha*x + beta).clip(0, 1) -def Gelu(x:Tensor, approximate:str|None=None): return x.gelu() if approximate == "tanh" else 0.5 * x * (1 + (x/math.sqrt(2)).erf()) -def FastGelu(x:Tensor, bias:Tensor|None=None): - # this is tanh approximated - return (x + bias).gelu() if bias is not None else x.gelu() -# TODO: fix this -def PRelu(X:Tensor, slope:Tensor): - slope = slope[0] if slope.shape[-1] != X.shape[-1] else slope - return (X > 0).where(X, X * slope) -def LeakyRelu(X:Tensor, alpha:float=0.01): return X.leakyrelu(alpha) -def ThresholdedRelu(X:Tensor, alpha:float=1.0): return (X > alpha).where(X, 0) -def LogSoftmax(x: Tensor, axis:int=-1): return x.log_softmax(axis) -def Binarizer(x:Tensor, threshold:float=0.0): return (x > threshold).float() - -# ***** Unary Ops (broadcasted) ***** -def Add(x:Tensor,y:Tensor, broadcast=None, axis=None): return x + y if x.dtype == dtypes.float or isinstance(x.dtype, ImageDType) else (x + y).cast(x.dtype) -def Sub(x:Tensor|int,y:Tensor): return x - y # some test has input as int -def Div(x:Tensor,y:Tensor): return (x/y).cast(x.dtype) -def Less(x:Tensor,y:Tensor): return x < y -def LessOrEqual(x:Tensor,y:Tensor): return x <= y -def Greater(x:Tensor,y:Tensor): return x > y -def GreaterOrEqual(x:Tensor,y:Tensor): return x >= y -def Equal(x:Tensor,y:Tensor): return x == y -def And(x:Tensor,y:Tensor): return (x==y).where(x, False) -def Or(x:Tensor,y:Tensor): return (x==y).where(x, True) -def BitwiseAnd(x:Tensor,y:Tensor): return x & y -def BitwiseOr(x:Tensor,y:Tensor): return x | y -def BitwiseXor(x:Tensor,y:Tensor): return x ^ y -def BitwiseNot(x:Tensor): return ~x - -# ***** Casting Ops ***** -# TODO: saturate -def Cast(x:Tensor, to:int, saturate:int=1): return x.cast(dtype_parse(to)) -def CastLike(x:Tensor, target_type:Tensor, saturate:int=1): return x.cast(target_type.dtype) - -# ***** Reduce Ops ***** -def Max(*data_0:Tensor): return functools.reduce(Tensor.maximum, data_0) -def Min(*data_0:Tensor): return functools.reduce(Tensor.minimum, data_0) -def Sum(*data_0:Tensor): return functools.reduce(Tensor.add, data_0) -def Mean(*data_0:Tensor): return Sum(*data_0) / len(data_0) -def _axes(axes, noop_with_empty_axes): return axes or ([] if noop_with_empty_axes else None) -def ReduceMax(data:Tensor, axes:list[int]|None=None, keepdims:int=1, noop_with_empty_axes:int=0): - return data.max(_axes(axes, noop_with_empty_axes), keepdim=keepdims) -def ReduceMin(data:Tensor, axes:list[int]|None=None, keepdims:int=1, noop_with_empty_axes:int=0): - return data.min(_axes(axes, noop_with_empty_axes), keepdim=keepdims) -def ReduceSum(data:Tensor, axes:list[int]|None=None, keepdims:int=1, noop_with_empty_axes:int=0): - return data.sum(_axes(axes, noop_with_empty_axes), keepdim=keepdims) -def ReduceMean(data:Tensor, axes:list[int]|None=None, keepdims:int=1, noop_with_empty_axes:int=0): - return data.mean(_axes(axes, noop_with_empty_axes), keepdim=keepdims) -def ReduceSumSquare(data:Tensor, axes:list[int]|None=None, keepdims:int=1, noop_with_empty_axes:int=0): - return ReduceSum(data.square(), axes, keepdims, noop_with_empty_axes) -def ReduceProd(data:Tensor, axes:list[int]|None=None, keepdims:int=1, noop_with_empty_axes:int=0): - return data.prod(_axes(axes, noop_with_empty_axes), keepdim=keepdims) -def ReduceL1(data:Tensor, axes:list[int]|None=None, keepdims:int=1, noop_with_empty_axes:int=0): - return ReduceSum(data.abs(), axes, keepdims, noop_with_empty_axes) -def ReduceL2(data:Tensor, axes:list[int]|None=None, keepdims:int=1, noop_with_empty_axes:int=0): - return ReduceSumSquare(data, axes, keepdims, noop_with_empty_axes).sqrt() -def ReduceLogSum(data:Tensor, axes:list[int]|None=None, keepdims:int=1, noop_with_empty_axes:int=0): - return ReduceSum(data, axes, keepdims, noop_with_empty_axes).log() -def ReduceLogSumExp(data:Tensor, axes:list[int]|None=None, keepdims:int=1, noop_with_empty_axes:int=0): - return ReduceSum(data.exp(), axes, keepdims, noop_with_empty_axes).log() -def ArgMax(x:Tensor, axis:int=0, keepdims:int=1, select_last_index:int=0): - if select_last_index: return ((x.shape[axis]-1) - x.flip(axis).argmax(axis, keepdim=keepdims)).cast(dtypes.int64) - return x.argmax(axis, keepdim=keepdims).cast(dtypes.int64) -def ArgMin(x, axis:int=0, keepdims:int=1, select_last_index:int=0): - return ArgMax(-x, axis=axis, keepdims=keepdims, select_last_index=select_last_index) - -# ***** Movement Ops ***** -def Reshape(data:Tensor, shape:list[int], allowzero:int=0): - return data.reshape([x if x != 0 else (0 if allowzero else data.shape[i]) for i,x in enumerate(shape)]) -def Flatten(x:Tensor, axis:int=1): return x.reshape(prod(x.shape[0:axis]), -1) -def Expand(x:Tensor, shape:list[int]): return x.expand(_broadcast_shape(x.shape, tuple(shape))) -def Shrink(x:Tensor, bias:float=0.0, lambd:float=0.5): return (x < -lambd)*(x+bias) + (x > lambd)*(x-bias) -def Transpose(x:Tensor, perm:list[int]|None=None): return x.permute(order=list(range(x.ndim)[::-1]) if perm is None else perm) - -# TODO: add test for when axes is None -def Squeeze(data:Tensor, axes:list[int]|None=None): - return data.squeeze() if axes is None else functools.reduce(lambda d, dim: d.squeeze(dim), sorted(axes, reverse=True), data) -def Unsqueeze(data:Tensor, axes:list[int]): return functools.reduce(lambda d, dim: d.unsqueeze(dim), sorted(axes), data) - -def Tile(x:Tensor, repeats:list[int]): return x.repeat(repeats) -def Concat(*xs:Tensor, axis:int): return Tensor.cat(*xs, dim=axis) -def Slice(data:Tensor, starts:list[int], ends:list[int], axes:list[int]|None=None, steps:list[int]|None=None): - axes = axes or list(range(data.ndim)) - steps = steps or [1]*data.ndim - slices = [slice(0,x,1) for x in data.shape] - for i, axis in enumerate(axes): slices[axis] = slice(starts[i], ends[i], steps[i]) - return data[tuple(slices)] - -def Split(data:Tensor, split:list[int]|None=None, num_outputs:int=0, axis:int=0): - sz = data.shape[axis] - if split is None: split = [sz // num_outputs + (1 if i < sz % num_outputs else 0) for i in range(num_outputs)] - return data.split(split, axis) - -def _onnx_pads_to_tiny_pads(pads): - # (padding_top, padding_left, ..., padding_bottom, padding_right, ...) -> (padding_left, padding_right, padding_top, padding_bottom, ...) - return tuple(flatten(reversed(list(zip(pads, pads[len(pads)//2:]))))) -def Pad(x:Tensor, pads:list[int], constant_value:ConstType|None=None, axes:list[int]|None=None, - mode:Literal["constant", "reflect", "edge", "wrap"]="constant", value=0): - value = constant_value or value - axes = axes or list(range(x.ndim)) - real_pads = [0] * (x.ndim*2) - for i,axis in enumerate(axes): real_pads[axis%x.ndim], real_pads[axis%x.ndim+x.ndim] = pads[i], pads[i+len(axes)] - return x.pad(padding=_onnx_pads_to_tiny_pads(real_pads), mode={"edge":"replicate", "wrap":"circular"}.get(mode, mode), value=value) - -def CenterCropPad(t:Tensor, shape:list[int], axes:list[int]|None=None): - shrink_arg:list[None|tuple[int,int]] = [None] * t.ndim - pad_arg:list[None|tuple[int,int]] = [None] * t.ndim - for s, x in zip(shape, axes or range(t.ndim)): - tx = t.shape[x] - if s < tx: shrink_arg[x] = (tx//2 - (s+1)//2, tx//2 + s//2) - elif s > tx: pad_arg[x] = ((s-tx)//2, (s-tx+1)//2) - return t.shrink(tuple(shrink_arg)).pad(tuple(pad_arg)) - -# ***** Processing Ops ***** -AUTO_PAD_OPTIONS = Literal["NOTSET", "SAME_UPPER", "SAME_LOWER", "VALID"] -def _auto_pad(pads, auto_pad: AUTO_PAD_OPTIONS): - # (padding_height, padding_width) -> (padding_top, padding_left, padding_bottom, padding_right) - if auto_pad == "SAME_UPPER": return [pads[i]//2 for i in range(len(pads))] + [pads[i]-pads[i]//2 for i in range(len(pads))] - return [pads[i]-pads[i]//2 for i in range(len(pads))] + [pads[i]//2 for i in range(len(pads))] -def _resolve_pool_pads(x:Tensor, p_, k_, d_, s_, auto_pad:AUTO_PAD_OPTIONS): - i_, (s_,d_,p_) = x.shape[-len(k_):], (make_tuple(x, len(k_)*2) for x in (s_, d_, p_)) - if auto_pad == "NOTSET": return _onnx_pads_to_tiny_pads(p_ if len(p_)==len(k_)*2 else p_*2) - o_ = [((i - (1 if auto_pad in ("SAME_UPPER", "SAME_LOWER") else k)) // s + 1) for i,k,s in zip(i_, k_, s_)] - return _onnx_pads_to_tiny_pads(_auto_pad([(o-1)*s+k-i for o,i,k,s in zip(o_, i_, k_, s_)], auto_pad)) - -def AveragePool(X: Tensor, kernel_shape:list[int], auto_pad:AUTO_PAD_OPTIONS="NOTSET", ceil_mode:int=0, count_include_pad:int=0, - dilations:list[int]|int=1, pads:list[int]|int=0, strides:list[int]|int=1): - return X.avg_pool2d(kernel_shape, strides, dilations, _resolve_pool_pads(X, pads, kernel_shape, dilations, strides, auto_pad), - ceil_mode=ceil_mode, count_include_pad=count_include_pad) - -def MaxPool(X: Tensor, kernel_shape:list[int], auto_pad:AUTO_PAD_OPTIONS="NOTSET", ceil_mode:int=0, dilations:list[int]|int=1, pads:list[int]|int=0, - storage_order:int=0, strides:list[int]|int=1): - ret = X.max_pool2d(kernel_shape, strides, dilations, _resolve_pool_pads(X, pads, kernel_shape, dilations, strides, auto_pad), ceil_mode=ceil_mode) - # tests expect indices with int64 dtype - # TODO: if there are repeated values, this is wrong - indices = ((ret.reshape(-1, 1) == X.reshape(1, -1)) * Tensor.arange(X.numel(), dtype=dtypes.int64).unsqueeze(0)).sum(1).reshape(ret.shape) - return ret.cast(X.dtype), indices.transpose(-2, -1) if storage_order else indices - -def Conv(X: Tensor, W: Tensor, B:Tensor|None=None, auto_pad:AUTO_PAD_OPTIONS="NOTSET", dilations:list[int]|int=1, group:int=1, - kernel_shape:list[int]|None=None, pads:list[int]|int=0, strides:list[int]|int=1): - return X.conv2d(W, B, stride=strides, groups=group, dilation=dilations, - padding=_resolve_pool_pads(X, pads, kernel_shape or W.shape[2:], dilations, strides, auto_pad)) - -def ConvTranspose(X: Tensor, W: Tensor, B:Tensor|None=None, auto_pad:AUTO_PAD_OPTIONS="NOTSET", dilations:list[int]|int=1, group:int=1, - kernel_shape:list[int]|None=None, pads:list[int]|None=None, output_shape:list[int]|None=None, output_padding:list[int]|int=0, - strides:list[int]|int=1): - input_shape, kernel_shape = X.shape[2:], (kernel_shape or W.shape[2:]) - strides, dilations, output_padding = (make_tuple(x, len(input_shape)) for x in (strides, dilations, output_padding)) - if output_shape is not None: # we pad according to output_shape - pads = _auto_pad([s*(i-1) + op + ((k-1)*d+1) - os for s,i,op,k,d,os in - zip(strides, input_shape, output_padding, kernel_shape, dilations, output_shape)], auto_pad) - if pads is None: # we generate pads - output_shape = output_shape or [X.shape[i+2] * strides[i] for i in range(len(strides))] - pads = [strides[i]*(input_shape[i]-1) + output_padding[i] + ((kernel_shape[i]-1)*dilations[i]+1)-output_shape[i] for i in range(len(input_shape))] - pads = _auto_pad(pads, auto_pad) if auto_pad != "NOTSET" else [0] * len(input_shape) * 2 - pads = _onnx_pads_to_tiny_pads(pads) - return X.conv_transpose2d(W, B, stride=strides, groups=group, dilation=dilations, padding=pads, output_padding=output_padding) - -def MaxUnpool(xT: Tensor, xI: Tensor, outshape: list[int]|None=None, kernel_shape:list[int]=None, pads:list[int]|int=0, strides:list[int]|int=1): - pads, strides = (make_tuple(x, len(xI.shape)) for x in (pads, strides)) - out_sh = [(ks//2)*2 + st * inps for inps, st, ks in zip(xI.shape, strides, kernel_shape)] - ret = (xI.reshape(-1, 1)._one_hot_along_dim(prod(out_sh)) * xT.reshape(-1, 1)).sum(0).reshape(1, 1, *out_sh) - if outshape is not None and outshape != ret.shape: pads = _auto_pad([outshape[-2] - ret.shape[-2], outshape[-1] - ret.shape[-1]], "SAME_UPPER") - return ret.pad(_onnx_pads_to_tiny_pads(pads)) - -def GlobalAveragePool(X:Tensor): return X.mean(axis=tuple(range(2, X.ndim)), keepdim=True) -def GlobalMaxPool(X:Tensor): return X.max(axis=tuple(range(2, X.ndim)), keepdim=True) - -def Gemm(A:Tensor, B:Tensor, C:Tensor|None=None, alpha:float=1.0, beta:float=1.0, transA:int=0, transB:int=0, broadcast=0): - ret = alpha * (A.transpose(transA) @ B.transpose(transB)) - if C is not None: ret = ret + beta * (C if broadcast == 0 else C.reshape([-1 if i < len(C.shape) else 1 for i in range(ret.ndim)][::-1])) - return ret - -def Einsum(*Inputs:list[Tensor], equation:str): return Tensor.einsum(equation, *Inputs) - -def CumSum(X:Tensor, axis:int|list, exclusive:int=0, reverse:int=0): - axis = X._resolve_dim(axis[0] if isinstance(axis, list) else axis) - if reverse: X = X.flip(axis) - if exclusive: X = X.pad(tuple((1,0) if i == axis else None for i in range(X.ndim)))\ - .shrink(tuple((0,X.shape[axis]) if i == axis else None for i in range(X.ndim))) - return X.cumsum(axis).flip(axis) if reverse else X.cumsum(axis) - -def Trilu(x:Tensor, k:int=0, upper:int=1): return x.triu(k) if upper else x.tril(k) - -def Resize(X:Tensor, roi:list[float]|None=None, scales:list[float]|None=None, sizes:list[int]|None=None, antialias:int=0, - axes:list[int]|None=None, coordinate_transformation_mode:str='half_pixel', cubic_coeff_a:float=-0.75, exclude_outside:int=0, - extrapolation_value:float=0.0, keep_aspect_ratio_policy:str='stretch', mode:str='nearest', nearest_mode:str='round_prefer_floor'): - def _apply_nearest_mode(index: Tensor, input_dim, mode: str): - if mode == "round_prefer_floor": index = (index - 0.5).ceil() - elif mode == "round_prefer_ceil": index = (index + 0.5).floor() - elif mode in ["floor", "ceil"]: index = getattr(index, mode)() - else: raise ValueError(f"invalid {nearest_mode=}") - return index.cast(dtypes.int32).clip(0, input_dim-1) - def _apply_transformation(index: Tensor, input_dim, scale_dim, roi_dim, mode): - # TODO: needs more testing, not confident in this - # NOTE: their reference implementation differ from the implementation in their reference docs - # https://github.com/onnx/onnx/blob/main/onnx/reference/ops/op_resize.py - # https://github.com/onnx/onnx/blob/main/docs/Operators.md#Resize - output_dim = scale_dim * input_dim - if mode == "half_pixel": index = (index + 0.5) / scale_dim - 0.5 - elif mode == "align_corners": index = index * (input_dim - 1) / (output_dim - 1) if output_dim != 1 else Tensor([0]) - elif mode == "asymmetric": index = index / scale_dim - elif mode == "pytorch_half_pixel": index = (index + 0.5) / scale_dim - 0.5 if output_dim != 1 else Tensor([-0.5]) - elif mode == "half_pixel_symmetric": index = input_dim / 2 * (1 - int(output_dim) / output_dim) + (index + 0.5) / scale_dim - 0.5 - elif mode == "tf_crop_and_resize": index = roi_dim[0] * (input_dim - 1) + index * ((roi_dim[1] - roi_dim[0]) * (input_dim - 1) / (output_dim - 1)) - else: raise ValueError(f"invalid {coordinate_transformation_mode=}") - return index.clip(0, input_dim-1) - - scales, sizes = (None if scales is None else scales[2-(X.ndim-len(scales)):]), (None if sizes is None else sizes[2-(X.ndim-len(sizes)):]) - # we pre permute the axes and permute back after resize - axes, input_shape, = (axes or list(range(X.ndim))), cast(tuple[int, ...], X.shape[2:]), - perm = [a for a in range(len(X.shape)) if a not in axes] + list(axes) - X = X.permute(*perm) - - if sizes is not None: - if keep_aspect_ratio_policy in ["not_larger", "not_smaller"]: - scale_fxn = min if keep_aspect_ratio_policy == "not_larger" else max - scales = [scale_fxn([sizes[i] / input_shape[i] for i in range(len(input_shape)) if i+2 in axes])] * 2 - sizes = [int((scales[0] * input_shape[i]) + 0.5) if i+2 in axes else input_shape[i] for i in range(X.ndim-2)] - else: - scales = [size / input_shape for size, input_shape in zip(sizes, input_shape)] - else: - sizes = [int(sc*sh) for sc, sh in zip(scales, input_shape)] - regions = [[st, ed] for st, ed in zip(roi, roi[len(roi)//2:])] if isinstance(roi, list) and roi else [[0.0, 0.0]] * (X.ndim-2) - - # NOTE: this transformation makes it so that we can't just call Tensor.interpolate - # in Tensor.interpolate, we use indexes without any transformation - indexes = [] - for shape, size, scale, region in zip(input_shape, sizes, scales, regions): - indexes.append(_apply_transformation(Tensor.arange(size), shape, scale, region, coordinate_transformation_mode)) - - if mode == "nearest": - indexes = [_apply_nearest_mode(index, shape, nearest_mode) for (index, shape) in zip(indexes, input_shape)] - X = X[(..., *Tensor.meshgrid(*indexes))] - if mode == "linear": - expand = list(X.shape) - for i in range(-len(sizes), 0): - reshape, index = [1] * X.ndim, indexes[i] - reshape[i] = expand[i] = sizes[i] - low, high, perc = [y.reshape(reshape).expand(expand) for y in (index.floor(), index.ceil(), index - index.floor())] - X = X.gather(i, low).lerp(X.gather(i, high), perc) - if mode == "cubic": raise NotImplementedError("cubic interpolation is not implemented") - return X.permute(*[perm.index(i) for i in range(len(perm))]) if perm else X -def Upsample(X, scales, mode): return Resize(X=X, scales=scales, mode=mode) # deprecated - -# ***** Neural Network Ops ***** -# TODO: try to factor out common implementations for these ops -# https://medium.com/@zljdanceholic/groupnorm-then-batchnorm-instancenorm-layernorm-e2b2a1d350a0 -def BatchNormalization(X:Tensor, scale:Tensor, B:Tensor, input_mean:Tensor, input_var:Tensor, epsilon:float=1e-05, momentum:float=0.9, - training_mode:int=0, spatial=1, is_test=0): - if training_mode: - x_detached = X.detach() - current_mean = x_detached.mean(axis=(0,2,3)) - y = (x_detached - current_mean.reshape(shape=[1, -1, 1, 1])) - current_var = (y*y).mean(axis=(0,2,3)) - current_invstd = current_var.add(epsilon).rsqrt() - - running_mean = input_mean * momentum + current_mean * (1 - momentum) - running_var = input_var * momentum + current_var * (1 - momentum) - - return X.batchnorm(scale, B, current_mean, current_invstd), running_mean, running_var - invstd = (input_var + epsilon).rsqrt() - return X.batchnorm(scale, B, input_mean, invstd) -def InstanceNormalization(x:Tensor, scale:Tensor, bias:Tensor, epsilon:float=1e-05): - axis = tuple(range(2, x.ndim)) - mean = x.mean(axis=axis, keepdim=True) - invstd = x.sub(mean).square().mean(axis=axis, keepdim=True).add(epsilon).rsqrt() - return x.sub(mean).mul(scale.reshape(shape=[-1, 1, 1])).mul(invstd).add(bias.reshape(shape=[-1, 1, 1])) -def LayerNormalization(x:Tensor, scale:Tensor, bias:Tensor, axis:int=-1, epsilon:float=1e-05, stash_type:int=1): - assert stash_type == 1, "only float32 is supported" - axes = tuple(i for i in range(axis if axis >= 0 else x.ndim + axis, x.ndim)) - mean = x.mean(axis=axes, keepdim=True) - return x.layernorm(axes, epsilon).mul(scale).add(bias), mean, (x.sub(mean)).square().mean(axis=axes, keepdim=True).add(epsilon).rsqrt() -def GroupNormalization(x:Tensor, scale:Tensor, bias:Tensor, num_groups:int, epsilon:float=1e-05): - return x.reshape(x.shape[0], num_groups, -1).layernorm(axis=-1, eps=epsilon).mul(scale.unsqueeze(-1)).add(bias.unsqueeze(-1)).reshape(x.shape) -def MeanVarianceNormalization(x:Tensor, axis:list[int]=[0,2,3]): - return (x - x.mean(axis, keepdim=True)) / (x.std(axis, keepdim=True, correction=0) + 1e-9) -def SkipLayerNormalization(x:Tensor, skip:Tensor, gamma:Tensor, beta:Tensor|None=None, bias:Tensor|None=None, epsilon:float=1e-12): - x = x + skip + bias - return x.layernorm(eps=epsilon) * gamma + beta, None, None, x -def EmbedLayerNormalization(input_ids: Tensor, segment_ids:Tensor, word_embedding:Tensor, position_embedding:Tensor, - segment_embedding:Tensor, gamma=None, beta=None, mask:Tensor|None=None, - position_ids:Tensor|None=None, epsilon=1e-12, mask_index_type=0): - # https://github.com/microsoft/onnxruntime/blob/main/docs/ContribOperators.md#com.microsoft.EmbedLayerNormalization - assert (segment_ids is None) is (segment_embedding is None) - assert mask is None and not mask_index_type, "functionality not supported yet" # TODO - input_shape = input_ids.shape - seq_length = input_shape[1] - compute_seg_emb = (segment_embedding is not None and segment_ids is not None) - vocab_size, max_position_embeddings = word_embedding.shape[0], position_embedding.shape[0] - type_vocab_size = (segment_embedding.shape[0] if compute_seg_emb else None) - - def embedding(x:Tensor, vocab_size, weight:Tensor) -> Tensor: - return x.unsqueeze(-1).expand(*x.shape, vocab_size)._one_hot_along_dim(vocab_size) @ weight - - # bert embedding layer - if position_ids is None: position_ids = Tensor.arange(seq_length, requires_grad=False).unsqueeze(0).expand(*input_shape) - wrd_embedding_res = embedding(input_ids, vocab_size, word_embedding) - pos_embedding_res = embedding(position_ids, max_position_embeddings, position_embedding) - seg_embedding_res = embedding(segment_ids, type_vocab_size, segment_embedding) if compute_seg_emb else None - - embedding_sum = wrd_embedding_res + pos_embedding_res - if seg_embedding_res is not None: embedding_sum = embedding_sum + seg_embedding_res - out = embedding_sum.layernorm(eps=epsilon) * gamma + beta - return out, None, embedding_sum - -def OneHot(indices:Tensor, depth:float|int|list, values:Tensor, axis:int=-1): - # Scalar or Rank 1 tensor containing exactly one element - depth = int(depth[0] if isinstance(depth, list) else depth) - indices = (indices < 0).where(indices+depth, indices) - return indices[:, None]._one_hot_along_dim(depth, dim=axis).where(values[1], values[0]) - -def DepthToSpace(X:Tensor, blocksize:int, mode:str="DCR"): - return X.rearrange("b (c h1 w1) h w -> b c (h h1) (w w1)" if mode=="CRD" else "b (h1 w1 c) h w -> b c (h h1) (w w1)", h1=blocksize, w1=blocksize) -def SpaceToDepth(X:Tensor, blocksize:int): - return X.rearrange("b c (h h1) (w w1) -> b (h1 w1 c) h w", h1=blocksize, w1=blocksize) - -# Reimplemented here because you need legacy RNG for passing ONNX tests. -def Dropout_7(data:Tensor, ratio:float=0.5, training_mode:bool=False, seed:int|None=None): - if not training_mode: return data, Tensor.ones(data.shape, dtype=dtypes.bool) # if mask is requested as output it will contain all True's. - mask = Tensor(np.random.RandomState(seed).random(cast(tuple[int,...], data.shape)) >= ratio, requires_grad=False, device=data.device) - return data * mask * (1/(1.0 - ratio)), mask -# 6 with 'is_test' needed for https://github.com/MTlab/onnx2caffe/raw/refs/heads/master/model/MobileNetV2.onnx -def Dropout_6(data:Tensor, ratio:float=0.5, is_test=0): return Dropout_7(data, ratio, training_mode=not is_test) -Dropout = {6:Dropout_6, 7:Dropout_7} - -def LRN(x:Tensor, size:int, alpha:float=1e-4, beta:float=0.75, bias:float=1.0): - pooled_x = (x**2).rearrange('b c h w -> b 1 c (h w)').pad((0,0,(size-1)//2, size//2)).avg_pool2d((size, 1), 1) - return x / (pooled_x.reshape(x.shape) * alpha + bias).pow(beta) - -def NegativeLogLikelihoodLoss(x:Tensor, target:Tensor, weight:Tensor|None=None, ignore_index:int|None=None, reduction:ReductionStr="mean"): - return x.nll_loss(target, weight, ignore_index, reduction) -def SoftmaxCrossEntropyLoss(scores:Tensor, labels:Tensor, weights:Tensor|None=None, ignore_index:int|None=None, reduction:ReductionStr="mean"): - log_probs = scores.log_softmax(1) - return log_probs.nll_loss(labels, weights, ignore_index, reduction), log_probs - -def AffineGrid(theta:Tensor, size:list[int], align_corners:int=0): - N, _, *spatial_dims = size - def generate_grid(steps): - return Tensor.linspace(-1, 1, steps, device=theta.device) if align_corners else Tensor.linspace(-1+1/steps, 1-1/steps, steps, device=theta.device) - grids = Tensor.meshgrid(*(generate_grid(d) for d in spatial_dims)) - base_grid = Tensor.stack(*reversed(grids), Tensor.ones_like(grids[0], device=theta.device), dim=-1) - base_grid = base_grid.reshape(1, prod(spatial_dims), len(grids)+1).expand(N, -1, -1) - return (base_grid @ theta.transpose(1, 2)).reshape(N, *spatial_dims, -1) - -def Attention(x:Tensor, weights, bias:Tensor, mask_index:Tensor|None=None, past:Tensor|None=None, - relative_position_bias:Tensor|None=None, past_sequence_length:Tensor|None=None, do_rotary:int|None=None, - mask_filter_value:float|None=None, num_heads:int|None=None, past_present_share_buffer:int|None=None, - qkv_hidden_sizes:list[int]|None=None, scale:float|None=None, unidirectional:int|None=None): - # https://github.com/microsoft/onnxruntime/blob/main/docs/ContribOperators.md#com.microsoft.Attention - assert num_heads is not None # required - assert (qkv_hidden_sizes is None and past is not None) or (qkv_hidden_sizes is not None) - assert relative_position_bias is do_rotary is past_sequence_length is mask_filter_value is past_present_share_buffer is scale is None, \ - "functionality not supported yet" # TODO strange params - hidden_size, v_hidden_size = qkv_hidden_sizes[1:] if qkv_hidden_sizes is not None else 2*(weights.shape[1] // 3,) - - if unidirectional: # gpt-style - assert hidden_size == v_hidden_size - xqkv = x.linear(weights, bias) - xq, xk, xv = [xqkv.shrink([None, None, (i*hidden_size, (i+1)*hidden_size)]) for i in range(3)] - else: # bert-style - wq, wk, wv = weights[:,:hidden_size], weights[:,hidden_size:hidden_size+v_hidden_size], weights[:,hidden_size+v_hidden_size:] - bq, bk, bv = (bias[:hidden_size], bias[hidden_size:hidden_size+v_hidden_size], bias[hidden_size+v_hidden_size]) if bias is not None else None - xq, xk, xv = [x.linear(w, b) for w, b in zip((wq, wk, wv), (bq, bk, bv))] - xq, xk, xv = [x.reshape(x.shape[0], x.shape[1], num_heads, -1).transpose(1, 2) for x in (xq, xk, xv)] - - if past is not None: - xk, xv = Tensor.cat(past[0], xk, dim=-2), Tensor.cat(past[1], xv, dim=-2) - present = Tensor.cat(xk.unsqueeze(0), xv.unsqueeze(0)) - - def attn(query, key, value, attn_mask): - query_length, key_length = query.shape[-2], key.shape[-2] - cdim = max(query_length, key_length) + 1 - attn_weights = query @ key.transpose(-1, -2) / math.sqrt(value.shape[-1]) - # This is where Tensor.scaled_dot_product_attention differs: - causal_mask = Tensor.ones((cdim, cdim), requires_grad=False, dtype=dtypes.bool).tril(0)[key_length - query_length : key_length, :key_length] - masked = Tensor.where(causal_mask, attn_weights, -math.inf) - if attn_mask is not None: masked = masked + attn_mask - return masked.softmax(-1) @ value - - bsz, _, seq_len, _ = xq.shape - out = attn(xq, xk, xv, mask_index).transpose(1, 2).reshape(bsz, seq_len, -1) - return out, present if past is not None else out - -# ***** Indexing Ops ***** -def ArrayFeatureExtractor(x:Tensor, indices:Tensor): return x[..., indices] - -def Gather(x:Tensor, indices:Tensor, axis:int=0): - if indices.numel() < 9: # NOTE lessor kernels for smaller indices but kernel number increases depending on size of indices - x_sh = list(x.shape) - ret_shape = x_sh[:axis] + list(indices.shape) + x_sh[axis+1:] - if indices.ndim > 1: indices = indices.flatten() - indices = [_cached_to_python_const(indices)] if indices.shape == () else [x_sh[axis]+x if x<0 else x for x in _cached_to_python_const(indices)] - args = [[(0,x) if j != axis else (i,i+1) for j, x in enumerate(x_sh)] for i in indices] # type: ignore - return x.shrink(arg=tuple(args[0])).cat(*[x.shrink(arg=tuple(arg)) for arg in args[1:]], dim=axis).reshape(ret_shape) - # NOTE faster gather, fixed number of kernels, but exceeds limited kernels for openpilot - return x[tuple([slice(None) if i != axis else indices for i in range(x.ndim)])] -def Scatter(*args, **kwargs): return ScatterElements(*args, **kwargs) # deprecated - -def GatherND(x:Tensor, indices:Tensor, batch_dims:int=0): - if batch_dims == 0: return x[tuple(i.squeeze(-1) for i in indices.split(1, -1))] - x_shape, i_shape = x.shape, indices.shape - b = math.prod(x.shape[dim] for dim in range(batch_dims)) - # NOTE: each batched dim of both input and indices are equal - x = x.reshape(b, *x.shape[batch_dims:]) - indices = indices.reshape(b, *indices.shape[batch_dims:]) - b_idx = Tensor.arange(b, device=x.device).reshape(b, *(1,)*(indices.ndim - 2)).expand(*indices.shape[:-1]) - ret = x[(b_idx,) + tuple(i.squeeze(-1) for i in indices.split(1, -1))] - return ret.reshape(*x_shape[:batch_dims], *i_shape[batch_dims:-1], *ret.shape[indices.ndim-1:]) -def ScatterND(x:Tensor, indices:Tensor, updates:Tensor, reduction:Literal["none", "add", "mul"]='none'): - assert updates.shape == indices.shape[:-1] + x.shape[cast(int, indices.shape[-1]):] - x = x.contiguous() - for index, u in zip(indices.split(1, 0), updates.split(1, 0)): - i = tuple(idx.squeeze(-1) for idx in index.squeeze(0).split(1, -1)) - u = u.squeeze(0) - if reduction == "none": x[i] = u - elif reduction == "add": x[i] += u - elif reduction == "mul": x[i] *= u - else: raise NotImplementedError("reduction doesn't support max or min") - return x - -def ScatterElements(x: Tensor, indices: Tensor, updates: Tensor, axis=0, reduction:Literal["none", "add", "mul"]="none"): - indices = (indices < 0).where(x.shape[axis], 0) + indices - return x.scatter(axis, indices, updates, {"none":None, "mul": "multiply"}.get(reduction, reduction)) -def GatherElements(x:Tensor, indices:Tensor, axis:int): - indices = (indices < 0).where(x.shape[axis], 0) + indices - return x.gather(axis, indices) - -def Compress(inp:Tensor, condition:list[bool], axis:int|None=None): - if axis is None: - inp = inp.flatten() - axis = 0 - if axis < 0: axis += inp.ndim - con = Tensor(np.arange(len(condition))[condition]) # no boolean indexing in Tensor - return inp[tuple(con if i == axis else slice(None) for i in range(inp.ndim))] - -# ***** Quantization Ops ***** -def _clamp_cast(x:Tensor, dtype:DType): return x.clamp(dtypes.min(dtype), dtypes.max(dtype)).cast(dtype) - -def _prepare_quantize(x, scale, zero_point, axis=1, block_size=0): - if axis < 0: axis += x.ndim - if not isinstance(zero_point, Tensor): zero_point = Tensor(zero_point, dtype=dtypes.uint8)._broadcast_to(scale.shape) - if block_size == 0: - shape = (*[1]*axis, *scale.shape, *[1]*(x.ndim - axis - scale.ndim)) - return scale.reshape(shape), zero_point.reshape(shape) - return scale.repeat_interleave(block_size, dim=axis), zero_point.repeat_interleave(block_size, dim=axis) - -def QuantizeLinear(x:Tensor, y_scale:Tensor, y_zero_point:Tensor|int=0, axis:int=1, block_size:int=0, output_dtype:int=0, saturate=1): - out_dtype = y_zero_point.dtype if isinstance(y_zero_point, Tensor) else dtype_parse(output_dtype) if output_dtype else dtypes.uint8 - y_scale, y_zero_point = _prepare_quantize(x, y_scale, y_zero_point, axis, block_size) - return _clamp_cast(((x / y_scale).round() + y_zero_point), out_dtype).contiguous() - -def DequantizeLinear(x:Tensor, x_scale:Tensor, x_zero_point:Tensor|int=0, axis:int=1, block_size:int=0): - x_scale, x_zero_point = _prepare_quantize(x, x_scale, x_zero_point, axis, block_size) - return ((x.int() - x_zero_point) * x_scale).cast(x_scale.dtype) - -def _op_integer(op, inputs:list[Tensor], zero_points:list[Tensor], **opts): - adjusted_inputs = [inp.int() - zp for inp, zp in zip(inputs, zero_points)] - return op(*adjusted_inputs, **opts) - -def _qlinearop_quantized(op, inputs:list[Tensor], zero_points:list[Tensor], scales:list[Tensor], out_scale:Tensor, out_zero_point:Tensor, **opts): - # op execution is done in quantized int - out = _op_integer(op, inputs, zero_points, **opts) - assert dtypes.is_int(out.dtype), "quantized op should've done math in int" - out_quantized = (out * prod(scales) / out_scale).round() + out_zero_point - return _clamp_cast(out_quantized, out_zero_point.dtype) - -def _qlinearop_float(op, inputs:list[Tensor], zero_points:list[Tensor], scales:list[Tensor], out_scale:Tensor, out_zero_point:Tensor, **opts): - # op execution is done in float32 - dequantized_inputs = [(inp.int() - zp) * scale for inp, zp, scale in zip(inputs, zero_points, scales)] - out = op(*dequantized_inputs, **opts) - assert dtypes.is_float(out.dtype), "op should've done math in float" - out_quantized = (out / out_scale).round() + out_zero_point - return _clamp_cast(out_quantized, out_zero_point.dtype) - -def QLinearConv(x:Tensor, x_scale:Tensor, x_zero_point:Tensor|int, w:Tensor, w_scale:Tensor, w_zero_point:Tensor|int, y_scale:Tensor, - y_zero_point: Tensor|int, B:Tensor|None=None, **opts): - return _qlinearop_quantized(Conv, [x,w], [x_zero_point,w_zero_point], [x_scale,w_scale], y_scale, y_zero_point, **{"B":B, **opts}) - -def QLinearMatMul(a:Tensor, a_scale:Tensor, a_zero_point:Tensor|int, b:Tensor, b_scale:Tensor, b_zero_point:Tensor|int, y_scale:Tensor, - y_zero_point:Tensor|int) -> Tensor: - return _qlinearop_quantized(Tensor.matmul, [a,b], [a_zero_point,b_zero_point], [a_scale,b_scale], y_scale, y_zero_point) - -def QLinearAdd(a:Tensor, a_scale:Tensor, a_zero_point:Tensor, b:Tensor, b_scale:Tensor, b_zero_point:Tensor, c_scale:Tensor, c_zero_point:Tensor): - return _qlinearop_float(Tensor.add, [a,b], [a_zero_point,b_zero_point], [a_scale,b_scale], c_scale, c_zero_point) - -def QLinearGlobalAveragePool(X:Tensor, x_scale:Tensor, x_zero_point:Tensor, y_scale:Tensor, y_zero_point:Tensor, channels_last:int): - assert channels_last == 0, "unsure what this does" - return _qlinearop_float(GlobalAveragePool, [X], [x_zero_point], [x_scale], y_scale, y_zero_point) - -def ConvInteger(x: Tensor, w: Tensor, x_zero_point: Tensor | int = 0, w_zero_point: Tensor | int = 0, B: Tensor | None = None, **opts) -> Tensor: - return _op_integer(Conv, [x,w], [x_zero_point,w_zero_point], **{"B":B, **opts}) - -def MatMulInteger(A: Tensor, B: Tensor, a_zero_point: Tensor | int = 0, b_zero_point: Tensor | int = 0) -> Tensor: - return _op_integer(Tensor.matmul, [A,B], [a_zero_point,b_zero_point]) - -# ***** Training Ops ***** -# NOTE: onnx test coverage only covers `T==0` cases, so for all `T>0` this isn't tested -# NOTE: onnx training ops actually don't need the state for optim, all the ops work in a functional way, but we still can reuse optim.py code -def _onnx_training(input_group_size): - def __decorator(func): - def ___wrapper(R:Tensor, T:int, *inputs:Tensor, **kwargs): - R = R.detach() - groups = len(inputs) // input_group_size - ret = [func(R, T, *inps, **kwargs) for inps in (inputs[i::groups] for i in range(groups))] - return tuple(flatten(zip(*ret))) - return ___wrapper - return __decorator - -@_onnx_training(3) -def Adagrad(R:Tensor, T:int, *inputs:Tensor, decay_factor:float=0.0, epsilon:float=0.0, norm_coefficient:float=0.0): - X, G, H = (i.detach() for i in inputs) - grad = norm_coefficient * X + G - H.assign(H + grad.square()) - up = grad / (H.sqrt() + epsilon) - r = R / (1 + T * decay_factor) - X.assign(X.detach() - r * up) - return [X, H] - -@_onnx_training(4) -def Adam(R:Tensor, T:int, *inputs:Tensor, alpha:float=0.9, beta:float=0.999, epsilon:float=0.0, norm_coefficient:float=0.0, - norm_coefficient_post:float=0.0): - from tinygrad.nn.optim import Adam as TinyAdam - X, G, V, H = inputs - G, V, H = G.detach(), V.detach(), H.detach() # TODO we shouldn't need these detaches - X.grad = norm_coefficient * X.detach() + G - opt = TinyAdam([X], b1=alpha, b2=beta, eps=epsilon) - opt.m, opt.v, opt.lr = [V], [H], R - # need no-op for m_hat and v_hat if T == 0 - if T == 0: opt.b1_t, opt.b2_t = opt.b1_t.zeros_like(), opt.b2_t.zeros_like() - else: - # `T-1` since it's applied again at the start of `_step` - opt.b1_t = Tensor([alpha**(T-1)], dtype=dtypes.float32, device=X.device, requires_grad=False) - opt.b2_t = Tensor([beta**(T-1)], dtype=dtypes.float32, device=X.device, requires_grad=False) - opt.step() - X = (1 - norm_coefficient_post) * X - return [X, V, H] - -@_onnx_training(3) -def Momentum(R:Tensor, T:int, *inputs:Tensor, alpha:float, beta:float, mode:str, norm_coefficient:float): - from tinygrad.nn.optim import SGD - X, G, V = inputs - G, V = G.detach(), V.detach() - X.grad = (norm_coefficient * X.detach() + G) * (beta if T > 0 else 1) - opt = SGD([X], momentum=alpha, nesterov=(mode=="nesterov")) - opt.b, opt.lr = [V], R - opt.step() - return [X, V] - -def Gradient(*inputs:Tensor, y:str, intermediate_tensors:dict[str, Tensor], **_): - intermediate_tensors[y].backward() - return tuple([t.grad for t in inputs]) diff --git a/test/test_const_folding.py b/test/test_const_folding.py index dfffca8989..b78faee145 100644 --- a/test/test_const_folding.py +++ b/test/test_const_folding.py @@ -94,10 +94,9 @@ class TestBinaryOpsConstFolding(unittest.TestCase): _check_ast_count(0, Tensor([1.0, 2, 3, 4]) ** Tensor.ones(4)) def test_literal_one_pow(self): _check_ast_count(0, 1 ** Tensor([1.0, 2, 3, 4])) - # this fails because of DETACH, it shouldn't - # update: passes after CONST(VIEW(DEVICE)) in tensor + # TODO: pow simplification def test_tensor_one_pow(self): - _check_ast_count(0, Tensor.ones(4) ** Tensor([1.0, 2, 3, 4])) + _check_ast_count(1, Tensor.ones(4) ** Tensor([1.0, 2, 3, 4])) # folds advance indexing into basic indexing class TestIndexingConstFolding(unittest.TestCase): diff --git a/test/unit/test_disk_tensor.py b/test/unit/test_disk_tensor.py index 6bb099aa76..6291216d1d 100644 --- a/test/unit/test_disk_tensor.py +++ b/test/unit/test_disk_tensor.py @@ -3,7 +3,7 @@ import numpy as np from tinygrad import Tensor, Device, dtypes from tinygrad.dtype import DType from tinygrad.nn.state import safe_load, safe_save, get_state_dict, torch_load -from tinygrad.helpers import Timing, fetch, temp, CI +from tinygrad.helpers import Timing, fetch, temp, CI, OSX from tinygrad.device import is_dtype_supported def compare_weights_both(url): @@ -298,6 +298,7 @@ class TestDiskTensor(unittest.TestCase): ret = t.bitcast(dtypes.uint16).to("CLANG") + 1 assert ret.tolist() == [2827, 3341, 3855, 4369] + @unittest.skipIf(OSX, "new LLVM has an issue on OSX") def test_bf16_disk_write_read(self): t = Tensor([10000, -1, -1000, -10000, 20], dtype=dtypes.float32) t.to(f"disk:{temp('dt_bf16_disk_write_read_f32')}").realize() diff --git a/tinygrad/codegen/kernel.py b/tinygrad/codegen/kernel.py index 7c6592b912..2fc83e7711 100644 --- a/tinygrad/codegen/kernel.py +++ b/tinygrad/codegen/kernel.py @@ -620,7 +620,7 @@ class Kernel: if self.use_tensor_cores == 3: # for TC=3, emulate the warp addressing with locals local_shape = tuple(1 if i >= self.first_reduce and i < self.first_upcast else s for i, s in enumerate(self.full_shape)) st = store_st = ShapeTracker.from_shape(local_shape) - local_buffer = UOp(Ops.DEFINE_LOCAL, tc.dtype_in.ptr(size=st.real_size(), local=True), (), f"temp{i + 1}") + local_buffer = UOp(Ops.DEFINE_LOCAL, tc.dtype_in.ptr(size=st.real_size(), local=True), (), f"temp{i}") if swizzle: store_st = get_tc_swizzle_st(store_st.shape, *swizzle) local_store = UOp.store(local_buffer, store_st.to_uop(), srcs[i]) srcs[i] = UOp(Ops.LOAD, tc.dtype_in, (local_buffer, st.to_uop(), local_store)) @@ -648,7 +648,7 @@ class Kernel: (1,) * (self.shape_len - self.upcasted - self.group_for_reduces - self.first_reduce) + tuple([x[0] for x in self.upcasted_axis(0)]) st_uop = ShapeTracker.from_shape(local_shape).to_uop() local_size = st_uop.arg.real_size() - local_buffer = UOp(Ops.DEFINE_LOCAL, op.dtype.ptr(local_size, local=True), (), f"temp{self.reduceops.index(op)+1}") + local_buffer = UOp(Ops.DEFINE_LOCAL, op.dtype.ptr(local_size, local=True), (), f"temp{self.reduceops.index(op)}") local_load = UOp(Ops.LOAD, op.dtype, (local_buffer, st_uop, UOp.store(local_buffer, st_uop, ret))) grouped_reduce = UOp(Ops.REDUCE_AXIS, op.dtype, (local_load,), arg=(op.arg[0], grouped_axes)) if op is self.reduceops[-1]: return grouped_reduce diff --git a/tinygrad/engine/multi.py b/tinygrad/engine/multi.py index 4803127690..17a3259b04 100644 --- a/tinygrad/engine/multi.py +++ b/tinygrad/engine/multi.py @@ -43,13 +43,28 @@ def to_sharded(lbs:list[UOp], axis:int, bounds: tuple[tuple[int, int], ...]) -> from tinygrad.ops import PatternMatcher, UPat, GroupOp, graph_rewrite_map, track_rewrites +def get_axis(root:UOp): + if root.op is Ops.MULTI: return root.arg[0] + # NOTE: they all have to share an axis, we always choose [-1] + if root.op in GroupOp.ALU: return axes[-1] if (axes := dedup([x.axis for x in root.src if x.axis is not None])) else None + src_axis = get_axis(root.src[0]) + if root.op is Ops.REDUCE_AXIS: return None if src_axis is not None and src_axis in root.arg[1] else src_axis + if root.op is Ops.RESHAPE: + if src_axis is None: return None + arg_acc:list[sint] = list(itertools.accumulate(root.arg, operator.mul, initial=1)) + # new_axis is the last one that preserves prod(prior to new_axis) and must not move items between shards + # TODO: what to do about shrinking to self.shape[self.axis]==1 len(self.real_lbs)==1? + return len(arg_acc) - arg_acc[::-1].index(prod(root.src[0].shape[:src_axis])) - 1 + if root.op is Ops.PERMUTE: return root.arg.index(src_axis) if src_axis is not None else None + raise NotImplementedError("rest should be passthrough") + def alu_multi(root:UOp): msrcs = root.src assert all(x.op is Ops.MULTI for x in msrcs), f"all buffers must be MultiLazyBuffer {[x.op for x in msrcs]}" assert all_same([x.device for x in msrcs]), f"all buffers must have the same device {[x.device for x in msrcs]}" - # NOTE: they all have to share an axis, we always choose [-1] - axis, bounds = axes[-1] if len(axes := dedup([(x.axis, x.bounds) for x in msrcs if x.axis is not None])) else (None, None) + axis = get_axis(root) + bounds = dedup([x.bounds for x in root.src if x.axis == axis])[-1] if axis is not None else None srcs:list[list[UOp]] = [] not_all_real = not all(all(mlb.real) for mlb in msrcs) new_real = tuple(all(transposed) for transposed in zip(*[mlb.real for mlb in msrcs])) if not_all_real else msrcs[0].real @@ -64,28 +79,24 @@ def alu_multi(root:UOp): return UOp.multi(*new_lbs, axis=axis, real=new_real) def reduce_multi(root:UOp, multi:UOp): - op, axis = root.arg + (op, axis), new_axis = root.arg, get_axis(root) if multi.axis is not None and multi.axis in axis: # all-reduce on sharded axes reduced_parts = [(x if r else x.const_like(0)).r(op, axis) for x,r in zip(multi.src, multi.real)] # if all partitions are real, do all_reduce - if all(multi.real): return UOp.multi(*all_reduce(op, reduced_parts), axis=None) + if all(multi.real): return UOp.multi(*all_reduce(op, reduced_parts), axis=new_axis) # only one partition is real, keep it - return UOp.multi(*reduced_parts, axis=None, real=multi.real) + return UOp.multi(*reduced_parts, axis=new_axis, real=multi.real) # reduce on non sharded axes, piecewise is fine. if axis is None this is also correct - return UOp.multi(*[x.r(op, axis) for x in multi.src], axis=multi.axis, real=multi.real) + return UOp.multi(*[x.r(op, axis) for x in multi.src], axis=new_axis, real=multi.real) def _shape_to_single_shard(axis, shape:tuple[sint, ...], lb:UOp) -> tuple[sint, ...]: return tuple(lb.shape[axis] if a == axis else s for a,s in enumerate(shape)) def reshape_multi(root:UOp, multi:UOp): - arg = root.arg - if multi.axis is None: return UOp.multi(*[x.reshape(arg) for x in multi.src], axis=None, real=multi.real) + arg, new_axis = root.arg, get_axis(root) + if multi.axis is None: return UOp.multi(*[x.reshape(arg) for x in multi.src], axis=new_axis, real=multi.real) assert prod(multi.shape) == prod(arg), "reshape must maintain prod(shape)" - arg_acc:list[sint] = list(itertools.accumulate(arg, operator.mul, initial=1)) - # new_axis is the last one that preserves prod(prior to new_axis) and must not move items between shards - # todo: what to do about shrinking to self.shape[self.axis]==1 len(self.real_lbs)==1? - new_axis = len(arg_acc) - arg_acc[::-1].index(prod(multi.shape[:multi.axis])) - 1 assert all(prod(lb.shape[multi.axis:])%prod(arg[new_axis+1:])==0 for lb in multi.src), \ f"reshape cannot move items between shards {multi.shape} -> {root.arg=}" lbs = [x.reshape(tuple(s if a!=new_axis else prod(x.shape[multi.axis:])//prod(arg[new_axis+1:]) for a,s in enumerate(arg))) for x in multi.src] @@ -109,7 +120,7 @@ def pad_multi(root:UOp, multi:UOp): def permute_multi(root:UOp, multi:UOp): # all permutes supported! - return UOp.multi(*[x.permute(root.arg) for x in multi.src], axis=root.arg.index(multi.axis) if multi.axis is not None else None, real=multi.real) + return UOp.multi(*[x.permute(root.arg) for x in multi.src], axis=get_axis(root), real=multi.real) def shrink_multi(root:UOp, multi:UOp): assert multi.axis is None or root.arg[multi.axis] == (0, multi.shape[multi.axis]) or root.arg[multi.axis] in multi.bounds, \ diff --git a/tinygrad/engine/schedule.py b/tinygrad/engine/schedule.py index 1083240ea7..98d527b97a 100644 --- a/tinygrad/engine/schedule.py +++ b/tinygrad/engine/schedule.py @@ -13,7 +13,257 @@ from tinygrad.device import Buffer # creation can recurse a lot sys.setrecursionlimit(10000) -# **** ScheduleItem return type +# **** schedule simplifier + +def simplify_reduceop(reduce:UOp, x:UOp) -> UOp|None: + if not all_int(x.shape): return None + # remove reduce on unmasked const + prshape = prod(unwrap(x.st).shape[i] for i in reduce.arg[1]) + ret = x.const_arg + match reduce.arg[0]: + case Ops.ADD: ret *= prshape + case Ops.MUL: ret **= prshape + case Ops.MAX: pass # NOTE: Ops.MAX is passthrough + case _: return None + return reduce.const_like(ret) + +def found_contiguous(ctx:dict[UOp, UOp], contig:UOp, src:UOp): + if (sti:=unwrap(src.st).invert(src.base.shape)) is not None: ctx[src.base] = contig.view(sti) +def replace_contiguous(ctx:dict[UOp, UOp], alu:UOp): + new_src = list(alu.src) + for i,s in enumerate(alu.src): + if (replace_src:=ctx.get(s, None)) is not None: new_src[i] = replace_src + if tuple(new_src) != alu.src: return alu.replace(src=tuple(new_src)) + +sym = symbolic_simple+PatternMatcher([ + # UOp with size 0 is zero + (UPat(set(Ops)-{Ops.SINK}, name="root"), lambda root: root.const_like(0) if root.base.st is not None and root.size == 0 \ + and not (root.base.op is Ops.CONST and root.base.arg == 0) else None), + # DETACH and CONTIGUOUS_BACKWARD are NOOPs here + (UPat((Ops.DETACH, Ops.CONTIGUOUS_BACKWARD), name="x"), lambda x: x.src[0]), + # reduce of size 0 is the identity element + (UPat(Ops.REDUCE_AXIS, name="reduce", src=(UPat.var("x"),)), + lambda reduce,x: reduce.const_like(identity_element(reduce.arg[0], reduce.dtype)) if x.size == 0 and reduce.size != 0 else None), + # reduce of const is collapsed (TODO: make this a generic rule for stride0) + (UPat(Ops.REDUCE_AXIS, name="reduce", src=(UPat.cvar("x"),)), simplify_reduceop), + # COPY(CONST) creates a new CONST on the destination device + (UPat(Ops.COPY, name="root", src=(UPat(), UPat.cvar("x"),)), lambda root,x: root.const_like(x.const_arg)), + # no COPY to same device, except clone (arg is True) + (UPat(Ops.COPY, src=(UPat(), UPat.var("copyin")), name="copy"), + lambda copyin,copy: copyin if copyin.device == copy.device and copy.arg is not True else None), + # remove cast to image when it's already a contiguous image + (UPat(Ops.VIEW, name="vm1", src=(UPat(Ops.CAST, name="cast", src=(UPat(Ops.VIEW, name="vm2", src=(UPat(Ops.CONTIGUOUS, name="base"))))),)), + lambda cast,base,vm1,vm2: base.view(vm2.st+vm1.st) if isinstance(cast.dtype, ImageDType) and isinstance(base.dtype, ImageDType) else None), + # remove contiguous if we can just view the buffer + (UPat(Ops.CONTIGUOUS, name="root", src=(UPat(Ops.VIEW, name="view", src=(UPat(Ops.BUFFER, name="buf"),)),)), + lambda root,view,buf: view if view.st.contiguous and view.size == buf.size else None), + # contiguous/buffer is already contiguous + (UPat(Ops.CONTIGUOUS, name="root", src=(UPat((Ops.CONTIGUOUS, Ops.BUFFER)),)), lambda root: root.src[0]), + # support for using a contiguous permuted view instead of the parent view if one exists + (UPat(Ops.CONTIGUOUS, name="contig", src=(UPat(Ops.VIEW, name="src"),)), found_contiguous), + (UPat(GroupOp.ALU, name="alu"), replace_contiguous), + # remove CONST/BIND/BUFFER from SINK + (UPat(Ops.SINK, name="root"), + lambda root: UOp(Ops.SINK, root.dtype, new_src, root.arg) + if (new_src:=tuple(x for x in root.src if not x.is_realized and x.base.op not in {Ops.CONST, Ops.BIND})) != root.src else None), +]) + +remove_movement_ops = merge_views+PatternMatcher([ + # NOTE: movement ops are always applied to base + (UPat(GroupOp.Movement, name="mov", src=(UPat.var("x"),)), lambda x,mov: x.view(unwrap(mov.st))), + # some masked views can collapse to 0, VIEW(x) -> CONST(VIEW) + (UPat(Ops.VIEW, name="view"), + lambda view: view.const_like(0) if (vm:=view.st.views[-1].mask) is not None and any((x[1]-x[0]) == 0 for x in vm) else None), +]) + +# **** UOp realization + +@dataclass(frozen=True) +class ScheduleContext: + tensor_uops: dict[UOp, list[UOp]] # this maps BUFFER uops of this schedule to the tensor uop + assigns: set[UOp] = field(default_factory=set) # this holds all the BUFFER uops we ASSIGN to in this schedule + realizes: dict[UOp, UOp] = field(default_factory=dict) # this holds all the BUFFER uops we mutate in this schedule + allbufs: dict[UOp, UOp] = field(default_factory=dict) # this maps BUFFER uops the actual op + ops_metadata: dict[UOp, Metadata] = field(default_factory=dict) # this maps fused ops to Metadata + children: defaultdict[UOp, dict[UOp, None]] = field(default_factory=lambda: defaultdict(dict)) + preloads: defaultdict[Buffer, dict[UOp, None]] = field(default_factory=lambda: defaultdict(dict)) + +# wrap tensor uops around a VIEW(BUFFER, ) +# this BUFFER preserves a link back to the uop on the tensor after the scheduler rewrites it. +def add_buffers(buf:UOp, buffer_map:dict[UOp, UOp], cache:dict[UOp, UOp]) -> UOp: + if (r:=cache.get(buf)) is not None: return r + # SINK is passthrough + if buf.op is Ops.SINK: return buf.replace(src=tuple(add_buffers(x, buffer_map, cache) for x in buf.src)) + # skip creating buffers for CONST/BIND/DEVICE/BUFFER + if buf.base.op in {Ops.CONST, Ops.BIND, Ops.DEVICE}: return buf + if buf.base.op is Ops.BUFFER: return buf.view(unwrap(buf.st)) + # VIEW is passthrough + if buf is not buf.base: + cache[buf] = ret = add_buffers(buf.base, buffer_map, cache).view(unwrap(buf.st)) + return ret + # make things that can't be images not images + dtype = buf.dtype + if isinstance(dtype, ImageDType) and (prod(buf.shape)!=prod(dtype.shape) or not any(buf.shape[x]%4==0 for x in unwrap(buf.st).unit_stride_axes())): + if DEBUG >= 2: print(f"forcing image {dtype} with shape {buf.shape} to {dtype.base}") + dtype = buf.dtype.base + # ASSIGN already has a target buffer, otherwise we create a new one + assert isinstance(buf.device, str), f"buf device is str, not {buf.device}" + buf_uop = buf.buf_uop if buf.op is Ops.ASSIGN else UOp.new_buffer(buf.device, buf.size, dtype) + op = buf.replace(dtype=dtype, src=tuple(add_buffers(x, buffer_map, cache) for x in buf.src)) + # track the buffer uop for the simplified uop + buffer_map[buf] = buf_uop + # (early) bufferize + cache[buf] = ret = UOp(Ops.VIEW, dtype.base, (buf_uop, op), buf.st) + return ret + +class UPatScheduled(UPat): + def __init__(self, *args, **kwargs): + super().__init__(Ops.VIEW, name="base", src=(UPat(Ops.BUFFER, name="b"), UPat(*args, **{"name":"to_store",**kwargs}))) + +def realize(ctx:ScheduleContext, b:UOp, to_store:UOp, **kwargs) -> None: ctx.realizes[b] = to_store + +def realize_before_view(ctx:ScheduleContext, view:UOp, src:UOp, b:UOp, **kwargs) -> None: + st = unwrap(view.st) + # fold simple pads + if len(st.views) == 1 and (m:=st.views[-1].mask) is not None and all_int(src.shape) and resolve(prod(src.shape) >= prod([y-x for x,y in m])): + return None if can_pad(src, ctx.realizes, set()) else realize(ctx, b, src) + # early realize before expand + if resolve(prod(src.shape) < prod(st.shape)) and not getenv("DONT_REALIZE_EXPAND"): return realize(ctx, b, src) + # otherwise safety check pads + return None if (all(v.mask is None for v in st.views) or can_pad(src, ctx.realizes, set())) else realize(ctx, b, src) + +def fold_img_cast(ctx:ScheduleContext, xb:UOp, view:UOp, b:UOp, x:UOp, **kwargs) -> UOp|None: + if not isinstance(xb.dtype, ImageDType) or b not in ctx.realizes or xb not in ctx.realizes or uval(x.base).op is Ops.COPY: return None + del ctx.realizes[b] + return x.view(unwrap(view.st)) + +def create_subbuffer(base:UOp, b:UOp, root:UOp, x:UOp): + if isinstance(b.device, tuple) or not b.device.startswith("DISK"): return None + buffers[b] = x.buf_uop.buffer.view(b.size, b.dtype, unwrap(x.st).views[0].offset*x.dtype.itemsize) + return base.replace(src=(b, root.replace(op=Ops.BUFFER_VIEW))) + +do_realize = PatternMatcher([ + # always realize SINK parents + (UPat(Ops.SINK, name="sink"), lambda ctx,sink: ctx.realizes.update((x.buf_uop, x) for x in sink.src)), + # always realize ASSIGN/CONTIGUOUS/COPY/BUFFER_VIEW + (UPatScheduled({Ops.ASSIGN, Ops.CONTIGUOUS, Ops.COPY, Ops.BUFFER_VIEW}), realize), + # realize before expand or unsafe pad ops + (UPat(Ops.VIEW, name="view", src=(UPatScheduled(name="src"),)), realize_before_view), + # don't realize image to image casts + (UPat(Ops.VIEW, name="view", src=(UPatScheduled(Ops.CAST, src=(UPat(Ops.VIEW, src=(UPat.var("xb"), UPat()), name="x"),), dtype=dtypes.float),)), + fold_img_cast), + # realize before COPY or BUFFER_VIEW + (UPat(Ops.COPY, src=(UPat(), UPat.any(UPatScheduled(), UPatScheduled().view()),)), realize), + (UPat(Ops.BUFFER_VIEW, src=(UPat.any(UPatScheduled(), UPatScheduled().view()),)), realize), + # substitute BITCAST/CONTIGUOUS with BUFFER_VIEW on DISK + (UPatScheduled((Ops.BITCAST, Ops.CONTIGUOUS), name="root", src=(UPat.var("x"),)), create_subbuffer), +]) + +def append_uop(ctx:ScheduleContext, view:UOp, buf_uop:UOp) -> None: + ctx.allbufs[buf_uop] = view + if (op:=uval(view)).op is Ops.ASSIGN: ctx.assigns.add(buf_uop) + for x in op.base.src: + if is_scheduled(x.base): ctx.children.setdefault(x.base.buf_uop, {})[buf_uop] = None +create_ctx = PatternMatcher([(UPat(Ops.VIEW, name="view", src=(UPat(Ops.BUFFER, name="buf_uop"), UPat())), append_uop)]) + +def is_scheduled(u:UOp) -> bool: return u.op is Ops.VIEW and len(u.src) == 2 and u.src[0].op is Ops.BUFFER +def uval(u:UOp) -> UOp: + assert is_scheduled(u), f"must be a scheduled op {u}" + return u.src[1] + +def recursive_group(tr:UOp, st:ShapeTracker, r:UOp, children:defaultdict[UOp, dict[UOp, None]], allbufs:dict[UOp, UOp], realizes:dict[UOp, UOp], + reduce_for_op:dict[UOp, UOp], group:dict[UOp, None], cache:dict[tuple[UOp, ShapeTracker], None]) -> None: + """recursively search the uop for groupable children, realize the UOp if a child can't group""" + if (tr, st) in cache: return + cache.setdefault((tr, st)) + rsize = unwrap(allbufs[r].st).size + if tr in realizes and tr is not r: + # can only fuse contiguous + # max one reduceop per kernel + if not st.contiguous or st.size != rsize or tr in reduce_for_op: group.setdefault(r) + return group.setdefault(tr) + for tr_next in children[tr]: + # max one reduceop per kernel + if (tr_next_uop:=uval(allbufs[tr_next]).base).op is Ops.REDUCE_AXIS: return group.setdefault(r) + # can only fuse contiguous + if len(st_childs:=dedup(unwrap(x.st) for x in tr_next_uop.src if is_scheduled(x.base) and x.base.buf_uop == tr)) > 1: return group.setdefault(r) + recursive_group(tr_next, st+st_childs[0], r, children, allbufs, realizes, reduce_for_op, group, cache) + +def group_realizes(sink:UOp, ctx:ScheduleContext) -> dict[UOp, UOp]: + # start by adding uops that always realize + sink = graph_rewrite(sink, do_realize+create_ctx, ctx) + # find all reduces, and pair them to a elementwise op. if they can't be cleanly paired, force realize the reduce (or a contig child) + reduce_for_op: dict[UOp, UOp] = {} + double_reduces: list[UOp] = [] + for r, r_uop in ctx.allbufs.items(): + if (r_uop:=uval(r_uop)).op is not Ops.REDUCE_AXIS: continue + if FUSE_CONV_BW and is_scheduled((x:=r_uop.src[0]).base) and uval(x.base).op is r_uop.op and x.base is not x: double_reduces.append(r) + if r in ctx.realizes: continue + group: dict[UOp, None] = {} + recursive_group(r, unwrap(r_uop.st), r, ctx.children, ctx.allbufs, ctx.realizes, reduce_for_op, group, cache={}) + # max one reduceop per kernel + can_chase = all(tr not in reduce_for_op for tr in group) + # TODO: forced_realize exists because the scheduler is incapable of checking for self-contained DAGs + forced_realize = r in group + # can only have one output + if not forced_realize and len(group) > 1: forced_realize = True + # can only fuse assign if no other assign_target is used in the kernel + if not forced_realize and any(x in ctx.assigns for x in group): + parents = deque((r, *group)) + while parents and not forced_realize: + if (p_uop:=ctx.allbufs.get(p:=parents.pop())) is None: continue + if (p_uop:=uval(p_uop)).op is Ops.ASSIGN and p not in group: forced_realize, can_chase = True, False + if p in ctx.realizes: continue + parents.extend([x.base.buf_uop for x in p_uop.src if x.base.is_realized or (x.base.op is Ops.VIEW and len(x.base.src) != 0)]) + if forced_realize or not group: + tr = r + if can_chase: + # can chase this down to contiguous children + st = unwrap(r_uop.st) + while len(ctx.children[tr]) == 1: + tr_next_uop = uval(ctx.allbufs[(tr_next:=next(iter(ctx.children[tr])))]) + st_childs = dedup([unwrap(x.st) for x in tr_next_uop.src if is_scheduled(x.base) and x.base.buf_uop is tr]) + if len(st_childs) > 1: break + if st.size != st_childs[0].size: break + st = st + st_childs[0] + if not st.contiguous or tr_next_uop.op is Ops.REDUCE_AXIS: break + tr = tr_next + # don't cast to higher size before store (tr cannot be realized if forced_realize) + if (tr_uop:=uval(ctx.allbufs[tr])).op is Ops.CAST and tr_uop.dtype.base.itemsize > tr_uop.src[0].dtype.base.itemsize: + tr = tr_uop.src[0].base.buf_uop + group = {tr: None} + ctx.realizes[tr] = tr + reduce_for_op.update((tr, r) for tr in group) + if FUSE_ARANGE and r_uop.arg[0] is Ops.ADD and r_uop.src[0].base.op is Ops.CONST: + # maybe fuse arange with its children + if len(flatten(ctx.children[tr] for tr in group)) != 0: + for tr in group: del ctx.realizes[tr] + # fuse double reduces with no other child + for reduceop in double_reduces: + top_reduce = uval(ctx.allbufs[reduceop]).src[0].base.buf_uop + if len(ctx.children[top_reduce]) == 1: del ctx.realizes[top_reduce] + graph_rewrite(sink, break_sched, ctx) + return ctx.realizes + +# break the SINK into stores + +def load_realized(ctx:ScheduleContext, b:UOp, st:UOp): + # NOTE: if we're assigning to the BUFFER too, PRELOAD tells toposort to place this load before the ASSIGN + return UOp(Ops.PRELOAD if b in ctx.assigns else Ops.LOAD, b.dtype.base, (b, unwrap(st.st).to_uop())) + +def store_or_fuse(ctx:ScheduleContext, b:UOp, x:UOp, st:UOp): + if (m:=ctx.tensor_uops[b][-1].metadata) is not None: ctx.ops_metadata[x] = m + if b not in ctx.realizes: return x # collapse BUFFER + ctx.realizes[b] = UOp.store(b, ShapeTracker.from_shape(st.shape).to_uop(), x) + return UOp(Ops.LOAD, x.dtype, (b, unwrap(st.st).to_uop())) + +break_sched = PatternMatcher([ + # VIEW of BUFFER either becomes a LOAD/STORE or we fuse it + (UPat(Ops.VIEW, name="st", src=(UPat(Ops.BUFFER, name="b"),)), load_realized), + (UPat(Ops.VIEW, name="st", src=(UPat(Ops.BUFFER, name="b"), UPat.var("x"))), store_or_fuse), +]) + +# **** ScheduleItem creation @dataclass(frozen=True) class ScheduleItem: @@ -31,49 +281,11 @@ class ScheduleItem: @functools.cached_property def output_idxs(self) -> tuple[int, ...]: return tuple(x.src[0].arg for x in self.ast.src) if self.ast.op is Ops.SINK else (0,) -# **** Schedule context and big graph - @dataclass(frozen=True) -class ScheduleContext: - tensor_uops: dict[UOp, list[UOp]] = field(default_factory=dict) # this maps BUFFER uops of this schedule to the tensor uop - assigns: set[UOp] = field(default_factory=set) # this holds all the BUFFER uops we ASSIGN to in this schedule - realizes: dict[UOp, UOp] = field(default_factory=dict) # this holds all the BUFFER uops we mutate in this schedule - allbufs: dict[UOp, UOp] = field(default_factory=dict) # this maps BUFFER uops the actual op - ops_metadata: dict[UOp, Metadata] = field(default_factory=dict) # this maps fused ops to Metadata - children: defaultdict[UOp, dict[UOp, None]] = field(default_factory=lambda: defaultdict(dict)) - preloads: defaultdict[Buffer, dict[UOp, None]] = field(default_factory=lambda: defaultdict(dict)) - -# wrap tensor uops around a VIEW(BUFFER, ) -# this BUFFER preserves a link back to the uop on the tensor after the scheduler rewrites it. -def add_buffers(buf:UOp, tensor_map:dict[UOp, list[UOp]], ctx:ScheduleContext, cache:dict[UOp, UOp]) -> UOp: - if (r:=cache.get(buf)) is not None: return r - # SINK is passthrough - if buf.op is Ops.SINK: return buf.replace(src=tuple(add_buffers(x, tensor_map, ctx, cache) for x in buf.src)) - # skip creating buffers for CONST/BIND/DEVICE/BUFFER - if buf.base.op in {Ops.CONST, Ops.BIND, Ops.DEVICE}: return buf - if buf.base.op is Ops.BUFFER: return buf.view(unwrap(buf.st)) - # VIEW is passthrough - if buf is not buf.base: - cache[buf] = ret = add_buffers(buf.base, tensor_map, ctx, cache).view(unwrap(buf.st)) - return ret - # make things that can't be images not images - dtype = buf.dtype - if isinstance(dtype, ImageDType) and (prod(buf.shape)!=prod(dtype.shape) or not any(buf.shape[x]%4==0 for x in unwrap(buf.st).unit_stride_axes())): - if DEBUG >= 2: print(f"forcing image {dtype} with shape {buf.shape} to {dtype.base}") - dtype = buf.dtype.base - # ASSIGN already has a target buffer, otherwise we create a new one - assert isinstance(buf.device, str), f"buf device is str, not {buf.device}" - buf_uop = buf.buf_uop if buf.op is Ops.ASSIGN else UOp.new_buffer(buf.device, buf.size, dtype) - op = buf.replace(dtype=dtype, src=tuple(add_buffers(x, tensor_map, ctx, cache) for x in buf.src)) - # track the underlying tensor uop for this buffer - ctx.tensor_uops[buf_uop] = tensor_map[buf] - # (early) bufferize - cache[buf] = ret = UOp(Ops.VIEW, dtype.base, (buf_uop, op), buf.st) - return ret - -# **** AST graph rewrite - -# ** movement ops +class ScheduleItemContext: + var_vals: dict[Variable, int] + sts: set[ShapeTracker] = field(default_factory=set) + bufs: list[UOp] = field(default_factory=list) def apply_swizzle(u:UOp) -> UOp: with Context(TRACK_MATCH_STATS=0): return graph_rewrite(u, view_left) @@ -127,14 +339,6 @@ view_right = merge_views+PatternMatcher([ (UPat(Ops.REDUCE_AXIS, src=(UPat(Ops.REDUCE_AXIS, name="first_reduce"),), name="root"), merge_double_reduce), ]) -# ** ScheduleItem context builder - -@dataclass(frozen=True) -class ScheduleItemContext: - var_vals: dict[Variable, int] - sts: set[ShapeTracker] = field(default_factory=set) - bufs: list[UOp] = field(default_factory=list) - def _append_st_vars(ctx:ScheduleItemContext, x:UOp) -> UOp|None: if (st:=unwrap(x.st)) in ctx.sts: return None st, var_vals = st.simplify().unbind() @@ -203,222 +407,7 @@ if CAPTURE_PROCESS_REPLAY: def save_process_replay() -> None: for k,v in PROCESS_REPLAY_CAPTURE.items(): diskcache_put("schedule_process_replay", k, v, prepickled=True) -# **** Schedule grouping - -def is_scheduled(u:UOp) -> bool: return u.op is Ops.VIEW and len(u.src) == 2 and u.src[0].op is Ops.BUFFER -def uval(u:UOp) -> UOp: - assert is_scheduled(u), f"must be a scheduled op {u}" - return u.src[1] - -def recursive_group(tr:UOp, st:ShapeTracker, r:UOp, children:defaultdict[UOp, dict[UOp, None]], allbufs:dict[UOp, UOp], realizes:dict[UOp, UOp], - reduce_for_op:dict[UOp, UOp], group:dict[UOp, None], cache:dict[tuple[UOp, ShapeTracker], None]) -> None: - """recursively search the uop for groupable children, realize the UOp if a child can't group""" - if (tr, st) in cache: return - cache.setdefault((tr, st)) - rsize = unwrap(allbufs[r].st).size - if tr in realizes and tr is not r: - # can only fuse contiguous - # max one reduceop per kernel - if not st.contiguous or st.size != rsize or tr in reduce_for_op: group.setdefault(r) - return group.setdefault(tr) - for tr_next in children[tr]: - # max one reduceop per kernel - if (tr_next_uop:=uval(allbufs[tr_next]).base).op is Ops.REDUCE_AXIS: return group.setdefault(r) - # can only fuse contiguous - if len(st_childs:=dedup(unwrap(x.st) for x in tr_next_uop.src if is_scheduled(x.base) and x.base.buf_uop == tr)) > 1: return group.setdefault(r) - recursive_group(tr_next, st+st_childs[0], r, children, allbufs, realizes, reduce_for_op, group, cache) - -def group_realizes(ctx:ScheduleContext) -> None: - """search the big graph for all the reduceops that need to realize, sometimes group/fuse the reduceop""" - # find all reduces, and pair them to a elementwise op. if they can't be cleanly paired, force realize the reduce (or a contig child) - reduce_for_op: dict[UOp, UOp] = {} - double_reduces: list[UOp] = [] - for r, r_uop in ctx.allbufs.items(): - if (r_uop:=uval(r_uop)).op is not Ops.REDUCE_AXIS: continue - if FUSE_CONV_BW and is_scheduled((x:=r_uop.src[0]).base) and uval(x.base).op is r_uop.op and x.base is not x: double_reduces.append(r) - if r in ctx.realizes: continue - group: dict[UOp, None] = {} - recursive_group(r, unwrap(r_uop.st), r, ctx.children, ctx.allbufs, ctx.realizes, reduce_for_op, group, cache={}) - # max one reduceop per kernel - can_chase = all(tr not in reduce_for_op for tr in group) - # TODO: forced_realize exists because the scheduler is incapable of checking for self-contained DAGs - forced_realize = r in group - # can only have one output - if not forced_realize and len(group) > 1: forced_realize = True - # can only fuse assign if no other assign_target is used in the kernel - if not forced_realize and any(x in ctx.assigns for x in group): - parents = deque((r, *group)) - while parents and not forced_realize: - if (p_uop:=ctx.allbufs.get(p:=parents.pop())) is None: continue - if (p_uop:=uval(p_uop)).op is Ops.ASSIGN and p not in group: forced_realize, can_chase = True, False - if p in ctx.realizes: continue - parents.extend([x.base.buf_uop for x in p_uop.src if x.base.is_realized or (x.base.op is Ops.VIEW and len(x.base.src) != 0)]) - if forced_realize or not group: - tr = r - if can_chase: - # can chase this down to contiguous children - st = unwrap(r_uop.st) - while len(ctx.children[tr]) == 1: - tr_next_uop = uval(ctx.allbufs[(tr_next:=next(iter(ctx.children[tr])))]) - st_childs = dedup([unwrap(x.st) for x in tr_next_uop.src if is_scheduled(x.base) and x.base.buf_uop is tr]) - if len(st_childs) > 1: break - if st.size != st_childs[0].size: break - st = st + st_childs[0] - if not st.contiguous or tr_next_uop.op is Ops.REDUCE_AXIS: break - tr = tr_next - # don't cast to higher size before store (tr cannot be realized if forced_realize) - if (tr_uop:=uval(ctx.allbufs[tr])).op is Ops.CAST and tr_uop.dtype.base.itemsize > tr_uop.src[0].dtype.base.itemsize: - tr = tr_uop.src[0].base.buf_uop - group = {tr: None} - ctx.realizes[tr] = tr - reduce_for_op.update((tr, r) for tr in group) - if FUSE_ARANGE and r_uop.arg[0] is Ops.ADD and r_uop.src[0].base.op is Ops.CONST: - # maybe fuse arange with its children - if len(flatten(ctx.children[tr] for tr in group)) != 0: - for tr in group: del ctx.realizes[tr] - # fuse double reduces with no other child - for reduceop in double_reduces: - top_reduce = uval(ctx.allbufs[reduceop]).src[0].base.buf_uop - if len(ctx.children[top_reduce]) == 1: del ctx.realizes[top_reduce] - -# **** Schedule creation and BFS toposort - -# ** this is schedule level const folding - -def simplify_reduceop(reduce:UOp, x:UOp) -> UOp|None: - if not all_int(x.shape): return None - # remove reduce on unmasked const - prshape = prod(unwrap(x.st).shape[i] for i in reduce.arg[1]) - ret = x.const_arg - match reduce.arg[0]: - case Ops.ADD: ret *= prshape - case Ops.MUL: ret **= prshape - case Ops.MAX: pass # NOTE: Ops.MAX is passthrough - case _: return None - return reduce.const_like(ret) - -def found_contiguous(ctx:dict[UOp, UOp], contig:UOp, src:UOp): - if (sti:=unwrap(src.st).invert(src.base.shape)) is not None: ctx[src.base] = contig.view(sti) -def replace_contiguous(ctx:dict[UOp, UOp], alu:UOp): - new_src = list(alu.src) - for i,s in enumerate(alu.src): - if (replace_src:=ctx.get(s, None)) is not None: new_src[i] = replace_src - if tuple(new_src) != alu.src: return alu.replace(src=tuple(new_src)) - -sym = symbolic_simple+PatternMatcher([ - # UOp with size 0 is zero - (UPat(set(Ops)-{Ops.SINK}, name="root"), lambda root: root.const_like(0) if root.base.st is not None and root.size == 0 \ - and not (root.base.op is Ops.CONST and root.base.arg == 0) else None), - # DETACH and CONTIGUOUS_BACKWARD are NOOPs here - (UPat((Ops.DETACH, Ops.CONTIGUOUS_BACKWARD), name="x"), lambda x: x.src[0]), - # reduce of size 0 is the identity element - (UPat(Ops.REDUCE_AXIS, name="reduce", src=(UPat.var("x"),)), - lambda reduce,x: reduce.const_like(identity_element(reduce.arg[0], reduce.dtype)) if x.size == 0 and reduce.size != 0 else None), - # reduce of const is collapsed (TODO: make this a generic rule for stride0) - (UPat(Ops.REDUCE_AXIS, name="reduce", src=(UPat.cvar("x"),)), simplify_reduceop), - # COPY(CONST) creates a new CONST on the destination device - (UPat(Ops.COPY, name="root", src=(UPat(), UPat.cvar("x"),)), lambda root,x: root.const_like(x.const_arg)), - # no COPY to same device, except clone (arg is True) - (UPat(Ops.COPY, src=(UPat(), UPat.var("copyin")), name="copy"), - lambda copyin,copy: copyin if copyin.device == copy.device and copy.arg is not True else None), - # remove cast to image when it's already a contiguous image - (UPat(Ops.VIEW, name="vm1", src=(UPat(Ops.CAST, name="cast", src=(UPat(Ops.VIEW, name="vm2", src=(UPat(Ops.CONTIGUOUS, name="base"))))),)), - lambda cast,base,vm1,vm2: base.view(vm2.st+vm1.st) if isinstance(cast.dtype, ImageDType) and isinstance(base.dtype, ImageDType) else None), - # remove contiguous if we can just view the buffer - (UPat(Ops.CONTIGUOUS, name="root", src=(UPat(Ops.VIEW, name="view", src=(UPat(Ops.BUFFER, name="buf"),)),)), - lambda root,view,buf: view if view.st.contiguous and view.size == buf.size else None), - # contiguous/buffer is already contiguous - (UPat(Ops.CONTIGUOUS, name="root", src=(UPat((Ops.CONTIGUOUS, Ops.BUFFER)),)), lambda root: root.src[0]), - # support for using a contiguous permuted view instead of the parent view if one exists - (UPat(Ops.CONTIGUOUS, name="contig", src=(UPat(Ops.VIEW, name="src"),)), found_contiguous), - (UPat(GroupOp.ALU, name="alu"), replace_contiguous), - # remove CONST/BIND/BUFFER from SINK - (UPat(Ops.SINK, name="root"), - lambda root: UOp(Ops.SINK, root.dtype, new_src, root.arg) - if (new_src:=tuple(x for x in root.src if not x.is_realized and x.base.op not in {Ops.CONST, Ops.BIND})) != root.src else None), -]) - -# ** this decides which ops get realized - -class UPatScheduled(UPat): - def __init__(self, *args, **kwargs): - super().__init__(Ops.VIEW, name="base", src=(UPat(Ops.BUFFER, name="b"), UPat(*args, **{"name":"to_store",**kwargs}))) - -def realize(ctx:ScheduleContext, b:UOp, to_store:UOp, **kwargs) -> None: ctx.realizes[b] = to_store - -def realize_before_view(ctx:ScheduleContext, view:UOp, src:UOp, b:UOp, **kwargs) -> None: - st = unwrap(view.st) - # fold simple pads - if len(st.views) == 1 and (m:=st.views[-1].mask) is not None and all_int(src.shape) and resolve(prod(src.shape) >= prod([y-x for x,y in m])): - return None if can_pad(src, ctx.realizes, set()) else realize(ctx, b, src) - # early realize before expand - if resolve(prod(src.shape) < prod(st.shape)) and not getenv("DONT_REALIZE_EXPAND"): return realize(ctx, b, src) - # otherwise safety check pads - return None if (all(v.mask is None for v in st.views) or can_pad(src, ctx.realizes, set())) else realize(ctx, b, src) - -def fold_img_cast(ctx:ScheduleContext, xb:UOp, view:UOp, b:UOp, x:UOp, **kwargs) -> UOp|None: - if not isinstance(xb.dtype, ImageDType) or b not in ctx.realizes or xb not in ctx.realizes or uval(x.base).op is Ops.COPY: return None - del ctx.realizes[b] - return x.view(unwrap(view.st)) - -def create_subbuffer(base:UOp, b:UOp, root:UOp, x:UOp): - if isinstance(b.device, tuple) or not b.device.startswith("DISK"): return None - buffers[b] = x.buf_uop.buffer.view(b.size, b.dtype, unwrap(x.st).views[0].offset*x.dtype.itemsize) - return base.replace(src=(b, root.replace(op=Ops.BUFFER_VIEW))) - -do_realize = PatternMatcher([ - # always realize SINK parents - (UPat(Ops.SINK, name="sink"), lambda ctx,sink: ctx.realizes.update((x.buf_uop, x) for x in sink.src)), - # always realize ASSIGN/CONTIGUOUS/COPY/BUFFER_VIEW - (UPatScheduled({Ops.ASSIGN, Ops.CONTIGUOUS, Ops.COPY, Ops.BUFFER_VIEW}), realize), - # realize before expand or unsafe pad ops - (UPat(Ops.VIEW, name="view", src=(UPatScheduled(name="src"),)), realize_before_view), - # don't realize image to image casts - (UPat(Ops.VIEW, name="view", src=(UPatScheduled(Ops.CAST, src=(UPat(Ops.VIEW, src=(UPat.var("xb"), UPat()), name="x"),), dtype=dtypes.float),)), - fold_img_cast), - # realize before COPY or BUFFER_VIEW - (UPat(Ops.COPY, src=(UPat(), UPat.any(UPatScheduled(), UPatScheduled().view()),)), realize), - (UPat(Ops.BUFFER_VIEW, src=(UPat.any(UPatScheduled(), UPatScheduled().view()),)), realize), - # substitute BITCAST/CONTIGUOUS with BUFFER_VIEW on DISK - (UPatScheduled((Ops.BITCAST, Ops.CONTIGUOUS), name="root", src=(UPat.var("x"),)), create_subbuffer), -]) - -# **** rewrite VIEW into LOAD/STORE or fuse the underlying UOp - -def load_realized(ctx:ScheduleContext, b:UOp, st:UOp): - # NOTE: if we're assigning to the BUFFER too, PRELOAD tells toposort to place this load before the ASSIGN - return UOp(Ops.PRELOAD if b in ctx.assigns else Ops.LOAD, b.dtype.base, (b, unwrap(st.st).to_uop())) - -def store_or_fuse(ctx:ScheduleContext, b:UOp, x:UOp, st:UOp): - if (m:=ctx.tensor_uops[b][-1].metadata) is not None: ctx.ops_metadata[x] = m - if b not in ctx.realizes: return x # collapse BUFFER - ctx.realizes[b] = UOp.store(b, ShapeTracker.from_shape(st.shape).to_uop(), x) - return UOp(Ops.LOAD, x.dtype, (b, unwrap(st.st).to_uop())) - -break_sched = PatternMatcher([ - # VIEW of BUFFER either becomes a LOAD/STORE or we fuse it - (UPat(Ops.VIEW, name="st", src=(UPat(Ops.BUFFER, name="b"),)), load_realized), - (UPat(Ops.VIEW, name="st", src=(UPat(Ops.BUFFER, name="b"), UPat.var("x"))), store_or_fuse), -]) - -# **** Schedule context builder - -def append_uop(ctx:ScheduleContext, view:UOp, buf_uop:UOp) -> None: - ctx.allbufs[buf_uop] = view - if (op:=uval(view)).op is Ops.ASSIGN: ctx.assigns.add(buf_uop) - for x in op.base.src: - if is_scheduled(x.base): ctx.children.setdefault(x.base.buf_uop, {})[buf_uop] = None -create_ctx = PatternMatcher([(UPat(Ops.VIEW, name="view", src=(UPat(Ops.BUFFER, name="buf_uop"), UPat())), append_uop)]) - -# **** movement ops - -remove_movement_ops = merge_views+PatternMatcher([ - # NOTE: movement ops are always applied to base - (UPat(GroupOp.Movement, name="mov", src=(UPat.var("x"),)), lambda x,mov: x.view(unwrap(mov.st))), - # some masked views can collapse to 0, VIEW(x) -> CONST(VIEW) - (UPat(Ops.VIEW, name="view"), - lambda view: view.const_like(0) if (vm:=view.st.views[-1].mask) is not None and any((x[1]-x[0]) == 0 for x in vm) else None), -]) +# **** schedule creation and toposort @track_rewrites(named=True) def create_schedule_with_vars(big_sink:UOp) -> tuple[list[ScheduleItem], dict[Variable, int], dict[UOp, UOp]]: @@ -434,15 +423,13 @@ def create_schedule_with_vars(big_sink:UOp) -> tuple[list[ScheduleItem], dict[Va elif v.op is Ops.CONST and all_int(v.shape): becomes_map[k] = v # we group the rest of UOps into ScheduleItems - rev_tensor_map: dict[UOp, list[UOp]] = {} - for k,v in tensor_map.items(): rev_tensor_map.setdefault(v, []).append(k) - # add BUFFER uops - sink = add_buffers(tensor_map[big_sink], rev_tensor_map, ctx:=ScheduleContext(), cache={}) - # add realizes - sink = graph_rewrite(sink, do_realize+create_ctx, ctx) - # group realizes into kernels - group_realizes(ctx) - graph_rewrite(sink, break_sched, ctx) + buffer_map: dict[UOp, UOp] = {} + sink = add_buffers(tensor_map[big_sink], buffer_map, cache={}) + # get realizes + buf_tensors: dict[UOp, list[UOp]] = {} + for k,v in tensor_map.items(): + if (b:=buffer_map.get(v)) is not None: buf_tensors.setdefault(b, []).append(k) + realize_map = group_realizes(sink, ctx:=ScheduleContext(buf_tensors)) # TODO: this should be the break between the "grouper" and the "linearizer" # here, there should just be one sink UOp with BUFFER/KERNEL/COPY/ASSIGN (assign is the parent if you want the buffer post assign) @@ -451,11 +438,11 @@ def create_schedule_with_vars(big_sink:UOp) -> tuple[list[ScheduleItem], dict[Va # create schedule items + map buffers to realized tensors prescheduled: list[ScheduleItem] = [] var_vals: dict[Variable, int] = {} - for buf_uop,store in ctx.realizes.items(): + for buf_uop,store in realize_map.items(): assert store.op is Ops.STORE, f"expected a realized BUFFER to get a STORE {sink}" prescheduled.append(schedule_uop(store.sink(), ctx, var_vals)) # can only schedule once - for tensor_uop in ctx.tensor_uops[buf_uop]: becomes_map[tensor_uop] = buf_uop.view(unwrap(tensor_uop.st)) + for tensor_uop in buf_tensors[buf_uop]: becomes_map[tensor_uop] = buf_uop.view(unwrap(tensor_uop.st)) # increment refcount for this buffer buf_uop.buffer.ref(1) diff --git a/tinygrad/ops.py b/tinygrad/ops.py index 1327537e19..f193bd9bfd 100644 --- a/tinygrad/ops.py +++ b/tinygrad/ops.py @@ -93,7 +93,7 @@ class MathTrait(SimpleMathTrait): # the order of these Ops controls the order of the toposort class Ops(FastEnum): # uops that aren't rendered - SINK = auto(); CONTIGUOUS = auto(); CONTIGUOUS_BACKWARD = auto(); DETACH = auto(); PRELOAD = auto() # noqa: E702 + SINK = auto(); CONTIGUOUS = auto(); CONTIGUOUS_BACKWARD = auto(); DETACH = auto(); PRELOAD = auto(); KERNEL = auto() # noqa: E702 # TODO: empty continues to exist because of tensor EMPTY = auto() diff --git a/tinygrad/runtime/ops_clang.py b/tinygrad/runtime/ops_clang.py index 2baf572382..463799f305 100644 --- a/tinygrad/runtime/ops_clang.py +++ b/tinygrad/runtime/ops_clang.py @@ -1,25 +1,9 @@ -import platform, tempfile, pathlib, subprocess, sys -from tinygrad.helpers import cpu_objdump, capstone_flatdump +import platform, subprocess, sys +from tinygrad.helpers import capstone_flatdump from tinygrad.device import Compiled, Compiler, MallocAllocator, CPUProgram from tinygrad.runtime.support.elf import jit_loader from tinygrad.renderer.cstyle import ClangRenderer -# Used by ops_dsp.py -class ClangCompiler(Compiler): - def __init__(self, cachekey="compile_clang", args:list[str]|None=None, objdump_tool='objdump'): - self.args = ['-march=native'] if args is None else args - self.objdump_tool = objdump_tool - super().__init__(cachekey) - - def compile(self, src:str) -> bytes: - # TODO: remove file write. sadly clang doesn't like the use of /dev/stdout here - with tempfile.NamedTemporaryFile(delete=True) as output_file: - subprocess.check_output(['clang', '-shared', *self.args, '-O2', '-Wall', '-Werror', '-x', 'c', '-fPIC', '-ffreestanding', '-nostdlib', - '-', '-o', str(output_file.name)], input=src.encode('utf-8')) - return pathlib.Path(output_file.name).read_bytes() - - def disassemble(self, lib:bytes): return cpu_objdump(lib, self.objdump_tool) - class ClangJITCompiler(Compiler): def __init__(self, cachekey="compile_clang_jit"): super().__init__(cachekey) diff --git a/tinygrad/runtime/ops_dsp.py b/tinygrad/runtime/ops_dsp.py index 7cac17c6c1..8813bafa45 100644 --- a/tinygrad/runtime/ops_dsp.py +++ b/tinygrad/runtime/ops_dsp.py @@ -1,12 +1,11 @@ from __future__ import annotations from typing import Tuple, Any, List -import ctypes, os, mmap, tempfile, pathlib, array, functools, threading, contextlib, sys +import ctypes, os, mmap, tempfile, pathlib, array, functools, threading, contextlib, sys, subprocess assert sys.platform != 'win32' -from tinygrad.device import BufferSpec, Compiled, Allocator +from tinygrad.device import BufferSpec, Compiled, Allocator, Compiler from tinygrad.dtype import dtypes, DType, PtrDType from tinygrad.ops import Ops, UOp -from tinygrad.helpers import from_mv, getenv, round_up, mv_address, to_mv -from tinygrad.runtime.ops_clang import ClangCompiler +from tinygrad.helpers import from_mv, getenv, round_up, mv_address, to_mv, cpu_objdump from tinygrad.renderer.cstyle import ClangRenderer from tinygrad.runtime.autogen import libc, qcom_dsp if getenv("IOCTL"): import extra.dsp.run # noqa: F401 # pylint: disable=unused-import @@ -91,10 +90,23 @@ class DSPAllocator(Allocator): def _copyout(self, dest:memoryview, src:DSPBuffer): ctypes.memmove(from_mv(dest), src.va_addr, dest.nbytes) def _offset(self, buf, size:int, offset:int): return DSPBuffer(buf.va_addr+offset, size, buf.share_info, buf.offset+offset) -class DSPDevice(Compiled): - def __init__(self, device:str=""): - self.ion_fd = os.open('/dev/ion', os.O_RDONLY) +class ClangCompiler(Compiler): + def __init__(self, cachekey="compile_clang", args:list[str]|None=None, objdump_tool='objdump'): + self.args = ['-march=native'] if args is None else args + self.objdump_tool = objdump_tool + super().__init__(cachekey) + def compile(self, src:str) -> bytes: + # TODO: remove file write. sadly clang doesn't like the use of /dev/stdout here + with tempfile.NamedTemporaryFile(delete=True) as output_file: + subprocess.check_output(['clang', '-shared', *self.args, '-O2', '-Wall', '-Werror', '-x', 'c', '-fPIC', '-ffreestanding', '-nostdlib', + '-', '-o', str(output_file.name)], input=src.encode('utf-8')) + return pathlib.Path(output_file.name).read_bytes() + + def disassemble(self, lib:bytes): return cpu_objdump(lib, self.objdump_tool) + +class DSPCompiler(ClangCompiler): + def __init__(self): # Generate link script to pass into clang. Aligning all used sections to 4k fixes invoke problem. sections = ['hash', 'text', 'rela.plt', 'got', 'got.plt', 'dynamic', 'dynsym', 'dynstr', 'plt', 'data', 'bss'] sections_link = '\n'.join([f'.{n} : ALIGN(4096) {{ *(.{n}) }}' for n in sections]) @@ -103,15 +115,19 @@ class DSPDevice(Compiled): self.link_ld.flush() compiler_args = ["--target=hexagon", "-mcpu=hexagonv65", "-fuse-ld=lld", "-nostdlib", "-mhvx=v65", "-mhvx-length=128b", f"-T{self.link_ld.name}"] - super().__init__(device, DSPAllocator(self), DSPRenderer(), - ClangCompiler("compile_dsp", args=compiler_args, objdump_tool='llvm-objdump'), functools.partial(DSPProgram, self)) + return super().__init__("compile_dsp", args=compiler_args, objdump_tool='llvm-objdump') + +class DSPDevice(Compiled): + def __init__(self, device:str=""): + self.ion_fd = os.open('/dev/ion', os.O_RDONLY) + super().__init__(device, DSPAllocator(self), DSPRenderer(), DSPCompiler(), functools.partial(DSPProgram, self)) fastrpc_shell = memoryview(bytearray(pathlib.Path('/dsp/cdsp/fastrpc_shell_3').read_bytes())) self.shell_buf = self.allocator.alloc(round_up(fastrpc_shell.nbytes, 0x1000), BufferSpec(nolru=True)) ctypes.memmove(self.shell_buf.va_addr, mv_address(fastrpc_shell), fastrpc_shell.nbytes) self.init_dsp() - RPCListner(self).start() + RPCListener(self).start() def open_lib(self, lib): self.binded_lib, self.binded_lib_off = lib, 0 @@ -149,7 +165,7 @@ class DSPDevice(Compiled): qcom_dsp.FASTRPC_IOCTL_INIT(self.rpc_fd, flags=0x1, file=self.shell_buf.va_addr, filelen=self.shell_buf.size, filefd=self.shell_buf.share_info.fd) qcom_dsp.FASTRPC_IOCTL_INVOKE(self.rpc_fd, handle=3, sc=rpc_sc(method=3, ins=0, outs=0)) -class RPCListner(threading.Thread): +class RPCListener(threading.Thread): def __init__(self, device:DSPDevice): super().__init__() self.device, self.daemon = device, True diff --git a/tinygrad/runtime/ops_llvm.py b/tinygrad/runtime/ops_llvm.py index 0dff0e4de3..64583a5404 100644 --- a/tinygrad/runtime/ops_llvm.py +++ b/tinygrad/runtime/ops_llvm.py @@ -11,12 +11,15 @@ def expect(x, err, ret=None): if x: raise RuntimeError(llvm.string_cast(err.contents) if not isinstance(err, str) else err) return ret -HOST_ARCH = {'arm64': 'AArch64', 'aarch64': 'AArch64', 'x86_64': 'X86', 'AMD64': 'X86'}[platform.machine()] -HOST_TRIPLE = {'AArch64': 'aarch64', 'X86': 'x86_64'}[HOST_ARCH] -REQUIRED_COMPONENTS = ['Target', 'TargetInfo', 'TargetMC', 'AsmPrinter'] - class LLVMCompiler(Compiler): - def __init__(self, target_machine, opt): + def __init__(self, host_arch:str, opt:bool): + for component in ['Target', 'TargetInfo', 'TargetMC', 'AsmPrinter']: getattr(llvm, f'LLVMInitialize{host_arch}{component}')() + triple = ({'AArch64': 'aarch64', 'X86': 'x86_64'}[host_arch]+'-none-unknown-elf').encode() + + target = expect(llvm.LLVMGetTargetFromTriple(triple, ctypes.pointer(tgt:=llvm.LLVMTargetRef()), err:=cerr()), err, tgt) + target_machine = llvm.LLVMCreateTargetMachine(target, triple, b'', b'+reserve-x18' if platform.machine() == 'arm64' else b'', + llvm.LLVMCodeGenLevelDefault, llvm.LLVMRelocPIC, llvm.LLVMCodeModelDefault) + self.pbo = llvm.LLVMCreatePassBuilderOptions() if opt: self.passes = b'default' @@ -48,14 +51,5 @@ class LLVMCompiler(Compiler): class LLVMDevice(Compiled): def __init__(self, device:str): - for component in REQUIRED_COMPONENTS: - getattr(llvm, f'LLVMInitialize{HOST_ARCH}{component}')() - - triple = f'{HOST_TRIPLE}-none-unknown-elf'.encode() - target = expect(llvm.LLVMGetTargetFromTriple(triple, ctypes.pointer(tgt:=llvm.LLVMTargetRef()), err:=cerr()), err, tgt) - features = b'+reserve-x18' if platform.machine() == 'arm64' else b'' - target_machine = llvm.LLVMCreateTargetMachine(target, triple, b'', features, llvm.LLVMCodeGenLevelDefault, llvm.LLVMRelocPIC, - llvm.LLVMCodeModelDefault) - - super().__init__(device, MallocAllocator, LLVMRenderer('win64cc' if sys.platform == 'win32' else None), - LLVMCompiler(target_machine, getenv("LLVMOPT")), CPUProgram) + compiler = LLVMCompiler({'arm64': 'AArch64', 'aarch64': 'AArch64', 'x86_64': 'X86', 'AMD64': 'X86'}[platform.machine()], bool(getenv("LLVMOPT"))) + super().__init__(device, MallocAllocator, LLVMRenderer('win64cc' if sys.platform == 'win32' else None), compiler, CPUProgram) diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index 8d048eb431..de142aa06f 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -3314,10 +3314,10 @@ class Tensor(SimpleMathTrait): if not base.is_floating_point(): raise RuntimeError("base needs to be float") # start with b ** e = exp(e * log(b)) ret = base.abs().log().mul(exponent).exp() - # correct sign of negative base with odd exponent (cos has a period of 2pi so we use it here to get the oddness of the exponent) + # correct sign of negative base with odd exponent negative_base = (base < 0).detach().where(1, 0) # 1 for non-negative base or negative even exponent, -1 for negative odd exponent, don't care about non-integer exponent - correct_sign = 1 + negative_base * ((exponent * math.pi).cos() - 1) + correct_sign = (exponent.int()%2==0).where(1, 1-2*negative_base) # inject nan for negative base and non-integer exponent inject_nan = (negative_base * (exponent != exponent.trunc())).detach().where(math.nan, 1) # apply correct_sign inject_nan, and fix 0 ** 0 = 1 diff --git a/tinygrad/viz/serve.py b/tinygrad/viz/serve.py index 1e75338adc..a7163213d1 100755 --- a/tinygrad/viz/serve.py +++ b/tinygrad/viz/serve.py @@ -12,7 +12,7 @@ from tinygrad.dtype import dtypes uops_colors = {Ops.LOAD: "#ffc0c0", Ops.PRELOAD: "#ffc0c0", Ops.STORE: "#87CEEB", Ops.CONST: "#e0e0e0", Ops.VCONST: "#e0e0e0", Ops.DEFINE_GLOBAL: "#ffe0b0", Ops.DEFINE_LOCAL: "#ffe0d0", Ops.DEFINE_ACC: "#f0ffe0", Ops.REDUCE_AXIS: "#FF6B6B", Ops.RANGE: "#c8a0e0", Ops.ASSIGN: "#e0ffc0", Ops.BARRIER: "#ff8080", Ops.IF: "#c8b0c0", Ops.SPECIAL: "#c0c0ff", - Ops.INDEX: "#e8ffa0", Ops.WMMA: "#efefc0", Ops.VIEW: "#C8F9D4", Ops.MULTI: "#f6ccff", + Ops.INDEX: "#e8ffa0", Ops.WMMA: "#efefc0", Ops.VIEW: "#C8F9D4", Ops.MULTI: "#f6ccff", Ops.KERNEL: "#3e7f55", **{x:"#D8F9E4" for x in GroupOp.Movement}, **{x:"#ffffc0" for x in GroupOp.ALU}, Ops.THREEFRY:"#ffff80", Ops.BLOCK: "#C4A484", Ops.BLOCKEND: "#C4A4A4", Ops.BUFFER: "#B0BDFF", Ops.COPY: "#a040a0"}