mirror of
https://github.com/Pythagora-io/gpt-pilot.git
synced 2026-01-09 21:27:53 -05:00
Add a function for restoring convo
This commit is contained in:
@@ -9,6 +9,12 @@ from core.agents.convo import AgentConvo
|
||||
from core.agents.mixins import ChatWithBreakdownMixin, TestSteps
|
||||
from core.agents.response import AgentResponse
|
||||
from core.config import CHECK_LOGS_AGENT_NAME, magic_words
|
||||
from core.config.actions import (
|
||||
BH_START_BUG_HUNT,
|
||||
BH_START_USER_TEST,
|
||||
BH_STARTING_PAIR_PROGRAMMING,
|
||||
BH_WAIT_BUG_REP_INSTRUCTIONS,
|
||||
)
|
||||
from core.config.constants import CONVO_ITERATIONS_LIMIT
|
||||
from core.db.models.project_state import IterationStatus
|
||||
from core.llm.parser import JSONParser
|
||||
@@ -45,12 +51,6 @@ class ImportantLogsForDebugging(BaseModel):
|
||||
logs: list[ImportantLog] = Field(description="Important logs that will help the human debug the current bug.")
|
||||
|
||||
|
||||
BH_STARTING_PAIR_PROGRAMMING = "Start pair programming for task #{}"
|
||||
BH_START_USER_TEST = "Start user testing for task #{}"
|
||||
BH_WAIT_BUG_REP_INSTRUCTIONS = "Awaiting bug reproduction instructions for task #{}"
|
||||
BH_START_BUG_HUNT = "Start bug hunt for task #{}"
|
||||
|
||||
|
||||
class BugHunter(ChatWithBreakdownMixin, BaseAgent):
|
||||
agent_type = "bug-hunter"
|
||||
display_name = "Bug Hunter"
|
||||
|
||||
@@ -11,6 +11,7 @@ from core.agents.convo import AgentConvo
|
||||
from core.agents.mixins import FileDiffMixin
|
||||
from core.agents.response import AgentResponse, ResponseType
|
||||
from core.config import CODE_MONKEY_AGENT_NAME, CODE_REVIEW_AGENT_NAME, DESCRIBE_FILES_AGENT_NAME
|
||||
from core.config.actions import CM_UPDATE_FILES
|
||||
from core.db.models import File
|
||||
from core.llm.parser import JSONParser, OptionalCodeBlockParser
|
||||
from core.log import get_logger
|
||||
@@ -57,9 +58,6 @@ class FileDescription(BaseModel):
|
||||
)
|
||||
|
||||
|
||||
CM_UPDATE_FILES = "Updating files"
|
||||
|
||||
|
||||
class CodeMonkey(FileDiffMixin, BaseAgent):
|
||||
agent_type = "code-monkey"
|
||||
display_name = "Code Monkey"
|
||||
|
||||
@@ -10,6 +10,13 @@ from core.agents.convo import AgentConvo
|
||||
from core.agents.mixins import ChatWithBreakdownMixin, RelevantFilesMixin
|
||||
from core.agents.response import AgentResponse
|
||||
from core.config import PARSE_TASK_AGENT_NAME, TASK_BREAKDOWN_AGENT_NAME
|
||||
from core.config.actions import (
|
||||
DEV_TASK_BREAKDOWN,
|
||||
DEV_TASK_REVIEW_FEEDBACK,
|
||||
DEV_TASK_STARTING,
|
||||
DEV_TROUBLESHOOT,
|
||||
DEV_WAIT_TEST,
|
||||
)
|
||||
from core.db.models.project_state import IterationStatus, TaskStatus
|
||||
from core.db.models.specification import Complexity
|
||||
from core.llm.parser import JSONParser
|
||||
|
||||
@@ -6,6 +6,7 @@ from pydantic import BaseModel, Field
|
||||
from core.agents.base import BaseAgent
|
||||
from core.agents.convo import AgentConvo
|
||||
from core.agents.response import AgentResponse
|
||||
from core.config.actions import EX_RUN_COMMAND, EX_SKIP_COMMAND
|
||||
from core.llm.parser import JSONParser
|
||||
from core.log import get_logger
|
||||
from core.proc.exec_log import ExecLog
|
||||
@@ -32,10 +33,6 @@ class CommandResult(BaseModel):
|
||||
)
|
||||
|
||||
|
||||
EX_SKIP_COMMAND = 'Skip "{}"'
|
||||
EX_RUN_COMMAND = 'Run "{}"'
|
||||
|
||||
|
||||
class Executor(BaseAgent):
|
||||
agent_type = "executor"
|
||||
display_name = "Executor"
|
||||
|
||||
@@ -8,6 +8,7 @@ from core.agents.git import GitMixin
|
||||
from core.agents.mixins import FileDiffMixin
|
||||
from core.agents.response import AgentResponse
|
||||
from core.config import FRONTEND_AGENT_NAME
|
||||
from core.config.actions import FE_CONTINUE, FE_INIT, FE_ITERATION, FE_ITERATION_DONE, FE_START
|
||||
from core.llm.parser import DescriptiveCodeBlockParser
|
||||
from core.log import get_logger
|
||||
from core.telemetry import telemetry
|
||||
@@ -16,12 +17,6 @@ from core.ui.base import ProjectStage
|
||||
|
||||
log = get_logger(__name__)
|
||||
|
||||
FE_INIT = "Frontend init"
|
||||
FE_START = "Frontend start"
|
||||
FE_CONTINUE = "Frontend continue"
|
||||
FE_ITERATION = "Frontend iteration"
|
||||
FE_ITERATION_DONE = "Frontend iteration done"
|
||||
|
||||
|
||||
class Frontend(FileDiffMixin, GitMixin, BaseAgent):
|
||||
agent_type = "frontend"
|
||||
|
||||
@@ -1,12 +1,12 @@
|
||||
import asyncio
|
||||
import json
|
||||
from difflib import unified_diff
|
||||
from typing import List, Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from core.agents.convo import AgentConvo
|
||||
from core.agents.response import AgentResponse
|
||||
from core.cli.helpers import get_line_changes
|
||||
from core.config import GET_RELEVANT_FILES_AGENT_NAME, TASK_BREAKDOWN_AGENT_NAME, TROUBLESHOOTER_BUG_REPORT
|
||||
from core.config.constants import CONVO_ITERATIONS_LIMIT
|
||||
from core.llm.parser import JSONParser
|
||||
@@ -191,18 +191,4 @@ class FileDiffMixin:
|
||||
:return: a tuple (added_lines, deleted_lines)
|
||||
"""
|
||||
|
||||
from_lines = old_content.splitlines(keepends=True)
|
||||
to_lines = new_content.splitlines(keepends=True)
|
||||
|
||||
diff_gen = unified_diff(from_lines, to_lines)
|
||||
|
||||
added_lines = 0
|
||||
deleted_lines = 0
|
||||
|
||||
for line in diff_gen:
|
||||
if line.startswith("+") and not line.startswith("+++"): # Exclude the file headers
|
||||
added_lines += 1
|
||||
elif line.startswith("-") and not line.startswith("---"): # Exclude the file headers
|
||||
deleted_lines += 1
|
||||
|
||||
return added_lines, deleted_lines
|
||||
return get_line_changes(old_content, new_content)
|
||||
|
||||
@@ -2,6 +2,7 @@ from core.agents.base import BaseAgent
|
||||
from core.agents.convo import AgentConvo
|
||||
from core.agents.response import AgentResponse, ResponseType
|
||||
from core.config import SPEC_WRITER_AGENT_NAME
|
||||
from core.config.actions import SPEC_CHANGE_FEATURE_STEP_NAME, SPEC_CHANGE_STEP_NAME, SPEC_CREATE_STEP_NAME
|
||||
from core.db.models import Complexity
|
||||
from core.db.models.project_state import IterationStatus
|
||||
from core.llm.parser import StringParser
|
||||
@@ -14,10 +15,6 @@ ANALYZE_THRESHOLD = 1500
|
||||
INITIAL_PROJECT_HOWTO_URL = (
|
||||
"https://github.com/Pythagora-io/gpt-pilot/wiki/How-to-write-a-good-initial-project-description"
|
||||
)
|
||||
SPEC_CREATE_STEP_NAME = "Create specification"
|
||||
SPEC_CHANGE_STEP_NAME = "Change specification"
|
||||
SPEC_CHANGE_FEATURE_STEP_NAME = "Change specification due to new feature"
|
||||
|
||||
log = get_logger(__name__)
|
||||
|
||||
|
||||
|
||||
@@ -1,13 +1,12 @@
|
||||
from core.agents.base import BaseAgent
|
||||
from core.agents.git import GitMixin
|
||||
from core.agents.response import AgentResponse
|
||||
from core.config.actions import TC_TASK_DONE
|
||||
from core.log import get_logger
|
||||
from core.telemetry import telemetry
|
||||
|
||||
log = get_logger(__name__)
|
||||
|
||||
TC_TASK_DONE = "Task #{} complete"
|
||||
|
||||
|
||||
class TaskCompleter(BaseAgent, GitMixin):
|
||||
agent_type = "pythagora"
|
||||
|
||||
@@ -9,6 +9,7 @@ from core.agents.convo import AgentConvo
|
||||
from core.agents.mixins import RelevantFilesMixin
|
||||
from core.agents.response import AgentResponse
|
||||
from core.config import TECH_LEAD_EPIC_BREAKDOWN, TECH_LEAD_PLANNING
|
||||
from core.config.actions import TL_CREATE_INITIAL_EPIC, TL_CREATE_PLAN, TL_INITIAL_PROJECT_NAME, TL_START_FEATURE
|
||||
from core.db.models import Complexity
|
||||
from core.db.models.project_state import TaskStatus
|
||||
from core.llm.parser import JSONParser
|
||||
@@ -46,11 +47,6 @@ class EpicPlan(BaseModel):
|
||||
plan: list[Task] = Field(description="List of tasks that need to be done to implement the entire epic.")
|
||||
|
||||
|
||||
TL_CREATE_INITIAL_EPIC = "Create initial project epic"
|
||||
TL_CREATE_PLAN = "Create a development plan for epic: {}"
|
||||
TL_START_FEATURE = "Start of feature #{}"
|
||||
|
||||
|
||||
class TechLead(RelevantFilesMixin, BaseAgent):
|
||||
agent_type = "tech-lead"
|
||||
display_name = "Tech Lead"
|
||||
@@ -94,7 +90,7 @@ class TechLead(RelevantFilesMixin, BaseAgent):
|
||||
self.next_state.epics = self.current_state.epics + [
|
||||
{
|
||||
"id": uuid4().hex,
|
||||
"name": "Initial Project",
|
||||
"name": TL_INITIAL_PROJECT_NAME,
|
||||
"source": "app",
|
||||
"description": self.current_state.specification.description,
|
||||
"test_instructions": None,
|
||||
|
||||
@@ -1,14 +1,13 @@
|
||||
from core.agents.base import BaseAgent
|
||||
from core.agents.convo import AgentConvo
|
||||
from core.agents.response import AgentResponse
|
||||
from core.config.actions import TW_WRITE
|
||||
from core.db.models.project_state import TaskStatus
|
||||
from core.log import get_logger
|
||||
from core.ui.base import success_source
|
||||
|
||||
log = get_logger(__name__)
|
||||
|
||||
TW_WRITE = "Write documentation"
|
||||
|
||||
|
||||
class TechnicalWriter(BaseAgent):
|
||||
agent_type = "tech-writer"
|
||||
|
||||
@@ -9,6 +9,7 @@ from core.agents.convo import AgentConvo
|
||||
from core.agents.mixins import ChatWithBreakdownMixin, IterationPromptMixin, RelevantFilesMixin, TestSteps
|
||||
from core.agents.response import AgentResponse
|
||||
from core.config import TROUBLESHOOTER_GET_RUN_COMMAND
|
||||
from core.config.actions import TS_ALT_SOLUTION, TS_TASK_REVIEWED
|
||||
from core.db.models.file import File
|
||||
from core.db.models.project_state import IterationStatus, TaskStatus
|
||||
from core.llm.parser import JSONParser, OptionalCodeBlockParser
|
||||
@@ -31,10 +32,6 @@ class RouteFilePaths(BaseModel):
|
||||
files: list[str] = Field(description="List of paths for files that contain routes")
|
||||
|
||||
|
||||
TS_TASK_REVIEWED = "Task #{} reviewed"
|
||||
TS_ALT_SOLUTION = "Alternative solution (attempt #{})"
|
||||
|
||||
|
||||
class Troubleshooter(ChatWithBreakdownMixin, IterationPromptMixin, RelevantFilesMixin, BaseAgent):
|
||||
agent_type = "troubleshooter"
|
||||
display_name = "Troubleshooter"
|
||||
|
||||
@@ -3,22 +3,36 @@ import os
|
||||
import os.path
|
||||
import sys
|
||||
from argparse import ArgumentParser, ArgumentTypeError, Namespace
|
||||
from difflib import unified_diff
|
||||
from typing import Optional
|
||||
from urllib.parse import urlparse
|
||||
from uuid import UUID
|
||||
|
||||
from core.config import Config, LLMProvider, LocalIPCConfig, ProviderConfig, UIAdapter, get_config, loader
|
||||
from core.config.actions import (
|
||||
BH_START_BUG_HUNT,
|
||||
BH_START_USER_TEST,
|
||||
BH_STARTING_PAIR_PROGRAMMING,
|
||||
BH_WAIT_BUG_REP_INSTRUCTIONS,
|
||||
CM_UPDATE_FILES,
|
||||
DEV_TASK_BREAKDOWN,
|
||||
DEV_TASK_STARTING,
|
||||
DEV_TROUBLESHOOT,
|
||||
TC_TASK_DONE,
|
||||
)
|
||||
from core.config.env_importer import import_from_dotenv
|
||||
from core.config.version import get_version
|
||||
from core.db.session import SessionManager
|
||||
from core.db.setup import run_migrations
|
||||
from core.log import setup
|
||||
from core.log import get_logger, setup
|
||||
from core.state.state_manager import StateManager
|
||||
from core.ui.base import UIBase
|
||||
from core.ui.console import PlainConsoleUI
|
||||
from core.ui.ipc_client import IPCClientUI
|
||||
from core.ui.virtual import VirtualUI
|
||||
|
||||
log = get_logger(__name__)
|
||||
|
||||
|
||||
def parse_llm_endpoint(value: str) -> Optional[tuple[LLMProvider, str]]:
|
||||
"""
|
||||
@@ -47,6 +61,35 @@ def parse_llm_endpoint(value: str) -> Optional[tuple[LLMProvider, str]]:
|
||||
return provider, url.geturl()
|
||||
|
||||
|
||||
def get_line_changes(old_content: str, new_content: str) -> tuple[int, int]:
|
||||
"""
|
||||
Get the number of added and deleted lines between two files.
|
||||
|
||||
This uses Python difflib to produce a unified diff, then counts
|
||||
the number of added and deleted lines.
|
||||
|
||||
:param old_content: old file content
|
||||
:param new_content: new file content
|
||||
:return: a tuple (added_lines, deleted_lines)
|
||||
"""
|
||||
|
||||
from_lines = old_content.splitlines(keepends=True)
|
||||
to_lines = new_content.splitlines(keepends=True)
|
||||
|
||||
diff_gen = unified_diff(from_lines, to_lines)
|
||||
|
||||
added_lines = 0
|
||||
deleted_lines = 0
|
||||
|
||||
for line in diff_gen:
|
||||
if line.startswith("+") and not line.startswith("+++"): # Exclude the file headers
|
||||
added_lines += 1
|
||||
elif line.startswith("-") and not line.startswith("---"): # Exclude the file headers
|
||||
deleted_lines += 1
|
||||
|
||||
return added_lines, deleted_lines
|
||||
|
||||
|
||||
def parse_llm_key(value: str) -> Optional[tuple[LLMProvider, str]]:
|
||||
"""
|
||||
Parse --llm-key command-line option.
|
||||
@@ -228,6 +271,125 @@ async def list_projects_json(db: SessionManager):
|
||||
print(json.dumps(data, indent=2))
|
||||
|
||||
|
||||
async def load_convo(
|
||||
sm: StateManager,
|
||||
project_id: Optional[UUID] = None,
|
||||
step_index: Optional[int] = None,
|
||||
) -> list:
|
||||
"""
|
||||
List all projects in the database.
|
||||
"""
|
||||
convo = []
|
||||
|
||||
project_states = await sm.get_project_states(project_id)
|
||||
project_states = [state for state in project_states if 0 <= state.step_index <= step_index]
|
||||
|
||||
branches = await sm.get_branches_for_project_id(project_id)
|
||||
branch_id = branches[0].id
|
||||
|
||||
task_counter = 1
|
||||
|
||||
for i, state in enumerate(project_states):
|
||||
prev_state = project_states[i - 1] if i > 0 else None
|
||||
|
||||
convo_el = {}
|
||||
ui = await sm.find_user_input(state, branch_id)
|
||||
|
||||
if ui is not None and ui.question is not None:
|
||||
convo_el["question"] = ui.question
|
||||
if ui.answer_text is not None:
|
||||
convo_el["answer"] = str(ui.answer_text)
|
||||
else:
|
||||
convo_el["answer"] = str(ui.answer_button)
|
||||
|
||||
if state.action is not None:
|
||||
if DEV_TASK_STARTING[:-2] in state.action:
|
||||
task_counter = int(state.action.split("#")[-1])
|
||||
|
||||
elif state.action == DEV_TROUBLESHOOT.format(task_counter):
|
||||
if state.iterations is not None:
|
||||
si = state.iterations[-1]
|
||||
if si is not None:
|
||||
if si["user_feedback"] is not None:
|
||||
convo_el["user_feedback"] = si["user_feedback"]
|
||||
if si["description"] is not None:
|
||||
convo_el["description"] = si["description"]
|
||||
|
||||
elif state.action == DEV_TASK_BREAKDOWN.format(task_counter):
|
||||
task = state.tasks[task_counter - 1]
|
||||
if "description" in task and task["description"] is not None:
|
||||
convo_el["description"] = task["description"]
|
||||
|
||||
if "instructions" in task and task["instructions"] is not None:
|
||||
convo_el["instructions"] = task["instructions"]
|
||||
|
||||
elif state.action == TC_TASK_DONE.format(task_counter):
|
||||
if len(state.tasks) > 0:
|
||||
task = state.tasks[task_counter - 1]
|
||||
if "test_instructions" in task and task["test_instructions"] is not None:
|
||||
convo_el["test_instructions"] = task["test_instructions"]
|
||||
|
||||
elif state.action == CM_UPDATE_FILES:
|
||||
files = {}
|
||||
for steps in state.steps:
|
||||
if "save_file" in steps and "path" in steps["save_file"]:
|
||||
path = steps["save_file"]["path"]
|
||||
files["path"] = path
|
||||
|
||||
current_file = await sm.get_file_for_project(state.id, path)
|
||||
prev_file = await sm.get_file_for_project(prev_state.id, path)
|
||||
|
||||
if current_file and prev_file:
|
||||
files["diff"] = get_line_changes(
|
||||
old_content=prev_file.content.content, new_content=current_file.content.content
|
||||
)
|
||||
|
||||
convo_el["files"] = files
|
||||
|
||||
elif state.action == BH_START_BUG_HUNT.format(task_counter):
|
||||
si = state.iterations[-1]
|
||||
if si is not None:
|
||||
if "user_feedback" in si and si["user_feedback"] is not None:
|
||||
convo_el["user_feedback"] = si["user_feedback"]
|
||||
|
||||
if "description" in si and si["description"] is not None:
|
||||
convo_el["description"] = si["description"]
|
||||
|
||||
elif state.action == BH_WAIT_BUG_REP_INSTRUCTIONS.format(task_counter):
|
||||
if state.iterations is not None:
|
||||
for si in state.iterations:
|
||||
if "bug_reproduction_description" in si and si["bug_reproduction_description"] is not None:
|
||||
convo_el["bug_reproduction_description"] = si["bug_reproduction_description"]
|
||||
|
||||
elif state.action == BH_START_USER_TEST.format(task_counter):
|
||||
si = state.iterations[-1]
|
||||
if si is not None:
|
||||
if "bug_hunting_cycles" in si and si["bug_hunting_cycles"] is not None:
|
||||
cycle = si["bug_hunting_cycles"][-1]
|
||||
if cycle is not None:
|
||||
if "user_feedback" in cycle and cycle["user_feedback"] is not None:
|
||||
convo_el["user_feedback"] = cycle["user_feedback"]
|
||||
if (
|
||||
"human_readable_instructions" in cycle
|
||||
and cycle["human_readable_instructions"] is not None
|
||||
):
|
||||
convo_el["human_readable_instructions"] = cycle["human_readable_instructions"]
|
||||
|
||||
elif state.action == BH_STARTING_PAIR_PROGRAMMING.format(task_counter):
|
||||
si = state.iterations[-1]
|
||||
if si is not None:
|
||||
if "user_feedback" in si and si["user_feedback"] is not None:
|
||||
convo_el["user_feedback"] = si["user_feedback"]
|
||||
if "initial_explanation" in si and si["initial_explanation"] is not None:
|
||||
convo_el["initial_explanation"] = si["initial_explanation"]
|
||||
|
||||
if len(convo_el.keys()) > 0:
|
||||
convo_el["action"] = state.action
|
||||
convo.append(convo_el)
|
||||
|
||||
return convo
|
||||
|
||||
|
||||
async def list_projects(db: SessionManager):
|
||||
"""
|
||||
List all projects in the database.
|
||||
|
||||
@@ -14,7 +14,15 @@ except ImportError:
|
||||
SENTRY_AVAILABLE = False
|
||||
|
||||
from core.agents.orchestrator import Orchestrator
|
||||
from core.cli.helpers import delete_project, init, list_projects, list_projects_json, load_project, show_config
|
||||
from core.cli.helpers import (
|
||||
delete_project,
|
||||
init,
|
||||
list_projects,
|
||||
list_projects_json,
|
||||
load_convo,
|
||||
load_project,
|
||||
show_config,
|
||||
)
|
||||
from core.db.session import SessionManager
|
||||
from core.db.v0importer import LegacyDatabaseImporter
|
||||
from core.llm.anthropic_client import CustomAssertionError
|
||||
@@ -196,6 +204,10 @@ async def run_pythagora_session(sm: StateManager, ui: UIBase, args: Namespace):
|
||||
:return: True if the application ran successfully, False otherwise.
|
||||
"""
|
||||
|
||||
if args.project and args.step:
|
||||
convo = await load_convo(sm, args.project, args.step)
|
||||
log.debug(f"Convo exists: {len(convo) > 0}")
|
||||
|
||||
if args.project or args.branch or args.step:
|
||||
telemetry.set("is_continuation", True)
|
||||
success = await load_project(sm, args.project, args.branch, args.step)
|
||||
|
||||
41
core/config/actions.py
Normal file
41
core/config/actions.py
Normal file
@@ -0,0 +1,41 @@
|
||||
BH_START_BUG_HUNT = "Start bug hunt for task #{}"
|
||||
BH_WAIT_BUG_REP_INSTRUCTIONS = "Awaiting bug reproduction instructions for task #{}"
|
||||
BH_START_USER_TEST = "Start user testing for task #{}"
|
||||
BH_STARTING_PAIR_PROGRAMMING = "Start pair programming for task #{}"
|
||||
|
||||
CM_UPDATE_FILES = "Updating files"
|
||||
|
||||
|
||||
DEV_WAIT_TEST = "Awaiting user test"
|
||||
DEV_TASK_STARTING = "Starting task #{}"
|
||||
DEV_TASK_BREAKDOWN = "Task #{} breakdown"
|
||||
DEV_TROUBLESHOOT = "Troubleshooting #{}"
|
||||
DEV_TASK_REVIEW_FEEDBACK = "Task review feedback"
|
||||
|
||||
TC_TASK_DONE = "Task #{} complete"
|
||||
|
||||
|
||||
FE_INIT = "Frontend init"
|
||||
FE_START = "Frontend start"
|
||||
FE_CONTINUE = "Frontend continue"
|
||||
FE_ITERATION = "Frontend iteration"
|
||||
FE_ITERATION_DONE = "Frontend iteration done"
|
||||
|
||||
TL_CREATE_INITIAL_EPIC = "Create initial project epic"
|
||||
TL_CREATE_PLAN = "Create a development plan for epic: {}"
|
||||
TL_START_FEATURE = "Start of feature #{}"
|
||||
TL_INITIAL_PROJECT_NAME = "Initial Project"
|
||||
|
||||
TW_WRITE = "Write documentation"
|
||||
|
||||
EX_SKIP_COMMAND = 'Skip "{}"'
|
||||
EX_RUN_COMMAND = 'Run "{}"'
|
||||
|
||||
SPEC_CREATE_STEP_NAME = "Create specification"
|
||||
SPEC_CHANGE_STEP_NAME = "Change specification"
|
||||
SPEC_CHANGE_FEATURE_STEP_NAME = "Change specification due to new feature"
|
||||
|
||||
TS_TASK_REVIEWED = "Task #{} reviewed"
|
||||
TS_ALT_SOLUTION = "Alternative solution (attempt #{})"
|
||||
|
||||
PS_EPIC_COMPLETE = "Epic {} completed"
|
||||
@@ -9,7 +9,7 @@ from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.orm import Mapped, mapped_column, relationship, selectinload
|
||||
from sqlalchemy.sql import func
|
||||
|
||||
from core.db.models import Base
|
||||
from core.db.models import Base, File
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from core.db.models import Branch
|
||||
@@ -66,6 +66,20 @@ class Project(Base):
|
||||
result = await session.execute(select(Branch).where(Branch.project_id == self.id, Branch.name == name))
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
@staticmethod
|
||||
async def get_file_for_project(session: AsyncSession, project_state_id: UUID, path: str) -> Optional["File"]:
|
||||
file_result = await session.execute(
|
||||
select(File).where(File.project_state_id == project_state_id, File.path == path)
|
||||
)
|
||||
return file_result.scalar_one_or_none()
|
||||
|
||||
@staticmethod
|
||||
async def get_branches_for_project_id(session: AsyncSession, project_id: UUID) -> list["Branch"]:
|
||||
from core.db.models import Branch
|
||||
|
||||
branch_result = await session.execute(select(Branch).where(Branch.project_id == project_id))
|
||||
return branch_result.scalars().all()
|
||||
|
||||
@staticmethod
|
||||
async def get_all_projects(session: "AsyncSession") -> list["Project"]:
|
||||
"""
|
||||
|
||||
@@ -3,12 +3,13 @@ from datetime import datetime
|
||||
from typing import TYPE_CHECKING, Optional, Union
|
||||
from uuid import UUID, uuid4
|
||||
|
||||
from sqlalchemy import ForeignKey, UniqueConstraint, delete, inspect
|
||||
from sqlalchemy import ForeignKey, UniqueConstraint, delete, inspect, select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||
from sqlalchemy.orm.attributes import flag_modified
|
||||
from sqlalchemy.sql import func
|
||||
|
||||
from core.config.actions import PS_EPIC_COMPLETE
|
||||
from core.db.models import Base, FileContent
|
||||
from core.log import get_logger
|
||||
|
||||
@@ -46,10 +47,6 @@ class IterationStatus:
|
||||
DONE = "done"
|
||||
|
||||
|
||||
PS_EPIC_COMPLETE = "Epic {} completed"
|
||||
PS_TASK_COMPLETE = "Task {} completed"
|
||||
|
||||
|
||||
class ProjectState(Base):
|
||||
__tablename__ = "project_states"
|
||||
__table_args__ = (
|
||||
@@ -218,6 +215,19 @@ class ProjectState(Base):
|
||||
step_index=1,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def get_project_states(
|
||||
session: "AsyncSession",
|
||||
project_id: UUID,
|
||||
) -> list["ProjectState"]:
|
||||
from core.db.models import Branch, ProjectState
|
||||
|
||||
branch = await session.execute(select(Branch).where(Branch.project_id == project_id))
|
||||
branch = branch.scalar_one_or_none()
|
||||
|
||||
project_states_result = await session.execute(select(ProjectState).where(ProjectState.branch_id == branch.id))
|
||||
return project_states_result.scalars().all()
|
||||
|
||||
async def create_next_state(self) -> "ProjectState":
|
||||
"""
|
||||
Create the next project state for the branch.
|
||||
|
||||
@@ -2,7 +2,8 @@ from datetime import datetime
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
from uuid import UUID
|
||||
|
||||
from sqlalchemy import ForeignKey, inspect
|
||||
from sqlalchemy import ForeignKey, and_, inspect, select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||
from sqlalchemy.sql import func
|
||||
|
||||
@@ -57,3 +58,24 @@ class UserInput(Base):
|
||||
)
|
||||
session.add(obj)
|
||||
return obj
|
||||
|
||||
@staticmethod
|
||||
async def find_user_input(session: AsyncSession, project_state, branch_id) -> Optional["UserInput"]:
|
||||
from core.db.models import UserInput
|
||||
|
||||
user_input = await session.execute(
|
||||
select(UserInput).where(
|
||||
and_(UserInput.branch_id == branch_id, UserInput.project_state_id == project_state.id)
|
||||
)
|
||||
)
|
||||
user_input = user_input.scalars().all()
|
||||
|
||||
if user_input is None:
|
||||
user_input = await session.execute(
|
||||
select(UserInput).where(
|
||||
and_(UserInput.branch_id == branch_id, UserInput.project_state_id == project_state.prev_state_id)
|
||||
)
|
||||
)
|
||||
user_input = user_input.scalars().all()
|
||||
|
||||
return user_input[0] if len(user_input) > 0 else None
|
||||
|
||||
@@ -78,6 +78,26 @@ class StateManager:
|
||||
async with self.session_manager as session:
|
||||
return await Project.get_all_projects(session)
|
||||
|
||||
async def get_convo(self):
|
||||
async with self.session_manager as session:
|
||||
return await Project.get_convo(session)
|
||||
|
||||
async def get_project_states(self, project_id: UUID) -> list[ProjectState]:
|
||||
async with self.session_manager as session:
|
||||
return await ProjectState.get_project_states(session, project_id)
|
||||
|
||||
async def get_branches_for_project_id(self, project_id: UUID) -> list[Branch]:
|
||||
async with self.session_manager as session:
|
||||
return await Project.get_branches_for_project_id(session, project_id)
|
||||
|
||||
async def find_user_input(self, project_state, branch_id) -> Optional["UserInput"]:
|
||||
async with self.session_manager as session:
|
||||
return await UserInput.find_user_input(session, project_state, branch_id)
|
||||
|
||||
async def get_file_for_project(self, state_id: UUID, path: str):
|
||||
async with self.session_manager as session:
|
||||
return await Project.get_file_for_project(session, state_id, path)
|
||||
|
||||
async def create_project(self, name: str, folder_name: Optional[str] = None) -> Project:
|
||||
"""
|
||||
Create a new project and set it as the current one.
|
||||
|
||||
Reference in New Issue
Block a user