Changes for Azure embedding handling

This commit is contained in:
Peter Edwards
2023-04-11 13:45:37 +02:00
parent 5a6e565c52
commit 9d33a75083
4 changed files with 9 additions and 2 deletions

View File

@@ -49,6 +49,7 @@ class Config(metaclass=Singleton):
self.openai_api_base = os.getenv("OPENAI_AZURE_API_BASE")
self.openai_api_version = os.getenv("OPENAI_AZURE_API_VERSION")
self.openai_deployment_id = os.getenv("OPENAI_AZURE_DEPLOYMENT_ID")
self.openai_embedding_deployment_id = os.getenv("OPENAI_AZURE_EMBEDDING_DEPLOYMENT_ID")
openai.api_type = "azure"
openai.api_base = self.openai_api_base
openai.api_version = self.openai_api_version

View File

@@ -1,12 +1,17 @@
"""Base class for memory providers."""
import abc
from config import AbstractSingleton
from config import Config
import openai
cfg = Config()
def get_ada_embedding(text):
text = text.replace("\n", " ")
return openai.Embedding.create(input=[text], model="text-embedding-ada-002")["data"][0]["embedding"]
if cfg.use_azure:
return openai.Embedding.create(input=[text], engine=cfg.openai_embedding_deployment_id)["data"][0]["embedding"]
else:
return openai.Embedding.create(input=[text], model="text-embedding-ada-002")["data"][0]["embedding"]
class MemoryProviderSingleton(AbstractSingleton):