mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-02-15 09:05:40 -05:00
clean up einsum_mulacc (#3312)
* clean up einsum_mulacc * push get_strides * stride * get_strides for ndim
This commit is contained in:
@@ -10,17 +10,15 @@ def reduce_axis(in_shape:Tuple[int, ...], out_shape:Tuple[int, ...]) -> Tuple[in
|
||||
|
||||
def einsum_mulacc(einsum, get_strides, expand):
|
||||
def einscripts(x): return ''.join(["abcdefghijklmnopqrstuvwxyz"[i] for i in x])
|
||||
def sum_and_nonzero_strides_axes_slices(strides, sum_axes):
|
||||
axes = tuple(i for i,s in enumerate(strides) if s != 0 or i in sum_axes)
|
||||
slices = tuple(slice(None) if s != 0 or i in sum_axes else 0 for i,s in enumerate(strides))
|
||||
return axes, slices
|
||||
def mulacc(a, b, new_shape):
|
||||
sum_axes = tuple(i for i,s in enumerate(new_shape) if s == 1)
|
||||
(a_axes, a_slices) = sum_and_nonzero_strides_axes_slices(get_strides(a), sum_axes)
|
||||
(b_axes, b_slices) = sum_and_nonzero_strides_axes_slices(get_strides(b), sum_axes)
|
||||
out = [i for i in range(len(new_shape)) if a.shape[i] == new_shape[i] and (i in a_axes or i in b_axes)]
|
||||
ret = einsum(f"{einscripts(a_axes)}, {einscripts(b_axes)} -> {einscripts(out)}", a[a_slices], b[b_slices])
|
||||
return expand(ret.reshape(tuple(1 if i not in a_axes and i not in b_axes else s for i,s in enumerate(new_shape))), new_shape)
|
||||
def get_input_axes(t, sum_axes): return tuple(i for i,stride in enumerate(get_strides(t)) if stride != 0 or i in sum_axes)
|
||||
def get_sliced_input(t, axes): return t[tuple(slice(None) if i in axes else 0 for i in range(len(get_strides(t))))]
|
||||
def mulacc(a, b, out_shape):
|
||||
sum_axes = tuple(i for i,s in enumerate(out_shape) if s == 1)
|
||||
a_axes, b_axes = get_input_axes(a, sum_axes), get_input_axes(b, sum_axes)
|
||||
a_input, b_input = get_sliced_input(a, a_axes), get_sliced_input(b, b_axes)
|
||||
out_axes = [i for i in range(len(out_shape)) if (i in a_axes or i in b_axes) and i not in sum_axes]
|
||||
ret = einsum(f"{einscripts(a_axes)}, {einscripts(b_axes)} -> {einscripts(out_axes)}", a_input, b_input)
|
||||
return expand(ret.reshape(tuple(1 if (i not in a_axes and i not in b_axes) else s for i,s in enumerate(out_shape))), out_shape)
|
||||
return mulacc
|
||||
|
||||
def as_strided(x, arg):
|
||||
|
||||
Reference in New Issue
Block a user