mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-08 22:48:25 -05:00
add max/min reduction support to ScatterND (#13562)
This commit is contained in:
3
test/external/external_test_onnx_backend.py
vendored
3
test/external/external_test_onnx_backend.py
vendored
@@ -184,9 +184,6 @@ backend_test.exclude('test_ai_onnx_ml_label_encoder_tensor_mapping_cpu') # bad d
|
|||||||
backend_test.exclude('test_if_opt_cpu') # ValueError: 13 is not a valid AttributeType
|
backend_test.exclude('test_if_opt_cpu') # ValueError: 13 is not a valid AttributeType
|
||||||
backend_test.exclude('test_if_seq_cpu') # NotImplementedError: op='SequenceConstruct' is not supported
|
backend_test.exclude('test_if_seq_cpu') # NotImplementedError: op='SequenceConstruct' is not supported
|
||||||
|
|
||||||
backend_test.exclude('test_scatternd_min_cpu') # min not yet supported
|
|
||||||
backend_test.exclude('test_scatternd_max_cpu') # max not yet supported
|
|
||||||
|
|
||||||
# regression from removing StrEnum in Domain
|
# regression from removing StrEnum in Domain
|
||||||
backend_test.exclude('test_adam_cpu')
|
backend_test.exclude('test_adam_cpu')
|
||||||
backend_test.exclude('test_gradient_of_add_and_mul_cpu')
|
backend_test.exclude('test_gradient_of_add_and_mul_cpu')
|
||||||
|
|||||||
@@ -1158,7 +1158,7 @@ def get_onnx_ops() -> dict[str, types.FunctionType|dict[OpSetId, types.FunctionT
|
|||||||
b_idx = Tensor.arange(b, device=x.device).reshape(b, *(1,)*(indices.ndim - 2)).expand(*indices.shape[:-1])
|
b_idx = Tensor.arange(b, device=x.device).reshape(b, *(1,)*(indices.ndim - 2)).expand(*indices.shape[:-1])
|
||||||
ret = x[(b_idx,) + tuple(i.squeeze(-1) for i in indices.split(1, -1))]
|
ret = x[(b_idx,) + tuple(i.squeeze(-1) for i in indices.split(1, -1))]
|
||||||
return ret.reshape(*x_shape[:batch_dims], *i_shape[batch_dims:-1], *ret.shape[indices.ndim-1:])
|
return ret.reshape(*x_shape[:batch_dims], *i_shape[batch_dims:-1], *ret.shape[indices.ndim-1:])
|
||||||
def ScatterND(x:Tensor, indices:Tensor, updates:Tensor, reduction:Literal["none", "add", "mul"]='none'):
|
def ScatterND(x:Tensor, indices:Tensor, updates:Tensor, reduction:Literal["none", "add", "mul", "max", "min"]='none'):
|
||||||
assert updates.shape == indices.shape[:-1] + x.shape[cast(int, indices.shape[-1]):]
|
assert updates.shape == indices.shape[:-1] + x.shape[cast(int, indices.shape[-1]):]
|
||||||
x = x.contiguous()
|
x = x.contiguous()
|
||||||
for index, u in zip(indices.split(1, 0), updates.split(1, 0)):
|
for index, u in zip(indices.split(1, 0), updates.split(1, 0)):
|
||||||
@@ -1167,7 +1167,8 @@ def get_onnx_ops() -> dict[str, types.FunctionType|dict[OpSetId, types.FunctionT
|
|||||||
if reduction == "none": x[i] = u
|
if reduction == "none": x[i] = u
|
||||||
elif reduction == "add": x[i] += u
|
elif reduction == "add": x[i] += u
|
||||||
elif reduction == "mul": x[i] *= u
|
elif reduction == "mul": x[i] *= u
|
||||||
else: raise NotImplementedError("reduction doesn't support max or min")
|
elif reduction == "max": x[i] = x[i].maximum(u)
|
||||||
|
elif reduction == "min": x[i] = x[i].minimum(u)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
def ScatterElements(x: Tensor, indices: Tensor, updates: Tensor, axis=0, reduction:Literal["none", "add", "mul", "min", "max"]="none"):
|
def ScatterElements(x: Tensor, indices: Tensor, updates: Tensor, axis=0, reduction:Literal["none", "add", "mul", "min", "max"]="none"):
|
||||||
|
|||||||
Reference in New Issue
Block a user