diff --git a/flaml/ml.py b/flaml/ml.py index 02c523d25..fdafec752 100644 --- a/flaml/ml.py +++ b/flaml/ml.py @@ -19,7 +19,6 @@ from sklearn.metrics import ( ) from sklearn.model_selection import RepeatedStratifiedKFold, GroupKFold, TimeSeriesSplit from .model import ( - XGBoostEstimator, XGBoostSklearnEstimator, RandomForestEstimator, LGBMEstimator, @@ -41,10 +40,7 @@ logger = logging.getLogger(__name__) def get_estimator_class(task, estimator_name): # when adding a new learner, need to add an elif branch if "xgboost" == estimator_name: - if "regression" == task: - estimator_class = XGBoostEstimator - else: - estimator_class = XGBoostSklearnEstimator + estimator_class = XGBoostSklearnEstimator elif "rf" == estimator_name: estimator_class = RandomForestEstimator elif "lgbm" == estimator_name: