mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 15:08:02 -05:00
fix symbolic_ops tests with Tensor.training=True (#1686)
This commit is contained in:
@@ -44,9 +44,8 @@ class TestSymbolicOps(unittest.TestCase):
|
||||
with self.assertRaises(AssertionError):
|
||||
f(a.reshape(3, vi), b.reshape(vi, 5)).numpy()
|
||||
|
||||
@unittest.skipIf(Device.DEFAULT == "CLANG" and CI, "broken on CLANG CI")
|
||||
def test_attention(self):
|
||||
def f(q, k, v): return Tensor.scaled_dot_product_attention(q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2)).realize()
|
||||
def test_attention(self, dropout_p=0.0):
|
||||
def f(q, k, v): return Tensor.scaled_dot_product_attention(q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2), dropout_p=dropout_p).realize()
|
||||
vi = Variable("i", 1, 10)
|
||||
for i in range(1, 5):
|
||||
q = Tensor.rand(2, 1, 4, 8)
|
||||
@@ -56,6 +55,13 @@ class TestSymbolicOps(unittest.TestCase):
|
||||
expected = f(q, k, v).numpy()
|
||||
np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6)
|
||||
|
||||
def test_attention_training(self):
|
||||
Tensor.training = True
|
||||
self.test_attention(dropout_p=0.0)
|
||||
with self.assertRaises(TypeError):
|
||||
# symbolic shape dropout is not supported
|
||||
self.test_attention(dropout_p=0.5)
|
||||
|
||||
def test_cat_dim0(self):
|
||||
def f(a, b): return a.cat(b, dim=0).realize()
|
||||
vi = Variable("i", 1, 10)
|
||||
|
||||
Reference in New Issue
Block a user