Files
MP-SPDZ/Programs/Source/spect.mpc
2022-11-09 11:22:18 +11:00

50 lines
950 B
Plaintext

m = 22
n_train = 80
n_test = 187
debug = 'debug' in program.args
combo = 'combo' in program.args
if debug:
n_train = 7
if combo:
n_train += n_test
Array.check_indices = False
MultiArray.disable_index_checks()
train = sint.Array(n_train), sint.Matrix(m, n_train)
test = sint.Array(n_test), sint.Matrix(m, n_test)
for x in train:
x.input_from(0)
if not (debug or combo):
for x in test:
x.input_from(0)
import decision_tree, util
#decision_tree.debug = True
if 'nearest' in program.args:
sfix.round_nearest = True
sfix.set_precision_from_args(program, True)
try:
n_threads = int(program.args[3])
except:
n_threads = None
trainer = decision_tree.TreeTrainer(
train[1], train[0], int(program.args[1]), binary=int(program.args[2]),
n_threads=n_threads)
if not (debug or combo):
layers = trainer.train_with_testing(*test)
else:
layers = trainer.train()
test_decision_tree('train', layers, y, x)