Compare commits

..

1 Commits

Author SHA1 Message Date
Otto
7f7a7067ec refactor(copilot): use Pydantic models and match/case in customize_agent
Addresses review feedback from ntindle:

1. Use typed parameters instead of kwargs.get():
   - Added CustomizeAgentInput Pydantic model with field_validator for stripping strings
   - Tool now uses params = CustomizeAgentInput(**kwargs) pattern

2. Use match/case for cleaner pattern matching:
   - Extracted response handling to _handle_customization_result method
   - Uses match result_type: case 'error' | 'clarifying_questions' | _

3. Improved code organization:
   - Split monolithic _execute into smaller focused methods
   - _handle_customization_result for response type handling
   - _save_or_preview_agent for final save/preview logic
2026-02-04 08:53:02 +00:00
8 changed files with 157 additions and 139 deletions

View File

@@ -3,6 +3,8 @@
import logging
from typing import Any
from pydantic import BaseModel, field_validator
from backend.api.features.chat.model import ChatSession
from backend.api.features.store import db as store_db
from backend.api.features.store.exceptions import AgentNotFoundError
@@ -27,6 +29,23 @@ from .models import (
logger = logging.getLogger(__name__)
class CustomizeAgentInput(BaseModel):
"""Input parameters for the customize_agent tool."""
agent_id: str = ""
modifications: str = ""
context: str = ""
save: bool = True
@field_validator("agent_id", "modifications", "context", mode="before")
@classmethod
def strip_strings(cls, v: Any) -> str:
"""Strip whitespace from string fields."""
if isinstance(v, str):
return v.strip()
return v if v is not None else ""
class CustomizeAgentTool(BaseTool):
"""Tool for customizing marketplace/template agents using natural language."""
@@ -92,7 +111,7 @@ class CustomizeAgentTool(BaseTool):
self,
user_id: str | None,
session: ChatSession,
**kwargs,
**kwargs: Any,
) -> ToolResponseBase:
"""Execute the customize_agent tool.
@@ -102,20 +121,17 @@ class CustomizeAgentTool(BaseTool):
3. Call customize_template with the modification request
4. Preview or save based on the save parameter
"""
agent_id = kwargs.get("agent_id", "").strip()
modifications = kwargs.get("modifications", "").strip()
context = kwargs.get("context", "")
save = kwargs.get("save", True)
params = CustomizeAgentInput(**kwargs)
session_id = session.session_id if session else None
if not agent_id:
if not params.agent_id:
return ErrorResponse(
message="Please provide the marketplace agent ID (e.g., 'creator/agent-name').",
error="missing_agent_id",
session_id=session_id,
)
if not modifications:
if not params.modifications:
return ErrorResponse(
message="Please describe how you want to customize this agent.",
error="missing_modifications",
@@ -123,11 +139,11 @@ class CustomizeAgentTool(BaseTool):
)
# Parse agent_id in format "creator/slug"
parts = [p.strip() for p in agent_id.split("/")]
parts = params.agent_id.split("/")
if len(parts) != 2 or not parts[0] or not parts[1]:
return ErrorResponse(
message=(
f"Invalid agent ID format: '{agent_id}'. "
f"Invalid agent ID format: '{params.agent_id}'. "
"Expected format is 'creator/agent-name' "
"(e.g., 'autogpt/newsletter-writer')."
),
@@ -145,14 +161,14 @@ class CustomizeAgentTool(BaseTool):
except AgentNotFoundError:
return ErrorResponse(
message=(
f"Could not find marketplace agent '{agent_id}'. "
f"Could not find marketplace agent '{params.agent_id}'. "
"Please check the agent ID and try again."
),
error="agent_not_found",
session_id=session_id,
)
except Exception as e:
logger.error(f"Error fetching marketplace agent {agent_id}: {e}")
logger.error(f"Error fetching marketplace agent {params.agent_id}: {e}")
return ErrorResponse(
message="Failed to fetch the marketplace agent. Please try again.",
error="fetch_error",
@@ -162,7 +178,7 @@ class CustomizeAgentTool(BaseTool):
if not agent_details.store_listing_version_id:
return ErrorResponse(
message=(
f"The agent '{agent_id}' does not have an available version. "
f"The agent '{params.agent_id}' does not have an available version. "
"Please try a different agent."
),
error="no_version_available",
@@ -174,7 +190,7 @@ class CustomizeAgentTool(BaseTool):
graph = await store_db.get_agent(agent_details.store_listing_version_id)
template_agent = graph_to_json(graph)
except Exception as e:
logger.error(f"Error fetching agent graph for {agent_id}: {e}")
logger.error(f"Error fetching agent graph for {params.agent_id}: {e}")
return ErrorResponse(
message="Failed to fetch the agent configuration. Please try again.",
error="graph_fetch_error",
@@ -185,8 +201,8 @@ class CustomizeAgentTool(BaseTool):
try:
result = await customize_template(
template_agent=template_agent,
modification_request=modifications,
context=context,
modification_request=params.modifications,
context=params.context,
)
except AgentGeneratorNotConfiguredError:
return ErrorResponse(
@@ -198,7 +214,7 @@ class CustomizeAgentTool(BaseTool):
session_id=session_id,
)
except Exception as e:
logger.error(f"Error calling customize_template for {agent_id}: {e}")
logger.error(f"Error calling customize_template for {params.agent_id}: {e}")
return ErrorResponse(
message=(
"Failed to customize the agent due to a service error. "
@@ -219,55 +235,25 @@ class CustomizeAgentTool(BaseTool):
session_id=session_id,
)
# Handle error response
if isinstance(result, dict) and result.get("type") == "error":
error_msg = result.get("error", "Unknown error")
error_type = result.get("error_type", "unknown")
user_message = get_user_message_for_error(
error_type,
operation="customize the agent",
llm_parse_message=(
"The AI had trouble customizing the agent. "
"Please try again or simplify your request."
),
validation_message=(
"The customized agent failed validation. "
"Please try rephrasing your request."
),
error_details=error_msg,
)
return ErrorResponse(
message=user_message,
error=f"customization_failed:{error_type}",
session_id=session_id,
)
# Handle response using match/case for cleaner pattern matching
return await self._handle_customization_result(
result=result,
params=params,
agent_details=agent_details,
user_id=user_id,
session_id=session_id,
)
# Handle clarifying questions
if isinstance(result, dict) and result.get("type") == "clarifying_questions":
questions = result.get("questions") or []
if not isinstance(questions, list):
logger.error(
f"Unexpected clarifying questions format: {type(questions)}"
)
questions = []
return ClarificationNeededResponse(
message=(
"I need some more information to customize this agent. "
"Please answer the following questions:"
),
questions=[
ClarifyingQuestion(
question=q.get("question", ""),
keyword=q.get("keyword", ""),
example=q.get("example"),
)
for q in questions
if isinstance(q, dict)
],
session_id=session_id,
)
# Result should be the customized agent JSON
async def _handle_customization_result(
self,
result: dict[str, Any],
params: CustomizeAgentInput,
agent_details: Any,
user_id: str | None,
session_id: str | None,
) -> ToolResponseBase:
"""Handle the result from customize_template using pattern matching."""
# Ensure result is a dict
if not isinstance(result, dict):
logger.error(f"Unexpected customize_template response type: {type(result)}")
return ErrorResponse(
@@ -276,8 +262,77 @@ class CustomizeAgentTool(BaseTool):
session_id=session_id,
)
customized_agent = result
result_type = result.get("type")
match result_type:
case "error":
error_msg = result.get("error", "Unknown error")
error_type = result.get("error_type", "unknown")
user_message = get_user_message_for_error(
error_type,
operation="customize the agent",
llm_parse_message=(
"The AI had trouble customizing the agent. "
"Please try again or simplify your request."
),
validation_message=(
"The customized agent failed validation. "
"Please try rephrasing your request."
),
error_details=error_msg,
)
return ErrorResponse(
message=user_message,
error=f"customization_failed:{error_type}",
session_id=session_id,
)
case "clarifying_questions":
questions_data = result.get("questions") or []
if not isinstance(questions_data, list):
logger.error(
f"Unexpected clarifying questions format: {type(questions_data)}"
)
questions_data = []
questions = [
ClarifyingQuestion(
question=q.get("question", "") if isinstance(q, dict) else "",
keyword=q.get("keyword", "") if isinstance(q, dict) else "",
example=q.get("example") if isinstance(q, dict) else None,
)
for q in questions_data
if isinstance(q, dict)
]
return ClarificationNeededResponse(
message=(
"I need some more information to customize this agent. "
"Please answer the following questions:"
),
questions=questions,
session_id=session_id,
)
case _:
# Default case: result is the customized agent JSON
return await self._save_or_preview_agent(
customized_agent=result,
params=params,
agent_details=agent_details,
user_id=user_id,
session_id=session_id,
)
async def _save_or_preview_agent(
self,
customized_agent: dict[str, Any],
params: CustomizeAgentInput,
agent_details: Any,
user_id: str | None,
session_id: str | None,
) -> ToolResponseBase:
"""Save or preview the customized agent based on params.save."""
agent_name = customized_agent.get(
"name", f"Customized {agent_details.agent_name}"
)
@@ -287,7 +342,7 @@ class CustomizeAgentTool(BaseTool):
node_count = len(nodes) if isinstance(nodes, list) else 0
link_count = len(links) if isinstance(links, list) else 0
if not save:
if not params.save:
return AgentPreviewResponse(
message=(
f"I've customized the agent '{agent_details.agent_name}'. "

View File

@@ -1,17 +1,6 @@
import { OAuthPopupResultMessage } from "./types";
import { NextResponse } from "next/server";
/**
* Safely encode a value as JSON for embedding in a script tag.
* Escapes characters that could break out of the script context to prevent XSS.
*/
function safeJsonStringify(value: unknown): string {
return JSON.stringify(value)
.replace(/</g, "\\u003c")
.replace(/>/g, "\\u003e")
.replace(/&/g, "\\u0026");
}
// This route is intended to be used as the callback for integration OAuth flows,
// controlled by the CredentialsInput component. The CredentialsInput opens the login
// page in a pop-up window, which then redirects to this route to close the loop.
@@ -34,13 +23,12 @@ export async function GET(request: Request) {
console.debug("Sending message to opener:", message);
// Return a response with the message as JSON and a script to close the window
// Use safeJsonStringify to prevent XSS by escaping <, >, and & characters
return new NextResponse(
`
<html>
<body>
<script>
window.opener.postMessage(${safeJsonStringify(message)});
window.opener.postMessage(${JSON.stringify(message)});
window.close();
</script>
</body>

View File

@@ -26,20 +26,8 @@ export function buildCopilotChatUrl(prompt: string): string {
export function getQuickActions(): string[] {
return [
"I don't know where to start, just ask me stuff",
"I do the same thing every week and it's killing me",
"Help me find where I'm wasting my time",
"Show me what I can automate",
"Design a custom workflow",
"Help me with content creation",
];
}
export function getInputPlaceholder(width?: number) {
if (!width) return "What's your role and what eats up most of your day?";
if (width < 500) {
return "I'm a chef and I hate...";
}
if (width <= 1080) {
return "What's your role and what eats up most of your day?";
}
return "What's your role and what eats up most of your day? e.g. 'I'm a recruiter and I hate...'";
}

View File

@@ -6,9 +6,7 @@ import { Text } from "@/components/atoms/Text/Text";
import { Chat } from "@/components/contextual/Chat/Chat";
import { ChatInput } from "@/components/contextual/Chat/components/ChatInput/ChatInput";
import { Dialog } from "@/components/molecules/Dialog/Dialog";
import { useEffect, useState } from "react";
import { useCopilotStore } from "./copilot-page-store";
import { getInputPlaceholder } from "./helpers";
import { useCopilotPage } from "./useCopilotPage";
export default function CopilotPage() {
@@ -16,25 +14,8 @@ export default function CopilotPage() {
const isInterruptModalOpen = useCopilotStore((s) => s.isInterruptModalOpen);
const confirmInterrupt = useCopilotStore((s) => s.confirmInterrupt);
const cancelInterrupt = useCopilotStore((s) => s.cancelInterrupt);
const [inputPlaceholder, setInputPlaceholder] = useState(
getInputPlaceholder(),
);
useEffect(() => {
const handleResize = () => {
setInputPlaceholder(getInputPlaceholder(window.innerWidth));
};
handleResize();
window.addEventListener("resize", handleResize);
return () => window.removeEventListener("resize", handleResize);
}, []);
const { greetingName, quickActions, isLoading, hasSession, initialPrompt } =
state;
const {
handleQuickAction,
startChatWithPrompt,
@@ -92,7 +73,7 @@ export default function CopilotPage() {
}
return (
<div className="flex h-full flex-1 items-center justify-center overflow-y-auto bg-[#f8f8f9] px-3 py-5 md:px-6 md:py-10">
<div className="flex h-full flex-1 items-center justify-center overflow-y-auto bg-[#f8f8f9] px-6 py-10">
<div className="w-full text-center">
{isLoading ? (
<div className="mx-auto max-w-2xl">
@@ -109,25 +90,25 @@ export default function CopilotPage() {
</div>
) : (
<>
<div className="mx-auto max-w-3xl">
<div className="mx-auto max-w-2xl">
<Text
variant="h3"
className="mb-1 !text-[1.375rem] text-zinc-700"
className="mb-3 !text-[1.375rem] text-zinc-700"
>
Hey, <span className="text-violet-600">{greetingName}</span>
</Text>
<Text variant="h3" className="mb-8 !font-normal">
Tell me about your work I&apos;ll find what to automate.
What do you want to automate?
</Text>
<div className="mb-6">
<ChatInput
onSend={startChatWithPrompt}
placeholder={inputPlaceholder}
placeholder='You can search or just ask - e.g. "create a blog post outline"'
/>
</div>
</div>
<div className="flex flex-wrap items-center justify-center gap-3 overflow-x-auto [-ms-overflow-style:none] [scrollbar-width:none] [&::-webkit-scrollbar]:hidden">
<div className="flex flex-nowrap items-center justify-center gap-3 overflow-x-auto [-ms-overflow-style:none] [scrollbar-width:none] [&::-webkit-scrollbar]:hidden">
{quickActions.map((action) => (
<Button
key={action}
@@ -135,7 +116,7 @@ export default function CopilotPage() {
variant="outline"
size="small"
onClick={() => handleQuickAction(action)}
className="h-auto shrink-0 border-zinc-300 px-3 py-2 text-[.9rem] text-zinc-600"
className="h-auto shrink-0 border-zinc-600 !px-4 !py-2 text-[1rem] text-zinc-600"
>
{action}
</Button>

View File

@@ -2,6 +2,7 @@ import type { SessionDetailResponse } from "@/app/api/__generated__/models/sessi
import { Button } from "@/components/atoms/Button/Button";
import { Text } from "@/components/atoms/Text/Text";
import { Dialog } from "@/components/molecules/Dialog/Dialog";
import { useBreakpoint } from "@/lib/hooks/useBreakpoint";
import { cn } from "@/lib/utils";
import { GlobeHemisphereEastIcon } from "@phosphor-icons/react";
import { useEffect } from "react";
@@ -55,6 +56,10 @@ export function ChatContainer({
onStreamingChange?.(isStreaming);
}, [isStreaming, onStreamingChange]);
const breakpoint = useBreakpoint();
const isMobile =
breakpoint === "base" || breakpoint === "sm" || breakpoint === "md";
return (
<div
className={cn(
@@ -122,7 +127,11 @@ export function ChatContainer({
disabled={isStreaming || !sessionId}
isStreaming={isStreaming}
onStop={stopStreaming}
placeholder="What else can I help with?"
placeholder={
isMobile
? "You can search or just ask"
: 'You can search or just ask — e.g. "create a blog post outline"'
}
/>
</div>
</div>

View File

@@ -74,20 +74,19 @@ export function ChatInput({
hasMultipleLines ? "rounded-xlarge" : "rounded-full",
)}
>
{!value && !isRecording && (
<div
className="pointer-events-none absolute inset-0 top-0.5 flex items-center justify-start pl-14 text-[1rem] text-zinc-400"
aria-hidden="true"
>
{isTranscribing ? "Transcribing..." : placeholder}
</div>
)}
<textarea
id={inputId}
aria-label="Chat message input"
value={value}
onChange={handleChange}
onKeyDown={handleKeyDown}
placeholder={
isTranscribing
? "Transcribing..."
: isRecording
? ""
: placeholder
}
disabled={isInputDisabled}
rows={1}
className={cn(
@@ -123,14 +122,13 @@ export function ChatInput({
size="icon"
aria-label={isRecording ? "Stop recording" : "Start recording"}
onClick={toggleRecording}
disabled={disabled || isTranscribing || isStreaming}
disabled={disabled || isTranscribing}
className={cn(
isRecording
? "animate-pulse border-red-500 bg-red-500 text-white hover:border-red-600 hover:bg-red-600"
: isTranscribing
? "border-zinc-300 bg-zinc-100 text-zinc-400"
: "border-zinc-300 bg-white text-zinc-500 hover:border-zinc-400 hover:bg-zinc-50 hover:text-zinc-700",
isStreaming && "opacity-40",
)}
>
{isTranscribing ? (

View File

@@ -38,8 +38,8 @@ export function AudioWaveform({
// Create audio context and analyser
const audioContext = new AudioContext();
const analyser = audioContext.createAnalyser();
analyser.fftSize = 256;
analyser.smoothingTimeConstant = 0.3;
analyser.fftSize = 512;
analyser.smoothingTimeConstant = 0.8;
// Connect the stream to the analyser
const source = audioContext.createMediaStreamSource(stream);
@@ -73,11 +73,10 @@ export function AudioWaveform({
maxAmplitude = Math.max(maxAmplitude, amplitude);
}
// Normalize amplitude (0-128 range) to 0-1
const normalized = maxAmplitude / 128;
// Apply sensitivity boost (multiply by 4) and use sqrt curve to amplify quiet sounds
const boosted = Math.min(1, Math.sqrt(normalized) * 4);
const height = minBarHeight + boosted * (maxBarHeight - minBarHeight);
// Map amplitude (0-128) to bar height
const normalized = (maxAmplitude / 128) * 255;
const height =
minBarHeight + (normalized / 255) * (maxBarHeight - minBarHeight);
newBars.push(height);
}

View File

@@ -224,7 +224,7 @@ export function useVoiceRecording({
[value, isTranscribing, toggleRecording, baseHandleKeyDown],
);
const showMicButton = isSupported;
const showMicButton = isSupported && !isStreaming;
const isInputDisabled = disabled || isStreaming || isTranscribing;
// Cleanup on unmount