mirror of
https://github.com/data61/MP-SPDZ.git
synced 2026-04-20 03:01:31 -04:00
55 lines
1.1 KiB
Plaintext
55 lines
1.1 KiB
Plaintext
m = 6
|
|
n_train = 32561
|
|
n_test = 16281
|
|
|
|
combo = 'combo' in program.args
|
|
binary = 'binary' in program.args
|
|
mixed = 'mixed' in program.args
|
|
nocap = 'nocap' in program.args
|
|
|
|
try:
|
|
n_threads = int(program.args[2])
|
|
except:
|
|
n_threads = None
|
|
|
|
if combo:
|
|
n_train += n_test
|
|
|
|
if binary:
|
|
m = 60
|
|
attr_lengths = [1] * m
|
|
elif mixed or nocap:
|
|
cont = 6 if mixed else 3
|
|
m = 60 + cont
|
|
attr_lengths = [0] * cont + [1] * 60
|
|
else:
|
|
attr_lengths = None
|
|
|
|
program.set_bit_length(32)
|
|
program.options_from_args()
|
|
|
|
train = sint.Array(n_train), sint.Matrix(m, n_train)
|
|
test = sint.Array(n_test), sint.Matrix(m, n_test)
|
|
|
|
for x in train + test:
|
|
x.input_from(0)
|
|
|
|
import decision_tree, util
|
|
|
|
#decision_tree.debug_layers = True
|
|
decision_tree.max_leaves = 3000
|
|
|
|
if 'nearest' in program.args:
|
|
sfix.round_nearest = True
|
|
|
|
sfix.set_precision_from_args(program, True)
|
|
|
|
trainer = decision_tree.TreeTrainer(
|
|
train[1], train[0], int(program.args[1]), attr_lengths=attr_lengths,
|
|
n_threads=n_threads)
|
|
trainer.debug_selection = 'debug_selection' in program.args
|
|
trainer.debug_gini = True
|
|
layers = trainer.train_with_testing(*test)
|
|
|
|
#decision_tree.output_decision_tree(layers)
|