llvmpipe kinda works

This commit is contained in:
Christopher Milan
2025-09-17 14:58:52 -07:00
parent c7bc3832dc
commit d37ac3ac16
10 changed files with 11360 additions and 115 deletions

View File

@@ -442,7 +442,7 @@ generate_libusb() {
python3 -c "import tinygrad.runtime.autogen.libusb"
}
generate_nak() {
generate_mesa() {
MESA_COMMIT_HASH=9e0991eff5aea2e064fc16d5c7fa0ee6cd52d894
MESA_SRC=/tmp/mesa-$MESA_COMMIT_HASH
if [ ! -d "$MESA_SRC" ]; then
@@ -461,39 +461,58 @@ generate_nak() {
python3 src/compiler/nir/nir_intrinsics_indices_h.py --outdir gen
python3 src/compiler/nir/nir_builder_opcodes_h.py > gen/nir_builder_opcodes.h
python3 src/compiler/nir/nir_intrinsics_h.py --outdir gen
python3 src/compiler/builtin_types_h.py gen/builtin_types.h
popd
fi
clang2py -k cdefstu \
$MESA_SRC/src/nouveau/headers/nv_device_info.h \
$MESA_SRC/src/nouveau/compiler/nak.h \
--clang-args="-DHAVE_ENDIAN_H -I$MESA_SRC/src -I$MESA_SRC/include -I$MESA_SRC/src/compiler/nir" \
$MESA_SRC/src/compiler/nir/nir_shader_compiler_options.h \
$MESA_SRC/src/compiler/nir/nir_serialize.h \
$MESA_SRC/src/util/blob.h \
$MESA_SRC/src/compiler/glsl_types.h \
--clang-args="-DHAVE_ENDIAN_H -DHAVE_STRUCT_TIMESPEC -DHAVE_PTHREAD -I$MESA_SRC/src -I$MESA_SRC/include -I$MESA_SRC/src/compiler/nir -I$MESA_SRC/gen" \
-o $BASE/nak.py
clang2py -v -k cdefstu \
clang2py -k cdefstu \
$MESA_SRC/src/compiler/list.h \
$MESA_SRC/src/compiler/nir/nir.h \
$MESA_SRC/src/compiler/nir/nir_builder.h \
$MESA_SRC/src/compiler/nir/nir_shader_compiler_options.h \
$MESA_SRC/src/compiler/nir/nir_serialize.h \
$MESA_SRC/gen/nir_intrinsics.h \
$MESA_SRC/src/compiler/glsl_types.h \
--clang-args="-DHAVE_ENDIAN_H -DHAVE_STRUCT_TIMESPEC -DHAVE_PTHREAD -I$MESA_SRC/src -I$MESA_SRC/include -I$MESA_SRC/src/compiler/nir -I$MESA_SRC/gen" \
-o $BASE/nir.py
fixup $BASE/nak.py
clang2py -k cdefstu \
$MESA_SRC/src/gallium/auxiliary/gallivm/lp_bld.h \
$MESA_SRC/src/gallium/auxiliary/gallivm/lp_bld_type.h \
$MESA_SRC/src/gallium/auxiliary/gallivm/lp_bld_init.h \
$MESA_SRC/src/gallium/auxiliary/gallivm/lp_bld_nir.h \
$MESA_SRC/src/gallium/auxiliary/gallivm/lp_bld_struct.h \
$MESA_SRC/src/gallium/auxiliary/gallivm/lp_bld_jit_types.h \
$MESA_SRC/src/gallium/auxiliary/gallivm/lp_bld_flow.h \
$MESA_SRC/src/gallium/auxiliary/gallivm/lp_bld_const.h \
$MESA_SRC/src/compiler/nir/nir_shader_compiler_options.h \
$MESA_SRC/src/compiler/nir/nir_serialize.h \
$MESA_SRC/src/util/blob.h \
--clang-args="-DHAVE_ENDIAN_H -DHAVE_STRUCT_TIMESPEC -DHAVE_PTHREAD -I$MESA_SRC/src -I$MESA_SRC/include -I$MESA_SRC/gen -I$MESA_SRC/src/compiler/nir -I$MESA_SRC/src/gallium/auxiliary -I$MESA_SRC/src/gallium/include -I$(llvm-config-14 --includedir)" \
-o $BASE/lvp.py
fixup $BASE/nir.py
sed -i "s\FunctionFactoryStub()\ctypes.CDLL('/usr/lib/x86_64-linux-gnu/libvulkan_nouveau.so')\g" $BASE/nir.py
sed -i "s\FunctionFactoryStub()\ctypes.CDLL('/usr/lib/x86_64-linux-gnu/libvulkan_nouveau.so')\g" $BASE/nak.py
sed -i "s\import ctypes\import ctypes, os\g" $BASE/nak.py
sed -i "s\import ctypes\import ctypes, os\g" $BASE/nir.py
sed -i "s\'/usr/\os.getenv('MESA_PATH', '/usr/')+'/\g" $BASE/nak.py
sed -i "s\'/usr/\os.getenv('MESA_PATH', '/usr/')+'/\g" $BASE/nir.py
sed -i "s/ctypes.glsl_base_type/glsl_base_type/" $BASE/nir.py
fixup $BASE/nak.py
fixup $BASE/lvp.py
sed -i "s\import ctypes\import ctypes, os\g" $BASE/nak.py $BASE/nir.py $BASE/lvp.py
sed -i "s\FunctionFactoryStub()\ctypes.CDLL(os.getenv('MESA_PATH', '/usr')+'/lib/x86_64-linux-gnu/libvulkan_lvp.so')\g" $BASE/nir.py $BASE/lvp.py
sed -i "s\FunctionFactoryStub()\ctypes.CDLL(os.getenv('MESA_PATH', '/usr')+'/lib/x86_64-linux-gnu/libvulkan_nouveau.so')\g" $BASE/nak.py
sed -i "s/ctypes.glsl_base_type/glsl_base_type/" $BASE/nak.py $BASE/nir.py $BASE/lvp.py
# bitfield bug in clang2py
sed -i "s/('fp_fast_math', ctypes.c_bool, 9)/('fp_fast_math', ctypes.c_uint32, 9)/" $BASE/nir.py
sed -i "s/\([0-9]\+\)()/\1/" $BASE/nir.py
sed -i "s/\(struct_nir_builder._pack_\) = 1/\1 = 0/" $BASE/nir.py
python3 -c "import tinygrad.runtime.autogen.nak, tinygrad.runtime.autogen.nir"
python3 -c "import tinygrad.runtime.autogen.nak, tinygrad.runtime.autogen.nir, tinygrad.runtime.autogen.lvp"
}
if [ "$1" == "opencl" ]; then generate_opencl
@@ -519,7 +538,7 @@ elif [ "$1" == "pci" ]; then generate_pci
elif [ "$1" == "vfio" ]; then generate_vfio
elif [ "$1" == "webgpu" ]; then generate_webgpu
elif [ "$1" == "libusb" ]; then generate_libusb
elif [ "$1" == "nak" ]; then generate_nak
elif [ "$1" == "all" ]; then generate_opencl; generate_hip; generate_comgr; generate_cuda; generate_nvrtc; generate_hsa; generate_kfd; generate_nv; generate_amd; generate_io_uring; generate_libc; generate_am; generate_webgpu; generate_nak
elif [ "$1" == "mesa" ]; then generate_mesa
elif [ "$1" == "all" ]; then generate_opencl; generate_hip; generate_comgr; generate_cuda; generate_nvrtc; generate_hsa; generate_kfd; generate_nv; generate_amd; generate_io_uring; generate_libc; generate_am; generate_webgpu; generate_mesa
else echo "usage: $0 <type>"
fi

View File

@@ -0,0 +1,69 @@
/* cc gallivm_nir_options.c -o gallivm_nir_options -I$MESA_SRC/src/compiler/nir -I$MESA_SRC/src -I$MESA_SRC/include
* ./gallivm_nir_options | gzip |base64
*/
#define HAVE_ENDIAN_H
#define HAVE_STRUCT_TIMESPEC
#define HAVE_PTHREAD
#include <unistd.h>
#include "nir_shader_compiler_options.h"
#include "compiler/shader_enums.h"
struct nir_shader_compiler_options gallivm_nir_options = {
.lower_scmp = true,
.lower_flrp32 = true,
.lower_flrp64 = true,
.lower_fsat = true,
.lower_bitfield_insert = true,
.lower_bitfield_extract8 = true,
.lower_bitfield_extract16 = true,
.lower_bitfield_extract = true,
.lower_fdph = true,
.lower_ffma16 = true,
.lower_ffma32 = true,
.lower_ffma64 = true,
.lower_flrp16 = true,
.lower_fmod = true,
.lower_hadd = true,
.lower_uadd_sat = true,
.lower_usub_sat = true,
.lower_iadd_sat = true,
.lower_ldexp = true,
.lower_pack_snorm_2x16 = true,
.lower_pack_snorm_4x8 = true,
.lower_pack_unorm_2x16 = true,
.lower_pack_unorm_4x8 = true,
.lower_pack_half_2x16 = true,
.lower_pack_64_4x16 = true,
.lower_pack_split = true,
.lower_unpack_snorm_2x16 = true,
.lower_unpack_snorm_4x8 = true,
.lower_unpack_unorm_2x16 = true,
.lower_unpack_unorm_4x8 = true,
.lower_unpack_half_2x16 = true,
.lower_extract_byte = true,
.lower_extract_word = true,
.lower_insert_byte = true,
.lower_insert_word = true,
.lower_uadd_carry = true,
.lower_usub_borrow = true,
.lower_mul_2x32_64 = true,
.lower_ifind_msb = true,
.lower_int64_options = nir_lower_imul_2x32_64 | nir_lower_bitfield_extract64,
.lower_doubles_options = nir_lower_dround_even,
.max_unroll_iterations = 32,
.lower_to_scalar = true,
.lower_uniforms_to_ubo = true,
.lower_vector_cmp = true,
.lower_device_index_to_zero = true,
.support_16bit_alu = true,
.lower_fisnormal = true,
.lower_fquantize2f16 = true,
.lower_fminmax_signed_zero = true,
.driver_functions = true,
.scalarize_ddx = true,
.support_indirect_inputs = (uint8_t)BITFIELD_MASK(MESA_SHADER_STAGES),
.support_indirect_outputs = (uint8_t)BITFIELD_MASK(MESA_SHADER_STAGES),
};
int main(void) { write(1, &gallivm_nir_options, sizeof(gallivm_nir_options)); }

View File

@@ -119,12 +119,12 @@ def nlid(b:nir.nir_builder) -> nir.nir_def:
def nreg_idx(b:nir.nir_builder, reg:nir.nir_variable, idx:nir.nir_def) -> nir.nir_def:
parent = nir.nir_deref_instr_create(b.shader, nir.nir_deref_type_var)
parent.contents.modes, parent.contents.type, parent.contents.var = reg.data.mode, reg.type, ctypes.pointer(reg)
nir.nir_def_init(parent.contents.instr, d(parent), 1, 64)
nir.nir_def_init(parent.contents.instr, d(parent), 1, 32)
nir.nir_builder_instr_insert(b, parent.contents.instr)
deref = nir.nir_deref_instr_create(b.shader, nir.nir_deref_type_array)
deref.contents.modes, deref.contents.type = reg.data.mode, nir.glsl_get_array_element(reg.type)
deref.contents.parent, deref.contents.arr.index = nir_src_for_ssa(d(parent)), nir_src_for_ssa(idx)
nir.nir_def_init(deref.contents.instr, d(deref), 1, 64)
nir.nir_def_init(deref.contents.instr, d(deref), 1, 32)
nir.nir_builder_instr_insert(b, deref.contents.instr)
return d(deref)
@@ -208,21 +208,12 @@ class NIRRenderer(Renderer):
lambda buf,idx,val,gate: UOp.store(buf.index(idx), val, gate)),
])
def nv_param(self, dtype:DType, sz:int) -> nir.nir_def:
intrin = nir.nir_intrinsic_instr_create(self.b.shader, nir.nir_intrinsic_ldc_nv)
intrin.contents.num_components = 1
nir.nir_def_init(intrin.contents.instr, d(intrin), 1, sz * 8)
arr = ctypes.cast(intrin.contents.src, ctypes.POINTER(nir.nir_src))
arr[0], arr[1] = nir_src_for_ssa(nimm(self.b, 0, dtypes.int)), nir_src_for_ssa(nimm(self.b, self.param_idx, dtypes.int))
nir_intrinsic_set(nir.NIR_INTRINSIC_ALIGN_MUL, intrin, sz)
nir.nir_builder_instr_insert(self.b, intrin.contents.instr)
self.param_idx += sz
return d(intrin)
def param(self, dtype:DType, sz:int) -> nir.nir_def: raise NotImplementedError("needs param")
def_rewrite = PatternMatcher([
(UPat(Ops.CONST, name="x"), lambda ctx,x: nimm(ctx.b, x.arg, x.dtype)),
(UPat(Ops.DEFINE_GLOBAL, name="x"), lambda ctx,x: ctx.nv_param(x.dtype, 8)),
(UPat(Ops.DEFINE_VAR, name="x"), lambda ctx,x: ctx.nv_param(x.dtype, 4)),
(UPat(Ops.DEFINE_GLOBAL, name="x"), lambda ctx,x: ctx.param(x.dtype, 8)),
(UPat(Ops.DEFINE_VAR, name="x"), lambda ctx,x: ctx.param(x.dtype, 4)),
(UPat(Ops.SPECIAL, name="x"), lambda ctx,x: nchannel(ctx.b, ngid(ctx.b) if x.arg[0] == 'g' else nlid(ctx.b), int(x.arg[-1]))),
(UPat(Ops.STORE, src=(UPat.var("loc"), UPat.var("val")), allow_any_len=True, name="x"),
lambda ctx,x,loc,val: nstore(ctx.b, AddrSpace(x.arg), ctx.r[loc], ctx.r[val], val.dtype)),
@@ -244,20 +235,22 @@ class NIRRenderer(Renderer):
(UPat(Ops.ENDIF, name="x"), lambda ctx,x: ensure(nir.nir_pop_if(ctx.b, ctx.r[x.src[0]])))
])
def __init__(self, dev, device="NV"): self.device, self.dev = device, dev
def __init__(self, dev, device): self.device, self.dev = device, dev
def prerender(self, uops:list[UOp]):
nir.glsl_type_singleton_init_or_ref() # TODO: call glsl_type_singleton_decref somewhere
self.b = nir.nir_builder_init_simple_shader(nir.MESA_SHADER_COMPUTE, self.dev.compiler.nir_options, None)
def render(self, uops:list[UOp]) -> str:
self.b = nir.nir_builder_init_simple_shader(nir.MESA_SHADER_COMPUTE, self.dev.compiler.nir_options, None)
self.prerender(uops)
for u in [u for u in uops if u.op is Ops.SPECIAL and u.arg[0] == "l"]: self.b.shader.contents.info.workgroup_size[int(u.arg[-1])] = u.src[0].arg
self.r, self.param_idx = {}, 0
ranges: list[Tuple[nir.nir_loop, nir.nir_phi_instr]] = []
self.r, self.param_idx, ranges = {}, 0, []
# import os
# input(f"pid: {os.getpid()}")
nir.glsl_type_singleton_init_or_ref() # TODO: call glsl_type_singleton_decref somewhere
for u in uops:
# print(u)
# nir.nir_print_shader(b.shader, stdout)
# nir.nir_print_shader(self.b.shader, stdout)
if u.op == Ops.NOOP: pass
elif u.op == Ops.SINK:
# why do we care about setting this?
@@ -283,4 +276,45 @@ class NIRRenderer(Renderer):
raise RuntimeError(f"failed to render {u.op} with {u.dtype} srcs {[x.dtype for x in u.src]}")
self.r[u] = cast(nir.nir_def, d)
nir.nir_print_shader(self.b.shader, stdout)
return self.b.shader.contents
blob = nir.struct_blob()
nir.nir_serialize(blob, self.b.shader, False)
return ctypes.string_at(blob.data, blob.size)
class NAKRenderer(NIRRenderer):
def __init__(self, dev, device="NV"): super().__init__(dev, device)
def param(self, dtype:DType, sz:int) -> nir.nir_def:
intrin = nir.nir_intrinsic_instr_create(self.b.shader, nir.nir_intrinsic_ldc_nv)
intrin.contents.num_components = 1
nir.nir_def_init(intrin.contents.instr, d(intrin), 1, sz * 8)
arr = ctypes.cast(intrin.contents.src, ctypes.POINTER(nir.nir_src))
arr[0], arr[1] = nir_src_for_ssa(nimm(self.b, 0, dtypes.int)), nir_src_for_ssa(nimm(self.b, self.param_idx, dtypes.int))
nir_intrinsic_set(nir.NIR_INTRINSIC_ALIGN_MUL, intrin, sz)
nir.nir_builder_instr_insert(self.b, intrin.contents.instr)
self.param_idx += sz
return d(intrin)
class LVPRenderer(NIRRenderer):
has_local = False
global_max = (1, 0, 0)
def __init__(self, dev, device="CPU"): super().__init__(dev, device)
def param(self, dtype:DType, sz:int) -> nir.nir_def:
intrin = nir.nir_intrinsic_instr_create(self.b.shader, nir.nir_intrinsic_load_ubo)
intrin.contents.num_components = 1
nir.nir_def_init(intrin.contents.instr, d(intrin), 1, sz * 8)
arr = ctypes.cast(intrin.contents.src, ctypes.POINTER(nir.nir_src))
arr[0], arr[1] = nir_src_for_ssa(nimm(self.b, 0, dtypes.int)), nir_src_for_ssa(nimm(self.b, self.param_idx, dtypes.int))
nir_intrinsic_set(nir.NIR_INTRINSIC_ALIGN_MUL, intrin, sz)
nir_intrinsic_set(nir.NIR_INTRINSIC_RANGE, intrin, self.paramsz)
nir.nir_builder_instr_insert(self.b, intrin.contents.instr)
self.param_idx += sz
return d(intrin)
def prerender(self, uops:list[UOp]):
super().prerender(uops)
# TODO: does this do anything?
self.paramsz = sum([8 if u.op == Ops.DEFINE_GLOBAL else u.dtype.itemize for u in uops if u.op in (Ops.DEFINE_GLOBAL, Ops.DEFINE_VAR)])
nir.nir_variable_create(self.b.shader, nir.nir_var_mem_ubo, glsl_type(dtypes.uchar.ptr(self.paramsz)), s("kernel_input"))

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@@ -135,7 +135,7 @@ class FunctionFactoryStub:
# You can either re-run clan2py with -l /path/to/library.so
# Or manually fix this by comment the ctypes.CDLL loading
_libraries = {}
_libraries['FIXME_STUB'] = ctypes.CDLL(os.getenv('MESA_PATH', '/usr/')+'/lib/x86_64-linux-gnu/libvulkan_nouveau.so') # ctypes.CDLL('FIXME_STUB')
_libraries['FIXME_STUB'] = ctypes.CDLL(os.getenv('MESA_PATH', '/usr')+'/lib/x86_64-linux-gnu/libvulkan_lvp.so') # ctypes.CDLL('FIXME_STUB')
def string_cast(char_pointer, encoding='utf-8', errors='strict'):
value = ctypes.cast(char_pointer, ctypes.c_char_p).value
if value is not None and encoding is not None:
@@ -14729,6 +14729,30 @@ try:
nir_call_serialized.argtypes = [ctypes.POINTER(struct_nir_builder), ctypes.POINTER(ctypes.c_uint32), size_t, ctypes.POINTER(ctypes.POINTER(struct_nir_def))]
except AttributeError:
pass
try:
nir_serialize = _libraries['FIXME_STUB'].nir_serialize
nir_serialize.restype = None
nir_serialize.argtypes = [ctypes.POINTER(struct_blob), ctypes.POINTER(struct_nir_shader), ctypes.c_bool]
except AttributeError:
pass
try:
nir_deserialize = _libraries['FIXME_STUB'].nir_deserialize
nir_deserialize.restype = ctypes.POINTER(struct_nir_shader)
nir_deserialize.argtypes = [ctypes.POINTER(None), ctypes.POINTER(struct_nir_shader_compiler_options), ctypes.POINTER(struct_blob_reader)]
except AttributeError:
pass
try:
nir_serialize_function = _libraries['FIXME_STUB'].nir_serialize_function
nir_serialize_function.restype = None
nir_serialize_function.argtypes = [ctypes.POINTER(struct_blob), ctypes.POINTER(struct_nir_function)]
except AttributeError:
pass
try:
nir_deserialize_function = _libraries['FIXME_STUB'].nir_deserialize_function
nir_deserialize_function.restype = ctypes.POINTER(struct_nir_function)
nir_deserialize_function.argtypes = [ctypes.POINTER(None), ctypes.POINTER(struct_nir_shader_compiler_options), ctypes.POINTER(struct_blob_reader)]
except AttributeError:
pass
__all__ = \
['ACCESS_CAN_REORDER', 'ACCESS_CAN_SPECULATE', 'ACCESS_COHERENT',
'ACCESS_CP_GE_COHERENT_AMD', 'ACCESS_FMASK_LOWERED_AMD',
@@ -15510,7 +15534,8 @@ __all__ = \
'nir_deref_type__enumvalues', 'nir_deref_type_array',
'nir_deref_type_array_wildcard', 'nir_deref_type_cast',
'nir_deref_type_ptr_as_array', 'nir_deref_type_struct',
'nir_deref_type_var', 'nir_discard', 'nir_discard_if',
'nir_deref_type_var', 'nir_deserialize',
'nir_deserialize_function', 'nir_discard', 'nir_discard_if',
'nir_divergence_analysis', 'nir_divergence_analysis_impl',
'nir_divergence_ignore_undef_if_phi_srcs',
'nir_divergence_multiple_workgroup_per_compute_subgroup',
@@ -16811,7 +16836,8 @@ __all__ = \
'nir_selection_control_divergent_always_taken',
'nir_selection_control_dont_flatten',
'nir_selection_control_flatten', 'nir_selection_control_none',
'nir_shader', 'nir_shader_add_variable', 'nir_shader_alu_pass',
'nir_serialize', 'nir_serialize_function', 'nir_shader',
'nir_shader_add_variable', 'nir_shader_alu_pass',
'nir_shader_as_str', 'nir_shader_as_str_annotated',
'nir_shader_clear_pass_flags', 'nir_shader_clone',
'nir_shader_compiler_options', 'nir_shader_create',

View File

@@ -1,12 +1,18 @@
from __future__ import annotations
import platform, sys, ctypes, functools, time, mmap, threading, queue
from tinygrad.helpers import from_mv, to_mv, OSX, WIN, mv_address, wait_cond, cpu_profile, CPU_LLVM, suppress_finalizing
import platform, sys, ctypes, functools, time, mmap, threading, queue, struct
from tinygrad.helpers import from_mv, to_mv, OSX, WIN, mv_address, wait_cond, cpu_profile, CPU_LLVM, suppress_finalizing, getenv, data64_le, i2u
from tinygrad.device import BufferSpec, DMACPURef
from tinygrad.runtime.support.hcq import HCQCompiled, HCQAllocatorBase, HCQBuffer, HWQueue, HCQArgsState, HCQSignal, HCQProgram, MMIOInterface
from tinygrad.runtime.support.hcq import CLikeArgsState
from tinygrad.renderer.cstyle import ClangRenderer
from tinygrad.renderer.llvmir import LLVMRenderer
from tinygrad.runtime.support.compiler_cpu import HostLLVMCompiler, ClangJITCompiler
from tinygrad.uop.ops import sint
if (NIR := getenv("NIR")):
from tinygrad.renderer.nir import LVPRenderer
from tinygrad.runtime.support.lvp import LVPCompiler
from tinygrad.runtime.support.elf import elf_loader
import tinygrad.runtime.autogen.libc as libc
class CPUSignal(HCQSignal):
def _sleep(self, time_spent_waiting_ms:int):
@@ -46,6 +52,9 @@ class CPUComputeQueue(HWQueue):
def memory_barrier(self): return self
def exec(self, prg:CPUProgram, args_state:HCQArgsState, global_size, local_size):
if NIR:
self.bind_args_state(args_state)
return self.cmd(self._exec, prg, 1, ctypes.cast((ctypes.c_int * 3)(*data64_le(args_state.buf.va_addr), 0xFF), ctypes.c_void_p).value)
return self.cmd(self._exec, prg, len(args_state.bufs), *[x.va_addr for x in args_state.bufs], *args_state.vals, threads=(global_size or (1,))[0])
def wait(self, signal, value=0): return self.cmd(self._wait, signal.value_addr, value)
def timestamp(self, signal): return self.cmd(self._timestamp, signal.timestamp_addr)
@@ -74,7 +83,14 @@ class CPUProgram(HCQProgram):
self.mem = mmap.mmap(-1, len(lib), mmap.MAP_ANON|mmap.MAP_PRIVATE|(MAP_JIT if OSX else 0), mmap.PROT_READ|mmap.PROT_WRITE|mmap.PROT_EXEC)
if OSX: CPUProgram.rt_lib.pthread_jit_write_protect_np(False)
self.mem.write(lib)
if NIR:
(image, _, relocs), addr = elf_loader(lib), ctypes.addressof(ctypes.c_void_p.from_buffer(self.mem))
for ploc,tgt,r_type,r_addend in relocs:
match r_type:
case libc.R_X86_64_64: image[ploc:ploc+8] = struct.pack("<Q", i2u(64, tgt+r_addend+addr))
case _: raise NotImplementedError(f"Encountered unknown relocation type {r_type}")
self.mem.write(image)
else: self.mem.write(lib)
if OSX: CPUProgram.rt_lib.pthread_jit_write_protect_np(True)
# __clear_cache isn't a normal libc function, but a compiler support routine found in libgcc_s for gcc and compiler-rt for clang.
@@ -85,7 +101,7 @@ class CPUProgram(HCQProgram):
self.fxn = ctypes.CFUNCTYPE(None)(mv_address(self.mem))
super().__init__(HCQArgsState, dev, name, kernargs_alloc_size=0)
super().__init__(CLikeArgsState if NIR else HCQArgsState, dev, name, kernargs_alloc_size=0)
@suppress_finalizing
def __del__(self):
@@ -116,5 +132,5 @@ class CPUDevice(HCQCompiled):
def __init__(self, device:str=""):
self.tasks:queue.Queue = queue.Queue()
CPUWorker(self, self.tasks, thread_id=0).start()
super().__init__(device, CPUAllocator(self), LLVMRenderer() if CPU_LLVM else ClangRenderer(),
HostLLVMCompiler() if CPU_LLVM else ClangJITCompiler(), functools.partial(CPUProgram, self), CPUSignal, CPUComputeQueue)
super().__init__(device, CPUAllocator(self), LLVMRenderer() if CPU_LLVM else (LVPRenderer(self) if NIR else ClangRenderer()), HostLLVMCompiler()
if CPU_LLVM else (LVPCompiler() if NIR else ClangJITCompiler()), functools.partial(CPUProgram, self), CPUSignal, CPUComputeQueue)

View File

@@ -17,7 +17,7 @@ from tinygrad.runtime.support.nv.nvdev import NVDev, NVMemoryManager
from tinygrad.runtime.support.system import System, PCIIfaceBase, MAP_FIXED
if getenv("IOCTL"): import extra.nv_gpu_driver.nv_ioctl # noqa: F401 # pylint: disable=unused-import
if (NIR := getenv("NIR")):
from tinygrad.renderer.nir import NIRRenderer
from tinygrad.renderer.nir import NAKRenderer
from tinygrad.runtime.support.nak import NAKCompiler, parse_nak_shader
def get_error_str(status): return f"{status}: {nv_gpu.nv_status_codes.get(status, 'Unknown error')}"
@@ -538,7 +538,7 @@ class NVDevice(HCQCompiled[HCQSignal]):
self.sass_version = ((self.sm_version & 0xf00) >> 4) | (self.sm_version & 0xf)
cc = NAKCompiler(self) if NIR else ((PTXCompiler if PTX else CUDACompiler) if MOCKGPU else (NVPTXCompiler if PTX else NVCompiler))(self.arch)
rr = PTXRenderer(self.arch, device="NV") if PTX else (NIRRenderer(self) if NIR else NVRenderer(self.arch))
rr = PTXRenderer(self.arch, device="NV") if PTX else (NAKRenderer(self) if NIR else NVRenderer(self.arch))
super().__init__(device, NVAllocator(self), rr, cc, functools.partial(NVProgram, self), HCQSignal, NVComputeQueue, NVCopyQueue)
self._setup_gpfifos()

View File

@@ -19,6 +19,7 @@ def elf_loader(blob:bytes, force_section_align:int=1) -> tuple[memoryview, list[
rela = [(sh, sh.name[5:], _to_carray(sh, libc.Elf64_Rela)) for sh in sections if sh.header.sh_type == libc.SHT_RELA]
symtab = [_to_carray(sh, libc.Elf64_Sym) for sh in sections if sh.header.sh_type == libc.SHT_SYMTAB][0]
progbits = [sh for sh in sections if sh.header.sh_type == libc.SHT_PROGBITS]
progbits.sort(key=lambda sh: sh.name != ".ltext") # LLVM 20 ORCJIT puts .ltext after .rodata
# Prealloc image for all fixed addresses.
image = bytearray(max([sh.header.sh_addr + sh.header.sh_size for sh in progbits if sh.header.sh_addr != 0] + [0]))
@@ -31,6 +32,7 @@ def elf_loader(blob:bytes, force_section_align:int=1) -> tuple[memoryview, list[
# Relocations
relocs = []
for sh, trgt_sh_name, c_rels in rel + rela:
if trgt_sh_name not in [sh.name for sh in progbits]: continue
target_image_off = next(tsh for tsh in sections if tsh.name == trgt_sh_name).header.sh_addr
rels = [(r.r_offset, symtab[libc.ELF64_R_SYM(r.r_info)], libc.ELF64_R_TYPE(r.r_info), getattr(r, "r_addend", 0)) for r in c_rels]
for roff, sym, r_type_, r_addend in rels:

View File

@@ -1,27 +1,29 @@
from typing import cast, Tuple
from typing import Tuple
from tinygrad.device import Compiler
from tinygrad.helpers import round_up
import tinygrad.runtime.autogen.nak as nak
import tinygrad.runtime.autogen.nir as nir
import ctypes, pathlib, tempfile, hashlib, subprocess, functools
from tinygrad.runtime.autogen.nir import nir_shader_compiler_options
import ctypes, pathlib, tempfile, hashlib, subprocess
class NAKCompiler(Compiler):
def __init__(self, dev, cache_key="nak"):
self.arch = dev.arch
self.cc = nak.nak_compiler_create(nak.struct_nv_device_info(sm=int(dev.arch[3:]), max_warps_per_mp=dev.max_warps_per_sm))
self.nir_options = ctypes.cast(nak.nak_nir_options(self.cc), ctypes.POINTER(nir_shader_compiler_options))
super().__init__(f"compile_{cache_key}_{dev.arch}")
def compile(self, src) -> bytes:
shader = nak.struct_nir_shader.from_buffer(cast(bytes, src))
nak.glsl_type_singleton_init_or_ref() # TODO: call glsl_type_singleton_decref somewhere
blobreader = nak.struct_blob_reader()
nak.blob_reader_init(blobreader, src, len(src))
shader = nak.nir_deserialize(None, ctypes.cast(self.nir_options, ctypes.POINTER(nak.nir_shader_compiler_options)), blobreader)
nak.nak_preprocess_nir(shader, self.cc)
return nak.nak_compile_shader(shader, False, self.cc, 0, None).contents
def disassemble(self, lib: bytes):
try:
fn = (pathlib.Path(tempfile.gettempdir()) / f"tinynir_{hashlib.md5(lib).hexdigest()}").as_posix()
fn = (pathlib.Path(tempfile.gettempdir()) / f"tinynak_{hashlib.md5(lib).hexdigest()}").as_posix()
with open(fn, "wb") as f: f.write(parse_nak_shader(lib)[0])
print(subprocess.check_output(['nvdisasm', "-b", f"SM{self.arch[3:]}", fn]).decode('utf-8'))
except Exception as e: print("Failed to generate SASS", str(e), "Make sure your PATH contains nvdisasm binary of compatible version.")
@functools.cached_property
def nir_options(self): return ctypes.cast(nak.nak_nir_options(self.cc), ctypes.POINTER(nir.nir_shader_compiler_options))
def parse_nak_shader(shader:bytes) -> Tuple[memoryview, int, int, int]:
sb = nak.struct_nak_shader_bin.from_buffer(shader)