mirror of
https://github.com/microsoft/autogen.git
synced 2026-02-15 18:55:03 -05:00
Allow custom GroupKFold object as split_type (#616)
* Allow custom GroupKFold object * handle unpickle error for prophet 1.1 * eval_method in test_object()
This commit is contained in:
@@ -46,9 +46,11 @@ class TestScore:
|
||||
automl.score(X_test, y_test)
|
||||
automl.pickle("automl.pkl")
|
||||
with open("automl.pkl", "rb") as f:
|
||||
pickle.load(f)
|
||||
except ImportError:
|
||||
print("not using prophet due to ImportError")
|
||||
pickle.load(f) # v1.1 of prophet raises RecursionError
|
||||
except (ImportError, RecursionError):
|
||||
print(
|
||||
"not using prophet due to ImportError or RecursionError (when unpickling in v1.1)"
|
||||
)
|
||||
automl.fit(
|
||||
dataframe=df,
|
||||
**settings,
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from sklearn.datasets import fetch_openml
|
||||
from flaml.automl import AutoML
|
||||
from sklearn.model_selection import train_test_split, KFold
|
||||
from sklearn.model_selection import GroupKFold, train_test_split, KFold
|
||||
from sklearn.metrics import accuracy_score
|
||||
|
||||
|
||||
@@ -80,6 +80,19 @@ def test_groups():
|
||||
automl_settings["eval_method"] = "holdout"
|
||||
automl.fit(X, y, **automl_settings)
|
||||
|
||||
automl_settings["split_type"] = GroupKFold(n_splits=3)
|
||||
try:
|
||||
automl.fit(X, y, **automl_settings)
|
||||
raise RuntimeError(
|
||||
"GroupKFold object as split_type should fail when eval_method is holdout"
|
||||
)
|
||||
except AssertionError:
|
||||
# eval_method must be 'auto' or 'cv' for custom data splitter.
|
||||
pass
|
||||
|
||||
automl_settings["eval_method"] = "cv"
|
||||
automl.fit(X, y, **automl_settings)
|
||||
|
||||
|
||||
def test_rank():
|
||||
from sklearn.externals._arff import ArffException
|
||||
@@ -150,7 +163,6 @@ def test_object():
|
||||
automl = AutoML()
|
||||
automl_settings = {
|
||||
"time_budget": 2,
|
||||
# "metric": 'accuracy',
|
||||
"task": "classification",
|
||||
"log_file_name": "test/{}.log".format(dataset),
|
||||
"model_history": True,
|
||||
@@ -158,6 +170,9 @@ def test_object():
|
||||
"split_type": TestKFold(5),
|
||||
}
|
||||
automl.fit(X, y, **automl_settings)
|
||||
assert (
|
||||
automl._state.eval_method == "cv"
|
||||
), "eval_method must be 'cv' for custom data splitter"
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
Reference in New Issue
Block a user