Compare commits
4 Commits
test-scree
...
fix/openro
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
079902501e | ||
|
|
a042c84907 | ||
|
|
dfb7f327de | ||
|
|
0c4931b8f8 |
@@ -43,7 +43,6 @@ async def get_cost_dashboard(
|
||||
model: str | None = Query(None),
|
||||
block_name: str | None = Query(None),
|
||||
tracking_type: str | None = Query(None),
|
||||
graph_exec_id: str | None = Query(None),
|
||||
):
|
||||
logger.info("Admin %s fetching platform cost dashboard", admin_user_id)
|
||||
return await get_platform_cost_dashboard(
|
||||
@@ -54,7 +53,6 @@ async def get_cost_dashboard(
|
||||
model=model,
|
||||
block_name=block_name,
|
||||
tracking_type=tracking_type,
|
||||
graph_exec_id=graph_exec_id,
|
||||
)
|
||||
|
||||
|
||||
@@ -74,7 +72,6 @@ async def get_cost_logs(
|
||||
model: str | None = Query(None),
|
||||
block_name: str | None = Query(None),
|
||||
tracking_type: str | None = Query(None),
|
||||
graph_exec_id: str | None = Query(None),
|
||||
):
|
||||
logger.info("Admin %s fetching platform cost logs", admin_user_id)
|
||||
logs, total = await get_platform_cost_logs(
|
||||
@@ -87,7 +84,6 @@ async def get_cost_logs(
|
||||
model=model,
|
||||
block_name=block_name,
|
||||
tracking_type=tracking_type,
|
||||
graph_exec_id=graph_exec_id,
|
||||
)
|
||||
total_pages = (total + page_size - 1) // page_size
|
||||
return PlatformCostLogsResponse(
|
||||
@@ -121,7 +117,6 @@ async def export_cost_logs(
|
||||
model: str | None = Query(None),
|
||||
block_name: str | None = Query(None),
|
||||
tracking_type: str | None = Query(None),
|
||||
graph_exec_id: str | None = Query(None),
|
||||
):
|
||||
logger.info("Admin %s exporting platform cost logs", admin_user_id)
|
||||
logs, truncated = await get_platform_cost_logs_for_export(
|
||||
@@ -132,7 +127,6 @@ async def export_cost_logs(
|
||||
model=model,
|
||||
block_name=block_name,
|
||||
tracking_type=tracking_type,
|
||||
graph_exec_id=graph_exec_id,
|
||||
)
|
||||
return PlatformCostExportResponse(
|
||||
logs=logs,
|
||||
|
||||
@@ -298,6 +298,21 @@ class _TokenUsage:
|
||||
self.cost_usd = None
|
||||
|
||||
|
||||
def _apply_token_usage(acc: _TokenUsage, usage: dict) -> None:
|
||||
"""Accumulate token counts from a ResultMessage usage dict into *acc*.
|
||||
|
||||
Uses ``or 0`` instead of ``.get(key, 0)`` because OpenRouter may include
|
||||
cache token keys with a ``null`` value (rather than omitting them) during
|
||||
the initial streaming event before real counts are available. Plain
|
||||
``.get(key, 0)`` returns ``None`` when the key exists but is ``null``,
|
||||
causing ``int += None`` TypeError.
|
||||
"""
|
||||
acc.prompt_tokens += usage.get("input_tokens") or 0
|
||||
acc.cache_read_tokens += usage.get("cache_read_input_tokens") or 0
|
||||
acc.cache_creation_tokens += usage.get("cache_creation_input_tokens") or 0
|
||||
acc.completion_tokens += usage.get("output_tokens") or 0
|
||||
|
||||
|
||||
@dataclass
|
||||
class _RetryState:
|
||||
"""Mutable state passed to `_run_stream_attempt` instead of closures.
|
||||
@@ -1912,21 +1927,7 @@ async def _run_stream_attempt(
|
||||
# cache_read_input_tokens = served from cache
|
||||
# cache_creation_input_tokens = written to cache
|
||||
if sdk_msg.usage:
|
||||
# Use `or 0` instead of a default in .get() because
|
||||
# OpenRouter may include the key with a null value (e.g.
|
||||
# {"cache_read_input_tokens": null}) for models that don't
|
||||
# yet report cache tokens, making .get("key", 0) return
|
||||
# None rather than the fallback 0.
|
||||
state.usage.prompt_tokens += sdk_msg.usage.get("input_tokens") or 0
|
||||
state.usage.cache_read_tokens += (
|
||||
sdk_msg.usage.get("cache_read_input_tokens") or 0
|
||||
)
|
||||
state.usage.cache_creation_tokens += (
|
||||
sdk_msg.usage.get("cache_creation_input_tokens") or 0
|
||||
)
|
||||
state.usage.completion_tokens += (
|
||||
sdk_msg.usage.get("output_tokens") or 0
|
||||
)
|
||||
_apply_token_usage(state.usage, sdk_msg.usage)
|
||||
logger.info(
|
||||
"%s Token usage: uncached=%d, cache_read=%d, "
|
||||
"cache_create=%d, output=%d",
|
||||
|
||||
@@ -17,6 +17,7 @@ from .conftest import build_test_transcript as _build_transcript
|
||||
from .service import (
|
||||
_RETRY_TARGET_TOKENS,
|
||||
ReducedContext,
|
||||
_apply_token_usage,
|
||||
_is_prompt_too_long,
|
||||
_is_tool_only_message,
|
||||
_iter_sdk_messages,
|
||||
@@ -354,47 +355,6 @@ class TestIsParallelContinuation:
|
||||
assert _is_tool_only_message(msg) is True
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _normalize_model_name — used by per-request model override
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestNormalizeModelName:
|
||||
"""Unit tests for the model-name normalisation helper.
|
||||
|
||||
The per-request model toggle calls _normalize_model_name with either
|
||||
``"anthropic/claude-opus-4-6"`` (for 'advanced') or ``config.model`` (for
|
||||
'standard'). These tests verify the OpenRouter/provider-prefix stripping
|
||||
that keeps the value compatible with the Claude CLI.
|
||||
"""
|
||||
|
||||
def test_strips_anthropic_prefix(self):
|
||||
assert _normalize_model_name("anthropic/claude-opus-4-6") == "claude-opus-4-6"
|
||||
|
||||
def test_strips_openai_prefix(self):
|
||||
assert _normalize_model_name("openai/gpt-4o") == "gpt-4o"
|
||||
|
||||
def test_strips_google_prefix(self):
|
||||
assert _normalize_model_name("google/gemini-2.5-flash") == "gemini-2.5-flash"
|
||||
|
||||
def test_already_normalized_unchanged(self):
|
||||
assert (
|
||||
_normalize_model_name("claude-sonnet-4-20250514")
|
||||
== "claude-sonnet-4-20250514"
|
||||
)
|
||||
|
||||
def test_empty_string_unchanged(self):
|
||||
assert _normalize_model_name("") == ""
|
||||
|
||||
def test_opus_model_roundtrip(self):
|
||||
"""The exact string used for the 'opus' toggle strips correctly."""
|
||||
assert _normalize_model_name("anthropic/claude-opus-4-6") == "claude-opus-4-6"
|
||||
|
||||
def test_sonnet_openrouter_model(self):
|
||||
"""Sonnet model as stored in config (OpenRouter-prefixed) strips cleanly."""
|
||||
assert _normalize_model_name("anthropic/claude-sonnet-4") == "claude-sonnet-4"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _TokenUsage — null-safe accumulation (OpenRouter initial-stream-event bug)
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -409,13 +369,6 @@ class TestTokenUsageNullSafety:
|
||||
when the key existed with a null value, causing 'int += None' TypeError.
|
||||
"""
|
||||
|
||||
def _apply_usage(self, usage: dict, acc: _TokenUsage) -> None:
|
||||
"""Mirror the production accumulation in sdk/service.py."""
|
||||
acc.prompt_tokens += usage.get("input_tokens") or 0
|
||||
acc.cache_read_tokens += usage.get("cache_read_input_tokens") or 0
|
||||
acc.cache_creation_tokens += usage.get("cache_creation_input_tokens") or 0
|
||||
acc.completion_tokens += usage.get("output_tokens") or 0
|
||||
|
||||
def test_null_cache_tokens_do_not_crash(self):
|
||||
"""OpenRouter initial event: cache keys present with null value."""
|
||||
usage = {
|
||||
@@ -425,7 +378,7 @@ class TestTokenUsageNullSafety:
|
||||
"cache_creation_input_tokens": None,
|
||||
}
|
||||
acc = _TokenUsage()
|
||||
self._apply_usage(usage, acc) # must not raise TypeError
|
||||
_apply_token_usage(acc, usage) # must not raise TypeError
|
||||
assert acc.prompt_tokens == 0
|
||||
assert acc.cache_read_tokens == 0
|
||||
assert acc.cache_creation_tokens == 0
|
||||
@@ -440,7 +393,7 @@ class TestTokenUsageNullSafety:
|
||||
"cache_creation_input_tokens": 512,
|
||||
}
|
||||
acc = _TokenUsage()
|
||||
self._apply_usage(usage, acc)
|
||||
_apply_token_usage(acc, usage)
|
||||
assert acc.prompt_tokens == 10
|
||||
assert acc.cache_read_tokens == 16600
|
||||
assert acc.cache_creation_tokens == 512
|
||||
@@ -450,7 +403,7 @@ class TestTokenUsageNullSafety:
|
||||
"""Minimal usage dict without cache keys defaults correctly."""
|
||||
usage = {"input_tokens": 5, "output_tokens": 20}
|
||||
acc = _TokenUsage()
|
||||
self._apply_usage(usage, acc)
|
||||
_apply_token_usage(acc, usage)
|
||||
assert acc.prompt_tokens == 5
|
||||
assert acc.cache_read_tokens == 0
|
||||
assert acc.cache_creation_tokens == 0
|
||||
@@ -471,9 +424,28 @@ class TestTokenUsageNullSafety:
|
||||
"cache_creation_input_tokens": 512,
|
||||
}
|
||||
acc = _TokenUsage()
|
||||
self._apply_usage(null_event, acc)
|
||||
self._apply_usage(real_event, acc)
|
||||
_apply_token_usage(acc, null_event)
|
||||
_apply_token_usage(acc, real_event)
|
||||
assert acc.prompt_tokens == 10
|
||||
assert acc.cache_read_tokens == 16600
|
||||
assert acc.cache_creation_tokens == 512
|
||||
assert acc.completion_tokens == 349
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"key,null_field,real_value,acc_attr",
|
||||
[
|
||||
("cache_read_input_tokens", None, 16600, "cache_read_tokens"),
|
||||
("cache_creation_input_tokens", None, 512, "cache_creation_tokens"),
|
||||
("input_tokens", None, 10, "prompt_tokens"),
|
||||
("output_tokens", None, 349, "completion_tokens"),
|
||||
],
|
||||
)
|
||||
def test_null_then_real_per_field(
|
||||
self, key: str, null_field: None, real_value: int, acc_attr: str
|
||||
) -> None:
|
||||
"""Each token field handles null → real transition independently."""
|
||||
acc = _TokenUsage()
|
||||
_apply_token_usage(acc, {key: null_field})
|
||||
assert getattr(acc, acc_attr) == 0
|
||||
_apply_token_usage(acc, {key: real_value})
|
||||
assert getattr(acc, acc_attr) == real_value
|
||||
|
||||
@@ -215,7 +215,6 @@ def _build_prisma_where(
|
||||
model: str | None = None,
|
||||
block_name: str | None = None,
|
||||
tracking_type: str | None = None,
|
||||
graph_exec_id: str | None = None,
|
||||
) -> PlatformCostLogWhereInput:
|
||||
"""Build a Prisma WhereInput for PlatformCostLog filters."""
|
||||
where: PlatformCostLogWhereInput = {}
|
||||
@@ -243,9 +242,6 @@ def _build_prisma_where(
|
||||
if tracking_type:
|
||||
where["trackingType"] = tracking_type
|
||||
|
||||
if graph_exec_id:
|
||||
where["graphExecId"] = graph_exec_id
|
||||
|
||||
return where
|
||||
|
||||
|
||||
@@ -257,7 +253,6 @@ def _build_raw_where(
|
||||
model: str | None = None,
|
||||
block_name: str | None = None,
|
||||
tracking_type: str | None = None,
|
||||
graph_exec_id: str | None = None,
|
||||
) -> tuple[str, list]:
|
||||
"""Build a parameterised WHERE clause for raw SQL queries.
|
||||
|
||||
@@ -307,11 +302,6 @@ def _build_raw_where(
|
||||
params.append(block_name)
|
||||
idx += 1
|
||||
|
||||
if graph_exec_id is not None:
|
||||
clauses.append(f'"graphExecId" = ${idx}')
|
||||
params.append(graph_exec_id)
|
||||
idx += 1
|
||||
|
||||
return (" AND ".join(clauses), params)
|
||||
|
||||
|
||||
@@ -324,7 +314,6 @@ async def get_platform_cost_dashboard(
|
||||
model: str | None = None,
|
||||
block_name: str | None = None,
|
||||
tracking_type: str | None = None,
|
||||
graph_exec_id: str | None = None,
|
||||
) -> PlatformCostDashboard:
|
||||
"""Aggregate platform cost logs for the admin dashboard.
|
||||
|
||||
@@ -341,7 +330,7 @@ async def get_platform_cost_dashboard(
|
||||
start = datetime.now(timezone.utc) - timedelta(days=DEFAULT_DASHBOARD_DAYS)
|
||||
|
||||
where = _build_prisma_where(
|
||||
start, end, provider, user_id, model, block_name, tracking_type, graph_exec_id
|
||||
start, end, provider, user_id, model, block_name, tracking_type
|
||||
)
|
||||
|
||||
# For per-user tracking-type breakdown we intentionally omit the
|
||||
@@ -349,14 +338,7 @@ async def get_platform_cost_dashboard(
|
||||
# This ensures cost_bearing_request_count is correct even when the caller
|
||||
# is filtering the main view by a different tracking_type.
|
||||
where_no_tracking_type = _build_prisma_where(
|
||||
start,
|
||||
end,
|
||||
provider,
|
||||
user_id,
|
||||
model,
|
||||
block_name,
|
||||
tracking_type=None,
|
||||
graph_exec_id=graph_exec_id,
|
||||
start, end, provider, user_id, model, block_name, tracking_type=None
|
||||
)
|
||||
|
||||
sum_fields = {
|
||||
@@ -376,14 +358,7 @@ async def get_platform_cost_dashboard(
|
||||
# "cost_usd" — percentile and histogram queries only make sense on
|
||||
# cost-denominated rows, regardless of what the caller is filtering.
|
||||
raw_where, raw_params = _build_raw_where(
|
||||
start,
|
||||
end,
|
||||
provider,
|
||||
user_id,
|
||||
model,
|
||||
block_name,
|
||||
tracking_type=None,
|
||||
graph_exec_id=graph_exec_id,
|
||||
start, end, provider, user_id, model, block_name, tracking_type=None
|
||||
)
|
||||
|
||||
# Queries that always run regardless of tracking_type filter.
|
||||
@@ -672,13 +647,12 @@ async def get_platform_cost_logs(
|
||||
model: str | None = None,
|
||||
block_name: str | None = None,
|
||||
tracking_type: str | None = None,
|
||||
graph_exec_id: str | None = None,
|
||||
) -> tuple[list[CostLogRow], int]:
|
||||
if start is None:
|
||||
start = datetime.now(tz=timezone.utc) - timedelta(days=DEFAULT_DASHBOARD_DAYS)
|
||||
|
||||
where = _build_prisma_where(
|
||||
start, end, provider, user_id, model, block_name, tracking_type, graph_exec_id
|
||||
start, end, provider, user_id, model, block_name, tracking_type
|
||||
)
|
||||
offset = (page - 1) * page_size
|
||||
|
||||
@@ -728,7 +702,6 @@ async def get_platform_cost_logs_for_export(
|
||||
model: str | None = None,
|
||||
block_name: str | None = None,
|
||||
tracking_type: str | None = None,
|
||||
graph_exec_id: str | None = None,
|
||||
) -> tuple[list[CostLogRow], bool]:
|
||||
"""Return all matching rows up to EXPORT_MAX_ROWS.
|
||||
|
||||
@@ -739,7 +712,7 @@ async def get_platform_cost_logs_for_export(
|
||||
start = datetime.now(tz=timezone.utc) - timedelta(days=DEFAULT_DASHBOARD_DAYS)
|
||||
|
||||
where = _build_prisma_where(
|
||||
start, end, provider, user_id, model, block_name, tracking_type, graph_exec_id
|
||||
start, end, provider, user_id, model, block_name, tracking_type
|
||||
)
|
||||
|
||||
rows = await PrismaLog.prisma().find_many(
|
||||
|
||||
@@ -19,7 +19,6 @@ interface Props {
|
||||
model?: string;
|
||||
block_name?: string;
|
||||
tracking_type?: string;
|
||||
graph_exec_id?: string;
|
||||
page?: string;
|
||||
tab?: string;
|
||||
};
|
||||
@@ -48,8 +47,6 @@ export function PlatformCostContent({ searchParams }: Props) {
|
||||
setBlockInput,
|
||||
typeInput,
|
||||
setTypeInput,
|
||||
executionIdInput,
|
||||
setExecutionIdInput,
|
||||
rateOverrides,
|
||||
handleRateOverride,
|
||||
updateUrl,
|
||||
@@ -238,22 +235,6 @@ export function PlatformCostContent({ searchParams }: Props) {
|
||||
onChange={(e) => setTypeInput(e.target.value)}
|
||||
/>
|
||||
</div>
|
||||
<div className="flex flex-col gap-1">
|
||||
<label
|
||||
htmlFor="execution-id-filter"
|
||||
className="text-sm text-muted-foreground"
|
||||
>
|
||||
Execution ID
|
||||
</label>
|
||||
<input
|
||||
id="execution-id-filter"
|
||||
type="text"
|
||||
placeholder="Filter by execution"
|
||||
className="rounded border px-3 py-1.5 text-sm"
|
||||
value={executionIdInput}
|
||||
onChange={(e) => setExecutionIdInput(e.target.value)}
|
||||
/>
|
||||
</div>
|
||||
<button
|
||||
onClick={handleFilter}
|
||||
className="rounded bg-primary px-4 py-1.5 text-sm text-primary-foreground hover:bg-primary/90"
|
||||
@@ -269,7 +250,6 @@ export function PlatformCostContent({ searchParams }: Props) {
|
||||
setModelInput("");
|
||||
setBlockInput("");
|
||||
setTypeInput("");
|
||||
setExecutionIdInput("");
|
||||
updateUrl({
|
||||
start: "",
|
||||
end: "",
|
||||
@@ -278,7 +258,6 @@ export function PlatformCostContent({ searchParams }: Props) {
|
||||
model: "",
|
||||
block_name: "",
|
||||
tracking_type: "",
|
||||
graph_exec_id: "",
|
||||
page: "1",
|
||||
});
|
||||
}}
|
||||
|
||||
@@ -23,7 +23,6 @@ interface InitialSearchParams {
|
||||
model?: string;
|
||||
block_name?: string;
|
||||
tracking_type?: string;
|
||||
graph_exec_id?: string;
|
||||
page?: string;
|
||||
tab?: string;
|
||||
}
|
||||
@@ -44,8 +43,6 @@ export function usePlatformCostContent(searchParams: InitialSearchParams) {
|
||||
urlParams.get("block_name") || searchParams.block_name || "";
|
||||
const typeFilter =
|
||||
urlParams.get("tracking_type") || searchParams.tracking_type || "";
|
||||
const executionIdFilter =
|
||||
urlParams.get("graph_exec_id") || searchParams.graph_exec_id || "";
|
||||
|
||||
const [startInput, setStartInput] = useState(toLocalInput(startDate));
|
||||
const [endInput, setEndInput] = useState(toLocalInput(endDate));
|
||||
@@ -54,7 +51,6 @@ export function usePlatformCostContent(searchParams: InitialSearchParams) {
|
||||
const [modelInput, setModelInput] = useState(modelFilter);
|
||||
const [blockInput, setBlockInput] = useState(blockFilter);
|
||||
const [typeInput, setTypeInput] = useState(typeFilter);
|
||||
const [executionIdInput, setExecutionIdInput] = useState(executionIdFilter);
|
||||
const [rateOverrides, setRateOverrides] = useState<Record<string, number>>(
|
||||
{},
|
||||
);
|
||||
@@ -71,7 +67,6 @@ export function usePlatformCostContent(searchParams: InitialSearchParams) {
|
||||
model: modelFilter || undefined,
|
||||
block_name: blockFilter || undefined,
|
||||
tracking_type: typeFilter || undefined,
|
||||
graph_exec_id: executionIdFilter || undefined,
|
||||
};
|
||||
|
||||
const {
|
||||
@@ -120,7 +115,6 @@ export function usePlatformCostContent(searchParams: InitialSearchParams) {
|
||||
model: modelInput,
|
||||
block_name: blockInput,
|
||||
tracking_type: typeInput,
|
||||
graph_exec_id: executionIdInput,
|
||||
page: "1",
|
||||
});
|
||||
}
|
||||
@@ -191,8 +185,6 @@ export function usePlatformCostContent(searchParams: InitialSearchParams) {
|
||||
setBlockInput,
|
||||
typeInput,
|
||||
setTypeInput,
|
||||
executionIdInput,
|
||||
setExecutionIdInput,
|
||||
rateOverrides,
|
||||
handleRateOverride,
|
||||
updateUrl,
|
||||
|
||||
@@ -7,10 +7,6 @@ type SearchParams = {
|
||||
end?: string;
|
||||
provider?: string;
|
||||
user_id?: string;
|
||||
model?: string;
|
||||
block_name?: string;
|
||||
tracking_type?: string;
|
||||
graph_exec_id?: string;
|
||||
page?: string;
|
||||
tab?: string;
|
||||
};
|
||||
|
||||
@@ -218,24 +218,18 @@ export function ChatInput({
|
||||
onFilesSelected={handleFilesSelected}
|
||||
disabled={isBusy}
|
||||
/>
|
||||
{showModeToggle &&
|
||||
!isStreaming &&
|
||||
(!hasSession || copilotChatMode === "extended_thinking") && (
|
||||
<ModeToggleButton
|
||||
mode={copilotChatMode}
|
||||
onToggle={handleToggleMode}
|
||||
readOnly={hasSession}
|
||||
/>
|
||||
)}
|
||||
{showModeToggle &&
|
||||
!isStreaming &&
|
||||
(!hasSession || copilotLlmModel === "advanced") && (
|
||||
<ModelToggleButton
|
||||
model={copilotLlmModel}
|
||||
onToggle={handleToggleModel}
|
||||
readOnly={hasSession}
|
||||
/>
|
||||
)}
|
||||
{showModeToggle && !isStreaming && (
|
||||
<ModeToggleButton
|
||||
mode={copilotChatMode}
|
||||
onToggle={handleToggleMode}
|
||||
/>
|
||||
)}
|
||||
{showModeToggle && !isStreaming && (
|
||||
<ModelToggleButton
|
||||
model={copilotLlmModel}
|
||||
onToggle={handleToggleModel}
|
||||
/>
|
||||
)}
|
||||
{showDryRunToggle && (!hasSession || isDryRun) && (
|
||||
<DryRunToggleButton
|
||||
isDryRun={isDryRun}
|
||||
|
||||
@@ -7,37 +7,28 @@ import type { CopilotMode } from "../../../store";
|
||||
interface Props {
|
||||
mode: CopilotMode;
|
||||
onToggle: () => void;
|
||||
readOnly?: boolean;
|
||||
}
|
||||
|
||||
export function ModeToggleButton({ mode, onToggle, readOnly = false }: Props) {
|
||||
export function ModeToggleButton({ mode, onToggle }: Props) {
|
||||
const isExtended = mode === "extended_thinking";
|
||||
return (
|
||||
<button
|
||||
type="button"
|
||||
aria-pressed={isExtended}
|
||||
disabled={readOnly}
|
||||
onClick={readOnly ? undefined : onToggle}
|
||||
onClick={onToggle}
|
||||
className={cn(
|
||||
"inline-flex min-h-11 min-w-11 items-center justify-center gap-1 rounded-md px-2 py-1 text-xs font-medium transition-colors",
|
||||
isExtended
|
||||
? "bg-purple-100 text-purple-900 hover:bg-purple-200 disabled:hover:bg-purple-100"
|
||||
: "bg-amber-100 text-amber-900 hover:bg-amber-200 disabled:hover:bg-amber-100",
|
||||
readOnly && "cursor-default opacity-70",
|
||||
? "bg-purple-100 text-purple-900 hover:bg-purple-200"
|
||||
: "bg-amber-100 text-amber-900 hover:bg-amber-200",
|
||||
)}
|
||||
aria-label={
|
||||
readOnly
|
||||
? `${isExtended ? "Extended Thinking" : "Fast"} mode active for this session`
|
||||
: isExtended
|
||||
? "Switch to Fast mode"
|
||||
: "Switch to Extended Thinking mode"
|
||||
isExtended ? "Switch to Fast mode" : "Switch to Extended Thinking mode"
|
||||
}
|
||||
title={
|
||||
readOnly
|
||||
? `${isExtended ? "Extended Thinking" : "Fast"} mode active for this session`
|
||||
: isExtended
|
||||
? "Extended Thinking mode — deeper reasoning (click to switch to Fast mode)"
|
||||
: "Fast mode — quicker responses (click to switch to Extended Thinking)"
|
||||
isExtended
|
||||
? "Extended Thinking mode — deeper reasoning (click to switch to Fast mode)"
|
||||
: "Fast mode — quicker responses (click to switch to Extended Thinking)"
|
||||
}
|
||||
>
|
||||
{isExtended ? (
|
||||
|
||||
@@ -7,45 +7,32 @@ import type { CopilotLlmModel } from "../../../store";
|
||||
interface Props {
|
||||
model: CopilotLlmModel;
|
||||
onToggle: () => void;
|
||||
readOnly?: boolean;
|
||||
}
|
||||
|
||||
export function ModelToggleButton({
|
||||
model,
|
||||
onToggle,
|
||||
readOnly = false,
|
||||
}: Props) {
|
||||
export function ModelToggleButton({ model, onToggle }: Props) {
|
||||
const isAdvanced = model === "advanced";
|
||||
return (
|
||||
<button
|
||||
type="button"
|
||||
aria-pressed={isAdvanced}
|
||||
disabled={readOnly}
|
||||
onClick={readOnly ? undefined : onToggle}
|
||||
onClick={onToggle}
|
||||
className={cn(
|
||||
"inline-flex min-h-11 min-w-11 items-center justify-center gap-1 rounded-md px-2 py-1 text-xs font-medium transition-colors",
|
||||
isAdvanced
|
||||
? "bg-sky-100 text-sky-900 hover:bg-sky-200 disabled:hover:bg-sky-100"
|
||||
: "bg-neutral-100 text-neutral-700 hover:bg-neutral-200 disabled:hover:bg-neutral-100",
|
||||
readOnly && "cursor-default opacity-70",
|
||||
? "bg-sky-100 text-sky-900 hover:bg-sky-200"
|
||||
: "text-neutral-500 hover:bg-neutral-100 hover:text-neutral-700",
|
||||
)}
|
||||
aria-label={
|
||||
readOnly
|
||||
? `${isAdvanced ? "Advanced" : "Standard"} model active for this session`
|
||||
: isAdvanced
|
||||
? "Switch to Standard model"
|
||||
: "Switch to Advanced model"
|
||||
isAdvanced ? "Switch to Standard model" : "Switch to Advanced model"
|
||||
}
|
||||
title={
|
||||
readOnly
|
||||
? `${isAdvanced ? "Advanced" : "Standard"} model active for this session`
|
||||
: isAdvanced
|
||||
? "Advanced model — highest capability (click to switch to Standard)"
|
||||
: "Standard model — click to switch to Advanced"
|
||||
isAdvanced
|
||||
? "Advanced model — highest capability (click to switch to Standard)"
|
||||
: "Standard model — click to switch to Advanced"
|
||||
}
|
||||
>
|
||||
<Cpu size={14} />
|
||||
{isAdvanced ? "Advanced" : "Standard"}
|
||||
{isAdvanced && "Advanced"}
|
||||
</button>
|
||||
);
|
||||
}
|
||||
|
||||
@@ -5,9 +5,9 @@ import { ModelToggleButton } from "../ModelToggleButton";
|
||||
afterEach(cleanup);
|
||||
|
||||
describe("ModelToggleButton", () => {
|
||||
it("shows Standard label when model is standard", () => {
|
||||
it("shows no label when model is standard", () => {
|
||||
render(<ModelToggleButton model="standard" onToggle={vi.fn()} />);
|
||||
expect(screen.getByText("Standard")).toBeTruthy();
|
||||
expect(screen.queryByText("Advanced")).toBeNull();
|
||||
});
|
||||
|
||||
it("shows Advanced label when model is advanced", () => {
|
||||
@@ -33,30 +33,4 @@ describe("ModelToggleButton", () => {
|
||||
const btn = screen.getByLabelText("Switch to Standard model");
|
||||
expect(btn.getAttribute("aria-pressed")).toBe("true");
|
||||
});
|
||||
|
||||
it("is disabled when readOnly", () => {
|
||||
render(<ModelToggleButton model="advanced" onToggle={vi.fn()} readOnly />);
|
||||
expect(screen.getByRole("button").hasAttribute("disabled")).toBe(true);
|
||||
});
|
||||
|
||||
it("does not call onToggle when readOnly", () => {
|
||||
const onToggle = vi.fn();
|
||||
render(<ModelToggleButton model="standard" onToggle={onToggle} readOnly />);
|
||||
fireEvent.click(screen.getByRole("button"));
|
||||
expect(onToggle).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it("shows session-locked title when readOnly and advanced", () => {
|
||||
render(<ModelToggleButton model="advanced" onToggle={vi.fn()} readOnly />);
|
||||
expect(
|
||||
screen.getByTitle("Advanced model active for this session"),
|
||||
).toBeDefined();
|
||||
});
|
||||
|
||||
it("shows session-locked title when readOnly and standard", () => {
|
||||
render(<ModelToggleButton model="standard" onToggle={vi.fn()} readOnly />);
|
||||
expect(
|
||||
screen.getByTitle("Standard model active for this session"),
|
||||
).toBeDefined();
|
||||
});
|
||||
});
|
||||
|
||||
@@ -82,15 +82,6 @@
|
||||
"anyOf": [{ "type": "string" }, { "type": "null" }],
|
||||
"title": "Tracking Type"
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "graph_exec_id",
|
||||
"in": "query",
|
||||
"required": false,
|
||||
"schema": {
|
||||
"anyOf": [{ "type": "string" }, { "type": "null" }],
|
||||
"title": "Graph Exec Id"
|
||||
}
|
||||
}
|
||||
],
|
||||
"responses": {
|
||||
@@ -216,15 +207,6 @@
|
||||
"anyOf": [{ "type": "string" }, { "type": "null" }],
|
||||
"title": "Tracking Type"
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "graph_exec_id",
|
||||
"in": "query",
|
||||
"required": false,
|
||||
"schema": {
|
||||
"anyOf": [{ "type": "string" }, { "type": "null" }],
|
||||
"title": "Graph Exec Id"
|
||||
}
|
||||
}
|
||||
],
|
||||
"responses": {
|
||||
@@ -327,15 +309,6 @@
|
||||
"anyOf": [{ "type": "string" }, { "type": "null" }],
|
||||
"title": "Tracking Type"
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "graph_exec_id",
|
||||
"in": "query",
|
||||
"required": false,
|
||||
"schema": {
|
||||
"anyOf": [{ "type": "string" }, { "type": "null" }],
|
||||
"title": "Graph Exec Id"
|
||||
}
|
||||
}
|
||||
],
|
||||
"responses": {
|
||||
|
||||
|
Before Width: | Height: | Size: 74 KiB |
|
Before Width: | Height: | Size: 74 KiB |
|
Before Width: | Height: | Size: 38 KiB |
|
Before Width: | Height: | Size: 74 KiB |
|
Before Width: | Height: | Size: 83 KiB |
|
Before Width: | Height: | Size: 40 KiB |
|
Before Width: | Height: | Size: 75 KiB |
|
Before Width: | Height: | Size: 54 KiB |
|
Before Width: | Height: | Size: 109 KiB |
|
Before Width: | Height: | Size: 75 KiB |
|
Before Width: | Height: | Size: 83 KiB |
|
Before Width: | Height: | Size: 96 KiB |
|
Before Width: | Height: | Size: 83 KiB |
|
Before Width: | Height: | Size: 82 KiB |
|
Before Width: | Height: | Size: 109 KiB |
|
Before Width: | Height: | Size: 109 KiB |
|
Before Width: | Height: | Size: 109 KiB |
|
Before Width: | Height: | Size: 96 KiB |
|
Before Width: | Height: | Size: 111 KiB |
|
Before Width: | Height: | Size: 111 KiB |