Pass config everywhere in order to get rid of singleton (#4666)

Signed-off-by: Merwane Hamadi <merwanehamadi@gmail.com>
This commit is contained in:
merwanehamadi
2023-06-18 19:05:41 -07:00
committed by GitHub
parent 096d27f342
commit a7f805604c
44 changed files with 323 additions and 300 deletions

View File

@@ -22,7 +22,7 @@ def get_embedding(input: list[str] | list[TText]) -> list[Embedding]:
def get_embedding(
input: str | TText | list[str] | list[TText],
input: str | TText | list[str] | list[TText], config: Config
) -> Embedding | list[Embedding]:
"""Get an embedding from the ada model.
@@ -33,7 +33,6 @@ def get_embedding(
Returns:
List[float]: The embedding.
"""
cfg = Config()
multiple = isinstance(input, list) and all(not isinstance(i, int) for i in input)
if isinstance(input, str):
@@ -41,22 +40,22 @@ def get_embedding(
elif multiple and isinstance(input[0], str):
input = [text.replace("\n", " ") for text in input]
model = cfg.embedding_model
if cfg.use_azure:
kwargs = {"engine": cfg.get_azure_deployment_id_for_model(model)}
model = config.embedding_model
if config.use_azure:
kwargs = {"engine": config.get_azure_deployment_id_for_model(model)}
else:
kwargs = {"model": model}
logger.debug(
f"Getting embedding{f's for {len(input)} inputs' if multiple else ''}"
f" with model '{model}'"
+ (f" via Azure deployment '{kwargs['engine']}'" if cfg.use_azure else "")
+ (f" via Azure deployment '{kwargs['engine']}'" if config.use_azure else "")
)
embeddings = iopenai.create_embedding(
input,
**kwargs,
api_key=cfg.openai_api_key,
api_key=config.openai_api_key,
).data
if not multiple: