add optional compiler in Renderer (#13817)

* add optional compiler in Renderer [pr]

* fix

* late init

* remove precompiled

* cleanup
This commit is contained in:
George Hotz
2025-12-23 17:58:46 -05:00
committed by GitHub
parent 8eab6175ee
commit 43c6e973d8
13 changed files with 57 additions and 260 deletions

View File

@@ -1,225 +0,0 @@
# [<buf device:HIP size:1605632 dtype:dtypes.float>, <buf device:HIP size:301506 dtype:dtypes.float>, <buf device:HIP size:9408 dtype:dtypes.float>]
from tinygrad import Device, dtypes
from tinygrad.device import Buffer, CompiledRunner
import ctypes
import gpuctypes.hip as hip
from tinygrad.helpers import to_char_p_p, init_c_var
def get_bytes(arg, get_sz, get_str, check) -> bytes: return (sz := init_c_var(ctypes.c_size_t(), lambda x: check(get_sz(arg, ctypes.byref(x)))), ctypes.string_at(init_c_var(ctypes.create_string_buffer(sz.value), lambda x: check(get_str(arg, x))), size=sz.value))[1] # noqa: E501
def check(status):
if status != 0: raise RuntimeError(f"HIP Error {status}, {ctypes.string_at(hip.hipGetErrorString(status)).decode()}")
def compile_hip(prg:str, arch="gfx1100") -> bytes:
check(hip.hiprtcCreateProgram(ctypes.byref(prog := hip.hiprtcProgram()), prg.encode(), "<null>".encode(), 0, None, None))
compile_options = [f'--offload-arch={arch}', '-I/opt/rocm/include']
status = hip.hiprtcCompileProgram(prog, len(compile_options), to_char_p_p([o.encode() for o in compile_options]))
if status != 0: raise RuntimeError(f"compile failed: {get_bytes(prog, hip.hiprtcGetProgramLogSize, hip.hiprtcGetProgramLog, check).decode()}")
return get_bytes(prog, hip.hiprtcGetCodeSize, hip.hiprtcGetCode, check)
prefix = """
typedef long unsigned int size_t;
extern "C" __attribute__((device)) __attribute__((const)) size_t __ockl_get_local_id(unsigned int);
extern "C" __attribute__((device)) __attribute__((const)) size_t __ockl_get_group_id(unsigned int);
extern "C" __attribute__((device)) __attribute__((const)) size_t __ockl_get_local_size(unsigned int);
typedef float float2 __attribute__((ext_vector_type(2)));
static inline __attribute__((device)) float2 make_float2(float x, float y) { return {x, y}; }
"""
code = """
extern "C" __attribute__((global))void r_2_8_7_7_4_8_3_7_7_4_4_2_2(float* data0, const float* data1, const float* data2) {
int gidx0 = __ockl_get_group_id(2); /* 2 */
int gidx1 = __ockl_get_group_id(1); /* 8 */
int gidx2 = __ockl_get_group_id(0); /* 49 */
int lidx4 = __ockl_get_local_id(1); /* 4 */
int lidx5 = __ockl_get_local_id(0); /* 8 */
float2 acc0 = make_float2(0.0f,0.0f);
float2 acc1 = make_float2(0.0f,0.0f);
float2 acc2 = make_float2(0.0f,0.0f);
float2 acc3 = make_float2(0.0f,0.0f);
float2 acc4 = make_float2(0.0f,0.0f);
float2 acc5 = make_float2(0.0f,0.0f);
float2 acc6 = make_float2(0.0f,0.0f);
float2 acc7 = make_float2(0.0f,0.0f);
float2 acc8 = make_float2(0.0f,0.0f);
float2 acc9 = make_float2(0.0f,0.0f);
float2 acc10 = make_float2(0.0f,0.0f);
float2 acc11 = make_float2(0.0f,0.0f);
float2 acc12 = make_float2(0.0f,0.0f);
float2 acc13 = make_float2(0.0f,0.0f);
float2 acc14 = make_float2(0.0f,0.0f);
float2 acc15 = make_float2(0.0f,0.0f);
float2 acc16 = make_float2(0.0f,0.0f);
float2 acc17 = make_float2(0.0f,0.0f);
float2 acc18 = make_float2(0.0f,0.0f);
float2 acc19 = make_float2(0.0f,0.0f);
float2 acc20 = make_float2(0.0f,0.0f);
float2 acc21 = make_float2(0.0f,0.0f);
float2 acc22 = make_float2(0.0f,0.0f);
float2 acc23 = make_float2(0.0f,0.0f);
float2 acc24 = make_float2(0.0f,0.0f);
float2 acc25 = make_float2(0.0f,0.0f);
float2 acc26 = make_float2(0.0f,0.0f);
float2 acc27 = make_float2(0.0f,0.0f);
float2 acc28 = make_float2(0.0f,0.0f);
float2 acc29 = make_float2(0.0f,0.0f);
float2 acc30 = make_float2(0.0f,0.0f);
float2 acc31 = make_float2(0.0f,0.0f);
int alu0 = (gidx2/7);
int alu1 = (gidx2%7);
int alu2 = (alu1*32);
int alu3 = (lidx5*4);
int alu4 = ((gidx0*802816)+(gidx1*100352)+(alu0*1792)+(alu1*16)+(lidx4*448)+(lidx5*2));
for (int ridx0 = 0; ridx0 < 3; ridx0++) {
for (int ridx1 = 0; ridx1 < 7; ridx1++) {
int alu5 = ((alu0*(-32))+(lidx4*(-8))+(ridx1*(-1)));
bool alu6 = (alu5<(-2));
bool alu7 = (alu5<0);
bool alu8 = (((alu0*32)+(lidx4*8)+ridx1)<221);
for (int ridx2 = 0; ridx2 < 7; ridx2++) {
int alu9 = ((gidx0*150528)+(ridx0*50176)+(alu0*7168)+(lidx4*1792)+(ridx1*224)+alu2+alu3+ridx2);
int alu10 = ((alu1*(-32))+(lidx5*(-4))+(ridx2*(-1)));
bool alu11 = (alu10<(-2));
float val0 = 0.0f;
if ((alu6*alu11)) { val0 = data1[alu9+(-675)]; }
float val1 = 0.0f;
if ((alu7*alu11)) { val1 = data1[alu9+(-227)]; }
float val2 = 0.0f;
if (alu11) { val2 = data1[alu9+221]; }
float val3 = 0.0f;
if ((alu8*alu11)) { val3 = data1[alu9+669]; }
bool alu12 = (alu10<0);
bool alu13 = ((alu2+alu3+ridx2)<225);
float val4 = 0.0f;
if ((alu6*alu12*alu13)) { val4 = data1[alu9+(-673)]; }
float val5 = 0.0f;
if ((alu7*alu12*alu13)) { val5 = data1[alu9+(-225)]; }
float val6 = 0.0f;
if ((alu12*alu13)) { val6 = data1[alu9+223]; }
float val7 = 0.0f;
if ((alu8*alu12*alu13)) { val7 = data1[alu9+671]; }
int alu14 = ((gidx1*1176)+(ridx0*49)+(ridx1*7)+ridx2);
float val8 = data2[alu14];
float val9 = data2[alu14+147];
float val10 = data2[alu14+294];
float val11 = data2[alu14+441];
float val12 = data2[alu14+588];
float val13 = data2[alu14+735];
float val14 = data2[alu14+882];
float val15 = data2[alu14+1029];
(acc0).x = ((val0*val8)+(acc0).x);
(acc1).x = ((val0*val9)+(acc1).x);
(acc2).x = ((val0*val10)+(acc2).x);
(acc3).x = ((val0*val11)+(acc3).x);
(acc4).x = ((val1*val8)+(acc4).x);
(acc5).x = ((val1*val9)+(acc5).x);
(acc6).x = ((val1*val10)+(acc6).x);
(acc7).x = ((val1*val11)+(acc7).x);
(acc8).x = ((val2*val8)+(acc8).x);
(acc9).x = ((val2*val9)+(acc9).x);
(acc10).x = ((val2*val10)+(acc10).x);
(acc11).x = ((val2*val11)+(acc11).x);
(acc12).x = ((val3*val8)+(acc12).x);
(acc13).x = ((val3*val9)+(acc13).x);
(acc14).x = ((val3*val10)+(acc14).x);
(acc15).x = ((val3*val11)+(acc15).x);
(acc16).x = ((val0*val12)+(acc16).x);
(acc17).x = ((val0*val13)+(acc17).x);
(acc18).x = ((val0*val14)+(acc18).x);
(acc19).x = ((val0*val15)+(acc19).x);
(acc20).x = ((val1*val12)+(acc20).x);
(acc21).x = ((val1*val13)+(acc21).x);
(acc22).x = ((val1*val14)+(acc22).x);
(acc23).x = ((val1*val15)+(acc23).x);
(acc24).x = ((val2*val12)+(acc24).x);
(acc25).x = ((val2*val13)+(acc25).x);
(acc26).x = ((val2*val14)+(acc26).x);
(acc27).x = ((val2*val15)+(acc27).x);
(acc28).x = ((val3*val12)+(acc28).x);
(acc29).x = ((val3*val13)+(acc29).x);
(acc30).x = ((val3*val14)+(acc30).x);
(acc31).x = ((val3*val15)+(acc31).x);
(acc0).y = ((val4*val8)+(acc0).y);
(acc1).y = ((val4*val9)+(acc1).y);
(acc2).y = ((val4*val10)+(acc2).y);
(acc3).y = ((val4*val11)+(acc3).y);
(acc4).y = ((val5*val8)+(acc4).y);
(acc5).y = ((val5*val9)+(acc5).y);
(acc6).y = ((val5*val10)+(acc6).y);
(acc7).y = ((val5*val11)+(acc7).y);
(acc8).y = ((val6*val8)+(acc8).y);
(acc9).y = ((val6*val9)+(acc9).y);
(acc10).y = ((val6*val10)+(acc10).y);
(acc11).y = ((val6*val11)+(acc11).y);
(acc12).y = ((val7*val8)+(acc12).y);
(acc13).y = ((val7*val9)+(acc13).y);
(acc14).y = ((val7*val10)+(acc14).y);
(acc15).y = ((val7*val11)+(acc15).y);
(acc16).y = ((val4*val12)+(acc16).y);
(acc17).y = ((val4*val13)+(acc17).y);
(acc18).y = ((val4*val14)+(acc18).y);
(acc19).y = ((val4*val15)+(acc19).y);
(acc20).y = ((val5*val12)+(acc20).y);
(acc21).y = ((val5*val13)+(acc21).y);
(acc22).y = ((val5*val14)+(acc22).y);
(acc23).y = ((val5*val15)+(acc23).y);
(acc24).y = ((val6*val12)+(acc24).y);
(acc25).y = ((val6*val13)+(acc25).y);
(acc26).y = ((val6*val14)+(acc26).y);
(acc27).y = ((val6*val15)+(acc27).y);
(acc28).y = ((val7*val12)+(acc28).y);
(acc29).y = ((val7*val13)+(acc29).y);
(acc30).y = ((val7*val14)+(acc30).y);
(acc31).y = ((val7*val15)+(acc31).y);
}
}
}
*((float2*)(data0+alu4)) = acc0;
*((float2*)(data0+alu4+12544)) = acc1;
*((float2*)(data0+alu4+25088)) = acc2;
*((float2*)(data0+alu4+37632)) = acc3;
*((float2*)(data0+alu4+112)) = acc4;
*((float2*)(data0+alu4+12656)) = acc5;
*((float2*)(data0+alu4+25200)) = acc6;
*((float2*)(data0+alu4+37744)) = acc7;
*((float2*)(data0+alu4+224)) = acc8;
*((float2*)(data0+alu4+12768)) = acc9;
*((float2*)(data0+alu4+25312)) = acc10;
*((float2*)(data0+alu4+37856)) = acc11;
*((float2*)(data0+alu4+336)) = acc12;
*((float2*)(data0+alu4+12880)) = acc13;
*((float2*)(data0+alu4+25424)) = acc14;
*((float2*)(data0+alu4+37968)) = acc15;
*((float2*)(data0+alu4+50176)) = acc16;
*((float2*)(data0+alu4+62720)) = acc17;
*((float2*)(data0+alu4+75264)) = acc18;
*((float2*)(data0+alu4+87808)) = acc19;
*((float2*)(data0+alu4+50288)) = acc20;
*((float2*)(data0+alu4+62832)) = acc21;
*((float2*)(data0+alu4+75376)) = acc22;
*((float2*)(data0+alu4+87920)) = acc23;
*((float2*)(data0+alu4+50400)) = acc24;
*((float2*)(data0+alu4+62944)) = acc25;
*((float2*)(data0+alu4+75488)) = acc26;
*((float2*)(data0+alu4+88032)) = acc27;
*((float2*)(data0+alu4+50512)) = acc28;
*((float2*)(data0+alu4+63056)) = acc29;
*((float2*)(data0+alu4+75600)) = acc30;
*((float2*)(data0+alu4+88144)) = acc31;
}
"""
dev = "HIP"
lib = Device[dev].compiler.compile(prefix+code)
#lib = compile_hip(code)
b0 = Buffer(dev, 1605632, dtypes.float)
b1 = Buffer(dev, 301506, dtypes.float)
b2 = Buffer(dev, 9408, dtypes.float)
print(hex(b0._buf.value), hex(b0._buf.value+1605632*4))
print(hex(b1._buf.value))
print(hex(b2._buf.value))
#prg = CompiledRunner("r_2_8_7_7_4_8_3_7_7_4_4_2_2", "", dev, [7, 1, 1], [8, 4, 1], precompiled=lib)
prg = CompiledRunner("r_2_8_7_7_4_8_3_7_7_4_4_2_2", "", dev, [49, 8, 2], [8, 4, 1], precompiled=lib)
print("compiled")
prg([b0, b1, b2], {})
print("ran")
Device[dev].synchronize()
print("sync")

View File

@@ -1,5 +1,5 @@
import unittest, subprocess, platform
from tinygrad.runtime.ops_cpu import ClangJITCompiler
from tinygrad.runtime.support.compiler_cpu import ClangJITCompiler
from tinygrad.runtime.support.elf import elf_loader
class TestElfLoader(unittest.TestCase):

View File

@@ -135,9 +135,15 @@ def do_render(ctx:Renderer, prg:UOp, lin:UOp) -> UOp:
src = ctx.render(list(lin.src))
return prg.replace(src=prg.src + (UOp(Ops.SOURCE, arg=src),))
def do_compile(ctx:Renderer, prg:UOp, source:UOp) -> UOp|None:
if ctx.compiler is None: return None
lib = ctx.compiler.compile_cached(source.arg)
return prg.replace(src=prg.src + (UOp(Ops.BINARY, arg=lib),))
pm_to_program = PatternMatcher([
(UPat(Ops.PROGRAM, src=(UPat(Ops.SINK, name="sink"), UPat(Ops.DEVICE)), name="prg"), do_linearize),
(UPat(Ops.PROGRAM, src=(UPat(), UPat(Ops.DEVICE), UPat(Ops.LINEAR, name="lin")), name="prg"), do_render),
(UPat(Ops.PROGRAM, src=(UPat(), UPat(Ops.DEVICE), UPat(Ops.LINEAR), UPat(Ops.SOURCE, name="source")), name="prg"), do_compile),
])
@track_rewrites(name=lambda *args,ret,**kwargs: TracingKey(ret.name, (ret.function_name, ret.ast), ret=ret), replay=True)

View File

@@ -41,7 +41,7 @@ def _time_program(p:ProgramSpec, lib:bytes, var_vals:dict[str, int], rawbufs:lis
if allow_test_size and max_global_size is not None:
global_size, factor = get_test_global_size(p.global_size, max_global_size, var_vals)
p = replace(p, global_size=global_size)
try: car = CompiledRunner(p, precompiled=lib)
try: car = CompiledRunner(replace(p, lib=lib))
except AssertionError: return [math.inf] * cnt
tms = []
input_bufs = [rawbufs[i] for i in car.p.globals]
@@ -72,7 +72,7 @@ def _try_compile(x:tuple[int,Scheduler], compiler:Compiler) -> tuple[int, tuple[
if getenv("BEAM_LOG_SURPASS_MAX"): print(f"too many uops. {len(p.uops)=}, {uops_max=}")
raise RuntimeError("too many uops")
st = time.perf_counter()
prog = compiler.compile(p.src)
prog = p.lib if p.lib is not None else compiler.compile(p.src)
et = time.perf_counter() - st
ret = (p, prog, et)
except RuntimeError:

View File

@@ -278,7 +278,7 @@ class Compiler:
def disassemble(self, lib:bytes): pass
@dataclass(frozen=True)
class CompilerPair: renderer:type[Renderer]|functools.partial; compiler:type[Compiler]|functools.partial; ctrl_var:ContextVar|None = None # noqa: E702
class CompilerPair: renderer:type[Renderer]|functools.partial; compiler:type[Compiler]|functools.partial|None; ctrl_var:ContextVar|None = None # noqa: E702
@dataclass(frozen=True)
class CompilerSet: cset:list[CompilerPair]; ctrl_var:ContextVar|None = None # noqa: E702
@@ -290,21 +290,25 @@ class Compiled:
self.device, self.allocator, self.runtime, self.graph, self.group_id = device, allocator, runtime, graph, group_id
self.comps_ctrl_var = compilers.ctrl_var if compilers is not None else None
self.comp_sets:dict[Any, tuple[ContextVar|None, tuple[type[Renderer]|functools.partial, type[Compiler]|functools.partial]]] = {}
self.cached_pair:dict[Any, tuple[Renderer, Compiler]] = {}
self.comp_sets:dict[Any, tuple[ContextVar|None, tuple[type[Renderer]|functools.partial, type[Compiler]|functools.partial|None]]] = {}
self.cached_pair:dict[Any, tuple[Renderer, Compiler|None]] = {}
for cpair in (compilers.cset if compilers is not None else [CompilerPair(Renderer, Compiler)]):
self.comp_sets[self._compiler_name(cpair.compiler)] = (cpair.ctrl_var, (cpair.renderer, cpair.compiler))
self.comp_sets[self._compiler_name(cpair.renderer, cpair.compiler)] = (cpair.ctrl_var, (cpair.renderer, cpair.compiler))
@property
def renderer(self) -> Renderer: return self._select_compiler_pair()[0]
@property
def compiler(self) -> Compiler: return self._select_compiler_pair()[1]
def compiler(self) -> Compiler:
if (ret:=self.renderer.compiler or self._select_compiler_pair()[1]) is None: raise RuntimeError(f"no compiler for {self.device}")
return ret
def _compiler_name(self, c:type[Compiler]|functools.partial) -> str:
return unwrap_class_type(c).__name__.upper().removesuffix("COMPILER").removeprefix(devname:=self.device.split(':')[0].upper()) or devname
def _compiler_name(self, r:type[Renderer]|functools.partial, c:type[Compiler]|functools.partial|None) -> str:
devname = self.device.split(':')[0].upper()
if c is None: return unwrap_class_type(r).__name__.upper().removesuffix("RENDERER").removeprefix(devname) or devname
return unwrap_class_type(c).__name__.upper().removesuffix("COMPILER").removeprefix(devname) or devname
def _select_compiler_pair(self) -> tuple[Renderer, Compiler]:
def _select_compiler_pair(self) -> tuple[Renderer, Compiler|None]:
# select forced compiler from global env var.
forced_comps = set([self.comp_sets[val][1]] if self.comps_ctrl_var is not None and (val:=self.comps_ctrl_var.value) else [])
@@ -398,7 +402,7 @@ def enumerate_devices_str() -> Generator[str, None, None]:
# d.renderer, d.compiler = r(), c()
with Context(CACHELEVEL=0): test = (Tensor([1,2,3], device=device) * 2).tolist()
if test != [2,4,6]: raise ValueError(f"got {test} instead of [2, 4, 6]")
set_text = f'({cc_ctrl_var.key}={d._compiler_name(c)} to make default)' if cc_ctrl_var is not None else ''
set_text = f'({cc_ctrl_var.key}={d._compiler_name(r, c)} to make default)' if cc_ctrl_var is not None else ''
default_text = '(default)' if type(default_compiler) is type(d.compiler) else set_text
compilers_results.append(f"{colored('+', 'green')} {unwrap_class_type(c).__name__} {default_text}")
any_works = True

View File

@@ -36,19 +36,19 @@ def optimize_local_size(_prg:Callable, global_size:list[int], rawbufs:list[Buffe
return ret[1]
class CompiledRunner(Runner):
def __init__(self, p:ProgramSpec, precompiled:bytes|None=None, prg=None):
def __init__(self, p:ProgramSpec, prg=None):
if DEBUG >= 3: print(p.applied_opts)
if DEBUG >= 4: print(p.src)
self.p:ProgramSpec = p
if precompiled is not None: self.lib = precompiled
else:
if p.lib is None:
with cpu_profile(TracingKey(f"compile {p.name}", (p.function_name,)), "TINY"):
self.lib = Device[p.device].compiler.compile_cached(p.src)
if DEBUG >= 7: Device[p.device].compiler.disassemble(self.lib)
self._prg = Device[p.device].runtime(p.function_name, self.lib) if prg is None else prg
p = replace(p, lib=Device[p.device].compiler.compile_cached(p.src))
self.p:ProgramSpec = p
assert self.p.lib is not None
if DEBUG >= 7: Device[p.device].compiler.disassemble(self.p.lib)
self._prg = Device[p.device].runtime(p.function_name, self.p.lib) if prg is None else prg
super().__init__(p.name, p.device, p.estimates)
def __reduce__(self): return self.__class__, (self.p, self.lib)
def __reduce__(self): return self.__class__, (self.p,)
def __call__(self, rawbufs:list[Buffer], var_vals:dict[str, int]|None=None, wait=False) -> float|None:
if var_vals is None: var_vals = {}
@@ -115,7 +115,7 @@ def get_runner(device:str, ast:UOp) -> CompiledRunner:
if cret:=method_cache.get(ckey): return cret
bkey = (device.split(":")[0], type(Device[device].compiler), ast.key, context, True)
if bret:=method_cache.get(bkey):
method_cache[ckey] = ret = CompiledRunner(replace(bret.p, device=device), bret.lib)
method_cache[ckey] = ret = CompiledRunner(replace(bret.p, device=device))
else:
prg: ProgramSpec = get_program(ast, Device[device].renderer)
method_cache[ckey] = method_cache[bkey] = ret = CompiledRunner(replace(prg, device=device))

View File

@@ -115,12 +115,12 @@ def suppress_finalizing(func):
if not getattr(sys, 'is_finalizing', lambda: True)(): raise # re-raise if not finalizing
return wrapper
def select_first_inited(candidates:Sequence[Callable[...,T]|Sequence[Callable[...,T]]], err_msg:str, cache:dict|None=None) -> tuple[T,...]|T:
def select_first_inited(candidates:Sequence[Callable[...,T]|Sequence[Callable[...,T]|None]], err_msg:str, cache:dict|None=None):
excs = []
for typ in candidates:
if cache is not None and typ in cache: return cache[typ]
try:
x = tuple([cast(Callable, t)() for t in typ]) if isinstance(typ, Sequence) else cast(Callable, typ)()
x = tuple([cast(Callable, t)() if t is not None else None for t in typ]) if isinstance(typ, Sequence) else cast(Callable, typ)()
if cache is not None: cache[typ] = x
return x
except Exception as e: excs.append(e)

View File

@@ -1,5 +1,5 @@
from __future__ import annotations
from typing import Callable, cast
from typing import Callable, cast, TYPE_CHECKING
import functools
from dataclasses import dataclass, field
from tinygrad.helpers import to_function_name, dedup, prod, DEBUG
@@ -7,6 +7,7 @@ from tinygrad.uop.ops import Ops, UOp, sym_infer, sint, Variable, ssimplify, Gro
from tinygrad.dtype import AddrSpace, PtrDType
from tinygrad.codegen.opt.tc import TensorCore
from tinygrad.codegen.opt import Opt
if TYPE_CHECKING: from tinygrad.device import Compiler
@dataclass(frozen=True)
class Estimates:
@@ -64,6 +65,7 @@ class ProgramSpec:
device:str
ast:UOp # save the base ast (this is method cache key)
uops:list[UOp]|None=None
lib:bytes|None=None
# filled in from uops (via from_uop)
global_size:list[int]=field(default_factory=lambda: [1,1,1])
@@ -95,8 +97,9 @@ class ProgramSpec:
def from_uop(prg:UOp) -> ProgramSpec:
"""Construct ProgramSpec from a PROGRAM UOp."""
assert prg.op is Ops.PROGRAM, f"expected PROGRAM, got {prg.op}"
# SINK/DEVICE/LINEAR/SOURCE
sink, device, linear, source = prg.src
# SINK/DEVICE/LINEAR/SOURCE/BINARY?
sink, device, linear, source = prg.src[:4]
lib = prg.src[4].arg if len(prg.src) > 4 else None
uops = list(linear.src)
if DEBUG >= 6: print_uops(uops) # LINEAR is src[2]
@@ -120,7 +123,7 @@ class ProgramSpec:
# TODO: this cast is wrong, u.src[0].ssimplify() can be sint
if special_size is not None: special_size[int(u.arg[-1])] = cast(int, u.src[0].ssimplify())
return ProgramSpec(sink.arg.name, source.arg, device.arg, sink, uops, global_size, local_size,
return ProgramSpec(sink.arg.name, source.arg, device.arg, sink, uops, lib, global_size, local_size,
sorted(_vars, key=lambda v: v.arg), sorted(dedup(_globals)), sorted(dedup(outs)), sorted(dedup(ins)))
class Renderer:
@@ -139,6 +142,7 @@ class Renderer:
pre_matcher: PatternMatcher|None = None
extra_matcher: PatternMatcher|None = None
code_for_op: dict[Ops, Callable] = {}
compiler: Compiler|None = None
def __reduce__(self): return self.__class__, ()
def render(self, uops:list[UOp]) -> str: raise NotImplementedError("needs a renderer")

View File

@@ -8,6 +8,7 @@ from tinygrad.dtype import ImageDType, dtypes, DType, PtrDType, AddrSpace, trunc
from tinygrad.renderer import Renderer
from tinygrad.codegen.late.devectorizer import no_vectorized_alu
base_rewrite = PatternMatcher([
(UPat(Ops.DEFINE_REG, name="x"), lambda ctx,x: f"{ctx.render_dtype(x.dtype.base)} {ctx[x]}[{x.dtype.size}];"),
(UPat(Ops.IF, name="x"), lambda ctx,x: f"if ({ctx[x.src[0]]}) {{"),
@@ -278,6 +279,11 @@ class ClangRenderer(CStyleLanguage):
defines = '\n'.join(self._render_defines(uops))
return defines + "\n" + self._render_body(function_name, kernel, bufs, uops, prefix) + "\n" + self._render_entry(function_name, bufs)
class ClangJITRenderer(ClangRenderer):
def __init__(self):
from tinygrad.runtime.support.compiler_cpu import ClangJITCompiler
self.compiler = ClangJITCompiler()
class OpenCLRenderer(CStyleLanguage):
device = "CL"

View File

@@ -5,10 +5,10 @@ from tinygrad.helpers import CPU_CC, CPU_LVP, CPU_LLVM
from tinygrad.device import BufferSpec, DMACPURef, CompilerSet, CompilerPair
from tinygrad.runtime.support.hcq import HCQCompiled, HCQAllocatorBase, HCQBuffer, HWQueue, HCQArgsState, HCQSignal, HCQProgram, MMIOInterface
from tinygrad.runtime.support.hcq import CLikeArgsState
from tinygrad.renderer.cstyle import ClangRenderer
from tinygrad.renderer.cstyle import ClangJITRenderer
from tinygrad.renderer.llvmir import LLVMRenderer
from tinygrad.renderer.nir import LVPRenderer
from tinygrad.runtime.support.compiler_cpu import CPULLVMCompiler, ClangJITCompiler
from tinygrad.runtime.support.compiler_cpu import CPULLVMCompiler
from tinygrad.runtime.support.compiler_mesa import LVPCompiler
from tinygrad.runtime.support.elf import jit_loader
from tinygrad.uop.ops import sint
@@ -136,6 +136,6 @@ class CPUDevice(HCQCompiled):
def __init__(self, device:str=""):
self.tasks:queue.Queue = queue.Queue()
CPUWorker(self, self.tasks, thread_id=0).start()
compilers = CompilerSet([CompilerPair(ClangRenderer, ClangJITCompiler), CompilerPair(LLVMRenderer, CPULLVMCompiler, ctrl_var=CPU_LLVM),
compilers = CompilerSet([CompilerPair(ClangJITRenderer, None), CompilerPair(LLVMRenderer, CPULLVMCompiler, ctrl_var=CPU_LLVM),
CompilerPair(LVPRenderer, LVPCompiler, ctrl_var=CPU_LVP)], ctrl_var=CPU_CC)
super().__init__(device, CPUAllocator(self), compilers, functools.partial(CPUProgram, self), CPUSignal, CPUComputeQueue)

View File

@@ -28,8 +28,8 @@ class Ops(FastEnum):
NOOP = auto(); REWRITE_ERROR = auto()
# renderer
# LINEAR is a list of UOps, SOURCE has a str arg that's human readable
PROGRAM = auto(); LINEAR = auto(); SOURCE = auto()
# LINEAR is a list of UOps, SOURCE has a str arg that's human readable, BINARY has bytes arg that's compiled
PROGRAM = auto(); LINEAR = auto(); SOURCE = auto(); BINARY = auto()
# AFTER passes src[0] through and promises in the toposort that any consumers of the AFTER run after src[1:]
# GROUP is a NOOP that just merges things together

View File

@@ -219,7 +219,7 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
# late ops don't have shape
case Ops.UNIQUE | Ops.LUNIQUE | Ops.DEVICE | Ops.RANGE | Ops.LOAD | Ops.IF | Ops.BARRIER | Ops.CUSTOM | Ops.CUSTOMI | \
Ops.VECTORIZE | Ops.VCONST | Ops.GEP | Ops.SPECIAL | Ops.UNROLL | Ops.CONTRACT | Ops.CUSTOM_KERNEL | \
Ops.LINEAR | Ops.PROGRAM | Ops.SOURCE:
Ops.LINEAR | Ops.PROGRAM | Ops.SOURCE | Ops.BINARY:
return None
case Ops.INDEX:

View File

@@ -249,13 +249,15 @@ full_spec = PatternMatcher([
# in progress MSTACK may lose device
(UPat((Ops.MSELECT, Ops.MSTACK), name="x"), lambda x: True),
# codegen: PROGRAM with progressive sources through the pipeline (SINK, DEVICE, LINEAR?, SOURCE?)
# codegen: PROGRAM with progressive sources through the pipeline (SINK, DEVICE, LINEAR?, SOURCE?, BINARY?)
(UPat(Ops.PROGRAM, dtypes.void, src=(UPat(Ops.SINK), UPat(Ops.DEVICE))), lambda: True),
(UPat(Ops.PROGRAM, dtypes.void, src=(UPat(Ops.SINK), UPat(Ops.DEVICE), UPat(Ops.LINEAR))), lambda: True),
(UPat(Ops.PROGRAM, dtypes.void, src=(UPat(Ops.SINK), UPat(Ops.DEVICE), UPat(Ops.LINEAR), UPat(Ops.SOURCE))), lambda: True),
# codegen: standalone LINEAR/SOURCE
(UPat(Ops.PROGRAM, dtypes.void, src=(UPat(Ops.SINK), UPat(Ops.DEVICE), UPat(Ops.LINEAR), UPat(Ops.SOURCE), UPat(Ops.BINARY))), lambda: True),
# codegen: standalone LINEAR/SOURCE/BINARY
(UPat(Ops.LINEAR, dtypes.void), lambda: True),
(UPat(Ops.SOURCE, dtypes.void, src=()), lambda: True),
(UPat(Ops.BINARY, dtypes.void, src=()), lambda: True),
# temp VECTORIZE/INDEX during rewrite have the wrong dtype
(UPat(Ops.VECTORIZE), lambda: True),