From 3d68feb67dd335e8d59e5a174cb945e704bb62d6 Mon Sep 17 00:00:00 2001 From: chenyu Date: Fri, 25 Jul 2025 21:08:08 -0400 Subject: [PATCH] minor onnx Gather cleanup (#11375) removed a type ignore and one error code skip --- extra/onnx.py | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/extra/onnx.py b/extra/onnx.py index 2ae3c2d2e5..2c634b4f06 100644 --- a/extra/onnx.py +++ b/extra/onnx.py @@ -1,4 +1,4 @@ -# mypy: disable-error-code="misc, list-item, assignment, attr-defined, operator, index, arg-type" +# mypy: disable-error-code="misc, list-item, assignment, operator, index, arg-type" from types import SimpleNamespace from typing import Any, Sequence, cast, Literal, Callable, get_args, NamedTuple import dataclasses, functools, io, math, types, warnings, pathlib, sys, enum @@ -798,12 +798,11 @@ def get_onnx_ops() -> dict[str, types.FunctionType|dict[OpSetId, types.FunctionT def Gather(x:Tensor, indices:Tensor, axis:int=0): if indices.numel() < 9: # NOTE lessor kernels for smaller indices but kernel number increases depending on size of indices - x_sh = list(x.shape) - ret_shape = x_sh[:axis] + list(indices.shape) + x_sh[axis+1:] + ret_shape = x.shape[:axis] + indices.shape + x.shape[axis+1:] if indices.ndim > 1: indices = indices.flatten() - indices = [_cached_to_python_const(indices)] if indices.shape == () else _cached_to_python_const(indices) - indices = [x_sh[axis]+x if x<0 else x for x in indices] - args = [[(0,x) if j != axis else (i,i+1) for j, x in enumerate(x_sh)] for i in indices] # type: ignore + index_consts = [_cached_to_python_const(indices)] if indices.shape == () else _cached_to_python_const(indices) + index_consts = [x.shape[axis]+i if i<0 else i for i in index_consts] + args = [[(0,x) if j != axis else (i,i+1) for j, x in enumerate(x.shape)] for i in index_consts] return x.shrink(arg=tuple(args[0])).cat(*[x.shrink(arg=tuple(arg)) for arg in args[1:]], dim=axis).reshape(ret_shape) # NOTE faster gather, fixed number of kernels, but exceeds limited kernels for openpilot return x[tuple([slice(None) if i != axis else indices for i in range(x.ndim)])]