mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-04-30 03:00:41 -04:00
feat(copilot): standard/advanced model toggle with Opus rate-limit multiplier
Add a per-request model tier toggle to CoPilot. Users can switch between Standard (Sonnet) and Advanced (Opus) from the chat input toolbar. Opus turns consume rate-limit quota 5× faster (matching Anthropic pricing), so no separate entitlement gate is needed — usage self-limits via the budget. - Add CopilotLlmModel = Literal["standard", "advanced"] type - ModelToggleButton (sky-blue, star icon, label only when active) - localStorage persistence via Key.COPILOT_MODEL - Backend: model tier passed in request body, resolved to actual model name - Rate-limit multiplier 5.0 for Opus in record_token_usage (Redis only, does not affect PlatformCostLog or cost_usd — those use real API values) - Reduce claude_agent_max_budget_usd default from $15 to $10
This commit is contained in:
@@ -15,7 +15,7 @@ from pydantic import BaseModel, ConfigDict, Field, field_validator
|
||||
|
||||
from backend.copilot import service as chat_service
|
||||
from backend.copilot import stream_registry
|
||||
from backend.copilot.config import ChatConfig, CopilotMode
|
||||
from backend.copilot.config import ChatConfig, CopilotLlmModel, CopilotMode
|
||||
from backend.copilot.db import get_chat_messages_paginated
|
||||
from backend.copilot.executor.utils import enqueue_cancel_task, enqueue_copilot_turn
|
||||
from backend.copilot.model import (
|
||||
@@ -139,6 +139,11 @@ class StreamChatRequest(BaseModel):
|
||||
description="Autopilot mode: 'fast' for baseline LLM, 'extended_thinking' for Claude Agent SDK. "
|
||||
"If None, uses the server default (extended_thinking).",
|
||||
)
|
||||
model: CopilotLlmModel | None = Field(
|
||||
default=None,
|
||||
description="Model tier: 'standard' for the default model, 'advanced' for the highest-capability model. "
|
||||
"If None, the server applies per-user LD targeting then falls back to config.",
|
||||
)
|
||||
|
||||
|
||||
class CreateSessionRequest(BaseModel):
|
||||
@@ -891,6 +896,7 @@ async def stream_chat_post(
|
||||
context=request.context,
|
||||
file_ids=sanitized_file_ids,
|
||||
mode=request.mode,
|
||||
model=request.model,
|
||||
)
|
||||
|
||||
setup_time = (time.perf_counter() - stream_start_time) * 1000
|
||||
|
||||
@@ -16,6 +16,13 @@ from backend.util.clients import OPENROUTER_BASE_URL
|
||||
# subscription flag → LaunchDarkly COPILOT_SDK → config.use_claude_agent_sdk.
|
||||
CopilotMode = Literal["fast", "extended_thinking"]
|
||||
|
||||
# Per-request model tier set by the frontend model toggle.
|
||||
# 'standard' uses the global config default (currently Sonnet).
|
||||
# 'advanced' forces the highest-capability model (currently Opus).
|
||||
# None means no preference — falls through to LD per-user targeting, then config.
|
||||
# Using tier names instead of model names keeps the contract model-agnostic.
|
||||
CopilotLlmModel = Literal["standard", "advanced"]
|
||||
|
||||
|
||||
class ChatConfig(BaseSettings):
|
||||
"""Configuration for the chat system."""
|
||||
@@ -163,12 +170,12 @@ class ChatConfig(BaseSettings):
|
||||
"CHAT_CLAUDE_AGENT_MAX_TURNS env var if your workflows need more.",
|
||||
)
|
||||
claude_agent_max_budget_usd: float = Field(
|
||||
default=15.0,
|
||||
default=10.0,
|
||||
ge=0.01,
|
||||
le=1000.0,
|
||||
description="Maximum spend in USD per SDK query. The CLI attempts "
|
||||
"to wrap up gracefully when this budget is reached. "
|
||||
"Set to $15 to allow most tasks to complete (p50=$5.37, p75=$13.07). "
|
||||
"Set to $10 to allow most tasks to complete (p50=$5.37, p75=$13.07). "
|
||||
"Override via CHAT_CLAUDE_AGENT_MAX_BUDGET_USD env var.",
|
||||
)
|
||||
claude_agent_max_thinking_tokens: int = Field(
|
||||
|
||||
@@ -351,6 +351,7 @@ class CoPilotProcessor:
|
||||
context=entry.context,
|
||||
file_ids=entry.file_ids,
|
||||
mode=effective_mode,
|
||||
model=entry.model,
|
||||
)
|
||||
async for chunk in stream_registry.stream_and_publish(
|
||||
session_id=entry.session_id,
|
||||
|
||||
@@ -9,7 +9,7 @@ import logging
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from backend.copilot.config import CopilotMode
|
||||
from backend.copilot.config import CopilotLlmModel, CopilotMode
|
||||
from backend.data.rabbitmq import Exchange, ExchangeType, Queue, RabbitMQConfig
|
||||
from backend.util.logging import TruncatedLogger, is_structured_logging_enabled
|
||||
|
||||
@@ -160,6 +160,9 @@ class CoPilotExecutionEntry(BaseModel):
|
||||
mode: CopilotMode | None = None
|
||||
"""Autopilot mode override: 'fast' or 'extended_thinking'. None = server default."""
|
||||
|
||||
model: CopilotLlmModel | None = None
|
||||
"""Per-request model tier: 'standard' or 'advanced'. None = server default."""
|
||||
|
||||
|
||||
class CancelCoPilotEvent(BaseModel):
|
||||
"""Event to cancel a CoPilot operation."""
|
||||
@@ -180,6 +183,7 @@ async def enqueue_copilot_turn(
|
||||
context: dict[str, str] | None = None,
|
||||
file_ids: list[str] | None = None,
|
||||
mode: CopilotMode | None = None,
|
||||
model: CopilotLlmModel | None = None,
|
||||
) -> None:
|
||||
"""Enqueue a CoPilot task for processing by the executor service.
|
||||
|
||||
@@ -192,6 +196,7 @@ async def enqueue_copilot_turn(
|
||||
context: Optional context for the message (e.g., {url: str, content: str})
|
||||
file_ids: Optional workspace file IDs attached to the user's message
|
||||
mode: Autopilot mode override ('fast' or 'extended_thinking'). None = server default.
|
||||
model: Per-request model tier ('standard' or 'advanced'). None = server default.
|
||||
"""
|
||||
from backend.util.clients import get_async_copilot_queue
|
||||
|
||||
@@ -204,6 +209,7 @@ async def enqueue_copilot_turn(
|
||||
context=context,
|
||||
file_ids=file_ids,
|
||||
mode=mode,
|
||||
model=model,
|
||||
)
|
||||
|
||||
queue_client = await get_async_copilot_queue()
|
||||
|
||||
@@ -302,6 +302,7 @@ async def record_token_usage(
|
||||
*,
|
||||
cache_read_tokens: int = 0,
|
||||
cache_creation_tokens: int = 0,
|
||||
model_cost_multiplier: float = 1.0,
|
||||
) -> None:
|
||||
"""Record token usage for a user across all windows.
|
||||
|
||||
@@ -315,12 +316,17 @@ async def record_token_usage(
|
||||
``prompt_tokens`` should be the *uncached* input count (``input_tokens``
|
||||
from the API response). Cache counts are passed separately.
|
||||
|
||||
``model_cost_multiplier`` scales the final weighted total to reflect
|
||||
relative model cost. Use 5.0 for Opus (5× more expensive than Sonnet)
|
||||
so that Opus turns deplete the rate limit faster, proportional to cost.
|
||||
|
||||
Args:
|
||||
user_id: The user's ID.
|
||||
prompt_tokens: Uncached input tokens.
|
||||
completion_tokens: Output tokens.
|
||||
cache_read_tokens: Tokens served from prompt cache (10% cost).
|
||||
cache_creation_tokens: Tokens written to prompt cache (25% cost).
|
||||
model_cost_multiplier: Relative model cost factor (1.0 = Sonnet, 5.0 = Opus).
|
||||
"""
|
||||
prompt_tokens = max(0, prompt_tokens)
|
||||
completion_tokens = max(0, completion_tokens)
|
||||
@@ -332,7 +338,9 @@ async def record_token_usage(
|
||||
+ round(cache_creation_tokens * 0.25)
|
||||
+ round(cache_read_tokens * 0.1)
|
||||
)
|
||||
total = weighted_input + completion_tokens
|
||||
total = round(
|
||||
(weighted_input + completion_tokens) * max(1.0, model_cost_multiplier)
|
||||
)
|
||||
if total <= 0:
|
||||
return
|
||||
|
||||
@@ -340,11 +348,12 @@ async def record_token_usage(
|
||||
prompt_tokens + cache_read_tokens + cache_creation_tokens + completion_tokens
|
||||
)
|
||||
logger.info(
|
||||
"Recording token usage for %s: raw=%d, weighted=%d "
|
||||
"Recording token usage for %s: raw=%d, weighted=%d, multiplier=%.1fx "
|
||||
"(uncached=%d, cache_read=%d@10%%, cache_create=%d@25%%, output=%d)",
|
||||
user_id[:8],
|
||||
raw_total,
|
||||
total,
|
||||
model_cost_multiplier,
|
||||
prompt_tokens,
|
||||
cache_read_tokens,
|
||||
cache_creation_tokens,
|
||||
|
||||
@@ -60,7 +60,7 @@ from backend.util.feature_flag import (
|
||||
)
|
||||
from backend.util.settings import Settings
|
||||
|
||||
from ..config import ChatConfig, CopilotMode
|
||||
from ..config import ChatConfig, CopilotLlmModel, CopilotMode
|
||||
from ..constants import (
|
||||
COPILOT_ERROR_PREFIX,
|
||||
COPILOT_RETRYABLE_ERROR_PREFIX,
|
||||
@@ -1958,6 +1958,7 @@ async def stream_chat_completion_sdk(
|
||||
file_ids: list[str] | None = None,
|
||||
permissions: "CopilotPermissions | None" = None,
|
||||
mode: CopilotMode | None = None,
|
||||
model: CopilotLlmModel | None = None,
|
||||
**_kwargs: Any,
|
||||
) -> AsyncIterator[StreamBaseResponse]:
|
||||
"""Stream chat completion using Claude Agent SDK.
|
||||
@@ -1968,6 +1969,9 @@ async def stream_chat_completion_sdk(
|
||||
saved to the SDK working directory for the Read tool.
|
||||
mode: Accepted for signature compatibility with the baseline path.
|
||||
The SDK path does not currently branch on this value.
|
||||
model: Per-request model preference from the frontend toggle.
|
||||
'opus' → Claude Opus; 'sonnet' → global config default.
|
||||
Takes priority over per-user LaunchDarkly targeting.
|
||||
"""
|
||||
_ = mode # SDK path ignores the requested mode.
|
||||
|
||||
@@ -2298,6 +2302,36 @@ async def stream_chat_completion_sdk(
|
||||
)
|
||||
sdk_model = user_model_override
|
||||
|
||||
# Explicit per-request model tier from frontend toggle — highest priority,
|
||||
# overrides both global config and per-user LD targeting.
|
||||
# 'advanced' → claude-opus-4-6 (highest capability today).
|
||||
# 'standard' → config.model (current Sonnet default).
|
||||
# Rate-limit multiplier (5×) ensures Opus turns deplete quota faster —
|
||||
# no separate entitlement gate needed; users self-limit via rate limiting.
|
||||
if model == "advanced":
|
||||
sdk_model = _normalize_model_name("anthropic/claude-opus-4-6")
|
||||
logger.info(
|
||||
"[SDK] [%s] Per-request model override: advanced (%s)",
|
||||
session_id[:12] if session_id else "?",
|
||||
sdk_model,
|
||||
)
|
||||
elif model == "standard":
|
||||
sdk_model = _normalize_model_name(config.model)
|
||||
logger.info(
|
||||
"[SDK] [%s] Per-request model override: standard (%s)",
|
||||
session_id[:12] if session_id else "?",
|
||||
sdk_model,
|
||||
)
|
||||
|
||||
# Compute rate-limit cost multiplier based on the final model.
|
||||
# Opus costs 5× more than Sonnet (Anthropic pricing: $15/$75 vs $3/$15 per M tokens).
|
||||
# This multiplier scales the token counter in record_token_usage so that
|
||||
# Opus turns deplete the rate limit proportionally faster.
|
||||
_OPUS_COST_MULTIPLIER = 5.0
|
||||
model_cost_multiplier = (
|
||||
_OPUS_COST_MULTIPLIER if sdk_model and "opus" in sdk_model else 1.0
|
||||
)
|
||||
|
||||
# Track SDK-internal compaction (PreCompact hook → start, next msg → end)
|
||||
compaction = CompactionTracker()
|
||||
|
||||
@@ -2944,6 +2978,7 @@ async def stream_chat_completion_sdk(
|
||||
cost_usd=turn_cost_usd,
|
||||
model=config.model,
|
||||
provider="anthropic",
|
||||
model_cost_multiplier=model_cost_multiplier,
|
||||
)
|
||||
|
||||
# --- Persist session messages ---
|
||||
|
||||
@@ -19,6 +19,7 @@ from .service import (
|
||||
_is_prompt_too_long,
|
||||
_is_tool_only_message,
|
||||
_iter_sdk_messages,
|
||||
_normalize_model_name,
|
||||
_reduce_context,
|
||||
_resolve_user_model_override,
|
||||
)
|
||||
@@ -427,3 +428,44 @@ class TestResolveUserModelOverride:
|
||||
with patch("backend.copilot.sdk.service.get_feature_flag_value", new=ld_mock):
|
||||
await _resolve_user_model_override("user-abc")
|
||||
ld_mock.assert_called_once_with("copilot-model", "user-abc", default=None)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _normalize_model_name — used by per-request model override
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestNormalizeModelName:
|
||||
"""Unit tests for the model-name normalisation helper.
|
||||
|
||||
The per-request model toggle calls _normalize_model_name with either
|
||||
``"anthropic/claude-opus-4-6"`` (for 'advanced') or ``config.model`` (for
|
||||
'standard'). These tests verify the OpenRouter/provider-prefix stripping
|
||||
that keeps the value compatible with the Claude CLI.
|
||||
"""
|
||||
|
||||
def test_strips_anthropic_prefix(self):
|
||||
assert _normalize_model_name("anthropic/claude-opus-4-6") == "claude-opus-4-6"
|
||||
|
||||
def test_strips_openai_prefix(self):
|
||||
assert _normalize_model_name("openai/gpt-4o") == "gpt-4o"
|
||||
|
||||
def test_strips_google_prefix(self):
|
||||
assert _normalize_model_name("google/gemini-2.5-flash") == "gemini-2.5-flash"
|
||||
|
||||
def test_already_normalized_unchanged(self):
|
||||
assert (
|
||||
_normalize_model_name("claude-sonnet-4-20250514")
|
||||
== "claude-sonnet-4-20250514"
|
||||
)
|
||||
|
||||
def test_empty_string_unchanged(self):
|
||||
assert _normalize_model_name("") == ""
|
||||
|
||||
def test_opus_model_roundtrip(self):
|
||||
"""The exact string used for the 'opus' toggle strips correctly."""
|
||||
assert _normalize_model_name("anthropic/claude-opus-4-6") == "claude-opus-4-6"
|
||||
|
||||
def test_sonnet_openrouter_model(self):
|
||||
"""Sonnet model as stored in config (OpenRouter-prefixed) strips cleanly."""
|
||||
assert _normalize_model_name("anthropic/claude-sonnet-4") == "claude-sonnet-4"
|
||||
|
||||
@@ -96,6 +96,7 @@ async def persist_and_record_usage(
|
||||
cost_usd: float | str | None = None,
|
||||
model: str | None = None,
|
||||
provider: str = "open_router",
|
||||
model_cost_multiplier: float = 1.0,
|
||||
) -> int:
|
||||
"""Persist token usage to session and record for rate limiting.
|
||||
|
||||
@@ -109,6 +110,9 @@ async def persist_and_record_usage(
|
||||
log_prefix: Prefix for log messages (e.g. "[SDK]", "[Baseline]").
|
||||
cost_usd: Optional cost for logging (float from SDK, str otherwise).
|
||||
provider: Cost provider name (e.g. "anthropic", "open_router").
|
||||
model_cost_multiplier: Relative model cost factor for rate limiting
|
||||
(1.0 = Sonnet/default, 5.0 = Opus). Scales the token counter so
|
||||
more expensive models deplete the rate limit proportionally faster.
|
||||
|
||||
Returns:
|
||||
The computed total_tokens (prompt + completion; cache excluded).
|
||||
@@ -163,6 +167,7 @@ async def persist_and_record_usage(
|
||||
completion_tokens=completion_tokens,
|
||||
cache_read_tokens=cache_read_tokens,
|
||||
cache_creation_tokens=cache_creation_tokens,
|
||||
model_cost_multiplier=model_cost_multiplier,
|
||||
)
|
||||
except Exception as usage_err:
|
||||
logger.warning("%s Failed to record token usage: %s", log_prefix, usage_err)
|
||||
|
||||
@@ -230,6 +230,7 @@ class TestRateLimitRecording:
|
||||
completion_tokens=50,
|
||||
cache_read_tokens=1000,
|
||||
cache_creation_tokens=200,
|
||||
model_cost_multiplier=1.0,
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
||||
@@ -13,6 +13,7 @@ import { ChangeEvent, useEffect, useState } from "react";
|
||||
import { AttachmentMenu } from "./components/AttachmentMenu";
|
||||
import { DryRunToggleButton } from "./components/DryRunToggleButton";
|
||||
import { FileChips } from "./components/FileChips";
|
||||
import { ModelToggleButton } from "./components/ModelToggleButton";
|
||||
import { ModeToggleButton } from "./components/ModeToggleButton";
|
||||
import { RecordingButton } from "./components/RecordingButton";
|
||||
import { RecordingIndicator } from "./components/RecordingIndicator";
|
||||
@@ -50,8 +51,14 @@ export function ChatInput({
|
||||
onDroppedFilesConsumed,
|
||||
hasSession = false,
|
||||
}: Props) {
|
||||
const { copilotMode, setCopilotMode, isDryRun, setIsDryRun } =
|
||||
useCopilotUIStore();
|
||||
const {
|
||||
copilotMode,
|
||||
setCopilotMode,
|
||||
copilotModel,
|
||||
setCopilotLlmModel,
|
||||
isDryRun,
|
||||
setIsDryRun,
|
||||
} = useCopilotUIStore();
|
||||
const showModeToggle = useGetFlag(Flag.CHAT_MODE_OPTION);
|
||||
const showDryRunToggle = showModeToggle;
|
||||
const [files, setFiles] = useState<File[]>([]);
|
||||
@@ -72,6 +79,21 @@ export function ChatInput({
|
||||
});
|
||||
}
|
||||
|
||||
function handleToggleModel() {
|
||||
const next = copilotModel === "advanced" ? "standard" : "advanced";
|
||||
setCopilotLlmModel(next);
|
||||
toast({
|
||||
title:
|
||||
next === "advanced"
|
||||
? "Switched to Advanced model"
|
||||
: "Switched to Standard model",
|
||||
description:
|
||||
next === "advanced"
|
||||
? "Using the highest-capability model."
|
||||
: "Using the balanced standard model.",
|
||||
});
|
||||
}
|
||||
|
||||
function handleToggleDryRun() {
|
||||
const next = !isDryRun;
|
||||
setIsDryRun(next);
|
||||
@@ -202,6 +224,12 @@ export function ChatInput({
|
||||
onToggle={handleToggleMode}
|
||||
/>
|
||||
)}
|
||||
{showModeToggle && !isStreaming && (
|
||||
<ModelToggleButton
|
||||
model={copilotModel}
|
||||
onToggle={handleToggleModel}
|
||||
/>
|
||||
)}
|
||||
{showDryRunToggle && (!hasSession || isDryRun) && (
|
||||
<DryRunToggleButton
|
||||
isDryRun={isDryRun}
|
||||
|
||||
@@ -0,0 +1,38 @@
|
||||
"use client";
|
||||
|
||||
import { cn } from "@/lib/utils";
|
||||
import { Star } from "@phosphor-icons/react";
|
||||
import type { CopilotLlmModel } from "../../../store";
|
||||
|
||||
interface Props {
|
||||
model: CopilotLlmModel;
|
||||
onToggle: () => void;
|
||||
}
|
||||
|
||||
export function ModelToggleButton({ model, onToggle }: Props) {
|
||||
const isAdvanced = model === "advanced";
|
||||
return (
|
||||
<button
|
||||
type="button"
|
||||
aria-pressed={isAdvanced}
|
||||
onClick={onToggle}
|
||||
className={cn(
|
||||
"inline-flex min-h-11 min-w-11 items-center justify-center gap-1 rounded-md px-2 py-1 text-xs font-medium transition-colors",
|
||||
isAdvanced
|
||||
? "bg-sky-100 text-sky-900 hover:bg-sky-200"
|
||||
: "text-neutral-500 hover:bg-neutral-100 hover:text-neutral-700",
|
||||
)}
|
||||
aria-label={
|
||||
isAdvanced ? "Switch to Standard model" : "Switch to Advanced model"
|
||||
}
|
||||
title={
|
||||
isAdvanced
|
||||
? "Advanced model — highest capability (click to switch to Standard)"
|
||||
: "Standard model — click to switch to Advanced"
|
||||
}
|
||||
>
|
||||
<Star size={14} />
|
||||
{isAdvanced && "Advanced"}
|
||||
</button>
|
||||
);
|
||||
}
|
||||
@@ -0,0 +1,36 @@
|
||||
import { render, screen, fireEvent, cleanup } from "@testing-library/react";
|
||||
import { afterEach, describe, expect, it, vi } from "vitest";
|
||||
import { ModelToggleButton } from "../ModelToggleButton";
|
||||
|
||||
afterEach(cleanup);
|
||||
|
||||
describe("ModelToggleButton", () => {
|
||||
it("shows no label when model is standard", () => {
|
||||
render(<ModelToggleButton model="standard" onToggle={vi.fn()} />);
|
||||
expect(screen.queryByText("Advanced")).toBeNull();
|
||||
});
|
||||
|
||||
it("shows Advanced label when model is advanced", () => {
|
||||
render(<ModelToggleButton model="advanced" onToggle={vi.fn()} />);
|
||||
expect(screen.getByText("Advanced")).toBeTruthy();
|
||||
});
|
||||
|
||||
it("calls onToggle when clicked", () => {
|
||||
const onToggle = vi.fn();
|
||||
render(<ModelToggleButton model="standard" onToggle={onToggle} />);
|
||||
fireEvent.click(screen.getByRole("button"));
|
||||
expect(onToggle).toHaveBeenCalledTimes(1);
|
||||
});
|
||||
|
||||
it("sets aria-pressed=false for standard", () => {
|
||||
render(<ModelToggleButton model="standard" onToggle={vi.fn()} />);
|
||||
const btn = screen.getByLabelText("Switch to Advanced model");
|
||||
expect(btn.getAttribute("aria-pressed")).toBe("false");
|
||||
});
|
||||
|
||||
it("sets aria-pressed=true for advanced", () => {
|
||||
render(<ModelToggleButton model="advanced" onToggle={vi.fn()} />);
|
||||
const btn = screen.getByLabelText("Switch to Standard model");
|
||||
expect(btn.getAttribute("aria-pressed")).toBe("true");
|
||||
});
|
||||
});
|
||||
@@ -52,6 +52,9 @@ export const DEFAULT_PANEL_WIDTH = 600;
|
||||
/** Autopilot response mode. */
|
||||
export type CopilotMode = "extended_thinking" | "fast";
|
||||
|
||||
/** Per-request model tier. 'standard' = current default; 'advanced' = highest-capability. */
|
||||
export type CopilotLlmModel = "standard" | "advanced";
|
||||
|
||||
const isClient = typeof window !== "undefined";
|
||||
|
||||
function getPersistedWidth(): number {
|
||||
@@ -131,6 +134,10 @@ interface CopilotUIState {
|
||||
copilotMode: CopilotMode;
|
||||
setCopilotMode: (mode: CopilotMode) => void;
|
||||
|
||||
/** Model tier: 'standard' (default) or 'advanced' (highest-capability). */
|
||||
copilotModel: CopilotLlmModel;
|
||||
setCopilotLlmModel: (model: CopilotLlmModel) => void;
|
||||
|
||||
/** Developer dry-run mode: sessions created with dry_run=true. */
|
||||
isDryRun: boolean;
|
||||
setIsDryRun: (enabled: boolean) => void;
|
||||
@@ -280,6 +287,15 @@ export const useCopilotUIStore = create<CopilotUIState>((set) => ({
|
||||
set({ copilotMode: mode });
|
||||
},
|
||||
|
||||
copilotModel: (() => {
|
||||
const saved = isClient ? storage.get(Key.COPILOT_MODEL) : null;
|
||||
return saved === "advanced" ? "advanced" : "standard";
|
||||
})(),
|
||||
setCopilotLlmModel: (model) => {
|
||||
storage.set(Key.COPILOT_MODEL, model);
|
||||
set({ copilotModel: model });
|
||||
},
|
||||
|
||||
isDryRun: isClient && storage.get(Key.COPILOT_DRY_RUN) === "true",
|
||||
setIsDryRun: (enabled) => {
|
||||
if (enabled) {
|
||||
@@ -299,6 +315,7 @@ export const useCopilotUIStore = create<CopilotUIState>((set) => ({
|
||||
storage.clean(Key.COPILOT_ARTIFACT_PANEL_WIDTH);
|
||||
storage.clean(Key.COPILOT_COMPLETED_SESSIONS);
|
||||
storage.clean(Key.COPILOT_DRY_RUN);
|
||||
storage.clean(Key.COPILOT_MODEL);
|
||||
set({
|
||||
completedSessionIDs: new Set<string>(),
|
||||
isNotificationsEnabled: false,
|
||||
@@ -312,6 +329,7 @@ export const useCopilotUIStore = create<CopilotUIState>((set) => ({
|
||||
history: [],
|
||||
},
|
||||
copilotMode: "extended_thinking",
|
||||
copilotModel: "standard",
|
||||
isDryRun: false,
|
||||
});
|
||||
if (isClient) {
|
||||
|
||||
@@ -43,6 +43,7 @@ export function useCopilotPage() {
|
||||
isDrawerOpen,
|
||||
setDrawerOpen,
|
||||
copilotMode,
|
||||
copilotModel,
|
||||
isDryRun,
|
||||
} = useCopilotUIStore();
|
||||
|
||||
@@ -79,6 +80,7 @@ export function useCopilotPage() {
|
||||
hasActiveStream,
|
||||
refetchSession,
|
||||
copilotMode: isModeToggleEnabled ? copilotMode : undefined,
|
||||
copilotModel: isModeToggleEnabled ? copilotModel : undefined,
|
||||
});
|
||||
|
||||
const { olderMessages, hasMore, isLoadingMore, loadMore } =
|
||||
|
||||
@@ -18,7 +18,7 @@ import {
|
||||
resolveInProgressTools,
|
||||
getSendSuppressionReason,
|
||||
} from "./helpers";
|
||||
import type { CopilotMode } from "./store";
|
||||
import type { CopilotLlmModel, CopilotMode } from "./store";
|
||||
|
||||
const RECONNECT_BASE_DELAY_MS = 1_000;
|
||||
const RECONNECT_MAX_ATTEMPTS = 3;
|
||||
@@ -33,6 +33,8 @@ interface UseCopilotStreamArgs {
|
||||
refetchSession: () => Promise<{ data?: unknown }>;
|
||||
/** Autopilot mode to use for requests. `undefined` = let backend decide via feature flags. */
|
||||
copilotMode: CopilotMode | undefined;
|
||||
/** Model tier override. `undefined` = let backend decide. */
|
||||
copilotModel: CopilotLlmModel | undefined;
|
||||
}
|
||||
|
||||
export function useCopilotStream({
|
||||
@@ -41,17 +43,20 @@ export function useCopilotStream({
|
||||
hasActiveStream,
|
||||
refetchSession,
|
||||
copilotMode,
|
||||
copilotModel,
|
||||
}: UseCopilotStreamArgs) {
|
||||
const queryClient = useQueryClient();
|
||||
const [rateLimitMessage, setRateLimitMessage] = useState<string | null>(null);
|
||||
function dismissRateLimit() {
|
||||
setRateLimitMessage(null);
|
||||
}
|
||||
// Use a ref for copilotMode so the transport closure always reads the
|
||||
// latest value without recreating the DefaultChatTransport (which would
|
||||
// Use refs for copilotMode and copilotModel so the transport closure always reads
|
||||
// the latest value without recreating the DefaultChatTransport (which would
|
||||
// reset useChat's internal Chat instance and break mid-session streaming).
|
||||
const copilotModeRef = useRef(copilotMode);
|
||||
copilotModeRef.current = copilotMode;
|
||||
const copilotModelRef = useRef(copilotModel);
|
||||
copilotModelRef.current = copilotModel;
|
||||
|
||||
// Connect directly to the Python backend for SSE, bypassing the Next.js
|
||||
// serverless proxy. This eliminates the Vercel 800s function timeout that
|
||||
@@ -83,6 +88,7 @@ export function useCopilotStream({
|
||||
context: null,
|
||||
file_ids: fileIds && fileIds.length > 0 ? fileIds : null,
|
||||
mode: copilotModeRef.current ?? null,
|
||||
model: copilotModelRef.current ?? null,
|
||||
},
|
||||
headers: await getCopilotAuthHeaders(),
|
||||
};
|
||||
|
||||
@@ -17,6 +17,7 @@ export enum Key {
|
||||
COPILOT_NOTIFICATION_DIALOG_DISMISSED = "copilot-notification-dialog-dismissed",
|
||||
COPILOT_ARTIFACT_PANEL_WIDTH = "copilot-artifact-panel-width",
|
||||
COPILOT_MODE = "copilot-mode",
|
||||
COPILOT_MODEL = "copilot-model",
|
||||
COPILOT_COMPLETED_SESSIONS = "copilot-completed-sessions",
|
||||
COPILOT_DRY_RUN = "copilot-dry-run",
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user