clean up einsum_mulacc (#3312)

* clean up einsum_mulacc

* push get_strides

* stride

* get_strides for ndim
This commit is contained in:
chenyu
2024-02-04 06:21:19 -05:00
committed by GitHub
parent d459956966
commit ca7973f61c

View File

@@ -10,17 +10,15 @@ def reduce_axis(in_shape:Tuple[int, ...], out_shape:Tuple[int, ...]) -> Tuple[in
def einsum_mulacc(einsum, get_strides, expand):
def einscripts(x): return ''.join(["abcdefghijklmnopqrstuvwxyz"[i] for i in x])
def sum_and_nonzero_strides_axes_slices(strides, sum_axes):
axes = tuple(i for i,s in enumerate(strides) if s != 0 or i in sum_axes)
slices = tuple(slice(None) if s != 0 or i in sum_axes else 0 for i,s in enumerate(strides))
return axes, slices
def mulacc(a, b, new_shape):
sum_axes = tuple(i for i,s in enumerate(new_shape) if s == 1)
(a_axes, a_slices) = sum_and_nonzero_strides_axes_slices(get_strides(a), sum_axes)
(b_axes, b_slices) = sum_and_nonzero_strides_axes_slices(get_strides(b), sum_axes)
out = [i for i in range(len(new_shape)) if a.shape[i] == new_shape[i] and (i in a_axes or i in b_axes)]
ret = einsum(f"{einscripts(a_axes)}, {einscripts(b_axes)} -> {einscripts(out)}", a[a_slices], b[b_slices])
return expand(ret.reshape(tuple(1 if i not in a_axes and i not in b_axes else s for i,s in enumerate(new_shape))), new_shape)
def get_input_axes(t, sum_axes): return tuple(i for i,stride in enumerate(get_strides(t)) if stride != 0 or i in sum_axes)
def get_sliced_input(t, axes): return t[tuple(slice(None) if i in axes else 0 for i in range(len(get_strides(t))))]
def mulacc(a, b, out_shape):
sum_axes = tuple(i for i,s in enumerate(out_shape) if s == 1)
a_axes, b_axes = get_input_axes(a, sum_axes), get_input_axes(b, sum_axes)
a_input, b_input = get_sliced_input(a, a_axes), get_sliced_input(b, b_axes)
out_axes = [i for i in range(len(out_shape)) if (i in a_axes or i in b_axes) and i not in sum_axes]
ret = einsum(f"{einscripts(a_axes)}, {einscripts(b_axes)} -> {einscripts(out_axes)}", a_input, b_input)
return expand(ret.reshape(tuple(1 if (i not in a_axes and i not in b_axes) else s for i,s in enumerate(out_shape))), out_shape)
return mulacc
def as_strided(x, arg):