mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
new Onnx Gather (#14187)
instead of assuming const indices, check if it showed as a const
This commit is contained in:
3
test/external/external_test_onnx_ops.py
vendored
3
test/external/external_test_onnx_ops.py
vendored
@@ -89,10 +89,9 @@ class TestMainOnnxOps(TestOnnxOps):
|
||||
# without JIT: correct
|
||||
self.assertEqual([Gather(x, Tensor(idx)).tolist() for idx in indices_list], expected)
|
||||
|
||||
# TODO: Gather should not assume indices is const, result should be [[10, 20], [30, 40], [50, 10]]
|
||||
@TinyJit
|
||||
def gather_jit(x, indices): return Gather(x, indices)
|
||||
self.assertEqual([gather_jit(x, Tensor(idx)).tolist() for idx in indices_list], [[10, 20], [30, 40], [30, 40]])
|
||||
self.assertEqual([gather_jit(x, Tensor(idx)).tolist() for idx in indices_list], expected)
|
||||
|
||||
# NOTE: resize OP is sensitive to numerical errors
|
||||
def _test_resize_scales(self, scale_values, **kwargs):
|
||||
|
||||
@@ -404,6 +404,8 @@ class OnnxRunner:
|
||||
self.graph_inputs = {i["name"]: i["parsed_type"] for i in graph["input"] if i["name"] not in self.graph_values}
|
||||
self.graph_outputs = tuple(o["name"] for o in graph["output"])
|
||||
self.graph_nodes = tuple(n["parsed_node"] for n in graph["node"])
|
||||
# track names from initializers and Constant nodes for fast path optimizations
|
||||
self.const_names: set[str] = set(self.graph_values.keys()) | {o for n in self.graph_nodes if n.op == "Constant" for o in n.outputs}
|
||||
|
||||
self.old_training = Tensor.training
|
||||
Tensor.training = self.is_training
|
||||
@@ -466,6 +468,9 @@ class OnnxRunner:
|
||||
# provide additional opts
|
||||
if node.op == "Split" and 'num_outputs' not in opts: opts['num_outputs'] = len(node.outputs)
|
||||
if node.op in {"Gradient", "If"}: opts['intermediate_tensors'] = self.graph_values
|
||||
# for Gather, convert indices to python const if from Constant/initializer for shrink fast path
|
||||
if node.op == "Gather" and len(node.inputs) > 1 and node.inputs[1] in self.const_names:
|
||||
inps[1] = _cached_to_python_const(self.graph_values[node.inputs[1]])
|
||||
|
||||
if debug >= 1: print((f"[{self.graph_name}] " if self.graph_name else "") + f"{num}: op '{node.op}' opt {opts}")
|
||||
if debug >= 2 and node.inputs: print("\tinputs:\n" + "\n".join(f"\t\t{x} - {i!r}" for x,i in zip(node.inputs, inps)))
|
||||
@@ -1142,15 +1147,16 @@ def get_onnx_ops() -> dict[str, types.FunctionType|dict[OpSetId, types.FunctionT
|
||||
|
||||
def ArrayFeatureExtractor(x:Tensor, indices:Tensor): return x[..., indices]
|
||||
|
||||
def Gather(x:Tensor, indices:Tensor, axis:int=0):
|
||||
if indices.numel() < 9: # NOTE lessor kernels for smaller indices but kernel number increases depending on size of indices
|
||||
ret_shape = x.shape[:axis] + indices.shape + x.shape[axis+1:]
|
||||
if indices.ndim > 1: indices = indices.flatten()
|
||||
index_consts = [_cached_to_python_const(indices)] if indices.shape == () else _cached_to_python_const(indices)
|
||||
index_consts = [x.shape[axis]+i if i<0 else i for i in index_consts]
|
||||
args = [[(0,x) if j != axis else (i,i+1) for j, x in enumerate(x.shape)] for i in index_consts]
|
||||
def Gather(x:Tensor, indices:Tensor|list[int]|int, axis:int=0):
|
||||
axis = x._resolve_dim(axis)
|
||||
# fast path for constant indices (passed as python list/int from to_python_const)
|
||||
if not isinstance(indices, Tensor):
|
||||
indices_list = [indices] if isinstance(indices, int) else list(indices)
|
||||
indices_shape = () if isinstance(indices, int) else (len(indices_list),)
|
||||
ret_shape = x.shape[:axis] + indices_shape + x.shape[axis+1:]
|
||||
index_consts = [x.shape[axis]+i if i<0 else i for i in indices_list]
|
||||
args = [[(0,s) if j != axis else (i,i+1) for j, s in enumerate(x.shape)] for i in index_consts]
|
||||
return x.shrink(arg=tuple(args[0])).cat(*[x.shrink(arg=tuple(arg)) for arg in args[1:]], dim=axis).reshape(ret_shape)
|
||||
# NOTE faster gather, fixed number of kernels, but exceeds limited kernels for openpilot
|
||||
return x[tuple([slice(None) if i != axis else indices for i in range(x.ndim)])]
|
||||
def Scatter(*args, **kwargs): return ScatterElements(*args, **kwargs) # deprecated
|
||||
|
||||
|
||||
Reference in New Issue
Block a user