load t5 model in the same format as it is saved, seems to load as float32 on Macs

This commit is contained in:
David Burnett
2024-10-17 13:25:25 +01:00
committed by psychedelicious
parent d85733f22b
commit 54b3aa1d01

View File

@@ -175,7 +175,7 @@ class T5EncoderCheckpointModel(ModelLoader):
case SubModelType.Tokenizer2:
return T5Tokenizer.from_pretrained(Path(config.path) / "tokenizer_2", max_length=512)
case SubModelType.TextEncoder2:
return T5EncoderModel.from_pretrained(Path(config.path) / "text_encoder_2")
return T5EncoderModel.from_pretrained(Path(config.path) / "text_encoder_2", torch_dtype='auto')
raise ValueError(
f"Only Tokenizer and TextEncoder submodels are currently supported. Received: {submodel_type.value if submodel_type else 'None'}"