Compare commits

..

20 Commits

Author SHA1 Message Date
Bentlybro
9b93a956b4 style: fix trailing whitespace in registry.py 2026-03-16 15:44:45 +00:00
Bentlybro
b236719bbf fix(startup): handle missing AgentNode table in migrate_llm_models
Tests fail with 'relation "platform.AgentNode" does not exist' because
migrate_llm_models() runs during startup and queries a table that doesn't
exist in fresh test databases.

This is an existing bug in the codebase - the function has no error handling.

Wrap the call in try/except to gracefully handle test environments where
the AgentNode table hasn't been created yet.
2026-03-16 15:12:25 +00:00
Bentlybro
4f286f510f refactor: address CodeRabbit/Majdyz review feedback
- Fix ModelMetadata duplicate type collision by importing from blocks.llm
- Remove _json_to_dict helper, use dict() inline
- Add warning when Provider relation is missing (data corruption indicator)
- Optimize get_default_model_slug with next() (single sort pass)
- Optimize _build_schema_options to use list comprehension
- Move llm_registry import to top-level in rest_api.py
- Ensure max_output_tokens falls back to context_window when null

All critical and quick-win issues addressed.
2026-03-16 14:55:39 +00:00
Bentlybro
b1595d871d fix: address Sentry/CodeRabbit critical and major issues
**CRITICAL FIX - ModelMetadata instantiation:**
- Removed non-existent 'supports_vision' argument
- Added required fields: display_name, provider_name, creator_name, price_tier
- Handle nullable DB fields (Creator, priceTier, maxOutputTokens) safely
- Fallback: creator_name='Unknown' if no Creator, price_tier=1 if invalid

**MAJOR FIX - Preserve pricing unit:**
- Added 'unit' field to RegistryModelCost dataclass
- Prevents RUN vs TOKENS ambiguity in cached costs
- Convert Prisma enum to string when building cost objects

**MAJOR FIX - Deterministic default model:**
- Sort recommended models by display_name before selection
- Prevents non-deterministic results when multiple models are recommended
- Ensures consistent default across refreshes

**STARTUP IMPROVEMENT:**
- Added comment: graceful fallback OK for now (no blocks use registry yet)
- Will be stricter in PR #5 when block integration lands
- Added success log message for registry refresh

Fixes identified by Sentry (critical TypeError) and CodeRabbit review.
2026-03-16 14:55:39 +00:00
Bentlybro
29ab7f2d9c feat(platform): Add LLM registry core - DB layer + in-memory cache
Implements the registry core for dynamic LLM model management:

**DB Layer:**
- Fetch models with provider, costs, and creator relations
- Prisma query with includes for related data
- Convert DB records to typed dataclasses

**In-memory Cache:**
- Global dict for fast model lookups
- Atomic cache refresh with lock protection
- Schema options generation for UI dropdowns

**Public API:**
- get_model(slug) - lookup by slug
- get_all_models() - all models (including disabled)
- get_enabled_models() - enabled models only
- get_schema_options() - UI dropdown data
- get_default_model_slug() - recommended or first enabled
- refresh_llm_registry() - manual refresh trigger

**Integration:**
- Refresh at API startup (before block init)
- Graceful fallback if registry unavailable
- Enables blocks to consume registry data

**Models:**
- RegistryModel - full model with metadata
- RegistryModelCost - pricing configuration
- RegistryModelCreator - model creator info
- ModelMetadata - context window, capabilities

**Next PRs:**
- PR #3: Public read API (GET endpoints)
- PR #4: Admin write API (POST/PATCH/DELETE)
- PR #5: Block integration (update LLM block)
- PR #6: Redis cache (solve thundering herd)

Lines: ~230 (registry.py ~210, __init__.py ~30, model.py from draft)
Files: 4 (3 new, 1 modified)
2026-03-16 14:55:39 +00:00
Bentlybro
784936b323 revert: undo changes to graph.py
Reverting migrate_llm_models modifications per request.
Back to dev baseline for this file.
2026-03-16 14:55:39 +00:00
Bentlybro
f2ae38a1a7 fix(schema): address Majdyz review feedback
- Add FK constraints on LlmModelMigration (sourceModelSlug, targetModelSlug → LlmModel.slug)
- Remove unused @@index([credentialProvider]) on LlmModelCost
- Remove redundant @@index([isReverted]) on LlmModelMigration (covered by composite)
- Add documentation for credentialProvider field explaining its purpose
- Add reverse relation fields to LlmModel (SourceMigrations, TargetMigrations)

Fixes data integrity: typos in migration slugs now caught at DB level.
2026-03-16 14:52:19 +00:00
Bently
2ccfb4e4c1 Merge branch 'dev' into feat/llm-registry-schema 2026-03-10 17:52:01 +00:00
Bentlybro
c65e5c957a fix: isort import order 2026-03-10 16:43:34 +00:00
Bentlybro
54355a691b fix: use execute_raw_with_schema for proper multi-schema support
Per Sentry feedback: db.execute_raw ignores connection string's ?schema=
parameter and defaults to 'public' schema. This breaks in multi-schema setups.

Changes:
- Import execute_raw_with_schema from .db
- Use {schema_prefix} placeholder in query
- Call execute_raw_with_schema instead of db.execute_raw

This matches the pattern used in fix_llm_provider_credentials and other
schema-aware migrations. Works in both CI (public schema) and local
(platform schema from connection string).
2026-03-10 16:25:12 +00:00
Bentlybro
3cafa49c4c fix: remove hardcoded schema prefix from migrate_llm_models query
The raw SQL query in migrate_llm_models() hardcoded platform."AgentNode"
which fails in CI where tables are in 'public' schema (not 'platform').

This code exists in dev but only runs when LLM registry has data. With our
new schema, the migration tries to run at startup and fails in CI.

Changed: UPDATE platform."AgentNode" -> UPDATE "AgentNode"

Matches pattern of all other migrations - let connection string's default
schema handle routing.
2026-03-10 16:19:57 +00:00
Bentlybro
ded002a406 fix: remove CREATE SCHEMA to match CI environment
CI uses schema "public" as default (not "platform"), so creating
a platform schema then tables without prefix puts tables in public
but Prisma looks in platform.

Existing migrations don't create schema - they rely on connection
string's default. Remove CREATE SCHEMA IF NOT EXISTS to match.
2026-03-10 15:57:26 +00:00
Bentlybro
4fdf89c3be fix: remove schema prefix from migration SQL to match existing pattern
CI failing with 'relation "platform.AgentNode" does not exist' because
Prisma generates queries differently when tables are created with
explicit schema prefixes.

Existing AutoGPT migrations use:
  CREATE TABLE "AgentNode" (...)

Not:
  CREATE TABLE "platform"."AgentNode" (...)

The connection string's ?schema=platform handles schema selection,
so explicit prefixes aren't needed and cause compatibility issues.

Changes:
- Remove all "platform". prefixes from:
  * CREATE TYPE statements
  * CREATE TABLE statements
  * CREATE INDEX statements
  * ALTER TABLE statements
  * REFERENCES clauses in foreign keys

Now matches existing migration pattern exactly.
2026-03-10 15:38:41 +00:00
Bentlybro
d816bd739f fix: add partial unique indexes for data integrity
Per CodeRabbit feedback - fix 2 actual bugs:

1. Prevent multiple active migrations per source model
   - Add partial unique index: UNIQUE (sourceModelSlug) WHERE isReverted = false
   - Prevents ambiguous routing when resolving migrations

2. Allow both default and credential-specific costs
   - Remove @@unique([llmModelId, credentialProvider, unit])
   - Add 2 partial unique indexes:
     * UNIQUE (llmModelId, provider, unit) WHERE credentialId IS NULL (defaults)
     * UNIQUE (llmModelId, provider, credentialId, unit) WHERE credentialId IS NOT NULL (overrides)
   - Enables provider-level default costs + per-credential overrides

Schema comments document that these constraints exist in migration SQL.
2026-03-10 15:08:44 +00:00
Bentlybro
6a16376323 fix: remove multiSchema - follow existing AutoGPT pattern
Remove unnecessary multiSchema configuration that broke existing models.

AutoGPT uses connection string's ?schema=platform parameter as default,
not Prisma's multiSchema feature. Existing models (User, AgentGraph, etc.)
have no @@schema() directives and work fine.

Changes:
- Remove schemas = ["platform", "public"] from datasource
- Remove "multiSchema" from previewFeatures
- Remove all @@schema() directives from LLM models and enum

Migration SQL already creates tables in platform schema explicitly
(CREATE TABLE "platform"."LlmProvider" etc.) which is correct.

This matches the existing pattern used throughout the codebase.
2026-03-10 14:49:23 +00:00
Bentlybro
ed7b02ffb1 fix: address CodeRabbit design feedback
Per CodeRabbit review:

1. **Safety: Change capability defaults false → safer for partial seeding**
   - supportsTools: true → false
   - supportsJsonOutput: true → false
   - Prevents partially-seeded rows from being assumed capable

2. **Clarity: Rename supportsParallelTool → supportsParallelToolCalls**
   - More explicit about what the field represents

3. **Performance: Remove redundant indexes**
   - Drop @@index([llmModelId]) - covered by unique constraint
   - Drop @@index([sourceModelSlug]) - covered by composite index
   - Reduces write overhead and storage

4. **Documentation: Acknowledge customCreditCost limitation**
   - It's unit-agnostic (doesn't distinguish RUN vs TOKENS)
   - Noted as TODO for follow-up PR with proper unit-aware override

Schema + migration both updated to match.
2026-03-10 14:27:42 +00:00
Bentlybro
d064198dd1 fix: add @@schema("platform") to LlmCostUnit enum
Sentry caught this - enums also need @@schema directive with multiSchema enabled.
Without it, Prisma looks for enum in public schema but it's created in platform.
2026-03-10 14:23:24 +00:00
Bentlybro
01ad033b2b feat: add database CHECK constraints for data integrity
Per CodeRabbit feedback - enforce numeric domain rules at DB level:

Migration:
- priceTier: CHECK (priceTier BETWEEN 1 AND 3)
- creditCost: CHECK (creditCost >= 0)
- nodeCount: CHECK (nodeCount >= 0)
- customCreditCost: CHECK (customCreditCost IS NULL OR customCreditCost >= 0)

Schema comments:
- Document constraints inline for developer visibility

Prevents invalid data (negative costs, out-of-range tiers) from
entering the database, matching backend/blocks/llm.py contract.
2026-03-10 14:20:07 +00:00
Bentlybro
56bcbda054 fix: use @@schema() instead of @@map() for platform schema + create schema in migration
Critical fixes from PR review:

1. Replace @@map("platform.ModelName") with @@schema("platform")
   - Sentry correctly identified: Prisma was looking for literal table "platform.LlmProvider" with dot
   - Proper syntax: enable multiSchema feature + use @@schema directive

2. Create platform schema in migration
   - CI failed: schema "platform" does not exist
   - Add CREATE SCHEMA IF NOT EXISTS at start of migration

Schema changes:
- datasource: add schemas = ["platform", "public"]
- generator: add "multiSchema" to previewFeatures
- All 5 models: @@map() → @@schema("platform")

Migration changes:
- Add CREATE SCHEMA IF NOT EXISTS "platform" before enum creation

Fixes CI failure and Sentry-identified bug.
2026-03-10 14:15:35 +00:00
Bentlybro
d40efc6056 feat(platform): Add LLM registry database schema
Add Prisma schema and migration for dynamic LLM model registry:

Schema additions:
- LlmProvider: Registry of LLM providers (OpenAI, Anthropic, etc.)
- LlmModel: Individual models with capabilities and metadata
- LlmModelCost: Per-model pricing configuration
- LlmModelCreator: Model creators/trainers (OpenAI, Meta, etc.)
- LlmModelMigration: Track model migrations and reverts
- LlmCostUnit enum: RUN vs TOKENS pricing units

Key features:
- Model-specific capabilities (tools, JSON, reasoning, parallel calls)
- Flexible creator/provider separation (e.g., Meta model via Hugging Face)
- Migration tracking with custom pricing overrides
- Indexes for performance on common queries

Part 1 of incremental LLM registry implementation.
Refs: Draft PR #11699
2026-03-10 13:22:05 +00:00
187 changed files with 2340 additions and 14497 deletions

View File

@@ -1,40 +0,0 @@
-- =============================================================
-- View: analytics.auth_activities
-- Looker source alias: ds49 | Charts: 1
-- =============================================================
-- DESCRIPTION
-- Tracks authentication events (login, logout, SSO, password
-- reset, etc.) from Supabase's internal audit log.
-- Useful for monitoring sign-in patterns and detecting anomalies.
--
-- SOURCE TABLES
-- auth.audit_log_entries — Supabase internal auth event log
--
-- OUTPUT COLUMNS
-- created_at TIMESTAMPTZ When the auth event occurred
-- actor_id TEXT User ID who triggered the event
-- actor_via_sso TEXT Whether the action was via SSO ('true'/'false')
-- action TEXT Event type (e.g. 'login', 'logout', 'token_refreshed')
--
-- WINDOW
-- Rolling 90 days from current date
--
-- EXAMPLE QUERIES
-- -- Daily login counts
-- SELECT DATE_TRUNC('day', created_at) AS day, COUNT(*) AS logins
-- FROM analytics.auth_activities
-- WHERE action = 'login'
-- GROUP BY 1 ORDER BY 1;
--
-- -- SSO vs password login breakdown
-- SELECT actor_via_sso, COUNT(*) FROM analytics.auth_activities
-- WHERE action = 'login' GROUP BY 1;
-- =============================================================
SELECT
created_at,
payload->>'actor_id' AS actor_id,
payload->>'actor_via_sso' AS actor_via_sso,
payload->>'action' AS action
FROM auth.audit_log_entries
WHERE created_at >= NOW() - INTERVAL '90 days'

View File

@@ -1,105 +0,0 @@
-- =============================================================
-- View: analytics.graph_execution
-- Looker source alias: ds16 | Charts: 21
-- =============================================================
-- DESCRIPTION
-- One row per agent graph execution (last 90 days).
-- Unpacks the JSONB stats column into individual numeric columns
-- and normalises the executionStatus — runs that failed due to
-- insufficient credits are reclassified as 'NO_CREDITS' for
-- easier filtering. Error messages are scrubbed of IDs and URLs
-- to allow safe grouping.
--
-- SOURCE TABLES
-- platform.AgentGraphExecution — Execution records
-- platform.AgentGraph — Agent graph metadata (for name)
-- platform.LibraryAgent — To flag possibly-AI (safe-mode) agents
--
-- OUTPUT COLUMNS
-- id TEXT Execution UUID
-- agentGraphId TEXT Agent graph UUID
-- agentGraphVersion INT Graph version number
-- executionStatus TEXT COMPLETED | FAILED | NO_CREDITS | RUNNING | QUEUED | TERMINATED
-- createdAt TIMESTAMPTZ When the execution was queued
-- updatedAt TIMESTAMPTZ Last status update time
-- userId TEXT Owner user UUID
-- agentGraphName TEXT Human-readable agent name
-- cputime DECIMAL Total CPU seconds consumed
-- walltime DECIMAL Total wall-clock seconds
-- node_count DECIMAL Number of nodes in the graph
-- nodes_cputime DECIMAL CPU time across all nodes
-- nodes_walltime DECIMAL Wall time across all nodes
-- execution_cost DECIMAL Credit cost of this execution
-- correctness_score FLOAT AI correctness score (if available)
-- possibly_ai BOOLEAN True if agent has sensitive_action_safe_mode enabled
-- groupedErrorMessage TEXT Scrubbed error string (IDs/URLs replaced with wildcards)
--
-- WINDOW
-- Rolling 90 days (createdAt > CURRENT_DATE - 90 days)
--
-- EXAMPLE QUERIES
-- -- Daily execution counts by status
-- SELECT DATE_TRUNC('day', "createdAt") AS day, "executionStatus", COUNT(*)
-- FROM analytics.graph_execution
-- GROUP BY 1, 2 ORDER BY 1;
--
-- -- Average cost per execution by agent
-- SELECT "agentGraphName", AVG("execution_cost") AS avg_cost, COUNT(*) AS runs
-- FROM analytics.graph_execution
-- WHERE "executionStatus" = 'COMPLETED'
-- GROUP BY 1 ORDER BY avg_cost DESC;
--
-- -- Top error messages
-- SELECT "groupedErrorMessage", COUNT(*) AS occurrences
-- FROM analytics.graph_execution
-- WHERE "executionStatus" = 'FAILED'
-- GROUP BY 1 ORDER BY 2 DESC LIMIT 20;
-- =============================================================
SELECT
ge."id" AS id,
ge."agentGraphId" AS agentGraphId,
ge."agentGraphVersion" AS agentGraphVersion,
CASE
WHEN jsonb_exists(ge."stats"::jsonb, 'error')
AND (
(ge."stats"::jsonb->>'error') ILIKE '%insufficient balance%'
OR (ge."stats"::jsonb->>'error') ILIKE '%you have no credits left%'
)
THEN 'NO_CREDITS'
ELSE CAST(ge."executionStatus" AS TEXT)
END AS executionStatus,
ge."createdAt" AS createdAt,
ge."updatedAt" AS updatedAt,
ge."userId" AS userId,
g."name" AS agentGraphName,
(ge."stats"::jsonb->>'cputime')::decimal AS cputime,
(ge."stats"::jsonb->>'walltime')::decimal AS walltime,
(ge."stats"::jsonb->>'node_count')::decimal AS node_count,
(ge."stats"::jsonb->>'nodes_cputime')::decimal AS nodes_cputime,
(ge."stats"::jsonb->>'nodes_walltime')::decimal AS nodes_walltime,
(ge."stats"::jsonb->>'cost')::decimal AS execution_cost,
(ge."stats"::jsonb->>'correctness_score')::float AS correctness_score,
COALESCE(la.possibly_ai, FALSE) AS possibly_ai,
REGEXP_REPLACE(
REGEXP_REPLACE(
TRIM(BOTH '"' FROM ge."stats"::jsonb->>'error'),
'(https?://)([A-Za-z0-9.-]+)(:[0-9]+)?(/[^\s]*)?',
'\1\2/...', 'gi'
),
'[a-zA-Z0-9_:-]*\d[a-zA-Z0-9_:-]*', '*', 'g'
) AS groupedErrorMessage
FROM platform."AgentGraphExecution" ge
LEFT JOIN platform."AgentGraph" g
ON ge."agentGraphId" = g."id"
AND ge."agentGraphVersion" = g."version"
LEFT JOIN (
SELECT DISTINCT ON ("userId", "agentGraphId")
"userId", "agentGraphId",
("settings"::jsonb->>'sensitive_action_safe_mode')::boolean AS possibly_ai
FROM platform."LibraryAgent"
WHERE "isDeleted" = FALSE
AND "isArchived" = FALSE
ORDER BY "userId", "agentGraphId", "agentGraphVersion" DESC
) la ON la."userId" = ge."userId" AND la."agentGraphId" = ge."agentGraphId"
WHERE ge."createdAt" > CURRENT_DATE - INTERVAL '90 days'

View File

@@ -1,101 +0,0 @@
-- =============================================================
-- View: analytics.node_block_execution
-- Looker source alias: ds14 | Charts: 11
-- =============================================================
-- DESCRIPTION
-- One row per node (block) execution (last 90 days).
-- Unpacks stats JSONB and joins to identify which block type
-- was run. For failed nodes, joins the error output and
-- scrubs it for safe grouping.
--
-- SOURCE TABLES
-- platform.AgentNodeExecution — Node execution records
-- platform.AgentNode — Node → block mapping
-- platform.AgentBlock — Block name/ID
-- platform.AgentNodeExecutionInputOutput — Error output values
--
-- OUTPUT COLUMNS
-- id TEXT Node execution UUID
-- agentGraphExecutionId TEXT Parent graph execution UUID
-- agentNodeId TEXT Node UUID within the graph
-- executionStatus TEXT COMPLETED | FAILED | QUEUED | RUNNING | TERMINATED
-- addedTime TIMESTAMPTZ When the node was queued
-- queuedTime TIMESTAMPTZ When it entered the queue
-- startedTime TIMESTAMPTZ When execution started
-- endedTime TIMESTAMPTZ When execution finished
-- inputSize BIGINT Input payload size in bytes
-- outputSize BIGINT Output payload size in bytes
-- walltime NUMERIC Wall-clock seconds for this node
-- cputime NUMERIC CPU seconds for this node
-- llmRetryCount INT Number of LLM retries
-- llmCallCount INT Number of LLM API calls made
-- inputTokenCount BIGINT LLM input tokens consumed
-- outputTokenCount BIGINT LLM output tokens produced
-- blockName TEXT Human-readable block name (e.g. 'OpenAIBlock')
-- blockId TEXT Block UUID
-- groupedErrorMessage TEXT Scrubbed error (IDs/URLs wildcarded)
-- errorMessage TEXT Raw error output (only set when FAILED)
--
-- WINDOW
-- Rolling 90 days (addedTime > CURRENT_DATE - 90 days)
--
-- EXAMPLE QUERIES
-- -- Most-used blocks by execution count
-- SELECT "blockName", COUNT(*) AS executions,
-- COUNT(*) FILTER (WHERE "executionStatus"='FAILED') AS failures
-- FROM analytics.node_block_execution
-- GROUP BY 1 ORDER BY executions DESC LIMIT 20;
--
-- -- Average LLM token usage per block
-- SELECT "blockName",
-- AVG("inputTokenCount") AS avg_input_tokens,
-- AVG("outputTokenCount") AS avg_output_tokens
-- FROM analytics.node_block_execution
-- WHERE "llmCallCount" > 0
-- GROUP BY 1 ORDER BY avg_input_tokens DESC;
--
-- -- Top failure reasons
-- SELECT "blockName", "groupedErrorMessage", COUNT(*) AS count
-- FROM analytics.node_block_execution
-- WHERE "executionStatus" = 'FAILED'
-- GROUP BY 1, 2 ORDER BY count DESC LIMIT 20;
-- =============================================================
SELECT
ne."id" AS id,
ne."agentGraphExecutionId" AS agentGraphExecutionId,
ne."agentNodeId" AS agentNodeId,
CAST(ne."executionStatus" AS TEXT) AS executionStatus,
ne."addedTime" AS addedTime,
ne."queuedTime" AS queuedTime,
ne."startedTime" AS startedTime,
ne."endedTime" AS endedTime,
(ne."stats"::jsonb->>'input_size')::bigint AS inputSize,
(ne."stats"::jsonb->>'output_size')::bigint AS outputSize,
(ne."stats"::jsonb->>'walltime')::numeric AS walltime,
(ne."stats"::jsonb->>'cputime')::numeric AS cputime,
(ne."stats"::jsonb->>'llm_retry_count')::int AS llmRetryCount,
(ne."stats"::jsonb->>'llm_call_count')::int AS llmCallCount,
(ne."stats"::jsonb->>'input_token_count')::bigint AS inputTokenCount,
(ne."stats"::jsonb->>'output_token_count')::bigint AS outputTokenCount,
b."name" AS blockName,
b."id" AS blockId,
REGEXP_REPLACE(
REGEXP_REPLACE(
TRIM(BOTH '"' FROM eio."data"::text),
'(https?://)([A-Za-z0-9.-]+)(:[0-9]+)?(/[^\s]*)?',
'\1\2/...', 'gi'
),
'[a-zA-Z0-9_:-]*\d[a-zA-Z0-9_:-]*', '*', 'g'
) AS groupedErrorMessage,
eio."data" AS errorMessage
FROM platform."AgentNodeExecution" ne
LEFT JOIN platform."AgentNode" nd
ON ne."agentNodeId" = nd."id"
LEFT JOIN platform."AgentBlock" b
ON nd."agentBlockId" = b."id"
LEFT JOIN platform."AgentNodeExecutionInputOutput" eio
ON eio."referencedByOutputExecId" = ne."id"
AND eio."name" = 'error'
AND ne."executionStatus" = 'FAILED'
WHERE ne."addedTime" > CURRENT_DATE - INTERVAL '90 days'

View File

@@ -1,97 +0,0 @@
-- =============================================================
-- View: analytics.retention_agent
-- Looker source alias: ds35 | Charts: 2
-- =============================================================
-- DESCRIPTION
-- Weekly cohort retention broken down per individual agent.
-- Cohort = week of a user's first use of THAT specific agent.
-- Tells you which agents keep users coming back vs. one-shot
-- use. Only includes cohorts from the last 180 days.
--
-- SOURCE TABLES
-- platform.AgentGraphExecution — Execution records (user × agent × time)
-- platform.AgentGraph — Agent names
--
-- OUTPUT COLUMNS
-- agent_id TEXT Agent graph UUID
-- agent_label TEXT 'AgentName [first8chars]'
-- agent_label_n TEXT 'AgentName [first8chars] (n=total_users)'
-- cohort_week_start DATE Week users first ran this agent
-- cohort_label TEXT ISO week label
-- cohort_label_n TEXT ISO week label with cohort size
-- user_lifetime_week INT Weeks since first use of this agent
-- cohort_users BIGINT Users in this cohort for this agent
-- active_users BIGINT Users who ran the agent again in week k
-- retention_rate FLOAT active_users / cohort_users
-- cohort_users_w0 BIGINT cohort_users only at week 0 (safe to SUM)
-- agent_total_users BIGINT Total users across all cohorts for this agent
--
-- EXAMPLE QUERIES
-- -- Best-retained agents at week 2
-- SELECT agent_label, AVG(retention_rate) AS w2_retention
-- FROM analytics.retention_agent
-- WHERE user_lifetime_week = 2 AND cohort_users >= 10
-- GROUP BY 1 ORDER BY w2_retention DESC LIMIT 10;
--
-- -- Agents with most unique users
-- SELECT DISTINCT agent_label, agent_total_users
-- FROM analytics.retention_agent
-- ORDER BY agent_total_users DESC LIMIT 20;
-- =============================================================
WITH params AS (SELECT 12::int AS max_weeks, (CURRENT_DATE - INTERVAL '180 days') AS cohort_start),
events AS (
SELECT e."userId"::text AS user_id, e."agentGraphId" AS agent_id,
e."createdAt"::timestamptz AS created_at,
DATE_TRUNC('week', e."createdAt")::date AS week_start
FROM platform."AgentGraphExecution" e
),
first_use AS (
SELECT user_id, agent_id, MIN(created_at) AS first_use_at,
DATE_TRUNC('week', MIN(created_at))::date AS cohort_week_start
FROM events GROUP BY 1,2
HAVING MIN(created_at) >= (SELECT cohort_start FROM params)
),
activity_weeks AS (SELECT DISTINCT user_id, agent_id, week_start FROM events),
user_week_age AS (
SELECT aw.user_id, aw.agent_id, fu.cohort_week_start,
((aw.week_start - DATE_TRUNC('week',fu.first_use_at)::date)/7)::int AS user_lifetime_week
FROM activity_weeks aw JOIN first_use fu USING (user_id, agent_id)
WHERE aw.week_start >= DATE_TRUNC('week',fu.first_use_at)::date
),
active_counts AS (
SELECT agent_id, cohort_week_start, user_lifetime_week, COUNT(DISTINCT user_id) AS active_users
FROM user_week_age WHERE user_lifetime_week >= 0 GROUP BY 1,2,3
),
cohort_sizes AS (
SELECT agent_id, cohort_week_start, COUNT(DISTINCT user_id) AS cohort_users FROM first_use GROUP BY 1,2
),
cohort_caps AS (
SELECT cs.agent_id, cs.cohort_week_start, cs.cohort_users,
LEAST((SELECT max_weeks FROM params),
GREATEST(0,((DATE_TRUNC('week',CURRENT_DATE)::date-cs.cohort_week_start)/7)::int)) AS cap_weeks
FROM cohort_sizes cs
),
grid AS (
SELECT cc.agent_id, cc.cohort_week_start, gs AS user_lifetime_week, cc.cohort_users
FROM cohort_caps cc CROSS JOIN LATERAL generate_series(0, cc.cap_weeks) gs
),
agent_names AS (SELECT DISTINCT ON (g."id") g."id" AS agent_id, g."name" AS agent_name FROM platform."AgentGraph" g ORDER BY g."id", g."version" DESC),
agent_total_users AS (SELECT agent_id, SUM(cohort_users) AS agent_total_users FROM cohort_sizes GROUP BY 1)
SELECT
g.agent_id,
COALESCE(an.agent_name,'(unnamed)')||' ['||LEFT(g.agent_id::text,8)||']' AS agent_label,
COALESCE(an.agent_name,'(unnamed)')||' ['||LEFT(g.agent_id::text,8)||'] (n='||COALESCE(atu.agent_total_users,0)||')' AS agent_label_n,
g.cohort_week_start,
TO_CHAR(g.cohort_week_start,'IYYY-"W"IW') AS cohort_label,
TO_CHAR(g.cohort_week_start,'IYYY-"W"IW')||' (n='||g.cohort_users||')' AS cohort_label_n,
g.user_lifetime_week, g.cohort_users,
COALESCE(ac.active_users,0) AS active_users,
COALESCE(ac.active_users,0)::float / NULLIF(g.cohort_users,0) AS retention_rate,
CASE WHEN g.user_lifetime_week=0 THEN g.cohort_users ELSE 0 END AS cohort_users_w0,
COALESCE(atu.agent_total_users,0) AS agent_total_users
FROM grid g
LEFT JOIN active_counts ac ON ac.agent_id=g.agent_id AND ac.cohort_week_start=g.cohort_week_start AND ac.user_lifetime_week=g.user_lifetime_week
LEFT JOIN agent_names an ON an.agent_id=g.agent_id
LEFT JOIN agent_total_users atu ON atu.agent_id=g.agent_id
ORDER BY agent_label, g.cohort_week_start, g.user_lifetime_week;

View File

@@ -1,81 +0,0 @@
-- =============================================================
-- View: analytics.retention_execution_daily
-- Looker source alias: ds111 | Charts: 1
-- =============================================================
-- DESCRIPTION
-- Daily cohort retention based on agent executions.
-- Cohort anchor = day of user's FIRST ever execution.
-- Only includes cohorts from the last 90 days, up to day 30.
-- Great for early engagement analysis (did users run another
-- agent the next day?).
--
-- SOURCE TABLES
-- platform.AgentGraphExecution — Execution records
--
-- OUTPUT COLUMNS
-- Same pattern as retention_login_daily.
-- cohort_day_start = day of first execution (not first login)
--
-- EXAMPLE QUERIES
-- -- Day-3 execution retention
-- SELECT cohort_label, retention_rate_bounded AS d3_retention
-- FROM analytics.retention_execution_daily
-- WHERE user_lifetime_day = 3 ORDER BY cohort_day_start;
-- =============================================================
WITH params AS (SELECT 30::int AS max_days, (CURRENT_DATE - INTERVAL '90 days') AS cohort_start),
events AS (
SELECT e."userId"::text AS user_id, e."createdAt"::timestamptz AS created_at,
DATE_TRUNC('day', e."createdAt")::date AS day_start
FROM platform."AgentGraphExecution" e WHERE e."userId" IS NOT NULL
),
first_exec AS (
SELECT user_id, MIN(created_at) AS first_exec_at,
DATE_TRUNC('day', MIN(created_at))::date AS cohort_day_start
FROM events GROUP BY 1
HAVING MIN(created_at) >= (SELECT cohort_start FROM params)
),
activity_days AS (SELECT DISTINCT user_id, day_start FROM events),
user_day_age AS (
SELECT ad.user_id, fe.cohort_day_start,
(ad.day_start - DATE_TRUNC('day',fe.first_exec_at)::date)::int AS user_lifetime_day
FROM activity_days ad JOIN first_exec fe USING (user_id)
WHERE ad.day_start >= DATE_TRUNC('day',fe.first_exec_at)::date
),
bounded_counts AS (
SELECT cohort_day_start, user_lifetime_day, COUNT(DISTINCT user_id) AS active_users_bounded
FROM user_day_age WHERE user_lifetime_day >= 0 GROUP BY 1,2
),
last_active AS (
SELECT cohort_day_start, user_id, MAX(user_lifetime_day) AS last_active_day FROM user_day_age GROUP BY 1,2
),
unbounded_counts AS (
SELECT la.cohort_day_start, gs AS user_lifetime_day, COUNT(*) AS retained_users_unbounded
FROM last_active la
CROSS JOIN LATERAL generate_series(0, LEAST(la.last_active_day,(SELECT max_days FROM params))) gs
GROUP BY 1,2
),
cohort_sizes AS (SELECT cohort_day_start, COUNT(DISTINCT user_id) AS cohort_users FROM first_exec GROUP BY 1),
cohort_caps AS (
SELECT cs.cohort_day_start, cs.cohort_users,
LEAST((SELECT max_days FROM params), GREATEST(0,(CURRENT_DATE-cs.cohort_day_start)::int)) AS cap_days
FROM cohort_sizes cs
),
grid AS (
SELECT cc.cohort_day_start, gs AS user_lifetime_day, cc.cohort_users
FROM cohort_caps cc CROSS JOIN LATERAL generate_series(0, cc.cap_days) gs
)
SELECT
g.cohort_day_start,
TO_CHAR(g.cohort_day_start,'YYYY-MM-DD') AS cohort_label,
TO_CHAR(g.cohort_day_start,'YYYY-MM-DD')||' (n='||g.cohort_users||')' AS cohort_label_n,
g.user_lifetime_day, g.cohort_users,
COALESCE(b.active_users_bounded,0) AS active_users_bounded,
COALESCE(u.retained_users_unbounded,0) AS retained_users_unbounded,
CASE WHEN g.cohort_users>0 THEN COALESCE(b.active_users_bounded,0)::float/g.cohort_users END AS retention_rate_bounded,
CASE WHEN g.cohort_users>0 THEN COALESCE(u.retained_users_unbounded,0)::float/g.cohort_users END AS retention_rate_unbounded,
CASE WHEN g.user_lifetime_day=0 THEN g.cohort_users ELSE 0 END AS cohort_users_d0
FROM grid g
LEFT JOIN bounded_counts b ON b.cohort_day_start=g.cohort_day_start AND b.user_lifetime_day=g.user_lifetime_day
LEFT JOIN unbounded_counts u ON u.cohort_day_start=g.cohort_day_start AND u.user_lifetime_day=g.user_lifetime_day
ORDER BY g.cohort_day_start, g.user_lifetime_day;

View File

@@ -1,81 +0,0 @@
-- =============================================================
-- View: analytics.retention_execution_weekly
-- Looker source alias: ds92 | Charts: 2
-- =============================================================
-- DESCRIPTION
-- Weekly cohort retention based on agent executions.
-- Cohort anchor = week of user's FIRST ever agent execution
-- (not first login). Only includes cohorts from the last 180 days.
-- Useful when you care about product engagement, not just visits.
--
-- SOURCE TABLES
-- platform.AgentGraphExecution — Execution records
--
-- OUTPUT COLUMNS
-- Same pattern as retention_login_weekly.
-- cohort_week_start = week of first execution (not first login)
--
-- EXAMPLE QUERIES
-- -- Week-2 execution retention
-- SELECT cohort_label, retention_rate_bounded
-- FROM analytics.retention_execution_weekly
-- WHERE user_lifetime_week = 2 ORDER BY cohort_week_start;
-- =============================================================
WITH params AS (SELECT 12::int AS max_weeks, (CURRENT_DATE - INTERVAL '180 days') AS cohort_start),
events AS (
SELECT e."userId"::text AS user_id, e."createdAt"::timestamptz AS created_at,
DATE_TRUNC('week', e."createdAt")::date AS week_start
FROM platform."AgentGraphExecution" e WHERE e."userId" IS NOT NULL
),
first_exec AS (
SELECT user_id, MIN(created_at) AS first_exec_at,
DATE_TRUNC('week', MIN(created_at))::date AS cohort_week_start
FROM events GROUP BY 1
HAVING MIN(created_at) >= (SELECT cohort_start FROM params)
),
activity_weeks AS (SELECT DISTINCT user_id, week_start FROM events),
user_week_age AS (
SELECT aw.user_id, fe.cohort_week_start,
((aw.week_start - DATE_TRUNC('week',fe.first_exec_at)::date)/7)::int AS user_lifetime_week
FROM activity_weeks aw JOIN first_exec fe USING (user_id)
WHERE aw.week_start >= DATE_TRUNC('week',fe.first_exec_at)::date
),
bounded_counts AS (
SELECT cohort_week_start, user_lifetime_week, COUNT(DISTINCT user_id) AS active_users_bounded
FROM user_week_age WHERE user_lifetime_week >= 0 GROUP BY 1,2
),
last_active AS (
SELECT cohort_week_start, user_id, MAX(user_lifetime_week) AS last_active_week FROM user_week_age GROUP BY 1,2
),
unbounded_counts AS (
SELECT la.cohort_week_start, gs AS user_lifetime_week, COUNT(*) AS retained_users_unbounded
FROM last_active la
CROSS JOIN LATERAL generate_series(0, LEAST(la.last_active_week,(SELECT max_weeks FROM params))) gs
GROUP BY 1,2
),
cohort_sizes AS (SELECT cohort_week_start, COUNT(DISTINCT user_id) AS cohort_users FROM first_exec GROUP BY 1),
cohort_caps AS (
SELECT cs.cohort_week_start, cs.cohort_users,
LEAST((SELECT max_weeks FROM params),
GREATEST(0,((DATE_TRUNC('week',CURRENT_DATE)::date-cs.cohort_week_start)/7)::int)) AS cap_weeks
FROM cohort_sizes cs
),
grid AS (
SELECT cc.cohort_week_start, gs AS user_lifetime_week, cc.cohort_users
FROM cohort_caps cc CROSS JOIN LATERAL generate_series(0, cc.cap_weeks) gs
)
SELECT
g.cohort_week_start,
TO_CHAR(g.cohort_week_start,'IYYY-"W"IW') AS cohort_label,
TO_CHAR(g.cohort_week_start,'IYYY-"W"IW')||' (n='||g.cohort_users||')' AS cohort_label_n,
g.user_lifetime_week, g.cohort_users,
COALESCE(b.active_users_bounded,0) AS active_users_bounded,
COALESCE(u.retained_users_unbounded,0) AS retained_users_unbounded,
CASE WHEN g.cohort_users>0 THEN COALESCE(b.active_users_bounded,0)::float/g.cohort_users END AS retention_rate_bounded,
CASE WHEN g.cohort_users>0 THEN COALESCE(u.retained_users_unbounded,0)::float/g.cohort_users END AS retention_rate_unbounded,
CASE WHEN g.user_lifetime_week=0 THEN g.cohort_users ELSE 0 END AS cohort_users_w0
FROM grid g
LEFT JOIN bounded_counts b ON b.cohort_week_start=g.cohort_week_start AND b.user_lifetime_week=g.user_lifetime_week
LEFT JOIN unbounded_counts u ON u.cohort_week_start=g.cohort_week_start AND u.user_lifetime_week=g.user_lifetime_week
ORDER BY g.cohort_week_start, g.user_lifetime_week;

View File

@@ -1,94 +0,0 @@
-- =============================================================
-- View: analytics.retention_login_daily
-- Looker source alias: ds112 | Charts: 1
-- =============================================================
-- DESCRIPTION
-- Daily cohort retention based on login sessions.
-- Same logic as retention_login_weekly but at day granularity,
-- showing up to day 30 for cohorts from the last 90 days.
-- Useful for analysing early activation (days 1-7) in detail.
--
-- SOURCE TABLES
-- auth.sessions — Login session records
--
-- OUTPUT COLUMNS (same pattern as retention_login_weekly)
-- cohort_day_start DATE First day the cohort logged in
-- cohort_label TEXT Date string (e.g. '2025-03-01')
-- cohort_label_n TEXT Date + cohort size (e.g. '2025-03-01 (n=12)')
-- user_lifetime_day INT Days since first login (0 = signup day)
-- cohort_users BIGINT Total users in cohort
-- active_users_bounded BIGINT Users active on exactly day k
-- retained_users_unbounded BIGINT Users active any time on/after day k
-- retention_rate_bounded FLOAT bounded / cohort_users
-- retention_rate_unbounded FLOAT unbounded / cohort_users
-- cohort_users_d0 BIGINT cohort_users only at day 0, else 0 (safe to SUM)
--
-- EXAMPLE QUERIES
-- -- Day-1 retention rate (came back next day)
-- SELECT cohort_label, retention_rate_bounded AS d1_retention
-- FROM analytics.retention_login_daily
-- WHERE user_lifetime_day = 1 ORDER BY cohort_day_start;
--
-- -- Average retention curve across all cohorts
-- SELECT user_lifetime_day,
-- SUM(active_users_bounded)::float / NULLIF(SUM(cohort_users_d0), 0) AS avg_retention
-- FROM analytics.retention_login_daily
-- GROUP BY 1 ORDER BY 1;
-- =============================================================
WITH params AS (SELECT 30::int AS max_days, (CURRENT_DATE - INTERVAL '90 days')::date AS cohort_start),
events AS (
SELECT s.user_id::text AS user_id, s.created_at::timestamptz AS created_at,
DATE_TRUNC('day', s.created_at)::date AS day_start
FROM auth.sessions s WHERE s.user_id IS NOT NULL
),
first_login AS (
SELECT user_id, MIN(created_at) AS first_login_time,
DATE_TRUNC('day', MIN(created_at))::date AS cohort_day_start
FROM events GROUP BY 1
HAVING MIN(created_at) >= (SELECT cohort_start FROM params)
),
activity_days AS (SELECT DISTINCT user_id, day_start FROM events),
user_day_age AS (
SELECT ad.user_id, fl.cohort_day_start,
(ad.day_start - DATE_TRUNC('day', fl.first_login_time)::date)::int AS user_lifetime_day
FROM activity_days ad JOIN first_login fl USING (user_id)
WHERE ad.day_start >= DATE_TRUNC('day', fl.first_login_time)::date
),
bounded_counts AS (
SELECT cohort_day_start, user_lifetime_day, COUNT(DISTINCT user_id) AS active_users_bounded
FROM user_day_age WHERE user_lifetime_day >= 0 GROUP BY 1,2
),
last_active AS (
SELECT cohort_day_start, user_id, MAX(user_lifetime_day) AS last_active_day FROM user_day_age GROUP BY 1,2
),
unbounded_counts AS (
SELECT la.cohort_day_start, gs AS user_lifetime_day, COUNT(*) AS retained_users_unbounded
FROM last_active la
CROSS JOIN LATERAL generate_series(0, LEAST(la.last_active_day,(SELECT max_days FROM params))) gs
GROUP BY 1,2
),
cohort_sizes AS (SELECT cohort_day_start, COUNT(DISTINCT user_id) AS cohort_users FROM first_login GROUP BY 1),
cohort_caps AS (
SELECT cs.cohort_day_start, cs.cohort_users,
LEAST((SELECT max_days FROM params), GREATEST(0,(CURRENT_DATE-cs.cohort_day_start)::int)) AS cap_days
FROM cohort_sizes cs
),
grid AS (
SELECT cc.cohort_day_start, gs AS user_lifetime_day, cc.cohort_users
FROM cohort_caps cc CROSS JOIN LATERAL generate_series(0, cc.cap_days) gs
)
SELECT
g.cohort_day_start,
TO_CHAR(g.cohort_day_start,'YYYY-MM-DD') AS cohort_label,
TO_CHAR(g.cohort_day_start,'YYYY-MM-DD')||' (n='||g.cohort_users||')' AS cohort_label_n,
g.user_lifetime_day, g.cohort_users,
COALESCE(b.active_users_bounded,0) AS active_users_bounded,
COALESCE(u.retained_users_unbounded,0) AS retained_users_unbounded,
CASE WHEN g.cohort_users>0 THEN COALESCE(b.active_users_bounded,0)::float/g.cohort_users END AS retention_rate_bounded,
CASE WHEN g.cohort_users>0 THEN COALESCE(u.retained_users_unbounded,0)::float/g.cohort_users END AS retention_rate_unbounded,
CASE WHEN g.user_lifetime_day=0 THEN g.cohort_users ELSE 0 END AS cohort_users_d0
FROM grid g
LEFT JOIN bounded_counts b ON b.cohort_day_start=g.cohort_day_start AND b.user_lifetime_day=g.user_lifetime_day
LEFT JOIN unbounded_counts u ON u.cohort_day_start=g.cohort_day_start AND u.user_lifetime_day=g.user_lifetime_day
ORDER BY g.cohort_day_start, g.user_lifetime_day;

View File

@@ -1,96 +0,0 @@
-- =============================================================
-- View: analytics.retention_login_onboarded_weekly
-- Looker source alias: ds101 | Charts: 2
-- =============================================================
-- DESCRIPTION
-- Weekly cohort retention from login sessions, restricted to
-- users who "onboarded" — defined as running at least one
-- agent within 365 days of their first login.
-- Filters out users who signed up but never activated,
-- giving a cleaner view of engaged-user retention.
--
-- SOURCE TABLES
-- auth.sessions — Login session records
-- platform.AgentGraphExecution — Used to identify onboarders
--
-- OUTPUT COLUMNS
-- Same as retention_login_weekly (cohort_week_start, user_lifetime_week,
-- retention_rate_bounded, retention_rate_unbounded, etc.)
-- Only difference: cohort is filtered to onboarded users only.
--
-- EXAMPLE QUERIES
-- -- Compare week-4 retention: all users vs onboarded only
-- SELECT 'all_users' AS segment, AVG(retention_rate_bounded) AS w4_retention
-- FROM analytics.retention_login_weekly WHERE user_lifetime_week = 4
-- UNION ALL
-- SELECT 'onboarded', AVG(retention_rate_bounded)
-- FROM analytics.retention_login_onboarded_weekly WHERE user_lifetime_week = 4;
-- =============================================================
WITH params AS (SELECT 12::int AS max_weeks, 365::int AS onboarding_window_days),
events AS (
SELECT s.user_id::text AS user_id, s.created_at::timestamptz AS created_at,
DATE_TRUNC('week', s.created_at)::date AS week_start
FROM auth.sessions s WHERE s.user_id IS NOT NULL
),
first_login_all AS (
SELECT user_id, MIN(created_at) AS first_login_time,
DATE_TRUNC('week', MIN(created_at))::date AS cohort_week_start
FROM events GROUP BY 1
),
onboarders AS (
SELECT fl.user_id FROM first_login_all fl
WHERE EXISTS (
SELECT 1 FROM platform."AgentGraphExecution" e
WHERE e."userId"::text = fl.user_id
AND e."createdAt" >= fl.first_login_time
AND e."createdAt" < fl.first_login_time
+ make_interval(days => (SELECT onboarding_window_days FROM params))
)
),
first_login AS (SELECT * FROM first_login_all WHERE user_id IN (SELECT user_id FROM onboarders)),
activity_weeks AS (SELECT DISTINCT user_id, week_start FROM events),
user_week_age AS (
SELECT aw.user_id, fl.cohort_week_start,
((aw.week_start - DATE_TRUNC('week',fl.first_login_time)::date)/7)::int AS user_lifetime_week
FROM activity_weeks aw JOIN first_login fl USING (user_id)
WHERE aw.week_start >= DATE_TRUNC('week',fl.first_login_time)::date
),
bounded_counts AS (
SELECT cohort_week_start, user_lifetime_week, COUNT(DISTINCT user_id) AS active_users_bounded
FROM user_week_age WHERE user_lifetime_week >= 0 GROUP BY 1,2
),
last_active AS (
SELECT cohort_week_start, user_id, MAX(user_lifetime_week) AS last_active_week FROM user_week_age GROUP BY 1,2
),
unbounded_counts AS (
SELECT la.cohort_week_start, gs AS user_lifetime_week, COUNT(*) AS retained_users_unbounded
FROM last_active la
CROSS JOIN LATERAL generate_series(0, LEAST(la.last_active_week,(SELECT max_weeks FROM params))) gs
GROUP BY 1,2
),
cohort_sizes AS (SELECT cohort_week_start, COUNT(DISTINCT user_id) AS cohort_users FROM first_login GROUP BY 1),
cohort_caps AS (
SELECT cs.cohort_week_start, cs.cohort_users,
LEAST((SELECT max_weeks FROM params),
GREATEST(0,((DATE_TRUNC('week',CURRENT_DATE)::date-cs.cohort_week_start)/7)::int)) AS cap_weeks
FROM cohort_sizes cs
),
grid AS (
SELECT cc.cohort_week_start, gs AS user_lifetime_week, cc.cohort_users
FROM cohort_caps cc CROSS JOIN LATERAL generate_series(0, cc.cap_weeks) gs
)
SELECT
g.cohort_week_start,
TO_CHAR(g.cohort_week_start,'IYYY-"W"IW') AS cohort_label,
TO_CHAR(g.cohort_week_start,'IYYY-"W"IW')||' (n='||g.cohort_users||')' AS cohort_label_n,
g.user_lifetime_week, g.cohort_users,
COALESCE(b.active_users_bounded,0) AS active_users_bounded,
COALESCE(u.retained_users_unbounded,0) AS retained_users_unbounded,
CASE WHEN g.cohort_users>0 THEN COALESCE(b.active_users_bounded,0)::float/g.cohort_users END AS retention_rate_bounded,
CASE WHEN g.cohort_users>0 THEN COALESCE(u.retained_users_unbounded,0)::float/g.cohort_users END AS retention_rate_unbounded,
CASE WHEN g.user_lifetime_week=0 THEN g.cohort_users ELSE 0 END AS cohort_users_w0
FROM grid g
LEFT JOIN bounded_counts b ON b.cohort_week_start=g.cohort_week_start AND b.user_lifetime_week=g.user_lifetime_week
LEFT JOIN unbounded_counts u ON u.cohort_week_start=g.cohort_week_start AND u.user_lifetime_week=g.user_lifetime_week
ORDER BY g.cohort_week_start, g.user_lifetime_week;

View File

@@ -1,103 +0,0 @@
-- =============================================================
-- View: analytics.retention_login_weekly
-- Looker source alias: ds83 | Charts: 2
-- =============================================================
-- DESCRIPTION
-- Weekly cohort retention based on login sessions.
-- Users are grouped by the ISO week of their first ever login.
-- For each cohort × lifetime-week combination, outputs both:
-- - bounded rate: % active in exactly that week
-- - unbounded rate: % who were ever active on or after that week
-- Weeks are capped to the cohort's actual age (no future data points).
--
-- SOURCE TABLES
-- auth.sessions — Login session records
--
-- HOW TO READ THE OUTPUT
-- cohort_week_start The Monday of the week users first logged in
-- user_lifetime_week 0 = signup week, 1 = one week later, etc.
-- retention_rate_bounded = active_users_bounded / cohort_users
-- retention_rate_unbounded = retained_users_unbounded / cohort_users
--
-- OUTPUT COLUMNS
-- cohort_week_start DATE First day of the cohort's signup week
-- cohort_label TEXT ISO week label (e.g. '2025-W01')
-- cohort_label_n TEXT ISO week label with cohort size (e.g. '2025-W01 (n=42)')
-- user_lifetime_week INT Weeks since first login (0 = signup week)
-- cohort_users BIGINT Total users in this cohort (denominator)
-- active_users_bounded BIGINT Users active in exactly week k
-- retained_users_unbounded BIGINT Users active any time on/after week k
-- retention_rate_bounded FLOAT bounded active / cohort_users
-- retention_rate_unbounded FLOAT unbounded retained / cohort_users
-- cohort_users_w0 BIGINT cohort_users only at week 0, else 0 (safe to SUM in pivot tables)
--
-- EXAMPLE QUERIES
-- -- Week-1 retention rate per cohort
-- SELECT cohort_label, retention_rate_bounded AS w1_retention
-- FROM analytics.retention_login_weekly
-- WHERE user_lifetime_week = 1
-- ORDER BY cohort_week_start;
--
-- -- Overall average retention curve (all cohorts combined)
-- SELECT user_lifetime_week,
-- SUM(active_users_bounded)::float / NULLIF(SUM(cohort_users_w0), 0) AS avg_retention
-- FROM analytics.retention_login_weekly
-- GROUP BY 1 ORDER BY 1;
-- =============================================================
WITH params AS (SELECT 12::int AS max_weeks),
events AS (
SELECT s.user_id::text AS user_id, s.created_at::timestamptz AS created_at,
DATE_TRUNC('week', s.created_at)::date AS week_start
FROM auth.sessions s WHERE s.user_id IS NOT NULL
),
first_login AS (
SELECT user_id, MIN(created_at) AS first_login_time,
DATE_TRUNC('week', MIN(created_at))::date AS cohort_week_start
FROM events GROUP BY 1
),
activity_weeks AS (SELECT DISTINCT user_id, week_start FROM events),
user_week_age AS (
SELECT aw.user_id, fl.cohort_week_start,
((aw.week_start - DATE_TRUNC('week', fl.first_login_time)::date) / 7)::int AS user_lifetime_week
FROM activity_weeks aw JOIN first_login fl USING (user_id)
WHERE aw.week_start >= DATE_TRUNC('week', fl.first_login_time)::date
),
bounded_counts AS (
SELECT cohort_week_start, user_lifetime_week, COUNT(DISTINCT user_id) AS active_users_bounded
FROM user_week_age WHERE user_lifetime_week >= 0 GROUP BY 1,2
),
last_active AS (
SELECT cohort_week_start, user_id, MAX(user_lifetime_week) AS last_active_week FROM user_week_age GROUP BY 1,2
),
unbounded_counts AS (
SELECT la.cohort_week_start, gs AS user_lifetime_week, COUNT(*) AS retained_users_unbounded
FROM last_active la
CROSS JOIN LATERAL generate_series(0, LEAST(la.last_active_week,(SELECT max_weeks FROM params))) gs
GROUP BY 1,2
),
cohort_sizes AS (SELECT cohort_week_start, COUNT(DISTINCT user_id) AS cohort_users FROM first_login GROUP BY 1),
cohort_caps AS (
SELECT cs.cohort_week_start, cs.cohort_users,
LEAST((SELECT max_weeks FROM params),
GREATEST(0,((DATE_TRUNC('week',CURRENT_DATE)::date - cs.cohort_week_start)/7)::int)) AS cap_weeks
FROM cohort_sizes cs
),
grid AS (
SELECT cc.cohort_week_start, gs AS user_lifetime_week, cc.cohort_users
FROM cohort_caps cc CROSS JOIN LATERAL generate_series(0, cc.cap_weeks) gs
)
SELECT
g.cohort_week_start,
TO_CHAR(g.cohort_week_start,'IYYY-"W"IW') AS cohort_label,
TO_CHAR(g.cohort_week_start,'IYYY-"W"IW')||' (n='||g.cohort_users||')' AS cohort_label_n,
g.user_lifetime_week, g.cohort_users,
COALESCE(b.active_users_bounded,0) AS active_users_bounded,
COALESCE(u.retained_users_unbounded,0) AS retained_users_unbounded,
CASE WHEN g.cohort_users>0 THEN COALESCE(b.active_users_bounded,0)::float/g.cohort_users END AS retention_rate_bounded,
CASE WHEN g.cohort_users>0 THEN COALESCE(u.retained_users_unbounded,0)::float/g.cohort_users END AS retention_rate_unbounded,
CASE WHEN g.user_lifetime_week=0 THEN g.cohort_users ELSE 0 END AS cohort_users_w0
FROM grid g
LEFT JOIN bounded_counts b ON b.cohort_week_start=g.cohort_week_start AND b.user_lifetime_week=g.user_lifetime_week
LEFT JOIN unbounded_counts u ON u.cohort_week_start=g.cohort_week_start AND u.user_lifetime_week=g.user_lifetime_week
ORDER BY g.cohort_week_start, g.user_lifetime_week

View File

@@ -1,71 +0,0 @@
-- =============================================================
-- View: analytics.user_block_spending
-- Looker source alias: ds6 | Charts: 5
-- =============================================================
-- DESCRIPTION
-- One row per credit transaction (last 90 days).
-- Shows how users spend credits broken down by block type,
-- LLM provider and model. Joins node execution stats for
-- token-level detail.
--
-- SOURCE TABLES
-- platform.CreditTransaction — Credit debit/credit records
-- platform.AgentNodeExecution — Node execution stats (for token counts)
--
-- OUTPUT COLUMNS
-- transactionKey TEXT Unique transaction identifier
-- userId TEXT User who was charged
-- amount DECIMAL Credit amount (positive = credit, negative = debit)
-- negativeAmount DECIMAL amount * -1 (convenience for spend charts)
-- transactionType TEXT Transaction type (e.g. 'USAGE', 'REFUND', 'TOP_UP')
-- transactionTime TIMESTAMPTZ When the transaction was recorded
-- blockId TEXT Block UUID that triggered the spend
-- blockName TEXT Human-readable block name
-- llm_provider TEXT LLM provider (e.g. 'openai', 'anthropic')
-- llm_model TEXT Model name (e.g. 'gpt-4o', 'claude-3-5-sonnet')
-- node_exec_id TEXT Linked node execution UUID
-- llm_call_count INT LLM API calls made in that execution
-- llm_retry_count INT LLM retries in that execution
-- llm_input_token_count INT Input tokens consumed
-- llm_output_token_count INT Output tokens produced
--
-- WINDOW
-- Rolling 90 days (createdAt > CURRENT_DATE - 90 days)
--
-- EXAMPLE QUERIES
-- -- Total spend per user (last 90 days)
-- SELECT "userId", SUM("negativeAmount") AS total_spent
-- FROM analytics.user_block_spending
-- WHERE "transactionType" = 'USAGE'
-- GROUP BY 1 ORDER BY total_spent DESC;
--
-- -- Spend by LLM provider + model
-- SELECT "llm_provider", "llm_model",
-- SUM("negativeAmount") AS total_cost,
-- SUM("llm_input_token_count") AS input_tokens,
-- SUM("llm_output_token_count") AS output_tokens
-- FROM analytics.user_block_spending
-- WHERE "llm_provider" IS NOT NULL
-- GROUP BY 1, 2 ORDER BY total_cost DESC;
-- =============================================================
SELECT
c."transactionKey" AS transactionKey,
c."userId" AS userId,
c."amount" AS amount,
c."amount" * -1 AS negativeAmount,
c."type" AS transactionType,
c."createdAt" AS transactionTime,
c.metadata->>'block_id' AS blockId,
c.metadata->>'block' AS blockName,
c.metadata->'input'->'credentials'->>'provider' AS llm_provider,
c.metadata->'input'->>'model' AS llm_model,
c.metadata->>'node_exec_id' AS node_exec_id,
(ne."stats"->>'llm_call_count')::int AS llm_call_count,
(ne."stats"->>'llm_retry_count')::int AS llm_retry_count,
(ne."stats"->>'input_token_count')::int AS llm_input_token_count,
(ne."stats"->>'output_token_count')::int AS llm_output_token_count
FROM platform."CreditTransaction" c
LEFT JOIN platform."AgentNodeExecution" ne
ON (c.metadata->>'node_exec_id') = ne."id"::text
WHERE c."createdAt" > CURRENT_DATE - INTERVAL '90 days'

View File

@@ -1,45 +0,0 @@
-- =============================================================
-- View: analytics.user_onboarding
-- Looker source alias: ds68 | Charts: 3
-- =============================================================
-- DESCRIPTION
-- One row per user onboarding record. Contains the user's
-- stated usage reason, selected integrations, completed
-- onboarding steps and optional first agent selection.
-- Full history (no date filter) since onboarding happens
-- once per user.
--
-- SOURCE TABLES
-- platform.UserOnboarding — Onboarding state per user
--
-- OUTPUT COLUMNS
-- id TEXT Onboarding record UUID
-- createdAt TIMESTAMPTZ When onboarding started
-- updatedAt TIMESTAMPTZ Last update to onboarding state
-- usageReason TEXT Why user signed up (e.g. 'work', 'personal')
-- integrations TEXT[] Array of integration names the user selected
-- userId TEXT User UUID
-- completedSteps TEXT[] Array of onboarding step enums completed
-- selectedStoreListingVersionId TEXT First marketplace agent the user chose (if any)
--
-- EXAMPLE QUERIES
-- -- Usage reason breakdown
-- SELECT "usageReason", COUNT(*) FROM analytics.user_onboarding GROUP BY 1;
--
-- -- Completion rate per step
-- SELECT step, COUNT(*) AS users_completed
-- FROM analytics.user_onboarding
-- CROSS JOIN LATERAL UNNEST("completedSteps") AS step
-- GROUP BY 1 ORDER BY users_completed DESC;
-- =============================================================
SELECT
id,
"createdAt",
"updatedAt",
"usageReason",
integrations,
"userId",
"completedSteps",
"selectedStoreListingVersionId"
FROM platform."UserOnboarding"

View File

@@ -1,100 +0,0 @@
-- =============================================================
-- View: analytics.user_onboarding_funnel
-- Looker source alias: ds74 | Charts: 1
-- =============================================================
-- DESCRIPTION
-- Pre-aggregated onboarding funnel showing how many users
-- completed each step and the drop-off percentage from the
-- previous step. One row per onboarding step (all 22 steps
-- always present, even with 0 completions — prevents sparse
-- gaps from making LAG compare the wrong predecessors).
--
-- SOURCE TABLES
-- platform.UserOnboarding — Onboarding records with completedSteps array
--
-- OUTPUT COLUMNS
-- step TEXT Onboarding step enum name (e.g. 'WELCOME', 'CONGRATS')
-- step_order INT Numeric position in the funnel (1=first, 22=last)
-- users_completed BIGINT Distinct users who completed this step
-- pct_from_prev NUMERIC % of users from the previous step who reached this one
--
-- STEP ORDER
-- 1 WELCOME 9 MARKETPLACE_VISIT 17 SCHEDULE_AGENT
-- 2 USAGE_REASON 10 MARKETPLACE_ADD_AGENT 18 RUN_AGENTS
-- 3 INTEGRATIONS 11 MARKETPLACE_RUN_AGENT 19 RUN_3_DAYS
-- 4 AGENT_CHOICE 12 BUILDER_OPEN 20 TRIGGER_WEBHOOK
-- 5 AGENT_NEW_RUN 13 BUILDER_SAVE_AGENT 21 RUN_14_DAYS
-- 6 AGENT_INPUT 14 BUILDER_RUN_AGENT 22 RUN_AGENTS_100
-- 7 CONGRATS 15 VISIT_COPILOT
-- 8 GET_RESULTS 16 RE_RUN_AGENT
--
-- WINDOW
-- Users who started onboarding in the last 90 days
--
-- EXAMPLE QUERIES
-- -- Full funnel
-- SELECT * FROM analytics.user_onboarding_funnel ORDER BY step_order;
--
-- -- Biggest drop-off point
-- SELECT step, pct_from_prev FROM analytics.user_onboarding_funnel
-- ORDER BY pct_from_prev ASC LIMIT 3;
-- =============================================================
WITH all_steps AS (
-- Complete ordered grid of all 22 steps so zero-completion steps
-- are always present, keeping LAG comparisons correct.
SELECT step_name, step_order
FROM (VALUES
('WELCOME', 1),
('USAGE_REASON', 2),
('INTEGRATIONS', 3),
('AGENT_CHOICE', 4),
('AGENT_NEW_RUN', 5),
('AGENT_INPUT', 6),
('CONGRATS', 7),
('GET_RESULTS', 8),
('MARKETPLACE_VISIT', 9),
('MARKETPLACE_ADD_AGENT', 10),
('MARKETPLACE_RUN_AGENT', 11),
('BUILDER_OPEN', 12),
('BUILDER_SAVE_AGENT', 13),
('BUILDER_RUN_AGENT', 14),
('VISIT_COPILOT', 15),
('RE_RUN_AGENT', 16),
('SCHEDULE_AGENT', 17),
('RUN_AGENTS', 18),
('RUN_3_DAYS', 19),
('TRIGGER_WEBHOOK', 20),
('RUN_14_DAYS', 21),
('RUN_AGENTS_100', 22)
) AS t(step_name, step_order)
),
raw AS (
SELECT
u."userId",
step_txt::text AS step
FROM platform."UserOnboarding" u
CROSS JOIN LATERAL UNNEST(u."completedSteps") AS step_txt
WHERE u."createdAt" >= CURRENT_DATE - INTERVAL '90 days'
),
step_counts AS (
SELECT step, COUNT(DISTINCT "userId") AS users_completed
FROM raw GROUP BY step
),
funnel AS (
SELECT
a.step_name AS step,
a.step_order,
COALESCE(sc.users_completed, 0) AS users_completed,
ROUND(
100.0 * COALESCE(sc.users_completed, 0)
/ NULLIF(
LAG(COALESCE(sc.users_completed, 0)) OVER (ORDER BY a.step_order),
0
),
2
) AS pct_from_prev
FROM all_steps a
LEFT JOIN step_counts sc ON sc.step = a.step_name
)
SELECT * FROM funnel ORDER BY step_order

View File

@@ -1,41 +0,0 @@
-- =============================================================
-- View: analytics.user_onboarding_integration
-- Looker source alias: ds75 | Charts: 1
-- =============================================================
-- DESCRIPTION
-- Pre-aggregated count of users who selected each integration
-- during onboarding. One row per integration type, sorted
-- by popularity.
--
-- SOURCE TABLES
-- platform.UserOnboarding — integrations array column
--
-- OUTPUT COLUMNS
-- integration TEXT Integration name (e.g. 'github', 'slack', 'notion')
-- users_with_integration BIGINT Distinct users who selected this integration
--
-- WINDOW
-- Users who started onboarding in the last 90 days
--
-- EXAMPLE QUERIES
-- -- Full integration popularity ranking
-- SELECT * FROM analytics.user_onboarding_integration;
--
-- -- Top 5 integrations
-- SELECT * FROM analytics.user_onboarding_integration LIMIT 5;
-- =============================================================
WITH exploded AS (
SELECT
u."userId" AS user_id,
UNNEST(u."integrations") AS integration
FROM platform."UserOnboarding" u
WHERE u."createdAt" >= CURRENT_DATE - INTERVAL '90 days'
)
SELECT
integration,
COUNT(DISTINCT user_id) AS users_with_integration
FROM exploded
WHERE integration IS NOT NULL AND integration <> ''
GROUP BY integration
ORDER BY users_with_integration DESC

View File

@@ -1,145 +0,0 @@
-- =============================================================
-- View: analytics.users_activities
-- Looker source alias: ds56 | Charts: 5
-- =============================================================
-- DESCRIPTION
-- One row per user with lifetime activity summary.
-- Joins login sessions with agent graphs, executions and
-- node-level runs to give a full picture of how engaged
-- each user is. Includes a convenience flag for 7-day
-- activation (did the user return at least 7 days after
-- their first login?).
--
-- SOURCE TABLES
-- auth.sessions — Login/session records
-- platform.AgentGraph — Graphs (agents) built by the user
-- platform.AgentGraphExecution — Agent run history
-- platform.AgentNodeExecution — Individual block execution history
--
-- PERFORMANCE NOTE
-- Each CTE aggregates its own table independently by userId.
-- This avoids the fan-out that occurs when driving every join
-- from user_logins across the two largest tables
-- (AgentGraphExecution and AgentNodeExecution).
--
-- OUTPUT COLUMNS
-- user_id TEXT Supabase user UUID
-- first_login_time TIMESTAMPTZ First ever session created_at
-- last_login_time TIMESTAMPTZ Most recent session created_at
-- last_visit_time TIMESTAMPTZ Max of last refresh or login
-- last_agent_save_time TIMESTAMPTZ Last time user saved an agent graph
-- agent_count BIGINT Number of distinct active graphs built (0 if none)
-- first_agent_run_time TIMESTAMPTZ First ever graph execution
-- last_agent_run_time TIMESTAMPTZ Most recent graph execution
-- unique_agent_runs BIGINT Distinct agent graphs ever run (0 if none)
-- agent_runs BIGINT Total graph execution count (0 if none)
-- node_execution_count BIGINT Total node executions across all runs
-- node_execution_failed BIGINT Node executions with FAILED status
-- node_execution_completed BIGINT Node executions with COMPLETED status
-- node_execution_terminated BIGINT Node executions with TERMINATED status
-- node_execution_queued BIGINT Node executions with QUEUED status
-- node_execution_running BIGINT Node executions with RUNNING status
-- is_active_after_7d INT 1=returned after day 7, 0=did not, NULL=too early to tell
-- node_execution_incomplete BIGINT Node executions with INCOMPLETE status
-- node_execution_review BIGINT Node executions with REVIEW status
--
-- EXAMPLE QUERIES
-- -- Users who ran at least one agent and returned after 7 days
-- SELECT COUNT(*) FROM analytics.users_activities
-- WHERE agent_runs > 0 AND is_active_after_7d = 1;
--
-- -- Top 10 most active users by agent runs
-- SELECT user_id, agent_runs, node_execution_count
-- FROM analytics.users_activities
-- ORDER BY agent_runs DESC LIMIT 10;
--
-- -- 7-day activation rate
-- SELECT
-- SUM(CASE WHEN is_active_after_7d = 1 THEN 1 ELSE 0 END)::float
-- / NULLIF(COUNT(CASE WHEN is_active_after_7d IS NOT NULL THEN 1 END), 0)
-- AS activation_rate
-- FROM analytics.users_activities;
-- =============================================================
WITH user_logins AS (
SELECT
user_id::text AS user_id,
MIN(created_at) AS first_login_time,
MAX(created_at) AS last_login_time,
GREATEST(
MAX(refreshed_at)::timestamptz,
MAX(created_at)::timestamptz
) AS last_visit_time
FROM auth.sessions
GROUP BY user_id
),
user_agents AS (
-- Aggregate AgentGraph directly by userId (no fan-out from user_logins)
SELECT
"userId"::text AS user_id,
MAX("updatedAt") AS last_agent_save_time,
COUNT(DISTINCT "id") AS agent_count
FROM platform."AgentGraph"
WHERE "isActive"
GROUP BY "userId"
),
user_graph_runs AS (
-- Aggregate AgentGraphExecution directly by userId
SELECT
"userId"::text AS user_id,
MIN("createdAt") AS first_agent_run_time,
MAX("createdAt") AS last_agent_run_time,
COUNT(DISTINCT "agentGraphId") AS unique_agent_runs,
COUNT("id") AS agent_runs
FROM platform."AgentGraphExecution"
GROUP BY "userId"
),
user_node_runs AS (
-- Aggregate AgentNodeExecution directly; resolve userId via a
-- single join to AgentGraphExecution instead of fanning out from
-- user_logins through both large tables.
SELECT
g."userId"::text AS user_id,
COUNT(*) AS node_execution_count,
COUNT(*) FILTER (WHERE n."executionStatus" = 'FAILED') AS node_execution_failed,
COUNT(*) FILTER (WHERE n."executionStatus" = 'COMPLETED') AS node_execution_completed,
COUNT(*) FILTER (WHERE n."executionStatus" = 'TERMINATED') AS node_execution_terminated,
COUNT(*) FILTER (WHERE n."executionStatus" = 'QUEUED') AS node_execution_queued,
COUNT(*) FILTER (WHERE n."executionStatus" = 'RUNNING') AS node_execution_running,
COUNT(*) FILTER (WHERE n."executionStatus" = 'INCOMPLETE') AS node_execution_incomplete,
COUNT(*) FILTER (WHERE n."executionStatus" = 'REVIEW') AS node_execution_review
FROM platform."AgentNodeExecution" n
JOIN platform."AgentGraphExecution" g
ON g."id" = n."agentGraphExecutionId"
GROUP BY g."userId"
)
SELECT
ul.user_id,
ul.first_login_time,
ul.last_login_time,
ul.last_visit_time,
ua.last_agent_save_time,
COALESCE(ua.agent_count, 0) AS agent_count,
gr.first_agent_run_time,
gr.last_agent_run_time,
COALESCE(gr.unique_agent_runs, 0) AS unique_agent_runs,
COALESCE(gr.agent_runs, 0) AS agent_runs,
COALESCE(nr.node_execution_count, 0) AS node_execution_count,
COALESCE(nr.node_execution_failed, 0) AS node_execution_failed,
COALESCE(nr.node_execution_completed, 0) AS node_execution_completed,
COALESCE(nr.node_execution_terminated, 0) AS node_execution_terminated,
COALESCE(nr.node_execution_queued, 0) AS node_execution_queued,
COALESCE(nr.node_execution_running, 0) AS node_execution_running,
CASE
WHEN ul.first_login_time < NOW() - INTERVAL '7 days'
AND ul.last_visit_time >= ul.first_login_time + INTERVAL '7 days' THEN 1
WHEN ul.first_login_time < NOW() - INTERVAL '7 days'
AND ul.last_visit_time < ul.first_login_time + INTERVAL '7 days' THEN 0
ELSE NULL
END AS is_active_after_7d,
COALESCE(nr.node_execution_incomplete, 0) AS node_execution_incomplete,
COALESCE(nr.node_execution_review, 0) AS node_execution_review
FROM user_logins ul
LEFT JOIN user_agents ua ON ul.user_id = ua.user_id
LEFT JOIN user_graph_runs gr ON ul.user_id = gr.user_id
LEFT JOIN user_node_runs nr ON ul.user_id = nr.user_id

View File

@@ -37,10 +37,6 @@ JWT_VERIFY_KEY=your-super-secret-jwt-token-with-at-least-32-characters-long
ENCRYPTION_KEY=dvziYgz0KSK8FENhju0ZYi8-fRTfAdlz6YLhdB_jhNw=
UNSUBSCRIBE_SECRET_KEY=HlP8ivStJjmbf6NKi78m_3FnOogut0t5ckzjsIqeaio=
## ===== SIGNUP / INVITE GATE ===== ##
# Set to true to require an invite before users can sign up
ENABLE_INVITE_GATE=false
## ===== IMPORTANT OPTIONAL CONFIGURATION ===== ##
# Platform URLs (set these for webhooks and OAuth to work)
PLATFORM_BASE_URL=http://localhost:8000

View File

@@ -1,17 +1,8 @@
from __future__ import annotations
from datetime import datetime
from typing import TYPE_CHECKING, Any, Literal, Optional
import prisma.enums
from pydantic import BaseModel, EmailStr
from pydantic import BaseModel
from backend.data.model import UserTransaction
from backend.util.models import Pagination
if TYPE_CHECKING:
from backend.data.invited_user import BulkInvitedUsersResult, InvitedUserRecord
class UserHistoryResponse(BaseModel):
"""Response model for listings with version history"""
@@ -23,70 +14,3 @@ class UserHistoryResponse(BaseModel):
class AddUserCreditsResponse(BaseModel):
new_balance: int
transaction_key: str
class CreateInvitedUserRequest(BaseModel):
email: EmailStr
name: Optional[str] = None
class InvitedUserResponse(BaseModel):
id: str
email: str
status: prisma.enums.InvitedUserStatus
auth_user_id: Optional[str] = None
name: Optional[str] = None
tally_understanding: Optional[dict[str, Any]] = None
tally_status: prisma.enums.TallyComputationStatus
tally_computed_at: Optional[datetime] = None
tally_error: Optional[str] = None
created_at: datetime
updated_at: datetime
@classmethod
def from_record(cls, record: InvitedUserRecord) -> InvitedUserResponse:
return cls.model_validate(record.model_dump())
class InvitedUsersResponse(BaseModel):
invited_users: list[InvitedUserResponse]
pagination: Pagination
class BulkInvitedUserRowResponse(BaseModel):
row_number: int
email: Optional[str] = None
name: Optional[str] = None
status: Literal["CREATED", "SKIPPED", "ERROR"]
message: str
invited_user: Optional[InvitedUserResponse] = None
class BulkInvitedUsersResponse(BaseModel):
created_count: int
skipped_count: int
error_count: int
results: list[BulkInvitedUserRowResponse]
@classmethod
def from_result(cls, result: BulkInvitedUsersResult) -> BulkInvitedUsersResponse:
return cls(
created_count=result.created_count,
skipped_count=result.skipped_count,
error_count=result.error_count,
results=[
BulkInvitedUserRowResponse(
row_number=row.row_number,
email=row.email,
name=row.name,
status=row.status,
message=row.message,
invited_user=(
InvitedUserResponse.from_record(row.invited_user)
if row.invited_user is not None
else None
),
)
for row in result.results
],
)

View File

@@ -1,137 +0,0 @@
import logging
import math
from autogpt_libs.auth import get_user_id, requires_admin_user
from fastapi import APIRouter, File, Query, Security, UploadFile
from backend.data.invited_user import (
bulk_create_invited_users_from_file,
create_invited_user,
list_invited_users,
retry_invited_user_tally,
revoke_invited_user,
)
from backend.data.tally import mask_email
from backend.util.models import Pagination
from .model import (
BulkInvitedUsersResponse,
CreateInvitedUserRequest,
InvitedUserResponse,
InvitedUsersResponse,
)
logger = logging.getLogger(__name__)
router = APIRouter(
prefix="/admin",
tags=["users", "admin"],
dependencies=[Security(requires_admin_user)],
)
@router.get(
"/invited-users",
response_model=InvitedUsersResponse,
summary="List Invited Users",
)
async def get_invited_users(
admin_user_id: str = Security(get_user_id),
page: int = Query(1, ge=1),
page_size: int = Query(50, ge=1, le=200),
) -> InvitedUsersResponse:
logger.info("Admin user %s requested invited users", admin_user_id)
invited_users, total = await list_invited_users(page=page, page_size=page_size)
return InvitedUsersResponse(
invited_users=[InvitedUserResponse.from_record(iu) for iu in invited_users],
pagination=Pagination(
total_items=total,
total_pages=max(1, math.ceil(total / page_size)),
current_page=page,
page_size=page_size,
),
)
@router.post(
"/invited-users",
response_model=InvitedUserResponse,
summary="Create Invited User",
)
async def create_invited_user_route(
request: CreateInvitedUserRequest,
admin_user_id: str = Security(get_user_id),
) -> InvitedUserResponse:
logger.info(
"Admin user %s creating invited user for %s",
admin_user_id,
mask_email(request.email),
)
invited_user = await create_invited_user(request.email, request.name)
logger.info(
"Admin user %s created invited user %s",
admin_user_id,
invited_user.id,
)
return InvitedUserResponse.from_record(invited_user)
@router.post(
"/invited-users/bulk",
response_model=BulkInvitedUsersResponse,
summary="Bulk Create Invited Users",
operation_id="postV2BulkCreateInvitedUsers",
)
async def bulk_create_invited_users_route(
file: UploadFile = File(...),
admin_user_id: str = Security(get_user_id),
) -> BulkInvitedUsersResponse:
logger.info(
"Admin user %s bulk invited users from %s",
admin_user_id,
file.filename or "<unnamed>",
)
content = await file.read()
result = await bulk_create_invited_users_from_file(file.filename, content)
return BulkInvitedUsersResponse.from_result(result)
@router.post(
"/invited-users/{invited_user_id}/revoke",
response_model=InvitedUserResponse,
summary="Revoke Invited User",
)
async def revoke_invited_user_route(
invited_user_id: str,
admin_user_id: str = Security(get_user_id),
) -> InvitedUserResponse:
logger.info(
"Admin user %s revoking invited user %s", admin_user_id, invited_user_id
)
invited_user = await revoke_invited_user(invited_user_id)
logger.info("Admin user %s revoked invited user %s", admin_user_id, invited_user_id)
return InvitedUserResponse.from_record(invited_user)
@router.post(
"/invited-users/{invited_user_id}/retry-tally",
response_model=InvitedUserResponse,
summary="Retry Invited User Tally",
)
async def retry_invited_user_tally_route(
invited_user_id: str,
admin_user_id: str = Security(get_user_id),
) -> InvitedUserResponse:
logger.info(
"Admin user %s retrying Tally seed for invited user %s",
admin_user_id,
invited_user_id,
)
invited_user = await retry_invited_user_tally(invited_user_id)
logger.info(
"Admin user %s retried Tally seed for invited user %s",
admin_user_id,
invited_user_id,
)
return InvitedUserResponse.from_record(invited_user)

View File

@@ -1,168 +0,0 @@
from datetime import datetime, timezone
from unittest.mock import AsyncMock
import fastapi
import fastapi.testclient
import prisma.enums
import pytest
import pytest_mock
from autogpt_libs.auth.jwt_utils import get_jwt_payload
from backend.data.invited_user import (
BulkInvitedUserRowResult,
BulkInvitedUsersResult,
InvitedUserRecord,
)
from .user_admin_routes import router as user_admin_router
app = fastapi.FastAPI()
app.include_router(user_admin_router)
client = fastapi.testclient.TestClient(app)
@pytest.fixture(autouse=True)
def setup_app_admin_auth(mock_jwt_admin):
app.dependency_overrides[get_jwt_payload] = mock_jwt_admin["get_jwt_payload"]
yield
app.dependency_overrides.clear()
def _sample_invited_user() -> InvitedUserRecord:
now = datetime.now(timezone.utc)
return InvitedUserRecord(
id="invite-1",
email="invited@example.com",
status=prisma.enums.InvitedUserStatus.INVITED,
auth_user_id=None,
name="Invited User",
tally_understanding=None,
tally_status=prisma.enums.TallyComputationStatus.PENDING,
tally_computed_at=None,
tally_error=None,
created_at=now,
updated_at=now,
)
def _sample_bulk_invited_users_result() -> BulkInvitedUsersResult:
return BulkInvitedUsersResult(
created_count=1,
skipped_count=1,
error_count=0,
results=[
BulkInvitedUserRowResult(
row_number=1,
email="invited@example.com",
name=None,
status="CREATED",
message="Invite created",
invited_user=_sample_invited_user(),
),
BulkInvitedUserRowResult(
row_number=2,
email="duplicate@example.com",
name=None,
status="SKIPPED",
message="An invited user with this email already exists",
invited_user=None,
),
],
)
def test_get_invited_users(
mocker: pytest_mock.MockerFixture,
) -> None:
mocker.patch(
"backend.api.features.admin.user_admin_routes.list_invited_users",
AsyncMock(return_value=([_sample_invited_user()], 1)),
)
response = client.get("/admin/invited-users")
assert response.status_code == 200
data = response.json()
assert len(data["invited_users"]) == 1
assert data["invited_users"][0]["email"] == "invited@example.com"
assert data["invited_users"][0]["status"] == "INVITED"
assert data["pagination"]["total_items"] == 1
assert data["pagination"]["current_page"] == 1
assert data["pagination"]["page_size"] == 50
def test_create_invited_user(
mocker: pytest_mock.MockerFixture,
) -> None:
mocker.patch(
"backend.api.features.admin.user_admin_routes.create_invited_user",
AsyncMock(return_value=_sample_invited_user()),
)
response = client.post(
"/admin/invited-users",
json={"email": "invited@example.com", "name": "Invited User"},
)
assert response.status_code == 200
data = response.json()
assert data["email"] == "invited@example.com"
assert data["name"] == "Invited User"
def test_bulk_create_invited_users(
mocker: pytest_mock.MockerFixture,
) -> None:
mocker.patch(
"backend.api.features.admin.user_admin_routes.bulk_create_invited_users_from_file",
AsyncMock(return_value=_sample_bulk_invited_users_result()),
)
response = client.post(
"/admin/invited-users/bulk",
files={
"file": ("invites.txt", b"invited@example.com\nduplicate@example.com\n")
},
)
assert response.status_code == 200
data = response.json()
assert data["created_count"] == 1
assert data["skipped_count"] == 1
assert data["results"][0]["status"] == "CREATED"
assert data["results"][1]["status"] == "SKIPPED"
def test_revoke_invited_user(
mocker: pytest_mock.MockerFixture,
) -> None:
revoked = _sample_invited_user().model_copy(
update={"status": prisma.enums.InvitedUserStatus.REVOKED}
)
mocker.patch(
"backend.api.features.admin.user_admin_routes.revoke_invited_user",
AsyncMock(return_value=revoked),
)
response = client.post("/admin/invited-users/invite-1/revoke")
assert response.status_code == 200
assert response.json()["status"] == "REVOKED"
def test_retry_invited_user_tally(
mocker: pytest_mock.MockerFixture,
) -> None:
retried = _sample_invited_user().model_copy(
update={"tally_status": prisma.enums.TallyComputationStatus.RUNNING}
)
mocker.patch(
"backend.api.features.admin.user_admin_routes.retry_invited_user_tally",
AsyncMock(return_value=retried),
)
response = client.post("/admin/invited-users/invite-1/retry-tally")
assert response.status_code == 200
assert response.json()["tally_status"] == "RUNNING"

View File

@@ -27,12 +27,6 @@ from backend.copilot.model import (
get_user_sessions,
update_session_title,
)
from backend.copilot.rate_limit import (
CoPilotUsageStatus,
RateLimitExceeded,
check_rate_limit,
get_usage_status,
)
from backend.copilot.response_model import StreamError, StreamFinish, StreamHeartbeat
from backend.copilot.tools.e2b_sandbox import kill_sandbox
from backend.copilot.tools.models import (
@@ -59,8 +53,6 @@ from backend.copilot.tools.models import (
UnderstandingUpdatedResponse,
)
from backend.copilot.tracking import track_user_message
from backend.data.redis_client import get_redis_async
from backend.data.understanding import get_business_understanding
from backend.data.workspace import get_or_create_workspace
from backend.util.exceptions import NotFoundError
@@ -126,8 +118,6 @@ class SessionDetailResponse(BaseModel):
user_id: str | None
messages: list[dict]
active_stream: ActiveStreamInfo | None = None # Present if stream is still active
total_prompt_tokens: int = 0
total_completion_tokens: int = 0
class SessionSummaryResponse(BaseModel):
@@ -137,7 +127,6 @@ class SessionSummaryResponse(BaseModel):
created_at: str
updated_at: str
title: str | None = None
is_processing: bool
class ListSessionsResponse(BaseModel):
@@ -196,28 +185,6 @@ async def list_sessions(
"""
sessions, total_count = await get_user_sessions(user_id, limit, offset)
# Batch-check Redis for active stream status on each session
processing_set: set[str] = set()
if sessions:
try:
redis = await get_redis_async()
pipe = redis.pipeline(transaction=False)
for session in sessions:
pipe.hget(
f"{config.session_meta_prefix}{session.session_id}",
"status",
)
statuses = await pipe.execute()
processing_set = {
session.session_id
for session, st in zip(sessions, statuses)
if st == "running"
}
except Exception:
logger.warning(
"Failed to fetch processing status from Redis; " "defaulting to empty"
)
return ListSessionsResponse(
sessions=[
SessionSummaryResponse(
@@ -225,7 +192,6 @@ async def list_sessions(
created_at=session.started_at.isoformat(),
updated_at=session.updated_at.isoformat(),
title=session.title,
is_processing=session.session_id in processing_set,
)
for session in sessions
],
@@ -397,10 +363,6 @@ async def get_session(
last_message_id=last_message_id,
)
# Sum token usage from session
total_prompt = sum(u.prompt_tokens for u in session.usage)
total_completion = sum(u.completion_tokens for u in session.usage)
return SessionDetailResponse(
id=session.session_id,
created_at=session.started_at.isoformat(),
@@ -408,26 +370,6 @@ async def get_session(
user_id=session.user_id or None,
messages=messages,
active_stream=active_stream_info,
total_prompt_tokens=total_prompt,
total_completion_tokens=total_completion,
)
@router.get("/usage")
async def get_copilot_usage(
user_id: Annotated[str | None, Depends(auth.get_user_id)],
) -> CoPilotUsageStatus:
"""Get CoPilot usage status for the authenticated user.
Returns current token usage vs limits for daily and weekly windows.
"""
if not user_id:
raise HTTPException(status_code=401, detail="Authentication required")
return await get_usage_status(
user_id=user_id,
daily_token_limit=config.daily_token_limit,
weekly_token_limit=config.weekly_token_limit,
)
@@ -528,17 +470,6 @@ async def stream_chat_post(
},
)
# Pre-turn rate limit check (token-based)
if user_id and (config.daily_token_limit > 0 or config.weekly_token_limit > 0):
try:
await check_rate_limit(
user_id=user_id,
daily_token_limit=config.daily_token_limit,
weekly_token_limit=config.weekly_token_limit,
)
except RateLimitExceeded as e:
raise HTTPException(status_code=429, detail=str(e)) from e
# Enrich message with file metadata if file_ids are provided.
# Also sanitise file_ids so only validated, workspace-scoped IDs are
# forwarded downstream (e.g. to the executor via enqueue_copilot_turn).
@@ -897,36 +828,6 @@ async def session_assign_user(
return {"status": "ok"}
# ========== Suggested Prompts ==========
class SuggestedPromptsResponse(BaseModel):
"""Response model for user-specific suggested prompts."""
prompts: list[str]
@router.get(
"/suggested-prompts",
dependencies=[Security(auth.requires_user)],
)
async def get_suggested_prompts(
user_id: Annotated[str, Security(auth.get_user_id)],
) -> SuggestedPromptsResponse:
"""
Get LLM-generated suggested prompts for the authenticated user.
Returns personalized quick-action prompts based on the user's
business understanding. Returns an empty list if no custom prompts
are available.
"""
understanding = await get_business_understanding(user_id)
if understanding is None:
return SuggestedPromptsResponse(prompts=[])
return SuggestedPromptsResponse(prompts=understanding.suggested_prompts)
# ========== Configuration ==========

View File

@@ -1,7 +1,6 @@
"""Tests for chat API routes: session title update, file attachment validation, usage, and suggested prompts."""
"""Tests for chat API routes: session title update and file attachment validation."""
from datetime import UTC, datetime, timedelta
from unittest.mock import AsyncMock, MagicMock
from unittest.mock import AsyncMock
import fastapi
import fastapi.testclient
@@ -250,130 +249,3 @@ def test_file_ids_scoped_to_workspace(mocker: pytest_mock.MockFixture):
call_kwargs = mock_prisma.find_many.call_args[1]
assert call_kwargs["where"]["workspaceId"] == "my-workspace-id"
assert call_kwargs["where"]["isDeleted"] is False
# ─── Usage endpoint ───────────────────────────────────────────────────
def _mock_usage(
mocker: pytest_mock.MockerFixture,
*,
daily_used: int = 500,
weekly_used: int = 2000,
) -> AsyncMock:
"""Mock get_usage_status to return a predictable CoPilotUsageStatus."""
from backend.copilot.rate_limit import CoPilotUsageStatus, UsageWindow
resets_at = datetime.now(UTC) + timedelta(days=1)
status = CoPilotUsageStatus(
daily=UsageWindow(used=daily_used, limit=10000, resets_at=resets_at),
weekly=UsageWindow(used=weekly_used, limit=50000, resets_at=resets_at),
)
return mocker.patch(
"backend.api.features.chat.routes.get_usage_status",
new_callable=AsyncMock,
return_value=status,
)
def test_usage_returns_daily_and_weekly(
mocker: pytest_mock.MockerFixture,
test_user_id: str,
) -> None:
"""GET /usage returns daily and weekly usage."""
mock_get = _mock_usage(mocker, daily_used=500, weekly_used=2000)
mocker.patch.object(chat_routes.config, "daily_token_limit", 10000)
mocker.patch.object(chat_routes.config, "weekly_token_limit", 50000)
response = client.get("/usage")
assert response.status_code == 200
data = response.json()
assert data["daily"]["used"] == 500
assert data["weekly"]["used"] == 2000
mock_get.assert_called_once_with(
user_id=test_user_id,
daily_token_limit=10000,
weekly_token_limit=50000,
)
def test_usage_uses_config_limits(
mocker: pytest_mock.MockerFixture,
test_user_id: str,
) -> None:
"""The endpoint forwards daily_token_limit and weekly_token_limit from config."""
mock_get = _mock_usage(mocker)
mocker.patch.object(chat_routes.config, "daily_token_limit", 99999)
mocker.patch.object(chat_routes.config, "weekly_token_limit", 77777)
response = client.get("/usage")
assert response.status_code == 200
mock_get.assert_called_once_with(
user_id=test_user_id,
daily_token_limit=99999,
weekly_token_limit=77777,
)
# ─── Suggested prompts endpoint ──────────────────────────────────────
def _mock_get_business_understanding(
mocker: pytest_mock.MockerFixture,
*,
return_value=None,
):
"""Mock get_business_understanding."""
return mocker.patch(
"backend.api.features.chat.routes.get_business_understanding",
new_callable=AsyncMock,
return_value=return_value,
)
def test_suggested_prompts_returns_prompts(
mocker: pytest_mock.MockerFixture,
test_user_id: str,
) -> None:
"""User with understanding and prompts gets them back."""
mock_understanding = MagicMock()
mock_understanding.suggested_prompts = ["Do X", "Do Y", "Do Z"]
_mock_get_business_understanding(mocker, return_value=mock_understanding)
response = client.get("/suggested-prompts")
assert response.status_code == 200
assert response.json() == {"prompts": ["Do X", "Do Y", "Do Z"]}
def test_suggested_prompts_no_understanding(
mocker: pytest_mock.MockerFixture,
test_user_id: str,
) -> None:
"""User with no understanding gets empty list."""
_mock_get_business_understanding(mocker, return_value=None)
response = client.get("/suggested-prompts")
assert response.status_code == 200
assert response.json() == {"prompts": []}
def test_suggested_prompts_empty_prompts(
mocker: pytest_mock.MockerFixture,
test_user_id: str,
) -> None:
"""User with understanding but no prompts gets empty list."""
mock_understanding = MagicMock()
mock_understanding.suggested_prompts = []
_mock_get_business_understanding(mocker, return_value=mock_understanding)
response = client.get("/suggested-prompts")
assert response.status_code == 200
assert response.json() == {"prompts": []}

View File

@@ -638,7 +638,7 @@ async def test_process_review_action_auto_approve_creates_auto_approval_records(
# Mock get_node_executions to return node_id mapping
mock_get_node_executions = mocker.patch(
"backend.api.features.executions.review.routes.get_node_executions"
"backend.data.execution.get_node_executions"
)
mock_node_exec = mocker.Mock(spec=NodeExecutionResult)
mock_node_exec.node_exec_id = "test_node_123"
@@ -936,7 +936,7 @@ async def test_process_review_action_auto_approve_only_applies_to_approved_revie
# Mock get_node_executions to return node_id mapping
mock_get_node_executions = mocker.patch(
"backend.api.features.executions.review.routes.get_node_executions"
"backend.data.execution.get_node_executions"
)
mock_node_exec = mocker.Mock(spec=NodeExecutionResult)
mock_node_exec.node_exec_id = "node_exec_approved"
@@ -1148,7 +1148,7 @@ async def test_process_review_action_per_review_auto_approve_granularity(
# Mock get_node_executions to return batch node data
mock_get_node_executions = mocker.patch(
"backend.api.features.executions.review.routes.get_node_executions"
"backend.data.execution.get_node_executions"
)
# Create mock node executions for each review
mock_node_execs = []

View File

@@ -6,15 +6,10 @@ import autogpt_libs.auth as autogpt_auth_lib
from fastapi import APIRouter, HTTPException, Query, Security, status
from prisma.enums import ReviewStatus
from backend.copilot.constants import (
is_copilot_synthetic_id,
parse_node_id_from_exec_id,
)
from backend.data.execution import (
ExecutionContext,
ExecutionStatus,
get_graph_execution_meta,
get_node_executions,
)
from backend.data.graph import get_graph_settings
from backend.data.human_review import (
@@ -41,38 +36,6 @@ router = APIRouter(
)
async def _resolve_node_ids(
node_exec_ids: list[str],
graph_exec_id: str,
is_copilot: bool,
) -> dict[str, str]:
"""Resolve node_exec_id -> node_id for auto-approval records.
CoPilot synthetic IDs encode node_id in the format "{node_id}:{random}".
Graph executions look up node_id from NodeExecution records.
"""
if not node_exec_ids:
return {}
if is_copilot:
return {neid: parse_node_id_from_exec_id(neid) for neid in node_exec_ids}
node_execs = await get_node_executions(
graph_exec_id=graph_exec_id, include_exec_data=False
)
node_exec_map = {ne.node_exec_id: ne.node_id for ne in node_execs}
result = {}
for neid in node_exec_ids:
if neid in node_exec_map:
result[neid] = node_exec_map[neid]
else:
logger.error(
f"Failed to resolve node_id for {neid}: Node execution not found."
)
return result
@router.get(
"/pending",
summary="Get Pending Reviews",
@@ -147,16 +110,14 @@ async def list_pending_reviews_for_execution(
"""
# Verify user owns the graph execution before returning reviews
# (CoPilot synthetic IDs don't have graph execution records)
if not is_copilot_synthetic_id(graph_exec_id):
graph_exec = await get_graph_execution_meta(
user_id=user_id, execution_id=graph_exec_id
graph_exec = await get_graph_execution_meta(
user_id=user_id, execution_id=graph_exec_id
)
if not graph_exec:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"Graph execution #{graph_exec_id} not found",
)
if not graph_exec:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"Graph execution #{graph_exec_id} not found",
)
return await get_pending_reviews_for_execution(graph_exec_id, user_id)
@@ -199,26 +160,30 @@ async def process_review_action(
)
graph_exec_id = next(iter(graph_exec_ids))
is_copilot = is_copilot_synthetic_id(graph_exec_id)
# Validate execution status for graph executions (skip for CoPilot synthetic IDs)
if not is_copilot:
graph_exec_meta = await get_graph_execution_meta(
user_id=user_id, execution_id=graph_exec_id
# Validate execution status before processing reviews
graph_exec_meta = await get_graph_execution_meta(
user_id=user_id, execution_id=graph_exec_id
)
if not graph_exec_meta:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"Graph execution #{graph_exec_id} not found",
)
# Only allow processing reviews if execution is paused for review
# or incomplete (partial execution with some reviews already processed)
if graph_exec_meta.status not in (
ExecutionStatus.REVIEW,
ExecutionStatus.INCOMPLETE,
):
raise HTTPException(
status_code=status.HTTP_409_CONFLICT,
detail=f"Cannot process reviews while execution status is {graph_exec_meta.status}. "
f"Reviews can only be processed when execution is paused (REVIEW status). "
f"Current status: {graph_exec_meta.status}",
)
if not graph_exec_meta:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"Graph execution #{graph_exec_id} not found",
)
if graph_exec_meta.status not in (
ExecutionStatus.REVIEW,
ExecutionStatus.INCOMPLETE,
):
raise HTTPException(
status_code=status.HTTP_409_CONFLICT,
detail=f"Cannot process reviews while execution status is {graph_exec_meta.status}",
)
# Build review decisions map and track which reviews requested auto-approval
# Auto-approved reviews use original data (no modifications allowed)
@@ -271,7 +236,7 @@ async def process_review_action(
)
return (node_id, False)
# Collect node_exec_ids that need auto-approval and resolve their node_ids
# Collect node_exec_ids that need auto-approval
node_exec_ids_needing_auto_approval = [
node_exec_id
for node_exec_id, review_result in updated_reviews.items()
@@ -279,16 +244,29 @@ async def process_review_action(
and auto_approve_requests.get(node_exec_id, False)
]
node_id_map = await _resolve_node_ids(
node_exec_ids_needing_auto_approval, graph_exec_id, is_copilot
)
# Deduplicate by node_id — one auto-approval per node
# Batch-fetch node executions to get node_ids
nodes_needing_auto_approval: dict[str, Any] = {}
for node_exec_id in node_exec_ids_needing_auto_approval:
node_id = node_id_map.get(node_exec_id)
if node_id and node_id not in nodes_needing_auto_approval:
nodes_needing_auto_approval[node_id] = updated_reviews[node_exec_id]
if node_exec_ids_needing_auto_approval:
from backend.data.execution import get_node_executions
node_execs = await get_node_executions(
graph_exec_id=graph_exec_id, include_exec_data=False
)
node_exec_map = {node_exec.node_exec_id: node_exec for node_exec in node_execs}
for node_exec_id in node_exec_ids_needing_auto_approval:
node_exec = node_exec_map.get(node_exec_id)
if node_exec:
review_result = updated_reviews[node_exec_id]
# Use the first approved review for this node (deduplicate by node_id)
if node_exec.node_id not in nodes_needing_auto_approval:
nodes_needing_auto_approval[node_exec.node_id] = review_result
else:
logger.error(
f"Failed to create auto-approval record for {node_exec_id}: "
f"Node execution not found. This may indicate a race condition "
f"or data inconsistency."
)
# Execute all auto-approval creations in parallel (deduplicated by node_id)
auto_approval_results = await asyncio.gather(
@@ -303,11 +281,13 @@ async def process_review_action(
auto_approval_failed_count = 0
for result in auto_approval_results:
if isinstance(result, Exception):
# Unexpected exception during auto-approval creation
auto_approval_failed_count += 1
logger.error(
f"Unexpected exception during auto-approval creation: {result}"
)
elif isinstance(result, tuple) and len(result) == 2 and not result[1]:
# Auto-approval creation failed (returned False)
auto_approval_failed_count += 1
# Count results
@@ -322,20 +302,22 @@ async def process_review_action(
if review.status == ReviewStatus.REJECTED
)
# Resume graph execution only for real graph executions (not CoPilot)
# CoPilot sessions are resumed by the LLM retrying run_block with review_id
if not is_copilot and updated_reviews:
# Resume execution only if ALL pending reviews for this execution have been processed
if updated_reviews:
still_has_pending = await has_pending_reviews_for_graph_exec(graph_exec_id)
if not still_has_pending:
# Get the graph_id from any processed review
first_review = next(iter(updated_reviews.values()))
try:
# Fetch user and settings to build complete execution context
user = await get_user_by_id(user_id)
settings = await get_graph_settings(
user_id=user_id, graph_id=first_review.graph_id
)
# Preserve user's timezone preference when resuming execution
user_timezone = (
user.timezone if user.timezone != USER_TIMEZONE_NOT_SET else "UTC"
)

View File

@@ -165,6 +165,7 @@ class LibraryAgent(pydantic.BaseModel):
id: str
graph_id: str
graph_version: int
owner_user_id: str
image_url: str | None
@@ -205,9 +206,7 @@ class LibraryAgent(pydantic.BaseModel):
default_factory=list,
description="List of recent executions with status, score, and summary",
)
can_access_graph: bool = pydantic.Field(
description="Indicates whether the same user owns the corresponding graph"
)
can_access_graph: bool
is_latest_version: bool
is_favorite: bool
folder_id: str | None = None
@@ -325,6 +324,7 @@ class LibraryAgent(pydantic.BaseModel):
id=agent.id,
graph_id=agent.agentGraphId,
graph_version=agent.agentGraphVersion,
owner_user_id=agent.userId,
image_url=agent.imageUrl,
creator_name=creator_name,
creator_image_url=creator_image_url,

View File

@@ -42,6 +42,7 @@ async def test_get_library_agents_success(
id="test-agent-1",
graph_id="test-agent-1",
graph_version=1,
owner_user_id=test_user_id,
name="Test Agent 1",
description="Test Description 1",
image_url=None,
@@ -66,6 +67,7 @@ async def test_get_library_agents_success(
id="test-agent-2",
graph_id="test-agent-2",
graph_version=1,
owner_user_id=test_user_id,
name="Test Agent 2",
description="Test Description 2",
image_url=None,
@@ -129,6 +131,7 @@ async def test_get_favorite_library_agents_success(
id="test-agent-1",
graph_id="test-agent-1",
graph_version=1,
owner_user_id=test_user_id,
name="Favorite Agent 1",
description="Test Favorite Description 1",
image_url=None,
@@ -181,6 +184,7 @@ def test_add_agent_to_library_success(
id="test-library-agent-id",
graph_id="test-agent-1",
graph_version=1,
owner_user_id=test_user_id,
name="Test Agent 1",
description="Test Description 1",
image_url=None,

View File

@@ -55,7 +55,6 @@ from backend.data.credit import (
set_auto_top_up,
)
from backend.data.graph import GraphSettings
from backend.data.invited_user import get_or_activate_user
from backend.data.model import CredentialsMetaInput, UserOnboarding
from backend.data.notifications import NotificationPreference, NotificationPreferenceDTO
from backend.data.onboarding import (
@@ -71,6 +70,7 @@ from backend.data.onboarding import (
update_user_onboarding,
)
from backend.data.user import (
get_or_create_user,
get_user_by_id,
get_user_notification_preference,
update_user_email,
@@ -136,10 +136,12 @@ _tally_background_tasks: set[asyncio.Task] = set()
dependencies=[Security(requires_user)],
)
async def get_or_create_user_route(user_data: dict = Security(get_jwt_payload)):
user = await get_or_activate_user(user_data)
user = await get_or_create_user(user_data)
# Fire-and-forget: backfill Tally understanding when invite pre-seeding did
# not produce a stored result before first activation.
# Fire-and-forget: populate business understanding from Tally form.
# We use created_at proximity instead of an is_new flag because
# get_or_create_user is cached — a separate is_new return value would be
# unreliable on repeated calls within the cache TTL.
age_seconds = (datetime.now(timezone.utc) - user.created_at).total_seconds()
if age_seconds < 30:
try:
@@ -163,8 +165,7 @@ async def get_or_create_user_route(user_data: dict = Security(get_jwt_payload)):
dependencies=[Security(requires_user)],
)
async def update_user_email_route(
user_id: Annotated[str, Security(get_user_id)],
email: str = Body(...),
user_id: Annotated[str, Security(get_user_id)], email: str = Body(...)
) -> dict[str, str]:
await update_user_email(user_id, email)
@@ -178,16 +179,10 @@ async def update_user_email_route(
dependencies=[Security(requires_user)],
)
async def get_user_timezone_route(
user_id: Annotated[str, Security(get_user_id)],
user_data: dict = Security(get_jwt_payload),
) -> TimezoneResponse:
"""Get user timezone setting."""
try:
user = await get_user_by_id(user_id)
except ValueError:
raise HTTPException(
status_code=HTTP_404_NOT_FOUND,
detail="User not found. Please complete activation via /auth/user first.",
)
user = await get_or_create_user(user_data)
return TimezoneResponse(timezone=user.timezone)
@@ -198,8 +193,7 @@ async def get_user_timezone_route(
dependencies=[Security(requires_user)],
)
async def update_user_timezone_route(
user_id: Annotated[str, Security(get_user_id)],
request: UpdateTimezoneRequest,
user_id: Annotated[str, Security(get_user_id)], request: UpdateTimezoneRequest
) -> TimezoneResponse:
"""Update user timezone. The timezone should be a valid IANA timezone identifier."""
user = await update_user_timezone(user_id, str(request.timezone))

View File

@@ -51,7 +51,7 @@ def test_get_or_create_user_route(
}
mocker.patch(
"backend.api.features.v1.get_or_activate_user",
"backend.api.features.v1.get_or_create_user",
return_value=mock_user,
)

View File

@@ -94,8 +94,3 @@ class NotificationPayload(pydantic.BaseModel):
class OnboardingNotificationPayload(NotificationPayload):
step: OnboardingStep | None
class CopilotCompletionPayload(NotificationPayload):
session_id: str
status: Literal["completed", "failed"]

View File

@@ -19,7 +19,6 @@ from prisma.errors import PrismaError
import backend.api.features.admin.credit_admin_routes
import backend.api.features.admin.execution_analytics_routes
import backend.api.features.admin.store_admin_routes
import backend.api.features.admin.user_admin_routes
import backend.api.features.builder
import backend.api.features.builder.routes
import backend.api.features.chat.routes as chat_routes
@@ -38,6 +37,7 @@ import backend.api.features.workspace.routes as workspace_routes
import backend.data.block
import backend.data.db
import backend.data.graph
import backend.data.llm_registry
import backend.data.user
import backend.integrations.webhooks.utils
import backend.util.service
@@ -118,11 +118,30 @@ async def lifespan_context(app: fastapi.FastAPI):
AutoRegistry.patch_integrations()
# Refresh LLM registry before initializing blocks so blocks can use registry data
# Note: Graceful fallback for now since no blocks consume registry yet (comes in PR #5)
# When block integration lands, this should fail hard or skip block initialization
try:
await backend.data.llm_registry.refresh_llm_registry()
logger.info("LLM registry refreshed successfully at startup")
except Exception as e:
logger.warning(
f"Failed to refresh LLM registry at startup: {e}. "
"Blocks will initialize with empty registry."
)
await backend.data.block.initialize_blocks()
await backend.data.user.migrate_and_encrypt_user_integrations()
await backend.data.graph.fix_llm_provider_credentials()
await backend.data.graph.migrate_llm_models(DEFAULT_LLM_MODEL)
try:
await backend.data.graph.migrate_llm_models(DEFAULT_LLM_MODEL)
except Exception as e:
logger.warning(
f"Failed to migrate LLM models at startup: {e}. "
"This is expected in test environments without AgentNode table."
)
await backend.integrations.webhooks.utils.migrate_legacy_triggered_graphs()
with launch_darkly_context():
@@ -312,11 +331,6 @@ app.include_router(
tags=["v2", "admin"],
prefix="/api/executions",
)
app.include_router(
backend.api.features.admin.user_admin_routes.router,
tags=["v2", "admin"],
prefix="/api/users",
)
app.include_router(
backend.api.features.executions.review.routes.router,
tags=["v2", "executions", "review"],

View File

@@ -624,7 +624,6 @@ class Block(ABC, Generic[BlockSchemaInputType, BlockSchemaOutputType]):
graph_id: str,
graph_version: int,
execution_context: "ExecutionContext",
is_graph_execution: bool = True,
**kwargs,
) -> tuple[bool, BlockInput]:
"""
@@ -653,7 +652,6 @@ class Block(ABC, Generic[BlockSchemaInputType, BlockSchemaOutputType]):
graph_version=graph_version,
block_name=self.name,
editable=True,
is_graph_execution=is_graph_execution,
)
if decision is None:

View File

@@ -126,7 +126,7 @@ class PrintToConsoleBlock(Block):
output_schema=PrintToConsoleBlock.Output,
test_input={"text": "Hello, World!"},
is_sensitive_action=True,
disabled=True,
disabled=True, # Disabled per Nick Tindle's request (OPEN-3000)
test_output=[
("output", "Hello, World!"),
("status", "printed"),

View File

@@ -96,7 +96,6 @@ class SendEmailBlock(Block):
test_credentials=TEST_CREDENTIALS,
test_output=[("status", "Email sent successfully")],
test_mock={"send_email": lambda *args, **kwargs: "Email sent successfully"},
is_sensitive_action=True,
)
@staticmethod

View File

@@ -1,3 +0,0 @@
def github_repo_path(repo_url: str) -> str:
"""Extract 'owner/repo' from a GitHub repository URL."""
return repo_url.replace("https://github.com/", "")

View File

@@ -1,374 +0,0 @@
import asyncio
from enum import StrEnum
from urllib.parse import quote
from typing_extensions import TypedDict
from backend.blocks._base import (
Block,
BlockCategory,
BlockOutput,
BlockSchemaInput,
BlockSchemaOutput,
)
from backend.data.model import SchemaField
from ._api import get_api
from ._auth import (
TEST_CREDENTIALS,
TEST_CREDENTIALS_INPUT,
GithubCredentials,
GithubCredentialsField,
GithubCredentialsInput,
)
from ._utils import github_repo_path
class GithubListCommitsBlock(Block):
class Input(BlockSchemaInput):
credentials: GithubCredentialsInput = GithubCredentialsField("repo")
repo_url: str = SchemaField(
description="URL of the GitHub repository",
placeholder="https://github.com/owner/repo",
)
branch: str = SchemaField(
description="Branch name to list commits from",
default="main",
)
per_page: int = SchemaField(
description="Number of commits to return (max 100)",
default=30,
ge=1,
le=100,
)
page: int = SchemaField(
description="Page number for pagination",
default=1,
ge=1,
)
class Output(BlockSchemaOutput):
class CommitItem(TypedDict):
sha: str
message: str
author: str
date: str
url: str
commit: CommitItem = SchemaField(
title="Commit", description="A commit with its details"
)
commits: list[CommitItem] = SchemaField(
description="List of commits with their details"
)
error: str = SchemaField(description="Error message if listing commits failed")
def __init__(self):
super().__init__(
id="8b13f579-d8b6-4dc2-a140-f770428805de",
description="This block lists commits on a branch in a GitHub repository.",
categories={BlockCategory.DEVELOPER_TOOLS},
input_schema=GithubListCommitsBlock.Input,
output_schema=GithubListCommitsBlock.Output,
test_input={
"repo_url": "https://github.com/owner/repo",
"branch": "main",
"per_page": 30,
"page": 1,
"credentials": TEST_CREDENTIALS_INPUT,
},
test_credentials=TEST_CREDENTIALS,
test_output=[
(
"commits",
[
{
"sha": "abc123",
"message": "Initial commit",
"author": "octocat",
"date": "2024-01-01T00:00:00Z",
"url": "https://github.com/owner/repo/commit/abc123",
}
],
),
(
"commit",
{
"sha": "abc123",
"message": "Initial commit",
"author": "octocat",
"date": "2024-01-01T00:00:00Z",
"url": "https://github.com/owner/repo/commit/abc123",
},
),
],
test_mock={
"list_commits": lambda *args, **kwargs: [
{
"sha": "abc123",
"message": "Initial commit",
"author": "octocat",
"date": "2024-01-01T00:00:00Z",
"url": "https://github.com/owner/repo/commit/abc123",
}
]
},
)
@staticmethod
async def list_commits(
credentials: GithubCredentials,
repo_url: str,
branch: str,
per_page: int,
page: int,
) -> list[Output.CommitItem]:
api = get_api(credentials)
commits_url = repo_url + "/commits"
params = {"sha": branch, "per_page": str(per_page), "page": str(page)}
response = await api.get(commits_url, params=params)
data = response.json()
repo_path = github_repo_path(repo_url)
return [
GithubListCommitsBlock.Output.CommitItem(
sha=c["sha"],
message=c["commit"]["message"],
author=(c["commit"].get("author") or {}).get("name", "Unknown"),
date=(c["commit"].get("author") or {}).get("date", ""),
url=f"https://github.com/{repo_path}/commit/{c['sha']}",
)
for c in data
]
async def run(
self,
input_data: Input,
*,
credentials: GithubCredentials,
**kwargs,
) -> BlockOutput:
try:
commits = await self.list_commits(
credentials,
input_data.repo_url,
input_data.branch,
input_data.per_page,
input_data.page,
)
yield "commits", commits
for commit in commits:
yield "commit", commit
except Exception as e:
yield "error", str(e)
class FileOperation(StrEnum):
"""File operations for GithubMultiFileCommitBlock.
UPSERT creates or overwrites a file (the Git Trees API does not distinguish
between creation and update — the blob is placed at the given path regardless
of whether a file already exists there).
DELETE removes a file from the tree.
"""
UPSERT = "upsert"
DELETE = "delete"
class FileOperationInput(TypedDict):
path: str
content: str
operation: FileOperation
class GithubMultiFileCommitBlock(Block):
class Input(BlockSchemaInput):
credentials: GithubCredentialsInput = GithubCredentialsField("repo")
repo_url: str = SchemaField(
description="URL of the GitHub repository",
placeholder="https://github.com/owner/repo",
)
branch: str = SchemaField(
description="Branch to commit to",
placeholder="feature-branch",
)
commit_message: str = SchemaField(
description="Commit message",
placeholder="Add new feature",
)
files: list[FileOperationInput] = SchemaField(
description=(
"List of file operations. Each item has: "
"'path' (file path), 'content' (file content, ignored for delete), "
"'operation' (upsert/delete)"
),
)
class Output(BlockSchemaOutput):
sha: str = SchemaField(description="SHA of the new commit")
url: str = SchemaField(description="URL of the new commit")
error: str = SchemaField(description="Error message if the commit failed")
def __init__(self):
super().__init__(
id="389eee51-a95e-4230-9bed-92167a327802",
description=(
"This block creates a single commit with multiple file "
"upsert/delete operations using the Git Trees API."
),
categories={BlockCategory.DEVELOPER_TOOLS},
input_schema=GithubMultiFileCommitBlock.Input,
output_schema=GithubMultiFileCommitBlock.Output,
test_input={
"repo_url": "https://github.com/owner/repo",
"branch": "feature",
"commit_message": "Add files",
"files": [
{
"path": "src/new.py",
"content": "print('hello')",
"operation": "upsert",
},
{
"path": "src/old.py",
"content": "",
"operation": "delete",
},
],
"credentials": TEST_CREDENTIALS_INPUT,
},
test_credentials=TEST_CREDENTIALS,
test_output=[
("sha", "newcommitsha"),
("url", "https://github.com/owner/repo/commit/newcommitsha"),
],
test_mock={
"multi_file_commit": lambda *args, **kwargs: (
"newcommitsha",
"https://github.com/owner/repo/commit/newcommitsha",
)
},
)
@staticmethod
async def multi_file_commit(
credentials: GithubCredentials,
repo_url: str,
branch: str,
commit_message: str,
files: list[FileOperationInput],
) -> tuple[str, str]:
api = get_api(credentials)
safe_branch = quote(branch, safe="")
# 1. Get the latest commit SHA for the branch
ref_url = repo_url + f"/git/refs/heads/{safe_branch}"
response = await api.get(ref_url)
ref_data = response.json()
latest_commit_sha = ref_data["object"]["sha"]
# 2. Get the tree SHA of the latest commit
commit_url = repo_url + f"/git/commits/{latest_commit_sha}"
response = await api.get(commit_url)
commit_data = response.json()
base_tree_sha = commit_data["tree"]["sha"]
# 3. Build tree entries for each file operation (blobs created concurrently)
async def _create_blob(content: str) -> str:
blob_url = repo_url + "/git/blobs"
blob_response = await api.post(
blob_url,
json={"content": content, "encoding": "utf-8"},
)
return blob_response.json()["sha"]
tree_entries: list[dict] = []
upsert_files = []
for file_op in files:
path = file_op["path"]
operation = FileOperation(file_op.get("operation", "upsert"))
if operation == FileOperation.DELETE:
tree_entries.append(
{
"path": path,
"mode": "100644",
"type": "blob",
"sha": None, # null SHA = delete
}
)
else:
upsert_files.append((path, file_op.get("content", "")))
# Create all blobs concurrently
if upsert_files:
blob_shas = await asyncio.gather(
*[_create_blob(content) for _, content in upsert_files]
)
for (path, _), blob_sha in zip(upsert_files, blob_shas):
tree_entries.append(
{
"path": path,
"mode": "100644",
"type": "blob",
"sha": blob_sha,
}
)
# 4. Create a new tree
tree_url = repo_url + "/git/trees"
tree_response = await api.post(
tree_url,
json={"base_tree": base_tree_sha, "tree": tree_entries},
)
new_tree_sha = tree_response.json()["sha"]
# 5. Create a new commit
new_commit_url = repo_url + "/git/commits"
commit_response = await api.post(
new_commit_url,
json={
"message": commit_message,
"tree": new_tree_sha,
"parents": [latest_commit_sha],
},
)
new_commit_sha = commit_response.json()["sha"]
# 6. Update the branch reference
try:
await api.patch(
ref_url,
json={"sha": new_commit_sha},
)
except Exception as e:
raise RuntimeError(
f"Commit {new_commit_sha} was created but failed to update "
f"ref heads/{branch}: {e}. "
f"You can recover by manually updating the branch to {new_commit_sha}."
) from e
repo_path = github_repo_path(repo_url)
commit_web_url = f"https://github.com/{repo_path}/commit/{new_commit_sha}"
return new_commit_sha, commit_web_url
async def run(
self,
input_data: Input,
*,
credentials: GithubCredentials,
**kwargs,
) -> BlockOutput:
try:
sha, url = await self.multi_file_commit(
credentials,
input_data.repo_url,
input_data.branch,
input_data.commit_message,
input_data.files,
)
yield "sha", sha
yield "url", url
except Exception as e:
yield "error", str(e)

View File

@@ -1,5 +1,4 @@
import re
from typing import Literal
from typing_extensions import TypedDict
@@ -21,8 +20,6 @@ from ._auth import (
GithubCredentialsInput,
)
MergeMethod = Literal["merge", "squash", "rebase"]
class GithubListPullRequestsBlock(Block):
class Input(BlockSchemaInput):
@@ -561,109 +558,12 @@ class GithubListPRReviewersBlock(Block):
yield "reviewer", reviewer
class GithubMergePullRequestBlock(Block):
class Input(BlockSchemaInput):
credentials: GithubCredentialsInput = GithubCredentialsField("repo")
pr_url: str = SchemaField(
description="URL of the GitHub pull request",
placeholder="https://github.com/owner/repo/pull/1",
)
merge_method: MergeMethod = SchemaField(
description="Merge method to use: merge, squash, or rebase",
default="merge",
)
commit_title: str = SchemaField(
description="Title for the merge commit (optional, used for merge and squash)",
default="",
)
commit_message: str = SchemaField(
description="Message for the merge commit (optional, used for merge and squash)",
default="",
)
class Output(BlockSchemaOutput):
sha: str = SchemaField(description="SHA of the merge commit")
merged: bool = SchemaField(description="Whether the PR was merged")
message: str = SchemaField(description="Merge status message")
error: str = SchemaField(description="Error message if the merge failed")
def __init__(self):
super().__init__(
id="77456c22-33d8-4fd4-9eef-50b46a35bb48",
description="This block merges a pull request using merge, squash, or rebase.",
categories={BlockCategory.DEVELOPER_TOOLS},
input_schema=GithubMergePullRequestBlock.Input,
output_schema=GithubMergePullRequestBlock.Output,
test_input={
"pr_url": "https://github.com/owner/repo/pull/1",
"merge_method": "squash",
"commit_title": "",
"commit_message": "",
"credentials": TEST_CREDENTIALS_INPUT,
},
test_credentials=TEST_CREDENTIALS,
test_output=[
("sha", "abc123"),
("merged", True),
("message", "Pull Request successfully merged"),
],
test_mock={
"merge_pr": lambda *args, **kwargs: (
"abc123",
True,
"Pull Request successfully merged",
)
},
is_sensitive_action=True,
)
@staticmethod
async def merge_pr(
credentials: GithubCredentials,
pr_url: str,
merge_method: MergeMethod,
commit_title: str,
commit_message: str,
) -> tuple[str, bool, str]:
api = get_api(credentials)
merge_url = prepare_pr_api_url(pr_url=pr_url, path="merge")
data: dict[str, str] = {"merge_method": merge_method}
if commit_title:
data["commit_title"] = commit_title
if commit_message:
data["commit_message"] = commit_message
response = await api.put(merge_url, json=data)
result = response.json()
return result["sha"], result["merged"], result["message"]
async def run(
self,
input_data: Input,
*,
credentials: GithubCredentials,
**kwargs,
) -> BlockOutput:
try:
sha, merged, message = await self.merge_pr(
credentials,
input_data.pr_url,
input_data.merge_method,
input_data.commit_title,
input_data.commit_message,
)
yield "sha", sha
yield "merged", merged
yield "message", message
except Exception as e:
yield "error", str(e)
def prepare_pr_api_url(pr_url: str, path: str) -> str:
# Pattern to capture the base repository URL and the pull request number
pattern = r"^(?:(https?)://)?([^/]+/[^/]+/[^/]+)/pull/(\d+)"
pattern = r"^(?:https?://)?([^/]+/[^/]+/[^/]+)/pull/(\d+)"
match = re.match(pattern, pr_url)
if not match:
return pr_url
scheme, base_url, pr_number = match.groups()
return f"{scheme or 'https'}://{base_url}/pulls/{pr_number}/{path}"
base_url, pr_number = match.groups()
return f"{base_url}/pulls/{pr_number}/{path}"

View File

@@ -1,3 +1,5 @@
import base64
from typing_extensions import TypedDict
from backend.blocks._base import (
@@ -17,7 +19,6 @@ from ._auth import (
GithubCredentialsField,
GithubCredentialsInput,
)
from ._utils import github_repo_path
class GithubListTagsBlock(Block):
@@ -88,7 +89,7 @@ class GithubListTagsBlock(Block):
tags_url = repo_url + "/tags"
response = await api.get(tags_url)
data = response.json()
repo_path = github_repo_path(repo_url)
repo_path = repo_url.replace("https://github.com/", "")
tags: list[GithubListTagsBlock.Output.TagItem] = [
{
"name": tag["name"],
@@ -114,6 +115,101 @@ class GithubListTagsBlock(Block):
yield "tag", tag
class GithubListBranchesBlock(Block):
class Input(BlockSchemaInput):
credentials: GithubCredentialsInput = GithubCredentialsField("repo")
repo_url: str = SchemaField(
description="URL of the GitHub repository",
placeholder="https://github.com/owner/repo",
)
class Output(BlockSchemaOutput):
class BranchItem(TypedDict):
name: str
url: str
branch: BranchItem = SchemaField(
title="Branch",
description="Branches with their name and file tree browser URL",
)
branches: list[BranchItem] = SchemaField(
description="List of branches with their name and file tree browser URL"
)
def __init__(self):
super().__init__(
id="74243e49-2bec-4916-8bf4-db43d44aead5",
description="This block lists all branches for a specified GitHub repository.",
categories={BlockCategory.DEVELOPER_TOOLS},
input_schema=GithubListBranchesBlock.Input,
output_schema=GithubListBranchesBlock.Output,
test_input={
"repo_url": "https://github.com/owner/repo",
"credentials": TEST_CREDENTIALS_INPUT,
},
test_credentials=TEST_CREDENTIALS,
test_output=[
(
"branches",
[
{
"name": "main",
"url": "https://github.com/owner/repo/tree/main",
}
],
),
(
"branch",
{
"name": "main",
"url": "https://github.com/owner/repo/tree/main",
},
),
],
test_mock={
"list_branches": lambda *args, **kwargs: [
{
"name": "main",
"url": "https://github.com/owner/repo/tree/main",
}
]
},
)
@staticmethod
async def list_branches(
credentials: GithubCredentials, repo_url: str
) -> list[Output.BranchItem]:
api = get_api(credentials)
branches_url = repo_url + "/branches"
response = await api.get(branches_url)
data = response.json()
repo_path = repo_url.replace("https://github.com/", "")
branches: list[GithubListBranchesBlock.Output.BranchItem] = [
{
"name": branch["name"],
"url": f"https://github.com/{repo_path}/tree/{branch['name']}",
}
for branch in data
]
return branches
async def run(
self,
input_data: Input,
*,
credentials: GithubCredentials,
**kwargs,
) -> BlockOutput:
branches = await self.list_branches(
credentials,
input_data.repo_url,
)
yield "branches", branches
for branch in branches:
yield "branch", branch
class GithubListDiscussionsBlock(Block):
class Input(BlockSchemaInput):
credentials: GithubCredentialsInput = GithubCredentialsField("repo")
@@ -187,7 +283,7 @@ class GithubListDiscussionsBlock(Block):
) -> list[Output.DiscussionItem]:
api = get_api(credentials)
# GitHub GraphQL API endpoint is different; we'll use api.post with custom URL
repo_path = github_repo_path(repo_url)
repo_path = repo_url.replace("https://github.com/", "")
owner, repo = repo_path.split("/")
query = """
query($owner: String!, $repo: String!, $num: Int!) {
@@ -320,6 +416,564 @@ class GithubListReleasesBlock(Block):
yield "release", release
class GithubReadFileBlock(Block):
class Input(BlockSchemaInput):
credentials: GithubCredentialsInput = GithubCredentialsField("repo")
repo_url: str = SchemaField(
description="URL of the GitHub repository",
placeholder="https://github.com/owner/repo",
)
file_path: str = SchemaField(
description="Path to the file in the repository",
placeholder="path/to/file",
)
branch: str = SchemaField(
description="Branch to read from",
placeholder="branch_name",
default="master",
)
class Output(BlockSchemaOutput):
text_content: str = SchemaField(
description="Content of the file (decoded as UTF-8 text)"
)
raw_content: str = SchemaField(
description="Raw base64-encoded content of the file"
)
size: int = SchemaField(description="The size of the file (in bytes)")
def __init__(self):
super().__init__(
id="87ce6c27-5752-4bbc-8e26-6da40a3dcfd3",
description="This block reads the content of a specified file from a GitHub repository.",
categories={BlockCategory.DEVELOPER_TOOLS},
input_schema=GithubReadFileBlock.Input,
output_schema=GithubReadFileBlock.Output,
test_input={
"repo_url": "https://github.com/owner/repo",
"file_path": "path/to/file",
"branch": "master",
"credentials": TEST_CREDENTIALS_INPUT,
},
test_credentials=TEST_CREDENTIALS,
test_output=[
("raw_content", "RmlsZSBjb250ZW50"),
("text_content", "File content"),
("size", 13),
],
test_mock={"read_file": lambda *args, **kwargs: ("RmlsZSBjb250ZW50", 13)},
)
@staticmethod
async def read_file(
credentials: GithubCredentials, repo_url: str, file_path: str, branch: str
) -> tuple[str, int]:
api = get_api(credentials)
content_url = repo_url + f"/contents/{file_path}?ref={branch}"
response = await api.get(content_url)
data = response.json()
if isinstance(data, list):
# Multiple entries of different types exist at this path
if not (file := next((f for f in data if f["type"] == "file"), None)):
raise TypeError("Not a file")
data = file
if data["type"] != "file":
raise TypeError("Not a file")
return data["content"], data["size"]
async def run(
self,
input_data: Input,
*,
credentials: GithubCredentials,
**kwargs,
) -> BlockOutput:
content, size = await self.read_file(
credentials,
input_data.repo_url,
input_data.file_path,
input_data.branch,
)
yield "raw_content", content
yield "text_content", base64.b64decode(content).decode("utf-8")
yield "size", size
class GithubReadFolderBlock(Block):
class Input(BlockSchemaInput):
credentials: GithubCredentialsInput = GithubCredentialsField("repo")
repo_url: str = SchemaField(
description="URL of the GitHub repository",
placeholder="https://github.com/owner/repo",
)
folder_path: str = SchemaField(
description="Path to the folder in the repository",
placeholder="path/to/folder",
)
branch: str = SchemaField(
description="Branch name to read from (defaults to master)",
placeholder="branch_name",
default="master",
)
class Output(BlockSchemaOutput):
class DirEntry(TypedDict):
name: str
path: str
class FileEntry(TypedDict):
name: str
path: str
size: int
file: FileEntry = SchemaField(description="Files in the folder")
dir: DirEntry = SchemaField(description="Directories in the folder")
error: str = SchemaField(
description="Error message if reading the folder failed"
)
def __init__(self):
super().__init__(
id="1355f863-2db3-4d75-9fba-f91e8a8ca400",
description="This block reads the content of a specified folder from a GitHub repository.",
categories={BlockCategory.DEVELOPER_TOOLS},
input_schema=GithubReadFolderBlock.Input,
output_schema=GithubReadFolderBlock.Output,
test_input={
"repo_url": "https://github.com/owner/repo",
"folder_path": "path/to/folder",
"branch": "master",
"credentials": TEST_CREDENTIALS_INPUT,
},
test_credentials=TEST_CREDENTIALS,
test_output=[
(
"file",
{
"name": "file1.txt",
"path": "path/to/folder/file1.txt",
"size": 1337,
},
),
("dir", {"name": "dir2", "path": "path/to/folder/dir2"}),
],
test_mock={
"read_folder": lambda *args, **kwargs: (
[
{
"name": "file1.txt",
"path": "path/to/folder/file1.txt",
"size": 1337,
}
],
[{"name": "dir2", "path": "path/to/folder/dir2"}],
)
},
)
@staticmethod
async def read_folder(
credentials: GithubCredentials, repo_url: str, folder_path: str, branch: str
) -> tuple[list[Output.FileEntry], list[Output.DirEntry]]:
api = get_api(credentials)
contents_url = repo_url + f"/contents/{folder_path}?ref={branch}"
response = await api.get(contents_url)
data = response.json()
if not isinstance(data, list):
raise TypeError("Not a folder")
files: list[GithubReadFolderBlock.Output.FileEntry] = [
GithubReadFolderBlock.Output.FileEntry(
name=entry["name"],
path=entry["path"],
size=entry["size"],
)
for entry in data
if entry["type"] == "file"
]
dirs: list[GithubReadFolderBlock.Output.DirEntry] = [
GithubReadFolderBlock.Output.DirEntry(
name=entry["name"],
path=entry["path"],
)
for entry in data
if entry["type"] == "dir"
]
return files, dirs
async def run(
self,
input_data: Input,
*,
credentials: GithubCredentials,
**kwargs,
) -> BlockOutput:
files, dirs = await self.read_folder(
credentials,
input_data.repo_url,
input_data.folder_path.lstrip("/"),
input_data.branch,
)
for file in files:
yield "file", file
for dir in dirs:
yield "dir", dir
class GithubMakeBranchBlock(Block):
class Input(BlockSchemaInput):
credentials: GithubCredentialsInput = GithubCredentialsField("repo")
repo_url: str = SchemaField(
description="URL of the GitHub repository",
placeholder="https://github.com/owner/repo",
)
new_branch: str = SchemaField(
description="Name of the new branch",
placeholder="new_branch_name",
)
source_branch: str = SchemaField(
description="Name of the source branch",
placeholder="source_branch_name",
)
class Output(BlockSchemaOutput):
status: str = SchemaField(description="Status of the branch creation operation")
error: str = SchemaField(
description="Error message if the branch creation failed"
)
def __init__(self):
super().__init__(
id="944cc076-95e7-4d1b-b6b6-b15d8ee5448d",
description="This block creates a new branch from a specified source branch.",
categories={BlockCategory.DEVELOPER_TOOLS},
input_schema=GithubMakeBranchBlock.Input,
output_schema=GithubMakeBranchBlock.Output,
test_input={
"repo_url": "https://github.com/owner/repo",
"new_branch": "new_branch_name",
"source_branch": "source_branch_name",
"credentials": TEST_CREDENTIALS_INPUT,
},
test_credentials=TEST_CREDENTIALS,
test_output=[("status", "Branch created successfully")],
test_mock={
"create_branch": lambda *args, **kwargs: "Branch created successfully"
},
)
@staticmethod
async def create_branch(
credentials: GithubCredentials,
repo_url: str,
new_branch: str,
source_branch: str,
) -> str:
api = get_api(credentials)
ref_url = repo_url + f"/git/refs/heads/{source_branch}"
response = await api.get(ref_url)
data = response.json()
sha = data["object"]["sha"]
# Create the new branch
new_ref_url = repo_url + "/git/refs"
data = {
"ref": f"refs/heads/{new_branch}",
"sha": sha,
}
response = await api.post(new_ref_url, json=data)
return "Branch created successfully"
async def run(
self,
input_data: Input,
*,
credentials: GithubCredentials,
**kwargs,
) -> BlockOutput:
status = await self.create_branch(
credentials,
input_data.repo_url,
input_data.new_branch,
input_data.source_branch,
)
yield "status", status
class GithubDeleteBranchBlock(Block):
class Input(BlockSchemaInput):
credentials: GithubCredentialsInput = GithubCredentialsField("repo")
repo_url: str = SchemaField(
description="URL of the GitHub repository",
placeholder="https://github.com/owner/repo",
)
branch: str = SchemaField(
description="Name of the branch to delete",
placeholder="branch_name",
)
class Output(BlockSchemaOutput):
status: str = SchemaField(description="Status of the branch deletion operation")
error: str = SchemaField(
description="Error message if the branch deletion failed"
)
def __init__(self):
super().__init__(
id="0d4130f7-e0ab-4d55-adc3-0a40225e80f4",
description="This block deletes a specified branch.",
categories={BlockCategory.DEVELOPER_TOOLS},
input_schema=GithubDeleteBranchBlock.Input,
output_schema=GithubDeleteBranchBlock.Output,
test_input={
"repo_url": "https://github.com/owner/repo",
"branch": "branch_name",
"credentials": TEST_CREDENTIALS_INPUT,
},
test_credentials=TEST_CREDENTIALS,
test_output=[("status", "Branch deleted successfully")],
test_mock={
"delete_branch": lambda *args, **kwargs: "Branch deleted successfully"
},
)
@staticmethod
async def delete_branch(
credentials: GithubCredentials, repo_url: str, branch: str
) -> str:
api = get_api(credentials)
ref_url = repo_url + f"/git/refs/heads/{branch}"
await api.delete(ref_url)
return "Branch deleted successfully"
async def run(
self,
input_data: Input,
*,
credentials: GithubCredentials,
**kwargs,
) -> BlockOutput:
status = await self.delete_branch(
credentials,
input_data.repo_url,
input_data.branch,
)
yield "status", status
class GithubCreateFileBlock(Block):
class Input(BlockSchemaInput):
credentials: GithubCredentialsInput = GithubCredentialsField("repo")
repo_url: str = SchemaField(
description="URL of the GitHub repository",
placeholder="https://github.com/owner/repo",
)
file_path: str = SchemaField(
description="Path where the file should be created",
placeholder="path/to/file.txt",
)
content: str = SchemaField(
description="Content to write to the file",
placeholder="File content here",
)
branch: str = SchemaField(
description="Branch where the file should be created",
default="main",
)
commit_message: str = SchemaField(
description="Message for the commit",
default="Create new file",
)
class Output(BlockSchemaOutput):
url: str = SchemaField(description="URL of the created file")
sha: str = SchemaField(description="SHA of the commit")
error: str = SchemaField(
description="Error message if the file creation failed"
)
def __init__(self):
super().__init__(
id="8fd132ac-b917-428a-8159-d62893e8a3fe",
description="This block creates a new file in a GitHub repository.",
categories={BlockCategory.DEVELOPER_TOOLS},
input_schema=GithubCreateFileBlock.Input,
output_schema=GithubCreateFileBlock.Output,
test_input={
"repo_url": "https://github.com/owner/repo",
"file_path": "test/file.txt",
"content": "Test content",
"branch": "main",
"commit_message": "Create test file",
"credentials": TEST_CREDENTIALS_INPUT,
},
test_credentials=TEST_CREDENTIALS,
test_output=[
("url", "https://github.com/owner/repo/blob/main/test/file.txt"),
("sha", "abc123"),
],
test_mock={
"create_file": lambda *args, **kwargs: (
"https://github.com/owner/repo/blob/main/test/file.txt",
"abc123",
)
},
)
@staticmethod
async def create_file(
credentials: GithubCredentials,
repo_url: str,
file_path: str,
content: str,
branch: str,
commit_message: str,
) -> tuple[str, str]:
api = get_api(credentials)
contents_url = repo_url + f"/contents/{file_path}"
content_base64 = base64.b64encode(content.encode()).decode()
data = {
"message": commit_message,
"content": content_base64,
"branch": branch,
}
response = await api.put(contents_url, json=data)
data = response.json()
return data["content"]["html_url"], data["commit"]["sha"]
async def run(
self,
input_data: Input,
*,
credentials: GithubCredentials,
**kwargs,
) -> BlockOutput:
try:
url, sha = await self.create_file(
credentials,
input_data.repo_url,
input_data.file_path,
input_data.content,
input_data.branch,
input_data.commit_message,
)
yield "url", url
yield "sha", sha
except Exception as e:
yield "error", str(e)
class GithubUpdateFileBlock(Block):
class Input(BlockSchemaInput):
credentials: GithubCredentialsInput = GithubCredentialsField("repo")
repo_url: str = SchemaField(
description="URL of the GitHub repository",
placeholder="https://github.com/owner/repo",
)
file_path: str = SchemaField(
description="Path to the file to update",
placeholder="path/to/file.txt",
)
content: str = SchemaField(
description="New content for the file",
placeholder="Updated content here",
)
branch: str = SchemaField(
description="Branch containing the file",
default="main",
)
commit_message: str = SchemaField(
description="Message for the commit",
default="Update file",
)
class Output(BlockSchemaOutput):
url: str = SchemaField(description="URL of the updated file")
sha: str = SchemaField(description="SHA of the commit")
def __init__(self):
super().__init__(
id="30be12a4-57cb-4aa4-baf5-fcc68d136076",
description="This block updates an existing file in a GitHub repository.",
categories={BlockCategory.DEVELOPER_TOOLS},
input_schema=GithubUpdateFileBlock.Input,
output_schema=GithubUpdateFileBlock.Output,
test_input={
"repo_url": "https://github.com/owner/repo",
"file_path": "test/file.txt",
"content": "Updated content",
"branch": "main",
"commit_message": "Update test file",
"credentials": TEST_CREDENTIALS_INPUT,
},
test_credentials=TEST_CREDENTIALS,
test_output=[
("url", "https://github.com/owner/repo/blob/main/test/file.txt"),
("sha", "def456"),
],
test_mock={
"update_file": lambda *args, **kwargs: (
"https://github.com/owner/repo/blob/main/test/file.txt",
"def456",
)
},
)
@staticmethod
async def update_file(
credentials: GithubCredentials,
repo_url: str,
file_path: str,
content: str,
branch: str,
commit_message: str,
) -> tuple[str, str]:
api = get_api(credentials)
contents_url = repo_url + f"/contents/{file_path}"
params = {"ref": branch}
response = await api.get(contents_url, params=params)
data = response.json()
# Convert new content to base64
content_base64 = base64.b64encode(content.encode()).decode()
data = {
"message": commit_message,
"content": content_base64,
"sha": data["sha"],
"branch": branch,
}
response = await api.put(contents_url, json=data)
data = response.json()
return data["content"]["html_url"], data["commit"]["sha"]
async def run(
self,
input_data: Input,
*,
credentials: GithubCredentials,
**kwargs,
) -> BlockOutput:
try:
url, sha = await self.update_file(
credentials,
input_data.repo_url,
input_data.file_path,
input_data.content,
input_data.branch,
input_data.commit_message,
)
yield "url", url
yield "sha", sha
except Exception as e:
yield "error", str(e)
class GithubCreateRepositoryBlock(Block):
class Input(BlockSchemaInput):
credentials: GithubCredentialsInput = GithubCredentialsField("repo")
@@ -449,7 +1103,7 @@ class GithubListStargazersBlock(Block):
def __init__(self):
super().__init__(
id="e96d01ec-b55e-4a99-8ce8-c8776dce850b", # Generated unique UUID
id="a4b9c2d1-e5f6-4g7h-8i9j-0k1l2m3n4o5p", # Generated unique UUID
description="This block lists all users who have starred a specified GitHub repository.",
categories={BlockCategory.DEVELOPER_TOOLS},
input_schema=GithubListStargazersBlock.Input,
@@ -518,230 +1172,3 @@ class GithubListStargazersBlock(Block):
yield "stargazers", stargazers
for stargazer in stargazers:
yield "stargazer", stargazer
class GithubGetRepositoryInfoBlock(Block):
class Input(BlockSchemaInput):
credentials: GithubCredentialsInput = GithubCredentialsField("repo")
repo_url: str = SchemaField(
description="URL of the GitHub repository",
placeholder="https://github.com/owner/repo",
)
class Output(BlockSchemaOutput):
name: str = SchemaField(description="Repository name")
full_name: str = SchemaField(description="Full repository name (owner/repo)")
description: str = SchemaField(description="Repository description")
default_branch: str = SchemaField(description="Default branch name (e.g. main)")
private: bool = SchemaField(description="Whether the repository is private")
html_url: str = SchemaField(description="Web URL of the repository")
clone_url: str = SchemaField(description="Git clone URL")
stars: int = SchemaField(description="Number of stars")
forks: int = SchemaField(description="Number of forks")
open_issues: int = SchemaField(description="Number of open issues")
error: str = SchemaField(
description="Error message if fetching repo info failed"
)
def __init__(self):
super().__init__(
id="59d4f241-968a-4040-95da-348ac5c5ce27",
description="This block retrieves metadata about a GitHub repository.",
categories={BlockCategory.DEVELOPER_TOOLS},
input_schema=GithubGetRepositoryInfoBlock.Input,
output_schema=GithubGetRepositoryInfoBlock.Output,
test_input={
"repo_url": "https://github.com/owner/repo",
"credentials": TEST_CREDENTIALS_INPUT,
},
test_credentials=TEST_CREDENTIALS,
test_output=[
("name", "repo"),
("full_name", "owner/repo"),
("description", "A test repo"),
("default_branch", "main"),
("private", False),
("html_url", "https://github.com/owner/repo"),
("clone_url", "https://github.com/owner/repo.git"),
("stars", 42),
("forks", 5),
("open_issues", 3),
],
test_mock={
"get_repo_info": lambda *args, **kwargs: {
"name": "repo",
"full_name": "owner/repo",
"description": "A test repo",
"default_branch": "main",
"private": False,
"html_url": "https://github.com/owner/repo",
"clone_url": "https://github.com/owner/repo.git",
"stargazers_count": 42,
"forks_count": 5,
"open_issues_count": 3,
}
},
)
@staticmethod
async def get_repo_info(credentials: GithubCredentials, repo_url: str) -> dict:
api = get_api(credentials)
response = await api.get(repo_url)
return response.json()
async def run(
self,
input_data: Input,
*,
credentials: GithubCredentials,
**kwargs,
) -> BlockOutput:
try:
data = await self.get_repo_info(credentials, input_data.repo_url)
yield "name", data["name"]
yield "full_name", data["full_name"]
yield "description", data.get("description", "") or ""
yield "default_branch", data["default_branch"]
yield "private", data["private"]
yield "html_url", data["html_url"]
yield "clone_url", data["clone_url"]
yield "stars", data["stargazers_count"]
yield "forks", data["forks_count"]
yield "open_issues", data["open_issues_count"]
except Exception as e:
yield "error", str(e)
class GithubForkRepositoryBlock(Block):
class Input(BlockSchemaInput):
credentials: GithubCredentialsInput = GithubCredentialsField("repo")
repo_url: str = SchemaField(
description="URL of the GitHub repository to fork",
placeholder="https://github.com/owner/repo",
)
organization: str = SchemaField(
description="Organization to fork into (leave empty to fork to your account)",
default="",
)
class Output(BlockSchemaOutput):
url: str = SchemaField(description="URL of the forked repository")
clone_url: str = SchemaField(description="Git clone URL of the fork")
full_name: str = SchemaField(description="Full name of the fork (owner/repo)")
error: str = SchemaField(description="Error message if the fork failed")
def __init__(self):
super().__init__(
id="a439f2f4-835f-4dae-ba7b-0205ffa70be6",
description="This block forks a GitHub repository to your account or an organization.",
categories={BlockCategory.DEVELOPER_TOOLS},
input_schema=GithubForkRepositoryBlock.Input,
output_schema=GithubForkRepositoryBlock.Output,
test_input={
"repo_url": "https://github.com/owner/repo",
"organization": "",
"credentials": TEST_CREDENTIALS_INPUT,
},
test_credentials=TEST_CREDENTIALS,
test_output=[
("url", "https://github.com/myuser/repo"),
("clone_url", "https://github.com/myuser/repo.git"),
("full_name", "myuser/repo"),
],
test_mock={
"fork_repo": lambda *args, **kwargs: (
"https://github.com/myuser/repo",
"https://github.com/myuser/repo.git",
"myuser/repo",
)
},
)
@staticmethod
async def fork_repo(
credentials: GithubCredentials,
repo_url: str,
organization: str,
) -> tuple[str, str, str]:
api = get_api(credentials)
forks_url = repo_url + "/forks"
data: dict[str, str] = {}
if organization:
data["organization"] = organization
response = await api.post(forks_url, json=data)
result = response.json()
return result["html_url"], result["clone_url"], result["full_name"]
async def run(
self,
input_data: Input,
*,
credentials: GithubCredentials,
**kwargs,
) -> BlockOutput:
try:
url, clone_url, full_name = await self.fork_repo(
credentials,
input_data.repo_url,
input_data.organization,
)
yield "url", url
yield "clone_url", clone_url
yield "full_name", full_name
except Exception as e:
yield "error", str(e)
class GithubStarRepositoryBlock(Block):
class Input(BlockSchemaInput):
credentials: GithubCredentialsInput = GithubCredentialsField("repo")
repo_url: str = SchemaField(
description="URL of the GitHub repository to star",
placeholder="https://github.com/owner/repo",
)
class Output(BlockSchemaOutput):
status: str = SchemaField(description="Status of the star operation")
error: str = SchemaField(description="Error message if starring failed")
def __init__(self):
super().__init__(
id="bd700764-53e3-44dd-a969-d1854088458f",
description="This block stars a GitHub repository.",
categories={BlockCategory.DEVELOPER_TOOLS},
input_schema=GithubStarRepositoryBlock.Input,
output_schema=GithubStarRepositoryBlock.Output,
test_input={
"repo_url": "https://github.com/owner/repo",
"credentials": TEST_CREDENTIALS_INPUT,
},
test_credentials=TEST_CREDENTIALS,
test_output=[("status", "Repository starred successfully")],
test_mock={
"star_repo": lambda *args, **kwargs: "Repository starred successfully"
},
)
@staticmethod
async def star_repo(credentials: GithubCredentials, repo_url: str) -> str:
api = get_api(credentials, convert_urls=False)
repo_path = github_repo_path(repo_url)
owner, repo = repo_path.split("/")
await api.put(
f"https://api.github.com/user/starred/{owner}/{repo}",
headers={"Content-Length": "0"},
)
return "Repository starred successfully"
async def run(
self,
input_data: Input,
*,
credentials: GithubCredentials,
**kwargs,
) -> BlockOutput:
try:
status = await self.star_repo(credentials, input_data.repo_url)
yield "status", status
except Exception as e:
yield "error", str(e)

View File

@@ -1,452 +0,0 @@
from urllib.parse import quote
from typing_extensions import TypedDict
from backend.blocks._base import (
Block,
BlockCategory,
BlockOutput,
BlockSchemaInput,
BlockSchemaOutput,
)
from backend.data.model import SchemaField
from ._api import get_api
from ._auth import (
TEST_CREDENTIALS,
TEST_CREDENTIALS_INPUT,
GithubCredentials,
GithubCredentialsField,
GithubCredentialsInput,
)
from ._utils import github_repo_path
class GithubListBranchesBlock(Block):
class Input(BlockSchemaInput):
credentials: GithubCredentialsInput = GithubCredentialsField("repo")
repo_url: str = SchemaField(
description="URL of the GitHub repository",
placeholder="https://github.com/owner/repo",
)
per_page: int = SchemaField(
description="Number of branches to return per page (max 100)",
default=30,
ge=1,
le=100,
)
page: int = SchemaField(
description="Page number for pagination",
default=1,
ge=1,
)
class Output(BlockSchemaOutput):
class BranchItem(TypedDict):
name: str
url: str
branch: BranchItem = SchemaField(
title="Branch",
description="Branches with their name and file tree browser URL",
)
branches: list[BranchItem] = SchemaField(
description="List of branches with their name and file tree browser URL"
)
error: str = SchemaField(description="Error message if listing branches failed")
def __init__(self):
super().__init__(
id="74243e49-2bec-4916-8bf4-db43d44aead5",
description="This block lists all branches for a specified GitHub repository.",
categories={BlockCategory.DEVELOPER_TOOLS},
input_schema=GithubListBranchesBlock.Input,
output_schema=GithubListBranchesBlock.Output,
test_input={
"repo_url": "https://github.com/owner/repo",
"per_page": 30,
"page": 1,
"credentials": TEST_CREDENTIALS_INPUT,
},
test_credentials=TEST_CREDENTIALS,
test_output=[
(
"branches",
[
{
"name": "main",
"url": "https://github.com/owner/repo/tree/main",
}
],
),
(
"branch",
{
"name": "main",
"url": "https://github.com/owner/repo/tree/main",
},
),
],
test_mock={
"list_branches": lambda *args, **kwargs: [
{
"name": "main",
"url": "https://github.com/owner/repo/tree/main",
}
]
},
)
@staticmethod
async def list_branches(
credentials: GithubCredentials, repo_url: str, per_page: int, page: int
) -> list[Output.BranchItem]:
api = get_api(credentials)
branches_url = repo_url + "/branches"
response = await api.get(
branches_url, params={"per_page": str(per_page), "page": str(page)}
)
data = response.json()
repo_path = github_repo_path(repo_url)
branches: list[GithubListBranchesBlock.Output.BranchItem] = [
{
"name": branch["name"],
"url": f"https://github.com/{repo_path}/tree/{branch['name']}",
}
for branch in data
]
return branches
async def run(
self,
input_data: Input,
*,
credentials: GithubCredentials,
**kwargs,
) -> BlockOutput:
try:
branches = await self.list_branches(
credentials,
input_data.repo_url,
input_data.per_page,
input_data.page,
)
yield "branches", branches
for branch in branches:
yield "branch", branch
except Exception as e:
yield "error", str(e)
class GithubMakeBranchBlock(Block):
class Input(BlockSchemaInput):
credentials: GithubCredentialsInput = GithubCredentialsField("repo")
repo_url: str = SchemaField(
description="URL of the GitHub repository",
placeholder="https://github.com/owner/repo",
)
new_branch: str = SchemaField(
description="Name of the new branch",
placeholder="new_branch_name",
)
source_branch: str = SchemaField(
description="Name of the source branch",
placeholder="source_branch_name",
)
class Output(BlockSchemaOutput):
status: str = SchemaField(description="Status of the branch creation operation")
error: str = SchemaField(
description="Error message if the branch creation failed"
)
def __init__(self):
super().__init__(
id="944cc076-95e7-4d1b-b6b6-b15d8ee5448d",
description="This block creates a new branch from a specified source branch.",
categories={BlockCategory.DEVELOPER_TOOLS},
input_schema=GithubMakeBranchBlock.Input,
output_schema=GithubMakeBranchBlock.Output,
test_input={
"repo_url": "https://github.com/owner/repo",
"new_branch": "new_branch_name",
"source_branch": "source_branch_name",
"credentials": TEST_CREDENTIALS_INPUT,
},
test_credentials=TEST_CREDENTIALS,
test_output=[("status", "Branch created successfully")],
test_mock={
"create_branch": lambda *args, **kwargs: "Branch created successfully"
},
)
@staticmethod
async def create_branch(
credentials: GithubCredentials,
repo_url: str,
new_branch: str,
source_branch: str,
) -> str:
api = get_api(credentials)
ref_url = repo_url + f"/git/refs/heads/{quote(source_branch, safe='')}"
response = await api.get(ref_url)
data = response.json()
sha = data["object"]["sha"]
# Create the new branch
new_ref_url = repo_url + "/git/refs"
data = {
"ref": f"refs/heads/{new_branch}",
"sha": sha,
}
response = await api.post(new_ref_url, json=data)
return "Branch created successfully"
async def run(
self,
input_data: Input,
*,
credentials: GithubCredentials,
**kwargs,
) -> BlockOutput:
try:
status = await self.create_branch(
credentials,
input_data.repo_url,
input_data.new_branch,
input_data.source_branch,
)
yield "status", status
except Exception as e:
yield "error", str(e)
class GithubDeleteBranchBlock(Block):
class Input(BlockSchemaInput):
credentials: GithubCredentialsInput = GithubCredentialsField("repo")
repo_url: str = SchemaField(
description="URL of the GitHub repository",
placeholder="https://github.com/owner/repo",
)
branch: str = SchemaField(
description="Name of the branch to delete",
placeholder="branch_name",
)
class Output(BlockSchemaOutput):
status: str = SchemaField(description="Status of the branch deletion operation")
error: str = SchemaField(
description="Error message if the branch deletion failed"
)
def __init__(self):
super().__init__(
id="0d4130f7-e0ab-4d55-adc3-0a40225e80f4",
description="This block deletes a specified branch.",
categories={BlockCategory.DEVELOPER_TOOLS},
input_schema=GithubDeleteBranchBlock.Input,
output_schema=GithubDeleteBranchBlock.Output,
test_input={
"repo_url": "https://github.com/owner/repo",
"branch": "branch_name",
"credentials": TEST_CREDENTIALS_INPUT,
},
test_credentials=TEST_CREDENTIALS,
test_output=[("status", "Branch deleted successfully")],
test_mock={
"delete_branch": lambda *args, **kwargs: "Branch deleted successfully"
},
is_sensitive_action=True,
)
@staticmethod
async def delete_branch(
credentials: GithubCredentials, repo_url: str, branch: str
) -> str:
api = get_api(credentials)
ref_url = repo_url + f"/git/refs/heads/{quote(branch, safe='')}"
await api.delete(ref_url)
return "Branch deleted successfully"
async def run(
self,
input_data: Input,
*,
credentials: GithubCredentials,
**kwargs,
) -> BlockOutput:
try:
status = await self.delete_branch(
credentials,
input_data.repo_url,
input_data.branch,
)
yield "status", status
except Exception as e:
yield "error", str(e)
class GithubCompareBranchesBlock(Block):
class Input(BlockSchemaInput):
credentials: GithubCredentialsInput = GithubCredentialsField("repo")
repo_url: str = SchemaField(
description="URL of the GitHub repository",
placeholder="https://github.com/owner/repo",
)
base: str = SchemaField(
description="Base branch or commit SHA",
placeholder="main",
)
head: str = SchemaField(
description="Head branch or commit SHA to compare against base",
placeholder="feature-branch",
)
class Output(BlockSchemaOutput):
class FileChange(TypedDict):
filename: str
status: str
additions: int
deletions: int
patch: str
status: str = SchemaField(
description="Comparison status: ahead, behind, diverged, or identical"
)
ahead_by: int = SchemaField(
description="Number of commits head is ahead of base"
)
behind_by: int = SchemaField(
description="Number of commits head is behind base"
)
total_commits: int = SchemaField(
description="Total number of commits in the comparison"
)
diff: str = SchemaField(description="Unified diff of all file changes")
file: FileChange = SchemaField(
title="Changed File", description="A changed file with its diff"
)
files: list[FileChange] = SchemaField(
description="List of changed files with their diffs"
)
error: str = SchemaField(description="Error message if comparison failed")
def __init__(self):
super().__init__(
id="2e4faa8c-6086-4546-ba77-172d1d560186",
description="This block compares two branches or commits in a GitHub repository.",
categories={BlockCategory.DEVELOPER_TOOLS},
input_schema=GithubCompareBranchesBlock.Input,
output_schema=GithubCompareBranchesBlock.Output,
test_input={
"repo_url": "https://github.com/owner/repo",
"base": "main",
"head": "feature",
"credentials": TEST_CREDENTIALS_INPUT,
},
test_credentials=TEST_CREDENTIALS,
test_output=[
("status", "ahead"),
("ahead_by", 2),
("behind_by", 0),
("total_commits", 2),
("diff", "+++ b/file.py\n+new line"),
(
"files",
[
{
"filename": "file.py",
"status": "modified",
"additions": 1,
"deletions": 0,
"patch": "+new line",
}
],
),
(
"file",
{
"filename": "file.py",
"status": "modified",
"additions": 1,
"deletions": 0,
"patch": "+new line",
},
),
],
test_mock={
"compare_branches": lambda *args, **kwargs: {
"status": "ahead",
"ahead_by": 2,
"behind_by": 0,
"total_commits": 2,
"files": [
{
"filename": "file.py",
"status": "modified",
"additions": 1,
"deletions": 0,
"patch": "+new line",
}
],
}
},
)
@staticmethod
async def compare_branches(
credentials: GithubCredentials,
repo_url: str,
base: str,
head: str,
) -> dict:
api = get_api(credentials)
safe_base = quote(base, safe="")
safe_head = quote(head, safe="")
compare_url = repo_url + f"/compare/{safe_base}...{safe_head}"
response = await api.get(compare_url)
return response.json()
async def run(
self,
input_data: Input,
*,
credentials: GithubCredentials,
**kwargs,
) -> BlockOutput:
try:
data = await self.compare_branches(
credentials,
input_data.repo_url,
input_data.base,
input_data.head,
)
yield "status", data["status"]
yield "ahead_by", data["ahead_by"]
yield "behind_by", data["behind_by"]
yield "total_commits", data["total_commits"]
files: list[GithubCompareBranchesBlock.Output.FileChange] = [
GithubCompareBranchesBlock.Output.FileChange(
filename=f["filename"],
status=f["status"],
additions=f["additions"],
deletions=f["deletions"],
patch=f.get("patch", ""),
)
for f in data.get("files", [])
]
# Build unified diff
diff_parts = []
for f in data.get("files", []):
patch = f.get("patch", "")
if patch:
diff_parts.append(f"+++ b/{f['filename']}\n{patch}")
yield "diff", "\n".join(diff_parts)
yield "files", files
for file in files:
yield "file", file
except Exception as e:
yield "error", str(e)

View File

@@ -1,720 +0,0 @@
import base64
from urllib.parse import quote
from typing_extensions import TypedDict
from backend.blocks._base import (
Block,
BlockCategory,
BlockOutput,
BlockSchemaInput,
BlockSchemaOutput,
)
from backend.data.model import SchemaField
from ._api import get_api
from ._auth import (
TEST_CREDENTIALS,
TEST_CREDENTIALS_INPUT,
GithubCredentials,
GithubCredentialsField,
GithubCredentialsInput,
)
class GithubReadFileBlock(Block):
class Input(BlockSchemaInput):
credentials: GithubCredentialsInput = GithubCredentialsField("repo")
repo_url: str = SchemaField(
description="URL of the GitHub repository",
placeholder="https://github.com/owner/repo",
)
file_path: str = SchemaField(
description="Path to the file in the repository",
placeholder="path/to/file",
)
branch: str = SchemaField(
description="Branch to read from",
placeholder="branch_name",
default="main",
)
class Output(BlockSchemaOutput):
text_content: str = SchemaField(
description="Content of the file (decoded as UTF-8 text)"
)
raw_content: str = SchemaField(
description="Raw base64-encoded content of the file"
)
size: int = SchemaField(description="The size of the file (in bytes)")
error: str = SchemaField(description="Error message if reading the file failed")
def __init__(self):
super().__init__(
id="87ce6c27-5752-4bbc-8e26-6da40a3dcfd3",
description="This block reads the content of a specified file from a GitHub repository.",
categories={BlockCategory.DEVELOPER_TOOLS},
input_schema=GithubReadFileBlock.Input,
output_schema=GithubReadFileBlock.Output,
test_input={
"repo_url": "https://github.com/owner/repo",
"file_path": "path/to/file",
"branch": "main",
"credentials": TEST_CREDENTIALS_INPUT,
},
test_credentials=TEST_CREDENTIALS,
test_output=[
("raw_content", "RmlsZSBjb250ZW50"),
("text_content", "File content"),
("size", 13),
],
test_mock={"read_file": lambda *args, **kwargs: ("RmlsZSBjb250ZW50", 13)},
)
@staticmethod
async def read_file(
credentials: GithubCredentials, repo_url: str, file_path: str, branch: str
) -> tuple[str, int]:
api = get_api(credentials)
content_url = (
repo_url
+ f"/contents/{quote(file_path, safe='')}?ref={quote(branch, safe='')}"
)
response = await api.get(content_url)
data = response.json()
if isinstance(data, list):
# Multiple entries of different types exist at this path
if not (file := next((f for f in data if f["type"] == "file"), None)):
raise TypeError("Not a file")
data = file
if data["type"] != "file":
raise TypeError("Not a file")
return data["content"], data["size"]
async def run(
self,
input_data: Input,
*,
credentials: GithubCredentials,
**kwargs,
) -> BlockOutput:
try:
content, size = await self.read_file(
credentials,
input_data.repo_url,
input_data.file_path,
input_data.branch,
)
yield "raw_content", content
yield "text_content", base64.b64decode(content).decode("utf-8")
yield "size", size
except Exception as e:
yield "error", str(e)
class GithubReadFolderBlock(Block):
class Input(BlockSchemaInput):
credentials: GithubCredentialsInput = GithubCredentialsField("repo")
repo_url: str = SchemaField(
description="URL of the GitHub repository",
placeholder="https://github.com/owner/repo",
)
folder_path: str = SchemaField(
description="Path to the folder in the repository",
placeholder="path/to/folder",
)
branch: str = SchemaField(
description="Branch name to read from (defaults to main)",
placeholder="branch_name",
default="main",
)
class Output(BlockSchemaOutput):
class DirEntry(TypedDict):
name: str
path: str
class FileEntry(TypedDict):
name: str
path: str
size: int
file: FileEntry = SchemaField(description="Files in the folder")
dir: DirEntry = SchemaField(description="Directories in the folder")
error: str = SchemaField(
description="Error message if reading the folder failed"
)
def __init__(self):
super().__init__(
id="1355f863-2db3-4d75-9fba-f91e8a8ca400",
description="This block reads the content of a specified folder from a GitHub repository.",
categories={BlockCategory.DEVELOPER_TOOLS},
input_schema=GithubReadFolderBlock.Input,
output_schema=GithubReadFolderBlock.Output,
test_input={
"repo_url": "https://github.com/owner/repo",
"folder_path": "path/to/folder",
"branch": "main",
"credentials": TEST_CREDENTIALS_INPUT,
},
test_credentials=TEST_CREDENTIALS,
test_output=[
(
"file",
{
"name": "file1.txt",
"path": "path/to/folder/file1.txt",
"size": 1337,
},
),
("dir", {"name": "dir2", "path": "path/to/folder/dir2"}),
],
test_mock={
"read_folder": lambda *args, **kwargs: (
[
{
"name": "file1.txt",
"path": "path/to/folder/file1.txt",
"size": 1337,
}
],
[{"name": "dir2", "path": "path/to/folder/dir2"}],
)
},
)
@staticmethod
async def read_folder(
credentials: GithubCredentials, repo_url: str, folder_path: str, branch: str
) -> tuple[list[Output.FileEntry], list[Output.DirEntry]]:
api = get_api(credentials)
contents_url = (
repo_url
+ f"/contents/{quote(folder_path, safe='/')}?ref={quote(branch, safe='')}"
)
response = await api.get(contents_url)
data = response.json()
if not isinstance(data, list):
raise TypeError("Not a folder")
files: list[GithubReadFolderBlock.Output.FileEntry] = [
GithubReadFolderBlock.Output.FileEntry(
name=entry["name"],
path=entry["path"],
size=entry["size"],
)
for entry in data
if entry["type"] == "file"
]
dirs: list[GithubReadFolderBlock.Output.DirEntry] = [
GithubReadFolderBlock.Output.DirEntry(
name=entry["name"],
path=entry["path"],
)
for entry in data
if entry["type"] == "dir"
]
return files, dirs
async def run(
self,
input_data: Input,
*,
credentials: GithubCredentials,
**kwargs,
) -> BlockOutput:
try:
files, dirs = await self.read_folder(
credentials,
input_data.repo_url,
input_data.folder_path.lstrip("/"),
input_data.branch,
)
for file in files:
yield "file", file
for dir in dirs:
yield "dir", dir
except Exception as e:
yield "error", str(e)
class GithubCreateFileBlock(Block):
class Input(BlockSchemaInput):
credentials: GithubCredentialsInput = GithubCredentialsField("repo")
repo_url: str = SchemaField(
description="URL of the GitHub repository",
placeholder="https://github.com/owner/repo",
)
file_path: str = SchemaField(
description="Path where the file should be created",
placeholder="path/to/file.txt",
)
content: str = SchemaField(
description="Content to write to the file",
placeholder="File content here",
)
branch: str = SchemaField(
description="Branch where the file should be created",
default="main",
)
commit_message: str = SchemaField(
description="Message for the commit",
default="Create new file",
)
class Output(BlockSchemaOutput):
url: str = SchemaField(description="URL of the created file")
sha: str = SchemaField(description="SHA of the commit")
error: str = SchemaField(
description="Error message if the file creation failed"
)
def __init__(self):
super().__init__(
id="8fd132ac-b917-428a-8159-d62893e8a3fe",
description="This block creates a new file in a GitHub repository.",
categories={BlockCategory.DEVELOPER_TOOLS},
input_schema=GithubCreateFileBlock.Input,
output_schema=GithubCreateFileBlock.Output,
test_input={
"repo_url": "https://github.com/owner/repo",
"file_path": "test/file.txt",
"content": "Test content",
"branch": "main",
"commit_message": "Create test file",
"credentials": TEST_CREDENTIALS_INPUT,
},
test_credentials=TEST_CREDENTIALS,
test_output=[
("url", "https://github.com/owner/repo/blob/main/test/file.txt"),
("sha", "abc123"),
],
test_mock={
"create_file": lambda *args, **kwargs: (
"https://github.com/owner/repo/blob/main/test/file.txt",
"abc123",
)
},
)
@staticmethod
async def create_file(
credentials: GithubCredentials,
repo_url: str,
file_path: str,
content: str,
branch: str,
commit_message: str,
) -> tuple[str, str]:
api = get_api(credentials)
contents_url = repo_url + f"/contents/{quote(file_path, safe='/')}"
content_base64 = base64.b64encode(content.encode()).decode()
data = {
"message": commit_message,
"content": content_base64,
"branch": branch,
}
response = await api.put(contents_url, json=data)
data = response.json()
return data["content"]["html_url"], data["commit"]["sha"]
async def run(
self,
input_data: Input,
*,
credentials: GithubCredentials,
**kwargs,
) -> BlockOutput:
try:
url, sha = await self.create_file(
credentials,
input_data.repo_url,
input_data.file_path,
input_data.content,
input_data.branch,
input_data.commit_message,
)
yield "url", url
yield "sha", sha
except Exception as e:
yield "error", str(e)
class GithubUpdateFileBlock(Block):
class Input(BlockSchemaInput):
credentials: GithubCredentialsInput = GithubCredentialsField("repo")
repo_url: str = SchemaField(
description="URL of the GitHub repository",
placeholder="https://github.com/owner/repo",
)
file_path: str = SchemaField(
description="Path to the file to update",
placeholder="path/to/file.txt",
)
content: str = SchemaField(
description="New content for the file",
placeholder="Updated content here",
)
branch: str = SchemaField(
description="Branch containing the file",
default="main",
)
commit_message: str = SchemaField(
description="Message for the commit",
default="Update file",
)
class Output(BlockSchemaOutput):
url: str = SchemaField(description="URL of the updated file")
sha: str = SchemaField(description="SHA of the commit")
def __init__(self):
super().__init__(
id="30be12a4-57cb-4aa4-baf5-fcc68d136076",
description="This block updates an existing file in a GitHub repository.",
categories={BlockCategory.DEVELOPER_TOOLS},
input_schema=GithubUpdateFileBlock.Input,
output_schema=GithubUpdateFileBlock.Output,
test_input={
"repo_url": "https://github.com/owner/repo",
"file_path": "test/file.txt",
"content": "Updated content",
"branch": "main",
"commit_message": "Update test file",
"credentials": TEST_CREDENTIALS_INPUT,
},
test_credentials=TEST_CREDENTIALS,
test_output=[
("url", "https://github.com/owner/repo/blob/main/test/file.txt"),
("sha", "def456"),
],
test_mock={
"update_file": lambda *args, **kwargs: (
"https://github.com/owner/repo/blob/main/test/file.txt",
"def456",
)
},
)
@staticmethod
async def update_file(
credentials: GithubCredentials,
repo_url: str,
file_path: str,
content: str,
branch: str,
commit_message: str,
) -> tuple[str, str]:
api = get_api(credentials)
contents_url = repo_url + f"/contents/{quote(file_path, safe='/')}"
params = {"ref": branch}
response = await api.get(contents_url, params=params)
data = response.json()
# Convert new content to base64
content_base64 = base64.b64encode(content.encode()).decode()
data = {
"message": commit_message,
"content": content_base64,
"sha": data["sha"],
"branch": branch,
}
response = await api.put(contents_url, json=data)
data = response.json()
return data["content"]["html_url"], data["commit"]["sha"]
async def run(
self,
input_data: Input,
*,
credentials: GithubCredentials,
**kwargs,
) -> BlockOutput:
try:
url, sha = await self.update_file(
credentials,
input_data.repo_url,
input_data.file_path,
input_data.content,
input_data.branch,
input_data.commit_message,
)
yield "url", url
yield "sha", sha
except Exception as e:
yield "error", str(e)
class GithubSearchCodeBlock(Block):
class Input(BlockSchemaInput):
credentials: GithubCredentialsInput = GithubCredentialsField("repo")
query: str = SchemaField(
description="Search query (GitHub code search syntax)",
placeholder="className language:python",
)
repo: str = SchemaField(
description="Restrict search to a repository (owner/repo format, optional)",
default="",
placeholder="owner/repo",
)
per_page: int = SchemaField(
description="Number of results to return (max 100)",
default=30,
ge=1,
le=100,
)
class Output(BlockSchemaOutput):
class SearchResult(TypedDict):
name: str
path: str
repository: str
url: str
score: float
result: SearchResult = SchemaField(
title="Result", description="A code search result"
)
results: list[SearchResult] = SchemaField(
description="List of code search results"
)
total_count: int = SchemaField(description="Total number of matching results")
error: str = SchemaField(description="Error message if search failed")
def __init__(self):
super().__init__(
id="47f94891-a2b1-4f1c-b5f2-573c043f721e",
description="This block searches for code in GitHub repositories.",
categories={BlockCategory.DEVELOPER_TOOLS},
input_schema=GithubSearchCodeBlock.Input,
output_schema=GithubSearchCodeBlock.Output,
test_input={
"query": "addClass",
"repo": "owner/repo",
"per_page": 30,
"credentials": TEST_CREDENTIALS_INPUT,
},
test_credentials=TEST_CREDENTIALS,
test_output=[
("total_count", 1),
(
"results",
[
{
"name": "file.py",
"path": "src/file.py",
"repository": "owner/repo",
"url": "https://github.com/owner/repo/blob/main/src/file.py",
"score": 1.0,
}
],
),
(
"result",
{
"name": "file.py",
"path": "src/file.py",
"repository": "owner/repo",
"url": "https://github.com/owner/repo/blob/main/src/file.py",
"score": 1.0,
},
),
],
test_mock={
"search_code": lambda *args, **kwargs: (
1,
[
{
"name": "file.py",
"path": "src/file.py",
"repository": "owner/repo",
"url": "https://github.com/owner/repo/blob/main/src/file.py",
"score": 1.0,
}
],
)
},
)
@staticmethod
async def search_code(
credentials: GithubCredentials,
query: str,
repo: str,
per_page: int,
) -> tuple[int, list[Output.SearchResult]]:
api = get_api(credentials, convert_urls=False)
full_query = f"{query} repo:{repo}" if repo else query
params = {"q": full_query, "per_page": str(per_page)}
response = await api.get("https://api.github.com/search/code", params=params)
data = response.json()
results: list[GithubSearchCodeBlock.Output.SearchResult] = [
GithubSearchCodeBlock.Output.SearchResult(
name=item["name"],
path=item["path"],
repository=item["repository"]["full_name"],
url=item["html_url"],
score=item["score"],
)
for item in data["items"]
]
return data["total_count"], results
async def run(
self,
input_data: Input,
*,
credentials: GithubCredentials,
**kwargs,
) -> BlockOutput:
try:
total_count, results = await self.search_code(
credentials,
input_data.query,
input_data.repo,
input_data.per_page,
)
yield "total_count", total_count
yield "results", results
for result in results:
yield "result", result
except Exception as e:
yield "error", str(e)
class GithubGetRepositoryTreeBlock(Block):
class Input(BlockSchemaInput):
credentials: GithubCredentialsInput = GithubCredentialsField("repo")
repo_url: str = SchemaField(
description="URL of the GitHub repository",
placeholder="https://github.com/owner/repo",
)
branch: str = SchemaField(
description="Branch name to get the tree from",
default="main",
)
recursive: bool = SchemaField(
description="Whether to recursively list the entire tree",
default=True,
)
class Output(BlockSchemaOutput):
class TreeEntry(TypedDict):
path: str
type: str
size: int
sha: str
entry: TreeEntry = SchemaField(
title="Tree Entry", description="A file or directory in the tree"
)
entries: list[TreeEntry] = SchemaField(
description="List of all files and directories in the tree"
)
truncated: bool = SchemaField(
description="Whether the tree was truncated due to size"
)
error: str = SchemaField(description="Error message if getting tree failed")
def __init__(self):
super().__init__(
id="89c5c0ec-172e-4001-a32c-bdfe4d0c9e81",
description="This block lists the entire file tree of a GitHub repository recursively.",
categories={BlockCategory.DEVELOPER_TOOLS},
input_schema=GithubGetRepositoryTreeBlock.Input,
output_schema=GithubGetRepositoryTreeBlock.Output,
test_input={
"repo_url": "https://github.com/owner/repo",
"branch": "main",
"recursive": True,
"credentials": TEST_CREDENTIALS_INPUT,
},
test_credentials=TEST_CREDENTIALS,
test_output=[
("truncated", False),
(
"entries",
[
{
"path": "src/main.py",
"type": "blob",
"size": 1234,
"sha": "abc123",
}
],
),
(
"entry",
{
"path": "src/main.py",
"type": "blob",
"size": 1234,
"sha": "abc123",
},
),
],
test_mock={
"get_tree": lambda *args, **kwargs: (
False,
[
{
"path": "src/main.py",
"type": "blob",
"size": 1234,
"sha": "abc123",
}
],
)
},
)
@staticmethod
async def get_tree(
credentials: GithubCredentials,
repo_url: str,
branch: str,
recursive: bool,
) -> tuple[bool, list[Output.TreeEntry]]:
api = get_api(credentials)
tree_url = repo_url + f"/git/trees/{quote(branch, safe='')}"
params = {"recursive": "1"} if recursive else {}
response = await api.get(tree_url, params=params)
data = response.json()
entries: list[GithubGetRepositoryTreeBlock.Output.TreeEntry] = [
GithubGetRepositoryTreeBlock.Output.TreeEntry(
path=item["path"],
type=item["type"],
size=item.get("size", 0),
sha=item["sha"],
)
for item in data["tree"]
]
return data.get("truncated", False), entries
async def run(
self,
input_data: Input,
*,
credentials: GithubCredentials,
**kwargs,
) -> BlockOutput:
try:
truncated, entries = await self.get_tree(
credentials,
input_data.repo_url,
input_data.branch,
input_data.recursive,
)
yield "truncated", truncated
yield "entries", entries
for entry in entries:
yield "entry", entry
except Exception as e:
yield "error", str(e)

View File

@@ -1,120 +0,0 @@
import inspect
import pytest
from backend.blocks.github._auth import TEST_CREDENTIALS, TEST_CREDENTIALS_INPUT
from backend.blocks.github.commits import FileOperation, GithubMultiFileCommitBlock
from backend.blocks.github.pull_requests import (
GithubMergePullRequestBlock,
prepare_pr_api_url,
)
from backend.util.exceptions import BlockExecutionError
# ── prepare_pr_api_url tests ──
class TestPreparePrApiUrl:
def test_https_scheme_preserved(self):
result = prepare_pr_api_url("https://github.com/owner/repo/pull/42", "merge")
assert result == "https://github.com/owner/repo/pulls/42/merge"
def test_http_scheme_preserved(self):
result = prepare_pr_api_url("http://github.com/owner/repo/pull/1", "files")
assert result == "http://github.com/owner/repo/pulls/1/files"
def test_no_scheme_defaults_to_https(self):
result = prepare_pr_api_url("github.com/owner/repo/pull/5", "merge")
assert result == "https://github.com/owner/repo/pulls/5/merge"
def test_reviewers_path(self):
result = prepare_pr_api_url(
"https://github.com/owner/repo/pull/99", "requested_reviewers"
)
assert result == "https://github.com/owner/repo/pulls/99/requested_reviewers"
def test_invalid_url_returned_as_is(self):
url = "https://example.com/not-a-pr"
assert prepare_pr_api_url(url, "merge") == url
def test_empty_string(self):
assert prepare_pr_api_url("", "merge") == ""
# ── Error-path block tests ──
# When a block's run() yields ("error", msg), _execute() converts it to a
# BlockExecutionError. We call block.execute() directly (not execute_block_test,
# which returns early on empty test_output).
def _mock_block(block, mocks: dict):
"""Apply mocks to a block's static methods, wrapping sync mocks as async."""
for name, mock_fn in mocks.items():
original = getattr(block, name)
if inspect.iscoroutinefunction(original):
async def async_mock(*args, _fn=mock_fn, **kwargs):
return _fn(*args, **kwargs)
setattr(block, name, async_mock)
else:
setattr(block, name, mock_fn)
def _raise(exc: Exception):
"""Helper that returns a callable which raises the given exception."""
def _raiser(*args, **kwargs):
raise exc
return _raiser
@pytest.mark.asyncio
async def test_merge_pr_error_path():
block = GithubMergePullRequestBlock()
_mock_block(block, {"merge_pr": _raise(RuntimeError("PR not mergeable"))})
input_data = {
"pr_url": "https://github.com/owner/repo/pull/1",
"merge_method": "squash",
"commit_title": "",
"commit_message": "",
"credentials": TEST_CREDENTIALS_INPUT,
}
with pytest.raises(BlockExecutionError, match="PR not mergeable"):
async for _ in block.execute(input_data, credentials=TEST_CREDENTIALS):
pass
@pytest.mark.asyncio
async def test_multi_file_commit_error_path():
block = GithubMultiFileCommitBlock()
_mock_block(block, {"multi_file_commit": _raise(RuntimeError("ref update failed"))})
input_data = {
"repo_url": "https://github.com/owner/repo",
"branch": "feature",
"commit_message": "test",
"files": [{"path": "a.py", "content": "x", "operation": "upsert"}],
"credentials": TEST_CREDENTIALS_INPUT,
}
with pytest.raises(BlockExecutionError, match="ref update failed"):
async for _ in block.execute(input_data, credentials=TEST_CREDENTIALS):
pass
# ── FileOperation enum tests ──
class TestFileOperation:
def test_upsert_value(self):
assert FileOperation.UPSERT == "upsert"
def test_delete_value(self):
assert FileOperation.DELETE == "delete"
def test_invalid_value_raises(self):
with pytest.raises(ValueError):
FileOperation("create")
def test_invalid_value_raises_typo(self):
with pytest.raises(ValueError):
FileOperation("upser")

View File

@@ -241,8 +241,8 @@ class GmailBase(Block, ABC):
h.ignore_links = False
h.ignore_images = True
return h.handle(html_content)
except Exception:
# Keep extraction resilient if html2text is unavailable or fails.
except ImportError:
# Fallback: return raw HTML if html2text is not available
return html_content
# Handle content stored as attachment

View File

@@ -67,7 +67,6 @@ class HITLReviewHelper:
graph_version: int,
block_name: str = "Block",
editable: bool = False,
is_graph_execution: bool = True,
) -> Optional[ReviewResult]:
"""
Handle a review request for a block that requires human review.
@@ -144,11 +143,10 @@ class HITLReviewHelper:
logger.info(
f"Block {block_name} pausing execution for node {node_exec_id} - awaiting human review"
)
if is_graph_execution:
await HITLReviewHelper.update_node_execution_status(
exec_id=node_exec_id,
status=ExecutionStatus.REVIEW,
)
await HITLReviewHelper.update_node_execution_status(
exec_id=node_exec_id,
status=ExecutionStatus.REVIEW,
)
return None # Signal that execution should pause
# Mark review as processed if not already done
@@ -170,7 +168,6 @@ class HITLReviewHelper:
graph_version: int,
block_name: str = "Block",
editable: bool = False,
is_graph_execution: bool = True,
) -> Optional[ReviewDecision]:
"""
Handle a review request and return the decision in a single call.
@@ -200,7 +197,6 @@ class HITLReviewHelper:
graph_version=graph_version,
block_name=block_name,
editable=editable,
is_graph_execution=is_graph_execution,
)
if review_result is None:

View File

@@ -140,31 +140,19 @@ class LlmModel(str, Enum, metaclass=LlmModelMeta):
# OpenRouter models
OPENAI_GPT_OSS_120B = "openai/gpt-oss-120b"
OPENAI_GPT_OSS_20B = "openai/gpt-oss-20b"
GEMINI_2_5_PRO_PREVIEW = "google/gemini-2.5-pro-preview-03-25"
GEMINI_2_5_PRO = "google/gemini-2.5-pro"
GEMINI_3_1_PRO_PREVIEW = "google/gemini-3.1-pro-preview"
GEMINI_3_FLASH_PREVIEW = "google/gemini-3-flash-preview"
GEMINI_2_5_PRO = "google/gemini-2.5-pro-preview-03-25"
GEMINI_3_PRO_PREVIEW = "google/gemini-3-pro-preview"
GEMINI_2_5_FLASH = "google/gemini-2.5-flash"
GEMINI_2_0_FLASH = "google/gemini-2.0-flash-001"
GEMINI_3_1_FLASH_LITE_PREVIEW = "google/gemini-3.1-flash-lite-preview"
GEMINI_2_5_FLASH_LITE_PREVIEW = "google/gemini-2.5-flash-lite-preview-06-17"
GEMINI_2_0_FLASH_LITE = "google/gemini-2.0-flash-lite-001"
MISTRAL_NEMO = "mistralai/mistral-nemo"
MISTRAL_LARGE_3 = "mistralai/mistral-large-2512"
MISTRAL_MEDIUM_3_1 = "mistralai/mistral-medium-3.1"
MISTRAL_SMALL_3_2 = "mistralai/mistral-small-3.2-24b-instruct"
CODESTRAL = "mistralai/codestral-2508"
COHERE_COMMAND_R_08_2024 = "cohere/command-r-08-2024"
COHERE_COMMAND_R_PLUS_08_2024 = "cohere/command-r-plus-08-2024"
COHERE_COMMAND_A_03_2025 = "cohere/command-a-03-2025"
COHERE_COMMAND_A_TRANSLATE_08_2025 = "cohere/command-a-translate-08-2025"
COHERE_COMMAND_A_REASONING_08_2025 = "cohere/command-a-reasoning-08-2025"
COHERE_COMMAND_A_VISION_07_2025 = "cohere/command-a-vision-07-2025"
DEEPSEEK_CHAT = "deepseek/deepseek-chat" # Actually: DeepSeek V3
DEEPSEEK_R1_0528 = "deepseek/deepseek-r1-0528"
PERPLEXITY_SONAR = "perplexity/sonar"
PERPLEXITY_SONAR_PRO = "perplexity/sonar-pro"
PERPLEXITY_SONAR_REASONING_PRO = "perplexity/sonar-reasoning-pro"
PERPLEXITY_SONAR_DEEP_RESEARCH = "perplexity/sonar-deep-research"
NOUSRESEARCH_HERMES_3_LLAMA_3_1_405B = "nousresearch/hermes-3-llama-3.1-405b"
NOUSRESEARCH_HERMES_3_LLAMA_3_1_70B = "nousresearch/hermes-3-llama-3.1-70b"
@@ -172,11 +160,9 @@ class LlmModel(str, Enum, metaclass=LlmModelMeta):
AMAZON_NOVA_MICRO_V1 = "amazon/nova-micro-v1"
AMAZON_NOVA_PRO_V1 = "amazon/nova-pro-v1"
MICROSOFT_WIZARDLM_2_8X22B = "microsoft/wizardlm-2-8x22b"
MICROSOFT_PHI_4 = "microsoft/phi-4"
GRYPHE_MYTHOMAX_L2_13B = "gryphe/mythomax-l2-13b"
META_LLAMA_4_SCOUT = "meta-llama/llama-4-scout"
META_LLAMA_4_MAVERICK = "meta-llama/llama-4-maverick"
GROK_3 = "x-ai/grok-3"
GROK_4 = "x-ai/grok-4"
GROK_4_FAST = "x-ai/grok-4-fast"
GROK_4_1_FAST = "x-ai/grok-4.1-fast"
@@ -354,41 +340,17 @@ MODEL_METADATA = {
"ollama", 32768, None, "Dolphin Mistral Latest", "Ollama", "Mistral AI", 1
),
# https://openrouter.ai/models
LlmModel.GEMINI_2_5_PRO_PREVIEW: ModelMetadata(
LlmModel.GEMINI_2_5_PRO: ModelMetadata(
"open_router",
1048576,
65536,
1050000,
8192,
"Gemini 2.5 Pro Preview 03.25",
"OpenRouter",
"Google",
2,
),
LlmModel.GEMINI_2_5_PRO: ModelMetadata(
"open_router",
1048576,
65536,
"Gemini 2.5 Pro",
"OpenRouter",
"Google",
2,
),
LlmModel.GEMINI_3_1_PRO_PREVIEW: ModelMetadata(
"open_router",
1048576,
65536,
"Gemini 3.1 Pro Preview",
"OpenRouter",
"Google",
2,
),
LlmModel.GEMINI_3_FLASH_PREVIEW: ModelMetadata(
"open_router",
1048576,
65536,
"Gemini 3 Flash Preview",
"OpenRouter",
"Google",
1,
LlmModel.GEMINI_3_PRO_PREVIEW: ModelMetadata(
"open_router", 1048576, 65535, "Gemini 3 Pro Preview", "OpenRouter", "Google", 2
),
LlmModel.GEMINI_2_5_FLASH: ModelMetadata(
"open_router", 1048576, 65535, "Gemini 2.5 Flash", "OpenRouter", "Google", 1
@@ -396,15 +358,6 @@ MODEL_METADATA = {
LlmModel.GEMINI_2_0_FLASH: ModelMetadata(
"open_router", 1048576, 8192, "Gemini 2.0 Flash 001", "OpenRouter", "Google", 1
),
LlmModel.GEMINI_3_1_FLASH_LITE_PREVIEW: ModelMetadata(
"open_router",
1048576,
65536,
"Gemini 3.1 Flash Lite Preview",
"OpenRouter",
"Google",
1,
),
LlmModel.GEMINI_2_5_FLASH_LITE_PREVIEW: ModelMetadata(
"open_router",
1048576,
@@ -426,78 +379,12 @@ MODEL_METADATA = {
LlmModel.MISTRAL_NEMO: ModelMetadata(
"open_router", 128000, 4096, "Mistral Nemo", "OpenRouter", "Mistral AI", 1
),
LlmModel.MISTRAL_LARGE_3: ModelMetadata(
"open_router",
262144,
None,
"Mistral Large 3 2512",
"OpenRouter",
"Mistral AI",
2,
),
LlmModel.MISTRAL_MEDIUM_3_1: ModelMetadata(
"open_router",
131072,
None,
"Mistral Medium 3.1",
"OpenRouter",
"Mistral AI",
2,
),
LlmModel.MISTRAL_SMALL_3_2: ModelMetadata(
"open_router",
131072,
131072,
"Mistral Small 3.2 24B",
"OpenRouter",
"Mistral AI",
1,
),
LlmModel.CODESTRAL: ModelMetadata(
"open_router",
256000,
None,
"Codestral 2508",
"OpenRouter",
"Mistral AI",
1,
),
LlmModel.COHERE_COMMAND_R_08_2024: ModelMetadata(
"open_router", 128000, 4096, "Command R 08.2024", "OpenRouter", "Cohere", 1
),
LlmModel.COHERE_COMMAND_R_PLUS_08_2024: ModelMetadata(
"open_router", 128000, 4096, "Command R Plus 08.2024", "OpenRouter", "Cohere", 2
),
LlmModel.COHERE_COMMAND_A_03_2025: ModelMetadata(
"open_router", 256000, 8192, "Command A 03.2025", "OpenRouter", "Cohere", 2
),
LlmModel.COHERE_COMMAND_A_TRANSLATE_08_2025: ModelMetadata(
"open_router",
128000,
8192,
"Command A Translate 08.2025",
"OpenRouter",
"Cohere",
2,
),
LlmModel.COHERE_COMMAND_A_REASONING_08_2025: ModelMetadata(
"open_router",
256000,
32768,
"Command A Reasoning 08.2025",
"OpenRouter",
"Cohere",
3,
),
LlmModel.COHERE_COMMAND_A_VISION_07_2025: ModelMetadata(
"open_router",
128000,
8192,
"Command A Vision 07.2025",
"OpenRouter",
"Cohere",
2,
),
LlmModel.DEEPSEEK_CHAT: ModelMetadata(
"open_router", 64000, 2048, "DeepSeek Chat", "OpenRouter", "DeepSeek", 1
),
@@ -510,15 +397,6 @@ MODEL_METADATA = {
LlmModel.PERPLEXITY_SONAR_PRO: ModelMetadata(
"open_router", 200000, 8000, "Sonar Pro", "OpenRouter", "Perplexity", 2
),
LlmModel.PERPLEXITY_SONAR_REASONING_PRO: ModelMetadata(
"open_router",
128000,
8000,
"Sonar Reasoning Pro",
"OpenRouter",
"Perplexity",
2,
),
LlmModel.PERPLEXITY_SONAR_DEEP_RESEARCH: ModelMetadata(
"open_router",
128000,
@@ -564,9 +442,6 @@ MODEL_METADATA = {
LlmModel.MICROSOFT_WIZARDLM_2_8X22B: ModelMetadata(
"open_router", 65536, 4096, "WizardLM 2 8x22B", "OpenRouter", "Microsoft", 1
),
LlmModel.MICROSOFT_PHI_4: ModelMetadata(
"open_router", 16384, 16384, "Phi-4", "OpenRouter", "Microsoft", 1
),
LlmModel.GRYPHE_MYTHOMAX_L2_13B: ModelMetadata(
"open_router", 4096, 4096, "MythoMax L2 13B", "OpenRouter", "Gryphe", 1
),
@@ -576,15 +451,6 @@ MODEL_METADATA = {
LlmModel.META_LLAMA_4_MAVERICK: ModelMetadata(
"open_router", 1048576, 1000000, "Llama 4 Maverick", "OpenRouter", "Meta", 1
),
LlmModel.GROK_3: ModelMetadata(
"open_router",
131072,
131072,
"Grok 3",
"OpenRouter",
"xAI",
2,
),
LlmModel.GROK_4: ModelMetadata(
"open_router", 256000, 256000, "Grok 4", "OpenRouter", "xAI", 3
),

View File

@@ -4,7 +4,7 @@ from enum import Enum
from typing import Any, Literal
import openai
from pydantic import SecretStr, field_validator
from pydantic import SecretStr
from backend.blocks._base import (
Block,
@@ -13,7 +13,6 @@ from backend.blocks._base import (
BlockSchemaInput,
BlockSchemaOutput,
)
from backend.data.block import BlockInput
from backend.data.model import (
APIKeyCredentials,
CredentialsField,
@@ -36,20 +35,6 @@ class PerplexityModel(str, Enum):
SONAR_DEEP_RESEARCH = "perplexity/sonar-deep-research"
def _sanitize_perplexity_model(value: Any) -> PerplexityModel:
"""Return a valid PerplexityModel, falling back to SONAR for invalid values."""
if isinstance(value, PerplexityModel):
return value
try:
return PerplexityModel(value)
except ValueError:
logger.warning(
f"Invalid PerplexityModel '{value}', "
f"falling back to {PerplexityModel.SONAR.value}"
)
return PerplexityModel.SONAR
PerplexityCredentials = CredentialsMetaInput[
Literal[ProviderName.OPEN_ROUTER], Literal["api_key"]
]
@@ -88,25 +73,6 @@ class PerplexityBlock(Block):
advanced=False,
)
credentials: PerplexityCredentials = PerplexityCredentialsField()
@field_validator("model", mode="before")
@classmethod
def fallback_invalid_model(cls, v: Any) -> PerplexityModel:
"""Fall back to SONAR if the model value is not a valid
PerplexityModel (e.g. an OpenAI model ID set by the agent
generator)."""
return _sanitize_perplexity_model(v)
@classmethod
def validate_data(cls, data: BlockInput) -> str | None:
"""Sanitize the model field before JSON schema validation so that
invalid values are replaced with the default instead of raising a
BlockInputError."""
model_value = data.get("model")
if model_value is not None:
data["model"] = _sanitize_perplexity_model(model_value).value
return super().validate_data(data)
system_prompt: str = SchemaField(
title="System Prompt",
default="",

View File

@@ -2232,7 +2232,6 @@ class DeleteRedditPostBlock(Block):
("post_id", "abc123"),
],
test_mock={"delete_post": lambda creds, post_id: True},
is_sensitive_action=True,
)
@staticmethod
@@ -2291,7 +2290,6 @@ class DeleteRedditCommentBlock(Block):
("comment_id", "xyz789"),
],
test_mock={"delete_comment": lambda creds, comment_id: True},
is_sensitive_action=True,
)
@staticmethod

View File

@@ -72,7 +72,6 @@ class Slant3DCreateOrderBlock(Slant3DBlockBase):
"_make_request": lambda *args, **kwargs: {"orderId": "314144241"},
"_convert_to_color": lambda *args, **kwargs: "black",
},
is_sensitive_action=True,
)
async def run(

View File

@@ -1,81 +0,0 @@
"""Unit tests for PerplexityBlock model fallback behavior."""
import pytest
from backend.blocks.perplexity import (
TEST_CREDENTIALS_INPUT,
PerplexityBlock,
PerplexityModel,
)
def _make_input(**overrides) -> dict:
defaults = {
"prompt": "test query",
"credentials": TEST_CREDENTIALS_INPUT,
}
defaults.update(overrides)
return defaults
class TestPerplexityModelFallback:
"""Tests for fallback_invalid_model field_validator."""
def test_invalid_model_falls_back_to_sonar(self):
inp = PerplexityBlock.Input(**_make_input(model="gpt-5.2-2025-12-11"))
assert inp.model == PerplexityModel.SONAR
def test_another_invalid_model_falls_back_to_sonar(self):
inp = PerplexityBlock.Input(**_make_input(model="gpt-4o"))
assert inp.model == PerplexityModel.SONAR
def test_valid_model_string_is_kept(self):
inp = PerplexityBlock.Input(**_make_input(model="perplexity/sonar-pro"))
assert inp.model == PerplexityModel.SONAR_PRO
def test_valid_enum_value_is_kept(self):
inp = PerplexityBlock.Input(
**_make_input(model=PerplexityModel.SONAR_DEEP_RESEARCH)
)
assert inp.model == PerplexityModel.SONAR_DEEP_RESEARCH
def test_default_model_when_omitted(self):
inp = PerplexityBlock.Input(**_make_input())
assert inp.model == PerplexityModel.SONAR
@pytest.mark.parametrize(
"model_value",
[
"perplexity/sonar",
"perplexity/sonar-pro",
"perplexity/sonar-deep-research",
],
)
def test_all_valid_models_accepted(self, model_value: str):
inp = PerplexityBlock.Input(**_make_input(model=model_value))
assert inp.model.value == model_value
class TestPerplexityValidateData:
"""Tests for validate_data which runs during block execution (before
Pydantic instantiation). Invalid models must be sanitized here so
JSON schema validation does not reject them."""
def test_invalid_model_sanitized_before_schema_validation(self):
data = _make_input(model="gpt-5.2-2025-12-11")
error = PerplexityBlock.Input.validate_data(data)
assert error is None
assert data["model"] == PerplexityModel.SONAR.value
def test_valid_model_unchanged_by_validate_data(self):
data = _make_input(model="perplexity/sonar-pro")
error = PerplexityBlock.Input.validate_data(data)
assert error is None
assert data["model"] == "perplexity/sonar-pro"
def test_missing_model_uses_default(self):
data = _make_input() # no model key
error = PerplexityBlock.Input.validate_data(data)
assert error is None
inp = PerplexityBlock.Input(**data)
assert inp.model == PerplexityModel.SONAR

View File

@@ -18,13 +18,11 @@ from langfuse import propagate_attributes
from backend.copilot.model import (
ChatMessage,
ChatSession,
Usage,
get_chat_session,
update_session_title,
upsert_chat_session,
)
from backend.copilot.prompting import get_baseline_supplement
from backend.copilot.rate_limit import record_token_usage
from backend.copilot.response_model import (
StreamBaseResponse,
StreamError,
@@ -38,7 +36,6 @@ from backend.copilot.response_model import (
StreamToolInputAvailable,
StreamToolInputStart,
StreamToolOutputAvailable,
StreamUsage,
)
from backend.copilot.service import (
_build_system_prompt,
@@ -49,11 +46,7 @@ from backend.copilot.service import (
from backend.copilot.tools import execute_tool, get_available_tools
from backend.copilot.tracking import track_user_message
from backend.util.exceptions import NotFoundError
from backend.util.prompt import (
compress_context,
estimate_token_count,
estimate_token_count_str,
)
from backend.util.prompt import compress_context
logger = logging.getLogger(__name__)
@@ -228,9 +221,6 @@ async def stream_chat_completion_baseline(
text_block_id = str(uuid.uuid4())
text_started = False
step_open = False
# Token usage accumulators — populated from streaming chunks
turn_prompt_tokens = 0
turn_completion_tokens = 0
try:
for _round in range(_MAX_TOOL_ROUNDS):
# Open a new step for each LLM round
@@ -242,7 +232,6 @@ async def stream_chat_completion_baseline(
model=config.model,
messages=openai_messages,
stream=True,
stream_options={"include_usage": True},
)
if tools:
create_kwargs["tools"] = tools
@@ -253,18 +242,7 @@ async def stream_chat_completion_baseline(
tool_calls_by_index: dict[int, dict[str, str]] = {}
async for chunk in response:
# Capture token usage from the streaming chunk.
# OpenRouter normalises all providers into OpenAI format
# where prompt_tokens already includes cached tokens
# (unlike Anthropic's native API). Use += to sum all
# tool-call rounds since each API call is independent.
if chunk.usage:
turn_prompt_tokens += chunk.usage.prompt_tokens or 0
turn_completion_tokens += chunk.usage.completion_tokens or 0
if not chunk.choices:
continue
delta = chunk.choices[0].delta
delta = chunk.choices[0].delta if chunk.choices else None
if not delta:
continue
@@ -433,53 +411,6 @@ async def stream_chat_completion_baseline(
except Exception:
logger.warning("[Baseline] Langfuse trace context teardown failed")
# Fallback: estimate tokens via tiktoken when the provider does
# not honour stream_options={"include_usage": True}.
# Count the full message list (system + history + turn) since
# each API call sends the complete context window.
if turn_prompt_tokens == 0 and turn_completion_tokens == 0:
turn_prompt_tokens = max(
estimate_token_count(openai_messages, model=config.model), 0
)
turn_completion_tokens = max(
estimate_token_count_str(assistant_text, model=config.model), 0
)
logger.info(
"[Baseline] No streaming usage reported; estimated tokens: "
"prompt=%d, completion=%d",
turn_prompt_tokens,
turn_completion_tokens,
)
# Emit token usage and update session for persistence
if turn_prompt_tokens > 0 or turn_completion_tokens > 0:
total_tokens = turn_prompt_tokens + turn_completion_tokens
session.usage.append(
Usage(
prompt_tokens=turn_prompt_tokens,
completion_tokens=turn_completion_tokens,
total_tokens=total_tokens,
)
)
logger.info(
"[Baseline] Turn usage: prompt=%d, completion=%d, total=%d",
turn_prompt_tokens,
turn_completion_tokens,
total_tokens,
)
# Record for rate limiting counters
if user_id:
try:
await record_token_usage(
user_id=user_id,
prompt_tokens=turn_prompt_tokens,
completion_tokens=turn_completion_tokens,
)
except Exception as usage_err:
logger.warning(
"[Baseline] Failed to record token usage: %s", usage_err
)
# Persist assistant response
if assistant_text:
session.messages.append(
@@ -490,16 +421,4 @@ async def stream_chat_completion_baseline(
except Exception as persist_err:
logger.error("[Baseline] Failed to persist session: %s", persist_err)
# Yield usage and finish AFTER try/finally (not inside finally).
# PEP 525 prohibits yielding from finally in async generators during
# aclose() — doing so raises RuntimeError on client disconnect.
# On GeneratorExit the client is already gone, so unreachable yields
# are harmless; on normal completion they reach the SSE stream.
if turn_prompt_tokens > 0 or turn_completion_tokens > 0:
yield StreamUsage(
promptTokens=turn_prompt_tokens,
completionTokens=turn_completion_tokens,
totalTokens=turn_prompt_tokens + turn_completion_tokens,
)
yield StreamFinish()

View File

@@ -70,20 +70,6 @@ class ChatConfig(BaseSettings):
description="Cache TTL in seconds for Langfuse prompt (0 to disable caching)",
)
# Rate limiting — token-based limits per day and per week.
# Each CoPilot turn consumes ~10-15K tokens (system prompt + tool schemas + response),
# so 2.5M daily allows ~170-250 turns/day which is reasonable for normal use.
# TODO: These are global deploy-time constants. For per-user or per-plan limits,
# move to the database (e.g. UserPlan table) and look up in get_usage_status.
daily_token_limit: int = Field(
default=2_500_000,
description="Max tokens per day, resets at midnight UTC (0 = unlimited)",
)
weekly_token_limit: int = Field(
default=12_500_000,
description="Max tokens per week, resets Monday 00:00 UTC (0 = unlimited)",
)
# Claude Agent SDK Configuration
use_claude_agent_sdk: bool = Field(
default=True,
@@ -129,7 +115,7 @@ class ChatConfig(BaseSettings):
description="E2B sandbox template to use for copilot sessions.",
)
e2b_sandbox_timeout: int = Field(
default=300, # 5 min safety net explicit per-turn pause is the primary mechanism
default=10800, # 3 hours — wall-clock timeout, not idle; explicit pause is primary
description="E2B sandbox running-time timeout (seconds). "
"E2B timeout is wall-clock (not idle). Explicit per-turn pause is the primary "
"mechanism; this is the safety net.",

View File

@@ -6,32 +6,6 @@
COPILOT_ERROR_PREFIX = "[__COPILOT_ERROR_f7a1__]" # Renders as ErrorCard
COPILOT_SYSTEM_PREFIX = "[__COPILOT_SYSTEM_e3b0__]" # Renders as system info message
# Prefix for all synthetic IDs generated by CoPilot block execution.
# Used to distinguish CoPilot-generated records from real graph execution records
# in PendingHumanReview and other tables.
COPILOT_SYNTHETIC_ID_PREFIX = "copilot-"
# Sub-prefixes for session-scoped and node-scoped synthetic IDs.
COPILOT_SESSION_PREFIX = f"{COPILOT_SYNTHETIC_ID_PREFIX}session-"
COPILOT_NODE_PREFIX = f"{COPILOT_SYNTHETIC_ID_PREFIX}node-"
# Separator used in synthetic node_exec_id to encode node_id.
# Format: "{node_id}:{random_hex}" — extract node_id via rsplit(":", 1)[0]
COPILOT_NODE_EXEC_ID_SEPARATOR = ":"
# Compaction notice messages shown to users.
COMPACTION_DONE_MSG = "Earlier messages were summarized to fit within context limits."
COMPACTION_TOOL_NAME = "context_compaction"
def is_copilot_synthetic_id(id_value: str) -> bool:
"""Check if an ID is a CoPilot synthetic ID (not from a real graph execution)."""
return id_value.startswith(COPILOT_SYNTHETIC_ID_PREFIX)
def parse_node_id_from_exec_id(node_exec_id: str) -> str:
"""Extract node_id from a synthetic node_exec_id.
Format: "{node_id}:{random_hex}" → returns "{node_id}".
"""
return node_exec_id.rsplit(COPILOT_NODE_EXEC_ID_SEPARATOR, 1)[0]

View File

@@ -73,9 +73,6 @@ class Usage(BaseModel):
prompt_tokens: int
completion_tokens: int
total_tokens: int
# Cache breakdown (Anthropic-specific; zero for non-Anthropic models)
cache_read_tokens: int = 0
cache_creation_tokens: int = 0
class ChatSessionInfo(BaseModel):

View File

@@ -52,11 +52,6 @@ Examples:
You can embed a reference inside any string argument, or use it as the entire
value. Multiple references in one argument are all expanded.
**Type coercion**: The platform automatically coerces expanded string values
to match the block's expected input types. For example, if a block expects
`list[list[str]]` and you pass a string containing a JSON array (e.g. from
an @@agptfile: expansion), the string will be parsed into the correct type.
### Sub-agent tasks
- When using the Task tool, NEVER set `run_in_background` to true.

View File

@@ -1,253 +0,0 @@
"""CoPilot rate limiting based on token usage.
Uses Redis fixed-window counters to track per-user token consumption
with configurable daily and weekly limits. Daily windows reset at
midnight UTC; weekly windows reset at ISO week boundary (Monday 00:00
UTC). Fails open when Redis is unavailable to avoid blocking users.
"""
import asyncio
import logging
from datetime import UTC, datetime, timedelta
from pydantic import BaseModel, Field
from backend.data.redis_client import get_redis_async
logger = logging.getLogger(__name__)
# Redis key prefixes
_PREFIX = "copilot:usage"
class UsageWindow(BaseModel):
"""Usage within a single time window."""
used: int
limit: int = Field(
description="Maximum tokens allowed in this window. 0 means unlimited."
)
resets_at: datetime
class CoPilotUsageStatus(BaseModel):
"""Current usage status for a user across all windows."""
daily: UsageWindow
weekly: UsageWindow
class RateLimitExceeded(Exception):
"""Raised when a user exceeds their CoPilot usage limit."""
def __init__(self, window: str, resets_at: datetime):
self.window = window
self.resets_at = resets_at
delta = resets_at - datetime.now(UTC)
total_secs = delta.total_seconds()
if total_secs <= 0:
time_str = "now"
else:
hours = int(total_secs // 3600)
minutes = int((total_secs % 3600) // 60)
time_str = f"{hours}h {minutes}m" if hours > 0 else f"{minutes}m"
super().__init__(
f"You've reached your {window} usage limit. Resets in {time_str}."
)
def _daily_key(user_id: str, now: datetime | None = None) -> str:
if now is None:
now = datetime.now(UTC)
return f"{_PREFIX}:daily:{user_id}:{now.strftime('%Y-%m-%d')}"
def _weekly_key(user_id: str, now: datetime | None = None) -> str:
if now is None:
now = datetime.now(UTC)
year, week, _ = now.isocalendar()
return f"{_PREFIX}:weekly:{user_id}:{year}-W{week:02d}"
def _daily_reset_time(now: datetime | None = None) -> datetime:
"""Calculate when the current daily window resets (next midnight UTC)."""
if now is None:
now = datetime.now(UTC)
return now.replace(hour=0, minute=0, second=0, microsecond=0) + timedelta(days=1)
def _weekly_reset_time(now: datetime | None = None) -> datetime:
"""Calculate when the current weekly window resets (next Monday 00:00 UTC).
On Monday itself, ``(7 - weekday) % 7`` is 0; the ``or 7`` fallback
pushes to *next* Monday so the current week's window stays open.
"""
if now is None:
now = datetime.now(UTC)
days_until_monday = (7 - now.weekday()) % 7 or 7
return now.replace(hour=0, minute=0, second=0, microsecond=0) + timedelta(
days=days_until_monday
)
async def _fetch_counters(user_id: str, now: datetime) -> tuple[int, int]:
"""Fetch daily and weekly token counters from Redis.
Returns (daily_used, weekly_used). Returns (0, 0) if Redis is unavailable.
"""
redis = await get_redis_async()
daily_raw, weekly_raw = await asyncio.gather(
redis.get(_daily_key(user_id, now=now)),
redis.get(_weekly_key(user_id, now=now)),
)
return int(daily_raw or 0), int(weekly_raw or 0)
async def get_usage_status(
user_id: str,
daily_token_limit: int,
weekly_token_limit: int,
) -> CoPilotUsageStatus:
"""Get current usage status for a user.
Args:
user_id: The user's ID.
daily_token_limit: Max tokens per day (0 = unlimited).
weekly_token_limit: Max tokens per week (0 = unlimited).
Returns:
CoPilotUsageStatus with current usage and limits.
"""
now = datetime.now(UTC)
try:
daily_used, weekly_used = await _fetch_counters(user_id, now)
except Exception:
logger.warning(
"Redis unavailable for usage status, returning zeros", exc_info=True
)
daily_used, weekly_used = 0, 0
return CoPilotUsageStatus(
daily=UsageWindow(
used=daily_used,
limit=daily_token_limit,
resets_at=_daily_reset_time(now=now),
),
weekly=UsageWindow(
used=weekly_used,
limit=weekly_token_limit,
resets_at=_weekly_reset_time(now=now),
),
)
async def check_rate_limit(
user_id: str,
daily_token_limit: int,
weekly_token_limit: int,
) -> None:
"""Check if user is within rate limits. Raises RateLimitExceeded if not.
This is a pre-turn soft check. The authoritative usage counter is updated
by ``record_token_usage()`` after the turn completes. Under concurrency,
two parallel turns may both pass this check against the same snapshot.
This is acceptable because token-based limits are approximate by nature
(the exact token count is unknown until after generation).
Fails open: if Redis is unavailable, allows the request.
"""
now = datetime.now(UTC)
try:
daily_used, weekly_used = await _fetch_counters(user_id, now)
except Exception:
logger.warning(
"Redis unavailable for rate limit check, allowing request", exc_info=True
)
return
if daily_token_limit > 0 and daily_used >= daily_token_limit:
raise RateLimitExceeded("daily", _daily_reset_time(now=now))
if weekly_token_limit > 0 and weekly_used >= weekly_token_limit:
raise RateLimitExceeded("weekly", _weekly_reset_time(now=now))
async def record_token_usage(
user_id: str,
prompt_tokens: int,
completion_tokens: int,
*,
cache_read_tokens: int = 0,
cache_creation_tokens: int = 0,
) -> None:
"""Record token usage for a user across all windows.
Uses cost-weighted counting so cached tokens don't unfairly penalise
multi-turn conversations. Anthropic's pricing:
- uncached input: 100%
- cache creation: 25%
- cache read: 10%
- output: 100%
``prompt_tokens`` should be the *uncached* input count (``input_tokens``
from the API response). Cache counts are passed separately.
Args:
user_id: The user's ID.
prompt_tokens: Uncached input tokens.
completion_tokens: Output tokens.
cache_read_tokens: Tokens served from prompt cache (10% cost).
cache_creation_tokens: Tokens written to prompt cache (25% cost).
"""
weighted_input = (
prompt_tokens
+ round(cache_creation_tokens * 0.25)
+ round(cache_read_tokens * 0.1)
)
total = weighted_input + completion_tokens
if total <= 0:
return
raw_total = (
prompt_tokens + cache_read_tokens + cache_creation_tokens + completion_tokens
)
logger.info(
"Recording token usage for %s: raw=%d, weighted=%d "
"(uncached=%d, cache_read=%d@10%%, cache_create=%d@25%%, output=%d)",
user_id[:8],
raw_total,
total,
prompt_tokens,
cache_read_tokens,
cache_creation_tokens,
completion_tokens,
)
now = datetime.now(UTC)
try:
redis = await get_redis_async()
pipe = redis.pipeline(transaction=False)
# Daily counter (expires at next midnight UTC)
d_key = _daily_key(user_id, now=now)
pipe.incrby(d_key, total)
seconds_until_daily_reset = int(
(_daily_reset_time(now=now) - now).total_seconds()
)
pipe.expire(d_key, max(seconds_until_daily_reset, 1))
# Weekly counter (expires end of week)
w_key = _weekly_key(user_id, now=now)
pipe.incrby(w_key, total)
seconds_until_weekly_reset = int(
(_weekly_reset_time(now=now) - now).total_seconds()
)
pipe.expire(w_key, max(seconds_until_weekly_reset, 1))
await pipe.execute()
except Exception:
logger.warning(
"Redis unavailable for recording token usage (tokens=%d)",
total,
exc_info=True,
)

View File

@@ -1,334 +0,0 @@
"""Unit tests for CoPilot rate limiting."""
from datetime import UTC, datetime, timedelta
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from redis.exceptions import RedisError
from .rate_limit import (
CoPilotUsageStatus,
RateLimitExceeded,
check_rate_limit,
get_usage_status,
record_token_usage,
)
_USER = "test-user-rl"
# ---------------------------------------------------------------------------
# RateLimitExceeded
# ---------------------------------------------------------------------------
class TestRateLimitExceeded:
def test_message_contains_window_name(self):
exc = RateLimitExceeded("daily", datetime.now(UTC) + timedelta(hours=1))
assert "daily" in str(exc)
def test_message_contains_reset_time(self):
exc = RateLimitExceeded(
"weekly", datetime.now(UTC) + timedelta(hours=2, minutes=30)
)
msg = str(exc)
# Allow for slight timing drift (29m or 30m)
assert "2h " in msg
assert "Resets in" in msg
def test_message_minutes_only_when_under_one_hour(self):
exc = RateLimitExceeded("daily", datetime.now(UTC) + timedelta(minutes=15))
msg = str(exc)
assert "Resets in" in msg
# Should not have "0h"
assert "0h" not in msg
def test_message_says_now_when_resets_at_is_in_the_past(self):
"""Negative delta (clock skew / stale TTL) should say 'now', not '-1h -30m'."""
exc = RateLimitExceeded("daily", datetime.now(UTC) - timedelta(minutes=5))
assert "Resets in now" in str(exc)
# ---------------------------------------------------------------------------
# get_usage_status
# ---------------------------------------------------------------------------
class TestGetUsageStatus:
@pytest.mark.asyncio
async def test_returns_redis_values(self):
mock_redis = AsyncMock()
mock_redis.get = AsyncMock(side_effect=["500", "2000"])
with patch(
"backend.copilot.rate_limit.get_redis_async",
return_value=mock_redis,
):
status = await get_usage_status(
_USER, daily_token_limit=10000, weekly_token_limit=50000
)
assert isinstance(status, CoPilotUsageStatus)
assert status.daily.used == 500
assert status.daily.limit == 10000
assert status.weekly.used == 2000
assert status.weekly.limit == 50000
@pytest.mark.asyncio
async def test_returns_zeros_when_redis_unavailable(self):
with patch(
"backend.copilot.rate_limit.get_redis_async",
side_effect=ConnectionError("Redis down"),
):
status = await get_usage_status(
_USER, daily_token_limit=10000, weekly_token_limit=50000
)
assert status.daily.used == 0
assert status.weekly.used == 0
@pytest.mark.asyncio
async def test_partial_none_daily_counter(self):
"""Daily counter is None (new day), weekly has usage."""
mock_redis = AsyncMock()
mock_redis.get = AsyncMock(side_effect=[None, "3000"])
with patch(
"backend.copilot.rate_limit.get_redis_async",
return_value=mock_redis,
):
status = await get_usage_status(
_USER, daily_token_limit=10000, weekly_token_limit=50000
)
assert status.daily.used == 0
assert status.weekly.used == 3000
@pytest.mark.asyncio
async def test_partial_none_weekly_counter(self):
"""Weekly counter is None (start of week), daily has usage."""
mock_redis = AsyncMock()
mock_redis.get = AsyncMock(side_effect=["500", None])
with patch(
"backend.copilot.rate_limit.get_redis_async",
return_value=mock_redis,
):
status = await get_usage_status(
_USER, daily_token_limit=10000, weekly_token_limit=50000
)
assert status.daily.used == 500
assert status.weekly.used == 0
@pytest.mark.asyncio
async def test_resets_at_daily_is_next_midnight_utc(self):
mock_redis = AsyncMock()
mock_redis.get = AsyncMock(side_effect=["0", "0"])
with patch(
"backend.copilot.rate_limit.get_redis_async",
return_value=mock_redis,
):
status = await get_usage_status(
_USER, daily_token_limit=10000, weekly_token_limit=50000
)
now = datetime.now(UTC)
# Daily reset should be within 24h
assert status.daily.resets_at > now
assert status.daily.resets_at <= now + timedelta(hours=24, seconds=5)
# ---------------------------------------------------------------------------
# check_rate_limit
# ---------------------------------------------------------------------------
class TestCheckRateLimit:
@pytest.mark.asyncio
async def test_allows_when_under_limit(self):
mock_redis = AsyncMock()
mock_redis.get = AsyncMock(side_effect=["100", "200"])
with patch(
"backend.copilot.rate_limit.get_redis_async",
return_value=mock_redis,
):
# Should not raise
await check_rate_limit(
_USER, daily_token_limit=10000, weekly_token_limit=50000
)
@pytest.mark.asyncio
async def test_raises_when_daily_limit_exceeded(self):
mock_redis = AsyncMock()
mock_redis.get = AsyncMock(side_effect=["10000", "200"])
with patch(
"backend.copilot.rate_limit.get_redis_async",
return_value=mock_redis,
):
with pytest.raises(RateLimitExceeded) as exc_info:
await check_rate_limit(
_USER, daily_token_limit=10000, weekly_token_limit=50000
)
assert exc_info.value.window == "daily"
@pytest.mark.asyncio
async def test_raises_when_weekly_limit_exceeded(self):
mock_redis = AsyncMock()
mock_redis.get = AsyncMock(side_effect=["100", "50000"])
with patch(
"backend.copilot.rate_limit.get_redis_async",
return_value=mock_redis,
):
with pytest.raises(RateLimitExceeded) as exc_info:
await check_rate_limit(
_USER, daily_token_limit=10000, weekly_token_limit=50000
)
assert exc_info.value.window == "weekly"
@pytest.mark.asyncio
async def test_allows_when_redis_unavailable(self):
"""Fail-open: allow requests when Redis is down."""
with patch(
"backend.copilot.rate_limit.get_redis_async",
side_effect=ConnectionError("Redis down"),
):
# Should not raise
await check_rate_limit(
_USER, daily_token_limit=10000, weekly_token_limit=50000
)
@pytest.mark.asyncio
async def test_skips_check_when_limit_is_zero(self):
mock_redis = AsyncMock()
mock_redis.get = AsyncMock(side_effect=["999999", "999999"])
with patch(
"backend.copilot.rate_limit.get_redis_async",
return_value=mock_redis,
):
# Should not raise — limits of 0 mean unlimited
await check_rate_limit(_USER, daily_token_limit=0, weekly_token_limit=0)
# ---------------------------------------------------------------------------
# record_token_usage
# ---------------------------------------------------------------------------
class TestRecordTokenUsage:
@staticmethod
def _make_pipeline_mock() -> MagicMock:
"""Create a pipeline mock with sync methods and async execute."""
pipe = MagicMock()
pipe.execute = AsyncMock(return_value=[])
return pipe
@pytest.mark.asyncio
async def test_increments_redis_counters(self):
mock_pipe = self._make_pipeline_mock()
mock_redis = AsyncMock()
mock_redis.pipeline = lambda **_kw: mock_pipe
with patch(
"backend.copilot.rate_limit.get_redis_async",
return_value=mock_redis,
):
await record_token_usage(_USER, prompt_tokens=100, completion_tokens=50)
# Should call incrby twice (daily + weekly) with total=150
incrby_calls = mock_pipe.incrby.call_args_list
assert len(incrby_calls) == 2
assert incrby_calls[0].args[1] == 150 # daily
assert incrby_calls[1].args[1] == 150 # weekly
@pytest.mark.asyncio
async def test_skips_when_zero_tokens(self):
mock_redis = AsyncMock()
with patch(
"backend.copilot.rate_limit.get_redis_async",
return_value=mock_redis,
):
await record_token_usage(_USER, prompt_tokens=0, completion_tokens=0)
# Should not call pipeline at all
mock_redis.pipeline.assert_not_called()
@pytest.mark.asyncio
async def test_sets_expire_on_both_keys(self):
"""Pipeline should call expire for both daily and weekly keys."""
mock_pipe = self._make_pipeline_mock()
mock_redis = AsyncMock()
mock_redis.pipeline = lambda **_kw: mock_pipe
with patch(
"backend.copilot.rate_limit.get_redis_async",
return_value=mock_redis,
):
await record_token_usage(_USER, prompt_tokens=100, completion_tokens=50)
expire_calls = mock_pipe.expire.call_args_list
assert len(expire_calls) == 2
# Daily key TTL should be positive (seconds until next midnight)
daily_ttl = expire_calls[0].args[1]
assert daily_ttl >= 1
# Weekly key TTL should be positive (seconds until next Monday)
weekly_ttl = expire_calls[1].args[1]
assert weekly_ttl >= 1
@pytest.mark.asyncio
async def test_handles_redis_failure_gracefully(self):
"""Should not raise when Redis is unavailable."""
with patch(
"backend.copilot.rate_limit.get_redis_async",
side_effect=ConnectionError("Redis down"),
):
# Should not raise
await record_token_usage(_USER, prompt_tokens=100, completion_tokens=50)
@pytest.mark.asyncio
async def test_cost_weighted_counting(self):
"""Cached tokens should be weighted: cache_read=10%, cache_create=25%."""
mock_pipe = self._make_pipeline_mock()
mock_redis = AsyncMock()
mock_redis.pipeline = lambda **_kw: mock_pipe
with patch(
"backend.copilot.rate_limit.get_redis_async",
return_value=mock_redis,
):
await record_token_usage(
_USER,
prompt_tokens=100, # uncached → 100
completion_tokens=50, # output → 50
cache_read_tokens=10000, # 10% → 1000
cache_creation_tokens=400, # 25% → 100
)
# Expected weighted total: 100 + 1000 + 100 + 50 = 1250
incrby_calls = mock_pipe.incrby.call_args_list
assert len(incrby_calls) == 2
assert incrby_calls[0].args[1] == 1250 # daily
assert incrby_calls[1].args[1] == 1250 # weekly
@pytest.mark.asyncio
async def test_handles_redis_error_during_pipeline_execute(self):
"""Should not raise when pipeline.execute() fails with RedisError."""
mock_pipe = self._make_pipeline_mock()
mock_pipe.execute = AsyncMock(side_effect=RedisError("Pipeline failed"))
mock_redis = AsyncMock()
mock_redis.pipeline = lambda **_kw: mock_pipe
with patch(
"backend.copilot.rate_limit.get_redis_async",
return_value=mock_redis,
):
# Should not raise — fail-open
await record_token_usage(_USER, prompt_tokens=100, completion_tokens=50)

View File

@@ -186,29 +186,12 @@ class StreamToolOutputAvailable(StreamBaseResponse):
class StreamUsage(StreamBaseResponse):
"""Token usage statistics.
Emitted as an SSE comment so the Vercel AI SDK parser ignores it
(it uses z.strictObject() and rejects unknown event types).
Usage data is recorded server-side (session DB + Redis counters).
"""
"""Token usage statistics."""
type: ResponseType = ResponseType.USAGE
promptTokens: int = Field(..., description="Number of uncached prompt tokens")
promptTokens: int = Field(..., description="Number of prompt tokens")
completionTokens: int = Field(..., description="Number of completion tokens")
totalTokens: int = Field(
..., description="Total number of tokens (raw, not weighted)"
)
cacheReadTokens: int = Field(
default=0, description="Prompt tokens served from cache (10% cost)"
)
cacheCreationTokens: int = Field(
default=0, description="Prompt tokens written to cache (25% cost)"
)
def to_sse(self) -> str:
"""Emit as SSE comment so the AI SDK parser ignores it."""
return f": usage {self.model_dump_json(exclude_none=True)}\n\n"
totalTokens: int = Field(..., description="Total number of tokens")
class StreamError(StreamBaseResponse):

View File

@@ -198,7 +198,6 @@ class CompactionTracker:
def reset_for_query(self) -> None:
"""Reset per-query state before a new SDK query."""
self._compact_start.clear()
self._done = False
self._start_emitted = False
self._tool_call_id = ""

View File

@@ -1,546 +0,0 @@
"""End-to-end compaction flow test.
Simulates the full service.py compaction lifecycle using real-format
JSONL session files — no SDK subprocess needed. Exercises:
1. TranscriptBuilder loads a "downloaded" transcript
2. User query appended, assistant response streamed
3. PreCompact hook fires → CompactionTracker.on_compact()
4. Next message → emit_start_if_ready() yields spinner events
5. Message after that → emit_end_if_ready() returns end events
6. _read_compacted_entries() reads the CLI session file
7. TranscriptBuilder.replace_entries() syncs state
8. More messages appended post-compaction
9. to_jsonl() exports full state for upload
10. Fresh builder loads the export — roundtrip verified
"""
import asyncio
from pathlib import Path
from backend.copilot.model import ChatSession
from backend.copilot.response_model import (
StreamFinishStep,
StreamStartStep,
StreamToolInputAvailable,
StreamToolInputStart,
StreamToolOutputAvailable,
)
from backend.copilot.sdk.compaction import CompactionTracker
from backend.copilot.sdk.transcript import strip_progress_entries
from backend.copilot.sdk.transcript_builder import TranscriptBuilder
from backend.util import json
def _make_jsonl(*entries: dict) -> str:
return "\n".join(json.dumps(e) for e in entries) + "\n"
def _run(coro):
"""Run an async coroutine synchronously."""
return asyncio.run(coro)
def _read_compacted_entries(path: str) -> tuple[list[dict], str] | None:
"""Test-only: read compacted entries from a session JSONL file.
Returns (parsed_dicts, jsonl_string) from the first ``isCompactSummary``
entry onward, or ``None`` if no summary is found.
"""
content = Path(path).read_text()
lines = content.strip().split("\n")
compact_idx: int | None = None
parsed: list[dict] = []
raw_lines: list[str] = []
for line in lines:
if not line.strip():
continue
entry = json.loads(line, fallback=None)
if not isinstance(entry, dict):
continue
parsed.append(entry)
raw_lines.append(line.strip())
if compact_idx is None and entry.get("isCompactSummary"):
compact_idx = len(parsed) - 1
if compact_idx is None:
return None
return parsed[compact_idx:], "\n".join(raw_lines[compact_idx:]) + "\n"
# ---------------------------------------------------------------------------
# Fixtures: realistic CLI session file content
# ---------------------------------------------------------------------------
# Pre-compaction conversation
USER_1 = {
"type": "user",
"uuid": "u1",
"message": {"role": "user", "content": "What files are in this project?"},
}
ASST_1_THINKING = {
"type": "assistant",
"uuid": "a1-think",
"parentUuid": "u1",
"message": {
"role": "assistant",
"id": "msg_sdk_aaa",
"type": "message",
"content": [{"type": "thinking", "thinking": "Let me look at the files..."}],
"stop_reason": None,
"stop_sequence": None,
},
}
ASST_1_TOOL = {
"type": "assistant",
"uuid": "a1-tool",
"parentUuid": "u1",
"message": {
"role": "assistant",
"id": "msg_sdk_aaa",
"type": "message",
"content": [
{
"type": "tool_use",
"id": "tu1",
"name": "Bash",
"input": {"command": "ls"},
}
],
"stop_reason": "tool_use",
"stop_sequence": None,
},
}
TOOL_RESULT_1 = {
"type": "user",
"uuid": "tr1",
"parentUuid": "a1-tool",
"message": {
"role": "user",
"content": [
{
"type": "tool_result",
"tool_use_id": "tu1",
"content": "file1.py\nfile2.py",
}
],
},
}
ASST_1_TEXT = {
"type": "assistant",
"uuid": "a1-text",
"parentUuid": "tr1",
"message": {
"role": "assistant",
"id": "msg_sdk_bbb",
"type": "message",
"content": [{"type": "text", "text": "I found file1.py and file2.py."}],
"stop_reason": "end_turn",
"stop_sequence": None,
},
}
# Progress entries (should be stripped during upload)
PROGRESS_1 = {
"type": "progress",
"uuid": "prog1",
"parentUuid": "a1-tool",
"data": {"type": "bash_progress", "stdout": "running ls..."},
}
# Second user message
USER_2 = {
"type": "user",
"uuid": "u2",
"parentUuid": "a1-text",
"message": {"role": "user", "content": "Show me file1.py"},
}
ASST_2 = {
"type": "assistant",
"uuid": "a2",
"parentUuid": "u2",
"message": {
"role": "assistant",
"id": "msg_sdk_ccc",
"type": "message",
"content": [{"type": "text", "text": "Here is file1.py content..."}],
"stop_reason": "end_turn",
"stop_sequence": None,
},
}
# --- Compaction summary (written by CLI after context compaction) ---
COMPACT_SUMMARY = {
"type": "summary",
"uuid": "cs1",
"isCompactSummary": True,
"message": {
"role": "user",
"content": (
"Summary: User asked about project files. Found file1.py and file2.py. "
"User then asked to see file1.py."
),
},
}
# Post-compaction assistant response
POST_COMPACT_ASST = {
"type": "assistant",
"uuid": "a3",
"parentUuid": "cs1",
"message": {
"role": "assistant",
"id": "msg_sdk_ddd",
"type": "message",
"content": [{"type": "text", "text": "Here is the content of file1.py..."}],
"stop_reason": "end_turn",
"stop_sequence": None,
},
}
# Post-compaction user follow-up
USER_3 = {
"type": "user",
"uuid": "u3",
"parentUuid": "a3",
"message": {"role": "user", "content": "Now show file2.py"},
}
ASST_3 = {
"type": "assistant",
"uuid": "a4",
"parentUuid": "u3",
"message": {
"role": "assistant",
"id": "msg_sdk_eee",
"type": "message",
"content": [{"type": "text", "text": "Here is file2.py..."}],
"stop_reason": "end_turn",
"stop_sequence": None,
},
}
# ---------------------------------------------------------------------------
# E2E test
# ---------------------------------------------------------------------------
class TestCompactionE2E:
def _write_session_file(self, session_dir, entries):
"""Write a CLI session JSONL file."""
path = session_dir / "session.jsonl"
path.write_text(_make_jsonl(*entries))
return path
def test_full_compaction_lifecycle(self, tmp_path):
"""Simulate the complete service.py compaction flow.
Timeline:
1. Previous turn uploaded transcript with [USER_1, ASST_1, USER_2, ASST_2]
2. Current turn: download → load_previous
3. User sends "Now show file2.py" → append_user
4. SDK starts streaming response
5. Mid-stream: PreCompact hook fires (context too large)
6. CLI writes compaction summary to session file
7. Next SDK message → emit_start (spinner)
8. Following message → emit_end (end events)
9. _read_compacted_entries reads the session file
10. replace_entries syncs TranscriptBuilder
11. More assistant messages appended
12. Export → upload → next turn downloads it
"""
session_dir = tmp_path / "session"
session_dir.mkdir(parents=True)
# --- Step 1-2: Load "downloaded" transcript from previous turn ---
previous_transcript = _make_jsonl(
USER_1,
ASST_1_THINKING,
ASST_1_TOOL,
TOOL_RESULT_1,
ASST_1_TEXT,
USER_2,
ASST_2,
)
builder = TranscriptBuilder()
builder.load_previous(previous_transcript)
assert builder.entry_count == 7
# --- Step 3: User sends new query ---
builder.append_user("Now show file2.py")
assert builder.entry_count == 8
# --- Step 4: SDK starts streaming ---
builder.append_assistant(
[{"type": "thinking", "thinking": "Let me read file2.py..."}],
model="claude-sonnet-4-20250514",
)
assert builder.entry_count == 9
# --- Step 5-6: PreCompact fires, CLI writes session file ---
session_file = self._write_session_file(
session_dir,
[
USER_1,
ASST_1_THINKING,
ASST_1_TOOL,
PROGRESS_1,
TOOL_RESULT_1,
ASST_1_TEXT,
USER_2,
ASST_2,
COMPACT_SUMMARY,
POST_COMPACT_ASST,
USER_3,
ASST_3,
],
)
# --- Step 7: CompactionTracker receives PreCompact hook ---
tracker = CompactionTracker()
session = ChatSession.new(user_id="test-user")
# on_compact is a property returning Event.set callable
tracker.on_compact()
# --- Step 8: Next SDK message arrives → emit_start ---
start_events = tracker.emit_start_if_ready()
assert len(start_events) == 3
assert isinstance(start_events[0], StreamStartStep)
assert isinstance(start_events[1], StreamToolInputStart)
assert isinstance(start_events[2], StreamToolInputAvailable)
# Verify tool_call_id is set
tool_call_id = start_events[1].toolCallId
assert tool_call_id.startswith("compaction-")
# --- Step 9: Following message → emit_end ---
end_events = _run(tracker.emit_end_if_ready(session))
assert len(end_events) == 2
assert isinstance(end_events[0], StreamToolOutputAvailable)
assert isinstance(end_events[1], StreamFinishStep)
# Verify same tool_call_id
assert end_events[0].toolCallId == tool_call_id
# Session should have compaction messages persisted
assert len(session.messages) == 2
assert session.messages[0].role == "assistant"
assert session.messages[1].role == "tool"
# --- Step 10: _read_compacted_entries + replace_entries ---
result = _read_compacted_entries(str(session_file))
assert result is not None
compacted_dicts, compacted_jsonl = result
# Should have: COMPACT_SUMMARY + POST_COMPACT_ASST + USER_3 + ASST_3
assert len(compacted_dicts) == 4
assert compacted_dicts[0]["uuid"] == "cs1"
assert compacted_dicts[0]["isCompactSummary"] is True
# Replace builder state with compacted JSONL
old_count = builder.entry_count
builder.replace_entries(compacted_jsonl)
assert builder.entry_count == 4 # Only compacted entries
assert builder.entry_count < old_count # Compaction reduced entries
# --- Step 11: More assistant messages after compaction ---
builder.append_assistant(
[{"type": "text", "text": "Here is file2.py:\n\ndef hello():\n pass"}],
model="claude-sonnet-4-20250514",
stop_reason="end_turn",
)
assert builder.entry_count == 5
# --- Step 12: Export for upload ---
output = builder.to_jsonl()
assert output # Not empty
output_entries = [json.loads(line) for line in output.strip().split("\n")]
assert len(output_entries) == 5
# Verify structure:
# [COMPACT_SUMMARY, POST_COMPACT_ASST, USER_3, ASST_3, new_assistant]
assert output_entries[0]["type"] == "summary"
assert output_entries[0].get("isCompactSummary") is True
assert output_entries[0]["uuid"] == "cs1"
assert output_entries[1]["uuid"] == "a3"
assert output_entries[2]["uuid"] == "u3"
assert output_entries[3]["uuid"] == "a4"
assert output_entries[4]["type"] == "assistant"
# Verify parent chain is intact
assert output_entries[1]["parentUuid"] == "cs1" # a3 → cs1
assert output_entries[2]["parentUuid"] == "a3" # u3 → a3
assert output_entries[3]["parentUuid"] == "u3" # a4 → u3
assert output_entries[4]["parentUuid"] == "a4" # new → a4
# --- Step 13: Roundtrip — next turn loads this export ---
builder2 = TranscriptBuilder()
builder2.load_previous(output)
assert builder2.entry_count == 5
# isCompactSummary survives roundtrip
output2 = builder2.to_jsonl()
first_entry = json.loads(output2.strip().split("\n")[0])
assert first_entry.get("isCompactSummary") is True
# Can append more messages
builder2.append_user("What about file3.py?")
assert builder2.entry_count == 6
final_output = builder2.to_jsonl()
last_entry = json.loads(final_output.strip().split("\n")[-1])
assert last_entry["type"] == "user"
# Parented to the last entry from previous turn
assert last_entry["parentUuid"] == output_entries[-1]["uuid"]
def test_double_compaction_within_session(self, tmp_path):
"""Two compactions in the same session (across reset_for_query)."""
session_dir = tmp_path / "session"
session_dir.mkdir(parents=True)
tracker = CompactionTracker()
session = ChatSession.new(user_id="test")
builder = TranscriptBuilder()
# --- First query with compaction ---
builder.append_user("first question")
builder.append_assistant([{"type": "text", "text": "first answer"}])
# Write session file for first compaction
first_summary = {
"type": "summary",
"uuid": "cs-first",
"isCompactSummary": True,
"message": {"role": "user", "content": "First compaction summary"},
}
first_post = {
"type": "assistant",
"uuid": "a-first",
"parentUuid": "cs-first",
"message": {"role": "assistant", "content": "first post-compact"},
}
file1 = session_dir / "session1.jsonl"
file1.write_text(_make_jsonl(first_summary, first_post))
tracker.on_compact()
tracker.emit_start_if_ready()
end_events1 = _run(tracker.emit_end_if_ready(session))
assert len(end_events1) == 2 # output + finish
result1_entries = _read_compacted_entries(str(file1))
assert result1_entries is not None
_, compacted1_jsonl = result1_entries
builder.replace_entries(compacted1_jsonl)
assert builder.entry_count == 2
# --- Reset for second query ---
tracker.reset_for_query()
# --- Second query with compaction ---
builder.append_user("second question")
builder.append_assistant([{"type": "text", "text": "second answer"}])
second_summary = {
"type": "summary",
"uuid": "cs-second",
"isCompactSummary": True,
"message": {"role": "user", "content": "Second compaction summary"},
}
second_post = {
"type": "assistant",
"uuid": "a-second",
"parentUuid": "cs-second",
"message": {"role": "assistant", "content": "second post-compact"},
}
file2 = session_dir / "session2.jsonl"
file2.write_text(_make_jsonl(second_summary, second_post))
tracker.on_compact()
tracker.emit_start_if_ready()
end_events2 = _run(tracker.emit_end_if_ready(session))
assert len(end_events2) == 2 # output + finish
result2_entries = _read_compacted_entries(str(file2))
assert result2_entries is not None
_, compacted2_jsonl = result2_entries
builder.replace_entries(compacted2_jsonl)
assert builder.entry_count == 2 # Only second compaction entries
# Export and verify
output = builder.to_jsonl()
entries = [json.loads(line) for line in output.strip().split("\n")]
assert entries[0]["uuid"] == "cs-second"
assert entries[0].get("isCompactSummary") is True
def test_strip_progress_then_load_then_compact_roundtrip(self, tmp_path):
"""Full pipeline: strip → load → compact → replace → export → reload.
This tests the exact sequence that happens across two turns:
Turn 1: SDK produces transcript with progress entries
Upload: strip_progress_entries removes progress, upload to cloud
Turn 2: Download → load_previous → compaction fires → replace → export
Turn 3: Download the Turn 2 export → load_previous (roundtrip)
"""
session_dir = tmp_path / "session"
session_dir.mkdir(parents=True)
# --- Turn 1: SDK produces raw transcript ---
raw_content = _make_jsonl(
USER_1,
ASST_1_THINKING,
ASST_1_TOOL,
PROGRESS_1,
TOOL_RESULT_1,
ASST_1_TEXT,
USER_2,
ASST_2,
)
# Strip progress for upload
stripped = strip_progress_entries(raw_content)
stripped_entries = [
json.loads(line) for line in stripped.strip().split("\n") if line.strip()
]
# Progress should be gone
assert not any(e.get("type") == "progress" for e in stripped_entries)
assert len(stripped_entries) == 7 # 8 - 1 progress
# --- Turn 2: Download stripped, load, compaction happens ---
builder = TranscriptBuilder()
builder.load_previous(stripped)
assert builder.entry_count == 7
builder.append_user("Now show file2.py")
builder.append_assistant(
[{"type": "text", "text": "Reading file2.py..."}],
model="claude-sonnet-4-20250514",
)
# CLI writes session file with compaction
session_file = self._write_session_file(
session_dir,
[
USER_1,
ASST_1_TOOL,
TOOL_RESULT_1,
ASST_1_TEXT,
USER_2,
ASST_2,
COMPACT_SUMMARY,
POST_COMPACT_ASST,
],
)
result = _read_compacted_entries(str(session_file))
assert result is not None
_, compacted_jsonl = result
builder.replace_entries(compacted_jsonl)
# Append post-compaction message
builder.append_user("Thanks!")
output = builder.to_jsonl()
# --- Turn 3: Fresh load of Turn 2 export ---
builder3 = TranscriptBuilder()
builder3.load_previous(output)
# Should have: compact_summary + post_compact_asst + "Thanks!"
assert builder3.entry_count == 3
# Compact summary survived the full pipeline
first = json.loads(builder3.to_jsonl().strip().split("\n")[0])
assert first.get("isCompactSummary") is True
assert first["type"] == "summary"

View File

@@ -26,17 +26,3 @@ For other services, search the MCP registry at https://registry.modelcontextprot
If the server requires credentials, a `SetupRequirementsResponse` is returned with an OAuth
login prompt. Once the user completes the flow and confirms, retry the same call immediately.
### Communication style
Avoid technical jargon like "MCP server", "OAuth", or "credentials" when talking to the user.
Use plain, friendly language instead:
| Instead of… | Say… |
|---|---|
| "Let me connect to Sentry's MCP server and discover what tools are available." | "I can connect to Sentry and help identify important issues." |
| "Let me connect to Sentry's MCP server now." | "Next, I'll connect to Sentry." |
| "The MCP server at mcp.sentry.dev requires authentication. Please connect your credentials to continue." | "To continue, sign in to Sentry and approve access." |
| "Sentry's MCP server needs OAuth authentication. You should see a prompt to connect your Sentry account…" | "You should see a prompt to sign in to Sentry. Once connected, I can help surface critical issues right away." |
Use **"connect to [Service]"** or **"sign in to [Service]"** — never "MCP server", "OAuth", or "credentials".

View File

@@ -221,12 +221,12 @@ class SDKResponseAdapter:
responses.append(StreamFinish())
else:
logger.warning(
"Unexpected ResultMessage subtype: %s", sdk_message.subtype
f"Unexpected ResultMessage subtype: {sdk_message.subtype}"
)
responses.append(StreamFinish())
else:
logger.debug("Unhandled SDK message type: %s", type(sdk_message).__name__)
logger.debug(f"Unhandled SDK message type: {type(sdk_message).__name__}")
return responses

View File

@@ -52,7 +52,7 @@ def _validate_workspace_path(
if is_allowed_local_path(path, sdk_cwd):
return {}
logger.warning("Blocked %s outside workspace: %s", tool_name, path)
logger.warning(f"Blocked {tool_name} outside workspace: {path}")
workspace_hint = f" Allowed workspace: {sdk_cwd}" if sdk_cwd else ""
return _deny(
f"[SECURITY] Tool '{tool_name}' can only access files within the workspace "
@@ -71,7 +71,7 @@ def _validate_tool_access(
"""
# Block forbidden tools
if tool_name in BLOCKED_TOOLS:
logger.warning("Blocked tool access attempt: %s", tool_name)
logger.warning(f"Blocked tool access attempt: {tool_name}")
return _deny(
f"[SECURITY] Tool '{tool_name}' is blocked for security. "
"This is enforced by the platform and cannot be bypassed. "
@@ -111,9 +111,7 @@ def _validate_user_isolation(
# the tool itself via _validate_ephemeral_path.
path = tool_input.get("path", "") or tool_input.get("file_path", "")
if path and ".." in path:
logger.warning(
"Blocked path traversal attempt: %s by user %s", path, user_id
)
logger.warning(f"Blocked path traversal attempt: {path} by user {user_id}")
return {
"hookSpecificOutput": {
"hookEventName": "PreToolUse",
@@ -171,7 +169,7 @@ def create_security_hooks(
# Block background task execution first — denied calls
# should not consume a subtask slot.
if tool_input.get("run_in_background"):
logger.info("[SDK] Blocked background Task, user=%s", user_id)
logger.info(f"[SDK] Blocked background Task, user={user_id}")
return cast(
SyncHookJSONOutput,
_deny(
@@ -213,7 +211,7 @@ def create_security_hooks(
if tool_name == "Task" and tool_use_id is not None:
task_tool_use_ids.add(tool_use_id)
logger.debug("[SDK] Tool start: %s, user=%s", tool_name, user_id)
logger.debug(f"[SDK] Tool start: {tool_name}, user={user_id}")
return cast(SyncHookJSONOutput, {})
def _release_task_slot(tool_name: str, tool_use_id: str | None) -> None:

View File

@@ -40,13 +40,11 @@ from ..constants import COPILOT_ERROR_PREFIX, COPILOT_SYSTEM_PREFIX
from ..model import (
ChatMessage,
ChatSession,
Usage,
get_chat_session,
update_session_title,
upsert_chat_session,
)
from ..prompting import get_sdk_supplement
from ..rate_limit import record_token_usage
from ..response_model import (
StreamBaseResponse,
StreamError,
@@ -56,7 +54,6 @@ from ..response_model import (
StreamTextDelta,
StreamToolInputAvailable,
StreamToolOutputAvailable,
StreamUsage,
)
from ..service import (
_build_system_prompt,
@@ -78,12 +75,8 @@ from .tool_adapter import (
wait_for_stash,
)
from .transcript import (
COMPACT_THRESHOLD_BYTES,
TranscriptDownload,
cleanup_cli_project_dir,
compact_transcript,
download_transcript,
read_cli_session_file,
upload_transcript,
validate_transcript,
write_transcript_to_tempfile,
@@ -301,7 +294,7 @@ def _cleanup_sdk_tool_results(cwd: str) -> None:
"""
normalized = os.path.normpath(cwd)
if not normalized.startswith(_SDK_CWD_PREFIX):
logger.warning("[SDK] Rejecting cleanup for path outside workspace: %s", cwd)
logger.warning(f"[SDK] Rejecting cleanup for path outside workspace: {cwd}")
return
# Clean the CLI's project directory (transcripts + tool-results).
@@ -395,7 +388,7 @@ async def _compress_messages(
client=client,
)
except Exception as e:
logger.warning("[SDK] Context compression with LLM failed: %s", e)
logger.warning(f"[SDK] Context compression with LLM failed: {e}")
# Fall back to truncation-only (no LLM summarization)
result = await compress_context(
messages=messages_dict,
@@ -631,56 +624,6 @@ async def _prepare_file_attachments(
return PreparedAttachments(hint=hint, image_blocks=image_blocks)
async def _maybe_compact_and_upload(
dl: TranscriptDownload,
user_id: str,
session_id: str,
log_prefix: str = "[Transcript]",
) -> str:
"""Compact an oversized transcript and upload the compacted version.
Returns the (possibly compacted) transcript content, or an empty string
if compaction was needed but failed.
"""
content = dl.content
if len(content) <= COMPACT_THRESHOLD_BYTES:
return content
logger.warning(
"%s Transcript oversized (%dB > %dB), compacting",
log_prefix,
len(content),
COMPACT_THRESHOLD_BYTES,
)
compacted = await compact_transcript(content, log_prefix=log_prefix)
if not compacted:
logger.warning(
"%s Compaction failed, skipping resume for this turn", log_prefix
)
return ""
# Keep the original message_count: it reflects the number of
# session.messages covered by this transcript, which the gap-fill
# logic uses as a slice index. Counting JSONL lines would give a
# smaller number (compacted messages != session message count) and
# cause already-covered messages to be re-injected.
try:
await upload_transcript(
user_id=user_id,
session_id=session_id,
content=compacted,
message_count=dl.message_count,
log_prefix=log_prefix,
)
except Exception:
logger.warning(
"%s Failed to upload compacted transcript",
log_prefix,
exc_info=True,
)
return compacted
async def stream_chat_completion_sdk(
session_id: str,
message: str | None = None,
@@ -792,14 +735,6 @@ async def stream_chat_completion_sdk(
_otel_ctx: Any = None
# Make sure there is no more code between the lock acquisition and try-block.
# Token usage accumulators — populated from ResultMessage at end of turn
turn_prompt_tokens = 0 # uncached input tokens only
turn_completion_tokens = 0
turn_cache_read_tokens = 0
turn_cache_creation_tokens = 0
total_tokens = 0 # computed once before StreamUsage, reused in finally
turn_cost_usd: float | None = None
try:
# Build system prompt (reuses non-SDK path with Langfuse support).
# Pre-compute the cwd here so the exact working directory path can be
@@ -892,33 +827,20 @@ async def stream_chat_completion_sdk(
is_valid,
)
if is_valid:
transcript_content = await _maybe_compact_and_upload(
dl,
user_id=user_id or "",
session_id=session_id,
log_prefix=log_prefix,
)
# Load previous context into builder (empty string is a no-op)
if transcript_content:
transcript_builder.load_previous(
transcript_content, log_prefix=log_prefix
)
resume_file = (
write_transcript_to_tempfile(
transcript_content, session_id, sdk_cwd
)
if transcript_content
else None
# Load previous FULL context into builder
transcript_builder.load_previous(dl.content, log_prefix=log_prefix)
resume_file = write_transcript_to_tempfile(
dl.content, session_id, sdk_cwd
)
if resume_file:
use_resume = True
transcript_msg_count = dl.message_count
logger.debug(
f"{log_prefix} Using --resume ({len(transcript_content)}B, "
f"{log_prefix} Using --resume ({len(dl.content)}B, "
f"msg_count={transcript_msg_count})"
)
else:
logger.warning("%s Transcript downloaded but invalid", log_prefix)
logger.warning(f"{log_prefix} Transcript downloaded but invalid")
elif config.claude_agent_use_resume and user_id and len(session.messages) > 1:
logger.warning(
f"{log_prefix} No transcript available "
@@ -1188,7 +1110,7 @@ async def stream_chat_completion_sdk(
- len(adapter.resolved_tool_calls),
)
# Log ResultMessage details and capture token usage
# Log ResultMessage details for debugging
if isinstance(sdk_msg, ResultMessage):
logger.info(
"%s Received: ResultMessage %s "
@@ -1207,46 +1129,9 @@ async def stream_chat_completion_sdk(
sdk_msg.result or "(no error message provided)",
)
# Capture token usage from ResultMessage.
# Anthropic reports cached tokens separately:
# input_tokens = uncached only
# cache_read_input_tokens = served from cache
# cache_creation_input_tokens = written to cache
if sdk_msg.usage:
turn_prompt_tokens += sdk_msg.usage.get("input_tokens", 0)
turn_cache_read_tokens += sdk_msg.usage.get(
"cache_read_input_tokens", 0
)
turn_cache_creation_tokens += sdk_msg.usage.get(
"cache_creation_input_tokens", 0
)
turn_completion_tokens += sdk_msg.usage.get(
"output_tokens", 0
)
logger.info(
"%s Token usage: uncached=%d, cache_read=%d, cache_create=%d, output=%d",
log_prefix,
turn_prompt_tokens,
turn_cache_read_tokens,
turn_cache_creation_tokens,
turn_completion_tokens,
)
if sdk_msg.total_cost_usd is not None:
turn_cost_usd = sdk_msg.total_cost_usd
# Emit compaction end if SDK finished compacting.
# When compaction ends, sync TranscriptBuilder with
# the CLI's compacted session file so the uploaded
# transcript reflects compaction.
compaction_events = await compaction.emit_end_if_ready(session)
for ev in compaction_events:
# Emit compaction end if SDK finished compacting
for ev in await compaction.emit_end_if_ready(session):
yield ev
if compaction_events and sdk_cwd:
cli_content = await read_cli_session_file(sdk_cwd)
if cli_content:
transcript_builder.replace_entries(
cli_content, log_prefix=log_prefix
)
for response in adapter.convert_message(sdk_msg):
if isinstance(response, StreamStart):
@@ -1440,27 +1325,6 @@ async def stream_chat_completion_sdk(
) and not has_appended_assistant:
session.messages.append(assistant_response)
# Emit token usage to the client (must be in try to reach SSE stream).
# Session persistence of usage is in finally to stay consistent with
# rate-limit recording even if an exception interrupts between here
# and the finally block.
# Compute total_tokens once; reused in the finally block for
# session persistence and rate-limit recording.
total_tokens = (
turn_prompt_tokens
+ turn_cache_read_tokens
+ turn_cache_creation_tokens
+ turn_completion_tokens
)
if total_tokens > 0:
yield StreamUsage(
promptTokens=turn_prompt_tokens,
completionTokens=turn_completion_tokens,
totalTokens=total_tokens,
cacheReadTokens=turn_cache_read_tokens,
cacheCreationTokens=turn_cache_creation_tokens,
)
# Transcript upload is handled exclusively in the finally block
# to avoid double-uploads (the success path used to upload the
# old resume file, then the finally block overwrote it with the
@@ -1525,48 +1389,6 @@ async def stream_chat_completion_sdk(
except Exception:
logger.warning("OTEL context teardown failed", exc_info=True)
# --- Persist token usage to session + rate-limit counters ---
# Both must live in finally so they stay consistent even when an
# exception interrupts the try block after StreamUsage was yielded.
# total_tokens is computed once before StreamUsage yield above.
if total_tokens > 0:
if session is not None:
session.usage.append(
Usage(
prompt_tokens=turn_prompt_tokens,
completion_tokens=turn_completion_tokens,
total_tokens=total_tokens,
cache_read_tokens=turn_cache_read_tokens,
cache_creation_tokens=turn_cache_creation_tokens,
)
)
logger.info(
"%s Turn usage: uncached=%d, cache_read=%d, cache_create=%d, "
"output=%d, total=%d, cost_usd=%s",
log_prefix,
turn_prompt_tokens,
turn_cache_read_tokens,
turn_cache_creation_tokens,
turn_completion_tokens,
total_tokens,
turn_cost_usd,
)
if user_id and total_tokens > 0:
try:
await record_token_usage(
user_id=user_id,
prompt_tokens=turn_prompt_tokens,
completion_tokens=turn_completion_tokens,
cache_read_tokens=turn_cache_read_tokens,
cache_creation_tokens=turn_cache_creation_tokens,
)
except Exception as usage_err:
logger.warning(
"%s Failed to record token usage: %s",
log_prefix,
usage_err,
)
# --- Persist session messages ---
# This MUST run in finally to persist messages even when the generator
# is stopped early (e.g., user clicks stop, processor breaks stream loop).
@@ -1662,6 +1484,6 @@ async def _update_title_async(
)
if title and user_id:
await update_session_title(session_id, user_id, title, only_if_empty=True)
logger.debug("[SDK] Generated title for %s: %s", session_id, title)
logger.debug(f"[SDK] Generated title for {session_id}: {title}")
except Exception as e:
logger.warning("[SDK] Failed to update session title: %s", e)
logger.warning(f"[SDK] Failed to update session title: {e}")

View File

@@ -234,9 +234,7 @@ def create_tool_handler(base_tool: BaseTool):
try:
return await _execute_tool_sync(base_tool, user_id, session, args)
except Exception as e:
logger.error(
"Error executing tool %s: %s", base_tool.name, e, exc_info=True
)
logger.error(f"Error executing tool {base_tool.name}: {e}", exc_info=True)
return _mcp_error(f"Failed to execute {base_tool.name}: {e}")
return tool_handler

View File

@@ -13,17 +13,10 @@ filesystem for self-hosted) — no DB column needed.
import logging
import os
import re
import shutil
import time
from dataclasses import dataclass
from pathlib import Path
from uuid import uuid4
import openai
from backend.copilot.config import ChatConfig
from backend.util import json
from backend.util.prompt import CompressResult, compress_context
logger = logging.getLogger(__name__)
@@ -41,11 +34,6 @@ STRIPPABLE_TYPES = frozenset(
{"progress", "file-history-snapshot", "queue-operation", "summary", "pr-link"}
)
# JSONL protocol values used in transcript serialization.
STOP_REASON_END_TURN = "end_turn"
COMPACT_MSG_ID_PREFIX = "msg_compact_"
ENTRY_TYPE_MESSAGE = "message"
@dataclass
class TranscriptDownload:
@@ -94,11 +82,7 @@ def strip_progress_entries(content: str) -> str:
parent = entry.get("parentUuid", "")
if uid:
uuid_to_parent[uid] = parent
if (
entry.get("type", "") in STRIPPABLE_TYPES
and uid
and not entry.get("isCompactSummary")
):
if entry.get("type", "") in STRIPPABLE_TYPES and uid:
stripped_uuids.add(uid)
# Second pass: keep non-stripped entries, reparenting where needed.
@@ -109,9 +93,7 @@ def strip_progress_entries(content: str) -> str:
continue
parent = entry.get("parentUuid", "")
original_parent = parent
seen_parents: set[str] = set()
while parent in stripped_uuids and parent not in seen_parents:
seen_parents.add(parent)
while parent in stripped_uuids:
parent = uuid_to_parent.get(parent, "")
if parent != original_parent:
entry["parentUuid"] = parent
@@ -124,9 +106,7 @@ def strip_progress_entries(content: str) -> str:
if not isinstance(entry, dict):
result_lines.append(line)
continue
if entry.get("type", "") in STRIPPABLE_TYPES and not entry.get(
"isCompactSummary"
):
if entry.get("type", "") in STRIPPABLE_TYPES:
continue
uid = entry.get("uuid", "")
if uid in reparented:
@@ -157,78 +137,32 @@ def _sanitize_id(raw_id: str, max_len: int = 36) -> str:
_SAFE_CWD_PREFIX = os.path.realpath("/tmp/copilot-")
def _cli_project_dir(sdk_cwd: str) -> str | None:
"""Return the CLI's project directory for a given working directory.
def cleanup_cli_project_dir(sdk_cwd: str) -> None:
"""Remove the CLI's project directory for a specific working directory.
Returns ``None`` if the path would escape the projects base.
The CLI stores session data under ``~/.claude/projects/<encoded_cwd>/``.
Each SDK turn uses a unique ``sdk_cwd``, so the project directory is
safe to remove entirely after the transcript has been uploaded.
"""
import shutil
# Encode cwd the same way CLI does (replaces non-alphanumeric with -)
cwd_encoded = re.sub(r"[^a-zA-Z0-9]", "-", os.path.realpath(sdk_cwd))
config_dir = os.environ.get("CLAUDE_CONFIG_DIR") or os.path.expanduser("~/.claude")
projects_base = os.path.realpath(os.path.join(config_dir, "projects"))
project_dir = os.path.realpath(os.path.join(projects_base, cwd_encoded))
if not project_dir.startswith(projects_base + os.sep):
logger.warning("[Transcript] Project dir escaped base: %s", project_dir)
return None
return project_dir
async def read_cli_session_file(sdk_cwd: str) -> str | None:
"""Read the CLI's own session file, which reflects any mid-stream compaction.
After the CLI compacts context, its session file contains the compacted
conversation. Reading this file lets ``TranscriptBuilder`` replace its
uncompacted entries with the CLI's compacted version.
"""
import aiofiles
project_dir = _cli_project_dir(sdk_cwd)
if not project_dir or not os.path.isdir(project_dir):
return None
jsonl_files = list(Path(project_dir).glob("*.jsonl"))
if not jsonl_files:
logger.debug("[Transcript] No CLI session file in %s", project_dir)
return None
# Pick the most recently modified file (there should only be one per turn).
# Guard against races where a file is deleted between glob and stat.
candidates: list[tuple[float, Path]] = []
for p in jsonl_files:
try:
candidates.append((p.stat().st_mtime, p))
except OSError:
continue
if not candidates:
logger.debug("[Transcript] No readable CLI session file in %s", project_dir)
return None
# Resolve + prefix check to prevent symlink escapes.
session_file = max(candidates, key=lambda item: item[0])[1]
real_path = str(session_file.resolve())
if not real_path.startswith(project_dir + os.sep):
logger.warning("[Transcript] Session file escaped project dir: %s", real_path)
return None
try:
async with aiofiles.open(real_path) as f:
content = await f.read()
logger.info(
"[Transcript] Read CLI session file: %s (%d bytes)",
real_path,
len(content),
logger.warning(
f"[Transcript] Cleanup path escaped projects base: {project_dir}"
)
return content
except OSError as e:
logger.warning("[Transcript] Failed to read CLI session file: %s", e)
return None
def cleanup_cli_project_dir(sdk_cwd: str) -> None:
"""Remove the CLI's project directory for a specific working directory."""
project_dir = _cli_project_dir(sdk_cwd)
if not project_dir:
return
if os.path.isdir(project_dir):
shutil.rmtree(project_dir, ignore_errors=True)
logger.debug("[Transcript] Cleaned up CLI project dir: %s", project_dir)
logger.debug(f"[Transcript] Cleaned up CLI project dir: {project_dir}")
else:
logger.debug("[Transcript] Project dir not found: %s", project_dir)
logger.debug(f"[Transcript] Project dir not found: {project_dir}")
def write_transcript_to_tempfile(
@@ -246,7 +180,7 @@ def write_transcript_to_tempfile(
# Validate cwd is under the expected sandbox prefix (CodeQL sanitizer).
real_cwd = os.path.realpath(cwd)
if not real_cwd.startswith(_SAFE_CWD_PREFIX):
logger.warning("[Transcript] cwd outside sandbox: %s", cwd)
logger.warning(f"[Transcript] cwd outside sandbox: {cwd}")
return None
try:
@@ -256,17 +190,17 @@ def write_transcript_to_tempfile(
os.path.join(real_cwd, f"transcript-{safe_id}.jsonl")
)
if not jsonl_path.startswith(real_cwd):
logger.warning("[Transcript] Path escaped cwd: %s", jsonl_path)
logger.warning(f"[Transcript] Path escaped cwd: {jsonl_path}")
return None
with open(jsonl_path, "w") as f:
f.write(transcript_content)
logger.info("[Transcript] Wrote resume file: %s", jsonl_path)
logger.info(f"[Transcript] Wrote resume file: {jsonl_path}")
return jsonl_path
except OSError as e:
logger.warning("[Transcript] Failed to write resume file: %s", e)
logger.warning(f"[Transcript] Failed to write resume file: {e}")
return None
@@ -410,14 +344,11 @@ async def upload_transcript(
content=json.dumps(meta).encode("utf-8"),
)
except Exception as e:
logger.warning("%s Failed to write metadata: %s", log_prefix, e)
logger.warning(f"{log_prefix} Failed to write metadata: {e}")
logger.info(
"%s Uploaded %dB (stripped from %dB, msg_count=%d)",
log_prefix,
len(encoded),
len(content),
message_count,
f"{log_prefix} Uploaded {len(encoded)}B "
f"(stripped from {len(content)}B, msg_count={message_count})"
)
@@ -440,10 +371,10 @@ async def download_transcript(
data = await storage.retrieve(path)
content = data.decode("utf-8")
except FileNotFoundError:
logger.debug("%s No transcript in storage", log_prefix)
logger.debug(f"{log_prefix} No transcript in storage")
return None
except Exception as e:
logger.warning("%s Failed to download transcript: %s", log_prefix, e)
logger.warning(f"{log_prefix} Failed to download transcript: {e}")
return None
# Try to load metadata (best-effort — old transcripts won't have it)
@@ -463,14 +394,10 @@ async def download_transcript(
meta = json.loads(meta_data.decode("utf-8"), fallback={})
message_count = meta.get("message_count", 0)
uploaded_at = meta.get("uploaded_at", 0.0)
except FileNotFoundError:
except (FileNotFoundError, Exception):
pass # No metadata — treat as unknown (msg_count=0 → always fill gap)
except Exception as e:
logger.debug("%s Failed to load transcript metadata: %s", log_prefix, e)
logger.info(
"%s Downloaded %dB (msg_count=%d)", log_prefix, len(content), message_count
)
logger.info(f"{log_prefix} Downloaded {len(content)}B (msg_count={message_count})")
return TranscriptDownload(
content=content,
message_count=message_count,
@@ -478,171 +405,15 @@ async def download_transcript(
)
# ---------------------------------------------------------------------------
# Transcript compaction
# ---------------------------------------------------------------------------
async def delete_transcript(user_id: str, session_id: str) -> None:
"""Delete transcript from bucket storage (e.g. after resume failure)."""
from backend.util.workspace_storage import get_workspace_storage
# Transcripts above this byte threshold are compacted at download time.
COMPACT_THRESHOLD_BYTES = 400_000
storage = await get_workspace_storage()
path = _build_storage_path(user_id, session_id, storage)
def _flatten_assistant_content(blocks: list) -> str:
"""Flatten assistant content blocks into a single plain-text string."""
parts: list[str] = []
for block in blocks:
if isinstance(block, dict):
if block.get("type") == "text":
parts.append(block.get("text", ""))
elif block.get("type") == "tool_use":
parts.append(f"[tool_use: {block.get('name', '?')}]")
elif isinstance(block, str):
parts.append(block)
return "\n".join(parts) if parts else ""
def _flatten_tool_result_content(blocks: list) -> str:
"""Flatten tool_result and other content blocks into plain text.
Handles nested tool_result structures, text blocks, and raw strings.
Uses ``json.dumps`` as fallback for dict blocks without a ``text`` key
or where ``text`` is ``None``.
"""
str_parts: list[str] = []
for block in blocks:
if isinstance(block, dict) and block.get("type") == "tool_result":
inner = block.get("content", "")
if isinstance(inner, list):
for sub in inner:
if isinstance(sub, dict):
text = sub.get("text")
str_parts.append(
str(text) if text is not None else json.dumps(sub)
)
else:
str_parts.append(str(sub))
else:
str_parts.append(str(inner))
elif isinstance(block, dict) and block.get("type") == "text":
str_parts.append(str(block.get("text", "")))
elif isinstance(block, str):
str_parts.append(block)
return "\n".join(str_parts) if str_parts else ""
def _transcript_to_messages(content: str) -> list[dict]:
"""Convert JSONL transcript entries to message dicts for compress_context."""
messages: list[dict] = []
for line in content.strip().split("\n"):
if not line.strip():
continue
entry = json.loads(line, fallback=None)
if not isinstance(entry, dict):
continue
if entry.get("type", "") in STRIPPABLE_TYPES and not entry.get(
"isCompactSummary"
):
continue
msg = entry.get("message", {})
role = msg.get("role", "")
if not role:
continue
msg_dict: dict = {"role": role}
raw_content = msg.get("content")
if role == "assistant" and isinstance(raw_content, list):
msg_dict["content"] = _flatten_assistant_content(raw_content)
elif isinstance(raw_content, list):
msg_dict["content"] = _flatten_tool_result_content(raw_content)
else:
msg_dict["content"] = raw_content or ""
messages.append(msg_dict)
return messages
def _messages_to_transcript(messages: list[dict]) -> str:
"""Convert compressed message dicts back to JSONL transcript format."""
lines: list[str] = []
last_uuid: str | None = None
for msg in messages:
role = msg.get("role", "user")
entry_type = "assistant" if role == "assistant" else "user"
uid = str(uuid4())
content = msg.get("content", "")
if role == "assistant":
message: dict = {
"role": "assistant",
"model": "",
"id": f"{COMPACT_MSG_ID_PREFIX}{uuid4().hex[:24]}",
"type": ENTRY_TYPE_MESSAGE,
"content": [{"type": "text", "text": content}] if content else [],
"stop_reason": STOP_REASON_END_TURN,
"stop_sequence": None,
}
else:
message = {"role": role, "content": content}
entry = {
"type": entry_type,
"uuid": uid,
"parentUuid": last_uuid,
"message": message,
}
lines.append(json.dumps(entry, separators=(",", ":")))
last_uuid = uid
return "\n".join(lines) + "\n" if lines else ""
async def _run_compression(
messages: list[dict],
model: str,
cfg: ChatConfig,
log_prefix: str,
) -> CompressResult:
"""Run LLM-based compression with truncation fallback."""
try:
async with openai.AsyncOpenAI(
api_key=cfg.api_key, base_url=cfg.base_url, timeout=30.0
) as client:
return await compress_context(messages=messages, model=model, client=client)
await storage.delete(path)
logger.info(f"[Transcript] Deleted transcript for session {session_id}")
except Exception as e:
logger.warning("%s LLM compaction failed, using truncation: %s", log_prefix, e)
return await compress_context(messages=messages, model=model, client=None)
async def compact_transcript(
content: str,
log_prefix: str = "[Transcript]",
) -> str | None:
"""Compact an oversized JSONL transcript using LLM summarization.
Converts transcript entries to plain messages, runs ``compress_context``
(the same compressor used for pre-query history), and rebuilds JSONL.
Returns the compacted JSONL string, or ``None`` on failure.
"""
cfg = ChatConfig()
messages = _transcript_to_messages(content)
if len(messages) < 2:
logger.warning("%s Too few messages to compact (%d)", log_prefix, len(messages))
return None
try:
result = await _run_compression(messages, cfg.model, cfg, log_prefix)
if not result.was_compacted:
logger.info("%s Transcript already within token budget", log_prefix)
return content
logger.info(
"%s Compacted transcript: %d->%d tokens (%d summarized, %d dropped)",
log_prefix,
result.original_token_count,
result.token_count,
result.messages_summarized,
result.messages_dropped,
)
compacted = _messages_to_transcript(result.messages)
if not validate_transcript(compacted):
logger.warning("%s Compacted transcript failed validation", log_prefix)
return None
return compacted
except Exception as e:
logger.error(
"%s Transcript compaction failed: %s", log_prefix, e, exc_info=True
)
return None
logger.warning(f"[Transcript] Failed to delete transcript: {e}")

View File

@@ -31,7 +31,6 @@ class TranscriptEntry(BaseModel):
uuid: str
parentUuid: str | None
message: dict[str, Any]
isCompactSummary: bool | None = None
class TranscriptBuilder:
@@ -79,12 +78,10 @@ class TranscriptBuilder:
)
continue
# Skip STRIPPABLE_TYPES unless the entry is a compaction summary.
# Compaction summaries may have type "summary" but must be preserved
# so --resume can reconstruct the compacted conversation.
# Load all non-strippable entries (user/assistant/system/etc.)
# Skip only STRIPPABLE_TYPES to match strip_progress_entries() behavior
entry_type = data.get("type", "")
is_compact = data.get("isCompactSummary", False)
if entry_type in STRIPPABLE_TYPES and not is_compact:
if entry_type in STRIPPABLE_TYPES:
continue
entry = TranscriptEntry(
@@ -92,7 +89,6 @@ class TranscriptBuilder:
uuid=data.get("uuid") or str(uuid4()),
parentUuid=data.get("parentUuid"),
message=data.get("message", {}),
isCompactSummary=True if is_compact else None,
)
self._entries.append(entry)
self._last_uuid = entry.uuid
@@ -181,33 +177,6 @@ class TranscriptBuilder:
lines = [entry.model_dump_json(exclude_none=True) for entry in self._entries]
return "\n".join(lines) + "\n"
def replace_entries(self, content: str, log_prefix: str = "[Transcript]") -> None:
"""Replace all entries with compacted JSONL content.
Called after the CLI performs mid-stream compaction so the builder's
state reflects the compacted conversation instead of the full
pre-compaction history.
"""
prev_count = len(self._entries)
temp = TranscriptBuilder()
try:
temp.load_previous(content, log_prefix=log_prefix)
except Exception:
logger.exception(
"%s Failed to parse compacted transcript; keeping %d existing entries",
log_prefix,
prev_count,
)
return
self._entries = temp._entries
self._last_uuid = temp._last_uuid
logger.info(
"%s Replaced %d entries with %d compacted entries",
log_prefix,
prev_count,
len(self._entries),
)
@property
def entry_count(self) -> int:
"""Total number of entries in the complete context."""

View File

@@ -2,25 +2,14 @@
import os
import pytest
from backend.util import json
from .transcript import (
COMPACT_MSG_ID_PREFIX,
STRIPPABLE_TYPES,
_cli_project_dir,
_flatten_assistant_content,
_flatten_tool_result_content,
_messages_to_transcript,
_transcript_to_messages,
compact_transcript,
read_cli_session_file,
strip_progress_entries,
validate_transcript,
write_transcript_to_tempfile,
)
from .transcript_builder import TranscriptBuilder
def _make_jsonl(*entries: dict) -> str:
@@ -46,14 +35,6 @@ PROGRESS_ENTRY = {
"data": {"type": "bash_progress", "stdout": "running..."},
}
COMPACT_SUMMARY = {
"type": "summary",
"uuid": "cs1",
"parentUuid": None,
"isCompactSummary": True,
"message": {"role": "user", "content": "Summary of previous conversation..."},
}
VALID_TRANSCRIPT = _make_jsonl(METADATA_LINE, FILE_HISTORY, USER_MSG, ASST_MSG)
@@ -256,121 +237,6 @@ class TestStripProgressEntries:
# Should return just a newline (empty content stripped)
assert result.strip() == ""
# --- _cli_project_dir ---
class TestCliProjectDir:
def test_returns_path_for_valid_cwd(self, tmp_path, monkeypatch):
monkeypatch.setenv("CLAUDE_CONFIG_DIR", str(tmp_path))
projects = tmp_path / "projects"
projects.mkdir()
result = _cli_project_dir("/tmp/copilot-abc")
assert result is not None
assert "projects" in result
def test_returns_none_for_path_traversal(self, tmp_path, monkeypatch):
monkeypatch.setenv("CLAUDE_CONFIG_DIR", str(tmp_path))
projects = tmp_path / "projects"
projects.mkdir()
# A cwd that encodes to something with .. shouldn't escape
result = _cli_project_dir("/tmp/copilot-test")
# Should return a valid path (no traversal possible with alphanum encoding)
assert result is None or result.startswith(str(projects))
# --- read_cli_session_file ---
class TestReadCliSessionFile:
@pytest.mark.asyncio
async def test_reads_session_file(self, tmp_path, monkeypatch):
monkeypatch.setenv("CLAUDE_CONFIG_DIR", str(tmp_path))
# Create the CLI project directory structure
cwd = "/tmp/copilot-testread"
import re
encoded = re.sub(r"[^a-zA-Z0-9]", "-", os.path.realpath(cwd))
project_dir = tmp_path / "projects" / encoded
project_dir.mkdir(parents=True)
# Write a session file
session_file = project_dir / "test-session.jsonl"
session_file.write_text(json.dumps(ASST_MSG) + "\n")
result = await read_cli_session_file(cwd)
assert result is not None
assert "assistant" in result
@pytest.mark.asyncio
async def test_returns_none_when_no_files(self, tmp_path, monkeypatch):
monkeypatch.setenv("CLAUDE_CONFIG_DIR", str(tmp_path))
cwd = "/tmp/copilot-nofiles"
import re
encoded = re.sub(r"[^a-zA-Z0-9]", "-", os.path.realpath(cwd))
project_dir = tmp_path / "projects" / encoded
project_dir.mkdir(parents=True)
# No jsonl files
result = await read_cli_session_file(cwd)
assert result is None
@pytest.mark.asyncio
async def test_returns_none_when_dir_missing(self, tmp_path, monkeypatch):
monkeypatch.setenv("CLAUDE_CONFIG_DIR", str(tmp_path))
(tmp_path / "projects").mkdir()
result = await read_cli_session_file("/tmp/copilot-nonexistent")
assert result is None
# --- _transcript_to_messages / _messages_to_transcript ---
class TestTranscriptMessageConversion:
def test_roundtrip_preserves_roles(self):
transcript = _make_jsonl(USER_MSG, ASST_MSG)
messages = _transcript_to_messages(transcript)
assert len(messages) == 2
assert messages[0]["role"] == "user"
assert messages[1]["role"] == "assistant"
def test_messages_to_transcript_produces_valid_jsonl(self):
messages = [
{"role": "user", "content": "hi"},
{"role": "assistant", "content": "hello"},
]
result = _messages_to_transcript(messages)
assert validate_transcript(result) is True
def test_strips_strippable_types(self):
transcript = _make_jsonl(
{"type": "progress", "uuid": "p1", "message": {"role": "user"}},
USER_MSG,
ASST_MSG,
)
messages = _transcript_to_messages(transcript)
assert len(messages) == 2 # progress entry skipped
def test_flattens_assistant_content_blocks(self):
asst_with_blocks = {
"type": "assistant",
"uuid": "a1",
"message": {
"role": "assistant",
"content": [
{"type": "text", "text": "hello"},
{"type": "tool_use", "name": "bash"},
],
},
}
messages = _transcript_to_messages(_make_jsonl(asst_with_blocks))
assert len(messages) == 1
assert "hello" in messages[0]["content"]
assert "[tool_use: bash]" in messages[0]["content"]
def test_empty_messages_returns_empty(self):
result = _messages_to_transcript([])
assert result == ""
def test_no_strippable_entries(self):
"""When there's nothing to strip, output matches input structure."""
content = _make_jsonl(USER_MSG, ASST_MSG)
@@ -416,654 +282,3 @@ class TestTranscriptMessageConversion:
lines = result.strip().split("\n")
asst_entry = json.loads(lines[-1])
assert asst_entry["parentUuid"] == "u1" # reparented
# --- TranscriptBuilder ---
class TestTranscriptBuilderReplaceEntries:
"""Tests for TranscriptBuilder.replace_entries — the compaction sync path."""
def test_replace_entries_with_valid_content(self):
builder = TranscriptBuilder()
builder.append_user("hello")
builder.append_assistant([{"type": "text", "text": "world"}])
assert builder.entry_count == 2
# Replace with compacted content (one user + one assistant)
compacted = _make_jsonl(USER_MSG, ASST_MSG)
builder.replace_entries(compacted)
assert builder.entry_count == 2
def test_replace_entries_keeps_old_on_corrupt_content(self):
builder = TranscriptBuilder()
builder.append_user("hello")
assert builder.entry_count == 1
# Corrupt content that fails to parse
builder.replace_entries("not valid json at all\n")
# Should still have old entries (load_previous skips invalid lines,
# but if ALL lines are invalid, temp builder is empty → exception path)
assert builder.entry_count >= 0 # doesn't crash
def test_replace_entries_with_empty_content(self):
builder = TranscriptBuilder()
builder.append_user("hello")
assert builder.entry_count == 1
builder.replace_entries("")
# Empty content → load_previous returns early → temp is empty
# replace_entries swaps to empty (0 entries)
assert builder.entry_count == 0
def test_replace_entries_filters_strippable_types(self):
"""Strippable types (progress, file-history-snapshot) are filtered out."""
builder = TranscriptBuilder()
builder.append_user("hello")
content = _make_jsonl(
{"type": "progress", "uuid": "p1", "message": {}},
USER_MSG,
ASST_MSG,
)
builder.replace_entries(content)
# Only user + assistant should remain (progress filtered)
assert builder.entry_count == 2
def test_replace_entries_preserves_uuids(self):
builder = TranscriptBuilder()
content = _make_jsonl(USER_MSG, ASST_MSG)
builder.replace_entries(content)
jsonl = builder.to_jsonl()
lines = jsonl.strip().split("\n")
first = json.loads(lines[0])
assert first["uuid"] == "u1"
class TestTranscriptBuilderBasic:
def test_append_user_and_assistant(self):
builder = TranscriptBuilder()
builder.append_user("hi")
builder.append_assistant([{"type": "text", "text": "hello"}])
assert builder.entry_count == 2
assert not builder.is_empty
def test_to_jsonl_empty(self):
builder = TranscriptBuilder()
assert builder.to_jsonl() == ""
assert builder.is_empty
def test_load_previous_and_append(self):
builder = TranscriptBuilder()
content = _make_jsonl(USER_MSG, ASST_MSG)
builder.load_previous(content)
assert builder.entry_count == 2
builder.append_user("new message")
assert builder.entry_count == 3
def test_consecutive_assistant_entries_share_message_id(self):
builder = TranscriptBuilder()
builder.append_user("hi")
builder.append_assistant([{"type": "text", "text": "part1"}])
builder.append_assistant([{"type": "text", "text": "part2"}])
jsonl = builder.to_jsonl()
lines = jsonl.strip().split("\n")
asst1 = json.loads(lines[1])
asst2 = json.loads(lines[2])
assert asst1["message"]["id"] == asst2["message"]["id"]
def test_non_consecutive_assistant_entries_get_new_id(self):
builder = TranscriptBuilder()
builder.append_user("hi")
builder.append_assistant([{"type": "text", "text": "response1"}])
builder.append_user("followup")
builder.append_assistant([{"type": "text", "text": "response2"}])
jsonl = builder.to_jsonl()
lines = jsonl.strip().split("\n")
asst1 = json.loads(lines[1])
asst2 = json.loads(lines[3])
assert asst1["message"]["id"] != asst2["message"]["id"]
class TestCompactSummaryRoundtrip:
"""Verify isCompactSummary survives export→reload roundtrip."""
def test_load_previous_preserves_compact_summary(self):
"""Compaction summary with type 'summary' should not be stripped."""
content = _make_jsonl(COMPACT_SUMMARY, USER_MSG, ASST_MSG)
builder = TranscriptBuilder()
builder.load_previous(content)
# summary type is in STRIPPABLE_TYPES, but isCompactSummary keeps it
assert builder.entry_count == 3
def test_export_reload_preserves_compact_summary(self):
"""Critical: isCompactSummary must survive to_jsonl → load_previous."""
content = _make_jsonl(COMPACT_SUMMARY, USER_MSG, ASST_MSG)
builder1 = TranscriptBuilder()
builder1.load_previous(content)
assert builder1.entry_count == 3
exported = builder1.to_jsonl()
# Verify isCompactSummary is in the exported JSONL
first_line = json.loads(exported.strip().split("\n")[0])
assert first_line.get("isCompactSummary") is True
# Reload and verify it's still preserved
builder2 = TranscriptBuilder()
builder2.load_previous(exported)
assert builder2.entry_count == 3
def test_strip_progress_preserves_compact_summary(self):
"""strip_progress_entries should keep isCompactSummary entries."""
content = _make_jsonl(COMPACT_SUMMARY, USER_MSG, ASST_MSG)
stripped = strip_progress_entries(content)
entries = [json.loads(line) for line in stripped.strip().split("\n")]
types = [e.get("type") for e in entries]
assert "summary" in types # Not stripped despite being in STRIPPABLE_TYPES
compact = [e for e in entries if e.get("isCompactSummary")]
assert len(compact) == 1
def test_regular_summary_still_stripped(self):
"""Non-compact summaries should still be stripped."""
regular_summary = {
"type": "summary",
"uuid": "rs1",
"summary": "Session summary",
}
content = _make_jsonl(regular_summary, USER_MSG, ASST_MSG)
stripped = strip_progress_entries(content)
entries = [json.loads(line) for line in stripped.strip().split("\n")]
types = [e.get("type") for e in entries]
assert "summary" not in types
def test_replace_entries_preserves_compact_summary(self):
"""replace_entries should preserve isCompactSummary entries."""
builder = TranscriptBuilder()
builder.append_user("old")
content = _make_jsonl(COMPACT_SUMMARY, USER_MSG, ASST_MSG)
builder.replace_entries(content)
assert builder.entry_count == 3
# Verify by re-exporting
exported = builder.to_jsonl()
first = json.loads(exported.strip().split("\n")[0])
assert first.get("isCompactSummary") is True
# --- _flatten_assistant_content ---
class TestFlattenAssistantContent:
def test_text_blocks(self):
blocks = [
{"type": "text", "text": "Hello"},
{"type": "text", "text": "World"},
]
assert _flatten_assistant_content(blocks) == "Hello\nWorld"
def test_tool_use_blocks(self):
blocks = [{"type": "tool_use", "name": "read_file", "id": "t1", "input": {}}]
assert _flatten_assistant_content(blocks) == "[tool_use: read_file]"
def test_mixed_blocks(self):
blocks = [
{"type": "text", "text": "Let me read that."},
{"type": "tool_use", "name": "read", "id": "t1", "input": {}},
]
result = _flatten_assistant_content(blocks)
assert "Let me read that." in result
assert "[tool_use: read]" in result
def test_string_blocks(self):
"""Plain strings in the list should be included."""
assert _flatten_assistant_content(["hello", "world"]) == "hello\nworld"
def test_empty_list(self):
assert _flatten_assistant_content([]) == ""
def test_tool_use_missing_name(self):
blocks = [{"type": "tool_use", "id": "t1", "input": {}}]
assert _flatten_assistant_content(blocks) == "[tool_use: ?]"
# --- _flatten_tool_result_content ---
class TestFlattenToolResultContent:
def test_tool_result_with_text(self):
blocks = [
{
"type": "tool_result",
"tool_use_id": "t1",
"content": [{"type": "text", "text": "file contents here"}],
}
]
assert _flatten_tool_result_content(blocks) == "file contents here"
def test_tool_result_with_string_content(self):
blocks = [
{"type": "tool_result", "tool_use_id": "t1", "content": "simple result"}
]
assert _flatten_tool_result_content(blocks) == "simple result"
def test_tool_result_with_nested_list(self):
blocks = [
{
"type": "tool_result",
"tool_use_id": "t1",
"content": [
{"type": "text", "text": "line 1"},
{"type": "text", "text": "line 2"},
],
}
]
assert _flatten_tool_result_content(blocks) == "line 1\nline 2"
def test_text_blocks(self):
blocks = [{"type": "text", "text": "some text"}]
assert _flatten_tool_result_content(blocks) == "some text"
def test_string_items(self):
assert _flatten_tool_result_content(["raw string"]) == "raw string"
def test_empty_list(self):
assert _flatten_tool_result_content([]) == ""
def test_tool_result_none_text_uses_json(self):
"""Dicts without text key fall back to json.dumps."""
blocks = [
{
"type": "tool_result",
"tool_use_id": "t1",
"content": [{"type": "image", "source": "data:..."}],
}
]
result = _flatten_tool_result_content(blocks)
assert "image" in result # json.dumps fallback includes the key
# --- _transcript_to_messages ---
class TestTranscriptToMessages:
def test_basic_conversation(self):
content = _make_jsonl(
{
"type": "user",
"uuid": "u1",
"message": {"role": "user", "content": "hello"},
},
{
"type": "assistant",
"uuid": "a1",
"parentUuid": "u1",
"message": {
"role": "assistant",
"content": [{"type": "text", "text": "hi there"}],
},
},
)
msgs = _transcript_to_messages(content)
assert len(msgs) == 2
assert msgs[0] == {"role": "user", "content": "hello"}
assert msgs[1] == {"role": "assistant", "content": "hi there"}
def test_strips_progress_entries(self):
content = _make_jsonl(
{
"type": "user",
"uuid": "u1",
"message": {"role": "user", "content": "hi"},
},
{
"type": "progress",
"uuid": "p1",
"message": {"role": "user", "content": "..."},
},
{
"type": "assistant",
"uuid": "a1",
"message": {
"role": "assistant",
"content": [{"type": "text", "text": "ok"}],
},
},
)
msgs = _transcript_to_messages(content)
assert len(msgs) == 2
assert msgs[0]["role"] == "user"
assert msgs[1]["role"] == "assistant"
def test_preserves_compact_summaries(self):
content = _make_jsonl(
{
"type": "summary",
"uuid": "cs1",
"isCompactSummary": True,
"message": {"role": "user", "content": "Summary of previous..."},
},
{
"type": "user",
"uuid": "u1",
"message": {"role": "user", "content": "hi"},
},
)
msgs = _transcript_to_messages(content)
assert len(msgs) == 2
assert msgs[0]["content"] == "Summary of previous..."
def test_strips_regular_summary(self):
content = _make_jsonl(
{
"type": "summary",
"uuid": "s1",
"message": {"role": "user", "content": "Session summary"},
},
{
"type": "user",
"uuid": "u1",
"message": {"role": "user", "content": "hi"},
},
)
msgs = _transcript_to_messages(content)
assert len(msgs) == 1
assert msgs[0]["content"] == "hi"
def test_skips_entries_without_role(self):
content = _make_jsonl(
{"type": "user", "uuid": "u1", "message": {}},
{
"type": "user",
"uuid": "u2",
"message": {"role": "user", "content": "hi"},
},
)
msgs = _transcript_to_messages(content)
assert len(msgs) == 1
def test_tool_result_content(self):
content = _make_jsonl(
{
"type": "user",
"uuid": "u1",
"message": {
"role": "user",
"content": [
{
"type": "tool_result",
"tool_use_id": "t1",
"content": "file contents",
}
],
},
},
)
msgs = _transcript_to_messages(content)
assert len(msgs) == 1
assert "file contents" in msgs[0]["content"]
def test_empty_content(self):
assert _transcript_to_messages("") == []
assert _transcript_to_messages(" \n ") == []
def test_invalid_json_lines_skipped(self):
content = '{"type":"user","uuid":"u1","message":{"role":"user","content":"hi"}}\nnot json\n'
msgs = _transcript_to_messages(content)
assert len(msgs) == 1
# --- _messages_to_transcript ---
class TestMessagesToTranscript:
def test_basic_roundtrip_structure(self):
messages = [
{"role": "user", "content": "hello"},
{"role": "assistant", "content": "hi there"},
]
result = _messages_to_transcript(messages)
assert result.endswith("\n")
lines = [json.loads(line) for line in result.strip().split("\n")]
assert len(lines) == 2
# User entry
assert lines[0]["type"] == "user"
assert lines[0]["message"]["role"] == "user"
assert lines[0]["message"]["content"] == "hello"
assert lines[0]["parentUuid"] is None
# Assistant entry
assert lines[1]["type"] == "assistant"
assert lines[1]["message"]["role"] == "assistant"
assert lines[1]["message"]["content"] == [{"type": "text", "text": "hi there"}]
assert lines[1]["message"]["id"].startswith(COMPACT_MSG_ID_PREFIX)
assert lines[1]["parentUuid"] == lines[0]["uuid"]
def test_parent_uuid_chain(self):
messages = [
{"role": "user", "content": "q1"},
{"role": "assistant", "content": "a1"},
{"role": "user", "content": "q2"},
]
result = _messages_to_transcript(messages)
lines = [json.loads(line) for line in result.strip().split("\n")]
assert lines[0]["parentUuid"] is None
assert lines[1]["parentUuid"] == lines[0]["uuid"]
assert lines[2]["parentUuid"] == lines[1]["uuid"]
def test_empty_messages(self):
assert _messages_to_transcript([]) == ""
def test_assistant_empty_content(self):
messages = [{"role": "assistant", "content": ""}]
result = _messages_to_transcript(messages)
entry = json.loads(result.strip())
assert entry["message"]["content"] == []
def test_output_is_valid_transcript(self):
messages = [
{"role": "user", "content": "hello"},
{"role": "assistant", "content": "world"},
]
result = _messages_to_transcript(messages)
assert validate_transcript(result)
# --- _transcript_to_messages + _messages_to_transcript roundtrip ---
class TestTranscriptCompactionRoundtrip:
def test_content_preserved_through_roundtrip(self):
"""Messages→transcript→messages preserves content."""
original = [
{"role": "user", "content": "What is 2+2?"},
{"role": "assistant", "content": "4"},
{"role": "user", "content": "Thanks"},
]
transcript = _messages_to_transcript(original)
recovered = _transcript_to_messages(transcript)
assert len(recovered) == len(original)
for orig, rec in zip(original, recovered):
assert orig["role"] == rec["role"]
assert orig["content"] == rec["content"]
def test_full_transcript_to_messages_and_back(self):
"""Real-ish JSONL → messages → transcript → messages roundtrip."""
source = _make_jsonl(
{
"type": "user",
"uuid": "u1",
"message": {"role": "user", "content": "explain python"},
},
{
"type": "assistant",
"uuid": "a1",
"parentUuid": "u1",
"message": {
"role": "assistant",
"content": [
{"type": "text", "text": "Python is a programming language."}
],
},
},
{
"type": "user",
"uuid": "u2",
"parentUuid": "a1",
"message": {
"role": "user",
"content": [
{
"type": "tool_result",
"tool_use_id": "t1",
"content": "output of ls",
}
],
},
},
)
msgs1 = _transcript_to_messages(source)
assert len(msgs1) == 3
rebuilt = _messages_to_transcript(msgs1)
msgs2 = _transcript_to_messages(rebuilt)
assert len(msgs2) == len(msgs1)
for m1, m2 in zip(msgs1, msgs2):
assert m1["role"] == m2["role"]
# Content may differ in format (list vs string) but text is preserved
assert m1["content"] in m2["content"] or m2["content"] in m1["content"]
# --- compact_transcript ---
class TestCompactTranscript:
@pytest.mark.asyncio
async def test_too_few_messages_returns_none(self):
"""Transcripts with < 2 messages can't be compacted."""
single = _make_jsonl(
{"type": "user", "uuid": "u1", "message": {"role": "user", "content": "hi"}}
)
result = await compact_transcript(single)
assert result is None
@pytest.mark.asyncio
async def test_empty_transcript_returns_none(self):
result = await compact_transcript("")
assert result is None
@pytest.mark.asyncio
async def test_compaction_produces_valid_transcript(self, monkeypatch):
"""When compress_context compacts, result should be valid JSONL."""
from unittest.mock import AsyncMock
from backend.util.prompt import CompressResult
mock_result = CompressResult(
messages=[
{"role": "user", "content": "Summary of conversation"},
{"role": "assistant", "content": "Acknowledged"},
],
token_count=50,
was_compacted=True,
original_token_count=5000,
messages_summarized=10,
messages_dropped=5,
)
monkeypatch.setattr(
"backend.copilot.sdk.transcript._run_compression",
AsyncMock(return_value=mock_result),
)
source = _make_jsonl(
{
"type": "user",
"uuid": "u1",
"message": {"role": "user", "content": "msg1"},
},
{
"type": "assistant",
"uuid": "a1",
"message": {
"role": "assistant",
"content": [{"type": "text", "text": "reply1"}],
},
},
{
"type": "user",
"uuid": "u2",
"message": {"role": "user", "content": "msg2"},
},
)
result = await compact_transcript(source)
assert result is not None
assert validate_transcript(result)
# Verify compacted content
msgs = _transcript_to_messages(result)
assert len(msgs) == 2
assert msgs[0]["content"] == "Summary of conversation"
@pytest.mark.asyncio
async def test_no_compaction_needed_returns_original(self, monkeypatch):
"""When compress_context says no compaction needed, return original."""
from unittest.mock import AsyncMock
from backend.util.prompt import CompressResult
mock_result = CompressResult(
messages=[], token_count=100, was_compacted=False, original_token_count=100
)
monkeypatch.setattr(
"backend.copilot.sdk.transcript._run_compression",
AsyncMock(return_value=mock_result),
)
source = _make_jsonl(
{
"type": "user",
"uuid": "u1",
"message": {"role": "user", "content": "hi"},
},
{
"type": "assistant",
"uuid": "a1",
"message": {
"role": "assistant",
"content": [{"type": "text", "text": "hello"}],
},
},
)
result = await compact_transcript(source)
assert result == source # Unchanged
@pytest.mark.asyncio
async def test_compression_failure_returns_none(self, monkeypatch):
"""When _run_compression raises, compact_transcript returns None."""
from unittest.mock import AsyncMock
monkeypatch.setattr(
"backend.copilot.sdk.transcript._run_compression",
AsyncMock(side_effect=RuntimeError("LLM unavailable")),
)
source = _make_jsonl(
{
"type": "user",
"uuid": "u1",
"message": {"role": "user", "content": "hi"},
},
{
"type": "assistant",
"uuid": "a1",
"message": {
"role": "assistant",
"content": [{"type": "text", "text": "hello"}],
},
},
)
result = await compact_transcript(source)
assert result is None

View File

@@ -23,11 +23,6 @@ from typing import Any, Literal
import orjson
from backend.api.model import CopilotCompletionPayload
from backend.data.notification_bus import (
AsyncRedisNotificationEventBus,
NotificationEvent,
)
from backend.data.redis_client import get_redis_async
from .config import ChatConfig
@@ -43,7 +38,6 @@ from .response_model import (
logger = logging.getLogger(__name__)
config = ChatConfig()
_notification_bus = AsyncRedisNotificationEventBus()
# Track background tasks for this pod (just the asyncio.Task reference, not subscribers)
_local_sessions: dict[str, asyncio.Task] = {}
@@ -751,29 +745,6 @@ async def mark_session_completed(
# Clean up local session reference if exists
_local_sessions.pop(session_id, None)
# Publish copilot completion notification via WebSocket
if meta:
parsed = _parse_session_meta(meta, session_id)
if parsed.user_id:
try:
await _notification_bus.publish(
NotificationEvent(
user_id=parsed.user_id,
payload=CopilotCompletionPayload(
type="copilot_completion",
event="session_completed",
session_id=session_id,
status=status,
),
)
)
except Exception as e:
logger.warning(
f"Failed to publish copilot completion notification "
f"for session {session_id}: {e}"
)
return True

View File

@@ -12,7 +12,6 @@ from .agent_browser import BrowserActTool, BrowserNavigateTool, BrowserScreensho
from .agent_output import AgentOutputTool
from .base import BaseTool
from .bash_exec import BashExecTool
from .continue_run_block import ContinueRunBlockTool
from .create_agent import CreateAgentTool
from .customize_agent import CustomizeAgentTool
from .edit_agent import EditAgentTool
@@ -69,7 +68,6 @@ TOOL_REGISTRY: dict[str, BaseTool] = {
"move_agents_to_folder": MoveAgentsToFolderTool(),
"run_agent": RunAgentTool(),
"run_block": RunBlockTool(),
"continue_run_block": ContinueRunBlockTool(),
"run_mcp_tool": RunMCPToolTool(),
"get_mcp_guide": GetMCPGuideTool(),
"view_agent_output": AgentOutputTool(),

View File

@@ -829,12 +829,8 @@ class AgentFixer:
For nodes whose block has category "AI", this function ensures that the
input_default has a "model" parameter set to one of the allowed models.
If missing or set to an unsupported value, it is replaced with the
appropriate default.
Blocks that define their own ``enum`` constraint on the ``model`` field
in their inputSchema (e.g. PerplexityBlock) are validated against that
enum instead of the generic allowed set.
If missing or set to an unsupported value, it is replaced with
default_model.
Args:
agent: The agent dictionary to fix
@@ -844,7 +840,7 @@ class AgentFixer:
Returns:
The fixed agent dictionary
"""
generic_allowed_models = {"gpt-4o", "claude-opus-4-6"}
allowed_models = {"gpt-4o", "claude-opus-4-6"}
# Create a mapping of block_id to block for quick lookup
block_map = {block.get("id"): block for block in blocks}
@@ -872,36 +868,20 @@ class AgentFixer:
input_default = node.get("input_default", {})
current_model = input_default.get("model")
# Determine allowed models and default from the block's schema.
# Blocks with a block-specific enum on the model field (e.g.
# PerplexityBlock) use their own enum values; others use the
# generic set.
model_schema = (
block.get("inputSchema", {}).get("properties", {}).get("model", {})
)
block_model_enum = model_schema.get("enum")
if block_model_enum:
allowed_models = set(block_model_enum)
fallback_model = model_schema.get("default", block_model_enum[0])
else:
allowed_models = generic_allowed_models
fallback_model = default_model
if current_model not in allowed_models:
block_name = block.get("name", "Unknown AI Block")
if current_model is None:
self.add_fix_log(
f"Added model parameter '{fallback_model}' to AI "
f"Added model parameter '{default_model}' to AI "
f"block node {node_id} ({block_name})"
)
else:
self.add_fix_log(
f"Replaced unsupported model '{current_model}' "
f"with '{fallback_model}' on AI block node "
f"with '{default_model}' on AI block node "
f"{node_id} ({block_name})"
)
input_default["model"] = fallback_model
input_default["model"] = default_model
node["input_default"] = input_default
fixed_count += 1

View File

@@ -475,111 +475,6 @@ class TestFixAiModelParameter:
assert result["nodes"][0]["input_default"]["model"] == "claude-opus-4-6"
def test_block_specific_enum_uses_block_default(self):
"""Blocks with their own model enum (e.g. PerplexityBlock) should use
the block's allowed models and default, not the generic ones."""
fixer = AgentFixer()
block_id = generate_uuid()
node = _make_node(
node_id="n1",
block_id=block_id,
input_default={"model": "gpt-5.2-2025-12-11"},
)
agent = _make_agent(nodes=[node])
blocks = [
{
"id": block_id,
"name": "PerplexityBlock",
"categories": [{"category": "AI"}],
"inputSchema": {
"properties": {
"model": {
"type": "string",
"enum": [
"perplexity/sonar",
"perplexity/sonar-pro",
"perplexity/sonar-deep-research",
],
"default": "perplexity/sonar",
}
},
},
}
]
result = fixer.fix_ai_model_parameter(agent, blocks)
assert result["nodes"][0]["input_default"]["model"] == "perplexity/sonar"
def test_block_specific_enum_valid_model_unchanged(self):
"""A valid block-specific model should not be replaced."""
fixer = AgentFixer()
block_id = generate_uuid()
node = _make_node(
node_id="n1",
block_id=block_id,
input_default={"model": "perplexity/sonar-pro"},
)
agent = _make_agent(nodes=[node])
blocks = [
{
"id": block_id,
"name": "PerplexityBlock",
"categories": [{"category": "AI"}],
"inputSchema": {
"properties": {
"model": {
"type": "string",
"enum": [
"perplexity/sonar",
"perplexity/sonar-pro",
"perplexity/sonar-deep-research",
],
"default": "perplexity/sonar",
}
},
},
}
]
result = fixer.fix_ai_model_parameter(agent, blocks)
assert result["nodes"][0]["input_default"]["model"] == "perplexity/sonar-pro"
def test_block_specific_enum_missing_model_gets_block_default(self):
"""Missing model on a block with enum should use the block's default."""
fixer = AgentFixer()
block_id = generate_uuid()
node = _make_node(node_id="n1", block_id=block_id, input_default={})
agent = _make_agent(nodes=[node])
blocks = [
{
"id": block_id,
"name": "PerplexityBlock",
"categories": [{"category": "AI"}],
"inputSchema": {
"properties": {
"model": {
"type": "string",
"enum": [
"perplexity/sonar",
"perplexity/sonar-pro",
"perplexity/sonar-deep-research",
],
"default": "perplexity/sonar",
}
},
},
}
]
result = fixer.fix_ai_model_parameter(agent, blocks)
assert result["nodes"][0]["input_default"]["model"] == "perplexity/sonar"
class TestFixAgentExecutorBlocks:
"""Tests for fix_agent_executor_blocks."""

View File

@@ -1,157 +0,0 @@
"""Tool for continuing block execution after human review approval."""
import logging
from typing import Any
from prisma.enums import ReviewStatus
from backend.blocks import get_block
from backend.copilot.constants import (
COPILOT_NODE_PREFIX,
COPILOT_SESSION_PREFIX,
parse_node_id_from_exec_id,
)
from backend.copilot.model import ChatSession
from backend.data.db_accessors import review_db
from .base import BaseTool
from .helpers import execute_block, resolve_block_credentials
from .models import ErrorResponse, ToolResponseBase
logger = logging.getLogger(__name__)
class ContinueRunBlockTool(BaseTool):
"""Tool for continuing a block execution after human review approval."""
@property
def name(self) -> str:
return "continue_run_block"
@property
def description(self) -> str:
return (
"Continue executing a block after human review approval. "
"Use this after a run_block call returned review_required. "
"Pass the review_id from the review_required response. "
"The block will execute with the original pre-approved input data."
)
@property
def parameters(self) -> dict[str, Any]:
return {
"type": "object",
"properties": {
"review_id": {
"type": "string",
"description": (
"The review_id from a previous review_required response. "
"This resumes execution with the pre-approved input data."
),
},
},
"required": ["review_id"],
}
@property
def requires_auth(self) -> bool:
return True
async def _execute(
self,
user_id: str | None,
session: ChatSession,
**kwargs,
) -> ToolResponseBase:
review_id = (
kwargs.get("review_id", "").strip() if kwargs.get("review_id") else ""
)
session_id = session.session_id
if not review_id:
return ErrorResponse(
message="Please provide a review_id", session_id=session_id
)
if not user_id:
return ErrorResponse(
message="Authentication required", session_id=session_id
)
# Look up and validate the review record via adapter
reviews = await review_db().get_reviews_by_node_exec_ids([review_id], user_id)
review = reviews.get(review_id)
if not review:
return ErrorResponse(
message=(
f"Review '{review_id}' not found or already executed. "
"It may have been consumed by a previous continue_run_block call."
),
session_id=session_id,
)
# Validate the review belongs to this session
expected_graph_exec_id = f"{COPILOT_SESSION_PREFIX}{session_id}"
if review.graph_exec_id != expected_graph_exec_id:
return ErrorResponse(
message="Review does not belong to this session.",
session_id=session_id,
)
if review.status == ReviewStatus.WAITING:
return ErrorResponse(
message="Review has not been approved yet. "
"Please wait for the user to approve the review first.",
session_id=session_id,
)
if review.status == ReviewStatus.REJECTED:
return ErrorResponse(
message="Review was rejected. The block will not execute.",
session_id=session_id,
)
# Extract block_id from review_id: copilot-node-{block_id}:{random_hex}
block_id = parse_node_id_from_exec_id(review_id).removeprefix(
COPILOT_NODE_PREFIX
)
block = get_block(block_id)
if not block:
return ErrorResponse(
message=f"Block '{block_id}' not found", session_id=session_id
)
input_data: dict[str, Any] = (
review.payload if isinstance(review.payload, dict) else {}
)
logger.info(
f"Continuing block {block.name} ({block_id}) for user {user_id} "
f"with review_id={review_id}"
)
matched_creds, missing_creds = await resolve_block_credentials(
user_id, block, input_data
)
if missing_creds:
return ErrorResponse(
message=f"Block '{block.name}' requires credentials that are not configured.",
session_id=session_id,
)
result = await execute_block(
block=block,
block_id=block_id,
input_data=input_data,
user_id=user_id,
session_id=session_id,
node_exec_id=review_id,
matched_credentials=matched_creds,
)
# Delete review record after successful execution (one-time use)
if result.type != "error":
await review_db().delete_review_by_node_exec_id(review_id, user_id)
return result

View File

@@ -1,186 +0,0 @@
"""Tests for ContinueRunBlockTool."""
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from prisma.enums import ReviewStatus
from ._test_data import make_session
from .continue_run_block import ContinueRunBlockTool
from .models import BlockOutputResponse, ErrorResponse
_TEST_USER_ID = "test-user-continue"
def _make_review_model(
node_exec_id: str,
status: ReviewStatus = ReviewStatus.APPROVED,
payload: dict | None = None,
graph_exec_id: str = "",
):
"""Create a mock PendingHumanReviewModel."""
mock = MagicMock()
mock.node_exec_id = node_exec_id
mock.status = status
mock.payload = payload or {"text": "hello"}
mock.graph_exec_id = graph_exec_id
return mock
class TestContinueRunBlock:
@pytest.mark.asyncio(loop_scope="session")
async def test_missing_review_id_returns_error(self):
tool = ContinueRunBlockTool()
session = make_session(user_id=_TEST_USER_ID)
response = await tool._execute(
user_id=_TEST_USER_ID,
session=session,
review_id="",
)
assert isinstance(response, ErrorResponse)
assert "review_id" in response.message
@pytest.mark.asyncio(loop_scope="session")
async def test_review_not_found_returns_error(self):
tool = ContinueRunBlockTool()
session = make_session(user_id=_TEST_USER_ID)
mock_db = MagicMock()
mock_db.get_reviews_by_node_exec_ids = AsyncMock(return_value={})
with patch(
"backend.copilot.tools.continue_run_block.review_db",
return_value=mock_db,
):
response = await tool._execute(
user_id=_TEST_USER_ID,
session=session,
review_id="copilot-node-some-block:abc12345",
)
assert isinstance(response, ErrorResponse)
assert "not found" in response.message
@pytest.mark.asyncio(loop_scope="session")
async def test_waiting_review_returns_error(self):
tool = ContinueRunBlockTool()
session = make_session(user_id=_TEST_USER_ID)
review_id = "copilot-node-some-block:abc12345"
graph_exec_id = f"copilot-session-{session.session_id}"
review = _make_review_model(
review_id, status=ReviewStatus.WAITING, graph_exec_id=graph_exec_id
)
mock_db = MagicMock()
mock_db.get_reviews_by_node_exec_ids = AsyncMock(
return_value={review_id: review}
)
with patch(
"backend.copilot.tools.continue_run_block.review_db",
return_value=mock_db,
):
response = await tool._execute(
user_id=_TEST_USER_ID,
session=session,
review_id=review_id,
)
assert isinstance(response, ErrorResponse)
assert "not been approved" in response.message
@pytest.mark.asyncio(loop_scope="session")
async def test_rejected_review_returns_error(self):
tool = ContinueRunBlockTool()
session = make_session(user_id=_TEST_USER_ID)
review_id = "copilot-node-some-block:abc12345"
graph_exec_id = f"copilot-session-{session.session_id}"
review = _make_review_model(
review_id, status=ReviewStatus.REJECTED, graph_exec_id=graph_exec_id
)
mock_db = MagicMock()
mock_db.get_reviews_by_node_exec_ids = AsyncMock(
return_value={review_id: review}
)
with patch(
"backend.copilot.tools.continue_run_block.review_db",
return_value=mock_db,
):
response = await tool._execute(
user_id=_TEST_USER_ID,
session=session,
review_id=review_id,
)
assert isinstance(response, ErrorResponse)
assert "rejected" in response.message.lower()
@pytest.mark.asyncio(loop_scope="session")
async def test_approved_review_executes_block(self):
tool = ContinueRunBlockTool()
session = make_session(user_id=_TEST_USER_ID)
review_id = "copilot-node-delete-branch-id:abc12345"
graph_exec_id = f"copilot-session-{session.session_id}"
input_data = {"repo_url": "https://github.com/test/repo", "branch": "main"}
review = _make_review_model(
review_id,
status=ReviewStatus.APPROVED,
payload=input_data,
graph_exec_id=graph_exec_id,
)
mock_block = MagicMock()
mock_block.name = "Delete Branch"
async def mock_execute(data, **kwargs):
yield "result", "Branch deleted"
mock_block.execute = mock_execute
mock_block.input_schema.get_credentials_fields_info.return_value = []
mock_workspace_db = MagicMock()
mock_workspace_db.get_or_create_workspace = AsyncMock(
return_value=MagicMock(id="test-workspace-id")
)
mock_db = MagicMock()
mock_db.get_reviews_by_node_exec_ids = AsyncMock(
return_value={review_id: review}
)
mock_db.delete_review_by_node_exec_id = AsyncMock(return_value=1)
with (
patch(
"backend.copilot.tools.continue_run_block.review_db",
return_value=mock_db,
),
patch(
"backend.copilot.tools.continue_run_block.get_block",
return_value=mock_block,
),
patch(
"backend.copilot.tools.helpers.workspace_db",
return_value=mock_workspace_db,
),
patch(
"backend.copilot.tools.helpers.match_credentials_to_requirements",
return_value=({}, []),
),
):
response = await tool._execute(
user_id=_TEST_USER_ID,
session=session,
review_id=review_id,
)
assert isinstance(response, BlockOutputResponse)
assert response.success is True
assert response.block_name == "Delete Branch"
# Verify review was deleted (one-time use)
mock_db.delete_review_by_node_exec_id.assert_called_once_with(
review_id, _TEST_USER_ID
)

View File

@@ -21,11 +21,9 @@ Lifecycle
Cost control
------------
Sandboxes are created with a configurable ``on_timeout`` lifecycle action
(default: ``"pause"``) and ``auto_resume`` (default: ``True``). The explicit
per-turn ``pause_sandbox()`` call is the primary mechanism; the lifecycle
timeout is a safety net (default: 5 min). ``auto_resume`` ensures that paused
sandboxes wake transparently on SDK activity, making the aggressive safety-net
timeout safe. Paused sandboxes are free.
(default: ``"pause"``). The explicit per-turn ``pause_sandbox()`` call is the
primary mechanism; the lifecycle setting is a safety net. Paused sandboxes are
free.
The sandbox_id is stored in Redis. The same key doubles as a creation lock:
a ``"creating"`` sentinel value is written with a short TTL while a new sandbox
@@ -42,7 +40,6 @@ import logging
from typing import Any, Awaitable, Callable, Literal
from e2b import AsyncSandbox
from e2b.sandbox.sandbox_api import SandboxLifecycle
from backend.data.redis_client import get_redis_async
@@ -119,10 +116,9 @@ async def get_or_create_sandbox(
removes the need for a separate lock key.
*timeout* controls how long the e2b sandbox may run continuously before
the ``on_timeout`` lifecycle rule fires (default: 5 min).
the ``on_timeout`` lifecycle rule fires (default: 3 h).
*on_timeout* controls what happens on timeout: ``"pause"`` (default, free)
or ``"kill"``. When ``"pause"``, ``auto_resume`` is enabled so paused
sandboxes wake transparently on SDK activity.
or ``"kill"``.
"""
redis = await get_redis_async()
key = _sandbox_key(session_id)
@@ -160,15 +156,11 @@ async def get_or_create_sandbox(
# We hold the slot — create the sandbox.
try:
lifecycle = SandboxLifecycle(
on_timeout=on_timeout,
auto_resume=on_timeout == "pause",
)
sandbox = await AsyncSandbox.create(
template=template,
api_key=api_key,
timeout=timeout,
lifecycle=lifecycle,
lifecycle={"on_timeout": on_timeout},
)
try:
await _set_stored_sandbox_id(session_id, sandbox.sandbox_id)

View File

@@ -157,17 +157,14 @@ class TestGetOrCreateSandbox:
assert result is new_sb
mock_cls.create.assert_awaited_once()
# Verify lifecycle: pause + auto_resume enabled
# Verify lifecycle param is set
_, kwargs = mock_cls.create.call_args
assert kwargs.get("lifecycle") == {
"on_timeout": "pause",
"auto_resume": True,
}
assert kwargs.get("lifecycle") == {"on_timeout": "pause"}
# sandbox_id should be saved to Redis
redis.set.assert_awaited()
def test_create_with_on_timeout_kill(self):
"""on_timeout='kill' disables auto_resume automatically."""
"""on_timeout='kill' is passed through to AsyncSandbox.create."""
new_sb = _mock_sandbox("sb-new")
redis = _mock_redis(set_nx_result=True, stored_sandbox_id=None)
with (
@@ -182,10 +179,7 @@ class TestGetOrCreateSandbox:
)
_, kwargs = mock_cls.create.call_args
assert kwargs.get("lifecycle") == {
"on_timeout": "kill",
"auto_resume": False,
}
assert kwargs.get("lifecycle") == {"on_timeout": "kill"}
def test_create_failure_releases_slot(self):
"""If sandbox creation fails, the Redis creation slot is deleted."""

View File

@@ -1,49 +1,7 @@
"""Shared helpers for chat tools."""
import logging
from collections import defaultdict
from typing import Any
from pydantic_core import PydanticUndefined
from backend.blocks._base import AnyBlockSchema
from backend.copilot.constants import COPILOT_NODE_PREFIX, COPILOT_SESSION_PREFIX
from backend.data import db
from backend.data.credit import UsageTransactionMetadata, get_user_credit_model
from backend.data.db_accessors import workspace_db
from backend.data.execution import ExecutionContext
from backend.data.model import CredentialsFieldInfo, CredentialsMetaInput
from backend.executor.utils import block_usage_cost
from backend.integrations.creds_manager import IntegrationCredentialsManager
from backend.util.clients import get_database_manager_async_client
from backend.util.exceptions import BlockError, InsufficientBalanceError
from backend.util.type import coerce_inputs_to_schema
from .models import BlockOutputResponse, ErrorResponse, ToolResponseBase
from .utils import match_credentials_to_requirements
logger = logging.getLogger(__name__)
async def _get_credits(user_id: str) -> int:
"""Get user credits using the adapter pattern (RPC when Prisma unavailable)."""
if not db.is_connected():
return await get_database_manager_async_client().get_credits(user_id)
credit_model = await get_user_credit_model(user_id)
return await credit_model.get_credits(user_id)
async def _spend_credits(
user_id: str, cost: int, metadata: UsageTransactionMetadata
) -> int:
"""Spend user credits using the adapter pattern (RPC when Prisma unavailable)."""
if not db.is_connected():
return await get_database_manager_async_client().spend_credits(
user_id, cost, metadata
)
credit_model = await get_user_credit_model(user_id)
return await credit_model.spend_credits(user_id, cost, metadata)
def get_inputs_from_schema(
input_schema: dict[str, Any],
@@ -69,207 +27,3 @@ def get_inputs_from_schema(
for name, schema in properties.items()
if name not in exclude
]
async def execute_block(
*,
block: AnyBlockSchema,
block_id: str,
input_data: dict[str, Any],
user_id: str,
session_id: str,
node_exec_id: str,
matched_credentials: dict[str, CredentialsMetaInput],
sensitive_action_safe_mode: bool = False,
) -> ToolResponseBase:
"""Execute a block with full context setup, credential injection, and error handling.
This is the shared execution path used by both ``run_block`` (after review
check) and ``continue_run_block`` (after approval).
Returns:
BlockOutputResponse on success, ErrorResponse on failure.
"""
try:
workspace = await workspace_db().get_or_create_workspace(user_id)
synthetic_graph_id = f"{COPILOT_SESSION_PREFIX}{session_id}"
synthetic_node_id = f"{COPILOT_NODE_PREFIX}{block_id}"
execution_context = ExecutionContext(
user_id=user_id,
graph_id=synthetic_graph_id,
graph_exec_id=synthetic_graph_id,
graph_version=1,
node_id=synthetic_node_id,
node_exec_id=node_exec_id,
workspace_id=workspace.id,
session_id=session_id,
sensitive_action_safe_mode=sensitive_action_safe_mode,
)
exec_kwargs: dict[str, Any] = {
"user_id": user_id,
"execution_context": execution_context,
"workspace_id": workspace.id,
"graph_exec_id": synthetic_graph_id,
"node_exec_id": node_exec_id,
"node_id": synthetic_node_id,
"graph_version": 1,
"graph_id": synthetic_graph_id,
}
# Inject credentials
creds_manager = IntegrationCredentialsManager()
for field_name, cred_meta in matched_credentials.items():
if field_name not in input_data:
input_data[field_name] = cred_meta.model_dump()
actual_credentials = await creds_manager.get(
user_id, cred_meta.id, lock=False
)
if actual_credentials:
exec_kwargs[field_name] = actual_credentials
else:
return ErrorResponse(
message=f"Failed to retrieve credentials for {field_name}",
session_id=session_id,
)
# Coerce non-matching data types to the expected input schema.
coerce_inputs_to_schema(input_data, block.input_schema)
# Pre-execution credit check
cost, cost_filter = block_usage_cost(block, input_data)
has_cost = cost > 0
if has_cost:
balance = await _get_credits(user_id)
if balance < cost:
return ErrorResponse(
message=(
f"Insufficient credits to run '{block.name}'. "
"Please top up your credits to continue."
),
session_id=session_id,
)
# Execute the block and collect outputs
outputs: dict[str, list[Any]] = defaultdict(list)
async for output_name, output_data in block.execute(
input_data,
**exec_kwargs,
):
outputs[output_name].append(output_data)
# Charge credits for block execution
if has_cost:
try:
await _spend_credits(
user_id=user_id,
cost=cost,
metadata=UsageTransactionMetadata(
graph_exec_id=synthetic_graph_id,
graph_id=synthetic_graph_id,
node_id=synthetic_node_id,
node_exec_id=node_exec_id,
block_id=block_id,
block=block.name,
input=cost_filter,
reason="copilot_block_execution",
),
)
except InsufficientBalanceError:
logger.warning(
"Post-exec credit charge failed for block %s (cost=%d)",
block.name,
cost,
)
return ErrorResponse(
message=(
f"Insufficient credits to complete '{block.name}'. "
"Please top up your credits to continue."
),
session_id=session_id,
)
return BlockOutputResponse(
message=f"Block '{block.name}' executed successfully",
block_id=block_id,
block_name=block.name,
outputs=dict(outputs),
success=True,
session_id=session_id,
)
except BlockError as e:
logger.warning("Block execution failed: %s", e)
return ErrorResponse(
message=f"Block execution failed: {e}",
error=str(e),
session_id=session_id,
)
except Exception as e:
logger.error("Unexpected error executing block: %s", e, exc_info=True)
return ErrorResponse(
message="An unexpected error occurred while executing the block",
error=str(e),
session_id=session_id,
)
async def resolve_block_credentials(
user_id: str,
block: AnyBlockSchema,
input_data: dict[str, Any] | None = None,
) -> tuple[dict[str, CredentialsMetaInput], list[CredentialsMetaInput]]:
"""Resolve credentials for a block by matching user's available credentials.
Handles discriminated credentials (e.g. provider selection based on model).
Returns:
(matched_credentials, missing_credentials)
"""
input_data = input_data or {}
requirements = _resolve_discriminated_credentials(block, input_data)
if not requirements:
return {}, []
return await match_credentials_to_requirements(user_id, requirements)
def _resolve_discriminated_credentials(
block: AnyBlockSchema,
input_data: dict[str, Any],
) -> dict[str, CredentialsFieldInfo]:
"""Resolve credential requirements, applying discriminator logic where needed."""
credentials_fields_info = block.input_schema.get_credentials_fields_info()
if not credentials_fields_info:
return {}
resolved: dict[str, CredentialsFieldInfo] = {}
for field_name, field_info in credentials_fields_info.items():
effective_field_info = field_info
if field_info.discriminator and field_info.discriminator_mapping:
discriminator_value = input_data.get(field_info.discriminator)
if discriminator_value is None:
field = block.input_schema.model_fields.get(field_info.discriminator)
if field and field.default is not PydanticUndefined:
discriminator_value = field.default
if (
discriminator_value
and discriminator_value in field_info.discriminator_mapping
):
effective_field_info = field_info.discriminate(discriminator_value)
effective_field_info.discriminator_values.add(discriminator_value)
logger.debug(
f"Discriminated provider for {field_name}: "
f"{discriminator_value} -> {effective_field_info.provider}"
)
resolved[field_name] = effective_field_info
return resolved

View File

@@ -1,506 +0,0 @@
"""Tests for execute_block — credit charging and type coercion."""
from collections.abc import AsyncIterator
from typing import Any
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from backend.blocks._base import BlockType
from backend.copilot.tools.helpers import execute_block
from backend.copilot.tools.models import BlockOutputResponse, ErrorResponse
_USER = "test-user-helpers"
_SESSION = "test-session-helpers"
def _make_block(block_id: str = "block-1", name: str = "TestBlock"):
"""Create a minimal mock block for execute_block()."""
mock = MagicMock()
mock.id = block_id
mock.name = name
mock.block_type = BlockType.STANDARD
mock.input_schema = MagicMock()
mock.input_schema.get_credentials_fields_info.return_value = {}
async def _execute(
input_data: dict, **kwargs: Any
) -> AsyncIterator[tuple[str, Any]]:
yield "result", "ok"
mock.execute = _execute
return mock
def _patch_workspace():
"""Patch workspace_db to return a mock workspace."""
mock_workspace = MagicMock()
mock_workspace.id = "ws-1"
mock_ws_db = MagicMock()
mock_ws_db.get_or_create_workspace = AsyncMock(return_value=mock_workspace)
return patch("backend.copilot.tools.helpers.workspace_db", return_value=mock_ws_db)
# ---------------------------------------------------------------------------
# Credit charging tests
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
class TestExecuteBlockCreditCharging:
async def test_charges_credits_when_cost_is_positive(self):
"""Block with cost > 0 should call spend_credits after execution."""
block = _make_block()
mock_spend = AsyncMock()
with (
_patch_workspace(),
patch(
"backend.copilot.tools.helpers.block_usage_cost",
return_value=(10, {"key": "val"}),
),
patch(
"backend.copilot.tools.helpers._get_credits",
new_callable=AsyncMock,
return_value=100,
),
patch(
"backend.copilot.tools.helpers._spend_credits",
new_callable=AsyncMock,
side_effect=mock_spend,
),
):
result = await execute_block(
block=block,
block_id="block-1",
input_data={"text": "hello"},
user_id=_USER,
session_id=_SESSION,
node_exec_id="exec-1",
matched_credentials={},
)
assert isinstance(result, BlockOutputResponse)
assert result.success is True
mock_spend.assert_awaited_once()
call_kwargs = mock_spend.call_args.kwargs
assert call_kwargs["cost"] == 10
assert call_kwargs["metadata"].reason == "copilot_block_execution"
async def test_returns_error_when_insufficient_credits_before_exec(self):
"""Pre-execution check should return ErrorResponse when balance < cost."""
block = _make_block()
with (
_patch_workspace(),
patch(
"backend.copilot.tools.helpers.block_usage_cost",
return_value=(10, {}),
),
patch(
"backend.copilot.tools.helpers._get_credits",
new_callable=AsyncMock,
return_value=5, # balance < cost (10)
),
):
result = await execute_block(
block=block,
block_id="block-1",
input_data={},
user_id=_USER,
session_id=_SESSION,
node_exec_id="exec-1",
matched_credentials={},
)
assert isinstance(result, ErrorResponse)
assert "Insufficient credits" in result.message
async def test_no_charge_when_cost_is_zero(self):
"""Block with cost 0 should not call spend_credits."""
block = _make_block()
with (
_patch_workspace(),
patch(
"backend.copilot.tools.helpers.block_usage_cost",
return_value=(0, {}),
),
patch(
"backend.copilot.tools.helpers._get_credits",
) as mock_get_credits,
patch(
"backend.copilot.tools.helpers._spend_credits",
) as mock_spend_credits,
):
result = await execute_block(
block=block,
block_id="block-1",
input_data={},
user_id=_USER,
session_id=_SESSION,
node_exec_id="exec-1",
matched_credentials={},
)
assert isinstance(result, BlockOutputResponse)
assert result.success is True
# Credit functions should not be called at all for zero-cost blocks
mock_get_credits.assert_not_awaited()
mock_spend_credits.assert_not_awaited()
async def test_returns_error_on_post_exec_insufficient_balance(self):
"""If charging fails after execution, return ErrorResponse."""
from backend.util.exceptions import InsufficientBalanceError
block = _make_block()
with (
_patch_workspace(),
patch(
"backend.copilot.tools.helpers.block_usage_cost",
return_value=(10, {}),
),
patch(
"backend.copilot.tools.helpers._get_credits",
new_callable=AsyncMock,
return_value=15, # passes pre-check
),
patch(
"backend.copilot.tools.helpers._spend_credits",
new_callable=AsyncMock,
side_effect=InsufficientBalanceError(
"Low balance", _USER, 5, 10
), # fails during actual charge (race with concurrent spend)
),
):
result = await execute_block(
block=block,
block_id="block-1",
input_data={},
user_id=_USER,
session_id=_SESSION,
node_exec_id="exec-1",
matched_credentials={},
)
assert isinstance(result, ErrorResponse)
assert "Insufficient credits" in result.message
# ---------------------------------------------------------------------------
# Type coercion tests
# ---------------------------------------------------------------------------
def _make_block_schema(annotations: dict[str, Any]) -> MagicMock:
"""Create a mock input_schema with model_fields matching the given annotations."""
schema = MagicMock()
model_fields = {}
for name, ann in annotations.items():
field = MagicMock()
field.annotation = ann
model_fields[name] = field
schema.model_fields = model_fields
return schema
def _make_coerce_block(
block_id: str,
name: str,
annotations: dict[str, Any],
outputs: dict[str, list[Any]] | None = None,
) -> MagicMock:
"""Create a mock block with typed annotations and a simple execute method."""
block = MagicMock()
block.id = block_id
block.name = name
block.input_schema = _make_block_schema(annotations)
captured_inputs: dict[str, Any] = {}
async def mock_execute(input_data: dict, **_kwargs: Any):
captured_inputs.update(input_data)
for output_name, values in (outputs or {"result": ["ok"]}).items():
for v in values:
yield output_name, v
block.execute = mock_execute
block._captured_inputs = captured_inputs
return block
_TEST_SESSION_ID = "test-session-coerce"
_TEST_USER_ID = "test-user-coerce"
@pytest.mark.asyncio(loop_scope="session")
async def test_coerce_json_string_to_nested_list():
"""JSON string → list[list[str]] (Google Sheets CSV import case)."""
block = _make_coerce_block(
"sheets-write",
"Google Sheets Write",
{"values": list[list[str]], "spreadsheet_id": str},
)
mock_workspace_db = MagicMock()
mock_workspace_db.get_or_create_workspace = AsyncMock(
return_value=MagicMock(id="ws-1")
)
with patch(
"backend.copilot.tools.helpers.workspace_db",
return_value=mock_workspace_db,
):
response = await execute_block(
block=block,
block_id="sheets-write",
input_data={
"values": '[["Name","Score"],["Alice","90"],["Bob","85"]]',
"spreadsheet_id": "abc123",
},
user_id=_TEST_USER_ID,
session_id=_TEST_SESSION_ID,
node_exec_id="exec-1",
matched_credentials={},
)
assert isinstance(response, BlockOutputResponse)
assert response.success is True
assert block._captured_inputs["values"] == [
["Name", "Score"],
["Alice", "90"],
["Bob", "85"],
]
assert isinstance(block._captured_inputs["values"], list)
assert isinstance(block._captured_inputs["values"][0], list)
@pytest.mark.asyncio(loop_scope="session")
async def test_coerce_json_string_to_list():
"""JSON string → list[str]."""
block = _make_coerce_block(
"list-block",
"List Block",
{"items": list[str]},
)
mock_workspace_db = MagicMock()
mock_workspace_db.get_or_create_workspace = AsyncMock(
return_value=MagicMock(id="ws-1")
)
with patch(
"backend.copilot.tools.helpers.workspace_db",
return_value=mock_workspace_db,
):
response = await execute_block(
block=block,
block_id="list-block",
input_data={"items": '["a","b","c"]'},
user_id=_TEST_USER_ID,
session_id=_TEST_SESSION_ID,
node_exec_id="exec-2",
matched_credentials={},
)
assert isinstance(response, BlockOutputResponse)
assert block._captured_inputs["items"] == ["a", "b", "c"]
@pytest.mark.asyncio(loop_scope="session")
async def test_coerce_json_string_to_dict():
"""JSON string → dict[str, str]."""
block = _make_coerce_block(
"dict-block",
"Dict Block",
{"config": dict[str, str]},
)
mock_workspace_db = MagicMock()
mock_workspace_db.get_or_create_workspace = AsyncMock(
return_value=MagicMock(id="ws-1")
)
with patch(
"backend.copilot.tools.helpers.workspace_db",
return_value=mock_workspace_db,
):
response = await execute_block(
block=block,
block_id="dict-block",
input_data={"config": '{"key": "value", "foo": "bar"}'},
user_id=_TEST_USER_ID,
session_id=_TEST_SESSION_ID,
node_exec_id="exec-3",
matched_credentials={},
)
assert isinstance(response, BlockOutputResponse)
assert block._captured_inputs["config"] == {"key": "value", "foo": "bar"}
@pytest.mark.asyncio(loop_scope="session")
async def test_no_coercion_when_type_matches():
"""Already-correct types pass through without coercion."""
block = _make_coerce_block(
"pass-through",
"Pass Through",
{"values": list[list[str]], "name": str},
)
original_values = [["a", "b"], ["c", "d"]]
mock_workspace_db = MagicMock()
mock_workspace_db.get_or_create_workspace = AsyncMock(
return_value=MagicMock(id="ws-1")
)
with patch(
"backend.copilot.tools.helpers.workspace_db",
return_value=mock_workspace_db,
):
response = await execute_block(
block=block,
block_id="pass-through",
input_data={"values": original_values, "name": "test"},
user_id=_TEST_USER_ID,
session_id=_TEST_SESSION_ID,
node_exec_id="exec-4",
matched_credentials={},
)
assert isinstance(response, BlockOutputResponse)
assert block._captured_inputs["values"] == original_values
assert block._captured_inputs["name"] == "test"
@pytest.mark.asyncio(loop_scope="session")
async def test_coerce_string_to_int():
"""String number → int."""
block = _make_coerce_block(
"int-block",
"Int Block",
{"count": int},
)
mock_workspace_db = MagicMock()
mock_workspace_db.get_or_create_workspace = AsyncMock(
return_value=MagicMock(id="ws-1")
)
with patch(
"backend.copilot.tools.helpers.workspace_db",
return_value=mock_workspace_db,
):
response = await execute_block(
block=block,
block_id="int-block",
input_data={"count": "42"},
user_id=_TEST_USER_ID,
session_id=_TEST_SESSION_ID,
node_exec_id="exec-5",
matched_credentials={},
)
assert isinstance(response, BlockOutputResponse)
assert block._captured_inputs["count"] == 42
assert isinstance(block._captured_inputs["count"], int)
@pytest.mark.asyncio(loop_scope="session")
async def test_coerce_skips_none_values():
"""None values are not coerced (they may be optional fields)."""
block = _make_coerce_block(
"optional-block",
"Optional Block",
{"data": list[str], "label": str},
)
mock_workspace_db = MagicMock()
mock_workspace_db.get_or_create_workspace = AsyncMock(
return_value=MagicMock(id="ws-1")
)
with patch(
"backend.copilot.tools.helpers.workspace_db",
return_value=mock_workspace_db,
):
response = await execute_block(
block=block,
block_id="optional-block",
input_data={"label": "test"},
user_id=_TEST_USER_ID,
session_id=_TEST_SESSION_ID,
node_exec_id="exec-6",
matched_credentials={},
)
assert isinstance(response, BlockOutputResponse)
assert "data" not in block._captured_inputs
@pytest.mark.asyncio(loop_scope="session")
async def test_coerce_union_type_preserves_valid_member():
"""Union-typed fields should not be coerced when the value matches a member."""
block = _make_coerce_block(
"union-block",
"Union Block",
{"content": str | list[str]},
)
mock_workspace_db = MagicMock()
mock_workspace_db.get_or_create_workspace = AsyncMock(
return_value=MagicMock(id="ws-1")
)
with patch(
"backend.copilot.tools.helpers.workspace_db",
return_value=mock_workspace_db,
):
response = await execute_block(
block=block,
block_id="union-block",
input_data={"content": ["a", "b"]},
user_id=_TEST_USER_ID,
session_id=_TEST_SESSION_ID,
node_exec_id="exec-7",
matched_credentials={},
)
assert isinstance(response, BlockOutputResponse)
assert block._captured_inputs["content"] == ["a", "b"]
assert isinstance(block._captured_inputs["content"], list)
@pytest.mark.asyncio(loop_scope="session")
async def test_coerce_inner_elements_of_generic():
"""Inner elements of generic containers are recursively coerced."""
block = _make_coerce_block(
"inner-coerce",
"Inner Coerce",
{"values": list[str]},
)
mock_workspace_db = MagicMock()
mock_workspace_db.get_or_create_workspace = AsyncMock(
return_value=MagicMock(id="ws-1")
)
with patch(
"backend.copilot.tools.helpers.workspace_db",
return_value=mock_workspace_db,
):
response = await execute_block(
block=block,
block_id="inner-coerce",
input_data={"values": [1, 2, 3]},
user_id=_TEST_USER_ID,
session_id=_TEST_SESSION_ID,
node_exec_id="exec-8",
matched_credentials={},
)
assert isinstance(response, BlockOutputResponse)
assert block._captured_inputs["values"] == ["1", "2", "3"]
assert all(isinstance(v, str) for v in block._captured_inputs["values"])

View File

@@ -39,7 +39,6 @@ class ResponseType(str, Enum):
BLOCK_LIST = "block_list"
BLOCK_DETAILS = "block_details"
BLOCK_OUTPUT = "block_output"
REVIEW_REQUIRED = "review_required"
# MCP
MCP_GUIDE = "mcp_guide"
@@ -459,21 +458,6 @@ class BlockOutputResponse(ToolResponseBase):
success: bool = True
class ReviewRequiredResponse(ToolResponseBase):
"""Response when a block requires human review before execution."""
type: ResponseType = ResponseType.REVIEW_REQUIRED
block_id: str
block_name: str
review_id: str = Field(description="The review ID for tracking approval status")
graph_exec_id: str = Field(
description="The graph execution ID for fetching review status"
)
input_data: dict[str, Any] = Field(
description="The input data that requires review"
)
class WebFetchResponse(ToolResponseBase):
"""Response for web_fetch tool."""

View File

@@ -534,9 +534,7 @@ class RunAgentTool(BaseTool):
return ExecutionStartedResponse(
message=(
f"Agent '{library_agent.name}' is awaiting human review. "
f"The user can approve or reject inline. After approval, "
f"the execution resumes automatically. Use view_agent_output "
f"with execution_id='{execution.id}' to check the result."
f"Check at {library_agent_link}."
),
session_id=session_id,
execution_id=execution.id,

View File

@@ -2,34 +2,38 @@
import logging
import uuid
from collections import defaultdict
from typing import Any
from pydantic_core import PydanticUndefined
from backend.blocks import BlockType, get_block
from backend.blocks._base import AnyBlockSchema
from backend.copilot.constants import (
COPILOT_NODE_EXEC_ID_SEPARATOR,
COPILOT_NODE_PREFIX,
COPILOT_SESSION_PREFIX,
)
from backend.copilot.model import ChatSession
from backend.data.db_accessors import review_db
from backend.data.db_accessors import workspace_db
from backend.data.execution import ExecutionContext
from backend.data.model import CredentialsFieldInfo, CredentialsMetaInput
from backend.integrations.creds_manager import IntegrationCredentialsManager
from backend.util.exceptions import BlockError
from .base import BaseTool
from .find_block import COPILOT_EXCLUDED_BLOCK_IDS, COPILOT_EXCLUDED_BLOCK_TYPES
from .helpers import execute_block, get_inputs_from_schema, resolve_block_credentials
from .helpers import get_inputs_from_schema
from .models import (
BlockDetails,
BlockDetailsResponse,
BlockOutputResponse,
ErrorResponse,
InputValidationErrorResponse,
ReviewRequiredResponse,
SetupInfo,
SetupRequirementsResponse,
ToolResponseBase,
UserReadiness,
)
from .utils import build_missing_credentials_from_field_info
from .utils import (
build_missing_credentials_from_field_info,
match_credentials_to_requirements,
)
logger = logging.getLogger(__name__)
@@ -48,9 +52,7 @@ class RunBlockTool(BaseTool):
"IMPORTANT: You MUST call find_block first to get the block's 'id' - "
"do NOT guess or make up block IDs. "
"On first attempt (without input_data), returns detailed schema showing "
"required inputs and outputs. Then call again with proper input_data to execute. "
"If a block requires human review, use continue_run_block with the "
"review_id after the user approves."
"required inputs and outputs. Then call again with proper input_data to execute."
)
@property
@@ -164,10 +166,11 @@ class RunBlockTool(BaseTool):
logger.info(f"Executing block {block.name} ({block_id}) for user {user_id}")
creds_manager = IntegrationCredentialsManager()
(
matched_credentials,
missing_credentials,
) = await resolve_block_credentials(user_id, block, input_data)
) = await self._resolve_block_credentials(user_id, block, input_data)
# Get block schemas for details/validation
try:
@@ -276,97 +279,169 @@ class RunBlockTool(BaseTool):
user_authenticated=True,
)
# Generate synthetic IDs for CoPilot context.
# Encode node_id in node_exec_id so it can be extracted later
# (e.g. for auto-approve, where we need node_id but have no NodeExecution row).
synthetic_graph_id = f"{COPILOT_SESSION_PREFIX}{session.session_id}"
synthetic_node_id = f"{COPILOT_NODE_PREFIX}{block_id}"
try:
# Get or create user's workspace for CoPilot file operations
workspace = await workspace_db().get_or_create_workspace(user_id)
# Check for an existing WAITING review for this block with the same input.
# If the LLM retries run_block with identical input, we reuse the existing
# review instead of creating duplicates. Different inputs = new execution.
existing_reviews = await review_db().get_pending_reviews_for_execution(
synthetic_graph_id, user_id
)
existing_review = next(
(
r
for r in existing_reviews
if r.node_id == synthetic_node_id
and r.status.value == "WAITING"
and r.payload == input_data
),
None,
)
if existing_review:
return ReviewRequiredResponse(
message=(
f"Block '{block.name}' requires human review. "
f"After the user approves, call continue_run_block with "
f"review_id='{existing_review.node_exec_id}' to execute."
),
session_id=session_id,
block_id=block_id,
block_name=block.name,
review_id=existing_review.node_exec_id,
graph_exec_id=synthetic_graph_id,
input_data=input_data,
# Generate synthetic IDs for CoPilot context
# Each chat session is treated as its own agent with one continuous run
# This means:
# - graph_id (agent) = session (memories scoped to session when limit_to_agent=True)
# - graph_exec_id (run) = session (memories scoped to session when limit_to_run=True)
# - node_exec_id = unique per block execution
synthetic_graph_id = f"copilot-session-{session.session_id}"
synthetic_graph_exec_id = f"copilot-session-{session.session_id}"
synthetic_node_id = f"copilot-node-{block_id}"
synthetic_node_exec_id = (
f"copilot-{session.session_id}-{uuid.uuid4().hex[:8]}"
)
synthetic_node_exec_id = (
f"{synthetic_node_id}{COPILOT_NODE_EXEC_ID_SEPARATOR}"
f"{uuid.uuid4().hex[:8]}"
)
# Check for HITL review before execution.
# This creates the review record in the DB for CoPilot flows.
review_context = ExecutionContext(
user_id=user_id,
graph_id=synthetic_graph_id,
graph_exec_id=synthetic_graph_id,
graph_version=1,
node_id=synthetic_node_id,
node_exec_id=synthetic_node_exec_id,
sensitive_action_safe_mode=True,
)
should_pause, input_data = await block.is_block_exec_need_review(
input_data,
user_id=user_id,
node_id=synthetic_node_id,
node_exec_id=synthetic_node_exec_id,
graph_exec_id=synthetic_graph_id,
graph_id=synthetic_graph_id,
graph_version=1,
execution_context=review_context,
is_graph_execution=False,
)
if should_pause:
return ReviewRequiredResponse(
message=(
f"Block '{block.name}' requires human review. "
f"After the user approves, call continue_run_block with "
f"review_id='{synthetic_node_exec_id}' to execute."
),
session_id=session_id,
block_id=block_id,
block_name=block.name,
review_id=synthetic_node_exec_id,
graph_exec_id=synthetic_graph_id,
input_data=input_data,
# Create unified execution context with all required fields
execution_context = ExecutionContext(
# Execution identity
user_id=user_id,
graph_id=synthetic_graph_id,
graph_exec_id=synthetic_graph_exec_id,
graph_version=1, # Versions are 1-indexed
node_id=synthetic_node_id,
node_exec_id=synthetic_node_exec_id,
# Workspace with session scoping
workspace_id=workspace.id,
session_id=session.session_id,
)
return await execute_block(
block=block,
block_id=block_id,
input_data=input_data,
user_id=user_id,
session_id=session_id,
node_exec_id=synthetic_node_exec_id,
matched_credentials=matched_credentials,
)
# Prepare kwargs for block execution
# Keep individual kwargs for backwards compatibility with existing blocks
exec_kwargs: dict[str, Any] = {
"user_id": user_id,
"execution_context": execution_context,
# Legacy: individual kwargs for blocks not yet using execution_context
"workspace_id": workspace.id,
"graph_exec_id": synthetic_graph_exec_id,
"node_exec_id": synthetic_node_exec_id,
"node_id": synthetic_node_id,
"graph_version": 1, # Versions are 1-indexed
"graph_id": synthetic_graph_id,
}
for field_name, cred_meta in matched_credentials.items():
# Inject metadata into input_data (for validation)
if field_name not in input_data:
input_data[field_name] = cred_meta.model_dump()
# Fetch actual credentials and pass as kwargs (for execution)
actual_credentials = await creds_manager.get(
user_id, cred_meta.id, lock=False
)
if actual_credentials:
exec_kwargs[field_name] = actual_credentials
else:
return ErrorResponse(
message=f"Failed to retrieve credentials for {field_name}",
session_id=session_id,
)
# Execute the block and collect outputs
outputs: dict[str, list[Any]] = defaultdict(list)
async for output_name, output_data in block.execute(
input_data,
**exec_kwargs,
):
outputs[output_name].append(output_data)
return BlockOutputResponse(
message=f"Block '{block.name}' executed successfully",
block_id=block_id,
block_name=block.name,
outputs=dict(outputs),
success=True,
session_id=session_id,
)
except BlockError as e:
logger.warning(f"Block execution failed: {e}")
return ErrorResponse(
message=f"Block execution failed: {e}",
error=str(e),
session_id=session_id,
)
except Exception as e:
logger.error(f"Unexpected error executing block: {e}", exc_info=True)
return ErrorResponse(
message=f"Failed to execute block: {str(e)}",
error=str(e),
session_id=session_id,
)
async def _resolve_block_credentials(
self,
user_id: str,
block: AnyBlockSchema,
input_data: dict[str, Any] | None = None,
) -> tuple[dict[str, CredentialsMetaInput], list[CredentialsMetaInput]]:
"""
Resolve credentials for a block by matching user's available credentials.
Args:
user_id: User ID
block: Block to resolve credentials for
input_data: Input data for the block (used to determine provider via discriminator)
Returns:
tuple of (matched_credentials, missing_credentials) - matched credentials
are used for block execution, missing ones indicate setup requirements.
"""
input_data = input_data or {}
requirements = self._resolve_discriminated_credentials(block, input_data)
if not requirements:
return {}, []
return await match_credentials_to_requirements(user_id, requirements)
def _get_inputs_list(self, block: AnyBlockSchema) -> list[dict[str, Any]]:
"""Extract non-credential inputs from block schema."""
schema = block.input_schema.jsonschema()
credentials_fields = set(block.input_schema.get_credentials_fields().keys())
return get_inputs_from_schema(schema, exclude_fields=credentials_fields)
def _resolve_discriminated_credentials(
self,
block: AnyBlockSchema,
input_data: dict[str, Any],
) -> dict[str, CredentialsFieldInfo]:
"""Resolve credential requirements, applying discriminator logic where needed."""
credentials_fields_info = block.input_schema.get_credentials_fields_info()
if not credentials_fields_info:
return {}
resolved: dict[str, CredentialsFieldInfo] = {}
for field_name, field_info in credentials_fields_info.items():
effective_field_info = field_info
if field_info.discriminator and field_info.discriminator_mapping:
discriminator_value = input_data.get(field_info.discriminator)
if discriminator_value is None:
field = block.input_schema.model_fields.get(
field_info.discriminator
)
if field and field.default is not PydanticUndefined:
discriminator_value = field.default
if (
discriminator_value
and discriminator_value in field_info.discriminator_mapping
):
effective_field_info = field_info.discriminate(discriminator_value)
# For host-scoped credentials, add the discriminator value
# (e.g., URL) so _credential_is_for_host can match it
effective_field_info.discriminator_values.add(discriminator_value)
logger.debug(
f"Discriminated provider for {field_name}: "
f"{discriminator_value} -> {effective_field_info.provider}"
)
resolved[field_name] = effective_field_info
return resolved

View File

@@ -12,7 +12,6 @@ from .models import (
BlockOutputResponse,
ErrorResponse,
InputValidationErrorResponse,
ReviewRequiredResponse,
)
from .run_block import RunBlockTool
@@ -28,16 +27,9 @@ def make_mock_block(
mock.name = name
mock.block_type = block_type
mock.disabled = disabled
mock.is_sensitive_action = False
mock.input_schema = MagicMock()
mock.input_schema.jsonschema.return_value = {"properties": {}, "required": []}
mock.input_schema.get_credentials_fields_info.return_value = {}
mock.input_schema.get_credentials_fields.return_value = {}
async def _no_review(input_data, **kwargs):
return False, input_data
mock.is_block_exec_need_review = _no_review
mock.input_schema.get_credentials_fields_info.return_value = []
return mock
@@ -54,7 +46,6 @@ def make_mock_block_with_schema(
mock.name = name
mock.block_type = BlockType.STANDARD
mock.disabled = False
mock.is_sensitive_action = False
mock.description = f"Test block: {name}"
input_schema = {
@@ -72,12 +63,6 @@ def make_mock_block_with_schema(
mock.output_schema = MagicMock()
mock.output_schema.jsonschema.return_value = output_schema
# Default: no review needed, pass through input_data unchanged
async def _no_review(input_data, **kwargs):
return False, input_data
mock.is_block_exec_need_review = _no_review
return mock
@@ -141,15 +126,9 @@ class TestRunBlockFiltering:
"standard-id", "HTTP Request", BlockType.STANDARD
)
with (
patch(
"backend.copilot.tools.run_block.get_block",
return_value=standard_block,
),
patch(
"backend.copilot.tools.helpers.match_credentials_to_requirements",
return_value=({}, []),
),
with patch(
"backend.copilot.tools.run_block.get_block",
return_value=standard_block,
):
tool = RunBlockTool()
response = await tool._execute(
@@ -175,7 +154,12 @@ class TestRunBlockInputValidation:
@pytest.mark.asyncio(loop_scope="session")
async def test_unknown_input_fields_are_rejected(self):
"""run_block rejects unknown input fields instead of silently ignoring them."""
"""run_block rejects unknown input fields instead of silently ignoring them.
Scenario: The AI Text Generator block has a field called 'model' (for LLM model
selection), but the LLM calling the tool guesses wrong and sends 'LLM_Model'
instead. The block should reject the request and return the valid schema.
"""
session = make_session(user_id=_TEST_USER_ID)
mock_block = make_mock_block_with_schema(
@@ -198,31 +182,27 @@ class TestRunBlockInputValidation:
output_properties={"response": {"type": "string"}},
)
with (
patch(
"backend.copilot.tools.run_block.get_block",
return_value=mock_block,
),
patch(
"backend.copilot.tools.helpers.match_credentials_to_requirements",
return_value=({}, []),
),
with patch(
"backend.copilot.tools.run_block.get_block",
return_value=mock_block,
):
tool = RunBlockTool()
# Provide 'prompt' (correct) but 'LLM_Model' instead of 'model' (wrong key)
response = await tool._execute(
user_id=_TEST_USER_ID,
session=session,
block_id="ai-text-gen-id",
input_data={
"prompt": "Write a haiku about coding",
"LLM_Model": "claude-opus-4-6",
"LLM_Model": "claude-opus-4-6", # WRONG KEY - should be 'model'
},
)
assert isinstance(response, InputValidationErrorResponse)
assert "LLM_Model" in response.unrecognized_fields
assert "Block was not executed" in response.message
assert "inputs" in response.model_dump()
assert "inputs" in response.model_dump() # valid schema included
@pytest.mark.asyncio(loop_scope="session")
async def test_multiple_wrong_keys_are_all_reported(self):
@@ -241,26 +221,21 @@ class TestRunBlockInputValidation:
required_fields=["prompt"],
)
with (
patch(
"backend.copilot.tools.run_block.get_block",
return_value=mock_block,
),
patch(
"backend.copilot.tools.helpers.match_credentials_to_requirements",
return_value=({}, []),
),
with patch(
"backend.copilot.tools.run_block.get_block",
return_value=mock_block,
):
tool = RunBlockTool()
response = await tool._execute(
user_id=_TEST_USER_ID,
session=session,
block_id="ai-text-gen-id",
input_data={
"prompt": "Hello",
"llm_model": "claude-opus-4-6",
"system_prompt": "Be helpful",
"retries": 5,
"prompt": "Hello", # correct
"llm_model": "claude-opus-4-6", # WRONG - should be 'model'
"system_prompt": "Be helpful", # WRONG - should be 'sys_prompt'
"retries": 5, # WRONG - should be 'retry'
},
)
@@ -287,26 +262,23 @@ class TestRunBlockInputValidation:
required_fields=["prompt"],
)
with (
patch(
"backend.copilot.tools.run_block.get_block",
return_value=mock_block,
),
patch(
"backend.copilot.tools.helpers.match_credentials_to_requirements",
return_value=({}, []),
),
with patch(
"backend.copilot.tools.run_block.get_block",
return_value=mock_block,
):
tool = RunBlockTool()
# 'prompt' is missing AND 'LLM_Model' is an unknown field
response = await tool._execute(
user_id=_TEST_USER_ID,
session=session,
block_id="ai-text-gen-id",
input_data={
"LLM_Model": "claude-opus-4-6",
"LLM_Model": "claude-opus-4-6", # wrong key, and 'prompt' is missing
},
)
# Unknown fields are caught first
assert isinstance(response, InputValidationErrorResponse)
assert "LLM_Model" in response.unrecognized_fields
@@ -341,11 +313,7 @@ class TestRunBlockInputValidation:
return_value=mock_block,
),
patch(
"backend.copilot.tools.helpers.match_credentials_to_requirements",
return_value=({}, []),
),
patch(
"backend.copilot.tools.helpers.workspace_db",
"backend.copilot.tools.run_block.workspace_db",
return_value=mock_workspace_db,
),
):
@@ -357,7 +325,7 @@ class TestRunBlockInputValidation:
block_id="ai-text-gen-id",
input_data={
"prompt": "Write a haiku",
"model": "gpt-4o-mini",
"model": "gpt-4o-mini", # correct field name
},
)
@@ -379,191 +347,20 @@ class TestRunBlockInputValidation:
required_fields=["prompt"],
)
with (
patch(
"backend.copilot.tools.run_block.get_block",
return_value=mock_block,
),
patch(
"backend.copilot.tools.helpers.match_credentials_to_requirements",
return_value=({}, []),
),
with patch(
"backend.copilot.tools.run_block.get_block",
return_value=mock_block,
):
tool = RunBlockTool()
# Only provide valid optional field, missing required 'prompt'
response = await tool._execute(
user_id=_TEST_USER_ID,
session=session,
block_id="ai-text-gen-id",
input_data={
"model": "gpt-4o-mini",
"model": "gpt-4o-mini", # valid but optional
},
)
assert isinstance(response, BlockDetailsResponse)
class TestRunBlockSensitiveAction:
"""Tests for sensitive action HITL review in RunBlockTool.
run_block calls is_block_exec_need_review() explicitly before execution.
When review is needed (should_pause=True), ReviewRequiredResponse is returned.
"""
@pytest.mark.asyncio(loop_scope="session")
async def test_sensitive_block_paused_returns_review_required(self):
"""When is_block_exec_need_review returns should_pause=True, ReviewRequiredResponse is returned."""
session = make_session(user_id=_TEST_USER_ID)
input_data = {
"repo_url": "https://github.com/test/repo",
"branch": "feature-branch",
}
mock_block = make_mock_block_with_schema(
block_id="delete-branch-id",
name="Delete Branch",
input_properties={
"repo_url": {"type": "string"},
"branch": {"type": "string"},
},
required_fields=["repo_url", "branch"],
)
mock_block.is_sensitive_action = True
mock_block.is_block_exec_need_review = AsyncMock(
return_value=(True, input_data)
)
with (
patch(
"backend.copilot.tools.run_block.get_block",
return_value=mock_block,
),
patch(
"backend.copilot.tools.helpers.match_credentials_to_requirements",
return_value=({}, []),
),
):
tool = RunBlockTool()
response = await tool._execute(
user_id=_TEST_USER_ID,
session=session,
block_id="delete-branch-id",
input_data=input_data,
)
assert isinstance(response, ReviewRequiredResponse)
assert "requires human review" in response.message
assert "continue_run_block" in response.message
assert response.block_name == "Delete Branch"
@pytest.mark.asyncio(loop_scope="session")
async def test_sensitive_block_executes_after_approval(self):
"""After approval (should_pause=False), sensitive blocks execute and return outputs."""
session = make_session(user_id=_TEST_USER_ID)
input_data = {
"repo_url": "https://github.com/test/repo",
"branch": "feature-branch",
}
mock_block = make_mock_block_with_schema(
block_id="delete-branch-id",
name="Delete Branch",
input_properties={
"repo_url": {"type": "string"},
"branch": {"type": "string"},
},
required_fields=["repo_url", "branch"],
)
mock_block.is_sensitive_action = True
mock_block.is_block_exec_need_review = AsyncMock(
return_value=(False, input_data)
)
async def mock_execute(input_data, **kwargs):
yield "result", "Branch deleted successfully"
mock_block.execute = mock_execute
mock_workspace_db = MagicMock()
mock_workspace_db.get_or_create_workspace = AsyncMock(
return_value=MagicMock(id="test-workspace-id")
)
with (
patch(
"backend.copilot.tools.run_block.get_block",
return_value=mock_block,
),
patch(
"backend.copilot.tools.helpers.match_credentials_to_requirements",
return_value=({}, []),
),
patch(
"backend.copilot.tools.helpers.workspace_db",
return_value=mock_workspace_db,
),
):
tool = RunBlockTool()
response = await tool._execute(
user_id=_TEST_USER_ID,
session=session,
block_id="delete-branch-id",
input_data=input_data,
)
assert isinstance(response, BlockOutputResponse)
assert response.success is True
@pytest.mark.asyncio(loop_scope="session")
async def test_non_sensitive_block_executes_normally(self):
"""Non-sensitive blocks skip review and execute directly."""
session = make_session(user_id=_TEST_USER_ID)
input_data = {"url": "https://example.com"}
mock_block = make_mock_block_with_schema(
block_id="http-request-id",
name="HTTP Request",
input_properties={
"url": {"type": "string"},
},
required_fields=["url"],
)
mock_block.is_sensitive_action = False
mock_block.is_block_exec_need_review = AsyncMock(
return_value=(False, input_data)
)
async def mock_execute(input_data, **kwargs):
yield "response", {"status": 200}
mock_block.execute = mock_execute
mock_workspace_db = MagicMock()
mock_workspace_db.get_or_create_workspace = AsyncMock(
return_value=MagicMock(id="test-workspace-id")
)
with (
patch(
"backend.copilot.tools.run_block.get_block",
return_value=mock_block,
),
patch(
"backend.copilot.tools.helpers.match_credentials_to_requirements",
return_value=({}, []),
),
patch(
"backend.copilot.tools.helpers.workspace_db",
return_value=mock_workspace_db,
),
):
tool = RunBlockTool()
response = await tool._execute(
user_id=_TEST_USER_ID,
session=session,
block_id="http-request-id",
input_data=input_data,
)
assert isinstance(response, BlockOutputResponse)
assert response.success is True

View File

@@ -34,11 +34,6 @@ logger = logging.getLogger(__name__)
_AUTH_STATUS_CODES = {401, 403}
def _service_name(host: str) -> str:
"""Strip the 'mcp.' prefix from an MCP hostname: 'mcp.sentry.dev''sentry.dev'"""
return host[4:] if host.startswith("mcp.") else host
class RunMCPToolTool(BaseTool):
"""
Tool for discovering and executing tools on any MCP server.
@@ -308,8 +303,8 @@ class RunMCPToolTool(BaseTool):
)
return ErrorResponse(
message=(
f"Unable to connect to {_service_name(server_host(server_url))} "
" no credentials configured."
f"The MCP server at {server_host(server_url)} requires authentication, "
"but no credential configuration was found."
),
session_id=session_id,
)
@@ -317,13 +312,15 @@ class RunMCPToolTool(BaseTool):
missing_creds_list = list(missing_creds_dict.values())
host = server_host(server_url)
service = _service_name(host)
return SetupRequirementsResponse(
message=(f"To continue, sign in to {service} and approve access."),
message=(
f"The MCP server at {host} requires authentication. "
"Please connect your credentials to continue."
),
session_id=session_id,
setup_info=SetupInfo(
agent_id=server_url,
agent_name=service,
agent_name=f"MCP: {host}",
user_readiness=UserReadiness(
has_all_credentials=False,
missing_credentials=missing_creds_dict,

View File

@@ -65,8 +65,9 @@ async def test_run_block_returns_details_when_no_input_provided():
return_value=http_block,
):
# Mock credentials check to return no missing credentials
with patch(
"backend.copilot.tools.run_block.resolve_block_credentials",
with patch.object(
RunBlockTool,
"_resolve_block_credentials",
new_callable=AsyncMock,
return_value=({}, []), # (matched_credentials, missing_credentials)
):
@@ -122,8 +123,9 @@ async def test_run_block_returns_details_when_only_credentials_provided():
"backend.copilot.tools.run_block.get_block",
return_value=mock,
):
with patch(
"backend.copilot.tools.run_block.resolve_block_credentials",
with patch.object(
RunBlockTool,
"_resolve_block_credentials",
new_callable=AsyncMock,
return_value=(
{

View File

@@ -756,4 +756,4 @@ async def test_build_setup_requirements_returns_setup_response():
)
assert isinstance(result, SetupRequirementsResponse)
assert result.setup_info.agent_id == _SERVER_URL
assert "sign in" in result.message.lower()
assert "authentication" in result.message.lower()

View File

@@ -100,31 +100,19 @@ MODEL_COST: dict[LlmModel, int] = {
LlmModel.OLLAMA_DOLPHIN: 1,
LlmModel.OPENAI_GPT_OSS_120B: 1,
LlmModel.OPENAI_GPT_OSS_20B: 1,
LlmModel.GEMINI_2_5_PRO_PREVIEW: 4,
LlmModel.GEMINI_2_5_PRO: 4,
LlmModel.GEMINI_3_1_PRO_PREVIEW: 5,
LlmModel.GEMINI_3_FLASH_PREVIEW: 2,
LlmModel.GEMINI_3_PRO_PREVIEW: 5,
LlmModel.GEMINI_2_5_FLASH: 1,
LlmModel.GEMINI_2_0_FLASH: 1,
LlmModel.GEMINI_3_1_FLASH_LITE_PREVIEW: 1,
LlmModel.GEMINI_2_5_FLASH_LITE_PREVIEW: 1,
LlmModel.GEMINI_2_0_FLASH_LITE: 1,
LlmModel.MISTRAL_NEMO: 1,
LlmModel.MISTRAL_LARGE_3: 2,
LlmModel.MISTRAL_MEDIUM_3_1: 2,
LlmModel.MISTRAL_SMALL_3_2: 1,
LlmModel.CODESTRAL: 1,
LlmModel.COHERE_COMMAND_R_08_2024: 1,
LlmModel.COHERE_COMMAND_R_PLUS_08_2024: 3,
LlmModel.COHERE_COMMAND_A_03_2025: 3,
LlmModel.COHERE_COMMAND_A_TRANSLATE_08_2025: 3,
LlmModel.COHERE_COMMAND_A_REASONING_08_2025: 6,
LlmModel.COHERE_COMMAND_A_VISION_07_2025: 3,
LlmModel.DEEPSEEK_CHAT: 2,
LlmModel.DEEPSEEK_R1_0528: 1,
LlmModel.PERPLEXITY_SONAR: 1,
LlmModel.PERPLEXITY_SONAR_PRO: 5,
LlmModel.PERPLEXITY_SONAR_REASONING_PRO: 5,
LlmModel.PERPLEXITY_SONAR_DEEP_RESEARCH: 10,
LlmModel.NOUSRESEARCH_HERMES_3_LLAMA_3_1_405B: 1,
LlmModel.NOUSRESEARCH_HERMES_3_LLAMA_3_1_70B: 1,
@@ -132,7 +120,6 @@ MODEL_COST: dict[LlmModel, int] = {
LlmModel.AMAZON_NOVA_MICRO_V1: 1,
LlmModel.AMAZON_NOVA_PRO_V1: 1,
LlmModel.MICROSOFT_WIZARDLM_2_8X22B: 1,
LlmModel.MICROSOFT_PHI_4: 1,
LlmModel.GRYPHE_MYTHOMAX_L2_13B: 1,
LlmModel.META_LLAMA_4_SCOUT: 1,
LlmModel.META_LLAMA_4_MAVERICK: 1,
@@ -140,7 +127,6 @@ MODEL_COST: dict[LlmModel, int] = {
LlmModel.LLAMA_API_LLAMA4_MAVERICK: 1,
LlmModel.LLAMA_API_LLAMA3_3_8B: 1,
LlmModel.LLAMA_API_LLAMA3_3_70B: 1,
LlmModel.GROK_3: 3,
LlmModel.GROK_4: 9,
LlmModel.GROK_4_FAST: 1,
LlmModel.GROK_4_1_FAST: 1,

View File

@@ -116,16 +116,3 @@ def workspace_db():
workspace_db = get_database_manager_async_client()
return workspace_db
def review_db():
if db.is_connected():
from backend.data import human_review as _review_db
review_db = _review_db
else:
from backend.util.clients import get_database_manager_async_client
review_db = get_database_manager_async_client()
return review_db

View File

@@ -79,10 +79,7 @@ from backend.data.graph import (
from backend.data.human_review import (
cancel_pending_reviews_for_execution,
check_approval,
delete_review_by_node_exec_id,
get_or_create_human_review,
get_pending_reviews_for_execution,
get_reviews_by_node_exec_ids,
has_pending_reviews_for_graph_exec,
update_review_processed_status,
)
@@ -249,10 +246,7 @@ class DatabaseManager(AppService):
# ============ Human In The Loop ============ #
cancel_pending_reviews_for_execution = _(cancel_pending_reviews_for_execution)
check_approval = _(check_approval)
delete_review_by_node_exec_id = _(delete_review_by_node_exec_id)
get_or_create_human_review = _(get_or_create_human_review)
get_pending_reviews_for_execution = _(get_pending_reviews_for_execution)
get_reviews_by_node_exec_ids = _(get_reviews_by_node_exec_ids)
has_pending_reviews_for_graph_exec = _(has_pending_reviews_for_graph_exec)
update_review_processed_status = _(update_review_processed_status)
@@ -439,10 +433,7 @@ class DatabaseManagerAsyncClient(AppServiceClient):
# ============ Human In The Loop ============ #
cancel_pending_reviews_for_execution = d.cancel_pending_reviews_for_execution
check_approval = d.check_approval
delete_review_by_node_exec_id = d.delete_review_by_node_exec_id
get_or_create_human_review = d.get_or_create_human_review
get_pending_reviews_for_execution = d.get_pending_reviews_for_execution
get_reviews_by_node_exec_ids = d.get_reviews_by_node_exec_ids
update_review_processed_status = d.update_review_processed_status
# ============ User Comms ============ #
@@ -512,10 +503,6 @@ class DatabaseManagerAsyncClient(AppServiceClient):
list_workspace_files = d.list_workspace_files
soft_delete_workspace_file = d.soft_delete_workspace_file
# ============ Credits ============ #
spend_credits = d.spend_credits
get_credits = d.get_credits
# ============ Understanding ============ #
get_business_understanding = d.get_business_understanding
upsert_business_understanding = d.upsert_business_understanding

View File

@@ -17,10 +17,6 @@ from backend.api.features.executions.review.model import (
PendingHumanReviewModel,
SafeJsonData,
)
from backend.copilot.constants import (
is_copilot_synthetic_id,
parse_node_id_from_exec_id,
)
from backend.data.execution import get_graph_execution_meta
from backend.util.json import SafeJson
@@ -127,13 +123,11 @@ async def create_auto_approval_record(
Raises:
ValueError: If the graph execution doesn't belong to the user
"""
# Validate ownership: if a graph execution record exists, it must belong
# to this user. Non-graph executions (e.g. CoPilot) won't have a record.
if not is_copilot_synthetic_id(
graph_exec_id
) and not await get_graph_execution_meta(
# Validate that the graph execution belongs to this user (defense in depth)
graph_exec = await get_graph_execution_meta(
user_id=user_id, execution_id=graph_exec_id
):
)
if not graph_exec:
raise ValueError(
f"Graph execution {graph_exec_id} not found or doesn't belong to user {user_id}"
)
@@ -271,7 +265,7 @@ async def get_pending_review_by_node_exec_id(
async def get_reviews_by_node_exec_ids(
node_exec_ids: list[str], user_id: str
) -> dict[str, PendingHumanReviewModel]:
) -> dict[str, "PendingHumanReviewModel"]:
"""
Get multiple reviews by their node execution IDs regardless of status.
@@ -298,26 +292,21 @@ async def get_reviews_by_node_exec_ids(
if not reviews:
return {}
# Split into synthetic (CoPilot) and real IDs for different resolution paths
synthetic_ids = {
r.nodeExecId for r in reviews if is_copilot_synthetic_id(r.nodeExecId)
}
real_ids = [r.nodeExecId for r in reviews if r.nodeExecId not in synthetic_ids]
# Batch fetch all node executions to avoid N+1 queries
node_exec_ids_to_fetch = [review.nodeExecId for review in reviews]
node_execs = await AgentNodeExecution.prisma().find_many(
where={"id": {"in": node_exec_ids_to_fetch}},
include={"Node": True},
)
# Batch fetch real node executions to avoid N+1 queries
node_exec_id_to_node_id: dict[str, str] = {}
if real_ids:
node_execs = await AgentNodeExecution.prisma().find_many(
where={"id": {"in": real_ids}},
)
node_exec_id_to_node_id = {ne.id: ne.agentNodeId for ne in node_execs}
# Create mapping from node_exec_id to node_id
node_exec_id_to_node_id = {
node_exec.id: node_exec.agentNodeId for node_exec in node_execs
}
result = {}
for review in reviews:
if review.nodeExecId in synthetic_ids:
node_id = parse_node_id_from_exec_id(review.nodeExecId)
else:
node_id = node_exec_id_to_node_id.get(review.nodeExecId, review.nodeExecId)
node_id = node_exec_id_to_node_id.get(review.nodeExecId, review.nodeExecId)
result[review.nodeExecId] = PendingHumanReviewModel.from_db(
review, node_id=node_id
)
@@ -342,19 +331,6 @@ async def has_pending_reviews_for_graph_exec(graph_exec_id: str) -> bool:
return count > 0
async def _resolve_node_id(node_exec_id: str, get_node_execution) -> str:
"""Resolve node_id from a node_exec_id.
For CoPilot synthetic IDs (e.g. copilot-node-block-id:abc12345),
extract the node_id portion (copilot-node-block-id).
For real graph executions, look up the NodeExecution record.
"""
if is_copilot_synthetic_id(node_exec_id):
return parse_node_id_from_exec_id(node_exec_id)
node_exec = await get_node_execution(node_exec_id)
return node_exec.node_id if node_exec else node_exec_id
async def get_pending_reviews_for_user(
user_id: str, page: int = 1, page_size: int = 25
) -> list["PendingHumanReviewModel"]:
@@ -385,7 +361,8 @@ async def get_pending_reviews_for_user(
# Fetch node_id for each review from NodeExecution
result = []
for review in reviews:
node_id = await _resolve_node_id(review.nodeExecId, get_node_execution)
node_exec = await get_node_execution(review.nodeExecId)
node_id = node_exec.node_id if node_exec else review.nodeExecId
result.append(PendingHumanReviewModel.from_db(review, node_id=node_id))
return result
@@ -393,7 +370,7 @@ async def get_pending_reviews_for_user(
async def get_pending_reviews_for_execution(
graph_exec_id: str, user_id: str
) -> list[PendingHumanReviewModel]:
) -> list["PendingHumanReviewModel"]:
"""
Get all pending reviews for a specific graph execution.
@@ -419,7 +396,8 @@ async def get_pending_reviews_for_execution(
# Fetch node_id for each review from NodeExecution
result = []
for review in reviews:
node_id = await _resolve_node_id(review.nodeExecId, get_node_execution)
node_exec = await get_node_execution(review.nodeExecId)
node_id = node_exec.node_id if node_exec else review.nodeExecId
result.append(PendingHumanReviewModel.from_db(review, node_id=node_id))
return result
@@ -531,12 +509,8 @@ async def process_all_reviews_for_execution(
result = {}
for review in all_result_reviews:
if is_copilot_synthetic_id(review.nodeExecId):
# CoPilot synthetic node_exec_ids encode node_id as "{node_id}:{random}"
node_id = parse_node_id_from_exec_id(review.nodeExecId)
else:
node_exec = await get_node_execution(review.nodeExecId)
node_id = node_exec.node_id if node_exec else review.nodeExecId
node_exec = await get_node_execution(review.nodeExecId)
node_id = node_exec.node_id if node_exec else review.nodeExecId
result[review.nodeExecId] = PendingHumanReviewModel.from_db(
review, node_id=node_id
)
@@ -590,21 +564,3 @@ async def cancel_pending_reviews_for_execution(graph_exec_id: str, user_id: str)
},
)
return result
async def delete_review_by_node_exec_id(node_exec_id: str, user_id: str) -> int:
"""Delete a review record by node execution ID after it has been consumed.
Used by CoPilot's continue_run_block to clean up one-time-use review records
after successful execution.
Args:
node_exec_id: The node execution ID of the review to delete
user_id: User ID for authorization
Returns:
Number of records deleted
"""
return await PendingHumanReview.prisma().delete_many(
where={"nodeExecId": node_exec_id, "userId": user_id}
)

View File

@@ -1,750 +0,0 @@
import asyncio
import csv
import io
import logging
import os
import re
import socket
from dataclasses import dataclass
from datetime import datetime, timezone
from typing import Any, Literal, Optional
from uuid import uuid4
import prisma.enums
import prisma.models
import prisma.types
from prisma.errors import UniqueViolationError
from pydantic import BaseModel, EmailStr, TypeAdapter, ValidationError
from backend.data.db import transaction
from backend.data.model import User
from backend.data.redis_client import get_redis_async
from backend.data.tally import get_business_understanding_input_from_tally, mask_email
from backend.data.understanding import (
BusinessUnderstandingInput,
merge_business_understanding_data,
)
from backend.data.user import get_user_by_email, get_user_by_id
from backend.executor.cluster_lock import AsyncClusterLock
from backend.util.exceptions import (
NotAuthorizedError,
NotFoundError,
PreconditionFailed,
)
from backend.util.json import SafeJson
from backend.util.settings import Settings
logger = logging.getLogger(__name__)
_settings = Settings()
_WORKER_ID = f"{socket.gethostname()}:{os.getpid()}"
_tally_seed_tasks: dict[str, asyncio.Task] = {}
_TALLY_STALE_SECONDS = 300
_MAX_TALLY_ERROR_LENGTH = 200
_email_adapter = TypeAdapter(EmailStr)
MAX_BULK_INVITE_FILE_BYTES = 1024 * 1024
MAX_BULK_INVITE_ROWS = 500
class InvitedUserRecord(BaseModel):
id: str
email: str
status: prisma.enums.InvitedUserStatus
auth_user_id: Optional[str] = None
name: Optional[str] = None
tally_understanding: Optional[dict[str, Any]] = None
tally_status: prisma.enums.TallyComputationStatus
tally_computed_at: Optional[datetime] = None
tally_error: Optional[str] = None
created_at: datetime
updated_at: datetime
@classmethod
def from_db(cls, invited_user: "prisma.models.InvitedUser") -> "InvitedUserRecord":
payload = (
invited_user.tallyUnderstanding
if isinstance(invited_user.tallyUnderstanding, dict)
else None
)
return cls(
id=invited_user.id,
email=invited_user.email,
status=invited_user.status,
auth_user_id=invited_user.authUserId,
name=invited_user.name,
tally_understanding=payload,
tally_status=invited_user.tallyStatus,
tally_computed_at=invited_user.tallyComputedAt,
tally_error=invited_user.tallyError,
created_at=invited_user.createdAt,
updated_at=invited_user.updatedAt,
)
class BulkInvitedUserRowResult(BaseModel):
row_number: int
email: Optional[str] = None
name: Optional[str] = None
status: Literal["CREATED", "SKIPPED", "ERROR"]
message: str
invited_user: Optional[InvitedUserRecord] = None
class BulkInvitedUsersResult(BaseModel):
created_count: int
skipped_count: int
error_count: int
results: list[BulkInvitedUserRowResult]
@dataclass(frozen=True)
class _ParsedInviteRow:
row_number: int
email: str
name: Optional[str]
def normalize_email(email: str) -> str:
return email.strip().lower()
def _normalize_name(name: Optional[str]) -> Optional[str]:
if name is None:
return None
normalized = name.strip()
return normalized or None
def _default_profile_name(email: str, preferred_name: Optional[str]) -> str:
if preferred_name:
return preferred_name
local_part = email.split("@", 1)[0].strip()
return local_part or "user"
def _sanitize_username_base(email: str) -> str:
local_part = email.split("@", 1)[0].lower()
sanitized = re.sub(r"[^a-z0-9-]", "", local_part)
sanitized = sanitized.strip("-")
return sanitized[:40] or "user"
async def _generate_unique_profile_username(email: str, tx) -> str:
base = _sanitize_username_base(email)
for _ in range(2):
candidate = f"{base}-{uuid4().hex[:6]}"
existing = await prisma.models.Profile.prisma(tx).find_unique(
where={"username": candidate}
)
if existing is None:
return candidate
raise RuntimeError(f"Unable to generate unique username for {email}")
async def _ensure_default_profile(
user_id: str,
email: str,
preferred_name: Optional[str],
tx,
) -> None:
existing_profile = await prisma.models.Profile.prisma(tx).find_unique(
where={"userId": user_id}
)
if existing_profile is not None:
return
username = await _generate_unique_profile_username(email, tx)
await prisma.models.Profile.prisma(tx).create(
data=prisma.types.ProfileCreateInput(
userId=user_id,
name=_default_profile_name(email, preferred_name),
username=username,
description="I'm new here",
links=[],
avatarUrl="",
)
)
async def _ensure_default_onboarding(user_id: str, tx) -> None:
await prisma.models.UserOnboarding.prisma(tx).upsert(
where={"userId": user_id},
data={
"create": prisma.types.UserOnboardingCreateInput(userId=user_id),
"update": {},
},
)
async def _apply_tally_understanding(
user_id: str,
invited_user: "prisma.models.InvitedUser",
tx,
) -> None:
if not isinstance(invited_user.tallyUnderstanding, dict):
return
try:
input_data = BusinessUnderstandingInput.model_validate(
invited_user.tallyUnderstanding
)
except Exception:
logger.warning(
"Malformed tallyUnderstanding for invited user %s; skipping",
invited_user.id,
exc_info=True,
)
return
payload = merge_business_understanding_data({}, input_data)
await prisma.models.CoPilotUnderstanding.prisma(tx).upsert(
where={"userId": user_id},
data={
"create": {"userId": user_id, "data": SafeJson(payload)},
"update": {"data": SafeJson(payload)},
},
)
async def list_invited_users(
page: int = 1,
page_size: int = 50,
) -> tuple[list[InvitedUserRecord], int]:
total = await prisma.models.InvitedUser.prisma().count()
invited_users = await prisma.models.InvitedUser.prisma().find_many(
order={"createdAt": "desc"},
skip=(page - 1) * page_size,
take=page_size,
)
return [InvitedUserRecord.from_db(iu) for iu in invited_users], total
async def create_invited_user(
email: str, name: Optional[str] = None
) -> InvitedUserRecord:
normalized_email = normalize_email(email)
normalized_name = _normalize_name(name)
existing_user = await prisma.models.User.prisma().find_unique(
where={"email": normalized_email}
)
if existing_user is not None:
raise PreconditionFailed("An active user with this email already exists")
existing_invited_user = await prisma.models.InvitedUser.prisma().find_unique(
where={"email": normalized_email}
)
if existing_invited_user is not None:
raise PreconditionFailed("An invited user with this email already exists")
try:
invited_user = await prisma.models.InvitedUser.prisma().create(
data={
"email": normalized_email,
"name": normalized_name,
"status": prisma.enums.InvitedUserStatus.INVITED,
"tallyStatus": prisma.enums.TallyComputationStatus.PENDING,
}
)
except UniqueViolationError:
raise PreconditionFailed("An invited user with this email already exists")
schedule_invited_user_tally_precompute(invited_user.id)
return InvitedUserRecord.from_db(invited_user)
async def revoke_invited_user(invited_user_id: str) -> InvitedUserRecord:
invited_user = await prisma.models.InvitedUser.prisma().find_unique(
where={"id": invited_user_id}
)
if invited_user is None:
raise NotFoundError(f"Invited user {invited_user_id} not found")
if invited_user.status == prisma.enums.InvitedUserStatus.CLAIMED:
raise PreconditionFailed("Claimed invited users cannot be revoked")
if invited_user.status == prisma.enums.InvitedUserStatus.REVOKED:
return InvitedUserRecord.from_db(invited_user)
revoked_user = await prisma.models.InvitedUser.prisma().update(
where={"id": invited_user_id},
data={"status": prisma.enums.InvitedUserStatus.REVOKED},
)
if revoked_user is None:
raise NotFoundError(f"Invited user {invited_user_id} not found")
return InvitedUserRecord.from_db(revoked_user)
async def retry_invited_user_tally(invited_user_id: str) -> InvitedUserRecord:
invited_user = await prisma.models.InvitedUser.prisma().find_unique(
where={"id": invited_user_id}
)
if invited_user is None:
raise NotFoundError(f"Invited user {invited_user_id} not found")
if invited_user.status == prisma.enums.InvitedUserStatus.REVOKED:
raise PreconditionFailed("Revoked invited users cannot retry Tally seeding")
refreshed_user = await prisma.models.InvitedUser.prisma().update(
where={"id": invited_user_id},
data={
"tallyUnderstanding": None,
"tallyStatus": prisma.enums.TallyComputationStatus.PENDING,
"tallyComputedAt": None,
"tallyError": None,
},
)
if refreshed_user is None:
raise NotFoundError(f"Invited user {invited_user_id} not found")
schedule_invited_user_tally_precompute(invited_user_id)
return InvitedUserRecord.from_db(refreshed_user)
def _decode_bulk_invite_file(content: bytes) -> str:
if len(content) > MAX_BULK_INVITE_FILE_BYTES:
raise ValueError("Invite file exceeds the maximum size of 1 MB")
try:
return content.decode("utf-8-sig")
except UnicodeDecodeError as exc:
raise ValueError("Invite file must be UTF-8 encoded") from exc
def _parse_bulk_invite_csv(text: str) -> list[_ParsedInviteRow]:
indexed_rows: list[tuple[int, list[str]]] = []
for row_number, row in enumerate(csv.reader(io.StringIO(text)), start=1):
normalized_row = [cell.strip() for cell in row]
if any(normalized_row):
indexed_rows.append((row_number, normalized_row))
if not indexed_rows:
return []
header = [cell.lower() for cell in indexed_rows[0][1]]
has_header = "email" in header
email_index = header.index("email") if has_header else 0
name_index: Optional[int] = (
header.index("name")
if has_header and "name" in header
else (1 if not has_header else None)
)
data_rows = indexed_rows[1:] if has_header else indexed_rows
parsed_rows: list[_ParsedInviteRow] = []
for row_number, row in data_rows:
if len(parsed_rows) >= MAX_BULK_INVITE_ROWS:
break
email = row[email_index].strip() if len(row) > email_index else ""
name = (
row[name_index].strip()
if name_index is not None and len(row) > name_index
else ""
)
parsed_rows.append(
_ParsedInviteRow(
row_number=row_number,
email=email,
name=name or None,
)
)
return parsed_rows
def _parse_bulk_invite_text(text: str) -> list[_ParsedInviteRow]:
parsed_rows: list[_ParsedInviteRow] = []
for row_number, raw_line in enumerate(text.splitlines(), start=1):
if len(parsed_rows) >= MAX_BULK_INVITE_ROWS:
break
line = raw_line.strip()
if not line or line.startswith("#"):
continue
parsed_rows.append(
_ParsedInviteRow(
row_number=row_number,
email=line,
name=None,
)
)
return parsed_rows
def _parse_bulk_invite_file(
filename: Optional[str],
content: bytes,
) -> list[_ParsedInviteRow]:
text = _decode_bulk_invite_file(content)
file_name = filename.lower() if filename else ""
parsed_rows = (
_parse_bulk_invite_csv(text)
if file_name.endswith(".csv")
else _parse_bulk_invite_text(text)
)
if not parsed_rows:
raise ValueError("Invite file did not contain any emails")
return parsed_rows
async def bulk_create_invited_users_from_file(
filename: Optional[str],
content: bytes,
) -> BulkInvitedUsersResult:
parsed_rows = _parse_bulk_invite_file(filename, content)
created_count = 0
skipped_count = 0
error_count = 0
results: list[BulkInvitedUserRowResult] = []
seen_emails: set[str] = set()
for row in parsed_rows:
row_name = _normalize_name(row.name)
try:
validated_email = _email_adapter.validate_python(row.email)
except ValidationError:
error_count += 1
results.append(
BulkInvitedUserRowResult(
row_number=row.row_number,
email=row.email or None,
name=row_name,
status="ERROR",
message="Invalid email address",
)
)
continue
normalized_email = normalize_email(str(validated_email))
if normalized_email in seen_emails:
skipped_count += 1
results.append(
BulkInvitedUserRowResult(
row_number=row.row_number,
email=normalized_email,
name=row_name,
status="SKIPPED",
message="Duplicate email in upload file",
)
)
continue
seen_emails.add(normalized_email)
try:
invited_user = await create_invited_user(normalized_email, row_name)
except PreconditionFailed as exc:
skipped_count += 1
results.append(
BulkInvitedUserRowResult(
row_number=row.row_number,
email=normalized_email,
name=row_name,
status="SKIPPED",
message=str(exc),
)
)
except Exception:
masked = mask_email(normalized_email)
logger.exception(
"Failed to create bulk invite for row %s (%s)",
row.row_number,
masked,
)
error_count += 1
results.append(
BulkInvitedUserRowResult(
row_number=row.row_number,
email=normalized_email,
name=row_name,
status="ERROR",
message="Unexpected error creating invite",
)
)
else:
created_count += 1
results.append(
BulkInvitedUserRowResult(
row_number=row.row_number,
email=normalized_email,
name=row_name,
status="CREATED",
message="Invite created",
invited_user=invited_user,
)
)
return BulkInvitedUsersResult(
created_count=created_count,
skipped_count=skipped_count,
error_count=error_count,
results=results,
)
async def _compute_invited_user_tally_seed(invited_user_id: str) -> None:
invited_user = await prisma.models.InvitedUser.prisma().find_unique(
where={"id": invited_user_id}
)
if invited_user is None:
return
if invited_user.status == prisma.enums.InvitedUserStatus.REVOKED:
return
try:
r = await get_redis_async()
except Exception:
r = None
lock: AsyncClusterLock | None = None
if r is not None:
lock = AsyncClusterLock(
redis=r,
key=f"tally_seed:{invited_user_id}",
owner_id=_WORKER_ID,
timeout=_TALLY_STALE_SECONDS,
)
current_owner = await lock.try_acquire()
if current_owner is None:
logger.warn("Redis unvailable for tally lock - skipping tally enrichement")
return
elif current_owner != _WORKER_ID:
logger.debug(
"Tally seed for %s already locked by %s, skipping",
invited_user_id,
current_owner,
)
return
if (
invited_user.tallyStatus == prisma.enums.TallyComputationStatus.RUNNING
and invited_user.updatedAt is not None
):
age = (datetime.now(timezone.utc) - invited_user.updatedAt).total_seconds()
if age < _TALLY_STALE_SECONDS:
logger.debug(
"Tally task for %s still RUNNING (age=%ds), skipping",
invited_user_id,
int(age),
)
return
logger.info(
"Tally task for %s is stale (age=%ds), re-running",
invited_user_id,
int(age),
)
await prisma.models.InvitedUser.prisma().update(
where={"id": invited_user_id},
data={
"tallyStatus": prisma.enums.TallyComputationStatus.RUNNING,
"tallyError": None,
},
)
try:
input_data = await get_business_understanding_input_from_tally(
invited_user.email,
require_api_key=True,
)
payload = (
SafeJson(input_data.model_dump(exclude_none=True))
if input_data is not None
else None
)
await prisma.models.InvitedUser.prisma().update(
where={"id": invited_user_id},
data={
"tallyUnderstanding": payload,
"tallyStatus": prisma.enums.TallyComputationStatus.READY,
"tallyComputedAt": datetime.now(timezone.utc),
"tallyError": None,
},
)
except Exception as exc:
logger.exception(
"Failed to compute Tally understanding for invited user %s",
invited_user_id,
)
sanitized_error = re.sub(
r"https?://\S+", "<url>", f"{type(exc).__name__}: {exc}"
)[:_MAX_TALLY_ERROR_LENGTH]
await prisma.models.InvitedUser.prisma().update(
where={"id": invited_user_id},
data={
"tallyStatus": prisma.enums.TallyComputationStatus.FAILED,
"tallyError": sanitized_error,
},
)
def schedule_invited_user_tally_precompute(invited_user_id: str) -> None:
existing = _tally_seed_tasks.get(invited_user_id)
if existing is not None and not existing.done():
logger.debug("Tally task already running for %s, skipping", invited_user_id)
return
task = asyncio.create_task(_compute_invited_user_tally_seed(invited_user_id))
_tally_seed_tasks[invited_user_id] = task
def _on_done(t: asyncio.Task, _id: str = invited_user_id) -> None:
if _tally_seed_tasks.get(_id) is t:
del _tally_seed_tasks[_id]
task.add_done_callback(_on_done)
async def _open_signup_create_user(
auth_user_id: str,
normalized_email: str,
metadata_name: Optional[str],
) -> User:
"""Create a user without requiring an invite (open signup mode)."""
preferred_name = _normalize_name(metadata_name)
try:
async with transaction() as tx:
user = await prisma.models.User.prisma(tx).create(
data=prisma.types.UserCreateInput(
id=auth_user_id,
email=normalized_email,
name=preferred_name,
)
)
await _ensure_default_profile(
auth_user_id, normalized_email, preferred_name, tx
)
await _ensure_default_onboarding(auth_user_id, tx)
except UniqueViolationError:
existing = await prisma.models.User.prisma().find_unique(
where={"id": auth_user_id}
)
if existing is not None:
return User.from_db(existing)
raise
return User.from_db(user)
# TODO: We need to change this functions logic before going live
async def get_or_activate_user(user_data: dict) -> User:
auth_user_id = user_data.get("sub")
if not auth_user_id:
raise NotAuthorizedError("User ID not found in token")
auth_email = user_data.get("email")
if not auth_email:
raise NotAuthorizedError("Email not found in token")
normalized_email = normalize_email(auth_email)
user_metadata = user_data.get("user_metadata")
metadata_name = (
user_metadata.get("name") if isinstance(user_metadata, dict) else None
)
existing_user = None
try:
existing_user = await get_user_by_id(auth_user_id)
except ValueError:
existing_user = None
except Exception:
logger.exception("Error on get user by id during tally enrichment process")
raise
if existing_user is not None:
return existing_user
if not _settings.config.enable_invite_gate or normalized_email.endswith("@agpt.co"):
return await _open_signup_create_user(
auth_user_id, normalized_email, metadata_name
)
invited_user = await prisma.models.InvitedUser.prisma().find_unique(
where={"email": normalized_email}
)
if invited_user is None:
raise NotAuthorizedError("Your email is not allowed to access the platform")
if invited_user.status != prisma.enums.InvitedUserStatus.INVITED:
raise NotAuthorizedError("Your invitation is no longer active")
try:
async with transaction() as tx:
current_user = await prisma.models.User.prisma(tx).find_unique(
where={"id": auth_user_id}
)
if current_user is not None:
return User.from_db(current_user)
current_invited_user = await prisma.models.InvitedUser.prisma(
tx
).find_unique(where={"email": normalized_email})
if current_invited_user is None:
raise NotAuthorizedError(
"Your email is not allowed to access the platform"
)
if current_invited_user.status != prisma.enums.InvitedUserStatus.INVITED:
raise NotAuthorizedError("Your invitation is no longer active")
if current_invited_user.authUserId not in (None, auth_user_id):
raise NotAuthorizedError("Your invitation has already been claimed")
preferred_name = current_invited_user.name or _normalize_name(metadata_name)
await prisma.models.User.prisma(tx).create(
data=prisma.types.UserCreateInput(
id=auth_user_id,
email=normalized_email,
name=preferred_name,
)
)
await prisma.models.InvitedUser.prisma(tx).update(
where={"id": current_invited_user.id},
data={
"status": prisma.enums.InvitedUserStatus.CLAIMED,
"authUserId": auth_user_id,
},
)
await _ensure_default_profile(
auth_user_id,
normalized_email,
preferred_name,
tx,
)
await _ensure_default_onboarding(auth_user_id, tx)
await _apply_tally_understanding(auth_user_id, current_invited_user, tx)
except UniqueViolationError:
logger.info("Concurrent activation for user %s; re-fetching", auth_user_id)
already_created = await prisma.models.User.prisma().find_unique(
where={"id": auth_user_id}
)
if already_created is not None:
return User.from_db(already_created)
raise RuntimeError(
f"UniqueViolationError during activation but user {auth_user_id} not found"
)
get_user_by_id.cache_delete(auth_user_id)
get_user_by_email.cache_delete(normalized_email)
activated_user = await prisma.models.User.prisma().find_unique(
where={"id": auth_user_id}
)
if activated_user is None:
raise RuntimeError(
f"Activated user {auth_user_id} was not found after creation"
)
return User.from_db(activated_user)

View File

@@ -1,335 +0,0 @@
from contextlib import asynccontextmanager
from datetime import datetime, timezone
from types import SimpleNamespace
from typing import cast
from unittest.mock import AsyncMock, Mock
import prisma.enums
import prisma.models
import pytest
import pytest_mock
from backend.util.exceptions import NotAuthorizedError, PreconditionFailed
from .invited_user import (
InvitedUserRecord,
bulk_create_invited_users_from_file,
create_invited_user,
get_or_activate_user,
retry_invited_user_tally,
)
def _invited_user_db_record(
*,
status: prisma.enums.InvitedUserStatus = prisma.enums.InvitedUserStatus.INVITED,
tally_understanding: dict | None = None,
):
now = datetime.now(timezone.utc)
return SimpleNamespace(
id="invite-1",
email="invited@example.com",
status=status,
authUserId=None,
name="Invited User",
tallyUnderstanding=tally_understanding,
tallyStatus=prisma.enums.TallyComputationStatus.PENDING,
tallyComputedAt=None,
tallyError=None,
createdAt=now,
updatedAt=now,
)
def _invited_user_record(
*,
status: prisma.enums.InvitedUserStatus = prisma.enums.InvitedUserStatus.INVITED,
tally_understanding: dict | None = None,
):
return InvitedUserRecord.from_db(
cast(
prisma.models.InvitedUser,
_invited_user_db_record(
status=status,
tally_understanding=tally_understanding,
),
)
)
def _user_db_record():
now = datetime.now(timezone.utc)
return SimpleNamespace(
id="auth-user-1",
email="invited@example.com",
emailVerified=True,
name="Invited User",
createdAt=now,
updatedAt=now,
metadata={},
integrations="",
stripeCustomerId=None,
topUpConfig=None,
maxEmailsPerDay=3,
notifyOnAgentRun=True,
notifyOnZeroBalance=True,
notifyOnLowBalance=True,
notifyOnBlockExecutionFailed=True,
notifyOnContinuousAgentError=True,
notifyOnDailySummary=True,
notifyOnWeeklySummary=True,
notifyOnMonthlySummary=True,
notifyOnAgentApproved=True,
notifyOnAgentRejected=True,
timezone="not-set",
)
@pytest.mark.asyncio
async def test_create_invited_user_rejects_existing_active_user(
mocker: pytest_mock.MockerFixture,
) -> None:
user_repo = Mock()
user_repo.find_unique = AsyncMock(return_value=_user_db_record())
invited_user_repo = Mock()
invited_user_repo.find_unique = AsyncMock()
mocker.patch(
"backend.data.invited_user.prisma.models.User.prisma", return_value=user_repo
)
mocker.patch(
"backend.data.invited_user.prisma.models.InvitedUser.prisma",
return_value=invited_user_repo,
)
with pytest.raises(PreconditionFailed):
await create_invited_user("Invited@example.com")
@pytest.mark.asyncio
async def test_create_invited_user_schedules_tally_seed(
mocker: pytest_mock.MockerFixture,
) -> None:
user_repo = Mock()
user_repo.find_unique = AsyncMock(return_value=None)
invited_user_repo = Mock()
invited_user_repo.find_unique = AsyncMock(return_value=None)
invited_user_repo.create = AsyncMock(return_value=_invited_user_db_record())
schedule = mocker.patch(
"backend.data.invited_user.schedule_invited_user_tally_precompute"
)
mocker.patch(
"backend.data.invited_user.prisma.models.User.prisma", return_value=user_repo
)
mocker.patch(
"backend.data.invited_user.prisma.models.InvitedUser.prisma",
return_value=invited_user_repo,
)
invited_user = await create_invited_user("Invited@example.com", "Invited User")
assert invited_user.email == "invited@example.com"
invited_user_repo.create.assert_awaited_once()
schedule.assert_called_once_with("invite-1")
@pytest.mark.asyncio
async def test_retry_invited_user_tally_resets_state_and_schedules(
mocker: pytest_mock.MockerFixture,
) -> None:
invited_user_repo = Mock()
invited_user_repo.find_unique = AsyncMock(return_value=_invited_user_db_record())
invited_user_repo.update = AsyncMock(return_value=_invited_user_db_record())
schedule = mocker.patch(
"backend.data.invited_user.schedule_invited_user_tally_precompute"
)
mocker.patch(
"backend.data.invited_user.prisma.models.InvitedUser.prisma",
return_value=invited_user_repo,
)
invited_user = await retry_invited_user_tally("invite-1")
assert invited_user.id == "invite-1"
invited_user_repo.update.assert_awaited_once()
schedule.assert_called_once_with("invite-1")
@pytest.mark.asyncio
async def test_get_or_activate_user_requires_invite(
mocker: pytest_mock.MockerFixture,
) -> None:
invited_user_repo = Mock()
invited_user_repo.find_unique = AsyncMock(return_value=None)
mock_get_user_by_id = AsyncMock(side_effect=ValueError("User not found"))
mock_get_user_by_id.cache_delete = Mock()
mocker.patch(
"backend.data.invited_user.get_user_by_id",
mock_get_user_by_id,
)
mocker.patch(
"backend.data.invited_user._settings.config.enable_invite_gate",
True,
)
mocker.patch(
"backend.data.invited_user.prisma.models.InvitedUser.prisma",
return_value=invited_user_repo,
)
with pytest.raises(NotAuthorizedError):
await get_or_activate_user(
{"sub": "auth-user-1", "email": "invited@example.com"}
)
@pytest.mark.asyncio
async def test_get_or_activate_user_creates_user_from_invite(
mocker: pytest_mock.MockerFixture,
) -> None:
tx = object()
invited_user = _invited_user_db_record(
tally_understanding={"user_name": "Invited User", "industry": "Automation"}
)
created_user = _user_db_record()
outside_user_repo = Mock()
# Only called once at post-transaction verification (line 741);
# get_user_by_id (line 657) uses prisma.user.find_unique, not this mock.
outside_user_repo.find_unique = AsyncMock(return_value=created_user)
inside_user_repo = Mock()
inside_user_repo.find_unique = AsyncMock(return_value=None)
inside_user_repo.create = AsyncMock(return_value=created_user)
outside_invited_repo = Mock()
outside_invited_repo.find_unique = AsyncMock(return_value=invited_user)
inside_invited_repo = Mock()
inside_invited_repo.find_unique = AsyncMock(return_value=invited_user)
inside_invited_repo.update = AsyncMock(return_value=invited_user)
def user_prisma(client=None):
return inside_user_repo if client is tx else outside_user_repo
def invited_user_prisma(client=None):
return inside_invited_repo if client is tx else outside_invited_repo
@asynccontextmanager
async def fake_transaction():
yield tx
# Mock get_user_by_id since it uses prisma.user.find_unique (global client),
# not prisma.models.User.prisma().find_unique which we mock above.
mock_get_user_by_id = AsyncMock(side_effect=ValueError("User not found"))
mock_get_user_by_id.cache_delete = Mock()
mocker.patch(
"backend.data.invited_user.get_user_by_id",
mock_get_user_by_id,
)
mock_get_user_by_email = AsyncMock()
mock_get_user_by_email.cache_delete = Mock()
mocker.patch(
"backend.data.invited_user.get_user_by_email",
mock_get_user_by_email,
)
ensure_profile = mocker.patch(
"backend.data.invited_user._ensure_default_profile",
AsyncMock(),
)
ensure_onboarding = mocker.patch(
"backend.data.invited_user._ensure_default_onboarding",
AsyncMock(),
)
apply_tally = mocker.patch(
"backend.data.invited_user._apply_tally_understanding",
AsyncMock(),
)
mocker.patch("backend.data.invited_user.transaction", fake_transaction)
mocker.patch(
"backend.data.invited_user.prisma.models.User.prisma", side_effect=user_prisma
)
mocker.patch(
"backend.data.invited_user.prisma.models.InvitedUser.prisma",
side_effect=invited_user_prisma,
)
user = await get_or_activate_user(
{
"sub": "auth-user-1",
"email": "Invited@example.com",
"user_metadata": {"name": "Invited User"},
}
)
assert user.id == "auth-user-1"
inside_user_repo.create.assert_awaited_once()
inside_invited_repo.update.assert_awaited_once()
ensure_profile.assert_awaited_once()
ensure_onboarding.assert_awaited_once_with("auth-user-1", tx)
apply_tally.assert_awaited_once_with("auth-user-1", invited_user, tx)
@pytest.mark.asyncio
async def test_bulk_create_invited_users_from_text_file(
mocker: pytest_mock.MockerFixture,
) -> None:
create_invited = mocker.patch(
"backend.data.invited_user.create_invited_user",
AsyncMock(
side_effect=[
_invited_user_record(),
_invited_user_record(),
]
),
)
result = await bulk_create_invited_users_from_file(
"invites.txt",
b"Invited@example.com\nsecond@example.com\n",
)
assert result.created_count == 2
assert result.skipped_count == 0
assert result.error_count == 0
assert [row.status for row in result.results] == ["CREATED", "CREATED"]
assert create_invited.await_count == 2
@pytest.mark.asyncio
async def test_bulk_create_invited_users_handles_csv_duplicates_and_invalid_rows(
mocker: pytest_mock.MockerFixture,
) -> None:
create_invited = mocker.patch(
"backend.data.invited_user.create_invited_user",
AsyncMock(
side_effect=[
_invited_user_record(),
PreconditionFailed("An invited user with this email already exists"),
]
),
)
result = await bulk_create_invited_users_from_file(
"invites.csv",
(
"email,name\n"
"valid@example.com,Valid User\n"
"not-an-email,Bad Row\n"
"valid@example.com,Duplicate In File\n"
"existing@example.com,Existing User\n"
).encode("utf-8"),
)
assert result.created_count == 1
assert result.skipped_count == 2
assert result.error_count == 1
assert [row.status for row in result.results] == [
"CREATED",
"ERROR",
"SKIPPED",
"SKIPPED",
]
assert create_invited.await_count == 2

View File

@@ -0,0 +1,31 @@
"""LLM Registry - Dynamic model management system."""
from .model import ModelMetadata
from .registry import (
RegistryModel,
RegistryModelCost,
RegistryModelCreator,
get_all_model_slugs_for_validation,
get_all_models,
get_default_model_slug,
get_enabled_models,
get_model,
get_schema_options,
refresh_llm_registry,
)
__all__ = [
# Models
"ModelMetadata",
"RegistryModel",
"RegistryModelCost",
"RegistryModelCreator",
# Functions
"refresh_llm_registry",
"get_model",
"get_all_models",
"get_enabled_models",
"get_schema_options",
"get_default_model_slug",
"get_all_model_slugs_for_validation",
]

View File

@@ -0,0 +1,9 @@
"""Type definitions for LLM model metadata.
Re-exports ModelMetadata from blocks.llm to avoid type collision.
In PR #5 (block integration), this will become the canonical location.
"""
from backend.blocks.llm import ModelMetadata
__all__ = ["ModelMetadata"]

View File

@@ -0,0 +1,240 @@
"""Core LLM registry implementation for managing models dynamically."""
from __future__ import annotations
import asyncio
import logging
from dataclasses import dataclass, field
from typing import Any
import prisma.models
from backend.data.llm_registry.model import ModelMetadata
logger = logging.getLogger(__name__)
@dataclass(frozen=True)
class RegistryModelCost:
"""Cost configuration for an LLM model."""
unit: str # "RUN" or "TOKENS"
credit_cost: int
credential_provider: str
credential_id: str | None
credential_type: str | None
currency: str | None
metadata: dict[str, Any]
@dataclass(frozen=True)
class RegistryModelCreator:
"""Creator information for an LLM model."""
id: str
name: str
display_name: str
description: str | None
website_url: str | None
logo_url: str | None
@dataclass(frozen=True)
class RegistryModel:
"""Represents a model in the LLM registry."""
slug: str
display_name: str
description: str | None
metadata: ModelMetadata
capabilities: dict[str, Any]
extra_metadata: dict[str, Any]
provider_display_name: str
is_enabled: bool
is_recommended: bool = False
costs: tuple[RegistryModelCost, ...] = field(default_factory=tuple)
creator: RegistryModelCreator | None = None
# In-memory cache (will be replaced with Redis in PR #6)
_dynamic_models: dict[str, RegistryModel] = {}
_schema_options: list[dict[str, str]] = []
_lock = asyncio.Lock()
async def refresh_llm_registry() -> None:
"""
Refresh the LLM registry from the database.
Fetches all models with their costs, providers, and creators,
then updates the in-memory cache.
"""
async with _lock:
try:
records = await prisma.models.LlmModel.prisma().find_many(
include={
"Provider": True,
"Costs": True,
"Creator": True,
}
)
logger.info(f"Fetched {len(records)} LLM models from database")
# Build model instances
new_models: dict[str, RegistryModel] = {}
for record in records:
# Parse costs
costs = tuple(
RegistryModelCost(
unit=str(cost.unit), # Convert enum to string
credit_cost=cost.creditCost,
credential_provider=cost.credentialProvider,
credential_id=cost.credentialId,
credential_type=cost.credentialType,
currency=cost.currency,
metadata=dict(cost.metadata or {}),
)
for cost in (record.Costs or [])
)
# Parse creator
creator = None
if record.Creator:
creator = RegistryModelCreator(
id=record.Creator.id,
name=record.Creator.name,
display_name=record.Creator.displayName,
description=record.Creator.description,
website_url=record.Creator.websiteUrl,
logo_url=record.Creator.logoUrl,
)
# Parse capabilities
capabilities = dict(record.capabilities or {})
# Build metadata from record
# Warn if Provider relation is missing (indicates data corruption)
if not record.Provider:
logger.warning(
f"LlmModel {record.slug} has no Provider despite NOT NULL FK - "
f"falling back to providerId {record.providerId}"
)
provider_name = (
record.Provider.name if record.Provider else record.providerId
)
provider_display = (
record.Provider.displayName
if record.Provider
else record.providerId
)
# Extract creator name (fallback to "Unknown" if no creator)
creator_name = (
record.Creator.displayName if record.Creator else "Unknown"
)
# Price tier defaults to 1 if not set
price_tier = record.priceTier if record.priceTier in (1, 2, 3) else 1
metadata = ModelMetadata(
provider=provider_name,
context_window=record.contextWindow,
max_output_tokens=(
record.maxOutputTokens
if record.maxOutputTokens is not None
else record.contextWindow
),
display_name=record.displayName,
provider_name=provider_display,
creator_name=creator_name,
price_tier=price_tier,
)
# Create model instance
model = RegistryModel(
slug=record.slug,
display_name=record.displayName,
description=record.description,
metadata=metadata,
capabilities=capabilities,
extra_metadata=dict(record.metadata or {}),
provider_display_name=provider_display,
is_enabled=record.isEnabled,
is_recommended=record.isRecommended,
costs=costs,
creator=creator,
)
new_models[record.slug] = model
# Atomic swap
global _dynamic_models, _schema_options
_dynamic_models = new_models
_schema_options = _build_schema_options()
logger.info(
f"LLM registry refreshed: {len(_dynamic_models)} models, "
f"{len(_schema_options)} schema options"
)
except Exception as e:
logger.error(f"Failed to refresh LLM registry: {e}", exc_info=True)
raise
def _build_schema_options() -> list[dict[str, str]]:
"""Build schema options for model selection dropdown. Only includes enabled models."""
return [
{
"label": model.display_name,
"value": model.slug,
"group": model.metadata.provider,
"description": model.description or "",
}
for model in sorted(
_dynamic_models.values(), key=lambda m: m.display_name.lower()
)
if model.is_enabled
]
def get_model(slug: str) -> RegistryModel | None:
"""Get a model by slug from the registry."""
return _dynamic_models.get(slug)
def get_all_models() -> list[RegistryModel]:
"""Get all models from the registry (including disabled)."""
return list(_dynamic_models.values())
def get_enabled_models() -> list[RegistryModel]:
"""Get only enabled models from the registry."""
return [model for model in _dynamic_models.values() if model.is_enabled]
def get_schema_options() -> list[dict[str, str]]:
"""Get schema options for model selection dropdown (enabled models only)."""
return _schema_options
def get_default_model_slug() -> str | None:
"""Get the default model slug (first recommended, or first enabled)."""
# Sort once and use next() to short-circuit on first match
models = sorted(_dynamic_models.values(), key=lambda m: m.display_name)
# Prefer recommended models
recommended = next(
(m.slug for m in models if m.is_recommended and m.is_enabled), None
)
if recommended:
return recommended
# Fallback to first enabled model
return next((m.slug for m in models if m.is_enabled), None)
def get_all_model_slugs_for_validation() -> list[str]:
"""
Get all model slugs for validation (enables migrate_llm_models to work).
Returns slugs for enabled models only.
"""
return [model.slug for model in _dynamic_models.values() if model.is_enabled]

View File

@@ -8,8 +8,6 @@ from backend.api.model import NotificationPayload
from backend.data.event_bus import AsyncRedisEventBus
from backend.util.settings import Settings
_settings = Settings()
class NotificationEvent(BaseModel):
"""Generic notification event destined for websocket delivery."""
@@ -28,7 +26,7 @@ class AsyncRedisNotificationEventBus(AsyncRedisEventBus[NotificationEvent]):
@property
def event_bus_name(self) -> str:
return _settings.config.notification_event_bus_name
return Settings().config.notification_event_bus_name
async def publish(self, event: NotificationEvent) -> None:
await self.publish_event(event, event.user_id)

View File

@@ -41,7 +41,7 @@ _MAX_PAGES = 100
_LLM_TIMEOUT = 30
def mask_email(email: str) -> str:
def _mask_email(email: str) -> str:
"""Mask an email for safe logging: 'alice@example.com' -> 'a***e@example.com'."""
try:
local, domain = email.rsplit("@", 1)
@@ -196,7 +196,8 @@ async def _refresh_cache(form_id: str) -> tuple[dict, list]:
Returns (email_index, questions).
"""
client = _make_tally_client(_settings.secrets.tally_api_key)
settings = Settings()
client = _make_tally_client(settings.secrets.tally_api_key)
redis = await get_redis_async()
last_fetch_key = _LAST_FETCH_KEY.format(form_id=form_id)
@@ -331,9 +332,6 @@ Fields:
- current_software (list of strings): software/tools currently used
- existing_automation (list of strings): existing automations
- additional_notes (string): any additional context
- suggested_prompts (list of 5 strings): short action prompts (each under 20 words) that would help \
this person get started with automating their work. Should be specific to their industry, role, and \
pain points; actionable and conversational in tone; focused on automation opportunities.
Form data:
"""
@@ -341,21 +339,21 @@ Form data:
_EXTRACTION_SUFFIX = "\n\nReturn ONLY valid JSON."
async def extract_business_understanding_from_tally(
async def extract_business_understanding(
formatted_text: str,
) -> BusinessUnderstandingInput:
"""
Use an LLM to extract structured business understanding from form text.
"""Use an LLM to extract structured business understanding from form text.
Raises on timeout or unparseable response so the caller can handle it.
"""
api_key = _settings.secrets.open_router_api_key
settings = Settings()
api_key = settings.secrets.open_router_api_key
client = AsyncOpenAI(api_key=api_key, base_url=OPENROUTER_BASE_URL)
try:
response = await asyncio.wait_for(
client.chat.completions.create(
model=_settings.config.tally_extraction_llm_model,
model="openai/gpt-4o-mini",
messages=[
{
"role": "user",
@@ -380,57 +378,9 @@ async def extract_business_understanding_from_tally(
# Filter out null values before constructing
cleaned = {k: v for k, v in data.items() if v is not None}
# Validate suggested_prompts: filter >20 words, keep top 3
raw_prompts = cleaned.get("suggested_prompts", [])
if isinstance(raw_prompts, list):
valid = [
p.strip()
for p in raw_prompts
if isinstance(p, str) and len(p.strip().split()) <= 20
]
# This will keep up to 3 suggestions
short_prompts = valid[:3] if valid else None
if short_prompts:
cleaned["suggested_prompts"] = short_prompts
else:
# We dont want to add a None value suggested_prompts field
cleaned.pop("suggested_prompts", None)
else:
# suggested_prompts must be a list - removing it as its not here
cleaned.pop("suggested_prompts", None)
return BusinessUnderstandingInput(**cleaned)
async def get_business_understanding_input_from_tally(
email: str,
*,
require_api_key: bool = False,
) -> Optional[BusinessUnderstandingInput]:
if not _settings.secrets.tally_api_key:
if require_api_key:
raise RuntimeError("Tally API key is not configured")
logger.debug("Tally: no API key configured, skipping")
return None
masked = mask_email(email)
result = await find_submission_by_email(TALLY_FORM_ID, email)
if result is None:
logger.debug(f"Tally: no submission found for {masked}")
return None
submission, questions = result
logger.info(f"Tally: found submission for {masked}, extracting understanding")
formatted = format_submission_for_llm(submission, questions)
if not formatted.strip():
logger.warning("Tally: formatted submission was empty, skipping")
return None
return await extract_business_understanding_from_tally(formatted)
async def populate_understanding_from_tally(user_id: str, email: str) -> None:
"""Main orchestrator: check Tally for a matching submission and populate understanding.
@@ -445,10 +395,30 @@ async def populate_understanding_from_tally(user_id: str, email: str) -> None:
)
return
understanding_input = await get_business_understanding_input_from_tally(email)
if understanding_input is None:
# Check API key is configured
settings = Settings()
if not settings.secrets.tally_api_key:
logger.debug("Tally: no API key configured, skipping")
return
# Look up submission by email
masked = _mask_email(email)
result = await find_submission_by_email(TALLY_FORM_ID, email)
if result is None:
logger.debug(f"Tally: no submission found for {masked}")
return
submission, questions = result
logger.info(f"Tally: found submission for {masked}, extracting understanding")
# Format and extract
formatted = format_submission_for_llm(submission, questions)
if not formatted.strip():
logger.warning("Tally: formatted submission was empty, skipping")
return
understanding_input = await extract_business_understanding(formatted)
# Upsert into database
await upsert_business_understanding(user_id, understanding_input)
logger.info(f"Tally: successfully populated understanding for user {user_id}")

View File

@@ -12,11 +12,11 @@ from backend.data.tally import (
_build_email_index,
_format_answer,
_make_tally_client,
_mask_email,
_refresh_cache,
extract_business_understanding_from_tally,
extract_business_understanding,
find_submission_by_email,
format_submission_for_llm,
mask_email,
populate_understanding_from_tally,
)
@@ -248,7 +248,7 @@ async def test_populate_understanding_skips_no_api_key():
new_callable=AsyncMock,
return_value=None,
),
patch("backend.data.tally._settings", mock_settings),
patch("backend.data.tally.Settings", return_value=mock_settings),
patch(
"backend.data.tally.find_submission_by_email",
new_callable=AsyncMock,
@@ -284,7 +284,6 @@ async def test_populate_understanding_full_flow():
],
}
mock_input = MagicMock()
mock_input.suggested_prompts = ["Prompt 1", "Prompt 2", "Prompt 3"]
with (
patch(
@@ -292,14 +291,14 @@ async def test_populate_understanding_full_flow():
new_callable=AsyncMock,
return_value=None,
),
patch("backend.data.tally._settings", mock_settings),
patch("backend.data.tally.Settings", return_value=mock_settings),
patch(
"backend.data.tally.find_submission_by_email",
new_callable=AsyncMock,
return_value=(submission, SAMPLE_QUESTIONS),
),
patch(
"backend.data.tally.extract_business_understanding_from_tally",
"backend.data.tally.extract_business_understanding",
new_callable=AsyncMock,
return_value=mock_input,
) as mock_extract,
@@ -332,14 +331,14 @@ async def test_populate_understanding_handles_llm_timeout():
new_callable=AsyncMock,
return_value=None,
),
patch("backend.data.tally._settings", mock_settings),
patch("backend.data.tally.Settings", return_value=mock_settings),
patch(
"backend.data.tally.find_submission_by_email",
new_callable=AsyncMock,
return_value=(submission, SAMPLE_QUESTIONS),
),
patch(
"backend.data.tally.extract_business_understanding_from_tally",
"backend.data.tally.extract_business_understanding",
new_callable=AsyncMock,
side_effect=asyncio.TimeoutError(),
),
@@ -357,13 +356,13 @@ async def test_populate_understanding_handles_llm_timeout():
def test_mask_email():
assert mask_email("alice@example.com") == "a***e@example.com"
assert mask_email("ab@example.com") == "a***@example.com"
assert mask_email("a@example.com") == "a***@example.com"
assert _mask_email("alice@example.com") == "a***e@example.com"
assert _mask_email("ab@example.com") == "a***@example.com"
assert _mask_email("a@example.com") == "a***@example.com"
def test_mask_email_invalid():
assert mask_email("no-at-sign") == "***"
assert _mask_email("no-at-sign") == "***"
# ── Prompt construction (curly-brace safety) ─────────────────────────────────
@@ -394,11 +393,11 @@ def test_extraction_prompt_no_format_placeholders():
assert single_braces == [], f"Found format placeholders: {single_braces}"
# ── extract_business_understanding_from_tally ────────────────────────────────────────────
# ── extract_business_understanding ────────────────────────────────────────────
@pytest.mark.asyncio
async def test_extract_business_understanding_from_tally_success():
async def test_extract_business_understanding_success():
"""Happy path: LLM returns valid JSON that maps to BusinessUnderstandingInput."""
mock_choice = MagicMock()
mock_choice.message.content = json.dumps(
@@ -407,13 +406,6 @@ async def test_extract_business_understanding_from_tally_success():
"business_name": "Acme Corp",
"industry": "Technology",
"pain_points": ["manual reporting"],
"suggested_prompts": [
"Automate weekly reports",
"Set up invoice processing",
"Create a customer onboarding flow",
"Track project deadlines automatically",
"Send follow-up emails after meetings",
],
}
)
mock_response = MagicMock()
@@ -423,56 +415,16 @@ async def test_extract_business_understanding_from_tally_success():
mock_client.chat.completions.create.return_value = mock_response
with patch("backend.data.tally.AsyncOpenAI", return_value=mock_client):
result = await extract_business_understanding_from_tally("Q: Name?\nA: Alice")
result = await extract_business_understanding("Q: Name?\nA: Alice")
assert result.user_name == "Alice"
assert result.business_name == "Acme Corp"
assert result.industry == "Technology"
assert result.pain_points == ["manual reporting"]
# suggested_prompts validated and sliced to top 3
assert result.suggested_prompts == [
"Automate weekly reports",
"Set up invoice processing",
"Create a customer onboarding flow",
]
@pytest.mark.asyncio
async def test_extract_business_understanding_from_tally_filters_long_prompts():
"""Prompts exceeding 20 words are excluded and only top 3 are kept."""
long_prompt = " ".join(["word"] * 21)
mock_choice = MagicMock()
mock_choice.message.content = json.dumps(
{
"user_name": "Alice",
"suggested_prompts": [
long_prompt,
"Short prompt one",
long_prompt,
"Short prompt two",
"Short prompt three",
"Short prompt four",
],
}
)
mock_response = MagicMock()
mock_response.choices = [mock_choice]
mock_client = AsyncMock()
mock_client.chat.completions.create.return_value = mock_response
with patch("backend.data.tally.AsyncOpenAI", return_value=mock_client):
result = await extract_business_understanding_from_tally("Q: Name?\nA: Alice")
assert result.suggested_prompts == [
"Short prompt one",
"Short prompt two",
"Short prompt three",
]
@pytest.mark.asyncio
async def test_extract_business_understanding_from_tally_filters_nulls():
async def test_extract_business_understanding_filters_nulls():
"""Null values from LLM should be excluded from the result."""
mock_choice = MagicMock()
mock_choice.message.content = json.dumps(
@@ -485,7 +437,7 @@ async def test_extract_business_understanding_from_tally_filters_nulls():
mock_client.chat.completions.create.return_value = mock_response
with patch("backend.data.tally.AsyncOpenAI", return_value=mock_client):
result = await extract_business_understanding_from_tally("Q: Name?\nA: Alice")
result = await extract_business_understanding("Q: Name?\nA: Alice")
assert result.user_name == "Alice"
assert result.business_name is None
@@ -493,7 +445,7 @@ async def test_extract_business_understanding_from_tally_filters_nulls():
@pytest.mark.asyncio
async def test_extract_business_understanding_from_tally_invalid_json():
async def test_extract_business_understanding_invalid_json():
"""Invalid JSON from LLM should raise JSONDecodeError."""
mock_choice = MagicMock()
mock_choice.message.content = "not valid json {"
@@ -507,11 +459,11 @@ async def test_extract_business_understanding_from_tally_invalid_json():
patch("backend.data.tally.AsyncOpenAI", return_value=mock_client),
pytest.raises(json.JSONDecodeError),
):
await extract_business_understanding_from_tally("Q: Name?\nA: Alice")
await extract_business_understanding("Q: Name?\nA: Alice")
@pytest.mark.asyncio
async def test_extract_business_understanding_from_tally_timeout():
async def test_extract_business_understanding_timeout():
"""LLM timeout should propagate as asyncio.TimeoutError."""
mock_client = AsyncMock()
mock_client.chat.completions.create.side_effect = asyncio.TimeoutError()
@@ -521,7 +473,7 @@ async def test_extract_business_understanding_from_tally_timeout():
patch("backend.data.tally._LLM_TIMEOUT", 0.001),
pytest.raises(asyncio.TimeoutError),
):
await extract_business_understanding_from_tally("Q: Name?\nA: Alice")
await extract_business_understanding("Q: Name?\nA: Alice")
# ── _refresh_cache ───────────────────────────────────────────────────────────
@@ -540,7 +492,7 @@ async def test_refresh_cache_full_fetch():
submissions = SAMPLE_SUBMISSIONS
with (
patch("backend.data.tally._settings", mock_settings),
patch("backend.data.tally.Settings", return_value=mock_settings),
patch(
"backend.data.tally.get_redis_async",
new_callable=AsyncMock,
@@ -588,7 +540,7 @@ async def test_refresh_cache_incremental_fetch():
new_submissions = [SAMPLE_SUBMISSIONS[0]] # Just Alice
with (
patch("backend.data.tally._settings", mock_settings),
patch("backend.data.tally.Settings", return_value=mock_settings),
patch(
"backend.data.tally.get_redis_async",
new_callable=AsyncMock,

View File

@@ -86,11 +86,6 @@ class BusinessUnderstandingInput(pydantic.BaseModel):
None, description="Any additional context"
)
# Suggested prompts (UI-only, not included in system prompt)
suggested_prompts: Optional[list[str]] = pydantic.Field(
None, description="LLM-generated suggested prompts based on business context"
)
class BusinessUnderstanding(pydantic.BaseModel):
"""Full business understanding model returned from database."""
@@ -127,9 +122,6 @@ class BusinessUnderstanding(pydantic.BaseModel):
# Additional context
additional_notes: Optional[str] = None
# Suggested prompts (UI-only, not included in system prompt)
suggested_prompts: list[str] = pydantic.Field(default_factory=list)
@classmethod
def from_db(cls, db_record: CoPilotUnderstanding) -> "BusinessUnderstanding":
"""Convert database record to Pydantic model."""
@@ -157,7 +149,6 @@ class BusinessUnderstanding(pydantic.BaseModel):
current_software=_json_to_list(business.get("current_software")),
existing_automation=_json_to_list(business.get("existing_automation")),
additional_notes=business.get("additional_notes"),
suggested_prompts=_json_to_list(data.get("suggested_prompts")),
)
@@ -175,62 +166,6 @@ def _merge_lists(existing: list | None, new: list | None) -> list | None:
return merged
def merge_business_understanding_data(
existing_data: dict[str, Any],
input_data: BusinessUnderstandingInput,
) -> dict[str, Any]:
merged_data = dict(existing_data)
merged_business: dict[str, Any] = {}
if isinstance(merged_data.get("business"), dict):
merged_business = dict(merged_data["business"])
business_string_fields = [
"job_title",
"business_name",
"industry",
"business_size",
"user_role",
"additional_notes",
]
business_list_fields = [
"key_workflows",
"daily_activities",
"pain_points",
"bottlenecks",
"manual_tasks",
"automation_goals",
"current_software",
"existing_automation",
]
if input_data.user_name is not None:
merged_data["name"] = input_data.user_name
for field in business_string_fields:
value = getattr(input_data, field)
if value is not None:
merged_business[field] = value
for field in business_list_fields:
value = getattr(input_data, field)
if value is not None:
existing_list = _json_to_list(merged_business.get(field))
merged_list = _merge_lists(existing_list, value)
merged_business[field] = merged_list
merged_business["version"] = 1
merged_data["business"] = merged_business
# suggested_prompts lives at the top level (not under `business`) because
# it's a UI-only artifact consumed by the frontend, not business understanding
# data. The `business` sub-dict feeds the system prompt.
if input_data.suggested_prompts is not None:
merged_data["suggested_prompts"] = input_data.suggested_prompts
return merged_data
async def _get_from_cache(user_id: str) -> Optional[BusinessUnderstanding]:
"""Get business understanding from Redis cache."""
try:
@@ -310,18 +245,63 @@ async def upsert_business_understanding(
where={"userId": user_id}
)
# Get existing data structure or start fresh
existing_data: dict[str, Any] = {}
if existing and isinstance(existing.data, dict):
existing_data = dict(existing.data)
merged_data = merge_business_understanding_data(existing_data, input_data)
existing_business: dict[str, Any] = {}
if isinstance(existing_data.get("business"), dict):
existing_business = dict(existing_data["business"])
# Business fields (stored inside business object)
business_string_fields = [
"job_title",
"business_name",
"industry",
"business_size",
"user_role",
"additional_notes",
]
business_list_fields = [
"key_workflows",
"daily_activities",
"pain_points",
"bottlenecks",
"manual_tasks",
"automation_goals",
"current_software",
"existing_automation",
]
# Handle top-level name field
if input_data.user_name is not None:
existing_data["name"] = input_data.user_name
# Business string fields - overwrite if provided
for field in business_string_fields:
value = getattr(input_data, field)
if value is not None:
existing_business[field] = value
# Business list fields - merge with existing
for field in business_list_fields:
value = getattr(input_data, field)
if value is not None:
existing_list = _json_to_list(existing_business.get(field))
merged = _merge_lists(existing_list, value)
existing_business[field] = merged
# Set version and nest business data
existing_business["version"] = 1
existing_data["business"] = existing_business
# Upsert with the merged data
record = await CoPilotUnderstanding.prisma().upsert(
where={"userId": user_id},
data={
"create": {"userId": user_id, "data": SafeJson(merged_data)},
"update": {"data": SafeJson(merged_data)},
"create": {"userId": user_id, "data": SafeJson(existing_data)},
"update": {"data": SafeJson(existing_data)},
},
)

View File

@@ -1,102 +0,0 @@
"""Tests for business understanding merge and format logic."""
from datetime import datetime, timezone
from typing import Any
from backend.data.understanding import (
BusinessUnderstanding,
BusinessUnderstandingInput,
format_understanding_for_prompt,
merge_business_understanding_data,
)
def _make_input(**kwargs: Any) -> BusinessUnderstandingInput:
"""Create a BusinessUnderstandingInput with only the specified fields."""
return BusinessUnderstandingInput.model_validate(kwargs)
# ─── merge_business_understanding_data: suggested_prompts ─────────────
def test_merge_suggested_prompts_overwrites_existing():
"""New suggested_prompts should fully replace existing ones (not append)."""
existing = {
"name": "Alice",
"business": {"industry": "Tech", "version": 1},
"suggested_prompts": ["Old prompt 1", "Old prompt 2"],
}
input_data = _make_input(
suggested_prompts=["New prompt A", "New prompt B", "New prompt C"],
)
result = merge_business_understanding_data(existing, input_data)
assert result["suggested_prompts"] == [
"New prompt A",
"New prompt B",
"New prompt C",
]
def test_merge_suggested_prompts_none_preserves_existing():
"""When input has suggested_prompts=None, existing prompts are preserved."""
existing = {
"name": "Alice",
"business": {"industry": "Tech", "version": 1},
"suggested_prompts": ["Keep me"],
}
input_data = _make_input(industry="Finance")
result = merge_business_understanding_data(existing, input_data)
assert result["suggested_prompts"] == ["Keep me"]
assert result["business"]["industry"] == "Finance"
def test_merge_suggested_prompts_added_to_empty_data():
"""Suggested prompts are set at top level even when starting from empty data."""
existing: dict[str, Any] = {}
input_data = _make_input(suggested_prompts=["Prompt 1"])
result = merge_business_understanding_data(existing, input_data)
assert result["suggested_prompts"] == ["Prompt 1"]
def test_merge_suggested_prompts_empty_list_overwrites():
"""An explicit empty list should overwrite existing prompts."""
existing: dict[str, Any] = {
"suggested_prompts": ["Old prompt"],
"business": {"version": 1},
}
input_data = _make_input(suggested_prompts=[])
result = merge_business_understanding_data(existing, input_data)
assert result["suggested_prompts"] == []
# ─── format_understanding_for_prompt: excludes suggested_prompts ──────
def test_format_understanding_excludes_suggested_prompts():
"""suggested_prompts is UI-only and must NOT appear in the system prompt."""
understanding = BusinessUnderstanding(
id="test-id",
user_id="user-1",
created_at=datetime.now(tz=timezone.utc),
updated_at=datetime.now(tz=timezone.utc),
user_name="Alice",
industry="Technology",
suggested_prompts=["Automate reports", "Set up alerts", "Track KPIs"],
)
formatted = format_understanding_for_prompt(understanding)
assert "Alice" in formatted
assert "Technology" in formatted
assert "suggested_prompts" not in formatted
assert "Automate reports" not in formatted
assert "Set up alerts" not in formatted
assert "Track KPIs" not in formatted

View File

@@ -46,7 +46,7 @@ from backend.util.exceptions import (
)
from backend.util.logging import TruncatedLogger, is_structured_logging_enabled
from backend.util.settings import Config
from backend.util.type import coerce_inputs_to_schema
from backend.util.type import convert
config = Config()
logger = TruncatedLogger(logging.getLogger(__name__), prefix="[GraphExecutorUtil]")
@@ -213,8 +213,11 @@ def validate_exec(
if resolve_input:
data = merge_execution_input(data)
# Coerce non-matching data types to the expected input schema.
coerce_inputs_to_schema(data, schema)
# Convert non-matching data types to the expected input schema.
for name, data_type in schema.__annotations__.items():
value = data.get(name)
if (value is not None) and (type(value) is not data_type):
data[name] = convert(value, data_type)
# Input data post-merge should contain all required fields from the schema.
if missing_input := schema.get_missing_input(data):

View File

@@ -70,9 +70,6 @@ def _msg_tokens(msg: dict, enc) -> int:
# Count tool result tokens
tool_call_tokens += _tok_len(item.get("tool_use_id", ""), enc)
tool_call_tokens += _tok_len(item.get("content", ""), enc)
elif isinstance(item, dict) and item.get("type") == "text":
# Count text block tokens
tool_call_tokens += _tok_len(item.get("text", ""), enc)
elif isinstance(item, dict) and "content" in item:
# Other content types with content field
tool_call_tokens += _tok_len(item.get("content", ""), enc)
@@ -148,14 +145,10 @@ def _truncate_middle_tokens(text: str, enc, max_tok: int) -> str:
if len(ids) <= max_tok:
return text # nothing to do
# Need at least 3 tokens (head + ellipsis + tail) for meaningful truncation
mid = enc.encode("")
if max_tok < 3:
return enc.decode(mid)
# Split the allowance between the two ends:
head = max_tok // 2 - 1 # -1 for the ellipsis
tail = max_tok - head - 1
mid = enc.encode("")
return enc.decode(ids[:head] + mid + ids[-tail:])
@@ -403,7 +396,7 @@ def validate_and_remove_orphan_tool_responses(
if log_warning:
logger.warning(
"Removing %d orphan tool response(s): %s", len(orphan_ids), orphan_ids
f"Removing {len(orphan_ids)} orphan tool response(s): {orphan_ids}"
)
return _remove_orphan_tool_responses(messages, orphan_ids)
@@ -495,9 +488,8 @@ def _ensure_tool_pairs_intact(
# Some tool_call_ids couldn't be resolved - remove those tool responses
# This shouldn't happen in normal operation but handles edge cases
logger.warning(
"Could not find assistant messages for tool_call_ids: %s. "
"Removing orphan tool responses.",
orphan_tool_call_ids,
f"Could not find assistant messages for tool_call_ids: {orphan_tool_call_ids}. "
"Removing orphan tool responses."
)
recent_messages = _remove_orphan_tool_responses(
recent_messages, orphan_tool_call_ids
@@ -505,8 +497,8 @@ def _ensure_tool_pairs_intact(
if messages_to_prepend:
logger.info(
"Extended recent messages by %d to preserve tool_call/tool_response pairs",
len(messages_to_prepend),
f"Extended recent messages by {len(messages_to_prepend)} to preserve "
f"tool_call/tool_response pairs"
)
return messages_to_prepend + recent_messages
@@ -694,15 +686,11 @@ async def compress_context(
msgs = [summary_msg] + recent_msgs
logger.info(
"Context summarized: %d -> %d tokens, summarized %d messages",
original_count,
total_tokens(),
messages_summarized,
f"Context summarized: {original_count} -> {total_tokens()} tokens, "
f"summarized {messages_summarized} messages"
)
except Exception as e:
logger.warning(
"Summarization failed, continuing with truncation: %s", e
)
logger.warning(f"Summarization failed, continuing with truncation: {e}")
# Fall through to content truncation
# ---- STEP 2: Normalize content ----------------------------------------
@@ -740,12 +728,6 @@ async def compress_context(
# This is more granular than dropping all old messages at once.
while total_tokens() + reserve > target_tokens and len(msgs) > 2:
deletable: list[int] = []
# Count assistant messages to ensure we keep at least one
assistant_indices: set[int] = {
i
for i in range(len(msgs))
if msgs[i] is not None and msgs[i].get("role") == "assistant"
}
for i in range(1, len(msgs) - 1):
msg = msgs[i]
if (
@@ -753,9 +735,6 @@ async def compress_context(
and not _is_tool_message(msg)
and not _is_objective_message(msg)
):
# Skip if this is the last remaining assistant message
if msg.get("role") == "assistant" and len(assistant_indices) <= 1:
continue
deletable.append(i)
if not deletable:
break

View File

@@ -89,10 +89,6 @@ class Config(UpdateTrackingModel["Config"], BaseSettings):
le=500,
description="Thread pool size for FastAPI sync operations. All sync endpoints and dependencies automatically use this pool. Higher values support more concurrent sync operations but use more memory.",
)
tally_extraction_llm_model: str = Field(
default="openai/gpt-4o-mini",
description="OpenRouter model ID used for extracting business understanding from Tally form data",
)
ollama_host: str = Field(
default="localhost:11434",
description="Default Ollama host; exempted from SSRF checks.",
@@ -121,10 +117,6 @@ class Config(UpdateTrackingModel["Config"], BaseSettings):
default=True,
description="If authentication is enabled or not",
)
enable_invite_gate: bool = Field(
default=True,
description="If the invite-only signup gate is enforced",
)
enable_credit: bool = Field(
default=False,
description="If user credit system is enabled or not",

View File

@@ -249,87 +249,6 @@ def convert(value: Any, target_type: Any) -> Any:
raise ConversionError(f"Failed to convert {value} to {target_type}") from e
def _value_satisfies_type(value: Any, target: Any) -> bool:
"""Check whether *value* already satisfies *target*, including inner elements.
For union types this checks each member; for generic container types it
recursively checks that inner elements satisfy the type args (e.g. every
element in a ``list[str]`` is a ``str``). Returns ``False`` when uncertain
so the caller falls through to :func:`convert`.
"""
# typing.Any cannot be used with isinstance(); treat as always satisfied.
if target is Any:
return True
origin = get_origin(target)
if origin is Union or origin is types.UnionType:
non_none = [a for a in get_args(target) if a is not type(None)]
return any(_value_satisfies_type(value, member) for member in non_none)
# Generic container type (e.g. list[str], dict[str, int])
if origin is not None:
# Guard: origin may not be a runtime type (e.g. Literal)
if not isinstance(origin, type):
return False
if not isinstance(value, origin):
return False
args = get_args(target)
if not args:
return True
# Check inner elements satisfy the type args
if _is_type_or_subclass(origin, list):
return all(_value_satisfies_type(v, args[0]) for v in value)
if _is_type_or_subclass(origin, dict) and len(args) >= 2:
return all(
_value_satisfies_type(k, args[0]) and _value_satisfies_type(v, args[1])
for k, v in value.items()
)
if (
_is_type_or_subclass(origin, set) or _is_type_or_subclass(origin, frozenset)
) and args:
return all(_value_satisfies_type(v, args[0]) for v in value)
if _is_type_or_subclass(origin, tuple):
# Homogeneous tuple[T, ...] — single type + Ellipsis
if len(args) == 2 and args[1] is Ellipsis:
return all(_value_satisfies_type(v, args[0]) for v in value)
# Heterogeneous tuple[T1, T2, ...] — positional types
if len(value) != len(args):
return False
return all(_value_satisfies_type(v, t) for v, t in zip(value, args))
# Unhandled generic origin — fall through to convert()
return False
# Simple type (e.g. str, int)
if isinstance(target, type):
return isinstance(value, target)
return False
def coerce_inputs_to_schema(data: dict[str, Any], schema: type) -> None:
"""Coerce *data* values in-place to match *schema*'s field types.
Uses ``model_fields`` (not ``__annotations__``) so inherited fields are
included. Skips coercion when the value already satisfies the target
type — in particular for union-typed fields where the value matches one
member but differs from the annotation object itself.
This is the single authoritative coercion step shared by the executor
(``validate_exec``) and the CoPilot (``execute_block``).
"""
for name, field_info in schema.model_fields.items():
value = data.get(name)
if value is None:
continue
target = field_info.annotation
if target is None:
continue
if _value_satisfies_type(value, target):
continue
data[name] = convert(value, target)
class FormattedStringType(str):
string_format: str

View File

@@ -1,8 +1,6 @@
from typing import Any, List, Literal, Optional
from typing import List, Optional
from pydantic import BaseModel
from backend.util.type import _value_satisfies_type, coerce_inputs_to_schema, convert
from backend.util.type import convert
def test_type_conversion():
@@ -48,343 +46,3 @@ def test_type_conversion():
# Test other empty list conversions
assert convert([], int) == 0 # len([]) = 0
assert convert([], Optional[int]) == 0
# ---------------------------------------------------------------------------
# _value_satisfies_type
# ---------------------------------------------------------------------------
class TestValueSatisfiesType:
# --- simple types ---
def test_simple_match(self):
assert _value_satisfies_type("hello", str) is True
assert _value_satisfies_type(42, int) is True
assert _value_satisfies_type(3.14, float) is True
assert _value_satisfies_type(True, bool) is True
def test_simple_mismatch(self):
assert _value_satisfies_type("hello", int) is False
assert _value_satisfies_type(42, str) is False
assert _value_satisfies_type([1, 2], str) is False
# --- Any ---
def test_any_always_satisfied(self):
assert _value_satisfies_type("hello", Any) is True
assert _value_satisfies_type(42, Any) is True
assert _value_satisfies_type([1, 2], Any) is True
assert _value_satisfies_type(None, Any) is True
# --- Optional / Union ---
def test_optional_with_value(self):
assert _value_satisfies_type("hello", Optional[str]) is True
assert _value_satisfies_type(42, Optional[int]) is True
def test_optional_mismatch(self):
assert _value_satisfies_type(42, Optional[str]) is False
def test_union_matches_first_member(self):
assert _value_satisfies_type("hello", str | list[str]) is True
def test_union_matches_second_member(self):
assert _value_satisfies_type(["a", "b"], str | list[str]) is True
def test_union_no_match(self):
assert _value_satisfies_type(42, str | list[str]) is False
# --- list[T] ---
def test_list_str_all_match(self):
assert _value_satisfies_type(["a", "b", "c"], list[str]) is True
def test_list_str_inner_mismatch(self):
assert _value_satisfies_type([1, 2, 3], list[str]) is False
def test_list_int_all_match(self):
assert _value_satisfies_type([1, 2, 3], list[int]) is True
def test_list_int_inner_mismatch(self):
assert _value_satisfies_type(["1", "2"], list[int]) is False
def test_empty_list_satisfies_any_list_type(self):
assert _value_satisfies_type([], list[str]) is True
assert _value_satisfies_type([], list[int]) is True
def test_string_does_not_satisfy_list(self):
assert _value_satisfies_type("hello", list[str]) is False
# --- nested list[list[str]] ---
def test_nested_list_all_match(self):
assert _value_satisfies_type([["a", "b"], ["c"]], list[list[str]]) is True
def test_nested_list_inner_mismatch(self):
assert _value_satisfies_type([["a", 1], ["c"]], list[list[str]]) is False
def test_nested_list_outer_mismatch(self):
assert _value_satisfies_type(["a", "b"], list[list[str]]) is False
# --- dict[K, V] ---
def test_dict_str_int_match(self):
assert _value_satisfies_type({"a": 1, "b": 2}, dict[str, int]) is True
def test_dict_str_int_value_mismatch(self):
assert _value_satisfies_type({"a": "1", "b": "2"}, dict[str, int]) is False
def test_dict_str_int_key_mismatch(self):
assert _value_satisfies_type({1: 1, 2: 2}, dict[str, int]) is False
def test_empty_dict_satisfies(self):
assert _value_satisfies_type({}, dict[str, int]) is True
# --- set[T] / tuple[T] ---
def test_set_match(self):
assert _value_satisfies_type({1, 2, 3}, set[int]) is True
def test_set_mismatch(self):
assert _value_satisfies_type({"a", "b"}, set[int]) is False
def test_tuple_homogeneous_match(self):
assert _value_satisfies_type((1, 2, 3), tuple[int, ...]) is True
def test_tuple_homogeneous_mismatch(self):
assert _value_satisfies_type((1, "2", 3), tuple[int, ...]) is False
def test_tuple_heterogeneous_match(self):
assert _value_satisfies_type(("a", 1, True), tuple[str, int, bool]) is True
def test_tuple_heterogeneous_mismatch(self):
assert _value_satisfies_type(("a", "b", True), tuple[str, int, bool]) is False
def test_tuple_heterogeneous_wrong_length(self):
assert _value_satisfies_type(("a", 1), tuple[str, int, bool]) is False
# --- bare generics (no args) ---
def test_bare_list(self):
assert _value_satisfies_type([1, "a"], list) is True
def test_bare_dict(self):
assert _value_satisfies_type({"a": 1}, dict) is True
# --- union with generic inner mismatch ---
def test_union_list_with_wrong_inner_falls_through(self):
# [1, 2] doesn't satisfy list[str] (inner mismatch), and not str either
assert _value_satisfies_type([1, 2], str | list[str]) is False
# --- Literal (non-runtime origin) ---
def test_literal_does_not_crash(self):
"""Literal origins are not runtime types — should return False, not crash."""
assert _value_satisfies_type("active", Literal["active", "inactive"]) is False
# ---------------------------------------------------------------------------
# coerce_inputs_to_schema — using real Pydantic models
# ---------------------------------------------------------------------------
class SampleSchema(BaseModel):
name: str
count: int
items: list[str]
config: dict[str, int] = {}
class NestedSchema(BaseModel):
rows: list[list[str]]
class UnionSchema(BaseModel):
content: str | list[str]
class OptionalSchema(BaseModel):
label: Optional[str] = None
value: int = 0
class AnyFieldSchema(BaseModel):
data: Any
class TestCoerceInputsToSchema:
def test_string_to_int(self):
data: dict[str, Any] = {"name": "test", "count": "42", "items": ["a"]}
coerce_inputs_to_schema(data, SampleSchema)
assert data["count"] == 42
assert isinstance(data["count"], int)
def test_json_string_to_list(self):
data: dict[str, Any] = {"name": "test", "count": 1, "items": '["a","b","c"]'}
coerce_inputs_to_schema(data, SampleSchema)
assert data["items"] == ["a", "b", "c"]
def test_already_correct_types_unchanged(self):
data: dict[str, Any] = {
"name": "test",
"count": 42,
"items": ["a", "b"],
"config": {"x": 1},
}
coerce_inputs_to_schema(data, SampleSchema)
assert data == {
"name": "test",
"count": 42,
"items": ["a", "b"],
"config": {"x": 1},
}
def test_inner_element_coercion(self):
"""list[str] with int inner elements → coerced to strings."""
data: dict[str, Any] = {"name": "test", "count": 1, "items": [1, 2, 3]}
coerce_inputs_to_schema(data, SampleSchema)
assert data["items"] == ["1", "2", "3"]
def test_dict_value_coercion(self):
"""dict[str, int] with string values → coerced to ints."""
data: dict[str, Any] = {
"name": "test",
"count": 1,
"items": [],
"config": {"x": "10", "y": "20"},
}
coerce_inputs_to_schema(data, SampleSchema)
assert data["config"] == {"x": 10, "y": 20}
def test_nested_list_from_json_string(self):
data: dict[str, Any] = {
"rows": '[["Name","Score"],["Alice","90"]]',
}
coerce_inputs_to_schema(data, NestedSchema)
assert data["rows"] == [["Name", "Score"], ["Alice", "90"]]
def test_nested_list_already_correct(self):
original = [["a", "b"], ["c", "d"]]
data: dict[str, Any] = {"rows": original}
coerce_inputs_to_schema(data, NestedSchema)
assert data["rows"] == original
def test_union_preserves_valid_list(self):
"""list[str] value should NOT be stringified for str | list[str]."""
data: dict[str, Any] = {"content": ["a", "b"]}
coerce_inputs_to_schema(data, UnionSchema)
assert data["content"] == ["a", "b"]
assert isinstance(data["content"], list)
def test_union_preserves_valid_string(self):
data: dict[str, Any] = {"content": "hello"}
coerce_inputs_to_schema(data, UnionSchema)
assert data["content"] == "hello"
def test_union_list_with_wrong_inner_gets_coerced(self):
"""[1, 2] for str | list[str] — inner ints don't match list[str],
so convert() is called. convert tries str first → stringifies."""
data: dict[str, Any] = {"content": [1, 2]}
coerce_inputs_to_schema(data, UnionSchema)
# convert([1,2], str | list[str]) tries str first → "[1, 2]"
# This is convert()'s union behavior — str wins over list[str]
assert isinstance(data["content"], (str, list))
def test_skips_none_values(self):
data: dict[str, Any] = {"label": None, "value": "5"}
coerce_inputs_to_schema(data, OptionalSchema)
assert data["label"] is None
assert data["value"] == 5
def test_skips_missing_fields(self):
data: dict[str, Any] = {"value": "10"}
coerce_inputs_to_schema(data, OptionalSchema)
assert "label" not in data
assert data["value"] == 10
def test_any_field_skipped(self):
"""Fields typed as Any should pass through without coercion."""
data: dict[str, Any] = {"data": [1, "mixed", {"nested": True}]}
coerce_inputs_to_schema(data, AnyFieldSchema)
assert data["data"] == [1, "mixed", {"nested": True}]
def test_preserves_all_convert_capabilities(self):
"""Verify coerce_inputs_to_schema doesn't lose any convert() capability
that existed before the _value_satisfies_type gate was added."""
class FullSchema(BaseModel):
as_int: int
as_float: float
as_bool: bool
as_str: str
as_list: list[int]
as_dict: dict[str, str]
data: dict[str, Any] = {
"as_int": "42",
"as_float": "3.14",
"as_bool": "True",
"as_str": 123,
"as_list": "[1,2,3]",
"as_dict": '{"a": "b"}',
}
coerce_inputs_to_schema(data, FullSchema)
assert data["as_int"] == 42
assert data["as_float"] == 3.14
assert data["as_bool"] is True
assert data["as_str"] == "123"
assert data["as_list"] == [1, 2, 3]
assert data["as_dict"] == {"a": "b"}
def test_inherited_fields_are_coerced(self):
"""model_fields includes inherited fields; __annotations__ does not.
This verifies that fields from a parent schema are still coerced."""
class ParentSchema(BaseModel):
base_count: int
class ChildSchema(ParentSchema):
name: str
# base_count is inherited — __annotations__ wouldn't include it
assert "base_count" not in ChildSchema.__annotations__
assert "base_count" in ChildSchema.model_fields
data: dict[str, Any] = {"base_count": "42", "name": "test"}
coerce_inputs_to_schema(data, ChildSchema)
assert data["base_count"] == 42
assert isinstance(data["base_count"], int)
def test_nested_pydantic_model_field(self):
"""dict input for a Pydantic model-typed field passes through.
convert() doesn't construct Pydantic models — Pydantic validation
handles that downstream. This test documents the behavior."""
class InnerModel(BaseModel):
x: int
class OuterModel(BaseModel):
inner: InnerModel
data: dict[str, Any] = {"inner": {"x": 1}}
coerce_inputs_to_schema(data, OuterModel)
# dict stays as dict — convert() doesn't construct Pydantic models
assert data["inner"] == {"x": 1}
assert isinstance(data["inner"], dict)
def test_literal_field_passes_through(self):
"""Literal-typed fields should not crash coercion."""
class LiteralSchema(BaseModel):
status: Literal["active", "inactive"]
data: dict[str, Any] = {"status": "active"}
coerce_inputs_to_schema(data, LiteralSchema)
assert data["status"] == "active"
def test_list_of_pydantic_model_field(self):
"""list[dict] for list[PydanticModel] passes through unchanged."""
class ItemModel(BaseModel):
name: str
class ContainerModel(BaseModel):
items: list[ItemModel]
data: dict[str, Any] = {"items": [{"name": "a"}, {"name": "b"}]}
coerce_inputs_to_schema(data, ContainerModel)
# Dicts stay as dicts — Pydantic validation handles construction
assert data["items"] == [{"name": "a"}, {"name": "b"}]
assert isinstance(data["items"][0], dict)

View File

@@ -1,246 +0,0 @@
#!/usr/bin/env python3
"""
AutoGPT Analytics — View Generator
====================================
Reads every .sql file in analytics/queries/ and registers it as a
CREATE OR REPLACE VIEW in the analytics schema.
Quick start (from autogpt_platform/backend/):
Step 1 — one-time setup (creates schema, role, grants):
poetry run analytics-setup
Step 2 — create / refresh all 14 analytics views:
poetry run analytics-views
Both commands auto-detect credentials from .env (DB_* vars).
Use --db-url to override.
Step 3 (optional) — enable login and set a password for the read-only
role so external tools (Supabase MCP, PostHog Data Warehouse) can connect.
The role is created as NOLOGIN, so you must grant LOGIN at the same time.
Run in Supabase SQL Editor:
ALTER ROLE analytics_readonly WITH LOGIN PASSWORD 'your-password';
Usage
-----
poetry run analytics-setup # apply setup to DB
poetry run analytics-setup --dry-run # print setup SQL only
poetry run analytics-views # apply all views to DB
poetry run analytics-views --dry-run # print all view SQL only
poetry run analytics-views --only graph_execution,retention_login_weekly
Environment variables
---------------------
DATABASE_URL Postgres connection string (checked before .env)
Notes
-----
- .env DB_* vars are read automatically as a fallback.
- Safe to re-run: uses CREATE OR REPLACE VIEW.
- Looker, PostHog Data Warehouse, and Supabase MCP all read from the
same analytics.* views — no raw tables exposed.
"""
import argparse
import os
import sys
from pathlib import Path
from urllib.parse import quote
QUERIES_DIR = Path(__file__).parent.parent / "analytics" / "queries"
ENV_FILE = Path(__file__).parent / ".env"
SCHEMA = "analytics"
SETUP_SQL = """\
-- =============================================================
-- AutoGPT Analytics Schema Setup
-- Run ONCE as the postgres superuser (e.g. via Supabase SQL Editor).
-- After this, run: poetry run analytics-views
-- =============================================================
-- 1. Create the analytics schema
CREATE SCHEMA IF NOT EXISTS analytics;
-- 2. Create the read-only role (skip if already exists)
DO $$
BEGIN
IF NOT EXISTS (SELECT FROM pg_roles WHERE rolname = 'analytics_readonly') THEN
CREATE ROLE analytics_readonly NOLOGIN;
END IF;
END
$$;
-- 3. Analytics schema grants only.
-- Views use security_invoker = false so they execute as their
-- owner (postgres). analytics_readonly never needs direct access
-- to the platform or auth schemas.
GRANT USAGE ON SCHEMA analytics TO analytics_readonly;
GRANT SELECT ON ALL TABLES IN SCHEMA analytics TO analytics_readonly;
ALTER DEFAULT PRIVILEGES IN SCHEMA analytics
GRANT SELECT ON TABLES TO analytics_readonly;
"""
def load_db_url_from_env() -> str | None:
"""Read DB_* vars from .env and build a psycopg2 connection string."""
if not ENV_FILE.exists():
return None
env: dict[str, str] = {}
for line in ENV_FILE.read_text().splitlines():
line = line.strip()
if not line or line.startswith("#") or "=" not in line:
continue
key, _, value = line.partition("=")
env[key.strip()] = value.strip().strip('"').strip("'")
host = env.get("DB_HOST", "localhost")
port = env.get("DB_PORT", "5432")
user = env.get("DB_USER", "postgres")
password = env.get("DB_PASS", "")
dbname = env.get("DB_NAME", "postgres")
if not password:
return None
return (
"postgresql://"
f"{quote(user, safe='')}:{quote(password, safe='')}"
f"@{host}:{port}/{quote(dbname, safe='')}"
)
def get_db_url(args: argparse.Namespace) -> str | None:
return args.db_url or os.environ.get("DATABASE_URL") or load_db_url_from_env()
def connect(db_url: str):
try:
import psycopg2
except ImportError:
print("psycopg2 not found. Run: poetry install", file=sys.stderr)
sys.exit(1)
return psycopg2.connect(db_url)
def run_sql(db_url: str, statements: list[tuple[str, str]]) -> None:
"""Execute a list of (label, sql) pairs in a single transaction."""
conn = connect(db_url)
conn.autocommit = False
cur = conn.cursor()
try:
for label, sql in statements:
print(f" {label} ...", end=" ")
cur.execute(sql)
print("OK")
conn.commit()
print(f"\n{len(statements)} statement(s) applied.")
except Exception as e:
conn.rollback()
print(f"\n✗ Error: {e}", file=sys.stderr)
sys.exit(1)
finally:
cur.close()
conn.close()
def build_view_sql(name: str, query_body: str) -> str:
body = query_body.strip().rstrip(";")
# security_invoker = false → view runs as its owner (postgres), not the
# caller, so analytics_readonly only needs analytics schema access.
return f"CREATE OR REPLACE VIEW {SCHEMA}.{name} WITH (security_invoker = false) AS\n{body};\n"
def load_views(only: list[str] | None = None) -> list[tuple[str, str]]:
"""Return [(label, sql)] for all views, in alphabetical order."""
files = sorted(QUERIES_DIR.glob("*.sql"))
if not files:
print(f"No .sql files found in {QUERIES_DIR}", file=sys.stderr)
sys.exit(1)
known = {f.stem for f in files}
if only:
unknown = [n for n in only if n not in known]
if unknown:
print(
f"Unknown view name(s): {', '.join(unknown)}\n"
f"Available: {', '.join(sorted(known))}",
file=sys.stderr,
)
sys.exit(1)
result = []
for f in files:
name = f.stem
if only and name not in only:
continue
result.append((f"view analytics.{name}", build_view_sql(name, f.read_text())))
return result
def no_db_url_error() -> None:
print(
"No database URL found.\n"
"Tried: --db-url, DATABASE_URL env var, and .env (DB_* vars).\n"
"Use --dry-run to just print the SQL.",
file=sys.stderr,
)
sys.exit(1)
def cmd_setup(args: argparse.Namespace) -> None:
if args.dry_run:
print(SETUP_SQL)
return
db_url = get_db_url(args)
if not db_url:
no_db_url_error()
assert db_url
print("Applying analytics setup...")
run_sql(db_url, [("schema / role / grants", SETUP_SQL)])
def cmd_views(args: argparse.Namespace) -> None:
only = [v.strip() for v in args.only.split(",")] if args.only else None
views = load_views(only=only)
if not views:
print("No matching views found.")
sys.exit(0)
if args.dry_run:
print(f"-- {len(views)} views\n")
for label, sql in views:
print(f"-- {label}")
print(sql)
return
db_url = get_db_url(args)
if not db_url:
no_db_url_error()
assert db_url
print(f"Applying {len(views)} view(s)...")
# Append grant refresh so the readonly role sees any new views
grant = f"GRANT SELECT ON ALL TABLES IN SCHEMA {SCHEMA} TO analytics_readonly;"
run_sql(db_url, views + [("grant analytics_readonly", grant)])
def main_setup() -> None:
parser = argparse.ArgumentParser(description="Apply analytics schema setup to DB")
parser.add_argument(
"--dry-run", action="store_true", help="Print SQL, don't execute"
)
parser.add_argument("--db-url", help="Postgres connection string")
cmd_setup(parser.parse_args())
def main_views() -> None:
parser = argparse.ArgumentParser(description="Apply analytics views to DB")
parser.add_argument(
"--dry-run", action="store_true", help="Print SQL, don't execute"
)
parser.add_argument("--db-url", help="Postgres connection string")
parser.add_argument("--only", help="Comma-separated view names to update")
cmd_views(parser.parse_args())
if __name__ == "__main__":
# Default: apply views (backwards-compatible with direct python invocation)
main_views()

Some files were not shown because too many files have changed in this diff Show More