updated code generation and intial chat session logic

This commit is contained in:
Swifty
2025-12-16 22:45:17 +01:00
parent 9e1354bfee
commit 858a8a818b
11 changed files with 317 additions and 48 deletions

View File

@@ -14,6 +14,10 @@ class ChatConfig(BaseSettings):
model: str = Field(
default="anthropic/claude-opus-4.5", description="Default model to use"
)
title_model: str = Field(
default="openai/gpt-4o-mini",
description="Model to use for generating session titles (should be fast/cheap)",
)
api_key: str | None = Field(default=None, description="OpenAI API key")
base_url: str | None = Field(
default="https://openrouter.ai/api/v1",

View File

@@ -178,6 +178,11 @@ async def get_user_chat_sessions(
)
async def get_user_session_count(user_id: str) -> int:
"""Get the total number of chat sessions for a user."""
return await PrismaChatSession.prisma().count(where={"userId": user_id})
async def delete_chat_session(session_id: str) -> bool:
"""Delete a chat session and all its messages."""
try:

View File

@@ -49,6 +49,7 @@ class Usage(BaseModel):
class ChatSession(BaseModel):
session_id: str
user_id: str | None
title: str | None = None
messages: list[ChatMessage]
usage: list[Usage]
credentials: dict[str, dict] = {} # Map of provider -> credential metadata
@@ -62,6 +63,7 @@ class ChatSession(BaseModel):
return ChatSession(
session_id=str(uuid.uuid4()),
user_id=user_id,
title=None,
messages=[],
usage=[],
credentials={},
@@ -138,6 +140,7 @@ class ChatSession(BaseModel):
return ChatSession(
session_id=prisma_session.id,
user_id=prisma_session.userId,
title=prisma_session.title,
messages=messages,
usage=usage,
credentials=credentials,
@@ -246,7 +249,13 @@ async def _get_session_from_cache(session_id: str) -> ChatSession | None:
return None
try:
return ChatSession.model_validate_json(raw_session)
session = ChatSession.model_validate_json(raw_session)
logger.info(
f"Loading session {session_id} from cache: "
f"message_count={len(session.messages)}, "
f"roles={[m.role for m in session.messages]}"
)
return session
except Exception as e:
logger.error(f"Failed to deserialize session {session_id}: {e}", exc_info=True)
raise RedisError(f"Corrupted session data for {session_id}") from e
@@ -265,7 +274,15 @@ async def _get_session_from_db(session_id: str) -> ChatSession | None:
if not prisma_session:
return None
return ChatSession.from_prisma(prisma_session, prisma_session.Messages)
messages = prisma_session.Messages
logger.info(
f"Loading session {session_id} from DB: "
f"has_messages={messages is not None}, "
f"message_count={len(messages) if messages else 0}, "
f"roles={[m.role for m in messages] if messages else []}"
)
return ChatSession.from_prisma(prisma_session, messages)
async def _save_session_to_db(
@@ -313,6 +330,11 @@ async def _save_session_to_db(
"function_call": msg.function_call,
}
)
logger.info(
f"Saving {len(new_messages)} new messages to DB for session {session.session_id}: "
f"roles={[m['role'] for m in messages_data]}, "
f"start_sequence={existing_message_count}"
)
await chat_db.add_chat_messages_batch(
session_id=session.session_id,
messages=messages_data,

View File

@@ -68,3 +68,50 @@ async def test_chatsession_redis_storage_user_id_mismatch():
s2 = await get_chat_session(s.session_id, None)
assert s2 is None
@pytest.mark.asyncio(loop_scope="session")
async def test_chatsession_db_storage():
"""Test that messages are correctly saved to and loaded from DB (not cache)."""
from backend.data.redis_client import get_redis_async
# Create session with messages including assistant message
s = ChatSession.new(user_id=None)
s.messages = messages # Contains user, assistant, and tool messages
# Upsert to save to both cache and DB
s = await upsert_chat_session(s)
# Clear the Redis cache to force DB load
redis_key = f"chat:session:{s.session_id}"
async_redis = await get_redis_async()
await async_redis.delete(redis_key)
# Load from DB (cache was cleared)
s2 = await get_chat_session(
session_id=s.session_id,
user_id=s.user_id,
)
assert s2 is not None, "Session not found after loading from DB"
assert len(s2.messages) == len(
s.messages
), f"Message count mismatch: expected {len(s.messages)}, got {len(s2.messages)}"
# Verify all roles are present
roles = [m.role for m in s2.messages]
assert "user" in roles, f"User message missing. Roles found: {roles}"
assert "assistant" in roles, f"Assistant message missing. Roles found: {roles}"
assert "tool" in roles, f"Tool message missing. Roles found: {roles}"
# Verify message content
for orig, loaded in zip(s.messages, s2.messages):
assert orig.role == loaded.role, f"Role mismatch: {orig.role} != {loaded.role}"
assert (
orig.content == loaded.content
), f"Content mismatch for {orig.role}: {orig.content} != {loaded.content}"
if orig.tool_calls:
assert (
loaded.tool_calls is not None
), f"Tool calls missing for {orig.role} message"
assert len(orig.tool_calls) == len(loaded.tool_calls)

View File

@@ -164,12 +164,20 @@ async def get_session(
session = await chat_service.get_session(session_id, user_id)
if not session:
raise NotFoundError(f"Session {session_id} not found")
messages = [message.model_dump() for message in session.messages]
logger.info(
f"Returning session {session_id}: "
f"message_count={len(messages)}, "
f"roles={[m.get('role') for m in messages]}"
)
return SessionDetailResponse(
id=session.session_id,
created_at=session.started_at.isoformat(),
updated_at=session.updated_at.isoformat(),
user_id=session.user_id or None,
messages=[message.model_dump() for message in session.messages],
messages=messages,
)
@@ -367,12 +375,20 @@ async def get_onboarding_session(
session = await chat_service.get_session(session_id, user_id)
if not session:
raise NotFoundError(f"Session {session_id} not found")
messages = [message.model_dump() for message in session.messages]
logger.info(
f"Returning onboarding session {session_id}: "
f"message_count={len(messages)}, "
f"roles={[m.get('role') for m in messages]}"
)
return SessionDetailResponse(
id=session.session_id,
created_at=session.started_at.isoformat(),
updated_at=session.updated_at.isoformat(),
user_id=session.user_id or None,
messages=[message.model_dump() for message in session.messages],
messages=messages,
)

View File

@@ -8,6 +8,7 @@ from openai import AsyncOpenAI
from openai.types.chat import ChatCompletionChunk, ChatCompletionToolParam
import backend.server.v2.chat.config
import backend.server.v2.chat.db as chat_db
from backend.data.understanding import (
format_understanding_for_prompt,
get_business_understanding,
@@ -37,6 +38,19 @@ config = backend.server.v2.chat.config.ChatConfig()
client = AsyncOpenAI(api_key=config.api_key, base_url=config.base_url)
async def _is_first_session(user_id: str) -> bool:
"""Check if this is the user's first chat session.
Returns True if the user has 1 or fewer sessions (meaning this is their first).
"""
try:
session_count = await chat_db.get_user_session_count(user_id)
return session_count <= 1
except Exception as e:
logger.warning(f"Failed to check session count for user {user_id}: {e}")
return False # Default to non-onboarding if we can't check
async def _build_system_prompt(
user_id: str | None, prompt_type: str = "default"
) -> str:
@@ -45,12 +59,20 @@ async def _build_system_prompt(
Args:
user_id: The user ID for fetching business understanding
prompt_type: The type of prompt to load ("default" or "onboarding")
If "default" and this is the user's first session, will use "onboarding" instead.
Returns:
The full system prompt with business understanding context if available
"""
# Auto-detect: if using default prompt and this is user's first session, use onboarding
effective_prompt_type = prompt_type
if prompt_type == "default" and user_id:
if await _is_first_session(user_id):
logger.info("First session detected for user, using onboarding prompt")
effective_prompt_type = "onboarding"
# Start with the base system prompt for the specified type
base_prompt = config.get_system_prompt_for_type(prompt_type)
base_prompt = config.get_system_prompt_for_type(effective_prompt_type)
# If user is authenticated, try to fetch their business understanding
if user_id:
@@ -72,6 +94,46 @@ async def _build_system_prompt(
return base_prompt
async def _generate_session_title(message: str) -> str | None:
"""Generate a concise title for a chat session based on the first message.
Args:
message: The first user message in the session
Returns:
A short title (3-6 words) or None if generation fails
"""
try:
response = await client.chat.completions.create(
model=config.title_model,
messages=[
{
"role": "system",
"content": (
"Generate a very short title (3-6 words) for a chat conversation "
"based on the user's first message. The title should capture the "
"main topic or intent. Return ONLY the title, no quotes or punctuation."
),
},
{"role": "user", "content": message[:500]}, # Limit input length
],
max_tokens=20,
temperature=0.7,
)
title = response.choices[0].message.content
if title:
# Clean up the title
title = title.strip().strip("\"'")
# Limit length
if len(title) > 50:
title = title[:47] + "..."
return title
return None
except Exception as e:
logger.warning(f"Failed to generate session title: {e}")
return None
async def create_chat_session(
user_id: str | None = None,
) -> ChatSession:
@@ -202,6 +264,29 @@ async def stream_chat_completion(
session = await upsert_chat_session(session)
assert session, "Session not found"
# Generate title for new sessions on first user message (non-blocking)
# Check: is_user_message, no title yet, and this is the first user message
if is_user_message and message and not session.title:
user_messages = [m for m in session.messages if m.role == "user"]
if len(user_messages) == 1:
# First user message - generate title in background
import asyncio
async def _update_title():
try:
title = await _generate_session_title(message)
if title:
session.title = title
await upsert_chat_session(session)
logger.info(
f"Generated title for session {session_id}: {title}"
)
except Exception as e:
logger.warning(f"Failed to update session title: {e}")
# Fire and forget - don't block the chat response
asyncio.create_task(_update_title())
# Build system prompt with business understanding
system_prompt = await _build_system_prompt(user_id, prompt_type)
@@ -581,8 +666,12 @@ async def _yield_tool_call(
"""
logger.info(f"Yielding tool call: {tool_calls[yield_idx]}")
# Parse tool call arguments - exceptions will propagate to caller
arguments = orjson.loads(tool_calls[yield_idx]["function"]["arguments"])
# Parse tool call arguments - handle empty arguments gracefully
raw_arguments = tool_calls[yield_idx]["function"]["arguments"]
if raw_arguments:
arguments = orjson.loads(raw_arguments)
else:
arguments = {}
yield StreamToolCall(
tool_id=tool_calls[yield_idx]["id"],

View File

@@ -154,28 +154,67 @@ def json_to_graph(agent_json: dict[str, Any]) -> Graph:
)
def _reassign_node_ids(graph: Graph) -> None:
"""Reassign all node and link IDs to new UUIDs.
This is needed when creating a new version to avoid unique constraint violations.
"""
# Create mapping from old node IDs to new UUIDs
id_map = {node.id: str(uuid.uuid4()) for node in graph.nodes}
# Reassign node IDs
for node in graph.nodes:
node.id = id_map[node.id]
# Update link references to use new node IDs
for link in graph.links:
link.id = str(uuid.uuid4()) # Also give links new IDs
if link.source_id in id_map:
link.source_id = id_map[link.source_id]
if link.sink_id in id_map:
link.sink_id = id_map[link.sink_id]
async def save_agent_to_library(
agent_json: dict[str, Any], user_id: str
agent_json: dict[str, Any], user_id: str, is_update: bool = False
) -> tuple[Graph, Any]:
"""Save agent to database and user's library.
Args:
agent_json: Agent JSON dict
user_id: User ID
is_update: Whether this is an update to an existing agent
Returns:
Tuple of (created Graph, LibraryAgent)
"""
from backend.data.graph import get_graph_all_versions
graph = json_to_graph(agent_json)
# Ensure graph has a unique ID
if not graph.id or graph.id == "":
if is_update:
# For updates, keep the same graph ID but increment version
# and reassign node/link IDs to avoid conflicts
if graph.id:
existing_versions = await get_graph_all_versions(graph.id, user_id)
if existing_versions:
latest_version = max(v.version for v in existing_versions)
graph.version = latest_version + 1
# Reassign node IDs (but keep graph ID the same)
_reassign_node_ids(graph)
logger.info(f"Updating agent {graph.id} to version {graph.version}")
else:
# For new agents, always generate a fresh UUID to avoid collisions
graph.id = str(uuid.uuid4())
graph.version = 1
# Reassign all node IDs as well
_reassign_node_ids(graph)
logger.info(f"Creating new agent with ID {graph.id}")
# Save to database
created_graph = await create_graph(graph, user_id)
# Add to user's library
# Add to user's library (or update existing library agent)
library_agents = await library_db.create_library_agent(
graph=created_graph,
user_id=user_id,

View File

@@ -47,14 +47,49 @@ def is_valid_uuid(value: str) -> bool:
return isinstance(value, str) and UUID_REGEX.match(value) is not None
def _compact_schema(schema: dict) -> dict[str, str]:
"""Extract compact type info from a JSON schema properties dict.
Returns a dict of {field_name: type_string} for essential info only.
"""
props = schema.get("properties", {})
result = {}
for name, prop in props.items():
# Skip internal/complex fields
if name.startswith("_"):
continue
# Get type string
type_str = prop.get("type", "any")
# Handle anyOf/oneOf (optional types)
if "anyOf" in prop:
types = [t.get("type", "?") for t in prop["anyOf"] if t.get("type")]
type_str = "|".join(types) if types else "any"
elif "allOf" in prop:
type_str = "object"
# Add array item type if present
if type_str == "array" and "items" in prop:
items = prop["items"]
if isinstance(items, dict):
item_type = items.get("type", "any")
type_str = f"array[{item_type}]"
result[name] = type_str
return result
def get_block_summaries(include_schemas: bool = True) -> str:
"""Generate block summaries for prompts.
"""Generate compact block summaries for prompts.
Args:
include_schemas: Whether to include full input/output schemas
include_schemas: Whether to include input/output type info
Returns:
Formatted string of block summaries
Formatted string of block summaries (compact format)
"""
blocks = get_blocks()
summaries = []
@@ -64,51 +99,49 @@ def get_block_summaries(include_schemas: bool = True) -> str:
name = block.name
desc = getattr(block, "description", "") or ""
# Truncate description
if len(desc) > 150:
desc = desc[:147] + "..."
if not include_schemas:
# Simple format
if len(desc) > 200:
desc = desc[:197] + "..."
summaries.append(f"- {name} (id: {block_id}): {desc}")
else:
# Full format with schemas
input_schema = {}
output_schema = {}
# Compact format with type info only
inputs = {}
outputs = {}
required = []
if hasattr(block, "input_schema"):
try:
full_schema = block.input_schema.jsonschema()
input_schema = {
"properties": full_schema.get("properties", {}),
"required": full_schema.get("required", []),
}
schema = block.input_schema.jsonschema()
inputs = _compact_schema(schema)
required = schema.get("required", [])
except Exception:
pass
if hasattr(block, "output_schema"):
try:
full_schema = block.output_schema.jsonschema()
output_schema = {
"properties": full_schema.get("properties", {}),
}
schema = block.output_schema.jsonschema()
outputs = _compact_schema(schema)
except Exception:
pass
block_info = {
"name": name,
"id": block_id,
"description": desc[:500] if len(desc) > 500 else desc,
"inputSchema": input_schema,
"outputSchema": output_schema,
}
# Build compact line format
# Format: NAME (id): desc | in: {field:type, ...} [required] | out: {field:type}
in_str = ", ".join(f"{k}:{v}" for k, v in inputs.items())
out_str = ", ".join(f"{k}:{v}" for k, v in outputs.items())
req_str = f" req=[{','.join(required)}]" if required else ""
# Check for static output
if getattr(block, "static_output", False):
block_info["staticOutput"] = True
static = " [static]" if getattr(block, "static_output", False) else ""
summaries.append(json.dumps(block_info, indent=2))
line = f"- {name} (id: {block_id}): {desc}"
if in_str:
line += f"\n in: {{{in_str}}}{req_str}"
if out_str:
line += f"\n out: {{{out_str}}}{static}"
summaries.append(line)
if include_schemas:
return "\n\n".join(summaries)
return "\n".join(summaries)

View File

@@ -269,7 +269,7 @@ class EditAgentTool(BaseTool):
try:
created_graph, library_agent = await save_agent_to_library(
updated_agent, user_id
updated_agent, user_id, is_update=True
)
return AgentSavedResponse(

View File

@@ -172,6 +172,7 @@ class RunBlockTool(BaseTool):
logger.info(f"Executing block {block.name} ({block_id}) for user {user_id}")
# Check credentials
creds_manager = IntegrationCredentialsManager()
matched_credentials, missing_credentials = await self._check_block_credentials(
user_id, block
)
@@ -205,16 +206,31 @@ class RunBlockTool(BaseTool):
)
try:
# Inject matched credentials into input_data
# Fetch actual credentials and prepare kwargs for block execution
exec_kwargs: dict[str, Any] = {"user_id": user_id}
for field_name, cred_meta in matched_credentials.items():
# Inject metadata into input_data (for validation)
if field_name not in input_data:
input_data[field_name] = cred_meta.model_dump()
# Fetch actual credentials and pass as kwargs (for execution)
actual_credentials = await creds_manager.get(
user_id, cred_meta.id, lock=False
)
if actual_credentials:
exec_kwargs[field_name] = actual_credentials
else:
return ErrorResponse(
message=f"Failed to retrieve credentials for {field_name}",
session_id=session_id,
)
# Execute the block and collect outputs
outputs: dict[str, list[Any]] = defaultdict(list)
async for output_name, output_data in block.execute(
input_data,
user_id=user_id,
**exec_kwargs,
):
outputs[output_name].append(output_data)

View File

@@ -60,9 +60,7 @@ async def main(batch_size: int = 100) -> int:
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="Backfill embeddings for store agents"
)
parser = argparse.ArgumentParser(description="Backfill embeddings for store agents")
parser.add_argument(
"--batch-size",
type=int,