diff --git a/flaml/automl.py b/flaml/automl.py index eb084c1e3..91eaf2215 100644 --- a/flaml/automl.py +++ b/flaml/automl.py @@ -1,10 +1,10 @@ # ! -# * Copyright (c) Microsoft Corporation. All rights reserved. +# * Copyright (c) FLAML authors. All rights reserved. # * Licensed under the MIT License. See LICENSE file in the # * project root for license information. import time import os -from typing import Callable, Optional +from typing import Callable, Optional, List, Union from functools import partial import numpy as np from scipy.sparse import issparse @@ -20,10 +20,7 @@ from sklearn.utils import shuffle from sklearn.base import BaseEstimator import pandas as pd import logging -from typing import List, Union -from pandas import DataFrame -from .data import _is_nlp_task - +import json from .ml import ( compute_estimator, train_estimator, @@ -40,8 +37,14 @@ from .config import ( N_SPLITS, SAMPLE_MULTIPLY_FACTOR, ) - -from .data import concat, CLASSIFICATION, TS_FORECAST, FORECAST, REGRESSION +from .data import ( + concat, + CLASSIFICATION, + TS_FORECAST, + FORECAST, + REGRESSION, + _is_nlp_task, +) from . import tune from .training_log import training_log_reader, training_log_writer @@ -678,6 +681,15 @@ class AutoML(BaseEstimator): self._search_states[self._best_estimator], "best_config_train_time", None ) + def save_best_config(self, filename): + best = { + "class": self.best_estimator, + "hyperparameters": self.best_config, + } + os.makedirs(os.path.dirname(filename), exist_ok=True) + with open(filename, "w") as f: + json.dump(best, f) + @property def classes_(self): """A list of n_classes elements for class labels.""" @@ -694,7 +706,9 @@ class AutoML(BaseEstimator): """Time taken to find best model in seconds.""" return self.__dict__.get("_time_taken_best_iter") - def predict(self, X_test: Union[np.array, DataFrame, List[str], List[List[str]]]): + def predict( + self, X_test: Union[np.array, pd.DataFrame, List[str], List[List[str]]] + ): """Predict label from features. Args: @@ -763,7 +777,7 @@ class AutoML(BaseEstimator): try: if isinstance(X[0], List): X = [x for x in zip(*X)] - X = DataFrame( + X = pd.DataFrame( dict( [ (self._transformer._str_columns[idx], X[idx]) diff --git a/test/automl/test_python_log.py b/test/automl/test_python_log.py index 73581d539..c367600c1 100644 --- a/test/automl/test_python_log.py +++ b/test/automl/test_python_log.py @@ -119,3 +119,4 @@ class TestLogging(unittest.TestCase): pred2 = automl.predict(X_train) delta = pred1 - pred2 assert max(delta) == 0 and min(delta) == 0 + automl.save_best_config("test/housing.json")