From 2dd0dd4ae024a4caa4b46f2900922120d1c2847f Mon Sep 17 00:00:00 2001 From: chenyu Date: Thu, 14 Dec 2023 18:13:22 -0500 Subject: [PATCH] cleanup llvmir (#2770) --- tinygrad/renderer/llvmir.py | 33 ++++++++++++--------------------- 1 file changed, 12 insertions(+), 21 deletions(-) diff --git a/tinygrad/renderer/llvmir.py b/tinygrad/renderer/llvmir.py index 271607d9e4..2382fd5083 100644 --- a/tinygrad/renderer/llvmir.py +++ b/tinygrad/renderer/llvmir.py @@ -6,8 +6,7 @@ from tinygrad.ops import Op, UnaryOps, BinaryOps, TernaryOps MFLAGS = ('nsz', 'arcp', 'contract', 'afn', 'reassoc') # All from fast math, but nnan and ninf -def is_bool_or_unsigned(dtype: DType): - return dtype == dtypes.bool or dtypes.is_unsigned(dtype) +def is_bool_or_unsigned(dtype: DType): return dtype == dtypes.bool or dtypes.is_unsigned(dtype) code_for_op: Final[Dict[Op, Callable]] = { UnaryOps.NEG: lambda builder, x, var_dtype: builder.xor(x, ir.Constant(ir.IntType(1), 1)) if var_dtype == dtypes.bool else builder.neg(x) if dtypes.is_int(var_dtype) else builder.fneg(x, flags=MFLAGS), # noqa: E501 @@ -24,17 +23,16 @@ code_for_op: Final[Dict[Op, Callable]] = { BinaryOps.CMPLT: lambda builder, x, y, var_dtype: builder.icmp_unsigned("<", x, y) if is_bool_or_unsigned(var_dtype) else builder.icmp_signed("<", x, y) if dtypes.is_int(var_dtype) else builder.fcmp_unordered("<", x, y, flags=MFLAGS), # noqa: E501 BinaryOps.MAX: lambda builder, x, y, var_dtype: builder.select(builder.icmp_unsigned(">", x, y) if is_bool_or_unsigned(var_dtype) else builder.icmp_signed(">", x, y) if dtypes.is_int(var_dtype) else builder.fcmp_unordered(">", x, y, flags=MFLAGS), x, y), # noqa: E501 BinaryOps.MOD: lambda builder, x, y, var_dtype: - builder.urem(x, y) if is_bool_or_unsigned(var_dtype) else builder.srem(x, y) if dtypes.is_int(var_dtype) else builder.frem(x, y), + builder.urem(x, y) if is_bool_or_unsigned(var_dtype) else builder.srem(x, y) if dtypes.is_int(var_dtype) else builder.frem(x, y), BinaryOps.XOR: lambda builder, x, y, var_dtype: builder.xor(x, y), TernaryOps.MULACC: lambda builder, x, y, z, var_dtype: builder.fadd(builder.fmul(x, y, flags=MFLAGS), z, flags=MFLAGS), TernaryOps.WHERE: lambda builder, x, y, z, var_dtype: builder.select(builder.trunc(x, ir.IntType(1)) if isinstance(x.type, ir.IntType) else builder.fcmp_unordered("!=", x, ir.Constant(ir.FloatType(), 0), flags=MFLAGS), y, z) # noqa: E501 } -dtype_to_llvm_dtype = {dtypes.float64:ir.DoubleType(), dtypes.float16:ir.HalfType(), dtypes.bfloat16:ir.IntType(16), dtypes.float32:ir.FloatType(), - dtypes.int8:ir.IntType(8), dtypes.uint8:ir.IntType(8), dtypes.bool: ir.IntType(1), dtypes.int64: ir.IntType(64), - dtypes.int32: ir.IntType(32), dtypes.int16:ir.IntType(16), dtypes.uint16:ir.IntType(16), - dtypes.uint32:ir.IntType(32), dtypes.uint64:ir.IntType(64)} +dtype_to_llvm_dtype = { dtypes.bool:ir.IntType(1), dtypes.int8:ir.IntType(8), dtypes.uint8:ir.IntType(8), dtypes.int16:ir.IntType(16), + dtypes.uint16:ir.IntType(16), dtypes.int32:ir.IntType(32), dtypes.uint32:ir.IntType(32), dtypes.int64:ir.IntType(64), dtypes.uint64:ir.IntType(64), + dtypes.float16:ir.HalfType(), dtypes.bfloat16:ir.IntType(16), dtypes.float32:ir.FloatType(), dtypes.float64:ir.DoubleType() } def cast(bb, val, input_type, output_type, bitcast=False): if input_type == output_type: return val @@ -50,9 +48,7 @@ def cast(bb, val, input_type, output_type, bitcast=False): if output_type == dtypes.bool: return bb[-1].fcmp_unordered('!=', cast(bb, val, input_type, dtypes.float32), ir.Constant(ir.FloatType(), 0)) if dtypes.is_unsigned(input_type) or input_type == dtypes.bool: - if output_type == dtypes.float16: - val = bb[-1].uitofp(val, ir.FloatType()) - return bb[-1].fptrunc(val, ir.HalfType()) + if output_type == dtypes.float16: return bb[-1].fptrunc(bb[-1].uitofp(val, ir.FloatType()), ir.HalfType()) if dtypes.is_float(output_type): return bb[-1].uitofp(val, dtype_to_llvm_dtype[output_type]) if dtypes.is_int(output_type): if input_type.itemsize > output_type.itemsize: return bb[-1].trunc(val, dtype_to_llvm_dtype[output_type]) @@ -60,9 +56,7 @@ def cast(bb, val, input_type, output_type, bitcast=False): if output_type == dtypes.bool: return bb[-1].icmp_unsigned('!=', val, ir.Constant(val.type, 0)) if dtypes.is_int(input_type): - if output_type == dtypes.float16: - val = bb[-1].sitofp(val, ir.FloatType()) - return bb[-1].fptrunc(val, ir.HalfType()) + if output_type == dtypes.float16: return bb[-1].fptrunc(bb[-1].sitofp(val, ir.FloatType()), ir.HalfType()) if dtypes.is_float(output_type): return bb[-1].sitofp(val, dtype_to_llvm_dtype[output_type]) if dtypes.is_int(output_type): if input_type.itemsize > output_type.itemsize: return bb[-1].trunc(val, dtype_to_llvm_dtype[output_type]) @@ -72,6 +66,7 @@ def cast(bb, val, input_type, output_type, bitcast=False): raise NotImplementedError(f"cast from {input_type} -> {output_type} not implemented") def const(args, dtype): + # TODO: remove int from int(args) once const args conform with dtype return ir.Constant(dtype_to_llvm_dtype[dtype], int(args) if dtypes.is_int(dtype) else bool(args) if dtype == dtypes.bool else args) def uops_to_llvm_ir(function_name:str, uops:List[UOp]) -> Tuple[str, Dict]: @@ -124,15 +119,12 @@ def uops_to_llvm_ir(function_name:str, uops:List[UOp]) -> Tuple[str, Dict]: for n,phi in phis: phi.add_incoming(lvars[n], bb[-1]._block) bb.append(ir.IRBuilder(func.append_basic_block(f"loop_exit_{len(loop_blocks)}"))) bb[-2].cbranch(bb[-2].icmp_unsigned("<", idx_p1, lvars[vin[0].vin[1]]), block._block, bb[-1]._block) - if uop == UOps.DEFINE_GLOBAL: - lvars[u] = func.args[buf_index[args]] + if uop == UOps.DEFINE_GLOBAL: lvars[u] = func.args[buf_index[args]] if uop == UOps.DEFINE_ACC: lvars[u] = const(args, dtype) reduce_phis.append(u) - if uop == UOps.SPECIAL: - lvars[u] = lvars[args.expr] - if uop == UOps.CONST: - lvars[u] = const(args, dtype) + if uop == UOps.SPECIAL: lvars[u] = lvars[args.expr] + if uop == UOps.CONST: lvars[u] = const(args, dtype) if uop == UOps.LOAD: assert dtype is not None if len(vin) > 2: @@ -157,8 +149,7 @@ def uops_to_llvm_ir(function_name:str, uops:List[UOp]) -> Tuple[str, Dict]: if len(vin) > 3: with bb[-1].if_then(bb[-1].trunc(lvars[vin[3]], ir.IntType(1))): store_op() else: store_op() - if uop == UOps.ALU: - lvars[u] = code_for_op[args](bb[-1], *[lvars[x] for x in vin] + [dtype if args != BinaryOps.CMPLT else vin[0].dtype]) + if uop == UOps.ALU: lvars[u] = code_for_op[args](bb[-1], *[lvars[x] for x in vin] + [dtype if args != BinaryOps.CMPLT else vin[0].dtype]) if uop == UOps.CAST: lvars[u] = cast(bb, lvars[vin[0]], vin[0].dtype, dtype, bitcast=isinstance(args, tuple) and args[1]) bb[-1].ret_void()