consider num_samples in bs thread priority (#207)

* consider num_samples in bs thread priority

* continue search for bs
This commit is contained in:
Chi Wang
2021-09-14 18:36:10 -07:00
committed by GitHub
parent ea6c6ded2f
commit a9d39b71da
3 changed files with 376 additions and 231 deletions

View File

@@ -6,7 +6,6 @@ import numpy as np
from flaml.searcher.suggestion import ConcurrencyLimiter
from flaml import tune
from flaml import CFO
from flaml import BlendSearch
class AbstractWarmStartTest:
@@ -27,28 +26,24 @@ class AbstractWarmStartTest:
search_alg, cost = self.set_basic_conf()
search_alg = ConcurrencyLimiter(search_alg, 1)
results_exp_1 = tune.run(
cost,
num_samples=5,
search_alg=search_alg,
verbose=0,
local_dir=self.tmpdir)
cost, num_samples=5, search_alg=search_alg, verbose=0, local_dir=self.tmpdir
)
checkpoint_path = os.path.join(self.tmpdir, self.experiment_name)
search_alg.save(checkpoint_path)
return results_exp_1, np.random.get_state(), checkpoint_path
def run_explicit_restore(self, random_state, checkpoint_path):
np.random.set_state(random_state)
search_alg2, cost = self.set_basic_conf()
search_alg2 = ConcurrencyLimiter(search_alg2, 1)
search_alg2.restore(checkpoint_path)
np.random.set_state(random_state)
return tune.run(cost, num_samples=5, search_alg=search_alg2, verbose=0)
def run_full(self):
np.random.seed(162)
search_alg3, cost = self.set_basic_conf()
search_alg3 = ConcurrencyLimiter(search_alg3, 1)
return tune.run(
cost, num_samples=10, search_alg=search_alg3, verbose=0)
return tune.run(cost, num_samples=10, search_alg=search_alg3, verbose=0)
def testReproduce(self):
results_exp_1, _, _ = self.run_part_from_scratch()
@@ -75,7 +70,7 @@ class CFOWarmStartTest(AbstractWarmStartTest, unittest.TestCase):
}
def cost(param):
tune.report(loss=(param["height"] - 14)**2 - abs(param["width"] - 3))
tune.report(loss=(param["height"] - 14) ** 2 - abs(param["width"] - 3))
search_alg = CFO(
space=space,
@@ -86,6 +81,7 @@ class CFOWarmStartTest(AbstractWarmStartTest, unittest.TestCase):
return search_alg, cost
# # # Not doing test for BS because of problems with random seed in OptunaSearch
# class BlendsearchWarmStartTest(AbstractWarmStartTest, unittest.TestCase):
# def set_basic_conf(self):