fix beam search for llvm, this needs tests (#2101)

This commit is contained in:
George Hotz
2023-10-17 20:09:42 -07:00
committed by GitHub
parent 4d1e59abfd
commit 2498802b46

View File

@@ -29,8 +29,8 @@ def time_linearizer(lin:Linearizer, rawbufs:List[RawBuffer], allow_test_size=Tru
try:
lin.linearize()
prg = cast(Compiled, Device[Device.DEFAULT]).to_program(lin)
real_global_size = prg.global_size[:]
if allow_test_size:
real_global_size = prg.global_size
if allow_test_size and prg.global_size:
test_global_size = prg.global_size[:]
while prod(test_global_size) > max_global_size:
for j in range(2,-1,-1):
@@ -45,12 +45,13 @@ def time_linearizer(lin:Linearizer, rawbufs:List[RawBuffer], allow_test_size=Tru
# TODO: this is super broken for var_vals
# TODO: this is copied from prg.__call__
global_size, local_size = prg.launch_dims(var_vals)
if local_size is None:
if global_size is not None and local_size is None:
local_size = prg.optimize_local_size(global_size, rawbufs)
global_size = [g//l if g%l == 0 else g/l for g,l in zip(global_size, local_size)]
tms = [prg.clprg(global_size, local_size, *rawbufs, *var_vals.values(), wait=True)*factor for _ in range(cnt)]
prg.global_size = real_global_size
except Exception:
#import traceback; traceback.print_exc()
#print("FAILED")
#print(lin.ast)
#print(lin.applied_opts)