mirror of
https://github.com/data61/MP-SPDZ.git
synced 2026-01-09 13:37:58 -05:00
38 lines
756 B
Plaintext
38 lines
756 B
Plaintext
from Compiler import ml
|
|
|
|
debug = False
|
|
|
|
program.use_edabit(True)
|
|
program.use_trunc_pr = True
|
|
|
|
if 'split' in program.args:
|
|
program.use_split(3)
|
|
|
|
if 'split2' in program.args:
|
|
program.use_split(2)
|
|
|
|
sfix.set_precision(16, 31)
|
|
cfix.set_precision(16, 31)
|
|
|
|
dim = int(program.args[1])
|
|
batch = int(program.args[2])
|
|
|
|
try:
|
|
ml.set_n_threads(int(program.args[3]))
|
|
except:
|
|
ml.set_n_threads(None)
|
|
|
|
X_normal = sfix.Matrix(6400, dim)
|
|
X_pos = sfix.Matrix(6400, dim)
|
|
|
|
dense = ml.Dense(12800, dim, 1)
|
|
layers = [dense, ml.Output(12800, debug=debug, approx='approx' in program.args)]
|
|
|
|
sgd = ml.SGD(layers, batch // 128 * 10 , debug=debug, report_loss=False)
|
|
sgd.reset([X_normal, X_pos])
|
|
sgd.run(batch_size=batch)
|
|
|
|
# @for_range(1000)
|
|
# def _(i):
|
|
# sgd.backward()
|