diff --git a/tinygrad/runtime/ops_cpu.py b/tinygrad/runtime/ops_cpu.py index 7ca5542f1f..dfea5224f9 100644 --- a/tinygrad/runtime/ops_cpu.py +++ b/tinygrad/runtime/ops_cpu.py @@ -10,17 +10,15 @@ def reduce_axis(in_shape:Tuple[int, ...], out_shape:Tuple[int, ...]) -> Tuple[in def einsum_mulacc(einsum, get_strides, expand): def einscripts(x): return ''.join(["abcdefghijklmnopqrstuvwxyz"[i] for i in x]) - def sum_and_nonzero_strides_axes_slices(strides, sum_axes): - axes = tuple(i for i,s in enumerate(strides) if s != 0 or i in sum_axes) - slices = tuple(slice(None) if s != 0 or i in sum_axes else 0 for i,s in enumerate(strides)) - return axes, slices - def mulacc(a, b, new_shape): - sum_axes = tuple(i for i,s in enumerate(new_shape) if s == 1) - (a_axes, a_slices) = sum_and_nonzero_strides_axes_slices(get_strides(a), sum_axes) - (b_axes, b_slices) = sum_and_nonzero_strides_axes_slices(get_strides(b), sum_axes) - out = [i for i in range(len(new_shape)) if a.shape[i] == new_shape[i] and (i in a_axes or i in b_axes)] - ret = einsum(f"{einscripts(a_axes)}, {einscripts(b_axes)} -> {einscripts(out)}", a[a_slices], b[b_slices]) - return expand(ret.reshape(tuple(1 if i not in a_axes and i not in b_axes else s for i,s in enumerate(new_shape))), new_shape) + def get_input_axes(t, sum_axes): return tuple(i for i,stride in enumerate(get_strides(t)) if stride != 0 or i in sum_axes) + def get_sliced_input(t, axes): return t[tuple(slice(None) if i in axes else 0 for i in range(len(get_strides(t))))] + def mulacc(a, b, out_shape): + sum_axes = tuple(i for i,s in enumerate(out_shape) if s == 1) + a_axes, b_axes = get_input_axes(a, sum_axes), get_input_axes(b, sum_axes) + a_input, b_input = get_sliced_input(a, a_axes), get_sliced_input(b, b_axes) + out_axes = [i for i in range(len(out_shape)) if (i in a_axes or i in b_axes) and i not in sum_axes] + ret = einsum(f"{einscripts(a_axes)}, {einscripts(b_axes)} -> {einscripts(out_axes)}", a_input, b_input) + return expand(ret.reshape(tuple(1 if (i not in a_axes and i not in b_axes) else s for i,s in enumerate(out_shape))), out_shape) return mulacc def as_strided(x, arg):