train at least one iter when not trained (#336)

* train at least one iter when not trained

* bump version to 0.9.1
This commit is contained in:
Chi Wang
2021-12-12 20:05:18 -08:00
committed by GitHub
parent 1a3e01c352
commit 434586e2e2
2 changed files with 18 additions and 20 deletions

View File

@@ -868,26 +868,24 @@ class LGBMEstimator(BaseEstimator):
)
if trained and max_iter <= self.params[self.ITER_HP]:
return time.time() - start_time
self.params[self.ITER_HP] = max_iter
if self.params[self.ITER_HP] > 0:
if self.HAS_CALLBACK:
self._fit(
X_train,
y_train,
callbacks=self._callbacks(start_time, deadline),
**kwargs,
)
best_iteration = (
self._model.get_booster().best_iteration
if isinstance(self, XGBoostSklearnEstimator)
else self._model.best_iteration_
)
if best_iteration is not None:
self._model.set_params(n_estimators=best_iteration + 1)
else:
self._fit(X_train, y_train, **kwargs)
# when not trained, train at least one iter
self.params[self.ITER_HP] = max(max_iter, 1)
if self.HAS_CALLBACK:
self._fit(
X_train,
y_train,
callbacks=self._callbacks(start_time, deadline),
**kwargs,
)
best_iteration = (
self._model.get_booster().best_iteration
if isinstance(self, XGBoostSklearnEstimator)
else self._model.best_iteration_
)
if best_iteration is not None:
self._model.set_params(n_estimators=best_iteration + 1)
else:
self.params[self.ITER_HP] = self._model.n_estimators
self._fit(X_train, y_train, **kwargs)
train_time = time.time() - start_time
return train_time

View File

@@ -1 +1 @@
__version__ = "0.9.0"
__version__ = "0.9.1"