From d0d39885c386d730da29f6a9cc1fdac589319b9e Mon Sep 17 00:00:00 2001 From: chenyu Date: Thu, 14 Aug 2025 16:57:21 -0700 Subject: [PATCH] onnx in tinygrad (#11675) --- .github/workflows/test.yml | 8 +- extra/onnx.py | 1261 ------------------- extra/onnx_helpers.py | 3 +- test/external/external_test_onnx_runner.py | 3 +- tinygrad/frontend/onnx.py | 1267 +++++++++++++++++++- 5 files changed, 1266 insertions(+), 1276 deletions(-) delete mode 100644 extra/onnx.py diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index b8abbc8a85..27a93773d2 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -329,15 +329,13 @@ jobs: run: | pip3 install --upgrade --force-reinstall ruff==0.11.0 python3 -m ruff check . - python3 -m ruff check extra/onnx.py python3 -m ruff check examples/mlperf/ --ignore E501 - name: Lint tinygrad with pylint - run: python -m pylint tinygrad/ extra/onnx.py + run: python -m pylint tinygrad/ - name: Run mypy run: | python -m mypy --strict-equality --lineprecision-report . cat lineprecision.txt - python -m mypy --strict-equality extra/onnx.py unittest: name: Unit Tests @@ -375,8 +373,8 @@ jobs: PYTHONPATH=. python extra/optimization/extract_dataset.py gzip -c /tmp/sops > extra/datasets/sops.gz DEBUG=1 MIN_ASTS=1 PYTHONPATH=. python extra/optimization/get_action_space.py - - name: Repo line count < 16000 lines - run: MAX_LINE_COUNT=16000 python sz.py + - name: Repo line count < 17000 lines + run: MAX_LINE_COUNT=17000 python sz.py fuzzing: name: Fuzzing diff --git a/extra/onnx.py b/extra/onnx.py deleted file mode 100644 index bcead3994d..0000000000 --- a/extra/onnx.py +++ /dev/null @@ -1,1261 +0,0 @@ -# pylint: disable=possibly-unused-variable -from typing import Any, Sequence, cast, Literal, NamedTuple, Generator -import dataclasses, functools, io, math, types, warnings, pathlib, sys, os, struct, enum -from io import BufferedReader -from tinygrad.nn.state import TensorIO -from tinygrad.tensor import Tensor, _broadcast_shape, ReductionStr -from tinygrad.helpers import getenv, DEBUG, all_same, prod, flatten, make_tuple, argsort, is_numpy_ndarray, get_single_element, polyN -from tinygrad.dtype import DType, ConstType, dtypes, _from_np_dtype -from tinygrad.device import is_dtype_supported, Device - -# ***** protobuf definitions ****** -class WireType(enum.IntEnum): - """ - Protocol Buffer wire types for decoding fields. - Reference: https://github.com/protocolbuffers/protobuf/blob/main/python/google/protobuf/internal/wire_format.py#L24-L29 - """ - VARINT = 0; FIXED64 = 1; LENGTH_DELIMITED = 2; START_GROUP = 3; END_GROUP = 4; FIXED32 = 5 # noqa: E702 - -class AttributeType(enum.IntEnum): - """ - ONNX attribute type identifiers. - Reference: https://github.com/onnx/onnx/blob/rel-1.18.0/onnx/onnx.proto3#L128-L145 - """ - FLOAT = 1; INT = 2; STRING = 3; TENSOR = 4; FLOATS = 6; INTS = 7; STRINGS = 8 # noqa: E702 - - def to_field_name(self) -> str: return {1: "f", 2: "i", 3: "s", 4: "t", 6: "floats", 7: "ints", 8: "strings"}[self.value] - -class OnnxDataType(enum.IntEnum): - """ - ONNX tensor data type identifiers. - Reference: https://github.com/onnx/onnx/blob/rel-1.18.0/onnx/onnx.proto3#L500-L544 - """ - FLOAT = 1; UINT8 = 2; INT8 = 3; UINT16 = 4; INT16 = 5; INT32 = 6; INT64 = 7; BOOL = 9; FLOAT16 = 10; DOUBLE = 11; UINT32 = 12 # noqa: E702 - UINT64 = 13; BFLOAT16 = 16 # noqa: E702 - - def to_dtype(self) -> DType: return dtypes.fields()[self.name.lower()] - -def dtype_fallback(dtype: DType, fallback_context: str) -> DType: - if is_dtype_supported(dtype): return dtype - default_dtype = dtypes.default_int if dtypes.is_int(dtype) else dtypes.default_float - warnings.warn(f"dtype {dtype} on {Device.DEFAULT} from {fallback_context} is not supported, falling back to {default_dtype}") - assert is_dtype_supported(default_dtype), f"dtype {default_dtype} must be supported on {Device.DEFAULT}" - return default_dtype - -# ***** onnx spec definitions ***** -class Domain(enum.Enum): - ONNX = "ai.onnx" - ONNX_ML = "ai.onnx.ml" - AI_ONNX_TRAINING = "ai.onnx.training" - AI_ONNX_PREVIEW_TRAINING = "ai.onnx.preview.training" - MICROSOFT_CONTRIB_OPS = "com.microsoft" - MICROSOFT_NCHWC = "com.microsoft.nchwc" - MICROSOFT_EXPERIMENTAL = "com.microsoft.experimental" - PYTORCH_ATEN = "org.pytorch.aten" - @classmethod - def from_onnx(cls, domain: str | None) -> "Domain": return cls.ONNX if domain is None or domain == "" else cls(domain) - -class OpSetId(NamedTuple): - domain: Domain - version: int - -@dataclasses.dataclass(frozen=True) -class OnnxValue: - shape: tuple[str|int, ...] - dtype: DType - is_optional: bool - is_sequence: bool - -@dataclasses.dataclass(frozen=True) -class OnnxNode: - op: str - opset_id: OpSetId - inputs: tuple[str, ...] - outputs: tuple[str, ...] - opts: dict[str, Any] - -# ***** protobuf parsing ****** -class PBBufferedReader(BufferedReader): - def __init__(self, tensor: Tensor): - assert tensor.dtype == dtypes.uint8, tensor - super().__init__(TensorIO(tensor)) - self.len = tensor.nbytes() - - def decode_varint(self) -> int: - """Reference: https://protobuf.dev/programming-guides/encoding/#varints""" - result = 0 - shift = 0 - while True: - data = self.read(1) - if data == b"": raise EOFError("decode_varint EOF") - result |= (data[0] & 0x7F) << shift - if not (data[0] & 0x80): return result - shift += 7 - if shift >= 70: raise ValueError("Varint too long") - - def read_delimited(self, use_tensor=False): - str_len = self.decode_varint() - if not use_tensor: return self.read(str_len) - raw = self.raw - assert isinstance(raw, TensorIO) - res = raw._tensor[self.tell():(self.tell()+str_len)] - self.seek(str_len, os.SEEK_CUR) - return res - def read_string(self) -> str: return self.read_delimited().decode("utf-8") - def read_bytes(self) -> Tensor: return self.read_delimited(use_tensor=True) - def read_float(self) -> float: return struct.unpack(" Tensor: return self.read_delimited(use_tensor=True) - def read_int64(self) -> int: - val = self.decode_varint() - return val - 2**64 if val & (1 << 63) else val - def read_packed_int64s(self) -> list[int]: - total_bytes_len = self.decode_varint() - old_pos = self.tell() - values = [] - # need copy here because packed ints are varint - while self.tell() < total_bytes_len + old_pos: values.append(self.read_int64()) - return values - - def skip_field(self, wire_type: WireType) -> None: - """Skip a field based on its wire type.""" - match wire_type: - case WireType.VARINT: self.decode_varint() - case WireType.FIXED64: self.seek(8, os.SEEK_CUR) - case WireType.FIXED32: self.seek(4, os.SEEK_CUR) - case WireType.LENGTH_DELIMITED: self.seek(self.decode_varint(), os.SEEK_CUR) - case _: raise ValueError(f"Unknown wire type: {wire_type}") - -class OnnxPBParser: - """ - ONNX protobuf parser. - Reference: https://github.com/onnx/onnx/blob/main/onnx/onnx.proto3 - """ - def __init__(self, inp: Tensor|str|pathlib.Path, load_external_data: bool=True): - self.file_path: pathlib.Path|None = None - self.load_external_data = load_external_data - if not isinstance(inp, Tensor): - self.file_path = pathlib.Path(inp) - self.tensor = Tensor(self.file_path) - else: self.tensor = inp - self.reader = PBBufferedReader(self.tensor) - - def parse(self) -> dict: - """Parses the ONNX model into a nested dictionary. """ - return self._parse_ModelProto() - - def _parse_message(self, end_pos: int) -> Generator[tuple[int, WireType], None, None]: - while self.reader.tell() < end_pos: - tag = self.reader.decode_varint() - yield tag >> 3, WireType(tag & 0x07) - - def _decode_end_pos(self) -> int: - str_len = self.reader.decode_varint() - start_pos = self.reader.tell() - return start_pos + str_len - - def _parse_ModelProto(self) -> dict: - """Entry point for parsing the ONNX model.""" - obj: dict[str, Any] = {"opset_import": []} - for fid, wire_type in self._parse_message(self.reader.len): - match fid: - case 4: obj["domain"] = self.reader.read_string() - case 5: obj["model_version"] = self.reader.read_int64() - case 7: obj["graph"] = self._parse_GraphProto() - case 8: obj["opset_import"].append(self._parse_OperatorSetIdProto()) - case _: self.reader.skip_field(wire_type) - - # update opset version - opset_imports = {Domain.from_onnx(x.get('domain')):x.get('version', 1) for x in obj["opset_import"]} - for n in obj["graph"]["node"]: - n_ = n["parsed_node"] - n["parsed_node"] = OnnxNode(n_.op, OpSetId(n_.opset_id.domain, opset_imports.get(n_.opset_id.domain, 1)), n_.inputs, n_.outputs, n_.opts) - return obj - - def _parse_GraphProto(self) -> dict: - obj: dict[str, Any] = {"node": [], "initializer": [], "input": [], "output": []} - for fid, wire_type in self._parse_message(self._decode_end_pos()): - match fid: - case 1: obj["node"].append(self._parse_NodeProto()) - case 2: obj["name"] = self.reader.read_string() - case 5: obj["initializer"].append(self._parse_TensorProto()) - case 11: obj["input"].append(self._parse_ValueInfoProto()) - case 12: obj["output"].append(self._parse_ValueInfoProto()) - case _: self.reader.skip_field(wire_type) - return obj - - def _parse_NodeProto(self) -> dict: - obj: dict[str, Any] = {"input": [], "output": [], "attribute": [], "domain": None} - for fid, wire_type in self._parse_message(self._decode_end_pos()): - match fid: - case 1: obj["input"].append(self.reader.read_string()) - case 2: obj["output"].append(self.reader.read_string()) - case 3: obj["name"] = self.reader.read_string() - case 4: obj["op_type"] = self.reader.read_string() - case 5: obj["attribute"].append(self._parse_AttributeProto()) - case 6: obj["doc_string"] = self.reader.read_string() - case 7: obj["domain"] = self.reader.read_string() - case _: self.reader.skip_field(wire_type) - - # parse node - attributes = {attr_dict["name"]: attr_dict[AttributeType(attr_dict["type"]).to_field_name()] for attr_dict in obj["attribute"]} - opset_id = OpSetId(Domain.from_onnx(obj.get('domain')), 1) # default version, to be updated later in _parse_ModelProto - obj["parsed_node"] = OnnxNode(obj["op_type"], opset_id, tuple(obj["input"]), tuple(obj["output"]), attributes) - return obj - - def _parse_TensorProto(self) -> dict: - obj: dict[str, Any] = {"dims": []} - for fid, wire_type in self._parse_message(self._decode_end_pos()): - match fid: - case 1: obj["dims"].append(self.reader.read_int64()) - case 2: obj["data_type"] = self.reader.read_int64() - case 4: obj["float_data"] = self.reader.read_packed_floats() - case 5: obj["int32_data"] = self.reader.read_packed_int64s() - case 7: obj["int64_data"] = self.reader.read_packed_int64s() - case 8: obj["name"] = self.reader.read_string() - case 9: obj["raw_data"] = self.reader.read_bytes() - case 10: obj["double_data"] = self.reader.read_packed_floats() - case 11: obj["uint64_data"] = self.reader.read_packed_int64s() - case 13: obj.setdefault("external_data", []).append(self._parse_StringStringEntryProto()) - case 14: obj["data_location"] = self.reader.read_int64() - case _: self.reader.skip_field(wire_type) - - # load external data - if self.load_external_data and obj.get("data_location", 0) == 1: - if "external_data" not in obj: raise ValueError("no external_data") - location, length, offset = None, None, 0 - for kv in obj["external_data"]: - if kv["key"] == "location": location = kv["value"] - elif kv["key"] == "offset": offset = int(kv["value"]) - elif kv["key"] == "length": length = int(kv["value"]) - if location is None: raise ValueError("no location in external_data") - - if self.file_path is None: - if isinstance(self.tensor.device, str) and self.tensor.device.startswith("DISK:"): - self.file_path = pathlib.Path(self.tensor.device[5:]) - else: raise ValueError("onnx external_data needs the origin file path, try passing onnx file path to onnx_load") - ext_path = self.file_path.parent.joinpath(location) - if not ext_path.exists(): raise FileNotFoundError(f"external location not exists: {ext_path}") - - ext_tensor = Tensor(ext_path) - obj["raw_data"] = ext_tensor[offset:offset+length] if length is not None else ext_tensor[offset:] - obj["data_location"] = 0 - - # parse tensor - to_dtype = dtype_fallback(true_dtype := OnnxDataType(obj['data_type']).to_dtype(), "buffer parse") - shape = tuple(obj['dims']) - present_fields = [field for field in ['float_data', 'int32_data', 'int64_data', 'double_data', 'uint64_data', 'raw_data'] if field in obj] - assert len(present_fields) == 1, f"only 1 data field is allowed from {obj=}" - data = obj[present_fields[0]] - if not isinstance(data, Tensor): - obj["parsed_tensor"] = Tensor(data, dtype=to_dtype).reshape(shape) - return obj - assert isinstance(data, Tensor) and data.dtype == dtypes.uint8, data - data = data.bitcast(true_dtype).reshape(shape) - data = data.to(Device.DEFAULT) if true_dtype is to_dtype else data.to("cpu").cast(to_dtype).to(Device.DEFAULT) - # const folding - if shape == (): - if data.dtype == dtypes.float16 and sys.version_info < (3, 12): data = data.cast(dtypes.float32) - data = Tensor(data.item(), dtype=to_dtype).reshape(shape) - obj["parsed_tensor"] = data - return obj - - def _parse_AttributeProto(self) -> dict: - obj: dict[str, Any] = {"floats": [], "ints": [], "strings": []} - for fid, wire_type in self._parse_message(self._decode_end_pos()): - match fid: - case 1: obj["name"] = self.reader.read_string() - case 2: obj["f"] = self.reader.read_float() - case 3: obj["i"] = self.reader.read_int64() - case 4: obj["s"] = self.reader.read_bytes().data().tobytes().decode("utf8") - case 5: obj["t"] = self._parse_TensorProto()['parsed_tensor'] - case 7: obj["floats"].append(self.reader.read_float()) - case 8: obj["ints"].append(self.reader.read_int64()) - case 9: obj["strings"].append(self.reader.read_bytes().data().tobytes().decode("utf8")) - case 20: obj["type"] = self.reader.read_int64() - case _: self.reader.skip_field(wire_type) - obj["floats"], obj["ints"], obj["strings"] = tuple(obj["floats"]), tuple(obj["ints"]), tuple(obj["strings"]) - return obj - - def _parse_ValueInfoProto(self) -> dict: - obj: dict[str, Any] = {} - for fid, wire_type in self._parse_message(self._decode_end_pos()): - match fid: - case 1: obj["name"] = self.reader.read_string() - case 2: obj["type"] = self._parse_TypeProto() - case _: self.reader.skip_field(wire_type) - - # parse type - if "type" not in obj: return {**obj, "parsed_type": None} - type_obj = obj["type"] - if is_optional := "optional_type" in type_obj: type_obj = type_obj["optional_type"]["elem_type"] - if is_sequence := "sequence_type" in type_obj: type_obj = type_obj["sequence_type"]["elem_type"] - assert "tensor_type" in type_obj, type_obj - shape_dims = type_obj['tensor_type'].get('shape', {}).get('dim', []) - obj['parsed_type'] = OnnxValue(tuple(d.get('dim_param') or d.get('dim_value') for d in shape_dims), - OnnxDataType(type_obj['tensor_type']['elem_type']).to_dtype(), is_optional, is_sequence) - return obj - - def _parse_TypeProto(self) -> dict: - obj: dict[str, Any] = {} - for fid, wire_type in self._parse_message(self._decode_end_pos()): - match fid: - case 1: obj["tensor_type"] = self._parse_TypeProtoTensor() - case 4: obj["sequence_type"] = self._parse_TypeProtoSequence() - case 9: obj["optional_type"] = self._parse_TypeProtoOptional() - case _: self.reader.skip_field(wire_type) - return obj - - def _parse_TypeProtoTensor(self) -> dict: - obj: dict[str, Any] = {} - for fid, wire_type in self._parse_message(self._decode_end_pos()): - match fid: - case 1: obj["elem_type"] = self.reader.read_int64() - case 2: obj["shape"] = self._parse_TensorShapeProto() - case _: self.reader.skip_field(wire_type) - return obj - - def _parse_TypeProtoSequence(self) -> dict: - obj = {} - for fid, wire_type in self._parse_message(self._decode_end_pos()): - match fid: - case 1: obj["elem_type"] = self._parse_TypeProto() - case _: self.reader.skip_field(wire_type) - return obj - - def _parse_TypeProtoOptional(self) -> dict: - obj = {} - for fid, wire_type in self._parse_message(self._decode_end_pos()): - match fid: - case 1: obj["elem_type"] = self._parse_TypeProto() - case _: self.reader.skip_field(wire_type) - return obj - - def _parse_TensorShapeProto(self) -> dict: - obj: dict[str, Any] = {"dim": []} - for fid, wire_type in self._parse_message(self._decode_end_pos()): - match fid: - case 1: obj["dim"].append(self._parse_TensorShapeProtoDimension()) - case _: self.reader.skip_field(wire_type) - return obj - - def _parse_TensorShapeProtoDimension(self) -> dict: - obj: dict[str, Any] = {} - for fid, wire_type in self._parse_message(self._decode_end_pos()): - match fid: - case 1: obj["dim_value"] = self.reader.read_int64() - case 2: obj["dim_param"] = self.reader.read_string() - case _: self.reader.skip_field(wire_type) - return obj - - def _parse_StringStringEntryProto(self) -> dict: - obj: dict[str, Any] = {} - for fid, wire_type in self._parse_message(self._decode_end_pos()): - match fid: - case 1: obj["key"] = self.reader.read_string() - case 2: obj["value"] = self.reader.read_string() - case _: self.reader.skip_field(wire_type) - return obj - - def _parse_OperatorSetIdProto(self) -> dict: - obj: dict[str, Any] = {} - for fid, wire_type in self._parse_message(self._decode_end_pos()): - match fid: - case 1: obj["domain"] = self.reader.read_string() - case 2: obj["version"] = self.reader.read_int64() - case _: self.reader.skip_field(wire_type) - return obj - -# ***** python const ***** -required_input_python_consts: dict[str, tuple[int, ...]] = { - "Tile": (1,), "Range": (0,1,2), "Expand": (1,), "Reshape": (1,), "Squeeze": (1,), "Unsqueeze": (1,), "Trilu": (1,), "ConstantOfShape": (0,), - "CumSum": (1,), "TopK": (1,), "Pad": (1,2,3), "MaxUnpool": (2,), "Dropout": (1,2), "CenterCropPad": (1,), "OneHot": (1,), "Compress": (1,), - "ImageDecoder": (0,), "AffineGrid": (1,), "Resize": (1,2,3), "Upsample": (1,), "Split": (1,), "Slice": (1,2,3,4), - **{"Reduce"+r: (1,) for r in ("Max", "Min", "Sum", "Mean", "SumSquare", "Prod", "L1", "L2", "LogSum", "LogSumExp")}, - **{optim: (1,) for optim in ("Adam", "Adagrad", "Momentum")} -} - -cache_misses = 0 -@functools.cache -def _cached_to_python_const(t:Tensor): - if t.dtype == dtypes.uint8: return t.data().tobytes() - if 0 in t.shape: return [] - return t.tolist() - -# Tensor -> python value cache for parameters -def to_python_const(t:Any, op:str, idx:int) -> list[ConstType]|ConstType|bytes: - if idx not in required_input_python_consts.get(op, ()) or not isinstance(t, Tensor): return t - global cache_misses - ret = _cached_to_python_const(t) - if (info := _cached_to_python_const.cache_info()).misses > cache_misses and DEBUG >= 3: - print(f"Cache miss for {t}") - cache_misses = info.misses - return ret - -# ***** runner ****** -debug = int(getenv("DEBUGONNX", "0")) -limit = int(getenv("ONNXLIMIT", "-1")) -class OnnxRunner: - """ - `OnnxRunner` executes an ONNX model using Tinygrad. - - Args: - model_path: The ONNX model, provided as a file path (a string or Path object) or a Tensor. - """ - def __init__(self, model_path: Tensor | str | pathlib.Path): - model = OnnxPBParser(model_path, load_external_data=True).parse() - graph = model["graph"] - self.is_training = any(n['parsed_node'].opset_id.domain in {Domain.AI_ONNX_TRAINING, Domain.AI_ONNX_PREVIEW_TRAINING} for n in graph["node"]) - self.graph_values = {"": None, **{i["name"]: i["parsed_tensor"] for i in graph["initializer"]}} - self.graph_inputs = {i["name"]: i["parsed_type"] for i in graph["input"] if i["name"] not in self.graph_values} - self.graph_outputs = tuple(o["name"] for o in graph["output"]) - self.graph_nodes = tuple(n["parsed_node"] for n in graph["node"]) - - self.old_training = Tensor.training - Tensor.training = True if self.is_training else False - - self.variable_dims: dict[str, int] = {} - self.onnx_ops = onnx_ops - - def _parse_input(self, name: str, value: Any, spec: OnnxValue): - if spec.is_optional and value is None: return None - if spec.is_sequence: - if not isinstance(value, Sequence): raise RuntimeError(f"input {name} received {value}, expected a sequence type") - sequence = [Tensor(v, dtype=spec.dtype, requires_grad=self.is_training) if not isinstance(v, Tensor) else v for v in value] - if not all_same(tuple(t.shape for t in sequence)): raise RuntimeError(f"Shapes for input {name} sequence must be homogeneous") - if not all(t.dtype is spec.dtype for t in sequence): warnings.warn(f"Dtypes for input {name} sequence aren't all {spec.dtype}") - return sequence - dtype = _from_np_dtype(value.dtype) if is_numpy_ndarray(value) else spec.dtype - tensor = Tensor(value, dtype=dtype, requires_grad=self.is_training) if not isinstance(value, Tensor) else value - if tensor.dtype is not spec.dtype: warnings.warn(f"input {name} has mismatch on dtype. Expected {spec.dtype}, received {tensor.dtype}.") - for dim, (onnx_dim, user_dim_input) in enumerate(zip(spec.shape, tensor.shape, strict=True)): - if isinstance(onnx_dim, str): - onnx_dim = self.variable_dims[onnx_dim] if onnx_dim in self.variable_dims else self.variable_dims.setdefault(onnx_dim, int(user_dim_input)) - if user_dim_input != onnx_dim: raise RuntimeError(f"input {name} has mismatch on {dim=}. Expected {onnx_dim}, received {user_dim_input}.") - return tensor - - def _select_op(self, op:str, required_opset:OpSetId) -> types.FunctionType: - if op not in self.onnx_ops: raise NotImplementedError(f"{op=} is not supported") - # return default implementation if no opset_id is specified - if isinstance(impl := self.onnx_ops[op], types.FunctionType): return impl - # match domain and select implementation with latest compatible version - eligible_ops = {impl_opset.version:impl_fxn for impl_opset,impl_fxn in impl.items() - if impl_opset.domain == required_opset.domain and impl_opset.version <= required_opset.version} - if not eligible_ops: raise NotImplementedError(f"{op=} is not supported for domain {required_opset.domain} and version {required_opset.version}") - return eligible_ops[max(eligible_ops.keys())] - - def get_empty_input_data(self, device:str|None=None, dtype:DType|None=None) -> dict[str, Tensor]: - return {name:Tensor.empty(*spec.shape, device=device, dtype=dtype or spec.dtype) for name, spec in self.graph_inputs.items()} - - def to(self, device:str|None): - self.graph_values = {k:v.to(device) if isinstance(v, Tensor) else v for k,v in self.graph_values.items()} - self.graph_nodes = tuple(OnnxNode(n.op, n.opset_id, tuple(n.inputs), tuple(n.outputs), - {k:v.to(device) if isinstance(v, Tensor) else v for k,v in n.opts.items()}) for n in self.graph_nodes) - return self - - def __call__(self, inputs:dict[str, Any], debug=debug): - for name, input_spec in self.graph_inputs.items(): - if name not in inputs: raise RuntimeError(f"Please provide input data for {name}") - self.graph_values[name] = self._parse_input(name, inputs[name], input_spec) - - for num, node in enumerate(self.graph_nodes): - inps = [to_python_const(self.graph_values[name], node.op, i) for i,name in enumerate(node.inputs)] - opts = node.opts - - # provide additional opts - if node.op == "Split" and 'num_outputs' not in opts: opts['num_outputs'] = len(node.outputs) - if node.op == "Gradient": opts['intermediate_tensors'] = self.graph_values - - if debug >= 1: print(f"{num}: op '{node.op}' opt {opts}") - if debug >= 2 and node.inputs: print("\tinputs:\n" + "\n".join(f"\t\t{x} - {i!r}" for x,i in zip(node.inputs, inps))) - ret = self._select_op(node.op, node.opset_id)(*inps, **opts) - ret = ret if isinstance(ret, tuple) else (ret,) - if debug >= 2: print("\toutputs:\n" + "\n".join(f"\t\t{x} - {o!r}" for x,o in zip(node.outputs, ret))) - - self.graph_values.update(dict(zip(node.outputs, ret[:len(node.outputs)], strict=True))) - - if num == limit: - Tensor.training = self.old_training - return {name:self.graph_values[name] for name in node.outputs} - Tensor.training = self.old_training - return {name:self.graph_values[name] for name in self.graph_outputs} - -#################### -##### ONNX OPS ##### -#################### -def get_onnx_ops() -> dict[str, types.FunctionType|dict[OpSetId, types.FunctionType]]: - # ***** helper functions ***** - def _resolve_const(x: Sequence[ConstType]|ConstType): return get_single_element(x) if isinstance(x, Sequence) else x - - def _axes(axes, noop_with_empty_axes): return axes or ([] if noop_with_empty_axes else None) - - # (padding_top, padding_left, ..., padding_bottom, padding_right, ...) -> (padding_left, padding_right, padding_top, padding_bottom, ...) - def _onnx_pads_to_tiny_pads(pads): return tuple(flatten(reversed(list(zip(pads, pads[len(pads)//2:]))))) - - AUTO_PAD_OPTIONS = Literal["NOTSET", "SAME_UPPER", "SAME_LOWER", "VALID"] - # (padding_height, padding_width) -> (padding_top, padding_left, padding_bottom, padding_right) - def _auto_pad(pads, auto_pad: AUTO_PAD_OPTIONS): - if auto_pad == "SAME_UPPER": return [pads[i]//2 for i in range(len(pads))] + [pads[i]-pads[i]//2 for i in range(len(pads))] - return [pads[i]-pads[i]//2 for i in range(len(pads))] + [pads[i]//2 for i in range(len(pads))] - - def _resolve_pool_pads(x:Tensor, p_, k_, d_, s_, auto_pad:AUTO_PAD_OPTIONS): - if auto_pad == "VALID": return [0]*(len(k_)*2) - i_, (s_,d_,p_) = x.shape[-len(k_):], (make_tuple(x, len(k_)*2) for x in (s_, d_, p_)) - if auto_pad == "NOTSET": return _onnx_pads_to_tiny_pads(p_ if len(p_)==len(k_)*2 else p_*2) - o_ = [((i - (1 if auto_pad in ("SAME_UPPER", "SAME_LOWER") else k)) // s + 1) for i,k,s in zip(i_, k_, s_)] - return _onnx_pads_to_tiny_pads(_auto_pad([(o-1)*s+k-i for o,i,k,s in zip(o_, i_, k_, s_)], auto_pad)) - - def _clamp_cast(x:Tensor, dtype:DType): return x.clamp(dtypes.min(dtype), dtypes.max(dtype)).cast(dtype) - - def _prepare_quantize(x:Tensor, scale:Tensor, zero_point:Tensor|int, axis=1, block_size=0): - if axis < 0: axis += x.ndim - # https://github.com/onnx/onnx/blob/main/onnx/reference/ops/op_quantize_linear.py#L31 - def reshape(val:Tensor): - if val.numel() == 1: return val - if block_size == 0: return val.reshape([val.shape[0] if dim == axis else 1 for dim in range(x.ndim)]) - return val.repeat_interleave(block_size, axis) - return (reshape(scale), reshape(zero_point) if isinstance(zero_point, Tensor) else zero_point) - - def _op_integer(op, inputs:list[Tensor], zero_points:list[Tensor], **opts): - adjusted_inputs = [inp.int() - zp for inp, zp in zip(inputs, zero_points)] - return op(*adjusted_inputs, **opts) - - def _qlinearop_quantized(op, inputs:list[Tensor], zero_points:list[Tensor], scales:list[Tensor], out_scale:Tensor, out_zero_point:Tensor, **opts): - # op execution is done in quantized int - out = _op_integer(op, inputs, zero_points, **opts) - assert dtypes.is_int(out.dtype), "quantized op should've done math in int" - out_quantized = (out * prod(scales) / out_scale).round() + out_zero_point - return _clamp_cast(out_quantized, out_zero_point.dtype) - - def _qlinearop_float(op, inputs:list[Tensor], zero_points:list[Tensor], scales:list[Tensor], out_scale:Tensor, out_zero_point:Tensor, **opts): - # op execution is done in float32 - dequantized_inputs = [(inp.int() - zp) * scale for inp, zp, scale in zip(inputs, zero_points, scales)] - out = op(*dequantized_inputs, **opts) - assert dtypes.is_float(out.dtype), "op should've done math in float" - out_quantized = (out / out_scale).round() + out_zero_point - return _clamp_cast(out_quantized, out_zero_point.dtype) - - def _onnx_training(input_group_size): - def __decorator(func): - def ___wrapper(R:Tensor, T:int, *inputs:Tensor, **kwargs): - R = R.detach() - groups = len(inputs) // input_group_size - ret = [func(R, T, *inps, **kwargs) for inps in (inputs[i::groups] for i in range(groups))] - return tuple(flatten(zip(*ret))) - return ___wrapper - return __decorator - - # ***** Property/Graph Ops ***** - def Identity(x:Tensor): return x - def Constant(sparse_value:Tensor|None=None, value:Tensor|None=None, value_float:float|None=None, value_floats:list[float]|None=None, - value_int:int|None=None, value_ints:list[int]|None=None, value_string:str|None=None, value_strings:list[str]|None=None): - if value is not None: return value - if value_float is not None: return Tensor(value_float, dtype=dtypes.float32, requires_grad=False) - if value_floats is not None: return Tensor(list(value_floats), dtype=dtypes.float32, requires_grad=False) - if value_int is not None: return Tensor(value_int, dtype=dtypes.int64, requires_grad=False) - if value_ints is not None: return Tensor(list(value_ints), dtype=dtypes.int64, requires_grad=False) - if value_string is not None or value_strings is not None or sparse_value is not None: - raise NotImplementedError('Constant OP not implemented for value_string, value_strings and sparse_value') - - def Range(start:float|int|list[float|int], limit:float|int|list[float|int], delta:float|int|list[float|int]): - return Tensor.arange(start=_resolve_const(start), stop=_resolve_const(limit), step=_resolve_const(delta)) - - def ImageDecoder(encoded_stream:bytes, pixel_format="RGB"): - try: import PIL.Image - except ImportError as e: raise ImportError("Pillow must be installed for the ImageDecoder operator") from e - img = PIL.Image.open(io.BytesIO(encoded_stream)) - if pixel_format == "BGR": return Tensor(img.tobytes(), dtype=dtypes.uint8).reshape(*img.size, 3).flip(-1) - if pixel_format == "RGB": return Tensor(img.tobytes(), dtype=dtypes.uint8).reshape(*img.size, 3) - if pixel_format == "Grayscale": return Tensor(img.convert("L").tobytes(), dtype=dtypes.uint8).reshape(*img.size, 1) - raise ValueError(f"pixel_format={pixel_format!r} is not supported.") - - def EyeLike(x:Tensor, dtype:int|None=None, k:int=0): - ret = Tensor.eye(cast(int, min(x.shape)), dtype=dtype_fallback(OnnxDataType(dtype).to_dtype(), "EyeLike op") if dtype is not None else x.dtype) - return ret if x.size(0) == x.size(1) else ret.pad(tuple(None if d == ret.size(0) else (k, d-ret.shape[0]-k) for d in x.shape)) - - def OptionalHasElement(x:Tensor|None=None): return Tensor(x is not None and x.numel() > 0) - def OptionalGetElement(x:Tensor|None=None): return x if x is not None else Tensor([]) - def ConstantOfShape(shape:list[int], value:Tensor|None=None): - if value is None: value = Tensor(0, dtype=dtypes.float32) - if shape == [0]: return Tensor([], dtype=value.dtype) - return value.expand(shape) - - def Size(data:Tensor): return data.numel() - def Shape(data:Tensor, end:int|None=None, start:int=0): return Tensor(data.shape[start:end], dtype=dtypes.int64) - - # ***** Unary Ops (math) ***** - def Not(x:Tensor): return x.logical_not() - def Clip(x: Tensor, min:Tensor|None=None, max:Tensor|None=None): return x if min is None and max is None else x.clip(min, max) # noqa: A002 # pylint: disable=redefined-builtin - def IsInf(x:Tensor, detect_negative:int=1, detect_positive:int=1): return x.isinf(bool(detect_positive), bool(detect_negative)) - - # ***** Unary Ops (activation) ***** - def softmax_1(x:Tensor, axis:int=1): return x.softmax(axis) - def softmax_13(x:Tensor, axis:int=-1): return x.softmax(axis) - Softmax = {OpSetId(Domain.ONNX, 1):softmax_1, OpSetId(Domain.ONNX, 13):softmax_13} - def HardSigmoid(x:Tensor, alpha:float=0.2, beta:float=0.5): return (alpha*x + beta).clip(0, 1) - def Gelu(x:Tensor, approximate:str|None=None): return x.gelu() if approximate == "tanh" else 0.5 * x * (1 + (x/math.sqrt(2)).erf()) - def BiasGelu(x: Tensor, bias: Tensor, approximate: str | None = None) -> Tensor: return Gelu(x + bias, approximate) - def FastGelu(x:Tensor, bias:Tensor|None=None): return (x + bias).gelu() if bias is not None else x.gelu() # this is tanh approximated - def PRelu(X:Tensor, slope:Tensor): return (X > 0).where(X, X * slope) - def LeakyRelu(X:Tensor, alpha:float=0.01): return X.leaky_relu(alpha) - def ThresholdedRelu(X:Tensor, alpha:float=1.0): return (X > alpha).where(X, 0) - def LogSoftmax(x: Tensor, axis:int=-1): return x.log_softmax(axis) - def Binarizer(x:Tensor, threshold:float=0.0): return (x > threshold).float() - - # ***** Unary Ops (broadcasted) ***** - def Add(x:Tensor,y:Tensor, broadcast=None, axis=None): return x + y - def Sub(x:Tensor|int,y:Tensor): return x - y # some test has input as int - def Div(x:Tensor,y:Tensor): return x.div(y, rounding_mode='trunc' if dtypes.is_int(x.dtype) else None) - def Less(x:Tensor,y:Tensor): return x < y - def LessOrEqual(x:Tensor,y:Tensor): return x <= y - def Greater(x:Tensor,y:Tensor): return x > y - def GreaterOrEqual(x:Tensor,y:Tensor): return x >= y - def Equal(x:Tensor,y:Tensor): return x == y - def And(x:Tensor,y:Tensor): return (x==y).where(x, False) - def Or(x:Tensor,y:Tensor): return (x==y).where(x, True) - def Xor(x:Tensor,y:Tensor): return x.bool().bitwise_xor(y.bool()) - def BitwiseAnd(x:Tensor,y:Tensor): return x & y - def BitwiseOr(x:Tensor,y:Tensor): return x | y - def BitwiseXor(x:Tensor,y:Tensor): return x ^ y - def BitwiseNot(x:Tensor): return ~x - def Mod(x:Tensor,y:Tensor,fmod=0): return x - x.div(y, rounding_mode="trunc") * y if fmod else x % y - - # ***** Casting Ops ***** - # TODO: saturate - def Cast(x:Tensor, to:int, saturate:int=1): return x.cast(dtype_fallback(OnnxDataType(to).to_dtype(), "Cast op")) - def CastLike(x:Tensor, target_type:Tensor, saturate:int=1): return x.cast(target_type.dtype) - - # ***** Reduce Ops ***** - def Max(*data_0:Tensor): return functools.reduce(Tensor.maximum, data_0) - def Min(*data_0:Tensor): return functools.reduce(Tensor.minimum, data_0) - def Sum(*data_0:Tensor): return functools.reduce(Tensor.add, data_0) - def Mean(*data_0:Tensor): return Sum(*data_0) / len(data_0) - def ReduceMax(data:Tensor, axes:list[int]|None=None, keepdims:int=1, noop_with_empty_axes:int=0): - return data.max(_axes(axes, noop_with_empty_axes), keepdim=keepdims) - def ReduceMin(data:Tensor, axes:list[int]|None=None, keepdims:int=1, noop_with_empty_axes:int=0): - return data.min(_axes(axes, noop_with_empty_axes), keepdim=keepdims) - def ReduceSum(data:Tensor, axes:list[int]|None=None, keepdims:int=1, noop_with_empty_axes:int=0): - return data.sum(_axes(axes, noop_with_empty_axes), keepdim=keepdims) - def ReduceMean(data:Tensor, axes:list[int]|None=None, keepdims:int=1, noop_with_empty_axes:int=0): - return data.mean(_axes(axes, noop_with_empty_axes), keepdim=keepdims) - def ReduceSumSquare(data:Tensor, axes:list[int]|None=None, keepdims:int=1, noop_with_empty_axes:int=0): - return ReduceSum(data.square(), axes, keepdims, noop_with_empty_axes) - def ReduceProd(data:Tensor, axes:list[int]|None=None, keepdims:int=1, noop_with_empty_axes:int=0): - return data.prod(_axes(axes, noop_with_empty_axes), keepdim=keepdims) - def ReduceL1(data:Tensor, axes:list[int]|None=None, keepdims:int=1, noop_with_empty_axes:int=0): - return ReduceSum(data.abs(), axes, keepdims, noop_with_empty_axes) - def ReduceL2(data:Tensor, axes:list[int]|None=None, keepdims:int=1, noop_with_empty_axes:int=0): - return ReduceSumSquare(data, axes, keepdims, noop_with_empty_axes).sqrt() - def ReduceLogSum(data:Tensor, axes:list[int]|None=None, keepdims:int=1, noop_with_empty_axes:int=0): - return ReduceSum(data, axes, keepdims, noop_with_empty_axes).log() - def ReduceLogSumExp(data:Tensor, axes:list[int]|None=None, keepdims:int=1, noop_with_empty_axes:int=0): - return ReduceSum(data.exp(), axes, keepdims, noop_with_empty_axes).log() - def ArgMax(x:Tensor, axis:int=0, keepdims:int=1, select_last_index:int=0): - if select_last_index: return ((x.shape[axis]-1) - x.flip(axis).argmax(axis, keepdim=keepdims)).cast(dtypes.int64) - return x.argmax(axis, keepdim=keepdims).cast(dtypes.int64) - def ArgMin(x, axis:int=0, keepdims:int=1, select_last_index:int=0): - return ArgMax(-x, axis=axis, keepdims=keepdims, select_last_index=select_last_index) - - # ***** Movement Ops ***** - def Reshape(data:Tensor, shape:list[int], allowzero:int=0): - return data.reshape([x if x != 0 else (0 if allowzero else data.shape[i]) for i,x in enumerate(shape)]) - def Flatten(x:Tensor, axis:int=1): return x.reshape(prod(x.shape[0:axis]), -1) - def Expand(x:Tensor, shape:list[int]): return x.expand(_broadcast_shape(x.shape, tuple(shape))) - def Shrink(x:Tensor, bias:float=0.0, lambd:float=0.5): return (x < -lambd)*(x+bias) + (x > lambd)*(x-bias) - def Transpose(x:Tensor, perm:list[int]|None=None): return x.permute(order=perm or list(range(x.ndim)[::-1])) - - def Squeeze(data:Tensor, axes:list[int]|None=None): - return data.squeeze() if axes is None else functools.reduce(lambda d, dim: d.squeeze(dim), sorted(axes, reverse=True), data) - def Unsqueeze(data:Tensor, axes:list[int]): return functools.reduce(lambda d, dim: d.unsqueeze(dim), sorted(axes), data) - - def Tile(x:Tensor, repeats:list[int]): return x.repeat(repeats) - def Concat(*xs:Tensor, axis:int): return Tensor.cat(*xs, dim=axis) - def Slice(data:Tensor, starts:list[int], ends:list[int], axes:list[int]|None=None, steps:list[int]|None=None): - axes = axes or list(range(data.ndim)) - steps = steps or [1]*data.ndim - slices = [slice(0,x,1) for x in data.shape] - for i, axis in enumerate(axes): slices[axis] = slice(starts[i], ends[i], steps[i]) - return data[tuple(slices)] - - def Split(data:Tensor, split:list[int]|None=None, num_outputs:int=0, axis:int=0): - sz = data.shape[axis] - if split is None: split = [sz // num_outputs + (1 if i < sz % num_outputs else 0) for i in range(num_outputs)] - return data.split(split, axis) - - def Pad(x:Tensor, pads:list[int], constant_value:ConstType|None=None, axes:list[int]|None=None, - mode:Literal["constant", "reflect", "edge", "wrap"]="constant", value=0): - value = constant_value or value - axes = axes or list(range(x.ndim)) - real_pads = [0] * (x.ndim*2) - for i,axis in enumerate(axes): real_pads[axis%x.ndim], real_pads[axis%x.ndim+x.ndim] = pads[i], pads[i+len(axes)] - return x.pad(padding=_onnx_pads_to_tiny_pads(real_pads), mode={"edge":"replicate", "wrap":"circular"}.get(mode, mode), value=value) - - def CenterCropPad(t:Tensor, shape:list[int], axes:list[int]|None=None): - shrink_arg:list[None|tuple[int,int]] = [None] * t.ndim - pad_arg:list[None|tuple[int,int]] = [None] * t.ndim - for s, x in zip(shape, axes or range(t.ndim)): - tx = t.shape[x] - if s < tx: shrink_arg[x] = (tx//2 - (s+1)//2, tx//2 + s//2) - elif s > tx: pad_arg[x] = ((s-tx)//2, (s-tx+1)//2) - return t.shrink(tuple(shrink_arg)).pad(tuple(pad_arg)) - - # ***** Processing Ops ***** - def AveragePool(X: Tensor, kernel_shape:list[int], auto_pad:AUTO_PAD_OPTIONS="NOTSET", ceil_mode:int=0, count_include_pad:int=0, - dilations:list[int]|int=1, pads:list[int]|int=0, strides:list[int]|int=1): - pool_pads = _resolve_pool_pads(X, pads, kernel_shape, dilations, strides, auto_pad) - return X.avg_pool2d(tuple(kernel_shape), strides, dilations, pool_pads, ceil_mode=ceil_mode, count_include_pad=count_include_pad) - - def MaxPool(X: Tensor, kernel_shape:list[int], auto_pad:AUTO_PAD_OPTIONS="NOTSET", ceil_mode:int=0, dilations:list[int]|int=1, pads:list[int]|int=0, - storage_order:int=0, strides:list[int]|int=1): - pool_pads = _resolve_pool_pads(X, pads, kernel_shape, dilations, strides, auto_pad) - out = X.max_pool2d(tuple(kernel_shape), strides, dilations, pool_pads, ceil_mode=ceil_mode, return_indices=True) - ret, idx = cast(tuple[Tensor, Tensor], out) - return ret, idx.transpose(-2, -1).cast(dtypes.int64) if storage_order else idx.cast(dtypes.int64) - - def Conv(X: Tensor, W: Tensor, B:Tensor|None=None, auto_pad:AUTO_PAD_OPTIONS="NOTSET", dilations:list[int]|int=1, group:int=1, - kernel_shape:list[int]|None=None, pads:list[int]|int=0, strides:list[int]|int=1): - return X.conv2d(W, B, stride=strides, groups=group, dilation=dilations, - padding=_resolve_pool_pads(X, pads, kernel_shape or W.shape[2:], dilations, strides, auto_pad)) - - def ConvTranspose(X: Tensor, W: Tensor, B:Tensor|None=None, auto_pad:AUTO_PAD_OPTIONS="NOTSET", dilations:list[int]|int=1, group:int=1, - kernel_shape:list[int]|None=None, pads:list[int]|None=None, output_shape:list[int]|None=None, output_padding:list[int]|int=0, - strides:list[int]|int=1): - input_shape_, kernel_shape_ = X.shape[2:], (kernel_shape or W.shape[2:]) - strides_, dilations_, output_padding_ = (make_tuple(x, len(input_shape_)) for x in (strides, dilations, output_padding)) - if output_shape is not None: # we pad according to output_shape - pads = _auto_pad([s_*(i-1) + op_ + ((k_-1)*d_+1) - os for s_,i,op_,k_,d_,os in - zip(strides_, input_shape_, output_padding_, kernel_shape_, dilations_, output_shape)], auto_pad) - if pads is None: # we generate pads - output_shape = output_shape or [X.shape[i+2] * strides_[i] for i in range(len(strides_))] - pads = [strides_[i]*(input_shape_[i]-1)+output_padding_[i]+((kernel_shape_[i]-1)*dilations_[i]+1)-output_shape[i] - for i in range(len(input_shape_))] - pads = _auto_pad(pads, auto_pad) if auto_pad != "NOTSET" else [0] * len(input_shape_) * 2 - pads = _onnx_pads_to_tiny_pads(pads) - return X.conv_transpose2d(W, B, group, strides_, dilations_, pads, output_padding_) - - def MaxUnpool(xT: Tensor, xI: Tensor, outshape: list[int]|None=None, kernel_shape:list[int]|None=None, pads:list[int]|int=0, - strides:list[int]|int=1): - if kernel_shape is None: kernel_shape = [] - pads_: int | tuple[int, ...] = tuple(pads) if isinstance(pads, list) else pads - return Tensor.max_unpool2d(xT, xI, tuple(kernel_shape), strides, 1, pads_, outshape if outshape is None else tuple(outshape)) - - def GlobalAveragePool(X:Tensor): return X.mean(axis=tuple(range(2, X.ndim)), keepdim=True) - def GlobalMaxPool(X:Tensor): return X.max(axis=tuple(range(2, X.ndim)), keepdim=True) - - def Gemm(A:Tensor, B:Tensor, C:Tensor|None=None, alpha:float=1.0, beta:float=1.0, transA:int=0, transB:int=0, broadcast=0): - ret = alpha * (A.transpose(transA) @ B.transpose(transB)) - if C is not None: ret = ret + beta * (C if broadcast == 0 else C.reshape([-1 if i < len(C.shape) else 1 for i in range(ret.ndim)][::-1])) - return ret - - def Einsum(*Inputs:list[Tensor], equation:str): return Tensor.einsum(equation, *Inputs) - - def CumSum(X:Tensor, axis:int|list[int], exclusive:int=0, reverse:int=0): - axis = X._resolve_dim(_resolve_const(axis)) - if reverse: X = X.flip(axis) - if exclusive: X = X.pad(tuple((1,0) if i == axis else None for i in range(X.ndim)))\ - .shrink(tuple((0,X.shape[axis]) if i == axis else None for i in range(X.ndim))) - return X.cumsum(axis).flip(axis) if reverse else X.cumsum(axis) - - def Trilu(x:Tensor, k:int|list[int]=0, upper:int=1): - k_ = _resolve_const(k) - return x.triu(k_) if upper else x.tril(k_) - - def Resize(X:Tensor, roi:list[float]|None=None, scales:list[float]|None=None, sizes:list[int]|None=None, antialias:int=0, - axes:list[int]|None=None, coordinate_transformation_mode:str='half_pixel', cubic_coeff_a:float=-0.75, exclude_outside:int=0, - extrapolation_value:float=0.0, keep_aspect_ratio_policy:str='stretch', mode:str='nearest', nearest_mode:str='round_prefer_floor'): - def _apply_transformation(input_sz, output_sz, scale_dim, mode): - index = Tensor.arange(output_sz, requires_grad=False, device=X.device) - if mode == "half_pixel": return (index + 0.5) / scale_dim - 0.5 - if mode == "align_corners": return index * (input_sz - 1) / (output_sz - 1) if output_sz != 1 else Tensor.zeros_like(index) - if mode == "asymmetric": return index / scale_dim - if mode == "pytorch_half_pixel": return ((index + 0.5) / scale_dim - 0.5) if output_sz != 1 else Tensor.zeros_like(index) - if mode == "half_pixel_symmetric": - output_dim_scaled = input_sz * scale_dim - return (input_sz / 2) * (1 - (output_sz / output_dim_scaled)) + (index + 0.5) / scale_dim - 0.5 - raise ValueError(f"invalid {coordinate_transformation_mode=}") - - if antialias: raise NotImplementedError("antialias is not implemented") - axes = axes or list(range(X.ndim)) - perm = [a for a in range(len(X.shape)) if a not in axes] + list(axes) - # we pre-permute the axes and permute back after resize - # the permute aligns X's axes to scales, sizes, and roi - X = X.permute(*perm) - - input_shape = cast(tuple[int, ...], X.shape[2:]) - if scales is not None: assert all(sc==1 for sc in scales[:-len(input_shape)]), "resizing batch_size dim or channel dim not supported" - if sizes is not None: assert tuple(sizes[:-2]) == tuple(X.shape[X.ndim-len(sizes):-2]), "resizing batch_size dim or channel dim not supported" - - scales, sizes = (None if scales is None else scales[-len(input_shape):]), (None if sizes is None else sizes[-len(input_shape):]) - if sizes is not None: - if keep_aspect_ratio_policy in ["not_larger", "not_smaller"]: - scale_fxn = min if keep_aspect_ratio_policy == "not_larger" else max - scale = scale_fxn(sz / sh for sz,sh in zip(sizes, input_shape)) - sizes, scales = [int(scale * sh + 0.5) for sh in input_shape], [scale]*len(input_shape) - else: scales = [sz / sh for sz, sh in zip(sizes, input_shape)] - else: - assert scales is not None, "either sizes or scales must be provided" - sizes = [int(sc * sh) for sc, sh in zip(scales, input_shape)] - - if all(sz == sh for sz, sh in zip(sizes, input_shape)): return X.permute(*argsort(perm)) if perm else X - - indexes = [] - for input_sz, output_sz, scale in zip(input_shape, sizes, scales): - indexes.append(_apply_transformation(input_sz, output_sz, scale, coordinate_transformation_mode)) - - if mode in ["nearest", "linear"]: indexes = [idx.clip(0, sz-1) for idx, sz in zip(indexes, input_shape)] - - if mode == "nearest": - mode_operations = { - "round_prefer_floor": lambda idx: (idx - 0.5).ceil(), - "round_prefer_ceil": lambda idx: (idx + 0.5).floor(), - "floor": lambda idx: idx.floor(), - "ceil": lambda idx: idx.ceil() - } - if nearest_mode not in mode_operations: raise ValueError(f"invalid {nearest_mode=}") - indexes = [mode_operations[nearest_mode](idx).int() for idx in indexes] - X = X[(..., *Tensor.meshgrid(*indexes))] - - if mode == "linear": - expand = list(X.shape) - for i in range(-len(sizes), 0): - reshape, index = [1] * X.ndim, indexes[i] - reshape[i] = expand[i] = sizes[i] - low, high, perc = [y.reshape(reshape).expand(expand) for y in (index.floor().int(), index.ceil().int(), index - index.floor())] - X = X.gather(i, low).lerp(X.gather(i, high), perc) - - if mode == "cubic": - A = cubic_coeff_a - - # Keys weights - # see piecewise function in: https://en.wikipedia.org/wiki/Bicubic_interpolation#Bicubic_convolution_algorithm - def W0_1(x:Tensor): return polyN(x, [A + 2, -(A + 3), 0, 1]) - def W1_2(x: Tensor): return polyN(x, [A, -5 * A, 8 * A, -4 * A]) - - expand = list(X.shape) - for i in range(-len(sizes), 0): - input_sz = cast(int, X.shape[i]) - reshape, index = [1] * X.ndim, indexes[i] - reshape[i] = expand[i] = sizes[i] - - p = index.floor().int() - ratio = index - p # in [0, 1] - - # Neighbor indices - idx0, idx1, idx2, idx3 = [p + d for d in [-1, 0, 1, 2]] - # Weights of distance from index and neighbor indices - c0, c1, c2, c3 = W1_2(ratio+1), W0_1(ratio), W0_1(-(ratio-1)), W1_2(-(ratio-2)) - - if exclude_outside: - c0 = ((idx0 >= 0) & (idx0 < input_sz)).where(c0, 0) - c1 = ((idx1 >= 0) & (idx1 < input_sz)).where(c1, 0) - c2 = ((idx2 >= 0) & (idx2 < input_sz)).where(c2, 0) - c3 = ((idx3 >= 0) & (idx3 < input_sz)).where(c3, 0) - - total = c0 + c1 + c2 + c3 - c0, c1, c2, c3 = c0 / (total + 1e-9), c1 / (total + 1e-9), c2 / (total + 1e-9), c3 / (total + 1e-9) - - # Reshape and expand - expanded_indices = [y.clip(0, input_sz - 1).reshape(reshape).expand(expand) for y in [idx0, idx1, idx2, idx3]] - expanded_coeffs = [y.reshape(reshape).expand(expand) for y in [c0, c1, c2, c3]] - - # Gather values and apply coefficients - gathered_values = [X.gather(i, idx) for idx in expanded_indices] - X = sum(v * c for v, c in zip(gathered_values, expanded_coeffs)) - return X.permute(*argsort(perm)) if perm else X - def Upsample(X, scales, mode): return Resize(X=X, scales=scales, mode=mode) # deprecated - - def TopK(X:Tensor, K:int|list[int], axis:int=-1, largest:int=1, sorted:int=1): # noqa: A002 # pylint: disable=redefined-builtin - val, idx = X.topk(_resolve_const(K), axis, bool(largest), bool(sorted)) - return val, idx.cast(dtypes.int64) - - # ***** Neural Network Ops ***** - def BatchNormalization(X:Tensor, scale:Tensor, B:Tensor, input_mean:Tensor, input_var:Tensor, epsilon:float=1e-05, momentum:float=0.9, - training_mode:int=0, spatial=1, is_test=0): - if training_mode: - x_detached = X.detach() - current_mean = x_detached.mean(axis=(0,2,3)) - y = (x_detached - current_mean.reshape(shape=[1, -1, 1, 1])) - current_var = (y*y).mean(axis=(0,2,3)) - current_invstd = current_var.add(epsilon).rsqrt() - - running_mean = input_mean * momentum + current_mean * (1 - momentum) - running_var = input_var * momentum + current_var * (1 - momentum) - - return X.batchnorm(scale, B, current_mean, current_invstd), running_mean, running_var - return X.batchnorm(scale, B, input_mean, (input_var + epsilon).rsqrt()) - def GroupNormalization(x:Tensor, scale:Tensor, bias:Tensor, num_groups:int, epsilon:float=1e-05): - x = x.reshape(x.shape[0], num_groups, -1).layernorm(eps=epsilon).reshape(x.shape) - return x * scale.reshape(1, -1, *[1] * (x.ndim-2)) + bias.reshape(1, -1, *[1] * (x.ndim-2)) - def InstanceNormalization(x:Tensor, scale:Tensor, bias:Tensor, epsilon:float=1e-05): - return GroupNormalization(x, scale, bias, num_groups=cast(int, x.shape[1]), epsilon=epsilon) - def LayerNormalization(x:Tensor, scale:Tensor, bias:Tensor, axis:int=-1, epsilon:float=1e-05, stash_type:int=1): - assert stash_type == 1, "only float32 is supported" - axes = tuple(i for i in range(axis if axis >= 0 else x.ndim + axis, x.ndim)) - mean = x.mean(axis=axes, keepdim=True) - return x.layernorm(axes, epsilon).mul(scale).add(bias), mean, (x.sub(mean)).square().mean(axis=axes, keepdim=True).add(epsilon).rsqrt() - def SkipLayerNormalization(x:Tensor, skip:Tensor, gamma:Tensor, beta:Tensor|None=None, bias:Tensor|None=None, epsilon:float=1e-12): - x = x + skip - if bias is not None: x = x + bias - ret = x.layernorm(eps=epsilon) * gamma - if beta is not None: ret = ret + beta - return ret, None, None, x - def EmbedLayerNormalization(input_ids: Tensor, segment_ids:Tensor, word_embedding:Tensor, position_embedding:Tensor, - segment_embedding:Tensor, gamma=None, beta=None, mask:Tensor|None=None, - position_ids:Tensor|None=None, epsilon=1e-12, mask_index_type=0): - # https://github.com/microsoft/onnxruntime/blob/main/docs/ContribOperators.md#com.microsoft.EmbedLayerNormalization - assert (segment_ids is None) is (segment_embedding is None) - assert mask is None and not mask_index_type, "functionality not supported yet" # TODO - input_shape = input_ids.shape - seq_length = input_shape[1] - compute_seg_emb = (segment_embedding is not None and segment_ids is not None) - vocab_size, max_position_embeddings = word_embedding.shape[0], position_embedding.shape[0] - type_vocab_size = (segment_embedding.shape[0] if compute_seg_emb else None) - - def embedding(x:Tensor, vocab_size, weight:Tensor) -> Tensor: - return x.unsqueeze(-1).expand(*x.shape, vocab_size)._one_hot_along_dim(vocab_size) @ weight - - # bert embedding layer - if position_ids is None: position_ids = Tensor.arange(seq_length, requires_grad=False).unsqueeze(0).expand(*input_shape) - wrd_embedding_res = embedding(input_ids, vocab_size, word_embedding) - pos_embedding_res = embedding(position_ids, max_position_embeddings, position_embedding) - seg_embedding_res = embedding(segment_ids, type_vocab_size, segment_embedding) if compute_seg_emb else None - - embedding_sum = wrd_embedding_res + pos_embedding_res - if seg_embedding_res is not None: embedding_sum = embedding_sum + seg_embedding_res - out = embedding_sum.layernorm(eps=epsilon) * gamma + beta - return out, None, embedding_sum - def MeanVarianceNormalization(x:Tensor, axis:list[int]|None=None): - if axis is None: axis = [0,2,3] - return (x - x.mean(axis, keepdim=True)) / (x.std(axis, keepdim=True, correction=0) + 1e-9) - - def OneHot(indices:Tensor, depth:float|int|list[int|float], values:Tensor, axis:int=-1): - # Scalar or Rank 1 tensor containing exactly one element - depth = int(_resolve_const(depth)) - indices = indices.int() - indices = (indices < 0).where(indices+depth, indices) - return indices.unsqueeze(axis)._one_hot_along_dim(depth, dim=axis).where(values[1], values[0]) - - def DepthToSpace(X:Tensor, blocksize:int, mode:str="DCR"): - return X.rearrange("b (c h1 w1) h w -> b c (h h1) (w w1)" if mode=="CRD" else "b (h1 w1 c) h w -> b c (h h1) (w w1)", h1=blocksize, w1=blocksize) - def SpaceToDepth(X:Tensor, blocksize:int): - return X.rearrange("b c (h h1) (w w1) -> b (h1 w1 c) h w", h1=blocksize, w1=blocksize) - - # Reimplemented here because you need legacy RNG for passing ONNX tests. - def dropout_7(data:Tensor, ratio:float=0.5, training_mode:bool=False, seed:int|None=None): - import numpy as np - if not training_mode: return data, data.full_like(True, dtype=dtypes.bool) - if seed is not None: - rand = Tensor(np.random.RandomState(seed).random(cast(tuple[int,...], data.shape)), requires_grad=False, dtype=data.dtype, device=data.device) - else: - rand = data.rand_like(requires_grad=False) - mask = rand >= ratio - return data * mask / (1.0 - ratio), mask - # 6 with 'is_test' needed for https://github.com/MTlab/onnx2caffe/raw/refs/heads/master/model/MobileNetV2.onnx - def dropout_6(data:Tensor, ratio:float=0.5, is_test=0): return dropout_7(data, ratio, training_mode=not is_test) - Dropout = {OpSetId(Domain.ONNX, 6):dropout_6, OpSetId(Domain.ONNX, 7):dropout_7} - - def LRN(x:Tensor, size:int, alpha:float=1e-4, beta:float=0.75, bias:float=1.0): - pooled_x = (x**2).rearrange('b c h w -> b 1 c (h w)').pad((0,0,(size-1)//2, size//2)).avg_pool2d((size, 1), 1) - return x / (pooled_x.reshape(x.shape) * alpha + bias).pow(beta) - - def NegativeLogLikelihoodLoss(x:Tensor, target:Tensor, weight:Tensor|None=None, ignore_index:int|None=None, reduction:ReductionStr="mean"): - return x.nll_loss(target, weight, ignore_index, reduction) - def SoftmaxCrossEntropyLoss(scores:Tensor, labels:Tensor, weights:Tensor|None=None, ignore_index:int|None=None, reduction:ReductionStr="mean"): - log_probs = scores.log_softmax(1) - return log_probs.nll_loss(labels, weights, ignore_index, reduction), log_probs - - def AffineGrid(theta:Tensor, size:list[int], align_corners:int=0): - N, _, *spatial_dims = size - def generate_grid(steps): - if align_corners: return Tensor.linspace(-1, 1, steps, device=theta.device) - return Tensor.linspace(-1+1/steps, 1-1/steps, steps, device=theta.device) - grids = Tensor.meshgrid(*(generate_grid(d) for d in spatial_dims)) - base_grid = Tensor.stack(*reversed(grids), Tensor.ones_like(grids[0], device=theta.device), dim=-1) - base_grid = base_grid.reshape(1, prod(spatial_dims), len(grids)+1).expand(N, -1, -1) - return (base_grid @ theta.transpose(1, 2)).reshape(N, *spatial_dims, -1) - - def attention_contrib(x:Tensor, weights:Tensor, bias:Tensor|None=None, mask_index:Tensor|None=None, past:Tensor|None=None, - attention_bias:Tensor|None=None, past_sequence_length:Tensor|None=None, do_rotary:int=0, mask_filter_value:float=-10000.0, - num_heads:int|None=None, past_present_share_buffer:int|None=None, qkv_hidden_sizes:list[int]|None=None, - rotary_embedding_dim:int|None=None, scale:float|None=None, unidirectional:int=0): - assert not do_rotary and not attention_bias, "TODO" - if qkv_hidden_sizes is None: qkv_hidden_sizes = [weights.shape[1] // 3] * 3 - qkv = x.linear(weights, bias) - q, k, v = qkv.split(qkv_hidden_sizes, dim=2) - - batch_size, seq_len, _ = x.shape - assert num_heads is not None, "num_heads must be provided" - q_head_size, k_head_size, v_head_size = (sz // num_heads for sz in qkv_hidden_sizes) - q, k, v = (x.reshape(batch_size, seq_len, num_heads, hsz).transpose(1, 2) for x, hsz in zip((q, k, v), (q_head_size, k_head_size, v_head_size))) - - present = None - if past is not None: - k, v = past[0].cat(k, dim=2), past[1].cat(v, dim=2) - present = k.stack(v) - - if scale is None: scale = 1.0 / math.sqrt(q_head_size) - attn_scores = q @ k.transpose(-1, -2) * scale - - if mask_index is not None: - assert 4 >= mask_index.ndim >= 1, f"{mask_index.ndim=}" - assert isinstance(batch_size, int), f"{batch_size=}" - if mask_index.ndim != 1: mask = mask_index.bool() - else: - if mask_index.shape[0] == batch_size: - mask = Tensor.arange(attn_scores.shape[-1], requires_grad=False, device=mask_index.device).unsqueeze(0) < mask_index.unsqueeze(1) - elif mask_index.shape[0] == 2*batch_size: - end_positions = mask_index[:batch_size] - start_positions = mask_index[batch_size:] - arange = Tensor.arange(seq_len).unsqueeze(0) - mask = (arange < end_positions.unsqueeze(1)) & (arange >= start_positions.unsqueeze(1)) - else: raise NotImplementedError("mask_index with shape (3 * batch_size + 2) is not implemented") - while mask.ndim < 4: mask = mask.unsqueeze(1) - attn_scores = mask.where(attn_scores, mask_filter_value) - - if unidirectional: - causal_mask = Tensor.ones((seq_len, seq_len), dtype=dtypes.bool).tril() - attn_scores = causal_mask.where(attn_scores, mask_filter_value) - - output = attn_scores.softmax(-1) @ v - output = output.transpose(1, 2).reshape(batch_size, seq_len, -1) - return output, present - - def attention_onnx(Q:Tensor, K:Tensor, V:Tensor, attn_mask:Tensor|None=None, past_key:Tensor|None=None, past_value:Tensor|None=None, - is_causal:int=0, kv_num_heads:int|None=None, q_num_heads:int|None=None, qk_matmul_output_mode:int=0, scale:float|None=None, - softcap:float=0.0, softmax_precision:int|None=None): - input_shape_len = Q.ndim - if input_shape_len == 3: - assert q_num_heads is not None and kv_num_heads is not None - Q = Q.reshape(Q.shape[0], q_num_heads, Q.shape[1], -1) - K = K.reshape(K.shape[0], kv_num_heads, K.shape[1], -1) - V = V.reshape(V.shape[0], kv_num_heads, V.shape[1], -1) - - if past_key is not None: K = past_key.cat(K, dim=2) - if past_value is not None: V = past_value.cat(V, dim=2) - present_key, present_value = K, V - - _q_heads, _kv_heads = q_num_heads or Q.shape[1], kv_num_heads or K.shape[1] - if _q_heads != _kv_heads: - K = K.repeat((1, _q_heads // _kv_heads, 1, 1)) - V = V.repeat((1, _q_heads // _kv_heads, 1, 1)) - - effective_scale = scale if scale is not None else 1.0 / (cast(int, Q.shape[-1]) ** 0.5) - scores = (Q @ K.transpose(-1, -2)) * effective_scale - qk_matmul_return_val = scores - - if is_causal: - causal_mask = Tensor.ones(Q.shape[-2], K.shape[-2], device=Q.device, dtype=dtypes.bool, requires_grad=False).tril(0) - scores = scores.masked_fill(causal_mask.logical_not(), -float("inf")) - - if attn_mask is not None: - mask_to_add = attn_mask.where(0, -float("inf")) if attn_mask.dtype == dtypes.bool else attn_mask - scores = scores + mask_to_add - if qk_matmul_output_mode == 1: qk_matmul_return_val = scores - - if softcap > 0.0: scores = (scores / softcap).tanh() * softcap - if qk_matmul_output_mode == 2: qk_matmul_return_val = scores - - if softmax_precision: scores = scores.cast({1: dtypes.float32, 10: dtypes.float16, 16: dtypes.bfloat16}[softmax_precision]) - qk_softmax = scores.softmax(-1).cast(Q.dtype) - if qk_matmul_output_mode == 3: qk_matmul_return_val = qk_softmax - - output = (qk_softmax @ V).cast(Q.dtype) - if input_shape_len == 3: output = output.permute(0, 2, 1, 3).reshape(Q.shape[0], Q.shape[2], -1) - return output, present_key, present_value, qk_matmul_return_val - Attention = {OpSetId(Domain.ONNX, 1): attention_onnx, OpSetId(Domain.MICROSOFT_CONTRIB_OPS, 1): attention_contrib} - - def RMSNormalization(X:Tensor, scale:Tensor, axis:int=-1, epsilon:float=1e-5): - norm = X.square().mean(axis=tuple(range(axis + X.ndim if axis < 0 else axis, X.ndim)), keepdim=True).add(epsilon).rsqrt() - return X * norm * scale - - def RotaryEmbedding(X:Tensor, cos_cache:Tensor, sin_cache:Tensor, position_ids:Tensor|None=None, interleaved:int=0, num_heads:int|None=None, - rotary_embedding_dim:int=0): - original_input_shape = X.shape - - if X.ndim == 4: X = X.permute(0, 2, 1, 3) - elif X.ndim == 3: - assert num_heads is not None, "num_heads must be provided for 3D input" - X = X.reshape(*X.shape[:-1], num_heads, X.shape[-1] // num_heads) - - head_size = cast(int, X.shape[-1]) - rot_dim = rotary_embedding_dim or head_size - x_rotate, x_pass = X[..., :rot_dim], X[..., rot_dim:] - - cos = cos_cache[position_ids] if position_ids is not None else cos_cache[:head_size] - sin = sin_cache[position_ids] if position_ids is not None else sin_cache[:head_size] - cos = cos[..., :rot_dim//2].unsqueeze(2) - sin = sin[..., :rot_dim//2].unsqueeze(2) - - if interleaved: - x1, x2 = x_rotate[..., ::2], x_rotate[..., 1::2] - real = x1 * cos - x2 * sin - imag = x1 * sin + x2 * cos - x_rotated = Tensor.stack(real, imag, dim=-1).flatten(start_dim=-2) - else: - x1, x2 = x_rotate.chunk(2, dim=-1) - real = x1 * cos - x2 * sin - imag = x1 * sin + x2 * cos - x_rotated = real.cat(imag, dim=-1) - - output = x_rotated.cat(x_pass, dim=-1) - return output.flatten(start_dim=2) if len(original_input_shape) == 3 else output.permute(0, 2, 1, 3) - - # ***** Indexing Ops ***** - def ArrayFeatureExtractor(x:Tensor, indices:Tensor): return x[..., indices] - - def Gather(x:Tensor, indices:Tensor, axis:int=0): - if indices.numel() < 9: # NOTE lessor kernels for smaller indices but kernel number increases depending on size of indices - ret_shape = x.shape[:axis] + indices.shape + x.shape[axis+1:] - if indices.ndim > 1: indices = indices.flatten() - index_consts = [_cached_to_python_const(indices)] if indices.shape == () else _cached_to_python_const(indices) - index_consts = [x.shape[axis]+i if i<0 else i for i in index_consts] - args = [[(0,x) if j != axis else (i,i+1) for j, x in enumerate(x.shape)] for i in index_consts] - return x.shrink(arg=tuple(args[0])).cat(*[x.shrink(arg=tuple(arg)) for arg in args[1:]], dim=axis).reshape(ret_shape) - # NOTE faster gather, fixed number of kernels, but exceeds limited kernels for openpilot - return x[tuple([slice(None) if i != axis else indices for i in range(x.ndim)])] - def Scatter(*args, **kwargs): return ScatterElements(*args, **kwargs) # deprecated - - def GatherND(x:Tensor, indices:Tensor, batch_dims:int=0): - if batch_dims == 0: return x[tuple(i.squeeze(-1) for i in indices.split(1, -1))] - x_shape, i_shape = x.shape, indices.shape - b = math.prod(x.shape[dim] for dim in range(batch_dims)) - # NOTE: each batched dim of both input and indices are equal - x = x.reshape(b, *x.shape[batch_dims:]) - indices = indices.reshape(b, *indices.shape[batch_dims:]) - b_idx = Tensor.arange(b, device=x.device).reshape(b, *(1,)*(indices.ndim - 2)).expand(*indices.shape[:-1]) - ret = x[(b_idx,) + tuple(i.squeeze(-1) for i in indices.split(1, -1))] - return ret.reshape(*x_shape[:batch_dims], *i_shape[batch_dims:-1], *ret.shape[indices.ndim-1:]) - def ScatterND(x:Tensor, indices:Tensor, updates:Tensor, reduction:Literal["none", "add", "mul"]='none'): - assert updates.shape == indices.shape[:-1] + x.shape[cast(int, indices.shape[-1]):] - x = x.contiguous() - for index, u in zip(indices.split(1, 0), updates.split(1, 0)): - i = tuple(idx.squeeze(-1) for idx in index.squeeze(0).split(1, -1)) - u = u.squeeze(0) - if reduction == "none": x[i] = u - elif reduction == "add": x[i] += u - elif reduction == "mul": x[i] *= u - else: raise NotImplementedError("reduction doesn't support max or min") - return x - - def ScatterElements(x: Tensor, indices: Tensor, updates: Tensor, axis=0, reduction:Literal["none", "add", "mul", "min", "max"]="none"): - indices = (indices < 0).where(x.shape[axis], 0) + indices - if reduction == "none": return x.scatter(axis, indices, updates) - reduction_ = cast(Literal["sum", "prod", "amin", "amax"], {"add": "sum", "mul": "prod", "min": "amin", "max": "amax"}[reduction]) - return x.scatter_reduce(axis, indices, updates, reduction_) - def GatherElements(x:Tensor, indices:Tensor, axis:int): - indices = (indices < 0).where(x.shape[axis], 0) + indices - return x.gather(axis, indices) - - def Compress(inp:Tensor, condition:list[bool], axis:int|None=None): - if axis is None: - inp = inp.flatten() - axis = 0 - axis = inp._resolve_dim(axis) - con = Tensor([i for i,cond in enumerate(condition) if cond]) # compress in python - return inp[tuple(con if i == axis else slice(None) for i in range(inp.ndim))] - - # ***** Quantization Ops ***** - def QuantizeLinear(x:Tensor, y_scale:Tensor, y_zero_point:Tensor|int=0, axis:int=1, block_size:int=0, output_dtype:int=0, saturate=1): - if isinstance(y_zero_point, Tensor): out_dtype = y_zero_point.dtype - elif output_dtype != 0: out_dtype = dtype_fallback(OnnxDataType(output_dtype).to_dtype(), "QuantizeLinear op") - else: out_dtype = dtypes.uint8 - y_scale, y_zero_point = _prepare_quantize(x, y_scale, y_zero_point, axis, block_size) - if out_dtype == dtypes.uchar: - # this appears to work in practice, at least for uchar out_dtype. it folds with the quantize stuff - ret = _clamp_cast((x / y_scale + 0.4999999 + y_zero_point).int(), out_dtype) - else: - ret = _clamp_cast(((x / y_scale).round() + y_zero_point), out_dtype) - return ret.contiguous() - - def DynamicQuantizeLinear(x: Tensor): - # only support uint8 - qmin, qmax = dtypes.min(dtypes.uint8), dtypes.max(dtypes.uint8) - scale = (x.max().maximum(0) + ((-x).max()).maximum(0)) / (qmax - qmin) - zero_point = _clamp_cast((qmin - x.min() / scale).round(), dtypes.uint8) - y = _clamp_cast((x / scale).round() + zero_point, dtypes.uint8) - return y, scale, zero_point - - def DequantizeLinear(x:Tensor, x_scale:Tensor, x_zero_point:Tensor|int=0, axis:int=1, block_size:int=0): - x_scale, x_zero_point = _prepare_quantize(x, x_scale, x_zero_point, axis, block_size) - return ((x.int() - x_zero_point) * x_scale).cast(x_scale.dtype) - - def QLinearConv(x:Tensor, x_scale:Tensor, x_zero_point:Tensor, w:Tensor, w_scale:Tensor, w_zero_point:Tensor, y_scale:Tensor, - y_zero_point:Tensor, B:Tensor|None=None, **opts): - return _qlinearop_quantized(Conv, [x,w], [x_zero_point,w_zero_point], [x_scale,w_scale], y_scale, y_zero_point, **{"B":B, **opts}) - - def QLinearMatMul(a:Tensor, a_scale:Tensor, a_zero_point:Tensor, b:Tensor, b_scale:Tensor, b_zero_point:Tensor, y_scale:Tensor, - y_zero_point:Tensor) -> Tensor: - return _qlinearop_quantized(Tensor.matmul, [a,b], [a_zero_point,b_zero_point], [a_scale,b_scale], y_scale, y_zero_point) - - def QLinearAdd(a:Tensor, a_scale:Tensor, a_zero_point:Tensor, b:Tensor, b_scale:Tensor, b_zero_point:Tensor, c_scale:Tensor, c_zero_point:Tensor): - return _qlinearop_float(Tensor.add, [a,b], [a_zero_point,b_zero_point], [a_scale,b_scale], c_scale, c_zero_point) - - def QLinearMul(a:Tensor, a_scale:Tensor, a_zero_point:Tensor, b:Tensor, b_scale:Tensor, b_zero_point:Tensor, c_scale:Tensor, c_zero_point:Tensor): - return _qlinearop_quantized(Tensor.mul, [a,b], [a_zero_point,b_zero_point], [a_scale,b_scale], c_scale, c_zero_point) - - def QLinearGlobalAveragePool(X:Tensor, x_scale:Tensor, x_zero_point:Tensor, y_scale:Tensor, y_zero_point:Tensor, channels_last:int): - assert channels_last == 0, "TODO NHWC" - return _qlinearop_float(GlobalAveragePool, [X], [x_zero_point], [x_scale], y_scale, y_zero_point) - - def ConvInteger(x: Tensor, w: Tensor, x_zero_point:Tensor = Tensor(0), w_zero_point:Tensor = Tensor(0), B: Tensor | None = None, **opts) -> Tensor: - return _op_integer(Conv, [x,w], [x_zero_point,w_zero_point], **{"B":B, **opts}) - - def MatMulInteger(A: Tensor, B: Tensor, a_zero_point: Tensor = Tensor(0), b_zero_point: Tensor = Tensor(0)) -> Tensor: - return _op_integer(Tensor.matmul, [A,B], [a_zero_point,b_zero_point]) - - # ***** Training Ops ***** - # NOTE: onnx training ops actually don't need the state for optim, all the ops work in a functional way, but we still can reuse optim.py code - @_onnx_training(3) - def Adagrad(R:Tensor, T:int, *inputs:Tensor, decay_factor:float=0.0, epsilon:float=0.0, norm_coefficient:float=0.0): - X, G, H = (i.detach() for i in inputs) - grad = norm_coefficient * X + G - H.assign(H + grad.square()) - up = grad / (H.sqrt() + epsilon) - r = R / (1 + T * decay_factor) - X.assign(X.detach() - r * up) - return [X, H] - - @_onnx_training(4) - def Adam(R:Tensor, T:int, *inputs:Tensor, alpha:float=0.9, beta:float=0.999, epsilon:float=0.0, norm_coefficient:float=0.0, - norm_coefficient_post:float=0.0): - from tinygrad.nn.optim import Adam as TinyAdam - X, G, V, H = inputs - G, V, H = G.detach(), V.detach(), H.detach() - X.grad = norm_coefficient * X.detach() + G - opt = TinyAdam([X], b1=alpha, b2=beta, eps=epsilon) - opt.m, opt.v, opt.lr = [V], [H], R - # need no-op for m_hat and v_hat if T == 0 - if T == 0: opt.b1_t, opt.b2_t = opt.b1_t.zeros_like(), opt.b2_t.zeros_like() - else: - # `T-1` since it's applied again at the start of `_step` - opt.b1_t = Tensor([alpha**(T-1)], dtype=dtypes.float32, device=X.device, requires_grad=False) - opt.b2_t = Tensor([beta**(T-1)], dtype=dtypes.float32, device=X.device, requires_grad=False) - opt.step() - X = (1 - norm_coefficient_post) * X - return [X, V, H] - - @_onnx_training(3) - def Momentum(R:Tensor, T:int, *inputs:Tensor, alpha:float, beta:float, mode:str, norm_coefficient:float): - X, G, V = (i.detach() for i in inputs) - grad = norm_coefficient * X + G - # NOTE: this beta_adjusted term makes it so we can't use SGD for nesterov - beta_adjusted = beta if T > 0 else 1 - V.assign(alpha * V + grad * beta_adjusted) - X.assign(X - R * (V if mode == "standard" else (grad + alpha * V))) - return [X, V] - - def Gradient(*inputs:Tensor, y:str, intermediate_tensors:dict[str, Tensor], **_): - intermediate_tensors[y].backward() - return tuple([t.grad for t in inputs]) - - return { - # Tensor ops - **{op: getattr(Tensor, op.lower()) for op in ("Neg", "Reciprocal", "Pow", "Sqrt", "Sign", "Abs", "Exp", "Log", "Mish", "Sin", "Cos", "Tan", - "Asin", "Acos", "Atan", "Relu", "Sigmoid", "MatMul", "Floor", "Ceil", "IsNaN", "Softplus", "HardSwish", "Where", "Mul", "Sinh", "Cosh", - "Tanh", "Softsign", "Asinh", "Acosh", "Atanh", "Elu", "Celu", "Selu", "Round", "Erf")}, - # Implemented ops - **{name:obj for name,obj in locals().items() if isinstance(obj, types.FunctionType) and not name.startswith("_") and name[0].isupper()}, - # Version ops - **{name:obj for name,obj in locals().items() if isinstance(obj, dict)}, - } - -onnx_ops = get_onnx_ops() diff --git a/extra/onnx_helpers.py b/extra/onnx_helpers.py index 750f54d9e3..632d5df8d7 100644 --- a/extra/onnx_helpers.py +++ b/extra/onnx_helpers.py @@ -1,7 +1,6 @@ from tinygrad import Tensor from tinygrad.tensor import _to_np_dtype -from tinygrad.frontend.onnx import OnnxRunner -from extra.onnx import OnnxValue +from tinygrad.frontend.onnx import OnnxRunner, OnnxValue import numpy as np import onnxruntime as ort diff --git a/test/external/external_test_onnx_runner.py b/test/external/external_test_onnx_runner.py index b4acd19995..f0d8941b45 100644 --- a/test/external/external_test_onnx_runner.py +++ b/test/external/external_test_onnx_runner.py @@ -3,8 +3,7 @@ import numpy as np from tinygrad import dtypes, Tensor from tinygrad.uop.ops import Ops from tinygrad.device import is_dtype_supported -from extra.onnx import OnnxDataType -from tinygrad.frontend.onnx import OnnxRunner +from tinygrad.frontend.onnx import OnnxRunner, OnnxDataType from hypothesis import given, strategies as st # copied from test_const_folding.py diff --git a/tinygrad/frontend/onnx.py b/tinygrad/frontend/onnx.py index 2d6703bf1f..bcead3994d 100644 --- a/tinygrad/frontend/onnx.py +++ b/tinygrad/frontend/onnx.py @@ -1,6 +1,1261 @@ -# type: ignore -import sys, pathlib -sys.path.append(pathlib.Path(__file__).parent.parent.as_posix()) -try: - from extra.onnx import OnnxRunner # noqa: F401 # pylint: disable=unused-import -except ImportError as e: raise ImportError("onnx frontend not in release\nTo fix, install tinygrad from a git checkout with pip install -e .") from e \ No newline at end of file +# pylint: disable=possibly-unused-variable +from typing import Any, Sequence, cast, Literal, NamedTuple, Generator +import dataclasses, functools, io, math, types, warnings, pathlib, sys, os, struct, enum +from io import BufferedReader +from tinygrad.nn.state import TensorIO +from tinygrad.tensor import Tensor, _broadcast_shape, ReductionStr +from tinygrad.helpers import getenv, DEBUG, all_same, prod, flatten, make_tuple, argsort, is_numpy_ndarray, get_single_element, polyN +from tinygrad.dtype import DType, ConstType, dtypes, _from_np_dtype +from tinygrad.device import is_dtype_supported, Device + +# ***** protobuf definitions ****** +class WireType(enum.IntEnum): + """ + Protocol Buffer wire types for decoding fields. + Reference: https://github.com/protocolbuffers/protobuf/blob/main/python/google/protobuf/internal/wire_format.py#L24-L29 + """ + VARINT = 0; FIXED64 = 1; LENGTH_DELIMITED = 2; START_GROUP = 3; END_GROUP = 4; FIXED32 = 5 # noqa: E702 + +class AttributeType(enum.IntEnum): + """ + ONNX attribute type identifiers. + Reference: https://github.com/onnx/onnx/blob/rel-1.18.0/onnx/onnx.proto3#L128-L145 + """ + FLOAT = 1; INT = 2; STRING = 3; TENSOR = 4; FLOATS = 6; INTS = 7; STRINGS = 8 # noqa: E702 + + def to_field_name(self) -> str: return {1: "f", 2: "i", 3: "s", 4: "t", 6: "floats", 7: "ints", 8: "strings"}[self.value] + +class OnnxDataType(enum.IntEnum): + """ + ONNX tensor data type identifiers. + Reference: https://github.com/onnx/onnx/blob/rel-1.18.0/onnx/onnx.proto3#L500-L544 + """ + FLOAT = 1; UINT8 = 2; INT8 = 3; UINT16 = 4; INT16 = 5; INT32 = 6; INT64 = 7; BOOL = 9; FLOAT16 = 10; DOUBLE = 11; UINT32 = 12 # noqa: E702 + UINT64 = 13; BFLOAT16 = 16 # noqa: E702 + + def to_dtype(self) -> DType: return dtypes.fields()[self.name.lower()] + +def dtype_fallback(dtype: DType, fallback_context: str) -> DType: + if is_dtype_supported(dtype): return dtype + default_dtype = dtypes.default_int if dtypes.is_int(dtype) else dtypes.default_float + warnings.warn(f"dtype {dtype} on {Device.DEFAULT} from {fallback_context} is not supported, falling back to {default_dtype}") + assert is_dtype_supported(default_dtype), f"dtype {default_dtype} must be supported on {Device.DEFAULT}" + return default_dtype + +# ***** onnx spec definitions ***** +class Domain(enum.Enum): + ONNX = "ai.onnx" + ONNX_ML = "ai.onnx.ml" + AI_ONNX_TRAINING = "ai.onnx.training" + AI_ONNX_PREVIEW_TRAINING = "ai.onnx.preview.training" + MICROSOFT_CONTRIB_OPS = "com.microsoft" + MICROSOFT_NCHWC = "com.microsoft.nchwc" + MICROSOFT_EXPERIMENTAL = "com.microsoft.experimental" + PYTORCH_ATEN = "org.pytorch.aten" + @classmethod + def from_onnx(cls, domain: str | None) -> "Domain": return cls.ONNX if domain is None or domain == "" else cls(domain) + +class OpSetId(NamedTuple): + domain: Domain + version: int + +@dataclasses.dataclass(frozen=True) +class OnnxValue: + shape: tuple[str|int, ...] + dtype: DType + is_optional: bool + is_sequence: bool + +@dataclasses.dataclass(frozen=True) +class OnnxNode: + op: str + opset_id: OpSetId + inputs: tuple[str, ...] + outputs: tuple[str, ...] + opts: dict[str, Any] + +# ***** protobuf parsing ****** +class PBBufferedReader(BufferedReader): + def __init__(self, tensor: Tensor): + assert tensor.dtype == dtypes.uint8, tensor + super().__init__(TensorIO(tensor)) + self.len = tensor.nbytes() + + def decode_varint(self) -> int: + """Reference: https://protobuf.dev/programming-guides/encoding/#varints""" + result = 0 + shift = 0 + while True: + data = self.read(1) + if data == b"": raise EOFError("decode_varint EOF") + result |= (data[0] & 0x7F) << shift + if not (data[0] & 0x80): return result + shift += 7 + if shift >= 70: raise ValueError("Varint too long") + + def read_delimited(self, use_tensor=False): + str_len = self.decode_varint() + if not use_tensor: return self.read(str_len) + raw = self.raw + assert isinstance(raw, TensorIO) + res = raw._tensor[self.tell():(self.tell()+str_len)] + self.seek(str_len, os.SEEK_CUR) + return res + def read_string(self) -> str: return self.read_delimited().decode("utf-8") + def read_bytes(self) -> Tensor: return self.read_delimited(use_tensor=True) + def read_float(self) -> float: return struct.unpack(" Tensor: return self.read_delimited(use_tensor=True) + def read_int64(self) -> int: + val = self.decode_varint() + return val - 2**64 if val & (1 << 63) else val + def read_packed_int64s(self) -> list[int]: + total_bytes_len = self.decode_varint() + old_pos = self.tell() + values = [] + # need copy here because packed ints are varint + while self.tell() < total_bytes_len + old_pos: values.append(self.read_int64()) + return values + + def skip_field(self, wire_type: WireType) -> None: + """Skip a field based on its wire type.""" + match wire_type: + case WireType.VARINT: self.decode_varint() + case WireType.FIXED64: self.seek(8, os.SEEK_CUR) + case WireType.FIXED32: self.seek(4, os.SEEK_CUR) + case WireType.LENGTH_DELIMITED: self.seek(self.decode_varint(), os.SEEK_CUR) + case _: raise ValueError(f"Unknown wire type: {wire_type}") + +class OnnxPBParser: + """ + ONNX protobuf parser. + Reference: https://github.com/onnx/onnx/blob/main/onnx/onnx.proto3 + """ + def __init__(self, inp: Tensor|str|pathlib.Path, load_external_data: bool=True): + self.file_path: pathlib.Path|None = None + self.load_external_data = load_external_data + if not isinstance(inp, Tensor): + self.file_path = pathlib.Path(inp) + self.tensor = Tensor(self.file_path) + else: self.tensor = inp + self.reader = PBBufferedReader(self.tensor) + + def parse(self) -> dict: + """Parses the ONNX model into a nested dictionary. """ + return self._parse_ModelProto() + + def _parse_message(self, end_pos: int) -> Generator[tuple[int, WireType], None, None]: + while self.reader.tell() < end_pos: + tag = self.reader.decode_varint() + yield tag >> 3, WireType(tag & 0x07) + + def _decode_end_pos(self) -> int: + str_len = self.reader.decode_varint() + start_pos = self.reader.tell() + return start_pos + str_len + + def _parse_ModelProto(self) -> dict: + """Entry point for parsing the ONNX model.""" + obj: dict[str, Any] = {"opset_import": []} + for fid, wire_type in self._parse_message(self.reader.len): + match fid: + case 4: obj["domain"] = self.reader.read_string() + case 5: obj["model_version"] = self.reader.read_int64() + case 7: obj["graph"] = self._parse_GraphProto() + case 8: obj["opset_import"].append(self._parse_OperatorSetIdProto()) + case _: self.reader.skip_field(wire_type) + + # update opset version + opset_imports = {Domain.from_onnx(x.get('domain')):x.get('version', 1) for x in obj["opset_import"]} + for n in obj["graph"]["node"]: + n_ = n["parsed_node"] + n["parsed_node"] = OnnxNode(n_.op, OpSetId(n_.opset_id.domain, opset_imports.get(n_.opset_id.domain, 1)), n_.inputs, n_.outputs, n_.opts) + return obj + + def _parse_GraphProto(self) -> dict: + obj: dict[str, Any] = {"node": [], "initializer": [], "input": [], "output": []} + for fid, wire_type in self._parse_message(self._decode_end_pos()): + match fid: + case 1: obj["node"].append(self._parse_NodeProto()) + case 2: obj["name"] = self.reader.read_string() + case 5: obj["initializer"].append(self._parse_TensorProto()) + case 11: obj["input"].append(self._parse_ValueInfoProto()) + case 12: obj["output"].append(self._parse_ValueInfoProto()) + case _: self.reader.skip_field(wire_type) + return obj + + def _parse_NodeProto(self) -> dict: + obj: dict[str, Any] = {"input": [], "output": [], "attribute": [], "domain": None} + for fid, wire_type in self._parse_message(self._decode_end_pos()): + match fid: + case 1: obj["input"].append(self.reader.read_string()) + case 2: obj["output"].append(self.reader.read_string()) + case 3: obj["name"] = self.reader.read_string() + case 4: obj["op_type"] = self.reader.read_string() + case 5: obj["attribute"].append(self._parse_AttributeProto()) + case 6: obj["doc_string"] = self.reader.read_string() + case 7: obj["domain"] = self.reader.read_string() + case _: self.reader.skip_field(wire_type) + + # parse node + attributes = {attr_dict["name"]: attr_dict[AttributeType(attr_dict["type"]).to_field_name()] for attr_dict in obj["attribute"]} + opset_id = OpSetId(Domain.from_onnx(obj.get('domain')), 1) # default version, to be updated later in _parse_ModelProto + obj["parsed_node"] = OnnxNode(obj["op_type"], opset_id, tuple(obj["input"]), tuple(obj["output"]), attributes) + return obj + + def _parse_TensorProto(self) -> dict: + obj: dict[str, Any] = {"dims": []} + for fid, wire_type in self._parse_message(self._decode_end_pos()): + match fid: + case 1: obj["dims"].append(self.reader.read_int64()) + case 2: obj["data_type"] = self.reader.read_int64() + case 4: obj["float_data"] = self.reader.read_packed_floats() + case 5: obj["int32_data"] = self.reader.read_packed_int64s() + case 7: obj["int64_data"] = self.reader.read_packed_int64s() + case 8: obj["name"] = self.reader.read_string() + case 9: obj["raw_data"] = self.reader.read_bytes() + case 10: obj["double_data"] = self.reader.read_packed_floats() + case 11: obj["uint64_data"] = self.reader.read_packed_int64s() + case 13: obj.setdefault("external_data", []).append(self._parse_StringStringEntryProto()) + case 14: obj["data_location"] = self.reader.read_int64() + case _: self.reader.skip_field(wire_type) + + # load external data + if self.load_external_data and obj.get("data_location", 0) == 1: + if "external_data" not in obj: raise ValueError("no external_data") + location, length, offset = None, None, 0 + for kv in obj["external_data"]: + if kv["key"] == "location": location = kv["value"] + elif kv["key"] == "offset": offset = int(kv["value"]) + elif kv["key"] == "length": length = int(kv["value"]) + if location is None: raise ValueError("no location in external_data") + + if self.file_path is None: + if isinstance(self.tensor.device, str) and self.tensor.device.startswith("DISK:"): + self.file_path = pathlib.Path(self.tensor.device[5:]) + else: raise ValueError("onnx external_data needs the origin file path, try passing onnx file path to onnx_load") + ext_path = self.file_path.parent.joinpath(location) + if not ext_path.exists(): raise FileNotFoundError(f"external location not exists: {ext_path}") + + ext_tensor = Tensor(ext_path) + obj["raw_data"] = ext_tensor[offset:offset+length] if length is not None else ext_tensor[offset:] + obj["data_location"] = 0 + + # parse tensor + to_dtype = dtype_fallback(true_dtype := OnnxDataType(obj['data_type']).to_dtype(), "buffer parse") + shape = tuple(obj['dims']) + present_fields = [field for field in ['float_data', 'int32_data', 'int64_data', 'double_data', 'uint64_data', 'raw_data'] if field in obj] + assert len(present_fields) == 1, f"only 1 data field is allowed from {obj=}" + data = obj[present_fields[0]] + if not isinstance(data, Tensor): + obj["parsed_tensor"] = Tensor(data, dtype=to_dtype).reshape(shape) + return obj + assert isinstance(data, Tensor) and data.dtype == dtypes.uint8, data + data = data.bitcast(true_dtype).reshape(shape) + data = data.to(Device.DEFAULT) if true_dtype is to_dtype else data.to("cpu").cast(to_dtype).to(Device.DEFAULT) + # const folding + if shape == (): + if data.dtype == dtypes.float16 and sys.version_info < (3, 12): data = data.cast(dtypes.float32) + data = Tensor(data.item(), dtype=to_dtype).reshape(shape) + obj["parsed_tensor"] = data + return obj + + def _parse_AttributeProto(self) -> dict: + obj: dict[str, Any] = {"floats": [], "ints": [], "strings": []} + for fid, wire_type in self._parse_message(self._decode_end_pos()): + match fid: + case 1: obj["name"] = self.reader.read_string() + case 2: obj["f"] = self.reader.read_float() + case 3: obj["i"] = self.reader.read_int64() + case 4: obj["s"] = self.reader.read_bytes().data().tobytes().decode("utf8") + case 5: obj["t"] = self._parse_TensorProto()['parsed_tensor'] + case 7: obj["floats"].append(self.reader.read_float()) + case 8: obj["ints"].append(self.reader.read_int64()) + case 9: obj["strings"].append(self.reader.read_bytes().data().tobytes().decode("utf8")) + case 20: obj["type"] = self.reader.read_int64() + case _: self.reader.skip_field(wire_type) + obj["floats"], obj["ints"], obj["strings"] = tuple(obj["floats"]), tuple(obj["ints"]), tuple(obj["strings"]) + return obj + + def _parse_ValueInfoProto(self) -> dict: + obj: dict[str, Any] = {} + for fid, wire_type in self._parse_message(self._decode_end_pos()): + match fid: + case 1: obj["name"] = self.reader.read_string() + case 2: obj["type"] = self._parse_TypeProto() + case _: self.reader.skip_field(wire_type) + + # parse type + if "type" not in obj: return {**obj, "parsed_type": None} + type_obj = obj["type"] + if is_optional := "optional_type" in type_obj: type_obj = type_obj["optional_type"]["elem_type"] + if is_sequence := "sequence_type" in type_obj: type_obj = type_obj["sequence_type"]["elem_type"] + assert "tensor_type" in type_obj, type_obj + shape_dims = type_obj['tensor_type'].get('shape', {}).get('dim', []) + obj['parsed_type'] = OnnxValue(tuple(d.get('dim_param') or d.get('dim_value') for d in shape_dims), + OnnxDataType(type_obj['tensor_type']['elem_type']).to_dtype(), is_optional, is_sequence) + return obj + + def _parse_TypeProto(self) -> dict: + obj: dict[str, Any] = {} + for fid, wire_type in self._parse_message(self._decode_end_pos()): + match fid: + case 1: obj["tensor_type"] = self._parse_TypeProtoTensor() + case 4: obj["sequence_type"] = self._parse_TypeProtoSequence() + case 9: obj["optional_type"] = self._parse_TypeProtoOptional() + case _: self.reader.skip_field(wire_type) + return obj + + def _parse_TypeProtoTensor(self) -> dict: + obj: dict[str, Any] = {} + for fid, wire_type in self._parse_message(self._decode_end_pos()): + match fid: + case 1: obj["elem_type"] = self.reader.read_int64() + case 2: obj["shape"] = self._parse_TensorShapeProto() + case _: self.reader.skip_field(wire_type) + return obj + + def _parse_TypeProtoSequence(self) -> dict: + obj = {} + for fid, wire_type in self._parse_message(self._decode_end_pos()): + match fid: + case 1: obj["elem_type"] = self._parse_TypeProto() + case _: self.reader.skip_field(wire_type) + return obj + + def _parse_TypeProtoOptional(self) -> dict: + obj = {} + for fid, wire_type in self._parse_message(self._decode_end_pos()): + match fid: + case 1: obj["elem_type"] = self._parse_TypeProto() + case _: self.reader.skip_field(wire_type) + return obj + + def _parse_TensorShapeProto(self) -> dict: + obj: dict[str, Any] = {"dim": []} + for fid, wire_type in self._parse_message(self._decode_end_pos()): + match fid: + case 1: obj["dim"].append(self._parse_TensorShapeProtoDimension()) + case _: self.reader.skip_field(wire_type) + return obj + + def _parse_TensorShapeProtoDimension(self) -> dict: + obj: dict[str, Any] = {} + for fid, wire_type in self._parse_message(self._decode_end_pos()): + match fid: + case 1: obj["dim_value"] = self.reader.read_int64() + case 2: obj["dim_param"] = self.reader.read_string() + case _: self.reader.skip_field(wire_type) + return obj + + def _parse_StringStringEntryProto(self) -> dict: + obj: dict[str, Any] = {} + for fid, wire_type in self._parse_message(self._decode_end_pos()): + match fid: + case 1: obj["key"] = self.reader.read_string() + case 2: obj["value"] = self.reader.read_string() + case _: self.reader.skip_field(wire_type) + return obj + + def _parse_OperatorSetIdProto(self) -> dict: + obj: dict[str, Any] = {} + for fid, wire_type in self._parse_message(self._decode_end_pos()): + match fid: + case 1: obj["domain"] = self.reader.read_string() + case 2: obj["version"] = self.reader.read_int64() + case _: self.reader.skip_field(wire_type) + return obj + +# ***** python const ***** +required_input_python_consts: dict[str, tuple[int, ...]] = { + "Tile": (1,), "Range": (0,1,2), "Expand": (1,), "Reshape": (1,), "Squeeze": (1,), "Unsqueeze": (1,), "Trilu": (1,), "ConstantOfShape": (0,), + "CumSum": (1,), "TopK": (1,), "Pad": (1,2,3), "MaxUnpool": (2,), "Dropout": (1,2), "CenterCropPad": (1,), "OneHot": (1,), "Compress": (1,), + "ImageDecoder": (0,), "AffineGrid": (1,), "Resize": (1,2,3), "Upsample": (1,), "Split": (1,), "Slice": (1,2,3,4), + **{"Reduce"+r: (1,) for r in ("Max", "Min", "Sum", "Mean", "SumSquare", "Prod", "L1", "L2", "LogSum", "LogSumExp")}, + **{optim: (1,) for optim in ("Adam", "Adagrad", "Momentum")} +} + +cache_misses = 0 +@functools.cache +def _cached_to_python_const(t:Tensor): + if t.dtype == dtypes.uint8: return t.data().tobytes() + if 0 in t.shape: return [] + return t.tolist() + +# Tensor -> python value cache for parameters +def to_python_const(t:Any, op:str, idx:int) -> list[ConstType]|ConstType|bytes: + if idx not in required_input_python_consts.get(op, ()) or not isinstance(t, Tensor): return t + global cache_misses + ret = _cached_to_python_const(t) + if (info := _cached_to_python_const.cache_info()).misses > cache_misses and DEBUG >= 3: + print(f"Cache miss for {t}") + cache_misses = info.misses + return ret + +# ***** runner ****** +debug = int(getenv("DEBUGONNX", "0")) +limit = int(getenv("ONNXLIMIT", "-1")) +class OnnxRunner: + """ + `OnnxRunner` executes an ONNX model using Tinygrad. + + Args: + model_path: The ONNX model, provided as a file path (a string or Path object) or a Tensor. + """ + def __init__(self, model_path: Tensor | str | pathlib.Path): + model = OnnxPBParser(model_path, load_external_data=True).parse() + graph = model["graph"] + self.is_training = any(n['parsed_node'].opset_id.domain in {Domain.AI_ONNX_TRAINING, Domain.AI_ONNX_PREVIEW_TRAINING} for n in graph["node"]) + self.graph_values = {"": None, **{i["name"]: i["parsed_tensor"] for i in graph["initializer"]}} + self.graph_inputs = {i["name"]: i["parsed_type"] for i in graph["input"] if i["name"] not in self.graph_values} + self.graph_outputs = tuple(o["name"] for o in graph["output"]) + self.graph_nodes = tuple(n["parsed_node"] for n in graph["node"]) + + self.old_training = Tensor.training + Tensor.training = True if self.is_training else False + + self.variable_dims: dict[str, int] = {} + self.onnx_ops = onnx_ops + + def _parse_input(self, name: str, value: Any, spec: OnnxValue): + if spec.is_optional and value is None: return None + if spec.is_sequence: + if not isinstance(value, Sequence): raise RuntimeError(f"input {name} received {value}, expected a sequence type") + sequence = [Tensor(v, dtype=spec.dtype, requires_grad=self.is_training) if not isinstance(v, Tensor) else v for v in value] + if not all_same(tuple(t.shape for t in sequence)): raise RuntimeError(f"Shapes for input {name} sequence must be homogeneous") + if not all(t.dtype is spec.dtype for t in sequence): warnings.warn(f"Dtypes for input {name} sequence aren't all {spec.dtype}") + return sequence + dtype = _from_np_dtype(value.dtype) if is_numpy_ndarray(value) else spec.dtype + tensor = Tensor(value, dtype=dtype, requires_grad=self.is_training) if not isinstance(value, Tensor) else value + if tensor.dtype is not spec.dtype: warnings.warn(f"input {name} has mismatch on dtype. Expected {spec.dtype}, received {tensor.dtype}.") + for dim, (onnx_dim, user_dim_input) in enumerate(zip(spec.shape, tensor.shape, strict=True)): + if isinstance(onnx_dim, str): + onnx_dim = self.variable_dims[onnx_dim] if onnx_dim in self.variable_dims else self.variable_dims.setdefault(onnx_dim, int(user_dim_input)) + if user_dim_input != onnx_dim: raise RuntimeError(f"input {name} has mismatch on {dim=}. Expected {onnx_dim}, received {user_dim_input}.") + return tensor + + def _select_op(self, op:str, required_opset:OpSetId) -> types.FunctionType: + if op not in self.onnx_ops: raise NotImplementedError(f"{op=} is not supported") + # return default implementation if no opset_id is specified + if isinstance(impl := self.onnx_ops[op], types.FunctionType): return impl + # match domain and select implementation with latest compatible version + eligible_ops = {impl_opset.version:impl_fxn for impl_opset,impl_fxn in impl.items() + if impl_opset.domain == required_opset.domain and impl_opset.version <= required_opset.version} + if not eligible_ops: raise NotImplementedError(f"{op=} is not supported for domain {required_opset.domain} and version {required_opset.version}") + return eligible_ops[max(eligible_ops.keys())] + + def get_empty_input_data(self, device:str|None=None, dtype:DType|None=None) -> dict[str, Tensor]: + return {name:Tensor.empty(*spec.shape, device=device, dtype=dtype or spec.dtype) for name, spec in self.graph_inputs.items()} + + def to(self, device:str|None): + self.graph_values = {k:v.to(device) if isinstance(v, Tensor) else v for k,v in self.graph_values.items()} + self.graph_nodes = tuple(OnnxNode(n.op, n.opset_id, tuple(n.inputs), tuple(n.outputs), + {k:v.to(device) if isinstance(v, Tensor) else v for k,v in n.opts.items()}) for n in self.graph_nodes) + return self + + def __call__(self, inputs:dict[str, Any], debug=debug): + for name, input_spec in self.graph_inputs.items(): + if name not in inputs: raise RuntimeError(f"Please provide input data for {name}") + self.graph_values[name] = self._parse_input(name, inputs[name], input_spec) + + for num, node in enumerate(self.graph_nodes): + inps = [to_python_const(self.graph_values[name], node.op, i) for i,name in enumerate(node.inputs)] + opts = node.opts + + # provide additional opts + if node.op == "Split" and 'num_outputs' not in opts: opts['num_outputs'] = len(node.outputs) + if node.op == "Gradient": opts['intermediate_tensors'] = self.graph_values + + if debug >= 1: print(f"{num}: op '{node.op}' opt {opts}") + if debug >= 2 and node.inputs: print("\tinputs:\n" + "\n".join(f"\t\t{x} - {i!r}" for x,i in zip(node.inputs, inps))) + ret = self._select_op(node.op, node.opset_id)(*inps, **opts) + ret = ret if isinstance(ret, tuple) else (ret,) + if debug >= 2: print("\toutputs:\n" + "\n".join(f"\t\t{x} - {o!r}" for x,o in zip(node.outputs, ret))) + + self.graph_values.update(dict(zip(node.outputs, ret[:len(node.outputs)], strict=True))) + + if num == limit: + Tensor.training = self.old_training + return {name:self.graph_values[name] for name in node.outputs} + Tensor.training = self.old_training + return {name:self.graph_values[name] for name in self.graph_outputs} + +#################### +##### ONNX OPS ##### +#################### +def get_onnx_ops() -> dict[str, types.FunctionType|dict[OpSetId, types.FunctionType]]: + # ***** helper functions ***** + def _resolve_const(x: Sequence[ConstType]|ConstType): return get_single_element(x) if isinstance(x, Sequence) else x + + def _axes(axes, noop_with_empty_axes): return axes or ([] if noop_with_empty_axes else None) + + # (padding_top, padding_left, ..., padding_bottom, padding_right, ...) -> (padding_left, padding_right, padding_top, padding_bottom, ...) + def _onnx_pads_to_tiny_pads(pads): return tuple(flatten(reversed(list(zip(pads, pads[len(pads)//2:]))))) + + AUTO_PAD_OPTIONS = Literal["NOTSET", "SAME_UPPER", "SAME_LOWER", "VALID"] + # (padding_height, padding_width) -> (padding_top, padding_left, padding_bottom, padding_right) + def _auto_pad(pads, auto_pad: AUTO_PAD_OPTIONS): + if auto_pad == "SAME_UPPER": return [pads[i]//2 for i in range(len(pads))] + [pads[i]-pads[i]//2 for i in range(len(pads))] + return [pads[i]-pads[i]//2 for i in range(len(pads))] + [pads[i]//2 for i in range(len(pads))] + + def _resolve_pool_pads(x:Tensor, p_, k_, d_, s_, auto_pad:AUTO_PAD_OPTIONS): + if auto_pad == "VALID": return [0]*(len(k_)*2) + i_, (s_,d_,p_) = x.shape[-len(k_):], (make_tuple(x, len(k_)*2) for x in (s_, d_, p_)) + if auto_pad == "NOTSET": return _onnx_pads_to_tiny_pads(p_ if len(p_)==len(k_)*2 else p_*2) + o_ = [((i - (1 if auto_pad in ("SAME_UPPER", "SAME_LOWER") else k)) // s + 1) for i,k,s in zip(i_, k_, s_)] + return _onnx_pads_to_tiny_pads(_auto_pad([(o-1)*s+k-i for o,i,k,s in zip(o_, i_, k_, s_)], auto_pad)) + + def _clamp_cast(x:Tensor, dtype:DType): return x.clamp(dtypes.min(dtype), dtypes.max(dtype)).cast(dtype) + + def _prepare_quantize(x:Tensor, scale:Tensor, zero_point:Tensor|int, axis=1, block_size=0): + if axis < 0: axis += x.ndim + # https://github.com/onnx/onnx/blob/main/onnx/reference/ops/op_quantize_linear.py#L31 + def reshape(val:Tensor): + if val.numel() == 1: return val + if block_size == 0: return val.reshape([val.shape[0] if dim == axis else 1 for dim in range(x.ndim)]) + return val.repeat_interleave(block_size, axis) + return (reshape(scale), reshape(zero_point) if isinstance(zero_point, Tensor) else zero_point) + + def _op_integer(op, inputs:list[Tensor], zero_points:list[Tensor], **opts): + adjusted_inputs = [inp.int() - zp for inp, zp in zip(inputs, zero_points)] + return op(*adjusted_inputs, **opts) + + def _qlinearop_quantized(op, inputs:list[Tensor], zero_points:list[Tensor], scales:list[Tensor], out_scale:Tensor, out_zero_point:Tensor, **opts): + # op execution is done in quantized int + out = _op_integer(op, inputs, zero_points, **opts) + assert dtypes.is_int(out.dtype), "quantized op should've done math in int" + out_quantized = (out * prod(scales) / out_scale).round() + out_zero_point + return _clamp_cast(out_quantized, out_zero_point.dtype) + + def _qlinearop_float(op, inputs:list[Tensor], zero_points:list[Tensor], scales:list[Tensor], out_scale:Tensor, out_zero_point:Tensor, **opts): + # op execution is done in float32 + dequantized_inputs = [(inp.int() - zp) * scale for inp, zp, scale in zip(inputs, zero_points, scales)] + out = op(*dequantized_inputs, **opts) + assert dtypes.is_float(out.dtype), "op should've done math in float" + out_quantized = (out / out_scale).round() + out_zero_point + return _clamp_cast(out_quantized, out_zero_point.dtype) + + def _onnx_training(input_group_size): + def __decorator(func): + def ___wrapper(R:Tensor, T:int, *inputs:Tensor, **kwargs): + R = R.detach() + groups = len(inputs) // input_group_size + ret = [func(R, T, *inps, **kwargs) for inps in (inputs[i::groups] for i in range(groups))] + return tuple(flatten(zip(*ret))) + return ___wrapper + return __decorator + + # ***** Property/Graph Ops ***** + def Identity(x:Tensor): return x + def Constant(sparse_value:Tensor|None=None, value:Tensor|None=None, value_float:float|None=None, value_floats:list[float]|None=None, + value_int:int|None=None, value_ints:list[int]|None=None, value_string:str|None=None, value_strings:list[str]|None=None): + if value is not None: return value + if value_float is not None: return Tensor(value_float, dtype=dtypes.float32, requires_grad=False) + if value_floats is not None: return Tensor(list(value_floats), dtype=dtypes.float32, requires_grad=False) + if value_int is not None: return Tensor(value_int, dtype=dtypes.int64, requires_grad=False) + if value_ints is not None: return Tensor(list(value_ints), dtype=dtypes.int64, requires_grad=False) + if value_string is not None or value_strings is not None or sparse_value is not None: + raise NotImplementedError('Constant OP not implemented for value_string, value_strings and sparse_value') + + def Range(start:float|int|list[float|int], limit:float|int|list[float|int], delta:float|int|list[float|int]): + return Tensor.arange(start=_resolve_const(start), stop=_resolve_const(limit), step=_resolve_const(delta)) + + def ImageDecoder(encoded_stream:bytes, pixel_format="RGB"): + try: import PIL.Image + except ImportError as e: raise ImportError("Pillow must be installed for the ImageDecoder operator") from e + img = PIL.Image.open(io.BytesIO(encoded_stream)) + if pixel_format == "BGR": return Tensor(img.tobytes(), dtype=dtypes.uint8).reshape(*img.size, 3).flip(-1) + if pixel_format == "RGB": return Tensor(img.tobytes(), dtype=dtypes.uint8).reshape(*img.size, 3) + if pixel_format == "Grayscale": return Tensor(img.convert("L").tobytes(), dtype=dtypes.uint8).reshape(*img.size, 1) + raise ValueError(f"pixel_format={pixel_format!r} is not supported.") + + def EyeLike(x:Tensor, dtype:int|None=None, k:int=0): + ret = Tensor.eye(cast(int, min(x.shape)), dtype=dtype_fallback(OnnxDataType(dtype).to_dtype(), "EyeLike op") if dtype is not None else x.dtype) + return ret if x.size(0) == x.size(1) else ret.pad(tuple(None if d == ret.size(0) else (k, d-ret.shape[0]-k) for d in x.shape)) + + def OptionalHasElement(x:Tensor|None=None): return Tensor(x is not None and x.numel() > 0) + def OptionalGetElement(x:Tensor|None=None): return x if x is not None else Tensor([]) + def ConstantOfShape(shape:list[int], value:Tensor|None=None): + if value is None: value = Tensor(0, dtype=dtypes.float32) + if shape == [0]: return Tensor([], dtype=value.dtype) + return value.expand(shape) + + def Size(data:Tensor): return data.numel() + def Shape(data:Tensor, end:int|None=None, start:int=0): return Tensor(data.shape[start:end], dtype=dtypes.int64) + + # ***** Unary Ops (math) ***** + def Not(x:Tensor): return x.logical_not() + def Clip(x: Tensor, min:Tensor|None=None, max:Tensor|None=None): return x if min is None and max is None else x.clip(min, max) # noqa: A002 # pylint: disable=redefined-builtin + def IsInf(x:Tensor, detect_negative:int=1, detect_positive:int=1): return x.isinf(bool(detect_positive), bool(detect_negative)) + + # ***** Unary Ops (activation) ***** + def softmax_1(x:Tensor, axis:int=1): return x.softmax(axis) + def softmax_13(x:Tensor, axis:int=-1): return x.softmax(axis) + Softmax = {OpSetId(Domain.ONNX, 1):softmax_1, OpSetId(Domain.ONNX, 13):softmax_13} + def HardSigmoid(x:Tensor, alpha:float=0.2, beta:float=0.5): return (alpha*x + beta).clip(0, 1) + def Gelu(x:Tensor, approximate:str|None=None): return x.gelu() if approximate == "tanh" else 0.5 * x * (1 + (x/math.sqrt(2)).erf()) + def BiasGelu(x: Tensor, bias: Tensor, approximate: str | None = None) -> Tensor: return Gelu(x + bias, approximate) + def FastGelu(x:Tensor, bias:Tensor|None=None): return (x + bias).gelu() if bias is not None else x.gelu() # this is tanh approximated + def PRelu(X:Tensor, slope:Tensor): return (X > 0).where(X, X * slope) + def LeakyRelu(X:Tensor, alpha:float=0.01): return X.leaky_relu(alpha) + def ThresholdedRelu(X:Tensor, alpha:float=1.0): return (X > alpha).where(X, 0) + def LogSoftmax(x: Tensor, axis:int=-1): return x.log_softmax(axis) + def Binarizer(x:Tensor, threshold:float=0.0): return (x > threshold).float() + + # ***** Unary Ops (broadcasted) ***** + def Add(x:Tensor,y:Tensor, broadcast=None, axis=None): return x + y + def Sub(x:Tensor|int,y:Tensor): return x - y # some test has input as int + def Div(x:Tensor,y:Tensor): return x.div(y, rounding_mode='trunc' if dtypes.is_int(x.dtype) else None) + def Less(x:Tensor,y:Tensor): return x < y + def LessOrEqual(x:Tensor,y:Tensor): return x <= y + def Greater(x:Tensor,y:Tensor): return x > y + def GreaterOrEqual(x:Tensor,y:Tensor): return x >= y + def Equal(x:Tensor,y:Tensor): return x == y + def And(x:Tensor,y:Tensor): return (x==y).where(x, False) + def Or(x:Tensor,y:Tensor): return (x==y).where(x, True) + def Xor(x:Tensor,y:Tensor): return x.bool().bitwise_xor(y.bool()) + def BitwiseAnd(x:Tensor,y:Tensor): return x & y + def BitwiseOr(x:Tensor,y:Tensor): return x | y + def BitwiseXor(x:Tensor,y:Tensor): return x ^ y + def BitwiseNot(x:Tensor): return ~x + def Mod(x:Tensor,y:Tensor,fmod=0): return x - x.div(y, rounding_mode="trunc") * y if fmod else x % y + + # ***** Casting Ops ***** + # TODO: saturate + def Cast(x:Tensor, to:int, saturate:int=1): return x.cast(dtype_fallback(OnnxDataType(to).to_dtype(), "Cast op")) + def CastLike(x:Tensor, target_type:Tensor, saturate:int=1): return x.cast(target_type.dtype) + + # ***** Reduce Ops ***** + def Max(*data_0:Tensor): return functools.reduce(Tensor.maximum, data_0) + def Min(*data_0:Tensor): return functools.reduce(Tensor.minimum, data_0) + def Sum(*data_0:Tensor): return functools.reduce(Tensor.add, data_0) + def Mean(*data_0:Tensor): return Sum(*data_0) / len(data_0) + def ReduceMax(data:Tensor, axes:list[int]|None=None, keepdims:int=1, noop_with_empty_axes:int=0): + return data.max(_axes(axes, noop_with_empty_axes), keepdim=keepdims) + def ReduceMin(data:Tensor, axes:list[int]|None=None, keepdims:int=1, noop_with_empty_axes:int=0): + return data.min(_axes(axes, noop_with_empty_axes), keepdim=keepdims) + def ReduceSum(data:Tensor, axes:list[int]|None=None, keepdims:int=1, noop_with_empty_axes:int=0): + return data.sum(_axes(axes, noop_with_empty_axes), keepdim=keepdims) + def ReduceMean(data:Tensor, axes:list[int]|None=None, keepdims:int=1, noop_with_empty_axes:int=0): + return data.mean(_axes(axes, noop_with_empty_axes), keepdim=keepdims) + def ReduceSumSquare(data:Tensor, axes:list[int]|None=None, keepdims:int=1, noop_with_empty_axes:int=0): + return ReduceSum(data.square(), axes, keepdims, noop_with_empty_axes) + def ReduceProd(data:Tensor, axes:list[int]|None=None, keepdims:int=1, noop_with_empty_axes:int=0): + return data.prod(_axes(axes, noop_with_empty_axes), keepdim=keepdims) + def ReduceL1(data:Tensor, axes:list[int]|None=None, keepdims:int=1, noop_with_empty_axes:int=0): + return ReduceSum(data.abs(), axes, keepdims, noop_with_empty_axes) + def ReduceL2(data:Tensor, axes:list[int]|None=None, keepdims:int=1, noop_with_empty_axes:int=0): + return ReduceSumSquare(data, axes, keepdims, noop_with_empty_axes).sqrt() + def ReduceLogSum(data:Tensor, axes:list[int]|None=None, keepdims:int=1, noop_with_empty_axes:int=0): + return ReduceSum(data, axes, keepdims, noop_with_empty_axes).log() + def ReduceLogSumExp(data:Tensor, axes:list[int]|None=None, keepdims:int=1, noop_with_empty_axes:int=0): + return ReduceSum(data.exp(), axes, keepdims, noop_with_empty_axes).log() + def ArgMax(x:Tensor, axis:int=0, keepdims:int=1, select_last_index:int=0): + if select_last_index: return ((x.shape[axis]-1) - x.flip(axis).argmax(axis, keepdim=keepdims)).cast(dtypes.int64) + return x.argmax(axis, keepdim=keepdims).cast(dtypes.int64) + def ArgMin(x, axis:int=0, keepdims:int=1, select_last_index:int=0): + return ArgMax(-x, axis=axis, keepdims=keepdims, select_last_index=select_last_index) + + # ***** Movement Ops ***** + def Reshape(data:Tensor, shape:list[int], allowzero:int=0): + return data.reshape([x if x != 0 else (0 if allowzero else data.shape[i]) for i,x in enumerate(shape)]) + def Flatten(x:Tensor, axis:int=1): return x.reshape(prod(x.shape[0:axis]), -1) + def Expand(x:Tensor, shape:list[int]): return x.expand(_broadcast_shape(x.shape, tuple(shape))) + def Shrink(x:Tensor, bias:float=0.0, lambd:float=0.5): return (x < -lambd)*(x+bias) + (x > lambd)*(x-bias) + def Transpose(x:Tensor, perm:list[int]|None=None): return x.permute(order=perm or list(range(x.ndim)[::-1])) + + def Squeeze(data:Tensor, axes:list[int]|None=None): + return data.squeeze() if axes is None else functools.reduce(lambda d, dim: d.squeeze(dim), sorted(axes, reverse=True), data) + def Unsqueeze(data:Tensor, axes:list[int]): return functools.reduce(lambda d, dim: d.unsqueeze(dim), sorted(axes), data) + + def Tile(x:Tensor, repeats:list[int]): return x.repeat(repeats) + def Concat(*xs:Tensor, axis:int): return Tensor.cat(*xs, dim=axis) + def Slice(data:Tensor, starts:list[int], ends:list[int], axes:list[int]|None=None, steps:list[int]|None=None): + axes = axes or list(range(data.ndim)) + steps = steps or [1]*data.ndim + slices = [slice(0,x,1) for x in data.shape] + for i, axis in enumerate(axes): slices[axis] = slice(starts[i], ends[i], steps[i]) + return data[tuple(slices)] + + def Split(data:Tensor, split:list[int]|None=None, num_outputs:int=0, axis:int=0): + sz = data.shape[axis] + if split is None: split = [sz // num_outputs + (1 if i < sz % num_outputs else 0) for i in range(num_outputs)] + return data.split(split, axis) + + def Pad(x:Tensor, pads:list[int], constant_value:ConstType|None=None, axes:list[int]|None=None, + mode:Literal["constant", "reflect", "edge", "wrap"]="constant", value=0): + value = constant_value or value + axes = axes or list(range(x.ndim)) + real_pads = [0] * (x.ndim*2) + for i,axis in enumerate(axes): real_pads[axis%x.ndim], real_pads[axis%x.ndim+x.ndim] = pads[i], pads[i+len(axes)] + return x.pad(padding=_onnx_pads_to_tiny_pads(real_pads), mode={"edge":"replicate", "wrap":"circular"}.get(mode, mode), value=value) + + def CenterCropPad(t:Tensor, shape:list[int], axes:list[int]|None=None): + shrink_arg:list[None|tuple[int,int]] = [None] * t.ndim + pad_arg:list[None|tuple[int,int]] = [None] * t.ndim + for s, x in zip(shape, axes or range(t.ndim)): + tx = t.shape[x] + if s < tx: shrink_arg[x] = (tx//2 - (s+1)//2, tx//2 + s//2) + elif s > tx: pad_arg[x] = ((s-tx)//2, (s-tx+1)//2) + return t.shrink(tuple(shrink_arg)).pad(tuple(pad_arg)) + + # ***** Processing Ops ***** + def AveragePool(X: Tensor, kernel_shape:list[int], auto_pad:AUTO_PAD_OPTIONS="NOTSET", ceil_mode:int=0, count_include_pad:int=0, + dilations:list[int]|int=1, pads:list[int]|int=0, strides:list[int]|int=1): + pool_pads = _resolve_pool_pads(X, pads, kernel_shape, dilations, strides, auto_pad) + return X.avg_pool2d(tuple(kernel_shape), strides, dilations, pool_pads, ceil_mode=ceil_mode, count_include_pad=count_include_pad) + + def MaxPool(X: Tensor, kernel_shape:list[int], auto_pad:AUTO_PAD_OPTIONS="NOTSET", ceil_mode:int=0, dilations:list[int]|int=1, pads:list[int]|int=0, + storage_order:int=0, strides:list[int]|int=1): + pool_pads = _resolve_pool_pads(X, pads, kernel_shape, dilations, strides, auto_pad) + out = X.max_pool2d(tuple(kernel_shape), strides, dilations, pool_pads, ceil_mode=ceil_mode, return_indices=True) + ret, idx = cast(tuple[Tensor, Tensor], out) + return ret, idx.transpose(-2, -1).cast(dtypes.int64) if storage_order else idx.cast(dtypes.int64) + + def Conv(X: Tensor, W: Tensor, B:Tensor|None=None, auto_pad:AUTO_PAD_OPTIONS="NOTSET", dilations:list[int]|int=1, group:int=1, + kernel_shape:list[int]|None=None, pads:list[int]|int=0, strides:list[int]|int=1): + return X.conv2d(W, B, stride=strides, groups=group, dilation=dilations, + padding=_resolve_pool_pads(X, pads, kernel_shape or W.shape[2:], dilations, strides, auto_pad)) + + def ConvTranspose(X: Tensor, W: Tensor, B:Tensor|None=None, auto_pad:AUTO_PAD_OPTIONS="NOTSET", dilations:list[int]|int=1, group:int=1, + kernel_shape:list[int]|None=None, pads:list[int]|None=None, output_shape:list[int]|None=None, output_padding:list[int]|int=0, + strides:list[int]|int=1): + input_shape_, kernel_shape_ = X.shape[2:], (kernel_shape or W.shape[2:]) + strides_, dilations_, output_padding_ = (make_tuple(x, len(input_shape_)) for x in (strides, dilations, output_padding)) + if output_shape is not None: # we pad according to output_shape + pads = _auto_pad([s_*(i-1) + op_ + ((k_-1)*d_+1) - os for s_,i,op_,k_,d_,os in + zip(strides_, input_shape_, output_padding_, kernel_shape_, dilations_, output_shape)], auto_pad) + if pads is None: # we generate pads + output_shape = output_shape or [X.shape[i+2] * strides_[i] for i in range(len(strides_))] + pads = [strides_[i]*(input_shape_[i]-1)+output_padding_[i]+((kernel_shape_[i]-1)*dilations_[i]+1)-output_shape[i] + for i in range(len(input_shape_))] + pads = _auto_pad(pads, auto_pad) if auto_pad != "NOTSET" else [0] * len(input_shape_) * 2 + pads = _onnx_pads_to_tiny_pads(pads) + return X.conv_transpose2d(W, B, group, strides_, dilations_, pads, output_padding_) + + def MaxUnpool(xT: Tensor, xI: Tensor, outshape: list[int]|None=None, kernel_shape:list[int]|None=None, pads:list[int]|int=0, + strides:list[int]|int=1): + if kernel_shape is None: kernel_shape = [] + pads_: int | tuple[int, ...] = tuple(pads) if isinstance(pads, list) else pads + return Tensor.max_unpool2d(xT, xI, tuple(kernel_shape), strides, 1, pads_, outshape if outshape is None else tuple(outshape)) + + def GlobalAveragePool(X:Tensor): return X.mean(axis=tuple(range(2, X.ndim)), keepdim=True) + def GlobalMaxPool(X:Tensor): return X.max(axis=tuple(range(2, X.ndim)), keepdim=True) + + def Gemm(A:Tensor, B:Tensor, C:Tensor|None=None, alpha:float=1.0, beta:float=1.0, transA:int=0, transB:int=0, broadcast=0): + ret = alpha * (A.transpose(transA) @ B.transpose(transB)) + if C is not None: ret = ret + beta * (C if broadcast == 0 else C.reshape([-1 if i < len(C.shape) else 1 for i in range(ret.ndim)][::-1])) + return ret + + def Einsum(*Inputs:list[Tensor], equation:str): return Tensor.einsum(equation, *Inputs) + + def CumSum(X:Tensor, axis:int|list[int], exclusive:int=0, reverse:int=0): + axis = X._resolve_dim(_resolve_const(axis)) + if reverse: X = X.flip(axis) + if exclusive: X = X.pad(tuple((1,0) if i == axis else None for i in range(X.ndim)))\ + .shrink(tuple((0,X.shape[axis]) if i == axis else None for i in range(X.ndim))) + return X.cumsum(axis).flip(axis) if reverse else X.cumsum(axis) + + def Trilu(x:Tensor, k:int|list[int]=0, upper:int=1): + k_ = _resolve_const(k) + return x.triu(k_) if upper else x.tril(k_) + + def Resize(X:Tensor, roi:list[float]|None=None, scales:list[float]|None=None, sizes:list[int]|None=None, antialias:int=0, + axes:list[int]|None=None, coordinate_transformation_mode:str='half_pixel', cubic_coeff_a:float=-0.75, exclude_outside:int=0, + extrapolation_value:float=0.0, keep_aspect_ratio_policy:str='stretch', mode:str='nearest', nearest_mode:str='round_prefer_floor'): + def _apply_transformation(input_sz, output_sz, scale_dim, mode): + index = Tensor.arange(output_sz, requires_grad=False, device=X.device) + if mode == "half_pixel": return (index + 0.5) / scale_dim - 0.5 + if mode == "align_corners": return index * (input_sz - 1) / (output_sz - 1) if output_sz != 1 else Tensor.zeros_like(index) + if mode == "asymmetric": return index / scale_dim + if mode == "pytorch_half_pixel": return ((index + 0.5) / scale_dim - 0.5) if output_sz != 1 else Tensor.zeros_like(index) + if mode == "half_pixel_symmetric": + output_dim_scaled = input_sz * scale_dim + return (input_sz / 2) * (1 - (output_sz / output_dim_scaled)) + (index + 0.5) / scale_dim - 0.5 + raise ValueError(f"invalid {coordinate_transformation_mode=}") + + if antialias: raise NotImplementedError("antialias is not implemented") + axes = axes or list(range(X.ndim)) + perm = [a for a in range(len(X.shape)) if a not in axes] + list(axes) + # we pre-permute the axes and permute back after resize + # the permute aligns X's axes to scales, sizes, and roi + X = X.permute(*perm) + + input_shape = cast(tuple[int, ...], X.shape[2:]) + if scales is not None: assert all(sc==1 for sc in scales[:-len(input_shape)]), "resizing batch_size dim or channel dim not supported" + if sizes is not None: assert tuple(sizes[:-2]) == tuple(X.shape[X.ndim-len(sizes):-2]), "resizing batch_size dim or channel dim not supported" + + scales, sizes = (None if scales is None else scales[-len(input_shape):]), (None if sizes is None else sizes[-len(input_shape):]) + if sizes is not None: + if keep_aspect_ratio_policy in ["not_larger", "not_smaller"]: + scale_fxn = min if keep_aspect_ratio_policy == "not_larger" else max + scale = scale_fxn(sz / sh for sz,sh in zip(sizes, input_shape)) + sizes, scales = [int(scale * sh + 0.5) for sh in input_shape], [scale]*len(input_shape) + else: scales = [sz / sh for sz, sh in zip(sizes, input_shape)] + else: + assert scales is not None, "either sizes or scales must be provided" + sizes = [int(sc * sh) for sc, sh in zip(scales, input_shape)] + + if all(sz == sh for sz, sh in zip(sizes, input_shape)): return X.permute(*argsort(perm)) if perm else X + + indexes = [] + for input_sz, output_sz, scale in zip(input_shape, sizes, scales): + indexes.append(_apply_transformation(input_sz, output_sz, scale, coordinate_transformation_mode)) + + if mode in ["nearest", "linear"]: indexes = [idx.clip(0, sz-1) for idx, sz in zip(indexes, input_shape)] + + if mode == "nearest": + mode_operations = { + "round_prefer_floor": lambda idx: (idx - 0.5).ceil(), + "round_prefer_ceil": lambda idx: (idx + 0.5).floor(), + "floor": lambda idx: idx.floor(), + "ceil": lambda idx: idx.ceil() + } + if nearest_mode not in mode_operations: raise ValueError(f"invalid {nearest_mode=}") + indexes = [mode_operations[nearest_mode](idx).int() for idx in indexes] + X = X[(..., *Tensor.meshgrid(*indexes))] + + if mode == "linear": + expand = list(X.shape) + for i in range(-len(sizes), 0): + reshape, index = [1] * X.ndim, indexes[i] + reshape[i] = expand[i] = sizes[i] + low, high, perc = [y.reshape(reshape).expand(expand) for y in (index.floor().int(), index.ceil().int(), index - index.floor())] + X = X.gather(i, low).lerp(X.gather(i, high), perc) + + if mode == "cubic": + A = cubic_coeff_a + + # Keys weights + # see piecewise function in: https://en.wikipedia.org/wiki/Bicubic_interpolation#Bicubic_convolution_algorithm + def W0_1(x:Tensor): return polyN(x, [A + 2, -(A + 3), 0, 1]) + def W1_2(x: Tensor): return polyN(x, [A, -5 * A, 8 * A, -4 * A]) + + expand = list(X.shape) + for i in range(-len(sizes), 0): + input_sz = cast(int, X.shape[i]) + reshape, index = [1] * X.ndim, indexes[i] + reshape[i] = expand[i] = sizes[i] + + p = index.floor().int() + ratio = index - p # in [0, 1] + + # Neighbor indices + idx0, idx1, idx2, idx3 = [p + d for d in [-1, 0, 1, 2]] + # Weights of distance from index and neighbor indices + c0, c1, c2, c3 = W1_2(ratio+1), W0_1(ratio), W0_1(-(ratio-1)), W1_2(-(ratio-2)) + + if exclude_outside: + c0 = ((idx0 >= 0) & (idx0 < input_sz)).where(c0, 0) + c1 = ((idx1 >= 0) & (idx1 < input_sz)).where(c1, 0) + c2 = ((idx2 >= 0) & (idx2 < input_sz)).where(c2, 0) + c3 = ((idx3 >= 0) & (idx3 < input_sz)).where(c3, 0) + + total = c0 + c1 + c2 + c3 + c0, c1, c2, c3 = c0 / (total + 1e-9), c1 / (total + 1e-9), c2 / (total + 1e-9), c3 / (total + 1e-9) + + # Reshape and expand + expanded_indices = [y.clip(0, input_sz - 1).reshape(reshape).expand(expand) for y in [idx0, idx1, idx2, idx3]] + expanded_coeffs = [y.reshape(reshape).expand(expand) for y in [c0, c1, c2, c3]] + + # Gather values and apply coefficients + gathered_values = [X.gather(i, idx) for idx in expanded_indices] + X = sum(v * c for v, c in zip(gathered_values, expanded_coeffs)) + return X.permute(*argsort(perm)) if perm else X + def Upsample(X, scales, mode): return Resize(X=X, scales=scales, mode=mode) # deprecated + + def TopK(X:Tensor, K:int|list[int], axis:int=-1, largest:int=1, sorted:int=1): # noqa: A002 # pylint: disable=redefined-builtin + val, idx = X.topk(_resolve_const(K), axis, bool(largest), bool(sorted)) + return val, idx.cast(dtypes.int64) + + # ***** Neural Network Ops ***** + def BatchNormalization(X:Tensor, scale:Tensor, B:Tensor, input_mean:Tensor, input_var:Tensor, epsilon:float=1e-05, momentum:float=0.9, + training_mode:int=0, spatial=1, is_test=0): + if training_mode: + x_detached = X.detach() + current_mean = x_detached.mean(axis=(0,2,3)) + y = (x_detached - current_mean.reshape(shape=[1, -1, 1, 1])) + current_var = (y*y).mean(axis=(0,2,3)) + current_invstd = current_var.add(epsilon).rsqrt() + + running_mean = input_mean * momentum + current_mean * (1 - momentum) + running_var = input_var * momentum + current_var * (1 - momentum) + + return X.batchnorm(scale, B, current_mean, current_invstd), running_mean, running_var + return X.batchnorm(scale, B, input_mean, (input_var + epsilon).rsqrt()) + def GroupNormalization(x:Tensor, scale:Tensor, bias:Tensor, num_groups:int, epsilon:float=1e-05): + x = x.reshape(x.shape[0], num_groups, -1).layernorm(eps=epsilon).reshape(x.shape) + return x * scale.reshape(1, -1, *[1] * (x.ndim-2)) + bias.reshape(1, -1, *[1] * (x.ndim-2)) + def InstanceNormalization(x:Tensor, scale:Tensor, bias:Tensor, epsilon:float=1e-05): + return GroupNormalization(x, scale, bias, num_groups=cast(int, x.shape[1]), epsilon=epsilon) + def LayerNormalization(x:Tensor, scale:Tensor, bias:Tensor, axis:int=-1, epsilon:float=1e-05, stash_type:int=1): + assert stash_type == 1, "only float32 is supported" + axes = tuple(i for i in range(axis if axis >= 0 else x.ndim + axis, x.ndim)) + mean = x.mean(axis=axes, keepdim=True) + return x.layernorm(axes, epsilon).mul(scale).add(bias), mean, (x.sub(mean)).square().mean(axis=axes, keepdim=True).add(epsilon).rsqrt() + def SkipLayerNormalization(x:Tensor, skip:Tensor, gamma:Tensor, beta:Tensor|None=None, bias:Tensor|None=None, epsilon:float=1e-12): + x = x + skip + if bias is not None: x = x + bias + ret = x.layernorm(eps=epsilon) * gamma + if beta is not None: ret = ret + beta + return ret, None, None, x + def EmbedLayerNormalization(input_ids: Tensor, segment_ids:Tensor, word_embedding:Tensor, position_embedding:Tensor, + segment_embedding:Tensor, gamma=None, beta=None, mask:Tensor|None=None, + position_ids:Tensor|None=None, epsilon=1e-12, mask_index_type=0): + # https://github.com/microsoft/onnxruntime/blob/main/docs/ContribOperators.md#com.microsoft.EmbedLayerNormalization + assert (segment_ids is None) is (segment_embedding is None) + assert mask is None and not mask_index_type, "functionality not supported yet" # TODO + input_shape = input_ids.shape + seq_length = input_shape[1] + compute_seg_emb = (segment_embedding is not None and segment_ids is not None) + vocab_size, max_position_embeddings = word_embedding.shape[0], position_embedding.shape[0] + type_vocab_size = (segment_embedding.shape[0] if compute_seg_emb else None) + + def embedding(x:Tensor, vocab_size, weight:Tensor) -> Tensor: + return x.unsqueeze(-1).expand(*x.shape, vocab_size)._one_hot_along_dim(vocab_size) @ weight + + # bert embedding layer + if position_ids is None: position_ids = Tensor.arange(seq_length, requires_grad=False).unsqueeze(0).expand(*input_shape) + wrd_embedding_res = embedding(input_ids, vocab_size, word_embedding) + pos_embedding_res = embedding(position_ids, max_position_embeddings, position_embedding) + seg_embedding_res = embedding(segment_ids, type_vocab_size, segment_embedding) if compute_seg_emb else None + + embedding_sum = wrd_embedding_res + pos_embedding_res + if seg_embedding_res is not None: embedding_sum = embedding_sum + seg_embedding_res + out = embedding_sum.layernorm(eps=epsilon) * gamma + beta + return out, None, embedding_sum + def MeanVarianceNormalization(x:Tensor, axis:list[int]|None=None): + if axis is None: axis = [0,2,3] + return (x - x.mean(axis, keepdim=True)) / (x.std(axis, keepdim=True, correction=0) + 1e-9) + + def OneHot(indices:Tensor, depth:float|int|list[int|float], values:Tensor, axis:int=-1): + # Scalar or Rank 1 tensor containing exactly one element + depth = int(_resolve_const(depth)) + indices = indices.int() + indices = (indices < 0).where(indices+depth, indices) + return indices.unsqueeze(axis)._one_hot_along_dim(depth, dim=axis).where(values[1], values[0]) + + def DepthToSpace(X:Tensor, blocksize:int, mode:str="DCR"): + return X.rearrange("b (c h1 w1) h w -> b c (h h1) (w w1)" if mode=="CRD" else "b (h1 w1 c) h w -> b c (h h1) (w w1)", h1=blocksize, w1=blocksize) + def SpaceToDepth(X:Tensor, blocksize:int): + return X.rearrange("b c (h h1) (w w1) -> b (h1 w1 c) h w", h1=blocksize, w1=blocksize) + + # Reimplemented here because you need legacy RNG for passing ONNX tests. + def dropout_7(data:Tensor, ratio:float=0.5, training_mode:bool=False, seed:int|None=None): + import numpy as np + if not training_mode: return data, data.full_like(True, dtype=dtypes.bool) + if seed is not None: + rand = Tensor(np.random.RandomState(seed).random(cast(tuple[int,...], data.shape)), requires_grad=False, dtype=data.dtype, device=data.device) + else: + rand = data.rand_like(requires_grad=False) + mask = rand >= ratio + return data * mask / (1.0 - ratio), mask + # 6 with 'is_test' needed for https://github.com/MTlab/onnx2caffe/raw/refs/heads/master/model/MobileNetV2.onnx + def dropout_6(data:Tensor, ratio:float=0.5, is_test=0): return dropout_7(data, ratio, training_mode=not is_test) + Dropout = {OpSetId(Domain.ONNX, 6):dropout_6, OpSetId(Domain.ONNX, 7):dropout_7} + + def LRN(x:Tensor, size:int, alpha:float=1e-4, beta:float=0.75, bias:float=1.0): + pooled_x = (x**2).rearrange('b c h w -> b 1 c (h w)').pad((0,0,(size-1)//2, size//2)).avg_pool2d((size, 1), 1) + return x / (pooled_x.reshape(x.shape) * alpha + bias).pow(beta) + + def NegativeLogLikelihoodLoss(x:Tensor, target:Tensor, weight:Tensor|None=None, ignore_index:int|None=None, reduction:ReductionStr="mean"): + return x.nll_loss(target, weight, ignore_index, reduction) + def SoftmaxCrossEntropyLoss(scores:Tensor, labels:Tensor, weights:Tensor|None=None, ignore_index:int|None=None, reduction:ReductionStr="mean"): + log_probs = scores.log_softmax(1) + return log_probs.nll_loss(labels, weights, ignore_index, reduction), log_probs + + def AffineGrid(theta:Tensor, size:list[int], align_corners:int=0): + N, _, *spatial_dims = size + def generate_grid(steps): + if align_corners: return Tensor.linspace(-1, 1, steps, device=theta.device) + return Tensor.linspace(-1+1/steps, 1-1/steps, steps, device=theta.device) + grids = Tensor.meshgrid(*(generate_grid(d) for d in spatial_dims)) + base_grid = Tensor.stack(*reversed(grids), Tensor.ones_like(grids[0], device=theta.device), dim=-1) + base_grid = base_grid.reshape(1, prod(spatial_dims), len(grids)+1).expand(N, -1, -1) + return (base_grid @ theta.transpose(1, 2)).reshape(N, *spatial_dims, -1) + + def attention_contrib(x:Tensor, weights:Tensor, bias:Tensor|None=None, mask_index:Tensor|None=None, past:Tensor|None=None, + attention_bias:Tensor|None=None, past_sequence_length:Tensor|None=None, do_rotary:int=0, mask_filter_value:float=-10000.0, + num_heads:int|None=None, past_present_share_buffer:int|None=None, qkv_hidden_sizes:list[int]|None=None, + rotary_embedding_dim:int|None=None, scale:float|None=None, unidirectional:int=0): + assert not do_rotary and not attention_bias, "TODO" + if qkv_hidden_sizes is None: qkv_hidden_sizes = [weights.shape[1] // 3] * 3 + qkv = x.linear(weights, bias) + q, k, v = qkv.split(qkv_hidden_sizes, dim=2) + + batch_size, seq_len, _ = x.shape + assert num_heads is not None, "num_heads must be provided" + q_head_size, k_head_size, v_head_size = (sz // num_heads for sz in qkv_hidden_sizes) + q, k, v = (x.reshape(batch_size, seq_len, num_heads, hsz).transpose(1, 2) for x, hsz in zip((q, k, v), (q_head_size, k_head_size, v_head_size))) + + present = None + if past is not None: + k, v = past[0].cat(k, dim=2), past[1].cat(v, dim=2) + present = k.stack(v) + + if scale is None: scale = 1.0 / math.sqrt(q_head_size) + attn_scores = q @ k.transpose(-1, -2) * scale + + if mask_index is not None: + assert 4 >= mask_index.ndim >= 1, f"{mask_index.ndim=}" + assert isinstance(batch_size, int), f"{batch_size=}" + if mask_index.ndim != 1: mask = mask_index.bool() + else: + if mask_index.shape[0] == batch_size: + mask = Tensor.arange(attn_scores.shape[-1], requires_grad=False, device=mask_index.device).unsqueeze(0) < mask_index.unsqueeze(1) + elif mask_index.shape[0] == 2*batch_size: + end_positions = mask_index[:batch_size] + start_positions = mask_index[batch_size:] + arange = Tensor.arange(seq_len).unsqueeze(0) + mask = (arange < end_positions.unsqueeze(1)) & (arange >= start_positions.unsqueeze(1)) + else: raise NotImplementedError("mask_index with shape (3 * batch_size + 2) is not implemented") + while mask.ndim < 4: mask = mask.unsqueeze(1) + attn_scores = mask.where(attn_scores, mask_filter_value) + + if unidirectional: + causal_mask = Tensor.ones((seq_len, seq_len), dtype=dtypes.bool).tril() + attn_scores = causal_mask.where(attn_scores, mask_filter_value) + + output = attn_scores.softmax(-1) @ v + output = output.transpose(1, 2).reshape(batch_size, seq_len, -1) + return output, present + + def attention_onnx(Q:Tensor, K:Tensor, V:Tensor, attn_mask:Tensor|None=None, past_key:Tensor|None=None, past_value:Tensor|None=None, + is_causal:int=0, kv_num_heads:int|None=None, q_num_heads:int|None=None, qk_matmul_output_mode:int=0, scale:float|None=None, + softcap:float=0.0, softmax_precision:int|None=None): + input_shape_len = Q.ndim + if input_shape_len == 3: + assert q_num_heads is not None and kv_num_heads is not None + Q = Q.reshape(Q.shape[0], q_num_heads, Q.shape[1], -1) + K = K.reshape(K.shape[0], kv_num_heads, K.shape[1], -1) + V = V.reshape(V.shape[0], kv_num_heads, V.shape[1], -1) + + if past_key is not None: K = past_key.cat(K, dim=2) + if past_value is not None: V = past_value.cat(V, dim=2) + present_key, present_value = K, V + + _q_heads, _kv_heads = q_num_heads or Q.shape[1], kv_num_heads or K.shape[1] + if _q_heads != _kv_heads: + K = K.repeat((1, _q_heads // _kv_heads, 1, 1)) + V = V.repeat((1, _q_heads // _kv_heads, 1, 1)) + + effective_scale = scale if scale is not None else 1.0 / (cast(int, Q.shape[-1]) ** 0.5) + scores = (Q @ K.transpose(-1, -2)) * effective_scale + qk_matmul_return_val = scores + + if is_causal: + causal_mask = Tensor.ones(Q.shape[-2], K.shape[-2], device=Q.device, dtype=dtypes.bool, requires_grad=False).tril(0) + scores = scores.masked_fill(causal_mask.logical_not(), -float("inf")) + + if attn_mask is not None: + mask_to_add = attn_mask.where(0, -float("inf")) if attn_mask.dtype == dtypes.bool else attn_mask + scores = scores + mask_to_add + if qk_matmul_output_mode == 1: qk_matmul_return_val = scores + + if softcap > 0.0: scores = (scores / softcap).tanh() * softcap + if qk_matmul_output_mode == 2: qk_matmul_return_val = scores + + if softmax_precision: scores = scores.cast({1: dtypes.float32, 10: dtypes.float16, 16: dtypes.bfloat16}[softmax_precision]) + qk_softmax = scores.softmax(-1).cast(Q.dtype) + if qk_matmul_output_mode == 3: qk_matmul_return_val = qk_softmax + + output = (qk_softmax @ V).cast(Q.dtype) + if input_shape_len == 3: output = output.permute(0, 2, 1, 3).reshape(Q.shape[0], Q.shape[2], -1) + return output, present_key, present_value, qk_matmul_return_val + Attention = {OpSetId(Domain.ONNX, 1): attention_onnx, OpSetId(Domain.MICROSOFT_CONTRIB_OPS, 1): attention_contrib} + + def RMSNormalization(X:Tensor, scale:Tensor, axis:int=-1, epsilon:float=1e-5): + norm = X.square().mean(axis=tuple(range(axis + X.ndim if axis < 0 else axis, X.ndim)), keepdim=True).add(epsilon).rsqrt() + return X * norm * scale + + def RotaryEmbedding(X:Tensor, cos_cache:Tensor, sin_cache:Tensor, position_ids:Tensor|None=None, interleaved:int=0, num_heads:int|None=None, + rotary_embedding_dim:int=0): + original_input_shape = X.shape + + if X.ndim == 4: X = X.permute(0, 2, 1, 3) + elif X.ndim == 3: + assert num_heads is not None, "num_heads must be provided for 3D input" + X = X.reshape(*X.shape[:-1], num_heads, X.shape[-1] // num_heads) + + head_size = cast(int, X.shape[-1]) + rot_dim = rotary_embedding_dim or head_size + x_rotate, x_pass = X[..., :rot_dim], X[..., rot_dim:] + + cos = cos_cache[position_ids] if position_ids is not None else cos_cache[:head_size] + sin = sin_cache[position_ids] if position_ids is not None else sin_cache[:head_size] + cos = cos[..., :rot_dim//2].unsqueeze(2) + sin = sin[..., :rot_dim//2].unsqueeze(2) + + if interleaved: + x1, x2 = x_rotate[..., ::2], x_rotate[..., 1::2] + real = x1 * cos - x2 * sin + imag = x1 * sin + x2 * cos + x_rotated = Tensor.stack(real, imag, dim=-1).flatten(start_dim=-2) + else: + x1, x2 = x_rotate.chunk(2, dim=-1) + real = x1 * cos - x2 * sin + imag = x1 * sin + x2 * cos + x_rotated = real.cat(imag, dim=-1) + + output = x_rotated.cat(x_pass, dim=-1) + return output.flatten(start_dim=2) if len(original_input_shape) == 3 else output.permute(0, 2, 1, 3) + + # ***** Indexing Ops ***** + def ArrayFeatureExtractor(x:Tensor, indices:Tensor): return x[..., indices] + + def Gather(x:Tensor, indices:Tensor, axis:int=0): + if indices.numel() < 9: # NOTE lessor kernels for smaller indices but kernel number increases depending on size of indices + ret_shape = x.shape[:axis] + indices.shape + x.shape[axis+1:] + if indices.ndim > 1: indices = indices.flatten() + index_consts = [_cached_to_python_const(indices)] if indices.shape == () else _cached_to_python_const(indices) + index_consts = [x.shape[axis]+i if i<0 else i for i in index_consts] + args = [[(0,x) if j != axis else (i,i+1) for j, x in enumerate(x.shape)] for i in index_consts] + return x.shrink(arg=tuple(args[0])).cat(*[x.shrink(arg=tuple(arg)) for arg in args[1:]], dim=axis).reshape(ret_shape) + # NOTE faster gather, fixed number of kernels, but exceeds limited kernels for openpilot + return x[tuple([slice(None) if i != axis else indices for i in range(x.ndim)])] + def Scatter(*args, **kwargs): return ScatterElements(*args, **kwargs) # deprecated + + def GatherND(x:Tensor, indices:Tensor, batch_dims:int=0): + if batch_dims == 0: return x[tuple(i.squeeze(-1) for i in indices.split(1, -1))] + x_shape, i_shape = x.shape, indices.shape + b = math.prod(x.shape[dim] for dim in range(batch_dims)) + # NOTE: each batched dim of both input and indices are equal + x = x.reshape(b, *x.shape[batch_dims:]) + indices = indices.reshape(b, *indices.shape[batch_dims:]) + b_idx = Tensor.arange(b, device=x.device).reshape(b, *(1,)*(indices.ndim - 2)).expand(*indices.shape[:-1]) + ret = x[(b_idx,) + tuple(i.squeeze(-1) for i in indices.split(1, -1))] + return ret.reshape(*x_shape[:batch_dims], *i_shape[batch_dims:-1], *ret.shape[indices.ndim-1:]) + def ScatterND(x:Tensor, indices:Tensor, updates:Tensor, reduction:Literal["none", "add", "mul"]='none'): + assert updates.shape == indices.shape[:-1] + x.shape[cast(int, indices.shape[-1]):] + x = x.contiguous() + for index, u in zip(indices.split(1, 0), updates.split(1, 0)): + i = tuple(idx.squeeze(-1) for idx in index.squeeze(0).split(1, -1)) + u = u.squeeze(0) + if reduction == "none": x[i] = u + elif reduction == "add": x[i] += u + elif reduction == "mul": x[i] *= u + else: raise NotImplementedError("reduction doesn't support max or min") + return x + + def ScatterElements(x: Tensor, indices: Tensor, updates: Tensor, axis=0, reduction:Literal["none", "add", "mul", "min", "max"]="none"): + indices = (indices < 0).where(x.shape[axis], 0) + indices + if reduction == "none": return x.scatter(axis, indices, updates) + reduction_ = cast(Literal["sum", "prod", "amin", "amax"], {"add": "sum", "mul": "prod", "min": "amin", "max": "amax"}[reduction]) + return x.scatter_reduce(axis, indices, updates, reduction_) + def GatherElements(x:Tensor, indices:Tensor, axis:int): + indices = (indices < 0).where(x.shape[axis], 0) + indices + return x.gather(axis, indices) + + def Compress(inp:Tensor, condition:list[bool], axis:int|None=None): + if axis is None: + inp = inp.flatten() + axis = 0 + axis = inp._resolve_dim(axis) + con = Tensor([i for i,cond in enumerate(condition) if cond]) # compress in python + return inp[tuple(con if i == axis else slice(None) for i in range(inp.ndim))] + + # ***** Quantization Ops ***** + def QuantizeLinear(x:Tensor, y_scale:Tensor, y_zero_point:Tensor|int=0, axis:int=1, block_size:int=0, output_dtype:int=0, saturate=1): + if isinstance(y_zero_point, Tensor): out_dtype = y_zero_point.dtype + elif output_dtype != 0: out_dtype = dtype_fallback(OnnxDataType(output_dtype).to_dtype(), "QuantizeLinear op") + else: out_dtype = dtypes.uint8 + y_scale, y_zero_point = _prepare_quantize(x, y_scale, y_zero_point, axis, block_size) + if out_dtype == dtypes.uchar: + # this appears to work in practice, at least for uchar out_dtype. it folds with the quantize stuff + ret = _clamp_cast((x / y_scale + 0.4999999 + y_zero_point).int(), out_dtype) + else: + ret = _clamp_cast(((x / y_scale).round() + y_zero_point), out_dtype) + return ret.contiguous() + + def DynamicQuantizeLinear(x: Tensor): + # only support uint8 + qmin, qmax = dtypes.min(dtypes.uint8), dtypes.max(dtypes.uint8) + scale = (x.max().maximum(0) + ((-x).max()).maximum(0)) / (qmax - qmin) + zero_point = _clamp_cast((qmin - x.min() / scale).round(), dtypes.uint8) + y = _clamp_cast((x / scale).round() + zero_point, dtypes.uint8) + return y, scale, zero_point + + def DequantizeLinear(x:Tensor, x_scale:Tensor, x_zero_point:Tensor|int=0, axis:int=1, block_size:int=0): + x_scale, x_zero_point = _prepare_quantize(x, x_scale, x_zero_point, axis, block_size) + return ((x.int() - x_zero_point) * x_scale).cast(x_scale.dtype) + + def QLinearConv(x:Tensor, x_scale:Tensor, x_zero_point:Tensor, w:Tensor, w_scale:Tensor, w_zero_point:Tensor, y_scale:Tensor, + y_zero_point:Tensor, B:Tensor|None=None, **opts): + return _qlinearop_quantized(Conv, [x,w], [x_zero_point,w_zero_point], [x_scale,w_scale], y_scale, y_zero_point, **{"B":B, **opts}) + + def QLinearMatMul(a:Tensor, a_scale:Tensor, a_zero_point:Tensor, b:Tensor, b_scale:Tensor, b_zero_point:Tensor, y_scale:Tensor, + y_zero_point:Tensor) -> Tensor: + return _qlinearop_quantized(Tensor.matmul, [a,b], [a_zero_point,b_zero_point], [a_scale,b_scale], y_scale, y_zero_point) + + def QLinearAdd(a:Tensor, a_scale:Tensor, a_zero_point:Tensor, b:Tensor, b_scale:Tensor, b_zero_point:Tensor, c_scale:Tensor, c_zero_point:Tensor): + return _qlinearop_float(Tensor.add, [a,b], [a_zero_point,b_zero_point], [a_scale,b_scale], c_scale, c_zero_point) + + def QLinearMul(a:Tensor, a_scale:Tensor, a_zero_point:Tensor, b:Tensor, b_scale:Tensor, b_zero_point:Tensor, c_scale:Tensor, c_zero_point:Tensor): + return _qlinearop_quantized(Tensor.mul, [a,b], [a_zero_point,b_zero_point], [a_scale,b_scale], c_scale, c_zero_point) + + def QLinearGlobalAveragePool(X:Tensor, x_scale:Tensor, x_zero_point:Tensor, y_scale:Tensor, y_zero_point:Tensor, channels_last:int): + assert channels_last == 0, "TODO NHWC" + return _qlinearop_float(GlobalAveragePool, [X], [x_zero_point], [x_scale], y_scale, y_zero_point) + + def ConvInteger(x: Tensor, w: Tensor, x_zero_point:Tensor = Tensor(0), w_zero_point:Tensor = Tensor(0), B: Tensor | None = None, **opts) -> Tensor: + return _op_integer(Conv, [x,w], [x_zero_point,w_zero_point], **{"B":B, **opts}) + + def MatMulInteger(A: Tensor, B: Tensor, a_zero_point: Tensor = Tensor(0), b_zero_point: Tensor = Tensor(0)) -> Tensor: + return _op_integer(Tensor.matmul, [A,B], [a_zero_point,b_zero_point]) + + # ***** Training Ops ***** + # NOTE: onnx training ops actually don't need the state for optim, all the ops work in a functional way, but we still can reuse optim.py code + @_onnx_training(3) + def Adagrad(R:Tensor, T:int, *inputs:Tensor, decay_factor:float=0.0, epsilon:float=0.0, norm_coefficient:float=0.0): + X, G, H = (i.detach() for i in inputs) + grad = norm_coefficient * X + G + H.assign(H + grad.square()) + up = grad / (H.sqrt() + epsilon) + r = R / (1 + T * decay_factor) + X.assign(X.detach() - r * up) + return [X, H] + + @_onnx_training(4) + def Adam(R:Tensor, T:int, *inputs:Tensor, alpha:float=0.9, beta:float=0.999, epsilon:float=0.0, norm_coefficient:float=0.0, + norm_coefficient_post:float=0.0): + from tinygrad.nn.optim import Adam as TinyAdam + X, G, V, H = inputs + G, V, H = G.detach(), V.detach(), H.detach() + X.grad = norm_coefficient * X.detach() + G + opt = TinyAdam([X], b1=alpha, b2=beta, eps=epsilon) + opt.m, opt.v, opt.lr = [V], [H], R + # need no-op for m_hat and v_hat if T == 0 + if T == 0: opt.b1_t, opt.b2_t = opt.b1_t.zeros_like(), opt.b2_t.zeros_like() + else: + # `T-1` since it's applied again at the start of `_step` + opt.b1_t = Tensor([alpha**(T-1)], dtype=dtypes.float32, device=X.device, requires_grad=False) + opt.b2_t = Tensor([beta**(T-1)], dtype=dtypes.float32, device=X.device, requires_grad=False) + opt.step() + X = (1 - norm_coefficient_post) * X + return [X, V, H] + + @_onnx_training(3) + def Momentum(R:Tensor, T:int, *inputs:Tensor, alpha:float, beta:float, mode:str, norm_coefficient:float): + X, G, V = (i.detach() for i in inputs) + grad = norm_coefficient * X + G + # NOTE: this beta_adjusted term makes it so we can't use SGD for nesterov + beta_adjusted = beta if T > 0 else 1 + V.assign(alpha * V + grad * beta_adjusted) + X.assign(X - R * (V if mode == "standard" else (grad + alpha * V))) + return [X, V] + + def Gradient(*inputs:Tensor, y:str, intermediate_tensors:dict[str, Tensor], **_): + intermediate_tensors[y].backward() + return tuple([t.grad for t in inputs]) + + return { + # Tensor ops + **{op: getattr(Tensor, op.lower()) for op in ("Neg", "Reciprocal", "Pow", "Sqrt", "Sign", "Abs", "Exp", "Log", "Mish", "Sin", "Cos", "Tan", + "Asin", "Acos", "Atan", "Relu", "Sigmoid", "MatMul", "Floor", "Ceil", "IsNaN", "Softplus", "HardSwish", "Where", "Mul", "Sinh", "Cosh", + "Tanh", "Softsign", "Asinh", "Acosh", "Atanh", "Elu", "Celu", "Selu", "Round", "Erf")}, + # Implemented ops + **{name:obj for name,obj in locals().items() if isinstance(obj, types.FunctionType) and not name.startswith("_") and name[0].isupper()}, + # Version ops + **{name:obj for name,obj in locals().items() if isinstance(obj, dict)}, + } + +onnx_ops = get_onnx_ops()