add a list_compatible_loras() method

This commit is contained in:
Lincoln Stein
2023-04-13 00:11:26 -04:00
parent 6dfbd1c677
commit afa3cdce27

View File

@@ -55,9 +55,18 @@ class LoraManager:
return conditions
return None
def list_compatible_loras(self)->Dict[str, Path]:
'''
List all the LoRAs in the global lora directory that
are compatible with the current model. Return a dictionary
of the lora basename and its path.
'''
model_length = self.kohya.text_encoder.get_input_embeddings().weight.data[0].shape[0]
return self.list_loras(model_length)
@classmethod
def list_loras(self, token_vector_length:int=None)->Dict[str, Path]:
@staticmethod
def list_loras(token_vector_length:int=None)->Dict[str, Path]:
'''List the LoRAS in the global lora directory.
If token_vector_length is provided, then only return
LoRAS that have the indicated length:
@@ -79,3 +88,4 @@ class LoraManager:
models_found[name]=Path(root,x) # conditional on the base model matching
return models_found