diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 4dc14fed0e..d6cdf7c5fe 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -332,7 +332,7 @@ jobs: strategy: fail-fast: false matrix: - backend: [llvm, clang, gpu, cuda, hip, ptx] #, triton] + backend: [llvm, clang, gpu, cuda, hip, ptx, amd] #, triton] name: Tests on (${{ matrix.backend }}) runs-on: ubuntu-latest @@ -356,7 +356,7 @@ jobs: path: ~/.cache/tinygrad/downloads/ key: downloads-cache-${{ matrix.backend }}-${{ env.DOWNLOAD_CACHE_VERSION }} - name: Set env - run: printf "${{ matrix.backend == 'llvm' && 'LLVM=1' || matrix.backend == 'clang' && 'CLANG=1' || matrix.backend == 'gpu' && 'GPU=1' || matrix.backend == 'cuda' && 'FORWARD_ONLY=1\nJIT=1\nOPT=2\nCUDA=1\nCUDACPU=1\n' || matrix.backend == 'PTX' && 'FORWARD_ONLY=1\nJIT=1\nOPT=2\nCUDA=1\nCUDACPU=1\nPTX=1' || matrix.backend == 'triton' && 'FORWARD_ONLY=1\nJIT=1\nOPT=2\nCUDA=1\nCUDACPU=1\nTRITON=1\nTRITON_PTXAS_PATH=/usr/bin/ptxas' || matrix.backend == 'hip' && 'RHIP=1\nFORWARD_ONLY=1' }}" >> $GITHUB_ENV + run: printf "${{ matrix.backend == 'llvm' && 'LLVM=1' || matrix.backend == 'clang' && 'CLANG=1' || matrix.backend == 'gpu' && 'GPU=1' || matrix.backend == 'cuda' && 'FORWARD_ONLY=1\nJIT=1\nOPT=2\nCUDA=1\nCUDACPU=1\n' || matrix.backend == 'PTX' && 'FORWARD_ONLY=1\nJIT=1\nOPT=2\nCUDA=1\nCUDACPU=1\nPTX=1' || matrix.backend == 'triton' && 'FORWARD_ONLY=1\nJIT=1\nOPT=2\nCUDA=1\nCUDACPU=1\nTRITON=1\nTRITON_PTXAS_PATH=/usr/bin/ptxas' || matrix.backend == 'hip' && 'RHIP=1\nFORWARD_ONLY=1' || matrix.backend == 'amd' && 'AMD=1\nMOCKGPU=1\nFORWARD_ONLY=1' }}" >> $GITHUB_ENV - name: Install OpenCL if: matrix.backend == 'gpu' run: | @@ -399,7 +399,7 @@ jobs: cd ${{ github.workspace }}/gpuocelot/ocelot/build sudo ninja install -d explain - name: Install packages (hip) - if: matrix.backend == 'hip' + if: matrix.backend == 'hip' || matrix.backend == 'amd' run: | echo 'Acquire::http::Pipeline-Depth "5";' | sudo tee -a /etc/apt/apt.conf.d/99parallel wget https://repo.radeon.com/rocm/rocm.gpg.key -O - | gpg --dearmor | sudo tee /etc/apt/keyrings/rocm.gpg > /dev/null @@ -416,7 +416,7 @@ jobs: run: pip install -e '.[testing${{matrix.backend=='llvm'&&',llvm'||matrix.backend=='cuda'&&',cuda'||matrix.backend=='ptx'&&',cuda'||matrix.backend=='triton'&&',triton'||''}}]' --extra-index-url https://download.pytorch.org/whl/cpu --extra-index-url https://aiinfra.pkgs.visualstudio.com/PublicPackages/_packaging/Triton-Nightly/pypi/simple/ - name: Check Device.DEFAULT and print some source run: | - python -c "from tinygrad import Device; assert Device.DEFAULT in ['LLVM','CLANG','CUDA','GPU','RHIP'], Device.DEFAULT" + PYTHONPATH=${{ github.workspace }} python3 -c "from tinygrad import Device; assert Device.DEFAULT in ['LLVM','CLANG','CUDA','GPU','RHIP','AMD'], Device.DEFAULT" DEBUG=5 PYTHONPATH=${{ github.workspace }} FORWARD_ONLY=1 python3 test/test_ops.py TestOps.test_add - name: Verify OpenCL autogen if: matrix.backend == 'gpu' @@ -442,8 +442,8 @@ jobs: ./autogen_stubs.sh comgr diff /tmp/hsa.py.bak tinygrad/runtime/autogen/hsa.py diff /tmp/comgr.py.bak tinygrad/runtime/autogen/comgr.py - - name: Run pytest (not cuda or hip) - if: matrix.backend!='cuda' && matrix.backend!='ptx' && matrix.backend!='triton' && matrix.backend != 'hip' + - name: Run pytest (not cuda or hip/amd) + if: matrix.backend!='cuda' && matrix.backend!='ptx' && matrix.backend!='triton' && matrix.backend != 'hip' && matrix.backend != 'amd' run: python -m pytest -n=auto test/ --durations=20 - name: Run ONNX (only LLVM) if: matrix.backend == 'llvm' @@ -454,6 +454,9 @@ jobs: - name: Run pytest (hip) if: matrix.backend=='hip' run: python -m pytest -n=auto test/test_ops.py test/test_dtype.py test/test_dtype_alu.py test/test_linearizer.py test/test_randomness.py test/imported/test_indexing.py test/external/external_test_hip_compile.py --durations=20 + - name: Run pytest (amd) + if: matrix.backend=='amd' + run: python -m pytest -n=auto test/test_ops.py test/test_dtype.py test/test_dtype_alu.py test/test_linearizer.py test/test_randomness.py test/imported/test_indexing.py test/external/external_test_hcq.py --durations=20 - name: Compile EfficientNet to C and test it if: matrix.backend=='clang' run: | diff --git a/extra/mockgpu/amd/amddriver.py b/extra/mockgpu/amd/amddriver.py new file mode 100644 index 0000000000..edbc9443ff --- /dev/null +++ b/extra/mockgpu/amd/amddriver.py @@ -0,0 +1,142 @@ +import pathlib, re, ctypes, mmap, collections, struct, functools, os, copy +import tinygrad.runtime.autogen.kfd as kfd +from typing import Optional, Any +from tinygrad.helpers import from_mv +from extra.mockgpu.driver import VirtDriver, VirtFileDesc, TextFileDesc, DirFileDesc, VirtFile +from extra.mockgpu.amd.amdgpu import AMDGPU, gpu_props + +libc = ctypes.CDLL(ctypes.util.find_library("c")) +libc.mmap.argtypes = [ctypes.c_void_p, ctypes.c_size_t, ctypes.c_int, ctypes.c_int, ctypes.c_int, ctypes.c_long] +libc.mmap.restype = ctypes.c_void_p + +def ioctls_from_header(): + hdrpy = (pathlib.Path(__file__).parent.parent.parent.parent / "tinygrad" / "runtime" / "autogen" / "kfd.py").read_text() + pattern = r'# (AMDKFD_IOC_[A-Z0-9_]+)\s=\s_(IOW?R?).*\(( 0x[0-9a-fA-F]+) ,\s+struct\s([A-Za-z0-9_]+)\s+\)' + matches = re.findall(pattern, hdrpy, re.MULTILINE) + return type("KFD_IOCTLS", (object, ), {name: int(nr, 0x10) for name, _, nr, _ in matches}), \ + {int(nr, 0x10): getattr(kfd, "struct_"+sname) for name, idir, nr, sname in matches} +kfd_ioctls, kfd_headers = ioctls_from_header() + +class KFDFileDesc(VirtFileDesc): + def __init__(self, fd, driver): + super().__init__(fd) + self.driver = driver + + def ioctl(self, fd, request, argp): return self.driver.kfd_ioctl(request, argp) + def mmap(self, start, sz, prot, flags, fd, offset): return offset + +class DRMFileDesc(VirtFileDesc): + def __init__(self, fd, driver, gpu): + super().__init__(fd) + self.driver, self.gpu = driver, gpu + + def mmap(self, start, sz, prot, flags, fd, offset): return libc.mmap(start, sz, prot, flags|mmap.MAP_ANONYMOUS, -1, 0) + +class AMDDriver(VirtDriver): + def __init__(self, gpus=6): + super().__init__() + + self.tracked_files += [VirtFile('/dev/kfd', functools.partial(KFDFileDesc, driver=self))] + \ + [VirtFile('/sys/devices/virtual/kfd/kfd/topology/nodes', functools.partial(DirFileDesc, child_names=[str(i) for i in range(gpus)]))] + + self.gpus = {} + self.next_fd = (1 << 30) + self.next_handle = 1 + self.next_event = 1 + + self.object_by_handle = {} + self.doorbells = {} + self.next_doorbell = collections.defaultdict(int) + + for i in range(gpus): self._prepare_gpu(i) + + def _alloc_fd(self): + my_fd = self.next_fd + self.next_fd = self.next_fd + 1 + return my_fd + + def _alloc_handle(self): + handle = self.next_handle + self.next_handle += 1 + return handle + + def _alloc_next_event_slot(self): + ev = self.next_event + self.next_event += 1 + return ev + + def _alloc_doorbell(self, gpu_id): + x = ctypes.addressof(from_mv(self.doorbells[gpu_id])) + self.next_doorbell[gpu_id] * 8 + self.next_doorbell[gpu_id] += 1 + return x + + def _prepare_gpu(self, gpu_id): + self.doorbells[gpu_id] = memoryview(bytearray(0x2000)) + self.gpus[gpu_id] = AMDGPU(gpu_id) + self.tracked_files += [ + VirtFile(f'/sys/devices/virtual/kfd/kfd/topology/nodes/{gpu_id}/gpu_id', functools.partial(TextFileDesc, text=f"{gpu_id}")), + VirtFile(f'/sys/devices/virtual/kfd/kfd/topology/nodes/{gpu_id}/properties', + functools.partial(TextFileDesc, text=gpu_props.format(drm_render_minor=gpu_id))), + VirtFile(f'/dev/dri/renderD{gpu_id}', functools.partial(DRMFileDesc, driver=self, gpu=f"{self.gpus[gpu_id]}")), + ] + + def open(self, name, flags, mode, virtfile): return virtfile.fdcls(self._alloc_fd()) + + def kfd_ioctl(self, req, argp): + nr = req & 0xFF + struct = kfd_headers[nr].from_address(argp) + + if nr == kfd_ioctls.AMDKFD_IOC_ACQUIRE_VM: pass + elif nr == kfd_ioctls.AMDKFD_IOC_ALLOC_MEMORY_OF_GPU: + if struct.gpu_id not in self.gpus: return -1 + struct.handle = self._alloc_handle() + self.object_by_handle[struct.handle] = copy.deepcopy(struct) # save memory struct to know what mem it is + elif nr == kfd_ioctls.AMDKFD_IOC_FREE_MEMORY_OF_GPU: + self.object_by_handle.pop(struct.handle) + elif nr == kfd_ioctls.AMDKFD_IOC_MAP_MEMORY_TO_GPU: + dev_ids = (ctypes.c_int32 * struct.n_devices).from_address(struct.device_ids_array_ptr) + for i in range(struct.n_devices): + gpu = self.gpus[dev_ids[i]] + mem_obj = self.object_by_handle[struct.handle] + gpu.map_range(mem_obj.va_addr, mem_obj.size) + struct.n_success = i + 1 + elif nr == kfd_ioctls.AMDKFD_IOC_UNMAP_MEMORY_FROM_GPU: + dev_ids = (ctypes.c_int32 * struct.n_devices).from_address(struct.device_ids_array_ptr) + for i in range(struct.n_devices): + gpu = self.gpus[dev_ids[i]] + mem_obj = self.object_by_handle[struct.handle] + gpu.unmap_range(mem_obj.va_addr, mem_obj.size) + struct.n_success = i + 1 + elif nr == kfd_ioctls.AMDKFD_IOC_CREATE_EVENT: + struct.event_slot_index = self._alloc_next_event_slot() + struct.event_id = struct.event_slot_index + elif nr == kfd_ioctls.AMDKFD_IOC_CREATE_QUEUE: + gpu = self.gpus[struct.gpu_id] + if struct.queue_type == kfd.KFD_IOC_QUEUE_TYPE_SDMA: + gpu.add_sdma_queue(struct.ring_base_address, struct.ring_size, struct.read_pointer_address, struct.write_pointer_address) + elif struct.queue_type == kfd.KFD_IOC_QUEUE_TYPE_COMPUTE: + gpu.add_pm4_queue(struct.ring_base_address, struct.ring_size, struct.read_pointer_address, struct.write_pointer_address) + else: raise RuntimeError("Unsuported, queue") + + # Track writes to doorbell, calling callback + struct.doorbell_offset = self._alloc_doorbell(struct.gpu_id) + self.track_address(struct.doorbell_offset, struct.doorbell_offset + 8, lambda mv,off: None, lambda mv, off: self._emulate_execute()) + elif nr == kfd_ioctls.AMDKFD_IOC_WAIT_EVENTS: + pass + else: + name = "unknown" + for k,v in kfd_ioctls.__dict__.items(): + if nr == v: name = k + assert False, f"unknown kfd ioctl, {nr} {name}" + exit(1) + return 0 + + def _emulate_execute(self): + any_progress = True + while any_progress: + any_progress = False + for gpu in self.gpus.values(): + for q in gpu.queues: + if (prev_rptr:=q.rptr[0]) != q.wptr[0]: + q.execute() + any_progress |= (prev_rptr != q.rptr[0]) diff --git a/extra/mockgpu/amd/amdgpu.py b/extra/mockgpu/amd/amdgpu.py new file mode 100644 index 0000000000..eadb685497 --- /dev/null +++ b/extra/mockgpu/amd/amdgpu.py @@ -0,0 +1,260 @@ +import ctypes, time +from extra.mockgpu.gpu import VirtGPU +from tinygrad.helpers import to_mv, init_c_struct_t +import tinygrad.runtime.autogen.amd_gpu as amd_gpu + +SDMA_MAX_COPY_SIZE = 0x400000 + +BASE_ADDR = 0x00001260 +PACKET3_SET_SH_REG_START = 0x2c00 +SUB = PACKET3_SET_SH_REG_START - BASE_ADDR + +regCOMPUTE_PGM_LO = 0x1bac - SUB +regCOMPUTE_USER_DATA_0 = 0x1be0 - SUB +regCOMPUTE_START_X = 0x1ba4 - SUB + +CACHE_FLUSH_AND_INV_TS_EVENT = 0x14 + +WAIT_REG_MEM_FUNCTION_ALWAYS = 0 +WAIT_REG_MEM_FUNCTION_EQ = 3 # == +WAIT_REG_MEM_FUNCTION_GEQ = 5 # >= + +remu = ctypes.CDLL("/usr/local/lib/libremu.so") +remu.run_asm.restype = ctypes.c_uint32 +remu.run_asm.argtypes = [ctypes.c_void_p, ctypes.c_uint32, ctypes.c_uint32, ctypes.c_uint32, ctypes.c_uint32, ctypes.c_uint32, ctypes.c_uint32, ctypes.c_uint32, ctypes.c_void_p] + +def create_sdma_packets(): + # TODO: clean up this, if we want to keep it + structs = {} + for name,pkt in [(name,s) for name,s in amd_gpu.__dict__.items() if name.startswith("struct_SDMA_PKT_") and name.endswith("_TAG")]: + names = set() + fields = [] + for pkt_fields in pkt._fields_: + if not pkt_fields[0].endswith("_UNION"): fields.append(pkt_fields) + else: + assert pkt_fields[1]._fields_[0][0] == '_0' + for union_fields in pkt_fields[1]._fields_[0][1]._fields_: + fname = union_fields[0] + if fname in names: fname = pkt_fields[0]+fname + names.add(fname) + # merge together 64-bit fields, otherwise just append them + if fname.endswith("_63_32") and fields[-1][0].endswith("_31_0"): fields[-1] = tuple([fname[:-6], ctypes.c_ulong, 64]) + else: fields.append(tuple([fname, *union_fields[1:]])) + new_name = name[16:-4].lower() + structs[new_name] = init_c_struct_t(tuple(fields)) + assert ctypes.sizeof(structs[new_name]) == ctypes.sizeof(pkt), f"{ctypes.sizeof(structs[new_name])} != {ctypes.sizeof(pkt)}" + return type("SDMA_PKTS", (object, ), structs) +sdma_pkts = create_sdma_packets() + +class AMDQueue(): + def __init__(self, base, size, rptr, wptr): + self.queue, self.size = to_mv(base, size).cast("I"), size + self.rptr = to_mv(rptr, 8).cast("Q") + self.wptr = to_mv(wptr, 8).cast("Q") + +class PM4Executor(AMDQueue): + def __init__(self, gpu, base, size, rptr, wptr): + self.gpu = gpu + super().__init__(base, size, rptr, wptr) + + def _next_dword(self): + x = self.queue[self.rptr[0] % (self.size // 4)] + self.rptr[0] += 1 + return x + + def execute(self): + while self.rptr[0] < self.wptr[0]: + cont = True + header = self._next_dword() + packet_type = header >> 30 + op = (header >> 8) & 0xFF + n = (header >> 16) & 0x3FFF + assert packet_type == 3, "Can parse only packet3" + if op == amd_gpu.PACKET3_SET_SH_REG: self._exec_set_sh_reg(n) + elif op == amd_gpu.PACKET3_ACQUIRE_MEM: self._exec_acquire_mem(n) + elif op == amd_gpu.PACKET3_RELEASE_MEM: self._exec_release_mem(n) + elif op == amd_gpu.PACKET3_WAIT_REG_MEM: cont = self._exec_wait_reg_mem(n) + elif op == amd_gpu.PACKET3_DISPATCH_DIRECT: self._exec_dispatch_direct(n) + else: raise RuntimeError(f"PM4: Unknown opcode: {op}") + if not cont: return + + def _exec_acquire_mem(self, n): + assert n == 6 + for _ in range(7): self._next_dword() # TODO: implement + + def _exec_release_mem(self, n): + assert n == 6 + mem_event_type = (self._next_dword() >> 0) & 0xff + selectors = self._next_dword() + mem_data_sel = (selectors >> 29) & 0b111 + int_sel = (selectors >> 24) & 0b11 + mem_dst_sel = (selectors >> 16) & 0b1 + addr_lo = self._next_dword() + addr_hi = self._next_dword() + val_lo = self._next_dword() + val_hi = self._next_dword() + val = val_lo + (val_hi << 32) + ev = self._next_dword() + + ptr = to_mv(addr_lo + (addr_hi << 32), 8) + if mem_data_sel == 1 or mem_data_sel == 2: ptr.cast('Q')[0] = val + elif mem_data_sel == 3: + if mem_event_type == CACHE_FLUSH_AND_INV_TS_EVENT: ptr.cast('I')[0] = int(time.perf_counter()) + else: raise RuntimeError(f"Unknown {mem_data_sel=} {mem_event_type=}") + else: raise RuntimeError(f"Unknown {mem_data_sel=}") + + def _exec_wait_reg_mem(self, n): + assert n == 5 + info = self._next_dword() + addr_lo = self._next_dword() + addr_hi = self._next_dword() + val = self._next_dword() + mask = self._next_dword() + timeout = self._next_dword() + + mem_function = (info >> 0) & 0b111 + mem_space = (info >> 4) & 0b1 + mem_op = (info >> 6) & 0b1 + mem_engine = (info >> 8) & 0b1 + + if mem_space == 0: read_op = lambda: val + elif mem_space == 1: read_op = lambda: to_mv(addr_lo + (addr_hi << 32), 4).cast('I')[0] + + if mem_function == WAIT_REG_MEM_FUNCTION_GEQ: cmp = lambda x,y: x >= y + elif mem_function == WAIT_REG_MEM_FUNCTION_EQ: cmp = lambda x,y: x == y + else: raise RuntimeError(f"Do not support {mem_function=}") + + mval = read_op() + can_cont = cmp(mval, val) + if not can_cont: self.rptr[0] = self.rptr[0] - 7 # revert packet, need to wait again + return can_cont + + def _exec_set_sh_reg(self, n): + reg = self._next_dword() + for i in range(n): + self.gpu.regs[reg] = self._next_dword() + reg += 1 + + def _exec_dispatch_direct(self, n): + assert n == 3 + gl = [self._next_dword() for _ in range(3)] + flags = self._next_dword() + + prg_addr = (self.gpu.regs[regCOMPUTE_PGM_LO] + (self.gpu.regs[regCOMPUTE_PGM_LO + 1] << 32)) << 8 + args_addr = self.gpu.regs[regCOMPUTE_USER_DATA_0] + (self.gpu.regs[regCOMPUTE_USER_DATA_0 + 1] << 32) + lc = [self.gpu.regs[i] for i in range(regCOMPUTE_START_X+3, regCOMPUTE_START_X+6)] + + prg_sz = 0 + for st,sz in self.gpu.mapped_ranges: + if st <= prg_addr <= st+sz: prg_sz = sz - (prg_addr - st) + + assert prg_sz > 0, "Invalid prg ptr (not found in mapped ranges)" + remu.run_asm(prg_addr, prg_sz, *gl, *lc, args_addr) + +class SDMAExecutor(AMDQueue): + def __init__(self, gpu, base, size, rptr, wptr): + self.gpu, self.base = gpu, base + super().__init__(base, size, rptr, wptr) + + def execute(self): + while self.rptr[0] < self.wptr[0]: + cont = True + header = self.queue[(self.rptr[0] // 4) % (self.size // 4)] + op = (header >> 0) & 0xff + if op == 0: self.rptr[0] += 4 + elif op == amd_gpu.SDMA_OP_FENCE: self._execute_fence() + elif op == amd_gpu.SDMA_OP_TRAP: self._execute_trap() + elif op == amd_gpu.SDMA_OP_POLL_REGMEM: cont = self._execute_poll_regmem() + elif op == amd_gpu.SDMA_OP_GCR: self._execute_gcr() + elif op == amd_gpu.SDMA_OP_COPY: self._execute_copy() + else: raise RuntimeError(f"Unknown SDMA op {op}") + if not cont: return + + def _execute_fence(self): + struct = sdma_pkts.fence.from_address(self.base + self.rptr[0] % self.size) + to_mv(struct.addr, 8).cast('Q')[0] = struct.data + self.rptr[0] += ctypes.sizeof(struct) + + def _execute_trap(self): + struct = sdma_pkts.trap.from_address(self.base + self.rptr[0] % self.size) + self.rptr[0] += ctypes.sizeof(struct) + + def _execute_poll_regmem(self): + struct = sdma_pkts.poll_regmem.from_address(self.base + self.rptr[0] % self.size) + + if struct.mem_poll == 0: read_op = lambda: struct.value + elif struct.mem_poll == 1: read_op = lambda: to_mv(struct.addr, 4).cast('I')[0] + + if struct.func == WAIT_REG_MEM_FUNCTION_GEQ: cmp = lambda x,y: x >= y + elif struct.func == WAIT_REG_MEM_FUNCTION_EQ: cmp = lambda x,y: x == y + elif struct.func == WAIT_REG_MEM_FUNCTION_ALWAYS: cmp = lambda x,y: True + else: raise RuntimeError(f"Do not support {struct.func=}") + + mval = read_op() & struct.mask + if not cmp(mval, struct.value): return False + + self.rptr[0] += ctypes.sizeof(struct) + return True + + def _execute_gcr(self): + struct = sdma_pkts.gcr.from_address(self.base + self.rptr[0] % self.size) + self.rptr[0] += ctypes.sizeof(struct) + + def _execute_copy(self): + struct = sdma_pkts.copy_linear.from_address(self.base + self.rptr[0] % self.size) + ctypes.memmove(struct.dst_addr, struct.src_addr, struct.count + 1) + self.rptr[0] += ctypes.sizeof(struct) + +class AMDGPU(VirtGPU): + def __init__(self, gpuid): + super().__init__(gpuid) + self.mapped_ranges = set() + self.queues = [] + + def map_range(self, vaddr, size): self.mapped_ranges.add((vaddr, size)) + def unmap_range(self, vaddr, size): self.mapped_ranges.remove((vaddr, size)) + def add_pm4_queue(self, base, size, rptr, wptr): + self.queues.append(PM4Executor(self, base, size, rptr, wptr)) + return len(self.queues) - 1 + def add_sdma_queue(self, base, size, rptr, wptr): + self.queues.append(SDMAExecutor(self, base, size, rptr, wptr)) + return len(self.queues) - 1 + +gpu_props = """cpu_cores_count 0 +simd_count 192 +mem_banks_count 1 +caches_count 206 +io_links_count 1 +p2p_links_count 5 +cpu_core_id_base 0 +simd_id_base 2147488032 +max_waves_per_simd 16 +lds_size_in_kb 64 +gds_size_in_kb 0 +num_gws 64 +wave_front_size 32 +array_count 12 +simd_arrays_per_engine 2 +cu_per_simd_array 8 +simd_per_cu 2 +max_slots_scratch_cu 32 +gfx_target_version 110000 +vendor_id 4098 +device_id 29772 +location_id 34304 +domain 0 +drm_render_minor {drm_render_minor} +hive_id 0 +num_sdma_engines 2 +num_sdma_xgmi_engines 0 +num_sdma_queues_per_engine 6 +num_cp_queues 8 +max_engine_clk_fcompute 2482 +local_mem_size 0 +fw_version 2140 +capability 671588992 +debug_prop 1495 +sdma_fw_version 20 +unique_id 11673270660693242239 +num_xcc 1 +max_engine_clk_ccompute 2400""" \ No newline at end of file diff --git a/extra/mockgpu/driver.py b/extra/mockgpu/driver.py new file mode 100644 index 0000000000..8c1d2aa245 --- /dev/null +++ b/extra/mockgpu/driver.py @@ -0,0 +1,96 @@ +import ctypes, struct, os, functools +from typing import Union +from dataclasses import dataclass +from tinygrad.helpers import round_up, to_mv + +class VirtFileDesc: + def __init__(self, fd): self.fd, self.off = fd, 0 + def read(self, fd, buf, sz): raise NotImplementedError() + def ioctl(self, fd, req, argp): raise NotImplementedError() + def mmap(self, st, sz, prot, flags, fd, off): raise NotImplementedError() + def write(self, fd, buf, sz): raise NotImplementedError() + def lseek(self, fd, off, whence): raise NotImplementedError() + def fstat(self, fd, buf): raise NotImplementedError() + def getdents(self, fd, buf, sz): return -1 + def close(self, fd): return 0 + +class TextFileDesc(VirtFileDesc): + def __init__(self, fd, text): + super().__init__(fd) + self.content = ctypes.create_string_buffer(text.encode()) + self.sz = len(self.content) - 1 + + def ioctl(self, fd, req, argp): return 0 + def write(self, fd, buf, sz): return -1 + def read(self, fd, buf, sz): + ctypes.memmove(buf, ctypes.addressof(self.content) + self.off, rdsz:=min(sz, self.sz - self.off)) + self.off += rdsz + return rdsz + def lseek(self, fd, off, whence): + if whence == os.SEEK_SET: self.off = off + elif whence == os.SEEK_CUR: self.off += off + elif whence == os.SEEK_END: self.off = self.sz + off + else: return -1 + return 0 + def fstat(self, fd, buf): + ctypes.memmove(buf, VirtFile.build_fstat(st_size=self.sz), 88) + return 0 + +class DirFileDesc(VirtFileDesc): + def __init__(self, fd, child_names): + super().__init__(fd) + child_names = ['.', '..'] + child_names + + tmp = b'' + for ino, name in enumerate(child_names): + tmp += VirtFile.build_dirent(ino + 1, 0, name) + self.content = ctypes.create_string_buffer(tmp) + self.sz = len(self.content) - 1 + + def ioctl(self, fd, req, argp): return 0 + def write(self, fd, buf, sz): return -1 + def read(self, fd, buf, sz): return -1 + def lseek(self, fd, off, whence): + if whence == os.SEEK_SET: self.off = off + elif whence == os.SEEK_CUR: self.off += off + elif whence == os.SEEK_END: self.off = self.sz + off + else: return -1 + return 0 + + def getdents(self, fd, buf, sz): + if self.sz == self.off: return 0 + if sz < self.sz: return -1 + ctypes.memmove(buf, ctypes.addressof(self.content) + self.off, self.sz) + self.off = self.sz + return self.sz + + def fstat(self, fd, buf): + ctypes.memmove(buf, VirtFile.build_fstat(st_mode=0o40755), 96) + return 0 + +@dataclass(frozen=True) +class VirtFile(): + path: str + fdcls: Union[VirtFileDesc, functools.partial[VirtFileDesc]] + + @staticmethod + def build_fstat(st_dev=0x20, st_ino=0x100000, st_mode=0o100777, st_nlink=1, st_uid=0, st_gid=0, st_rdev=0, st_size=0, + st_blksize=4096, st_blocks=0, st_atime=0, st_mtime=0, st_ctime=0): + assert (ssz:=struct.calcsize(fmt_string:='QQQIIIQQiQqqq')) == 96, f"{ssz} != 96" + return struct.pack(fmt_string, st_dev, st_ino, st_nlink, st_mode, st_uid, st_gid, + st_rdev, st_size, st_blksize, st_blocks, st_atime, st_mtime, st_ctime) + + @staticmethod + def build_dirent(d_ino, d_off, d_name, d_type=None): + # Start with packing inode number, offset, and record length + d_reclen = round_up(19 + len(d_name) + 1, 8) + packed_data = struct.pack('QQHc', d_ino, d_off, d_reclen, b'\x04') + d_name_bytes = d_name.encode() + return packed_data + d_name_bytes + b'\x00' + b'\x00' * (d_reclen - (19 + len(d_name) + 1)) + +class VirtDriver: + def __init__(self): + self.tracked_files = [] + self.tracked_addresses = [] + def track_address(self, staddr, enaddr, rcb, wcb): self.tracked_addresses.append((staddr, enaddr, rcb, wcb)) + def open(self, name, flags, mode): raise NotImplementedError() diff --git a/extra/mockgpu/gpu.py b/extra/mockgpu/gpu.py new file mode 100644 index 0000000000..8987902d8e --- /dev/null +++ b/extra/mockgpu/gpu.py @@ -0,0 +1,6 @@ +class VirtGPU: + def __init__(self, gpuid): + self.gpuid = gpuid + self.regs = {} + def map_range(self, vaddr, size): raise NotImplementedError() + def unmap_range(self, vaddr, size): raise NotImplementedError() diff --git a/extra/mockgpu/mockgpu.py b/extra/mockgpu/mockgpu.py new file mode 100644 index 0000000000..e32b21186b --- /dev/null +++ b/extra/mockgpu/mockgpu.py @@ -0,0 +1,197 @@ +import ctypes, ctypes.util, struct, platform, pathlib, re, time, os, builtins, atexit +from extra.mockgpu.amd.amddriver import AMDDriver +from tinygrad.helpers import from_mv, to_mv +start = time.perf_counter() + +# *** ioctl lib *** +libc = ctypes.CDLL(ctypes.util.find_library("c")) +libc.mmap.argtypes = [ctypes.c_void_p, ctypes.c_size_t, ctypes.c_int, ctypes.c_int, ctypes.c_int, ctypes.c_long] +libc.mmap.restype = ctypes.c_void_p +libc.munmap.argtypes = [ctypes.c_void_p, ctypes.c_size_t] +libc.munmap.restype = ctypes.c_int +libc.fdopendir.argtypes = [ctypes.c_int] +libc.fdopendir.restype = ctypes.c_void_p + +processor = platform.processor() +OPEN_SYSCALL = {"aarch64": None, "x86_64": 2}[processor] +CLOSE_SYSCALL = {"aarch64": 57, "x86_64": 3}[processor] +READ_SYSCALL = {"aarch64": 63, "x86_64": 0}[processor] +IOCTL_SYSCALL = {"aarch64": 29, "x86_64": 16}[processor] +MMAP_SYSCALL = {"aarch64": 222, "x86_64": 9}[processor] +LSEEK_SYSCALL = {"aarch64": 62, "x86_64": 8}[processor] +NEWFSTATAT_SYSCALL = {"aarch64": 79, "x86_64": 262}[processor] +GETDENTS64_SYSCALL = {"aarch64": 61, "x86_64": 217}[processor] + +def install_hook(c_function, python_function): + python_function_addr = ctypes.cast(ctypes.byref(python_function), ctypes.POINTER(ctypes.c_ulong)).contents.value + if processor == "x86_64": + # tramp = b"\x49\xB8" + struct.pack("Q", python_function_addr) + b"\x41\xFF\xE0" + # push r9 + # push r9 + # mov r9, 0x1122334455667788 + # mov [rsp+8], r9 + # pop r9 + # ret + tramp = b"\x41\x51\x41\x51\x49\xB9" + struct.pack("Q", python_function_addr) + b"\x4C\x89\x4C\x24\x08\x41\x59\xC3" + else: + raise Exception(f"processor {processor} not supported") + + original_bc = (ctypes.c_char * 64)() + + # get real ioctl address + ioctl_address = ctypes.cast(ctypes.byref(c_function), ctypes.POINTER(ctypes.c_ulong)) + + # hook ioctl + ret = libc.mprotect(ctypes.c_ulong((ioctl_address.contents.value//0x1000)*0x1000), 0x2000, 7) + assert ret == 0 + libc.memcpy(original_bc, ioctl_address.contents, len(tramp)) + libc.memcpy(ioctl_address.contents, ctypes.create_string_buffer(tramp), len(tramp)) + + # Restore correct functions to close libs after python exits + def __restore(): libc.memcpy(ioctl_address.contents, original_bc, len(tramp)) + atexit.register(__restore) + +drivers = [AMDDriver()] +tracked_fds = {} + +@ctypes.CFUNCTYPE(ctypes.c_int, ctypes.c_char_p, ctypes.c_int, ctypes.c_ulong) +def _open(name, flags, mode): + for d in drivers: + pyname = name.decode() + for x in d.tracked_files: + if pyname == x.path: + virtfd = d.open(pyname, flags, mode, x) + tracked_fds[virtfd.fd] = virtfd + return virtfd.fd + + libc.syscall.argtypes = [ctypes.c_ulong, ctypes.c_char_p, ctypes.c_int, ctypes.c_ulong] + libc.syscall.restype = ctypes.c_int + return libc.syscall(OPEN_SYSCALL, name, flags, mode) + +@ctypes.CFUNCTYPE(ctypes.c_void_p, ctypes.c_char_p) +def _opendir(name): + fd = _open(name, os.O_RDONLY| os.O_DIRECTORY, 0) + if fd >= 0x80: + fake_dirfd = _open(".".encode(), os.O_RDONLY| os.O_DIRECTORY, 0) + st = libc.fdopendir(fake_dirfd) + to_mv(st, 8).cast('Q')[0] = fd + return st + else: return libc.fdopendir(fd) + +@ctypes.CFUNCTYPE(ctypes.c_int, ctypes.c_int) +def _close(fd): + if fd in tracked_fds: + tracked_fds[fd].close(fd) + tracked_fds.pop(fd) + return 0 + + libc.syscall.argtypes = [ctypes.c_ulong, ctypes.c_int] + libc.syscall.restype = ctypes.c_int + return libc.syscall(CLOSE_SYSCALL, fd) + +@ctypes.CFUNCTYPE(ctypes.c_int, ctypes.c_void_p) +def _closedir(st): return _close(to_mv(st, 8).cast('Q')[0]) + +@ctypes.CFUNCTYPE(ctypes.c_int, ctypes.c_int, ctypes.c_ulong, ctypes.c_void_p) +def _ioctl(fd, request, argp): + if fd in tracked_fds: return tracked_fds[fd].ioctl(fd, request, argp) + + libc.syscall.argtypes = [ctypes.c_ulong, ctypes.c_int, ctypes.c_ulong, ctypes.c_void_p] + libc.syscall.restype = ctypes.c_int + return libc.syscall(IOCTL_SYSCALL, ctypes.c_int(fd), ctypes.c_ulong(request), ctypes.c_void_p(argp)) + +@ctypes.CFUNCTYPE(ctypes.c_long, ctypes.c_int, ctypes.c_void_p, ctypes.c_size_t) +def _read(fd, buf, sz): + if fd in tracked_fds: return tracked_fds[fd].read(fd, buf, sz) + + libc.syscall.argtypes = [ctypes.c_ulong, ctypes.c_int, ctypes.c_void_p, ctypes.c_size_t] + libc.syscall.restype = ctypes.c_int + return libc.syscall(READ_SYSCALL, ctypes.c_int(fd), ctypes.c_void_p(buf), ctypes.c_size_t(sz)) + +@ctypes.CFUNCTYPE(ctypes.c_int, ctypes.c_int, ctypes.c_ulong, ctypes.c_int) +def _lseek64(fd, off, whence): + if fd in tracked_fds: return tracked_fds[fd].lseek(fd, off, whence) + + libc.syscall.argtypes = [ctypes.c_ulong, ctypes.c_int, ctypes.c_ulong, ctypes.c_int] + libc.syscall.restype = ctypes.c_int + return libc.syscall(LSEEK_SYSCALL, fd, off, whence) + +@ctypes.CFUNCTYPE(ctypes.c_int, ctypes.c_char_p, ctypes.c_void_p) +def _stat64(name, buf): + for d in drivers: + pyname = name.decode() + for x in d.tracked_files: + if pyname == x.path: + virtfd = d.open(pyname, 0, 0, x) + return virtfd.fstat(virtfd.fd, buf) + + libc.syscall.argtypes = [ctypes.c_ulong, ctypes.c_int, ctypes.c_char_p, ctypes.c_void_p, ctypes.c_ulong] + libc.syscall.restype = ctypes.c_int + return libc.syscall(NEWFSTATAT_SYSCALL, -100, name, ctypes.c_void_p(buf), 0) + +@ctypes.CFUNCTYPE(ctypes.c_int, ctypes.c_int, ctypes.c_void_p) +def _fstat64(fd, buf): + if fd in tracked_fds: return tracked_fds[fd].fstat(fd, buf) + + empty_str = (ctypes.c_char*1)() + libc.syscall.argtypes = [ctypes.c_ulong, ctypes.c_int, ctypes.c_char_p, ctypes.c_void_p, ctypes.c_ulong] + libc.syscall.restype = ctypes.c_int + return libc.syscall(NEWFSTATAT_SYSCALL, ctypes.c_int(fd), empty_str, ctypes.c_void_p(buf), 0x1000) + +@ctypes.CFUNCTYPE(ctypes.c_int, ctypes.c_int, ctypes.c_void_p, ctypes.c_ulong) +def _getdents64(fd, buf, sz): + if fd in tracked_fds: return tracked_fds[fd].getdents(fd, buf, sz) + + libc.syscall.argtypes = [ctypes.c_ulong, ctypes.c_int, ctypes.c_void_p, ctypes.c_ulong] + libc.syscall.restype = ctypes.c_int + return libc.syscall(GETDENTS64_SYSCALL, fd, buf, sz) + +def _mmap(start, sz, prot, flags, fd, offset): + if fd in tracked_fds: return tracked_fds[fd].mmap(start, sz, prot, flags, fd, offset) + return libc.mmap(start, sz, prot, flags, fd, offset) + +def _munmap(buf, sz): + return libc.munmap(buf, sz) + +orignal_memoryview = builtins.memoryview +class TrackedMemoryView: + def __init__(self, data, rcb, wcb): + self.mv = orignal_memoryview(data) + self.rcb, self.wcb = rcb, wcb + + def __getitem__(self, index): + self.rcb(self.mv, index) + return self.mv[index] + + def __setitem__(self, index, value): + self.mv[index] = value + self.wcb(self.mv, index) + + def cast(self, new_type, **kwargs): + self.mv = self.mv.cast(new_type, **kwargs) + return self + + @property + def nbytes(self): return self.mv.nbytes + def __len__(self): return len(self.mv) + def __repr__(self): return repr(self.mv) + +def _memoryview(mem): + if isinstance(mem, int) or isinstance(mem, ctypes.Array): + addr = ctypes.addressof(mem) if isinstance(mem, ctypes.Array) else mem + for d in drivers: + for st,en,rcb,wcb in d.tracked_addresses: + if st <= addr <= en: return TrackedMemoryView(mem, rcb, wcb) + return orignal_memoryview(mem) + +install_hook(libc.open, _open) +install_hook(libc.opendir, _opendir) +install_hook(libc.close, _close) +install_hook(libc.closedir, _closedir) +install_hook(libc.ioctl, _ioctl) +install_hook(libc.read, _read) +install_hook(libc.lseek64, _lseek64) +install_hook(libc.stat64, _stat64) +install_hook(libc.fstat64, _fstat64) +install_hook(libc.getdents64, _getdents64) +builtins.memoryview = _memoryview # type: ignore diff --git a/test/external/external_test_hcq.py b/test/external/external_test_hcq.py index 9677bcd9c2..3b2935c329 100644 --- a/test/external/external_test_hcq.py +++ b/test/external/external_test_hcq.py @@ -1,8 +1,9 @@ import unittest, ctypes, struct, time, array from tinygrad import Device, Tensor, dtypes -from tinygrad.helpers import to_mv +from tinygrad.helpers import to_mv, CI from tinygrad.device import Buffer, BufferOptions from tinygrad.engine.schedule import create_schedule +from tinygrad.engine.realize import get_runner def _time_queue(q, d): st = time.perf_counter() @@ -21,7 +22,7 @@ class TestHCQ(unittest.TestCase): TestHCQ.a = Tensor([0.,1.], device=Device.DEFAULT).realize() TestHCQ.b = self.a + 1 si = create_schedule([self.b.lazydata])[-1] - TestHCQ.runner = TestHCQ.d0.get_runner(*si.ast) + TestHCQ.runner = get_runner(TestHCQ.d0.dname, si.ast) TestHCQ.b.lazydata.buffer.allocate() # wow that's a lot of abstraction layers TestHCQ.addr = struct.pack("QQ", TestHCQ.b.lazydata.buffer._buf.va_addr, TestHCQ.a.lazydata.buffer._buf.va_addr) @@ -53,11 +54,11 @@ class TestHCQ(unittest.TestCase): temp_signal, temp_value = TestHCQ.d0._get_signal(value=0), 0 q = TestHCQ.compute_queue() for _ in range(1000): - q.exec(TestHCQ.runner.clprg, TestHCQ.d0.kernargs_ptr, TestHCQ.runner.global_size, TestHCQ.runner.local_size) + q.exec(TestHCQ.runner.clprg, TestHCQ.d0.kernargs_ptr, TestHCQ.runner.p.global_size, TestHCQ.runner.p.local_size) q.signal(temp_signal, temp_value + 1).wait(temp_signal, temp_value + 1) temp_value += 1 - q.exec(TestHCQ.runner.clprg, TestHCQ.d0.kernargs_ptr+TestHCQ.kernargs_size, TestHCQ.runner.global_size, TestHCQ.runner.local_size) + q.exec(TestHCQ.runner.clprg, TestHCQ.d0.kernargs_ptr+TestHCQ.kernargs_size, TestHCQ.runner.p.global_size, TestHCQ.runner.p.local_size) q.signal(temp_signal, temp_value + 1).wait(temp_signal, temp_value + 1) temp_value += 1 @@ -70,10 +71,10 @@ class TestHCQ(unittest.TestCase): def test_run_1000_times(self): temp_signal = TestHCQ.d0._get_signal(value=0) q = TestHCQ.compute_queue() - q.exec(TestHCQ.runner.clprg, TestHCQ.d0.kernargs_ptr, TestHCQ.runner.global_size, TestHCQ.runner.local_size) + q.exec(TestHCQ.runner.clprg, TestHCQ.d0.kernargs_ptr, TestHCQ.runner.p.global_size, TestHCQ.runner.p.local_size) q.signal(temp_signal, 2).wait(temp_signal, 2) - q.exec(TestHCQ.runner.clprg, TestHCQ.d0.kernargs_ptr+TestHCQ.kernargs_size, TestHCQ.runner.global_size, - TestHCQ.runner.local_size) + q.exec(TestHCQ.runner.clprg, TestHCQ.d0.kernargs_ptr+TestHCQ.kernargs_size, TestHCQ.runner.p.global_size, + TestHCQ.runner.p.local_size) for _ in range(1000): TestHCQ.d0._set_signal(temp_signal, 1) q.submit(TestHCQ.d0) @@ -85,16 +86,17 @@ class TestHCQ(unittest.TestCase): def test_run_to_3(self): temp_signal = TestHCQ.d0._get_signal(value=0) q = TestHCQ.compute_queue() - q.exec(TestHCQ.runner.clprg, TestHCQ.d0.kernargs_ptr, TestHCQ.runner.global_size, TestHCQ.runner.local_size) + q.exec(TestHCQ.runner.clprg, TestHCQ.d0.kernargs_ptr, TestHCQ.runner.p.global_size, TestHCQ.runner.p.local_size) q.signal(temp_signal, 1).wait(temp_signal, 1) - q.exec(TestHCQ.runner.clprg, TestHCQ.d0.kernargs_ptr+TestHCQ.kernargs_size, TestHCQ.runner.global_size, TestHCQ.runner.local_size) + q.exec(TestHCQ.runner.clprg, TestHCQ.d0.kernargs_ptr+TestHCQ.kernargs_size, TestHCQ.runner.p.global_size, TestHCQ.runner.p.local_size) q.signal(temp_signal, 2).wait(temp_signal, 2) - q.exec(TestHCQ.runner.clprg, TestHCQ.d0.kernargs_ptr, TestHCQ.runner.global_size, TestHCQ.runner.local_size) + q.exec(TestHCQ.runner.clprg, TestHCQ.d0.kernargs_ptr, TestHCQ.runner.p.global_size, TestHCQ.runner.p.local_size) q.signal(TestHCQ.d0.timeline_signal, TestHCQ.d0.timeline_value).submit(TestHCQ.d0) TestHCQ.d0._wait_signal(TestHCQ.d0.timeline_signal, TestHCQ.d0.timeline_value) TestHCQ.d0.timeline_value += 1 assert (val:=TestHCQ.b.lazydata.buffer.as_buffer().cast("f")[0]) == 3.0, f"got val {val}" + @unittest.skipIf(CI, "Can't handle async update on CPU") def test_wait_signal(self): temp_signal = TestHCQ.d0._get_signal(value=0) TestHCQ.compute_queue().wait(temp_signal, value=1).signal(TestHCQ.d0.timeline_signal, TestHCQ.d0.timeline_value).submit(TestHCQ.d0) @@ -105,6 +107,7 @@ class TestHCQ(unittest.TestCase): TestHCQ.d0._wait_signal(TestHCQ.d0.timeline_signal, TestHCQ.d0.timeline_value, timeout=100) TestHCQ.d0.timeline_value += 1 + @unittest.skipIf(CI, "Can't handle async update on CPU") def test_wait_copy_signal(self): temp_signal = TestHCQ.d0._get_signal(value=0) TestHCQ.copy_queue().wait(temp_signal, value=1).signal(TestHCQ.d0.timeline_signal, TestHCQ.d0.timeline_value).submit(TestHCQ.d0) @@ -117,7 +120,7 @@ class TestHCQ(unittest.TestCase): def test_run_normal(self): q = TestHCQ.compute_queue() - q.exec(TestHCQ.runner.clprg, TestHCQ.d0.kernargs_ptr, TestHCQ.runner.global_size, TestHCQ.runner.local_size) + q.exec(TestHCQ.runner.clprg, TestHCQ.d0.kernargs_ptr, TestHCQ.runner.p.global_size, TestHCQ.runner.p.local_size) q.signal(TestHCQ.d0.timeline_signal, TestHCQ.d0.timeline_value).submit(TestHCQ.d0) TestHCQ.d0._wait_signal(TestHCQ.d0.timeline_signal, TestHCQ.d0.timeline_value) TestHCQ.d0.timeline_value += 1 @@ -147,7 +150,7 @@ class TestHCQ(unittest.TestCase): def test_run_signal(self): q = TestHCQ.compute_queue() - q.exec(TestHCQ.runner.clprg, TestHCQ.d0.kernargs_ptr, TestHCQ.runner.global_size, TestHCQ.runner.local_size) + q.exec(TestHCQ.runner.clprg, TestHCQ.d0.kernargs_ptr, TestHCQ.runner.p.global_size, TestHCQ.runner.p.local_size) q.signal(TestHCQ.d0.timeline_signal, TestHCQ.d0.timeline_value) q.submit(TestHCQ.d0) TestHCQ.d0._wait_signal(TestHCQ.d0.timeline_signal, TestHCQ.d0.timeline_value) @@ -187,7 +190,7 @@ class TestHCQ(unittest.TestCase): et = _time_queue(q, TestHCQ.d0) gb_s = (SZ/1e9)/et print(f"same device copy: {et*1e3:.2f} ms, {gb_s:.2f} GB/s") - assert gb_s > 10 and gb_s < 1000 + assert (0.3 if CI else 10) <= gb_s <= 1000 def test_cross_device_copy_bandwidth(self): SZ = 2_000_000_000 @@ -199,12 +202,12 @@ class TestHCQ(unittest.TestCase): et = _time_queue(q, TestHCQ.d0) gb_s = (SZ/1e9)/et print(f"cross device copy: {et*1e3:.2f} ms, {gb_s:.2f} GB/s") - assert gb_s > 2 and gb_s < 50 + assert (0.3 if CI else 2) <= gb_s <= 50 def test_interleave_compute_and_copy(self): q = TestHCQ.compute_queue() qc = TestHCQ.copy_queue() - q.exec(TestHCQ.runner.clprg, TestHCQ.d0.kernargs_ptr, TestHCQ.runner.global_size, TestHCQ.runner.local_size) # b = [1, 2] + q.exec(TestHCQ.runner.clprg, TestHCQ.d0.kernargs_ptr, TestHCQ.runner.p.global_size, TestHCQ.runner.p.local_size) # b = [1, 2] q.signal(sig:=TestHCQ.d0._get_signal(value=0), value=1) qc.wait(sig, value=1) qc.copy(TestHCQ.a.lazydata.buffer._buf.va_addr, TestHCQ.b.lazydata.buffer._buf.va_addr, 8) @@ -240,7 +243,7 @@ class TestHCQ(unittest.TestCase): for _ in range(40): q = TestHCQ.compute_queue() q.wait(TestHCQ.d0.timeline_signal, TestHCQ.d0.timeline_value - 1) - q.exec(TestHCQ.runner.clprg, TestHCQ.d0.kernargs_ptr, TestHCQ.runner.global_size, TestHCQ.runner.local_size) + q.exec(TestHCQ.runner.clprg, TestHCQ.d0.kernargs_ptr, TestHCQ.runner.p.global_size, TestHCQ.runner.p.local_size) q.signal(TestHCQ.d0.timeline_signal, TestHCQ.d0.timeline_value).submit(TestHCQ.d0) TestHCQ.d0._wait_signal(TestHCQ.d0.timeline_signal, TestHCQ.d0.timeline_value) TestHCQ.d0.timeline_value += 1 diff --git a/test/helpers.py b/test/helpers.py index 71ace0c766..f1f39f9ce3 100644 --- a/test/helpers.py +++ b/test/helpers.py @@ -26,7 +26,7 @@ def assert_jit_cache_len(fxn, expected_len): def is_dtype_supported(dtype: DType, device: str = Device.DEFAULT): if dtype == dtypes.bfloat16: # NOTE: this requires bf16 buffer support - return device in {"RHIP", "HSA"} or (device == "CUDA" and not CI and not getenv("PTX")) + return device in {"RHIP", "HSA", "AMD"} or (device == "CUDA" and not CI and not getenv("PTX")) if device in ["WEBGPU", "WEBGL"]: return dtype in [dtypes.float, dtypes.int32, dtypes.uint32] if device == "CUDA" and getenv("PTX") and dtype in (dtypes.int8, dtypes.uint8): return False # for CI GPU and OSX, cl_khr_fp16 isn't supported diff --git a/test/test_dtype_alu.py b/test/test_dtype_alu.py index e87ba62e62..32f615dc9c 100644 --- a/test/test_dtype_alu.py +++ b/test/test_dtype_alu.py @@ -145,7 +145,7 @@ class TestDTypeALU(unittest.TestCase): def test_int32_midcast_float(self, a, b, c, op1, op2): universal_test_midcast(a, b, c, op1, op2, dtypes.int32, dtypes.float32) # Metal and CUDACPU and HIP behave differently than numpy in CI for overflows - skip_overflow = CI and (Device.DEFAULT in {"RHIP", "HSA"} or getenv("CUDACPU")) + skip_overflow = CI and (Device.DEFAULT in {"RHIP", "HSA", "AMD"} or getenv("CUDACPU")) @given(strat.floats(width=32, min_value=0, max_value=10.0) if skip_overflow else ht.float32, strat.floats(width=32, min_value=0, max_value=10.0) if skip_overflow else ht.float32, ht.int32, strat.sampled_from(binary_operations), strat.sampled_from(integer_binary_operations)) diff --git a/test/test_linearizer.py b/test/test_linearizer.py index b76967f0a3..ff1536aabb 100644 --- a/test/test_linearizer.py +++ b/test/test_linearizer.py @@ -292,7 +292,7 @@ class TestLinearizer(unittest.TestCase): # check correctness helper_tc_allclose(tc.dims[0]+pad, tc.dims[1]+pad, tc.dims[2]+pad, tc.dtype_in, tc.dtype_out, tc_opt=2) - @unittest.skipIf(Device.DEFAULT == "RHIP", "RHIP is really slow here") + @unittest.skipIf(CI and Device.DEFAULT in {"RHIP", "AMD"}, "RHIP/AMD CI is really slow here") def test_tensor_cores_multi_reduce(self): if not Device[Device.DEFAULT].renderer.has_tensor_cores: self.skipTest("device doesn't have tensor cores") @@ -852,7 +852,7 @@ class TestKernelOpts(unittest.TestCase): ], apply_tc=True, atol=atol, rtol=rtol) def test_padto_matmul(self): - if CI and Device.DEFAULT in ["CUDA", "RHIP"]: self.skipTest("super slow on CUDA and RHIP because of the big grid dims") + if CI and Device.DEFAULT in ["CUDA", "RHIP", "AMD"]: self.skipTest("super slow on CUDA and RHIP because of the big grid dims") N = 17 * 17 Tensor.manual_seed(289) a = Tensor.rand(N, N) diff --git a/test/test_randomness.py b/test/test_randomness.py index 999ca5451c..e66beb4637 100644 --- a/test/test_randomness.py +++ b/test/test_randomness.py @@ -104,7 +104,7 @@ class TestRandomness(unittest.TestCase): self.assertTrue(equal_distribution(Tensor.randn, torch.randn, lambda x: np.random.randn(*x))) @given(strat.sampled_from([dtypes.float, dtypes.float16, dtypes.bfloat16])) - @unittest.skipIf(Device.DEFAULT in ["HSA", "RHIP"], "bfloat16 local buffer broken in HSA") + @unittest.skipIf(Device.DEFAULT in ["HSA", "RHIP", "AMD"], "bfloat16 local buffer broken in HSA") def test_randn_finite(self, default_float): if not is_dtype_supported(default_float): return old_default_float = dtypes.default_float diff --git a/tinygrad/runtime/ops_amd.py b/tinygrad/runtime/ops_amd.py index 6504cf984e..d70e0dd11b 100644 --- a/tinygrad/runtime/ops_amd.py +++ b/tinygrad/runtime/ops_amd.py @@ -1,6 +1,6 @@ from __future__ import annotations from typing import Tuple, List, Any, cast -import os, fcntl, ctypes, functools, re, pathlib, mmap, struct, errno, subprocess, time +import os, fcntl, ctypes, ctypes.util, functools, re, pathlib, mmap, struct, errno, subprocess, time from tinygrad.device import Compiled, Compiler, BufferOptions, LRUAllocator from tinygrad.helpers import getenv, from_mv, init_c_struct_t, to_mv, round_up, DEBUG from tinygrad.renderer.cstyle import HIPRenderer @@ -11,12 +11,17 @@ import tinygrad.runtime.autogen.hsa as hsa import tinygrad.runtime.autogen.amd_gpu as amd_gpu if getenv("IOCTL"): import extra.hip_gpu_driver.hip_ioctl # noqa: F401 -libc = ctypes.CDLL("libc.so.6") +libc = ctypes.CDLL(ctypes.util.find_library("c")) libc.mmap.argtypes = [ctypes.c_void_p, ctypes.c_size_t, ctypes.c_int, ctypes.c_int, ctypes.c_int, ctypes.c_long] libc.mmap.restype = ctypes.c_void_p libc.munmap.argtypes = [ctypes.c_void_p, ctypes.c_size_t] libc.munmap.restype = ctypes.c_int +if getenv("MOCKGPU"): + import extra.mockgpu.mockgpu # noqa: F401 + libc.mmap = extra.mockgpu.mockgpu._mmap # type: ignore + libc.munmap = extra.mockgpu.mockgpu._munmap # type: ignore + def is_usable_gpu(gpu_id): try: with gpu_id.open() as f: