exclude ast with variables from beam search (#2140)

* exclude ast with variables from beam search

* test that

* add to CI
This commit is contained in:
chenyu
2023-10-25 16:35:29 -04:00
committed by GitHub
parent a52b420fb3
commit 0ca0e9ee5e
3 changed files with 20 additions and 1 deletions

View File

@@ -187,6 +187,9 @@ jobs:
- if: ${{ matrix.task == 'optimization' }}
name: Test Action Space
run: PYTHONPATH="." GPU=1 python3 extra/optimization/get_action_space.py
- if: ${{ matrix.task == 'optimization' }}
name: Test Beam Search
run: PYTHONPATH="." GPU=1 python3 -m pytest extra/optimization/test_beam_search.py
testmetalwebgpu:
name: Metal and WebGPU Tests

View File

@@ -0,0 +1,16 @@
import unittest
from tinygrad.helpers import BEAM
from tinygrad.shape.symbolic import Variable
from tinygrad.tensor import Tensor
class TestBeamSearch(unittest.TestCase):
def setUp(self):
self.old_beam = BEAM.value
BEAM.value = 2
def tearDown(self):
BEAM.value = self.old_beam
def test_variable_ast_no_beam(self):
a = Tensor.rand(3, 3).reshape((Variable("a", 1, 10).bind(3), 3))
a = (a+1).realize()

View File

@@ -288,7 +288,7 @@ class Compiled:
assert k.info.dtype == output.dtype, f"linearizer must match dtype. linearizer wants {k.info.dtype} but buffer is {output.dtype}"
if not getenv("NOOPT"):
if not (used_tensor_cores:=k.apply_tensor_cores(getenv("TC", 1))): k.hand_coded_optimizations()
if BEAM:
if BEAM and not vars_from_ast(ast):
kb = Linearizer(ast, self.linearizer_opts)
kb.required_optimizations()
kb.dont_use_locals = bool(getenv("NOLOCALS"))