ptx indexing (#7359)

* ptx indexing

* shorter

* fix load/store
This commit is contained in:
George Hotz
2024-10-29 17:29:44 +07:00
committed by GitHub
parent 572499c71a
commit 0beb2d8f84
2 changed files with 19 additions and 30 deletions

View File

@@ -487,7 +487,7 @@ def flops_mem(uops:List[UOp], ignore_indexing=False) -> Tuple[sint, sint]:
if ignore_indexing:
for u in uops:
if u.op in {UOps.LOAD, UOps.STORE}:
offset = 0 if u.src[0].op not in {UOps.INDEX, UOps.CAST} else -1
offset = 0 if u.src[0].op not in {UOps.INDEX, UOps.CAST} and u.src[0].dtype != dtypes.int64 else -1
dont_count = dont_count.union(u.src[offset+1].sparents)
if len(u.src) > offset+3: dont_count = dont_count.union(u.src[offset+3].sparents)
elif u.op is UOps.IF:
@@ -503,7 +503,7 @@ def flops_mem(uops:List[UOp], ignore_indexing=False) -> Tuple[sint, sint]:
elif u.op is UOps.LOAD:
mem += u.dtype.itemsize * mults
elif u.op is UOps.STORE:
mem += u.src[2 if u.src[0].op not in {UOps.INDEX, UOps.CAST} else 1].dtype.itemsize * mults
mem += u.src[2 if u.src[0].op not in {UOps.INDEX, UOps.CAST} and u.src[0].dtype != dtypes.int64 else 1].dtype.itemsize * mults
elif u.op is UOps.ALU and u not in dont_count:
flops += (mults * (2 if u.arg == TernaryOps.MULACC else 1)) * u.dtype.count
elif u.op is UOps.WMMA and u not in dont_count:

View File

@@ -33,12 +33,6 @@ asm_for_op: Dict[Op, Callable] = {
f"@{a} mov.{name} {d}, {b};\n@!{a} mov.{name} {d}, {c};" if name == "pred" else f"selp.{'b16' if name == 'f16' else name} {d}, {b}, {c}, {a};"
}
def load_store_ptr_arithmetic(x:UOp, buf:UOp, alu:Optional[UOp]=None, const:Optional[UOp]=None) -> UOp:
src = list(x.src)
src[0] = buf.cast(dtypes.int64) if alu is None else (buf.cast(dtypes.int64) + alu.cast(dtypes.int64)*buf.dtype.itemsize)
src[1] = UOp.const(dtypes.int64, 0 if const is None else const.arg*buf.dtype.itemsize)
return x.replace(src=tuple(src))
supports_half: List[Op] = [UnaryOps.EXP2, BinaryOps.ADD, BinaryOps.MUL, BinaryOps.MAX, BinaryOps.CMPLT, TernaryOps.WHERE]
ptx_matcher = sym+PatternMatcher([
# bool CMPNE is XOR, bool CMPLT is XOR+AND (universal makes this slow, this is for renderer only)
@@ -49,13 +43,13 @@ ptx_matcher = sym+PatternMatcher([
lambda x: (UOp(x.op, dtypes.float32, tuple(vv.cast(dtypes.float32) for vv in x.src), x.arg).cast(dtypes.half)))
for op in asm_for_op.keys() if op not in supports_half],
# load/store bool -> uint8
(UPat(UOps.LOAD, dtypes.bool, name="x"),
lambda x: UOp(x.op, dtypes.uint8, x.src[0:2] + ((x.src[2].cast(dtypes.uint8),) if len(x.src) >= 3 else ()) + x.src[3:]).cast(dtypes.bool)),
(UPat(UOps.STORE, src=(UPat(), UPat(), UPat(dtype=dtypes.bool)), name="x", allow_any_len=True),
lambda x: UOp(x.op, dtypes.void, x.src[0:2] + (x.src[2].cast(dtypes.uint8),) + x.src[3:])),
# load/store use pointer arithmetic
(UPat((UOps.LOAD, UOps.STORE), name="x", allow_any_len=True, src=(UPat((UOps.DEFINE_LOCAL,UOps.DEFINE_GLOBAL), name="buf"),
UPat.any(UPat.var("alu")+UPat.cvar("const"), UPat.cvar("const"), UPat.var("alu")))), load_store_ptr_arithmetic),
(UPat(UOps.LOAD, dtypes.bool, src=(UPat(dtype=dtypes.int64),), name="x", allow_any_len=True),
lambda x: UOp(x.op, dtypes.uint8, x.src[0:1] + ((x.src[1].cast(dtypes.uint8),) if len(x.src) >= 2 else ()) + x.src[2:]).cast(dtypes.bool)),
(UPat(UOps.STORE, src=(UPat(dtype=dtypes.int64), UPat(dtype=dtypes.bool)), name="x", allow_any_len=True),
lambda x: UOp(x.op, dtypes.void, x.src[0:1] + (x.src[1].cast(dtypes.uint8),) + x.src[2:])),
# load/store use pointer arithmetic, and the cast does nothing
(UPat(UOps.INDEX, name="x"), lambda x: x.src[0].cast(dtypes.int64) + x.src[1].cast(dtypes.int64)*x.src[0].dtype.itemsize),
(UPat(UOps.CAST, name="x"), lambda x: x.src[0] if isinstance(x.dtype, PtrDType) else None),
])
class PTXRenderer(Renderer):
@@ -65,6 +59,7 @@ class PTXRenderer(Renderer):
tensor_cores = [tc for tc in CUDARenderer.tensor_cores if tc.dtype_in == dtypes.half]
code_for_op = asm_for_op
extra_matcher = ptx_matcher
indexing = True
def __init__(self, arch:str, device="CUDA"):
self.device, self.tensor_cores, self.arch = device, PTXRenderer.tensor_cores if int(arch[3:]) >= 80 else [], arch
def __reduce__(self): return self.__class__, (self.arch, self.device)
@@ -104,9 +99,6 @@ class PTXRenderer(Renderer):
if gate: return [f"@{gate} ld{ss}.{self.mem_types[dtype]} {dest}, [{loc}+{offset}];", f"@!{gate} mov.b{self.types[dtype][1:]} {dest}, {alt};"]
return [f"ld{ss}.{self.mem_types[dtype]} {dest}, [{loc}+{offset}];"]
def render_store(self, loc, val, dtype, gate=None, ss="", offset=0) -> List[str]:
return [(f"@{gate} " if gate else "") + f"st{ss}.{self.mem_types[dtype]} [{loc}+{offset}], {val};"]
def render_cast(self, d:str, a:str, dtype:DType, atype:DType, bitcast=False, pred=False) -> List[str]:
if bitcast: return [f"mov.b{self.types[dtype][1:]} {d}, {a};"]
if atype == dtypes.bool: return [f"selp.b{self.types[dtype][1:]} {d}, {render_val(1, dtype)}, {render_val(0, dtype)}, {a};"]
@@ -165,14 +157,12 @@ class PTXRenderer(Renderer):
kk(f"IF_{r[src[0].src[0]][1:]}_{uops.index(src[0])}:")
elif uop is UOps.STORE:
assert src[0].dtype == dtypes.int64, "store isn't int64"
assert src[1].op is UOps.CONST, f"store isn't const {u}"
mem_type = '.shared' if src[0].op is UOps.DEFINE_LOCAL or any(x.op is UOps.DEFINE_LOCAL for x in src[0].parents) else '.global'
if src[2].dtype.count > 1:
kk((f"@{r[src[3]]} " if len(src)>3 else "") + \
f"st{mem_type}.v{src[2].dtype.count}.{self.mem_types[src[2].dtype.scalar()]} [{r[src[0]]}+{src[1].arg}], {{{', '.join(r[src[2]])}}};")
gate = f"@{r[src[2]]} " if len(src)>2 and src[2].op is not UOps.IF else ""
if src[1].dtype.count > 1:
kk(gate + f"st{mem_type}.v{src[1].dtype.count}.{self.mem_types[src[1].dtype.scalar()]} [{r[src[0]]}+0], {{{', '.join(r[src[1]])}}};")
else:
kk(*self.render_store(r[src[0]], r[src[2]], src[2].dtype,
gate=r[src[3]] if len(src)>3 and src[3].op is not UOps.IF else None, ss=mem_type, offset=src[1].arg))
kk(gate + f"st{mem_type}.{self.mem_types[src[1].dtype]} [{r[src[0]]}+0], {r[src[1]]};")
else:
if uop is UOps.RANGE: kk(*self.render_loop(loop:=ssa('ridx', u), r[src[0]], "LOOP_"+loop[1:]))
elif uop is UOps.ALU:
@@ -198,18 +188,17 @@ class PTXRenderer(Renderer):
r[u] = r[src[0]][u.arg[0]]
elif uop is UOps.LOAD:
assert src[0].dtype == dtypes.int64, "load isn't int64"
assert src[1].op is UOps.CONST, f"load isn't const {u}"
mem_type = '.shared' if src[0].op is UOps.DEFINE_LOCAL or any(x.op is UOps.DEFINE_LOCAL for x in src[0].parents) else '.global'
has_gate = len(src) > 3 and src[3].op is UOps.ALU
has_gate = len(src) > 2 and src[2].op is UOps.ALU
if dtype.count > 1:
r[u] = [ssa('val', dtype=self.types[dtype.scalar()]) for _ in range(dtype.count)]
if has_gate:
for v in r[u]: kk(f"mov.{self.mem_types[dtype.scalar()]} {v}, {render_val(0, dtype.scalar())};")
kk((f"@{r[src[3]]}"if has_gate else "")
+ f" ld{mem_type}.v{dtype.count}.{self.mem_types[dtype.scalar()]} {{{', '.join(r[u])}}}, [{r[src[0]]}+{src[1].arg}];")
kk((f"@{r[src[2]]}" if has_gate else "")
+ f" ld{mem_type}.v{dtype.count}.{self.mem_types[dtype.scalar()]} {{{', '.join(r[u])}}}, [{r[src[0]]}+0];")
else:
kk(*self.render_load(r[src[0]], ssa('val', u), dtype, gate=r[src[3]] if has_gate else None,
alt=r[src[2]] if has_gate else None, ss=mem_type, offset=src[1].arg))
kk(*self.render_load(r[src[0]], ssa('val', u), dtype, gate=r[src[2]] if has_gate else None,
alt=r[src[1]] if has_gate else None, ss=mem_type, offset=0))
elif uop is UOps.ASSIGN:
if dtype.count > 1:
for x0, x1 in zip(r[src[0]], r[src[1]]): kk(f"mov.b{self.types[dtype.scalar()][1:]} {x0}, {x1};")