From 3e082d4a9d770e5cd014dd4d94b68588de79dfc3 Mon Sep 17 00:00:00 2001 From: George Hotz <72895+geohot@users.noreply.github.com> Date: Thu, 6 Feb 2025 12:15:50 +0800 Subject: [PATCH] add float4 support to LLVM (#8920) * add float4 support to LLVM * is_bool --- tinygrad/dtype.py | 2 ++ tinygrad/renderer/llvmir.py | 19 ++++++++++++++++--- 2 files changed, 18 insertions(+), 3 deletions(-) diff --git a/tinygrad/dtype.py b/tinygrad/dtype.py index d39fc70244..2118b18b22 100644 --- a/tinygrad/dtype.py +++ b/tinygrad/dtype.py @@ -80,6 +80,8 @@ class dtypes: @functools.lru_cache(None) def is_unsigned(x: DType) -> bool: return x.scalar() in dtypes.uints @staticmethod + def is_bool(x: DType) -> bool: return x.scalar() == dtypes.bool + @staticmethod def from_py(x) -> DType: if x.__class__ is float: return dtypes.default_float if x.__class__ is int: return dtypes.default_int diff --git a/tinygrad/renderer/llvmir.py b/tinygrad/renderer/llvmir.py index b0cc826592..6a59245875 100644 --- a/tinygrad/renderer/llvmir.py +++ b/tinygrad/renderer/llvmir.py @@ -5,6 +5,7 @@ from tinygrad.ops import UOp, PatternMatcher, UPat, Ops, GroupOp from tinygrad.dtype import dtypes, DType, PtrDType, truncate def ldt(dt:DType): + if dt.vcount > 1: return f"<{dt.vcount} x {ldt(dt.scalar())}>" if isinstance(dt, PtrDType): return ldt(dt.base) + "*" return {dtypes.int8: "i8", dtypes.int16: "i16", dtypes.int32: "i32", dtypes.int64: "i64", dtypes.uint8: "i8", dtypes.uint16: "i16", dtypes.uint32: "i32", dtypes.uint64: "i64", @@ -20,7 +21,7 @@ def lcast(input_type:DType, output_type:DType): if dtypes.is_float(input_type): if dtypes.is_float(output_type): return 'fpext' if output_type.itemsize > input_type.itemsize else 'fptrunc' if dtypes.is_int(output_type): return 'fptoui' if dtypes.is_unsigned(output_type) else 'fptosi' - if dtypes.is_unsigned(input_type) or input_type == dtypes.bool: + if dtypes.is_unsigned(input_type) or dtypes.is_bool(input_type): if dtypes.is_float(output_type): return 'uitofp' if dtypes.is_int(output_type): return 'trunc' if output_type.itemsize < input_type.itemsize else 'zext' if dtypes.is_int(input_type): @@ -49,12 +50,24 @@ llvm_rewrite = PatternMatcher([ (UPat(Ops.LOAD, src=(UPat.var('idx'),), name="x"), lambda ctx,x,idx: f" {ctx[x]} = load {ldt(x.dtype)}, {ldt(idx.dtype)} {ctx[idx]}"), (UPat(Ops.STORE, name="x"), lambda ctx,x: f" store {ldt(x.src[1].dtype)} {ctx[x.src[1]]}, {ldt(x.src[0].dtype)} {ctx[x.src[0]]}"), + # GEP/VECTORIZE/CAST for float4 support + (UPat(Ops.GEP, name="x"), lambda ctx,x: f" {ctx[x]} = extractelement {ldt(x.src[0].dtype)} {ctx[x.src[0]]}, i32 {x.arg[0]}"), + (UPat(Ops.VECTORIZE, src=UPat.var('y'), name="x"), lambda ctx,x,y: + f" {ctx[x]}_z = insertelement <1 x {ldt(y.dtype)}> poison, {ldt(y.dtype)} {ctx[y]}, i32 0\n" + f" {ctx[x]} = shufflevector <1 x {ldt(y.dtype)}> {ctx[x]}_z, <1 x {ldt(y.dtype)}> poison, <{x.dtype.count} x i32> zeroinitializer"), + (UPat(Ops.VECTORIZE, name="x"), lambda ctx,x: "\n".join([(f" {ctx[x]}_{i}" if i+1 != len(x.src) else f" {ctx[x]}")+ + f" = insertelement {ldt(x.dtype)} "+(f"{ctx[x]}_{i-1}" if i != 0 else "poison")+ + f", {ldt(u.dtype)} {ctx[u]}, i32 {i}" for i,u in enumerate(x.src)])), + (UPat(Ops.CAST, name="x"), lambda ctx,x: + f" {ctx[x]} = bitcast {ldt(x.src[0].dtype)} {ctx[x.src[0]]} to {ldt(x.dtype)}" if isinstance(x.dtype, PtrDType) else None), + # unary/binary/ternary ops (UPat(Ops.SQRT, name="x"), lambda ctx,x: f" {ctx[x]} = call{flags} {ldt(x.dtype)} @llvm.sqrt.{ldt(x.src[0].dtype)}({ldt(x.src[0].dtype)} {ctx[x.src[0]]})"), (UPat(Ops.BITCAST, name="x"), lambda ctx,x: f" {ctx[x]} = bitcast {ldt(x.src[0].dtype)} {ctx[x.src[0]]} to {ldt(x.dtype)}"), (UPat(Ops.CAST, name="x"), lambda ctx,x: f" {ctx[x]} = {lcast(x.src[0].dtype, x.dtype)} {ldt(x.src[0].dtype)} {ctx[x.src[0]]} to {ldt(x.dtype)}"), - (UPat(GroupOp.Binary, name="x"), lambda ctx,x: f" {ctx[x]} = {lop[x.src[0].dtype][x.op]} {ldt(x.src[0].dtype)} {ctx[x.src[0]]}, {ctx[x.src[1]]}"), + (UPat(GroupOp.Binary, name="x"), lambda ctx,x: + f" {ctx[x]} = {lop[x.src[0].dtype.scalar()][x.op]} {ldt(x.src[0].dtype)} {ctx[x.src[0]]}, {ctx[x.src[1]]}"), (UPat(Ops.WHERE, name="x"), lambda ctx,x: f" {ctx[x]} = select {ldt(x.src[0].dtype)} {ctx[x.src[0]]}, {ldt(x.src[1].dtype)} {ctx[x.src[1]]}, {ldt(x.src[2].dtype)} {ctx[x.src[2]]}"), @@ -79,7 +92,7 @@ def llvm_bf16_cast(buf:UOp, idx:UOp, root:UOp): class LLVMRenderer(Renderer): device = "LLVM" - supports_float4 = False + supports_float4 = True has_local = False has_shared = False global_max = None