From 286b0f40518c5556fddae23e8b94ac547a43f178 Mon Sep 17 00:00:00 2001 From: Xingyu Date: Sat, 17 May 2025 14:39:49 +0800 Subject: [PATCH] Add equal function implementation and corresponding test (#10351) - Implemented a new function `equal` in the torch backend to compare two tensors for equality. - Added unit tests for the `equal` function to verify its correctness with different tensor inputs. --- extra/torch_backend/backend.py | 3 +++ extra/torch_backend/test.py | 7 +++++++ 2 files changed, 10 insertions(+) diff --git a/extra/torch_backend/backend.py b/extra/torch_backend/backend.py index e811221ba7..3d10577bdd 100644 --- a/extra/torch_backend/backend.py +++ b/extra/torch_backend/backend.py @@ -583,6 +583,9 @@ def wrap_fxn(k,f): for k,v in tiny_backend.items(): torch.library.impl(k.replace("aten.", "aten::"), "privateuseone")(wrap_fxn(k,v)) +@torch.library.impl("aten::equal", "privateuseone") +def equal(x: torch.Tensor, y: torch.Tensor): return (x==y).all().item() + if TORCH_DEBUG: from torch.utils._python_dispatch import TorchDispatchMode class DispatchLog(TorchDispatchMode): diff --git a/extra/torch_backend/test.py b/extra/torch_backend/test.py index 30117c3dd1..f92be531c1 100644 --- a/extra/torch_backend/test.py +++ b/extra/torch_backend/test.py @@ -163,6 +163,13 @@ class TestTorchBackend(unittest.TestCase): a = torch.randn(10, 10, device=device, dtype=torch_dtype) self.assertEqual(a.dtype, torch_dtype) + def test_equal(self): + tensor_a = torch.tensor([[1, 2], [3, 4]], device=device) + tensor_b = torch.tensor([[1, 2], [3, 4]], device=device) + tensor_c = torch.tensor([[1, 2], [1, 2]], device=device) + assert torch.equal(tensor_a, tensor_b) + assert not torch.equal(tensor_a, tensor_c) + @unittest.skip("meh") def test_str(self): a = torch.ones(4, device=device)