tensor.py: add normalize function (#15159)

* tensor.py: add normalize function

* p==0 should match torch
This commit is contained in:
Roelof van Dijk
2026-03-05 11:55:53 +01:00
committed by GitHub
parent 4544da1c54
commit d65923bda5
3 changed files with 32 additions and 0 deletions

View File

@@ -505,7 +505,9 @@ tiny_backend_out = {**{f"aten.{x}.out":getattr(Tensor,x) for x in simple_tensor_
"aten.lt.Tensor_out": Tensor.__lt__, "aten.lt.Scalar_out": Tensor.__lt__,
"aten.le.Tensor_out": Tensor.__le__, "aten.le.Scalar_out": Tensor.__le__,
"aten.clamp_max.Tensor_out": lambda input,max_: input.clamp(max_=max_),
"aten.clamp_max.out": lambda input,max_: input.clamp(max_=max_),
"aten.clamp_min.Tensor_out": lambda input,min_: input.clamp(min_=min_),
"aten.clamp_min.out": lambda input,min_: input.clamp(min_=min_),
"aten.fmod.Tensor_out": lambda input,other: input-input.div(other, rounding_mode="trunc")*other,
# TODO: this might result in overflow issues
"aten.round.decimals_out": lambda self,decimals: (self*10**decimals).round()/10**decimals,

View File

@@ -1665,6 +1665,15 @@ class TestOps(unittest.TestCase):
helper_test_op([(10,10,10)], lambda x: x.log_softmax(1), atol=1e-7, grad_atol=1e-7)
helper_test_op([(10,10,10)], lambda x: x.log_softmax(2), atol=1e-7, grad_atol=1e-7)
def test_normalize(self):
helper_test_op([(45,65)], lambda x: torch.nn.functional.normalize(x), lambda x: x.normalize(), atol=1e-7, grad_atol=1e-7)
helper_test_op([(45,65)], lambda x: torch.nn.functional.normalize(x, dim=0), lambda x: x.normalize(dim=0), atol=1e-7, grad_atol=1e-7)
helper_test_op([(10,10,10)], lambda x: torch.nn.functional.normalize(x, dim=2), lambda x: x.normalize(dim=2), atol=1e-7, grad_atol=1e-7)
helper_test_op([(45,65)], lambda x: torch.nn.functional.normalize(x, p=1), lambda x: x.normalize(p=1), atol=1e-7, grad_atol=1e-7)
helper_test_op([(45,65)], lambda x: torch.nn.functional.normalize(x, p=3, dim=0), lambda x: x.normalize(p=3, dim=0), atol=1e-7, grad_atol=1e-7)
helper_test_op([(45,65)], lambda x: torch.nn.functional.normalize(x, p=0), lambda x: x.normalize(p=0), atol=1e-7, grad_atol=1e-7)
helper_test_op([(45,65)], lambda x: torch.nn.functional.normalize(x, p=-1), lambda x: x.normalize(p=-1), atol=1e-7, grad_atol=1e-7)
def test_logsumexp(self):
helper_test_op([(45,65)], lambda x: torch.logsumexp(x, dim=0), lambda x: x.logsumexp(0), atol=1e-7, grad_atol=1e-7)
helper_test_op([(45,65)], lambda x: torch.logsumexp(x, dim=0, keepdim=True), lambda x: x.logsumexp(0, True), atol=1e-7, grad_atol=1e-7)

View File

@@ -2028,6 +2028,27 @@ class Tensor(OpMixin):
m, _, ss = self._softmax(axis, dtype)
return m - ss.log()
def normalize(self, p:float=2.0, dim:int=1, eps:float=1e-12) -> Tensor:
"""
Performs Lp normalization of the tensor along the specified dimension.
See: https://pytorch.org/docs/stable/generated/torch.nn.functional.normalize.html
```python exec="true" source="above" session="tensor" result="python"
Tensor.manual_seed(42)
t = Tensor.randn(2, 3)
print(t.numpy())
```
```python exec="true" source="above" session="tensor" result="python"
print(t.normalize().numpy())
```
```python exec="true" source="above" session="tensor" result="python"
print(t.normalize(p=1, dim=0).numpy())
```
"""
if p == 0: return self / (self != 0).sum(dim, keepdim=True).maximum(eps)
return self / self.abs().pow(p).sum(dim, keepdim=True).pow(1/p).maximum(eps)
def logsumexp(self, axis=None, keepdim=False) -> Tensor:
"""
Computes the log-sum-exp of the tensor along the specified axis or axes.