diff --git a/extra/onnx.py b/extra/onnx.py index ae0d6a2e96..7c5515edbe 100644 --- a/extra/onnx.py +++ b/extra/onnx.py @@ -125,29 +125,29 @@ class OnnxRunner: self.onnx_ops = onnx_ops - def _is_valid(self, value: Tensor | list, spec: OnnxValue): + def _is_valid(self, name: str, value: Tensor | list, spec: OnnxValue): if isinstance(value, list): assert spec.is_sequence - if not all_same(tuple(v.shape for v in value)): raise RuntimeError - return value - for onnx_dim, user_dim_input in zip(spec.shape, value.shape, strict=True): - if isinstance(onnx_dim, str): - onnx_dim = self.variable_dims[onnx_dim] if onnx_dim in self.variable_dims else self.variable_dims.setdefault(onnx_dim, int(user_dim_input)) - if user_dim_input != onnx_dim: raise RuntimeError + if not all_same(tuple(v.shape for v in value)): raise RuntimeError(f"sequence '{name}' must have homogenous shape. Received {value}.") + else: + for dim, (onnx_dim, value_dim) in enumerate(zip(spec.shape, value.shape, strict=True)): + if isinstance(onnx_dim, str): + onnx_dim = self.variable_dims[onnx_dim] if onnx_dim in self.variable_dims else self.variable_dims.setdefault(onnx_dim, int(value_dim)) + if value_dim != onnx_dim: raise RuntimeError(f"tensor '{name}' has mismatch on {dim=}. Expected {onnx_dim}, received {value_dim}.") return value def _parse_input(self, name: str, value: Any, spec: OnnxValue): if spec.is_optional and value is None: return None elif spec.is_sequence: - if not isinstance(value, Sequence): raise RuntimeError(f"input '{name}' received {value}, expected a sequence type") + if not isinstance(value, Sequence): raise RuntimeError(f"'{name}' received {value}, expected a sequence type.") # sequence inputs and outputs are interpreted in OnnxRunner as list type value = [Tensor(v, dtype=spec.dtype, requires_grad=self.is_training) if not isinstance(v, Tensor) else v for v in value] elif not isinstance(value, Tensor): value = Tensor(value, dtype=spec.dtype, requires_grad=self.is_training) - return self._is_valid(value, spec) + return self._is_valid(name, value, spec) def _parse_output(self, name: str): value, spec = self.graph_values[name], self.graph_outputs[name] - if isinstance(value, (Tensor, list)): return self._is_valid(value, spec) + if isinstance(value, (Tensor, list)): return self._is_valid(name, value, spec) return value def _dispatch_op(self, op, inps, opts):