minor onnx Gather cleanup (#11375)

removed a type ignore and one error code skip
This commit is contained in:
chenyu
2025-07-25 21:08:08 -04:00
committed by GitHub
parent 88c338bfcc
commit 3d68feb67d

View File

@@ -1,4 +1,4 @@
# mypy: disable-error-code="misc, list-item, assignment, attr-defined, operator, index, arg-type"
# mypy: disable-error-code="misc, list-item, assignment, operator, index, arg-type"
from types import SimpleNamespace
from typing import Any, Sequence, cast, Literal, Callable, get_args, NamedTuple
import dataclasses, functools, io, math, types, warnings, pathlib, sys, enum
@@ -798,12 +798,11 @@ def get_onnx_ops() -> dict[str, types.FunctionType|dict[OpSetId, types.FunctionT
def Gather(x:Tensor, indices:Tensor, axis:int=0):
if indices.numel() < 9: # NOTE lessor kernels for smaller indices but kernel number increases depending on size of indices
x_sh = list(x.shape)
ret_shape = x_sh[:axis] + list(indices.shape) + x_sh[axis+1:]
ret_shape = x.shape[:axis] + indices.shape + x.shape[axis+1:]
if indices.ndim > 1: indices = indices.flatten()
indices = [_cached_to_python_const(indices)] if indices.shape == () else _cached_to_python_const(indices)
indices = [x_sh[axis]+x if x<0 else x for x in indices]
args = [[(0,x) if j != axis else (i,i+1) for j, x in enumerate(x_sh)] for i in indices] # type: ignore
index_consts = [_cached_to_python_const(indices)] if indices.shape == () else _cached_to_python_const(indices)
index_consts = [x.shape[axis]+i if i<0 else i for i in index_consts]
args = [[(0,x) if j != axis else (i,i+1) for j, x in enumerate(x.shape)] for i in index_consts]
return x.shrink(arg=tuple(args[0])).cat(*[x.shrink(arg=tuple(arg)) for arg in args[1:]], dim=axis).reshape(ret_shape)
# NOTE faster gather, fixed number of kernels, but exceeds limited kernels for openpilot
return x[tuple([slice(None) if i != axis else indices for i in range(x.ndim)])]