diff --git a/test/models/test_real_world.py b/test/models/test_real_world.py index 0f96170030..8a71c64033 100644 --- a/test/models/test_real_world.py +++ b/test/models/test_real_world.py @@ -49,7 +49,7 @@ class TestRealWorld(unittest.TestCase): def test(t, t2): return model(t, 801, t2).realize() helper_test("test_sd", lambda: (Tensor.randn(1, 4, 64, 64),Tensor.randn(1, 77, 768)), test, 18.0, 967) - @unittest.skipUnless(Device.DEFAULT in JIT_SUPPORTED_DEVICE, "needs JIT") + @unittest.skipUnless(Device.DEFAULT in JIT_SUPPORTED_DEVICE and Device.DEFAULT not in ["LLVM"], "needs JIT, too long on CI LLVM") def test_llama(self): old_type = Tensor.default_type Tensor.default_type = dtypes.float16 @@ -63,7 +63,7 @@ class TestRealWorld(unittest.TestCase): Tensor.default_type = old_type - @unittest.skipUnless(Device.DEFAULT in JIT_SUPPORTED_DEVICE, "needs JIT") + @unittest.skipUnless(Device.DEFAULT in JIT_SUPPORTED_DEVICE and Device.DEFAULT not in ["LLVM"], "needs JIT, too long on CI LLVM") def test_train_cifar(self): # TODO: with default device #old_default = Device.DEFAULT diff --git a/test/test_symbolic_jit.py b/test/test_symbolic_jit.py index 5543845616..bea43d3365 100644 --- a/test/test_symbolic_jit.py +++ b/test/test_symbolic_jit.py @@ -6,7 +6,7 @@ from tinygrad.tensor import Tensor, Device import numpy as np @unittest.skipIf(getenv("ARM64") or getenv("PTX"), "ARM64 and PTX are not supported") -@unittest.skipUnless(Device.DEFAULT in ["GPU", "METAL", "CLANG", "CUDA"], f"{Device.DEFAULT} is not supported") +@unittest.skipUnless(Device.DEFAULT in ["GPU", "METAL", "CLANG", "CUDA", "LLVM"], f"{Device.DEFAULT} is not supported") class TestSymbolicJit(unittest.TestCase): def test_plus1(self): def f(a): return (a+1).realize() diff --git a/tinygrad/jit.py b/tinygrad/jit.py index 7f2a07109c..00888d8fb4 100644 --- a/tinygrad/jit.py +++ b/tinygrad/jit.py @@ -8,7 +8,7 @@ from tinygrad.ops import RawBuffer from tinygrad.shape.shapetracker import ShapeTracker from tinygrad.shape.symbolic import Variable -JIT_SUPPORTED_DEVICE = ["GPU", "CLANG", "METAL", "CUDA", "HIP", "WEBGPU"] +JIT_SUPPORTED_DEVICE = ["GPU", "CLANG", "METAL", "CUDA", "HIP", "WEBGPU", "LLVM"] class TinyJit: def __init__(self, fxn:Callable): diff --git a/tinygrad/renderer/llvmir.py b/tinygrad/renderer/llvmir.py index ab3737dc32..bcc0d07c66 100644 --- a/tinygrad/renderer/llvmir.py +++ b/tinygrad/renderer/llvmir.py @@ -6,13 +6,13 @@ from tinygrad.helpers import dtypes from tinygrad.ops import Op, UnaryOps, BinaryOps, TernaryOps from tinygrad.shape.symbolic import Variable, NumNode, MulNode, DivNode, ModNode, LtNode, SumNode, AndNode -def int_const(x): return ir.Constant(ir.IntType(64), x) +def sym_render(a, ops=None, ctx=None): return ir.Constant(ir.IntType(64), a) if isinstance(a, int) else a.render(ops, ctx) render_llvm = { - NumNode: lambda self,ops,ctx: int_const(self.b), - MulNode: lambda self,ops,ctx: ctx.mul(self.a.render(ops,ctx), int_const(self.b)), - DivNode: lambda self,ops,ctx: ctx.sdiv(self.a.render(ops,ctx), int_const(self.b)), - ModNode: lambda self,ops,ctx: ctx.srem(self.a.render(ops,ctx), int_const(self.b)), - LtNode: lambda self,ops,ctx: ctx.icmp_signed("<", self.a.render(ops,ctx), int_const(self.b)), + NumNode: lambda self,ops,ctx: sym_render(self.b,ops,ctx), + MulNode: lambda self,ops,ctx: ctx.mul(self.a.render(ops,ctx), sym_render(self.b,ops,ctx)), + DivNode: lambda self,ops,ctx: ctx.sdiv(self.a.render(ops,ctx), sym_render(self.b,ops,ctx)), + ModNode: lambda self,ops,ctx: ctx.srem(self.a.render(ops,ctx), sym_render(self.b,ops,ctx)), + LtNode: lambda self,ops,ctx: ctx.icmp_signed("<", self.a.render(ops,ctx), sym_render(self.b,ops,ctx)), SumNode: lambda self,ops,ctx: functools.reduce(lambda a,b: ctx.add(a,b.render(ops,ctx)), self.nodes[1:], self.nodes[0].render(ops,ctx)), AndNode: lambda self,ops,ctx: functools.reduce(lambda a,b: ctx.and_(a,b.render(ops,ctx)), self.nodes[1:], self.nodes[0].render(ops,ctx)) } @@ -33,7 +33,7 @@ code_for_op: Final[Dict[Op, Callable]] = { TernaryOps.WHERE: lambda builder,x,y,z: builder.select(builder.fcmp_unordered("!=", x, ir.Constant(ir.FloatType(), 0), flags=('fast',)), y, z, flags=('fast',)), } -dtype_to_llvm_dtype = {dtypes.float64:ir.DoubleType(), dtypes.float16:ir.HalfType(), dtypes.bfloat16:ir.IntType(16), dtypes.float32:ir.FloatType(), dtypes.int8:ir.IntType(8), dtypes.uint8:ir.IntType(8), dtypes.bool: ir.IntType(1), dtypes.int64: ir.IntType(64), dtypes.int32: ir.IntType(32)} +dtype_to_llvm_dtype = {dtypes.float64:ir.DoubleType(), dtypes.float16:ir.HalfType(), dtypes.bfloat16:ir.IntType(16), dtypes.float32:ir.FloatType(), dtypes.int8:ir.IntType(8), dtypes.uint8:ir.IntType(8), dtypes.bool: ir.IntType(1), dtypes.int64: ir.IntType(64), dtypes.int32: ir.IntType(32), dtypes._arg_int32: ir.IntType(32)} def cast(bb, val, input_type, output_type): if input_type == output_type: return val @@ -75,9 +75,10 @@ def uops_to_llvm_ir(function_name:str, uops:List[UOp]) -> Tuple[str, Optional[Li buf_index = {x:i for i,x in enumerate(buf_to_dtype.keys())} # create llvm function - func_dtypes = [dtype_to_llvm_dtype[dtype] for dtype in buf_to_dtype.values()] - func = ir.Function(module, ir.FunctionType(ir.VoidType(), [x.as_pointer() for x in func_dtypes]), name=function_name) - for a in func.args: a.add_attribute("noalias") + func_dtypes = [(dtype_to_llvm_dtype[dtype],dtype) for dtype in buf_to_dtype.values()] + func = ir.Function(module, ir.FunctionType(ir.VoidType(), [x.as_pointer() if dt!=dtypes._arg_int32 else x for x,dt in func_dtypes]), name=function_name) + for a in func.args: + if a.type.is_pointer: a.add_attribute("noalias") # force llvmlite to allow us to add function attribute then add the attribute func.attributes._known = func.attributes._known.union(frozenset(['"no-nans-fp-math"="true"'])) @@ -90,6 +91,9 @@ def uops_to_llvm_ir(function_name:str, uops:List[UOp]) -> Tuple[str, Optional[Li lvars: Dict[Optional[Token], Any] = {} # this Any is an llvm type render_llvm[Variable] = lambda self,ops,ctx: lvars[self.expr] + for bufname,dtype in buf_to_dtype.items(): + if dtype == dtypes._arg_int32: lvars[bufname] = bb[-1].sext(func.args[buf_index[bufname]], ir.IntType(64)) + for uop,newvar,vin,args in uops: if uop == UOps.LOOP: for var in args[0]: @@ -106,16 +110,16 @@ def uops_to_llvm_ir(function_name:str, uops:List[UOp]) -> Tuple[str, Optional[Li loop_blocks.append((bb[-1], phis)) lvars[var.expr] = bb[-1].phi(ir.IntType(64), name=var.expr) - lvars[var.expr].add_incoming(int_const(var.min), bb[-2]._block) + lvars[var.expr].add_incoming(sym_render(var.min), bb[-2]._block) if uop == UOps.ENDLOOP: for var in args[0][::-1]: if isinstance(var, NumNode): continue block, phis = loop_blocks.pop() - idx_p1 = bb[-1].add(lvars[var.expr], int_const(1)) + idx_p1 = bb[-1].add(lvars[var.expr], sym_render(1)) lvars[var.expr].add_incoming(idx_p1, bb[-1]._block) for n,phi in phis: phi.add_incoming(lvars[n], bb[-1]._block) bb.append(ir.IRBuilder(func.append_basic_block(f"loop_exit_{var.expr}"))) - bb[-2].cbranch(bb[-2].icmp_unsigned("==", idx_p1, int_const(var.max+1)), bb[-1]._block, block._block) + bb[-2].cbranch(bb[-2].icmp_unsigned(">", idx_p1, sym_render(var.max, render_llvm, bb[-2])), bb[-1]._block, block._block) if uop == UOps.LOAD: assert newvar is not None and isinstance(args, (MemOp, ConstOp)) valid = args.valid.render(render_llvm, bb[-1]) @@ -130,7 +134,7 @@ def uops_to_llvm_ir(function_name:str, uops:List[UOp]) -> Tuple[str, Optional[Li else: idx = args.idx.render(render_llvm, bb[-1]) if args.valid.min == 0: - aug_idx = bb[-1].select(valid, idx, int_const(0)) + aug_idx = bb[-1].select(valid, idx, sym_render(0)) val = bb[-1].select(valid, bb[-1].load(bb[-1].gep(func.args[buf_index[args.name]], [aug_idx], inbounds=True)), ir.Constant(dtype_to_llvm_dtype[args.memory_dtype], args.invalid_value)) else: val = bb[-1].load(bb[-1].gep(func.args[buf_index[args.name]], [idx], inbounds=True)) diff --git a/tinygrad/runtime/ops_llvm.py b/tinygrad/runtime/ops_llvm.py index 00d404a799..7a2b140829 100644 --- a/tinygrad/runtime/ops_llvm.py +++ b/tinygrad/runtime/ops_llvm.py @@ -61,7 +61,7 @@ class LLVMProgram: def __call__(self, unused_global_size, unused_local_size, *bufs, wait=False): cfunc = CFUNCTYPE(ctypes.c_int, *[ctypes.c_void_p for _ in bufs])(self.fxn) if wait: st = time.monotonic() - cfunc(*[x._buf for x in bufs]) + cfunc(*[x._buf if not isinstance(x, int) else x for x in bufs]) if wait: return time.monotonic()-st LLVMBuffer = Compiled(RawMallocBuffer, LinearizerOptions(supports_float4=False, has_local=False), uops_to_llvm_ir, LLVMProgram)