mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-07 22:23:55 -05:00
add device to program (#13815)
* add device to program * from_uop * from_uop no renderer * simpler global_size
This commit is contained in:
@@ -1,7 +1,7 @@
|
||||
from typing import cast
|
||||
import itertools
|
||||
from tinygrad.helpers import DEVECTORIZE, TRANSCENDENTAL, SPEC
|
||||
from tinygrad.uop.ops import PatternMatcher, graph_rewrite, UOp, pm_lower_index_dtype, Ops, UPat
|
||||
from tinygrad.helpers import DEVECTORIZE, TRANSCENDENTAL, SPEC, DEBUG
|
||||
from tinygrad.uop.ops import PatternMatcher, graph_rewrite, UOp, pm_lower_index_dtype, Ops, UPat, print_uops
|
||||
from tinygrad.uop.spec import type_verify, program_spec, kernel_spec
|
||||
from tinygrad.renderer import Renderer
|
||||
from tinygrad.dtype import dtypes, PtrDType
|
||||
@@ -133,11 +133,15 @@ def do_render(ctx:Renderer, prg:UOp, lin:UOp) -> UOp:
|
||||
return prg.replace(src=prg.src + (UOp(Ops.SOURCE, arg=src),))
|
||||
|
||||
pm_to_program = PatternMatcher([
|
||||
(UPat(Ops.PROGRAM, src=(UPat(Ops.SINK, name="sink"),), name="prg"), do_linearize),
|
||||
(UPat(Ops.PROGRAM, src=(UPat(), UPat(Ops.LINEAR, name="lin")), name="prg"), do_render),
|
||||
(UPat(Ops.PROGRAM, src=(UPat(Ops.SINK, name="sink"), UPat(Ops.DEVICE)), name="prg"), do_linearize),
|
||||
(UPat(Ops.PROGRAM, src=(UPat(), UPat(Ops.DEVICE), UPat(Ops.LINEAR, name="lin")), name="prg"), do_render),
|
||||
])
|
||||
|
||||
def full_rewrite_to_program(sink:UOp, ren:Renderer) -> UOp:
|
||||
from tinygrad.uop.ops import KernelInfo
|
||||
if sink.arg is None: sink = sink.replace(arg=KernelInfo())
|
||||
full_sink = full_rewrite_to_sink(sink, ren, optimize=sink.tag is None)
|
||||
sink = UOp(Ops.PROGRAM, src=(full_sink,))
|
||||
return graph_rewrite(sink, pm_to_program, ctx=ren, name="linearize/render")
|
||||
sink = UOp(Ops.PROGRAM, src=(full_sink, UOp(Ops.DEVICE, arg=ren.device)))
|
||||
prg = graph_rewrite(sink, pm_to_program, ctx=ren, name="linearize/render")
|
||||
if DEBUG >= 6: print_uops(list(prg.src[2].src)) # LINEAR is src[2]
|
||||
return prg
|
||||
|
||||
@@ -37,7 +37,7 @@ def get_test_global_size(global_size, max_global_size, var_vals):
|
||||
def _time_program(p:ProgramSpec, lib:bytes, var_vals:dict[str, int], rawbufs:list[Buffer], early_stop:float|None=None,
|
||||
allow_test_size:int=True, max_global_size:int|None=65536, clear_l2=False, cnt=3, name="test") -> list[float]:
|
||||
factor = 1
|
||||
if allow_test_size and p.global_size is not None and max_global_size is not None:
|
||||
if allow_test_size and max_global_size is not None:
|
||||
global_size, factor = get_test_global_size(p.global_size, max_global_size, var_vals)
|
||||
p = replace(p, global_size=global_size)
|
||||
try: car = CompiledRunner(p, precompiled=lib)
|
||||
|
||||
@@ -97,7 +97,7 @@ class GraphRunner(Runner):
|
||||
global_dim_idx, local_dim_idx = find_symbolic_dim(ji.prg.p.global_size), find_symbolic_dim(ji.prg.p.local_size)
|
||||
if global_dim_idx is not None or local_dim_idx is not None:
|
||||
self.launch_dims_replace[j] = (global_dim_idx, local_dim_idx)
|
||||
assert ji.prg.p.global_size is not None and ji.prg.p.local_size is not None
|
||||
assert ji.prg.p.local_size is not None
|
||||
self.launch_dims_base[j] = (tuple(ji.prg.p.global_size), tuple(ji.prg.p.local_size))
|
||||
|
||||
# used in MultiGraphRunner. the ints are id() of _bufs
|
||||
|
||||
@@ -4,7 +4,7 @@ from dataclasses import dataclass, replace, field
|
||||
from tinygrad.helpers import all_same, colored, DEBUG, GlobalCounters, ansilen, BEAM, NOOPT, all_int, CAPTURING, Metadata, TRACEMETA, TracingKey
|
||||
from tinygrad.helpers import DEVECTORIZE, time_to_str, VALIDATE_WITH_CPU, getenv, cpu_profile, PROFILE, ProfilePointEvent, cpu_events, prod, Context
|
||||
from tinygrad.helpers import unwrap
|
||||
from tinygrad.uop.ops import Ops, PatternMatcher, UOp, UPat, sym_infer, graph_rewrite, print_uops, track_rewrites, KernelInfo, pyrender
|
||||
from tinygrad.uop.ops import Ops, PatternMatcher, UOp, UPat, sym_infer, graph_rewrite, track_rewrites, KernelInfo, pyrender
|
||||
from tinygrad.device import Device, Buffer
|
||||
from tinygrad.renderer import Renderer, ProgramSpec, Estimates
|
||||
from tinygrad.codegen import full_rewrite_to_program
|
||||
@@ -25,25 +25,15 @@ def get_program(ast:UOp, renderer:Renderer, opts:list[Opt]|None=None) -> Program
|
||||
The ProgramSpec of the program.
|
||||
"""
|
||||
|
||||
if getenv("VIZ"): graph_rewrite(ast, PatternMatcher([]), name="View Base AST")
|
||||
if DEBUG >= 5: print(pyrender(ast))
|
||||
|
||||
# linearize
|
||||
if opts is not None:
|
||||
assert ast.arg is None, "can't apply opts if sink has an arg"
|
||||
ast = ast.replace(arg=KernelInfo(opts_to_apply=tuple(opts)))
|
||||
if ast.arg is None: ast = ast.replace(arg=KernelInfo())
|
||||
|
||||
prg = full_rewrite_to_program(ast, renderer)
|
||||
# SINK/LINEAR/SOURCE
|
||||
sink, linear, source = prg.src
|
||||
if getenv("VIZ"): graph_rewrite(ast, PatternMatcher([]), name="View Base AST")
|
||||
if DEBUG >= 5: print(pyrender(ast))
|
||||
|
||||
# print
|
||||
if DEBUG >= 6: print_uops(list(linear.src))
|
||||
|
||||
return ProgramSpec(sink.arg.name, source.arg, renderer.device, sink, list(linear.src),
|
||||
global_size=[1,1,1] if renderer.has_local or renderer.has_threads else None,
|
||||
local_size=[1,1,1] if renderer.has_local else None)
|
||||
return ProgramSpec.from_uop(full_rewrite_to_program(ast, renderer))
|
||||
|
||||
# **************** Runners ****************
|
||||
|
||||
@@ -88,20 +78,13 @@ class CompiledRunner(Runner):
|
||||
|
||||
def __call__(self, rawbufs:list[Buffer], var_vals:dict[str, int]|None=None, wait=False) -> float|None:
|
||||
if var_vals is None: var_vals = {}
|
||||
has_local = Device[self.p.device].renderer.has_local
|
||||
global_size, local_size = self.p.launch_dims(var_vals)
|
||||
if has_local and global_size is not None and local_size is None and all_int(self.p.global_size): # type: ignore[arg-type]
|
||||
if Device[self.p.device].renderer.has_local and local_size is None and all_int(self.p.global_size): # type: ignore[arg-type]
|
||||
local_size = optimize_local_size(self._prg, global_size, rawbufs)
|
||||
global_size = [g//l if g%l == 0 else g/l for g,l in zip(global_size, local_size)]
|
||||
self.p = replace(self.p, global_size=global_size, local_size=local_size)
|
||||
lra = {}
|
||||
if global_size:
|
||||
lra['global_size'] = tuple(global_size)
|
||||
assert len(global_size) == 3, "global size must have len 3"
|
||||
if local_size:
|
||||
lra['local_size'] = tuple(local_size)
|
||||
assert len(local_size) == 3, "local size must have len 3"
|
||||
return self._prg(*[x._buf for x in rawbufs], **lra, vals=tuple(var_vals[k.expr] for k in self.p.vars), wait=wait)
|
||||
return self._prg(*[x._buf for x in rawbufs], global_size=tuple(global_size), local_size=tuple(local_size) if local_size else None,
|
||||
vals=tuple(var_vals[k.expr] for k in self.p.vars), wait=wait)
|
||||
|
||||
class ViewOp(Runner):
|
||||
def __init__(self, buf:Buffer): super().__init__(colored(f"view {buf.nbytes:8d} @ {buf.offset:<10d}", "yellow"), buf.device)
|
||||
|
||||
@@ -66,7 +66,7 @@ class ProgramSpec:
|
||||
uops:list[UOp]|None=None
|
||||
|
||||
# filled in from uops (if we have uops)
|
||||
global_size:list[int]|None=None
|
||||
global_size:list[int]=field(default_factory=lambda: [1,1,1])
|
||||
local_size:list[int]|None=None
|
||||
vars:list[Variable]=field(default_factory=list)
|
||||
globals:list[int]=field(default_factory=list)
|
||||
@@ -109,10 +109,18 @@ class ProgramSpec:
|
||||
return self.uops[-1].arg.applied_opts
|
||||
|
||||
def launch_dims(self, var_vals:dict[str, int]):
|
||||
global_size = [sym_infer(sz, var_vals) for sz in self.global_size] if self.global_size is not None else None
|
||||
global_size = [sym_infer(sz, var_vals) for sz in self.global_size]
|
||||
local_size = [sym_infer(sz, var_vals) for sz in self.local_size] if self.local_size is not None else None
|
||||
return global_size, local_size
|
||||
|
||||
@staticmethod
|
||||
def from_uop(prg:UOp) -> ProgramSpec:
|
||||
"""Construct ProgramSpec from a PROGRAM UOp."""
|
||||
assert prg.op is Ops.PROGRAM, f"expected PROGRAM, got {prg.op}"
|
||||
# SINK/DEVICE/LINEAR/SOURCE
|
||||
sink, device, linear, source = prg.src
|
||||
return ProgramSpec(sink.arg.name, source.arg, device.arg, sink, list(linear.src), global_size=[1,1,1], local_size=[1,1,1])
|
||||
|
||||
class Renderer:
|
||||
device: str = ""
|
||||
suffix: str = ""
|
||||
|
||||
@@ -83,7 +83,7 @@ class DSPProgram:
|
||||
def __init__(self, dev:DSPDevice, name:str, lib:bytes):
|
||||
self.dev, self.lib = dev, lib
|
||||
|
||||
def __call__(self, *bufs, vals:tuple[int, ...]=(), wait=False):
|
||||
def __call__(self, *bufs, global_size:tuple[int,int,int]=(1,1,1), local_size:tuple[int,int,int]=(1,1,1), vals:tuple[int, ...]=(), wait=False):
|
||||
if len(bufs) >= 16: raise RuntimeError(f"Too many buffers to execute: {len(bufs)}")
|
||||
|
||||
pra, fds, attrs, _ = rpc_prep_args(ins=[var_vals_mv:=memoryview(bytearray((len(bufs)+len(vals))*4)), off_mv:=memoryview(bytearray(len(bufs)*4))],
|
||||
@@ -289,7 +289,7 @@ class MockDSPRenderer(DSPRenderer):
|
||||
|
||||
class MockDSPProgram:
|
||||
def __init__(self, name:str, lib:bytes): self.lib = lib
|
||||
def __call__(self, *bufs, vals:tuple[int, ...]=(), wait=False):
|
||||
def __call__(self, *bufs, global_size:tuple[int,int,int]=(1,1,1), local_size:tuple[int,int,int]=(1,1,1), vals:tuple[int, ...]=(), wait=False):
|
||||
with tempfile.NamedTemporaryFile(suffix=".out") as dsp_lib:
|
||||
dsp_lib.write(self.lib)
|
||||
dsp_lib.flush()
|
||||
|
||||
@@ -249,10 +249,10 @@ full_spec = PatternMatcher([
|
||||
# in progress MSTACK may lose device
|
||||
(UPat((Ops.MSELECT, Ops.MSTACK), name="x"), lambda x: True),
|
||||
|
||||
# codegen: PROGRAM with progressive sources through the pipeline
|
||||
(UPat(Ops.PROGRAM, dtypes.void, src=(UPat(Ops.SINK),)), lambda: True),
|
||||
(UPat(Ops.PROGRAM, dtypes.void, src=(UPat(Ops.SINK), UPat(Ops.LINEAR))), lambda: True),
|
||||
(UPat(Ops.PROGRAM, dtypes.void, src=(UPat(Ops.SINK), UPat(Ops.LINEAR), UPat(Ops.SOURCE))), lambda: True),
|
||||
# codegen: PROGRAM with progressive sources through the pipeline (SINK, DEVICE, LINEAR?, SOURCE?)
|
||||
(UPat(Ops.PROGRAM, dtypes.void, src=(UPat(Ops.SINK), UPat(Ops.DEVICE))), lambda: True),
|
||||
(UPat(Ops.PROGRAM, dtypes.void, src=(UPat(Ops.SINK), UPat(Ops.DEVICE), UPat(Ops.LINEAR))), lambda: True),
|
||||
(UPat(Ops.PROGRAM, dtypes.void, src=(UPat(Ops.SINK), UPat(Ops.DEVICE), UPat(Ops.LINEAR), UPat(Ops.SOURCE))), lambda: True),
|
||||
# codegen: standalone LINEAR/SOURCE
|
||||
(UPat(Ops.LINEAR, dtypes.void), lambda: True),
|
||||
(UPat(Ops.SOURCE, dtypes.void, src=()), lambda: True),
|
||||
|
||||
Reference in New Issue
Block a user