Files
tinygrad/extra/onnx.py
2025-01-11 11:38:36 -05:00

178 lines
10 KiB
Python

from typing import Callable, Any, Sequence
import importlib, functools
import numpy as np
from tinygrad import Tensor, dtypes
from tinygrad.helpers import getenv, DEBUG, all_same
from tinygrad.dtype import DType, ConstType
from tinygrad.device import is_dtype_supported
from onnx import AttributeProto, ModelProto, TensorProto, ValueInfoProto, helper
from google.protobuf.json_format import MessageToDict
cache_misses = 0
@functools.lru_cache(None)
def _cached_to_python_const(t:Tensor):
if t.dtype is dtypes.uint8: return t.data().tobytes()
if 0 in t.shape: return []
return t.tolist()
# Tensor -> python value cache for parameters
def to_python_const(t) -> list[ConstType]|ConstType|bytes:
if not isinstance(t, Tensor): return t
global cache_misses
ret = _cached_to_python_const(t)
if (info := _cached_to_python_const.cache_info()).misses > cache_misses and DEBUG >= 3:
print(f"Cache miss for {t}")
cache_misses = info.misses
return ret
# TODO: use real float16
# src: onnx/mapping.py
DTYPE_MAP: dict[int, DType] = {
TensorProto.FLOAT:dtypes.float32, TensorProto.UINT8:dtypes.uint8, TensorProto.INT8:dtypes.int8,
TensorProto.UINT16:dtypes.uint16, TensorProto.INT16:dtypes.int16, TensorProto.INT32:dtypes.int32, TensorProto.INT64:dtypes.int64,
TensorProto.BOOL:dtypes.bool, TensorProto.FLOAT16:dtypes.float32, TensorProto.DOUBLE:dtypes.double, TensorProto.UINT32:dtypes.uint32,
TensorProto.UINT64:dtypes.uint64, TensorProto.BFLOAT16:dtypes.bfloat16,
}
def dtype_parse(onnx_dtype: int) -> DType:
if onnx_dtype not in DTYPE_MAP: raise NotImplementedError(f"onnx dtype {TensorProto.DataType.Name(onnx_dtype)} is not supported")
return DTYPE_MAP[onnx_dtype] if is_dtype_supported(DTYPE_MAP[onnx_dtype]) else dtypes.float
# src: onnx/onnx_ml_pb2.pyi
ATTRIBUTE_MAP: dict[AttributeProto.AttributeType, Callable[[AttributeProto], Any]] = {
AttributeProto.FLOAT: lambda a: float(a.f), AttributeProto.INT: lambda a: int(a.i),
AttributeProto.STRING: lambda a: a.s.decode("utf-8"), AttributeProto.TENSOR: lambda a: buffer_parse(a.t),
AttributeProto.FLOATS: lambda a: tuple(float(x) for x in a.floats), AttributeProto.INTS: lambda a: tuple(int(x) for x in a.ints),
AttributeProto.STRINGS: lambda a: tuple(x.decode("utf-8") for x in a.strings)
}
def attribute_parse(onnx_attribute: AttributeProto):
if onnx_attribute.type not in ATTRIBUTE_MAP:
raise NotImplementedError(f"attribute with type {AttributeProto.AttributeType.Name(onnx_attribute.type)} is not supported")
return ATTRIBUTE_MAP[onnx_attribute.type](onnx_attribute)
def buffer_parse(onnx_tensor: TensorProto) -> Tensor:
if onnx_tensor.string_data: raise NotImplementedError("Parsing for buffer with string data is not implemented.")
dtype, shape = dtype_parse(onnx_tensor.data_type), tuple(onnx_tensor.dims)
if data := list(onnx_tensor.float_data) or list(onnx_tensor.int32_data) or list(onnx_tensor.int64_data) or list(onnx_tensor.double_data) or \
list(onnx_tensor.uint64_data):
if len(data) == 1: return Tensor(data[0], dtype=dtype).reshape(shape)
return Tensor(data, dtype=dtype).reshape(shape).realize()
if onnx_tensor.HasField("raw_data"):
np_buffer = np.frombuffer(onnx_tensor.raw_data, dtype=helper.tensor_dtype_to_np_dtype(onnx_tensor.data_type)).copy().reshape(shape)
if np_buffer.size == 1: return Tensor(np_buffer.item(), dtype=dtype).reshape(shape)
return Tensor(np_buffer, dtype=dtype)
return Tensor(None)
onnx_ops = importlib.import_module('extra.onnx_ops')
ONNXLIMIT = getenv("ONNXLIMIT", -1)
def get_run_onnx(onnx_model: ModelProto):
# model initialization data
model_tensors = {inp.name:buffer_parse(inp) for inp in onnx_model.graph.initializer}
model_expected_inputs = {inp.name:inp for inp in onnx_model.graph.input if inp.name not in model_tensors}
model_attributes = {num:{x.name:attribute_parse(x) for x in n.attribute} for num,n in enumerate(onnx_model.graph.node)}
# model descriptions
# TODO: need a better way of controlling training vs non-training
is_onnx_preview_training = any(n.HasField("domain") and n.domain == "ai.onnx.preview.training" for n in onnx_model.graph.node)
onnx_model_version = onnx_model.opset_import[0].version
# used to check validity of user_input according to their dimension variables
variable_dims = {}
# mapping from onnx ops to tensor.py ops
tensor_methods = {
op:op.lower() for op in ("Neg", "Reciprocal", "Pow", "Sqrt", "Sign", "Abs", "Exp", "Log", "Mish", "Sin", "Cos", "Tan", "Asin", "Acos", "Atan",
"Relu", "Sigmoid", "MatMul", "Floor", "Ceil", "IsInf", "IsNaN", "Softplus", "HardSwish", "Where", "Mul", "Sinh", "Cosh", "Tanh",
"Softsign", "Asinh", "Acosh", "Atanh", "Elu", "Celu", "Selu", "Xor", "Round", "Erf", "Mod")
}
# these values are expected to be python consts
required_input_python_consts: dict[str, tuple[int, ...]] = {
"Tile": (1,), "Range": (0,1,2), "Expand": (1,), "Reshape": (1,), "Squeeze": (1,), "Unsqueeze": (1,), "Trilu": (1,), "ConstantOfShape": (0,),
"CumSum": (1,), "Pad": (1,2,3), "MaxUnpool": (2,), "Dropout": (1,2), "CenterCropPad": (1,), "OneHot": (1,), "Compress": (1,),
"ImageDecoder": (0,), "AffineGrid": (1,), "Resize": (1,2,3), "Upsample": (1,), "Split": (1,), "Slice": (1,2,3,4),
**{"Reduce"+r: (1,) for r in ("Max", "Min", "Sum", "Mean", "SumSquare", "Prod", "L1", "L2", "LogSum", "LogSumExp")},
**{optim: (1,) for optim in ("Adam", "Adagrad", "Momentum")}
}
# src: https://onnx.ai/onnx/repo-docs/IR.html#input-output-data-types
# parses and validates inputs based on their shape and dtype specified by model
def prepare_input(user_input:Any, model_input:ValueInfoProto):
type_proto = model_input.type
if type_proto.HasField("optional_type"):
if user_input is None: return None
type_proto = type_proto.optional_type.elem_type
if type_proto.HasField("sequence_type"):
if not isinstance(user_input, Sequence): raise RuntimeError(f"{model_input.name} received {user_input}, expected sequence type")
dtype = dtype_parse(type_proto.sequence_type.elem_type.tensor_type.elem_type)
sequence = [Tensor(i, dtype=dtype, requires_grad=is_onnx_preview_training) if not isinstance(i, Tensor) else i for i in user_input]
if not all_same(tuple(t.shape for t in sequence)): raise RuntimeError(f"shapes for {model_input.name} must be homogeneous")
# TODO: need true float16 for dtype checking
# if not all(t.dtype is dtype for t in sequence):
# raise RuntimeError(f"{model_input.name} has dtype mismatch for sequence type. Expected {dtype}, received {tensor.dtype}.")
return sequence
if type_proto.HasField("tensor_type"):
dtype = dtype_parse(type_proto.tensor_type.elem_type)
tensor = Tensor(user_input, dtype=dtype, requires_grad=is_onnx_preview_training) if not isinstance(user_input, Tensor) else user_input
# TODO: need true float16 for dtype checking
# if dtype is not tensor.dtype: raise RuntimeError(f"{model_input.name} has mismatch for dtype. Expected {dtype}, received {tensor.dtype}.")
for dim, onnx_dim in enumerate(type_proto.tensor_type.shape.dim):
dim_param, dim_value = onnx_dim.dim_param, onnx_dim.dim_value
user_dim_input = tensor.shape[dim]
if dim_param: dim_value = variable_dims[dim_param] if dim_param in variable_dims else variable_dims.setdefault(dim_param, user_dim_input)
if user_dim_input != dim_value:
raise RuntimeError(f"{model_input.name} has mismatch for dim={dim_param or dim}. Expected {dim_value}, received {user_dim_input}.")
return tensor
type_field_names = [field.name for field,_ in type_proto.ListFields()]
raise NotImplementedError(f"{model_input.name} with {type_field_names=} is not supported")
def run_onnx(inputs={}, debug=0):
debug = getenv("DEBUGONNX") or debug
if debug >= 3: print("Model initialization data:\n" + "\n".join(f"\t{i.name} - {model_tensors[i.name]}" for i in onnx_model.graph.initializer))
if debug >= 1: print("Model input:")
for name, value_info in model_expected_inputs.items():
if name not in inputs: raise RuntimeError(f"Please provide input data for {name}")
model_tensors[name] = prepare_input(inputs[name], value_info)
if debug >= 1: print(f"\t{name} - {model_tensors[name]}")
if debug >= 2: print(f"\t\t{MessageToDict(value_info.type)}")
for num,n in enumerate(onnx_model.graph.node):
inp_tensors = [model_tensors.get(x) for x in n.input]
required_consts = required_input_python_consts.get(n.op_type, ())
inp = [to_python_const(t) if i in required_consts else t for i,t in enumerate(inp_tensors)]
opt = model_attributes[num]
if debug >= 1: print(f"{num}: op \"{n.op_type}\" input shapes {[x.shape if isinstance(x, Tensor) else x for x in inp_tensors]} opt {opt}")
if debug >= 3:
print("\tinputs:")
print("\n".join(f"\t\t{x} - {t!r}" + (" (to_python_const)" if i in required_consts else "") for i,(x,t) in enumerate(zip(n.input, inp))))
# provide additional arguments
if n.op_type == "Split" and 'num_outputs' not in opt: opt['num_outputs'] = len(n.output)
if n.op_type == "Gradient": opt['intermediate_tensors'] = model_tensors
# run op
if n.op_type in tensor_methods: ret = getattr(Tensor, tensor_methods[n.op_type])(*inp, **opt)
elif hasattr(onnx_ops, n.op_type):
fxn = getattr(onnx_ops, n.op_type)
if isinstance(fxn, dict):
for k in sorted(fxn.keys()):
if k <= onnx_model_version:
real_fxn = fxn[k]
else:
real_fxn = fxn
ret = real_fxn(*inp, **opt)
else:
print("UNSUPPORTED", n.op_type, n.input, n.output)
raise NotImplementedError(f"op_type {n.op_type} not supported")
# finalization after running the op
if not isinstance(ret, tuple): ret = (ret, )
if len(n.output) > len(ret): raise RuntimeError(f"expected output size must be less than {len(ret)}, it's {n.output}")
for i in range(len(n.output)): model_tensors[n.output[i]] = ret[i]
if debug >= 2: print("\toutputs:\n" + "\n".join(f"\t\t{n.output[i]} - {ret[i]}" for i in range(len(n.output))))
if num == ONNXLIMIT: return {name:model_tensors[name] for name in n.output}
return {x.name:model_tensors[x.name] for x in onnx_model.graph.output}
return run_onnx