verbose apply_matrix (#3333)

* verbose apply_matrix

* types

* not so verbose

* small comment change

* fix typo

---------

Co-authored-by: George Hotz <72895+geohot@users.noreply.github.com>
This commit is contained in:
David Hou
2024-02-12 03:06:12 -08:00
committed by GitHub
parent d55f99e881
commit 323393b650

View File

@@ -661,8 +661,14 @@ class Tensor:
# winograd conv 3 kernel f(4x4,3x3) see: http://arxiv.org/abs/1509.09308
def apply_matrix(mat, t, dims=len(HW)):
t_ = t.reshape(t.shape[:dims]+(1,)*dims+t.shape[dims:]).expand(t.shape[:dims]+(len(mat),)*dims+t.shape[dims:])
matcols = [[Tensor.cat(*[Tensor.full(t_.shape[dims:dims+dim]+(1,)+t_.shape[dims+dim+1:], float(m[k]), device=t.device) for m in mat], dim=dim) for k in range(len(mat[0]))] for dim in range(dims)] # noqa: E501
# multiply mat_1 @ mat_2 @ t with foldable constants, where mat_i acts on vector t along dimension i; roughly kron(mat, mat) @ t
# due to realize-before-expand rule in lazy.py, we must operate in this order: reshape -> expand -> arithmetic
t_ = t.reshape(t.shape[:dims] + (1,) * dims + t.shape[dims:]).expand(t.shape[:dims] + (len(mat),) * dims + t.shape[dims:]) # add output dims
# precalculate mat columns for each dim; prod(itertools.product(matcols)) gives the columns of kron(mat, mat, ...)
matcols = [[
Tensor.cat(*[Tensor.full(t_.shape[dims:dims + dim] + (1,) + t_.shape[dims + dim + 1:], float(m[k]), device=t.device) for m in mat], dim=dim)
for k in range(len(mat[0]))] for dim in range(dims)]
# multiply each element of t_ by the corresponding stacked column of kron(mat, mat), producing only one view for each element of t
return sum(prod(col[idx] for col, idx in zip(matcols, mat_is)) * t_[mat_is] for mat_is in itertools.product(range(len(mat[0])), repeat=dims))
HWI, HWO = (6,) * len(HW), (4,) * len(HW) # F(4x4,3x3) winograd tiles
winograd_Bt = [[4, 0, -5, 0, 1, 0], [0, -4, -4, 1, 1, 0], [0, 4, -4, -1, 1, 0], [0, -2, -1, 2, 1, 0], [0, 2, -1, -2, 1, 0], [0, 4, 0, -5, 0, 1]]