add name uop (#9149)

* add name uop, TODO: refactor renderer to use

* renderer uses name uop

* fix tests

* render

* ptx
This commit is contained in:
George Hotz
2025-02-18 15:26:58 +08:00
committed by GitHub
parent 2db8b4046a
commit a4dab3ec3f
18 changed files with 49 additions and 31 deletions

View File

@@ -52,7 +52,7 @@ class UOpsFuzzerRunner(CompiledRunner):
# setup prg
uops = list(path)
if DEBUG >= 5: print_uops(uops)
self.p = replace(self.p, name=(name:=f"{init_name}fuzz{i}"), src=Device[self.p.device].renderer.render(name, uops), uops=uops)
self.p = replace(self.p, name=(name:=f"{init_name}fuzz{i}"), src=Device[self.p.device].renderer.render(uops), uops=uops)
if DEBUG >= 4: print(self.p.src)
self.lib = Device[self.p.device].compiler.compile_cached(self.p.src)
self.clprg = Device[self.p.device].runtime(name, self.lib)

View File

@@ -40,7 +40,7 @@ def recreate_kernel(ast:UOp, opts:Renderer, applied_opts:list[Opt], name:str) ->
k = Kernel(ast, opts=opts)
for opt in applied_opts: k.apply_opt(opt)
# NOTE: replay with the captured renderer, not the one in master
return k.opts.render(name, cast(list,k.to_program().uops))
return k.opts.render(cast(list,k.to_program(name).uops))
# *** diff a "good" recreation against the generated version

View File

@@ -60,6 +60,6 @@ def eval_uop(uop:UOp, inputs:list[tuple[DType, list[Any]]]|None=None):
allocator._copyin(buf, memoryview(struct.pack(str(len(data)) + buf_dt.fmt, *data)))
g = UOp(Ops.DEFINE_GLOBAL, uop.dtype.ptr(), arg=0, src=())
rw = full_graph_rewrite(UOp.store(g.index(UOp.const(dtypes.int, 0)), uop).sink(), PythonRenderer)
prog = PythonProgram("run", PythonCompiler().compile(PythonRenderer().render("run", linearize_uop(rw))))
prog = PythonProgram("run", PythonCompiler().compile(PythonRenderer().render(linearize_uop(rw))))
prog(out_buf:=allocator.alloc(uop.dtype.itemsize), *bufs)
return out_buf.cast(uop.dtype.fmt).tolist()[0]

View File

@@ -6,7 +6,7 @@ class TestDeviceSpeed(unittest.TestCase):
@classmethod
def setUpClass(cls):
cls.dev = Device[Device.DEFAULT]
cls.empty = Device[Device.DEFAULT].renderer.render("test", [])
cls.empty = Device[Device.DEFAULT].renderer.render([])
def test_empty_compile(self):
with Timing("compiler "):

View File

@@ -952,7 +952,7 @@ class TestLinearizer(unittest.TestCase):
sink = UOp(Ops.SINK, src=(store,))
lin = Kernel(sink)
lin.linearize()
assert len(lin.uops) <= 9, "too many uops"
assert len(lin.uops) <= 10, "too many uops"
def test_upcast_cse(self):
# when upcasting, within a subtree, there may be common expressions.

View File

@@ -22,7 +22,7 @@ def _test_uop_result(inputs:List[Tensor], stores:List[UOp], local_size=None):
outbufs = [Buffer(Device.DEFAULT, sz:=(1 if local_size is None else prod(local_size)), (dtype:=u.src[1].dtype), \
initial_value=np.zeros(sz, dtype=_to_np_dtype(dtype)).data) for u in uops if u.op is Ops.STORE]
inbufs = [cast(UOp,x.lazydata).base.buffer for x in inputs]
src = Device[Device.DEFAULT].renderer.render("test", uops)
src = Device[Device.DEFAULT].renderer.render(uops)
ei = CompiledRunner(ProgramSpec("test", src, Device.DEFAULT, uops[-1], uops=uops, local_size=local_size))
ei.exec(outbufs+inbufs)
return [np.frombuffer(x.as_buffer(), _to_np_dtype(x.dtype)) for x in outbufs]

View File

@@ -802,7 +802,7 @@ class TestIdxUpcast(unittest.TestCase):
if s.ast.op is Ops.SINK:
renderer = Device[s.bufs[0].device].renderer
uops = linearize_uop(full_graph_rewrite(rewrite_shapetracker_with_index(s.ast, renderer), renderer))
renderer.render("test", uops)
renderer.render(uops)
return uops
def _assert(self, dtype: DType, a: Tensor):

View File

@@ -22,7 +22,7 @@ def to_uops_list(u:list[UOp], opts=None, skip_check=False) -> list[UOp]: return
def _uops_to_prg(uops_list):
uops = linearize_uop(full_graph_rewrite(ast:=UOp.sink(*uops_list), opts=Device[Device.DEFAULT].renderer))
src = Device[Device.DEFAULT].renderer.render("test", uops)
src = Device[Device.DEFAULT].renderer.render(uops)
has_local = Device[Device.DEFAULT].renderer.has_local
return CompiledRunner(ProgramSpec("test", src, Device.DEFAULT, ast, uops=uops,
global_size=[1,1,1] if has_local else None, local_size=[1,1,1] if has_local else None))
@@ -343,7 +343,7 @@ class TestAssembly(unittest.TestCase):
a1 = UOp(Ops.MUL, dtypes.int, (l1, c1))
a2 = UOp(Ops.MUL, dtypes.int, (l1, c2))
uops = to_uops_list([a1,a2], opts=Device[Device.DEFAULT].renderer)
Device[Device.DEFAULT].renderer.render("test", uops)
Device[Device.DEFAULT].renderer.render(uops)
ops = [x.op for x in uops]
self.assertIn(Ops.SHL, ops)
self.assertIn(Ops.MUL, ops)
@@ -356,7 +356,7 @@ class TestAssembly(unittest.TestCase):
a1 = UOp(Ops.IDIV, dtypes.uint, (l1, c1))
a2 = UOp(Ops.IDIV, dtypes.uint, (l1, c2))
uops = to_uops_list([a1,a2], opts=Device[Device.DEFAULT].renderer)
Device[Device.DEFAULT].renderer.render("test", uops)
Device[Device.DEFAULT].renderer.render(uops)
ops = [x.op for x in uops]
self.assertIn(Ops.SHR, ops)
self.assertIn(Ops.IDIV, ops)

View File

@@ -584,7 +584,7 @@ class Kernel:
num = f"n{Kernel.kernel_cnt[function_name]-1}" if Kernel.kernel_cnt[function_name] > 1 else ""
return name + colored(num, 'BLACK')
def get_optimized_ast(self) -> UOp:
def get_optimized_ast(self, name_override:Optional[str]=None) -> UOp:
@functools.lru_cache(None)
def fixup_ast(op:UOp) -> UOp:
ret = op.replace(src=tuple(fixup_ast(x) for x in op.src))
@@ -594,7 +594,9 @@ class Kernel:
if op.op is Ops.CONST and any(v.mask is not None for v in unwrap(st_uop.st).views): return op.valid(unwrap(st_uop.st))
# otherwise we just replace the VIEW source
return ret.replace(src=(st_uop,)) if len(op.src) == 1 else ret.replace(src=(ret.src[0], st_uop, *ret.src[2:]))
if op.op is Ops.SINK: return ret.replace(arg = KernelInfo(self.local_dims, self.upcasted, self.dont_use_locals))
if op.op is Ops.SINK:
return ret.replace(arg = KernelInfo(to_function_name(self.name) if name_override is None else name_override,
self.local_dims, self.upcasted, self.dont_use_locals))
if op.op is Ops.REDUCE_AXIS:
reduce_idx = len(self.bufs) + self.reduceops.index(op) * 2
@@ -664,8 +666,8 @@ class Kernel:
# **** this is the lowerer ****
@track_rewrites()
def linearize(self) -> Kernel:
modified_ast = self.get_optimized_ast()
def linearize(self, name_override:Optional[str]=None) -> Kernel:
modified_ast = self.get_optimized_ast(name_override)
if DEBUG >= 3:
print(self.name)
@@ -683,16 +685,17 @@ class Kernel:
return self
def to_program(self, name_override:Optional[str]=None) -> ProgramSpec:
self.linearize()
src = self.opts.render(name:=to_function_name(ansiname:=(name_override if name_override is not None else self.name)), self.uops)
self.linearize(name_override)
assert self.uops[0].op is Ops.NAME, "first uop must be name"
src = self.opts.render(self.uops)
if CAPTURE_PROCESS_REPLAY:
diskcache_put("kernel_process_replay", str(id(self)), (self.ast, self.opts, self.applied_opts, name, ContextVar._cache, src))
diskcache_put("kernel_process_replay", str(id(self)), (self.ast, self.opts, self.applied_opts, self.uops[0].arg, ContextVar._cache, src))
# group non-local bufs by the op type (LOAD or STORE) and the buffer arg. take the max access of that buffer in bytes
# TODO: these max and min don't work on symbolic, and results are very wrong.
mem_bytes = sum(max(x.src[0].dtype.itemsize * x.st_arg.real_size() for x in group)
for _, group in itertools.groupby([x for x in self.ast.toposort if x.op in GroupOp.Buffer and x.src[0].op is Ops.DEFINE_GLOBAL],
key=lambda x: (x.op, x.src[0].arg)))
return ProgramSpec(ansiname, src, self.opts.device, self.ast, self.uops, mem_estimate=mem_bytes,
return ProgramSpec(self.uops[0].arg, src, self.opts.device, self.ast, self.uops, mem_estimate=mem_bytes,
global_size=[1,1,1] if self.opts.has_local else None, local_size=[1,1,1] if self.opts.has_local else None)

View File

@@ -6,7 +6,7 @@ from tinygrad.spec import type_verify
from tinygrad.dtype import dtypes, PtrDType
from tinygrad.helpers import dedup, flatten, partition
DONT_PLACE_IN_BLOCK = {Ops.DEFINE_GLOBAL, Ops.DEFINE_LOCAL, Ops.DEFINE_VAR, Ops.SPECIAL, Ops.CONST, *GroupOp.Block}
DONT_PLACE_IN_BLOCK = {Ops.NAME, Ops.DEFINE_GLOBAL, Ops.DEFINE_LOCAL, Ops.DEFINE_VAR, Ops.SPECIAL, Ops.CONST, *GroupOp.Block}
def disp(y:UOp) -> str:
if y.op is Ops.BLOCKSTART: return "w"+disp(y.src[0])
@@ -70,7 +70,8 @@ def append_to_block(ctx:tuple[dict[UOp, tuple[UOp, ...]], dict[UOp, list[UOp]]],
return UOp(Ops.BLOCK, dtypes.void, tuple(dedup(list(old_blocks.values())+new_srcs)), BasicBlock(x.arg.ctx, tuple(to_append)+x.arg.lst))
make_basic_blocks = PatternMatcher([
(UPat(Ops.SINK, name="x"), lambda x: UOp(Ops.BLOCK, src=x.src, arg=BasicBlock((), (x,)))),
(UPat(Ops.SINK, name="x"),
lambda x: UOp(Ops.BLOCK, src=x.src+((UOp(Ops.NAME, arg=x.arg.name),) if x.arg is not None else ()), arg=BasicBlock((), (x,)))),
(UPat(Ops.BLOCK, name="x"), append_to_block),
])
@@ -117,7 +118,8 @@ def block_finalize(block:UOp):
_uops = sorted(dedup(block.src), key=lambda x: x.tuplize)
assert all(len(x.src) == 0 and x.op not in {Ops.BLOCK, Ops.BLOCKSTART, Ops.BLOCKEND, Ops.BLOCKFORK} for x in _uops)
_uops += block.arg.lst
assert _uops[-1].op is Ops.SINK, "block doesn't end with SINK"
# strip the SINK
assert _uops[-1].op is Ops.SINK, "doesn't end with SINK"
return UOp(Ops.BLOCK, arg=BasicBlock((), tuple(_uops[:-1])))
pm_block_finalize = PatternMatcher([(UPat(Ops.BLOCK, name="block"), block_finalize)])

View File

@@ -93,7 +93,7 @@ class MathTrait(SimpleMathTrait):
# the order of these Ops controls the order of the toposort
class Ops(FastEnum):
# uops that aren't rendered
SINK = auto(); CONTIGUOUS = auto(); CONTIGUOUS_BACKWARD = auto(); DETACH = auto(); KERNEL = auto() # noqa: E702
NAME = auto(); SINK = auto(); CONTIGUOUS = auto(); CONTIGUOUS_BACKWARD = auto(); DETACH = auto(); KERNEL = auto() # noqa: E702
# TODO: empty continues to exist because of tensor
EMPTY = auto()
@@ -669,6 +669,7 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
@dataclass(frozen=True)
class KernelInfo:
name: str = "test" # name of the kernel
local_dims: int = 0 # number of local dimensions (this is remapping RANGE to SPECIAL)
upcasted: int = 0 # count that are upcasted (this is remapping RANGE to UNROLL)
dont_use_locals: bool = False # don't use local indexing

View File

@@ -131,4 +131,4 @@ class Renderer:
code_for_op: dict[Ops, Callable] = {}
def __reduce__(self): return self.__class__, ()
def render(self, name:str, uops:list[UOp]) -> str: raise NotImplementedError("needs a renderer")
def render(self, uops:list[UOp]) -> str: raise NotImplementedError("needs a renderer")

View File

@@ -120,7 +120,7 @@ class CStyleLanguage(Renderer):
return self.type_map.get(scalar:=dt.scalar(), scalar.name)
def __getitem__(self, key): return self.r[key] # hacky helper
def render(self, name:str, uops:list[UOp]) -> str:
def render(self, uops:list[UOp]) -> str:
r: dict[UOp, str] = {}
self.r = r
@@ -129,7 +129,11 @@ class CStyleLanguage(Renderer):
kernel = []
depth = 1
c: defaultdict[str, int] = defaultdict(int)
name = "test"
for u in uops:
if u.op is Ops.NAME:
name = u.arg
continue
if u.op in (Ops.DEFINE_GLOBAL, Ops.DEFINE_VAR):
r[u] = f"data{u.arg}" if u.op is Ops.DEFINE_GLOBAL else u.arg[0]
bufs[u] = (r[u], (u.dtype, False))

View File

@@ -127,7 +127,7 @@ class LLVMRenderer(Renderer):
(UPat(Ops.CAST, name="root", src=(UPat.load(UPat.index(UPat.var("buf"), UPat.var("idx")), dtype=dtypes.bfloat16),)), llvm_bf16_cast),
])
def render(self, name: str, uops: list[UOp]) -> str:
def render(self, uops: list[UOp]) -> str:
r: dict[UOp, str] = {}
args: list[str] = []
kernel: list[str] = []
@@ -148,7 +148,11 @@ class LLVMRenderer(Renderer):
kernel += [f" {r[u]}_amx{i} = alloca {ldt(dtype)}, align {dtype.itemsize}",
f" {r[u]}_ptr_amx{i} = ptrtoint {ldt(dtype.ptr())} {r[u]}_amx{i} to i64"]
name = "test"
for u in uops:
if u.op is Ops.NAME:
name = u.arg
continue
if u.op in (Ops.DEFINE_GLOBAL, Ops.DEFINE_VAR):
r[u] = f"%data{u.arg}" if u.op is Ops.DEFINE_GLOBAL else f"%{u.arg[0]}"
# NOTE: MallocAllocator promises 0x20 alignment

View File

@@ -154,7 +154,7 @@ class PTXRenderer(Renderer):
params = ',\n\t'.join([f".param .{'u64' if dtype.__class__ == PtrDType else self.types[dtype]} {name}" for name,dtype in bufs])
return f"{self.kernel_prefix} {function_name}(\n\t{params}\n)\n{{\n{kernel}\n}}"
def render(self, name:str, uops:list[UOp]) -> str:
def render(self, uops:list[UOp]) -> str:
kernel:list[str] = []
bufs = []
@@ -169,7 +169,11 @@ class PTXRenderer(Renderer):
c[prefix] += 1
return f"%{prefix}{c[prefix]-1}"
name = "test"
for u in uops:
if u.op is Ops.NAME:
name = u.arg
continue
if u.op is Ops.VECTORIZE:
r[u] = [cast(str,r[x]) for x in u.src]
continue

View File

@@ -40,7 +40,7 @@ class PythonProgram:
loop_ends: dict[int, int] = {}
while i < len(self.uops):
uop, dtype, idp, arg = self.uops[i]
void_ops = {Ops.STORE, Ops.ENDRANGE, Ops.BARRIER, Ops.IF, Ops.ENDIF}
void_ops = {Ops.STORE, Ops.ENDRANGE, Ops.BARRIER, Ops.IF, Ops.ENDIF, Ops.NAME}
if uop is Ops.DEFINE_ACC: idp = [idp[0]]
inp = [ul[v] for v in idp if self.uops[v][0] not in void_ops]
dtp = [dl[v] for v in idp if self.uops[v][0] not in void_ops]
@@ -60,7 +60,7 @@ class PythonProgram:
loop_ends[idp[0]] = i
i = idp[0]
continue
if uop in (Ops.BARRIER, Ops.IF, Ops.ENDIF):
if uop in (Ops.BARRIER, Ops.IF, Ops.ENDIF, Ops.NAME):
# in the python emulator, the warp is always in sync
i += 1
continue
@@ -196,7 +196,7 @@ class PythonRenderer(Renderer):
if getenv("EMULATE_INTEL"): self.device, self.suffix, self.tensor_cores = "INTEL", "INTEL", IntelRenderer.tensor_cores
if getenv("EMULATE_AMX"): self.device, self.tensor_cores = "CLANG", ClangRenderer.tensor_cores
def render(self, name:str, uops:list[UOp]) -> str:
def render(self, uops:list[UOp]) -> str:
lops = [(u.op, u.dtype, [uops.index(v) for v in u.src], u.arg) for u in uops]
return base64.b64encode(pickle.dumps(lops)).decode()

View File

@@ -110,7 +110,7 @@ spec = PatternMatcher([
# NOTE: for testing, we let sinks be anything
#(UPat(Ops.SINK, src=UPat(Ops.STORE)), lambda: True),
(UPat(Ops.SINK, dtypes.void), lambda: True),
(UPat((Ops.NAME, Ops.SINK), dtypes.void), lambda: True),
(UPat((Ops.NOOP, Ops.CUSTOM)), lambda: True),
# PTX LOAD/STORE

View File

@@ -14,7 +14,7 @@ uops_colors = {Ops.LOAD: "#ffc0c0", Ops.STORE: "#87CEEB", Ops.CONST: "#e0e0e0",
Ops.RANGE: "#c8a0e0", Ops.ASSIGN: "#e0ffc0", Ops.BARRIER: "#ff8080", Ops.IF: "#c8b0c0", Ops.SPECIAL: "#c0c0ff",
Ops.INDEX: "#e8ffa0", Ops.WMMA: "#efefc0", Ops.VIEW: "#C8F9D4", Ops.MULTI: "#f6ccff", Ops.KERNEL: "#3e7f55",
**{x:"#D8F9E4" for x in GroupOp.Movement}, **{x:"#ffffc0" for x in GroupOp.ALU}, Ops.THREEFRY:"#ffff80",
Ops.BLOCK: "#C4A484", Ops.BLOCKEND: "#C4A4A4", Ops.BUFFER: "#B0BDFF", Ops.COPY: "#a040a0"}
Ops.BLOCK: "#C4A484", Ops.BLOCKEND: "#C4A4A4", Ops.BUFFER: "#B0BDFF", Ops.COPY: "#a040a0", Ops.NAME:"#808080"}
# VIZ API