From 3d3c5b2fb984ad0feee32fdedd7fee52c0fe4c3f Mon Sep 17 00:00:00 2001 From: George Hotz <72895+geohot@users.noreply.github.com> Date: Tue, 23 Dec 2025 16:15:33 -0500 Subject: [PATCH] add device to program (#13815) * add device to program * from_uop * from_uop no renderer * simpler global_size --- tinygrad/codegen/__init__.py | 16 ++++++++++------ tinygrad/codegen/opt/search.py | 2 +- tinygrad/engine/jit.py | 2 +- tinygrad/engine/realize.py | 31 +++++++------------------------ tinygrad/renderer/__init__.py | 12 ++++++++++-- tinygrad/runtime/ops_dsp.py | 4 ++-- tinygrad/uop/spec.py | 8 ++++---- 7 files changed, 35 insertions(+), 40 deletions(-) diff --git a/tinygrad/codegen/__init__.py b/tinygrad/codegen/__init__.py index f59aee5592..038e4d683d 100644 --- a/tinygrad/codegen/__init__.py +++ b/tinygrad/codegen/__init__.py @@ -1,7 +1,7 @@ from typing import cast import itertools -from tinygrad.helpers import DEVECTORIZE, TRANSCENDENTAL, SPEC -from tinygrad.uop.ops import PatternMatcher, graph_rewrite, UOp, pm_lower_index_dtype, Ops, UPat +from tinygrad.helpers import DEVECTORIZE, TRANSCENDENTAL, SPEC, DEBUG +from tinygrad.uop.ops import PatternMatcher, graph_rewrite, UOp, pm_lower_index_dtype, Ops, UPat, print_uops from tinygrad.uop.spec import type_verify, program_spec, kernel_spec from tinygrad.renderer import Renderer from tinygrad.dtype import dtypes, PtrDType @@ -133,11 +133,15 @@ def do_render(ctx:Renderer, prg:UOp, lin:UOp) -> UOp: return prg.replace(src=prg.src + (UOp(Ops.SOURCE, arg=src),)) pm_to_program = PatternMatcher([ - (UPat(Ops.PROGRAM, src=(UPat(Ops.SINK, name="sink"),), name="prg"), do_linearize), - (UPat(Ops.PROGRAM, src=(UPat(), UPat(Ops.LINEAR, name="lin")), name="prg"), do_render), + (UPat(Ops.PROGRAM, src=(UPat(Ops.SINK, name="sink"), UPat(Ops.DEVICE)), name="prg"), do_linearize), + (UPat(Ops.PROGRAM, src=(UPat(), UPat(Ops.DEVICE), UPat(Ops.LINEAR, name="lin")), name="prg"), do_render), ]) def full_rewrite_to_program(sink:UOp, ren:Renderer) -> UOp: + from tinygrad.uop.ops import KernelInfo + if sink.arg is None: sink = sink.replace(arg=KernelInfo()) full_sink = full_rewrite_to_sink(sink, ren, optimize=sink.tag is None) - sink = UOp(Ops.PROGRAM, src=(full_sink,)) - return graph_rewrite(sink, pm_to_program, ctx=ren, name="linearize/render") + sink = UOp(Ops.PROGRAM, src=(full_sink, UOp(Ops.DEVICE, arg=ren.device))) + prg = graph_rewrite(sink, pm_to_program, ctx=ren, name="linearize/render") + if DEBUG >= 6: print_uops(list(prg.src[2].src)) # LINEAR is src[2] + return prg diff --git a/tinygrad/codegen/opt/search.py b/tinygrad/codegen/opt/search.py index 084352c486..6dc6203ccf 100644 --- a/tinygrad/codegen/opt/search.py +++ b/tinygrad/codegen/opt/search.py @@ -37,7 +37,7 @@ def get_test_global_size(global_size, max_global_size, var_vals): def _time_program(p:ProgramSpec, lib:bytes, var_vals:dict[str, int], rawbufs:list[Buffer], early_stop:float|None=None, allow_test_size:int=True, max_global_size:int|None=65536, clear_l2=False, cnt=3, name="test") -> list[float]: factor = 1 - if allow_test_size and p.global_size is not None and max_global_size is not None: + if allow_test_size and max_global_size is not None: global_size, factor = get_test_global_size(p.global_size, max_global_size, var_vals) p = replace(p, global_size=global_size) try: car = CompiledRunner(p, precompiled=lib) diff --git a/tinygrad/engine/jit.py b/tinygrad/engine/jit.py index 89a652ed96..146472bcd0 100644 --- a/tinygrad/engine/jit.py +++ b/tinygrad/engine/jit.py @@ -97,7 +97,7 @@ class GraphRunner(Runner): global_dim_idx, local_dim_idx = find_symbolic_dim(ji.prg.p.global_size), find_symbolic_dim(ji.prg.p.local_size) if global_dim_idx is not None or local_dim_idx is not None: self.launch_dims_replace[j] = (global_dim_idx, local_dim_idx) - assert ji.prg.p.global_size is not None and ji.prg.p.local_size is not None + assert ji.prg.p.local_size is not None self.launch_dims_base[j] = (tuple(ji.prg.p.global_size), tuple(ji.prg.p.local_size)) # used in MultiGraphRunner. the ints are id() of _bufs diff --git a/tinygrad/engine/realize.py b/tinygrad/engine/realize.py index 52898f88a8..55755d8b4d 100644 --- a/tinygrad/engine/realize.py +++ b/tinygrad/engine/realize.py @@ -4,7 +4,7 @@ from dataclasses import dataclass, replace, field from tinygrad.helpers import all_same, colored, DEBUG, GlobalCounters, ansilen, BEAM, NOOPT, all_int, CAPTURING, Metadata, TRACEMETA, TracingKey from tinygrad.helpers import DEVECTORIZE, time_to_str, VALIDATE_WITH_CPU, getenv, cpu_profile, PROFILE, ProfilePointEvent, cpu_events, prod, Context from tinygrad.helpers import unwrap -from tinygrad.uop.ops import Ops, PatternMatcher, UOp, UPat, sym_infer, graph_rewrite, print_uops, track_rewrites, KernelInfo, pyrender +from tinygrad.uop.ops import Ops, PatternMatcher, UOp, UPat, sym_infer, graph_rewrite, track_rewrites, KernelInfo, pyrender from tinygrad.device import Device, Buffer from tinygrad.renderer import Renderer, ProgramSpec, Estimates from tinygrad.codegen import full_rewrite_to_program @@ -25,25 +25,15 @@ def get_program(ast:UOp, renderer:Renderer, opts:list[Opt]|None=None) -> Program The ProgramSpec of the program. """ - if getenv("VIZ"): graph_rewrite(ast, PatternMatcher([]), name="View Base AST") - if DEBUG >= 5: print(pyrender(ast)) - # linearize if opts is not None: assert ast.arg is None, "can't apply opts if sink has an arg" ast = ast.replace(arg=KernelInfo(opts_to_apply=tuple(opts))) - if ast.arg is None: ast = ast.replace(arg=KernelInfo()) - prg = full_rewrite_to_program(ast, renderer) - # SINK/LINEAR/SOURCE - sink, linear, source = prg.src + if getenv("VIZ"): graph_rewrite(ast, PatternMatcher([]), name="View Base AST") + if DEBUG >= 5: print(pyrender(ast)) - # print - if DEBUG >= 6: print_uops(list(linear.src)) - - return ProgramSpec(sink.arg.name, source.arg, renderer.device, sink, list(linear.src), - global_size=[1,1,1] if renderer.has_local or renderer.has_threads else None, - local_size=[1,1,1] if renderer.has_local else None) + return ProgramSpec.from_uop(full_rewrite_to_program(ast, renderer)) # **************** Runners **************** @@ -88,20 +78,13 @@ class CompiledRunner(Runner): def __call__(self, rawbufs:list[Buffer], var_vals:dict[str, int]|None=None, wait=False) -> float|None: if var_vals is None: var_vals = {} - has_local = Device[self.p.device].renderer.has_local global_size, local_size = self.p.launch_dims(var_vals) - if has_local and global_size is not None and local_size is None and all_int(self.p.global_size): # type: ignore[arg-type] + if Device[self.p.device].renderer.has_local and local_size is None and all_int(self.p.global_size): # type: ignore[arg-type] local_size = optimize_local_size(self._prg, global_size, rawbufs) global_size = [g//l if g%l == 0 else g/l for g,l in zip(global_size, local_size)] self.p = replace(self.p, global_size=global_size, local_size=local_size) - lra = {} - if global_size: - lra['global_size'] = tuple(global_size) - assert len(global_size) == 3, "global size must have len 3" - if local_size: - lra['local_size'] = tuple(local_size) - assert len(local_size) == 3, "local size must have len 3" - return self._prg(*[x._buf for x in rawbufs], **lra, vals=tuple(var_vals[k.expr] for k in self.p.vars), wait=wait) + return self._prg(*[x._buf for x in rawbufs], global_size=tuple(global_size), local_size=tuple(local_size) if local_size else None, + vals=tuple(var_vals[k.expr] for k in self.p.vars), wait=wait) class ViewOp(Runner): def __init__(self, buf:Buffer): super().__init__(colored(f"view {buf.nbytes:8d} @ {buf.offset:<10d}", "yellow"), buf.device) diff --git a/tinygrad/renderer/__init__.py b/tinygrad/renderer/__init__.py index c63dbff3df..62416ee4e8 100644 --- a/tinygrad/renderer/__init__.py +++ b/tinygrad/renderer/__init__.py @@ -66,7 +66,7 @@ class ProgramSpec: uops:list[UOp]|None=None # filled in from uops (if we have uops) - global_size:list[int]|None=None + global_size:list[int]=field(default_factory=lambda: [1,1,1]) local_size:list[int]|None=None vars:list[Variable]=field(default_factory=list) globals:list[int]=field(default_factory=list) @@ -109,10 +109,18 @@ class ProgramSpec: return self.uops[-1].arg.applied_opts def launch_dims(self, var_vals:dict[str, int]): - global_size = [sym_infer(sz, var_vals) for sz in self.global_size] if self.global_size is not None else None + global_size = [sym_infer(sz, var_vals) for sz in self.global_size] local_size = [sym_infer(sz, var_vals) for sz in self.local_size] if self.local_size is not None else None return global_size, local_size + @staticmethod + def from_uop(prg:UOp) -> ProgramSpec: + """Construct ProgramSpec from a PROGRAM UOp.""" + assert prg.op is Ops.PROGRAM, f"expected PROGRAM, got {prg.op}" + # SINK/DEVICE/LINEAR/SOURCE + sink, device, linear, source = prg.src + return ProgramSpec(sink.arg.name, source.arg, device.arg, sink, list(linear.src), global_size=[1,1,1], local_size=[1,1,1]) + class Renderer: device: str = "" suffix: str = "" diff --git a/tinygrad/runtime/ops_dsp.py b/tinygrad/runtime/ops_dsp.py index b41ef5e78a..bc6010571c 100644 --- a/tinygrad/runtime/ops_dsp.py +++ b/tinygrad/runtime/ops_dsp.py @@ -83,7 +83,7 @@ class DSPProgram: def __init__(self, dev:DSPDevice, name:str, lib:bytes): self.dev, self.lib = dev, lib - def __call__(self, *bufs, vals:tuple[int, ...]=(), wait=False): + def __call__(self, *bufs, global_size:tuple[int,int,int]=(1,1,1), local_size:tuple[int,int,int]=(1,1,1), vals:tuple[int, ...]=(), wait=False): if len(bufs) >= 16: raise RuntimeError(f"Too many buffers to execute: {len(bufs)}") pra, fds, attrs, _ = rpc_prep_args(ins=[var_vals_mv:=memoryview(bytearray((len(bufs)+len(vals))*4)), off_mv:=memoryview(bytearray(len(bufs)*4))], @@ -289,7 +289,7 @@ class MockDSPRenderer(DSPRenderer): class MockDSPProgram: def __init__(self, name:str, lib:bytes): self.lib = lib - def __call__(self, *bufs, vals:tuple[int, ...]=(), wait=False): + def __call__(self, *bufs, global_size:tuple[int,int,int]=(1,1,1), local_size:tuple[int,int,int]=(1,1,1), vals:tuple[int, ...]=(), wait=False): with tempfile.NamedTemporaryFile(suffix=".out") as dsp_lib: dsp_lib.write(self.lib) dsp_lib.flush() diff --git a/tinygrad/uop/spec.py b/tinygrad/uop/spec.py index 0c8eee5cbc..8988a5f5bf 100644 --- a/tinygrad/uop/spec.py +++ b/tinygrad/uop/spec.py @@ -249,10 +249,10 @@ full_spec = PatternMatcher([ # in progress MSTACK may lose device (UPat((Ops.MSELECT, Ops.MSTACK), name="x"), lambda x: True), - # codegen: PROGRAM with progressive sources through the pipeline - (UPat(Ops.PROGRAM, dtypes.void, src=(UPat(Ops.SINK),)), lambda: True), - (UPat(Ops.PROGRAM, dtypes.void, src=(UPat(Ops.SINK), UPat(Ops.LINEAR))), lambda: True), - (UPat(Ops.PROGRAM, dtypes.void, src=(UPat(Ops.SINK), UPat(Ops.LINEAR), UPat(Ops.SOURCE))), lambda: True), + # codegen: PROGRAM with progressive sources through the pipeline (SINK, DEVICE, LINEAR?, SOURCE?) + (UPat(Ops.PROGRAM, dtypes.void, src=(UPat(Ops.SINK), UPat(Ops.DEVICE))), lambda: True), + (UPat(Ops.PROGRAM, dtypes.void, src=(UPat(Ops.SINK), UPat(Ops.DEVICE), UPat(Ops.LINEAR))), lambda: True), + (UPat(Ops.PROGRAM, dtypes.void, src=(UPat(Ops.SINK), UPat(Ops.DEVICE), UPat(Ops.LINEAR), UPat(Ops.SOURCE))), lambda: True), # codegen: standalone LINEAR/SOURCE (UPat(Ops.LINEAR, dtypes.void), lambda: True), (UPat(Ops.SOURCE, dtypes.void, src=()), lambda: True),