mirror of
https://github.com/microsoft/autogen.git
synced 2026-01-23 09:47:55 -05:00
Unify regression and classification for XGBoost (#276)
* scikit-learn API for XGBoostRegressor
This commit is contained in:
@@ -19,7 +19,6 @@ from sklearn.metrics import (
|
||||
)
|
||||
from sklearn.model_selection import RepeatedStratifiedKFold, GroupKFold, TimeSeriesSplit
|
||||
from .model import (
|
||||
XGBoostEstimator,
|
||||
XGBoostSklearnEstimator,
|
||||
RandomForestEstimator,
|
||||
LGBMEstimator,
|
||||
@@ -41,10 +40,7 @@ logger = logging.getLogger(__name__)
|
||||
def get_estimator_class(task, estimator_name):
|
||||
# when adding a new learner, need to add an elif branch
|
||||
if "xgboost" == estimator_name:
|
||||
if "regression" == task:
|
||||
estimator_class = XGBoostEstimator
|
||||
else:
|
||||
estimator_class = XGBoostSklearnEstimator
|
||||
estimator_class = XGBoostSklearnEstimator
|
||||
elif "rf" == estimator_name:
|
||||
estimator_class = RandomForestEstimator
|
||||
elif "lgbm" == estimator_name:
|
||||
|
||||
Reference in New Issue
Block a user