mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-24 14:28:09 -05:00
mypy will compile the shapetracker, no speed up
This commit is contained in:
@@ -39,7 +39,7 @@ class View:
|
||||
return Variable.sum(ret)
|
||||
|
||||
@property
|
||||
def expr(self):
|
||||
def expr(self) -> str:
|
||||
return 'idx=' + str(self.expr_node(Variable('idx', 0, prod([x[0] for x in self.shape_strides])-1)))
|
||||
|
||||
# generate an expression if you have a variable or expression for each index
|
||||
@@ -47,20 +47,13 @@ class View:
|
||||
return Variable.sum([Variable.num(self.offset+offset)] + [Variable(idxs[i], 0, sh-1)*st for i,(sh,st) in enumerate(zip(self.shape, self.strides)) if sh != 1 and st != 0])
|
||||
|
||||
class ZeroView:
|
||||
__slots__ = ('old_shape', 'arg', 'shape')
|
||||
__slots__ = ('old_shape', 'arg', 'shape', 'contiguous', 'strides', 'offset')
|
||||
|
||||
def __init__(self, old_shape:Tuple[int, ...], arg):
|
||||
self.old_shape, self.arg = old_shape, arg
|
||||
self.shape : Tuple[int, ...] = tuple([y-x for x,y in self.arg])
|
||||
|
||||
@property
|
||||
def strides(self): raise NotImplementedError("ZeroView doesn't have strides")
|
||||
|
||||
@property
|
||||
def offset(self): raise NotImplementedError("ZeroView doesn't have offset")
|
||||
|
||||
@property
|
||||
def contiguous(self): return False
|
||||
# fake properties
|
||||
self.strides, self.contiguous, self.offset = strides_for_shape(self.shape), False, 0
|
||||
|
||||
def expr_idxs(self, idxs, offset=0): raise NotImplementedError("ZeroView doesn't have expr_idxs")
|
||||
|
||||
@@ -73,7 +66,7 @@ class ZeroView:
|
||||
return Variable.ands(expr)
|
||||
|
||||
@property
|
||||
def expr(self):
|
||||
def expr(self) -> str:
|
||||
max_idx = prod([y-x for x,y in self.arg])
|
||||
return 'valid=' + str(self.expr_node(Variable('valid', 0, 1), Variable('idx', 0, max_idx-1)))
|
||||
|
||||
@@ -131,7 +124,7 @@ class ShapeTracker:
|
||||
else: idx = v.expr_node(idx)
|
||||
return idx, valid
|
||||
|
||||
def expr(self):
|
||||
def expr(self) -> str:
|
||||
idx, valid = self.expr_node()
|
||||
if valid is not None and str(valid) != "valid": return f"valid={valid};idx={idx}"
|
||||
else: return f"idx={idx}"
|
||||
|
||||
Reference in New Issue
Block a user