metric constraint (#90)

* penalty change

* metric modification

* catboost init
This commit is contained in:
Chi Wang
2021-05-22 08:51:38 -07:00
committed by GitHub
parent 0925e2b308
commit b206363c9a
3 changed files with 55 additions and 8 deletions

View File

@@ -922,6 +922,7 @@ class AutoML:
# set up learner search space
for estimator_name in estimator_list:
estimator_class = self._state.learner_classes[estimator_name]
estimator_class.init()
self._search_states[estimator_name] = SearchState(
learner_class=estimator_class,
data_size=self._state.data_size, task=self._state.task,

View File

@@ -163,6 +163,11 @@ class BaseEstimator:
'''[optional method] relative cost compared to lightgbm'''
return 1.0
@classmethod
def init(cls):
'''[optional method] initialize the class'''
pass
class SKLearnEstimator(BaseEstimator):
@@ -632,6 +637,11 @@ class CatBoostEstimator(BaseEstimator):
def cost_relative2lgbm(cls):
return 15
@classmethod
def init(cls):
CatBoostEstimator._time_per_iter = None
CatBoostEstimator._train_size = 0
def __init__(
self, task='binary:logistic', n_jobs=1,
n_estimators=8192, learning_rate=0.1, early_stopping_rounds=4, **params

View File

@@ -27,6 +27,8 @@ class BlendSearch(Searcher):
'''
cost_attr = "time_total_s" # cost attribute in result
lagrange = '_lagrange' # suffix for lagrange-modified metric
penalty = 1e+10 # penalty term for constraints
def __init__(self,
metric: Optional[str] = None,
@@ -106,6 +108,11 @@ class BlendSearch(Searcher):
self._metric, self._mode = metric, mode
init_config = low_cost_partial_config or {}
self._points_to_evaluate = points_to_evaluate or []
self._config_constraints = config_constraints
self._metric_constraints = metric_constraints
if self._metric_constraints:
# metric modified by lagrange
metric += self.lagrange
if global_search_alg is not None:
self._gs = global_search_alg
elif getattr(self, '__name__', None) != 'CFO':
@@ -115,8 +122,6 @@ class BlendSearch(Searcher):
self._ls = LocalSearch(
init_config, metric, mode, cat_hp_cost, space,
prune_attr, min_resource, max_resource, reduction_factor, seed)
self._config_constraints = config_constraints
self._metric_constraints = metric_constraints
self._init_search()
def set_search_properties(self,
@@ -131,6 +136,11 @@ class BlendSearch(Searcher):
else:
if metric:
self._metric = metric
if self._metric_constraints:
# metric modified by lagrange
metric += self.lagrange
# TODO: don't change metric for global search methods that
# can handle constraints already
if mode:
self._mode = mode
self._ls.set_search_properties(metric, mode, config)
@@ -156,6 +166,13 @@ class BlendSearch(Searcher):
self._gs_admissible_max = self._ls_bound_max.copy()
self._result = {} # config_signature: tuple -> result: Dict
self._deadline = np.inf
if self._metric_constraints:
self._metric_constraint_satisfied = False
self._metric_constraint_penalty = [
self.penalty for _ in self._metric_constraints]
else:
self._metric_constraint_satisfied = True
self._metric_constraint_penalty = None
def save(self, checkpoint_path: str):
save_object = self
@@ -182,6 +199,8 @@ class BlendSearch(Searcher):
self._ls = state._ls
self._config_constraints = state._config_constraints
self._metric_constraints = state._metric_constraints
self._metric_constraint_satisfied = state._metric_constraint_satisfied
self._metric_constraint_penalty = state._metric_constraint_penalty
def restore_from_dir(self, checkpoint_dir: str):
super.restore_from_dir(checkpoint_dir)
@@ -190,10 +209,11 @@ class BlendSearch(Searcher):
error: bool = False):
''' search thread updater and cleaner
'''
metric_constraint_satisfied = True
if result and not error and self._metric_constraints:
# accout for metric constraints if any
# account for metric constraints if any
objective = result[self._metric]
for constraint in self._metric_constraints:
for i, constraint in enumerate(self._metric_constraints):
metric_constraint, sign, threshold = constraint
value = result.get(metric_constraint)
if value:
@@ -202,8 +222,16 @@ class BlendSearch(Searcher):
violation = (value - threshold) * sign_op
if violation > 0:
# add penalty term to the metric
objective += 1e+10 * violation * self._ls.metric_op
result[self._metric] = objective
objective += self._metric_constraint_penalty[
i] * violation * self._ls.metric_op
metric_constraint_satisfied = False
if self._metric_constraint_penalty[i] < self.penalty:
self._metric_constraint_penalty[i] += violation
result[self._metric + self.lagrange] = objective
if metric_constraint_satisfied and not self._metric_constraint_satisfied:
# found a feasible point
self._metric_constraint_penalty = [1 for _ in self._metric_constraints]
self._metric_constraint_satisfied |= metric_constraint_satisfied
thread_id = self._trial_proposed_by.get(trial_id)
if thread_id in self._search_thread_pool:
self._search_thread_pool[thread_id].on_trial_complete(
@@ -219,10 +247,13 @@ class BlendSearch(Searcher):
else: # add to result cache
self._result[self._ls.config_signature(config)] = result
# update target metric if improved
objective = result[self._metric]
objective = result[
self._metric + self.lagrange] if self._metric_constraints \
else result[self._metric]
if (objective - self._metric_target) * self._ls.metric_op < 0:
self._metric_target = objective
if not thread_id and self._create_condition(result):
if not thread_id and metric_constraint_satisfied \
and self._create_condition(result):
# thread creator
self._search_thread_pool[self._thread_count] = SearchThread(
self._ls.mode,
@@ -233,6 +264,9 @@ class BlendSearch(Searcher):
self._thread_count += 1
self._update_admissible_region(
config, self._ls_bound_min, self._ls_bound_max)
elif thread_id and not self._metric_constraint_satisfied:
# no point has been found to satisfy metric constraint
self._expand_admissible_region()
# reset admissible region to ls bounding box
self._gs_admissible_min.update(self._ls_bound_min)
self._gs_admissible_max.update(self._ls_bound_max)
@@ -306,6 +340,8 @@ class BlendSearch(Searcher):
thread_id = self._trial_proposed_by[trial_id]
if thread_id not in self._search_thread_pool:
return
if result and self._metric_constraints:
result[self._metric + self.lagrange] = result[self._metric]
self._search_thread_pool[thread_id].on_trial_result(trial_id, result)
def suggest(self, trial_id: str) -> Optional[Dict]: