mirror of
https://github.com/microsoft/autogen.git
synced 2026-02-04 04:45:04 -05:00
update config if n_estimators is modified (#225)
* update config if n_estimators is modified * prediction as int * handle the case n_estimators <= 0 * if trained and no budget to train more, return the trained model * split_type=group for classification & regression
This commit is contained in:
@@ -17,7 +17,7 @@ def _test(split_type):
|
||||
automl_settings = {
|
||||
"time_budget": 2,
|
||||
# "metric": 'accuracy',
|
||||
"task": 'classification',
|
||||
"task": "classification",
|
||||
"log_file_name": "test/{}.log".format(dataset),
|
||||
"model_history": True,
|
||||
"log_training_metric": True,
|
||||
@@ -28,13 +28,16 @@ def _test(split_type):
|
||||
X, y = fetch_openml(name=dataset, return_X_y=True)
|
||||
except (ArffException, ValueError):
|
||||
from sklearn.datasets import load_wine
|
||||
|
||||
X, y = load_wine(return_X_y=True)
|
||||
if split_type != 'time':
|
||||
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.33,
|
||||
random_state=42)
|
||||
if split_type != "time":
|
||||
X_train, X_test, y_train, y_test = train_test_split(
|
||||
X, y, test_size=0.33, random_state=42
|
||||
)
|
||||
else:
|
||||
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.33,
|
||||
shuffle=False)
|
||||
X_train, X_test, y_train, y_test = train_test_split(
|
||||
X, y, test_size=0.33, shuffle=False
|
||||
)
|
||||
automl.fit(X_train=X_train, y_train=y_train, **automl_settings)
|
||||
|
||||
pred = automl.predict(X_test)
|
||||
@@ -53,36 +56,45 @@ def test_time():
|
||||
|
||||
def test_groups():
|
||||
from sklearn.externals._arff import ArffException
|
||||
|
||||
try:
|
||||
X, y = fetch_openml(name=dataset, return_X_y=True)
|
||||
except (ArffException, ValueError):
|
||||
from sklearn.datasets import load_wine
|
||||
|
||||
X, y = load_wine(return_X_y=True)
|
||||
|
||||
import numpy as np
|
||||
|
||||
automl = AutoML()
|
||||
automl_settings = {
|
||||
"time_budget": 2,
|
||||
"task": 'classification',
|
||||
"task": "classification",
|
||||
"log_file_name": "test/{}.log".format(dataset),
|
||||
"model_history": True,
|
||||
"eval_method": "cv",
|
||||
"groups": np.random.randint(low=0, high=10, size=len(y)),
|
||||
"estimator_list": ['lgbm', 'rf', 'xgboost', 'kneighbor'], # list of ML learners
|
||||
"estimator_list": ["lgbm", "rf", "xgboost", "kneighbor"],
|
||||
"learner_selector": "roundrobin",
|
||||
}
|
||||
automl.fit(X, y, **automl_settings)
|
||||
|
||||
automl_settings["eval_method"] = "holdout"
|
||||
automl.fit(X, y, **automl_settings)
|
||||
|
||||
|
||||
def test_rank():
|
||||
from sklearn.externals._arff import ArffException
|
||||
|
||||
try:
|
||||
X, y = fetch_openml(name=dataset, return_X_y=True)
|
||||
except (ArffException, ValueError):
|
||||
from sklearn.datasets import load_wine
|
||||
|
||||
X, y = load_wine(return_X_y=True)
|
||||
y = y.cat.codes
|
||||
import numpy as np
|
||||
|
||||
automl = AutoML()
|
||||
automl_settings = {
|
||||
"time_budget": 2,
|
||||
@@ -90,8 +102,9 @@ def test_rank():
|
||||
"log_file_name": "test/{}.log".format(dataset),
|
||||
"model_history": True,
|
||||
"eval_method": "cv",
|
||||
"groups": np.array( # group labels
|
||||
[0] * 200 + [1] * 200 + [2] * 200 + [3] * 200 + [4] * 100 + [5] * 100),
|
||||
"groups": np.array( # group labels
|
||||
[0] * 200 + [1] * 200 + [2] * 200 + [3] * 200 + [4] * 100 + [5] * 100
|
||||
),
|
||||
"learner_selector": "roundrobin",
|
||||
}
|
||||
automl.fit(X, y, **automl_settings)
|
||||
@@ -100,10 +113,10 @@ def test_rank():
|
||||
automl_settings = {
|
||||
"time_budget": 2,
|
||||
"task": "rank",
|
||||
"metric": "ndcg@5", # 5 can be replaced by any number
|
||||
"metric": "ndcg@5", # 5 can be replaced by any number
|
||||
"log_file_name": "test/{}.log".format(dataset),
|
||||
"model_history": True,
|
||||
"groups": [200] * 4 + [100] * 2, # alternative way: group counts
|
||||
"groups": [200] * 4 + [100] * 2, # alternative way: group counts
|
||||
# "estimator_list": ['lgbm', 'xgboost'], # list of ML learners
|
||||
"learner_selector": "roundrobin",
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user