Fix dot product on buffers with zero strides (#3303)

* skip matacc opt if the all src buffers of mul op are const buffers

* add noqa directive for long test

* unskip MALACC opt

* ensure that a_axes at least includes summation axes in order to perform np.einsum correctly

* add regression test for mulacc op

* compute a_slices using a_axes

* refactor  helper of  function to retrieve axes and slices for nonzero strides as well as summation axes

* include a regression test that uses  and  to test the behaviour indirectly
This commit is contained in:
Obada Khalili
2024-02-04 12:15:06 +02:00
committed by GitHub
parent 30a3288c4a
commit b4ea0e18e3
2 changed files with 27 additions and 2 deletions

View File

@@ -10,9 +10,14 @@ def reduce_axis(in_shape:Tuple[int, ...], out_shape:Tuple[int, ...]) -> Tuple[in
def einsum_mulacc(einsum, get_strides, expand):
def einscripts(x): return ''.join(["abcdefghijklmnopqrstuvwxyz"[i] for i in x])
def axes_slice(strides): return tuple(i for i,s in enumerate(strides) if s != 0), tuple(slice(None) if s != 0 else 0 for s in strides)
def sum_and_nonzero_strides_axes_slices(strides, sum_axes):
axes = tuple(i for i,s in enumerate(strides) if s != 0 or i in sum_axes)
slices = tuple(slice(None) if s != 0 or i in sum_axes else 0 for i,s in enumerate(strides))
return axes, slices
def mulacc(a, b, new_shape):
(a_axes, a_slices), (b_axes, b_slices) = axes_slice(get_strides(a)), axes_slice(get_strides(b))
sum_axes = tuple(i for i,s in enumerate(new_shape) if s == 1)
(a_axes, a_slices) = sum_and_nonzero_strides_axes_slices(get_strides(a), sum_axes)
(b_axes, b_slices) = sum_and_nonzero_strides_axes_slices(get_strides(b), sum_axes)
out = [i for i in range(len(new_shape)) if a.shape[i] == new_shape[i] and (i in a_axes or i in b_axes)]
ret = einsum(f"{einscripts(a_axes)}, {einscripts(b_axes)} -> {einscripts(out)}", a[a_slices], b[b_slices])
return expand(ret.reshape(tuple(1 if i not in a_axes and i not in b_axes else s for i,s in enumerate(new_shape))), new_shape)