mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-02-19 02:44:40 -05:00
style: else-after-return (#1216)
Co-authored-by: Roelof van Dijk <roelof.van.dijk@vitestro.com>
This commit is contained in:
@@ -422,17 +422,16 @@ class Tensor:
|
||||
xup = xup.slice(slc_prefix + flatten(((0,k), (0,o), (0,1)) for k,o in zip(k_, o_)))
|
||||
xup = xup.reshape(*prefix, *flatten((k,o) for k,o in zip(k_, o_)))
|
||||
return xup.permute(*range(len(prefix)), *[len(prefix)+i*2+1 for i in range(len(k_))], *[len(prefix)+i*2 for i in range(len(k_))])
|
||||
else:
|
||||
# TODO: once the shapetracker can optimize well, remove this alternative implementation. or not if the CPU implementation doesn't use ShapeTracker
|
||||
o_ = [(i+(s-k))//s for i,s,k in zip(i_, s_, k_)]
|
||||
xup = self.slice(slc_prefix + [(0,o*s) for o,s in zip(o_, s_)])
|
||||
xup = xup.reshape(*prefix, *([1]*len(_insert_dims)), *flatten(((o, s) for o,s in zip(o_, s_))))
|
||||
if len(_insert_dims):
|
||||
xup = xup.expand(*prefix, *_insert_dims, *flatten(((o, s) for o,s in zip(o_, s_))))
|
||||
prefix += _insert_dims
|
||||
slc_prefix += [(0,x) for x in _insert_dims]
|
||||
xup = xup.slice(slc_prefix + flatten(((0,o), (0,k)) for o,k in zip(o_, k_)))
|
||||
return xup.permute(*range(len(prefix)), *[len(prefix)+i*2 for i in range(len(k_))], *[len(prefix)+i*2+1 for i in range(len(k_))])
|
||||
# TODO: once the shapetracker can optimize well, remove this alternative implementation. or not if the CPU implementation doesn't use ShapeTracker
|
||||
o_ = [(i+(s-k))//s for i,s,k in zip(i_, s_, k_)]
|
||||
xup = self.slice(slc_prefix + [(0,o*s) for o,s in zip(o_, s_)])
|
||||
xup = xup.reshape(*prefix, *([1]*len(_insert_dims)), *flatten(((o, s) for o,s in zip(o_, s_))))
|
||||
if len(_insert_dims):
|
||||
xup = xup.expand(*prefix, *_insert_dims, *flatten(((o, s) for o,s in zip(o_, s_))))
|
||||
prefix += _insert_dims
|
||||
slc_prefix += [(0,x) for x in _insert_dims]
|
||||
xup = xup.slice(slc_prefix + flatten(((0,o), (0,k)) for o,k in zip(o_, k_)))
|
||||
return xup.permute(*range(len(prefix)), *[len(prefix)+i*2 for i in range(len(k_))], *[len(prefix)+i*2+1 for i in range(len(k_))])
|
||||
|
||||
# NOTE: these work for more than 2D
|
||||
def avg_pool2d(self, kernel_size=(2,2), stride=None): return self._pool(make_pair(kernel_size), stride if stride is not None else kernel_size).mean(axis=tuple(range(0-len(make_pair(kernel_size)), 0)))
|
||||
|
||||
Reference in New Issue
Block a user