diff --git a/train.py b/train.py index f8aebd3..f6d252f 100644 --- a/train.py +++ b/train.py @@ -375,6 +375,9 @@ try: if training_run_args.test_dataset: trainer.evaluate_all() + if trainer.is_fsdp_enabled: + trainer.accelerator.state.fsdp_plugin.set_state_dict_type("FULL_STATE_DICT") + if training_run_args.use_lora and training_run_args.lora_merge: trainer.save_model() # save lora @@ -388,7 +391,7 @@ try: tokenizer.save_pretrained(model_dir) except Exception as ex: - if torch.cuda.device_count() > 1: + if trainer.is_fsdp_enabled > 1: raise ex # this doesn't play nice with FSDP so don't even try print("Something bad happened! Try and save it?")