Files
tinygrad/extra/optimization/test_beam_search.py
nimlgen 4e0d47533e beam works with var vals (#2296)
* beam works with var vals

* test passes now

* better comment

* linter happy
2023-11-14 13:03:19 -05:00

35 lines
851 B
Python

import unittest
import numpy as np
from tinygrad.helpers import BEAM, Timing
from tinygrad.shape.symbolic import Variable
from tinygrad.tensor import Tensor
from tinygrad.nn import Conv2d
class TestBeamSearch(unittest.TestCase):
def setUp(self):
self.old_beam = BEAM.value
BEAM.value = 2
def tearDown(self):
BEAM.value = self.old_beam
def test_variable_ast_beam(self):
a = Tensor.rand(3, 3).reshape((Variable("a", 1, 10).bind(3), 3))
a = (a+1).realize()
def test_no_mutate_rawbuffers(self):
a = Tensor.rand(3, 3).realize()
desired = a.numpy() + 1
a.assign(a+1)
actual = a.numpy()
np.testing.assert_allclose(actual, desired)
def test_conv_beam(self):
c = Conv2d(3, 16, (3,3))
x = Tensor.rand(1,3,32,32)
with Timing():
c(x).realize()
if __name__ == '__main__':
unittest.main()