fix CUDA float4 issues

This commit is contained in:
George Hotz
2023-03-06 07:16:38 -08:00
parent 7dbcc26582
commit 8c5dea8d72
3 changed files with 38 additions and 6 deletions

24
test/test_example.py Normal file
View File

@@ -0,0 +1,24 @@
import unittest
from tinygrad.tensor import Tensor
class TestExample(unittest.TestCase):
def test_example_readme(self):
x = Tensor.eye(3, requires_grad=True)
y = Tensor([[2.0,0,-2.0]], requires_grad=True)
z = y.matmul(x).sum()
z.backward()
print(x.grad.numpy()) # dz/dx
print(y.grad.numpy()) # dz/dy
def test_example_matmul(self):
x = Tensor.eye(256, requires_grad=True)
y = Tensor.eye(256, requires_grad=True)
z = y.matmul(x).sum()
z.backward()
print(x.grad.numpy()) # dz/dx
print(y.grad.numpy()) # dz/dy
if __name__ == '__main__':
unittest.main()