mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 15:08:02 -05:00
mypy==1.13.0 (#7990)
* explicit instantiation and narrowing asserts
* explicit cast
* bump
* one line assert
* handle case for no copy_queue_t
* Revert "handle case for no copy_queue_t"
This reverts commit 38347806ca.
* more readable control flow
---------
Co-authored-by: George Hotz <72895+geohot@users.noreply.github.com>
This commit is contained in:
2
setup.py
2
setup.py
@@ -29,7 +29,7 @@ setup(name='tinygrad',
|
||||
'triton': ["triton-nightly>=2.1.0.dev20231014192330"],
|
||||
'linting': [
|
||||
"pylint",
|
||||
"mypy==1.11.2",
|
||||
"mypy==1.13.0",
|
||||
"typing-extensions",
|
||||
"pre-commit",
|
||||
"ruff",
|
||||
|
||||
@@ -52,7 +52,7 @@ class PtrDType(DType):
|
||||
def vec(self, sz:int) -> DType:
|
||||
assert self.v == 1, f"can't vectorize ptr {self} with size {sz}"
|
||||
if sz == 1: return self # sz=1 is a scalar
|
||||
return type(self)(*tuple(sz if f.name == 'v' else (self if f.name == '_scalar' else getattr(self, f.name)) for f in fields(self)))
|
||||
return type(self)(self.priority, self.itemsize, self.name, self.fmt, self.count, self, self._base, self.local, sz)
|
||||
def ptr(self, local=False): raise RuntimeError("can't make a pointer from a pointer")
|
||||
@property
|
||||
def vcount(self): return self.v
|
||||
|
||||
@@ -53,8 +53,14 @@ class HCQGraph(MultiGraphRunner):
|
||||
for dev, queue in self.comp_queues.items(): dev_access[queue].add(dev)
|
||||
|
||||
for j,ji in enumerate(jit_cache):
|
||||
enqueue_dev = ji.prg.dev if (is_exec_prg:=isinstance(ji.prg, CompiledRunner)) else Device[ji.bufs[1].device] #type:ignore
|
||||
enqueue_queue = self.comp_queues[enqueue_dev] if is_exec_prg else self.copy_queues.setdefault(enqueue_dev, enqueue_dev.hw_copy_queue_t())
|
||||
enqueue_dev: HCQCompiled = ji.prg.dev if (is_exec_prg:=isinstance(ji.prg, CompiledRunner)) else Device[ji.bufs[1].device] #type:ignore
|
||||
|
||||
if is_exec_prg:
|
||||
enqueue_queue = self.comp_queues[enqueue_dev]
|
||||
else:
|
||||
assert (enqueue_dev.hw_copy_queue_t is not None), "device must implement a copy queue"
|
||||
enqueue_queue = self.copy_queues.setdefault(enqueue_dev, enqueue_dev.hw_copy_queue_t())
|
||||
|
||||
out_signal = self.signals.setdefault(enqueue_queue, enqueue_dev.signal_t(value=0))
|
||||
|
||||
# Get dependencies based on input and output buffers.
|
||||
|
||||
@@ -282,7 +282,7 @@ class Tensor(SimpleMathTrait):
|
||||
assert self.dtype.base.fmt is not None, f"no fmt dtype for {self.dtype.base}"
|
||||
assert all_int(self.shape), f"no data if shape is symbolic, {self.shape=}"
|
||||
if TYPE_CHECKING or sys.version_info < (3, 12): assert self.dtype.base.fmt != "e"
|
||||
return self._data().cast(self.dtype.base.fmt) if 0 in self.shape else self._data().cast(self.dtype.base.fmt, self.shape)
|
||||
return cast(memoryview, self._data().cast(self.dtype.base.fmt) if 0 in self.shape else self._data().cast(self.dtype.base.fmt, self.shape))
|
||||
|
||||
def item(self) -> ConstType:
|
||||
"""
|
||||
|
||||
Reference in New Issue
Block a user