metal bfloat as cast (#6773)

This commit is contained in:
George Hotz
2024-09-26 19:31:40 +08:00
committed by GitHub
parent ed2f28388f
commit 249af24f18

View File

@@ -271,12 +271,17 @@ class MetalRenderer(CStyleLanguage):
# uint3 used for gid/lid - TODO: this should probably be `ushort3 lid [[thread_position_in_threadgroup]]`
extra_args = ['uint3 gid [[threadgroup_position_in_grid]]', 'uint3 lid [[thread_position_in_threadgroup]]']
type_map = {dtypes.bfloat16: "bfloat"}
code_for_op = {**CStyleLanguage().code_for_op,
BinaryOps.MAX: lambda a,b,dtype: f"(bfloat)max((float){a},(float){b})" if dtype == dtypes.bfloat16 else f"max({a},{b})",
UnaryOps.SQRT: lambda x,dtype: f"(bfloat)sqrt({x})" if dtype == dtypes.bfloat16 else f"sqrt({x})",
UnaryOps.EXP2: lambda x,dtype: f"(bfloat)exp2({x})" if dtype == dtypes.bfloat16 else f"exp2({x})",
UnaryOps.LOG2: lambda x,dtype: f"(bfloat)log2({x})" if dtype == dtypes.bfloat16 else f"log2({x})",
UnaryOps.SIN: lambda x,dtype: f"(bfloat)precise::sin({x})" if dtype == dtypes.bfloat16 else f"precise::sin({x})",}
# precise::sin
code_for_op = {**CStyleLanguage().code_for_op, UnaryOps.SIN: lambda x,dtype: f"precise::sin({x})"}
# upcast to float32 all the ops that don't support bfloat16
extra_matcher = PatternMatcher([
# NOTE: this is copied from PTX
*[(UPat(UOps.ALU, arg=op, dtype=dtypes.bfloat16, name="x"),
lambda x: (UOp(x.op, dtypes.float, tuple(vv.cast(dtypes.float) for vv in x.src), x.arg).cast(dtypes.bfloat16)))
for op in [BinaryOps.MAX, UnaryOps.SQRT, UnaryOps.EXP2, UnaryOps.LOG2, UnaryOps.SIN]]
]) + extra_pm
string_rewrite = PatternMatcher([
(UPat(UOps.BITCAST, name="x"), lambda r,x: f"as_type<{r.render_dtype(x.dtype)}>({r[x.src[0]]})"),