diff --git a/test/test_linearizer.py b/test/test_linearizer.py index 30a337759a..957f529c1b 100644 --- a/test/test_linearizer.py +++ b/test/test_linearizer.py @@ -1219,7 +1219,7 @@ class TestLinearizer(unittest.TestCase): assert len(sched) == 1 lin = Kernel(sched[0].ast) - assert sum(u.op is UnaryOps.RECIP for u in lin.linearize().uops) == max_ops, msg + assert sum(u.op in {UnaryOps.RECIP, BinaryOps.FDIV} for u in lin.linearize().uops) == max_ops, msg a = Tensor.empty((4,4)) b = Tensor.empty((4,4)) diff --git a/tinygrad/ops.py b/tinygrad/ops.py index 06cbbeb952..031de098b1 100644 --- a/tinygrad/ops.py +++ b/tinygrad/ops.py @@ -140,7 +140,7 @@ class Ops(FastEnum): # BinaryOps ADD = auto(); MUL = auto(); IDIV = auto(); MAX = auto(); MOD = auto(); CMPLT = auto(); CMPNE = auto(); XOR = auto() # noqa: E702 - SHL = auto(); SHR = auto(); OR = auto(); AND = auto(); THREEFRY = auto(); SUB = auto() # noqa: E702 + SHL = auto(); SHR = auto(); OR = auto(); AND = auto(); THREEFRY = auto(); SUB = auto(); FDIV = auto() # noqa: E702 # TernaryOps WHERE = auto(); MULACC = auto() # noqa: E702 @@ -168,7 +168,8 @@ class Ops(FastEnum): class GroupOp: Unary = {Ops.EXP2, Ops.LOG2, Ops.SIN, Ops.SQRT, Ops.RECIP, Ops.NEG} - Binary = {Ops.ADD, Ops.MUL, Ops.IDIV, Ops.MAX, Ops.MOD, Ops.CMPLT, Ops.CMPNE, Ops.XOR, Ops.SHL, Ops.SHR, Ops.OR, Ops.AND, Ops.THREEFRY, Ops.SUB} + Binary = {Ops.ADD, Ops.MUL, Ops.IDIV, Ops.MAX, Ops.MOD, Ops.CMPLT, Ops.CMPNE, Ops.XOR, Ops.SHL, Ops.SHR, Ops.OR, Ops.AND, Ops.THREEFRY, + Ops.SUB, Ops.FDIV} Ternary = {Ops.WHERE, Ops.MULACC} ALU = set.union(Unary, Binary, Ternary) diff --git a/tinygrad/renderer/llvmir.py b/tinygrad/renderer/llvmir.py index 9e5061a5be..897341c7f2 100644 --- a/tinygrad/renderer/llvmir.py +++ b/tinygrad/renderer/llvmir.py @@ -1,50 +1,77 @@ -from typing import Dict, Callable, List, Optional -from llvmlite import ir -from tinygrad.dtype import DType, PtrDType, dtypes -from tinygrad.ops import UnaryOps, BinaryOps, TernaryOps, Ops, UOp, GroupOp +from typing import List, Dict, cast +import math, struct from tinygrad.renderer import Renderer +from tinygrad.ops import UOp, PatternMatcher, UPat, Ops, GroupOp +from tinygrad.dtype import dtypes, DType, PtrDType, truncate -MFLAGS = ('nsz', 'arcp', 'contract', 'afn') # All from fast math, but nnan and ninf and reassoc +def ldt(dt:DType): + if isinstance(dt, PtrDType): return ldt(dt.base) + "*" + return {dtypes.int8: "i8", dtypes.int16: "i16", dtypes.int32: "i32", dtypes.int64: "i64", + dtypes.uint8: "i8", dtypes.uint16: "i16", dtypes.uint32: "i32", dtypes.uint64: "i64", + dtypes.float16: "half", dtypes.float32: "float", dtypes.float64: "double", dtypes.bool: "i1", dtypes.void: "void"}[dt] -def is_bool_or_unsigned(dtype: DType): return dtype == dtypes.bool or dtypes.is_unsigned(dtype) - -dtype_to_llvm_dtype = { dtypes.bool:ir.IntType(1), dtypes.int8:ir.IntType(8), dtypes.uint8:ir.IntType(8), dtypes.int16:ir.IntType(16), - dtypes.uint16:ir.IntType(16), dtypes.int32:ir.IntType(32), dtypes.uint32:ir.IntType(32), dtypes.int64:ir.IntType(64), dtypes.uint64:ir.IntType(64), - dtypes.float16:ir.HalfType(), dtypes.bfloat16:ir.IntType(16), dtypes.float32:ir.FloatType(), dtypes.float64:ir.DoubleType() } - -def cast(bb, val, input_type, output_type, bitcast=False): - if input_type == output_type: return val - llvm_type = dtype_to_llvm_dtype[output_type] - if bitcast: return bb[-1].bitcast(val, llvm_type) - - if input_type == dtypes.bfloat16: - val = bb[-1].bitcast(bb[-1].shl(bb[-1].sext(val, ir.IntType(32)), ir.Constant(ir.IntType(32), 16)),val, ir.FloatType()) - input_type = dtypes.float32 - if output_type == dtypes.bfloat16: - val = cast(bb, val, input_type, dtypes.float32) - return bb[-1].trunc(bb[-1].lshr(bb[-1].bitcast(val, ir.IntType(32)), ir.Constant(ir.IntType(32), 16)), ir.IntType(16)) +def lconst(x, dtype:DType): + if dtype in dtypes.floats: + if math.isinf(x) or math.isnan(x): return "0x%02X%02X%02X%02X%02X%02X%02X%02X" % tuple(struct.pack("d",x)[::-1]) + return truncate[dtype](x) + return int(x) +def lcast(input_type:DType, output_type:DType): if dtypes.is_float(input_type): - if dtypes.is_float(output_type): - return bb[-1].fpext(val, llvm_type) if output_type.itemsize > input_type.itemsize else bb[-1].fptrunc(val, llvm_type) - if dtypes.is_int(output_type): return bb[-1].fptoui(val, llvm_type) if dtypes.is_unsigned(output_type) else bb[-1].fptosi(val, llvm_type) - if output_type == dtypes.bool: return bb[-1].fcmp_unordered('!=', cast(bb, val, input_type, dtypes.float32), ir.Constant(ir.FloatType(), 0)) - + if dtypes.is_float(output_type): return 'fpext' if output_type.itemsize > input_type.itemsize else 'fptrunc' + if dtypes.is_int(output_type): return 'fptoui' if dtypes.is_unsigned(output_type) else 'fptosi' if dtypes.is_unsigned(input_type) or input_type == dtypes.bool: - if output_type == dtypes.float16: return bb[-1].fptrunc(bb[-1].uitofp(val, ir.FloatType()), ir.HalfType()) - if dtypes.is_float(output_type): return bb[-1].uitofp(val, dtype_to_llvm_dtype[output_type]) - if dtypes.is_int(output_type): return bb[-1].trunc(val, llvm_type) if input_type.itemsize > output_type.itemsize else bb[-1].zext(val, llvm_type) - if output_type == dtypes.bool: return bb[-1].icmp_unsigned('!=', val, ir.Constant(val.type, 0)) - + if dtypes.is_float(output_type): return 'uitofp' + if dtypes.is_int(output_type): return 'trunc' if output_type.itemsize < input_type.itemsize else 'zext' if dtypes.is_int(input_type): - if output_type == dtypes.float16: return bb[-1].fptrunc(bb[-1].sitofp(val, ir.FloatType()), ir.HalfType()) - if dtypes.is_float(output_type): return bb[-1].sitofp(val, llvm_type) - if dtypes.is_int(output_type): return bb[-1].trunc(val, llvm_type) if input_type.itemsize > output_type.itemsize else bb[-1].sext(val, llvm_type) - if output_type == dtypes.bool: return bb[-1].icmp_signed('!=', val, ir.Constant(val.type, 0)) - + if dtypes.is_float(output_type): return 'sitofp' + if dtypes.is_int(output_type): return 'trunc' if output_type.itemsize < input_type.itemsize else 'sext' raise NotImplementedError(f"cast from {input_type} -> {output_type} not implemented") -def const(args, dtype): return ir.Constant(dtype_to_llvm_dtype[dtype], args) +# llvm ops, lop[][] +unsigned_lop = { Ops.ADD: "add", Ops.MUL: "mul", Ops.IDIV: "udiv", Ops.MOD: "urem", + Ops.CMPLT: "icmp ult", Ops.CMPNE: "icmp ne", Ops.OR: "or", Ops.AND: "and", Ops.XOR: "xor", } +signed_lop = {**unsigned_lop, Ops.CMPLT: "icmp slt", Ops.IDIV: "sdiv", Ops.MOD: "srem"} +flags = " nsz arcp contract afn" +float_lop = {Ops.ADD: "fadd"+flags, Ops.MUL: "fmul"+flags, Ops.CMPLT: f"fcmp{flags} ult", Ops.CMPNE: f"fcmp{flags} une", Ops.FDIV: "fdiv"+flags} +lop = {**{x:unsigned_lop for x in (dtypes.bool,)+dtypes.uints}, **{x:signed_lop for x in dtypes.sints}, **{x:float_lop for x in dtypes.floats}} + +llvm_rewrite = PatternMatcher([ + # memory load/store + (UPat(Ops.INDEX, name="x"), lambda ctx,x: + f" {ctx[x]} = getelementptr inbounds {ldt(x.dtype.base)}, {ldt(x.src[0].dtype)} {ctx[x.src[0]]}, i32 {ctx[x.src[1]]}"), + (UPat(Ops.LOAD, src=(UPat.var('idx'), UPat.var('alt'), UPat.var('mask')), name="x"), lambda ctx,x,idx,alt,mask: + f" br label {ctx[x]}_entry\n{ctx[x][1:]}_entry:\n" + f" br i1 {ctx[mask]}, label {ctx[x]}_load, label {ctx[x]}_exit\n{ctx[x][1:]}_load:\n" + f" {ctx[x]}_yes = load {ldt(x.dtype)}, {ldt(idx.dtype)} {ctx[idx]}\n" + f" br label {ctx[x]}_exit\n{ctx[x][1:]}_exit:\n" + f" {ctx[x]} = phi {ldt(x.dtype)} [{ctx[x]}_yes, {ctx[x]}_load], [{ctx[alt]}, {ctx[x]}_entry]"), + (UPat(Ops.LOAD, src=(UPat.var('idx'),), name="x"), lambda ctx,x,idx: f" {ctx[x]} = load {ldt(x.dtype)}, {ldt(idx.dtype)} {ctx[idx]}"), + (UPat(Ops.STORE, name="x"), lambda ctx,x: f" store {ldt(x.src[1].dtype)} {ctx[x.src[1]]}, {ldt(x.src[0].dtype)} {ctx[x.src[0]]}"), + + # unary/binary/ternary ops + (UPat(Ops.SQRT, name="x"), lambda ctx,x: + f" {ctx[x]} = call{flags} {ldt(x.dtype)} @llvm.sqrt.{ldt(x.src[0].dtype)}({ldt(x.src[0].dtype)} {ctx[x.src[0]]})"), + (UPat(Ops.BITCAST, name="x"), lambda ctx,x: f" {ctx[x]} = bitcast {ldt(x.src[0].dtype)} {ctx[x.src[0]]} to {ldt(x.dtype)}"), + (UPat(Ops.CAST, name="x"), lambda ctx,x: f" {ctx[x]} = {lcast(x.src[0].dtype, x.dtype)} {ldt(x.src[0].dtype)} {ctx[x.src[0]]} to {ldt(x.dtype)}"), + (UPat(GroupOp.Binary, name="x"), lambda ctx,x: f" {ctx[x]} = {lop[x.src[0].dtype][x.op]} {ldt(x.src[0].dtype)} {ctx[x.src[0]]}, {ctx[x.src[1]]}"), + (UPat(Ops.WHERE, name="x"), lambda ctx,x: + f" {ctx[x]} = select {ldt(x.src[0].dtype)} {ctx[x.src[0]]}, {ldt(x.src[1].dtype)} {ctx[x.src[1]]}, {ldt(x.src[2].dtype)} {ctx[x.src[2]]}"), + + # range + (UPat(Ops.RANGE, name="x"), lambda ctx,x: + f" br label %loop_entry_{x.arg[0]}\nloop_entry_{x.arg[0]}:\n" + f" br label %loop_body_{x.arg[0]}\nloop_body_{x.arg[0]}:\n" + f" {ctx[x]} = phi {ldt(x.dtype)} [{ctx[x.src[0]]}, %loop_entry_{x.arg[0]}], [{ctx[x]}phi, %loop_latch_{x.arg[0]}]"), + (UPat(Ops.ENDRANGE, name="x"), lambda ctx,x: + f" br label %loop_latch_{x.src[0].arg[0]}\nloop_latch_{x.src[0].arg[0]}:\n" + f" {ctx[x.src[0]]}phi = add i32 {ctx[x.src[0]]}, 1\n {ctx[x]} = icmp ult i32 {ctx[x.src[0]]}phi, {ctx[x.src[0].src[1]]}\n" + f" br i1 {ctx[x]}, label %loop_body_{x.src[0].arg[0]}, label %loop_exit_{x.src[0].arg[0]}\nloop_exit_{x.src[0].arg[0]}:"), + + # if + (UPat(Ops.IF, name="x"), lambda ctx,x: f" br i1 {ctx[x.src[0]]}, label %ifbody_{ctx[x][1:]}, label %ifskip_{ctx[x][1:]}\nifbody_{ctx[x][1:]}:"), + (UPat(Ops.ENDIF, name="x"), lambda ctx,x: f" br label %ifskip_{ctx[x.src[0]][1:]}\nifskip_{ctx[x.src[0]][1:]}:"), +]) class LLVMRenderer(Renderer): device = "LLVM" @@ -52,101 +79,64 @@ class LLVMRenderer(Renderer): has_local = False has_shared = False global_max = None - code_for_op: Dict[Ops, Callable] = { - UnaryOps.RECIP: lambda builder, x, dtype: builder.fdiv(const(1, dtype), x, flags=MFLAGS), - UnaryOps.SQRT: lambda builder, x, dtype: builder.call(builder.module.declare_intrinsic('llvm.sqrt', [x.type]), [x], fastmath=MFLAGS), - BinaryOps.ADD: lambda builder, x, y, dtype: builder.or_(x, y) if dtype == dtypes.bool else builder.add(x, y) if dtypes.is_int(dtype) else builder.fadd(x, y, flags=MFLAGS), # noqa: E501 - BinaryOps.MUL: lambda builder, x, y, dtype: builder.mul(x, y) if is_bool_or_unsigned(dtype) or dtypes.is_int(dtype) else builder.fmul(x, y, flags=MFLAGS), # noqa: E501 - BinaryOps.IDIV: lambda builder, x, y, dtype: builder.udiv(x, y) if is_bool_or_unsigned(dtype) else builder.sdiv(x, y), - BinaryOps.CMPLT: lambda builder, x, y, dtype: builder.icmp_unsigned("<", x, y) if is_bool_or_unsigned(dtype) else builder.icmp_signed("<", x, y) if dtypes.is_int(dtype) else builder.fcmp_unordered("<", x, y, flags=MFLAGS), # noqa: E501 - BinaryOps.CMPNE: lambda builder, x, y, dtype: builder.icmp_unsigned("!=", x, y) if is_bool_or_unsigned(dtype) else builder.icmp_signed("!=", x, y) if dtypes.is_int(dtype) else builder.fcmp_unordered("!=", x, y, flags=MFLAGS), # noqa: E501 - BinaryOps.MAX: lambda builder, x, y, dtype: builder.select(builder.icmp_unsigned(">", x, y) if is_bool_or_unsigned(dtype) else builder.icmp_signed(">", x, y) if dtypes.is_int(dtype) else builder.fcmp_unordered(">", x, y, flags=MFLAGS), x, y), # noqa: E501 - BinaryOps.MOD: lambda builder, x, y, dtype: builder.urem(x, y) if is_bool_or_unsigned(dtype) else builder.srem(x, y) if dtypes.is_int(dtype) else builder.frem(x, y), # noqa: E501 - BinaryOps.XOR: lambda builder, x, y, dtype: builder.xor(x, y), BinaryOps.AND: lambda builder, x, y, dtype: builder.and_(x, y), BinaryOps.OR: lambda builder, x, y, dtype: builder.or_(x, y), # noqa: E501 - BinaryOps.SHL: lambda builder, x, y, dtype: builder.shl(x, y), BinaryOps.SHR: lambda builder, x, y, dtype: builder.lshr(x, y) if dtypes.is_unsigned(dtype) else builder.ashr(x, y), # noqa: E501 - TernaryOps.WHERE: lambda builder, x, y, z, dtype: builder.select(x, y, z)} - def render(self, name:str, uops:List[UOp]) -> str: - # all llvm stuff goes into a module - module = ir.Module(name=__file__) + extra_matcher = PatternMatcher([ + # rewrite RECIP with FDIV + (UPat(Ops.RECIP, name="x"), lambda x: UOp(Ops.FDIV, x.dtype, (x.const_like(1), x.src[0]))), + # rewrite cast to bool to CMPNE 0 + (UPat(Ops.CAST, dtype=dtypes.bool, name="x"), lambda x: x.src[0] != x.src[0].const_like(0)), + # *** also in cstyle *** + # gate any stores that aren't gated with ifs + (UPat(Ops.STORE, dtype=dtypes.void, src=(UPat(), UPat(), UPat(dtype=dtypes.bool)), name="store"), + lambda store: UOp(Ops.STORE, src=store.src[:2]+(UOp(Ops.IF, src=(store.src[2],)),))), + # rewrite MAX to CMPLT + WHERE + (UPat(Ops.MAX, name="m"), lambda m: (m.src[0] < m.src[1]).where(m.src[1], m.src[0])), + ]) - # extract global buffers (NOTE: this isn't right if DEFINE_GLOBAL is out of order) - buf_to_dtype = {u.arg:u.dtype for u in uops if u.op in {Ops.DEFINE_GLOBAL, Ops.DEFINE_VAR}} - buf_index = {x:i for i,x in enumerate(buf_to_dtype.keys())} + def render(self, name: str, uops: List[UOp]) -> str: + r: Dict[UOp, str] = {} + args: List[str] = [] + kernel: List[str] = [] + end_lines: Dict[str, None] = {} + vc = -1 - # create llvm function - func_dtypes = [(dtype_to_llvm_dtype[dtype.base if isinstance(dtype, PtrDType) else dtype],dtype) for dtype in buf_to_dtype.values()] - func = ir.Function(module, ir.FunctionType(ir.VoidType(), [x.as_pointer() if isinstance(dt, PtrDType) else x for x,dt in func_dtypes]), name=name) - for a in func.args: - if a.type.is_pointer: a.add_attribute("noalias") - - bb = [ir.IRBuilder(func.append_basic_block("entry"))] - loop_blocks: List = [] - reduce_phis: List = [] - lvars: Dict[Optional[UOp], ir.Instruction] = {} - - for bufname,dtype in buf_to_dtype.items(): - if not isinstance(dtype, PtrDType) and dtype == dtypes.int32: lvars[bufname] = bb[-1].sext(func.args[buf_index[bufname]], ir.IntType(32)) + # prealloc all assigns + acc_to_assign: Dict[UOp, UOp] = {} + for u in uops: + if u.op is Ops.ASSIGN: + vc += 1 + r[u] = r[u.src[1]] = f"%assign{vc}" + assert u.src[0] not in acc_to_assign, "can't assign to DEFINE_ACC twice" + acc_to_assign[u.src[0]] = u.src[1] for u in uops: - uop,dtype,src,args = u.op,u.dtype,u.src,u.arg - if uop is Ops.INDEX: - lvars[u] = bb[-1].gep(lvars[src[0]], [lvars[src[1]]], inbounds=True) - elif uop is Ops.STORE: - if len(src) > 2: - with bb[-1].if_then(lvars[src[2]]): bb[-1].store(lvars[src[1]], lvars[src[0]]) - else: - bb[-1].store(lvars[src[1]], lvars[src[0]]) - elif uop is Ops.ENDRANGE: - loop_entry_bb, phis = loop_blocks.pop() - idx_p1 = bb[-1].add(lvars[src[0]], ir.Constant(ir.IntType(32), 1)) - lvars[src[0]].add_incoming(idx_p1, bb[-1].block) - for n,phi in phis: phi.add_incoming(lvars[n], bb[-1].block) - bb.append(ir.IRBuilder(func.append_basic_block(f"loop_exit_{len(loop_blocks)}"))) - bb[-2].cbranch(bb[-2].icmp_unsigned("<", idx_p1, lvars[src[0].src[1]]), loop_entry_bb, bb[-1].block) + # hack for defining sqrt function (TODO: can we get a transcendental for this?) + if u.op is Ops.SQRT: end_lines[f'declare {ldt(u.dtype)} @llvm.sqrt.{ldt(u.dtype)}({ldt(u.dtype)} %".1")'] = None + + if u.op in (Ops.DEFINE_GLOBAL, Ops.DEFINE_VAR): + r[u] = f"%data{u.arg}" if u.op is Ops.DEFINE_GLOBAL else f"%{u.arg[0]}" + args.append(f"{ldt(u.dtype)}{' noalias' if isinstance(u.dtype, PtrDType) else ''} {r[u]}") + elif u.op is Ops.ASSIGN: pass # assign is already handled by the first pass + elif u.op is Ops.DEFINE_ACC: r[u] = r[u.src[0]] # a define acc can be used and never be assigned to + elif u.op is Ops.CONST: r[u] = lconst(u.arg, u.dtype) + elif u.op is Ops.CAST and ldt(u.dtype) == ldt(u.src[0].dtype): r[u] = r[u.src[0]] # cast from signed to unsigned of the same size is a noop else: - if uop is Ops.RANGE: - bb.append(ir.IRBuilder(func.append_basic_block(f"loop_body_{len(loop_blocks)}"))) - bb[-2].branch(bb[-1].block) + # if it's an assign target, it's already preallocated + if u not in r: + vc += 1 + r[u] = f"%v{vc}" - phis = [] - for rp in reduce_phis: - incoming = lvars[rp] - lvars[rp] = bb[-1].phi(dtype_to_llvm_dtype[rp.dtype]) - lvars[rp].add_incoming(incoming, bb[-2].block) - phis.append((rp, lvars[rp])) + # do the rendering of the llvm ir code + if (l:=llvm_rewrite.rewrite(u, ctx=r)) is None: raise RuntimeError(f"failed to render {u.op} with {u.dtype} srcs {[x.dtype for x in u.src]}") + kernel.append(cast(str, l)) - lvars[u] = bb[-1].phi(ir.IntType(32), name=f"loop{len(loop_blocks)}") - lvars[u].add_incoming(lvars[src[0]], bb[-2].block) - loop_blocks.append((bb[-1].block, phis)) - elif uop is Ops.DEFINE_ACC: - lvars[u] = const(src[0].arg, dtype) - reduce_phis.append(u) - elif uop is Ops.LOAD: - if len(src) > 1: - with bb[-1].if_else(lvars[src[2]]) as (then, otherwise): - with then: - val1 = bb[-1].load(lvars[src[0]]) - then_blk = bb[-1].block - with otherwise: otherwise_blk = bb[-1].block - val = bb[-1].phi(val1.type) - val.add_incoming(val1, then_blk) - val.add_incoming(lvars[src[1]], otherwise_blk) - else: - val = bb[-1].load(lvars[src[0]]) - lvars[u] = val - elif uop is Ops.ASSIGN: - lvars[u] = lvars[src[1]] - # ASSIGN UOps can link to other ASSIGN Uops, backtrace this to DEFINE_ACC - backward = src[0] - while backward.op is Ops.ASSIGN: backward = backward.src[0] - lvars[backward] = lvars[u] - elif uop in GroupOp.ALU: - lvars[u] = self.code_for_op[uop](bb[-1], *[lvars[x] for x in src], src[0].dtype if uop in {BinaryOps.CMPLT, BinaryOps.CMPNE} else dtype) - elif uop in {Ops.CAST, Ops.BITCAST}: lvars[u] = cast(bb, lvars[src[0]], src[0].dtype, dtype, bitcast=uop is Ops.BITCAST) - elif uop in {Ops.DEFINE_GLOBAL, Ops.DEFINE_VAR}: lvars[u] = func.args[buf_index[args]] - elif uop is Ops.CONST: lvars[u] = const(args, dtype) - else: raise RuntimeError(f"failed to render {uop}") + # generate the phi nodes for the assigns + if u.op is Ops.RANGE: + for x in acc_to_assign: + if u in x.src: # if this range is relevent for this acc + vc += 1 + kernel.append(f" %acc{vc} = phi {ldt(x.dtype)}" f"[{r[x]}, %loop_entry_{u.arg[0]}], [{r[acc_to_assign[x]]}, %loop_latch_{u.arg[0]}]") + r[x] = f"%acc{vc}" - bb[-1].ret_void() - return str(module) + # output the function + return f"define void @{name}({','.join(args)}) {{\n" + '\n'.join(kernel) + "\n ret void\n}\n"+'\n'.join(end_lines.keys()) diff --git a/tinygrad/runtime/ops_llvm.py b/tinygrad/runtime/ops_llvm.py index c897b2e574..5a16f5650c 100644 --- a/tinygrad/runtime/ops_llvm.py +++ b/tinygrad/runtime/ops_llvm.py @@ -23,6 +23,7 @@ class LLVMProgram: self.name, self.lib = name, lib device.engine.add_object_file(llvm.object_file.ObjectFileRef.from_data(lib)) self.fxn = device.engine.get_function_address(name) + assert self.fxn != 0, "LLVM failed to get function address" def __call__(self, *bufs, vals:Tuple[int, ...]=(), wait=False): if not hasattr(self, 'cfunc'):