From fc30e4825d3315dfe32888427a0509d07dec149f Mon Sep 17 00:00:00 2001 From: Vyacheslav Pachkov Date: Tue, 3 Sep 2024 22:19:07 +0300 Subject: [PATCH] qcom refactor regs setters (#6347) * add Qreg * less loc for qreg --------- Co-authored-by: nimlgen <138685161+nimlgen@users.noreply.github.com> --- tinygrad/runtime/ops_qcom.py | 58 +++++++++++++++++------------------- 1 file changed, 28 insertions(+), 30 deletions(-) diff --git a/tinygrad/runtime/ops_qcom.py b/tinygrad/runtime/ops_qcom.py index 436287ae7e..23018a646b 100644 --- a/tinygrad/runtime/ops_qcom.py +++ b/tinygrad/runtime/ops_qcom.py @@ -9,6 +9,12 @@ from tinygrad.renderer.cstyle import QCOMRenderer from tinygrad.helpers import getenv, from_mv, mv_address, to_mv, round_up, data64_le, prod if getenv("IOCTL"): import extra.qcom_gpu_driver.opencl_ioctl # noqa: F401 # pylint: disable=unused-import +def _qreg_exec(reg, __val=0, **kwargs): + for k, v in kwargs.items(): + __val |= (getattr(adreno, reg[4:] + "_" + k.upper())) if isinstance(v, bool) else (v << getattr(adreno, reg[4:] + "_" + k.upper() + "__SHIFT")) + return __val +qreg: Any = type("QREG", (object,), {name[4:].lower(): functools.partial(_qreg_exec, name) for name in adreno.__dict__.keys() if name[:4] == 'REG_'}) + def next_power2(x): return 1 if x == 0 else 1 << (x - 1).bit_length() def prt(val: int): @@ -90,14 +96,14 @@ class QCOMComputeQueue(HWComputeQueue): def setup(self): self.cmd(adreno.CP_WAIT_FOR_IDLE) self.cmd(adreno.CP_SET_MARKER, adreno.RM6_COMPUTE) - self.reg(adreno.REG_A6XX_HLSQ_INVALIDATE_CMD, adreno.A6XX_HLSQ_INVALIDATE_CMD_CS_STATE | adreno.A6XX_HLSQ_INVALIDATE_CMD_CS_IBO) - self.reg(adreno.REG_A6XX_HLSQ_INVALIDATE_CMD, 0x0) - self.reg(adreno.REG_A6XX_SP_CS_TEX_COUNT, 0xff) # set to max - self.reg(adreno.REG_A6XX_SP_CS_IBO_COUNT, 0xff) # set to max - self.reg(adreno.REG_A6XX_SP_MODE_CONTROL, adreno.A6XX_SP_MODE_CONTROL_ISAMMODE(adreno.ISAMMODE_CL)) - self.reg(adreno.REG_A6XX_SP_PERFCTR_ENABLE, adreno.A6XX_SP_PERFCTR_ENABLE_CS) - self.reg(adreno.REG_A6XX_SP_TP_MODE_CNTL, adreno.ISAMMODE_CL | (1 << 3)) # ISAMMODE|UNK3 - self.reg(adreno.REG_A6XX_TPL1_DBG_ECO_CNTL, 0) + self.reg(adreno.REG_A6XX_HLSQ_INVALIDATE_CMD, qreg.a6xx_hlsq_invalidate_cmd(cs_state=True, cs_ibo=True)) + self.reg(adreno.REG_A6XX_HLSQ_INVALIDATE_CMD, qreg.a6xx_hlsq_invalidate_cmd()) + self.reg(adreno.REG_A6XX_SP_CS_TEX_COUNT, qreg.a6xx_sp_cs_tex_count(0xff)) + self.reg(adreno.REG_A6XX_SP_CS_IBO_COUNT, qreg.a6xx_sp_cs_ibo_count(0xff)) + self.reg(adreno.REG_A6XX_SP_MODE_CONTROL, qreg.a6xx_sp_mode_control(isammode=adreno.ISAMMODE_CL)) + self.reg(adreno.REG_A6XX_SP_PERFCTR_ENABLE, qreg.a6xx_sp_perfctr_enable(cs=True)) + self.reg(adreno.REG_A6XX_SP_TP_MODE_CNTL, qreg.a6xx_sp_tp_mode_cntl(isammode=adreno.ISAMMODE_CL, unk3=2)) + self.reg(adreno.REG_A6XX_TPL1_DBG_ECO_CNTL, qreg.a6xx_tpl1_dbg_eco_cntl()) def _exec(self, prg, args_state, global_size, local_size): global_size_mp = cast(Tuple[int,int,int], tuple(int(g*l) for g,l in zip(global_size, local_size))) if local_size else global_size @@ -105,17 +111,14 @@ class QCOMComputeQueue(HWComputeQueue): self.cmd(adreno.CP_WAIT_FOR_IDLE) self.reg(adreno.REG_A6XX_HLSQ_CS_NDRANGE_0, - adreno.A6XX_HLSQ_CS_NDRANGE_0_KERNELDIM(3) | adreno.A6XX_HLSQ_CS_NDRANGE_0_LOCALSIZEX(local_size[0] - 1) - | adreno.A6XX_HLSQ_CS_NDRANGE_0_LOCALSIZEY(local_size[1] - 1) | adreno.A6XX_HLSQ_CS_NDRANGE_0_LOCALSIZEZ(local_size[2] - 1), - global_size_mp[0], 0, global_size_mp[1], 0, global_size_mp[2], 0, 0xccc0cf, - 0xfc | adreno.A6XX_HLSQ_CS_CNTL_1_THREADSIZE(adreno.THREAD64), + qreg.a6xx_hlsq_cs_ndrange_0(kerneldim=3, localsizex=local_size[0] - 1, localsizey=local_size[1] - 1, localsizez=local_size[2] - 1), + global_size_mp[0], 0, global_size_mp[1], 0, global_size_mp[2], 0, 0xccc0cf, 0xfc | qreg.a6xx_hlsq_cs_cntl_1(threadsize=adreno.THREAD64), int(math.ceil(global_size[0])), int(math.ceil(global_size[1])), int(math.ceil(global_size[2]))) self.reg(adreno.REG_A6XX_SP_CS_CTRL_REG0, - adreno.A6XX_SP_CS_CTRL_REG0_THREADSIZE(adreno.THREAD64) | adreno.A6XX_SP_CS_CTRL_REG0_HALFREGFOOTPRINT(prg.hregs_count) - | adreno.A6XX_SP_CS_CTRL_REG0_FULLREGFOOTPRINT(prg.fregs_count) | adreno.A6XX_SP_CS_CTRL_REG0_BRANCHSTACK(prg.branch_stack // 2), - adreno.A6XX_SP_CS_UNKNOWN_A9B1_UNK5 | adreno.A6XX_SP_CS_UNKNOWN_A9B1_UNK6 | adreno.A6XX_SP_CS_UNKNOWN_A9B1_SHARED_SIZE(prg.shared_size), - 0, prg.prg_offset, *data64_le(prg.lib_gpu.va_addr), adreno.A6XX_SP_CS_PVT_MEM_PARAM_MEMSIZEPERITEM(prg.pvtmem_size_per_item), - *data64_le(prg.device._stack.va_addr), adreno.A6XX_SP_CS_PVT_MEM_SIZE_TOTALPVTMEMSIZE(prg.pvtmem_size_total)) + qreg.a6xx_sp_cs_ctrl_reg0(threadsize=adreno.THREAD64, halfregfootprint=prg.hregs, fullregfootprint=prg.fregs, branchstack=prg.brnchstck), + qreg.a6xx_sp_cs_unknown_a9b1(unk5=True, unk6=True, shared_size=prg.shared_size), 0, prg.prg_offset, *data64_le(prg.lib_gpu.va_addr), + qreg.a6xx_sp_cs_pvt_mem_param(memsizeperitem=prg.pvtmem_size_per_item), *data64_le(prg.device._stack.va_addr), + qreg.a6xx_sp_cs_pvt_mem_size(totalpvtmemsize=prg.pvtmem_size_total)) self.cmd(adreno.CP_LOAD_STATE6_FRAG, adreno.CP_LOAD_STATE6_0_STATE_TYPE(adreno.ST_CONSTANTS) | adreno.CP_LOAD_STATE6_0_STATE_SRC(adreno.SS6_INDIRECT) | adreno.CP_LOAD_STATE6_0_STATE_BLOCK(adreno.SB6_CS_SHADER) | adreno.CP_LOAD_STATE6_0_NUM_UNIT(prg.kernargs_alloc_size // 4), @@ -124,10 +127,9 @@ class QCOMComputeQueue(HWComputeQueue): | adreno.CP_LOAD_STATE6_0_STATE_BLOCK(adreno.SB6_CS_SHADER) | adreno.CP_LOAD_STATE6_0_NUM_UNIT(round_up(prg.image_size, 128) // 128), *data64_le(prg.lib_gpu.va_addr)) self.reg(adreno.REG_A6XX_HLSQ_CONTROL_2_REG, 0xfcfcfcfc, 0xfcfcfcfc, 0xfcfcfcfc, 0xfc, - adreno.A6XX_HLSQ_CS_CNTL_CONSTLEN(prg.kernargs_alloc_size // 4) | adreno.A6XX_HLSQ_CS_CNTL_ENABLED) - - self.reg(adreno.REG_A6XX_SP_CS_PVT_MEM_HW_STACK_OFFSET, prg.hw_stack_offset) - self.reg(adreno.REG_A6XX_SP_CS_INSTRLEN, prg.image_size // 4) + qreg.a6xx_hlsq_cs_cntl(constlen=prg.kernargs_alloc_size // 4, enabled=True)) + self.reg(adreno.REG_A6XX_SP_CS_PVT_MEM_HW_STACK_OFFSET, qreg.a6xx_sp_cs_pvt_mem_hw_stack_offset(prg.hw_stack_offset)) + self.reg(adreno.REG_A6XX_SP_CS_INSTRLEN, qreg.a6xx_sp_cs_instrlen(prg.image_size // 4)) if hasattr(args_state, 'samplers_ptr'): self.cmd(adreno.CP_LOAD_STATE6_FRAG, @@ -153,9 +155,7 @@ class QCOMComputeQueue(HWComputeQueue): self.reg(adreno.REG_A6XX_SP_CS_IBO, *data64_le(args_state.ibos_ptr.va_addr)) self.reg(adreno.REG_A6XX_SP_CS_CONFIG, - adreno.A6XX_SP_CS_CONFIG_ENABLED | adreno.A6XX_SP_CS_CONFIG_NSAMP(args_state.samplers_cnt) - | adreno.A6XX_SP_CS_CONFIG_NTEX(args_state.descriptors_cnt) | adreno.A6XX_SP_CS_CONFIG_NIBO(args_state.ibos_cnt)) - + qreg.a6xx_sp_cs_config(enabled=True, nsamp=args_state.samplers_cnt, ntex=args_state.descriptors_cnt, nibo=args_state.ibos_cnt)) self.cmd(adreno.CP_RUN_OPENCL, 0) def _update_exec(self, cmd_idx, global_size, local_size): @@ -164,9 +164,7 @@ class QCOMComputeQueue(HWComputeQueue): self.cmd_idx_to_dims[cmd_idx][0] = global_size if local_size is not None: - payload = (adreno.A6XX_HLSQ_CS_NDRANGE_0_KERNELDIM(3) | adreno.A6XX_HLSQ_CS_NDRANGE_0_LOCALSIZEX(local_size[0] - 1) - | adreno.A6XX_HLSQ_CS_NDRANGE_0_LOCALSIZEY(local_size[1] - 1) | adreno.A6XX_HLSQ_CS_NDRANGE_0_LOCALSIZEZ(local_size[2] - 1)) - + payload = qreg.a6xx_hlsq_cs_ndrange_0(kerneldim=3, localsizex=local_size[0] - 1, localsizey=local_size[1] - 1, localsizez=local_size[2] - 1) self._patch(cmd_idx, offset=2, data=[payload]) self.cmd_idx_to_dims[cmd_idx][1] = local_size @@ -232,7 +230,7 @@ class QCOMProgram(HCQProgram): self.pvtmem_size_total = self.pvtmem_size_per_item * 128 * 2 self.hw_stack_offset = round_up(next_power2(round_up(self.pvtmem, 512)) * 128 * 16, 0x1000) self.shared_size = max(1, (self.shmem - 1) // 1024) - self.max_threads = min(1024, ((384 * 32) // (max(1, (self.fregs_count + round_up(self.hregs_count, 2) // 2)) * 128)) * 128) + self.max_threads = min(1024, ((384 * 32) // (max(1, (self.fregs + round_up(self.hregs, 2) // 2)) * 128)) * 128) device._ensure_stack_size(self.hw_stack_offset * 4) super().__init__(QCOMArgsState, self.device, self.name, kernargs_alloc_size=1024) @@ -252,7 +250,7 @@ class QCOMProgram(HCQProgram): # Parse image descriptors image_desc_off = _read_lib(0x110) - self.prg_offset, self.branch_stack = _read_lib(image_desc_off+0xc4), _read_lib(image_desc_off+0x108) + self.prg_offset, self.brnchstck = _read_lib(image_desc_off+0xc4), _read_lib(image_desc_off+0x108) // 2 self.pvtmem, self.shmem = _read_lib(image_desc_off+0xc8), _read_lib(image_desc_off+0xd8) # Fill up constants and buffers info @@ -275,7 +273,7 @@ class QCOMProgram(HCQProgram): # Registers info reg_desc_off = _read_lib(0x34) - self.fregs_count, self.hregs_count = _read_lib(reg_desc_off + 0x14), _read_lib(reg_desc_off + 0x18) + self.fregs, self.hregs = _read_lib(reg_desc_off + 0x14), _read_lib(reg_desc_off + 0x18) def __del__(self): if hasattr(self, 'lib_gpu'): self.device.allocator.free(self.lib_gpu, self.lib_gpu.size, options=BufferOptions(cpu_access=True, nolru=True))