From 86cd1e9e81416a0665679024de1ec281ecc1515a Mon Sep 17 00:00:00 2001 From: George Hotz <72895+geohot@users.noreply.github.com> Date: Fri, 19 Dec 2025 17:41:18 -0400 Subject: [PATCH] remove UPatAny for typing fix [pr] (#13766) * remove UPatAny for typing fix [pr] * fix dtype --- tinygrad/schedule/rangeify.py | 2 +- tinygrad/tensor.py | 9 +++++---- tinygrad/uop/ops.py | 23 +++++++++++------------ tinygrad/uop/upat.py | 4 ++-- 4 files changed, 19 insertions(+), 19 deletions(-) diff --git a/tinygrad/schedule/rangeify.py b/tinygrad/schedule/rangeify.py index a15de5f020..9811450ea7 100644 --- a/tinygrad/schedule/rangeify.py +++ b/tinygrad/schedule/rangeify.py @@ -553,7 +553,7 @@ def get_rangeify_map(sink:UOp) -> dict[UOp, UOp]: tsink = graph_rewrite(tsink, pm_mops+earliest_rewrites+replace_contiguous, ctx={}, name="earliest rewrites") # convert movement ops to ranges - tsink, rctx = run_rangeify(tsink, DEBUG_RANGEIFY) + tsink, rctx = run_rangeify(tsink, bool(DEBUG_RANGEIFY)) tsink = graph_rewrite(tsink, symbolic+pm_reduce_simplify+pm_const_buffer_folding+pm_remove_bufferize, name="symbolic+reduce_collapse+debuf") tsink = graph_rewrite(tsink, pm_limit_bufs, ctx=rctx, name="limit buffers") diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index c397bbd08b..290a36e568 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -2,7 +2,8 @@ from __future__ import annotations import time, math, itertools, functools, struct, sys, inspect, pathlib, string, hashlib, weakref from contextlib import ContextDecorator -from typing import Callable, ClassVar, Sequence, cast, get_args, Literal, SupportsIndex, ParamSpec, TypeVar, Generic +from typing import Callable, ClassVar, Sequence, cast, get_args, Literal, SupportsIndex, ParamSpec, TypeVar, Generic, TYPE_CHECKING +if TYPE_CHECKING: import numpy from tinygrad.dtype import DType, DTypeLike, dtypes, ImageDType, ConstType, least_upper_float, least_upper_dtype, sum_acc_dtype, to_dtype, truncate from tinygrad.dtype import _from_np_dtype, _to_np_dtype from tinygrad.helpers import argfix, make_tuple, flatten, prod, all_int, round_up, merge_dicts, argsort, getenv, all_same, fully_flatten @@ -40,7 +41,7 @@ def _apply_map_to_tensors(applied_map:dict[UOp, UOp], name:str) -> None: # **** Tensor helper functions **** -def _fromnp(x: 'np.ndarray') -> UOp: # type: ignore [name-defined] # noqa: F821 +def _fromnp(x: 'numpy.ndarray') -> UOp: ret = UOp.new_buffer("NPY", x.size, _from_np_dtype(x.dtype)) # fake realize ret.buffer.allocate(x) @@ -110,7 +111,7 @@ class Tensor(OpMixin): __slots__ = "uop", "requires_grad", "grad" training: ClassVar[bool] = False - def __init__(self, data:ConstType|bytes|list|tuple|UOp|'np.ndarray'|pathlib.Path|None, # type: ignore [name-defined] # noqa: F821 + def __init__(self, data:ConstType|bytes|list|tuple|UOp|'numpy.ndarray'|pathlib.Path|None, device:str|tuple|list|None=None, dtype:DTypeLike|None=None, requires_grad:bool|None=None, _force_unique:bool=False): if device is None and isinstance(data, pathlib.Path): device = f"DISK:{data.resolve()}" # keep it on the disk if device is None _dtype:DType|None = to_dtype(dtype) if dtype is not None else None @@ -345,7 +346,7 @@ class Tensor(OpMixin): if self.dtype in (dtypes.half, dtypes.bfloat16, *dtypes.fp8s): return self.cast(dtypes.float32).tolist() return self.data().tolist() - def numpy(self) -> 'np.ndarray': # type: ignore [name-defined] # noqa: F821 + def numpy(self) -> 'numpy.ndarray': """ Returns the value of this tensor as a `numpy.ndarray`. diff --git a/tinygrad/uop/ops.py b/tinygrad/uop/ops.py index d17ead39dc..b939fc6fa8 100644 --- a/tinygrad/uop/ops.py +++ b/tinygrad/uop/ops.py @@ -913,15 +913,16 @@ def get_location() -> tuple[str, int]: return frm.f_code.co_filename, frm.f_lineno class UPat(OpMixin): - __slots__ = ("op", "dtype", "arg", "name", "src") - def __init__(self, op:Ops|tuple[Ops, ...]|set[Ops]|None=None, dtype:DType|tuple[DType, ...]|None=None, + __slots__ = ("op", "dtype", "arg", "name", "src", "is_any") + def __init__(self, op:Ops|tuple[Ops, ...]|set[Ops]|None=None, dtype:DType|tuple[DType, ...]|set[DType]|None=None, src:tuple[UPat, ...]|list[UPat]|UPat|None=None, arg:Any=None, - name:str|None=None, allow_any_len:bool=False, custom_early_reject:set[Ops]|None=None, location=None): + name:str|None=None, allow_any_len:bool=False, custom_early_reject:set[Ops]|None=None, location=None, is_any:bool=False): assert op is None or isinstance(op, (Ops, tuple, set)), "op must be Ops or tuple of Ops" self.op: tuple[Ops, ...]|None = (op,) if isinstance(op, Ops) else (tuple(op) if isinstance(op, set) else op) - self.dtype: tuple[DType, ...]|None = (dtype,) if isinstance(dtype, DType) else dtype + self.dtype: tuple[DType, ...]|None = (dtype,) if isinstance(dtype, DType) else (tuple(dtype) if isinstance(dtype, set) else dtype) self.arg, self.name, self._in_src, self.custom_early_reject = arg, name, src, custom_early_reject self.src: Any = None + self.is_any = is_any assert self.name != "ctx", "UPat can't be named ctx" assert dtype is None or isinstance(dtype, DType) or all(isinstance(x, DType) for x in dtype), f"invalid dtype {dtype}" @@ -946,7 +947,7 @@ class UPat(OpMixin): def named(self, name:str): return UPat(self.op, self.dtype, self._in_src, self.arg, name, not self.strict_length, self.custom_early_reject) @staticmethod - def any(*src): return UPatAny(src=src) + def any(*src): return UPat(src=src, is_any=True) def or_casted(self, name:str|None=None): return UPat.any(self if name is None else self.named(name), UPat(Ops.CAST, name=name, src=(self,))) def or_after(self, name:str|None=None): return UPat.any(self if name is None else self.named(name), UPat(Ops.AFTER, name=name, src=(self,), allow_any_len=True)) @@ -986,6 +987,9 @@ class UPat(OpMixin): return UPat(op, dtypes.bool if op in {Ops.CMPLT, Ops.CMPNE} else asrc[-1].dtype, list(asrc) if op in GroupOp.Commutative else asrc) def match(self:UPat, uop:UOp, store:dict[str, UOp]) -> list[dict[str, UOp]]: + if self.is_any: + matches = [x.match(uop, store.copy()) for x in self.src[0]] + return flatten([x for x in matches if x is not None]) if (self.op is not None and uop.op not in self.op) or \ (self.name is not None and store.setdefault(self.name, uop) is not uop) or \ (self.dtype is not None and uop.dtype not in self.dtype and uop.dtype.scalar() not in self.dtype) or \ @@ -1002,11 +1006,6 @@ class UPat(OpMixin): res.extend(stores) return res -class UPatAny(UPat): - def match(self:UPat, uop:UOp, store:dict[str, UOp]) -> list[dict[str, UOp]]: - matches = [x.match(uop, store.copy()) for x in self.src[0]] - return flatten([x for x in matches if x is not None]) - def deconstruct_function(fxn:Callable) -> tuple: new_globals = {k:v for k,v in fxn.__globals__.items() if k in fxn.__code__.co_names} for co in fxn.__code__.co_consts: @@ -1050,7 +1049,7 @@ class PatternMatcher: @functools.cache # pylint: disable=method-cache-max-size-none def __add__(self, more:PatternMatcher) -> PatternMatcher: return PatternMatcher(self.patterns+more.patterns) - def rewrite(self, uop:UOp, ctx=None) -> UOp|None: + def rewrite(self, uop:UOp, ctx=None): if len(pats:=self.pdict.get(uop.op, [])): ler = {u.op for u in uop.src} for _,match,early_reject in pats: @@ -1137,7 +1136,7 @@ def profile_matches(fxn:Callable): return wrap_profile_matches class TrackedPatternMatcher(PatternMatcher): - def rewrite(self, uop:UOp, ctx=None) -> UOp|None: + def rewrite(self, uop:UOp, ctx=None): if len(pats:=self.pdict.get(uop.op, [])): ret = None ler = {u.op for u in uop.src} diff --git a/tinygrad/uop/upat.py b/tinygrad/uop/upat.py index 32255ec1fa..3babe9e1de 100644 --- a/tinygrad/uop/upat.py +++ b/tinygrad/uop/upat.py @@ -1,14 +1,14 @@ from typing import Any, Callable import itertools, inspect, functools, types from tinygrad.helpers import partition, dedup, Context -from tinygrad.uop.ops import UPat, UPatAny, UOp, Ops, PatternMatcher, graph_rewrite, deconstruct_function +from tinygrad.uop.ops import UPat, UOp, Ops, PatternMatcher, graph_rewrite, deconstruct_function class UPatCompileError(Exception): pass # **** UPat compiled **** def _get_clause(self:UPat, base:UOp, depth=0) -> UOp: - if isinstance(self, UPatAny): + if self.is_any: assert len(self.src) == 1 return UOp(Ops.AND, src=(UOp(Ops.OR, src=tuple(_get_clause(s, base, depth) for s in self.src[0])),)) # build the and_clause for acceptance