remove UPatAny for typing fix [pr] (#13766)

* remove UPatAny for typing fix [pr]

* fix dtype
This commit is contained in:
George Hotz
2025-12-19 17:41:18 -04:00
committed by GitHub
parent 4702da41d5
commit 86cd1e9e81
4 changed files with 19 additions and 19 deletions

View File

@@ -553,7 +553,7 @@ def get_rangeify_map(sink:UOp) -> dict[UOp, UOp]:
tsink = graph_rewrite(tsink, pm_mops+earliest_rewrites+replace_contiguous, ctx={}, name="earliest rewrites") tsink = graph_rewrite(tsink, pm_mops+earliest_rewrites+replace_contiguous, ctx={}, name="earliest rewrites")
# convert movement ops to ranges # convert movement ops to ranges
tsink, rctx = run_rangeify(tsink, DEBUG_RANGEIFY) tsink, rctx = run_rangeify(tsink, bool(DEBUG_RANGEIFY))
tsink = graph_rewrite(tsink, symbolic+pm_reduce_simplify+pm_const_buffer_folding+pm_remove_bufferize, name="symbolic+reduce_collapse+debuf") tsink = graph_rewrite(tsink, symbolic+pm_reduce_simplify+pm_const_buffer_folding+pm_remove_bufferize, name="symbolic+reduce_collapse+debuf")
tsink = graph_rewrite(tsink, pm_limit_bufs, ctx=rctx, name="limit buffers") tsink = graph_rewrite(tsink, pm_limit_bufs, ctx=rctx, name="limit buffers")

View File

@@ -2,7 +2,8 @@
from __future__ import annotations from __future__ import annotations
import time, math, itertools, functools, struct, sys, inspect, pathlib, string, hashlib, weakref import time, math, itertools, functools, struct, sys, inspect, pathlib, string, hashlib, weakref
from contextlib import ContextDecorator from contextlib import ContextDecorator
from typing import Callable, ClassVar, Sequence, cast, get_args, Literal, SupportsIndex, ParamSpec, TypeVar, Generic from typing import Callable, ClassVar, Sequence, cast, get_args, Literal, SupportsIndex, ParamSpec, TypeVar, Generic, TYPE_CHECKING
if TYPE_CHECKING: import numpy
from tinygrad.dtype import DType, DTypeLike, dtypes, ImageDType, ConstType, least_upper_float, least_upper_dtype, sum_acc_dtype, to_dtype, truncate from tinygrad.dtype import DType, DTypeLike, dtypes, ImageDType, ConstType, least_upper_float, least_upper_dtype, sum_acc_dtype, to_dtype, truncate
from tinygrad.dtype import _from_np_dtype, _to_np_dtype from tinygrad.dtype import _from_np_dtype, _to_np_dtype
from tinygrad.helpers import argfix, make_tuple, flatten, prod, all_int, round_up, merge_dicts, argsort, getenv, all_same, fully_flatten from tinygrad.helpers import argfix, make_tuple, flatten, prod, all_int, round_up, merge_dicts, argsort, getenv, all_same, fully_flatten
@@ -40,7 +41,7 @@ def _apply_map_to_tensors(applied_map:dict[UOp, UOp], name:str) -> None:
# **** Tensor helper functions **** # **** Tensor helper functions ****
def _fromnp(x: 'np.ndarray') -> UOp: # type: ignore [name-defined] # noqa: F821 def _fromnp(x: 'numpy.ndarray') -> UOp:
ret = UOp.new_buffer("NPY", x.size, _from_np_dtype(x.dtype)) ret = UOp.new_buffer("NPY", x.size, _from_np_dtype(x.dtype))
# fake realize # fake realize
ret.buffer.allocate(x) ret.buffer.allocate(x)
@@ -110,7 +111,7 @@ class Tensor(OpMixin):
__slots__ = "uop", "requires_grad", "grad" __slots__ = "uop", "requires_grad", "grad"
training: ClassVar[bool] = False training: ClassVar[bool] = False
def __init__(self, data:ConstType|bytes|list|tuple|UOp|'np.ndarray'|pathlib.Path|None, # type: ignore [name-defined] # noqa: F821 def __init__(self, data:ConstType|bytes|list|tuple|UOp|'numpy.ndarray'|pathlib.Path|None,
device:str|tuple|list|None=None, dtype:DTypeLike|None=None, requires_grad:bool|None=None, _force_unique:bool=False): device:str|tuple|list|None=None, dtype:DTypeLike|None=None, requires_grad:bool|None=None, _force_unique:bool=False):
if device is None and isinstance(data, pathlib.Path): device = f"DISK:{data.resolve()}" # keep it on the disk if device is None if device is None and isinstance(data, pathlib.Path): device = f"DISK:{data.resolve()}" # keep it on the disk if device is None
_dtype:DType|None = to_dtype(dtype) if dtype is not None else None _dtype:DType|None = to_dtype(dtype) if dtype is not None else None
@@ -345,7 +346,7 @@ class Tensor(OpMixin):
if self.dtype in (dtypes.half, dtypes.bfloat16, *dtypes.fp8s): return self.cast(dtypes.float32).tolist() if self.dtype in (dtypes.half, dtypes.bfloat16, *dtypes.fp8s): return self.cast(dtypes.float32).tolist()
return self.data().tolist() return self.data().tolist()
def numpy(self) -> 'np.ndarray': # type: ignore [name-defined] # noqa: F821 def numpy(self) -> 'numpy.ndarray':
""" """
Returns the value of this tensor as a `numpy.ndarray`. Returns the value of this tensor as a `numpy.ndarray`.

