From f502c9b08f8dd9bbbe3ab39c51e5c19bbce2ad7e Mon Sep 17 00:00:00 2001 From: chenyu Date: Thu, 11 Jan 2024 13:05:54 -0500 Subject: [PATCH] minor cleanup of View.reshape (#3088) * minor cleanup of View.reshape removed some redundant logic * new_strides * revert that --- tinygrad/shape/view.py | 20 +++++++++++--------- 1 file changed, 11 insertions(+), 9 deletions(-) diff --git a/tinygrad/shape/view.py b/tinygrad/shape/view.py index 20e400278d..1784ab8cb7 100644 --- a/tinygrad/shape/view.py +++ b/tinygrad/shape/view.py @@ -181,18 +181,20 @@ class View: if self.contiguous: return View.create(new_shape) strides, r_new_shape = [], reversed(new_shape) - for merged_dim, s, real_dim in reversed(_merge_dims(self.shape, self.strides, self.mask)): - acc, new_stride = 1, s + for merged_dim, new_stride, real_dim in reversed(_merge_dims(self.shape, self.strides, self.mask)): + acc = 1 + # TODO: this <= and != is for symbolic!? while acc <= merged_dim and acc != merged_dim and (new_dim := next(r_new_shape, None)): - strides.append(new_stride if new_dim != 1 else 0) - if new_dim == 1: continue - new_stride *= (new_dim if (acc := acc * new_dim) < real_dim else 0) + strides.append(new_stride) + if new_dim != 1: new_stride *= (new_dim if (acc := acc * new_dim) < real_dim else 0) if acc != merged_dim: break else: strides += [0,] * (len(new_shape) - len(strides)) - mask, extra = _reshape_mask(self, new_shape) - cstrides = canonicalize_strides(tuple(e-b for b,e in mask) if mask else new_shape, tuple(reversed(strides))) - extra_offset = (sum(m[0] * s for m,s in zip(self.mask, self.strides)) if self.mask else 0) - (sum(m[0] * s for m,s in zip(mask, cstrides)) if mask else 0) # noqa: E501 - if not extra: return View.create(new_shape, cstrides, self.offset + extra_offset, mask) + new_mask, extra = _reshape_mask(self, new_shape) + if not extra: + new_strides = canonicalize_strides(tuple(e-b for b,e in new_mask) if new_mask else new_shape, tuple(reversed(strides))) + extra_offset = (sum(m[0] * s for m,s in zip(self.mask, self.strides)) if self.mask else 0) - \ + (sum(m[0] * s for m,s in zip(new_mask, new_strides)) if new_mask else 0) + return View.create(new_shape, new_strides, self.offset + extra_offset, new_mask) return None