This commit is contained in:
Christopher Milan
2025-12-18 19:47:37 +00:00
parent 51106f9cf0
commit 44a1d0a1af
7 changed files with 19 additions and 16 deletions

View File

@@ -27,9 +27,10 @@ def to_uops_list(u:list[UOp], ren=None) -> list[UOp]:
def _uops_to_prg(uops_list):
uops = full_rewrite(ast:=UOp.sink(*uops_list), ren=Device[Device.DEFAULT].renderer)
src = Device[Device.DEFAULT].renderer.render(uops)
aux = Device[Device.DEFAULT].renderer.aux(uops) if Device[Device.DEFAULT].renderer.has_aux else {}
has_local = Device[Device.DEFAULT].renderer.has_local
return CompiledRunner(ProgramSpec(uops[-1].arg.name if uops[-1].arg is not None else "test", src, Device.DEFAULT, ast, uops=uops,
global_size=[1,1,1] if has_local else None, local_size=[1,1,1] if has_local else None))
global_size=[1,1,1] if has_local else None, local_size=[1,1,1] if has_local else None, aux=aux))
def uop(uops:list[UOp], uop:Ops, dtype:Optional[DType], src:tuple[UOp, ...], arg:Any=None) -> UOp:
uops.append(UOp(uop, dtype, tuple(src), arg))

View File

@@ -6,7 +6,7 @@ import importlib, inspect, functools, pathlib, os, platform, contextlib, sys, re
from tinygrad.helpers import CI, OSX, LRU, getenv, diskcache_get, diskcache_put, DEBUG, GlobalCounters, flat_mv, PROFILE, temp, colored
from tinygrad.helpers import Context, CCACHE, ALLOW_DEVICE_USAGE, MAX_BUFFER_SIZE, cpu_events, ProfileEvent, ProfilePointEvent, dedup, ContextVar
from tinygrad.helpers import unwrap_class_type, suppress_finalizing, select_first_inited, VIZ, CPU_LLVM, CPU_LVP, NV_PTX, CUDA_PTX, NV_NAK
from tinygrad.dtype import DType, dtypes, _to_np_dtype
from tinygrad.dtype import ImageDType, PtrDType, DType, dtypes, _to_np_dtype
from tinygrad.renderer import Renderer
# **************** Device ****************
@@ -93,7 +93,7 @@ class Buffer:
profile_events:list[ProfileEvent] = []
def __init__(self, device:str, size:int, dtype:DType, opaque:Any=None, options:BufferSpec|None=None, initial_value:bytes|None=None,
uop_refcount=0, base:Buffer|None=None, offset:int=0, preallocate=False):
# assert isinstance(dtype, DType) and not isinstance(dtype, PtrDType)
assert isinstance(dtype, DType) and (isinstance(dtype, ImageDType) or not isinstance(dtype, PtrDType))
self.device, self.size, self.dtype, self.options, self.offset, self.allocated_views = device, size, dtype, options, offset, 0
if base is None:
assert offset == 0, "base buffers can't have offset"

View File

@@ -44,11 +44,11 @@ def get_program(ast:UOp, renderer:Renderer, opts:list[Opt]|None=None) -> Program
# print and render
if DEBUG >= 6: print_uops(uops)
src, aux = r if isinstance(r:=renderer.render(uops), tuple) else (r, None)
src = renderer.render(uops)
return ProgramSpec(uops[-1].arg.name if uops[-1].arg is not None else "test", src, renderer.device, ast, uops,
global_size=[1,1,1] if renderer.has_local or renderer.has_threads else None,
local_size=[1,1,1] if renderer.has_local else None, aux=aux)
local_size=[1,1,1] if renderer.has_local else None, aux=renderer.aux(uops) if renderer.has_aux else {})
# **************** Runners ****************
@@ -86,7 +86,7 @@ class CompiledRunner(Runner):
with cpu_profile(TracingKey(f"compile {p.name}", (p.function_name,)), "TINY"):
self.lib = Device[p.device].compiler.compile_cached(p.src)
if DEBUG >= 7: Device[p.device].compiler.disassemble(self.lib)
self._prg = Device[p.device].runtime(p.function_name, self.lib, **({"aux_render": p.aux} if p.aux is not None else {})) if prg is None else prg
self._prg = Device[p.device].runtime(p.function_name, self.lib, **p.aux) if prg is None else prg
super().__init__(p.name, p.device, p.estimates)
def __reduce__(self): return self.__class__, (self.p, self.lib)

View File

