mirror of
https://github.com/data61/MP-SPDZ.git
synced 2026-01-09 13:37:58 -05:00
Make interface of Dropout backwards compatible
This commit is contained in:
@@ -77,16 +77,16 @@ if 'batchnorm' in program.args:
|
||||
|
||||
if 'dropout' in program.args:
|
||||
for i in range(len(layers) - 1, 0, -1):
|
||||
layers.insert(i, ml.Dropout([N, n_inner]))
|
||||
layers.insert(i, ml.Dropout(N, n_inner))
|
||||
|
||||
if 'dropout-late' in program.args:
|
||||
layers.insert(-1, ml.Dropout([N, n_inner]))
|
||||
layers.insert(-1, ml.Dropout(N, n_inner))
|
||||
|
||||
if 'dropout-early' in program.args:
|
||||
layers.insert(0, ml.Dropout([n_examples, n_features]))
|
||||
layers.insert(0, ml.Dropout(n_examples, n_features))
|
||||
|
||||
if 'dropout-early.25' in program.args:
|
||||
layers.insert(0, ml.Dropout([n_examples, n_features], alpha=.25))
|
||||
layers.insert(0, ml.Dropout(n_examples, n_features, alpha=.25))
|
||||
|
||||
layers += [ml.MultiOutput.from_args(program, n_examples, 10)]
|
||||
|
||||
|
||||
@@ -73,16 +73,16 @@ if 'batchnorm' in program.args:
|
||||
layers.insert(1, ml.BatchNorm([N, 24, 24, 20], args=program.args))
|
||||
|
||||
if 'dropout' in program.args or 'dropout2' in program.args:
|
||||
layers.insert(8, ml.Dropout([N, 500]))
|
||||
layers.insert(8, ml.Dropout(N, 500))
|
||||
elif 'dropout.25' in program.args:
|
||||
layers.insert(8, ml.Dropout([N, 500], alpha=0.25))
|
||||
layers.insert(8, ml.Dropout(N, 500, alpha=0.25))
|
||||
elif 'dropout.125' in program.args:
|
||||
layers.insert(8, ml.Dropout([N, 500], alpha=0.125))
|
||||
layers.insert(8, ml.Dropout(N, 500, alpha=0.125))
|
||||
|
||||
if 'dropout2' in program.args:
|
||||
layers.insert(6, ml.Dropout([N, 800], alpha=0.125))
|
||||
layers.insert(6, ml.Dropout(N, 800, alpha=0.125))
|
||||
elif 'dropout1' in program.args:
|
||||
layers.insert(6, ml.Dropout([N, 800], alpha=0.5))
|
||||
layers.insert(6, ml.Dropout(N, 800, alpha=0.5))
|
||||
|
||||
if 'no_relu' in program.args:
|
||||
for x in layers:
|
||||
|
||||
@@ -79,18 +79,18 @@ dropout = 'dropout' in program.args
|
||||
|
||||
if '1dense' in program.args:
|
||||
if dropout:
|
||||
layers += [ml.Dropout([N, n_inner])]
|
||||
layers += [ml.Dropout(N, n_inner)]
|
||||
layers += [ml.Dense(N, n_inner, 10),]
|
||||
elif '2dense' in program.args:
|
||||
if dropout:
|
||||
layers += [ml.Dropout([N, n_inner])]
|
||||
layers += [ml.Dropout(N, n_inner)]
|
||||
layers += [
|
||||
ml.Dense(N, n_inner, 100),
|
||||
ml.Relu([N, 100]),
|
||||
ml.Dense(N, 100, 10),
|
||||
]
|
||||
if dropout or 'dropout1' in program.args:
|
||||
layers.insert(-1, ml.Dropout([N, 100]))
|
||||
layers.insert(-1, ml.Dropout(N, 100))
|
||||
else:
|
||||
raise Exception('need to specify number of dense layers')
|
||||
|
||||
|
||||
Reference in New Issue
Block a user