mirror of
https://github.com/microsoft/autogen.git
synced 2026-04-20 03:02:16 -04:00
* close #249 * admissible region * best_config can be None * optional dependency on lgbm and xgb resolve #252
This commit is contained in:
@@ -13,10 +13,9 @@ import pandas as pd
|
||||
from datetime import datetime
|
||||
|
||||
from flaml import AutoML
|
||||
from flaml.data import get_output_from_log
|
||||
from flaml.data import CLASSIFICATION, get_output_from_log
|
||||
|
||||
from flaml.model import LGBMEstimator, SKLearnEstimator, XGBoostEstimator
|
||||
from rgf.sklearn import RGFClassifier, RGFRegressor
|
||||
from flaml import tune
|
||||
from flaml.training_log import training_log_reader
|
||||
|
||||
@@ -26,9 +25,13 @@ class MyRegularizedGreedyForest(SKLearnEstimator):
|
||||
|
||||
super().__init__(task, **config)
|
||||
|
||||
if task in ("binary", "multi"):
|
||||
if task in CLASSIFICATION:
|
||||
from rgf.sklearn import RGFClassifier
|
||||
|
||||
self.estimator_class = RGFClassifier
|
||||
else:
|
||||
from rgf.sklearn import RGFRegressor
|
||||
|
||||
self.estimator_class = RGFRegressor
|
||||
|
||||
@classmethod
|
||||
@@ -628,7 +631,7 @@ class TestAutoML(unittest.TestCase):
|
||||
"log_file_name": "test/california.log",
|
||||
"log_type": "all",
|
||||
"n_jobs": 1,
|
||||
"n_concurrent_trials": 2,
|
||||
"n_concurrent_trials": 10,
|
||||
"hpo_method": hpo_method,
|
||||
}
|
||||
X_train, y_train = fetch_california_housing(return_X_y=True)
|
||||
|
||||
@@ -109,4 +109,4 @@ def test_mlflow():
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_automl(300)
|
||||
test_automl(120)
|
||||
|
||||
@@ -64,18 +64,20 @@ class TestLogging(unittest.TestCase):
|
||||
automl.search_space, automl.low_cost_partial_config, automl.cat_hp_cost
|
||||
)
|
||||
logger.info(automl.search_space["ml"].categories)
|
||||
config = automl.best_config.copy()
|
||||
config["learner"] = automl.best_estimator
|
||||
automl.trainable({"ml": config})
|
||||
if automl.best_config:
|
||||
config = automl.best_config.copy()
|
||||
config["learner"] = automl.best_estimator
|
||||
automl.trainable({"ml": config})
|
||||
from flaml import tune, BlendSearch
|
||||
from flaml.automl import size
|
||||
from functools import partial
|
||||
|
||||
low_cost_partial_config = automl.low_cost_partial_config
|
||||
search_alg = BlendSearch(
|
||||
metric="val_loss",
|
||||
mode="min",
|
||||
space=automl.search_space,
|
||||
low_cost_partial_config=automl.low_cost_partial_config,
|
||||
low_cost_partial_config=low_cost_partial_config,
|
||||
points_to_evaluate=automl.points_to_evaluate,
|
||||
cat_hp_cost=automl.cat_hp_cost,
|
||||
prune_attr=automl.prune_attr,
|
||||
@@ -95,6 +97,14 @@ class TestLogging(unittest.TestCase):
|
||||
print(min(trial.last_result["val_loss"] for trial in analysis.trials))
|
||||
config = analysis.trials[-1].last_result["config"]["ml"]
|
||||
automl._state._train_with_config(config["learner"], config)
|
||||
for _ in range(3):
|
||||
print(
|
||||
search_alg._ls.complete_config(
|
||||
low_cost_partial_config,
|
||||
search_alg._ls_bound_min,
|
||||
search_alg._ls_bound_max,
|
||||
)
|
||||
)
|
||||
# Check if the log buffer is populated.
|
||||
self.assertTrue(len(buf.getvalue()) > 0)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user