diff --git a/azure.yaml.template b/azure.yaml.template index 852645ca0d..74ca797b2d 100644 --- a/azure.yaml.template +++ b/azure.yaml.template @@ -1,3 +1,4 @@ +azure_api_type: azure_ad azure_api_base: your-base-url-for-azure azure_api_version: api-version-for-azure azure_model_map: diff --git a/scripts/config.py b/scripts/config.py index 82f5851462..e966cce258 100644 --- a/scripts/config.py +++ b/scripts/config.py @@ -45,13 +45,12 @@ class Config(metaclass=Singleton): self.openai_api_key = os.getenv("OPENAI_API_KEY") self.temperature = float(os.getenv("TEMPERATURE", "1")) - self.use_azure = False self.use_azure = os.getenv("USE_AZURE") == 'True' self.execute_local_commands = os.getenv('EXECUTE_LOCAL_COMMANDS', 'False') == 'True' if self.use_azure: self.load_azure_config() - openai.api_type = "azure" + openai.api_type = self.openai_api_type openai.api_base = self.openai_api_base openai.api_version = self.openai_api_version @@ -121,8 +120,9 @@ class Config(metaclass=Singleton): config_params = yaml.load(file, Loader=yaml.FullLoader) except FileNotFoundError: config_params = {} - self.openai_api_base = config_params.get("azure_api_base", "") - self.openai_api_version = config_params.get("azure_api_version", "") + self.openai_api_type = os.getenv("OPENAI_API_TYPE", config_params.get("azure_api_type", "azure")) + self.openai_api_base = os.getenv("OPENAI_AZURE_API_BASE", config_params.get("azure_api_base", "")) + self.openai_api_version = os.getenv("OPENAI_AZURE_API_VERSION", config_params.get("azure_api_version", "")) self.azure_model_to_deployment_id_map = config_params.get("azure_model_map", []) def set_continuous_mode(self, value: bool):