new types and fixup ShapeTracker type mismatches

This commit is contained in:
George Hotz
2023-01-25 19:39:36 -08:00
parent 1b624a5051
commit b1dec64815
4 changed files with 28 additions and 19 deletions

View File

@@ -128,8 +128,8 @@ def get_movementroot_contiguous(x:LazyBuffer) -> LazyBuffer: return get_movement
LAZY = int(os.getenv("LAZY", "1"))
class LazyBuffer:
lazycache : weakref.WeakValueDictionary[LazyOp, LazyBuffer] = weakref.WeakValueDictionary()
def __new__(cls, device, shape, optype, op):
lazycache : weakref.WeakValueDictionary[Tuple[str, OpType, LazyOp], LazyBuffer] = weakref.WeakValueDictionary()
def __new__(cls, device:str, shape:Union[ShapeTracker, Tuple[int, ...]], optype:OpType, op:LazyOp):
# fromcpu aren't cached
if optype == LoadOps and op.op == LoadOps.FROMCPU:
return super().__new__(cls)
@@ -139,7 +139,7 @@ class LazyBuffer:
LazyBuffer.lazycache[wop] = ret = super().__new__(cls) # noqa: F841, pylint: disable=W0612
return LazyBuffer.lazycache[wop]
def __init__(self, device, shape:Union[ShapeTracker, Tuple[int, ...]], optype:OpType, op:LazyOp):
def __init__(self, device:str, shape:Union[ShapeTracker, Tuple[int, ...]], optype:OpType, op:LazyOp):
if hasattr(self, 'device'):
return # cache hit, we return and don't reinit
self.st = shape if isinstance(shape, ShapeTracker) else ShapeTracker(tuple(shape))

View File

@@ -21,14 +21,14 @@ def to_shape_strides(shape:Tuple[int, ...], strides:Tuple[int, ...]) -> List[Tup
return ret
class View:
def __init__(self, shape, strides, offset:int=0):
def __init__(self, shape:Tuple[int, ...], strides:Tuple[int, ...], offset:int=0):
self.shape, self.strides, self.offset = tuple(shape), tuple(strides), offset
self.shape_strides = to_shape_strides(self.shape, self.strides)
def __repr__(self): return f"View({self.shape}, {self.strides}, {self.offset})"
@functools.cached_property
def contiguous(self):
def contiguous(self) -> bool:
return self.offset == 0 and all(s1 == s2 or s == 1 for s,s1,s2 in zip(self.shape, self.strides, strides_for_shape(self.shape)))
def expr_node(self, idx):
@@ -49,17 +49,25 @@ class View:
return Variable.sum([Variable.num(self.offset)] + [Variable(idxs[i], 0, sh-1)*st for i,(sh,st) in enumerate(zip(self.shape, self.strides)) if sh != 1 and st != 0])
class ZeroView:
def __init__(self, old_shape, arg):
self.old_shape, self.arg, self.shape = old_shape, arg, []
def __init__(self, old_shape:Tuple[int, ...], arg):
self.old_shape, self.arg = old_shape, arg
self.shape : Tuple[int, ...] = tuple([y-x for x,y in self.arg])
@property
def strides(self): raise Exception("ZeroView doesn't have strides")
@property
def offset(self): raise Exception("ZeroView doesn't have offset")
@property
def contiguous(self): return False
def expr_node(self, valid, idx):
expr, acc = [valid] if valid is not None else [], 1
for s,(x,y) in list(zip(self.old_shape, self.arg))[::-1]:
self.shape = [y-x] + self.shape
base = idx//acc
base = (base % self.shape[0]) + x
expr += ([base >= 0] if x < 0 else []) + ([base < s] if y > s else [])
acc *= self.shape[0]
for os,ns,(x,y) in list(zip(self.old_shape, self.shape, self.arg))[::-1]:
base = ((idx//acc) % ns) + x
expr += ([base >= 0] if x < 0 else []) + ([base < os] if y > os else [])
acc *= ns
return Variable.ands(expr)
@functools.cached_property
@@ -89,16 +97,16 @@ class ShapeTracker:
def __repr__(self): return f"ShapeTracker(shape={self.shape}, views={self.views})"
@property
def contiguous(self): return len(self.views) == 1 and self.views[-1].contiguous
def contiguous(self) -> bool: return len(self.views) == 1 and self.views[-1].contiguous
@property
def shape(self): return self.views[-1].shape
def shape(self) -> Tuple[int, ...]: return self.views[-1].shape
@property
def strides(self): return self.views[-1].strides
def strides(self) -> Tuple[int, ...]: return self.views[-1].strides
@property
def offset(self): return self.views[-1].offset
def offset(self) -> int: return self.views[-1].offset
def expr_node(self):
idx = Variable('idx', 0, prod(self.shape)-1)

View File

@@ -5,6 +5,7 @@ from tinygrad.helpers import partition, modn, all_same
class Node:
b, min, max = 0, -math.inf, math.inf # make mypy happy
expr: str
def __str__(self):
if self.min == self.max: return str(self.min) # this is universal
return self.expr

View File

@@ -11,7 +11,7 @@ from tinygrad.lazy import Device, LazyBuffer
class Tensor:
training, no_grad = False, False
def __init__(self, data, device=Device.DEFAULT, requires_grad=None):
def __init__(self, data, device=Device.DEFAULT, requires_grad:Optional[bool]=None):
if isinstance(data, list):
data = np.array(data, dtype=np.float32)
elif isinstance(data, LazyBuffer) and data.device != device:
@@ -326,7 +326,7 @@ class Function:
self.device, self.parents = device, tensors
self.needs_input_grad = [t.requires_grad for t in self.parents]
self.requires_grad = True if any(self.needs_input_grad) else (None if any(x is None for x in self.needs_input_grad) else False)
self.saved_tensors : List[Tensor] = []
self.saved_tensors : List[LazyBuffer] = []
def forward(self, *args, **kwargs): raise NotImplementedError(f"forward not implemented for {type(self)}")
def backward(self, *args, **kwargs): raise NotImplementedError(f"backward not implemented for {type(self)}")