mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-06 21:53:53 -05:00
delete the PRECAST op [p] (#13415)
* don't use PRECAST in cstyle renderer [p] * fix in metal * fix opencl * __builtin_bit_cast * precast is unused * cuda is c99? * lambda_union_bitcast * helper function * delete precast op
This commit is contained in:
@@ -22,10 +22,10 @@ base_rewrite = PatternMatcher([
|
||||
(UPat(Ops.CAST, name="x"), lambda ctx,x:
|
||||
f"__builtin_convertvector({ctx[x.src[0]]}, {ctx.render_dtype(x.dtype)})" if x.dtype.count > 1 and not isinstance(x.dtype, PtrDType) else None),
|
||||
(UPat(Ops.CAST, name="x"), lambda ctx,x: f"({ctx.render_cast(x.dtype, ctx[x.src[0]])})"),
|
||||
(UPat(Ops.BITCAST, name="x"), lambda ctx,x: f"(*(({ctx.buffer_prefix}{ctx.render_dtype(x.dtype)}*)&{ctx[x.src[0]]}))"),
|
||||
(UPat(Ops.BITCAST, name="x"), lambda ctx,x:
|
||||
f"__builtin_bit_cast({ctx.render_dtype(x.dtype)}, ({ctx.render_dtype(x.src[0].dtype)})({ctx[x.src[0]]}))"),
|
||||
(UPat(Ops.DEFINE_LOCAL, name="x"), lambda ctx,x: f"{ctx.smem_align}{ctx.smem_prefix}{ctx.render_dtype(x.dtype.base)} {ctx[x]}[{x.dtype.size}];"),
|
||||
(UPat(Ops.BARRIER), lambda ctx: ctx.barrier),
|
||||
(UPat(Ops.PRECAST, name="x"), lambda ctx,x: ctx[x.src[0]]),
|
||||
(UPat(Ops.SPECIAL, name="x"), lambda ctx,x: f"{ctx.code_for_workitem[x.arg[0]](x.arg[-1])}; /* {(x.src[0]).render()} */"),
|
||||
# const
|
||||
(UPat(Ops.CONST, arg=math.inf, name="x"), lambda ctx, x: f"({ctx.render_cast(x.dtype, ctx.infinity)})"),
|
||||
@@ -60,9 +60,6 @@ base_rewrite = PatternMatcher([
|
||||
])
|
||||
|
||||
extra_pm = PatternMatcher([
|
||||
# insert a PRECAST before BITCAST to force it to be rendered. not needed on all backends?
|
||||
(UPat(Ops.BITCAST, name="x"), lambda x: UOp(Ops.BITCAST, x.dtype, (UOp(Ops.PRECAST, x.src[0].dtype, x.src),))
|
||||
if x.src[0].op not in {Ops.PRECAST, Ops.LOAD, Ops.CUSTOM} else None),
|
||||
# devectorize any bools
|
||||
(UPat((*GroupOp.ALU, Ops.CAST, Ops.BITCAST, Ops.INDEX), dtype=dtypes.bool, name="alu"), no_vectorized_alu),
|
||||
# CAST (from bool) can't be vectorized
|
||||
@@ -181,7 +178,7 @@ class CStyleLanguage(Renderer):
|
||||
elif u.op is Ops.RANGE: r[u] = f"{axis_letters[u.arg[-1]]}idx"+range_str(u)
|
||||
else:
|
||||
prefix = {Ops.WMMA: "wmma", Ops.DEFINE_LOCAL: "temp", Ops.CONST: "const",
|
||||
Ops.CAST: "cast", Ops.BITCAST: "cast", Ops.GEP: "gep", Ops.VECTORIZE: "cast", Ops.PRECAST: "precast",
|
||||
Ops.CAST: "cast", Ops.BITCAST: "cast", Ops.GEP: "gep", Ops.VECTORIZE: "cast",
|
||||
Ops.INDEX: "bidx", Ops.DEFINE_REG: "acc", Ops.LOAD: "val"}.get(u.op, "alu")
|
||||
r[u] = f"{prefix}{c[prefix]}"
|
||||
|
||||
@@ -278,7 +275,7 @@ class OpenCLRenderer(CStyleLanguage):
|
||||
dtypes.bfloat16: "ushort" }
|
||||
|
||||
string_rewrite = PatternMatcher([
|
||||
(UPat(Ops.BITCAST, name="x"), lambda ctx,x: f"as_{ctx.render_dtype(x.dtype)}({ctx[x.src[0]]})"),
|
||||
(UPat(Ops.BITCAST, name="x"), lambda ctx,x: f"as_{ctx.render_dtype(x.dtype)}(({ctx.render_dtype(x.src[0].dtype)})({ctx[x.src[0]]}))"),
|
||||
# load/store image (OpenCL)
|
||||
(UPat(Ops.LOAD, dtype=dtypes.float.vec(4), src=(UPat.var('buf').index(UPat.var('idx', dtypes.int.vec(2)), UPat.var("gate")), UPat.var("var"))),
|
||||
lambda ctx,buf,idx,var,gate: f"({ctx[gate]}?read_imagef({ctx[buf]}, smp, {ctx[idx]}):{ctx[var]})"),
|
||||
@@ -338,7 +335,7 @@ class MetalRenderer(CStyleLanguage):
|
||||
]) + extra_pm
|
||||
|
||||
string_rewrite = PatternMatcher([
|
||||
(UPat(Ops.BITCAST, name="x"), lambda ctx,x: f"as_type<{ctx.render_dtype(x.dtype)}>({ctx[x.src[0]]})"),
|
||||
(UPat(Ops.BITCAST, name="x"), lambda ctx,x: f"as_type<{ctx.render_dtype(x.dtype)}>(({ctx.render_dtype(x.src[0].dtype)})({ctx[x.src[0]]}))"),
|
||||
]) + base_rewrite
|
||||
|
||||
def render_kernel(self, function_name, kernel, bufs, uops, prefix=None):
|
||||
@@ -385,6 +382,10 @@ class CUDARenderer(CStyleLanguage):
|
||||
extra_matcher = create_non_native_float_pats(dtypes.fp8s, casting=False) + PatternMatcher([
|
||||
(UPat(Ops.CAST, dtypes.fp8s, UPat.var("x", dtypes.fp8s), name='y'), lambda x,y: x.cast(dtypes.float).cast(y.dtype) if x.dtype!=y.dtype else None),
|
||||
]) + extra_pm
|
||||
string_rewrite = PatternMatcher([
|
||||
(UPat(Ops.BITCAST, name="x"), lambda ctx,x: f"tg_bitcast<{ctx.render_dtype(x.dtype)}>(({ctx.render_dtype(x.src[0].dtype)})({ctx[x.src[0]]}))"),
|
||||
]) + base_rewrite
|
||||
|
||||
def render_vector_prefix(self, dt:DType) -> str:
|
||||
vec, scal = self.render_dtype(dt), self.render_dtype(dt.scalar()),
|
||||
elems, header = ', '.join(_nms[:dt.count]), ', '.join([f"{scal} {x}" for x in _nms[:dt.count]])
|
||||
@@ -392,8 +393,8 @@ class CUDARenderer(CStyleLanguage):
|
||||
|
||||
def render_kernel(self, function_name, kernel, bufs, uops, prefix=None):
|
||||
# TODO: why is dtypes.bfloat16.name == "__bf16"? would be easier not override dtypes.name
|
||||
prefix = ["#define INFINITY (__int_as_float(0x7f800000))","#define NAN (__int_as_float(0x7fffffff))"]
|
||||
|
||||
prefix = ["#define INFINITY (__int_as_float(0x7f800000))", "#define NAN (__int_as_float(0x7fffffff))",
|
||||
"template <class T, class F> __device__ __forceinline__ T tg_bitcast(F v) { union U { F f; T t; }; U u; u.f = v; return u.t; }"]
|
||||
used_dtypes = uops_to_dtypes(uops)
|
||||
if any(dt.scalar() in dtypes.fp8s for dt in used_dtypes): prefix.append("#include <cuda_fp8.h>")
|
||||
if any(dt.scalar() == dtypes.half for dt in used_dtypes): prefix.append("#include <cuda_fp16.h>")
|
||||
|
||||
@@ -25,7 +25,7 @@ class Ops(FastEnum):
|
||||
# ** 2 -- non op uops **
|
||||
|
||||
# uops that aren't rendered
|
||||
NOOP = auto(); SINK = auto(); PRECAST = auto()
|
||||
NOOP = auto(); SINK = auto()
|
||||
|
||||
# AFTER passes src[0] through and promises in the toposort that any consumers of the AFTER run after src[1:]
|
||||
AFTER = auto()
|
||||
|
||||
@@ -200,7 +200,7 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
|
||||
match self.op:
|
||||
# late ops don't have shape
|
||||
case Ops.UNIQUE | Ops.DEVICE | Ops.RANGE | Ops.LOAD | Ops.IF | Ops.BARRIER | Ops.CUSTOM | Ops.CUSTOMI | \
|
||||
Ops.VECTORIZE | Ops.VCONST | Ops.GEP | Ops.SPECIAL | Ops.UNROLL | Ops.PRECAST | Ops.CONTRACT:
|
||||
Ops.VECTORIZE | Ops.VCONST | Ops.GEP | Ops.SPECIAL | Ops.UNROLL | Ops.CONTRACT:
|
||||
return None
|
||||
|
||||
case Ops.INDEX:
|
||||
|
||||
@@ -147,8 +147,8 @@ shared_codegen_spec = PatternMatcher([
|
||||
(UPat().index(UPat()).or_casted().load(), lambda: True),
|
||||
(UPat(Ops.INDEX).or_casted().store(UPat()), lambda: True),
|
||||
|
||||
# all CUSTOM + PRECAST
|
||||
(UPat((Ops.CUSTOMI, Ops.CUSTOM, Ops.PRECAST)), lambda: True),
|
||||
# CUSTOM (inline and non inline)
|
||||
(UPat((Ops.CUSTOMI, Ops.CUSTOM)), lambda: True),
|
||||
|
||||
# INDEX
|
||||
(UPat(GroupOp.Defines|{Ops.AFTER}, name="buf").index(UPat.var("idx")), validate_index),
|
||||
|
||||
Reference in New Issue
Block a user