diff --git a/tinygrad/ops.py b/tinygrad/ops.py index 6e8bfedb3a..6f4ce533f5 100644 --- a/tinygrad/ops.py +++ b/tinygrad/ops.py @@ -835,6 +835,14 @@ def type_verify(uops:List[UOp]): print_uops(uops) raise RuntimeError(f"UOp verification failed at {i} on {u.op} {u.dtype} {len(u.src)} {[x.op for x in u.src]} {u.arg}") +# *** uop helpers *** + +def cast_float_to_bf16(x: UOp) -> UOp: + assert x.dtype == dtypes.float, "cast float -> bf16 must start with float" + x = x.bitcast(dtypes.uint) + x = (-x & 0x7f800000).where(x + ((x >> 16) & 1) + 0x7fff, (x & 0xffff).where((x | 0x10000), x)) + return (x >> 16).cast(dtypes.ushort).bitcast(dtypes.bfloat16) + # *** most of symbolic lives here now *** def split_uop(x:UOp, sep:BinaryOps): diff --git a/tinygrad/renderer/cstyle.py b/tinygrad/renderer/cstyle.py index 557d415057..a69df196fd 100644 --- a/tinygrad/renderer/cstyle.py +++ b/tinygrad/renderer/cstyle.py @@ -2,7 +2,7 @@ from __future__ import annotations from typing import Dict, List, Optional, Tuple, Union, DefaultDict, Literal, Callable, cast import os, math from collections import defaultdict, Counter -from tinygrad.ops import UnaryOps, BinaryOps, TernaryOps, UOps, UOp, PatternMatcher, UPat +from tinygrad.ops import UnaryOps, BinaryOps, TernaryOps, UOps, UOp, PatternMatcher, UPat, cast_float_to_bf16 from tinygrad.helpers import strip_parens, getenv, prod, dedup, AMX from tinygrad.dtype import ImageDType, dtypes, DType, PtrDType from tinygrad.renderer import Renderer, TensorCore @@ -286,12 +286,6 @@ class MetalRenderer(CStyleLanguage): return {arg[3].name}2(c.thread_elements()[0], c.thread_elements()[1]);\n}}""") return super().render_kernel(function_name, kernel, bufs, uops, prefix) -code_for_op_half = {UnaryOps.RECIP: lambda x,dtype: f"hrcp({x})" if dtype in (dtypes.half, dtypes.bfloat16) else f"(1/{x})", - UnaryOps.SQRT: lambda x,dtype: f"hsqrt({x})" if dtype in (dtypes.half, dtypes.bfloat16) else f"sqrt({x})", - UnaryOps.SIN: lambda x,dtype: f"hsin({x})" if dtype in (dtypes.half, dtypes.bfloat16) else f"sin({x})", - UnaryOps.LOG2: lambda x,dtype: f"hlog2({x})" if dtype in (dtypes.half, dtypes.bfloat16) else f"log2({x})", - UnaryOps.EXP2: lambda x,dtype: f"hexp2({x})" if dtype in (dtypes.half, dtypes.bfloat16) else f"exp2({x})",} - _nms = "xyzwabcdefghijkl" class CUDARenderer(CStyleLanguage): @@ -315,7 +309,12 @@ class CUDARenderer(CStyleLanguage): float4 = "make_float4" code_for_workitem = {"g": lambda x: f"blockIdx.{chr(120+int(x))}", "l": lambda x: f"threadIdx.{chr(120+int(x))}", "i": lambda x: f"(blockIdx.{chr(120+int(x))}*blockDim.{chr(120+int(x))}+threadIdx.{chr(120+int(x))})"} - code_for_op = {**CStyleLanguage.code_for_op, **code_for_op_half} + code_for_op = {**CStyleLanguage.code_for_op, + UnaryOps.RECIP: lambda x,dtype: f"hrcp({x})" if dtype in (dtypes.half, dtypes.bfloat16) else f"(1/{x})", + UnaryOps.SQRT: lambda x,dtype: f"hsqrt({x})" if dtype in (dtypes.half, dtypes.bfloat16) else f"sqrt({x})", + UnaryOps.SIN: lambda x,dtype: f"hsin({x})" if dtype in (dtypes.half, dtypes.bfloat16) else f"sin({x})", + UnaryOps.LOG2: lambda x,dtype: f"hlog2({x})" if dtype in (dtypes.half, dtypes.bfloat16) else f"log2({x})", + UnaryOps.EXP2: lambda x,dtype: f"hexp2({x})" if dtype in (dtypes.half, dtypes.bfloat16) else f"exp2({x})",} type_map = {dtypes.bfloat16: "nv_bfloat16"} def render_vector_prefix(self, dt:DType) -> str: @@ -353,20 +352,6 @@ class CUDARenderer(CStyleLanguage): # https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html return f"__launch_bounds__({maxThreadsPerBlock}) " -code_for_op_hip = { UnaryOps.SQRT: lambda x,dtype: f"__ocml_sqrt_f{ {dtypes.half:16, dtypes.double:64}.get(dtype, 32)}({x})", - UnaryOps.SIN: lambda x,dtype: f"__ocml_sin_f{ {dtypes.half:16, dtypes.double:64}.get(dtype, 32)}({x})", - UnaryOps.LOG2: lambda x,dtype: f"__ocml_log2_f{ {dtypes.half:16, dtypes.double:64}.get(dtype, 32)}({x})", - UnaryOps.EXP2: lambda x,dtype: f"__ocml_exp2_f{ {dtypes.half:16, dtypes.double:64}.get(dtype, 32)}({x})"} - -def cast_float_bf16(x: UOp) -> UOp: - x = x.bitcast(dtypes.uint) - - is_not_inf_nan = -x & 0x7f800000 - has_mantissa = x & 0xffff - x = is_not_inf_nan.where(x + ((x >> 16) & 1) + 0x7fff, has_mantissa.where((x | 0x10000), x)) - - return (x >> 16).cast(dtypes.ushort).bitcast(dtypes.bfloat16) - class AMDRenderer(CStyleLanguage): device = "AMD" shared_max = 65536 @@ -385,7 +370,11 @@ class AMDRenderer(CStyleLanguage): kernel_prefix += '\nextern "C" __attribute__((global))' code_for_workitem = {"g": lambda x: f"__ockl_get_group_id({x})", "l": lambda x: f"__ockl_get_local_id({x})", "i": lambda x: f"(__ockl_get_group_id({x})*__ockl_get_local_size({x})+__ockl_get_local_id({x}))"} - code_for_op = { **CStyleLanguage.code_for_op, **code_for_op_hip } + code_for_op = { **CStyleLanguage.code_for_op, + UnaryOps.SQRT: lambda x,dtype: f"__ocml_sqrt_f{ {dtypes.half:16, dtypes.double:64}.get(dtype, 32)}({x})", + UnaryOps.SIN: lambda x,dtype: f"__ocml_sin_f{ {dtypes.half:16, dtypes.double:64}.get(dtype, 32)}({x})", + UnaryOps.LOG2: lambda x,dtype: f"__ocml_log2_f{ {dtypes.half:16, dtypes.double:64}.get(dtype, 32)}({x})", + UnaryOps.EXP2: lambda x,dtype: f"__ocml_exp2_f{ {dtypes.half:16, dtypes.double:64}.get(dtype, 32)}({x})"} smem_prefix = "__attribute__((shared))" barrier = '__builtin_amdgcn_fence(__ATOMIC_RELEASE, "workgroup");' + '__builtin_amdgcn_s_barrier();' + \ '__builtin_amdgcn_fence(__ATOMIC_ACQUIRE, "workgroup");' @@ -405,7 +394,7 @@ class AMDRenderer(CStyleLanguage): # bfloat16 casting (UPat(UOps.CAST, dtype=dtypes.float, src=UPat.var("x", dtype=dtypes.bfloat16)), lambda x: (x.bitcast(dtypes.ushort).cast(dtypes.uint)<<16).bitcast(dtypes.float)), - (UPat(UOps.CAST, dtype=dtypes.bfloat16, src=UPat.var("x", dtype=dtypes.float)), cast_float_bf16)]) + extra_pm + (UPat(UOps.CAST, dtype=dtypes.bfloat16, src=UPat.var("x", dtype=dtypes.float)), cast_float_to_bf16)]) + extra_pm def render_vector_prefix(self, dtype:DType) -> str: vec, scal = self.render_dtype(dtype), self.render_dtype(dtype.scalar())