diff --git a/tinygrad/codegen/optimizer.py b/tinygrad/codegen/optimizer.py index 4571bb9e39..f71a743380 100644 --- a/tinygrad/codegen/optimizer.py +++ b/tinygrad/codegen/optimizer.py @@ -311,8 +311,11 @@ class OptimizedKernel(Kernel): isinstance(self.reduceop.src[0], LazyOp) and self.reduceop.src[0].op == BinaryOps.MUL and \ self.reduceop.src[0].src[0].op == BufferOps.MEM and self.reduceop.src[0].src[1].op == BufferOps.MEM: buf0 = self.bufs.index(cast(LazyOp, self.reduceop.src[0].src[0]).arg) + buf1 = self.bufs.index(cast(LazyOp, self.reduceop.src[0].src[1]).arg) buf0_strides = self.sts[buf0].real_strides() - if buf0_strides[self.first_reduce] == 1: + buf1_strides = self.sts[buf1].real_strides() + def has_expanded_axis(s, st): return any(x > 1 and y == 0 for x,y in zip(s,st)) + if buf0_strides[self.first_reduce] == 1 and not (has_expanded_axis(self.sts[buf0].shape, buf0_strides) and has_expanded_axis(self.sts[buf1].shape, buf1_strides)): for global_idx in range(self.global_dims): if self.full_shape[self.first_reduce]%MV_THREADS_PER_ROW == 0 and self.full_shape[global_idx]%(MV_BLOCKSIZE*MV_ROWS_PER_THREAD) == 0: if DEBUG >= 3: print(f"MATVEC: full_shape={self.full_shape} first_reduce={self.first_reduce} buf0_strides={buf0_strides} blocksize={MV_BLOCKSIZE} threads_per_row={MV_THREADS_PER_ROW} rows_per_thread{MV_ROWS_PER_THREAD}")