mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-02-18 10:31:41 -05:00
* Replace sign with gt0 * Replace sign with gt0 * GT0 works on GPU * Fix brackets --------- Co-authored-by: Tom Finet <tom.codeninja@gmail.com>
363 lines
15 KiB
Python
363 lines
15 KiB
Python
from __future__ import annotations
|
|
import hashlib
|
|
import math
|
|
import time
|
|
from typing import Tuple, Union, Dict, Any, List
|
|
from tinygrad.helpers import prod
|
|
from tinygrad.shape import ShapeTracker, ZeroView
|
|
from tinygrad.ops import LazyOp
|
|
from tinygrad.ast import ASTKernel
|
|
import ctypes
|
|
import numpy as np
|
|
from ctypes import CFUNCTYPE
|
|
from tinygrad.ops import DEBUG, UnaryOps, BinaryOps, ReduceOps, ExplicitExecAST, GlobalCounters
|
|
|
|
from llvmlite import ir # type: ignore
|
|
import llvmlite.binding as llvm # type: ignore
|
|
|
|
def int_const(x): return ir.Constant(ir.IntType(64), x)
|
|
|
|
# this is only used on the crappy path
|
|
def idx_deref(builder, buf, ptr, idx):
|
|
if DEBUG >= 1:
|
|
print("viewcount:", len(buf.st.views), buf.st.expr(), ptr, "on", buf.shape)
|
|
# TODO: unify this with expr in ShapeTracker
|
|
valid = None
|
|
for v in buf.st.views[0:-1][::-1]:
|
|
if isinstance(v, ZeroView):
|
|
if valid is None:
|
|
valid = ir.Constant(ir.IntType(1), 1)
|
|
acc = 1
|
|
for s,(x,y) in list(zip(v.old_shape, v.arg))[::-1]:
|
|
if x < 0 or y > s:
|
|
lr = idx
|
|
if acc != 1:
|
|
lr = builder.sdiv(lr, int_const(acc))
|
|
lr = builder.srem(lr, int_const(y-x))
|
|
if x != 0:
|
|
lr = builder.add(lr, int_const(x))
|
|
if x < 0:
|
|
valid = builder.and_(valid, builder.icmp_signed(">=", lr, int_const(0)))
|
|
if y > s:
|
|
valid = builder.and_(valid, builder.icmp_signed("<", lr, int_const(s)))
|
|
acc *= y-x
|
|
else:
|
|
acc = 1
|
|
ret = int_const(v.offset)
|
|
if DEBUG >= 2:
|
|
print(f"expanding index {v.shape_strides}")
|
|
for i,(d,s) in enumerate(v.shape_strides[::-1]):
|
|
if d != 1 and s != 0:
|
|
# slow path
|
|
lr = idx
|
|
if acc != 1:
|
|
lr = builder.sdiv(lr, int_const(acc))
|
|
if acc*d != prod(buf.shape):
|
|
lr = builder.srem(lr, int_const(d))
|
|
if s != 1:
|
|
lr = builder.mul(lr, int_const(s))
|
|
ret = builder.add(ret, lr)
|
|
acc *= d
|
|
idx = ret
|
|
if valid is not None:
|
|
# this always does the load, so we have it load *0 if the arg won't be used
|
|
# TODO: would control flow be faster?
|
|
aug_idx = builder.select(valid, idx, int_const(0))
|
|
return builder.select(valid, builder.load(builder.gep(ptr, [aug_idx], inbounds=True)), ir.Constant(ir.FloatType(), 0))
|
|
else:
|
|
return builder.load(builder.gep(ptr, [idx], inbounds=True))
|
|
|
|
class LLVM:
|
|
target_machine = None
|
|
engine = None
|
|
optimizer = None
|
|
|
|
def __init__(self):
|
|
if LLVM.engine is not None:
|
|
return
|
|
llvm.initialize()
|
|
llvm.initialize_native_target()
|
|
llvm.initialize_native_asmprinter()
|
|
llvm.initialize_native_asmparser()
|
|
target = llvm.Target.from_triple(llvm.get_process_triple())
|
|
LLVM.optimizer = llvm.create_module_pass_manager()
|
|
LLVM.target_machine = target.create_target_machine(opt=2) # this opt actually can change things. ex: opt=3 means no FMA, opt=2 means FMA
|
|
LLVM.target_machine.add_analysis_passes(LLVM.optimizer)
|
|
|
|
llvm.set_option('', '-force-vector-interleave=4') # this makes sum the same speed as torch, it also doubles the (slow) conv speed
|
|
if DEBUG >= 4:
|
|
llvm.set_option('', '--debug-only=loop-vectorize')
|
|
#llvm.set_option('', '--debug')
|
|
|
|
# does this do anything?
|
|
builder = llvm.create_pass_manager_builder()
|
|
builder.opt_level = 3
|
|
builder.size_level = 0
|
|
builder.loop_vectorize = True
|
|
builder.slp_vectorize = True
|
|
builder.populate(LLVM.optimizer)
|
|
|
|
LLVM.target_machine.set_asm_verbosity(True)
|
|
backing_mod = llvm.parse_assembly("")
|
|
backing_mod.triple = llvm.get_process_triple()
|
|
LLVM.engine = llvm.create_mcjit_compiler(backing_mod, LLVM.target_machine)
|
|
|
|
def exec(self, module, bufs, op_estimate=0, mem_estimate=0):
|
|
module.triple = llvm.get_process_triple()
|
|
module.data_layout = self.engine.target_data
|
|
llvm_ir = str(module)
|
|
|
|
if DEBUG >= 2:
|
|
print(llvm_ir)
|
|
|
|
mod = llvm.parse_assembly(llvm_ir)
|
|
mod.verify()
|
|
LLVM.optimizer.run(mod)
|
|
if DEBUG >= 4:
|
|
print("Optimized IR:")
|
|
print(str(mod))
|
|
mod.name = hashlib.sha1(llvm_ir.encode('utf-8')).hexdigest()
|
|
if DEBUG >= 3:
|
|
print(LLVM.target_machine.emit_assembly(mod))
|
|
LLVM.engine.add_module(mod)
|
|
LLVM.engine.finalize_object()
|
|
|
|
# call function (NOTE: if the types don't match, there's likely something wrong with the cache)
|
|
#cfunc = CFUNCTYPE(ctypes.c_int, *[type(x._buf) for x in bufs])(LLVM.engine.get_function_address('exec'))
|
|
|
|
# why is this needed without the types. fixed tests below
|
|
# LLVM=1 OPT=2 python3 test/test_ops.py TestOps.test_cat TestOps.test_multicat
|
|
cfunc = CFUNCTYPE(ctypes.c_int, *[ctypes.POINTER(ctypes.c_float) for x in bufs])(LLVM.engine.get_function_address('exec'))
|
|
|
|
st = time.monotonic()
|
|
cfunc(*[x._buf for x in bufs])
|
|
et = time.monotonic() - st
|
|
if DEBUG >= 1:
|
|
print(f"**LLVM** time {et*1000:7.2f} ms OPs {op_estimate/1e6:7.2f}M -- {(op_estimate/1e9)/et:5.2f} GFLOPS -- {mem_estimate:10d} reads -- {(mem_estimate*4/1e9)/et:5.2f} GB/s")
|
|
GlobalCounters.global_ops += op_estimate
|
|
GlobalCounters.global_mem += mem_estimate
|
|
|
|
# we are done
|
|
LLVM.engine.remove_module(mod)
|
|
return cfunc
|
|
|
|
|
|
# TODO: Refactor LLVMBuffer and GPUBuffer into ShapeTrackedBuffer
|
|
class LLVMBuffer(ExplicitExecAST):
|
|
op_lookup = {
|
|
UnaryOps.NOOP: lambda builder,x: x,
|
|
UnaryOps.NEG: lambda builder,x: builder.fneg(x, flags=('fast',)),
|
|
UnaryOps.RELU: lambda builder,x: builder.select(builder.fcmp_ordered("<=", ir.Constant(ir.FloatType(), 0), x, flags=('fast',)), x, ir.Constant(ir.FloatType(), 0)),
|
|
UnaryOps.EXP: lambda builder,x: builder.call(builder._block.module.declare_intrinsic('llvm.exp', [ir.FloatType()]), [x], fastmath=('fast',)),
|
|
UnaryOps.LOG: lambda builder,x: builder.call(builder._block.module.declare_intrinsic('llvm.log', [ir.FloatType()]), [x], fastmath=('fast',)),
|
|
UnaryOps.GT0: lambda builder,x: builder.select(builder.fcmp_ordered(">", x, ir.Constant(ir.FloatType(), 0), flags=('fast',)), ir.Constant(ir.FloatType(), 1), ir.Constant(ir.FloatType(), 0)),
|
|
UnaryOps.RECIPROCAL: lambda builder,x: builder.fdiv(ir.Constant(ir.FloatType(), 1), x, flags=('fast',)),
|
|
BinaryOps.ADD: lambda builder,x,y: builder.fadd(x,y, flags=('fast',)),
|
|
BinaryOps.SUB: lambda builder,x,y: builder.fsub(x,y, flags=('fast',)),
|
|
BinaryOps.MUL: lambda builder,x,y: builder.fmul(x,y, flags=('fast',)),
|
|
BinaryOps.DIV: lambda builder,x,y: builder.fdiv(x,y, flags=('fast',)),
|
|
BinaryOps.POW: lambda builder,x,y: builder.call(builder._block.module.declare_intrinsic('llvm.pow', [ir.FloatType()]), [x,y], fastmath=('fast',)),
|
|
BinaryOps.CMPEQ: lambda builder,x,y: builder.uitofp(builder.fcmp_ordered("==", x, y, flags=('fast',)), ir.FloatType())
|
|
}
|
|
start_for_op = {
|
|
ReduceOps.SUM: ir.Constant(ir.FloatType(), 0),
|
|
ReduceOps.MAX: ir.Constant(ir.FloatType(), -math.inf)
|
|
}
|
|
|
|
def __init__(self, shape:Union[ShapeTracker, Tuple[int, ...]], hostbuf=None, force_create=False):
|
|
super().__init__(shape, hostbuf)
|
|
# TODO: force alignment?
|
|
self._buf = (ctypes.c_float * (prod(self.shape)))() if hostbuf is None else hostbuf._buf
|
|
#assert ctypes.addressof(self._buf) & 0x1F == 0
|
|
|
|
def __repr__(self): return f"LLVMBuffer {str(self.st)}"
|
|
|
|
@staticmethod
|
|
def fromCPU(x):
|
|
x = x.astype(np.float32)
|
|
ret = LLVMBuffer(x.shape)
|
|
ctypes.memmove(ret._buf, x.ctypes.data, prod(ret.shape)*4)
|
|
return ret
|
|
|
|
def toCPU(x): return np.ctypeslib.as_array(x.contiguous()._buf)[:prod(x.shape)].reshape(x.shape).copy()
|
|
|
|
func_cache : Dict[str, Any] = {}
|
|
@classmethod
|
|
def exec_ast(cls, ast:LazyOp) -> LLVMBuffer:
|
|
k = ASTKernel(ast)
|
|
|
|
# cached kernel
|
|
if k.key in LLVMBuffer.func_cache:
|
|
LLVMBuffer.func_cache[k.key](*[x._buf for x in k.bufs])
|
|
return k.ret
|
|
|
|
if DEBUG >= 2:
|
|
print(k.ast)
|
|
print("old:", [x.shape for x in k.sts])
|
|
print("old:", [x.views[-1].strides for x in k.sts])
|
|
|
|
# this stuff can't be hand coded
|
|
kernel_output_axis : List[int] = []
|
|
"""
|
|
CACHE_DIM = 32
|
|
if len(k.shapes[0]) == 2:
|
|
# cache tiling, makes permute fast
|
|
k.reshape_and_permute(
|
|
lambda shape: (shape[0]//CACHE_DIM, CACHE_DIM, shape[1]//CACHE_DIM, CACHE_DIM),
|
|
(0,2,1,3))
|
|
elif len(k.shapes[0]) == 3:
|
|
if k.reduceop:
|
|
if k.strides[1][-1] == 1 and k.strides[2][-1] == 1:
|
|
DY, DX = 8, 8
|
|
elif k.strides[1][-1] in [1,0] and k.strides[1][-2] in [1,0]:
|
|
DY, DX = 4, 16
|
|
else:
|
|
DY, DX = 16, 4
|
|
# matmul: YyXxK -> YXKyx
|
|
k.reshape_and_permute(
|
|
lambda shape: (shape[0]//DY, DY, shape[1]//DX, DX, shape[2]),
|
|
(0,2,4,1,3))
|
|
kernel_output_axis = [-2, -1]
|
|
else:
|
|
CACHE_L2_DIM = 256
|
|
k.reshape_and_permute(
|
|
lambda shape: (shape[0], shape[1]//CACHE_L2_DIM, CACHE_L2_DIM, shape[2]),
|
|
(1,0,2,3))
|
|
kernel_output_axis = [-1]
|
|
elif len(k.shapes[0]) == 7:
|
|
# conv: split chans and X
|
|
DY, DX = 4, 16
|
|
k.reshape_and_permute(
|
|
lambda shape: (shape[0], shape[1]//DY, DY, shape[2], shape[3]//DX, DX, shape[4], shape[5], shape[6]),
|
|
(0,1,3,4,6,7,8,2,5))
|
|
kernel_output_axis = [-2, -1]
|
|
|
|
if DEBUG >= 2:
|
|
print("new:", k.shapes)
|
|
print("new:", k.strides)
|
|
"""
|
|
|
|
# the 4x4 need to go all the way at the end, even after reduce
|
|
output_shape = k.sts[0].shape
|
|
full_shape_options = [x.shape for x in k.sts if x.shape != output_shape]
|
|
full_shape = output_shape if len(full_shape_options) == 0 else full_shape_options[0]
|
|
|
|
full_shape = full_shape if not kernel_output_axis else full_shape[:-len(kernel_output_axis)]
|
|
kernel_output_dim = prod([k.sts[0].shape[a] for a in kernel_output_axis])
|
|
kernel_output_type = ir.FloatType() if kernel_output_dim == 1 else ir.VectorType(ir.FloatType(), kernel_output_dim)
|
|
|
|
def get_idxs(builder, idx, buf_index):
|
|
idx_offsets = [0]
|
|
for axis in kernel_output_axis:
|
|
new_idx_offsets = []
|
|
for s in range(k.shapes[buf_index][axis]):
|
|
for i in idx_offsets:
|
|
new_idx_offsets.append(i + s * k.strides[buf_index][axis])
|
|
idx_offsets = new_idx_offsets
|
|
return [builder.add(idx, int_const(i)) for i in idx_offsets]
|
|
|
|
# *** llvm specific below this line ***
|
|
|
|
# create llvm function
|
|
module = ir.Module(name=__file__)
|
|
func = ir.Function(module, ir.FunctionType(ir.VoidType(), [ir.FloatType().as_pointer()]*(len(k.bufs))), name='exec')
|
|
|
|
# force llvmlite to allow us to add function attribute then add the attribute
|
|
func.attributes._known = func.attributes._known.union(frozenset(['"no-nans-fp-math"="true"']))
|
|
func.attributes.add('"no-nans-fp-math"="true"')
|
|
|
|
# construct the structure of the loops
|
|
loop_entry = [ir.IRBuilder(func.append_basic_block(name="entry"))]
|
|
loop_exit = []
|
|
for i,_ in enumerate(full_shape): loop_entry.append(ir.IRBuilder(func.append_basic_block(name=f"loop_{i}")))
|
|
for i,_ in enumerate(full_shape): loop_exit.append(ir.IRBuilder(func.append_basic_block(name=f"loopexit_{len(full_shape)-1-i}")))
|
|
loop_exit.append(ir.IRBuilder(func.append_basic_block(name="exit")))
|
|
loop_exit = loop_exit[::-1]
|
|
|
|
# add the buffer indexing
|
|
idx_level = [[int_const(st.offset)] for st in k.sts]
|
|
for i in range(len(full_shape)):
|
|
for j in range(len(k.bufs)):
|
|
# stride
|
|
si = loop_entry[i+1].phi(ir.IntType(64), name=f"idx_{j}_{i}")
|
|
si.add_incoming(idx_level[j][-1], loop_entry[i]._block)
|
|
si_ps = loop_exit[i+1].add(si, int_const(k.sts[j].views[-1].strides[i]))
|
|
si.add_incoming(si_ps, loop_exit[i+1]._block)
|
|
idx_level[j].append(si)
|
|
|
|
# the ast parser
|
|
def ast_parse(builder, x, level, reduce_result=None):
|
|
if not isinstance(x, LazyOp):
|
|
m = kernel_output_type(ir.Undefined)
|
|
buf_index = k.bufs.index(x)
|
|
for i, idx in enumerate(get_idxs(builder, idx_level[buf_index][level], buf_index)):
|
|
if len(x.st.views) > 1:
|
|
if DEBUG >= 1: print(f"WARNING: {x} has buffers with more than 1 view, can't optimize")
|
|
element = idx_deref(builder, x, func.args[buf_index], idx)
|
|
else:
|
|
element = builder.load(builder.gep(func.args[buf_index], [idx], inbounds=True))
|
|
m = element if kernel_output_dim == 1 else builder.insert_element(m, element, int_const(i))
|
|
return m
|
|
if isinstance(x.op, ReduceOps):
|
|
if reduce_result is None:
|
|
raise Exception("no reduce")
|
|
return reduce_result
|
|
values = [ast_parse(builder, v, level, reduce_result) for v in x.src]
|
|
|
|
m = kernel_output_type(ir.Undefined)
|
|
if kernel_output_dim == 1:
|
|
return LLVMBuffer.op_lookup[x.op](builder, *values)
|
|
else:
|
|
# TODO: this only has to be done for certain ops
|
|
for i in range(kernel_output_dim):
|
|
value = [builder.extract_element(v, int_const(i)) for v in values]
|
|
element = LLVMBuffer.op_lookup[x.op](builder, *value)
|
|
m = builder.insert_element(m, element, int_const(i))
|
|
return m
|
|
|
|
# add the ast + final store
|
|
store_loop = output_shape.index(1) if 1 in output_shape else -1
|
|
|
|
# do the early ast
|
|
reduce_result = None
|
|
if k.reduceop:
|
|
reduce_input = ast_parse(loop_exit[-1], k.reduceop.src[0], -1)
|
|
phis = [LLVMBuffer.start_for_op[k.reduceop.op]] # type: ignore
|
|
if kernel_output_dim > 1:
|
|
phis = [kernel_output_type(phis * kernel_output_dim)]
|
|
for i in range(store_loop+1, len(loop_entry)):
|
|
val = loop_entry[i].phi(kernel_output_type, f"reduce_phi_{i}")
|
|
val.add_incoming(phis[-1], loop_entry[i-1]._block)
|
|
phis.append(val)
|
|
|
|
if k.reduceop.op == ReduceOps.SUM:
|
|
reduce_result = loop_exit[-1].fadd(reduce_input, val, flags=('fast',))
|
|
elif k.reduceop.op == ReduceOps.MAX:
|
|
reduce_result = loop_exit[i].select(loop_exit[-1].fcmp_unordered(">", val, reduce_input, flags=('fast',)), val, reduce_input, flags=('fast',))
|
|
|
|
for i,phi in enumerate(phis[1:]):
|
|
phi.add_incoming(reduce_result, loop_exit[store_loop+1+i]._block)
|
|
|
|
# do the late ast
|
|
result = ast_parse(loop_exit[store_loop], k.ast, store_loop, reduce_result=reduce_result)
|
|
|
|
# store result
|
|
builder = loop_exit[store_loop]
|
|
for i, idx in enumerate(get_idxs(builder, idx_level[0][store_loop], 0)):
|
|
element = result if kernel_output_dim == 1 else builder.extract_element(result, int_const(i))
|
|
builder.store(element, builder.gep(func.args[0], [idx], inbounds=True))
|
|
|
|
# add the looping
|
|
for i,s in enumerate(full_shape):
|
|
loop_entry[i].branch(loop_entry[i+1]._block)
|
|
idx = loop_entry[i+1].phi(ir.IntType(64), name=f"loopvar_{i}")
|
|
idx.add_incoming(int_const(0), loop_entry[i]._block)
|
|
idx_p1 = loop_exit[i+1].add(idx, int_const(1))
|
|
idx.add_incoming(idx_p1, loop_exit[i+1]._block)
|
|
loop_exit[i+1].cbranch(loop_exit[i+1].icmp_unsigned("==", idx_p1, int_const(s)), loop_exit[i]._block, loop_entry[i+1]._block)
|
|
|
|
loop_entry[-1].branch(loop_exit[-1]._block)
|
|
loop_exit[0].ret_void()
|
|
LLVMBuffer.func_cache[k.key] = LLVM().exec(module, k.bufs, k.info.flops, sum(len(x._buf) for x in k.bufs))
|
|
return k.ret
|