onnx TensorScatter (#14024)

This commit is contained in:
chenyu
2026-01-05 09:05:22 -05:00
committed by GitHub
parent 9497ec00f2
commit 83063cc3e4
2 changed files with 11 additions and 1 deletions

View File

@@ -167,7 +167,6 @@ backend_test.exclude('test_split_to_sequence_*')
backend_test.exclude('test_ai_onnx_ml_tree_ensemble_*') # https://github.com/onnx/onnx/blob/main/onnx/reference/ops/aionnxml/op_tree_ensemble.py#L121
# TODO: not yet implemented
backend_test.exclude('test_tensorscatter_*')
backend_test.exclude('test_l1normalization_*')
backend_test.exclude('test_l2normalization_*')
backend_test.exclude('test_lpnormalization_*')

View File

@@ -1171,6 +1171,17 @@ def get_onnx_ops() -> dict[str, types.FunctionType|dict[OpSetId, types.FunctionT
elif reduction == "min": x[i] = x[i].minimum(u)
return x
def TensorScatter(data: Tensor, updates: Tensor, indices: Tensor, mode: str = 'default'):
# scatter updates along axis -2 at positions given by indices, for each batch
B, U, D = indices.shape[0], updates.shape[-2], data.shape[-2]
orig_shape, data_flat, updates_flat = data.shape, data.reshape(-1, D, data.shape[-1]), updates.reshape(-1, U, updates.shape[-1])
B_total = data_flat.shape[0]
batch_idx = Tensor.arange(B_total, device=data.device).reshape(B_total, 1).expand(B_total, U)
indices_expanded = indices.reshape(B, *([1] * (data.ndim - 3))).expand(*orig_shape[:-2]).reshape(B_total)
row_idx = indices_expanded.reshape(B_total, 1).expand(B_total, U) + Tensor.arange(U, device=data.device).reshape(1, U).expand(B_total, U)
if mode == 'circular': row_idx = row_idx % D
return ScatterND(data_flat, batch_idx.unsqueeze(-1).cat(row_idx.unsqueeze(-1), dim=-1), updates_flat).reshape(orig_shape)
def ScatterElements(x: Tensor, indices: Tensor, updates: Tensor, axis=0, reduction:Literal["none", "add", "mul", "min", "max"]="none"):
indices = (indices < 0).where(x.shape[axis], 0) + indices
if reduction == "none": return x.scatter(axis, indices, updates)