mirror of
https://github.com/Pythagora-io/gpt-pilot.git
synced 2026-01-08 12:53:50 -05:00
refactor and optimize get_task_conversation_project_states
This commit is contained in:
@@ -374,7 +374,9 @@ class Developer(ChatWithBreakdownMixin, RelevantFilesMixin, BaseAgent):
|
||||
previous_task = tasks_done[-1] if tasks_done else None
|
||||
if previous_task:
|
||||
e_i, t_i = get_epic_task_number(self.current_state, previous_task)
|
||||
task_convo = await self.state_manager.get_task_conversation_project_states(UUID(previous_task["id"]))
|
||||
task_convo = await self.state_manager.get_task_conversation_project_states(
|
||||
UUID(previous_task["id"]), first_last_only=True
|
||||
)
|
||||
await self.ui.send_back_logs(
|
||||
[
|
||||
{
|
||||
|
||||
@@ -743,44 +743,62 @@ class ProjectState(Base):
|
||||
|
||||
@staticmethod
|
||||
async def get_task_conversation_project_states(
|
||||
session: "AsyncSession", branch_id: UUID, task_id: UUID
|
||||
session: "AsyncSession",
|
||||
branch_id: UUID,
|
||||
task_id: UUID,
|
||||
first_last_only: bool = False,
|
||||
limit: Optional[int] = 25,
|
||||
) -> Optional[list["ProjectState"]]:
|
||||
"""
|
||||
Retrieve the conversation for the task in the project state.
|
||||
|
||||
:param session: The SQLAlchemy async session.
|
||||
:param state_id: The UUID of the project state.
|
||||
:param branch_id: The UUID of the branch.
|
||||
:param task_id: The UUID of the task.
|
||||
:param first_last_only: If True, return only first and last states.
|
||||
:param limit: Maximum number of states to return (default 25).
|
||||
:return: List of conversation messages if found, None otherwise.
|
||||
"""
|
||||
query = select(ProjectState).where(
|
||||
and_(
|
||||
ProjectState.branch_id == branch_id,
|
||||
or_(ProjectState.action.like("%Task #%"), ProjectState.action.like("%Create a development plan%")),
|
||||
)
|
||||
log.debug(
|
||||
f"Getting task conversation project states for task {task_id} in branch {branch_id} with first_last_only {first_last_only} and limit {limit}"
|
||||
)
|
||||
# First, we need to find the start and end step indices
|
||||
# Use a more efficient query that only loads necessary fields
|
||||
query = (
|
||||
select(ProjectState)
|
||||
.options(load_only(ProjectState.id, ProjectState.step_index, ProjectState.tasks, ProjectState.action))
|
||||
.where(
|
||||
and_(
|
||||
ProjectState.branch_id == branch_id,
|
||||
or_(ProjectState.action.like("%Task #%"), ProjectState.action.like("%Create a development plan%")),
|
||||
)
|
||||
)
|
||||
.order_by(ProjectState.step_index)
|
||||
)
|
||||
|
||||
result = await session.execute(query)
|
||||
states = result.scalars().all()
|
||||
|
||||
log.debug(f"Found {len(states)} states with custom action")
|
||||
|
||||
start = -1
|
||||
end = -1
|
||||
start_step_index = None
|
||||
end_step_index = None
|
||||
|
||||
# for the FIRST task, it is todo in the same state as Create a development plan, while other tasks are "Task #N start" (action)
|
||||
|
||||
# this is done solely to be able to reload to the first task, due to the fact that we need the same project_state_id for the send_back_logs
|
||||
# for the first task, we need to start from the FIRST state that has that task in TODO status
|
||||
# for all other tasks, we need to start from LAST state that has that task in TODO status
|
||||
for i, state in enumerate(states):
|
||||
for state in states:
|
||||
for task in state.tasks:
|
||||
if UUID(task["id"]) == task_id and task.get("status", "") == TaskStatus.TODO:
|
||||
if UUID(task["id"]) == UUID(state.tasks[0]["id"]):
|
||||
# First task: set start only once (first occurrence)
|
||||
if start == -1:
|
||||
start = i
|
||||
if start_step_index is None:
|
||||
start_step_index = state.step_index
|
||||
else:
|
||||
# Other tasks: update start every time (last occurrence)
|
||||
start = i
|
||||
start_step_index = state.step_index
|
||||
|
||||
if UUID(task["id"]) == task_id and task.get("status", "") in [
|
||||
TaskStatus.SKIPPED,
|
||||
@@ -788,37 +806,88 @@ class ProjectState(Base):
|
||||
TaskStatus.REVIEWED,
|
||||
TaskStatus.DONE,
|
||||
]:
|
||||
end = i
|
||||
end_step_index = state.step_index
|
||||
|
||||
if end == -1:
|
||||
query = select(ProjectState).where(
|
||||
and_(
|
||||
ProjectState.branch_id == branch_id,
|
||||
ProjectState.step_index >= states[start].step_index,
|
||||
if start_step_index is None:
|
||||
return []
|
||||
|
||||
# Now build the optimized query based on what we need
|
||||
if first_last_only:
|
||||
# For first_last_only, we only need the first and last states
|
||||
# Get first state
|
||||
first_query = (
|
||||
select(ProjectState)
|
||||
.where(
|
||||
and_(
|
||||
ProjectState.branch_id == branch_id,
|
||||
ProjectState.step_index >= start_step_index,
|
||||
ProjectState.step_index < end_step_index if end_step_index else True,
|
||||
)
|
||||
)
|
||||
.order_by(ProjectState.step_index.asc())
|
||||
.limit(1)
|
||||
)
|
||||
|
||||
# Get last state (excluding the uncommitted one)
|
||||
last_query = (
|
||||
select(ProjectState)
|
||||
.where(
|
||||
and_(
|
||||
ProjectState.branch_id == branch_id,
|
||||
ProjectState.step_index >= start_step_index,
|
||||
ProjectState.step_index < end_step_index if end_step_index else True,
|
||||
)
|
||||
)
|
||||
.order_by(ProjectState.step_index.desc())
|
||||
.limit(2)
|
||||
) # Get last 2 to exclude uncommitted
|
||||
|
||||
first_result = await session.execute(first_query)
|
||||
last_result = await session.execute(last_query)
|
||||
|
||||
first_state = first_result.scalars().first()
|
||||
last_states = last_result.scalars().all()
|
||||
|
||||
# Remove the last state (uncommitted) and get the actual last
|
||||
if len(last_states) > 1:
|
||||
last_state = last_states[1] # Second to last is the actual last committed
|
||||
else:
|
||||
last_state = first_state # Only one state
|
||||
|
||||
if first_state and last_state and first_state.id != last_state.id:
|
||||
return [first_state, last_state]
|
||||
elif first_state:
|
||||
return [first_state]
|
||||
else:
|
||||
return []
|
||||
|
||||
else:
|
||||
query = select(ProjectState).where(
|
||||
and_(
|
||||
ProjectState.branch_id == branch_id,
|
||||
ProjectState.step_index >= states[start].step_index,
|
||||
ProjectState.step_index < states[end].step_index,
|
||||
# For regular queries, apply limit at the database level
|
||||
query = (
|
||||
select(ProjectState)
|
||||
.where(
|
||||
and_(
|
||||
ProjectState.branch_id == branch_id,
|
||||
ProjectState.step_index >= start_step_index,
|
||||
ProjectState.step_index < end_step_index if end_step_index else True,
|
||||
)
|
||||
)
|
||||
.order_by(ProjectState.step_index.asc())
|
||||
)
|
||||
result = await session.execute(query)
|
||||
results = result.scalars().all()
|
||||
|
||||
# Remove the last state from the list because that state is not yet committed in the database!
|
||||
results = results[:-1]
|
||||
if limit:
|
||||
# Apply limit + 1 to account for removing the last uncommitted state
|
||||
query = query.limit(limit + 1)
|
||||
|
||||
# index = -1
|
||||
# for i, state in enumerate(results):
|
||||
# if state.action and "Task #" in state.action and "start" in state.action:
|
||||
# index = i
|
||||
# break
|
||||
#
|
||||
# return results[index:]
|
||||
return results
|
||||
result = await session.execute(query)
|
||||
results = result.scalars().all()
|
||||
|
||||
log.debug(f"Found {len(results)} states with custom action")
|
||||
# Remove the last state from the list because that state is not yet committed in the database!
|
||||
if results:
|
||||
results = results[:-1]
|
||||
|
||||
return results
|
||||
|
||||
@staticmethod
|
||||
async def get_fe_states(session: "AsyncSession", branch_id: UUID) -> Optional["ProjectState"]:
|
||||
|
||||
@@ -137,13 +137,15 @@ class StateManager:
|
||||
async def get_project_state_for_convo_id(self, convo_id) -> Optional["ProjectState"]:
|
||||
return await ChatConvo.get_project_state_for_convo_id(self.current_session, convo_id)
|
||||
|
||||
async def get_task_conversation_project_states(self, task_id: UUID) -> Optional[list[ProjectState]]:
|
||||
async def get_task_conversation_project_states(
|
||||
self, task_id: UUID, first_last_only: bool = False, limit: Optional[int] = 25
|
||||
) -> Optional[list[ProjectState]]:
|
||||
"""
|
||||
Get all project states for a specific task conversation.
|
||||
This retrieves all project states that are associated with a specific task
|
||||
"""
|
||||
return await ProjectState.get_task_conversation_project_states(
|
||||
self.current_session, self.current_state.branch_id, task_id
|
||||
self.current_session, self.current_state.branch_id, task_id, first_last_only, limit
|
||||
)
|
||||
|
||||
async def get_project_states_in_between(self, start_state_id: UUID, end_state_id: UUID) -> list[ProjectState]:
|
||||
|
||||
Reference in New Issue
Block a user