mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-22 21:38:10 -05:00
hotfix: ptx threads match cuda threads
This commit is contained in:
@@ -21,7 +21,7 @@ class PTXRenderer(Renderer):
|
||||
global_max = (2147483647, 65535, 65535)
|
||||
local_max = (1024, 1024, 64)
|
||||
shared_max = 49152
|
||||
tensor_cores = [TensorCore(dims=(8,16,16), threads=[(0,2),(0,2),(1,2),(1,2),(0,2)], dtype_in=di, dtype_out=do) for (di, do) in ([(dtypes.half, dtypes.float)])] # noqa: E501
|
||||
tensor_cores = [TensorCore(dims=(8,16,16), threads=[(0,2),(0,2),(1,2),(1,2),(1,2)], dtype_in=di, dtype_out=do) for (di, do) in ([(dtypes.half, dtypes.float)])] # noqa: E501
|
||||
def __init__(self, arch:str, device="CUDA"): self.device, self.tensor_cores = device, PTXRenderer.tensor_cores if int(arch[3:]) >= 80 else []
|
||||
|
||||
# language options
|
||||
|
||||
Reference in New Issue
Block a user