mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
184 lines
9.3 KiB
Python
184 lines
9.3 KiB
Python
from __future__ import annotations
|
|
from typing import List, Dict, Union, Callable, Any
|
|
import importlib, functools
|
|
import numpy as np
|
|
from tinygrad import Tensor, dtypes
|
|
from tinygrad.helpers import getenv, DEBUG
|
|
from tinygrad.dtype import DType, ConstType
|
|
from tinygrad.device import is_dtype_supported
|
|
from onnx import AttributeProto, ModelProto, TensorProto, TypeProto
|
|
try:
|
|
from onnx.helper import tensor_dtype_to_np_dtype
|
|
except ImportError:
|
|
# for onnx < 1.13
|
|
from onnx.mapping import TENSOR_TYPE_TO_NP_TYPE
|
|
def tensor_dtype_to_np_dtype(tensor_dtype:int) -> np.dtype: return TENSOR_TYPE_TO_NP_TYPE[tensor_dtype]
|
|
|
|
cache_misses = 0
|
|
@functools.lru_cache(None)
|
|
def _cached_to_python_const(t:Tensor):
|
|
if t.dtype is dtypes.uint8: return t.data().tobytes()
|
|
if 0 in t.shape: return []
|
|
return t.tolist()
|
|
|
|
# Tensor -> python value cache for parameters
|
|
def to_python_const(t) -> Union[List[ConstType], List[bytes], Union[ConstType, bytes]]:
|
|
if not isinstance(t, Tensor): return t
|
|
global cache_misses
|
|
ret = _cached_to_python_const(t)
|
|
if (info := _cached_to_python_const.cache_info()).misses > cache_misses and DEBUG >= 3:
|
|
print(f"Cache miss for {t}")
|
|
cache_misses = info.misses
|
|
return ret
|
|
|
|
# TODO: use real float16
|
|
# src: onnx/mapping.py
|
|
DTYPE_MAP: Dict[TensorProto.DataType | int, DType] = {
|
|
TensorProto.FLOAT:dtypes.float32, TensorProto.UINT8:dtypes.uint8, TensorProto.INT8:dtypes.int8,
|
|
TensorProto.UINT16:dtypes.uint16, TensorProto.INT16:dtypes.int16, TensorProto.INT32:dtypes.int32, TensorProto.INT64:dtypes.int64,
|
|
TensorProto.BOOL:dtypes.bool, TensorProto.FLOAT16:dtypes.float32, TensorProto.DOUBLE:dtypes.double, TensorProto.UINT32:dtypes.uint32,
|
|
TensorProto.UINT64:dtypes.uint64, TensorProto.BFLOAT16:dtypes.bfloat16, TensorProto.FLOAT8E4M3FN:dtypes.float,
|
|
TensorProto.FLOAT8E4M3FNUZ:dtypes.float, TensorProto.FLOAT8E5M2:dtypes.float, TensorProto.FLOAT8E5M2FNUZ:dtypes.float
|
|
}
|
|
def dtype_parse(onnx_dtype: TensorProto.DataType | int) -> DType:
|
|
if onnx_dtype not in DTYPE_MAP: raise NotImplementedError(f"onnx dtype {TensorProto.DataType.Name(onnx_dtype)} is not supported")
|
|
return DTYPE_MAP[onnx_dtype] if is_dtype_supported(DTYPE_MAP[onnx_dtype]) else dtypes.float
|
|
|
|
# src: onnx/onnx_ml_pb2.pyi
|
|
ATTRIBUTE_MAP: Dict[AttributeProto.AttributeType, Callable[[AttributeProto], Any]] = {
|
|
AttributeProto.FLOAT: lambda a: float(a.f), AttributeProto.INT: lambda a: int(a.i),
|
|
AttributeProto.STRING: lambda a: a.s.decode("utf-8"), AttributeProto.TENSOR: lambda a: buffer_parse(a.t),
|
|
AttributeProto.FLOATS: lambda a: tuple(float(x) for x in a.floats), AttributeProto.INTS: lambda a: tuple(int(x) for x in a.ints),
|
|
AttributeProto.STRINGS: lambda a: tuple(x.decode("utf-8") for x in a.strings)
|
|
}
|
|
def attribute_parse(onnx_attribute: AttributeProto):
|
|
if onnx_attribute.type not in ATTRIBUTE_MAP:
|
|
raise NotImplementedError(f"attribute with type {AttributeProto.AttributeType.Name(onnx_attribute.type)} is not supported")
|
|
return ATTRIBUTE_MAP[onnx_attribute.type](onnx_attribute)
|
|
|
|
def buffer_parse(inp: TensorProto) -> Tensor:
|
|
if dat := list(inp.float_data) or list(inp.int32_data) or list(inp.int64_data):
|
|
return Tensor(dat, dtype=dtype_parse(inp.data_type), requires_grad=False).reshape(tuple(inp.dims))
|
|
if len(inp.raw_data) > 0:
|
|
return Tensor(np.frombuffer(inp.raw_data, dtype=tensor_dtype_to_np_dtype(inp.data_type)).copy().reshape(tuple(inp.dims)),
|
|
dtype=dtype_parse(inp.data_type), requires_grad=False)
|
|
raise NotImplementedError(f"buffer with data type {TensorProto.DataType.Name(inp.data_type)} is not supported")
|
|
|
|
onnx_ops = importlib.import_module('extra.onnx_ops')
|
|
ONNXLIMIT = getenv("ONNXLIMIT", -1)
|
|
|
|
def get_run_onnx(onnx_model: ModelProto):
|
|
def type_parse(type_proto: TypeProto):
|
|
ret = []
|
|
while True:
|
|
attr = type_proto.WhichOneof('value')
|
|
if attr == 'tensor_type':
|
|
if "dim_value" not in type_proto.tensor_type.shape.dim.__dir__(): return () # variable type, unable to determine shape
|
|
elif not ret:
|
|
return tuple([x.dim_value for x in type_proto.tensor_type.shape.dim])
|
|
else:
|
|
ret.extend([(x.dim_value,) for x in type_proto.tensor_type.shape.dim])
|
|
return tuple(ret)
|
|
elif attr == 'sequence_type':
|
|
type_proto = getattr(type_proto, attr).elem_type
|
|
ret.append(1)
|
|
elif attr == 'optional_type': type_proto = getattr(type_proto, attr).elem_type
|
|
elif attr == 'map_type': raise NotImplementedError(f"map_type is not implemented: {type_proto}")
|
|
elif attr == 'opaque_type': raise NotImplementedError(f"opaque_type is not implemented: {type_proto}")
|
|
elif attr == 'sparse_tensor_type': raise NotImplementedError(f"sparse_tensor_type is not implemented: {type_proto}")
|
|
else: raise AttributeError(f"unknown attr: {attr}, {type_proto}")
|
|
|
|
# initialization data
|
|
model_parameters = {inp.name:buffer_parse(inp) for inp in onnx_model.graph.initializer}
|
|
model_attributes = {num:{x.name:attribute_parse(x) for x in n.attribute} for num,n in enumerate(onnx_model.graph.node)}
|
|
|
|
# model specs
|
|
is_onnx_preview_training = any(n.HasField("domain") and n.domain == "ai.onnx.preview.training" for n in onnx_model.graph.node)
|
|
onnx_model_version = onnx_model.opset_import[0].version
|
|
|
|
# mapping from onnx ops to tensor.py ops
|
|
tensor_methods = {
|
|
op:op.lower() for op in ("Neg", "Reciprocal", "Pow", "Sqrt", "Sign", "Abs", "Exp", "Log", "Mish", "Sin", "Cos", "Tan", "Asin", "Acos", "Atan",
|
|
"Relu", "Sigmoid", "MatMul", "Floor", "Ceil", "IsInf", "IsNaN", "Softplus", "HardSwish", "Where", "Mul", "Sinh", "Cosh", "Tanh",
|
|
"Softsign", "Asinh", "Acosh", "Atanh", "Elu", "Celu", "Selu", "Xor", "Round", "Erf")
|
|
}
|
|
|
|
def run_onnx(inputs={}, debug=0):
|
|
debug = getenv("DEBUGONNX") or debug
|
|
input_tensors: Dict[str,Tensor|List[Tensor]] = {}
|
|
intermediate_tensors: Dict[str,Tensor] = {}
|
|
|
|
# get inputs
|
|
for model_input in onnx_model.graph.input:
|
|
name = model_input.name
|
|
if name in model_parameters: continue
|
|
shape = type_parse(model_input.type)
|
|
if name in inputs:
|
|
if isinstance(inputs[name], Tensor):
|
|
input_tensors[name] = inputs[name]
|
|
elif isinstance(inputs[name], list):
|
|
input_tensors[name] = [Tensor(i, requires_grad=False) for i in inputs[name]]
|
|
# TODO: this is just to make training tests pass, need a principled way to handle training vs non-training
|
|
elif is_onnx_preview_training:
|
|
input_tensors[name] = Tensor(inputs[name], requires_grad=True)
|
|
else:
|
|
input_tensors[name] = Tensor(inputs[name], requires_grad=False)
|
|
if shape: # if only input_tensor is not variable type
|
|
ts = input_tensors[name]
|
|
input_shape = ts.shape if isinstance(ts, Tensor) else (1, *[i.shape for i in ts])
|
|
assert input_shape == shape, f"wrong shape for input {name}, {input_shape} isn't {shape}"
|
|
else:
|
|
raise RuntimeError(f"no data for {name} with shape {shape}")
|
|
|
|
def fetch_tensor(x: str):
|
|
if x in model_parameters: return model_parameters[x]
|
|
if x in intermediate_tensors: return intermediate_tensors[x]
|
|
if x != "": return input_tensors[x]
|
|
return None
|
|
|
|
for num,n in enumerate(onnx_model.graph.node):
|
|
inp = [fetch_tensor(x) for x in n.input]
|
|
opt = model_attributes[num]
|
|
|
|
if debug >= 1: print(f"{num}: op \"{n.op_type}\" input shapes {[x.shape if isinstance(x, Tensor) else x for x in inp]} opt {opt}")
|
|
if debug >= 3: print("\tinputs:\n" + "\n".join(f"\t\t{x} - {t}" for i,(x,t) in enumerate(zip(n.input, inp))))
|
|
|
|
if n.op_type in tensor_methods:
|
|
ret = getattr(Tensor, tensor_methods[n.op_type])(*inp, **opt)
|
|
|
|
# NOTE some ops live here because they require access to some local variables
|
|
elif n.op_type == "Split":
|
|
axis, n_outputs = opt.get('axis', 0), opt.get('num_outputs') or len(n.output)
|
|
sz = inp[0].shape[axis]
|
|
sizes = to_python_const(inp[1]) if len(inp) == 2 else [sz // n_outputs + (1 if i < sz % n_outputs else 0) for i in range(n_outputs)]
|
|
ret = inp[0].split(sizes, axis)
|
|
elif n.op_type == "Gradient":
|
|
assert len(opt["xs"]) == len(inp), f"len(opt['xs']):{len(opt['xs'])}, len(inp):{len(inp)} output and input has to match"
|
|
y = opt["y"]
|
|
intermediate_tensors[y].backward()
|
|
ret = tuple([t.grad for t in inp])
|
|
|
|
# onnx_ops.py
|
|
elif hasattr(onnx_ops, n.op_type):
|
|
fxn = getattr(onnx_ops, n.op_type)
|
|
if isinstance(fxn, dict):
|
|
for k in sorted(fxn.keys()):
|
|
if k <= onnx_model_version:
|
|
real_fxn = fxn[k]
|
|
else:
|
|
real_fxn = fxn
|
|
ret = real_fxn(*inp, **opt)
|
|
else:
|
|
print("UNSUPPORTED", n.op_type, n.input, n.output)
|
|
raise NotImplementedError(f"op_type {n.op_type} not supported")
|
|
|
|
# finalization after running the op
|
|
if not isinstance(ret, tuple): ret = (ret, )
|
|
if len(n.output) > len(ret): raise RuntimeError(f"expected output size must be less than {len(ret)}, it's {n.output}")
|
|
for i in range(len(n.output)): intermediate_tensors[n.output[i]] = ret[i]
|
|
if debug >= 2: print("\toutputs:\n" + "\n".join(f"\t\t{n.output[i]} - {ret[i]}" for i in range(len(n.output))))
|
|
|
|
if num == ONNXLIMIT: return {name:intermediate_tensors[name] for name in n.output}
|
|
return {x.name:intermediate_tensors[x.name] for x in onnx_model.graph.output}
|
|
return run_onnx
|