From f8563a7e9f46efe038585e7a9561512bd633ad16 Mon Sep 17 00:00:00 2001 From: chenyu Date: Fri, 2 Feb 2024 05:13:37 -0500 Subject: [PATCH] touchup apply_matrix (#3301) --- tinygrad/tensor.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index c8fda2cfe6..b6fbef4fa8 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -655,8 +655,8 @@ class Tensor: # winograd conv 3 kernel f(4x4,3x3) see: http://arxiv.org/abs/1509.09308 def apply_matrix(mat, t, dims=len(HW)): t_ = t.reshape(t.shape[:dims]+(1,)*dims+t.shape[dims:]).expand(t.shape[:dims]+(len(mat),)*dims+t.shape[dims:]) - matcols = [[Tensor.cat(*[Tensor(float(m[k]), device=t.device).reshape((1,) * len(t.shape)).expand(t_.shape[dims:dims+dim]+(1,)+t_.shape[dims+dim+1:]) for m in mat], dim=dim) for k in range(len(mat[0]))] for dim in range(dims)] # noqa: E501 - return sum(prod([matcols[dim][mat_is[dim]] for dim in range(dims)]) * t_[mat_is] for mat_is in itertools.product(*[range(len(mat[0])) for _ in range(dims)])) # noqa: E501 + matcols = [[Tensor.cat(*[Tensor.full(t_.shape[dims:dims+dim]+(1,)+t_.shape[dims+dim+1:], float(m[k]), device=t.device) for m in mat], dim=dim) for k in range(len(mat[0]))] for dim in range(dims)] # noqa: E501 + return sum(prod(col[idx] for col, idx in zip(matcols, mat_is)) * t_[mat_is] for mat_is in itertools.product(range(len(mat[0])), repeat=dims)) HWI, HWO = (6,) * len(HW), (4,) * len(HW) # F(4x4,3x3) winograd tiles winograd_Bt = [[4, 0, -5, 0, 1, 0], [0, -4, -4, 1, 1, 0], [0, 4, -4, -1, 1, 0], [0, -2, -1, 2, 1, 0], [0, 2, -1, -2, 1, 0], [0, 4, 0, -5, 0, 1]] winograd_G = [[1/4, 0, 0], [-1/6, -1/6, -1/6], [-1/6, 1/6, -1/6], [1/24, 1/12, 1/6], [1/24, -1/12, 1/6], [0, 0, 1]]