mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-07 22:23:55 -05:00
add optional compiler in Renderer (#13817)
* add optional compiler in Renderer [pr] * fix * late init * remove precompiled * cleanup
This commit is contained in:
225
test/external/external_hip_compiler_bug.py
vendored
225
test/external/external_hip_compiler_bug.py
vendored
@@ -1,225 +0,0 @@
|
||||
# [<buf device:HIP size:1605632 dtype:dtypes.float>, <buf device:HIP size:301506 dtype:dtypes.float>, <buf device:HIP size:9408 dtype:dtypes.float>]
|
||||
from tinygrad import Device, dtypes
|
||||
from tinygrad.device import Buffer, CompiledRunner
|
||||
|
||||
import ctypes
|
||||
import gpuctypes.hip as hip
|
||||
from tinygrad.helpers import to_char_p_p, init_c_var
|
||||
def get_bytes(arg, get_sz, get_str, check) -> bytes: return (sz := init_c_var(ctypes.c_size_t(), lambda x: check(get_sz(arg, ctypes.byref(x)))), ctypes.string_at(init_c_var(ctypes.create_string_buffer(sz.value), lambda x: check(get_str(arg, x))), size=sz.value))[1] # noqa: E501
|
||||
def check(status):
|
||||
if status != 0: raise RuntimeError(f"HIP Error {status}, {ctypes.string_at(hip.hipGetErrorString(status)).decode()}")
|
||||
def compile_hip(prg:str, arch="gfx1100") -> bytes:
|
||||
check(hip.hiprtcCreateProgram(ctypes.byref(prog := hip.hiprtcProgram()), prg.encode(), "<null>".encode(), 0, None, None))
|
||||
compile_options = [f'--offload-arch={arch}', '-I/opt/rocm/include']
|
||||
status = hip.hiprtcCompileProgram(prog, len(compile_options), to_char_p_p([o.encode() for o in compile_options]))
|
||||
if status != 0: raise RuntimeError(f"compile failed: {get_bytes(prog, hip.hiprtcGetProgramLogSize, hip.hiprtcGetProgramLog, check).decode()}")
|
||||
return get_bytes(prog, hip.hiprtcGetCodeSize, hip.hiprtcGetCode, check)
|
||||
|
||||
prefix = """
|
||||
typedef long unsigned int size_t;
|
||||
extern "C" __attribute__((device)) __attribute__((const)) size_t __ockl_get_local_id(unsigned int);
|
||||
extern "C" __attribute__((device)) __attribute__((const)) size_t __ockl_get_group_id(unsigned int);
|
||||
extern "C" __attribute__((device)) __attribute__((const)) size_t __ockl_get_local_size(unsigned int);
|
||||
typedef float float2 __attribute__((ext_vector_type(2)));
|
||||
static inline __attribute__((device)) float2 make_float2(float x, float y) { return {x, y}; }
|
||||
"""
|
||||
|
||||
code = """
|
||||
extern "C" __attribute__((global))void r_2_8_7_7_4_8_3_7_7_4_4_2_2(float* data0, const float* data1, const float* data2) {
|
||||
int gidx0 = __ockl_get_group_id(2); /* 2 */
|
||||
int gidx1 = __ockl_get_group_id(1); /* 8 */
|
||||
int gidx2 = __ockl_get_group_id(0); /* 49 */
|
||||
int lidx4 = __ockl_get_local_id(1); /* 4 */
|
||||
int lidx5 = __ockl_get_local_id(0); /* 8 */
|
||||
float2 acc0 = make_float2(0.0f,0.0f);
|
||||
float2 acc1 = make_float2(0.0f,0.0f);
|
||||
float2 acc2 = make_float2(0.0f,0.0f);
|
||||
float2 acc3 = make_float2(0.0f,0.0f);
|
||||
float2 acc4 = make_float2(0.0f,0.0f);
|
||||
float2 acc5 = make_float2(0.0f,0.0f);
|
||||
float2 acc6 = make_float2(0.0f,0.0f);
|
||||
float2 acc7 = make_float2(0.0f,0.0f);
|
||||
float2 acc8 = make_float2(0.0f,0.0f);
|
||||
float2 acc9 = make_float2(0.0f,0.0f);
|
||||
float2 acc10 = make_float2(0.0f,0.0f);
|
||||
float2 acc11 = make_float2(0.0f,0.0f);
|
||||
float2 acc12 = make_float2(0.0f,0.0f);
|
||||
float2 acc13 = make_float2(0.0f,0.0f);
|
||||
float2 acc14 = make_float2(0.0f,0.0f);
|
||||
float2 acc15 = make_float2(0.0f,0.0f);
|
||||
float2 acc16 = make_float2(0.0f,0.0f);
|
||||
float2 acc17 = make_float2(0.0f,0.0f);
|
||||
float2 acc18 = make_float2(0.0f,0.0f);
|
||||
float2 acc19 = make_float2(0.0f,0.0f);
|
||||
float2 acc20 = make_float2(0.0f,0.0f);
|
||||
float2 acc21 = make_float2(0.0f,0.0f);
|
||||
float2 acc22 = make_float2(0.0f,0.0f);
|
||||
float2 acc23 = make_float2(0.0f,0.0f);
|
||||
float2 acc24 = make_float2(0.0f,0.0f);
|
||||
float2 acc25 = make_float2(0.0f,0.0f);
|
||||
float2 acc26 = make_float2(0.0f,0.0f);
|
||||
float2 acc27 = make_float2(0.0f,0.0f);
|
||||
float2 acc28 = make_float2(0.0f,0.0f);
|
||||
float2 acc29 = make_float2(0.0f,0.0f);
|
||||
float2 acc30 = make_float2(0.0f,0.0f);
|
||||
float2 acc31 = make_float2(0.0f,0.0f);
|
||||
int alu0 = (gidx2/7);
|
||||
int alu1 = (gidx2%7);
|
||||
int alu2 = (alu1*32);
|
||||
int alu3 = (lidx5*4);
|
||||
int alu4 = ((gidx0*802816)+(gidx1*100352)+(alu0*1792)+(alu1*16)+(lidx4*448)+(lidx5*2));
|
||||
for (int ridx0 = 0; ridx0 < 3; ridx0++) {
|
||||
for (int ridx1 = 0; ridx1 < 7; ridx1++) {
|
||||
int alu5 = ((alu0*(-32))+(lidx4*(-8))+(ridx1*(-1)));
|
||||
bool alu6 = (alu5<(-2));
|
||||
bool alu7 = (alu5<0);
|
||||
bool alu8 = (((alu0*32)+(lidx4*8)+ridx1)<221);
|
||||
for (int ridx2 = 0; ridx2 < 7; ridx2++) {
|
||||
int alu9 = ((gidx0*150528)+(ridx0*50176)+(alu0*7168)+(lidx4*1792)+(ridx1*224)+alu2+alu3+ridx2);
|
||||
int alu10 = ((alu1*(-32))+(lidx5*(-4))+(ridx2*(-1)));
|
||||
bool alu11 = (alu10<(-2));
|
||||
float val0 = 0.0f;
|
||||
if ((alu6*alu11)) { val0 = data1[alu9+(-675)]; }
|
||||
float val1 = 0.0f;
|
||||
if ((alu7*alu11)) { val1 = data1[alu9+(-227)]; }
|
||||
float val2 = 0.0f;
|
||||
if (alu11) { val2 = data1[alu9+221]; }
|
||||
float val3 = 0.0f;
|
||||
if ((alu8*alu11)) { val3 = data1[alu9+669]; }
|
||||
bool alu12 = (alu10<0);
|
||||
bool alu13 = ((alu2+alu3+ridx2)<225);
|
||||
float val4 = 0.0f;
|
||||
if ((alu6*alu12*alu13)) { val4 = data1[alu9+(-673)]; }
|
||||
float val5 = 0.0f;
|
||||
if ((alu7*alu12*alu13)) { val5 = data1[alu9+(-225)]; }
|
||||
float val6 = 0.0f;
|
||||
if ((alu12*alu13)) { val6 = data1[alu9+223]; }
|
||||
float val7 = 0.0f;
|
||||
if ((alu8*alu12*alu13)) { val7 = data1[alu9+671]; }
|
||||
int alu14 = ((gidx1*1176)+(ridx0*49)+(ridx1*7)+ridx2);
|
||||
float val8 = data2[alu14];
|
||||
float val9 = data2[alu14+147];
|
||||
float val10 = data2[alu14+294];
|
||||
float val11 = data2[alu14+441];
|
||||
float val12 = data2[alu14+588];
|
||||
float val13 = data2[alu14+735];
|
||||
float val14 = data2[alu14+882];
|
||||
float val15 = data2[alu14+1029];
|
||||
(acc0).x = ((val0*val8)+(acc0).x);
|
||||
(acc1).x = ((val0*val9)+(acc1).x);
|
||||
(acc2).x = ((val0*val10)+(acc2).x);
|
||||
(acc3).x = ((val0*val11)+(acc3).x);
|
||||
(acc4).x = ((val1*val8)+(acc4).x);
|
||||
(acc5).x = ((val1*val9)+(acc5).x);
|
||||
(acc6).x = ((val1*val10)+(acc6).x);
|
||||
(acc7).x = ((val1*val11)+(acc7).x);
|
||||
(acc8).x = ((val2*val8)+(acc8).x);
|
||||
(acc9).x = ((val2*val9)+(acc9).x);
|
||||
(acc10).x = ((val2*val10)+(acc10).x);
|
||||
(acc11).x = ((val2*val11)+(acc11).x);
|
||||
(acc12).x = ((val3*val8)+(acc12).x);
|
||||
(acc13).x = ((val3*val9)+(acc13).x);
|
||||
(acc14).x = ((val3*val10)+(acc14).x);
|
||||
(acc15).x = ((val3*val11)+(acc15).x);
|
||||
(acc16).x = ((val0*val12)+(acc16).x);
|
||||
(acc17).x = ((val0*val13)+(acc17).x);
|
||||
(acc18).x = ((val0*val14)+(acc18).x);
|
||||
(acc19).x = ((val0*val15)+(acc19).x);
|
||||
(acc20).x = ((val1*val12)+(acc20).x);
|
||||
(acc21).x = ((val1*val13)+(acc21).x);
|
||||
(acc22).x = ((val1*val14)+(acc22).x);
|
||||
(acc23).x = ((val1*val15)+(acc23).x);
|
||||
(acc24).x = ((val2*val12)+(acc24).x);
|
||||
(acc25).x = ((val2*val13)+(acc25).x);
|
||||
(acc26).x = ((val2*val14)+(acc26).x);
|
||||
(acc27).x = ((val2*val15)+(acc27).x);
|
||||
(acc28).x = ((val3*val12)+(acc28).x);
|
||||
(acc29).x = ((val3*val13)+(acc29).x);
|
||||
(acc30).x = ((val3*val14)+(acc30).x);
|
||||
(acc31).x = ((val3*val15)+(acc31).x);
|
||||
(acc0).y = ((val4*val8)+(acc0).y);
|
||||
(acc1).y = ((val4*val9)+(acc1).y);
|
||||
(acc2).y = ((val4*val10)+(acc2).y);
|
||||
(acc3).y = ((val4*val11)+(acc3).y);
|
||||
(acc4).y = ((val5*val8)+(acc4).y);
|
||||
(acc5).y = ((val5*val9)+(acc5).y);
|
||||
(acc6).y = ((val5*val10)+(acc6).y);
|
||||
(acc7).y = ((val5*val11)+(acc7).y);
|
||||
(acc8).y = ((val6*val8)+(acc8).y);
|
||||
(acc9).y = ((val6*val9)+(acc9).y);
|
||||
(acc10).y = ((val6*val10)+(acc10).y);
|
||||
(acc11).y = ((val6*val11)+(acc11).y);
|
||||
(acc12).y = ((val7*val8)+(acc12).y);
|
||||
(acc13).y = ((val7*val9)+(acc13).y);
|
||||
(acc14).y = ((val7*val10)+(acc14).y);
|
||||
(acc15).y = ((val7*val11)+(acc15).y);
|
||||
(acc16).y = ((val4*val12)+(acc16).y);
|
||||
(acc17).y = ((val4*val13)+(acc17).y);
|
||||
(acc18).y = ((val4*val14)+(acc18).y);
|
||||
(acc19).y = ((val4*val15)+(acc19).y);
|
||||
(acc20).y = ((val5*val12)+(acc20).y);
|
||||
(acc21).y = ((val5*val13)+(acc21).y);
|
||||
(acc22).y = ((val5*val14)+(acc22).y);
|
||||
(acc23).y = ((val5*val15)+(acc23).y);
|
||||
(acc24).y = ((val6*val12)+(acc24).y);
|
||||
(acc25).y = ((val6*val13)+(acc25).y);
|
||||
(acc26).y = ((val6*val14)+(acc26).y);
|
||||
(acc27).y = ((val6*val15)+(acc27).y);
|
||||
(acc28).y = ((val7*val12)+(acc28).y);
|
||||
(acc29).y = ((val7*val13)+(acc29).y);
|
||||
(acc30).y = ((val7*val14)+(acc30).y);
|
||||
(acc31).y = ((val7*val15)+(acc31).y);
|
||||
}
|
||||
}
|
||||
}
|
||||
*((float2*)(data0+alu4)) = acc0;
|
||||
*((float2*)(data0+alu4+12544)) = acc1;
|
||||
*((float2*)(data0+alu4+25088)) = acc2;
|
||||
*((float2*)(data0+alu4+37632)) = acc3;
|
||||
*((float2*)(data0+alu4+112)) = acc4;
|
||||
*((float2*)(data0+alu4+12656)) = acc5;
|
||||
*((float2*)(data0+alu4+25200)) = acc6;
|
||||
*((float2*)(data0+alu4+37744)) = acc7;
|
||||
*((float2*)(data0+alu4+224)) = acc8;
|
||||
*((float2*)(data0+alu4+12768)) = acc9;
|
||||
*((float2*)(data0+alu4+25312)) = acc10;
|
||||
*((float2*)(data0+alu4+37856)) = acc11;
|
||||
*((float2*)(data0+alu4+336)) = acc12;
|
||||
*((float2*)(data0+alu4+12880)) = acc13;
|
||||
*((float2*)(data0+alu4+25424)) = acc14;
|
||||
*((float2*)(data0+alu4+37968)) = acc15;
|
||||
*((float2*)(data0+alu4+50176)) = acc16;
|
||||
*((float2*)(data0+alu4+62720)) = acc17;
|
||||
*((float2*)(data0+alu4+75264)) = acc18;
|
||||
*((float2*)(data0+alu4+87808)) = acc19;
|
||||
*((float2*)(data0+alu4+50288)) = acc20;
|
||||
*((float2*)(data0+alu4+62832)) = acc21;
|
||||
*((float2*)(data0+alu4+75376)) = acc22;
|
||||
*((float2*)(data0+alu4+87920)) = acc23;
|
||||
*((float2*)(data0+alu4+50400)) = acc24;
|
||||
*((float2*)(data0+alu4+62944)) = acc25;
|
||||
*((float2*)(data0+alu4+75488)) = acc26;
|
||||
*((float2*)(data0+alu4+88032)) = acc27;
|
||||
*((float2*)(data0+alu4+50512)) = acc28;
|
||||
*((float2*)(data0+alu4+63056)) = acc29;
|
||||
*((float2*)(data0+alu4+75600)) = acc30;
|
||||
*((float2*)(data0+alu4+88144)) = acc31;
|
||||
}
|
||||
"""
|
||||
|
||||
dev = "HIP"
|
||||
lib = Device[dev].compiler.compile(prefix+code)
|
||||
#lib = compile_hip(code)
|
||||
b0 = Buffer(dev, 1605632, dtypes.float)
|
||||
b1 = Buffer(dev, 301506, dtypes.float)
|
||||
b2 = Buffer(dev, 9408, dtypes.float)
|
||||
print(hex(b0._buf.value), hex(b0._buf.value+1605632*4))
|
||||
print(hex(b1._buf.value))
|
||||
print(hex(b2._buf.value))
|
||||
#prg = CompiledRunner("r_2_8_7_7_4_8_3_7_7_4_4_2_2", "", dev, [7, 1, 1], [8, 4, 1], precompiled=lib)
|
||||
prg = CompiledRunner("r_2_8_7_7_4_8_3_7_7_4_4_2_2", "", dev, [49, 8, 2], [8, 4, 1], precompiled=lib)
|
||||
print("compiled")
|
||||
prg([b0, b1, b2], {})
|
||||
print("ran")
|
||||
Device[dev].synchronize()
|
||||
print("sync")
|
||||
@@ -1,5 +1,5 @@
|
||||
import unittest, subprocess, platform
|
||||
from tinygrad.runtime.ops_cpu import ClangJITCompiler
|
||||
from tinygrad.runtime.support.compiler_cpu import ClangJITCompiler
|
||||
from tinygrad.runtime.support.elf import elf_loader
|
||||
|
||||
class TestElfLoader(unittest.TestCase):
|
||||
|
||||
@@ -135,9 +135,15 @@ def do_render(ctx:Renderer, prg:UOp, lin:UOp) -> UOp:
|
||||
src = ctx.render(list(lin.src))
|
||||
return prg.replace(src=prg.src + (UOp(Ops.SOURCE, arg=src),))
|
||||
|
||||
def do_compile(ctx:Renderer, prg:UOp, source:UOp) -> UOp|None:
|
||||
if ctx.compiler is None: return None
|
||||
lib = ctx.compiler.compile_cached(source.arg)
|
||||
return prg.replace(src=prg.src + (UOp(Ops.BINARY, arg=lib),))
|
||||
|
||||
pm_to_program = PatternMatcher([
|
||||
(UPat(Ops.PROGRAM, src=(UPat(Ops.SINK, name="sink"), UPat(Ops.DEVICE)), name="prg"), do_linearize),
|
||||
(UPat(Ops.PROGRAM, src=(UPat(), UPat(Ops.DEVICE), UPat(Ops.LINEAR, name="lin")), name="prg"), do_render),
|
||||
(UPat(Ops.PROGRAM, src=(UPat(), UPat(Ops.DEVICE), UPat(Ops.LINEAR), UPat(Ops.SOURCE, name="source")), name="prg"), do_compile),
|
||||
])
|
||||
|
||||
@track_rewrites(name=lambda *args,ret,**kwargs: TracingKey(ret.name, (ret.function_name, ret.ast), ret=ret), replay=True)
|
||||
|
||||
@@ -41,7 +41,7 @@ def _time_program(p:ProgramSpec, lib:bytes, var_vals:dict[str, int], rawbufs:lis
|
||||
if allow_test_size and max_global_size is not None:
|
||||
global_size, factor = get_test_global_size(p.global_size, max_global_size, var_vals)
|
||||
p = replace(p, global_size=global_size)
|
||||
try: car = CompiledRunner(p, precompiled=lib)
|
||||
try: car = CompiledRunner(replace(p, lib=lib))
|
||||
except AssertionError: return [math.inf] * cnt
|
||||
tms = []
|
||||
input_bufs = [rawbufs[i] for i in car.p.globals]
|
||||
@@ -72,7 +72,7 @@ def _try_compile(x:tuple[int,Scheduler], compiler:Compiler) -> tuple[int, tuple[
|
||||
if getenv("BEAM_LOG_SURPASS_MAX"): print(f"too many uops. {len(p.uops)=}, {uops_max=}")
|
||||
raise RuntimeError("too many uops")
|
||||
st = time.perf_counter()
|
||||
prog = compiler.compile(p.src)
|
||||
prog = p.lib if p.lib is not None else compiler.compile(p.src)
|
||||
et = time.perf_counter() - st
|
||||
ret = (p, prog, et)
|
||||
except RuntimeError:
|
||||
|
||||
@@ -278,7 +278,7 @@ class Compiler:
|
||||
def disassemble(self, lib:bytes): pass
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class CompilerPair: renderer:type[Renderer]|functools.partial; compiler:type[Compiler]|functools.partial; ctrl_var:ContextVar|None = None # noqa: E702
|
||||
class CompilerPair: renderer:type[Renderer]|functools.partial; compiler:type[Compiler]|functools.partial|None; ctrl_var:ContextVar|None = None # noqa: E702
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class CompilerSet: cset:list[CompilerPair]; ctrl_var:ContextVar|None = None # noqa: E702
|
||||
@@ -290,21 +290,25 @@ class Compiled:
|
||||
self.device, self.allocator, self.runtime, self.graph, self.group_id = device, allocator, runtime, graph, group_id
|
||||
|
||||
self.comps_ctrl_var = compilers.ctrl_var if compilers is not None else None
|
||||
self.comp_sets:dict[Any, tuple[ContextVar|None, tuple[type[Renderer]|functools.partial, type[Compiler]|functools.partial]]] = {}
|
||||
self.cached_pair:dict[Any, tuple[Renderer, Compiler]] = {}
|
||||
self.comp_sets:dict[Any, tuple[ContextVar|None, tuple[type[Renderer]|functools.partial, type[Compiler]|functools.partial|None]]] = {}
|
||||
self.cached_pair:dict[Any, tuple[Renderer, Compiler|None]] = {}
|
||||
for cpair in (compilers.cset if compilers is not None else [CompilerPair(Renderer, Compiler)]):
|
||||
self.comp_sets[self._compiler_name(cpair.compiler)] = (cpair.ctrl_var, (cpair.renderer, cpair.compiler))
|
||||
self.comp_sets[self._compiler_name(cpair.renderer, cpair.compiler)] = (cpair.ctrl_var, (cpair.renderer, cpair.compiler))
|
||||
|
||||
@property
|
||||
def renderer(self) -> Renderer: return self._select_compiler_pair()[0]
|
||||
|
||||
@property
|
||||
def compiler(self) -> Compiler: return self._select_compiler_pair()[1]
|
||||
def compiler(self) -> Compiler:
|
||||
if (ret:=self.renderer.compiler or self._select_compiler_pair()[1]) is None: raise RuntimeError(f"no compiler for {self.device}")
|
||||
return ret
|
||||
|
||||
def _compiler_name(self, c:type[Compiler]|functools.partial) -> str:
|
||||
return unwrap_class_type(c).__name__.upper().removesuffix("COMPILER").removeprefix(devname:=self.device.split(':')[0].upper()) or devname
|
||||
def _compiler_name(self, r:type[Renderer]|functools.partial, c:type[Compiler]|functools.partial|None) -> str:
|
||||
devname = self.device.split(':')[0].upper()
|
||||
if c is None: return unwrap_class_type(r).__name__.upper().removesuffix("RENDERER").removeprefix(devname) or devname
|
||||
return unwrap_class_type(c).__name__.upper().removesuffix("COMPILER").removeprefix(devname) or devname
|
||||
|
||||
def _select_compiler_pair(self) -> tuple[Renderer, Compiler]:
|
||||
def _select_compiler_pair(self) -> tuple[Renderer, Compiler|None]:
|
||||
# select forced compiler from global env var.
|
||||
forced_comps = set([self.comp_sets[val][1]] if self.comps_ctrl_var is not None and (val:=self.comps_ctrl_var.value) else [])
|
||||
|
||||
@@ -398,7 +402,7 @@ def enumerate_devices_str() -> Generator[str, None, None]:
|
||||
# d.renderer, d.compiler = r(), c()
|
||||
with Context(CACHELEVEL=0): test = (Tensor([1,2,3], device=device) * 2).tolist()
|
||||
if test != [2,4,6]: raise ValueError(f"got {test} instead of [2, 4, 6]")
|
||||
set_text = f'({cc_ctrl_var.key}={d._compiler_name(c)} to make default)' if cc_ctrl_var is not None else ''
|
||||
set_text = f'({cc_ctrl_var.key}={d._compiler_name(r, c)} to make default)' if cc_ctrl_var is not None else ''
|
||||
default_text = '(default)' if type(default_compiler) is type(d.compiler) else set_text
|
||||
compilers_results.append(f"{colored('+', 'green')} {unwrap_class_type(c).__name__} {default_text}")
|
||||
any_works = True
|
||||
|
||||
@@ -36,19 +36,19 @@ def optimize_local_size(_prg:Callable, global_size:list[int], rawbufs:list[Buffe
|
||||
return ret[1]
|
||||
|
||||
class CompiledRunner(Runner):
|
||||
def __init__(self, p:ProgramSpec, precompiled:bytes|None=None, prg=None):
|
||||
def __init__(self, p:ProgramSpec, prg=None):
|
||||
if DEBUG >= 3: print(p.applied_opts)
|
||||
if DEBUG >= 4: print(p.src)
|
||||
self.p:ProgramSpec = p
|
||||
if precompiled is not None: self.lib = precompiled
|
||||
else:
|
||||
if p.lib is None:
|
||||
with cpu_profile(TracingKey(f"compile {p.name}", (p.function_name,)), "TINY"):
|
||||
self.lib = Device[p.device].compiler.compile_cached(p.src)
|
||||
if DEBUG >= 7: Device[p.device].compiler.disassemble(self.lib)
|
||||
self._prg = Device[p.device].runtime(p.function_name, self.lib) if prg is None else prg
|
||||
p = replace(p, lib=Device[p.device].compiler.compile_cached(p.src))
|
||||
self.p:ProgramSpec = p
|
||||
assert self.p.lib is not None
|
||||
if DEBUG >= 7: Device[p.device].compiler.disassemble(self.p.lib)
|
||||
self._prg = Device[p.device].runtime(p.function_name, self.p.lib) if prg is None else prg
|
||||
super().__init__(p.name, p.device, p.estimates)
|
||||
|
||||
def __reduce__(self): return self.__class__, (self.p, self.lib)
|
||||
def __reduce__(self): return self.__class__, (self.p,)
|
||||
|
||||
def __call__(self, rawbufs:list[Buffer], var_vals:dict[str, int]|None=None, wait=False) -> float|None:
|
||||
if var_vals is None: var_vals = {}
|
||||
@@ -115,7 +115,7 @@ def get_runner(device:str, ast:UOp) -> CompiledRunner:
|
||||
if cret:=method_cache.get(ckey): return cret
|
||||
bkey = (device.split(":")[0], type(Device[device].compiler), ast.key, context, True)
|
||||
if bret:=method_cache.get(bkey):
|
||||
method_cache[ckey] = ret = CompiledRunner(replace(bret.p, device=device), bret.lib)
|
||||
method_cache[ckey] = ret = CompiledRunner(replace(bret.p, device=device))
|
||||
else:
|
||||
prg: ProgramSpec = get_program(ast, Device[device].renderer)
|
||||
method_cache[ckey] = method_cache[bkey] = ret = CompiledRunner(replace(prg, device=device))
|
||||
|
||||
@@ -115,12 +115,12 @@ def suppress_finalizing(func):
|
||||
if not getattr(sys, 'is_finalizing', lambda: True)(): raise # re-raise if not finalizing
|
||||
return wrapper
|
||||
|
||||
def select_first_inited(candidates:Sequence[Callable[...,T]|Sequence[Callable[...,T]]], err_msg:str, cache:dict|None=None) -> tuple[T,...]|T:
|
||||
def select_first_inited(candidates:Sequence[Callable[...,T]|Sequence[Callable[...,T]|None]], err_msg:str, cache:dict|None=None):
|
||||
excs = []
|
||||
for typ in candidates:
|
||||
if cache is not None and typ in cache: return cache[typ]
|
||||
try:
|
||||
x = tuple([cast(Callable, t)() for t in typ]) if isinstance(typ, Sequence) else cast(Callable, typ)()
|
||||
x = tuple([cast(Callable, t)() if t is not None else None for t in typ]) if isinstance(typ, Sequence) else cast(Callable, typ)()
|
||||
if cache is not None: cache[typ] = x
|
||||
return x
|
||||
except Exception as e: excs.append(e)
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
from __future__ import annotations
|
||||
from typing import Callable, cast
|
||||
from typing import Callable, cast, TYPE_CHECKING
|
||||
import functools
|
||||
from dataclasses import dataclass, field
|
||||
from tinygrad.helpers import to_function_name, dedup, prod, DEBUG
|
||||
@@ -7,6 +7,7 @@ from tinygrad.uop.ops import Ops, UOp, sym_infer, sint, Variable, ssimplify, Gro
|
||||
from tinygrad.dtype import AddrSpace, PtrDType
|
||||
from tinygrad.codegen.opt.tc import TensorCore
|
||||
from tinygrad.codegen.opt import Opt
|
||||
if TYPE_CHECKING: from tinygrad.device import Compiler
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class Estimates:
|
||||
@@ -64,6 +65,7 @@ class ProgramSpec:
|
||||
device:str
|
||||
ast:UOp # save the base ast (this is method cache key)
|
||||
uops:list[UOp]|None=None
|
||||
lib:bytes|None=None
|
||||
|
||||
# filled in from uops (via from_uop)
|
||||
global_size:list[int]=field(default_factory=lambda: [1,1,1])
|
||||
@@ -95,8 +97,9 @@ class ProgramSpec:
|
||||
def from_uop(prg:UOp) -> ProgramSpec:
|
||||
"""Construct ProgramSpec from a PROGRAM UOp."""
|
||||
assert prg.op is Ops.PROGRAM, f"expected PROGRAM, got {prg.op}"
|
||||
# SINK/DEVICE/LINEAR/SOURCE
|
||||
sink, device, linear, source = prg.src
|
||||
# SINK/DEVICE/LINEAR/SOURCE/BINARY?
|
||||
sink, device, linear, source = prg.src[:4]
|
||||
lib = prg.src[4].arg if len(prg.src) > 4 else None
|
||||
uops = list(linear.src)
|
||||
if DEBUG >= 6: print_uops(uops) # LINEAR is src[2]
|
||||
|
||||
@@ -120,7 +123,7 @@ class ProgramSpec:
|
||||
# TODO: this cast is wrong, u.src[0].ssimplify() can be sint
|
||||
if special_size is not None: special_size[int(u.arg[-1])] = cast(int, u.src[0].ssimplify())
|
||||
|
||||
return ProgramSpec(sink.arg.name, source.arg, device.arg, sink, uops, global_size, local_size,
|
||||
return ProgramSpec(sink.arg.name, source.arg, device.arg, sink, uops, lib, global_size, local_size,
|
||||
sorted(_vars, key=lambda v: v.arg), sorted(dedup(_globals)), sorted(dedup(outs)), sorted(dedup(ins)))
|
||||
|
||||
class Renderer:
|
||||
@@ -139,6 +142,7 @@ class Renderer:
|
||||
pre_matcher: PatternMatcher|None = None
|
||||
extra_matcher: PatternMatcher|None = None
|
||||
code_for_op: dict[Ops, Callable] = {}
|
||||
compiler: Compiler|None = None
|
||||
|
||||
def __reduce__(self): return self.__class__, ()
|
||||
def render(self, uops:list[UOp]) -> str: raise NotImplementedError("needs a renderer")
|
||||
|
||||
@@ -8,6 +8,7 @@ from tinygrad.dtype import ImageDType, dtypes, DType, PtrDType, AddrSpace, trunc
|
||||
from tinygrad.renderer import Renderer
|
||||
from tinygrad.codegen.late.devectorizer import no_vectorized_alu
|
||||
|
||||
|
||||
base_rewrite = PatternMatcher([
|
||||
(UPat(Ops.DEFINE_REG, name="x"), lambda ctx,x: f"{ctx.render_dtype(x.dtype.base)} {ctx[x]}[{x.dtype.size}];"),
|
||||
(UPat(Ops.IF, name="x"), lambda ctx,x: f"if ({ctx[x.src[0]]}) {{"),
|
||||
@@ -278,6 +279,11 @@ class ClangRenderer(CStyleLanguage):
|
||||
defines = '\n'.join(self._render_defines(uops))
|
||||
return defines + "\n" + self._render_body(function_name, kernel, bufs, uops, prefix) + "\n" + self._render_entry(function_name, bufs)
|
||||
|
||||
class ClangJITRenderer(ClangRenderer):
|
||||
def __init__(self):
|
||||
from tinygrad.runtime.support.compiler_cpu import ClangJITCompiler
|
||||
self.compiler = ClangJITCompiler()
|
||||
|
||||
class OpenCLRenderer(CStyleLanguage):
|
||||
device = "CL"
|
||||
|
||||
|
||||
@@ -5,10 +5,10 @@ from tinygrad.helpers import CPU_CC, CPU_LVP, CPU_LLVM
|
||||
from tinygrad.device import BufferSpec, DMACPURef, CompilerSet, CompilerPair
|
||||
from tinygrad.runtime.support.hcq import HCQCompiled, HCQAllocatorBase, HCQBuffer, HWQueue, HCQArgsState, HCQSignal, HCQProgram, MMIOInterface
|
||||
from tinygrad.runtime.support.hcq import CLikeArgsState
|
||||
from tinygrad.renderer.cstyle import ClangRenderer
|
||||
from tinygrad.renderer.cstyle import ClangJITRenderer
|
||||
from tinygrad.renderer.llvmir import LLVMRenderer
|
||||
from tinygrad.renderer.nir import LVPRenderer
|
||||
from tinygrad.runtime.support.compiler_cpu import CPULLVMCompiler, ClangJITCompiler
|
||||
from tinygrad.runtime.support.compiler_cpu import CPULLVMCompiler
|
||||
from tinygrad.runtime.support.compiler_mesa import LVPCompiler
|
||||
from tinygrad.runtime.support.elf import jit_loader
|
||||
from tinygrad.uop.ops import sint
|
||||
@@ -136,6 +136,6 @@ class CPUDevice(HCQCompiled):
|
||||
def __init__(self, device:str=""):
|
||||
self.tasks:queue.Queue = queue.Queue()
|
||||
CPUWorker(self, self.tasks, thread_id=0).start()
|
||||
compilers = CompilerSet([CompilerPair(ClangRenderer, ClangJITCompiler), CompilerPair(LLVMRenderer, CPULLVMCompiler, ctrl_var=CPU_LLVM),
|
||||
compilers = CompilerSet([CompilerPair(ClangJITRenderer, None), CompilerPair(LLVMRenderer, CPULLVMCompiler, ctrl_var=CPU_LLVM),
|
||||
CompilerPair(LVPRenderer, LVPCompiler, ctrl_var=CPU_LVP)], ctrl_var=CPU_CC)
|
||||
super().__init__(device, CPUAllocator(self), compilers, functools.partial(CPUProgram, self), CPUSignal, CPUComputeQueue)
|
||||
|
||||
@@ -28,8 +28,8 @@ class Ops(FastEnum):
|
||||
NOOP = auto(); REWRITE_ERROR = auto()
|
||||
|
||||
# renderer
|
||||
# LINEAR is a list of UOps, SOURCE has a str arg that's human readable
|
||||
PROGRAM = auto(); LINEAR = auto(); SOURCE = auto()
|
||||
# LINEAR is a list of UOps, SOURCE has a str arg that's human readable, BINARY has bytes arg that's compiled
|
||||
PROGRAM = auto(); LINEAR = auto(); SOURCE = auto(); BINARY = auto()
|
||||
|
||||
# AFTER passes src[0] through and promises in the toposort that any consumers of the AFTER run after src[1:]
|
||||
# GROUP is a NOOP that just merges things together
|
||||
|
||||
@@ -219,7 +219,7 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
|
||||
# late ops don't have shape
|
||||
case Ops.UNIQUE | Ops.LUNIQUE | Ops.DEVICE | Ops.RANGE | Ops.LOAD | Ops.IF | Ops.BARRIER | Ops.CUSTOM | Ops.CUSTOMI | \
|
||||
Ops.VECTORIZE | Ops.VCONST | Ops.GEP | Ops.SPECIAL | Ops.UNROLL | Ops.CONTRACT | Ops.CUSTOM_KERNEL | \
|
||||
Ops.LINEAR | Ops.PROGRAM | Ops.SOURCE:
|
||||
Ops.LINEAR | Ops.PROGRAM | Ops.SOURCE | Ops.BINARY:
|
||||
return None
|
||||
|
||||
case Ops.INDEX:
|
||||
|
||||
@@ -249,13 +249,15 @@ full_spec = PatternMatcher([
|
||||
# in progress MSTACK may lose device
|
||||
(UPat((Ops.MSELECT, Ops.MSTACK), name="x"), lambda x: True),
|
||||
|
||||
# codegen: PROGRAM with progressive sources through the pipeline (SINK, DEVICE, LINEAR?, SOURCE?)
|
||||
# codegen: PROGRAM with progressive sources through the pipeline (SINK, DEVICE, LINEAR?, SOURCE?, BINARY?)
|
||||
(UPat(Ops.PROGRAM, dtypes.void, src=(UPat(Ops.SINK), UPat(Ops.DEVICE))), lambda: True),
|
||||
(UPat(Ops.PROGRAM, dtypes.void, src=(UPat(Ops.SINK), UPat(Ops.DEVICE), UPat(Ops.LINEAR))), lambda: True),
|
||||
(UPat(Ops.PROGRAM, dtypes.void, src=(UPat(Ops.SINK), UPat(Ops.DEVICE), UPat(Ops.LINEAR), UPat(Ops.SOURCE))), lambda: True),
|
||||
# codegen: standalone LINEAR/SOURCE
|
||||
(UPat(Ops.PROGRAM, dtypes.void, src=(UPat(Ops.SINK), UPat(Ops.DEVICE), UPat(Ops.LINEAR), UPat(Ops.SOURCE), UPat(Ops.BINARY))), lambda: True),
|
||||
# codegen: standalone LINEAR/SOURCE/BINARY
|
||||
(UPat(Ops.LINEAR, dtypes.void), lambda: True),
|
||||
(UPat(Ops.SOURCE, dtypes.void, src=()), lambda: True),
|
||||
(UPat(Ops.BINARY, dtypes.void, src=()), lambda: True),
|
||||
|
||||
# temp VECTORIZE/INDEX during rewrite have the wrong dtype
|
||||
(UPat(Ops.VECTORIZE), lambda: True),
|
||||
|
||||
Reference in New Issue
Block a user