mirror of
https://github.com/data61/MP-SPDZ.git
synced 2026-01-10 05:57:57 -05:00
Fix bugs in MaxPool.
This commit is contained in:
@@ -1080,6 +1080,7 @@ class MaxPool(NoVariableLayer):
|
||||
self.nabla_Y = Tensor(output_shape, sfix)
|
||||
self.N = shape[0]
|
||||
self.comparisons = MultiArray([self.N, self.X.sizes[3],
|
||||
output_shape[1], output_shape[2],
|
||||
ksize[1] * ksize[2]], sint)
|
||||
|
||||
def __repr__(self):
|
||||
@@ -1099,26 +1100,28 @@ class MaxPool(NoVariableLayer):
|
||||
red = util.tree_reduce(m, [(x[0], [1] if training else [])
|
||||
for x in pool])
|
||||
self.Y[bi][i][j][k] = red[0]
|
||||
for i, x in enumerate(red[1]):
|
||||
self.comparisons[bi][k][i] = x
|
||||
for ii, x in enumerate(red[1]):
|
||||
self.comparisons[bi][k][i][j][ii] = x
|
||||
self.traverse(batch, process)
|
||||
|
||||
def backward(self, compute_nabla_X=True, batch=None):
|
||||
if compute_nabla_X:
|
||||
self.nabla_X.alloc()
|
||||
self.nabla_X.assign_all(0)
|
||||
def process(pool, bi, k, i, j):
|
||||
for (x, h_in, w_in, h, w), c in zip(pool,
|
||||
self.comparisons[bi][k]):
|
||||
for (x, h_in, w_in, h, w), c \
|
||||
in zip(pool, self.comparisons[bi][k][i][j]):
|
||||
hh = h * h_in
|
||||
ww = w * w_in
|
||||
self.nabla_X[bi][hh][ww][k] = \
|
||||
util.if_else(h_in * w_in, c * self.nabla_Y[bi][i][j][k],
|
||||
self.nabla_X[bi][hh][ww][k])
|
||||
res = h_in * w_in * c * self.nabla_Y[bi][i][j][k]
|
||||
self.nabla_X[bi][hh][ww][k] += res
|
||||
self.traverse(batch, process)
|
||||
|
||||
def traverse(self, batch, process):
|
||||
need_padding = [self.strides[i] * (self.Y.sizes[i] - 1) + self.ksize[i] >
|
||||
self.X.sizes[i] for i in range(4)]
|
||||
overlap = reduce(operator.or_,
|
||||
(x < y for x, y in zip(self.strides, self.ksize)))
|
||||
@for_range_opt_multithread(self.n_threads,
|
||||
[len(batch), self.X.sizes[3]])
|
||||
def _(l, k):
|
||||
@@ -1128,6 +1131,8 @@ class MaxPool(NoVariableLayer):
|
||||
h_base = self.strides[1] * i
|
||||
@for_range_opt(self.Y.sizes[2])
|
||||
def _(j):
|
||||
if overlap:
|
||||
break_point()
|
||||
w_base = self.strides[2] * j
|
||||
pool = []
|
||||
for ii in range(self.ksize[1]):
|
||||
|
||||
Reference in New Issue
Block a user