hotfix: ptx threads match cuda threads

This commit is contained in:
George Hotz
2024-07-28 16:53:24 -07:00
parent 460b120d62
commit 5b84a7db1a

View File

@@ -21,7 +21,7 @@ class PTXRenderer(Renderer):
global_max = (2147483647, 65535, 65535)
local_max = (1024, 1024, 64)
shared_max = 49152
tensor_cores = [TensorCore(dims=(8,16,16), threads=[(0,2),(0,2),(1,2),(1,2),(0,2)], dtype_in=di, dtype_out=do) for (di, do) in ([(dtypes.half, dtypes.float)])] # noqa: E501
tensor_cores = [TensorCore(dims=(8,16,16), threads=[(0,2),(0,2),(1,2),(1,2),(1,2)], dtype_in=di, dtype_out=do) for (di, do) in ([(dtypes.half, dtypes.float)])] # noqa: E501
def __init__(self, arch:str, device="CUDA"): self.device, self.tensor_cores = device, PTXRenderer.tensor_cores if int(arch[3:]) >= 80 else []
# language options