minor cleanups of cstyle [pr] (#7391)

* minor cleanups of cstyle [pr]

* work
This commit is contained in:
George Hotz
2024-10-30 10:59:27 +07:00
committed by GitHub
parent 4ca53db604
commit 32dd2dcba5
2 changed files with 21 additions and 24 deletions

View File

@@ -835,6 +835,14 @@ def type_verify(uops:List[UOp]):
print_uops(uops)
raise RuntimeError(f"UOp verification failed at {i} on {u.op} {u.dtype} {len(u.src)} {[x.op for x in u.src]} {u.arg}")
# *** uop helpers ***
def cast_float_to_bf16(x: UOp) -> UOp:
assert x.dtype == dtypes.float, "cast float -> bf16 must start with float"
x = x.bitcast(dtypes.uint)
x = (-x & 0x7f800000).where(x + ((x >> 16) & 1) + 0x7fff, (x & 0xffff).where((x | 0x10000), x))
return (x >> 16).cast(dtypes.ushort).bitcast(dtypes.bfloat16)
# *** most of symbolic lives here now ***
def split_uop(x:UOp, sep:BinaryOps):

View File

@@ -2,7 +2,7 @@ from __future__ import annotations
from typing import Dict, List, Optional, Tuple, Union, DefaultDict, Literal, Callable, cast
import os, math
from collections import defaultdict, Counter
from tinygrad.ops import UnaryOps, BinaryOps, TernaryOps, UOps, UOp, PatternMatcher, UPat
from tinygrad.ops import UnaryOps, BinaryOps, TernaryOps, UOps, UOp, PatternMatcher, UPat, cast_float_to_bf16
from tinygrad.helpers import strip_parens, getenv, prod, dedup, AMX
from tinygrad.dtype import ImageDType, dtypes, DType, PtrDType
from tinygrad.renderer import Renderer, TensorCore
@@ -286,12 +286,6 @@ class MetalRenderer(CStyleLanguage):
return {arg[3].name}2(c.thread_elements()[0], c.thread_elements()[1]);\n}}""")
return super().render_kernel(function_name, kernel, bufs, uops, prefix)
code_for_op_half = {UnaryOps.RECIP: lambda x,dtype: f"hrcp({x})" if dtype in (dtypes.half, dtypes.bfloat16) else f"(1/{x})",
UnaryOps.SQRT: lambda x,dtype: f"hsqrt({x})" if dtype in (dtypes.half, dtypes.bfloat16) else f"sqrt({x})",
UnaryOps.SIN: lambda x,dtype: f"hsin({x})" if dtype in (dtypes.half, dtypes.bfloat16) else f"sin({x})",
UnaryOps.LOG2: lambda x,dtype: f"hlog2({x})" if dtype in (dtypes.half, dtypes.bfloat16) else f"log2({x})",
UnaryOps.EXP2: lambda x,dtype: f"hexp2({x})" if dtype in (dtypes.half, dtypes.bfloat16) else f"exp2({x})",}
_nms = "xyzwabcdefghijkl"
class CUDARenderer(CStyleLanguage):
@@ -315,7 +309,12 @@ class CUDARenderer(CStyleLanguage):
float4 = "make_float4"
code_for_workitem = {"g": lambda x: f"blockIdx.{chr(120+int(x))}", "l": lambda x: f"threadIdx.{chr(120+int(x))}",
"i": lambda x: f"(blockIdx.{chr(120+int(x))}*blockDim.{chr(120+int(x))}+threadIdx.{chr(120+int(x))})"}
code_for_op = {**CStyleLanguage.code_for_op, **code_for_op_half}
code_for_op = {**CStyleLanguage.code_for_op,
UnaryOps.RECIP: lambda x,dtype: f"hrcp({x})" if dtype in (dtypes.half, dtypes.bfloat16) else f"(1/{x})",
UnaryOps.SQRT: lambda x,dtype: f"hsqrt({x})" if dtype in (dtypes.half, dtypes.bfloat16) else f"sqrt({x})",
UnaryOps.SIN: lambda x,dtype: f"hsin({x})" if dtype in (dtypes.half, dtypes.bfloat16) else f"sin({x})",
UnaryOps.LOG2: lambda x,dtype: f"hlog2({x})" if dtype in (dtypes.half, dtypes.bfloat16) else f"log2({x})",
UnaryOps.EXP2: lambda x,dtype: f"hexp2({x})" if dtype in (dtypes.half, dtypes.bfloat16) else f"exp2({x})",}
type_map = {dtypes.bfloat16: "nv_bfloat16"}
def render_vector_prefix(self, dt:DType) -> str:
@@ -353,20 +352,6 @@ class CUDARenderer(CStyleLanguage):
# https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html
return f"__launch_bounds__({maxThreadsPerBlock}) "
code_for_op_hip = { UnaryOps.SQRT: lambda x,dtype: f"__ocml_sqrt_f{ {dtypes.half:16, dtypes.double:64}.get(dtype, 32)}({x})",
UnaryOps.SIN: lambda x,dtype: f"__ocml_sin_f{ {dtypes.half:16, dtypes.double:64}.get(dtype, 32)}({x})",
UnaryOps.LOG2: lambda x,dtype: f"__ocml_log2_f{ {dtypes.half:16, dtypes.double:64}.get(dtype, 32)}({x})",
UnaryOps.EXP2: lambda x,dtype: f"__ocml_exp2_f{ {dtypes.half:16, dtypes.double:64}.get(dtype, 32)}({x})"}
def cast_float_bf16(x: UOp) -> UOp:
x = x.bitcast(dtypes.uint)
is_not_inf_nan = -x & 0x7f800000
has_mantissa = x & 0xffff
x = is_not_inf_nan.where(x + ((x >> 16) & 1) + 0x7fff, has_mantissa.where((x | 0x10000), x))
return (x >> 16).cast(dtypes.ushort).bitcast(dtypes.bfloat16)
class AMDRenderer(CStyleLanguage):
device = "AMD"
shared_max = 65536
@@ -385,7 +370,11 @@ class AMDRenderer(CStyleLanguage):
kernel_prefix += '\nextern "C" __attribute__((global))'
code_for_workitem = {"g": lambda x: f"__ockl_get_group_id({x})", "l": lambda x: f"__ockl_get_local_id({x})",
"i": lambda x: f"(__ockl_get_group_id({x})*__ockl_get_local_size({x})+__ockl_get_local_id({x}))"}
code_for_op = { **CStyleLanguage.code_for_op, **code_for_op_hip }
code_for_op = { **CStyleLanguage.code_for_op,
UnaryOps.SQRT: lambda x,dtype: f"__ocml_sqrt_f{ {dtypes.half:16, dtypes.double:64}.get(dtype, 32)}({x})",
UnaryOps.SIN: lambda x,dtype: f"__ocml_sin_f{ {dtypes.half:16, dtypes.double:64}.get(dtype, 32)}({x})",
UnaryOps.LOG2: lambda x,dtype: f"__ocml_log2_f{ {dtypes.half:16, dtypes.double:64}.get(dtype, 32)}({x})",
UnaryOps.EXP2: lambda x,dtype: f"__ocml_exp2_f{ {dtypes.half:16, dtypes.double:64}.get(dtype, 32)}({x})"}
smem_prefix = "__attribute__((shared))"
barrier = '__builtin_amdgcn_fence(__ATOMIC_RELEASE, "workgroup");' + '__builtin_amdgcn_s_barrier();' + \
'__builtin_amdgcn_fence(__ATOMIC_ACQUIRE, "workgroup");'
@@ -405,7 +394,7 @@ class AMDRenderer(CStyleLanguage):
# bfloat16 casting
(UPat(UOps.CAST, dtype=dtypes.float, src=UPat.var("x", dtype=dtypes.bfloat16)),
lambda x: (x.bitcast(dtypes.ushort).cast(dtypes.uint)<<16).bitcast(dtypes.float)),
(UPat(UOps.CAST, dtype=dtypes.bfloat16, src=UPat.var("x", dtype=dtypes.float)), cast_float_bf16)]) + extra_pm
(UPat(UOps.CAST, dtype=dtypes.bfloat16, src=UPat.var("x", dtype=dtypes.float)), cast_float_to_bf16)]) + extra_pm
def render_vector_prefix(self, dtype:DType) -> str:
vec, scal = self.render_dtype(dtype), self.render_dtype(dtype.scalar())