diff --git a/tinygrad/lazy.py b/tinygrad/lazy.py index 71975c5c2c..6bd32d2afd 100644 --- a/tinygrad/lazy.py +++ b/tinygrad/lazy.py @@ -128,8 +128,8 @@ def get_movementroot_contiguous(x:LazyBuffer) -> LazyBuffer: return get_movement LAZY = int(os.getenv("LAZY", "1")) class LazyBuffer: - lazycache : weakref.WeakValueDictionary[LazyOp, LazyBuffer] = weakref.WeakValueDictionary() - def __new__(cls, device, shape, optype, op): + lazycache : weakref.WeakValueDictionary[Tuple[str, OpType, LazyOp], LazyBuffer] = weakref.WeakValueDictionary() + def __new__(cls, device:str, shape:Union[ShapeTracker, Tuple[int, ...]], optype:OpType, op:LazyOp): # fromcpu aren't cached if optype == LoadOps and op.op == LoadOps.FROMCPU: return super().__new__(cls) @@ -139,7 +139,7 @@ class LazyBuffer: LazyBuffer.lazycache[wop] = ret = super().__new__(cls) # noqa: F841, pylint: disable=W0612 return LazyBuffer.lazycache[wop] - def __init__(self, device, shape:Union[ShapeTracker, Tuple[int, ...]], optype:OpType, op:LazyOp): + def __init__(self, device:str, shape:Union[ShapeTracker, Tuple[int, ...]], optype:OpType, op:LazyOp): if hasattr(self, 'device'): return # cache hit, we return and don't reinit self.st = shape if isinstance(shape, ShapeTracker) else ShapeTracker(tuple(shape)) diff --git a/tinygrad/shape/__init__.py b/tinygrad/shape/__init__.py index a2da527be8..6e68955754 100644 --- a/tinygrad/shape/__init__.py +++ b/tinygrad/shape/__init__.py @@ -21,14 +21,14 @@ def to_shape_strides(shape:Tuple[int, ...], strides:Tuple[int, ...]) -> List[Tup return ret class View: - def __init__(self, shape, strides, offset:int=0): + def __init__(self, shape:Tuple[int, ...], strides:Tuple[int, ...], offset:int=0): self.shape, self.strides, self.offset = tuple(shape), tuple(strides), offset self.shape_strides = to_shape_strides(self.shape, self.strides) def __repr__(self): return f"View({self.shape}, {self.strides}, {self.offset})" @functools.cached_property - def contiguous(self): + def contiguous(self) -> bool: return self.offset == 0 and all(s1 == s2 or s == 1 for s,s1,s2 in zip(self.shape, self.strides, strides_for_shape(self.shape))) def expr_node(self, idx): @@ -49,17 +49,25 @@ class View: return Variable.sum([Variable.num(self.offset)] + [Variable(idxs[i], 0, sh-1)*st for i,(sh,st) in enumerate(zip(self.shape, self.strides)) if sh != 1 and st != 0]) class ZeroView: - def __init__(self, old_shape, arg): - self.old_shape, self.arg, self.shape = old_shape, arg, [] + def __init__(self, old_shape:Tuple[int, ...], arg): + self.old_shape, self.arg = old_shape, arg + self.shape : Tuple[int, ...] = tuple([y-x for x,y in self.arg]) + + @property + def strides(self): raise Exception("ZeroView doesn't have strides") + + @property + def offset(self): raise Exception("ZeroView doesn't have offset") + + @property + def contiguous(self): return False def expr_node(self, valid, idx): expr, acc = [valid] if valid is not None else [], 1 - for s,(x,y) in list(zip(self.old_shape, self.arg))[::-1]: - self.shape = [y-x] + self.shape - base = idx//acc - base = (base % self.shape[0]) + x - expr += ([base >= 0] if x < 0 else []) + ([base < s] if y > s else []) - acc *= self.shape[0] + for os,ns,(x,y) in list(zip(self.old_shape, self.shape, self.arg))[::-1]: + base = ((idx//acc) % ns) + x + expr += ([base >= 0] if x < 0 else []) + ([base < os] if y > os else []) + acc *= ns return Variable.ands(expr) @functools.cached_property @@ -89,16 +97,16 @@ class ShapeTracker: def __repr__(self): return f"ShapeTracker(shape={self.shape}, views={self.views})" @property - def contiguous(self): return len(self.views) == 1 and self.views[-1].contiguous + def contiguous(self) -> bool: return len(self.views) == 1 and self.views[-1].contiguous @property - def shape(self): return self.views[-1].shape + def shape(self) -> Tuple[int, ...]: return self.views[-1].shape @property - def strides(self): return self.views[-1].strides + def strides(self) -> Tuple[int, ...]: return self.views[-1].strides @property - def offset(self): return self.views[-1].offset + def offset(self) -> int: return self.views[-1].offset def expr_node(self): idx = Variable('idx', 0, prod(self.shape)-1) diff --git a/tinygrad/shape/symbolic.py b/tinygrad/shape/symbolic.py index 7f583c3977..0f240862f4 100644 --- a/tinygrad/shape/symbolic.py +++ b/tinygrad/shape/symbolic.py @@ -5,6 +5,7 @@ from tinygrad.helpers import partition, modn, all_same class Node: b, min, max = 0, -math.inf, math.inf # make mypy happy + expr: str def __str__(self): if self.min == self.max: return str(self.min) # this is universal return self.expr diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index 129db2e757..c87ec7359e 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -11,7 +11,7 @@ from tinygrad.lazy import Device, LazyBuffer class Tensor: training, no_grad = False, False - def __init__(self, data, device=Device.DEFAULT, requires_grad=None): + def __init__(self, data, device=Device.DEFAULT, requires_grad:Optional[bool]=None): if isinstance(data, list): data = np.array(data, dtype=np.float32) elif isinstance(data, LazyBuffer) and data.device != device: @@ -326,7 +326,7 @@ class Function: self.device, self.parents = device, tensors self.needs_input_grad = [t.requires_grad for t in self.parents] self.requires_grad = True if any(self.needs_input_grad) else (None if any(x is None for x in self.needs_input_grad) else False) - self.saved_tensors : List[Tensor] = [] + self.saved_tensors : List[LazyBuffer] = [] def forward(self, *args, **kwargs): raise NotImplementedError(f"forward not implemented for {type(self)}") def backward(self, *args, **kwargs): raise NotImplementedError(f"backward not implemented for {type(self)}")