mirror of
https://github.com/microsoft/autogen.git
synced 2026-04-20 03:02:16 -04:00
train at least one iter when not trained (#336)
* train at least one iter when not trained * bump version to 0.9.1
This commit is contained in:
@@ -868,26 +868,24 @@ class LGBMEstimator(BaseEstimator):
|
||||
)
|
||||
if trained and max_iter <= self.params[self.ITER_HP]:
|
||||
return time.time() - start_time
|
||||
self.params[self.ITER_HP] = max_iter
|
||||
if self.params[self.ITER_HP] > 0:
|
||||
if self.HAS_CALLBACK:
|
||||
self._fit(
|
||||
X_train,
|
||||
y_train,
|
||||
callbacks=self._callbacks(start_time, deadline),
|
||||
**kwargs,
|
||||
)
|
||||
best_iteration = (
|
||||
self._model.get_booster().best_iteration
|
||||
if isinstance(self, XGBoostSklearnEstimator)
|
||||
else self._model.best_iteration_
|
||||
)
|
||||
if best_iteration is not None:
|
||||
self._model.set_params(n_estimators=best_iteration + 1)
|
||||
else:
|
||||
self._fit(X_train, y_train, **kwargs)
|
||||
# when not trained, train at least one iter
|
||||
self.params[self.ITER_HP] = max(max_iter, 1)
|
||||
if self.HAS_CALLBACK:
|
||||
self._fit(
|
||||
X_train,
|
||||
y_train,
|
||||
callbacks=self._callbacks(start_time, deadline),
|
||||
**kwargs,
|
||||
)
|
||||
best_iteration = (
|
||||
self._model.get_booster().best_iteration
|
||||
if isinstance(self, XGBoostSklearnEstimator)
|
||||
else self._model.best_iteration_
|
||||
)
|
||||
if best_iteration is not None:
|
||||
self._model.set_params(n_estimators=best_iteration + 1)
|
||||
else:
|
||||
self.params[self.ITER_HP] = self._model.n_estimators
|
||||
self._fit(X_train, y_train, **kwargs)
|
||||
train_time = time.time() - start_time
|
||||
return train_time
|
||||
|
||||
|
||||
@@ -1 +1 @@
|
||||
__version__ = "0.9.0"
|
||||
__version__ = "0.9.1"
|
||||
|
||||
Reference in New Issue
Block a user