Make interface of Dropout backwards compatible

This commit is contained in:
Hidde L
2025-11-06 16:16:21 -05:00
parent e17e175b16
commit bfe3e7e038
4 changed files with 20 additions and 13 deletions

View File

@@ -1051,7 +1051,14 @@ class Dropout(NoVariableLayer):
:param shape: list [N, ...] where N is the number of examples and an arbitrary amount of further dimensions
:param alpha: probability (power of two)
"""
def __init__(self, shape, alpha=0.5):
def __init__(self, N, d1=None, d2=1, alpha=0.5):
if isinstance(N, list) or isinstance(N, tuple):
shape = N
assert d1 is None, ("If shape is given as list/tuple, d1 must be None. "
"Alpha must be passed explicitly for backwards compatibility.")
else:
assert d1 is not None, "At least one non-batch dimension must be set"
shape = [N, d1] if d2 == 1 else [N, d1, d2]
self.N = shape[0]
self.X = Tensor(shape, sfix)
self.Y = Tensor(shape, sfix)