catch TFT logger bugs (#833)

* catch logger bugs

* indentations issues

* fix logger issues

* specify exception

* document reason for exception

* update exceptions

* disable callbacks when `logger=False`

Co-authored-by: Qingyun Wu <qingyun.wu@psu.edu>
This commit is contained in:
Kevin Chen
2022-12-02 16:56:59 -05:00
committed by GitHub
parent 2501b86444
commit d213ae8f39

View File

@@ -573,9 +573,11 @@ class TransformersEstimator(BaseEstimator):
if data_collator_class:
kwargs = {
"model": self._model_init(), # need to set model, or there's ValueError: Expected input batch_size (..) to match target batch_size (..)
"model": self._model_init(),
# need to set model, or there's ValueError: Expected input batch_size (..) to match target batch_size (..)
"label_pad_token_id": -100, # pad with token id -100
"pad_to_multiple_of": 8, # pad to multiple of 8 because quote Transformers: "This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability >= 7.5 (Volta)"
"pad_to_multiple_of": 8,
# pad to multiple of 8 because quote Transformers: "This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability >= 7.5 (Volta)"
"tokenizer": self.tokenizer,
}
@@ -692,7 +694,6 @@ class TransformersEstimator(BaseEstimator):
# if gpu_per_trial == 0:
# os.environ["CUDA_VISIBLE_DEVICES"] = ""
if tmp_cuda_visible_devices.count(",") != math.ceil(gpu_per_trial) - 1:
os.environ["CUDA_VISIBLE_DEVICES"] = ",".join(
[str(x) for x in range(math.ceil(gpu_per_trial))]
)
@@ -2287,36 +2288,47 @@ class TemporalFusionTransformerEstimator(SKLearnEstimator):
monitor="val_loss", min_delta=1e-4, patience=10, verbose=False, mode="min"
)
lr_logger = LearningRateMonitor() # log the learning rate
logger = TensorBoardLogger(
kwargs.get("log_dir", "lightning_logs")
) # logging results to a tensorboard
default_trainer_kwargs = dict(
gpus=kwargs.get("gpu_per_trial", [0])
if torch.cuda.is_available()
else None,
max_epochs=max_epochs,
gradient_clip_val=gradient_clip_val,
callbacks=[lr_logger, early_stop_callback],
logger=logger,
)
trainer = pl.Trainer(
**default_trainer_kwargs,
)
tft = TemporalFusionTransformer.from_dataset(
training,
**params,
lstm_layers=2, # 2 is mostly optimal according to documentation
output_size=7, # 7 quantiles by default
loss=QuantileLoss(),
log_interval=10, # uncomment for learning rate finder and otherwise, e.g. to 10 for logging every 10 batches
reduce_on_plateau_patience=4,
)
# fit network
trainer.fit(
tft,
train_dataloaders=train_dataloader,
val_dataloaders=val_dataloader,
)
def _fit(log):
default_trainer_kwargs = dict(
gpus=kwargs.get("gpu_per_trial", [0])
if torch.cuda.is_available()
else None,
max_epochs=max_epochs,
gradient_clip_val=gradient_clip_val,
callbacks=[lr_logger, early_stop_callback] if log else False,
logger=log,
)
trainer = pl.Trainer(
**default_trainer_kwargs,
)
tft = TemporalFusionTransformer.from_dataset(
training,
**params,
lstm_layers=2, # 2 is mostly optimal according to documentation
output_size=7, # 7 quantiles by default
loss=QuantileLoss(),
log_interval=10,
# uncomment for learning rate finder and otherwise, e.g. to 10 for logging every 10 batches
reduce_on_plateau_patience=4,
)
# fit network
trainer.fit(
tft,
train_dataloaders=train_dataloader,
val_dataloaders=val_dataloader,
)
return trainer
try:
logger = TensorBoardLogger(
kwargs.get("log_dir", "lightning_logs")
) # logging results to a tensorboard
trainer = _fit(log=logger)
except ValueError:
# issue with pytorch forecasting model log_prediction() function
# pytorch-forecasting issue #1145
trainer = _fit(log=False)
best_model_path = trainer.checkpoint_callback.best_model_path
best_tft = TemporalFusionTransformer.load_from_checkpoint(best_model_path)
train_time = time.time() - current_time