mypy will compile the shapetracker, no speed up

This commit is contained in:
George Hotz
2023-02-07 15:43:44 -06:00
parent 185d2e3678
commit 2aeebd70a6

View File

@@ -39,7 +39,7 @@ class View:
return Variable.sum(ret)
@property
def expr(self):
def expr(self) -> str:
return 'idx=' + str(self.expr_node(Variable('idx', 0, prod([x[0] for x in self.shape_strides])-1)))
# generate an expression if you have a variable or expression for each index
@@ -47,20 +47,13 @@ class View:
return Variable.sum([Variable.num(self.offset+offset)] + [Variable(idxs[i], 0, sh-1)*st for i,(sh,st) in enumerate(zip(self.shape, self.strides)) if sh != 1 and st != 0])
class ZeroView:
__slots__ = ('old_shape', 'arg', 'shape')
__slots__ = ('old_shape', 'arg', 'shape', 'contiguous', 'strides', 'offset')
def __init__(self, old_shape:Tuple[int, ...], arg):
self.old_shape, self.arg = old_shape, arg
self.shape : Tuple[int, ...] = tuple([y-x for x,y in self.arg])
@property
def strides(self): raise NotImplementedError("ZeroView doesn't have strides")
@property
def offset(self): raise NotImplementedError("ZeroView doesn't have offset")
@property
def contiguous(self): return False
# fake properties
self.strides, self.contiguous, self.offset = strides_for_shape(self.shape), False, 0
def expr_idxs(self, idxs, offset=0): raise NotImplementedError("ZeroView doesn't have expr_idxs")
@@ -73,7 +66,7 @@ class ZeroView:
return Variable.ands(expr)
@property
def expr(self):
def expr(self) -> str:
max_idx = prod([y-x for x,y in self.arg])
return 'valid=' + str(self.expr_node(Variable('valid', 0, 1), Variable('idx', 0, max_idx-1)))
@@ -131,7 +124,7 @@ class ShapeTracker:
else: idx = v.expr_node(idx)
return idx, valid
def expr(self):
def expr(self) -> str:
idx, valid = self.expr_node()
if valid is not None and str(valid) != "valid": return f"valid={valid};idx={idx}"
else: return f"idx={idx}"