From 5e6a72c33f989a40cb2b3da36f94acd93522d9da Mon Sep 17 00:00:00 2001 From: chenyu Date: Fri, 16 Jan 2026 22:24:07 -0500 Subject: [PATCH] new Onnx Gather (#14187) instead of assuming const indices, check if it showed as a const --- test/external/external_test_onnx_ops.py | 3 +-- tinygrad/nn/onnx.py | 22 ++++++++++++++-------- 2 files changed, 15 insertions(+), 10 deletions(-) diff --git a/test/external/external_test_onnx_ops.py b/test/external/external_test_onnx_ops.py index d5bb75f5a5..194d734988 100644 --- a/test/external/external_test_onnx_ops.py +++ b/test/external/external_test_onnx_ops.py @@ -89,10 +89,9 @@ class TestMainOnnxOps(TestOnnxOps): # without JIT: correct self.assertEqual([Gather(x, Tensor(idx)).tolist() for idx in indices_list], expected) - # TODO: Gather should not assume indices is const, result should be [[10, 20], [30, 40], [50, 10]] @TinyJit def gather_jit(x, indices): return Gather(x, indices) - self.assertEqual([gather_jit(x, Tensor(idx)).tolist() for idx in indices_list], [[10, 20], [30, 40], [30, 40]]) + self.assertEqual([gather_jit(x, Tensor(idx)).tolist() for idx in indices_list], expected) # NOTE: resize OP is sensitive to numerical errors def _test_resize_scales(self, scale_values, **kwargs): diff --git a/tinygrad/nn/onnx.py b/tinygrad/nn/onnx.py index 24f509ec55..da1eea9efe 100644 --- a/tinygrad/nn/onnx.py +++ b/tinygrad/nn/onnx.py @@ -404,6 +404,8 @@ class OnnxRunner: self.graph_inputs = {i["name"]: i["parsed_type"] for i in graph["input"] if i["name"] not in self.graph_values} self.graph_outputs = tuple(o["name"] for o in graph["output"]) self.graph_nodes = tuple(n["parsed_node"] for n in graph["node"]) + # track names from initializers and Constant nodes for fast path optimizations + self.const_names: set[str] = set(self.graph_values.keys()) | {o for n in self.graph_nodes if n.op == "Constant" for o in n.outputs} self.old_training = Tensor.training Tensor.training = self.is_training @@ -466,6 +468,9 @@ class OnnxRunner: # provide additional opts if node.op == "Split" and 'num_outputs' not in opts: opts['num_outputs'] = len(node.outputs) if node.op in {"Gradient", "If"}: opts['intermediate_tensors'] = self.graph_values + # for Gather, convert indices to python const if from Constant/initializer for shrink fast path + if node.op == "Gather" and len(node.inputs) > 1 and node.inputs[1] in self.const_names: + inps[1] = _cached_to_python_const(self.graph_values[node.inputs[1]]) if debug >= 1: print((f"[{self.graph_name}] " if self.graph_name else "") + f"{num}: op '{node.op}' opt {opts}") if debug >= 2 and node.inputs: print("\tinputs:\n" + "\n".join(f"\t\t{x} - {i!r}" for x,i in zip(node.inputs, inps))) @@ -1142,15 +1147,16 @@ def get_onnx_ops() -> dict[str, types.FunctionType|dict[OpSetId, types.FunctionT def ArrayFeatureExtractor(x:Tensor, indices:Tensor): return x[..., indices] - def Gather(x:Tensor, indices:Tensor, axis:int=0): - if indices.numel() < 9: # NOTE lessor kernels for smaller indices but kernel number increases depending on size of indices - ret_shape = x.shape[:axis] + indices.shape + x.shape[axis+1:] - if indices.ndim > 1: indices = indices.flatten() - index_consts = [_cached_to_python_const(indices)] if indices.shape == () else _cached_to_python_const(indices) - index_consts = [x.shape[axis]+i if i<0 else i for i in index_consts] - args = [[(0,x) if j != axis else (i,i+1) for j, x in enumerate(x.shape)] for i in index_consts] + def Gather(x:Tensor, indices:Tensor|list[int]|int, axis:int=0): + axis = x._resolve_dim(axis) + # fast path for constant indices (passed as python list/int from to_python_const) + if not isinstance(indices, Tensor): + indices_list = [indices] if isinstance(indices, int) else list(indices) + indices_shape = () if isinstance(indices, int) else (len(indices_list),) + ret_shape = x.shape[:axis] + indices_shape + x.shape[axis+1:] + index_consts = [x.shape[axis]+i if i<0 else i for i in indices_list] + args = [[(0,s) if j != axis else (i,i+1) for j, s in enumerate(x.shape)] for i in index_consts] return x.shrink(arg=tuple(args[0])).cat(*[x.shrink(arg=tuple(arg)) for arg in args[1:]], dim=axis).reshape(ret_shape) - # NOTE faster gather, fixed number of kernels, but exceeds limited kernels for openpilot return x[tuple([slice(None) if i != axis else indices for i in range(x.ndim)])] def Scatter(*args, **kwargs): return ScatterElements(*args, **kwargs) # deprecated