From 9109a8ae7f4f81f9d147a686b468f2df5d8bef0a Mon Sep 17 00:00:00 2001 From: Goran Peretin Date: Sat, 25 May 2024 19:01:39 +0200 Subject: [PATCH] Check LLM API connection before doing anything else. (#949) --- core/agents/orchestrator.py | 57 ------------------- core/cli/main.py | 91 +++++++++++++++++++++++++------ core/config/__init__.py | 7 +++ core/llm/base.py | 23 ++++++-- tests/agents/test_orchestrator.py | 48 +--------------- tests/cli/test_cli.py | 8 ++- 6 files changed, 108 insertions(+), 126 deletions(-) diff --git a/core/agents/orchestrator.py b/core/agents/orchestrator.py index 8d2c0dc1..4e2b5172 100644 --- a/core/agents/orchestrator.py +++ b/core/agents/orchestrator.py @@ -15,8 +15,6 @@ from core.agents.task_reviewer import TaskReviewer from core.agents.tech_lead import TechLead from core.agents.tech_writer import TechnicalWriter from core.agents.troubleshooter import Troubleshooter -from core.config import LLMProvider, get_config -from core.llm.convo import Convo from core.log import get_logger from core.telemetry import telemetry from core.ui.base import ProjectStage @@ -53,10 +51,6 @@ class Orchestrator(BaseAgent): await self.init_ui() await self.offline_changes_check() - llm_api_check = await self.test_llm_access() - if not llm_api_check: - return False - # TODO: consider refactoring this into two loop; the outer with one iteration per comitted step, # and the inner which runs the agents for the current step until they're done. This would simplify # handle_done() and let us do other per-step processing (eg. describing files) in between agent runs. @@ -78,57 +72,6 @@ class Orchestrator(BaseAgent): # TODO: rollback changes to "next" so they aren't accidentally committed? return True - async def test_llm_access(self) -> bool: - """ - Make sure the LLMs for all the defined agents are reachable. - - Each LLM provider is only checked once. - Returns True if the check for successful for all LLMs. - """ - - config = get_config() - defined_agents = config.agent.keys() - - convo = Convo() - convo.user( - " ".join( - [ - "This is a connection test. If you can see this,", - "please respond only with 'START' and nothing else.", - ] - ) - ) - - success = True - tested_llms: set[LLMProvider] = set() - for agent_name in defined_agents: - llm = self.get_llm(agent_name) - llm_config = config.llm_for_agent(agent_name) - - if llm_config.provider in tested_llms: - continue - - tested_llms.add(llm_config.provider) - provider_model_combo = f"{llm_config.provider.value} {llm_config.model}" - try: - resp = await llm(convo) - except Exception as err: - log.warning(f"API check for {provider_model_combo} failed: {err}") - success = False - await self.ui.send_message(f"Error connecting to the {provider_model_combo} API: {err}") - continue - - if resp and len(resp) > 0: - log.debug(f"API check for {provider_model_combo} passed.") - else: - log.warning(f"API check for {provider_model_combo} failed.") - await self.ui.send_message( - f"Error connecting to the {provider_model_combo} API. Please check your settings and internet connection." - ) - success = False - - return success - async def offline_changes_check(self): """ Check for changes outside of Pythagora. diff --git a/core/cli/main.py b/core/cli/main.py index af8ddc2a..dcd086bb 100644 --- a/core/cli/main.py +++ b/core/cli/main.py @@ -4,9 +4,10 @@ from asyncio import run from core.agents.orchestrator import Orchestrator from core.cli.helpers import delete_project, init, list_projects, list_projects_json, load_project, show_config +from core.config import LLMProvider, get_config from core.db.session import SessionManager from core.db.v0importer import LegacyDatabaseImporter -from core.llm.base import APIError +from core.llm.base import APIError, BaseLLMClient from core.log import get_logger from core.state.state_manager import StateManager from core.telemetry import telemetry @@ -58,6 +59,45 @@ async def run_project(sm: StateManager, ui: UIBase) -> bool: return success +async def llm_api_check(ui: UIBase) -> bool: + """ + Check whether the configured LLMs are reachable. + + :param ui: UI we'll use to report any issues + :return: True if all the LLMs are reachable. + """ + + config = get_config() + + async def handler(*args, **kwargs): + pass + + success = True + checked_llms: set[LLMProvider] = set() + for llm_config in config.all_llms(): + if llm_config.provider in checked_llms: + continue + + client_class = BaseLLMClient.for_provider(llm_config.provider) + llm_client = client_class(llm_config, stream_handler=handler, error_handler=handler) + try: + resp = await llm_client.api_check() + if not resp: + success = False + log.warning(f"API check for {llm_config.provider.value} failed.") + else: + log.info(f"API check for {llm_config.provider.value} succeeded.") + except APIError as err: + await ui.send_message(f"API check for {llm_config.provider.value} failed with: {err}") + log.warning(f"API check for {llm_config.provider.value} failed with: {err}") + success = False + + if not success: + telemetry.set("end_result", "failure:api-error") + + return success + + async def start_new_project(sm: StateManager, ui: UIBase) -> bool: """ Start a new project. @@ -74,6 +114,36 @@ async def start_new_project(sm: StateManager, ui: UIBase) -> bool: return project_state is not None +async def run_pythagora_session(sm: StateManager, ui: UIBase, args: Namespace): + """ + Run a Pythagora session. + + :param sm: State manager. + :param ui: User interface. + :param args: Command-line arguments. + :return: True if the application ran successfully, False otherwise. + """ + + if not await llm_api_check(ui): + return False + + if args.project or args.branch or args.step: + telemetry.set("is_continuation", True) + # FIXME: we should send the project stage and other runtime info to the UI + success = await load_project(sm, args.project, args.branch, args.step) + if not success: + return False + elif args.delete: + success = await delete_project(sm, args.delete) + return success + else: + success = await start_new_project(sm, ui) + if not success: + return False + + return await run_project(sm, ui) + + async def async_main( ui: UIBase, db: SessionManager, @@ -112,21 +182,10 @@ async def async_main( if not ui_started: return False - if args.project or args.branch or args.step: - telemetry.set("is_continuation", True) - # FIXME: we should send the project stage and other runtime info to the UI - success = await load_project(sm, args.project, args.branch, args.step) - if not success: - return False - elif args.delete: - success = await delete_project(sm, args.delete) - return success - else: - success = await start_new_project(sm, ui) - if not success: - return False - - return await run_project(sm, ui) + telemetry.start() + success = await run_pythagora_session(sm, ui, args) + await telemetry.send() + return success def run_pythagora(): diff --git a/core/config/__init__.py b/core/config/__init__.py index 69a63846..533689de 100644 --- a/core/config/__init__.py +++ b/core/config/__init__.py @@ -299,6 +299,13 @@ class Config(_StrictModel): provider_config = self.llm[agent_config.provider] return LLMConfig.from_provider_and_agent_configs(provider_config, agent_config) + def all_llms(self) -> list[LLMConfig]: + """ + Get configuration for all defined LLMs. + """ + + return [self.llm_for_agent(agent) for agent in self.agent] + class ConfigLoader: """ diff --git a/core/llm/base.py b/core/llm/base.py index 4a4a6bfa..656d72b4 100644 --- a/core/llm/base.py +++ b/core/llm/base.py @@ -192,7 +192,8 @@ class BaseLLMClient: wait_time = self.rate_limit_sleep(err) if wait_time: message = f"We've hit {self.config.provider.value} rate limit. Sleeping for {wait_time.seconds} seconds..." - await self.error_handler(LLMError.RATE_LIMITED, message) + if self.error_handler: + await self.error_handler(LLMError.RATE_LIMITED, message) await asyncio.sleep(wait_time.seconds) continue else: @@ -207,9 +208,10 @@ class BaseLLMClient: err_msg = err.response.json().get("error", {}).get("message", "Incorrect API key") if "[BricksLLM]" in err_msg: # We only want to show the key expired message if it's from Bricks - should_retry = await self.error_handler(LLMError.KEY_EXPIRED) - if should_retry: - continue + if self.error_handler: + should_retry = await self.error_handler(LLMError.KEY_EXPIRED) + if should_retry: + continue raise APIError(err_msg) from err except (openai.APIStatusError, anthropic.APIStatusError, groq.APIStatusError) as err: @@ -268,6 +270,19 @@ class BaseLLMClient: return response, request_log + async def api_check(self) -> bool: + """ + Perform an LLM API check. + + :return: True if the check was successful, False otherwise. + """ + + convo = Convo() + msg = "This is a connection test. If you can see this, please respond only with 'START' and nothing else." + convo.user(msg) + resp, _log = await self(convo) + return bool(resp) + @staticmethod def for_provider(provider: LLMProvider) -> type["BaseLLMClient"]: """ diff --git a/tests/agents/test_orchestrator.py b/tests/agents/test_orchestrator.py index 24c24d83..416835b4 100644 --- a/tests/agents/test_orchestrator.py +++ b/tests/agents/test_orchestrator.py @@ -1,55 +1,9 @@ -from unittest.mock import AsyncMock, Mock, patch +from unittest.mock import AsyncMock, Mock import pytest from core.agents.orchestrator import Orchestrator from core.state.state_manager import StateManager -from core.ui.console import PlainConsoleUI - - -@pytest.mark.asyncio -@patch("core.agents.base.BaseLLMClient") -@patch("core.state.state_manager.StateManager") -async def test_check_llms_are_accessible(mock_StateManager, mock_BaseLLMClient): - mock_sm = mock_StateManager.return_value - mock_sm.log_llm_request = AsyncMock() - - mock_OpenAIClient = mock_BaseLLMClient.for_provider.return_value - mock_client = AsyncMock(return_value=("START", "log")) - mock_OpenAIClient.return_value = mock_client - - orca = Orchestrator(mock_sm, PlainConsoleUI()) - assert await orca.test_llm_access() - - -@pytest.mark.asyncio -@patch("core.agents.base.BaseLLMClient") -@patch("core.state.state_manager.StateManager") -async def test_check_llms_returns_fail_if_one_fails(mock_StateManager, mock_BaseLLMClient): - mock_sm = mock_StateManager.return_value - mock_sm.log_llm_request = AsyncMock() - - mock_OpenAIClient = mock_BaseLLMClient.for_provider.return_value - mock_client = AsyncMock(return_value=(None, "log")) - mock_OpenAIClient.return_value = mock_client - - orca = Orchestrator(mock_sm, PlainConsoleUI()) - assert await orca.test_llm_access() is False - - -@pytest.mark.asyncio -@patch("core.agents.base.BaseLLMClient") -@patch("core.state.state_manager.StateManager") -async def test_check_llms_returns_fail_if_llm_throws_exception(mock_StateManager, mock_BaseLLMClient): - mock_sm = mock_StateManager.return_value - mock_sm.log_llm_request = AsyncMock() - - mock_OpenAIClient = mock_BaseLLMClient.for_provider.return_value - mock_client = AsyncMock(side_effect=ValueError("Invalid API key")) - mock_OpenAIClient.return_value = mock_client - - orca = Orchestrator(mock_sm, PlainConsoleUI()) - assert await orca.test_llm_access() is False @pytest.mark.asyncio diff --git a/tests/cli/test_cli.py b/tests/cli/test_cli.py index 3ca5e910..ba4cc488 100644 --- a/tests/cli/test_cli.py +++ b/tests/cli/test_cli.py @@ -287,8 +287,10 @@ def test_init(tmp_path): ([], True, True), ], ) +@patch("core.cli.main.llm_api_check") @patch("core.cli.main.Orchestrator") -async def test_main(mock_Orchestrator, args, run_orchestrator, retval, tmp_path): +async def test_main(mock_Orchestrator, mock_llm_check, args, run_orchestrator, retval, tmp_path): + mock_llm_check.return_value = True config_file = write_test_config(tmp_path) class MockArgumentParser(ArgumentParser): @@ -313,8 +315,10 @@ async def test_main(mock_Orchestrator, args, run_orchestrator, retval, tmp_path) @pytest.mark.asyncio +@patch("core.cli.main.llm_api_check") @patch("core.cli.main.Orchestrator") -async def test_main_handles_crash(mock_Orchestrator, tmp_path, capsys): +async def test_main_handles_crash(mock_Orchestrator, mock_llm_check, tmp_path): + mock_llm_check.return_value = True config_file = write_test_config(tmp_path) class MockArgumentParser(ArgumentParser):