mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-10 07:28:15 -05:00
exclude ast with variables from beam search (#2140)
* exclude ast with variables from beam search * test that * add to CI
This commit is contained in:
3
.github/workflows/test.yml
vendored
3
.github/workflows/test.yml
vendored
@@ -187,6 +187,9 @@ jobs:
|
||||
- if: ${{ matrix.task == 'optimization' }}
|
||||
name: Test Action Space
|
||||
run: PYTHONPATH="." GPU=1 python3 extra/optimization/get_action_space.py
|
||||
- if: ${{ matrix.task == 'optimization' }}
|
||||
name: Test Beam Search
|
||||
run: PYTHONPATH="." GPU=1 python3 -m pytest extra/optimization/test_beam_search.py
|
||||
|
||||
testmetalwebgpu:
|
||||
name: Metal and WebGPU Tests
|
||||
|
||||
16
extra/optimization/test_beam_search.py
Normal file
16
extra/optimization/test_beam_search.py
Normal file
@@ -0,0 +1,16 @@
|
||||
import unittest
|
||||
|
||||
from tinygrad.helpers import BEAM
|
||||
from tinygrad.shape.symbolic import Variable
|
||||
from tinygrad.tensor import Tensor
|
||||
|
||||
class TestBeamSearch(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.old_beam = BEAM.value
|
||||
BEAM.value = 2
|
||||
def tearDown(self):
|
||||
BEAM.value = self.old_beam
|
||||
|
||||
def test_variable_ast_no_beam(self):
|
||||
a = Tensor.rand(3, 3).reshape((Variable("a", 1, 10).bind(3), 3))
|
||||
a = (a+1).realize()
|
||||
@@ -288,7 +288,7 @@ class Compiled:
|
||||
assert k.info.dtype == output.dtype, f"linearizer must match dtype. linearizer wants {k.info.dtype} but buffer is {output.dtype}"
|
||||
if not getenv("NOOPT"):
|
||||
if not (used_tensor_cores:=k.apply_tensor_cores(getenv("TC", 1))): k.hand_coded_optimizations()
|
||||
if BEAM:
|
||||
if BEAM and not vars_from_ast(ast):
|
||||
kb = Linearizer(ast, self.linearizer_opts)
|
||||
kb.required_optimizations()
|
||||
kb.dont_use_locals = bool(getenv("NOLOCALS"))
|
||||
|
||||
Reference in New Issue
Block a user