diff --git a/flaml/automl/automl.py b/flaml/automl/automl.py index 2d8ea04a1..81ec42458 100644 --- a/flaml/automl/automl.py +++ b/flaml/automl/automl.py @@ -594,6 +594,7 @@ class AutoML(BaseEstimator): return None X = self._state.task.preprocess(X, self._transformer) y_pred = estimator.predict(X, **pred_kwargs) + if ( isinstance(y_pred, np.ndarray) and y_pred.ndim > 1 diff --git a/flaml/automl/model.py b/flaml/automl/model.py index 2144be8e4..7208674c5 100644 --- a/flaml/automl/model.py +++ b/flaml/automl/model.py @@ -1191,8 +1191,13 @@ class TransformersEstimator(BaseEstimator): test_dataset = Dataset.from_pandas(X_test) new_trainer = self._init_model_for_predict() - predictions = new_trainer.predict(test_dataset) - return predictions.predictions + try: + predictions = new_trainer.predict(test_dataset).predictions + except ZeroDivisionError: + logger.warning("Zero division error appeared in HuggingFace Transformers.") + predictions = np.array([-0.05] * len(test_dataset)) + else: + return predictions def score(self, X_val: DataFrame, y_val: Series, **kwargs): import transformers @@ -1222,13 +1227,13 @@ class TransformersEstimator(BaseEstimator): new_trainer = self._init_model_for_predict() - if self._task not in NLG_TASKS: - predictions = new_trainer.predict(test_dataset) - else: - predictions = new_trainer.predict( - test_dataset, - metric_key_prefix="predict", - ) + kwargs = {} if self._task not in NLG_TASKS else {"metric_key_prefix": "predict"} + try: + predictions = new_trainer.predict(test_dataset, **kwargs) + except ZeroDivisionError: + logger.warning("Zero division error appeared in HuggingFace Transformers.") + predictions = np.array([0] * len(test_dataset)) + post_y_pred, _ = postprocess_prediction_and_true( task=self._task, y_pred=predictions.predictions, diff --git a/test/nlp/test_autohf.py b/test/nlp/test_autohf.py index 8edadc200..d751200fd 100644 --- a/test/nlp/test_autohf.py +++ b/test/nlp/test_autohf.py @@ -62,7 +62,9 @@ def test_hf_data(): **automl_settings ) automl.predict(X_test, **{"per_device_eval_batch_size": 2}) - automl.predict(["test test", "test test"]) + automl.predict(["", ""]) + automl.predict_proba(["", ""]) + automl.predict( [ ["test test", "test test"],