mirror of
https://github.com/Pythagora-io/gpt-pilot.git
synced 2026-01-09 21:27:53 -05:00
@@ -117,6 +117,9 @@ class BaseAgent:
|
||||
extra_info=extra_info,
|
||||
placeholder=placeholder,
|
||||
)
|
||||
# Store the access token in the state manager
|
||||
if hasattr(response, "access_token") and response.access_token:
|
||||
self.state_manager.update_access_token(response.access_token)
|
||||
await self.state_manager.log_user_input(question, response)
|
||||
return response
|
||||
|
||||
@@ -188,7 +191,13 @@ class BaseAgent:
|
||||
llm_config = config.llm_for_agent(name)
|
||||
client_class = BaseLLMClient.for_provider(llm_config.provider)
|
||||
stream_handler = self.stream_handler if stream_output else None
|
||||
llm_client = client_class(llm_config, stream_handler=stream_handler, error_handler=self.error_handler)
|
||||
llm_client = client_class(
|
||||
llm_config,
|
||||
stream_handler=stream_handler,
|
||||
error_handler=self.error_handler,
|
||||
ui=self.ui,
|
||||
state_manager=self.state_manager,
|
||||
)
|
||||
|
||||
async def client(convo, **kwargs) -> Any:
|
||||
"""
|
||||
|
||||
@@ -97,6 +97,7 @@ def parse_arguments() -> Namespace:
|
||||
--extension-version: Version of the VSCode extension, if used
|
||||
--no-check: Disable initial LLM API check
|
||||
--use-git: Use Git for version control
|
||||
--access-token: Access token
|
||||
:return: Parsed arguments object.
|
||||
"""
|
||||
version = get_version()
|
||||
@@ -138,6 +139,7 @@ def parse_arguments() -> Namespace:
|
||||
parser.add_argument("--extension-version", help="Version of the VSCode extension", required=False)
|
||||
parser.add_argument("--no-check", help="Disable initial LLM API check", action="store_true")
|
||||
parser.add_argument("--use-git", help="Use Git for version control", action="store_true", required=False)
|
||||
parser.add_argument("--access-token", help="Access token", required=False)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
|
||||
@@ -88,7 +88,7 @@ async def run_project(sm: StateManager, ui: UIBase, args) -> bool:
|
||||
return success
|
||||
|
||||
|
||||
async def llm_api_check(ui: UIBase) -> bool:
|
||||
async def llm_api_check(ui: UIBase, sm: StateManager) -> bool:
|
||||
"""
|
||||
Check whether the configured LLMs are reachable in parallel.
|
||||
|
||||
@@ -110,7 +110,7 @@ async def llm_api_check(ui: UIBase) -> bool:
|
||||
|
||||
checked_llms.add(llm_config.provider + llm_config.model)
|
||||
client_class = BaseLLMClient.for_provider(llm_config.provider)
|
||||
llm_client = client_class(llm_config, stream_handler=handler, error_handler=handler)
|
||||
llm_client = client_class(llm_config, stream_handler=handler, error_handler=handler, ui=ui, state_manager=sm)
|
||||
try:
|
||||
resp = await llm_client.api_check()
|
||||
if not resp:
|
||||
@@ -224,7 +224,7 @@ async def run_pythagora_session(sm: StateManager, ui: UIBase, args: Namespace):
|
||||
"""
|
||||
|
||||
if not args.no_check:
|
||||
if not await llm_api_check(ui):
|
||||
if not await llm_api_check(ui, sm):
|
||||
await ui.send_message(
|
||||
"Pythagora cannot start because the LLM API is not reachable.",
|
||||
source=pythagora_source,
|
||||
@@ -283,6 +283,8 @@ async def async_main(
|
||||
telemetry.set("extension_version", args.extension_version)
|
||||
|
||||
sm = StateManager(db, ui)
|
||||
if args.access_token:
|
||||
sm.update_access_token(args.access_token)
|
||||
ui_started = await ui.start()
|
||||
if not ui_started:
|
||||
return False
|
||||
|
||||
@@ -326,17 +326,17 @@ class Config(_StrictModel):
|
||||
default={
|
||||
DEFAULT_AGENT_NAME: AgentLLMConfig(),
|
||||
CHECK_LOGS_AGENT_NAME: AgentLLMConfig(
|
||||
provider=LLMProvider.ANTHROPIC,
|
||||
provider=LLMProvider.OPENAI,
|
||||
model="claude-3-5-sonnet-20241022",
|
||||
temperature=0.5,
|
||||
),
|
||||
CODE_MONKEY_AGENT_NAME: AgentLLMConfig(
|
||||
provider=LLMProvider.ANTHROPIC,
|
||||
provider=LLMProvider.OPENAI,
|
||||
model="claude-3-5-sonnet-20241022",
|
||||
temperature=0.0,
|
||||
),
|
||||
CODE_REVIEW_AGENT_NAME: AgentLLMConfig(
|
||||
provider=LLMProvider.ANTHROPIC,
|
||||
provider=LLMProvider.OPENAI,
|
||||
model="claude-3-5-sonnet-20240620",
|
||||
temperature=0.0,
|
||||
),
|
||||
@@ -346,7 +346,7 @@ class Config(_StrictModel):
|
||||
temperature=0.0,
|
||||
),
|
||||
FRONTEND_AGENT_NAME: AgentLLMConfig(
|
||||
provider=LLMProvider.ANTHROPIC,
|
||||
provider=LLMProvider.OPENAI,
|
||||
model="claude-3-5-sonnet-20241022",
|
||||
temperature=0.0,
|
||||
),
|
||||
@@ -356,7 +356,7 @@ class Config(_StrictModel):
|
||||
temperature=0.5,
|
||||
),
|
||||
PARSE_TASK_AGENT_NAME: AgentLLMConfig(
|
||||
provider=LLMProvider.ANTHROPIC,
|
||||
provider=LLMProvider.OPENAI,
|
||||
model="claude-3-5-sonnet-20241022",
|
||||
temperature=0.0,
|
||||
),
|
||||
@@ -366,27 +366,27 @@ class Config(_StrictModel):
|
||||
temperature=0.0,
|
||||
),
|
||||
TASK_BREAKDOWN_AGENT_NAME: AgentLLMConfig(
|
||||
provider=LLMProvider.ANTHROPIC,
|
||||
provider=LLMProvider.OPENAI,
|
||||
model="claude-3-5-sonnet-20241022",
|
||||
temperature=0.5,
|
||||
),
|
||||
TECH_LEAD_PLANNING: AgentLLMConfig(
|
||||
provider=LLMProvider.ANTHROPIC,
|
||||
provider=LLMProvider.OPENAI,
|
||||
model="claude-3-5-sonnet-20240620",
|
||||
temperature=0.5,
|
||||
),
|
||||
TECH_LEAD_EPIC_BREAKDOWN: AgentLLMConfig(
|
||||
provider=LLMProvider.ANTHROPIC,
|
||||
provider=LLMProvider.OPENAI,
|
||||
model="claude-3-5-sonnet-20241022",
|
||||
temperature=0.5,
|
||||
),
|
||||
TROUBLESHOOTER_BUG_REPORT: AgentLLMConfig(
|
||||
provider=LLMProvider.ANTHROPIC,
|
||||
provider=LLMProvider.OPENAI,
|
||||
model="claude-3-5-sonnet-20240620",
|
||||
temperature=0.5,
|
||||
),
|
||||
TROUBLESHOOTER_GET_RUN_COMMAND: AgentLLMConfig(
|
||||
provider=LLMProvider.ANTHROPIC,
|
||||
provider=LLMProvider.OPENAI,
|
||||
model="claude-3-5-sonnet-20240620",
|
||||
temperature=0.0,
|
||||
),
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import asyncio
|
||||
import datetime
|
||||
import json
|
||||
import sys
|
||||
from enum import Enum
|
||||
from time import time
|
||||
from typing import Any, Callable, Optional, Tuple
|
||||
@@ -11,6 +12,8 @@ from core.config import LLMConfig, LLMProvider
|
||||
from core.llm.convo import Convo
|
||||
from core.llm.request_log import LLMRequestLog, LLMRequestStatus
|
||||
from core.log import get_logger
|
||||
from core.state.state_manager import StateManager
|
||||
from core.ui.base import UIBase, pythagora_source
|
||||
|
||||
log = get_logger(__name__)
|
||||
|
||||
@@ -48,9 +51,11 @@ class BaseLLMClient:
|
||||
def __init__(
|
||||
self,
|
||||
config: LLMConfig,
|
||||
state_manager: StateManager,
|
||||
*,
|
||||
stream_handler: Optional[Callable] = None,
|
||||
error_handler: Optional[Callable] = None,
|
||||
ui: Optional[UIBase] = None,
|
||||
):
|
||||
"""
|
||||
Initialize the client with the given configuration.
|
||||
@@ -61,6 +66,8 @@ class BaseLLMClient:
|
||||
self.config = config
|
||||
self.stream_handler = stream_handler
|
||||
self.error_handler = error_handler
|
||||
self.ui = ui
|
||||
self.state_manager = state_manager
|
||||
self._init_client()
|
||||
|
||||
def _init_client(self):
|
||||
@@ -186,6 +193,30 @@ class BaseLLMClient:
|
||||
response = None
|
||||
|
||||
try:
|
||||
access_token = self.state_manager.get_access_token()
|
||||
|
||||
if access_token:
|
||||
# Store the original client
|
||||
original_client = self.client
|
||||
|
||||
# Copy client based on its type
|
||||
if isinstance(original_client, openai.AsyncOpenAI):
|
||||
self.client = openai.AsyncOpenAI(
|
||||
api_key=original_client.api_key,
|
||||
base_url=original_client.base_url,
|
||||
default_headers={"Authorization": f"Bearer {access_token}"},
|
||||
)
|
||||
elif isinstance(original_client, anthropic.AsyncAnthropic):
|
||||
# Create new Anthropic client with custom headers
|
||||
self.client = anthropic.AsyncAnthropic(
|
||||
api_key=original_client.api_key,
|
||||
base_url=original_client.base_url,
|
||||
default_headers={"Authorization": f"Bearer {access_token}"},
|
||||
)
|
||||
else:
|
||||
# Handle other client types or raise exception
|
||||
raise ValueError(f"Unsupported client type: {type(original_client)}")
|
||||
|
||||
response, prompt_tokens, completion_tokens = await self._make_request(
|
||||
convo,
|
||||
temperature=temperature,
|
||||
@@ -244,6 +275,44 @@ class BaseLLMClient:
|
||||
# so we can't be certain that's the problem in Anthropic case.
|
||||
# Here we try to detect that and tell the user what happened.
|
||||
log.info(f"API status error: {err}")
|
||||
if getattr(err, "status_code", None) in (401, 403):
|
||||
if self.ui:
|
||||
try:
|
||||
await self.ui.send_message("Token expired")
|
||||
sys.exit(0)
|
||||
# TODO implement this to not crash in parallel
|
||||
# access_token = await self.ui.send_token_expired()
|
||||
# self.state_manager.update_access_token(access_token)
|
||||
# continue
|
||||
except Exception:
|
||||
raise APIError("Token expired")
|
||||
|
||||
if getattr(err, "status_code", None) == 400 and getattr(err, "message", None) == "not_enough_tokens":
|
||||
if self.ui:
|
||||
try:
|
||||
await self.ui.ask_question(
|
||||
"",
|
||||
buttons={},
|
||||
buttons_only=True,
|
||||
extra_info="not_enough_tokens",
|
||||
source=pythagora_source,
|
||||
)
|
||||
sys.exit(0)
|
||||
# TODO implement this to not crash in parallel
|
||||
# user_response = await self.ui.ask_question(
|
||||
# 'Not enough tokens left, please top up your account and press "Continue".',
|
||||
# buttons={"continue": "Continue", "exit": "Exit"},
|
||||
# buttons_only=True,
|
||||
# extra_info="not_enough_tokens",
|
||||
# source=pythagora_source,
|
||||
# )
|
||||
# if user_response.button == "continue":
|
||||
# continue
|
||||
# else:
|
||||
# raise APIError("Not enough tokens left")
|
||||
except Exception:
|
||||
raise APIError("Not enough tokens left")
|
||||
|
||||
try:
|
||||
if hasattr(err, "response"):
|
||||
if err.response.headers.get("Content-Type", "").startswith("application/json"):
|
||||
|
||||
@@ -53,6 +53,7 @@ class StateManager:
|
||||
self.git_available = False
|
||||
self.git_used = False
|
||||
self.options = {}
|
||||
self.access_token = None
|
||||
|
||||
@asynccontextmanager
|
||||
async def db_blocker(self):
|
||||
@@ -754,5 +755,17 @@ class StateManager:
|
||||
|
||||
return lines
|
||||
|
||||
def update_access_token(self, access_token: str):
|
||||
"""
|
||||
Store the access token in the state manager.
|
||||
"""
|
||||
self.access_token = access_token
|
||||
|
||||
def get_access_token(self) -> Optional[str]:
|
||||
"""
|
||||
Get the access token from the state manager.
|
||||
"""
|
||||
return self.access_token or None
|
||||
|
||||
|
||||
__all__ = ["StateManager"]
|
||||
|
||||
@@ -80,11 +80,13 @@ class UserInput(BaseModel):
|
||||
* `text`: User-provided text (if any).
|
||||
* `button`: Name (key) of the button the user selected (if any).
|
||||
* `cancelled`: Whether the user cancelled the input.
|
||||
* `access_token`: Access token (if any).
|
||||
"""
|
||||
|
||||
text: Optional[str] = None
|
||||
button: Optional[str] = None
|
||||
cancelled: bool = False
|
||||
access_token: Optional[str] = None
|
||||
|
||||
|
||||
class UIBase:
|
||||
@@ -141,6 +143,12 @@ class UIBase:
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
async def send_token_expired(self):
|
||||
"""
|
||||
Send the token expired message.
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
async def send_app_finished(
|
||||
self,
|
||||
app_id: Optional[str] = None,
|
||||
|
||||
@@ -46,6 +46,9 @@ class PlainConsoleUI(UIBase):
|
||||
if message:
|
||||
await self.send_message(message)
|
||||
|
||||
async def send_token_expired(self):
|
||||
await self.send_message("Access token expired")
|
||||
|
||||
async def send_app_finished(
|
||||
self,
|
||||
app_id: Optional[str] = None,
|
||||
|
||||
@@ -55,6 +55,7 @@ class MessageType(str, Enum):
|
||||
TEST_INSTRUCTIONS = "testInstructions"
|
||||
KNOWLEDGE_BASE_UPDATE = "updatedKnowledgeBase"
|
||||
STOP_APP = "stopApp"
|
||||
TOKEN_EXPIRED = "tokenExpired"
|
||||
|
||||
|
||||
class Message(BaseModel):
|
||||
@@ -65,6 +66,9 @@ class Message(BaseModel):
|
||||
* `type`: Message type (always "response" for VSC server responses)
|
||||
* `category`: Message category (eg. "agent:product-owner"), optional
|
||||
* `content`: Message content (eg. "Hello, how are you?"), optional
|
||||
* `extra_info`: Additional information (eg. "This is a hint"), optional
|
||||
* `placeholder`: Placeholder for user input, optional
|
||||
* `access_token`: Access token for user input, optional
|
||||
"""
|
||||
|
||||
type: MessageType
|
||||
@@ -74,6 +78,7 @@ class Message(BaseModel):
|
||||
extra_info: Optional[str] = None
|
||||
content: Union[str, dict, None] = None
|
||||
placeholder: Optional[str] = None
|
||||
accessToken: Optional[str] = None
|
||||
|
||||
def to_bytes(self) -> bytes:
|
||||
"""
|
||||
@@ -228,6 +233,11 @@ class IPCClientUI(UIBase):
|
||||
async def send_key_expired(self, message: Optional[str] = None):
|
||||
await self._send(MessageType.KEY_EXPIRED)
|
||||
|
||||
async def send_token_expired(self):
|
||||
await self._send(MessageType.TOKEN_EXPIRED)
|
||||
response = await self._receive()
|
||||
return response.accessToken
|
||||
|
||||
async def send_app_finished(
|
||||
self,
|
||||
app_id: Optional[str] = None,
|
||||
@@ -337,6 +347,9 @@ class IPCClientUI(UIBase):
|
||||
)
|
||||
|
||||
response = await self._receive()
|
||||
|
||||
access_token = response.accessToken
|
||||
|
||||
answer = response.content.strip()
|
||||
if answer == "exitPythagoraCore":
|
||||
raise KeyboardInterrupt()
|
||||
@@ -347,17 +360,17 @@ class IPCClientUI(UIBase):
|
||||
if buttons:
|
||||
# Answer matches one of the buttons (or maybe the default if it's a button name)
|
||||
if answer in buttons:
|
||||
return UserInput(button=answer, text=None)
|
||||
return UserInput(button=answer, text=None, access_token=access_token)
|
||||
# VSCode extension only deals with values so we need to check them as well
|
||||
value2key = {v: k for k, v in buttons.items()}
|
||||
if answer in value2key:
|
||||
return UserInput(button=value2key[answer], text=None)
|
||||
return UserInput(button=value2key[answer], text=None, access_token=access_token)
|
||||
|
||||
if answer or allow_empty:
|
||||
return UserInput(button=None, text=answer)
|
||||
return UserInput(button=None, text=answer, access_token=access_token)
|
||||
|
||||
# Empty answer which we don't allow, treat as user cancelled the input
|
||||
return UserInput(cancelled=True)
|
||||
return UserInput(cancelled=True, access_token=access_token)
|
||||
|
||||
async def send_project_stage(self, data: dict):
|
||||
await self._send(MessageType.INFO, content=json.dumps(data))
|
||||
|
||||
@@ -46,6 +46,9 @@ class VirtualUI(UIBase):
|
||||
async def send_key_expired(self, message: Optional[str] = None):
|
||||
pass
|
||||
|
||||
async def send_token_expired(self):
|
||||
pass
|
||||
|
||||
async def send_app_finished(
|
||||
self,
|
||||
app_id: Optional[str] = None,
|
||||
|
||||
@@ -58,6 +58,7 @@ def test_parse_arguments(mock_ArgumentParser):
|
||||
"--extension-version",
|
||||
"--no-check",
|
||||
"--use-git",
|
||||
"--access-token",
|
||||
}
|
||||
|
||||
parser.parse_args.assert_called_once_with()
|
||||
|
||||
@@ -7,6 +7,7 @@ from core.config import LLMConfig
|
||||
from core.llm.base import APIError
|
||||
from core.llm.convo import Convo
|
||||
from core.llm.openai_client import OpenAIClient
|
||||
from core.state.state_manager import StateManager
|
||||
|
||||
|
||||
async def mock_response_generator(*content):
|
||||
@@ -17,15 +18,28 @@ async def mock_response_generator(*content):
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("core.cli.helpers.StateManager")
|
||||
@patch("core.llm.openai_client.AsyncOpenAI")
|
||||
async def test_openai_calls_gpt(mock_AsyncOpenAI):
|
||||
async def test_openai_calls_gpt(mock_AsyncOpenAI, mock_state_manager):
|
||||
cfg = LLMConfig(model="gpt-4-turbo")
|
||||
convo = Convo("system hello").user("user hello")
|
||||
|
||||
# Create AsyncMock for the chat.completions.create method
|
||||
stream = AsyncMock(return_value=mock_response_generator("hello", None, "world"))
|
||||
mock_AsyncOpenAI.return_value.chat.completions.create = stream
|
||||
|
||||
llm = OpenAIClient(cfg)
|
||||
# Set up the complete mock chain
|
||||
mock_chat = AsyncMock()
|
||||
mock_completions = AsyncMock()
|
||||
mock_completions.create = stream
|
||||
mock_chat.completions = mock_completions
|
||||
|
||||
# Configure the AsyncOpenAI mock
|
||||
mock_client = AsyncMock()
|
||||
mock_client.chat = mock_chat
|
||||
mock_AsyncOpenAI.return_value = mock_client
|
||||
|
||||
sm = StateManager(mock_state_manager)
|
||||
llm = OpenAIClient(cfg, state_manager=sm)
|
||||
response, req_log = await llm(convo, json_mode=True)
|
||||
assert response == "helloworld"
|
||||
|
||||
@@ -49,40 +63,67 @@ async def test_openai_calls_gpt(mock_AsyncOpenAI):
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("core.cli.helpers.StateManager")
|
||||
@patch("core.llm.openai_client.AsyncOpenAI")
|
||||
async def test_openai_stream_handler(mock_AsyncOpenAI):
|
||||
async def test_openai_stream_handler(mock_AsyncOpenAI, mock_state_manager):
|
||||
cfg = LLMConfig(model="gpt-4-turbo")
|
||||
convo = Convo("system hello").user("user hello")
|
||||
|
||||
stream_handler = AsyncMock()
|
||||
|
||||
# Create AsyncMock for the chat.completions.create method
|
||||
stream = AsyncMock(return_value=mock_response_generator("hello", None, "world"))
|
||||
mock_AsyncOpenAI.return_value.chat.completions.create = stream
|
||||
|
||||
llm = OpenAIClient(cfg, stream_handler=stream_handler)
|
||||
# Set up the complete mock chain
|
||||
mock_chat = AsyncMock()
|
||||
mock_completions = AsyncMock()
|
||||
mock_completions.create = stream
|
||||
mock_chat.completions = mock_completions
|
||||
|
||||
# Configure the AsyncOpenAI mock
|
||||
mock_client = AsyncMock()
|
||||
mock_client.chat = mock_chat
|
||||
mock_AsyncOpenAI.return_value = mock_client
|
||||
|
||||
sm = StateManager(mock_state_manager)
|
||||
llm = OpenAIClient(cfg, stream_handler=stream_handler, state_manager=sm)
|
||||
await llm(convo)
|
||||
|
||||
stream_handler.assert_has_awaits([call("hello"), call("world")])
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("core.cli.helpers.StateManager")
|
||||
@patch("core.llm.openai_client.AsyncOpenAI")
|
||||
async def test_openai_parser_with_retries(mock_AsyncOpenAI):
|
||||
async def test_openai_parser_with_retries(mock_AsyncOpenAI, mock_state_manager):
|
||||
cfg = LLMConfig(model="gpt-4-turbo")
|
||||
convo = Convo("system").user("user")
|
||||
|
||||
parser = MagicMock()
|
||||
parser.side_effect = [ValueError("Try again"), "world"]
|
||||
|
||||
# Create AsyncMock for the chat.completions.create method with side effects
|
||||
stream = AsyncMock(
|
||||
side_effect=[
|
||||
mock_response_generator("hello"),
|
||||
mock_response_generator("world"),
|
||||
]
|
||||
)
|
||||
mock_AsyncOpenAI.return_value.chat.completions.create = stream
|
||||
|
||||
llm = OpenAIClient(cfg)
|
||||
# Set up the complete mock chain
|
||||
mock_chat = AsyncMock()
|
||||
mock_completions = AsyncMock()
|
||||
mock_completions.create = stream
|
||||
mock_chat.completions = mock_completions
|
||||
|
||||
# Configure the AsyncOpenAI mock
|
||||
mock_client = AsyncMock()
|
||||
mock_client.chat = mock_chat
|
||||
mock_AsyncOpenAI.return_value = mock_client
|
||||
|
||||
# Create StateManager instance
|
||||
sm = StateManager(mock_state_manager)
|
||||
llm = OpenAIClient(cfg, state_manager=sm)
|
||||
response, req_log = await llm(convo, parser=parser)
|
||||
|
||||
assert response == "world"
|
||||
@@ -101,26 +142,41 @@ async def test_openai_parser_with_retries(mock_AsyncOpenAI):
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("core.cli.helpers.StateManager")
|
||||
@patch("core.llm.openai_client.AsyncOpenAI")
|
||||
async def test_openai_parser_fails(mock_AsyncOpenAI):
|
||||
async def test_openai_parser_fails(mock_AsyncOpenAI, mock_state_manager):
|
||||
cfg = LLMConfig(model="gpt-4-turbo")
|
||||
convo = Convo("system").user("user")
|
||||
|
||||
parser = MagicMock()
|
||||
parser.side_effect = [ValueError("Try again")]
|
||||
|
||||
# Create AsyncMock for the chat.completions.create method
|
||||
stream = AsyncMock(return_value=mock_response_generator("hello"))
|
||||
mock_AsyncOpenAI.return_value.chat.completions.create = stream
|
||||
|
||||
llm = OpenAIClient(cfg)
|
||||
# Set up the complete mock chain
|
||||
mock_chat = AsyncMock()
|
||||
mock_completions = AsyncMock()
|
||||
mock_completions.create = stream
|
||||
mock_chat.completions = mock_completions
|
||||
|
||||
# Configure the AsyncOpenAI mock
|
||||
mock_client = AsyncMock()
|
||||
mock_client.chat = mock_chat
|
||||
mock_AsyncOpenAI.return_value = mock_client
|
||||
|
||||
# Create state manager
|
||||
sm = StateManager(mock_state_manager)
|
||||
llm = OpenAIClient(cfg, state_manager=sm)
|
||||
|
||||
with pytest.raises(APIError, match="Error parsing response"):
|
||||
await llm(convo, parser=parser, max_retries=1)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("core.cli.helpers.StateManager")
|
||||
@patch("core.llm.openai_client.AsyncOpenAI")
|
||||
async def test_openai_error_handler_success(mock_AsyncOpenAI):
|
||||
async def test_openai_error_handler_success(mock_AsyncOpenAI, mock_state_manager):
|
||||
"""
|
||||
Test that LLM client auto-retries up to max_retries, then calls
|
||||
the error handler to decide what next.
|
||||
@@ -137,7 +193,20 @@ async def test_openai_error_handler_success(mock_AsyncOpenAI):
|
||||
assert message == expected_errors.pop(0)
|
||||
return True
|
||||
|
||||
llm = OpenAIClient(cfg, error_handler=error_handler)
|
||||
# Set up the complete mock chain
|
||||
mock_chat = AsyncMock()
|
||||
mock_completions = AsyncMock()
|
||||
mock_chat.completions = mock_completions
|
||||
|
||||
# Configure the AsyncOpenAI mock
|
||||
mock_client = AsyncMock()
|
||||
mock_client.chat = mock_chat
|
||||
mock_AsyncOpenAI.return_value = mock_client
|
||||
|
||||
# Create StateManager instance
|
||||
sm = StateManager(mock_state_manager)
|
||||
|
||||
llm = OpenAIClient(cfg, error_handler=error_handler, state_manager=sm)
|
||||
llm._make_request = AsyncMock(
|
||||
side_effect=[
|
||||
openai.APIConnectionError(message="first", request=None), # auto-retried
|
||||
@@ -152,8 +221,9 @@ async def test_openai_error_handler_success(mock_AsyncOpenAI):
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("core.cli.helpers.StateManager")
|
||||
@patch("core.llm.openai_client.AsyncOpenAI")
|
||||
async def test_openai_error_handler_failure(mock_AsyncOpenAI):
|
||||
async def test_openai_error_handler_failure(mock_AsyncOpenAI, mock_state_manager):
|
||||
"""
|
||||
Test that LLM client raises an API error if error handler decides
|
||||
not to retry.
|
||||
@@ -161,9 +231,24 @@ async def test_openai_error_handler_failure(mock_AsyncOpenAI):
|
||||
cfg = LLMConfig(model="gpt-4-turbo")
|
||||
convo = Convo("system hello").user("user hello")
|
||||
|
||||
# Set up error handler mock
|
||||
error_handler = AsyncMock(return_value=False)
|
||||
llm = OpenAIClient(cfg, error_handler=error_handler)
|
||||
llm._make_request = AsyncMock(side_effect=[openai.APIError("test error", None, body=None)])
|
||||
|
||||
# Set up the complete mock chain
|
||||
mock_chat = AsyncMock()
|
||||
mock_completions = AsyncMock()
|
||||
mock_completions.create = AsyncMock(side_effect=[openai.APIError("test error", None, body=None)])
|
||||
mock_chat.completions = mock_completions
|
||||
|
||||
# Configure the AsyncOpenAI mock
|
||||
mock_client = AsyncMock()
|
||||
mock_client.chat = mock_chat
|
||||
mock_AsyncOpenAI.return_value = mock_client
|
||||
|
||||
# Create state manager
|
||||
sm = StateManager(mock_state_manager)
|
||||
|
||||
llm = OpenAIClient(cfg, error_handler=error_handler, state_manager=sm)
|
||||
|
||||
with pytest.raises(APIError, match="test error"):
|
||||
await llm(convo, max_retries=1)
|
||||
@@ -181,8 +266,11 @@ async def test_openai_error_handler_failure(mock_AsyncOpenAI):
|
||||
(1, "", "1h1m1s", 3661),
|
||||
],
|
||||
)
|
||||
@patch("core.cli.helpers.StateManager")
|
||||
@patch("core.llm.openai_client.AsyncOpenAI")
|
||||
def test_openai_rate_limit_parser(mock_AsyncOpenAI, remaining_tokens, reset_tokens, reset_requests, expected):
|
||||
def test_openai_rate_limit_parser(
|
||||
mock_AsyncOpenAI, mock_state_manager, remaining_tokens, reset_tokens, reset_requests, expected
|
||||
):
|
||||
headers = {
|
||||
"x-ratelimit-remaining-tokens": remaining_tokens,
|
||||
"x-ratelimit-reset-tokens": reset_tokens,
|
||||
@@ -190,5 +278,6 @@ def test_openai_rate_limit_parser(mock_AsyncOpenAI, remaining_tokens, reset_toke
|
||||
}
|
||||
err = MagicMock(response=MagicMock(headers=headers))
|
||||
|
||||
llm = OpenAIClient(LLMConfig(model="gpt-4"))
|
||||
sm = StateManager(mock_state_manager)
|
||||
llm = OpenAIClient(LLMConfig(model="gpt-4"), state_manager=sm)
|
||||
assert int(llm.rate_limit_sleep(err).total_seconds()) == expected
|
||||
|
||||
@@ -104,6 +104,7 @@ async def test_send_message():
|
||||
"full_screen": False,
|
||||
"extra_info": "test",
|
||||
"placeholder": None,
|
||||
"accessToken": None,
|
||||
},
|
||||
{
|
||||
"type": "exit",
|
||||
@@ -113,6 +114,7 @@ async def test_send_message():
|
||||
"full_screen": False,
|
||||
"extra_info": None,
|
||||
"placeholder": None,
|
||||
"accessToken": None,
|
||||
},
|
||||
]
|
||||
|
||||
@@ -142,6 +144,7 @@ async def test_stream():
|
||||
"full_screen": False,
|
||||
"extra_info": None,
|
||||
"placeholder": None,
|
||||
"accessToken": None,
|
||||
},
|
||||
{
|
||||
"type": "stream",
|
||||
@@ -151,6 +154,7 @@ async def test_stream():
|
||||
"full_screen": False,
|
||||
"extra_info": None,
|
||||
"placeholder": None,
|
||||
"accessToken": None,
|
||||
},
|
||||
{
|
||||
"type": "exit",
|
||||
@@ -160,6 +164,7 @@ async def test_stream():
|
||||
"full_screen": False,
|
||||
"extra_info": None,
|
||||
"placeholder": None,
|
||||
"accessToken": None,
|
||||
},
|
||||
]
|
||||
|
||||
|
||||
Reference in New Issue
Block a user