diff --git a/tinygrad/renderer/assembly.py b/tinygrad/renderer/assembly.py index bdb0719fbc..bcf76684d1 100644 --- a/tinygrad/renderer/assembly.py +++ b/tinygrad/renderer/assembly.py @@ -32,6 +32,12 @@ asm_for_op: Dict[Op, Callable] = { f"@{a} mov.{name} {d}, {b};\n@!{a} mov.{name} {d}, {c};" if name == "pred" else f"selp.{'b16' if name == 'f16' else name} {d}, {b}, {c}, {a};" } +def load_store_ptr_arithmetic(x:UOp, buf:UOp, alu:Optional[UOp]=None, const:Optional[UOp]=None) -> UOp: + src = list(x.src) + src[0] = buf.cast(dtypes.int64) if alu is None else (buf.cast(dtypes.int64) + alu.cast(dtypes.int64)*buf.dtype.itemsize) + src[1] = UOp.const(dtypes.int64, 0 if const is None else const.arg*buf.dtype.itemsize) + return x.replace(src=tuple(src)) + supports_half: List[Op] = [UnaryOps.EXP2, BinaryOps.ADD, BinaryOps.MUL, BinaryOps.MAX, BinaryOps.CMPLT, TernaryOps.WHERE] ptx_matcher = constant_folder+PatternMatcher([ # bool CMPNE is XOR, bool CMPLT is XOR+AND (universal makes this slow, this is for renderer only) @@ -41,31 +47,14 @@ ptx_matcher = constant_folder+PatternMatcher([ *[(UPat(UOps.ALU, arg=op, dtype=dtypes.half, name="x"), lambda x: (UOp(x.op, dtypes.float32, tuple(vv.cast(dtypes.float32) for vv in x.src), x.arg).cast(dtypes.half))) for op in asm_for_op.keys() if op not in supports_half], - # fix the gates for load/store (low quality!) - (UPat(UOps.LOAD, name="root", dtype=dtypes.bool, src=(UPat.var("x"),UPat.var("y"),UPat.var("z"),UPat.var("k"))), - lambda root,x,y,z,k: UOp(root.op, dtypes.uint8, (x,y,z.cast(dtypes.uint8),k)).cast(dtypes.bool)), - (UPat(UOps.LOAD, name="root", dtype=dtypes.bool, src=(UPat(),UPat())), - lambda root: UOp(root.op, dtypes.uint8, root.src, root.arg).cast(dtypes.bool)), - (UPat(UOps.STORE, name="root", src=(UPat(),UPat(),UPat.var("z", dtypes.bool), UPat())), - lambda root,z: UOp(root.op, root.dtype, root.src[:2] + (z.cast(dtypes.uint8),), root.arg)), - (UPat(UOps.STORE, name="root", src=(UPat(),UPat(),UPat.var("z", dtypes.bool))), - lambda root,z: UOp(root.op, root.dtype, root.src[:2] + (z.cast(dtypes.uint8),), root.arg)), - (UPat(UOps.STORE, name="root", src=(UPat(),UPat(),UPat(),UPat.var("g", dtypes.int))), - lambda root,g: UOp(root.op, root.dtype, root.src[:3] + (g.cast(dtypes.uint8),), root.arg)), - # ptr_ar (load/store) - (UPat((UOps.LOAD, UOps.STORE), name="root", allow_any_len=True, src=(UPat((UOps.DEFINE_LOCAL,UOps.DEFINE_GLOBAL)), - UPat(UOps.ALU, arg=BinaryOps.ADD, src=[UPat.var("alu"), UPat.cvar("const")]))), - lambda root, alu, const: UOp(root.op, root.dtype, - (alu.cast(dtypes.int64)*UOp.const(dtypes.int64, root.src[0].dtype.itemsize)+root.src[0].cast(dtypes.int64), - const*root.src[0].dtype.itemsize)+root.src[2:])), - (UPat((UOps.LOAD, UOps.STORE), name="root", allow_any_len=True, src=(UPat((UOps.DEFINE_LOCAL,UOps.DEFINE_GLOBAL)), UPat.cvar("const"))), - lambda root, const: UOp(root.op, root.dtype, - (root.src[0].cast(dtypes.int64), - UOp.const(dtypes.int64, const.arg*root.src[0].dtype.itemsize),)+root.src[2:])), - (UPat((UOps.LOAD, UOps.STORE), name="root", allow_any_len=True, src=(UPat((UOps.DEFINE_LOCAL,UOps.DEFINE_GLOBAL)), UPat.var("alu"))), - lambda root, alu: UOp(root.op, root.dtype, - (alu.cast(dtypes.int64)*UOp.const(dtypes.int64, root.src[0].dtype.itemsize)+root.src[0].cast(dtypes.int64), - UOp.const(dtypes.int64, 0))+root.src[2:])), + # load/store bool -> uint8 + (UPat(UOps.LOAD, dtypes.bool, name="x"), + lambda x: UOp(x.op, dtypes.uint8, x.src[0:2] + ((x.src[2].cast(dtypes.uint8),) if len(x.src) >= 3 else ()) + x.src[3:]).cast(dtypes.bool)), + (UPat(UOps.STORE, src=(UPat(), UPat(), UPat(dtype=dtypes.bool)), name="x", allow_any_len=True), + lambda x: UOp(x.op, dtypes.void, x.src[0:2] + (x.src[2].cast(dtypes.uint8),) + x.src[3:])), + # load/store use pointer arithmetic + (UPat((UOps.LOAD, UOps.STORE), name="x", allow_any_len=True, src=(UPat((UOps.DEFINE_LOCAL,UOps.DEFINE_GLOBAL), name="buf"), + UPat.any(UPat.var("alu")+UPat.cvar("const"), UPat.cvar("const"), UPat.var("alu")))), load_store_ptr_arithmetic), ]) class PTXRenderer(Renderer):