mirror of
https://github.com/data61/MP-SPDZ.git
synced 2026-01-09 13:37:58 -05:00
Parameter for multithreading.
This commit is contained in:
@@ -549,10 +549,12 @@ class TreeClassifier:
|
||||
:py:class:`TreeTrainer` internally.
|
||||
|
||||
:param max_depth: the depth of the decision tree
|
||||
:param n_threads: number of threads used in training
|
||||
|
||||
"""
|
||||
def __init__(self, max_depth):
|
||||
def __init__(self, max_depth, n_threads=None):
|
||||
self.max_depth = max_depth
|
||||
self.n_threads = n_threads
|
||||
|
||||
@staticmethod
|
||||
def get_attr_lengths(attr_types):
|
||||
@@ -570,7 +572,8 @@ class TreeClassifier:
|
||||
"""
|
||||
self.tree = TreeTrainer(
|
||||
X.transpose(), y, self.max_depth,
|
||||
attr_lengths=self.get_attr_lengths(attr_types)).train()
|
||||
attr_lengths=self.get_attr_lengths(attr_types),
|
||||
n_threads=self.n_threads).train()
|
||||
|
||||
def fit_with_testing(self, X_train, y_train, X_test, y_test,
|
||||
attr_types=None, output_trees=False, debug=False):
|
||||
@@ -587,7 +590,8 @@ class TreeClassifier:
|
||||
|
||||
"""
|
||||
trainer = TreeTrainer(X_train.transpose(), y_train, self.max_depth,
|
||||
attr_lengths=self.get_attr_lengths(attr_types))
|
||||
attr_lengths=self.get_attr_lengths(attr_types),
|
||||
n_threads=self.n_threads)
|
||||
trainer.debug = debug
|
||||
trainer.debug_gini = debug
|
||||
trainer.debug_threading = debug > 1
|
||||
|
||||
@@ -18,7 +18,7 @@ sfix.set_precision_from_args(program)
|
||||
|
||||
from Compiler.decision_tree import TreeClassifier
|
||||
|
||||
tree = TreeClassifier(max_depth=5)
|
||||
tree = TreeClassifier(max_depth=5, n_threads=2)
|
||||
|
||||
# plain training
|
||||
tree.fit(X_train, y_train)
|
||||
|
||||
Reference in New Issue
Block a user