From 9b4de8abc7341f8f58857639cbb716a0687e9421 Mon Sep 17 00:00:00 2001 From: anu Date: Sat, 27 Dec 2025 16:24:22 -0500 Subject: [PATCH] fix beam in python 3.14+ (#13836) * fix beam search on python 3.14 * add PickleableCount class to helpers * change name, add test, add step * tidy count init --- test/unit/test_helpers.py | 16 ++++++++++++++-- tinygrad/codegen/opt/postrange.py | 4 ++-- tinygrad/helpers.py | 8 ++++++++ 3 files changed, 24 insertions(+), 4 deletions(-) diff --git a/test/unit/test_helpers.py b/test/unit/test_helpers.py index 63eca3f705..f72486ae52 100644 --- a/test/unit/test_helpers.py +++ b/test/unit/test_helpers.py @@ -1,6 +1,6 @@ -import ctypes, gzip, unittest, timeit +import ctypes, gzip, unittest, timeit, pickle from tinygrad import Variable -from tinygrad.helpers import Context, ContextVar, argfix, colored, word_wrap, is_numpy_ndarray, mv_address, get_contraction +from tinygrad.helpers import Context, ContextVar, argfix, colored, word_wrap, is_numpy_ndarray, mv_address, get_contraction, count from tinygrad.helpers import merge_dicts, strip_parens, prod, round_up, fetch, fully_flatten, from_mv, to_mv, polyN, time_to_str, cdiv, cmod, getbits from tinygrad.tensor import Tensor, get_shape import numpy as np @@ -120,6 +120,18 @@ class TestRoundUp(unittest.TestCase): self.assertEqual(round_up(232, 24984), 24984) self.assertEqual(round_up(24984, 232), 25056) +class TestCount(unittest.TestCase): + def test_count_basic(self): + c = count(3) + self.assertEqual(next(c), 3) + self.assertEqual(next(c), 4) + + def test_count_step_pickle(self): + c = count(1, 2) + self.assertEqual(next(c), 1) + c2 = pickle.loads(pickle.dumps(c)) + self.assertEqual(next(c2), 3) + @unittest.skip("no fetch tests because they need internet") class TestFetch(unittest.TestCase): def test_fetch_bad_http(self): diff --git a/tinygrad/codegen/opt/postrange.py b/tinygrad/codegen/opt/postrange.py index cc2809efb4..b00bd5bea3 100644 --- a/tinygrad/codegen/opt/postrange.py +++ b/tinygrad/codegen/opt/postrange.py @@ -6,7 +6,7 @@ from tinygrad.uop.ops import axis_letters, axis_colors, axis_to_pos from tinygrad.device import Buffer from tinygrad.dtype import dtypes, ImageDType from tinygrad.helpers import colored, BEAM, getenv, DEBUG, to_function_name, NOOPT, argsort, round_up, prod, merge_dicts, get_single_element, flatten -from tinygrad.helpers import ALLOW_TF32 +from tinygrad.helpers import ALLOW_TF32, count from tinygrad.codegen.opt import Opt, OptOps, KernelOptError, check from tinygrad.codegen.simplify import pm_flatten_range from tinygrad.renderer import Renderer @@ -18,7 +18,7 @@ class Scheduler: self.ast, self.ren = ast, ren self.dont_use_locals = self.ast.arg.dont_use_locals if self.ast.arg is not None else False self.applied_opts = list(self.ast.arg.applied_opts) if self.ast.arg is not None else [] - self.opt_range = itertools.count(start=max([x.arg[0] for x in self.rngs], default=0)+1) + self.opt_range = count(start=max([x.arg[0] for x in self.rngs], default=0)+1) @property def rngs(self): diff --git a/tinygrad/helpers.py b/tinygrad/helpers.py index e0000f1590..7514bd34ae 100644 --- a/tinygrad/helpers.py +++ b/tinygrad/helpers.py @@ -542,3 +542,11 @@ copyreg.pickle(types.CodeType, _serialize_code) def _serialize_module(module:types.ModuleType): return importlib.import_module, (module.__name__,) copyreg.pickle(types.ModuleType, _serialize_module) + +class count: + def __init__(self, start:int=0, step:int=1): + self.n, self.step = start, step + def __next__(self) -> int: + cur = self.n + self.n += self.step + return cur