improve rendering of shapes in viz + investigate symbolic [pr] (#10091)

This commit is contained in:
George Hotz
2025-04-28 16:44:09 -04:00
committed by GitHub
parent dbb7aee02e
commit d32f5e9f3a
4 changed files with 20 additions and 9 deletions

View File

@@ -1,6 +1,6 @@
import unittest
from tinygrad import Variable
from tinygrad.helpers import Context
from tinygrad.helpers import Context, GlobalCounters
from tinygrad.tensor import Tensor
from examples.gpt2 import Attention
import numpy as np
@@ -43,17 +43,24 @@ class TestSymbolicOps(unittest.TestCase):
expected = f(a, b).numpy()
np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6)
def test_attention(self, dropout_p=0.0):
def test_attention(self, dropout_p=0.0, imin=1, imax=5, use_symbolic=True):
def f(q, k, v): return Tensor.scaled_dot_product_attention(q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2), dropout_p=dropout_p).realize()
for i in range(1, 5):
vi = Variable("i", 1, 10).bind(i)
for i in range(imin, imax):
vi = Variable("i", 1, 10).bind(i) if use_symbolic else i
q = Tensor.rand(2, 1, 4, 8)
k = Tensor.rand(2, i, 4, 8)
v = Tensor.rand(2, i, 4, 8)
Tensor.realize(q, k, v)
GlobalCounters.reset()
symbolic = f(q, k.reshape(2, vi, 4, 8), v.reshape(2, vi, 4, 8)).reshape(2, 4, 1, 8).numpy()
expected = f(q, k, v).numpy()
np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6)
def test_attention_cmp_symbolic(self):
# symbolic isn't seeing if i == i, so it's not putting them on the same axis
self.test_attention(imin=4, imax=5, use_symbolic=False)
self.test_attention(imin=4, imax=5, use_symbolic=True)
def test_attention_training(self):
with Tensor.train():
self.test_attention(dropout_p=0.0)