remove hip backend (#3783)

* remove hip backend

* remove unused

* rhip

* more RHIP
This commit is contained in:
George Hotz
2024-03-17 10:12:16 -07:00
committed by GitHub
parent 2a14d1b5e0
commit 53adcb34f5
11 changed files with 33 additions and 15 deletions

View File

@@ -375,7 +375,7 @@ jobs:
path: ~/.cache/tinygrad/downloads/
key: downloads-cache-${{ matrix.backend }}-${{ env.DOWNLOAD_CACHE_VERSION }}
- name: Set env
run: printf "${{ matrix.backend == 'llvm' && 'LLVM=1' || matrix.backend == 'clang' && 'CLANG=1' || matrix.backend == 'gpu' && 'GPU=1' || matrix.backend == 'cuda' && 'FORWARD_ONLY=1\nJIT=1\nOPT=2\nCUDA=1\nCUDACPU=1\n' || matrix.backend == 'PTX' && 'FORWARD_ONLY=1\nJIT=1\nOPT=2\nCUDA=1\nCUDACPU=1\nPTX=1' || matrix.backend == 'triton' && 'FORWARD_ONLY=1\nJIT=1\nOPT=2\nCUDA=1\nCUDACPU=1\nTRITON=1\nTRITON_PTXAS_PATH=/usr/bin/ptxas' || matrix.backend == 'hip' && 'HIP=1\nHIPCPU=1\nFORWARD_ONLY=1' }}" >> $GITHUB_ENV
run: printf "${{ matrix.backend == 'llvm' && 'LLVM=1' || matrix.backend == 'clang' && 'CLANG=1' || matrix.backend == 'gpu' && 'GPU=1' || matrix.backend == 'cuda' && 'FORWARD_ONLY=1\nJIT=1\nOPT=2\nCUDA=1\nCUDACPU=1\n' || matrix.backend == 'PTX' && 'FORWARD_ONLY=1\nJIT=1\nOPT=2\nCUDA=1\nCUDACPU=1\nPTX=1' || matrix.backend == 'triton' && 'FORWARD_ONLY=1\nJIT=1\nOPT=2\nCUDA=1\nCUDACPU=1\nTRITON=1\nTRITON_PTXAS_PATH=/usr/bin/ptxas' || matrix.backend == 'hip' && 'RHIP=1\nFORWARD_ONLY=1' }}" >> $GITHUB_ENV
- name: Install OpenCL
if: matrix.backend == 'gpu'
run: |
@@ -435,7 +435,7 @@ jobs:
run: pip install -e '.[testing${{matrix.backend=='llvm'&&',llvm'||matrix.backend=='cuda'&&',cuda'||matrix.backend=='ptx'&&',cuda'||matrix.backend=='triton'&&',triton'||''}}]' --extra-index-url https://download.pytorch.org/whl/cpu --extra-index-url https://aiinfra.pkgs.visualstudio.com/PublicPackages/_packaging/Triton-Nightly/pypi/simple/
- name: Check Device.DEFAULT and print some source
run: |
python -c "from tinygrad import Device; assert Device.DEFAULT in ['LLVM','CLANG','CUDA','GPU', 'HIP'], Device.DEFAULT"
python -c "from tinygrad import Device; assert Device.DEFAULT in ['LLVM','CLANG','CUDA','GPU','RHIP'], Device.DEFAULT"
DEBUG=5 PYTHONPATH=${{ github.workspace }} FORWARD_ONLY=1 python3 test/test_ops.py TestOps.test_add
- name: Verify OpenCL autogen
if: matrix.backend == 'gpu'

View File

@@ -23,7 +23,7 @@ def assert_jit_cache_len(fxn, expected_len):
def is_dtype_supported(dtype: DType, device: str = Device.DEFAULT):
if dtype == dtypes.bfloat16:
# NOTE: this requires bf16 buffer support
return device in ["HIP"]
return device in {"RHIP", "HSA"}
if device in ["WEBGPU", "WEBGL"]: return dtype in [dtypes.float, dtypes.int32, dtypes.uint32]
# for CI GPU, cl_khr_fp16 isn't supported
# for CI LLVM, it segfaults because it can't link to the casting function

View File

@@ -142,7 +142,7 @@ class TestDTypeALU(unittest.TestCase):
def test_int32_midcast_float(self, a, b, c, op1, op2): universal_test_midcast(a, b, c, op1, op2, dtypes.int32, dtypes.float32)
# Metal and CUDACPU and HIP behave differently than numpy in CI for overflows
skip_overflow = CI and (Device.DEFAULT == "HIP" or getenv("CUDACPU"))
skip_overflow = CI and (Device.DEFAULT in {"RHIP", "HSA"} or getenv("CUDACPU"))
@given(strat.floats(width=32, min_value=0, max_value=10.0) if skip_overflow else ht.float32,
strat.floats(width=32, min_value=0, max_value=10.0) if skip_overflow else ht.float32,
ht.int32, strat.sampled_from(binary_operations), strat.sampled_from(integer_binary_operations))

View File

@@ -706,7 +706,7 @@ class TestLinearizerOpts(unittest.TestCase):
], apply_tc=True, atol=atol, rtol=rtol)
def test_padto_matmul(self):
if Device.DEFAULT in ["CUDA", "HIP"]: self.skipTest("super slow on CUDA and HIP because of the big grid dims")
if Device.DEFAULT in ["CUDA", "RHIP"]: self.skipTest("super slow on CUDA and RHIP because of the big grid dims")
N = 17 * 17
Tensor.manual_seed(289)
a = Tensor.rand(N, N)

