diff --git a/test/test_ops.py b/test/test_ops.py index 928e533eb2..f224566a43 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -587,6 +587,26 @@ class TestOps(unittest.TestCase): with self.assertRaises(AssertionError): a = Tensor(3.14) a.matmul(a) + def test_mulacc_with_zero_strides(self): + helper_test_op( + [], + lambda: torch.tensor(1.0).reshape((1,1,1)).expand(2,4,3).mul(torch.tensor(1.0).reshape((1,1,1)).expand(2,4,3)).sum(-1), + lambda: Tensor(1.0).reshape((1,1,1)).expand(2,4,3).mul(Tensor(1.0).reshape((1,1,1)).expand(2,4,3)).sum(-1), + forward_only=True + ) + a = [[1.,1.,1.,1.], [1.,1.,1.,1.]] + b = [1.,1.,1.,1.] + helper_test_op( + [], + lambda: torch.tensor(a).reshape((2,4,1)).expand(2,4,3).mul(torch.tensor(b).reshape((1,4,1)).expand(2,4,3)).sum([0,2]), + lambda: Tensor(a).reshape((2,4,1)).expand(2,4,3).mul(Tensor(b).reshape((1,4,1)).expand(2,4,3)).sum([0,2]), + forward_only=True + ) + helper_test_op( + [], + lambda: torch.ones((1,2)).matmul(torch.ones((2,3))), lambda: Tensor.ones((1,2)).dot(Tensor.ones((2,3))), + forward_only=True + ) def test_matmul_simple(self): helper_test_op([(4), (4,4)], lambda x,y: x.matmul(y), Tensor.dot, atol=1e-4) diff --git a/tinygrad/runtime/ops_cpu.py b/tinygrad/runtime/ops_cpu.py index 262eecb2c2..7ca5542f1f 100644 --- a/tinygrad/runtime/ops_cpu.py +++ b/tinygrad/runtime/ops_cpu.py @@ -10,9 +10,14 @@ def reduce_axis(in_shape:Tuple[int, ...], out_shape:Tuple[int, ...]) -> Tuple[in def einsum_mulacc(einsum, get_strides, expand): def einscripts(x): return ''.join(["abcdefghijklmnopqrstuvwxyz"[i] for i in x]) - def axes_slice(strides): return tuple(i for i,s in enumerate(strides) if s != 0), tuple(slice(None) if s != 0 else 0 for s in strides) + def sum_and_nonzero_strides_axes_slices(strides, sum_axes): + axes = tuple(i for i,s in enumerate(strides) if s != 0 or i in sum_axes) + slices = tuple(slice(None) if s != 0 or i in sum_axes else 0 for i,s in enumerate(strides)) + return axes, slices def mulacc(a, b, new_shape): - (a_axes, a_slices), (b_axes, b_slices) = axes_slice(get_strides(a)), axes_slice(get_strides(b)) + sum_axes = tuple(i for i,s in enumerate(new_shape) if s == 1) + (a_axes, a_slices) = sum_and_nonzero_strides_axes_slices(get_strides(a), sum_axes) + (b_axes, b_slices) = sum_and_nonzero_strides_axes_slices(get_strides(b), sum_axes) out = [i for i in range(len(new_shape)) if a.shape[i] == new_shape[i] and (i in a_axes or i in b_axes)] ret = einsum(f"{einscripts(a_axes)}, {einscripts(b_axes)} -> {einscripts(out)}", a[a_slices], b[b_slices]) return expand(ret.reshape(tuple(1 if i not in a_axes and i not in b_axes else s for i,s in enumerate(new_shape))), new_shape)