Files
MP-SPDZ/Programs/Source/logreg.mpc
2020-04-02 09:09:45 +02:00

38 lines
756 B
Plaintext

from Compiler import ml
debug = False
program.use_edabit(True)
program.use_trunc_pr = True
if 'split' in program.args:
program.use_split(3)
if 'split2' in program.args:
program.use_split(2)
sfix.set_precision(16, 31)
cfix.set_precision(16, 31)
dim = int(program.args[1])
batch = int(program.args[2])
try:
ml.set_n_threads(int(program.args[3]))
except:
ml.set_n_threads(None)
X_normal = sfix.Matrix(6400, dim)
X_pos = sfix.Matrix(6400, dim)
dense = ml.Dense(12800, dim, 1)
layers = [dense, ml.Output(12800, debug=debug, approx='approx' in program.args)]
sgd = ml.SGD(layers, batch // 128 * 10 , debug=debug, report_loss=False)
sgd.reset([X_normal, X_pos])
sgd.run(batch_size=batch)
# @for_range(1000)
# def _(i):
# sgd.backward()