mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 15:08:02 -05:00
shared_codegen_spec and fix index spec (#12967)
* split shared_codegen_spec and fix index * add VCONST to program_spec and move index to shared_codegen_spec * working ignore_oob=0 * cleanup * fix spec * undo that * move barrier and special earlier * fix more spec issues * more updates * remove special from program_spec * cleanup and fixes * move more to shared * special is not in shared_spec * some comments * dont do bounds check there
This commit is contained in:
@@ -1,5 +1,5 @@
|
||||
import unittest, functools
|
||||
from tinygrad import Tensor
|
||||
from tinygrad import Tensor, Context
|
||||
import numpy as np
|
||||
|
||||
def orthogonality_helper(A:Tensor, tolerance=1e-5):
|
||||
@@ -27,15 +27,16 @@ class TestLinAlg(unittest.TestCase):
|
||||
reconstruction_helper([U,s_diag,V],a)
|
||||
|
||||
def _test_svd_nonfull(self, size):
|
||||
a = Tensor.randn(size).realize()
|
||||
U,S,V = a.svd(full_matrices=False)
|
||||
b_shape,m,n = size[0:-2],size[-2],size[-1]
|
||||
k = min(m,n)
|
||||
s_diag = (S.unsqueeze(-2) * Tensor.eye(k).reshape((1,) * len(b_shape) + (k,k)).expand(b_shape + (k,k)))
|
||||
#reduced U,V is only orthogonal along smaller dim
|
||||
if (m < n): orthogonality_helper(U),orthogonality_helper(V)
|
||||
else: orthogonality_helper(U.transpose(-2,-1)),orthogonality_helper(V.transpose(-2,-1))
|
||||
reconstruction_helper([U,s_diag,V],a)
|
||||
with Context(IGNORE_OOB=1): # sometimes this is slow in CI
|
||||
a = Tensor.randn(size).realize()
|
||||
U,S,V = a.svd(full_matrices=False)
|
||||
b_shape,m,n = size[0:-2],size[-2],size[-1]
|
||||
k = min(m,n)
|
||||
s_diag = (S.unsqueeze(-2) * Tensor.eye(k).reshape((1,) * len(b_shape) + (k,k)).expand(b_shape + (k,k)))
|
||||
#reduced U,V is only orthogonal along smaller dim
|
||||
if (m < n): orthogonality_helper(U),orthogonality_helper(V)
|
||||
else: orthogonality_helper(U.transpose(-2,-1)),orthogonality_helper(V.transpose(-2,-1))
|
||||
reconstruction_helper([U,s_diag,V],a)
|
||||
|
||||
# faster for parallel pytest
|
||||
def test_svd_nonfull_2_2(self): self._test_svd_nonfull((2,2))
|
||||
@@ -75,4 +76,4 @@ class TestLinAlg(unittest.TestCase):
|
||||
orthogonality_helper(b if size[-1] > size[-2] else b.transpose(-2, -1), tolerance=1e-3)
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
unittest.main()
|
||||
|
||||
Reference in New Issue
Block a user