@@ -64,7 +64,7 @@ class ProgramSpec:
device:str
ast:UOp # save the base ast (this is method cache key)
uops:list[UOp]|None=None
aux:Any=None
aux:dict[str,...]=field(default_factory=dict)
# filled in from uops (if we have uops)
global_size:list[int]|None=None
@@ -122,6 +122,7 @@ class Renderer:
has_local: bool = True
has_threads: bool = False
has_shared: bool = True
has_aux: bool = False # additional program info, eg. image shapes
# NOTE: these two should be in (x,y,z) order to match the max_sizes argument in get_grouped_dims
global_max: tuple[int, ...]|None = (0x8FFFFFFF,) * (3) # TODO: Ops.SPECIAL int32 indexes right now
local_max: tuple[int, ...]|None = (0x8FFFFFFF,) * (3) # TODO: Ops.SPECIAL int32 indexes right now
@@ -133,3 +134,4 @@ class Renderer:
def __reduce__(self): return self.__class__, ()
def render(self, uops:list[UOp]) -> str|tuple[str,...]: raise NotImplementedError("needs a renderer")
def aux(self, uops:list[UOp]) -> dict[str,...]: raise NotImplementedError("needs aux")

View File

@@ -280,6 +280,7 @@ class ClangRenderer(CStyleLanguage):
class OpenCLRenderer(CStyleLanguage):
device = "CL"
has_aux = True
# language options
kernel_typedef = "__kernel void"
@@ -308,7 +309,7 @@ class OpenCLRenderer(CStyleLanguage):
if any(uop.dtype.base == dtypes.half for uop in uops): prefix = (["#pragma OPENCL EXTENSION cl_khr_fp16 : enable"] + (prefix or []))
return super().render_kernel(function_name, kernel, bufs, uops, prefix)
def render(self, uops:list[UOp]): return super().render(uops), [u.dtype for u in uops if u.op == Ops.DEFINE_GLOBAL]
def aux(self, uops:list[UOp]): return {"buf_dtypes": [u.dtype for u in uops if u.op == Ops.DEFINE_GLOBAL]}
class IntelRenderer(OpenCLRenderer):
device, suffix, kernel_typedef = "CL", "INTEL", "__attribute__((intel_reqd_sub_group_size(8)))\n" + "__kernel void"

View File

@@ -34,8 +34,8 @@ class CLCompiler(Compiler):
return bytes(binary)
class CLProgram:
def __init__(self, device:CLDevice, name:str, lib:bytes, aux_render=None):
self.dev, self.name, self.lib, self.buf_dtypes = device, name, lib, aux_render
def __init__(self, device:CLDevice, name:str, lib:bytes, buf_dtypes=[]):
self.dev, self.name, self.lib, self.buf_dtypes = device, name, lib, buf_dtypes
self.program = checked(cl.clCreateProgramWithBinary(device.context, 1, device.device_id, (ctypes.c_size_t * 1)(len(lib)),
to_char_p_p([lib], ctypes.c_ubyte), binary_status := ctypes.c_int32(),
errcode_ret := ctypes.c_int32()), errcode_ret)
@@ -53,10 +53,9 @@ class CLProgram:
vals:tuple[int, ...]=(), wait=False) -> float|None:
for i,(b,_) in enumerate(bufs):
if isinstance(dt:=self.buf_dtypes[i], ImageDType):
try: b = checked(
cl.clCreateImage(self.dev.context, cl.CL_MEM_READ_WRITE, cl.cl_image_format(cl.CL_RGBA, {2:cl.CL_HALF_FLOAT, 4:cl.CL_FLOAT}[dt.itemsize]),
cl.cl_image_desc(cl.CL_MEM_OBJECT_IMAGE2D, dt.shape[1], dt.shape[0], buffer=b), None, status:=ctypes.c_int32()), status)
except RuntimeError as e: raise ValueError(f"{i=} {dt=}") from e
fmt = cl.cl_image_format(cl.CL_RGBA, {2:cl.CL_HALF_FLOAT, 4:cl.CL_FLOAT}[dt.itemsize])
desc = cl.cl_image_desc(cl.CL_MEM_OBJECT_IMAGE2D, dt.shape[1], dt.shape[0], buffer=b)
b = checked(cl.clCreateImage(self.dev.context, cl.CL_MEM_READ_WRITE, fmt, desc, None, status:=ctypes.c_int32()), status)
check(cl.clSetKernelArg(self.kernel, i, ctypes.sizeof(b), ctypes.byref(b)))
for i,v in enumerate(vals,start=len(bufs)): check(cl.clSetKernelArg(self.kernel, i, 4, ctypes.byref(ctypes.c_int32(v))))
if local_size is not None: global_size = cast(tuple[int,int,int], tuple(int(g*l) for g,l in zip(global_size, local_size)))

View File

@@ -226,9 +226,9 @@ class IR3ArgsState(HCQArgsState):
self.bind_sints_to_buf(*flatten([b.texture_info.ibo + ([0] * 8) for b in ibos]), buf=self.buf, fmt='I', offset=prg.ibo_off)
class QCOMProgram(HCQProgram):
def __init__(self, dev: QCOMDevice, name: str, lib: bytes, aux_render=None):
def __init__(self, dev: QCOMDevice, name: str, lib: bytes, buf_dtypes=[]):
self.tex_infos:list[QCOMTextureInfo|None] = []
for dtype in aux_render:
for dtype in buf_dtypes:
if isinstance(dtype, ImageDType):
imgw, imgh = dtype.shape[1], dtype.shape[0]
stride = imgw * 4 * dtype.itemsize