make onnx runner a class (#8647)

* this

* clean up

* more clean ups and improve debug msg

* more correct training toggler

* remove manual training toggling

* change some variable names

* actually just add the training toggle for LIMIT envvar too

* more refinement

* __call__ and OnnxRunner

* fix half pylint, other half is importing from onnx while this file is onnx.py, figure out later

* ahhhh found another mistake

* remove limit from __call__

---------

Co-authored-by: chenyu <chenyu@fastmail.com>
This commit is contained in:
geohotstan
2025-01-21 02:11:05 +08:00
committed by GitHub
parent 46a8c5e1e5
commit dd82b4c913
10 changed files with 172 additions and 173 deletions

View File

@@ -1,14 +1,12 @@
import sys, onnx, time
from tinygrad import Tensor, TinyJit, Device, GlobalCounters, fetch
from tinygrad.tensor import _from_np_dtype
from extra.onnx import get_run_onnx
from extra.onnx import OnnxRunner
def load_onnx_model(fn):
onnx_file = fetch(fn)
onnx_model = onnx.load(onnx_file)
Tensor.no_grad = True
Tensor.training = False
run_onnx = get_run_onnx(onnx_model)
run_onnx = OnnxRunner(onnx_model)
# find preinitted tensors and ignore them
initted_tensors = {inp.name:None for inp in onnx_model.graph.initializer}

View File

@@ -8,7 +8,7 @@ import numpy as np
import subprocess
import tensorflow as tf
import tf2onnx
from extra.onnx import get_run_onnx
from extra.onnx import OnnxRunner
from tinygrad.tensor import Tensor
from extra.export_model import export_model_clang, compile_net, jit_model
@@ -25,7 +25,7 @@ class TinyOnnx:
def __init__(self, keras_model):
input_signature = [tf.TensorSpec([1,32], tf.float32, name='x')]
onnx_model, _ = tf2onnx.convert.from_keras(keras_model, input_signature, opset=13)
self.run_onnx = get_run_onnx(onnx_model)
self.run_onnx = OnnxRunner(onnx_model)
def forward(self, x):
return self.run_onnx({"x": x}, debug=False)['predictions']

View File

@@ -12,17 +12,14 @@ from tinygrad.engine.realize import CompiledRunner
import onnx
from onnx.helper import tensor_dtype_to_np_dtype
from extra.onnx import get_run_onnx # TODO: port to main tinygrad
from extra.onnx import OnnxRunner # TODO: port to main tinygrad
OPENPILOT_MODEL = sys.argv[1] if len(sys.argv) > 1 else "https://github.com/commaai/openpilot/raw/v0.9.7/selfdrive/modeld/models/supercombo.onnx"
OUTPUT = sys.argv[2] if len(sys.argv) > 2 else "/tmp/openpilot.pkl"
def compile(onnx_file):
onnx_model = onnx.load(onnx_file)
Tensor.no_grad = True
Tensor.training = False
run_onnx = get_run_onnx(onnx_model)
run_onnx = OnnxRunner(onnx_model)
print("loaded model")
input_shapes = {inp.name:tuple(x.dim_value for x in inp.type.tensor_type.shape.dim) for inp in onnx_model.graph.input}

View File

@@ -3,7 +3,7 @@ import os
from ultralytics import YOLO
import onnx
from pathlib import Path
from extra.onnx import get_run_onnx
from extra.onnx import OnnxRunner
from tinygrad.tensor import Tensor
os.chdir("/tmp")
@@ -14,5 +14,5 @@ onnx_model = onnx.load(open("yolov8n-seg.onnx", "rb"))
# TODO: move get example inputs to onnx
input_shapes = {inp.name:tuple(x.dim_value for x in inp.type.tensor_type.shape.dim) for inp in onnx_model.graph.input}
print(input_shapes)
run_onnx = get_run_onnx(onnx_model)
run_onnx = OnnxRunner(onnx_model)
run_onnx({"images": Tensor.zeros(1,3,480,640)}, debug=True)

View File

@@ -1,53 +1,42 @@
from typing import Callable, Any, Sequence
import importlib, functools
import numpy as np
from tinygrad import Tensor, dtypes
import importlib, functools, dataclasses
from tinygrad.tensor import Tensor
from tinygrad.helpers import getenv, DEBUG, all_same
from tinygrad.dtype import DType, ConstType
from tinygrad.dtype import DType, ConstType, dtypes
from tinygrad.device import is_dtype_supported
from onnx import AttributeProto, ModelProto, TensorProto, ValueInfoProto, helper
from google.protobuf.json_format import MessageToDict
cache_misses = 0
@functools.lru_cache(None)
def _cached_to_python_const(t:Tensor):
if t.dtype is dtypes.uint8: return t.data().tobytes()
if 0 in t.shape: return []
return t.tolist()
# ***** protobuf parsing ******
from onnx import AttributeProto, ModelProto, TensorProto, TypeProto, helper
import numpy as np
# Tensor -> python value cache for parameters
def to_python_const(t) -> list[ConstType]|ConstType|bytes:
if not isinstance(t, Tensor): return t
global cache_misses
ret = _cached_to_python_const(t)
if (info := _cached_to_python_const.cache_info()).misses > cache_misses and DEBUG >= 3:
print(f"Cache miss for {t}")
cache_misses = info.misses
return ret
# TODO: use real float16
# src: onnx/mapping.py
DTYPE_MAP: dict[int, DType] = {
TensorProto.FLOAT:dtypes.float32, TensorProto.UINT8:dtypes.uint8, TensorProto.INT8:dtypes.int8,
TensorProto.UINT16:dtypes.uint16, TensorProto.INT16:dtypes.int16, TensorProto.INT32:dtypes.int32, TensorProto.INT64:dtypes.int64,
TensorProto.BOOL:dtypes.bool, TensorProto.FLOAT16:dtypes.float32, TensorProto.DOUBLE:dtypes.double, TensorProto.UINT32:dtypes.uint32,
TensorProto.UINT64:dtypes.uint64, TensorProto.BFLOAT16:dtypes.bfloat16,
}
def dtype_parse(onnx_dtype: int) -> DType:
if onnx_dtype not in DTYPE_MAP: raise NotImplementedError(f"onnx dtype {TensorProto.DataType.Name(onnx_dtype)} is not supported")
return DTYPE_MAP[onnx_dtype] if is_dtype_supported(DTYPE_MAP[onnx_dtype]) else dtypes.float
supported: dict[int, DType] = {
TensorProto.FLOAT:dtypes.float32, TensorProto.UINT8:dtypes.uint8, TensorProto.INT8:dtypes.int8,
TensorProto.UINT16:dtypes.uint16, TensorProto.INT16:dtypes.int16, TensorProto.INT32:dtypes.int32, TensorProto.INT64:dtypes.int64,
TensorProto.BOOL:dtypes.bool, TensorProto.FLOAT16:dtypes.float32, TensorProto.DOUBLE:dtypes.double, TensorProto.UINT32:dtypes.uint32,
TensorProto.UINT64:dtypes.uint64, TensorProto.BFLOAT16:dtypes.bfloat16,
}
unsupported = {
TensorProto.UNDEFINED, TensorProto.STRING, TensorProto.COMPLEX64, TensorProto.COMPLEX128, TensorProto.FLOAT8E4M3FN, TensorProto.FLOAT8E4M3FNUZ,
TensorProto.FLOAT8E5M2, TensorProto.FLOAT8E5M2FNUZ, TensorProto.UINT4, TensorProto.INT4
}
if onnx_dtype in unsupported: raise NotImplementedError(f"onnx dtype {TensorProto.DataType.Name(onnx_dtype)} is not supported")
return supported[onnx_dtype] if is_dtype_supported(supported[onnx_dtype]) else dtypes.float
# src: onnx/onnx_ml_pb2.pyi
ATTRIBUTE_MAP: dict[AttributeProto.AttributeType, Callable[[AttributeProto], Any]] = {
AttributeProto.FLOAT: lambda a: float(a.f), AttributeProto.INT: lambda a: int(a.i),
AttributeProto.STRING: lambda a: a.s.decode("utf-8"), AttributeProto.TENSOR: lambda a: buffer_parse(a.t),
AttributeProto.FLOATS: lambda a: tuple(float(x) for x in a.floats), AttributeProto.INTS: lambda a: tuple(int(x) for x in a.ints),
AttributeProto.STRINGS: lambda a: tuple(x.decode("utf-8") for x in a.strings)
}
def attribute_parse(onnx_attribute: AttributeProto):
if onnx_attribute.type not in ATTRIBUTE_MAP:
supported: dict[AttributeProto.AttributeType, Callable[[AttributeProto], Any]] = {
AttributeProto.FLOAT: lambda a: float(a.f), AttributeProto.INT: lambda a: int(a.i),
AttributeProto.STRING: lambda a: a.s.decode("utf-8"), AttributeProto.TENSOR: lambda a: buffer_parse(a.t),
AttributeProto.FLOATS: lambda a: tuple(float(x) for x in a.floats), AttributeProto.INTS: lambda a: tuple(int(x) for x in a.ints),
AttributeProto.STRINGS: lambda a: tuple(x.decode("utf-8") for x in a.strings)
}
unsupported = {
AttributeProto.UNDEFINED, AttributeProto.GRAPH, AttributeProto.SPARSE_TENSOR, AttributeProto.TYPE_PROTO, AttributeProto.TENSORS,
AttributeProto.GRAPHS, AttributeProto.SPARSE_TENSORS, AttributeProto.TYPE_PROTOS
}
if onnx_attribute.type in unsupported:
raise NotImplementedError(f"attribute with type {AttributeProto.AttributeType.Name(onnx_attribute.type)} is not supported")
return ATTRIBUTE_MAP[onnx_attribute.type](onnx_attribute)
return supported[onnx_attribute.type](onnx_attribute)
def buffer_parse(onnx_tensor: TensorProto) -> Tensor:
if onnx_tensor.string_data: raise NotImplementedError("Parsing for buffer with string data is not implemented.")
@@ -62,116 +51,137 @@ def buffer_parse(onnx_tensor: TensorProto) -> Tensor:
return Tensor(np_buffer, dtype=dtype)
return Tensor(None)
onnx_ops = importlib.import_module('extra.onnx_ops')
ONNXLIMIT = getenv("ONNXLIMIT", -1)
def get_run_onnx(onnx_model: ModelProto):
# model initialization data
model_tensors = {inp.name:buffer_parse(inp) for inp in onnx_model.graph.initializer}
model_expected_inputs = {inp.name:inp for inp in onnx_model.graph.input if inp.name not in model_tensors}
model_attributes = {num:{x.name:attribute_parse(x) for x in n.attribute} for num,n in enumerate(onnx_model.graph.node)}
def type_parse(onnx_type: TypeProto):
elem_type = onnx_type
if elem_type.HasField("map_type") or elem_type.HasField("sparse_tensor_type") or elem_type.HasField("opaque_type"):
raise NotImplementedError("parsing for map_type, sparse_tensor_type and opaque_type are not implemented")
if is_optional := elem_type.HasField("optional_type"): elem_type = elem_type.optional_type.elem_type
if is_sequence := elem_type.HasField("sequence_type"): elem_type = elem_type.sequence_type.elem_type
if elem_type.HasField("tensor_type"):
shape = tuple(d.dim_param or d.dim_value for d in elem_type.tensor_type.shape.dim)
dtype = dtype_parse(elem_type.tensor_type.elem_type)
return OnnxValue(shape, dtype, is_optional, is_sequence)
raise RuntimeError(f"TypeProto was not parsed properly: {onnx_type=}")
# model descriptions
# TODO: need a better way of controlling training vs non-training
is_onnx_preview_training = any(n.HasField("domain") and n.domain == "ai.onnx.preview.training" for n in onnx_model.graph.node)
onnx_model_version = onnx_model.opset_import[0].version
# ***** onnx spec *****
@dataclasses.dataclass(frozen=True)
class OnnxValue:
shape: tuple[str|int]
dtype: DType
is_optional: bool
is_sequence: bool
# used to check validity of user_input according to their dimension variables
variable_dims = {}
@dataclasses.dataclass(frozen=True)
class OnnxNode:
num: int
op: str
inputs: tuple[str]
outputs: tuple[str]
opts: dict[str, Any]
# mapping from onnx ops to tensor.py ops
tensor_methods = {
op:op.lower() for op in ("Neg", "Reciprocal", "Pow", "Sqrt", "Sign", "Abs", "Exp", "Log", "Mish", "Sin", "Cos", "Tan", "Asin", "Acos", "Atan",
"Relu", "Sigmoid", "MatMul", "Floor", "Ceil", "IsInf", "IsNaN", "Softplus", "HardSwish", "Where", "Mul", "Sinh", "Cosh", "Tanh",
"Softsign", "Asinh", "Acosh", "Atanh", "Elu", "Celu", "Selu", "Xor", "Round", "Erf", "Mod")
}
# ***** python const *****
required_input_python_consts: dict[str, tuple[int, ...]] = {
"Tile": (1,), "Range": (0,1,2), "Expand": (1,), "Reshape": (1,), "Squeeze": (1,), "Unsqueeze": (1,), "Trilu": (1,), "ConstantOfShape": (0,),
"CumSum": (1,), "Pad": (1,2,3), "MaxUnpool": (2,), "Dropout": (1,2), "CenterCropPad": (1,), "OneHot": (1,), "Compress": (1,),
"ImageDecoder": (0,), "AffineGrid": (1,), "Resize": (1,2,3), "Upsample": (1,), "Split": (1,), "Slice": (1,2,3,4),
**{"Reduce"+r: (1,) for r in ("Max", "Min", "Sum", "Mean", "SumSquare", "Prod", "L1", "L2", "LogSum", "LogSumExp")},
**{optim: (1,) for optim in ("Adam", "Adagrad", "Momentum")}
}
# these values are expected to be python consts
required_input_python_consts: dict[str, tuple[int, ...]] = {
"Tile": (1,), "Range": (0,1,2), "Expand": (1,), "Reshape": (1,), "Squeeze": (1,), "Unsqueeze": (1,), "Trilu": (1,), "ConstantOfShape": (0,),
"CumSum": (1,), "Pad": (1,2,3), "MaxUnpool": (2,), "Dropout": (1,2), "CenterCropPad": (1,), "OneHot": (1,), "Compress": (1,),
"ImageDecoder": (0,), "AffineGrid": (1,), "Resize": (1,2,3), "Upsample": (1,), "Split": (1,), "Slice": (1,2,3,4),
**{"Reduce"+r: (1,) for r in ("Max", "Min", "Sum", "Mean", "SumSquare", "Prod", "L1", "L2", "LogSum", "LogSumExp")},
**{optim: (1,) for optim in ("Adam", "Adagrad", "Momentum")}
}
cache_misses = 0
@functools.lru_cache(None)
def _cached_to_python_const(t:Tensor):
if t.dtype is dtypes.uint8: return t.data().tobytes()
if 0 in t.shape: return []
return t.tolist()
# src: https://onnx.ai/onnx/repo-docs/IR.html#input-output-data-types
# parses and validates inputs based on their shape and dtype specified by model
def prepare_input(user_input:Any, model_input:ValueInfoProto):
type_proto = model_input.type
if type_proto.HasField("optional_type"):
if user_input is None: return None
type_proto = type_proto.optional_type.elem_type
if type_proto.HasField("sequence_type"):
if not isinstance(user_input, Sequence): raise RuntimeError(f"{model_input.name} received {user_input}, expected sequence type")
dtype = dtype_parse(type_proto.sequence_type.elem_type.tensor_type.elem_type)
sequence = [Tensor(i, dtype=dtype, requires_grad=is_onnx_preview_training) if not isinstance(i, Tensor) else i for i in user_input]
if not all_same(tuple(t.shape for t in sequence)): raise RuntimeError(f"shapes for {model_input.name} must be homogeneous")
# TODO: need true float16 for dtype checking
# if not all(t.dtype is dtype for t in sequence):
# raise RuntimeError(f"{model_input.name} has dtype mismatch for sequence type. Expected {dtype}, received {tensor.dtype}.")
# Tensor -> python value cache for parameters
def to_python_const(t:Any, op:str, idx:int) -> list[ConstType]|ConstType|bytes:
if idx not in required_input_python_consts.get(op, ()) or not isinstance(t, Tensor): return t
global cache_misses
ret = _cached_to_python_const(t)
if (info := _cached_to_python_const.cache_info()).misses > cache_misses and DEBUG >= 3:
print(f"Cache miss for {t}")
cache_misses = info.misses
return ret
# ***** runner ******
debug = int(getenv("DEBUGONNX", "0"))
limit = int(getenv("ONNXLIMIT", "-1"))
class OnnxRunner:
def __init__(self, model: ModelProto):
# parse model protobuf
self.is_training = any(n.HasField("domain") and n.domain == "ai.onnx.preview.training" for n in model.graph.node)
self.old_training, self.old_no_grad = Tensor.training, Tensor.no_grad
Tensor.training = True if self.is_training else False
Tensor.no_grad = False if self.is_training else True
self.graph_values = {x.name:buffer_parse(x) for x in model.graph.initializer}
self.graph_inputs = {x.name:type_parse(x.type) for x in model.graph.input if x.name not in self.graph_values}
self.graph_outputs = {x.name:type_parse(x.type) for x in model.graph.output}
self.graph_nodes = tuple(OnnxNode(num, n.op_type, tuple(n.input), tuple(n.output), {x.name:attribute_parse(x) for x in n.attribute})
for num,n in enumerate(model.graph.node))
self.opset_version = model.opset_import[0].version
self.variable_dims: dict[str, int] = {}
# TODO: move extra.onnx_ops here so we don't have to deal with annoying circular import
# TODO: clean up opset stuff after moving extra.onnx_ops here
self.onnx_ops_module = importlib.import_module('extra.onnx_ops')
self.onnx_ops = {
**{op: getattr(Tensor, op.lower()) for op in ("Neg", "Reciprocal", "Pow", "Sqrt", "Sign", "Abs", "Exp", "Log", "Mish", "Sin", "Cos", "Tan",
"Asin", "Acos", "Atan", "Relu", "Sigmoid", "MatMul", "Floor", "Ceil", "IsInf", "IsNaN", "Softplus", "HardSwish", "Where", "Mul", "Sinh", "Cosh",
"Tanh", "Softsign", "Asinh", "Acosh", "Atanh", "Elu", "Celu", "Selu", "Xor", "Round", "Erf", "Mod")},
}
def _parse_input(self, name: str, value: Any, spec: OnnxValue):
if spec.is_optional and value is None: return None
# TODO: need true float16 for dtype checking
if spec.is_sequence:
if not isinstance(value, Sequence): raise RuntimeError(f"{name} received {value}, expected a sequence type")
sequence = [Tensor(v, dtype=spec.dtype, requires_grad=self.is_training) if not isinstance(v, Tensor) else v for v in value]
if not all_same(tuple(t.shape for t in sequence)): raise RuntimeError(f"Shapes for {name} sequence must be homogeneous")
return sequence
if type_proto.HasField("tensor_type"):
dtype = dtype_parse(type_proto.tensor_type.elem_type)
tensor = Tensor(user_input, dtype=dtype, requires_grad=is_onnx_preview_training) if not isinstance(user_input, Tensor) else user_input
# TODO: need true float16 for dtype checking
# if dtype is not tensor.dtype: raise RuntimeError(f"{model_input.name} has mismatch for dtype. Expected {dtype}, received {tensor.dtype}.")
for dim, onnx_dim in enumerate(type_proto.tensor_type.shape.dim):
dim_param, dim_value = onnx_dim.dim_param, onnx_dim.dim_value
user_dim_input = tensor.shape[dim]
if dim_param: dim_value = variable_dims[dim_param] if dim_param in variable_dims else variable_dims.setdefault(dim_param, user_dim_input)
if user_dim_input != dim_value:
raise RuntimeError(f"{model_input.name} has mismatch for dim={dim_param or dim}. Expected {dim_value}, received {user_dim_input}.")
return tensor
type_field_names = [field.name for field,_ in type_proto.ListFields()]
raise NotImplementedError(f"{model_input.name} with {type_field_names=} is not supported")
tensor = Tensor(value, dtype=spec.dtype, requires_grad=self.is_training) if not isinstance(value, Tensor) else value
for dim, (onnx_dim, user_dim_input) in enumerate(zip(spec.shape, tensor.shape, strict=True)):
if isinstance(onnx_dim, str):
onnx_dim = self.variable_dims[onnx_dim] if onnx_dim in self.variable_dims else self.variable_dims.setdefault(onnx_dim, int(user_dim_input))
if user_dim_input != onnx_dim: raise RuntimeError(f"{name} has mismatch on {dim=}. Expected {onnx_dim}, received {user_dim_input}.")
return tensor
def run_onnx(inputs={}, debug=0):
debug = getenv("DEBUGONNX") or debug
if debug >= 3: print("Model initialization data:\n" + "\n".join(f"\t{i.name} - {model_tensors[i.name]}" for i in onnx_model.graph.initializer))
def _dispatch_op(self, op, inps, opts):
if op in self.onnx_ops: return self.onnx_ops[op](*inps, **opts)
if hasattr(self.onnx_ops_module, op):
fxn = getattr(self.onnx_ops_module, op)
if isinstance(fxn, dict):
for k in sorted(fxn.keys()):
if k <= self.opset_version:
real_fxn = fxn[k]
else: real_fxn = fxn
return real_fxn(*inps, **opts)
raise NotImplementedError(f"{op=} not supported")
if debug >= 1: print("Model input:")
for name, value_info in model_expected_inputs.items():
def __call__(self, inputs:dict[str, Any], debug=debug):
for name, input_spec in self.graph_inputs.items():
if name not in inputs: raise RuntimeError(f"Please provide input data for {name}")
model_tensors[name] = prepare_input(inputs[name], value_info)
if debug >= 1: print(f"\t{name} - {model_tensors[name]}")
if debug >= 2: print(f"\t\t{MessageToDict(value_info.type)}")
self.graph_values[name] = self._parse_input(name, inputs[name], input_spec)
for num,n in enumerate(onnx_model.graph.node):
inp_tensors = [model_tensors.get(x) for x in n.input]
required_consts = required_input_python_consts.get(n.op_type, ())
inp = [to_python_const(t) if i in required_consts else t for i,t in enumerate(inp_tensors)]
opt = model_attributes[num]
for node in self.graph_nodes:
inps = [to_python_const(self.graph_values.get(name), node.op, i) for i,name in enumerate(node.inputs)]
opts = node.opts
if debug >= 1: print(f"{num}: op \"{n.op_type}\" input shapes {[x.shape if isinstance(x, Tensor) else x for x in inp_tensors]} opt {opt}")
if debug >= 3:
print("\tinputs:")
print("\n".join(f"\t\t{x} - {t!r}" + (" (to_python_const)" if i in required_consts else "") for i,(x,t) in enumerate(zip(n.input, inp))))
# provide additional opts
if node.op == "Split" and 'num_outputs' not in opts: opts['num_outputs'] = len(node.outputs)
if node.op == "Gradient": opts['intermediate_tensors'] = self.graph_values
# provide additional arguments
if n.op_type == "Split" and 'num_outputs' not in opt: opt['num_outputs'] = len(n.output)
if n.op_type == "Gradient": opt['intermediate_tensors'] = model_tensors
if debug >= 1: print(f"{node.num}: op '{node.op}' opt {opts}")
if debug >= 2 and node.inputs: print("\tinputs:\n" + "\n".join(f"\t\t{x} - {i!r}" for x,i in zip(node.inputs, inps)))
ret = self._dispatch_op(node.op, inps, opts)
ret = ret if isinstance(ret, tuple) else (ret,)
if debug >= 2: print("\toutputs:\n" + "\n".join(f"\t\t{x} - {o!r}" for x,o in zip(node.outputs, ret)))
# run op
if n.op_type in tensor_methods: ret = getattr(Tensor, tensor_methods[n.op_type])(*inp, **opt)
elif hasattr(onnx_ops, n.op_type):
fxn = getattr(onnx_ops, n.op_type)
if isinstance(fxn, dict):
for k in sorted(fxn.keys()):
if k <= onnx_model_version:
real_fxn = fxn[k]
else:
real_fxn = fxn
ret = real_fxn(*inp, **opt)
else:
print("UNSUPPORTED", n.op_type, n.input, n.output)
raise NotImplementedError(f"op_type {n.op_type} not supported")
self.graph_values.update(dict(zip(node.outputs, ret[:len(node.outputs)], strict=True)))
# finalization after running the op
if not isinstance(ret, tuple): ret = (ret, )
if len(n.output) > len(ret): raise RuntimeError(f"expected output size must be less than {len(ret)}, it's {n.output}")
for i in range(len(n.output)): model_tensors[n.output[i]] = ret[i]
if debug >= 2: print("\toutputs:\n" + "\n".join(f"\t\t{n.output[i]} - {ret[i]}" for i in range(len(n.output))))
if num == ONNXLIMIT: return {name:model_tensors[name] for name in n.output}
return {x.name:model_tensors[x.name] for x in onnx_model.graph.output}
return run_onnx
if node.num == limit:
Tensor.training, Tensor.no_grad = self.old_training, self.old_no_grad
return {name:self.graph_values[name] for name in node.outputs}
Tensor.training, Tensor.no_grad = self.old_training, self.old_no_grad
return {name:self.graph_values[name] for name in self.graph_outputs}

View File

@@ -3,7 +3,7 @@ from typing import cast, Literal
from tinygrad.tensor import Tensor, _broadcast_shape, ConstType, ReductionStr
from tinygrad.dtype import ImageDType, dtypes
from tinygrad.helpers import prod, flatten, make_tuple
from extra.onnx import dtype_parse, to_python_const
from extra.onnx import dtype_parse, _cached_to_python_const
import numpy as np
# **************** Free Ops ****************
@@ -282,7 +282,7 @@ def Gather(x:Tensor, indices:Tensor, axis:int=0):
x_sh = list(x.shape)
ret_shape = x_sh[:axis] + list(indices.shape) + x_sh[axis+1:]
if indices.ndim > 1: indices = indices.flatten()
indices = [to_python_const(indices)] if indices.shape == () else [x_sh[axis]+x if x<0 else x for x in to_python_const(indices)] # type: ignore
indices = [_cached_to_python_const(indices)] if indices.shape == () else [x_sh[axis]+x if x<0 else x for x in _cached_to_python_const(indices)]
args = [[(0,x) if j != axis else (i,i+1) for j, x in enumerate(x_sh)] for i in indices] # type: ignore
return x.shrink(arg=tuple(args[0])).cat(*[x.shrink(arg=tuple(arg)) for arg in args[1:]], dim=axis).reshape(ret_shape)
# NOTE faster gather, fixed number of kernels, but exceeds limited kernels for openpilot
@@ -575,12 +575,9 @@ from tinygrad.nn.optim import SGD
def onnx_training(input_group_size):
def _decorator(func):
def __wrapper(R:Tensor, T:int, *inputs:Tensor, **kwargs):
old_training = Tensor.training
Tensor.training = True
R = R.detach()
groups = len(inputs) // input_group_size
ret = [func(R, T, *inps, **kwargs) for inps in (inputs[i::groups] for i in range(groups))]
Tensor.training = old_training
return tuple(flatten(zip(*ret)))
return __wrapper
return _decorator

View File

@@ -2,7 +2,7 @@ import time, sys, hashlib
from pathlib import Path
import onnx
from onnx.helper import tensor_dtype_to_np_dtype
from extra.onnx import get_run_onnx
from extra.onnx import OnnxRunner
from tinygrad import Tensor, dtypes, TinyJit
from tinygrad.helpers import IMAGE, GlobalCounters, fetch, colored, getenv, trange
from tinygrad.tensor import _from_np_dtype
@@ -11,11 +11,8 @@ import numpy as np
OPENPILOT_MODEL = sys.argv[1] if len(sys.argv) > 1 else "https://github.com/commaai/openpilot/raw/v0.9.4/selfdrive/modeld/models/supercombo.onnx"
if __name__ == "__main__":
Tensor.no_grad = True
Tensor.training = False
onnx_model = onnx.load(onnx_path := fetch(OPENPILOT_MODEL))
run_onnx = get_run_onnx(onnx_model)
run_onnx = OnnxRunner(onnx_model)
Tensor.manual_seed(100)
input_shapes = {inp.name:tuple(x.dim_value for x in inp.type.tensor_type.shape.dim) for inp in onnx_model.graph.input}

View File

@@ -6,7 +6,7 @@ import onnx
from onnx.helper import tensor_dtype_to_np_dtype
import onnxruntime as ort
from onnx2torch import convert
from extra.onnx import get_run_onnx
from extra.onnx import OnnxRunner
from tinygrad.helpers import OSX, DEBUG, fetch
from tinygrad import Tensor, Device
from tinygrad.device import CompileError
@@ -65,7 +65,7 @@ def benchmark_model(m, devices, validate_outs=False):
try:
Device.DEFAULT = device
inputs = {k:Tensor(inp) for k,inp in np_inputs.items()}
tinygrad_model = get_run_onnx(onnx_model)
tinygrad_model = OnnxRunner(onnx_model)
benchmark(m, f"tinygrad_{device.lower()}_jitless", lambda: {k:v.numpy() for k,v in tinygrad_model(inputs).items()})
from tinygrad.engine.jit import TinyJit
@@ -115,7 +115,7 @@ def benchmark_model(m, devices, validate_outs=False):
rtol, atol = 2e-3, 2e-3 # tolerance for fp16 models
Device.DEFAULT = device
inputs = {k:Tensor(inp) for k,inp in np_inputs.items()}
tinygrad_model = get_run_onnx(onnx_model)
tinygrad_model = OnnxRunner(onnx_model)
tinygrad_out = tinygrad_model(inputs)
ort_sess = ort.InferenceSession(str(fn), ort_options, ["CPUExecutionProvider"])

View File

@@ -10,7 +10,7 @@ from tinygrad.device import is_dtype_supported
# pip3 install tabulate
pytest_plugins = 'onnx.backend.test.report',
from extra.onnx import get_run_onnx
from extra.onnx import OnnxRunner
class TinygradModel(BackendRep):
def __init__(self, run_onnx, input_names):
@@ -20,7 +20,7 @@ class TinygradModel(BackendRep):
def run(self, inputs: Any, **kwargs: Any) -> Tuple[Any, ...]:
real_inputs = dict(zip(self.input_names, inputs))
ret = self.fxn(real_inputs, debug=True)
ret = self.fxn(real_inputs, debug=2)
return tuple(x.numpy() if isinstance(x, Tensor) else [i.numpy() for i in x] if isinstance(x, list) else np.array(x) for x in ret.values())
class TinygradBackend(Backend):
@@ -30,7 +30,7 @@ class TinygradBackend(Backend):
input_initializer = [x.name for x in model.graph.initializer]
net_feed_input = [x for x in input_all if x not in input_initializer]
print("prepare", cls, device, net_feed_input)
run_onnx = get_run_onnx(model)
run_onnx = OnnxRunner(model)
return TinygradModel(run_onnx, net_feed_input)
@classmethod

View File

@@ -7,7 +7,7 @@ try:
import onnx
except ModuleNotFoundError:
raise unittest.SkipTest("onnx not installed, skipping onnx test")
from extra.onnx import get_run_onnx
from extra.onnx import OnnxRunner
from tinygrad.tensor import Tensor
from tinygrad.helpers import CI, fetch, temp
@@ -26,7 +26,7 @@ np.random.seed(1337)
class TestOnnxModel(unittest.TestCase):
def test_benchmark_openpilot_model(self):
onnx_model = onnx.load(fetch(OPENPILOT_MODEL))
run_onnx = get_run_onnx(onnx_model)
run_onnx = OnnxRunner(onnx_model)
def get_inputs():
np_inputs = {
"input_imgs": np.random.randn(*(1, 12, 128, 256)),
@@ -70,7 +70,7 @@ class TestOnnxModel(unittest.TestCase):
def test_openpilot_model(self):
onnx_model = onnx.load(fetch(OPENPILOT_MODEL))
run_onnx = get_run_onnx(onnx_model)
run_onnx = OnnxRunner(onnx_model)
print("got run_onnx")
inputs = {
"input_imgs": np.random.randn(*(1, 12, 128, 256)),
@@ -124,7 +124,7 @@ class TestOnnxModel(unittest.TestCase):
onnx_model = onnx.load(fn)
print("onnx loaded")
from test.models.test_efficientnet import chicken_img, car_img, preprocess, _LABELS
run_onnx = get_run_onnx(onnx_model)
run_onnx = OnnxRunner(onnx_model)
def run(img):
inputs = {input_name: preprocess(img, new=input_new)}