fix ner bug; refactor post processing of TransformersEstimator prediction (#615)

* fix ner bug; refactor post processing

* fix too many values to unpack

* supporting id/token label for NER
This commit is contained in:
Xueqing Liu
2022-07-05 13:38:21 -04:00
committed by GitHub
parent 6dd93bc939
commit 6108493e0b
11 changed files with 847 additions and 133 deletions

View File

@@ -25,10 +25,10 @@ def custom_metric(
else:
trainer = estimator._trainer
if y_test is not None:
X_test, _ = estimator._preprocess(X_test)
X_test = estimator._preprocess(X_test)
eval_dataset = Dataset.from_pandas(TransformersEstimator._join(X_test, y_test))
else:
X_test, _ = estimator._preprocess(X_test)
X_test = estimator._preprocess(X_test)
eval_dataset = Dataset.from_pandas(X_test)
estimator_metric_backup = estimator._metric

View File

@@ -6,6 +6,7 @@ from utils import get_toy_data_summarization, get_automl_settings
@pytest.mark.skipif(sys.platform == "darwin", reason="do not run on mac os")
def test_summarization():
# TODO: manual test for how effective postprocess_seq2seq_prediction_label is
from flaml import AutoML
X_train, y_train, X_val, y_val, X_test = get_toy_data_summarization()

View File

@@ -1,17 +1,21 @@
import sys
import pytest
import requests
from utils import get_toy_data_tokenclassification, get_automl_settings
from utils import (
get_toy_data_tokenclassification_idlabel,
get_toy_data_tokenclassification_tokenlabel,
get_automl_settings,
)
@pytest.mark.skipif(
sys.platform == "darwin" or sys.version < "3.7",
reason="do not run on mac os or py<3.7",
)
def test_tokenclassification():
def test_tokenclassification_idlabel():
from flaml import AutoML
X_train, y_train, X_val, y_val = get_toy_data_tokenclassification()
X_train, y_train, X_val, y_val = get_toy_data_tokenclassification_idlabel()
automl = AutoML()
automl_settings = get_automl_settings()
@@ -42,6 +46,66 @@ def test_tokenclassification():
except requests.exceptions.HTTPError:
return
# perf test
import json
with open("seqclass.log", "r") as fin:
for line in fin:
each_log = json.loads(line.strip("\n"))
if "validation_loss" in each_log:
val_loss = each_log["validation_loss"]
min_inter_result = min(
each_dict.get("eval_automl_metric", sys.maxsize)
for each_dict in each_log["logged_metric"]["intermediate_results"]
)
if min_inter_result != sys.maxsize:
assert val_loss == min_inter_result
@pytest.mark.skipif(
sys.platform == "darwin" or sys.version < "3.7",
reason="do not run on mac os or py<3.7",
)
def test_tokenclassification_tokenlabel():
from flaml import AutoML
X_train, y_train, X_val, y_val = get_toy_data_tokenclassification_tokenlabel()
automl = AutoML()
automl_settings = get_automl_settings()
automl_settings["task"] = "token-classification"
automl_settings[
"metric"
] = "seqeval:overall_f1" # evaluating based on the overall_f1 of seqeval
try:
automl.fit(
X_train=X_train,
y_train=y_train,
X_val=X_val,
y_val=y_val,
**automl_settings
)
except requests.exceptions.HTTPError:
return
# perf test
import json
with open("seqclass.log", "r") as fin:
for line in fin:
each_log = json.loads(line.strip("\n"))
if "validation_loss" in each_log:
val_loss = each_log["validation_loss"]
min_inter_result = min(
each_dict.get("eval_automl_metric", sys.maxsize)
for each_dict in each_log["logged_metric"]["intermediate_results"]
)
if min_inter_result != sys.maxsize:
assert val_loss == min_inter_result
if __name__ == "__main__":
test_tokenclassification()
test_tokenclassification_idlabel()

View File

@@ -406,7 +406,8 @@ def get_toy_data_summarization():
return X_train, y_train, X_val, y_val, X_test
def get_toy_data_tokenclassification():
def get_toy_data_tokenclassification_idlabel():
# test token classification when the labels are ids
train_data = {
"chunk_tags": [
[11, 21, 11, 12, 21, 22, 11, 12, 0],
@@ -1116,6 +1117,391 @@ def get_toy_data_tokenclassification():
return X_train, y_train, X_val, y_val
def get_toy_data_tokenclassification_tokenlabel():
# test token classification when the labels are tokens
train_data = {
"id": ["0", "1", "2", "3"],
"ner_tags": [
["B-ORG", "O", "B-MISC", "O", "O", "O", "B-MISC", "O", "O"],
["B-PER", "I-PER"],
["B-LOC", "O"],
[
"O",
"B-ORG",
"I-ORG",
"O",
"O",
"O",
"O",
"O",
"O",
"B-MISC",
"O",
"O",
"O",
"O",
"O",
"B-MISC",
"O",
"O",
"O",
"O",
"O",
"O",
"O",
"O",
"O",
"O",
"O",
"O",
"O",
"O",
],
],
"tokens": [
[
"EU",
"rejects",
"German",
"call",
"to",
"boycott",
"British",
"lamb",
".",
],
["Peter", "Blackburn"],
["BRUSSELS", "1996-08-22"],
[
"The",
"European",
"Commission",
"said",
"on",
"Thursday",
"it",
"disagreed",
"with",
"German",
"advice",
"to",
"consumers",
"to",
"shun",
"British",
"lamb",
"until",
"scientists",
"determine",
"whether",
"mad",
"cow",
"disease",
"can",
"be",
"transmitted",
"to",
"sheep",
".",
],
],
}
dev_data = {
"id": ["4", "5", "6", "7"],
"ner_tags": [
[
"B-LOC",
"O",
"O",
"O",
"O",
"B-ORG",
"I-ORG",
"O",
"O",
"O",
"B-PER",
"I-PER",
"O",
"O",
"O",
"O",
"O",
"O",
"O",
"O",
"O",
"O",
"O",
"B-LOC",
"O",
"O",
"O",
"O",
"O",
"O",
"O",
],
[
"O",
"O",
"O",
"O",
"O",
"O",
"O",
"O",
"O",
"O",
"O",
"O",
"O",
"O",
"O",
"O",
"O",
"O",
"O",
"O",
"B-ORG",
"O",
"O",
"O",
"B-PER",
"I-PER",
"I-PER",
"I-PER",
"O",
"O",
"O",
"O",
"O",
],
[
"O",
"O",
"O",
"O",
"O",
"O",
"O",
"O",
"O",
"O",
"O",
"O",
"O",
"O",
"O",
"O",
"O",
"O",
"O",
"O",
"O",
"O",
"B-ORG",
"I-ORG",
"O",
],
[
"O",
"O",
"O",
"O",
"O",
"O",
"O",
"B-ORG",
"O",
"O",
"B-PER",
"I-PER",
"O",
"O",
"O",
"O",
"O",
"O",
"O",
"O",
"O",
"O",
"O",
"O",
"O",
"O",
"O",
"O",
"O",
"O",
"O",
"O",
"O",
"O",
"O",
"O",
"O",
"O",
"O",
"O",
],
],
"tokens": [
[
"Germany",
"'s",
"representative",
"to",
"the",
"European",
"Union",
"'s",
"veterinary",
"committee",
"Werner",
"Zwingmann",
"said",
"on",
"Wednesday",
"consumers",
"should",
"buy",
"sheepmeat",
"from",
"countries",
"other",
"than",
"Britain",
"until",
"the",
"scientific",
"advice",
"was",
"clearer",
".",
],
[
'"',
"We",
"do",
"n't",
"support",
"any",
"such",
"recommendation",
"because",
"we",
"do",
"n't",
"see",
"any",
"grounds",
"for",
"it",
",",
'"',
"the",
"Commission",
"'s",
"chief",
"spokesman",
"Nikolaus",
"van",
"der",
"Pas",
"told",
"a",
"news",
"briefing",
".",
],
[
"He",
"said",
"further",
"scientific",
"study",
"was",
"required",
"and",
"if",
"it",
"was",
"found",
"that",
"action",
"was",
"needed",
"it",
"should",
"be",
"taken",
"by",
"the",
"European",
"Union",
".",
],
[
"He",
"said",
"a",
"proposal",
"last",
"month",
"by",
"EU",
"Farm",
"Commissioner",
"Franz",
"Fischler",
"to",
"ban",
"sheep",
"brains",
",",
"spleens",
"and",
"spinal",
"cords",
"from",
"the",
"human",
"and",
"animal",
"food",
"chains",
"was",
"a",
"highly",
"specific",
"and",
"precautionary",
"move",
"to",
"protect",
"human",
"health",
".",
],
],
}
train_dataset = pd.DataFrame(train_data)
dev_dataset = pd.DataFrame(dev_data)
custom_sent_keys = ["tokens"]
label_key = "ner_tags"
X_train = train_dataset[custom_sent_keys]
y_train = train_dataset[label_key]
X_val = dev_dataset[custom_sent_keys]
y_val = dev_dataset[label_key]
return X_train, y_train, X_val, y_val
def get_automl_settings(estimator_name="transformer"):
automl_settings = {