Files
MP-SPDZ/Programs/Source/mnist_logreg.mpc
2021-09-17 14:31:25 +10:00

66 lines
1.2 KiB
Plaintext

# this trains logistic regression in 0/1 distinction
# see https://github.com/csiro-mlai/mnist-mpc for data preparation
import ml
program.options_from_args()
approx = 3
if 'gisette' in program.args:
print('Compiling for 4/9')
N = 11791
n_test = 1991
else:
N = 12665
n_test = 2115
n_examples = N
n_features = 28 ** 2
try:
n_epochs = int(program.args[1])
except:
n_epochs = 100
try:
batch_size = int(program.args[2])
except:
batch_size = N
assert batch_size <= N
ml.Layer.back_batch_size = batch_size
try:
ml.set_n_threads(int(program.args[3]))
except:
pass
layers = [ml.Dense(N, n_features, 1),
ml.Output(N, approx=approx)]
layers[1].Y.input_from(0)
layers[0].X.input_from(0)
Y = sint.Array(n_test)
X = sfix.Matrix(n_test, n_features)
Y.input_from(0)
X.input_from(0)
sgd = ml.SGD(layers, n_epochs, report_loss=True)
sgd.reset()
start_timer(1)
sgd.run(batch_size)
stop_timer(1)
layers[0].X.assign(X)
sgd.forward(n_test)
n_correct = cfix(0)
for i in range(n_test):
n_correct += Y[i].reveal().bit_xor(layers[0].Y[i][0][0][0].reveal() < 0)
print_ln('acc: %s (%s/%s)', n_correct / n_test, n_correct, n_test)