mirror of
https://github.com/microsoft/autogen.git
synced 2026-02-03 23:05:02 -05:00
support for customized splitters (#333)
* add support for customized splitters * use the param split_type for feeding generators * use single API for customized splitter and add test * when task==TS_FORCAST, always set shuffle=False * update docstr Co-authored-by: Chi Wang <wang.chi@microsoft.com>
This commit is contained in:
@@ -2,7 +2,7 @@ import unittest
|
||||
|
||||
from sklearn.datasets import fetch_openml
|
||||
from flaml.automl import AutoML
|
||||
from sklearn.model_selection import train_test_split
|
||||
from sklearn.model_selection import train_test_split, KFold
|
||||
from sklearn.metrics import accuracy_score
|
||||
|
||||
|
||||
@@ -123,6 +123,45 @@ def test_rank():
|
||||
automl.fit(X, y, **automl_settings)
|
||||
|
||||
|
||||
def test_object():
|
||||
from sklearn.externals._arff import ArffException
|
||||
|
||||
try:
|
||||
X, y = fetch_openml(name=dataset, return_X_y=True)
|
||||
except (ArffException, ValueError):
|
||||
from sklearn.datasets import load_wine
|
||||
|
||||
X, y = load_wine(return_X_y=True)
|
||||
|
||||
import numpy as np
|
||||
|
||||
class TestKFold(KFold):
|
||||
def __init__(self, n_splits):
|
||||
self.n_splits = int(n_splits)
|
||||
|
||||
def split(self, X):
|
||||
rng = np.random.default_rng()
|
||||
train_num = int(len(X) * 0.8)
|
||||
for _ in range(self.n_splits):
|
||||
permu_idx = rng.permutation(len(X))
|
||||
yield permu_idx[:train_num], permu_idx[train_num:]
|
||||
|
||||
def get_n_splits(self, X=None, y=None, groups=None):
|
||||
return self.n_splits
|
||||
|
||||
automl = AutoML()
|
||||
automl_settings = {
|
||||
"time_budget": 2,
|
||||
# "metric": 'accuracy',
|
||||
"task": "classification",
|
||||
"log_file_name": "test/{}.log".format(dataset),
|
||||
"model_history": True,
|
||||
"log_training_metric": True,
|
||||
"split_type": TestKFold(5),
|
||||
}
|
||||
automl.fit(X, y, **automl_settings)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# unittest.main()
|
||||
test_groups()
|
||||
|
||||
Reference in New Issue
Block a user