mirror of
https://github.com/Pythagora-io/gpt-pilot.git
synced 2026-01-10 05:27:54 -05:00
check API connection in parallel
This commit is contained in:
@@ -1,3 +1,4 @@
|
||||
import asyncio
|
||||
import sys
|
||||
from argparse import Namespace
|
||||
from asyncio import run
|
||||
@@ -62,7 +63,7 @@ async def run_project(sm: StateManager, ui: UIBase) -> bool:
|
||||
|
||||
async def llm_api_check(ui: UIBase) -> bool:
|
||||
"""
|
||||
Check whether the configured LLMs are reachable.
|
||||
Check whether the configured LLMs are reachable in parallel.
|
||||
|
||||
:param ui: UI we'll use to report any issues
|
||||
:return: True if all the LLMs are reachable.
|
||||
@@ -73,29 +74,46 @@ async def llm_api_check(ui: UIBase) -> bool:
|
||||
async def handler(*args, **kwargs):
|
||||
pass
|
||||
|
||||
success = True
|
||||
checked_llms: set[LLMProvider] = set()
|
||||
for llm_config in config.all_llms():
|
||||
if llm_config.provider in checked_llms:
|
||||
continue
|
||||
tasks = []
|
||||
|
||||
async def check_llm(llm_config):
|
||||
if llm_config.provider + llm_config.model in checked_llms:
|
||||
return True
|
||||
|
||||
checked_llms.add(llm_config.provider + llm_config.model)
|
||||
client_class = BaseLLMClient.for_provider(llm_config.provider)
|
||||
llm_client = client_class(llm_config, stream_handler=handler, error_handler=handler)
|
||||
try:
|
||||
await ui.send_message(
|
||||
f"API check for {llm_config.provider.value} {llm_config.model} !",
|
||||
source=pythagora_source,
|
||||
)
|
||||
resp = await llm_client.api_check()
|
||||
if not resp:
|
||||
success = False
|
||||
log.warning(f"API check for {llm_config.provider.value} failed.")
|
||||
return False
|
||||
else:
|
||||
await ui.send_message(
|
||||
f"DONE {llm_config.provider.value} {llm_config.model} !",
|
||||
source=pythagora_source,
|
||||
)
|
||||
log.info(f"API check for {llm_config.provider.value} succeeded.")
|
||||
checked_llms.add(llm_config.provider)
|
||||
return True
|
||||
except APIError as err:
|
||||
await ui.send_message(
|
||||
f"API check for {llm_config.provider.value} failed with: {err}",
|
||||
source=pythagora_source,
|
||||
)
|
||||
log.warning(f"API check for {llm_config.provider.value} failed with: {err}")
|
||||
success = False
|
||||
return False
|
||||
|
||||
for llm_config in config.all_llms():
|
||||
tasks.append(check_llm(llm_config))
|
||||
|
||||
results = await asyncio.gather(*tasks)
|
||||
|
||||
success = all(results)
|
||||
|
||||
if not success:
|
||||
telemetry.set("end_result", "failure:api-error")
|
||||
|
||||
Reference in New Issue
Block a user