new Onnx Gather (#14187)

instead of assuming const indices, check if it showed as a const
This commit is contained in:
chenyu
2026-01-16 22:24:07 -05:00
committed by GitHub
parent 9f7f2f0e0c
commit 5e6a72c33f
2 changed files with 15 additions and 10 deletions

View File

@@ -89,10 +89,9 @@ class TestMainOnnxOps(TestOnnxOps):
# without JIT: correct
self.assertEqual([Gather(x, Tensor(idx)).tolist() for idx in indices_list], expected)
# TODO: Gather should not assume indices is const, result should be [[10, 20], [30, 40], [50, 10]]
@TinyJit
def gather_jit(x, indices): return Gather(x, indices)
self.assertEqual([gather_jit(x, Tensor(idx)).tolist() for idx in indices_list], [[10, 20], [30, 40], [30, 40]])
self.assertEqual([gather_jit(x, Tensor(idx)).tolist() for idx in indices_list], expected)
# NOTE: resize OP is sensitive to numerical errors
def _test_resize_scales(self, scale_values, **kwargs):

View File

@@ -404,6 +404,8 @@ class OnnxRunner:
self.graph_inputs = {i["name"]: i["parsed_type"] for i in graph["input"] if i["name"] not in self.graph_values}
self.graph_outputs = tuple(o["name"] for o in graph["output"])
self.graph_nodes = tuple(n["parsed_node"] for n in graph["node"])
# track names from initializers and Constant nodes for fast path optimizations
self.const_names: set[str] = set(self.graph_values.keys()) | {o for n in self.graph_nodes if n.op == "Constant" for o in n.outputs}
self.old_training = Tensor.training
Tensor.training = self.is_training
@@ -466,6 +468,9 @@ class OnnxRunner:
# provide additional opts
if node.op == "Split" and 'num_outputs' not in opts: opts['num_outputs'] = len(node.outputs)
if node.op in {"Gradient", "If"}: opts['intermediate_tensors'] = self.graph_values
# for Gather, convert indices to python const if from Constant/initializer for shrink fast path
if node.op == "Gather" and len(node.inputs) > 1 and node.inputs[1] in self.const_names:
inps[1] = _cached_to_python_const(self.graph_values[node.inputs[1]])
if debug >= 1: print((f"[{self.graph_name}] " if self.graph_name else "") + f"{num}: op '{node.op}' opt {opts}")
if debug >= 2 and node.inputs: print("\tinputs:\n" + "\n".join(f"\t\t{x} - {i!r}" for x,i in zip(node.inputs, inps)))
@@ -1142,15 +1147,16 @@ def get_onnx_ops() -> dict[str, types.FunctionType|dict[OpSetId, types.FunctionT
def ArrayFeatureExtractor(x:Tensor, indices:Tensor): return x[..., indices]
def Gather(x:Tensor, indices:Tensor, axis:int=0):
if indices.numel() < 9: # NOTE lessor kernels for smaller indices but kernel number increases depending on size of indices
ret_shape = x.shape[:axis] + indices.shape + x.shape[axis+1:]
if indices.ndim > 1: indices = indices.flatten()
index_consts = [_cached_to_python_const(indices)] if indices.shape == () else _cached_to_python_const(indices)
index_consts = [x.shape[axis]+i if i<0 else i for i in index_consts]
args = [[(0,x) if j != axis else (i,i+1) for j, x in enumerate(x.shape)] for i in index_consts]
def Gather(x:Tensor, indices:Tensor|list[int]|int, axis:int=0):
axis = x._resolve_dim(axis)
# fast path for constant indices (passed as python list/int from to_python_const)
if not isinstance(indices, Tensor):
indices_list = [indices] if isinstance(indices, int) else list(indices)
indices_shape = () if isinstance(indices, int) else (len(indices_list),)
ret_shape = x.shape[:axis] + indices_shape + x.shape[axis+1:]
index_consts = [x.shape[axis]+i if i<0 else i for i in indices_list]
args = [[(0,s) if j != axis else (i,i+1) for j, s in enumerate(x.shape)] for i in index_consts]
return x.shrink(arg=tuple(args[0])).cat(*[x.shrink(arg=tuple(arg)) for arg in args[1:]], dim=axis).reshape(ret_shape)
# NOTE faster gather, fixed number of kernels, but exceeds limited kernels for openpilot
return x[tuple([slice(None) if i != axis else indices for i in range(x.ndim)])]
def Scatter(*args, **kwargs): return ScatterElements(*args, **kwargs) # deprecated