mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-07 22:23:55 -05:00
add max/min reduction support to ScatterND (#13562)
This commit is contained in:
3
test/external/external_test_onnx_backend.py
vendored
3
test/external/external_test_onnx_backend.py
vendored
@@ -184,9 +184,6 @@ backend_test.exclude('test_ai_onnx_ml_label_encoder_tensor_mapping_cpu') # bad d
|
||||
backend_test.exclude('test_if_opt_cpu') # ValueError: 13 is not a valid AttributeType
|
||||
backend_test.exclude('test_if_seq_cpu') # NotImplementedError: op='SequenceConstruct' is not supported
|
||||
|
||||
backend_test.exclude('test_scatternd_min_cpu') # min not yet supported
|
||||
backend_test.exclude('test_scatternd_max_cpu') # max not yet supported
|
||||
|
||||
# regression from removing StrEnum in Domain
|
||||
backend_test.exclude('test_adam_cpu')
|
||||
backend_test.exclude('test_gradient_of_add_and_mul_cpu')
|
||||
|
||||
@@ -1158,7 +1158,7 @@ def get_onnx_ops() -> dict[str, types.FunctionType|dict[OpSetId, types.FunctionT
|
||||
b_idx = Tensor.arange(b, device=x.device).reshape(b, *(1,)*(indices.ndim - 2)).expand(*indices.shape[:-1])
|
||||
ret = x[(b_idx,) + tuple(i.squeeze(-1) for i in indices.split(1, -1))]
|
||||
return ret.reshape(*x_shape[:batch_dims], *i_shape[batch_dims:-1], *ret.shape[indices.ndim-1:])
|
||||
def ScatterND(x:Tensor, indices:Tensor, updates:Tensor, reduction:Literal["none", "add", "mul"]='none'):
|
||||
def ScatterND(x:Tensor, indices:Tensor, updates:Tensor, reduction:Literal["none", "add", "mul", "max", "min"]='none'):
|
||||
assert updates.shape == indices.shape[:-1] + x.shape[cast(int, indices.shape[-1]):]
|
||||
x = x.contiguous()
|
||||
for index, u in zip(indices.split(1, 0), updates.split(1, 0)):
|
||||
@@ -1167,7 +1167,8 @@ def get_onnx_ops() -> dict[str, types.FunctionType|dict[OpSetId, types.FunctionT
|
||||
if reduction == "none": x[i] = u
|
||||
elif reduction == "add": x[i] += u
|
||||
elif reduction == "mul": x[i] *= u
|
||||
else: raise NotImplementedError("reduction doesn't support max or min")
|
||||
elif reduction == "max": x[i] = x[i].maximum(u)
|
||||
elif reduction == "min": x[i] = x[i].minimum(u)
|
||||
return x
|
||||
|
||||
def ScatterElements(x: Tensor, indices: Tensor, updates: Tensor, axis=0, reduction:Literal["none", "add", "mul", "min", "max"]="none"):
|
||||
|
||||
Reference in New Issue
Block a user