mirror of
https://github.com/microsoft/autogen.git
synced 2026-04-20 03:02:16 -04:00
fix ner bug; refactor post processing of TransformersEstimator prediction (#615)
* fix ner bug; refactor post processing * fix too many values to unpack * supporting id/token label for NER
This commit is contained in:
@@ -25,10 +25,10 @@ def custom_metric(
|
||||
else:
|
||||
trainer = estimator._trainer
|
||||
if y_test is not None:
|
||||
X_test, _ = estimator._preprocess(X_test)
|
||||
X_test = estimator._preprocess(X_test)
|
||||
eval_dataset = Dataset.from_pandas(TransformersEstimator._join(X_test, y_test))
|
||||
else:
|
||||
X_test, _ = estimator._preprocess(X_test)
|
||||
X_test = estimator._preprocess(X_test)
|
||||
eval_dataset = Dataset.from_pandas(X_test)
|
||||
|
||||
estimator_metric_backup = estimator._metric
|
||||
|
||||
@@ -6,6 +6,7 @@ from utils import get_toy_data_summarization, get_automl_settings
|
||||
|
||||
@pytest.mark.skipif(sys.platform == "darwin", reason="do not run on mac os")
|
||||
def test_summarization():
|
||||
# TODO: manual test for how effective postprocess_seq2seq_prediction_label is
|
||||
from flaml import AutoML
|
||||
|
||||
X_train, y_train, X_val, y_val, X_test = get_toy_data_summarization()
|
||||
|
||||
@@ -1,17 +1,21 @@
|
||||
import sys
|
||||
import pytest
|
||||
import requests
|
||||
from utils import get_toy_data_tokenclassification, get_automl_settings
|
||||
from utils import (
|
||||
get_toy_data_tokenclassification_idlabel,
|
||||
get_toy_data_tokenclassification_tokenlabel,
|
||||
get_automl_settings,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
sys.platform == "darwin" or sys.version < "3.7",
|
||||
reason="do not run on mac os or py<3.7",
|
||||
)
|
||||
def test_tokenclassification():
|
||||
def test_tokenclassification_idlabel():
|
||||
from flaml import AutoML
|
||||
|
||||
X_train, y_train, X_val, y_val = get_toy_data_tokenclassification()
|
||||
X_train, y_train, X_val, y_val = get_toy_data_tokenclassification_idlabel()
|
||||
automl = AutoML()
|
||||
|
||||
automl_settings = get_automl_settings()
|
||||
@@ -42,6 +46,66 @@ def test_tokenclassification():
|
||||
except requests.exceptions.HTTPError:
|
||||
return
|
||||
|
||||
# perf test
|
||||
import json
|
||||
|
||||
with open("seqclass.log", "r") as fin:
|
||||
for line in fin:
|
||||
each_log = json.loads(line.strip("\n"))
|
||||
if "validation_loss" in each_log:
|
||||
val_loss = each_log["validation_loss"]
|
||||
min_inter_result = min(
|
||||
each_dict.get("eval_automl_metric", sys.maxsize)
|
||||
for each_dict in each_log["logged_metric"]["intermediate_results"]
|
||||
)
|
||||
|
||||
if min_inter_result != sys.maxsize:
|
||||
assert val_loss == min_inter_result
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
sys.platform == "darwin" or sys.version < "3.7",
|
||||
reason="do not run on mac os or py<3.7",
|
||||
)
|
||||
def test_tokenclassification_tokenlabel():
|
||||
from flaml import AutoML
|
||||
|
||||
X_train, y_train, X_val, y_val = get_toy_data_tokenclassification_tokenlabel()
|
||||
automl = AutoML()
|
||||
|
||||
automl_settings = get_automl_settings()
|
||||
automl_settings["task"] = "token-classification"
|
||||
automl_settings[
|
||||
"metric"
|
||||
] = "seqeval:overall_f1" # evaluating based on the overall_f1 of seqeval
|
||||
|
||||
try:
|
||||
automl.fit(
|
||||
X_train=X_train,
|
||||
y_train=y_train,
|
||||
X_val=X_val,
|
||||
y_val=y_val,
|
||||
**automl_settings
|
||||
)
|
||||
except requests.exceptions.HTTPError:
|
||||
return
|
||||
|
||||
# perf test
|
||||
import json
|
||||
|
||||
with open("seqclass.log", "r") as fin:
|
||||
for line in fin:
|
||||
each_log = json.loads(line.strip("\n"))
|
||||
if "validation_loss" in each_log:
|
||||
val_loss = each_log["validation_loss"]
|
||||
min_inter_result = min(
|
||||
each_dict.get("eval_automl_metric", sys.maxsize)
|
||||
for each_dict in each_log["logged_metric"]["intermediate_results"]
|
||||
)
|
||||
|
||||
if min_inter_result != sys.maxsize:
|
||||
assert val_loss == min_inter_result
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_tokenclassification()
|
||||
test_tokenclassification_idlabel()
|
||||
|
||||
@@ -406,7 +406,8 @@ def get_toy_data_summarization():
|
||||
return X_train, y_train, X_val, y_val, X_test
|
||||
|
||||
|
||||
def get_toy_data_tokenclassification():
|
||||
def get_toy_data_tokenclassification_idlabel():
|
||||
# test token classification when the labels are ids
|
||||
train_data = {
|
||||
"chunk_tags": [
|
||||
[11, 21, 11, 12, 21, 22, 11, 12, 0],
|
||||
@@ -1116,6 +1117,391 @@ def get_toy_data_tokenclassification():
|
||||
return X_train, y_train, X_val, y_val
|
||||
|
||||
|
||||
def get_toy_data_tokenclassification_tokenlabel():
|
||||
# test token classification when the labels are tokens
|
||||
train_data = {
|
||||
"id": ["0", "1", "2", "3"],
|
||||
"ner_tags": [
|
||||
["B-ORG", "O", "B-MISC", "O", "O", "O", "B-MISC", "O", "O"],
|
||||
["B-PER", "I-PER"],
|
||||
["B-LOC", "O"],
|
||||
[
|
||||
"O",
|
||||
"B-ORG",
|
||||
"I-ORG",
|
||||
"O",
|
||||
"O",
|
||||
"O",
|
||||
"O",
|
||||
"O",
|
||||
"O",
|
||||
"B-MISC",
|
||||
"O",
|
||||
"O",
|
||||
"O",
|
||||
"O",
|
||||
"O",
|
||||
"B-MISC",
|
||||
"O",
|
||||
"O",
|
||||
"O",
|
||||
"O",
|
||||
"O",
|
||||
"O",
|
||||
"O",
|
||||
"O",
|
||||
"O",
|
||||
"O",
|
||||
"O",
|
||||
"O",
|
||||
"O",
|
||||
"O",
|
||||
],
|
||||
],
|
||||
"tokens": [
|
||||
[
|
||||
"EU",
|
||||
"rejects",
|
||||
"German",
|
||||
"call",
|
||||
"to",
|
||||
"boycott",
|
||||
"British",
|
||||
"lamb",
|
||||
".",
|
||||
],
|
||||
["Peter", "Blackburn"],
|
||||
["BRUSSELS", "1996-08-22"],
|
||||
[
|
||||
"The",
|
||||
"European",
|
||||
"Commission",
|
||||
"said",
|
||||
"on",
|
||||
"Thursday",
|
||||
"it",
|
||||
"disagreed",
|
||||
"with",
|
||||
"German",
|
||||
"advice",
|
||||
"to",
|
||||
"consumers",
|
||||
"to",
|
||||
"shun",
|
||||
"British",
|
||||
"lamb",
|
||||
"until",
|
||||
"scientists",
|
||||
"determine",
|
||||
"whether",
|
||||
"mad",
|
||||
"cow",
|
||||
"disease",
|
||||
"can",
|
||||
"be",
|
||||
"transmitted",
|
||||
"to",
|
||||
"sheep",
|
||||
".",
|
||||
],
|
||||
],
|
||||
}
|
||||
|
||||
dev_data = {
|
||||
"id": ["4", "5", "6", "7"],
|
||||
"ner_tags": [
|
||||
[
|
||||
"B-LOC",
|
||||
"O",
|
||||
"O",
|
||||
"O",
|
||||
"O",
|
||||
"B-ORG",
|
||||
"I-ORG",
|
||||
"O",
|
||||
"O",
|
||||
"O",
|
||||
"B-PER",
|
||||
"I-PER",
|
||||
"O",
|
||||
"O",
|
||||
"O",
|
||||
"O",
|
||||
"O",
|
||||
"O",
|
||||
"O",
|
||||
"O",
|
||||
"O",
|
||||
"O",
|
||||
"O",
|
||||
"B-LOC",
|
||||
"O",
|
||||
"O",
|
||||
"O",
|
||||
"O",
|
||||
"O",
|
||||
"O",
|
||||
"O",
|
||||
],
|
||||
[
|
||||
"O",
|
||||
"O",
|
||||
"O",
|
||||
"O",
|
||||
"O",
|
||||
"O",
|
||||
"O",
|
||||
"O",
|
||||
"O",
|
||||
"O",
|
||||
"O",
|
||||
"O",
|
||||
"O",
|
||||
"O",
|
||||
"O",
|
||||
"O",
|
||||
"O",
|
||||
"O",
|
||||
"O",
|
||||
"O",
|
||||
"B-ORG",
|
||||
"O",
|
||||
"O",
|
||||
"O",
|
||||
"B-PER",
|
||||
"I-PER",
|
||||
"I-PER",
|
||||
"I-PER",
|
||||
"O",
|
||||
"O",
|
||||
"O",
|
||||
"O",
|
||||
"O",
|
||||
],
|
||||
[
|
||||
"O",
|
||||
"O",
|
||||
"O",
|
||||
"O",
|
||||
"O",
|
||||
"O",
|
||||
"O",
|
||||
"O",
|
||||
"O",
|
||||
"O",
|
||||
"O",
|
||||
"O",
|
||||
"O",
|
||||
"O",
|
||||
"O",
|
||||
"O",
|
||||
"O",
|
||||
"O",
|
||||
"O",
|
||||
"O",
|
||||
"O",
|
||||
"O",
|
||||
"B-ORG",
|
||||
"I-ORG",
|
||||
"O",
|
||||
],
|
||||
[
|
||||
"O",
|
||||
"O",
|
||||
"O",
|
||||
"O",
|
||||
"O",
|
||||
"O",
|
||||
"O",
|
||||
"B-ORG",
|
||||
"O",
|
||||
"O",
|
||||
"B-PER",
|
||||
"I-PER",
|
||||
"O",
|
||||
"O",
|
||||
"O",
|
||||
"O",
|
||||
"O",
|
||||
"O",
|
||||
"O",
|
||||
"O",
|
||||
"O",
|
||||
"O",
|
||||
"O",
|
||||
"O",
|
||||
"O",
|
||||
"O",
|
||||
"O",
|
||||
"O",
|
||||
"O",
|
||||
"O",
|
||||
"O",
|
||||
"O",
|
||||
"O",
|
||||
"O",
|
||||
"O",
|
||||
"O",
|
||||
"O",
|
||||
"O",
|
||||
"O",
|
||||
"O",
|
||||
],
|
||||
],
|
||||
"tokens": [
|
||||
[
|
||||
"Germany",
|
||||
"'s",
|
||||
"representative",
|
||||
"to",
|
||||
"the",
|
||||
"European",
|
||||
"Union",
|
||||
"'s",
|
||||
"veterinary",
|
||||
"committee",
|
||||
"Werner",
|
||||
"Zwingmann",
|
||||
"said",
|
||||
"on",
|
||||
"Wednesday",
|
||||
"consumers",
|
||||
"should",
|
||||
"buy",
|
||||
"sheepmeat",
|
||||
"from",
|
||||
"countries",
|
||||
"other",
|
||||
"than",
|
||||
"Britain",
|
||||
"until",
|
||||
"the",
|
||||
"scientific",
|
||||
"advice",
|
||||
"was",
|
||||
"clearer",
|
||||
".",
|
||||
],
|
||||
[
|
||||
'"',
|
||||
"We",
|
||||
"do",
|
||||
"n't",
|
||||
"support",
|
||||
"any",
|
||||
"such",
|
||||
"recommendation",
|
||||
"because",
|
||||
"we",
|
||||
"do",
|
||||
"n't",
|
||||
"see",
|
||||
"any",
|
||||
"grounds",
|
||||
"for",
|
||||
"it",
|
||||
",",
|
||||
'"',
|
||||
"the",
|
||||
"Commission",
|
||||
"'s",
|
||||
"chief",
|
||||
"spokesman",
|
||||
"Nikolaus",
|
||||
"van",
|
||||
"der",
|
||||
"Pas",
|
||||
"told",
|
||||
"a",
|
||||
"news",
|
||||
"briefing",
|
||||
".",
|
||||
],
|
||||
[
|
||||
"He",
|
||||
"said",
|
||||
"further",
|
||||
"scientific",
|
||||
"study",
|
||||
"was",
|
||||
"required",
|
||||
"and",
|
||||
"if",
|
||||
"it",
|
||||
"was",
|
||||
"found",
|
||||
"that",
|
||||
"action",
|
||||
"was",
|
||||
"needed",
|
||||
"it",
|
||||
"should",
|
||||
"be",
|
||||
"taken",
|
||||
"by",
|
||||
"the",
|
||||
"European",
|
||||
"Union",
|
||||
".",
|
||||
],
|
||||
[
|
||||
"He",
|
||||
"said",
|
||||
"a",
|
||||
"proposal",
|
||||
"last",
|
||||
"month",
|
||||
"by",
|
||||
"EU",
|
||||
"Farm",
|
||||
"Commissioner",
|
||||
"Franz",
|
||||
"Fischler",
|
||||
"to",
|
||||
"ban",
|
||||
"sheep",
|
||||
"brains",
|
||||
",",
|
||||
"spleens",
|
||||
"and",
|
||||
"spinal",
|
||||
"cords",
|
||||
"from",
|
||||
"the",
|
||||
"human",
|
||||
"and",
|
||||
"animal",
|
||||
"food",
|
||||
"chains",
|
||||
"was",
|
||||
"a",
|
||||
"highly",
|
||||
"specific",
|
||||
"and",
|
||||
"precautionary",
|
||||
"move",
|
||||
"to",
|
||||
"protect",
|
||||
"human",
|
||||
"health",
|
||||
".",
|
||||
],
|
||||
],
|
||||
}
|
||||
train_dataset = pd.DataFrame(train_data)
|
||||
dev_dataset = pd.DataFrame(dev_data)
|
||||
|
||||
custom_sent_keys = ["tokens"]
|
||||
label_key = "ner_tags"
|
||||
|
||||
X_train = train_dataset[custom_sent_keys]
|
||||
y_train = train_dataset[label_key]
|
||||
|
||||
X_val = dev_dataset[custom_sent_keys]
|
||||
y_val = dev_dataset[label_key]
|
||||
return X_train, y_train, X_val, y_val
|
||||
|
||||
|
||||
def get_automl_settings(estimator_name="transformer"):
|
||||
|
||||
automl_settings = {
|
||||
|
||||
Reference in New Issue
Block a user