diff --git a/flaml/tune/tune.py b/flaml/tune/tune.py index 3468bff60..d112da932 100644 --- a/flaml/tune/tune.py +++ b/flaml/tune/tune.py @@ -6,6 +6,7 @@ from typing import Optional, Union, List, Callable, Tuple import numpy as np import datetime import time +import os try: from ray import __version__ as ray_version @@ -147,6 +148,7 @@ def run( max_failure: Optional[int] = 100, use_ray: Optional[bool] = False, use_incumbent_result_in_evaluation: Optional[bool] = None, + log_file_name: Optional[str] = None, **ray_args, ): """The trigger for HPO. @@ -298,6 +300,11 @@ def run( max_failure: int | the maximal consecutive number of failures to sample a trial before the tuning is terminated. use_ray: A boolean of whether to use ray as the backend. + log_file_name: A string of the log file name. Default to None. + When set to None: + if local_dir is not given, no log file is created; + if local_dir is given, the log file name will be autogenerated under local_dir. + Only valid when verbose > 0 or use_ray is True. **ray_args: keyword arguments to pass to ray.tune.run(). Only valid when use_ray=True. """ @@ -309,11 +316,19 @@ def run( old_verbose = _verbose old_running_trial = _running_trial old_training_iteration = _training_iteration + if local_dir and not log_file_name and verbose > 0: + os.makedirs(local_dir, exist_ok=True) + log_file_name = os.path.join( + local_dir, "tune_" + str(datetime.datetime.now()).replace(":", "-") + ".log" + ) if not use_ray: _verbose = verbose old_handlers = logger.handlers old_level = logger.getEffectiveLevel() logger.handlers = [] + global _runner + old_runner = _runner + assert not ray_args, "ray_args is only valid when use_ray=True" if ( old_handlers and isinstance(old_handlers[0], logging.StreamHandler) @@ -322,18 +337,8 @@ def run( # Add the console handler. logger.addHandler(old_handlers[0]) if verbose > 0: - if local_dir: - import os - - os.makedirs(local_dir, exist_ok=True) - logger.addHandler( - logging.FileHandler( - local_dir - + "/tune_" - + str(datetime.datetime.now()).replace(":", "-") - + ".log" - ) - ) + if log_file_name: + logger.addHandler(logging.FileHandler(log_file_name)) elif not logger.hasHandlers(): # Add the console handler. _ch = logging.StreamHandler() @@ -466,6 +471,10 @@ def run( resources_per_trial=resources_per_trial, **ray_args, ) + if log_file_name: + with open(log_file_name, "w") as f: + for trial in analysis.trials: + f.write(f"result: {trial.last_result}\n") return analysis finally: _use_ray = old_use_ray @@ -480,8 +489,6 @@ def run( scheduler.set_search_properties(metric=metric, mode=mode) from .trial_runner import SequentialTrialRunner - global _runner - old_runner = _runner try: _runner = SequentialTrialRunner( search_alg=search_alg, @@ -530,7 +537,7 @@ def run( _verbose = old_verbose _running_trial = old_running_trial _training_iteration = old_training_iteration - _runner = old_runner if not use_ray: + _runner = old_runner logger.handlers = old_handlers logger.setLevel(old_level) diff --git a/test/tune/test_constraints.py b/test/tune/test_constraints.py index 54043837a..0f6b18f75 100644 --- a/test/tune/test_constraints.py +++ b/test/tune/test_constraints.py @@ -22,6 +22,7 @@ def test_config_constraint(): metric="metric", mode="max", num_samples=100, + log_file_name="logs/config_constraint.log", ) assert analysis.best_config["x"] > analysis.best_config["y"] diff --git a/test/tune/test_searcher.py b/test/tune/test_searcher.py index d3002c3b2..c378eb706 100644 --- a/test/tune/test_searcher.py +++ b/test/tune/test_searcher.py @@ -295,7 +295,7 @@ def test_searcher(): print(searcher.suggest("t1")) from flaml import tune - tune.run(lambda x: 1, config={}, use_ray=use_ray) + tune.run(lambda x: 1, config={}, use_ray=use_ray, log_file_name="logs/searcher.log") def test_no_optuna(): diff --git a/test/tune/test_tune.py b/test/tune/test_tune.py index 5c00bf9b7..3ba47b9c8 100644 --- a/test/tune/test_tune.py +++ b/test/tune/test_tune.py @@ -47,6 +47,8 @@ def test_nested_run(): mode="min", num_samples=5, local_dir="logs", + log_file_name="logs/nested.log", + verbose=3, ) print(analysis.best_result)