mirror of
https://github.com/data61/MP-SPDZ.git
synced 2026-01-09 13:37:58 -05:00
Make interface of Dropout backwards compatible
This commit is contained in:
@@ -1051,7 +1051,14 @@ class Dropout(NoVariableLayer):
|
||||
:param shape: list [N, ...] where N is the number of examples and an arbitrary amount of further dimensions
|
||||
:param alpha: probability (power of two)
|
||||
"""
|
||||
def __init__(self, shape, alpha=0.5):
|
||||
def __init__(self, N, d1=None, d2=1, alpha=0.5):
|
||||
if isinstance(N, list) or isinstance(N, tuple):
|
||||
shape = N
|
||||
assert d1 is None, ("If shape is given as list/tuple, d1 must be None. "
|
||||
"Alpha must be passed explicitly for backwards compatibility.")
|
||||
else:
|
||||
assert d1 is not None, "At least one non-batch dimension must be set"
|
||||
shape = [N, d1] if d2 == 1 else [N, d1, d2]
|
||||
self.N = shape[0]
|
||||
self.X = Tensor(shape, sfix)
|
||||
self.Y = Tensor(shape, sfix)
|
||||
|
||||
Reference in New Issue
Block a user