diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 0f90ba745e..6aaa41210e 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -35,16 +35,18 @@ jobs: IMAGE=2 PYTHON=1 python3 test/test_ops.py TestOps.test_simple_conv2d - name: Test emulated METAL tensor cores run: DEBUG=2 EMULATE_METAL=1 FORWARD_ONLY=1 PYTHON=1 python3 test/test_ops.py TestOps.test_big_gemm - - name: Test emulated HIP tensor cores + - name: Test emulated HSA tensor cores run: | - PYTHONPATH=. DEBUG=2 EMULATE_HIP=1 FORWARD_ONLY=1 PYTHON=1 N=16 HALF=1 ACC_HALF=0 python3 ./extra/gemm/simple_matmul.py - PYTHONPATH=. DEBUG=2 EMULATE_HIP=1 FORWARD_ONLY=1 PYTHON=1 N=64 HALF=1 ACC_HALF=0 python3 ./extra/gemm/simple_matmul.py - PYTHONPATH=. DEBUG=2 EMULATE_HIP=1 FORWARD_ONLY=1 PYTHON=1 N=16 HALF=1 ACC_HALF=1 python3 ./extra/gemm/simple_matmul.py - PYTHONPATH=. DEBUG=2 EMULATE_HIP=1 FORWARD_ONLY=1 PYTHON=1 N=64 HALF=1 ACC_HALF=1 python3 ./extra/gemm/simple_matmul.py + PYTHONPATH=. DEBUG=2 EMULATE_HSA=1 FORWARD_ONLY=1 PYTHON=1 N=16 HALF=1 ACC_HALF=0 python3 ./extra/gemm/simple_matmul.py + PYTHONPATH=. DEBUG=2 EMULATE_HSA=1 FORWARD_ONLY=1 PYTHON=1 N=64 HALF=1 ACC_HALF=0 python3 ./extra/gemm/simple_matmul.py + PYTHONPATH=. DEBUG=2 EMULATE_HSA=1 FORWARD_ONLY=1 PYTHON=1 N=16 HALF=1 ACC_HALF=1 python3 ./extra/gemm/simple_matmul.py + PYTHONPATH=. DEBUG=2 EMULATE_HSA=1 FORWARD_ONLY=1 PYTHON=1 N=64 HALF=1 ACC_HALF=1 python3 ./extra/gemm/simple_matmul.py + - name: Test enumlated CUDA tensor cores + run: DEBUG=2 EMULATE_CUDA=1 FORWARD_ONLY=1 PYTHON=1 python3 test/test_ops.py TestOps.test_big_gemm - name: Full test tensor cores run: | DEBUG=2 EMULATE_METAL=1 FORWARD_ONLY=1 PYTHON=1 python3 ./test/test_linearizer.py TestLinearizer.test_tensor_cores - DEBUG=2 EMULATE_HIP=1 FORWARD_ONLY=1 PYTHON=1 python3 ./test/test_linearizer.py TestLinearizer.test_tensor_cores + DEBUG=2 EMULATE_HSA=1 FORWARD_ONLY=1 PYTHON=1 python3 ./test/test_linearizer.py TestLinearizer.test_tensor_cores DEBUG=2 EMULATE_CUDA=1 FORWARD_ONLY=1 PYTHON=1 python3 ./test/test_linearizer.py TestLinearizer.test_tensor_cores - name: Test dtype with Python emulator run: PYTHONPATH=. DEBUG=2 PYTHON=1 python3 test/test_dtype.py diff --git a/README.md b/README.md index 339e4d4a38..2384692cb1 100644 --- a/README.md +++ b/README.md @@ -81,7 +81,7 @@ tinygrad already supports numerous accelerators, including: - [x] [LLVM](tinygrad/runtime/ops_llvm.py) - [x] [METAL](tinygrad/runtime/ops_metal.py) - [x] [CUDA](tinygrad/runtime/ops_cuda.py) -- [x] [HIP](tinygrad/runtime/ops_hip.py) +- [x] [HSA](tinygrad/runtime/ops_hsa.py) And it is easy to add more! Your accelerator of choice only needs to support a total of ~25 low level ops. More information can be found in the [documentation for adding new accelerators](/docs/adding_new_accelerators.md). diff --git a/test/test_linearizer_overflows.py b/test/test_linearizer_overflows.py index c03de6094f..930ce48e4f 100644 --- a/test/test_linearizer_overflows.py +++ b/test/test_linearizer_overflows.py @@ -63,7 +63,7 @@ class TestLinearizerOverflow(unittest.TestCase): opts = [Opt(op=OptOps.UPCAST, axis=3, amt=4), Opt(op=OptOps.LOCAL, axis=3, amt=16), Opt(op=OptOps.UPCAST, axis=1, amt=4), Opt(op=OptOps.LOCAL, axis=2, amt=8), Opt(op=OptOps.UPCAST, axis=1, amt=2), Opt(op=OptOps.UPCAST, axis=2, amt=4)] _test_overflow(ast, opts) -#@unittest.skipIf(Device.DEFAULT not in {"GPU", "HIP", "HSA", "CUDA", "METAL"}, "only backends with locals") +#@unittest.skipIf(Device.DEFAULT not in {"GPU", "HSA", "CUDA", "METAL"}, "only backends with locals") @unittest.skipIf(CI, "slow") class TestLinearizerOverflowAlt(unittest.TestCase): def test_overflow_1(self): diff --git a/tinygrad/codegen/kernel.py b/tinygrad/codegen/kernel.py index 40859b63ff..71653e4950 100644 --- a/tinygrad/codegen/kernel.py +++ b/tinygrad/codegen/kernel.py @@ -56,7 +56,7 @@ tensor_cores: Dict[str, List[TensorCore]] = { TensorCore(dims=[8,8,8], dtype_in=dtypes.half, dtype_out=dtypes.float, wmma_func="__metal_wmma", threads=[(0,2),(1,4),(0,2),(1,2)], thread_local_sizes=[[2],[2],[2]], thread_local_aliases=[ [[4],[0],[2],[0],[-1, 1, 3],[0]], [[0],[3],[0],[1],[2, 4],[-1]], [[4],[3],[2],[1],[0],[-1]] ]), # noqa: E501 TensorCore(dims=[8,8,8], dtype_in=dtypes.half, dtype_out=dtypes.half, wmma_func="__metal_wmma", threads=[(0,2),(1,4),(0,2),(1,2)], thread_local_sizes=[[2],[2],[2]], thread_local_aliases=[ [[4],[0],[2],[0],[-1, 1, 3],[0]], [[0],[3],[0],[1],[2, 4],[-1]], [[4],[3],[2],[1],[0],[-1]] ]), # noqa: E501 ], - "HIP": [ + "HSA": [ TensorCore(dims=[16,16,16], dtype_in=dtypes.half, dtype_out=dtypes.float, wmma_func="__builtin_amdgcn_wmma_f32_16x16x16_f16_w32", threads=[(0,16),(1,2)], thread_local_sizes=[[16],[16],[8]], thread_local_aliases=[ [[0],[0],[-1],[1]], [[0],[1],[-1],[0]], [[0],[1],[0],[2,-1]] ]), # noqa: E501 TensorCore(dims=[16,16,16], dtype_in=dtypes.half, dtype_out=dtypes.half, wmma_func="__hip_wmma_f16_f16", threads=[(0,16),(1,2)], thread_local_sizes=[[16],[16],[8]], thread_local_aliases=[ [[0],[0],[-1],[1]], [[0],[1],[-1],[0]], [[0],[1],[0],[2,-1]] ]), # noqa: E501 ], @@ -64,7 +64,6 @@ tensor_cores: Dict[str, List[TensorCore]] = { TensorCore(dims=[8,16,16], dtype_in=dtypes.half, dtype_out=dtypes.float, wmma_func="__cuda_mma_m16n8k16_f16_f32", threads=[(0,2),(0,2),(1,2),(1,2),(0,2)], thread_local_sizes=[[2,2,2],[2,2],[2,2]], thread_local_aliases=[ [[0],[-2],[5],[0],[0],[-1,1,2,-3],[3,4]], [[5],[0],[0],[4],[3],[-1,1,2,-2],[0]], [[2],[-2],[5],[1],[-1],[0],[3,4]] ]), # noqa: E501 ], } -tensor_cores["HSA"] = tensor_cores["HIP"] class LocalBuffer(NamedTuple): name: str diff --git a/tinygrad/features/search.py b/tinygrad/features/search.py index 1175e9d2db..0540957a74 100644 --- a/tinygrad/features/search.py +++ b/tinygrad/features/search.py @@ -107,7 +107,7 @@ def beam_search(lin:Linearizer, rawbufs, amt:int, allow_test_size=True) -> Linea beam: List[Tuple[Linearizer, float]] = [] seen_libs = set() - default_parallel, min_progress_micros = 1 if lin.opts.device in {"CUDA", "HIP", "HSA"} else 0, getenv("BEAM_MIN_PROGRESS",0.01) + default_parallel, min_progress_micros = 1 if lin.opts.device in {"CUDA", "HSA"} else 0, getenv("BEAM_MIN_PROGRESS",0.01) if beam_pool is None and getenv("PARALLEL", default_parallel): beam_pool = multiprocessing.Pool(multiprocessing.cpu_count(), _init_worker) try: diff --git a/tinygrad/lazy.py b/tinygrad/lazy.py index 807199146a..eb729fa7e9 100644 --- a/tinygrad/lazy.py +++ b/tinygrad/lazy.py @@ -82,11 +82,10 @@ class LazyBuffer: def is_unrealized_contiguous_const(self): return self.base == self and not self.base.realized and self.op is LoadOps.CONST def _copy(self, device:str) -> LazyBuffer: - sync_size = 1 if self.device.startswith("HIP") else 0 if self.device.startswith("EXT") or self.device.startswith("DISK"): # DISK/EXT don't sync return create_lazybuffer(device, ShapeTracker.from_shape(self.shape), self.dtype, LoadOps.COPY, None, (self,), enable_cache=False) - sync = LazyBuffer.loadop(LoadOps.SYNC, (sync_size,), dtypes.uint32, self.device, src=(self,), enable_cache=True) + sync = LazyBuffer.loadop(LoadOps.SYNC, (0,), dtypes.uint32, self.device, src=(self,), enable_cache=True) wait = LazyBuffer.loadop(LoadOps.WAIT, (0,), dtypes.uint32, device, src=(sync,), enable_cache=True) return create_lazybuffer(device, ShapeTracker.from_shape(self.shape), self.dtype, LoadOps.COPY, None, (self, wait), enable_cache=False) diff --git a/tinygrad/runtime/ops_python.py b/tinygrad/runtime/ops_python.py index aa868cf5e2..d86707fcb6 100644 --- a/tinygrad/runtime/ops_python.py +++ b/tinygrad/runtime/ops_python.py @@ -188,7 +188,7 @@ class PythonProgram: class PythonCompiler(Compiler): linearizer_opts = LinearizerOptions("METAL", has_tensor_cores=True) if getenv("EMULATE_METAL") else \ - (LinearizerOptions("HIP", has_tensor_cores=True) if getenv("EMULATE_HIP") else \ + (LinearizerOptions("HSA", has_tensor_cores=True) if getenv("EMULATE_HSA") else \ (LinearizerOptions("CUDA", has_tensor_cores=True) if getenv("EMULATE_CUDA") else LinearizerOptions("PYTHON"))) def render(self, name:str, uops:UOpGraph) -> str: lops = [(u.uop, u.dtype, [uops.uops.index(v) for v in u.vin], u.arg) for u in uops]