mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-10 07:28:15 -05:00
llvm diagnostic error (#10267)
* llvm diagnostic info * use decorator * better error reporting * fix mypy * collect all diag msgs * test diag error --------- Co-authored-by: b1tg <b1tg@users.noreply.github.com> Co-authored-by: chenyu <chenyu@fastmail.com>
This commit is contained in:
@@ -1,6 +1,7 @@
|
||||
import unittest
|
||||
import numpy as np
|
||||
from tinygrad import Device
|
||||
from tinygrad.device import CompileError
|
||||
from tinygrad.helpers import flat_mv
|
||||
if Device.DEFAULT=="AMD":
|
||||
from tinygrad.runtime.ops_amd import AMDAllocator, AMDDevice, AMDProgram
|
||||
@@ -28,5 +29,24 @@ entry:
|
||||
allocator._copyout(flat_mv(na.data), a)
|
||||
assert na == [0x1234567800000005]
|
||||
|
||||
def test_compiler_diag_error(self):
|
||||
src = """
|
||||
@local_temp0 = internal unnamed_addr addrspace(3) global [{N} x float*] undef, align 16
|
||||
define amdgpu_kernel void @test(float* noalias align 32 %data0, half* noalias align 32 %data1, float* noalias align 32 %data2) #0
|
||||
{{
|
||||
%local_temp0 = addrspacecast [{N} x float*] addrspace(3)* @local_temp0 to [{N} x float*]*
|
||||
%v178 = getelementptr inbounds float, float* %local_temp0, i32 1
|
||||
%v133 = getelementptr inbounds float, float* %data2, i32 1
|
||||
%v134 = load float, float* %v133
|
||||
store float %v134, float* %v178
|
||||
ret void
|
||||
}}
|
||||
"""
|
||||
compiler = AMDLLVMCompiler("gfx1100")
|
||||
compiler.compile(src.format(N=65536//8))
|
||||
with self.assertRaises(CompileError):
|
||||
# llvm diagnostic: <unknown>:0:0: local memory (65544) exceeds limit (65536) in function 'test'
|
||||
compiler.compile(src.format(N=65536//8+1))
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
||||
@@ -33,11 +33,21 @@ class LLVMCompiler(Compiler):
|
||||
else:
|
||||
self.passes = b'default<O0>'
|
||||
|
||||
self.diag_msgs: list[str] = []
|
||||
@ctypes.CFUNCTYPE(None, llvm.LLVMDiagnosticInfoRef, ctypes.c_void_p)
|
||||
def handle_diag(diag_ref, _arg):
|
||||
severity = llvm.LLVMGetDiagInfoSeverity(diag_ref)
|
||||
msg = ctypes.string_at(llvm.LLVMGetDiagInfoDescription(diag_ref)).decode()
|
||||
if severity == llvm.LLVMDSError:
|
||||
self.diag_msgs.append(msg)
|
||||
self.handle_diag = handle_diag
|
||||
llvm.LLVMContextSetDiagnosticHandler(llvm.LLVMGetGlobalContext(), handle_diag, None)
|
||||
super().__init__(f"compile_llvm_{self.target_arch}{'_jit' if self.jit else ''}{'_opt' if opt else ''}")
|
||||
|
||||
def __del__(self): llvm.LLVMDisposePassBuilderOptions(self.pbo)
|
||||
|
||||
def compile(self, src:str) -> bytes:
|
||||
self.diag_msgs.clear()
|
||||
src_buf = llvm.LLVMCreateMemoryBufferWithMemoryRangeCopy(ctypes.create_string_buffer(src_bytes:=src.encode()), len(src_bytes), b'src')
|
||||
mod = expect(llvm.LLVMParseIRInContext(llvm.LLVMGetGlobalContext(), src_buf, ctypes.pointer(m:=llvm.LLVMModuleRef()), err:=cerr()), err, m)
|
||||
expect(llvm.LLVMVerifyModule(mod, llvm.LLVMReturnStatusAction, err:=cerr()), err)
|
||||
@@ -48,6 +58,7 @@ class LLVMCompiler(Compiler):
|
||||
llvm.LLVMDisposeModule(mod)
|
||||
obj = ctypes.string_at(llvm.LLVMGetBufferStart(obj_buf), llvm.LLVMGetBufferSize(obj_buf))
|
||||
llvm.LLVMDisposeMemoryBuffer(obj_buf)
|
||||
if self.diag_msgs: raise RuntimeError("llvm diagnostic: " + "\n".join(self.diag_msgs))
|
||||
return jit_loader(obj) if self.jit else obj
|
||||
|
||||
def disassemble(self, lib:bytes): capstone_flatdump(lib)
|
||||
|
||||
Reference in New Issue
Block a user