add sequence_type back

This commit is contained in:
zibokapi
2025-04-03 12:24:18 +08:00
parent 3e9b2c2474
commit 288170b2d9
3 changed files with 21 additions and 21 deletions

View File

@@ -1,7 +1,7 @@
from typing import Any, cast, Literal, Callable
from typing import Any, Sequence, cast, Literal, Callable
import dataclasses, functools, io, math, types
from tinygrad.tensor import Tensor, _broadcast_shape, ReductionStr
from tinygrad.helpers import getenv, DEBUG, prod, flatten, make_tuple, argsort
from tinygrad.helpers import getenv, DEBUG, all_same, prod, flatten, make_tuple, argsort
from tinygrad.dtype import DType, ConstType, dtypes, ImageDType
from tinygrad.device import is_dtype_supported
@@ -56,10 +56,11 @@ def type_parse(onnx_type: TypeProto):
if elem_type.HasField("map_type") or elem_type.HasField("sparse_tensor_type") or elem_type.HasField("opaque_type"):
raise NotImplementedError("parsing for map_type, sparse_tensor_type and opaque_type are not implemented")
if is_optional := elem_type.HasField("optional_type"): elem_type = elem_type.optional_type.elem_type
if is_sequence := elem_type.HasField("sequence_type"): elem_type = elem_type.sequence_type.elem_type
if elem_type.HasField("tensor_type"):
shape = tuple(d.dim_param or d.dim_value for d in elem_type.tensor_type.shape.dim)
dtype = dtype_parse(elem_type.tensor_type.elem_type)
return OnnxValue(shape, dtype, is_optional)
return OnnxValue(shape, dtype, is_optional, is_sequence)
raise RuntimeError(f"TypeProto was not parsed properly: {onnx_type=}")
# ***** onnx spec *****
@@ -68,6 +69,7 @@ class OnnxValue:
shape: tuple[str|int, ...]
dtype: DType
is_optional: bool
is_sequence: bool
@dataclasses.dataclass(frozen=True)
class OnnxNode:
@@ -123,25 +125,29 @@ class OnnxRunner:
self.onnx_ops = onnx_ops
def _valid_shape(self, value: Tensor, spec: OnnxValue):
if len(spec.shape) != len(value.shape): return False
for onnx_dim, user_dim_input in zip(spec.shape, value.shape):
def _is_valid(self, value: Tensor | list, spec: OnnxValue):
if isinstance(value, list):
assert spec.is_sequence
if not all_same(tuple(v.shape for v in value)): raise RuntimeError
return value
for onnx_dim, user_dim_input in zip(spec.shape, value.shape, strict=True):
if isinstance(onnx_dim, str):
onnx_dim = self.variable_dims[onnx_dim] if onnx_dim in self.variable_dims else self.variable_dims.setdefault(onnx_dim, int(user_dim_input))
if user_dim_input != onnx_dim: return False
return True
if user_dim_input != onnx_dim: raise RuntimeError
return value
def _parse_input(self, name: str, value: Any, spec: OnnxValue):
if spec.is_optional and value is None: return None
if value is None: raise RuntimeError(f"'{name}' is not marked as optional, but received a None value")
if not isinstance(value, Tensor): value = Tensor(value, dtype=spec.dtype, requires_grad=self.is_training)
if not self._valid_shape(value, spec): raise RuntimeError(f"input '{name}' has wrong shape, got {value.shape}, expected {spec}")
return value
elif spec.is_sequence:
if not isinstance(value, Sequence): raise RuntimeError(f"input '{name}' received {value}, expected a sequence type")
# sequence inputs and outputs are interpreted in OnnxRunner as list type
value = [Tensor(v, dtype=spec.dtype, requires_grad=self.is_training) if not isinstance(v, Tensor) else v for v in value]
elif not isinstance(value, Tensor): value = Tensor(value, dtype=spec.dtype, requires_grad=self.is_training)
return self._is_valid(value, spec)
def _parse_output(self, name: str):
value, spec = self.graph_values[name], self.graph_outputs[name]
if not isinstance(value, Tensor): return value
if not self._valid_shape(value, spec): raise RuntimeError(f"output '{name}' has wrong shape, got {value.shape}, expected {spec}")
if isinstance(value, (Tensor, list)): return self._is_valid(value, spec)
return value
def _dispatch_op(self, op, inps, opts):

View File

@@ -40,7 +40,7 @@ def get_example_inputs(graph_inputs:dict[str, OnnxValue], config={}):
ret: dict[str, Tensor] = {}
for name, spec in graph_inputs.items():
assert not spec.is_optional, "only allow tensor input for now"
assert not spec.is_optional and not spec.is_sequence, "only allow tensor input for now"
shape = _get_shape(spec.shape)
value = _get_value(name, shape, spec.dtype)
ret.update({name:value})

View File

@@ -175,12 +175,6 @@ backend_test.exclude('test_resize_downsample_sizes_linear_antialias_cpu') # anti
backend_test.exclude('test_ai_onnx_ml_label_encoder_tensor_value_only_mapping_cpu') # bad data type string
backend_test.exclude('test_ai_onnx_ml_label_encoder_tensor_mapping_cpu') # bad data type string
# no support for sequence
backend_test.exclude('test_identity_opt_cpu')
backend_test.exclude('test_identity_sequence_cpu')
backend_test.exclude('test_optional_get_element_optional_sequence_cpu')
backend_test.exclude('test_optional_get_element_sequence_cpu')
backend_test.exclude('test_scatternd_min_cpu') # min not yet supported
backend_test.exclude('test_scatternd_max_cpu') # max not yet supported