Merge pull request #89 from Pythagora-io/bricks

Bricks
This commit is contained in:
zvone187
2025-02-13 10:44:56 +01:00
committed by GitHub
13 changed files with 254 additions and 37 deletions

View File

@@ -117,6 +117,9 @@ class BaseAgent:
extra_info=extra_info,
placeholder=placeholder,
)
# Store the access token in the state manager
if hasattr(response, "access_token") and response.access_token:
self.state_manager.update_access_token(response.access_token)
await self.state_manager.log_user_input(question, response)
return response
@@ -188,7 +191,13 @@ class BaseAgent:
llm_config = config.llm_for_agent(name)
client_class = BaseLLMClient.for_provider(llm_config.provider)
stream_handler = self.stream_handler if stream_output else None
llm_client = client_class(llm_config, stream_handler=stream_handler, error_handler=self.error_handler)
llm_client = client_class(
llm_config,
stream_handler=stream_handler,
error_handler=self.error_handler,
ui=self.ui,
state_manager=self.state_manager,
)
async def client(convo, **kwargs) -> Any:
"""

View File

@@ -97,6 +97,7 @@ def parse_arguments() -> Namespace:
--extension-version: Version of the VSCode extension, if used
--no-check: Disable initial LLM API check
--use-git: Use Git for version control
--access-token: Access token
:return: Parsed arguments object.
"""
version = get_version()
@@ -138,6 +139,7 @@ def parse_arguments() -> Namespace:
parser.add_argument("--extension-version", help="Version of the VSCode extension", required=False)
parser.add_argument("--no-check", help="Disable initial LLM API check", action="store_true")
parser.add_argument("--use-git", help="Use Git for version control", action="store_true", required=False)
parser.add_argument("--access-token", help="Access token", required=False)
return parser.parse_args()

View File

@@ -88,7 +88,7 @@ async def run_project(sm: StateManager, ui: UIBase, args) -> bool:
return success
async def llm_api_check(ui: UIBase) -> bool:
async def llm_api_check(ui: UIBase, sm: StateManager) -> bool:
"""
Check whether the configured LLMs are reachable in parallel.
@@ -110,7 +110,7 @@ async def llm_api_check(ui: UIBase) -> bool:
checked_llms.add(llm_config.provider + llm_config.model)
client_class = BaseLLMClient.for_provider(llm_config.provider)
llm_client = client_class(llm_config, stream_handler=handler, error_handler=handler)
llm_client = client_class(llm_config, stream_handler=handler, error_handler=handler, ui=ui, state_manager=sm)
try:
resp = await llm_client.api_check()
if not resp:
@@ -224,7 +224,7 @@ async def run_pythagora_session(sm: StateManager, ui: UIBase, args: Namespace):
"""
if not args.no_check:
if not await llm_api_check(ui):
if not await llm_api_check(ui, sm):
await ui.send_message(
"Pythagora cannot start because the LLM API is not reachable.",
source=pythagora_source,
@@ -283,6 +283,8 @@ async def async_main(
telemetry.set("extension_version", args.extension_version)
sm = StateManager(db, ui)
if args.access_token:
sm.update_access_token(args.access_token)
ui_started = await ui.start()
if not ui_started:
return False

View File

@@ -326,17 +326,17 @@ class Config(_StrictModel):
default={
DEFAULT_AGENT_NAME: AgentLLMConfig(),
CHECK_LOGS_AGENT_NAME: AgentLLMConfig(
provider=LLMProvider.ANTHROPIC,
provider=LLMProvider.OPENAI,
model="claude-3-5-sonnet-20241022",
temperature=0.5,
),
CODE_MONKEY_AGENT_NAME: AgentLLMConfig(
provider=LLMProvider.ANTHROPIC,
provider=LLMProvider.OPENAI,
model="claude-3-5-sonnet-20241022",
temperature=0.0,
),
CODE_REVIEW_AGENT_NAME: AgentLLMConfig(
provider=LLMProvider.ANTHROPIC,
provider=LLMProvider.OPENAI,
model="claude-3-5-sonnet-20240620",
temperature=0.0,
),
@@ -346,7 +346,7 @@ class Config(_StrictModel):
temperature=0.0,
),
FRONTEND_AGENT_NAME: AgentLLMConfig(
provider=LLMProvider.ANTHROPIC,
provider=LLMProvider.OPENAI,
model="claude-3-5-sonnet-20241022",
temperature=0.0,
),
@@ -356,7 +356,7 @@ class Config(_StrictModel):
temperature=0.5,
),
PARSE_TASK_AGENT_NAME: AgentLLMConfig(
provider=LLMProvider.ANTHROPIC,
provider=LLMProvider.OPENAI,
model="claude-3-5-sonnet-20241022",
temperature=0.0,
),
@@ -366,27 +366,27 @@ class Config(_StrictModel):
temperature=0.0,
),
TASK_BREAKDOWN_AGENT_NAME: AgentLLMConfig(
provider=LLMProvider.ANTHROPIC,
provider=LLMProvider.OPENAI,
model="claude-3-5-sonnet-20241022",
temperature=0.5,
),
TECH_LEAD_PLANNING: AgentLLMConfig(
provider=LLMProvider.ANTHROPIC,
provider=LLMProvider.OPENAI,
model="claude-3-5-sonnet-20240620",
temperature=0.5,
),
TECH_LEAD_EPIC_BREAKDOWN: AgentLLMConfig(
provider=LLMProvider.ANTHROPIC,
provider=LLMProvider.OPENAI,
model="claude-3-5-sonnet-20241022",
temperature=0.5,
),
TROUBLESHOOTER_BUG_REPORT: AgentLLMConfig(
provider=LLMProvider.ANTHROPIC,
provider=LLMProvider.OPENAI,
model="claude-3-5-sonnet-20240620",
temperature=0.5,
),
TROUBLESHOOTER_GET_RUN_COMMAND: AgentLLMConfig(
provider=LLMProvider.ANTHROPIC,
provider=LLMProvider.OPENAI,
model="claude-3-5-sonnet-20240620",
temperature=0.0,
),

