fix some mypy cast [pr] (#13331)

This commit is contained in:
chenyu
2025-11-18 09:23:42 -05:00
committed by GitHub
parent 5623e765c8
commit 05294bc648

View File

@@ -89,8 +89,8 @@ class UOpMetaClass(type):
if SPEC > 1:
from tinygrad.uop.spec import full_spec, test_pyrender
if SPEC > 2: test_pyrender(created)
with Context(IGNORE_OOB=1): ret = full_spec.rewrite(created)
if cast(bool|None, ret) is not True: raise RuntimeError(f"SPEC ISSUE {ret}: {created}")
with Context(IGNORE_OOB=1): fret = cast(bool|None, full_spec.rewrite(created))
if fret is not True: raise RuntimeError(f"SPEC ISSUE {fret}: {created}")
return created
# some uops map to other stuff
@@ -583,7 +583,7 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
def new_buffer(device:str|tuple[str, ...], size:int, dtype:DType, num=None):
return UOp(Ops.BUFFER, dtype, (UOp.unique(num), UOp(Ops.DEVICE, arg=device)), size)
@property
def device(self) -> str|tuple[str, ...]: return cast(str|tuple[str, ...], unwrap(self._device))
def device(self) -> str|tuple[str, ...]: return unwrap(self._device)
@recursive_property
def _device(self) -> str|tuple[str, ...]|None:
if self.op is Ops.DEVICE: return self.arg
@@ -1164,12 +1164,12 @@ class RewriteContext:
def cached_pm_rewrite(self, x:UOp):
if (ret:=self.pm_cache.get(x,SENTINEL)) is not SENTINEL: return ret
ret = self.pm_cache[x] = cast(PatternMatcher, self.pm).rewrite(x, self.ctx)
ret = self.pm_cache[x] = unwrap(self.pm).rewrite(x, self.ctx)
return ret
def cached_bpm_rewrite(self, x:UOp):
if (ret:=self.bpm_cache.get(x,SENTINEL)) is not SENTINEL: return ret
ret = self.bpm_cache[x] = cast(PatternMatcher, self.bpm).rewrite(x, self.ctx)
ret = self.bpm_cache[x] = unwrap(self.bpm).rewrite(x, self.ctx)
return ret
def unified_rewrite(self, root:UOp) -> UOp: