|
|
|
|
@@ -6,13 +6,13 @@ from tinygrad.helpers import dtypes
|
|
|
|
|
from tinygrad.ops import Op, UnaryOps, BinaryOps, TernaryOps
|
|
|
|
|
|
|
|
|
|
from tinygrad.shape.symbolic import Variable, NumNode, MulNode, DivNode, ModNode, LtNode, SumNode, AndNode
|
|
|
|
|
def int_const(x): return ir.Constant(ir.IntType(64), x)
|
|
|
|
|
def sym_render(a, ops=None, ctx=None): return ir.Constant(ir.IntType(64), a) if isinstance(a, int) else a.render(ops, ctx)
|
|
|
|
|
render_llvm = {
|
|
|
|
|
NumNode: lambda self,ops,ctx: int_const(self.b),
|
|
|
|
|
MulNode: lambda self,ops,ctx: ctx.mul(self.a.render(ops,ctx), int_const(self.b)),
|
|
|
|
|
DivNode: lambda self,ops,ctx: ctx.sdiv(self.a.render(ops,ctx), int_const(self.b)),
|
|
|
|
|
ModNode: lambda self,ops,ctx: ctx.srem(self.a.render(ops,ctx), int_const(self.b)),
|
|
|
|
|
LtNode: lambda self,ops,ctx: ctx.icmp_signed("<", self.a.render(ops,ctx), int_const(self.b)),
|
|
|
|
|
NumNode: lambda self,ops,ctx: sym_render(self.b,ops,ctx),
|
|
|
|
|
MulNode: lambda self,ops,ctx: ctx.mul(self.a.render(ops,ctx), sym_render(self.b,ops,ctx)),
|
|
|
|
|
DivNode: lambda self,ops,ctx: ctx.sdiv(self.a.render(ops,ctx), sym_render(self.b,ops,ctx)),
|
|
|
|
|
ModNode: lambda self,ops,ctx: ctx.srem(self.a.render(ops,ctx), sym_render(self.b,ops,ctx)),
|
|
|
|
|
LtNode: lambda self,ops,ctx: ctx.icmp_signed("<", self.a.render(ops,ctx), sym_render(self.b,ops,ctx)),
|
|
|
|
|
SumNode: lambda self,ops,ctx: functools.reduce(lambda a,b: ctx.add(a,b.render(ops,ctx)), self.nodes[1:], self.nodes[0].render(ops,ctx)),
|
|
|
|
|
AndNode: lambda self,ops,ctx: functools.reduce(lambda a,b: ctx.and_(a,b.render(ops,ctx)), self.nodes[1:], self.nodes[0].render(ops,ctx))
|
|
|
|
|
}
|
|
|
|
|
@@ -33,7 +33,7 @@ code_for_op: Final[Dict[Op, Callable]] = {
|
|
|
|
|
TernaryOps.WHERE: lambda builder,x,y,z: builder.select(builder.fcmp_unordered("!=", x, ir.Constant(ir.FloatType(), 0), flags=('fast',)), y, z, flags=('fast',)),
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
dtype_to_llvm_dtype = {dtypes.float64:ir.DoubleType(), dtypes.float16:ir.HalfType(), dtypes.bfloat16:ir.IntType(16), dtypes.float32:ir.FloatType(), dtypes.int8:ir.IntType(8), dtypes.uint8:ir.IntType(8), dtypes.bool: ir.IntType(1), dtypes.int64: ir.IntType(64), dtypes.int32: ir.IntType(32)}
|
|
|
|
|
dtype_to_llvm_dtype = {dtypes.float64:ir.DoubleType(), dtypes.float16:ir.HalfType(), dtypes.bfloat16:ir.IntType(16), dtypes.float32:ir.FloatType(), dtypes.int8:ir.IntType(8), dtypes.uint8:ir.IntType(8), dtypes.bool: ir.IntType(1), dtypes.int64: ir.IntType(64), dtypes.int32: ir.IntType(32), dtypes._arg_int32: ir.IntType(32)}
|
|
|
|
|
|
|
|
|
|
def cast(bb, val, input_type, output_type):
|
|
|
|
|
if input_type == output_type: return val
|
|
|
|
|
@@ -75,9 +75,10 @@ def uops_to_llvm_ir(function_name:str, uops:List[UOp]) -> Tuple[str, Optional[Li
|
|
|
|
|
buf_index = {x:i for i,x in enumerate(buf_to_dtype.keys())}
|
|
|
|
|
|
|
|
|
|
# create llvm function
|
|
|
|
|
func_dtypes = [dtype_to_llvm_dtype[dtype] for dtype in buf_to_dtype.values()]
|
|
|
|
|
func = ir.Function(module, ir.FunctionType(ir.VoidType(), [x.as_pointer() for x in func_dtypes]), name=function_name)
|
|
|
|
|
for a in func.args: a.add_attribute("noalias")
|
|
|
|
|
func_dtypes = [(dtype_to_llvm_dtype[dtype],dtype) for dtype in buf_to_dtype.values()]
|
|
|
|
|
func = ir.Function(module, ir.FunctionType(ir.VoidType(), [x.as_pointer() if dt!=dtypes._arg_int32 else x for x,dt in func_dtypes]), name=function_name)
|
|
|
|
|
for a in func.args:
|
|
|
|
|
if a.type.is_pointer: a.add_attribute("noalias")
|
|
|
|
|
|
|
|
|
|
# force llvmlite to allow us to add function attribute then add the attribute
|
|
|
|
|
func.attributes._known = func.attributes._known.union(frozenset(['"no-nans-fp-math"="true"']))
|
|
|
|
|
@@ -90,6 +91,9 @@ def uops_to_llvm_ir(function_name:str, uops:List[UOp]) -> Tuple[str, Optional[Li
|
|
|
|
|
lvars: Dict[Optional[Token], Any] = {} # this Any is an llvm type
|
|
|
|
|
render_llvm[Variable] = lambda self,ops,ctx: lvars[self.expr]
|
|
|
|
|
|
|
|
|
|
for bufname,dtype in buf_to_dtype.items():
|
|
|
|
|
if dtype == dtypes._arg_int32: lvars[bufname] = bb[-1].sext(func.args[buf_index[bufname]], ir.IntType(64))
|
|
|
|
|
|
|
|
|
|
for uop,newvar,vin,args in uops:
|
|
|
|
|
if uop == UOps.LOOP:
|
|
|
|
|
for var in args[0]:
|
|
|
|
|
@@ -106,16 +110,16 @@ def uops_to_llvm_ir(function_name:str, uops:List[UOp]) -> Tuple[str, Optional[Li
|
|
|
|
|
loop_blocks.append((bb[-1], phis))
|
|
|
|
|
|
|
|
|
|
lvars[var.expr] = bb[-1].phi(ir.IntType(64), name=var.expr)
|
|
|
|
|
lvars[var.expr].add_incoming(int_const(var.min), bb[-2]._block)
|
|
|
|
|
lvars[var.expr].add_incoming(sym_render(var.min), bb[-2]._block)
|
|
|
|
|
if uop == UOps.ENDLOOP:
|
|
|
|
|
for var in args[0][::-1]:
|
|
|
|
|
if isinstance(var, NumNode): continue
|
|
|
|
|
block, phis = loop_blocks.pop()
|
|
|
|
|
idx_p1 = bb[-1].add(lvars[var.expr], int_const(1))
|
|
|
|
|
idx_p1 = bb[-1].add(lvars[var.expr], sym_render(1))
|
|
|
|
|
lvars[var.expr].add_incoming(idx_p1, bb[-1]._block)
|
|
|
|
|
for n,phi in phis: phi.add_incoming(lvars[n], bb[-1]._block)
|
|
|
|
|
bb.append(ir.IRBuilder(func.append_basic_block(f"loop_exit_{var.expr}")))
|
|
|
|
|
bb[-2].cbranch(bb[-2].icmp_unsigned("==", idx_p1, int_const(var.max+1)), bb[-1]._block, block._block)
|
|
|
|
|
bb[-2].cbranch(bb[-2].icmp_unsigned(">", idx_p1, sym_render(var.max, render_llvm, bb[-2])), bb[-1]._block, block._block)
|
|
|
|
|
if uop == UOps.LOAD:
|
|
|
|
|
assert newvar is not None and isinstance(args, (MemOp, ConstOp))
|
|
|
|
|
valid = args.valid.render(render_llvm, bb[-1])
|
|
|
|
|
@@ -130,7 +134,7 @@ def uops_to_llvm_ir(function_name:str, uops:List[UOp]) -> Tuple[str, Optional[Li
|
|
|
|
|
else:
|
|
|
|
|
idx = args.idx.render(render_llvm, bb[-1])
|
|
|
|
|
if args.valid.min == 0:
|
|
|
|
|
aug_idx = bb[-1].select(valid, idx, int_const(0))
|
|
|
|
|
aug_idx = bb[-1].select(valid, idx, sym_render(0))
|
|
|
|
|
val = bb[-1].select(valid, bb[-1].load(bb[-1].gep(func.args[buf_index[args.name]], [aug_idx], inbounds=True)), ir.Constant(dtype_to_llvm_dtype[args.memory_dtype], args.invalid_value))
|
|
|
|
|
else:
|
|
|
|
|
val = bb[-1].load(bb[-1].gep(func.args[buf_index[args.name]], [idx], inbounds=True))
|
|
|
|
|
|