mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-31 01:38:20 -05:00
@@ -119,14 +119,6 @@ class PTXRenderer(Renderer):
|
||||
if u is not None: r[u] = f"%{prefix}{c[prefix]-1}"
|
||||
return f"%{prefix}{c[prefix]-1}"
|
||||
|
||||
c_label: DefaultDict[str, int] = defaultdict(int)
|
||||
r_label: Dict[UOp, str] = {}
|
||||
def ssa_label(prefix:str, u:UOp):
|
||||
nonlocal c_label, r_label
|
||||
c_label[prefix] += 1
|
||||
r_label[u] = f"{self.label_prefix}{prefix}_{c_label[prefix]-1}"
|
||||
return r_label[u]
|
||||
|
||||
def const(x:ConstType, dtype:DType, mov=False):
|
||||
if mov or dtype in self.const_requires_mov:
|
||||
kk(*self.render_const(x, dtype, mov=(out:=ssa('const', dtype=self.types[dtype]))))
|
||||
@@ -144,14 +136,14 @@ class PTXRenderer(Renderer):
|
||||
uop,dtype,vin,args = u.uop,u.dtype,u.vin,u.arg
|
||||
if uop is UOps.IF:
|
||||
assert vin[0].dtype is not None
|
||||
kk(*self.render_bra(ssa_label('if', u), _cast(r[vin[0]], dtypes.bool, vin[0].dtype, u=u, pred=True)))
|
||||
kk(*self.render_bra(f"IF_{r[vin[0]][1:]}", _cast(r[vin[0]], dtypes.bool, vin[0].dtype, u=u, pred=True)))
|
||||
elif uop is UOps.BARRIER and self.barrier: kk(self.barrier)
|
||||
elif uop is UOps.ENDRANGE:
|
||||
kk(self.asm_for_op[BinaryOps.ADD](r[vin[0]], r[vin[0]], "1", dtypes.int, self.types[dtypes.int]),
|
||||
self.asm_for_op[BinaryOps.CMPLT](pred:=ssa("pred", dtype="pred"), r[vin[0]], r[vin[0].vin[1]], dtypes.int, self.types[dtypes.int]))
|
||||
kk(*self.render_bra(r_label[vin[0]], pred))
|
||||
kk(*self.render_bra(f"LOOP_{r[vin[0]][1:]}", pred))
|
||||
elif uop is UOps.ENDIF:
|
||||
kk(f"{r_label[vin[0]]}:")
|
||||
kk(f"IF_{r[vin[0].vin[0]][1:]}:")
|
||||
elif uop is UOps.STORE:
|
||||
assert vin[0].dtype is not None and vin[2].dtype is not None
|
||||
assert vin[0].dtype == dtypes.int64, "store isn't int64"
|
||||
@@ -164,7 +156,7 @@ class PTXRenderer(Renderer):
|
||||
kk(*self.render_store(r[vin[0]], r[vin[2]], vin[2].dtype, gate=r[vin[3]] if len(vin)>3 else None, ss=mem_type, offset=vin[1].arg))
|
||||
else:
|
||||
assert dtype is not None, f"None dtype for uop {uop}"
|
||||
if uop is UOps.RANGE: kk(*self.render_loop(ssa('ridx', u), r[vin[0]], ssa_label('loop', u)))
|
||||
if uop is UOps.RANGE: kk(*self.render_loop(loop:=ssa('ridx', u), r[vin[0]], "LOOP_"+loop[1:]))
|
||||
elif uop is UOps.ALU:
|
||||
assert vin[0].dtype is not None
|
||||
if args is BinaryOps.CMPLT or args is BinaryOps.CMPNE:
|
||||
|
||||
Reference in New Issue
Block a user