mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-10 15:38:29 -05:00
add sequence_type back
This commit is contained in:
@@ -1,7 +1,7 @@
|
||||
from typing import Any, cast, Literal, Callable
|
||||
from typing import Any, Sequence, cast, Literal, Callable
|
||||
import dataclasses, functools, io, math, types
|
||||
from tinygrad.tensor import Tensor, _broadcast_shape, ReductionStr
|
||||
from tinygrad.helpers import getenv, DEBUG, prod, flatten, make_tuple, argsort
|
||||
from tinygrad.helpers import getenv, DEBUG, all_same, prod, flatten, make_tuple, argsort
|
||||
from tinygrad.dtype import DType, ConstType, dtypes, ImageDType
|
||||
from tinygrad.device import is_dtype_supported
|
||||
|
||||
@@ -56,10 +56,11 @@ def type_parse(onnx_type: TypeProto):
|
||||
if elem_type.HasField("map_type") or elem_type.HasField("sparse_tensor_type") or elem_type.HasField("opaque_type"):
|
||||
raise NotImplementedError("parsing for map_type, sparse_tensor_type and opaque_type are not implemented")
|
||||
if is_optional := elem_type.HasField("optional_type"): elem_type = elem_type.optional_type.elem_type
|
||||
if is_sequence := elem_type.HasField("sequence_type"): elem_type = elem_type.sequence_type.elem_type
|
||||
if elem_type.HasField("tensor_type"):
|
||||
shape = tuple(d.dim_param or d.dim_value for d in elem_type.tensor_type.shape.dim)
|
||||
dtype = dtype_parse(elem_type.tensor_type.elem_type)
|
||||
return OnnxValue(shape, dtype, is_optional)
|
||||
return OnnxValue(shape, dtype, is_optional, is_sequence)
|
||||
raise RuntimeError(f"TypeProto was not parsed properly: {onnx_type=}")
|
||||
|
||||
# ***** onnx spec *****
|
||||
@@ -68,6 +69,7 @@ class OnnxValue:
|
||||
shape: tuple[str|int, ...]
|
||||
dtype: DType
|
||||
is_optional: bool
|
||||
is_sequence: bool
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class OnnxNode:
|
||||
@@ -123,25 +125,29 @@ class OnnxRunner:
|
||||
|
||||
self.onnx_ops = onnx_ops
|
||||
|
||||
def _valid_shape(self, value: Tensor, spec: OnnxValue):
|
||||
if len(spec.shape) != len(value.shape): return False
|
||||
for onnx_dim, user_dim_input in zip(spec.shape, value.shape):
|
||||
def _is_valid(self, value: Tensor | list, spec: OnnxValue):
|
||||
if isinstance(value, list):
|
||||
assert spec.is_sequence
|
||||
if not all_same(tuple(v.shape for v in value)): raise RuntimeError
|
||||
return value
|
||||
for onnx_dim, user_dim_input in zip(spec.shape, value.shape, strict=True):
|
||||
if isinstance(onnx_dim, str):
|
||||
onnx_dim = self.variable_dims[onnx_dim] if onnx_dim in self.variable_dims else self.variable_dims.setdefault(onnx_dim, int(user_dim_input))
|
||||
if user_dim_input != onnx_dim: return False
|
||||
return True
|
||||
if user_dim_input != onnx_dim: raise RuntimeError
|
||||
return value
|
||||
|
||||
def _parse_input(self, name: str, value: Any, spec: OnnxValue):
|
||||
if spec.is_optional and value is None: return None
|
||||
if value is None: raise RuntimeError(f"'{name}' is not marked as optional, but received a None value")
|
||||
if not isinstance(value, Tensor): value = Tensor(value, dtype=spec.dtype, requires_grad=self.is_training)
|
||||
if not self._valid_shape(value, spec): raise RuntimeError(f"input '{name}' has wrong shape, got {value.shape}, expected {spec}")
|
||||
return value
|
||||
elif spec.is_sequence:
|
||||
if not isinstance(value, Sequence): raise RuntimeError(f"input '{name}' received {value}, expected a sequence type")
|
||||
# sequence inputs and outputs are interpreted in OnnxRunner as list type
|
||||
value = [Tensor(v, dtype=spec.dtype, requires_grad=self.is_training) if not isinstance(v, Tensor) else v for v in value]
|
||||
elif not isinstance(value, Tensor): value = Tensor(value, dtype=spec.dtype, requires_grad=self.is_training)
|
||||
return self._is_valid(value, spec)
|
||||
|
||||
def _parse_output(self, name: str):
|
||||
value, spec = self.graph_values[name], self.graph_outputs[name]
|
||||
if not isinstance(value, Tensor): return value
|
||||
if not self._valid_shape(value, spec): raise RuntimeError(f"output '{name}' has wrong shape, got {value.shape}, expected {spec}")
|
||||
if isinstance(value, (Tensor, list)): return self._is_valid(value, spec)
|
||||
return value
|
||||
|
||||
def _dispatch_op(self, op, inps, opts):
|
||||
|
||||
@@ -40,7 +40,7 @@ def get_example_inputs(graph_inputs:dict[str, OnnxValue], config={}):
|
||||
|
||||
ret: dict[str, Tensor] = {}
|
||||
for name, spec in graph_inputs.items():
|
||||
assert not spec.is_optional, "only allow tensor input for now"
|
||||
assert not spec.is_optional and not spec.is_sequence, "only allow tensor input for now"
|
||||
shape = _get_shape(spec.shape)
|
||||
value = _get_value(name, shape, spec.dtype)
|
||||
ret.update({name:value})
|
||||
|
||||
6
test/external/external_test_onnx_backend.py
vendored
6
test/external/external_test_onnx_backend.py
vendored
@@ -175,12 +175,6 @@ backend_test.exclude('test_resize_downsample_sizes_linear_antialias_cpu') # anti
|
||||
backend_test.exclude('test_ai_onnx_ml_label_encoder_tensor_value_only_mapping_cpu') # bad data type string
|
||||
backend_test.exclude('test_ai_onnx_ml_label_encoder_tensor_mapping_cpu') # bad data type string
|
||||
|
||||
# no support for sequence
|
||||
backend_test.exclude('test_identity_opt_cpu')
|
||||
backend_test.exclude('test_identity_sequence_cpu')
|
||||
backend_test.exclude('test_optional_get_element_optional_sequence_cpu')
|
||||
backend_test.exclude('test_optional_get_element_sequence_cpu')
|
||||
|
||||
backend_test.exclude('test_scatternd_min_cpu') # min not yet supported
|
||||
backend_test.exclude('test_scatternd_max_cpu') # max not yet supported
|
||||
|
||||
|
||||
Reference in New Issue
Block a user