add device to program (#13815)

* add device to program

* from_uop

* from_uop no renderer

* simpler global_size
This commit is contained in:
George Hotz
2025-12-23 16:15:33 -05:00
committed by GitHub
parent 90b217896f
commit 3d3c5b2fb9
7 changed files with 35 additions and 40 deletions

View File

@@ -1,7 +1,7 @@
from typing import cast
import itertools
from tinygrad.helpers import DEVECTORIZE, TRANSCENDENTAL, SPEC
from tinygrad.uop.ops import PatternMatcher, graph_rewrite, UOp, pm_lower_index_dtype, Ops, UPat
from tinygrad.helpers import DEVECTORIZE, TRANSCENDENTAL, SPEC, DEBUG
from tinygrad.uop.ops import PatternMatcher, graph_rewrite, UOp, pm_lower_index_dtype, Ops, UPat, print_uops
from tinygrad.uop.spec import type_verify, program_spec, kernel_spec
from tinygrad.renderer import Renderer
from tinygrad.dtype import dtypes, PtrDType
@@ -133,11 +133,15 @@ def do_render(ctx:Renderer, prg:UOp, lin:UOp) -> UOp:
return prg.replace(src=prg.src + (UOp(Ops.SOURCE, arg=src),))
pm_to_program = PatternMatcher([
(UPat(Ops.PROGRAM, src=(UPat(Ops.SINK, name="sink"),), name="prg"), do_linearize),
(UPat(Ops.PROGRAM, src=(UPat(), UPat(Ops.LINEAR, name="lin")), name="prg"), do_render),
(UPat(Ops.PROGRAM, src=(UPat(Ops.SINK, name="sink"), UPat(Ops.DEVICE)), name="prg"), do_linearize),
(UPat(Ops.PROGRAM, src=(UPat(), UPat(Ops.DEVICE), UPat(Ops.LINEAR, name="lin")), name="prg"), do_render),
])
def full_rewrite_to_program(sink:UOp, ren:Renderer) -> UOp:
from tinygrad.uop.ops import KernelInfo
if sink.arg is None: sink = sink.replace(arg=KernelInfo())
full_sink = full_rewrite_to_sink(sink, ren, optimize=sink.tag is None)
sink = UOp(Ops.PROGRAM, src=(full_sink,))
return graph_rewrite(sink, pm_to_program, ctx=ren, name="linearize/render")
sink = UOp(Ops.PROGRAM, src=(full_sink, UOp(Ops.DEVICE, arg=ren.device)))
prg = graph_rewrite(sink, pm_to_program, ctx=ren, name="linearize/render")
if DEBUG >= 6: print_uops(list(prg.src[2].src)) # LINEAR is src[2]
return prg

View File

@@ -37,7 +37,7 @@ def get_test_global_size(global_size, max_global_size, var_vals):
def _time_program(p:ProgramSpec, lib:bytes, var_vals:dict[str, int], rawbufs:list[Buffer], early_stop:float|None=None,
allow_test_size:int=True, max_global_size:int|None=65536, clear_l2=False, cnt=3, name="test") -> list[float]:
factor = 1
if allow_test_size and p.global_size is not None and max_global_size is not None:
if allow_test_size and max_global_size is not None:
global_size, factor = get_test_global_size(p.global_size, max_global_size, var_vals)
p = replace(p, global_size=global_size)
try: car = CompiledRunner(p, precompiled=lib)

View File

@@ -97,7 +97,7 @@ class GraphRunner(Runner):
global_dim_idx, local_dim_idx = find_symbolic_dim(ji.prg.p.global_size), find_symbolic_dim(ji.prg.p.local_size)
if global_dim_idx is not None or local_dim_idx is not None:
self.launch_dims_replace[j] = (global_dim_idx, local_dim_idx)
assert ji.prg.p.global_size is not None and ji.prg.p.local_size is not None
assert ji.prg.p.local_size is not None
self.launch_dims_base[j] = (tuple(ji.prg.p.global_size), tuple(ji.prg.p.local_size))
# used in MultiGraphRunner. the ints are id() of _bufs

View File

@@ -4,7 +4,7 @@ from dataclasses import dataclass, replace, field
from tinygrad.helpers import all_same, colored, DEBUG, GlobalCounters, ansilen, BEAM, NOOPT, all_int, CAPTURING, Metadata, TRACEMETA, TracingKey
from tinygrad.helpers import DEVECTORIZE, time_to_str, VALIDATE_WITH_CPU, getenv, cpu_profile, PROFILE, ProfilePointEvent, cpu_events, prod, Context
from tinygrad.helpers import unwrap
from tinygrad.uop.ops import Ops, PatternMatcher, UOp, UPat, sym_infer, graph_rewrite, print_uops, track_rewrites, KernelInfo, pyrender
from tinygrad.uop.ops import Ops, PatternMatcher, UOp, UPat, sym_infer, graph_rewrite, track_rewrites, KernelInfo, pyrender
from tinygrad.device import Device, Buffer
from tinygrad.renderer import Renderer, ProgramSpec, Estimates
from tinygrad.codegen import full_rewrite_to_program
@@ -25,25 +25,15 @@ def get_program(ast:UOp, renderer:Renderer, opts:list[Opt]|None=None) -> Program
The ProgramSpec of the program.
"""
if getenv("VIZ"): graph_rewrite(ast, PatternMatcher([]), name="View Base AST")
if DEBUG >= 5: print(pyrender(ast))
# linearize
if opts is not None:
assert ast.arg is None, "can't apply opts if sink has an arg"
ast = ast.replace(arg=KernelInfo(opts_to_apply=tuple(opts)))
if ast.arg is None: ast = ast.replace(arg=KernelInfo())
prg = full_rewrite_to_program(ast, renderer)
# SINK/LINEAR/SOURCE
sink, linear, source = prg.src
if getenv("VIZ"): graph_rewrite(ast, PatternMatcher([]), name="View Base AST")
if DEBUG >= 5: print(pyrender(ast))
# print
if DEBUG >= 6: print_uops(list(linear.src))
return ProgramSpec(sink.arg.name, source.arg, renderer.device, sink, list(linear.src),
global_size=[1,1,1] if renderer.has_local or renderer.has_threads else None,
local_size=[1,1,1] if renderer.has_local else None)
return ProgramSpec.from_uop(full_rewrite_to_program(ast, renderer))
# **************** Runners ****************
@@ -88,20 +78,13 @@ class CompiledRunner(Runner):
def __call__(self, rawbufs:list[Buffer], var_vals:dict[str, int]|None=None, wait=False) -> float|None:
if var_vals is None: var_vals = {}
has_local = Device[self.p.device].renderer.has_local
global_size, local_size = self.p.launch_dims(var_vals)
if has_local and global_size is not None and local_size is None and all_int(self.p.global_size): # type: ignore[arg-type]
if Device[self.p.device].renderer.has_local and local_size is None and all_int(self.p.global_size): # type: ignore[arg-type]
local_size = optimize_local_size(self._prg, global_size, rawbufs)
global_size = [g//l if g%l == 0 else g/l for g,l in zip(global_size, local_size)]
self.p = replace(self.p, global_size=global_size, local_size=local_size)
lra = {}
if global_size:
lra['global_size'] = tuple(global_size)
assert len(global_size) == 3, "global size must have len 3"
if local_size:
lra['local_size'] = tuple(local_size)
assert len(local_size) == 3, "local size must have len 3"
return self._prg(*[x._buf for x in rawbufs], **lra, vals=tuple(var_vals[k.expr] for k in self.p.vars), wait=wait)
return self._prg(*[x._buf for x in rawbufs], global_size=tuple(global_size), local_size=tuple(local_size) if local_size else None,
vals=tuple(var_vals[k.expr] for k in self.p.vars), wait=wait)
class ViewOp(Runner):
def __init__(self, buf:Buffer): super().__init__(colored(f"view {buf.nbytes:8d} @ {buf.offset:<10d}", "yellow"), buf.device)

View File

@@ -66,7 +66,7 @@ class ProgramSpec:
uops:list[UOp]|None=None
# filled in from uops (if we have uops)
global_size:list[int]|None=None
global_size:list[int]=field(default_factory=lambda: [1,1,1])
local_size:list[int]|None=None
vars:list[Variable]=field(default_factory=list)
globals:list[int]=field(default_factory=list)
@@ -109,10 +109,18 @@ class ProgramSpec:
return self.uops[-1].arg.applied_opts
def launch_dims(self, var_vals:dict[str, int]):
global_size = [sym_infer(sz, var_vals) for sz in self.global_size] if self.global_size is not None else None
global_size = [sym_infer(sz, var_vals) for sz in self.global_size]
local_size = [sym_infer(sz, var_vals) for sz in self.local_size] if self.local_size is not None else None
return global_size, local_size
@staticmethod
def from_uop(prg:UOp) -> ProgramSpec:
"""Construct ProgramSpec from a PROGRAM UOp."""
assert prg.op is Ops.PROGRAM, f"expected PROGRAM, got {prg.op}"
# SINK/DEVICE/LINEAR/SOURCE
sink, device, linear, source = prg.src
return ProgramSpec(sink.arg.name, source.arg, device.arg, sink, list(linear.src), global_size=[1,1,1], local_size=[1,1,1])
class Renderer:
device: str = ""
suffix: str = ""

View File

@@ -83,7 +83,7 @@ class DSPProgram:
def __init__(self, dev:DSPDevice, name:str, lib:bytes):
self.dev, self.lib = dev, lib
def __call__(self, *bufs, vals:tuple[int, ...]=(), wait=False):
def __call__(self, *bufs, global_size:tuple[int,int,int]=(1,1,1), local_size:tuple[int,int,int]=(1,1,1), vals:tuple[int, ...]=(), wait=False):
if len(bufs) >= 16: raise RuntimeError(f"Too many buffers to execute: {len(bufs)}")
pra, fds, attrs, _ = rpc_prep_args(ins=[var_vals_mv:=memoryview(bytearray((len(bufs)+len(vals))*4)), off_mv:=memoryview(bytearray(len(bufs)*4))],
@@ -289,7 +289,7 @@ class MockDSPRenderer(DSPRenderer):
class MockDSPProgram:
def __init__(self, name:str, lib:bytes): self.lib = lib
def __call__(self, *bufs, vals:tuple[int, ...]=(), wait=False):
def __call__(self, *bufs, global_size:tuple[int,int,int]=(1,1,1), local_size:tuple[int,int,int]=(1,1,1), vals:tuple[int, ...]=(), wait=False):
with tempfile.NamedTemporaryFile(suffix=".out") as dsp_lib:
dsp_lib.write(self.lib)
dsp_lib.flush()

View File

@@ -249,10 +249,10 @@ full_spec = PatternMatcher([
# in progress MSTACK may lose device
(UPat((Ops.MSELECT, Ops.MSTACK), name="x"), lambda x: True),
# codegen: PROGRAM with progressive sources through the pipeline
(UPat(Ops.PROGRAM, dtypes.void, src=(UPat(Ops.SINK),)), lambda: True),
(UPat(Ops.PROGRAM, dtypes.void, src=(UPat(Ops.SINK), UPat(Ops.LINEAR))), lambda: True),
(UPat(Ops.PROGRAM, dtypes.void, src=(UPat(Ops.SINK), UPat(Ops.LINEAR), UPat(Ops.SOURCE))), lambda: True),
# codegen: PROGRAM with progressive sources through the pipeline (SINK, DEVICE, LINEAR?, SOURCE?)
(UPat(Ops.PROGRAM, dtypes.void, src=(UPat(Ops.SINK), UPat(Ops.DEVICE))), lambda: True),
(UPat(Ops.PROGRAM, dtypes.void, src=(UPat(Ops.SINK), UPat(Ops.DEVICE), UPat(Ops.LINEAR))), lambda: True),
(UPat(Ops.PROGRAM, dtypes.void, src=(UPat(Ops.SINK), UPat(Ops.DEVICE), UPat(Ops.LINEAR), UPat(Ops.SOURCE))), lambda: True),
# codegen: standalone LINEAR/SOURCE
(UPat(Ops.LINEAR, dtypes.void), lambda: True),
(UPat(Ops.SOURCE, dtypes.void, src=()), lambda: True),