Allow custom GroupKFold object as split_type (#616)

* Allow custom GroupKFold object

* handle unpickle error for prophet 1.1

* eval_method in test_object()
This commit is contained in:
Chi Wang
2022-06-29 21:04:25 -07:00
committed by GitHub
parent cbb85e2aab
commit 74cca60606
3 changed files with 50 additions and 13 deletions

View File

@@ -46,9 +46,11 @@ class TestScore:
automl.score(X_test, y_test)
automl.pickle("automl.pkl")
with open("automl.pkl", "rb") as f:
pickle.load(f)
except ImportError:
print("not using prophet due to ImportError")
pickle.load(f) # v1.1 of prophet raises RecursionError
except (ImportError, RecursionError):
print(
"not using prophet due to ImportError or RecursionError (when unpickling in v1.1)"
)
automl.fit(
dataframe=df,
**settings,

View File

@@ -1,6 +1,6 @@
from sklearn.datasets import fetch_openml
from flaml.automl import AutoML
from sklearn.model_selection import train_test_split, KFold
from sklearn.model_selection import GroupKFold, train_test_split, KFold
from sklearn.metrics import accuracy_score
@@ -80,6 +80,19 @@ def test_groups():
automl_settings["eval_method"] = "holdout"
automl.fit(X, y, **automl_settings)
automl_settings["split_type"] = GroupKFold(n_splits=3)
try:
automl.fit(X, y, **automl_settings)
raise RuntimeError(
"GroupKFold object as split_type should fail when eval_method is holdout"
)
except AssertionError:
# eval_method must be 'auto' or 'cv' for custom data splitter.
pass
automl_settings["eval_method"] = "cv"
automl.fit(X, y, **automl_settings)
def test_rank():
from sklearn.externals._arff import ArffException
@@ -150,7 +163,6 @@ def test_object():
automl = AutoML()
automl_settings = {
"time_budget": 2,
# "metric": 'accuracy',
"task": "classification",
"log_file_name": "test/{}.log".format(dataset),
"model_history": True,
@@ -158,6 +170,9 @@ def test_object():
"split_type": TestKFold(5),
}
automl.fit(X, y, **automl_settings)
assert (
automl._state.eval_method == "cv"
), "eval_method must be 'cv' for custom data splitter"
if __name__ == "__main__":