Update model.py (#739)

* Update model.py

* Fix review comment
This commit is contained in:
Vijaya Lakshmi Venkatraman
2022-10-02 05:55:53 +05:30
committed by GitHub
parent 74557472ec
commit 0df54ec3f3

View File

@@ -2287,7 +2287,7 @@ class TemporalFusionTransformerEstimator(SKLearnEstimator):
kwargs.get("log_dir", "lightning_logs")
) # logging results to a tensorboard
default_trainer_kwargs = dict(
gpus=self._kwargs.get("gpu_per_trial", [0])
gpus=kwargs.get("gpu_per_trial", [0])
if torch.cuda.is_available()
else None,
max_epochs=max_epochs,