mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-10 23:48:01 -05:00
replace llvm with new llvm (#7616)
* replace llvm with new llvm * fix test_linearizer * minor fixups * fix alloca * don't use alloca * fix DEFINE_ACC * lines * comments and lines * a little tighter
This commit is contained in:
@@ -1219,7 +1219,7 @@ class TestLinearizer(unittest.TestCase):
|
||||
assert len(sched) == 1
|
||||
|
||||
lin = Kernel(sched[0].ast)
|
||||
assert sum(u.op is UnaryOps.RECIP for u in lin.linearize().uops) == max_ops, msg
|
||||
assert sum(u.op in {UnaryOps.RECIP, BinaryOps.FDIV} for u in lin.linearize().uops) == max_ops, msg
|
||||
|
||||
a = Tensor.empty((4,4))
|
||||
b = Tensor.empty((4,4))
|
||||
|
||||
@@ -140,7 +140,7 @@ class Ops(FastEnum):
|
||||
|
||||
# BinaryOps
|
||||
ADD = auto(); MUL = auto(); IDIV = auto(); MAX = auto(); MOD = auto(); CMPLT = auto(); CMPNE = auto(); XOR = auto() # noqa: E702
|
||||
SHL = auto(); SHR = auto(); OR = auto(); AND = auto(); THREEFRY = auto(); SUB = auto() # noqa: E702
|
||||
SHL = auto(); SHR = auto(); OR = auto(); AND = auto(); THREEFRY = auto(); SUB = auto(); FDIV = auto() # noqa: E702
|
||||
|
||||
# TernaryOps
|
||||
WHERE = auto(); MULACC = auto() # noqa: E702
|
||||
@@ -168,7 +168,8 @@ class Ops(FastEnum):
|
||||
|
||||
class GroupOp:
|
||||
Unary = {Ops.EXP2, Ops.LOG2, Ops.SIN, Ops.SQRT, Ops.RECIP, Ops.NEG}
|
||||
Binary = {Ops.ADD, Ops.MUL, Ops.IDIV, Ops.MAX, Ops.MOD, Ops.CMPLT, Ops.CMPNE, Ops.XOR, Ops.SHL, Ops.SHR, Ops.OR, Ops.AND, Ops.THREEFRY, Ops.SUB}
|
||||
Binary = {Ops.ADD, Ops.MUL, Ops.IDIV, Ops.MAX, Ops.MOD, Ops.CMPLT, Ops.CMPNE, Ops.XOR, Ops.SHL, Ops.SHR, Ops.OR, Ops.AND, Ops.THREEFRY,
|
||||
Ops.SUB, Ops.FDIV}
|
||||
Ternary = {Ops.WHERE, Ops.MULACC}
|
||||
ALU = set.union(Unary, Binary, Ternary)
|
||||
|
||||
|
||||
@@ -1,50 +1,77 @@
|
||||
from typing import Dict, Callable, List, Optional
|
||||
from llvmlite import ir
|
||||
from tinygrad.dtype import DType, PtrDType, dtypes
|
||||
from tinygrad.ops import UnaryOps, BinaryOps, TernaryOps, Ops, UOp, GroupOp
|
||||
from typing import List, Dict, cast
|
||||
import math, struct
|
||||
from tinygrad.renderer import Renderer
|
||||
from tinygrad.ops import UOp, PatternMatcher, UPat, Ops, GroupOp
|
||||
from tinygrad.dtype import dtypes, DType, PtrDType, truncate
|
||||
|
||||
MFLAGS = ('nsz', 'arcp', 'contract', 'afn') # All from fast math, but nnan and ninf and reassoc
|
||||
def ldt(dt:DType):
|
||||
if isinstance(dt, PtrDType): return ldt(dt.base) + "*"
|
||||
return {dtypes.int8: "i8", dtypes.int16: "i16", dtypes.int32: "i32", dtypes.int64: "i64",
|
||||
dtypes.uint8: "i8", dtypes.uint16: "i16", dtypes.uint32: "i32", dtypes.uint64: "i64",
|
||||
dtypes.float16: "half", dtypes.float32: "float", dtypes.float64: "double", dtypes.bool: "i1", dtypes.void: "void"}[dt]
|
||||
|
||||
def is_bool_or_unsigned(dtype: DType): return dtype == dtypes.bool or dtypes.is_unsigned(dtype)
|
||||
|
||||
dtype_to_llvm_dtype = { dtypes.bool:ir.IntType(1), dtypes.int8:ir.IntType(8), dtypes.uint8:ir.IntType(8), dtypes.int16:ir.IntType(16),
|
||||
dtypes.uint16:ir.IntType(16), dtypes.int32:ir.IntType(32), dtypes.uint32:ir.IntType(32), dtypes.int64:ir.IntType(64), dtypes.uint64:ir.IntType(64),
|
||||
dtypes.float16:ir.HalfType(), dtypes.bfloat16:ir.IntType(16), dtypes.float32:ir.FloatType(), dtypes.float64:ir.DoubleType() }
|
||||
|
||||
def cast(bb, val, input_type, output_type, bitcast=False):
|
||||
if input_type == output_type: return val
|
||||
llvm_type = dtype_to_llvm_dtype[output_type]
|
||||
if bitcast: return bb[-1].bitcast(val, llvm_type)
|
||||
|
||||
if input_type == dtypes.bfloat16:
|
||||
val = bb[-1].bitcast(bb[-1].shl(bb[-1].sext(val, ir.IntType(32)), ir.Constant(ir.IntType(32), 16)),val, ir.FloatType())
|
||||
input_type = dtypes.float32
|
||||
if output_type == dtypes.bfloat16:
|
||||
val = cast(bb, val, input_type, dtypes.float32)
|
||||
return bb[-1].trunc(bb[-1].lshr(bb[-1].bitcast(val, ir.IntType(32)), ir.Constant(ir.IntType(32), 16)), ir.IntType(16))
|
||||
def lconst(x, dtype:DType):
|
||||
if dtype in dtypes.floats:
|
||||
if math.isinf(x) or math.isnan(x): return "0x%02X%02X%02X%02X%02X%02X%02X%02X" % tuple(struct.pack("d",x)[::-1])
|
||||
return truncate[dtype](x)
|
||||
return int(x)
|
||||
|
||||
def lcast(input_type:DType, output_type:DType):
|
||||
if dtypes.is_float(input_type):
|
||||
if dtypes.is_float(output_type):
|
||||
return bb[-1].fpext(val, llvm_type) if output_type.itemsize > input_type.itemsize else bb[-1].fptrunc(val, llvm_type)
|
||||
if dtypes.is_int(output_type): return bb[-1].fptoui(val, llvm_type) if dtypes.is_unsigned(output_type) else bb[-1].fptosi(val, llvm_type)
|
||||
if output_type == dtypes.bool: return bb[-1].fcmp_unordered('!=', cast(bb, val, input_type, dtypes.float32), ir.Constant(ir.FloatType(), 0))
|
||||
|
||||
if dtypes.is_float(output_type): return 'fpext' if output_type.itemsize > input_type.itemsize else 'fptrunc'
|
||||
if dtypes.is_int(output_type): return 'fptoui' if dtypes.is_unsigned(output_type) else 'fptosi'
|
||||
if dtypes.is_unsigned(input_type) or input_type == dtypes.bool:
|
||||
if output_type == dtypes.float16: return bb[-1].fptrunc(bb[-1].uitofp(val, ir.FloatType()), ir.HalfType())
|
||||
if dtypes.is_float(output_type): return bb[-1].uitofp(val, dtype_to_llvm_dtype[output_type])
|
||||
if dtypes.is_int(output_type): return bb[-1].trunc(val, llvm_type) if input_type.itemsize > output_type.itemsize else bb[-1].zext(val, llvm_type)
|
||||
if output_type == dtypes.bool: return bb[-1].icmp_unsigned('!=', val, ir.Constant(val.type, 0))
|
||||
|
||||
if dtypes.is_float(output_type): return 'uitofp'
|
||||
if dtypes.is_int(output_type): return 'trunc' if output_type.itemsize < input_type.itemsize else 'zext'
|
||||
if dtypes.is_int(input_type):
|
||||
if output_type == dtypes.float16: return bb[-1].fptrunc(bb[-1].sitofp(val, ir.FloatType()), ir.HalfType())
|
||||
if dtypes.is_float(output_type): return bb[-1].sitofp(val, llvm_type)
|
||||
if dtypes.is_int(output_type): return bb[-1].trunc(val, llvm_type) if input_type.itemsize > output_type.itemsize else bb[-1].sext(val, llvm_type)
|
||||
if output_type == dtypes.bool: return bb[-1].icmp_signed('!=', val, ir.Constant(val.type, 0))
|
||||
|
||||
if dtypes.is_float(output_type): return 'sitofp'
|
||||
if dtypes.is_int(output_type): return 'trunc' if output_type.itemsize < input_type.itemsize else 'sext'
|
||||
raise NotImplementedError(f"cast from {input_type} -> {output_type} not implemented")
|
||||
|
||||
def const(args, dtype): return ir.Constant(dtype_to_llvm_dtype[dtype], args)
|
||||
# llvm ops, lop[<dtype>][<op>]
|
||||
unsigned_lop = { Ops.ADD: "add", Ops.MUL: "mul", Ops.IDIV: "udiv", Ops.MOD: "urem",
|
||||
Ops.CMPLT: "icmp ult", Ops.CMPNE: "icmp ne", Ops.OR: "or", Ops.AND: "and", Ops.XOR: "xor", }
|
||||
signed_lop = {**unsigned_lop, Ops.CMPLT: "icmp slt", Ops.IDIV: "sdiv", Ops.MOD: "srem"}
|
||||
flags = " nsz arcp contract afn"
|
||||
float_lop = {Ops.ADD: "fadd"+flags, Ops.MUL: "fmul"+flags, Ops.CMPLT: f"fcmp{flags} ult", Ops.CMPNE: f"fcmp{flags} une", Ops.FDIV: "fdiv"+flags}
|
||||
lop = {**{x:unsigned_lop for x in (dtypes.bool,)+dtypes.uints}, **{x:signed_lop for x in dtypes.sints}, **{x:float_lop for x in dtypes.floats}}
|
||||
|
||||
llvm_rewrite = PatternMatcher([
|
||||
# memory load/store
|
||||
(UPat(Ops.INDEX, name="x"), lambda ctx,x:
|
||||
f" {ctx[x]} = getelementptr inbounds {ldt(x.dtype.base)}, {ldt(x.src[0].dtype)} {ctx[x.src[0]]}, i32 {ctx[x.src[1]]}"),
|
||||
(UPat(Ops.LOAD, src=(UPat.var('idx'), UPat.var('alt'), UPat.var('mask')), name="x"), lambda ctx,x,idx,alt,mask:
|
||||
f" br label {ctx[x]}_entry\n{ctx[x][1:]}_entry:\n"
|
||||
f" br i1 {ctx[mask]}, label {ctx[x]}_load, label {ctx[x]}_exit\n{ctx[x][1:]}_load:\n"
|
||||
f" {ctx[x]}_yes = load {ldt(x.dtype)}, {ldt(idx.dtype)} {ctx[idx]}\n"
|
||||
f" br label {ctx[x]}_exit\n{ctx[x][1:]}_exit:\n"
|
||||
f" {ctx[x]} = phi {ldt(x.dtype)} [{ctx[x]}_yes, {ctx[x]}_load], [{ctx[alt]}, {ctx[x]}_entry]"),
|
||||
(UPat(Ops.LOAD, src=(UPat.var('idx'),), name="x"), lambda ctx,x,idx: f" {ctx[x]} = load {ldt(x.dtype)}, {ldt(idx.dtype)} {ctx[idx]}"),
|
||||
(UPat(Ops.STORE, name="x"), lambda ctx,x: f" store {ldt(x.src[1].dtype)} {ctx[x.src[1]]}, {ldt(x.src[0].dtype)} {ctx[x.src[0]]}"),
|
||||
|
||||
# unary/binary/ternary ops
|
||||
(UPat(Ops.SQRT, name="x"), lambda ctx,x:
|
||||
f" {ctx[x]} = call{flags} {ldt(x.dtype)} @llvm.sqrt.{ldt(x.src[0].dtype)}({ldt(x.src[0].dtype)} {ctx[x.src[0]]})"),
|
||||
(UPat(Ops.BITCAST, name="x"), lambda ctx,x: f" {ctx[x]} = bitcast {ldt(x.src[0].dtype)} {ctx[x.src[0]]} to {ldt(x.dtype)}"),
|
||||
(UPat(Ops.CAST, name="x"), lambda ctx,x: f" {ctx[x]} = {lcast(x.src[0].dtype, x.dtype)} {ldt(x.src[0].dtype)} {ctx[x.src[0]]} to {ldt(x.dtype)}"),
|
||||
(UPat(GroupOp.Binary, name="x"), lambda ctx,x: f" {ctx[x]} = {lop[x.src[0].dtype][x.op]} {ldt(x.src[0].dtype)} {ctx[x.src[0]]}, {ctx[x.src[1]]}"),
|
||||
(UPat(Ops.WHERE, name="x"), lambda ctx,x:
|
||||
f" {ctx[x]} = select {ldt(x.src[0].dtype)} {ctx[x.src[0]]}, {ldt(x.src[1].dtype)} {ctx[x.src[1]]}, {ldt(x.src[2].dtype)} {ctx[x.src[2]]}"),
|
||||
|
||||
# range
|
||||
(UPat(Ops.RANGE, name="x"), lambda ctx,x:
|
||||
f" br label %loop_entry_{x.arg[0]}\nloop_entry_{x.arg[0]}:\n"
|
||||
f" br label %loop_body_{x.arg[0]}\nloop_body_{x.arg[0]}:\n"
|
||||
f" {ctx[x]} = phi {ldt(x.dtype)} [{ctx[x.src[0]]}, %loop_entry_{x.arg[0]}], [{ctx[x]}phi, %loop_latch_{x.arg[0]}]"),
|
||||
(UPat(Ops.ENDRANGE, name="x"), lambda ctx,x:
|
||||
f" br label %loop_latch_{x.src[0].arg[0]}\nloop_latch_{x.src[0].arg[0]}:\n"
|
||||
f" {ctx[x.src[0]]}phi = add i32 {ctx[x.src[0]]}, 1\n {ctx[x]} = icmp ult i32 {ctx[x.src[0]]}phi, {ctx[x.src[0].src[1]]}\n"
|
||||
f" br i1 {ctx[x]}, label %loop_body_{x.src[0].arg[0]}, label %loop_exit_{x.src[0].arg[0]}\nloop_exit_{x.src[0].arg[0]}:"),
|
||||
|
||||
# if
|
||||
(UPat(Ops.IF, name="x"), lambda ctx,x: f" br i1 {ctx[x.src[0]]}, label %ifbody_{ctx[x][1:]}, label %ifskip_{ctx[x][1:]}\nifbody_{ctx[x][1:]}:"),
|
||||
(UPat(Ops.ENDIF, name="x"), lambda ctx,x: f" br label %ifskip_{ctx[x.src[0]][1:]}\nifskip_{ctx[x.src[0]][1:]}:"),
|
||||
])
|
||||
|
||||
class LLVMRenderer(Renderer):
|
||||
device = "LLVM"
|
||||
@@ -52,101 +79,64 @@ class LLVMRenderer(Renderer):
|
||||
has_local = False
|
||||
has_shared = False
|
||||
global_max = None
|
||||
code_for_op: Dict[Ops, Callable] = {
|
||||
UnaryOps.RECIP: lambda builder, x, dtype: builder.fdiv(const(1, dtype), x, flags=MFLAGS),
|
||||
UnaryOps.SQRT: lambda builder, x, dtype: builder.call(builder.module.declare_intrinsic('llvm.sqrt', [x.type]), [x], fastmath=MFLAGS),
|
||||
BinaryOps.ADD: lambda builder, x, y, dtype: builder.or_(x, y) if dtype == dtypes.bool else builder.add(x, y) if dtypes.is_int(dtype) else builder.fadd(x, y, flags=MFLAGS), # noqa: E501
|
||||
BinaryOps.MUL: lambda builder, x, y, dtype: builder.mul(x, y) if is_bool_or_unsigned(dtype) or dtypes.is_int(dtype) else builder.fmul(x, y, flags=MFLAGS), # noqa: E501
|
||||
BinaryOps.IDIV: lambda builder, x, y, dtype: builder.udiv(x, y) if is_bool_or_unsigned(dtype) else builder.sdiv(x, y),
|
||||
BinaryOps.CMPLT: lambda builder, x, y, dtype: builder.icmp_unsigned("<", x, y) if is_bool_or_unsigned(dtype) else builder.icmp_signed("<", x, y) if dtypes.is_int(dtype) else builder.fcmp_unordered("<", x, y, flags=MFLAGS), # noqa: E501
|
||||
BinaryOps.CMPNE: lambda builder, x, y, dtype: builder.icmp_unsigned("!=", x, y) if is_bool_or_unsigned(dtype) else builder.icmp_signed("!=", x, y) if dtypes.is_int(dtype) else builder.fcmp_unordered("!=", x, y, flags=MFLAGS), # noqa: E501
|
||||
BinaryOps.MAX: lambda builder, x, y, dtype: builder.select(builder.icmp_unsigned(">", x, y) if is_bool_or_unsigned(dtype) else builder.icmp_signed(">", x, y) if dtypes.is_int(dtype) else builder.fcmp_unordered(">", x, y, flags=MFLAGS), x, y), # noqa: E501
|
||||
BinaryOps.MOD: lambda builder, x, y, dtype: builder.urem(x, y) if is_bool_or_unsigned(dtype) else builder.srem(x, y) if dtypes.is_int(dtype) else builder.frem(x, y), # noqa: E501
|
||||
BinaryOps.XOR: lambda builder, x, y, dtype: builder.xor(x, y), BinaryOps.AND: lambda builder, x, y, dtype: builder.and_(x, y), BinaryOps.OR: lambda builder, x, y, dtype: builder.or_(x, y), # noqa: E501
|
||||
BinaryOps.SHL: lambda builder, x, y, dtype: builder.shl(x, y), BinaryOps.SHR: lambda builder, x, y, dtype: builder.lshr(x, y) if dtypes.is_unsigned(dtype) else builder.ashr(x, y), # noqa: E501
|
||||
TernaryOps.WHERE: lambda builder, x, y, z, dtype: builder.select(x, y, z)}
|
||||
|
||||
def render(self, name:str, uops:List[UOp]) -> str:
|
||||
# all llvm stuff goes into a module
|
||||
module = ir.Module(name=__file__)
|
||||
extra_matcher = PatternMatcher([
|
||||
# rewrite RECIP with FDIV
|
||||
(UPat(Ops.RECIP, name="x"), lambda x: UOp(Ops.FDIV, x.dtype, (x.const_like(1), x.src[0]))),
|
||||
# rewrite cast to bool to CMPNE 0
|
||||
(UPat(Ops.CAST, dtype=dtypes.bool, name="x"), lambda x: x.src[0] != x.src[0].const_like(0)),
|
||||
# *** also in cstyle ***
|
||||
# gate any stores that aren't gated with ifs
|
||||
(UPat(Ops.STORE, dtype=dtypes.void, src=(UPat(), UPat(), UPat(dtype=dtypes.bool)), name="store"),
|
||||
lambda store: UOp(Ops.STORE, src=store.src[:2]+(UOp(Ops.IF, src=(store.src[2],)),))),
|
||||
# rewrite MAX to CMPLT + WHERE
|
||||
(UPat(Ops.MAX, name="m"), lambda m: (m.src[0] < m.src[1]).where(m.src[1], m.src[0])),
|
||||
])
|
||||
|
||||
# extract global buffers (NOTE: this isn't right if DEFINE_GLOBAL is out of order)
|
||||
buf_to_dtype = {u.arg:u.dtype for u in uops if u.op in {Ops.DEFINE_GLOBAL, Ops.DEFINE_VAR}}
|
||||
buf_index = {x:i for i,x in enumerate(buf_to_dtype.keys())}
|
||||
def render(self, name: str, uops: List[UOp]) -> str:
|
||||
r: Dict[UOp, str] = {}
|
||||
args: List[str] = []
|
||||
kernel: List[str] = []
|
||||
end_lines: Dict[str, None] = {}
|
||||
vc = -1
|
||||
|
||||
# create llvm function
|
||||
func_dtypes = [(dtype_to_llvm_dtype[dtype.base if isinstance(dtype, PtrDType) else dtype],dtype) for dtype in buf_to_dtype.values()]
|
||||
func = ir.Function(module, ir.FunctionType(ir.VoidType(), [x.as_pointer() if isinstance(dt, PtrDType) else x for x,dt in func_dtypes]), name=name)
|
||||
for a in func.args:
|
||||
if a.type.is_pointer: a.add_attribute("noalias")
|
||||
|
||||
bb = [ir.IRBuilder(func.append_basic_block("entry"))]
|
||||
loop_blocks: List = []
|
||||
reduce_phis: List = []
|
||||
lvars: Dict[Optional[UOp], ir.Instruction] = {}
|
||||
|
||||
for bufname,dtype in buf_to_dtype.items():
|
||||
if not isinstance(dtype, PtrDType) and dtype == dtypes.int32: lvars[bufname] = bb[-1].sext(func.args[buf_index[bufname]], ir.IntType(32))
|
||||
# prealloc all assigns
|
||||
acc_to_assign: Dict[UOp, UOp] = {}
|
||||
for u in uops:
|
||||
if u.op is Ops.ASSIGN:
|
||||
vc += 1
|
||||
r[u] = r[u.src[1]] = f"%assign{vc}"
|
||||
assert u.src[0] not in acc_to_assign, "can't assign to DEFINE_ACC twice"
|
||||
acc_to_assign[u.src[0]] = u.src[1]
|
||||
|
||||
for u in uops:
|
||||
uop,dtype,src,args = u.op,u.dtype,u.src,u.arg
|
||||
if uop is Ops.INDEX:
|
||||
lvars[u] = bb[-1].gep(lvars[src[0]], [lvars[src[1]]], inbounds=True)
|
||||
elif uop is Ops.STORE:
|
||||
if len(src) > 2:
|
||||
with bb[-1].if_then(lvars[src[2]]): bb[-1].store(lvars[src[1]], lvars[src[0]])
|
||||
else:
|
||||
bb[-1].store(lvars[src[1]], lvars[src[0]])
|
||||
elif uop is Ops.ENDRANGE:
|
||||
loop_entry_bb, phis = loop_blocks.pop()
|
||||
idx_p1 = bb[-1].add(lvars[src[0]], ir.Constant(ir.IntType(32), 1))
|
||||
lvars[src[0]].add_incoming(idx_p1, bb[-1].block)
|
||||
for n,phi in phis: phi.add_incoming(lvars[n], bb[-1].block)
|
||||
bb.append(ir.IRBuilder(func.append_basic_block(f"loop_exit_{len(loop_blocks)}")))
|
||||
bb[-2].cbranch(bb[-2].icmp_unsigned("<", idx_p1, lvars[src[0].src[1]]), loop_entry_bb, bb[-1].block)
|
||||
# hack for defining sqrt function (TODO: can we get a transcendental for this?)
|
||||
if u.op is Ops.SQRT: end_lines[f'declare {ldt(u.dtype)} @llvm.sqrt.{ldt(u.dtype)}({ldt(u.dtype)} %".1")'] = None
|
||||
|
||||
if u.op in (Ops.DEFINE_GLOBAL, Ops.DEFINE_VAR):
|
||||
r[u] = f"%data{u.arg}" if u.op is Ops.DEFINE_GLOBAL else f"%{u.arg[0]}"
|
||||
args.append(f"{ldt(u.dtype)}{' noalias' if isinstance(u.dtype, PtrDType) else ''} {r[u]}")
|
||||
elif u.op is Ops.ASSIGN: pass # assign is already handled by the first pass
|
||||
elif u.op is Ops.DEFINE_ACC: r[u] = r[u.src[0]] # a define acc can be used and never be assigned to
|
||||
elif u.op is Ops.CONST: r[u] = lconst(u.arg, u.dtype)
|
||||
elif u.op is Ops.CAST and ldt(u.dtype) == ldt(u.src[0].dtype): r[u] = r[u.src[0]] # cast from signed to unsigned of the same size is a noop
|
||||
else:
|
||||
if uop is Ops.RANGE:
|
||||
bb.append(ir.IRBuilder(func.append_basic_block(f"loop_body_{len(loop_blocks)}")))
|
||||
bb[-2].branch(bb[-1].block)
|
||||
# if it's an assign target, it's already preallocated
|
||||
if u not in r:
|
||||
vc += 1
|
||||
r[u] = f"%v{vc}"
|
||||
|
||||
phis = []
|
||||
for rp in reduce_phis:
|
||||
incoming = lvars[rp]
|
||||
lvars[rp] = bb[-1].phi(dtype_to_llvm_dtype[rp.dtype])
|
||||
lvars[rp].add_incoming(incoming, bb[-2].block)
|
||||
phis.append((rp, lvars[rp]))
|
||||
# do the rendering of the llvm ir code
|
||||
if (l:=llvm_rewrite.rewrite(u, ctx=r)) is None: raise RuntimeError(f"failed to render {u.op} with {u.dtype} srcs {[x.dtype for x in u.src]}")
|
||||
kernel.append(cast(str, l))
|
||||
|
||||
lvars[u] = bb[-1].phi(ir.IntType(32), name=f"loop{len(loop_blocks)}")
|
||||
lvars[u].add_incoming(lvars[src[0]], bb[-2].block)
|
||||
loop_blocks.append((bb[-1].block, phis))
|
||||
elif uop is Ops.DEFINE_ACC:
|
||||
lvars[u] = const(src[0].arg, dtype)
|
||||
reduce_phis.append(u)
|
||||
elif uop is Ops.LOAD:
|
||||
if len(src) > 1:
|
||||
with bb[-1].if_else(lvars[src[2]]) as (then, otherwise):
|
||||
with then:
|
||||
val1 = bb[-1].load(lvars[src[0]])
|
||||
then_blk = bb[-1].block
|
||||
with otherwise: otherwise_blk = bb[-1].block
|
||||
val = bb[-1].phi(val1.type)
|
||||
val.add_incoming(val1, then_blk)
|
||||
val.add_incoming(lvars[src[1]], otherwise_blk)
|
||||
else:
|
||||
val = bb[-1].load(lvars[src[0]])
|
||||
lvars[u] = val
|
||||
elif uop is Ops.ASSIGN:
|
||||
lvars[u] = lvars[src[1]]
|
||||
# ASSIGN UOps can link to other ASSIGN Uops, backtrace this to DEFINE_ACC
|
||||
backward = src[0]
|
||||
while backward.op is Ops.ASSIGN: backward = backward.src[0]
|
||||
lvars[backward] = lvars[u]
|
||||
elif uop in GroupOp.ALU:
|
||||
lvars[u] = self.code_for_op[uop](bb[-1], *[lvars[x] for x in src], src[0].dtype if uop in {BinaryOps.CMPLT, BinaryOps.CMPNE} else dtype)
|
||||
elif uop in {Ops.CAST, Ops.BITCAST}: lvars[u] = cast(bb, lvars[src[0]], src[0].dtype, dtype, bitcast=uop is Ops.BITCAST)
|
||||
elif uop in {Ops.DEFINE_GLOBAL, Ops.DEFINE_VAR}: lvars[u] = func.args[buf_index[args]]
|
||||
elif uop is Ops.CONST: lvars[u] = const(args, dtype)
|
||||
else: raise RuntimeError(f"failed to render {uop}")
|
||||
# generate the phi nodes for the assigns
|
||||
if u.op is Ops.RANGE:
|
||||
for x in acc_to_assign:
|
||||
if u in x.src: # if this range is relevent for this acc
|
||||
vc += 1
|
||||
kernel.append(f" %acc{vc} = phi {ldt(x.dtype)}" f"[{r[x]}, %loop_entry_{u.arg[0]}], [{r[acc_to_assign[x]]}, %loop_latch_{u.arg[0]}]")
|
||||
r[x] = f"%acc{vc}"
|
||||
|
||||
bb[-1].ret_void()
|
||||
return str(module)
|
||||
# output the function
|
||||
return f"define void @{name}({','.join(args)}) {{\n" + '\n'.join(kernel) + "\n ret void\n}\n"+'\n'.join(end_lines.keys())
|
||||
|
||||
@@ -23,6 +23,7 @@ class LLVMProgram:
|
||||
self.name, self.lib = name, lib
|
||||
device.engine.add_object_file(llvm.object_file.ObjectFileRef.from_data(lib))
|
||||
self.fxn = device.engine.get_function_address(name)
|
||||
assert self.fxn != 0, "LLVM failed to get function address"
|
||||
|
||||
def __call__(self, *bufs, vals:Tuple[int, ...]=(), wait=False):
|
||||
if not hasattr(self, 'cfunc'):
|
||||
|
||||
Reference in New Issue
Block a user