improve err msg

This commit is contained in:
zibokapi
2025-04-03 12:32:28 +08:00
parent 288170b2d9
commit dc9eaea4bb

View File

@@ -125,29 +125,29 @@ class OnnxRunner:
self.onnx_ops = onnx_ops
def _is_valid(self, value: Tensor | list, spec: OnnxValue):
def _is_valid(self, name: str, value: Tensor | list, spec: OnnxValue):
if isinstance(value, list):
assert spec.is_sequence
if not all_same(tuple(v.shape for v in value)): raise RuntimeError
return value
for onnx_dim, user_dim_input in zip(spec.shape, value.shape, strict=True):
if isinstance(onnx_dim, str):
onnx_dim = self.variable_dims[onnx_dim] if onnx_dim in self.variable_dims else self.variable_dims.setdefault(onnx_dim, int(user_dim_input))
if user_dim_input != onnx_dim: raise RuntimeError
if not all_same(tuple(v.shape for v in value)): raise RuntimeError(f"sequence '{name}' must have homogenous shape. Received {value}.")
else:
for dim, (onnx_dim, value_dim) in enumerate(zip(spec.shape, value.shape, strict=True)):
if isinstance(onnx_dim, str):
onnx_dim = self.variable_dims[onnx_dim] if onnx_dim in self.variable_dims else self.variable_dims.setdefault(onnx_dim, int(value_dim))
if value_dim != onnx_dim: raise RuntimeError(f"tensor '{name}' has mismatch on {dim=}. Expected {onnx_dim}, received {value_dim}.")
return value
def _parse_input(self, name: str, value: Any, spec: OnnxValue):
if spec.is_optional and value is None: return None
elif spec.is_sequence:
if not isinstance(value, Sequence): raise RuntimeError(f"input '{name}' received {value}, expected a sequence type")
if not isinstance(value, Sequence): raise RuntimeError(f"'{name}' received {value}, expected a sequence type.")
# sequence inputs and outputs are interpreted in OnnxRunner as list type
value = [Tensor(v, dtype=spec.dtype, requires_grad=self.is_training) if not isinstance(v, Tensor) else v for v in value]
elif not isinstance(value, Tensor): value = Tensor(value, dtype=spec.dtype, requires_grad=self.is_training)
return self._is_valid(value, spec)
return self._is_valid(name, value, spec)
def _parse_output(self, name: str):
value, spec = self.graph_values[name], self.graph_outputs[name]
if isinstance(value, (Tensor, list)): return self._is_valid(value, spec)
if isinstance(value, (Tensor, list)): return self._is_valid(name, value, spec)
return value
def _dispatch_op(self, op, inps, opts):