From 7eeb35ba6f562f71fe26add444eaba923b47d496 Mon Sep 17 00:00:00 2001 From: b1tg <33436708+b1tg@users.noreply.github.com> Date: Mon, 12 May 2025 13:57:07 +0800 Subject: [PATCH] fix AMD LLVM compile error for bf16 cifar (#10254) * fix AMD LLVM compile error * remove llvm_bf16_cast --------- Co-authored-by: b1tg --- tinygrad/renderer/llvmir.py | 6 ------ 1 file changed, 6 deletions(-) diff --git a/tinygrad/renderer/llvmir.py b/tinygrad/renderer/llvmir.py index 1deea12870..0292663d3b 100644 --- a/tinygrad/renderer/llvmir.py +++ b/tinygrad/renderer/llvmir.py @@ -108,10 +108,6 @@ base_rewrite = PatternMatcher([ (UPat(Ops.BARRIER), lambda ctx: "") ]) -def llvm_bf16_cast(buf:UOp, idx:UOp, root:UOp): - u16_buf = buf.replace(dtype=dtypes.ushort.ptr(size=cast(PtrDType,buf.dtype).size)) - return UOp.load(UOp.index(u16_buf, idx), dtype=dtypes.ushort).cast(dtypes.uint).mul(1<<16).bitcast(dtypes.float32).cast(root.dtype) - class LLVMRenderer(Renderer): device = "LLVM" abi = 'win64cc' if sys.platform == 'win32' else None @@ -128,8 +124,6 @@ class LLVMRenderer(Renderer): (UPat(Ops.CAST, dtype=dtypes.bool, name="x"), lambda x: x.src[0] != x.src[0].const_like(0)), # rewrite MAX to CMPLT + WHERE (UPat(Ops.MAX, name="m"), lambda m: (m.src[0] < m.src[1]).where(m.src[1], m.src[0])), - # rewrite bf16 CAST(LOAD) to CAST(BITCAST) - (UPat(Ops.CAST, name="root", src=(UPat.load(UPat.index(UPat.var("buf"), UPat.var("idx")), dtype=dtypes.bfloat16),)), llvm_bf16_cast), # copied from cstyle.py, upcast to float32 all the ops that don't support bfloat16 (UPat((Ops.SQRT, Ops.EXP2, Ops.LOG2, Ops.SIN), dtype=dtypes.bfloat16, name="x"), lambda x: (UOp(x.op, dtypes.float, tuple(vv.cast(dtypes.float) for vv in x.src), x.arg).cast(dtypes.bfloat16))),