mirror of
https://github.com/data61/MP-SPDZ.git
synced 2026-01-10 22:17:57 -05:00
50 lines
950 B
Plaintext
50 lines
950 B
Plaintext
m = 22
|
|
n_train = 80
|
|
n_test = 187
|
|
|
|
debug = 'debug' in program.args
|
|
combo = 'combo' in program.args
|
|
|
|
if debug:
|
|
n_train = 7
|
|
|
|
if combo:
|
|
n_train += n_test
|
|
|
|
Array.check_indices = False
|
|
MultiArray.disable_index_checks()
|
|
|
|
train = sint.Array(n_train), sint.Matrix(m, n_train)
|
|
test = sint.Array(n_test), sint.Matrix(m, n_test)
|
|
|
|
for x in train:
|
|
x.input_from(0)
|
|
|
|
if not (debug or combo):
|
|
for x in test:
|
|
x.input_from(0)
|
|
|
|
import decision_tree, util
|
|
|
|
#decision_tree.debug = True
|
|
|
|
if 'nearest' in program.args:
|
|
sfix.round_nearest = True
|
|
|
|
sfix.set_precision_from_args(program, True)
|
|
|
|
try:
|
|
n_threads = int(program.args[3])
|
|
except:
|
|
n_threads = None
|
|
|
|
trainer = decision_tree.TreeTrainer(
|
|
train[1], train[0], int(program.args[1]), binary=int(program.args[2]),
|
|
n_threads=n_threads)
|
|
|
|
if not (debug or combo):
|
|
layers = trainer.train_with_testing(*test)
|
|
else:
|
|
layers = trainer.train()
|
|
test_decision_tree('train', layers, y, x)
|