mirror of
https://github.com/microsoft/autogen.git
synced 2026-01-24 20:48:20 -05:00
Skip transform (#665)
* Skip transform * Fix logic and docstring, add test * Add period ending to skip_transform doc * Add skip_transform to retrain_from_log method * Update test/automl/test_classification.py Co-authored-by: Xueqing Liu <liususan091219@users.noreply.github.com> Co-authored-by: Xueqing Liu <liususan091219@users.noreply.github.com>
This commit is contained in:
@@ -678,6 +678,7 @@ class AutoML(BaseEstimator):
|
||||
}
|
||||
}
|
||||
```
|
||||
skip_transform: boolean, default=False | Whether to pre-process data prior to modeling.
|
||||
fit_kwargs_by_estimator: dict, default=None | The user specified keywords arguments, grouped by estimator name.
|
||||
e.g.,
|
||||
|
||||
@@ -732,6 +733,7 @@ class AutoML(BaseEstimator):
|
||||
"fit_kwargs_by_estimator", {}
|
||||
)
|
||||
settings["custom_hp"] = settings.get("custom_hp", {})
|
||||
settings["skip_transform"] = settings.get("skip_transform", False)
|
||||
|
||||
self._estimator_type = (
|
||||
"classifier" if settings["task"] in CLASSIFICATION else "regressor"
|
||||
@@ -1119,7 +1121,7 @@ class AutoML(BaseEstimator):
|
||||
"or all columns of X are integer ids (tokenized)"
|
||||
)
|
||||
|
||||
if issparse(X_train_all):
|
||||
if issparse(X_train_all) or self._skip_transform:
|
||||
self._transformer = self._label_transformer = False
|
||||
self._X_train_all, self._y_train_all = X, y
|
||||
else:
|
||||
@@ -1540,6 +1542,7 @@ class AutoML(BaseEstimator):
|
||||
record_id=-1,
|
||||
auto_augment=None,
|
||||
custom_hp=None,
|
||||
skip_transform=None,
|
||||
fit_kwargs_by_estimator=None,
|
||||
**fit_kwargs,
|
||||
):
|
||||
@@ -1649,6 +1652,7 @@ class AutoML(BaseEstimator):
|
||||
|
||||
self._state.fit_kwargs = fit_kwargs
|
||||
self._state.custom_hp = custom_hp or self._settings.get("custom_hp")
|
||||
self._skip_transform = self._settings.get("skip_transform") if skip_transform is None else skip_transform
|
||||
self._state.fit_kwargs_by_estimator = (
|
||||
fit_kwargs_by_estimator or self._settings.get("fit_kwargs_by_estimator")
|
||||
)
|
||||
@@ -2070,6 +2074,7 @@ class AutoML(BaseEstimator):
|
||||
use_ray=None,
|
||||
metric_constraints=None,
|
||||
custom_hp=None,
|
||||
skip_transform=None,
|
||||
fit_kwargs_by_estimator=None,
|
||||
**fit_kwargs,
|
||||
):
|
||||
@@ -2274,6 +2279,8 @@ class AutoML(BaseEstimator):
|
||||
Each key is the estimator name, each value is a dict of the custom search space for that estimator. Notice the
|
||||
domain of the custom search space can either be a value of a sample.Domain object.
|
||||
|
||||
|
||||
|
||||
```python
|
||||
custom_hp = {
|
||||
"transformer_ms": {
|
||||
@@ -2287,6 +2294,7 @@ class AutoML(BaseEstimator):
|
||||
}
|
||||
```
|
||||
|
||||
skip_transform: boolean, default=False | Whether to pre-process data prior to modeling.
|
||||
fit_kwargs_by_estimator: dict, default=None | The user specified keywords arguments, grouped by estimator name.
|
||||
For TransformersEstimator, available fit_kwargs can be found from
|
||||
[TrainingArgumentsForAuto](nlp/huggingface/training_args).
|
||||
@@ -2418,6 +2426,7 @@ class AutoML(BaseEstimator):
|
||||
|
||||
self._state.fit_kwargs = fit_kwargs
|
||||
custom_hp = custom_hp or self._settings.get("custom_hp")
|
||||
self._skip_transform = self._settings.get("skip_transform") if skip_transform is None else skip_transform
|
||||
fit_kwargs_by_estimator = fit_kwargs_by_estimator or self._settings.get(
|
||||
"fit_kwargs_by_estimator"
|
||||
)
|
||||
|
||||
@@ -155,6 +155,25 @@ class TestClassification(unittest.TestCase):
|
||||
# "verbose": 4,
|
||||
"ensemble": True,
|
||||
}
|
||||
automl_settings["keep_search_state"] = True
|
||||
automl.fit(X, y, **automl_settings)
|
||||
X, y = automl._X_train_all, automl._y_train_all
|
||||
del automl
|
||||
|
||||
automl = AutoML()
|
||||
automl_settings = {
|
||||
"time_budget": 3,
|
||||
"task": "classification",
|
||||
"n_jobs": 1,
|
||||
"estimator_list": ["kneighbor"],
|
||||
"eval_method": "cv",
|
||||
"n_splits": 3,
|
||||
"metric": "accuracy",
|
||||
"log_training_metric": True,
|
||||
# "verbose": 4,
|
||||
"ensemble": True,
|
||||
"skip_transform": True,
|
||||
}
|
||||
automl.fit(X, y, **automl_settings)
|
||||
del automl
|
||||
|
||||
|
||||
Reference in New Issue
Block a user