mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-07 22:23:55 -05:00
onnx TensorScatter (#14024)
This commit is contained in:
1
test/external/external_test_onnx_backend.py
vendored
1
test/external/external_test_onnx_backend.py
vendored
@@ -167,7 +167,6 @@ backend_test.exclude('test_split_to_sequence_*')
|
||||
backend_test.exclude('test_ai_onnx_ml_tree_ensemble_*') # https://github.com/onnx/onnx/blob/main/onnx/reference/ops/aionnxml/op_tree_ensemble.py#L121
|
||||
|
||||
# TODO: not yet implemented
|
||||
backend_test.exclude('test_tensorscatter_*')
|
||||
backend_test.exclude('test_l1normalization_*')
|
||||
backend_test.exclude('test_l2normalization_*')
|
||||
backend_test.exclude('test_lpnormalization_*')
|
||||
|
||||
@@ -1171,6 +1171,17 @@ def get_onnx_ops() -> dict[str, types.FunctionType|dict[OpSetId, types.FunctionT
|
||||
elif reduction == "min": x[i] = x[i].minimum(u)
|
||||
return x
|
||||
|
||||
def TensorScatter(data: Tensor, updates: Tensor, indices: Tensor, mode: str = 'default'):
|
||||
# scatter updates along axis -2 at positions given by indices, for each batch
|
||||
B, U, D = indices.shape[0], updates.shape[-2], data.shape[-2]
|
||||
orig_shape, data_flat, updates_flat = data.shape, data.reshape(-1, D, data.shape[-1]), updates.reshape(-1, U, updates.shape[-1])
|
||||
B_total = data_flat.shape[0]
|
||||
batch_idx = Tensor.arange(B_total, device=data.device).reshape(B_total, 1).expand(B_total, U)
|
||||
indices_expanded = indices.reshape(B, *([1] * (data.ndim - 3))).expand(*orig_shape[:-2]).reshape(B_total)
|
||||
row_idx = indices_expanded.reshape(B_total, 1).expand(B_total, U) + Tensor.arange(U, device=data.device).reshape(1, U).expand(B_total, U)
|
||||
if mode == 'circular': row_idx = row_idx % D
|
||||
return ScatterND(data_flat, batch_idx.unsqueeze(-1).cat(row_idx.unsqueeze(-1), dim=-1), updates_flat).reshape(orig_shape)
|
||||
|
||||
def ScatterElements(x: Tensor, indices: Tensor, updates: Tensor, axis=0, reduction:Literal["none", "add", "mul", "min", "max"]="none"):
|
||||
indices = (indices < 0).where(x.shape[axis], 0) + indices
|
||||
if reduction == "none": return x.scatter(axis, indices, updates)
|
||||
|
||||
Reference in New Issue
Block a user