mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-07 22:23:55 -05:00
remove UPatAny for typing fix [pr] (#13766)
* remove UPatAny for typing fix [pr] * fix dtype
This commit is contained in:
@@ -553,7 +553,7 @@ def get_rangeify_map(sink:UOp) -> dict[UOp, UOp]:
|
||||
tsink = graph_rewrite(tsink, pm_mops+earliest_rewrites+replace_contiguous, ctx={}, name="earliest rewrites")
|
||||
|
||||
# convert movement ops to ranges
|
||||
tsink, rctx = run_rangeify(tsink, DEBUG_RANGEIFY)
|
||||
tsink, rctx = run_rangeify(tsink, bool(DEBUG_RANGEIFY))
|
||||
|
||||
tsink = graph_rewrite(tsink, symbolic+pm_reduce_simplify+pm_const_buffer_folding+pm_remove_bufferize, name="symbolic+reduce_collapse+debuf")
|
||||
tsink = graph_rewrite(tsink, pm_limit_bufs, ctx=rctx, name="limit buffers")
|
||||
|
||||
@@ -2,7 +2,8 @@
|
||||
from __future__ import annotations
|
||||
import time, math, itertools, functools, struct, sys, inspect, pathlib, string, hashlib, weakref
|
||||
from contextlib import ContextDecorator
|
||||
from typing import Callable, ClassVar, Sequence, cast, get_args, Literal, SupportsIndex, ParamSpec, TypeVar, Generic
|
||||
from typing import Callable, ClassVar, Sequence, cast, get_args, Literal, SupportsIndex, ParamSpec, TypeVar, Generic, TYPE_CHECKING
|
||||
if TYPE_CHECKING: import numpy
|
||||
from tinygrad.dtype import DType, DTypeLike, dtypes, ImageDType, ConstType, least_upper_float, least_upper_dtype, sum_acc_dtype, to_dtype, truncate
|
||||
from tinygrad.dtype import _from_np_dtype, _to_np_dtype
|
||||
from tinygrad.helpers import argfix, make_tuple, flatten, prod, all_int, round_up, merge_dicts, argsort, getenv, all_same, fully_flatten
|
||||
@@ -40,7 +41,7 @@ def _apply_map_to_tensors(applied_map:dict[UOp, UOp], name:str) -> None:
|
||||
|
||||
# **** Tensor helper functions ****
|
||||
|
||||
def _fromnp(x: 'np.ndarray') -> UOp: # type: ignore [name-defined] # noqa: F821
|
||||
def _fromnp(x: 'numpy.ndarray') -> UOp:
|
||||
ret = UOp.new_buffer("NPY", x.size, _from_np_dtype(x.dtype))
|
||||
# fake realize
|
||||
ret.buffer.allocate(x)
|
||||
@@ -110,7 +111,7 @@ class Tensor(OpMixin):
|
||||
__slots__ = "uop", "requires_grad", "grad"
|
||||
training: ClassVar[bool] = False
|
||||
|
||||
def __init__(self, data:ConstType|bytes|list|tuple|UOp|'np.ndarray'|pathlib.Path|None, # type: ignore [name-defined] # noqa: F821
|
||||
def __init__(self, data:ConstType|bytes|list|tuple|UOp|'numpy.ndarray'|pathlib.Path|None,
|
||||
device:str|tuple|list|None=None, dtype:DTypeLike|None=None, requires_grad:bool|None=None, _force_unique:bool=False):
|
||||
if device is None and isinstance(data, pathlib.Path): device = f"DISK:{data.resolve()}" # keep it on the disk if device is None
|
||||
_dtype:DType|None = to_dtype(dtype) if dtype is not None else None
|
||||
@@ -345,7 +346,7 @@ class Tensor(OpMixin):
|
||||
if self.dtype in (dtypes.half, dtypes.bfloat16, *dtypes.fp8s): return self.cast(dtypes.float32).tolist()
|
||||
return self.data().tolist()
|
||||
|
||||
def numpy(self) -> 'np.ndarray': # type: ignore [name-defined] # noqa: F821
|
||||
def numpy(self) -> 'numpy.ndarray':
|
||||
"""
|
||||
Returns the value of this tensor as a `numpy.ndarray`.
|
||||
|
||||
|
||||
@@ -913,15 +913,16 @@ def get_location() -> tuple[str, int]:
|
||||
return frm.f_code.co_filename, frm.f_lineno
|
||||
|
||||
class UPat(OpMixin):
|
||||
__slots__ = ("op", "dtype", "arg", "name", "src")
|
||||
def __init__(self, op:Ops|tuple[Ops, ...]|set[Ops]|None=None, dtype:DType|tuple[DType, ...]|None=None,
|
||||
__slots__ = ("op", "dtype", "arg", "name", "src", "is_any")
|
||||
def __init__(self, op:Ops|tuple[Ops, ...]|set[Ops]|None=None, dtype:DType|tuple[DType, ...]|set[DType]|None=None,
|
||||
src:tuple[UPat, ...]|list[UPat]|UPat|None=None, arg:Any=None,
|
||||
name:str|None=None, allow_any_len:bool=False, custom_early_reject:set[Ops]|None=None, location=None):
|
||||
name:str|None=None, allow_any_len:bool=False, custom_early_reject:set[Ops]|None=None, location=None, is_any:bool=False):
|
||||
assert op is None or isinstance(op, (Ops, tuple, set)), "op must be Ops or tuple of Ops"
|
||||
self.op: tuple[Ops, ...]|None = (op,) if isinstance(op, Ops) else (tuple(op) if isinstance(op, set) else op)
|
||||
self.dtype: tuple[DType, ...]|None = (dtype,) if isinstance(dtype, DType) else dtype
|
||||
self.dtype: tuple[DType, ...]|None = (dtype,) if isinstance(dtype, DType) else (tuple(dtype) if isinstance(dtype, set) else dtype)
|
||||
self.arg, self.name, self._in_src, self.custom_early_reject = arg, name, src, custom_early_reject
|
||||
self.src: Any = None
|
||||
self.is_any = is_any
|
||||
assert self.name != "ctx", "UPat can't be named ctx"
|
||||
assert dtype is None or isinstance(dtype, DType) or all(isinstance(x, DType) for x in dtype), f"invalid dtype {dtype}"
|
||||
|
||||
@@ -946,7 +947,7 @@ class UPat(OpMixin):
|
||||
def named(self, name:str): return UPat(self.op, self.dtype, self._in_src, self.arg, name, not self.strict_length, self.custom_early_reject)
|
||||
|
||||
@staticmethod
|
||||
def any(*src): return UPatAny(src=src)
|
||||
def any(*src): return UPat(src=src, is_any=True)
|
||||
def or_casted(self, name:str|None=None): return UPat.any(self if name is None else self.named(name), UPat(Ops.CAST, name=name, src=(self,)))
|
||||
def or_after(self, name:str|None=None):
|
||||
return UPat.any(self if name is None else self.named(name), UPat(Ops.AFTER, name=name, src=(self,), allow_any_len=True))
|
||||
@@ -986,6 +987,9 @@ class UPat(OpMixin):
|
||||
return UPat(op, dtypes.bool if op in {Ops.CMPLT, Ops.CMPNE} else asrc[-1].dtype, list(asrc) if op in GroupOp.Commutative else asrc)
|
||||
|
||||
def match(self:UPat, uop:UOp, store:dict[str, UOp]) -> list[dict[str, UOp]]:
|
||||
if self.is_any:
|
||||
matches = [x.match(uop, store.copy()) for x in self.src[0]]
|
||||
return flatten([x for x in matches if x is not None])
|
||||
if (self.op is not None and uop.op not in self.op) or \
|
||||
(self.name is not None and store.setdefault(self.name, uop) is not uop) or \
|
||||
(self.dtype is not None and uop.dtype not in self.dtype and uop.dtype.scalar() not in self.dtype) or \
|
||||
@@ -1002,11 +1006,6 @@ class UPat(OpMixin):
|
||||
res.extend(stores)
|
||||
return res
|
||||
|
||||
class UPatAny(UPat):
|
||||
def match(self:UPat, uop:UOp, store:dict[str, UOp]) -> list[dict[str, UOp]]:
|
||||
matches = [x.match(uop, store.copy()) for x in self.src[0]]
|
||||
return flatten([x for x in matches if x is not None])
|
||||
|
||||
def deconstruct_function(fxn:Callable) -> tuple:
|
||||
new_globals = {k:v for k,v in fxn.__globals__.items() if k in fxn.__code__.co_names}
|
||||
for co in fxn.__code__.co_consts:
|
||||
@@ -1050,7 +1049,7 @@ class PatternMatcher:
|
||||
@functools.cache # pylint: disable=method-cache-max-size-none
|
||||
def __add__(self, more:PatternMatcher) -> PatternMatcher: return PatternMatcher(self.patterns+more.patterns)
|
||||
|
||||
def rewrite(self, uop:UOp, ctx=None) -> UOp|None:
|
||||
def rewrite(self, uop:UOp, ctx=None):
|
||||
if len(pats:=self.pdict.get(uop.op, [])):
|
||||
ler = {u.op for u in uop.src}
|
||||
for _,match,early_reject in pats:
|
||||
@@ -1137,7 +1136,7 @@ def profile_matches(fxn:Callable):
|
||||
return wrap_profile_matches
|
||||
|
||||
class TrackedPatternMatcher(PatternMatcher):
|
||||
def rewrite(self, uop:UOp, ctx=None) -> UOp|None:
|
||||
def rewrite(self, uop:UOp, ctx=None):
|
||||
if len(pats:=self.pdict.get(uop.op, [])):
|
||||
ret = None
|
||||
ler = {u.op for u in uop.src}
|
||||
|
||||
@@ -1,14 +1,14 @@
|
||||
from typing import Any, Callable
|
||||
import itertools, inspect, functools, types
|
||||
from tinygrad.helpers import partition, dedup, Context
|
||||
from tinygrad.uop.ops import UPat, UPatAny, UOp, Ops, PatternMatcher, graph_rewrite, deconstruct_function
|
||||
from tinygrad.uop.ops import UPat, UOp, Ops, PatternMatcher, graph_rewrite, deconstruct_function
|
||||
|
||||
class UPatCompileError(Exception): pass
|
||||
|
||||
# **** UPat compiled ****
|
||||
|
||||
def _get_clause(self:UPat, base:UOp, depth=0) -> UOp:
|
||||
if isinstance(self, UPatAny):
|
||||
if self.is_any:
|
||||
assert len(self.src) == 1
|
||||
return UOp(Ops.AND, src=(UOp(Ops.OR, src=tuple(_get_clause(s, base, depth) for s in self.src[0])),))
|
||||
# build the and_clause for acceptance
|
||||
|
||||
Reference in New Issue
Block a user