mypy==1.13.0 (#7990)

* explicit instantiation and narrowing asserts

* explicit cast

* bump

* one line assert

* handle case for no copy_queue_t

* Revert "handle case for no copy_queue_t"

This reverts commit 38347806ca.

* more readable control flow

---------

Co-authored-by: George Hotz <72895+geohot@users.noreply.github.com>
This commit is contained in:
JaSpa99
2024-12-06 05:09:14 +01:00
committed by GitHub
parent 65b6696f3b
commit 3c5d5f9414
4 changed files with 11 additions and 5 deletions

View File

@@ -29,7 +29,7 @@ setup(name='tinygrad',
'triton': ["triton-nightly>=2.1.0.dev20231014192330"],
'linting': [
"pylint",
"mypy==1.11.2",
"mypy==1.13.0",
"typing-extensions",
"pre-commit",
"ruff",

View File

@@ -52,7 +52,7 @@ class PtrDType(DType):
def vec(self, sz:int) -> DType:
assert self.v == 1, f"can't vectorize ptr {self} with size {sz}"
if sz == 1: return self # sz=1 is a scalar
return type(self)(*tuple(sz if f.name == 'v' else (self if f.name == '_scalar' else getattr(self, f.name)) for f in fields(self)))
return type(self)(self.priority, self.itemsize, self.name, self.fmt, self.count, self, self._base, self.local, sz)
def ptr(self, local=False): raise RuntimeError("can't make a pointer from a pointer")
@property
def vcount(self): return self.v

View File

@@ -53,8 +53,14 @@ class HCQGraph(MultiGraphRunner):
for dev, queue in self.comp_queues.items(): dev_access[queue].add(dev)
for j,ji in enumerate(jit_cache):
enqueue_dev = ji.prg.dev if (is_exec_prg:=isinstance(ji.prg, CompiledRunner)) else Device[ji.bufs[1].device] #type:ignore
enqueue_queue = self.comp_queues[enqueue_dev] if is_exec_prg else self.copy_queues.setdefault(enqueue_dev, enqueue_dev.hw_copy_queue_t())
enqueue_dev: HCQCompiled = ji.prg.dev if (is_exec_prg:=isinstance(ji.prg, CompiledRunner)) else Device[ji.bufs[1].device] #type:ignore
if is_exec_prg:
enqueue_queue = self.comp_queues[enqueue_dev]
else:
assert (enqueue_dev.hw_copy_queue_t is not None), "device must implement a copy queue"
enqueue_queue = self.copy_queues.setdefault(enqueue_dev, enqueue_dev.hw_copy_queue_t())
out_signal = self.signals.setdefault(enqueue_queue, enqueue_dev.signal_t(value=0))
# Get dependencies based on input and output buffers.

View File

@@ -282,7 +282,7 @@ class Tensor(SimpleMathTrait):
assert self.dtype.base.fmt is not None, f"no fmt dtype for {self.dtype.base}"
assert all_int(self.shape), f"no data if shape is symbolic, {self.shape=}"
if TYPE_CHECKING or sys.version_info < (3, 12): assert self.dtype.base.fmt != "e"
return self._data().cast(self.dtype.base.fmt) if 0 in self.shape else self._data().cast(self.dtype.base.fmt, self.shape)
return cast(memoryview, self._data().cast(self.dtype.base.fmt) if 0 in self.shape else self._data().cast(self.dtype.base.fmt, self.shape))
def item(self) -> ConstType:
"""