mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-25 14:58:46 -05:00
touchup apply_matrix (#3301)
This commit is contained in:
@@ -655,8 +655,8 @@ class Tensor:
|
||||
# winograd conv 3 kernel f(4x4,3x3) see: http://arxiv.org/abs/1509.09308
|
||||
def apply_matrix(mat, t, dims=len(HW)):
|
||||
t_ = t.reshape(t.shape[:dims]+(1,)*dims+t.shape[dims:]).expand(t.shape[:dims]+(len(mat),)*dims+t.shape[dims:])
|
||||
matcols = [[Tensor.cat(*[Tensor(float(m[k]), device=t.device).reshape((1,) * len(t.shape)).expand(t_.shape[dims:dims+dim]+(1,)+t_.shape[dims+dim+1:]) for m in mat], dim=dim) for k in range(len(mat[0]))] for dim in range(dims)] # noqa: E501
|
||||
return sum(prod([matcols[dim][mat_is[dim]] for dim in range(dims)]) * t_[mat_is] for mat_is in itertools.product(*[range(len(mat[0])) for _ in range(dims)])) # noqa: E501
|
||||
matcols = [[Tensor.cat(*[Tensor.full(t_.shape[dims:dims+dim]+(1,)+t_.shape[dims+dim+1:], float(m[k]), device=t.device) for m in mat], dim=dim) for k in range(len(mat[0]))] for dim in range(dims)] # noqa: E501
|
||||
return sum(prod(col[idx] for col, idx in zip(matcols, mat_is)) * t_[mat_is] for mat_is in itertools.product(range(len(mat[0])), repeat=dims))
|
||||
HWI, HWO = (6,) * len(HW), (4,) * len(HW) # F(4x4,3x3) winograd tiles
|
||||
winograd_Bt = [[4, 0, -5, 0, 1, 0], [0, -4, -4, 1, 1, 0], [0, 4, -4, -1, 1, 0], [0, -2, -1, 2, 1, 0], [0, 2, -1, -2, 1, 0], [0, 4, 0, -5, 0, 1]]
|
||||
winograd_G = [[1/4, 0, 0], [-1/6, -1/6, -1/6], [-1/6, 1/6, -1/6], [1/24, 1/12, 1/6], [1/24, -1/12, 1/6], [0, 0, 1]]
|
||||
|
||||
Reference in New Issue
Block a user