delete the PRECAST op [p] (#13415)

* don't use PRECAST in cstyle renderer [p]

* fix in metal

* fix opencl

* __builtin_bit_cast

* precast is unused

* cuda is c99?

* lambda_union_bitcast

* helper function

* delete precast op
This commit is contained in:
George Hotz
2025-11-21 21:47:14 -08:00
committed by GitHub
parent 1f648bb1ba
commit 4c0f4226b9
4 changed files with 15 additions and 14 deletions

View File

@@ -22,10 +22,10 @@ base_rewrite = PatternMatcher([
(UPat(Ops.CAST, name="x"), lambda ctx,x:
f"__builtin_convertvector({ctx[x.src[0]]}, {ctx.render_dtype(x.dtype)})" if x.dtype.count > 1 and not isinstance(x.dtype, PtrDType) else None),
(UPat(Ops.CAST, name="x"), lambda ctx,x: f"({ctx.render_cast(x.dtype, ctx[x.src[0]])})"),
(UPat(Ops.BITCAST, name="x"), lambda ctx,x: f"(*(({ctx.buffer_prefix}{ctx.render_dtype(x.dtype)}*)&{ctx[x.src[0]]}))"),
(UPat(Ops.BITCAST, name="x"), lambda ctx,x:
f"__builtin_bit_cast({ctx.render_dtype(x.dtype)}, ({ctx.render_dtype(x.src[0].dtype)})({ctx[x.src[0]]}))"),
(UPat(Ops.DEFINE_LOCAL, name="x"), lambda ctx,x: f"{ctx.smem_align}{ctx.smem_prefix}{ctx.render_dtype(x.dtype.base)} {ctx[x]}[{x.dtype.size}];"),
(UPat(Ops.BARRIER), lambda ctx: ctx.barrier),
(UPat(Ops.PRECAST, name="x"), lambda ctx,x: ctx[x.src[0]]),
(UPat(Ops.SPECIAL, name="x"), lambda ctx,x: f"{ctx.code_for_workitem[x.arg[0]](x.arg[-1])}; /* {(x.src[0]).render()} */"),
# const
(UPat(Ops.CONST, arg=math.inf, name="x"), lambda ctx, x: f"({ctx.render_cast(x.dtype, ctx.infinity)})"),
@@ -60,9 +60,6 @@ base_rewrite = PatternMatcher([
])
extra_pm = PatternMatcher([
# insert a PRECAST before BITCAST to force it to be rendered. not needed on all backends?
(UPat(Ops.BITCAST, name="x"), lambda x: UOp(Ops.BITCAST, x.dtype, (UOp(Ops.PRECAST, x.src[0].dtype, x.src),))
if x.src[0].op not in {Ops.PRECAST, Ops.LOAD, Ops.CUSTOM} else None),
# devectorize any bools
(UPat((*GroupOp.ALU, Ops.CAST, Ops.BITCAST, Ops.INDEX), dtype=dtypes.bool, name="alu"), no_vectorized_alu),
# CAST (from bool) can't be vectorized
@@ -181,7 +178,7 @@ class CStyleLanguage(Renderer):
elif u.op is Ops.RANGE: r[u] = f"{axis_letters[u.arg[-1]]}idx"+range_str(u)
else:
prefix = {Ops.WMMA: "wmma", Ops.DEFINE_LOCAL: "temp", Ops.CONST: "const",
Ops.CAST: "cast", Ops.BITCAST: "cast", Ops.GEP: "gep", Ops.VECTORIZE: "cast", Ops.PRECAST: "precast",
Ops.CAST: "cast", Ops.BITCAST: "cast", Ops.GEP: "gep", Ops.VECTORIZE: "cast",
Ops.INDEX: "bidx", Ops.DEFINE_REG: "acc", Ops.LOAD: "val"}.get(u.op, "alu")
r[u] = f"{prefix}{c[prefix]}"
@@ -278,7 +275,7 @@ class OpenCLRenderer(CStyleLanguage):
dtypes.bfloat16: "ushort" }
string_rewrite = PatternMatcher([
(UPat(Ops.BITCAST, name="x"), lambda ctx,x: f"as_{ctx.render_dtype(x.dtype)}({ctx[x.src[0]]})"),
(UPat(Ops.BITCAST, name="x"), lambda ctx,x: f"as_{ctx.render_dtype(x.dtype)}(({ctx.render_dtype(x.src[0].dtype)})({ctx[x.src[0]]}))"),
# load/store image (OpenCL)
(UPat(Ops.LOAD, dtype=dtypes.float.vec(4), src=(UPat.var('buf').index(UPat.var('idx', dtypes.int.vec(2)), UPat.var("gate")), UPat.var("var"))),
lambda ctx,buf,idx,var,gate: f"({ctx[gate]}?read_imagef({ctx[buf]}, smp, {ctx[idx]}):{ctx[var]})"),
@@ -338,7 +335,7 @@ class MetalRenderer(CStyleLanguage):
]) + extra_pm
string_rewrite = PatternMatcher([
(UPat(Ops.BITCAST, name="x"), lambda ctx,x: f"as_type<{ctx.render_dtype(x.dtype)}>({ctx[x.src[0]]})"),
(UPat(Ops.BITCAST, name="x"), lambda ctx,x: f"as_type<{ctx.render_dtype(x.dtype)}>(({ctx.render_dtype(x.src[0].dtype)})({ctx[x.src[0]]}))"),
]) + base_rewrite
def render_kernel(self, function_name, kernel, bufs, uops, prefix=None):
@@ -385,6 +382,10 @@ class CUDARenderer(CStyleLanguage):
extra_matcher = create_non_native_float_pats(dtypes.fp8s, casting=False) + PatternMatcher([
(UPat(Ops.CAST, dtypes.fp8s, UPat.var("x", dtypes.fp8s), name='y'), lambda x,y: x.cast(dtypes.float).cast(y.dtype) if x.dtype!=y.dtype else None),
]) + extra_pm
string_rewrite = PatternMatcher([
(UPat(Ops.BITCAST, name="x"), lambda ctx,x: f"tg_bitcast<{ctx.render_dtype(x.dtype)}>(({ctx.render_dtype(x.src[0].dtype)})({ctx[x.src[0]]}))"),
]) + base_rewrite
def render_vector_prefix(self, dt:DType) -> str:
vec, scal = self.render_dtype(dt), self.render_dtype(dt.scalar()),
elems, header = ', '.join(_nms[:dt.count]), ', '.join([f"{scal} {x}" for x in _nms[:dt.count]])
@@ -392,8 +393,8 @@ class CUDARenderer(CStyleLanguage):
def render_kernel(self, function_name, kernel, bufs, uops, prefix=None):
# TODO: why is dtypes.bfloat16.name == "__bf16"? would be easier not override dtypes.name
prefix = ["#define INFINITY (__int_as_float(0x7f800000))","#define NAN (__int_as_float(0x7fffffff))"]
prefix = ["#define INFINITY (__int_as_float(0x7f800000))", "#define NAN (__int_as_float(0x7fffffff))",
"template <class T, class F> __device__ __forceinline__ T tg_bitcast(F v) { union U { F f; T t; }; U u; u.f = v; return u.t; }"]
used_dtypes = uops_to_dtypes(uops)
if any(dt.scalar() in dtypes.fp8s for dt in used_dtypes): prefix.append("#include <cuda_fp8.h>")
if any(dt.scalar() == dtypes.half for dt in used_dtypes): prefix.append("#include <cuda_fp16.h>")

View File

@@ -25,7 +25,7 @@ class Ops(FastEnum):
# ** 2 -- non op uops **
# uops that aren't rendered
NOOP = auto(); SINK = auto(); PRECAST = auto()
NOOP = auto(); SINK = auto()
# AFTER passes src[0] through and promises in the toposort that any consumers of the AFTER run after src[1:]
AFTER = auto()

View File

@@ -200,7 +200,7 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
match self.op:
# late ops don't have shape
case Ops.UNIQUE | Ops.DEVICE | Ops.RANGE | Ops.LOAD | Ops.IF | Ops.BARRIER | Ops.CUSTOM | Ops.CUSTOMI | \
Ops.VECTORIZE | Ops.VCONST | Ops.GEP | Ops.SPECIAL | Ops.UNROLL | Ops.PRECAST | Ops.CONTRACT:
Ops.VECTORIZE | Ops.VCONST | Ops.GEP | Ops.SPECIAL | Ops.UNROLL | Ops.CONTRACT:
return None
case Ops.INDEX:

View File

@@ -147,8 +147,8 @@ shared_codegen_spec = PatternMatcher([
(UPat().index(UPat()).or_casted().load(), lambda: True),
(UPat(Ops.INDEX).or_casted().store(UPat()), lambda: True),
# all CUSTOM + PRECAST
(UPat((Ops.CUSTOMI, Ops.CUSTOM, Ops.PRECAST)), lambda: True),
# CUSTOM (inline and non inline)
(UPat((Ops.CUSTOMI, Ops.CUSTOM)), lambda: True),
# INDEX
(UPat(GroupOp.Defines|{Ops.AFTER}, name="buf").index(UPat.var("idx")), validate_index),