mirror of
https://github.com/Pythagora-io/gpt-pilot.git
synced 2026-01-09 21:27:53 -05:00
Add relace and continue_frontend properly after iterate_frontend
This commit is contained in:
@@ -8,7 +8,7 @@ from core.agents.git import GitMixin
|
||||
from core.agents.mixins import FileDiffMixin
|
||||
from core.agents.response import AgentResponse
|
||||
from core.cli.helpers import capture_exception
|
||||
from core.config import FRONTEND_AGENT_NAME, SWAGGER_EMBEDDINGS_API
|
||||
from core.config import FRONTEND_AGENT_NAME, IMPLEMENT_CHANGES_AGENT_NAME, SWAGGER_EMBEDDINGS_API
|
||||
from core.config.actions import (
|
||||
FE_CHANGE_REQ,
|
||||
FE_CONTINUE,
|
||||
@@ -17,7 +17,8 @@ from core.config.actions import (
|
||||
FE_ITERATION_DONE,
|
||||
FE_START,
|
||||
)
|
||||
from core.llm.parser import DescriptiveCodeBlockParser
|
||||
from core.llm.convo import Convo
|
||||
from core.llm.parser import DescriptiveCodeBlockParser, OptionalCodeBlockParser
|
||||
from core.log import get_logger
|
||||
from core.telemetry import telemetry
|
||||
from core.ui.base import ProjectStage
|
||||
@@ -25,6 +26,13 @@ from core.ui.base import ProjectStage
|
||||
log = get_logger(__name__)
|
||||
|
||||
|
||||
def has_correct_num_of_backticks(response: str) -> bool:
|
||||
"""
|
||||
Checks if the response has the correct number of backticks.
|
||||
"""
|
||||
return response.count("```") % 2 == 0 and response.count("```") > 0
|
||||
|
||||
|
||||
class Frontend(FileDiffMixin, GitMixin, BaseAgent):
|
||||
agent_type = "frontend"
|
||||
display_name = "Frontend"
|
||||
@@ -104,19 +112,21 @@ class Frontend(FileDiffMixin, GitMixin, BaseAgent):
|
||||
response_blocks = response.blocks
|
||||
convo.assistant(response.original_response)
|
||||
|
||||
await self.process_response(response_blocks)
|
||||
|
||||
self.next_state.epics[-1]["messages"] = convo.messages
|
||||
use_relace = self.current_state.epics[-1].get("use_relace", False)
|
||||
await self.process_response(response_blocks, relace=use_relace)
|
||||
|
||||
if self.next_state.epics[-1].get("manual_iteration", False):
|
||||
# If manual iteration is True, we assume only one iteration of continue is needed
|
||||
# If we want multiple iterations, we should use response.original_response.count("```") % 2 == 0
|
||||
self.next_state.epics[-1]["fe_iteration_done"] = True
|
||||
self.next_state.epics[-1]["fe_iteration_done"] = (
|
||||
has_correct_num_of_backticks(response.original_response)
|
||||
or self.current_state.epics[-1]["retry_count"] >= 2
|
||||
)
|
||||
self.next_state.epics[-1]["retry_count"] = self.current_state.epics[-1].get("retry_count", 0) + 1
|
||||
else:
|
||||
self.next_state.epics[-1]["fe_iteration_done"] = (
|
||||
"done" in response.original_response[-20:].lower().strip() or len(convo.messages) > 15
|
||||
)
|
||||
|
||||
self.next_state.epics[-1]["messages"] = convo.messages
|
||||
self.next_state.flag_epics_as_modified()
|
||||
|
||||
return False
|
||||
@@ -127,6 +137,7 @@ class Frontend(FileDiffMixin, GitMixin, BaseAgent):
|
||||
|
||||
:return: True if the frontend is fully built, False otherwise.
|
||||
"""
|
||||
self.next_state.epics[-1]["retry_count"] = 0
|
||||
frontend_only = self.current_state.branch.project.project_type == "swagger"
|
||||
self.next_state.action = FE_ITERATION
|
||||
# update the pages in the knowledge base
|
||||
@@ -209,25 +220,41 @@ class Frontend(FileDiffMixin, GitMixin, BaseAgent):
|
||||
|
||||
llm = self.get_llm(FRONTEND_AGENT_NAME, stream_output=True)
|
||||
|
||||
# try relace first
|
||||
convo = AgentConvo(self).template(
|
||||
"build_frontend",
|
||||
"iterate_frontend",
|
||||
description=self.current_state.epics[0]["description"],
|
||||
user_feedback=answer.text,
|
||||
relevant_api_documentation=relevant_api_documentation,
|
||||
first_time_build=False,
|
||||
)
|
||||
|
||||
# replace system prompt because of relace
|
||||
convo.messages[0]["content"] = AgentConvo(self).render("system_relace")
|
||||
|
||||
response = await llm(convo, parser=DescriptiveCodeBlockParser())
|
||||
|
||||
await self.process_response(response.blocks)
|
||||
relace_finished = await self.process_response(response.blocks, relace=True)
|
||||
|
||||
if not relace_finished:
|
||||
log.debug("Relace didn't finish, reverting to build_frontend")
|
||||
convo = AgentConvo(self).template(
|
||||
"build_frontend",
|
||||
description=self.current_state.epics[0]["description"],
|
||||
user_feedback=answer.text,
|
||||
relevant_api_documentation=relevant_api_documentation,
|
||||
first_time_build=False,
|
||||
)
|
||||
|
||||
response = await llm(convo, parser=DescriptiveCodeBlockParser())
|
||||
|
||||
await self.process_response(response.blocks)
|
||||
convo.assistant(response.original_response)
|
||||
|
||||
# Store the conversation in the epic messages for potential continuation
|
||||
self.next_state.epics[-1]["messages"] = convo.messages
|
||||
|
||||
self.next_state.epics[-1]["fe_iteration_done"] = response.original_response.count("```") % 2 == 0
|
||||
self.next_state.epics[-1]["use_relace"] = relace_finished
|
||||
self.next_state.epics[-1]["fe_iteration_done"] = has_correct_num_of_backticks(response.original_response)
|
||||
self.next_state.epics[-1]["manual_iteration"] = True
|
||||
|
||||
self.next_state.flag_epics_as_modified()
|
||||
|
||||
return False
|
||||
@@ -268,7 +295,7 @@ class Frontend(FileDiffMixin, GitMixin, BaseAgent):
|
||||
|
||||
return AgentResponse.done(self)
|
||||
|
||||
async def process_response(self, response_blocks: list, removed_mock: bool = False) -> list[str]:
|
||||
async def process_response(self, response_blocks: list, removed_mock: bool = False, relace: bool = False) -> bool:
|
||||
"""
|
||||
Processes the response blocks from the LLM.
|
||||
|
||||
@@ -292,6 +319,21 @@ class Frontend(FileDiffMixin, GitMixin, BaseAgent):
|
||||
continue
|
||||
new_content = content
|
||||
old_content = self.current_state.get_file_content_by_path(file_path)
|
||||
|
||||
if relace:
|
||||
llm = self.get_llm(IMPLEMENT_CHANGES_AGENT_NAME)
|
||||
convo = Convo().user(
|
||||
{
|
||||
"initialCode": old_content,
|
||||
"editSnippet": new_content,
|
||||
}
|
||||
)
|
||||
|
||||
new_content = await llm(convo, temperature=0, parser=OptionalCodeBlockParser())
|
||||
|
||||
if not new_content or new_content == ("", 0, 0):
|
||||
return False
|
||||
|
||||
n_new_lines, n_del_lines = self.get_line_changes(old_content, new_content)
|
||||
await self.ui.send_file_status(file_path, "done", source=self.ui_source)
|
||||
await self.ui.generate_diff(
|
||||
@@ -327,7 +369,7 @@ class Frontend(FileDiffMixin, GitMixin, BaseAgent):
|
||||
else:
|
||||
log.info(f"Unknown block description: {description}")
|
||||
|
||||
return AgentResponse.done(self)
|
||||
return True
|
||||
|
||||
async def remove_mock(self):
|
||||
"""
|
||||
|
||||
@@ -17,7 +17,7 @@ class RelaceClient(BaseLLMClient):
|
||||
def _init_client(self):
|
||||
self.headers = {
|
||||
"Content-Type": "application/json",
|
||||
"Authorization": f"Bearer {self.state_manager.get_access_token()}",
|
||||
"Authorization": f"Bearer {self.state_manager.get_access_token() if self.state_manager.get_access_token() is not None else self.config.api_key if self.config.api_key is not None else ''}",
|
||||
}
|
||||
self.client = AsyncClient()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user