touchup apply_matrix (#3301)

This commit is contained in:
chenyu
2024-02-02 05:13:37 -05:00
committed by GitHub
parent 3a7c1eb383
commit f8563a7e9f

View File

@@ -655,8 +655,8 @@ class Tensor:
# winograd conv 3 kernel f(4x4,3x3) see: http://arxiv.org/abs/1509.09308
def apply_matrix(mat, t, dims=len(HW)):
t_ = t.reshape(t.shape[:dims]+(1,)*dims+t.shape[dims:]).expand(t.shape[:dims]+(len(mat),)*dims+t.shape[dims:])
matcols = [[Tensor.cat(*[Tensor(float(m[k]), device=t.device).reshape((1,) * len(t.shape)).expand(t_.shape[dims:dims+dim]+(1,)+t_.shape[dims+dim+1:]) for m in mat], dim=dim) for k in range(len(mat[0]))] for dim in range(dims)] # noqa: E501
return sum(prod([matcols[dim][mat_is[dim]] for dim in range(dims)]) * t_[mat_is] for mat_is in itertools.product(*[range(len(mat[0])) for _ in range(dims)])) # noqa: E501
matcols = [[Tensor.cat(*[Tensor.full(t_.shape[dims:dims+dim]+(1,)+t_.shape[dims+dim+1:], float(m[k]), device=t.device) for m in mat], dim=dim) for k in range(len(mat[0]))] for dim in range(dims)] # noqa: E501
return sum(prod(col[idx] for col, idx in zip(matcols, mat_is)) * t_[mat_is] for mat_is in itertools.product(range(len(mat[0])), repeat=dims))
HWI, HWO = (6,) * len(HW), (4,) * len(HW) # F(4x4,3x3) winograd tiles
winograd_Bt = [[4, 0, -5, 0, 1, 0], [0, -4, -4, 1, 1, 0], [0, 4, -4, -1, 1, 0], [0, -2, -1, 2, 1, 0], [0, 2, -1, -2, 1, 0], [0, 4, 0, -5, 0, 1]]
winograd_G = [[1/4, 0, 0], [-1/6, -1/6, -1/6], [-1/6, 1/6, -1/6], [1/24, 1/12, 1/6], [1/24, -1/12, 1/6], [0, 0, 1]]