mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 06:58:11 -05:00
Add equal function implementation and corresponding test (#10351)
- Implemented a new function `equal` in the torch backend to compare two tensors for equality. - Added unit tests for the `equal` function to verify its correctness with different tensor inputs.
This commit is contained in:
@@ -583,6 +583,9 @@ def wrap_fxn(k,f):
|
||||
|
||||
for k,v in tiny_backend.items(): torch.library.impl(k.replace("aten.", "aten::"), "privateuseone")(wrap_fxn(k,v))
|
||||
|
||||
@torch.library.impl("aten::equal", "privateuseone")
|
||||
def equal(x: torch.Tensor, y: torch.Tensor): return (x==y).all().item()
|
||||
|
||||
if TORCH_DEBUG:
|
||||
from torch.utils._python_dispatch import TorchDispatchMode
|
||||
class DispatchLog(TorchDispatchMode):
|
||||
|
||||
@@ -163,6 +163,13 @@ class TestTorchBackend(unittest.TestCase):
|
||||
a = torch.randn(10, 10, device=device, dtype=torch_dtype)
|
||||
self.assertEqual(a.dtype, torch_dtype)
|
||||
|
||||
def test_equal(self):
|
||||
tensor_a = torch.tensor([[1, 2], [3, 4]], device=device)
|
||||
tensor_b = torch.tensor([[1, 2], [3, 4]], device=device)
|
||||
tensor_c = torch.tensor([[1, 2], [1, 2]], device=device)
|
||||
assert torch.equal(tensor_a, tensor_b)
|
||||
assert not torch.equal(tensor_a, tensor_c)
|
||||
|
||||
@unittest.skip("meh")
|
||||
def test_str(self):
|
||||
a = torch.ones(4, device=device)
|
||||
|
||||
Reference in New Issue
Block a user