mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-10 07:28:15 -05:00
Fix dot product on buffers with zero strides (#3303)
* skip matacc opt if the all src buffers of mul op are const buffers * add noqa directive for long test * unskip MALACC opt * ensure that a_axes at least includes summation axes in order to perform np.einsum correctly * add regression test for mulacc op * compute a_slices using a_axes * refactor helper of function to retrieve axes and slices for nonzero strides as well as summation axes * include a regression test that uses and to test the behaviour indirectly
This commit is contained in:
@@ -10,9 +10,14 @@ def reduce_axis(in_shape:Tuple[int, ...], out_shape:Tuple[int, ...]) -> Tuple[in
|
||||
|
||||
def einsum_mulacc(einsum, get_strides, expand):
|
||||
def einscripts(x): return ''.join(["abcdefghijklmnopqrstuvwxyz"[i] for i in x])
|
||||
def axes_slice(strides): return tuple(i for i,s in enumerate(strides) if s != 0), tuple(slice(None) if s != 0 else 0 for s in strides)
|
||||
def sum_and_nonzero_strides_axes_slices(strides, sum_axes):
|
||||
axes = tuple(i for i,s in enumerate(strides) if s != 0 or i in sum_axes)
|
||||
slices = tuple(slice(None) if s != 0 or i in sum_axes else 0 for i,s in enumerate(strides))
|
||||
return axes, slices
|
||||
def mulacc(a, b, new_shape):
|
||||
(a_axes, a_slices), (b_axes, b_slices) = axes_slice(get_strides(a)), axes_slice(get_strides(b))
|
||||
sum_axes = tuple(i for i,s in enumerate(new_shape) if s == 1)
|
||||
(a_axes, a_slices) = sum_and_nonzero_strides_axes_slices(get_strides(a), sum_axes)
|
||||
(b_axes, b_slices) = sum_and_nonzero_strides_axes_slices(get_strides(b), sum_axes)
|
||||
out = [i for i in range(len(new_shape)) if a.shape[i] == new_shape[i] and (i in a_axes or i in b_axes)]
|
||||
ret = einsum(f"{einscripts(a_axes)}, {einscripts(b_axes)} -> {einscripts(out)}", a[a_slices], b[b_slices])
|
||||
return expand(ret.reshape(tuple(1 if i not in a_axes and i not in b_axes else s for i,s in enumerate(new_shape))), new_shape)
|
||||
|
||||
Reference in New Issue
Block a user