fix AMD LLVM compile error for bf16 cifar (#10254)

* fix AMD LLVM compile error

* remove llvm_bf16_cast

---------

Co-authored-by: b1tg <b1tg@users.noreply.github.com>
This commit is contained in:
b1tg
2025-05-12 13:57:07 +08:00
committed by GitHub
parent a0ed1ec1ae
commit 7eeb35ba6f

View File

@@ -108,10 +108,6 @@ base_rewrite = PatternMatcher([
(UPat(Ops.BARRIER), lambda ctx: "")
])
def llvm_bf16_cast(buf:UOp, idx:UOp, root:UOp):
u16_buf = buf.replace(dtype=dtypes.ushort.ptr(size=cast(PtrDType,buf.dtype).size))
return UOp.load(UOp.index(u16_buf, idx), dtype=dtypes.ushort).cast(dtypes.uint).mul(1<<16).bitcast(dtypes.float32).cast(root.dtype)
class LLVMRenderer(Renderer):
device = "LLVM"
abi = 'win64cc' if sys.platform == 'win32' else None
@@ -128,8 +124,6 @@ class LLVMRenderer(Renderer):
(UPat(Ops.CAST, dtype=dtypes.bool, name="x"), lambda x: x.src[0] != x.src[0].const_like(0)),
# rewrite MAX to CMPLT + WHERE
(UPat(Ops.MAX, name="m"), lambda m: (m.src[0] < m.src[1]).where(m.src[1], m.src[0])),
# rewrite bf16 CAST(LOAD) to CAST(BITCAST)
(UPat(Ops.CAST, name="root", src=(UPat.load(UPat.index(UPat.var("buf"), UPat.var("idx")), dtype=dtypes.bfloat16),)), llvm_bf16_cast),
# copied from cstyle.py, upcast to float32 all the ops that don't support bfloat16
(UPat((Ops.SQRT, Ops.EXP2, Ops.LOG2, Ops.SIN), dtype=dtypes.bfloat16, name="x"),
lambda x: (UOp(x.op, dtypes.float, tuple(vv.cast(dtypes.float) for vv in x.src), x.arg).cast(dtypes.bfloat16))),