Parameter for multithreading.

This commit is contained in:
Marcel Keller
2023-08-08 11:40:21 +10:00
parent d8a1e18e26
commit 13cd9420f9
2 changed files with 8 additions and 4 deletions

View File

@@ -549,10 +549,12 @@ class TreeClassifier:
:py:class:`TreeTrainer` internally.
:param max_depth: the depth of the decision tree
:param n_threads: number of threads used in training
"""
def __init__(self, max_depth):
def __init__(self, max_depth, n_threads=None):
self.max_depth = max_depth
self.n_threads = n_threads
@staticmethod
def get_attr_lengths(attr_types):
@@ -570,7 +572,8 @@ class TreeClassifier:
"""
self.tree = TreeTrainer(
X.transpose(), y, self.max_depth,
attr_lengths=self.get_attr_lengths(attr_types)).train()
attr_lengths=self.get_attr_lengths(attr_types),
n_threads=self.n_threads).train()
def fit_with_testing(self, X_train, y_train, X_test, y_test,
attr_types=None, output_trees=False, debug=False):
@@ -587,7 +590,8 @@ class TreeClassifier:
"""
trainer = TreeTrainer(X_train.transpose(), y_train, self.max_depth,
attr_lengths=self.get_attr_lengths(attr_types))
attr_lengths=self.get_attr_lengths(attr_types),
n_threads=self.n_threads)
trainer.debug = debug
trainer.debug_gini = debug
trainer.debug_threading = debug > 1

View File

@@ -18,7 +18,7 @@ sfix.set_precision_from_args(program)
from Compiler.decision_tree import TreeClassifier
tree = TreeClassifier(max_depth=5)
tree = TreeClassifier(max_depth=5, n_threads=2)
# plain training
tree.fit(X_train, y_train)