mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-03-17 03:00:27 -04:00
Compare commits
1 Commits
fix/copilo
...
master
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
0b594a219c |
@@ -1,40 +0,0 @@
|
||||
-- =============================================================
|
||||
-- View: analytics.auth_activities
|
||||
-- Looker source alias: ds49 | Charts: 1
|
||||
-- =============================================================
|
||||
-- DESCRIPTION
|
||||
-- Tracks authentication events (login, logout, SSO, password
|
||||
-- reset, etc.) from Supabase's internal audit log.
|
||||
-- Useful for monitoring sign-in patterns and detecting anomalies.
|
||||
--
|
||||
-- SOURCE TABLES
|
||||
-- auth.audit_log_entries — Supabase internal auth event log
|
||||
--
|
||||
-- OUTPUT COLUMNS
|
||||
-- created_at TIMESTAMPTZ When the auth event occurred
|
||||
-- actor_id TEXT User ID who triggered the event
|
||||
-- actor_via_sso TEXT Whether the action was via SSO ('true'/'false')
|
||||
-- action TEXT Event type (e.g. 'login', 'logout', 'token_refreshed')
|
||||
--
|
||||
-- WINDOW
|
||||
-- Rolling 90 days from current date
|
||||
--
|
||||
-- EXAMPLE QUERIES
|
||||
-- -- Daily login counts
|
||||
-- SELECT DATE_TRUNC('day', created_at) AS day, COUNT(*) AS logins
|
||||
-- FROM analytics.auth_activities
|
||||
-- WHERE action = 'login'
|
||||
-- GROUP BY 1 ORDER BY 1;
|
||||
--
|
||||
-- -- SSO vs password login breakdown
|
||||
-- SELECT actor_via_sso, COUNT(*) FROM analytics.auth_activities
|
||||
-- WHERE action = 'login' GROUP BY 1;
|
||||
-- =============================================================
|
||||
|
||||
SELECT
|
||||
created_at,
|
||||
payload->>'actor_id' AS actor_id,
|
||||
payload->>'actor_via_sso' AS actor_via_sso,
|
||||
payload->>'action' AS action
|
||||
FROM auth.audit_log_entries
|
||||
WHERE created_at >= NOW() - INTERVAL '90 days'
|
||||
@@ -1,105 +0,0 @@
|
||||
-- =============================================================
|
||||
-- View: analytics.graph_execution
|
||||
-- Looker source alias: ds16 | Charts: 21
|
||||
-- =============================================================
|
||||
-- DESCRIPTION
|
||||
-- One row per agent graph execution (last 90 days).
|
||||
-- Unpacks the JSONB stats column into individual numeric columns
|
||||
-- and normalises the executionStatus — runs that failed due to
|
||||
-- insufficient credits are reclassified as 'NO_CREDITS' for
|
||||
-- easier filtering. Error messages are scrubbed of IDs and URLs
|
||||
-- to allow safe grouping.
|
||||
--
|
||||
-- SOURCE TABLES
|
||||
-- platform.AgentGraphExecution — Execution records
|
||||
-- platform.AgentGraph — Agent graph metadata (for name)
|
||||
-- platform.LibraryAgent — To flag possibly-AI (safe-mode) agents
|
||||
--
|
||||
-- OUTPUT COLUMNS
|
||||
-- id TEXT Execution UUID
|
||||
-- agentGraphId TEXT Agent graph UUID
|
||||
-- agentGraphVersion INT Graph version number
|
||||
-- executionStatus TEXT COMPLETED | FAILED | NO_CREDITS | RUNNING | QUEUED | TERMINATED
|
||||
-- createdAt TIMESTAMPTZ When the execution was queued
|
||||
-- updatedAt TIMESTAMPTZ Last status update time
|
||||
-- userId TEXT Owner user UUID
|
||||
-- agentGraphName TEXT Human-readable agent name
|
||||
-- cputime DECIMAL Total CPU seconds consumed
|
||||
-- walltime DECIMAL Total wall-clock seconds
|
||||
-- node_count DECIMAL Number of nodes in the graph
|
||||
-- nodes_cputime DECIMAL CPU time across all nodes
|
||||
-- nodes_walltime DECIMAL Wall time across all nodes
|
||||
-- execution_cost DECIMAL Credit cost of this execution
|
||||
-- correctness_score FLOAT AI correctness score (if available)
|
||||
-- possibly_ai BOOLEAN True if agent has sensitive_action_safe_mode enabled
|
||||
-- groupedErrorMessage TEXT Scrubbed error string (IDs/URLs replaced with wildcards)
|
||||
--
|
||||
-- WINDOW
|
||||
-- Rolling 90 days (createdAt > CURRENT_DATE - 90 days)
|
||||
--
|
||||
-- EXAMPLE QUERIES
|
||||
-- -- Daily execution counts by status
|
||||
-- SELECT DATE_TRUNC('day', "createdAt") AS day, "executionStatus", COUNT(*)
|
||||
-- FROM analytics.graph_execution
|
||||
-- GROUP BY 1, 2 ORDER BY 1;
|
||||
--
|
||||
-- -- Average cost per execution by agent
|
||||
-- SELECT "agentGraphName", AVG("execution_cost") AS avg_cost, COUNT(*) AS runs
|
||||
-- FROM analytics.graph_execution
|
||||
-- WHERE "executionStatus" = 'COMPLETED'
|
||||
-- GROUP BY 1 ORDER BY avg_cost DESC;
|
||||
--
|
||||
-- -- Top error messages
|
||||
-- SELECT "groupedErrorMessage", COUNT(*) AS occurrences
|
||||
-- FROM analytics.graph_execution
|
||||
-- WHERE "executionStatus" = 'FAILED'
|
||||
-- GROUP BY 1 ORDER BY 2 DESC LIMIT 20;
|
||||
-- =============================================================
|
||||
|
||||
SELECT
|
||||
ge."id" AS id,
|
||||
ge."agentGraphId" AS agentGraphId,
|
||||
ge."agentGraphVersion" AS agentGraphVersion,
|
||||
CASE
|
||||
WHEN jsonb_exists(ge."stats"::jsonb, 'error')
|
||||
AND (
|
||||
(ge."stats"::jsonb->>'error') ILIKE '%insufficient balance%'
|
||||
OR (ge."stats"::jsonb->>'error') ILIKE '%you have no credits left%'
|
||||
)
|
||||
THEN 'NO_CREDITS'
|
||||
ELSE CAST(ge."executionStatus" AS TEXT)
|
||||
END AS executionStatus,
|
||||
ge."createdAt" AS createdAt,
|
||||
ge."updatedAt" AS updatedAt,
|
||||
ge."userId" AS userId,
|
||||
g."name" AS agentGraphName,
|
||||
(ge."stats"::jsonb->>'cputime')::decimal AS cputime,
|
||||
(ge."stats"::jsonb->>'walltime')::decimal AS walltime,
|
||||
(ge."stats"::jsonb->>'node_count')::decimal AS node_count,
|
||||
(ge."stats"::jsonb->>'nodes_cputime')::decimal AS nodes_cputime,
|
||||
(ge."stats"::jsonb->>'nodes_walltime')::decimal AS nodes_walltime,
|
||||
(ge."stats"::jsonb->>'cost')::decimal AS execution_cost,
|
||||
(ge."stats"::jsonb->>'correctness_score')::float AS correctness_score,
|
||||
COALESCE(la.possibly_ai, FALSE) AS possibly_ai,
|
||||
REGEXP_REPLACE(
|
||||
REGEXP_REPLACE(
|
||||
TRIM(BOTH '"' FROM ge."stats"::jsonb->>'error'),
|
||||
'(https?://)([A-Za-z0-9.-]+)(:[0-9]+)?(/[^\s]*)?',
|
||||
'\1\2/...', 'gi'
|
||||
),
|
||||
'[a-zA-Z0-9_:-]*\d[a-zA-Z0-9_:-]*', '*', 'g'
|
||||
) AS groupedErrorMessage
|
||||
FROM platform."AgentGraphExecution" ge
|
||||
LEFT JOIN platform."AgentGraph" g
|
||||
ON ge."agentGraphId" = g."id"
|
||||
AND ge."agentGraphVersion" = g."version"
|
||||
LEFT JOIN (
|
||||
SELECT DISTINCT ON ("userId", "agentGraphId")
|
||||
"userId", "agentGraphId",
|
||||
("settings"::jsonb->>'sensitive_action_safe_mode')::boolean AS possibly_ai
|
||||
FROM platform."LibraryAgent"
|
||||
WHERE "isDeleted" = FALSE
|
||||
AND "isArchived" = FALSE
|
||||
ORDER BY "userId", "agentGraphId", "agentGraphVersion" DESC
|
||||
) la ON la."userId" = ge."userId" AND la."agentGraphId" = ge."agentGraphId"
|
||||
WHERE ge."createdAt" > CURRENT_DATE - INTERVAL '90 days'
|
||||
@@ -1,101 +0,0 @@
|
||||
-- =============================================================
|
||||
-- View: analytics.node_block_execution
|
||||
-- Looker source alias: ds14 | Charts: 11
|
||||
-- =============================================================
|
||||
-- DESCRIPTION
|
||||
-- One row per node (block) execution (last 90 days).
|
||||
-- Unpacks stats JSONB and joins to identify which block type
|
||||
-- was run. For failed nodes, joins the error output and
|
||||
-- scrubs it for safe grouping.
|
||||
--
|
||||
-- SOURCE TABLES
|
||||
-- platform.AgentNodeExecution — Node execution records
|
||||
-- platform.AgentNode — Node → block mapping
|
||||
-- platform.AgentBlock — Block name/ID
|
||||
-- platform.AgentNodeExecutionInputOutput — Error output values
|
||||
--
|
||||
-- OUTPUT COLUMNS
|
||||
-- id TEXT Node execution UUID
|
||||
-- agentGraphExecutionId TEXT Parent graph execution UUID
|
||||
-- agentNodeId TEXT Node UUID within the graph
|
||||
-- executionStatus TEXT COMPLETED | FAILED | QUEUED | RUNNING | TERMINATED
|
||||
-- addedTime TIMESTAMPTZ When the node was queued
|
||||
-- queuedTime TIMESTAMPTZ When it entered the queue
|
||||
-- startedTime TIMESTAMPTZ When execution started
|
||||
-- endedTime TIMESTAMPTZ When execution finished
|
||||
-- inputSize BIGINT Input payload size in bytes
|
||||
-- outputSize BIGINT Output payload size in bytes
|
||||
-- walltime NUMERIC Wall-clock seconds for this node
|
||||
-- cputime NUMERIC CPU seconds for this node
|
||||
-- llmRetryCount INT Number of LLM retries
|
||||
-- llmCallCount INT Number of LLM API calls made
|
||||
-- inputTokenCount BIGINT LLM input tokens consumed
|
||||
-- outputTokenCount BIGINT LLM output tokens produced
|
||||
-- blockName TEXT Human-readable block name (e.g. 'OpenAIBlock')
|
||||
-- blockId TEXT Block UUID
|
||||
-- groupedErrorMessage TEXT Scrubbed error (IDs/URLs wildcarded)
|
||||
-- errorMessage TEXT Raw error output (only set when FAILED)
|
||||
--
|
||||
-- WINDOW
|
||||
-- Rolling 90 days (addedTime > CURRENT_DATE - 90 days)
|
||||
--
|
||||
-- EXAMPLE QUERIES
|
||||
-- -- Most-used blocks by execution count
|
||||
-- SELECT "blockName", COUNT(*) AS executions,
|
||||
-- COUNT(*) FILTER (WHERE "executionStatus"='FAILED') AS failures
|
||||
-- FROM analytics.node_block_execution
|
||||
-- GROUP BY 1 ORDER BY executions DESC LIMIT 20;
|
||||
--
|
||||
-- -- Average LLM token usage per block
|
||||
-- SELECT "blockName",
|
||||
-- AVG("inputTokenCount") AS avg_input_tokens,
|
||||
-- AVG("outputTokenCount") AS avg_output_tokens
|
||||
-- FROM analytics.node_block_execution
|
||||
-- WHERE "llmCallCount" > 0
|
||||
-- GROUP BY 1 ORDER BY avg_input_tokens DESC;
|
||||
--
|
||||
-- -- Top failure reasons
|
||||
-- SELECT "blockName", "groupedErrorMessage", COUNT(*) AS count
|
||||
-- FROM analytics.node_block_execution
|
||||
-- WHERE "executionStatus" = 'FAILED'
|
||||
-- GROUP BY 1, 2 ORDER BY count DESC LIMIT 20;
|
||||
-- =============================================================
|
||||
|
||||
SELECT
|
||||
ne."id" AS id,
|
||||
ne."agentGraphExecutionId" AS agentGraphExecutionId,
|
||||
ne."agentNodeId" AS agentNodeId,
|
||||
CAST(ne."executionStatus" AS TEXT) AS executionStatus,
|
||||
ne."addedTime" AS addedTime,
|
||||
ne."queuedTime" AS queuedTime,
|
||||
ne."startedTime" AS startedTime,
|
||||
ne."endedTime" AS endedTime,
|
||||
(ne."stats"::jsonb->>'input_size')::bigint AS inputSize,
|
||||
(ne."stats"::jsonb->>'output_size')::bigint AS outputSize,
|
||||
(ne."stats"::jsonb->>'walltime')::numeric AS walltime,
|
||||
(ne."stats"::jsonb->>'cputime')::numeric AS cputime,
|
||||
(ne."stats"::jsonb->>'llm_retry_count')::int AS llmRetryCount,
|
||||
(ne."stats"::jsonb->>'llm_call_count')::int AS llmCallCount,
|
||||
(ne."stats"::jsonb->>'input_token_count')::bigint AS inputTokenCount,
|
||||
(ne."stats"::jsonb->>'output_token_count')::bigint AS outputTokenCount,
|
||||
b."name" AS blockName,
|
||||
b."id" AS blockId,
|
||||
REGEXP_REPLACE(
|
||||
REGEXP_REPLACE(
|
||||
TRIM(BOTH '"' FROM eio."data"::text),
|
||||
'(https?://)([A-Za-z0-9.-]+)(:[0-9]+)?(/[^\s]*)?',
|
||||
'\1\2/...', 'gi'
|
||||
),
|
||||
'[a-zA-Z0-9_:-]*\d[a-zA-Z0-9_:-]*', '*', 'g'
|
||||
) AS groupedErrorMessage,
|
||||
eio."data" AS errorMessage
|
||||
FROM platform."AgentNodeExecution" ne
|
||||
LEFT JOIN platform."AgentNode" nd
|
||||
ON ne."agentNodeId" = nd."id"
|
||||
LEFT JOIN platform."AgentBlock" b
|
||||
ON nd."agentBlockId" = b."id"
|
||||
LEFT JOIN platform."AgentNodeExecutionInputOutput" eio
|
||||
ON eio."referencedByOutputExecId" = ne."id"
|
||||
AND eio."name" = 'error'
|
||||
AND ne."executionStatus" = 'FAILED'
|
||||
WHERE ne."addedTime" > CURRENT_DATE - INTERVAL '90 days'
|
||||
@@ -1,97 +0,0 @@
|
||||
-- =============================================================
|
||||
-- View: analytics.retention_agent
|
||||
-- Looker source alias: ds35 | Charts: 2
|
||||
-- =============================================================
|
||||
-- DESCRIPTION
|
||||
-- Weekly cohort retention broken down per individual agent.
|
||||
-- Cohort = week of a user's first use of THAT specific agent.
|
||||
-- Tells you which agents keep users coming back vs. one-shot
|
||||
-- use. Only includes cohorts from the last 180 days.
|
||||
--
|
||||
-- SOURCE TABLES
|
||||
-- platform.AgentGraphExecution — Execution records (user × agent × time)
|
||||
-- platform.AgentGraph — Agent names
|
||||
--
|
||||
-- OUTPUT COLUMNS
|
||||
-- agent_id TEXT Agent graph UUID
|
||||
-- agent_label TEXT 'AgentName [first8chars]'
|
||||
-- agent_label_n TEXT 'AgentName [first8chars] (n=total_users)'
|
||||
-- cohort_week_start DATE Week users first ran this agent
|
||||
-- cohort_label TEXT ISO week label
|
||||
-- cohort_label_n TEXT ISO week label with cohort size
|
||||
-- user_lifetime_week INT Weeks since first use of this agent
|
||||
-- cohort_users BIGINT Users in this cohort for this agent
|
||||
-- active_users BIGINT Users who ran the agent again in week k
|
||||
-- retention_rate FLOAT active_users / cohort_users
|
||||
-- cohort_users_w0 BIGINT cohort_users only at week 0 (safe to SUM)
|
||||
-- agent_total_users BIGINT Total users across all cohorts for this agent
|
||||
--
|
||||
-- EXAMPLE QUERIES
|
||||
-- -- Best-retained agents at week 2
|
||||
-- SELECT agent_label, AVG(retention_rate) AS w2_retention
|
||||
-- FROM analytics.retention_agent
|
||||
-- WHERE user_lifetime_week = 2 AND cohort_users >= 10
|
||||
-- GROUP BY 1 ORDER BY w2_retention DESC LIMIT 10;
|
||||
--
|
||||
-- -- Agents with most unique users
|
||||
-- SELECT DISTINCT agent_label, agent_total_users
|
||||
-- FROM analytics.retention_agent
|
||||
-- ORDER BY agent_total_users DESC LIMIT 20;
|
||||
-- =============================================================
|
||||
|
||||
WITH params AS (SELECT 12::int AS max_weeks, (CURRENT_DATE - INTERVAL '180 days') AS cohort_start),
|
||||
events AS (
|
||||
SELECT e."userId"::text AS user_id, e."agentGraphId" AS agent_id,
|
||||
e."createdAt"::timestamptz AS created_at,
|
||||
DATE_TRUNC('week', e."createdAt")::date AS week_start
|
||||
FROM platform."AgentGraphExecution" e
|
||||
),
|
||||
first_use AS (
|
||||
SELECT user_id, agent_id, MIN(created_at) AS first_use_at,
|
||||
DATE_TRUNC('week', MIN(created_at))::date AS cohort_week_start
|
||||
FROM events GROUP BY 1,2
|
||||
HAVING MIN(created_at) >= (SELECT cohort_start FROM params)
|
||||
),
|
||||
activity_weeks AS (SELECT DISTINCT user_id, agent_id, week_start FROM events),
|
||||
user_week_age AS (
|
||||
SELECT aw.user_id, aw.agent_id, fu.cohort_week_start,
|
||||
((aw.week_start - DATE_TRUNC('week',fu.first_use_at)::date)/7)::int AS user_lifetime_week
|
||||
FROM activity_weeks aw JOIN first_use fu USING (user_id, agent_id)
|
||||
WHERE aw.week_start >= DATE_TRUNC('week',fu.first_use_at)::date
|
||||
),
|
||||
active_counts AS (
|
||||
SELECT agent_id, cohort_week_start, user_lifetime_week, COUNT(DISTINCT user_id) AS active_users
|
||||
FROM user_week_age WHERE user_lifetime_week >= 0 GROUP BY 1,2,3
|
||||
),
|
||||
cohort_sizes AS (
|
||||
SELECT agent_id, cohort_week_start, COUNT(DISTINCT user_id) AS cohort_users FROM first_use GROUP BY 1,2
|
||||
),
|
||||
cohort_caps AS (
|
||||
SELECT cs.agent_id, cs.cohort_week_start, cs.cohort_users,
|
||||
LEAST((SELECT max_weeks FROM params),
|
||||
GREATEST(0,((DATE_TRUNC('week',CURRENT_DATE)::date-cs.cohort_week_start)/7)::int)) AS cap_weeks
|
||||
FROM cohort_sizes cs
|
||||
),
|
||||
grid AS (
|
||||
SELECT cc.agent_id, cc.cohort_week_start, gs AS user_lifetime_week, cc.cohort_users
|
||||
FROM cohort_caps cc CROSS JOIN LATERAL generate_series(0, cc.cap_weeks) gs
|
||||
),
|
||||
agent_names AS (SELECT DISTINCT ON (g."id") g."id" AS agent_id, g."name" AS agent_name FROM platform."AgentGraph" g ORDER BY g."id", g."version" DESC),
|
||||
agent_total_users AS (SELECT agent_id, SUM(cohort_users) AS agent_total_users FROM cohort_sizes GROUP BY 1)
|
||||
SELECT
|
||||
g.agent_id,
|
||||
COALESCE(an.agent_name,'(unnamed)')||' ['||LEFT(g.agent_id::text,8)||']' AS agent_label,
|
||||
COALESCE(an.agent_name,'(unnamed)')||' ['||LEFT(g.agent_id::text,8)||'] (n='||COALESCE(atu.agent_total_users,0)||')' AS agent_label_n,
|
||||
g.cohort_week_start,
|
||||
TO_CHAR(g.cohort_week_start,'IYYY-"W"IW') AS cohort_label,
|
||||
TO_CHAR(g.cohort_week_start,'IYYY-"W"IW')||' (n='||g.cohort_users||')' AS cohort_label_n,
|
||||
g.user_lifetime_week, g.cohort_users,
|
||||
COALESCE(ac.active_users,0) AS active_users,
|
||||
COALESCE(ac.active_users,0)::float / NULLIF(g.cohort_users,0) AS retention_rate,
|
||||
CASE WHEN g.user_lifetime_week=0 THEN g.cohort_users ELSE 0 END AS cohort_users_w0,
|
||||
COALESCE(atu.agent_total_users,0) AS agent_total_users
|
||||
FROM grid g
|
||||
LEFT JOIN active_counts ac ON ac.agent_id=g.agent_id AND ac.cohort_week_start=g.cohort_week_start AND ac.user_lifetime_week=g.user_lifetime_week
|
||||
LEFT JOIN agent_names an ON an.agent_id=g.agent_id
|
||||
LEFT JOIN agent_total_users atu ON atu.agent_id=g.agent_id
|
||||
ORDER BY agent_label, g.cohort_week_start, g.user_lifetime_week;
|
||||
@@ -1,81 +0,0 @@
|
||||
-- =============================================================
|
||||
-- View: analytics.retention_execution_daily
|
||||
-- Looker source alias: ds111 | Charts: 1
|
||||
-- =============================================================
|
||||
-- DESCRIPTION
|
||||
-- Daily cohort retention based on agent executions.
|
||||
-- Cohort anchor = day of user's FIRST ever execution.
|
||||
-- Only includes cohorts from the last 90 days, up to day 30.
|
||||
-- Great for early engagement analysis (did users run another
|
||||
-- agent the next day?).
|
||||
--
|
||||
-- SOURCE TABLES
|
||||
-- platform.AgentGraphExecution — Execution records
|
||||
--
|
||||
-- OUTPUT COLUMNS
|
||||
-- Same pattern as retention_login_daily.
|
||||
-- cohort_day_start = day of first execution (not first login)
|
||||
--
|
||||
-- EXAMPLE QUERIES
|
||||
-- -- Day-3 execution retention
|
||||
-- SELECT cohort_label, retention_rate_bounded AS d3_retention
|
||||
-- FROM analytics.retention_execution_daily
|
||||
-- WHERE user_lifetime_day = 3 ORDER BY cohort_day_start;
|
||||
-- =============================================================
|
||||
|
||||
WITH params AS (SELECT 30::int AS max_days, (CURRENT_DATE - INTERVAL '90 days') AS cohort_start),
|
||||
events AS (
|
||||
SELECT e."userId"::text AS user_id, e."createdAt"::timestamptz AS created_at,
|
||||
DATE_TRUNC('day', e."createdAt")::date AS day_start
|
||||
FROM platform."AgentGraphExecution" e WHERE e."userId" IS NOT NULL
|
||||
),
|
||||
first_exec AS (
|
||||
SELECT user_id, MIN(created_at) AS first_exec_at,
|
||||
DATE_TRUNC('day', MIN(created_at))::date AS cohort_day_start
|
||||
FROM events GROUP BY 1
|
||||
HAVING MIN(created_at) >= (SELECT cohort_start FROM params)
|
||||
),
|
||||
activity_days AS (SELECT DISTINCT user_id, day_start FROM events),
|
||||
user_day_age AS (
|
||||
SELECT ad.user_id, fe.cohort_day_start,
|
||||
(ad.day_start - DATE_TRUNC('day',fe.first_exec_at)::date)::int AS user_lifetime_day
|
||||
FROM activity_days ad JOIN first_exec fe USING (user_id)
|
||||
WHERE ad.day_start >= DATE_TRUNC('day',fe.first_exec_at)::date
|
||||
),
|
||||
bounded_counts AS (
|
||||
SELECT cohort_day_start, user_lifetime_day, COUNT(DISTINCT user_id) AS active_users_bounded
|
||||
FROM user_day_age WHERE user_lifetime_day >= 0 GROUP BY 1,2
|
||||
),
|
||||
last_active AS (
|
||||
SELECT cohort_day_start, user_id, MAX(user_lifetime_day) AS last_active_day FROM user_day_age GROUP BY 1,2
|
||||
),
|
||||
unbounded_counts AS (
|
||||
SELECT la.cohort_day_start, gs AS user_lifetime_day, COUNT(*) AS retained_users_unbounded
|
||||
FROM last_active la
|
||||
CROSS JOIN LATERAL generate_series(0, LEAST(la.last_active_day,(SELECT max_days FROM params))) gs
|
||||
GROUP BY 1,2
|
||||
),
|
||||
cohort_sizes AS (SELECT cohort_day_start, COUNT(DISTINCT user_id) AS cohort_users FROM first_exec GROUP BY 1),
|
||||
cohort_caps AS (
|
||||
SELECT cs.cohort_day_start, cs.cohort_users,
|
||||
LEAST((SELECT max_days FROM params), GREATEST(0,(CURRENT_DATE-cs.cohort_day_start)::int)) AS cap_days
|
||||
FROM cohort_sizes cs
|
||||
),
|
||||
grid AS (
|
||||
SELECT cc.cohort_day_start, gs AS user_lifetime_day, cc.cohort_users
|
||||
FROM cohort_caps cc CROSS JOIN LATERAL generate_series(0, cc.cap_days) gs
|
||||
)
|
||||
SELECT
|
||||
g.cohort_day_start,
|
||||
TO_CHAR(g.cohort_day_start,'YYYY-MM-DD') AS cohort_label,
|
||||
TO_CHAR(g.cohort_day_start,'YYYY-MM-DD')||' (n='||g.cohort_users||')' AS cohort_label_n,
|
||||
g.user_lifetime_day, g.cohort_users,
|
||||
COALESCE(b.active_users_bounded,0) AS active_users_bounded,
|
||||
COALESCE(u.retained_users_unbounded,0) AS retained_users_unbounded,
|
||||
CASE WHEN g.cohort_users>0 THEN COALESCE(b.active_users_bounded,0)::float/g.cohort_users END AS retention_rate_bounded,
|
||||
CASE WHEN g.cohort_users>0 THEN COALESCE(u.retained_users_unbounded,0)::float/g.cohort_users END AS retention_rate_unbounded,
|
||||
CASE WHEN g.user_lifetime_day=0 THEN g.cohort_users ELSE 0 END AS cohort_users_d0
|
||||
FROM grid g
|
||||
LEFT JOIN bounded_counts b ON b.cohort_day_start=g.cohort_day_start AND b.user_lifetime_day=g.user_lifetime_day
|
||||
LEFT JOIN unbounded_counts u ON u.cohort_day_start=g.cohort_day_start AND u.user_lifetime_day=g.user_lifetime_day
|
||||
ORDER BY g.cohort_day_start, g.user_lifetime_day;
|
||||
@@ -1,81 +0,0 @@
|
||||
-- =============================================================
|
||||
-- View: analytics.retention_execution_weekly
|
||||
-- Looker source alias: ds92 | Charts: 2
|
||||
-- =============================================================
|
||||
-- DESCRIPTION
|
||||
-- Weekly cohort retention based on agent executions.
|
||||
-- Cohort anchor = week of user's FIRST ever agent execution
|
||||
-- (not first login). Only includes cohorts from the last 180 days.
|
||||
-- Useful when you care about product engagement, not just visits.
|
||||
--
|
||||
-- SOURCE TABLES
|
||||
-- platform.AgentGraphExecution — Execution records
|
||||
--
|
||||
-- OUTPUT COLUMNS
|
||||
-- Same pattern as retention_login_weekly.
|
||||
-- cohort_week_start = week of first execution (not first login)
|
||||
--
|
||||
-- EXAMPLE QUERIES
|
||||
-- -- Week-2 execution retention
|
||||
-- SELECT cohort_label, retention_rate_bounded
|
||||
-- FROM analytics.retention_execution_weekly
|
||||
-- WHERE user_lifetime_week = 2 ORDER BY cohort_week_start;
|
||||
-- =============================================================
|
||||
|
||||
WITH params AS (SELECT 12::int AS max_weeks, (CURRENT_DATE - INTERVAL '180 days') AS cohort_start),
|
||||
events AS (
|
||||
SELECT e."userId"::text AS user_id, e."createdAt"::timestamptz AS created_at,
|
||||
DATE_TRUNC('week', e."createdAt")::date AS week_start
|
||||
FROM platform."AgentGraphExecution" e WHERE e."userId" IS NOT NULL
|
||||
),
|
||||
first_exec AS (
|
||||
SELECT user_id, MIN(created_at) AS first_exec_at,
|
||||
DATE_TRUNC('week', MIN(created_at))::date AS cohort_week_start
|
||||
FROM events GROUP BY 1
|
||||
HAVING MIN(created_at) >= (SELECT cohort_start FROM params)
|
||||
),
|
||||
activity_weeks AS (SELECT DISTINCT user_id, week_start FROM events),
|
||||
user_week_age AS (
|
||||
SELECT aw.user_id, fe.cohort_week_start,
|
||||
((aw.week_start - DATE_TRUNC('week',fe.first_exec_at)::date)/7)::int AS user_lifetime_week
|
||||
FROM activity_weeks aw JOIN first_exec fe USING (user_id)
|
||||
WHERE aw.week_start >= DATE_TRUNC('week',fe.first_exec_at)::date
|
||||
),
|
||||
bounded_counts AS (
|
||||
SELECT cohort_week_start, user_lifetime_week, COUNT(DISTINCT user_id) AS active_users_bounded
|
||||
FROM user_week_age WHERE user_lifetime_week >= 0 GROUP BY 1,2
|
||||
),
|
||||
last_active AS (
|
||||
SELECT cohort_week_start, user_id, MAX(user_lifetime_week) AS last_active_week FROM user_week_age GROUP BY 1,2
|
||||
),
|
||||
unbounded_counts AS (
|
||||
SELECT la.cohort_week_start, gs AS user_lifetime_week, COUNT(*) AS retained_users_unbounded
|
||||
FROM last_active la
|
||||
CROSS JOIN LATERAL generate_series(0, LEAST(la.last_active_week,(SELECT max_weeks FROM params))) gs
|
||||
GROUP BY 1,2
|
||||
),
|
||||
cohort_sizes AS (SELECT cohort_week_start, COUNT(DISTINCT user_id) AS cohort_users FROM first_exec GROUP BY 1),
|
||||
cohort_caps AS (
|
||||
SELECT cs.cohort_week_start, cs.cohort_users,
|
||||
LEAST((SELECT max_weeks FROM params),
|
||||
GREATEST(0,((DATE_TRUNC('week',CURRENT_DATE)::date-cs.cohort_week_start)/7)::int)) AS cap_weeks
|
||||
FROM cohort_sizes cs
|
||||
),
|
||||
grid AS (
|
||||
SELECT cc.cohort_week_start, gs AS user_lifetime_week, cc.cohort_users
|
||||
FROM cohort_caps cc CROSS JOIN LATERAL generate_series(0, cc.cap_weeks) gs
|
||||
)
|
||||
SELECT
|
||||
g.cohort_week_start,
|
||||
TO_CHAR(g.cohort_week_start,'IYYY-"W"IW') AS cohort_label,
|
||||
TO_CHAR(g.cohort_week_start,'IYYY-"W"IW')||' (n='||g.cohort_users||')' AS cohort_label_n,
|
||||
g.user_lifetime_week, g.cohort_users,
|
||||
COALESCE(b.active_users_bounded,0) AS active_users_bounded,
|
||||
COALESCE(u.retained_users_unbounded,0) AS retained_users_unbounded,
|
||||
CASE WHEN g.cohort_users>0 THEN COALESCE(b.active_users_bounded,0)::float/g.cohort_users END AS retention_rate_bounded,
|
||||
CASE WHEN g.cohort_users>0 THEN COALESCE(u.retained_users_unbounded,0)::float/g.cohort_users END AS retention_rate_unbounded,
|
||||
CASE WHEN g.user_lifetime_week=0 THEN g.cohort_users ELSE 0 END AS cohort_users_w0
|
||||
FROM grid g
|
||||
LEFT JOIN bounded_counts b ON b.cohort_week_start=g.cohort_week_start AND b.user_lifetime_week=g.user_lifetime_week
|
||||
LEFT JOIN unbounded_counts u ON u.cohort_week_start=g.cohort_week_start AND u.user_lifetime_week=g.user_lifetime_week
|
||||
ORDER BY g.cohort_week_start, g.user_lifetime_week;
|
||||
@@ -1,94 +0,0 @@
|
||||
-- =============================================================
|
||||
-- View: analytics.retention_login_daily
|
||||
-- Looker source alias: ds112 | Charts: 1
|
||||
-- =============================================================
|
||||
-- DESCRIPTION
|
||||
-- Daily cohort retention based on login sessions.
|
||||
-- Same logic as retention_login_weekly but at day granularity,
|
||||
-- showing up to day 30 for cohorts from the last 90 days.
|
||||
-- Useful for analysing early activation (days 1-7) in detail.
|
||||
--
|
||||
-- SOURCE TABLES
|
||||
-- auth.sessions — Login session records
|
||||
--
|
||||
-- OUTPUT COLUMNS (same pattern as retention_login_weekly)
|
||||
-- cohort_day_start DATE First day the cohort logged in
|
||||
-- cohort_label TEXT Date string (e.g. '2025-03-01')
|
||||
-- cohort_label_n TEXT Date + cohort size (e.g. '2025-03-01 (n=12)')
|
||||
-- user_lifetime_day INT Days since first login (0 = signup day)
|
||||
-- cohort_users BIGINT Total users in cohort
|
||||
-- active_users_bounded BIGINT Users active on exactly day k
|
||||
-- retained_users_unbounded BIGINT Users active any time on/after day k
|
||||
-- retention_rate_bounded FLOAT bounded / cohort_users
|
||||
-- retention_rate_unbounded FLOAT unbounded / cohort_users
|
||||
-- cohort_users_d0 BIGINT cohort_users only at day 0, else 0 (safe to SUM)
|
||||
--
|
||||
-- EXAMPLE QUERIES
|
||||
-- -- Day-1 retention rate (came back next day)
|
||||
-- SELECT cohort_label, retention_rate_bounded AS d1_retention
|
||||
-- FROM analytics.retention_login_daily
|
||||
-- WHERE user_lifetime_day = 1 ORDER BY cohort_day_start;
|
||||
--
|
||||
-- -- Average retention curve across all cohorts
|
||||
-- SELECT user_lifetime_day,
|
||||
-- SUM(active_users_bounded)::float / NULLIF(SUM(cohort_users_d0), 0) AS avg_retention
|
||||
-- FROM analytics.retention_login_daily
|
||||
-- GROUP BY 1 ORDER BY 1;
|
||||
-- =============================================================
|
||||
|
||||
WITH params AS (SELECT 30::int AS max_days, (CURRENT_DATE - INTERVAL '90 days')::date AS cohort_start),
|
||||
events AS (
|
||||
SELECT s.user_id::text AS user_id, s.created_at::timestamptz AS created_at,
|
||||
DATE_TRUNC('day', s.created_at)::date AS day_start
|
||||
FROM auth.sessions s WHERE s.user_id IS NOT NULL
|
||||
),
|
||||
first_login AS (
|
||||
SELECT user_id, MIN(created_at) AS first_login_time,
|
||||
DATE_TRUNC('day', MIN(created_at))::date AS cohort_day_start
|
||||
FROM events GROUP BY 1
|
||||
HAVING MIN(created_at) >= (SELECT cohort_start FROM params)
|
||||
),
|
||||
activity_days AS (SELECT DISTINCT user_id, day_start FROM events),
|
||||
user_day_age AS (
|
||||
SELECT ad.user_id, fl.cohort_day_start,
|
||||
(ad.day_start - DATE_TRUNC('day', fl.first_login_time)::date)::int AS user_lifetime_day
|
||||
FROM activity_days ad JOIN first_login fl USING (user_id)
|
||||
WHERE ad.day_start >= DATE_TRUNC('day', fl.first_login_time)::date
|
||||
),
|
||||
bounded_counts AS (
|
||||
SELECT cohort_day_start, user_lifetime_day, COUNT(DISTINCT user_id) AS active_users_bounded
|
||||
FROM user_day_age WHERE user_lifetime_day >= 0 GROUP BY 1,2
|
||||
),
|
||||
last_active AS (
|
||||
SELECT cohort_day_start, user_id, MAX(user_lifetime_day) AS last_active_day FROM user_day_age GROUP BY 1,2
|
||||
),
|
||||
unbounded_counts AS (
|
||||
SELECT la.cohort_day_start, gs AS user_lifetime_day, COUNT(*) AS retained_users_unbounded
|
||||
FROM last_active la
|
||||
CROSS JOIN LATERAL generate_series(0, LEAST(la.last_active_day,(SELECT max_days FROM params))) gs
|
||||
GROUP BY 1,2
|
||||
),
|
||||
cohort_sizes AS (SELECT cohort_day_start, COUNT(DISTINCT user_id) AS cohort_users FROM first_login GROUP BY 1),
|
||||
cohort_caps AS (
|
||||
SELECT cs.cohort_day_start, cs.cohort_users,
|
||||
LEAST((SELECT max_days FROM params), GREATEST(0,(CURRENT_DATE-cs.cohort_day_start)::int)) AS cap_days
|
||||
FROM cohort_sizes cs
|
||||
),
|
||||
grid AS (
|
||||
SELECT cc.cohort_day_start, gs AS user_lifetime_day, cc.cohort_users
|
||||
FROM cohort_caps cc CROSS JOIN LATERAL generate_series(0, cc.cap_days) gs
|
||||
)
|
||||
SELECT
|
||||
g.cohort_day_start,
|
||||
TO_CHAR(g.cohort_day_start,'YYYY-MM-DD') AS cohort_label,
|
||||
TO_CHAR(g.cohort_day_start,'YYYY-MM-DD')||' (n='||g.cohort_users||')' AS cohort_label_n,
|
||||
g.user_lifetime_day, g.cohort_users,
|
||||
COALESCE(b.active_users_bounded,0) AS active_users_bounded,
|
||||
COALESCE(u.retained_users_unbounded,0) AS retained_users_unbounded,
|
||||
CASE WHEN g.cohort_users>0 THEN COALESCE(b.active_users_bounded,0)::float/g.cohort_users END AS retention_rate_bounded,
|
||||
CASE WHEN g.cohort_users>0 THEN COALESCE(u.retained_users_unbounded,0)::float/g.cohort_users END AS retention_rate_unbounded,
|
||||
CASE WHEN g.user_lifetime_day=0 THEN g.cohort_users ELSE 0 END AS cohort_users_d0
|
||||
FROM grid g
|
||||
LEFT JOIN bounded_counts b ON b.cohort_day_start=g.cohort_day_start AND b.user_lifetime_day=g.user_lifetime_day
|
||||
LEFT JOIN unbounded_counts u ON u.cohort_day_start=g.cohort_day_start AND u.user_lifetime_day=g.user_lifetime_day
|
||||
ORDER BY g.cohort_day_start, g.user_lifetime_day;
|
||||
@@ -1,96 +0,0 @@
|
||||
-- =============================================================
|
||||
-- View: analytics.retention_login_onboarded_weekly
|
||||
-- Looker source alias: ds101 | Charts: 2
|
||||
-- =============================================================
|
||||
-- DESCRIPTION
|
||||
-- Weekly cohort retention from login sessions, restricted to
|
||||
-- users who "onboarded" — defined as running at least one
|
||||
-- agent within 365 days of their first login.
|
||||
-- Filters out users who signed up but never activated,
|
||||
-- giving a cleaner view of engaged-user retention.
|
||||
--
|
||||
-- SOURCE TABLES
|
||||
-- auth.sessions — Login session records
|
||||
-- platform.AgentGraphExecution — Used to identify onboarders
|
||||
--
|
||||
-- OUTPUT COLUMNS
|
||||
-- Same as retention_login_weekly (cohort_week_start, user_lifetime_week,
|
||||
-- retention_rate_bounded, retention_rate_unbounded, etc.)
|
||||
-- Only difference: cohort is filtered to onboarded users only.
|
||||
--
|
||||
-- EXAMPLE QUERIES
|
||||
-- -- Compare week-4 retention: all users vs onboarded only
|
||||
-- SELECT 'all_users' AS segment, AVG(retention_rate_bounded) AS w4_retention
|
||||
-- FROM analytics.retention_login_weekly WHERE user_lifetime_week = 4
|
||||
-- UNION ALL
|
||||
-- SELECT 'onboarded', AVG(retention_rate_bounded)
|
||||
-- FROM analytics.retention_login_onboarded_weekly WHERE user_lifetime_week = 4;
|
||||
-- =============================================================
|
||||
|
||||
WITH params AS (SELECT 12::int AS max_weeks, 365::int AS onboarding_window_days),
|
||||
events AS (
|
||||
SELECT s.user_id::text AS user_id, s.created_at::timestamptz AS created_at,
|
||||
DATE_TRUNC('week', s.created_at)::date AS week_start
|
||||
FROM auth.sessions s WHERE s.user_id IS NOT NULL
|
||||
),
|
||||
first_login_all AS (
|
||||
SELECT user_id, MIN(created_at) AS first_login_time,
|
||||
DATE_TRUNC('week', MIN(created_at))::date AS cohort_week_start
|
||||
FROM events GROUP BY 1
|
||||
),
|
||||
onboarders AS (
|
||||
SELECT fl.user_id FROM first_login_all fl
|
||||
WHERE EXISTS (
|
||||
SELECT 1 FROM platform."AgentGraphExecution" e
|
||||
WHERE e."userId"::text = fl.user_id
|
||||
AND e."createdAt" >= fl.first_login_time
|
||||
AND e."createdAt" < fl.first_login_time
|
||||
+ make_interval(days => (SELECT onboarding_window_days FROM params))
|
||||
)
|
||||
),
|
||||
first_login AS (SELECT * FROM first_login_all WHERE user_id IN (SELECT user_id FROM onboarders)),
|
||||
activity_weeks AS (SELECT DISTINCT user_id, week_start FROM events),
|
||||
user_week_age AS (
|
||||
SELECT aw.user_id, fl.cohort_week_start,
|
||||
((aw.week_start - DATE_TRUNC('week',fl.first_login_time)::date)/7)::int AS user_lifetime_week
|
||||
FROM activity_weeks aw JOIN first_login fl USING (user_id)
|
||||
WHERE aw.week_start >= DATE_TRUNC('week',fl.first_login_time)::date
|
||||
),
|
||||
bounded_counts AS (
|
||||
SELECT cohort_week_start, user_lifetime_week, COUNT(DISTINCT user_id) AS active_users_bounded
|
||||
FROM user_week_age WHERE user_lifetime_week >= 0 GROUP BY 1,2
|
||||
),
|
||||
last_active AS (
|
||||
SELECT cohort_week_start, user_id, MAX(user_lifetime_week) AS last_active_week FROM user_week_age GROUP BY 1,2
|
||||
),
|
||||
unbounded_counts AS (
|
||||
SELECT la.cohort_week_start, gs AS user_lifetime_week, COUNT(*) AS retained_users_unbounded
|
||||
FROM last_active la
|
||||
CROSS JOIN LATERAL generate_series(0, LEAST(la.last_active_week,(SELECT max_weeks FROM params))) gs
|
||||
GROUP BY 1,2
|
||||
),
|
||||
cohort_sizes AS (SELECT cohort_week_start, COUNT(DISTINCT user_id) AS cohort_users FROM first_login GROUP BY 1),
|
||||
cohort_caps AS (
|
||||
SELECT cs.cohort_week_start, cs.cohort_users,
|
||||
LEAST((SELECT max_weeks FROM params),
|
||||
GREATEST(0,((DATE_TRUNC('week',CURRENT_DATE)::date-cs.cohort_week_start)/7)::int)) AS cap_weeks
|
||||
FROM cohort_sizes cs
|
||||
),
|
||||
grid AS (
|
||||
SELECT cc.cohort_week_start, gs AS user_lifetime_week, cc.cohort_users
|
||||
FROM cohort_caps cc CROSS JOIN LATERAL generate_series(0, cc.cap_weeks) gs
|
||||
)
|
||||
SELECT
|
||||
g.cohort_week_start,
|
||||
TO_CHAR(g.cohort_week_start,'IYYY-"W"IW') AS cohort_label,
|
||||
TO_CHAR(g.cohort_week_start,'IYYY-"W"IW')||' (n='||g.cohort_users||')' AS cohort_label_n,
|
||||
g.user_lifetime_week, g.cohort_users,
|
||||
COALESCE(b.active_users_bounded,0) AS active_users_bounded,
|
||||
COALESCE(u.retained_users_unbounded,0) AS retained_users_unbounded,
|
||||
CASE WHEN g.cohort_users>0 THEN COALESCE(b.active_users_bounded,0)::float/g.cohort_users END AS retention_rate_bounded,
|
||||
CASE WHEN g.cohort_users>0 THEN COALESCE(u.retained_users_unbounded,0)::float/g.cohort_users END AS retention_rate_unbounded,
|
||||
CASE WHEN g.user_lifetime_week=0 THEN g.cohort_users ELSE 0 END AS cohort_users_w0
|
||||
FROM grid g
|
||||
LEFT JOIN bounded_counts b ON b.cohort_week_start=g.cohort_week_start AND b.user_lifetime_week=g.user_lifetime_week
|
||||
LEFT JOIN unbounded_counts u ON u.cohort_week_start=g.cohort_week_start AND u.user_lifetime_week=g.user_lifetime_week
|
||||
ORDER BY g.cohort_week_start, g.user_lifetime_week;
|
||||
@@ -1,103 +0,0 @@
|
||||
-- =============================================================
|
||||
-- View: analytics.retention_login_weekly
|
||||
-- Looker source alias: ds83 | Charts: 2
|
||||
-- =============================================================
|
||||
-- DESCRIPTION
|
||||
-- Weekly cohort retention based on login sessions.
|
||||
-- Users are grouped by the ISO week of their first ever login.
|
||||
-- For each cohort × lifetime-week combination, outputs both:
|
||||
-- - bounded rate: % active in exactly that week
|
||||
-- - unbounded rate: % who were ever active on or after that week
|
||||
-- Weeks are capped to the cohort's actual age (no future data points).
|
||||
--
|
||||
-- SOURCE TABLES
|
||||
-- auth.sessions — Login session records
|
||||
--
|
||||
-- HOW TO READ THE OUTPUT
|
||||
-- cohort_week_start The Monday of the week users first logged in
|
||||
-- user_lifetime_week 0 = signup week, 1 = one week later, etc.
|
||||
-- retention_rate_bounded = active_users_bounded / cohort_users
|
||||
-- retention_rate_unbounded = retained_users_unbounded / cohort_users
|
||||
--
|
||||
-- OUTPUT COLUMNS
|
||||
-- cohort_week_start DATE First day of the cohort's signup week
|
||||
-- cohort_label TEXT ISO week label (e.g. '2025-W01')
|
||||
-- cohort_label_n TEXT ISO week label with cohort size (e.g. '2025-W01 (n=42)')
|
||||
-- user_lifetime_week INT Weeks since first login (0 = signup week)
|
||||
-- cohort_users BIGINT Total users in this cohort (denominator)
|
||||
-- active_users_bounded BIGINT Users active in exactly week k
|
||||
-- retained_users_unbounded BIGINT Users active any time on/after week k
|
||||
-- retention_rate_bounded FLOAT bounded active / cohort_users
|
||||
-- retention_rate_unbounded FLOAT unbounded retained / cohort_users
|
||||
-- cohort_users_w0 BIGINT cohort_users only at week 0, else 0 (safe to SUM in pivot tables)
|
||||
--
|
||||
-- EXAMPLE QUERIES
|
||||
-- -- Week-1 retention rate per cohort
|
||||
-- SELECT cohort_label, retention_rate_bounded AS w1_retention
|
||||
-- FROM analytics.retention_login_weekly
|
||||
-- WHERE user_lifetime_week = 1
|
||||
-- ORDER BY cohort_week_start;
|
||||
--
|
||||
-- -- Overall average retention curve (all cohorts combined)
|
||||
-- SELECT user_lifetime_week,
|
||||
-- SUM(active_users_bounded)::float / NULLIF(SUM(cohort_users_w0), 0) AS avg_retention
|
||||
-- FROM analytics.retention_login_weekly
|
||||
-- GROUP BY 1 ORDER BY 1;
|
||||
-- =============================================================
|
||||
|
||||
WITH params AS (SELECT 12::int AS max_weeks),
|
||||
events AS (
|
||||
SELECT s.user_id::text AS user_id, s.created_at::timestamptz AS created_at,
|
||||
DATE_TRUNC('week', s.created_at)::date AS week_start
|
||||
FROM auth.sessions s WHERE s.user_id IS NOT NULL
|
||||
),
|
||||
first_login AS (
|
||||
SELECT user_id, MIN(created_at) AS first_login_time,
|
||||
DATE_TRUNC('week', MIN(created_at))::date AS cohort_week_start
|
||||
FROM events GROUP BY 1
|
||||
),
|
||||
activity_weeks AS (SELECT DISTINCT user_id, week_start FROM events),
|
||||
user_week_age AS (
|
||||
SELECT aw.user_id, fl.cohort_week_start,
|
||||
((aw.week_start - DATE_TRUNC('week', fl.first_login_time)::date) / 7)::int AS user_lifetime_week
|
||||
FROM activity_weeks aw JOIN first_login fl USING (user_id)
|
||||
WHERE aw.week_start >= DATE_TRUNC('week', fl.first_login_time)::date
|
||||
),
|
||||
bounded_counts AS (
|
||||
SELECT cohort_week_start, user_lifetime_week, COUNT(DISTINCT user_id) AS active_users_bounded
|
||||
FROM user_week_age WHERE user_lifetime_week >= 0 GROUP BY 1,2
|
||||
),
|
||||
last_active AS (
|
||||
SELECT cohort_week_start, user_id, MAX(user_lifetime_week) AS last_active_week FROM user_week_age GROUP BY 1,2
|
||||
),
|
||||
unbounded_counts AS (
|
||||
SELECT la.cohort_week_start, gs AS user_lifetime_week, COUNT(*) AS retained_users_unbounded
|
||||
FROM last_active la
|
||||
CROSS JOIN LATERAL generate_series(0, LEAST(la.last_active_week,(SELECT max_weeks FROM params))) gs
|
||||
GROUP BY 1,2
|
||||
),
|
||||
cohort_sizes AS (SELECT cohort_week_start, COUNT(DISTINCT user_id) AS cohort_users FROM first_login GROUP BY 1),
|
||||
cohort_caps AS (
|
||||
SELECT cs.cohort_week_start, cs.cohort_users,
|
||||
LEAST((SELECT max_weeks FROM params),
|
||||
GREATEST(0,((DATE_TRUNC('week',CURRENT_DATE)::date - cs.cohort_week_start)/7)::int)) AS cap_weeks
|
||||
FROM cohort_sizes cs
|
||||
),
|
||||
grid AS (
|
||||
SELECT cc.cohort_week_start, gs AS user_lifetime_week, cc.cohort_users
|
||||
FROM cohort_caps cc CROSS JOIN LATERAL generate_series(0, cc.cap_weeks) gs
|
||||
)
|
||||
SELECT
|
||||
g.cohort_week_start,
|
||||
TO_CHAR(g.cohort_week_start,'IYYY-"W"IW') AS cohort_label,
|
||||
TO_CHAR(g.cohort_week_start,'IYYY-"W"IW')||' (n='||g.cohort_users||')' AS cohort_label_n,
|
||||
g.user_lifetime_week, g.cohort_users,
|
||||
COALESCE(b.active_users_bounded,0) AS active_users_bounded,
|
||||
COALESCE(u.retained_users_unbounded,0) AS retained_users_unbounded,
|
||||
CASE WHEN g.cohort_users>0 THEN COALESCE(b.active_users_bounded,0)::float/g.cohort_users END AS retention_rate_bounded,
|
||||
CASE WHEN g.cohort_users>0 THEN COALESCE(u.retained_users_unbounded,0)::float/g.cohort_users END AS retention_rate_unbounded,
|
||||
CASE WHEN g.user_lifetime_week=0 THEN g.cohort_users ELSE 0 END AS cohort_users_w0
|
||||
FROM grid g
|
||||
LEFT JOIN bounded_counts b ON b.cohort_week_start=g.cohort_week_start AND b.user_lifetime_week=g.user_lifetime_week
|
||||
LEFT JOIN unbounded_counts u ON u.cohort_week_start=g.cohort_week_start AND u.user_lifetime_week=g.user_lifetime_week
|
||||
ORDER BY g.cohort_week_start, g.user_lifetime_week
|
||||
@@ -1,71 +0,0 @@
|
||||
-- =============================================================
|
||||
-- View: analytics.user_block_spending
|
||||
-- Looker source alias: ds6 | Charts: 5
|
||||
-- =============================================================
|
||||
-- DESCRIPTION
|
||||
-- One row per credit transaction (last 90 days).
|
||||
-- Shows how users spend credits broken down by block type,
|
||||
-- LLM provider and model. Joins node execution stats for
|
||||
-- token-level detail.
|
||||
--
|
||||
-- SOURCE TABLES
|
||||
-- platform.CreditTransaction — Credit debit/credit records
|
||||
-- platform.AgentNodeExecution — Node execution stats (for token counts)
|
||||
--
|
||||
-- OUTPUT COLUMNS
|
||||
-- transactionKey TEXT Unique transaction identifier
|
||||
-- userId TEXT User who was charged
|
||||
-- amount DECIMAL Credit amount (positive = credit, negative = debit)
|
||||
-- negativeAmount DECIMAL amount * -1 (convenience for spend charts)
|
||||
-- transactionType TEXT Transaction type (e.g. 'USAGE', 'REFUND', 'TOP_UP')
|
||||
-- transactionTime TIMESTAMPTZ When the transaction was recorded
|
||||
-- blockId TEXT Block UUID that triggered the spend
|
||||
-- blockName TEXT Human-readable block name
|
||||
-- llm_provider TEXT LLM provider (e.g. 'openai', 'anthropic')
|
||||
-- llm_model TEXT Model name (e.g. 'gpt-4o', 'claude-3-5-sonnet')
|
||||
-- node_exec_id TEXT Linked node execution UUID
|
||||
-- llm_call_count INT LLM API calls made in that execution
|
||||
-- llm_retry_count INT LLM retries in that execution
|
||||
-- llm_input_token_count INT Input tokens consumed
|
||||
-- llm_output_token_count INT Output tokens produced
|
||||
--
|
||||
-- WINDOW
|
||||
-- Rolling 90 days (createdAt > CURRENT_DATE - 90 days)
|
||||
--
|
||||
-- EXAMPLE QUERIES
|
||||
-- -- Total spend per user (last 90 days)
|
||||
-- SELECT "userId", SUM("negativeAmount") AS total_spent
|
||||
-- FROM analytics.user_block_spending
|
||||
-- WHERE "transactionType" = 'USAGE'
|
||||
-- GROUP BY 1 ORDER BY total_spent DESC;
|
||||
--
|
||||
-- -- Spend by LLM provider + model
|
||||
-- SELECT "llm_provider", "llm_model",
|
||||
-- SUM("negativeAmount") AS total_cost,
|
||||
-- SUM("llm_input_token_count") AS input_tokens,
|
||||
-- SUM("llm_output_token_count") AS output_tokens
|
||||
-- FROM analytics.user_block_spending
|
||||
-- WHERE "llm_provider" IS NOT NULL
|
||||
-- GROUP BY 1, 2 ORDER BY total_cost DESC;
|
||||
-- =============================================================
|
||||
|
||||
SELECT
|
||||
c."transactionKey" AS transactionKey,
|
||||
c."userId" AS userId,
|
||||
c."amount" AS amount,
|
||||
c."amount" * -1 AS negativeAmount,
|
||||
c."type" AS transactionType,
|
||||
c."createdAt" AS transactionTime,
|
||||
c.metadata->>'block_id' AS blockId,
|
||||
c.metadata->>'block' AS blockName,
|
||||
c.metadata->'input'->'credentials'->>'provider' AS llm_provider,
|
||||
c.metadata->'input'->>'model' AS llm_model,
|
||||
c.metadata->>'node_exec_id' AS node_exec_id,
|
||||
(ne."stats"->>'llm_call_count')::int AS llm_call_count,
|
||||
(ne."stats"->>'llm_retry_count')::int AS llm_retry_count,
|
||||
(ne."stats"->>'input_token_count')::int AS llm_input_token_count,
|
||||
(ne."stats"->>'output_token_count')::int AS llm_output_token_count
|
||||
FROM platform."CreditTransaction" c
|
||||
LEFT JOIN platform."AgentNodeExecution" ne
|
||||
ON (c.metadata->>'node_exec_id') = ne."id"::text
|
||||
WHERE c."createdAt" > CURRENT_DATE - INTERVAL '90 days'
|
||||
@@ -1,45 +0,0 @@
|
||||
-- =============================================================
|
||||
-- View: analytics.user_onboarding
|
||||
-- Looker source alias: ds68 | Charts: 3
|
||||
-- =============================================================
|
||||
-- DESCRIPTION
|
||||
-- One row per user onboarding record. Contains the user's
|
||||
-- stated usage reason, selected integrations, completed
|
||||
-- onboarding steps and optional first agent selection.
|
||||
-- Full history (no date filter) since onboarding happens
|
||||
-- once per user.
|
||||
--
|
||||
-- SOURCE TABLES
|
||||
-- platform.UserOnboarding — Onboarding state per user
|
||||
--
|
||||
-- OUTPUT COLUMNS
|
||||
-- id TEXT Onboarding record UUID
|
||||
-- createdAt TIMESTAMPTZ When onboarding started
|
||||
-- updatedAt TIMESTAMPTZ Last update to onboarding state
|
||||
-- usageReason TEXT Why user signed up (e.g. 'work', 'personal')
|
||||
-- integrations TEXT[] Array of integration names the user selected
|
||||
-- userId TEXT User UUID
|
||||
-- completedSteps TEXT[] Array of onboarding step enums completed
|
||||
-- selectedStoreListingVersionId TEXT First marketplace agent the user chose (if any)
|
||||
--
|
||||
-- EXAMPLE QUERIES
|
||||
-- -- Usage reason breakdown
|
||||
-- SELECT "usageReason", COUNT(*) FROM analytics.user_onboarding GROUP BY 1;
|
||||
--
|
||||
-- -- Completion rate per step
|
||||
-- SELECT step, COUNT(*) AS users_completed
|
||||
-- FROM analytics.user_onboarding
|
||||
-- CROSS JOIN LATERAL UNNEST("completedSteps") AS step
|
||||
-- GROUP BY 1 ORDER BY users_completed DESC;
|
||||
-- =============================================================
|
||||
|
||||
SELECT
|
||||
id,
|
||||
"createdAt",
|
||||
"updatedAt",
|
||||
"usageReason",
|
||||
integrations,
|
||||
"userId",
|
||||
"completedSteps",
|
||||
"selectedStoreListingVersionId"
|
||||
FROM platform."UserOnboarding"
|
||||
@@ -1,100 +0,0 @@
|
||||
-- =============================================================
|
||||
-- View: analytics.user_onboarding_funnel
|
||||
-- Looker source alias: ds74 | Charts: 1
|
||||
-- =============================================================
|
||||
-- DESCRIPTION
|
||||
-- Pre-aggregated onboarding funnel showing how many users
|
||||
-- completed each step and the drop-off percentage from the
|
||||
-- previous step. One row per onboarding step (all 22 steps
|
||||
-- always present, even with 0 completions — prevents sparse
|
||||
-- gaps from making LAG compare the wrong predecessors).
|
||||
--
|
||||
-- SOURCE TABLES
|
||||
-- platform.UserOnboarding — Onboarding records with completedSteps array
|
||||
--
|
||||
-- OUTPUT COLUMNS
|
||||
-- step TEXT Onboarding step enum name (e.g. 'WELCOME', 'CONGRATS')
|
||||
-- step_order INT Numeric position in the funnel (1=first, 22=last)
|
||||
-- users_completed BIGINT Distinct users who completed this step
|
||||
-- pct_from_prev NUMERIC % of users from the previous step who reached this one
|
||||
--
|
||||
-- STEP ORDER
|
||||
-- 1 WELCOME 9 MARKETPLACE_VISIT 17 SCHEDULE_AGENT
|
||||
-- 2 USAGE_REASON 10 MARKETPLACE_ADD_AGENT 18 RUN_AGENTS
|
||||
-- 3 INTEGRATIONS 11 MARKETPLACE_RUN_AGENT 19 RUN_3_DAYS
|
||||
-- 4 AGENT_CHOICE 12 BUILDER_OPEN 20 TRIGGER_WEBHOOK
|
||||
-- 5 AGENT_NEW_RUN 13 BUILDER_SAVE_AGENT 21 RUN_14_DAYS
|
||||
-- 6 AGENT_INPUT 14 BUILDER_RUN_AGENT 22 RUN_AGENTS_100
|
||||
-- 7 CONGRATS 15 VISIT_COPILOT
|
||||
-- 8 GET_RESULTS 16 RE_RUN_AGENT
|
||||
--
|
||||
-- WINDOW
|
||||
-- Users who started onboarding in the last 90 days
|
||||
--
|
||||
-- EXAMPLE QUERIES
|
||||
-- -- Full funnel
|
||||
-- SELECT * FROM analytics.user_onboarding_funnel ORDER BY step_order;
|
||||
--
|
||||
-- -- Biggest drop-off point
|
||||
-- SELECT step, pct_from_prev FROM analytics.user_onboarding_funnel
|
||||
-- ORDER BY pct_from_prev ASC LIMIT 3;
|
||||
-- =============================================================
|
||||
|
||||
WITH all_steps AS (
|
||||
-- Complete ordered grid of all 22 steps so zero-completion steps
|
||||
-- are always present, keeping LAG comparisons correct.
|
||||
SELECT step_name, step_order
|
||||
FROM (VALUES
|
||||
('WELCOME', 1),
|
||||
('USAGE_REASON', 2),
|
||||
('INTEGRATIONS', 3),
|
||||
('AGENT_CHOICE', 4),
|
||||
('AGENT_NEW_RUN', 5),
|
||||
('AGENT_INPUT', 6),
|
||||
('CONGRATS', 7),
|
||||
('GET_RESULTS', 8),
|
||||
('MARKETPLACE_VISIT', 9),
|
||||
('MARKETPLACE_ADD_AGENT', 10),
|
||||
('MARKETPLACE_RUN_AGENT', 11),
|
||||
('BUILDER_OPEN', 12),
|
||||
('BUILDER_SAVE_AGENT', 13),
|
||||
('BUILDER_RUN_AGENT', 14),
|
||||
('VISIT_COPILOT', 15),
|
||||
('RE_RUN_AGENT', 16),
|
||||
('SCHEDULE_AGENT', 17),
|
||||
('RUN_AGENTS', 18),
|
||||
('RUN_3_DAYS', 19),
|
||||
('TRIGGER_WEBHOOK', 20),
|
||||
('RUN_14_DAYS', 21),
|
||||
('RUN_AGENTS_100', 22)
|
||||
) AS t(step_name, step_order)
|
||||
),
|
||||
raw AS (
|
||||
SELECT
|
||||
u."userId",
|
||||
step_txt::text AS step
|
||||
FROM platform."UserOnboarding" u
|
||||
CROSS JOIN LATERAL UNNEST(u."completedSteps") AS step_txt
|
||||
WHERE u."createdAt" >= CURRENT_DATE - INTERVAL '90 days'
|
||||
),
|
||||
step_counts AS (
|
||||
SELECT step, COUNT(DISTINCT "userId") AS users_completed
|
||||
FROM raw GROUP BY step
|
||||
),
|
||||
funnel AS (
|
||||
SELECT
|
||||
a.step_name AS step,
|
||||
a.step_order,
|
||||
COALESCE(sc.users_completed, 0) AS users_completed,
|
||||
ROUND(
|
||||
100.0 * COALESCE(sc.users_completed, 0)
|
||||
/ NULLIF(
|
||||
LAG(COALESCE(sc.users_completed, 0)) OVER (ORDER BY a.step_order),
|
||||
0
|
||||
),
|
||||
2
|
||||
) AS pct_from_prev
|
||||
FROM all_steps a
|
||||
LEFT JOIN step_counts sc ON sc.step = a.step_name
|
||||
)
|
||||
SELECT * FROM funnel ORDER BY step_order
|
||||
@@ -1,41 +0,0 @@
|
||||
-- =============================================================
|
||||
-- View: analytics.user_onboarding_integration
|
||||
-- Looker source alias: ds75 | Charts: 1
|
||||
-- =============================================================
|
||||
-- DESCRIPTION
|
||||
-- Pre-aggregated count of users who selected each integration
|
||||
-- during onboarding. One row per integration type, sorted
|
||||
-- by popularity.
|
||||
--
|
||||
-- SOURCE TABLES
|
||||
-- platform.UserOnboarding — integrations array column
|
||||
--
|
||||
-- OUTPUT COLUMNS
|
||||
-- integration TEXT Integration name (e.g. 'github', 'slack', 'notion')
|
||||
-- users_with_integration BIGINT Distinct users who selected this integration
|
||||
--
|
||||
-- WINDOW
|
||||
-- Users who started onboarding in the last 90 days
|
||||
--
|
||||
-- EXAMPLE QUERIES
|
||||
-- -- Full integration popularity ranking
|
||||
-- SELECT * FROM analytics.user_onboarding_integration;
|
||||
--
|
||||
-- -- Top 5 integrations
|
||||
-- SELECT * FROM analytics.user_onboarding_integration LIMIT 5;
|
||||
-- =============================================================
|
||||
|
||||
WITH exploded AS (
|
||||
SELECT
|
||||
u."userId" AS user_id,
|
||||
UNNEST(u."integrations") AS integration
|
||||
FROM platform."UserOnboarding" u
|
||||
WHERE u."createdAt" >= CURRENT_DATE - INTERVAL '90 days'
|
||||
)
|
||||
SELECT
|
||||
integration,
|
||||
COUNT(DISTINCT user_id) AS users_with_integration
|
||||
FROM exploded
|
||||
WHERE integration IS NOT NULL AND integration <> ''
|
||||
GROUP BY integration
|
||||
ORDER BY users_with_integration DESC
|
||||
@@ -1,145 +0,0 @@
|
||||
-- =============================================================
|
||||
-- View: analytics.users_activities
|
||||
-- Looker source alias: ds56 | Charts: 5
|
||||
-- =============================================================
|
||||
-- DESCRIPTION
|
||||
-- One row per user with lifetime activity summary.
|
||||
-- Joins login sessions with agent graphs, executions and
|
||||
-- node-level runs to give a full picture of how engaged
|
||||
-- each user is. Includes a convenience flag for 7-day
|
||||
-- activation (did the user return at least 7 days after
|
||||
-- their first login?).
|
||||
--
|
||||
-- SOURCE TABLES
|
||||
-- auth.sessions — Login/session records
|
||||
-- platform.AgentGraph — Graphs (agents) built by the user
|
||||
-- platform.AgentGraphExecution — Agent run history
|
||||
-- platform.AgentNodeExecution — Individual block execution history
|
||||
--
|
||||
-- PERFORMANCE NOTE
|
||||
-- Each CTE aggregates its own table independently by userId.
|
||||
-- This avoids the fan-out that occurs when driving every join
|
||||
-- from user_logins across the two largest tables
|
||||
-- (AgentGraphExecution and AgentNodeExecution).
|
||||
--
|
||||
-- OUTPUT COLUMNS
|
||||
-- user_id TEXT Supabase user UUID
|
||||
-- first_login_time TIMESTAMPTZ First ever session created_at
|
||||
-- last_login_time TIMESTAMPTZ Most recent session created_at
|
||||
-- last_visit_time TIMESTAMPTZ Max of last refresh or login
|
||||
-- last_agent_save_time TIMESTAMPTZ Last time user saved an agent graph
|
||||
-- agent_count BIGINT Number of distinct active graphs built (0 if none)
|
||||
-- first_agent_run_time TIMESTAMPTZ First ever graph execution
|
||||
-- last_agent_run_time TIMESTAMPTZ Most recent graph execution
|
||||
-- unique_agent_runs BIGINT Distinct agent graphs ever run (0 if none)
|
||||
-- agent_runs BIGINT Total graph execution count (0 if none)
|
||||
-- node_execution_count BIGINT Total node executions across all runs
|
||||
-- node_execution_failed BIGINT Node executions with FAILED status
|
||||
-- node_execution_completed BIGINT Node executions with COMPLETED status
|
||||
-- node_execution_terminated BIGINT Node executions with TERMINATED status
|
||||
-- node_execution_queued BIGINT Node executions with QUEUED status
|
||||
-- node_execution_running BIGINT Node executions with RUNNING status
|
||||
-- is_active_after_7d INT 1=returned after day 7, 0=did not, NULL=too early to tell
|
||||
-- node_execution_incomplete BIGINT Node executions with INCOMPLETE status
|
||||
-- node_execution_review BIGINT Node executions with REVIEW status
|
||||
--
|
||||
-- EXAMPLE QUERIES
|
||||
-- -- Users who ran at least one agent and returned after 7 days
|
||||
-- SELECT COUNT(*) FROM analytics.users_activities
|
||||
-- WHERE agent_runs > 0 AND is_active_after_7d = 1;
|
||||
--
|
||||
-- -- Top 10 most active users by agent runs
|
||||
-- SELECT user_id, agent_runs, node_execution_count
|
||||
-- FROM analytics.users_activities
|
||||
-- ORDER BY agent_runs DESC LIMIT 10;
|
||||
--
|
||||
-- -- 7-day activation rate
|
||||
-- SELECT
|
||||
-- SUM(CASE WHEN is_active_after_7d = 1 THEN 1 ELSE 0 END)::float
|
||||
-- / NULLIF(COUNT(CASE WHEN is_active_after_7d IS NOT NULL THEN 1 END), 0)
|
||||
-- AS activation_rate
|
||||
-- FROM analytics.users_activities;
|
||||
-- =============================================================
|
||||
|
||||
WITH user_logins AS (
|
||||
SELECT
|
||||
user_id::text AS user_id,
|
||||
MIN(created_at) AS first_login_time,
|
||||
MAX(created_at) AS last_login_time,
|
||||
GREATEST(
|
||||
MAX(refreshed_at)::timestamptz,
|
||||
MAX(created_at)::timestamptz
|
||||
) AS last_visit_time
|
||||
FROM auth.sessions
|
||||
GROUP BY user_id
|
||||
),
|
||||
user_agents AS (
|
||||
-- Aggregate AgentGraph directly by userId (no fan-out from user_logins)
|
||||
SELECT
|
||||
"userId"::text AS user_id,
|
||||
MAX("updatedAt") AS last_agent_save_time,
|
||||
COUNT(DISTINCT "id") AS agent_count
|
||||
FROM platform."AgentGraph"
|
||||
WHERE "isActive"
|
||||
GROUP BY "userId"
|
||||
),
|
||||
user_graph_runs AS (
|
||||
-- Aggregate AgentGraphExecution directly by userId
|
||||
SELECT
|
||||
"userId"::text AS user_id,
|
||||
MIN("createdAt") AS first_agent_run_time,
|
||||
MAX("createdAt") AS last_agent_run_time,
|
||||
COUNT(DISTINCT "agentGraphId") AS unique_agent_runs,
|
||||
COUNT("id") AS agent_runs
|
||||
FROM platform."AgentGraphExecution"
|
||||
GROUP BY "userId"
|
||||
),
|
||||
user_node_runs AS (
|
||||
-- Aggregate AgentNodeExecution directly; resolve userId via a
|
||||
-- single join to AgentGraphExecution instead of fanning out from
|
||||
-- user_logins through both large tables.
|
||||
SELECT
|
||||
g."userId"::text AS user_id,
|
||||
COUNT(*) AS node_execution_count,
|
||||
COUNT(*) FILTER (WHERE n."executionStatus" = 'FAILED') AS node_execution_failed,
|
||||
COUNT(*) FILTER (WHERE n."executionStatus" = 'COMPLETED') AS node_execution_completed,
|
||||
COUNT(*) FILTER (WHERE n."executionStatus" = 'TERMINATED') AS node_execution_terminated,
|
||||
COUNT(*) FILTER (WHERE n."executionStatus" = 'QUEUED') AS node_execution_queued,
|
||||
COUNT(*) FILTER (WHERE n."executionStatus" = 'RUNNING') AS node_execution_running,
|
||||
COUNT(*) FILTER (WHERE n."executionStatus" = 'INCOMPLETE') AS node_execution_incomplete,
|
||||
COUNT(*) FILTER (WHERE n."executionStatus" = 'REVIEW') AS node_execution_review
|
||||
FROM platform."AgentNodeExecution" n
|
||||
JOIN platform."AgentGraphExecution" g
|
||||
ON g."id" = n."agentGraphExecutionId"
|
||||
GROUP BY g."userId"
|
||||
)
|
||||
SELECT
|
||||
ul.user_id,
|
||||
ul.first_login_time,
|
||||
ul.last_login_time,
|
||||
ul.last_visit_time,
|
||||
ua.last_agent_save_time,
|
||||
COALESCE(ua.agent_count, 0) AS agent_count,
|
||||
gr.first_agent_run_time,
|
||||
gr.last_agent_run_time,
|
||||
COALESCE(gr.unique_agent_runs, 0) AS unique_agent_runs,
|
||||
COALESCE(gr.agent_runs, 0) AS agent_runs,
|
||||
COALESCE(nr.node_execution_count, 0) AS node_execution_count,
|
||||
COALESCE(nr.node_execution_failed, 0) AS node_execution_failed,
|
||||
COALESCE(nr.node_execution_completed, 0) AS node_execution_completed,
|
||||
COALESCE(nr.node_execution_terminated, 0) AS node_execution_terminated,
|
||||
COALESCE(nr.node_execution_queued, 0) AS node_execution_queued,
|
||||
COALESCE(nr.node_execution_running, 0) AS node_execution_running,
|
||||
CASE
|
||||
WHEN ul.first_login_time < NOW() - INTERVAL '7 days'
|
||||
AND ul.last_visit_time >= ul.first_login_time + INTERVAL '7 days' THEN 1
|
||||
WHEN ul.first_login_time < NOW() - INTERVAL '7 days'
|
||||
AND ul.last_visit_time < ul.first_login_time + INTERVAL '7 days' THEN 0
|
||||
ELSE NULL
|
||||
END AS is_active_after_7d,
|
||||
COALESCE(nr.node_execution_incomplete, 0) AS node_execution_incomplete,
|
||||
COALESCE(nr.node_execution_review, 0) AS node_execution_review
|
||||
FROM user_logins ul
|
||||
LEFT JOIN user_agents ua ON ul.user_id = ua.user_id
|
||||
LEFT JOIN user_graph_runs gr ON ul.user_id = gr.user_id
|
||||
LEFT JOIN user_node_runs nr ON ul.user_id = nr.user_id
|
||||
@@ -37,10 +37,6 @@ JWT_VERIFY_KEY=your-super-secret-jwt-token-with-at-least-32-characters-long
|
||||
ENCRYPTION_KEY=dvziYgz0KSK8FENhju0ZYi8-fRTfAdlz6YLhdB_jhNw=
|
||||
UNSUBSCRIBE_SECRET_KEY=HlP8ivStJjmbf6NKi78m_3FnOogut0t5ckzjsIqeaio=
|
||||
|
||||
## ===== SIGNUP / INVITE GATE ===== ##
|
||||
# Set to true to require an invite before users can sign up
|
||||
ENABLE_INVITE_GATE=false
|
||||
|
||||
## ===== IMPORTANT OPTIONAL CONFIGURATION ===== ##
|
||||
# Platform URLs (set these for webhooks and OAuth to work)
|
||||
PLATFORM_BASE_URL=http://localhost:8000
|
||||
|
||||
@@ -1,17 +1,8 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime
|
||||
from typing import TYPE_CHECKING, Any, Literal, Optional
|
||||
|
||||
import prisma.enums
|
||||
from pydantic import BaseModel, EmailStr
|
||||
from pydantic import BaseModel
|
||||
|
||||
from backend.data.model import UserTransaction
|
||||
from backend.util.models import Pagination
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from backend.data.invited_user import BulkInvitedUsersResult, InvitedUserRecord
|
||||
|
||||
|
||||
class UserHistoryResponse(BaseModel):
|
||||
"""Response model for listings with version history"""
|
||||
@@ -23,70 +14,3 @@ class UserHistoryResponse(BaseModel):
|
||||
class AddUserCreditsResponse(BaseModel):
|
||||
new_balance: int
|
||||
transaction_key: str
|
||||
|
||||
|
||||
class CreateInvitedUserRequest(BaseModel):
|
||||
email: EmailStr
|
||||
name: Optional[str] = None
|
||||
|
||||
|
||||
class InvitedUserResponse(BaseModel):
|
||||
id: str
|
||||
email: str
|
||||
status: prisma.enums.InvitedUserStatus
|
||||
auth_user_id: Optional[str] = None
|
||||
name: Optional[str] = None
|
||||
tally_understanding: Optional[dict[str, Any]] = None
|
||||
tally_status: prisma.enums.TallyComputationStatus
|
||||
tally_computed_at: Optional[datetime] = None
|
||||
tally_error: Optional[str] = None
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
|
||||
@classmethod
|
||||
def from_record(cls, record: InvitedUserRecord) -> InvitedUserResponse:
|
||||
return cls.model_validate(record.model_dump())
|
||||
|
||||
|
||||
class InvitedUsersResponse(BaseModel):
|
||||
invited_users: list[InvitedUserResponse]
|
||||
pagination: Pagination
|
||||
|
||||
|
||||
class BulkInvitedUserRowResponse(BaseModel):
|
||||
row_number: int
|
||||
email: Optional[str] = None
|
||||
name: Optional[str] = None
|
||||
status: Literal["CREATED", "SKIPPED", "ERROR"]
|
||||
message: str
|
||||
invited_user: Optional[InvitedUserResponse] = None
|
||||
|
||||
|
||||
class BulkInvitedUsersResponse(BaseModel):
|
||||
created_count: int
|
||||
skipped_count: int
|
||||
error_count: int
|
||||
results: list[BulkInvitedUserRowResponse]
|
||||
|
||||
@classmethod
|
||||
def from_result(cls, result: BulkInvitedUsersResult) -> BulkInvitedUsersResponse:
|
||||
return cls(
|
||||
created_count=result.created_count,
|
||||
skipped_count=result.skipped_count,
|
||||
error_count=result.error_count,
|
||||
results=[
|
||||
BulkInvitedUserRowResponse(
|
||||
row_number=row.row_number,
|
||||
email=row.email,
|
||||
name=row.name,
|
||||
status=row.status,
|
||||
message=row.message,
|
||||
invited_user=(
|
||||
InvitedUserResponse.from_record(row.invited_user)
|
||||
if row.invited_user is not None
|
||||
else None
|
||||
),
|
||||
)
|
||||
for row in result.results
|
||||
],
|
||||
)
|
||||
|
||||
@@ -1,137 +0,0 @@
|
||||
import logging
|
||||
import math
|
||||
|
||||
from autogpt_libs.auth import get_user_id, requires_admin_user
|
||||
from fastapi import APIRouter, File, Query, Security, UploadFile
|
||||
|
||||
from backend.data.invited_user import (
|
||||
bulk_create_invited_users_from_file,
|
||||
create_invited_user,
|
||||
list_invited_users,
|
||||
retry_invited_user_tally,
|
||||
revoke_invited_user,
|
||||
)
|
||||
from backend.data.tally import mask_email
|
||||
from backend.util.models import Pagination
|
||||
|
||||
from .model import (
|
||||
BulkInvitedUsersResponse,
|
||||
CreateInvitedUserRequest,
|
||||
InvitedUserResponse,
|
||||
InvitedUsersResponse,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
router = APIRouter(
|
||||
prefix="/admin",
|
||||
tags=["users", "admin"],
|
||||
dependencies=[Security(requires_admin_user)],
|
||||
)
|
||||
|
||||
|
||||
@router.get(
|
||||
"/invited-users",
|
||||
response_model=InvitedUsersResponse,
|
||||
summary="List Invited Users",
|
||||
)
|
||||
async def get_invited_users(
|
||||
admin_user_id: str = Security(get_user_id),
|
||||
page: int = Query(1, ge=1),
|
||||
page_size: int = Query(50, ge=1, le=200),
|
||||
) -> InvitedUsersResponse:
|
||||
logger.info("Admin user %s requested invited users", admin_user_id)
|
||||
invited_users, total = await list_invited_users(page=page, page_size=page_size)
|
||||
return InvitedUsersResponse(
|
||||
invited_users=[InvitedUserResponse.from_record(iu) for iu in invited_users],
|
||||
pagination=Pagination(
|
||||
total_items=total,
|
||||
total_pages=max(1, math.ceil(total / page_size)),
|
||||
current_page=page,
|
||||
page_size=page_size,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/invited-users",
|
||||
response_model=InvitedUserResponse,
|
||||
summary="Create Invited User",
|
||||
)
|
||||
async def create_invited_user_route(
|
||||
request: CreateInvitedUserRequest,
|
||||
admin_user_id: str = Security(get_user_id),
|
||||
) -> InvitedUserResponse:
|
||||
logger.info(
|
||||
"Admin user %s creating invited user for %s",
|
||||
admin_user_id,
|
||||
mask_email(request.email),
|
||||
)
|
||||
invited_user = await create_invited_user(request.email, request.name)
|
||||
logger.info(
|
||||
"Admin user %s created invited user %s",
|
||||
admin_user_id,
|
||||
invited_user.id,
|
||||
)
|
||||
return InvitedUserResponse.from_record(invited_user)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/invited-users/bulk",
|
||||
response_model=BulkInvitedUsersResponse,
|
||||
summary="Bulk Create Invited Users",
|
||||
operation_id="postV2BulkCreateInvitedUsers",
|
||||
)
|
||||
async def bulk_create_invited_users_route(
|
||||
file: UploadFile = File(...),
|
||||
admin_user_id: str = Security(get_user_id),
|
||||
) -> BulkInvitedUsersResponse:
|
||||
logger.info(
|
||||
"Admin user %s bulk invited users from %s",
|
||||
admin_user_id,
|
||||
file.filename or "<unnamed>",
|
||||
)
|
||||
content = await file.read()
|
||||
result = await bulk_create_invited_users_from_file(file.filename, content)
|
||||
return BulkInvitedUsersResponse.from_result(result)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/invited-users/{invited_user_id}/revoke",
|
||||
response_model=InvitedUserResponse,
|
||||
summary="Revoke Invited User",
|
||||
)
|
||||
async def revoke_invited_user_route(
|
||||
invited_user_id: str,
|
||||
admin_user_id: str = Security(get_user_id),
|
||||
) -> InvitedUserResponse:
|
||||
logger.info(
|
||||
"Admin user %s revoking invited user %s", admin_user_id, invited_user_id
|
||||
)
|
||||
invited_user = await revoke_invited_user(invited_user_id)
|
||||
logger.info("Admin user %s revoked invited user %s", admin_user_id, invited_user_id)
|
||||
return InvitedUserResponse.from_record(invited_user)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/invited-users/{invited_user_id}/retry-tally",
|
||||
response_model=InvitedUserResponse,
|
||||
summary="Retry Invited User Tally",
|
||||
)
|
||||
async def retry_invited_user_tally_route(
|
||||
invited_user_id: str,
|
||||
admin_user_id: str = Security(get_user_id),
|
||||
) -> InvitedUserResponse:
|
||||
logger.info(
|
||||
"Admin user %s retrying Tally seed for invited user %s",
|
||||
admin_user_id,
|
||||
invited_user_id,
|
||||
)
|
||||
invited_user = await retry_invited_user_tally(invited_user_id)
|
||||
logger.info(
|
||||
"Admin user %s retried Tally seed for invited user %s",
|
||||
admin_user_id,
|
||||
invited_user_id,
|
||||
)
|
||||
return InvitedUserResponse.from_record(invited_user)
|
||||
@@ -1,168 +0,0 @@
|
||||
from datetime import datetime, timezone
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
import fastapi
|
||||
import fastapi.testclient
|
||||
import prisma.enums
|
||||
import pytest
|
||||
import pytest_mock
|
||||
from autogpt_libs.auth.jwt_utils import get_jwt_payload
|
||||
|
||||
from backend.data.invited_user import (
|
||||
BulkInvitedUserRowResult,
|
||||
BulkInvitedUsersResult,
|
||||
InvitedUserRecord,
|
||||
)
|
||||
|
||||
from .user_admin_routes import router as user_admin_router
|
||||
|
||||
app = fastapi.FastAPI()
|
||||
app.include_router(user_admin_router)
|
||||
|
||||
client = fastapi.testclient.TestClient(app)
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def setup_app_admin_auth(mock_jwt_admin):
|
||||
app.dependency_overrides[get_jwt_payload] = mock_jwt_admin["get_jwt_payload"]
|
||||
yield
|
||||
app.dependency_overrides.clear()
|
||||
|
||||
|
||||
def _sample_invited_user() -> InvitedUserRecord:
|
||||
now = datetime.now(timezone.utc)
|
||||
return InvitedUserRecord(
|
||||
id="invite-1",
|
||||
email="invited@example.com",
|
||||
status=prisma.enums.InvitedUserStatus.INVITED,
|
||||
auth_user_id=None,
|
||||
name="Invited User",
|
||||
tally_understanding=None,
|
||||
tally_status=prisma.enums.TallyComputationStatus.PENDING,
|
||||
tally_computed_at=None,
|
||||
tally_error=None,
|
||||
created_at=now,
|
||||
updated_at=now,
|
||||
)
|
||||
|
||||
|
||||
def _sample_bulk_invited_users_result() -> BulkInvitedUsersResult:
|
||||
return BulkInvitedUsersResult(
|
||||
created_count=1,
|
||||
skipped_count=1,
|
||||
error_count=0,
|
||||
results=[
|
||||
BulkInvitedUserRowResult(
|
||||
row_number=1,
|
||||
email="invited@example.com",
|
||||
name=None,
|
||||
status="CREATED",
|
||||
message="Invite created",
|
||||
invited_user=_sample_invited_user(),
|
||||
),
|
||||
BulkInvitedUserRowResult(
|
||||
row_number=2,
|
||||
email="duplicate@example.com",
|
||||
name=None,
|
||||
status="SKIPPED",
|
||||
message="An invited user with this email already exists",
|
||||
invited_user=None,
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
def test_get_invited_users(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
) -> None:
|
||||
mocker.patch(
|
||||
"backend.api.features.admin.user_admin_routes.list_invited_users",
|
||||
AsyncMock(return_value=([_sample_invited_user()], 1)),
|
||||
)
|
||||
|
||||
response = client.get("/admin/invited-users")
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert len(data["invited_users"]) == 1
|
||||
assert data["invited_users"][0]["email"] == "invited@example.com"
|
||||
assert data["invited_users"][0]["status"] == "INVITED"
|
||||
assert data["pagination"]["total_items"] == 1
|
||||
assert data["pagination"]["current_page"] == 1
|
||||
assert data["pagination"]["page_size"] == 50
|
||||
|
||||
|
||||
def test_create_invited_user(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
) -> None:
|
||||
mocker.patch(
|
||||
"backend.api.features.admin.user_admin_routes.create_invited_user",
|
||||
AsyncMock(return_value=_sample_invited_user()),
|
||||
)
|
||||
|
||||
response = client.post(
|
||||
"/admin/invited-users",
|
||||
json={"email": "invited@example.com", "name": "Invited User"},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["email"] == "invited@example.com"
|
||||
assert data["name"] == "Invited User"
|
||||
|
||||
|
||||
def test_bulk_create_invited_users(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
) -> None:
|
||||
mocker.patch(
|
||||
"backend.api.features.admin.user_admin_routes.bulk_create_invited_users_from_file",
|
||||
AsyncMock(return_value=_sample_bulk_invited_users_result()),
|
||||
)
|
||||
|
||||
response = client.post(
|
||||
"/admin/invited-users/bulk",
|
||||
files={
|
||||
"file": ("invites.txt", b"invited@example.com\nduplicate@example.com\n")
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["created_count"] == 1
|
||||
assert data["skipped_count"] == 1
|
||||
assert data["results"][0]["status"] == "CREATED"
|
||||
assert data["results"][1]["status"] == "SKIPPED"
|
||||
|
||||
|
||||
def test_revoke_invited_user(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
) -> None:
|
||||
revoked = _sample_invited_user().model_copy(
|
||||
update={"status": prisma.enums.InvitedUserStatus.REVOKED}
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.api.features.admin.user_admin_routes.revoke_invited_user",
|
||||
AsyncMock(return_value=revoked),
|
||||
)
|
||||
|
||||
response = client.post("/admin/invited-users/invite-1/revoke")
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.json()["status"] == "REVOKED"
|
||||
|
||||
|
||||
def test_retry_invited_user_tally(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
) -> None:
|
||||
retried = _sample_invited_user().model_copy(
|
||||
update={"tally_status": prisma.enums.TallyComputationStatus.RUNNING}
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.api.features.admin.user_admin_routes.retry_invited_user_tally",
|
||||
AsyncMock(return_value=retried),
|
||||
)
|
||||
|
||||
response = client.post("/admin/invited-users/invite-1/retry-tally")
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.json()["tally_status"] == "RUNNING"
|
||||
@@ -27,12 +27,6 @@ from backend.copilot.model import (
|
||||
get_user_sessions,
|
||||
update_session_title,
|
||||
)
|
||||
from backend.copilot.rate_limit import (
|
||||
CoPilotUsageStatus,
|
||||
RateLimitExceeded,
|
||||
check_rate_limit,
|
||||
get_usage_status,
|
||||
)
|
||||
from backend.copilot.response_model import StreamError, StreamFinish, StreamHeartbeat
|
||||
from backend.copilot.tools.e2b_sandbox import kill_sandbox
|
||||
from backend.copilot.tools.models import (
|
||||
@@ -60,7 +54,6 @@ from backend.copilot.tools.models import (
|
||||
)
|
||||
from backend.copilot.tracking import track_user_message
|
||||
from backend.data.redis_client import get_redis_async
|
||||
from backend.data.understanding import get_business_understanding
|
||||
from backend.data.workspace import get_or_create_workspace
|
||||
from backend.util.exceptions import NotFoundError
|
||||
|
||||
@@ -126,8 +119,6 @@ class SessionDetailResponse(BaseModel):
|
||||
user_id: str | None
|
||||
messages: list[dict]
|
||||
active_stream: ActiveStreamInfo | None = None # Present if stream is still active
|
||||
total_prompt_tokens: int = 0
|
||||
total_completion_tokens: int = 0
|
||||
|
||||
|
||||
class SessionSummaryResponse(BaseModel):
|
||||
@@ -397,10 +388,6 @@ async def get_session(
|
||||
last_message_id=last_message_id,
|
||||
)
|
||||
|
||||
# Sum token usage from session
|
||||
total_prompt = sum(u.prompt_tokens for u in session.usage)
|
||||
total_completion = sum(u.completion_tokens for u in session.usage)
|
||||
|
||||
return SessionDetailResponse(
|
||||
id=session.session_id,
|
||||
created_at=session.started_at.isoformat(),
|
||||
@@ -408,26 +395,6 @@ async def get_session(
|
||||
user_id=session.user_id or None,
|
||||
messages=messages,
|
||||
active_stream=active_stream_info,
|
||||
total_prompt_tokens=total_prompt,
|
||||
total_completion_tokens=total_completion,
|
||||
)
|
||||
|
||||
|
||||
@router.get("/usage")
|
||||
async def get_copilot_usage(
|
||||
user_id: Annotated[str | None, Depends(auth.get_user_id)],
|
||||
) -> CoPilotUsageStatus:
|
||||
"""Get CoPilot usage status for the authenticated user.
|
||||
|
||||
Returns current token usage vs limits for daily and weekly windows.
|
||||
"""
|
||||
if not user_id:
|
||||
raise HTTPException(status_code=401, detail="Authentication required")
|
||||
|
||||
return await get_usage_status(
|
||||
user_id=user_id,
|
||||
daily_token_limit=config.daily_token_limit,
|
||||
weekly_token_limit=config.weekly_token_limit,
|
||||
)
|
||||
|
||||
|
||||
@@ -528,17 +495,6 @@ async def stream_chat_post(
|
||||
},
|
||||
)
|
||||
|
||||
# Pre-turn rate limit check (token-based)
|
||||
if user_id and (config.daily_token_limit > 0 or config.weekly_token_limit > 0):
|
||||
try:
|
||||
await check_rate_limit(
|
||||
user_id=user_id,
|
||||
daily_token_limit=config.daily_token_limit,
|
||||
weekly_token_limit=config.weekly_token_limit,
|
||||
)
|
||||
except RateLimitExceeded as e:
|
||||
raise HTTPException(status_code=429, detail=str(e)) from e
|
||||
|
||||
# Enrich message with file metadata if file_ids are provided.
|
||||
# Also sanitise file_ids so only validated, workspace-scoped IDs are
|
||||
# forwarded downstream (e.g. to the executor via enqueue_copilot_turn).
|
||||
@@ -897,36 +853,6 @@ async def session_assign_user(
|
||||
return {"status": "ok"}
|
||||
|
||||
|
||||
# ========== Suggested Prompts ==========
|
||||
|
||||
|
||||
class SuggestedPromptsResponse(BaseModel):
|
||||
"""Response model for user-specific suggested prompts."""
|
||||
|
||||
prompts: list[str]
|
||||
|
||||
|
||||
@router.get(
|
||||
"/suggested-prompts",
|
||||
dependencies=[Security(auth.requires_user)],
|
||||
)
|
||||
async def get_suggested_prompts(
|
||||
user_id: Annotated[str, Security(auth.get_user_id)],
|
||||
) -> SuggestedPromptsResponse:
|
||||
"""
|
||||
Get LLM-generated suggested prompts for the authenticated user.
|
||||
|
||||
Returns personalized quick-action prompts based on the user's
|
||||
business understanding. Returns an empty list if no custom prompts
|
||||
are available.
|
||||
"""
|
||||
understanding = await get_business_understanding(user_id)
|
||||
if understanding is None:
|
||||
return SuggestedPromptsResponse(prompts=[])
|
||||
|
||||
return SuggestedPromptsResponse(prompts=understanding.suggested_prompts)
|
||||
|
||||
|
||||
# ========== Configuration ==========
|
||||
|
||||
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
"""Tests for chat API routes: session title update, file attachment validation, usage, and suggested prompts."""
|
||||
"""Tests for chat API routes: session title update and file attachment validation."""
|
||||
|
||||
from datetime import UTC, datetime, timedelta
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
import fastapi
|
||||
import fastapi.testclient
|
||||
@@ -250,130 +249,3 @@ def test_file_ids_scoped_to_workspace(mocker: pytest_mock.MockFixture):
|
||||
call_kwargs = mock_prisma.find_many.call_args[1]
|
||||
assert call_kwargs["where"]["workspaceId"] == "my-workspace-id"
|
||||
assert call_kwargs["where"]["isDeleted"] is False
|
||||
|
||||
|
||||
# ─── Usage endpoint ───────────────────────────────────────────────────
|
||||
|
||||
|
||||
def _mock_usage(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
*,
|
||||
daily_used: int = 500,
|
||||
weekly_used: int = 2000,
|
||||
) -> AsyncMock:
|
||||
"""Mock get_usage_status to return a predictable CoPilotUsageStatus."""
|
||||
from backend.copilot.rate_limit import CoPilotUsageStatus, UsageWindow
|
||||
|
||||
resets_at = datetime.now(UTC) + timedelta(days=1)
|
||||
status = CoPilotUsageStatus(
|
||||
daily=UsageWindow(used=daily_used, limit=10000, resets_at=resets_at),
|
||||
weekly=UsageWindow(used=weekly_used, limit=50000, resets_at=resets_at),
|
||||
)
|
||||
return mocker.patch(
|
||||
"backend.api.features.chat.routes.get_usage_status",
|
||||
new_callable=AsyncMock,
|
||||
return_value=status,
|
||||
)
|
||||
|
||||
|
||||
def test_usage_returns_daily_and_weekly(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
test_user_id: str,
|
||||
) -> None:
|
||||
"""GET /usage returns daily and weekly usage."""
|
||||
mock_get = _mock_usage(mocker, daily_used=500, weekly_used=2000)
|
||||
|
||||
mocker.patch.object(chat_routes.config, "daily_token_limit", 10000)
|
||||
mocker.patch.object(chat_routes.config, "weekly_token_limit", 50000)
|
||||
|
||||
response = client.get("/usage")
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["daily"]["used"] == 500
|
||||
assert data["weekly"]["used"] == 2000
|
||||
|
||||
mock_get.assert_called_once_with(
|
||||
user_id=test_user_id,
|
||||
daily_token_limit=10000,
|
||||
weekly_token_limit=50000,
|
||||
)
|
||||
|
||||
|
||||
def test_usage_uses_config_limits(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
test_user_id: str,
|
||||
) -> None:
|
||||
"""The endpoint forwards daily_token_limit and weekly_token_limit from config."""
|
||||
mock_get = _mock_usage(mocker)
|
||||
|
||||
mocker.patch.object(chat_routes.config, "daily_token_limit", 99999)
|
||||
mocker.patch.object(chat_routes.config, "weekly_token_limit", 77777)
|
||||
|
||||
response = client.get("/usage")
|
||||
|
||||
assert response.status_code == 200
|
||||
mock_get.assert_called_once_with(
|
||||
user_id=test_user_id,
|
||||
daily_token_limit=99999,
|
||||
weekly_token_limit=77777,
|
||||
)
|
||||
|
||||
|
||||
# ─── Suggested prompts endpoint ──────────────────────────────────────
|
||||
|
||||
|
||||
def _mock_get_business_understanding(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
*,
|
||||
return_value=None,
|
||||
):
|
||||
"""Mock get_business_understanding."""
|
||||
return mocker.patch(
|
||||
"backend.api.features.chat.routes.get_business_understanding",
|
||||
new_callable=AsyncMock,
|
||||
return_value=return_value,
|
||||
)
|
||||
|
||||
|
||||
def test_suggested_prompts_returns_prompts(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
test_user_id: str,
|
||||
) -> None:
|
||||
"""User with understanding and prompts gets them back."""
|
||||
mock_understanding = MagicMock()
|
||||
mock_understanding.suggested_prompts = ["Do X", "Do Y", "Do Z"]
|
||||
_mock_get_business_understanding(mocker, return_value=mock_understanding)
|
||||
|
||||
response = client.get("/suggested-prompts")
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.json() == {"prompts": ["Do X", "Do Y", "Do Z"]}
|
||||
|
||||
|
||||
def test_suggested_prompts_no_understanding(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
test_user_id: str,
|
||||
) -> None:
|
||||
"""User with no understanding gets empty list."""
|
||||
_mock_get_business_understanding(mocker, return_value=None)
|
||||
|
||||
response = client.get("/suggested-prompts")
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.json() == {"prompts": []}
|
||||
|
||||
|
||||
def test_suggested_prompts_empty_prompts(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
test_user_id: str,
|
||||
) -> None:
|
||||
"""User with understanding but no prompts gets empty list."""
|
||||
mock_understanding = MagicMock()
|
||||
mock_understanding.suggested_prompts = []
|
||||
_mock_get_business_understanding(mocker, return_value=mock_understanding)
|
||||
|
||||
response = client.get("/suggested-prompts")
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.json() == {"prompts": []}
|
||||
|
||||
@@ -55,7 +55,6 @@ from backend.data.credit import (
|
||||
set_auto_top_up,
|
||||
)
|
||||
from backend.data.graph import GraphSettings
|
||||
from backend.data.invited_user import get_or_activate_user
|
||||
from backend.data.model import CredentialsMetaInput, UserOnboarding
|
||||
from backend.data.notifications import NotificationPreference, NotificationPreferenceDTO
|
||||
from backend.data.onboarding import (
|
||||
@@ -71,6 +70,7 @@ from backend.data.onboarding import (
|
||||
update_user_onboarding,
|
||||
)
|
||||
from backend.data.user import (
|
||||
get_or_create_user,
|
||||
get_user_by_id,
|
||||
get_user_notification_preference,
|
||||
update_user_email,
|
||||
@@ -136,10 +136,12 @@ _tally_background_tasks: set[asyncio.Task] = set()
|
||||
dependencies=[Security(requires_user)],
|
||||
)
|
||||
async def get_or_create_user_route(user_data: dict = Security(get_jwt_payload)):
|
||||
user = await get_or_activate_user(user_data)
|
||||
user = await get_or_create_user(user_data)
|
||||
|
||||
# Fire-and-forget: backfill Tally understanding when invite pre-seeding did
|
||||
# not produce a stored result before first activation.
|
||||
# Fire-and-forget: populate business understanding from Tally form.
|
||||
# We use created_at proximity instead of an is_new flag because
|
||||
# get_or_create_user is cached — a separate is_new return value would be
|
||||
# unreliable on repeated calls within the cache TTL.
|
||||
age_seconds = (datetime.now(timezone.utc) - user.created_at).total_seconds()
|
||||
if age_seconds < 30:
|
||||
try:
|
||||
@@ -163,8 +165,7 @@ async def get_or_create_user_route(user_data: dict = Security(get_jwt_payload)):
|
||||
dependencies=[Security(requires_user)],
|
||||
)
|
||||
async def update_user_email_route(
|
||||
user_id: Annotated[str, Security(get_user_id)],
|
||||
email: str = Body(...),
|
||||
user_id: Annotated[str, Security(get_user_id)], email: str = Body(...)
|
||||
) -> dict[str, str]:
|
||||
await update_user_email(user_id, email)
|
||||
|
||||
@@ -178,16 +179,10 @@ async def update_user_email_route(
|
||||
dependencies=[Security(requires_user)],
|
||||
)
|
||||
async def get_user_timezone_route(
|
||||
user_id: Annotated[str, Security(get_user_id)],
|
||||
user_data: dict = Security(get_jwt_payload),
|
||||
) -> TimezoneResponse:
|
||||
"""Get user timezone setting."""
|
||||
try:
|
||||
user = await get_user_by_id(user_id)
|
||||
except ValueError:
|
||||
raise HTTPException(
|
||||
status_code=HTTP_404_NOT_FOUND,
|
||||
detail="User not found. Please complete activation via /auth/user first.",
|
||||
)
|
||||
user = await get_or_create_user(user_data)
|
||||
return TimezoneResponse(timezone=user.timezone)
|
||||
|
||||
|
||||
@@ -198,8 +193,7 @@ async def get_user_timezone_route(
|
||||
dependencies=[Security(requires_user)],
|
||||
)
|
||||
async def update_user_timezone_route(
|
||||
user_id: Annotated[str, Security(get_user_id)],
|
||||
request: UpdateTimezoneRequest,
|
||||
user_id: Annotated[str, Security(get_user_id)], request: UpdateTimezoneRequest
|
||||
) -> TimezoneResponse:
|
||||
"""Update user timezone. The timezone should be a valid IANA timezone identifier."""
|
||||
user = await update_user_timezone(user_id, str(request.timezone))
|
||||
|
||||
@@ -51,7 +51,7 @@ def test_get_or_create_user_route(
|
||||
}
|
||||
|
||||
mocker.patch(
|
||||
"backend.api.features.v1.get_or_activate_user",
|
||||
"backend.api.features.v1.get_or_create_user",
|
||||
return_value=mock_user,
|
||||
)
|
||||
|
||||
|
||||
@@ -19,7 +19,6 @@ from prisma.errors import PrismaError
|
||||
import backend.api.features.admin.credit_admin_routes
|
||||
import backend.api.features.admin.execution_analytics_routes
|
||||
import backend.api.features.admin.store_admin_routes
|
||||
import backend.api.features.admin.user_admin_routes
|
||||
import backend.api.features.builder
|
||||
import backend.api.features.builder.routes
|
||||
import backend.api.features.chat.routes as chat_routes
|
||||
@@ -312,11 +311,6 @@ app.include_router(
|
||||
tags=["v2", "admin"],
|
||||
prefix="/api/executions",
|
||||
)
|
||||
app.include_router(
|
||||
backend.api.features.admin.user_admin_routes.router,
|
||||
tags=["v2", "admin"],
|
||||
prefix="/api/users",
|
||||
)
|
||||
app.include_router(
|
||||
backend.api.features.executions.review.routes.router,
|
||||
tags=["v2", "executions", "review"],
|
||||
|
||||
@@ -18,13 +18,11 @@ from langfuse import propagate_attributes
|
||||
from backend.copilot.model import (
|
||||
ChatMessage,
|
||||
ChatSession,
|
||||
Usage,
|
||||
get_chat_session,
|
||||
update_session_title,
|
||||
upsert_chat_session,
|
||||
)
|
||||
from backend.copilot.prompting import get_baseline_supplement
|
||||
from backend.copilot.rate_limit import record_token_usage
|
||||
from backend.copilot.response_model import (
|
||||
StreamBaseResponse,
|
||||
StreamError,
|
||||
@@ -38,7 +36,6 @@ from backend.copilot.response_model import (
|
||||
StreamToolInputAvailable,
|
||||
StreamToolInputStart,
|
||||
StreamToolOutputAvailable,
|
||||
StreamUsage,
|
||||
)
|
||||
from backend.copilot.service import (
|
||||
_build_system_prompt,
|
||||
@@ -49,11 +46,7 @@ from backend.copilot.service import (
|
||||
from backend.copilot.tools import execute_tool, get_available_tools
|
||||
from backend.copilot.tracking import track_user_message
|
||||
from backend.util.exceptions import NotFoundError
|
||||
from backend.util.prompt import (
|
||||
compress_context,
|
||||
estimate_token_count,
|
||||
estimate_token_count_str,
|
||||
)
|
||||
from backend.util.prompt import compress_context
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -228,9 +221,6 @@ async def stream_chat_completion_baseline(
|
||||
text_block_id = str(uuid.uuid4())
|
||||
text_started = False
|
||||
step_open = False
|
||||
# Token usage accumulators — populated from streaming chunks
|
||||
turn_prompt_tokens = 0
|
||||
turn_completion_tokens = 0
|
||||
try:
|
||||
for _round in range(_MAX_TOOL_ROUNDS):
|
||||
# Open a new step for each LLM round
|
||||
@@ -242,7 +232,6 @@ async def stream_chat_completion_baseline(
|
||||
model=config.model,
|
||||
messages=openai_messages,
|
||||
stream=True,
|
||||
stream_options={"include_usage": True},
|
||||
)
|
||||
if tools:
|
||||
create_kwargs["tools"] = tools
|
||||
@@ -253,18 +242,7 @@ async def stream_chat_completion_baseline(
|
||||
tool_calls_by_index: dict[int, dict[str, str]] = {}
|
||||
|
||||
async for chunk in response:
|
||||
# Capture token usage from the streaming chunk.
|
||||
# OpenRouter normalises all providers into OpenAI format
|
||||
# where prompt_tokens already includes cached tokens
|
||||
# (unlike Anthropic's native API). Use += to sum all
|
||||
# tool-call rounds since each API call is independent.
|
||||
if chunk.usage:
|
||||
turn_prompt_tokens += chunk.usage.prompt_tokens or 0
|
||||
turn_completion_tokens += chunk.usage.completion_tokens or 0
|
||||
|
||||
if not chunk.choices:
|
||||
continue
|
||||
delta = chunk.choices[0].delta
|
||||
delta = chunk.choices[0].delta if chunk.choices else None
|
||||
if not delta:
|
||||
continue
|
||||
|
||||
@@ -433,53 +411,6 @@ async def stream_chat_completion_baseline(
|
||||
except Exception:
|
||||
logger.warning("[Baseline] Langfuse trace context teardown failed")
|
||||
|
||||
# Fallback: estimate tokens via tiktoken when the provider does
|
||||
# not honour stream_options={"include_usage": True}.
|
||||
# Count the full message list (system + history + turn) since
|
||||
# each API call sends the complete context window.
|
||||
if turn_prompt_tokens == 0 and turn_completion_tokens == 0:
|
||||
turn_prompt_tokens = max(
|
||||
estimate_token_count(openai_messages, model=config.model), 0
|
||||
)
|
||||
turn_completion_tokens = max(
|
||||
estimate_token_count_str(assistant_text, model=config.model), 0
|
||||
)
|
||||
logger.info(
|
||||
"[Baseline] No streaming usage reported; estimated tokens: "
|
||||
"prompt=%d, completion=%d",
|
||||
turn_prompt_tokens,
|
||||
turn_completion_tokens,
|
||||
)
|
||||
|
||||
# Emit token usage and update session for persistence
|
||||
if turn_prompt_tokens > 0 or turn_completion_tokens > 0:
|
||||
total_tokens = turn_prompt_tokens + turn_completion_tokens
|
||||
session.usage.append(
|
||||
Usage(
|
||||
prompt_tokens=turn_prompt_tokens,
|
||||
completion_tokens=turn_completion_tokens,
|
||||
total_tokens=total_tokens,
|
||||
)
|
||||
)
|
||||
logger.info(
|
||||
"[Baseline] Turn usage: prompt=%d, completion=%d, total=%d",
|
||||
turn_prompt_tokens,
|
||||
turn_completion_tokens,
|
||||
total_tokens,
|
||||
)
|
||||
# Record for rate limiting counters
|
||||
if user_id:
|
||||
try:
|
||||
await record_token_usage(
|
||||
user_id=user_id,
|
||||
prompt_tokens=turn_prompt_tokens,
|
||||
completion_tokens=turn_completion_tokens,
|
||||
)
|
||||
except Exception as usage_err:
|
||||
logger.warning(
|
||||
"[Baseline] Failed to record token usage: %s", usage_err
|
||||
)
|
||||
|
||||
# Persist assistant response
|
||||
if assistant_text:
|
||||
session.messages.append(
|
||||
@@ -490,16 +421,4 @@ async def stream_chat_completion_baseline(
|
||||
except Exception as persist_err:
|
||||
logger.error("[Baseline] Failed to persist session: %s", persist_err)
|
||||
|
||||
# Yield usage and finish AFTER try/finally (not inside finally).
|
||||
# PEP 525 prohibits yielding from finally in async generators during
|
||||
# aclose() — doing so raises RuntimeError on client disconnect.
|
||||
# On GeneratorExit the client is already gone, so unreachable yields
|
||||
# are harmless; on normal completion they reach the SSE stream.
|
||||
if turn_prompt_tokens > 0 or turn_completion_tokens > 0:
|
||||
yield StreamUsage(
|
||||
promptTokens=turn_prompt_tokens,
|
||||
completionTokens=turn_completion_tokens,
|
||||
totalTokens=turn_prompt_tokens + turn_completion_tokens,
|
||||
)
|
||||
|
||||
yield StreamFinish()
|
||||
|
||||
@@ -70,20 +70,6 @@ class ChatConfig(BaseSettings):
|
||||
description="Cache TTL in seconds for Langfuse prompt (0 to disable caching)",
|
||||
)
|
||||
|
||||
# Rate limiting — token-based limits per day and per week.
|
||||
# Each CoPilot turn consumes ~10-15K tokens (system prompt + tool schemas + response),
|
||||
# so 2.5M daily allows ~170-250 turns/day which is reasonable for normal use.
|
||||
# TODO: These are global deploy-time constants. For per-user or per-plan limits,
|
||||
# move to the database (e.g. UserPlan table) and look up in get_usage_status.
|
||||
daily_token_limit: int = Field(
|
||||
default=2_500_000,
|
||||
description="Max tokens per day, resets at midnight UTC (0 = unlimited)",
|
||||
)
|
||||
weekly_token_limit: int = Field(
|
||||
default=12_500_000,
|
||||
description="Max tokens per week, resets Monday 00:00 UTC (0 = unlimited)",
|
||||
)
|
||||
|
||||
# Claude Agent SDK Configuration
|
||||
use_claude_agent_sdk: bool = Field(
|
||||
default=True,
|
||||
@@ -129,7 +115,7 @@ class ChatConfig(BaseSettings):
|
||||
description="E2B sandbox template to use for copilot sessions.",
|
||||
)
|
||||
e2b_sandbox_timeout: int = Field(
|
||||
default=300, # 5 min safety net — explicit per-turn pause is the primary mechanism
|
||||
default=10800, # 3 hours — wall-clock timeout, not idle; explicit pause is primary
|
||||
description="E2B sandbox running-time timeout (seconds). "
|
||||
"E2B timeout is wall-clock (not idle). Explicit per-turn pause is the primary "
|
||||
"mechanism; this is the safety net.",
|
||||
|
||||
@@ -73,9 +73,6 @@ class Usage(BaseModel):
|
||||
prompt_tokens: int
|
||||
completion_tokens: int
|
||||
total_tokens: int
|
||||
# Cache breakdown (Anthropic-specific; zero for non-Anthropic models)
|
||||
cache_read_tokens: int = 0
|
||||
cache_creation_tokens: int = 0
|
||||
|
||||
|
||||
class ChatSessionInfo(BaseModel):
|
||||
|
||||
@@ -52,11 +52,6 @@ Examples:
|
||||
You can embed a reference inside any string argument, or use it as the entire
|
||||
value. Multiple references in one argument are all expanded.
|
||||
|
||||
**Type coercion**: The platform automatically coerces expanded string values
|
||||
to match the block's expected input types. For example, if a block expects
|
||||
`list[list[str]]` and you pass a string containing a JSON array (e.g. from
|
||||
an @@agptfile: expansion), the string will be parsed into the correct type.
|
||||
|
||||
|
||||
### Sub-agent tasks
|
||||
- When using the Task tool, NEVER set `run_in_background` to true.
|
||||
|
||||
@@ -1,253 +0,0 @@
|
||||
"""CoPilot rate limiting based on token usage.
|
||||
|
||||
Uses Redis fixed-window counters to track per-user token consumption
|
||||
with configurable daily and weekly limits. Daily windows reset at
|
||||
midnight UTC; weekly windows reset at ISO week boundary (Monday 00:00
|
||||
UTC). Fails open when Redis is unavailable to avoid blocking users.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from datetime import UTC, datetime, timedelta
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from backend.data.redis_client import get_redis_async
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Redis key prefixes
|
||||
_PREFIX = "copilot:usage"
|
||||
|
||||
|
||||
class UsageWindow(BaseModel):
|
||||
"""Usage within a single time window."""
|
||||
|
||||
used: int
|
||||
limit: int = Field(
|
||||
description="Maximum tokens allowed in this window. 0 means unlimited."
|
||||
)
|
||||
resets_at: datetime
|
||||
|
||||
|
||||
class CoPilotUsageStatus(BaseModel):
|
||||
"""Current usage status for a user across all windows."""
|
||||
|
||||
daily: UsageWindow
|
||||
weekly: UsageWindow
|
||||
|
||||
|
||||
class RateLimitExceeded(Exception):
|
||||
"""Raised when a user exceeds their CoPilot usage limit."""
|
||||
|
||||
def __init__(self, window: str, resets_at: datetime):
|
||||
self.window = window
|
||||
self.resets_at = resets_at
|
||||
delta = resets_at - datetime.now(UTC)
|
||||
total_secs = delta.total_seconds()
|
||||
if total_secs <= 0:
|
||||
time_str = "now"
|
||||
else:
|
||||
hours = int(total_secs // 3600)
|
||||
minutes = int((total_secs % 3600) // 60)
|
||||
time_str = f"{hours}h {minutes}m" if hours > 0 else f"{minutes}m"
|
||||
super().__init__(
|
||||
f"You've reached your {window} usage limit. Resets in {time_str}."
|
||||
)
|
||||
|
||||
|
||||
def _daily_key(user_id: str, now: datetime | None = None) -> str:
|
||||
if now is None:
|
||||
now = datetime.now(UTC)
|
||||
return f"{_PREFIX}:daily:{user_id}:{now.strftime('%Y-%m-%d')}"
|
||||
|
||||
|
||||
def _weekly_key(user_id: str, now: datetime | None = None) -> str:
|
||||
if now is None:
|
||||
now = datetime.now(UTC)
|
||||
year, week, _ = now.isocalendar()
|
||||
return f"{_PREFIX}:weekly:{user_id}:{year}-W{week:02d}"
|
||||
|
||||
|
||||
def _daily_reset_time(now: datetime | None = None) -> datetime:
|
||||
"""Calculate when the current daily window resets (next midnight UTC)."""
|
||||
if now is None:
|
||||
now = datetime.now(UTC)
|
||||
return now.replace(hour=0, minute=0, second=0, microsecond=0) + timedelta(days=1)
|
||||
|
||||
|
||||
def _weekly_reset_time(now: datetime | None = None) -> datetime:
|
||||
"""Calculate when the current weekly window resets (next Monday 00:00 UTC).
|
||||
|
||||
On Monday itself, ``(7 - weekday) % 7`` is 0; the ``or 7`` fallback
|
||||
pushes to *next* Monday so the current week's window stays open.
|
||||
"""
|
||||
if now is None:
|
||||
now = datetime.now(UTC)
|
||||
days_until_monday = (7 - now.weekday()) % 7 or 7
|
||||
return now.replace(hour=0, minute=0, second=0, microsecond=0) + timedelta(
|
||||
days=days_until_monday
|
||||
)
|
||||
|
||||
|
||||
async def _fetch_counters(user_id: str, now: datetime) -> tuple[int, int]:
|
||||
"""Fetch daily and weekly token counters from Redis.
|
||||
|
||||
Returns (daily_used, weekly_used). Returns (0, 0) if Redis is unavailable.
|
||||
"""
|
||||
redis = await get_redis_async()
|
||||
daily_raw, weekly_raw = await asyncio.gather(
|
||||
redis.get(_daily_key(user_id, now=now)),
|
||||
redis.get(_weekly_key(user_id, now=now)),
|
||||
)
|
||||
return int(daily_raw or 0), int(weekly_raw or 0)
|
||||
|
||||
|
||||
async def get_usage_status(
|
||||
user_id: str,
|
||||
daily_token_limit: int,
|
||||
weekly_token_limit: int,
|
||||
) -> CoPilotUsageStatus:
|
||||
"""Get current usage status for a user.
|
||||
|
||||
Args:
|
||||
user_id: The user's ID.
|
||||
daily_token_limit: Max tokens per day (0 = unlimited).
|
||||
weekly_token_limit: Max tokens per week (0 = unlimited).
|
||||
|
||||
Returns:
|
||||
CoPilotUsageStatus with current usage and limits.
|
||||
"""
|
||||
now = datetime.now(UTC)
|
||||
try:
|
||||
daily_used, weekly_used = await _fetch_counters(user_id, now)
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"Redis unavailable for usage status, returning zeros", exc_info=True
|
||||
)
|
||||
daily_used, weekly_used = 0, 0
|
||||
|
||||
return CoPilotUsageStatus(
|
||||
daily=UsageWindow(
|
||||
used=daily_used,
|
||||
limit=daily_token_limit,
|
||||
resets_at=_daily_reset_time(now=now),
|
||||
),
|
||||
weekly=UsageWindow(
|
||||
used=weekly_used,
|
||||
limit=weekly_token_limit,
|
||||
resets_at=_weekly_reset_time(now=now),
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
async def check_rate_limit(
|
||||
user_id: str,
|
||||
daily_token_limit: int,
|
||||
weekly_token_limit: int,
|
||||
) -> None:
|
||||
"""Check if user is within rate limits. Raises RateLimitExceeded if not.
|
||||
|
||||
This is a pre-turn soft check. The authoritative usage counter is updated
|
||||
by ``record_token_usage()`` after the turn completes. Under concurrency,
|
||||
two parallel turns may both pass this check against the same snapshot.
|
||||
This is acceptable because token-based limits are approximate by nature
|
||||
(the exact token count is unknown until after generation).
|
||||
|
||||
Fails open: if Redis is unavailable, allows the request.
|
||||
"""
|
||||
now = datetime.now(UTC)
|
||||
try:
|
||||
daily_used, weekly_used = await _fetch_counters(user_id, now)
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"Redis unavailable for rate limit check, allowing request", exc_info=True
|
||||
)
|
||||
return
|
||||
|
||||
if daily_token_limit > 0 and daily_used >= daily_token_limit:
|
||||
raise RateLimitExceeded("daily", _daily_reset_time(now=now))
|
||||
|
||||
if weekly_token_limit > 0 and weekly_used >= weekly_token_limit:
|
||||
raise RateLimitExceeded("weekly", _weekly_reset_time(now=now))
|
||||
|
||||
|
||||
async def record_token_usage(
|
||||
user_id: str,
|
||||
prompt_tokens: int,
|
||||
completion_tokens: int,
|
||||
*,
|
||||
cache_read_tokens: int = 0,
|
||||
cache_creation_tokens: int = 0,
|
||||
) -> None:
|
||||
"""Record token usage for a user across all windows.
|
||||
|
||||
Uses cost-weighted counting so cached tokens don't unfairly penalise
|
||||
multi-turn conversations. Anthropic's pricing:
|
||||
- uncached input: 100%
|
||||
- cache creation: 25%
|
||||
- cache read: 10%
|
||||
- output: 100%
|
||||
|
||||
``prompt_tokens`` should be the *uncached* input count (``input_tokens``
|
||||
from the API response). Cache counts are passed separately.
|
||||
|
||||
Args:
|
||||
user_id: The user's ID.
|
||||
prompt_tokens: Uncached input tokens.
|
||||
completion_tokens: Output tokens.
|
||||
cache_read_tokens: Tokens served from prompt cache (10% cost).
|
||||
cache_creation_tokens: Tokens written to prompt cache (25% cost).
|
||||
"""
|
||||
weighted_input = (
|
||||
prompt_tokens
|
||||
+ round(cache_creation_tokens * 0.25)
|
||||
+ round(cache_read_tokens * 0.1)
|
||||
)
|
||||
total = weighted_input + completion_tokens
|
||||
if total <= 0:
|
||||
return
|
||||
|
||||
raw_total = (
|
||||
prompt_tokens + cache_read_tokens + cache_creation_tokens + completion_tokens
|
||||
)
|
||||
logger.info(
|
||||
"Recording token usage for %s: raw=%d, weighted=%d "
|
||||
"(uncached=%d, cache_read=%d@10%%, cache_create=%d@25%%, output=%d)",
|
||||
user_id[:8],
|
||||
raw_total,
|
||||
total,
|
||||
prompt_tokens,
|
||||
cache_read_tokens,
|
||||
cache_creation_tokens,
|
||||
completion_tokens,
|
||||
)
|
||||
|
||||
now = datetime.now(UTC)
|
||||
try:
|
||||
redis = await get_redis_async()
|
||||
pipe = redis.pipeline(transaction=False)
|
||||
|
||||
# Daily counter (expires at next midnight UTC)
|
||||
d_key = _daily_key(user_id, now=now)
|
||||
pipe.incrby(d_key, total)
|
||||
seconds_until_daily_reset = int(
|
||||
(_daily_reset_time(now=now) - now).total_seconds()
|
||||
)
|
||||
pipe.expire(d_key, max(seconds_until_daily_reset, 1))
|
||||
|
||||
# Weekly counter (expires end of week)
|
||||
w_key = _weekly_key(user_id, now=now)
|
||||
pipe.incrby(w_key, total)
|
||||
seconds_until_weekly_reset = int(
|
||||
(_weekly_reset_time(now=now) - now).total_seconds()
|
||||
)
|
||||
pipe.expire(w_key, max(seconds_until_weekly_reset, 1))
|
||||
|
||||
await pipe.execute()
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"Redis unavailable for recording token usage (tokens=%d)",
|
||||
total,
|
||||
exc_info=True,
|
||||
)
|
||||
@@ -1,334 +0,0 @@
|
||||
"""Unit tests for CoPilot rate limiting."""
|
||||
|
||||
from datetime import UTC, datetime, timedelta
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from redis.exceptions import RedisError
|
||||
|
||||
from .rate_limit import (
|
||||
CoPilotUsageStatus,
|
||||
RateLimitExceeded,
|
||||
check_rate_limit,
|
||||
get_usage_status,
|
||||
record_token_usage,
|
||||
)
|
||||
|
||||
_USER = "test-user-rl"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# RateLimitExceeded
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestRateLimitExceeded:
|
||||
def test_message_contains_window_name(self):
|
||||
exc = RateLimitExceeded("daily", datetime.now(UTC) + timedelta(hours=1))
|
||||
assert "daily" in str(exc)
|
||||
|
||||
def test_message_contains_reset_time(self):
|
||||
exc = RateLimitExceeded(
|
||||
"weekly", datetime.now(UTC) + timedelta(hours=2, minutes=30)
|
||||
)
|
||||
msg = str(exc)
|
||||
# Allow for slight timing drift (29m or 30m)
|
||||
assert "2h " in msg
|
||||
assert "Resets in" in msg
|
||||
|
||||
def test_message_minutes_only_when_under_one_hour(self):
|
||||
exc = RateLimitExceeded("daily", datetime.now(UTC) + timedelta(minutes=15))
|
||||
msg = str(exc)
|
||||
assert "Resets in" in msg
|
||||
# Should not have "0h"
|
||||
assert "0h" not in msg
|
||||
|
||||
def test_message_says_now_when_resets_at_is_in_the_past(self):
|
||||
"""Negative delta (clock skew / stale TTL) should say 'now', not '-1h -30m'."""
|
||||
exc = RateLimitExceeded("daily", datetime.now(UTC) - timedelta(minutes=5))
|
||||
assert "Resets in now" in str(exc)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# get_usage_status
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestGetUsageStatus:
|
||||
@pytest.mark.asyncio
|
||||
async def test_returns_redis_values(self):
|
||||
mock_redis = AsyncMock()
|
||||
mock_redis.get = AsyncMock(side_effect=["500", "2000"])
|
||||
|
||||
with patch(
|
||||
"backend.copilot.rate_limit.get_redis_async",
|
||||
return_value=mock_redis,
|
||||
):
|
||||
status = await get_usage_status(
|
||||
_USER, daily_token_limit=10000, weekly_token_limit=50000
|
||||
)
|
||||
|
||||
assert isinstance(status, CoPilotUsageStatus)
|
||||
assert status.daily.used == 500
|
||||
assert status.daily.limit == 10000
|
||||
assert status.weekly.used == 2000
|
||||
assert status.weekly.limit == 50000
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_returns_zeros_when_redis_unavailable(self):
|
||||
with patch(
|
||||
"backend.copilot.rate_limit.get_redis_async",
|
||||
side_effect=ConnectionError("Redis down"),
|
||||
):
|
||||
status = await get_usage_status(
|
||||
_USER, daily_token_limit=10000, weekly_token_limit=50000
|
||||
)
|
||||
|
||||
assert status.daily.used == 0
|
||||
assert status.weekly.used == 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_partial_none_daily_counter(self):
|
||||
"""Daily counter is None (new day), weekly has usage."""
|
||||
mock_redis = AsyncMock()
|
||||
mock_redis.get = AsyncMock(side_effect=[None, "3000"])
|
||||
|
||||
with patch(
|
||||
"backend.copilot.rate_limit.get_redis_async",
|
||||
return_value=mock_redis,
|
||||
):
|
||||
status = await get_usage_status(
|
||||
_USER, daily_token_limit=10000, weekly_token_limit=50000
|
||||
)
|
||||
|
||||
assert status.daily.used == 0
|
||||
assert status.weekly.used == 3000
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_partial_none_weekly_counter(self):
|
||||
"""Weekly counter is None (start of week), daily has usage."""
|
||||
mock_redis = AsyncMock()
|
||||
mock_redis.get = AsyncMock(side_effect=["500", None])
|
||||
|
||||
with patch(
|
||||
"backend.copilot.rate_limit.get_redis_async",
|
||||
return_value=mock_redis,
|
||||
):
|
||||
status = await get_usage_status(
|
||||
_USER, daily_token_limit=10000, weekly_token_limit=50000
|
||||
)
|
||||
|
||||
assert status.daily.used == 500
|
||||
assert status.weekly.used == 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_resets_at_daily_is_next_midnight_utc(self):
|
||||
mock_redis = AsyncMock()
|
||||
mock_redis.get = AsyncMock(side_effect=["0", "0"])
|
||||
|
||||
with patch(
|
||||
"backend.copilot.rate_limit.get_redis_async",
|
||||
return_value=mock_redis,
|
||||
):
|
||||
status = await get_usage_status(
|
||||
_USER, daily_token_limit=10000, weekly_token_limit=50000
|
||||
)
|
||||
|
||||
now = datetime.now(UTC)
|
||||
# Daily reset should be within 24h
|
||||
assert status.daily.resets_at > now
|
||||
assert status.daily.resets_at <= now + timedelta(hours=24, seconds=5)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# check_rate_limit
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestCheckRateLimit:
|
||||
@pytest.mark.asyncio
|
||||
async def test_allows_when_under_limit(self):
|
||||
mock_redis = AsyncMock()
|
||||
mock_redis.get = AsyncMock(side_effect=["100", "200"])
|
||||
|
||||
with patch(
|
||||
"backend.copilot.rate_limit.get_redis_async",
|
||||
return_value=mock_redis,
|
||||
):
|
||||
# Should not raise
|
||||
await check_rate_limit(
|
||||
_USER, daily_token_limit=10000, weekly_token_limit=50000
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_raises_when_daily_limit_exceeded(self):
|
||||
mock_redis = AsyncMock()
|
||||
mock_redis.get = AsyncMock(side_effect=["10000", "200"])
|
||||
|
||||
with patch(
|
||||
"backend.copilot.rate_limit.get_redis_async",
|
||||
return_value=mock_redis,
|
||||
):
|
||||
with pytest.raises(RateLimitExceeded) as exc_info:
|
||||
await check_rate_limit(
|
||||
_USER, daily_token_limit=10000, weekly_token_limit=50000
|
||||
)
|
||||
assert exc_info.value.window == "daily"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_raises_when_weekly_limit_exceeded(self):
|
||||
mock_redis = AsyncMock()
|
||||
mock_redis.get = AsyncMock(side_effect=["100", "50000"])
|
||||
|
||||
with patch(
|
||||
"backend.copilot.rate_limit.get_redis_async",
|
||||
return_value=mock_redis,
|
||||
):
|
||||
with pytest.raises(RateLimitExceeded) as exc_info:
|
||||
await check_rate_limit(
|
||||
_USER, daily_token_limit=10000, weekly_token_limit=50000
|
||||
)
|
||||
assert exc_info.value.window == "weekly"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_allows_when_redis_unavailable(self):
|
||||
"""Fail-open: allow requests when Redis is down."""
|
||||
with patch(
|
||||
"backend.copilot.rate_limit.get_redis_async",
|
||||
side_effect=ConnectionError("Redis down"),
|
||||
):
|
||||
# Should not raise
|
||||
await check_rate_limit(
|
||||
_USER, daily_token_limit=10000, weekly_token_limit=50000
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_skips_check_when_limit_is_zero(self):
|
||||
mock_redis = AsyncMock()
|
||||
mock_redis.get = AsyncMock(side_effect=["999999", "999999"])
|
||||
|
||||
with patch(
|
||||
"backend.copilot.rate_limit.get_redis_async",
|
||||
return_value=mock_redis,
|
||||
):
|
||||
# Should not raise — limits of 0 mean unlimited
|
||||
await check_rate_limit(_USER, daily_token_limit=0, weekly_token_limit=0)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# record_token_usage
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestRecordTokenUsage:
|
||||
@staticmethod
|
||||
def _make_pipeline_mock() -> MagicMock:
|
||||
"""Create a pipeline mock with sync methods and async execute."""
|
||||
pipe = MagicMock()
|
||||
pipe.execute = AsyncMock(return_value=[])
|
||||
return pipe
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_increments_redis_counters(self):
|
||||
mock_pipe = self._make_pipeline_mock()
|
||||
mock_redis = AsyncMock()
|
||||
mock_redis.pipeline = lambda **_kw: mock_pipe
|
||||
|
||||
with patch(
|
||||
"backend.copilot.rate_limit.get_redis_async",
|
||||
return_value=mock_redis,
|
||||
):
|
||||
await record_token_usage(_USER, prompt_tokens=100, completion_tokens=50)
|
||||
|
||||
# Should call incrby twice (daily + weekly) with total=150
|
||||
incrby_calls = mock_pipe.incrby.call_args_list
|
||||
assert len(incrby_calls) == 2
|
||||
assert incrby_calls[0].args[1] == 150 # daily
|
||||
assert incrby_calls[1].args[1] == 150 # weekly
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_skips_when_zero_tokens(self):
|
||||
mock_redis = AsyncMock()
|
||||
|
||||
with patch(
|
||||
"backend.copilot.rate_limit.get_redis_async",
|
||||
return_value=mock_redis,
|
||||
):
|
||||
await record_token_usage(_USER, prompt_tokens=0, completion_tokens=0)
|
||||
|
||||
# Should not call pipeline at all
|
||||
mock_redis.pipeline.assert_not_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_sets_expire_on_both_keys(self):
|
||||
"""Pipeline should call expire for both daily and weekly keys."""
|
||||
mock_pipe = self._make_pipeline_mock()
|
||||
mock_redis = AsyncMock()
|
||||
mock_redis.pipeline = lambda **_kw: mock_pipe
|
||||
|
||||
with patch(
|
||||
"backend.copilot.rate_limit.get_redis_async",
|
||||
return_value=mock_redis,
|
||||
):
|
||||
await record_token_usage(_USER, prompt_tokens=100, completion_tokens=50)
|
||||
|
||||
expire_calls = mock_pipe.expire.call_args_list
|
||||
assert len(expire_calls) == 2
|
||||
|
||||
# Daily key TTL should be positive (seconds until next midnight)
|
||||
daily_ttl = expire_calls[0].args[1]
|
||||
assert daily_ttl >= 1
|
||||
|
||||
# Weekly key TTL should be positive (seconds until next Monday)
|
||||
weekly_ttl = expire_calls[1].args[1]
|
||||
assert weekly_ttl >= 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handles_redis_failure_gracefully(self):
|
||||
"""Should not raise when Redis is unavailable."""
|
||||
with patch(
|
||||
"backend.copilot.rate_limit.get_redis_async",
|
||||
side_effect=ConnectionError("Redis down"),
|
||||
):
|
||||
# Should not raise
|
||||
await record_token_usage(_USER, prompt_tokens=100, completion_tokens=50)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cost_weighted_counting(self):
|
||||
"""Cached tokens should be weighted: cache_read=10%, cache_create=25%."""
|
||||
mock_pipe = self._make_pipeline_mock()
|
||||
mock_redis = AsyncMock()
|
||||
mock_redis.pipeline = lambda **_kw: mock_pipe
|
||||
|
||||
with patch(
|
||||
"backend.copilot.rate_limit.get_redis_async",
|
||||
return_value=mock_redis,
|
||||
):
|
||||
await record_token_usage(
|
||||
_USER,
|
||||
prompt_tokens=100, # uncached → 100
|
||||
completion_tokens=50, # output → 50
|
||||
cache_read_tokens=10000, # 10% → 1000
|
||||
cache_creation_tokens=400, # 25% → 100
|
||||
)
|
||||
|
||||
# Expected weighted total: 100 + 1000 + 100 + 50 = 1250
|
||||
incrby_calls = mock_pipe.incrby.call_args_list
|
||||
assert len(incrby_calls) == 2
|
||||
assert incrby_calls[0].args[1] == 1250 # daily
|
||||
assert incrby_calls[1].args[1] == 1250 # weekly
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handles_redis_error_during_pipeline_execute(self):
|
||||
"""Should not raise when pipeline.execute() fails with RedisError."""
|
||||
mock_pipe = self._make_pipeline_mock()
|
||||
mock_pipe.execute = AsyncMock(side_effect=RedisError("Pipeline failed"))
|
||||
mock_redis = AsyncMock()
|
||||
mock_redis.pipeline = lambda **_kw: mock_pipe
|
||||
|
||||
with patch(
|
||||
"backend.copilot.rate_limit.get_redis_async",
|
||||
return_value=mock_redis,
|
||||
):
|
||||
# Should not raise — fail-open
|
||||
await record_token_usage(_USER, prompt_tokens=100, completion_tokens=50)
|
||||
@@ -186,29 +186,12 @@ class StreamToolOutputAvailable(StreamBaseResponse):
|
||||
|
||||
|
||||
class StreamUsage(StreamBaseResponse):
|
||||
"""Token usage statistics.
|
||||
|
||||
Emitted as an SSE comment so the Vercel AI SDK parser ignores it
|
||||
(it uses z.strictObject() and rejects unknown event types).
|
||||
Usage data is recorded server-side (session DB + Redis counters).
|
||||
"""
|
||||
"""Token usage statistics."""
|
||||
|
||||
type: ResponseType = ResponseType.USAGE
|
||||
promptTokens: int = Field(..., description="Number of uncached prompt tokens")
|
||||
promptTokens: int = Field(..., description="Number of prompt tokens")
|
||||
completionTokens: int = Field(..., description="Number of completion tokens")
|
||||
totalTokens: int = Field(
|
||||
..., description="Total number of tokens (raw, not weighted)"
|
||||
)
|
||||
cacheReadTokens: int = Field(
|
||||
default=0, description="Prompt tokens served from cache (10% cost)"
|
||||
)
|
||||
cacheCreationTokens: int = Field(
|
||||
default=0, description="Prompt tokens written to cache (25% cost)"
|
||||
)
|
||||
|
||||
def to_sse(self) -> str:
|
||||
"""Emit as SSE comment so the AI SDK parser ignores it."""
|
||||
return f": usage {self.model_dump_json(exclude_none=True)}\n\n"
|
||||
totalTokens: int = Field(..., description="Total number of tokens")
|
||||
|
||||
|
||||
class StreamError(StreamBaseResponse):
|
||||
|
||||
@@ -198,7 +198,6 @@ class CompactionTracker:
|
||||
|
||||
def reset_for_query(self) -> None:
|
||||
"""Reset per-query state before a new SDK query."""
|
||||
self._compact_start.clear()
|
||||
self._done = False
|
||||
self._start_emitted = False
|
||||
self._tool_call_id = ""
|
||||
|
||||
@@ -1,546 +0,0 @@
|
||||
"""End-to-end compaction flow test.
|
||||
|
||||
Simulates the full service.py compaction lifecycle using real-format
|
||||
JSONL session files — no SDK subprocess needed. Exercises:
|
||||
|
||||
1. TranscriptBuilder loads a "downloaded" transcript
|
||||
2. User query appended, assistant response streamed
|
||||
3. PreCompact hook fires → CompactionTracker.on_compact()
|
||||
4. Next message → emit_start_if_ready() yields spinner events
|
||||
5. Message after that → emit_end_if_ready() returns end events
|
||||
6. _read_compacted_entries() reads the CLI session file
|
||||
7. TranscriptBuilder.replace_entries() syncs state
|
||||
8. More messages appended post-compaction
|
||||
9. to_jsonl() exports full state for upload
|
||||
10. Fresh builder loads the export — roundtrip verified
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
from pathlib import Path
|
||||
|
||||
from backend.copilot.model import ChatSession
|
||||
from backend.copilot.response_model import (
|
||||
StreamFinishStep,
|
||||
StreamStartStep,
|
||||
StreamToolInputAvailable,
|
||||
StreamToolInputStart,
|
||||
StreamToolOutputAvailable,
|
||||
)
|
||||
from backend.copilot.sdk.compaction import CompactionTracker
|
||||
from backend.copilot.sdk.transcript import strip_progress_entries
|
||||
from backend.copilot.sdk.transcript_builder import TranscriptBuilder
|
||||
from backend.util import json
|
||||
|
||||
|
||||
def _make_jsonl(*entries: dict) -> str:
|
||||
return "\n".join(json.dumps(e) for e in entries) + "\n"
|
||||
|
||||
|
||||
def _run(coro):
|
||||
"""Run an async coroutine synchronously."""
|
||||
return asyncio.run(coro)
|
||||
|
||||
|
||||
def _read_compacted_entries(path: str) -> tuple[list[dict], str] | None:
|
||||
"""Test-only: read compacted entries from a session JSONL file.
|
||||
|
||||
Returns (parsed_dicts, jsonl_string) from the first ``isCompactSummary``
|
||||
entry onward, or ``None`` if no summary is found.
|
||||
"""
|
||||
content = Path(path).read_text()
|
||||
lines = content.strip().split("\n")
|
||||
compact_idx: int | None = None
|
||||
parsed: list[dict] = []
|
||||
raw_lines: list[str] = []
|
||||
for line in lines:
|
||||
if not line.strip():
|
||||
continue
|
||||
entry = json.loads(line, fallback=None)
|
||||
if not isinstance(entry, dict):
|
||||
continue
|
||||
parsed.append(entry)
|
||||
raw_lines.append(line.strip())
|
||||
if compact_idx is None and entry.get("isCompactSummary"):
|
||||
compact_idx = len(parsed) - 1
|
||||
if compact_idx is None:
|
||||
return None
|
||||
return parsed[compact_idx:], "\n".join(raw_lines[compact_idx:]) + "\n"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Fixtures: realistic CLI session file content
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
# Pre-compaction conversation
|
||||
USER_1 = {
|
||||
"type": "user",
|
||||
"uuid": "u1",
|
||||
"message": {"role": "user", "content": "What files are in this project?"},
|
||||
}
|
||||
ASST_1_THINKING = {
|
||||
"type": "assistant",
|
||||
"uuid": "a1-think",
|
||||
"parentUuid": "u1",
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"id": "msg_sdk_aaa",
|
||||
"type": "message",
|
||||
"content": [{"type": "thinking", "thinking": "Let me look at the files..."}],
|
||||
"stop_reason": None,
|
||||
"stop_sequence": None,
|
||||
},
|
||||
}
|
||||
ASST_1_TOOL = {
|
||||
"type": "assistant",
|
||||
"uuid": "a1-tool",
|
||||
"parentUuid": "u1",
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"id": "msg_sdk_aaa",
|
||||
"type": "message",
|
||||
"content": [
|
||||
{
|
||||
"type": "tool_use",
|
||||
"id": "tu1",
|
||||
"name": "Bash",
|
||||
"input": {"command": "ls"},
|
||||
}
|
||||
],
|
||||
"stop_reason": "tool_use",
|
||||
"stop_sequence": None,
|
||||
},
|
||||
}
|
||||
TOOL_RESULT_1 = {
|
||||
"type": "user",
|
||||
"uuid": "tr1",
|
||||
"parentUuid": "a1-tool",
|
||||
"message": {
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "tool_result",
|
||||
"tool_use_id": "tu1",
|
||||
"content": "file1.py\nfile2.py",
|
||||
}
|
||||
],
|
||||
},
|
||||
}
|
||||
ASST_1_TEXT = {
|
||||
"type": "assistant",
|
||||
"uuid": "a1-text",
|
||||
"parentUuid": "tr1",
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"id": "msg_sdk_bbb",
|
||||
"type": "message",
|
||||
"content": [{"type": "text", "text": "I found file1.py and file2.py."}],
|
||||
"stop_reason": "end_turn",
|
||||
"stop_sequence": None,
|
||||
},
|
||||
}
|
||||
# Progress entries (should be stripped during upload)
|
||||
PROGRESS_1 = {
|
||||
"type": "progress",
|
||||
"uuid": "prog1",
|
||||
"parentUuid": "a1-tool",
|
||||
"data": {"type": "bash_progress", "stdout": "running ls..."},
|
||||
}
|
||||
# Second user message
|
||||
USER_2 = {
|
||||
"type": "user",
|
||||
"uuid": "u2",
|
||||
"parentUuid": "a1-text",
|
||||
"message": {"role": "user", "content": "Show me file1.py"},
|
||||
}
|
||||
ASST_2 = {
|
||||
"type": "assistant",
|
||||
"uuid": "a2",
|
||||
"parentUuid": "u2",
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"id": "msg_sdk_ccc",
|
||||
"type": "message",
|
||||
"content": [{"type": "text", "text": "Here is file1.py content..."}],
|
||||
"stop_reason": "end_turn",
|
||||
"stop_sequence": None,
|
||||
},
|
||||
}
|
||||
|
||||
# --- Compaction summary (written by CLI after context compaction) ---
|
||||
COMPACT_SUMMARY = {
|
||||
"type": "summary",
|
||||
"uuid": "cs1",
|
||||
"isCompactSummary": True,
|
||||
"message": {
|
||||
"role": "user",
|
||||
"content": (
|
||||
"Summary: User asked about project files. Found file1.py and file2.py. "
|
||||
"User then asked to see file1.py."
|
||||
),
|
||||
},
|
||||
}
|
||||
|
||||
# Post-compaction assistant response
|
||||
POST_COMPACT_ASST = {
|
||||
"type": "assistant",
|
||||
"uuid": "a3",
|
||||
"parentUuid": "cs1",
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"id": "msg_sdk_ddd",
|
||||
"type": "message",
|
||||
"content": [{"type": "text", "text": "Here is the content of file1.py..."}],
|
||||
"stop_reason": "end_turn",
|
||||
"stop_sequence": None,
|
||||
},
|
||||
}
|
||||
|
||||
# Post-compaction user follow-up
|
||||
USER_3 = {
|
||||
"type": "user",
|
||||
"uuid": "u3",
|
||||
"parentUuid": "a3",
|
||||
"message": {"role": "user", "content": "Now show file2.py"},
|
||||
}
|
||||
ASST_3 = {
|
||||
"type": "assistant",
|
||||
"uuid": "a4",
|
||||
"parentUuid": "u3",
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"id": "msg_sdk_eee",
|
||||
"type": "message",
|
||||
"content": [{"type": "text", "text": "Here is file2.py..."}],
|
||||
"stop_reason": "end_turn",
|
||||
"stop_sequence": None,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# E2E test
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestCompactionE2E:
|
||||
def _write_session_file(self, session_dir, entries):
|
||||
"""Write a CLI session JSONL file."""
|
||||
path = session_dir / "session.jsonl"
|
||||
path.write_text(_make_jsonl(*entries))
|
||||
return path
|
||||
|
||||
def test_full_compaction_lifecycle(self, tmp_path):
|
||||
"""Simulate the complete service.py compaction flow.
|
||||
|
||||
Timeline:
|
||||
1. Previous turn uploaded transcript with [USER_1, ASST_1, USER_2, ASST_2]
|
||||
2. Current turn: download → load_previous
|
||||
3. User sends "Now show file2.py" → append_user
|
||||
4. SDK starts streaming response
|
||||
5. Mid-stream: PreCompact hook fires (context too large)
|
||||
6. CLI writes compaction summary to session file
|
||||
7. Next SDK message → emit_start (spinner)
|
||||
8. Following message → emit_end (end events)
|
||||
9. _read_compacted_entries reads the session file
|
||||
10. replace_entries syncs TranscriptBuilder
|
||||
11. More assistant messages appended
|
||||
12. Export → upload → next turn downloads it
|
||||
"""
|
||||
session_dir = tmp_path / "session"
|
||||
session_dir.mkdir(parents=True)
|
||||
|
||||
# --- Step 1-2: Load "downloaded" transcript from previous turn ---
|
||||
previous_transcript = _make_jsonl(
|
||||
USER_1,
|
||||
ASST_1_THINKING,
|
||||
ASST_1_TOOL,
|
||||
TOOL_RESULT_1,
|
||||
ASST_1_TEXT,
|
||||
USER_2,
|
||||
ASST_2,
|
||||
)
|
||||
builder = TranscriptBuilder()
|
||||
builder.load_previous(previous_transcript)
|
||||
assert builder.entry_count == 7
|
||||
|
||||
# --- Step 3: User sends new query ---
|
||||
builder.append_user("Now show file2.py")
|
||||
assert builder.entry_count == 8
|
||||
|
||||
# --- Step 4: SDK starts streaming ---
|
||||
builder.append_assistant(
|
||||
[{"type": "thinking", "thinking": "Let me read file2.py..."}],
|
||||
model="claude-sonnet-4-20250514",
|
||||
)
|
||||
assert builder.entry_count == 9
|
||||
|
||||
# --- Step 5-6: PreCompact fires, CLI writes session file ---
|
||||
session_file = self._write_session_file(
|
||||
session_dir,
|
||||
[
|
||||
USER_1,
|
||||
ASST_1_THINKING,
|
||||
ASST_1_TOOL,
|
||||
PROGRESS_1,
|
||||
TOOL_RESULT_1,
|
||||
ASST_1_TEXT,
|
||||
USER_2,
|
||||
ASST_2,
|
||||
COMPACT_SUMMARY,
|
||||
POST_COMPACT_ASST,
|
||||
USER_3,
|
||||
ASST_3,
|
||||
],
|
||||
)
|
||||
|
||||
# --- Step 7: CompactionTracker receives PreCompact hook ---
|
||||
tracker = CompactionTracker()
|
||||
session = ChatSession.new(user_id="test-user")
|
||||
# on_compact is a property returning Event.set callable
|
||||
tracker.on_compact()
|
||||
|
||||
# --- Step 8: Next SDK message arrives → emit_start ---
|
||||
start_events = tracker.emit_start_if_ready()
|
||||
assert len(start_events) == 3
|
||||
assert isinstance(start_events[0], StreamStartStep)
|
||||
assert isinstance(start_events[1], StreamToolInputStart)
|
||||
assert isinstance(start_events[2], StreamToolInputAvailable)
|
||||
|
||||
# Verify tool_call_id is set
|
||||
tool_call_id = start_events[1].toolCallId
|
||||
assert tool_call_id.startswith("compaction-")
|
||||
|
||||
# --- Step 9: Following message → emit_end ---
|
||||
end_events = _run(tracker.emit_end_if_ready(session))
|
||||
assert len(end_events) == 2
|
||||
assert isinstance(end_events[0], StreamToolOutputAvailable)
|
||||
assert isinstance(end_events[1], StreamFinishStep)
|
||||
# Verify same tool_call_id
|
||||
assert end_events[0].toolCallId == tool_call_id
|
||||
|
||||
# Session should have compaction messages persisted
|
||||
assert len(session.messages) == 2
|
||||
assert session.messages[0].role == "assistant"
|
||||
assert session.messages[1].role == "tool"
|
||||
|
||||
# --- Step 10: _read_compacted_entries + replace_entries ---
|
||||
result = _read_compacted_entries(str(session_file))
|
||||
assert result is not None
|
||||
compacted_dicts, compacted_jsonl = result
|
||||
# Should have: COMPACT_SUMMARY + POST_COMPACT_ASST + USER_3 + ASST_3
|
||||
assert len(compacted_dicts) == 4
|
||||
assert compacted_dicts[0]["uuid"] == "cs1"
|
||||
assert compacted_dicts[0]["isCompactSummary"] is True
|
||||
|
||||
# Replace builder state with compacted JSONL
|
||||
old_count = builder.entry_count
|
||||
builder.replace_entries(compacted_jsonl)
|
||||
assert builder.entry_count == 4 # Only compacted entries
|
||||
assert builder.entry_count < old_count # Compaction reduced entries
|
||||
|
||||
# --- Step 11: More assistant messages after compaction ---
|
||||
builder.append_assistant(
|
||||
[{"type": "text", "text": "Here is file2.py:\n\ndef hello():\n pass"}],
|
||||
model="claude-sonnet-4-20250514",
|
||||
stop_reason="end_turn",
|
||||
)
|
||||
assert builder.entry_count == 5
|
||||
|
||||
# --- Step 12: Export for upload ---
|
||||
output = builder.to_jsonl()
|
||||
assert output # Not empty
|
||||
output_entries = [json.loads(line) for line in output.strip().split("\n")]
|
||||
assert len(output_entries) == 5
|
||||
|
||||
# Verify structure:
|
||||
# [COMPACT_SUMMARY, POST_COMPACT_ASST, USER_3, ASST_3, new_assistant]
|
||||
assert output_entries[0]["type"] == "summary"
|
||||
assert output_entries[0].get("isCompactSummary") is True
|
||||
assert output_entries[0]["uuid"] == "cs1"
|
||||
assert output_entries[1]["uuid"] == "a3"
|
||||
assert output_entries[2]["uuid"] == "u3"
|
||||
assert output_entries[3]["uuid"] == "a4"
|
||||
assert output_entries[4]["type"] == "assistant"
|
||||
|
||||
# Verify parent chain is intact
|
||||
assert output_entries[1]["parentUuid"] == "cs1" # a3 → cs1
|
||||
assert output_entries[2]["parentUuid"] == "a3" # u3 → a3
|
||||
assert output_entries[3]["parentUuid"] == "u3" # a4 → u3
|
||||
assert output_entries[4]["parentUuid"] == "a4" # new → a4
|
||||
|
||||
# --- Step 13: Roundtrip — next turn loads this export ---
|
||||
builder2 = TranscriptBuilder()
|
||||
builder2.load_previous(output)
|
||||
assert builder2.entry_count == 5
|
||||
|
||||
# isCompactSummary survives roundtrip
|
||||
output2 = builder2.to_jsonl()
|
||||
first_entry = json.loads(output2.strip().split("\n")[0])
|
||||
assert first_entry.get("isCompactSummary") is True
|
||||
|
||||
# Can append more messages
|
||||
builder2.append_user("What about file3.py?")
|
||||
assert builder2.entry_count == 6
|
||||
final_output = builder2.to_jsonl()
|
||||
last_entry = json.loads(final_output.strip().split("\n")[-1])
|
||||
assert last_entry["type"] == "user"
|
||||
# Parented to the last entry from previous turn
|
||||
assert last_entry["parentUuid"] == output_entries[-1]["uuid"]
|
||||
|
||||
def test_double_compaction_within_session(self, tmp_path):
|
||||
"""Two compactions in the same session (across reset_for_query)."""
|
||||
session_dir = tmp_path / "session"
|
||||
session_dir.mkdir(parents=True)
|
||||
|
||||
tracker = CompactionTracker()
|
||||
session = ChatSession.new(user_id="test")
|
||||
builder = TranscriptBuilder()
|
||||
|
||||
# --- First query with compaction ---
|
||||
builder.append_user("first question")
|
||||
builder.append_assistant([{"type": "text", "text": "first answer"}])
|
||||
|
||||
# Write session file for first compaction
|
||||
first_summary = {
|
||||
"type": "summary",
|
||||
"uuid": "cs-first",
|
||||
"isCompactSummary": True,
|
||||
"message": {"role": "user", "content": "First compaction summary"},
|
||||
}
|
||||
first_post = {
|
||||
"type": "assistant",
|
||||
"uuid": "a-first",
|
||||
"parentUuid": "cs-first",
|
||||
"message": {"role": "assistant", "content": "first post-compact"},
|
||||
}
|
||||
file1 = session_dir / "session1.jsonl"
|
||||
file1.write_text(_make_jsonl(first_summary, first_post))
|
||||
|
||||
tracker.on_compact()
|
||||
tracker.emit_start_if_ready()
|
||||
end_events1 = _run(tracker.emit_end_if_ready(session))
|
||||
assert len(end_events1) == 2 # output + finish
|
||||
|
||||
result1_entries = _read_compacted_entries(str(file1))
|
||||
assert result1_entries is not None
|
||||
_, compacted1_jsonl = result1_entries
|
||||
builder.replace_entries(compacted1_jsonl)
|
||||
assert builder.entry_count == 2
|
||||
|
||||
# --- Reset for second query ---
|
||||
tracker.reset_for_query()
|
||||
|
||||
# --- Second query with compaction ---
|
||||
builder.append_user("second question")
|
||||
builder.append_assistant([{"type": "text", "text": "second answer"}])
|
||||
|
||||
second_summary = {
|
||||
"type": "summary",
|
||||
"uuid": "cs-second",
|
||||
"isCompactSummary": True,
|
||||
"message": {"role": "user", "content": "Second compaction summary"},
|
||||
}
|
||||
second_post = {
|
||||
"type": "assistant",
|
||||
"uuid": "a-second",
|
||||
"parentUuid": "cs-second",
|
||||
"message": {"role": "assistant", "content": "second post-compact"},
|
||||
}
|
||||
file2 = session_dir / "session2.jsonl"
|
||||
file2.write_text(_make_jsonl(second_summary, second_post))
|
||||
|
||||
tracker.on_compact()
|
||||
tracker.emit_start_if_ready()
|
||||
end_events2 = _run(tracker.emit_end_if_ready(session))
|
||||
assert len(end_events2) == 2 # output + finish
|
||||
|
||||
result2_entries = _read_compacted_entries(str(file2))
|
||||
assert result2_entries is not None
|
||||
_, compacted2_jsonl = result2_entries
|
||||
builder.replace_entries(compacted2_jsonl)
|
||||
assert builder.entry_count == 2 # Only second compaction entries
|
||||
|
||||
# Export and verify
|
||||
output = builder.to_jsonl()
|
||||
entries = [json.loads(line) for line in output.strip().split("\n")]
|
||||
assert entries[0]["uuid"] == "cs-second"
|
||||
assert entries[0].get("isCompactSummary") is True
|
||||
|
||||
def test_strip_progress_then_load_then_compact_roundtrip(self, tmp_path):
|
||||
"""Full pipeline: strip → load → compact → replace → export → reload.
|
||||
|
||||
This tests the exact sequence that happens across two turns:
|
||||
Turn 1: SDK produces transcript with progress entries
|
||||
Upload: strip_progress_entries removes progress, upload to cloud
|
||||
Turn 2: Download → load_previous → compaction fires → replace → export
|
||||
Turn 3: Download the Turn 2 export → load_previous (roundtrip)
|
||||
"""
|
||||
session_dir = tmp_path / "session"
|
||||
session_dir.mkdir(parents=True)
|
||||
|
||||
# --- Turn 1: SDK produces raw transcript ---
|
||||
raw_content = _make_jsonl(
|
||||
USER_1,
|
||||
ASST_1_THINKING,
|
||||
ASST_1_TOOL,
|
||||
PROGRESS_1,
|
||||
TOOL_RESULT_1,
|
||||
ASST_1_TEXT,
|
||||
USER_2,
|
||||
ASST_2,
|
||||
)
|
||||
|
||||
# Strip progress for upload
|
||||
stripped = strip_progress_entries(raw_content)
|
||||
stripped_entries = [
|
||||
json.loads(line) for line in stripped.strip().split("\n") if line.strip()
|
||||
]
|
||||
# Progress should be gone
|
||||
assert not any(e.get("type") == "progress" for e in stripped_entries)
|
||||
assert len(stripped_entries) == 7 # 8 - 1 progress
|
||||
|
||||
# --- Turn 2: Download stripped, load, compaction happens ---
|
||||
builder = TranscriptBuilder()
|
||||
builder.load_previous(stripped)
|
||||
assert builder.entry_count == 7
|
||||
|
||||
builder.append_user("Now show file2.py")
|
||||
builder.append_assistant(
|
||||
[{"type": "text", "text": "Reading file2.py..."}],
|
||||
model="claude-sonnet-4-20250514",
|
||||
)
|
||||
|
||||
# CLI writes session file with compaction
|
||||
session_file = self._write_session_file(
|
||||
session_dir,
|
||||
[
|
||||
USER_1,
|
||||
ASST_1_TOOL,
|
||||
TOOL_RESULT_1,
|
||||
ASST_1_TEXT,
|
||||
USER_2,
|
||||
ASST_2,
|
||||
COMPACT_SUMMARY,
|
||||
POST_COMPACT_ASST,
|
||||
],
|
||||
)
|
||||
|
||||
result = _read_compacted_entries(str(session_file))
|
||||
assert result is not None
|
||||
_, compacted_jsonl = result
|
||||
builder.replace_entries(compacted_jsonl)
|
||||
|
||||
# Append post-compaction message
|
||||
builder.append_user("Thanks!")
|
||||
output = builder.to_jsonl()
|
||||
|
||||
# --- Turn 3: Fresh load of Turn 2 export ---
|
||||
builder3 = TranscriptBuilder()
|
||||
builder3.load_previous(output)
|
||||
# Should have: compact_summary + post_compact_asst + "Thanks!"
|
||||
assert builder3.entry_count == 3
|
||||
|
||||
# Compact summary survived the full pipeline
|
||||
first = json.loads(builder3.to_jsonl().strip().split("\n")[0])
|
||||
assert first.get("isCompactSummary") is True
|
||||
assert first["type"] == "summary"
|
||||
@@ -221,12 +221,12 @@ class SDKResponseAdapter:
|
||||
responses.append(StreamFinish())
|
||||
else:
|
||||
logger.warning(
|
||||
"Unexpected ResultMessage subtype: %s", sdk_message.subtype
|
||||
f"Unexpected ResultMessage subtype: {sdk_message.subtype}"
|
||||
)
|
||||
responses.append(StreamFinish())
|
||||
|
||||
else:
|
||||
logger.debug("Unhandled SDK message type: %s", type(sdk_message).__name__)
|
||||
logger.debug(f"Unhandled SDK message type: {type(sdk_message).__name__}")
|
||||
|
||||
return responses
|
||||
|
||||
|
||||
@@ -52,7 +52,7 @@ def _validate_workspace_path(
|
||||
if is_allowed_local_path(path, sdk_cwd):
|
||||
return {}
|
||||
|
||||
logger.warning("Blocked %s outside workspace: %s", tool_name, path)
|
||||
logger.warning(f"Blocked {tool_name} outside workspace: {path}")
|
||||
workspace_hint = f" Allowed workspace: {sdk_cwd}" if sdk_cwd else ""
|
||||
return _deny(
|
||||
f"[SECURITY] Tool '{tool_name}' can only access files within the workspace "
|
||||
@@ -71,7 +71,7 @@ def _validate_tool_access(
|
||||
"""
|
||||
# Block forbidden tools
|
||||
if tool_name in BLOCKED_TOOLS:
|
||||
logger.warning("Blocked tool access attempt: %s", tool_name)
|
||||
logger.warning(f"Blocked tool access attempt: {tool_name}")
|
||||
return _deny(
|
||||
f"[SECURITY] Tool '{tool_name}' is blocked for security. "
|
||||
"This is enforced by the platform and cannot be bypassed. "
|
||||
@@ -111,9 +111,7 @@ def _validate_user_isolation(
|
||||
# the tool itself via _validate_ephemeral_path.
|
||||
path = tool_input.get("path", "") or tool_input.get("file_path", "")
|
||||
if path and ".." in path:
|
||||
logger.warning(
|
||||
"Blocked path traversal attempt: %s by user %s", path, user_id
|
||||
)
|
||||
logger.warning(f"Blocked path traversal attempt: {path} by user {user_id}")
|
||||
return {
|
||||
"hookSpecificOutput": {
|
||||
"hookEventName": "PreToolUse",
|
||||
@@ -171,7 +169,7 @@ def create_security_hooks(
|
||||
# Block background task execution first — denied calls
|
||||
# should not consume a subtask slot.
|
||||
if tool_input.get("run_in_background"):
|
||||
logger.info("[SDK] Blocked background Task, user=%s", user_id)
|
||||
logger.info(f"[SDK] Blocked background Task, user={user_id}")
|
||||
return cast(
|
||||
SyncHookJSONOutput,
|
||||
_deny(
|
||||
@@ -213,7 +211,7 @@ def create_security_hooks(
|
||||
if tool_name == "Task" and tool_use_id is not None:
|
||||
task_tool_use_ids.add(tool_use_id)
|
||||
|
||||
logger.debug("[SDK] Tool start: %s, user=%s", tool_name, user_id)
|
||||
logger.debug(f"[SDK] Tool start: {tool_name}, user={user_id}")
|
||||
return cast(SyncHookJSONOutput, {})
|
||||
|
||||
def _release_task_slot(tool_name: str, tool_use_id: str | None) -> None:
|
||||
|
||||
@@ -40,13 +40,11 @@ from ..constants import COPILOT_ERROR_PREFIX, COPILOT_SYSTEM_PREFIX
|
||||
from ..model import (
|
||||
ChatMessage,
|
||||
ChatSession,
|
||||
Usage,
|
||||
get_chat_session,
|
||||
update_session_title,
|
||||
upsert_chat_session,
|
||||
)
|
||||
from ..prompting import get_sdk_supplement
|
||||
from ..rate_limit import record_token_usage
|
||||
from ..response_model import (
|
||||
StreamBaseResponse,
|
||||
StreamError,
|
||||
@@ -56,7 +54,6 @@ from ..response_model import (
|
||||
StreamTextDelta,
|
||||
StreamToolInputAvailable,
|
||||
StreamToolOutputAvailable,
|
||||
StreamUsage,
|
||||
)
|
||||
from ..service import (
|
||||
_build_system_prompt,
|
||||
@@ -78,12 +75,8 @@ from .tool_adapter import (
|
||||
wait_for_stash,
|
||||
)
|
||||
from .transcript import (
|
||||
COMPACT_THRESHOLD_BYTES,
|
||||
TranscriptDownload,
|
||||
cleanup_cli_project_dir,
|
||||
compact_transcript,
|
||||
download_transcript,
|
||||
read_cli_session_file,
|
||||
upload_transcript,
|
||||
validate_transcript,
|
||||
write_transcript_to_tempfile,
|
||||
@@ -301,7 +294,7 @@ def _cleanup_sdk_tool_results(cwd: str) -> None:
|
||||
"""
|
||||
normalized = os.path.normpath(cwd)
|
||||
if not normalized.startswith(_SDK_CWD_PREFIX):
|
||||
logger.warning("[SDK] Rejecting cleanup for path outside workspace: %s", cwd)
|
||||
logger.warning(f"[SDK] Rejecting cleanup for path outside workspace: {cwd}")
|
||||
return
|
||||
|
||||
# Clean the CLI's project directory (transcripts + tool-results).
|
||||
@@ -395,7 +388,7 @@ async def _compress_messages(
|
||||
client=client,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning("[SDK] Context compression with LLM failed: %s", e)
|
||||
logger.warning(f"[SDK] Context compression with LLM failed: {e}")
|
||||
# Fall back to truncation-only (no LLM summarization)
|
||||
result = await compress_context(
|
||||
messages=messages_dict,
|
||||
@@ -631,56 +624,6 @@ async def _prepare_file_attachments(
|
||||
return PreparedAttachments(hint=hint, image_blocks=image_blocks)
|
||||
|
||||
|
||||
async def _maybe_compact_and_upload(
|
||||
dl: TranscriptDownload,
|
||||
user_id: str,
|
||||
session_id: str,
|
||||
log_prefix: str = "[Transcript]",
|
||||
) -> str:
|
||||
"""Compact an oversized transcript and upload the compacted version.
|
||||
|
||||
Returns the (possibly compacted) transcript content, or an empty string
|
||||
if compaction was needed but failed.
|
||||
"""
|
||||
content = dl.content
|
||||
if len(content) <= COMPACT_THRESHOLD_BYTES:
|
||||
return content
|
||||
|
||||
logger.warning(
|
||||
"%s Transcript oversized (%dB > %dB), compacting",
|
||||
log_prefix,
|
||||
len(content),
|
||||
COMPACT_THRESHOLD_BYTES,
|
||||
)
|
||||
compacted = await compact_transcript(content, log_prefix=log_prefix)
|
||||
if not compacted:
|
||||
logger.warning(
|
||||
"%s Compaction failed, skipping resume for this turn", log_prefix
|
||||
)
|
||||
return ""
|
||||
|
||||
# Keep the original message_count: it reflects the number of
|
||||
# session.messages covered by this transcript, which the gap-fill
|
||||
# logic uses as a slice index. Counting JSONL lines would give a
|
||||
# smaller number (compacted messages != session message count) and
|
||||
# cause already-covered messages to be re-injected.
|
||||
try:
|
||||
await upload_transcript(
|
||||
user_id=user_id,
|
||||
session_id=session_id,
|
||||
content=compacted,
|
||||
message_count=dl.message_count,
|
||||
log_prefix=log_prefix,
|
||||
)
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"%s Failed to upload compacted transcript",
|
||||
log_prefix,
|
||||
exc_info=True,
|
||||
)
|
||||
return compacted
|
||||
|
||||
|
||||
async def stream_chat_completion_sdk(
|
||||
session_id: str,
|
||||
message: str | None = None,
|
||||
@@ -792,14 +735,6 @@ async def stream_chat_completion_sdk(
|
||||
_otel_ctx: Any = None
|
||||
|
||||
# Make sure there is no more code between the lock acquisition and try-block.
|
||||
# Token usage accumulators — populated from ResultMessage at end of turn
|
||||
turn_prompt_tokens = 0 # uncached input tokens only
|
||||
turn_completion_tokens = 0
|
||||
turn_cache_read_tokens = 0
|
||||
turn_cache_creation_tokens = 0
|
||||
total_tokens = 0 # computed once before StreamUsage, reused in finally
|
||||
turn_cost_usd: float | None = None
|
||||
|
||||
try:
|
||||
# Build system prompt (reuses non-SDK path with Langfuse support).
|
||||
# Pre-compute the cwd here so the exact working directory path can be
|
||||
@@ -892,33 +827,20 @@ async def stream_chat_completion_sdk(
|
||||
is_valid,
|
||||
)
|
||||
if is_valid:
|
||||
transcript_content = await _maybe_compact_and_upload(
|
||||
dl,
|
||||
user_id=user_id or "",
|
||||
session_id=session_id,
|
||||
log_prefix=log_prefix,
|
||||
)
|
||||
# Load previous context into builder (empty string is a no-op)
|
||||
if transcript_content:
|
||||
transcript_builder.load_previous(
|
||||
transcript_content, log_prefix=log_prefix
|
||||
)
|
||||
resume_file = (
|
||||
write_transcript_to_tempfile(
|
||||
transcript_content, session_id, sdk_cwd
|
||||
)
|
||||
if transcript_content
|
||||
else None
|
||||
# Load previous FULL context into builder
|
||||
transcript_builder.load_previous(dl.content, log_prefix=log_prefix)
|
||||
resume_file = write_transcript_to_tempfile(
|
||||
dl.content, session_id, sdk_cwd
|
||||
)
|
||||
if resume_file:
|
||||
use_resume = True
|
||||
transcript_msg_count = dl.message_count
|
||||
logger.debug(
|
||||
f"{log_prefix} Using --resume ({len(transcript_content)}B, "
|
||||
f"{log_prefix} Using --resume ({len(dl.content)}B, "
|
||||
f"msg_count={transcript_msg_count})"
|
||||
)
|
||||
else:
|
||||
logger.warning("%s Transcript downloaded but invalid", log_prefix)
|
||||
logger.warning(f"{log_prefix} Transcript downloaded but invalid")
|
||||
elif config.claude_agent_use_resume and user_id and len(session.messages) > 1:
|
||||
logger.warning(
|
||||
f"{log_prefix} No transcript available "
|
||||
@@ -1188,7 +1110,7 @@ async def stream_chat_completion_sdk(
|
||||
- len(adapter.resolved_tool_calls),
|
||||
)
|
||||
|
||||
# Log ResultMessage details and capture token usage
|
||||
# Log ResultMessage details for debugging
|
||||
if isinstance(sdk_msg, ResultMessage):
|
||||
logger.info(
|
||||
"%s Received: ResultMessage %s "
|
||||
@@ -1207,46 +1129,9 @@ async def stream_chat_completion_sdk(
|
||||
sdk_msg.result or "(no error message provided)",
|
||||
)
|
||||
|
||||
# Capture token usage from ResultMessage.
|
||||
# Anthropic reports cached tokens separately:
|
||||
# input_tokens = uncached only
|
||||
# cache_read_input_tokens = served from cache
|
||||
# cache_creation_input_tokens = written to cache
|
||||
if sdk_msg.usage:
|
||||
turn_prompt_tokens += sdk_msg.usage.get("input_tokens", 0)
|
||||
turn_cache_read_tokens += sdk_msg.usage.get(
|
||||
"cache_read_input_tokens", 0
|
||||
)
|
||||
turn_cache_creation_tokens += sdk_msg.usage.get(
|
||||
"cache_creation_input_tokens", 0
|
||||
)
|
||||
turn_completion_tokens += sdk_msg.usage.get(
|
||||
"output_tokens", 0
|
||||
)
|
||||
logger.info(
|
||||
"%s Token usage: uncached=%d, cache_read=%d, cache_create=%d, output=%d",
|
||||
log_prefix,
|
||||
turn_prompt_tokens,
|
||||
turn_cache_read_tokens,
|
||||
turn_cache_creation_tokens,
|
||||
turn_completion_tokens,
|
||||
)
|
||||
if sdk_msg.total_cost_usd is not None:
|
||||
turn_cost_usd = sdk_msg.total_cost_usd
|
||||
|
||||
# Emit compaction end if SDK finished compacting.
|
||||
# When compaction ends, sync TranscriptBuilder with
|
||||
# the CLI's compacted session file so the uploaded
|
||||
# transcript reflects compaction.
|
||||
compaction_events = await compaction.emit_end_if_ready(session)
|
||||
for ev in compaction_events:
|
||||
# Emit compaction end if SDK finished compacting
|
||||
for ev in await compaction.emit_end_if_ready(session):
|
||||
yield ev
|
||||
if compaction_events and sdk_cwd:
|
||||
cli_content = await read_cli_session_file(sdk_cwd)
|
||||
if cli_content:
|
||||
transcript_builder.replace_entries(
|
||||
cli_content, log_prefix=log_prefix
|
||||
)
|
||||
|
||||
for response in adapter.convert_message(sdk_msg):
|
||||
if isinstance(response, StreamStart):
|
||||
@@ -1440,27 +1325,6 @@ async def stream_chat_completion_sdk(
|
||||
) and not has_appended_assistant:
|
||||
session.messages.append(assistant_response)
|
||||
|
||||
# Emit token usage to the client (must be in try to reach SSE stream).
|
||||
# Session persistence of usage is in finally to stay consistent with
|
||||
# rate-limit recording even if an exception interrupts between here
|
||||
# and the finally block.
|
||||
# Compute total_tokens once; reused in the finally block for
|
||||
# session persistence and rate-limit recording.
|
||||
total_tokens = (
|
||||
turn_prompt_tokens
|
||||
+ turn_cache_read_tokens
|
||||
+ turn_cache_creation_tokens
|
||||
+ turn_completion_tokens
|
||||
)
|
||||
if total_tokens > 0:
|
||||
yield StreamUsage(
|
||||
promptTokens=turn_prompt_tokens,
|
||||
completionTokens=turn_completion_tokens,
|
||||
totalTokens=total_tokens,
|
||||
cacheReadTokens=turn_cache_read_tokens,
|
||||
cacheCreationTokens=turn_cache_creation_tokens,
|
||||
)
|
||||
|
||||
# Transcript upload is handled exclusively in the finally block
|
||||
# to avoid double-uploads (the success path used to upload the
|
||||
# old resume file, then the finally block overwrote it with the
|
||||
@@ -1525,48 +1389,6 @@ async def stream_chat_completion_sdk(
|
||||
except Exception:
|
||||
logger.warning("OTEL context teardown failed", exc_info=True)
|
||||
|
||||
# --- Persist token usage to session + rate-limit counters ---
|
||||
# Both must live in finally so they stay consistent even when an
|
||||
# exception interrupts the try block after StreamUsage was yielded.
|
||||
# total_tokens is computed once before StreamUsage yield above.
|
||||
if total_tokens > 0:
|
||||
if session is not None:
|
||||
session.usage.append(
|
||||
Usage(
|
||||
prompt_tokens=turn_prompt_tokens,
|
||||
completion_tokens=turn_completion_tokens,
|
||||
total_tokens=total_tokens,
|
||||
cache_read_tokens=turn_cache_read_tokens,
|
||||
cache_creation_tokens=turn_cache_creation_tokens,
|
||||
)
|
||||
)
|
||||
logger.info(
|
||||
"%s Turn usage: uncached=%d, cache_read=%d, cache_create=%d, "
|
||||
"output=%d, total=%d, cost_usd=%s",
|
||||
log_prefix,
|
||||
turn_prompt_tokens,
|
||||
turn_cache_read_tokens,
|
||||
turn_cache_creation_tokens,
|
||||
turn_completion_tokens,
|
||||
total_tokens,
|
||||
turn_cost_usd,
|
||||
)
|
||||
if user_id and total_tokens > 0:
|
||||
try:
|
||||
await record_token_usage(
|
||||
user_id=user_id,
|
||||
prompt_tokens=turn_prompt_tokens,
|
||||
completion_tokens=turn_completion_tokens,
|
||||
cache_read_tokens=turn_cache_read_tokens,
|
||||
cache_creation_tokens=turn_cache_creation_tokens,
|
||||
)
|
||||
except Exception as usage_err:
|
||||
logger.warning(
|
||||
"%s Failed to record token usage: %s",
|
||||
log_prefix,
|
||||
usage_err,
|
||||
)
|
||||
|
||||
# --- Persist session messages ---
|
||||
# This MUST run in finally to persist messages even when the generator
|
||||
# is stopped early (e.g., user clicks stop, processor breaks stream loop).
|
||||
@@ -1662,6 +1484,6 @@ async def _update_title_async(
|
||||
)
|
||||
if title and user_id:
|
||||
await update_session_title(session_id, user_id, title, only_if_empty=True)
|
||||
logger.debug("[SDK] Generated title for %s: %s", session_id, title)
|
||||
logger.debug(f"[SDK] Generated title for {session_id}: {title}")
|
||||
except Exception as e:
|
||||
logger.warning("[SDK] Failed to update session title: %s", e)
|
||||
logger.warning(f"[SDK] Failed to update session title: {e}")
|
||||
|
||||
@@ -234,9 +234,7 @@ def create_tool_handler(base_tool: BaseTool):
|
||||
try:
|
||||
return await _execute_tool_sync(base_tool, user_id, session, args)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Error executing tool %s: %s", base_tool.name, e, exc_info=True
|
||||
)
|
||||
logger.error(f"Error executing tool {base_tool.name}: {e}", exc_info=True)
|
||||
return _mcp_error(f"Failed to execute {base_tool.name}: {e}")
|
||||
|
||||
return tool_handler
|
||||
|
||||
@@ -13,17 +13,10 @@ filesystem for self-hosted) — no DB column needed.
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
import shutil
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from uuid import uuid4
|
||||
|
||||
import openai
|
||||
|
||||
from backend.copilot.config import ChatConfig
|
||||
from backend.util import json
|
||||
from backend.util.prompt import CompressResult, compress_context
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -41,11 +34,6 @@ STRIPPABLE_TYPES = frozenset(
|
||||
{"progress", "file-history-snapshot", "queue-operation", "summary", "pr-link"}
|
||||
)
|
||||
|
||||
# JSONL protocol values used in transcript serialization.
|
||||
STOP_REASON_END_TURN = "end_turn"
|
||||
COMPACT_MSG_ID_PREFIX = "msg_compact_"
|
||||
ENTRY_TYPE_MESSAGE = "message"
|
||||
|
||||
|
||||
@dataclass
|
||||
class TranscriptDownload:
|
||||
@@ -94,11 +82,7 @@ def strip_progress_entries(content: str) -> str:
|
||||
parent = entry.get("parentUuid", "")
|
||||
if uid:
|
||||
uuid_to_parent[uid] = parent
|
||||
if (
|
||||
entry.get("type", "") in STRIPPABLE_TYPES
|
||||
and uid
|
||||
and not entry.get("isCompactSummary")
|
||||
):
|
||||
if entry.get("type", "") in STRIPPABLE_TYPES and uid:
|
||||
stripped_uuids.add(uid)
|
||||
|
||||
# Second pass: keep non-stripped entries, reparenting where needed.
|
||||
@@ -109,9 +93,7 @@ def strip_progress_entries(content: str) -> str:
|
||||
continue
|
||||
parent = entry.get("parentUuid", "")
|
||||
original_parent = parent
|
||||
seen_parents: set[str] = set()
|
||||
while parent in stripped_uuids and parent not in seen_parents:
|
||||
seen_parents.add(parent)
|
||||
while parent in stripped_uuids:
|
||||
parent = uuid_to_parent.get(parent, "")
|
||||
if parent != original_parent:
|
||||
entry["parentUuid"] = parent
|
||||
@@ -124,9 +106,7 @@ def strip_progress_entries(content: str) -> str:
|
||||
if not isinstance(entry, dict):
|
||||
result_lines.append(line)
|
||||
continue
|
||||
if entry.get("type", "") in STRIPPABLE_TYPES and not entry.get(
|
||||
"isCompactSummary"
|
||||
):
|
||||
if entry.get("type", "") in STRIPPABLE_TYPES:
|
||||
continue
|
||||
uid = entry.get("uuid", "")
|
||||
if uid in reparented:
|
||||
@@ -157,78 +137,32 @@ def _sanitize_id(raw_id: str, max_len: int = 36) -> str:
|
||||
_SAFE_CWD_PREFIX = os.path.realpath("/tmp/copilot-")
|
||||
|
||||
|
||||
def _cli_project_dir(sdk_cwd: str) -> str | None:
|
||||
"""Return the CLI's project directory for a given working directory.
|
||||
def cleanup_cli_project_dir(sdk_cwd: str) -> None:
|
||||
"""Remove the CLI's project directory for a specific working directory.
|
||||
|
||||
Returns ``None`` if the path would escape the projects base.
|
||||
The CLI stores session data under ``~/.claude/projects/<encoded_cwd>/``.
|
||||
Each SDK turn uses a unique ``sdk_cwd``, so the project directory is
|
||||
safe to remove entirely after the transcript has been uploaded.
|
||||
"""
|
||||
import shutil
|
||||
|
||||
# Encode cwd the same way CLI does (replaces non-alphanumeric with -)
|
||||
cwd_encoded = re.sub(r"[^a-zA-Z0-9]", "-", os.path.realpath(sdk_cwd))
|
||||
config_dir = os.environ.get("CLAUDE_CONFIG_DIR") or os.path.expanduser("~/.claude")
|
||||
projects_base = os.path.realpath(os.path.join(config_dir, "projects"))
|
||||
project_dir = os.path.realpath(os.path.join(projects_base, cwd_encoded))
|
||||
|
||||
if not project_dir.startswith(projects_base + os.sep):
|
||||
logger.warning("[Transcript] Project dir escaped base: %s", project_dir)
|
||||
return None
|
||||
return project_dir
|
||||
|
||||
|
||||
async def read_cli_session_file(sdk_cwd: str) -> str | None:
|
||||
"""Read the CLI's own session file, which reflects any mid-stream compaction.
|
||||
|
||||
After the CLI compacts context, its session file contains the compacted
|
||||
conversation. Reading this file lets ``TranscriptBuilder`` replace its
|
||||
uncompacted entries with the CLI's compacted version.
|
||||
"""
|
||||
import aiofiles
|
||||
|
||||
project_dir = _cli_project_dir(sdk_cwd)
|
||||
if not project_dir or not os.path.isdir(project_dir):
|
||||
return None
|
||||
jsonl_files = list(Path(project_dir).glob("*.jsonl"))
|
||||
if not jsonl_files:
|
||||
logger.debug("[Transcript] No CLI session file in %s", project_dir)
|
||||
return None
|
||||
# Pick the most recently modified file (there should only be one per turn).
|
||||
# Guard against races where a file is deleted between glob and stat.
|
||||
candidates: list[tuple[float, Path]] = []
|
||||
for p in jsonl_files:
|
||||
try:
|
||||
candidates.append((p.stat().st_mtime, p))
|
||||
except OSError:
|
||||
continue
|
||||
if not candidates:
|
||||
logger.debug("[Transcript] No readable CLI session file in %s", project_dir)
|
||||
return None
|
||||
# Resolve + prefix check to prevent symlink escapes.
|
||||
session_file = max(candidates, key=lambda item: item[0])[1]
|
||||
real_path = str(session_file.resolve())
|
||||
if not real_path.startswith(project_dir + os.sep):
|
||||
logger.warning("[Transcript] Session file escaped project dir: %s", real_path)
|
||||
return None
|
||||
try:
|
||||
async with aiofiles.open(real_path) as f:
|
||||
content = await f.read()
|
||||
logger.info(
|
||||
"[Transcript] Read CLI session file: %s (%d bytes)",
|
||||
real_path,
|
||||
len(content),
|
||||
logger.warning(
|
||||
f"[Transcript] Cleanup path escaped projects base: {project_dir}"
|
||||
)
|
||||
return content
|
||||
except OSError as e:
|
||||
logger.warning("[Transcript] Failed to read CLI session file: %s", e)
|
||||
return None
|
||||
|
||||
|
||||
def cleanup_cli_project_dir(sdk_cwd: str) -> None:
|
||||
"""Remove the CLI's project directory for a specific working directory."""
|
||||
project_dir = _cli_project_dir(sdk_cwd)
|
||||
if not project_dir:
|
||||
return
|
||||
|
||||
if os.path.isdir(project_dir):
|
||||
shutil.rmtree(project_dir, ignore_errors=True)
|
||||
logger.debug("[Transcript] Cleaned up CLI project dir: %s", project_dir)
|
||||
logger.debug(f"[Transcript] Cleaned up CLI project dir: {project_dir}")
|
||||
else:
|
||||
logger.debug("[Transcript] Project dir not found: %s", project_dir)
|
||||
logger.debug(f"[Transcript] Project dir not found: {project_dir}")
|
||||
|
||||
|
||||
def write_transcript_to_tempfile(
|
||||
@@ -246,7 +180,7 @@ def write_transcript_to_tempfile(
|
||||
# Validate cwd is under the expected sandbox prefix (CodeQL sanitizer).
|
||||
real_cwd = os.path.realpath(cwd)
|
||||
if not real_cwd.startswith(_SAFE_CWD_PREFIX):
|
||||
logger.warning("[Transcript] cwd outside sandbox: %s", cwd)
|
||||
logger.warning(f"[Transcript] cwd outside sandbox: {cwd}")
|
||||
return None
|
||||
|
||||
try:
|
||||
@@ -256,17 +190,17 @@ def write_transcript_to_tempfile(
|
||||
os.path.join(real_cwd, f"transcript-{safe_id}.jsonl")
|
||||
)
|
||||
if not jsonl_path.startswith(real_cwd):
|
||||
logger.warning("[Transcript] Path escaped cwd: %s", jsonl_path)
|
||||
logger.warning(f"[Transcript] Path escaped cwd: {jsonl_path}")
|
||||
return None
|
||||
|
||||
with open(jsonl_path, "w") as f:
|
||||
f.write(transcript_content)
|
||||
|
||||
logger.info("[Transcript] Wrote resume file: %s", jsonl_path)
|
||||
logger.info(f"[Transcript] Wrote resume file: {jsonl_path}")
|
||||
return jsonl_path
|
||||
|
||||
except OSError as e:
|
||||
logger.warning("[Transcript] Failed to write resume file: %s", e)
|
||||
logger.warning(f"[Transcript] Failed to write resume file: {e}")
|
||||
return None
|
||||
|
||||
|
||||
@@ -410,14 +344,11 @@ async def upload_transcript(
|
||||
content=json.dumps(meta).encode("utf-8"),
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning("%s Failed to write metadata: %s", log_prefix, e)
|
||||
logger.warning(f"{log_prefix} Failed to write metadata: {e}")
|
||||
|
||||
logger.info(
|
||||
"%s Uploaded %dB (stripped from %dB, msg_count=%d)",
|
||||
log_prefix,
|
||||
len(encoded),
|
||||
len(content),
|
||||
message_count,
|
||||
f"{log_prefix} Uploaded {len(encoded)}B "
|
||||
f"(stripped from {len(content)}B, msg_count={message_count})"
|
||||
)
|
||||
|
||||
|
||||
@@ -440,10 +371,10 @@ async def download_transcript(
|
||||
data = await storage.retrieve(path)
|
||||
content = data.decode("utf-8")
|
||||
except FileNotFoundError:
|
||||
logger.debug("%s No transcript in storage", log_prefix)
|
||||
logger.debug(f"{log_prefix} No transcript in storage")
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.warning("%s Failed to download transcript: %s", log_prefix, e)
|
||||
logger.warning(f"{log_prefix} Failed to download transcript: {e}")
|
||||
return None
|
||||
|
||||
# Try to load metadata (best-effort — old transcripts won't have it)
|
||||
@@ -463,14 +394,10 @@ async def download_transcript(
|
||||
meta = json.loads(meta_data.decode("utf-8"), fallback={})
|
||||
message_count = meta.get("message_count", 0)
|
||||
uploaded_at = meta.get("uploaded_at", 0.0)
|
||||
except FileNotFoundError:
|
||||
except (FileNotFoundError, Exception):
|
||||
pass # No metadata — treat as unknown (msg_count=0 → always fill gap)
|
||||
except Exception as e:
|
||||
logger.debug("%s Failed to load transcript metadata: %s", log_prefix, e)
|
||||
|
||||
logger.info(
|
||||
"%s Downloaded %dB (msg_count=%d)", log_prefix, len(content), message_count
|
||||
)
|
||||
logger.info(f"{log_prefix} Downloaded {len(content)}B (msg_count={message_count})")
|
||||
return TranscriptDownload(
|
||||
content=content,
|
||||
message_count=message_count,
|
||||
@@ -478,171 +405,15 @@ async def download_transcript(
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Transcript compaction
|
||||
# ---------------------------------------------------------------------------
|
||||
async def delete_transcript(user_id: str, session_id: str) -> None:
|
||||
"""Delete transcript from bucket storage (e.g. after resume failure)."""
|
||||
from backend.util.workspace_storage import get_workspace_storage
|
||||
|
||||
# Transcripts above this byte threshold are compacted at download time.
|
||||
COMPACT_THRESHOLD_BYTES = 400_000
|
||||
storage = await get_workspace_storage()
|
||||
path = _build_storage_path(user_id, session_id, storage)
|
||||
|
||||
|
||||
def _flatten_assistant_content(blocks: list) -> str:
|
||||
"""Flatten assistant content blocks into a single plain-text string."""
|
||||
parts: list[str] = []
|
||||
for block in blocks:
|
||||
if isinstance(block, dict):
|
||||
if block.get("type") == "text":
|
||||
parts.append(block.get("text", ""))
|
||||
elif block.get("type") == "tool_use":
|
||||
parts.append(f"[tool_use: {block.get('name', '?')}]")
|
||||
elif isinstance(block, str):
|
||||
parts.append(block)
|
||||
return "\n".join(parts) if parts else ""
|
||||
|
||||
|
||||
def _flatten_tool_result_content(blocks: list) -> str:
|
||||
"""Flatten tool_result and other content blocks into plain text.
|
||||
|
||||
Handles nested tool_result structures, text blocks, and raw strings.
|
||||
Uses ``json.dumps`` as fallback for dict blocks without a ``text`` key
|
||||
or where ``text`` is ``None``.
|
||||
"""
|
||||
str_parts: list[str] = []
|
||||
for block in blocks:
|
||||
if isinstance(block, dict) and block.get("type") == "tool_result":
|
||||
inner = block.get("content", "")
|
||||
if isinstance(inner, list):
|
||||
for sub in inner:
|
||||
if isinstance(sub, dict):
|
||||
text = sub.get("text")
|
||||
str_parts.append(
|
||||
str(text) if text is not None else json.dumps(sub)
|
||||
)
|
||||
else:
|
||||
str_parts.append(str(sub))
|
||||
else:
|
||||
str_parts.append(str(inner))
|
||||
elif isinstance(block, dict) and block.get("type") == "text":
|
||||
str_parts.append(str(block.get("text", "")))
|
||||
elif isinstance(block, str):
|
||||
str_parts.append(block)
|
||||
return "\n".join(str_parts) if str_parts else ""
|
||||
|
||||
|
||||
def _transcript_to_messages(content: str) -> list[dict]:
|
||||
"""Convert JSONL transcript entries to message dicts for compress_context."""
|
||||
messages: list[dict] = []
|
||||
for line in content.strip().split("\n"):
|
||||
if not line.strip():
|
||||
continue
|
||||
entry = json.loads(line, fallback=None)
|
||||
if not isinstance(entry, dict):
|
||||
continue
|
||||
if entry.get("type", "") in STRIPPABLE_TYPES and not entry.get(
|
||||
"isCompactSummary"
|
||||
):
|
||||
continue
|
||||
msg = entry.get("message", {})
|
||||
role = msg.get("role", "")
|
||||
if not role:
|
||||
continue
|
||||
msg_dict: dict = {"role": role}
|
||||
raw_content = msg.get("content")
|
||||
if role == "assistant" and isinstance(raw_content, list):
|
||||
msg_dict["content"] = _flatten_assistant_content(raw_content)
|
||||
elif isinstance(raw_content, list):
|
||||
msg_dict["content"] = _flatten_tool_result_content(raw_content)
|
||||
else:
|
||||
msg_dict["content"] = raw_content or ""
|
||||
messages.append(msg_dict)
|
||||
return messages
|
||||
|
||||
|
||||
def _messages_to_transcript(messages: list[dict]) -> str:
|
||||
"""Convert compressed message dicts back to JSONL transcript format."""
|
||||
lines: list[str] = []
|
||||
last_uuid: str | None = None
|
||||
for msg in messages:
|
||||
role = msg.get("role", "user")
|
||||
entry_type = "assistant" if role == "assistant" else "user"
|
||||
uid = str(uuid4())
|
||||
content = msg.get("content", "")
|
||||
if role == "assistant":
|
||||
message: dict = {
|
||||
"role": "assistant",
|
||||
"model": "",
|
||||
"id": f"{COMPACT_MSG_ID_PREFIX}{uuid4().hex[:24]}",
|
||||
"type": ENTRY_TYPE_MESSAGE,
|
||||
"content": [{"type": "text", "text": content}] if content else [],
|
||||
"stop_reason": STOP_REASON_END_TURN,
|
||||
"stop_sequence": None,
|
||||
}
|
||||
else:
|
||||
message = {"role": role, "content": content}
|
||||
entry = {
|
||||
"type": entry_type,
|
||||
"uuid": uid,
|
||||
"parentUuid": last_uuid,
|
||||
"message": message,
|
||||
}
|
||||
lines.append(json.dumps(entry, separators=(",", ":")))
|
||||
last_uuid = uid
|
||||
return "\n".join(lines) + "\n" if lines else ""
|
||||
|
||||
|
||||
async def _run_compression(
|
||||
messages: list[dict],
|
||||
model: str,
|
||||
cfg: ChatConfig,
|
||||
log_prefix: str,
|
||||
) -> CompressResult:
|
||||
"""Run LLM-based compression with truncation fallback."""
|
||||
try:
|
||||
async with openai.AsyncOpenAI(
|
||||
api_key=cfg.api_key, base_url=cfg.base_url, timeout=30.0
|
||||
) as client:
|
||||
return await compress_context(messages=messages, model=model, client=client)
|
||||
await storage.delete(path)
|
||||
logger.info(f"[Transcript] Deleted transcript for session {session_id}")
|
||||
except Exception as e:
|
||||
logger.warning("%s LLM compaction failed, using truncation: %s", log_prefix, e)
|
||||
return await compress_context(messages=messages, model=model, client=None)
|
||||
|
||||
|
||||
async def compact_transcript(
|
||||
content: str,
|
||||
log_prefix: str = "[Transcript]",
|
||||
) -> str | None:
|
||||
"""Compact an oversized JSONL transcript using LLM summarization.
|
||||
|
||||
Converts transcript entries to plain messages, runs ``compress_context``
|
||||
(the same compressor used for pre-query history), and rebuilds JSONL.
|
||||
|
||||
Returns the compacted JSONL string, or ``None`` on failure.
|
||||
"""
|
||||
cfg = ChatConfig()
|
||||
messages = _transcript_to_messages(content)
|
||||
if len(messages) < 2:
|
||||
logger.warning("%s Too few messages to compact (%d)", log_prefix, len(messages))
|
||||
return None
|
||||
try:
|
||||
result = await _run_compression(messages, cfg.model, cfg, log_prefix)
|
||||
if not result.was_compacted:
|
||||
logger.info("%s Transcript already within token budget", log_prefix)
|
||||
return content
|
||||
logger.info(
|
||||
"%s Compacted transcript: %d->%d tokens (%d summarized, %d dropped)",
|
||||
log_prefix,
|
||||
result.original_token_count,
|
||||
result.token_count,
|
||||
result.messages_summarized,
|
||||
result.messages_dropped,
|
||||
)
|
||||
compacted = _messages_to_transcript(result.messages)
|
||||
if not validate_transcript(compacted):
|
||||
logger.warning("%s Compacted transcript failed validation", log_prefix)
|
||||
return None
|
||||
return compacted
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"%s Transcript compaction failed: %s", log_prefix, e, exc_info=True
|
||||
)
|
||||
return None
|
||||
logger.warning(f"[Transcript] Failed to delete transcript: {e}")
|
||||
|
||||
@@ -31,7 +31,6 @@ class TranscriptEntry(BaseModel):
|
||||
uuid: str
|
||||
parentUuid: str | None
|
||||
message: dict[str, Any]
|
||||
isCompactSummary: bool | None = None
|
||||
|
||||
|
||||
class TranscriptBuilder:
|
||||
@@ -79,12 +78,10 @@ class TranscriptBuilder:
|
||||
)
|
||||
continue
|
||||
|
||||
# Skip STRIPPABLE_TYPES unless the entry is a compaction summary.
|
||||
# Compaction summaries may have type "summary" but must be preserved
|
||||
# so --resume can reconstruct the compacted conversation.
|
||||
# Load all non-strippable entries (user/assistant/system/etc.)
|
||||
# Skip only STRIPPABLE_TYPES to match strip_progress_entries() behavior
|
||||
entry_type = data.get("type", "")
|
||||
is_compact = data.get("isCompactSummary", False)
|
||||
if entry_type in STRIPPABLE_TYPES and not is_compact:
|
||||
if entry_type in STRIPPABLE_TYPES:
|
||||
continue
|
||||
|
||||
entry = TranscriptEntry(
|
||||
@@ -92,7 +89,6 @@ class TranscriptBuilder:
|
||||
uuid=data.get("uuid") or str(uuid4()),
|
||||
parentUuid=data.get("parentUuid"),
|
||||
message=data.get("message", {}),
|
||||
isCompactSummary=True if is_compact else None,
|
||||
)
|
||||
self._entries.append(entry)
|
||||
self._last_uuid = entry.uuid
|
||||
@@ -181,33 +177,6 @@ class TranscriptBuilder:
|
||||
lines = [entry.model_dump_json(exclude_none=True) for entry in self._entries]
|
||||
return "\n".join(lines) + "\n"
|
||||
|
||||
def replace_entries(self, content: str, log_prefix: str = "[Transcript]") -> None:
|
||||
"""Replace all entries with compacted JSONL content.
|
||||
|
||||
Called after the CLI performs mid-stream compaction so the builder's
|
||||
state reflects the compacted conversation instead of the full
|
||||
pre-compaction history.
|
||||
"""
|
||||
prev_count = len(self._entries)
|
||||
temp = TranscriptBuilder()
|
||||
try:
|
||||
temp.load_previous(content, log_prefix=log_prefix)
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"%s Failed to parse compacted transcript; keeping %d existing entries",
|
||||
log_prefix,
|
||||
prev_count,
|
||||
)
|
||||
return
|
||||
self._entries = temp._entries
|
||||
self._last_uuid = temp._last_uuid
|
||||
logger.info(
|
||||
"%s Replaced %d entries with %d compacted entries",
|
||||
log_prefix,
|
||||
prev_count,
|
||||
len(self._entries),
|
||||
)
|
||||
|
||||
@property
|
||||
def entry_count(self) -> int:
|
||||
"""Total number of entries in the complete context."""
|
||||
|
||||
@@ -2,25 +2,14 @@
|
||||
|
||||
import os
|
||||
|
||||
import pytest
|
||||
|
||||
from backend.util import json
|
||||
|
||||
from .transcript import (
|
||||
COMPACT_MSG_ID_PREFIX,
|
||||
STRIPPABLE_TYPES,
|
||||
_cli_project_dir,
|
||||
_flatten_assistant_content,
|
||||
_flatten_tool_result_content,
|
||||
_messages_to_transcript,
|
||||
_transcript_to_messages,
|
||||
compact_transcript,
|
||||
read_cli_session_file,
|
||||
strip_progress_entries,
|
||||
validate_transcript,
|
||||
write_transcript_to_tempfile,
|
||||
)
|
||||
from .transcript_builder import TranscriptBuilder
|
||||
|
||||
|
||||
def _make_jsonl(*entries: dict) -> str:
|
||||
@@ -46,14 +35,6 @@ PROGRESS_ENTRY = {
|
||||
"data": {"type": "bash_progress", "stdout": "running..."},
|
||||
}
|
||||
|
||||
COMPACT_SUMMARY = {
|
||||
"type": "summary",
|
||||
"uuid": "cs1",
|
||||
"parentUuid": None,
|
||||
"isCompactSummary": True,
|
||||
"message": {"role": "user", "content": "Summary of previous conversation..."},
|
||||
}
|
||||
|
||||
VALID_TRANSCRIPT = _make_jsonl(METADATA_LINE, FILE_HISTORY, USER_MSG, ASST_MSG)
|
||||
|
||||
|
||||
@@ -256,121 +237,6 @@ class TestStripProgressEntries:
|
||||
# Should return just a newline (empty content stripped)
|
||||
assert result.strip() == ""
|
||||
|
||||
|
||||
# --- _cli_project_dir ---
|
||||
|
||||
|
||||
class TestCliProjectDir:
|
||||
def test_returns_path_for_valid_cwd(self, tmp_path, monkeypatch):
|
||||
monkeypatch.setenv("CLAUDE_CONFIG_DIR", str(tmp_path))
|
||||
projects = tmp_path / "projects"
|
||||
projects.mkdir()
|
||||
result = _cli_project_dir("/tmp/copilot-abc")
|
||||
assert result is not None
|
||||
assert "projects" in result
|
||||
|
||||
def test_returns_none_for_path_traversal(self, tmp_path, monkeypatch):
|
||||
monkeypatch.setenv("CLAUDE_CONFIG_DIR", str(tmp_path))
|
||||
projects = tmp_path / "projects"
|
||||
projects.mkdir()
|
||||
# A cwd that encodes to something with .. shouldn't escape
|
||||
result = _cli_project_dir("/tmp/copilot-test")
|
||||
# Should return a valid path (no traversal possible with alphanum encoding)
|
||||
assert result is None or result.startswith(str(projects))
|
||||
|
||||
|
||||
# --- read_cli_session_file ---
|
||||
|
||||
|
||||
class TestReadCliSessionFile:
|
||||
@pytest.mark.asyncio
|
||||
async def test_reads_session_file(self, tmp_path, monkeypatch):
|
||||
monkeypatch.setenv("CLAUDE_CONFIG_DIR", str(tmp_path))
|
||||
# Create the CLI project directory structure
|
||||
cwd = "/tmp/copilot-testread"
|
||||
import re
|
||||
|
||||
encoded = re.sub(r"[^a-zA-Z0-9]", "-", os.path.realpath(cwd))
|
||||
project_dir = tmp_path / "projects" / encoded
|
||||
project_dir.mkdir(parents=True)
|
||||
# Write a session file
|
||||
session_file = project_dir / "test-session.jsonl"
|
||||
session_file.write_text(json.dumps(ASST_MSG) + "\n")
|
||||
|
||||
result = await read_cli_session_file(cwd)
|
||||
assert result is not None
|
||||
assert "assistant" in result
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_returns_none_when_no_files(self, tmp_path, monkeypatch):
|
||||
monkeypatch.setenv("CLAUDE_CONFIG_DIR", str(tmp_path))
|
||||
cwd = "/tmp/copilot-nofiles"
|
||||
import re
|
||||
|
||||
encoded = re.sub(r"[^a-zA-Z0-9]", "-", os.path.realpath(cwd))
|
||||
project_dir = tmp_path / "projects" / encoded
|
||||
project_dir.mkdir(parents=True)
|
||||
# No jsonl files
|
||||
result = await read_cli_session_file(cwd)
|
||||
assert result is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_returns_none_when_dir_missing(self, tmp_path, monkeypatch):
|
||||
monkeypatch.setenv("CLAUDE_CONFIG_DIR", str(tmp_path))
|
||||
(tmp_path / "projects").mkdir()
|
||||
result = await read_cli_session_file("/tmp/copilot-nonexistent")
|
||||
assert result is None
|
||||
|
||||
|
||||
# --- _transcript_to_messages / _messages_to_transcript ---
|
||||
|
||||
|
||||
class TestTranscriptMessageConversion:
|
||||
def test_roundtrip_preserves_roles(self):
|
||||
transcript = _make_jsonl(USER_MSG, ASST_MSG)
|
||||
messages = _transcript_to_messages(transcript)
|
||||
assert len(messages) == 2
|
||||
assert messages[0]["role"] == "user"
|
||||
assert messages[1]["role"] == "assistant"
|
||||
|
||||
def test_messages_to_transcript_produces_valid_jsonl(self):
|
||||
messages = [
|
||||
{"role": "user", "content": "hi"},
|
||||
{"role": "assistant", "content": "hello"},
|
||||
]
|
||||
result = _messages_to_transcript(messages)
|
||||
assert validate_transcript(result) is True
|
||||
|
||||
def test_strips_strippable_types(self):
|
||||
transcript = _make_jsonl(
|
||||
{"type": "progress", "uuid": "p1", "message": {"role": "user"}},
|
||||
USER_MSG,
|
||||
ASST_MSG,
|
||||
)
|
||||
messages = _transcript_to_messages(transcript)
|
||||
assert len(messages) == 2 # progress entry skipped
|
||||
|
||||
def test_flattens_assistant_content_blocks(self):
|
||||
asst_with_blocks = {
|
||||
"type": "assistant",
|
||||
"uuid": "a1",
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"content": [
|
||||
{"type": "text", "text": "hello"},
|
||||
{"type": "tool_use", "name": "bash"},
|
||||
],
|
||||
},
|
||||
}
|
||||
messages = _transcript_to_messages(_make_jsonl(asst_with_blocks))
|
||||
assert len(messages) == 1
|
||||
assert "hello" in messages[0]["content"]
|
||||
assert "[tool_use: bash]" in messages[0]["content"]
|
||||
|
||||
def test_empty_messages_returns_empty(self):
|
||||
result = _messages_to_transcript([])
|
||||
assert result == ""
|
||||
|
||||
def test_no_strippable_entries(self):
|
||||
"""When there's nothing to strip, output matches input structure."""
|
||||
content = _make_jsonl(USER_MSG, ASST_MSG)
|
||||
@@ -416,654 +282,3 @@ class TestTranscriptMessageConversion:
|
||||
lines = result.strip().split("\n")
|
||||
asst_entry = json.loads(lines[-1])
|
||||
assert asst_entry["parentUuid"] == "u1" # reparented
|
||||
|
||||
|
||||
# --- TranscriptBuilder ---
|
||||
|
||||
|
||||
class TestTranscriptBuilderReplaceEntries:
|
||||
"""Tests for TranscriptBuilder.replace_entries — the compaction sync path."""
|
||||
|
||||
def test_replace_entries_with_valid_content(self):
|
||||
builder = TranscriptBuilder()
|
||||
builder.append_user("hello")
|
||||
builder.append_assistant([{"type": "text", "text": "world"}])
|
||||
assert builder.entry_count == 2
|
||||
|
||||
# Replace with compacted content (one user + one assistant)
|
||||
compacted = _make_jsonl(USER_MSG, ASST_MSG)
|
||||
builder.replace_entries(compacted)
|
||||
assert builder.entry_count == 2
|
||||
|
||||
def test_replace_entries_keeps_old_on_corrupt_content(self):
|
||||
builder = TranscriptBuilder()
|
||||
builder.append_user("hello")
|
||||
assert builder.entry_count == 1
|
||||
|
||||
# Corrupt content that fails to parse
|
||||
builder.replace_entries("not valid json at all\n")
|
||||
# Should still have old entries (load_previous skips invalid lines,
|
||||
# but if ALL lines are invalid, temp builder is empty → exception path)
|
||||
assert builder.entry_count >= 0 # doesn't crash
|
||||
|
||||
def test_replace_entries_with_empty_content(self):
|
||||
builder = TranscriptBuilder()
|
||||
builder.append_user("hello")
|
||||
assert builder.entry_count == 1
|
||||
|
||||
builder.replace_entries("")
|
||||
# Empty content → load_previous returns early → temp is empty
|
||||
# replace_entries swaps to empty (0 entries)
|
||||
assert builder.entry_count == 0
|
||||
|
||||
def test_replace_entries_filters_strippable_types(self):
|
||||
"""Strippable types (progress, file-history-snapshot) are filtered out."""
|
||||
builder = TranscriptBuilder()
|
||||
builder.append_user("hello")
|
||||
|
||||
content = _make_jsonl(
|
||||
{"type": "progress", "uuid": "p1", "message": {}},
|
||||
USER_MSG,
|
||||
ASST_MSG,
|
||||
)
|
||||
builder.replace_entries(content)
|
||||
# Only user + assistant should remain (progress filtered)
|
||||
assert builder.entry_count == 2
|
||||
|
||||
def test_replace_entries_preserves_uuids(self):
|
||||
builder = TranscriptBuilder()
|
||||
content = _make_jsonl(USER_MSG, ASST_MSG)
|
||||
builder.replace_entries(content)
|
||||
|
||||
jsonl = builder.to_jsonl()
|
||||
lines = jsonl.strip().split("\n")
|
||||
first = json.loads(lines[0])
|
||||
assert first["uuid"] == "u1"
|
||||
|
||||
|
||||
class TestTranscriptBuilderBasic:
|
||||
def test_append_user_and_assistant(self):
|
||||
builder = TranscriptBuilder()
|
||||
builder.append_user("hi")
|
||||
builder.append_assistant([{"type": "text", "text": "hello"}])
|
||||
assert builder.entry_count == 2
|
||||
assert not builder.is_empty
|
||||
|
||||
def test_to_jsonl_empty(self):
|
||||
builder = TranscriptBuilder()
|
||||
assert builder.to_jsonl() == ""
|
||||
assert builder.is_empty
|
||||
|
||||
def test_load_previous_and_append(self):
|
||||
builder = TranscriptBuilder()
|
||||
content = _make_jsonl(USER_MSG, ASST_MSG)
|
||||
builder.load_previous(content)
|
||||
assert builder.entry_count == 2
|
||||
builder.append_user("new message")
|
||||
assert builder.entry_count == 3
|
||||
|
||||
def test_consecutive_assistant_entries_share_message_id(self):
|
||||
builder = TranscriptBuilder()
|
||||
builder.append_user("hi")
|
||||
builder.append_assistant([{"type": "text", "text": "part1"}])
|
||||
builder.append_assistant([{"type": "text", "text": "part2"}])
|
||||
|
||||
jsonl = builder.to_jsonl()
|
||||
lines = jsonl.strip().split("\n")
|
||||
asst1 = json.loads(lines[1])
|
||||
asst2 = json.loads(lines[2])
|
||||
assert asst1["message"]["id"] == asst2["message"]["id"]
|
||||
|
||||
def test_non_consecutive_assistant_entries_get_new_id(self):
|
||||
builder = TranscriptBuilder()
|
||||
builder.append_user("hi")
|
||||
builder.append_assistant([{"type": "text", "text": "response1"}])
|
||||
builder.append_user("followup")
|
||||
builder.append_assistant([{"type": "text", "text": "response2"}])
|
||||
|
||||
jsonl = builder.to_jsonl()
|
||||
lines = jsonl.strip().split("\n")
|
||||
asst1 = json.loads(lines[1])
|
||||
asst2 = json.loads(lines[3])
|
||||
assert asst1["message"]["id"] != asst2["message"]["id"]
|
||||
|
||||
|
||||
class TestCompactSummaryRoundtrip:
|
||||
"""Verify isCompactSummary survives export→reload roundtrip."""
|
||||
|
||||
def test_load_previous_preserves_compact_summary(self):
|
||||
"""Compaction summary with type 'summary' should not be stripped."""
|
||||
content = _make_jsonl(COMPACT_SUMMARY, USER_MSG, ASST_MSG)
|
||||
builder = TranscriptBuilder()
|
||||
builder.load_previous(content)
|
||||
# summary type is in STRIPPABLE_TYPES, but isCompactSummary keeps it
|
||||
assert builder.entry_count == 3
|
||||
|
||||
def test_export_reload_preserves_compact_summary(self):
|
||||
"""Critical: isCompactSummary must survive to_jsonl → load_previous."""
|
||||
content = _make_jsonl(COMPACT_SUMMARY, USER_MSG, ASST_MSG)
|
||||
builder1 = TranscriptBuilder()
|
||||
builder1.load_previous(content)
|
||||
assert builder1.entry_count == 3
|
||||
|
||||
exported = builder1.to_jsonl()
|
||||
# Verify isCompactSummary is in the exported JSONL
|
||||
first_line = json.loads(exported.strip().split("\n")[0])
|
||||
assert first_line.get("isCompactSummary") is True
|
||||
|
||||
# Reload and verify it's still preserved
|
||||
builder2 = TranscriptBuilder()
|
||||
builder2.load_previous(exported)
|
||||
assert builder2.entry_count == 3
|
||||
|
||||
def test_strip_progress_preserves_compact_summary(self):
|
||||
"""strip_progress_entries should keep isCompactSummary entries."""
|
||||
content = _make_jsonl(COMPACT_SUMMARY, USER_MSG, ASST_MSG)
|
||||
stripped = strip_progress_entries(content)
|
||||
entries = [json.loads(line) for line in stripped.strip().split("\n")]
|
||||
types = [e.get("type") for e in entries]
|
||||
assert "summary" in types # Not stripped despite being in STRIPPABLE_TYPES
|
||||
compact = [e for e in entries if e.get("isCompactSummary")]
|
||||
assert len(compact) == 1
|
||||
|
||||
def test_regular_summary_still_stripped(self):
|
||||
"""Non-compact summaries should still be stripped."""
|
||||
regular_summary = {
|
||||
"type": "summary",
|
||||
"uuid": "rs1",
|
||||
"summary": "Session summary",
|
||||
}
|
||||
content = _make_jsonl(regular_summary, USER_MSG, ASST_MSG)
|
||||
stripped = strip_progress_entries(content)
|
||||
entries = [json.loads(line) for line in stripped.strip().split("\n")]
|
||||
types = [e.get("type") for e in entries]
|
||||
assert "summary" not in types
|
||||
|
||||
def test_replace_entries_preserves_compact_summary(self):
|
||||
"""replace_entries should preserve isCompactSummary entries."""
|
||||
builder = TranscriptBuilder()
|
||||
builder.append_user("old")
|
||||
content = _make_jsonl(COMPACT_SUMMARY, USER_MSG, ASST_MSG)
|
||||
builder.replace_entries(content)
|
||||
assert builder.entry_count == 3
|
||||
|
||||
# Verify by re-exporting
|
||||
exported = builder.to_jsonl()
|
||||
first = json.loads(exported.strip().split("\n")[0])
|
||||
assert first.get("isCompactSummary") is True
|
||||
|
||||
|
||||
# --- _flatten_assistant_content ---
|
||||
|
||||
|
||||
class TestFlattenAssistantContent:
|
||||
def test_text_blocks(self):
|
||||
blocks = [
|
||||
{"type": "text", "text": "Hello"},
|
||||
{"type": "text", "text": "World"},
|
||||
]
|
||||
assert _flatten_assistant_content(blocks) == "Hello\nWorld"
|
||||
|
||||
def test_tool_use_blocks(self):
|
||||
blocks = [{"type": "tool_use", "name": "read_file", "id": "t1", "input": {}}]
|
||||
assert _flatten_assistant_content(blocks) == "[tool_use: read_file]"
|
||||
|
||||
def test_mixed_blocks(self):
|
||||
blocks = [
|
||||
{"type": "text", "text": "Let me read that."},
|
||||
{"type": "tool_use", "name": "read", "id": "t1", "input": {}},
|
||||
]
|
||||
result = _flatten_assistant_content(blocks)
|
||||
assert "Let me read that." in result
|
||||
assert "[tool_use: read]" in result
|
||||
|
||||
def test_string_blocks(self):
|
||||
"""Plain strings in the list should be included."""
|
||||
assert _flatten_assistant_content(["hello", "world"]) == "hello\nworld"
|
||||
|
||||
def test_empty_list(self):
|
||||
assert _flatten_assistant_content([]) == ""
|
||||
|
||||
def test_tool_use_missing_name(self):
|
||||
blocks = [{"type": "tool_use", "id": "t1", "input": {}}]
|
||||
assert _flatten_assistant_content(blocks) == "[tool_use: ?]"
|
||||
|
||||
|
||||
# --- _flatten_tool_result_content ---
|
||||
|
||||
|
||||
class TestFlattenToolResultContent:
|
||||
def test_tool_result_with_text(self):
|
||||
blocks = [
|
||||
{
|
||||
"type": "tool_result",
|
||||
"tool_use_id": "t1",
|
||||
"content": [{"type": "text", "text": "file contents here"}],
|
||||
}
|
||||
]
|
||||
assert _flatten_tool_result_content(blocks) == "file contents here"
|
||||
|
||||
def test_tool_result_with_string_content(self):
|
||||
blocks = [
|
||||
{"type": "tool_result", "tool_use_id": "t1", "content": "simple result"}
|
||||
]
|
||||
assert _flatten_tool_result_content(blocks) == "simple result"
|
||||
|
||||
def test_tool_result_with_nested_list(self):
|
||||
blocks = [
|
||||
{
|
||||
"type": "tool_result",
|
||||
"tool_use_id": "t1",
|
||||
"content": [
|
||||
{"type": "text", "text": "line 1"},
|
||||
{"type": "text", "text": "line 2"},
|
||||
],
|
||||
}
|
||||
]
|
||||
assert _flatten_tool_result_content(blocks) == "line 1\nline 2"
|
||||
|
||||
def test_text_blocks(self):
|
||||
blocks = [{"type": "text", "text": "some text"}]
|
||||
assert _flatten_tool_result_content(blocks) == "some text"
|
||||
|
||||
def test_string_items(self):
|
||||
assert _flatten_tool_result_content(["raw string"]) == "raw string"
|
||||
|
||||
def test_empty_list(self):
|
||||
assert _flatten_tool_result_content([]) == ""
|
||||
|
||||
def test_tool_result_none_text_uses_json(self):
|
||||
"""Dicts without text key fall back to json.dumps."""
|
||||
blocks = [
|
||||
{
|
||||
"type": "tool_result",
|
||||
"tool_use_id": "t1",
|
||||
"content": [{"type": "image", "source": "data:..."}],
|
||||
}
|
||||
]
|
||||
result = _flatten_tool_result_content(blocks)
|
||||
assert "image" in result # json.dumps fallback includes the key
|
||||
|
||||
|
||||
# --- _transcript_to_messages ---
|
||||
|
||||
|
||||
class TestTranscriptToMessages:
|
||||
def test_basic_conversation(self):
|
||||
content = _make_jsonl(
|
||||
{
|
||||
"type": "user",
|
||||
"uuid": "u1",
|
||||
"message": {"role": "user", "content": "hello"},
|
||||
},
|
||||
{
|
||||
"type": "assistant",
|
||||
"uuid": "a1",
|
||||
"parentUuid": "u1",
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"content": [{"type": "text", "text": "hi there"}],
|
||||
},
|
||||
},
|
||||
)
|
||||
msgs = _transcript_to_messages(content)
|
||||
assert len(msgs) == 2
|
||||
assert msgs[0] == {"role": "user", "content": "hello"}
|
||||
assert msgs[1] == {"role": "assistant", "content": "hi there"}
|
||||
|
||||
def test_strips_progress_entries(self):
|
||||
content = _make_jsonl(
|
||||
{
|
||||
"type": "user",
|
||||
"uuid": "u1",
|
||||
"message": {"role": "user", "content": "hi"},
|
||||
},
|
||||
{
|
||||
"type": "progress",
|
||||
"uuid": "p1",
|
||||
"message": {"role": "user", "content": "..."},
|
||||
},
|
||||
{
|
||||
"type": "assistant",
|
||||
"uuid": "a1",
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"content": [{"type": "text", "text": "ok"}],
|
||||
},
|
||||
},
|
||||
)
|
||||
msgs = _transcript_to_messages(content)
|
||||
assert len(msgs) == 2
|
||||
assert msgs[0]["role"] == "user"
|
||||
assert msgs[1]["role"] == "assistant"
|
||||
|
||||
def test_preserves_compact_summaries(self):
|
||||
content = _make_jsonl(
|
||||
{
|
||||
"type": "summary",
|
||||
"uuid": "cs1",
|
||||
"isCompactSummary": True,
|
||||
"message": {"role": "user", "content": "Summary of previous..."},
|
||||
},
|
||||
{
|
||||
"type": "user",
|
||||
"uuid": "u1",
|
||||
"message": {"role": "user", "content": "hi"},
|
||||
},
|
||||
)
|
||||
msgs = _transcript_to_messages(content)
|
||||
assert len(msgs) == 2
|
||||
assert msgs[0]["content"] == "Summary of previous..."
|
||||
|
||||
def test_strips_regular_summary(self):
|
||||
content = _make_jsonl(
|
||||
{
|
||||
"type": "summary",
|
||||
"uuid": "s1",
|
||||
"message": {"role": "user", "content": "Session summary"},
|
||||
},
|
||||
{
|
||||
"type": "user",
|
||||
"uuid": "u1",
|
||||
"message": {"role": "user", "content": "hi"},
|
||||
},
|
||||
)
|
||||
msgs = _transcript_to_messages(content)
|
||||
assert len(msgs) == 1
|
||||
assert msgs[0]["content"] == "hi"
|
||||
|
||||
def test_skips_entries_without_role(self):
|
||||
content = _make_jsonl(
|
||||
{"type": "user", "uuid": "u1", "message": {}},
|
||||
{
|
||||
"type": "user",
|
||||
"uuid": "u2",
|
||||
"message": {"role": "user", "content": "hi"},
|
||||
},
|
||||
)
|
||||
msgs = _transcript_to_messages(content)
|
||||
assert len(msgs) == 1
|
||||
|
||||
def test_tool_result_content(self):
|
||||
content = _make_jsonl(
|
||||
{
|
||||
"type": "user",
|
||||
"uuid": "u1",
|
||||
"message": {
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "tool_result",
|
||||
"tool_use_id": "t1",
|
||||
"content": "file contents",
|
||||
}
|
||||
],
|
||||
},
|
||||
},
|
||||
)
|
||||
msgs = _transcript_to_messages(content)
|
||||
assert len(msgs) == 1
|
||||
assert "file contents" in msgs[0]["content"]
|
||||
|
||||
def test_empty_content(self):
|
||||
assert _transcript_to_messages("") == []
|
||||
assert _transcript_to_messages(" \n ") == []
|
||||
|
||||
def test_invalid_json_lines_skipped(self):
|
||||
content = '{"type":"user","uuid":"u1","message":{"role":"user","content":"hi"}}\nnot json\n'
|
||||
msgs = _transcript_to_messages(content)
|
||||
assert len(msgs) == 1
|
||||
|
||||
|
||||
# --- _messages_to_transcript ---
|
||||
|
||||
|
||||
class TestMessagesToTranscript:
|
||||
def test_basic_roundtrip_structure(self):
|
||||
messages = [
|
||||
{"role": "user", "content": "hello"},
|
||||
{"role": "assistant", "content": "hi there"},
|
||||
]
|
||||
result = _messages_to_transcript(messages)
|
||||
assert result.endswith("\n")
|
||||
lines = [json.loads(line) for line in result.strip().split("\n")]
|
||||
assert len(lines) == 2
|
||||
|
||||
# User entry
|
||||
assert lines[0]["type"] == "user"
|
||||
assert lines[0]["message"]["role"] == "user"
|
||||
assert lines[0]["message"]["content"] == "hello"
|
||||
assert lines[0]["parentUuid"] is None
|
||||
|
||||
# Assistant entry
|
||||
assert lines[1]["type"] == "assistant"
|
||||
assert lines[1]["message"]["role"] == "assistant"
|
||||
assert lines[1]["message"]["content"] == [{"type": "text", "text": "hi there"}]
|
||||
assert lines[1]["message"]["id"].startswith(COMPACT_MSG_ID_PREFIX)
|
||||
assert lines[1]["parentUuid"] == lines[0]["uuid"]
|
||||
|
||||
def test_parent_uuid_chain(self):
|
||||
messages = [
|
||||
{"role": "user", "content": "q1"},
|
||||
{"role": "assistant", "content": "a1"},
|
||||
{"role": "user", "content": "q2"},
|
||||
]
|
||||
result = _messages_to_transcript(messages)
|
||||
lines = [json.loads(line) for line in result.strip().split("\n")]
|
||||
assert lines[0]["parentUuid"] is None
|
||||
assert lines[1]["parentUuid"] == lines[0]["uuid"]
|
||||
assert lines[2]["parentUuid"] == lines[1]["uuid"]
|
||||
|
||||
def test_empty_messages(self):
|
||||
assert _messages_to_transcript([]) == ""
|
||||
|
||||
def test_assistant_empty_content(self):
|
||||
messages = [{"role": "assistant", "content": ""}]
|
||||
result = _messages_to_transcript(messages)
|
||||
entry = json.loads(result.strip())
|
||||
assert entry["message"]["content"] == []
|
||||
|
||||
def test_output_is_valid_transcript(self):
|
||||
messages = [
|
||||
{"role": "user", "content": "hello"},
|
||||
{"role": "assistant", "content": "world"},
|
||||
]
|
||||
result = _messages_to_transcript(messages)
|
||||
assert validate_transcript(result)
|
||||
|
||||
|
||||
# --- _transcript_to_messages + _messages_to_transcript roundtrip ---
|
||||
|
||||
|
||||
class TestTranscriptCompactionRoundtrip:
|
||||
def test_content_preserved_through_roundtrip(self):
|
||||
"""Messages→transcript→messages preserves content."""
|
||||
original = [
|
||||
{"role": "user", "content": "What is 2+2?"},
|
||||
{"role": "assistant", "content": "4"},
|
||||
{"role": "user", "content": "Thanks"},
|
||||
]
|
||||
transcript = _messages_to_transcript(original)
|
||||
recovered = _transcript_to_messages(transcript)
|
||||
assert len(recovered) == len(original)
|
||||
for orig, rec in zip(original, recovered):
|
||||
assert orig["role"] == rec["role"]
|
||||
assert orig["content"] == rec["content"]
|
||||
|
||||
def test_full_transcript_to_messages_and_back(self):
|
||||
"""Real-ish JSONL → messages → transcript → messages roundtrip."""
|
||||
source = _make_jsonl(
|
||||
{
|
||||
"type": "user",
|
||||
"uuid": "u1",
|
||||
"message": {"role": "user", "content": "explain python"},
|
||||
},
|
||||
{
|
||||
"type": "assistant",
|
||||
"uuid": "a1",
|
||||
"parentUuid": "u1",
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"content": [
|
||||
{"type": "text", "text": "Python is a programming language."}
|
||||
],
|
||||
},
|
||||
},
|
||||
{
|
||||
"type": "user",
|
||||
"uuid": "u2",
|
||||
"parentUuid": "a1",
|
||||
"message": {
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "tool_result",
|
||||
"tool_use_id": "t1",
|
||||
"content": "output of ls",
|
||||
}
|
||||
],
|
||||
},
|
||||
},
|
||||
)
|
||||
msgs1 = _transcript_to_messages(source)
|
||||
assert len(msgs1) == 3
|
||||
|
||||
rebuilt = _messages_to_transcript(msgs1)
|
||||
msgs2 = _transcript_to_messages(rebuilt)
|
||||
assert len(msgs2) == len(msgs1)
|
||||
for m1, m2 in zip(msgs1, msgs2):
|
||||
assert m1["role"] == m2["role"]
|
||||
# Content may differ in format (list vs string) but text is preserved
|
||||
assert m1["content"] in m2["content"] or m2["content"] in m1["content"]
|
||||
|
||||
|
||||
# --- compact_transcript ---
|
||||
|
||||
|
||||
class TestCompactTranscript:
|
||||
@pytest.mark.asyncio
|
||||
async def test_too_few_messages_returns_none(self):
|
||||
"""Transcripts with < 2 messages can't be compacted."""
|
||||
single = _make_jsonl(
|
||||
{"type": "user", "uuid": "u1", "message": {"role": "user", "content": "hi"}}
|
||||
)
|
||||
result = await compact_transcript(single)
|
||||
assert result is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_empty_transcript_returns_none(self):
|
||||
result = await compact_transcript("")
|
||||
assert result is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_compaction_produces_valid_transcript(self, monkeypatch):
|
||||
"""When compress_context compacts, result should be valid JSONL."""
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
from backend.util.prompt import CompressResult
|
||||
|
||||
mock_result = CompressResult(
|
||||
messages=[
|
||||
{"role": "user", "content": "Summary of conversation"},
|
||||
{"role": "assistant", "content": "Acknowledged"},
|
||||
],
|
||||
token_count=50,
|
||||
was_compacted=True,
|
||||
original_token_count=5000,
|
||||
messages_summarized=10,
|
||||
messages_dropped=5,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"backend.copilot.sdk.transcript._run_compression",
|
||||
AsyncMock(return_value=mock_result),
|
||||
)
|
||||
|
||||
source = _make_jsonl(
|
||||
{
|
||||
"type": "user",
|
||||
"uuid": "u1",
|
||||
"message": {"role": "user", "content": "msg1"},
|
||||
},
|
||||
{
|
||||
"type": "assistant",
|
||||
"uuid": "a1",
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"content": [{"type": "text", "text": "reply1"}],
|
||||
},
|
||||
},
|
||||
{
|
||||
"type": "user",
|
||||
"uuid": "u2",
|
||||
"message": {"role": "user", "content": "msg2"},
|
||||
},
|
||||
)
|
||||
result = await compact_transcript(source)
|
||||
assert result is not None
|
||||
assert validate_transcript(result)
|
||||
|
||||
# Verify compacted content
|
||||
msgs = _transcript_to_messages(result)
|
||||
assert len(msgs) == 2
|
||||
assert msgs[0]["content"] == "Summary of conversation"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_compaction_needed_returns_original(self, monkeypatch):
|
||||
"""When compress_context says no compaction needed, return original."""
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
from backend.util.prompt import CompressResult
|
||||
|
||||
mock_result = CompressResult(
|
||||
messages=[], token_count=100, was_compacted=False, original_token_count=100
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"backend.copilot.sdk.transcript._run_compression",
|
||||
AsyncMock(return_value=mock_result),
|
||||
)
|
||||
|
||||
source = _make_jsonl(
|
||||
{
|
||||
"type": "user",
|
||||
"uuid": "u1",
|
||||
"message": {"role": "user", "content": "hi"},
|
||||
},
|
||||
{
|
||||
"type": "assistant",
|
||||
"uuid": "a1",
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"content": [{"type": "text", "text": "hello"}],
|
||||
},
|
||||
},
|
||||
)
|
||||
result = await compact_transcript(source)
|
||||
assert result == source # Unchanged
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_compression_failure_returns_none(self, monkeypatch):
|
||||
"""When _run_compression raises, compact_transcript returns None."""
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
monkeypatch.setattr(
|
||||
"backend.copilot.sdk.transcript._run_compression",
|
||||
AsyncMock(side_effect=RuntimeError("LLM unavailable")),
|
||||
)
|
||||
|
||||
source = _make_jsonl(
|
||||
{
|
||||
"type": "user",
|
||||
"uuid": "u1",
|
||||
"message": {"role": "user", "content": "hi"},
|
||||
},
|
||||
{
|
||||
"type": "assistant",
|
||||
"uuid": "a1",
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"content": [{"type": "text", "text": "hello"}],
|
||||
},
|
||||
},
|
||||
)
|
||||
result = await compact_transcript(source)
|
||||
assert result is None
|
||||
|
||||
@@ -21,11 +21,9 @@ Lifecycle
|
||||
Cost control
|
||||
------------
|
||||
Sandboxes are created with a configurable ``on_timeout`` lifecycle action
|
||||
(default: ``"pause"``) and ``auto_resume`` (default: ``True``). The explicit
|
||||
per-turn ``pause_sandbox()`` call is the primary mechanism; the lifecycle
|
||||
timeout is a safety net (default: 5 min). ``auto_resume`` ensures that paused
|
||||
sandboxes wake transparently on SDK activity, making the aggressive safety-net
|
||||
timeout safe. Paused sandboxes are free.
|
||||
(default: ``"pause"``). The explicit per-turn ``pause_sandbox()`` call is the
|
||||
primary mechanism; the lifecycle setting is a safety net. Paused sandboxes are
|
||||
free.
|
||||
|
||||
The sandbox_id is stored in Redis. The same key doubles as a creation lock:
|
||||
a ``"creating"`` sentinel value is written with a short TTL while a new sandbox
|
||||
@@ -42,7 +40,6 @@ import logging
|
||||
from typing import Any, Awaitable, Callable, Literal
|
||||
|
||||
from e2b import AsyncSandbox
|
||||
from e2b.sandbox.sandbox_api import SandboxLifecycle
|
||||
|
||||
from backend.data.redis_client import get_redis_async
|
||||
|
||||
@@ -119,10 +116,9 @@ async def get_or_create_sandbox(
|
||||
removes the need for a separate lock key.
|
||||
|
||||
*timeout* controls how long the e2b sandbox may run continuously before
|
||||
the ``on_timeout`` lifecycle rule fires (default: 5 min).
|
||||
the ``on_timeout`` lifecycle rule fires (default: 3 h).
|
||||
*on_timeout* controls what happens on timeout: ``"pause"`` (default, free)
|
||||
or ``"kill"``. When ``"pause"``, ``auto_resume`` is enabled so paused
|
||||
sandboxes wake transparently on SDK activity.
|
||||
or ``"kill"``.
|
||||
"""
|
||||
redis = await get_redis_async()
|
||||
key = _sandbox_key(session_id)
|
||||
@@ -160,15 +156,11 @@ async def get_or_create_sandbox(
|
||||
|
||||
# We hold the slot — create the sandbox.
|
||||
try:
|
||||
lifecycle = SandboxLifecycle(
|
||||
on_timeout=on_timeout,
|
||||
auto_resume=on_timeout == "pause",
|
||||
)
|
||||
sandbox = await AsyncSandbox.create(
|
||||
template=template,
|
||||
api_key=api_key,
|
||||
timeout=timeout,
|
||||
lifecycle=lifecycle,
|
||||
lifecycle={"on_timeout": on_timeout},
|
||||
)
|
||||
try:
|
||||
await _set_stored_sandbox_id(session_id, sandbox.sandbox_id)
|
||||
|
||||
@@ -157,17 +157,14 @@ class TestGetOrCreateSandbox:
|
||||
|
||||
assert result is new_sb
|
||||
mock_cls.create.assert_awaited_once()
|
||||
# Verify lifecycle: pause + auto_resume enabled
|
||||
# Verify lifecycle param is set
|
||||
_, kwargs = mock_cls.create.call_args
|
||||
assert kwargs.get("lifecycle") == {
|
||||
"on_timeout": "pause",
|
||||
"auto_resume": True,
|
||||
}
|
||||
assert kwargs.get("lifecycle") == {"on_timeout": "pause"}
|
||||
# sandbox_id should be saved to Redis
|
||||
redis.set.assert_awaited()
|
||||
|
||||
def test_create_with_on_timeout_kill(self):
|
||||
"""on_timeout='kill' disables auto_resume automatically."""
|
||||
"""on_timeout='kill' is passed through to AsyncSandbox.create."""
|
||||
new_sb = _mock_sandbox("sb-new")
|
||||
redis = _mock_redis(set_nx_result=True, stored_sandbox_id=None)
|
||||
with (
|
||||
@@ -182,10 +179,7 @@ class TestGetOrCreateSandbox:
|
||||
)
|
||||
|
||||
_, kwargs = mock_cls.create.call_args
|
||||
assert kwargs.get("lifecycle") == {
|
||||
"on_timeout": "kill",
|
||||
"auto_resume": False,
|
||||
}
|
||||
assert kwargs.get("lifecycle") == {"on_timeout": "kill"}
|
||||
|
||||
def test_create_failure_releases_slot(self):
|
||||
"""If sandbox creation fails, the Redis creation slot is deleted."""
|
||||
|
||||
@@ -8,16 +8,11 @@ from pydantic_core import PydanticUndefined
|
||||
|
||||
from backend.blocks._base import AnyBlockSchema
|
||||
from backend.copilot.constants import COPILOT_NODE_PREFIX, COPILOT_SESSION_PREFIX
|
||||
from backend.data import db
|
||||
from backend.data.credit import UsageTransactionMetadata, get_user_credit_model
|
||||
from backend.data.db_accessors import workspace_db
|
||||
from backend.data.execution import ExecutionContext
|
||||
from backend.data.model import CredentialsFieldInfo, CredentialsMetaInput
|
||||
from backend.executor.utils import block_usage_cost
|
||||
from backend.integrations.creds_manager import IntegrationCredentialsManager
|
||||
from backend.util.clients import get_database_manager_async_client
|
||||
from backend.util.exceptions import BlockError, InsufficientBalanceError
|
||||
from backend.util.type import coerce_inputs_to_schema
|
||||
from backend.util.exceptions import BlockError
|
||||
|
||||
from .models import BlockOutputResponse, ErrorResponse, ToolResponseBase
|
||||
from .utils import match_credentials_to_requirements
|
||||
@@ -25,26 +20,6 @@ from .utils import match_credentials_to_requirements
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def _get_credits(user_id: str) -> int:
|
||||
"""Get user credits using the adapter pattern (RPC when Prisma unavailable)."""
|
||||
if not db.is_connected():
|
||||
return await get_database_manager_async_client().get_credits(user_id)
|
||||
credit_model = await get_user_credit_model(user_id)
|
||||
return await credit_model.get_credits(user_id)
|
||||
|
||||
|
||||
async def _spend_credits(
|
||||
user_id: str, cost: int, metadata: UsageTransactionMetadata
|
||||
) -> int:
|
||||
"""Spend user credits using the adapter pattern (RPC when Prisma unavailable)."""
|
||||
if not db.is_connected():
|
||||
return await get_database_manager_async_client().spend_credits(
|
||||
user_id, cost, metadata
|
||||
)
|
||||
credit_model = await get_user_credit_model(user_id)
|
||||
return await credit_model.spend_credits(user_id, cost, metadata)
|
||||
|
||||
|
||||
def get_inputs_from_schema(
|
||||
input_schema: dict[str, Any],
|
||||
exclude_fields: set[str] | None = None,
|
||||
@@ -136,23 +111,6 @@ async def execute_block(
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
# Coerce non-matching data types to the expected input schema.
|
||||
coerce_inputs_to_schema(input_data, block.input_schema)
|
||||
|
||||
# Pre-execution credit check
|
||||
cost, cost_filter = block_usage_cost(block, input_data)
|
||||
has_cost = cost > 0
|
||||
if has_cost:
|
||||
balance = await _get_credits(user_id)
|
||||
if balance < cost:
|
||||
return ErrorResponse(
|
||||
message=(
|
||||
f"Insufficient credits to run '{block.name}'. "
|
||||
"Please top up your credits to continue."
|
||||
),
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
# Execute the block and collect outputs
|
||||
outputs: dict[str, list[Any]] = defaultdict(list)
|
||||
async for output_name, output_data in block.execute(
|
||||
@@ -161,37 +119,6 @@ async def execute_block(
|
||||
):
|
||||
outputs[output_name].append(output_data)
|
||||
|
||||
# Charge credits for block execution
|
||||
if has_cost:
|
||||
try:
|
||||
await _spend_credits(
|
||||
user_id=user_id,
|
||||
cost=cost,
|
||||
metadata=UsageTransactionMetadata(
|
||||
graph_exec_id=synthetic_graph_id,
|
||||
graph_id=synthetic_graph_id,
|
||||
node_id=synthetic_node_id,
|
||||
node_exec_id=node_exec_id,
|
||||
block_id=block_id,
|
||||
block=block.name,
|
||||
input=cost_filter,
|
||||
reason="copilot_block_execution",
|
||||
),
|
||||
)
|
||||
except InsufficientBalanceError:
|
||||
logger.warning(
|
||||
"Post-exec credit charge failed for block %s (cost=%d)",
|
||||
block.name,
|
||||
cost,
|
||||
)
|
||||
return ErrorResponse(
|
||||
message=(
|
||||
f"Insufficient credits to complete '{block.name}'. "
|
||||
"Please top up your credits to continue."
|
||||
),
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
return BlockOutputResponse(
|
||||
message=f"Block '{block.name}' executed successfully",
|
||||
block_id=block_id,
|
||||
@@ -202,16 +129,16 @@ async def execute_block(
|
||||
)
|
||||
|
||||
except BlockError as e:
|
||||
logger.warning("Block execution failed: %s", e)
|
||||
logger.warning(f"Block execution failed: {e}")
|
||||
return ErrorResponse(
|
||||
message=f"Block execution failed: {e}",
|
||||
error=str(e),
|
||||
session_id=session_id,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error("Unexpected error executing block: %s", e, exc_info=True)
|
||||
logger.error(f"Unexpected error executing block: {e}", exc_info=True)
|
||||
return ErrorResponse(
|
||||
message="An unexpected error occurred while executing the block",
|
||||
message=f"Failed to execute block: {str(e)}",
|
||||
error=str(e),
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
@@ -1,506 +0,0 @@
|
||||
"""Tests for execute_block — credit charging and type coercion."""
|
||||
|
||||
from collections.abc import AsyncIterator
|
||||
from typing import Any
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from backend.blocks._base import BlockType
|
||||
from backend.copilot.tools.helpers import execute_block
|
||||
from backend.copilot.tools.models import BlockOutputResponse, ErrorResponse
|
||||
|
||||
_USER = "test-user-helpers"
|
||||
_SESSION = "test-session-helpers"
|
||||
|
||||
|
||||
def _make_block(block_id: str = "block-1", name: str = "TestBlock"):
|
||||
"""Create a minimal mock block for execute_block()."""
|
||||
mock = MagicMock()
|
||||
mock.id = block_id
|
||||
mock.name = name
|
||||
mock.block_type = BlockType.STANDARD
|
||||
|
||||
mock.input_schema = MagicMock()
|
||||
mock.input_schema.get_credentials_fields_info.return_value = {}
|
||||
|
||||
async def _execute(
|
||||
input_data: dict, **kwargs: Any
|
||||
) -> AsyncIterator[tuple[str, Any]]:
|
||||
yield "result", "ok"
|
||||
|
||||
mock.execute = _execute
|
||||
return mock
|
||||
|
||||
|
||||
def _patch_workspace():
|
||||
"""Patch workspace_db to return a mock workspace."""
|
||||
mock_workspace = MagicMock()
|
||||
mock_workspace.id = "ws-1"
|
||||
mock_ws_db = MagicMock()
|
||||
mock_ws_db.get_or_create_workspace = AsyncMock(return_value=mock_workspace)
|
||||
return patch("backend.copilot.tools.helpers.workspace_db", return_value=mock_ws_db)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Credit charging tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
class TestExecuteBlockCreditCharging:
|
||||
async def test_charges_credits_when_cost_is_positive(self):
|
||||
"""Block with cost > 0 should call spend_credits after execution."""
|
||||
block = _make_block()
|
||||
mock_spend = AsyncMock()
|
||||
|
||||
with (
|
||||
_patch_workspace(),
|
||||
patch(
|
||||
"backend.copilot.tools.helpers.block_usage_cost",
|
||||
return_value=(10, {"key": "val"}),
|
||||
),
|
||||
patch(
|
||||
"backend.copilot.tools.helpers._get_credits",
|
||||
new_callable=AsyncMock,
|
||||
return_value=100,
|
||||
),
|
||||
patch(
|
||||
"backend.copilot.tools.helpers._spend_credits",
|
||||
new_callable=AsyncMock,
|
||||
side_effect=mock_spend,
|
||||
),
|
||||
):
|
||||
result = await execute_block(
|
||||
block=block,
|
||||
block_id="block-1",
|
||||
input_data={"text": "hello"},
|
||||
user_id=_USER,
|
||||
session_id=_SESSION,
|
||||
node_exec_id="exec-1",
|
||||
matched_credentials={},
|
||||
)
|
||||
|
||||
assert isinstance(result, BlockOutputResponse)
|
||||
assert result.success is True
|
||||
mock_spend.assert_awaited_once()
|
||||
call_kwargs = mock_spend.call_args.kwargs
|
||||
assert call_kwargs["cost"] == 10
|
||||
assert call_kwargs["metadata"].reason == "copilot_block_execution"
|
||||
|
||||
async def test_returns_error_when_insufficient_credits_before_exec(self):
|
||||
"""Pre-execution check should return ErrorResponse when balance < cost."""
|
||||
block = _make_block()
|
||||
|
||||
with (
|
||||
_patch_workspace(),
|
||||
patch(
|
||||
"backend.copilot.tools.helpers.block_usage_cost",
|
||||
return_value=(10, {}),
|
||||
),
|
||||
patch(
|
||||
"backend.copilot.tools.helpers._get_credits",
|
||||
new_callable=AsyncMock,
|
||||
return_value=5, # balance < cost (10)
|
||||
),
|
||||
):
|
||||
result = await execute_block(
|
||||
block=block,
|
||||
block_id="block-1",
|
||||
input_data={},
|
||||
user_id=_USER,
|
||||
session_id=_SESSION,
|
||||
node_exec_id="exec-1",
|
||||
matched_credentials={},
|
||||
)
|
||||
|
||||
assert isinstance(result, ErrorResponse)
|
||||
assert "Insufficient credits" in result.message
|
||||
|
||||
async def test_no_charge_when_cost_is_zero(self):
|
||||
"""Block with cost 0 should not call spend_credits."""
|
||||
block = _make_block()
|
||||
|
||||
with (
|
||||
_patch_workspace(),
|
||||
patch(
|
||||
"backend.copilot.tools.helpers.block_usage_cost",
|
||||
return_value=(0, {}),
|
||||
),
|
||||
patch(
|
||||
"backend.copilot.tools.helpers._get_credits",
|
||||
) as mock_get_credits,
|
||||
patch(
|
||||
"backend.copilot.tools.helpers._spend_credits",
|
||||
) as mock_spend_credits,
|
||||
):
|
||||
result = await execute_block(
|
||||
block=block,
|
||||
block_id="block-1",
|
||||
input_data={},
|
||||
user_id=_USER,
|
||||
session_id=_SESSION,
|
||||
node_exec_id="exec-1",
|
||||
matched_credentials={},
|
||||
)
|
||||
|
||||
assert isinstance(result, BlockOutputResponse)
|
||||
assert result.success is True
|
||||
# Credit functions should not be called at all for zero-cost blocks
|
||||
mock_get_credits.assert_not_awaited()
|
||||
mock_spend_credits.assert_not_awaited()
|
||||
|
||||
async def test_returns_error_on_post_exec_insufficient_balance(self):
|
||||
"""If charging fails after execution, return ErrorResponse."""
|
||||
from backend.util.exceptions import InsufficientBalanceError
|
||||
|
||||
block = _make_block()
|
||||
|
||||
with (
|
||||
_patch_workspace(),
|
||||
patch(
|
||||
"backend.copilot.tools.helpers.block_usage_cost",
|
||||
return_value=(10, {}),
|
||||
),
|
||||
patch(
|
||||
"backend.copilot.tools.helpers._get_credits",
|
||||
new_callable=AsyncMock,
|
||||
return_value=15, # passes pre-check
|
||||
),
|
||||
patch(
|
||||
"backend.copilot.tools.helpers._spend_credits",
|
||||
new_callable=AsyncMock,
|
||||
side_effect=InsufficientBalanceError(
|
||||
"Low balance", _USER, 5, 10
|
||||
), # fails during actual charge (race with concurrent spend)
|
||||
),
|
||||
):
|
||||
result = await execute_block(
|
||||
block=block,
|
||||
block_id="block-1",
|
||||
input_data={},
|
||||
user_id=_USER,
|
||||
session_id=_SESSION,
|
||||
node_exec_id="exec-1",
|
||||
matched_credentials={},
|
||||
)
|
||||
|
||||
assert isinstance(result, ErrorResponse)
|
||||
assert "Insufficient credits" in result.message
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Type coercion tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _make_block_schema(annotations: dict[str, Any]) -> MagicMock:
|
||||
"""Create a mock input_schema with model_fields matching the given annotations."""
|
||||
schema = MagicMock()
|
||||
model_fields = {}
|
||||
for name, ann in annotations.items():
|
||||
field = MagicMock()
|
||||
field.annotation = ann
|
||||
model_fields[name] = field
|
||||
schema.model_fields = model_fields
|
||||
return schema
|
||||
|
||||
|
||||
def _make_coerce_block(
|
||||
block_id: str,
|
||||
name: str,
|
||||
annotations: dict[str, Any],
|
||||
outputs: dict[str, list[Any]] | None = None,
|
||||
) -> MagicMock:
|
||||
"""Create a mock block with typed annotations and a simple execute method."""
|
||||
block = MagicMock()
|
||||
block.id = block_id
|
||||
block.name = name
|
||||
block.input_schema = _make_block_schema(annotations)
|
||||
|
||||
captured_inputs: dict[str, Any] = {}
|
||||
|
||||
async def mock_execute(input_data: dict, **_kwargs: Any):
|
||||
captured_inputs.update(input_data)
|
||||
for output_name, values in (outputs or {"result": ["ok"]}).items():
|
||||
for v in values:
|
||||
yield output_name, v
|
||||
|
||||
block.execute = mock_execute
|
||||
block._captured_inputs = captured_inputs
|
||||
return block
|
||||
|
||||
|
||||
_TEST_SESSION_ID = "test-session-coerce"
|
||||
_TEST_USER_ID = "test-user-coerce"
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_coerce_json_string_to_nested_list():
|
||||
"""JSON string → list[list[str]] (Google Sheets CSV import case)."""
|
||||
block = _make_coerce_block(
|
||||
"sheets-write",
|
||||
"Google Sheets Write",
|
||||
{"values": list[list[str]], "spreadsheet_id": str},
|
||||
)
|
||||
|
||||
mock_workspace_db = MagicMock()
|
||||
mock_workspace_db.get_or_create_workspace = AsyncMock(
|
||||
return_value=MagicMock(id="ws-1")
|
||||
)
|
||||
|
||||
with patch(
|
||||
"backend.copilot.tools.helpers.workspace_db",
|
||||
return_value=mock_workspace_db,
|
||||
):
|
||||
response = await execute_block(
|
||||
block=block,
|
||||
block_id="sheets-write",
|
||||
input_data={
|
||||
"values": '[["Name","Score"],["Alice","90"],["Bob","85"]]',
|
||||
"spreadsheet_id": "abc123",
|
||||
},
|
||||
user_id=_TEST_USER_ID,
|
||||
session_id=_TEST_SESSION_ID,
|
||||
node_exec_id="exec-1",
|
||||
matched_credentials={},
|
||||
)
|
||||
|
||||
assert isinstance(response, BlockOutputResponse)
|
||||
assert response.success is True
|
||||
assert block._captured_inputs["values"] == [
|
||||
["Name", "Score"],
|
||||
["Alice", "90"],
|
||||
["Bob", "85"],
|
||||
]
|
||||
assert isinstance(block._captured_inputs["values"], list)
|
||||
assert isinstance(block._captured_inputs["values"][0], list)
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_coerce_json_string_to_list():
|
||||
"""JSON string → list[str]."""
|
||||
block = _make_coerce_block(
|
||||
"list-block",
|
||||
"List Block",
|
||||
{"items": list[str]},
|
||||
)
|
||||
|
||||
mock_workspace_db = MagicMock()
|
||||
mock_workspace_db.get_or_create_workspace = AsyncMock(
|
||||
return_value=MagicMock(id="ws-1")
|
||||
)
|
||||
|
||||
with patch(
|
||||
"backend.copilot.tools.helpers.workspace_db",
|
||||
return_value=mock_workspace_db,
|
||||
):
|
||||
response = await execute_block(
|
||||
block=block,
|
||||
block_id="list-block",
|
||||
input_data={"items": '["a","b","c"]'},
|
||||
user_id=_TEST_USER_ID,
|
||||
session_id=_TEST_SESSION_ID,
|
||||
node_exec_id="exec-2",
|
||||
matched_credentials={},
|
||||
)
|
||||
|
||||
assert isinstance(response, BlockOutputResponse)
|
||||
assert block._captured_inputs["items"] == ["a", "b", "c"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_coerce_json_string_to_dict():
|
||||
"""JSON string → dict[str, str]."""
|
||||
block = _make_coerce_block(
|
||||
"dict-block",
|
||||
"Dict Block",
|
||||
{"config": dict[str, str]},
|
||||
)
|
||||
|
||||
mock_workspace_db = MagicMock()
|
||||
mock_workspace_db.get_or_create_workspace = AsyncMock(
|
||||
return_value=MagicMock(id="ws-1")
|
||||
)
|
||||
|
||||
with patch(
|
||||
"backend.copilot.tools.helpers.workspace_db",
|
||||
return_value=mock_workspace_db,
|
||||
):
|
||||
response = await execute_block(
|
||||
block=block,
|
||||
block_id="dict-block",
|
||||
input_data={"config": '{"key": "value", "foo": "bar"}'},
|
||||
user_id=_TEST_USER_ID,
|
||||
session_id=_TEST_SESSION_ID,
|
||||
node_exec_id="exec-3",
|
||||
matched_credentials={},
|
||||
)
|
||||
|
||||
assert isinstance(response, BlockOutputResponse)
|
||||
assert block._captured_inputs["config"] == {"key": "value", "foo": "bar"}
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_no_coercion_when_type_matches():
|
||||
"""Already-correct types pass through without coercion."""
|
||||
block = _make_coerce_block(
|
||||
"pass-through",
|
||||
"Pass Through",
|
||||
{"values": list[list[str]], "name": str},
|
||||
)
|
||||
|
||||
original_values = [["a", "b"], ["c", "d"]]
|
||||
mock_workspace_db = MagicMock()
|
||||
mock_workspace_db.get_or_create_workspace = AsyncMock(
|
||||
return_value=MagicMock(id="ws-1")
|
||||
)
|
||||
|
||||
with patch(
|
||||
"backend.copilot.tools.helpers.workspace_db",
|
||||
return_value=mock_workspace_db,
|
||||
):
|
||||
response = await execute_block(
|
||||
block=block,
|
||||
block_id="pass-through",
|
||||
input_data={"values": original_values, "name": "test"},
|
||||
user_id=_TEST_USER_ID,
|
||||
session_id=_TEST_SESSION_ID,
|
||||
node_exec_id="exec-4",
|
||||
matched_credentials={},
|
||||
)
|
||||
|
||||
assert isinstance(response, BlockOutputResponse)
|
||||
assert block._captured_inputs["values"] == original_values
|
||||
assert block._captured_inputs["name"] == "test"
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_coerce_string_to_int():
|
||||
"""String number → int."""
|
||||
block = _make_coerce_block(
|
||||
"int-block",
|
||||
"Int Block",
|
||||
{"count": int},
|
||||
)
|
||||
|
||||
mock_workspace_db = MagicMock()
|
||||
mock_workspace_db.get_or_create_workspace = AsyncMock(
|
||||
return_value=MagicMock(id="ws-1")
|
||||
)
|
||||
|
||||
with patch(
|
||||
"backend.copilot.tools.helpers.workspace_db",
|
||||
return_value=mock_workspace_db,
|
||||
):
|
||||
response = await execute_block(
|
||||
block=block,
|
||||
block_id="int-block",
|
||||
input_data={"count": "42"},
|
||||
user_id=_TEST_USER_ID,
|
||||
session_id=_TEST_SESSION_ID,
|
||||
node_exec_id="exec-5",
|
||||
matched_credentials={},
|
||||
)
|
||||
|
||||
assert isinstance(response, BlockOutputResponse)
|
||||
assert block._captured_inputs["count"] == 42
|
||||
assert isinstance(block._captured_inputs["count"], int)
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_coerce_skips_none_values():
|
||||
"""None values are not coerced (they may be optional fields)."""
|
||||
block = _make_coerce_block(
|
||||
"optional-block",
|
||||
"Optional Block",
|
||||
{"data": list[str], "label": str},
|
||||
)
|
||||
|
||||
mock_workspace_db = MagicMock()
|
||||
mock_workspace_db.get_or_create_workspace = AsyncMock(
|
||||
return_value=MagicMock(id="ws-1")
|
||||
)
|
||||
|
||||
with patch(
|
||||
"backend.copilot.tools.helpers.workspace_db",
|
||||
return_value=mock_workspace_db,
|
||||
):
|
||||
response = await execute_block(
|
||||
block=block,
|
||||
block_id="optional-block",
|
||||
input_data={"label": "test"},
|
||||
user_id=_TEST_USER_ID,
|
||||
session_id=_TEST_SESSION_ID,
|
||||
node_exec_id="exec-6",
|
||||
matched_credentials={},
|
||||
)
|
||||
|
||||
assert isinstance(response, BlockOutputResponse)
|
||||
assert "data" not in block._captured_inputs
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_coerce_union_type_preserves_valid_member():
|
||||
"""Union-typed fields should not be coerced when the value matches a member."""
|
||||
block = _make_coerce_block(
|
||||
"union-block",
|
||||
"Union Block",
|
||||
{"content": str | list[str]},
|
||||
)
|
||||
|
||||
mock_workspace_db = MagicMock()
|
||||
mock_workspace_db.get_or_create_workspace = AsyncMock(
|
||||
return_value=MagicMock(id="ws-1")
|
||||
)
|
||||
|
||||
with patch(
|
||||
"backend.copilot.tools.helpers.workspace_db",
|
||||
return_value=mock_workspace_db,
|
||||
):
|
||||
response = await execute_block(
|
||||
block=block,
|
||||
block_id="union-block",
|
||||
input_data={"content": ["a", "b"]},
|
||||
user_id=_TEST_USER_ID,
|
||||
session_id=_TEST_SESSION_ID,
|
||||
node_exec_id="exec-7",
|
||||
matched_credentials={},
|
||||
)
|
||||
|
||||
assert isinstance(response, BlockOutputResponse)
|
||||
assert block._captured_inputs["content"] == ["a", "b"]
|
||||
assert isinstance(block._captured_inputs["content"], list)
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_coerce_inner_elements_of_generic():
|
||||
"""Inner elements of generic containers are recursively coerced."""
|
||||
block = _make_coerce_block(
|
||||
"inner-coerce",
|
||||
"Inner Coerce",
|
||||
{"values": list[str]},
|
||||
)
|
||||
|
||||
mock_workspace_db = MagicMock()
|
||||
mock_workspace_db.get_or_create_workspace = AsyncMock(
|
||||
return_value=MagicMock(id="ws-1")
|
||||
)
|
||||
|
||||
with patch(
|
||||
"backend.copilot.tools.helpers.workspace_db",
|
||||
return_value=mock_workspace_db,
|
||||
):
|
||||
response = await execute_block(
|
||||
block=block,
|
||||
block_id="inner-coerce",
|
||||
input_data={"values": [1, 2, 3]},
|
||||
user_id=_TEST_USER_ID,
|
||||
session_id=_TEST_SESSION_ID,
|
||||
node_exec_id="exec-8",
|
||||
matched_credentials={},
|
||||
)
|
||||
|
||||
assert isinstance(response, BlockOutputResponse)
|
||||
assert block._captured_inputs["values"] == ["1", "2", "3"]
|
||||
assert all(isinstance(v, str) for v in block._captured_inputs["values"])
|
||||
@@ -512,10 +512,6 @@ class DatabaseManagerAsyncClient(AppServiceClient):
|
||||
list_workspace_files = d.list_workspace_files
|
||||
soft_delete_workspace_file = d.soft_delete_workspace_file
|
||||
|
||||
# ============ Credits ============ #
|
||||
spend_credits = d.spend_credits
|
||||
get_credits = d.get_credits
|
||||
|
||||
# ============ Understanding ============ #
|
||||
get_business_understanding = d.get_business_understanding
|
||||
upsert_business_understanding = d.upsert_business_understanding
|
||||
|
||||
@@ -1,750 +0,0 @@
|
||||
import asyncio
|
||||
import csv
|
||||
import io
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
import socket
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any, Literal, Optional
|
||||
from uuid import uuid4
|
||||
|
||||
import prisma.enums
|
||||
import prisma.models
|
||||
import prisma.types
|
||||
from prisma.errors import UniqueViolationError
|
||||
from pydantic import BaseModel, EmailStr, TypeAdapter, ValidationError
|
||||
|
||||
from backend.data.db import transaction
|
||||
from backend.data.model import User
|
||||
from backend.data.redis_client import get_redis_async
|
||||
from backend.data.tally import get_business_understanding_input_from_tally, mask_email
|
||||
from backend.data.understanding import (
|
||||
BusinessUnderstandingInput,
|
||||
merge_business_understanding_data,
|
||||
)
|
||||
from backend.data.user import get_user_by_email, get_user_by_id
|
||||
from backend.executor.cluster_lock import AsyncClusterLock
|
||||
from backend.util.exceptions import (
|
||||
NotAuthorizedError,
|
||||
NotFoundError,
|
||||
PreconditionFailed,
|
||||
)
|
||||
from backend.util.json import SafeJson
|
||||
from backend.util.settings import Settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
_settings = Settings()
|
||||
|
||||
_WORKER_ID = f"{socket.gethostname()}:{os.getpid()}"
|
||||
|
||||
_tally_seed_tasks: dict[str, asyncio.Task] = {}
|
||||
_TALLY_STALE_SECONDS = 300
|
||||
_MAX_TALLY_ERROR_LENGTH = 200
|
||||
_email_adapter = TypeAdapter(EmailStr)
|
||||
|
||||
MAX_BULK_INVITE_FILE_BYTES = 1024 * 1024
|
||||
MAX_BULK_INVITE_ROWS = 500
|
||||
|
||||
|
||||
class InvitedUserRecord(BaseModel):
|
||||
id: str
|
||||
email: str
|
||||
status: prisma.enums.InvitedUserStatus
|
||||
auth_user_id: Optional[str] = None
|
||||
name: Optional[str] = None
|
||||
tally_understanding: Optional[dict[str, Any]] = None
|
||||
tally_status: prisma.enums.TallyComputationStatus
|
||||
tally_computed_at: Optional[datetime] = None
|
||||
tally_error: Optional[str] = None
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
|
||||
@classmethod
|
||||
def from_db(cls, invited_user: "prisma.models.InvitedUser") -> "InvitedUserRecord":
|
||||
payload = (
|
||||
invited_user.tallyUnderstanding
|
||||
if isinstance(invited_user.tallyUnderstanding, dict)
|
||||
else None
|
||||
)
|
||||
return cls(
|
||||
id=invited_user.id,
|
||||
email=invited_user.email,
|
||||
status=invited_user.status,
|
||||
auth_user_id=invited_user.authUserId,
|
||||
name=invited_user.name,
|
||||
tally_understanding=payload,
|
||||
tally_status=invited_user.tallyStatus,
|
||||
tally_computed_at=invited_user.tallyComputedAt,
|
||||
tally_error=invited_user.tallyError,
|
||||
created_at=invited_user.createdAt,
|
||||
updated_at=invited_user.updatedAt,
|
||||
)
|
||||
|
||||
|
||||
class BulkInvitedUserRowResult(BaseModel):
|
||||
row_number: int
|
||||
email: Optional[str] = None
|
||||
name: Optional[str] = None
|
||||
status: Literal["CREATED", "SKIPPED", "ERROR"]
|
||||
message: str
|
||||
invited_user: Optional[InvitedUserRecord] = None
|
||||
|
||||
|
||||
class BulkInvitedUsersResult(BaseModel):
|
||||
created_count: int
|
||||
skipped_count: int
|
||||
error_count: int
|
||||
results: list[BulkInvitedUserRowResult]
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class _ParsedInviteRow:
|
||||
row_number: int
|
||||
email: str
|
||||
name: Optional[str]
|
||||
|
||||
|
||||
def normalize_email(email: str) -> str:
|
||||
return email.strip().lower()
|
||||
|
||||
|
||||
def _normalize_name(name: Optional[str]) -> Optional[str]:
|
||||
if name is None:
|
||||
return None
|
||||
normalized = name.strip()
|
||||
return normalized or None
|
||||
|
||||
|
||||
def _default_profile_name(email: str, preferred_name: Optional[str]) -> str:
|
||||
if preferred_name:
|
||||
return preferred_name
|
||||
local_part = email.split("@", 1)[0].strip()
|
||||
return local_part or "user"
|
||||
|
||||
|
||||
def _sanitize_username_base(email: str) -> str:
|
||||
local_part = email.split("@", 1)[0].lower()
|
||||
sanitized = re.sub(r"[^a-z0-9-]", "", local_part)
|
||||
sanitized = sanitized.strip("-")
|
||||
return sanitized[:40] or "user"
|
||||
|
||||
|
||||
async def _generate_unique_profile_username(email: str, tx) -> str:
|
||||
base = _sanitize_username_base(email)
|
||||
|
||||
for _ in range(2):
|
||||
candidate = f"{base}-{uuid4().hex[:6]}"
|
||||
existing = await prisma.models.Profile.prisma(tx).find_unique(
|
||||
where={"username": candidate}
|
||||
)
|
||||
if existing is None:
|
||||
return candidate
|
||||
|
||||
raise RuntimeError(f"Unable to generate unique username for {email}")
|
||||
|
||||
|
||||
async def _ensure_default_profile(
|
||||
user_id: str,
|
||||
email: str,
|
||||
preferred_name: Optional[str],
|
||||
tx,
|
||||
) -> None:
|
||||
existing_profile = await prisma.models.Profile.prisma(tx).find_unique(
|
||||
where={"userId": user_id}
|
||||
)
|
||||
if existing_profile is not None:
|
||||
return
|
||||
|
||||
username = await _generate_unique_profile_username(email, tx)
|
||||
await prisma.models.Profile.prisma(tx).create(
|
||||
data=prisma.types.ProfileCreateInput(
|
||||
userId=user_id,
|
||||
name=_default_profile_name(email, preferred_name),
|
||||
username=username,
|
||||
description="I'm new here",
|
||||
links=[],
|
||||
avatarUrl="",
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
async def _ensure_default_onboarding(user_id: str, tx) -> None:
|
||||
await prisma.models.UserOnboarding.prisma(tx).upsert(
|
||||
where={"userId": user_id},
|
||||
data={
|
||||
"create": prisma.types.UserOnboardingCreateInput(userId=user_id),
|
||||
"update": {},
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
async def _apply_tally_understanding(
|
||||
user_id: str,
|
||||
invited_user: "prisma.models.InvitedUser",
|
||||
tx,
|
||||
) -> None:
|
||||
if not isinstance(invited_user.tallyUnderstanding, dict):
|
||||
return
|
||||
|
||||
try:
|
||||
input_data = BusinessUnderstandingInput.model_validate(
|
||||
invited_user.tallyUnderstanding
|
||||
)
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"Malformed tallyUnderstanding for invited user %s; skipping",
|
||||
invited_user.id,
|
||||
exc_info=True,
|
||||
)
|
||||
return
|
||||
|
||||
payload = merge_business_understanding_data({}, input_data)
|
||||
await prisma.models.CoPilotUnderstanding.prisma(tx).upsert(
|
||||
where={"userId": user_id},
|
||||
data={
|
||||
"create": {"userId": user_id, "data": SafeJson(payload)},
|
||||
"update": {"data": SafeJson(payload)},
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
async def list_invited_users(
|
||||
page: int = 1,
|
||||
page_size: int = 50,
|
||||
) -> tuple[list[InvitedUserRecord], int]:
|
||||
total = await prisma.models.InvitedUser.prisma().count()
|
||||
invited_users = await prisma.models.InvitedUser.prisma().find_many(
|
||||
order={"createdAt": "desc"},
|
||||
skip=(page - 1) * page_size,
|
||||
take=page_size,
|
||||
)
|
||||
return [InvitedUserRecord.from_db(iu) for iu in invited_users], total
|
||||
|
||||
|
||||
async def create_invited_user(
|
||||
email: str, name: Optional[str] = None
|
||||
) -> InvitedUserRecord:
|
||||
normalized_email = normalize_email(email)
|
||||
normalized_name = _normalize_name(name)
|
||||
|
||||
existing_user = await prisma.models.User.prisma().find_unique(
|
||||
where={"email": normalized_email}
|
||||
)
|
||||
if existing_user is not None:
|
||||
raise PreconditionFailed("An active user with this email already exists")
|
||||
|
||||
existing_invited_user = await prisma.models.InvitedUser.prisma().find_unique(
|
||||
where={"email": normalized_email}
|
||||
)
|
||||
if existing_invited_user is not None:
|
||||
raise PreconditionFailed("An invited user with this email already exists")
|
||||
|
||||
try:
|
||||
invited_user = await prisma.models.InvitedUser.prisma().create(
|
||||
data={
|
||||
"email": normalized_email,
|
||||
"name": normalized_name,
|
||||
"status": prisma.enums.InvitedUserStatus.INVITED,
|
||||
"tallyStatus": prisma.enums.TallyComputationStatus.PENDING,
|
||||
}
|
||||
)
|
||||
except UniqueViolationError:
|
||||
raise PreconditionFailed("An invited user with this email already exists")
|
||||
schedule_invited_user_tally_precompute(invited_user.id)
|
||||
return InvitedUserRecord.from_db(invited_user)
|
||||
|
||||
|
||||
async def revoke_invited_user(invited_user_id: str) -> InvitedUserRecord:
|
||||
invited_user = await prisma.models.InvitedUser.prisma().find_unique(
|
||||
where={"id": invited_user_id}
|
||||
)
|
||||
if invited_user is None:
|
||||
raise NotFoundError(f"Invited user {invited_user_id} not found")
|
||||
|
||||
if invited_user.status == prisma.enums.InvitedUserStatus.CLAIMED:
|
||||
raise PreconditionFailed("Claimed invited users cannot be revoked")
|
||||
|
||||
if invited_user.status == prisma.enums.InvitedUserStatus.REVOKED:
|
||||
return InvitedUserRecord.from_db(invited_user)
|
||||
|
||||
revoked_user = await prisma.models.InvitedUser.prisma().update(
|
||||
where={"id": invited_user_id},
|
||||
data={"status": prisma.enums.InvitedUserStatus.REVOKED},
|
||||
)
|
||||
if revoked_user is None:
|
||||
raise NotFoundError(f"Invited user {invited_user_id} not found")
|
||||
return InvitedUserRecord.from_db(revoked_user)
|
||||
|
||||
|
||||
async def retry_invited_user_tally(invited_user_id: str) -> InvitedUserRecord:
|
||||
invited_user = await prisma.models.InvitedUser.prisma().find_unique(
|
||||
where={"id": invited_user_id}
|
||||
)
|
||||
if invited_user is None:
|
||||
raise NotFoundError(f"Invited user {invited_user_id} not found")
|
||||
|
||||
if invited_user.status == prisma.enums.InvitedUserStatus.REVOKED:
|
||||
raise PreconditionFailed("Revoked invited users cannot retry Tally seeding")
|
||||
|
||||
refreshed_user = await prisma.models.InvitedUser.prisma().update(
|
||||
where={"id": invited_user_id},
|
||||
data={
|
||||
"tallyUnderstanding": None,
|
||||
"tallyStatus": prisma.enums.TallyComputationStatus.PENDING,
|
||||
"tallyComputedAt": None,
|
||||
"tallyError": None,
|
||||
},
|
||||
)
|
||||
if refreshed_user is None:
|
||||
raise NotFoundError(f"Invited user {invited_user_id} not found")
|
||||
schedule_invited_user_tally_precompute(invited_user_id)
|
||||
return InvitedUserRecord.from_db(refreshed_user)
|
||||
|
||||
|
||||
def _decode_bulk_invite_file(content: bytes) -> str:
|
||||
if len(content) > MAX_BULK_INVITE_FILE_BYTES:
|
||||
raise ValueError("Invite file exceeds the maximum size of 1 MB")
|
||||
|
||||
try:
|
||||
return content.decode("utf-8-sig")
|
||||
except UnicodeDecodeError as exc:
|
||||
raise ValueError("Invite file must be UTF-8 encoded") from exc
|
||||
|
||||
|
||||
def _parse_bulk_invite_csv(text: str) -> list[_ParsedInviteRow]:
|
||||
indexed_rows: list[tuple[int, list[str]]] = []
|
||||
|
||||
for row_number, row in enumerate(csv.reader(io.StringIO(text)), start=1):
|
||||
normalized_row = [cell.strip() for cell in row]
|
||||
if any(normalized_row):
|
||||
indexed_rows.append((row_number, normalized_row))
|
||||
|
||||
if not indexed_rows:
|
||||
return []
|
||||
|
||||
header = [cell.lower() for cell in indexed_rows[0][1]]
|
||||
has_header = "email" in header
|
||||
email_index = header.index("email") if has_header else 0
|
||||
name_index: Optional[int] = (
|
||||
header.index("name")
|
||||
if has_header and "name" in header
|
||||
else (1 if not has_header else None)
|
||||
)
|
||||
data_rows = indexed_rows[1:] if has_header else indexed_rows
|
||||
|
||||
parsed_rows: list[_ParsedInviteRow] = []
|
||||
for row_number, row in data_rows:
|
||||
if len(parsed_rows) >= MAX_BULK_INVITE_ROWS:
|
||||
break
|
||||
email = row[email_index].strip() if len(row) > email_index else ""
|
||||
name = (
|
||||
row[name_index].strip()
|
||||
if name_index is not None and len(row) > name_index
|
||||
else ""
|
||||
)
|
||||
parsed_rows.append(
|
||||
_ParsedInviteRow(
|
||||
row_number=row_number,
|
||||
email=email,
|
||||
name=name or None,
|
||||
)
|
||||
)
|
||||
|
||||
return parsed_rows
|
||||
|
||||
|
||||
def _parse_bulk_invite_text(text: str) -> list[_ParsedInviteRow]:
|
||||
parsed_rows: list[_ParsedInviteRow] = []
|
||||
|
||||
for row_number, raw_line in enumerate(text.splitlines(), start=1):
|
||||
if len(parsed_rows) >= MAX_BULK_INVITE_ROWS:
|
||||
break
|
||||
line = raw_line.strip()
|
||||
if not line or line.startswith("#"):
|
||||
continue
|
||||
|
||||
parsed_rows.append(
|
||||
_ParsedInviteRow(
|
||||
row_number=row_number,
|
||||
email=line,
|
||||
name=None,
|
||||
)
|
||||
)
|
||||
|
||||
return parsed_rows
|
||||
|
||||
|
||||
def _parse_bulk_invite_file(
|
||||
filename: Optional[str],
|
||||
content: bytes,
|
||||
) -> list[_ParsedInviteRow]:
|
||||
text = _decode_bulk_invite_file(content)
|
||||
file_name = filename.lower() if filename else ""
|
||||
parsed_rows = (
|
||||
_parse_bulk_invite_csv(text)
|
||||
if file_name.endswith(".csv")
|
||||
else _parse_bulk_invite_text(text)
|
||||
)
|
||||
|
||||
if not parsed_rows:
|
||||
raise ValueError("Invite file did not contain any emails")
|
||||
|
||||
return parsed_rows
|
||||
|
||||
|
||||
async def bulk_create_invited_users_from_file(
|
||||
filename: Optional[str],
|
||||
content: bytes,
|
||||
) -> BulkInvitedUsersResult:
|
||||
parsed_rows = _parse_bulk_invite_file(filename, content)
|
||||
|
||||
created_count = 0
|
||||
skipped_count = 0
|
||||
error_count = 0
|
||||
results: list[BulkInvitedUserRowResult] = []
|
||||
seen_emails: set[str] = set()
|
||||
|
||||
for row in parsed_rows:
|
||||
row_name = _normalize_name(row.name)
|
||||
|
||||
try:
|
||||
validated_email = _email_adapter.validate_python(row.email)
|
||||
except ValidationError:
|
||||
error_count += 1
|
||||
results.append(
|
||||
BulkInvitedUserRowResult(
|
||||
row_number=row.row_number,
|
||||
email=row.email or None,
|
||||
name=row_name,
|
||||
status="ERROR",
|
||||
message="Invalid email address",
|
||||
)
|
||||
)
|
||||
continue
|
||||
|
||||
normalized_email = normalize_email(str(validated_email))
|
||||
if normalized_email in seen_emails:
|
||||
skipped_count += 1
|
||||
results.append(
|
||||
BulkInvitedUserRowResult(
|
||||
row_number=row.row_number,
|
||||
email=normalized_email,
|
||||
name=row_name,
|
||||
status="SKIPPED",
|
||||
message="Duplicate email in upload file",
|
||||
)
|
||||
)
|
||||
continue
|
||||
|
||||
seen_emails.add(normalized_email)
|
||||
|
||||
try:
|
||||
invited_user = await create_invited_user(normalized_email, row_name)
|
||||
except PreconditionFailed as exc:
|
||||
skipped_count += 1
|
||||
results.append(
|
||||
BulkInvitedUserRowResult(
|
||||
row_number=row.row_number,
|
||||
email=normalized_email,
|
||||
name=row_name,
|
||||
status="SKIPPED",
|
||||
message=str(exc),
|
||||
)
|
||||
)
|
||||
except Exception:
|
||||
masked = mask_email(normalized_email)
|
||||
logger.exception(
|
||||
"Failed to create bulk invite for row %s (%s)",
|
||||
row.row_number,
|
||||
masked,
|
||||
)
|
||||
error_count += 1
|
||||
results.append(
|
||||
BulkInvitedUserRowResult(
|
||||
row_number=row.row_number,
|
||||
email=normalized_email,
|
||||
name=row_name,
|
||||
status="ERROR",
|
||||
message="Unexpected error creating invite",
|
||||
)
|
||||
)
|
||||
else:
|
||||
created_count += 1
|
||||
results.append(
|
||||
BulkInvitedUserRowResult(
|
||||
row_number=row.row_number,
|
||||
email=normalized_email,
|
||||
name=row_name,
|
||||
status="CREATED",
|
||||
message="Invite created",
|
||||
invited_user=invited_user,
|
||||
)
|
||||
)
|
||||
|
||||
return BulkInvitedUsersResult(
|
||||
created_count=created_count,
|
||||
skipped_count=skipped_count,
|
||||
error_count=error_count,
|
||||
results=results,
|
||||
)
|
||||
|
||||
|
||||
async def _compute_invited_user_tally_seed(invited_user_id: str) -> None:
|
||||
invited_user = await prisma.models.InvitedUser.prisma().find_unique(
|
||||
where={"id": invited_user_id}
|
||||
)
|
||||
if invited_user is None:
|
||||
return
|
||||
|
||||
if invited_user.status == prisma.enums.InvitedUserStatus.REVOKED:
|
||||
return
|
||||
|
||||
try:
|
||||
r = await get_redis_async()
|
||||
except Exception:
|
||||
r = None
|
||||
|
||||
lock: AsyncClusterLock | None = None
|
||||
|
||||
if r is not None:
|
||||
lock = AsyncClusterLock(
|
||||
redis=r,
|
||||
key=f"tally_seed:{invited_user_id}",
|
||||
owner_id=_WORKER_ID,
|
||||
timeout=_TALLY_STALE_SECONDS,
|
||||
)
|
||||
current_owner = await lock.try_acquire()
|
||||
|
||||
if current_owner is None:
|
||||
logger.warn("Redis unvailable for tally lock - skipping tally enrichement")
|
||||
return
|
||||
elif current_owner != _WORKER_ID:
|
||||
logger.debug(
|
||||
"Tally seed for %s already locked by %s, skipping",
|
||||
invited_user_id,
|
||||
current_owner,
|
||||
)
|
||||
return
|
||||
if (
|
||||
invited_user.tallyStatus == prisma.enums.TallyComputationStatus.RUNNING
|
||||
and invited_user.updatedAt is not None
|
||||
):
|
||||
age = (datetime.now(timezone.utc) - invited_user.updatedAt).total_seconds()
|
||||
if age < _TALLY_STALE_SECONDS:
|
||||
logger.debug(
|
||||
"Tally task for %s still RUNNING (age=%ds), skipping",
|
||||
invited_user_id,
|
||||
int(age),
|
||||
)
|
||||
return
|
||||
logger.info(
|
||||
"Tally task for %s is stale (age=%ds), re-running",
|
||||
invited_user_id,
|
||||
int(age),
|
||||
)
|
||||
|
||||
await prisma.models.InvitedUser.prisma().update(
|
||||
where={"id": invited_user_id},
|
||||
data={
|
||||
"tallyStatus": prisma.enums.TallyComputationStatus.RUNNING,
|
||||
"tallyError": None,
|
||||
},
|
||||
)
|
||||
|
||||
try:
|
||||
input_data = await get_business_understanding_input_from_tally(
|
||||
invited_user.email,
|
||||
require_api_key=True,
|
||||
)
|
||||
payload = (
|
||||
SafeJson(input_data.model_dump(exclude_none=True))
|
||||
if input_data is not None
|
||||
else None
|
||||
)
|
||||
await prisma.models.InvitedUser.prisma().update(
|
||||
where={"id": invited_user_id},
|
||||
data={
|
||||
"tallyUnderstanding": payload,
|
||||
"tallyStatus": prisma.enums.TallyComputationStatus.READY,
|
||||
"tallyComputedAt": datetime.now(timezone.utc),
|
||||
"tallyError": None,
|
||||
},
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.exception(
|
||||
"Failed to compute Tally understanding for invited user %s",
|
||||
invited_user_id,
|
||||
)
|
||||
sanitized_error = re.sub(
|
||||
r"https?://\S+", "<url>", f"{type(exc).__name__}: {exc}"
|
||||
)[:_MAX_TALLY_ERROR_LENGTH]
|
||||
await prisma.models.InvitedUser.prisma().update(
|
||||
where={"id": invited_user_id},
|
||||
data={
|
||||
"tallyStatus": prisma.enums.TallyComputationStatus.FAILED,
|
||||
"tallyError": sanitized_error,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def schedule_invited_user_tally_precompute(invited_user_id: str) -> None:
|
||||
existing = _tally_seed_tasks.get(invited_user_id)
|
||||
if existing is not None and not existing.done():
|
||||
logger.debug("Tally task already running for %s, skipping", invited_user_id)
|
||||
return
|
||||
|
||||
task = asyncio.create_task(_compute_invited_user_tally_seed(invited_user_id))
|
||||
_tally_seed_tasks[invited_user_id] = task
|
||||
|
||||
def _on_done(t: asyncio.Task, _id: str = invited_user_id) -> None:
|
||||
if _tally_seed_tasks.get(_id) is t:
|
||||
del _tally_seed_tasks[_id]
|
||||
|
||||
task.add_done_callback(_on_done)
|
||||
|
||||
|
||||
async def _open_signup_create_user(
|
||||
auth_user_id: str,
|
||||
normalized_email: str,
|
||||
metadata_name: Optional[str],
|
||||
) -> User:
|
||||
"""Create a user without requiring an invite (open signup mode)."""
|
||||
preferred_name = _normalize_name(metadata_name)
|
||||
try:
|
||||
async with transaction() as tx:
|
||||
user = await prisma.models.User.prisma(tx).create(
|
||||
data=prisma.types.UserCreateInput(
|
||||
id=auth_user_id,
|
||||
email=normalized_email,
|
||||
name=preferred_name,
|
||||
)
|
||||
)
|
||||
await _ensure_default_profile(
|
||||
auth_user_id, normalized_email, preferred_name, tx
|
||||
)
|
||||
await _ensure_default_onboarding(auth_user_id, tx)
|
||||
except UniqueViolationError:
|
||||
existing = await prisma.models.User.prisma().find_unique(
|
||||
where={"id": auth_user_id}
|
||||
)
|
||||
if existing is not None:
|
||||
return User.from_db(existing)
|
||||
raise
|
||||
|
||||
return User.from_db(user)
|
||||
|
||||
|
||||
# TODO: We need to change this functions logic before going live
|
||||
async def get_or_activate_user(user_data: dict) -> User:
|
||||
auth_user_id = user_data.get("sub")
|
||||
if not auth_user_id:
|
||||
raise NotAuthorizedError("User ID not found in token")
|
||||
|
||||
auth_email = user_data.get("email")
|
||||
if not auth_email:
|
||||
raise NotAuthorizedError("Email not found in token")
|
||||
|
||||
normalized_email = normalize_email(auth_email)
|
||||
user_metadata = user_data.get("user_metadata")
|
||||
metadata_name = (
|
||||
user_metadata.get("name") if isinstance(user_metadata, dict) else None
|
||||
)
|
||||
|
||||
existing_user = None
|
||||
try:
|
||||
existing_user = await get_user_by_id(auth_user_id)
|
||||
except ValueError:
|
||||
existing_user = None
|
||||
except Exception:
|
||||
logger.exception("Error on get user by id during tally enrichment process")
|
||||
raise
|
||||
|
||||
if existing_user is not None:
|
||||
return existing_user
|
||||
|
||||
if not _settings.config.enable_invite_gate or normalized_email.endswith("@agpt.co"):
|
||||
return await _open_signup_create_user(
|
||||
auth_user_id, normalized_email, metadata_name
|
||||
)
|
||||
|
||||
invited_user = await prisma.models.InvitedUser.prisma().find_unique(
|
||||
where={"email": normalized_email}
|
||||
)
|
||||
if invited_user is None:
|
||||
raise NotAuthorizedError("Your email is not allowed to access the platform")
|
||||
|
||||
if invited_user.status != prisma.enums.InvitedUserStatus.INVITED:
|
||||
raise NotAuthorizedError("Your invitation is no longer active")
|
||||
|
||||
try:
|
||||
async with transaction() as tx:
|
||||
current_user = await prisma.models.User.prisma(tx).find_unique(
|
||||
where={"id": auth_user_id}
|
||||
)
|
||||
if current_user is not None:
|
||||
return User.from_db(current_user)
|
||||
|
||||
current_invited_user = await prisma.models.InvitedUser.prisma(
|
||||
tx
|
||||
).find_unique(where={"email": normalized_email})
|
||||
if current_invited_user is None:
|
||||
raise NotAuthorizedError(
|
||||
"Your email is not allowed to access the platform"
|
||||
)
|
||||
|
||||
if current_invited_user.status != prisma.enums.InvitedUserStatus.INVITED:
|
||||
raise NotAuthorizedError("Your invitation is no longer active")
|
||||
|
||||
if current_invited_user.authUserId not in (None, auth_user_id):
|
||||
raise NotAuthorizedError("Your invitation has already been claimed")
|
||||
|
||||
preferred_name = current_invited_user.name or _normalize_name(metadata_name)
|
||||
await prisma.models.User.prisma(tx).create(
|
||||
data=prisma.types.UserCreateInput(
|
||||
id=auth_user_id,
|
||||
email=normalized_email,
|
||||
name=preferred_name,
|
||||
)
|
||||
)
|
||||
|
||||
await prisma.models.InvitedUser.prisma(tx).update(
|
||||
where={"id": current_invited_user.id},
|
||||
data={
|
||||
"status": prisma.enums.InvitedUserStatus.CLAIMED,
|
||||
"authUserId": auth_user_id,
|
||||
},
|
||||
)
|
||||
|
||||
await _ensure_default_profile(
|
||||
auth_user_id,
|
||||
normalized_email,
|
||||
preferred_name,
|
||||
tx,
|
||||
)
|
||||
await _ensure_default_onboarding(auth_user_id, tx)
|
||||
await _apply_tally_understanding(auth_user_id, current_invited_user, tx)
|
||||
except UniqueViolationError:
|
||||
logger.info("Concurrent activation for user %s; re-fetching", auth_user_id)
|
||||
already_created = await prisma.models.User.prisma().find_unique(
|
||||
where={"id": auth_user_id}
|
||||
)
|
||||
if already_created is not None:
|
||||
return User.from_db(already_created)
|
||||
raise RuntimeError(
|
||||
f"UniqueViolationError during activation but user {auth_user_id} not found"
|
||||
)
|
||||
|
||||
get_user_by_id.cache_delete(auth_user_id)
|
||||
get_user_by_email.cache_delete(normalized_email)
|
||||
|
||||
activated_user = await prisma.models.User.prisma().find_unique(
|
||||
where={"id": auth_user_id}
|
||||
)
|
||||
if activated_user is None:
|
||||
raise RuntimeError(
|
||||
f"Activated user {auth_user_id} was not found after creation"
|
||||
)
|
||||
|
||||
return User.from_db(activated_user)
|
||||
@@ -1,335 +0,0 @@
|
||||
from contextlib import asynccontextmanager
|
||||
from datetime import datetime, timezone
|
||||
from types import SimpleNamespace
|
||||
from typing import cast
|
||||
from unittest.mock import AsyncMock, Mock
|
||||
|
||||
import prisma.enums
|
||||
import prisma.models
|
||||
import pytest
|
||||
import pytest_mock
|
||||
|
||||
from backend.util.exceptions import NotAuthorizedError, PreconditionFailed
|
||||
|
||||
from .invited_user import (
|
||||
InvitedUserRecord,
|
||||
bulk_create_invited_users_from_file,
|
||||
create_invited_user,
|
||||
get_or_activate_user,
|
||||
retry_invited_user_tally,
|
||||
)
|
||||
|
||||
|
||||
def _invited_user_db_record(
|
||||
*,
|
||||
status: prisma.enums.InvitedUserStatus = prisma.enums.InvitedUserStatus.INVITED,
|
||||
tally_understanding: dict | None = None,
|
||||
):
|
||||
now = datetime.now(timezone.utc)
|
||||
return SimpleNamespace(
|
||||
id="invite-1",
|
||||
email="invited@example.com",
|
||||
status=status,
|
||||
authUserId=None,
|
||||
name="Invited User",
|
||||
tallyUnderstanding=tally_understanding,
|
||||
tallyStatus=prisma.enums.TallyComputationStatus.PENDING,
|
||||
tallyComputedAt=None,
|
||||
tallyError=None,
|
||||
createdAt=now,
|
||||
updatedAt=now,
|
||||
)
|
||||
|
||||
|
||||
def _invited_user_record(
|
||||
*,
|
||||
status: prisma.enums.InvitedUserStatus = prisma.enums.InvitedUserStatus.INVITED,
|
||||
tally_understanding: dict | None = None,
|
||||
):
|
||||
return InvitedUserRecord.from_db(
|
||||
cast(
|
||||
prisma.models.InvitedUser,
|
||||
_invited_user_db_record(
|
||||
status=status,
|
||||
tally_understanding=tally_understanding,
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def _user_db_record():
|
||||
now = datetime.now(timezone.utc)
|
||||
return SimpleNamespace(
|
||||
id="auth-user-1",
|
||||
email="invited@example.com",
|
||||
emailVerified=True,
|
||||
name="Invited User",
|
||||
createdAt=now,
|
||||
updatedAt=now,
|
||||
metadata={},
|
||||
integrations="",
|
||||
stripeCustomerId=None,
|
||||
topUpConfig=None,
|
||||
maxEmailsPerDay=3,
|
||||
notifyOnAgentRun=True,
|
||||
notifyOnZeroBalance=True,
|
||||
notifyOnLowBalance=True,
|
||||
notifyOnBlockExecutionFailed=True,
|
||||
notifyOnContinuousAgentError=True,
|
||||
notifyOnDailySummary=True,
|
||||
notifyOnWeeklySummary=True,
|
||||
notifyOnMonthlySummary=True,
|
||||
notifyOnAgentApproved=True,
|
||||
notifyOnAgentRejected=True,
|
||||
timezone="not-set",
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_invited_user_rejects_existing_active_user(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
) -> None:
|
||||
user_repo = Mock()
|
||||
user_repo.find_unique = AsyncMock(return_value=_user_db_record())
|
||||
invited_user_repo = Mock()
|
||||
invited_user_repo.find_unique = AsyncMock()
|
||||
|
||||
mocker.patch(
|
||||
"backend.data.invited_user.prisma.models.User.prisma", return_value=user_repo
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.data.invited_user.prisma.models.InvitedUser.prisma",
|
||||
return_value=invited_user_repo,
|
||||
)
|
||||
|
||||
with pytest.raises(PreconditionFailed):
|
||||
await create_invited_user("Invited@example.com")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_invited_user_schedules_tally_seed(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
) -> None:
|
||||
user_repo = Mock()
|
||||
user_repo.find_unique = AsyncMock(return_value=None)
|
||||
invited_user_repo = Mock()
|
||||
invited_user_repo.find_unique = AsyncMock(return_value=None)
|
||||
invited_user_repo.create = AsyncMock(return_value=_invited_user_db_record())
|
||||
schedule = mocker.patch(
|
||||
"backend.data.invited_user.schedule_invited_user_tally_precompute"
|
||||
)
|
||||
|
||||
mocker.patch(
|
||||
"backend.data.invited_user.prisma.models.User.prisma", return_value=user_repo
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.data.invited_user.prisma.models.InvitedUser.prisma",
|
||||
return_value=invited_user_repo,
|
||||
)
|
||||
|
||||
invited_user = await create_invited_user("Invited@example.com", "Invited User")
|
||||
|
||||
assert invited_user.email == "invited@example.com"
|
||||
invited_user_repo.create.assert_awaited_once()
|
||||
schedule.assert_called_once_with("invite-1")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_retry_invited_user_tally_resets_state_and_schedules(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
) -> None:
|
||||
invited_user_repo = Mock()
|
||||
invited_user_repo.find_unique = AsyncMock(return_value=_invited_user_db_record())
|
||||
invited_user_repo.update = AsyncMock(return_value=_invited_user_db_record())
|
||||
schedule = mocker.patch(
|
||||
"backend.data.invited_user.schedule_invited_user_tally_precompute"
|
||||
)
|
||||
|
||||
mocker.patch(
|
||||
"backend.data.invited_user.prisma.models.InvitedUser.prisma",
|
||||
return_value=invited_user_repo,
|
||||
)
|
||||
|
||||
invited_user = await retry_invited_user_tally("invite-1")
|
||||
|
||||
assert invited_user.id == "invite-1"
|
||||
invited_user_repo.update.assert_awaited_once()
|
||||
schedule.assert_called_once_with("invite-1")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_or_activate_user_requires_invite(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
) -> None:
|
||||
invited_user_repo = Mock()
|
||||
invited_user_repo.find_unique = AsyncMock(return_value=None)
|
||||
|
||||
mock_get_user_by_id = AsyncMock(side_effect=ValueError("User not found"))
|
||||
mock_get_user_by_id.cache_delete = Mock()
|
||||
mocker.patch(
|
||||
"backend.data.invited_user.get_user_by_id",
|
||||
mock_get_user_by_id,
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.data.invited_user._settings.config.enable_invite_gate",
|
||||
True,
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.data.invited_user.prisma.models.InvitedUser.prisma",
|
||||
return_value=invited_user_repo,
|
||||
)
|
||||
|
||||
with pytest.raises(NotAuthorizedError):
|
||||
await get_or_activate_user(
|
||||
{"sub": "auth-user-1", "email": "invited@example.com"}
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_or_activate_user_creates_user_from_invite(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
) -> None:
|
||||
tx = object()
|
||||
invited_user = _invited_user_db_record(
|
||||
tally_understanding={"user_name": "Invited User", "industry": "Automation"}
|
||||
)
|
||||
created_user = _user_db_record()
|
||||
|
||||
outside_user_repo = Mock()
|
||||
# Only called once at post-transaction verification (line 741);
|
||||
# get_user_by_id (line 657) uses prisma.user.find_unique, not this mock.
|
||||
outside_user_repo.find_unique = AsyncMock(return_value=created_user)
|
||||
|
||||
inside_user_repo = Mock()
|
||||
inside_user_repo.find_unique = AsyncMock(return_value=None)
|
||||
inside_user_repo.create = AsyncMock(return_value=created_user)
|
||||
|
||||
outside_invited_repo = Mock()
|
||||
outside_invited_repo.find_unique = AsyncMock(return_value=invited_user)
|
||||
|
||||
inside_invited_repo = Mock()
|
||||
inside_invited_repo.find_unique = AsyncMock(return_value=invited_user)
|
||||
inside_invited_repo.update = AsyncMock(return_value=invited_user)
|
||||
|
||||
def user_prisma(client=None):
|
||||
return inside_user_repo if client is tx else outside_user_repo
|
||||
|
||||
def invited_user_prisma(client=None):
|
||||
return inside_invited_repo if client is tx else outside_invited_repo
|
||||
|
||||
@asynccontextmanager
|
||||
async def fake_transaction():
|
||||
yield tx
|
||||
|
||||
# Mock get_user_by_id since it uses prisma.user.find_unique (global client),
|
||||
# not prisma.models.User.prisma().find_unique which we mock above.
|
||||
mock_get_user_by_id = AsyncMock(side_effect=ValueError("User not found"))
|
||||
mock_get_user_by_id.cache_delete = Mock()
|
||||
mocker.patch(
|
||||
"backend.data.invited_user.get_user_by_id",
|
||||
mock_get_user_by_id,
|
||||
)
|
||||
mock_get_user_by_email = AsyncMock()
|
||||
mock_get_user_by_email.cache_delete = Mock()
|
||||
mocker.patch(
|
||||
"backend.data.invited_user.get_user_by_email",
|
||||
mock_get_user_by_email,
|
||||
)
|
||||
ensure_profile = mocker.patch(
|
||||
"backend.data.invited_user._ensure_default_profile",
|
||||
AsyncMock(),
|
||||
)
|
||||
ensure_onboarding = mocker.patch(
|
||||
"backend.data.invited_user._ensure_default_onboarding",
|
||||
AsyncMock(),
|
||||
)
|
||||
apply_tally = mocker.patch(
|
||||
"backend.data.invited_user._apply_tally_understanding",
|
||||
AsyncMock(),
|
||||
)
|
||||
mocker.patch("backend.data.invited_user.transaction", fake_transaction)
|
||||
mocker.patch(
|
||||
"backend.data.invited_user.prisma.models.User.prisma", side_effect=user_prisma
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.data.invited_user.prisma.models.InvitedUser.prisma",
|
||||
side_effect=invited_user_prisma,
|
||||
)
|
||||
|
||||
user = await get_or_activate_user(
|
||||
{
|
||||
"sub": "auth-user-1",
|
||||
"email": "Invited@example.com",
|
||||
"user_metadata": {"name": "Invited User"},
|
||||
}
|
||||
)
|
||||
|
||||
assert user.id == "auth-user-1"
|
||||
inside_user_repo.create.assert_awaited_once()
|
||||
inside_invited_repo.update.assert_awaited_once()
|
||||
ensure_profile.assert_awaited_once()
|
||||
ensure_onboarding.assert_awaited_once_with("auth-user-1", tx)
|
||||
apply_tally.assert_awaited_once_with("auth-user-1", invited_user, tx)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_bulk_create_invited_users_from_text_file(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
) -> None:
|
||||
create_invited = mocker.patch(
|
||||
"backend.data.invited_user.create_invited_user",
|
||||
AsyncMock(
|
||||
side_effect=[
|
||||
_invited_user_record(),
|
||||
_invited_user_record(),
|
||||
]
|
||||
),
|
||||
)
|
||||
|
||||
result = await bulk_create_invited_users_from_file(
|
||||
"invites.txt",
|
||||
b"Invited@example.com\nsecond@example.com\n",
|
||||
)
|
||||
|
||||
assert result.created_count == 2
|
||||
assert result.skipped_count == 0
|
||||
assert result.error_count == 0
|
||||
assert [row.status for row in result.results] == ["CREATED", "CREATED"]
|
||||
assert create_invited.await_count == 2
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_bulk_create_invited_users_handles_csv_duplicates_and_invalid_rows(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
) -> None:
|
||||
create_invited = mocker.patch(
|
||||
"backend.data.invited_user.create_invited_user",
|
||||
AsyncMock(
|
||||
side_effect=[
|
||||
_invited_user_record(),
|
||||
PreconditionFailed("An invited user with this email already exists"),
|
||||
]
|
||||
),
|
||||
)
|
||||
|
||||
result = await bulk_create_invited_users_from_file(
|
||||
"invites.csv",
|
||||
(
|
||||
"email,name\n"
|
||||
"valid@example.com,Valid User\n"
|
||||
"not-an-email,Bad Row\n"
|
||||
"valid@example.com,Duplicate In File\n"
|
||||
"existing@example.com,Existing User\n"
|
||||
).encode("utf-8"),
|
||||
)
|
||||
|
||||
assert result.created_count == 1
|
||||
assert result.skipped_count == 2
|
||||
assert result.error_count == 1
|
||||
assert [row.status for row in result.results] == [
|
||||
"CREATED",
|
||||
"ERROR",
|
||||
"SKIPPED",
|
||||
"SKIPPED",
|
||||
]
|
||||
assert create_invited.await_count == 2
|
||||
@@ -41,7 +41,7 @@ _MAX_PAGES = 100
|
||||
_LLM_TIMEOUT = 30
|
||||
|
||||
|
||||
def mask_email(email: str) -> str:
|
||||
def _mask_email(email: str) -> str:
|
||||
"""Mask an email for safe logging: 'alice@example.com' -> 'a***e@example.com'."""
|
||||
try:
|
||||
local, domain = email.rsplit("@", 1)
|
||||
@@ -196,7 +196,8 @@ async def _refresh_cache(form_id: str) -> tuple[dict, list]:
|
||||
|
||||
Returns (email_index, questions).
|
||||
"""
|
||||
client = _make_tally_client(_settings.secrets.tally_api_key)
|
||||
settings = Settings()
|
||||
client = _make_tally_client(settings.secrets.tally_api_key)
|
||||
|
||||
redis = await get_redis_async()
|
||||
last_fetch_key = _LAST_FETCH_KEY.format(form_id=form_id)
|
||||
@@ -331,9 +332,6 @@ Fields:
|
||||
- current_software (list of strings): software/tools currently used
|
||||
- existing_automation (list of strings): existing automations
|
||||
- additional_notes (string): any additional context
|
||||
- suggested_prompts (list of 5 strings): short action prompts (each under 20 words) that would help \
|
||||
this person get started with automating their work. Should be specific to their industry, role, and \
|
||||
pain points; actionable and conversational in tone; focused on automation opportunities.
|
||||
|
||||
Form data:
|
||||
"""
|
||||
@@ -341,21 +339,21 @@ Form data:
|
||||
_EXTRACTION_SUFFIX = "\n\nReturn ONLY valid JSON."
|
||||
|
||||
|
||||
async def extract_business_understanding_from_tally(
|
||||
async def extract_business_understanding(
|
||||
formatted_text: str,
|
||||
) -> BusinessUnderstandingInput:
|
||||
"""
|
||||
Use an LLM to extract structured business understanding from form text.
|
||||
"""Use an LLM to extract structured business understanding from form text.
|
||||
|
||||
Raises on timeout or unparseable response so the caller can handle it.
|
||||
"""
|
||||
api_key = _settings.secrets.open_router_api_key
|
||||
settings = Settings()
|
||||
api_key = settings.secrets.open_router_api_key
|
||||
client = AsyncOpenAI(api_key=api_key, base_url=OPENROUTER_BASE_URL)
|
||||
|
||||
try:
|
||||
response = await asyncio.wait_for(
|
||||
client.chat.completions.create(
|
||||
model=_settings.config.tally_extraction_llm_model,
|
||||
model="openai/gpt-4o-mini",
|
||||
messages=[
|
||||
{
|
||||
"role": "user",
|
||||
@@ -380,57 +378,9 @@ async def extract_business_understanding_from_tally(
|
||||
|
||||
# Filter out null values before constructing
|
||||
cleaned = {k: v for k, v in data.items() if v is not None}
|
||||
|
||||
# Validate suggested_prompts: filter >20 words, keep top 3
|
||||
raw_prompts = cleaned.get("suggested_prompts", [])
|
||||
if isinstance(raw_prompts, list):
|
||||
valid = [
|
||||
p.strip()
|
||||
for p in raw_prompts
|
||||
if isinstance(p, str) and len(p.strip().split()) <= 20
|
||||
]
|
||||
# This will keep up to 3 suggestions
|
||||
short_prompts = valid[:3] if valid else None
|
||||
if short_prompts:
|
||||
cleaned["suggested_prompts"] = short_prompts
|
||||
else:
|
||||
# We dont want to add a None value suggested_prompts field
|
||||
cleaned.pop("suggested_prompts", None)
|
||||
else:
|
||||
# suggested_prompts must be a list - removing it as its not here
|
||||
cleaned.pop("suggested_prompts", None)
|
||||
|
||||
return BusinessUnderstandingInput(**cleaned)
|
||||
|
||||
|
||||
async def get_business_understanding_input_from_tally(
|
||||
email: str,
|
||||
*,
|
||||
require_api_key: bool = False,
|
||||
) -> Optional[BusinessUnderstandingInput]:
|
||||
if not _settings.secrets.tally_api_key:
|
||||
if require_api_key:
|
||||
raise RuntimeError("Tally API key is not configured")
|
||||
logger.debug("Tally: no API key configured, skipping")
|
||||
return None
|
||||
|
||||
masked = mask_email(email)
|
||||
result = await find_submission_by_email(TALLY_FORM_ID, email)
|
||||
if result is None:
|
||||
logger.debug(f"Tally: no submission found for {masked}")
|
||||
return None
|
||||
|
||||
submission, questions = result
|
||||
logger.info(f"Tally: found submission for {masked}, extracting understanding")
|
||||
|
||||
formatted = format_submission_for_llm(submission, questions)
|
||||
if not formatted.strip():
|
||||
logger.warning("Tally: formatted submission was empty, skipping")
|
||||
return None
|
||||
|
||||
return await extract_business_understanding_from_tally(formatted)
|
||||
|
||||
|
||||
async def populate_understanding_from_tally(user_id: str, email: str) -> None:
|
||||
"""Main orchestrator: check Tally for a matching submission and populate understanding.
|
||||
|
||||
@@ -445,10 +395,30 @@ async def populate_understanding_from_tally(user_id: str, email: str) -> None:
|
||||
)
|
||||
return
|
||||
|
||||
understanding_input = await get_business_understanding_input_from_tally(email)
|
||||
if understanding_input is None:
|
||||
# Check API key is configured
|
||||
settings = Settings()
|
||||
if not settings.secrets.tally_api_key:
|
||||
logger.debug("Tally: no API key configured, skipping")
|
||||
return
|
||||
|
||||
# Look up submission by email
|
||||
masked = _mask_email(email)
|
||||
result = await find_submission_by_email(TALLY_FORM_ID, email)
|
||||
if result is None:
|
||||
logger.debug(f"Tally: no submission found for {masked}")
|
||||
return
|
||||
|
||||
submission, questions = result
|
||||
logger.info(f"Tally: found submission for {masked}, extracting understanding")
|
||||
|
||||
# Format and extract
|
||||
formatted = format_submission_for_llm(submission, questions)
|
||||
if not formatted.strip():
|
||||
logger.warning("Tally: formatted submission was empty, skipping")
|
||||
return
|
||||
|
||||
understanding_input = await extract_business_understanding(formatted)
|
||||
|
||||
# Upsert into database
|
||||
await upsert_business_understanding(user_id, understanding_input)
|
||||
logger.info(f"Tally: successfully populated understanding for user {user_id}")
|
||||
|
||||
@@ -12,11 +12,11 @@ from backend.data.tally import (
|
||||
_build_email_index,
|
||||
_format_answer,
|
||||
_make_tally_client,
|
||||
_mask_email,
|
||||
_refresh_cache,
|
||||
extract_business_understanding_from_tally,
|
||||
extract_business_understanding,
|
||||
find_submission_by_email,
|
||||
format_submission_for_llm,
|
||||
mask_email,
|
||||
populate_understanding_from_tally,
|
||||
)
|
||||
|
||||
@@ -248,7 +248,7 @@ async def test_populate_understanding_skips_no_api_key():
|
||||
new_callable=AsyncMock,
|
||||
return_value=None,
|
||||
),
|
||||
patch("backend.data.tally._settings", mock_settings),
|
||||
patch("backend.data.tally.Settings", return_value=mock_settings),
|
||||
patch(
|
||||
"backend.data.tally.find_submission_by_email",
|
||||
new_callable=AsyncMock,
|
||||
@@ -284,7 +284,6 @@ async def test_populate_understanding_full_flow():
|
||||
],
|
||||
}
|
||||
mock_input = MagicMock()
|
||||
mock_input.suggested_prompts = ["Prompt 1", "Prompt 2", "Prompt 3"]
|
||||
|
||||
with (
|
||||
patch(
|
||||
@@ -292,14 +291,14 @@ async def test_populate_understanding_full_flow():
|
||||
new_callable=AsyncMock,
|
||||
return_value=None,
|
||||
),
|
||||
patch("backend.data.tally._settings", mock_settings),
|
||||
patch("backend.data.tally.Settings", return_value=mock_settings),
|
||||
patch(
|
||||
"backend.data.tally.find_submission_by_email",
|
||||
new_callable=AsyncMock,
|
||||
return_value=(submission, SAMPLE_QUESTIONS),
|
||||
),
|
||||
patch(
|
||||
"backend.data.tally.extract_business_understanding_from_tally",
|
||||
"backend.data.tally.extract_business_understanding",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_input,
|
||||
) as mock_extract,
|
||||
@@ -332,14 +331,14 @@ async def test_populate_understanding_handles_llm_timeout():
|
||||
new_callable=AsyncMock,
|
||||
return_value=None,
|
||||
),
|
||||
patch("backend.data.tally._settings", mock_settings),
|
||||
patch("backend.data.tally.Settings", return_value=mock_settings),
|
||||
patch(
|
||||
"backend.data.tally.find_submission_by_email",
|
||||
new_callable=AsyncMock,
|
||||
return_value=(submission, SAMPLE_QUESTIONS),
|
||||
),
|
||||
patch(
|
||||
"backend.data.tally.extract_business_understanding_from_tally",
|
||||
"backend.data.tally.extract_business_understanding",
|
||||
new_callable=AsyncMock,
|
||||
side_effect=asyncio.TimeoutError(),
|
||||
),
|
||||
@@ -357,13 +356,13 @@ async def test_populate_understanding_handles_llm_timeout():
|
||||
|
||||
|
||||
def test_mask_email():
|
||||
assert mask_email("alice@example.com") == "a***e@example.com"
|
||||
assert mask_email("ab@example.com") == "a***@example.com"
|
||||
assert mask_email("a@example.com") == "a***@example.com"
|
||||
assert _mask_email("alice@example.com") == "a***e@example.com"
|
||||
assert _mask_email("ab@example.com") == "a***@example.com"
|
||||
assert _mask_email("a@example.com") == "a***@example.com"
|
||||
|
||||
|
||||
def test_mask_email_invalid():
|
||||
assert mask_email("no-at-sign") == "***"
|
||||
assert _mask_email("no-at-sign") == "***"
|
||||
|
||||
|
||||
# ── Prompt construction (curly-brace safety) ─────────────────────────────────
|
||||
@@ -394,11 +393,11 @@ def test_extraction_prompt_no_format_placeholders():
|
||||
assert single_braces == [], f"Found format placeholders: {single_braces}"
|
||||
|
||||
|
||||
# ── extract_business_understanding_from_tally ────────────────────────────────────────────
|
||||
# ── extract_business_understanding ────────────────────────────────────────────
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_extract_business_understanding_from_tally_success():
|
||||
async def test_extract_business_understanding_success():
|
||||
"""Happy path: LLM returns valid JSON that maps to BusinessUnderstandingInput."""
|
||||
mock_choice = MagicMock()
|
||||
mock_choice.message.content = json.dumps(
|
||||
@@ -407,13 +406,6 @@ async def test_extract_business_understanding_from_tally_success():
|
||||
"business_name": "Acme Corp",
|
||||
"industry": "Technology",
|
||||
"pain_points": ["manual reporting"],
|
||||
"suggested_prompts": [
|
||||
"Automate weekly reports",
|
||||
"Set up invoice processing",
|
||||
"Create a customer onboarding flow",
|
||||
"Track project deadlines automatically",
|
||||
"Send follow-up emails after meetings",
|
||||
],
|
||||
}
|
||||
)
|
||||
mock_response = MagicMock()
|
||||
@@ -423,56 +415,16 @@ async def test_extract_business_understanding_from_tally_success():
|
||||
mock_client.chat.completions.create.return_value = mock_response
|
||||
|
||||
with patch("backend.data.tally.AsyncOpenAI", return_value=mock_client):
|
||||
result = await extract_business_understanding_from_tally("Q: Name?\nA: Alice")
|
||||
result = await extract_business_understanding("Q: Name?\nA: Alice")
|
||||
|
||||
assert result.user_name == "Alice"
|
||||
assert result.business_name == "Acme Corp"
|
||||
assert result.industry == "Technology"
|
||||
assert result.pain_points == ["manual reporting"]
|
||||
# suggested_prompts validated and sliced to top 3
|
||||
assert result.suggested_prompts == [
|
||||
"Automate weekly reports",
|
||||
"Set up invoice processing",
|
||||
"Create a customer onboarding flow",
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_extract_business_understanding_from_tally_filters_long_prompts():
|
||||
"""Prompts exceeding 20 words are excluded and only top 3 are kept."""
|
||||
long_prompt = " ".join(["word"] * 21)
|
||||
mock_choice = MagicMock()
|
||||
mock_choice.message.content = json.dumps(
|
||||
{
|
||||
"user_name": "Alice",
|
||||
"suggested_prompts": [
|
||||
long_prompt,
|
||||
"Short prompt one",
|
||||
long_prompt,
|
||||
"Short prompt two",
|
||||
"Short prompt three",
|
||||
"Short prompt four",
|
||||
],
|
||||
}
|
||||
)
|
||||
mock_response = MagicMock()
|
||||
mock_response.choices = [mock_choice]
|
||||
|
||||
mock_client = AsyncMock()
|
||||
mock_client.chat.completions.create.return_value = mock_response
|
||||
|
||||
with patch("backend.data.tally.AsyncOpenAI", return_value=mock_client):
|
||||
result = await extract_business_understanding_from_tally("Q: Name?\nA: Alice")
|
||||
|
||||
assert result.suggested_prompts == [
|
||||
"Short prompt one",
|
||||
"Short prompt two",
|
||||
"Short prompt three",
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_extract_business_understanding_from_tally_filters_nulls():
|
||||
async def test_extract_business_understanding_filters_nulls():
|
||||
"""Null values from LLM should be excluded from the result."""
|
||||
mock_choice = MagicMock()
|
||||
mock_choice.message.content = json.dumps(
|
||||
@@ -485,7 +437,7 @@ async def test_extract_business_understanding_from_tally_filters_nulls():
|
||||
mock_client.chat.completions.create.return_value = mock_response
|
||||
|
||||
with patch("backend.data.tally.AsyncOpenAI", return_value=mock_client):
|
||||
result = await extract_business_understanding_from_tally("Q: Name?\nA: Alice")
|
||||
result = await extract_business_understanding("Q: Name?\nA: Alice")
|
||||
|
||||
assert result.user_name == "Alice"
|
||||
assert result.business_name is None
|
||||
@@ -493,7 +445,7 @@ async def test_extract_business_understanding_from_tally_filters_nulls():
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_extract_business_understanding_from_tally_invalid_json():
|
||||
async def test_extract_business_understanding_invalid_json():
|
||||
"""Invalid JSON from LLM should raise JSONDecodeError."""
|
||||
mock_choice = MagicMock()
|
||||
mock_choice.message.content = "not valid json {"
|
||||
@@ -507,11 +459,11 @@ async def test_extract_business_understanding_from_tally_invalid_json():
|
||||
patch("backend.data.tally.AsyncOpenAI", return_value=mock_client),
|
||||
pytest.raises(json.JSONDecodeError),
|
||||
):
|
||||
await extract_business_understanding_from_tally("Q: Name?\nA: Alice")
|
||||
await extract_business_understanding("Q: Name?\nA: Alice")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_extract_business_understanding_from_tally_timeout():
|
||||
async def test_extract_business_understanding_timeout():
|
||||
"""LLM timeout should propagate as asyncio.TimeoutError."""
|
||||
mock_client = AsyncMock()
|
||||
mock_client.chat.completions.create.side_effect = asyncio.TimeoutError()
|
||||
@@ -521,7 +473,7 @@ async def test_extract_business_understanding_from_tally_timeout():
|
||||
patch("backend.data.tally._LLM_TIMEOUT", 0.001),
|
||||
pytest.raises(asyncio.TimeoutError),
|
||||
):
|
||||
await extract_business_understanding_from_tally("Q: Name?\nA: Alice")
|
||||
await extract_business_understanding("Q: Name?\nA: Alice")
|
||||
|
||||
|
||||
# ── _refresh_cache ───────────────────────────────────────────────────────────
|
||||
@@ -540,7 +492,7 @@ async def test_refresh_cache_full_fetch():
|
||||
submissions = SAMPLE_SUBMISSIONS
|
||||
|
||||
with (
|
||||
patch("backend.data.tally._settings", mock_settings),
|
||||
patch("backend.data.tally.Settings", return_value=mock_settings),
|
||||
patch(
|
||||
"backend.data.tally.get_redis_async",
|
||||
new_callable=AsyncMock,
|
||||
@@ -588,7 +540,7 @@ async def test_refresh_cache_incremental_fetch():
|
||||
new_submissions = [SAMPLE_SUBMISSIONS[0]] # Just Alice
|
||||
|
||||
with (
|
||||
patch("backend.data.tally._settings", mock_settings),
|
||||
patch("backend.data.tally.Settings", return_value=mock_settings),
|
||||
patch(
|
||||
"backend.data.tally.get_redis_async",
|
||||
new_callable=AsyncMock,
|
||||
|
||||
@@ -86,11 +86,6 @@ class BusinessUnderstandingInput(pydantic.BaseModel):
|
||||
None, description="Any additional context"
|
||||
)
|
||||
|
||||
# Suggested prompts (UI-only, not included in system prompt)
|
||||
suggested_prompts: Optional[list[str]] = pydantic.Field(
|
||||
None, description="LLM-generated suggested prompts based on business context"
|
||||
)
|
||||
|
||||
|
||||
class BusinessUnderstanding(pydantic.BaseModel):
|
||||
"""Full business understanding model returned from database."""
|
||||
@@ -127,9 +122,6 @@ class BusinessUnderstanding(pydantic.BaseModel):
|
||||
# Additional context
|
||||
additional_notes: Optional[str] = None
|
||||
|
||||
# Suggested prompts (UI-only, not included in system prompt)
|
||||
suggested_prompts: list[str] = pydantic.Field(default_factory=list)
|
||||
|
||||
@classmethod
|
||||
def from_db(cls, db_record: CoPilotUnderstanding) -> "BusinessUnderstanding":
|
||||
"""Convert database record to Pydantic model."""
|
||||
@@ -157,7 +149,6 @@ class BusinessUnderstanding(pydantic.BaseModel):
|
||||
current_software=_json_to_list(business.get("current_software")),
|
||||
existing_automation=_json_to_list(business.get("existing_automation")),
|
||||
additional_notes=business.get("additional_notes"),
|
||||
suggested_prompts=_json_to_list(data.get("suggested_prompts")),
|
||||
)
|
||||
|
||||
|
||||
@@ -175,62 +166,6 @@ def _merge_lists(existing: list | None, new: list | None) -> list | None:
|
||||
return merged
|
||||
|
||||
|
||||
def merge_business_understanding_data(
|
||||
existing_data: dict[str, Any],
|
||||
input_data: BusinessUnderstandingInput,
|
||||
) -> dict[str, Any]:
|
||||
merged_data = dict(existing_data)
|
||||
|
||||
merged_business: dict[str, Any] = {}
|
||||
if isinstance(merged_data.get("business"), dict):
|
||||
merged_business = dict(merged_data["business"])
|
||||
|
||||
business_string_fields = [
|
||||
"job_title",
|
||||
"business_name",
|
||||
"industry",
|
||||
"business_size",
|
||||
"user_role",
|
||||
"additional_notes",
|
||||
]
|
||||
business_list_fields = [
|
||||
"key_workflows",
|
||||
"daily_activities",
|
||||
"pain_points",
|
||||
"bottlenecks",
|
||||
"manual_tasks",
|
||||
"automation_goals",
|
||||
"current_software",
|
||||
"existing_automation",
|
||||
]
|
||||
|
||||
if input_data.user_name is not None:
|
||||
merged_data["name"] = input_data.user_name
|
||||
|
||||
for field in business_string_fields:
|
||||
value = getattr(input_data, field)
|
||||
if value is not None:
|
||||
merged_business[field] = value
|
||||
|
||||
for field in business_list_fields:
|
||||
value = getattr(input_data, field)
|
||||
if value is not None:
|
||||
existing_list = _json_to_list(merged_business.get(field))
|
||||
merged_list = _merge_lists(existing_list, value)
|
||||
merged_business[field] = merged_list
|
||||
|
||||
merged_business["version"] = 1
|
||||
merged_data["business"] = merged_business
|
||||
|
||||
# suggested_prompts lives at the top level (not under `business`) because
|
||||
# it's a UI-only artifact consumed by the frontend, not business understanding
|
||||
# data. The `business` sub-dict feeds the system prompt.
|
||||
if input_data.suggested_prompts is not None:
|
||||
merged_data["suggested_prompts"] = input_data.suggested_prompts
|
||||
|
||||
return merged_data
|
||||
|
||||
|
||||
async def _get_from_cache(user_id: str) -> Optional[BusinessUnderstanding]:
|
||||
"""Get business understanding from Redis cache."""
|
||||
try:
|
||||
@@ -310,18 +245,63 @@ async def upsert_business_understanding(
|
||||
where={"userId": user_id}
|
||||
)
|
||||
|
||||
# Get existing data structure or start fresh
|
||||
existing_data: dict[str, Any] = {}
|
||||
if existing and isinstance(existing.data, dict):
|
||||
existing_data = dict(existing.data)
|
||||
|
||||
merged_data = merge_business_understanding_data(existing_data, input_data)
|
||||
existing_business: dict[str, Any] = {}
|
||||
if isinstance(existing_data.get("business"), dict):
|
||||
existing_business = dict(existing_data["business"])
|
||||
|
||||
# Business fields (stored inside business object)
|
||||
business_string_fields = [
|
||||
"job_title",
|
||||
"business_name",
|
||||
"industry",
|
||||
"business_size",
|
||||
"user_role",
|
||||
"additional_notes",
|
||||
]
|
||||
business_list_fields = [
|
||||
"key_workflows",
|
||||
"daily_activities",
|
||||
"pain_points",
|
||||
"bottlenecks",
|
||||
"manual_tasks",
|
||||
"automation_goals",
|
||||
"current_software",
|
||||
"existing_automation",
|
||||
]
|
||||
|
||||
# Handle top-level name field
|
||||
if input_data.user_name is not None:
|
||||
existing_data["name"] = input_data.user_name
|
||||
|
||||
# Business string fields - overwrite if provided
|
||||
for field in business_string_fields:
|
||||
value = getattr(input_data, field)
|
||||
if value is not None:
|
||||
existing_business[field] = value
|
||||
|
||||
# Business list fields - merge with existing
|
||||
for field in business_list_fields:
|
||||
value = getattr(input_data, field)
|
||||
if value is not None:
|
||||
existing_list = _json_to_list(existing_business.get(field))
|
||||
merged = _merge_lists(existing_list, value)
|
||||
existing_business[field] = merged
|
||||
|
||||
# Set version and nest business data
|
||||
existing_business["version"] = 1
|
||||
existing_data["business"] = existing_business
|
||||
|
||||
# Upsert with the merged data
|
||||
record = await CoPilotUnderstanding.prisma().upsert(
|
||||
where={"userId": user_id},
|
||||
data={
|
||||
"create": {"userId": user_id, "data": SafeJson(merged_data)},
|
||||
"update": {"data": SafeJson(merged_data)},
|
||||
"create": {"userId": user_id, "data": SafeJson(existing_data)},
|
||||
"update": {"data": SafeJson(existing_data)},
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@@ -1,102 +0,0 @@
|
||||
"""Tests for business understanding merge and format logic."""
|
||||
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any
|
||||
|
||||
from backend.data.understanding import (
|
||||
BusinessUnderstanding,
|
||||
BusinessUnderstandingInput,
|
||||
format_understanding_for_prompt,
|
||||
merge_business_understanding_data,
|
||||
)
|
||||
|
||||
|
||||
def _make_input(**kwargs: Any) -> BusinessUnderstandingInput:
|
||||
"""Create a BusinessUnderstandingInput with only the specified fields."""
|
||||
return BusinessUnderstandingInput.model_validate(kwargs)
|
||||
|
||||
|
||||
# ─── merge_business_understanding_data: suggested_prompts ─────────────
|
||||
|
||||
|
||||
def test_merge_suggested_prompts_overwrites_existing():
|
||||
"""New suggested_prompts should fully replace existing ones (not append)."""
|
||||
existing = {
|
||||
"name": "Alice",
|
||||
"business": {"industry": "Tech", "version": 1},
|
||||
"suggested_prompts": ["Old prompt 1", "Old prompt 2"],
|
||||
}
|
||||
input_data = _make_input(
|
||||
suggested_prompts=["New prompt A", "New prompt B", "New prompt C"],
|
||||
)
|
||||
|
||||
result = merge_business_understanding_data(existing, input_data)
|
||||
|
||||
assert result["suggested_prompts"] == [
|
||||
"New prompt A",
|
||||
"New prompt B",
|
||||
"New prompt C",
|
||||
]
|
||||
|
||||
|
||||
def test_merge_suggested_prompts_none_preserves_existing():
|
||||
"""When input has suggested_prompts=None, existing prompts are preserved."""
|
||||
existing = {
|
||||
"name": "Alice",
|
||||
"business": {"industry": "Tech", "version": 1},
|
||||
"suggested_prompts": ["Keep me"],
|
||||
}
|
||||
input_data = _make_input(industry="Finance")
|
||||
|
||||
result = merge_business_understanding_data(existing, input_data)
|
||||
|
||||
assert result["suggested_prompts"] == ["Keep me"]
|
||||
assert result["business"]["industry"] == "Finance"
|
||||
|
||||
|
||||
def test_merge_suggested_prompts_added_to_empty_data():
|
||||
"""Suggested prompts are set at top level even when starting from empty data."""
|
||||
existing: dict[str, Any] = {}
|
||||
input_data = _make_input(suggested_prompts=["Prompt 1"])
|
||||
|
||||
result = merge_business_understanding_data(existing, input_data)
|
||||
|
||||
assert result["suggested_prompts"] == ["Prompt 1"]
|
||||
|
||||
|
||||
def test_merge_suggested_prompts_empty_list_overwrites():
|
||||
"""An explicit empty list should overwrite existing prompts."""
|
||||
existing: dict[str, Any] = {
|
||||
"suggested_prompts": ["Old prompt"],
|
||||
"business": {"version": 1},
|
||||
}
|
||||
input_data = _make_input(suggested_prompts=[])
|
||||
|
||||
result = merge_business_understanding_data(existing, input_data)
|
||||
|
||||
assert result["suggested_prompts"] == []
|
||||
|
||||
|
||||
# ─── format_understanding_for_prompt: excludes suggested_prompts ──────
|
||||
|
||||
|
||||
def test_format_understanding_excludes_suggested_prompts():
|
||||
"""suggested_prompts is UI-only and must NOT appear in the system prompt."""
|
||||
understanding = BusinessUnderstanding(
|
||||
id="test-id",
|
||||
user_id="user-1",
|
||||
created_at=datetime.now(tz=timezone.utc),
|
||||
updated_at=datetime.now(tz=timezone.utc),
|
||||
user_name="Alice",
|
||||
industry="Technology",
|
||||
suggested_prompts=["Automate reports", "Set up alerts", "Track KPIs"],
|
||||
)
|
||||
|
||||
formatted = format_understanding_for_prompt(understanding)
|
||||
|
||||
assert "Alice" in formatted
|
||||
assert "Technology" in formatted
|
||||
assert "suggested_prompts" not in formatted
|
||||
assert "Automate reports" not in formatted
|
||||
assert "Set up alerts" not in formatted
|
||||
assert "Track KPIs" not in formatted
|
||||
@@ -46,7 +46,7 @@ from backend.util.exceptions import (
|
||||
)
|
||||
from backend.util.logging import TruncatedLogger, is_structured_logging_enabled
|
||||
from backend.util.settings import Config
|
||||
from backend.util.type import coerce_inputs_to_schema
|
||||
from backend.util.type import convert
|
||||
|
||||
config = Config()
|
||||
logger = TruncatedLogger(logging.getLogger(__name__), prefix="[GraphExecutorUtil]")
|
||||
@@ -213,8 +213,11 @@ def validate_exec(
|
||||
if resolve_input:
|
||||
data = merge_execution_input(data)
|
||||
|
||||
# Coerce non-matching data types to the expected input schema.
|
||||
coerce_inputs_to_schema(data, schema)
|
||||
# Convert non-matching data types to the expected input schema.
|
||||
for name, data_type in schema.__annotations__.items():
|
||||
value = data.get(name)
|
||||
if (value is not None) and (type(value) is not data_type):
|
||||
data[name] = convert(value, data_type)
|
||||
|
||||
# Input data post-merge should contain all required fields from the schema.
|
||||
if missing_input := schema.get_missing_input(data):
|
||||
|
||||
@@ -70,9 +70,6 @@ def _msg_tokens(msg: dict, enc) -> int:
|
||||
# Count tool result tokens
|
||||
tool_call_tokens += _tok_len(item.get("tool_use_id", ""), enc)
|
||||
tool_call_tokens += _tok_len(item.get("content", ""), enc)
|
||||
elif isinstance(item, dict) and item.get("type") == "text":
|
||||
# Count text block tokens
|
||||
tool_call_tokens += _tok_len(item.get("text", ""), enc)
|
||||
elif isinstance(item, dict) and "content" in item:
|
||||
# Other content types with content field
|
||||
tool_call_tokens += _tok_len(item.get("content", ""), enc)
|
||||
@@ -148,14 +145,10 @@ def _truncate_middle_tokens(text: str, enc, max_tok: int) -> str:
|
||||
if len(ids) <= max_tok:
|
||||
return text # nothing to do
|
||||
|
||||
# Need at least 3 tokens (head + ellipsis + tail) for meaningful truncation
|
||||
mid = enc.encode(" … ")
|
||||
if max_tok < 3:
|
||||
return enc.decode(mid)
|
||||
|
||||
# Split the allowance between the two ends:
|
||||
head = max_tok // 2 - 1 # -1 for the ellipsis
|
||||
tail = max_tok - head - 1
|
||||
mid = enc.encode(" … ")
|
||||
return enc.decode(ids[:head] + mid + ids[-tail:])
|
||||
|
||||
|
||||
@@ -403,7 +396,7 @@ def validate_and_remove_orphan_tool_responses(
|
||||
|
||||
if log_warning:
|
||||
logger.warning(
|
||||
"Removing %d orphan tool response(s): %s", len(orphan_ids), orphan_ids
|
||||
f"Removing {len(orphan_ids)} orphan tool response(s): {orphan_ids}"
|
||||
)
|
||||
|
||||
return _remove_orphan_tool_responses(messages, orphan_ids)
|
||||
@@ -495,9 +488,8 @@ def _ensure_tool_pairs_intact(
|
||||
# Some tool_call_ids couldn't be resolved - remove those tool responses
|
||||
# This shouldn't happen in normal operation but handles edge cases
|
||||
logger.warning(
|
||||
"Could not find assistant messages for tool_call_ids: %s. "
|
||||
"Removing orphan tool responses.",
|
||||
orphan_tool_call_ids,
|
||||
f"Could not find assistant messages for tool_call_ids: {orphan_tool_call_ids}. "
|
||||
"Removing orphan tool responses."
|
||||
)
|
||||
recent_messages = _remove_orphan_tool_responses(
|
||||
recent_messages, orphan_tool_call_ids
|
||||
@@ -505,8 +497,8 @@ def _ensure_tool_pairs_intact(
|
||||
|
||||
if messages_to_prepend:
|
||||
logger.info(
|
||||
"Extended recent messages by %d to preserve tool_call/tool_response pairs",
|
||||
len(messages_to_prepend),
|
||||
f"Extended recent messages by {len(messages_to_prepend)} to preserve "
|
||||
f"tool_call/tool_response pairs"
|
||||
)
|
||||
return messages_to_prepend + recent_messages
|
||||
|
||||
@@ -694,15 +686,11 @@ async def compress_context(
|
||||
msgs = [summary_msg] + recent_msgs
|
||||
|
||||
logger.info(
|
||||
"Context summarized: %d -> %d tokens, summarized %d messages",
|
||||
original_count,
|
||||
total_tokens(),
|
||||
messages_summarized,
|
||||
f"Context summarized: {original_count} -> {total_tokens()} tokens, "
|
||||
f"summarized {messages_summarized} messages"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
"Summarization failed, continuing with truncation: %s", e
|
||||
)
|
||||
logger.warning(f"Summarization failed, continuing with truncation: {e}")
|
||||
# Fall through to content truncation
|
||||
|
||||
# ---- STEP 2: Normalize content ----------------------------------------
|
||||
@@ -740,12 +728,6 @@ async def compress_context(
|
||||
# This is more granular than dropping all old messages at once.
|
||||
while total_tokens() + reserve > target_tokens and len(msgs) > 2:
|
||||
deletable: list[int] = []
|
||||
# Count assistant messages to ensure we keep at least one
|
||||
assistant_indices: set[int] = {
|
||||
i
|
||||
for i in range(len(msgs))
|
||||
if msgs[i] is not None and msgs[i].get("role") == "assistant"
|
||||
}
|
||||
for i in range(1, len(msgs) - 1):
|
||||
msg = msgs[i]
|
||||
if (
|
||||
@@ -753,9 +735,6 @@ async def compress_context(
|
||||
and not _is_tool_message(msg)
|
||||
and not _is_objective_message(msg)
|
||||
):
|
||||
# Skip if this is the last remaining assistant message
|
||||
if msg.get("role") == "assistant" and len(assistant_indices) <= 1:
|
||||
continue
|
||||
deletable.append(i)
|
||||
if not deletable:
|
||||
break
|
||||
|
||||
@@ -89,10 +89,6 @@ class Config(UpdateTrackingModel["Config"], BaseSettings):
|
||||
le=500,
|
||||
description="Thread pool size for FastAPI sync operations. All sync endpoints and dependencies automatically use this pool. Higher values support more concurrent sync operations but use more memory.",
|
||||
)
|
||||
tally_extraction_llm_model: str = Field(
|
||||
default="openai/gpt-4o-mini",
|
||||
description="OpenRouter model ID used for extracting business understanding from Tally form data",
|
||||
)
|
||||
ollama_host: str = Field(
|
||||
default="localhost:11434",
|
||||
description="Default Ollama host; exempted from SSRF checks.",
|
||||
@@ -121,10 +117,6 @@ class Config(UpdateTrackingModel["Config"], BaseSettings):
|
||||
default=True,
|
||||
description="If authentication is enabled or not",
|
||||
)
|
||||
enable_invite_gate: bool = Field(
|
||||
default=True,
|
||||
description="If the invite-only signup gate is enforced",
|
||||
)
|
||||
enable_credit: bool = Field(
|
||||
default=False,
|
||||
description="If user credit system is enabled or not",
|
||||
|
||||
@@ -249,87 +249,6 @@ def convert(value: Any, target_type: Any) -> Any:
|
||||
raise ConversionError(f"Failed to convert {value} to {target_type}") from e
|
||||
|
||||
|
||||
def _value_satisfies_type(value: Any, target: Any) -> bool:
|
||||
"""Check whether *value* already satisfies *target*, including inner elements.
|
||||
|
||||
For union types this checks each member; for generic container types it
|
||||
recursively checks that inner elements satisfy the type args (e.g. every
|
||||
element in a ``list[str]`` is a ``str``). Returns ``False`` when uncertain
|
||||
so the caller falls through to :func:`convert`.
|
||||
"""
|
||||
# typing.Any cannot be used with isinstance(); treat as always satisfied.
|
||||
if target is Any:
|
||||
return True
|
||||
|
||||
origin = get_origin(target)
|
||||
|
||||
if origin is Union or origin is types.UnionType:
|
||||
non_none = [a for a in get_args(target) if a is not type(None)]
|
||||
return any(_value_satisfies_type(value, member) for member in non_none)
|
||||
|
||||
# Generic container type (e.g. list[str], dict[str, int])
|
||||
if origin is not None:
|
||||
# Guard: origin may not be a runtime type (e.g. Literal)
|
||||
if not isinstance(origin, type):
|
||||
return False
|
||||
if not isinstance(value, origin):
|
||||
return False
|
||||
args = get_args(target)
|
||||
if not args:
|
||||
return True
|
||||
# Check inner elements satisfy the type args
|
||||
if _is_type_or_subclass(origin, list):
|
||||
return all(_value_satisfies_type(v, args[0]) for v in value)
|
||||
if _is_type_or_subclass(origin, dict) and len(args) >= 2:
|
||||
return all(
|
||||
_value_satisfies_type(k, args[0]) and _value_satisfies_type(v, args[1])
|
||||
for k, v in value.items()
|
||||
)
|
||||
if (
|
||||
_is_type_or_subclass(origin, set) or _is_type_or_subclass(origin, frozenset)
|
||||
) and args:
|
||||
return all(_value_satisfies_type(v, args[0]) for v in value)
|
||||
if _is_type_or_subclass(origin, tuple):
|
||||
# Homogeneous tuple[T, ...] — single type + Ellipsis
|
||||
if len(args) == 2 and args[1] is Ellipsis:
|
||||
return all(_value_satisfies_type(v, args[0]) for v in value)
|
||||
# Heterogeneous tuple[T1, T2, ...] — positional types
|
||||
if len(value) != len(args):
|
||||
return False
|
||||
return all(_value_satisfies_type(v, t) for v, t in zip(value, args))
|
||||
# Unhandled generic origin — fall through to convert()
|
||||
return False
|
||||
|
||||
# Simple type (e.g. str, int)
|
||||
if isinstance(target, type):
|
||||
return isinstance(value, target)
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def coerce_inputs_to_schema(data: dict[str, Any], schema: type) -> None:
|
||||
"""Coerce *data* values in-place to match *schema*'s field types.
|
||||
|
||||
Uses ``model_fields`` (not ``__annotations__``) so inherited fields are
|
||||
included. Skips coercion when the value already satisfies the target
|
||||
type — in particular for union-typed fields where the value matches one
|
||||
member but differs from the annotation object itself.
|
||||
|
||||
This is the single authoritative coercion step shared by the executor
|
||||
(``validate_exec``) and the CoPilot (``execute_block``).
|
||||
"""
|
||||
for name, field_info in schema.model_fields.items():
|
||||
value = data.get(name)
|
||||
if value is None:
|
||||
continue
|
||||
target = field_info.annotation
|
||||
if target is None:
|
||||
continue
|
||||
if _value_satisfies_type(value, target):
|
||||
continue
|
||||
data[name] = convert(value, target)
|
||||
|
||||
|
||||
class FormattedStringType(str):
|
||||
string_format: str
|
||||
|
||||
|
||||
@@ -1,8 +1,6 @@
|
||||
from typing import Any, List, Literal, Optional
|
||||
from typing import List, Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from backend.util.type import _value_satisfies_type, coerce_inputs_to_schema, convert
|
||||
from backend.util.type import convert
|
||||
|
||||
|
||||
def test_type_conversion():
|
||||
@@ -48,343 +46,3 @@ def test_type_conversion():
|
||||
# Test other empty list conversions
|
||||
assert convert([], int) == 0 # len([]) = 0
|
||||
assert convert([], Optional[int]) == 0
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _value_satisfies_type
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestValueSatisfiesType:
|
||||
# --- simple types ---
|
||||
def test_simple_match(self):
|
||||
assert _value_satisfies_type("hello", str) is True
|
||||
assert _value_satisfies_type(42, int) is True
|
||||
assert _value_satisfies_type(3.14, float) is True
|
||||
assert _value_satisfies_type(True, bool) is True
|
||||
|
||||
def test_simple_mismatch(self):
|
||||
assert _value_satisfies_type("hello", int) is False
|
||||
assert _value_satisfies_type(42, str) is False
|
||||
assert _value_satisfies_type([1, 2], str) is False
|
||||
|
||||
# --- Any ---
|
||||
def test_any_always_satisfied(self):
|
||||
assert _value_satisfies_type("hello", Any) is True
|
||||
assert _value_satisfies_type(42, Any) is True
|
||||
assert _value_satisfies_type([1, 2], Any) is True
|
||||
assert _value_satisfies_type(None, Any) is True
|
||||
|
||||
# --- Optional / Union ---
|
||||
def test_optional_with_value(self):
|
||||
assert _value_satisfies_type("hello", Optional[str]) is True
|
||||
assert _value_satisfies_type(42, Optional[int]) is True
|
||||
|
||||
def test_optional_mismatch(self):
|
||||
assert _value_satisfies_type(42, Optional[str]) is False
|
||||
|
||||
def test_union_matches_first_member(self):
|
||||
assert _value_satisfies_type("hello", str | list[str]) is True
|
||||
|
||||
def test_union_matches_second_member(self):
|
||||
assert _value_satisfies_type(["a", "b"], str | list[str]) is True
|
||||
|
||||
def test_union_no_match(self):
|
||||
assert _value_satisfies_type(42, str | list[str]) is False
|
||||
|
||||
# --- list[T] ---
|
||||
def test_list_str_all_match(self):
|
||||
assert _value_satisfies_type(["a", "b", "c"], list[str]) is True
|
||||
|
||||
def test_list_str_inner_mismatch(self):
|
||||
assert _value_satisfies_type([1, 2, 3], list[str]) is False
|
||||
|
||||
def test_list_int_all_match(self):
|
||||
assert _value_satisfies_type([1, 2, 3], list[int]) is True
|
||||
|
||||
def test_list_int_inner_mismatch(self):
|
||||
assert _value_satisfies_type(["1", "2"], list[int]) is False
|
||||
|
||||
def test_empty_list_satisfies_any_list_type(self):
|
||||
assert _value_satisfies_type([], list[str]) is True
|
||||
assert _value_satisfies_type([], list[int]) is True
|
||||
|
||||
def test_string_does_not_satisfy_list(self):
|
||||
assert _value_satisfies_type("hello", list[str]) is False
|
||||
|
||||
# --- nested list[list[str]] ---
|
||||
def test_nested_list_all_match(self):
|
||||
assert _value_satisfies_type([["a", "b"], ["c"]], list[list[str]]) is True
|
||||
|
||||
def test_nested_list_inner_mismatch(self):
|
||||
assert _value_satisfies_type([["a", 1], ["c"]], list[list[str]]) is False
|
||||
|
||||
def test_nested_list_outer_mismatch(self):
|
||||
assert _value_satisfies_type(["a", "b"], list[list[str]]) is False
|
||||
|
||||
# --- dict[K, V] ---
|
||||
def test_dict_str_int_match(self):
|
||||
assert _value_satisfies_type({"a": 1, "b": 2}, dict[str, int]) is True
|
||||
|
||||
def test_dict_str_int_value_mismatch(self):
|
||||
assert _value_satisfies_type({"a": "1", "b": "2"}, dict[str, int]) is False
|
||||
|
||||
def test_dict_str_int_key_mismatch(self):
|
||||
assert _value_satisfies_type({1: 1, 2: 2}, dict[str, int]) is False
|
||||
|
||||
def test_empty_dict_satisfies(self):
|
||||
assert _value_satisfies_type({}, dict[str, int]) is True
|
||||
|
||||
# --- set[T] / tuple[T] ---
|
||||
def test_set_match(self):
|
||||
assert _value_satisfies_type({1, 2, 3}, set[int]) is True
|
||||
|
||||
def test_set_mismatch(self):
|
||||
assert _value_satisfies_type({"a", "b"}, set[int]) is False
|
||||
|
||||
def test_tuple_homogeneous_match(self):
|
||||
assert _value_satisfies_type((1, 2, 3), tuple[int, ...]) is True
|
||||
|
||||
def test_tuple_homogeneous_mismatch(self):
|
||||
assert _value_satisfies_type((1, "2", 3), tuple[int, ...]) is False
|
||||
|
||||
def test_tuple_heterogeneous_match(self):
|
||||
assert _value_satisfies_type(("a", 1, True), tuple[str, int, bool]) is True
|
||||
|
||||
def test_tuple_heterogeneous_mismatch(self):
|
||||
assert _value_satisfies_type(("a", "b", True), tuple[str, int, bool]) is False
|
||||
|
||||
def test_tuple_heterogeneous_wrong_length(self):
|
||||
assert _value_satisfies_type(("a", 1), tuple[str, int, bool]) is False
|
||||
|
||||
# --- bare generics (no args) ---
|
||||
def test_bare_list(self):
|
||||
assert _value_satisfies_type([1, "a"], list) is True
|
||||
|
||||
def test_bare_dict(self):
|
||||
assert _value_satisfies_type({"a": 1}, dict) is True
|
||||
|
||||
# --- union with generic inner mismatch ---
|
||||
def test_union_list_with_wrong_inner_falls_through(self):
|
||||
# [1, 2] doesn't satisfy list[str] (inner mismatch), and not str either
|
||||
assert _value_satisfies_type([1, 2], str | list[str]) is False
|
||||
|
||||
# --- Literal (non-runtime origin) ---
|
||||
def test_literal_does_not_crash(self):
|
||||
"""Literal origins are not runtime types — should return False, not crash."""
|
||||
assert _value_satisfies_type("active", Literal["active", "inactive"]) is False
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# coerce_inputs_to_schema — using real Pydantic models
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class SampleSchema(BaseModel):
|
||||
name: str
|
||||
count: int
|
||||
items: list[str]
|
||||
config: dict[str, int] = {}
|
||||
|
||||
|
||||
class NestedSchema(BaseModel):
|
||||
rows: list[list[str]]
|
||||
|
||||
|
||||
class UnionSchema(BaseModel):
|
||||
content: str | list[str]
|
||||
|
||||
|
||||
class OptionalSchema(BaseModel):
|
||||
label: Optional[str] = None
|
||||
value: int = 0
|
||||
|
||||
|
||||
class AnyFieldSchema(BaseModel):
|
||||
data: Any
|
||||
|
||||
|
||||
class TestCoerceInputsToSchema:
|
||||
def test_string_to_int(self):
|
||||
data: dict[str, Any] = {"name": "test", "count": "42", "items": ["a"]}
|
||||
coerce_inputs_to_schema(data, SampleSchema)
|
||||
assert data["count"] == 42
|
||||
assert isinstance(data["count"], int)
|
||||
|
||||
def test_json_string_to_list(self):
|
||||
data: dict[str, Any] = {"name": "test", "count": 1, "items": '["a","b","c"]'}
|
||||
coerce_inputs_to_schema(data, SampleSchema)
|
||||
assert data["items"] == ["a", "b", "c"]
|
||||
|
||||
def test_already_correct_types_unchanged(self):
|
||||
data: dict[str, Any] = {
|
||||
"name": "test",
|
||||
"count": 42,
|
||||
"items": ["a", "b"],
|
||||
"config": {"x": 1},
|
||||
}
|
||||
coerce_inputs_to_schema(data, SampleSchema)
|
||||
assert data == {
|
||||
"name": "test",
|
||||
"count": 42,
|
||||
"items": ["a", "b"],
|
||||
"config": {"x": 1},
|
||||
}
|
||||
|
||||
def test_inner_element_coercion(self):
|
||||
"""list[str] with int inner elements → coerced to strings."""
|
||||
data: dict[str, Any] = {"name": "test", "count": 1, "items": [1, 2, 3]}
|
||||
coerce_inputs_to_schema(data, SampleSchema)
|
||||
assert data["items"] == ["1", "2", "3"]
|
||||
|
||||
def test_dict_value_coercion(self):
|
||||
"""dict[str, int] with string values → coerced to ints."""
|
||||
data: dict[str, Any] = {
|
||||
"name": "test",
|
||||
"count": 1,
|
||||
"items": [],
|
||||
"config": {"x": "10", "y": "20"},
|
||||
}
|
||||
coerce_inputs_to_schema(data, SampleSchema)
|
||||
assert data["config"] == {"x": 10, "y": 20}
|
||||
|
||||
def test_nested_list_from_json_string(self):
|
||||
data: dict[str, Any] = {
|
||||
"rows": '[["Name","Score"],["Alice","90"]]',
|
||||
}
|
||||
coerce_inputs_to_schema(data, NestedSchema)
|
||||
assert data["rows"] == [["Name", "Score"], ["Alice", "90"]]
|
||||
|
||||
def test_nested_list_already_correct(self):
|
||||
original = [["a", "b"], ["c", "d"]]
|
||||
data: dict[str, Any] = {"rows": original}
|
||||
coerce_inputs_to_schema(data, NestedSchema)
|
||||
assert data["rows"] == original
|
||||
|
||||
def test_union_preserves_valid_list(self):
|
||||
"""list[str] value should NOT be stringified for str | list[str]."""
|
||||
data: dict[str, Any] = {"content": ["a", "b"]}
|
||||
coerce_inputs_to_schema(data, UnionSchema)
|
||||
assert data["content"] == ["a", "b"]
|
||||
assert isinstance(data["content"], list)
|
||||
|
||||
def test_union_preserves_valid_string(self):
|
||||
data: dict[str, Any] = {"content": "hello"}
|
||||
coerce_inputs_to_schema(data, UnionSchema)
|
||||
assert data["content"] == "hello"
|
||||
|
||||
def test_union_list_with_wrong_inner_gets_coerced(self):
|
||||
"""[1, 2] for str | list[str] — inner ints don't match list[str],
|
||||
so convert() is called. convert tries str first → stringifies."""
|
||||
data: dict[str, Any] = {"content": [1, 2]}
|
||||
coerce_inputs_to_schema(data, UnionSchema)
|
||||
# convert([1,2], str | list[str]) tries str first → "[1, 2]"
|
||||
# This is convert()'s union behavior — str wins over list[str]
|
||||
assert isinstance(data["content"], (str, list))
|
||||
|
||||
def test_skips_none_values(self):
|
||||
data: dict[str, Any] = {"label": None, "value": "5"}
|
||||
coerce_inputs_to_schema(data, OptionalSchema)
|
||||
assert data["label"] is None
|
||||
assert data["value"] == 5
|
||||
|
||||
def test_skips_missing_fields(self):
|
||||
data: dict[str, Any] = {"value": "10"}
|
||||
coerce_inputs_to_schema(data, OptionalSchema)
|
||||
assert "label" not in data
|
||||
assert data["value"] == 10
|
||||
|
||||
def test_any_field_skipped(self):
|
||||
"""Fields typed as Any should pass through without coercion."""
|
||||
data: dict[str, Any] = {"data": [1, "mixed", {"nested": True}]}
|
||||
coerce_inputs_to_schema(data, AnyFieldSchema)
|
||||
assert data["data"] == [1, "mixed", {"nested": True}]
|
||||
|
||||
def test_preserves_all_convert_capabilities(self):
|
||||
"""Verify coerce_inputs_to_schema doesn't lose any convert() capability
|
||||
that existed before the _value_satisfies_type gate was added."""
|
||||
|
||||
class FullSchema(BaseModel):
|
||||
as_int: int
|
||||
as_float: float
|
||||
as_bool: bool
|
||||
as_str: str
|
||||
as_list: list[int]
|
||||
as_dict: dict[str, str]
|
||||
|
||||
data: dict[str, Any] = {
|
||||
"as_int": "42",
|
||||
"as_float": "3.14",
|
||||
"as_bool": "True",
|
||||
"as_str": 123,
|
||||
"as_list": "[1,2,3]",
|
||||
"as_dict": '{"a": "b"}',
|
||||
}
|
||||
coerce_inputs_to_schema(data, FullSchema)
|
||||
assert data["as_int"] == 42
|
||||
assert data["as_float"] == 3.14
|
||||
assert data["as_bool"] is True
|
||||
assert data["as_str"] == "123"
|
||||
assert data["as_list"] == [1, 2, 3]
|
||||
assert data["as_dict"] == {"a": "b"}
|
||||
|
||||
def test_inherited_fields_are_coerced(self):
|
||||
"""model_fields includes inherited fields; __annotations__ does not.
|
||||
This verifies that fields from a parent schema are still coerced."""
|
||||
|
||||
class ParentSchema(BaseModel):
|
||||
base_count: int
|
||||
|
||||
class ChildSchema(ParentSchema):
|
||||
name: str
|
||||
|
||||
# base_count is inherited — __annotations__ wouldn't include it
|
||||
assert "base_count" not in ChildSchema.__annotations__
|
||||
assert "base_count" in ChildSchema.model_fields
|
||||
|
||||
data: dict[str, Any] = {"base_count": "42", "name": "test"}
|
||||
coerce_inputs_to_schema(data, ChildSchema)
|
||||
assert data["base_count"] == 42
|
||||
assert isinstance(data["base_count"], int)
|
||||
|
||||
def test_nested_pydantic_model_field(self):
|
||||
"""dict input for a Pydantic model-typed field passes through.
|
||||
convert() doesn't construct Pydantic models — Pydantic validation
|
||||
handles that downstream. This test documents the behavior."""
|
||||
|
||||
class InnerModel(BaseModel):
|
||||
x: int
|
||||
|
||||
class OuterModel(BaseModel):
|
||||
inner: InnerModel
|
||||
|
||||
data: dict[str, Any] = {"inner": {"x": 1}}
|
||||
coerce_inputs_to_schema(data, OuterModel)
|
||||
# dict stays as dict — convert() doesn't construct Pydantic models
|
||||
assert data["inner"] == {"x": 1}
|
||||
assert isinstance(data["inner"], dict)
|
||||
|
||||
def test_literal_field_passes_through(self):
|
||||
"""Literal-typed fields should not crash coercion."""
|
||||
|
||||
class LiteralSchema(BaseModel):
|
||||
status: Literal["active", "inactive"]
|
||||
|
||||
data: dict[str, Any] = {"status": "active"}
|
||||
coerce_inputs_to_schema(data, LiteralSchema)
|
||||
assert data["status"] == "active"
|
||||
|
||||
def test_list_of_pydantic_model_field(self):
|
||||
"""list[dict] for list[PydanticModel] passes through unchanged."""
|
||||
|
||||
class ItemModel(BaseModel):
|
||||
name: str
|
||||
|
||||
class ContainerModel(BaseModel):
|
||||
items: list[ItemModel]
|
||||
|
||||
data: dict[str, Any] = {"items": [{"name": "a"}, {"name": "b"}]}
|
||||
coerce_inputs_to_schema(data, ContainerModel)
|
||||
# Dicts stay as dicts — Pydantic validation handles construction
|
||||
assert data["items"] == [{"name": "a"}, {"name": "b"}]
|
||||
assert isinstance(data["items"][0], dict)
|
||||
|
||||
@@ -1,246 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
AutoGPT Analytics — View Generator
|
||||
====================================
|
||||
Reads every .sql file in analytics/queries/ and registers it as a
|
||||
CREATE OR REPLACE VIEW in the analytics schema.
|
||||
|
||||
Quick start (from autogpt_platform/backend/):
|
||||
|
||||
Step 1 — one-time setup (creates schema, role, grants):
|
||||
|
||||
poetry run analytics-setup
|
||||
|
||||
Step 2 — create / refresh all 14 analytics views:
|
||||
|
||||
poetry run analytics-views
|
||||
|
||||
Both commands auto-detect credentials from .env (DB_* vars).
|
||||
Use --db-url to override.
|
||||
|
||||
Step 3 (optional) — enable login and set a password for the read-only
|
||||
role so external tools (Supabase MCP, PostHog Data Warehouse) can connect.
|
||||
The role is created as NOLOGIN, so you must grant LOGIN at the same time.
|
||||
Run in Supabase SQL Editor:
|
||||
|
||||
ALTER ROLE analytics_readonly WITH LOGIN PASSWORD 'your-password';
|
||||
|
||||
Usage
|
||||
-----
|
||||
poetry run analytics-setup # apply setup to DB
|
||||
poetry run analytics-setup --dry-run # print setup SQL only
|
||||
poetry run analytics-views # apply all views to DB
|
||||
poetry run analytics-views --dry-run # print all view SQL only
|
||||
poetry run analytics-views --only graph_execution,retention_login_weekly
|
||||
|
||||
Environment variables
|
||||
---------------------
|
||||
DATABASE_URL Postgres connection string (checked before .env)
|
||||
|
||||
Notes
|
||||
-----
|
||||
- .env DB_* vars are read automatically as a fallback.
|
||||
- Safe to re-run: uses CREATE OR REPLACE VIEW.
|
||||
- Looker, PostHog Data Warehouse, and Supabase MCP all read from the
|
||||
same analytics.* views — no raw tables exposed.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import os
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from urllib.parse import quote
|
||||
|
||||
QUERIES_DIR = Path(__file__).parent.parent / "analytics" / "queries"
|
||||
ENV_FILE = Path(__file__).parent / ".env"
|
||||
SCHEMA = "analytics"
|
||||
|
||||
SETUP_SQL = """\
|
||||
-- =============================================================
|
||||
-- AutoGPT Analytics Schema Setup
|
||||
-- Run ONCE as the postgres superuser (e.g. via Supabase SQL Editor).
|
||||
-- After this, run: poetry run analytics-views
|
||||
-- =============================================================
|
||||
|
||||
-- 1. Create the analytics schema
|
||||
CREATE SCHEMA IF NOT EXISTS analytics;
|
||||
|
||||
-- 2. Create the read-only role (skip if already exists)
|
||||
DO $$
|
||||
BEGIN
|
||||
IF NOT EXISTS (SELECT FROM pg_roles WHERE rolname = 'analytics_readonly') THEN
|
||||
CREATE ROLE analytics_readonly NOLOGIN;
|
||||
END IF;
|
||||
END
|
||||
$$;
|
||||
|
||||
-- 3. Analytics schema grants only.
|
||||
-- Views use security_invoker = false so they execute as their
|
||||
-- owner (postgres). analytics_readonly never needs direct access
|
||||
-- to the platform or auth schemas.
|
||||
GRANT USAGE ON SCHEMA analytics TO analytics_readonly;
|
||||
GRANT SELECT ON ALL TABLES IN SCHEMA analytics TO analytics_readonly;
|
||||
ALTER DEFAULT PRIVILEGES IN SCHEMA analytics
|
||||
GRANT SELECT ON TABLES TO analytics_readonly;
|
||||
"""
|
||||
|
||||
|
||||
def load_db_url_from_env() -> str | None:
|
||||
"""Read DB_* vars from .env and build a psycopg2 connection string."""
|
||||
if not ENV_FILE.exists():
|
||||
return None
|
||||
env: dict[str, str] = {}
|
||||
for line in ENV_FILE.read_text().splitlines():
|
||||
line = line.strip()
|
||||
if not line or line.startswith("#") or "=" not in line:
|
||||
continue
|
||||
key, _, value = line.partition("=")
|
||||
env[key.strip()] = value.strip().strip('"').strip("'")
|
||||
host = env.get("DB_HOST", "localhost")
|
||||
port = env.get("DB_PORT", "5432")
|
||||
user = env.get("DB_USER", "postgres")
|
||||
password = env.get("DB_PASS", "")
|
||||
dbname = env.get("DB_NAME", "postgres")
|
||||
if not password:
|
||||
return None
|
||||
return (
|
||||
"postgresql://"
|
||||
f"{quote(user, safe='')}:{quote(password, safe='')}"
|
||||
f"@{host}:{port}/{quote(dbname, safe='')}"
|
||||
)
|
||||
|
||||
|
||||
def get_db_url(args: argparse.Namespace) -> str | None:
|
||||
return args.db_url or os.environ.get("DATABASE_URL") or load_db_url_from_env()
|
||||
|
||||
|
||||
def connect(db_url: str):
|
||||
try:
|
||||
import psycopg2
|
||||
except ImportError:
|
||||
print("psycopg2 not found. Run: poetry install", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
return psycopg2.connect(db_url)
|
||||
|
||||
|
||||
def run_sql(db_url: str, statements: list[tuple[str, str]]) -> None:
|
||||
"""Execute a list of (label, sql) pairs in a single transaction."""
|
||||
conn = connect(db_url)
|
||||
conn.autocommit = False
|
||||
cur = conn.cursor()
|
||||
try:
|
||||
for label, sql in statements:
|
||||
print(f" {label} ...", end=" ")
|
||||
cur.execute(sql)
|
||||
print("OK")
|
||||
conn.commit()
|
||||
print(f"\n✓ {len(statements)} statement(s) applied.")
|
||||
except Exception as e:
|
||||
conn.rollback()
|
||||
print(f"\n✗ Error: {e}", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
finally:
|
||||
cur.close()
|
||||
conn.close()
|
||||
|
||||
|
||||
def build_view_sql(name: str, query_body: str) -> str:
|
||||
body = query_body.strip().rstrip(";")
|
||||
# security_invoker = false → view runs as its owner (postgres), not the
|
||||
# caller, so analytics_readonly only needs analytics schema access.
|
||||
return f"CREATE OR REPLACE VIEW {SCHEMA}.{name} WITH (security_invoker = false) AS\n{body};\n"
|
||||
|
||||
|
||||
def load_views(only: list[str] | None = None) -> list[tuple[str, str]]:
|
||||
"""Return [(label, sql)] for all views, in alphabetical order."""
|
||||
files = sorted(QUERIES_DIR.glob("*.sql"))
|
||||
if not files:
|
||||
print(f"No .sql files found in {QUERIES_DIR}", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
known = {f.stem for f in files}
|
||||
if only:
|
||||
unknown = [n for n in only if n not in known]
|
||||
if unknown:
|
||||
print(
|
||||
f"Unknown view name(s): {', '.join(unknown)}\n"
|
||||
f"Available: {', '.join(sorted(known))}",
|
||||
file=sys.stderr,
|
||||
)
|
||||
sys.exit(1)
|
||||
result = []
|
||||
for f in files:
|
||||
name = f.stem
|
||||
if only and name not in only:
|
||||
continue
|
||||
result.append((f"view analytics.{name}", build_view_sql(name, f.read_text())))
|
||||
return result
|
||||
|
||||
|
||||
def no_db_url_error() -> None:
|
||||
print(
|
||||
"No database URL found.\n"
|
||||
"Tried: --db-url, DATABASE_URL env var, and .env (DB_* vars).\n"
|
||||
"Use --dry-run to just print the SQL.",
|
||||
file=sys.stderr,
|
||||
)
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
def cmd_setup(args: argparse.Namespace) -> None:
|
||||
if args.dry_run:
|
||||
print(SETUP_SQL)
|
||||
return
|
||||
db_url = get_db_url(args)
|
||||
if not db_url:
|
||||
no_db_url_error()
|
||||
assert db_url
|
||||
print("Applying analytics setup...")
|
||||
run_sql(db_url, [("schema / role / grants", SETUP_SQL)])
|
||||
|
||||
|
||||
def cmd_views(args: argparse.Namespace) -> None:
|
||||
only = [v.strip() for v in args.only.split(",")] if args.only else None
|
||||
views = load_views(only=only)
|
||||
if not views:
|
||||
print("No matching views found.")
|
||||
sys.exit(0)
|
||||
|
||||
if args.dry_run:
|
||||
print(f"-- {len(views)} views\n")
|
||||
for label, sql in views:
|
||||
print(f"-- {label}")
|
||||
print(sql)
|
||||
return
|
||||
|
||||
db_url = get_db_url(args)
|
||||
if not db_url:
|
||||
no_db_url_error()
|
||||
assert db_url
|
||||
print(f"Applying {len(views)} view(s)...")
|
||||
# Append grant refresh so the readonly role sees any new views
|
||||
grant = f"GRANT SELECT ON ALL TABLES IN SCHEMA {SCHEMA} TO analytics_readonly;"
|
||||
run_sql(db_url, views + [("grant analytics_readonly", grant)])
|
||||
|
||||
|
||||
def main_setup() -> None:
|
||||
parser = argparse.ArgumentParser(description="Apply analytics schema setup to DB")
|
||||
parser.add_argument(
|
||||
"--dry-run", action="store_true", help="Print SQL, don't execute"
|
||||
)
|
||||
parser.add_argument("--db-url", help="Postgres connection string")
|
||||
cmd_setup(parser.parse_args())
|
||||
|
||||
|
||||
def main_views() -> None:
|
||||
parser = argparse.ArgumentParser(description="Apply analytics views to DB")
|
||||
parser.add_argument(
|
||||
"--dry-run", action="store_true", help="Print SQL, don't execute"
|
||||
)
|
||||
parser.add_argument("--db-url", help="Postgres connection string")
|
||||
parser.add_argument("--only", help="Comma-separated view names to update")
|
||||
cmd_views(parser.parse_args())
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Default: apply views (backwards-compatible with direct python invocation)
|
||||
main_views()
|
||||
@@ -1,46 +0,0 @@
|
||||
/*
|
||||
Warnings:
|
||||
|
||||
- You are about to drop the column `search` on the `StoreListingVersion` table. All the data in the column will be lost.
|
||||
|
||||
*/-- CreateEnum
|
||||
CREATE TYPE "InvitedUserStatus" AS ENUM('INVITED',
|
||||
'CLAIMED',
|
||||
'REVOKED');
|
||||
-- CreateEnum
|
||||
CREATE TYPE "TallyComputationStatus" AS ENUM('PENDING',
|
||||
'RUNNING',
|
||||
'READY',
|
||||
'FAILED');
|
||||
-- CreateTable
|
||||
CREATE TABLE "InvitedUser"(
|
||||
"id" TEXT NOT NULL,
|
||||
"createdAt" TIMESTAMP(3) NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
"updatedAt" TIMESTAMP(3) NOT NULL,
|
||||
"email" TEXT NOT NULL,
|
||||
"status" "InvitedUserStatus" NOT NULL DEFAULT 'INVITED',
|
||||
"authUserId" TEXT,
|
||||
"name" TEXT,
|
||||
"tallyUnderstanding" JSONB,
|
||||
"tallyStatus" "TallyComputationStatus" NOT NULL DEFAULT 'PENDING',
|
||||
"tallyComputedAt" TIMESTAMP(3),
|
||||
"tallyError" TEXT,
|
||||
CONSTRAINT "InvitedUser_pkey" PRIMARY KEY("id")
|
||||
);
|
||||
-- CreateIndex
|
||||
CREATE UNIQUE INDEX "InvitedUser_email_key"
|
||||
ON "InvitedUser"("email");
|
||||
-- CreateIndex
|
||||
CREATE UNIQUE INDEX "InvitedUser_authUserId_key"
|
||||
ON "InvitedUser"("authUserId");
|
||||
-- CreateIndex
|
||||
CREATE INDEX "InvitedUser_status_idx"
|
||||
ON "InvitedUser"("status");
|
||||
-- CreateIndex
|
||||
CREATE INDEX "InvitedUser_tallyStatus_idx"
|
||||
ON "InvitedUser"("tallyStatus");
|
||||
-- AddForeignKey
|
||||
ALTER TABLE "InvitedUser" ADD CONSTRAINT "InvitedUser_authUserId_fkey" FOREIGN KEY("authUserId") REFERENCES "User"("id")
|
||||
ON DELETE
|
||||
SET NULL
|
||||
ON UPDATE CASCADE;
|
||||
@@ -1,15 +0,0 @@
|
||||
-- Drop the trigger that auto-creates User + Profile on auth.users INSERT.
|
||||
-- The invite activation flow in get_or_activate_user() now handles this.
|
||||
DO $$
|
||||
BEGIN
|
||||
IF EXISTS (
|
||||
SELECT 1 FROM information_schema.tables
|
||||
WHERE table_schema = 'auth' AND table_name = 'users'
|
||||
) THEN
|
||||
DROP TRIGGER IF EXISTS user_added_to_platform ON auth.users;
|
||||
END IF;
|
||||
END $$;
|
||||
|
||||
DROP FUNCTION IF EXISTS add_user_and_profile_to_platform();
|
||||
DROP FUNCTION IF EXISTS add_user_to_platform();
|
||||
-- Keep generate_username() — used by backfill migration 20250205110132
|
||||
@@ -1,7 +0,0 @@
|
||||
-- DropIndex
|
||||
DROP INDEX "InvitedUser_status_idx";
|
||||
-- DropIndex
|
||||
DROP INDEX "InvitedUser_tallyStatus_idx";
|
||||
-- CreateIndex
|
||||
CREATE INDEX "InvitedUser_createdAt_idx"
|
||||
ON "InvitedUser"("createdAt");
|
||||
8
autogpt_platform/backend/poetry.lock
generated
8
autogpt_platform/backend/poetry.lock
generated
@@ -1282,14 +1282,14 @@ pgp = ["gpg"]
|
||||
|
||||
[[package]]
|
||||
name = "e2b"
|
||||
version = "2.15.2"
|
||||
version = "2.15.1"
|
||||
description = "E2B SDK that give agents cloud environments"
|
||||
optional = false
|
||||
python-versions = "<4.0,>=3.10"
|
||||
groups = ["main"]
|
||||
files = [
|
||||
{file = "e2b-2.15.2-py3-none-any.whl", hash = "sha256:19a56fbdea25974dc81426ed48337eae6cea91d404f5bcf8861a5a2c6e4d982a"},
|
||||
{file = "e2b-2.15.2.tar.gz", hash = "sha256:414379d2421d6827eeb2eb50a4d6b3fdb7d691b39ff73b5ea05ca4b532819831"},
|
||||
{file = "e2b-2.15.1-py3-none-any.whl", hash = "sha256:a3bc4e004eab51fb05bae44e9ee4fe821e4637260f4ce3064c8f7c6ed7f5a2a0"},
|
||||
{file = "e2b-2.15.1.tar.gz", hash = "sha256:a4f1bbc8b5180a8a1098079257fcb73e42503ed546098f676f722f11f0d68c09"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
@@ -8882,4 +8882,4 @@ cffi = ["cffi (>=1.17,<2.0) ; platform_python_implementation != \"PyPy\" and pyt
|
||||
[metadata]
|
||||
lock-version = "2.1"
|
||||
python-versions = ">=3.10,<3.14"
|
||||
content-hash = "4e4365721cd3b68c58c237353b74adae1c64233fd4446904c335f23eb866fdca"
|
||||
content-hash = "618d61b0586ab82fec1e28d1feb549a198e0b5c9d152e808862e55efc00a65b9"
|
||||
|
||||
@@ -20,7 +20,7 @@ claude-agent-sdk = "0.1.45" # see copilot/sdk/sdk_compat_test.py for capability
|
||||
click = "^8.2.0"
|
||||
cryptography = "^46.0"
|
||||
discord-py = "^2.5.2"
|
||||
e2b = "^2.15.2"
|
||||
e2b = "^2.0"
|
||||
e2b-code-interpreter = "^2.0"
|
||||
elevenlabs = "^1.50.0"
|
||||
fastapi = "^0.128.6"
|
||||
@@ -120,8 +120,6 @@ ws = "backend.ws:main"
|
||||
scheduler = "backend.scheduler:main"
|
||||
notification = "backend.notification:main"
|
||||
executor = "backend.exec:main"
|
||||
analytics-setup = "generate_views:main_setup"
|
||||
analytics-views = "generate_views:main_views"
|
||||
copilot-executor = "backend.copilot.executor.__main__:main"
|
||||
cli = "backend.cli:main"
|
||||
format = "linter:format"
|
||||
|
||||
@@ -65,7 +65,6 @@ model User {
|
||||
NotificationBatches UserNotificationBatch[]
|
||||
PendingHumanReviews PendingHumanReview[]
|
||||
Workspace UserWorkspace?
|
||||
ClaimedInvite InvitedUser? @relation("InvitedUserAuthUser")
|
||||
|
||||
// OAuth Provider relations
|
||||
OAuthApplications OAuthApplication[]
|
||||
@@ -74,38 +73,6 @@ model User {
|
||||
OAuthRefreshTokens OAuthRefreshToken[]
|
||||
}
|
||||
|
||||
enum InvitedUserStatus {
|
||||
INVITED
|
||||
CLAIMED
|
||||
REVOKED
|
||||
}
|
||||
|
||||
enum TallyComputationStatus {
|
||||
PENDING
|
||||
RUNNING
|
||||
READY
|
||||
FAILED
|
||||
}
|
||||
|
||||
model InvitedUser {
|
||||
id String @id @default(uuid())
|
||||
createdAt DateTime @default(now())
|
||||
updatedAt DateTime @updatedAt
|
||||
|
||||
email String @unique
|
||||
status InvitedUserStatus @default(INVITED)
|
||||
authUserId String? @unique
|
||||
AuthUser User? @relation("InvitedUserAuthUser", fields: [authUserId], references: [id], onDelete: SetNull)
|
||||
name String?
|
||||
|
||||
tallyUnderstanding Json?
|
||||
tallyStatus TallyComputationStatus @default(PENDING)
|
||||
tallyComputedAt DateTime?
|
||||
tallyError String?
|
||||
|
||||
@@index([createdAt])
|
||||
}
|
||||
|
||||
enum OnboardingStep {
|
||||
// Introductory onboarding (Library)
|
||||
WELCOME
|
||||
@@ -1025,7 +992,7 @@ model StoreListing {
|
||||
ActiveVersion StoreListingVersion? @relation("ActiveVersion", fields: [activeVersionId], references: [id])
|
||||
|
||||
// The agent link here is only so we can do lookup on agentId
|
||||
agentGraphId String @unique
|
||||
agentGraphId String @unique
|
||||
|
||||
owningUserId String
|
||||
OwningUser User @relation(fields: [owningUserId], references: [id])
|
||||
|
||||
@@ -34,7 +34,7 @@ from backend.data.auth.api_key import create_api_key
|
||||
from backend.data.credit import get_user_credit_model
|
||||
from backend.data.db import prisma
|
||||
from backend.data.graph import Graph, Link, Node, create_graph
|
||||
from backend.data.invited_user import get_or_activate_user
|
||||
from backend.data.user import get_or_create_user
|
||||
from backend.util.clients import get_supabase
|
||||
|
||||
faker = Faker()
|
||||
@@ -151,7 +151,7 @@ class TestDataCreator:
|
||||
}
|
||||
|
||||
# Use the API function to create user in local database
|
||||
user = await get_or_activate_user(user_data)
|
||||
user = await get_or_create_user(user_data)
|
||||
users.append(user.model_dump())
|
||||
|
||||
except Exception as e:
|
||||
|
||||
@@ -1,14 +1,7 @@
|
||||
"use client";
|
||||
|
||||
import { Sidebar } from "@/components/__legacy__/Sidebar";
|
||||
import {
|
||||
UsersIcon,
|
||||
CurrencyDollarSimpleIcon,
|
||||
UserPlusIcon,
|
||||
MagnifyingGlassIcon,
|
||||
FileTextIcon,
|
||||
SlidersHorizontalIcon,
|
||||
} from "@phosphor-icons/react";
|
||||
import { Users, DollarSign, UserSearch, FileText } from "lucide-react";
|
||||
|
||||
import { IconSliders } from "@/components/__legacy__/ui/icons";
|
||||
|
||||
const sidebarLinkGroups = [
|
||||
{
|
||||
@@ -16,32 +9,27 @@ const sidebarLinkGroups = [
|
||||
{
|
||||
text: "Marketplace Management",
|
||||
href: "/admin/marketplace",
|
||||
icon: <UsersIcon size={24} />,
|
||||
icon: <Users className="h-6 w-6" />,
|
||||
},
|
||||
{
|
||||
text: "User Spending",
|
||||
href: "/admin/spending",
|
||||
icon: <CurrencyDollarSimpleIcon size={24} />,
|
||||
},
|
||||
{
|
||||
text: "Beta Invites",
|
||||
href: "/admin/users",
|
||||
icon: <UserPlusIcon size={24} />,
|
||||
icon: <DollarSign className="h-6 w-6" />,
|
||||
},
|
||||
{
|
||||
text: "User Impersonation",
|
||||
href: "/admin/impersonation",
|
||||
icon: <MagnifyingGlassIcon size={24} />,
|
||||
icon: <UserSearch className="h-6 w-6" />,
|
||||
},
|
||||
{
|
||||
text: "Execution Analytics",
|
||||
href: "/admin/execution-analytics",
|
||||
icon: <FileTextIcon size={24} />,
|
||||
icon: <FileText className="h-6 w-6" />,
|
||||
},
|
||||
{
|
||||
text: "Admin User Management",
|
||||
href: "/admin/settings",
|
||||
icon: <SlidersHorizontalIcon size={24} />,
|
||||
icon: <IconSliders className="h-6 w-6" />,
|
||||
},
|
||||
],
|
||||
},
|
||||
|
||||
@@ -1,80 +0,0 @@
|
||||
"use client";
|
||||
|
||||
import { Card } from "@/components/atoms/Card/Card";
|
||||
import { BulkInviteForm } from "../BulkInviteForm/BulkInviteForm";
|
||||
import { InviteUserForm } from "../InviteUserForm/InviteUserForm";
|
||||
import { InvitedUsersTable } from "../InvitedUsersTable/InvitedUsersTable";
|
||||
import { useAdminUsersPage } from "../../useAdminUsersPage";
|
||||
|
||||
export function AdminUsersPage() {
|
||||
const {
|
||||
email,
|
||||
name,
|
||||
bulkInviteFile,
|
||||
bulkInviteInputKey,
|
||||
lastBulkInviteResult,
|
||||
invitedUsers,
|
||||
isLoadingInvitedUsers,
|
||||
isRefreshingInvitedUsers,
|
||||
isCreatingInvite,
|
||||
isBulkInviting,
|
||||
pendingInviteAction,
|
||||
setEmail,
|
||||
setName,
|
||||
handleBulkInviteFileChange,
|
||||
handleBulkInviteSubmit,
|
||||
handleCreateInvite,
|
||||
handleRetryTally,
|
||||
handleRevoke,
|
||||
} = useAdminUsersPage();
|
||||
|
||||
return (
|
||||
<div className="mx-auto flex max-w-7xl flex-col gap-6 p-6">
|
||||
<div className="flex flex-col gap-2">
|
||||
<h1 className="text-3xl font-bold text-zinc-900">Beta Invites</h1>
|
||||
<p className="max-w-3xl text-sm text-zinc-600">
|
||||
Pre-provision beta users before they sign up. Invites store the
|
||||
platform-side record, run Tally understanding extraction, and activate
|
||||
the real account on the user's first authenticated request.
|
||||
</p>
|
||||
</div>
|
||||
|
||||
<div className="grid gap-6 xl:grid-cols-[24rem,1fr]">
|
||||
<div className="flex flex-col gap-6">
|
||||
<Card className="border border-zinc-200 shadow-sm">
|
||||
<InviteUserForm
|
||||
email={email}
|
||||
name={name}
|
||||
isSubmitting={isCreatingInvite}
|
||||
onEmailChange={setEmail}
|
||||
onNameChange={setName}
|
||||
onSubmit={handleCreateInvite}
|
||||
/>
|
||||
</Card>
|
||||
|
||||
<Card className="border border-zinc-200 shadow-sm">
|
||||
<BulkInviteForm
|
||||
selectedFile={bulkInviteFile}
|
||||
inputKey={bulkInviteInputKey}
|
||||
isSubmitting={isBulkInviting}
|
||||
lastResult={lastBulkInviteResult}
|
||||
onFileChange={handleBulkInviteFileChange}
|
||||
onSubmit={handleBulkInviteSubmit}
|
||||
/>
|
||||
</Card>
|
||||
</div>
|
||||
|
||||
<Card className="border border-zinc-200 shadow-sm">
|
||||
<InvitedUsersTable
|
||||
invitedUsers={invitedUsers}
|
||||
isLoading={isLoadingInvitedUsers}
|
||||
isRefreshing={isRefreshingInvitedUsers}
|
||||
pendingInviteAction={pendingInviteAction}
|
||||
onRetryTally={handleRetryTally}
|
||||
onRevoke={handleRevoke}
|
||||
/>
|
||||
</Card>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
@@ -1,135 +0,0 @@
|
||||
"use client";
|
||||
|
||||
import type { BulkInvitedUsersResponse } from "@/app/api/__generated__/models/bulkInvitedUsersResponse";
|
||||
import { Badge } from "@/components/atoms/Badge/Badge";
|
||||
import { Button } from "@/components/atoms/Button/Button";
|
||||
import type { FormEvent } from "react";
|
||||
|
||||
interface Props {
|
||||
selectedFile: File | null;
|
||||
inputKey: number;
|
||||
isSubmitting: boolean;
|
||||
lastResult: BulkInvitedUsersResponse | null;
|
||||
onFileChange: (file: File | null) => void;
|
||||
onSubmit: (event: FormEvent<HTMLFormElement>) => void;
|
||||
}
|
||||
|
||||
function getStatusVariant(status: "CREATED" | "SKIPPED" | "ERROR") {
|
||||
if (status === "CREATED") {
|
||||
return "success";
|
||||
}
|
||||
|
||||
if (status === "ERROR") {
|
||||
return "error";
|
||||
}
|
||||
|
||||
return "info";
|
||||
}
|
||||
|
||||
export function BulkInviteForm({
|
||||
selectedFile,
|
||||
inputKey,
|
||||
isSubmitting,
|
||||
lastResult,
|
||||
onFileChange,
|
||||
onSubmit,
|
||||
}: Props) {
|
||||
return (
|
||||
<form className="flex flex-col gap-4" onSubmit={onSubmit}>
|
||||
<div className="flex flex-col gap-1">
|
||||
<h2 className="text-xl font-semibold text-zinc-900">Bulk invite</h2>
|
||||
<p className="text-sm text-zinc-600">
|
||||
Upload a <span className="font-medium text-zinc-800">.txt</span> file
|
||||
with one email per line, or a{" "}
|
||||
<span className="font-medium text-zinc-800">.csv</span> with
|
||||
<span className="font-medium text-zinc-800"> email</span> and optional
|
||||
<span className="font-medium text-zinc-800"> name</span> columns.
|
||||
</p>
|
||||
</div>
|
||||
|
||||
<label
|
||||
htmlFor="bulk-invite-file-input"
|
||||
className="flex cursor-pointer flex-col gap-2 rounded-2xl border border-dashed border-zinc-300 bg-zinc-50 px-4 py-5 text-sm text-zinc-600 transition-colors focus-within:ring-2 focus-within:ring-zinc-500 focus-within:ring-offset-2 hover:border-zinc-400 hover:bg-zinc-100"
|
||||
>
|
||||
<span className="font-medium text-zinc-900">
|
||||
{selectedFile ? selectedFile.name : "Choose invite file"}
|
||||
</span>
|
||||
<span>Maximum 500 rows, UTF-8 encoded.</span>
|
||||
<input
|
||||
id="bulk-invite-file-input"
|
||||
key={inputKey}
|
||||
type="file"
|
||||
accept=".txt,.csv,text/plain,text/csv"
|
||||
disabled={isSubmitting}
|
||||
className="sr-only"
|
||||
onChange={(event) =>
|
||||
onFileChange(event.target.files?.item(0) ?? null)
|
||||
}
|
||||
/>
|
||||
</label>
|
||||
|
||||
<Button
|
||||
type="submit"
|
||||
variant="primary"
|
||||
loading={isSubmitting}
|
||||
disabled={!selectedFile}
|
||||
className="w-full"
|
||||
>
|
||||
{isSubmitting ? "Uploading invites..." : "Upload invite file"}
|
||||
</Button>
|
||||
|
||||
{lastResult ? (
|
||||
<div className="flex flex-col gap-3 rounded-2xl border border-zinc-200 bg-zinc-50 p-4">
|
||||
<div className="grid grid-cols-3 gap-2 text-center">
|
||||
<div className="rounded-xl bg-white px-3 py-2">
|
||||
<div className="text-lg font-semibold text-zinc-900">
|
||||
{lastResult.created_count}
|
||||
</div>
|
||||
<div className="text-xs uppercase tracking-[0.16em] text-zinc-500">
|
||||
Created
|
||||
</div>
|
||||
</div>
|
||||
<div className="rounded-xl bg-white px-3 py-2">
|
||||
<div className="text-lg font-semibold text-zinc-900">
|
||||
{lastResult.skipped_count}
|
||||
</div>
|
||||
<div className="text-xs uppercase tracking-[0.16em] text-zinc-500">
|
||||
Skipped
|
||||
</div>
|
||||
</div>
|
||||
<div className="rounded-xl bg-white px-3 py-2">
|
||||
<div className="text-lg font-semibold text-zinc-900">
|
||||
{lastResult.error_count}
|
||||
</div>
|
||||
<div className="text-xs uppercase tracking-[0.16em] text-zinc-500">
|
||||
Errors
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div className="max-h-64 overflow-y-auto rounded-xl border border-zinc-200 bg-white">
|
||||
<div className="flex flex-col divide-y divide-zinc-100">
|
||||
{lastResult.results.map((row) => (
|
||||
<div
|
||||
key={`${row.row_number}-${row.email ?? row.message}`}
|
||||
className="flex items-start gap-3 px-3 py-3"
|
||||
>
|
||||
<Badge variant={getStatusVariant(row.status)} size="small">
|
||||
{row.status}
|
||||
</Badge>
|
||||
<div className="flex min-w-0 flex-1 flex-col gap-1">
|
||||
<span className="text-sm font-medium text-zinc-900">
|
||||
Row {row.row_number}
|
||||
{row.email ? ` · ${row.email}` : ""}
|
||||
</span>
|
||||
<span className="text-xs text-zinc-500">{row.message}</span>
|
||||
</div>
|
||||
</div>
|
||||
))}
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
) : null}
|
||||
</form>
|
||||
);
|
||||
}
|
||||
@@ -1,66 +0,0 @@
|
||||
"use client";
|
||||
|
||||
import { Button } from "@/components/atoms/Button/Button";
|
||||
import { Input } from "@/components/atoms/Input/Input";
|
||||
import type { FormEvent } from "react";
|
||||
|
||||
interface Props {
|
||||
email: string;
|
||||
name: string;
|
||||
isSubmitting: boolean;
|
||||
onEmailChange: (value: string) => void;
|
||||
onNameChange: (value: string) => void;
|
||||
onSubmit: (event: FormEvent<HTMLFormElement>) => void;
|
||||
}
|
||||
|
||||
export function InviteUserForm({
|
||||
email,
|
||||
name,
|
||||
isSubmitting,
|
||||
onEmailChange,
|
||||
onNameChange,
|
||||
onSubmit,
|
||||
}: Props) {
|
||||
return (
|
||||
<form className="flex flex-col gap-4" onSubmit={onSubmit}>
|
||||
<div className="flex flex-col gap-1">
|
||||
<h2 className="text-xl font-semibold text-zinc-900">Create invite</h2>
|
||||
<p className="text-sm text-zinc-600">
|
||||
The invite is stored immediately, then Tally pre-seeding starts in the
|
||||
background.
|
||||
</p>
|
||||
</div>
|
||||
|
||||
<Input
|
||||
id="invite-email"
|
||||
label="Email"
|
||||
type="email"
|
||||
value={email}
|
||||
placeholder="jane@example.com"
|
||||
autoComplete="email"
|
||||
disabled={isSubmitting}
|
||||
onChange={(event) => onEmailChange(event.target.value)}
|
||||
/>
|
||||
|
||||
<Input
|
||||
id="invite-name"
|
||||
label="Name"
|
||||
type="text"
|
||||
value={name}
|
||||
placeholder="Jane Doe"
|
||||
disabled={isSubmitting}
|
||||
onChange={(event) => onNameChange(event.target.value)}
|
||||
/>
|
||||
|
||||
<Button
|
||||
type="submit"
|
||||
variant="primary"
|
||||
loading={isSubmitting}
|
||||
disabled={!email.trim()}
|
||||
className="w-full"
|
||||
>
|
||||
{isSubmitting ? "Creating invite..." : "Create invite"}
|
||||
</Button>
|
||||
</form>
|
||||
);
|
||||
}
|
||||
@@ -1,209 +0,0 @@
|
||||
"use client";
|
||||
|
||||
import type { InvitedUserResponse } from "@/app/api/__generated__/models/invitedUserResponse";
|
||||
import { Badge } from "@/components/atoms/Badge/Badge";
|
||||
import { Button } from "@/components/atoms/Button/Button";
|
||||
import {
|
||||
Table,
|
||||
TableBody,
|
||||
TableCell,
|
||||
TableHead,
|
||||
TableHeader,
|
||||
TableRow,
|
||||
} from "@/components/__legacy__/ui/table";
|
||||
|
||||
interface Props {
|
||||
invitedUsers: InvitedUserResponse[];
|
||||
isLoading: boolean;
|
||||
isRefreshing: boolean;
|
||||
pendingInviteAction: string | null;
|
||||
onRetryTally: (invitedUserId: string) => void;
|
||||
onRevoke: (invitedUserId: string) => void;
|
||||
}
|
||||
|
||||
function getInviteBadgeVariant(status: InvitedUserResponse["status"]) {
|
||||
if (status === "CLAIMED") {
|
||||
return "success";
|
||||
}
|
||||
|
||||
if (status === "REVOKED") {
|
||||
return "error";
|
||||
}
|
||||
|
||||
return "info";
|
||||
}
|
||||
|
||||
function getTallyBadgeVariant(status: InvitedUserResponse["tally_status"]) {
|
||||
if (status === "READY") {
|
||||
return "success";
|
||||
}
|
||||
|
||||
if (status === "FAILED") {
|
||||
return "error";
|
||||
}
|
||||
|
||||
return "info";
|
||||
}
|
||||
|
||||
function formatDate(value: Date | undefined) {
|
||||
if (!value) {
|
||||
return "-";
|
||||
}
|
||||
|
||||
return value.toLocaleString();
|
||||
}
|
||||
|
||||
function getTallySummary(invitedUser: InvitedUserResponse) {
|
||||
if (invitedUser.tally_status === "FAILED" && invitedUser.tally_error) {
|
||||
return invitedUser.tally_error;
|
||||
}
|
||||
|
||||
if (invitedUser.tally_status === "READY" && invitedUser.tally_understanding) {
|
||||
return "Stored and ready for activation";
|
||||
}
|
||||
|
||||
if (invitedUser.tally_status === "READY") {
|
||||
return "No matching Tally submission found";
|
||||
}
|
||||
|
||||
if (invitedUser.tally_status === "RUNNING") {
|
||||
return "Extraction in progress";
|
||||
}
|
||||
|
||||
return "Waiting to run";
|
||||
}
|
||||
|
||||
function isActionPending(
|
||||
pendingInviteAction: string | null,
|
||||
action: "retry" | "revoke",
|
||||
invitedUserId: string,
|
||||
) {
|
||||
return pendingInviteAction === `${action}:${invitedUserId}`;
|
||||
}
|
||||
|
||||
export function InvitedUsersTable({
|
||||
invitedUsers,
|
||||
isLoading,
|
||||
isRefreshing,
|
||||
pendingInviteAction,
|
||||
onRetryTally,
|
||||
onRevoke,
|
||||
}: Props) {
|
||||
return (
|
||||
<div className="flex flex-col gap-4">
|
||||
<div className="flex items-center justify-between gap-4">
|
||||
<div className="flex flex-col gap-1">
|
||||
<h2 className="text-xl font-semibold text-zinc-900">Invited users</h2>
|
||||
<p className="text-sm text-zinc-600">
|
||||
Live invite state, claim status, and Tally pre-seeding progress.
|
||||
</p>
|
||||
</div>
|
||||
<span className="text-xs uppercase tracking-[0.18em] text-zinc-400">
|
||||
{isRefreshing ? "Refreshing" : `${invitedUsers.length} total`}
|
||||
</span>
|
||||
</div>
|
||||
|
||||
<div className="overflow-hidden rounded-2xl border border-zinc-200">
|
||||
<Table>
|
||||
<TableHeader className="bg-zinc-50">
|
||||
<TableRow>
|
||||
<TableHead>Email</TableHead>
|
||||
<TableHead>Name</TableHead>
|
||||
<TableHead>Invite</TableHead>
|
||||
<TableHead>Tally</TableHead>
|
||||
<TableHead>Claimed User</TableHead>
|
||||
<TableHead>Created</TableHead>
|
||||
<TableHead className="text-right">Actions</TableHead>
|
||||
</TableRow>
|
||||
</TableHeader>
|
||||
<TableBody>
|
||||
{isLoading ? (
|
||||
<TableRow>
|
||||
<TableCell
|
||||
colSpan={7}
|
||||
className="py-10 text-center text-zinc-500"
|
||||
>
|
||||
Loading invited users...
|
||||
</TableCell>
|
||||
</TableRow>
|
||||
) : invitedUsers.length === 0 ? (
|
||||
<TableRow>
|
||||
<TableCell
|
||||
colSpan={7}
|
||||
className="py-10 text-center text-zinc-500"
|
||||
>
|
||||
No invited users yet
|
||||
</TableCell>
|
||||
</TableRow>
|
||||
) : (
|
||||
invitedUsers.map((invitedUser) => (
|
||||
<TableRow key={invitedUser.id} className="align-top">
|
||||
<TableCell className="font-medium text-zinc-900">
|
||||
{invitedUser.email}
|
||||
</TableCell>
|
||||
<TableCell>{invitedUser.name || "-"}</TableCell>
|
||||
<TableCell>
|
||||
<Badge variant={getInviteBadgeVariant(invitedUser.status)}>
|
||||
{invitedUser.status}
|
||||
</Badge>
|
||||
</TableCell>
|
||||
<TableCell>
|
||||
<div className="flex max-w-xs flex-col gap-2">
|
||||
<Badge
|
||||
variant={getTallyBadgeVariant(invitedUser.tally_status)}
|
||||
>
|
||||
{invitedUser.tally_status}
|
||||
</Badge>
|
||||
<span className="text-xs text-zinc-500">
|
||||
{getTallySummary(invitedUser)}
|
||||
</span>
|
||||
<span className="text-xs text-zinc-400">
|
||||
{formatDate(invitedUser.tally_computed_at ?? undefined)}
|
||||
</span>
|
||||
</div>
|
||||
</TableCell>
|
||||
<TableCell className="font-mono text-xs text-zinc-500">
|
||||
{invitedUser.auth_user_id || "-"}
|
||||
</TableCell>
|
||||
<TableCell className="text-sm text-zinc-500">
|
||||
{formatDate(invitedUser.created_at)}
|
||||
</TableCell>
|
||||
<TableCell>
|
||||
<div className="flex justify-end gap-2">
|
||||
<Button
|
||||
variant="outline"
|
||||
size="small"
|
||||
disabled={invitedUser.status === "REVOKED"}
|
||||
loading={isActionPending(
|
||||
pendingInviteAction,
|
||||
"retry",
|
||||
invitedUser.id,
|
||||
)}
|
||||
onClick={() => onRetryTally(invitedUser.id)}
|
||||
>
|
||||
Retry Tally
|
||||
</Button>
|
||||
<Button
|
||||
variant="secondary"
|
||||
size="small"
|
||||
disabled={invitedUser.status !== "INVITED"}
|
||||
loading={isActionPending(
|
||||
pendingInviteAction,
|
||||
"revoke",
|
||||
invitedUser.id,
|
||||
)}
|
||||
onClick={() => onRevoke(invitedUser.id)}
|
||||
>
|
||||
Revoke
|
||||
</Button>
|
||||
</div>
|
||||
</TableCell>
|
||||
</TableRow>
|
||||
))
|
||||
)}
|
||||
</TableBody>
|
||||
</Table>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
@@ -1,11 +1,16 @@
|
||||
import { withRoleAccess } from "@/lib/withRoleAccess";
|
||||
import { AdminUsersPage } from "./components/AdminUsersPage/AdminUsersPage";
|
||||
import React from "react";
|
||||
|
||||
function AdminUsers() {
|
||||
return <AdminUsersPage />;
|
||||
return (
|
||||
<div>
|
||||
<h1>Users Dashboard</h1>
|
||||
{/* Add your admin-only content here */}
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
export default async function AdminUsersRoute() {
|
||||
export default async function AdminUsersPage() {
|
||||
"use server";
|
||||
const withAdminAccess = await withRoleAccess(["admin"]);
|
||||
const ProtectedAdminUsers = await withAdminAccess(AdminUsers);
|
||||
|
||||
@@ -1,197 +0,0 @@
|
||||
"use client";
|
||||
|
||||
import type { BulkInvitedUsersResponse } from "@/app/api/__generated__/models/bulkInvitedUsersResponse";
|
||||
import { okData } from "@/app/api/helpers";
|
||||
import {
|
||||
getGetV2ListInvitedUsersQueryKey,
|
||||
useGetV2ListInvitedUsers,
|
||||
usePostV2BulkCreateInvitedUsers,
|
||||
usePostV2CreateInvitedUser,
|
||||
usePostV2RetryInvitedUserTally,
|
||||
usePostV2RevokeInvitedUser,
|
||||
} from "@/app/api/__generated__/endpoints/admin/admin";
|
||||
import { useToast } from "@/components/molecules/Toast/use-toast";
|
||||
import { useQueryClient } from "@tanstack/react-query";
|
||||
import { type FormEvent, useState } from "react";
|
||||
|
||||
function getErrorMessage(error: unknown) {
|
||||
if (error instanceof Error) {
|
||||
return error.message;
|
||||
}
|
||||
|
||||
return "Something went wrong";
|
||||
}
|
||||
|
||||
export function useAdminUsersPage() {
|
||||
const queryClient = useQueryClient();
|
||||
const { toast } = useToast();
|
||||
const [email, setEmail] = useState("");
|
||||
const [name, setName] = useState("");
|
||||
const [bulkInviteFile, setBulkInviteFile] = useState<File | null>(null);
|
||||
const [bulkInviteInputKey, setBulkInviteInputKey] = useState(0);
|
||||
const [lastBulkInviteResult, setLastBulkInviteResult] =
|
||||
useState<BulkInvitedUsersResponse | null>(null);
|
||||
const [pendingInviteAction, setPendingInviteAction] = useState<string | null>(
|
||||
null,
|
||||
);
|
||||
|
||||
const invitedUsersQuery = useGetV2ListInvitedUsers(undefined, {
|
||||
query: {
|
||||
select: okData,
|
||||
refetchInterval: 30_000,
|
||||
},
|
||||
});
|
||||
|
||||
const createInvitedUserMutation = usePostV2CreateInvitedUser({
|
||||
mutation: {
|
||||
onSuccess: async () => {
|
||||
setEmail("");
|
||||
setName("");
|
||||
await queryClient.invalidateQueries({
|
||||
queryKey: getGetV2ListInvitedUsersQueryKey(),
|
||||
});
|
||||
toast({
|
||||
title: "Invited user created",
|
||||
variant: "default",
|
||||
});
|
||||
},
|
||||
onError: (error) => {
|
||||
toast({
|
||||
title: getErrorMessage(error),
|
||||
variant: "destructive",
|
||||
});
|
||||
},
|
||||
},
|
||||
});
|
||||
|
||||
const bulkCreateInvitedUsersMutation = usePostV2BulkCreateInvitedUsers({
|
||||
mutation: {
|
||||
onSuccess: async (response) => {
|
||||
const result = okData(response) ?? null;
|
||||
setBulkInviteFile(null);
|
||||
setBulkInviteInputKey((currentValue) => currentValue + 1);
|
||||
setLastBulkInviteResult(result);
|
||||
await queryClient.invalidateQueries({
|
||||
queryKey: getGetV2ListInvitedUsersQueryKey(),
|
||||
});
|
||||
toast({
|
||||
title: result
|
||||
? `${result.created_count} invites created`
|
||||
: "Bulk invite upload complete",
|
||||
variant: "default",
|
||||
});
|
||||
},
|
||||
onError: (error) => {
|
||||
toast({
|
||||
title: getErrorMessage(error),
|
||||
variant: "destructive",
|
||||
});
|
||||
},
|
||||
},
|
||||
});
|
||||
|
||||
const retryInvitedUserTallyMutation = usePostV2RetryInvitedUserTally({
|
||||
mutation: {
|
||||
onSuccess: async () => {
|
||||
setPendingInviteAction(null);
|
||||
await queryClient.invalidateQueries({
|
||||
queryKey: getGetV2ListInvitedUsersQueryKey(),
|
||||
});
|
||||
toast({
|
||||
title: "Tally pre-seeding restarted",
|
||||
variant: "default",
|
||||
});
|
||||
},
|
||||
onError: (error) => {
|
||||
setPendingInviteAction(null);
|
||||
toast({
|
||||
title: getErrorMessage(error),
|
||||
variant: "destructive",
|
||||
});
|
||||
},
|
||||
},
|
||||
});
|
||||
|
||||
const revokeInvitedUserMutation = usePostV2RevokeInvitedUser({
|
||||
mutation: {
|
||||
onSuccess: async () => {
|
||||
setPendingInviteAction(null);
|
||||
await queryClient.invalidateQueries({
|
||||
queryKey: getGetV2ListInvitedUsersQueryKey(),
|
||||
});
|
||||
toast({
|
||||
title: "Invite revoked",
|
||||
variant: "default",
|
||||
});
|
||||
},
|
||||
onError: (error) => {
|
||||
setPendingInviteAction(null);
|
||||
toast({
|
||||
title: getErrorMessage(error),
|
||||
variant: "destructive",
|
||||
});
|
||||
},
|
||||
},
|
||||
});
|
||||
|
||||
function handleCreateInvite(event: FormEvent<HTMLFormElement>) {
|
||||
event.preventDefault();
|
||||
|
||||
createInvitedUserMutation.mutate({
|
||||
data: {
|
||||
email,
|
||||
name: name.trim() || null,
|
||||
},
|
||||
});
|
||||
}
|
||||
|
||||
function handleRetryTally(invitedUserId: string) {
|
||||
setPendingInviteAction(`retry:${invitedUserId}`);
|
||||
retryInvitedUserTallyMutation.mutate({ invitedUserId });
|
||||
}
|
||||
|
||||
function handleBulkInviteFileChange(file: File | null) {
|
||||
setBulkInviteFile(file);
|
||||
}
|
||||
|
||||
function handleBulkInviteSubmit(event: FormEvent<HTMLFormElement>) {
|
||||
event.preventDefault();
|
||||
|
||||
if (!bulkInviteFile) {
|
||||
return;
|
||||
}
|
||||
|
||||
bulkCreateInvitedUsersMutation.mutate({
|
||||
data: {
|
||||
file: bulkInviteFile,
|
||||
},
|
||||
});
|
||||
}
|
||||
|
||||
function handleRevoke(invitedUserId: string) {
|
||||
setPendingInviteAction(`revoke:${invitedUserId}`);
|
||||
revokeInvitedUserMutation.mutate({ invitedUserId });
|
||||
}
|
||||
|
||||
return {
|
||||
email,
|
||||
name,
|
||||
bulkInviteFile,
|
||||
bulkInviteInputKey,
|
||||
lastBulkInviteResult,
|
||||
invitedUsers: invitedUsersQuery.data?.invited_users ?? [],
|
||||
invitedUsersError: invitedUsersQuery.error,
|
||||
isLoadingInvitedUsers: invitedUsersQuery.isLoading,
|
||||
isRefreshingInvitedUsers: invitedUsersQuery.isFetching,
|
||||
isCreatingInvite: createInvitedUserMutation.isPending,
|
||||
isBulkInviting: bulkCreateInvitedUsersMutation.isPending,
|
||||
pendingInviteAction,
|
||||
setEmail,
|
||||
setName,
|
||||
handleBulkInviteFileChange,
|
||||
handleBulkInviteSubmit,
|
||||
handleCreateInvite,
|
||||
handleRetryTally,
|
||||
handleRevoke,
|
||||
};
|
||||
}
|
||||
@@ -1,8 +1,14 @@
|
||||
"use client";
|
||||
|
||||
import {
|
||||
DropdownMenu,
|
||||
DropdownMenuContent,
|
||||
DropdownMenuItem,
|
||||
DropdownMenuTrigger,
|
||||
} from "@/components/molecules/DropdownMenu/DropdownMenu";
|
||||
import { SidebarProvider } from "@/components/ui/sidebar";
|
||||
import { cn } from "@/lib/utils";
|
||||
import { UploadSimple } from "@phosphor-icons/react";
|
||||
import { DotsThree, UploadSimple } from "@phosphor-icons/react";
|
||||
import { useCallback, useRef, useState } from "react";
|
||||
import { ChatContainer } from "./components/ChatContainer/ChatContainer";
|
||||
import { ChatSidebar } from "./components/ChatSidebar/ChatSidebar";
|
||||
@@ -86,6 +92,7 @@ export function CopilotPage() {
|
||||
// Delete functionality
|
||||
sessionToDelete,
|
||||
isDeleting,
|
||||
handleDeleteClick,
|
||||
handleConfirmDelete,
|
||||
handleCancelDelete,
|
||||
} = useCopilotPage();
|
||||
@@ -141,6 +148,38 @@ export function CopilotPage() {
|
||||
isUploadingFiles={isUploadingFiles}
|
||||
droppedFiles={droppedFiles}
|
||||
onDroppedFilesConsumed={handleDroppedFilesConsumed}
|
||||
headerSlot={
|
||||
isMobile && sessionId ? (
|
||||
<div className="flex justify-end">
|
||||
<DropdownMenu>
|
||||
<DropdownMenuTrigger asChild>
|
||||
<button
|
||||
className="rounded p-1.5 hover:bg-neutral-100"
|
||||
aria-label="More actions"
|
||||
>
|
||||
<DotsThree className="h-5 w-5 text-neutral-600" />
|
||||
</button>
|
||||
</DropdownMenuTrigger>
|
||||
<DropdownMenuContent align="end">
|
||||
<DropdownMenuItem
|
||||
onClick={() => {
|
||||
const session = sessions.find(
|
||||
(s) => s.id === sessionId,
|
||||
);
|
||||
if (session) {
|
||||
handleDeleteClick(session.id, session.title);
|
||||
}
|
||||
}}
|
||||
disabled={isDeleting}
|
||||
className="text-red-600 focus:bg-red-50 focus:text-red-600"
|
||||
>
|
||||
Delete chat
|
||||
</DropdownMenuItem>
|
||||
</DropdownMenuContent>
|
||||
</DropdownMenu>
|
||||
</div>
|
||||
) : undefined
|
||||
}
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
import { ChatInput } from "@/app/(platform)/copilot/components/ChatInput/ChatInput";
|
||||
import { UIDataTypes, UIMessage, UITools } from "ai";
|
||||
import { LayoutGroup, motion } from "framer-motion";
|
||||
import { ReactNode } from "react";
|
||||
import { ChatMessagesContainer } from "../ChatMessagesContainer/ChatMessagesContainer";
|
||||
import { CopilotChatActionsProvider } from "../CopilotChatActionsProvider/CopilotChatActionsProvider";
|
||||
import { EmptySession } from "../EmptySession/EmptySession";
|
||||
@@ -20,6 +21,7 @@ export interface ChatContainerProps {
|
||||
onSend: (message: string, files?: File[]) => void | Promise<void>;
|
||||
onStop: () => void;
|
||||
isUploadingFiles?: boolean;
|
||||
headerSlot?: ReactNode;
|
||||
/** Files dropped onto the chat window. */
|
||||
droppedFiles?: File[];
|
||||
/** Called after droppedFiles have been consumed by ChatInput. */
|
||||
@@ -38,6 +40,7 @@ export const ChatContainer = ({
|
||||
onSend,
|
||||
onStop,
|
||||
isUploadingFiles,
|
||||
headerSlot,
|
||||
droppedFiles,
|
||||
onDroppedFilesConsumed,
|
||||
}: ChatContainerProps) => {
|
||||
@@ -60,6 +63,7 @@ export const ChatContainer = ({
|
||||
status={status}
|
||||
error={error}
|
||||
isLoading={isLoadingSession}
|
||||
headerSlot={headerSlot}
|
||||
sessionID={sessionId}
|
||||
/>
|
||||
<motion.div
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import { useCopilotUIStore } from "@/app/(platform)/copilot/store";
|
||||
import { ChangeEvent, FormEvent, useEffect, useState } from "react";
|
||||
|
||||
interface Args {
|
||||
@@ -16,6 +17,16 @@ export function useChatInput({
|
||||
}: Args) {
|
||||
const [value, setValue] = useState("");
|
||||
const [isSending, setIsSending] = useState(false);
|
||||
const { initialPrompt, setInitialPrompt } = useCopilotUIStore();
|
||||
|
||||
useEffect(
|
||||
function consumeInitialPrompt() {
|
||||
if (!initialPrompt) return;
|
||||
setValue((prev) => (prev.length === 0 ? initialPrompt : prev));
|
||||
setInitialPrompt(null);
|
||||
},
|
||||
[initialPrompt, setInitialPrompt],
|
||||
);
|
||||
|
||||
useEffect(
|
||||
function focusOnMount() {
|
||||
|
||||
@@ -30,6 +30,7 @@ interface Props {
|
||||
status: string;
|
||||
error: Error | undefined;
|
||||
isLoading: boolean;
|
||||
headerSlot?: React.ReactNode;
|
||||
sessionID?: string | null;
|
||||
}
|
||||
|
||||
@@ -101,6 +102,7 @@ export function ChatMessagesContainer({
|
||||
status,
|
||||
error,
|
||||
isLoading,
|
||||
headerSlot,
|
||||
sessionID,
|
||||
}: Props) {
|
||||
const lastMessage = messages[messages.length - 1];
|
||||
@@ -133,6 +135,7 @@ export function ChatMessagesContainer({
|
||||
return (
|
||||
<Conversation className="min-h-0 flex-1">
|
||||
<ConversationContent className="flex flex-1 flex-col gap-6 px-3 py-6">
|
||||
{headerSlot}
|
||||
{isLoading && messages.length === 0 && (
|
||||
<div
|
||||
className="flex flex-1 items-center justify-center"
|
||||
|
||||
@@ -37,7 +37,6 @@ import { useCopilotUIStore } from "../../store";
|
||||
import { NotificationToggle } from "./components/NotificationToggle/NotificationToggle";
|
||||
import { DeleteChatDialog } from "../DeleteChatDialog/DeleteChatDialog";
|
||||
import { PulseLoader } from "../PulseLoader/PulseLoader";
|
||||
import { UsageLimits } from "../UsageLimits/UsageLimits";
|
||||
|
||||
export function ChatSidebar() {
|
||||
const { state } = useSidebar();
|
||||
@@ -257,10 +256,11 @@ export function ChatSidebar() {
|
||||
<Text variant="h3" size="body-medium">
|
||||
Your chats
|
||||
</Text>
|
||||
<div className="flex items-center">
|
||||
<UsageLimits />
|
||||
<div className="relative left-5 flex items-center gap-1">
|
||||
<NotificationToggle />
|
||||
<SidebarTrigger />
|
||||
<div className="relative left-1">
|
||||
<SidebarTrigger />
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
{sessionId ? (
|
||||
|
||||
@@ -7,7 +7,6 @@ import {
|
||||
PopoverTrigger,
|
||||
} from "@/components/molecules/Popover/Popover";
|
||||
import { toast } from "@/components/molecules/Toast/use-toast";
|
||||
import { Button } from "@/components/ui/button";
|
||||
import { cn } from "@/lib/utils";
|
||||
import { Bell, BellRinging, BellSlash } from "@phosphor-icons/react";
|
||||
import { useCopilotUIStore } from "../../../../store";
|
||||
@@ -49,7 +48,10 @@ export function NotificationToggle() {
|
||||
return (
|
||||
<Popover>
|
||||
<PopoverTrigger asChild>
|
||||
<Button variant="ghost" size="icon" aria-label="Notification settings">
|
||||
<button
|
||||
className="rounded p-1 text-black transition-colors hover:bg-zinc-50"
|
||||
aria-label="Notification settings"
|
||||
>
|
||||
{!isNotificationsEnabled ? (
|
||||
<BellSlash className="!size-5" />
|
||||
) : isSoundEnabled ? (
|
||||
@@ -57,7 +59,7 @@ export function NotificationToggle() {
|
||||
) : (
|
||||
<Bell className="!size-5" />
|
||||
)}
|
||||
</Button>
|
||||
</button>
|
||||
</PopoverTrigger>
|
||||
<PopoverContent align="start" className="w-56 p-3">
|
||||
<div className="flex flex-col gap-3">
|
||||
|
||||
@@ -1,9 +1,7 @@
|
||||
"use client";
|
||||
|
||||
import { useGetV2GetSuggestedPrompts } from "@/app/api/__generated__/endpoints/chat/chat";
|
||||
import { ChatInput } from "@/app/(platform)/copilot/components/ChatInput/ChatInput";
|
||||
import { Button } from "@/components/atoms/Button/Button";
|
||||
import { Skeleton } from "@/components/atoms/Skeleton/Skeleton";
|
||||
import { Text } from "@/components/atoms/Text/Text";
|
||||
import { useSupabase } from "@/lib/supabase/hooks/useSupabase";
|
||||
import { SpinnerGapIcon } from "@phosphor-icons/react";
|
||||
@@ -35,38 +33,15 @@ export function EmptySession({
|
||||
}: Props) {
|
||||
const { user } = useSupabase();
|
||||
const greetingName = getGreetingName(user);
|
||||
|
||||
const { data: suggestedPromptsResponse, isLoading: isLoadingPrompts } =
|
||||
useGetV2GetSuggestedPrompts({
|
||||
query: { staleTime: Infinity },
|
||||
});
|
||||
const customPrompts =
|
||||
suggestedPromptsResponse?.status === 200
|
||||
? suggestedPromptsResponse.data.prompts
|
||||
: undefined;
|
||||
const quickActions = getQuickActions(customPrompts);
|
||||
const quickActions = getQuickActions();
|
||||
const [loadingAction, setLoadingAction] = useState<string | null>(null);
|
||||
const [inputPlaceholder, setInputPlaceholder] = useState(
|
||||
getInputPlaceholder(),
|
||||
);
|
||||
|
||||
// Use matchMedia instead of resize event — fires only when crossing
|
||||
// the 500px and 1081px breakpoints defined in getInputPlaceholder(),
|
||||
// rather than dozens of times per second during a window drag.
|
||||
useEffect(() => {
|
||||
function update() {
|
||||
setInputPlaceholder(getInputPlaceholder(window.innerWidth));
|
||||
}
|
||||
const mq500 = window.matchMedia("(min-width: 500px)");
|
||||
const mq1081 = window.matchMedia("(min-width: 1081px)");
|
||||
update();
|
||||
mq500.addEventListener("change", update);
|
||||
mq1081.addEventListener("change", update);
|
||||
return () => {
|
||||
mq500.removeEventListener("change", update);
|
||||
mq1081.removeEventListener("change", update);
|
||||
};
|
||||
}, []);
|
||||
setInputPlaceholder(getInputPlaceholder(window.innerWidth));
|
||||
}, [window.innerWidth]);
|
||||
|
||||
async function handleQuickActionClick(action: string) {
|
||||
if (isCreatingSession || loadingAction) return;
|
||||
@@ -116,32 +91,28 @@ export function EmptySession({
|
||||
</div>
|
||||
|
||||
<div className="flex flex-wrap items-center justify-center gap-3 overflow-x-auto [-ms-overflow-style:none] [scrollbar-width:none] [&::-webkit-scrollbar]:hidden">
|
||||
{isLoadingPrompts
|
||||
? Array.from({ length: 3 }, (_, i) => (
|
||||
<Skeleton key={i} className="h-10 w-64 shrink-0 rounded-full" />
|
||||
))
|
||||
: quickActions.map((action) => (
|
||||
<Button
|
||||
key={action}
|
||||
type="button"
|
||||
variant="outline"
|
||||
size="small"
|
||||
onClick={() => void handleQuickActionClick(action)}
|
||||
disabled={isCreatingSession || loadingAction !== null}
|
||||
aria-busy={loadingAction === action}
|
||||
leftIcon={
|
||||
loadingAction === action ? (
|
||||
<SpinnerGapIcon
|
||||
className="h-4 w-4 animate-spin"
|
||||
weight="bold"
|
||||
/>
|
||||
) : null
|
||||
}
|
||||
className="h-auto shrink-0 border-zinc-300 px-3 py-2 text-[.9rem] text-zinc-600"
|
||||
>
|
||||
{action}
|
||||
</Button>
|
||||
))}
|
||||
{quickActions.map((action) => (
|
||||
<Button
|
||||
key={action}
|
||||
type="button"
|
||||
variant="outline"
|
||||
size="small"
|
||||
onClick={() => void handleQuickActionClick(action)}
|
||||
disabled={isCreatingSession || loadingAction !== null}
|
||||
aria-busy={loadingAction === action}
|
||||
leftIcon={
|
||||
loadingAction === action ? (
|
||||
<SpinnerGapIcon
|
||||
className="h-4 w-4 animate-spin"
|
||||
weight="bold"
|
||||
/>
|
||||
) : null
|
||||
}
|
||||
className="h-auto shrink-0 border-zinc-300 px-3 py-2 text-[.9rem] text-zinc-600"
|
||||
>
|
||||
{action}
|
||||
</Button>
|
||||
))}
|
||||
</div>
|
||||
</motion.div>
|
||||
</div>
|
||||
|
||||
@@ -12,17 +12,12 @@ export function getInputPlaceholder(width?: number) {
|
||||
return "What's your role and what eats up most of your day? e.g. 'I'm a recruiter and I hate...'";
|
||||
}
|
||||
|
||||
const DEFAULT_QUICK_ACTIONS = [
|
||||
"I don't know where to start, just ask me stuff",
|
||||
"I do the same thing every week and it's killing me",
|
||||
"Help me find where I'm wasting my time",
|
||||
];
|
||||
|
||||
export function getQuickActions(customPrompts?: string[]) {
|
||||
if (customPrompts && customPrompts.length > 0) {
|
||||
return customPrompts;
|
||||
}
|
||||
return DEFAULT_QUICK_ACTIONS;
|
||||
export function getQuickActions() {
|
||||
return [
|
||||
"I don't know where to start, just ask me stuff",
|
||||
"I do the same thing every week and it's killing me",
|
||||
"Help me find where I'm wasting my time",
|
||||
];
|
||||
}
|
||||
|
||||
export function getGreetingName(user?: User | null) {
|
||||
|
||||
@@ -1,146 +0,0 @@
|
||||
import type { CoPilotUsageStatus } from "@/app/api/__generated__/models/coPilotUsageStatus";
|
||||
import {
|
||||
Popover,
|
||||
PopoverContent,
|
||||
PopoverTrigger,
|
||||
} from "@/components/molecules/Popover/Popover";
|
||||
import { Button } from "@/components/ui/button";
|
||||
import { ChartBar } from "@phosphor-icons/react";
|
||||
import { useUsageLimits } from "./useUsageLimits";
|
||||
|
||||
const MS_PER_MINUTE = 60_000;
|
||||
const MS_PER_HOUR = 3_600_000;
|
||||
|
||||
function formatResetTime(resetsAt: Date | string): string {
|
||||
const resetDate =
|
||||
typeof resetsAt === "string" ? new Date(resetsAt) : resetsAt;
|
||||
const now = new Date();
|
||||
const diffMs = resetDate.getTime() - now.getTime();
|
||||
if (diffMs <= 0) return "now";
|
||||
|
||||
const hours = Math.floor(diffMs / MS_PER_HOUR);
|
||||
|
||||
// Under 24h: show relative time ("in 4h 23m")
|
||||
if (hours < 24) {
|
||||
const minutes = Math.floor((diffMs % MS_PER_HOUR) / MS_PER_MINUTE);
|
||||
if (hours > 0) return `in ${hours}h ${minutes}m`;
|
||||
return `in ${minutes}m`;
|
||||
}
|
||||
|
||||
// Over 24h: show day and time in local timezone ("Mon 12:00 AM PST")
|
||||
return resetDate.toLocaleString(undefined, {
|
||||
weekday: "short",
|
||||
hour: "numeric",
|
||||
minute: "2-digit",
|
||||
timeZoneName: "short",
|
||||
});
|
||||
}
|
||||
|
||||
function UsageBar({
|
||||
label,
|
||||
used,
|
||||
limit,
|
||||
resetsAt,
|
||||
}: {
|
||||
label: string;
|
||||
used: number;
|
||||
limit: number;
|
||||
resetsAt: Date | string;
|
||||
}) {
|
||||
if (limit <= 0) return null;
|
||||
|
||||
const rawPercent = (used / limit) * 100;
|
||||
const percent = Math.min(100, Math.round(rawPercent));
|
||||
const isHigh = percent >= 80;
|
||||
const percentLabel =
|
||||
used > 0 && percent === 0 ? "<1% used" : `${percent}% used`;
|
||||
|
||||
return (
|
||||
<div className="flex flex-col gap-1">
|
||||
<div className="flex items-baseline justify-between">
|
||||
<span className="text-xs font-medium text-neutral-700">{label}</span>
|
||||
<span className="text-[11px] tabular-nums text-neutral-500">
|
||||
{percentLabel}
|
||||
</span>
|
||||
</div>
|
||||
<div className="text-[10px] text-neutral-400">
|
||||
Resets {formatResetTime(resetsAt)}
|
||||
</div>
|
||||
<div className="h-2 w-full overflow-hidden rounded-full bg-neutral-200">
|
||||
<div
|
||||
className={`h-full rounded-full transition-[width] duration-300 ease-out ${
|
||||
isHigh ? "bg-orange-500" : "bg-blue-500"
|
||||
}`}
|
||||
style={{ width: `${Math.max(used > 0 ? 1 : 0, percent)}%` }}
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
export function UsagePanelContent({
|
||||
usage,
|
||||
showBillingLink = true,
|
||||
}: {
|
||||
usage: CoPilotUsageStatus;
|
||||
showBillingLink?: boolean;
|
||||
}) {
|
||||
const hasDailyLimit = usage.daily.limit > 0;
|
||||
const hasWeeklyLimit = usage.weekly.limit > 0;
|
||||
|
||||
if (!hasDailyLimit && !hasWeeklyLimit) {
|
||||
return (
|
||||
<div className="text-xs text-neutral-500">No usage limits configured</div>
|
||||
);
|
||||
}
|
||||
|
||||
return (
|
||||
<div className="flex flex-col gap-3">
|
||||
<div className="text-xs font-semibold text-neutral-800">Usage limits</div>
|
||||
{hasDailyLimit && (
|
||||
<UsageBar
|
||||
label="Today"
|
||||
used={usage.daily.used}
|
||||
limit={usage.daily.limit}
|
||||
resetsAt={usage.daily.resets_at}
|
||||
/>
|
||||
)}
|
||||
{hasWeeklyLimit && (
|
||||
<UsageBar
|
||||
label="This week"
|
||||
used={usage.weekly.used}
|
||||
limit={usage.weekly.limit}
|
||||
resetsAt={usage.weekly.resets_at}
|
||||
/>
|
||||
)}
|
||||
{showBillingLink && (
|
||||
<a
|
||||
href="/profile/credits"
|
||||
className="text-[11px] text-blue-600 hover:underline"
|
||||
>
|
||||
Learn more about usage limits
|
||||
</a>
|
||||
)}
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
export function UsageLimits() {
|
||||
const { data: usage, isLoading } = useUsageLimits();
|
||||
|
||||
if (isLoading || !usage) return null;
|
||||
if (usage.daily.limit <= 0 && usage.weekly.limit <= 0) return null;
|
||||
|
||||
return (
|
||||
<Popover>
|
||||
<PopoverTrigger asChild>
|
||||
<Button variant="ghost" size="icon" aria-label="Usage limits">
|
||||
<ChartBar className="!size-5" weight="light" />
|
||||
</Button>
|
||||
</PopoverTrigger>
|
||||
<PopoverContent align="start" className="w-64 p-3">
|
||||
<UsagePanelContent usage={usage} />
|
||||
</PopoverContent>
|
||||
</Popover>
|
||||
);
|
||||
}
|
||||
@@ -1,121 +0,0 @@
|
||||
import { render, screen, cleanup } from "@/tests/integrations/test-utils";
|
||||
import { afterEach, describe, expect, it, vi } from "vitest";
|
||||
import { UsageLimits } from "../UsageLimits";
|
||||
|
||||
// Mock the useUsageLimits hook
|
||||
const mockUseUsageLimits = vi.fn();
|
||||
vi.mock("../useUsageLimits", () => ({
|
||||
useUsageLimits: () => mockUseUsageLimits(),
|
||||
}));
|
||||
|
||||
// Mock Popover to render children directly (Radix portals don't work in happy-dom)
|
||||
vi.mock("@/components/molecules/Popover/Popover", () => ({
|
||||
Popover: ({ children }: { children: React.ReactNode }) => (
|
||||
<div>{children}</div>
|
||||
),
|
||||
PopoverTrigger: ({ children }: { children: React.ReactNode }) => (
|
||||
<div>{children}</div>
|
||||
),
|
||||
PopoverContent: ({ children }: { children: React.ReactNode }) => (
|
||||
<div>{children}</div>
|
||||
),
|
||||
}));
|
||||
|
||||
afterEach(() => {
|
||||
cleanup();
|
||||
mockUseUsageLimits.mockReset();
|
||||
});
|
||||
|
||||
function makeUsage({
|
||||
dailyUsed = 500,
|
||||
dailyLimit = 10000,
|
||||
weeklyUsed = 2000,
|
||||
weeklyLimit = 50000,
|
||||
}: {
|
||||
dailyUsed?: number;
|
||||
dailyLimit?: number;
|
||||
weeklyUsed?: number;
|
||||
weeklyLimit?: number;
|
||||
} = {}) {
|
||||
const future = new Date(Date.now() + 3600 * 1000); // 1h from now
|
||||
return {
|
||||
daily: { used: dailyUsed, limit: dailyLimit, resets_at: future },
|
||||
weekly: { used: weeklyUsed, limit: weeklyLimit, resets_at: future },
|
||||
};
|
||||
}
|
||||
|
||||
describe("UsageLimits", () => {
|
||||
it("renders nothing while loading", () => {
|
||||
mockUseUsageLimits.mockReturnValue({ data: undefined, isLoading: true });
|
||||
const { container } = render(<UsageLimits />);
|
||||
expect(container.innerHTML).toBe("");
|
||||
});
|
||||
|
||||
it("renders nothing when no limits are configured", () => {
|
||||
mockUseUsageLimits.mockReturnValue({
|
||||
data: makeUsage({ dailyLimit: 0, weeklyLimit: 0 }),
|
||||
isLoading: false,
|
||||
});
|
||||
const { container } = render(<UsageLimits />);
|
||||
expect(container.innerHTML).toBe("");
|
||||
});
|
||||
|
||||
it("renders the usage button when limits exist", () => {
|
||||
mockUseUsageLimits.mockReturnValue({
|
||||
data: makeUsage(),
|
||||
isLoading: false,
|
||||
});
|
||||
render(<UsageLimits />);
|
||||
expect(screen.getByRole("button", { name: /usage limits/i })).toBeDefined();
|
||||
});
|
||||
|
||||
it("displays daily and weekly usage percentages", () => {
|
||||
mockUseUsageLimits.mockReturnValue({
|
||||
data: makeUsage({ dailyUsed: 5000, dailyLimit: 10000 }),
|
||||
isLoading: false,
|
||||
});
|
||||
render(<UsageLimits />);
|
||||
|
||||
expect(screen.getByText("50% used")).toBeDefined();
|
||||
expect(screen.getByText("Today")).toBeDefined();
|
||||
expect(screen.getByText("This week")).toBeDefined();
|
||||
expect(screen.getByText("Usage limits")).toBeDefined();
|
||||
});
|
||||
|
||||
it("shows only weekly bar when daily limit is 0", () => {
|
||||
mockUseUsageLimits.mockReturnValue({
|
||||
data: makeUsage({
|
||||
dailyLimit: 0,
|
||||
weeklyUsed: 25000,
|
||||
weeklyLimit: 50000,
|
||||
}),
|
||||
isLoading: false,
|
||||
});
|
||||
render(<UsageLimits />);
|
||||
|
||||
expect(screen.getByText("This week")).toBeDefined();
|
||||
expect(screen.queryByText("Today")).toBeNull();
|
||||
});
|
||||
|
||||
it("caps percentage at 100% when over limit", () => {
|
||||
mockUseUsageLimits.mockReturnValue({
|
||||
data: makeUsage({ dailyUsed: 15000, dailyLimit: 10000 }),
|
||||
isLoading: false,
|
||||
});
|
||||
render(<UsageLimits />);
|
||||
|
||||
expect(screen.getByText("100% used")).toBeDefined();
|
||||
});
|
||||
|
||||
it("shows learn more link to credits page", () => {
|
||||
mockUseUsageLimits.mockReturnValue({
|
||||
data: makeUsage(),
|
||||
isLoading: false,
|
||||
});
|
||||
render(<UsageLimits />);
|
||||
|
||||
const link = screen.getByText("Learn more about usage limits");
|
||||
expect(link).toBeDefined();
|
||||
expect(link.closest("a")?.getAttribute("href")).toBe("/profile/credits");
|
||||
});
|
||||
});
|
||||
@@ -1,12 +0,0 @@
|
||||
import type { CoPilotUsageStatus } from "@/app/api/__generated__/models/coPilotUsageStatus";
|
||||
import { useGetV2GetCopilotUsage } from "@/app/api/__generated__/endpoints/chat/chat";
|
||||
|
||||
export function useUsageLimits() {
|
||||
return useGetV2GetCopilotUsage({
|
||||
query: {
|
||||
select: (res) => res.data as CoPilotUsageStatus,
|
||||
refetchInterval: 30000,
|
||||
staleTime: 10000,
|
||||
},
|
||||
});
|
||||
}
|
||||
@@ -7,6 +7,10 @@ export interface DeleteTarget {
|
||||
}
|
||||
|
||||
interface CopilotUIState {
|
||||
/** Prompt extracted from URL hash (e.g. /copilot#prompt=...) for input prefill. */
|
||||
initialPrompt: string | null;
|
||||
setInitialPrompt: (prompt: string | null) => void;
|
||||
|
||||
sessionToDelete: DeleteTarget | null;
|
||||
setSessionToDelete: (target: DeleteTarget | null) => void;
|
||||
|
||||
@@ -31,6 +35,9 @@ interface CopilotUIState {
|
||||
}
|
||||
|
||||
export const useCopilotUIStore = create<CopilotUIState>((set) => ({
|
||||
initialPrompt: null,
|
||||
setInitialPrompt: (prompt) => set({ initialPrompt: prompt }),
|
||||
|
||||
sessionToDelete: null,
|
||||
setSessionToDelete: (target) => set({ sessionToDelete: target }),
|
||||
|
||||
|
||||
@@ -19,6 +19,42 @@ import { useCopilotStream } from "./useCopilotStream";
|
||||
const TITLE_POLL_INTERVAL_MS = 2_000;
|
||||
const TITLE_POLL_MAX_ATTEMPTS = 5;
|
||||
|
||||
/**
|
||||
* Extract a prompt from the URL hash fragment.
|
||||
* Supports: /copilot#prompt=URL-encoded-text
|
||||
* Optionally auto-submits if ?autosubmit=true is in the query string.
|
||||
* Returns null if no prompt is present.
|
||||
*/
|
||||
function extractPromptFromUrl(): {
|
||||
prompt: string;
|
||||
autosubmit: boolean;
|
||||
} | null {
|
||||
if (typeof window === "undefined") return null;
|
||||
|
||||
const hash = window.location.hash;
|
||||
if (!hash) return null;
|
||||
|
||||
const hashParams = new URLSearchParams(hash.slice(1));
|
||||
const prompt = hashParams.get("prompt");
|
||||
|
||||
if (!prompt || !prompt.trim()) return null;
|
||||
|
||||
const searchParams = new URLSearchParams(window.location.search);
|
||||
const autosubmit = searchParams.get("autosubmit") === "true";
|
||||
|
||||
// Clean up hash + autosubmit param only (preserve other query params)
|
||||
const cleanURL = new URL(window.location.href);
|
||||
cleanURL.hash = "";
|
||||
cleanURL.searchParams.delete("autosubmit");
|
||||
window.history.replaceState(
|
||||
null,
|
||||
"",
|
||||
`${cleanURL.pathname}${cleanURL.search}`,
|
||||
);
|
||||
|
||||
return { prompt: prompt.trim(), autosubmit };
|
||||
}
|
||||
|
||||
interface UploadedFile {
|
||||
file_id: string;
|
||||
name: string;
|
||||
@@ -127,6 +163,28 @@ export function useCopilotPage() {
|
||||
}
|
||||
}, [sessionId, pendingMessage, sendMessage]);
|
||||
|
||||
// --- Extract prompt from URL hash on mount (e.g. /copilot#prompt=Hello) ---
|
||||
const { setInitialPrompt } = useCopilotUIStore();
|
||||
const hasProcessedUrlPrompt = useRef(false);
|
||||
useEffect(() => {
|
||||
if (hasProcessedUrlPrompt.current) return;
|
||||
|
||||
const urlPrompt = extractPromptFromUrl();
|
||||
if (!urlPrompt) return;
|
||||
|
||||
hasProcessedUrlPrompt.current = true;
|
||||
|
||||
if (urlPrompt.autosubmit) {
|
||||
setPendingMessage(urlPrompt.prompt);
|
||||
void createSession().catch(() => {
|
||||
setPendingMessage(null);
|
||||
setInitialPrompt(urlPrompt.prompt);
|
||||
});
|
||||
} else {
|
||||
setInitialPrompt(urlPrompt.prompt);
|
||||
}
|
||||
}, [createSession, setInitialPrompt]);
|
||||
|
||||
async function uploadFiles(
|
||||
files: File[],
|
||||
sid: string,
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
import {
|
||||
getGetV2GetCopilotUsageQueryKey,
|
||||
getGetV2GetSessionQueryKey,
|
||||
postV2CancelSessionTask,
|
||||
} from "@/app/api/__generated__/endpoints/chat/chat";
|
||||
@@ -308,9 +307,6 @@ export function useCopilotStream({
|
||||
queryClient.invalidateQueries({
|
||||
queryKey: getGetV2GetSessionQueryKey(sessionId),
|
||||
});
|
||||
queryClient.invalidateQueries({
|
||||
queryKey: getGetV2GetCopilotUsageQueryKey(),
|
||||
});
|
||||
if (status === "ready") {
|
||||
reconnectAttemptsRef.current = 0;
|
||||
hasShownDisconnectToast.current = false;
|
||||
|
||||
@@ -7,6 +7,7 @@ import { LibraryActionSubHeader } from "../LibraryActionSubHeader/LibraryActionS
|
||||
import { LibraryAgentCard } from "../LibraryAgentCard/LibraryAgentCard";
|
||||
import { LibraryFolder } from "../LibraryFolder/LibraryFolder";
|
||||
import { LibrarySubSection } from "../LibrarySubSection/LibrarySubSection";
|
||||
import { Button } from "@/components/atoms/Button/Button";
|
||||
import { ArrowLeftIcon, HeartIcon } from "@phosphor-icons/react";
|
||||
import { Text } from "@/components/atoms/Text/Text";
|
||||
import { Tab } from "../LibraryTabs/LibraryTabs";
|
||||
@@ -135,21 +136,22 @@ export function LibraryAgentList({
|
||||
<div>
|
||||
{selectedFolderId && (
|
||||
<div className="mb-4 flex items-center gap-2">
|
||||
<button
|
||||
type="button"
|
||||
<Button
|
||||
variant="ghost"
|
||||
size="small"
|
||||
onClick={() => onFolderSelect(null)}
|
||||
className="inline-flex items-center gap-1 text-sm text-zinc-500 hover:text-zinc-900"
|
||||
className="gap-1 text-zinc-500 hover:text-zinc-900"
|
||||
>
|
||||
<ArrowLeftIcon className="h-4 w-4" />
|
||||
My Library
|
||||
</button>
|
||||
</Button>
|
||||
{currentFolder && (
|
||||
<>
|
||||
<Text variant="body" className="text-zinc-400">
|
||||
<Text variant="small" className="text-zinc-400">
|
||||
/
|
||||
</Text>
|
||||
<Text variant="large" className="text-zinc-700">
|
||||
{currentFolder.name}
|
||||
<Text variant="h4" className="text-zinc-700">
|
||||
{currentFolder.icon} {currentFolder.name}
|
||||
</Text>
|
||||
</>
|
||||
)}
|
||||
|
||||
@@ -4,7 +4,6 @@ import { useGetV2ListLibraryAgentsInfinite } from "@/app/api/__generated__/endpo
|
||||
import { getGetV2ListLibraryAgentsQueryKey } from "@/app/api/__generated__/endpoints/library/library";
|
||||
import {
|
||||
useGetV2ListLibraryFolders,
|
||||
useGetV2GetFolder,
|
||||
usePostV2BulkMoveAgents,
|
||||
getGetV2ListLibraryFoldersQueryKey,
|
||||
} from "@/app/api/__generated__/endpoints/folders/folders";
|
||||
@@ -107,12 +106,9 @@ export function useLibraryAgentList({
|
||||
fetchNextPage: fetchNextPage,
|
||||
};
|
||||
|
||||
const { data: rawFoldersData } = useGetV2ListLibraryFolders(
|
||||
{ parent_id: selectedFolderId ?? undefined },
|
||||
{
|
||||
query: { select: okData },
|
||||
},
|
||||
);
|
||||
const { data: rawFoldersData } = useGetV2ListLibraryFolders(undefined, {
|
||||
query: { select: okData },
|
||||
});
|
||||
|
||||
const foldersData = searchTerm ? undefined : rawFoldersData;
|
||||
|
||||
@@ -189,15 +185,11 @@ export function useLibraryAgentList({
|
||||
});
|
||||
}
|
||||
|
||||
const { data: currentFolderData } = useGetV2GetFolder(
|
||||
selectedFolderId ?? "",
|
||||
{
|
||||
query: { select: okData, enabled: !!selectedFolderId },
|
||||
},
|
||||
);
|
||||
const currentFolder = selectedFolderId ? currentFolderData : null;
|
||||
const currentFolder = selectedFolderId
|
||||
? foldersData?.folders.find((f) => f.id === selectedFolderId)
|
||||
: null;
|
||||
|
||||
const showFolders = !isFavoritesTab;
|
||||
const showFolders = !isFavoritesTab && !selectedFolderId;
|
||||
|
||||
function handleFolderDeleted() {
|
||||
if (selectedFolderId === deletingFolder?.id) {
|
||||
|
||||
@@ -11,8 +11,6 @@ import {
|
||||
|
||||
import { RefundModal } from "./RefundModal";
|
||||
import { CreditTransaction } from "@/lib/autogpt-server-api";
|
||||
import { UsagePanelContent } from "@/app/(platform)/copilot/components/UsageLimits/UsageLimits";
|
||||
import { useUsageLimits } from "@/app/(platform)/copilot/components/UsageLimits/useUsageLimits";
|
||||
|
||||
import {
|
||||
Table,
|
||||
@@ -23,26 +21,6 @@ import {
|
||||
TableRow,
|
||||
} from "@/components/__legacy__/ui/table";
|
||||
|
||||
function CoPilotUsageSection() {
|
||||
const { data: usage, isLoading } = useUsageLimits();
|
||||
const router = useRouter();
|
||||
|
||||
if (isLoading || !usage) return null;
|
||||
if (usage.daily.limit <= 0 && usage.weekly.limit <= 0) return null;
|
||||
|
||||
return (
|
||||
<div className="my-6 space-y-4">
|
||||
<h3 className="text-lg font-medium">CoPilot Usage Limits</h3>
|
||||
<div className="rounded-lg border border-neutral-200 p-4 dark:border-neutral-700">
|
||||
<UsagePanelContent usage={usage} showBillingLink={false} />
|
||||
</div>
|
||||
<Button className="w-full" onClick={() => router.push("/copilot")}>
|
||||
Open CoPilot
|
||||
</Button>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
export default function CreditsPage() {
|
||||
const api = useBackendAPI();
|
||||
const {
|
||||
@@ -259,13 +237,11 @@ export default function CreditsPage() {
|
||||
</Button>
|
||||
)}
|
||||
</form>
|
||||
|
||||
{/* CoPilot Usage Limits */}
|
||||
<CoPilotUsageSection />
|
||||
</div>
|
||||
|
||||
<div className="my-6 space-y-4">
|
||||
{/* Payment Portal */}
|
||||
|
||||
<h3 className="text-lg font-medium">Manage Your Payment Methods</h3>
|
||||
<p className="text-neutral-600">
|
||||
You can manage your cards and see your payment history in the
|
||||
|
||||
@@ -1358,52 +1358,6 @@
|
||||
}
|
||||
}
|
||||
},
|
||||
"/api/chat/suggested-prompts": {
|
||||
"get": {
|
||||
"tags": ["v2", "chat", "chat"],
|
||||
"summary": "Get Suggested Prompts",
|
||||
"description": "Get LLM-generated suggested prompts for the authenticated user.\n\nReturns personalized quick-action prompts based on the user's\nbusiness understanding. Returns an empty list if no custom prompts\nare available.",
|
||||
"operationId": "getV2GetSuggestedPrompts",
|
||||
"responses": {
|
||||
"200": {
|
||||
"description": "Successful Response",
|
||||
"content": {
|
||||
"application/json": {
|
||||
"schema": {
|
||||
"$ref": "#/components/schemas/SuggestedPromptsResponse"
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"401": {
|
||||
"$ref": "#/components/responses/HTTP401NotAuthenticatedError"
|
||||
}
|
||||
},
|
||||
"security": [{ "HTTPBearerJWT": [] }]
|
||||
}
|
||||
},
|
||||
"/api/chat/usage": {
|
||||
"get": {
|
||||
"tags": ["v2", "chat", "chat"],
|
||||
"summary": "Get Copilot Usage",
|
||||
"description": "Get CoPilot usage status for the authenticated user.\n\nReturns current token usage vs limits for daily and weekly windows.",
|
||||
"operationId": "getV2GetCopilotUsage",
|
||||
"responses": {
|
||||
"200": {
|
||||
"description": "Successful Response",
|
||||
"content": {
|
||||
"application/json": {
|
||||
"schema": { "$ref": "#/components/schemas/CoPilotUsageStatus" }
|
||||
}
|
||||
}
|
||||
},
|
||||
"401": {
|
||||
"$ref": "#/components/responses/HTTP401NotAuthenticatedError"
|
||||
}
|
||||
},
|
||||
"security": [{ "HTTPBearerJWT": [] }]
|
||||
}
|
||||
},
|
||||
"/api/credits": {
|
||||
"get": {
|
||||
"tags": ["v1", "credits"],
|
||||
@@ -6692,214 +6646,6 @@
|
||||
}
|
||||
}
|
||||
},
|
||||
"/api/users/admin/invited-users": {
|
||||
"get": {
|
||||
"tags": ["v2", "admin", "users", "admin"],
|
||||
"summary": "List Invited Users",
|
||||
"operationId": "getV2List invited users",
|
||||
"security": [{ "HTTPBearerJWT": [] }],
|
||||
"parameters": [
|
||||
{
|
||||
"name": "page",
|
||||
"in": "query",
|
||||
"required": false,
|
||||
"schema": {
|
||||
"type": "integer",
|
||||
"minimum": 1,
|
||||
"default": 1,
|
||||
"title": "Page"
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "page_size",
|
||||
"in": "query",
|
||||
"required": false,
|
||||
"schema": {
|
||||
"type": "integer",
|
||||
"maximum": 200,
|
||||
"minimum": 1,
|
||||
"default": 50,
|
||||
"title": "Page Size"
|
||||
}
|
||||
}
|
||||
],
|
||||
"responses": {
|
||||
"200": {
|
||||
"description": "Successful Response",
|
||||
"content": {
|
||||
"application/json": {
|
||||
"schema": {
|
||||
"$ref": "#/components/schemas/InvitedUsersResponse"
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"401": {
|
||||
"$ref": "#/components/responses/HTTP401NotAuthenticatedError"
|
||||
},
|
||||
"422": {
|
||||
"description": "Validation Error",
|
||||
"content": {
|
||||
"application/json": {
|
||||
"schema": { "$ref": "#/components/schemas/HTTPValidationError" }
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"post": {
|
||||
"tags": ["v2", "admin", "users", "admin"],
|
||||
"summary": "Create Invited User",
|
||||
"operationId": "postV2Create invited user",
|
||||
"security": [{ "HTTPBearerJWT": [] }],
|
||||
"requestBody": {
|
||||
"required": true,
|
||||
"content": {
|
||||
"application/json": {
|
||||
"schema": {
|
||||
"$ref": "#/components/schemas/CreateInvitedUserRequest"
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"responses": {
|
||||
"200": {
|
||||
"description": "Successful Response",
|
||||
"content": {
|
||||
"application/json": {
|
||||
"schema": { "$ref": "#/components/schemas/InvitedUserResponse" }
|
||||
}
|
||||
}
|
||||
},
|
||||
"401": {
|
||||
"$ref": "#/components/responses/HTTP401NotAuthenticatedError"
|
||||
},
|
||||
"422": {
|
||||
"description": "Validation Error",
|
||||
"content": {
|
||||
"application/json": {
|
||||
"schema": { "$ref": "#/components/schemas/HTTPValidationError" }
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"/api/users/admin/invited-users/bulk": {
|
||||
"post": {
|
||||
"tags": ["v2", "admin", "users", "admin"],
|
||||
"summary": "Bulk Create Invited Users",
|
||||
"operationId": "postV2BulkCreateInvitedUsers",
|
||||
"requestBody": {
|
||||
"content": {
|
||||
"multipart/form-data": {
|
||||
"schema": {
|
||||
"$ref": "#/components/schemas/Body_postV2BulkCreateInvitedUsers"
|
||||
}
|
||||
}
|
||||
},
|
||||
"required": true
|
||||
},
|
||||
"responses": {
|
||||
"200": {
|
||||
"description": "Successful Response",
|
||||
"content": {
|
||||
"application/json": {
|
||||
"schema": {
|
||||
"$ref": "#/components/schemas/BulkInvitedUsersResponse"
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"401": {
|
||||
"$ref": "#/components/responses/HTTP401NotAuthenticatedError"
|
||||
},
|
||||
"422": {
|
||||
"description": "Validation Error",
|
||||
"content": {
|
||||
"application/json": {
|
||||
"schema": { "$ref": "#/components/schemas/HTTPValidationError" }
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"security": [{ "HTTPBearerJWT": [] }]
|
||||
}
|
||||
},
|
||||
"/api/users/admin/invited-users/{invited_user_id}/retry-tally": {
|
||||
"post": {
|
||||
"tags": ["v2", "admin", "users", "admin"],
|
||||
"summary": "Retry Invited User Tally",
|
||||
"operationId": "postV2Retry invited user tally",
|
||||
"security": [{ "HTTPBearerJWT": [] }],
|
||||
"parameters": [
|
||||
{
|
||||
"name": "invited_user_id",
|
||||
"in": "path",
|
||||
"required": true,
|
||||
"schema": { "type": "string", "title": "Invited User Id" }
|
||||
}
|
||||
],
|
||||
"responses": {
|
||||
"200": {
|
||||
"description": "Successful Response",
|
||||
"content": {
|
||||
"application/json": {
|
||||
"schema": { "$ref": "#/components/schemas/InvitedUserResponse" }
|
||||
}
|
||||
}
|
||||
},
|
||||
"401": {
|
||||
"$ref": "#/components/responses/HTTP401NotAuthenticatedError"
|
||||
},
|
||||
"422": {
|
||||
"description": "Validation Error",
|
||||
"content": {
|
||||
"application/json": {
|
||||
"schema": { "$ref": "#/components/schemas/HTTPValidationError" }
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"/api/users/admin/invited-users/{invited_user_id}/revoke": {
|
||||
"post": {
|
||||
"tags": ["v2", "admin", "users", "admin"],
|
||||
"summary": "Revoke Invited User",
|
||||
"operationId": "postV2Revoke invited user",
|
||||
"security": [{ "HTTPBearerJWT": [] }],
|
||||
"parameters": [
|
||||
{
|
||||
"name": "invited_user_id",
|
||||
"in": "path",
|
||||
"required": true,
|
||||
"schema": { "type": "string", "title": "Invited User Id" }
|
||||
}
|
||||
],
|
||||
"responses": {
|
||||
"200": {
|
||||
"description": "Successful Response",
|
||||
"content": {
|
||||
"application/json": {
|
||||
"schema": { "$ref": "#/components/schemas/InvitedUserResponse" }
|
||||
}
|
||||
}
|
||||
},
|
||||
"401": {
|
||||
"$ref": "#/components/responses/HTTP401NotAuthenticatedError"
|
||||
},
|
||||
"422": {
|
||||
"description": "Validation Error",
|
||||
"content": {
|
||||
"application/json": {
|
||||
"schema": { "$ref": "#/components/schemas/HTTPValidationError" }
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"/api/workspace/files/upload": {
|
||||
"post": {
|
||||
"tags": ["workspace"],
|
||||
@@ -8286,14 +8032,6 @@
|
||||
"required": ["store_listing_version_id"],
|
||||
"title": "Body_postV2Add marketplace agent"
|
||||
},
|
||||
"Body_postV2BulkCreateInvitedUsers": {
|
||||
"properties": {
|
||||
"file": { "type": "string", "format": "binary", "title": "File" }
|
||||
},
|
||||
"type": "object",
|
||||
"required": ["file"],
|
||||
"title": "Body_postV2BulkCreateInvitedUsers"
|
||||
},
|
||||
"Body_postV2Execute_a_preset": {
|
||||
"properties": {
|
||||
"inputs": {
|
||||
@@ -8328,56 +8066,6 @@
|
||||
"required": ["file"],
|
||||
"title": "Body_postWorkspaceUpload file to workspace"
|
||||
},
|
||||
"BulkInvitedUserRowResponse": {
|
||||
"properties": {
|
||||
"row_number": { "type": "integer", "title": "Row Number" },
|
||||
"email": {
|
||||
"anyOf": [{ "type": "string" }, { "type": "null" }],
|
||||
"title": "Email"
|
||||
},
|
||||
"name": {
|
||||
"anyOf": [{ "type": "string" }, { "type": "null" }],
|
||||
"title": "Name"
|
||||
},
|
||||
"status": {
|
||||
"type": "string",
|
||||
"enum": ["CREATED", "SKIPPED", "ERROR"],
|
||||
"title": "Status"
|
||||
},
|
||||
"message": { "type": "string", "title": "Message" },
|
||||
"invited_user": {
|
||||
"anyOf": [
|
||||
{ "$ref": "#/components/schemas/InvitedUserResponse" },
|
||||
{ "type": "null" }
|
||||
]
|
||||
}
|
||||
},
|
||||
"type": "object",
|
||||
"required": ["row_number", "status", "message"],
|
||||
"title": "BulkInvitedUserRowResponse"
|
||||
},
|
||||
"BulkInvitedUsersResponse": {
|
||||
"properties": {
|
||||
"created_count": { "type": "integer", "title": "Created Count" },
|
||||
"skipped_count": { "type": "integer", "title": "Skipped Count" },
|
||||
"error_count": { "type": "integer", "title": "Error Count" },
|
||||
"results": {
|
||||
"items": {
|
||||
"$ref": "#/components/schemas/BulkInvitedUserRowResponse"
|
||||
},
|
||||
"type": "array",
|
||||
"title": "Results"
|
||||
}
|
||||
},
|
||||
"type": "object",
|
||||
"required": [
|
||||
"created_count",
|
||||
"skipped_count",
|
||||
"error_count",
|
||||
"results"
|
||||
],
|
||||
"title": "BulkInvitedUsersResponse"
|
||||
},
|
||||
"BulkMoveAgentsRequest": {
|
||||
"properties": {
|
||||
"agent_ids": {
|
||||
@@ -8477,16 +8165,6 @@
|
||||
"title": "ClarifyingQuestion",
|
||||
"description": "A question that needs user clarification."
|
||||
},
|
||||
"CoPilotUsageStatus": {
|
||||
"properties": {
|
||||
"daily": { "$ref": "#/components/schemas/UsageWindow" },
|
||||
"weekly": { "$ref": "#/components/schemas/UsageWindow" }
|
||||
},
|
||||
"type": "object",
|
||||
"required": ["daily", "weekly"],
|
||||
"title": "CoPilotUsageStatus",
|
||||
"description": "Current usage status for a user across all windows."
|
||||
},
|
||||
"ContentType": {
|
||||
"type": "string",
|
||||
"enum": [
|
||||
@@ -8564,18 +8242,6 @@
|
||||
"required": ["graph"],
|
||||
"title": "CreateGraph"
|
||||
},
|
||||
"CreateInvitedUserRequest": {
|
||||
"properties": {
|
||||
"email": { "type": "string", "format": "email", "title": "Email" },
|
||||
"name": {
|
||||
"anyOf": [{ "type": "string" }, { "type": "null" }],
|
||||
"title": "Name"
|
||||
}
|
||||
},
|
||||
"type": "object",
|
||||
"required": ["email"],
|
||||
"title": "CreateInvitedUserRequest"
|
||||
},
|
||||
"CreateSessionResponse": {
|
||||
"properties": {
|
||||
"id": { "type": "string", "title": "Id" },
|
||||
@@ -10040,80 +9706,6 @@
|
||||
"title": "InputValidationErrorResponse",
|
||||
"description": "Response when run_agent receives unknown input fields."
|
||||
},
|
||||
"InvitedUserResponse": {
|
||||
"properties": {
|
||||
"id": { "type": "string", "title": "Id" },
|
||||
"email": { "type": "string", "title": "Email" },
|
||||
"status": { "$ref": "#/components/schemas/InvitedUserStatus" },
|
||||
"auth_user_id": {
|
||||
"anyOf": [{ "type": "string" }, { "type": "null" }],
|
||||
"title": "Auth User Id"
|
||||
},
|
||||
"name": {
|
||||
"anyOf": [{ "type": "string" }, { "type": "null" }],
|
||||
"title": "Name"
|
||||
},
|
||||
"tally_understanding": {
|
||||
"anyOf": [
|
||||
{ "additionalProperties": true, "type": "object" },
|
||||
{ "type": "null" }
|
||||
],
|
||||
"title": "Tally Understanding"
|
||||
},
|
||||
"tally_status": {
|
||||
"$ref": "#/components/schemas/TallyComputationStatus"
|
||||
},
|
||||
"tally_computed_at": {
|
||||
"anyOf": [
|
||||
{ "type": "string", "format": "date-time" },
|
||||
{ "type": "null" }
|
||||
],
|
||||
"title": "Tally Computed At"
|
||||
},
|
||||
"tally_error": {
|
||||
"anyOf": [{ "type": "string" }, { "type": "null" }],
|
||||
"title": "Tally Error"
|
||||
},
|
||||
"created_at": {
|
||||
"type": "string",
|
||||
"format": "date-time",
|
||||
"title": "Created At"
|
||||
},
|
||||
"updated_at": {
|
||||
"type": "string",
|
||||
"format": "date-time",
|
||||
"title": "Updated At"
|
||||
}
|
||||
},
|
||||
"type": "object",
|
||||
"required": [
|
||||
"id",
|
||||
"email",
|
||||
"status",
|
||||
"tally_status",
|
||||
"created_at",
|
||||
"updated_at"
|
||||
],
|
||||
"title": "InvitedUserResponse"
|
||||
},
|
||||
"InvitedUserStatus": {
|
||||
"type": "string",
|
||||
"enum": ["INVITED", "CLAIMED", "REVOKED"],
|
||||
"title": "InvitedUserStatus"
|
||||
},
|
||||
"InvitedUsersResponse": {
|
||||
"properties": {
|
||||
"invited_users": {
|
||||
"items": { "$ref": "#/components/schemas/InvitedUserResponse" },
|
||||
"type": "array",
|
||||
"title": "Invited Users"
|
||||
},
|
||||
"pagination": { "$ref": "#/components/schemas/Pagination" }
|
||||
},
|
||||
"type": "object",
|
||||
"required": ["invited_users", "pagination"],
|
||||
"title": "InvitedUsersResponse"
|
||||
},
|
||||
"LibraryAgent": {
|
||||
"properties": {
|
||||
"id": { "type": "string", "title": "Id" },
|
||||
@@ -12222,16 +11814,6 @@
|
||||
{ "$ref": "#/components/schemas/ActiveStreamInfo" },
|
||||
{ "type": "null" }
|
||||
]
|
||||
},
|
||||
"total_prompt_tokens": {
|
||||
"type": "integer",
|
||||
"title": "Total Prompt Tokens",
|
||||
"default": 0
|
||||
},
|
||||
"total_completion_tokens": {
|
||||
"type": "integer",
|
||||
"title": "Total Completion Tokens",
|
||||
"default": 0
|
||||
}
|
||||
},
|
||||
"type": "object",
|
||||
@@ -13076,19 +12658,6 @@
|
||||
"title": "SuggestedGoalResponse",
|
||||
"description": "Response when the goal needs refinement with a suggested alternative."
|
||||
},
|
||||
"SuggestedPromptsResponse": {
|
||||
"properties": {
|
||||
"prompts": {
|
||||
"items": { "type": "string" },
|
||||
"type": "array",
|
||||
"title": "Prompts"
|
||||
}
|
||||
},
|
||||
"type": "object",
|
||||
"required": ["prompts"],
|
||||
"title": "SuggestedPromptsResponse",
|
||||
"description": "Response model for user-specific suggested prompts."
|
||||
},
|
||||
"SuggestionsResponse": {
|
||||
"properties": {
|
||||
"recent_searches": {
|
||||
@@ -13114,11 +12683,6 @@
|
||||
"required": ["recent_searches", "providers", "top_blocks"],
|
||||
"title": "SuggestionsResponse"
|
||||
},
|
||||
"TallyComputationStatus": {
|
||||
"type": "string",
|
||||
"enum": ["PENDING", "RUNNING", "READY", "FAILED"],
|
||||
"title": "TallyComputationStatus"
|
||||
},
|
||||
"TimezoneResponse": {
|
||||
"properties": {
|
||||
"timezone": {
|
||||
@@ -14629,25 +14193,6 @@
|
||||
"required": ["timezone"],
|
||||
"title": "UpdateTimezoneRequest"
|
||||
},
|
||||
"UsageWindow": {
|
||||
"properties": {
|
||||
"used": { "type": "integer", "title": "Used" },
|
||||
"limit": {
|
||||
"type": "integer",
|
||||
"title": "Limit",
|
||||
"description": "Maximum tokens allowed in this window. 0 means unlimited."
|
||||
},
|
||||
"resets_at": {
|
||||
"type": "string",
|
||||
"format": "date-time",
|
||||
"title": "Resets At"
|
||||
}
|
||||
},
|
||||
"type": "object",
|
||||
"required": ["used", "limit", "resets_at"],
|
||||
"title": "UsageWindow",
|
||||
"description": "Usage within a single time window."
|
||||
},
|
||||
"UserHistoryResponse": {
|
||||
"properties": {
|
||||
"history": {
|
||||
|
||||
@@ -288,7 +288,6 @@ const SidebarTrigger = React.forwardRef<
|
||||
ref={ref}
|
||||
data-sidebar="trigger"
|
||||
variant="ghost"
|
||||
size="icon"
|
||||
onClick={(event) => {
|
||||
onClick?.(event);
|
||||
toggleSidebar();
|
||||
|
||||
Reference in New Issue
Block a user