replace llvm with new llvm (#7616)

* replace llvm with new llvm

* fix test_linearizer

* minor fixups

* fix alloca

* don't use alloca

* fix DEFINE_ACC

* lines

* comments and lines

* a little tighter
This commit is contained in:
George Hotz
2024-11-10 11:28:52 +08:00
committed by GitHub
parent b61266eb97
commit 0a411b4f68
4 changed files with 122 additions and 130 deletions

View File

@@ -1219,7 +1219,7 @@ class TestLinearizer(unittest.TestCase):
assert len(sched) == 1
lin = Kernel(sched[0].ast)
assert sum(u.op is UnaryOps.RECIP for u in lin.linearize().uops) == max_ops, msg
assert sum(u.op in {UnaryOps.RECIP, BinaryOps.FDIV} for u in lin.linearize().uops) == max_ops, msg
a = Tensor.empty((4,4))
b = Tensor.empty((4,4))

View File

@@ -140,7 +140,7 @@ class Ops(FastEnum):
# BinaryOps
ADD = auto(); MUL = auto(); IDIV = auto(); MAX = auto(); MOD = auto(); CMPLT = auto(); CMPNE = auto(); XOR = auto() # noqa: E702
SHL = auto(); SHR = auto(); OR = auto(); AND = auto(); THREEFRY = auto(); SUB = auto() # noqa: E702
SHL = auto(); SHR = auto(); OR = auto(); AND = auto(); THREEFRY = auto(); SUB = auto(); FDIV = auto() # noqa: E702
# TernaryOps
WHERE = auto(); MULACC = auto() # noqa: E702
@@ -168,7 +168,8 @@ class Ops(FastEnum):
class GroupOp:
Unary = {Ops.EXP2, Ops.LOG2, Ops.SIN, Ops.SQRT, Ops.RECIP, Ops.NEG}
Binary = {Ops.ADD, Ops.MUL, Ops.IDIV, Ops.MAX, Ops.MOD, Ops.CMPLT, Ops.CMPNE, Ops.XOR, Ops.SHL, Ops.SHR, Ops.OR, Ops.AND, Ops.THREEFRY, Ops.SUB}
Binary = {Ops.ADD, Ops.MUL, Ops.IDIV, Ops.MAX, Ops.MOD, Ops.CMPLT, Ops.CMPNE, Ops.XOR, Ops.SHL, Ops.SHR, Ops.OR, Ops.AND, Ops.THREEFRY,
Ops.SUB, Ops.FDIV}
Ternary = {Ops.WHERE, Ops.MULACC}
ALU = set.union(Unary, Binary, Ternary)

View File

@@ -1,50 +1,77 @@
from typing import Dict, Callable, List, Optional
from llvmlite import ir
from tinygrad.dtype import DType, PtrDType, dtypes
from tinygrad.ops import UnaryOps, BinaryOps, TernaryOps, Ops, UOp, GroupOp
from typing import List, Dict, cast
import math, struct
from tinygrad.renderer import Renderer
from tinygrad.ops import UOp, PatternMatcher, UPat, Ops, GroupOp
from tinygrad.dtype import dtypes, DType, PtrDType, truncate
MFLAGS = ('nsz', 'arcp', 'contract', 'afn') # All from fast math, but nnan and ninf and reassoc
def ldt(dt:DType):
if isinstance(dt, PtrDType): return ldt(dt.base) + "*"
return {dtypes.int8: "i8", dtypes.int16: "i16", dtypes.int32: "i32", dtypes.int64: "i64",
dtypes.uint8: "i8", dtypes.uint16: "i16", dtypes.uint32: "i32", dtypes.uint64: "i64",
dtypes.float16: "half", dtypes.float32: "float", dtypes.float64: "double", dtypes.bool: "i1", dtypes.void: "void"}[dt]
def is_bool_or_unsigned(dtype: DType): return dtype == dtypes.bool or dtypes.is_unsigned(dtype)
dtype_to_llvm_dtype = { dtypes.bool:ir.IntType(1), dtypes.int8:ir.IntType(8), dtypes.uint8:ir.IntType(8), dtypes.int16:ir.IntType(16),
dtypes.uint16:ir.IntType(16), dtypes.int32:ir.IntType(32), dtypes.uint32:ir.IntType(32), dtypes.int64:ir.IntType(64), dtypes.uint64:ir.IntType(64),
dtypes.float16:ir.HalfType(), dtypes.bfloat16:ir.IntType(16), dtypes.float32:ir.FloatType(), dtypes.float64:ir.DoubleType() }
def cast(bb, val, input_type, output_type, bitcast=False):
if input_type == output_type: return val
llvm_type = dtype_to_llvm_dtype[output_type]
if bitcast: return bb[-1].bitcast(val, llvm_type)
if input_type == dtypes.bfloat16:
val = bb[-1].bitcast(bb[-1].shl(bb[-1].sext(val, ir.IntType(32)), ir.Constant(ir.IntType(32), 16)),val, ir.FloatType())
input_type = dtypes.float32
if output_type == dtypes.bfloat16:
val = cast(bb, val, input_type, dtypes.float32)
return bb[-1].trunc(bb[-1].lshr(bb[-1].bitcast(val, ir.IntType(32)), ir.Constant(ir.IntType(32), 16)), ir.IntType(16))
def lconst(x, dtype:DType):
if dtype in dtypes.floats:
if math.isinf(x) or math.isnan(x): return "0x%02X%02X%02X%02X%02X%02X%02X%02X" % tuple(struct.pack("d",x)[::-1])
return truncate[dtype](x)
return int(x)
def lcast(input_type:DType, output_type:DType):
if dtypes.is_float(input_type):
if dtypes.is_float(output_type):
return bb[-1].fpext(val, llvm_type) if output_type.itemsize > input_type.itemsize else bb[-1].fptrunc(val, llvm_type)
if dtypes.is_int(output_type): return bb[-1].fptoui(val, llvm_type) if dtypes.is_unsigned(output_type) else bb[-1].fptosi(val, llvm_type)
if output_type == dtypes.bool: return bb[-1].fcmp_unordered('!=', cast(bb, val, input_type, dtypes.float32), ir.Constant(ir.FloatType(), 0))
if dtypes.is_float(output_type): return 'fpext' if output_type.itemsize > input_type.itemsize else 'fptrunc'
if dtypes.is_int(output_type): return 'fptoui' if dtypes.is_unsigned(output_type) else 'fptosi'
if dtypes.is_unsigned(input_type) or input_type == dtypes.bool:
if output_type == dtypes.float16: return bb[-1].fptrunc(bb[-1].uitofp(val, ir.FloatType()), ir.HalfType())
if dtypes.is_float(output_type): return bb[-1].uitofp(val, dtype_to_llvm_dtype[output_type])
if dtypes.is_int(output_type): return bb[-1].trunc(val, llvm_type) if input_type.itemsize > output_type.itemsize else bb[-1].zext(val, llvm_type)
if output_type == dtypes.bool: return bb[-1].icmp_unsigned('!=', val, ir.Constant(val.type, 0))
if dtypes.is_float(output_type): return 'uitofp'
if dtypes.is_int(output_type): return 'trunc' if output_type.itemsize < input_type.itemsize else 'zext'
if dtypes.is_int(input_type):
if output_type == dtypes.float16: return bb[-1].fptrunc(bb[-1].sitofp(val, ir.FloatType()), ir.HalfType())
if dtypes.is_float(output_type): return bb[-1].sitofp(val, llvm_type)
if dtypes.is_int(output_type): return bb[-1].trunc(val, llvm_type) if input_type.itemsize > output_type.itemsize else bb[-1].sext(val, llvm_type)
if output_type == dtypes.bool: return bb[-1].icmp_signed('!=', val, ir.Constant(val.type, 0))
if dtypes.is_float(output_type): return 'sitofp'
if dtypes.is_int(output_type): return 'trunc' if output_type.itemsize < input_type.itemsize else 'sext'
raise NotImplementedError(f"cast from {input_type} -> {output_type} not implemented")
def const(args, dtype): return ir.Constant(dtype_to_llvm_dtype[dtype], args)
# llvm ops, lop[<dtype>][<op>]
unsigned_lop = { Ops.ADD: "add", Ops.MUL: "mul", Ops.IDIV: "udiv", Ops.MOD: "urem",
Ops.CMPLT: "icmp ult", Ops.CMPNE: "icmp ne", Ops.OR: "or", Ops.AND: "and", Ops.XOR: "xor", }
signed_lop = {**unsigned_lop, Ops.CMPLT: "icmp slt", Ops.IDIV: "sdiv", Ops.MOD: "srem"}
flags = " nsz arcp contract afn"
float_lop = {Ops.ADD: "fadd"+flags, Ops.MUL: "fmul"+flags, Ops.CMPLT: f"fcmp{flags} ult", Ops.CMPNE: f"fcmp{flags} une", Ops.FDIV: "fdiv"+flags}
lop = {**{x:unsigned_lop for x in (dtypes.bool,)+dtypes.uints}, **{x:signed_lop for x in dtypes.sints}, **{x:float_lop for x in dtypes.floats}}
llvm_rewrite = PatternMatcher([
# memory load/store
(UPat(Ops.INDEX, name="x"), lambda ctx,x:
f" {ctx[x]} = getelementptr inbounds {ldt(x.dtype.base)}, {ldt(x.src[0].dtype)} {ctx[x.src[0]]}, i32 {ctx[x.src[1]]}"),
(UPat(Ops.LOAD, src=(UPat.var('idx'), UPat.var('alt'), UPat.var('mask')), name="x"), lambda ctx,x,idx,alt,mask:
f" br label {ctx[x]}_entry\n{ctx[x][1:]}_entry:\n"
f" br i1 {ctx[mask]}, label {ctx[x]}_load, label {ctx[x]}_exit\n{ctx[x][1:]}_load:\n"
f" {ctx[x]}_yes = load {ldt(x.dtype)}, {ldt(idx.dtype)} {ctx[idx]}\n"
f" br label {ctx[x]}_exit\n{ctx[x][1:]}_exit:\n"
f" {ctx[x]} = phi {ldt(x.dtype)} [{ctx[x]}_yes, {ctx[x]}_load], [{ctx[alt]}, {ctx[x]}_entry]"),
(UPat(Ops.LOAD, src=(UPat.var('idx'),), name="x"), lambda ctx,x,idx: f" {ctx[x]} = load {ldt(x.dtype)}, {ldt(idx.dtype)} {ctx[idx]}"),
(UPat(Ops.STORE, name="x"), lambda ctx,x: f" store {ldt(x.src[1].dtype)} {ctx[x.src[1]]}, {ldt(x.src[0].dtype)} {ctx[x.src[0]]}"),
# unary/binary/ternary ops
(UPat(Ops.SQRT, name="x"), lambda ctx,x:
f" {ctx[x]} = call{flags} {ldt(x.dtype)} @llvm.sqrt.{ldt(x.src[0].dtype)}({ldt(x.src[0].dtype)} {ctx[x.src[0]]})"),
(UPat(Ops.BITCAST, name="x"), lambda ctx,x: f" {ctx[x]} = bitcast {ldt(x.src[0].dtype)} {ctx[x.src[0]]} to {ldt(x.dtype)}"),
(UPat(Ops.CAST, name="x"), lambda ctx,x: f" {ctx[x]} = {lcast(x.src[0].dtype, x.dtype)} {ldt(x.src[0].dtype)} {ctx[x.src[0]]} to {ldt(x.dtype)}"),
(UPat(GroupOp.Binary, name="x"), lambda ctx,x: f" {ctx[x]} = {lop[x.src[0].dtype][x.op]} {ldt(x.src[0].dtype)} {ctx[x.src[0]]}, {ctx[x.src[1]]}"),
(UPat(Ops.WHERE, name="x"), lambda ctx,x:
f" {ctx[x]} = select {ldt(x.src[0].dtype)} {ctx[x.src[0]]}, {ldt(x.src[1].dtype)} {ctx[x.src[1]]}, {ldt(x.src[2].dtype)} {ctx[x.src[2]]}"),
# range
(UPat(Ops.RANGE, name="x"), lambda ctx,x:
f" br label %loop_entry_{x.arg[0]}\nloop_entry_{x.arg[0]}:\n"
f" br label %loop_body_{x.arg[0]}\nloop_body_{x.arg[0]}:\n"
f" {ctx[x]} = phi {ldt(x.dtype)} [{ctx[x.src[0]]}, %loop_entry_{x.arg[0]}], [{ctx[x]}phi, %loop_latch_{x.arg[0]}]"),
(UPat(Ops.ENDRANGE, name="x"), lambda ctx,x:
f" br label %loop_latch_{x.src[0].arg[0]}\nloop_latch_{x.src[0].arg[0]}:\n"
f" {ctx[x.src[0]]}phi = add i32 {ctx[x.src[0]]}, 1\n {ctx[x]} = icmp ult i32 {ctx[x.src[0]]}phi, {ctx[x.src[0].src[1]]}\n"
f" br i1 {ctx[x]}, label %loop_body_{x.src[0].arg[0]}, label %loop_exit_{x.src[0].arg[0]}\nloop_exit_{x.src[0].arg[0]}:"),
# if
(UPat(Ops.IF, name="x"), lambda ctx,x: f" br i1 {ctx[x.src[0]]}, label %ifbody_{ctx[x][1:]}, label %ifskip_{ctx[x][1:]}\nifbody_{ctx[x][1:]}:"),
(UPat(Ops.ENDIF, name="x"), lambda ctx,x: f" br label %ifskip_{ctx[x.src[0]][1:]}\nifskip_{ctx[x.src[0]][1:]}:"),
])
class LLVMRenderer(Renderer):
device = "LLVM"
@@ -52,101 +79,64 @@ class LLVMRenderer(Renderer):
has_local = False
has_shared = False
global_max = None
code_for_op: Dict[Ops, Callable] = {
UnaryOps.RECIP: lambda builder, x, dtype: builder.fdiv(const(1, dtype), x, flags=MFLAGS),
UnaryOps.SQRT: lambda builder, x, dtype: builder.call(builder.module.declare_intrinsic('llvm.sqrt', [x.type]), [x], fastmath=MFLAGS),
BinaryOps.ADD: lambda builder, x, y, dtype: builder.or_(x, y) if dtype == dtypes.bool else builder.add(x, y) if dtypes.is_int(dtype) else builder.fadd(x, y, flags=MFLAGS), # noqa: E501
BinaryOps.MUL: lambda builder, x, y, dtype: builder.mul(x, y) if is_bool_or_unsigned(dtype) or dtypes.is_int(dtype) else builder.fmul(x, y, flags=MFLAGS), # noqa: E501
BinaryOps.IDIV: lambda builder, x, y, dtype: builder.udiv(x, y) if is_bool_or_unsigned(dtype) else builder.sdiv(x, y),
BinaryOps.CMPLT: lambda builder, x, y, dtype: builder.icmp_unsigned("<", x, y) if is_bool_or_unsigned(dtype) else builder.icmp_signed("<", x, y) if dtypes.is_int(dtype) else builder.fcmp_unordered("<", x, y, flags=MFLAGS), # noqa: E501
BinaryOps.CMPNE: lambda builder, x, y, dtype: builder.icmp_unsigned("!=", x, y) if is_bool_or_unsigned(dtype) else builder.icmp_signed("!=", x, y) if dtypes.is_int(dtype) else builder.fcmp_unordered("!=", x, y, flags=MFLAGS), # noqa: E501
BinaryOps.MAX: lambda builder, x, y, dtype: builder.select(builder.icmp_unsigned(">", x, y) if is_bool_or_unsigned(dtype) else builder.icmp_signed(">", x, y) if dtypes.is_int(dtype) else builder.fcmp_unordered(">", x, y, flags=MFLAGS), x, y), # noqa: E501
BinaryOps.MOD: lambda builder, x, y, dtype: builder.urem(x, y) if is_bool_or_unsigned(dtype) else builder.srem(x, y) if dtypes.is_int(dtype) else builder.frem(x, y), # noqa: E501
BinaryOps.XOR: lambda builder, x, y, dtype: builder.xor(x, y), BinaryOps.AND: lambda builder, x, y, dtype: builder.and_(x, y), BinaryOps.OR: lambda builder, x, y, dtype: builder.or_(x, y), # noqa: E501
BinaryOps.SHL: lambda builder, x, y, dtype: builder.shl(x, y), BinaryOps.SHR: lambda builder, x, y, dtype: builder.lshr(x, y) if dtypes.is_unsigned(dtype) else builder.ashr(x, y), # noqa: E501
TernaryOps.WHERE: lambda builder, x, y, z, dtype: builder.select(x, y, z)}
def render(self, name:str, uops:List[UOp]) -> str:
# all llvm stuff goes into a module
module = ir.Module(name=__file__)
extra_matcher = PatternMatcher([
# rewrite RECIP with FDIV
(UPat(Ops.RECIP, name="x"), lambda x: UOp(Ops.FDIV, x.dtype, (x.const_like(1), x.src[0]))),
# rewrite cast to bool to CMPNE 0
(UPat(Ops.CAST, dtype=dtypes.bool, name="x"), lambda x: x.src[0] != x.src[0].const_like(0)),
# *** also in cstyle ***
# gate any stores that aren't gated with ifs
(UPat(Ops.STORE, dtype=dtypes.void, src=(UPat(), UPat(), UPat(dtype=dtypes.bool)), name="store"),
lambda store: UOp(Ops.STORE, src=store.src[:2]+(UOp(Ops.IF, src=(store.src[2],)),))),
# rewrite MAX to CMPLT + WHERE
(UPat(Ops.MAX, name="m"), lambda m: (m.src[0] < m.src[1]).where(m.src[1], m.src[0])),
])
# extract global buffers (NOTE: this isn't right if DEFINE_GLOBAL is out of order)
buf_to_dtype = {u.arg:u.dtype for u in uops if u.op in {Ops.DEFINE_GLOBAL, Ops.DEFINE_VAR}}
buf_index = {x:i for i,x in enumerate(buf_to_dtype.keys())}
def render(self, name: str, uops: List[UOp]) -> str:
r: Dict[UOp, str] = {}
args: List[str] = []
kernel: List[str] = []
end_lines: Dict[str, None] = {}
vc = -1
# create llvm function
func_dtypes = [(dtype_to_llvm_dtype[dtype.base if isinstance(dtype, PtrDType) else dtype],dtype) for dtype in buf_to_dtype.values()]
func = ir.Function(module, ir.FunctionType(ir.VoidType(), [x.as_pointer() if isinstance(dt, PtrDType) else x for x,dt in func_dtypes]), name=name)
for a in func.args:
if a.type.is_pointer: a.add_attribute("noalias")
bb = [ir.IRBuilder(func.append_basic_block("entry"))]
loop_blocks: List = []
reduce_phis: List = []
lvars: Dict[Optional[UOp], ir.Instruction] = {}
for bufname,dtype in buf_to_dtype.items():
if not isinstance(dtype, PtrDType) and dtype == dtypes.int32: lvars[bufname] = bb[-1].sext(func.args[buf_index[bufname]], ir.IntType(32))
# prealloc all assigns
acc_to_assign: Dict[UOp, UOp] = {}
for u in uops:
if u.op is Ops.ASSIGN:
vc += 1
r[u] = r[u.src[1]] = f"%assign{vc}"
assert u.src[0] not in acc_to_assign, "can't assign to DEFINE_ACC twice"
acc_to_assign[u.src[0]] = u.src[1]
for u in uops:
uop,dtype,src,args = u.op,u.dtype,u.src,u.arg
if uop is Ops.INDEX:
lvars[u] = bb[-1].gep(lvars[src[0]], [lvars[src[1]]], inbounds=True)
elif uop is Ops.STORE:
if len(src) > 2:
with bb[-1].if_then(lvars[src[2]]): bb[-1].store(lvars[src[1]], lvars[src[0]])
else:
bb[-1].store(lvars[src[1]], lvars[src[0]])
elif uop is Ops.ENDRANGE:
loop_entry_bb, phis = loop_blocks.pop()
idx_p1 = bb[-1].add(lvars[src[0]], ir.Constant(ir.IntType(32), 1))
lvars[src[0]].add_incoming(idx_p1, bb[-1].block)
for n,phi in phis: phi.add_incoming(lvars[n], bb[-1].block)
bb.append(ir.IRBuilder(func.append_basic_block(f"loop_exit_{len(loop_blocks)}")))
bb[-2].cbranch(bb[-2].icmp_unsigned("<", idx_p1, lvars[src[0].src[1]]), loop_entry_bb, bb[-1].block)
# hack for defining sqrt function (TODO: can we get a transcendental for this?)
if u.op is Ops.SQRT: end_lines[f'declare {ldt(u.dtype)} @llvm.sqrt.{ldt(u.dtype)}({ldt(u.dtype)} %".1")'] = None
if u.op in (Ops.DEFINE_GLOBAL, Ops.DEFINE_VAR):
r[u] = f"%data{u.arg}" if u.op is Ops.DEFINE_GLOBAL else f"%{u.arg[0]}"
args.append(f"{ldt(u.dtype)}{' noalias' if isinstance(u.dtype, PtrDType) else ''} {r[u]}")
elif u.op is Ops.ASSIGN: pass # assign is already handled by the first pass
elif u.op is Ops.DEFINE_ACC: r[u] = r[u.src[0]] # a define acc can be used and never be assigned to
elif u.op is Ops.CONST: r[u] = lconst(u.arg, u.dtype)
elif u.op is Ops.CAST and ldt(u.dtype) == ldt(u.src[0].dtype): r[u] = r[u.src[0]] # cast from signed to unsigned of the same size is a noop
else:
if uop is Ops.RANGE:
bb.append(ir.IRBuilder(func.append_basic_block(f"loop_body_{len(loop_blocks)}")))
bb[-2].branch(bb[-1].block)
# if it's an assign target, it's already preallocated
if u not in r:
vc += 1
r[u] = f"%v{vc}"
phis = []
for rp in reduce_phis:
incoming = lvars[rp]
lvars[rp] = bb[-1].phi(dtype_to_llvm_dtype[rp.dtype])
lvars[rp].add_incoming(incoming, bb[-2].block)
phis.append((rp, lvars[rp]))
# do the rendering of the llvm ir code
if (l:=llvm_rewrite.rewrite(u, ctx=r)) is None: raise RuntimeError(f"failed to render {u.op} with {u.dtype} srcs {[x.dtype for x in u.src]}")
kernel.append(cast(str, l))
lvars[u] = bb[-1].phi(ir.IntType(32), name=f"loop{len(loop_blocks)}")
lvars[u].add_incoming(lvars[src[0]], bb[-2].block)
loop_blocks.append((bb[-1].block, phis))
elif uop is Ops.DEFINE_ACC:
lvars[u] = const(src[0].arg, dtype)
reduce_phis.append(u)
elif uop is Ops.LOAD:
if len(src) > 1:
with bb[-1].if_else(lvars[src[2]]) as (then, otherwise):
with then:
val1 = bb[-1].load(lvars[src[0]])
then_blk = bb[-1].block
with otherwise: otherwise_blk = bb[-1].block
val = bb[-1].phi(val1.type)
val.add_incoming(val1, then_blk)
val.add_incoming(lvars[src[1]], otherwise_blk)
else:
val = bb[-1].load(lvars[src[0]])
lvars[u] = val
elif uop is Ops.ASSIGN:
lvars[u] = lvars[src[1]]
# ASSIGN UOps can link to other ASSIGN Uops, backtrace this to DEFINE_ACC
backward = src[0]
while backward.op is Ops.ASSIGN: backward = backward.src[0]
lvars[backward] = lvars[u]
elif uop in GroupOp.ALU:
lvars[u] = self.code_for_op[uop](bb[-1], *[lvars[x] for x in src], src[0].dtype if uop in {BinaryOps.CMPLT, BinaryOps.CMPNE} else dtype)
elif uop in {Ops.CAST, Ops.BITCAST}: lvars[u] = cast(bb, lvars[src[0]], src[0].dtype, dtype, bitcast=uop is Ops.BITCAST)
elif uop in {Ops.DEFINE_GLOBAL, Ops.DEFINE_VAR}: lvars[u] = func.args[buf_index[args]]
elif uop is Ops.CONST: lvars[u] = const(args, dtype)
else: raise RuntimeError(f"failed to render {uop}")
# generate the phi nodes for the assigns
if u.op is Ops.RANGE:
for x in acc_to_assign:
if u in x.src: # if this range is relevent for this acc
vc += 1
kernel.append(f" %acc{vc} = phi {ldt(x.dtype)}" f"[{r[x]}, %loop_entry_{u.arg[0]}], [{r[acc_to_assign[x]]}, %loop_latch_{u.arg[0]}]")
r[x] = f"%acc{vc}"
bb[-1].ret_void()
return str(module)
# output the function
return f"define void @{name}({','.join(args)}) {{\n" + '\n'.join(kernel) + "\n ret void\n}\n"+'\n'.join(end_lines.keys())

View File

@@ -23,6 +23,7 @@ class LLVMProgram:
self.name, self.lib = name, lib
device.engine.add_object_file(llvm.object_file.ObjectFileRef.from_data(lib))
self.fxn = device.engine.get_function_address(name)
assert self.fxn != 0, "LLVM failed to get function address"
def __call__(self, *bufs, vals:Tuple[int, ...]=(), wait=False):
if not hasattr(self, 'cfunc'):