shared_codegen_spec and fix index spec (#12967)

* split shared_codegen_spec and fix index

* add VCONST to program_spec and move index to shared_codegen_spec

* working ignore_oob=0

* cleanup

* fix spec

* undo that

* move barrier and special earlier

* fix more spec issues

* more updates

* remove special from program_spec

* cleanup and fixes

* move more to shared

* special is not in shared_spec

* some comments

* dont do bounds check there
This commit is contained in:
Sieds Lykles
2025-10-29 09:14:11 +01:00
committed by GitHub
parent 1c362736aa
commit 9f39f6391c
7 changed files with 101 additions and 85 deletions

View File

@@ -1,5 +1,5 @@
import unittest, functools
from tinygrad import Tensor
from tinygrad import Tensor, Context
import numpy as np
def orthogonality_helper(A:Tensor, tolerance=1e-5):
@@ -27,15 +27,16 @@ class TestLinAlg(unittest.TestCase):
reconstruction_helper([U,s_diag,V],a)
def _test_svd_nonfull(self, size):
a = Tensor.randn(size).realize()
U,S,V = a.svd(full_matrices=False)
b_shape,m,n = size[0:-2],size[-2],size[-1]
k = min(m,n)
s_diag = (S.unsqueeze(-2) * Tensor.eye(k).reshape((1,) * len(b_shape) + (k,k)).expand(b_shape + (k,k)))
#reduced U,V is only orthogonal along smaller dim
if (m < n): orthogonality_helper(U),orthogonality_helper(V)
else: orthogonality_helper(U.transpose(-2,-1)),orthogonality_helper(V.transpose(-2,-1))
reconstruction_helper([U,s_diag,V],a)
with Context(IGNORE_OOB=1): # sometimes this is slow in CI
a = Tensor.randn(size).realize()
U,S,V = a.svd(full_matrices=False)
b_shape,m,n = size[0:-2],size[-2],size[-1]
k = min(m,n)
s_diag = (S.unsqueeze(-2) * Tensor.eye(k).reshape((1,) * len(b_shape) + (k,k)).expand(b_shape + (k,k)))
#reduced U,V is only orthogonal along smaller dim
if (m < n): orthogonality_helper(U),orthogonality_helper(V)
else: orthogonality_helper(U.transpose(-2,-1)),orthogonality_helper(V.transpose(-2,-1))
reconstruction_helper([U,s_diag,V],a)
# faster for parallel pytest
def test_svd_nonfull_2_2(self): self._test_svd_nonfull((2,2))
@@ -75,4 +76,4 @@ class TestLinAlg(unittest.TestCase):
orthogonality_helper(b if size[-1] > size[-2] else b.transpose(-2, -1), tolerance=1e-3)
if __name__ == "__main__":
unittest.main()
unittest.main()