From db76586780fcc3dc2bce52afb7e1f47d5dd35872 Mon Sep 17 00:00:00 2001 From: Ahmed Harmouche Date: Thu, 12 Dec 2024 15:12:14 +0100 Subject: [PATCH] Cast pattern touchup in AMDRenderer [pr] (#8185) --- tinygrad/renderer/cstyle.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/tinygrad/renderer/cstyle.py b/tinygrad/renderer/cstyle.py index f641d2a3d7..5425e0966e 100644 --- a/tinygrad/renderer/cstyle.py +++ b/tinygrad/renderer/cstyle.py @@ -395,8 +395,7 @@ class AMDRenderer(CStyleLanguage): (UPat(Ops.CAST, dtypes.bfloat16, UPat.var("x")),lambda x: x.cast(dtypes.float).cast(dtypes.bfloat16) if x.dtype!=dtypes.float else None), # bfloat16 casting (UPat.cvar('x', dtypes.bfloat16), lambda x: cast_float_to_bf16(UOp.const(dtypes.float, x.arg))), - (UPat(Ops.CAST, dtype=dtypes.float, src=UPat.var("x", dtype=dtypes.bfloat16)), - lambda x: (x.bitcast(dtypes.ushort).cast(dtypes.uint)<<16).bitcast(dtypes.float)), + (UPat(Ops.CAST, dtypes.float, UPat.var("x", dtypes.bfloat16)), lambda x: (x.bitcast(dtypes.ushort).cast(dtypes.uint)<<16).bitcast(dtypes.float)), (UPat(Ops.CAST, dtype=dtypes.bfloat16, src=UPat.var("x", dtype=dtypes.float)), cast_float_to_bf16)]) + extra_pm def render_vector_prefix(self, dtype:DType) -> str: