diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index 0c83c1cb91..917f14f958 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -407,7 +407,7 @@ class Tensor: order = list(range(len(self.shape))) order[ax1], order[ax2] = order[ax2], order[ax1] return self.permute(order) - def flatten(self, start_dim=0): return self.reshape(shape=tuple(list(self.shape[0:start_dim]) + [-1])) + def flatten(self, start_dim=0): return self.reshape(shape=self.shape[:start_dim] + (-1,)) # ***** reduce ops *****