diff --git a/.github/workflows/pre-commit.yml b/.github/workflows/pre-commit.yml new file mode 100644 index 0000000000..b3abaf8b6c --- /dev/null +++ b/.github/workflows/pre-commit.yml @@ -0,0 +1,26 @@ +name: Code formatting + +# see: https://help.github.com/en/actions/reference/events-that-trigger-workflows +on: # Trigger the workflow on push or pull request, but only for the main branch + push: + branches: [main] + pull_request: {} + +defaults: + run: + shell: bash + +jobs: + + pre-commit-check: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v3 + - uses: actions/setup-python@v4 + - name: Set $PY environment variable + run: echo "PY=$(python -VV | sha256sum | cut -d' ' -f1)" >> $GITHUB_ENV + - uses: actions/cache@v3 + with: + path: ~/.cache/pre-commit + key: pre-commit|${{ env.PY }}|${{ hashFiles('.pre-commit-config.yaml') }} + - uses: pre-commit/action@v3.0.0 diff --git a/flaml/default/suggest.py b/flaml/default/suggest.py index 9e17687e7b..e5f99569b1 100644 --- a/flaml/default/suggest.py +++ b/flaml/default/suggest.py @@ -61,8 +61,15 @@ def load_config_predictor(estimator_name, task, location=None): return predictor -def suggest_config(task, X, y, estimator_or_predictor, location=None, k=None, meta_feature_fn=meta_feature): - +def suggest_config( + task, + X, + y, + estimator_or_predictor, + location=None, + k=None, + meta_feature_fn=meta_feature, +): """Suggest a list of configs for the given task and training data. The returned configs can be used as starting points for AutoML.fit(). diff --git a/flaml/tune/spark/utils.py b/flaml/tune/spark/utils.py index f9377fc61a..6b8b46166e 100644 --- a/flaml/tune/spark/utils.py +++ b/flaml/tune/spark/utils.py @@ -281,7 +281,8 @@ class PySparkOvertimeMonitor: def __enter__(self): """Enter the context manager. - This will start a monitor thread if spark is available and force_cancel is True.""" + This will start a monitor thread if spark is available and force_cancel is True. + """ if self._force_cancel and _have_spark: self._monitor_daemon = threading.Thread(target=self._monitor_overtime) # logger.setLevel("INFO")