diff --git a/tinygrad/runtime/ops_llvm.py b/tinygrad/runtime/ops_llvm.py index 0dff0e4de3..cd22d83dff 100644 --- a/tinygrad/runtime/ops_llvm.py +++ b/tinygrad/runtime/ops_llvm.py @@ -11,12 +11,15 @@ def expect(x, err, ret=None): if x: raise RuntimeError(llvm.string_cast(err.contents) if not isinstance(err, str) else err) return ret -HOST_ARCH = {'arm64': 'AArch64', 'aarch64': 'AArch64', 'x86_64': 'X86', 'AMD64': 'X86'}[platform.machine()] -HOST_TRIPLE = {'AArch64': 'aarch64', 'X86': 'x86_64'}[HOST_ARCH] -REQUIRED_COMPONENTS = ['Target', 'TargetInfo', 'TargetMC', 'AsmPrinter'] - class LLVMCompiler(Compiler): - def __init__(self, target_machine, opt): + def __init__(self, host_arch:str, opt:bool): + for component in ['Target', 'TargetInfo', 'TargetMC', 'AsmPrinter']: getattr(llvm, f'LLVMInitialize{host_arch}{component}')() + triple = ({'AArch64': 'aarch64', 'X86': 'x86_64'}[host_arch]+'-none-unknown-elf').encode() + + target = expect(llvm.LLVMGetTargetFromTriple(triple, ctypes.pointer(tgt:=llvm.LLVMTargetRef()), err:=cerr()), err, tgt) + target_machine = llvm.LLVMCreateTargetMachine(target, triple, b'', b'+reserve-x18' if host_arch == 'arm64' else b'', + llvm.LLVMCodeGenLevelDefault, llvm.LLVMRelocPIC, llvm.LLVMCodeModelDefault) + self.pbo = llvm.LLVMCreatePassBuilderOptions() if opt: self.passes = b'default' @@ -48,14 +51,5 @@ class LLVMCompiler(Compiler): class LLVMDevice(Compiled): def __init__(self, device:str): - for component in REQUIRED_COMPONENTS: - getattr(llvm, f'LLVMInitialize{HOST_ARCH}{component}')() - - triple = f'{HOST_TRIPLE}-none-unknown-elf'.encode() - target = expect(llvm.LLVMGetTargetFromTriple(triple, ctypes.pointer(tgt:=llvm.LLVMTargetRef()), err:=cerr()), err, tgt) - features = b'+reserve-x18' if platform.machine() == 'arm64' else b'' - target_machine = llvm.LLVMCreateTargetMachine(target, triple, b'', features, llvm.LLVMCodeGenLevelDefault, llvm.LLVMRelocPIC, - llvm.LLVMCodeModelDefault) - - super().__init__(device, MallocAllocator, LLVMRenderer('win64cc' if sys.platform == 'win32' else None), - LLVMCompiler(target_machine, getenv("LLVMOPT")), CPUProgram) + compiler = LLVMCompiler({'arm64': 'AArch64', 'aarch64': 'AArch64', 'x86_64': 'X86', 'AMD64': 'X86'}[platform.machine()], bool(getenv("LLVMOPT"))) + super().__init__(device, MallocAllocator, LLVMRenderer('win64cc' if sys.platform == 'win32' else None), compiler, CPUProgram)