mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-26 07:18:40 -05:00
cleanup of lines / unused / types (#2336)
This commit is contained in:
@@ -2,7 +2,7 @@ from __future__ import annotations
|
||||
import os, math, itertools
|
||||
from typing import NamedTuple, Optional, List, Tuple, cast, Dict, Union
|
||||
from tinygrad.ops import LazyOp, FlopCounter, get_lazyop_info, UnaryOps, BinaryOps, ReduceOps, MemBuffer, ConstBuffer, BufferOps, Device, Compiled
|
||||
from tinygrad.helpers import dedup, dtypes, colored, ImageDType, DType, all_int, ansilen, getenv, prod, DEBUG
|
||||
from tinygrad.helpers import dedup, dtypes, colored, ImageDType, DType, ansilen, getenv, prod, DEBUG
|
||||
from tinygrad.shape.shapetracker import ShapeTracker, get_contraction
|
||||
from tinygrad.shape.symbolic import sint
|
||||
from tinygrad.shape.view import View, strides_for_shape
|
||||
@@ -129,11 +129,6 @@ class Kernel:
|
||||
@property
|
||||
def membufs(self) -> List[MemBuffer]: return [x for x in self.bufs if isinstance(x, MemBuffer)]
|
||||
|
||||
def has_variable_shape(self) -> bool:
|
||||
for b in self.bufs:
|
||||
if not isinstance(b, LocalBuffer) and not all_int(b.st.views[-1].shape): return True
|
||||
return False
|
||||
|
||||
def shape_offsets(self, i): return itertools.product(*[list(range(s)) for s in self.sts[i].shape[self.shape_len-self.upcasted:][::-1]]) if self.upcasted > 0 else [tuple()]
|
||||
def float4_axis(self, i): return [x-(self.shape_len-self.upcasted) for x in self.sts[i].unit_stride_axes() if x >= self.shape_len-self.upcasted and self.sts[i].shape[x]%4 == 0]
|
||||
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
from __future__ import annotations
|
||||
import os, functools, platform, time, re, contextlib, operator, hashlib, pickle, sqlite3, cProfile, pstats
|
||||
import numpy as np
|
||||
from typing import Dict, Tuple, Union, List, NamedTuple, Final, Iterator, ClassVar, Optional, Iterable, Any, TypeVar, TYPE_CHECKING
|
||||
from typing import Dict, Tuple, Union, List, NamedTuple, Final, Iterator, ClassVar, Optional, Iterable, Any, TypeVar, TYPE_CHECKING, Callable
|
||||
if TYPE_CHECKING: # TODO: remove this and import TypeGuard from typing once minimum python supported version is 3.10
|
||||
from typing_extensions import TypeGuard
|
||||
|
||||
@@ -29,9 +29,9 @@ def round_up(num, amt): return num if num%amt == 0 else num+(amt-(num%amt))
|
||||
def merge_dicts(ds:Iterable[Dict]) -> Dict:
|
||||
assert len(kvs:=set([(k,v) for d in ds for k,v in d.items()])) == len(set(kv[0] for kv in kvs)), f"cannot merge, {kvs} contains different values for the same key"
|
||||
return {k:v for d in ds for k,v in d.items()}
|
||||
def partition(lst, fxn):
|
||||
a: list[Any] = []
|
||||
b: list[Any] = []
|
||||
def partition(lst:List[T], fxn:Callable[[T],bool]):
|
||||
a:List[T] = []
|
||||
b:List[T] = []
|
||||
for s in lst: (a if fxn(s) else b).append(s)
|
||||
return a,b
|
||||
|
||||
@@ -200,8 +200,7 @@ def diskcache_get(table:str, key:Union[Dict, str, int]) -> Any:
|
||||
res = cur.execute(f"SELECT val FROM {table} WHERE {' AND '.join([f'{x}=?' for x in key.keys()])}", tuple(key.values()))
|
||||
except sqlite3.OperationalError:
|
||||
return None # table doesn't exist
|
||||
if (val:=res.fetchone()) is not None:
|
||||
return pickle.loads(val[0])
|
||||
if (val:=res.fetchone()) is not None: return pickle.loads(val[0])
|
||||
return None
|
||||
|
||||
_db_tables = set()
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
from __future__ import annotations
|
||||
from abc import abstractmethod
|
||||
import functools
|
||||
from math import gcd
|
||||
from itertools import product
|
||||
@@ -187,8 +186,7 @@ class OpNode(Node):
|
||||
self.a, self.b = a, b
|
||||
self.min, self.max = self.get_bounds()
|
||||
def vars(self): return self.a.vars() + (self.b.vars() if isinstance(self.b, Node) else [])
|
||||
@abstractmethod
|
||||
def get_bounds(self) -> Tuple[int, int]: pass
|
||||
def get_bounds(self) -> Tuple[int, int]: raise NotImplementedError("must be implemented")
|
||||
|
||||
class LtNode(OpNode):
|
||||
def __floordiv__(self, b: Union[Node, int], _=False): return (self.a//b) < (self.b//b)
|
||||
@@ -211,8 +209,7 @@ class MulNode(OpNode):
|
||||
def __mod__(self, b: Union[Node, int]):
|
||||
a = (self.a * (self.b%b))
|
||||
return Node.__mod__(a, b)
|
||||
def get_bounds(self) -> Tuple[int, int]:
|
||||
return (self.a.min*self.b, self.a.max*self.b) if self.b >= 0 else (self.a.max*self.b, self.a.min*self.b)
|
||||
def get_bounds(self) -> Tuple[int, int]: return (self.a.min*self.b, self.a.max*self.b) if self.b >= 0 else (self.a.max*self.b, self.a.min*self.b)
|
||||
def substitute(self, var_vals: Dict[VariableOrNum, Node]) -> Node: return self.a.substitute(var_vals) * (self.b if isinstance(self.b, int) else self.b.substitute(var_vals))
|
||||
|
||||
class DivNode(OpNode):
|
||||
|
||||
@@ -41,9 +41,7 @@ class Tensor:
|
||||
training: ClassVar[bool] = False
|
||||
class train:
|
||||
def __init__(self, val=True): self.val = val
|
||||
def __enter__(self):
|
||||
self.prev = Tensor.training
|
||||
Tensor.training = self.val
|
||||
def __enter__(self): self.prev, Tensor.training = Tensor.training, self.val
|
||||
def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any): Tensor.training = self.prev
|
||||
|
||||
no_grad: ClassVar[bool] = False
|
||||
@@ -171,8 +169,7 @@ class Tensor:
|
||||
@staticmethod
|
||||
def eye(dim:int, **kwargs): return Tensor.full((dim,1),1,**kwargs).pad(((0,0),(0,dim))).reshape(dim*(dim+1)).shrink(((0,dim*dim),)).reshape(dim, dim)
|
||||
|
||||
def full_like(self, fill_value, **kwargs):
|
||||
return Tensor.full(self.shape, fill_value=fill_value, dtype=kwargs.pop("dtype", self.dtype), device=kwargs.pop("device", self.device), **kwargs)
|
||||
def full_like(self, fill_value, **kwargs): return Tensor.full(self.shape, fill_value=fill_value, dtype=kwargs.pop("dtype", self.dtype), device=kwargs.pop("device", self.device), **kwargs)
|
||||
def zeros_like(self, **kwargs): return self.full_like(0, **kwargs)
|
||||
def ones_like(self, **kwargs): return self.full_like(1, **kwargs)
|
||||
|
||||
@@ -382,8 +379,7 @@ class Tensor:
|
||||
shapes = [s.shape[dim] for s in catargs]
|
||||
shape_cumsum = [0, *accumulate(shapes)]
|
||||
slc = [[(0, 0) for _ in self.shape] for _ in catargs]
|
||||
for shp,k,s in zip(shapes, shape_cumsum[:-1], slc):
|
||||
s[dim] = (k, shape_cumsum[-1] - k - shp)
|
||||
for shp,k,s in zip(shapes, shape_cumsum[:-1], slc): s[dim] = (k, shape_cumsum[-1] - k - shp)
|
||||
return reduce(Tensor.__add__, [arg.pad(tuple(s)) for arg,s in zip(catargs, slc)])
|
||||
|
||||
@staticmethod
|
||||
|
||||
Reference in New Issue
Block a user