Files
MP-SPDZ/Programs/Source/adult.mpc
2022-11-09 11:22:18 +11:00

55 lines
1.1 KiB
Plaintext

m = 6
n_train = 32561
n_test = 16281
combo = 'combo' in program.args
binary = 'binary' in program.args
mixed = 'mixed' in program.args
nocap = 'nocap' in program.args
try:
n_threads = int(program.args[2])
except:
n_threads = None
if combo:
n_train += n_test
if binary:
m = 60
attr_lengths = [1] * m
elif mixed or nocap:
cont = 6 if mixed else 3
m = 60 + cont
attr_lengths = [0] * cont + [1] * 60
else:
attr_lengths = None
program.set_bit_length(32)
program.options_from_args()
train = sint.Array(n_train), sint.Matrix(m, n_train)
test = sint.Array(n_test), sint.Matrix(m, n_test)
for x in train + test:
x.input_from(0)
import decision_tree, util
#decision_tree.debug_layers = True
decision_tree.max_leaves = 3000
if 'nearest' in program.args:
sfix.round_nearest = True
sfix.set_precision_from_args(program, True)
trainer = decision_tree.TreeTrainer(
train[1], train[0], int(program.args[1]), attr_lengths=attr_lengths,
n_threads=n_threads)
trainer.debug_selection = 'debug_selection' in program.args
trainer.debug_gini = True
layers = trainer.train_with_testing(*test)
#decision_tree.output_decision_tree(layers)