mirror of
https://github.com/microsoft/autogen.git
synced 2026-04-20 03:02:16 -04:00
categorical choice can be ordered or unordered (#677)
* categorical choice can be ordered or unordered * ordered -> order * move choice into utils * version comparison * packaging -> setuptools * import version * version_parse * test order for choice
This commit is contained in:
@@ -6,6 +6,7 @@ import json
|
||||
from sklearn.preprocessing import RobustScaler
|
||||
from flaml.default import greedy
|
||||
from flaml.default.regret import load_result, build_regret
|
||||
from flaml.version import __version__
|
||||
|
||||
regret_bound = 0.01
|
||||
|
||||
@@ -113,7 +114,6 @@ def serialize(configs, regret, meta_features, output_file, config_path):
|
||||
)
|
||||
portfolio = [load_json(config_path.joinpath(m + ".json")) for m in configs]
|
||||
regret = regret.loc[configs]
|
||||
from flaml import __version__
|
||||
|
||||
meta_predictor = {
|
||||
"version": __version__,
|
||||
|
||||
@@ -5,12 +5,17 @@ import pathlib
|
||||
import json
|
||||
from flaml.data import CLASSIFICATION, DataTransformer
|
||||
from flaml.ml import get_estimator_class, get_classification_objective
|
||||
from flaml.version import __version__
|
||||
|
||||
LOCATION = pathlib.Path(__file__).parent.resolve()
|
||||
logger = logging.getLogger(__name__)
|
||||
CONFIG_PREDICTORS = {}
|
||||
|
||||
|
||||
def version_parse(version):
|
||||
return tuple(map(int, (version.split("."))))
|
||||
|
||||
|
||||
def meta_feature(task, X_train, y_train, meta_feature_names):
|
||||
this_feature = []
|
||||
n_row = X_train.shape[0]
|
||||
@@ -72,11 +77,14 @@ def suggest_config(task, X, y, estimator_or_predictor, location=None, k=None):
|
||||
if isinstance(estimator_or_predictor, str)
|
||||
else estimator_or_predictor
|
||||
)
|
||||
from flaml import __version__
|
||||
|
||||
older_version = "1.0.2"
|
||||
# TODO: update older_version when the newer code can no longer handle the older version json file
|
||||
assert __version__ >= predictor["version"] >= older_version
|
||||
assert (
|
||||
version_parse(__version__)
|
||||
>= version_parse(predictor["version"])
|
||||
>= version_parse(older_version)
|
||||
)
|
||||
prep = predictor["preprocessing"]
|
||||
feature = meta_feature(
|
||||
task, X_train=X, y_train=y, meta_feature_names=predictor["meta_feature_names"]
|
||||
|
||||
@@ -32,7 +32,6 @@ from .data import (
|
||||
TOKENCLASSIFICATION,
|
||||
SUMMARIZATION,
|
||||
NLG_TASKS,
|
||||
MULTICHOICECLASSIFICATION,
|
||||
)
|
||||
|
||||
try:
|
||||
|
||||
@@ -5,7 +5,6 @@ try:
|
||||
from ray.tune import (
|
||||
uniform,
|
||||
quniform,
|
||||
choice,
|
||||
randint,
|
||||
qrandint,
|
||||
randn,
|
||||
@@ -14,12 +13,12 @@ try:
|
||||
qloguniform,
|
||||
lograndint,
|
||||
qlograndint,
|
||||
sample,
|
||||
)
|
||||
except (ImportError, AssertionError):
|
||||
from .sample import (
|
||||
uniform,
|
||||
quniform,
|
||||
choice,
|
||||
randint,
|
||||
qrandint,
|
||||
randn,
|
||||
@@ -29,7 +28,9 @@ except (ImportError, AssertionError):
|
||||
lograndint,
|
||||
qlograndint,
|
||||
)
|
||||
from . import sample
|
||||
from .tune import run, report, INCUMBENT_RESULT
|
||||
from .sample import polynomial_expansion_set
|
||||
from .sample import PolynomialExpansionSet, Categorical, Float
|
||||
from .trial import Trial
|
||||
from .utils import choice
|
||||
|
||||
@@ -225,15 +225,18 @@ def add_cost_to_space(space: Dict, low_cost_point: Dict, choice_cost: Dict):
|
||||
domain.choice_cost = cost[ind]
|
||||
domain.const = [domain.const[i] for i in ind]
|
||||
domain.ordered = True
|
||||
elif all(
|
||||
isinstance(x, int) or isinstance(x, float) for x in domain.categories
|
||||
):
|
||||
# sort the choices by value
|
||||
ind = np.argsort(domain.categories)
|
||||
domain.categories = [domain.categories[i] for i in ind]
|
||||
domain.ordered = True
|
||||
else:
|
||||
domain.ordered = False
|
||||
ordered = getattr(domain, "ordered", None)
|
||||
if ordered is None:
|
||||
# automatically decide whether to order the choices based on the value type
|
||||
domain.ordered = ordered = all(
|
||||
isinstance(x, (int, float)) for x in domain.categories
|
||||
)
|
||||
if ordered:
|
||||
# sort the choices by value
|
||||
ind = np.argsort(domain.categories)
|
||||
domain.categories = [domain.categories[i] for i in ind]
|
||||
|
||||
if low_cost and low_cost not in domain.categories:
|
||||
assert isinstance(
|
||||
low_cost, list
|
||||
|
||||
28
flaml/tune/utils.py
Normal file
28
flaml/tune/utils.py
Normal file
@@ -0,0 +1,28 @@
|
||||
from typing import Sequence
|
||||
|
||||
try:
|
||||
from ray import __version__ as ray_version
|
||||
|
||||
assert ray_version >= "1.10.0"
|
||||
from ray.tune import sample
|
||||
except (ImportError, AssertionError):
|
||||
from . import sample
|
||||
|
||||
|
||||
def choice(categories: Sequence, order=None):
|
||||
"""Sample a categorical value.
|
||||
Sampling from ``tune.choice([1, 2])`` is equivalent to sampling from
|
||||
``np.random.choice([1, 2])``
|
||||
|
||||
Args:
|
||||
categories (Sequence): Sequence of categories to sample from.
|
||||
order (bool): Whether the categories have an order. If None, will be decided autoamtically:
|
||||
Numerical categories have an order, while string categories do not.
|
||||
"""
|
||||
domain = sample.Categorical(categories).uniform()
|
||||
domain.ordered = (
|
||||
order
|
||||
if order is not None
|
||||
else all(isinstance(x, (int, float)) for x in categories)
|
||||
)
|
||||
return domain
|
||||
@@ -1 +1 @@
|
||||
__version__ = "1.0.9"
|
||||
__version__ = "1.0.10"
|
||||
|
||||
@@ -4,7 +4,6 @@ from flaml.tune.sample import (
|
||||
Domain,
|
||||
uniform,
|
||||
quniform,
|
||||
choice,
|
||||
randint,
|
||||
qrandint,
|
||||
randn,
|
||||
@@ -14,6 +13,7 @@ from flaml.tune.sample import (
|
||||
lograndint,
|
||||
qlograndint,
|
||||
)
|
||||
from flaml.tune import choice
|
||||
|
||||
|
||||
def test_sampler():
|
||||
@@ -22,6 +22,8 @@ def test_sampler():
|
||||
print(qrandn(2, 10, 2).sample(size=2))
|
||||
c = choice([1, 2])
|
||||
print(c.domain_str, len(c), c.is_valid(3))
|
||||
c = choice([1, 2], order=False)
|
||||
print(c.domain_str, len(c), c.ordered)
|
||||
i = randint(1, 10)
|
||||
print(i.domain_str, i.is_valid(10))
|
||||
d = Domain()
|
||||
|
||||
Reference in New Issue
Block a user