fix odd number in get_test_global_size (#10671)

factor might not be a integer if input global_size has an odd number in it
This commit is contained in:
chenyu
2025-06-06 17:31:35 -04:00
committed by GitHub
parent bf4ffc054c
commit bdede4924e
2 changed files with 16 additions and 5 deletions

11
test/unit/test_search.py Normal file
View File

@@ -0,0 +1,11 @@
import unittest
from tinygrad.engine.search import get_test_global_size
class TestSearchUtil(unittest.TestCase):
def test_get_test_global_size(self):
self.assertEqual(get_test_global_size([256, 256, 256], 65536, {}), ([256, 16, 16], 256.0))
self.assertEqual(get_test_global_size([65536, 1, 1], 256, {}), ([256, 1, 1], 256.0))
self.assertEqual(get_test_global_size([77, 1, 1], 16, {}), ([9, 1, 1], 77/9))
if __name__ == "__main__":
unittest.main()

View File

@@ -25,21 +25,21 @@ actions += [Opt(op=OptOps.TC, axis=axis, arg=(-1, getenv("TC_OPT", 2), getenv("T
actions += [Opt(op=OptOps.SWAP, axis=axis_0, arg=axis_1) for axis_0 in range(5) for axis_1 in range(axis_0+1, 5)]
if getenv("NOLOCALS"): actions += [Opt(op=OptOps.NOLOCALS)]
def _get_test_global_size(global_size, max_global_size, var_vals):
test_global_size, factor = [sym_infer(sz, var_vals) for sz in global_size], 1
def get_test_global_size(global_size, max_global_size, var_vals):
test_global_size = [sym_infer(sz, var_vals) for sz in global_size]
input_size = prod(test_global_size)
while prod(test_global_size) > max_global_size:
for j in range(len(global_size)-1,-1,-1):
if test_global_size[j] > 16:
test_global_size[j] //= 2
factor *= 2
break
return test_global_size, factor
return test_global_size, input_size / prod(test_global_size)
def _time_program(p:ProgramSpec, lib:bytes, var_vals:dict[Variable, int], rawbufs:list[Buffer], early_stop:Optional[float]=None,
allow_test_size:int=True, max_global_size:Optional[int]=65536, clear_l2=False, cnt=3, name="test") -> list[float]:
factor = 1
if allow_test_size and p.global_size is not None and max_global_size is not None:
global_size, factor = _get_test_global_size(p.global_size, max_global_size, var_vals)
global_size, factor = get_test_global_size(p.global_size, max_global_size, var_vals)
p = replace(p, global_size=global_size)
try: car = CompiledRunner(p, precompiled=lib)
except AssertionError: return [math.inf] * cnt