mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-04-08 03:00:28 -04:00
feat(copilot): add mode toggle and baseline transcript support
- Add transcript support to baseline autopilot (download/upload/build) for feature parity with SDK path, enabling seamless mode switching - Thread `mode` field through full stack: StreamChatRequest → queue → executor → service selection (fast=baseline, extended_thinking=SDK) - Add mode toggle button in ChatInput UI with brain/lightning icons - Persist mode preference in localStorage via Zustand store
This commit is contained in:
@@ -110,6 +110,11 @@ class StreamChatRequest(BaseModel):
|
||||
file_ids: list[str] | None = Field(
|
||||
default=None, max_length=20
|
||||
) # Workspace file IDs attached to this message
|
||||
mode: str | None = Field(
|
||||
default=None,
|
||||
description="Autopilot mode: 'fast' for baseline LLM, 'extended_thinking' for Claude Agent SDK. "
|
||||
"If None, uses the server default (extended_thinking).",
|
||||
)
|
||||
|
||||
|
||||
class CreateSessionResponse(BaseModel):
|
||||
@@ -811,6 +816,7 @@ async def stream_chat_post(
|
||||
is_user_message=request.is_user_message,
|
||||
context=request.context,
|
||||
file_ids=sanitized_file_ids,
|
||||
mode=request.mode,
|
||||
)
|
||||
|
||||
setup_time = (time.perf_counter() - stream_start_time) * 1000
|
||||
|
||||
@@ -41,6 +41,12 @@ from backend.copilot.response_model import (
|
||||
StreamToolOutputAvailable,
|
||||
StreamUsage,
|
||||
)
|
||||
from backend.copilot.sdk.transcript import (
|
||||
download_transcript,
|
||||
upload_transcript,
|
||||
validate_transcript,
|
||||
)
|
||||
from backend.copilot.sdk.transcript_builder import TranscriptBuilder
|
||||
from backend.copilot.service import (
|
||||
_build_system_prompt,
|
||||
_generate_session_title,
|
||||
@@ -51,11 +57,7 @@ from backend.copilot.token_tracking import persist_and_record_usage
|
||||
from backend.copilot.tools import execute_tool, get_available_tools
|
||||
from backend.copilot.tracking import track_user_message
|
||||
from backend.util.exceptions import NotFoundError
|
||||
from backend.util.prompt import (
|
||||
compress_context,
|
||||
estimate_token_count,
|
||||
estimate_token_count_str,
|
||||
)
|
||||
from backend.util.prompt import compress_context
|
||||
from backend.util.tool_call_loop import (
|
||||
LLMLoopResponse,
|
||||
LLMToolCall,
|
||||
@@ -196,6 +198,7 @@ async def _baseline_tool_executor(
|
||||
state: _BaselineStreamState,
|
||||
user_id: str | None,
|
||||
session: ChatSession,
|
||||
transcript_builder: TranscriptBuilder,
|
||||
) -> ToolCallResult:
|
||||
"""Execute a tool via the copilot tool registry.
|
||||
|
||||
@@ -218,6 +221,10 @@ async def _baseline_tool_executor(
|
||||
success=False,
|
||||
)
|
||||
)
|
||||
transcript_builder.append_tool_result(
|
||||
tool_use_id=tool_call_id,
|
||||
content=parse_error,
|
||||
)
|
||||
return ToolCallResult(
|
||||
tool_call_id=tool_call_id,
|
||||
tool_name=tool_name,
|
||||
@@ -248,6 +255,10 @@ async def _baseline_tool_executor(
|
||||
tool_output = (
|
||||
result.output if isinstance(result.output, str) else str(result.output)
|
||||
)
|
||||
transcript_builder.append_tool_result(
|
||||
tool_use_id=tool_call_id,
|
||||
content=tool_output,
|
||||
)
|
||||
return ToolCallResult(
|
||||
tool_call_id=tool_call_id,
|
||||
tool_name=tool_name,
|
||||
@@ -269,6 +280,10 @@ async def _baseline_tool_executor(
|
||||
success=False,
|
||||
)
|
||||
)
|
||||
transcript_builder.append_tool_result(
|
||||
tool_use_id=tool_call_id,
|
||||
content=error_output,
|
||||
)
|
||||
return ToolCallResult(
|
||||
tool_call_id=tool_call_id,
|
||||
tool_name=tool_name,
|
||||
@@ -281,6 +296,9 @@ def _baseline_conversation_updater(
|
||||
messages: list[dict[str, Any]],
|
||||
response: LLMLoopResponse,
|
||||
tool_results: list[ToolCallResult] | None = None,
|
||||
*,
|
||||
transcript_builder: TranscriptBuilder,
|
||||
model: str = "",
|
||||
) -> None:
|
||||
"""Update OpenAI message list with assistant response + tool results.
|
||||
|
||||
@@ -300,6 +318,27 @@ def _baseline_conversation_updater(
|
||||
for tc in response.tool_calls
|
||||
]
|
||||
messages.append(assistant_msg)
|
||||
# Record assistant message (with tool_calls) to transcript
|
||||
content_blocks: list[dict[str, Any]] = []
|
||||
if response.response_text:
|
||||
content_blocks.append({"type": "text", "text": response.response_text})
|
||||
for tc in response.tool_calls:
|
||||
try:
|
||||
args = orjson.loads(tc.arguments) if tc.arguments else {}
|
||||
except Exception:
|
||||
args = {}
|
||||
content_blocks.append(
|
||||
{
|
||||
"type": "tool_use",
|
||||
"id": tc.id,
|
||||
"name": tc.name,
|
||||
"input": args,
|
||||
}
|
||||
)
|
||||
if content_blocks:
|
||||
transcript_builder.append_assistant(
|
||||
content_blocks=content_blocks, model=model
|
||||
)
|
||||
for tr in tool_results:
|
||||
messages.append(
|
||||
{
|
||||
@@ -311,6 +350,11 @@ def _baseline_conversation_updater(
|
||||
else:
|
||||
if response.response_text:
|
||||
messages.append({"role": "assistant", "content": response.response_text})
|
||||
# Record final text to transcript
|
||||
transcript_builder.append_assistant(
|
||||
content_blocks=[{"type": "text", "text": response.response_text}],
|
||||
model=model,
|
||||
)
|
||||
|
||||
|
||||
async def _update_title_async(
|
||||
@@ -415,6 +459,34 @@ async def stream_chat_completion_baseline(
|
||||
|
||||
session = await upsert_chat_session(session)
|
||||
|
||||
# --- Transcript support (feature parity with SDK path) ---
|
||||
transcript_builder = TranscriptBuilder()
|
||||
transcript_covers_prefix = True
|
||||
|
||||
if user_id and len(session.messages) > 1:
|
||||
try:
|
||||
dl = await download_transcript(user_id, session_id, log_prefix="[Baseline]")
|
||||
if dl and validate_transcript(dl.content):
|
||||
transcript_builder.load_previous(dl.content, log_prefix="[Baseline]")
|
||||
logger.info(
|
||||
"[Baseline] Loaded transcript: %dB, msg_count=%d",
|
||||
len(dl.content),
|
||||
dl.message_count,
|
||||
)
|
||||
elif dl:
|
||||
logger.warning("[Baseline] Downloaded transcript but invalid")
|
||||
transcript_covers_prefix = False
|
||||
else:
|
||||
logger.debug("[Baseline] No transcript available")
|
||||
transcript_covers_prefix = False
|
||||
except Exception as e:
|
||||
logger.warning("[Baseline] Transcript download failed: %s", e)
|
||||
transcript_covers_prefix = False
|
||||
|
||||
# Append user message to transcript
|
||||
if message:
|
||||
transcript_builder.append_user(content=message)
|
||||
|
||||
# Generate title for new sessions
|
||||
if is_user_message and not session.title:
|
||||
user_messages = [m for m in session.messages if m.role == "user"]
|
||||
@@ -480,7 +552,17 @@ async def stream_chat_completion_baseline(
|
||||
# using functools.partial so they satisfy the Protocol signatures.
|
||||
_bound_llm_caller = partial(_baseline_llm_caller, state=state)
|
||||
_bound_tool_executor = partial(
|
||||
_baseline_tool_executor, state=state, user_id=user_id, session=session
|
||||
_baseline_tool_executor,
|
||||
state=state,
|
||||
user_id=user_id,
|
||||
session=session,
|
||||
transcript_builder=transcript_builder,
|
||||
)
|
||||
|
||||
_bound_conversation_updater = partial(
|
||||
_baseline_conversation_updater,
|
||||
transcript_builder=transcript_builder,
|
||||
model=config.model,
|
||||
)
|
||||
|
||||
try:
|
||||
@@ -490,7 +572,7 @@ async def stream_chat_completion_baseline(
|
||||
tools=tools,
|
||||
llm_call=_bound_llm_caller,
|
||||
execute_tool=_bound_tool_executor,
|
||||
update_conversation=_baseline_conversation_updater,
|
||||
update_conversation=_bound_conversation_updater,
|
||||
max_iterations=_MAX_TOOL_ROUNDS,
|
||||
):
|
||||
# Drain buffered events after each iteration (real-time streaming)
|
||||
@@ -558,6 +640,11 @@ async def stream_chat_completion_baseline(
|
||||
and state.turn_completion_tokens == 0
|
||||
and not (_stream_error and not state.assistant_text)
|
||||
):
|
||||
from backend.util.prompt import (
|
||||
estimate_token_count,
|
||||
estimate_token_count_str,
|
||||
)
|
||||
|
||||
state.turn_prompt_tokens = max(
|
||||
estimate_token_count(openai_messages, model=config.model), 1
|
||||
)
|
||||
@@ -593,6 +680,23 @@ async def stream_chat_completion_baseline(
|
||||
except Exception as persist_err:
|
||||
logger.error("[Baseline] Failed to persist session: %s", persist_err)
|
||||
|
||||
# --- Upload transcript for next-turn continuity ---
|
||||
if user_id and transcript_covers_prefix:
|
||||
try:
|
||||
_transcript_content = transcript_builder.to_jsonl()
|
||||
if _transcript_content and validate_transcript(_transcript_content):
|
||||
await upload_transcript(
|
||||
user_id=user_id,
|
||||
session_id=session_id,
|
||||
content=_transcript_content,
|
||||
message_count=len(session.messages),
|
||||
log_prefix="[Baseline]",
|
||||
)
|
||||
else:
|
||||
logger.debug("[Baseline] No valid transcript to upload")
|
||||
except Exception as upload_err:
|
||||
logger.error("[Baseline] Transcript upload failed: %s", upload_err)
|
||||
|
||||
# Yield usage and finish AFTER try/finally (not inside finally).
|
||||
# PEP 525 prohibits yielding from finally in async generators during
|
||||
# aclose() — doing so raises RuntimeError on client disconnect.
|
||||
|
||||
@@ -251,20 +251,31 @@ class CoPilotProcessor:
|
||||
stream_fn = stream_chat_completion_dummy
|
||||
log.warning("Using DUMMY service (CHAT_TEST_MODE=true)")
|
||||
else:
|
||||
use_sdk = (
|
||||
config.use_claude_code_subscription
|
||||
or await is_feature_enabled(
|
||||
Flag.COPILOT_SDK,
|
||||
entry.user_id or "anonymous",
|
||||
default=config.use_claude_agent_sdk,
|
||||
# Per-request mode override from the frontend takes priority.
|
||||
# 'fast' → baseline (OpenAI-compatible), 'extended_thinking' → SDK.
|
||||
if entry.mode == "fast":
|
||||
use_sdk = False
|
||||
elif entry.mode == "extended_thinking":
|
||||
use_sdk = True
|
||||
else:
|
||||
# No mode specified — fall back to feature flag / config.
|
||||
use_sdk = (
|
||||
config.use_claude_code_subscription
|
||||
or await is_feature_enabled(
|
||||
Flag.COPILOT_SDK,
|
||||
entry.user_id or "anonymous",
|
||||
default=config.use_claude_agent_sdk,
|
||||
)
|
||||
)
|
||||
)
|
||||
stream_fn = (
|
||||
sdk_service.stream_chat_completion_sdk
|
||||
if use_sdk
|
||||
else stream_chat_completion_baseline
|
||||
)
|
||||
log.info(f"Using {'SDK' if use_sdk else 'baseline'} service")
|
||||
log.info(
|
||||
f"Using {'SDK' if use_sdk else 'baseline'} service "
|
||||
f"(mode={entry.mode or 'default'})"
|
||||
)
|
||||
|
||||
# Stream chat completion and publish chunks to Redis.
|
||||
# stream_and_publish wraps the raw stream with registry
|
||||
|
||||
@@ -156,6 +156,9 @@ class CoPilotExecutionEntry(BaseModel):
|
||||
file_ids: list[str] | None = None
|
||||
"""Workspace file IDs attached to the user's message"""
|
||||
|
||||
mode: str | None = None
|
||||
"""Autopilot mode override: 'fast' or 'extended_thinking'. None = server default."""
|
||||
|
||||
|
||||
class CancelCoPilotEvent(BaseModel):
|
||||
"""Event to cancel a CoPilot operation."""
|
||||
@@ -175,6 +178,7 @@ async def enqueue_copilot_turn(
|
||||
is_user_message: bool = True,
|
||||
context: dict[str, str] | None = None,
|
||||
file_ids: list[str] | None = None,
|
||||
mode: str | None = None,
|
||||
) -> None:
|
||||
"""Enqueue a CoPilot task for processing by the executor service.
|
||||
|
||||
@@ -197,6 +201,7 @@ async def enqueue_copilot_turn(
|
||||
is_user_message=is_user_message,
|
||||
context=context,
|
||||
file_ids=file_ids,
|
||||
mode=mode,
|
||||
)
|
||||
|
||||
queue_client = await get_async_copilot_queue()
|
||||
|
||||
@@ -12,6 +12,7 @@ import { AttachmentMenu } from "./components/AttachmentMenu";
|
||||
import { FileChips } from "./components/FileChips";
|
||||
import { RecordingButton } from "./components/RecordingButton";
|
||||
import { RecordingIndicator } from "./components/RecordingIndicator";
|
||||
import { useCopilotUIStore } from "../../store";
|
||||
import { useChatInput } from "./useChatInput";
|
||||
import { useVoiceRecording } from "./useVoiceRecording";
|
||||
|
||||
@@ -42,6 +43,7 @@ export function ChatInput({
|
||||
droppedFiles,
|
||||
onDroppedFilesConsumed,
|
||||
}: Props) {
|
||||
const { copilotMode, setCopilotMode } = useCopilotUIStore();
|
||||
const [files, setFiles] = useState<File[]>([]);
|
||||
|
||||
// Merge files dropped onto the chat window into internal state.
|
||||
@@ -157,6 +159,71 @@ export function ChatInput({
|
||||
onFilesSelected={handleFilesSelected}
|
||||
disabled={isBusy}
|
||||
/>
|
||||
<button
|
||||
type="button"
|
||||
onClick={() =>
|
||||
setCopilotMode(
|
||||
copilotMode === "extended_thinking"
|
||||
? "fast"
|
||||
: "extended_thinking",
|
||||
)
|
||||
}
|
||||
className={cn(
|
||||
"inline-flex items-center gap-1 rounded-md px-2 py-1 text-xs font-medium transition-colors",
|
||||
copilotMode === "extended_thinking"
|
||||
? "bg-purple-100 text-purple-700 hover:bg-purple-200 dark:bg-purple-900/30 dark:text-purple-300"
|
||||
: "bg-amber-100 text-amber-700 hover:bg-amber-200 dark:bg-amber-900/30 dark:text-amber-300",
|
||||
)}
|
||||
title={
|
||||
copilotMode === "extended_thinking"
|
||||
? "Extended Thinking mode — deeper reasoning (click to switch to Fast mode)"
|
||||
: "Fast mode — quicker responses (click to switch to Extended Thinking)"
|
||||
}
|
||||
>
|
||||
{copilotMode === "extended_thinking" ? (
|
||||
<>
|
||||
<svg
|
||||
xmlns="http://www.w3.org/2000/svg"
|
||||
width="14"
|
||||
height="14"
|
||||
viewBox="0 0 24 24"
|
||||
fill="none"
|
||||
stroke="currentColor"
|
||||
strokeWidth="2"
|
||||
strokeLinecap="round"
|
||||
strokeLinejoin="round"
|
||||
>
|
||||
<path d="M12 5a3 3 0 1 0-5.997.125 4 4 0 0 0-2.526 5.77 4 4 0 0 0 .556 6.588A4 4 0 1 0 12 18Z" />
|
||||
<path d="M12 5a3 3 0 1 1 5.997.125 4 4 0 0 1 2.526 5.77 4 4 0 0 1-.556 6.588A4 4 0 1 1 12 18Z" />
|
||||
<path d="M15 13a4.5 4.5 0 0 1-3-4 4.5 4.5 0 0 1-3 4" />
|
||||
<path d="M17.599 6.5a3 3 0 0 0 .399-1.375" />
|
||||
<path d="M6.003 5.125A3 3 0 0 0 6.401 6.5" />
|
||||
<path d="M3.477 10.896a4 4 0 0 1 .585-.396" />
|
||||
<path d="M19.938 10.5a4 4 0 0 1 .585.396" />
|
||||
<path d="M6 18a4 4 0 0 1-1.967-.516" />
|
||||
<path d="M19.967 17.484A4 4 0 0 1 18 18" />
|
||||
</svg>
|
||||
Thinking
|
||||
</>
|
||||
) : (
|
||||
<>
|
||||
<svg
|
||||
xmlns="http://www.w3.org/2000/svg"
|
||||
width="14"
|
||||
height="14"
|
||||
viewBox="0 0 24 24"
|
||||
fill="none"
|
||||
stroke="currentColor"
|
||||
strokeWidth="2"
|
||||
strokeLinecap="round"
|
||||
strokeLinejoin="round"
|
||||
>
|
||||
<path d="M4 14a1 1 0 0 1-.78-1.63l9.9-10.2a.5.5 0 0 1 .86.46l-1.92 6.02A1 1 0 0 0 13 10h7a1 1 0 0 1 .78 1.63l-9.9 10.2a.5.5 0 0 1-.86-.46l1.92-6.02A1 1 0 0 0 11 14z" />
|
||||
</svg>
|
||||
Fast
|
||||
</>
|
||||
)}
|
||||
</button>
|
||||
</PromptInputTools>
|
||||
|
||||
<div className="flex items-center gap-4">
|
||||
|
||||
@@ -31,6 +31,10 @@ interface CopilotUIState {
|
||||
showNotificationDialog: boolean;
|
||||
setShowNotificationDialog: (show: boolean) => void;
|
||||
|
||||
/** Autopilot mode: 'extended_thinking' (default) or 'fast'. */
|
||||
copilotMode: "extended_thinking" | "fast";
|
||||
setCopilotMode: (mode: "extended_thinking" | "fast") => void;
|
||||
|
||||
clearCopilotLocalData: () => void;
|
||||
}
|
||||
|
||||
@@ -80,15 +84,25 @@ export const useCopilotUIStore = create<CopilotUIState>((set) => ({
|
||||
showNotificationDialog: false,
|
||||
setShowNotificationDialog: (show) => set({ showNotificationDialog: show }),
|
||||
|
||||
copilotMode:
|
||||
(storage.get(Key.COPILOT_MODE) as "extended_thinking" | "fast") ||
|
||||
"extended_thinking",
|
||||
setCopilotMode: (mode) => {
|
||||
storage.set(Key.COPILOT_MODE, mode);
|
||||
set({ copilotMode: mode });
|
||||
},
|
||||
|
||||
clearCopilotLocalData: () => {
|
||||
storage.clean(Key.COPILOT_NOTIFICATIONS_ENABLED);
|
||||
storage.clean(Key.COPILOT_SOUND_ENABLED);
|
||||
storage.clean(Key.COPILOT_NOTIFICATION_BANNER_DISMISSED);
|
||||
storage.clean(Key.COPILOT_NOTIFICATION_DIALOG_DISMISSED);
|
||||
storage.clean(Key.COPILOT_MODE);
|
||||
set({
|
||||
completedSessionIDs: new Set<string>(),
|
||||
isNotificationsEnabled: false,
|
||||
isSoundEnabled: true,
|
||||
copilotMode: "extended_thinking",
|
||||
});
|
||||
document.title = "AutoGPT";
|
||||
},
|
||||
|
||||
@@ -32,8 +32,13 @@ export function useCopilotPage() {
|
||||
const [pendingMessage, setPendingMessage] = useState<string | null>(null);
|
||||
const queryClient = useQueryClient();
|
||||
|
||||
const { sessionToDelete, setSessionToDelete, isDrawerOpen, setDrawerOpen } =
|
||||
useCopilotUIStore();
|
||||
const {
|
||||
sessionToDelete,
|
||||
setSessionToDelete,
|
||||
isDrawerOpen,
|
||||
setDrawerOpen,
|
||||
copilotMode,
|
||||
} = useCopilotUIStore();
|
||||
|
||||
const {
|
||||
sessionId,
|
||||
@@ -64,6 +69,7 @@ export function useCopilotPage() {
|
||||
hydratedMessages,
|
||||
hasActiveStream,
|
||||
refetchSession,
|
||||
copilotMode,
|
||||
});
|
||||
|
||||
useCopilotNotifications(sessionId);
|
||||
|
||||
@@ -38,6 +38,8 @@ interface UseCopilotStreamArgs {
|
||||
hydratedMessages: UIMessage[] | undefined;
|
||||
hasActiveStream: boolean;
|
||||
refetchSession: () => Promise<{ data?: unknown }>;
|
||||
/** Autopilot mode to use for requests. */
|
||||
copilotMode: "extended_thinking" | "fast";
|
||||
}
|
||||
|
||||
export function useCopilotStream({
|
||||
@@ -45,6 +47,7 @@ export function useCopilotStream({
|
||||
hydratedMessages,
|
||||
hasActiveStream,
|
||||
refetchSession,
|
||||
copilotMode,
|
||||
}: UseCopilotStreamArgs) {
|
||||
const queryClient = useQueryClient();
|
||||
const [rateLimitMessage, setRateLimitMessage] = useState<string | null>(null);
|
||||
@@ -79,6 +82,7 @@ export function useCopilotStream({
|
||||
is_user_message: last.role === "user",
|
||||
context: null,
|
||||
file_ids: fileIds && fileIds.length > 0 ? fileIds : null,
|
||||
mode: copilotMode,
|
||||
},
|
||||
headers: await getAuthHeaders(),
|
||||
};
|
||||
@@ -89,7 +93,7 @@ export function useCopilotStream({
|
||||
}),
|
||||
})
|
||||
: null,
|
||||
[sessionId],
|
||||
[sessionId, copilotMode],
|
||||
);
|
||||
|
||||
// Reconnect state — use refs for values read inside callbacks to avoid
|
||||
|
||||
@@ -12930,6 +12930,11 @@
|
||||
{ "type": "null" }
|
||||
],
|
||||
"title": "File Ids"
|
||||
},
|
||||
"mode": {
|
||||
"anyOf": [{ "type": "string" }, { "type": "null" }],
|
||||
"title": "Mode",
|
||||
"description": "Autopilot mode: 'fast' for baseline LLM, 'extended_thinking' for Claude Agent SDK. If None, uses the server default (extended_thinking)."
|
||||
}
|
||||
},
|
||||
"type": "object",
|
||||
|
||||
@@ -15,6 +15,7 @@ export enum Key {
|
||||
COPILOT_NOTIFICATIONS_ENABLED = "copilot-notifications-enabled",
|
||||
COPILOT_NOTIFICATION_BANNER_DISMISSED = "copilot-notification-banner-dismissed",
|
||||
COPILOT_NOTIFICATION_DIALOG_DISMISSED = "copilot-notification-dialog-dismissed",
|
||||
COPILOT_MODE = "copilot-mode",
|
||||
}
|
||||
|
||||
function get(key: Key) {
|
||||
|
||||
Reference in New Issue
Block a user