mirror of
https://github.com/microsoft/autogen.git
synced 2026-01-25 20:28:22 -05:00
@@ -6,7 +6,12 @@ from transformers.data.data_collator import (
|
||||
)
|
||||
from collections import OrderedDict
|
||||
|
||||
from flaml.data import TOKENCLASSIFICATION, MULTICHOICECLASSIFICATION, SUMMARIZATION
|
||||
from flaml.data import (
|
||||
TOKENCLASSIFICATION,
|
||||
MULTICHOICECLASSIFICATION,
|
||||
SUMMARIZATION,
|
||||
SEQCLASSIFICATION,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -45,5 +50,6 @@ task_to_datacollator_class = OrderedDict(
|
||||
(TOKENCLASSIFICATION, DataCollatorForTokenClassification),
|
||||
(MULTICHOICECLASSIFICATION, DataCollatorForMultipleChoiceClassification),
|
||||
(SUMMARIZATION, DataCollatorForSeq2Seq),
|
||||
(SEQCLASSIFICATION, DataCollatorWithPadding),
|
||||
]
|
||||
)
|
||||
|
||||
@@ -56,7 +56,7 @@ def test_hf_data():
|
||||
record_id=0,
|
||||
**automl_settings
|
||||
)
|
||||
automl.predict(X_test)
|
||||
automl.predict(X_test, **{"per_device_eval_batch_size": 2})
|
||||
automl.predict(["test test", "test test"])
|
||||
automl.predict(
|
||||
[
|
||||
|
||||
Reference in New Issue
Block a user