mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-07 03:00:26 -04:00
fix CUDA float4 issues
This commit is contained in:
24
test/test_example.py
Normal file
24
test/test_example.py
Normal file
@@ -0,0 +1,24 @@
|
||||
import unittest
|
||||
from tinygrad.tensor import Tensor
|
||||
|
||||
class TestExample(unittest.TestCase):
|
||||
def test_example_readme(self):
|
||||
x = Tensor.eye(3, requires_grad=True)
|
||||
y = Tensor([[2.0,0,-2.0]], requires_grad=True)
|
||||
z = y.matmul(x).sum()
|
||||
z.backward()
|
||||
|
||||
print(x.grad.numpy()) # dz/dx
|
||||
print(y.grad.numpy()) # dz/dy
|
||||
|
||||
def test_example_matmul(self):
|
||||
x = Tensor.eye(256, requires_grad=True)
|
||||
y = Tensor.eye(256, requires_grad=True)
|
||||
z = y.matmul(x).sum()
|
||||
z.backward()
|
||||
|
||||
print(x.grad.numpy()) # dz/dx
|
||||
print(y.grad.numpy()) # dz/dy
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
Reference in New Issue
Block a user