diff --git a/tinygrad/renderer/llvmir.py b/tinygrad/renderer/llvmir.py index 746e8c7ff2..9fcdccdc00 100644 --- a/tinygrad/renderer/llvmir.py +++ b/tinygrad/renderer/llvmir.py @@ -10,22 +10,22 @@ MFLAGS = ('nsz', 'arcp', 'contract', 'afn', 'reassoc') # All from fast math, but def is_bool_or_unsigned(dtype: DType): return dtype == dtypes.bool or dtypes.is_unsigned(dtype) code_for_op: Final[Dict[Op, Callable]] = { - UnaryOps.NEG: lambda builder, x, var_dtype: builder.neg(x) if dtypes.is_int(var_dtype) else \ - (builder.not_(x) if var_dtype is dtypes.bool else builder.fneg(x, flags=MFLAGS)), - UnaryOps.EXP2: lambda builder, x, var_dtype: builder.call(builder._block.module.declare_intrinsic('llvm.exp2', [x.type]), [x], fastmath=MFLAGS), - UnaryOps.LOG2: lambda builder, x, var_dtype: builder.call(builder._block.module.declare_intrinsic('llvm.log2', [x.type]), [x], fastmath=MFLAGS), - UnaryOps.SIN: lambda builder, x, var_dtype: builder.call(builder._block.module.declare_intrinsic('llvm.sin', [x.type]), [x], fastmath=MFLAGS), - UnaryOps.SQRT: lambda builder, x, var_dtype: builder.call(builder._block.module.declare_intrinsic('llvm.sqrt', [x.type]), [x], fastmath=MFLAGS), - BinaryOps.ADD: lambda builder, x, y, var_dtype: builder.or_(x, y) if var_dtype == dtypes.bool else builder.add(x, y) if dtypes.is_int(var_dtype) else builder.fadd(x, y, flags=MFLAGS), # noqa: E501 - BinaryOps.SUB: lambda builder, x, y, var_dtype: builder.sub(x, y) if dtypes.is_int(var_dtype) else builder.fsub(x, y, flags=MFLAGS), - BinaryOps.MUL: lambda builder, x, y, var_dtype: builder.mul(x, y) if is_bool_or_unsigned(var_dtype) or dtypes.is_int(var_dtype) else builder.fmul(x, y, flags=MFLAGS), # noqa: E501 - BinaryOps.DIV: lambda builder, x, y, var_dtype: builder.udiv(x, y) if is_bool_or_unsigned(var_dtype) else builder.sdiv(x, y) if dtypes.is_int(var_dtype) else builder.fdiv(x, y, flags=MFLAGS), # noqa: E501 - BinaryOps.CMPLT: lambda builder, x, y, var_dtype: builder.icmp_unsigned("<", x, y) if is_bool_or_unsigned(var_dtype) else builder.icmp_signed("<", x, y) if dtypes.is_int(var_dtype) else builder.fcmp_unordered("<", x, y, flags=MFLAGS), # noqa: E501 - BinaryOps.CMPEQ: lambda builder, x, y, var_dtype: builder.icmp_unsigned("==", x, y) if is_bool_or_unsigned(var_dtype) else builder.icmp_signed("==", x, y) if dtypes.is_int(var_dtype) else builder.fcmp_unordered("==", x, y, flags=MFLAGS), # noqa: E501 - BinaryOps.MAX: lambda builder, x, y, var_dtype: builder.select(builder.icmp_unsigned(">", x, y) if is_bool_or_unsigned(var_dtype) else builder.icmp_signed(">", x, y) if dtypes.is_int(var_dtype) else builder.fcmp_unordered(">", x, y, flags=MFLAGS), x, y), # noqa: E501 - BinaryOps.MOD: lambda builder, x, y, var_dtype: builder.urem(x, y) if is_bool_or_unsigned(var_dtype) else builder.srem(x, y) if dtypes.is_int(var_dtype) else builder.frem(x, y), # noqa: E501 - BinaryOps.XOR: lambda builder, x, y, var_dtype: builder.xor(x, y), - TernaryOps.WHERE: lambda builder, x, y, z, var_dtype: builder.select(x, y, z)} + UnaryOps.NEG: lambda builder, x, dtype: builder.neg(x) if dtypes.is_int(dtype) else \ + (builder.not_(x) if dtype is dtypes.bool else builder.fneg(x, flags=MFLAGS)), + UnaryOps.EXP2: lambda builder, x, dtype: builder.call(builder.module.declare_intrinsic('llvm.exp2', [x.type]), [x], fastmath=MFLAGS), + UnaryOps.LOG2: lambda builder, x, dtype: builder.call(builder.module.declare_intrinsic('llvm.log2', [x.type]), [x], fastmath=MFLAGS), + UnaryOps.SIN: lambda builder, x, dtype: builder.call(builder.module.declare_intrinsic('llvm.sin', [x.type]), [x], fastmath=MFLAGS), + UnaryOps.SQRT: lambda builder, x, dtype: builder.call(builder.module.declare_intrinsic('llvm.sqrt', [x.type]), [x], fastmath=MFLAGS), + BinaryOps.ADD: lambda builder, x, y, dtype: builder.or_(x, y) if dtype == dtypes.bool else builder.add(x, y) if dtypes.is_int(dtype) else builder.fadd(x, y, flags=MFLAGS), # noqa: E501 + BinaryOps.SUB: lambda builder, x, y, dtype: builder.sub(x, y) if dtypes.is_int(dtype) else builder.fsub(x, y, flags=MFLAGS), + BinaryOps.MUL: lambda builder, x, y, dtype: builder.mul(x, y) if is_bool_or_unsigned(dtype) or dtypes.is_int(dtype) else builder.fmul(x, y, flags=MFLAGS), # noqa: E501 + BinaryOps.DIV: lambda builder, x, y, dtype: builder.udiv(x, y) if is_bool_or_unsigned(dtype) else builder.sdiv(x, y) if dtypes.is_int(dtype) else builder.fdiv(x, y, flags=MFLAGS), # noqa: E501 + BinaryOps.CMPLT: lambda builder, x, y, dtype: builder.icmp_unsigned("<", x, y) if is_bool_or_unsigned(dtype) else builder.icmp_signed("<", x, y) if dtypes.is_int(dtype) else builder.fcmp_unordered("<", x, y, flags=MFLAGS), # noqa: E501 + BinaryOps.CMPEQ: lambda builder, x, y, dtype: builder.icmp_unsigned("==", x, y) if is_bool_or_unsigned(dtype) else builder.icmp_signed("==", x, y) if dtypes.is_int(dtype) else builder.fcmp_unordered("==", x, y, flags=MFLAGS), # noqa: E501 + BinaryOps.MAX: lambda builder, x, y, dtype: builder.select(builder.icmp_unsigned(">", x, y) if is_bool_or_unsigned(dtype) else builder.icmp_signed(">", x, y) if dtypes.is_int(dtype) else builder.fcmp_unordered(">", x, y, flags=MFLAGS), x, y), # noqa: E501 + BinaryOps.MOD: lambda builder, x, y, dtype: builder.urem(x, y) if is_bool_or_unsigned(dtype) else builder.srem(x, y) if dtypes.is_int(dtype) else builder.frem(x, y), # noqa: E501 + BinaryOps.XOR: lambda builder, x, y, dtype: builder.xor(x, y), + TernaryOps.WHERE: lambda builder, x, y, z, dtype: builder.select(x, y, z)} dtype_to_llvm_dtype = { dtypes.bool:ir.IntType(1), dtypes.int8:ir.IntType(8), dtypes.uint8:ir.IntType(8), dtypes.int16:ir.IntType(16), dtypes.uint16:ir.IntType(16), dtypes.int32:ir.IntType(32), dtypes.uint32:ir.IntType(32), dtypes.int64:ir.IntType(64), dtypes.uint64:ir.IntType(64), @@ -105,25 +105,25 @@ def uops_to_llvm_ir(function_name:str, uops:UOpGraph) -> str: elif uop is UOps.ENDLOOP: block, phis = loop_blocks.pop() idx_p1 = bb[-1].add(lvars[vin[0]], ir.Constant(ir.IntType(32), 1)) - lvars[vin[0]].add_incoming(idx_p1, bb[-1]._block) - for n,phi in phis: phi.add_incoming(lvars[n], bb[-1]._block) + lvars[vin[0]].add_incoming(idx_p1, bb[-1].block) + for n,phi in phis: phi.add_incoming(lvars[n], bb[-1].block) bb.append(ir.IRBuilder(func.append_basic_block(f"loop_exit_{len(loop_blocks)}"))) - bb[-2].cbranch(bb[-2].icmp_unsigned("<", idx_p1, lvars[vin[0].vin[1]]), block._block, bb[-1]._block) + bb[-2].cbranch(bb[-2].icmp_unsigned("<", idx_p1, lvars[vin[0].vin[1]]), block.block, bb[-1].block) else: assert dtype is not None, f"None dtype for uop {uop}" if uop == UOps.LOOP: bb.append(ir.IRBuilder(func.append_basic_block(f"loop_body_{len(loop_blocks)}"))) - bb[-2].branch(bb[-1]._block) + bb[-2].branch(bb[-1].block) phis = [] for rp in reduce_phis: incoming = lvars[rp] lvars[rp] = bb[-1].phi(dtype_to_llvm_dtype[rp.dtype]) - lvars[rp].add_incoming(incoming, bb[-2]._block) + lvars[rp].add_incoming(incoming, bb[-2].block) phis.append((rp, lvars[rp])) lvars[u] = bb[-1].phi(ir.IntType(32), name=f"loop{len(loop_blocks)}") - lvars[u].add_incoming(lvars[vin[0]], bb[-2]._block) + lvars[u].add_incoming(lvars[vin[0]], bb[-2].block) loop_blocks.append((bb[-1], phis)) elif uop is UOps.DEFINE_ACC: lvars[u] = const(args, dtype)