From 9b314c6342e0d018a1119eaa557d9c4e90ce3f2c Mon Sep 17 00:00:00 2001 From: George Hotz Date: Mon, 20 Mar 2023 08:19:48 -0700 Subject: [PATCH] factor uops transformers into functions --- tinygrad/codegen/cstyle.py | 250 +++++++++++++++++++------------------ tinygrad/codegen/llvmir.py | 164 ++++++++++++------------ 2 files changed, 214 insertions(+), 200 deletions(-) diff --git a/tinygrad/codegen/cstyle.py b/tinygrad/codegen/cstyle.py index 4fa2d35d77..624da8ea2c 100644 --- a/tinygrad/codegen/cstyle.py +++ b/tinygrad/codegen/cstyle.py @@ -1,10 +1,11 @@ -from typing import Final, Dict, Callable, ClassVar, List, Optional, NamedTuple, DefaultDict, Tuple, Set +from typing import Final, Dict, Callable, ClassVar, List, Optional, NamedTuple, DefaultDict, Tuple, Set, Any import math, collections from tinygrad.codegen.linearizer import Linearizer, UOps from tinygrad.ops import ASTRunner, Op, UnaryOps, BinaryOps, FusedOps from tinygrad.helpers import getenv, all_same, partition, ImageDType, DEBUG, dtypes from tinygrad.runtime.lib import RawConst from tinygrad.shape.symbolic import DivNode, AndNode, render_python, NumNode, Variable, Node, SumNode, MulNode +from tinygrad.lazy import LazyBuffer # div is different in cl than python render_cl = render_python.copy() @@ -46,6 +47,132 @@ def to_image_idx(base_shape:Tuple[int, ...], idxy:Node, valid:Node, validhacks=F if DEBUG >= 5: print("to_image_idx", base_shape, idx.min, idx.max, idy.min, idy.max, idx, idy) return idx, idy +code_for_op: Final[Dict[Op, Callable]] = { + UnaryOps.EXP: lambda x: f"native_exp({x})" if NATIVE_EXPLOG else f"exp({x})", + UnaryOps.LOG: lambda x: f"native_log({x})" if NATIVE_EXPLOG else f"log({x})", + BinaryOps.ADD: lambda a,b: f"({a}+{b})", BinaryOps.SUB: lambda a,b: f"({a}-{b})", + BinaryOps.MUL: lambda a,b: f"({a}*{b})", BinaryOps.DIV: lambda a,b: f"({a}/{b})", + BinaryOps.POW: lambda a,b: f"pow({a},{b})", BinaryOps.MAX: lambda a,b: f"max({a},{b})", + BinaryOps.CMPEQ: lambda a,b: f"({a}=={b})", FusedOps.MULACC: lambda a,b,c: f"(({b}*{c})+{a})" +} + +def uops_to_cstyle(uops:List[Tuple[UOps, Optional[str], Any]], bufs:List[LazyBuffer], bufnames:List[str], lang:CStyleLanguage) -> Tuple[str, List[int], List[int]]: + def group_float4(grp:List[str]) -> str: + if all(g.endswith(e) for g,e in zip(grp, [".x", ".y", ".z", ".w"])) and all_same([g.split(".")[0] for g in grp]): return grp[0].split(".")[0] + else: return f"{lang.float4}({','.join(g for g in grp)})" + + prekernel: Set[str] = set() + kernel = [] + global_size = [] + local_size = [] + pend_close = None + + depth = 0 + def kk(s): kernel.append(" "*depth+s) + + for uop,newvar,args in uops: + if uop == UOps.LOOP: + root = None + for i,var in enumerate(args[0]): + if isinstance(var, NumNode): + if args[1] == "global" and lang.gid: global_size.append(1) + if args[1] == "local" and lang.lid: local_size.append(1) + # one number, not an index + kk("{") + else: + if args[1] == "global" and lang.gid: + if len(args[0]) >= 4 and len(args[0])-i > 2: + # sometimes, there's more dimensions. compact all the dimensions into the last CL dimension + # TODO: these compactions should be searchable (they sort of are with reshapes and permutes) + if i == 0: + kk(f"{{ int {var.expr} = {lang.gid[-1]}; /* {var.max+1} */") + root = var.expr + global_size.append(var.max+1) + else: + kk(f"{{ int {var.expr} = {root} % {var.max+1}; {root} /= {var.max+1};") + global_size[-1] *= var.max+1 + else: + kk(f"{{ int {var.expr} = {lang.gid[len(args[0])-1-i]}; /* {var.max+1} */") + global_size.append(var.max+1) + elif args[1] == "local" and lang.lid: + assert len(args[0]) <= len(lang.lid) + kk(f"{{ int {var.expr} = {lang.lid[len(args[0])-1-i]}; /* {var.max+1} */") + local_size.append(var.max+1) + else: + kk(f"for (int {var.expr} = {var.min}; {var.expr} <= {var.max}; ++{var.expr}) {{") + depth += 1 + if uop == UOps.ENDLOOP: + if args[1] == "local" and len(lang.lid): + # TODO: this is a bit of a hack. the local loop isn't real on the GPU + kk(lang.barrier) + kk(f"if ({Variable.sum(args[0]).render(render_cl)} == 0) {{") + pend_close = "}"*(len(args[0])+1) + f" /* {args[1]} */" + else: + if args[1] == "global" and pend_close: + depth -= 1 + kk(pend_close) + pend_close = None + depth -= 1 + kk("}"*len(args[0]) + f" /* {args[1]} */") + if uop == UOps.CONST: + if args[0] == -math.inf: + kk(f"float {newvar} = -INFINITY;") + else: + kk(f"float {newvar} = {args[0]}f;") + if uop == UOps.ALU: + if newvar is None: + kk(f"{args[2]} = {code_for_op[args[0]](*args[1])};") + else: + kk(f"float {newvar} = {code_for_op[args[0]](*args[1])};") + # TODO: refactor the next 14 lines + if uop == UOps.LOAD: + # TODO: merge with CONST? + if bufs[args[0]] is not None and isinstance(bufs[args[0]].realized, RawConst): + # nan? inf? + val = f"{bufs[args[0]].realized._buf}f" + else: + if lang.uses_vload and bufs[args[0]] is not None and bufs[args[0]].dtype == dtypes.float16: + val = f"vload_half({args[1].render(render_cl)}, {bufnames[args[0]]})" + else: + val = f"{bufnames[args[0]]}[{args[1].render(render_cl)}]" + # NOTE: if min and max are both 0, it should be a CONST in the Linearizer + if args[2].min == 1: kk(f"float {newvar} = {val};") + else: kk(f"float {newvar} = ({args[2].render(render_cl)}) ? ({val}) : 0.0f;") + if uop == UOps.LOAD4: + if bufs[args[0]] is not None and isinstance(bufs[args[0]].dtype, ImageDType): + prekernel.add("const sampler_t smp = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST;\n") + idx, idy = to_image_idx(bufs[args[0]].dtype.shape, args[1], args[2]) + val = f"read_imagef({bufnames[args[0]]}, smp, (int2)({idx.render(render_cl)}, {idy.render(render_cl)}))" + else: + val = f"(({lang.buffer_prefix if bufs[args[0]] is not None else lang.smem_prefix}float4*){bufnames[args[0]]})[{(args[1]//4).render(render_cl)}]" + # NOTE: if min and max are both 0, it should be a CONST in the Linearizer + if args[2].min == 1: kk(f"float4 {newvar} = {val};") + else: kk(f"float4 {newvar} = ({args[2].render(render_cl)}) ? ({val}) : {group_float4(['0.0f']*4)};") + if uop == UOps.STORE: + assert args[2].min == 1, "store must be valid" + if lang.uses_vload and bufs[args[0]] is not None and bufs[args[0]].dtype == dtypes.float16: + kk(f"vstore_half({args[3]}, {args[1].render(render_cl)}, {bufnames[args[0]]});") + else: + kk(f"{bufnames[args[0]]}[{args[1].render(render_cl)}] = {args[3]};") + if uop == UOps.STORE4: + assert args[2].min == 1, "store must be valid" + if bufs[args[0]] is not None and isinstance(bufs[args[0]].dtype, ImageDType): + idx, idy = to_image_idx(bufs[args[0]].dtype.shape, args[1], args[2]) + kk(f"write_imagef({bufnames[args[0]]}, (int2)({idx.render(render_cl)}, {idy.render(render_cl)}), {group_float4(args[3])});") + else: + kk(f"(({lang.buffer_prefix if bufs[args[0]] is not None else lang.smem_prefix}float4*){bufnames[args[0]]})[{(args[1]//4).render(render_cl)}] = {group_float4(args[3])};") + if uop == UOps.DEFINE_LOCAL: + kk(lang.smem_prefix + f"float {args[0]}[{args[1]}];") + + buftypes = [(i,f"{'read_only' if i > 0 else 'write_only'} image2d_t" if x.dtype.name.startswith('image') else + ("const " if i > 0 else "")+lang.buffer_prefix+x.dtype.name+"*"+lang.buffer_suffix) for i,x in enumerate(bufs) + if x is not None and not isinstance(x.realized, RawConst)] + prg = ''.join([f"{lang.kernel_prefix} void KERNEL_NAME_PLACEHOLDER(",] + + [', '.join([f'{t} {bufnames[i]}' for i,t in buftypes] + lang.extra_args)] + + [") {\n"] + list(prekernel) + ['\n'.join(kernel), "\n}"]) + + return prg, global_size, local_size + class CStyleCodegen(Linearizer): lang: ClassVar[CStyleLanguage] = CStyleLanguage() supports_constant_folding: bool = True @@ -55,19 +182,6 @@ class CStyleCodegen(Linearizer): kernel_cnt: Final[DefaultDict[str, int]] = collections.defaultdict(int) kernel_name_cache: Final[Dict[str, str]] = {} - code_for_op: Final[Dict[Op, Callable]] = { - UnaryOps.EXP: lambda x: f"native_exp({x})" if NATIVE_EXPLOG else f"exp({x})", - UnaryOps.LOG: lambda x: f"native_log({x})" if NATIVE_EXPLOG else f"log({x})", - BinaryOps.ADD: lambda a,b: f"({a}+{b})", BinaryOps.SUB: lambda a,b: f"({a}-{b})", - BinaryOps.MUL: lambda a,b: f"({a}*{b})", BinaryOps.DIV: lambda a,b: f"({a}/{b})", - BinaryOps.POW: lambda a,b: f"pow({a},{b})", BinaryOps.MAX: lambda a,b: f"max({a},{b})", - BinaryOps.CMPEQ: lambda a,b: f"({a}=={b})", FusedOps.MULACC: lambda a,b,c: f"(({b}*{c})+{a})" - } - - def group_float4(self, grp:List[str]) -> str: - if all(g.endswith(e) for g,e in zip(grp, [".x", ".y", ".z", ".w"])) and all_same([g.split(".")[0] for g in grp]): return grp[0].split(".")[0] - else: return f"{self.lang.float4}({','.join(g for g in grp)})" - def codegen(self): self.process() @@ -86,113 +200,7 @@ class CStyleCodegen(Linearizer): self.hand_coded_optimizations() self.linearize() - prekernel: Set[str] = set() - kernel = [] - global_size = [] - local_size = [] - pend_close = None - - depth = 0 - def kk(s): kernel.append(" "*depth+s) - - for uop,newvar,args in self.uops: - if uop == UOps.LOOP: - root = None - for i,var in enumerate(args[0]): - if isinstance(var, NumNode): - if args[1] == "global" and self.lang.gid: global_size.append(1) - if args[1] == "local" and self.lang.lid: local_size.append(1) - # one number, not an index - kk("{") - else: - if args[1] == "global" and self.lang.gid: - if len(args[0]) >= 4 and len(args[0])-i > 2: - # sometimes, there's more dimensions. compact all the dimensions into the last CL dimension - # TODO: these compactions should be searchable (they sort of are with reshapes and permutes) - if i == 0: - kk(f"{{ int {var.expr} = {self.lang.gid[-1]}; /* {var.max+1} */") - root = var.expr - global_size.append(var.max+1) - else: - kk(f"{{ int {var.expr} = {root} % {var.max+1}; {root} /= {var.max+1};") - global_size[-1] *= var.max+1 - else: - kk(f"{{ int {var.expr} = {self.lang.gid[len(args[0])-1-i]}; /* {var.max+1} */") - global_size.append(var.max+1) - elif args[1] == "local" and self.lang.lid: - assert len(args[0]) <= len(self.lang.lid) - kk(f"{{ int {var.expr} = {self.lang.lid[len(args[0])-1-i]}; /* {var.max+1} */") - local_size.append(var.max+1) - else: - kk(f"for (int {var.expr} = {var.min}; {var.expr} <= {var.max}; ++{var.expr}) {{") - depth += 1 - if uop == UOps.ENDLOOP: - if args[1] == "local" and len(self.lang.lid): - # TODO: this is a bit of a hack. the local loop isn't real on the GPU - kk(self.lang.barrier) - kk(f"if ({Variable.sum(args[0]).render(render_cl)} == 0) {{") - pend_close = "}"*(len(args[0])+1) + f" /* {args[1]} */" - else: - if args[1] == "global" and pend_close: - depth -= 1 - kk(pend_close) - pend_close = None - depth -= 1 - kk("}"*len(args[0]) + f" /* {args[1]} */") - if uop == UOps.CONST: - if args[0] == -math.inf: - kk(f"float {newvar} = -INFINITY;") - else: - kk(f"float {newvar} = {args[0]}f;") - if uop == UOps.ALU: - if newvar is None: - kk(f"{args[2]} = {self.code_for_op[args[0]](*args[1])};") - else: - kk(f"float {newvar} = {self.code_for_op[args[0]](*args[1])};") - # TODO: refactor the next 14 lines - if uop == UOps.LOAD: - # TODO: merge with CONST? - if self.bufs[args[0]] is not None and isinstance(self.bufs[args[0]].realized, RawConst): - # nan? inf? - val = f"{self.bufs[args[0]].realized._buf}f" - else: - if self.lang.uses_vload and self.bufs[args[0]] is not None and self.bufs[args[0]].dtype == dtypes.float16: - val = f"vload_half({args[1].render(render_cl)}, {self.registers[args[0]].name})" - else: - val = f"{self.registers[args[0]].name}[{args[1].render(render_cl)}]" - # NOTE: if min and max are both 0, it should be a CONST in the Linearizer - if args[2].min == 1: kk(f"float {newvar} = {val};") - else: kk(f"float {newvar} = ({args[2].render(render_cl)}) ? ({val}) : 0.0f;") - if uop == UOps.LOAD4: - if self.bufs[args[0]] is not None and isinstance(self.bufs[args[0]].dtype, ImageDType): - prekernel.add("const sampler_t smp = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST;\n") - idx, idy = to_image_idx(self.bufs[args[0]].dtype.shape, args[1], args[2]) - val = f"read_imagef({self.registers[args[0]].name}, smp, (int2)({idx.render(render_cl)}, {idy.render(render_cl)}))" - else: - val = f"(({self.lang.buffer_prefix if self.bufs[args[0]] is not None else self.lang.smem_prefix}float4*){self.registers[args[0]].name})[{(args[1]//4).render(render_cl)}]" - # NOTE: if min and max are both 0, it should be a CONST in the Linearizer - if args[2].min == 1: kk(f"float4 {newvar} = {val};") - else: kk(f"float4 {newvar} = ({args[2].render(render_cl)}) ? ({val}) : {self.group_float4(['0.0f']*4)};") - if uop == UOps.STORE: - assert args[2].min == 1, "store must be valid" - if self.lang.uses_vload and self.bufs[args[0]] is not None and self.bufs[args[0]].dtype == dtypes.float16: - kk(f"vstore_half({args[3]}, {args[1].render(render_cl)}, {self.registers[args[0]].name});") - else: - kk(f"{self.registers[args[0]].name}[{args[1].render(render_cl)}] = {args[3]};") - if uop == UOps.STORE4: - assert args[2].min == 1, "store must be valid" - if self.bufs[args[0]] is not None and isinstance(self.bufs[args[0]].dtype, ImageDType): - idx, idy = to_image_idx(self.bufs[args[0]].dtype.shape, args[1], args[2]) - kk(f"write_imagef({self.registers[args[0]].name}, (int2)({idx.render(render_cl)}, {idy.render(render_cl)}), {self.group_float4(args[3])});") - else: - kk(f"(({self.lang.buffer_prefix if self.bufs[args[0]] is not None else self.lang.smem_prefix}float4*){self.registers[args[0]].name})[{(args[1]//4).render(render_cl)}] = {self.group_float4(args[3])};") - if uop == UOps.DEFINE_LOCAL: - kk(self.lang.smem_prefix + f"float {args[0]}[{args[1]}];") - - buftypes = [(i,f"{'read_only' if i > 0 else 'write_only'} image2d_t" if x.dtype.name.startswith('image') else self.lang.buffer_prefix+x.dtype.name+"*"+self.lang.buffer_suffix) for i,x in enumerate(self.bufs) if x is not None and not isinstance(x.realized, RawConst)] - prg = ''.join([f"{self.lang.kernel_prefix} void KERNEL_NAME_PLACEHOLDER(",] + - [', '.join([f'{"const" if i > 0 else ""} {t} data{i}' for i,t in buftypes] + self.lang.extra_args)] + - [") {\n"] + list(prekernel) + ['\n'.join(kernel), "\n}"]) + prg, global_size, local_size = uops_to_cstyle(self.uops, self.bufs, [x.name for x in self.registers], self.lang) # if we have local_sizes, we have to correct the global_size for i,s in enumerate(local_size): global_size[i] *= s diff --git a/tinygrad/codegen/llvmir.py b/tinygrad/codegen/llvmir.py index 512b96036d..942438f6a0 100644 --- a/tinygrad/codegen/llvmir.py +++ b/tinygrad/codegen/llvmir.py @@ -1,9 +1,10 @@ -from typing import Final, Dict, Callable, Any, List, Optional +from typing import Final, Dict, Callable, Any, List, Optional, Tuple import functools from llvmlite import ir # type: ignore from tinygrad.codegen.linearizer import Linearizer, UOps from tinygrad.helpers import dtypes from tinygrad.ops import Op, ASTRunner, UnaryOps, BinaryOps, FusedOps +from tinygrad.lazy import LazyBuffer from tinygrad.shape.symbolic import Variable, NumNode, MulNode, DivNode, ModNode, GeNode, LtNode, SumNode, AndNode def int_const(x): return ir.Constant(ir.IntType(64), x) @@ -18,86 +19,91 @@ render_llvm = { AndNode: lambda self,ops,ctx: functools.reduce(lambda a,b: ctx.and_(a,b.render(ops,ctx)), self.nodes[1:], self.nodes[0].render(ops,ctx)) } +code_for_op: Final[Dict[Op, Callable]] = { + UnaryOps.EXP: lambda builder,x: builder.call(builder._block.module.declare_intrinsic('llvm.exp', [ir.FloatType()]), [x], fastmath=('fast',)), + UnaryOps.LOG: lambda builder,x: builder.call(builder._block.module.declare_intrinsic('llvm.log', [ir.FloatType()]), [x], fastmath=('fast',)), + BinaryOps.ADD: lambda builder,x,y: builder.fadd(x,y, flags=('fast',)), + BinaryOps.SUB: lambda builder,x,y: builder.fsub(x,y, flags=('fast',)), + BinaryOps.MUL: lambda builder,x,y: builder.fmul(x,y, flags=('fast',)), + BinaryOps.DIV: lambda builder,x,y: builder.fdiv(x,y, flags=('fast',)), + BinaryOps.POW: lambda builder,x,y: builder.call(builder._block.module.declare_intrinsic('llvm.pow', [ir.FloatType()]), [x,y], fastmath=('fast',)), + BinaryOps.CMPEQ: lambda builder,x,y: builder.uitofp(builder.fcmp_ordered("==", x, y, flags=('fast',)), ir.FloatType()), + BinaryOps.MAX: lambda builder,x,y: builder.select(builder.fcmp_unordered(">", x, y, flags=('fast',)), x, y, flags=('fast',)), + FusedOps.MULACC: lambda builder,x,y,z: builder.fadd(builder.fmul(y,z, flags=('fast',)), x, flags=('fast',)), +} + +def uops_to_llvm_ir(uops:List[Tuple[UOps, Optional[str], Any]], bufs:List[LazyBuffer]) -> str: + # all llvm stuff goes into a module + module = ir.Module(name=__file__) + + # create llvm function + func_dtypes = [{dtypes.float16:ir.HalfType(), dtypes.float32:ir.FloatType()}[buf.dtype] for buf in bufs] + func = ir.Function(module, ir.FunctionType(ir.VoidType(), [x.as_pointer() for x in func_dtypes]), name='exec') + + # force llvmlite to allow us to add function attribute then add the attribute + func.attributes._known = func.attributes._known.union(frozenset(['"no-nans-fp-math"="true"'])) + func.attributes.add('"no-nans-fp-math"="true"') + + bb = [ir.IRBuilder(func.append_basic_block("entry"))] + loop_blocks = [] + reduce_phis: List = [] + # TODO: newvar probably shouldn't be optional + lvars: Dict[Optional[str], Any] = {} # this Any is an llvm type + render_llvm[Variable] = lambda self,ops,ctx: lvars[self.expr] + + for uop,newvar,args in uops: + if uop == UOps.CONST: + lvars[newvar] = ir.Constant(ir.FloatType(), args[0]) + reduce_phis.append(newvar) + if uop == UOps.LOOP: + for var in args[0]: + if isinstance(var, NumNode): continue + bb.append(ir.IRBuilder(func.append_basic_block(f"loop_body_{var.expr}"))) + bb[-2].branch(bb[-1]._block) + + phis = [] + for rp in reduce_phis: + incoming = lvars[rp] + lvars[rp] = bb[-1].phi(ir.FloatType()) + lvars[rp].add_incoming(incoming, bb[-2]._block) + phis.append((rp, lvars[rp])) + loop_blocks.append((bb[-1], phis)) + + lvars[var.expr] = bb[-1].phi(ir.IntType(64), name=var.expr) + lvars[var.expr].add_incoming(int_const(var.min), bb[-2]._block) + if uop == UOps.ENDLOOP: + for var in args[0][::-1]: + if isinstance(var, NumNode): continue + block, phis = loop_blocks.pop() + idx_p1 = bb[-1].add(lvars[var.expr], int_const(1)) + lvars[var.expr].add_incoming(idx_p1, bb[-1]._block) + for n,phi in phis: phi.add_incoming(lvars[n], bb[-1]._block) + bb.append(ir.IRBuilder(func.append_basic_block(f"loop_exit_{var.expr}"))) + bb[-2].cbranch(bb[-2].icmp_unsigned("==", idx_p1, int_const(var.max+1)), bb[-1]._block, block._block) + if uop == UOps.LOAD: + idx, valid = args[1].render(render_llvm, bb[-1]), args[2].render(render_llvm, bb[-1]) + if args[2].min == 0: + aug_idx = bb[-1].select(valid, idx, int_const(0)) + val= bb[-1].select(valid, bb[-1].load(bb[-1].gep(func.args[args[0]], [aug_idx], inbounds=True)), ir.Constant(func_dtypes[args[0]], 0)) + else: + val = bb[-1].load(bb[-1].gep(func.args[args[0]], [idx], inbounds=True)) + if func_dtypes[args[0]] != ir.FloatType(): val = bb[-1].fpext(val, ir.FloatType()) + lvars[newvar] = val + if uop == UOps.STORE: + assert args[2].min == 1, "store must be valid" + idx = args[1].render(render_llvm, bb[-1]) + element = lvars[args[3]] + if func_dtypes[0] != ir.FloatType(): element = bb[-1].fptrunc(element, func_dtypes[0]) + bb[-1].store(element, bb[-1].gep(func.args[args[0]], [idx], inbounds=True)) + if uop == UOps.ALU: + lvars[newvar if newvar is not None else args[2]] = code_for_op[args[0]](bb[-1], *[lvars[x] for x in args[1]]) + + bb[-1].ret_void() + return str(module) + class LLVMIRCodegen(Linearizer): - code_for_op: Final[Dict[Op, Callable]] = { - UnaryOps.EXP: lambda builder,x: builder.call(builder._block.module.declare_intrinsic('llvm.exp', [ir.FloatType()]), [x], fastmath=('fast',)), - UnaryOps.LOG: lambda builder,x: builder.call(builder._block.module.declare_intrinsic('llvm.log', [ir.FloatType()]), [x], fastmath=('fast',)), - BinaryOps.ADD: lambda builder,x,y: builder.fadd(x,y, flags=('fast',)), - BinaryOps.SUB: lambda builder,x,y: builder.fsub(x,y, flags=('fast',)), - BinaryOps.MUL: lambda builder,x,y: builder.fmul(x,y, flags=('fast',)), - BinaryOps.DIV: lambda builder,x,y: builder.fdiv(x,y, flags=('fast',)), - BinaryOps.POW: lambda builder,x,y: builder.call(builder._block.module.declare_intrinsic('llvm.pow', [ir.FloatType()]), [x,y], fastmath=('fast',)), - BinaryOps.CMPEQ: lambda builder,x,y: builder.uitofp(builder.fcmp_ordered("==", x, y, flags=('fast',)), ir.FloatType()), - BinaryOps.MAX: lambda builder,x,y: builder.select(builder.fcmp_unordered(">", x, y, flags=('fast',)), x, y, flags=('fast',)), - FusedOps.MULACC: lambda builder,x,y,z: builder.fadd(builder.fmul(y,z, flags=('fast',)), x, flags=('fast',)), - } def codegen(self): self.process() # no optimize, this doesn't support local self.linearize() - - # create llvm function - module = ir.Module(name=__file__) - func_dtypes = [{dtypes.float16:ir.HalfType(), dtypes.float32:ir.FloatType()}[buf.dtype] for buf in self.bufs] - func = ir.Function(module, ir.FunctionType(ir.VoidType(), [x.as_pointer() for x in func_dtypes]), name='exec') - - # force llvmlite to allow us to add function attribute then add the attribute - func.attributes._known = func.attributes._known.union(frozenset(['"no-nans-fp-math"="true"'])) - func.attributes.add('"no-nans-fp-math"="true"') - - bb = [ir.IRBuilder(func.append_basic_block("entry"))] - loop_blocks = [] - reduce_phis: List = [] - # TODO: newvar probably shouldn't be optional - lvars: Dict[Optional[str], Any] = {} # this Any is an llvm type - render_llvm[Variable] = lambda self,ops,ctx: lvars[self.expr] - - for uop,newvar,args in self.uops: - if uop == UOps.CONST: - lvars[newvar] = ir.Constant(ir.FloatType(), args[0]) - reduce_phis.append(newvar) - if uop == UOps.LOOP: - for var in args[0]: - if isinstance(var, NumNode): continue - bb.append(ir.IRBuilder(func.append_basic_block(f"loop_body_{var.expr}"))) - bb[-2].branch(bb[-1]._block) - - phis = [] - for rp in reduce_phis: - incoming = lvars[rp] - lvars[rp] = bb[-1].phi(ir.FloatType()) - lvars[rp].add_incoming(incoming, bb[-2]._block) - phis.append((rp, lvars[rp])) - loop_blocks.append((bb[-1], phis)) - - lvars[var.expr] = bb[-1].phi(ir.IntType(64), name=var.expr) - lvars[var.expr].add_incoming(int_const(var.min), bb[-2]._block) - if uop == UOps.ENDLOOP: - for var in args[0][::-1]: - if isinstance(var, NumNode): continue - block, phis = loop_blocks.pop() - idx_p1 = bb[-1].add(lvars[var.expr], int_const(1)) - lvars[var.expr].add_incoming(idx_p1, bb[-1]._block) - for n,phi in phis: phi.add_incoming(lvars[n], bb[-1]._block) - bb.append(ir.IRBuilder(func.append_basic_block(f"loop_exit_{var.expr}"))) - bb[-2].cbranch(bb[-2].icmp_unsigned("==", idx_p1, int_const(var.max+1)), bb[-1]._block, block._block) - if uop == UOps.LOAD: - idx, valid = args[1].render(render_llvm, bb[-1]), args[2].render(render_llvm, bb[-1]) - if args[2].min == 0: - aug_idx = bb[-1].select(valid, idx, int_const(0)) - val= bb[-1].select(valid, bb[-1].load(bb[-1].gep(func.args[args[0]], [aug_idx], inbounds=True)), ir.Constant(func_dtypes[args[0]], 0)) - else: - val = bb[-1].load(bb[-1].gep(func.args[args[0]], [idx], inbounds=True)) - if func_dtypes[args[0]] != ir.FloatType(): val = bb[-1].fpext(val, ir.FloatType()) - lvars[newvar] = val - if uop == UOps.STORE: - assert args[2].min == 1, "store must be valid" - idx = args[1].render(render_llvm, bb[-1]) - element = lvars[args[3]] - if func_dtypes[0] != ir.FloatType(): element = bb[-1].fptrunc(element, func_dtypes[0]) - bb[-1].store(element, bb[-1].gep(func.args[args[0]], [idx], inbounds=True)) - if uop == UOps.ALU: - lvars[newvar if newvar is not None else args[2]] = self.code_for_op[args[0]](bb[-1], *[lvars[x] for x in args[1]]) - - bb[-1].ret_void() - return ASTRunner('exec', str(module), op_estimate=self.info.flops, mem_estimate=self.mem_estimate) + return ASTRunner('exec', uops_to_llvm_ir(self.uops, self.bufs), op_estimate=self.info.flops, mem_estimate=self.mem_estimate)