View File

@@ -913,15 +913,16 @@ def get_location() -> tuple[str, int]:
return frm.f_code.co_filename, frm.f_lineno return frm.f_code.co_filename, frm.f_lineno
class UPat(OpMixin): class UPat(OpMixin):
__slots__ = ("op", "dtype", "arg", "name", "src") __slots__ = ("op", "dtype", "arg", "name", "src", "is_any")
def __init__(self, op:Ops|tuple[Ops, ...]|set[Ops]|None=None, dtype:DType|tuple[DType, ...]|None=None, def __init__(self, op:Ops|tuple[Ops, ...]|set[Ops]|None=None, dtype:DType|tuple[DType, ...]|set[DType]|None=None,
src:tuple[UPat, ...]|list[UPat]|UPat|None=None, arg:Any=None, src:tuple[UPat, ...]|list[UPat]|UPat|None=None, arg:Any=None,
name:str|None=None, allow_any_len:bool=False, custom_early_reject:set[Ops]|None=None, location=None): name:str|None=None, allow_any_len:bool=False, custom_early_reject:set[Ops]|None=None, location=None, is_any:bool=False):
assert op is None or isinstance(op, (Ops, tuple, set)), "op must be Ops or tuple of Ops" assert op is None or isinstance(op, (Ops, tuple, set)), "op must be Ops or tuple of Ops"
self.op: tuple[Ops, ...]|None = (op,) if isinstance(op, Ops) else (tuple(op) if isinstance(op, set) else op) self.op: tuple[Ops, ...]|None = (op,) if isinstance(op, Ops) else (tuple(op) if isinstance(op, set) else op)
self.dtype: tuple[DType, ...]|None = (dtype,) if isinstance(dtype, DType) else dtype self.dtype: tuple[DType, ...]|None = (dtype,) if isinstance(dtype, DType) else (tuple(dtype) if isinstance(dtype, set) else dtype)
self.arg, self.name, self._in_src, self.custom_early_reject = arg, name, src, custom_early_reject self.arg, self.name, self._in_src, self.custom_early_reject = arg, name, src, custom_early_reject
self.src: Any = None self.src: Any = None
self.is_any = is_any
assert self.name != "ctx", "UPat can't be named ctx" assert self.name != "ctx", "UPat can't be named ctx"
assert dtype is None or isinstance(dtype, DType) or all(isinstance(x, DType) for x in dtype), f"invalid dtype {dtype}" assert dtype is None or isinstance(dtype, DType) or all(isinstance(x, DType) for x in dtype), f"invalid dtype {dtype}"
@@ -946,7 +947,7 @@ class UPat(OpMixin):
def named(self, name:str): return UPat(self.op, self.dtype, self._in_src, self.arg, name, not self.strict_length, self.custom_early_reject) def named(self, name:str): return UPat(self.op, self.dtype, self._in_src, self.arg, name, not self.strict_length, self.custom_early_reject)
@staticmethod @staticmethod
def any(*src): return UPatAny(src=src) def any(*src): return UPat(src=src, is_any=True)
def or_casted(self, name:str|None=None): return UPat.any(self if name is None else self.named(name), UPat(Ops.CAST, name=name, src=(self,))) def or_casted(self, name:str|None=None): return UPat.any(self if name is None else self.named(name), UPat(Ops.CAST, name=name, src=(self,)))
def or_after(self, name:str|None=None): def or_after(self, name:str|None=None):
return UPat.any(self if name is None else self.named(name), UPat(Ops.AFTER, name=name, src=(self,), allow_any_len=True)) return UPat.any(self if name is None else self.named(name), UPat(Ops.AFTER, name=name, src=(self,), allow_any_len=True))
@@ -986,6 +987,9 @@ class UPat(OpMixin):
return UPat(op, dtypes.bool if op in {Ops.CMPLT, Ops.CMPNE} else asrc[-1].dtype, list(asrc) if op in GroupOp.Commutative else asrc) return UPat(op, dtypes.bool if op in {Ops.CMPLT, Ops.CMPNE} else asrc[-1].dtype, list(asrc) if op in GroupOp.Commutative else asrc)
def match(self:UPat, uop:UOp, store:dict[str, UOp]) -> list[dict[str, UOp]]: def match(self:UPat, uop:UOp, store:dict[str, UOp]) -> list[dict[str, UOp]]:
if self.is_any:
matches = [x.match(uop, store.copy()) for x in self.src[0]]
return flatten([x for x in matches if x is not None])
if (self.op is not None and uop.op not in self.op) or \ if (self.op is not None and uop.op not in self.op) or \
(self.name is not None and store.setdefault(self.name, uop) is not uop) or \ (self.name is not None and store.setdefault(self.name, uop) is not uop) or \
(self.dtype is not None and uop.dtype not in self.dtype and uop.dtype.scalar() not in self.dtype) or \ (self.dtype is not None and uop.dtype not in self.dtype and uop.dtype.scalar() not in self.dtype) or \
@@ -1002,11 +1006,6 @@ class UPat(OpMixin):
res.extend(stores) res.extend(stores)
return res return res
class UPatAny(UPat):
def match(self:UPat, uop:UOp, store:dict[str, UOp]) -> list[dict[str, UOp]]:
matches = [x.match(uop, store.copy()) for x in self.src[0]]
return flatten([x for x in matches if x is not None])
def deconstruct_function(fxn:Callable) -> tuple: def deconstruct_function(fxn:Callable) -> tuple:
new_globals = {k:v for k,v in fxn.__globals__.items() if k in fxn.__code__.co_names} new_globals = {k:v for k,v in fxn.__globals__.items() if k in fxn.__code__.co_names}
for co in fxn.__code__.co_consts: for co in fxn.__code__.co_consts:
@@ -1050,7 +1049,7 @@ class PatternMatcher:
@functools.cache # pylint: disable=method-cache-max-size-none @functools.cache # pylint: disable=method-cache-max-size-none
def __add__(self, more:PatternMatcher) -> PatternMatcher: return PatternMatcher(self.patterns+more.patterns) def __add__(self, more:PatternMatcher) -> PatternMatcher: return PatternMatcher(self.patterns+more.patterns)
def rewrite(self, uop:UOp, ctx=None) -> UOp|None: def rewrite(self, uop:UOp, ctx=None):
if len(pats:=self.pdict.get(uop.op, [])): if len(pats:=self.pdict.get(uop.op, [])):
ler = {u.op for u in uop.src} ler = {u.op for u in uop.src}
for _,match,early_reject in pats: for _,match,early_reject in pats:
@@ -1137,7 +1136,7 @@ def profile_matches(fxn:Callable):
return wrap_profile_matches return wrap_profile_matches
class TrackedPatternMatcher(PatternMatcher): class TrackedPatternMatcher(PatternMatcher):
def rewrite(self, uop:UOp, ctx=None) -> UOp|None: def rewrite(self, uop:UOp, ctx=None):
if len(pats:=self.pdict.get(uop.op, [])): if len(pats:=self.pdict.get(uop.op, [])):
ret = None ret = None
ler = {u.op for u in uop.src} ler = {u.op for u in uop.src}

View File

@@ -1,14 +1,14 @@
from typing import Any, Callable from typing import Any, Callable
import itertools, inspect, functools, types import itertools, inspect, functools, types
from tinygrad.helpers import partition, dedup, Context from tinygrad.helpers import partition, dedup, Context
from tinygrad.uop.ops import UPat, UPatAny, UOp, Ops, PatternMatcher, graph_rewrite, deconstruct_function from tinygrad.uop.ops import UPat, UOp, Ops, PatternMatcher, graph_rewrite, deconstruct_function
class UPatCompileError(Exception): pass class UPatCompileError(Exception): pass
# **** UPat compiled **** # **** UPat compiled ****
def _get_clause(self:UPat, base:UOp, depth=0) -> UOp: def _get_clause(self:UPat, base:UOp, depth=0) -> UOp:
if isinstance(self, UPatAny): if self.is_any:
assert len(self.src) == 1 assert len(self.src) == 1
return UOp(Ops.AND, src=(UOp(Ops.OR, src=tuple(_get_clause(s, base, depth) for s in self.src[0])),)) return UOp(Ops.AND, src=(UOp(Ops.OR, src=tuple(_get_clause(s, base, depth) for s in self.src[0])),))
# build the and_clause for acceptance # build the and_clause for acceptance