mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-22 13:28:06 -05:00
fix AMD LLVM compile error for bf16 cifar (#10254)
* fix AMD LLVM compile error * remove llvm_bf16_cast --------- Co-authored-by: b1tg <b1tg@users.noreply.github.com>
This commit is contained in:
@@ -108,10 +108,6 @@ base_rewrite = PatternMatcher([
|
||||
(UPat(Ops.BARRIER), lambda ctx: "")
|
||||
])
|
||||
|
||||
def llvm_bf16_cast(buf:UOp, idx:UOp, root:UOp):
|
||||
u16_buf = buf.replace(dtype=dtypes.ushort.ptr(size=cast(PtrDType,buf.dtype).size))
|
||||
return UOp.load(UOp.index(u16_buf, idx), dtype=dtypes.ushort).cast(dtypes.uint).mul(1<<16).bitcast(dtypes.float32).cast(root.dtype)
|
||||
|
||||
class LLVMRenderer(Renderer):
|
||||
device = "LLVM"
|
||||
abi = 'win64cc' if sys.platform == 'win32' else None
|
||||
@@ -128,8 +124,6 @@ class LLVMRenderer(Renderer):
|
||||
(UPat(Ops.CAST, dtype=dtypes.bool, name="x"), lambda x: x.src[0] != x.src[0].const_like(0)),
|
||||
# rewrite MAX to CMPLT + WHERE
|
||||
(UPat(Ops.MAX, name="m"), lambda m: (m.src[0] < m.src[1]).where(m.src[1], m.src[0])),
|
||||
# rewrite bf16 CAST(LOAD) to CAST(BITCAST)
|
||||
(UPat(Ops.CAST, name="root", src=(UPat.load(UPat.index(UPat.var("buf"), UPat.var("idx")), dtype=dtypes.bfloat16),)), llvm_bf16_cast),
|
||||
# copied from cstyle.py, upcast to float32 all the ops that don't support bfloat16
|
||||
(UPat((Ops.SQRT, Ops.EXP2, Ops.LOG2, Ops.SIN), dtype=dtypes.bfloat16, name="x"),
|
||||
lambda x: (UOp(x.op, dtypes.float, tuple(vv.cast(dtypes.float) for vv in x.src), x.arg).cast(dtypes.bfloat16))),
|
||||
|
||||
Reference in New Issue
Block a user