mirror of
https://github.com/microsoft/autogen.git
synced 2026-02-15 22:55:03 -05:00
handle non-flaml scheduler in flaml.tune (#532)
* handle non-flaml scheduler in flaml.tune * revise time budget * Update website/docs/Use-Cases/Tune-User-Defined-Function.md Co-authored-by: Chi Wang <wang.chi@microsoft.com> * Update website/docs/Use-Cases/Tune-User-Defined-Function.md Co-authored-by: Chi Wang <wang.chi@microsoft.com> * Update flaml/tune/tune.py Co-authored-by: Chi Wang <wang.chi@microsoft.com> * add docstr * remove random seed * StopIteration * StopIteration format * format * Update flaml/tune/tune.py Co-authored-by: Chi Wang <wang.chi@microsoft.com> * revise docstr Co-authored-by: Chi Wang <wang.chi@microsoft.com>
This commit is contained in:
107
test/tune/example_scheduler.py
Normal file
107
test/tune/example_scheduler.py
Normal file
@@ -0,0 +1,107 @@
|
||||
from functools import partial
|
||||
import time
|
||||
|
||||
|
||||
def evaluation_fn(step, width, height):
|
||||
return (0.1 + width * step / 100) ** (-1) + height * 0.1
|
||||
|
||||
|
||||
def easy_objective(use_raytune, config):
|
||||
if use_raytune:
|
||||
from ray import tune
|
||||
else:
|
||||
from flaml import tune
|
||||
# Hyperparameters
|
||||
width, height = config["width"], config["height"]
|
||||
|
||||
for step in range(config["steps"]):
|
||||
# Iterative training function - can be any arbitrary training procedure
|
||||
intermediate_score = evaluation_fn(step, width, height)
|
||||
# Feed the score back back to Tune.
|
||||
try:
|
||||
tune.report(iterations=step, mean_loss=intermediate_score)
|
||||
except StopIteration:
|
||||
return
|
||||
|
||||
|
||||
def test_tune_scheduler(smoke_test=True, use_ray=True, use_raytune=False):
|
||||
import numpy as np
|
||||
from flaml.searcher.blendsearch import BlendSearch
|
||||
|
||||
np.random.seed(100)
|
||||
easy_objective_custom_tune = partial(easy_objective, use_raytune)
|
||||
if use_raytune:
|
||||
try:
|
||||
from ray import tune
|
||||
except ImportError:
|
||||
print("ray[tune] is not installed, skipping test")
|
||||
return
|
||||
searcher = BlendSearch(
|
||||
space={
|
||||
"steps": 100,
|
||||
"width": tune.uniform(0, 20),
|
||||
"height": tune.uniform(-100, 100),
|
||||
# This is an ignored parameter.
|
||||
"activation": tune.choice(["relu", "tanh"]),
|
||||
"test4": np.zeros((3, 1)),
|
||||
}
|
||||
)
|
||||
analysis = tune.run(
|
||||
easy_objective_custom_tune,
|
||||
search_alg=searcher,
|
||||
metric="mean_loss",
|
||||
mode="min",
|
||||
num_samples=10 if smoke_test else 100,
|
||||
scheduler="asynchyperband",
|
||||
config={
|
||||
"steps": 100,
|
||||
"width": tune.uniform(0, 20),
|
||||
"height": tune.uniform(-100, 100),
|
||||
# This is an ignored parameter.
|
||||
"activation": tune.choice(["relu", "tanh"]),
|
||||
"test4": np.zeros((3, 1)),
|
||||
},
|
||||
)
|
||||
else:
|
||||
from flaml import tune
|
||||
|
||||
searcher = BlendSearch(
|
||||
space={
|
||||
"steps": 100,
|
||||
"width": tune.uniform(0, 20),
|
||||
"height": tune.uniform(-100, 100),
|
||||
# This is an ignored parameter.
|
||||
"activation": tune.choice(["relu", "tanh"]),
|
||||
"test4": np.zeros((3, 1)),
|
||||
}
|
||||
)
|
||||
analysis = tune.run(
|
||||
easy_objective_custom_tune,
|
||||
search_alg=searcher,
|
||||
metric="mean_loss",
|
||||
mode="min",
|
||||
num_samples=10 if smoke_test else 100,
|
||||
scheduler="asynchyperband",
|
||||
resource_attr="iterations",
|
||||
max_resource=99,
|
||||
# min_resource=1,
|
||||
# reduction_factor=4,
|
||||
config={
|
||||
"steps": 100,
|
||||
"width": tune.uniform(0, 20),
|
||||
"height": tune.uniform(-100, 100),
|
||||
# This is an ignored parameter.
|
||||
"activation": tune.choice(["relu", "tanh"]),
|
||||
"test4": np.zeros((3, 1)),
|
||||
},
|
||||
use_ray=use_ray,
|
||||
)
|
||||
|
||||
print("Best hyperparameters found were: ", analysis.best_config)
|
||||
print("best results", analysis.best_result)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_tune_scheduler(smoke_test=True, use_ray=True, use_raytune=True)
|
||||
test_tune_scheduler(smoke_test=True, use_ray=True)
|
||||
test_tune_scheduler(smoke_test=True, use_ray=False)
|
||||
@@ -58,7 +58,6 @@ def _test_flaml_raytune_consistency(
|
||||
"skip _test_flaml_raytune_consistency because ray tune cannot be imported."
|
||||
)
|
||||
return
|
||||
np.random.seed(100)
|
||||
searcher = setup_searcher(searcher_name)
|
||||
analysis = tune.run(
|
||||
evaluate_config, # the function to evaluate a config
|
||||
@@ -78,7 +77,6 @@ def _test_flaml_raytune_consistency(
|
||||
flaml_time_in_results = [v["time_total_s"] for v in analysis.results.values()]
|
||||
print(analysis.best_trial.last_result) # the best trial's result
|
||||
|
||||
np.random.seed(100)
|
||||
searcher = setup_searcher(searcher_name)
|
||||
from ray.tune.suggest import ConcurrencyLimiter
|
||||
|
||||
|
||||
@@ -15,7 +15,7 @@ def rand_vector_unit_sphere(dim):
|
||||
return vec / mag
|
||||
|
||||
|
||||
def simple_obj(config, resource=10000):
|
||||
def simple_obj(resource, config):
|
||||
config_value_vector = np.array([config["x"], config["y"], config["z"]])
|
||||
score_sequence = []
|
||||
for i in range(resource):
|
||||
@@ -41,23 +41,29 @@ def obj_w_intermediate_report(resource, config):
|
||||
score_avg = np.mean(np.array(score_sequence))
|
||||
score_std = np.std(np.array(score_sequence))
|
||||
score_lb = score_avg - 1.96 * score_std / np.sqrt(i + 1)
|
||||
tune.report(samplesize=i + 1, sphere_projection=score_lb)
|
||||
try:
|
||||
tune.report(samplesize=i + 1, sphere_projection=score_lb)
|
||||
except StopIteration:
|
||||
return
|
||||
|
||||
|
||||
def obj_w_suggested_resource(resource_attr, config):
|
||||
resource = config[resource_attr]
|
||||
simple_obj(config, resource)
|
||||
simple_obj(resource, config)
|
||||
|
||||
|
||||
def test_scheduler(scheduler=None):
|
||||
def test_scheduler(scheduler=None, use_ray=False, time_budget_s=1):
|
||||
from functools import partial
|
||||
|
||||
resource_attr = "samplesize"
|
||||
max_resource = 10000
|
||||
|
||||
min_resource = 1000
|
||||
reduction_factor = 2
|
||||
time_budget_s = time_budget_s
|
||||
# specify the objective functions
|
||||
if scheduler is None:
|
||||
evaluation_obj = simple_obj
|
||||
evaluation_obj = partial(simple_obj, max_resource)
|
||||
min_resource = max_resource = reduction_factor = None
|
||||
elif scheduler == "flaml":
|
||||
evaluation_obj = partial(obj_w_suggested_resource, resource_attr)
|
||||
elif scheduler == "asha" or isinstance(scheduler, TrialScheduler):
|
||||
@@ -89,14 +95,17 @@ def test_scheduler(scheduler=None):
|
||||
resource_attr=resource_attr,
|
||||
scheduler=scheduler,
|
||||
max_resource=max_resource,
|
||||
min_resource=100,
|
||||
reduction_factor=2,
|
||||
time_budget_s=1,
|
||||
min_resource=min_resource,
|
||||
reduction_factor=reduction_factor,
|
||||
time_budget_s=time_budget_s,
|
||||
num_samples=500,
|
||||
use_ray=use_ray,
|
||||
)
|
||||
|
||||
print("Best hyperparameters found were: ", analysis.best_config)
|
||||
# print(analysis.get_best_trial)
|
||||
print(
|
||||
f"{len(analysis.results)} trials finished \
|
||||
in {time_budget_s} seconds with {str(scheduler)} scheduler"
|
||||
)
|
||||
return analysis.best_config
|
||||
|
||||
|
||||
@@ -105,13 +114,15 @@ def test_no_scheduler():
|
||||
print("No scheduler, test error:", abs(10 / 2 - best_config["z"] / 2))
|
||||
|
||||
|
||||
def test_asha_scheduler():
|
||||
def test_asha_scheduler(use_ray=False, time_budget_s=1):
|
||||
try:
|
||||
from ray.tune.schedulers import ASHAScheduler
|
||||
except ImportError:
|
||||
print("skip the test as ray tune cannot be imported.")
|
||||
return
|
||||
best_config = test_scheduler(scheduler="asha")
|
||||
best_config = test_scheduler(
|
||||
scheduler="asha", use_ray=use_ray, time_budget_s=time_budget_s
|
||||
)
|
||||
print("Auto ASHA scheduler, test error:", abs(10 / 2 - best_config["z"] / 2))
|
||||
|
||||
|
||||
@@ -150,6 +161,7 @@ def test_flaml_scheduler():
|
||||
if __name__ == "__main__":
|
||||
test_no_scheduler()
|
||||
test_asha_scheduler()
|
||||
test_asha_scheduler(use_ray=True, time_budget_s=3)
|
||||
test_custom_scheduler()
|
||||
test_custom_scheduler_default_time_attr()
|
||||
test_flaml_scheduler()
|
||||
|
||||
Reference in New Issue
Block a user