Remove ssa label (#4823)

* remove ssa label

* linting
This commit is contained in:
Szymon Ożóg
2024-06-04 16:51:05 +02:00
committed by GitHub
parent 052c928d06
commit b6895dabaa

View File

@@ -119,14 +119,6 @@ class PTXRenderer(Renderer):
if u is not None: r[u] = f"%{prefix}{c[prefix]-1}"
return f"%{prefix}{c[prefix]-1}"
c_label: DefaultDict[str, int] = defaultdict(int)
r_label: Dict[UOp, str] = {}
def ssa_label(prefix:str, u:UOp):
nonlocal c_label, r_label
c_label[prefix] += 1
r_label[u] = f"{self.label_prefix}{prefix}_{c_label[prefix]-1}"
return r_label[u]
def const(x:ConstType, dtype:DType, mov=False):
if mov or dtype in self.const_requires_mov:
kk(*self.render_const(x, dtype, mov=(out:=ssa('const', dtype=self.types[dtype]))))
@@ -144,14 +136,14 @@ class PTXRenderer(Renderer):
uop,dtype,vin,args = u.uop,u.dtype,u.vin,u.arg
if uop is UOps.IF:
assert vin[0].dtype is not None
kk(*self.render_bra(ssa_label('if', u), _cast(r[vin[0]], dtypes.bool, vin[0].dtype, u=u, pred=True)))
kk(*self.render_bra(f"IF_{r[vin[0]][1:]}", _cast(r[vin[0]], dtypes.bool, vin[0].dtype, u=u, pred=True)))
elif uop is UOps.BARRIER and self.barrier: kk(self.barrier)
elif uop is UOps.ENDRANGE:
kk(self.asm_for_op[BinaryOps.ADD](r[vin[0]], r[vin[0]], "1", dtypes.int, self.types[dtypes.int]),
self.asm_for_op[BinaryOps.CMPLT](pred:=ssa("pred", dtype="pred"), r[vin[0]], r[vin[0].vin[1]], dtypes.int, self.types[dtypes.int]))
kk(*self.render_bra(r_label[vin[0]], pred))
kk(*self.render_bra(f"LOOP_{r[vin[0]][1:]}", pred))
elif uop is UOps.ENDIF:
kk(f"{r_label[vin[0]]}:")
kk(f"IF_{r[vin[0].vin[0]][1:]}:")
elif uop is UOps.STORE:
assert vin[0].dtype is not None and vin[2].dtype is not None
assert vin[0].dtype == dtypes.int64, "store isn't int64"
@@ -164,7 +156,7 @@ class PTXRenderer(Renderer):
kk(*self.render_store(r[vin[0]], r[vin[2]], vin[2].dtype, gate=r[vin[3]] if len(vin)>3 else None, ss=mem_type, offset=vin[1].arg))
else:
assert dtype is not None, f"None dtype for uop {uop}"
if uop is UOps.RANGE: kk(*self.render_loop(ssa('ridx', u), r[vin[0]], ssa_label('loop', u)))
if uop is UOps.RANGE: kk(*self.render_loop(loop:=ssa('ridx', u), r[vin[0]], "LOOP_"+loop[1:]))
elif uop is UOps.ALU:
assert vin[0].dtype is not None
if args is BinaryOps.CMPLT or args is BinaryOps.CMPNE: