diff --git a/tinygrad/codegen/kernel.py b/tinygrad/codegen/kernel.py index 53c91c1a04..2f2242862c 100644 --- a/tinygrad/codegen/kernel.py +++ b/tinygrad/codegen/kernel.py @@ -2,7 +2,7 @@ from __future__ import annotations import os, math, itertools from typing import NamedTuple, Optional, List, Tuple, cast, Dict, Union from tinygrad.ops import LazyOp, FlopCounter, get_lazyop_info, UnaryOps, BinaryOps, ReduceOps, MemBuffer, ConstBuffer, BufferOps, Device, Compiled -from tinygrad.helpers import dedup, dtypes, colored, ImageDType, DType, all_int, ansilen, getenv, prod, DEBUG +from tinygrad.helpers import dedup, dtypes, colored, ImageDType, DType, ansilen, getenv, prod, DEBUG from tinygrad.shape.shapetracker import ShapeTracker, get_contraction from tinygrad.shape.symbolic import sint from tinygrad.shape.view import View, strides_for_shape @@ -129,11 +129,6 @@ class Kernel: @property def membufs(self) -> List[MemBuffer]: return [x for x in self.bufs if isinstance(x, MemBuffer)] - def has_variable_shape(self) -> bool: - for b in self.bufs: - if not isinstance(b, LocalBuffer) and not all_int(b.st.views[-1].shape): return True - return False - def shape_offsets(self, i): return itertools.product(*[list(range(s)) for s in self.sts[i].shape[self.shape_len-self.upcasted:][::-1]]) if self.upcasted > 0 else [tuple()] def float4_axis(self, i): return [x-(self.shape_len-self.upcasted) for x in self.sts[i].unit_stride_axes() if x >= self.shape_len-self.upcasted and self.sts[i].shape[x]%4 == 0] diff --git a/tinygrad/helpers.py b/tinygrad/helpers.py index fa99c5edd1..0b010cd307 100644 --- a/tinygrad/helpers.py +++ b/tinygrad/helpers.py @@ -1,7 +1,7 @@ from __future__ import annotations import os, functools, platform, time, re, contextlib, operator, hashlib, pickle, sqlite3, cProfile, pstats import numpy as np -from typing import Dict, Tuple, Union, List, NamedTuple, Final, Iterator, ClassVar, Optional, Iterable, Any, TypeVar, TYPE_CHECKING +from typing import Dict, Tuple, Union, List, NamedTuple, Final, Iterator, ClassVar, Optional, Iterable, Any, TypeVar, TYPE_CHECKING, Callable if TYPE_CHECKING: # TODO: remove this and import TypeGuard from typing once minimum python supported version is 3.10 from typing_extensions import TypeGuard @@ -29,9 +29,9 @@ def round_up(num, amt): return num if num%amt == 0 else num+(amt-(num%amt)) def merge_dicts(ds:Iterable[Dict]) -> Dict: assert len(kvs:=set([(k,v) for d in ds for k,v in d.items()])) == len(set(kv[0] for kv in kvs)), f"cannot merge, {kvs} contains different values for the same key" return {k:v for d in ds for k,v in d.items()} -def partition(lst, fxn): - a: list[Any] = [] - b: list[Any] = [] +def partition(lst:List[T], fxn:Callable[[T],bool]): + a:List[T] = [] + b:List[T] = [] for s in lst: (a if fxn(s) else b).append(s) return a,b @@ -200,8 +200,7 @@ def diskcache_get(table:str, key:Union[Dict, str, int]) -> Any: res = cur.execute(f"SELECT val FROM {table} WHERE {' AND '.join([f'{x}=?' for x in key.keys()])}", tuple(key.values())) except sqlite3.OperationalError: return None # table doesn't exist - if (val:=res.fetchone()) is not None: - return pickle.loads(val[0]) + if (val:=res.fetchone()) is not None: return pickle.loads(val[0]) return None _db_tables = set() diff --git a/tinygrad/shape/symbolic.py b/tinygrad/shape/symbolic.py index 7661f447a6..c61a362467 100644 --- a/tinygrad/shape/symbolic.py +++ b/tinygrad/shape/symbolic.py @@ -1,5 +1,4 @@ from __future__ import annotations -from abc import abstractmethod import functools from math import gcd from itertools import product @@ -187,8 +186,7 @@ class OpNode(Node): self.a, self.b = a, b self.min, self.max = self.get_bounds() def vars(self): return self.a.vars() + (self.b.vars() if isinstance(self.b, Node) else []) - @abstractmethod - def get_bounds(self) -> Tuple[int, int]: pass + def get_bounds(self) -> Tuple[int, int]: raise NotImplementedError("must be implemented") class LtNode(OpNode): def __floordiv__(self, b: Union[Node, int], _=False): return (self.a//b) < (self.b//b) @@ -211,8 +209,7 @@ class MulNode(OpNode): def __mod__(self, b: Union[Node, int]): a = (self.a * (self.b%b)) return Node.__mod__(a, b) - def get_bounds(self) -> Tuple[int, int]: - return (self.a.min*self.b, self.a.max*self.b) if self.b >= 0 else (self.a.max*self.b, self.a.min*self.b) + def get_bounds(self) -> Tuple[int, int]: return (self.a.min*self.b, self.a.max*self.b) if self.b >= 0 else (self.a.max*self.b, self.a.min*self.b) def substitute(self, var_vals: Dict[VariableOrNum, Node]) -> Node: return self.a.substitute(var_vals) * (self.b if isinstance(self.b, int) else self.b.substitute(var_vals)) class DivNode(OpNode): diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index 0dc1104496..45bb7f9a73 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -41,9 +41,7 @@ class Tensor: training: ClassVar[bool] = False class train: def __init__(self, val=True): self.val = val - def __enter__(self): - self.prev = Tensor.training - Tensor.training = self.val + def __enter__(self): self.prev, Tensor.training = Tensor.training, self.val def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any): Tensor.training = self.prev no_grad: ClassVar[bool] = False @@ -171,8 +169,7 @@ class Tensor: @staticmethod def eye(dim:int, **kwargs): return Tensor.full((dim,1),1,**kwargs).pad(((0,0),(0,dim))).reshape(dim*(dim+1)).shrink(((0,dim*dim),)).reshape(dim, dim) - def full_like(self, fill_value, **kwargs): - return Tensor.full(self.shape, fill_value=fill_value, dtype=kwargs.pop("dtype", self.dtype), device=kwargs.pop("device", self.device), **kwargs) + def full_like(self, fill_value, **kwargs): return Tensor.full(self.shape, fill_value=fill_value, dtype=kwargs.pop("dtype", self.dtype), device=kwargs.pop("device", self.device), **kwargs) def zeros_like(self, **kwargs): return self.full_like(0, **kwargs) def ones_like(self, **kwargs): return self.full_like(1, **kwargs) @@ -382,8 +379,7 @@ class Tensor: shapes = [s.shape[dim] for s in catargs] shape_cumsum = [0, *accumulate(shapes)] slc = [[(0, 0) for _ in self.shape] for _ in catargs] - for shp,k,s in zip(shapes, shape_cumsum[:-1], slc): - s[dim] = (k, shape_cumsum[-1] - k - shp) + for shp,k,s in zip(shapes, shape_cumsum[:-1], slc): s[dim] = (k, shape_cumsum[-1] - k - shp) return reduce(Tensor.__add__, [arg.pad(tuple(s)) for arg,s in zip(catargs, slc)]) @staticmethod