llvm jitted (#1652)

This commit is contained in:
nimlgen
2023-08-29 06:22:44 +03:00
committed by GitHub
parent 1c0449e190
commit 8844a0a822
5 changed files with 23 additions and 19 deletions

View File

@@ -49,7 +49,7 @@ class TestRealWorld(unittest.TestCase):
def test(t, t2): return model(t, 801, t2).realize()
helper_test("test_sd", lambda: (Tensor.randn(1, 4, 64, 64),Tensor.randn(1, 77, 768)), test, 18.0, 967)
@unittest.skipUnless(Device.DEFAULT in JIT_SUPPORTED_DEVICE, "needs JIT")
@unittest.skipUnless(Device.DEFAULT in JIT_SUPPORTED_DEVICE and Device.DEFAULT not in ["LLVM"], "needs JIT, too long on CI LLVM")
def test_llama(self):
old_type = Tensor.default_type
Tensor.default_type = dtypes.float16
@@ -63,7 +63,7 @@ class TestRealWorld(unittest.TestCase):
Tensor.default_type = old_type
@unittest.skipUnless(Device.DEFAULT in JIT_SUPPORTED_DEVICE, "needs JIT")
@unittest.skipUnless(Device.DEFAULT in JIT_SUPPORTED_DEVICE and Device.DEFAULT not in ["LLVM"], "needs JIT, too long on CI LLVM")
def test_train_cifar(self):
# TODO: with default device
#old_default = Device.DEFAULT

View File

@@ -6,7 +6,7 @@ from tinygrad.tensor import Tensor, Device
import numpy as np
@unittest.skipIf(getenv("ARM64") or getenv("PTX"), "ARM64 and PTX are not supported")
@unittest.skipUnless(Device.DEFAULT in ["GPU", "METAL", "CLANG", "CUDA"], f"{Device.DEFAULT} is not supported")
@unittest.skipUnless(Device.DEFAULT in ["GPU", "METAL", "CLANG", "CUDA", "LLVM"], f"{Device.DEFAULT} is not supported")
class TestSymbolicJit(unittest.TestCase):
def test_plus1(self):
def f(a): return (a+1).realize()

View File

@@ -8,7 +8,7 @@ from tinygrad.ops import RawBuffer
from tinygrad.shape.shapetracker import ShapeTracker
from tinygrad.shape.symbolic import Variable
JIT_SUPPORTED_DEVICE = ["GPU", "CLANG", "METAL", "CUDA", "HIP", "WEBGPU"]
JIT_SUPPORTED_DEVICE = ["GPU", "CLANG", "METAL", "CUDA", "HIP", "WEBGPU", "LLVM"]
class TinyJit:
def __init__(self, fxn:Callable):

View File

@@ -6,13 +6,13 @@ from tinygrad.helpers import dtypes
from tinygrad.ops import Op, UnaryOps, BinaryOps, TernaryOps
from tinygrad.shape.symbolic import Variable, NumNode, MulNode, DivNode, ModNode, LtNode, SumNode, AndNode
def int_const(x): return ir.Constant(ir.IntType(64), x)
def sym_render(a, ops=None, ctx=None): return ir.Constant(ir.IntType(64), a) if isinstance(a, int) else a.render(ops, ctx)
render_llvm = {
NumNode: lambda self,ops,ctx: int_const(self.b),
MulNode: lambda self,ops,ctx: ctx.mul(self.a.render(ops,ctx), int_const(self.b)),
DivNode: lambda self,ops,ctx: ctx.sdiv(self.a.render(ops,ctx), int_const(self.b)),
ModNode: lambda self,ops,ctx: ctx.srem(self.a.render(ops,ctx), int_const(self.b)),
LtNode: lambda self,ops,ctx: ctx.icmp_signed("<", self.a.render(ops,ctx), int_const(self.b)),
NumNode: lambda self,ops,ctx: sym_render(self.b,ops,ctx),
MulNode: lambda self,ops,ctx: ctx.mul(self.a.render(ops,ctx), sym_render(self.b,ops,ctx)),
DivNode: lambda self,ops,ctx: ctx.sdiv(self.a.render(ops,ctx), sym_render(self.b,ops,ctx)),
ModNode: lambda self,ops,ctx: ctx.srem(self.a.render(ops,ctx), sym_render(self.b,ops,ctx)),
LtNode: lambda self,ops,ctx: ctx.icmp_signed("<", self.a.render(ops,ctx), sym_render(self.b,ops,ctx)),
SumNode: lambda self,ops,ctx: functools.reduce(lambda a,b: ctx.add(a,b.render(ops,ctx)), self.nodes[1:], self.nodes[0].render(ops,ctx)),
AndNode: lambda self,ops,ctx: functools.reduce(lambda a,b: ctx.and_(a,b.render(ops,ctx)), self.nodes[1:], self.nodes[0].render(ops,ctx))
}
@@ -33,7 +33,7 @@ code_for_op: Final[Dict[Op, Callable]] = {
TernaryOps.WHERE: lambda builder,x,y,z: builder.select(builder.fcmp_unordered("!=", x, ir.Constant(ir.FloatType(), 0), flags=('fast',)), y, z, flags=('fast',)),
}
dtype_to_llvm_dtype = {dtypes.float64:ir.DoubleType(), dtypes.float16:ir.HalfType(), dtypes.bfloat16:ir.IntType(16), dtypes.float32:ir.FloatType(), dtypes.int8:ir.IntType(8), dtypes.uint8:ir.IntType(8), dtypes.bool: ir.IntType(1), dtypes.int64: ir.IntType(64), dtypes.int32: ir.IntType(32)}
dtype_to_llvm_dtype = {dtypes.float64:ir.DoubleType(), dtypes.float16:ir.HalfType(), dtypes.bfloat16:ir.IntType(16), dtypes.float32:ir.FloatType(), dtypes.int8:ir.IntType(8), dtypes.uint8:ir.IntType(8), dtypes.bool: ir.IntType(1), dtypes.int64: ir.IntType(64), dtypes.int32: ir.IntType(32), dtypes._arg_int32: ir.IntType(32)}
def cast(bb, val, input_type, output_type):
if input_type == output_type: return val
@@ -75,9 +75,10 @@ def uops_to_llvm_ir(function_name:str, uops:List[UOp]) -> Tuple[str, Optional[Li
buf_index = {x:i for i,x in enumerate(buf_to_dtype.keys())}
# create llvm function
func_dtypes = [dtype_to_llvm_dtype[dtype] for dtype in buf_to_dtype.values()]
func = ir.Function(module, ir.FunctionType(ir.VoidType(), [x.as_pointer() for x in func_dtypes]), name=function_name)
for a in func.args: a.add_attribute("noalias")
func_dtypes = [(dtype_to_llvm_dtype[dtype],dtype) for dtype in buf_to_dtype.values()]
func = ir.Function(module, ir.FunctionType(ir.VoidType(), [x.as_pointer() if dt!=dtypes._arg_int32 else x for x,dt in func_dtypes]), name=function_name)
for a in func.args:
if a.type.is_pointer: a.add_attribute("noalias")
# force llvmlite to allow us to add function attribute then add the attribute
func.attributes._known = func.attributes._known.union(frozenset(['"no-nans-fp-math"="true"']))
@@ -90,6 +91,9 @@ def uops_to_llvm_ir(function_name:str, uops:List[UOp]) -> Tuple[str, Optional[Li
lvars: Dict[Optional[Token], Any] = {} # this Any is an llvm type
render_llvm[Variable] = lambda self,ops,ctx: lvars[self.expr]
for bufname,dtype in buf_to_dtype.items():
if dtype == dtypes._arg_int32: lvars[bufname] = bb[-1].sext(func.args[buf_index[bufname]], ir.IntType(64))
for uop,newvar,vin,args in uops:
if uop == UOps.LOOP:
for var in args[0]:
@@ -106,16 +110,16 @@ def uops_to_llvm_ir(function_name:str, uops:List[UOp]) -> Tuple[str, Optional[Li
loop_blocks.append((bb[-1], phis))
lvars[var.expr] = bb[-1].phi(ir.IntType(64), name=var.expr)
lvars[var.expr].add_incoming(int_const(var.min), bb[-2]._block)
lvars[var.expr].add_incoming(sym_render(var.min), bb[-2]._block)
if uop == UOps.ENDLOOP:
for var in args[0][::-1]:
if isinstance(var, NumNode): continue
block, phis = loop_blocks.pop()
idx_p1 = bb[-1].add(lvars[var.expr], int_const(1))
idx_p1 = bb[-1].add(lvars[var.expr], sym_render(1))
lvars[var.expr].add_incoming(idx_p1, bb[-1]._block)
for n,phi in phis: phi.add_incoming(lvars[n], bb[-1]._block)
bb.append(ir.IRBuilder(func.append_basic_block(f"loop_exit_{var.expr}")))
bb[-2].cbranch(bb[-2].icmp_unsigned("==", idx_p1, int_const(var.max+1)), bb[-1]._block, block._block)
bb[-2].cbranch(bb[-2].icmp_unsigned(">", idx_p1, sym_render(var.max, render_llvm, bb[-2])), bb[-1]._block, block._block)
if uop == UOps.LOAD:
assert newvar is not None and isinstance(args, (MemOp, ConstOp))
valid = args.valid.render(render_llvm, bb[-1])
@@ -130,7 +134,7 @@ def uops_to_llvm_ir(function_name:str, uops:List[UOp]) -> Tuple[str, Optional[Li
else:
idx = args.idx.render(render_llvm, bb[-1])
if args.valid.min == 0:
aug_idx = bb[-1].select(valid, idx, int_const(0))
aug_idx = bb[-1].select(valid, idx, sym_render(0))
val = bb[-1].select(valid, bb[-1].load(bb[-1].gep(func.args[buf_index[args.name]], [aug_idx], inbounds=True)), ir.Constant(dtype_to_llvm_dtype[args.memory_dtype], args.invalid_value))
else:
val = bb[-1].load(bb[-1].gep(func.args[buf_index[args.name]], [idx], inbounds=True))

View File

@@ -61,7 +61,7 @@ class LLVMProgram:
def __call__(self, unused_global_size, unused_local_size, *bufs, wait=False):
cfunc = CFUNCTYPE(ctypes.c_int, *[ctypes.c_void_p for _ in bufs])(self.fxn)
if wait: st = time.monotonic()
cfunc(*[x._buf for x in bufs])
cfunc(*[x._buf if not isinstance(x, int) else x for x in bufs])
if wait: return time.monotonic()-st
LLVMBuffer = Compiled(RawMallocBuffer, LinearizerOptions(supports_float4=False, has_local=False), uops_to_llvm_ir, LLVMProgram)