mirror of
https://github.com/microsoft/autogen.git
synced 2026-02-04 05:05:09 -05:00
consider num_samples in bs thread priority (#207)
* consider num_samples in bs thread priority * continue search for bs
This commit is contained in:
@@ -6,7 +6,6 @@ import numpy as np
|
||||
from flaml.searcher.suggestion import ConcurrencyLimiter
|
||||
from flaml import tune
|
||||
from flaml import CFO
|
||||
from flaml import BlendSearch
|
||||
|
||||
|
||||
class AbstractWarmStartTest:
|
||||
@@ -27,28 +26,24 @@ class AbstractWarmStartTest:
|
||||
search_alg, cost = self.set_basic_conf()
|
||||
search_alg = ConcurrencyLimiter(search_alg, 1)
|
||||
results_exp_1 = tune.run(
|
||||
cost,
|
||||
num_samples=5,
|
||||
search_alg=search_alg,
|
||||
verbose=0,
|
||||
local_dir=self.tmpdir)
|
||||
cost, num_samples=5, search_alg=search_alg, verbose=0, local_dir=self.tmpdir
|
||||
)
|
||||
checkpoint_path = os.path.join(self.tmpdir, self.experiment_name)
|
||||
search_alg.save(checkpoint_path)
|
||||
return results_exp_1, np.random.get_state(), checkpoint_path
|
||||
|
||||
def run_explicit_restore(self, random_state, checkpoint_path):
|
||||
np.random.set_state(random_state)
|
||||
search_alg2, cost = self.set_basic_conf()
|
||||
search_alg2 = ConcurrencyLimiter(search_alg2, 1)
|
||||
search_alg2.restore(checkpoint_path)
|
||||
np.random.set_state(random_state)
|
||||
return tune.run(cost, num_samples=5, search_alg=search_alg2, verbose=0)
|
||||
|
||||
def run_full(self):
|
||||
np.random.seed(162)
|
||||
search_alg3, cost = self.set_basic_conf()
|
||||
search_alg3 = ConcurrencyLimiter(search_alg3, 1)
|
||||
return tune.run(
|
||||
cost, num_samples=10, search_alg=search_alg3, verbose=0)
|
||||
return tune.run(cost, num_samples=10, search_alg=search_alg3, verbose=0)
|
||||
|
||||
def testReproduce(self):
|
||||
results_exp_1, _, _ = self.run_part_from_scratch()
|
||||
@@ -75,7 +70,7 @@ class CFOWarmStartTest(AbstractWarmStartTest, unittest.TestCase):
|
||||
}
|
||||
|
||||
def cost(param):
|
||||
tune.report(loss=(param["height"] - 14)**2 - abs(param["width"] - 3))
|
||||
tune.report(loss=(param["height"] - 14) ** 2 - abs(param["width"] - 3))
|
||||
|
||||
search_alg = CFO(
|
||||
space=space,
|
||||
@@ -86,6 +81,7 @@ class CFOWarmStartTest(AbstractWarmStartTest, unittest.TestCase):
|
||||
|
||||
return search_alg, cost
|
||||
|
||||
|
||||
# # # Not doing test for BS because of problems with random seed in OptunaSearch
|
||||
# class BlendsearchWarmStartTest(AbstractWarmStartTest, unittest.TestCase):
|
||||
# def set_basic_conf(self):
|
||||
|
||||
Reference in New Issue
Block a user