Merge pull request #260 from Pythagora-io/optimizations

Optimizations
This commit is contained in:
LeonOstrez
2025-07-08 22:06:01 +02:00
committed by GitHub
5 changed files with 146 additions and 49 deletions

View File

@@ -374,7 +374,9 @@ class Developer(ChatWithBreakdownMixin, RelevantFilesMixin, BaseAgent):
previous_task = tasks_done[-1] if tasks_done else None
if previous_task:
e_i, t_i = get_epic_task_number(self.current_state, previous_task)
task_convo = await self.state_manager.get_task_conversation_project_states(UUID(previous_task["id"]))
task_convo = await self.state_manager.get_task_conversation_project_states(
UUID(previous_task["id"]), first_last_only=True
)
await self.ui.send_back_logs(
[
{

View File

@@ -1,5 +1,6 @@
import asyncio
import atexit
import gc
import signal
import sys
import traceback
@@ -273,6 +274,9 @@ async def run_pythagora_session(sm: StateManager, ui: UIBase, args: Namespace):
convo = await load_convo(sm, project_states=fe_states)
await print_convo(ui=ui, convo=convo, fake=False)
# Clear fe_states from memory after conversation is loaded
del fe_states
gc.collect() # Force garbage collection to free memory immediately
if last_task_in_db:
# if there is a task in the db (we are at backend stage), print backend convo history and add task back logs and front logs headers
await ui.send_front_logs_headers(
@@ -297,6 +301,10 @@ async def run_pythagora_session(sm: StateManager, ui: UIBase, args: Namespace):
convo = await load_convo(sm, project_states=be_states)
await print_convo(ui=ui, convo=convo, fake=False)
# Clear be_states from memory after conversation is loaded
del be_states
gc.collect() # Force garbage collection to free memory immediately
else:
sm.fe_auto_debug = True
success = await start_new_project(sm, ui, args)

View File

@@ -708,7 +708,9 @@ class ProjectState(Base):
return epics_and_tasks
@staticmethod
async def get_project_states_in_between(session: "AsyncSession", branch_id: UUID, start_id: UUID, end_id: UUID):
async def get_project_states_in_between(
session: "AsyncSession", branch_id: UUID, start_id: UUID, end_id: UUID, limit: Optional[int] = 100
):
query = select(ProjectState).where(
and_(
ProjectState.branch_id == branch_id,
@@ -731,56 +733,85 @@ class ProjectState(Base):
log.error(f"Could not find states with IDs {start_id} and {end_id} in branch {branch_id}")
return []
query = select(ProjectState).where(
and_(
ProjectState.branch_id == branch_id,
ProjectState.step_index >= start_state.step_index,
ProjectState.step_index <= end_state.step_index,
query = (
select(ProjectState)
.where(
and_(
ProjectState.branch_id == branch_id,
ProjectState.step_index >= start_state.step_index,
ProjectState.step_index <= end_state.step_index,
)
)
.order_by(ProjectState.step_index.desc())
)
if limit:
query = query.limit(limit)
result = await session.execute(query)
return result.scalars().all()
states = result.scalars().all()
# Since we always order by step_index desc, we need to reverse to get chronological order
return list(reversed(states))
@staticmethod
async def get_task_conversation_project_states(
session: "AsyncSession", branch_id: UUID, task_id: UUID
session: "AsyncSession",
branch_id: UUID,
task_id: UUID,
first_last_only: bool = False,
limit: Optional[int] = 25,
) -> Optional[list["ProjectState"]]:
"""
Retrieve the conversation for the task in the project state.
:param session: The SQLAlchemy async session.
:param state_id: The UUID of the project state.
:param branch_id: The UUID of the branch.
:param task_id: The UUID of the task.
:param first_last_only: If True, return only first and last states.
:param limit: Maximum number of states to return (default 25).
:return: List of conversation messages if found, None otherwise.
"""
query = select(ProjectState).where(
and_(
ProjectState.branch_id == branch_id,
or_(ProjectState.action.like("%Task #%"), ProjectState.action.like("%Create a development plan%")),
)
log.debug(
f"Getting task conversation project states for task {task_id} in branch {branch_id} with first_last_only {first_last_only} and limit {limit}"
)
# First, we need to find the start and end step indices
# Use a more efficient query that only loads necessary fields
query = (
select(ProjectState)
.options(load_only(ProjectState.id, ProjectState.step_index, ProjectState.tasks, ProjectState.action))
.where(
and_(
ProjectState.branch_id == branch_id,
or_(ProjectState.action.like("%Task #%"), ProjectState.action.like("%Create a development plan%")),
)
)
.order_by(ProjectState.step_index)
)
result = await session.execute(query)
states = result.scalars().all()
log.debug(f"Found {len(states)} states with custom action")
start = -1
end = -1
start_step_index = None
end_step_index = None
# for the FIRST task, it is todo in the same state as Create a development plan, while other tasks are "Task #N start" (action)
# this is done solely to be able to reload to the first task, due to the fact that we need the same project_state_id for the send_back_logs
# for the first task, we need to start from the FIRST state that has that task in TODO status
# for all other tasks, we need to start from LAST state that has that task in TODO status
for i, state in enumerate(states):
for state in states:
for task in state.tasks:
if UUID(task["id"]) == task_id and task.get("status", "") == TaskStatus.TODO:
if UUID(task["id"]) == UUID(state.tasks[0]["id"]):
# First task: set start only once (first occurrence)
if start == -1:
start = i
if start_step_index is None:
start_step_index = state.step_index
else:
# Other tasks: update start every time (last occurrence)
start = i
start_step_index = state.step_index
if UUID(task["id"]) == task_id and task.get("status", "") in [
TaskStatus.SKIPPED,
@@ -788,37 +819,88 @@ class ProjectState(Base):
TaskStatus.REVIEWED,
TaskStatus.DONE,
]:
end = i
end_step_index = state.step_index
if end == -1:
query = select(ProjectState).where(
and_(
ProjectState.branch_id == branch_id,
ProjectState.step_index >= states[start].step_index,
if start_step_index is None:
return []
# Now build the optimized query based on what we need
if first_last_only:
# For first_last_only, we only need the first and last states
# Get first state
first_query = (
select(ProjectState)
.where(
and_(
ProjectState.branch_id == branch_id,
ProjectState.step_index >= start_step_index,
ProjectState.step_index < end_step_index if end_step_index else True,
)
)
.order_by(ProjectState.step_index.asc())
.limit(1)
)
# Get last state (excluding the uncommitted one)
last_query = (
select(ProjectState)
.where(
and_(
ProjectState.branch_id == branch_id,
ProjectState.step_index >= start_step_index,
ProjectState.step_index < end_step_index if end_step_index else True,
)
)
.order_by(ProjectState.step_index.desc())
.limit(2)
) # Get last 2 to exclude uncommitted
first_result = await session.execute(first_query)
last_result = await session.execute(last_query)
first_state = first_result.scalars().first()
last_states = last_result.scalars().all()
# Remove the last state (uncommitted) and get the actual last
if len(last_states) > 1:
last_state = last_states[1] # Second to last is the actual last committed
else:
last_state = first_state # Only one state
if first_state and last_state and first_state.id != last_state.id:
return [first_state, last_state]
elif first_state:
return [first_state]
else:
return []
else:
query = select(ProjectState).where(
and_(
ProjectState.branch_id == branch_id,
ProjectState.step_index >= states[start].step_index,
ProjectState.step_index < states[end].step_index,
# For regular queries, apply limit at the database level
query = (
select(ProjectState)
.where(
and_(
ProjectState.branch_id == branch_id,
ProjectState.step_index >= start_step_index,
ProjectState.step_index < end_step_index if end_step_index else True,
)
)
.order_by(ProjectState.step_index.asc())
)
result = await session.execute(query)
results = result.scalars().all()
# Remove the last state from the list because that state is not yet committed in the database!
results = results[:-1]
if limit:
# Apply limit + 1 to account for removing the last uncommitted state
query = query.limit(limit + 1)
# index = -1
# for i, state in enumerate(results):
# if state.action and "Task #" in state.action and "start" in state.action:
# index = i
# break
#
# return results[index:]
return results
result = await session.execute(query)
results = result.scalars().all()
log.debug(f"Found {len(results)} states with custom action")
# Remove the last state from the list because that state is not yet committed in the database!
if results:
results = results[:-1]
return results
@staticmethod
async def get_fe_states(session: "AsyncSession", branch_id: UUID) -> Optional["ProjectState"]:

View File

@@ -137,22 +137,26 @@ class StateManager:
async def get_project_state_for_convo_id(self, convo_id) -> Optional["ProjectState"]:
return await ChatConvo.get_project_state_for_convo_id(self.current_session, convo_id)
async def get_task_conversation_project_states(self, task_id: UUID) -> Optional[list[ProjectState]]:
async def get_task_conversation_project_states(
self, task_id: UUID, first_last_only: bool = False, limit: Optional[int] = 25
) -> Optional[list[ProjectState]]:
"""
Get all project states for a specific task conversation.
This retrieves all project states that are associated with a specific task
"""
return await ProjectState.get_task_conversation_project_states(
self.current_session, self.current_state.branch_id, task_id
self.current_session, self.current_state.branch_id, task_id, first_last_only, limit
)
async def get_project_states_in_between(self, start_state_id: UUID, end_state_id: UUID) -> list[ProjectState]:
async def get_project_states_in_between(
self, start_state_id: UUID, end_state_id: UUID, limit: Optional[int] = 100
) -> list[ProjectState]:
"""
Get all project states in between two states.
This retrieves all project states that are associated with a specific branch
"""
return await ProjectState.get_project_states_in_between(
self.current_session, self.current_state.branch_id, start_state_id, end_state_id
self.current_session, self.current_state.branch_id, start_state_id, end_state_id, limit
)
async def get_fe_states(self) -> Optional[ProjectState]:

View File

@@ -665,13 +665,14 @@ class IPCServer:
:param message: Request message.
:param writer: Stream writer to send response.
"""
log.debug("Got _handle_task_convo request with message: %s", message)
log.debug("Got _handle_task_convo request")
try:
task_id = message.content.get("task_id", "")
if task_id:
task_id = uuid.UUID(task_id)
start_project_id = uuid.UUID(message.content.get("start_id", ""))
end_project_id = uuid.UUID(message.content.get("end_id", ""))
log.debug(f"task_id: {task_id}, start_project_id: {start_project_id}, end_project_id: {end_project_id}")
if start_project_id and end_project_id:
project_states = await self.state_manager.get_project_states_in_between(