mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-22 21:38:10 -05:00
Fix cuda runtime (#625)
This commit is contained in:
@@ -24,7 +24,7 @@ class CUDAProgram:
|
||||
global_size = global_size + [1] * (3 - len(global_size))
|
||||
assert all(x%y == 0 for x,y in zip(global_size, local_size)), f"local:{local_size} must divide global:{global_size}"
|
||||
global_size = [x//y for x,y in zip(global_size, local_size)]
|
||||
self.prg(*args, block=tuple(local_size), grid=tuple(global_size))
|
||||
self.prg(*[x._cl for x in args], block=tuple(local_size), grid=tuple(global_size))
|
||||
|
||||
class CUDACodegen(GPUCodegen):
|
||||
lang = GPULanguage(
|
||||
|
||||
Reference in New Issue
Block a user