From 43c6e973d802fe3c536262593cafff24f3e87948 Mon Sep 17 00:00:00 2001 From: George Hotz <72895+geohot@users.noreply.github.com> Date: Tue, 23 Dec 2025 17:58:46 -0500 Subject: [PATCH] add optional compiler in Renderer (#13817) * add optional compiler in Renderer [pr] * fix * late init * remove precompiled * cleanup --- test/external/external_hip_compiler_bug.py | 225 --------------------- test/unit/test_elf.py | 2 +- tinygrad/codegen/__init__.py | 6 + tinygrad/codegen/opt/search.py | 4 +- tinygrad/device.py | 22 +- tinygrad/engine/realize.py | 18 +- tinygrad/helpers.py | 4 +- tinygrad/renderer/__init__.py | 12 +- tinygrad/renderer/cstyle.py | 6 + tinygrad/runtime/ops_cpu.py | 6 +- tinygrad/uop/__init__.py | 4 +- tinygrad/uop/ops.py | 2 +- tinygrad/uop/spec.py | 6 +- 13 files changed, 57 insertions(+), 260 deletions(-) delete mode 100644 test/external/external_hip_compiler_bug.py diff --git a/test/external/external_hip_compiler_bug.py b/test/external/external_hip_compiler_bug.py deleted file mode 100644 index 9dfcf01264..0000000000 --- a/test/external/external_hip_compiler_bug.py +++ /dev/null @@ -1,225 +0,0 @@ -# [, , ] -from tinygrad import Device, dtypes -from tinygrad.device import Buffer, CompiledRunner - -import ctypes -import gpuctypes.hip as hip -from tinygrad.helpers import to_char_p_p, init_c_var -def get_bytes(arg, get_sz, get_str, check) -> bytes: return (sz := init_c_var(ctypes.c_size_t(), lambda x: check(get_sz(arg, ctypes.byref(x)))), ctypes.string_at(init_c_var(ctypes.create_string_buffer(sz.value), lambda x: check(get_str(arg, x))), size=sz.value))[1] # noqa: E501 -def check(status): - if status != 0: raise RuntimeError(f"HIP Error {status}, {ctypes.string_at(hip.hipGetErrorString(status)).decode()}") -def compile_hip(prg:str, arch="gfx1100") -> bytes: - check(hip.hiprtcCreateProgram(ctypes.byref(prog := hip.hiprtcProgram()), prg.encode(), "".encode(), 0, None, None)) - compile_options = [f'--offload-arch={arch}', '-I/opt/rocm/include'] - status = hip.hiprtcCompileProgram(prog, len(compile_options), to_char_p_p([o.encode() for o in compile_options])) - if status != 0: raise RuntimeError(f"compile failed: {get_bytes(prog, hip.hiprtcGetProgramLogSize, hip.hiprtcGetProgramLog, check).decode()}") - return get_bytes(prog, hip.hiprtcGetCodeSize, hip.hiprtcGetCode, check) - -prefix = """ -typedef long unsigned int size_t; -extern "C" __attribute__((device)) __attribute__((const)) size_t __ockl_get_local_id(unsigned int); -extern "C" __attribute__((device)) __attribute__((const)) size_t __ockl_get_group_id(unsigned int); -extern "C" __attribute__((device)) __attribute__((const)) size_t __ockl_get_local_size(unsigned int); -typedef float float2 __attribute__((ext_vector_type(2))); -static inline __attribute__((device)) float2 make_float2(float x, float y) { return {x, y}; } -""" - -code = """ -extern "C" __attribute__((global))void r_2_8_7_7_4_8_3_7_7_4_4_2_2(float* data0, const float* data1, const float* data2) { - int gidx0 = __ockl_get_group_id(2); /* 2 */ - int gidx1 = __ockl_get_group_id(1); /* 8 */ - int gidx2 = __ockl_get_group_id(0); /* 49 */ - int lidx4 = __ockl_get_local_id(1); /* 4 */ - int lidx5 = __ockl_get_local_id(0); /* 8 */ - float2 acc0 = make_float2(0.0f,0.0f); - float2 acc1 = make_float2(0.0f,0.0f); - float2 acc2 = make_float2(0.0f,0.0f); - float2 acc3 = make_float2(0.0f,0.0f); - float2 acc4 = make_float2(0.0f,0.0f); - float2 acc5 = make_float2(0.0f,0.0f); - float2 acc6 = make_float2(0.0f,0.0f); - float2 acc7 = make_float2(0.0f,0.0f); - float2 acc8 = make_float2(0.0f,0.0f); - float2 acc9 = make_float2(0.0f,0.0f); - float2 acc10 = make_float2(0.0f,0.0f); - float2 acc11 = make_float2(0.0f,0.0f); - float2 acc12 = make_float2(0.0f,0.0f); - float2 acc13 = make_float2(0.0f,0.0f); - float2 acc14 = make_float2(0.0f,0.0f); - float2 acc15 = make_float2(0.0f,0.0f); - float2 acc16 = make_float2(0.0f,0.0f); - float2 acc17 = make_float2(0.0f,0.0f); - float2 acc18 = make_float2(0.0f,0.0f); - float2 acc19 = make_float2(0.0f,0.0f); - float2 acc20 = make_float2(0.0f,0.0f); - float2 acc21 = make_float2(0.0f,0.0f); - float2 acc22 = make_float2(0.0f,0.0f); - float2 acc23 = make_float2(0.0f,0.0f); - float2 acc24 = make_float2(0.0f,0.0f); - float2 acc25 = make_float2(0.0f,0.0f); - float2 acc26 = make_float2(0.0f,0.0f); - float2 acc27 = make_float2(0.0f,0.0f); - float2 acc28 = make_float2(0.0f,0.0f); - float2 acc29 = make_float2(0.0f,0.0f); - float2 acc30 = make_float2(0.0f,0.0f); - float2 acc31 = make_float2(0.0f,0.0f); - int alu0 = (gidx2/7); - int alu1 = (gidx2%7); - int alu2 = (alu1*32); - int alu3 = (lidx5*4); - int alu4 = ((gidx0*802816)+(gidx1*100352)+(alu0*1792)+(alu1*16)+(lidx4*448)+(lidx5*2)); - for (int ridx0 = 0; ridx0 < 3; ridx0++) { - for (int ridx1 = 0; ridx1 < 7; ridx1++) { - int alu5 = ((alu0*(-32))+(lidx4*(-8))+(ridx1*(-1))); - bool alu6 = (alu5<(-2)); - bool alu7 = (alu5<0); - bool alu8 = (((alu0*32)+(lidx4*8)+ridx1)<221); - for (int ridx2 = 0; ridx2 < 7; ridx2++) { - int alu9 = ((gidx0*150528)+(ridx0*50176)+(alu0*7168)+(lidx4*1792)+(ridx1*224)+alu2+alu3+ridx2); - int alu10 = ((alu1*(-32))+(lidx5*(-4))+(ridx2*(-1))); - bool alu11 = (alu10<(-2)); - float val0 = 0.0f; - if ((alu6*alu11)) { val0 = data1[alu9+(-675)]; } - float val1 = 0.0f; - if ((alu7*alu11)) { val1 = data1[alu9+(-227)]; } - float val2 = 0.0f; - if (alu11) { val2 = data1[alu9+221]; } - float val3 = 0.0f; - if ((alu8*alu11)) { val3 = data1[alu9+669]; } - bool alu12 = (alu10<0); - bool alu13 = ((alu2+alu3+ridx2)<225); - float val4 = 0.0f; - if ((alu6*alu12*alu13)) { val4 = data1[alu9+(-673)]; } - float val5 = 0.0f; - if ((alu7*alu12*alu13)) { val5 = data1[alu9+(-225)]; } - float val6 = 0.0f; - if ((alu12*alu13)) { val6 = data1[alu9+223]; } - float val7 = 0.0f; - if ((alu8*alu12*alu13)) { val7 = data1[alu9+671]; } - int alu14 = ((gidx1*1176)+(ridx0*49)+(ridx1*7)+ridx2); - float val8 = data2[alu14]; - float val9 = data2[alu14+147]; - float val10 = data2[alu14+294]; - float val11 = data2[alu14+441]; - float val12 = data2[alu14+588]; - float val13 = data2[alu14+735]; - float val14 = data2[alu14+882]; - float val15 = data2[alu14+1029]; - (acc0).x = ((val0*val8)+(acc0).x); - (acc1).x = ((val0*val9)+(acc1).x); - (acc2).x = ((val0*val10)+(acc2).x); - (acc3).x = ((val0*val11)+(acc3).x); - (acc4).x = ((val1*val8)+(acc4).x); - (acc5).x = ((val1*val9)+(acc5).x); - (acc6).x = ((val1*val10)+(acc6).x); - (acc7).x = ((val1*val11)+(acc7).x); - (acc8).x = ((val2*val8)+(acc8).x); - (acc9).x = ((val2*val9)+(acc9).x); - (acc10).x = ((val2*val10)+(acc10).x); - (acc11).x = ((val2*val11)+(acc11).x); - (acc12).x = ((val3*val8)+(acc12).x); - (acc13).x = ((val3*val9)+(acc13).x); - (acc14).x = ((val3*val10)+(acc14).x); - (acc15).x = ((val3*val11)+(acc15).x); - (acc16).x = ((val0*val12)+(acc16).x); - (acc17).x = ((val0*val13)+(acc17).x); - (acc18).x = ((val0*val14)+(acc18).x); - (acc19).x = ((val0*val15)+(acc19).x); - (acc20).x = ((val1*val12)+(acc20).x); - (acc21).x = ((val1*val13)+(acc21).x); - (acc22).x = ((val1*val14)+(acc22).x); - (acc23).x = ((val1*val15)+(acc23).x); - (acc24).x = ((val2*val12)+(acc24).x); - (acc25).x = ((val2*val13)+(acc25).x); - (acc26).x = ((val2*val14)+(acc26).x); - (acc27).x = ((val2*val15)+(acc27).x); - (acc28).x = ((val3*val12)+(acc28).x); - (acc29).x = ((val3*val13)+(acc29).x); - (acc30).x = ((val3*val14)+(acc30).x); - (acc31).x = ((val3*val15)+(acc31).x); - (acc0).y = ((val4*val8)+(acc0).y); - (acc1).y = ((val4*val9)+(acc1).y); - (acc2).y = ((val4*val10)+(acc2).y); - (acc3).y = ((val4*val11)+(acc3).y); - (acc4).y = ((val5*val8)+(acc4).y); - (acc5).y = ((val5*val9)+(acc5).y); - (acc6).y = ((val5*val10)+(acc6).y); - (acc7).y = ((val5*val11)+(acc7).y); - (acc8).y = ((val6*val8)+(acc8).y); - (acc9).y = ((val6*val9)+(acc9).y); - (acc10).y = ((val6*val10)+(acc10).y); - (acc11).y = ((val6*val11)+(acc11).y); - (acc12).y = ((val7*val8)+(acc12).y); - (acc13).y = ((val7*val9)+(acc13).y); - (acc14).y = ((val7*val10)+(acc14).y); - (acc15).y = ((val7*val11)+(acc15).y); - (acc16).y = ((val4*val12)+(acc16).y); - (acc17).y = ((val4*val13)+(acc17).y); - (acc18).y = ((val4*val14)+(acc18).y); - (acc19).y = ((val4*val15)+(acc19).y); - (acc20).y = ((val5*val12)+(acc20).y); - (acc21).y = ((val5*val13)+(acc21).y); - (acc22).y = ((val5*val14)+(acc22).y); - (acc23).y = ((val5*val15)+(acc23).y); - (acc24).y = ((val6*val12)+(acc24).y); - (acc25).y = ((val6*val13)+(acc25).y); - (acc26).y = ((val6*val14)+(acc26).y); - (acc27).y = ((val6*val15)+(acc27).y); - (acc28).y = ((val7*val12)+(acc28).y); - (acc29).y = ((val7*val13)+(acc29).y); - (acc30).y = ((val7*val14)+(acc30).y); - (acc31).y = ((val7*val15)+(acc31).y); - } - } - } - *((float2*)(data0+alu4)) = acc0; - *((float2*)(data0+alu4+12544)) = acc1; - *((float2*)(data0+alu4+25088)) = acc2; - *((float2*)(data0+alu4+37632)) = acc3; - *((float2*)(data0+alu4+112)) = acc4; - *((float2*)(data0+alu4+12656)) = acc5; - *((float2*)(data0+alu4+25200)) = acc6; - *((float2*)(data0+alu4+37744)) = acc7; - *((float2*)(data0+alu4+224)) = acc8; - *((float2*)(data0+alu4+12768)) = acc9; - *((float2*)(data0+alu4+25312)) = acc10; - *((float2*)(data0+alu4+37856)) = acc11; - *((float2*)(data0+alu4+336)) = acc12; - *((float2*)(data0+alu4+12880)) = acc13; - *((float2*)(data0+alu4+25424)) = acc14; - *((float2*)(data0+alu4+37968)) = acc15; - *((float2*)(data0+alu4+50176)) = acc16; - *((float2*)(data0+alu4+62720)) = acc17; - *((float2*)(data0+alu4+75264)) = acc18; - *((float2*)(data0+alu4+87808)) = acc19; - *((float2*)(data0+alu4+50288)) = acc20; - *((float2*)(data0+alu4+62832)) = acc21; - *((float2*)(data0+alu4+75376)) = acc22; - *((float2*)(data0+alu4+87920)) = acc23; - *((float2*)(data0+alu4+50400)) = acc24; - *((float2*)(data0+alu4+62944)) = acc25; - *((float2*)(data0+alu4+75488)) = acc26; - *((float2*)(data0+alu4+88032)) = acc27; - *((float2*)(data0+alu4+50512)) = acc28; - *((float2*)(data0+alu4+63056)) = acc29; - *((float2*)(data0+alu4+75600)) = acc30; - *((float2*)(data0+alu4+88144)) = acc31; -} -""" - -dev = "HIP" -lib = Device[dev].compiler.compile(prefix+code) -#lib = compile_hip(code) -b0 = Buffer(dev, 1605632, dtypes.float) -b1 = Buffer(dev, 301506, dtypes.float) -b2 = Buffer(dev, 9408, dtypes.float) -print(hex(b0._buf.value), hex(b0._buf.value+1605632*4)) -print(hex(b1._buf.value)) -print(hex(b2._buf.value)) -#prg = CompiledRunner("r_2_8_7_7_4_8_3_7_7_4_4_2_2", "", dev, [7, 1, 1], [8, 4, 1], precompiled=lib) -prg = CompiledRunner("r_2_8_7_7_4_8_3_7_7_4_4_2_2", "", dev, [49, 8, 2], [8, 4, 1], precompiled=lib) -print("compiled") -prg([b0, b1, b2], {}) -print("ran") -Device[dev].synchronize() -print("sync") diff --git a/test/unit/test_elf.py b/test/unit/test_elf.py index 2dca88b077..9207a384fa 100644 --- a/test/unit/test_elf.py +++ b/test/unit/test_elf.py @@ -1,5 +1,5 @@ import unittest, subprocess, platform -from tinygrad.runtime.ops_cpu import ClangJITCompiler +from tinygrad.runtime.support.compiler_cpu import ClangJITCompiler from tinygrad.runtime.support.elf import elf_loader class TestElfLoader(unittest.TestCase): diff --git a/tinygrad/codegen/__init__.py b/tinygrad/codegen/__init__.py index b22c850fae..0a7ecfdad2 100644 --- a/tinygrad/codegen/__init__.py +++ b/tinygrad/codegen/__init__.py @@ -135,9 +135,15 @@ def do_render(ctx:Renderer, prg:UOp, lin:UOp) -> UOp: src = ctx.render(list(lin.src)) return prg.replace(src=prg.src + (UOp(Ops.SOURCE, arg=src),)) +def do_compile(ctx:Renderer, prg:UOp, source:UOp) -> UOp|None: + if ctx.compiler is None: return None + lib = ctx.compiler.compile_cached(source.arg) + return prg.replace(src=prg.src + (UOp(Ops.BINARY, arg=lib),)) + pm_to_program = PatternMatcher([ (UPat(Ops.PROGRAM, src=(UPat(Ops.SINK, name="sink"), UPat(Ops.DEVICE)), name="prg"), do_linearize), (UPat(Ops.PROGRAM, src=(UPat(), UPat(Ops.DEVICE), UPat(Ops.LINEAR, name="lin")), name="prg"), do_render), + (UPat(Ops.PROGRAM, src=(UPat(), UPat(Ops.DEVICE), UPat(Ops.LINEAR), UPat(Ops.SOURCE, name="source")), name="prg"), do_compile), ]) @track_rewrites(name=lambda *args,ret,**kwargs: TracingKey(ret.name, (ret.function_name, ret.ast), ret=ret), replay=True) diff --git a/tinygrad/codegen/opt/search.py b/tinygrad/codegen/opt/search.py index aa33c14bbf..18d7ea49bc 100644 --- a/tinygrad/codegen/opt/search.py +++ b/tinygrad/codegen/opt/search.py @@ -41,7 +41,7 @@ def _time_program(p:ProgramSpec, lib:bytes, var_vals:dict[str, int], rawbufs:lis if allow_test_size and max_global_size is not None: global_size, factor = get_test_global_size(p.global_size, max_global_size, var_vals) p = replace(p, global_size=global_size) - try: car = CompiledRunner(p, precompiled=lib) + try: car = CompiledRunner(replace(p, lib=lib)) except AssertionError: return [math.inf] * cnt tms = [] input_bufs = [rawbufs[i] for i in car.p.globals] @@ -72,7 +72,7 @@ def _try_compile(x:tuple[int,Scheduler], compiler:Compiler) -> tuple[int, tuple[ if getenv("BEAM_LOG_SURPASS_MAX"): print(f"too many uops. {len(p.uops)=}, {uops_max=}") raise RuntimeError("too many uops") st = time.perf_counter() - prog = compiler.compile(p.src) + prog = p.lib if p.lib is not None else compiler.compile(p.src) et = time.perf_counter() - st ret = (p, prog, et) except RuntimeError: diff --git a/tinygrad/device.py b/tinygrad/device.py index f8e5b5e915..1a76a89b47 100644 --- a/tinygrad/device.py +++ b/tinygrad/device.py @@ -278,7 +278,7 @@ class Compiler: def disassemble(self, lib:bytes): pass @dataclass(frozen=True) -class CompilerPair: renderer:type[Renderer]|functools.partial; compiler:type[Compiler]|functools.partial; ctrl_var:ContextVar|None = None # noqa: E702 +class CompilerPair: renderer:type[Renderer]|functools.partial; compiler:type[Compiler]|functools.partial|None; ctrl_var:ContextVar|None = None # noqa: E702 @dataclass(frozen=True) class CompilerSet: cset:list[CompilerPair]; ctrl_var:ContextVar|None = None # noqa: E702 @@ -290,21 +290,25 @@ class Compiled: self.device, self.allocator, self.runtime, self.graph, self.group_id = device, allocator, runtime, graph, group_id self.comps_ctrl_var = compilers.ctrl_var if compilers is not None else None - self.comp_sets:dict[Any, tuple[ContextVar|None, tuple[type[Renderer]|functools.partial, type[Compiler]|functools.partial]]] = {} - self.cached_pair:dict[Any, tuple[Renderer, Compiler]] = {} + self.comp_sets:dict[Any, tuple[ContextVar|None, tuple[type[Renderer]|functools.partial, type[Compiler]|functools.partial|None]]] = {} + self.cached_pair:dict[Any, tuple[Renderer, Compiler|None]] = {} for cpair in (compilers.cset if compilers is not None else [CompilerPair(Renderer, Compiler)]): - self.comp_sets[self._compiler_name(cpair.compiler)] = (cpair.ctrl_var, (cpair.renderer, cpair.compiler)) + self.comp_sets[self._compiler_name(cpair.renderer, cpair.compiler)] = (cpair.ctrl_var, (cpair.renderer, cpair.compiler)) @property def renderer(self) -> Renderer: return self._select_compiler_pair()[0] @property - def compiler(self) -> Compiler: return self._select_compiler_pair()[1] + def compiler(self) -> Compiler: + if (ret:=self.renderer.compiler or self._select_compiler_pair()[1]) is None: raise RuntimeError(f"no compiler for {self.device}") + return ret - def _compiler_name(self, c:type[Compiler]|functools.partial) -> str: - return unwrap_class_type(c).__name__.upper().removesuffix("COMPILER").removeprefix(devname:=self.device.split(':')[0].upper()) or devname + def _compiler_name(self, r:type[Renderer]|functools.partial, c:type[Compiler]|functools.partial|None) -> str: + devname = self.device.split(':')[0].upper() + if c is None: return unwrap_class_type(r).__name__.upper().removesuffix("RENDERER").removeprefix(devname) or devname + return unwrap_class_type(c).__name__.upper().removesuffix("COMPILER").removeprefix(devname) or devname - def _select_compiler_pair(self) -> tuple[Renderer, Compiler]: + def _select_compiler_pair(self) -> tuple[Renderer, Compiler|None]: # select forced compiler from global env var. forced_comps = set([self.comp_sets[val][1]] if self.comps_ctrl_var is not None and (val:=self.comps_ctrl_var.value) else []) @@ -398,7 +402,7 @@ def enumerate_devices_str() -> Generator[str, None, None]: # d.renderer, d.compiler = r(), c() with Context(CACHELEVEL=0): test = (Tensor([1,2,3], device=device) * 2).tolist() if test != [2,4,6]: raise ValueError(f"got {test} instead of [2, 4, 6]") - set_text = f'({cc_ctrl_var.key}={d._compiler_name(c)} to make default)' if cc_ctrl_var is not None else '' + set_text = f'({cc_ctrl_var.key}={d._compiler_name(r, c)} to make default)' if cc_ctrl_var is not None else '' default_text = '(default)' if type(default_compiler) is type(d.compiler) else set_text compilers_results.append(f"{colored('+', 'green')} {unwrap_class_type(c).__name__} {default_text}") any_works = True diff --git a/tinygrad/engine/realize.py b/tinygrad/engine/realize.py index ab3bdac1ab..70234e7071 100644 --- a/tinygrad/engine/realize.py +++ b/tinygrad/engine/realize.py @@ -36,19 +36,19 @@ def optimize_local_size(_prg:Callable, global_size:list[int], rawbufs:list[Buffe return ret[1] class CompiledRunner(Runner): - def __init__(self, p:ProgramSpec, precompiled:bytes|None=None, prg=None): + def __init__(self, p:ProgramSpec, prg=None): if DEBUG >= 3: print(p.applied_opts) if DEBUG >= 4: print(p.src) - self.p:ProgramSpec = p - if precompiled is not None: self.lib = precompiled - else: + if p.lib is None: with cpu_profile(TracingKey(f"compile {p.name}", (p.function_name,)), "TINY"): - self.lib = Device[p.device].compiler.compile_cached(p.src) - if DEBUG >= 7: Device[p.device].compiler.disassemble(self.lib) - self._prg = Device[p.device].runtime(p.function_name, self.lib) if prg is None else prg + p = replace(p, lib=Device[p.device].compiler.compile_cached(p.src)) + self.p:ProgramSpec = p + assert self.p.lib is not None + if DEBUG >= 7: Device[p.device].compiler.disassemble(self.p.lib) + self._prg = Device[p.device].runtime(p.function_name, self.p.lib) if prg is None else prg super().__init__(p.name, p.device, p.estimates) - def __reduce__(self): return self.__class__, (self.p, self.lib) + def __reduce__(self): return self.__class__, (self.p,) def __call__(self, rawbufs:list[Buffer], var_vals:dict[str, int]|None=None, wait=False) -> float|None: if var_vals is None: var_vals = {} @@ -115,7 +115,7 @@ def get_runner(device:str, ast:UOp) -> CompiledRunner: if cret:=method_cache.get(ckey): return cret bkey = (device.split(":")[0], type(Device[device].compiler), ast.key, context, True) if bret:=method_cache.get(bkey): - method_cache[ckey] = ret = CompiledRunner(replace(bret.p, device=device), bret.lib) + method_cache[ckey] = ret = CompiledRunner(replace(bret.p, device=device)) else: prg: ProgramSpec = get_program(ast, Device[device].renderer) method_cache[ckey] = method_cache[bkey] = ret = CompiledRunner(replace(prg, device=device)) diff --git a/tinygrad/helpers.py b/tinygrad/helpers.py index 324764b296..8069ae842f 100644 --- a/tinygrad/helpers.py +++ b/tinygrad/helpers.py @@ -115,12 +115,12 @@ def suppress_finalizing(func): if not getattr(sys, 'is_finalizing', lambda: True)(): raise # re-raise if not finalizing return wrapper -def select_first_inited(candidates:Sequence[Callable[...,T]|Sequence[Callable[...,T]]], err_msg:str, cache:dict|None=None) -> tuple[T,...]|T: +def select_first_inited(candidates:Sequence[Callable[...,T]|Sequence[Callable[...,T]|None]], err_msg:str, cache:dict|None=None): excs = [] for typ in candidates: if cache is not None and typ in cache: return cache[typ] try: - x = tuple([cast(Callable, t)() for t in typ]) if isinstance(typ, Sequence) else cast(Callable, typ)() + x = tuple([cast(Callable, t)() if t is not None else None for t in typ]) if isinstance(typ, Sequence) else cast(Callable, typ)() if cache is not None: cache[typ] = x return x except Exception as e: excs.append(e) diff --git a/tinygrad/renderer/__init__.py b/tinygrad/renderer/__init__.py index c871281c65..0d69e1df7a 100644 --- a/tinygrad/renderer/__init__.py +++ b/tinygrad/renderer/__init__.py @@ -1,5 +1,5 @@ from __future__ import annotations -from typing import Callable, cast +from typing import Callable, cast, TYPE_CHECKING import functools from dataclasses import dataclass, field from tinygrad.helpers import to_function_name, dedup, prod, DEBUG @@ -7,6 +7,7 @@ from tinygrad.uop.ops import Ops, UOp, sym_infer, sint, Variable, ssimplify, Gro from tinygrad.dtype import AddrSpace, PtrDType from tinygrad.codegen.opt.tc import TensorCore from tinygrad.codegen.opt import Opt +if TYPE_CHECKING: from tinygrad.device import Compiler @dataclass(frozen=True) class Estimates: @@ -64,6 +65,7 @@ class ProgramSpec: device:str ast:UOp # save the base ast (this is method cache key) uops:list[UOp]|None=None + lib:bytes|None=None # filled in from uops (via from_uop) global_size:list[int]=field(default_factory=lambda: [1,1,1]) @@ -95,8 +97,9 @@ class ProgramSpec: def from_uop(prg:UOp) -> ProgramSpec: """Construct ProgramSpec from a PROGRAM UOp.""" assert prg.op is Ops.PROGRAM, f"expected PROGRAM, got {prg.op}" - # SINK/DEVICE/LINEAR/SOURCE - sink, device, linear, source = prg.src + # SINK/DEVICE/LINEAR/SOURCE/BINARY? + sink, device, linear, source = prg.src[:4] + lib = prg.src[4].arg if len(prg.src) > 4 else None uops = list(linear.src) if DEBUG >= 6: print_uops(uops) # LINEAR is src[2] @@ -120,7 +123,7 @@ class ProgramSpec: # TODO: this cast is wrong, u.src[0].ssimplify() can be sint if special_size is not None: special_size[int(u.arg[-1])] = cast(int, u.src[0].ssimplify()) - return ProgramSpec(sink.arg.name, source.arg, device.arg, sink, uops, global_size, local_size, + return ProgramSpec(sink.arg.name, source.arg, device.arg, sink, uops, lib, global_size, local_size, sorted(_vars, key=lambda v: v.arg), sorted(dedup(_globals)), sorted(dedup(outs)), sorted(dedup(ins))) class Renderer: @@ -139,6 +142,7 @@ class Renderer: pre_matcher: PatternMatcher|None = None extra_matcher: PatternMatcher|None = None code_for_op: dict[Ops, Callable] = {} + compiler: Compiler|None = None def __reduce__(self): return self.__class__, () def render(self, uops:list[UOp]) -> str: raise NotImplementedError("needs a renderer") diff --git a/tinygrad/renderer/cstyle.py b/tinygrad/renderer/cstyle.py index cdd1b3c89b..0a26b02b65 100644 --- a/tinygrad/renderer/cstyle.py +++ b/tinygrad/renderer/cstyle.py @@ -8,6 +8,7 @@ from tinygrad.dtype import ImageDType, dtypes, DType, PtrDType, AddrSpace, trunc from tinygrad.renderer import Renderer from tinygrad.codegen.late.devectorizer import no_vectorized_alu + base_rewrite = PatternMatcher([ (UPat(Ops.DEFINE_REG, name="x"), lambda ctx,x: f"{ctx.render_dtype(x.dtype.base)} {ctx[x]}[{x.dtype.size}];"), (UPat(Ops.IF, name="x"), lambda ctx,x: f"if ({ctx[x.src[0]]}) {{"), @@ -278,6 +279,11 @@ class ClangRenderer(CStyleLanguage): defines = '\n'.join(self._render_defines(uops)) return defines + "\n" + self._render_body(function_name, kernel, bufs, uops, prefix) + "\n" + self._render_entry(function_name, bufs) +class ClangJITRenderer(ClangRenderer): + def __init__(self): + from tinygrad.runtime.support.compiler_cpu import ClangJITCompiler + self.compiler = ClangJITCompiler() + class OpenCLRenderer(CStyleLanguage): device = "CL" diff --git a/tinygrad/runtime/ops_cpu.py b/tinygrad/runtime/ops_cpu.py index 4c117b9b6f..2e76328e22 100644 --- a/tinygrad/runtime/ops_cpu.py +++ b/tinygrad/runtime/ops_cpu.py @@ -5,10 +5,10 @@ from tinygrad.helpers import CPU_CC, CPU_LVP, CPU_LLVM from tinygrad.device import BufferSpec, DMACPURef, CompilerSet, CompilerPair from tinygrad.runtime.support.hcq import HCQCompiled, HCQAllocatorBase, HCQBuffer, HWQueue, HCQArgsState, HCQSignal, HCQProgram, MMIOInterface from tinygrad.runtime.support.hcq import CLikeArgsState -from tinygrad.renderer.cstyle import ClangRenderer +from tinygrad.renderer.cstyle import ClangJITRenderer from tinygrad.renderer.llvmir import LLVMRenderer from tinygrad.renderer.nir import LVPRenderer -from tinygrad.runtime.support.compiler_cpu import CPULLVMCompiler, ClangJITCompiler +from tinygrad.runtime.support.compiler_cpu import CPULLVMCompiler from tinygrad.runtime.support.compiler_mesa import LVPCompiler from tinygrad.runtime.support.elf import jit_loader from tinygrad.uop.ops import sint @@ -136,6 +136,6 @@ class CPUDevice(HCQCompiled): def __init__(self, device:str=""): self.tasks:queue.Queue = queue.Queue() CPUWorker(self, self.tasks, thread_id=0).start() - compilers = CompilerSet([CompilerPair(ClangRenderer, ClangJITCompiler), CompilerPair(LLVMRenderer, CPULLVMCompiler, ctrl_var=CPU_LLVM), + compilers = CompilerSet([CompilerPair(ClangJITRenderer, None), CompilerPair(LLVMRenderer, CPULLVMCompiler, ctrl_var=CPU_LLVM), CompilerPair(LVPRenderer, LVPCompiler, ctrl_var=CPU_LVP)], ctrl_var=CPU_CC) super().__init__(device, CPUAllocator(self), compilers, functools.partial(CPUProgram, self), CPUSignal, CPUComputeQueue) diff --git a/tinygrad/uop/__init__.py b/tinygrad/uop/__init__.py index 898c014986..ad6776a067 100644 --- a/tinygrad/uop/__init__.py +++ b/tinygrad/uop/__init__.py @@ -28,8 +28,8 @@ class Ops(FastEnum): NOOP = auto(); REWRITE_ERROR = auto() # renderer - # LINEAR is a list of UOps, SOURCE has a str arg that's human readable - PROGRAM = auto(); LINEAR = auto(); SOURCE = auto() + # LINEAR is a list of UOps, SOURCE has a str arg that's human readable, BINARY has bytes arg that's compiled + PROGRAM = auto(); LINEAR = auto(); SOURCE = auto(); BINARY = auto() # AFTER passes src[0] through and promises in the toposort that any consumers of the AFTER run after src[1:] # GROUP is a NOOP that just merges things together diff --git a/tinygrad/uop/ops.py b/tinygrad/uop/ops.py index f016166e74..ce4d0c2d3e 100644 --- a/tinygrad/uop/ops.py +++ b/tinygrad/uop/ops.py @@ -219,7 +219,7 @@ class UOp(OpMixin, metaclass=UOpMetaClass): # late ops don't have shape case Ops.UNIQUE | Ops.LUNIQUE | Ops.DEVICE | Ops.RANGE | Ops.LOAD | Ops.IF | Ops.BARRIER | Ops.CUSTOM | Ops.CUSTOMI | \ Ops.VECTORIZE | Ops.VCONST | Ops.GEP | Ops.SPECIAL | Ops.UNROLL | Ops.CONTRACT | Ops.CUSTOM_KERNEL | \ - Ops.LINEAR | Ops.PROGRAM | Ops.SOURCE: + Ops.LINEAR | Ops.PROGRAM | Ops.SOURCE | Ops.BINARY: return None case Ops.INDEX: diff --git a/tinygrad/uop/spec.py b/tinygrad/uop/spec.py index 8988a5f5bf..c17c88822a 100644 --- a/tinygrad/uop/spec.py +++ b/tinygrad/uop/spec.py @@ -249,13 +249,15 @@ full_spec = PatternMatcher([ # in progress MSTACK may lose device (UPat((Ops.MSELECT, Ops.MSTACK), name="x"), lambda x: True), - # codegen: PROGRAM with progressive sources through the pipeline (SINK, DEVICE, LINEAR?, SOURCE?) + # codegen: PROGRAM with progressive sources through the pipeline (SINK, DEVICE, LINEAR?, SOURCE?, BINARY?) (UPat(Ops.PROGRAM, dtypes.void, src=(UPat(Ops.SINK), UPat(Ops.DEVICE))), lambda: True), (UPat(Ops.PROGRAM, dtypes.void, src=(UPat(Ops.SINK), UPat(Ops.DEVICE), UPat(Ops.LINEAR))), lambda: True), (UPat(Ops.PROGRAM, dtypes.void, src=(UPat(Ops.SINK), UPat(Ops.DEVICE), UPat(Ops.LINEAR), UPat(Ops.SOURCE))), lambda: True), - # codegen: standalone LINEAR/SOURCE + (UPat(Ops.PROGRAM, dtypes.void, src=(UPat(Ops.SINK), UPat(Ops.DEVICE), UPat(Ops.LINEAR), UPat(Ops.SOURCE), UPat(Ops.BINARY))), lambda: True), + # codegen: standalone LINEAR/SOURCE/BINARY (UPat(Ops.LINEAR, dtypes.void), lambda: True), (UPat(Ops.SOURCE, dtypes.void, src=()), lambda: True), + (UPat(Ops.BINARY, dtypes.void, src=()), lambda: True), # temp VECTORIZE/INDEX during rewrite have the wrong dtype (UPat(Ops.VECTORIZE), lambda: True),