Mesa NIR backend (NAK/LLVMpipe) (#12089)

* nak works

* TestOps::test_add works

* testop has no crashes

* fix bool casts

* fix typo

* add disassemble

* RANGE and locals/regs

* simplify NAKCompiler

* disass cleanup

* cleanup nir codegen

* almost all tests passing

* cleanup notes in extra/

* old notes

* only import nak if NIR=1

* fix new SPECIAL syntax

* fix local/shared memory

* more tests passing

* add DEFINE_VAR support

* llvmpipe kinda works

* diskcache

* some mypy stuff

* lvp passing test_ops.py

* fix imports

* actually fix imports

* remove 'stdout'

* fix llvm import

* fix mypy issues

* nicer errors

* simpler test_dtype skips

* test lvp in CI

* fix github action syntax

* fix more actions typos

* switch to mesa 25.1.0

* diskcache_put

* better generation for lvp nir_options

* b64encode shader blobs

* Revert diskcache changes

This reverts commits 930fa3de8a and 8428c694b3.

* general cleanup

* better error messages

* fix llvm import

* fix windows tests

* link with libm and libgcc_s

* fix some errors

* dont check for 'float4'

* NIR uses pointer arithmetic

* use tinymesa

* bump tinymesa

* bump tinymesa again

* update lvp nir_options

* print nir shader with DEBUG

* simplify LVPCompiler

* more tests

* "gated" STORE

* NAK is cacheable

* more tests

* all tests pass locally for NAK

* test autogen in CI

* autogen deps

* more deps

* fix uop_gc

* fix macos

* mypy

* save 2 lines

* save two more lines

* save 1 line

* save 4 lines

* save more lines

* Revert "save more lines"

This reverts commit dd3a720c5a.

* save more lines

* fix LVP on windows

* refactor

* reorganize some code

* refactor lib_gpu

* move LVP check

* out of order loads

* remove support.mesa

* bump tinymesa version

* simplify LVP jit

* macos

* macos ci

* shell: bash

* testing

* more testing

* compute brew prefix

* stupid typo

* actually fix

* lib

* stdout on macos

* inline gallivm_compile_module

* Revert "inline gallivm_compile_module"

This reverts commit b65983b151.

* elf macos

* semicolon

* inherit from CPULLVMCompiler

* ruff

* disas test

* fix libm linking

* default is fine actually

* arm works

* add elf loader link test

* fix NAK beam

* pylint is too smart by half

---------

Co-authored-by: George Hotz <72895+geohot@users.noreply.github.com>
Co-authored-by: nimlgen <138685161+nimlgen@users.noreply.github.com>
This commit is contained in:
Christopher Milan
2025-10-15 05:38:33 -04:00
committed by GitHub
parent f0268d13f6
commit 0aabc1e938
23 changed files with 20483 additions and 104 deletions

View File

@@ -41,6 +41,10 @@ inputs:
description: "Install LLVM?"
required: false
default: 'false'
mesa:
description: "Install mesa"
required: false
default: 'false'
runs:
using: "composite"
steps:
@@ -289,3 +293,13 @@ runs:
if: inputs.llvm == 'true' && runner.os == 'macOS'
shell: bash
run: brew install llvm@20
# **** mesa ****
- name: Install mesa (linux)
if: inputs.mesa == 'true' && runner.os == 'Linux'
shell: bash
run: sudo curl -L https://github.com/sirhcm/tinymesa/releases/download/tinymesa-32dc66c/libtinymesa_cpu-mesa-25.2.4-linux-amd64.so -o /usr/lib/libtinymesa_cpu.so
- name: Install mesa (macOS)
if: inputs.mesa == 'true' && runner.os == 'macOS'
shell: bash
run: brew install sirhcm/tinymesa/tinymesa

View File

@@ -36,8 +36,9 @@ jobs:
cuda: 'true'
webgpu: 'true'
llvm: 'true'
pydeps: 'pyyaml mako'
- name: Install autogen support packages
run: sudo apt-get install -y --no-install-recommends llvm-14-dev libclang-14-dev
run: sudo apt-get install -y --no-install-recommends llvm-14-dev libclang-14-dev llvm-20-dev
- name: Verify OpenCL autogen
run: |
cp tinygrad/runtime/autogen/opencl.py /tmp/opencl.py.bak
@@ -89,3 +90,8 @@ jobs:
cp tinygrad/runtime/autogen/llvm.py /tmp/llvm.py.bak
./autogen_stubs.sh llvm
diff /tmp/llvm.py.bak tinygrad/runtime/autogen/llvm.py
- name: Verify mesa autogen
run: |
cp tinygrad/runtime/autogen/mesa.py /tmp/mesa.py.bak
./autogen_stubs.sh mesa
diff /tmp/mesa.py.bak tinygrad/runtime/autogen/mesa.py

View File

@@ -677,7 +677,7 @@ jobs:
strategy:
fail-fast: false
matrix:
backend: [llvm, cpu, opencl]
backend: [llvm, cpu, opencl, lvp]
name: Linux (${{ matrix.backend }})
runs-on: ubuntu-22.04
@@ -691,9 +691,10 @@ jobs:
key: ${{ matrix.backend }}-minimal
deps: testing_minimal
opencl: ${{ matrix.backend == 'opencl' && 'true' }}
llvm: ${{ matrix.backend == 'llvm' && 'true' }}
llvm: ${{ matrix.backend == 'llvm' || matrix.backend == 'lvp' }}
mesa: ${{ matrix.backend == 'lvp' && 'true' }}
- name: Set env
run: printf "${{ matrix.backend == 'llvm' && 'CPU=1\nCPU_LLVM=1' || matrix.backend == 'cpu' && 'CPU=1\nCPU_LLVM=0\nCPU_COUNT=2' || matrix.backend == 'opencl' && 'CL=1' }}" >> $GITHUB_ENV
run: printf "${{ matrix.backend == 'llvm' && 'CPU=1\nCPU_LLVM=1' || matrix.backend == 'cpu' && 'CPU=1\nCPU_LLVM=0\nCPU_COUNT=2' || matrix.backend == 'opencl' && 'CL=1' || matrix.backend == 'lvp' && 'CPU=1\nCPU_LVP=1' }}" >> $GITHUB_ENV
- name: Check Device.DEFAULT and print some source
run: |
python3 -c "from tinygrad import Device; assert Device.DEFAULT in ['CPU','CL'], Device.DEFAULT"
@@ -895,7 +896,7 @@ jobs:
strategy:
fail-fast: false
matrix:
backend: [metal, llvm, cpu]
backend: [metal, llvm, cpu, lvp]
name: MacOS (${{ matrix.backend }})
runs-on: macos-15
timeout-minutes: 20
@@ -908,12 +909,13 @@ jobs:
key: macos-${{ matrix.backend }}-minimal
deps: testing_minimal
pydeps: "capstone"
llvm: ${{ matrix.backend == 'llvm' && 'true' }}
llvm: ${{ matrix.backend == 'llvm' || matrix.backend == 'lvp' }}
mesa: ${{ matrix.backend == 'lvp' && 'true' }}
- name: Set env
run: printf "${{ matrix.backend == 'llvm' && 'CPU=1\nCPU_LLVM=1' || matrix.backend == 'cpu' && 'CPU=1\nCPU_LLVM=0\nCPU_COUNT=2' || matrix.backend == 'metal' && 'METAL=1'}}" >> $GITHUB_ENV
run: printf "${{ matrix.backend == 'llvm' && 'CPU=1\nCPU_LLVM=1' || matrix.backend == 'cpu' && 'CPU=1\nCPU_LLVM=0\nCPU_COUNT=2' || matrix.backend == 'metal' && 'METAL=1' || matrix.backend == 'lvp' && 'CPU=1\nCPU_LVP=1' }}" >> $GITHUB_ENV
- name: Check Device.DEFAULT and print some source
run: |
python -c "from tinygrad import Device; assert Device.DEFAULT == {'LLVM':'CPU'}.get(x:='${{ matrix.backend }}'.upper(), x), Device.DEFAULT"
python -c "from tinygrad import Device; assert Device.DEFAULT == {'LLVM':'CPU','LVP':'CPU'}.get(x:='${{ matrix.backend }}'.upper(), x), Device.DEFAULT"
DEBUG=4 python3 test/test_tiny.py TestTiny.test_plus
- name: Run pytest (${{ matrix.backend }})
run: python3 -m pytest -n=auto test/ --ignore=test/models --ignore=test/unit --durations=20

View File

@@ -461,6 +461,85 @@ generate_libusb() {
python3 -c "import tinygrad.runtime.autogen.libusb"
}
generate_mesa() {
MESA_TAG="mesa-25.2.4"
MESA_SRC=/tmp/mesa-$MESA_TAG
TINYMESA_TAG=tinymesa-32dc66c
TINYMESA_DIR=/tmp/tinymesa-$MESA_TAG-$TINYMESA_TAG/
TINYMESA_SO=$TINYMESA_DIR/libtinymesa_cpu.so
if [ ! -d "$MESA_SRC" ]; then
git clone --depth 1 --branch $MESA_TAG https://gitlab.freedesktop.org/mesa/mesa.git $MESA_SRC
pushd .
cd $MESA_SRC
git reset --hard $MESA_COMMIT_HASH
# clang 14 doesn't support packed enums
sed -i "s/enum \w\+ \(\w\+\);$/uint8_t \1;/" $MESA_SRC/src/nouveau/headers/nv_device_info.h
sed -i "s/enum \w\+ \(\w\+\);$/uint8_t \1;/" $MESA_SRC/src/nouveau/compiler/nak.h
sed -i "s/nir_instr_type \(\w\+\);/uint8_t \1;/" $MESA_SRC/src/compiler/nir/nir.h
mkdir -p gen/util/format
python3 src/util/format/u_format_table.py src/util/format/u_format.yaml --enums > gen/util/format/u_format_gen.h
python3 src/compiler/nir/nir_opcodes_h.py > gen/nir_opcodes.h
python3 src/compiler/nir/nir_intrinsics_h.py --outdir gen
python3 src/compiler/nir/nir_intrinsics_indices_h.py --outdir gen
python3 src/compiler/nir/nir_builder_opcodes_h.py > gen/nir_builder_opcodes.h
python3 src/compiler/nir/nir_intrinsics_h.py --outdir gen
python3 src/compiler/builtin_types_h.py gen/builtin_types.h
popd
fi
if [ ! -d "$TINYMESA_DIR" ]; then
mkdir $TINYMESA_DIR
curl -L https://github.com/sirhcm/tinymesa/releases/download/$TINYMESA_TAG/libtinymesa_cpu-$MESA_TAG-linux-amd64.so -o $TINYMESA_SO
fi
clang2py -k cdefstu \
$MESA_SRC/src/compiler/nir/nir.h \
$MESA_SRC/src/compiler/nir/nir_builder.h \
$MESA_SRC/src/compiler/nir/nir_shader_compiler_options.h \
$MESA_SRC/src/compiler/nir/nir_serialize.h \
$MESA_SRC/gen/nir_intrinsics.h \
$MESA_SRC/src/nouveau/headers/nv_device_info.h \
$MESA_SRC/src/nouveau/compiler/nak.h \
$MESA_SRC/src/gallium/auxiliary/gallivm/lp_bld.h \
$MESA_SRC/src/gallium/auxiliary/gallivm/lp_bld_passmgr.h \
$MESA_SRC/src/gallium/auxiliary/gallivm/lp_bld_misc.h \
$MESA_SRC/src/gallium/auxiliary/gallivm/lp_bld_type.h \
$MESA_SRC/src/gallium/auxiliary/gallivm/lp_bld_init.h \
$MESA_SRC/src/gallium/auxiliary/gallivm/lp_bld_nir.h \
$MESA_SRC/src/gallium/auxiliary/gallivm/lp_bld_struct.h \
$MESA_SRC/src/gallium/auxiliary/gallivm/lp_bld_jit_types.h \
$MESA_SRC/src/gallium/auxiliary/gallivm/lp_bld_flow.h \
$MESA_SRC/src/gallium/auxiliary/gallivm/lp_bld_const.h \
$MESA_SRC/src/compiler/glsl_types.h \
$MESA_SRC/src/util/blob.h \
$MESA_SRC/src/util/ralloc.h \
--clang-args="-DHAVE_ENDIAN_H -DHAVE_STRUCT_TIMESPEC -DHAVE_PTHREAD -I$MESA_SRC/src -I$MESA_SRC/include -I$MESA_SRC/gen -I$MESA_SRC/src/compiler/nir -I$MESA_SRC/src/gallium/auxiliary -I$MESA_SRC/src/gallium/include -I$(llvm-config-20 --includedir)" \
-l $TINYMESA_SO \
-o $BASE/mesa.py
LVP_NIR_OPTIONS=$(./extra/mesa/lvp_nir_options.sh $MESA_SRC)
fixup $BASE/mesa.py
patch_dlopen $BASE/mesa.py tinymesa_cpu "(BASE:=os.getenv('MESA_PATH', f\"/usr{'/local/' if helpers.OSX else '/'}lib\"))+'/libtinymesa_cpu'+(EXT:='.dylib' if helpers.OSX else '.so')" "f'{BASE}/libtinymesa{EXT}'" "f'{brew_prefix()}/lib/libtinymesa_cpu.dylib'"
echo "lvp_nir_options = gzip.decompress(base64.b64decode('$LVP_NIR_OPTIONS'))" >> $BASE/mesa.py
cat <<EOF | sed -i "/import ctypes.*/r /dev/stdin" $BASE/mesa.py
def brew_prefix():
try: return subprocess.check_output(['brew', '--prefix', 'tinymesa']).decode().strip()
except Exception: return ''
EOF
sed -i "/in_dll/s/.*/try: &\nexcept AttributeError: pass/" $BASE/mesa.py
sed -i "s/import ctypes/import ctypes, ctypes.util, os, gzip, base64, subprocess, tinygrad.helpers as helpers/" $BASE/mesa.py
sed -i "s/ctypes.CDLL('.\+')/(dll := _try_dlopen_tinymesa_cpu())/" $BASE/mesa.py
echo "def __getattr__(nm): raise AttributeError() if dll else FileNotFoundError(f'libtinymesa not found (MESA_PATH={BASE}). See https://github.com/sirhcm/tinymesa ($TINYMESA_TAG, $MESA_TAG)')" >> $BASE/mesa.py
sed -i "s/ctypes.glsl_base_type/glsl_base_type/" $BASE/mesa.py
# bitfield bug in clang2py
sed -i "s/('fp_fast_math', ctypes.c_bool, 9)/('fp_fast_math', ctypes.c_uint32, 9)/" $BASE/mesa.py
sed -i "s/('\(\w\+\)', pipe_shader_type, 8)/('\1', ctypes.c_ubyte)/" $BASE/mesa.py
sed -i "s/\([0-9]\+\)()/\1/" $BASE/mesa.py
sed -i "s/\(struct_nir_builder._pack_\) = 1/\1 = 0/" $BASE/mesa.py
python3 -c "import tinygrad.runtime.autogen.mesa"
}
if [ "$1" == "opencl" ]; then generate_opencl
elif [ "$1" == "hip" ]; then generate_hip
elif [ "$1" == "comgr" ]; then generate_comgr
@@ -484,6 +563,7 @@ elif [ "$1" == "pci" ]; then generate_pci
elif [ "$1" == "vfio" ]; then generate_vfio
elif [ "$1" == "webgpu" ]; then generate_webgpu
elif [ "$1" == "libusb" ]; then generate_libusb
elif [ "$1" == "all" ]; then generate_opencl; generate_hip; generate_comgr; generate_cuda; generate_nvrtc; generate_hsa; generate_kfd; generate_nv; generate_amd; generate_io_uring; generate_libc; generate_am; generate_webgpu
elif [ "$1" == "mesa" ]; then generate_mesa
elif [ "$1" == "all" ]; then generate_opencl; generate_hip; generate_comgr; generate_cuda; generate_nvrtc; generate_hsa; generate_kfd; generate_nv; generate_amd; generate_io_uring; generate_libc; generate_am; generate_webgpu; generate_mesa
else echo "usage: $0 <type>"
fi

23
extra/mesa/lvp_nir_options.sh Executable file
View File

@@ -0,0 +1,23 @@
#!/bin/sh
if [ "$#" -ne 1 ] || ! [ -d $1 ]; then
echo "usage: $0 MESA_PREFIX"
exit 1
fi
TMP=$(mktemp)
trap 'rm -f "$TMP"' EXIT
(
cat <<EOF
#define HAVE_ENDIAN_H
#define HAVE_STRUCT_TIMESPEC
#define HAVE_PTHREAD
#include <unistd.h>
#include "nir_shader_compiler_options.h"
#include "compiler/shader_enums.h"
EOF
sed -n '/struct nir_shader_compiler_options/,/^}/{p;/^}/q}' $1/src/gallium/drivers/llvmpipe/lp_screen.c
echo "int main(void) { write(1, &gallivm_nir_options, sizeof(gallivm_nir_options)); }"
) | cc -x c -o $TMP - -I$1/src/compiler/nir -I$1/src -I$1/include && $TMP | gzip | base64 -w0

View File

@@ -1,7 +1,7 @@
import unittest, io
from contextlib import redirect_stdout
from tinygrad import Tensor, dtypes, Device
from tinygrad.helpers import OSX, CPU_LLVM
from tinygrad.helpers import OSX, CPU_LLVM, CPU_LVP
from tinygrad.engine.realize import lower_schedule
from tinygrad.device import is_dtype_supported
from tinygrad.engine.realize import get_program
@@ -19,7 +19,7 @@ class TestCompileFailures(unittest.TestCase):
class TestDisassembly(unittest.TestCase):
# TODO: fails on llvm. llvm.LLVMGetHostCPUName() returns "generic"
@unittest.skipUnless(Device.DEFAULT in ("CPU",) and not CPU_LLVM and OSX, "m series cpus support fp16 arithmetic")
@unittest.skipUnless(Device.DEFAULT in ("CPU",) and not (CPU_LLVM or CPU_LVP) and OSX, "m series cpus support fp16 arithmetic")
def test_float16_alu(self):
c = Tensor([1], dtype=dtypes.float16) + Tensor([1], dtype=dtypes.float16)
s = c.schedule()[-1]

View File

@@ -6,6 +6,7 @@ from tinygrad.device import is_dtype_supported
from tinygrad.helpers import getenv, DEBUG, CI
from tinygrad.dtype import DType, DTYPES_DICT, least_upper_dtype, fp8_to_float, float_to_fp8, _to_np_dtype, _to_torch_dtype, truncate
from tinygrad.renderer.ptx import PTXRenderer
from tinygrad.renderer.nir import NIRRenderer
from tinygrad import Device, Tensor, dtypes
from hypothesis import given, settings, strategies as strat
from test.helpers import rand_for_dtype
@@ -102,7 +103,7 @@ class TestDType(unittest.TestCase):
))
@unittest.skipIf(Device.DEFAULT == "PYTHON", "skip for now")
@unittest.skipIf(isinstance(Device[Device.DEFAULT].renderer, PTXRenderer), "skip for now")
@unittest.skipIf(isinstance(Device[Device.DEFAULT].renderer, (PTXRenderer, NIRRenderer)), "skip for now")
def test_uint_overflow(self):
if not dtypes.is_unsigned(self.DTYPE): raise unittest.SkipTest("only for unsigned")
v = dtypes.max(self.DTYPE)
@@ -261,7 +262,7 @@ class TestFloatDType(TestDType):
class TestDoubleDType(TestDType):
DTYPE = dtypes.double
@unittest.skipIf((CI and Device.DEFAULT in {"CUDA", "NV"}) or \
isinstance(Device[Device.DEFAULT].renderer, PTXRenderer), "conversion not supported on CI CUDA and PTX") # TODO: why not?
isinstance(Device[Device.DEFAULT].renderer, (PTXRenderer, NIRRenderer)), "conversion not supported on CI CUDA, PTX, and NIR") # TODO: why not?
def test_float64_increased_precision(self):
for func in [
lambda t: t.exp(),

View File

@@ -6,6 +6,7 @@ from tinygrad.tensor import _to_np_dtype
from tinygrad.device import is_dtype_supported
from tinygrad.runtime.ops_python import from_storage_scalar
from tinygrad.renderer.ptx import PTXRenderer
from tinygrad.renderer.nir import NIRRenderer
import numpy as np
import pytest
from hypothesis import assume, given, strategies as strat, settings, HealthCheck
@@ -29,8 +30,8 @@ unary_operations = [(Tensor.exp, np.exp), (Tensor.log, np.log), (Tensor.sin, np.
# TODO: enable this (this is a dtype issue)
#binary_operations.append(operator.truediv)
# TODO: CI CUDA segfaults on sin, WEBGPU sin is not precise enough for large numbers
if (getenv("MOCKGPU") and Device.DEFAULT in {"NV", "CUDA"}) or Device.DEFAULT == "WEBGPU":
# TODO: CI CUDA segfaults on sin, WEBGPU and NIR sines are not precise enough for large numbers
if (getenv("MOCKGPU") and Device.DEFAULT in {"NV", "CUDA"}) or Device.DEFAULT == "WEBGPU" or isinstance(Device[Device.DEFAULT].renderer, NIRRenderer):
unary_operations.remove((Tensor.sin, np.sin))
unary_operations.remove((Tensor.cos, np.cos))
@@ -184,8 +185,8 @@ class TestDTypeALU(unittest.TestCase):
@given(ht.int32, ht.int32, ht.float32, strat.sampled_from(integer_binary_operations), strat.sampled_from(binary_operations))
def test_int32_midcast_float(self, a, b, c, op1, op2): universal_test_midcast(a, b, c, op1, op2, dtypes.int32, dtypes.float32)
# Metal and CUDA and HIP behave differently than numpy in CI for overflows
skip_overflow = CI and Device.DEFAULT in {"AMD", "NV", "CUDA"}
# Metal and CUDA and HIP and NIR behave differently than numpy in CI for overflows
skip_overflow = (CI and Device.DEFAULT in {"AMD", "NV", "CUDA"}) or isinstance(Device[Device.DEFAULT].renderer, NIRRenderer)
@given(strat.floats(width=32, min_value=0, max_value=10.0) if skip_overflow else ht.float32,
strat.floats(width=32, min_value=0, max_value=10.0) if skip_overflow else ht.float32,
ht.int32, strat.sampled_from(binary_operations), strat.sampled_from(integer_binary_operations))

View File

@@ -26,8 +26,9 @@ import unittest
import numpy as np
import torch
from tinygrad import Tensor, dtypes, nn
from tinygrad.device import is_dtype_supported
from tinygrad.device import Device, is_dtype_supported
from tinygrad.helpers import getenv
from tinygrad.renderer.nir import NIRRenderer
MOCKGPU = getenv("MOCKGPU")
@@ -206,7 +207,8 @@ class TestUOpValidationIssue(unittest.TestCase):
# these fail with UOp verification error.
# we want more of these with diverse errors!
@unittest.skipIf((not is_dtype_supported(dtypes.long)) or MOCKGPU, "hangs gpuocelot")
@unittest.skipIf((not is_dtype_supported(dtypes.long)) or MOCKGPU or isinstance(Device[Device.DEFAULT].renderer, NIRRenderer),
"hangs gpuocelot, NIR cannot render")
def test_tensor_index_overflow(self):
val = Tensor([1])
big = val.expand(2**31 + 3)

View File

@@ -2,7 +2,7 @@ import time, math, unittest, functools, platform, warnings
import numpy as np
from typing import List, Callable
import torch
from tinygrad.helpers import getenv, IMAGE, DEBUG, CI, Context, CPU_LLVM, AMD_LLVM
from tinygrad.helpers import getenv, IMAGE, DEBUG, CI, Context, CPU_LLVM, CPU_LVP, AMD_LLVM
from tinygrad import Tensor, Device, dtypes
from tinygrad.tensor import _to_np_dtype
from tinygrad.device import is_dtype_supported
@@ -698,8 +698,8 @@ class TestOps(unittest.TestCase):
def test_pow_zero_tensor(self):
helper_test_op(None, lambda x,y: x**y, vals=[[0.0], [0.0]])
# TODO: fix WEBGPU
if Device.DEFAULT != "WEBGPU":
# TODO: fix WEBGPU and LVP
if Device.DEFAULT != "WEBGPU" and not CPU_LVP:
helper_test_op(None, lambda x,y: x**y, vals=[[0.0], [0.3]])
helper_test_op(None, lambda x,y: x**y, vals=[[0.0], [-0.3]])
def test_pow_zero_const(self):
@@ -830,6 +830,7 @@ class TestOps(unittest.TestCase):
self.assertEqual(a, b)
self.assertEqual(Tensor(-1).contiguous().idiv(4).item(), 0) # NOTE this is trunc-div behaviour
@unittest.skipIf(getenv("NV_NAK"), "MUFU.SIN is not accurate enough")
def test_sin(self):
helper_test_op([(45,65)], lambda x: x.sin())
helper_test_op([()], lambda x: x.sin())
@@ -839,6 +840,7 @@ class TestOps(unittest.TestCase):
helper_test_op(None, lambda x: x.sin(), vals=[[1e1, 1e2, 1e3, 1e4, 1e5, 1e6, -1e1, -1e2, -1e3, -1e4, -1e5, -1e6]],
atol=3e-3, rtol=3e-3, grad_atol=3e-3, grad_rtol=3e-3)
@unittest.skipIf(Device.DEFAULT == "WEBGPU" and platform.system() == "Windows", "Not accurate enough with DirectX backend")
@unittest.skipIf(getenv("NV_NAK"), "MUFU.SIN is not accurate enough")
def test_cos(self):
helper_test_op([(45,65)], lambda x: x.cos())
helper_test_op([()], lambda x: x.cos())
@@ -847,6 +849,7 @@ class TestOps(unittest.TestCase):
helper_test_op(None, lambda x: x.cos(), vals=[[1e1, 1e2, 1e3, 1e4, 1e5, 1e6, -1e1, -1e2, -1e3, -1e4, -1e5, -1e6]],
atol=3e-3, rtol=3e-3, grad_atol=3e-3, grad_rtol=3e-3)
@unittest.skipIf(Device.DEFAULT == "WEBGPU" and platform.system() == "Windows", "Not accurate enough with DirectX backend")
@unittest.skipIf(getenv("NV_NAK"), "MUFU.SIN is not accurate enough")
def test_tan(self):
# NOTE: backward has much higher diff with input close to pi/2 and -pi/2
helper_test_op([(45,65)], lambda x: x.tan(), low=-1.5, high=1.5)

View File

@@ -1,6 +1,6 @@
import unittest
from tinygrad import Tensor, Device
from tinygrad.helpers import CPU_LLVM
from tinygrad.helpers import CPU_LLVM, CPU_LVP
from tinygrad.codegen.opt import Opt, OptOps
from tinygrad.engine.realize import get_program
@@ -12,7 +12,7 @@ class TestOpts(unittest.TestCase):
out = (a+b).contiguous(arg=opts)
s = out.schedule()
self.assertEqual(s[-1].ast.arg.opts_to_apply, opts)
if Device.DEFAULT in {"CPU", "CL", "METAL"} and not CPU_LLVM:
if Device.DEFAULT in {"CPU", "CL", "METAL"} and not CPU_LLVM and not CPU_LVP:
prg = get_program(s[-1].ast)
self.assertIn('float4', prg.src)

View File

@@ -6,6 +6,7 @@ from tinygrad.helpers import getenv, CI, OSX
from tinygrad.device import is_dtype_supported
from tinygrad.engine.realize import lower_schedule, CompiledRunner
from tinygrad.renderer.ptx import PTXRenderer
from tinygrad.renderer.nir import NIRRenderer
from test.helpers import not_support_multi_device
import numpy as np
@@ -100,7 +101,7 @@ class TestRandomness(unittest.TestCase):
np.testing.assert_allclose(jr, r)
@unittest.skipIf(isinstance(Device[Device.DEFAULT].renderer, PTXRenderer), "fails with PTX")
@unittest.skipIf(isinstance(Device[Device.DEFAULT].renderer, (NIRRenderer, PTXRenderer)), "PTX and NIR use pointer arithmetic")
def test_threefry_doesnt_use_long(self):
for (_,ei) in lower_schedule(Tensor.rand(20).schedule()):
if isinstance(ei.prg, CompiledRunner):

View File

@@ -9,6 +9,7 @@ from hypothesis import given, settings, strategies as strat
from tinygrad.device import is_dtype_supported
from tinygrad.uop.ops import Ops, UOp
from tinygrad.renderer.ptx import PTXRenderer
from tinygrad.renderer.nir import NIRRenderer
from tinygrad.codegen import full_rewrite
from tinygrad.dtype import DType
@@ -871,7 +872,8 @@ class TestIdxUpcast(unittest.TestCase):
store = next(uop for uop in uops if uop.op is Ops.STORE)
assert store.op is Ops.STORE
idx = self._find_op(store, Ops.INDEX)
if idx is not None: # PTX turns Ops.INDEX into pointer arithmetic earlier than cstyle, plus it's already cast to int64
# PTX and NIR turn Ops.INDEX into pointer arithmetic earlier than cstyle, plus it's already cast to int64
if not isinstance(Device[Device.DEFAULT].renderer, (PTXRenderer, NIRRenderer)):
assert idx.op is Ops.INDEX
idx_val = idx.src[1]
assert idx_val.dtype is dtype
@@ -895,7 +897,7 @@ class TestIdxUpcast(unittest.TestCase):
def test_regular_sym(self):
self.do_op_then_assert(dtypes.int, 2048, 2048, UOp.variable("dim3", 1, 64).bind(32))
@unittest.skipIf(isinstance(Device[Device.DEFAULT].renderer, PTXRenderer), "PTX always convert Ops.INDEX to int64")
@unittest.skipIf(isinstance(Device[Device.DEFAULT].renderer, (PTXRenderer, NIRRenderer)), "PTX and NIR always converts Ops.INDEX to int64")
def test_symfold(self):
# This would cause an overflow, but after sym fold it's within int32
a = Tensor.arange(65535)

View File

@@ -149,6 +149,7 @@ class TestTranscendentalVectorized(unittest.TestCase):
for vec_size in [1,2,3,4,5,127,128]: self._test_vectorized_op(Tensor.log2, np.log2, (0.001, 200), vec_size)
@unittest.skipIf(getenv("DSP"), "requires int division")
@unittest.skipIf(getenv("NV_NAK"), "MUFU.SIN is not accurate enough")
def test_sin_vectorized(self):
for vec_size in [1,2,3,4,5,127,128]: self._test_vectorized_op(Tensor.sin, np.sin, (-100, 100), vec_size)

View File

@@ -24,6 +24,15 @@ class TestElfLoader(unittest.TestCase):
'''
with self.assertRaisesRegex(RuntimeError, 'evil_external_function'):
ClangJITCompiler().compile(src)
def test_link(self):
src = '''
float powf(float, float); // from libm
float test(float x, float y) { return powf(x, y); }
'''
args = ('-x', 'c', '-c', '-target', f'{platform.machine()}-none-unknown-elf', '-march=native', '-fPIC', '-O2', '-ffreestanding', '-nostdlib')
obj = subprocess.check_output(('clang',) + args + ('-', '-o', '-'), input=src.encode())
with self.assertRaisesRegex(RuntimeError, 'powf'): elf_loader(obj)
elf_loader(obj, link_libs=['m'])
if __name__ == '__main__':
unittest.main()

View File

@@ -327,8 +327,8 @@ def is_dtype_supported(dtype:DType, device:str|None=None) -> bool:
if device is None: device = Device.DEFAULT
if dtype == dtypes.bfloat16:
if device == "METAL": return not CI
if device in {"CUDA", "NV"}: return not CI and not getenv(f"{device}_PTX")
if device in {"CPU"}: return not CI and platform.machine() in {"arm", "arm64", "aarch64", "x86_64", "amd64"}
if device in {"CUDA", "NV"}: return not CI and not getenv(f"{device}_PTX") and not getenv("NV_NAK")
if device in {"CPU"}: return not CI and platform.machine() in {"arm", "arm64", "aarch64", "x86_64", "amd64"} and not getenv("CPU_LVP")
return device in {"AMD", "PYTHON", "NULL"}
if dtype in dtypes.fp8s: return device in {"PYTHON", "NULL"}
if device == "WEBGPU": return dtype in [dtypes.bool, dtypes.char, dtypes.uchar, dtypes.short,

View File

@@ -155,7 +155,7 @@ ALLOW_DEVICE_USAGE, MAX_BUFFER_SIZE = ContextVar("ALLOW_DEVICE_USAGE", 1), Conte
FUSE_ATTENTION = ContextVar("FUSE_ATTENTION", 0)
EMULATE = ContextVar("EMULATE", "")
CPU_COUNT = ContextVar("CPU_COUNT", max(1, len(os.sched_getaffinity(0)) if hasattr(os, "sched_getaffinity") else (os.cpu_count() or 1)))
CPU_LLVM, AMD_LLVM = ContextVar("CPU_LLVM", 0), ContextVar("AMD_LLVM", 1)
CPU_LLVM, CPU_LVP, AMD_LLVM = ContextVar("CPU_LLVM", 0), ContextVar("CPU_LVP", 0), ContextVar("AMD_LLVM", 1)
VIZ = PROFILE = ContextVar("VIZ", 0)
SPEC = ContextVar("SPEC", 0)
# TODO: disable by default due to speed

237
tinygrad/renderer/nir.py Normal file
View File

@@ -0,0 +1,237 @@
from typing import Callable, cast
from tinygrad.dtype import AddrSpace, DType, PtrDType, dtypes
from tinygrad.helpers import DEBUG, OSX, unwrap
from tinygrad.renderer import Renderer
from tinygrad.renderer.cstyle import CUDARenderer
from tinygrad.uop.ops import GroupOp, Ops, UOp, PatternMatcher, UPat
import tinygrad.runtime.autogen.mesa as mesa
import base64, ctypes, ctypes.util, struct, functools, inspect
def g(s:str): return getattr(mesa, s)
def nsrc(d:mesa.nir_def) -> mesa.nir_src: return mesa.nir_src(ssa=ctypes.pointer(d))
# this is a ridiculous hack, but I can't find a better way to grab the glsl_type objects
glsl_base = {**{d:g(f"GLSL_TYPE_{'U' if d in dtypes.uints else ''}INT{d.itemsize*8 if d.itemsize != 4 else ''}") for d in dtypes.ints},
**{getattr(dtypes,d):g(f"GLSL_TYPE_{d.upper()}") for d in ['double', 'float', 'float16']}, dtypes.bool: mesa.GLSL_TYPE_UINT8}
def glsl_type(t:DType) -> mesa.struct_glsl_type:
if isinstance(t, PtrDType): return mesa.glsl_array_type(glsl_type(t.base), t.size, 0).contents
return mesa.glsl_get_base_glsl_type(mesa.glsl_type(base_type=glsl_base[t])).contents
# alu ops, aop[<dtype>][<op>]
u_aop = { Ops.ADD: "iadd", Ops.MUL: "imul", Ops.IDIV: "udiv", Ops.MOD: "umod", Ops.CMPLT: "ult", Ops.CMPNE: "ine", Ops.CMPEQ: "ieq", Ops.OR: "ior",
Ops.AND: "iand", Ops.XOR: "ixor", Ops.WHERE: "bcsel", Ops.MAX: "umax"}
s_aop = {**u_aop, Ops.CMPLT: "ilt", Ops.IDIV: "idiv", Ops.MOD: "irem", Ops.MAX: "imax"}
f_aop = { Ops.ADD: "fadd", Ops.MUL: "fmul", Ops.CMPLT: "flt", Ops.CMPNE: "fneu", Ops.CMPEQ: "feq", Ops.FDIV: "fdiv", Ops.RECIP: "frcp",
Ops.MAX: "fmax", Ops.TRUNC: "ftrunc", Ops.SIN: "fsin", Ops.EXP2: "fexp2", Ops.LOG2: "flog2"}
aop = {**{x:u_aop for x in (dtypes.bool,)+dtypes.uints}, **{x:s_aop for x in dtypes.sints}, **{x:f_aop for x in dtypes.floats}}
def c(t:DType, u:bool=True) -> str: return "u" if t in dtypes.uints and u else ("i" if t in dtypes.ints else ("f" if t in dtypes.floats else "b"))
def ncast(b:mesa.nir_builder, src:mesa.nir_def, it:DType, ot:DType) -> mesa.nir_def:
if isinstance(it, PtrDType) and ot == dtypes.long: return src
if ot == dtypes.bool: return nalu(b, c(it, False)+'ne'+('u' if c(it) == 'f' else ''), src, nimm(b, 0, it))
return nalu(b, f"{c(it)}2{c(it) if it in dtypes.ints and ot in dtypes.ints else c(ot, ot == dtypes.bool)}{ot.itemsize*8}", src)
def nif(b:mesa.nir_builder, cond:mesa.nir_def, then_fn:Callable, else_fn:Callable):
nif = mesa.nir_push_if(b, cond)
t = then_fn()
mesa.nir_push_else(b, nif)
e = else_fn()
mesa.nir_pop_if(b, nif)
return t, e
def nalu(b:mesa.nir_builder, op:str, *srcs:mesa.nir_def) -> mesa.nir_def: return g(f"nir_build_alu{len(srcs)}")(b, g(f"nir_op_{op}"), *srcs).contents
def nir_instr(nc=1, bs=lambda: None, intrins=None, srcs=None, has_def=True, df=None, also=lambda: None, **contents):
def dec(f:Callable):
@functools.wraps(f)
def wrapper(*args, **kwargs) -> mesa.nir_def:
(ba:=inspect.signature(f).bind(*args, **kwargs)).apply_defaults()
def go(g): return g(**{nm: ba.arguments[nm] for nm in inspect.signature(g).parameters}) if callable(g) else g
instr = f(*args, **kwargs)
if has_def: mesa.nir_def_init(instr.contents.instr, getattr(instr.contents, "def"), go(nc), go(bs))
for k, v in go(intrins or {}).items():
idx = mesa.nir_intrinsic_infos[instr.contents.intrinsic].index_map[g(f"NIR_INTRINSIC_{k}")]
assert idx > 0
instr.contents.const_index[idx - 1] = go(v)
for i, src in enumerate(go(srcs or [])): ctypes.cast(instr.contents.src, ctypes.POINTER(mesa.nir_src))[i] = go(src)
for k,v in {k:vcomp for k,v in contents.items() if (vcomp:=go(v)) is not None}.items(): setattr(instr.contents, k, go(v))
mesa.nir_builder_instr_insert(ba.arguments['b'], instr.contents.instr)
go(also)
return getattr(instr.contents, "def") if has_def else (mesa.nir_def() if df is None else go(df))
return wrapper
return dec
@nir_instr(nc=1, bs=lambda src: src.bit_size, exact=lambda b:b.exact, fp_fast_math=lambda b:b.fp_fast_math)
def nchannel(b:mesa.nir_builder, src:mesa.nir_def, c:int):
alu_src = mesa.nir_alu_src(src=nsrc(src))
alu_src.swizzle[0] = c
mov = mesa.nir_alu_instr_create(b.shader, mesa.nir_op_mov)
ctypes.cast(mov.contents.src, ctypes.POINTER(mesa.nir_alu_src))[0] = alu_src
return mov
@nir_instr(nc=1, bs=lambda dtype: 1 if dtype == dtypes.bool else dtype.itemsize * 8)
def nimm(b:mesa.nir_builder, x, dtype:DType) -> mesa.nir_def:
instr = mesa.nir_load_const_instr_create(b.shader, 1, 1 if dtype == dtypes.bool else dtype.itemsize * 8)
struct.pack_into(unwrap(dtype.fmt), (ctypes.c_ubyte * dtype.itemsize).from_address(ctypes.addressof(instr.contents.value)), 0, x)
return instr
deref_var = nir_instr(nc=1, bs=32, modes=lambda var:var.data.mode, type=lambda var:var.type, var=lambda var:ctypes.pointer(var))( # pylint: disable=W0108
lambda b, var: mesa.nir_deref_instr_create(b.shader, mesa.nir_deref_type_var))
def iointr(space): return {"ALIGN_MUL":lambda dtype:dtype.itemsize} if space != AddrSpace.REG else {}
def scope(space): return 'global' if space == AddrSpace.GLOBAL else ('shared' if space == AddrSpace.LOCAL else 'deref')
nstore = nir_instr(has_def=False, df=lambda addr:addr, intrins=lambda space,val: {"WRITE_MASK":(1<<val.num_components)-1, **iointr(space)},
num_components=lambda val:val.num_components, srcs=lambda space, addr, val: [nsrc(val), nsrc(addr)][::1 if space != AddrSpace.REG else -1])(
lambda b, space, addr, val, dtype: mesa.nir_intrinsic_instr_create(b.shader, g(f"nir_intrinsic_store_{scope(space)}")))
nload = nir_instr(nc=lambda dtype:dtype.count, bs=lambda dtype:dtype.itemsize*8//dtype.count, num_components=lambda dtype:dtype.count,
intrins=lambda space:{**({"ACCESS":mesa.ACCESS_CAN_REORDER} if space==AddrSpace.GLOBAL else {}), **iointr(space)}, srcs=lambda addr: [nsrc(addr)])(
lambda b, space, addr, dtype: mesa.nir_intrinsic_instr_create(b.shader, g(f"nir_intrinsic_load_{scope(space)}")))
ngid = nir_instr(nc=3, bs=32)(lambda b: mesa.nir_intrinsic_instr_create(b.shader, mesa.nir_intrinsic_load_workgroup_id))
nlid = nir_instr(nc=3, bs=32)(lambda b: mesa.nir_intrinsic_instr_create(b.shader, mesa.nir_intrinsic_load_local_invocation_id))
nbarrier = nir_instr(has_def=False, intrins={"EXECUTION_SCOPE":mesa.SCOPE_WORKGROUP})(
lambda b: mesa.nir_intrinsic_instr_create(b.shader, mesa.nir_intrinsic_barrier))
@nir_instr(has_def=False, target=lambda tgt:tgt and ctypes.pointer(tgt), condition=lambda cond:cond and nsrc(cond),
else_target=lambda else_tgt: else_tgt and ctypes.pointer(else_tgt))
def njump(b:mesa.nir_builder, typ, tgt=None, cond=None, else_tgt=None): return mesa.nir_jump_instr_create(b.shader, typ)
def if_phi(b:mesa.nir_builder, cond, then_fn, else_fn): return mesa.nir_if_phi(b, *nif(b, cond, then_fn, else_fn)).contents
def nidx(b:mesa.nir_builder, buf, off, dtype, gate=None) -> mesa.nir_def:
@nir_instr(nc=1, bs=32, modes=lambda buf: buf.data.mode, type=lambda buf: mesa.glsl_get_array_element(buf.type))
def reg(b, buf):
deref = mesa.nir_deref_instr_create(b.shader, mesa.nir_deref_type_array)
deref.contents.parent, deref.contents.arr.index = nsrc(deref_var(b, buf)), nsrc(off)
return deref
f = (functools.partial(reg, b, buf) if dtype.addrspace == AddrSpace.REG else
lambda: nalu(b, "iadd", buf, nalu(b, "imul", off, nimm(b, dtype.itemsize, dtypes.long))))
return if_phi(b, gate, f, lambda: buf) if gate is not None else f()
class NIRRenderer(Renderer):
suffix = "NAK"
global_max, local_max, shared_max = CUDARenderer.global_max, CUDARenderer.local_max, CUDARenderer.shared_max
code_for_op = {**{k:lambda:None for k in u_aop.keys()}, **{k:lambda:None for k in s_aop.keys()}, **{k:lambda:None for k in f_aop.keys()}}
extra_matcher = PatternMatcher([
# handle negative unsigned CONST
(UPat.cvar("x", dtypes.uints), lambda x: UOp(Ops.CONST, dtype=x.dtype, arg=x.dtype.max+x.arg+1) if x.arg < 0 else None),
# from ptx
(UPat.var('x', dtype=dtypes.bool)<UPat.var('y'), lambda x,y: (x^True)&y),
# load/store bool -> uint8
(UPat(Ops.LOAD, dtypes.bool, name="x"),
lambda x: x.replace(dtype=dtypes.uint8, src=x.src[0:1]+((x.src[1].cast(dtypes.uint8),) if len(x.src)>=2 else ())+x.src[2:]).cast(dtypes.bool)),
(UPat(Ops.STORE, src=(UPat(), UPat(dtype=dtypes.bool)), name="x", allow_any_len=True),
lambda x: x.replace(src=x.src[0:1] + (x.src[1].cast(dtypes.uint8),) + x.src[2:])),
# load/store use pointer arithmetic, and the cast does nothing
(UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var("off")), allow_any_len=True, name="x"),
lambda x,buf,off: x.replace(src=(buf,off.cast(dtypes.long))+x.src[2:]) if buf.dtype.addrspace != AddrSpace.REG and off.op != Ops.CAST else None),
(UPat(Ops.CAST, name="x"), lambda x: x.src[0] if isinstance(x.dtype, PtrDType) or x.src[0].dtype == dtypes.void else None),
])
def_rewrite = PatternMatcher([
(UPat(Ops.CONST, name="x"), lambda ctx,x: nimm(ctx.b, x.arg, x.dtype)),
(UPat(Ops.DEFINE_GLOBAL, name="x"), lambda ctx,x: ctx.param(ctx.b, x.dtype, 8)),
(UPat(Ops.DEFINE_VAR, name="x"), lambda ctx,x: ctx.param(ctx.b, x.dtype, 4)),
(UPat(Ops.SPECIAL, name="x"), lambda ctx,x: nchannel(ctx.b, ngid(ctx.b) if x.arg[0] == 'g' else nlid(ctx.b), int(x.arg[-1]))),
(UPat(Ops.STORE, src=(UPat(Ops.INDEX, src=(UPat.var("buf"),UPat.var("off")), allow_any_len=True), UPat.var("val")), allow_any_len=True, name="x"),
lambda ctx,x,buf,off,val: nstore(ctx.b, buf.ptrdtype.addrspace, nidx(ctx.b, ctx.r[buf], ctx.r[off], buf.dtype), ctx.r[val], val.dtype)),
(UPat(Ops.LOAD, src=(UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var("off"), UPat.var("gate"))), UPat.var("alt")), allow_any_len=True, name="x"),
lambda ctx,x,buf,off,alt,gate: if_phi(ctx.b, ctx.r[gate],
lambda: nload(ctx.b, buf.ptrdtype.addrspace, nidx(ctx.b, ctx.r[buf], ctx.r[off], buf.dtype, ctx.r[gate]), x.dtype), lambda: ctx.r[alt])),
(UPat(Ops.LOAD, src=(UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var("off"))),), allow_any_len=True, name="x"),
lambda ctx,x,buf,off: nload(ctx.b, buf.ptrdtype.addrspace, nidx(ctx.b, ctx.r[buf], ctx.r[off], buf.dtype), x.dtype)),
(UPat(Ops.VECTORIZE, name="x"), lambda ctx,x: nalu(ctx.b, f"vec{x.dtype.count}", *[ctx.r[src] for src in x.src])),
(UPat(GroupOp.ALU, name="x"), lambda ctx,x: nalu(ctx.b, aop[x.src[0].dtype.scalar()][x.op], *[ctx.r[src] for src in x.src])),
(UPat(Ops.CAST, name="x"), lambda ctx,x: ncast(ctx.b, ctx.r[x.src[0]], x.src[0].dtype, x.dtype)),
(UPat(Ops.BITCAST, src=(UPat.var("a"),), allow_any_len=True), lambda ctx,a: ctx.r[a]),
(UPat(Ops.GEP, src=(UPat.var("a"),), name="x"), lambda ctx,x,a: nchannel(ctx.b, ctx.r[a], x.arg[0])),
(UPat(Ops.DEFINE_REG, name="x"), lambda ctx,x:mesa.nir_local_variable_create(ctx.b.impl, glsl_type(x.dtype), f"acc{x.arg[0]}".encode()).contents),
(UPat(Ops.BARRIER), lambda ctx: nbarrier(ctx.b)),
(UPat(Ops.IF, name="x"), lambda ctx,x: mesa.nir_push_if(ctx.b, ctx.r[x.src[0]])),
(UPat(Ops.ENDIF, name="x"), lambda ctx,x: (lambda _: mesa.nir_def())(mesa.nir_pop_if(ctx.b, ctx.r[x.src[0]])))
])
def __init__(self): mesa.glsl_type_singleton_init_or_ref()
def __del__(self):
try: mesa.glsl_type_singleton_decref()
except FileNotFoundError: pass
@property
def nir_options(self): raise NotImplementedError("needs nir_options")
def param(self, b:mesa.nir_builder, dtype:DType, sz:int) -> mesa.nir_def: raise NotImplementedError("needs param")
def prerender(self, uops:list[UOp]):
self.b = mesa.nir_builder_init_simple_shader(mesa.MESA_SHADER_COMPUTE, mesa.nir_shader_compiler_options.from_buffer_copy(self.nir_options), None)
def render(self, uops:list[UOp]):
self.prerender(uops)
for u in [u for u in uops if u.op is Ops.SPECIAL and u.arg[0] == "l"]: self.b.shader.contents.info.workgroup_size[int(u.arg[-1])] = u.src[0].arg
self.r, self.param_idx, ranges = {}, 0, []
for u in uops:
if u.op == Ops.NOOP or u.op == Ops.INDEX: pass
elif u.op == Ops.SINK:
if u.arg is not None: self.b.shader.contents.info.name = mesa.char_pointer_cast(u.arg.function_name)
elif u.op == Ops.DEFINE_LOCAL:
self.r[u] = nimm(self.b, self.b.shader.contents.info.shared_size, dtypes.long)
self.b.shader.contents.info.shared_size += u.dtype.nbytes()
elif u.op == Ops.RANGE:
ranges.append(i:=deref_var(self.b, mesa.nir_local_variable_create(self.b.impl, glsl_type(u.dtype), f"idx{u.arg[0]}".encode()).contents))
nstore(self.b, AddrSpace.REG, i, nimm(self.b, 0, u.dtype), u.dtype)
mesa.nir_push_loop(self.b)
self.r[u] = nload(self.b, AddrSpace.REG, i, u.dtype)
elif u.op == Ops.ENDRANGE:
nif(self.b, nalu(self.b, "ilt", x:=nalu(self.b, "iadd", self.r[u.src[0]], nimm(self.b, 1, u.src[0].dtype)), self.r[u.src[0].src[0]]),
functools.partial(nstore, self.b, AddrSpace.REG, ranges.pop(), x, u.src[0].dtype), lambda: njump(self.b, mesa.nir_jump_break))
mesa.nir_pop_loop(self.b, None)
else:
if (d:=self.def_rewrite.rewrite(u, ctx=self)) is None: raise RuntimeError(f"failed to render {u.op} srcs {[x.dtype for x in u.src]}")
self.r[u] = cast(mesa.nir_def, d)
mesa.nir_validate_shader(self.b.shader, b"after render")
if DEBUG >= 4: mesa.nir_print_shader(self.b.shader, ctypes.POINTER(mesa.struct__IO_FILE).in_dll(ctypes.CDLL(ctypes.util.find_library('c')),
"__stdoutp" if OSX else "stdout"))
mesa.nir_serialize(blob:=mesa.struct_blob(), self.b.shader, False)
ret = base64.b64encode(ctypes.string_at(blob.data, blob.size)).decode()
mesa.ralloc_free(self.b.shader)
ctypes.CDLL(None).free(blob.data)
del self.b, self.r
return ret
class NAKRenderer(NIRRenderer):
device = "NV"
def __init__(self, dev=None, nir_options=None):
self.dev, self._nir_options = dev, nir_options
super().__init__()
def __reduce__(self): return NAKRenderer, (None, self.nir_options,)
@property
def nir_options(self):
if self._nir_options is None: self._nir_options = self.dev.compiler.nir_options
return self._nir_options
param = nir_instr(nc=1, num_components=1, bs=lambda sz:sz*8, also=lambda self,sz: setattr(self, "param_idx", self.param_idx + sz),
intrins={"ALIGN_MUL":lambda sz:sz}, srcs=lambda self,b: [nsrc(nimm(b, 0, dtypes.int)), nsrc(nimm(b, self.param_idx, dtypes.int))])(
lambda self, b, dtype, sz: mesa.nir_intrinsic_instr_create(b.shader, mesa.nir_intrinsic_ldc_nv))
class LVPRenderer(NIRRenderer):
device = "CPU"
has_local = False
has_shared = False
global_max = (1, 0, 0)
nir_options = mesa.lvp_nir_options
param = nir_instr(nc=1, bs=lambda sz: sz * 8, num_components=1, intrins={"ALIGN_MUL":lambda sz: sz, "RANGE":lambda self: self.param_sz},
srcs=lambda b, self: [nsrc(nimm(b, 0, dtypes.int)), nsrc(nimm(b, self.param_idx, dtypes.int))], also=lambda self, sz:
setattr(self, "param_idx", self.param_idx+sz))(lambda self, b, dtype, sz: mesa.nir_intrinsic_instr_create(b.shader, mesa.nir_intrinsic_load_ubo))
def prerender(self, uops:list[UOp]):
super().prerender(uops)
self.param_sz = sum([8 if u.op == Ops.DEFINE_GLOBAL else u.dtype.itemsize for u in uops if u.op in (Ops.DEFINE_GLOBAL, Ops.DEFINE_VAR)])

File diff suppressed because it is too large Load Diff

View File

@@ -1,11 +1,15 @@
from __future__ import annotations
import platform, sys, ctypes, functools, time, mmap, threading, queue
from tinygrad.helpers import from_mv, to_mv, OSX, WIN, mv_address, wait_cond, cpu_profile, suppress_finalizing, unwrap
from tinygrad.device import BufferSpec, DMACPURef
from tinygrad.helpers import from_mv, to_mv, OSX, WIN, mv_address, wait_cond, cpu_profile, suppress_finalizing, unwrap, data64_le
from tinygrad.device import BufferSpec, DMACPURef, CompilerPairT
from tinygrad.runtime.support.hcq import HCQCompiled, HCQAllocatorBase, HCQBuffer, HWQueue, HCQArgsState, HCQSignal, HCQProgram, MMIOInterface
from tinygrad.runtime.support.hcq import CLikeArgsState
from tinygrad.renderer.cstyle import ClangRenderer
from tinygrad.renderer.llvmir import LLVMRenderer
from tinygrad.renderer.nir import LVPRenderer
from tinygrad.runtime.support.compiler_cpu import CPULLVMCompiler, ClangJITCompiler
from tinygrad.runtime.support.compiler_mesa import LVPCompiler
from tinygrad.runtime.support.elf import jit_loader
from tinygrad.uop.ops import sint
class CPUSignal(HCQSignal):
@@ -46,12 +50,18 @@ class CPUComputeQueue(HWQueue):
def memory_barrier(self): return self
def exec(self, prg:CPUProgram, args_state:HCQArgsState, global_size, local_size):
if isinstance(args_state, LVPArgsState):
self.bind_args_state(args_state)
return self.cmd(self._exec, prg, 1, args_state.buf.va_addr)
return self.cmd(self._exec, prg, len(args_state.bufs), *[x.va_addr for x in args_state.bufs], *args_state.vals, threads=(global_size or (1,))[0])
def wait(self, signal, value=0): return self.cmd(self._wait, signal.value_addr, value)
def timestamp(self, signal): return self.cmd(self._timestamp, signal.timestamp_addr)
def signal(self, signal, value:sint=0): return self.cmd(self._signal, signal.value_addr, value)
def _submit(self, dev): dev.tasks.put(self._q[:])
class LVPArgsState(CLikeArgsState):
def __init__(self, buf, prg, bufs, vals=()): super().__init__(buf, prg, bufs, vals, [*data64_le(buf.va_addr + 12), (len(bufs) + len(vals)) * 2])
# NOTE: MAP_JIT is added to mmap module in python 3.13
MAP_JIT = 0x0800
@@ -61,6 +71,7 @@ class CPUProgram(HCQProgram):
except OSError: pass
def __init__(self, dev, name:str, lib:bytes):
LVP = isinstance(dev.compiler, LVPCompiler)
if sys.platform == "win32": # mypy doesn't understand when WIN is used here
PAGE_EXECUTE_READWRITE, MEM_COMMIT, MEM_RESERVE = 0x40, 0x1000, 0x2000
ctypes.windll.kernel32.VirtualAlloc.restype = ctypes.c_void_p
@@ -76,6 +87,7 @@ class CPUProgram(HCQProgram):
self.mem = mmap.mmap(-1, len(lib), mmap.MAP_ANON|mmap.MAP_PRIVATE|(MAP_JIT if OSX else 0), mmap.PROT_READ|mmap.PROT_WRITE|mmap.PROT_EXEC)
if OSX: unwrap(CPUProgram.rt_lib).pthread_jit_write_protect_np(False)
if LVP: lib = jit_loader(lib, base=ctypes.addressof(ctypes.c_void_p.from_buffer(self.mem)), link_libs=['m'])
self.mem.write(lib)
if OSX: unwrap(CPUProgram.rt_lib).pthread_jit_write_protect_np(True)
@@ -92,7 +104,7 @@ class CPUProgram(HCQProgram):
self.fxn = ctypes.CFUNCTYPE(None)(mv_address(self.mem))
super().__init__(HCQArgsState, dev, name, kernargs_alloc_size=0)
super().__init__(LVPArgsState if LVP else HCQArgsState, dev, name, kernargs_alloc_size=12+256 if LVP else 0)
@suppress_finalizing
def __del__(self):
@@ -123,5 +135,5 @@ class CPUDevice(HCQCompiled):
def __init__(self, device:str=""):
self.tasks:queue.Queue = queue.Queue()
CPUWorker(self, self.tasks, thread_id=0).start()
compilers = [(ClangRenderer, ClangJITCompiler), (LLVMRenderer, CPULLVMCompiler)]
compilers:list[CompilerPairT] = [(ClangRenderer, ClangJITCompiler), (LLVMRenderer, CPULLVMCompiler), (LVPRenderer, LVPCompiler)]
super().__init__(device, CPUAllocator(self), compilers, functools.partial(CPUProgram, self), CPUSignal, CPUComputeQueue)

View File

@@ -11,10 +11,12 @@ from tinygrad.helpers import getenv, mv_address, round_up, data64, data64_le, pr
from tinygrad.renderer.ptx import PTXRenderer
from tinygrad.renderer.cstyle import NVRenderer
from tinygrad.runtime.support.compiler_cuda import CUDACompiler, PTXCompiler, NVPTXCompiler, NVCompiler
from tinygrad.runtime.autogen import nv_gpu, pci
from tinygrad.runtime.support.compiler_mesa import NAKCompiler
from tinygrad.runtime.autogen import nv_gpu, pci, mesa
from tinygrad.runtime.support.elf import elf_loader
from tinygrad.runtime.support.nv.nvdev import NVDev, NVMemoryManager
from tinygrad.runtime.support.system import System, PCIIfaceBase, MAP_FIXED
from tinygrad.renderer.nir import NAKRenderer
if getenv("IOCTL"): import extra.nv_gpu_driver.nv_ioctl # noqa: F401 # pylint: disable=unused-import
def get_error_str(status): return f"{status}: {nv_gpu.nv_status_codes.get(status, 'Unknown error')}"
@@ -185,68 +187,69 @@ class NVCopyQueue(NVCommandQueue):
class NVArgsState(CLikeArgsState):
def __init__(self, buf:HCQBuffer, prg:NVProgram, bufs:tuple[HCQBuffer, ...], vals:tuple[int, ...]=()):
if MOCKGPU: prg.constbuffer_0[80:82] = [len(bufs), len(vals)]
super().__init__(buf, prg, bufs, vals=vals, prefix=prg.constbuffer_0)
if MOCKGPU: prg.cbuf_0[80:82] = [len(bufs), len(vals)]
super().__init__(buf, prg, bufs, vals=vals, prefix=prg.cbuf_0 or None)
class NVProgram(HCQProgram):
def __init__(self, dev:NVDevice, name:str, lib:bytes):
self.dev, self.name, self.lib = dev, name, lib
# For MOCKGPU, the lib is PTX code, so some values are emulated.
cbuf0_size = 0 if not MOCKGPU else 0x160
if MOCKGPU: image, sections, relocs = memoryview(bytearray(lib) + b'\x00' * (4 - len(lib)%4)).cast("I"), [], [] # type: ignore
else: image, sections, relocs = elf_loader(self.lib, force_section_align=128)
# NOTE: Ensure at least 4KB of space after the program to mitigate prefetch memory faults.
self.lib_gpu = self.dev.allocator.alloc(round_up(image.nbytes, 0x1000) + 0x1000, buf_spec:=BufferSpec(cpu_access=True))
self.prog_addr, self.prog_sz, self.regs_usage, self.shmem_usage, self.lcmem_usage = self.lib_gpu.va_addr, image.nbytes, 0, 0x400, 0
self.constbufs: dict[int, tuple[int, int]] = {0: (0, 0x160)} # dict[constbuf index, tuple[va_addr, size]]
for sh in sections:
if sh.name == f".nv.shared.{self.name}": self.shmem_usage = round_up(0x400 + sh.header.sh_size, 128)
if sh.name == f".text.{self.name}": self.prog_addr, self.prog_sz = self.lib_gpu.va_addr+sh.header.sh_addr, sh.header.sh_size
elif m:=re.match(r'\.nv\.constant(\d+)', sh.name): self.constbufs[int(m.group(1))] = (self.lib_gpu.va_addr+sh.header.sh_addr, sh.header.sh_size)
elif sh.name.startswith(".nv.info"):
for typ, param, data in self._parse_elf_info(sh):
if sh.name == f".nv.info.{name}" and param == 0xa: cbuf0_size = struct.unpack_from("IH", data)[1] # EIATTR_PARAM_CBANK
elif sh.name == ".nv.info" and param == 0x12: self.lcmem_usage = struct.unpack_from("II", data)[1] + 0x240 # EIATTR_MIN_STACK_SIZE
elif sh.name == ".nv.info" and param == 0x2f: self.regs_usage = struct.unpack_from("II", data)[1] # EIATTR_REGCOUNT
if (NAK:=isinstance(dev.compiler, NAKCompiler)):
image, self.cbuf_0 = memoryview(bytearray(lib[ctypes.sizeof(info:=mesa.struct_nak_shader_info.from_buffer_copy(lib)):])), []
self.regs_usage, self.shmem_usage, self.lcmem_usage = info.num_gprs, round_up(info.cs.smem_size, 128), round_up(info.slm_size, 16)
elif MOCKGPU: image, sections, relocs = memoryview(bytearray(lib) + b'\x00' * (4 - len(lib)%4)).cast("I"), [], [] # type: ignore
else: image, sections, relocs = elf_loader(self.lib, force_section_align=128)
# NOTE: Ensure at least 4KB of space after the program to mitigate prefetch memory faults.
self.lib_gpu = self.dev.allocator.alloc(round_up((prog_sz:=image.nbytes), 0x1000) + 0x1000, buf_spec:=BufferSpec(cpu_access=True))
prog_addr = self.lib_gpu.va_addr
if not NAK:
# For MOCKGPU, the lib is PTX code, so some values are emulated.
self.regs_usage, self.shmem_usage, self.lcmem_usage, cbuf0_size = 0, 0x400, 0x240, 0 if not MOCKGPU else 0x160
for sh in sections: # pylint: disable=possibly-used-before-assignment
if sh.name == f".nv.shared.{self.name}": self.shmem_usage = round_up(0x400 + sh.header.sh_size, 128)
if sh.name == f".text.{self.name}": prog_addr, prog_sz = self.lib_gpu.va_addr+sh.header.sh_addr, sh.header.sh_size
elif m:=re.match(r'\.nv\.constant(\d+)', sh.name):
self.constbufs[int(m.group(1))] = (self.lib_gpu.va_addr+sh.header.sh_addr, sh.header.sh_size)
elif sh.name.startswith(".nv.info"):
for typ, param, data in self._parse_elf_info(sh):
if sh.name == f".nv.info.{name}" and param == 0xa: cbuf0_size = struct.unpack_from("IH", data)[1] # EIATTR_PARAM_CBANK
elif sh.name == ".nv.info" and param == 0x12: self.lcmem_usage = struct.unpack_from("II", data)[1] + 0x240 # EIATTR_MIN_STACK_SIZE
elif sh.name == ".nv.info" and param == 0x2f: self.regs_usage = struct.unpack_from("II", data)[1] # EIATTR_REGCOUNT
# Apply relocs
for apply_image_offset, rel_sym_offset, typ, _ in relocs: # pylint: disable=possibly-used-before-assignment
# These types are CUDA-specific, applying them here
if typ == 2: image[apply_image_offset:apply_image_offset+8] = struct.pack('<Q', self.lib_gpu.va_addr + rel_sym_offset) # R_CUDA_64
elif typ == 0x38: image[apply_image_offset+4:apply_image_offset+8] = struct.pack('<I', (self.lib_gpu.va_addr + rel_sym_offset) & 0xffffffff)
elif typ == 0x39: image[apply_image_offset+4:apply_image_offset+8] = struct.pack('<I', (self.lib_gpu.va_addr + rel_sym_offset) >> 32)
else: raise RuntimeError(f"unknown NV reloc {typ}")
self.cbuf_0 = [0] * (cbuf0_size // 4)
# Ensure device has enough local memory to run the program
self.dev._ensure_has_local_memory(self.lcmem_usage)
# Apply relocs
for apply_image_offset, rel_sym_offset, typ, _ in relocs:
# These types are CUDA-specific, applying them here
if typ == 2: image[apply_image_offset:apply_image_offset+8] = struct.pack('<Q', self.lib_gpu.va_addr + rel_sym_offset) # R_CUDA_64
elif typ == 0x38: image[apply_image_offset+4:apply_image_offset+8] = struct.pack('<I', (self.lib_gpu.va_addr + rel_sym_offset) & 0xffffffff)
elif typ == 0x39: image[apply_image_offset+4:apply_image_offset+8] = struct.pack('<I', (self.lib_gpu.va_addr + rel_sym_offset) >> 32)
else: raise RuntimeError(f"unknown NV reloc {typ}")
ctypes.memmove(self.lib_gpu.va_addr, mv_address(image), image.nbytes)
self.constbuffer_0 = [0] * (cbuf0_size // 4)
if dev.iface.compute_class >= nv_gpu.BLACKWELL_COMPUTE_A:
self.constbuffer_0[188:192], self.constbuffer_0[223] = [*data64_le(self.dev.shared_mem_window), *data64_le(self.dev.local_mem_window)], 0xfffdc0
qmd = {'qmd_major_version':5, 'qmd_type':nv_gpu.NVCEC0_QMDV05_00_QMD_TYPE_GRID_CTA, 'register_count':self.regs_usage,
'program_address_upper_shifted4':hi32(self.prog_addr>>4), 'program_address_lower_shifted4':lo32(self.prog_addr>>4),
'shared_memory_size_shifted7':self.shmem_usage>>7, 'shader_local_memory_high_size_shifted4':self.dev.slm_per_thread>>4}
if not NAK: self.cbuf_0[188:192], self.cbuf_0[223] = [*data64_le(self.dev.shared_mem_window), *data64_le(self.dev.local_mem_window)], 0xfffdc0
qmd = {'qmd_major_version':5, 'qmd_type':nv_gpu.NVCEC0_QMDV05_00_QMD_TYPE_GRID_CTA, 'program_address_upper_shifted4':hi32(prog_addr>>4),
'program_address_lower_shifted4':lo32(prog_addr>>4), 'register_count':self.regs_usage, 'shared_memory_size_shifted7':self.shmem_usage>>7,
'shader_local_memory_high_size_shifted4':self.lcmem_usage>>4 if NAK else self.dev.slm_per_thread>>4}
else:
self.constbuffer_0[6:12] = [*data64_le(self.dev.shared_mem_window), *data64_le(self.dev.local_mem_window), *data64_le(0xfffdc0)]
qmd = {'qmd_major_version':3, 'sm_global_caching_enable':1, 'shader_local_memory_high_size':self.dev.slm_per_thread,
'program_address_upper':hi32(self.prog_addr), 'program_address_lower':lo32(self.prog_addr), 'shared_memory_size':self.shmem_usage,
'register_count_v':self.regs_usage}
if not NAK: self.cbuf_0[6:12] = [*data64_le(self.dev.shared_mem_window), *data64_le(self.dev.local_mem_window), *data64_le(0xfffdc0)]
qmd = {'qmd_major_version':3, 'sm_global_caching_enable':1, 'program_address_upper':hi32(prog_addr), 'program_address_lower':lo32(prog_addr),
'shared_memory_size':self.shmem_usage, 'register_count_v':self.regs_usage,
**({'shader_local_memory_low_size':self.lcmem_usage} if NAK else {'shader_local_memory_high_size':self.dev.slm_per_thread})}
smem_cfg = min(shmem_conf * 1024 for shmem_conf in [32, 64, 100] if shmem_conf * 1024 >= self.shmem_usage) // 4096 + 1
self.qmd:QMD = QMD(dev, **qmd, qmd_group_id=0x3f, invalidate_texture_header_cache=1, invalidate_texture_sampler_cache=1,
invalidate_texture_data_cache=1, invalidate_shader_data_cache=1, api_visible_call_limit=1, sampler_index=1, barrier_count=1,
cwd_membar_type=nv_gpu.NVC6C0_QMDV03_00_CWD_MEMBAR_TYPE_L1_SYSMEMBAR, constant_buffer_invalidate_0=1,
min_sm_config_shared_mem_size=smem_cfg, target_sm_config_shared_mem_size=smem_cfg, max_sm_config_shared_mem_size=0x1a,
program_prefetch_size=min(self.prog_sz>>8, 0x1ff), sass_version=dev.sass_version,
program_prefetch_addr_upper_shifted=self.prog_addr>>40, program_prefetch_addr_lower_shifted=self.prog_addr>>8)
cwd_membar_type=nv_gpu.NVC6C0_QMDV03_00_CWD_MEMBAR_TYPE_L1_SYSMEMBAR, constant_buffer_invalidate_0=1, min_sm_config_shared_mem_size=smem_cfg,
target_sm_config_shared_mem_size=smem_cfg, max_sm_config_shared_mem_size=0x1a, program_prefetch_size=min(prog_sz>>8, 0x1ff),
sass_version=dev.sass_version, program_prefetch_addr_upper_shifted=prog_addr>>40, program_prefetch_addr_lower_shifted=prog_addr>>8)
for i,(addr,sz) in self.constbufs.items():
self.qmd.set_constant_buf_addr(i, addr)
@@ -526,7 +529,8 @@ class NVDevice(HCQCompiled[HCQSignal]):
self.sass_version = ((self.sm_version & 0xf00) >> 4) | (self.sm_version & 0xf)
compilers:list[CompilerPairT] = [(functools.partial(NVRenderer, self.arch),functools.partial(CUDACompiler if MOCKGPU else NVCompiler, self.arch)),
(functools.partial(PTXRenderer, self.arch, device="NV"), functools.partial(PTXCompiler if MOCKGPU else NVPTXCompiler, self.arch))]
(functools.partial(PTXRenderer, self.arch, device="NV"), functools.partial(PTXCompiler if MOCKGPU else NVPTXCompiler, self.arch)),
(functools.partial(NAKRenderer, dev=self), functools.partial(NAKCompiler, self.arch, self.max_warps_per_sm))]
super().__init__(device, NVAllocator(self), compilers, functools.partial(NVProgram, self), HCQSignal, NVComputeQueue, NVCopyQueue)
self._setup_gpfifos()

View File

@@ -0,0 +1,86 @@
import base64, ctypes, pathlib, tempfile, hashlib, subprocess
from tinygrad.device import Compiler
from tinygrad.helpers import cpu_objdump
import tinygrad.runtime.autogen.mesa as mesa
from tinygrad.runtime.support.compiler_cpu import CPULLVMCompiler, expect, cerr
try: import tinygrad.runtime.autogen.llvm as llvm
except (ImportError, FileNotFoundError): llvm = None #type:ignore[assignment]
def deserialize(enc_src, opts):
blobreader = mesa.struct_blob_reader()
mesa.blob_reader_init(blobreader, src:=base64.b64decode(enc_src), len(src))
return mesa.nir_deserialize(None, ctypes.cast(opts, ctypes.POINTER(mesa.nir_shader_compiler_options)), blobreader)
class NIRCompiler(Compiler):
def __init__(self, cache_key):
mesa.glsl_type_singleton_init_or_ref()
super().__init__(cache_key)
def __del__(self): mesa.glsl_type_singleton_decref()
class LVPCompiler(CPULLVMCompiler, NIRCompiler):
def __init__(self, cache_key="lvp"):
CPULLVMCompiler.__init__(self)
NIRCompiler.__init__(self, f"compile_{cache_key}")
def __del__(self):
NIRCompiler.__del__(self)
CPULLVMCompiler.__del__(self)
def compile(self, src) -> bytes:
shader, ctx = deserialize(src, mesa.lvp_nir_options), llvm.LLVMGetGlobalContext()
gallivm = mesa.gallivm_create(None, mesa.lp_context_ref(ctypes.cast(ctx, ctypes.POINTER(mesa.struct_LLVMOpaqueContext)), True), None).contents
module, builder = ctypes.cast(gallivm.module, llvm.LLVMModuleRef), ctypes.cast(gallivm.builder, llvm.LLVMBuilderRef)
params = mesa.struct_lp_build_tgsi_params(mesa.struct_lp_type(floating=True, sign=True, width=32, length=4),
resources_type=mesa.lp_build_jit_resources_type(gallivm), mask=ctypes.pointer(mesa.struct_lp_build_mask_context()))
pt = llvm.LLVMPointerType(ctypes.cast(params.resources_type, llvm.LLVMTypeRef), 0)
fn = llvm.LLVMAddFunction(module, shader.contents.info.name, llvm.LLVMFunctionType(llvm.LLVMVoidTypeInContext(ctx), pt, 1, 0))
llvm.LLVMPositionBuilderAtEnd(builder, llvm.LLVMAppendBasicBlockInContext(ctx, fn, b"entry"))
params.consts_ptr = mesa.lp_build_struct_get_ptr2(gallivm, params.resources_type,
ctypes.cast(llvm.LLVMGetParam(fn, 0), mesa.LLVMValueRef), mesa.LP_JIT_RES_CONSTANTS, b"constants")
mesa.lp_build_mask_begin(params.mask, gallivm, params.type, mesa.lp_build_one(gallivm, params.type))
mesa.lp_build_mask_end(params.mask)
mesa.lp_build_nir_soa(gallivm, shader, params, None)
llvm.LLVMBuildRetVoid(builder)
mesa.gallivm_verify_function(gallivm, ctypes.cast(fn, mesa.LLVMValueRef))
mesa.lp_passmgr_run(gallivm.passmgr, gallivm.module, ctypes.cast(self.target_machine, mesa.LLVMTargetMachineRef), gallivm.module_name)
obj_buf = expect(llvm.LLVMTargetMachineEmitToMemoryBuffer(self.target_machine, module, llvm.LLVMObjectFile, err:=cerr(),
ctypes.pointer(buf:=llvm.LLVMMemoryBufferRef())), err, buf)
obj = ctypes.string_at(llvm.LLVMGetBufferStart(obj_buf), llvm.LLVMGetBufferSize(obj_buf))
mesa.gallivm_destroy(gallivm)
mesa.ralloc_free(shader)
return obj
def disassemble(self, lib: bytes): cpu_objdump(lib)
class NAKCompiler(NIRCompiler):
def __init__(self, arch, warps_per_sm, cache_key="nak"):
self.arch, self.warps_per_sm = arch, warps_per_sm
self.cc = mesa.nak_compiler_create(mesa.struct_nv_device_info(sm=int(arch[3:]), max_warps_per_mp=warps_per_sm))
self.nir_options = bytes(mesa.nak_nir_options(self.cc).contents)
super().__init__(f"compile_{cache_key}_{arch}")
def __del__(self):
mesa.nak_compiler_destroy(self.cc)
super().__del__()
def __reduce__(self): return NAKCompiler, (self.arch, self.warps_per_sm)
def compile(self, src) -> bytes:
shader = deserialize(src, self.nir_options)
mesa.nak_preprocess_nir(shader, self.cc)
ret = bytes((out:=mesa.nak_compile_shader(shader, False, self.cc, 0, None).contents).info) + ctypes.string_at(out.code, out.code_size)
mesa.nak_shader_bin_destroy(out)
mesa.ralloc_free(shader)
return ret
def disassemble(self, lib: bytes):
try:
fn = (pathlib.Path(tempfile.gettempdir()) / f"tinynak_{hashlib.md5(lib).hexdigest()}").as_posix()
with open(fn, "wb") as f: f.write(lib[ctypes.sizeof(mesa.struct_nak_shader_info):])
print(subprocess.check_output(['nvdisasm', "-b", f"SM{self.arch[3:]}", fn]).decode('utf-8'))
except Exception as e: print("Failed to generate SASS", str(e), "Make sure your PATH contains nvdisasm binary of compatible version.")

View File

@@ -1,12 +1,18 @@
import struct
import struct, ctypes, ctypes.util
from dataclasses import dataclass
from tinygrad.helpers import getbits, i2u
from tinygrad.helpers import getbits, i2u, unwrap
import tinygrad.runtime.autogen.libc as libc
@dataclass(frozen=True)
class ElfSection: name:str; header:libc.Elf64_Shdr; content:bytes # noqa: E702
def elf_loader(blob:bytes, force_section_align:int=1) -> tuple[memoryview, list[ElfSection], list[tuple]]:
def link_sym(sym:str, libs:list[str]) -> int:
for lib in libs:
try: return unwrap(ctypes.cast(getattr(ctypes.CDLL(ctypes.util.find_library(lib)), sym), ctypes.c_void_p).value)
except (OSError, AttributeError): pass
raise RuntimeError(f'Attempting to relocate against an undefined symbol {sym}')
def elf_loader(blob:bytes, force_section_align:int=1, link_libs:list[str]|None=None) -> tuple[memoryview, list[ElfSection], list[tuple]]:
def _strtab(blob: bytes, idx: int) -> str: return blob[idx:blob.find(b'\x00', idx)].decode('utf-8')
header = libc.Elf64_Ehdr.from_buffer_copy(blob)
@@ -31,33 +37,42 @@ def elf_loader(blob:bytes, force_section_align:int=1) -> tuple[memoryview, list[
# Relocations
relocs = []
for sh, trgt_sh_name, c_rels in rel + rela:
if trgt_sh_name == ".eh_frame": continue
target_image_off = next(tsh for tsh in sections if tsh.name == trgt_sh_name).header.sh_addr
rels = [(r.r_offset, symtab[libc.ELF64_R_SYM(r.r_info)], libc.ELF64_R_TYPE(r.r_info), getattr(r, "r_addend", 0)) for r in c_rels]
for _, sym, _, _ in rels:
if sym.st_shndx == 0: raise RuntimeError(f'Attempting to relocate against an undefined symbol {repr(_strtab(sh_strtab, sym.st_name))}')
relocs += [(target_image_off + roff, sections[sym.st_shndx].header.sh_addr + sym.st_value, rtype, raddend) for roff, sym, rtype, raddend in rels]
relocs += [(target_image_off + roff, link_sym(_strtab(sh_strtab, sym.st_name), link_libs or []) if sym.st_shndx == 0 else
sections[sym.st_shndx].header.sh_addr + sym.st_value, rtype, raddend) for roff, sym, rtype, raddend in rels]
return memoryview(image), sections, relocs
def relocate(instr: int, ploc: int, tgt: int, r_type: int):
match r_type:
# https://refspecs.linuxfoundation.org/elf/x86_64-abi-0.95.pdf
case libc.R_X86_64_PC32: return i2u(32, tgt-ploc)
# https://github.com/ARM-software/abi-aa/blob/main/aaelf64/aaelf64.rst for definitions of relocations
# https://www.scs.stanford.edu/~zyedidia/arm64/index.html for instruction encodings
case libc.R_AARCH64_ADR_PREL_PG_HI21:
rel_pg = (tgt & ~0xFFF) - (ploc & ~0xFFF)
return instr | (getbits(rel_pg, 12, 13) << 29) | (getbits(rel_pg, 14, 32) << 5)
case libc.R_AARCH64_ADD_ABS_LO12_NC: return instr | (getbits(tgt, 0, 11) << 10)
case libc.R_AARCH64_LDST16_ABS_LO12_NC: return instr | (getbits(tgt, 1, 11) << 10)
case libc.R_AARCH64_LDST32_ABS_LO12_NC: return instr | (getbits(tgt, 2, 11) << 10)
case libc.R_AARCH64_LDST64_ABS_LO12_NC: return instr | (getbits(tgt, 3, 11) << 10)
case libc.R_AARCH64_LDST128_ABS_LO12_NC: return instr | (getbits(tgt, 4, 11) << 10)
raise NotImplementedError(f"Encountered unknown relocation type {r_type}")
def jit_loader(obj: bytes, base:int=0, link_libs:list[str]|None=None) -> bytes:
image_, _, relocs = elf_loader(obj, link_libs=link_libs)
image = bytearray(image_)
def relocate(instr: int, base: int, ploc: int, tgt: int, r_type: int):
match r_type:
# https://refspecs.linuxfoundation.org/elf/x86_64-abi-0.95.pdf
case libc.R_X86_64_PC32: return i2u(32, tgt-ploc)
case libc.R_X86_64_PLT32: return i2u(32, tgt-ploc-base)
# https://github.com/ARM-software/abi-aa/blob/main/aaelf64/aaelf64.rst for definitions of relocations
# https://www.scs.stanford.edu/~zyedidia/arm64/index.html for instruction encodings
case libc.R_AARCH64_ADR_PREL_PG_HI21:
rel_pg = (tgt & ~0xFFF) - (ploc & ~0xFFF)
return instr | (getbits(rel_pg, 12, 13) << 29) | (getbits(rel_pg, 14, 32) << 5)
case libc.R_AARCH64_ADD_ABS_LO12_NC: return instr | (getbits(tgt, 0, 11) << 10)
case libc.R_AARCH64_LDST16_ABS_LO12_NC: return instr | (getbits(tgt, 1, 11) << 10)
case libc.R_AARCH64_LDST32_ABS_LO12_NC: return instr | (getbits(tgt, 2, 11) << 10)
case libc.R_AARCH64_LDST64_ABS_LO12_NC: return instr | (getbits(tgt, 3, 11) << 10)
case libc.R_AARCH64_LDST128_ABS_LO12_NC: return instr | (getbits(tgt, 4, 11) << 10)
case libc.R_AARCH64_CALL26:
if -(2**25) <= tgt-ploc-base and tgt-ploc-base <= (2**25 - 1) * 4: return instr | getbits(tgt-ploc-base, 2, 27)
nonlocal image
# create trampoline: LDR x17, 8 BR x17
image += struct.pack("<IIQ", 0x58000051, 0xD61F0220, tgt)
return instr | getbits(len(image)-ploc-16, 2, 27)
raise NotImplementedError(f"Encountered unknown relocation type {r_type}")
def jit_loader(obj: bytes) -> bytes:
image, _, relocs = elf_loader(obj)
# This is needed because we have an object file, not a .so that has all internal references (like loads of constants from .rodata) resolved.
for ploc,tgt,r_type,r_addend in relocs:
image[ploc:ploc+4] = struct.pack("<I", relocate(struct.unpack("<I", image[ploc:ploc+4])[0], ploc, tgt+r_addend, r_type))
image[ploc:ploc+4] = struct.pack("<I", relocate(struct.unpack("<I", image[ploc:ploc+4])[0], base, ploc, tgt+r_addend, r_type))
return bytes(image)