mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 15:08:02 -05:00
@@ -3058,6 +3058,19 @@ class TestOps(unittest.TestCase):
|
||||
def test_bitcast(self):
|
||||
helper_test_op([(3, 3)], lambda x: x.view(torch.int32), lambda x: x.bitcast(dtypes.int32), forward_only=True)
|
||||
|
||||
def test_svd(self):
|
||||
# test for tiny backend. real svd tests are in test_linalg
|
||||
A = torch.randn(5, 5)
|
||||
U, S, Vh = torch.linalg.svd(A, full_matrices=True)
|
||||
np.testing.assert_allclose(torch.dist(A, U @ torch.diag(S) @ Vh).cpu().numpy(), 0, atol=1e-5)
|
||||
U, S, Vh = torch.linalg.svd(A, full_matrices=False)
|
||||
np.testing.assert_allclose(torch.dist(A, U @ torch.diag(S) @ Vh).cpu().numpy(), 0, atol=1e-5)
|
||||
|
||||
# # TODO: this works with torch, but not TINY_BACKEND. U has a wrong shape
|
||||
# A = torch.randn(5, 3)
|
||||
# U, S, Vh = torch.linalg.svd(A, full_matrices=False)
|
||||
# np.testing.assert_allclose(torch.dist(A, U @ torch.diag(S) @ Vh).cpu().numpy(), 0, atol=1e-5)
|
||||
|
||||
@unittest.skipUnless(is_dtype_supported(dtypes.uchar), f"no uint8 on {Device.DEFAULT}")
|
||||
class TestOpsUint8(unittest.TestCase):
|
||||
def test_cast(self):
|
||||
|
||||
Reference in New Issue
Block a user