Add equal function implementation and corresponding test (#10351)

- Implemented a new function `equal` in the torch backend to compare two tensors for equality.
- Added unit tests for the `equal` function to verify its correctness with different tensor inputs.
This commit is contained in:
Xingyu
2025-05-17 14:39:49 +08:00
committed by GitHub
parent e13f2a3092
commit 286b0f4051
2 changed files with 10 additions and 0 deletions

View File

@@ -583,6 +583,9 @@ def wrap_fxn(k,f):
for k,v in tiny_backend.items(): torch.library.impl(k.replace("aten.", "aten::"), "privateuseone")(wrap_fxn(k,v))
@torch.library.impl("aten::equal", "privateuseone")
def equal(x: torch.Tensor, y: torch.Tensor): return (x==y).all().item()
if TORCH_DEBUG:
from torch.utils._python_dispatch import TorchDispatchMode
class DispatchLog(TorchDispatchMode):

View File

@@ -163,6 +163,13 @@ class TestTorchBackend(unittest.TestCase):
a = torch.randn(10, 10, device=device, dtype=torch_dtype)
self.assertEqual(a.dtype, torch_dtype)
def test_equal(self):
tensor_a = torch.tensor([[1, 2], [3, 4]], device=device)
tensor_b = torch.tensor([[1, 2], [3, 4]], device=device)
tensor_c = torch.tensor([[1, 2], [1, 2]], device=device)
assert torch.equal(tensor_a, tensor_b)
assert not torch.equal(tensor_a, tensor_c)
@unittest.skip("meh")
def test_str(self):
a = torch.ones(4, device=device)