mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-29 08:48:15 -05:00
minor cleanups of cstyle [pr] (#7391)
* minor cleanups of cstyle [pr] * work
This commit is contained in:
@@ -835,6 +835,14 @@ def type_verify(uops:List[UOp]):
|
||||
print_uops(uops)
|
||||
raise RuntimeError(f"UOp verification failed at {i} on {u.op} {u.dtype} {len(u.src)} {[x.op for x in u.src]} {u.arg}")
|
||||
|
||||
# *** uop helpers ***
|
||||
|
||||
def cast_float_to_bf16(x: UOp) -> UOp:
|
||||
assert x.dtype == dtypes.float, "cast float -> bf16 must start with float"
|
||||
x = x.bitcast(dtypes.uint)
|
||||
x = (-x & 0x7f800000).where(x + ((x >> 16) & 1) + 0x7fff, (x & 0xffff).where((x | 0x10000), x))
|
||||
return (x >> 16).cast(dtypes.ushort).bitcast(dtypes.bfloat16)
|
||||
|
||||
# *** most of symbolic lives here now ***
|
||||
|
||||
def split_uop(x:UOp, sep:BinaryOps):
|
||||
|
||||
@@ -2,7 +2,7 @@ from __future__ import annotations
|
||||
from typing import Dict, List, Optional, Tuple, Union, DefaultDict, Literal, Callable, cast
|
||||
import os, math
|
||||
from collections import defaultdict, Counter
|
||||
from tinygrad.ops import UnaryOps, BinaryOps, TernaryOps, UOps, UOp, PatternMatcher, UPat
|
||||
from tinygrad.ops import UnaryOps, BinaryOps, TernaryOps, UOps, UOp, PatternMatcher, UPat, cast_float_to_bf16
|
||||
from tinygrad.helpers import strip_parens, getenv, prod, dedup, AMX
|
||||
from tinygrad.dtype import ImageDType, dtypes, DType, PtrDType
|
||||
from tinygrad.renderer import Renderer, TensorCore
|
||||
@@ -286,12 +286,6 @@ class MetalRenderer(CStyleLanguage):
|
||||
return {arg[3].name}2(c.thread_elements()[0], c.thread_elements()[1]);\n}}""")
|
||||
return super().render_kernel(function_name, kernel, bufs, uops, prefix)
|
||||
|
||||
code_for_op_half = {UnaryOps.RECIP: lambda x,dtype: f"hrcp({x})" if dtype in (dtypes.half, dtypes.bfloat16) else f"(1/{x})",
|
||||
UnaryOps.SQRT: lambda x,dtype: f"hsqrt({x})" if dtype in (dtypes.half, dtypes.bfloat16) else f"sqrt({x})",
|
||||
UnaryOps.SIN: lambda x,dtype: f"hsin({x})" if dtype in (dtypes.half, dtypes.bfloat16) else f"sin({x})",
|
||||
UnaryOps.LOG2: lambda x,dtype: f"hlog2({x})" if dtype in (dtypes.half, dtypes.bfloat16) else f"log2({x})",
|
||||
UnaryOps.EXP2: lambda x,dtype: f"hexp2({x})" if dtype in (dtypes.half, dtypes.bfloat16) else f"exp2({x})",}
|
||||
|
||||
_nms = "xyzwabcdefghijkl"
|
||||
|
||||
class CUDARenderer(CStyleLanguage):
|
||||
@@ -315,7 +309,12 @@ class CUDARenderer(CStyleLanguage):
|
||||
float4 = "make_float4"
|
||||
code_for_workitem = {"g": lambda x: f"blockIdx.{chr(120+int(x))}", "l": lambda x: f"threadIdx.{chr(120+int(x))}",
|
||||
"i": lambda x: f"(blockIdx.{chr(120+int(x))}*blockDim.{chr(120+int(x))}+threadIdx.{chr(120+int(x))})"}
|
||||
code_for_op = {**CStyleLanguage.code_for_op, **code_for_op_half}
|
||||
code_for_op = {**CStyleLanguage.code_for_op,
|
||||
UnaryOps.RECIP: lambda x,dtype: f"hrcp({x})" if dtype in (dtypes.half, dtypes.bfloat16) else f"(1/{x})",
|
||||
UnaryOps.SQRT: lambda x,dtype: f"hsqrt({x})" if dtype in (dtypes.half, dtypes.bfloat16) else f"sqrt({x})",
|
||||
UnaryOps.SIN: lambda x,dtype: f"hsin({x})" if dtype in (dtypes.half, dtypes.bfloat16) else f"sin({x})",
|
||||
UnaryOps.LOG2: lambda x,dtype: f"hlog2({x})" if dtype in (dtypes.half, dtypes.bfloat16) else f"log2({x})",
|
||||
UnaryOps.EXP2: lambda x,dtype: f"hexp2({x})" if dtype in (dtypes.half, dtypes.bfloat16) else f"exp2({x})",}
|
||||
type_map = {dtypes.bfloat16: "nv_bfloat16"}
|
||||
|
||||
def render_vector_prefix(self, dt:DType) -> str:
|
||||
@@ -353,20 +352,6 @@ class CUDARenderer(CStyleLanguage):
|
||||
# https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html
|
||||
return f"__launch_bounds__({maxThreadsPerBlock}) "
|
||||
|
||||
code_for_op_hip = { UnaryOps.SQRT: lambda x,dtype: f"__ocml_sqrt_f{ {dtypes.half:16, dtypes.double:64}.get(dtype, 32)}({x})",
|
||||
UnaryOps.SIN: lambda x,dtype: f"__ocml_sin_f{ {dtypes.half:16, dtypes.double:64}.get(dtype, 32)}({x})",
|
||||
UnaryOps.LOG2: lambda x,dtype: f"__ocml_log2_f{ {dtypes.half:16, dtypes.double:64}.get(dtype, 32)}({x})",
|
||||
UnaryOps.EXP2: lambda x,dtype: f"__ocml_exp2_f{ {dtypes.half:16, dtypes.double:64}.get(dtype, 32)}({x})"}
|
||||
|
||||
def cast_float_bf16(x: UOp) -> UOp:
|
||||
x = x.bitcast(dtypes.uint)
|
||||
|
||||
is_not_inf_nan = -x & 0x7f800000
|
||||
has_mantissa = x & 0xffff
|
||||
x = is_not_inf_nan.where(x + ((x >> 16) & 1) + 0x7fff, has_mantissa.where((x | 0x10000), x))
|
||||
|
||||
return (x >> 16).cast(dtypes.ushort).bitcast(dtypes.bfloat16)
|
||||
|
||||
class AMDRenderer(CStyleLanguage):
|
||||
device = "AMD"
|
||||
shared_max = 65536
|
||||
@@ -385,7 +370,11 @@ class AMDRenderer(CStyleLanguage):
|
||||
kernel_prefix += '\nextern "C" __attribute__((global))'
|
||||
code_for_workitem = {"g": lambda x: f"__ockl_get_group_id({x})", "l": lambda x: f"__ockl_get_local_id({x})",
|
||||
"i": lambda x: f"(__ockl_get_group_id({x})*__ockl_get_local_size({x})+__ockl_get_local_id({x}))"}
|
||||
code_for_op = { **CStyleLanguage.code_for_op, **code_for_op_hip }
|
||||
code_for_op = { **CStyleLanguage.code_for_op,
|
||||
UnaryOps.SQRT: lambda x,dtype: f"__ocml_sqrt_f{ {dtypes.half:16, dtypes.double:64}.get(dtype, 32)}({x})",
|
||||
UnaryOps.SIN: lambda x,dtype: f"__ocml_sin_f{ {dtypes.half:16, dtypes.double:64}.get(dtype, 32)}({x})",
|
||||
UnaryOps.LOG2: lambda x,dtype: f"__ocml_log2_f{ {dtypes.half:16, dtypes.double:64}.get(dtype, 32)}({x})",
|
||||
UnaryOps.EXP2: lambda x,dtype: f"__ocml_exp2_f{ {dtypes.half:16, dtypes.double:64}.get(dtype, 32)}({x})"}
|
||||
smem_prefix = "__attribute__((shared))"
|
||||
barrier = '__builtin_amdgcn_fence(__ATOMIC_RELEASE, "workgroup");' + '__builtin_amdgcn_s_barrier();' + \
|
||||
'__builtin_amdgcn_fence(__ATOMIC_ACQUIRE, "workgroup");'
|
||||
@@ -405,7 +394,7 @@ class AMDRenderer(CStyleLanguage):
|
||||
# bfloat16 casting
|
||||
(UPat(UOps.CAST, dtype=dtypes.float, src=UPat.var("x", dtype=dtypes.bfloat16)),
|
||||
lambda x: (x.bitcast(dtypes.ushort).cast(dtypes.uint)<<16).bitcast(dtypes.float)),
|
||||
(UPat(UOps.CAST, dtype=dtypes.bfloat16, src=UPat.var("x", dtype=dtypes.float)), cast_float_bf16)]) + extra_pm
|
||||
(UPat(UOps.CAST, dtype=dtypes.bfloat16, src=UPat.var("x", dtype=dtypes.float)), cast_float_to_bf16)]) + extra_pm
|
||||
|
||||
def render_vector_prefix(self, dtype:DType) -> str:
|
||||
vec, scal = self.render_dtype(dtype), self.render_dtype(dtype.scalar())
|
||||
|
||||
Reference in New Issue
Block a user