mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 23:18:04 -05:00
mypy
This commit is contained in:
@@ -171,7 +171,7 @@ class Buffer:
|
||||
return self.allocator._as_dmaref(self._buf)
|
||||
def as_buffer(self, allow_zero_copy=False, force_zero_copy=False) -> memoryview:
|
||||
# zero copy with as_buffer (disabled by default due to use after free)
|
||||
if (force_zero_copy or allow_zero_copy) and hasattr(self.allocator, '_as_buffer') and (self.options is None or self.options.image is None):
|
||||
if (force_zero_copy or allow_zero_copy) and hasattr(self.allocator, '_as_buffer') and self.options is None:
|
||||
return self.allocator._as_buffer(self._buf)
|
||||
assert not force_zero_copy, "force zero copy was passed, but copy is required"
|
||||
return self.copyout(memoryview(bytearray(self.nbytes)))
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
from __future__ import annotations
|
||||
from typing import Callable, cast
|
||||
from typing import Any, Callable, cast
|
||||
import functools
|
||||
from dataclasses import dataclass, field
|
||||
from tinygrad.helpers import to_function_name, dedup, prod
|
||||
@@ -64,7 +64,7 @@ class ProgramSpec:
|
||||
device:str
|
||||
ast:UOp # save the base ast (this is method cache key)
|
||||
uops:list[UOp]|None=None
|
||||
aux:dict|None=None
|
||||
aux:Any=None
|
||||
|
||||
# filled in from uops (if we have uops)
|
||||
global_size:list[int]|None=None
|
||||
@@ -132,4 +132,4 @@ class Renderer:
|
||||
code_for_op: dict[Ops, Callable] = {}
|
||||
|
||||
def __reduce__(self): return self.__class__, ()
|
||||
def render(self, uops:list[UOp]) -> str: raise NotImplementedError("needs a renderer")
|
||||
def render(self, uops:list[UOp]) -> str|tuple[str,...]: raise NotImplementedError("needs a renderer")
|
||||
|
||||
@@ -203,8 +203,8 @@ class QCOMArgsState(HCQArgsState):
|
||||
|
||||
if prg.samp_cnt > 0: to_mv(self.buf.va_addr + prg.samp_off, len(prg.samplers) * 4).cast('I')[:] = array.array('I', prg.samplers)
|
||||
for i, b in enumerate(bufs):
|
||||
if prg.buf_info[i].type in {BUFTYPE_TEX, BUFTYPE_IBO}:
|
||||
obj = prg.tex_infos[i].desc if prg.buf_info[i].type is BUFTYPE_TEX else prg.tex_infos[i].ibo
|
||||
if (ti:=prg.tex_infos[i]) is not None:
|
||||
obj = ti.desc if prg.buf_info[i].type is BUFTYPE_TEX else ti.ibo
|
||||
to_mv(self.buf.va_addr + prg.buf_info[i].offset, len(obj) * 4).cast('I')[:] = array.array('I', obj)
|
||||
self.bind_sints_to_buf(b.va_addr, buf=self.buf, fmt='Q', offset=self.buf_info[i].offset+(0 if self.buf_info[i].type is BUFTYPE_BUF else 16))
|
||||
|
||||
@@ -227,7 +227,7 @@ class IR3ArgsState(HCQArgsState):
|
||||
|
||||
class QCOMProgram(HCQProgram):
|
||||
def __init__(self, dev: QCOMDevice, name: str, lib: bytes, aux_render=None):
|
||||
self.tex_infos = []
|
||||
self.tex_infos:list[QCOMTextureInfo|None] = []
|
||||
for dtype in aux_render:
|
||||
if isinstance(dtype, ImageDType):
|
||||
imgw, imgh, itemsize_log = dtype.shape[1], dtype.shape[0], int(math.log2(dtype.itemsize))
|
||||
|
||||
Reference in New Issue
Block a user