one compiler path

This commit is contained in:
George Hotz
2025-12-22 19:20:04 -05:00
parent 339dadf056
commit 2aec58654a
3 changed files with 6 additions and 12 deletions

View File

@@ -23,8 +23,9 @@ def _test_uop_result(inputs:list[Tensor], stores:list[UOp], local_size=None):
initial_value=np.zeros(sz, dtype=_to_np_dtype(dtype)).data) for u in uops if u.op is Ops.STORE] initial_value=np.zeros(sz, dtype=_to_np_dtype(dtype)).data) for u in uops if u.op is Ops.STORE]
inbufs = [x.uop.base.buffer for x in inputs] inbufs = [x.uop.base.buffer for x in inputs]
src = Device[Device.DEFAULT].renderer.render(uops) src = Device[Device.DEFAULT].renderer.render(uops)
lib = Device[Device.DEFAULT].compiler.compile_cached(src)
ei = CompiledRunner(ProgramSpec(uops[-1].arg.name if uops[-1].arg is not None else "test", ei = CompiledRunner(ProgramSpec(uops[-1].arg.name if uops[-1].arg is not None else "test",
src, Device.DEFAULT, uops[-1], uops=uops, local_size=local_size)) src, Device.DEFAULT, uops[-1], uops=uops, lib=lib, local_size=local_size))
ei.exec(outbufs+inbufs) ei.exec(outbufs+inbufs)
return [np.frombuffer(x.as_buffer(), _to_np_dtype(x.dtype)) for x in outbufs] return [np.frombuffer(x.as_buffer(), _to_np_dtype(x.dtype)) for x in outbufs]

View File

@@ -7,7 +7,6 @@ from tinygrad.dtype import dtypes, DType, AddrSpace
from tinygrad.device import Buffer, Device from tinygrad.device import Buffer, Device
from tinygrad.uop.ops import Ops, UOp, UPat, KernelInfo, exec_alu, AxisType from tinygrad.uop.ops import Ops, UOp, UPat, KernelInfo, exec_alu, AxisType
from tinygrad.uop.spec import shared_spec from tinygrad.uop.spec import shared_spec
from tinygrad.renderer import ProgramSpec
from tinygrad.renderer.cstyle import CStyleLanguage from tinygrad.renderer.cstyle import CStyleLanguage
from tinygrad.engine.realize import CompiledRunner, get_program, get_runner from tinygrad.engine.realize import CompiledRunner, get_program, get_runner
from tinygrad.engine.schedule import ExecItem from tinygrad.engine.schedule import ExecItem
@@ -26,11 +25,7 @@ def to_uops_list(u:list[UOp], ren=None) -> list[UOp]:
return ret[:-1] return ret[:-1]
def _uops_to_prg(uops_list): def _uops_to_prg(uops_list):
uops = full_rewrite(ast:=UOp.sink(*uops_list), ren=Device[Device.DEFAULT].renderer) return CompiledRunner(get_program(UOp.sink(*uops_list), Device[Device.DEFAULT].renderer, Device.DEFAULT))
src = Device[Device.DEFAULT].renderer.render(uops)
has_local = Device[Device.DEFAULT].renderer.has_local
return CompiledRunner(ProgramSpec(uops[-1].arg.name if uops[-1].arg is not None else "test", src, Device.DEFAULT, ast, uops=uops,
global_size=[1,1,1] if has_local else None, local_size=[1,1,1] if has_local else None))
def uop(uops:list[UOp], uop:Ops, dtype:Optional[DType], src:tuple[UOp, ...], arg:Any=None) -> UOp: def uop(uops:list[UOp], uop:Ops, dtype:Optional[DType], src:tuple[UOp, ...], arg:Any=None) -> UOp:
uops.append(UOp(uop, dtype, tuple(src), arg)) uops.append(UOp(uop, dtype, tuple(src), arg))

View File

@@ -2,7 +2,7 @@ from typing import cast, Callable
import time, pprint, random, itertools, math import time, pprint, random, itertools, math
from dataclasses import dataclass, replace, field from dataclasses import dataclass, replace, field
from tinygrad.helpers import all_same, colored, DEBUG, GlobalCounters, ansilen, BEAM, NOOPT, all_int, CAPTURING, Metadata, TRACEMETA, TracingKey from tinygrad.helpers import all_same, colored, DEBUG, GlobalCounters, ansilen, BEAM, NOOPT, all_int, CAPTURING, Metadata, TRACEMETA, TracingKey
from tinygrad.helpers import DEVECTORIZE, time_to_str, VALIDATE_WITH_CPU, getenv, cpu_profile, PROFILE, ProfilePointEvent, cpu_events, prod, Context from tinygrad.helpers import DEVECTORIZE, time_to_str, VALIDATE_WITH_CPU, getenv, PROFILE, ProfilePointEvent, cpu_events, prod, Context
from tinygrad.helpers import unwrap from tinygrad.helpers import unwrap
from tinygrad.uop.ops import Ops, PatternMatcher, UOp, UPat, sym_infer, graph_rewrite, track_rewrites, KernelInfo, pyrender from tinygrad.uop.ops import Ops, PatternMatcher, UOp, UPat, sym_infer, graph_rewrite, track_rewrites, KernelInfo, pyrender
from tinygrad.device import Device, Buffer from tinygrad.device import Device, Buffer
@@ -75,11 +75,9 @@ class CompiledRunner(Runner):
def __init__(self, p:ProgramSpec, prg=None): def __init__(self, p:ProgramSpec, prg=None):
if DEBUG >= 3: print(p.applied_opts) if DEBUG >= 3: print(p.applied_opts)
if DEBUG >= 4: print(p.src) if DEBUG >= 4: print(p.src)
if p.lib is None: assert p.lib is not None, "lib must be provided"
with cpu_profile(TracingKey(f"compile {p.name}", (p.function_name,)), "TINY"):
p = replace(p, lib=Device[p.device].compiler.compile_cached(p.src))
self.p:ProgramSpec = p self.p:ProgramSpec = p
if DEBUG >= 7: Device[p.device].compiler.disassemble(unwrap(p.lib)) if DEBUG >= 7: Device[p.device].compiler.disassemble(p.lib)
self._prg = Device[p.device].runtime(p.function_name, p.lib) if prg is None else prg self._prg = Device[p.device].runtime(p.function_name, p.lib) if prg is None else prg
super().__init__(p.name, p.device, p.estimates) super().__init__(p.name, p.device, p.estimates)