cleanup llvmir (#2770)

This commit is contained in:
chenyu
2023-12-14 18:13:22 -05:00
committed by GitHub
parent 66d9eb10b6
commit 2dd0dd4ae0

View File

@@ -6,8 +6,7 @@ from tinygrad.ops import Op, UnaryOps, BinaryOps, TernaryOps
MFLAGS = ('nsz', 'arcp', 'contract', 'afn', 'reassoc') # All from fast math, but nnan and ninf
def is_bool_or_unsigned(dtype: DType):
return dtype == dtypes.bool or dtypes.is_unsigned(dtype)
def is_bool_or_unsigned(dtype: DType): return dtype == dtypes.bool or dtypes.is_unsigned(dtype)
code_for_op: Final[Dict[Op, Callable]] = {
UnaryOps.NEG: lambda builder, x, var_dtype: builder.xor(x, ir.Constant(ir.IntType(1), 1)) if var_dtype == dtypes.bool else builder.neg(x) if dtypes.is_int(var_dtype) else builder.fneg(x, flags=MFLAGS), # noqa: E501
@@ -24,17 +23,16 @@ code_for_op: Final[Dict[Op, Callable]] = {
BinaryOps.CMPLT: lambda builder, x, y, var_dtype: builder.icmp_unsigned("<", x, y) if is_bool_or_unsigned(var_dtype) else builder.icmp_signed("<", x, y) if dtypes.is_int(var_dtype) else builder.fcmp_unordered("<", x, y, flags=MFLAGS), # noqa: E501
BinaryOps.MAX: lambda builder, x, y, var_dtype: builder.select(builder.icmp_unsigned(">", x, y) if is_bool_or_unsigned(var_dtype) else builder.icmp_signed(">", x, y) if dtypes.is_int(var_dtype) else builder.fcmp_unordered(">", x, y, flags=MFLAGS), x, y), # noqa: E501
BinaryOps.MOD: lambda builder, x, y, var_dtype:
builder.urem(x, y) if is_bool_or_unsigned(var_dtype) else builder.srem(x, y) if dtypes.is_int(var_dtype) else builder.frem(x, y),
builder.urem(x, y) if is_bool_or_unsigned(var_dtype) else builder.srem(x, y) if dtypes.is_int(var_dtype) else builder.frem(x, y),
BinaryOps.XOR: lambda builder, x, y, var_dtype: builder.xor(x, y),
TernaryOps.MULACC: lambda builder, x, y, z, var_dtype: builder.fadd(builder.fmul(x, y, flags=MFLAGS), z, flags=MFLAGS),
TernaryOps.WHERE: lambda builder, x, y, z, var_dtype: builder.select(builder.trunc(x, ir.IntType(1)) if isinstance(x.type, ir.IntType) else builder.fcmp_unordered("!=", x, ir.Constant(ir.FloatType(), 0), flags=MFLAGS), y, z) # noqa: E501
}
dtype_to_llvm_dtype = {dtypes.float64:ir.DoubleType(), dtypes.float16:ir.HalfType(), dtypes.bfloat16:ir.IntType(16), dtypes.float32:ir.FloatType(),
dtypes.int8:ir.IntType(8), dtypes.uint8:ir.IntType(8), dtypes.bool: ir.IntType(1), dtypes.int64: ir.IntType(64),
dtypes.int32: ir.IntType(32), dtypes.int16:ir.IntType(16), dtypes.uint16:ir.IntType(16),
dtypes.uint32:ir.IntType(32), dtypes.uint64:ir.IntType(64)}
dtype_to_llvm_dtype = { dtypes.bool:ir.IntType(1), dtypes.int8:ir.IntType(8), dtypes.uint8:ir.IntType(8), dtypes.int16:ir.IntType(16),
dtypes.uint16:ir.IntType(16), dtypes.int32:ir.IntType(32), dtypes.uint32:ir.IntType(32), dtypes.int64:ir.IntType(64), dtypes.uint64:ir.IntType(64),
dtypes.float16:ir.HalfType(), dtypes.bfloat16:ir.IntType(16), dtypes.float32:ir.FloatType(), dtypes.float64:ir.DoubleType() }
def cast(bb, val, input_type, output_type, bitcast=False):
if input_type == output_type: return val
@@ -50,9 +48,7 @@ def cast(bb, val, input_type, output_type, bitcast=False):
if output_type == dtypes.bool: return bb[-1].fcmp_unordered('!=', cast(bb, val, input_type, dtypes.float32), ir.Constant(ir.FloatType(), 0))
if dtypes.is_unsigned(input_type) or input_type == dtypes.bool:
if output_type == dtypes.float16:
val = bb[-1].uitofp(val, ir.FloatType())
return bb[-1].fptrunc(val, ir.HalfType())
if output_type == dtypes.float16: return bb[-1].fptrunc(bb[-1].uitofp(val, ir.FloatType()), ir.HalfType())
if dtypes.is_float(output_type): return bb[-1].uitofp(val, dtype_to_llvm_dtype[output_type])
if dtypes.is_int(output_type):
if input_type.itemsize > output_type.itemsize: return bb[-1].trunc(val, dtype_to_llvm_dtype[output_type])
@@ -60,9 +56,7 @@ def cast(bb, val, input_type, output_type, bitcast=False):
if output_type == dtypes.bool: return bb[-1].icmp_unsigned('!=', val, ir.Constant(val.type, 0))
if dtypes.is_int(input_type):
if output_type == dtypes.float16:
val = bb[-1].sitofp(val, ir.FloatType())
return bb[-1].fptrunc(val, ir.HalfType())
if output_type == dtypes.float16: return bb[-1].fptrunc(bb[-1].sitofp(val, ir.FloatType()), ir.HalfType())
if dtypes.is_float(output_type): return bb[-1].sitofp(val, dtype_to_llvm_dtype[output_type])
if dtypes.is_int(output_type):
if input_type.itemsize > output_type.itemsize: return bb[-1].trunc(val, dtype_to_llvm_dtype[output_type])
@@ -72,6 +66,7 @@ def cast(bb, val, input_type, output_type, bitcast=False):
raise NotImplementedError(f"cast from {input_type} -> {output_type} not implemented")
def const(args, dtype):
# TODO: remove int from int(args) once const args conform with dtype
return ir.Constant(dtype_to_llvm_dtype[dtype], int(args) if dtypes.is_int(dtype) else bool(args) if dtype == dtypes.bool else args)
def uops_to_llvm_ir(function_name:str, uops:List[UOp]) -> Tuple[str, Dict]:
@@ -124,15 +119,12 @@ def uops_to_llvm_ir(function_name:str, uops:List[UOp]) -> Tuple[str, Dict]:
for n,phi in phis: phi.add_incoming(lvars[n], bb[-1]._block)
bb.append(ir.IRBuilder(func.append_basic_block(f"loop_exit_{len(loop_blocks)}")))
bb[-2].cbranch(bb[-2].icmp_unsigned("<", idx_p1, lvars[vin[0].vin[1]]), block._block, bb[-1]._block)
if uop == UOps.DEFINE_GLOBAL:
lvars[u] = func.args[buf_index[args]]
if uop == UOps.DEFINE_GLOBAL: lvars[u] = func.args[buf_index[args]]
if uop == UOps.DEFINE_ACC:
lvars[u] = const(args, dtype)
reduce_phis.append(u)
if uop == UOps.SPECIAL:
lvars[u] = lvars[args.expr]
if uop == UOps.CONST:
lvars[u] = const(args, dtype)
if uop == UOps.SPECIAL: lvars[u] = lvars[args.expr]
if uop == UOps.CONST: lvars[u] = const(args, dtype)
if uop == UOps.LOAD:
assert dtype is not None
if len(vin) > 2:
@@ -157,8 +149,7 @@ def uops_to_llvm_ir(function_name:str, uops:List[UOp]) -> Tuple[str, Dict]:
if len(vin) > 3:
with bb[-1].if_then(bb[-1].trunc(lvars[vin[3]], ir.IntType(1))): store_op()
else: store_op()
if uop == UOps.ALU:
lvars[u] = code_for_op[args](bb[-1], *[lvars[x] for x in vin] + [dtype if args != BinaryOps.CMPLT else vin[0].dtype])
if uop == UOps.ALU: lvars[u] = code_for_op[args](bb[-1], *[lvars[x] for x in vin] + [dtype if args != BinaryOps.CMPLT else vin[0].dtype])
if uop == UOps.CAST: lvars[u] = cast(bb, lvars[vin[0]], vin[0].dtype, dtype, bitcast=isinstance(args, tuple) and args[1])
bb[-1].ret_void()