From 1c16b6e0825784efffbf96a1badf24abf79f9173 Mon Sep 17 00:00:00 2001 From: Christopher Milan Date: Mon, 8 Dec 2025 11:02:08 -0800 Subject: [PATCH] Mesa: freedreno (#12746) * ir3 init * got a program * 1 + 1 works * use isa_disasm instead of shader_disasm * wip * matmul works * works on py3.14 * fix const loading * skip QCOM failing tests * cleanup * args actually work * add compile-only tests * fix typo and install tinymesa * IR3 NULL backend * (float32) images work * autogen fix * fix compile only test * typo * mypy happy * compile-only uses py3.14 * bump mesa * unify qcom disassembler * float16 works * disasm shows in viz * save a line * add real del * variable workgroup sizes * simplify diff * bump line count * properly set wgsz * regen mesa * no preamble * bump lines --- .github/actions/setup-tinygrad/action.yml | 2 +- .github/workflows/test.yml | 33 +- test/test_ops.py | 39 +- tinygrad/device.py | 2 +- tinygrad/helpers.py | 3 +- tinygrad/renderer/nir.py | 89 +- tinygrad/runtime/autogen/__init__.py | 21 +- tinygrad/runtime/autogen/mesa.py | 1793 +++++++++++++++++++++ tinygrad/runtime/ops_null.py | 6 +- tinygrad/runtime/ops_qcom.py | 99 +- tinygrad/runtime/support/compiler_mesa.py | 62 +- 11 files changed, 2077 insertions(+), 72 deletions(-) diff --git a/.github/actions/setup-tinygrad/action.yml b/.github/actions/setup-tinygrad/action.yml index e3a4c86814..934e3f811a 100644 --- a/.github/actions/setup-tinygrad/action.yml +++ b/.github/actions/setup-tinygrad/action.yml @@ -298,7 +298,7 @@ runs: - name: Install mesa (linux) if: inputs.mesa == 'true' && runner.os == 'Linux' shell: bash - run: sudo curl -fL https://github.com/sirhcm/tinymesa/releases/download/tinymesa-32dc66c/libtinymesa_cpu-mesa-25.2.4-linux-amd64.so -o /usr/lib/libtinymesa_cpu.so + run: sudo curl -fL https://github.com/sirhcm/tinymesa/releases/download/v1/libtinymesa_cpu-mesa-25.2.7-linux-amd64.so -o /usr/lib/libtinymesa_cpu.so - name: Install mesa (macOS) if: inputs.mesa == 'true' && runner.os == 'macOS' shell: bash diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index b42e4aebd5..6b9748ce84 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -289,8 +289,8 @@ jobs: python extra/optimization/extract_dataset.py gzip -c /tmp/sops > extra/datasets/sops.gz #DEBUG=1 MIN_ASTS=1 python extra/optimization/get_action_space.py - - name: Repo line count < 19000 lines - run: MAX_LINE_COUNT=19000 python sz.py + - name: Repo line count < 19150 lines + run: MAX_LINE_COUNT=19150 python sz.py spec: strategy: @@ -972,3 +972,32 @@ jobs: run: | python -c "from tinygrad import Device; assert Device.DEFAULT == {'LLVM':'CPU'}.get(x:='${{ matrix.backend }}'.upper(), x), Device.DEFAULT" python -m pytest -n=auto test/test_tiny.py test/test_ops.py --durations=20 + +# ****** Compile-only Tests ****** + + compiletests: + strategy: + fail-fast: false + matrix: + backend: [ir3] + name: Compile-only (${{ matrix.backend }}) + runs-on: ubuntu-24.04 + timeout-minutes: 15 + steps: + - name: Checkout Code + uses: actions/checkout@v4 + - name: Setup Environment + uses: ./.github/actions/setup-tinygrad + with: + key: compile-${{ matrix.backend }} + deps: testing_minimal + mesa: ${{ matrix.backend == 'ir3' && 'true' }} + python-version: '3.14' + - name: Set env + shell: bash + run: printf "NULL=1\n${{ matrix.backend == 'ir3' && 'NULL_IR3=1' }}" >> $GITHUB_ENV + - name: Run test_ops + shell: bash + run: | + python -c "from tinygrad import Device; assert Device.DEFAULT == 'NULL'" + python -m pytest -n=auto test/test_ops.py --durations=20 diff --git a/test/test_ops.py b/test/test_ops.py index 995420e56b..3f864a7a7d 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -2,7 +2,7 @@ import time, math, unittest, functools, platform, warnings import numpy as np from typing import List, Callable import torch -from tinygrad.helpers import getenv, IMAGE, DEBUG, CI, Context, CPU_LLVM, CPU_LVP, AMD_LLVM +from tinygrad.helpers import getenv, IMAGE, DEBUG, CI, Context, CPU_LLVM, CPU_LVP, AMD_LLVM, EMULATE from tinygrad import Tensor, Device, dtypes from tinygrad.tensor import _to_np_dtype from tinygrad.device import is_dtype_supported @@ -16,6 +16,7 @@ if CI: FORWARD_ONLY = getenv("FORWARD_ONLY", 0) PRINT_TENSORS = getenv("PRINT_TENSORS", 0) +COMPILE_ONLY = Device.DEFAULT == "NULL" and not EMULATE def slow_test(test_func): return unittest.skipIf(getenv("SKIP_SLOW_TEST"), "Skipping slow test")(test_func) @@ -38,6 +39,7 @@ def helper_test_op(shps, torch_fxn, tinygrad_fxn=None, atol=1e-6, rtol=1e-3, gra tinygrad_fp = time.monotonic() - st def compare(s, tinygrad_output, torch_output, atol, rtol): + if COMPILE_ONLY: return if PRINT_TENSORS: print(s, tinygrad_output, torch_output) try: assert tinygrad_output.shape == torch_output.shape, f"shape mismatch: tinygrad={tinygrad_output.shape} | torch={torch_output.shape}" @@ -421,8 +423,9 @@ class TestOps(unittest.TestCase): def test_isinf(self): val = [float('-inf'), 0., float('inf'), float('nan'), 1.1] helper_test_op(None, torch.isinf, Tensor.isinf, vals=[val], forward_only=True) - np.testing.assert_equal(Tensor(val).isinf(detect_positive=True, detect_negative=False).numpy(), [False, False, True, False, False]) - np.testing.assert_equal(Tensor(val).isinf(detect_positive=False, detect_negative=True).numpy(), [True, False, False, False, False]) + if not COMPILE_ONLY: + np.testing.assert_equal(Tensor(val).isinf(detect_positive=True, detect_negative=False).numpy(), [False, False, True, False, False]) + np.testing.assert_equal(Tensor(val).isinf(detect_positive=False, detect_negative=True).numpy(), [True, False, False, False, False]) def test_isnan(self): helper_test_op(None, torch.isnan, Tensor.isnan, vals=[[float('-inf'), 0., float('inf'), float('nan'), 1.1]], forward_only=True) @@ -594,7 +597,7 @@ class TestOps(unittest.TestCase): helper_test_op(None, lambda x: x//2, forward_only=True, vals=[[3, 4, 5]]) helper_test_op(None, functools.partial(torch.div, rounding_mode="trunc"), Tensor.idiv, forward_only=True, vals=[[-4, 7, 5, 4, -7, 8], [2, -3, 8, -2, 3, 5]]) - if is_dtype_supported(dtypes.uint64): + if is_dtype_supported(dtypes.uint64) and not COMPILE_ONLY: x = Tensor(2**64 - 1, dtype=dtypes.uint64).idiv(1) np.testing.assert_equal(x.numpy(), 2**64 - 1) @@ -679,6 +682,7 @@ class TestOps(unittest.TestCase): # float to power of int helper_test_op(None, lambda x: 0.7**x, vals=[[-2,-1,0,1,2,3]], forward_only=True) + @unittest.skipIf(COMPILE_ONLY, "test requires runtime") def test_pow_const_direct(self): # x ** c def get_tiny_gradient(x, c): @@ -1088,8 +1092,9 @@ class TestOps(unittest.TestCase): # check if it returns the first index for multiple occurences helper_test_op(None, lambda x: x.argmax().type(torch.int32), lambda x: x.argmax(), forward_only=True, vals=[[2, 2]]) helper_test_op(None, lambda x: x.argmax().type(torch.int32), lambda x: x.argmax(), forward_only=True, vals=[[1, 2, 2]]) - np.testing.assert_equal(Tensor([2,2]).argmax().numpy(), 0) - np.testing.assert_equal(Tensor([1,2,2]).argmax().numpy(), 1) + if not COMPILE_ONLY: + np.testing.assert_equal(Tensor([2,2]).argmax().numpy(), 0) + np.testing.assert_equal(Tensor([1,2,2]).argmax().numpy(), 1) helper_test_op([(10,20)], lambda x: x.argmax().type(torch.int32), lambda x: x.argmax(), forward_only=True) helper_test_op([(10,20)], lambda x: x.argmax(0, False).type(torch.int32), lambda x: x.argmax(0, False), forward_only=True) helper_test_op([(10,20)], lambda x: x.argmax(1, False).type(torch.int32), lambda x: x.argmax(1, False), forward_only=True) @@ -1107,8 +1112,9 @@ class TestOps(unittest.TestCase): # check if it returns the first index for multiple occurences helper_test_op(None, lambda x: x.argmin().type(torch.int32), lambda x: x.argmin(), forward_only=True, vals=[[2, 2]]) helper_test_op(None, lambda x: x.argmin().type(torch.int32), lambda x: x.argmin(), forward_only=True, vals=[[3, 2, 2]]) - np.testing.assert_equal(Tensor([2,2]).argmin().numpy(), 0) - np.testing.assert_equal(Tensor([3,2,2]).argmin().numpy(), 1) + if not COMPILE_ONLY: + np.testing.assert_equal(Tensor([2,2]).argmin().numpy(), 0) + np.testing.assert_equal(Tensor([3,2,2]).argmin().numpy(), 1) helper_test_op([(10,20)], lambda x: x.argmin().type(torch.int32), lambda x: x.argmin(), forward_only=True) helper_test_op([(10,20)], lambda x: x.argmin(0, False).type(torch.int32), lambda x: x.argmin(0, False), forward_only=True) helper_test_op([(10,20)], lambda x: x.argmin(1, False).type(torch.int32), lambda x: x.argmin(1, False), forward_only=True) @@ -1156,12 +1162,13 @@ class TestOps(unittest.TestCase): lambda x: x.topk(4, dim, largest, sorted_).indices.type(torch.int32), lambda x: x.topk(4, dim, largest, sorted_)[1], forward_only=True) # repeated values - value, indices = Tensor([1, 1, 0, 1, 0, 1, 0, 0, 1, 0, 0, 0, 1, 0]).topk(3) - np.testing.assert_equal(value.numpy(), [1, 1, 1]) - np.testing.assert_equal(indices.numpy(), [0, 1, 3]) - value, indices = Tensor([1, 1, 0, 1, 0, 1, 0, 0, 1, 0, 0, 0, 1, 0]).topk(3, largest=False) - np.testing.assert_equal(value.numpy(), [0, 0, 0]) - np.testing.assert_equal(indices.numpy(), [2, 4, 6]) + if not COMPILE_ONLY: + value, indices = Tensor([1, 1, 0, 1, 0, 1, 0, 0, 1, 0, 0, 0, 1, 0]).topk(3) + np.testing.assert_equal(value.numpy(), [1, 1, 1]) + np.testing.assert_equal(indices.numpy(), [0, 1, 3]) + value, indices = Tensor([1, 1, 0, 1, 0, 1, 0, 0, 1, 0, 0, 0, 1, 0]).topk(3, largest=False) + np.testing.assert_equal(value.numpy(), [0, 0, 0]) + np.testing.assert_equal(indices.numpy(), [2, 4, 6]) self.helper_test_exception([(4)], lambda x: x.topk(5), expected=(RuntimeError, ValueError)) @slow_test @@ -1313,6 +1320,7 @@ class TestOps(unittest.TestCase): helper_test_op(None, lambda x,y: x.matmul(y), lambda x,y: x@y, vals=[np.eye(8).astype(np.float32), np.eye(8).astype(np.float32)]) @unittest.skipIf(CI and Device.DEFAULT in ["NV", "CL", "CUDA"] or (Device.DEFAULT == "CPU" and CPU_LLVM) or IMAGE or (Device.DEFAULT == "WEBGPU" and platform.system() == "Windows"), "not supported on these in CI/IMAGE") + @unittest.skipIf(Device.DEFAULT == "QCOM", "not precise enough") def test_gemm_fp16(self): helper_test_op([(64,64), (64,64)], lambda x,y: x.half().matmul(y.half()), atol=5e-3, rtol=5e-3, grad_atol=5e-3, grad_rtol=5e-3) def test_gemm(self): @@ -1723,6 +1731,7 @@ class TestOps(unittest.TestCase): helper_test_op([(7,5,10)], lambda x: x[1:5:2, 3, ::4]) helper_test_op([(7,5,10)], lambda x: x[1:5:2, None, None, 3, None, ::4]) + @unittest.skipIf(COMPILE_ONLY, "test requires runtime") def test_slice_negative_strides(self): # Torch doesn't support slicing with negative steps a = np.random.randn(10, 10, 10).astype(np.float32) @@ -2752,6 +2761,7 @@ class TestOps(unittest.TestCase): n = Tensor([1, float("nan")]).max().numpy() assert math.isnan(n.item()), f"{n.item()} is not nan" + @unittest.skipIf(COMPILE_ONLY, "test requires runtime") def test_inf_where(self): x = Tensor.full((3, 3), float("inf")) n = (x < 0).where(x, 1).numpy() @@ -3168,6 +3178,7 @@ class TestOps(unittest.TestCase): @unittest.skipIf((getenv("MOCKGPU") or Device.DEFAULT == "PYTHON"), "very slow on MOCKGPU because reduce does not fold") @unittest.skipIf(Device.DEFAULT == "WEBGPU", "webgpu runtime issue") + @unittest.skipIf(Device.DEFAULT == "QCOM", "QCOM fails with: Resource deadlock avoided") def test_masked_select(self): helper_test_op([(32, 10)], lambda x: x.masked_select(x>0.5), lambda x: x.masked_select(x>0.5), forward_only=True) helper_test_op([(32, 10)], lambda x: x.masked_select(torch.tensor(True)), lambda x: x.masked_select(Tensor(True)), forward_only=True) diff --git a/tinygrad/device.py b/tinygrad/device.py index 27f1947d7f..455f6e8b3f 100644 --- a/tinygrad/device.py +++ b/tinygrad/device.py @@ -366,7 +366,7 @@ def is_dtype_supported(dtype:DType, device:str|None=None) -> bool: if device in ["CUDA", "NV"]: return not CI if device == "CPU" and CPU_LLVM: return OSX if device == "PYTHON": return sys.version_info >= (3, 12) - if dtype == dtypes.float64: return device != "METAL" and not (OSX and device == "CL") + if dtype == dtypes.float64: return device not in {"METAL", "QCOM"} and not (OSX and device == "CL") and not getenv("NULL_IR3") return True if PROFILE: diff --git a/tinygrad/helpers.py b/tinygrad/helpers.py index 04fade6908..a27156de42 100644 --- a/tinygrad/helpers.py +++ b/tinygrad/helpers.py @@ -186,8 +186,9 @@ EMULATE = ContextVar("EMULATE", "") CPU_COUNT = ContextVar("CPU_COUNT", max(1, len(os.sched_getaffinity(0)) if hasattr(os, "sched_getaffinity") else (os.cpu_count() or 1))) # Compilers CPU_LLVM, CPU_LVP, AMD_LLVM = ContextVar("CPU_LLVM", 0), ContextVar("CPU_LVP", 0), ContextVar("AMD_LLVM", 0) -NV_PTX, CUDA_PTX, NV_NAK = ContextVar("NV_PTX", 0), ContextVar("CUDA_PTX", 0), ContextVar("NV_NAK", 0) +NV_PTX, CUDA_PTX, NV_NAK, QCOM_IR3 = ContextVar("NV_PTX", 0), ContextVar("CUDA_PTX", 0), ContextVar("NV_NAK", 0), ContextVar("QCOM_IR3", 0) AMD_CC, CPU_CC, NV_CC, CUDA_CC = ContextVar("AMD_CC", ""), ContextVar("CPU_CC", ""), ContextVar("NV_CC", ""), ContextVar("CUDA_CC", "") +QCOM_CC = ContextVar("QCOM_CC", "") # VIZ implies PROFILE, but you can run PROFILE without VIZ VIZ = ContextVar("VIZ", 0) PROFILE = ContextVar("PROFILE", VIZ.value) diff --git a/tinygrad/renderer/nir.py b/tinygrad/renderer/nir.py index 22cba1b62d..4e0c8e859a 100644 --- a/tinygrad/renderer/nir.py +++ b/tinygrad/renderer/nir.py @@ -1,11 +1,11 @@ from typing import Callable, cast, Any -from tinygrad.dtype import AddrSpace, DType, PtrDType, dtypes +from tinygrad.dtype import AddrSpace, DType, PtrDType, ImageDType, dtypes from tinygrad.helpers import DEBUG, OSX, unwrap, charptr from tinygrad.renderer import Renderer from tinygrad.renderer.cstyle import CUDARenderer from tinygrad.uop.ops import GroupOp, Ops, UOp, PatternMatcher, UPat, range_str from tinygrad.runtime.autogen import mesa -import base64, ctypes, ctypes.util, struct, functools, inspect, contextlib +import base64, ctypes, ctypes.util, struct, functools, inspect, contextlib, itertools def g(s:str): return getattr(mesa, s) def nsrc(d:mesa.nir_def) -> mesa.nir_src: return mesa.nir_src(ssa=ctypes.pointer(d)) @@ -49,7 +49,7 @@ def nir_instr(nc=1, bs=lambda: None, intrins=None, srcs=None, has_def=True, df=N if has_def: mesa.nir_def_init(instr.contents.instr, getattr(instr.contents, "def"), go(nc), go(bs)) for k, v in go(intrins or {}).items(): idx = mesa.nir_intrinsic_infos[instr.contents.intrinsic.value].index_map[g(f"NIR_INTRINSIC_{k}")] - assert idx > 0 + assert idx > 0, "invalid intrinsic. mesa version mismatch?" instr.contents.const_index[idx - 1] = go(v) for i, src in enumerate(go(srcs or [])): ctypes.cast(instr.contents.src, ctypes.POINTER(mesa.nir_src))[i] = go(src) for k,v in {k:vcomp for k,v in contents.items() if (vcomp:=go(v)) is not None}.items(): setattr(instr.contents, k, go(v)) @@ -67,11 +67,16 @@ def nchannel(b:mesa.nir_builder, src:mesa.nir_def, c:int): ctypes.cast(mov.contents.src, ctypes.POINTER(mesa.nir_alu_src))[0] = alu_src return mov +def nimm_set(imm:mesa.nir_def, x, dtype:DType): + instr = ctypes.cast(imm.parent_instr, ctypes.POINTER(mesa.nir_load_const_instr)) + struct.pack_into(unwrap(dtype.fmt), (ctypes.c_ubyte * dtype.itemsize).from_address(ctypes.addressof(instr.contents.value)), 0, x) + @nir_instr(nc=1, bs=lambda dtype: 1 if dtype == dtypes.bool else dtype.itemsize * 8) def nimm(b:mesa.nir_builder, x, dtype:DType) -> mesa.nir_def: - instr = mesa.nir_load_const_instr_create(b.shader, 1, 1 if dtype == dtypes.bool else dtype.itemsize * 8) - struct.pack_into(unwrap(dtype.fmt), (ctypes.c_ubyte * dtype.itemsize).from_address(ctypes.addressof(instr.contents.value)), 0, x) + nimm_set(getattr((instr:=mesa.nir_load_const_instr_create(b.shader, 1, 1 if dtype==dtypes.bool else dtype.itemsize * 8)).contents, "def"), x, dtype) return instr +@nir_instr(nc=1, bs=lambda dtype: 1 if dtype == dtypes.bool else dtype.itemsize * 8) +def nundef(b, dtype): return mesa.nir_undef_instr_create(b.shader, 1, 1 if dtype == dtypes.bool else dtype.itemsize * 8) deref_var = nir_instr(nc=1, bs=32, modes=lambda var:var.data.mode, type=lambda var:var.type, var=lambda var:ctypes.pointer(var))( # pylint: disable=W0108 lambda b, var: mesa.nir_deref_instr_create(b.shader, mesa.nir_deref_type_var)) @@ -87,6 +92,8 @@ nload = nir_instr(nc=lambda dtype:dtype.count, bs=lambda dtype:dtype.itemsize*8/ ngid = nir_instr(nc=3, bs=32)(lambda b: mesa.nir_intrinsic_instr_create(b.shader, mesa.nir_intrinsic_load_workgroup_id)) nlid = nir_instr(nc=3, bs=32)(lambda b: mesa.nir_intrinsic_instr_create(b.shader, mesa.nir_intrinsic_load_local_invocation_id)) +ngsz = nir_instr(nc=3, bs=32)(lambda b: mesa.nir_intrinsic_instr_create(b.shader, mesa.nir_intrinsic_load_workgroup_size)) +def nid(b): return nalu(b, "iadd", nalu(b, "imul", ngid(b), ngsz(b)), nlid(b)) nbarrier = nir_instr(has_def=False, intrins={"EXECUTION_SCOPE":mesa.SCOPE_WORKGROUP})( lambda b: mesa.nir_intrinsic_instr_create(b.shader, mesa.nir_intrinsic_barrier)) @@ -123,16 +130,16 @@ class NIRRenderer(Renderer): (UPat(Ops.STORE, src=(UPat(), UPat(dtype=dtypes.bool)), name="x", allow_any_len=True), lambda x: x.replace(src=x.src[0:1] + (x.src[1].cast(dtypes.uint8),) + x.src[2:])), # load/store use pointer arithmetic, and the cast does nothing - (UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var("off")), allow_any_len=True, name="x"), - lambda x,buf,off: x.replace(src=(buf,off.cast(dtypes.long))+x.src[2:]) if buf.dtype.addrspace != AddrSpace.REG and off.op != Ops.CAST else None), + (UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var("off")), allow_any_len=True, name="x"), lambda x,buf,off: x.replace( + src=(buf,off.cast(dtypes.long))+x.src[2:]) if buf.dtype.addrspace != AddrSpace.REG and off.op not in (Ops.CAST, Ops.VECTORIZE) else None), (UPat(Ops.CAST, name="x"), lambda x: x.src[0] if isinstance(x.dtype, PtrDType) or x.src[0].dtype == dtypes.void else None), ]) def_rewrite = PatternMatcher([ (UPat(Ops.CONST, name="x"), lambda ctx,x: nimm(ctx.b, x.arg, x.dtype)), - (UPat(Ops.DEFINE_GLOBAL, name="x"), lambda ctx,x: ctx.param(ctx.b, x.dtype, 8)), - (UPat(Ops.DEFINE_VAR, name="x"), lambda ctx,x: ctx.param(ctx.b, x.dtype, 4)), - (UPat(Ops.SPECIAL, name="x"), lambda ctx,x: nchannel(ctx.b, ngid(ctx.b) if x.arg[0] == 'g' else nlid(ctx.b), int(x.arg[-1]))), + (UPat(Ops.DEFINE_GLOBAL, name="x"), lambda ctx,x: ctx.param(ctx.b, x, 8)), + (UPat(Ops.DEFINE_VAR, name="x"), lambda ctx,x: ctx.param(ctx.b, x, 4)), + (UPat(Ops.SPECIAL, name="x"), lambda ctx,x: nchannel(ctx.b, {'g':ngid, 'l':nlid, 'i': nid}[x.arg[0]](ctx.b), int(x.arg[-1]))), (UPat(Ops.STORE, src=(UPat(Ops.INDEX, src=(UPat.var("buf"),UPat.var("off")), allow_any_len=True), UPat.var("val")), allow_any_len=True, name="x"), lambda ctx,x,buf,off,val: nstore(ctx.b, buf.ptrdtype.addrspace, nidx(ctx.b, ctx.r[buf], ctx.r[off], buf.dtype), ctx.r[val], val.dtype)), (UPat(Ops.LOAD, src=(UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var("off"), UPat.var("gate"))), UPat.var("alt")), allow_any_len=True, name="x"), @@ -158,9 +165,11 @@ class NIRRenderer(Renderer): @property def nir_options(self): raise NotImplementedError("needs nir_options") - def param(self, b:mesa.nir_builder, dtype:DType, sz:int) -> mesa.nir_def: raise NotImplementedError("needs param") + def param(self, b:mesa.nir_builder, x, sz:int) -> mesa.nir_def: raise NotImplementedError("needs param") def prerender(self, uops:list[UOp]): self.b = mesa.nir_builder_init_simple_shader(mesa.MESA_SHADER_COMPUTE, mesa.nir_shader_compiler_options.from_buffer_copy(self.nir_options), None) + self.b.shader.contents.info.workgroup_size_variable = any([u.op == Ops.SPECIAL and u.arg[0] == 'i' for u in uops]) + def postrender(self, uops:list[UOp]): pass def render(self, uops:list[UOp]): self.prerender(uops) @@ -193,6 +202,7 @@ class NIRRenderer(Renderer): else: if (d:=self.def_rewrite.rewrite(u, ctx=self)) is None: raise RuntimeError(f"failed to render {u.op} srcs {[x.dtype for x in u.src]}") self.r[u] = cast(mesa.nir_def, d) + self.postrender(uops) mesa.nir_validate_shader(self.b.shader, b"after render") if DEBUG >= 4: mesa.nir_print_shader(self.b.shader, ctypes.POINTER(mesa.struct__IO_FILE).in_dll(ctypes.CDLL(ctypes.util.find_library('c')), @@ -206,22 +216,23 @@ class NIRRenderer(Renderer): return ret -class NAKRenderer(NIRRenderer): - device = "NV" +class NIRRendererWithOpts(NIRRenderer): def __init__(self, dev=None, nir_options=None): self.dev, self._nir_options = dev, nir_options super().__init__() - def __reduce__(self): return NAKRenderer, (None, self.nir_options,) + def __reduce__(self): return self.__class__, (None, self.nir_options) @property def nir_options(self): if self._nir_options is None: self._nir_options = self.dev.compiler.nir_options return self._nir_options +class NAKRenderer(NIRRendererWithOpts): + device = "NV" param = nir_instr(nc=1, num_components=1, bs=lambda sz:sz*8, also=lambda self,sz: setattr(self, "param_idx", self.param_idx + sz), intrins={"ALIGN_MUL":lambda sz:sz}, srcs=lambda self,b: [nsrc(nimm(b, 0, dtypes.int)), nsrc(nimm(b, self.param_idx, dtypes.int))])( - lambda self, b, dtype, sz: mesa.nir_intrinsic_instr_create(b.shader, mesa.nir_intrinsic_ldc_nv)) + lambda self, b, x, sz: mesa.nir_intrinsic_instr_create(b.shader, mesa.nir_intrinsic_ldc_nv)) class LVPRenderer(NIRRenderer): device = "CPU" @@ -232,9 +243,55 @@ class LVPRenderer(NIRRenderer): param = nir_instr(nc=1, bs=lambda sz: sz * 8, num_components=1, intrins={"ALIGN_MUL":lambda sz: sz, "RANGE":lambda self: self.param_sz}, srcs=lambda b, self: [nsrc(nimm(b, 0, dtypes.int)), nsrc(nimm(b, self.param_idx, dtypes.int))], also=lambda self, sz: - setattr(self, "param_idx", self.param_idx+sz))(lambda self, b, dtype, sz: mesa.nir_intrinsic_instr_create(b.shader, mesa.nir_intrinsic_load_ubo)) + setattr(self, "param_idx", self.param_idx+sz))(lambda self,b,x,sz: mesa.nir_intrinsic_instr_create(b.shader, mesa.nir_intrinsic_load_ubo)) def prerender(self, uops:list[UOp]): super().prerender(uops) self.param_sz = sum([8 if u.op == Ops.DEFINE_GLOBAL else u.dtype.itemsize for u in uops if u.op in (Ops.DEFINE_GLOBAL, Ops.DEFINE_VAR)]) +# FIXME: this should be a rewrite rule +def tovec(b, coord): return nalu(b, "vec4", nchannel(b, coord, 0), nchannel(b, coord, 1), nundef(b, dtypes.int), nundef(b, dtypes.int)) +def nfloat(dtype): return mesa.nir_type_float16 if dtype == dtypes.half else mesa.nir_type_float32 +nstore_img = nir_instr(has_def=False, df=lambda img:img, num_components=lambda val:val.num_components, + intrins=lambda dtype:{'IMAGE_DIM':mesa.GLSL_SAMPLER_DIM_2D, 'ACCESS':mesa.ACCESS_CAN_REORDER, 'SRC_TYPE':nfloat(dtype)}, + srcs=lambda b,img,coord,val:[nsrc(x) for x in [img, tovec(b, coord), nundef(b, dtypes.int), val, nimm(b, 0, dtypes.int)]])( + lambda b,img,coord,val,dtype:mesa.nir_intrinsic_instr_create(b.shader,g("nir_intrinsic_image_store"))) + +_nload_img = nir_instr(intrins=lambda dtype:{'IMAGE_DIM':mesa.GLSL_SAMPLER_DIM_2D, 'ACCESS':mesa.ACCESS_CAN_REORDER, 'DEST_TYPE':nfloat(dtype)}, + nc=4, bs=32, num_components=4, srcs=lambda b,img,coord:[nsrc(x) for x in [img, tovec(b, coord), nundef(b, dtypes.int), nimm(b, 0, dtypes.int)]])( + lambda b,img,coord,dtype: mesa.nir_intrinsic_instr_create(b.shader, g("nir_intrinsic_image_load"))) + +class IR3Renderer(NIRRendererWithOpts): + device = "QCOM" + + def nload_img(ctx,img,coord): + ctx.texs.add(img) + return _nload_img(ctx.b, ctx.r[img], ctx.r[coord], img.dtype) + + def_rewrite = PatternMatcher([ + (UPat(Ops.STORE, src=(UPat.var('img').index(UPat.var('coord', dtypes.int.vec(2)), allow_any_len=True), UPat.var("val")), + allow_any_len=True), lambda ctx,img,coord,val: nstore_img(ctx.b, ctx.r[img], ctx.r[coord], ctx.r[val], val.dtype)), + (UPat(Ops.LOAD, src=(UPat.var('img').index(UPat.var('coord', dtypes.int.vec(2)), UPat.var("gate")), UPat.var("alt"))), + lambda ctx,img,coord,alt,gate: if_phi(ctx.b, ctx.r[gate], lambda: ctx.nload_img(img, coord), lambda: ctx.r[alt])), + (UPat(Ops.LOAD, src=(UPat.var('img').index(UPat.var('coord', dtypes.int.vec(2))),)), nload_img), + ]) + NIRRenderer.def_rewrite + + _param = LVPRenderer.param + def _param_img(self, x): + self.img_idx += 1 + return nimm(self.b, self.img_idx - 1, dtypes.int) + + def param(self, b, x, sz): return self._param_img(x) if isinstance(x.dtype, ImageDType) else self._param(b, x, sz) + + def prerender(self, uops:list[UOp]): + super().prerender(uops) + self.texs:set[UOp] = set() + self.uops, self.ibo_idx, self.img_idx = uops, 0, 0 + self.param_sz = sum([8 if u.op == Ops.DEFINE_GLOBAL else u.dtype.itemsize for u in uops if u.op in (Ops.DEFINE_GLOBAL, Ops.DEFINE_VAR)]) + + def postrender(self, uops:list[UOp]): + bufs, texs, imgs = [u for u in uops if u.op == Ops.DEFINE_GLOBAL], itertools.count().__next__, itertools.count().__next__ + for b in filter(lambda b: isinstance(b.dtype, ImageDType), bufs): nimm_set(self.r[b], texs() if b in self.texs else imgs(), dtypes.int) + + self.b.shader.contents.info.num_ubos = len([u for u in bufs if not isinstance(u.dtype, ImageDType)]) + self.b.shader.contents.info.num_images = texs() + imgs() diff --git a/tinygrad/runtime/autogen/__init__.py b/tinygrad/runtime/autogen/__init__.py index 3eadb25987..c0db6f2454 100644 --- a/tinygrad/runtime/autogen/__init__.py +++ b/tinygrad/runtime/autogen/__init__.py @@ -113,17 +113,20 @@ def __getattr__(nm): *[f"{{}}/src/nouveau/{s}.h" for s in ["headers/nv_device_info", "compiler/nak"]], *[f"{{}}/src/gallium/auxiliary/gallivm/lp_bld{s}.h" for s in ["", "_passmgr", "_misc", "_type", "_init", "_nir", "_struct", "_jit_types", "_flow", "_const"]], - "{}/src/compiler/glsl_types.h", "{}/src/util/blob.h", "{}/src/util/ralloc.h", "{}/gen/builtin_types.h", "{}/gen/a6xx.xml.h", - "{}/gen/adreno_pm4.xml.h", "{}/gen/a6xx_enums.xml.h", "{}/gen/a6xx_descriptors.xml.h"], args=lambda:[ + *[f"{{}}/src/freedreno/{s}.h" for s in ["common/freedreno_dev_info", "ir3/ir3_compiler", "ir3/ir3_shader", "ir3/ir3_nir"]], + "{}/src/compiler/glsl_types.h", "{}/src/util/blob.h", "{}/src/util/ralloc.h", "{}/gen/ir3-isa.h", "{}/gen/builtin_types.h", + "{}/gen/a6xx.xml.h", "{}/gen/adreno_pm4.xml.h", "{}/gen/a6xx_enums.xml.h", "{}/gen/a6xx_descriptors.xml.h"], args=lambda:[ "-DHAVE_ENDIAN_H", "-DHAVE_STRUCT_TIMESPEC", "-DHAVE_PTHREAD", "-DHAVE_FUNC_ATTRIBUTE_PACKED", "-I{}/src", "-I{}/include", "-I{}/gen", - "-I{}/src/compiler/nir", "-I{}/src/gallium/auxiliary", "-I{}/src/gallium/include", f"-I{system('llvm-config-20 --includedir')}"], + "-I{}/src/compiler/nir", "-I{}/src/gallium/auxiliary", "-I{}/src/gallium/include", "-I{}/src/freedreno/common", + f"-I{system('llvm-config-20 --includedir')}"], preprocess=lambda path: subprocess.run("\n".join(["mkdir -p gen/util/format", "python3 src/compiler/builtin_types_h.py gen/builtin_types.h", - "python3 src/util/format/u_format_table.py src/util/format/u_format.yaml --enums > gen/util/format/u_format_gen.h", - *["python3 src/freedreno/registers/gen_header.py --rnn src/freedreno/registers/ --xml " + - f"src/freedreno/registers/adreno/{s}.xml c-defines > gen/{s}.xml.h" for s in ["a6xx", "adreno_pm4", "a6xx_enums", "a6xx_descriptors"]], - *[f"python3 src/compiler/{s}_h.py > gen/{s.split('/')[-1]}.h" for s in ["nir/nir_opcodes", "nir/nir_builder_opcodes"]], - *[f"python3 src/compiler/nir/nir_{s}_h.py --outdir gen" for s in ["intrinsics", "intrinsics_indices"]]]), cwd=path, shell=True, check=True), - tarball="https://gitlab.freedesktop.org/mesa/mesa/-/archive/mesa-25.2.4/mesa-25.2.4.tar.gz", + "python3 src/compiler/isaspec/decode.py --xml src/freedreno/isa/ir3.xml --out-c /dev/null --out-h gen/ir3-isa.h", + "python3 src/util/format/u_format_table.py src/util/format/u_format.yaml --enums > gen/util/format/u_format_gen.h", + *["python3 src/freedreno/registers/gen_header.py --rnn src/freedreno/registers/ --xml " + + f"src/freedreno/registers/adreno/{s}.xml c-defines > gen/{s}.xml.h" for s in ["a6xx", "adreno_pm4", "a6xx_enums", "a6xx_descriptors"]], + *[f"python3 src/compiler/{s}_h.py > gen/{s.split('/')[-1]}.h" for s in ["nir/nir_opcodes", "nir/nir_builder_opcodes"]], + *[f"python3 src/compiler/nir/nir_{s}_h.py --outdir gen" for s in ["intrinsics", "intrinsics_indices"]]]), cwd=path, shell=True, check=True), + tarball="https://gitlab.freedesktop.org/mesa/mesa/-/archive/mesa-25.2.7/mesa-25.2.7.tar.gz", prolog=["import gzip, base64", "from tinygrad.helpers import OSX"], epilog=lambda path: [system(f"{root}/extra/mesa/lvp_nir_options.sh {path}")]) case "libclang": return load("libclang", ["os.getenv('LIBCLANG_PATH', find_library('clang-20'))"], diff --git a/tinygrad/runtime/autogen/mesa.py b/tinygrad/runtime/autogen/mesa.py index e241938e68..43ab6cefa2 100644 --- a/tinygrad/runtime/autogen/mesa.py +++ b/tinygrad/runtime/autogen/mesa.py @@ -7043,6 +7043,1733 @@ except AttributeError: pass try: (lp_build_const_func_pointer_from_type:=dll.lp_build_const_func_pointer_from_type).restype, lp_build_const_func_pointer_from_type.argtypes = LLVMValueRef, [ctypes.POINTER(struct_gallivm_state), ctypes.c_void_p, LLVMTypeRef, ctypes.POINTER(ctypes.c_char)] except AttributeError: pass +class struct_fd_dev_info(Struct): pass +class struct_fd_dev_info_0(ctypes.Union): pass +struct_fd_dev_info_0._fields_ = [ + ('num_sp_cores', uint32_t), + ('num_ccu', uint32_t), +] +class struct_fd_dev_info_a6xx(Struct): pass +class struct_fd_dev_info_a6xx_magic(Struct): pass +struct_fd_dev_info_a6xx_magic._fields_ = [ + ('PC_POWER_CNTL', uint32_t), + ('TPL1_DBG_ECO_CNTL', uint32_t), + ('GRAS_DBG_ECO_CNTL', uint32_t), + ('SP_CHICKEN_BITS', uint32_t), + ('UCHE_CLIENT_PF', uint32_t), + ('PC_MODE_CNTL', uint32_t), + ('SP_DBG_ECO_CNTL', uint32_t), + ('RB_DBG_ECO_CNTL', uint32_t), + ('RB_DBG_ECO_CNTL_blit', uint32_t), + ('HLSQ_DBG_ECO_CNTL', uint32_t), + ('RB_UNKNOWN_8E01', uint32_t), + ('VPC_DBG_ECO_CNTL', uint32_t), + ('UCHE_UNKNOWN_0E12', uint32_t), + ('RB_CCU_DBG_ECO_CNTL', uint32_t), +] +class struct_fd_dev_info_a6xx_magic_raw(Struct): pass +struct_fd_dev_info_a6xx_magic_raw._fields_ = [ + ('reg', uint32_t), + ('value', uint32_t), +] +struct_fd_dev_info_a6xx._fields_ = [ + ('reg_size_vec4', uint32_t), + ('instr_cache_size', uint32_t), + ('has_hw_multiview', ctypes.c_bool), + ('has_fs_tex_prefetch', ctypes.c_bool), + ('supports_multiview_mask', ctypes.c_bool), + ('concurrent_resolve', ctypes.c_bool), + ('has_z24uint_s8uint', ctypes.c_bool), + ('tess_use_shared', ctypes.c_bool), + ('has_legacy_pipeline_shading_rate', ctypes.c_bool), + ('storage_16bit', ctypes.c_bool), + ('indirect_draw_wfm_quirk', ctypes.c_bool), + ('depth_bounds_require_depth_test_quirk', ctypes.c_bool), + ('has_tex_filter_cubic', ctypes.c_bool), + ('has_separate_chroma_filter', ctypes.c_bool), + ('has_sample_locations', ctypes.c_bool), + ('has_cp_reg_write', ctypes.c_bool), + ('has_8bpp_ubwc', ctypes.c_bool), + ('has_lpac', ctypes.c_bool), + ('has_getfiberid', ctypes.c_bool), + ('mov_half_shared_quirk', ctypes.c_bool), + ('has_movs', ctypes.c_bool), + ('has_dp2acc', ctypes.c_bool), + ('has_dp4acc', ctypes.c_bool), + ('enable_lrz_fast_clear', ctypes.c_bool), + ('has_lrz_dir_tracking', ctypes.c_bool), + ('lrz_track_quirk', ctypes.c_bool), + ('has_lrz_feedback', ctypes.c_bool), + ('has_per_view_viewport', ctypes.c_bool), + ('has_gmem_fast_clear', ctypes.c_bool), + ('sysmem_per_ccu_depth_cache_size', uint32_t), + ('sysmem_per_ccu_color_cache_size', uint32_t), + ('gmem_ccu_color_cache_fraction', uint32_t), + ('prim_alloc_threshold', uint32_t), + ('vs_max_inputs_count', uint32_t), + ('supports_double_threadsize', ctypes.c_bool), + ('has_sampler_minmax', ctypes.c_bool), + ('broken_ds_ubwc_quirk', ctypes.c_bool), + ('has_scalar_alu', ctypes.c_bool), + ('has_early_preamble', ctypes.c_bool), + ('has_isam_v', ctypes.c_bool), + ('has_ssbo_imm_offsets', ctypes.c_bool), + ('has_coherent_ubwc_flag_caches', ctypes.c_bool), + ('has_attachment_shading_rate', ctypes.c_bool), + ('has_ubwc_linear_mipmap_fallback', ctypes.c_bool), + ('predtf_nop_quirk', ctypes.c_bool), + ('prede_nop_quirk', ctypes.c_bool), + ('has_sad', ctypes.c_bool), + ('is_a702', ctypes.c_bool), + ('magic', struct_fd_dev_info_a6xx_magic), + ('magic_raw', (struct_fd_dev_info_a6xx_magic_raw * 64)), + ('max_sets', uint32_t), + ('line_width_min', ctypes.c_float), + ('line_width_max', ctypes.c_float), + ('has_bin_mask', ctypes.c_bool), +] +class struct_fd_dev_info_a7xx(Struct): pass +struct_fd_dev_info_a7xx._fields_ = [ + ('stsc_duplication_quirk', ctypes.c_bool), + ('has_event_write_sample_count', ctypes.c_bool), + ('has_64b_ssbo_atomics', ctypes.c_bool), + ('cmdbuf_start_a725_quirk', ctypes.c_bool), + ('load_inline_uniforms_via_preamble_ldgk', ctypes.c_bool), + ('load_shader_consts_via_preamble', ctypes.c_bool), + ('has_gmem_vpc_attr_buf', ctypes.c_bool), + ('sysmem_vpc_attr_buf_size', uint32_t), + ('gmem_vpc_attr_buf_size', uint32_t), + ('supports_uav_ubwc', ctypes.c_bool), + ('ubwc_unorm_snorm_int_compatible', ctypes.c_bool), + ('fs_must_have_non_zero_constlen_quirk', ctypes.c_bool), + ('gs_vpc_adjacency_quirk', ctypes.c_bool), + ('enable_tp_ubwc_flag_hint', ctypes.c_bool), + ('storage_8bit', ctypes.c_bool), + ('ubwc_all_formats_compatible', ctypes.c_bool), + ('has_compliant_dp4acc', ctypes.c_bool), + ('has_generic_clear', ctypes.c_bool), + ('r8g8_faulty_fast_clear_quirk', ctypes.c_bool), + ('ubwc_coherency_quirk', ctypes.c_bool), + ('has_persistent_counter', ctypes.c_bool), + ('has_primitive_shading_rate', ctypes.c_bool), + ('reading_shading_rate_requires_smask_quirk', ctypes.c_bool), + ('has_ray_intersection', ctypes.c_bool), + ('has_sw_fuse', ctypes.c_bool), + ('has_rt_workaround', ctypes.c_bool), + ('has_alias_rt', ctypes.c_bool), + ('has_abs_bin_mask', ctypes.c_bool), + ('new_control_regs', ctypes.c_bool), +] +struct_fd_dev_info._anonymous_ = ['_0'] +struct_fd_dev_info._fields_ = [ + ('chip', uint8_t), + ('tile_align_w', uint32_t), + ('tile_align_h', uint32_t), + ('gmem_align_w', uint32_t), + ('gmem_align_h', uint32_t), + ('tile_max_w', uint32_t), + ('tile_max_h', uint32_t), + ('num_vsc_pipes', uint32_t), + ('cs_shared_mem_size', uint32_t), + ('wave_granularity', ctypes.c_int32), + ('highest_bank_bit', uint32_t), + ('ubwc_swizzle', uint32_t), + ('macrotile_mode', uint32_t), + ('fibers_per_sp', uint32_t), + ('threadsize_base', uint32_t), + ('max_waves', uint32_t), + ('compute_lb_size', uint32_t), + ('_0', struct_fd_dev_info_0), + ('a6xx', struct_fd_dev_info_a6xx), + ('a7xx', struct_fd_dev_info_a7xx), +] +class struct_fd_dev_id(Struct): pass +struct_fd_dev_id._fields_ = [ + ('gpu_id', uint32_t), + ('chip_id', uint64_t), +] +try: (fd_dev_info_raw:=dll.fd_dev_info_raw).restype, fd_dev_info_raw.argtypes = ctypes.POINTER(struct_fd_dev_info), [ctypes.POINTER(struct_fd_dev_id)] +except AttributeError: pass + +try: (fd_dev_info:=dll.fd_dev_info).restype, fd_dev_info.argtypes = struct_fd_dev_info, [ctypes.POINTER(struct_fd_dev_id)] +except AttributeError: pass + +try: (fd_dev_info_raw_by_name:=dll.fd_dev_info_raw_by_name).restype, fd_dev_info_raw_by_name.argtypes = ctypes.POINTER(struct_fd_dev_info), [ctypes.POINTER(ctypes.c_char)] +except AttributeError: pass + +try: (fd_dev_name:=dll.fd_dev_name).restype, fd_dev_name.argtypes = ctypes.POINTER(ctypes.c_char), [ctypes.POINTER(struct_fd_dev_id)] +except AttributeError: pass + +try: (fd_dev_info_apply_dbg_options:=dll.fd_dev_info_apply_dbg_options).restype, fd_dev_info_apply_dbg_options.argtypes = None, [ctypes.POINTER(struct_fd_dev_info)] +except AttributeError: pass + +class struct_ir3_ra_reg_set(Struct): pass +class struct_ir3_shader(Struct): pass +class struct_ir3_compiler_options(Struct): pass +struct_ir3_compiler_options._fields_ = [ + ('push_ubo_with_preamble', ctypes.c_bool), + ('disable_cache', ctypes.c_bool), + ('bindless_fb_read_descriptor', ctypes.c_int32), + ('bindless_fb_read_slot', ctypes.c_int32), + ('storage_16bit', ctypes.c_bool), + ('storage_8bit', ctypes.c_bool), + ('lower_base_vertex', ctypes.c_bool), + ('shared_push_consts', ctypes.c_bool), + ('dual_color_blend_by_location', ctypes.c_bool), + ('uche_trap_base', uint64_t), +] +class struct_ir3_compiler(Struct): pass +class struct_fd_device(Struct): pass +class struct_disk_cache(Struct): pass +type_t = CEnum(ctypes.c_uint32) +TYPE_F16 = type_t.define('TYPE_F16', 0) +TYPE_F32 = type_t.define('TYPE_F32', 1) +TYPE_U16 = type_t.define('TYPE_U16', 2) +TYPE_U32 = type_t.define('TYPE_U32', 3) +TYPE_S16 = type_t.define('TYPE_S16', 4) +TYPE_S32 = type_t.define('TYPE_S32', 5) +TYPE_ATOMIC_U64 = type_t.define('TYPE_ATOMIC_U64', 6) +TYPE_U8 = type_t.define('TYPE_U8', 6) +TYPE_U8_32 = type_t.define('TYPE_U8_32', 7) + +class struct_ir3_compiler_delay_slots(Struct): pass +struct_ir3_compiler_delay_slots._fields_ = [ + ('alu_to_alu', ctypes.c_uint32), + ('non_alu', ctypes.c_uint32), + ('cat3_src2_read', ctypes.c_uint32), +] +struct_ir3_compiler._fields_ = [ + ('dev', ctypes.POINTER(struct_fd_device)), + ('dev_id', ctypes.POINTER(struct_fd_dev_id)), + ('gen', uint8_t), + ('shader_count', uint32_t), + ('disk_cache', ctypes.POINTER(struct_disk_cache)), + ('nir_options', struct_nir_shader_compiler_options), + ('options', struct_ir3_compiler_options), + ('is_64bit', ctypes.c_bool), + ('flat_bypass', ctypes.c_bool), + ('levels_add_one', ctypes.c_bool), + ('unminify_coords', ctypes.c_bool), + ('txf_ms_with_isaml', ctypes.c_bool), + ('array_index_add_half', ctypes.c_bool), + ('samgq_workaround', ctypes.c_bool), + ('tess_use_shared', ctypes.c_bool), + ('mergedregs', ctypes.c_bool), + ('max_const_pipeline', uint16_t), + ('max_const_geom', uint16_t), + ('max_const_frag', uint16_t), + ('max_const_safe', uint16_t), + ('max_const_compute', uint16_t), + ('compute_lb_size', uint32_t), + ('instr_align', uint32_t), + ('const_upload_unit', uint32_t), + ('threadsize_base', uint32_t), + ('wave_granularity', uint32_t), + ('max_waves', uint32_t), + ('reg_size_vec4', uint32_t), + ('local_mem_size', uint32_t), + ('branchstack_size', uint32_t), + ('pvtmem_per_fiber_align', uint32_t), + ('has_clip_cull', ctypes.c_bool), + ('has_pvtmem', ctypes.c_bool), + ('has_isam_ssbo', ctypes.c_bool), + ('has_isam_v', ctypes.c_bool), + ('has_ssbo_imm_offsets', ctypes.c_bool), + ('has_getfiberid', ctypes.c_bool), + ('mov_half_shared_quirk', ctypes.c_bool), + ('has_movs', ctypes.c_bool), + ('has_shfl', ctypes.c_bool), + ('has_bitwise_triops', ctypes.c_bool), + ('num_predicates', uint32_t), + ('bitops_can_write_predicates', ctypes.c_bool), + ('has_branch_and_or', ctypes.c_bool), + ('has_predication', ctypes.c_bool), + ('predtf_nop_quirk', ctypes.c_bool), + ('prede_nop_quirk', ctypes.c_bool), + ('max_variable_workgroup_size', uint32_t), + ('has_dp2acc', ctypes.c_bool), + ('has_dp4acc', ctypes.c_bool), + ('has_compliant_dp4acc', ctypes.c_bool), + ('bool_type', type_t), + ('has_shared_regfile', ctypes.c_bool), + ('has_preamble', ctypes.c_bool), + ('shared_consts_base_offset', uint16_t), + ('shared_consts_size', uint64_t), + ('geom_shared_consts_size_quirk', uint64_t), + ('has_fs_tex_prefetch', ctypes.c_bool), + ('stsc_duplication_quirk', ctypes.c_bool), + ('load_shader_consts_via_preamble', ctypes.c_bool), + ('load_inline_uniforms_via_preamble_ldgk', ctypes.c_bool), + ('has_scalar_alu', ctypes.c_bool), + ('fs_must_have_non_zero_constlen_quirk', ctypes.c_bool), + ('has_early_preamble', ctypes.c_bool), + ('has_rpt_bary_f', ctypes.c_bool), + ('has_alias_tex', ctypes.c_bool), + ('has_alias_rt', ctypes.c_bool), + ('reading_shading_rate_requires_smask_quirk', ctypes.c_bool), + ('delay_slots', struct_ir3_compiler_delay_slots), +] +try: (ir3_compiler_destroy:=dll.ir3_compiler_destroy).restype, ir3_compiler_destroy.argtypes = None, [ctypes.POINTER(struct_ir3_compiler)] +except AttributeError: pass + +try: (ir3_compiler_create:=dll.ir3_compiler_create).restype, ir3_compiler_create.argtypes = ctypes.POINTER(struct_ir3_compiler), [ctypes.POINTER(struct_fd_device), ctypes.POINTER(struct_fd_dev_id), ctypes.POINTER(struct_fd_dev_info), ctypes.POINTER(struct_ir3_compiler_options)] +except AttributeError: pass + +try: (ir3_disk_cache_init:=dll.ir3_disk_cache_init).restype, ir3_disk_cache_init.argtypes = None, [ctypes.POINTER(struct_ir3_compiler)] +except AttributeError: pass + +try: (ir3_disk_cache_init_shader_key:=dll.ir3_disk_cache_init_shader_key).restype, ir3_disk_cache_init_shader_key.argtypes = None, [ctypes.POINTER(struct_ir3_compiler), ctypes.POINTER(struct_ir3_shader)] +except AttributeError: pass + +class struct_ir3_shader_variant(Struct): pass +try: (ir3_retrieve_variant:=dll.ir3_retrieve_variant).restype, ir3_retrieve_variant.argtypes = ctypes.POINTER(struct_ir3_shader_variant), [ctypes.POINTER(struct_blob_reader), ctypes.POINTER(struct_ir3_compiler), ctypes.c_void_p] +except AttributeError: pass + +try: (ir3_store_variant:=dll.ir3_store_variant).restype, ir3_store_variant.argtypes = None, [ctypes.POINTER(struct_blob), ctypes.POINTER(struct_ir3_shader_variant)] +except AttributeError: pass + +try: (ir3_disk_cache_retrieve:=dll.ir3_disk_cache_retrieve).restype, ir3_disk_cache_retrieve.argtypes = ctypes.c_bool, [ctypes.POINTER(struct_ir3_shader), ctypes.POINTER(struct_ir3_shader_variant)] +except AttributeError: pass + +try: (ir3_disk_cache_store:=dll.ir3_disk_cache_store).restype, ir3_disk_cache_store.argtypes = None, [ctypes.POINTER(struct_ir3_shader), ctypes.POINTER(struct_ir3_shader_variant)] +except AttributeError: pass + +try: (ir3_get_compiler_options:=dll.ir3_get_compiler_options).restype, ir3_get_compiler_options.argtypes = ctypes.POINTER(nir_shader_compiler_options), [ctypes.POINTER(struct_ir3_compiler)] +except AttributeError: pass + +try: (ir3_compile_shader_nir:=dll.ir3_compile_shader_nir).restype, ir3_compile_shader_nir.argtypes = ctypes.c_int32, [ctypes.POINTER(struct_ir3_compiler), ctypes.POINTER(struct_ir3_shader), ctypes.POINTER(struct_ir3_shader_variant)] +except AttributeError: pass + +enum_ir3_shader_debug = CEnum(ctypes.c_uint32) +IR3_DBG_SHADER_VS = enum_ir3_shader_debug.define('IR3_DBG_SHADER_VS', 1) +IR3_DBG_SHADER_TCS = enum_ir3_shader_debug.define('IR3_DBG_SHADER_TCS', 2) +IR3_DBG_SHADER_TES = enum_ir3_shader_debug.define('IR3_DBG_SHADER_TES', 4) +IR3_DBG_SHADER_GS = enum_ir3_shader_debug.define('IR3_DBG_SHADER_GS', 8) +IR3_DBG_SHADER_FS = enum_ir3_shader_debug.define('IR3_DBG_SHADER_FS', 16) +IR3_DBG_SHADER_CS = enum_ir3_shader_debug.define('IR3_DBG_SHADER_CS', 32) +IR3_DBG_DISASM = enum_ir3_shader_debug.define('IR3_DBG_DISASM', 64) +IR3_DBG_OPTMSGS = enum_ir3_shader_debug.define('IR3_DBG_OPTMSGS', 128) +IR3_DBG_FORCES2EN = enum_ir3_shader_debug.define('IR3_DBG_FORCES2EN', 256) +IR3_DBG_NOUBOOPT = enum_ir3_shader_debug.define('IR3_DBG_NOUBOOPT', 512) +IR3_DBG_NOFP16 = enum_ir3_shader_debug.define('IR3_DBG_NOFP16', 1024) +IR3_DBG_NOCACHE = enum_ir3_shader_debug.define('IR3_DBG_NOCACHE', 2048) +IR3_DBG_SPILLALL = enum_ir3_shader_debug.define('IR3_DBG_SPILLALL', 4096) +IR3_DBG_NOPREAMBLE = enum_ir3_shader_debug.define('IR3_DBG_NOPREAMBLE', 8192) +IR3_DBG_SHADER_INTERNAL = enum_ir3_shader_debug.define('IR3_DBG_SHADER_INTERNAL', 16384) +IR3_DBG_FULLSYNC = enum_ir3_shader_debug.define('IR3_DBG_FULLSYNC', 32768) +IR3_DBG_FULLNOP = enum_ir3_shader_debug.define('IR3_DBG_FULLNOP', 65536) +IR3_DBG_NOEARLYPREAMBLE = enum_ir3_shader_debug.define('IR3_DBG_NOEARLYPREAMBLE', 131072) +IR3_DBG_NODESCPREFETCH = enum_ir3_shader_debug.define('IR3_DBG_NODESCPREFETCH', 262144) +IR3_DBG_EXPANDRPT = enum_ir3_shader_debug.define('IR3_DBG_EXPANDRPT', 524288) +IR3_DBG_ASM_ROUNDTRIP = enum_ir3_shader_debug.define('IR3_DBG_ASM_ROUNDTRIP', 1048576) +IR3_DBG_SCHEDMSGS = enum_ir3_shader_debug.define('IR3_DBG_SCHEDMSGS', 2097152) +IR3_DBG_RAMSGS = enum_ir3_shader_debug.define('IR3_DBG_RAMSGS', 4194304) +IR3_DBG_NOALIASTEX = enum_ir3_shader_debug.define('IR3_DBG_NOALIASTEX', 8388608) +IR3_DBG_NOALIASRT = enum_ir3_shader_debug.define('IR3_DBG_NOALIASRT', 16777216) + +try: ir3_shader_debug = enum_ir3_shader_debug.in_dll(dll, 'ir3_shader_debug') +except (ValueError,AttributeError): pass +try: ir3_shader_override_path = ctypes.POINTER(ctypes.c_char).in_dll(dll, 'ir3_shader_override_path') +except (ValueError,AttributeError): pass +try: (ir3_shader_debug_as_string:=dll.ir3_shader_debug_as_string).restype, ir3_shader_debug_as_string.argtypes = ctypes.POINTER(ctypes.c_char), [] +except AttributeError: pass + +class struct_ir3_driver_params_cs(Struct): pass +struct_ir3_driver_params_cs._fields_ = [ + ('num_work_groups_x', uint32_t), + ('num_work_groups_y', uint32_t), + ('num_work_groups_z', uint32_t), + ('work_dim', uint32_t), + ('base_group_x', uint32_t), + ('base_group_y', uint32_t), + ('base_group_z', uint32_t), + ('subgroup_size', uint32_t), + ('local_group_size_x', uint32_t), + ('local_group_size_y', uint32_t), + ('local_group_size_z', uint32_t), + ('subgroup_id_shift', uint32_t), + ('workgroup_id_x', uint32_t), + ('workgroup_id_y', uint32_t), + ('workgroup_id_z', uint32_t), + ('__pad', uint32_t), +] +class struct_ir3_driver_params_vs(Struct): pass +class struct_ir3_driver_params_vs_ucp(Struct): pass +struct_ir3_driver_params_vs_ucp._fields_ = [ + ('x', uint32_t), + ('y', uint32_t), + ('z', uint32_t), + ('w', uint32_t), +] +struct_ir3_driver_params_vs._fields_ = [ + ('draw_id', uint32_t), + ('vtxid_base', uint32_t), + ('instid_base', uint32_t), + ('vtxcnt_max', uint32_t), + ('is_indexed_draw', uint32_t), + ('ucp', (struct_ir3_driver_params_vs_ucp * 8)), + ('__pad_37_39', (uint32_t * 3)), +] +class struct_ir3_driver_params_tcs(Struct): pass +struct_ir3_driver_params_tcs._fields_ = [ + ('default_outer_level_x', uint32_t), + ('default_outer_level_y', uint32_t), + ('default_outer_level_z', uint32_t), + ('default_outer_level_w', uint32_t), + ('default_inner_level_x', uint32_t), + ('default_inner_level_y', uint32_t), + ('__pad_06_07', (uint32_t * 2)), +] +class struct_ir3_driver_params_fs(Struct): pass +struct_ir3_driver_params_fs._fields_ = [ + ('subgroup_size', uint32_t), + ('__pad_01_03', (uint32_t * 3)), + ('frag_invocation_count', uint32_t), + ('__pad_05_07', (uint32_t * 3)), + ('frag_size', uint32_t), + ('__pad_09', uint32_t), + ('frag_offset', uint32_t), + ('__pad_11_12', (uint32_t * 2)), +] +enum_ir3_bary = CEnum(ctypes.c_uint32) +IJ_PERSP_PIXEL = enum_ir3_bary.define('IJ_PERSP_PIXEL', 0) +IJ_PERSP_SAMPLE = enum_ir3_bary.define('IJ_PERSP_SAMPLE', 1) +IJ_PERSP_CENTROID = enum_ir3_bary.define('IJ_PERSP_CENTROID', 2) +IJ_PERSP_CENTER_RHW = enum_ir3_bary.define('IJ_PERSP_CENTER_RHW', 3) +IJ_LINEAR_PIXEL = enum_ir3_bary.define('IJ_LINEAR_PIXEL', 4) +IJ_LINEAR_CENTROID = enum_ir3_bary.define('IJ_LINEAR_CENTROID', 5) +IJ_LINEAR_SAMPLE = enum_ir3_bary.define('IJ_LINEAR_SAMPLE', 6) +IJ_COUNT = enum_ir3_bary.define('IJ_COUNT', 7) + +enum_ir3_wavesize_option = CEnum(ctypes.c_uint32) +IR3_SINGLE_ONLY = enum_ir3_wavesize_option.define('IR3_SINGLE_ONLY', 0) +IR3_SINGLE_OR_DOUBLE = enum_ir3_wavesize_option.define('IR3_SINGLE_OR_DOUBLE', 1) +IR3_DOUBLE_ONLY = enum_ir3_wavesize_option.define('IR3_DOUBLE_ONLY', 2) + +class struct_ir3_ubo_info(Struct): pass +struct_ir3_ubo_info._fields_ = [ + ('global_base', ctypes.POINTER(struct_nir_def)), + ('block', uint32_t), + ('bindless_base', uint16_t), + ('bindless', ctypes.c_bool), + ('global', ctypes.c_bool), +] +class struct_ir3_ubo_range(Struct): pass +struct_ir3_ubo_range._fields_ = [ + ('ubo', struct_ir3_ubo_info), + ('offset', uint32_t), + ('start', uint32_t), + ('end', uint32_t), +] +class struct_ir3_ubo_analysis_state(Struct): pass +struct_ir3_ubo_analysis_state._fields_ = [ + ('range', (struct_ir3_ubo_range * 32)), + ('num_enabled', uint32_t), + ('size', uint32_t), +] +enum_ir3_push_consts_type = CEnum(ctypes.c_uint32) +IR3_PUSH_CONSTS_NONE = enum_ir3_push_consts_type.define('IR3_PUSH_CONSTS_NONE', 0) +IR3_PUSH_CONSTS_PER_STAGE = enum_ir3_push_consts_type.define('IR3_PUSH_CONSTS_PER_STAGE', 1) +IR3_PUSH_CONSTS_SHARED = enum_ir3_push_consts_type.define('IR3_PUSH_CONSTS_SHARED', 2) +IR3_PUSH_CONSTS_SHARED_PREAMBLE = enum_ir3_push_consts_type.define('IR3_PUSH_CONSTS_SHARED_PREAMBLE', 3) + +class struct_ir3_driver_ubo(Struct): pass +struct_ir3_driver_ubo._fields_ = [ + ('idx', int32_t), + ('size', uint32_t), +] +enum_ir3_const_alloc_type = CEnum(ctypes.c_uint32) +IR3_CONST_ALLOC_PUSH_CONSTS = enum_ir3_const_alloc_type.define('IR3_CONST_ALLOC_PUSH_CONSTS', 0) +IR3_CONST_ALLOC_DYN_DESCRIPTOR_OFFSET = enum_ir3_const_alloc_type.define('IR3_CONST_ALLOC_DYN_DESCRIPTOR_OFFSET', 1) +IR3_CONST_ALLOC_INLINE_UNIFORM_ADDRS = enum_ir3_const_alloc_type.define('IR3_CONST_ALLOC_INLINE_UNIFORM_ADDRS', 2) +IR3_CONST_ALLOC_DRIVER_PARAMS = enum_ir3_const_alloc_type.define('IR3_CONST_ALLOC_DRIVER_PARAMS', 3) +IR3_CONST_ALLOC_UBO_RANGES = enum_ir3_const_alloc_type.define('IR3_CONST_ALLOC_UBO_RANGES', 4) +IR3_CONST_ALLOC_PREAMBLE = enum_ir3_const_alloc_type.define('IR3_CONST_ALLOC_PREAMBLE', 5) +IR3_CONST_ALLOC_GLOBAL = enum_ir3_const_alloc_type.define('IR3_CONST_ALLOC_GLOBAL', 6) +IR3_CONST_ALLOC_UBO_PTRS = enum_ir3_const_alloc_type.define('IR3_CONST_ALLOC_UBO_PTRS', 7) +IR3_CONST_ALLOC_IMAGE_DIMS = enum_ir3_const_alloc_type.define('IR3_CONST_ALLOC_IMAGE_DIMS', 8) +IR3_CONST_ALLOC_TFBO = enum_ir3_const_alloc_type.define('IR3_CONST_ALLOC_TFBO', 9) +IR3_CONST_ALLOC_PRIMITIVE_PARAM = enum_ir3_const_alloc_type.define('IR3_CONST_ALLOC_PRIMITIVE_PARAM', 10) +IR3_CONST_ALLOC_PRIMITIVE_MAP = enum_ir3_const_alloc_type.define('IR3_CONST_ALLOC_PRIMITIVE_MAP', 11) +IR3_CONST_ALLOC_MAX = enum_ir3_const_alloc_type.define('IR3_CONST_ALLOC_MAX', 12) + +class struct_ir3_const_allocation(Struct): pass +struct_ir3_const_allocation._fields_ = [ + ('offset_vec4', uint32_t), + ('size_vec4', uint32_t), + ('reserved_size_vec4', uint32_t), + ('reserved_align_vec4', uint32_t), +] +class struct_ir3_const_allocations(Struct): pass +struct_ir3_const_allocations._fields_ = [ + ('consts', (struct_ir3_const_allocation * 12)), + ('max_const_offset_vec4', uint32_t), + ('reserved_vec4', uint32_t), +] +class struct_ir3_const_image_dims(Struct): pass +struct_ir3_const_image_dims._fields_ = [ + ('mask', uint32_t), + ('count', uint32_t), + ('off', (uint32_t * 32)), +] +class struct_ir3_imm_const_state(Struct): pass +struct_ir3_imm_const_state._fields_ = [ + ('size', ctypes.c_uint32), + ('count', ctypes.c_uint32), + ('values', ctypes.POINTER(uint32_t)), +] +class struct_ir3_const_state(Struct): pass +struct_ir3_const_state._fields_ = [ + ('num_ubos', ctypes.c_uint32), + ('num_app_ubos', ctypes.c_uint32), + ('num_driver_params', ctypes.c_uint32), + ('consts_ubo', struct_ir3_driver_ubo), + ('driver_params_ubo', struct_ir3_driver_ubo), + ('primitive_map_ubo', struct_ir3_driver_ubo), + ('primitive_param_ubo', struct_ir3_driver_ubo), + ('allocs', struct_ir3_const_allocations), + ('image_dims', struct_ir3_const_image_dims), + ('ubo_state', struct_ir3_ubo_analysis_state), + ('push_consts_type', enum_ir3_push_consts_type), +] +class struct_ir3_stream_output(Struct): pass +struct_ir3_stream_output._fields_ = [ + ('register_index', ctypes.c_uint32,6), + ('start_component', ctypes.c_uint32,2), + ('num_components', ctypes.c_uint32,3), + ('output_buffer', ctypes.c_uint32,3), + ('dst_offset', ctypes.c_uint32,16), + ('stream', ctypes.c_uint32,2), +] +class struct_ir3_stream_output_info(Struct): pass +struct_ir3_stream_output_info._fields_ = [ + ('num_outputs', ctypes.c_uint32), + ('stride', (uint16_t * 4)), + ('streams_written', uint8_t), + ('buffer_to_stream', (uint8_t * 4)), + ('output', (struct_ir3_stream_output * 128)), +] +class struct_ir3_sampler_prefetch(Struct): pass +opc_t = CEnum(ctypes.c_uint32) +OPC_NOP = opc_t.define('OPC_NOP', 0) +OPC_JUMP = opc_t.define('OPC_JUMP', 2) +OPC_CALL = opc_t.define('OPC_CALL', 3) +OPC_RET = opc_t.define('OPC_RET', 4) +OPC_KILL = opc_t.define('OPC_KILL', 5) +OPC_END = opc_t.define('OPC_END', 6) +OPC_EMIT = opc_t.define('OPC_EMIT', 7) +OPC_CUT = opc_t.define('OPC_CUT', 8) +OPC_CHMASK = opc_t.define('OPC_CHMASK', 9) +OPC_CHSH = opc_t.define('OPC_CHSH', 10) +OPC_FLOW_REV = opc_t.define('OPC_FLOW_REV', 11) +OPC_BKT = opc_t.define('OPC_BKT', 16) +OPC_STKS = opc_t.define('OPC_STKS', 17) +OPC_STKR = opc_t.define('OPC_STKR', 18) +OPC_XSET = opc_t.define('OPC_XSET', 19) +OPC_XCLR = opc_t.define('OPC_XCLR', 20) +OPC_GETONE = opc_t.define('OPC_GETONE', 21) +OPC_DBG = opc_t.define('OPC_DBG', 22) +OPC_SHPS = opc_t.define('OPC_SHPS', 23) +OPC_SHPE = opc_t.define('OPC_SHPE', 24) +OPC_GETLAST = opc_t.define('OPC_GETLAST', 25) +OPC_PREDT = opc_t.define('OPC_PREDT', 29) +OPC_PREDF = opc_t.define('OPC_PREDF', 30) +OPC_PREDE = opc_t.define('OPC_PREDE', 31) +OPC_BR = opc_t.define('OPC_BR', 40) +OPC_BRAO = opc_t.define('OPC_BRAO', 41) +OPC_BRAA = opc_t.define('OPC_BRAA', 42) +OPC_BRAC = opc_t.define('OPC_BRAC', 43) +OPC_BANY = opc_t.define('OPC_BANY', 44) +OPC_BALL = opc_t.define('OPC_BALL', 45) +OPC_BRAX = opc_t.define('OPC_BRAX', 46) +OPC_DEMOTE = opc_t.define('OPC_DEMOTE', 47) +OPC_MOV = opc_t.define('OPC_MOV', 128) +OPC_MOVP = opc_t.define('OPC_MOVP', 129) +OPC_MOVS = opc_t.define('OPC_MOVS', 130) +OPC_MOVMSK = opc_t.define('OPC_MOVMSK', 131) +OPC_SWZ = opc_t.define('OPC_SWZ', 132) +OPC_GAT = opc_t.define('OPC_GAT', 133) +OPC_SCT = opc_t.define('OPC_SCT', 134) +OPC_MOV_IMMED = opc_t.define('OPC_MOV_IMMED', 168) +OPC_MOV_CONST = opc_t.define('OPC_MOV_CONST', 169) +OPC_MOV_GPR = opc_t.define('OPC_MOV_GPR', 170) +OPC_MOV_RELGPR = opc_t.define('OPC_MOV_RELGPR', 171) +OPC_MOV_RELCONST = opc_t.define('OPC_MOV_RELCONST', 172) +OPC_MOVS_IMMED = opc_t.define('OPC_MOVS_IMMED', 173) +OPC_MOVS_A0 = opc_t.define('OPC_MOVS_A0', 174) +OPC_BALLOT_MACRO = opc_t.define('OPC_BALLOT_MACRO', 178) +OPC_ANY_MACRO = opc_t.define('OPC_ANY_MACRO', 179) +OPC_ALL_MACRO = opc_t.define('OPC_ALL_MACRO', 180) +OPC_ELECT_MACRO = opc_t.define('OPC_ELECT_MACRO', 181) +OPC_READ_COND_MACRO = opc_t.define('OPC_READ_COND_MACRO', 182) +OPC_READ_FIRST_MACRO = opc_t.define('OPC_READ_FIRST_MACRO', 183) +OPC_SHPS_MACRO = opc_t.define('OPC_SHPS_MACRO', 184) +OPC_READ_GETLAST_MACRO = opc_t.define('OPC_READ_GETLAST_MACRO', 185) +OPC_SCAN_MACRO = opc_t.define('OPC_SCAN_MACRO', 186) +OPC_SCAN_CLUSTERS_MACRO = opc_t.define('OPC_SCAN_CLUSTERS_MACRO', 188) +OPC_ADD_F = opc_t.define('OPC_ADD_F', 256) +OPC_MIN_F = opc_t.define('OPC_MIN_F', 257) +OPC_MAX_F = opc_t.define('OPC_MAX_F', 258) +OPC_MUL_F = opc_t.define('OPC_MUL_F', 259) +OPC_SIGN_F = opc_t.define('OPC_SIGN_F', 260) +OPC_CMPS_F = opc_t.define('OPC_CMPS_F', 261) +OPC_ABSNEG_F = opc_t.define('OPC_ABSNEG_F', 262) +OPC_CMPV_F = opc_t.define('OPC_CMPV_F', 263) +OPC_FLOOR_F = opc_t.define('OPC_FLOOR_F', 265) +OPC_CEIL_F = opc_t.define('OPC_CEIL_F', 266) +OPC_RNDNE_F = opc_t.define('OPC_RNDNE_F', 267) +OPC_RNDAZ_F = opc_t.define('OPC_RNDAZ_F', 268) +OPC_TRUNC_F = opc_t.define('OPC_TRUNC_F', 269) +OPC_ADD_U = opc_t.define('OPC_ADD_U', 272) +OPC_ADD_S = opc_t.define('OPC_ADD_S', 273) +OPC_SUB_U = opc_t.define('OPC_SUB_U', 274) +OPC_SUB_S = opc_t.define('OPC_SUB_S', 275) +OPC_CMPS_U = opc_t.define('OPC_CMPS_U', 276) +OPC_CMPS_S = opc_t.define('OPC_CMPS_S', 277) +OPC_MIN_U = opc_t.define('OPC_MIN_U', 278) +OPC_MIN_S = opc_t.define('OPC_MIN_S', 279) +OPC_MAX_U = opc_t.define('OPC_MAX_U', 280) +OPC_MAX_S = opc_t.define('OPC_MAX_S', 281) +OPC_ABSNEG_S = opc_t.define('OPC_ABSNEG_S', 282) +OPC_AND_B = opc_t.define('OPC_AND_B', 284) +OPC_OR_B = opc_t.define('OPC_OR_B', 285) +OPC_NOT_B = opc_t.define('OPC_NOT_B', 286) +OPC_XOR_B = opc_t.define('OPC_XOR_B', 287) +OPC_CMPV_U = opc_t.define('OPC_CMPV_U', 289) +OPC_CMPV_S = opc_t.define('OPC_CMPV_S', 290) +OPC_MUL_U24 = opc_t.define('OPC_MUL_U24', 304) +OPC_MUL_S24 = opc_t.define('OPC_MUL_S24', 305) +OPC_MULL_U = opc_t.define('OPC_MULL_U', 306) +OPC_BFREV_B = opc_t.define('OPC_BFREV_B', 307) +OPC_CLZ_S = opc_t.define('OPC_CLZ_S', 308) +OPC_CLZ_B = opc_t.define('OPC_CLZ_B', 309) +OPC_SHL_B = opc_t.define('OPC_SHL_B', 310) +OPC_SHR_B = opc_t.define('OPC_SHR_B', 311) +OPC_ASHR_B = opc_t.define('OPC_ASHR_B', 312) +OPC_BARY_F = opc_t.define('OPC_BARY_F', 313) +OPC_MGEN_B = opc_t.define('OPC_MGEN_B', 314) +OPC_GETBIT_B = opc_t.define('OPC_GETBIT_B', 315) +OPC_SETRM = opc_t.define('OPC_SETRM', 316) +OPC_CBITS_B = opc_t.define('OPC_CBITS_B', 317) +OPC_SHB = opc_t.define('OPC_SHB', 318) +OPC_MSAD = opc_t.define('OPC_MSAD', 319) +OPC_FLAT_B = opc_t.define('OPC_FLAT_B', 320) +OPC_MAD_U16 = opc_t.define('OPC_MAD_U16', 384) +OPC_MADSH_U16 = opc_t.define('OPC_MADSH_U16', 385) +OPC_MAD_S16 = opc_t.define('OPC_MAD_S16', 386) +OPC_MADSH_M16 = opc_t.define('OPC_MADSH_M16', 387) +OPC_MAD_U24 = opc_t.define('OPC_MAD_U24', 388) +OPC_MAD_S24 = opc_t.define('OPC_MAD_S24', 389) +OPC_MAD_F16 = opc_t.define('OPC_MAD_F16', 390) +OPC_MAD_F32 = opc_t.define('OPC_MAD_F32', 391) +OPC_SEL_B16 = opc_t.define('OPC_SEL_B16', 392) +OPC_SEL_B32 = opc_t.define('OPC_SEL_B32', 393) +OPC_SEL_S16 = opc_t.define('OPC_SEL_S16', 394) +OPC_SEL_S32 = opc_t.define('OPC_SEL_S32', 395) +OPC_SEL_F16 = opc_t.define('OPC_SEL_F16', 396) +OPC_SEL_F32 = opc_t.define('OPC_SEL_F32', 397) +OPC_SAD_S16 = opc_t.define('OPC_SAD_S16', 398) +OPC_SAD_S32 = opc_t.define('OPC_SAD_S32', 399) +OPC_SHRM = opc_t.define('OPC_SHRM', 400) +OPC_SHLM = opc_t.define('OPC_SHLM', 401) +OPC_SHRG = opc_t.define('OPC_SHRG', 402) +OPC_SHLG = opc_t.define('OPC_SHLG', 403) +OPC_ANDG = opc_t.define('OPC_ANDG', 404) +OPC_DP2ACC = opc_t.define('OPC_DP2ACC', 405) +OPC_DP4ACC = opc_t.define('OPC_DP4ACC', 406) +OPC_WMM = opc_t.define('OPC_WMM', 407) +OPC_WMM_ACCU = opc_t.define('OPC_WMM_ACCU', 408) +OPC_RCP = opc_t.define('OPC_RCP', 512) +OPC_RSQ = opc_t.define('OPC_RSQ', 513) +OPC_LOG2 = opc_t.define('OPC_LOG2', 514) +OPC_EXP2 = opc_t.define('OPC_EXP2', 515) +OPC_SIN = opc_t.define('OPC_SIN', 516) +OPC_COS = opc_t.define('OPC_COS', 517) +OPC_SQRT = opc_t.define('OPC_SQRT', 518) +OPC_HRSQ = opc_t.define('OPC_HRSQ', 521) +OPC_HLOG2 = opc_t.define('OPC_HLOG2', 522) +OPC_HEXP2 = opc_t.define('OPC_HEXP2', 523) +OPC_ISAM = opc_t.define('OPC_ISAM', 640) +OPC_ISAML = opc_t.define('OPC_ISAML', 641) +OPC_ISAMM = opc_t.define('OPC_ISAMM', 642) +OPC_SAM = opc_t.define('OPC_SAM', 643) +OPC_SAMB = opc_t.define('OPC_SAMB', 644) +OPC_SAML = opc_t.define('OPC_SAML', 645) +OPC_SAMGQ = opc_t.define('OPC_SAMGQ', 646) +OPC_GETLOD = opc_t.define('OPC_GETLOD', 647) +OPC_CONV = opc_t.define('OPC_CONV', 648) +OPC_CONVM = opc_t.define('OPC_CONVM', 649) +OPC_GETSIZE = opc_t.define('OPC_GETSIZE', 650) +OPC_GETBUF = opc_t.define('OPC_GETBUF', 651) +OPC_GETPOS = opc_t.define('OPC_GETPOS', 652) +OPC_GETINFO = opc_t.define('OPC_GETINFO', 653) +OPC_DSX = opc_t.define('OPC_DSX', 654) +OPC_DSY = opc_t.define('OPC_DSY', 655) +OPC_GATHER4R = opc_t.define('OPC_GATHER4R', 656) +OPC_GATHER4G = opc_t.define('OPC_GATHER4G', 657) +OPC_GATHER4B = opc_t.define('OPC_GATHER4B', 658) +OPC_GATHER4A = opc_t.define('OPC_GATHER4A', 659) +OPC_SAMGP0 = opc_t.define('OPC_SAMGP0', 660) +OPC_SAMGP1 = opc_t.define('OPC_SAMGP1', 661) +OPC_SAMGP2 = opc_t.define('OPC_SAMGP2', 662) +OPC_SAMGP3 = opc_t.define('OPC_SAMGP3', 663) +OPC_DSXPP_1 = opc_t.define('OPC_DSXPP_1', 664) +OPC_DSYPP_1 = opc_t.define('OPC_DSYPP_1', 665) +OPC_RGETPOS = opc_t.define('OPC_RGETPOS', 666) +OPC_RGETINFO = opc_t.define('OPC_RGETINFO', 667) +OPC_BRCST_ACTIVE = opc_t.define('OPC_BRCST_ACTIVE', 668) +OPC_QUAD_SHUFFLE_BRCST = opc_t.define('OPC_QUAD_SHUFFLE_BRCST', 669) +OPC_QUAD_SHUFFLE_HORIZ = opc_t.define('OPC_QUAD_SHUFFLE_HORIZ', 670) +OPC_QUAD_SHUFFLE_VERT = opc_t.define('OPC_QUAD_SHUFFLE_VERT', 671) +OPC_QUAD_SHUFFLE_DIAG = opc_t.define('OPC_QUAD_SHUFFLE_DIAG', 672) +OPC_TCINV = opc_t.define('OPC_TCINV', 673) +OPC_DSXPP_MACRO = opc_t.define('OPC_DSXPP_MACRO', 675) +OPC_DSYPP_MACRO = opc_t.define('OPC_DSYPP_MACRO', 676) +OPC_LDG = opc_t.define('OPC_LDG', 768) +OPC_LDL = opc_t.define('OPC_LDL', 769) +OPC_LDP = opc_t.define('OPC_LDP', 770) +OPC_STG = opc_t.define('OPC_STG', 771) +OPC_STL = opc_t.define('OPC_STL', 772) +OPC_STP = opc_t.define('OPC_STP', 773) +OPC_LDIB = opc_t.define('OPC_LDIB', 774) +OPC_G2L = opc_t.define('OPC_G2L', 775) +OPC_L2G = opc_t.define('OPC_L2G', 776) +OPC_PREFETCH = opc_t.define('OPC_PREFETCH', 777) +OPC_LDLW = opc_t.define('OPC_LDLW', 778) +OPC_STLW = opc_t.define('OPC_STLW', 779) +OPC_RESFMT = opc_t.define('OPC_RESFMT', 782) +OPC_RESINFO = opc_t.define('OPC_RESINFO', 783) +OPC_ATOMIC_ADD = opc_t.define('OPC_ATOMIC_ADD', 784) +OPC_ATOMIC_SUB = opc_t.define('OPC_ATOMIC_SUB', 785) +OPC_ATOMIC_XCHG = opc_t.define('OPC_ATOMIC_XCHG', 786) +OPC_ATOMIC_INC = opc_t.define('OPC_ATOMIC_INC', 787) +OPC_ATOMIC_DEC = opc_t.define('OPC_ATOMIC_DEC', 788) +OPC_ATOMIC_CMPXCHG = opc_t.define('OPC_ATOMIC_CMPXCHG', 789) +OPC_ATOMIC_MIN = opc_t.define('OPC_ATOMIC_MIN', 790) +OPC_ATOMIC_MAX = opc_t.define('OPC_ATOMIC_MAX', 791) +OPC_ATOMIC_AND = opc_t.define('OPC_ATOMIC_AND', 792) +OPC_ATOMIC_OR = opc_t.define('OPC_ATOMIC_OR', 793) +OPC_ATOMIC_XOR = opc_t.define('OPC_ATOMIC_XOR', 794) +OPC_LDGB = opc_t.define('OPC_LDGB', 795) +OPC_STGB = opc_t.define('OPC_STGB', 796) +OPC_STIB = opc_t.define('OPC_STIB', 797) +OPC_LDC = opc_t.define('OPC_LDC', 798) +OPC_LDLV = opc_t.define('OPC_LDLV', 799) +OPC_PIPR = opc_t.define('OPC_PIPR', 800) +OPC_PIPC = opc_t.define('OPC_PIPC', 801) +OPC_EMIT2 = opc_t.define('OPC_EMIT2', 802) +OPC_ENDLS = opc_t.define('OPC_ENDLS', 803) +OPC_GETSPID = opc_t.define('OPC_GETSPID', 804) +OPC_GETWID = opc_t.define('OPC_GETWID', 805) +OPC_GETFIBERID = opc_t.define('OPC_GETFIBERID', 806) +OPC_SHFL = opc_t.define('OPC_SHFL', 807) +OPC_STC = opc_t.define('OPC_STC', 808) +OPC_RESINFO_B = opc_t.define('OPC_RESINFO_B', 809) +OPC_LDIB_B = opc_t.define('OPC_LDIB_B', 810) +OPC_STIB_B = opc_t.define('OPC_STIB_B', 811) +OPC_ATOMIC_B_ADD = opc_t.define('OPC_ATOMIC_B_ADD', 812) +OPC_ATOMIC_B_SUB = opc_t.define('OPC_ATOMIC_B_SUB', 813) +OPC_ATOMIC_B_XCHG = opc_t.define('OPC_ATOMIC_B_XCHG', 814) +OPC_ATOMIC_B_INC = opc_t.define('OPC_ATOMIC_B_INC', 815) +OPC_ATOMIC_B_DEC = opc_t.define('OPC_ATOMIC_B_DEC', 816) +OPC_ATOMIC_B_CMPXCHG = opc_t.define('OPC_ATOMIC_B_CMPXCHG', 817) +OPC_ATOMIC_B_MIN = opc_t.define('OPC_ATOMIC_B_MIN', 818) +OPC_ATOMIC_B_MAX = opc_t.define('OPC_ATOMIC_B_MAX', 819) +OPC_ATOMIC_B_AND = opc_t.define('OPC_ATOMIC_B_AND', 820) +OPC_ATOMIC_B_OR = opc_t.define('OPC_ATOMIC_B_OR', 821) +OPC_ATOMIC_B_XOR = opc_t.define('OPC_ATOMIC_B_XOR', 822) +OPC_ATOMIC_S_ADD = opc_t.define('OPC_ATOMIC_S_ADD', 823) +OPC_ATOMIC_S_SUB = opc_t.define('OPC_ATOMIC_S_SUB', 824) +OPC_ATOMIC_S_XCHG = opc_t.define('OPC_ATOMIC_S_XCHG', 825) +OPC_ATOMIC_S_INC = opc_t.define('OPC_ATOMIC_S_INC', 826) +OPC_ATOMIC_S_DEC = opc_t.define('OPC_ATOMIC_S_DEC', 827) +OPC_ATOMIC_S_CMPXCHG = opc_t.define('OPC_ATOMIC_S_CMPXCHG', 828) +OPC_ATOMIC_S_MIN = opc_t.define('OPC_ATOMIC_S_MIN', 829) +OPC_ATOMIC_S_MAX = opc_t.define('OPC_ATOMIC_S_MAX', 830) +OPC_ATOMIC_S_AND = opc_t.define('OPC_ATOMIC_S_AND', 831) +OPC_ATOMIC_S_OR = opc_t.define('OPC_ATOMIC_S_OR', 832) +OPC_ATOMIC_S_XOR = opc_t.define('OPC_ATOMIC_S_XOR', 833) +OPC_ATOMIC_G_ADD = opc_t.define('OPC_ATOMIC_G_ADD', 834) +OPC_ATOMIC_G_SUB = opc_t.define('OPC_ATOMIC_G_SUB', 835) +OPC_ATOMIC_G_XCHG = opc_t.define('OPC_ATOMIC_G_XCHG', 836) +OPC_ATOMIC_G_INC = opc_t.define('OPC_ATOMIC_G_INC', 837) +OPC_ATOMIC_G_DEC = opc_t.define('OPC_ATOMIC_G_DEC', 838) +OPC_ATOMIC_G_CMPXCHG = opc_t.define('OPC_ATOMIC_G_CMPXCHG', 839) +OPC_ATOMIC_G_MIN = opc_t.define('OPC_ATOMIC_G_MIN', 840) +OPC_ATOMIC_G_MAX = opc_t.define('OPC_ATOMIC_G_MAX', 841) +OPC_ATOMIC_G_AND = opc_t.define('OPC_ATOMIC_G_AND', 842) +OPC_ATOMIC_G_OR = opc_t.define('OPC_ATOMIC_G_OR', 843) +OPC_ATOMIC_G_XOR = opc_t.define('OPC_ATOMIC_G_XOR', 844) +OPC_LDG_A = opc_t.define('OPC_LDG_A', 845) +OPC_STG_A = opc_t.define('OPC_STG_A', 846) +OPC_SPILL_MACRO = opc_t.define('OPC_SPILL_MACRO', 847) +OPC_RELOAD_MACRO = opc_t.define('OPC_RELOAD_MACRO', 848) +OPC_LDC_K = opc_t.define('OPC_LDC_K', 849) +OPC_STSC = opc_t.define('OPC_STSC', 850) +OPC_LDG_K = opc_t.define('OPC_LDG_K', 851) +OPC_PUSH_CONSTS_LOAD_MACRO = opc_t.define('OPC_PUSH_CONSTS_LOAD_MACRO', 852) +OPC_RAY_INTERSECTION = opc_t.define('OPC_RAY_INTERSECTION', 858) +OPC_RESBASE = opc_t.define('OPC_RESBASE', 859) +OPC_BAR = opc_t.define('OPC_BAR', 896) +OPC_FENCE = opc_t.define('OPC_FENCE', 897) +OPC_SLEEP = opc_t.define('OPC_SLEEP', 898) +OPC_ICINV = opc_t.define('OPC_ICINV', 899) +OPC_DCCLN = opc_t.define('OPC_DCCLN', 900) +OPC_DCINV = opc_t.define('OPC_DCINV', 901) +OPC_DCFLU = opc_t.define('OPC_DCFLU', 902) +OPC_LOCK = opc_t.define('OPC_LOCK', 903) +OPC_UNLOCK = opc_t.define('OPC_UNLOCK', 904) +OPC_ALIAS = opc_t.define('OPC_ALIAS', 905) +OPC_CCINV = opc_t.define('OPC_CCINV', 906) +OPC_META_INPUT = opc_t.define('OPC_META_INPUT', 1024) +OPC_META_SPLIT = opc_t.define('OPC_META_SPLIT', 1026) +OPC_META_COLLECT = opc_t.define('OPC_META_COLLECT', 1027) +OPC_META_TEX_PREFETCH = opc_t.define('OPC_META_TEX_PREFETCH', 1028) +OPC_META_PARALLEL_COPY = opc_t.define('OPC_META_PARALLEL_COPY', 1029) +OPC_META_PHI = opc_t.define('OPC_META_PHI', 1030) +OPC_META_RAW = opc_t.define('OPC_META_RAW', 1031) + +struct_ir3_sampler_prefetch._fields_ = [ + ('src', uint8_t), + ('bindless', ctypes.c_bool), + ('samp_id', uint8_t), + ('tex_id', uint8_t), + ('samp_bindless_id', uint16_t), + ('tex_bindless_id', uint16_t), + ('dst', uint8_t), + ('wrmask', uint8_t), + ('half_precision', uint8_t), + ('tex_opc', opc_t), +] +class struct_ir3_shader_key(Struct): pass +class struct_ir3_shader_key_0(ctypes.Union): pass +class struct_ir3_shader_key_0_0(Struct): pass +struct_ir3_shader_key_0_0._fields_ = [ + ('ucp_enables', ctypes.c_uint32,8), + ('has_per_samp', ctypes.c_uint32,1), + ('sample_shading', ctypes.c_uint32,1), + ('msaa', ctypes.c_uint32,1), + ('rasterflat', ctypes.c_uint32,1), + ('tessellation', ctypes.c_uint32,2), + ('has_gs', ctypes.c_uint32,1), + ('tcs_store_primid', ctypes.c_uint32,1), + ('safe_constlen', ctypes.c_uint32,1), + ('force_dual_color_blend', ctypes.c_uint32,1), +] +struct_ir3_shader_key_0._anonymous_ = ['_0'] +struct_ir3_shader_key_0._fields_ = [ + ('_0', struct_ir3_shader_key_0_0), + ('global', uint32_t), +] +struct_ir3_shader_key._anonymous_ = ['_0'] +struct_ir3_shader_key._fields_ = [ + ('_0', struct_ir3_shader_key_0), + ('vsamples', uint32_t), + ('fsamples', uint32_t), + ('vastc_srgb', uint16_t), + ('fastc_srgb', uint16_t), + ('vsampler_swizzles', (uint16_t * 16)), + ('fsampler_swizzles', (uint16_t * 16)), +] +class struct_ir3_ibo_mapping(Struct): pass +struct_ir3_ibo_mapping._fields_ = [ + ('ssbo_to_tex', (uint8_t * 32)), + ('image_to_tex', (uint8_t * 32)), + ('tex_to_image', (uint8_t * 32)), + ('num_tex', uint8_t), + ('tex_base', uint8_t), +] +class struct_ir3_disasm_info(Struct): pass +struct_ir3_disasm_info._fields_ = [ + ('write_disasm', ctypes.c_bool), + ('nir', ctypes.POINTER(ctypes.c_char)), + ('disasm', ctypes.POINTER(ctypes.c_char)), +] +class struct_ir3_shader_nir_options(Struct): pass +struct_ir3_shader_nir_options._fields_ = [ + ('robust_modes', nir_variable_mode), +] +class struct_ir3_shader_options(Struct): pass +struct_ir3_shader_options._fields_ = [ + ('api_wavesize', enum_ir3_wavesize_option), + ('real_wavesize', enum_ir3_wavesize_option), + ('push_consts_type', enum_ir3_push_consts_type), + ('push_consts_base', uint32_t), + ('push_consts_dwords', uint32_t), + ('const_allocs', struct_ir3_const_allocations), + ('nir_options', struct_ir3_shader_nir_options), + ('fragdata_dynamic_remap', ctypes.c_bool), +] +class struct_ir3_shader_output(Struct): pass +struct_ir3_shader_output._fields_ = [ + ('slot', uint8_t), + ('regid', uint8_t), + ('view', uint8_t), + ('aliased_components', uint8_t,4), + ('half', ctypes.c_bool,1), +] +class struct_fd_bo(Struct): pass +class struct_ir3(Struct): pass +class struct_ir3_instruction(Struct): pass +class struct_ir3_block(Struct): pass +struct_ir3_block._fields_ = [ + ('node', struct_list_head), + ('shader', ctypes.POINTER(struct_ir3)), + ('nblock', ctypes.POINTER(struct_nir_block)), + ('instr_list', struct_list_head), + ('successors', (ctypes.POINTER(struct_ir3_block) * 2)), + ('divergent_condition', ctypes.c_bool), + ('predecessors_count', ctypes.c_uint32), + ('predecessors_sz', ctypes.c_uint32), + ('predecessors', ctypes.POINTER(ctypes.POINTER(struct_ir3_block))), + ('physical_predecessors_count', ctypes.c_uint32), + ('physical_predecessors_sz', ctypes.c_uint32), + ('physical_predecessors', ctypes.POINTER(ctypes.POINTER(struct_ir3_block))), + ('physical_successors_count', ctypes.c_uint32), + ('physical_successors_sz', ctypes.c_uint32), + ('physical_successors', ctypes.POINTER(ctypes.POINTER(struct_ir3_block))), + ('start_ip', uint16_t), + ('end_ip', uint16_t), + ('reconvergence_point', ctypes.c_bool), + ('in_early_preamble', ctypes.c_bool), + ('keeps_count', ctypes.c_uint32), + ('keeps_sz', ctypes.c_uint32), + ('keeps', ctypes.POINTER(ctypes.POINTER(struct_ir3_instruction))), + ('data', ctypes.c_void_p), + ('index', uint32_t), + ('imm_dom', ctypes.POINTER(struct_ir3_block)), + ('dom_children_count', ctypes.c_uint32), + ('dom_children_sz', ctypes.c_uint32), + ('dom_children', ctypes.POINTER(ctypes.POINTER(struct_ir3_block))), + ('dom_pre_index', uint32_t), + ('dom_post_index', uint32_t), + ('loop_depth', uint32_t), +] +enum_ir3_instruction_flags = CEnum(ctypes.c_uint32) +IR3_INSTR_SY = enum_ir3_instruction_flags.define('IR3_INSTR_SY', 1) +IR3_INSTR_SS = enum_ir3_instruction_flags.define('IR3_INSTR_SS', 2) +IR3_INSTR_JP = enum_ir3_instruction_flags.define('IR3_INSTR_JP', 4) +IR3_INSTR_EQ = enum_ir3_instruction_flags.define('IR3_INSTR_EQ', 8) +IR3_INSTR_UL = enum_ir3_instruction_flags.define('IR3_INSTR_UL', 16) +IR3_INSTR_3D = enum_ir3_instruction_flags.define('IR3_INSTR_3D', 32) +IR3_INSTR_A = enum_ir3_instruction_flags.define('IR3_INSTR_A', 64) +IR3_INSTR_O = enum_ir3_instruction_flags.define('IR3_INSTR_O', 128) +IR3_INSTR_P = enum_ir3_instruction_flags.define('IR3_INSTR_P', 256) +IR3_INSTR_S = enum_ir3_instruction_flags.define('IR3_INSTR_S', 512) +IR3_INSTR_S2EN = enum_ir3_instruction_flags.define('IR3_INSTR_S2EN', 1024) +IR3_INSTR_SAT = enum_ir3_instruction_flags.define('IR3_INSTR_SAT', 2048) +IR3_INSTR_B = enum_ir3_instruction_flags.define('IR3_INSTR_B', 4096) +IR3_INSTR_NONUNIF = enum_ir3_instruction_flags.define('IR3_INSTR_NONUNIF', 8192) +IR3_INSTR_A1EN = enum_ir3_instruction_flags.define('IR3_INSTR_A1EN', 16384) +IR3_INSTR_U = enum_ir3_instruction_flags.define('IR3_INSTR_U', 32768) +IR3_INSTR_MARK = enum_ir3_instruction_flags.define('IR3_INSTR_MARK', 65536) +IR3_INSTR_SHARED_SPILL = enum_ir3_instruction_flags.define('IR3_INSTR_SHARED_SPILL', 65536) +IR3_INSTR_UNUSED = enum_ir3_instruction_flags.define('IR3_INSTR_UNUSED', 131072) +IR3_INSTR_NEEDS_HELPERS = enum_ir3_instruction_flags.define('IR3_INSTR_NEEDS_HELPERS', 262144) +IR3_INSTR_V = enum_ir3_instruction_flags.define('IR3_INSTR_V', 524288) +IR3_INSTR_INV_1D = enum_ir3_instruction_flags.define('IR3_INSTR_INV_1D', 1048576) +IR3_INSTR_IMM_OFFSET = enum_ir3_instruction_flags.define('IR3_INSTR_IMM_OFFSET', 2097152) + +class struct_ir3_register(Struct): pass +enum_ir3_register_flags = CEnum(ctypes.c_uint32) +IR3_REG_CONST = enum_ir3_register_flags.define('IR3_REG_CONST', 1) +IR3_REG_IMMED = enum_ir3_register_flags.define('IR3_REG_IMMED', 2) +IR3_REG_HALF = enum_ir3_register_flags.define('IR3_REG_HALF', 4) +IR3_REG_SHARED = enum_ir3_register_flags.define('IR3_REG_SHARED', 8) +IR3_REG_RELATIV = enum_ir3_register_flags.define('IR3_REG_RELATIV', 16) +IR3_REG_R = enum_ir3_register_flags.define('IR3_REG_R', 32) +IR3_REG_FNEG = enum_ir3_register_flags.define('IR3_REG_FNEG', 64) +IR3_REG_FABS = enum_ir3_register_flags.define('IR3_REG_FABS', 128) +IR3_REG_SNEG = enum_ir3_register_flags.define('IR3_REG_SNEG', 256) +IR3_REG_SABS = enum_ir3_register_flags.define('IR3_REG_SABS', 512) +IR3_REG_BNOT = enum_ir3_register_flags.define('IR3_REG_BNOT', 1024) +IR3_REG_EI = enum_ir3_register_flags.define('IR3_REG_EI', 2048) +IR3_REG_SSA = enum_ir3_register_flags.define('IR3_REG_SSA', 4096) +IR3_REG_ARRAY = enum_ir3_register_flags.define('IR3_REG_ARRAY', 8192) +IR3_REG_KILL = enum_ir3_register_flags.define('IR3_REG_KILL', 16384) +IR3_REG_FIRST_KILL = enum_ir3_register_flags.define('IR3_REG_FIRST_KILL', 32768) +IR3_REG_UNUSED = enum_ir3_register_flags.define('IR3_REG_UNUSED', 65536) +IR3_REG_EARLY_CLOBBER = enum_ir3_register_flags.define('IR3_REG_EARLY_CLOBBER', 131072) +IR3_REG_LAST_USE = enum_ir3_register_flags.define('IR3_REG_LAST_USE', 262144) +IR3_REG_PREDICATE = enum_ir3_register_flags.define('IR3_REG_PREDICATE', 524288) +IR3_REG_RT = enum_ir3_register_flags.define('IR3_REG_RT', 1048576) +IR3_REG_ALIAS = enum_ir3_register_flags.define('IR3_REG_ALIAS', 2097152) +IR3_REG_FIRST_ALIAS = enum_ir3_register_flags.define('IR3_REG_FIRST_ALIAS', 4194304) + +class struct_ir3_register_0(ctypes.Union): pass +class struct_ir3_register_0_array(Struct): pass +struct_ir3_register_0_array._fields_ = [ + ('id', uint16_t), + ('offset', int16_t), + ('base', uint16_t), +] +struct_ir3_register_0._fields_ = [ + ('iim_val', int32_t), + ('uim_val', uint32_t), + ('fim_val', ctypes.c_float), + ('array', struct_ir3_register_0_array), +] +class struct_ir3_merge_set(Struct): pass +struct_ir3_merge_set._fields_ = [ + ('preferred_reg', uint16_t), + ('size', uint16_t), + ('alignment', uint16_t), + ('interval_start', ctypes.c_uint32), + ('spill_slot', ctypes.c_uint32), + ('regs_count', ctypes.c_uint32), + ('regs', ctypes.POINTER(ctypes.POINTER(struct_ir3_register))), +] +struct_ir3_register._anonymous_ = ['_0'] +struct_ir3_register._fields_ = [ + ('flags', enum_ir3_register_flags), + ('name', ctypes.c_uint32), + ('wrmask', ctypes.c_uint32,16), + ('size', ctypes.c_uint32,16), + ('num', uint16_t), + ('_0', struct_ir3_register_0), + ('instr', ctypes.POINTER(struct_ir3_instruction)), + ('def', ctypes.POINTER(struct_ir3_register)), + ('tied', ctypes.POINTER(struct_ir3_register)), + ('spill_slot', ctypes.c_uint32), + ('next_use', ctypes.c_uint32), + ('merge_set_offset', ctypes.c_uint32), + ('merge_set', ctypes.POINTER(struct_ir3_merge_set)), + ('interval_start', ctypes.c_uint32), + ('interval_end', ctypes.c_uint32), +] +class struct_ir3_instruction_0(ctypes.Union): pass +class struct_ir3_instruction_0_cat0(Struct): pass +struct_ir3_instruction_0_cat0._fields_ = [ + ('inv1', ctypes.c_char), + ('inv2', ctypes.c_char), + ('immed', ctypes.c_int32), + ('target', ctypes.POINTER(struct_ir3_block)), + ('target_label', ctypes.POINTER(ctypes.c_char)), + ('idx', ctypes.c_uint32), +] +class struct_ir3_instruction_0_cat1(Struct): pass +round_t = CEnum(ctypes.c_uint32) +ROUND_ZERO = round_t.define('ROUND_ZERO', 0) +ROUND_EVEN = round_t.define('ROUND_EVEN', 1) +ROUND_POS_INF = round_t.define('ROUND_POS_INF', 2) +ROUND_NEG_INF = round_t.define('ROUND_NEG_INF', 3) + +reduce_op_t = CEnum(ctypes.c_uint32) +REDUCE_OP_ADD_U = reduce_op_t.define('REDUCE_OP_ADD_U', 0) +REDUCE_OP_ADD_F = reduce_op_t.define('REDUCE_OP_ADD_F', 1) +REDUCE_OP_MUL_U = reduce_op_t.define('REDUCE_OP_MUL_U', 2) +REDUCE_OP_MUL_F = reduce_op_t.define('REDUCE_OP_MUL_F', 3) +REDUCE_OP_MIN_U = reduce_op_t.define('REDUCE_OP_MIN_U', 4) +REDUCE_OP_MIN_S = reduce_op_t.define('REDUCE_OP_MIN_S', 5) +REDUCE_OP_MIN_F = reduce_op_t.define('REDUCE_OP_MIN_F', 6) +REDUCE_OP_MAX_U = reduce_op_t.define('REDUCE_OP_MAX_U', 7) +REDUCE_OP_MAX_S = reduce_op_t.define('REDUCE_OP_MAX_S', 8) +REDUCE_OP_MAX_F = reduce_op_t.define('REDUCE_OP_MAX_F', 9) +REDUCE_OP_AND_B = reduce_op_t.define('REDUCE_OP_AND_B', 10) +REDUCE_OP_OR_B = reduce_op_t.define('REDUCE_OP_OR_B', 11) +REDUCE_OP_XOR_B = reduce_op_t.define('REDUCE_OP_XOR_B', 12) + +struct_ir3_instruction_0_cat1._fields_ = [ + ('src_type', type_t), + ('dst_type', type_t), + ('round', round_t), + ('reduce_op', reduce_op_t), +] +class struct_ir3_instruction_0_cat2(Struct): pass +struct_ir3_instruction_0_cat2_condition = CEnum(ctypes.c_uint32) +IR3_COND_LT = struct_ir3_instruction_0_cat2_condition.define('IR3_COND_LT', 0) +IR3_COND_LE = struct_ir3_instruction_0_cat2_condition.define('IR3_COND_LE', 1) +IR3_COND_GT = struct_ir3_instruction_0_cat2_condition.define('IR3_COND_GT', 2) +IR3_COND_GE = struct_ir3_instruction_0_cat2_condition.define('IR3_COND_GE', 3) +IR3_COND_EQ = struct_ir3_instruction_0_cat2_condition.define('IR3_COND_EQ', 4) +IR3_COND_NE = struct_ir3_instruction_0_cat2_condition.define('IR3_COND_NE', 5) + +struct_ir3_instruction_0_cat2._fields_ = [ + ('condition', struct_ir3_instruction_0_cat2_condition), +] +class struct_ir3_instruction_0_cat3(Struct): pass +struct_ir3_instruction_0_cat3_signedness = CEnum(ctypes.c_uint32) +IR3_SRC_UNSIGNED = struct_ir3_instruction_0_cat3_signedness.define('IR3_SRC_UNSIGNED', 0) +IR3_SRC_MIXED = struct_ir3_instruction_0_cat3_signedness.define('IR3_SRC_MIXED', 1) + +struct_ir3_instruction_0_cat3_packed = CEnum(ctypes.c_uint32) +IR3_SRC_PACKED_LOW = struct_ir3_instruction_0_cat3_packed.define('IR3_SRC_PACKED_LOW', 0) +IR3_SRC_PACKED_HIGH = struct_ir3_instruction_0_cat3_packed.define('IR3_SRC_PACKED_HIGH', 1) + +struct_ir3_instruction_0_cat3._fields_ = [ + ('signedness', struct_ir3_instruction_0_cat3_signedness), + ('packed', struct_ir3_instruction_0_cat3_packed), + ('swapped', ctypes.c_bool), +] +class struct_ir3_instruction_0_cat5(Struct): pass +struct_ir3_instruction_0_cat5._fields_ = [ + ('samp', ctypes.c_uint32), + ('tex', ctypes.c_uint32), + ('tex_base', ctypes.c_uint32,3), + ('cluster_size', ctypes.c_uint32,4), + ('type', type_t), +] +class struct_ir3_instruction_0_cat6(Struct): pass +ir3_shfl_mode = CEnum(ctypes.c_uint32) +SHFL_XOR = ir3_shfl_mode.define('SHFL_XOR', 1) +SHFL_UP = ir3_shfl_mode.define('SHFL_UP', 2) +SHFL_DOWN = ir3_shfl_mode.define('SHFL_DOWN', 3) +SHFL_RUP = ir3_shfl_mode.define('SHFL_RUP', 6) +SHFL_RDOWN = ir3_shfl_mode.define('SHFL_RDOWN', 7) + +struct_ir3_instruction_0_cat6._fields_ = [ + ('type', type_t), + ('dst_offset', ctypes.c_int32), + ('iim_val', ctypes.c_int32), + ('d', ctypes.c_uint32,3), + ('typed', ctypes.c_bool,1), + ('base', ctypes.c_uint32,3), + ('shfl_mode', ir3_shfl_mode,3), +] +class struct_ir3_instruction_0_cat7(Struct): pass +ir3_alias_scope = CEnum(ctypes.c_uint32) +ALIAS_TEX = ir3_alias_scope.define('ALIAS_TEX', 0) +ALIAS_RT = ir3_alias_scope.define('ALIAS_RT', 1) +ALIAS_MEM = ir3_alias_scope.define('ALIAS_MEM', 2) + +struct_ir3_instruction_0_cat7._fields_ = [ + ('w', ctypes.c_uint32,1), + ('r', ctypes.c_uint32,1), + ('l', ctypes.c_uint32,1), + ('g', ctypes.c_uint32,1), + ('alias_scope', ir3_alias_scope), + ('alias_table_size_minus_one', ctypes.c_uint32), + ('alias_type_float', ctypes.c_bool), +] +class struct_ir3_instruction_0_split(Struct): pass +struct_ir3_instruction_0_split._fields_ = [ + ('off', ctypes.c_int32), +] +class struct_ir3_instruction_0_end(Struct): pass +struct_ir3_instruction_0_end._fields_ = [ + ('outidxs', ctypes.POINTER(ctypes.c_uint32)), +] +class struct_ir3_instruction_0_phi(Struct): pass +struct_ir3_instruction_0_phi._fields_ = [ + ('nphi', ctypes.c_void_p), + ('comp', ctypes.c_uint32), +] +class struct_ir3_instruction_0_prefetch(Struct): pass +struct_ir3_instruction_0_prefetch._fields_ = [ + ('samp', ctypes.c_uint32), + ('tex', ctypes.c_uint32), + ('input_offset', ctypes.c_uint32), + ('samp_base', ctypes.c_uint32,3), + ('tex_base', ctypes.c_uint32,3), +] +class struct_ir3_instruction_0_input(Struct): pass +struct_ir3_instruction_0_input._fields_ = [ + ('inidx', ctypes.c_int32), + ('sysval', gl_system_value), +] +class struct_ir3_instruction_0_push_consts(Struct): pass +struct_ir3_instruction_0_push_consts._fields_ = [ + ('src_base', ctypes.c_uint32), + ('src_size', ctypes.c_uint32), + ('dst_base', ctypes.c_uint32), +] +class struct_ir3_instruction_0_raw(Struct): pass +struct_ir3_instruction_0_raw._fields_ = [ + ('value', uint64_t), +] +struct_ir3_instruction_0._fields_ = [ + ('cat0', struct_ir3_instruction_0_cat0), + ('cat1', struct_ir3_instruction_0_cat1), + ('cat2', struct_ir3_instruction_0_cat2), + ('cat3', struct_ir3_instruction_0_cat3), + ('cat5', struct_ir3_instruction_0_cat5), + ('cat6', struct_ir3_instruction_0_cat6), + ('cat7', struct_ir3_instruction_0_cat7), + ('split', struct_ir3_instruction_0_split), + ('end', struct_ir3_instruction_0_end), + ('phi', struct_ir3_instruction_0_phi), + ('prefetch', struct_ir3_instruction_0_prefetch), + ('input', struct_ir3_instruction_0_input), + ('push_consts', struct_ir3_instruction_0_push_consts), + ('raw', struct_ir3_instruction_0_raw), +] +struct_ir3_instruction_barrier_class = CEnum(ctypes.c_uint32) +IR3_BARRIER_EVERYTHING = struct_ir3_instruction_barrier_class.define('IR3_BARRIER_EVERYTHING', 1) +IR3_BARRIER_SHARED_R = struct_ir3_instruction_barrier_class.define('IR3_BARRIER_SHARED_R', 2) +IR3_BARRIER_SHARED_W = struct_ir3_instruction_barrier_class.define('IR3_BARRIER_SHARED_W', 4) +IR3_BARRIER_IMAGE_R = struct_ir3_instruction_barrier_class.define('IR3_BARRIER_IMAGE_R', 8) +IR3_BARRIER_IMAGE_W = struct_ir3_instruction_barrier_class.define('IR3_BARRIER_IMAGE_W', 16) +IR3_BARRIER_BUFFER_R = struct_ir3_instruction_barrier_class.define('IR3_BARRIER_BUFFER_R', 32) +IR3_BARRIER_BUFFER_W = struct_ir3_instruction_barrier_class.define('IR3_BARRIER_BUFFER_W', 64) +IR3_BARRIER_ARRAY_R = struct_ir3_instruction_barrier_class.define('IR3_BARRIER_ARRAY_R', 128) +IR3_BARRIER_ARRAY_W = struct_ir3_instruction_barrier_class.define('IR3_BARRIER_ARRAY_W', 256) +IR3_BARRIER_PRIVATE_R = struct_ir3_instruction_barrier_class.define('IR3_BARRIER_PRIVATE_R', 512) +IR3_BARRIER_PRIVATE_W = struct_ir3_instruction_barrier_class.define('IR3_BARRIER_PRIVATE_W', 1024) +IR3_BARRIER_CONST_W = struct_ir3_instruction_barrier_class.define('IR3_BARRIER_CONST_W', 2048) +IR3_BARRIER_ACTIVE_FIBERS_R = struct_ir3_instruction_barrier_class.define('IR3_BARRIER_ACTIVE_FIBERS_R', 4096) +IR3_BARRIER_ACTIVE_FIBERS_W = struct_ir3_instruction_barrier_class.define('IR3_BARRIER_ACTIVE_FIBERS_W', 8192) + +struct_ir3_instruction._anonymous_ = ['_0'] +struct_ir3_instruction._fields_ = [ + ('block', ctypes.POINTER(struct_ir3_block)), + ('opc', opc_t), + ('flags', enum_ir3_instruction_flags), + ('repeat', uint8_t), + ('nop', uint8_t), + ('srcs_count', ctypes.c_uint32), + ('dsts_count', ctypes.c_uint32), + ('dsts', ctypes.POINTER(ctypes.POINTER(struct_ir3_register))), + ('srcs', ctypes.POINTER(ctypes.POINTER(struct_ir3_register))), + ('_0', struct_ir3_instruction_0), + ('ip', uint32_t), + ('data', ctypes.c_void_p), + ('uses', ctypes.POINTER(struct_set)), + ('use_count', ctypes.c_int32), + ('address', ctypes.POINTER(struct_ir3_register)), + ('deps_count', ctypes.c_uint32), + ('deps_sz', ctypes.c_uint32), + ('deps', ctypes.POINTER(ctypes.POINTER(struct_ir3_instruction))), + ('barrier_class', struct_ir3_instruction_barrier_class), + ('barrier_conflict', struct_ir3_instruction_barrier_class), + ('node', struct_list_head), + ('rpt_node', struct_list_head), + ('serialno', uint32_t), + ('line', ctypes.c_int32), +] +struct_ir3._fields_ = [ + ('compiler', ctypes.POINTER(struct_ir3_compiler)), + ('type', gl_shader_stage), + ('inputs_count', ctypes.c_uint32), + ('inputs_sz', ctypes.c_uint32), + ('inputs', ctypes.POINTER(ctypes.POINTER(struct_ir3_instruction))), + ('baryfs_count', ctypes.c_uint32), + ('baryfs_sz', ctypes.c_uint32), + ('baryfs', ctypes.POINTER(ctypes.POINTER(struct_ir3_instruction))), + ('a0_users_count', ctypes.c_uint32), + ('a0_users_sz', ctypes.c_uint32), + ('a0_users', ctypes.POINTER(ctypes.POINTER(struct_ir3_instruction))), + ('a1_users_count', ctypes.c_uint32), + ('a1_users_sz', ctypes.c_uint32), + ('a1_users', ctypes.POINTER(ctypes.POINTER(struct_ir3_instruction))), + ('astc_srgb_count', ctypes.c_uint32), + ('astc_srgb_sz', ctypes.c_uint32), + ('astc_srgb', ctypes.POINTER(ctypes.POINTER(struct_ir3_instruction))), + ('tg4_count', ctypes.c_uint32), + ('tg4_sz', ctypes.c_uint32), + ('tg4', ctypes.POINTER(ctypes.POINTER(struct_ir3_instruction))), + ('block_list', struct_list_head), + ('array_list', struct_list_head), + ('instr_count', ctypes.c_uint32), +] +class struct_ir3_info(Struct): pass +struct_ir3_info._fields_ = [ + ('size', uint32_t), + ('constant_data_offset', uint32_t), + ('sizedwords', uint16_t), + ('instrs_count', uint16_t), + ('preamble_instrs_count', uint16_t), + ('nops_count', uint16_t), + ('mov_count', uint16_t), + ('cov_count', uint16_t), + ('stp_count', uint16_t), + ('ldp_count', uint16_t), + ('max_reg', int8_t), + ('max_half_reg', int8_t), + ('max_const', int16_t), + ('max_waves', int8_t), + ('subgroup_size', uint8_t), + ('double_threadsize', ctypes.c_bool), + ('multi_dword_ldp_stp', ctypes.c_bool), + ('early_preamble', ctypes.c_bool), + ('uses_ray_intersection', ctypes.c_bool), + ('ss', uint16_t), + ('sy', uint16_t), + ('sstall', uint16_t), + ('systall', uint16_t), + ('last_baryf', uint16_t), + ('last_helper', uint16_t), + ('instrs_per_cat', (uint16_t * 8)), +] +class struct_ir3_shader_variant_input(Struct): pass +struct_ir3_shader_variant_input._fields_ = [ + ('slot', uint8_t), + ('regid', uint8_t), + ('compmask', uint8_t), + ('inloc', uint8_t), + ('sysval', ctypes.c_bool,1), + ('bary', ctypes.c_bool,1), + ('rasterflat', ctypes.c_bool,1), + ('half', ctypes.c_bool,1), + ('flat', ctypes.c_bool,1), +] +class struct_ir3_shader_variant_astc_srgb(Struct): pass +struct_ir3_shader_variant_astc_srgb._fields_ = [ + ('base', ctypes.c_uint32), + ('count', ctypes.c_uint32), + ('orig_idx', (ctypes.c_uint32 * 16)), +] +class struct_ir3_shader_variant_tg4(Struct): pass +struct_ir3_shader_variant_tg4._fields_ = [ + ('base', ctypes.c_uint32), + ('count', ctypes.c_uint32), + ('orig_idx', (ctypes.c_uint32 * 16)), +] +class struct_ir3_shader_variant_0(ctypes.Union): pass +class struct_ir3_shader_variant_0_tess(Struct): pass +enum_gl_tess_spacing = CEnum(ctypes.c_uint32) +TESS_SPACING_UNSPECIFIED = enum_gl_tess_spacing.define('TESS_SPACING_UNSPECIFIED', 0) +TESS_SPACING_EQUAL = enum_gl_tess_spacing.define('TESS_SPACING_EQUAL', 1) +TESS_SPACING_FRACTIONAL_ODD = enum_gl_tess_spacing.define('TESS_SPACING_FRACTIONAL_ODD', 2) +TESS_SPACING_FRACTIONAL_EVEN = enum_gl_tess_spacing.define('TESS_SPACING_FRACTIONAL_EVEN', 3) + +struct_ir3_shader_variant_0_tess._fields_ = [ + ('primitive_mode', enum_tess_primitive_mode), + ('tcs_vertices_out', uint8_t), + ('spacing', enum_gl_tess_spacing,2), + ('ccw', ctypes.c_bool,1), + ('point_mode', ctypes.c_bool,1), +] +class struct_ir3_shader_variant_0_gs(Struct): pass +struct_ir3_shader_variant_0_gs._fields_ = [ + ('output_primitive', uint16_t), + ('vertices_out', uint16_t), + ('invocations', uint8_t), + ('vertices_in', uint8_t,3), +] +class struct_ir3_shader_variant_0_fs(Struct): pass +struct_ir3_shader_variant_0_fs._fields_ = [ + ('early_fragment_tests', ctypes.c_bool,1), + ('color_is_dual_source', ctypes.c_bool,1), + ('uses_fbfetch_output', ctypes.c_bool,1), + ('fbfetch_coherent', ctypes.c_bool,1), + ('depth_layout', enum_gl_frag_depth_layout), +] +class struct_ir3_shader_variant_0_cs(Struct): pass +struct_ir3_shader_variant_0_cs._fields_ = [ + ('req_local_mem', ctypes.c_uint32), + ('force_linear_dispatch', ctypes.c_bool), + ('local_invocation_id', uint32_t), + ('work_group_id', uint32_t), +] +struct_ir3_shader_variant_0._fields_ = [ + ('tess', struct_ir3_shader_variant_0_tess), + ('gs', struct_ir3_shader_variant_0_gs), + ('fs', struct_ir3_shader_variant_0_fs), + ('cs', struct_ir3_shader_variant_0_cs), +] +struct_ir3_shader_variant._anonymous_ = ['_0'] +struct_ir3_shader_variant._fields_ = [ + ('bo', ctypes.POINTER(struct_fd_bo)), + ('id', uint32_t), + ('shader_id', uint32_t), + ('key', struct_ir3_shader_key), + ('binning_pass', ctypes.c_bool), + ('binning', ctypes.POINTER(struct_ir3_shader_variant)), + ('nonbinning', ctypes.POINTER(struct_ir3_shader_variant)), + ('ir', ctypes.POINTER(struct_ir3)), + ('next', ctypes.POINTER(struct_ir3_shader_variant)), + ('type', gl_shader_stage), + ('compiler', ctypes.POINTER(struct_ir3_compiler)), + ('name', ctypes.POINTER(ctypes.c_char)), + ('constant_data', ctypes.c_void_p), + ('disasm_info', struct_ir3_disasm_info), + ('bin', ctypes.POINTER(uint32_t)), + ('const_state', ctypes.POINTER(struct_ir3_const_state)), + ('imm_state', struct_ir3_imm_const_state), + ('info', struct_ir3_info), + ('sha1_str', (ctypes.c_char * 41)), + ('shader_options', struct_ir3_shader_options), + ('constant_data_size', uint32_t), + ('branchstack', ctypes.c_uint32), + ('loops', ctypes.c_uint32), + ('instrlen', ctypes.c_uint32), + ('constlen', ctypes.c_uint32), + ('pvtmem_size', ctypes.c_uint32), + ('pvtmem_per_wave', ctypes.c_bool), + ('multi_pos_output', ctypes.c_bool), + ('dual_src_blend', ctypes.c_bool), + ('early_preamble', ctypes.c_bool), + ('shared_size', ctypes.c_uint32), + ('frag_face', ctypes.c_bool), + ('color0_mrt', ctypes.c_bool), + ('fragcoord_compmask', uint8_t), + ('outputs_count', ctypes.c_uint32), + ('outputs', (struct_ir3_shader_output * 34)), + ('writes_pos', ctypes.c_bool), + ('writes_smask', ctypes.c_bool), + ('writes_psize', ctypes.c_bool), + ('writes_viewport', ctypes.c_bool), + ('writes_stencilref', ctypes.c_bool), + ('writes_shading_rate', ctypes.c_bool), + ('output_size', uint32_t), + ('input_size', uint32_t), + ('output_loc', (ctypes.c_uint32 * 45)), + ('inputs_count', ctypes.c_uint32), + ('inputs', (struct_ir3_shader_variant_input * 34)), + ('reads_primid', ctypes.c_bool), + ('reads_shading_rate', ctypes.c_bool), + ('reads_smask', ctypes.c_bool), + ('total_in', ctypes.c_uint32), + ('sysval_in', ctypes.c_uint32), + ('varying_in', ctypes.c_uint32), + ('image_mapping', struct_ir3_ibo_mapping), + ('num_samp', ctypes.c_int32), + ('fb_read', ctypes.c_bool), + ('has_ssbo', ctypes.c_bool), + ('bindless_tex', ctypes.c_bool), + ('bindless_samp', ctypes.c_bool), + ('bindless_ibo', ctypes.c_bool), + ('bindless_ubo', ctypes.c_bool), + ('need_pixlod', ctypes.c_bool), + ('need_full_quad', ctypes.c_bool), + ('need_driver_params', ctypes.c_bool), + ('no_earlyz', ctypes.c_bool), + ('has_kill', ctypes.c_bool), + ('per_samp', ctypes.c_bool), + ('post_depth_coverage', ctypes.c_bool), + ('empty', ctypes.c_bool), + ('writes_only_color', ctypes.c_bool), + ('mergedregs', ctypes.c_bool), + ('clip_mask', uint8_t), + ('cull_mask', uint8_t), + ('astc_srgb', struct_ir3_shader_variant_astc_srgb), + ('tg4', struct_ir3_shader_variant_tg4), + ('num_sampler_prefetch', uint32_t), + ('sampler_prefetch', (struct_ir3_sampler_prefetch * 4)), + ('prefetch_bary_type', enum_ir3_bary), + ('prefetch_end_of_quad', ctypes.c_bool), + ('local_size', (uint16_t * 3)), + ('local_size_variable', ctypes.c_bool), + ('has_barrier', ctypes.c_bool), + ('num_ssbos', ctypes.c_uint32), + ('num_uavs', ctypes.c_uint32), + ('_0', struct_ir3_shader_variant_0), + ('vtxid_base', uint32_t), + ('stream_output', struct_ir3_stream_output_info), +] +class struct_ir3_shader_0(ctypes.Union): pass +class struct_ir3_shader_0_cs(Struct): pass +struct_ir3_shader_0_cs._fields_ = [ + ('req_local_mem', ctypes.c_uint32), + ('force_linear_dispatch', ctypes.c_bool), +] +class struct_ir3_shader_0_vs(Struct): pass +struct_ir3_shader_0_vs._fields_ = [ + ('passthrough_tcs_compiled', ctypes.c_uint32), + ('passthrough_tcs', (ctypes.POINTER(struct_ir3_shader) * 32)), +] +struct_ir3_shader_0._fields_ = [ + ('cs', struct_ir3_shader_0_cs), + ('vs', struct_ir3_shader_0_vs), +] +class pthread_mutex_t(ctypes.Union): pass +mtx_t = pthread_mutex_t +class struct___pthread_mutex_s(Struct): pass +class struct___pthread_internal_list(Struct): pass +__pthread_list_t = struct___pthread_internal_list +struct___pthread_internal_list._fields_ = [ + ('__prev', ctypes.POINTER(struct___pthread_internal_list)), + ('__next', ctypes.POINTER(struct___pthread_internal_list)), +] +struct___pthread_mutex_s._fields_ = [ + ('__lock', ctypes.c_int32), + ('__count', ctypes.c_uint32), + ('__owner', ctypes.c_int32), + ('__nusers', ctypes.c_uint32), + ('__kind', ctypes.c_int32), + ('__spins', ctypes.c_int16), + ('__elision', ctypes.c_int16), + ('__list', struct___pthread_internal_list), +] +pthread_mutex_t._fields_ = [ + ('__data', struct___pthread_mutex_s), + ('__size', (ctypes.c_char * 40)), + ('__align', ctypes.c_int64), +] +cache_key = (ctypes.c_ubyte * 20) +struct_ir3_shader._anonymous_ = ['_0'] +struct_ir3_shader._fields_ = [ + ('type', gl_shader_stage), + ('id', uint32_t), + ('variant_count', uint32_t), + ('initial_variants_done', ctypes.c_bool), + ('compiler', ctypes.POINTER(struct_ir3_compiler)), + ('options', struct_ir3_shader_options), + ('nir_finalized', ctypes.c_bool), + ('nir', ctypes.POINTER(struct_nir_shader)), + ('stream_output', struct_ir3_stream_output_info), + ('_0', struct_ir3_shader_0), + ('variants', ctypes.POINTER(struct_ir3_shader_variant)), + ('variants_lock', mtx_t), + ('cache_key', cache_key), + ('key_mask', struct_ir3_shader_key), +] +try: (ir3_const_ensure_imm_size:=dll.ir3_const_ensure_imm_size).restype, ir3_const_ensure_imm_size.argtypes = ctypes.c_bool, [ctypes.POINTER(struct_ir3_shader_variant), ctypes.c_uint32] +except AttributeError: pass + +try: (ir3_const_imm_index_to_reg:=dll.ir3_const_imm_index_to_reg).restype, ir3_const_imm_index_to_reg.argtypes = uint16_t, [ctypes.POINTER(struct_ir3_const_state), ctypes.c_uint32] +except AttributeError: pass + +try: (ir3_const_find_imm:=dll.ir3_const_find_imm).restype, ir3_const_find_imm.argtypes = uint16_t, [ctypes.POINTER(struct_ir3_shader_variant), uint32_t] +except AttributeError: pass + +try: (ir3_const_add_imm:=dll.ir3_const_add_imm).restype, ir3_const_add_imm.argtypes = uint16_t, [ctypes.POINTER(struct_ir3_shader_variant), uint32_t] +except AttributeError: pass + +try: (ir3_shader_assemble:=dll.ir3_shader_assemble).restype, ir3_shader_assemble.argtypes = ctypes.c_void_p, [ctypes.POINTER(struct_ir3_shader_variant)] +except AttributeError: pass + +try: (ir3_shader_create_variant:=dll.ir3_shader_create_variant).restype, ir3_shader_create_variant.argtypes = ctypes.POINTER(struct_ir3_shader_variant), [ctypes.POINTER(struct_ir3_shader), ctypes.POINTER(struct_ir3_shader_key), ctypes.c_bool] +except AttributeError: pass + +try: (ir3_shader_get_variant:=dll.ir3_shader_get_variant).restype, ir3_shader_get_variant.argtypes = ctypes.POINTER(struct_ir3_shader_variant), [ctypes.POINTER(struct_ir3_shader), ctypes.POINTER(struct_ir3_shader_key), ctypes.c_bool, ctypes.c_bool, ctypes.POINTER(ctypes.c_bool)] +except AttributeError: pass + +try: (ir3_shader_from_nir:=dll.ir3_shader_from_nir).restype, ir3_shader_from_nir.argtypes = ctypes.POINTER(struct_ir3_shader), [ctypes.POINTER(struct_ir3_compiler), ctypes.POINTER(nir_shader), ctypes.POINTER(struct_ir3_shader_options), ctypes.POINTER(struct_ir3_stream_output_info)] +except AttributeError: pass + +try: (ir3_trim_constlen:=dll.ir3_trim_constlen).restype, ir3_trim_constlen.argtypes = uint32_t, [ctypes.POINTER(ctypes.POINTER(struct_ir3_shader_variant)), ctypes.POINTER(struct_ir3_compiler)] +except AttributeError: pass + +try: (ir3_shader_passthrough_tcs:=dll.ir3_shader_passthrough_tcs).restype, ir3_shader_passthrough_tcs.argtypes = ctypes.POINTER(struct_ir3_shader), [ctypes.POINTER(struct_ir3_shader), ctypes.c_uint32] +except AttributeError: pass + +try: (ir3_shader_destroy:=dll.ir3_shader_destroy).restype, ir3_shader_destroy.argtypes = None, [ctypes.POINTER(struct_ir3_shader)] +except AttributeError: pass + +try: (ir3_shader_disasm:=dll.ir3_shader_disasm).restype, ir3_shader_disasm.argtypes = None, [ctypes.POINTER(struct_ir3_shader_variant), ctypes.POINTER(uint32_t), ctypes.POINTER(FILE)] +except AttributeError: pass + +try: (ir3_shader_outputs:=dll.ir3_shader_outputs).restype, ir3_shader_outputs.argtypes = uint64_t, [ctypes.POINTER(struct_ir3_shader)] +except AttributeError: pass + +try: (ir3_glsl_type_size:=dll.ir3_glsl_type_size).restype, ir3_glsl_type_size.argtypes = ctypes.c_int32, [ctypes.POINTER(struct_glsl_type), ctypes.c_bool] +except AttributeError: pass + +try: (ir3_shader_get_subgroup_size:=dll.ir3_shader_get_subgroup_size).restype, ir3_shader_get_subgroup_size.argtypes = None, [ctypes.POINTER(struct_ir3_compiler), ctypes.POINTER(struct_ir3_shader_options), gl_shader_stage, ctypes.POINTER(ctypes.c_uint32), ctypes.POINTER(ctypes.c_uint32)] +except AttributeError: pass + +class struct_ir3_shader_linkage(Struct): pass +class struct_ir3_shader_linkage_var(Struct): pass +struct_ir3_shader_linkage_var._fields_ = [ + ('slot', uint8_t), + ('regid', uint8_t), + ('compmask', uint8_t), + ('loc', uint8_t), +] +struct_ir3_shader_linkage._fields_ = [ + ('max_loc', uint8_t), + ('cnt', uint8_t), + ('varmask', (uint32_t * 4)), + ('var', (struct_ir3_shader_linkage_var * 32)), + ('primid_loc', uint8_t), + ('viewid_loc', uint8_t), + ('clip0_loc', uint8_t), + ('clip1_loc', uint8_t), +] +try: (print_raw:=dll.print_raw).restype, print_raw.argtypes = None, [ctypes.POINTER(FILE), ctypes.POINTER(ctypes.c_uint32), size_t] +except AttributeError: pass + +try: (ir3_link_stream_out:=dll.ir3_link_stream_out).restype, ir3_link_stream_out.argtypes = None, [ctypes.POINTER(struct_ir3_shader_linkage), ctypes.POINTER(struct_ir3_shader_variant)] +except AttributeError: pass + +try: (ir3_nir_apply_trig_workarounds:=dll.ir3_nir_apply_trig_workarounds).restype, ir3_nir_apply_trig_workarounds.argtypes = ctypes.c_bool, [ctypes.POINTER(nir_shader)] +except AttributeError: pass + +try: (ir3_nir_lower_imul:=dll.ir3_nir_lower_imul).restype, ir3_nir_lower_imul.argtypes = ctypes.c_bool, [ctypes.POINTER(nir_shader)] +except AttributeError: pass + +try: (ir3_nir_lower_io_offsets:=dll.ir3_nir_lower_io_offsets).restype, ir3_nir_lower_io_offsets.argtypes = ctypes.c_bool, [ctypes.POINTER(nir_shader)] +except AttributeError: pass + +try: (ir3_nir_lower_load_barycentric_at_sample:=dll.ir3_nir_lower_load_barycentric_at_sample).restype, ir3_nir_lower_load_barycentric_at_sample.argtypes = ctypes.c_bool, [ctypes.POINTER(nir_shader)] +except AttributeError: pass + +try: (ir3_nir_lower_load_barycentric_at_offset:=dll.ir3_nir_lower_load_barycentric_at_offset).restype, ir3_nir_lower_load_barycentric_at_offset.argtypes = ctypes.c_bool, [ctypes.POINTER(nir_shader)] +except AttributeError: pass + +try: (ir3_nir_lower_push_consts_to_preamble:=dll.ir3_nir_lower_push_consts_to_preamble).restype, ir3_nir_lower_push_consts_to_preamble.argtypes = ctypes.c_bool, [ctypes.POINTER(nir_shader), ctypes.POINTER(struct_ir3_shader_variant)] +except AttributeError: pass + +try: (ir3_nir_lower_driver_params_to_ubo:=dll.ir3_nir_lower_driver_params_to_ubo).restype, ir3_nir_lower_driver_params_to_ubo.argtypes = ctypes.c_bool, [ctypes.POINTER(nir_shader), ctypes.POINTER(struct_ir3_shader_variant)] +except AttributeError: pass + +try: (ir3_nir_move_varying_inputs:=dll.ir3_nir_move_varying_inputs).restype, ir3_nir_move_varying_inputs.argtypes = ctypes.c_bool, [ctypes.POINTER(nir_shader)] +except AttributeError: pass + +try: (ir3_nir_coord_offset:=dll.ir3_nir_coord_offset).restype, ir3_nir_coord_offset.argtypes = ctypes.c_int32, [ctypes.POINTER(nir_def), ctypes.POINTER(gl_system_value)] +except AttributeError: pass + +try: (ir3_nir_lower_tex_prefetch:=dll.ir3_nir_lower_tex_prefetch).restype, ir3_nir_lower_tex_prefetch.argtypes = ctypes.c_bool, [ctypes.POINTER(nir_shader), ctypes.POINTER(enum_ir3_bary)] +except AttributeError: pass + +try: (ir3_nir_lower_layer_id:=dll.ir3_nir_lower_layer_id).restype, ir3_nir_lower_layer_id.argtypes = ctypes.c_bool, [ctypes.POINTER(nir_shader)] +except AttributeError: pass + +try: (ir3_nir_lower_frag_shading_rate:=dll.ir3_nir_lower_frag_shading_rate).restype, ir3_nir_lower_frag_shading_rate.argtypes = ctypes.c_bool, [ctypes.POINTER(nir_shader)] +except AttributeError: pass + +try: (ir3_nir_lower_primitive_shading_rate:=dll.ir3_nir_lower_primitive_shading_rate).restype, ir3_nir_lower_primitive_shading_rate.argtypes = ctypes.c_bool, [ctypes.POINTER(nir_shader)] +except AttributeError: pass + +try: (ir3_nir_lower_to_explicit_output:=dll.ir3_nir_lower_to_explicit_output).restype, ir3_nir_lower_to_explicit_output.argtypes = ctypes.c_bool, [ctypes.POINTER(nir_shader), ctypes.POINTER(struct_ir3_shader_variant), ctypes.c_uint32] +except AttributeError: pass + +try: (ir3_nir_lower_to_explicit_input:=dll.ir3_nir_lower_to_explicit_input).restype, ir3_nir_lower_to_explicit_input.argtypes = ctypes.c_bool, [ctypes.POINTER(nir_shader), ctypes.POINTER(struct_ir3_shader_variant)] +except AttributeError: pass + +try: (ir3_nir_lower_tess_ctrl:=dll.ir3_nir_lower_tess_ctrl).restype, ir3_nir_lower_tess_ctrl.argtypes = ctypes.c_bool, [ctypes.POINTER(nir_shader), ctypes.POINTER(struct_ir3_shader_variant), ctypes.c_uint32] +except AttributeError: pass + +try: (ir3_nir_lower_tess_eval:=dll.ir3_nir_lower_tess_eval).restype, ir3_nir_lower_tess_eval.argtypes = ctypes.c_bool, [ctypes.POINTER(nir_shader), ctypes.POINTER(struct_ir3_shader_variant), ctypes.c_uint32] +except AttributeError: pass + +try: (ir3_nir_lower_gs:=dll.ir3_nir_lower_gs).restype, ir3_nir_lower_gs.argtypes = ctypes.c_bool, [ctypes.POINTER(nir_shader)] +except AttributeError: pass + +try: (ir3_supports_vectorized_nir_op:=dll.ir3_supports_vectorized_nir_op).restype, ir3_supports_vectorized_nir_op.argtypes = ctypes.c_bool, [nir_op] +except AttributeError: pass + +try: (ir3_nir_vectorize_filter:=dll.ir3_nir_vectorize_filter).restype, ir3_nir_vectorize_filter.argtypes = uint8_t, [ctypes.POINTER(nir_instr), ctypes.c_void_p] +except AttributeError: pass + +try: (ir3_nir_lower_64b_intrinsics:=dll.ir3_nir_lower_64b_intrinsics).restype, ir3_nir_lower_64b_intrinsics.argtypes = ctypes.c_bool, [ctypes.POINTER(nir_shader)] +except AttributeError: pass + +try: (ir3_nir_lower_64b_undef:=dll.ir3_nir_lower_64b_undef).restype, ir3_nir_lower_64b_undef.argtypes = ctypes.c_bool, [ctypes.POINTER(nir_shader)] +except AttributeError: pass + +try: (ir3_nir_lower_64b_global:=dll.ir3_nir_lower_64b_global).restype, ir3_nir_lower_64b_global.argtypes = ctypes.c_bool, [ctypes.POINTER(nir_shader)] +except AttributeError: pass + +try: (ir3_nir_lower_64b_regs:=dll.ir3_nir_lower_64b_regs).restype, ir3_nir_lower_64b_regs.argtypes = ctypes.c_bool, [ctypes.POINTER(nir_shader)] +except AttributeError: pass + +try: (ir3_mem_access_size_align:=dll.ir3_mem_access_size_align).restype, ir3_mem_access_size_align.argtypes = nir_mem_access_size_align, [nir_intrinsic_op, uint8_t, uint8_t, uint32_t, uint32_t, ctypes.c_bool, enum_gl_access_qualifier, ctypes.c_void_p] +except AttributeError: pass + +try: (ir3_nir_opt_branch_and_or_not:=dll.ir3_nir_opt_branch_and_or_not).restype, ir3_nir_opt_branch_and_or_not.argtypes = ctypes.c_bool, [ctypes.POINTER(nir_shader)] +except AttributeError: pass + +try: (ir3_nir_opt_triops_bitwise:=dll.ir3_nir_opt_triops_bitwise).restype, ir3_nir_opt_triops_bitwise.argtypes = ctypes.c_bool, [ctypes.POINTER(nir_shader)] +except AttributeError: pass + +try: (ir3_optimize_loop:=dll.ir3_optimize_loop).restype, ir3_optimize_loop.argtypes = ctypes.c_bool, [ctypes.POINTER(struct_ir3_compiler), ctypes.POINTER(struct_ir3_shader_nir_options), ctypes.POINTER(nir_shader)] +except AttributeError: pass + +try: (ir3_nir_lower_io_vars_to_temporaries:=dll.ir3_nir_lower_io_vars_to_temporaries).restype, ir3_nir_lower_io_vars_to_temporaries.argtypes = None, [ctypes.POINTER(nir_shader)] +except AttributeError: pass + +try: (ir3_finalize_nir:=dll.ir3_finalize_nir).restype, ir3_finalize_nir.argtypes = None, [ctypes.POINTER(struct_ir3_compiler), ctypes.POINTER(struct_ir3_shader_nir_options), ctypes.POINTER(nir_shader)] +except AttributeError: pass + +try: (ir3_nir_post_finalize:=dll.ir3_nir_post_finalize).restype, ir3_nir_post_finalize.argtypes = None, [ctypes.POINTER(struct_ir3_shader)] +except AttributeError: pass + +try: (ir3_nir_lower_variant:=dll.ir3_nir_lower_variant).restype, ir3_nir_lower_variant.argtypes = None, [ctypes.POINTER(struct_ir3_shader_variant), ctypes.POINTER(struct_ir3_shader_nir_options), ctypes.POINTER(nir_shader)] +except AttributeError: pass + +try: (ir3_setup_const_state:=dll.ir3_setup_const_state).restype, ir3_setup_const_state.argtypes = None, [ctypes.POINTER(nir_shader), ctypes.POINTER(struct_ir3_shader_variant), ctypes.POINTER(struct_ir3_const_state)] +except AttributeError: pass + +try: (ir3_const_state_get_free_space:=dll.ir3_const_state_get_free_space).restype, ir3_const_state_get_free_space.argtypes = uint32_t, [ctypes.POINTER(struct_ir3_shader_variant), ctypes.POINTER(struct_ir3_const_state), uint32_t] +except AttributeError: pass + +try: (ir3_const_alloc:=dll.ir3_const_alloc).restype, ir3_const_alloc.argtypes = None, [ctypes.POINTER(struct_ir3_const_allocations), enum_ir3_const_alloc_type, uint32_t, uint32_t] +except AttributeError: pass + +try: (ir3_const_reserve_space:=dll.ir3_const_reserve_space).restype, ir3_const_reserve_space.argtypes = None, [ctypes.POINTER(struct_ir3_const_allocations), enum_ir3_const_alloc_type, uint32_t, uint32_t] +except AttributeError: pass + +try: (ir3_const_free_reserved_space:=dll.ir3_const_free_reserved_space).restype, ir3_const_free_reserved_space.argtypes = None, [ctypes.POINTER(struct_ir3_const_allocations), enum_ir3_const_alloc_type] +except AttributeError: pass + +try: (ir3_const_alloc_all_reserved_space:=dll.ir3_const_alloc_all_reserved_space).restype, ir3_const_alloc_all_reserved_space.argtypes = None, [ctypes.POINTER(struct_ir3_const_allocations)] +except AttributeError: pass + +try: (ir3_nir_scan_driver_consts:=dll.ir3_nir_scan_driver_consts).restype, ir3_nir_scan_driver_consts.argtypes = uint32_t, [ctypes.POINTER(struct_ir3_compiler), ctypes.POINTER(nir_shader), ctypes.POINTER(struct_ir3_const_image_dims)] +except AttributeError: pass + +try: (ir3_alloc_driver_params:=dll.ir3_alloc_driver_params).restype, ir3_alloc_driver_params.argtypes = None, [ctypes.POINTER(struct_ir3_const_allocations), ctypes.POINTER(uint32_t), ctypes.POINTER(struct_ir3_compiler), enum_pipe_shader_type] +except AttributeError: pass + +try: (ir3_nir_lower_load_constant:=dll.ir3_nir_lower_load_constant).restype, ir3_nir_lower_load_constant.argtypes = ctypes.c_bool, [ctypes.POINTER(nir_shader), ctypes.POINTER(struct_ir3_shader_variant)] +except AttributeError: pass + +try: (ir3_nir_analyze_ubo_ranges:=dll.ir3_nir_analyze_ubo_ranges).restype, ir3_nir_analyze_ubo_ranges.argtypes = None, [ctypes.POINTER(nir_shader), ctypes.POINTER(struct_ir3_shader_variant)] +except AttributeError: pass + +try: (ir3_nir_lower_ubo_loads:=dll.ir3_nir_lower_ubo_loads).restype, ir3_nir_lower_ubo_loads.argtypes = ctypes.c_bool, [ctypes.POINTER(nir_shader), ctypes.POINTER(struct_ir3_shader_variant)] +except AttributeError: pass + +try: (ir3_nir_lower_const_global_loads:=dll.ir3_nir_lower_const_global_loads).restype, ir3_nir_lower_const_global_loads.argtypes = ctypes.c_bool, [ctypes.POINTER(nir_shader), ctypes.POINTER(struct_ir3_shader_variant)] +except AttributeError: pass + +try: (ir3_nir_fixup_load_const_ir3:=dll.ir3_nir_fixup_load_const_ir3).restype, ir3_nir_fixup_load_const_ir3.argtypes = ctypes.c_bool, [ctypes.POINTER(nir_shader)] +except AttributeError: pass + +try: (ir3_nir_opt_preamble:=dll.ir3_nir_opt_preamble).restype, ir3_nir_opt_preamble.argtypes = ctypes.c_bool, [ctypes.POINTER(nir_shader), ctypes.POINTER(struct_ir3_shader_variant)] +except AttributeError: pass + +try: (ir3_nir_opt_prefetch_descriptors:=dll.ir3_nir_opt_prefetch_descriptors).restype, ir3_nir_opt_prefetch_descriptors.argtypes = ctypes.c_bool, [ctypes.POINTER(nir_shader), ctypes.POINTER(struct_ir3_shader_variant)] +except AttributeError: pass + +try: (ir3_nir_lower_preamble:=dll.ir3_nir_lower_preamble).restype, ir3_nir_lower_preamble.argtypes = ctypes.c_bool, [ctypes.POINTER(nir_shader), ctypes.POINTER(struct_ir3_shader_variant)] +except AttributeError: pass + +try: (ir3_nir_try_propagate_bit_shift:=dll.ir3_nir_try_propagate_bit_shift).restype, ir3_nir_try_propagate_bit_shift.argtypes = ctypes.POINTER(nir_def), [ctypes.POINTER(nir_builder), ctypes.POINTER(nir_def), int32_t] +except AttributeError: pass + +try: (ir3_nir_lower_subgroups_filter:=dll.ir3_nir_lower_subgroups_filter).restype, ir3_nir_lower_subgroups_filter.argtypes = ctypes.c_bool, [ctypes.POINTER(nir_instr), ctypes.c_void_p] +except AttributeError: pass + +try: (ir3_nir_lower_shuffle:=dll.ir3_nir_lower_shuffle).restype, ir3_nir_lower_shuffle.argtypes = ctypes.c_bool, [ctypes.POINTER(nir_shader), ctypes.POINTER(struct_ir3_shader)] +except AttributeError: pass + +try: (ir3_nir_opt_subgroups:=dll.ir3_nir_opt_subgroups).restype, ir3_nir_opt_subgroups.argtypes = ctypes.c_bool, [ctypes.POINTER(nir_shader), ctypes.POINTER(struct_ir3_shader_variant)] +except AttributeError: pass + +try: (ir3_get_shared_driver_ubo:=dll.ir3_get_shared_driver_ubo).restype, ir3_get_shared_driver_ubo.argtypes = ctypes.POINTER(nir_def), [ctypes.POINTER(nir_builder), ctypes.POINTER(struct_ir3_driver_ubo)] +except AttributeError: pass + +try: (ir3_get_driver_ubo:=dll.ir3_get_driver_ubo).restype, ir3_get_driver_ubo.argtypes = ctypes.POINTER(nir_def), [ctypes.POINTER(nir_builder), ctypes.POINTER(struct_ir3_driver_ubo)] +except AttributeError: pass + +try: (ir3_get_driver_consts_ubo:=dll.ir3_get_driver_consts_ubo).restype, ir3_get_driver_consts_ubo.argtypes = ctypes.POINTER(nir_def), [ctypes.POINTER(nir_builder), ctypes.POINTER(struct_ir3_shader_variant)] +except AttributeError: pass + +try: (ir3_update_driver_ubo:=dll.ir3_update_driver_ubo).restype, ir3_update_driver_ubo.argtypes = None, [ctypes.POINTER(nir_shader), ctypes.POINTER(struct_ir3_driver_ubo), ctypes.POINTER(ctypes.c_char)] +except AttributeError: pass + +try: (ir3_load_shared_driver_ubo:=dll.ir3_load_shared_driver_ubo).restype, ir3_load_shared_driver_ubo.argtypes = ctypes.POINTER(nir_def), [ctypes.POINTER(nir_builder), ctypes.c_uint32, ctypes.POINTER(struct_ir3_driver_ubo), ctypes.c_uint32] +except AttributeError: pass + +try: (ir3_load_driver_ubo:=dll.ir3_load_driver_ubo).restype, ir3_load_driver_ubo.argtypes = ctypes.POINTER(nir_def), [ctypes.POINTER(nir_builder), ctypes.c_uint32, ctypes.POINTER(struct_ir3_driver_ubo), ctypes.c_uint32] +except AttributeError: pass + +try: (ir3_load_driver_ubo_indirect:=dll.ir3_load_driver_ubo_indirect).restype, ir3_load_driver_ubo_indirect.argtypes = ctypes.POINTER(nir_def), [ctypes.POINTER(nir_builder), ctypes.c_uint32, ctypes.POINTER(struct_ir3_driver_ubo), ctypes.c_uint32, ctypes.POINTER(nir_def), ctypes.c_uint32] +except AttributeError: pass + +try: (ir3_def_is_rematerializable_for_preamble:=dll.ir3_def_is_rematerializable_for_preamble).restype, ir3_def_is_rematerializable_for_preamble.argtypes = ctypes.c_bool, [ctypes.POINTER(nir_def), ctypes.POINTER(ctypes.POINTER(nir_def))] +except AttributeError: pass + +try: (ir3_rematerialize_def_for_preamble:=dll.ir3_rematerialize_def_for_preamble).restype, ir3_rematerialize_def_for_preamble.argtypes = ctypes.POINTER(nir_def), [ctypes.POINTER(nir_builder), ctypes.POINTER(nir_def), ctypes.POINTER(struct_set), ctypes.POINTER(ctypes.POINTER(nir_def))] +except AttributeError: pass + +class struct_driver_param_info(Struct): pass +struct_driver_param_info._fields_ = [ + ('offset', uint32_t), + ('extra_size', uint32_t), +] +try: (ir3_get_driver_param_info:=dll.ir3_get_driver_param_info).restype, ir3_get_driver_param_info.argtypes = ctypes.c_bool, [ctypes.POINTER(nir_shader), ctypes.POINTER(nir_intrinsic_instr), ctypes.POINTER(struct_driver_param_info)] +except AttributeError: pass + +try: (ir3_nir_max_imm_offset:=dll.ir3_nir_max_imm_offset).restype, ir3_nir_max_imm_offset.argtypes = uint32_t, [ctypes.POINTER(nir_intrinsic_instr), ctypes.c_void_p] +except AttributeError: pass + +try: (ir3_nir_intrinsic_barycentric_sysval:=dll.ir3_nir_intrinsic_barycentric_sysval).restype, ir3_nir_intrinsic_barycentric_sysval.argtypes = gl_system_value, [ctypes.POINTER(nir_intrinsic_instr)] +except AttributeError: pass + try: (glsl_type_singleton_init_or_ref:=dll.glsl_type_singleton_init_or_ref).restype, glsl_type_singleton_init_or_ref.argtypes = None, [] except AttributeError: pass @@ -7651,6 +9378,47 @@ RALLOC_PRINT_INFO_SUMMARY_ONLY = _anonenum7.define('RALLOC_PRINT_INFO_SUMMARY_ON try: (ralloc_print_info:=dll.ralloc_print_info).restype, ralloc_print_info.argtypes = None, [ctypes.POINTER(FILE), ctypes.c_void_p, ctypes.c_uint32] except AttributeError: pass +class struct_isa_decode_options(Struct): pass +class struct_isa_decode_value(Struct): pass +struct_isa_decode_value._fields_ = [ + ('str', ctypes.POINTER(ctypes.c_char)), + ('num', uint64_t), +] +class struct_isa_print_state(Struct): pass +struct_isa_print_state._fields_ = [ + ('out', ctypes.POINTER(FILE)), + ('line_column', ctypes.c_uint32), +] +class struct_isa_entrypoint(Struct): pass +struct_isa_entrypoint._fields_ = [ + ('name', ctypes.POINTER(ctypes.c_char)), + ('offset', uint32_t), +] +struct_isa_decode_options._fields_ = [ + ('gpu_id', uint32_t), + ('show_errors', ctypes.c_bool), + ('max_errors', ctypes.c_uint32), + ('branch_labels', ctypes.c_bool), + ('stop', ctypes.c_bool), + ('cbdata', ctypes.c_void_p), + ('field_cb', ctypes.CFUNCTYPE(None, ctypes.c_void_p, ctypes.POINTER(ctypes.c_char), ctypes.POINTER(struct_isa_decode_value))), + ('field_print_cb', ctypes.CFUNCTYPE(None, ctypes.POINTER(struct_isa_print_state), ctypes.POINTER(ctypes.c_char), uint64_t)), + ('pre_instr_cb', ctypes.CFUNCTYPE(None, ctypes.c_void_p, ctypes.c_uint32, ctypes.c_void_p)), + ('post_instr_cb', ctypes.CFUNCTYPE(None, ctypes.c_void_p, ctypes.c_uint32, ctypes.c_void_p)), + ('no_match_cb', ctypes.CFUNCTYPE(None, ctypes.POINTER(FILE), ctypes.POINTER(ctypes.c_uint32), size_t)), + ('entrypoint_count', ctypes.c_uint32), + ('entrypoints', ctypes.POINTER(struct_isa_entrypoint)), +] +try: (ir3_isa_disasm:=dll.ir3_isa_disasm).restype, ir3_isa_disasm.argtypes = None, [ctypes.c_void_p, ctypes.c_int32, ctypes.POINTER(FILE), ctypes.POINTER(struct_isa_decode_options)] +except AttributeError: pass + +try: (ir3_isa_decode:=dll.ir3_isa_decode).restype, ir3_isa_decode.argtypes = ctypes.c_bool, [ctypes.c_void_p, ctypes.c_void_p, ctypes.POINTER(struct_isa_decode_options)] +except AttributeError: pass + +class struct_decode_scope(Struct): pass +try: (ir3_isa_get_gpu_id:=dll.ir3_isa_get_gpu_id).restype, ir3_isa_get_gpu_id.argtypes = uint32_t, [ctypes.POINTER(struct_decode_scope)] +except AttributeError: pass + try: glsl_type_builtin_error = struct_glsl_type.in_dll(dll, 'glsl_type_builtin_error') except (ValueError,AttributeError): pass try: glsl_type_builtin_void = struct_glsl_type.in_dll(dll, 'glsl_type_builtin_void') @@ -9304,6 +11072,30 @@ lp_jit_vertex_header_id = lambda _gallivm,_type,_ptr: lp_build_struct__get_ptr2( lp_jit_vertex_header_clip_pos = lambda _gallivm,_type,_ptr: lp_build_struct__get_ptr2(_gallivm, _type, _ptr, LP_JIT_VERTEX_HEADER_CLIP_POS, "clip_pos") lp_jit_vertex_header_data = lambda _gallivm,_type,_ptr: lp_build_struct__get_ptr2(_gallivm, _type, _ptr, LP_JIT_VERTEX_HEADER_DATA, "data") LP_MAX_TEX_FUNC_ARGS = 32 +A6XX_CCU_DEPTH_SIZE = (64 * 1024) +A6XX_CCU_GMEM_COLOR_SIZE = (16 * 1024) +dword_offsetof = lambda type,name: DIV_ROUND_UP(offsetof(type, name), 4) +dword_sizeof = lambda type: DIV_ROUND_UP(sizeof(type), 4) +IR3_DP_CS = lambda name: dword_offsetof(struct_ir3_driver_params_cs, name) +IR3_DP_VS = lambda name: dword_offsetof(struct_ir3_driver_params_vs, name) +IR3_DP_TCS = lambda name: dword_offsetof(struct_ir3_driver_params_tcs, name) +IR3_DP_FS = lambda name: dword_offsetof(struct_ir3_driver_params_fs, name) +IR3_MAX_SHADER_BUFFERS = 32 +IR3_MAX_SHADER_IMAGES = 32 +IR3_MAX_SO_BUFFERS = 4 +IR3_MAX_SO_STREAMS = 4 +IR3_MAX_SO_OUTPUTS = 128 +IR3_MAX_UBO_PUSH_RANGES = 32 +IR3_MAX_SAMPLER_PREFETCH = 4 +IR3_SAMPLER_PREFETCH_CMD = 0x4 +IR3_SAMPLER_BINDLESS_PREFETCH_CMD = 0x6 +IR3_TESS_NONE = 0 +IR3_TESS_QUADS = 1 +IR3_TESS_TRIANGLES = 2 +IR3_TESS_ISOLINES = 3 +UAV_INVALID = 0xff +UAV_SSBO = 0x80 +HALF_REG_ID = 0x100 gc_alloc = lambda ctx,type,count: gc_alloc_size(ctx, sizeof(type) * (count), alignof(type)) gc_zalloc = lambda ctx,type,count: gc_zalloc_size(ctx, sizeof(type) * (count), alignof(type)) gc_alloc_zla = lambda ctx,type,type2,count: gc_alloc_size(ctx, sizeof(type) + sizeof(type2) * (count), MAX2(alignof(type), alignof(type2))) @@ -9312,6 +11104,7 @@ DECLARE_RALLOC_CXX_OPERATORS = lambda type: DECLARE_RALLOC_CXX_OPERATORS_TEMPLAT DECLARE_RZALLOC_CXX_OPERATORS = lambda type: DECLARE_RALLOC_CXX_OPERATORS_TEMPLATE(type, rzalloc_size) DECLARE_LINEAR_ALLOC_CXX_OPERATORS = lambda type: DECLARE_LINEAR_ALLOC_CXX_OPERATORS_TEMPLATE(type, linear_alloc_child) DECLARE_LINEAR_ZALLOC_CXX_OPERATORS = lambda type: DECLARE_LINEAR_ALLOC_CXX_OPERATORS_TEMPLATE(type, linear_zalloc_child) +ISA_GPU_ID = lambda: ir3_isa_get_gpu_id(scope) __struct__cast = lambda X: (struct_X) A6XX_RBBM_INT_0_MASK_RBBM_GPU_IDLE = 0x00000001 A6XX_RBBM_INT_0_MASK_CP_AHB_ERROR = 0x00000002 diff --git a/tinygrad/runtime/ops_null.py b/tinygrad/runtime/ops_null.py index af7f0eb046..f7bad77230 100644 --- a/tinygrad/runtime/ops_null.py +++ b/tinygrad/runtime/ops_null.py @@ -5,6 +5,8 @@ from tinygrad.renderer.cstyle import Renderer, CStyleLanguage from tinygrad.renderer.llvmir import AMDLLVMRenderer from tinygrad.uop.ops import Ops from tinygrad.helpers import cpu_profile, EMULATE +from tinygrad.renderer.nir import IR3Renderer +from tinygrad.runtime.support.compiler_mesa import IR3Compiler class NullRenderer(CStyleLanguage): device = "NULL" @@ -37,4 +39,6 @@ class NullDevice(Compiled): case "AMD_RDNA4": renderer = functools.partial(AMDLLVMRenderer, "gfx1201") case "": renderer = NullRenderer case _: raise RuntimeError(f"can't EMULATE device: {EMULATE.value}") - super().__init__(device, NullAllocator(self), CompilerSet([CompilerPair(renderer, Compiler)]), functools.partial(NullProgram, device), NullGraph) + compilers = CompilerSet([CompilerPair(renderer, Compiler), + CompilerPair(functools.partial(IR3Renderer, self), functools.partial(IR3Compiler, 0x6030001))]) # adreno 630 + super().__init__(device, NullAllocator(self), compilers, functools.partial(NullProgram, device), NullGraph) diff --git a/tinygrad/runtime/ops_qcom.py b/tinygrad/runtime/ops_qcom.py index 4b944d97a7..b1284bbd1e 100644 --- a/tinygrad/runtime/ops_qcom.py +++ b/tinygrad/runtime/ops_qcom.py @@ -9,7 +9,10 @@ from tinygrad.runtime.support.hcq import FileIOInterface, MMIOInterface from tinygrad.runtime.autogen import kgsl, mesa from tinygrad.runtime.ops_cl import CLCompiler, CLDevice from tinygrad.renderer.cstyle import QCOMRenderer +from tinygrad.renderer.nir import IR3Renderer +from tinygrad.runtime.support.compiler_mesa import IR3Compiler from tinygrad.helpers import getenv, mv_address, to_mv, round_up, data64_le, prod, fromimport, cpu_profile, lo32, PROFILE, suppress_finalizing +from tinygrad.helpers import flatten, QCOM_IR3, QCOM_CC from tinygrad.runtime.support.system import System if getenv("IOCTL"): import extra.qcom_gpu_driver.opencl_ioctl # noqa: F401 # pylint: disable=unused-import @@ -34,9 +37,11 @@ def pkt7_hdr(opcode: int, cnt: int): return mesa.CP_TYPE7_PKT | cnt & 0x3FFF | p def pkt4_hdr(reg: int, cnt: int): return mesa.CP_TYPE4_PKT | cnt & 0x7F | parity(cnt) << 7 | (reg & 0x3FFFF) << 8 | parity(reg) << 27 +def _read_lib(lib, off) -> int: return struct.unpack("I", lib[off:off+4])[0] class QCOMCompiler(CLCompiler): def __init__(self, device:str=""): super().__init__(CLDevice(device), 'compile_qcom') - def disassemble(self, lib:bytes): fromimport('extra.disassemblers.adreno', 'disasm')(lib) + def disassemble(self, lib:bytes): + fromimport('tinygrad.runtime.support.compiler_mesa', 'disas_adreno')(lib[(ofs:=_read_lib(lib, 0xc0)):ofs+_read_lib(lib, 0x100)]) class QCOMSignal(HCQSignal): def __init__(self, *args, **kwargs): super().__init__(*args, **{**kwargs, 'timestamp_divider': 19.2}) @@ -120,9 +125,9 @@ class QCOMComputeQueue(HWQueue): self.reg(mesa.REG_A6XX_SP_UPDATE_CNTL, 0x0) self.reg(mesa.REG_A6XX_SP_CS_TSIZE, qreg.a6xx_sp_cs_tsize(0x80)) # is this right? mesa uses 1 self.reg(mesa.REG_A6XX_SP_CS_USIZE, qreg.a6xx_sp_cs_usize(0x40)) # mesa also uses 1 - self.reg(mesa.REG_A6XX_SP_MODE_CNTL, qreg.a6xx_sp_mode_cntl(isammode=mesa.ISAMMODE_CL)) + self.reg(mesa.REG_A6XX_SP_MODE_CNTL, qreg.a6xx_sp_mode_cntl(isammode=mesa.ISAMMODE_GL if prg.NIR else mesa.ISAMMODE_CL)) self.reg(mesa.REG_A6XX_SP_PERFCTR_SHADER_MASK, qreg.a6xx_sp_perfctr_shader_mask(cs=True)) - self.reg(mesa.REG_A6XX_TPL1_MODE_CNTL, qreg.a6xx_tpl1_mode_cntl(isammode=mesa.ISAMMODE_CL)) + self.reg(mesa.REG_A6XX_TPL1_MODE_CNTL, qreg.a6xx_tpl1_mode_cntl(isammode=mesa.ISAMMODE_GL if prg.NIR else mesa.ISAMMODE_CL)) self.reg(mesa.REG_A6XX_TPL1_DBG_ECO_CNTL, 0) self.cmd(mesa.CP_WAIT_FOR_IDLE) @@ -138,6 +143,7 @@ class QCOMComputeQueue(HWQueue): qreg.a6xx_sp_cs_pvt_mem_param(memsizeperitem=prg.pvtmem_size_per_item), *data64_le(prg.dev._stack.va_addr), qreg.a6xx_sp_cs_pvt_mem_size(totalpvtmemsize=prg.pvtmem_size_total)) + if prg.NIR and prg.wgsz != 0xfc: to_mv(args_state.buf.va_addr + prg.wgsz * 4, 12)[:] = struct.pack("III", *local_size) self.cmd(mesa.CP_LOAD_STATE6_FRAG, qreg.cp_load_state6_0(state_type=mesa.ST_CONSTANTS, state_src=mesa.SS6_INDIRECT, state_block=mesa.SB6_CS_SHADER, num_unit=1024 // 4), *data64_le(args_state.buf.va_addr)) @@ -150,20 +156,20 @@ class QCOMComputeQueue(HWQueue): self.reg(mesa.REG_A6XX_SP_CS_PVT_MEM_STACK_OFFSET, qreg.a6xx_sp_cs_pvt_mem_stack_offset(prg.hw_stack_offset)) self.reg(mesa.REG_A6XX_SP_CS_INSTR_SIZE, qreg.a6xx_sp_cs_instr_size(prg.image_size // 4)) - if args_state.prg.samp_cnt > 0: + if prg.samp_cnt > 0: self.cmd(mesa.CP_LOAD_STATE6_FRAG, qreg.cp_load_state6_0(state_type=mesa.ST_SHADER, state_src=mesa.SS6_INDIRECT, state_block=mesa.SB6_CS_TEX, num_unit=args_state.prg.samp_cnt), *data64_le(args_state.buf.va_addr + args_state.prg.samp_off)) self.reg(mesa.REG_A6XX_SP_CS_SAMPLER_BASE, *data64_le(args_state.buf.va_addr + args_state.prg.samp_off)) self.reg(mesa.REG_A6XX_TPL1_CS_BORDER_COLOR_BASE, *data64_le(prg.dev.border_color_buf.va_addr)) - if args_state.prg.tex_cnt > 0: + if prg.tex_cnt > 0: self.cmd(mesa.CP_LOAD_STATE6_FRAG, qreg.cp_load_state6_0(state_type=mesa.ST_CONSTANTS, state_src=mesa.SS6_INDIRECT, state_block=mesa.SB6_CS_TEX, num_unit=min(16, args_state.prg.tex_cnt)), *data64_le(args_state.buf.va_addr + args_state.prg.tex_off)) self.reg(mesa.REG_A6XX_SP_CS_TEXMEMOBJ_BASE, *data64_le(args_state.buf.va_addr + args_state.prg.tex_off)) - if args_state.prg.ibo_cnt > 0: + if prg.ibo_cnt > 0: self.cmd(mesa.CP_LOAD_STATE6_FRAG, qreg.cp_load_state6_0(state_type=mesa.ST6_UAV, state_src=mesa.SS6_INDIRECT, state_block=mesa.SB6_CS_SHADER, num_unit=args_state.prg.ibo_cnt), *data64_le(args_state.buf.va_addr + args_state.prg.ibo_off)) @@ -171,7 +177,15 @@ class QCOMComputeQueue(HWQueue): self.reg(mesa.REG_A6XX_SP_CS_CONFIG, qreg.a6xx_sp_cs_config(enabled=True, nsamp=args_state.prg.samp_cnt, ntex=args_state.prg.tex_cnt, nuav=args_state.prg.ibo_cnt)) - self.cmd(mesa.CP_RUN_OPENCL, 0) + + if prg.NIR: + self.reg(mesa.REG_A6XX_SP_CS_CONST_CONFIG_0, + qreg.a6xx_sp_cs_const_config_0(wgidconstid=prg.wgid, wgsizeconstid=prg.wgsz, wgoffsetconstid=0xfc, localidregid=prg.lid), + qreg.a6xx_sp_cs_wge_cntl(linearlocalidregid=0xfc, threadsize=mesa.THREAD64)) + self.cmd(mesa.CP_EXEC_CS, 0, + qreg.cp_exec_cs_1(ngroups_x=global_size[0]), qreg.cp_exec_cs_2(ngroups_y=global_size[1]), qreg.cp_exec_cs_3(_ngroups_z=global_size[2])) + else: self.cmd(mesa.CP_RUN_OPENCL, 0) + self._cache_flush(write_back=True, invalidate=False, sync=False, memsync=False) return self @@ -195,11 +209,45 @@ class QCOMArgsState(HCQArgsState): for i, v in enumerate(vals): self.bind_sints_to_buf(v, buf=self.buf, fmt='I', offset=self.args_info[i].offset) +class IR3ArgsState(HCQArgsState): + def __init__(self, buf:HCQBuffer, prg:QCOMProgram, bufs:tuple[HCQBuffer, ...], vals:tuple[int, ...]=()): + super().__init__(buf, prg, bufs, vals=vals) + ctypes.memset(cast(int, self.buf.va_addr), 0, prg.kernargs_alloc_size) + to_mv(self.buf.va_addr + prg.imm_off, len(prg.imm_vals))[:] = prg.imm_vals + + ubos, uavs = [b for b in bufs if b.texture_info is None], [b for b in bufs if b.texture_info is not None] + ibos, texs = (uavs, []) if prg.tex_cnt == 0 else (uavs[:-prg.tex_cnt], uavs[-prg.tex_cnt:]) # textures are at the end + + if prg.samp_cnt > 0: to_mv(self.buf.va_addr + prg.samp_off, len(prg.samplers) * 4).cast('I')[:] = array.array('I', prg.samplers) + self.bind_sints_to_buf(*[b.va_addr for b in ubos], buf=self.buf, fmt='Q', offset=prg.buf_off) + self.bind_sints_to_buf(*vals, buf=self.buf, fmt='I', offset=prg.buf_off + len(ubos) * 8) + self.bind_sints_to_buf(*flatten([b.texture_info.desc + ([0] * 8) for b in texs]), buf=self.buf, fmt='I', offset=prg.tex_off) + self.bind_sints_to_buf(*flatten([b.texture_info.ibo + ([0] * 8) for b in ibos]), buf=self.buf, fmt='I', offset=prg.ibo_off) + class QCOMProgram(HCQProgram): def __init__(self, dev: QCOMDevice, name: str, lib: bytes): self.dev: QCOMDevice = dev - self.name, self.lib = name, lib - self._parse_lib() + self.name, self.lib, self.NIR = name, lib, isinstance(dev.compiler, IR3Compiler) + + if self.NIR: + from tinygrad.runtime.autogen import mesa + v, cs, self.imm_vals, self.image = IR3Compiler.unpack_lib(lib) + self.prg_offset, self.brnchstck, self.image_size, self.pvtmem, self.shmem = 0, v.branchstack, v.info.size, v.pvtmem_size, v.shared_size + self.wgsz = alloc.offset_vec4 * 4 + 8 if (alloc:=cs.allocs.consts[mesa.IR3_CONST_ALLOC_DRIVER_PARAMS]).size_vec4 else 0xfc + + self.wgid, self.lid = v.cs.work_group_id, v.cs.local_invocation_id # register ids + self.buf_off, self.imm_off = cs.ubo_state.range[0].offset, cs.allocs.max_const_offset_vec4 * 16 + + # see https://elixir.bootlin.com/mesa/mesa-25.3.0/source/src/freedreno/ir3/ir3_shader.h#L525 + # and https://elixir.bootlin.com/mesa/mesa-25.3.0/source/src/freedreno/ir3/ir3_compiler_nir.c#L5389 + self.samp_cnt, self.tex_cnt, self.ibo_cnt = (nt:=v.image_mapping.num_tex), nt, v.num_uavs - nt + # IR3 outputs a sampler for every texture (https://elixir.bootlin.com/mesa/mesa-25.3.0/source/src/freedreno/ir3/ir3_compiler_nir.c#L1714) + self.samplers = [qreg.a6xx_tex_samp_0(wrap_s=(clamp_mode:=mesa.A6XX_TEX_CLAMP_TO_BORDER), wrap_t=clamp_mode, wrap_r=clamp_mode), + qreg.a6xx_tex_samp_1(unnorm_coords=True, cubemapseamlessfiltoff=True), 0, 0] * self.samp_cnt + + self.tex_off, self.ibo_off, self.samp_off = 2048, 2048 + 0x40 * self.tex_cnt, 2048 + 0x40 * (self.tex_cnt + self.ibo_cnt) + self.fregs, self.hregs = v.info.max_reg + 1, v.info.max_half_reg + 1 + else: self._parse_lib() self.lib_gpu: HCQBuffer = self.dev.allocator.alloc(self.image_size, buf_spec:=BufferSpec(cpu_access=True, nolru=True)) to_mv(cast(int, self.lib_gpu.va_addr), self.image_size)[:] = self.image @@ -211,8 +259,8 @@ class QCOMProgram(HCQProgram): self.max_threads = min(1024, ((384 * 32) // (max(1, (self.fregs + round_up(self.hregs, 2) // 2)) * 128)) * 128) dev._ensure_stack_size(self.hw_stack_offset * 4) - kernargs_alloc_size = round_up(2048 + (self.tex_cnt + self.ibo_cnt) * 0x40 + self.samp_cnt * 0x10, 0x100) - super().__init__(QCOMArgsState, self.dev, self.name, kernargs_alloc_size=kernargs_alloc_size) + kernargs_alloc_size = round_up(2048 + (self.tex_cnt + self.ibo_cnt) * 0x40 + len(self.samplers) * 4, 0x100) + super().__init__(IR3ArgsState if self.NIR else QCOMArgsState, self.dev, self.name, kernargs_alloc_size=kernargs_alloc_size) weakref.finalize(self, self._fini, self.dev, self.lib_gpu, buf_spec) def __call__(self, *bufs, global_size:tuple[int,int,int]=(1,1,1), local_size:tuple[int,int,int]=(1,1,1), vals:tuple[int, ...]=(), wait=False): @@ -222,27 +270,26 @@ class QCOMProgram(HCQProgram): return super().__call__(*bufs, global_size=global_size, local_size=local_size, vals=vals, wait=wait) def _parse_lib(self): - def _read_lib(off) -> int: return struct.unpack("I", self.lib[off:off+4])[0] - # Extract image binary - self.image_size = _read_lib(0x100) - self.image = bytearray(self.lib[(image_offset:=_read_lib(0xc0)):image_offset+self.image_size]) + self.image_size = _read_lib(self.lib, 0x100) + self.image = bytearray(self.lib[(image_offset:=_read_lib(self.lib, 0xc0)):image_offset+self.image_size]) # Parse image descriptors - image_desc_off = _read_lib(0x110) - self.prg_offset, self.brnchstck = _read_lib(image_desc_off+0xc4), _read_lib(image_desc_off+0x108) // 2 - self.pvtmem, self.shmem = _read_lib(image_desc_off+0xc8), _read_lib(image_desc_off+0xd8) + image_desc_off = _read_lib(self.lib, 0x110) + self.prg_offset, self.brnchstck = _read_lib(self.lib, image_desc_off+0xc4), _read_lib(self.lib, image_desc_off+0x108) // 2 + self.pvtmem, self.shmem = _read_lib(self.lib, image_desc_off+0xc8), _read_lib(self.lib, image_desc_off+0xd8) # Fill up constants and buffers info self.buf_info, self.consts_info = [], [] # Collect sampler info. - self.samp_cnt = samp_cnt_in_file = _read_lib(image_desc_off + 0xdc) + self.samp_cnt = samp_cnt_in_file = _read_lib(self.lib, image_desc_off + 0xdc) assert self.samp_cnt <= 1, "Up to one sampler supported" if self.samp_cnt: self.samp_cnt += 1 self.samplers = [qreg.a6xx_tex_samp_0(wrap_s=(clamp_mode:=mesa.A6XX_TEX_CLAMP_TO_BORDER), wrap_t=clamp_mode, wrap_r=clamp_mode), qreg.a6xx_tex_samp_1(unnorm_coords=True, cubemapseamlessfiltoff=True), 0, 0, 0, 0, 0, 0] + else: self.samplers = [] # Collect kernel arguments (buffers) info. bdoff = round_up(image_desc_off + 0x158 + len(self.name), 4) + 8 * samp_cnt_in_file @@ -260,16 +307,16 @@ class QCOMProgram(HCQProgram): if x.type is BUFTYPE_IBO: x.offset, cur_ibo_off = cur_ibo_off, cur_ibo_off + 0x40 elif x.type is BUFTYPE_TEX: x.offset, cur_tex_off = cur_tex_off, cur_tex_off + 0x40 - if _read_lib(0xb0) != 0: # check if we have constants. - cdoff = _read_lib(0xac) + if _read_lib(self.lib, 0xb0) != 0: # check if we have constants. + cdoff = _read_lib(self.lib, 0xac) while cdoff + 40 <= image_offset: cnst, offset_words, _, is32 = struct.unpack("I", self.lib[cdoff:cdoff+4])[0], *struct.unpack("III", self.lib[cdoff+16:cdoff+28]) self.consts_info.append((cnst, offset_words * (sz_bytes:=(2 << is32)), sz_bytes)) cdoff += 40 # Registers info - reg_desc_off = _read_lib(0x34) - self.fregs, self.hregs = _read_lib(reg_desc_off + 0x14), _read_lib(reg_desc_off + 0x18) + reg_desc_off = _read_lib(self.lib, 0x34) + self.fregs, self.hregs = _read_lib(self.lib, reg_desc_off + 0x14), _read_lib(self.lib, reg_desc_off + 0x18) class QCOMTextureInfo: def __init__(self, pitch:int, real_stride:int, desc:list[int], ibo:list[int]): @@ -354,8 +401,10 @@ class QCOMDevice(HCQCompiled): if PROFILE and self.gpu_id[:2] < (7, 3): System.write_sysfs("/sys/class/kgsl/kgsl-3d0/idle_timer", value="4000000000", msg="Failed to disable suspend mode", expected="4294967276") - super().__init__(device, QCOMAllocator(self), CompilerSet([CompilerPair(QCOMRenderer, functools.partial(QCOMCompiler, device))]), - functools.partial(QCOMProgram, self), QCOMSignal, functools.partial(QCOMComputeQueue, self), None) + compilers = CompilerSet(ctrl_var=QCOM_CC, cset=[CompilerPair(QCOMRenderer, functools.partial(QCOMCompiler, device)), + CompilerPair(functools.partial(IR3Renderer, self), functools.partial(IR3Compiler, info.chip_id), QCOM_IR3)]) + super().__init__(device, QCOMAllocator(self), compilers, functools.partial(QCOMProgram, self), QCOMSignal, + functools.partial(QCOMComputeQueue, self), None) def _gpu_alloc(self, size:int, flags:int=0, uncached=False, fill_zeroes=False) -> HCQBuffer: flags |= flag("KGSL_MEMALIGN", alignment_hint:=12) | kgsl.KGSL_MEMFLAGS_USE_CPU_MAP diff --git a/tinygrad/runtime/support/compiler_mesa.py b/tinygrad/runtime/support/compiler_mesa.py index 8f97e12ed4..af7fa19d02 100644 --- a/tinygrad/runtime/support/compiler_mesa.py +++ b/tinygrad/runtime/support/compiler_mesa.py @@ -1,11 +1,16 @@ -import base64, ctypes, pathlib, tempfile, hashlib +import base64, ctypes, pathlib, tempfile, hashlib, sys from tinygrad.device import Compiler -from tinygrad.helpers import cpu_objdump, system +from tinygrad.helpers import cpu_objdump, system, data64 from tinygrad.runtime.autogen import mesa from tinygrad.runtime.support.compiler_cpu import CPULLVMCompiler, expect, cerr try: from tinygrad.runtime.autogen import llvm except (ImportError, FileNotFoundError): llvm = None #type:ignore[assignment] +def rzalloc(typ, ctx=None, **kwargs): + s = ctypes.cast(mesa.rzalloc_size(ctypes.cast(ctx, ctypes.c_void_p), ctypes.sizeof(typ)), ctypes.POINTER(typ)) + for k,v in kwargs.items(): setattr(s.contents, k, v) + return s + def deserialize(enc_src, opts): blobreader = mesa.struct_blob_reader() mesa.blob_reader_init(blobreader, src:=base64.b64decode(enc_src), len(src)) @@ -84,3 +89,56 @@ class NAKCompiler(NIRCompiler): with open(fn, "wb") as f: f.write(lib[ctypes.sizeof(mesa.struct_nak_shader_info):]) print(system(f"nvdisasm -b SM{self.arch[3:]} {fn}")) except Exception as e: print("Failed to generate SASS", str(e), "Make sure your PATH contains nvdisasm binary of compatible version.") + +def disas_adreno(lib:bytes, gpu_id=630): + with tempfile.TemporaryFile('w+', buffering=1) as tf: + @ctypes.CFUNCTYPE(None, ctypes.c_void_p, ctypes.c_uint32, ctypes.c_void_p) + def hd(data, n, instr): + fst, snd = data64(ctypes.cast(instr, ctypes.POINTER(ctypes.c_uint64)).contents.value) + print(f"{n:04} [{fst:08x}_{snd:08x}] ", end="", flush=True, file=tf) + + ctypes.CDLL(None).setlinebuf(fp:=ctypes.cast(ctypes.CDLL(None).fdopen(tf.fileno(), b"w"), ctypes.POINTER(mesa.struct__IO_FILE))) + mesa.ir3_isa_disasm(lib, len(lib), fp, mesa.struct_isa_decode_options(gpu_id, True, 0, True, pre_instr_cb=hd)) + tf.seek(0) + print(tf.read()) + +class IR3Compiler(NIRCompiler): + def __init__(self, chip_id, cache_key="ir3"): + assert sys.version_info >= (3,14), "IR3 requires python 3.14's bitfield fixes" + self.dev_id = mesa.struct_fd_dev_id(((chip_id >> 24) & 0xFF) * 100 + ((chip_id >> 16) & 0xFF) * 10 + ((chip_id >> 8) & 0xFF), chip_id) + self.cc = mesa.ir3_compiler_create(None, self.dev_id, mesa.fd_dev_info(self.dev_id), + mesa.struct_ir3_compiler_options(disable_cache=True)).contents + self.cc.has_preamble = False + self.nir_options = bytes(mesa.ir3_get_compiler_options(self.cc).contents) + super().__init__(f"compile_{cache_key}") + + def __del__(self): + mesa.ir3_compiler_destroy(self.cc) + super().__del__() + + def __reduce__(self): return IR3Compiler, (self.dev_id.chip_id,) + + # ir3_shader_variant info: https://elixir.bootlin.com/mesa/mesa-25.3.0/source/src/freedreno/ir3/ir3_shader.c#L1099 + def compile(self, src) -> bytes: + nir_shader = deserialize(src, self.nir_options) + mesa.ir3_nir_lower_io_vars_to_temporaries(nir_shader) + mesa.ir3_finalize_nir(self.cc, mesa.struct_ir3_shader_nir_options(), nir_shader) + shader = rzalloc(mesa.struct_ir3_shader, compiler=ctypes.pointer(self.cc), type=mesa.MESA_SHADER_COMPUTE, nir=nir_shader).contents + mesa.ir3_nir_post_finalize(shader) + v = rzalloc(mesa.struct_ir3_shader_variant, type=shader.type, compiler=ctypes.pointer(self.cc), key=mesa.struct_ir3_shader_key()).contents + v.const_state, shader.variants, shader.variant_count = rzalloc(mesa.struct_ir3_const_state, ctypes.pointer(v)), ctypes.pointer(v), 1 + v.num_uavs = (info:=nir_shader.contents.info).num_ssbos + info.num_images + assert not mesa.ir3_compile_shader_nir(self.cc, shader, v), "compilation failed" + lib = ctypes.cast(mesa.ir3_shader_assemble(v), ctypes.POINTER(ctypes.c_uint32)) + # NB: bytes(v) means the pointers in v are no longer safe! a custom __reduce__ that supports pointers for c.Struct would make this simpler + ret = bytes(v) + bytes(v.const_state.contents) + ctypes.string_at(v.imm_state.values, v.imm_state.count * 4) + ctypes.string_at(lib, v.info.size) + mesa.ralloc_free(ctypes.pointer(v)) + return ret + + @staticmethod + def unpack_lib(lib: bytes) -> tuple[mesa.struct_ir3_shader_variant, mesa.struct_ir3_const_state, bytes, bytes]: + shifted = lib[ctypes.sizeof(v:=mesa.struct_ir3_shader_variant.from_buffer_copy(lib)):] + shifted = shifted[ctypes.sizeof(cs:=mesa.struct_ir3_const_state.from_buffer_copy(shifted)):] + return v, cs, shifted[:v.imm_state.count * 4], shifted[v.imm_state.count * 4:] + + def disassemble(self, lib: bytes): disas_adreno(self.unpack_lib(lib)[3], self.dev_id.gpu_id)