qcom cache flush (#6367)

* qcom cache flush

* bench

* linter

* move
This commit is contained in:
nimlgen
2024-09-05 13:23:39 +03:00
committed by GitHub
parent 62f9f273f7
commit a1a15b54c9
2 changed files with 28 additions and 11 deletions

View File

@@ -4,7 +4,7 @@ import onnx
from onnx.helper import tensor_dtype_to_np_dtype
from extra.onnx import get_run_onnx
from tinygrad import Tensor, dtypes, TinyJit
from tinygrad.helpers import IMAGE, GlobalCounters, fetch, colored, getenv
from tinygrad.helpers import IMAGE, GlobalCounters, fetch, colored, getenv, trange
from tinygrad.tensor import _from_np_dtype
import numpy as np
@@ -47,9 +47,14 @@ if __name__ == "__main__":
if getenv("SAVE_OUTPUT"):
np.save(path, tinygrad_out)
print(f"saved output to {path}!")
elif getenv("FUZZ") and path.exists():
known_good_out = np.load(path)
for _ in trange(1000):
ret = next(iter(run_onnx_jit(new_inputs).values())).cast(dtypes.float32).numpy()
np.testing.assert_allclose(known_good_out, ret, atol=1e-2, rtol=1e-2)
print(colored("fuzz validated!", "green"))
elif path.exists():
known_good_out = np.load(path)
np.testing.assert_allclose(known_good_out, tinygrad_out, atol=1e-2, rtol=1e-2)
print(colored("outputs validated!", "green"))
else:

View File

@@ -51,11 +51,21 @@ class QCOMComputeQueue(HWComputeQueue):
def reg(self, reg: int, *vals: int): self.q += [pkt4_hdr(reg, len(vals)), *vals]
def _cache_flush(self, write_back=True, invalidate=False, sync=True, memsync=False):
# TODO: 7xx support.
if write_back: self.cmd(adreno.CP_EVENT_WRITE, adreno.CACHE_FLUSH_TS, *data64_le(QCOMDevice.dummy_addr), 0) # dirty cache write-back.
if invalidate: self.cmd(adreno.CP_EVENT_WRITE, adreno.CACHE_INVALIDATE) # invalidate cache lines (following reads from RAM).
if memsync: self.cmd(adreno.CP_WAIT_MEM_WRITES)
if sync: self.cmd(adreno.CP_WAIT_FOR_IDLE)
def _memory_barrier(self): self._cache_flush(write_back=True, invalidate=True, sync=True, memsync=True)
def _signal(self, signal, value=0, ts=False):
self.cmd(adreno.CP_WAIT_FOR_IDLE)
if QCOMDevice.gpu_id < 700:
self.cmd(adreno.CP_EVENT_WRITE, qreg.cp_event_write_0(event=adreno.CACHE_FLUSH_TS, timestamp=ts),
*data64_le(mv_address(signal._signal) + (0 if not ts else 8)), qreg.cp_event_write_3(value&0xFFFFFFFF))
self.cmd(adreno.CP_EVENT_WRITE, adreno.CACHE_INVALIDATE)
*data64_le(mv_address(signal._signal) + (0 if not ts else 8)), qreg.cp_event_write_3(value & 0xFFFFFFFF))
self._cache_flush(write_back=True, invalidate=False, sync=False, memsync=False)
else:
# TODO: support devices starting with 8 Gen 1. Also, 700th series have convenient CP_GLOBAL_TIMESTAMP and CP_LOCAL_TIMESTAMP
raise RuntimeError('CP_EVENT_WRITE7 is not supported')
@@ -67,8 +77,8 @@ class QCOMComputeQueue(HWComputeQueue):
qreg.cp_wait_reg_mem_3(ref=value&0xFFFFFFFF), qreg.cp_wait_reg_mem_4(mask=0xFFFFFFFF), qreg.cp_wait_reg_mem_5(delay_loop_cycles=32))
def _update_signal(self, cmd_idx, signal, value):
if signal is not None: self._patch(cmd_idx, offset=2, data=data64_le(mv_address(signal._signal)))
if value is not None: self._patch(cmd_idx, offset=4, data=[value & 0xFFFFFFFF])
if signal is not None: self._patch(cmd_idx, offset=3, data=data64_le(mv_address(signal._signal)))
if value is not None: self._patch(cmd_idx, offset=5, data=[value & 0xFFFFFFFF])
def _update_wait(self, cmd_idx, signal, value):
if signal is not None: self._patch(cmd_idx, offset=2, data=data64_le(mv_address(signal._signal)))
@@ -94,7 +104,6 @@ class QCOMComputeQueue(HWComputeQueue):
@hcq_command
def setup(self):
self.cmd(adreno.CP_WAIT_FOR_IDLE)
self.cmd(adreno.CP_SET_MARKER, qreg.a6xx_cp_set_marker_0(mode=adreno.RM6_COMPUTE))
self.reg(adreno.REG_A6XX_HLSQ_INVALIDATE_CMD, qreg.a6xx_hlsq_invalidate_cmd(cs_state=True, cs_ibo=True))
self.reg(adreno.REG_A6XX_HLSQ_INVALIDATE_CMD, qreg.a6xx_hlsq_invalidate_cmd())
@@ -104,12 +113,12 @@ class QCOMComputeQueue(HWComputeQueue):
self.reg(adreno.REG_A6XX_SP_PERFCTR_ENABLE, qreg.a6xx_sp_perfctr_enable(cs=True))
self.reg(adreno.REG_A6XX_SP_TP_MODE_CNTL, qreg.a6xx_sp_tp_mode_cntl(isammode=adreno.ISAMMODE_CL, unk3=2))
self.reg(adreno.REG_A6XX_TPL1_DBG_ECO_CNTL, qreg.a6xx_tpl1_dbg_eco_cntl())
self.cmd(adreno.CP_WAIT_FOR_IDLE)
def _exec(self, prg, args_state, global_size, local_size):
global_size_mp = cast(Tuple[int,int,int], tuple(int(g*l) for g,l in zip(global_size, local_size))) if local_size else global_size
self.cmd_idx_to_dims[len(self) - 1] = [global_size, local_size]
self.cmd(adreno.CP_WAIT_FOR_IDLE)
self.reg(adreno.REG_A6XX_HLSQ_CS_NDRANGE_0,
qreg.a6xx_hlsq_cs_ndrange_0(kerneldim=3, localsizex=local_size[0] - 1, localsizey=local_size[1] - 1, localsizez=local_size[2] - 1),
global_size_mp[0], 0, global_size_mp[1], 0, global_size_mp[2], 0, 0xccc0cf, 0xfc | qreg.a6xx_hlsq_cs_cntl_1(threadsize=adreno.THREAD64),
@@ -153,21 +162,22 @@ class QCOMComputeQueue(HWComputeQueue):
self.reg(adreno.REG_A6XX_SP_CS_CONFIG,
qreg.a6xx_sp_cs_config(enabled=True, nsamp=args_state.samplers_cnt, ntex=args_state.descriptors_cnt, nibo=args_state.ibos_cnt))
self.cmd(adreno.CP_RUN_OPENCL, 0)
self._cache_flush(write_back=True, invalidate=False, sync=False, memsync=False)
def _update_exec(self, cmd_idx, global_size, local_size):
if global_size is not None:
self._patch(cmd_idx, offset=11, data=[int(math.ceil(global_size[0])), int(math.ceil(global_size[1])), int(math.ceil(global_size[2]))])
self._patch(cmd_idx, offset=10, data=[int(math.ceil(global_size[0])), int(math.ceil(global_size[1])), int(math.ceil(global_size[2]))])
self.cmd_idx_to_dims[cmd_idx][0] = global_size
if local_size is not None:
payload = qreg.a6xx_hlsq_cs_ndrange_0(kerneldim=3, localsizex=local_size[0] - 1, localsizey=local_size[1] - 1, localsizez=local_size[2] - 1)
self._patch(cmd_idx, offset=2, data=[payload])
self._patch(cmd_idx, offset=1, data=[payload])
self.cmd_idx_to_dims[cmd_idx][1] = local_size
global_size_mp = self.cmd_idx_to_dims[cmd_idx][0]
if self.cmd_idx_to_dims[cmd_idx][1]:
global_size_mp = cast(Tuple[int,int,int], tuple(int(g*l) for g,l in zip(self.cmd_idx_to_dims[cmd_idx][0], self.cmd_idx_to_dims[cmd_idx][1])))
self._patch(cmd_idx, offset=3, data=[global_size_mp[0], 0, global_size_mp[1], 0, global_size_mp[2], 0])
self._patch(cmd_idx, offset=2, data=[global_size_mp[0], 0, global_size_mp[1], 0, global_size_mp[2], 0])
class QCOMArgsState(HCQArgsState):
def __init__(self, ptr:int, prg:QCOMProgram, bufs:Tuple[HCQBuffer, ...], vals:Tuple[int, ...]=()):
@@ -312,9 +322,11 @@ class QCOMDevice(HCQCompiled):
signals_page: Any = None
signals_pool: List[Any] = []
gpu_id: int = 0
dummy_addr: int = 0
def __init__(self, device:str=""):
self.fd = os.open('/dev/kgsl-3d0', os.O_RDWR)
QCOMDevice.dummy_addr = self._gpu_alloc(0x1000, map_to_cpu=False).va_addr
QCOMDevice.signals_page = self._gpu_alloc(16 * 65536, map_to_cpu=True, uncached=True)
QCOMDevice.signals_pool = [to_mv(self.signals_page.va_addr + off, 16).cast("Q") for off in range(0, self.signals_page.size, 16)]
info, self.ctx, self.cmd_buf, self.cmd_buf_ptr = self._info(), self._ctx_create(), self._gpu_alloc(0x1000000, map_to_cpu=True), 0