diff --git a/tinygrad/renderer/cstyle.py b/tinygrad/renderer/cstyle.py index 0d9ef047e7..3633a506f5 100644 --- a/tinygrad/renderer/cstyle.py +++ b/tinygrad/renderer/cstyle.py @@ -22,10 +22,10 @@ base_rewrite = PatternMatcher([ (UPat(Ops.CAST, name="x"), lambda ctx,x: f"__builtin_convertvector({ctx[x.src[0]]}, {ctx.render_dtype(x.dtype)})" if x.dtype.count > 1 and not isinstance(x.dtype, PtrDType) else None), (UPat(Ops.CAST, name="x"), lambda ctx,x: f"({ctx.render_cast(x.dtype, ctx[x.src[0]])})"), - (UPat(Ops.BITCAST, name="x"), lambda ctx,x: f"(*(({ctx.buffer_prefix}{ctx.render_dtype(x.dtype)}*)&{ctx[x.src[0]]}))"), + (UPat(Ops.BITCAST, name="x"), lambda ctx,x: + f"__builtin_bit_cast({ctx.render_dtype(x.dtype)}, ({ctx.render_dtype(x.src[0].dtype)})({ctx[x.src[0]]}))"), (UPat(Ops.DEFINE_LOCAL, name="x"), lambda ctx,x: f"{ctx.smem_align}{ctx.smem_prefix}{ctx.render_dtype(x.dtype.base)} {ctx[x]}[{x.dtype.size}];"), (UPat(Ops.BARRIER), lambda ctx: ctx.barrier), - (UPat(Ops.PRECAST, name="x"), lambda ctx,x: ctx[x.src[0]]), (UPat(Ops.SPECIAL, name="x"), lambda ctx,x: f"{ctx.code_for_workitem[x.arg[0]](x.arg[-1])}; /* {(x.src[0]).render()} */"), # const (UPat(Ops.CONST, arg=math.inf, name="x"), lambda ctx, x: f"({ctx.render_cast(x.dtype, ctx.infinity)})"), @@ -60,9 +60,6 @@ base_rewrite = PatternMatcher([ ]) extra_pm = PatternMatcher([ - # insert a PRECAST before BITCAST to force it to be rendered. not needed on all backends? - (UPat(Ops.BITCAST, name="x"), lambda x: UOp(Ops.BITCAST, x.dtype, (UOp(Ops.PRECAST, x.src[0].dtype, x.src),)) - if x.src[0].op not in {Ops.PRECAST, Ops.LOAD, Ops.CUSTOM} else None), # devectorize any bools (UPat((*GroupOp.ALU, Ops.CAST, Ops.BITCAST, Ops.INDEX), dtype=dtypes.bool, name="alu"), no_vectorized_alu), # CAST (from bool) can't be vectorized @@ -181,7 +178,7 @@ class CStyleLanguage(Renderer): elif u.op is Ops.RANGE: r[u] = f"{axis_letters[u.arg[-1]]}idx"+range_str(u) else: prefix = {Ops.WMMA: "wmma", Ops.DEFINE_LOCAL: "temp", Ops.CONST: "const", - Ops.CAST: "cast", Ops.BITCAST: "cast", Ops.GEP: "gep", Ops.VECTORIZE: "cast", Ops.PRECAST: "precast", + Ops.CAST: "cast", Ops.BITCAST: "cast", Ops.GEP: "gep", Ops.VECTORIZE: "cast", Ops.INDEX: "bidx", Ops.DEFINE_REG: "acc", Ops.LOAD: "val"}.get(u.op, "alu") r[u] = f"{prefix}{c[prefix]}" @@ -278,7 +275,7 @@ class OpenCLRenderer(CStyleLanguage): dtypes.bfloat16: "ushort" } string_rewrite = PatternMatcher([ - (UPat(Ops.BITCAST, name="x"), lambda ctx,x: f"as_{ctx.render_dtype(x.dtype)}({ctx[x.src[0]]})"), + (UPat(Ops.BITCAST, name="x"), lambda ctx,x: f"as_{ctx.render_dtype(x.dtype)}(({ctx.render_dtype(x.src[0].dtype)})({ctx[x.src[0]]}))"), # load/store image (OpenCL) (UPat(Ops.LOAD, dtype=dtypes.float.vec(4), src=(UPat.var('buf').index(UPat.var('idx', dtypes.int.vec(2)), UPat.var("gate")), UPat.var("var"))), lambda ctx,buf,idx,var,gate: f"({ctx[gate]}?read_imagef({ctx[buf]}, smp, {ctx[idx]}):{ctx[var]})"), @@ -338,7 +335,7 @@ class MetalRenderer(CStyleLanguage): ]) + extra_pm string_rewrite = PatternMatcher([ - (UPat(Ops.BITCAST, name="x"), lambda ctx,x: f"as_type<{ctx.render_dtype(x.dtype)}>({ctx[x.src[0]]})"), + (UPat(Ops.BITCAST, name="x"), lambda ctx,x: f"as_type<{ctx.render_dtype(x.dtype)}>(({ctx.render_dtype(x.src[0].dtype)})({ctx[x.src[0]]}))"), ]) + base_rewrite def render_kernel(self, function_name, kernel, bufs, uops, prefix=None): @@ -385,6 +382,10 @@ class CUDARenderer(CStyleLanguage): extra_matcher = create_non_native_float_pats(dtypes.fp8s, casting=False) + PatternMatcher([ (UPat(Ops.CAST, dtypes.fp8s, UPat.var("x", dtypes.fp8s), name='y'), lambda x,y: x.cast(dtypes.float).cast(y.dtype) if x.dtype!=y.dtype else None), ]) + extra_pm + string_rewrite = PatternMatcher([ + (UPat(Ops.BITCAST, name="x"), lambda ctx,x: f"tg_bitcast<{ctx.render_dtype(x.dtype)}>(({ctx.render_dtype(x.src[0].dtype)})({ctx[x.src[0]]}))"), + ]) + base_rewrite + def render_vector_prefix(self, dt:DType) -> str: vec, scal = self.render_dtype(dt), self.render_dtype(dt.scalar()), elems, header = ', '.join(_nms[:dt.count]), ', '.join([f"{scal} {x}" for x in _nms[:dt.count]]) @@ -392,8 +393,8 @@ class CUDARenderer(CStyleLanguage): def render_kernel(self, function_name, kernel, bufs, uops, prefix=None): # TODO: why is dtypes.bfloat16.name == "__bf16"? would be easier not override dtypes.name - prefix = ["#define INFINITY (__int_as_float(0x7f800000))","#define NAN (__int_as_float(0x7fffffff))"] - + prefix = ["#define INFINITY (__int_as_float(0x7f800000))", "#define NAN (__int_as_float(0x7fffffff))", + "template __device__ __forceinline__ T tg_bitcast(F v) { union U { F f; T t; }; U u; u.f = v; return u.t; }"] used_dtypes = uops_to_dtypes(uops) if any(dt.scalar() in dtypes.fp8s for dt in used_dtypes): prefix.append("#include ") if any(dt.scalar() == dtypes.half for dt in used_dtypes): prefix.append("#include ") diff --git a/tinygrad/uop/__init__.py b/tinygrad/uop/__init__.py index 322cd2323f..9d9e1054ba 100644 --- a/tinygrad/uop/__init__.py +++ b/tinygrad/uop/__init__.py @@ -25,7 +25,7 @@ class Ops(FastEnum): # ** 2 -- non op uops ** # uops that aren't rendered - NOOP = auto(); SINK = auto(); PRECAST = auto() + NOOP = auto(); SINK = auto() # AFTER passes src[0] through and promises in the toposort that any consumers of the AFTER run after src[1:] AFTER = auto() diff --git a/tinygrad/uop/ops.py b/tinygrad/uop/ops.py index 253d83902c..6d1160db33 100644 --- a/tinygrad/uop/ops.py +++ b/tinygrad/uop/ops.py @@ -200,7 +200,7 @@ class UOp(OpMixin, metaclass=UOpMetaClass): match self.op: # late ops don't have shape case Ops.UNIQUE | Ops.DEVICE | Ops.RANGE | Ops.LOAD | Ops.IF | Ops.BARRIER | Ops.CUSTOM | Ops.CUSTOMI | \ - Ops.VECTORIZE | Ops.VCONST | Ops.GEP | Ops.SPECIAL | Ops.UNROLL | Ops.PRECAST | Ops.CONTRACT: + Ops.VECTORIZE | Ops.VCONST | Ops.GEP | Ops.SPECIAL | Ops.UNROLL | Ops.CONTRACT: return None case Ops.INDEX: diff --git a/tinygrad/uop/spec.py b/tinygrad/uop/spec.py index 8fcd09a041..ecb915b5a1 100644 --- a/tinygrad/uop/spec.py +++ b/tinygrad/uop/spec.py @@ -147,8 +147,8 @@ shared_codegen_spec = PatternMatcher([ (UPat().index(UPat()).or_casted().load(), lambda: True), (UPat(Ops.INDEX).or_casted().store(UPat()), lambda: True), - # all CUSTOM + PRECAST - (UPat((Ops.CUSTOMI, Ops.CUSTOM, Ops.PRECAST)), lambda: True), + # CUSTOM (inline and non inline) + (UPat((Ops.CUSTOMI, Ops.CUSTOM)), lambda: True), # INDEX (UPat(GroupOp.Defines|{Ops.AFTER}, name="buf").index(UPat.var("idx")), validate_index),