mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-10 23:48:01 -05:00
fix beam search for llvm, this needs tests (#2101)
This commit is contained in:
@@ -29,8 +29,8 @@ def time_linearizer(lin:Linearizer, rawbufs:List[RawBuffer], allow_test_size=Tru
|
||||
try:
|
||||
lin.linearize()
|
||||
prg = cast(Compiled, Device[Device.DEFAULT]).to_program(lin)
|
||||
real_global_size = prg.global_size[:]
|
||||
if allow_test_size:
|
||||
real_global_size = prg.global_size
|
||||
if allow_test_size and prg.global_size:
|
||||
test_global_size = prg.global_size[:]
|
||||
while prod(test_global_size) > max_global_size:
|
||||
for j in range(2,-1,-1):
|
||||
@@ -45,12 +45,13 @@ def time_linearizer(lin:Linearizer, rawbufs:List[RawBuffer], allow_test_size=Tru
|
||||
# TODO: this is super broken for var_vals
|
||||
# TODO: this is copied from prg.__call__
|
||||
global_size, local_size = prg.launch_dims(var_vals)
|
||||
if local_size is None:
|
||||
if global_size is not None and local_size is None:
|
||||
local_size = prg.optimize_local_size(global_size, rawbufs)
|
||||
global_size = [g//l if g%l == 0 else g/l for g,l in zip(global_size, local_size)]
|
||||
tms = [prg.clprg(global_size, local_size, *rawbufs, *var_vals.values(), wait=True)*factor for _ in range(cnt)]
|
||||
prg.global_size = real_global_size
|
||||
except Exception:
|
||||
#import traceback; traceback.print_exc()
|
||||
#print("FAILED")
|
||||
#print(lin.ast)
|
||||
#print(lin.applied_opts)
|
||||
|
||||
Reference in New Issue
Block a user