|
|
|
|
@@ -1,53 +1,42 @@
|
|
|
|
|
from typing import Callable, Any, Sequence
|
|
|
|
|
import importlib, functools
|
|
|
|
|
import numpy as np
|
|
|
|
|
from tinygrad import Tensor, dtypes
|
|
|
|
|
import importlib, functools, dataclasses
|
|
|
|
|
from tinygrad.tensor import Tensor
|
|
|
|
|
from tinygrad.helpers import getenv, DEBUG, all_same
|
|
|
|
|
from tinygrad.dtype import DType, ConstType
|
|
|
|
|
from tinygrad.dtype import DType, ConstType, dtypes
|
|
|
|
|
from tinygrad.device import is_dtype_supported
|
|
|
|
|
from onnx import AttributeProto, ModelProto, TensorProto, ValueInfoProto, helper
|
|
|
|
|
from google.protobuf.json_format import MessageToDict
|
|
|
|
|
|
|
|
|
|
cache_misses = 0
|
|
|
|
|
@functools.lru_cache(None)
|
|
|
|
|
def _cached_to_python_const(t:Tensor):
|
|
|
|
|
if t.dtype is dtypes.uint8: return t.data().tobytes()
|
|
|
|
|
if 0 in t.shape: return []
|
|
|
|
|
return t.tolist()
|
|
|
|
|
# ***** protobuf parsing ******
|
|
|
|
|
from onnx import AttributeProto, ModelProto, TensorProto, TypeProto, helper
|
|
|
|
|
import numpy as np
|
|
|
|
|
|
|
|
|
|
# Tensor -> python value cache for parameters
|
|
|
|
|
def to_python_const(t) -> list[ConstType]|ConstType|bytes:
|
|
|
|
|
if not isinstance(t, Tensor): return t
|
|
|
|
|
global cache_misses
|
|
|
|
|
ret = _cached_to_python_const(t)
|
|
|
|
|
if (info := _cached_to_python_const.cache_info()).misses > cache_misses and DEBUG >= 3:
|
|
|
|
|
print(f"Cache miss for {t}")
|
|
|
|
|
cache_misses = info.misses
|
|
|
|
|
return ret
|
|
|
|
|
|
|
|
|
|
# TODO: use real float16
|
|
|
|
|
# src: onnx/mapping.py
|
|
|
|
|
DTYPE_MAP: dict[int, DType] = {
|
|
|
|
|
TensorProto.FLOAT:dtypes.float32, TensorProto.UINT8:dtypes.uint8, TensorProto.INT8:dtypes.int8,
|
|
|
|
|
TensorProto.UINT16:dtypes.uint16, TensorProto.INT16:dtypes.int16, TensorProto.INT32:dtypes.int32, TensorProto.INT64:dtypes.int64,
|
|
|
|
|
TensorProto.BOOL:dtypes.bool, TensorProto.FLOAT16:dtypes.float32, TensorProto.DOUBLE:dtypes.double, TensorProto.UINT32:dtypes.uint32,
|
|
|
|
|
TensorProto.UINT64:dtypes.uint64, TensorProto.BFLOAT16:dtypes.bfloat16,
|
|
|
|
|
}
|
|
|
|
|
def dtype_parse(onnx_dtype: int) -> DType:
|
|
|
|
|
if onnx_dtype not in DTYPE_MAP: raise NotImplementedError(f"onnx dtype {TensorProto.DataType.Name(onnx_dtype)} is not supported")
|
|
|
|
|
return DTYPE_MAP[onnx_dtype] if is_dtype_supported(DTYPE_MAP[onnx_dtype]) else dtypes.float
|
|
|
|
|
supported: dict[int, DType] = {
|
|
|
|
|
TensorProto.FLOAT:dtypes.float32, TensorProto.UINT8:dtypes.uint8, TensorProto.INT8:dtypes.int8,
|
|
|
|
|
TensorProto.UINT16:dtypes.uint16, TensorProto.INT16:dtypes.int16, TensorProto.INT32:dtypes.int32, TensorProto.INT64:dtypes.int64,
|
|
|
|
|
TensorProto.BOOL:dtypes.bool, TensorProto.FLOAT16:dtypes.float32, TensorProto.DOUBLE:dtypes.double, TensorProto.UINT32:dtypes.uint32,
|
|
|
|
|
TensorProto.UINT64:dtypes.uint64, TensorProto.BFLOAT16:dtypes.bfloat16,
|
|
|
|
|
}
|
|
|
|
|
unsupported = {
|
|
|
|
|
TensorProto.UNDEFINED, TensorProto.STRING, TensorProto.COMPLEX64, TensorProto.COMPLEX128, TensorProto.FLOAT8E4M3FN, TensorProto.FLOAT8E4M3FNUZ,
|
|
|
|
|
TensorProto.FLOAT8E5M2, TensorProto.FLOAT8E5M2FNUZ, TensorProto.UINT4, TensorProto.INT4
|
|
|
|
|
}
|
|
|
|
|
if onnx_dtype in unsupported: raise NotImplementedError(f"onnx dtype {TensorProto.DataType.Name(onnx_dtype)} is not supported")
|
|
|
|
|
return supported[onnx_dtype] if is_dtype_supported(supported[onnx_dtype]) else dtypes.float
|
|
|
|
|
|
|
|
|
|
# src: onnx/onnx_ml_pb2.pyi
|
|
|
|
|
ATTRIBUTE_MAP: dict[AttributeProto.AttributeType, Callable[[AttributeProto], Any]] = {
|
|
|
|
|
AttributeProto.FLOAT: lambda a: float(a.f), AttributeProto.INT: lambda a: int(a.i),
|
|
|
|
|
AttributeProto.STRING: lambda a: a.s.decode("utf-8"), AttributeProto.TENSOR: lambda a: buffer_parse(a.t),
|
|
|
|
|
AttributeProto.FLOATS: lambda a: tuple(float(x) for x in a.floats), AttributeProto.INTS: lambda a: tuple(int(x) for x in a.ints),
|
|
|
|
|
AttributeProto.STRINGS: lambda a: tuple(x.decode("utf-8") for x in a.strings)
|
|
|
|
|
}
|
|
|
|
|
def attribute_parse(onnx_attribute: AttributeProto):
|
|
|
|
|
if onnx_attribute.type not in ATTRIBUTE_MAP:
|
|
|
|
|
supported: dict[AttributeProto.AttributeType, Callable[[AttributeProto], Any]] = {
|
|
|
|
|
AttributeProto.FLOAT: lambda a: float(a.f), AttributeProto.INT: lambda a: int(a.i),
|
|
|
|
|
AttributeProto.STRING: lambda a: a.s.decode("utf-8"), AttributeProto.TENSOR: lambda a: buffer_parse(a.t),
|
|
|
|
|
AttributeProto.FLOATS: lambda a: tuple(float(x) for x in a.floats), AttributeProto.INTS: lambda a: tuple(int(x) for x in a.ints),
|
|
|
|
|
AttributeProto.STRINGS: lambda a: tuple(x.decode("utf-8") for x in a.strings)
|
|
|
|
|
}
|
|
|
|
|
unsupported = {
|
|
|
|
|
AttributeProto.UNDEFINED, AttributeProto.GRAPH, AttributeProto.SPARSE_TENSOR, AttributeProto.TYPE_PROTO, AttributeProto.TENSORS,
|
|
|
|
|
AttributeProto.GRAPHS, AttributeProto.SPARSE_TENSORS, AttributeProto.TYPE_PROTOS
|
|
|
|
|
}
|
|
|
|
|
if onnx_attribute.type in unsupported:
|
|
|
|
|
raise NotImplementedError(f"attribute with type {AttributeProto.AttributeType.Name(onnx_attribute.type)} is not supported")
|
|
|
|
|
return ATTRIBUTE_MAP[onnx_attribute.type](onnx_attribute)
|
|
|
|
|
return supported[onnx_attribute.type](onnx_attribute)
|
|
|
|
|
|
|
|
|
|
def buffer_parse(onnx_tensor: TensorProto) -> Tensor:
|
|
|
|
|
if onnx_tensor.string_data: raise NotImplementedError("Parsing for buffer with string data is not implemented.")
|
|
|
|
|
@@ -62,116 +51,137 @@ def buffer_parse(onnx_tensor: TensorProto) -> Tensor:
|
|
|
|
|
return Tensor(np_buffer, dtype=dtype)
|
|
|
|
|
return Tensor(None)
|
|
|
|
|
|
|
|
|
|
onnx_ops = importlib.import_module('extra.onnx_ops')
|
|
|
|
|
ONNXLIMIT = getenv("ONNXLIMIT", -1)
|
|
|
|
|
def get_run_onnx(onnx_model: ModelProto):
|
|
|
|
|
# model initialization data
|
|
|
|
|
model_tensors = {inp.name:buffer_parse(inp) for inp in onnx_model.graph.initializer}
|
|
|
|
|
model_expected_inputs = {inp.name:inp for inp in onnx_model.graph.input if inp.name not in model_tensors}
|
|
|
|
|
model_attributes = {num:{x.name:attribute_parse(x) for x in n.attribute} for num,n in enumerate(onnx_model.graph.node)}
|
|
|
|
|
def type_parse(onnx_type: TypeProto):
|
|
|
|
|
elem_type = onnx_type
|
|
|
|
|
if elem_type.HasField("map_type") or elem_type.HasField("sparse_tensor_type") or elem_type.HasField("opaque_type"):
|
|
|
|
|
raise NotImplementedError("parsing for map_type, sparse_tensor_type and opaque_type are not implemented")
|
|
|
|
|
if is_optional := elem_type.HasField("optional_type"): elem_type = elem_type.optional_type.elem_type
|
|
|
|
|
if is_sequence := elem_type.HasField("sequence_type"): elem_type = elem_type.sequence_type.elem_type
|
|
|
|
|
if elem_type.HasField("tensor_type"):
|
|
|
|
|
shape = tuple(d.dim_param or d.dim_value for d in elem_type.tensor_type.shape.dim)
|
|
|
|
|
dtype = dtype_parse(elem_type.tensor_type.elem_type)
|
|
|
|
|
return OnnxValue(shape, dtype, is_optional, is_sequence)
|
|
|
|
|
raise RuntimeError(f"TypeProto was not parsed properly: {onnx_type=}")
|
|
|
|
|
|
|
|
|
|
# model descriptions
|
|
|
|
|
# TODO: need a better way of controlling training vs non-training
|
|
|
|
|
is_onnx_preview_training = any(n.HasField("domain") and n.domain == "ai.onnx.preview.training" for n in onnx_model.graph.node)
|
|
|
|
|
onnx_model_version = onnx_model.opset_import[0].version
|
|
|
|
|
# ***** onnx spec *****
|
|
|
|
|
@dataclasses.dataclass(frozen=True)
|
|
|
|
|
class OnnxValue:
|
|
|
|
|
shape: tuple[str|int]
|
|
|
|
|
dtype: DType
|
|
|
|
|
is_optional: bool
|
|
|
|
|
is_sequence: bool
|
|
|
|
|
|
|
|
|
|
# used to check validity of user_input according to their dimension variables
|
|
|
|
|
variable_dims = {}
|
|
|
|
|
@dataclasses.dataclass(frozen=True)
|
|
|
|
|
class OnnxNode:
|
|
|
|
|
num: int
|
|
|
|
|
op: str
|
|
|
|
|
inputs: tuple[str]
|
|
|
|
|
outputs: tuple[str]
|
|
|
|
|
opts: dict[str, Any]
|
|
|
|
|
|
|
|
|
|
# mapping from onnx ops to tensor.py ops
|
|
|
|
|
tensor_methods = {
|
|
|
|
|
op:op.lower() for op in ("Neg", "Reciprocal", "Pow", "Sqrt", "Sign", "Abs", "Exp", "Log", "Mish", "Sin", "Cos", "Tan", "Asin", "Acos", "Atan",
|
|
|
|
|
"Relu", "Sigmoid", "MatMul", "Floor", "Ceil", "IsInf", "IsNaN", "Softplus", "HardSwish", "Where", "Mul", "Sinh", "Cosh", "Tanh",
|
|
|
|
|
"Softsign", "Asinh", "Acosh", "Atanh", "Elu", "Celu", "Selu", "Xor", "Round", "Erf", "Mod")
|
|
|
|
|
}
|
|
|
|
|
# ***** python const *****
|
|
|
|
|
required_input_python_consts: dict[str, tuple[int, ...]] = {
|
|
|
|
|
"Tile": (1,), "Range": (0,1,2), "Expand": (1,), "Reshape": (1,), "Squeeze": (1,), "Unsqueeze": (1,), "Trilu": (1,), "ConstantOfShape": (0,),
|
|
|
|
|
"CumSum": (1,), "Pad": (1,2,3), "MaxUnpool": (2,), "Dropout": (1,2), "CenterCropPad": (1,), "OneHot": (1,), "Compress": (1,),
|
|
|
|
|
"ImageDecoder": (0,), "AffineGrid": (1,), "Resize": (1,2,3), "Upsample": (1,), "Split": (1,), "Slice": (1,2,3,4),
|
|
|
|
|
**{"Reduce"+r: (1,) for r in ("Max", "Min", "Sum", "Mean", "SumSquare", "Prod", "L1", "L2", "LogSum", "LogSumExp")},
|
|
|
|
|
**{optim: (1,) for optim in ("Adam", "Adagrad", "Momentum")}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
# these values are expected to be python consts
|
|
|
|
|
required_input_python_consts: dict[str, tuple[int, ...]] = {
|
|
|
|
|
"Tile": (1,), "Range": (0,1,2), "Expand": (1,), "Reshape": (1,), "Squeeze": (1,), "Unsqueeze": (1,), "Trilu": (1,), "ConstantOfShape": (0,),
|
|
|
|
|
"CumSum": (1,), "Pad": (1,2,3), "MaxUnpool": (2,), "Dropout": (1,2), "CenterCropPad": (1,), "OneHot": (1,), "Compress": (1,),
|
|
|
|
|
"ImageDecoder": (0,), "AffineGrid": (1,), "Resize": (1,2,3), "Upsample": (1,), "Split": (1,), "Slice": (1,2,3,4),
|
|
|
|
|
**{"Reduce"+r: (1,) for r in ("Max", "Min", "Sum", "Mean", "SumSquare", "Prod", "L1", "L2", "LogSum", "LogSumExp")},
|
|
|
|
|
**{optim: (1,) for optim in ("Adam", "Adagrad", "Momentum")}
|
|
|
|
|
}
|
|
|
|
|
cache_misses = 0
|
|
|
|
|
@functools.lru_cache(None)
|
|
|
|
|
def _cached_to_python_const(t:Tensor):
|
|
|
|
|
if t.dtype is dtypes.uint8: return t.data().tobytes()
|
|
|
|
|
if 0 in t.shape: return []
|
|
|
|
|
return t.tolist()
|
|
|
|
|
|
|
|
|
|
# src: https://onnx.ai/onnx/repo-docs/IR.html#input-output-data-types
|
|
|
|
|
# parses and validates inputs based on their shape and dtype specified by model
|
|
|
|
|
def prepare_input(user_input:Any, model_input:ValueInfoProto):
|
|
|
|
|
type_proto = model_input.type
|
|
|
|
|
if type_proto.HasField("optional_type"):
|
|
|
|
|
if user_input is None: return None
|
|
|
|
|
type_proto = type_proto.optional_type.elem_type
|
|
|
|
|
if type_proto.HasField("sequence_type"):
|
|
|
|
|
if not isinstance(user_input, Sequence): raise RuntimeError(f"{model_input.name} received {user_input}, expected sequence type")
|
|
|
|
|
dtype = dtype_parse(type_proto.sequence_type.elem_type.tensor_type.elem_type)
|
|
|
|
|
sequence = [Tensor(i, dtype=dtype, requires_grad=is_onnx_preview_training) if not isinstance(i, Tensor) else i for i in user_input]
|
|
|
|
|
if not all_same(tuple(t.shape for t in sequence)): raise RuntimeError(f"shapes for {model_input.name} must be homogeneous")
|
|
|
|
|
# TODO: need true float16 for dtype checking
|
|
|
|
|
# if not all(t.dtype is dtype for t in sequence):
|
|
|
|
|
# raise RuntimeError(f"{model_input.name} has dtype mismatch for sequence type. Expected {dtype}, received {tensor.dtype}.")
|
|
|
|
|
# Tensor -> python value cache for parameters
|
|
|
|
|
def to_python_const(t:Any, op:str, idx:int) -> list[ConstType]|ConstType|bytes:
|
|
|
|
|
if idx not in required_input_python_consts.get(op, ()) or not isinstance(t, Tensor): return t
|
|
|
|
|
global cache_misses
|
|
|
|
|
ret = _cached_to_python_const(t)
|
|
|
|
|
if (info := _cached_to_python_const.cache_info()).misses > cache_misses and DEBUG >= 3:
|
|
|
|
|
print(f"Cache miss for {t}")
|
|
|
|
|
cache_misses = info.misses
|
|
|
|
|
return ret
|
|
|
|
|
|
|
|
|
|
# ***** runner ******
|
|
|
|
|
debug = int(getenv("DEBUGONNX", "0"))
|
|
|
|
|
limit = int(getenv("ONNXLIMIT", "-1"))
|
|
|
|
|
class OnnxRunner:
|
|
|
|
|
def __init__(self, model: ModelProto):
|
|
|
|
|
# parse model protobuf
|
|
|
|
|
self.is_training = any(n.HasField("domain") and n.domain == "ai.onnx.preview.training" for n in model.graph.node)
|
|
|
|
|
self.old_training, self.old_no_grad = Tensor.training, Tensor.no_grad
|
|
|
|
|
Tensor.training = True if self.is_training else False
|
|
|
|
|
Tensor.no_grad = False if self.is_training else True
|
|
|
|
|
self.graph_values = {x.name:buffer_parse(x) for x in model.graph.initializer}
|
|
|
|
|
self.graph_inputs = {x.name:type_parse(x.type) for x in model.graph.input if x.name not in self.graph_values}
|
|
|
|
|
self.graph_outputs = {x.name:type_parse(x.type) for x in model.graph.output}
|
|
|
|
|
self.graph_nodes = tuple(OnnxNode(num, n.op_type, tuple(n.input), tuple(n.output), {x.name:attribute_parse(x) for x in n.attribute})
|
|
|
|
|
for num,n in enumerate(model.graph.node))
|
|
|
|
|
self.opset_version = model.opset_import[0].version
|
|
|
|
|
self.variable_dims: dict[str, int] = {}
|
|
|
|
|
|
|
|
|
|
# TODO: move extra.onnx_ops here so we don't have to deal with annoying circular import
|
|
|
|
|
# TODO: clean up opset stuff after moving extra.onnx_ops here
|
|
|
|
|
self.onnx_ops_module = importlib.import_module('extra.onnx_ops')
|
|
|
|
|
self.onnx_ops = {
|
|
|
|
|
**{op: getattr(Tensor, op.lower()) for op in ("Neg", "Reciprocal", "Pow", "Sqrt", "Sign", "Abs", "Exp", "Log", "Mish", "Sin", "Cos", "Tan",
|
|
|
|
|
"Asin", "Acos", "Atan", "Relu", "Sigmoid", "MatMul", "Floor", "Ceil", "IsInf", "IsNaN", "Softplus", "HardSwish", "Where", "Mul", "Sinh", "Cosh",
|
|
|
|
|
"Tanh", "Softsign", "Asinh", "Acosh", "Atanh", "Elu", "Celu", "Selu", "Xor", "Round", "Erf", "Mod")},
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
def _parse_input(self, name: str, value: Any, spec: OnnxValue):
|
|
|
|
|
if spec.is_optional and value is None: return None
|
|
|
|
|
# TODO: need true float16 for dtype checking
|
|
|
|
|
if spec.is_sequence:
|
|
|
|
|
if not isinstance(value, Sequence): raise RuntimeError(f"{name} received {value}, expected a sequence type")
|
|
|
|
|
sequence = [Tensor(v, dtype=spec.dtype, requires_grad=self.is_training) if not isinstance(v, Tensor) else v for v in value]
|
|
|
|
|
if not all_same(tuple(t.shape for t in sequence)): raise RuntimeError(f"Shapes for {name} sequence must be homogeneous")
|
|
|
|
|
return sequence
|
|
|
|
|
if type_proto.HasField("tensor_type"):
|
|
|
|
|
dtype = dtype_parse(type_proto.tensor_type.elem_type)
|
|
|
|
|
tensor = Tensor(user_input, dtype=dtype, requires_grad=is_onnx_preview_training) if not isinstance(user_input, Tensor) else user_input
|
|
|
|
|
# TODO: need true float16 for dtype checking
|
|
|
|
|
# if dtype is not tensor.dtype: raise RuntimeError(f"{model_input.name} has mismatch for dtype. Expected {dtype}, received {tensor.dtype}.")
|
|
|
|
|
for dim, onnx_dim in enumerate(type_proto.tensor_type.shape.dim):
|
|
|
|
|
dim_param, dim_value = onnx_dim.dim_param, onnx_dim.dim_value
|
|
|
|
|
user_dim_input = tensor.shape[dim]
|
|
|
|
|
if dim_param: dim_value = variable_dims[dim_param] if dim_param in variable_dims else variable_dims.setdefault(dim_param, user_dim_input)
|
|
|
|
|
if user_dim_input != dim_value:
|
|
|
|
|
raise RuntimeError(f"{model_input.name} has mismatch for dim={dim_param or dim}. Expected {dim_value}, received {user_dim_input}.")
|
|
|
|
|
return tensor
|
|
|
|
|
type_field_names = [field.name for field,_ in type_proto.ListFields()]
|
|
|
|
|
raise NotImplementedError(f"{model_input.name} with {type_field_names=} is not supported")
|
|
|
|
|
tensor = Tensor(value, dtype=spec.dtype, requires_grad=self.is_training) if not isinstance(value, Tensor) else value
|
|
|
|
|
for dim, (onnx_dim, user_dim_input) in enumerate(zip(spec.shape, tensor.shape, strict=True)):
|
|
|
|
|
if isinstance(onnx_dim, str):
|
|
|
|
|
onnx_dim = self.variable_dims[onnx_dim] if onnx_dim in self.variable_dims else self.variable_dims.setdefault(onnx_dim, int(user_dim_input))
|
|
|
|
|
if user_dim_input != onnx_dim: raise RuntimeError(f"{name} has mismatch on {dim=}. Expected {onnx_dim}, received {user_dim_input}.")
|
|
|
|
|
return tensor
|
|
|
|
|
|
|
|
|
|
def run_onnx(inputs={}, debug=0):
|
|
|
|
|
debug = getenv("DEBUGONNX") or debug
|
|
|
|
|
if debug >= 3: print("Model initialization data:\n" + "\n".join(f"\t{i.name} - {model_tensors[i.name]}" for i in onnx_model.graph.initializer))
|
|
|
|
|
def _dispatch_op(self, op, inps, opts):
|
|
|
|
|
if op in self.onnx_ops: return self.onnx_ops[op](*inps, **opts)
|
|
|
|
|
if hasattr(self.onnx_ops_module, op):
|
|
|
|
|
fxn = getattr(self.onnx_ops_module, op)
|
|
|
|
|
if isinstance(fxn, dict):
|
|
|
|
|
for k in sorted(fxn.keys()):
|
|
|
|
|
if k <= self.opset_version:
|
|
|
|
|
real_fxn = fxn[k]
|
|
|
|
|
else: real_fxn = fxn
|
|
|
|
|
return real_fxn(*inps, **opts)
|
|
|
|
|
raise NotImplementedError(f"{op=} not supported")
|
|
|
|
|
|
|
|
|
|
if debug >= 1: print("Model input:")
|
|
|
|
|
for name, value_info in model_expected_inputs.items():
|
|
|
|
|
def __call__(self, inputs:dict[str, Any], debug=debug):
|
|
|
|
|
for name, input_spec in self.graph_inputs.items():
|
|
|
|
|
if name not in inputs: raise RuntimeError(f"Please provide input data for {name}")
|
|
|
|
|
model_tensors[name] = prepare_input(inputs[name], value_info)
|
|
|
|
|
if debug >= 1: print(f"\t{name} - {model_tensors[name]}")
|
|
|
|
|
if debug >= 2: print(f"\t\t{MessageToDict(value_info.type)}")
|
|
|
|
|
self.graph_values[name] = self._parse_input(name, inputs[name], input_spec)
|
|
|
|
|
|
|
|
|
|
for num,n in enumerate(onnx_model.graph.node):
|
|
|
|
|
inp_tensors = [model_tensors.get(x) for x in n.input]
|
|
|
|
|
required_consts = required_input_python_consts.get(n.op_type, ())
|
|
|
|
|
inp = [to_python_const(t) if i in required_consts else t for i,t in enumerate(inp_tensors)]
|
|
|
|
|
opt = model_attributes[num]
|
|
|
|
|
for node in self.graph_nodes:
|
|
|
|
|
inps = [to_python_const(self.graph_values.get(name), node.op, i) for i,name in enumerate(node.inputs)]
|
|
|
|
|
opts = node.opts
|
|
|
|
|
|
|
|
|
|
if debug >= 1: print(f"{num}: op \"{n.op_type}\" input shapes {[x.shape if isinstance(x, Tensor) else x for x in inp_tensors]} opt {opt}")
|
|
|
|
|
if debug >= 3:
|
|
|
|
|
print("\tinputs:")
|
|
|
|
|
print("\n".join(f"\t\t{x} - {t!r}" + (" (to_python_const)" if i in required_consts else "") for i,(x,t) in enumerate(zip(n.input, inp))))
|
|
|
|
|
# provide additional opts
|
|
|
|
|
if node.op == "Split" and 'num_outputs' not in opts: opts['num_outputs'] = len(node.outputs)
|
|
|
|
|
if node.op == "Gradient": opts['intermediate_tensors'] = self.graph_values
|
|
|
|
|
|
|
|
|
|
# provide additional arguments
|
|
|
|
|
if n.op_type == "Split" and 'num_outputs' not in opt: opt['num_outputs'] = len(n.output)
|
|
|
|
|
if n.op_type == "Gradient": opt['intermediate_tensors'] = model_tensors
|
|
|
|
|
if debug >= 1: print(f"{node.num}: op '{node.op}' opt {opts}")
|
|
|
|
|
if debug >= 2 and node.inputs: print("\tinputs:\n" + "\n".join(f"\t\t{x} - {i!r}" for x,i in zip(node.inputs, inps)))
|
|
|
|
|
ret = self._dispatch_op(node.op, inps, opts)
|
|
|
|
|
ret = ret if isinstance(ret, tuple) else (ret,)
|
|
|
|
|
if debug >= 2: print("\toutputs:\n" + "\n".join(f"\t\t{x} - {o!r}" for x,o in zip(node.outputs, ret)))
|
|
|
|
|
|
|
|
|
|
# run op
|
|
|
|
|
if n.op_type in tensor_methods: ret = getattr(Tensor, tensor_methods[n.op_type])(*inp, **opt)
|
|
|
|
|
elif hasattr(onnx_ops, n.op_type):
|
|
|
|
|
fxn = getattr(onnx_ops, n.op_type)
|
|
|
|
|
if isinstance(fxn, dict):
|
|
|
|
|
for k in sorted(fxn.keys()):
|
|
|
|
|
if k <= onnx_model_version:
|
|
|
|
|
real_fxn = fxn[k]
|
|
|
|
|
else:
|
|
|
|
|
real_fxn = fxn
|
|
|
|
|
ret = real_fxn(*inp, **opt)
|
|
|
|
|
else:
|
|
|
|
|
print("UNSUPPORTED", n.op_type, n.input, n.output)
|
|
|
|
|
raise NotImplementedError(f"op_type {n.op_type} not supported")
|
|
|
|
|
self.graph_values.update(dict(zip(node.outputs, ret[:len(node.outputs)], strict=True)))
|
|
|
|
|
|
|
|
|
|
# finalization after running the op
|
|
|
|
|
if not isinstance(ret, tuple): ret = (ret, )
|
|
|
|
|
if len(n.output) > len(ret): raise RuntimeError(f"expected output size must be less than {len(ret)}, it's {n.output}")
|
|
|
|
|
for i in range(len(n.output)): model_tensors[n.output[i]] = ret[i]
|
|
|
|
|
if debug >= 2: print("\toutputs:\n" + "\n".join(f"\t\t{n.output[i]} - {ret[i]}" for i in range(len(n.output))))
|
|
|
|
|
|
|
|
|
|
if num == ONNXLIMIT: return {name:model_tensors[name] for name in n.output}
|
|
|
|
|
return {x.name:model_tensors[x.name] for x in onnx_model.graph.output}
|
|
|
|
|
return run_onnx
|
|
|
|
|
if node.num == limit:
|
|
|
|
|
Tensor.training, Tensor.no_grad = self.old_training, self.old_no_grad
|
|
|
|
|
return {name:self.graph_values[name] for name in node.outputs}
|
|
|
|
|
Tensor.training, Tensor.no_grad = self.old_training, self.old_no_grad
|
|
|
|
|
return {name:self.graph_values[name] for name in self.graph_outputs}
|