mirror of
https://github.com/microsoft/autogen.git
synced 2026-01-23 04:08:04 -05:00
discount running thread (#121)
* discount running thread * version * limit dir * report result * catch * remove handler
This commit is contained in:
@@ -372,7 +372,6 @@ class BlendSearch(Searcher):
|
||||
choice, backup = self._select_thread()
|
||||
if choice < 0: # timeout
|
||||
return None
|
||||
self._use_rs = False
|
||||
config = self._search_thread_pool[choice].suggest(trial_id)
|
||||
if choice and config is None:
|
||||
# local search thread finishes
|
||||
@@ -386,18 +385,16 @@ class BlendSearch(Searcher):
|
||||
if choice:
|
||||
return None
|
||||
# use rs when BO fails to suggest a config
|
||||
self._use_rs = True
|
||||
for _, generated in generate_variants({'config': self._ls.space}):
|
||||
config = generated['config']
|
||||
break # get one random config
|
||||
skip = self._should_skip(choice, trial_id, config)
|
||||
skip = self._should_skip(-1, trial_id, config)
|
||||
if skip:
|
||||
return None
|
||||
if choice or self._valid(config):
|
||||
# LS or valid or no backup choice
|
||||
self._trial_proposed_by[trial_id] = choice
|
||||
else: # invalid config proposed by GS
|
||||
self._use_rs = False
|
||||
if choice == backup:
|
||||
# use CFO's init point
|
||||
init_config = self._ls.init_config
|
||||
@@ -439,6 +436,7 @@ class BlendSearch(Searcher):
|
||||
return None
|
||||
self._init_used = True
|
||||
self._trial_proposed_by[trial_id] = 0
|
||||
self._search_thread_pool[0].running += 1
|
||||
return config
|
||||
|
||||
def _should_skip(self, choice, trial_id, config) -> bool:
|
||||
@@ -462,16 +460,16 @@ class BlendSearch(Searcher):
|
||||
}
|
||||
exists = True
|
||||
break
|
||||
if exists:
|
||||
if not self._use_rs:
|
||||
if exists: # suggested before
|
||||
if choice >= 0: # not fallback to rs
|
||||
result = self._result.get(config_signature)
|
||||
if result:
|
||||
if result: # finished
|
||||
self._search_thread_pool[choice].on_trial_complete(
|
||||
trial_id, result, error=False)
|
||||
if choice:
|
||||
# local search thread
|
||||
self._clean(choice)
|
||||
# else:
|
||||
# else: # running
|
||||
# # tell the thread there is an error
|
||||
# self._search_thread_pool[choice].on_trial_complete(
|
||||
# trial_id, {}, error=True)
|
||||
|
||||
@@ -181,7 +181,7 @@ class FLOW2(Searcher):
|
||||
if self.step > self.step_ub:
|
||||
self.step = self.step_ub
|
||||
# maximal # consecutive no improvements
|
||||
self.dir = 2**(self.dim)
|
||||
self.dir = 2**(min(9, self.dim))
|
||||
self._configs = {} # dict from trial_id to (config, stepsize)
|
||||
self._K = 0
|
||||
self._iter_best_config = self.trial_count_proposed = self.trial_count_complete = 1
|
||||
|
||||
@@ -39,6 +39,7 @@ class SearchThread:
|
||||
self.eci = self.cost_best
|
||||
self.priority = self.speed = 0
|
||||
self._init_config = True
|
||||
self.running = 0 # the number of running trials from the thread
|
||||
|
||||
@classmethod
|
||||
def set_eps(cls, time_budget_s):
|
||||
@@ -57,6 +58,8 @@ class SearchThread:
|
||||
'The global search method raises FloatingPointError. '
|
||||
'Ignoring for this iteration.')
|
||||
config = None
|
||||
if config is not None:
|
||||
self.running += 1
|
||||
return config
|
||||
|
||||
def update_priority(self, eci: Optional[float] = 0):
|
||||
@@ -77,7 +80,8 @@ class SearchThread:
|
||||
def _update_speed(self):
|
||||
# calculate speed; use 0 for invalid speed temporarily
|
||||
if self.obj_best2 > self.obj_best1:
|
||||
self.speed = (self.obj_best2 - self.obj_best1) / (
|
||||
# discount the speed if there are unfinished trials
|
||||
self.speed = (self.obj_best2 - self.obj_best1) / self.running / (
|
||||
max(self.cost_total - self.cost_best2, SearchThread._eps))
|
||||
else:
|
||||
self.speed = 0
|
||||
@@ -92,7 +96,13 @@ class SearchThread:
|
||||
not error and trial_id in self._search_alg._ot_trials):
|
||||
# optuna doesn't handle error
|
||||
if self._is_ls or not self._init_config:
|
||||
self._search_alg.on_trial_complete(trial_id, result, error)
|
||||
try:
|
||||
self._search_alg.on_trial_complete(trial_id, result, error)
|
||||
except RuntimeError as e:
|
||||
# rs is used in place of optuna sometimes
|
||||
if not str(e).endswith(
|
||||
"has already finished and can not be updated."):
|
||||
raise e
|
||||
else:
|
||||
# init config is not proposed by self._search_alg
|
||||
# under this thread
|
||||
@@ -111,6 +121,8 @@ class SearchThread:
|
||||
self.obj_best1 = obj
|
||||
self.cost_best = self.cost_last
|
||||
self._update_speed()
|
||||
self.running -= 1
|
||||
assert self.running >= 0
|
||||
|
||||
def on_trial_result(self, trial_id: str, result: Dict):
|
||||
''' TODO update the statistics of the thread with partial result?
|
||||
@@ -118,8 +130,14 @@ class SearchThread:
|
||||
if not self._search_alg:
|
||||
return
|
||||
if not hasattr(self._search_alg, '_ot_trials') or (
|
||||
trial_id in self._search_alg._ot_trials):
|
||||
self._search_alg.on_trial_result(trial_id, result)
|
||||
trial_id in self._search_alg._ot_trials):
|
||||
try:
|
||||
self._search_alg.on_trial_result(trial_id, result)
|
||||
except RuntimeError as e:
|
||||
# rs is used in place of optuna sometimes
|
||||
if not str(e).endswith(
|
||||
"has already finished and can not be updated."):
|
||||
raise e
|
||||
if self.cost_attr in result and self.cost_last < result[self.cost_attr]:
|
||||
self.cost_last = result[self.cost_attr]
|
||||
# self._update_speed()
|
||||
|
||||
@@ -31,7 +31,7 @@ class ExperimentAnalysis(EA):
|
||||
super().__init__(self, None, trials, metric, mode)
|
||||
except (TypeError, ValueError):
|
||||
self.trials = trials
|
||||
self.default_metric = metric
|
||||
self.default_metric = metric or '_default_anonymous_metric'
|
||||
self.default_mode = mode
|
||||
|
||||
|
||||
@@ -257,7 +257,8 @@ def run(training_function,
|
||||
if search_alg is None:
|
||||
from ..searcher.blendsearch import BlendSearch
|
||||
search_alg = BlendSearch(
|
||||
metric=metric, mode=mode, space=config,
|
||||
metric=metric or '_default_anonymous_metric', mode=mode,
|
||||
space=config,
|
||||
points_to_evaluate=points_to_evaluate,
|
||||
low_cost_partial_config=low_cost_partial_config,
|
||||
cat_hp_cost=cat_hp_cost,
|
||||
@@ -325,6 +326,13 @@ def run(training_function,
|
||||
num_trials += 1
|
||||
if verbose:
|
||||
logger.info(f'trial {num_trials} config: {trial_to_run.config}')
|
||||
training_function(trial_to_run.config)
|
||||
result = training_function(trial_to_run.config)
|
||||
if result is not None:
|
||||
if isinstance(result, dict):
|
||||
tune.report(**result)
|
||||
else:
|
||||
tune.report(_metric=result)
|
||||
_runner.stop_trial(trial_to_run)
|
||||
if verbose > 0:
|
||||
logger.handlers.clear()
|
||||
return ExperimentAnalysis(_runner.get_trials(), metric=metric, mode=mode)
|
||||
|
||||
@@ -1 +1 @@
|
||||
__version__ = "0.5.4"
|
||||
__version__ = "0.5.5"
|
||||
|
||||
Reference in New Issue
Block a user