View File

@@ -17,7 +17,6 @@ def _test_overflow(ast, opts):
lin.linearize()
bufs = bufs_from_lin(lin)
print(bufs)
if bufs[0].device in {"HIP", "HSA"}: print([hex(x._buf.value) for x in bufs])
time_linearizer(lin, bufs)
# NOTE: if you want these to trigger, set launch bounds on HIP kernels

View File

@@ -2,7 +2,7 @@ import pathlib, unittest
import numpy as np
from tinygrad import Tensor, Device, dtypes
from tinygrad.nn.state import safe_load, safe_save, get_state_dict, torch_load
from tinygrad.helpers import Timing, fetch, temp, getenv
from tinygrad.helpers import Timing, fetch, temp
from test.helpers import is_dtype_supported
def compare_weights_both(url):
@@ -214,7 +214,7 @@ class TestDiskTensor(unittest.TestCase):
np.testing.assert_array_equal(t.numpy(), np.array([3] * 10))
@unittest.skipIf(getenv("HIPCPU"), "no real HIP device exists in CI")
@unittest.skipIf(Device.DEFAULT == "RHIP", "no real HIP device exists in CI")
def test_bf16_disk_write_read(self):
t = Tensor([10000, -1, -1000, -10000, 20]).cast(dtypes.float32)
t.to(f"disk:{temp('f32')}").realize()

View File

@@ -31,10 +31,6 @@ def lower_schedule_item(si:ScheduleItem) -> Optional[JITRunner]:
if si.ast[0].op is BufferOps.STORE: return Device[si.outputs[0].device].get_runner(*si.ast)
assert len(si.ast) == 1 and len(si.outputs) == 1, "only ASTRunner supports multioutput"
out, ast = si.outputs[0], si.ast[0]
if ast.op in {LoadOps.SYNC, LoadOps.WAIT, LoadOps.COPY} and out.device.startswith("HIP") and si.inputs[0].device.startswith("HIP"):
from tinygrad.runtime.ops_hip import HIPSyncEvent, HIPWaitEvent
if ast.op is LoadOps.SYNC: return HIPSyncEvent(out)
if ast.op is LoadOps.WAIT: return HIPWaitEvent(out.device)
if ast.op in {LoadOps.SYNC, LoadOps.WAIT} and out.device.startswith("HSA") and si.inputs[0].device.startswith("HSA"):
# Our HSA runtime handles synchronization
if ast.op is LoadOps.SYNC: return None

View File

@@ -3,10 +3,11 @@ import ctypes, functools, subprocess, io, atexit, collections, json
from typing import Tuple, TypeVar, List, Dict, Any
import tinygrad.runtime.autogen.hsa as hsa
from tinygrad.helpers import DEBUG, init_c_var, from_mv, round_up, to_mv, init_c_struct_t, getenv
from tinygrad.device import Compiled, LRUAllocator, BufferOptions
from tinygrad.device import Compiled, LRUAllocator, BufferOptions, Compiler
from tinygrad.codegen.kernel import LinearizerOptions
from tinygrad.runtime.ops_hip import HIPCompiler
from tinygrad.runtime.driver.hsa import check, scan_agents, find_memory_pool, AQLQueue
from tinygrad.renderer.cstyle import HIPRenderer
from tinygrad.runtime.driver.hip_comgr import compile_hip
PROFILE = getenv("PROFILE", 0)
@@ -40,8 +41,13 @@ class HSAProfiler:
print(f"Saved HSA profile to {path}")
Profiler = HSAProfiler()
class HSACompiler(HIPCompiler):
class HSACompiler(Compiler):
linearizer_opts = LinearizerOptions("HSA", has_tensor_cores=True, shared_max=65536)
def __init__(self, arch:str):
self.arch = arch
super().__init__(f"compile_hip_{self.arch}")
def render(self, name:str, uops) -> str: return HIPRenderer(name, uops)
def compile(self, src:str) -> bytes: return compile_hip(src, self.arch)
class HSAProgram:
def __init__(self, device:HSADevice, name:str, lib:bytes):

View File

@@ -0,0 +1,17 @@
import ctypes
from tinygrad.device import Compiled, MallocAllocator
from tinygrad.runtime.ops_hsa import HSACompiler
rhip = ctypes.CDLL("/usr/local/lib/libremu.so")
class RHIPProgram:
def __init__(self, name:str, lib:bytes):
self.name, self.lib = name, lib
def __call__(self, *args, global_size, local_size, vals=(), wait=False):
args = (*args, *vals)
rhip.hipModuleLaunchKernel(self.lib, len(self.lib), *global_size, *local_size, 0, None, None,
len(args), (ctypes.c_void_p * len(args))(*[ctypes.cast(x, ctypes.c_void_p) for x in args]))
class RHIPDevice(Compiled):
def __init__(self, device:str=""):
self.device = int(device.split(":")[1]) if ":" in device else 0
super().__init__(device, MallocAllocator, HSACompiler("gfx1100"), RHIPProgram)