mirror of
https://github.com/data61/MP-SPDZ.git
synced 2026-04-20 03:01:31 -04:00
51 lines
1.0 KiB
Plaintext
51 lines
1.0 KiB
Plaintext
from Compiler import ml
|
|
|
|
debug = False
|
|
|
|
program.options_from_args()
|
|
|
|
sfix.set_precision(16, 31)
|
|
cfix.set_precision(16, 31)
|
|
|
|
dim = int(program.args[1])
|
|
batch = int(program.args[2])
|
|
ml.Layer.back_batch_size = batch
|
|
|
|
try:
|
|
n_iterations = int(program.args[3])
|
|
except:
|
|
n_iterations = 1
|
|
|
|
try:
|
|
ml.set_n_threads(int(program.args[4]))
|
|
except:
|
|
ml.set_n_threads(None)
|
|
|
|
N = batch * n_iterations
|
|
|
|
print('run 1 epoch of logistic regression with %d examples' % (N))
|
|
|
|
dense = ml.Dense(N, dim, 1)
|
|
sigmoid = ml.Output(N, debug=debug, approx='approx' in program.args)
|
|
|
|
for x in dense.X, sigmoid.Y:
|
|
x.assign_all(0)
|
|
|
|
sgd = ml.SGD([dense, sigmoid], 1, debug=debug, report_loss=False)
|
|
sgd.reset()
|
|
|
|
if not ('forward' in program.args or 'backward' in program.args):
|
|
sgd.run(batch_size=batch)
|
|
|
|
if 'forward' in program.args:
|
|
@for_range(n_iterations)
|
|
def _(i):
|
|
sgd.forward(N=batch)
|
|
|
|
if 'backward' in program.args:
|
|
b = regint.Array(batch)
|
|
b.assign(regint.inc(batch))
|
|
@for_range(n_iterations)
|
|
def _(i):
|
|
sgd.backward(batch=b)
|