View File

@@ -1,6 +1,7 @@
import asyncio
import datetime
import json
import sys
from enum import Enum
from time import time
from typing import Any, Callable, Optional, Tuple
@@ -11,6 +12,8 @@ from core.config import LLMConfig, LLMProvider
from core.llm.convo import Convo
from core.llm.request_log import LLMRequestLog, LLMRequestStatus
from core.log import get_logger
from core.state.state_manager import StateManager
from core.ui.base import UIBase, pythagora_source
log = get_logger(__name__)
@@ -48,9 +51,11 @@ class BaseLLMClient:
def __init__(
self,
config: LLMConfig,
state_manager: StateManager,
*,
stream_handler: Optional[Callable] = None,
error_handler: Optional[Callable] = None,
ui: Optional[UIBase] = None,
):
"""
Initialize the client with the given configuration.
@@ -61,6 +66,8 @@ class BaseLLMClient:
self.config = config
self.stream_handler = stream_handler
self.error_handler = error_handler
self.ui = ui
self.state_manager = state_manager
self._init_client()
def _init_client(self):
@@ -186,6 +193,30 @@ class BaseLLMClient:
response = None
try:
access_token = self.state_manager.get_access_token()
if access_token:
# Store the original client
original_client = self.client
# Copy client based on its type
if isinstance(original_client, openai.AsyncOpenAI):
self.client = openai.AsyncOpenAI(
api_key=original_client.api_key,
base_url=original_client.base_url,
default_headers={"Authorization": f"Bearer {access_token}"},
)
elif isinstance(original_client, anthropic.AsyncAnthropic):
# Create new Anthropic client with custom headers
self.client = anthropic.AsyncAnthropic(
api_key=original_client.api_key,
base_url=original_client.base_url,
default_headers={"Authorization": f"Bearer {access_token}"},
)
else:
# Handle other client types or raise exception
raise ValueError(f"Unsupported client type: {type(original_client)}")
response, prompt_tokens, completion_tokens = await self._make_request(
convo,
temperature=temperature,
@@ -244,6 +275,44 @@ class BaseLLMClient:
# so we can't be certain that's the problem in Anthropic case.
# Here we try to detect that and tell the user what happened.
log.info(f"API status error: {err}")
if getattr(err, "status_code", None) in (401, 403):
if self.ui:
try:
await self.ui.send_message("Token expired")
sys.exit(0)
# TODO implement this to not crash in parallel
# access_token = await self.ui.send_token_expired()
# self.state_manager.update_access_token(access_token)
# continue
except Exception:
raise APIError("Token expired")
if getattr(err, "status_code", None) == 400 and getattr(err, "message", None) == "not_enough_tokens":
if self.ui:
try:
await self.ui.ask_question(
"",
buttons={},
buttons_only=True,
extra_info="not_enough_tokens",
source=pythagora_source,
)
sys.exit(0)
# TODO implement this to not crash in parallel
# user_response = await self.ui.ask_question(
# 'Not enough tokens left, please top up your account and press "Continue".',
# buttons={"continue": "Continue", "exit": "Exit"},
# buttons_only=True,
# extra_info="not_enough_tokens",
# source=pythagora_source,
# )
# if user_response.button == "continue":
# continue
# else:
# raise APIError("Not enough tokens left")
except Exception:
raise APIError("Not enough tokens left")
try:
if hasattr(err, "response"):
if err.response.headers.get("Content-Type", "").startswith("application/json"):

View File

@@ -53,6 +53,7 @@ class StateManager:
self.git_available = False
self.git_used = False
self.options = {}
self.access_token = None
@asynccontextmanager
async def db_blocker(self):
@@ -754,5 +755,17 @@ class StateManager:
return lines
def update_access_token(self, access_token: str):
"""
Store the access token in the state manager.
"""
self.access_token = access_token
def get_access_token(self) -> Optional[str]:
"""
Get the access token from the state manager.
"""
return self.access_token or None
__all__ = ["StateManager"]

View File

@@ -80,11 +80,13 @@ class UserInput(BaseModel):
* `text`: User-provided text (if any).
* `button`: Name (key) of the button the user selected (if any).
* `cancelled`: Whether the user cancelled the input.
* `access_token`: Access token (if any).
"""
text: Optional[str] = None
button: Optional[str] = None
cancelled: bool = False
access_token: Optional[str] = None
class UIBase:
@@ -141,6 +143,12 @@ class UIBase:
"""
raise NotImplementedError()
async def send_token_expired(self):
"""
Send the token expired message.
"""
raise NotImplementedError()
async def send_app_finished(
self,
app_id: Optional[str] = None,

View File

@@ -46,6 +46,9 @@ class PlainConsoleUI(UIBase):
if message:
await self.send_message(message)
async def send_token_expired(self):
await self.send_message("Access token expired")
async def send_app_finished(
self,
app_id: Optional[str] = None,

View File

@@ -55,6 +55,7 @@ class MessageType(str, Enum):
TEST_INSTRUCTIONS = "testInstructions"
KNOWLEDGE_BASE_UPDATE = "updatedKnowledgeBase"
STOP_APP = "stopApp"
TOKEN_EXPIRED = "tokenExpired"
class Message(BaseModel):
@@ -65,6 +66,9 @@ class Message(BaseModel):
* `type`: Message type (always "response" for VSC server responses)
* `category`: Message category (eg. "agent:product-owner"), optional
* `content`: Message content (eg. "Hello, how are you?"), optional
* `extra_info`: Additional information (eg. "This is a hint"), optional
* `placeholder`: Placeholder for user input, optional
* `access_token`: Access token for user input, optional
"""
type: MessageType
@@ -74,6 +78,7 @@ class Message(BaseModel):
extra_info: Optional[str] = None
content: Union[str, dict, None] = None
placeholder: Optional[str] = None
accessToken: Optional[str] = None
def to_bytes(self) -> bytes:
"""
@@ -228,6 +233,11 @@ class IPCClientUI(UIBase):
async def send_key_expired(self, message: Optional[str] = None):
await self._send(MessageType.KEY_EXPIRED)
async def send_token_expired(self):
await self._send(MessageType.TOKEN_EXPIRED)
response = await self._receive()
return response.accessToken
async def send_app_finished(
self,
app_id: Optional[str] = None,
@@ -337,6 +347,9 @@ class IPCClientUI(UIBase):
)
response = await self._receive()
access_token = response.accessToken
answer = response.content.strip()
if answer == "exitPythagoraCore":
raise KeyboardInterrupt()
@@ -347,17 +360,17 @@ class IPCClientUI(UIBase):
if buttons:
# Answer matches one of the buttons (or maybe the default if it's a button name)
if answer in buttons:
return UserInput(button=answer, text=None)
return UserInput(button=answer, text=None, access_token=access_token)
# VSCode extension only deals with values so we need to check them as well
value2key = {v: k for k, v in buttons.items()}
if answer in value2key:
return UserInput(button=value2key[answer], text=None)
return UserInput(button=value2key[answer], text=None, access_token=access_token)
if answer or allow_empty:
return UserInput(button=None, text=answer)
return UserInput(button=None, text=answer, access_token=access_token)
# Empty answer which we don't allow, treat as user cancelled the input
return UserInput(cancelled=True)
return UserInput(cancelled=True, access_token=access_token)
async def send_project_stage(self, data: dict):
await self._send(MessageType.INFO, content=json.dumps(data))

View File

@@ -46,6 +46,9 @@ class VirtualUI(UIBase):
async def send_key_expired(self, message: Optional[str] = None):
pass
async def send_token_expired(self):
pass
async def send_app_finished(
self,
app_id: Optional[str] = None,

View File

@@ -58,6 +58,7 @@ def test_parse_arguments(mock_ArgumentParser):
"--extension-version",
"--no-check",
"--use-git",
"--access-token",
}
parser.parse_args.assert_called_once_with()

View File

@@ -7,6 +7,7 @@ from core.config import LLMConfig
from core.llm.base import APIError
from core.llm.convo import Convo
from core.llm.openai_client import OpenAIClient
from core.state.state_manager import StateManager
async def mock_response_generator(*content):
@@ -17,15 +18,28 @@ async def mock_response_generator(*content):
@pytest.mark.asyncio
@patch("core.cli.helpers.StateManager")
@patch("core.llm.openai_client.AsyncOpenAI")
async def test_openai_calls_gpt(mock_AsyncOpenAI):
async def test_openai_calls_gpt(mock_AsyncOpenAI, mock_state_manager):
cfg = LLMConfig(model="gpt-4-turbo")
convo = Convo("system hello").user("user hello")
# Create AsyncMock for the chat.completions.create method
stream = AsyncMock(return_value=mock_response_generator("hello", None, "world"))
mock_AsyncOpenAI.return_value.chat.completions.create = stream
llm = OpenAIClient(cfg)
# Set up the complete mock chain
mock_chat = AsyncMock()
mock_completions = AsyncMock()
mock_completions.create = stream
mock_chat.completions = mock_completions
# Configure the AsyncOpenAI mock
mock_client = AsyncMock()
mock_client.chat = mock_chat
mock_AsyncOpenAI.return_value = mock_client
sm = StateManager(mock_state_manager)
llm = OpenAIClient(cfg, state_manager=sm)
response, req_log = await llm(convo, json_mode=True)
assert response == "helloworld"
@@ -49,40 +63,67 @@ async def test_openai_calls_gpt(mock_AsyncOpenAI):
@pytest.mark.asyncio
@patch("core.cli.helpers.StateManager")
@patch("core.llm.openai_client.AsyncOpenAI")
async def test_openai_stream_handler(mock_AsyncOpenAI):
async def test_openai_stream_handler(mock_AsyncOpenAI, mock_state_manager):
cfg = LLMConfig(model="gpt-4-turbo")
convo = Convo("system hello").user("user hello")
stream_handler = AsyncMock()
# Create AsyncMock for the chat.completions.create method
stream = AsyncMock(return_value=mock_response_generator("hello", None, "world"))
mock_AsyncOpenAI.return_value.chat.completions.create = stream
llm = OpenAIClient(cfg, stream_handler=stream_handler)
# Set up the complete mock chain
mock_chat = AsyncMock()
mock_completions = AsyncMock()
mock_completions.create = stream
mock_chat.completions = mock_completions
# Configure the AsyncOpenAI mock
mock_client = AsyncMock()
mock_client.chat = mock_chat
mock_AsyncOpenAI.return_value = mock_client
sm = StateManager(mock_state_manager)
llm = OpenAIClient(cfg, stream_handler=stream_handler, state_manager=sm)
await llm(convo)
stream_handler.assert_has_awaits([call("hello"), call("world")])
@pytest.mark.asyncio
@patch("core.cli.helpers.StateManager")
@patch("core.llm.openai_client.AsyncOpenAI")
async def test_openai_parser_with_retries(mock_AsyncOpenAI):
async def test_openai_parser_with_retries(mock_AsyncOpenAI, mock_state_manager):
cfg = LLMConfig(model="gpt-4-turbo")
convo = Convo("system").user("user")
parser = MagicMock()
parser.side_effect = [ValueError("Try again"), "world"]
# Create AsyncMock for the chat.completions.create method with side effects
stream = AsyncMock(
side_effect=[
mock_response_generator("hello"),
mock_response_generator("world"),
]
)
mock_AsyncOpenAI.return_value.chat.completions.create = stream
llm = OpenAIClient(cfg)
# Set up the complete mock chain
mock_chat = AsyncMock()
mock_completions = AsyncMock()
mock_completions.create = stream
mock_chat.completions = mock_completions
# Configure the AsyncOpenAI mock
mock_client = AsyncMock()
mock_client.chat = mock_chat
mock_AsyncOpenAI.return_value = mock_client
# Create StateManager instance
sm = StateManager(mock_state_manager)
llm = OpenAIClient(cfg, state_manager=sm)
response, req_log = await llm(convo, parser=parser)
assert response == "world"
@@ -101,26 +142,41 @@ async def test_openai_parser_with_retries(mock_AsyncOpenAI):
@pytest.mark.asyncio
@patch("core.cli.helpers.StateManager")
@patch("core.llm.openai_client.AsyncOpenAI")
async def test_openai_parser_fails(mock_AsyncOpenAI):
async def test_openai_parser_fails(mock_AsyncOpenAI, mock_state_manager):
cfg = LLMConfig(model="gpt-4-turbo")
convo = Convo("system").user("user")
parser = MagicMock()
parser.side_effect = [ValueError("Try again")]
# Create AsyncMock for the chat.completions.create method
stream = AsyncMock(return_value=mock_response_generator("hello"))
mock_AsyncOpenAI.return_value.chat.completions.create = stream
llm = OpenAIClient(cfg)
# Set up the complete mock chain
mock_chat = AsyncMock()
mock_completions = AsyncMock()
mock_completions.create = stream
mock_chat.completions = mock_completions
# Configure the AsyncOpenAI mock
mock_client = AsyncMock()
mock_client.chat = mock_chat
mock_AsyncOpenAI.return_value = mock_client
# Create state manager
sm = StateManager(mock_state_manager)
llm = OpenAIClient(cfg, state_manager=sm)
with pytest.raises(APIError, match="Error parsing response"):
await llm(convo, parser=parser, max_retries=1)
@pytest.mark.asyncio
@patch("core.cli.helpers.StateManager")
@patch("core.llm.openai_client.AsyncOpenAI")
async def test_openai_error_handler_success(mock_AsyncOpenAI):
async def test_openai_error_handler_success(mock_AsyncOpenAI, mock_state_manager):
"""
Test that LLM client auto-retries up to max_retries, then calls
the error handler to decide what next.
@@ -137,7 +193,20 @@ async def test_openai_error_handler_success(mock_AsyncOpenAI):
assert message == expected_errors.pop(0)
return True
llm = OpenAIClient(cfg, error_handler=error_handler)
# Set up the complete mock chain
mock_chat = AsyncMock()
mock_completions = AsyncMock()
mock_chat.completions = mock_completions
# Configure the AsyncOpenAI mock
mock_client = AsyncMock()
mock_client.chat = mock_chat
mock_AsyncOpenAI.return_value = mock_client
# Create StateManager instance
sm = StateManager(mock_state_manager)
llm = OpenAIClient(cfg, error_handler=error_handler, state_manager=sm)
llm._make_request = AsyncMock(
side_effect=[
openai.APIConnectionError(message="first", request=None), # auto-retried
@@ -152,8 +221,9 @@ async def test_openai_error_handler_success(mock_AsyncOpenAI):
@pytest.mark.asyncio
@patch("core.cli.helpers.StateManager")
@patch("core.llm.openai_client.AsyncOpenAI")
async def test_openai_error_handler_failure(mock_AsyncOpenAI):
async def test_openai_error_handler_failure(mock_AsyncOpenAI, mock_state_manager):
"""
Test that LLM client raises an API error if error handler decides
not to retry.
@@ -161,9 +231,24 @@ async def test_openai_error_handler_failure(mock_AsyncOpenAI):
cfg = LLMConfig(model="gpt-4-turbo")
convo = Convo("system hello").user("user hello")
# Set up error handler mock
error_handler = AsyncMock(return_value=False)
llm = OpenAIClient(cfg, error_handler=error_handler)
llm._make_request = AsyncMock(side_effect=[openai.APIError("test error", None, body=None)])
# Set up the complete mock chain
mock_chat = AsyncMock()
mock_completions = AsyncMock()
mock_completions.create = AsyncMock(side_effect=[openai.APIError("test error", None, body=None)])
mock_chat.completions = mock_completions
# Configure the AsyncOpenAI mock
mock_client = AsyncMock()
mock_client.chat = mock_chat
mock_AsyncOpenAI.return_value = mock_client
# Create state manager
sm = StateManager(mock_state_manager)
llm = OpenAIClient(cfg, error_handler=error_handler, state_manager=sm)
with pytest.raises(APIError, match="test error"):
await llm(convo, max_retries=1)
@@ -181,8 +266,11 @@ async def test_openai_error_handler_failure(mock_AsyncOpenAI):
(1, "", "1h1m1s", 3661),
],
)
@patch("core.cli.helpers.StateManager")
@patch("core.llm.openai_client.AsyncOpenAI")
def test_openai_rate_limit_parser(mock_AsyncOpenAI, remaining_tokens, reset_tokens, reset_requests, expected):
def test_openai_rate_limit_parser(
mock_AsyncOpenAI, mock_state_manager, remaining_tokens, reset_tokens, reset_requests, expected
):
headers = {
"x-ratelimit-remaining-tokens": remaining_tokens,
"x-ratelimit-reset-tokens": reset_tokens,
@@ -190,5 +278,6 @@ def test_openai_rate_limit_parser(mock_AsyncOpenAI, remaining_tokens, reset_toke
}
err = MagicMock(response=MagicMock(headers=headers))
llm = OpenAIClient(LLMConfig(model="gpt-4"))
sm = StateManager(mock_state_manager)
llm = OpenAIClient(LLMConfig(model="gpt-4"), state_manager=sm)
assert int(llm.rate_limit_sleep(err).total_seconds()) == expected

View File

@@ -104,6 +104,7 @@ async def test_send_message():
"full_screen": False,
"extra_info": "test",
"placeholder": None,
"accessToken": None,
},
{
"type": "exit",
@@ -113,6 +114,7 @@ async def test_send_message():
"full_screen": False,
"extra_info": None,
"placeholder": None,
"accessToken": None,
},
]
@@ -142,6 +144,7 @@ async def test_stream():
"full_screen": False,
"extra_info": None,
"placeholder": None,
"accessToken": None,
},
{
"type": "stream",
@@ -151,6 +154,7 @@ async def test_stream():
"full_screen": False,
"extra_info": None,
"placeholder": None,
"accessToken": None,
},
{
"type": "exit",
@@ -160,6 +164,7 @@ async def test_stream():
"full_screen": False,
"extra_info": None,
"placeholder": None,
"accessToken": None,
},
]