Simple chonker (#431)

* chonker will make llvm fast

* work

* better speed tests, we will make them fast

* with the cache add is the same speed

* relu and neg are fast

* fix sum speed

* maximum maxnum?

* hack for gemm opt

* gemm very slow

* zeros like

* test_permute

* shapetracker returns self

* fix shapetracker factorization

* err, int strides

* permutes are faster now in tinygrad than pytorch

* support -1 in expand

* gemm unrolled

* improve final test case

* WIP GEMM

* why isn't GEMM fast?

* revert cache dim

* ffp contract works on clang, not llvm?

* ignore llvm ir

* this makes fma work at least, but no faster

* USE_4x4

* 63 GFLOPS

* 87 GFLOPS

* that wasn't matmul, 44 GFLOPS now

* 82 GFLOPS permuted

* this permute too

* a little speed for the convs

* 45 GFLOPS

* speed tests pass again

* clean up prints

* fix FMA WHAT A WASTE OF TIME

* colors

* moar fair

* GPU

* useless on chonker

* cleanups

* improve factorized shapetracker

* better threshold

* label conv

* work

* ops test pass again

* hot load the index

* run the last view, no need to create

* ZeroView needs a repr for the key to work

* fix segfault on out of bounds

* one more test

* start amx, and llvm.initialize_native_asmparser

* amx works

* nice AMX class

* nicer AMX class

* refactor get_idxs

* amx working

* is slower...

* useless flip

* cache

* SZ_X

* AMX_SZ_X/Y work alone

* Contiguous mlop

* test gemm packed

* PREPARE in packed

* use_amx factor

* prefetch isn't faster

* loop

* same 3ms

* 2.24 ms

* allow double on store in TG

* amx reduce is the same speed as non amx reduce

* include memory bandwidth

* clean up shapetracker

* flip returns stride

* prepare for upstream

* Update ops_llvm.py (#426)

* permutes are yellow and green now

* faster conv

* llvm cleanups

* Show optimised IR under debug 4 (#428)

* ASTKernel class

* Make tinygrad work with older python version (#427)

* Make tinygrad work with older python version

* Use partialmethod instead of partial

* smiple chonker is chonking

* remove junk from test speed vs torch

* fix linker and types

* AMX is only here now

* add LLVM tests, it's a valid backend now

* oops, run llvm test

* contiguous_op

* fix loadops compare

* dedup reduceops

Co-authored-by: calledit <1573053+calledit@users.noreply.github.com>
This commit is contained in:
George Hotz
2022-11-10 23:17:09 -08:00
committed by GitHub
parent bff47e9dc1
commit b8c94a67c9
9 changed files with 451 additions and 166 deletions

View File

@@ -1,30 +1,28 @@
from __future__ import annotations
import os
import hashlib
import math
from typing import Tuple, Union
import time
from typing import Tuple, Union, Dict, Any
from tinygrad.helpers import prod
from tinygrad.shapetracker import ShapeTracker, ZeroView
from tinygrad.ops import LazyOp
from tinygrad.ops import LazyOp, ASTKernel
import ctypes
import numpy as np
from ctypes import CFUNCTYPE
from tinygrad.ops import DEBUG, UnaryOps, BinaryOps, ReduceOps, get_buffers, get_lazyops, ExplicitExecAST, get_lazyop_info
from tinygrad.ops import DEBUG, UnaryOps, BinaryOps, ReduceOps, ExplicitExecAST
from llvmlite import ir # type: ignore
import llvmlite.binding as llvm # type: ignore
def int_const(x): return ir.Constant(ir.IntType(64), x)
def idx_deref(builder, buf, ptr, eidx):
if eidx[2] == 1 and eidx[3] is None:
idx = eidx[1]
else:
idx = builder.add(builder.mul(eidx[1], int_const(eidx[2])), eidx[3], name="idx")
# this is only used on the crappy path
def idx_deref(builder, buf, ptr, idx):
if DEBUG >= 1:
print("viewcount:", len(buf.st.views), buf.st.expr(), ptr, "on", buf.shape)
# TODO: unify this with expr in ShapeTracker
valid = None
for v in buf.st.views[::-1]:
for v in buf.st.views[0:-1][::-1]:
if isinstance(v, ZeroView):
if valid is None:
valid = ir.Constant(ir.IntType(1), 1)
@@ -49,93 +47,53 @@ def idx_deref(builder, buf, ptr, eidx):
print(f"expanding index {v.shape_strides}")
for i,(d,s) in enumerate(v.shape_strides[::-1]):
if d != 1 and s != 0:
if acc%eidx[2] == 0 and len(buf.st.views) == 1:
# the inner one doesn't matter
lr = eidx[1]
if acc//eidx[2] != 1:
lr = builder.sdiv(lr, int_const(acc//eidx[2]))
if (acc//eidx[2])*d != eidx[0]:
lr = builder.srem(lr, int_const(d))
elif acc*d <= eidx[2] and eidx[3] is not None and len(buf.st.views) == 1:
# the outer one doesn't matter
lr = eidx[3]
if acc != 1:
lr = builder.sdiv(lr, int_const(acc))
if acc*d != eidx[2]:
lr = builder.srem(lr, int_const(d))
else:
# slow path
lr = idx
if acc != 1:
lr = builder.sdiv(lr, int_const(acc))
if acc*d != (eidx[0]*eidx[2]):
lr = builder.srem(lr, int_const(d))
# slow path
lr = idx
if acc != 1:
lr = builder.sdiv(lr, int_const(acc))
if acc*d != prod(buf.shape):
lr = builder.srem(lr, int_const(d))
if s != 1:
lr = builder.mul(lr, int_const(s))
ret = builder.add(ret, lr)
acc *= d
idx = ret
if valid is not None:
return builder.select(valid, builder.load(builder.gep(ptr, [idx], inbounds=True)), ir.Constant(ir.FloatType(), 0))
# this always does the load, so we have it load *0 if the arg won't be used
# TODO: would control flow be faster?
aug_idx = builder.select(valid, idx, int_const(0))
return builder.select(valid, builder.load(builder.gep(ptr, [aug_idx], inbounds=True)), ir.Constant(ir.FloatType(), 0))
else:
return builder.load(builder.gep(ptr, [idx], inbounds=True))
# https://blog.christianperone.com/2022/09/tutorial-on-using-llvm-to-jit-pytorch-fx-graphs-to-native-code-x86-arm-risc-v-wasm-part-i-scalars/
class LLVM:
target_machine = None
engine = None
optimizer = None
# if it can't vectorize
# OPT=2 DEBUG=3 LLVM=1 FORWARD_ONLY=1 python3 test/test_ops.py TestOps.test_mul
# if can't vectorize anything
# looks like we have two options, either use clang or handle vectorization in tinygrad
# for the sake of the GPU, we should probably do in tinygrad
# ARM NEON is 128b wide, aka <4 x float> (similar to most GPUs)
# Firestorm (big M1 core) can do up to 4 ops per cycle @ 3.2 GHz = 3.2*4*4*2 = 102.4 GFLOPS (fma)
# There's also AMX https://github.com/corsix/amx/blob/main/README.md
# It seems like torch CPU must be using this? I'm seeing ~150 GFLOPS with convs
# Calling nnp_s4gemm_only_3x3__neon and nnp_owt8x8_3x3_with_bias__neon which don't seem like AMX
# Could this be a winograd conv? Yes, nnp_owt8x8_3x3_with_bias__neon is in NNPACK 2d-winograd-8x8-3x3.c
# 2048x2048 matmul in 9.88 ms (17.18 GOPS) = 1739 GFLOPS (so much! this has to be the AMX)
# calling libBLAS.dylib`SGEMM
# 0x1c3ac5070: 0x0020100d .long 0x0020100d ; AMX instruction 0 = ldx
# 0x1c3ac5074: 0x0020102b .long 0x0020102b ; AMX instruction 1 = ldy (presumed typo in ldst.md)
# 0x1c3ac5078: 0x0020119f .long 0x0020119f ; AMX instruction 12 = fma32
# 0x1c3ac507c: 0x0020118e .long 0x0020118e ; AMX instruction 12 = fma32
# 0x1c3ac5080: 0x9144410f add x15, x8, #0x110, lsl #12 ; =0x110000
# 0x1c3ac5084: 0x00201188 .long 0x00201188 ; AMX instruction 12 = fma32
# 0x1c3ac5088: 0x0020118f .long 0x0020118f ; AMX instruction 12 = fma32
# 0x1c3ac508c: 0x8b0a016b add x11, x11, x10
# 0x1c3ac5090: 0x8b0c01ad add x13, x13, x12
# 0x1c3ac5094: 0xf1000529 subs x9, x9, #0x1
# 0x1c3ac5098: 0x54fffec1 b.ne 0x1c3ac5070 ; <+140>
# z is 16x16 float32s. 1.64 TFLOPS is one dispatch per clock cycle. 3.2*16*16*2 = 1638.4
# From HN: "On M1, for single-precision, one AMX P-unit is ~1.64 TFLOPs, one P-core is ~102 GFLOPS." which matches this
def __init__(self):
if LLVM.engine is not None:
return
llvm.initialize()
llvm.initialize_native_target()
llvm.initialize_native_asmprinter() # yes, even this one
llvm.initialize_native_asmprinter()
llvm.initialize_native_asmparser()
target = llvm.Target.from_triple(llvm.get_process_triple())
LLVM.optimizer = llvm.create_module_pass_manager()
LLVM.target_machine = target.create_target_machine(opt=3) # this opt actually can change things
LLVM.target_machine = target.create_target_machine(opt=2) # this opt actually can change things. ex: opt=3 means no FMA, opt=2 means FMA
LLVM.target_machine.add_analysis_passes(LLVM.optimizer)
llvm.set_option('', '-force-vector-interleave=4') # this makes sum the same speed as torch, it also doubles the (slow) conv speed
if DEBUG >= 4:
llvm.set_option('', '--debug-only=loop-vectorize')
#llvm.set_option('', '--debug')
# does this do anything?
builder = llvm.create_pass_manager_builder()
builder.opt_level = 3
builder.size_level = 0
builder.loop_vectorize = True
builder.slp_vectorize = 1
builder.slp_vectorize = True
builder.populate(LLVM.optimizer)
LLVM.target_machine.set_asm_verbosity(True)
@@ -143,23 +101,7 @@ class LLVM:
backing_mod.triple = llvm.get_process_triple()
LLVM.engine = llvm.create_mcjit_compiler(backing_mod, LLVM.target_machine)
# cache
def notify_func(module, buffer):
#print("notify", module.name)
with open(f"/tmp/llvmcache/{module.name}", "wb") as f:
f.write(buffer)
def getbuffer_func(module):
#print("getbuffer", module.name)
try:
with open(f"/tmp/llvmcache/{module.name}", "rb") as f:
return f.read()
except FileNotFoundError:
return None
# enable cache
if int(os.getenv("LLVMCACHE", "0")):
LLVM.engine.set_object_cache(notify_func, getbuffer_func)
def exec(self, module, bufs):
def exec(self, module, bufs, op_estimate=0, mem_estimate=0):
module.triple = llvm.get_process_triple()
module.data_layout = self.engine.target_data
llvm_ir = str(module)
@@ -170,43 +112,63 @@ class LLVM:
mod = llvm.parse_assembly(llvm_ir)
mod.verify()
LLVM.optimizer.run(mod)
if DEBUG >= 4:
print("Optimized IR:")
print(str(mod))
mod.name = hashlib.sha1(llvm_ir.encode('utf-8')).hexdigest()
if DEBUG >= 3:
print(LLVM.target_machine.emit_assembly(mod))
LLVM.engine.add_module(mod)
LLVM.engine.finalize_object()
# call function
cfunc = CFUNCTYPE(ctypes.c_int, *[type(x._buf) for x in bufs])(LLVM.engine.get_function_address('exec'))
# call function (NOTE: if the types don't match, there's likely something wrong with the cache)
#cfunc = CFUNCTYPE(ctypes.c_int, *[type(x._buf) for x in bufs])(LLVM.engine.get_function_address('exec'))
# why is this needed without the types. fixed tests below
# LLVM=1 OPT=2 python3 test/test_ops.py TestOps.test_cat TestOps.test_multicat
cfunc = CFUNCTYPE(ctypes.c_int, *[ctypes.POINTER(ctypes.c_float) for x in bufs])(LLVM.engine.get_function_address('exec'))
st = time.monotonic()
cfunc(*[x._buf for x in bufs])
et = time.monotonic() - st
if DEBUG >= 1:
print(f"**LLVM** time {et*1000:7.2f} ms OPs {op_estimate/1e6:7.2f}M -- {(op_estimate/1e9)/et:5.2f} GFLOPS -- {mem_estimate:10d} reads -- {(mem_estimate*4/1e9)/et:5.2f} GB/s")
# we are done
LLVM.engine.remove_module(mod)
return cfunc
# TODO: Refactor LLVMBuffer and GPUBuffer into ShapeTrackedBuffer
class LLVMBuffer(ExplicitExecAST):
op_lookup = {
UnaryOps.NOOP: lambda builder,x: x,
UnaryOps.NEG: lambda builder,x: builder.fneg(x),
UnaryOps.RELU: lambda builder,x: builder.select(builder.fcmp_ordered("<=", ir.Constant(ir.FloatType(), 0), x), x, ir.Constant(ir.FloatType(), 0)),
UnaryOps.EXP: lambda builder,x: builder.call(builder._block.module.declare_intrinsic('llvm.exp', [ir.FloatType()]), [x]),
UnaryOps.LOG: lambda builder,x: builder.call(builder._block.module.declare_intrinsic('llvm.log', [ir.FloatType()]), [x]),
UnaryOps.SIGN: lambda builder,x: builder.select(builder.fcmp_ordered("==", x, ir.Constant(ir.FloatType(), 0)), ir.Constant(ir.FloatType(), 0),
builder.select(builder.fcmp_ordered("<=", ir.Constant(ir.FloatType(), 0), x), ir.Constant(ir.FloatType(), 1), ir.Constant(ir.FloatType(), -1))),
UnaryOps.RECIPROCAL: lambda builder,x: builder.fdiv(ir.Constant(ir.FloatType(), 1), x),
BinaryOps.ADD: lambda builder,x,y: builder.fadd(x,y),
BinaryOps.SUB: lambda builder,x,y: builder.fsub(x,y),
BinaryOps.MUL: lambda builder,x,y: builder.fmul(x,y),
BinaryOps.DIV: lambda builder,x,y: builder.fdiv(x,y),
BinaryOps.POW: lambda builder,x,y: builder.call(builder._block.module.declare_intrinsic('llvm.pow', [ir.FloatType()]), [x,y]),
BinaryOps.CMPEQ: lambda builder,x,y: builder.uitofp(builder.fcmp_ordered("==", x, y), ir.FloatType())
UnaryOps.NEG: lambda builder,x: builder.fneg(x, flags=('fast',)),
UnaryOps.RELU: lambda builder,x: builder.select(builder.fcmp_ordered("<=", ir.Constant(ir.FloatType(), 0), x, flags=('fast',)), x, ir.Constant(ir.FloatType(), 0)),
UnaryOps.EXP: lambda builder,x: builder.call(builder._block.module.declare_intrinsic('llvm.exp', [ir.FloatType()]), [x], fastmath=('fast',)),
UnaryOps.LOG: lambda builder,x: builder.call(builder._block.module.declare_intrinsic('llvm.log', [ir.FloatType()]), [x], fastmath=('fast',)),
UnaryOps.SIGN: lambda builder,x: builder.select(builder.fcmp_ordered("==", x, ir.Constant(ir.FloatType(), 0), flags=('fast',)), ir.Constant(ir.FloatType(), 0),
builder.select(builder.fcmp_ordered("<=", ir.Constant(ir.FloatType(), 0), x, flags=('fast',)), ir.Constant(ir.FloatType(), 1), ir.Constant(ir.FloatType(), -1))),
UnaryOps.RECIPROCAL: lambda builder,x: builder.fdiv(ir.Constant(ir.FloatType(), 1), x, flags=('fast',)),
BinaryOps.ADD: lambda builder,x,y: builder.fadd(x,y, flags=('fast',)),
BinaryOps.SUB: lambda builder,x,y: builder.fsub(x,y, flags=('fast',)),
BinaryOps.MUL: lambda builder,x,y: builder.fmul(x,y, flags=('fast',)),
BinaryOps.DIV: lambda builder,x,y: builder.fdiv(x,y, flags=('fast',)),
BinaryOps.POW: lambda builder,x,y: builder.call(builder._block.module.declare_intrinsic('llvm.pow', [ir.FloatType()]), [x,y], fastmath=('fast',)),
BinaryOps.CMPEQ: lambda builder,x,y: builder.uitofp(builder.fcmp_ordered("==", x, y, flags=('fast',)), ir.FloatType())
}
start_for_op = {
ReduceOps.SUM: ir.Constant(ir.FloatType(), 0),
ReduceOps.MAX: ir.Constant(ir.FloatType(), -math.inf)
}
def __init__(self, shape:Union[ShapeTracker, Tuple[int, ...]], hostbuf=None):
super().__init__(shape, hostbuf)
# TODO: force alignment?
self._buf = (ctypes.c_float * (prod(self.shape)))() if hostbuf is None else hostbuf._buf
#assert ctypes.addressof(self._buf) & 0x1F == 0
def __repr__(self): return f"LLVMBuffer {str(self.shape)}"
def __repr__(self): return f"LLVMBuffer {str(self.st)}"
@staticmethod
def fromCPU(x):
@@ -217,80 +179,118 @@ class LLVMBuffer(ExplicitExecAST):
def toCPU(x): return np.ctypeslib.as_array(x.contiguous_op()._buf)[:prod(x.shape)].reshape(x.shape).copy()
# ast can contain one ReduceOp with arbitrary Binary/Unary ops
func_cache : Dict[str, Any] = {}
@classmethod
def exec_ast(cls, ast:LazyOp) -> LLVMBuffer:
# get the real buffers from the ast
bufs = get_buffers(ast)
reduceops = [x for x in get_lazyops(ast) if isinstance(x.op, ReduceOps)]
assert len(reduceops) <= 1, "max one reduce op in an ast"
earlybufs = get_buffers(reduceops[0]) if len(reduceops) > 0 else []
ret = cls(get_lazyop_info(ast).shape)
k = ASTKernel(ast)
# cached kernel
key = str(ast) # TODO: does this uniquely determine the AST? No! The shapetracker can change. Do this better.
if key in LLVMBuffer.func_cache:
LLVMBuffer.func_cache[key](*[x._buf for x in k.bufs])
return k.ret
# cache miss, we have to process the kernel
k.process()
if DEBUG >= 2:
print(ast)
print(k.shapes)
print(k.strides)
# the 4x4 need to go all the way at the end, even after reduce
output_shape = k.shapes[0]
full_shape = [x for x in k.shapes if x != output_shape]
full_shape = output_shape if len(full_shape) == 0 else full_shape[0]
# *** llvm specific below this line ***
# create llvm function
module = ir.Module(name=__file__)
func = ir.Function(module, ir.FunctionType(ir.VoidType(), [ir.PointerType(ir.FloatType())]*(1+len(bufs))), name='exec')
func = ir.Function(module, ir.FunctionType(ir.VoidType(), [ir.FloatType().as_pointer()]*(len(k.bufs))), name='exec')
# enter
start_builder = ir.IRBuilder(func.append_basic_block(name="entry"))
body_builder = ir.IRBuilder(func.append_basic_block(name="inner_loop"))
start_builder.branch(body_builder._block)
# force llvmlite to allow us to add function attribute then add the attribute
func.attributes._known = func.attributes._known.union(frozenset(['"no-nans-fp-math"="true"']))
func.attributes.add('"no-nans-fp-math"="true"')
idx = body_builder.phi(ir.IntType(64))
idx.add_incoming(int_const(0), start_builder._block)
# construct the structure of the loops
loop_entry = [ir.IRBuilder(func.append_basic_block(name="entry"))]
loop_exit = []
for i,_ in enumerate(full_shape):
loop_entry.append(ir.IRBuilder(func.append_basic_block(name=f"loop_{i}")))
for i,_ in enumerate(full_shape):
loop_exit.append(ir.IRBuilder(func.append_basic_block(name=f"loopexit_{len(full_shape)-1-i}")))
loop_exit.append(ir.IRBuilder(func.append_basic_block(name="exit")))
loop_exit = loop_exit[::-1]
reduce_builder = ir.IRBuilder(func.append_basic_block(name="reduce_loop"))
store_builder = ir.IRBuilder(func.append_basic_block(name="store_block"))
# add the buffer indexing
idx_level = [[int_const(o)] for o in k.offsets]
for i in range(len(full_shape)):
for j in range(len(k.bufs)):
# stride
si = loop_entry[i+1].phi(ir.IntType(64), name=f"idx_{j}_{i}")
si.add_incoming(idx_level[j][-1], loop_entry[i]._block)
si_ps = loop_exit[i+1].add(si, int_const(k.strides[j][i]))
si.add_incoming(si_ps, loop_exit[i+1]._block)
idx_level[j].append(si)
def ast_parse(builder, x, idx, reduce_result=None, depth=0):
if DEBUG >= 1:
print(" "*depth+"ast:", reduce_result, x)
# the ast parser
def ast_parse(builder, x, level, reduce_result=None):
if not isinstance(x, LazyOp):
return idx_deref(builder, x, func.args[bufs.index(x)+1], idx)
buf_index = k.bufs.index(x)
idx = idx_level[buf_index][level]
# load 1x1
if len(x.st.views) > 1:
if DEBUG >= 1:
print(f"WARNING: {x} has buffers with more than 1 view, can't optimize")
return idx_deref(builder, x, func.args[buf_index], idx)
else:
return builder.load(builder.gep(func.args[buf_index], [idx], inbounds=True))
if isinstance(x.op, ReduceOps):
if reduce_result is None:
raise Exception("no reduce")
return reduce_result
values = [ast_parse(builder, v, idx, reduce_result, depth=depth+1) for v in x.src]
values = [ast_parse(builder, v, level, reduce_result) for v in x.src]
return LLVMBuffer.op_lookup[x.op](builder, *values)
if len(reduceops) > 0:
assert len(earlybufs[0].shape) == len(reduceops[0].arg), "reduce only possible on matching shapes"
if DEBUG >= 1:
print(f"reduce {earlybufs[0].shape} -> {reduceops[0].arg}")
red = prod([s for s,n in zip(earlybufs[0].shape, reduceops[0].arg) if n == 1])
red_idx = reduce_builder.phi(ir.IntType(64))
red_idx.add_incoming(int_const(0), body_builder._block)
val = reduce_builder.phi(ir.FloatType())
reduce_input = ast_parse(reduce_builder, reduceops[0].src[0], (prod(reduceops[0].arg), idx, red, red_idx))
# add the ast + final store
store_loop = output_shape.index(1) if 1 in output_shape else -1
if reduceops[0].op == ReduceOps.SUM:
val.add_incoming(ir.Constant(ir.FloatType(), 0), body_builder._block)
reduce_result = reduce_builder.fadd(reduce_input, val)
elif reduceops[0].op == ReduceOps.MAX:
val.add_incoming(ir.Constant(ir.FloatType(), -math.inf), body_builder._block)
reduce_result = reduce_builder.call(ir.Function(module, ir.FunctionType(ir.FloatType(), [ir.FloatType(), ir.FloatType()]), name="llvm.maxnum.f32"), [reduce_input, val])
else:
raise Exception(f"unknown ReduceOps {ast.op}")
val.add_incoming(reduce_result, reduce_builder._block)
# do the early ast
reduce_result = None
if k.reduceop:
reduce_input = ast_parse(loop_exit[-1], k.reduceop.src[0], -1)
phis = [LLVMBuffer.start_for_op[k.reduceop.op]] # type: ignore
for i in range(store_loop+1, len(loop_entry)):
val = loop_entry[i].phi(ir.FloatType(), f"reduce_phi_{i}")
val.add_incoming(phis[-1], loop_entry[i-1]._block)
phis.append(val)
red_idx_p1 = reduce_builder.add(red_idx, int_const(1))
red_idx.add_incoming(red_idx_p1, reduce_builder._block)
reduce_builder.cbranch(reduce_builder.icmp_unsigned("==", red_idx_p1, int_const(red)), store_builder._block, reduce_builder._block)
else:
reduce_result = None
reduce_builder.branch(store_builder._block)
if k.reduceop.op == ReduceOps.SUM:
reduce_result = loop_exit[-1].fadd(reduce_input, val, flags=('fast',))
elif k.reduceop.op == ReduceOps.MAX:
reduce_result = loop_exit[i].select(loop_exit[-1].fcmp_unordered(">", val, reduce_input, flags=('fast',)), val, reduce_input, flags=('fast',))
body_builder.branch(reduce_builder._block)
result = ast_parse(store_builder, ast, (prod(ret.shape), idx, 1, None), reduce_result)
store_builder.store(result, store_builder.gep(func.args[0], [idx], inbounds=True))
idx_p1 = store_builder.add(idx, int_const(1))
idx.add_incoming(idx_p1, store_builder._block)
for i,phi in enumerate(phis[1:]):
if reduce_result != "AMX_Z":
phi.add_incoming(reduce_result, loop_exit[store_loop+1+i]._block)
exit_builder = ir.IRBuilder(func.append_basic_block(name="exit"))
exit_builder.ret_void()
# do the late ast
result = ast_parse(loop_exit[store_loop], ast, store_loop, reduce_result=reduce_result)
store_builder.cbranch(store_builder.icmp_unsigned("==", idx_p1, int_const(prod(ret.shape))), exit_builder._block, body_builder._block)
# store result
loop_exit[store_loop].store(result, loop_exit[store_loop].gep(func.args[0], [idx_level[0][store_loop]], inbounds=True))
# add the looping
for i,s in enumerate(full_shape):
loop_entry[i].branch(loop_entry[i+1]._block)
idx = loop_entry[i+1].phi(ir.IntType(64), name=f"loopvar_{i}")
idx.add_incoming(int_const(0), loop_entry[i]._block)
idx_p1 = loop_exit[i+1].add(idx, int_const(1))
idx.add_incoming(idx_p1, loop_exit[i+1]._block)
loop_exit[i+1].cbranch(loop_exit[i+1].icmp_unsigned("==", idx_p1, int_const(s)), loop_exit[i]._block, loop_entry[i+1]._block)
# **** llvm running ****
LLVM().exec(module, [ret] + bufs)
return ret
loop_entry[-1].branch(loop_exit[-1]._block)
loop_exit[0].ret_void()
LLVMBuffer.func_cache[key] = LLVM().exec(module, k.bufs, k.info.flops, sum(len(x._buf) for x in k.bufs))
return k.ret