mirror of
https://github.com/microsoft/autogen.git
synced 2026-04-20 03:02:16 -04:00
metric constraint (#90)
* penalty change * metric modification * catboost init
This commit is contained in:
@@ -922,6 +922,7 @@ class AutoML:
|
||||
# set up learner search space
|
||||
for estimator_name in estimator_list:
|
||||
estimator_class = self._state.learner_classes[estimator_name]
|
||||
estimator_class.init()
|
||||
self._search_states[estimator_name] = SearchState(
|
||||
learner_class=estimator_class,
|
||||
data_size=self._state.data_size, task=self._state.task,
|
||||
|
||||
@@ -163,6 +163,11 @@ class BaseEstimator:
|
||||
'''[optional method] relative cost compared to lightgbm'''
|
||||
return 1.0
|
||||
|
||||
@classmethod
|
||||
def init(cls):
|
||||
'''[optional method] initialize the class'''
|
||||
pass
|
||||
|
||||
|
||||
class SKLearnEstimator(BaseEstimator):
|
||||
|
||||
@@ -632,6 +637,11 @@ class CatBoostEstimator(BaseEstimator):
|
||||
def cost_relative2lgbm(cls):
|
||||
return 15
|
||||
|
||||
@classmethod
|
||||
def init(cls):
|
||||
CatBoostEstimator._time_per_iter = None
|
||||
CatBoostEstimator._train_size = 0
|
||||
|
||||
def __init__(
|
||||
self, task='binary:logistic', n_jobs=1,
|
||||
n_estimators=8192, learning_rate=0.1, early_stopping_rounds=4, **params
|
||||
|
||||
@@ -27,6 +27,8 @@ class BlendSearch(Searcher):
|
||||
'''
|
||||
|
||||
cost_attr = "time_total_s" # cost attribute in result
|
||||
lagrange = '_lagrange' # suffix for lagrange-modified metric
|
||||
penalty = 1e+10 # penalty term for constraints
|
||||
|
||||
def __init__(self,
|
||||
metric: Optional[str] = None,
|
||||
@@ -106,6 +108,11 @@ class BlendSearch(Searcher):
|
||||
self._metric, self._mode = metric, mode
|
||||
init_config = low_cost_partial_config or {}
|
||||
self._points_to_evaluate = points_to_evaluate or []
|
||||
self._config_constraints = config_constraints
|
||||
self._metric_constraints = metric_constraints
|
||||
if self._metric_constraints:
|
||||
# metric modified by lagrange
|
||||
metric += self.lagrange
|
||||
if global_search_alg is not None:
|
||||
self._gs = global_search_alg
|
||||
elif getattr(self, '__name__', None) != 'CFO':
|
||||
@@ -115,8 +122,6 @@ class BlendSearch(Searcher):
|
||||
self._ls = LocalSearch(
|
||||
init_config, metric, mode, cat_hp_cost, space,
|
||||
prune_attr, min_resource, max_resource, reduction_factor, seed)
|
||||
self._config_constraints = config_constraints
|
||||
self._metric_constraints = metric_constraints
|
||||
self._init_search()
|
||||
|
||||
def set_search_properties(self,
|
||||
@@ -131,6 +136,11 @@ class BlendSearch(Searcher):
|
||||
else:
|
||||
if metric:
|
||||
self._metric = metric
|
||||
if self._metric_constraints:
|
||||
# metric modified by lagrange
|
||||
metric += self.lagrange
|
||||
# TODO: don't change metric for global search methods that
|
||||
# can handle constraints already
|
||||
if mode:
|
||||
self._mode = mode
|
||||
self._ls.set_search_properties(metric, mode, config)
|
||||
@@ -156,6 +166,13 @@ class BlendSearch(Searcher):
|
||||
self._gs_admissible_max = self._ls_bound_max.copy()
|
||||
self._result = {} # config_signature: tuple -> result: Dict
|
||||
self._deadline = np.inf
|
||||
if self._metric_constraints:
|
||||
self._metric_constraint_satisfied = False
|
||||
self._metric_constraint_penalty = [
|
||||
self.penalty for _ in self._metric_constraints]
|
||||
else:
|
||||
self._metric_constraint_satisfied = True
|
||||
self._metric_constraint_penalty = None
|
||||
|
||||
def save(self, checkpoint_path: str):
|
||||
save_object = self
|
||||
@@ -182,6 +199,8 @@ class BlendSearch(Searcher):
|
||||
self._ls = state._ls
|
||||
self._config_constraints = state._config_constraints
|
||||
self._metric_constraints = state._metric_constraints
|
||||
self._metric_constraint_satisfied = state._metric_constraint_satisfied
|
||||
self._metric_constraint_penalty = state._metric_constraint_penalty
|
||||
|
||||
def restore_from_dir(self, checkpoint_dir: str):
|
||||
super.restore_from_dir(checkpoint_dir)
|
||||
@@ -190,10 +209,11 @@ class BlendSearch(Searcher):
|
||||
error: bool = False):
|
||||
''' search thread updater and cleaner
|
||||
'''
|
||||
metric_constraint_satisfied = True
|
||||
if result and not error and self._metric_constraints:
|
||||
# accout for metric constraints if any
|
||||
# account for metric constraints if any
|
||||
objective = result[self._metric]
|
||||
for constraint in self._metric_constraints:
|
||||
for i, constraint in enumerate(self._metric_constraints):
|
||||
metric_constraint, sign, threshold = constraint
|
||||
value = result.get(metric_constraint)
|
||||
if value:
|
||||
@@ -202,8 +222,16 @@ class BlendSearch(Searcher):
|
||||
violation = (value - threshold) * sign_op
|
||||
if violation > 0:
|
||||
# add penalty term to the metric
|
||||
objective += 1e+10 * violation * self._ls.metric_op
|
||||
result[self._metric] = objective
|
||||
objective += self._metric_constraint_penalty[
|
||||
i] * violation * self._ls.metric_op
|
||||
metric_constraint_satisfied = False
|
||||
if self._metric_constraint_penalty[i] < self.penalty:
|
||||
self._metric_constraint_penalty[i] += violation
|
||||
result[self._metric + self.lagrange] = objective
|
||||
if metric_constraint_satisfied and not self._metric_constraint_satisfied:
|
||||
# found a feasible point
|
||||
self._metric_constraint_penalty = [1 for _ in self._metric_constraints]
|
||||
self._metric_constraint_satisfied |= metric_constraint_satisfied
|
||||
thread_id = self._trial_proposed_by.get(trial_id)
|
||||
if thread_id in self._search_thread_pool:
|
||||
self._search_thread_pool[thread_id].on_trial_complete(
|
||||
@@ -219,10 +247,13 @@ class BlendSearch(Searcher):
|
||||
else: # add to result cache
|
||||
self._result[self._ls.config_signature(config)] = result
|
||||
# update target metric if improved
|
||||
objective = result[self._metric]
|
||||
objective = result[
|
||||
self._metric + self.lagrange] if self._metric_constraints \
|
||||
else result[self._metric]
|
||||
if (objective - self._metric_target) * self._ls.metric_op < 0:
|
||||
self._metric_target = objective
|
||||
if not thread_id and self._create_condition(result):
|
||||
if not thread_id and metric_constraint_satisfied \
|
||||
and self._create_condition(result):
|
||||
# thread creator
|
||||
self._search_thread_pool[self._thread_count] = SearchThread(
|
||||
self._ls.mode,
|
||||
@@ -233,6 +264,9 @@ class BlendSearch(Searcher):
|
||||
self._thread_count += 1
|
||||
self._update_admissible_region(
|
||||
config, self._ls_bound_min, self._ls_bound_max)
|
||||
elif thread_id and not self._metric_constraint_satisfied:
|
||||
# no point has been found to satisfy metric constraint
|
||||
self._expand_admissible_region()
|
||||
# reset admissible region to ls bounding box
|
||||
self._gs_admissible_min.update(self._ls_bound_min)
|
||||
self._gs_admissible_max.update(self._ls_bound_max)
|
||||
@@ -306,6 +340,8 @@ class BlendSearch(Searcher):
|
||||
thread_id = self._trial_proposed_by[trial_id]
|
||||
if thread_id not in self._search_thread_pool:
|
||||
return
|
||||
if result and self._metric_constraints:
|
||||
result[self._metric + self.lagrange] = result[self._metric]
|
||||
self._search_thread_pool[thread_id].on_trial_result(trial_id, result)
|
||||
|
||||
def suggest(self, trial_id: str) -> Optional[Dict]:
|
||||
|
||||
Reference in New Issue
Block a user