Fix bugs in MaxPool.

This commit is contained in:
Marcel Keller
2023-01-04 16:28:31 +11:00
parent 3e280abd24
commit a3acf6e8a3

View File

@@ -1080,6 +1080,7 @@ class MaxPool(NoVariableLayer):
self.nabla_Y = Tensor(output_shape, sfix)
self.N = shape[0]
self.comparisons = MultiArray([self.N, self.X.sizes[3],
output_shape[1], output_shape[2],
ksize[1] * ksize[2]], sint)
def __repr__(self):
@@ -1099,26 +1100,28 @@ class MaxPool(NoVariableLayer):
red = util.tree_reduce(m, [(x[0], [1] if training else [])
for x in pool])
self.Y[bi][i][j][k] = red[0]
for i, x in enumerate(red[1]):
self.comparisons[bi][k][i] = x
for ii, x in enumerate(red[1]):
self.comparisons[bi][k][i][j][ii] = x
self.traverse(batch, process)
def backward(self, compute_nabla_X=True, batch=None):
if compute_nabla_X:
self.nabla_X.alloc()
self.nabla_X.assign_all(0)
def process(pool, bi, k, i, j):
for (x, h_in, w_in, h, w), c in zip(pool,
self.comparisons[bi][k]):
for (x, h_in, w_in, h, w), c \
in zip(pool, self.comparisons[bi][k][i][j]):
hh = h * h_in
ww = w * w_in
self.nabla_X[bi][hh][ww][k] = \
util.if_else(h_in * w_in, c * self.nabla_Y[bi][i][j][k],
self.nabla_X[bi][hh][ww][k])
res = h_in * w_in * c * self.nabla_Y[bi][i][j][k]
self.nabla_X[bi][hh][ww][k] += res
self.traverse(batch, process)
def traverse(self, batch, process):
need_padding = [self.strides[i] * (self.Y.sizes[i] - 1) + self.ksize[i] >
self.X.sizes[i] for i in range(4)]
overlap = reduce(operator.or_,
(x < y for x, y in zip(self.strides, self.ksize)))
@for_range_opt_multithread(self.n_threads,
[len(batch), self.X.sizes[3]])
def _(l, k):
@@ -1128,6 +1131,8 @@ class MaxPool(NoVariableLayer):
h_base = self.strides[1] * i
@for_range_opt(self.Y.sizes[2])
def _(j):
if overlap:
break_point()
w_base = self.strides[2] * j
pool = []
for ii in range(self.ksize[1]):