diff --git a/tinygrad/features/search.py b/tinygrad/features/search.py index b87b2a1b39..52a2ba19a6 100644 --- a/tinygrad/features/search.py +++ b/tinygrad/features/search.py @@ -29,8 +29,8 @@ def time_linearizer(lin:Linearizer, rawbufs:List[RawBuffer], allow_test_size=Tru try: lin.linearize() prg = cast(Compiled, Device[Device.DEFAULT]).to_program(lin) - real_global_size = prg.global_size[:] - if allow_test_size: + real_global_size = prg.global_size + if allow_test_size and prg.global_size: test_global_size = prg.global_size[:] while prod(test_global_size) > max_global_size: for j in range(2,-1,-1): @@ -45,12 +45,13 @@ def time_linearizer(lin:Linearizer, rawbufs:List[RawBuffer], allow_test_size=Tru # TODO: this is super broken for var_vals # TODO: this is copied from prg.__call__ global_size, local_size = prg.launch_dims(var_vals) - if local_size is None: + if global_size is not None and local_size is None: local_size = prg.optimize_local_size(global_size, rawbufs) global_size = [g//l if g%l == 0 else g/l for g,l in zip(global_size, local_size)] tms = [prg.clprg(global_size, local_size, *rawbufs, *var_vals.values(), wait=True)*factor for _ in range(cnt)] prg.global_size = real_global_size except Exception: + #import traceback; traceback.print_exc() #print("FAILED") #print(lin.ast) #print(lin.applied_opts)