Files
gpt-pilot/core/agents/mixins.py
2024-07-18 20:58:18 +01:00

85 lines
2.4 KiB
Python

from enum import Enum
from typing import Annotated, Literal, Optional, Union
from pydantic import BaseModel, Field
from core.agents.convo import AgentConvo
class StepType(str, Enum):
COMMAND = "command"
SAVE_FILE = "save_file"
HUMAN_INTERVENTION = "human_intervention"
class CommandOptions(BaseModel):
command: str = Field(description="Command to run")
timeout: int = Field(description="Timeout in seconds")
success_message: str = ""
class SaveFileOptions(BaseModel):
path: str
class SaveFileStep(BaseModel):
type: Literal[StepType.SAVE_FILE] = StepType.SAVE_FILE
save_file: SaveFileOptions
class CommandStep(BaseModel):
type: Literal[StepType.COMMAND] = StepType.COMMAND
command: CommandOptions
class HumanInterventionStep(BaseModel):
type: Literal[StepType.HUMAN_INTERVENTION] = StepType.HUMAN_INTERVENTION
human_intervention_description: str
Step = Annotated[
Union[SaveFileStep, CommandStep, HumanInterventionStep],
Field(discriminator="type"),
]
class TaskSteps(BaseModel):
steps: list[Step]
class IterationPromptMixin:
"""
Provides a method to find a solution to a problem based on user feedback.
Used by ProblemSolver and Troubleshooter agents.
"""
async def find_solution(
self,
user_feedback: str,
*,
user_feedback_qa: Optional[list[str]] = None,
next_solution_to_try: Optional[str] = None,
logs_data: Optional[dict] = None,
) -> str:
"""
Generate a new solution for the problem the user reported.
:param user_feedback: User feedback about the problem.
:param user_feedback_qa: Additional q/a about the problem provided by the user (optional).
:param next_solution_to_try: Hint from ProblemSolver on which solution to try (optional).
:param logs_data: Data about logs that need to be added to the code (optional).
:return: The generated solution to the problem.
"""
llm = self.get_llm()
convo = AgentConvo(self).template(
"iteration",
current_task=self.current_state.current_task,
user_feedback=user_feedback,
user_feedback_qa=user_feedback_qa,
next_solution_to_try=next_solution_to_try,
logs_data=logs_data,
)
llm_solution: str = await llm(convo)
return llm_solution