discount running thread (#121)

* discount running thread

* version

* limit dir

* report result

* catch

* remove handler
This commit is contained in:
Chi Wang
2021-06-25 14:24:46 -07:00
committed by GitHub
parent da9fc51e62
commit 2dbf38da0a
5 changed files with 41 additions and 17 deletions

View File

@@ -372,7 +372,6 @@ class BlendSearch(Searcher):
choice, backup = self._select_thread()
if choice < 0: # timeout
return None
self._use_rs = False
config = self._search_thread_pool[choice].suggest(trial_id)
if choice and config is None:
# local search thread finishes
@@ -386,18 +385,16 @@ class BlendSearch(Searcher):
if choice:
return None
# use rs when BO fails to suggest a config
self._use_rs = True
for _, generated in generate_variants({'config': self._ls.space}):
config = generated['config']
break # get one random config
skip = self._should_skip(choice, trial_id, config)
skip = self._should_skip(-1, trial_id, config)
if skip:
return None
if choice or self._valid(config):
# LS or valid or no backup choice
self._trial_proposed_by[trial_id] = choice
else: # invalid config proposed by GS
self._use_rs = False
if choice == backup:
# use CFO's init point
init_config = self._ls.init_config
@@ -439,6 +436,7 @@ class BlendSearch(Searcher):
return None
self._init_used = True
self._trial_proposed_by[trial_id] = 0
self._search_thread_pool[0].running += 1
return config
def _should_skip(self, choice, trial_id, config) -> bool:
@@ -462,16 +460,16 @@ class BlendSearch(Searcher):
}
exists = True
break
if exists:
if not self._use_rs:
if exists: # suggested before
if choice >= 0: # not fallback to rs
result = self._result.get(config_signature)
if result:
if result: # finished
self._search_thread_pool[choice].on_trial_complete(
trial_id, result, error=False)
if choice:
# local search thread
self._clean(choice)
# else:
# else: # running
# # tell the thread there is an error
# self._search_thread_pool[choice].on_trial_complete(
# trial_id, {}, error=True)

View File

@@ -181,7 +181,7 @@ class FLOW2(Searcher):
if self.step > self.step_ub:
self.step = self.step_ub
# maximal # consecutive no improvements
self.dir = 2**(self.dim)
self.dir = 2**(min(9, self.dim))
self._configs = {} # dict from trial_id to (config, stepsize)
self._K = 0
self._iter_best_config = self.trial_count_proposed = self.trial_count_complete = 1

View File

@@ -39,6 +39,7 @@ class SearchThread:
self.eci = self.cost_best
self.priority = self.speed = 0
self._init_config = True
self.running = 0 # the number of running trials from the thread
@classmethod
def set_eps(cls, time_budget_s):
@@ -57,6 +58,8 @@ class SearchThread:
'The global search method raises FloatingPointError. '
'Ignoring for this iteration.')
config = None
if config is not None:
self.running += 1
return config
def update_priority(self, eci: Optional[float] = 0):
@@ -77,7 +80,8 @@ class SearchThread:
def _update_speed(self):
# calculate speed; use 0 for invalid speed temporarily
if self.obj_best2 > self.obj_best1:
self.speed = (self.obj_best2 - self.obj_best1) / (
# discount the speed if there are unfinished trials
self.speed = (self.obj_best2 - self.obj_best1) / self.running / (
max(self.cost_total - self.cost_best2, SearchThread._eps))
else:
self.speed = 0
@@ -92,7 +96,13 @@ class SearchThread:
not error and trial_id in self._search_alg._ot_trials):
# optuna doesn't handle error
if self._is_ls or not self._init_config:
self._search_alg.on_trial_complete(trial_id, result, error)
try:
self._search_alg.on_trial_complete(trial_id, result, error)
except RuntimeError as e:
# rs is used in place of optuna sometimes
if not str(e).endswith(
"has already finished and can not be updated."):
raise e
else:
# init config is not proposed by self._search_alg
# under this thread
@@ -111,6 +121,8 @@ class SearchThread:
self.obj_best1 = obj
self.cost_best = self.cost_last
self._update_speed()
self.running -= 1
assert self.running >= 0
def on_trial_result(self, trial_id: str, result: Dict):
''' TODO update the statistics of the thread with partial result?
@@ -118,8 +130,14 @@ class SearchThread:
if not self._search_alg:
return
if not hasattr(self._search_alg, '_ot_trials') or (
trial_id in self._search_alg._ot_trials):
self._search_alg.on_trial_result(trial_id, result)
trial_id in self._search_alg._ot_trials):
try:
self._search_alg.on_trial_result(trial_id, result)
except RuntimeError as e:
# rs is used in place of optuna sometimes
if not str(e).endswith(
"has already finished and can not be updated."):
raise e
if self.cost_attr in result and self.cost_last < result[self.cost_attr]:
self.cost_last = result[self.cost_attr]
# self._update_speed()

View File

@@ -31,7 +31,7 @@ class ExperimentAnalysis(EA):
super().__init__(self, None, trials, metric, mode)
except (TypeError, ValueError):
self.trials = trials
self.default_metric = metric
self.default_metric = metric or '_default_anonymous_metric'
self.default_mode = mode
@@ -257,7 +257,8 @@ def run(training_function,
if search_alg is None:
from ..searcher.blendsearch import BlendSearch
search_alg = BlendSearch(
metric=metric, mode=mode, space=config,
metric=metric or '_default_anonymous_metric', mode=mode,
space=config,
points_to_evaluate=points_to_evaluate,
low_cost_partial_config=low_cost_partial_config,
cat_hp_cost=cat_hp_cost,
@@ -325,6 +326,13 @@ def run(training_function,
num_trials += 1
if verbose:
logger.info(f'trial {num_trials} config: {trial_to_run.config}')
training_function(trial_to_run.config)
result = training_function(trial_to_run.config)
if result is not None:
if isinstance(result, dict):
tune.report(**result)
else:
tune.report(_metric=result)
_runner.stop_trial(trial_to_run)
if verbose > 0:
logger.handlers.clear()
return ExperimentAnalysis(_runner.get_trials(), metric=metric, mode=mode)

View File

@@ -1 +1 @@
__version__ = "0.5.4"
__version__ = "0.5.5"