mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-28 00:08:16 -05:00
cleaner ptx match rules [run_process_replay] (#6770)
* cleaner ptx match rules [run_process_replay] * clean up load/store rules * now that's clean * oops, typo * cast back to bool
This commit is contained in:
@@ -32,6 +32,12 @@ asm_for_op: Dict[Op, Callable] = {
|
||||
f"@{a} mov.{name} {d}, {b};\n@!{a} mov.{name} {d}, {c};" if name == "pred" else f"selp.{'b16' if name == 'f16' else name} {d}, {b}, {c}, {a};"
|
||||
}
|
||||
|
||||
def load_store_ptr_arithmetic(x:UOp, buf:UOp, alu:Optional[UOp]=None, const:Optional[UOp]=None) -> UOp:
|
||||
src = list(x.src)
|
||||
src[0] = buf.cast(dtypes.int64) if alu is None else (buf.cast(dtypes.int64) + alu.cast(dtypes.int64)*buf.dtype.itemsize)
|
||||
src[1] = UOp.const(dtypes.int64, 0 if const is None else const.arg*buf.dtype.itemsize)
|
||||
return x.replace(src=tuple(src))
|
||||
|
||||
supports_half: List[Op] = [UnaryOps.EXP2, BinaryOps.ADD, BinaryOps.MUL, BinaryOps.MAX, BinaryOps.CMPLT, TernaryOps.WHERE]
|
||||
ptx_matcher = constant_folder+PatternMatcher([
|
||||
# bool CMPNE is XOR, bool CMPLT is XOR+AND (universal makes this slow, this is for renderer only)
|
||||
@@ -41,31 +47,14 @@ ptx_matcher = constant_folder+PatternMatcher([
|
||||
*[(UPat(UOps.ALU, arg=op, dtype=dtypes.half, name="x"),
|
||||
lambda x: (UOp(x.op, dtypes.float32, tuple(vv.cast(dtypes.float32) for vv in x.src), x.arg).cast(dtypes.half)))
|
||||
for op in asm_for_op.keys() if op not in supports_half],
|
||||
# fix the gates for load/store (low quality!)
|
||||
(UPat(UOps.LOAD, name="root", dtype=dtypes.bool, src=(UPat.var("x"),UPat.var("y"),UPat.var("z"),UPat.var("k"))),
|
||||
lambda root,x,y,z,k: UOp(root.op, dtypes.uint8, (x,y,z.cast(dtypes.uint8),k)).cast(dtypes.bool)),
|
||||
(UPat(UOps.LOAD, name="root", dtype=dtypes.bool, src=(UPat(),UPat())),
|
||||
lambda root: UOp(root.op, dtypes.uint8, root.src, root.arg).cast(dtypes.bool)),
|
||||
(UPat(UOps.STORE, name="root", src=(UPat(),UPat(),UPat.var("z", dtypes.bool), UPat())),
|
||||
lambda root,z: UOp(root.op, root.dtype, root.src[:2] + (z.cast(dtypes.uint8),), root.arg)),
|
||||
(UPat(UOps.STORE, name="root", src=(UPat(),UPat(),UPat.var("z", dtypes.bool))),
|
||||
lambda root,z: UOp(root.op, root.dtype, root.src[:2] + (z.cast(dtypes.uint8),), root.arg)),
|
||||
(UPat(UOps.STORE, name="root", src=(UPat(),UPat(),UPat(),UPat.var("g", dtypes.int))),
|
||||
lambda root,g: UOp(root.op, root.dtype, root.src[:3] + (g.cast(dtypes.uint8),), root.arg)),
|
||||
# ptr_ar (load/store)
|
||||
(UPat((UOps.LOAD, UOps.STORE), name="root", allow_any_len=True, src=(UPat((UOps.DEFINE_LOCAL,UOps.DEFINE_GLOBAL)),
|
||||
UPat(UOps.ALU, arg=BinaryOps.ADD, src=[UPat.var("alu"), UPat.cvar("const")]))),
|
||||
lambda root, alu, const: UOp(root.op, root.dtype,
|
||||
(alu.cast(dtypes.int64)*UOp.const(dtypes.int64, root.src[0].dtype.itemsize)+root.src[0].cast(dtypes.int64),
|
||||
const*root.src[0].dtype.itemsize)+root.src[2:])),
|
||||
(UPat((UOps.LOAD, UOps.STORE), name="root", allow_any_len=True, src=(UPat((UOps.DEFINE_LOCAL,UOps.DEFINE_GLOBAL)), UPat.cvar("const"))),
|
||||
lambda root, const: UOp(root.op, root.dtype,
|
||||
(root.src[0].cast(dtypes.int64),
|
||||
UOp.const(dtypes.int64, const.arg*root.src[0].dtype.itemsize),)+root.src[2:])),
|
||||
(UPat((UOps.LOAD, UOps.STORE), name="root", allow_any_len=True, src=(UPat((UOps.DEFINE_LOCAL,UOps.DEFINE_GLOBAL)), UPat.var("alu"))),
|
||||
lambda root, alu: UOp(root.op, root.dtype,
|
||||
(alu.cast(dtypes.int64)*UOp.const(dtypes.int64, root.src[0].dtype.itemsize)+root.src[0].cast(dtypes.int64),
|
||||
UOp.const(dtypes.int64, 0))+root.src[2:])),
|
||||
# load/store bool -> uint8
|
||||
(UPat(UOps.LOAD, dtypes.bool, name="x"),
|
||||
lambda x: UOp(x.op, dtypes.uint8, x.src[0:2] + ((x.src[2].cast(dtypes.uint8),) if len(x.src) >= 3 else ()) + x.src[3:]).cast(dtypes.bool)),
|
||||
(UPat(UOps.STORE, src=(UPat(), UPat(), UPat(dtype=dtypes.bool)), name="x", allow_any_len=True),
|
||||
lambda x: UOp(x.op, dtypes.void, x.src[0:2] + (x.src[2].cast(dtypes.uint8),) + x.src[3:])),
|
||||
# load/store use pointer arithmetic
|
||||
(UPat((UOps.LOAD, UOps.STORE), name="x", allow_any_len=True, src=(UPat((UOps.DEFINE_LOCAL,UOps.DEFINE_GLOBAL), name="buf"),
|
||||
UPat.any(UPat.var("alu")+UPat.cvar("const"), UPat.cvar("const"), UPat.var("alu")))), load_store_ptr_arithmetic),
|
||||
])
|
||||
|
||||
class PTXRenderer(Renderer):
|
||||
|
||||
Reference in New Issue
Block a user