mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 15:08:02 -05:00
fix odd number in get_test_global_size (#10671)
factor might not be a integer if input global_size has an odd number in it
This commit is contained in:
11
test/unit/test_search.py
Normal file
11
test/unit/test_search.py
Normal file
@@ -0,0 +1,11 @@
|
||||
import unittest
|
||||
from tinygrad.engine.search import get_test_global_size
|
||||
|
||||
class TestSearchUtil(unittest.TestCase):
|
||||
def test_get_test_global_size(self):
|
||||
self.assertEqual(get_test_global_size([256, 256, 256], 65536, {}), ([256, 16, 16], 256.0))
|
||||
self.assertEqual(get_test_global_size([65536, 1, 1], 256, {}), ([256, 1, 1], 256.0))
|
||||
self.assertEqual(get_test_global_size([77, 1, 1], 16, {}), ([9, 1, 1], 77/9))
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
@@ -25,21 +25,21 @@ actions += [Opt(op=OptOps.TC, axis=axis, arg=(-1, getenv("TC_OPT", 2), getenv("T
|
||||
actions += [Opt(op=OptOps.SWAP, axis=axis_0, arg=axis_1) for axis_0 in range(5) for axis_1 in range(axis_0+1, 5)]
|
||||
if getenv("NOLOCALS"): actions += [Opt(op=OptOps.NOLOCALS)]
|
||||
|
||||
def _get_test_global_size(global_size, max_global_size, var_vals):
|
||||
test_global_size, factor = [sym_infer(sz, var_vals) for sz in global_size], 1
|
||||
def get_test_global_size(global_size, max_global_size, var_vals):
|
||||
test_global_size = [sym_infer(sz, var_vals) for sz in global_size]
|
||||
input_size = prod(test_global_size)
|
||||
while prod(test_global_size) > max_global_size:
|
||||
for j in range(len(global_size)-1,-1,-1):
|
||||
if test_global_size[j] > 16:
|
||||
test_global_size[j] //= 2
|
||||
factor *= 2
|
||||
break
|
||||
return test_global_size, factor
|
||||
return test_global_size, input_size / prod(test_global_size)
|
||||
|
||||
def _time_program(p:ProgramSpec, lib:bytes, var_vals:dict[Variable, int], rawbufs:list[Buffer], early_stop:Optional[float]=None,
|
||||
allow_test_size:int=True, max_global_size:Optional[int]=65536, clear_l2=False, cnt=3, name="test") -> list[float]:
|
||||
factor = 1
|
||||
if allow_test_size and p.global_size is not None and max_global_size is not None:
|
||||
global_size, factor = _get_test_global_size(p.global_size, max_global_size, var_vals)
|
||||
global_size, factor = get_test_global_size(p.global_size, max_global_size, var_vals)
|
||||
p = replace(p, global_size=global_size)
|
||||
try: car = CompiledRunner(p, precompiled=lib)
|
||||
except AssertionError: return [math.inf] * cnt
|
||||
|
||||
Reference in New Issue
Block a user