cleanup of lines / unused / types (#2336)

This commit is contained in:
chenyu
2023-11-16 21:15:32 -05:00
committed by GitHub
parent 3971259832
commit aa01a63b3f
4 changed files with 11 additions and 24 deletions

View File

@@ -2,7 +2,7 @@ from __future__ import annotations
import os, math, itertools
from typing import NamedTuple, Optional, List, Tuple, cast, Dict, Union
from tinygrad.ops import LazyOp, FlopCounter, get_lazyop_info, UnaryOps, BinaryOps, ReduceOps, MemBuffer, ConstBuffer, BufferOps, Device, Compiled
from tinygrad.helpers import dedup, dtypes, colored, ImageDType, DType, all_int, ansilen, getenv, prod, DEBUG
from tinygrad.helpers import dedup, dtypes, colored, ImageDType, DType, ansilen, getenv, prod, DEBUG
from tinygrad.shape.shapetracker import ShapeTracker, get_contraction
from tinygrad.shape.symbolic import sint
from tinygrad.shape.view import View, strides_for_shape
@@ -129,11 +129,6 @@ class Kernel:
@property
def membufs(self) -> List[MemBuffer]: return [x for x in self.bufs if isinstance(x, MemBuffer)]
def has_variable_shape(self) -> bool:
for b in self.bufs:
if not isinstance(b, LocalBuffer) and not all_int(b.st.views[-1].shape): return True
return False
def shape_offsets(self, i): return itertools.product(*[list(range(s)) for s in self.sts[i].shape[self.shape_len-self.upcasted:][::-1]]) if self.upcasted > 0 else [tuple()]
def float4_axis(self, i): return [x-(self.shape_len-self.upcasted) for x in self.sts[i].unit_stride_axes() if x >= self.shape_len-self.upcasted and self.sts[i].shape[x]%4 == 0]

View File

@@ -1,7 +1,7 @@
from __future__ import annotations
import os, functools, platform, time, re, contextlib, operator, hashlib, pickle, sqlite3, cProfile, pstats
import numpy as np
from typing import Dict, Tuple, Union, List, NamedTuple, Final, Iterator, ClassVar, Optional, Iterable, Any, TypeVar, TYPE_CHECKING
from typing import Dict, Tuple, Union, List, NamedTuple, Final, Iterator, ClassVar, Optional, Iterable, Any, TypeVar, TYPE_CHECKING, Callable
if TYPE_CHECKING: # TODO: remove this and import TypeGuard from typing once minimum python supported version is 3.10
from typing_extensions import TypeGuard
@@ -29,9 +29,9 @@ def round_up(num, amt): return num if num%amt == 0 else num+(amt-(num%amt))
def merge_dicts(ds:Iterable[Dict]) -> Dict:
assert len(kvs:=set([(k,v) for d in ds for k,v in d.items()])) == len(set(kv[0] for kv in kvs)), f"cannot merge, {kvs} contains different values for the same key"
return {k:v for d in ds for k,v in d.items()}
def partition(lst, fxn):
a: list[Any] = []
b: list[Any] = []
def partition(lst:List[T], fxn:Callable[[T],bool]):
a:List[T] = []
b:List[T] = []
for s in lst: (a if fxn(s) else b).append(s)
return a,b
@@ -200,8 +200,7 @@ def diskcache_get(table:str, key:Union[Dict, str, int]) -> Any:
res = cur.execute(f"SELECT val FROM {table} WHERE {' AND '.join([f'{x}=?' for x in key.keys()])}", tuple(key.values()))
except sqlite3.OperationalError:
return None # table doesn't exist
if (val:=res.fetchone()) is not None:
return pickle.loads(val[0])
if (val:=res.fetchone()) is not None: return pickle.loads(val[0])
return None
_db_tables = set()

View File

@@ -1,5 +1,4 @@
from __future__ import annotations
from abc import abstractmethod
import functools
from math import gcd
from itertools import product
@@ -187,8 +186,7 @@ class OpNode(Node):
self.a, self.b = a, b
self.min, self.max = self.get_bounds()
def vars(self): return self.a.vars() + (self.b.vars() if isinstance(self.b, Node) else [])
@abstractmethod
def get_bounds(self) -> Tuple[int, int]: pass
def get_bounds(self) -> Tuple[int, int]: raise NotImplementedError("must be implemented")
class LtNode(OpNode):
def __floordiv__(self, b: Union[Node, int], _=False): return (self.a//b) < (self.b//b)
@@ -211,8 +209,7 @@ class MulNode(OpNode):
def __mod__(self, b: Union[Node, int]):
a = (self.a * (self.b%b))
return Node.__mod__(a, b)
def get_bounds(self) -> Tuple[int, int]:
return (self.a.min*self.b, self.a.max*self.b) if self.b >= 0 else (self.a.max*self.b, self.a.min*self.b)
def get_bounds(self) -> Tuple[int, int]: return (self.a.min*self.b, self.a.max*self.b) if self.b >= 0 else (self.a.max*self.b, self.a.min*self.b)
def substitute(self, var_vals: Dict[VariableOrNum, Node]) -> Node: return self.a.substitute(var_vals) * (self.b if isinstance(self.b, int) else self.b.substitute(var_vals))
class DivNode(OpNode):

View File

@@ -41,9 +41,7 @@ class Tensor:
training: ClassVar[bool] = False
class train:
def __init__(self, val=True): self.val = val
def __enter__(self):
self.prev = Tensor.training
Tensor.training = self.val
def __enter__(self): self.prev, Tensor.training = Tensor.training, self.val
def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any): Tensor.training = self.prev
no_grad: ClassVar[bool] = False
@@ -171,8 +169,7 @@ class Tensor:
@staticmethod
def eye(dim:int, **kwargs): return Tensor.full((dim,1),1,**kwargs).pad(((0,0),(0,dim))).reshape(dim*(dim+1)).shrink(((0,dim*dim),)).reshape(dim, dim)
def full_like(self, fill_value, **kwargs):
return Tensor.full(self.shape, fill_value=fill_value, dtype=kwargs.pop("dtype", self.dtype), device=kwargs.pop("device", self.device), **kwargs)
def full_like(self, fill_value, **kwargs): return Tensor.full(self.shape, fill_value=fill_value, dtype=kwargs.pop("dtype", self.dtype), device=kwargs.pop("device", self.device), **kwargs)
def zeros_like(self, **kwargs): return self.full_like(0, **kwargs)
def ones_like(self, **kwargs): return self.full_like(1, **kwargs)
@@ -382,8 +379,7 @@ class Tensor:
shapes = [s.shape[dim] for s in catargs]
shape_cumsum = [0, *accumulate(shapes)]
slc = [[(0, 0) for _ in self.shape] for _ in catargs]
for shp,k,s in zip(shapes, shape_cumsum[:-1], slc):
s[dim] = (k, shape_cumsum[-1] - k - shp)
for shp,k,s in zip(shapes, shape_cumsum[:-1], slc): s[dim] = (k, shape_cumsum[-1] - k - shp)
return reduce(Tensor.__add__, [arg.pad(tuple(s)) for arg,s in zip(catargs, slc)])
@staticmethod