Files
MP-SPDZ/Programs/Source/logreg.mpc
2020-08-24 23:29:03 +10:00

42 lines
911 B
Plaintext

from Compiler import ml
debug = False
program.use_edabit(True)
program.options_from_args()
sfix.set_precision(16, 31)
cfix.set_precision(16, 31)
dim = int(program.args[1])
batch = int(program.args[2])
try:
ml.set_n_threads(int(program.args[3]))
except:
ml.set_n_threads(None)
X_normal = sfix.Matrix(6400, dim)
X_pos = sfix.Matrix(6400, dim)
dense = ml.Dense(12800, dim, 1)
layers = [dense, ml.Output(12800, debug=debug, approx='approx' in program.args)]
sgd = ml.SGD(layers, batch // 128 * 10 , debug=debug, report_loss=False)
if not ('forward' in program.args or 'backward' in program.args):
sgd.reset([X_normal, X_pos])
sgd.run(batch_size=batch)
if 'forward' in program.args:
@for_range(1000)
def _(i):
sgd.forward(N=batch)
if 'backward' in program.args:
b = regint.Array(batch)
b.assign(regint.inc(batch))
@for_range(1000)
def _(i):
sgd.backward(batch=b)