From 5b84a7db1a2af3304c211600089bfcd7ecd52a8d Mon Sep 17 00:00:00 2001 From: George Hotz Date: Sun, 28 Jul 2024 16:53:24 -0700 Subject: [PATCH] hotfix: ptx threads match cuda threads --- tinygrad/renderer/assembly.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tinygrad/renderer/assembly.py b/tinygrad/renderer/assembly.py index e60281c94d..aad47f0065 100644 --- a/tinygrad/renderer/assembly.py +++ b/tinygrad/renderer/assembly.py @@ -21,7 +21,7 @@ class PTXRenderer(Renderer): global_max = (2147483647, 65535, 65535) local_max = (1024, 1024, 64) shared_max = 49152 - tensor_cores = [TensorCore(dims=(8,16,16), threads=[(0,2),(0,2),(1,2),(1,2),(0,2)], dtype_in=di, dtype_out=do) for (di, do) in ([(dtypes.half, dtypes.float)])] # noqa: E501 + tensor_cores = [TensorCore(dims=(8,16,16), threads=[(0,2),(0,2),(1,2),(1,2),(1,2)], dtype_in=di, dtype_out=do) for (di, do) in ([(dtypes.half, dtypes.float)])] # noqa: E501 def __init__(self, arch:str, device="CUDA"): self.device, self.tensor_cores = device, PTXRenderer.tensor_cores if int(arch[3:]) >= 80 else [] # language options