mirror of
https://github.com/microsoft/autogen.git
synced 2026-04-20 03:02:16 -04:00
fixing bug for ner (#463)
* fixing bug for ner * removing global var * adding class for trial counter * adding notebook * adding use_ray dict * updating documentation for nlp
This commit is contained in:
@@ -4,12 +4,17 @@ from requests.exceptions import ChunkedEncodingError
|
||||
|
||||
def test_automl(budget=5, dataset_format="dataframe", hpo_method=None):
|
||||
from flaml.data import load_openml_dataset
|
||||
import urllib3
|
||||
|
||||
try:
|
||||
X_train, X_test, y_train, y_test = load_openml_dataset(
|
||||
dataset_id=1169, data_dir="test/", dataset_format=dataset_format
|
||||
)
|
||||
except (OpenMLServerException, ChunkedEncodingError) as e:
|
||||
except (
|
||||
OpenMLServerException,
|
||||
ChunkedEncodingError,
|
||||
urllib3.exceptions.ReadTimeoutError,
|
||||
) as e:
|
||||
print(e)
|
||||
return
|
||||
""" import AutoML class from flaml package """
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
def test_load_args_sub():
|
||||
from flaml.nlp.utils import HPOArgs
|
||||
from flaml.nlp.utils import HFArgs
|
||||
|
||||
HPOArgs.load_args()
|
||||
HFArgs.load_args()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@@ -84,9 +84,10 @@ def test_hf_data():
|
||||
"task": "seq-classification",
|
||||
"metric": "accuracy",
|
||||
"log_file_name": "seqclass.log",
|
||||
"use_ray": False,
|
||||
}
|
||||
|
||||
automl_settings["custom_hpo_args"] = {
|
||||
automl_settings["hf_args"] = {
|
||||
"model_path": "google/electra-small-discriminator",
|
||||
"output_dir": "test/data/output/",
|
||||
"ckpt_per_epoch": 5,
|
||||
@@ -116,7 +117,6 @@ def test_hf_data():
|
||||
pickle.dump(automl, f, pickle.HIGHEST_PROTOCOL)
|
||||
with open("automl.pkl", "rb") as f:
|
||||
automl = pickle.load(f)
|
||||
shutil.rmtree("test/data/output/")
|
||||
automl.predict(X_test)
|
||||
automl.predict(["test test", "test test"])
|
||||
automl.predict(
|
||||
@@ -164,7 +164,7 @@ def _test_custom_data():
|
||||
"metric": "accuracy",
|
||||
}
|
||||
|
||||
automl_settings["custom_hpo_args"] = {
|
||||
automl_settings["hf_args"] = {
|
||||
"model_path": "google/electra-small-discriminator",
|
||||
"output_dir": "data/output/",
|
||||
"ckpt_per_epoch": 1,
|
||||
@@ -183,6 +183,16 @@ def _test_custom_data():
|
||||
]
|
||||
)
|
||||
|
||||
import pickle
|
||||
|
||||
automl.pickle("automl.pkl")
|
||||
|
||||
with open("automl.pkl", "rb") as f:
|
||||
automl = pickle.load(f)
|
||||
config = automl.best_config.copy()
|
||||
config["learner"] = automl.best_estimator
|
||||
automl.trainable(config)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_hf_data()
|
||||
|
||||
@@ -52,7 +52,7 @@ def test_classification_head():
|
||||
"metric": "accuracy",
|
||||
}
|
||||
|
||||
automl_settings["custom_hpo_args"] = {
|
||||
automl_settings["hf_args"] = {
|
||||
"model_path": "google/electra-small-discriminator",
|
||||
"output_dir": "test/data/output/",
|
||||
"ckpt_per_epoch": 1,
|
||||
|
||||
@@ -19,8 +19,7 @@ def custom_metric(
|
||||
from flaml.model import TransformersEstimator
|
||||
|
||||
if estimator._trainer is None:
|
||||
estimator._init_model_for_predict(X_test)
|
||||
trainer = estimator._trainer
|
||||
trainer, _, _ = estimator._init_model_for_predict(X_test)
|
||||
estimator._trainer = None
|
||||
else:
|
||||
trainer = estimator._trainer
|
||||
@@ -103,7 +102,7 @@ def test_custom_metric():
|
||||
"log_file_name": "seqclass.log",
|
||||
}
|
||||
|
||||
automl_settings["custom_hpo_args"] = {
|
||||
automl_settings["hf_args"] = {
|
||||
"model_path": "google/electra-small-discriminator",
|
||||
"output_dir": "data/output/",
|
||||
"ckpt_per_epoch": 1,
|
||||
|
||||
@@ -43,7 +43,7 @@ def test_cv():
|
||||
"n_splits": 3,
|
||||
}
|
||||
|
||||
automl_settings["custom_hpo_args"] = {
|
||||
automl_settings["hf_args"] = {
|
||||
"model_path": "google/electra-small-discriminator",
|
||||
"output_dir": "test/data/output/",
|
||||
"ckpt_per_epoch": 1,
|
||||
|
||||
@@ -216,7 +216,7 @@ def test_mcc():
|
||||
"log_file_name": "seqclass.log",
|
||||
}
|
||||
|
||||
automl_settings["custom_hpo_args"] = {
|
||||
automl_settings["hf_args"] = {
|
||||
"model_path": "google/electra-small-discriminator",
|
||||
"output_dir": "test/data/output/",
|
||||
"ckpt_per_epoch": 1,
|
||||
|
||||
@@ -6,6 +6,9 @@ import pytest
|
||||
def test_regression():
|
||||
try:
|
||||
import ray
|
||||
|
||||
if not ray.is_initialized():
|
||||
ray.init()
|
||||
except ImportError:
|
||||
return
|
||||
from flaml import AutoML
|
||||
@@ -65,10 +68,10 @@ def test_regression():
|
||||
"task": "seq-regression",
|
||||
"metric": "pearsonr",
|
||||
"starting_points": {"transformer": {"num_train_epochs": 1}},
|
||||
"use_ray": True,
|
||||
"use_ray": {"local_dir": "data/outut/"},
|
||||
}
|
||||
|
||||
automl_settings["custom_hpo_args"] = {
|
||||
automl_settings["hf_args"] = {
|
||||
"model_path": "google/electra-small-discriminator",
|
||||
"output_dir": "test/data/output/",
|
||||
"ckpt_per_epoch": 1,
|
||||
@@ -77,6 +80,7 @@ def test_regression():
|
||||
|
||||
ray.shutdown()
|
||||
ray.init()
|
||||
|
||||
automl.fit(
|
||||
X_train=X_train, y_train=y_train, X_val=X_val, y_val=y_val, **automl_settings
|
||||
)
|
||||
|
||||
@@ -58,7 +58,7 @@ def test_summarization():
|
||||
"log_file_name": "seqclass.log",
|
||||
}
|
||||
|
||||
automl_settings["custom_hpo_args"] = {
|
||||
automl_settings["hf_args"] = {
|
||||
"model_path": "patrickvonplaten/t5-tiny-random",
|
||||
"output_dir": "test/data/output/",
|
||||
"ckpt_per_epoch": 1,
|
||||
|
||||
@@ -726,7 +726,7 @@ def test_tokenclassification():
|
||||
"metric": "seqeval",
|
||||
}
|
||||
|
||||
automl_settings["custom_hpo_args"] = {
|
||||
automl_settings["hf_args"] = {
|
||||
"model_path": "bert-base-uncased",
|
||||
"output_dir": "test/data/output/",
|
||||
"ckpt_per_epoch": 1,
|
||||
|
||||
@@ -81,7 +81,7 @@ def _test_hf_data():
|
||||
"use_ray": True,
|
||||
}
|
||||
|
||||
automl_settings["custom_hpo_args"] = {
|
||||
automl_settings["hf_args"] = {
|
||||
"model_path": "facebook/muppet-roberta-base",
|
||||
"output_dir": "test/data/output/",
|
||||
"ckpt_per_epoch": 5,
|
||||
|
||||
Reference in New Issue
Block a user