mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-10 07:28:15 -05:00
This reverts commit ac713e04db.
This commit is contained in:
@@ -1,7 +1,7 @@
|
||||
from typing import Any, cast, Literal, Callable
|
||||
from typing import Any, Sequence, cast, Literal, Callable
|
||||
import dataclasses, functools, io, math, types
|
||||
from tinygrad.tensor import Tensor, _broadcast_shape, ReductionStr
|
||||
from tinygrad.helpers import getenv, DEBUG, prod, flatten, make_tuple, argsort
|
||||
from tinygrad.helpers import getenv, DEBUG, all_same, prod, flatten, make_tuple, argsort
|
||||
from tinygrad.dtype import DType, ConstType, dtypes, ImageDType
|
||||
from tinygrad.device import is_dtype_supported
|
||||
|
||||
@@ -56,10 +56,11 @@ def type_parse(onnx_type: TypeProto):
|
||||
if elem_type.HasField("map_type") or elem_type.HasField("sparse_tensor_type") or elem_type.HasField("opaque_type"):
|
||||
raise NotImplementedError("parsing for map_type, sparse_tensor_type and opaque_type are not implemented")
|
||||
if is_optional := elem_type.HasField("optional_type"): elem_type = elem_type.optional_type.elem_type
|
||||
if is_sequence := elem_type.HasField("sequence_type"): elem_type = elem_type.sequence_type.elem_type
|
||||
if elem_type.HasField("tensor_type"):
|
||||
shape = tuple(d.dim_param or d.dim_value for d in elem_type.tensor_type.shape.dim)
|
||||
dtype = dtype_parse(elem_type.tensor_type.elem_type)
|
||||
return OnnxValue(shape, dtype, is_optional)
|
||||
return OnnxValue(shape, dtype, is_optional, is_sequence)
|
||||
raise RuntimeError(f"TypeProto was not parsed properly: {onnx_type=}")
|
||||
|
||||
# ***** onnx spec *****
|
||||
@@ -68,6 +69,7 @@ class OnnxValue:
|
||||
shape: tuple[str|int, ...]
|
||||
dtype: DType
|
||||
is_optional: bool
|
||||
is_sequence: bool
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class OnnxNode:
|
||||
@@ -115,7 +117,7 @@ class OnnxRunner:
|
||||
Tensor.no_grad = False if self.is_training else True
|
||||
self.graph_values = {"": None, **{x.name:buffer_parse(x) for x in model.graph.initializer}}
|
||||
self.graph_inputs = {x.name:type_parse(x.type) for x in model.graph.input if x.name not in self.graph_values}
|
||||
self.graph_outputs = {x.name:type_parse(x.type) for x in model.graph.output}
|
||||
self.graph_outputs = tuple(x.name for x in model.graph.output)
|
||||
self.graph_nodes = tuple(OnnxNode(num, n.op_type, tuple(n.input), tuple(n.output), {x.name:attribute_parse(x) for x in n.attribute})
|
||||
for num,n in enumerate(model.graph.node))
|
||||
self.opset_version = model.opset_import[0].version
|
||||
@@ -123,25 +125,20 @@ class OnnxRunner:
|
||||
|
||||
self.onnx_ops = onnx_ops
|
||||
|
||||
def _validate_shape(self, name: str, value: Tensor, spec: OnnxValue):
|
||||
# update new variable dims
|
||||
self.variable_dims.update({sd:vd for sd,vd in zip(spec.shape, value.shape) if isinstance(sd, str) and sd not in self.variable_dims})
|
||||
# resolve dynamic shape
|
||||
expected_shape = tuple(self.variable_dims[sd] if isinstance(sd, str) else sd for sd in spec.shape)
|
||||
assert value.shape == expected_shape, f"'{name}' has wrong shape"
|
||||
|
||||
def _parse_input(self, name: str, value: Any, spec: OnnxValue):
|
||||
if spec.is_optional and value is None: return None
|
||||
if value is None: raise RuntimeError(f"'{name}' is not marked as optional, but received a None value")
|
||||
if not isinstance(value, Tensor): value = Tensor(value, dtype=spec.dtype, requires_grad=self.is_training)
|
||||
self._validate_shape(name, value, spec)
|
||||
return value
|
||||
|
||||
def _parse_output(self, name: str):
|
||||
value, spec = self.graph_values[name], self.graph_outputs[name]
|
||||
if not isinstance(value, Tensor): return value
|
||||
self._validate_shape(name, value, spec)
|
||||
return value
|
||||
# TODO: need true float16 for dtype checking
|
||||
if spec.is_sequence:
|
||||
if not isinstance(value, Sequence): raise RuntimeError(f"{name} received {value}, expected a sequence type")
|
||||
sequence = [Tensor(v, dtype=spec.dtype, requires_grad=self.is_training) if not isinstance(v, Tensor) else v for v in value]
|
||||
if not all_same(tuple(t.shape for t in sequence)): raise RuntimeError(f"Shapes for {name} sequence must be homogeneous")
|
||||
return sequence
|
||||
tensor = Tensor(value, dtype=spec.dtype, requires_grad=self.is_training) if not isinstance(value, Tensor) else value
|
||||
for dim, (onnx_dim, user_dim_input) in enumerate(zip(spec.shape, tensor.shape, strict=True)):
|
||||
if isinstance(onnx_dim, str):
|
||||
onnx_dim = self.variable_dims[onnx_dim] if onnx_dim in self.variable_dims else self.variable_dims.setdefault(onnx_dim, int(user_dim_input))
|
||||
if user_dim_input != onnx_dim: raise RuntimeError(f"{name} has mismatch on {dim=}. Expected {onnx_dim}, received {user_dim_input}.")
|
||||
return tensor
|
||||
|
||||
def _dispatch_op(self, op, inps, opts):
|
||||
if op in self.onnx_ops:
|
||||
@@ -179,7 +176,7 @@ class OnnxRunner:
|
||||
Tensor.training, Tensor.no_grad = self.old_training, self.old_no_grad
|
||||
return {name:self.graph_values[name] for name in node.outputs}
|
||||
Tensor.training, Tensor.no_grad = self.old_training, self.old_no_grad
|
||||
return {name:self._parse_output(name) for name in self.graph_outputs}
|
||||
return {name:self.graph_values[name] for name in self.graph_outputs}
|
||||
|
||||
####################
|
||||
##### ONNX OPS #####
|
||||
|
||||
@@ -40,7 +40,7 @@ def get_example_inputs(graph_inputs:dict[str, OnnxValue], config={}):
|
||||
|
||||
ret: dict[str, Tensor] = {}
|
||||
for name, spec in graph_inputs.items():
|
||||
assert not spec.is_optional, "only allow tensor input for now"
|
||||
assert not spec.is_optional and not spec.is_sequence, "only allow tensor input for now"
|
||||
shape = _get_shape(spec.shape)
|
||||
value = _get_value(name, shape, spec.dtype)
|
||||
ret.update({name:value})
|
||||
|
||||
6
test/external/external_test_onnx_backend.py
vendored
6
test/external/external_test_onnx_backend.py
vendored
@@ -175,12 +175,6 @@ backend_test.exclude('test_resize_downsample_sizes_linear_antialias_cpu') # anti
|
||||
backend_test.exclude('test_ai_onnx_ml_label_encoder_tensor_value_only_mapping_cpu') # bad data type string
|
||||
backend_test.exclude('test_ai_onnx_ml_label_encoder_tensor_mapping_cpu') # bad data type string
|
||||
|
||||
# no support for sequence
|
||||
backend_test.exclude('test_identity_opt_cpu')
|
||||
backend_test.exclude('test_identity_sequence_cpu')
|
||||
backend_test.exclude('test_optional_get_element_optional_sequence_cpu')
|
||||
backend_test.exclude('test_optional_get_element_sequence_cpu')
|
||||
|
||||
backend_test.exclude('test_scatternd_min_cpu') # min not yet supported
|
||||
backend_test.exclude('test_scatternd_max_cpu') # max not yet supported
|
||||
|
||||
|
||||
56
test/external/external_test_onnx_ops.py
vendored
56
test/external/external_test_onnx_ops.py
vendored
@@ -11,11 +11,11 @@ from extra.onnx_helpers import validate
|
||||
|
||||
class TestOnnxOps(unittest.TestCase):
|
||||
DOMAIN = None
|
||||
def helper_build_model(self, op:str, inps:dict[str, np.ndarray], opts:dict[str, Any], outs:dict[str, np.ndarray]):
|
||||
inputs = [onnx.helper.make_tensor_value_info(name, onnx.helper.np_dtype_to_tensor_dtype(arr.dtype), arr.shape) for name, arr in inps.items()]
|
||||
outputs = [onnx.helper.make_tensor_value_info(name, onnx.helper.np_dtype_to_tensor_dtype(arr.dtype), arr.shape) for name, arr in outs.items()]
|
||||
def helper_build_model(self, op:str, inps:dict[str, np.ndarray], opts:dict[str, Any], outs:list[str]):
|
||||
onnx_inputs = [onnx.helper.make_tensor_value_info(name, onnx.helper.np_dtype_to_tensor_dtype(arr.dtype), arr.shape) for name, arr in inps.items()]
|
||||
onnx_outputs = [onnx.helper.make_empty_tensor_value_info(name) for name in outs]
|
||||
nodes = [onnx.helper.make_node(op, list(inps), list(outs), domain=self.DOMAIN, **opts)]
|
||||
graph = onnx.helper.make_graph(nodes, f"test_{op.lower()}", inputs, outputs)
|
||||
graph = onnx.helper.make_graph(nodes, f"test_{op.lower()}", onnx_inputs, onnx_outputs)
|
||||
model = onnx.helper.make_model(graph, producer_name=f"test_{op.lower()}")
|
||||
return model
|
||||
|
||||
@@ -30,7 +30,7 @@ class TestMainOnnxOps(TestOnnxOps):
|
||||
def test_reshape(self):
|
||||
inputs = {"in": np.arange(6, dtype=np.float32), "shape": np.array([2,3], dtype=np.int64)}
|
||||
attributes = {}
|
||||
outputs = {"out": np.empty((2,3), dtype=np.float32)}
|
||||
outputs = ["out"]
|
||||
self.helper_test_single_op("Reshape", inputs, attributes, outputs)
|
||||
|
||||
def test_conv(self):
|
||||
@@ -41,7 +41,7 @@ class TestMainOnnxOps(TestOnnxOps):
|
||||
"b": np.random.randn(1152).astype(np.float32)
|
||||
}
|
||||
attributes = {'auto_pad': 'VALID', 'dilations': (1, 1), 'group': 1, 'kernel_shape': (14, 14), 'strides': (14, 14)}
|
||||
outputs = {"y": np.empty((1, 1152, 27, 27), dtype=np.float32)}
|
||||
outputs = ["y"]
|
||||
self.helper_test_single_op("Conv", inputs, attributes, outputs, atol=1e-4)
|
||||
|
||||
def test_gather(self):
|
||||
@@ -51,7 +51,7 @@ class TestMainOnnxOps(TestOnnxOps):
|
||||
"indices": np.array(-2, dtype=np.int64),
|
||||
}
|
||||
attributes = {'axis': 1}
|
||||
outputs = {"y": np.empty((1, 3), dtype=np.float32)}
|
||||
outputs = ["y"]
|
||||
self.helper_test_single_op("Gather", inputs, attributes, outputs)
|
||||
|
||||
def test_maxunpool(self):
|
||||
@@ -61,7 +61,7 @@ class TestMainOnnxOps(TestOnnxOps):
|
||||
output_shape = np.array((1, 1, 5, 5), dtype=np.int64)
|
||||
inputs = {"x": xT, "indices": xI, "output_shape": output_shape}
|
||||
attributes = {"kernel_shape": [2, 2], "strides": [2, 2]}
|
||||
outputs = {"y": np.empty((1, 1, 5, 5), dtype=np.float32)}
|
||||
outputs = ["y"]
|
||||
self.helper_test_single_op("MaxUnpool", inputs, attributes, outputs)
|
||||
|
||||
def test_isinf(self):
|
||||
@@ -70,7 +70,7 @@ class TestMainOnnxOps(TestOnnxOps):
|
||||
x = np.array([-1.2, np.nan, np.inf, 2.8, -np.inf, np.inf], dtype=np.float32)
|
||||
inputs = {"x": x}
|
||||
attributes = {"detect_negative":1, "detect_positive":1}
|
||||
outputs = {"y": np.empty((6,), dtype=np.bool_)}
|
||||
outputs = ["y"]
|
||||
model = self.helper_build_model("IsInf", inputs, attributes, outputs)
|
||||
outputs = OnnxRunner(model)(inputs)
|
||||
assert outputs["y"].dtype is dtypes.bool
|
||||
@@ -87,8 +87,7 @@ class TestMainOnnxOps(TestOnnxOps):
|
||||
"y_scale": np.array(case["scale"], dtype=np.float32),
|
||||
"y_zero_point": np.array(case["qzero_point"], dtype=case["qdtype"])
|
||||
}
|
||||
outputs = {"y": np.empty_like(inputs["x"], dtype=case["qdtype"])}
|
||||
self.helper_test_single_op("QuantizeLinear", inputs, {}, outputs)
|
||||
self.helper_test_single_op("QuantizeLinear", inputs, {}, ["y"])
|
||||
|
||||
def test_dynamic_quantize_linear(self):
|
||||
test_cases = [
|
||||
@@ -101,12 +100,7 @@ class TestMainOnnxOps(TestOnnxOps):
|
||||
]
|
||||
for case in test_cases:
|
||||
with self.subTest(test_case=case["name"]):
|
||||
outputs = {
|
||||
"y": np.empty_like(case["x"], dtype=np.uint8),
|
||||
"y_scale": np.empty((), dtype=np.float32),
|
||||
"y_zero_point": np.empty((), dtype=np.uint8)
|
||||
}
|
||||
self.helper_test_single_op("DynamicQuantizeLinear", {"x": case["x"]}, {}, outputs)
|
||||
self.helper_test_single_op("DynamicQuantizeLinear", {"x": case["x"]}, {}, ["y", "y_scale", "y_zero_point"])
|
||||
|
||||
def test_qlinear_conv(self):
|
||||
for dtype, zero_point in [(np.uint8, 128), (np.int8, 0)]:
|
||||
@@ -125,7 +119,7 @@ class TestMainOnnxOps(TestOnnxOps):
|
||||
"b": b
|
||||
}
|
||||
attributes = {'auto_pad': 'NOTSET', 'dilations': (1, 1), 'group': 1, 'kernel_shape': (3, 3), 'pads': (1, 1, 1, 1), 'strides': (2, 2)}
|
||||
outputs = {"out": np.empty((1, 32, 112, 112), dtype=dtype)}
|
||||
outputs = ["out"]
|
||||
self.helper_test_single_op("QLinearConv", inputs, attributes, outputs, atol=1) # occasionally inaccurate
|
||||
|
||||
def test_qlinear_matmul(self):
|
||||
@@ -143,7 +137,7 @@ class TestMainOnnxOps(TestOnnxOps):
|
||||
"Y_zero_point": np.array(zero_point, dtype=dtype)
|
||||
}
|
||||
attributes = {}
|
||||
outputs = {"Y": np.empty((10, 10), dtype=dtype)}
|
||||
outputs = ["Y"]
|
||||
self.helper_test_single_op("QLinearMatMul", inputs, attributes, outputs)
|
||||
|
||||
for name,val in (("round_half_down_to_even", 1), ("round_half_up_to_even", 3)):
|
||||
@@ -159,7 +153,7 @@ class TestMainOnnxOps(TestOnnxOps):
|
||||
"Y_zero_point": np.array(0, dtype=np.int8)
|
||||
}
|
||||
attributes = {}
|
||||
outputs = {"Y": np.empty((), dtype=np.int8)}
|
||||
outputs = ["Y"]
|
||||
self.helper_test_single_op("QLinearMatMul", inputs, attributes, outputs)
|
||||
|
||||
class TestContribOnnxOps(TestOnnxOps):
|
||||
@@ -208,8 +202,7 @@ class TestContribOnnxOps(TestOnnxOps):
|
||||
with self.subTest(f"test_attention_{i}"):
|
||||
inps = {**base_inps, **extra_inps}
|
||||
opts = {**base_opts, **extra_opts}
|
||||
if "qkv_hidden_sizes" in opts: outputs = {"output": np.empty((batch_size, seq_len, 128), dtype=np.float32)}
|
||||
else: outputs = {"output": np.empty((batch_size, seq_len, hidden_size), dtype=np.float32)}
|
||||
outputs = ["output", "present"] if "past" in inps else ["output"]
|
||||
self.helper_test_single_op("Attention", inps, opts, outputs, atol=1e-4)
|
||||
|
||||
def test_skip_layer_normalization(self):
|
||||
@@ -226,12 +219,7 @@ class TestContribOnnxOps(TestOnnxOps):
|
||||
if has_beta: inputs["beta"] = np.random.randn(hidden_size).astype(np.float32)
|
||||
if has_bias: inputs["bias"] = np.random.randn(hidden_size).astype(np.float32)
|
||||
attributes = {"epsilon": 1e-12}
|
||||
outputs = {
|
||||
"output": np.empty(shape, dtype=np.float32),
|
||||
"mean": np.empty(0, dtype=np.float32),
|
||||
"inv_std_var": np.empty(0, dtype=np.float32),
|
||||
"input_skip_bias_sum": np.empty(shape, dtype=np.float32)
|
||||
}
|
||||
outputs = ["output", "mean", "inv_std_var", "input_skip_bias_sum"]
|
||||
self.helper_test_single_op("SkipLayerNormalization", inputs, attributes, outputs)
|
||||
|
||||
def test_bias_gelu(self):
|
||||
@@ -241,7 +229,7 @@ class TestContribOnnxOps(TestOnnxOps):
|
||||
"B": np.random.randn(shape[-1]).astype(np.float32)
|
||||
}
|
||||
attributes = {}
|
||||
outputs = {"C": np.empty(shape, dtype=np.float32)}
|
||||
outputs = ["C"]
|
||||
self.helper_test_single_op("BiasGelu", inputs, attributes, outputs)
|
||||
|
||||
def test_qlinear_add(self):
|
||||
@@ -259,7 +247,7 @@ class TestContribOnnxOps(TestOnnxOps):
|
||||
"C_zero_point": np.array(zero_point, dtype=dtype)
|
||||
}
|
||||
attributes = {}
|
||||
outputs = {"C": np.empty((10, 10), dtype=dtype)}
|
||||
outputs = ["C"]
|
||||
self.helper_test_single_op("QLinearAdd", inputs, attributes, outputs, atol=1) # TODO: look into why this is inaccurate
|
||||
|
||||
with self.subTest(test_case="round_half_to_even"):
|
||||
@@ -274,7 +262,7 @@ class TestContribOnnxOps(TestOnnxOps):
|
||||
"C_zero_point": np.array(0, dtype=np.int8)
|
||||
}
|
||||
attributes = {}
|
||||
outputs = {"C": np.empty((4,), dtype=np.int8)}
|
||||
outputs = ["C"]
|
||||
self.helper_test_single_op("QLinearAdd", inputs, attributes, outputs)
|
||||
|
||||
def test_qlinear_mul(self):
|
||||
@@ -292,7 +280,7 @@ class TestContribOnnxOps(TestOnnxOps):
|
||||
"C_zero_point": np.array(zero_point, dtype=dtype)
|
||||
}
|
||||
attributes = {}
|
||||
outputs = {"C": np.empty((10, 10), dtype=dtype)}
|
||||
outputs = ["C"]
|
||||
self.helper_test_single_op("QLinearMul", inputs, attributes, outputs)
|
||||
|
||||
with self.subTest(test_case="round_half_to_even"):
|
||||
@@ -307,7 +295,7 @@ class TestContribOnnxOps(TestOnnxOps):
|
||||
"C_zero_point": np.array(0, dtype=np.int8)
|
||||
}
|
||||
attributes = {}
|
||||
outputs = {"C": np.empty((4,), dtype=np.int8)}
|
||||
outputs = ["C"]
|
||||
self.helper_test_single_op("QLinearMul", inputs, attributes, outputs)
|
||||
|
||||
def test_qlinear_global_average_pool(self):
|
||||
@@ -322,7 +310,7 @@ class TestContribOnnxOps(TestOnnxOps):
|
||||
"y_zero_point": np.array(zero_point, dtype=dtype)
|
||||
}
|
||||
attributes = {"channels_last": 0}
|
||||
outputs = {"C": np.empty((1, 3, 1, 1), dtype=dtype)}
|
||||
outputs = ["C"]
|
||||
self.helper_test_single_op("QLinearGlobalAveragePool", inputs, attributes, outputs)
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
Reference in New Issue
Block a user