mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-07 03:00:26 -04:00
fix define global (#4383)
* fix define global * remove name from DEFINE_GLOBAL * fix fuzzing * fix ptx * fix python
This commit is contained in:
@@ -94,7 +94,7 @@ def uops_to_asm(lang:AssemblyLanguage, function_name:str, _uops:UOpGraph) -> str
|
||||
matcher.rewrite_graph(uops)
|
||||
|
||||
for pointer_op in list(filter(lambda uop: uop.uop in [UOps.LOAD, UOps.STORE], uops.uops)): ptr_ar(pointer_op, uops)
|
||||
uops.remove_childless(set(x for x in uops if x.uop in {UOps.DEFINE_GLOBAL, UOps.PHI, UOps.ENDIF, UOps.ENDLOOP, UOps.STORE}))
|
||||
uops.remove_childless(set(x for x in uops if x.uop in {UOps.PHI, UOps.ENDIF, UOps.ENDLOOP, UOps.STORE}))
|
||||
uops.optimize_loops()
|
||||
|
||||
def kk(*s: str): kernel.append("\n".join(s))
|
||||
@@ -199,11 +199,11 @@ def uops_to_asm(lang:AssemblyLanguage, function_name:str, _uops:UOpGraph) -> str
|
||||
r[u] = f"%{args.expr}"
|
||||
if lang.load_global: kk(*lang.render_load(args.expr, ssa('dat', u, lang.types[dtype]), dtype, ss=".param"))
|
||||
elif uop is UOps.DEFINE_GLOBAL:
|
||||
bufs.append((args[1], dtype))
|
||||
r[u] = f"%{args[1]}"
|
||||
bufs.append((nm:=f"data{args[0]}", dtype))
|
||||
r[u] = f"%{nm}"
|
||||
if lang.load_global:
|
||||
dt = dtypes.ulong if dtype.__class__ == PtrDType else dtype
|
||||
kk(*lang.render_load(args[1], ssa('dat', u, lang.types[dt]), dt, ss=".param"))
|
||||
kk(*lang.render_load(nm, ssa('dat', u, lang.types[dt]), dt, ss=".param"))
|
||||
elif uop is UOps.WMMA:
|
||||
wmma = []
|
||||
for vv in vin[:2]:
|
||||
|
||||
@@ -161,9 +161,8 @@ def uops_to_cstyle(lang:CStyleLanguage, function_name:str, uops:UOpGraph) -> str
|
||||
bufs.append((args.expr, (dtype,False)))
|
||||
r[u] = args.expr
|
||||
elif uop is UOps.DEFINE_GLOBAL:
|
||||
assert len(bufs) == args[0], f"missed a global buffer {len(bufs)} {args}"
|
||||
bufs.append((args[1], (dtype,args[2])))
|
||||
r[u] = args[1]
|
||||
bufs.append((nm:=f"data{args[0]}", (dtype,args[1])))
|
||||
r[u] = nm
|
||||
elif uop is UOps.WMMA: kk(f"{lang.render_dtype(dtype)} {ssa('wmma',u)} = __{args[0]}({r[vin[0]]}, {r[vin[1]]}, {r[vin[2]]});")
|
||||
elif uop is UOps.DEFINE_ACC: kk(f"{lang.render_dtype(dtype)} {ssa('acc',u)} = {lang.render_const(args, dtype)};")
|
||||
elif uop is UOps.CONST: r[u] = lang.render_const(args, dtype) if args >= 0 else f"({lang.render_const(args, dtype)})"
|
||||
|
||||
Reference in New Issue
Block a user