diff --git a/tinygrad/device.py b/tinygrad/device.py index 40ab4ad7ed..cd373bb859 100644 --- a/tinygrad/device.py +++ b/tinygrad/device.py @@ -171,7 +171,7 @@ class Buffer: return self.allocator._as_dmaref(self._buf) def as_buffer(self, allow_zero_copy=False, force_zero_copy=False) -> memoryview: # zero copy with as_buffer (disabled by default due to use after free) - if (force_zero_copy or allow_zero_copy) and hasattr(self.allocator, '_as_buffer') and (self.options is None or self.options.image is None): + if (force_zero_copy or allow_zero_copy) and hasattr(self.allocator, '_as_buffer') and self.options is None: return self.allocator._as_buffer(self._buf) assert not force_zero_copy, "force zero copy was passed, but copy is required" return self.copyout(memoryview(bytearray(self.nbytes))) diff --git a/tinygrad/renderer/__init__.py b/tinygrad/renderer/__init__.py index bf9efb736b..ce4b655a64 100644 --- a/tinygrad/renderer/__init__.py +++ b/tinygrad/renderer/__init__.py @@ -1,5 +1,5 @@ from __future__ import annotations -from typing import Callable, cast +from typing import Any, Callable, cast import functools from dataclasses import dataclass, field from tinygrad.helpers import to_function_name, dedup, prod @@ -64,7 +64,7 @@ class ProgramSpec: device:str ast:UOp # save the base ast (this is method cache key) uops:list[UOp]|None=None - aux:dict|None=None + aux:Any=None # filled in from uops (if we have uops) global_size:list[int]|None=None @@ -132,4 +132,4 @@ class Renderer: code_for_op: dict[Ops, Callable] = {} def __reduce__(self): return self.__class__, () - def render(self, uops:list[UOp]) -> str: raise NotImplementedError("needs a renderer") + def render(self, uops:list[UOp]) -> str|tuple[str,...]: raise NotImplementedError("needs a renderer") diff --git a/tinygrad/runtime/ops_qcom.py b/tinygrad/runtime/ops_qcom.py index 2f3bc89dfd..20022194d1 100644 --- a/tinygrad/runtime/ops_qcom.py +++ b/tinygrad/runtime/ops_qcom.py @@ -203,8 +203,8 @@ class QCOMArgsState(HCQArgsState): if prg.samp_cnt > 0: to_mv(self.buf.va_addr + prg.samp_off, len(prg.samplers) * 4).cast('I')[:] = array.array('I', prg.samplers) for i, b in enumerate(bufs): - if prg.buf_info[i].type in {BUFTYPE_TEX, BUFTYPE_IBO}: - obj = prg.tex_infos[i].desc if prg.buf_info[i].type is BUFTYPE_TEX else prg.tex_infos[i].ibo + if (ti:=prg.tex_infos[i]) is not None: + obj = ti.desc if prg.buf_info[i].type is BUFTYPE_TEX else ti.ibo to_mv(self.buf.va_addr + prg.buf_info[i].offset, len(obj) * 4).cast('I')[:] = array.array('I', obj) self.bind_sints_to_buf(b.va_addr, buf=self.buf, fmt='Q', offset=self.buf_info[i].offset+(0 if self.buf_info[i].type is BUFTYPE_BUF else 16)) @@ -227,7 +227,7 @@ class IR3ArgsState(HCQArgsState): class QCOMProgram(HCQProgram): def __init__(self, dev: QCOMDevice, name: str, lib: bytes, aux_render=None): - self.tex_infos = [] + self.tex_infos:list[QCOMTextureInfo|None] = [] for dtype in aux_render: if isinstance(dtype, ImageDType): imgw, imgh, itemsize_log = dtype.shape[1], dtype.shape[0], int(math.log2(dtype.itemsize))