Fix cuda runtime (#625)

This commit is contained in:
Martin Loretz
2023-03-02 15:52:34 +01:00
committed by GitHub
parent fca055bd66
commit 51fb6aeb45

View File

@@ -24,7 +24,7 @@ class CUDAProgram:
global_size = global_size + [1] * (3 - len(global_size))
assert all(x%y == 0 for x,y in zip(global_size, local_size)), f"local:{local_size} must divide global:{global_size}"
global_size = [x//y for x,y in zip(global_size, local_size)]
self.prg(*args, block=tuple(local_size), grid=tuple(global_size))
self.prg(*[x._cl for x in args], block=tuple(local_size), grid=tuple(global_size))
class CUDACodegen(GPUCodegen):
lang = GPULanguage(