From 2aeebd70a6dfeb25b6ce3a8a553d1266a3d2b106 Mon Sep 17 00:00:00 2001 From: George Hotz Date: Tue, 7 Feb 2023 15:43:44 -0600 Subject: [PATCH] mypy will compile the shapetracker, no speed up --- tinygrad/shape/__init__.py | 19 ++++++------------- 1 file changed, 6 insertions(+), 13 deletions(-) diff --git a/tinygrad/shape/__init__.py b/tinygrad/shape/__init__.py index c19a2064b5..940fa2624f 100644 --- a/tinygrad/shape/__init__.py +++ b/tinygrad/shape/__init__.py @@ -39,7 +39,7 @@ class View: return Variable.sum(ret) @property - def expr(self): + def expr(self) -> str: return 'idx=' + str(self.expr_node(Variable('idx', 0, prod([x[0] for x in self.shape_strides])-1))) # generate an expression if you have a variable or expression for each index @@ -47,20 +47,13 @@ class View: return Variable.sum([Variable.num(self.offset+offset)] + [Variable(idxs[i], 0, sh-1)*st for i,(sh,st) in enumerate(zip(self.shape, self.strides)) if sh != 1 and st != 0]) class ZeroView: - __slots__ = ('old_shape', 'arg', 'shape') + __slots__ = ('old_shape', 'arg', 'shape', 'contiguous', 'strides', 'offset') def __init__(self, old_shape:Tuple[int, ...], arg): self.old_shape, self.arg = old_shape, arg self.shape : Tuple[int, ...] = tuple([y-x for x,y in self.arg]) - - @property - def strides(self): raise NotImplementedError("ZeroView doesn't have strides") - - @property - def offset(self): raise NotImplementedError("ZeroView doesn't have offset") - - @property - def contiguous(self): return False + # fake properties + self.strides, self.contiguous, self.offset = strides_for_shape(self.shape), False, 0 def expr_idxs(self, idxs, offset=0): raise NotImplementedError("ZeroView doesn't have expr_idxs") @@ -73,7 +66,7 @@ class ZeroView: return Variable.ands(expr) @property - def expr(self): + def expr(self) -> str: max_idx = prod([y-x for x,y in self.arg]) return 'valid=' + str(self.expr_node(Variable('valid', 0, 1), Variable('idx', 0, max_idx-1))) @@ -131,7 +124,7 @@ class ShapeTracker: else: idx = v.expr_node(idx) return idx, valid - def expr(self): + def expr(self) -> str: idx, valid = self.expr_node() if valid is not None and str(valid) != "valid": return f"valid={valid};idx={idx}" else: return f"idx={idx}"