mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 15:08:02 -05:00
improve err msg
This commit is contained in:
@@ -125,29 +125,29 @@ class OnnxRunner:
|
||||
|
||||
self.onnx_ops = onnx_ops
|
||||
|
||||
def _is_valid(self, value: Tensor | list, spec: OnnxValue):
|
||||
def _is_valid(self, name: str, value: Tensor | list, spec: OnnxValue):
|
||||
if isinstance(value, list):
|
||||
assert spec.is_sequence
|
||||
if not all_same(tuple(v.shape for v in value)): raise RuntimeError
|
||||
return value
|
||||
for onnx_dim, user_dim_input in zip(spec.shape, value.shape, strict=True):
|
||||
if isinstance(onnx_dim, str):
|
||||
onnx_dim = self.variable_dims[onnx_dim] if onnx_dim in self.variable_dims else self.variable_dims.setdefault(onnx_dim, int(user_dim_input))
|
||||
if user_dim_input != onnx_dim: raise RuntimeError
|
||||
if not all_same(tuple(v.shape for v in value)): raise RuntimeError(f"sequence '{name}' must have homogenous shape. Received {value}.")
|
||||
else:
|
||||
for dim, (onnx_dim, value_dim) in enumerate(zip(spec.shape, value.shape, strict=True)):
|
||||
if isinstance(onnx_dim, str):
|
||||
onnx_dim = self.variable_dims[onnx_dim] if onnx_dim in self.variable_dims else self.variable_dims.setdefault(onnx_dim, int(value_dim))
|
||||
if value_dim != onnx_dim: raise RuntimeError(f"tensor '{name}' has mismatch on {dim=}. Expected {onnx_dim}, received {value_dim}.")
|
||||
return value
|
||||
|
||||
def _parse_input(self, name: str, value: Any, spec: OnnxValue):
|
||||
if spec.is_optional and value is None: return None
|
||||
elif spec.is_sequence:
|
||||
if not isinstance(value, Sequence): raise RuntimeError(f"input '{name}' received {value}, expected a sequence type")
|
||||
if not isinstance(value, Sequence): raise RuntimeError(f"'{name}' received {value}, expected a sequence type.")
|
||||
# sequence inputs and outputs are interpreted in OnnxRunner as list type
|
||||
value = [Tensor(v, dtype=spec.dtype, requires_grad=self.is_training) if not isinstance(v, Tensor) else v for v in value]
|
||||
elif not isinstance(value, Tensor): value = Tensor(value, dtype=spec.dtype, requires_grad=self.is_training)
|
||||
return self._is_valid(value, spec)
|
||||
return self._is_valid(name, value, spec)
|
||||
|
||||
def _parse_output(self, name: str):
|
||||
value, spec = self.graph_values[name], self.graph_outputs[name]
|
||||
if isinstance(value, (Tensor, list)): return self._is_valid(value, spec)
|
||||
if isinstance(value, (Tensor, list)): return self._is_valid(name, value, spec)
|
||||
return value
|
||||
|
||||
def _dispatch_op(self, op, inps, opts):
|
||||
|
||||
Reference in New Issue
Block a user