mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-02-12 15:45:27 -05:00
metal bfloat as cast (#6773)
This commit is contained in:
@@ -271,12 +271,17 @@ class MetalRenderer(CStyleLanguage):
|
||||
# uint3 used for gid/lid - TODO: this should probably be `ushort3 lid [[thread_position_in_threadgroup]]`
|
||||
extra_args = ['uint3 gid [[threadgroup_position_in_grid]]', 'uint3 lid [[thread_position_in_threadgroup]]']
|
||||
type_map = {dtypes.bfloat16: "bfloat"}
|
||||
code_for_op = {**CStyleLanguage().code_for_op,
|
||||
BinaryOps.MAX: lambda a,b,dtype: f"(bfloat)max((float){a},(float){b})" if dtype == dtypes.bfloat16 else f"max({a},{b})",
|
||||
UnaryOps.SQRT: lambda x,dtype: f"(bfloat)sqrt({x})" if dtype == dtypes.bfloat16 else f"sqrt({x})",
|
||||
UnaryOps.EXP2: lambda x,dtype: f"(bfloat)exp2({x})" if dtype == dtypes.bfloat16 else f"exp2({x})",
|
||||
UnaryOps.LOG2: lambda x,dtype: f"(bfloat)log2({x})" if dtype == dtypes.bfloat16 else f"log2({x})",
|
||||
UnaryOps.SIN: lambda x,dtype: f"(bfloat)precise::sin({x})" if dtype == dtypes.bfloat16 else f"precise::sin({x})",}
|
||||
|
||||
# precise::sin
|
||||
code_for_op = {**CStyleLanguage().code_for_op, UnaryOps.SIN: lambda x,dtype: f"precise::sin({x})"}
|
||||
|
||||
# upcast to float32 all the ops that don't support bfloat16
|
||||
extra_matcher = PatternMatcher([
|
||||
# NOTE: this is copied from PTX
|
||||
*[(UPat(UOps.ALU, arg=op, dtype=dtypes.bfloat16, name="x"),
|
||||
lambda x: (UOp(x.op, dtypes.float, tuple(vv.cast(dtypes.float) for vv in x.src), x.arg).cast(dtypes.bfloat16)))
|
||||
for op in [BinaryOps.MAX, UnaryOps.SQRT, UnaryOps.EXP2, UnaryOps.LOG2, UnaryOps.SIN]]
|
||||
]) + extra_pm
|
||||
|
||||
string_rewrite = PatternMatcher([
|
||||
(UPat(UOps.BITCAST, name="x"), lambda r,x: f"as_type<{r.render_dtype(x.dtype)}>({r[x.src[0]]})"),
|
||||
|
||||
Reference in New Issue
Block a user