llvm diagnostic error (#10267)

* llvm diagnostic info

* use decorator

* better error reporting

* fix mypy

* collect all diag msgs

* test diag error

---------

Co-authored-by: b1tg <b1tg@users.noreply.github.com>
Co-authored-by: chenyu <chenyu@fastmail.com>
This commit is contained in:
b1tg
2025-05-16 14:03:20 +08:00
committed by GitHub
parent a4a25720b2
commit caded2f413
2 changed files with 31 additions and 0 deletions

View File

@@ -1,6 +1,7 @@
import unittest
import numpy as np
from tinygrad import Device
from tinygrad.device import CompileError
from tinygrad.helpers import flat_mv
if Device.DEFAULT=="AMD":
from tinygrad.runtime.ops_amd import AMDAllocator, AMDDevice, AMDProgram
@@ -28,5 +29,24 @@ entry:
allocator._copyout(flat_mv(na.data), a)
assert na == [0x1234567800000005]
def test_compiler_diag_error(self):
src = """
@local_temp0 = internal unnamed_addr addrspace(3) global [{N} x float*] undef, align 16
define amdgpu_kernel void @test(float* noalias align 32 %data0, half* noalias align 32 %data1, float* noalias align 32 %data2) #0
{{
%local_temp0 = addrspacecast [{N} x float*] addrspace(3)* @local_temp0 to [{N} x float*]*
%v178 = getelementptr inbounds float, float* %local_temp0, i32 1
%v133 = getelementptr inbounds float, float* %data2, i32 1
%v134 = load float, float* %v133
store float %v134, float* %v178
ret void
}}
"""
compiler = AMDLLVMCompiler("gfx1100")
compiler.compile(src.format(N=65536//8))
with self.assertRaises(CompileError):
# llvm diagnostic: <unknown>:0:0: local memory (65544) exceeds limit (65536) in function 'test'
compiler.compile(src.format(N=65536//8+1))
if __name__ == '__main__':
unittest.main()

View File

@@ -33,11 +33,21 @@ class LLVMCompiler(Compiler):
else:
self.passes = b'default<O0>'
self.diag_msgs: list[str] = []
@ctypes.CFUNCTYPE(None, llvm.LLVMDiagnosticInfoRef, ctypes.c_void_p)
def handle_diag(diag_ref, _arg):
severity = llvm.LLVMGetDiagInfoSeverity(diag_ref)
msg = ctypes.string_at(llvm.LLVMGetDiagInfoDescription(diag_ref)).decode()
if severity == llvm.LLVMDSError:
self.diag_msgs.append(msg)
self.handle_diag = handle_diag
llvm.LLVMContextSetDiagnosticHandler(llvm.LLVMGetGlobalContext(), handle_diag, None)
super().__init__(f"compile_llvm_{self.target_arch}{'_jit' if self.jit else ''}{'_opt' if opt else ''}")
def __del__(self): llvm.LLVMDisposePassBuilderOptions(self.pbo)
def compile(self, src:str) -> bytes:
self.diag_msgs.clear()
src_buf = llvm.LLVMCreateMemoryBufferWithMemoryRangeCopy(ctypes.create_string_buffer(src_bytes:=src.encode()), len(src_bytes), b'src')
mod = expect(llvm.LLVMParseIRInContext(llvm.LLVMGetGlobalContext(), src_buf, ctypes.pointer(m:=llvm.LLVMModuleRef()), err:=cerr()), err, m)
expect(llvm.LLVMVerifyModule(mod, llvm.LLVMReturnStatusAction, err:=cerr()), err)
@@ -48,6 +58,7 @@ class LLVMCompiler(Compiler):
llvm.LLVMDisposeModule(mod)
obj = ctypes.string_at(llvm.LLVMGetBufferStart(obj_buf), llvm.LLVMGetBufferSize(obj_buf))
llvm.LLVMDisposeMemoryBuffer(obj_buf)
if self.diag_msgs: raise RuntimeError("llvm diagnostic: " + "\n".join(self.diag_msgs))
return jit_loader(obj) if self.jit else obj
def disassemble(self, lib:bytes): capstone_flatdump(lib)