mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-02-19 02:44:40 -05:00
Use generators instead of lists in anys and alls (#1111)
* Use generators in any(..) instead of lists for better best-case * Use generators in all(...) instead of lists * enable R1729 in .pylintrc * revert import sorting --------- Co-authored-by: Anselm Coogan <anselm@scandit.com>
This commit is contained in:
@@ -234,8 +234,8 @@ class Tensor:
|
||||
def expand(self, shape, *args) -> Tensor: return mlops.Expand.apply(self, shape=tuple([x if x != -1 else s for s,x in zip(self.shape, argfix(shape, *args))]))
|
||||
def permute(self, order, *args) -> Tensor: return mlops.Permute.apply(self, order=argfix(order, *args))
|
||||
def flip(self, axis, *args) -> Tensor: return mlops.Flip.apply(self, axis=[x if x >= 0 else x+len(self.shape) for x in argfix(axis, *args)])
|
||||
def pad(self, arg:Tuple[Tuple[int, int], ...]) -> Tensor: return mlops.Pad.apply(self, arg=arg) if any([x != (0,0) for x in arg]) else self
|
||||
def shrink(self, arg:Tuple[Tuple[int, int], ...]) -> Tensor: return mlops.Shrink.apply(self, arg=arg) if any([x != (0,s) for x,s in zip(arg, self.shape)]) else self
|
||||
def pad(self, arg:Tuple[Tuple[int, int], ...]) -> Tensor: return mlops.Pad.apply(self, arg=arg) if any(x != (0,0) for x in arg) else self
|
||||
def shrink(self, arg:Tuple[Tuple[int, int], ...]) -> Tensor: return mlops.Shrink.apply(self, arg=arg) if any(x != (0,s) for x,s in zip(arg, self.shape)) else self
|
||||
|
||||
# ***** movement hlops *****
|
||||
|
||||
@@ -314,7 +314,7 @@ class Tensor:
|
||||
|
||||
def cat(self, *args, dim=0):
|
||||
dim = (dim + len(self.shape)) if dim < 0 else dim
|
||||
assert all([len(y.shape) == len(self.shape) and all([y.shape[i] == s for i,s in enumerate(self.shape) if i != dim]) for y in args])
|
||||
assert all(len(y.shape) == len(self.shape) and all(y.shape[i] == s for i,s in enumerate(self.shape) if i != dim) for y in args)
|
||||
catargs = [self] + list(args)
|
||||
assert all(len(t.shape) != 0 for t in catargs), "zero-dimensional tensor cannot be concatenated"
|
||||
shape_cumsum = [0, *accumulate([y.shape[dim] for y in catargs])]
|
||||
@@ -438,7 +438,7 @@ class Tensor:
|
||||
HW, trailing = weight.shape[2:], list(range(3, len(weight.shape)+1))
|
||||
x, w = self, weight.reshape(groups, weight.shape[0]//groups, weight.shape[1], *weight.shape[2:]).permute(0,2,1,*trailing).flip(trailing)
|
||||
stride = make_pair(stride, len(HW))
|
||||
if any([s>1 for s in stride]):
|
||||
if any(s>1 for s in stride):
|
||||
x = x.reshape(*x.shape[:2], *flatten((k,1) for k in x.shape[2:]))
|
||||
x = x.pad(((0,0), (0,0), *flatten(((0,0),(0,s-1)) for s in stride)))
|
||||
x = x.reshape(*x.shape[:2], *[k*s for k,s in zip(x.shape[2::2], stride)])
|
||||
|
||||
Reference in New Issue
Block a user