check API connection in parallel

This commit is contained in:
LeonOstrez
2024-09-01 19:27:19 +02:00
parent 544b40d6a1
commit 1ede2a5826

View File

@@ -1,3 +1,4 @@
import asyncio
import sys
from argparse import Namespace
from asyncio import run
@@ -62,7 +63,7 @@ async def run_project(sm: StateManager, ui: UIBase) -> bool:
async def llm_api_check(ui: UIBase) -> bool:
"""
Check whether the configured LLMs are reachable.
Check whether the configured LLMs are reachable in parallel.
:param ui: UI we'll use to report any issues
:return: True if all the LLMs are reachable.
@@ -73,29 +74,46 @@ async def llm_api_check(ui: UIBase) -> bool:
async def handler(*args, **kwargs):
pass
success = True
checked_llms: set[LLMProvider] = set()
for llm_config in config.all_llms():
if llm_config.provider in checked_llms:
continue
tasks = []
async def check_llm(llm_config):
if llm_config.provider + llm_config.model in checked_llms:
return True
checked_llms.add(llm_config.provider + llm_config.model)
client_class = BaseLLMClient.for_provider(llm_config.provider)
llm_client = client_class(llm_config, stream_handler=handler, error_handler=handler)
try:
await ui.send_message(
f"API check for {llm_config.provider.value} {llm_config.model} !",
source=pythagora_source,
)
resp = await llm_client.api_check()
if not resp:
success = False
log.warning(f"API check for {llm_config.provider.value} failed.")
return False
else:
await ui.send_message(
f"DONE {llm_config.provider.value} {llm_config.model} !",
source=pythagora_source,
)
log.info(f"API check for {llm_config.provider.value} succeeded.")
checked_llms.add(llm_config.provider)
return True
except APIError as err:
await ui.send_message(
f"API check for {llm_config.provider.value} failed with: {err}",
source=pythagora_source,
)
log.warning(f"API check for {llm_config.provider.value} failed with: {err}")
success = False
return False
for llm_config in config.all_llms():
tasks.append(check_llm(llm_config))
results = await asyncio.gather(*tasks)
success = all(results)
if not success:
telemetry.set("end_result", "failure:api-error")