mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
minor PTX matcher cleanup [run_process_replay] (#5336)
* minor PTX matcher cleanup [run_process_replay] uop.cast syntatic sugar and some newline/space cleanup * comment
This commit is contained in:
@@ -33,8 +33,8 @@ class PTXRenderer(Renderer):
|
||||
gdim = [f'%nctaid.{chr(120+i)}' for i in range(3)]
|
||||
lid = [f'%tid.{chr(120+i)}' for i in range(3)]
|
||||
asm_for_op: Dict[Op, Callable] = {
|
||||
UnaryOps.NEG: lambda d,a,dt,name: f"not.pred {d}, {a};" if name == "pred" else f"sub.{name} {d}, 0, {a};" if dtypes.is_unsigned(dt) \
|
||||
else f"neg.{name} {d}, {a};",
|
||||
UnaryOps.NEG: lambda d,a,dt,name:
|
||||
f"not.pred {d}, {a};" if name == "pred" else f"sub.{name} {d}, 0, {a};" if dtypes.is_unsigned(dt) else f"neg.{name} {d}, {a};",
|
||||
UnaryOps.RECIP: lambda d,a,dt,name: f"rcp{'.approx' if dtypes.is_float(dt) else ''}.{name} {d}, {a};",
|
||||
UnaryOps.EXP2: lambda d,a,dt,name: f"ex2.approx.{name} {d}, {a};", UnaryOps.LOG2: lambda d,a,dt,name: f"lg2.approx.{name} {d}, {a};",
|
||||
UnaryOps.SIN: lambda d,a,dt,name: f"sin.approx.{name} {d}, {a};", UnaryOps.SQRT: lambda d,a,dt,name: f"sqrt.approx.{name} {d}, {a};",
|
||||
@@ -42,18 +42,16 @@ class PTXRenderer(Renderer):
|
||||
BinaryOps.ADD: lambda d,a,b,dt,name: f"{'or' if name == 'pred' else 'add'}.{name} {d}, {a}, {b};",
|
||||
BinaryOps.MUL: lambda d,a,b,dt,name: ('and' if dt == dtypes.bool else 'mul') + f"{'.lo' if dtypes.is_int(dt) else ''}.{name} {d}, {a}, {b};",
|
||||
BinaryOps.XOR: lambda d,a,b,dt,name: f"xor.pred {d}, {a}, {b};" if name == "pred" else f"xor.b{name[1:]} {d}, {a}, {b};",
|
||||
BinaryOps.AND: lambda d, a, b, dt, name: f"and.pred {d}, {a}, {b};" if name == "pred" else f"and.b{name[1:]} {d}, {a}, {b};",
|
||||
BinaryOps.OR: lambda d, a, b, dt, name: f"or.pred {d}, {a}, {b};" if name == "pred" else f"or.b{name[1:]} {d}, {a}, {b};",
|
||||
BinaryOps.AND: lambda d,a,b,dt, name: f"and.pred {d}, {a}, {b};" if name == "pred" else f"and.b{name[1:]} {d}, {a}, {b};",
|
||||
BinaryOps.OR: lambda d,a,b,dt, name: f"or.pred {d}, {a}, {b};" if name == "pred" else f"or.b{name[1:]} {d}, {a}, {b};",
|
||||
BinaryOps.IDIV: lambda d,a,b,dt,name: f"div.{name} {d}, {a}, {b};",
|
||||
BinaryOps.MAX: lambda d,a,b,dt,name: f"max.{name} {d}, {a}, {b};", BinaryOps.MOD: lambda d,a,b,dt,name: f"rem.{name} {d}, {a}, {b};",
|
||||
BinaryOps.CMPLT: lambda d,a,b,dt,name: f"setp.lt.{name} {d}, {a}, {b};",
|
||||
BinaryOps.CMPNE: lambda d,a,b,dt,name: f"setp.ne.{name} {d}, {a}, {b};",
|
||||
BinaryOps.CMPLT: lambda d,a,b,dt,name: f"setp.lt.{name} {d}, {a}, {b};", BinaryOps.CMPNE: lambda d,a,b,dt,name: f"setp.ne.{name} {d}, {a}, {b};",
|
||||
TernaryOps.MULACC: lambda d,a,b,c,dt,name: f"{'fma.rn' if dtypes.is_float(dt) else 'mad.lo'}.{name} {d}, {a}, {b}, {c};",
|
||||
TernaryOps.WHERE: lambda d,a,b,c,dt,name:
|
||||
f"@{a} mov.{name} {d}, {b};\n@!{a} mov.{name} {d}, {c};" if name == "pred" else f"selp.{'b16' if name == 'f16' else name} {d}, {b}, {c}, {a};"
|
||||
}
|
||||
supports_half: List[Op] = [UnaryOps.NEG, UnaryOps.EXP2, BinaryOps.ADD, BinaryOps.MUL, BinaryOps.MAX, BinaryOps.CMPLT,
|
||||
TernaryOps.WHERE]
|
||||
supports_half: List[Op] = [UnaryOps.NEG, UnaryOps.EXP2, BinaryOps.ADD, BinaryOps.MUL, BinaryOps.MAX, BinaryOps.CMPLT, TernaryOps.WHERE]
|
||||
# HACK: Use s16 and u16 for int8 and uint8 buffers. This can be wrong in cast.
|
||||
types: Dict[DType, str] = { dtypes.int8: "s16", dtypes.int16: "s16", dtypes.int32: "s32", dtypes.int64: "s64",
|
||||
dtypes.uint8: "u16", dtypes.uint16: "u16", dtypes.uint32: "u32", dtypes.uint64: "u64",
|
||||
@@ -238,24 +236,24 @@ ptx_matcher = PatternMatcher([
|
||||
(UPat(UOps.ALU, BinaryOps.CMPNE, (UPat(dtype=dtypes.bool),UPat()), "root"), lambda root: UOp(root.op, root.dtype, root.src, BinaryOps.XOR)),
|
||||
(UPat(UOps.ALU, BinaryOps.CMPLT, (UPat(name="x", dtype=dtypes.bool),UPat(name="y")), "root"),
|
||||
lambda root,x,y: UOp(root.op, root.dtype, (UOp(UOps.ALU, dtypes.bool, (x,), UnaryOps.NEG), y), BinaryOps.MUL)),
|
||||
(UPat(UOps.ALU, BinaryOps.ADD,
|
||||
[UPat(name="non_muls"), UPat(UOps.ALU, BinaryOps.MUL, name="muls")], "root"),
|
||||
(UPat(UOps.ALU, BinaryOps.ADD, [UPat(name="non_muls"), UPat(UOps.ALU, BinaryOps.MUL, name="muls")], "root"),
|
||||
lambda root, muls, non_muls: UOp(UOps.ALU, root.dtype, muls.src + (non_muls,), TernaryOps.MULACC)),
|
||||
*[(UPat(UOps.ALU, op, dtype=dtypes.half, name="x"),
|
||||
lambda x: UOp(UOps.CAST, dtypes.half, (UOp(x.op, dtypes.float32, tuple([UOp(UOps.CAST, dtypes.float32, (vv,)) for vv in x.src]), x.arg),)))
|
||||
lambda x: (UOp(x.op, dtypes.float32, tuple([vv.cast(dtypes.float32) for vv in x.src]), x.arg).cast(dtypes.half)))
|
||||
for op in PTXRenderer.asm_for_op.keys() if op not in PTXRenderer.supports_half],
|
||||
(UPat(UOps.ALU, name="x", dtype=dtypes.bool, arg=BinaryOps.MAX),
|
||||
lambda x: UOp(UOps.CAST, dtypes.bool, (UOp(UOps.ALU, dtypes.uint8, tuple(UOp(UOps.CAST, dtypes.uint8, (s,)) for s in x.src), x.arg),))),
|
||||
lambda x: UOp(UOps.ALU, dtypes.uint8, tuple(s.cast(dtypes.uint8) for s in x.src), x.arg).cast(dtypes.bool)),
|
||||
# TODO: this one looks sketchy, root.arg is applied to outer cast?
|
||||
(UPat(UOps.LOAD, name="root", dtype=dtypes.bool, src=(UPat(name="x"),UPat(name="y"),UPat(name="z"),UPat(name="k"))),
|
||||
lambda root,x,y,z,k: UOp(UOps.CAST, dtypes.bool, (UOp(root.op, dtypes.int8, (x,y,z,UOp(UOps.CAST, dtypes.uint8, (k,)))),), root.arg)),
|
||||
(UPat(UOps.LOAD, name="root", dtype=dtypes.bool, src=(UPat(),UPat())),
|
||||
lambda root: UOp(UOps.CAST, dtypes.bool, (UOp(root.op, dtypes.uint8, root.src, root.arg),))),
|
||||
lambda root: UOp(root.op, dtypes.uint8, root.src, root.arg).cast(dtypes.bool)),
|
||||
(UPat(UOps.STORE, name="root", src=(UPat(),UPat(),UPat(name="z",dtype=dtypes.bool), UPat())),
|
||||
lambda root,z: UOp(root.op, root.dtype, root.src[:2] + (UOp(UOps.CAST, dtypes.uint8, (z,)),), root.arg)),
|
||||
lambda root,z: UOp(root.op, root.dtype, root.src[:2] + (z.cast(dtypes.uint8),), root.arg)),
|
||||
(UPat(UOps.STORE, name="root", src=(UPat(),UPat(),UPat(name="z",dtype=dtypes.bool))),
|
||||
lambda root,z: UOp(root.op, root.dtype, root.src[:2] + (UOp(UOps.CAST, dtypes.uint8, (z,)),), root.arg)),
|
||||
lambda root,z: UOp(root.op, root.dtype, root.src[:2] + (z.cast(dtypes.uint8),), root.arg)),
|
||||
(UPat(UOps.STORE, name="root", src=(UPat(),UPat(),UPat(),UPat(name="g", dtype=dtypes.int))),
|
||||
lambda root,g: UOp(root.op, root.dtype, root.src[:3] + (UOp(UOps.CAST, dtypes.bool, (g,)),), root.arg)),
|
||||
lambda root,g: UOp(root.op, root.dtype, root.src[:3] + (g.cast(dtypes.uint8),), root.arg)),
|
||||
# ptr_ar (load/store)
|
||||
(UPat({UOps.LOAD, UOps.STORE}, name="root", allow_len={2,3,4,5}, src=(UPat({UOps.DEFINE_LOCAL,UOps.DEFINE_GLOBAL}),
|
||||
UPat(UOps.ALU, BinaryOps.ADD, src=[UPat(name="alu"), UPat(UOps.CONST, name="const")]))),
|
||||
@@ -264,12 +262,12 @@ ptx_matcher = PatternMatcher([
|
||||
UOp.const(const.dtype, root.src[0].dtype.itemsize)*const)+root.src[2:])),
|
||||
(UPat({UOps.LOAD, UOps.STORE}, name="root", allow_len={2,3,4,5}, src=(UPat({UOps.DEFINE_LOCAL,UOps.DEFINE_GLOBAL}),
|
||||
UPat(UOps.CONST, name="const"))),
|
||||
lambda root, const: UOp(root.op, root.dtype, (root.src[0].cast(dtypes.int64),
|
||||
UOp.const(dtypes.int64, const.arg * root.src[0].dtype.itemsize),
|
||||
)+root.src[2:])),
|
||||
lambda root, const: UOp(root.op, root.dtype,
|
||||
(root.src[0].cast(dtypes.int64),
|
||||
UOp.const(dtypes.int64, const.arg*root.src[0].dtype.itemsize),)+root.src[2:])),
|
||||
(UPat({UOps.LOAD, UOps.STORE}, name="root", allow_len={2,3,4,5}, src=(UPat({UOps.DEFINE_LOCAL,UOps.DEFINE_GLOBAL}),
|
||||
UPat(name="alu"))), # no const here
|
||||
lambda root, alu: UOp(root.op, root.dtype,
|
||||
(alu.cast(dtypes.int64)*UOp.const(dtypes.int64, root.src[0].dtype.itemsize)+root.src[0].cast(dtypes.int64),
|
||||
UOp.const(dtypes.int64, 0))+root.src[2:])),
|
||||
UOp.const(dtypes.int64, 0))+root.src[2:])),
|
||||
])
|
||||
|
||||
Reference in New Issue
Block a user