This commit is contained in:
Christopher Milan
2025-12-10 03:11:26 +00:00
parent 89ed801aaf
commit 530eb6e682
3 changed files with 7 additions and 7 deletions

View File

@@ -171,7 +171,7 @@ class Buffer:
return self.allocator._as_dmaref(self._buf)
def as_buffer(self, allow_zero_copy=False, force_zero_copy=False) -> memoryview:
# zero copy with as_buffer (disabled by default due to use after free)
if (force_zero_copy or allow_zero_copy) and hasattr(self.allocator, '_as_buffer') and (self.options is None or self.options.image is None):
if (force_zero_copy or allow_zero_copy) and hasattr(self.allocator, '_as_buffer') and self.options is None:
return self.allocator._as_buffer(self._buf)
assert not force_zero_copy, "force zero copy was passed, but copy is required"
return self.copyout(memoryview(bytearray(self.nbytes)))

View File

@@ -1,5 +1,5 @@
from __future__ import annotations
from typing import Callable, cast
from typing import Any, Callable, cast
import functools
from dataclasses import dataclass, field
from tinygrad.helpers import to_function_name, dedup, prod
@@ -64,7 +64,7 @@ class ProgramSpec:
device:str
ast:UOp # save the base ast (this is method cache key)
uops:list[UOp]|None=None
aux:dict|None=None
aux:Any=None
# filled in from uops (if we have uops)
global_size:list[int]|None=None
@@ -132,4 +132,4 @@ class Renderer:
code_for_op: dict[Ops, Callable] = {}
def __reduce__(self): return self.__class__, ()
def render(self, uops:list[UOp]) -> str: raise NotImplementedError("needs a renderer")
def render(self, uops:list[UOp]) -> str|tuple[str,...]: raise NotImplementedError("needs a renderer")

View File

@@ -203,8 +203,8 @@ class QCOMArgsState(HCQArgsState):
if prg.samp_cnt > 0: to_mv(self.buf.va_addr + prg.samp_off, len(prg.samplers) * 4).cast('I')[:] = array.array('I', prg.samplers)
for i, b in enumerate(bufs):
if prg.buf_info[i].type in {BUFTYPE_TEX, BUFTYPE_IBO}:
obj = prg.tex_infos[i].desc if prg.buf_info[i].type is BUFTYPE_TEX else prg.tex_infos[i].ibo
if (ti:=prg.tex_infos[i]) is not None:
obj = ti.desc if prg.buf_info[i].type is BUFTYPE_TEX else ti.ibo
to_mv(self.buf.va_addr + prg.buf_info[i].offset, len(obj) * 4).cast('I')[:] = array.array('I', obj)
self.bind_sints_to_buf(b.va_addr, buf=self.buf, fmt='Q', offset=self.buf_info[i].offset+(0 if self.buf_info[i].type is BUFTYPE_BUF else 16))
@@ -227,7 +227,7 @@ class IR3ArgsState(HCQArgsState):
class QCOMProgram(HCQProgram):
def __init__(self, dev: QCOMDevice, name: str, lib: bytes, aux_render=None):
self.tex_infos = []
self.tex_infos:list[QCOMTextureInfo|None] = []
for dtype in aux_render:
if isinstance(dtype, ImageDType):
imgw, imgh, itemsize_log = dtype.shape[1], dtype.shape[0], int(math.log2(dtype.itemsize))