add optional compiler in Renderer (#13817)

* add optional compiler in Renderer [pr]

* fix

* late init

* remove precompiled

* cleanup
This commit is contained in:
George Hotz
2025-12-23 17:58:46 -05:00
committed by GitHub
parent 8eab6175ee
commit 43c6e973d8
13 changed files with 57 additions and 260 deletions

View File

@@ -1,225 +0,0 @@
# [<buf device:HIP size:1605632 dtype:dtypes.float>, <buf device:HIP size:301506 dtype:dtypes.float>, <buf device:HIP size:9408 dtype:dtypes.float>]
from tinygrad import Device, dtypes
from tinygrad.device import Buffer, CompiledRunner
import ctypes
import gpuctypes.hip as hip
from tinygrad.helpers import to_char_p_p, init_c_var
def get_bytes(arg, get_sz, get_str, check) -> bytes: return (sz := init_c_var(ctypes.c_size_t(), lambda x: check(get_sz(arg, ctypes.byref(x)))), ctypes.string_at(init_c_var(ctypes.create_string_buffer(sz.value), lambda x: check(get_str(arg, x))), size=sz.value))[1] # noqa: E501
def check(status):
if status != 0: raise RuntimeError(f"HIP Error {status}, {ctypes.string_at(hip.hipGetErrorString(status)).decode()}")
def compile_hip(prg:str, arch="gfx1100") -> bytes:
check(hip.hiprtcCreateProgram(ctypes.byref(prog := hip.hiprtcProgram()), prg.encode(), "<null>".encode(), 0, None, None))
compile_options = [f'--offload-arch={arch}', '-I/opt/rocm/include']
status = hip.hiprtcCompileProgram(prog, len(compile_options), to_char_p_p([o.encode() for o in compile_options]))
if status != 0: raise RuntimeError(f"compile failed: {get_bytes(prog, hip.hiprtcGetProgramLogSize, hip.hiprtcGetProgramLog, check).decode()}")
return get_bytes(prog, hip.hiprtcGetCodeSize, hip.hiprtcGetCode, check)
prefix = """
typedef long unsigned int size_t;
extern "C" __attribute__((device)) __attribute__((const)) size_t __ockl_get_local_id(unsigned int);
extern "C" __attribute__((device)) __attribute__((const)) size_t __ockl_get_group_id(unsigned int);
extern "C" __attribute__((device)) __attribute__((const)) size_t __ockl_get_local_size(unsigned int);
typedef float float2 __attribute__((ext_vector_type(2)));
static inline __attribute__((device)) float2 make_float2(float x, float y) { return {x, y}; }
"""
code = """
extern "C" __attribute__((global))void r_2_8_7_7_4_8_3_7_7_4_4_2_2(float* data0, const float* data1, const float* data2) {
int gidx0 = __ockl_get_group_id(2); /* 2 */
int gidx1 = __ockl_get_group_id(1); /* 8 */
int gidx2 = __ockl_get_group_id(0); /* 49 */
int lidx4 = __ockl_get_local_id(1); /* 4 */
int lidx5 = __ockl_get_local_id(0); /* 8 */
float2 acc0 = make_float2(0.0f,0.0f);
float2 acc1 = make_float2(0.0f,0.0f);
float2 acc2 = make_float2(0.0f,0.0f);
float2 acc3 = make_float2(0.0f,0.0f);
float2 acc4 = make_float2(0.0f,0.0f);
float2 acc5 = make_float2(0.0f,0.0f);
float2 acc6 = make_float2(0.0f,0.0f);
float2 acc7 = make_float2(0.0f,0.0f);
float2 acc8 = make_float2(0.0f,0.0f);
float2 acc9 = make_float2(0.0f,0.0f);
float2 acc10 = make_float2(0.0f,0.0f);
float2 acc11 = make_float2(0.0f,0.0f);
float2 acc12 = make_float2(0.0f,0.0f);
float2 acc13 = make_float2(0.0f,0.0f);
float2 acc14 = make_float2(0.0f,0.0f);
float2 acc15 = make_float2(0.0f,0.0f);
float2 acc16 = make_float2(0.0f,0.0f);
float2 acc17 = make_float2(0.0f,0.0f);
float2 acc18 = make_float2(0.0f,0.0f);
float2 acc19 = make_float2(0.0f,0.0f);
float2 acc20 = make_float2(0.0f,0.0f);
float2 acc21 = make_float2(0.0f,0.0f);
float2 acc22 = make_float2(0.0f,0.0f);
float2 acc23 = make_float2(0.0f,0.0f);
float2 acc24 = make_float2(0.0f,0.0f);
float2 acc25 = make_float2(0.0f,0.0f);
float2 acc26 = make_float2(0.0f,0.0f);
float2 acc27 = make_float2(0.0f,0.0f);
float2 acc28 = make_float2(0.0f,0.0f);
float2 acc29 = make_float2(0.0f,0.0f);
float2 acc30 = make_float2(0.0f,0.0f);
float2 acc31 = make_float2(0.0f,0.0f);
int alu0 = (gidx2/7);
int alu1 = (gidx2%7);
int alu2 = (alu1*32);
int alu3 = (lidx5*4);
int alu4 = ((gidx0*802816)+(gidx1*100352)+(alu0*1792)+(alu1*16)+(lidx4*448)+(lidx5*2));
for (int ridx0 = 0; ridx0 < 3; ridx0++) {
for (int ridx1 = 0; ridx1 < 7; ridx1++) {
int alu5 = ((alu0*(-32))+(lidx4*(-8))+(ridx1*(-1)));
bool alu6 = (alu5<(-2));
bool alu7 = (alu5<0);
bool alu8 = (((alu0*32)+(lidx4*8)+ridx1)<221);
for (int ridx2 = 0; ridx2 < 7; ridx2++) {
int alu9 = ((gidx0*150528)+(ridx0*50176)+(alu0*7168)+(lidx4*1792)+(ridx1*224)+alu2+alu3+ridx2);
int alu10 = ((alu1*(-32))+(lidx5*(-4))+(ridx2*(-1)));
bool alu11 = (alu10<(-2));
float val0 = 0.0f;
if ((alu6*alu11)) { val0 = data1[alu9+(-675)]; }
float val1 = 0.0f;
if ((alu7*alu11)) { val1 = data1[alu9+(-227)]; }
float val2 = 0.0f;
if (alu11) { val2 = data1[alu9+221]; }
float val3 = 0.0f;
if ((alu8*alu11)) { val3 = data1[alu9+669]; }
bool alu12 = (alu10<0);
bool alu13 = ((alu2+alu3+ridx2)<225);
float val4 = 0.0f;
if ((alu6*alu12*alu13)) { val4 = data1[alu9+(-673)]; }
float val5 = 0.0f;
if ((alu7*alu12*alu13)) { val5 = data1[alu9+(-225)]; }
float val6 = 0.0f;
if ((alu12*alu13)) { val6 = data1[alu9+223]; }
float val7 = 0.0f;
if ((alu8*alu12*alu13)) { val7 = data1[alu9+671]; }
int alu14 = ((gidx1*1176)+(ridx0*49)+(ridx1*7)+ridx2);
float val8 = data2[alu14];
float val9 = data2[alu14+147];
float val10 = data2[alu14+294];
float val11 = data2[alu14+441];
float val12 = data2[alu14+588];
float val13 = data2[alu14+735];
float val14 = data2[alu14+882];
float val15 = data2[alu14+1029];
(acc0).x = ((val0*val8)+(acc0).x);
(acc1).x = ((val0*val9)+(acc1).x);
(acc2).x = ((val0*val10)+(acc2).x);
(acc3).x = ((val0*val11)+(acc3).x);
(acc4).x = ((val1*val8)+(acc4).x);
(acc5).x = ((val1*val9)+(acc5).x);
(acc6).x = ((val1*val10)+(acc6).x);
(acc7).x = ((val1*val11)+(acc7).x);
(acc8).x = ((val2*val8)+(acc8).x);
(acc9).x = ((val2*val9)+(acc9).x);
(acc10).x = ((val2*val10)+(acc10).x);
(acc11).x = ((val2*val11)+(acc11).x);
(acc12).x = ((val3*val8)+(acc12).x);
(acc13).x = ((val3*val9)+(acc13).x);
(acc14).x = ((val3*val10)+(acc14).x);
(acc15).x = ((val3*val11)+(acc15).x);
(acc16).x = ((val0*val12)+(acc16).x);
(acc17).x = ((val0*val13)+(acc17).x);
(acc18).x = ((val0*val14)+(acc18).x);
(acc19).x = ((val0*val15)+(acc19).x);
(acc20).x = ((val1*val12)+(acc20).x);
(acc21).x = ((val1*val13)+(acc21).x);
(acc22).x = ((val1*val14)+(acc22).x);
(acc23).x = ((val1*val15)+(acc23).x);
(acc24).x = ((val2*val12)+(acc24).x);
(acc25).x = ((val2*val13)+(acc25).x);
(acc26).x = ((val2*val14)+(acc26).x);
(acc27).x = ((val2*val15)+(acc27).x);
(acc28).x = ((val3*val12)+(acc28).x);
(acc29).x = ((val3*val13)+(acc29).x);
(acc30).x = ((val3*val14)+(acc30).x);
(acc31).x = ((val3*val15)+(acc31).x);
(acc0).y = ((val4*val8)+(acc0).y);
(acc1).y = ((val4*val9)+(acc1).y);
(acc2).y = ((val4*val10)+(acc2).y);
(acc3).y = ((val4*val11)+(acc3).y);
(acc4).y = ((val5*val8)+(acc4).y);
(acc5).y = ((val5*val9)+(acc5).y);
(acc6).y = ((val5*val10)+(acc6).y);
(acc7).y = ((val5*val11)+(acc7).y);
(acc8).y = ((val6*val8)+(acc8).y);
(acc9).y = ((val6*val9)+(acc9).y);
(acc10).y = ((val6*val10)+(acc10).y);
(acc11).y = ((val6*val11)+(acc11).y);
(acc12).y = ((val7*val8)+(acc12).y);
(acc13).y = ((val7*val9)+(acc13).y);
(acc14).y = ((val7*val10)+(acc14).y);
(acc15).y = ((val7*val11)+(acc15).y);
(acc16).y = ((val4*val12)+(acc16).y);
(acc17).y = ((val4*val13)+(acc17).y);
(acc18).y = ((val4*val14)+(acc18).y);
(acc19).y = ((val4*val15)+(acc19).y);
(acc20).y = ((val5*val12)+(acc20).y);
(acc21).y = ((val5*val13)+(acc21).y);
(acc22).y = ((val5*val14)+(acc22).y);
(acc23).y = ((val5*val15)+(acc23).y);
(acc24).y = ((val6*val12)+(acc24).y);
(acc25).y = ((val6*val13)+(acc25).y);
(acc26).y = ((val6*val14)+(acc26).y);
(acc27).y = ((val6*val15)+(acc27).y);
(acc28).y = ((val7*val12)+(acc28).y);
(acc29).y = ((val7*val13)+(acc29).y);
(acc30).y = ((val7*val14)+(acc30).y);
(acc31).y = ((val7*val15)+(acc31).y);
}
}
}
*((float2*)(data0+alu4)) = acc0;
*((float2*)(data0+alu4+12544)) = acc1;
*((float2*)(data0+alu4+25088)) = acc2;
*((float2*)(data0+alu4+37632)) = acc3;
*((float2*)(data0+alu4+112)) = acc4;
*((float2*)(data0+alu4+12656)) = acc5;
*((float2*)(data0+alu4+25200)) = acc6;
*((float2*)(data0+alu4+37744)) = acc7;
*((float2*)(data0+alu4+224)) = acc8;
*((float2*)(data0+alu4+12768)) = acc9;
*((float2*)(data0+alu4+25312)) = acc10;
*((float2*)(data0+alu4+37856)) = acc11;
*((float2*)(data0+alu4+336)) = acc12;
*((float2*)(data0+alu4+12880)) = acc13;
*((float2*)(data0+alu4+25424)) = acc14;
*((float2*)(data0+alu4+37968)) = acc15;
*((float2*)(data0+alu4+50176)) = acc16;
*((float2*)(data0+alu4+62720)) = acc17;
*((float2*)(data0+alu4+75264)) = acc18;
*((float2*)(data0+alu4+87808)) = acc19;
*((float2*)(data0+alu4+50288)) = acc20;
*((float2*)(data0+alu4+62832)) = acc21;
*((float2*)(data0+alu4+75376)) = acc22;
*((float2*)(data0+alu4+87920)) = acc23;
*((float2*)(data0+alu4+50400)) = acc24;
*((float2*)(data0+alu4+62944)) = acc25;
*((float2*)(data0+alu4+75488)) = acc26;
*((float2*)(data0+alu4+88032)) = acc27;
*((float2*)(data0+alu4+50512)) = acc28;
*((float2*)(data0+alu4+63056)) = acc29;
*((float2*)(data0+alu4+75600)) = acc30;
*((float2*)(data0+alu4+88144)) = acc31;
}
"""
dev = "HIP"
lib = Device[dev].compiler.compile(prefix+code)
#lib = compile_hip(code)
b0 = Buffer(dev, 1605632, dtypes.float)
b1 = Buffer(dev, 301506, dtypes.float)
b2 = Buffer(dev, 9408, dtypes.float)
print(hex(b0._buf.value), hex(b0._buf.value+1605632*4))
print(hex(b1._buf.value))
print(hex(b2._buf.value))
#prg = CompiledRunner("r_2_8_7_7_4_8_3_7_7_4_4_2_2", "", dev, [7, 1, 1], [8, 4, 1], precompiled=lib)
prg = CompiledRunner("r_2_8_7_7_4_8_3_7_7_4_4_2_2", "", dev, [49, 8, 2], [8, 4, 1], precompiled=lib)
print("compiled")
prg([b0, b1, b2], {})
print("ran")
Device[dev].synchronize()
print("sync")

View File

@@ -1,5 +1,5 @@
import unittest, subprocess, platform import unittest, subprocess, platform
from tinygrad.runtime.ops_cpu import ClangJITCompiler from tinygrad.runtime.support.compiler_cpu import ClangJITCompiler
from tinygrad.runtime.support.elf import elf_loader from tinygrad.runtime.support.elf import elf_loader
class TestElfLoader(unittest.TestCase): class TestElfLoader(unittest.TestCase):

View File

@@ -135,9 +135,15 @@ def do_render(ctx:Renderer, prg:UOp, lin:UOp) -> UOp:
src = ctx.render(list(lin.src)) src = ctx.render(list(lin.src))
return prg.replace(src=prg.src + (UOp(Ops.SOURCE, arg=src),)) return prg.replace(src=prg.src + (UOp(Ops.SOURCE, arg=src),))
def do_compile(ctx:Renderer, prg:UOp, source:UOp) -> UOp|None:
if ctx.compiler is None: return None
lib = ctx.compiler.compile_cached(source.arg)
return prg.replace(src=prg.src + (UOp(Ops.BINARY, arg=lib),))
pm_to_program = PatternMatcher([ pm_to_program = PatternMatcher([
(UPat(Ops.PROGRAM, src=(UPat(Ops.SINK, name="sink"), UPat(Ops.DEVICE)), name="prg"), do_linearize), (UPat(Ops.PROGRAM, src=(UPat(Ops.SINK, name="sink"), UPat(Ops.DEVICE)), name="prg"), do_linearize),
(UPat(Ops.PROGRAM, src=(UPat(), UPat(Ops.DEVICE), UPat(Ops.LINEAR, name="lin")), name="prg"), do_render), (UPat(Ops.PROGRAM, src=(UPat(), UPat(Ops.DEVICE), UPat(Ops.LINEAR, name="lin")), name="prg"), do_render),
(UPat(Ops.PROGRAM, src=(UPat(), UPat(Ops.DEVICE), UPat(Ops.LINEAR), UPat(Ops.SOURCE, name="source")), name="prg"), do_compile),
]) ])
@track_rewrites(name=lambda *args,ret,**kwargs: TracingKey(ret.name, (ret.function_name, ret.ast), ret=ret), replay=True) @track_rewrites(name=lambda *args,ret,**kwargs: TracingKey(ret.name, (ret.function_name, ret.ast), ret=ret), replay=True)

View File

@@ -41,7 +41,7 @@ def _time_program(p:ProgramSpec, lib:bytes, var_vals:dict[str, int], rawbufs:lis
if allow_test_size and max_global_size is not None: if allow_test_size and max_global_size is not None:
global_size, factor = get_test_global_size(p.global_size, max_global_size, var_vals) global_size, factor = get_test_global_size(p.global_size, max_global_size, var_vals)
p = replace(p, global_size=global_size) p = replace(p, global_size=global_size)
try: car = CompiledRunner(p, precompiled=lib) try: car = CompiledRunner(replace(p, lib=lib))
except AssertionError: return [math.inf] * cnt except AssertionError: return [math.inf] * cnt
tms = [] tms = []
input_bufs = [rawbufs[i] for i in car.p.globals] input_bufs = [rawbufs[i] for i in car.p.globals]
@@ -72,7 +72,7 @@ def _try_compile(x:tuple[int,Scheduler], compiler:Compiler) -> tuple[int, tuple[
if getenv("BEAM_LOG_SURPASS_MAX"): print(f"too many uops. {len(p.uops)=}, {uops_max=}") if getenv("BEAM_LOG_SURPASS_MAX"): print(f"too many uops. {len(p.uops)=}, {uops_max=}")
raise RuntimeError("too many uops") raise RuntimeError("too many uops")
st = time.perf_counter() st = time.perf_counter()
prog = compiler.compile(p.src) prog = p.lib if p.lib is not None else compiler.compile(p.src)
et = time.perf_counter() - st et = time.perf_counter() - st
ret = (p, prog, et) ret = (p, prog, et)
except RuntimeError: except RuntimeError:

View File

@@ -278,7 +278,7 @@ class Compiler:
def disassemble(self, lib:bytes): pass def disassemble(self, lib:bytes): pass
@dataclass(frozen=True) @dataclass(frozen=True)
class CompilerPair: renderer:type[Renderer]|functools.partial; compiler:type[Compiler]|functools.partial; ctrl_var:ContextVar|None = None # noqa: E702 class CompilerPair: renderer:type[Renderer]|functools.partial; compiler:type[Compiler]|functools.partial|None; ctrl_var:ContextVar|None = None # noqa: E702
@dataclass(frozen=True) @dataclass(frozen=True)
class CompilerSet: cset:list[CompilerPair]; ctrl_var:ContextVar|None = None # noqa: E702 class CompilerSet: cset:list[CompilerPair]; ctrl_var:ContextVar|None = None # noqa: E702
@@ -290,21 +290,25 @@ class Compiled:
self.device, self.allocator, self.runtime, self.graph, self.group_id = device, allocator, runtime, graph, group_id self.device, self.allocator, self.runtime, self.graph, self.group_id = device, allocator, runtime, graph, group_id
self.comps_ctrl_var = compilers.ctrl_var if compilers is not None else None self.comps_ctrl_var = compilers.ctrl_var if compilers is not None else None
self.comp_sets:dict[Any, tuple[ContextVar|None, tuple[type[Renderer]|functools.partial, type[Compiler]|functools.partial]]] = {} self.comp_sets:dict[Any, tuple[ContextVar|None, tuple[type[Renderer]|functools.partial, type[Compiler]|functools.partial|None]]] = {}
self.cached_pair:dict[Any, tuple[Renderer, Compiler]] = {} self.cached_pair:dict[Any, tuple[Renderer, Compiler|None]] = {}
for cpair in (compilers.cset if compilers is not None else [CompilerPair(Renderer, Compiler)]): for cpair in (compilers.cset if compilers is not None else [CompilerPair(Renderer, Compiler)]):
self.comp_sets[self._compiler_name(cpair.compiler)] = (cpair.ctrl_var, (cpair.renderer, cpair.compiler)) self.comp_sets[self._compiler_name(cpair.renderer, cpair.compiler)] = (cpair.ctrl_var, (cpair.renderer, cpair.compiler))
@property @property
def renderer(self) -> Renderer: return self._select_compiler_pair()[0] def renderer(self) -> Renderer: return self._select_compiler_pair()[0]
@property @property
def compiler(self) -> Compiler: return self._select_compiler_pair()[1] def compiler(self) -> Compiler:
if (ret:=self.renderer.compiler or self._select_compiler_pair()[1]) is None: raise RuntimeError(f"no compiler for {self.device}")
return ret
def _compiler_name(self, c:type[Compiler]|functools.partial) -> str: def _compiler_name(self, r:type[Renderer]|functools.partial, c:type[Compiler]|functools.partial|None) -> str:
return unwrap_class_type(c).__name__.upper().removesuffix("COMPILER").removeprefix(devname:=self.device.split(':')[0].upper()) or devname devname = self.device.split(':')[0].upper()
if c is None: return unwrap_class_type(r).__name__.upper().removesuffix("RENDERER").removeprefix(devname) or devname
return unwrap_class_type(c).__name__.upper().removesuffix("COMPILER").removeprefix(devname) or devname
def _select_compiler_pair(self) -> tuple[Renderer, Compiler]: def _select_compiler_pair(self) -> tuple[Renderer, Compiler|None]:
# select forced compiler from global env var. # select forced compiler from global env var.
forced_comps = set([self.comp_sets[val][1]] if self.comps_ctrl_var is not None and (val:=self.comps_ctrl_var.value) else []) forced_comps = set([self.comp_sets[val][1]] if self.comps_ctrl_var is not None and (val:=self.comps_ctrl_var.value) else [])
@@ -398,7 +402,7 @@ def enumerate_devices_str() -> Generator[str, None, None]:
# d.renderer, d.compiler = r(), c() # d.renderer, d.compiler = r(), c()
with Context(CACHELEVEL=0): test = (Tensor([1,2,3], device=device) * 2).tolist() with Context(CACHELEVEL=0): test = (Tensor([1,2,3], device=device) * 2).tolist()
if test != [2,4,6]: raise ValueError(f"got {test} instead of [2, 4, 6]") if test != [2,4,6]: raise ValueError(f"got {test} instead of [2, 4, 6]")
set_text = f'({cc_ctrl_var.key}={d._compiler_name(c)} to make default)' if cc_ctrl_var is not None else '' set_text = f'({cc_ctrl_var.key}={d._compiler_name(r, c)} to make default)' if cc_ctrl_var is not None else ''
default_text = '(default)' if type(default_compiler) is type(d.compiler) else set_text default_text = '(default)' if type(default_compiler) is type(d.compiler) else set_text
compilers_results.append(f"{colored('+', 'green')} {unwrap_class_type(c).__name__} {default_text}") compilers_results.append(f"{colored('+', 'green')} {unwrap_class_type(c).__name__} {default_text}")
any_works = True any_works = True

View File

@@ -36,19 +36,19 @@ def optimize_local_size(_prg:Callable, global_size:list[int], rawbufs:list[Buffe
return ret[1] return ret[1]
class CompiledRunner(Runner): class CompiledRunner(Runner):
def __init__(self, p:ProgramSpec, precompiled:bytes|None=None, prg=None): def __init__(self, p:ProgramSpec, prg=None):
if DEBUG >= 3: print(p.applied_opts) if DEBUG >= 3: print(p.applied_opts)
if DEBUG >= 4: print(p.src) if DEBUG >= 4: print(p.src)
self.p:ProgramSpec = p if p.lib is None:
if precompiled is not None: self.lib = precompiled
else:
with cpu_profile(TracingKey(f"compile {p.name}", (p.function_name,)), "TINY"): with cpu_profile(TracingKey(f"compile {p.name}", (p.function_name,)), "TINY"):
self.lib = Device[p.device].compiler.compile_cached(p.src) p = replace(p, lib=Device[p.device].compiler.compile_cached(p.src))
if DEBUG >= 7: Device[p.device].compiler.disassemble(self.lib) self.p:ProgramSpec = p
self._prg = Device[p.device].runtime(p.function_name, self.lib) if prg is None else prg assert self.p.lib is not None
if DEBUG >= 7: Device[p.device].compiler.disassemble(self.p.lib)
self._prg = Device[p.device].runtime(p.function_name, self.p.lib) if prg is None else prg
super().__init__(p.name, p.device, p.estimates) super().__init__(p.name, p.device, p.estimates)
def __reduce__(self): return self.__class__, (self.p, self.lib) def __reduce__(self): return self.__class__, (self.p,)
def __call__(self, rawbufs:list[Buffer], var_vals:dict[str, int]|None=None, wait=False) -> float|None: def __call__(self, rawbufs:list[Buffer], var_vals:dict[str, int]|None=None, wait=False) -> float|None:
if var_vals is None: var_vals = {} if var_vals is None: var_vals = {}
@@ -115,7 +115,7 @@ def get_runner(device:str, ast:UOp) -> CompiledRunner:
if cret:=method_cache.get(ckey): return cret if cret:=method_cache.get(ckey): return cret
bkey = (device.split(":")[0], type(Device[device].compiler), ast.key, context, True) bkey = (device.split(":")[0], type(Device[device].compiler), ast.key, context, True)
if bret:=method_cache.get(bkey): if bret:=method_cache.get(bkey):
method_cache[ckey] = ret = CompiledRunner(replace(bret.p, device=device), bret.lib) method_cache[ckey] = ret = CompiledRunner(replace(bret.p, device=device))
else: else:
prg: ProgramSpec = get_program(ast, Device[device].renderer) prg: ProgramSpec = get_program(ast, Device[device].renderer)
method_cache[ckey] = method_cache[bkey] = ret = CompiledRunner(replace(prg, device=device)) method_cache[ckey] = method_cache[bkey] = ret = CompiledRunner(replace(prg, device=device))

View File

@@ -115,12 +115,12 @@ def suppress_finalizing(func):
if not getattr(sys, 'is_finalizing', lambda: True)(): raise # re-raise if not finalizing if not getattr(sys, 'is_finalizing', lambda: True)(): raise # re-raise if not finalizing
return wrapper return wrapper
def select_first_inited(candidates:Sequence[Callable[...,T]|Sequence[Callable[...,T]]], err_msg:str, cache:dict|None=None) -> tuple[T,...]|T: def select_first_inited(candidates:Sequence[Callable[...,T]|Sequence[Callable[...,T]|None]], err_msg:str, cache:dict|None=None):
excs = [] excs = []
for typ in candidates: for typ in candidates:
if cache is not None and typ in cache: return cache[typ] if cache is not None and typ in cache: return cache[typ]
try: try:
x = tuple([cast(Callable, t)() for t in typ]) if isinstance(typ, Sequence) else cast(Callable, typ)() x = tuple([cast(Callable, t)() if t is not None else None for t in typ]) if isinstance(typ, Sequence) else cast(Callable, typ)()
if cache is not None: cache[typ] = x if cache is not None: cache[typ] = x
return x return x
except Exception as e: excs.append(e) except Exception as e: excs.append(e)

View File

@@ -1,5 +1,5 @@
from __future__ import annotations from __future__ import annotations
from typing import Callable, cast from typing import Callable, cast, TYPE_CHECKING
import functools import functools
from dataclasses import dataclass, field from dataclasses import dataclass, field
from tinygrad.helpers import to_function_name, dedup, prod, DEBUG from tinygrad.helpers import to_function_name, dedup, prod, DEBUG
@@ -7,6 +7,7 @@ from tinygrad.uop.ops import Ops, UOp, sym_infer, sint, Variable, ssimplify, Gro
from tinygrad.dtype import AddrSpace, PtrDType from tinygrad.dtype import AddrSpace, PtrDType
from tinygrad.codegen.opt.tc import TensorCore from tinygrad.codegen.opt.tc import TensorCore
from tinygrad.codegen.opt import Opt from tinygrad.codegen.opt import Opt
if TYPE_CHECKING: from tinygrad.device import Compiler
@dataclass(frozen=True) @dataclass(frozen=True)
class Estimates: class Estimates:
@@ -64,6 +65,7 @@ class ProgramSpec:
device:str device:str
ast:UOp # save the base ast (this is method cache key) ast:UOp # save the base ast (this is method cache key)
uops:list[UOp]|None=None uops:list[UOp]|None=None
lib:bytes|None=None
# filled in from uops (via from_uop) # filled in from uops (via from_uop)
global_size:list[int]=field(default_factory=lambda: [1,1,1]) global_size:list[int]=field(default_factory=lambda: [1,1,1])
@@ -95,8 +97,9 @@ class ProgramSpec:
def from_uop(prg:UOp) -> ProgramSpec: def from_uop(prg:UOp) -> ProgramSpec:
"""Construct ProgramSpec from a PROGRAM UOp.""" """Construct ProgramSpec from a PROGRAM UOp."""
assert prg.op is Ops.PROGRAM, f"expected PROGRAM, got {prg.op}" assert prg.op is Ops.PROGRAM, f"expected PROGRAM, got {prg.op}"
# SINK/DEVICE/LINEAR/SOURCE # SINK/DEVICE/LINEAR/SOURCE/BINARY?
sink, device, linear, source = prg.src sink, device, linear, source = prg.src[:4]
lib = prg.src[4].arg if len(prg.src) > 4 else None
uops = list(linear.src) uops = list(linear.src)
if DEBUG >= 6: print_uops(uops) # LINEAR is src[2] if DEBUG >= 6: print_uops(uops) # LINEAR is src[2]
@@ -120,7 +123,7 @@ class ProgramSpec:
# TODO: this cast is wrong, u.src[0].ssimplify() can be sint # TODO: this cast is wrong, u.src[0].ssimplify() can be sint
if special_size is not None: special_size[int(u.arg[-1])] = cast(int, u.src[0].ssimplify()) if special_size is not None: special_size[int(u.arg[-1])] = cast(int, u.src[0].ssimplify())
return ProgramSpec(sink.arg.name, source.arg, device.arg, sink, uops, global_size, local_size, return ProgramSpec(sink.arg.name, source.arg, device.arg, sink, uops, lib, global_size, local_size,
sorted(_vars, key=lambda v: v.arg), sorted(dedup(_globals)), sorted(dedup(outs)), sorted(dedup(ins))) sorted(_vars, key=lambda v: v.arg), sorted(dedup(_globals)), sorted(dedup(outs)), sorted(dedup(ins)))
class Renderer: class Renderer:
@@ -139,6 +142,7 @@ class Renderer:
pre_matcher: PatternMatcher|None = None pre_matcher: PatternMatcher|None = None
extra_matcher: PatternMatcher|None = None extra_matcher: PatternMatcher|None = None
code_for_op: dict[Ops, Callable] = {} code_for_op: dict[Ops, Callable] = {}
compiler: Compiler|None = None
def __reduce__(self): return self.__class__, () def __reduce__(self): return self.__class__, ()
def render(self, uops:list[UOp]) -> str: raise NotImplementedError("needs a renderer") def render(self, uops:list[UOp]) -> str: raise NotImplementedError("needs a renderer")

View File

@@ -8,6 +8,7 @@ from tinygrad.dtype import ImageDType, dtypes, DType, PtrDType, AddrSpace, trunc
from tinygrad.renderer import Renderer from tinygrad.renderer import Renderer
from tinygrad.codegen.late.devectorizer import no_vectorized_alu from tinygrad.codegen.late.devectorizer import no_vectorized_alu
base_rewrite = PatternMatcher([ base_rewrite = PatternMatcher([
(UPat(Ops.DEFINE_REG, name="x"), lambda ctx,x: f"{ctx.render_dtype(x.dtype.base)} {ctx[x]}[{x.dtype.size}];"), (UPat(Ops.DEFINE_REG, name="x"), lambda ctx,x: f"{ctx.render_dtype(x.dtype.base)} {ctx[x]}[{x.dtype.size}];"),
(UPat(Ops.IF, name="x"), lambda ctx,x: f"if ({ctx[x.src[0]]}) {{"), (UPat(Ops.IF, name="x"), lambda ctx,x: f"if ({ctx[x.src[0]]}) {{"),
@@ -278,6 +279,11 @@ class ClangRenderer(CStyleLanguage):
defines = '\n'.join(self._render_defines(uops)) defines = '\n'.join(self._render_defines(uops))
return defines + "\n" + self._render_body(function_name, kernel, bufs, uops, prefix) + "\n" + self._render_entry(function_name, bufs) return defines + "\n" + self._render_body(function_name, kernel, bufs, uops, prefix) + "\n" + self._render_entry(function_name, bufs)
class ClangJITRenderer(ClangRenderer):
def __init__(self):
from tinygrad.runtime.support.compiler_cpu import ClangJITCompiler
self.compiler = ClangJITCompiler()
class OpenCLRenderer(CStyleLanguage): class OpenCLRenderer(CStyleLanguage):
device = "CL" device = "CL"

View File

@@ -5,10 +5,10 @@ from tinygrad.helpers import CPU_CC, CPU_LVP, CPU_LLVM
from tinygrad.device import BufferSpec, DMACPURef, CompilerSet, CompilerPair from tinygrad.device import BufferSpec, DMACPURef, CompilerSet, CompilerPair
from tinygrad.runtime.support.hcq import HCQCompiled, HCQAllocatorBase, HCQBuffer, HWQueue, HCQArgsState, HCQSignal, HCQProgram, MMIOInterface from tinygrad.runtime.support.hcq import HCQCompiled, HCQAllocatorBase, HCQBuffer, HWQueue, HCQArgsState, HCQSignal, HCQProgram, MMIOInterface
from tinygrad.runtime.support.hcq import CLikeArgsState from tinygrad.runtime.support.hcq import CLikeArgsState
from tinygrad.renderer.cstyle import ClangRenderer from tinygrad.renderer.cstyle import ClangJITRenderer
from tinygrad.renderer.llvmir import LLVMRenderer from tinygrad.renderer.llvmir import LLVMRenderer
from tinygrad.renderer.nir import LVPRenderer from tinygrad.renderer.nir import LVPRenderer
from tinygrad.runtime.support.compiler_cpu import CPULLVMCompiler, ClangJITCompiler from tinygrad.runtime.support.compiler_cpu import CPULLVMCompiler
from tinygrad.runtime.support.compiler_mesa import LVPCompiler from tinygrad.runtime.support.compiler_mesa import LVPCompiler
from tinygrad.runtime.support.elf import jit_loader from tinygrad.runtime.support.elf import jit_loader
from tinygrad.uop.ops import sint from tinygrad.uop.ops import sint
@@ -136,6 +136,6 @@ class CPUDevice(HCQCompiled):
def __init__(self, device:str=""): def __init__(self, device:str=""):
self.tasks:queue.Queue = queue.Queue() self.tasks:queue.Queue = queue.Queue()
CPUWorker(self, self.tasks, thread_id=0).start() CPUWorker(self, self.tasks, thread_id=0).start()
compilers = CompilerSet([CompilerPair(ClangRenderer, ClangJITCompiler), CompilerPair(LLVMRenderer, CPULLVMCompiler, ctrl_var=CPU_LLVM), compilers = CompilerSet([CompilerPair(ClangJITRenderer, None), CompilerPair(LLVMRenderer, CPULLVMCompiler, ctrl_var=CPU_LLVM),
CompilerPair(LVPRenderer, LVPCompiler, ctrl_var=CPU_LVP)], ctrl_var=CPU_CC) CompilerPair(LVPRenderer, LVPCompiler, ctrl_var=CPU_LVP)], ctrl_var=CPU_CC)
super().__init__(device, CPUAllocator(self), compilers, functools.partial(CPUProgram, self), CPUSignal, CPUComputeQueue) super().__init__(device, CPUAllocator(self), compilers, functools.partial(CPUProgram, self), CPUSignal, CPUComputeQueue)

View File

@@ -28,8 +28,8 @@ class Ops(FastEnum):
NOOP = auto(); REWRITE_ERROR = auto() NOOP = auto(); REWRITE_ERROR = auto()
# renderer # renderer
# LINEAR is a list of UOps, SOURCE has a str arg that's human readable # LINEAR is a list of UOps, SOURCE has a str arg that's human readable, BINARY has bytes arg that's compiled
PROGRAM = auto(); LINEAR = auto(); SOURCE = auto() PROGRAM = auto(); LINEAR = auto(); SOURCE = auto(); BINARY = auto()
# AFTER passes src[0] through and promises in the toposort that any consumers of the AFTER run after src[1:] # AFTER passes src[0] through and promises in the toposort that any consumers of the AFTER run after src[1:]
# GROUP is a NOOP that just merges things together # GROUP is a NOOP that just merges things together

View File

@@ -219,7 +219,7 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
# late ops don't have shape # late ops don't have shape
case Ops.UNIQUE | Ops.LUNIQUE | Ops.DEVICE | Ops.RANGE | Ops.LOAD | Ops.IF | Ops.BARRIER | Ops.CUSTOM | Ops.CUSTOMI | \ case Ops.UNIQUE | Ops.LUNIQUE | Ops.DEVICE | Ops.RANGE | Ops.LOAD | Ops.IF | Ops.BARRIER | Ops.CUSTOM | Ops.CUSTOMI | \
Ops.VECTORIZE | Ops.VCONST | Ops.GEP | Ops.SPECIAL | Ops.UNROLL | Ops.CONTRACT | Ops.CUSTOM_KERNEL | \ Ops.VECTORIZE | Ops.VCONST | Ops.GEP | Ops.SPECIAL | Ops.UNROLL | Ops.CONTRACT | Ops.CUSTOM_KERNEL | \
Ops.LINEAR | Ops.PROGRAM | Ops.SOURCE: Ops.LINEAR | Ops.PROGRAM | Ops.SOURCE | Ops.BINARY:
return None return None
case Ops.INDEX: case Ops.INDEX:

View File

@@ -249,13 +249,15 @@ full_spec = PatternMatcher([
# in progress MSTACK may lose device # in progress MSTACK may lose device
(UPat((Ops.MSELECT, Ops.MSTACK), name="x"), lambda x: True), (UPat((Ops.MSELECT, Ops.MSTACK), name="x"), lambda x: True),
# codegen: PROGRAM with progressive sources through the pipeline (SINK, DEVICE, LINEAR?, SOURCE?) # codegen: PROGRAM with progressive sources through the pipeline (SINK, DEVICE, LINEAR?, SOURCE?, BINARY?)
(UPat(Ops.PROGRAM, dtypes.void, src=(UPat(Ops.SINK), UPat(Ops.DEVICE))), lambda: True), (UPat(Ops.PROGRAM, dtypes.void, src=(UPat(Ops.SINK), UPat(Ops.DEVICE))), lambda: True),
(UPat(Ops.PROGRAM, dtypes.void, src=(UPat(Ops.SINK), UPat(Ops.DEVICE), UPat(Ops.LINEAR))), lambda: True), (UPat(Ops.PROGRAM, dtypes.void, src=(UPat(Ops.SINK), UPat(Ops.DEVICE), UPat(Ops.LINEAR))), lambda: True),
(UPat(Ops.PROGRAM, dtypes.void, src=(UPat(Ops.SINK), UPat(Ops.DEVICE), UPat(Ops.LINEAR), UPat(Ops.SOURCE))), lambda: True), (UPat(Ops.PROGRAM, dtypes.void, src=(UPat(Ops.SINK), UPat(Ops.DEVICE), UPat(Ops.LINEAR), UPat(Ops.SOURCE))), lambda: True),
# codegen: standalone LINEAR/SOURCE (UPat(Ops.PROGRAM, dtypes.void, src=(UPat(Ops.SINK), UPat(Ops.DEVICE), UPat(Ops.LINEAR), UPat(Ops.SOURCE), UPat(Ops.BINARY))), lambda: True),
# codegen: standalone LINEAR/SOURCE/BINARY
(UPat(Ops.LINEAR, dtypes.void), lambda: True), (UPat(Ops.LINEAR, dtypes.void), lambda: True),
(UPat(Ops.SOURCE, dtypes.void, src=()), lambda: True), (UPat(Ops.SOURCE, dtypes.void, src=()), lambda: True),
(UPat(Ops.BINARY, dtypes.void, src=()), lambda: True),
# temp VECTORIZE/INDEX during rewrite have the wrong dtype # temp VECTORIZE/INDEX during rewrite have the wrong dtype
(UPat(Ops.VECTORIZE), lambda: True), (UPat(Ops.VECTORIZE), lambda: True),