mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
fix some mypy cast [pr] (#13331)
This commit is contained in:
@@ -89,8 +89,8 @@ class UOpMetaClass(type):
|
||||
if SPEC > 1:
|
||||
from tinygrad.uop.spec import full_spec, test_pyrender
|
||||
if SPEC > 2: test_pyrender(created)
|
||||
with Context(IGNORE_OOB=1): ret = full_spec.rewrite(created)
|
||||
if cast(bool|None, ret) is not True: raise RuntimeError(f"SPEC ISSUE {ret}: {created}")
|
||||
with Context(IGNORE_OOB=1): fret = cast(bool|None, full_spec.rewrite(created))
|
||||
if fret is not True: raise RuntimeError(f"SPEC ISSUE {fret}: {created}")
|
||||
return created
|
||||
|
||||
# some uops map to other stuff
|
||||
@@ -583,7 +583,7 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
|
||||
def new_buffer(device:str|tuple[str, ...], size:int, dtype:DType, num=None):
|
||||
return UOp(Ops.BUFFER, dtype, (UOp.unique(num), UOp(Ops.DEVICE, arg=device)), size)
|
||||
@property
|
||||
def device(self) -> str|tuple[str, ...]: return cast(str|tuple[str, ...], unwrap(self._device))
|
||||
def device(self) -> str|tuple[str, ...]: return unwrap(self._device)
|
||||
@recursive_property
|
||||
def _device(self) -> str|tuple[str, ...]|None:
|
||||
if self.op is Ops.DEVICE: return self.arg
|
||||
@@ -1164,12 +1164,12 @@ class RewriteContext:
|
||||
|
||||
def cached_pm_rewrite(self, x:UOp):
|
||||
if (ret:=self.pm_cache.get(x,SENTINEL)) is not SENTINEL: return ret
|
||||
ret = self.pm_cache[x] = cast(PatternMatcher, self.pm).rewrite(x, self.ctx)
|
||||
ret = self.pm_cache[x] = unwrap(self.pm).rewrite(x, self.ctx)
|
||||
return ret
|
||||
|
||||
def cached_bpm_rewrite(self, x:UOp):
|
||||
if (ret:=self.bpm_cache.get(x,SENTINEL)) is not SENTINEL: return ret
|
||||
ret = self.bpm_cache[x] = cast(PatternMatcher, self.bpm).rewrite(x, self.ctx)
|
||||
ret = self.bpm_cache[x] = unwrap(self.bpm).rewrite(x, self.ctx)
|
||||
return ret
|
||||
|
||||
def unified_rewrite(self, root:UOp) -> UOp:
|
||||
|
||||
Reference in New Issue
Block a user