mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
Simple chonker (#431)
* chonker will make llvm fast * work * better speed tests, we will make them fast * with the cache add is the same speed * relu and neg are fast * fix sum speed * maximum maxnum? * hack for gemm opt * gemm very slow * zeros like * test_permute * shapetracker returns self * fix shapetracker factorization * err, int strides * permutes are faster now in tinygrad than pytorch * support -1 in expand * gemm unrolled * improve final test case * WIP GEMM * why isn't GEMM fast? * revert cache dim * ffp contract works on clang, not llvm? * ignore llvm ir * this makes fma work at least, but no faster * USE_4x4 * 63 GFLOPS * 87 GFLOPS * that wasn't matmul, 44 GFLOPS now * 82 GFLOPS permuted * this permute too * a little speed for the convs * 45 GFLOPS * speed tests pass again * clean up prints * fix FMA WHAT A WASTE OF TIME * colors * moar fair * GPU * useless on chonker * cleanups * improve factorized shapetracker * better threshold * label conv * work * ops test pass again * hot load the index * run the last view, no need to create * ZeroView needs a repr for the key to work * fix segfault on out of bounds * one more test * start amx, and llvm.initialize_native_asmparser * amx works * nice AMX class * nicer AMX class * refactor get_idxs * amx working * is slower... * useless flip * cache * SZ_X * AMX_SZ_X/Y work alone * Contiguous mlop * test gemm packed * PREPARE in packed * use_amx factor * prefetch isn't faster * loop * same 3ms * 2.24 ms * allow double on store in TG * amx reduce is the same speed as non amx reduce * include memory bandwidth * clean up shapetracker * flip returns stride * prepare for upstream * Update ops_llvm.py (#426) * permutes are yellow and green now * faster conv * llvm cleanups * Show optimised IR under debug 4 (#428) * ASTKernel class * Make tinygrad work with older python version (#427) * Make tinygrad work with older python version * Use partialmethod instead of partial * smiple chonker is chonking * remove junk from test speed vs torch * fix linker and types * AMX is only here now * add LLVM tests, it's a valid backend now * oops, run llvm test * contiguous_op * fix loadops compare * dedup reduceops Co-authored-by: calledit <1573053+calledit@users.noreply.github.com>
This commit is contained in:
@@ -1,30 +1,28 @@
|
||||
from __future__ import annotations
|
||||
import os
|
||||
import hashlib
|
||||
import math
|
||||
from typing import Tuple, Union
|
||||
import time
|
||||
from typing import Tuple, Union, Dict, Any
|
||||
from tinygrad.helpers import prod
|
||||
from tinygrad.shapetracker import ShapeTracker, ZeroView
|
||||
from tinygrad.ops import LazyOp
|
||||
from tinygrad.ops import LazyOp, ASTKernel
|
||||
import ctypes
|
||||
import numpy as np
|
||||
from ctypes import CFUNCTYPE
|
||||
from tinygrad.ops import DEBUG, UnaryOps, BinaryOps, ReduceOps, get_buffers, get_lazyops, ExplicitExecAST, get_lazyop_info
|
||||
from tinygrad.ops import DEBUG, UnaryOps, BinaryOps, ReduceOps, ExplicitExecAST
|
||||
|
||||
from llvmlite import ir # type: ignore
|
||||
import llvmlite.binding as llvm # type: ignore
|
||||
|
||||
def int_const(x): return ir.Constant(ir.IntType(64), x)
|
||||
def idx_deref(builder, buf, ptr, eidx):
|
||||
if eidx[2] == 1 and eidx[3] is None:
|
||||
idx = eidx[1]
|
||||
else:
|
||||
idx = builder.add(builder.mul(eidx[1], int_const(eidx[2])), eidx[3], name="idx")
|
||||
|
||||
# this is only used on the crappy path
|
||||
def idx_deref(builder, buf, ptr, idx):
|
||||
if DEBUG >= 1:
|
||||
print("viewcount:", len(buf.st.views), buf.st.expr(), ptr, "on", buf.shape)
|
||||
# TODO: unify this with expr in ShapeTracker
|
||||
valid = None
|
||||
for v in buf.st.views[::-1]:
|
||||
for v in buf.st.views[0:-1][::-1]:
|
||||
if isinstance(v, ZeroView):
|
||||
if valid is None:
|
||||
valid = ir.Constant(ir.IntType(1), 1)
|
||||
@@ -49,93 +47,53 @@ def idx_deref(builder, buf, ptr, eidx):
|
||||
print(f"expanding index {v.shape_strides}")
|
||||
for i,(d,s) in enumerate(v.shape_strides[::-1]):
|
||||
if d != 1 and s != 0:
|
||||
if acc%eidx[2] == 0 and len(buf.st.views) == 1:
|
||||
# the inner one doesn't matter
|
||||
lr = eidx[1]
|
||||
if acc//eidx[2] != 1:
|
||||
lr = builder.sdiv(lr, int_const(acc//eidx[2]))
|
||||
if (acc//eidx[2])*d != eidx[0]:
|
||||
lr = builder.srem(lr, int_const(d))
|
||||
elif acc*d <= eidx[2] and eidx[3] is not None and len(buf.st.views) == 1:
|
||||
# the outer one doesn't matter
|
||||
lr = eidx[3]
|
||||
if acc != 1:
|
||||
lr = builder.sdiv(lr, int_const(acc))
|
||||
if acc*d != eidx[2]:
|
||||
lr = builder.srem(lr, int_const(d))
|
||||
else:
|
||||
# slow path
|
||||
lr = idx
|
||||
if acc != 1:
|
||||
lr = builder.sdiv(lr, int_const(acc))
|
||||
if acc*d != (eidx[0]*eidx[2]):
|
||||
lr = builder.srem(lr, int_const(d))
|
||||
# slow path
|
||||
lr = idx
|
||||
if acc != 1:
|
||||
lr = builder.sdiv(lr, int_const(acc))
|
||||
if acc*d != prod(buf.shape):
|
||||
lr = builder.srem(lr, int_const(d))
|
||||
if s != 1:
|
||||
lr = builder.mul(lr, int_const(s))
|
||||
ret = builder.add(ret, lr)
|
||||
acc *= d
|
||||
idx = ret
|
||||
if valid is not None:
|
||||
return builder.select(valid, builder.load(builder.gep(ptr, [idx], inbounds=True)), ir.Constant(ir.FloatType(), 0))
|
||||
# this always does the load, so we have it load *0 if the arg won't be used
|
||||
# TODO: would control flow be faster?
|
||||
aug_idx = builder.select(valid, idx, int_const(0))
|
||||
return builder.select(valid, builder.load(builder.gep(ptr, [aug_idx], inbounds=True)), ir.Constant(ir.FloatType(), 0))
|
||||
else:
|
||||
return builder.load(builder.gep(ptr, [idx], inbounds=True))
|
||||
|
||||
# https://blog.christianperone.com/2022/09/tutorial-on-using-llvm-to-jit-pytorch-fx-graphs-to-native-code-x86-arm-risc-v-wasm-part-i-scalars/
|
||||
class LLVM:
|
||||
target_machine = None
|
||||
engine = None
|
||||
optimizer = None
|
||||
# if it can't vectorize
|
||||
# OPT=2 DEBUG=3 LLVM=1 FORWARD_ONLY=1 python3 test/test_ops.py TestOps.test_mul
|
||||
# if can't vectorize anything
|
||||
|
||||
# looks like we have two options, either use clang or handle vectorization in tinygrad
|
||||
# for the sake of the GPU, we should probably do in tinygrad
|
||||
|
||||
# ARM NEON is 128b wide, aka <4 x float> (similar to most GPUs)
|
||||
# Firestorm (big M1 core) can do up to 4 ops per cycle @ 3.2 GHz = 3.2*4*4*2 = 102.4 GFLOPS (fma)
|
||||
|
||||
# There's also AMX https://github.com/corsix/amx/blob/main/README.md
|
||||
# It seems like torch CPU must be using this? I'm seeing ~150 GFLOPS with convs
|
||||
# Calling nnp_s4gemm_only_3x3__neon and nnp_owt8x8_3x3_with_bias__neon which don't seem like AMX
|
||||
# Could this be a winograd conv? Yes, nnp_owt8x8_3x3_with_bias__neon is in NNPACK 2d-winograd-8x8-3x3.c
|
||||
|
||||
# 2048x2048 matmul in 9.88 ms (17.18 GOPS) = 1739 GFLOPS (so much! this has to be the AMX)
|
||||
# calling libBLAS.dylib`SGEMM
|
||||
# 0x1c3ac5070: 0x0020100d .long 0x0020100d ; AMX instruction 0 = ldx
|
||||
# 0x1c3ac5074: 0x0020102b .long 0x0020102b ; AMX instruction 1 = ldy (presumed typo in ldst.md)
|
||||
# 0x1c3ac5078: 0x0020119f .long 0x0020119f ; AMX instruction 12 = fma32
|
||||
# 0x1c3ac507c: 0x0020118e .long 0x0020118e ; AMX instruction 12 = fma32
|
||||
# 0x1c3ac5080: 0x9144410f add x15, x8, #0x110, lsl #12 ; =0x110000
|
||||
# 0x1c3ac5084: 0x00201188 .long 0x00201188 ; AMX instruction 12 = fma32
|
||||
# 0x1c3ac5088: 0x0020118f .long 0x0020118f ; AMX instruction 12 = fma32
|
||||
# 0x1c3ac508c: 0x8b0a016b add x11, x11, x10
|
||||
# 0x1c3ac5090: 0x8b0c01ad add x13, x13, x12
|
||||
# 0x1c3ac5094: 0xf1000529 subs x9, x9, #0x1
|
||||
# 0x1c3ac5098: 0x54fffec1 b.ne 0x1c3ac5070 ; <+140>
|
||||
# z is 16x16 float32s. 1.64 TFLOPS is one dispatch per clock cycle. 3.2*16*16*2 = 1638.4
|
||||
|
||||
# From HN: "On M1, for single-precision, one AMX P-unit is ~1.64 TFLOPs, one P-core is ~102 GFLOPS." which matches this
|
||||
|
||||
def __init__(self):
|
||||
if LLVM.engine is not None:
|
||||
return
|
||||
llvm.initialize()
|
||||
llvm.initialize_native_target()
|
||||
llvm.initialize_native_asmprinter() # yes, even this one
|
||||
llvm.initialize_native_asmprinter()
|
||||
llvm.initialize_native_asmparser()
|
||||
target = llvm.Target.from_triple(llvm.get_process_triple())
|
||||
LLVM.optimizer = llvm.create_module_pass_manager()
|
||||
LLVM.target_machine = target.create_target_machine(opt=3) # this opt actually can change things
|
||||
LLVM.target_machine = target.create_target_machine(opt=2) # this opt actually can change things. ex: opt=3 means no FMA, opt=2 means FMA
|
||||
LLVM.target_machine.add_analysis_passes(LLVM.optimizer)
|
||||
|
||||
llvm.set_option('', '-force-vector-interleave=4') # this makes sum the same speed as torch, it also doubles the (slow) conv speed
|
||||
if DEBUG >= 4:
|
||||
llvm.set_option('', '--debug-only=loop-vectorize')
|
||||
#llvm.set_option('', '--debug')
|
||||
|
||||
# does this do anything?
|
||||
builder = llvm.create_pass_manager_builder()
|
||||
builder.opt_level = 3
|
||||
builder.size_level = 0
|
||||
builder.loop_vectorize = True
|
||||
builder.slp_vectorize = 1
|
||||
builder.slp_vectorize = True
|
||||
builder.populate(LLVM.optimizer)
|
||||
|
||||
LLVM.target_machine.set_asm_verbosity(True)
|
||||
@@ -143,23 +101,7 @@ class LLVM:
|
||||
backing_mod.triple = llvm.get_process_triple()
|
||||
LLVM.engine = llvm.create_mcjit_compiler(backing_mod, LLVM.target_machine)
|
||||
|
||||
# cache
|
||||
def notify_func(module, buffer):
|
||||
#print("notify", module.name)
|
||||
with open(f"/tmp/llvmcache/{module.name}", "wb") as f:
|
||||
f.write(buffer)
|
||||
def getbuffer_func(module):
|
||||
#print("getbuffer", module.name)
|
||||
try:
|
||||
with open(f"/tmp/llvmcache/{module.name}", "rb") as f:
|
||||
return f.read()
|
||||
except FileNotFoundError:
|
||||
return None
|
||||
# enable cache
|
||||
if int(os.getenv("LLVMCACHE", "0")):
|
||||
LLVM.engine.set_object_cache(notify_func, getbuffer_func)
|
||||
|
||||
def exec(self, module, bufs):
|
||||
def exec(self, module, bufs, op_estimate=0, mem_estimate=0):
|
||||
module.triple = llvm.get_process_triple()
|
||||
module.data_layout = self.engine.target_data
|
||||
llvm_ir = str(module)
|
||||
@@ -170,43 +112,63 @@ class LLVM:
|
||||
mod = llvm.parse_assembly(llvm_ir)
|
||||
mod.verify()
|
||||
LLVM.optimizer.run(mod)
|
||||
if DEBUG >= 4:
|
||||
print("Optimized IR:")
|
||||
print(str(mod))
|
||||
mod.name = hashlib.sha1(llvm_ir.encode('utf-8')).hexdigest()
|
||||
if DEBUG >= 3:
|
||||
print(LLVM.target_machine.emit_assembly(mod))
|
||||
LLVM.engine.add_module(mod)
|
||||
LLVM.engine.finalize_object()
|
||||
|
||||
# call function
|
||||
cfunc = CFUNCTYPE(ctypes.c_int, *[type(x._buf) for x in bufs])(LLVM.engine.get_function_address('exec'))
|
||||
# call function (NOTE: if the types don't match, there's likely something wrong with the cache)
|
||||
#cfunc = CFUNCTYPE(ctypes.c_int, *[type(x._buf) for x in bufs])(LLVM.engine.get_function_address('exec'))
|
||||
|
||||
# why is this needed without the types. fixed tests below
|
||||
# LLVM=1 OPT=2 python3 test/test_ops.py TestOps.test_cat TestOps.test_multicat
|
||||
cfunc = CFUNCTYPE(ctypes.c_int, *[ctypes.POINTER(ctypes.c_float) for x in bufs])(LLVM.engine.get_function_address('exec'))
|
||||
|
||||
st = time.monotonic()
|
||||
cfunc(*[x._buf for x in bufs])
|
||||
et = time.monotonic() - st
|
||||
if DEBUG >= 1:
|
||||
print(f"**LLVM** time {et*1000:7.2f} ms OPs {op_estimate/1e6:7.2f}M -- {(op_estimate/1e9)/et:5.2f} GFLOPS -- {mem_estimate:10d} reads -- {(mem_estimate*4/1e9)/et:5.2f} GB/s")
|
||||
|
||||
# we are done
|
||||
LLVM.engine.remove_module(mod)
|
||||
return cfunc
|
||||
|
||||
|
||||
# TODO: Refactor LLVMBuffer and GPUBuffer into ShapeTrackedBuffer
|
||||
class LLVMBuffer(ExplicitExecAST):
|
||||
op_lookup = {
|
||||
UnaryOps.NOOP: lambda builder,x: x,
|
||||
UnaryOps.NEG: lambda builder,x: builder.fneg(x),
|
||||
UnaryOps.RELU: lambda builder,x: builder.select(builder.fcmp_ordered("<=", ir.Constant(ir.FloatType(), 0), x), x, ir.Constant(ir.FloatType(), 0)),
|
||||
UnaryOps.EXP: lambda builder,x: builder.call(builder._block.module.declare_intrinsic('llvm.exp', [ir.FloatType()]), [x]),
|
||||
UnaryOps.LOG: lambda builder,x: builder.call(builder._block.module.declare_intrinsic('llvm.log', [ir.FloatType()]), [x]),
|
||||
UnaryOps.SIGN: lambda builder,x: builder.select(builder.fcmp_ordered("==", x, ir.Constant(ir.FloatType(), 0)), ir.Constant(ir.FloatType(), 0),
|
||||
builder.select(builder.fcmp_ordered("<=", ir.Constant(ir.FloatType(), 0), x), ir.Constant(ir.FloatType(), 1), ir.Constant(ir.FloatType(), -1))),
|
||||
UnaryOps.RECIPROCAL: lambda builder,x: builder.fdiv(ir.Constant(ir.FloatType(), 1), x),
|
||||
BinaryOps.ADD: lambda builder,x,y: builder.fadd(x,y),
|
||||
BinaryOps.SUB: lambda builder,x,y: builder.fsub(x,y),
|
||||
BinaryOps.MUL: lambda builder,x,y: builder.fmul(x,y),
|
||||
BinaryOps.DIV: lambda builder,x,y: builder.fdiv(x,y),
|
||||
BinaryOps.POW: lambda builder,x,y: builder.call(builder._block.module.declare_intrinsic('llvm.pow', [ir.FloatType()]), [x,y]),
|
||||
BinaryOps.CMPEQ: lambda builder,x,y: builder.uitofp(builder.fcmp_ordered("==", x, y), ir.FloatType())
|
||||
UnaryOps.NEG: lambda builder,x: builder.fneg(x, flags=('fast',)),
|
||||
UnaryOps.RELU: lambda builder,x: builder.select(builder.fcmp_ordered("<=", ir.Constant(ir.FloatType(), 0), x, flags=('fast',)), x, ir.Constant(ir.FloatType(), 0)),
|
||||
UnaryOps.EXP: lambda builder,x: builder.call(builder._block.module.declare_intrinsic('llvm.exp', [ir.FloatType()]), [x], fastmath=('fast',)),
|
||||
UnaryOps.LOG: lambda builder,x: builder.call(builder._block.module.declare_intrinsic('llvm.log', [ir.FloatType()]), [x], fastmath=('fast',)),
|
||||
UnaryOps.SIGN: lambda builder,x: builder.select(builder.fcmp_ordered("==", x, ir.Constant(ir.FloatType(), 0), flags=('fast',)), ir.Constant(ir.FloatType(), 0),
|
||||
builder.select(builder.fcmp_ordered("<=", ir.Constant(ir.FloatType(), 0), x, flags=('fast',)), ir.Constant(ir.FloatType(), 1), ir.Constant(ir.FloatType(), -1))),
|
||||
UnaryOps.RECIPROCAL: lambda builder,x: builder.fdiv(ir.Constant(ir.FloatType(), 1), x, flags=('fast',)),
|
||||
BinaryOps.ADD: lambda builder,x,y: builder.fadd(x,y, flags=('fast',)),
|
||||
BinaryOps.SUB: lambda builder,x,y: builder.fsub(x,y, flags=('fast',)),
|
||||
BinaryOps.MUL: lambda builder,x,y: builder.fmul(x,y, flags=('fast',)),
|
||||
BinaryOps.DIV: lambda builder,x,y: builder.fdiv(x,y, flags=('fast',)),
|
||||
BinaryOps.POW: lambda builder,x,y: builder.call(builder._block.module.declare_intrinsic('llvm.pow', [ir.FloatType()]), [x,y], fastmath=('fast',)),
|
||||
BinaryOps.CMPEQ: lambda builder,x,y: builder.uitofp(builder.fcmp_ordered("==", x, y, flags=('fast',)), ir.FloatType())
|
||||
}
|
||||
start_for_op = {
|
||||
ReduceOps.SUM: ir.Constant(ir.FloatType(), 0),
|
||||
ReduceOps.MAX: ir.Constant(ir.FloatType(), -math.inf)
|
||||
}
|
||||
|
||||
def __init__(self, shape:Union[ShapeTracker, Tuple[int, ...]], hostbuf=None):
|
||||
super().__init__(shape, hostbuf)
|
||||
# TODO: force alignment?
|
||||
self._buf = (ctypes.c_float * (prod(self.shape)))() if hostbuf is None else hostbuf._buf
|
||||
#assert ctypes.addressof(self._buf) & 0x1F == 0
|
||||
|
||||
def __repr__(self): return f"LLVMBuffer {str(self.shape)}"
|
||||
def __repr__(self): return f"LLVMBuffer {str(self.st)}"
|
||||
|
||||
@staticmethod
|
||||
def fromCPU(x):
|
||||
@@ -217,80 +179,118 @@ class LLVMBuffer(ExplicitExecAST):
|
||||
|
||||
def toCPU(x): return np.ctypeslib.as_array(x.contiguous_op()._buf)[:prod(x.shape)].reshape(x.shape).copy()
|
||||
|
||||
# ast can contain one ReduceOp with arbitrary Binary/Unary ops
|
||||
func_cache : Dict[str, Any] = {}
|
||||
@classmethod
|
||||
def exec_ast(cls, ast:LazyOp) -> LLVMBuffer:
|
||||
# get the real buffers from the ast
|
||||
bufs = get_buffers(ast)
|
||||
reduceops = [x for x in get_lazyops(ast) if isinstance(x.op, ReduceOps)]
|
||||
assert len(reduceops) <= 1, "max one reduce op in an ast"
|
||||
earlybufs = get_buffers(reduceops[0]) if len(reduceops) > 0 else []
|
||||
ret = cls(get_lazyop_info(ast).shape)
|
||||
k = ASTKernel(ast)
|
||||
|
||||
# cached kernel
|
||||
key = str(ast) # TODO: does this uniquely determine the AST? No! The shapetracker can change. Do this better.
|
||||
if key in LLVMBuffer.func_cache:
|
||||
LLVMBuffer.func_cache[key](*[x._buf for x in k.bufs])
|
||||
return k.ret
|
||||
|
||||
# cache miss, we have to process the kernel
|
||||
k.process()
|
||||
|
||||
if DEBUG >= 2:
|
||||
print(ast)
|
||||
print(k.shapes)
|
||||
print(k.strides)
|
||||
|
||||
# the 4x4 need to go all the way at the end, even after reduce
|
||||
output_shape = k.shapes[0]
|
||||
full_shape = [x for x in k.shapes if x != output_shape]
|
||||
full_shape = output_shape if len(full_shape) == 0 else full_shape[0]
|
||||
|
||||
# *** llvm specific below this line ***
|
||||
|
||||
# create llvm function
|
||||
module = ir.Module(name=__file__)
|
||||
func = ir.Function(module, ir.FunctionType(ir.VoidType(), [ir.PointerType(ir.FloatType())]*(1+len(bufs))), name='exec')
|
||||
func = ir.Function(module, ir.FunctionType(ir.VoidType(), [ir.FloatType().as_pointer()]*(len(k.bufs))), name='exec')
|
||||
|
||||
# enter
|
||||
start_builder = ir.IRBuilder(func.append_basic_block(name="entry"))
|
||||
body_builder = ir.IRBuilder(func.append_basic_block(name="inner_loop"))
|
||||
start_builder.branch(body_builder._block)
|
||||
# force llvmlite to allow us to add function attribute then add the attribute
|
||||
func.attributes._known = func.attributes._known.union(frozenset(['"no-nans-fp-math"="true"']))
|
||||
func.attributes.add('"no-nans-fp-math"="true"')
|
||||
|
||||
idx = body_builder.phi(ir.IntType(64))
|
||||
idx.add_incoming(int_const(0), start_builder._block)
|
||||
# construct the structure of the loops
|
||||
loop_entry = [ir.IRBuilder(func.append_basic_block(name="entry"))]
|
||||
loop_exit = []
|
||||
for i,_ in enumerate(full_shape):
|
||||
loop_entry.append(ir.IRBuilder(func.append_basic_block(name=f"loop_{i}")))
|
||||
for i,_ in enumerate(full_shape):
|
||||
loop_exit.append(ir.IRBuilder(func.append_basic_block(name=f"loopexit_{len(full_shape)-1-i}")))
|
||||
loop_exit.append(ir.IRBuilder(func.append_basic_block(name="exit")))
|
||||
loop_exit = loop_exit[::-1]
|
||||
|
||||
reduce_builder = ir.IRBuilder(func.append_basic_block(name="reduce_loop"))
|
||||
store_builder = ir.IRBuilder(func.append_basic_block(name="store_block"))
|
||||
# add the buffer indexing
|
||||
idx_level = [[int_const(o)] for o in k.offsets]
|
||||
for i in range(len(full_shape)):
|
||||
for j in range(len(k.bufs)):
|
||||
# stride
|
||||
si = loop_entry[i+1].phi(ir.IntType(64), name=f"idx_{j}_{i}")
|
||||
si.add_incoming(idx_level[j][-1], loop_entry[i]._block)
|
||||
si_ps = loop_exit[i+1].add(si, int_const(k.strides[j][i]))
|
||||
si.add_incoming(si_ps, loop_exit[i+1]._block)
|
||||
idx_level[j].append(si)
|
||||
|
||||
def ast_parse(builder, x, idx, reduce_result=None, depth=0):
|
||||
if DEBUG >= 1:
|
||||
print(" "*depth+"ast:", reduce_result, x)
|
||||
# the ast parser
|
||||
def ast_parse(builder, x, level, reduce_result=None):
|
||||
if not isinstance(x, LazyOp):
|
||||
return idx_deref(builder, x, func.args[bufs.index(x)+1], idx)
|
||||
buf_index = k.bufs.index(x)
|
||||
idx = idx_level[buf_index][level]
|
||||
# load 1x1
|
||||
if len(x.st.views) > 1:
|
||||
if DEBUG >= 1:
|
||||
print(f"WARNING: {x} has buffers with more than 1 view, can't optimize")
|
||||
return idx_deref(builder, x, func.args[buf_index], idx)
|
||||
else:
|
||||
return builder.load(builder.gep(func.args[buf_index], [idx], inbounds=True))
|
||||
if isinstance(x.op, ReduceOps):
|
||||
if reduce_result is None:
|
||||
raise Exception("no reduce")
|
||||
return reduce_result
|
||||
values = [ast_parse(builder, v, idx, reduce_result, depth=depth+1) for v in x.src]
|
||||
values = [ast_parse(builder, v, level, reduce_result) for v in x.src]
|
||||
return LLVMBuffer.op_lookup[x.op](builder, *values)
|
||||
|
||||
if len(reduceops) > 0:
|
||||
assert len(earlybufs[0].shape) == len(reduceops[0].arg), "reduce only possible on matching shapes"
|
||||
if DEBUG >= 1:
|
||||
print(f"reduce {earlybufs[0].shape} -> {reduceops[0].arg}")
|
||||
red = prod([s for s,n in zip(earlybufs[0].shape, reduceops[0].arg) if n == 1])
|
||||
red_idx = reduce_builder.phi(ir.IntType(64))
|
||||
red_idx.add_incoming(int_const(0), body_builder._block)
|
||||
val = reduce_builder.phi(ir.FloatType())
|
||||
reduce_input = ast_parse(reduce_builder, reduceops[0].src[0], (prod(reduceops[0].arg), idx, red, red_idx))
|
||||
# add the ast + final store
|
||||
store_loop = output_shape.index(1) if 1 in output_shape else -1
|
||||
|
||||
if reduceops[0].op == ReduceOps.SUM:
|
||||
val.add_incoming(ir.Constant(ir.FloatType(), 0), body_builder._block)
|
||||
reduce_result = reduce_builder.fadd(reduce_input, val)
|
||||
elif reduceops[0].op == ReduceOps.MAX:
|
||||
val.add_incoming(ir.Constant(ir.FloatType(), -math.inf), body_builder._block)
|
||||
reduce_result = reduce_builder.call(ir.Function(module, ir.FunctionType(ir.FloatType(), [ir.FloatType(), ir.FloatType()]), name="llvm.maxnum.f32"), [reduce_input, val])
|
||||
else:
|
||||
raise Exception(f"unknown ReduceOps {ast.op}")
|
||||
val.add_incoming(reduce_result, reduce_builder._block)
|
||||
# do the early ast
|
||||
reduce_result = None
|
||||
if k.reduceop:
|
||||
reduce_input = ast_parse(loop_exit[-1], k.reduceop.src[0], -1)
|
||||
phis = [LLVMBuffer.start_for_op[k.reduceop.op]] # type: ignore
|
||||
for i in range(store_loop+1, len(loop_entry)):
|
||||
val = loop_entry[i].phi(ir.FloatType(), f"reduce_phi_{i}")
|
||||
val.add_incoming(phis[-1], loop_entry[i-1]._block)
|
||||
phis.append(val)
|
||||
|
||||
red_idx_p1 = reduce_builder.add(red_idx, int_const(1))
|
||||
red_idx.add_incoming(red_idx_p1, reduce_builder._block)
|
||||
reduce_builder.cbranch(reduce_builder.icmp_unsigned("==", red_idx_p1, int_const(red)), store_builder._block, reduce_builder._block)
|
||||
else:
|
||||
reduce_result = None
|
||||
reduce_builder.branch(store_builder._block)
|
||||
if k.reduceop.op == ReduceOps.SUM:
|
||||
reduce_result = loop_exit[-1].fadd(reduce_input, val, flags=('fast',))
|
||||
elif k.reduceop.op == ReduceOps.MAX:
|
||||
reduce_result = loop_exit[i].select(loop_exit[-1].fcmp_unordered(">", val, reduce_input, flags=('fast',)), val, reduce_input, flags=('fast',))
|
||||
|
||||
body_builder.branch(reduce_builder._block)
|
||||
result = ast_parse(store_builder, ast, (prod(ret.shape), idx, 1, None), reduce_result)
|
||||
store_builder.store(result, store_builder.gep(func.args[0], [idx], inbounds=True))
|
||||
idx_p1 = store_builder.add(idx, int_const(1))
|
||||
idx.add_incoming(idx_p1, store_builder._block)
|
||||
for i,phi in enumerate(phis[1:]):
|
||||
if reduce_result != "AMX_Z":
|
||||
phi.add_incoming(reduce_result, loop_exit[store_loop+1+i]._block)
|
||||
|
||||
exit_builder = ir.IRBuilder(func.append_basic_block(name="exit"))
|
||||
exit_builder.ret_void()
|
||||
# do the late ast
|
||||
result = ast_parse(loop_exit[store_loop], ast, store_loop, reduce_result=reduce_result)
|
||||
|
||||
store_builder.cbranch(store_builder.icmp_unsigned("==", idx_p1, int_const(prod(ret.shape))), exit_builder._block, body_builder._block)
|
||||
# store result
|
||||
loop_exit[store_loop].store(result, loop_exit[store_loop].gep(func.args[0], [idx_level[0][store_loop]], inbounds=True))
|
||||
|
||||
# add the looping
|
||||
for i,s in enumerate(full_shape):
|
||||
loop_entry[i].branch(loop_entry[i+1]._block)
|
||||
idx = loop_entry[i+1].phi(ir.IntType(64), name=f"loopvar_{i}")
|
||||
idx.add_incoming(int_const(0), loop_entry[i]._block)
|
||||
idx_p1 = loop_exit[i+1].add(idx, int_const(1))
|
||||
idx.add_incoming(idx_p1, loop_exit[i+1]._block)
|
||||
loop_exit[i+1].cbranch(loop_exit[i+1].icmp_unsigned("==", idx_p1, int_const(s)), loop_exit[i]._block, loop_entry[i+1]._block)
|
||||
|
||||
# **** llvm running ****
|
||||
LLVM().exec(module, [ret] + bufs)
|
||||
return ret
|
||||
loop_entry[-1].branch(loop_exit[-1]._block)
|
||||
loop_exit[0].ret_void()
|
||||
LLVMBuffer.func_cache[key] = LLVM().exec(module, k.bufs, k.info.flops, sum(len(x._buf) for x in k.bufs))
|
||||
return k.ret
|
||||
|
||||
Reference in New Issue
Block a user