mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-01-22 21:48:12 -05:00
Compare commits
1 Commits
master
...
fix/pgvect
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
12690ad0a9 |
@@ -154,16 +154,15 @@ async def store_content_embedding(
|
|||||||
|
|
||||||
# Upsert the embedding
|
# Upsert the embedding
|
||||||
# WHERE clause in DO UPDATE prevents PostgreSQL 15 bug with NULLS NOT DISTINCT
|
# WHERE clause in DO UPDATE prevents PostgreSQL 15 bug with NULLS NOT DISTINCT
|
||||||
# Use unqualified ::vector - pgvector is in search_path on all environments
|
|
||||||
await execute_raw_with_schema(
|
await execute_raw_with_schema(
|
||||||
"""
|
"""
|
||||||
INSERT INTO {schema_prefix}"UnifiedContentEmbedding" (
|
INSERT INTO {schema_prefix}"UnifiedContentEmbedding" (
|
||||||
"id", "contentType", "contentId", "userId", "embedding", "searchableText", "metadata", "createdAt", "updatedAt"
|
"id", "contentType", "contentId", "userId", "embedding", "searchableText", "metadata", "createdAt", "updatedAt"
|
||||||
)
|
)
|
||||||
VALUES (gen_random_uuid()::text, $1::{schema_prefix}"ContentType", $2, $3, $4::vector, $5, $6::jsonb, NOW(), NOW())
|
VALUES (gen_random_uuid()::text, $1::{schema_prefix}"ContentType", $2, $3, $4::{schema}.vector, $5, $6::jsonb, NOW(), NOW())
|
||||||
ON CONFLICT ("contentType", "contentId", "userId")
|
ON CONFLICT ("contentType", "contentId", "userId")
|
||||||
DO UPDATE SET
|
DO UPDATE SET
|
||||||
"embedding" = $4::vector,
|
"embedding" = $4::{schema}.vector,
|
||||||
"searchableText" = $5,
|
"searchableText" = $5,
|
||||||
"metadata" = $6::jsonb,
|
"metadata" = $6::jsonb,
|
||||||
"updatedAt" = NOW()
|
"updatedAt" = NOW()
|
||||||
@@ -879,7 +878,6 @@ async def semantic_search(
|
|||||||
min_similarity_idx = len(params) + 1
|
min_similarity_idx = len(params) + 1
|
||||||
params.append(min_similarity)
|
params.append(min_similarity)
|
||||||
|
|
||||||
# Use unqualified ::vector and <=> operator - pgvector is in search_path on all environments
|
|
||||||
sql = (
|
sql = (
|
||||||
"""
|
"""
|
||||||
SELECT
|
SELECT
|
||||||
@@ -889,7 +887,7 @@ async def semantic_search(
|
|||||||
metadata,
|
metadata,
|
||||||
1 - (embedding <=> '"""
|
1 - (embedding <=> '"""
|
||||||
+ embedding_str
|
+ embedding_str
|
||||||
+ """'::vector) as similarity
|
+ """'::{schema}.vector) as similarity
|
||||||
FROM {schema_prefix}"UnifiedContentEmbedding"
|
FROM {schema_prefix}"UnifiedContentEmbedding"
|
||||||
WHERE "contentType" IN ("""
|
WHERE "contentType" IN ("""
|
||||||
+ content_type_placeholders
|
+ content_type_placeholders
|
||||||
@@ -899,7 +897,7 @@ async def semantic_search(
|
|||||||
+ """
|
+ """
|
||||||
AND 1 - (embedding <=> '"""
|
AND 1 - (embedding <=> '"""
|
||||||
+ embedding_str
|
+ embedding_str
|
||||||
+ """'::vector) >= $"""
|
+ """'::{schema}.vector) >= $"""
|
||||||
+ str(min_similarity_idx)
|
+ str(min_similarity_idx)
|
||||||
+ """
|
+ """
|
||||||
ORDER BY similarity DESC
|
ORDER BY similarity DESC
|
||||||
|
|||||||
@@ -295,7 +295,7 @@ async def unified_hybrid_search(
|
|||||||
FROM {{schema_prefix}}"UnifiedContentEmbedding" uce
|
FROM {{schema_prefix}}"UnifiedContentEmbedding" uce
|
||||||
WHERE uce."contentType" = ANY({content_types_param}::{{schema_prefix}}"ContentType"[])
|
WHERE uce."contentType" = ANY({content_types_param}::{{schema_prefix}}"ContentType"[])
|
||||||
{user_filter}
|
{user_filter}
|
||||||
ORDER BY uce.embedding <=> {embedding_param}::vector
|
ORDER BY uce.embedding <=> {embedding_param}::{{schema}}.vector
|
||||||
LIMIT 200
|
LIMIT 200
|
||||||
)
|
)
|
||||||
),
|
),
|
||||||
@@ -307,7 +307,7 @@ async def unified_hybrid_search(
|
|||||||
uce.metadata,
|
uce.metadata,
|
||||||
uce."updatedAt" as updated_at,
|
uce."updatedAt" as updated_at,
|
||||||
-- Semantic score: cosine similarity (1 - distance)
|
-- Semantic score: cosine similarity (1 - distance)
|
||||||
COALESCE(1 - (uce.embedding <=> {embedding_param}::vector), 0) as semantic_score,
|
COALESCE(1 - (uce.embedding <=> {embedding_param}::{{schema}}.vector), 0) as semantic_score,
|
||||||
-- Lexical score: ts_rank_cd
|
-- Lexical score: ts_rank_cd
|
||||||
COALESCE(ts_rank_cd(uce.search, plainto_tsquery('english', {query_param})), 0) as lexical_raw,
|
COALESCE(ts_rank_cd(uce.search, plainto_tsquery('english', {query_param})), 0) as lexical_raw,
|
||||||
-- Category match from metadata
|
-- Category match from metadata
|
||||||
@@ -583,7 +583,7 @@ async def hybrid_search(
|
|||||||
WHERE uce."contentType" = 'STORE_AGENT'::{{schema_prefix}}"ContentType"
|
WHERE uce."contentType" = 'STORE_AGENT'::{{schema_prefix}}"ContentType"
|
||||||
AND uce."userId" IS NULL
|
AND uce."userId" IS NULL
|
||||||
AND {where_clause}
|
AND {where_clause}
|
||||||
ORDER BY uce.embedding <=> {embedding_param}::vector
|
ORDER BY uce.embedding <=> {embedding_param}::{{schema}}.vector
|
||||||
LIMIT 200
|
LIMIT 200
|
||||||
) uce
|
) uce
|
||||||
),
|
),
|
||||||
@@ -605,7 +605,7 @@ async def hybrid_search(
|
|||||||
-- Searchable text for BM25 reranking
|
-- Searchable text for BM25 reranking
|
||||||
COALESCE(sa.agent_name, '') || ' ' || COALESCE(sa.sub_heading, '') || ' ' || COALESCE(sa.description, '') as searchable_text,
|
COALESCE(sa.agent_name, '') || ' ' || COALESCE(sa.sub_heading, '') || ' ' || COALESCE(sa.description, '') as searchable_text,
|
||||||
-- Semantic score
|
-- Semantic score
|
||||||
COALESCE(1 - (uce.embedding <=> {embedding_param}::vector), 0) as semantic_score,
|
COALESCE(1 - (uce.embedding <=> {embedding_param}::{{schema}}.vector), 0) as semantic_score,
|
||||||
-- Lexical score (raw, will normalize)
|
-- Lexical score (raw, will normalize)
|
||||||
COALESCE(ts_rank_cd(uce.search, plainto_tsquery('english', {query_param})), 0) as lexical_raw,
|
COALESCE(ts_rank_cd(uce.search, plainto_tsquery('english', {query_param})), 0) as lexical_raw,
|
||||||
-- Category match
|
-- Category match
|
||||||
|
|||||||
@@ -120,12 +120,7 @@ async def _raw_with_schema(
|
|||||||
|
|
||||||
Supports placeholders:
|
Supports placeholders:
|
||||||
- {schema_prefix}: Table/type prefix (e.g., "platform".)
|
- {schema_prefix}: Table/type prefix (e.g., "platform".)
|
||||||
- {schema}: Raw schema name for application tables (e.g., platform)
|
- {schema}: Raw schema name (e.g., platform) for pgvector types
|
||||||
|
|
||||||
Note on pgvector types:
|
|
||||||
Use unqualified ::vector and <=> operator in queries. PostgreSQL resolves
|
|
||||||
these via search_path, which includes the schema where pgvector is installed
|
|
||||||
on all environments (local, CI, dev).
|
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
query_template: SQL query with {schema_prefix} and/or {schema} placeholders
|
query_template: SQL query with {schema_prefix} and/or {schema} placeholders
|
||||||
@@ -137,19 +132,16 @@ async def _raw_with_schema(
|
|||||||
- list[dict] if execute=False (query results)
|
- list[dict] if execute=False (query results)
|
||||||
- int if execute=True (number of affected rows)
|
- int if execute=True (number of affected rows)
|
||||||
|
|
||||||
Example with vector type:
|
Example:
|
||||||
await execute_raw_with_schema(
|
await execute_raw_with_schema(
|
||||||
'INSERT INTO {schema_prefix}"Embedding" (vec) VALUES ($1::vector)',
|
'INSERT INTO {schema_prefix}"Embedding" (vec) VALUES ($1::{schema}.vector)',
|
||||||
embedding_data
|
embedding_data
|
||||||
)
|
)
|
||||||
"""
|
"""
|
||||||
schema = get_database_schema()
|
schema = get_database_schema()
|
||||||
schema_prefix = f'"{schema}".' if schema != "public" else ""
|
schema_prefix = f'"{schema}".' if schema != "public" else ""
|
||||||
|
|
||||||
formatted_query = query_template.format(
|
formatted_query = query_template.format(schema_prefix=schema_prefix, schema=schema)
|
||||||
schema_prefix=schema_prefix,
|
|
||||||
schema=schema,
|
|
||||||
)
|
|
||||||
|
|
||||||
import prisma as prisma_module
|
import prisma as prisma_module
|
||||||
|
|
||||||
|
|||||||
@@ -103,18 +103,8 @@ class RedisEventBus(BaseRedisEventBus[M], ABC):
|
|||||||
return redis.get_redis()
|
return redis.get_redis()
|
||||||
|
|
||||||
def publish_event(self, event: M, channel_key: str):
|
def publish_event(self, event: M, channel_key: str):
|
||||||
"""
|
message, full_channel_name = self._serialize_message(event, channel_key)
|
||||||
Publish an event to Redis. Gracefully handles connection failures
|
self.connection.publish(full_channel_name, message)
|
||||||
by logging the error instead of raising exceptions.
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
message, full_channel_name = self._serialize_message(event, channel_key)
|
|
||||||
self.connection.publish(full_channel_name, message)
|
|
||||||
except Exception:
|
|
||||||
logger.exception(
|
|
||||||
f"Failed to publish event to Redis channel {channel_key}. "
|
|
||||||
"Event bus operation will continue without Redis connectivity."
|
|
||||||
)
|
|
||||||
|
|
||||||
def listen_events(self, channel_key: str) -> Generator[M, None, None]:
|
def listen_events(self, channel_key: str) -> Generator[M, None, None]:
|
||||||
pubsub, full_channel_name = self._get_pubsub_channel(
|
pubsub, full_channel_name = self._get_pubsub_channel(
|
||||||
@@ -138,19 +128,9 @@ class AsyncRedisEventBus(BaseRedisEventBus[M], ABC):
|
|||||||
return await redis.get_redis_async()
|
return await redis.get_redis_async()
|
||||||
|
|
||||||
async def publish_event(self, event: M, channel_key: str):
|
async def publish_event(self, event: M, channel_key: str):
|
||||||
"""
|
message, full_channel_name = self._serialize_message(event, channel_key)
|
||||||
Publish an event to Redis. Gracefully handles connection failures
|
connection = await self.connection
|
||||||
by logging the error instead of raising exceptions.
|
await connection.publish(full_channel_name, message)
|
||||||
"""
|
|
||||||
try:
|
|
||||||
message, full_channel_name = self._serialize_message(event, channel_key)
|
|
||||||
connection = await self.connection
|
|
||||||
await connection.publish(full_channel_name, message)
|
|
||||||
except Exception:
|
|
||||||
logger.exception(
|
|
||||||
f"Failed to publish event to Redis channel {channel_key}. "
|
|
||||||
"Event bus operation will continue without Redis connectivity."
|
|
||||||
)
|
|
||||||
|
|
||||||
async def listen_events(self, channel_key: str) -> AsyncGenerator[M, None]:
|
async def listen_events(self, channel_key: str) -> AsyncGenerator[M, None]:
|
||||||
pubsub, full_channel_name = self._get_pubsub_channel(
|
pubsub, full_channel_name = self._get_pubsub_channel(
|
||||||
|
|||||||
@@ -1,56 +0,0 @@
|
|||||||
"""
|
|
||||||
Tests for event_bus graceful degradation when Redis is unavailable.
|
|
||||||
"""
|
|
||||||
|
|
||||||
from unittest.mock import AsyncMock, patch
|
|
||||||
|
|
||||||
import pytest
|
|
||||||
from pydantic import BaseModel
|
|
||||||
|
|
||||||
from backend.data.event_bus import AsyncRedisEventBus
|
|
||||||
|
|
||||||
|
|
||||||
class TestEvent(BaseModel):
|
|
||||||
"""Test event model."""
|
|
||||||
|
|
||||||
message: str
|
|
||||||
|
|
||||||
|
|
||||||
class TestNotificationBus(AsyncRedisEventBus[TestEvent]):
|
|
||||||
"""Test implementation of AsyncRedisEventBus."""
|
|
||||||
|
|
||||||
Model = TestEvent
|
|
||||||
|
|
||||||
@property
|
|
||||||
def event_bus_name(self) -> str:
|
|
||||||
return "test_event_bus"
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_publish_event_handles_connection_failure_gracefully():
|
|
||||||
"""Test that publish_event logs exception instead of raising when Redis is unavailable."""
|
|
||||||
bus = TestNotificationBus()
|
|
||||||
event = TestEvent(message="test message")
|
|
||||||
|
|
||||||
# Mock get_redis_async to raise connection error
|
|
||||||
with patch(
|
|
||||||
"backend.data.event_bus.redis.get_redis_async",
|
|
||||||
side_effect=ConnectionError("Authentication required."),
|
|
||||||
):
|
|
||||||
# Should not raise exception
|
|
||||||
await bus.publish_event(event, "test_channel")
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_publish_event_works_with_redis_available():
|
|
||||||
"""Test that publish_event works normally when Redis is available."""
|
|
||||||
bus = TestNotificationBus()
|
|
||||||
event = TestEvent(message="test message")
|
|
||||||
|
|
||||||
# Mock successful Redis connection
|
|
||||||
mock_redis = AsyncMock()
|
|
||||||
mock_redis.publish = AsyncMock()
|
|
||||||
|
|
||||||
with patch("backend.data.event_bus.redis.get_redis_async", return_value=mock_redis):
|
|
||||||
await bus.publish_event(event, "test_channel")
|
|
||||||
mock_redis.publish.assert_called_once()
|
|
||||||
@@ -81,8 +81,6 @@ class ExecutionContext(BaseModel):
|
|||||||
This includes information needed by blocks, sub-graphs, and execution management.
|
This includes information needed by blocks, sub-graphs, and execution management.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
model_config = {"extra": "ignore"}
|
|
||||||
|
|
||||||
human_in_the_loop_safe_mode: bool = True
|
human_in_the_loop_safe_mode: bool = True
|
||||||
sensitive_action_safe_mode: bool = False
|
sensitive_action_safe_mode: bool = False
|
||||||
user_timezone: str = "UTC"
|
user_timezone: str = "UTC"
|
||||||
|
|||||||
@@ -64,8 +64,6 @@ logger = logging.getLogger(__name__)
|
|||||||
class GraphSettings(BaseModel):
|
class GraphSettings(BaseModel):
|
||||||
# Use Annotated with BeforeValidator to coerce None to default values.
|
# Use Annotated with BeforeValidator to coerce None to default values.
|
||||||
# This handles cases where the database has null values for these fields.
|
# This handles cases where the database has null values for these fields.
|
||||||
model_config = {"extra": "ignore"}
|
|
||||||
|
|
||||||
human_in_the_loop_safe_mode: Annotated[
|
human_in_the_loop_safe_mode: Annotated[
|
||||||
bool, BeforeValidator(lambda v: v if v is not None else True)
|
bool, BeforeValidator(lambda v: v if v is not None else True)
|
||||||
] = True
|
] = True
|
||||||
|
|||||||
@@ -1,10 +1,9 @@
|
|||||||
-- CreateExtension
|
-- CreateExtension
|
||||||
-- Supabase: pgvector must be enabled via Dashboard → Database → Extensions first
|
-- Supabase: pgvector must be enabled via Dashboard → Database → Extensions first
|
||||||
-- Creates extension in current schema (determined by search_path from DATABASE_URL ?schema= param)
|
-- Create in public schema so vector type is available across all schemas
|
||||||
-- This ensures vector type is in the same schema as tables, making ::vector work without explicit qualification
|
|
||||||
DO $$
|
DO $$
|
||||||
BEGIN
|
BEGIN
|
||||||
CREATE EXTENSION IF NOT EXISTS "vector";
|
CREATE EXTENSION IF NOT EXISTS "vector" WITH SCHEMA "public";
|
||||||
EXCEPTION WHEN OTHERS THEN
|
EXCEPTION WHEN OTHERS THEN
|
||||||
RAISE NOTICE 'vector extension not available or already exists, skipping';
|
RAISE NOTICE 'vector extension not available or already exists, skipping';
|
||||||
END $$;
|
END $$;
|
||||||
@@ -20,7 +19,7 @@ CREATE TABLE "UnifiedContentEmbedding" (
|
|||||||
"contentType" "ContentType" NOT NULL,
|
"contentType" "ContentType" NOT NULL,
|
||||||
"contentId" TEXT NOT NULL,
|
"contentId" TEXT NOT NULL,
|
||||||
"userId" TEXT,
|
"userId" TEXT,
|
||||||
"embedding" vector(1536) NOT NULL,
|
"embedding" public.vector(1536) NOT NULL,
|
||||||
"searchableText" TEXT NOT NULL,
|
"searchableText" TEXT NOT NULL,
|
||||||
"metadata" JSONB NOT NULL DEFAULT '{}',
|
"metadata" JSONB NOT NULL DEFAULT '{}',
|
||||||
|
|
||||||
@@ -46,4 +45,4 @@ CREATE UNIQUE INDEX "UnifiedContentEmbedding_contentType_contentId_userId_key" O
|
|||||||
-- Uses cosine distance operator (<=>), which matches the query in hybrid_search.py
|
-- Uses cosine distance operator (<=>), which matches the query in hybrid_search.py
|
||||||
-- Note: Drop first in case Prisma created a btree index (Prisma doesn't support HNSW)
|
-- Note: Drop first in case Prisma created a btree index (Prisma doesn't support HNSW)
|
||||||
DROP INDEX IF EXISTS "UnifiedContentEmbedding_embedding_idx";
|
DROP INDEX IF EXISTS "UnifiedContentEmbedding_embedding_idx";
|
||||||
CREATE INDEX "UnifiedContentEmbedding_embedding_idx" ON "UnifiedContentEmbedding" USING hnsw ("embedding" vector_cosine_ops);
|
CREATE INDEX "UnifiedContentEmbedding_embedding_idx" ON "UnifiedContentEmbedding" USING hnsw ("embedding" public.vector_cosine_ops);
|
||||||
|
|||||||
@@ -366,12 +366,12 @@ def generate_block_markdown(
|
|||||||
lines.append("")
|
lines.append("")
|
||||||
|
|
||||||
# What it is (full description)
|
# What it is (full description)
|
||||||
lines.append("### What it is")
|
lines.append(f"### What it is")
|
||||||
lines.append(block.description or "No description available.")
|
lines.append(block.description or "No description available.")
|
||||||
lines.append("")
|
lines.append("")
|
||||||
|
|
||||||
# How it works (manual section)
|
# How it works (manual section)
|
||||||
lines.append("### How it works")
|
lines.append(f"### How it works")
|
||||||
how_it_works = manual_content.get(
|
how_it_works = manual_content.get(
|
||||||
"how_it_works", "_Add technical explanation here._"
|
"how_it_works", "_Add technical explanation here._"
|
||||||
)
|
)
|
||||||
@@ -383,7 +383,7 @@ def generate_block_markdown(
|
|||||||
# Inputs table (auto-generated)
|
# Inputs table (auto-generated)
|
||||||
visible_inputs = [f for f in block.inputs if not f.hidden]
|
visible_inputs = [f for f in block.inputs if not f.hidden]
|
||||||
if visible_inputs:
|
if visible_inputs:
|
||||||
lines.append("### Inputs")
|
lines.append(f"### Inputs")
|
||||||
lines.append("")
|
lines.append("")
|
||||||
lines.append("| Input | Description | Type | Required |")
|
lines.append("| Input | Description | Type | Required |")
|
||||||
lines.append("|-------|-------------|------|----------|")
|
lines.append("|-------|-------------|------|----------|")
|
||||||
@@ -400,7 +400,7 @@ def generate_block_markdown(
|
|||||||
# Outputs table (auto-generated)
|
# Outputs table (auto-generated)
|
||||||
visible_outputs = [f for f in block.outputs if not f.hidden]
|
visible_outputs = [f for f in block.outputs if not f.hidden]
|
||||||
if visible_outputs:
|
if visible_outputs:
|
||||||
lines.append("### Outputs")
|
lines.append(f"### Outputs")
|
||||||
lines.append("")
|
lines.append("")
|
||||||
lines.append("| Output | Description | Type |")
|
lines.append("| Output | Description | Type |")
|
||||||
lines.append("|--------|-------------|------|")
|
lines.append("|--------|-------------|------|")
|
||||||
@@ -414,7 +414,7 @@ def generate_block_markdown(
|
|||||||
lines.append("")
|
lines.append("")
|
||||||
|
|
||||||
# Possible use case (manual section)
|
# Possible use case (manual section)
|
||||||
lines.append("### Possible use case")
|
lines.append(f"### Possible use case")
|
||||||
use_case = manual_content.get("use_case", "_Add practical use case examples here._")
|
use_case = manual_content.get("use_case", "_Add practical use case examples here._")
|
||||||
lines.append("<!-- MANUAL: use_case -->")
|
lines.append("<!-- MANUAL: use_case -->")
|
||||||
lines.append(use_case)
|
lines.append(use_case)
|
||||||
|
|||||||
@@ -1,15 +1,8 @@
|
|||||||
"use client";
|
"use client";
|
||||||
import React, {
|
import React, { useCallback, useEffect, useMemo, useState } from "react";
|
||||||
useCallback,
|
|
||||||
useContext,
|
|
||||||
useEffect,
|
|
||||||
useMemo,
|
|
||||||
useState,
|
|
||||||
} from "react";
|
|
||||||
|
|
||||||
import {
|
import {
|
||||||
CredentialsMetaInput,
|
CredentialsMetaInput,
|
||||||
CredentialsType,
|
|
||||||
GraphExecutionID,
|
GraphExecutionID,
|
||||||
GraphMeta,
|
GraphMeta,
|
||||||
LibraryAgentPreset,
|
LibraryAgentPreset,
|
||||||
@@ -36,11 +29,7 @@ import {
|
|||||||
} from "@/components/__legacy__/ui/icons";
|
} from "@/components/__legacy__/ui/icons";
|
||||||
import { Input } from "@/components/__legacy__/ui/input";
|
import { Input } from "@/components/__legacy__/ui/input";
|
||||||
import { Button } from "@/components/atoms/Button/Button";
|
import { Button } from "@/components/atoms/Button/Button";
|
||||||
import { CredentialsGroupedView } from "@/components/contextual/CredentialsInput/components/CredentialsGroupedView/CredentialsGroupedView";
|
import { CredentialsInput } from "@/components/contextual/CredentialsInput/CredentialsInput";
|
||||||
import {
|
|
||||||
findSavedCredentialByProviderAndType,
|
|
||||||
findSavedUserCredentialByProviderAndType,
|
|
||||||
} from "@/components/contextual/CredentialsInput/components/CredentialsGroupedView/helpers";
|
|
||||||
import { InformationTooltip } from "@/components/molecules/InformationTooltip/InformationTooltip";
|
import { InformationTooltip } from "@/components/molecules/InformationTooltip/InformationTooltip";
|
||||||
import {
|
import {
|
||||||
useToast,
|
useToast,
|
||||||
@@ -48,7 +37,6 @@ import {
|
|||||||
} from "@/components/molecules/Toast/use-toast";
|
} from "@/components/molecules/Toast/use-toast";
|
||||||
import { humanizeCronExpression } from "@/lib/cron-expression-utils";
|
import { humanizeCronExpression } from "@/lib/cron-expression-utils";
|
||||||
import { cn, isEmpty } from "@/lib/utils";
|
import { cn, isEmpty } from "@/lib/utils";
|
||||||
import { CredentialsProvidersContext } from "@/providers/agent-credentials/credentials-provider";
|
|
||||||
import { ClockIcon, CopyIcon, InfoIcon } from "@phosphor-icons/react";
|
import { ClockIcon, CopyIcon, InfoIcon } from "@phosphor-icons/react";
|
||||||
import { CalendarClockIcon, Trash2Icon } from "lucide-react";
|
import { CalendarClockIcon, Trash2Icon } from "lucide-react";
|
||||||
|
|
||||||
@@ -102,7 +90,6 @@ export function AgentRunDraftView({
|
|||||||
const api = useBackendAPI();
|
const api = useBackendAPI();
|
||||||
const { toast } = useToast();
|
const { toast } = useToast();
|
||||||
const toastOnFail = useToastOnFail();
|
const toastOnFail = useToastOnFail();
|
||||||
const allProviders = useContext(CredentialsProvidersContext);
|
|
||||||
|
|
||||||
const [inputValues, setInputValues] = useState<Record<string, any>>({});
|
const [inputValues, setInputValues] = useState<Record<string, any>>({});
|
||||||
const [inputCredentials, setInputCredentials] = useState<
|
const [inputCredentials, setInputCredentials] = useState<
|
||||||
@@ -141,77 +128,6 @@ export function AgentRunDraftView({
|
|||||||
() => graph.credentials_input_schema.properties,
|
() => graph.credentials_input_schema.properties,
|
||||||
[graph],
|
[graph],
|
||||||
);
|
);
|
||||||
const credentialFields = useMemo(
|
|
||||||
function getCredentialFields() {
|
|
||||||
return Object.entries(agentCredentialsInputFields);
|
|
||||||
},
|
|
||||||
[agentCredentialsInputFields],
|
|
||||||
);
|
|
||||||
const requiredCredentials = useMemo(
|
|
||||||
function getRequiredCredentials() {
|
|
||||||
return new Set(
|
|
||||||
(graph.credentials_input_schema?.required as string[]) || [],
|
|
||||||
);
|
|
||||||
},
|
|
||||||
[graph.credentials_input_schema?.required],
|
|
||||||
);
|
|
||||||
|
|
||||||
useEffect(
|
|
||||||
function initializeDefaultCredentials() {
|
|
||||||
if (!allProviders) return;
|
|
||||||
if (!graph.credentials_input_schema?.properties) return;
|
|
||||||
if (requiredCredentials.size === 0) return;
|
|
||||||
|
|
||||||
setInputCredentials(function updateCredentials(currentCreds) {
|
|
||||||
const next = { ...currentCreds };
|
|
||||||
let didAdd = false;
|
|
||||||
|
|
||||||
for (const key of requiredCredentials) {
|
|
||||||
if (next[key]) continue;
|
|
||||||
const schema = graph.credentials_input_schema.properties[key];
|
|
||||||
if (!schema) continue;
|
|
||||||
|
|
||||||
const providerNames = schema.credentials_provider || [];
|
|
||||||
const credentialTypes = schema.credentials_types || [];
|
|
||||||
const requiredScopes = schema.credentials_scopes;
|
|
||||||
|
|
||||||
const userCredential = findSavedUserCredentialByProviderAndType(
|
|
||||||
providerNames,
|
|
||||||
credentialTypes,
|
|
||||||
requiredScopes,
|
|
||||||
allProviders,
|
|
||||||
);
|
|
||||||
|
|
||||||
const savedCredential =
|
|
||||||
userCredential ||
|
|
||||||
findSavedCredentialByProviderAndType(
|
|
||||||
providerNames,
|
|
||||||
credentialTypes,
|
|
||||||
requiredScopes,
|
|
||||||
allProviders,
|
|
||||||
);
|
|
||||||
|
|
||||||
if (!savedCredential) continue;
|
|
||||||
|
|
||||||
next[key] = {
|
|
||||||
id: savedCredential.id,
|
|
||||||
provider: savedCredential.provider,
|
|
||||||
type: savedCredential.type as CredentialsType,
|
|
||||||
title: savedCredential.title,
|
|
||||||
};
|
|
||||||
didAdd = true;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (!didAdd) return currentCreds;
|
|
||||||
return next;
|
|
||||||
});
|
|
||||||
},
|
|
||||||
[
|
|
||||||
allProviders,
|
|
||||||
graph.credentials_input_schema?.properties,
|
|
||||||
requiredCredentials,
|
|
||||||
],
|
|
||||||
);
|
|
||||||
|
|
||||||
const [allRequiredInputsAreSet, missingInputs] = useMemo(() => {
|
const [allRequiredInputsAreSet, missingInputs] = useMemo(() => {
|
||||||
const nonEmptyInputs = new Set(
|
const nonEmptyInputs = new Set(
|
||||||
@@ -229,35 +145,18 @@ export function AgentRunDraftView({
|
|||||||
);
|
);
|
||||||
return [isSuperset, difference];
|
return [isSuperset, difference];
|
||||||
}, [agentInputSchema.required, inputValues]);
|
}, [agentInputSchema.required, inputValues]);
|
||||||
const [allCredentialsAreSet, missingCredentials] = useMemo(
|
const [allCredentialsAreSet, missingCredentials] = useMemo(() => {
|
||||||
function getCredentialStatus() {
|
const availableCredentials = new Set(Object.keys(inputCredentials));
|
||||||
const missing = Array.from(requiredCredentials).filter((key) => {
|
const allCredentials = new Set(Object.keys(agentCredentialsInputFields));
|
||||||
const cred = inputCredentials[key];
|
// Backwards-compatible implementation of isSupersetOf and difference
|
||||||
return !cred || !cred.id;
|
const isSuperset = Array.from(allCredentials).every((item) =>
|
||||||
});
|
availableCredentials.has(item),
|
||||||
return [missing.length === 0, missing];
|
);
|
||||||
},
|
const difference = Array.from(allCredentials).filter(
|
||||||
[requiredCredentials, inputCredentials],
|
(item) => !availableCredentials.has(item),
|
||||||
);
|
);
|
||||||
function addChangedCredentials(prev: Set<keyof LibraryAgentPresetUpdatable>) {
|
return [isSuperset, difference];
|
||||||
const next = new Set(prev);
|
}, [agentCredentialsInputFields, inputCredentials]);
|
||||||
next.add("credentials");
|
|
||||||
return next;
|
|
||||||
}
|
|
||||||
|
|
||||||
function handleCredentialChange(key: string, value?: CredentialsMetaInput) {
|
|
||||||
setInputCredentials(function updateInputCredentials(currentCreds) {
|
|
||||||
const next = { ...currentCreds };
|
|
||||||
if (value === undefined) {
|
|
||||||
delete next[key];
|
|
||||||
return next;
|
|
||||||
}
|
|
||||||
next[key] = value;
|
|
||||||
return next;
|
|
||||||
});
|
|
||||||
setChangedPresetAttributes(addChangedCredentials);
|
|
||||||
}
|
|
||||||
|
|
||||||
const notifyMissingInputs = useCallback(
|
const notifyMissingInputs = useCallback(
|
||||||
(needPresetName: boolean = true) => {
|
(needPresetName: boolean = true) => {
|
||||||
const allMissingFields = (
|
const allMissingFields = (
|
||||||
@@ -750,6 +649,35 @@ export function AgentRunDraftView({
|
|||||||
</>
|
</>
|
||||||
)}
|
)}
|
||||||
|
|
||||||
|
{/* Credentials inputs */}
|
||||||
|
{Object.entries(agentCredentialsInputFields).map(
|
||||||
|
([key, inputSubSchema]) => (
|
||||||
|
<CredentialsInput
|
||||||
|
key={key}
|
||||||
|
schema={{ ...inputSubSchema, discriminator: undefined }}
|
||||||
|
selectedCredentials={
|
||||||
|
inputCredentials[key] ?? inputSubSchema.default
|
||||||
|
}
|
||||||
|
onSelectCredentials={(value) => {
|
||||||
|
setInputCredentials((obj) => {
|
||||||
|
const newObj = { ...obj };
|
||||||
|
if (value === undefined) {
|
||||||
|
delete newObj[key];
|
||||||
|
return newObj;
|
||||||
|
}
|
||||||
|
return {
|
||||||
|
...obj,
|
||||||
|
[key]: value,
|
||||||
|
};
|
||||||
|
});
|
||||||
|
setChangedPresetAttributes((prev) =>
|
||||||
|
prev.add("credentials"),
|
||||||
|
);
|
||||||
|
}}
|
||||||
|
/>
|
||||||
|
),
|
||||||
|
)}
|
||||||
|
|
||||||
{/* Regular inputs */}
|
{/* Regular inputs */}
|
||||||
{Object.entries(agentInputFields).map(([key, inputSubSchema]) => (
|
{Object.entries(agentInputFields).map(([key, inputSubSchema]) => (
|
||||||
<RunAgentInputs
|
<RunAgentInputs
|
||||||
@@ -767,17 +695,6 @@ export function AgentRunDraftView({
|
|||||||
data-testid={`agent-input-${key}`}
|
data-testid={`agent-input-${key}`}
|
||||||
/>
|
/>
|
||||||
))}
|
))}
|
||||||
|
|
||||||
{/* Credentials inputs */}
|
|
||||||
{credentialFields.length > 0 && (
|
|
||||||
<CredentialsGroupedView
|
|
||||||
credentialFields={credentialFields}
|
|
||||||
requiredCredentials={requiredCredentials}
|
|
||||||
inputCredentials={inputCredentials}
|
|
||||||
inputValues={inputValues}
|
|
||||||
onCredentialChange={handleCredentialChange}
|
|
||||||
/>
|
|
||||||
)}
|
|
||||||
</CardContent>
|
</CardContent>
|
||||||
</Card>
|
</Card>
|
||||||
</div>
|
</div>
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
import { CredentialsProvidersContextType } from "@/providers/agent-credentials/credentials-provider";
|
import { CredentialsProvidersContextType } from "@/providers/agent-credentials/credentials-provider";
|
||||||
import { filterSystemCredentials, getSystemCredentials } from "../../helpers";
|
import { getSystemCredentials } from "../../helpers";
|
||||||
|
|
||||||
export type CredentialField = [string, any];
|
export type CredentialField = [string, any];
|
||||||
|
|
||||||
@@ -208,42 +208,3 @@ export function findSavedCredentialByProviderAndType(
|
|||||||
|
|
||||||
return undefined;
|
return undefined;
|
||||||
}
|
}
|
||||||
|
|
||||||
export function findSavedUserCredentialByProviderAndType(
|
|
||||||
providerNames: string[],
|
|
||||||
credentialTypes: string[],
|
|
||||||
requiredScopes: string[] | undefined,
|
|
||||||
allProviders: CredentialsProvidersContextType | null,
|
|
||||||
): SavedCredential | undefined {
|
|
||||||
for (const providerName of providerNames) {
|
|
||||||
const providerData = allProviders?.[providerName];
|
|
||||||
if (!providerData) continue;
|
|
||||||
|
|
||||||
const userCredentials = filterSystemCredentials(
|
|
||||||
providerData.savedCredentials ?? [],
|
|
||||||
);
|
|
||||||
|
|
||||||
const matchingCredentials: SavedCredential[] = [];
|
|
||||||
|
|
||||||
for (const credential of userCredentials) {
|
|
||||||
const typeMatches =
|
|
||||||
credentialTypes.length === 0 ||
|
|
||||||
credentialTypes.includes(credential.type);
|
|
||||||
const scopesMatch = hasRequiredScopes(credential, requiredScopes);
|
|
||||||
|
|
||||||
if (!typeMatches) continue;
|
|
||||||
if (!scopesMatch) continue;
|
|
||||||
|
|
||||||
matchingCredentials.push(credential as SavedCredential);
|
|
||||||
}
|
|
||||||
|
|
||||||
if (matchingCredentials.length === 1) {
|
|
||||||
return matchingCredentials[0];
|
|
||||||
}
|
|
||||||
if (matchingCredentials.length > 1) {
|
|
||||||
return undefined;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return undefined;
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -98,20 +98,24 @@ export function useCredentialsInput({
|
|||||||
|
|
||||||
// Auto-select the first available credential on initial mount
|
// Auto-select the first available credential on initial mount
|
||||||
// Once a user has made a selection, we don't override it
|
// Once a user has made a selection, we don't override it
|
||||||
useEffect(
|
useEffect(() => {
|
||||||
function autoSelectCredential() {
|
if (readOnly) return;
|
||||||
if (readOnly) return;
|
if (!credentials || !("savedCredentials" in credentials)) return;
|
||||||
if (!credentials || !("savedCredentials" in credentials)) return;
|
|
||||||
if (selectedCredential?.id) return;
|
|
||||||
|
|
||||||
const savedCreds = credentials.savedCredentials;
|
// If already selected, don't auto-select
|
||||||
if (savedCreds.length === 0) return;
|
if (selectedCredential?.id) return;
|
||||||
|
|
||||||
if (hasAttemptedAutoSelect.current) return;
|
// Only attempt auto-selection once
|
||||||
hasAttemptedAutoSelect.current = true;
|
if (hasAttemptedAutoSelect.current) return;
|
||||||
|
hasAttemptedAutoSelect.current = true;
|
||||||
|
|
||||||
if (isOptional) return;
|
// If optional, don't auto-select (user can choose "None")
|
||||||
|
if (isOptional) return;
|
||||||
|
|
||||||
|
const savedCreds = credentials.savedCredentials;
|
||||||
|
|
||||||
|
// Auto-select the first credential if any are available
|
||||||
|
if (savedCreds.length > 0) {
|
||||||
const cred = savedCreds[0];
|
const cred = savedCreds[0];
|
||||||
onSelectCredential({
|
onSelectCredential({
|
||||||
id: cred.id,
|
id: cred.id,
|
||||||
@@ -119,15 +123,14 @@ export function useCredentialsInput({
|
|||||||
provider: credentials.provider,
|
provider: credentials.provider,
|
||||||
title: (cred as any).title,
|
title: (cred as any).title,
|
||||||
});
|
});
|
||||||
},
|
}
|
||||||
[
|
}, [
|
||||||
credentials,
|
credentials,
|
||||||
selectedCredential?.id,
|
selectedCredential?.id,
|
||||||
readOnly,
|
readOnly,
|
||||||
isOptional,
|
isOptional,
|
||||||
onSelectCredential,
|
onSelectCredential,
|
||||||
],
|
]);
|
||||||
);
|
|
||||||
|
|
||||||
if (
|
if (
|
||||||
!credentials ||
|
!credentials ||
|
||||||
|
|||||||
@@ -106,14 +106,9 @@ export function getTimezoneDisplayName(timezone: string): string {
|
|||||||
const parts = timezone.split("/");
|
const parts = timezone.split("/");
|
||||||
const city = parts[parts.length - 1].replace(/_/g, " ");
|
const city = parts[parts.length - 1].replace(/_/g, " ");
|
||||||
const abbr = getTimezoneAbbreviation(timezone);
|
const abbr = getTimezoneAbbreviation(timezone);
|
||||||
if (abbr && abbr !== timezone) {
|
return abbr ? `${city} (${abbr})` : city;
|
||||||
return `${city} (${abbr})`;
|
|
||||||
}
|
|
||||||
// If abbreviation is same as timezone or not found, show timezone with underscores replaced
|
|
||||||
const timezoneDisplay = timezone.replace(/_/g, " ");
|
|
||||||
return `${city} (${timezoneDisplay})`;
|
|
||||||
} catch {
|
} catch {
|
||||||
return timezone.replace(/_/g, " ");
|
return timezone;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -4,7 +4,6 @@ import { LoginPage } from "./pages/login.page";
|
|||||||
import { MarketplacePage } from "./pages/marketplace.page";
|
import { MarketplacePage } from "./pages/marketplace.page";
|
||||||
import { hasMinCount, hasUrl, isVisible, matchesUrl } from "./utils/assertion";
|
import { hasMinCount, hasUrl, isVisible, matchesUrl } from "./utils/assertion";
|
||||||
|
|
||||||
// Marketplace tests for store agent search functionality
|
|
||||||
test.describe("Marketplace – Basic Functionality", () => {
|
test.describe("Marketplace – Basic Functionality", () => {
|
||||||
test("User can access marketplace page when logged out", async ({ page }) => {
|
test("User can access marketplace page when logged out", async ({ page }) => {
|
||||||
const marketplacePage = new MarketplacePage(page);
|
const marketplacePage = new MarketplacePage(page);
|
||||||
|
|||||||
Reference in New Issue
Block a user