From b6aaf12df7015c7e7d1e292c969e9ea316293224 Mon Sep 17 00:00:00 2001 From: qazal <77887910+Qazalin@users.noreply.github.com> Date: Fri, 10 Nov 2023 13:42:39 -0500 Subject: [PATCH] Internal cast 2 with more tests (#2257) * Change linearizer to parse CAST * Oneliner renders for cstyle and triton * LLVM cast and ALU implementation * pylint fixes * cast in gep * remove printbufs * use cast for post-load ops * get rid of parse_cast * partially supported vectorized dtypes for initial dev * render phi as the dtype * Revert "partially supported vectorized dtypes for initial dev" This reverts commit 1bf1a818a3350d74314806f00f5aaacb075bdf51. * Revert "render phi as the dtype" This reverts commit d08cb270b42266f06e4a78b199f9937cb9dc4711. * reenable triton tests * no vstore_half if dtype is already half * upcast max --- test/test_dtype.py | 2 -- tinygrad/codegen/linearizer.py | 6 ++-- tinygrad/renderer/cstyle.py | 6 ++-- tinygrad/renderer/llvmir.py | 58 ++++++++++++++++++---------------- tinygrad/renderer/triton.py | 1 + 5 files changed, 38 insertions(+), 35 deletions(-) diff --git a/test/test_dtype.py b/test/test_dtype.py index 1a4231fd30..b78e52163b 100644 --- a/test/test_dtype.py +++ b/test/test_dtype.py @@ -19,8 +19,6 @@ def is_dtype_supported(dtype: DType): if dtype == dtypes.bool: # host-shareablity is a requirement for storage buffers, but 'bool' type is not host-shareable if Device.DEFAULT == "WEBGPU": return False - # TODO remove triton from here once internal casting is fixed. CAST of fp32s between 0-1 is broken in triton - if getenv("TRITON") == 1: return False return True def get_available_cast_dtypes(dtype: DType) -> List[DType]: return [v for k, v in DTYPES_DICT.items() if v != dtype and is_dtype_supported(v) and not k.startswith("_")] # dont cast internal dtypes diff --git a/tinygrad/codegen/linearizer.py b/tinygrad/codegen/linearizer.py index 110f338f2c..4de88e48f6 100644 --- a/tinygrad/codegen/linearizer.py +++ b/tinygrad/codegen/linearizer.py @@ -5,7 +5,7 @@ from collections import defaultdict from enum import Enum, auto from dataclasses import dataclass -from tinygrad.helpers import colored, ImageDType, DEBUG, dtypes, DType, prod, PtrDType, all_same, getenv +from tinygrad.helpers import colored, ImageDType, DEBUG, dtypes, DType, prod, PtrDType, getenv from tinygrad.ops import LazyOp, UnaryOps, ConstBuffer, MemBuffer, BufferOps from tinygrad.ops import ReduceOps, BinaryOps, TernaryOps from tinygrad.shape.shapetracker import ShapeTracker @@ -422,7 +422,6 @@ class Linearizer(Kernel): def uop(self, uop:UOps, dtype:Optional[DType], vin:Tuple[UOp, ...], arg:Any=None, cachable=True) -> UOp: key = (uop, dtype, vin, arg) if uop == UOps.PHI and len(vin) == 2 and vin[0] == vin[1]: return vin[0] # self phi is noop - if uop == UOps.CAST and all(x.uop == UOps.GEP for x in vin) and all_same([x.vin[0] for x in vin]) and all(x.arg == i for i,x in enumerate(vin)): return vin[0].vin[0] if uop == UOps.GEP and vin[0].uop == UOps.CONST: return self.const(vin[0].arg, dtype) if uop == UOps.ALU: # rewrites. NOTE: the rewritten NEG op is still around... @@ -445,7 +444,8 @@ class Linearizer(Kernel): def ast_parse(self, x, acc, offs, loaded_buffers, do_reduce=False, loop_ctx=tuple()) -> List[UOp]: if x.__class__ is not LazyOp: return loaded_buffers[x] # for LOCAL_BUFFER if x.op in BufferOps: return loaded_buffers[x.arg] - if x.op in [UnaryOps.NOOP, UnaryOps.CAST]: return self.ast_parse(x.src[0], acc, offs, loaded_buffers) # cast isn't an ALU op + if x.op == UnaryOps.NOOP: return self.ast_parse(x.src[0], acc, offs, loaded_buffers) + if x.op == UnaryOps.CAST: return [self.uop(UOps.CAST, x.arg[0], (u,), x.arg) if not isinstance(x.arg[0], ImageDType) else u for u in self.ast_parse(x.src[0], acc, offs, loaded_buffers)] if x.op in ReduceOps and not do_reduce: assert offs is None, "not available if we aren't doing reduce" return acc diff --git a/tinygrad/renderer/cstyle.py b/tinygrad/renderer/cstyle.py index 4a9ef73f44..4fa7290fb0 100644 --- a/tinygrad/renderer/cstyle.py +++ b/tinygrad/renderer/cstyle.py @@ -102,7 +102,7 @@ class CStyleLanguage(NamedTuple): if isinstance(buf_dtype, ImageDType): assert var_dtype == dtypes._float4, "images must be float4" return f"write_imagef({buf_name}, {idx}, {var_name});" - if self.uses_vload and buf_dtype == dtypes.float16: + if self.uses_vload and buf_dtype == dtypes.float16 and var_dtype != dtypes.float16: return f"vstore_half{'' if var_dtype.sz == 1 else str(var_dtype.sz)}({var_name}, 0, {buf_name}+{idx});" if var_dtype.sz > 1: return f"*(({self.smem_prefix if local and self.smem_prefix_for_cast else self.buffer_prefix}{buf_dtype.name}{var_dtype.sz}*)({buf_name}+{idx})) = ({buf_dtype.name}{var_dtype.sz}){var_name};" @@ -163,6 +163,8 @@ def uops_to_cstyle(lang:CStyleLanguage, function_name:str, uops:List[UOp]) -> Tu # remove parens if ALU types are the same. TODO: can do more here if vin[0].uop == UOps.ALU and vin[0].arg == args and args in {BinaryOps.ADD, BinaryOps.SUB, BinaryOps.MUL}: val = lang.code_for_op[args](strip_parens(r[vin[0]]), *[r[x] for x in vin[1:]]) + elif args == BinaryOps.MAX: + val = lang.code_for_op[args](*[lang.render_cast([r[x]], dtype) if x.dtype != dtype else r[x] for x in vin]) else: val = lang.code_for_op[args](*[r[x] for x in vin]) assert child_count[u] != 0, f"childless ALU op found {u}" @@ -191,7 +193,7 @@ def uops_to_cstyle(lang:CStyleLanguage, function_name:str, uops:List[UOp]) -> Tu elif uop == UOps.STORE: assert vin[0].dtype is not None and vin[2].dtype is not None kk(lang.render_store(r[vin[0]], vin[0].dtype, r[vin[2]], vin[2].dtype, strip_parens(r[vin[1]]), vin[0].uop == UOps.DEFINE_LOCAL)) - elif uop == UOps.CAST and dtype is not None and dtype.sz > 1: + elif uop == UOps.CAST and dtype is not None: val = lang.render_cast([r[x] for x in vin], dtype) if child_count[u] <= 1: r[u] = val else: kk(f"{lang.generic_var_prefix if lang.generic_var_prefix else dtype.name} {ssa(u,'cast')} = {val};") diff --git a/tinygrad/renderer/llvmir.py b/tinygrad/renderer/llvmir.py index de81e831dd..bfe2d75d13 100644 --- a/tinygrad/renderer/llvmir.py +++ b/tinygrad/renderer/llvmir.py @@ -19,7 +19,7 @@ code_for_op: Final[Dict[Op, Callable]] = { # TODO: this should be casted BinaryOps.CMPLT: lambda builder,x,y: builder.zext(builder.icmp_signed("<", x, y),ir.IntType(32)) if isinstance(x.type, ir.IntType) else builder.uitofp(builder.fcmp_ordered("<", x, y, flags=LLVM_FAST_MATH_FLAGS), ir.FloatType()), BinaryOps.MAX: lambda builder,x,y: builder.select(builder.fcmp_unordered(">", x, y, flags=LLVM_FAST_MATH_FLAGS), x, y, flags=LLVM_FAST_MATH_FLAGS), - BinaryOps.MOD: lambda builder,x,y: builder.srem(x,y), + BinaryOps.MOD: lambda builder,x,y: builder.srem(x,y) if isinstance(x.type, ir.IntType) else builder.frem(x,y) if isinstance(x.type, ir.FloatType) else builder.urem(x,y), TernaryOps.MULACC: lambda builder,x,y,z: builder.fadd(builder.fmul(x,y, flags=LLVM_FAST_MATH_FLAGS), z, flags=LLVM_FAST_MATH_FLAGS), TernaryOps.WHERE: lambda builder,x,y,z: builder.select(builder.fcmp_unordered("!=", x, ir.Constant(ir.FloatType(), 0), flags=LLVM_FAST_MATH_FLAGS) if isinstance(x.type, ir.FloatType) else builder.trunc(x, ir.IntType(1)), y, z, flags=LLVM_FAST_MATH_FLAGS), } @@ -29,33 +29,34 @@ dtype_to_llvm_dtype = {dtypes.float64:ir.DoubleType(), dtypes.float16:ir.HalfTyp def cast(bb, val, input_type, output_type): if input_type == output_type: return val - if output_type == dtypes.float32: - if dtypes.is_int(input_type) or input_type == dtypes.bool: - val = bb[-1].uitofp(val, ir.FloatType()) if dtypes.is_unsigned(input_type) or input_type == dtypes.bool else bb[-1].sitofp(val, ir.FloatType()) - elif input_type == dtypes.bfloat16: - val = bb[-1].sext(val, ir.IntType(32)) - val = bb[-1].shl(val, ir.Constant(ir.IntType(32), 16)) - val = bb[-1].bitcast(val, ir.FloatType()) - elif input_type == dtypes.float64: - val = bb[-1].fptrunc(val, ir.FloatType()) - else: - val = bb[-1].fpext(val, ir.FloatType()) - return val + if dtypes.is_float(input_type): + if dtypes.is_float(output_type): + if output_type.itemsize > input_type.itemsize: return bb[-1].fpext(val, dtype_to_llvm_dtype[output_type]) + return bb[-1].fptrunc(val, dtype_to_llvm_dtype[output_type]) + if dtypes.is_int(output_type): + if dtypes.is_unsigned(output_type): return bb[-1].fptoui(val, dtype_to_llvm_dtype[output_type]) + return bb[-1].fptosi(val, dtype_to_llvm_dtype[output_type]) + if output_type == dtypes.bool: return bb[-1].fcmp_unordered('!=', cast(bb, val, input_type, dtypes.float32), ir.Constant(ir.FloatType(), 0)) - if input_type == dtypes.float32: - if dtypes.is_int(output_type) or output_type == dtypes.bool: - if dtypes.is_unsigned(output_type): val = bb[-1].fptoui(val, dtype_to_llvm_dtype[output_type]) - elif output_type == dtypes.bool: val = bb[-1].fcmp_ordered("!=", val, ir.Constant(ir.FloatType(), 0), flags=LLVM_FAST_MATH_FLAGS) - else: val = bb[-1].fptosi(val, dtype_to_llvm_dtype[output_type]) - elif output_type == dtypes.bfloat16: - val = bb[-1].bitcast(val, ir.IntType(32)) - val = bb[-1].lshr(val, ir.Constant(ir.IntType(32), 16)) - val = bb[-1].trunc(val, ir.IntType(16)) - elif output_type == dtypes.float64: - val = bb[-1].fpext(val, ir.DoubleType()) - else: - val = bb[-1].fptrunc(val, dtype_to_llvm_dtype[output_type]) - return val + if dtypes.is_unsigned(input_type) or input_type == dtypes.bool: + if output_type == dtypes.float16: + val = bb[-1].uitofp(val, ir.FloatType()) + return bb[-1].fptrunc(val, ir.HalfType()) + if dtypes.is_float(output_type): return bb[-1].uitofp(val, dtype_to_llvm_dtype[output_type]) + if dtypes.is_int(output_type): + if input_type.itemsize > output_type.itemsize: return bb[-1].trunc(val, dtype_to_llvm_dtype[output_type]) + return bb[-1].zext(val, dtype_to_llvm_dtype[output_type]) + if output_type == dtypes.bool: return bb[-1].icmp_unsigned('!=', val, ir.Constant(val.type, 0)) + + if dtypes.is_int(input_type): + if output_type == dtypes.float16: + val = bb[-1].sitofp(val, ir.FloatType()) + return bb[-1].fptrunc(val, ir.HalfType()) + if dtypes.is_float(output_type): return bb[-1].sitofp(val, dtype_to_llvm_dtype[output_type]) + if dtypes.is_int(output_type): + if input_type.itemsize > output_type.itemsize: return bb[-1].trunc(val, dtype_to_llvm_dtype[output_type]) + return bb[-1].sext(val, dtype_to_llvm_dtype[output_type]) + if output_type == dtypes.bool: return bb[-1].icmp_signed('!=', val, ir.Constant(val.type, 0)) raise NotImplementedError(f"cast from {input_type} -> {output_type} not implemented") @@ -141,7 +142,8 @@ def uops_to_llvm_ir(function_name:str, uops:List[UOp]) -> Tuple[str, Dict]: element = cast(bb, lvars[vin[2]], vin[2].dtype, vin[0].dtype) bb[-1].store(element, bb[-1].gep(lvars[vin[0]], [lvars[vin[1]]], inbounds=True)) if uop == UOps.ALU: - lvars[u] = code_for_op[args](bb[-1], *[lvars[x] for x in vin]) + lvars[u] = cast(bb, code_for_op[args](bb[-1], *[cast(bb, lvars[x], x.dtype, dtypes.float) for x in vin]), dtypes.float, dtype) + if uop == UOps.CAST: lvars[u] = cast(bb, lvars[vin[0]], vin[0].dtype, dtype) bb[-1].ret_void() return str(module), {} diff --git a/tinygrad/renderer/triton.py b/tinygrad/renderer/triton.py index 31beecba6a..93f92db3c7 100644 --- a/tinygrad/renderer/triton.py +++ b/tinygrad/renderer/triton.py @@ -108,6 +108,7 @@ def uops_to_triton(function_name:str, uops:List[UOp]): kk(f"{args[1]} = tl.arange({0}, {next_power_of_2(args[2])})") local_size.append(args[2]) r[u] = args[1] + elif uop == UOps.CAST and dtype is not None: r[u] = render_cast(r[vin[0]], dtype) else: raise NotImplementedError(f"unimplemented: {uop}") prg = f"import triton\nimport triton.language as tl\ntl.core.TRITON_MAX_TENSOR_NUMEL = float('inf')\n@triton.jit\ndef {function_name}("+','.join(f"{buf[0]}" for buf in bufs)+"):\n"