mirror of
https://github.com/data61/MP-SPDZ.git
synced 2026-01-10 14:08:09 -05:00
66 lines
1.2 KiB
Plaintext
66 lines
1.2 KiB
Plaintext
# this trains logistic regression in 0/1 distinction
|
|
# see https://github.com/csiro-mlai/mnist-mpc for data preparation
|
|
|
|
import ml
|
|
|
|
program.options_from_args()
|
|
|
|
approx = 3
|
|
|
|
if 'gisette' in program.args:
|
|
print('Compiling for 4/9')
|
|
N = 11791
|
|
n_test = 1991
|
|
else:
|
|
N = 12665
|
|
n_test = 2115
|
|
|
|
n_examples = N
|
|
n_features = 28 ** 2
|
|
|
|
try:
|
|
n_epochs = int(program.args[1])
|
|
except:
|
|
n_epochs = 100
|
|
|
|
try:
|
|
batch_size = int(program.args[2])
|
|
except:
|
|
batch_size = N
|
|
|
|
assert batch_size <= N
|
|
ml.Layer.back_batch_size = batch_size
|
|
|
|
try:
|
|
ml.set_n_threads(int(program.args[3]))
|
|
except:
|
|
pass
|
|
|
|
layers = [ml.Dense(N, n_features, 1),
|
|
ml.Output(N, approx=approx)]
|
|
|
|
layers[1].Y.input_from(0)
|
|
layers[0].X.input_from(0)
|
|
|
|
Y = sint.Array(n_test)
|
|
X = sfix.Matrix(n_test, n_features)
|
|
Y.input_from(0)
|
|
X.input_from(0)
|
|
|
|
sgd = ml.SGD(layers, n_epochs, report_loss=True)
|
|
sgd.reset()
|
|
|
|
start_timer(1)
|
|
sgd.run(batch_size)
|
|
stop_timer(1)
|
|
|
|
layers[0].X.assign(X)
|
|
sgd.forward(n_test)
|
|
|
|
n_correct = cfix(0)
|
|
|
|
for i in range(n_test):
|
|
n_correct += Y[i].reveal().bit_xor(layers[0].Y[i][0][0][0].reveal() < 0)
|
|
|
|
print_ln('acc: %s (%s/%s)', n_correct / n_test, n_correct, n_test)
|