fix define global (#4383)

* fix define global

* remove name from DEFINE_GLOBAL

* fix fuzzing

* fix ptx

* fix python
This commit is contained in:
George Hotz
2024-05-01 19:32:56 -07:00
committed by GitHub
parent ad116dc5c6
commit f635c4d273
10 changed files with 48 additions and 42 deletions

View File

@@ -94,7 +94,7 @@ def uops_to_asm(lang:AssemblyLanguage, function_name:str, _uops:UOpGraph) -> str
matcher.rewrite_graph(uops)
for pointer_op in list(filter(lambda uop: uop.uop in [UOps.LOAD, UOps.STORE], uops.uops)): ptr_ar(pointer_op, uops)
uops.remove_childless(set(x for x in uops if x.uop in {UOps.DEFINE_GLOBAL, UOps.PHI, UOps.ENDIF, UOps.ENDLOOP, UOps.STORE}))
uops.remove_childless(set(x for x in uops if x.uop in {UOps.PHI, UOps.ENDIF, UOps.ENDLOOP, UOps.STORE}))
uops.optimize_loops()
def kk(*s: str): kernel.append("\n".join(s))
@@ -199,11 +199,11 @@ def uops_to_asm(lang:AssemblyLanguage, function_name:str, _uops:UOpGraph) -> str
r[u] = f"%{args.expr}"
if lang.load_global: kk(*lang.render_load(args.expr, ssa('dat', u, lang.types[dtype]), dtype, ss=".param"))
elif uop is UOps.DEFINE_GLOBAL:
bufs.append((args[1], dtype))
r[u] = f"%{args[1]}"
bufs.append((nm:=f"data{args[0]}", dtype))
r[u] = f"%{nm}"
if lang.load_global:
dt = dtypes.ulong if dtype.__class__ == PtrDType else dtype
kk(*lang.render_load(args[1], ssa('dat', u, lang.types[dt]), dt, ss=".param"))
kk(*lang.render_load(nm, ssa('dat', u, lang.types[dt]), dt, ss=".param"))
elif uop is UOps.WMMA:
wmma = []
for vv in vin[:2]:

View File

@@ -161,9 +161,8 @@ def uops_to_cstyle(lang:CStyleLanguage, function_name:str, uops:UOpGraph) -> str
bufs.append((args.expr, (dtype,False)))
r[u] = args.expr
elif uop is UOps.DEFINE_GLOBAL:
assert len(bufs) == args[0], f"missed a global buffer {len(bufs)} {args}"
bufs.append((args[1], (dtype,args[2])))
r[u] = args[1]
bufs.append((nm:=f"data{args[0]}", (dtype,args[1])))
r[u] = nm
elif uop is UOps.WMMA: kk(f"{lang.render_dtype(dtype)} {ssa('wmma',u)} = __{args[0]}({r[vin[0]]}, {r[vin[1]]}, {r[vin[2]]});")
elif uop is UOps.DEFINE_ACC: kk(f"{lang.render_dtype(dtype)} {ssa('acc',u)} = {lang.render_const(args, dtype)};")
elif uop is UOps.CONST: r[u] = lang.render_const(args, dtype) if args >= 0 else f"({lang.render_const(args, dtype)})"