mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-01-09 15:17:59 -05:00
Revert "Merge branch 'swiftyos/secrt-1709-store-provider-names-and-en… (#11225)
Changes to providers blocks to store in db ### Changes 🏗️ - revet change ### Checklist 📋 #### For code changes: - [x] I have clearly listed my changes in the PR description - [x] I have made a test plan - [x] I have tested my changes according to the test plan: <!-- Put your test plan here: --> - [x] I have reverted the merge
This commit is contained in:
@@ -0,0 +1,9 @@
|
||||
# Import the provider builder to ensure it's registered
|
||||
from backend.sdk.registry import AutoRegistry
|
||||
|
||||
from .triggers import GenericWebhookTriggerBlock, generic_webhook
|
||||
|
||||
# Ensure the SDK registry is patched to include our webhook manager
|
||||
AutoRegistry.patch_integrations()
|
||||
|
||||
__all__ = ["GenericWebhookTriggerBlock", "generic_webhook"]
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
import collections
|
||||
import inspect
|
||||
import logging
|
||||
import os
|
||||
@@ -10,7 +9,6 @@ from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
ClassVar,
|
||||
Dict,
|
||||
Generic,
|
||||
Optional,
|
||||
Sequence,
|
||||
@@ -22,8 +20,7 @@ from typing import (
|
||||
|
||||
import jsonref
|
||||
import jsonschema
|
||||
from prisma import Json
|
||||
from prisma.models import AgentBlock, BlocksRegistry
|
||||
from prisma.models import AgentBlock
|
||||
from prisma.types import AgentBlockCreateInput
|
||||
from pydantic import BaseModel
|
||||
|
||||
@@ -482,50 +479,19 @@ class Block(ABC, Generic[BlockSchemaInputType, BlockSchemaOutputType]):
|
||||
return self.__class__.__name__
|
||||
|
||||
def to_dict(self):
|
||||
# Sort categories by their name to ensure consistent ordering
|
||||
sorted_categories = [
|
||||
category.dict()
|
||||
for category in sorted(self.categories, key=lambda c: c.name)
|
||||
]
|
||||
|
||||
# Sort dictionary keys recursively for consistent ordering
|
||||
def sort_dict(obj):
|
||||
if isinstance(obj, dict):
|
||||
return collections.OrderedDict(
|
||||
sorted((k, sort_dict(v)) for k, v in obj.items())
|
||||
)
|
||||
elif isinstance(obj, list):
|
||||
# Check if all items in the list are primitive types that can be sorted
|
||||
if obj and all(
|
||||
isinstance(item, (str, int, float, bool, type(None)))
|
||||
for item in obj
|
||||
):
|
||||
# Sort primitive lists for consistent ordering
|
||||
return sorted(obj, key=lambda x: (x is None, str(x)))
|
||||
else:
|
||||
# For lists of complex objects, process each item but maintain order
|
||||
return [sort_dict(item) for item in obj]
|
||||
return obj
|
||||
|
||||
return collections.OrderedDict(
|
||||
[
|
||||
("id", self.id),
|
||||
("name", self.name),
|
||||
("inputSchema", sort_dict(self.input_schema.jsonschema())),
|
||||
("outputSchema", sort_dict(self.output_schema.jsonschema())),
|
||||
("description", self.description),
|
||||
("categories", sorted_categories),
|
||||
(
|
||||
"contributors",
|
||||
sorted(
|
||||
[contributor.model_dump() for contributor in self.contributors],
|
||||
key=lambda c: (c.get("name", ""), c.get("username", "")),
|
||||
),
|
||||
),
|
||||
("staticOutput", self.static_output),
|
||||
("uiType", self.block_type.value),
|
||||
]
|
||||
)
|
||||
return {
|
||||
"id": self.id,
|
||||
"name": self.name,
|
||||
"inputSchema": self.input_schema.jsonschema(),
|
||||
"outputSchema": self.output_schema.jsonschema(),
|
||||
"description": self.description,
|
||||
"categories": [category.dict() for category in self.categories],
|
||||
"contributors": [
|
||||
contributor.model_dump() for contributor in self.contributors
|
||||
],
|
||||
"staticOutput": self.static_output,
|
||||
"uiType": self.block_type.value,
|
||||
}
|
||||
|
||||
def get_info(self) -> BlockInfo:
|
||||
from backend.data.credit import get_block_cost
|
||||
@@ -772,123 +738,3 @@ def get_io_block_ids() -> Sequence[str]:
|
||||
for id, B in get_blocks().items()
|
||||
if B().block_type in (BlockType.INPUT, BlockType.OUTPUT)
|
||||
]
|
||||
|
||||
|
||||
async def get_block_registry() -> Dict[str, BlocksRegistry]:
|
||||
"""
|
||||
Retrieves the BlocksRegistry from the database and returns a dictionary mapping
|
||||
block names to BlocksRegistry objects.
|
||||
|
||||
Returns:
|
||||
Dict[str, BlocksRegistry]: A dictionary where each key is a block name and
|
||||
each value is a BlocksRegistry instance.
|
||||
"""
|
||||
blocks = await BlocksRegistry.prisma().find_many()
|
||||
return {block.id: block for block in blocks}
|
||||
|
||||
|
||||
def recursive_json_compare(
|
||||
db_block_definition: Any, local_block_definition: Any
|
||||
) -> bool:
|
||||
"""
|
||||
Recursively compares two JSON objects for equality.
|
||||
|
||||
Args:
|
||||
db_block_definition (Any): The JSON object from the database.
|
||||
local_block_definition (Any): The local JSON object to compare against.
|
||||
|
||||
Returns:
|
||||
bool: True if the objects are equal, False otherwise.
|
||||
"""
|
||||
if isinstance(db_block_definition, dict) and isinstance(
|
||||
local_block_definition, dict
|
||||
):
|
||||
if set(db_block_definition.keys()) != set(local_block_definition.keys()):
|
||||
logger.error(
|
||||
f"Keys are not the same: {set(db_block_definition.keys())} != {set(local_block_definition.keys())}"
|
||||
)
|
||||
return False
|
||||
return all(
|
||||
recursive_json_compare(db_block_definition[k], local_block_definition[k])
|
||||
for k in db_block_definition
|
||||
)
|
||||
values_are_same = db_block_definition == local_block_definition
|
||||
if not values_are_same:
|
||||
logger.error(
|
||||
f"Values are not the same: {db_block_definition} != {local_block_definition}"
|
||||
)
|
||||
return values_are_same
|
||||
|
||||
|
||||
def check_block_same(db_block: BlocksRegistry, local_block: Block) -> bool:
|
||||
"""
|
||||
Compares a database block with a local block.
|
||||
|
||||
Args:
|
||||
db_block (BlocksRegistry): The block object from the database registry.
|
||||
local_block (Block[BlockSchema, BlockSchema]): The local block definition.
|
||||
|
||||
Returns:
|
||||
bool: True if the blocks are equal, False otherwise.
|
||||
"""
|
||||
local_block_instance = local_block() # type: ignore
|
||||
local_block_definition = local_block_instance.to_dict()
|
||||
db_block_definition = db_block.definition
|
||||
is_same = recursive_json_compare(db_block_definition, local_block_definition)
|
||||
return is_same
|
||||
|
||||
|
||||
def find_delta_blocks(
|
||||
db_blocks: Dict[str, BlocksRegistry], local_blocks: Dict[str, Block]
|
||||
) -> Dict[str, Block]:
|
||||
"""
|
||||
Finds the set of blocks that are new or changed compared to the database.
|
||||
|
||||
Args:
|
||||
db_blocks (Dict[str, BlocksRegistry]): Existing blocks from the database, keyed by name.
|
||||
local_blocks (Dict[str, Block]): Local block definitions, keyed by name.
|
||||
|
||||
Returns:
|
||||
Dict[str, Block]: Blocks that are missing from or different than the database, keyed by name.
|
||||
"""
|
||||
block_update: Dict[str, Block] = {}
|
||||
for block_id, block in local_blocks.items():
|
||||
if block_id not in db_blocks:
|
||||
block_update[block_id] = block
|
||||
else:
|
||||
if not check_block_same(db_blocks[block_id], block):
|
||||
block_update[block_id] = block
|
||||
return block_update
|
||||
|
||||
|
||||
async def upsert_blocks_change_bulk(blocks: Dict[str, Block]):
|
||||
"""
|
||||
Bulk upserts blocks into the database if changed.
|
||||
|
||||
- Compares the provided local blocks to those in the database via their definition.
|
||||
- Inserts new or updated blocks.
|
||||
|
||||
Args:
|
||||
blocks (Dict[str, Block]): Local block definitions to upsert.
|
||||
|
||||
Returns:
|
||||
Dict[str, Block]: Blocks that were new or changed and upserted.
|
||||
"""
|
||||
db_blocks = await get_block_registry()
|
||||
block_update = find_delta_blocks(db_blocks, blocks)
|
||||
for block_id, block in block_update.items():
|
||||
await BlocksRegistry.prisma().upsert(
|
||||
where={"id": block_id},
|
||||
data={
|
||||
"create": {
|
||||
"id": block_id,
|
||||
"name": block().__class__.__name__, # type: ignore
|
||||
"definition": Json(block.to_dict(block())), # type: ignore
|
||||
},
|
||||
"update": {
|
||||
"name": block().__class__.__name__, # type: ignore
|
||||
"definition": Json(block.to_dict(block())), # type: ignore
|
||||
},
|
||||
},
|
||||
)
|
||||
return block_update
|
||||
|
||||
@@ -1,191 +0,0 @@
|
||||
import json
|
||||
from datetime import datetime
|
||||
|
||||
import pytest
|
||||
from prisma.models import BlocksRegistry
|
||||
|
||||
from backend.blocks.basic import (
|
||||
FileStoreBlock,
|
||||
PrintToConsoleBlock,
|
||||
ReverseListOrderBlock,
|
||||
StoreValueBlock,
|
||||
)
|
||||
from backend.data.block import (
|
||||
check_block_same,
|
||||
find_delta_blocks,
|
||||
recursive_json_compare,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_recursive_json_compare():
|
||||
db_block_definition = {
|
||||
"a": 1,
|
||||
"b": 2,
|
||||
"c": 3,
|
||||
}
|
||||
local_block_definition = {
|
||||
"a": 1,
|
||||
"b": 2,
|
||||
"c": 3,
|
||||
}
|
||||
assert recursive_json_compare(db_block_definition, local_block_definition)
|
||||
assert not recursive_json_compare(
|
||||
db_block_definition, {**local_block_definition, "d": 4}
|
||||
)
|
||||
assert not recursive_json_compare(
|
||||
db_block_definition, {**local_block_definition, "a": 2}
|
||||
)
|
||||
assert not recursive_json_compare(
|
||||
db_block_definition, {**local_block_definition, "b": 3}
|
||||
)
|
||||
assert not recursive_json_compare(
|
||||
db_block_definition, {**local_block_definition, "c": 4}
|
||||
)
|
||||
assert not recursive_json_compare(
|
||||
db_block_definition, {**local_block_definition, "a": 1, "b": 2, "c": 3, "d": 4}
|
||||
)
|
||||
assert recursive_json_compare({}, {})
|
||||
assert recursive_json_compare({"a": 1}, {"a": 1})
|
||||
assert not recursive_json_compare({"a": 1}, {"b": 1})
|
||||
assert not recursive_json_compare({"a": 1}, {"a": 2})
|
||||
assert not recursive_json_compare({"a": 1}, {"a": [1, 2]})
|
||||
assert not recursive_json_compare({"a": 1}, {"a": {"b": 1}})
|
||||
assert not recursive_json_compare({"a": 1}, {"a": {"b": 2}})
|
||||
assert not recursive_json_compare({"a": 1}, {"a": {"b": [1, 2]}})
|
||||
assert not recursive_json_compare({"a": 1}, {"a": {"b": {"c": 1}}})
|
||||
assert not recursive_json_compare({"a": 1}, {"a": {"b": {"c": 2}}})
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_check_block_same():
|
||||
local_block_instance = PrintToConsoleBlock()
|
||||
db_block = BlocksRegistry(
|
||||
id="f3b1c1b2-4c4f-4f0d-8d2f-4c4f0d8d2f4c",
|
||||
name=local_block_instance.__class__.__name__,
|
||||
definition=json.dumps(local_block_instance.to_dict()), # type: ignore To much type magic going on here
|
||||
updatedAt=datetime.now(),
|
||||
)
|
||||
assert check_block_same(db_block, PrintToConsoleBlock) # type: ignore
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_check_block_not_same():
|
||||
local_block_instance = PrintToConsoleBlock()
|
||||
local_block_data = local_block_instance.to_dict()
|
||||
local_block_data["description"] = "Hello, World!"
|
||||
|
||||
db_block = BlocksRegistry(
|
||||
id="f3b1c1b2-4c4f-4f0d-8d2f-4c4f0d8d2f4c",
|
||||
name=local_block_instance.__class__.__name__,
|
||||
definition=json.dumps(local_block_data), # type: ignore To much type magic going on here
|
||||
updatedAt=datetime.now(),
|
||||
)
|
||||
assert not check_block_same(db_block, PrintToConsoleBlock) # type: ignore
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_find_delta_blocks():
|
||||
now = datetime.now()
|
||||
store_value_block = StoreValueBlock()
|
||||
local_blocks = {
|
||||
PrintToConsoleBlock().id: PrintToConsoleBlock,
|
||||
ReverseListOrderBlock().id: ReverseListOrderBlock,
|
||||
FileStoreBlock().id: FileStoreBlock,
|
||||
store_value_block.id: StoreValueBlock,
|
||||
}
|
||||
db_blocks = {
|
||||
PrintToConsoleBlock().id: BlocksRegistry(
|
||||
id=PrintToConsoleBlock().id,
|
||||
name=PrintToConsoleBlock().__class__.__name__,
|
||||
definition=json.dumps(PrintToConsoleBlock().to_dict()), # type: ignore To much type magic going on here
|
||||
updatedAt=now,
|
||||
),
|
||||
ReverseListOrderBlock().id: BlocksRegistry(
|
||||
id=ReverseListOrderBlock().id,
|
||||
name=ReverseListOrderBlock().__class__.__name__,
|
||||
definition=json.dumps(ReverseListOrderBlock().to_dict()), # type: ignore To much type magic going on here
|
||||
updatedAt=now,
|
||||
),
|
||||
FileStoreBlock().id: BlocksRegistry(
|
||||
id=FileStoreBlock().id,
|
||||
name=FileStoreBlock().__class__.__name__,
|
||||
definition=json.dumps(FileStoreBlock().to_dict()), # type: ignore To much type magic going on here
|
||||
updatedAt=now,
|
||||
),
|
||||
}
|
||||
delta_blocks = find_delta_blocks(db_blocks, local_blocks)
|
||||
assert len(delta_blocks) == 1
|
||||
assert store_value_block.id in delta_blocks.keys()
|
||||
assert delta_blocks[store_value_block.id] == StoreValueBlock
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_find_delta_blocks_block_updated():
|
||||
now = datetime.now()
|
||||
store_value_block = StoreValueBlock()
|
||||
print_to_console_block_definition = PrintToConsoleBlock().to_dict()
|
||||
print_to_console_block_definition["description"] = "Hello, World!"
|
||||
local_blocks = {
|
||||
PrintToConsoleBlock().id: PrintToConsoleBlock,
|
||||
ReverseListOrderBlock().id: ReverseListOrderBlock,
|
||||
FileStoreBlock().id: FileStoreBlock,
|
||||
store_value_block.id: StoreValueBlock,
|
||||
}
|
||||
db_blocks = {
|
||||
PrintToConsoleBlock().id: BlocksRegistry(
|
||||
id=PrintToConsoleBlock().id,
|
||||
name=PrintToConsoleBlock().__class__.__name__,
|
||||
definition=json.dumps(print_to_console_block_definition), # type: ignore To much type magic going on here
|
||||
updatedAt=now,
|
||||
),
|
||||
ReverseListOrderBlock().id: BlocksRegistry(
|
||||
id=ReverseListOrderBlock().id,
|
||||
name=ReverseListOrderBlock().__class__.__name__,
|
||||
definition=json.dumps(ReverseListOrderBlock().to_dict()), # type: ignore To much type magic going on here
|
||||
updatedAt=now,
|
||||
),
|
||||
FileStoreBlock().id: BlocksRegistry(
|
||||
id=FileStoreBlock().id,
|
||||
name=FileStoreBlock().__class__.__name__,
|
||||
definition=json.dumps(FileStoreBlock().to_dict()), # type: ignore To much type magic going on here
|
||||
updatedAt=now,
|
||||
),
|
||||
}
|
||||
delta_blocks = find_delta_blocks(db_blocks, local_blocks)
|
||||
assert len(delta_blocks) == 2
|
||||
assert store_value_block.id in delta_blocks.keys()
|
||||
assert delta_blocks[store_value_block.id] == StoreValueBlock
|
||||
assert PrintToConsoleBlock().id in delta_blocks.keys()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_find_delta_block_no_diff():
|
||||
now = datetime.now()
|
||||
local_blocks = {
|
||||
PrintToConsoleBlock().id: PrintToConsoleBlock,
|
||||
ReverseListOrderBlock().id: ReverseListOrderBlock,
|
||||
FileStoreBlock().id: FileStoreBlock,
|
||||
}
|
||||
db_blocks = {
|
||||
PrintToConsoleBlock().id: BlocksRegistry(
|
||||
id=PrintToConsoleBlock().id,
|
||||
name=PrintToConsoleBlock().__class__.__name__,
|
||||
definition=json.dumps(PrintToConsoleBlock().to_dict()), # type: ignore To much type magic going on here
|
||||
updatedAt=now,
|
||||
),
|
||||
ReverseListOrderBlock().id: BlocksRegistry(
|
||||
id=ReverseListOrderBlock().id,
|
||||
name=ReverseListOrderBlock().__class__.__name__,
|
||||
definition=json.dumps(ReverseListOrderBlock().to_dict()), # type: ignore To much type magic going on here
|
||||
updatedAt=now,
|
||||
),
|
||||
FileStoreBlock().id: BlocksRegistry(
|
||||
id=FileStoreBlock().id,
|
||||
name=FileStoreBlock().__class__.__name__,
|
||||
definition=json.dumps(FileStoreBlock().to_dict()), # type: ignore To much type magic going on here
|
||||
updatedAt=now,
|
||||
),
|
||||
}
|
||||
delta_blocks = find_delta_blocks(db_blocks, local_blocks)
|
||||
assert len(delta_blocks) == 0
|
||||
@@ -346,8 +346,6 @@ class APIKeyCredentials(_BaseCredentials):
|
||||
)
|
||||
"""Unix timestamp (seconds) indicating when the API key expires (if at all)"""
|
||||
|
||||
api_key_env_var: Optional[str] = Field(default=None, exclude=True)
|
||||
|
||||
def auth_header(self) -> str:
|
||||
# Linear API keys should not have Bearer prefix
|
||||
if self.provider == "linear":
|
||||
@@ -526,13 +524,13 @@ class CredentialsMetaInput(BaseModel, Generic[CP, CT]):
|
||||
if hasattr(model_class, "allowed_providers") and hasattr(
|
||||
model_class, "allowed_cred_types"
|
||||
):
|
||||
allowed_providers = sorted(model_class.allowed_providers())
|
||||
allowed_providers = model_class.allowed_providers()
|
||||
# If no specific providers (None), allow any string
|
||||
if allowed_providers is None:
|
||||
schema["credentials_provider"] = ["string"] # Allow any string provider
|
||||
else:
|
||||
schema["credentials_provider"] = allowed_providers
|
||||
schema["credentials_types"] = sorted(model_class.allowed_cred_types())
|
||||
schema["credentials_types"] = model_class.allowed_cred_types()
|
||||
# Do not return anything, just mutate schema in place
|
||||
|
||||
model_config = ConfigDict(
|
||||
|
||||
@@ -26,7 +26,6 @@ ollama_credentials = APIKeyCredentials(
|
||||
id="744fdc56-071a-4761-b5a5-0af0ce10a2b5",
|
||||
provider="ollama",
|
||||
api_key=SecretStr("FAKE_API_KEY"),
|
||||
api_key_env_var=None,
|
||||
title="Use Credits for Ollama",
|
||||
expires_at=None,
|
||||
)
|
||||
@@ -35,7 +34,6 @@ revid_credentials = APIKeyCredentials(
|
||||
id="fdb7f412-f519-48d1-9b5f-d2f73d0e01fe",
|
||||
provider="revid",
|
||||
api_key=SecretStr(settings.secrets.revid_api_key),
|
||||
api_key_env_var="REVID_API_KEY",
|
||||
title="Use Credits for Revid",
|
||||
expires_at=None,
|
||||
)
|
||||
@@ -43,7 +41,6 @@ ideogram_credentials = APIKeyCredentials(
|
||||
id="760f84fc-b270-42de-91f6-08efe1b512d0",
|
||||
provider="ideogram",
|
||||
api_key=SecretStr(settings.secrets.ideogram_api_key),
|
||||
api_key_env_var="IDEOGRAM_API_KEY",
|
||||
title="Use Credits for Ideogram",
|
||||
expires_at=None,
|
||||
)
|
||||
@@ -51,7 +48,6 @@ replicate_credentials = APIKeyCredentials(
|
||||
id="6b9fc200-4726-4973-86c9-cd526f5ce5db",
|
||||
provider="replicate",
|
||||
api_key=SecretStr(settings.secrets.replicate_api_key),
|
||||
api_key_env_var="REPLICATE_API_KEY",
|
||||
title="Use Credits for Replicate",
|
||||
expires_at=None,
|
||||
)
|
||||
@@ -59,7 +55,6 @@ openai_credentials = APIKeyCredentials(
|
||||
id="53c25cb8-e3ee-465c-a4d1-e75a4c899c2a",
|
||||
provider="openai",
|
||||
api_key=SecretStr(settings.secrets.openai_api_key),
|
||||
api_key_env_var="OPENAI_API_KEY",
|
||||
title="Use Credits for OpenAI",
|
||||
expires_at=None,
|
||||
)
|
||||
@@ -67,7 +62,6 @@ aiml_api_credentials = APIKeyCredentials(
|
||||
id="aad82a89-9794-4ebb-977f-d736aa5260a3",
|
||||
provider="aiml_api",
|
||||
api_key=SecretStr(settings.secrets.aiml_api_key),
|
||||
api_key_env_var="AIML_API_KEY",
|
||||
title="Use Credits for AI/ML API",
|
||||
expires_at=None,
|
||||
)
|
||||
@@ -75,7 +69,6 @@ anthropic_credentials = APIKeyCredentials(
|
||||
id="24e5d942-d9e3-4798-8151-90143ee55629",
|
||||
provider="anthropic",
|
||||
api_key=SecretStr(settings.secrets.anthropic_api_key),
|
||||
api_key_env_var="ANTHROPIC_API_KEY",
|
||||
title="Use Credits for Anthropic",
|
||||
expires_at=None,
|
||||
)
|
||||
@@ -83,7 +76,6 @@ groq_credentials = APIKeyCredentials(
|
||||
id="4ec22295-8f97-4dd1-b42b-2c6957a02545",
|
||||
provider="groq",
|
||||
api_key=SecretStr(settings.secrets.groq_api_key),
|
||||
api_key_env_var="GROQ_API_KEY",
|
||||
title="Use Credits for Groq",
|
||||
expires_at=None,
|
||||
)
|
||||
@@ -91,7 +83,6 @@ did_credentials = APIKeyCredentials(
|
||||
id="7f7b0654-c36b-4565-8fa7-9a52575dfae2",
|
||||
provider="d_id",
|
||||
api_key=SecretStr(settings.secrets.did_api_key),
|
||||
api_key_env_var="DID_API_KEY",
|
||||
title="Use Credits for D-ID",
|
||||
expires_at=None,
|
||||
)
|
||||
@@ -99,7 +90,6 @@ jina_credentials = APIKeyCredentials(
|
||||
id="7f26de70-ba0d-494e-ba76-238e65e7b45f",
|
||||
provider="jina",
|
||||
api_key=SecretStr(settings.secrets.jina_api_key),
|
||||
api_key_env_var="JINA_API_KEY",
|
||||
title="Use Credits for Jina",
|
||||
expires_at=None,
|
||||
)
|
||||
@@ -107,7 +97,6 @@ unreal_credentials = APIKeyCredentials(
|
||||
id="66f20754-1b81-48e4-91d0-f4f0dd82145f",
|
||||
provider="unreal",
|
||||
api_key=SecretStr(settings.secrets.unreal_speech_api_key),
|
||||
api_key_env_var="UNREAL_SPEECH_API_KEY",
|
||||
title="Use Credits for Unreal",
|
||||
expires_at=None,
|
||||
)
|
||||
@@ -115,7 +104,6 @@ open_router_credentials = APIKeyCredentials(
|
||||
id="b5a0e27d-0c98-4df3-a4b9-10193e1f3c40",
|
||||
provider="open_router",
|
||||
api_key=SecretStr(settings.secrets.open_router_api_key),
|
||||
api_key_env_var="OPEN_ROUTER_API_KEY",
|
||||
title="Use Credits for Open Router",
|
||||
expires_at=None,
|
||||
)
|
||||
@@ -123,7 +111,6 @@ fal_credentials = APIKeyCredentials(
|
||||
id="6c0f5bd0-9008-4638-9d79-4b40b631803e",
|
||||
provider="fal",
|
||||
api_key=SecretStr(settings.secrets.fal_api_key),
|
||||
api_key_env_var="FAL_API_KEY",
|
||||
title="Use Credits for FAL",
|
||||
expires_at=None,
|
||||
)
|
||||
@@ -131,7 +118,6 @@ exa_credentials = APIKeyCredentials(
|
||||
id="96153e04-9c6c-4486-895f-5bb683b1ecec",
|
||||
provider="exa",
|
||||
api_key=SecretStr(settings.secrets.exa_api_key),
|
||||
api_key_env_var="EXA_API_KEY",
|
||||
title="Use Credits for Exa search",
|
||||
expires_at=None,
|
||||
)
|
||||
@@ -139,7 +125,6 @@ e2b_credentials = APIKeyCredentials(
|
||||
id="78d19fd7-4d59-4a16-8277-3ce310acf2b7",
|
||||
provider="e2b",
|
||||
api_key=SecretStr(settings.secrets.e2b_api_key),
|
||||
api_key_env_var="E2B_API_KEY",
|
||||
title="Use Credits for E2B",
|
||||
expires_at=None,
|
||||
)
|
||||
@@ -147,7 +132,6 @@ nvidia_credentials = APIKeyCredentials(
|
||||
id="96b83908-2789-4dec-9968-18f0ece4ceb3",
|
||||
provider="nvidia",
|
||||
api_key=SecretStr(settings.secrets.nvidia_api_key),
|
||||
api_key_env_var="NVIDIA_API_KEY",
|
||||
title="Use Credits for Nvidia",
|
||||
expires_at=None,
|
||||
)
|
||||
@@ -155,7 +139,6 @@ screenshotone_credentials = APIKeyCredentials(
|
||||
id="3b1bdd16-8818-4bc2-8cbb-b23f9a3439ed",
|
||||
provider="screenshotone",
|
||||
api_key=SecretStr(settings.secrets.screenshotone_api_key),
|
||||
api_key_env_var="SCREENSHOTONE_API_KEY",
|
||||
title="Use Credits for ScreenshotOne",
|
||||
expires_at=None,
|
||||
)
|
||||
@@ -163,7 +146,6 @@ mem0_credentials = APIKeyCredentials(
|
||||
id="ed55ac19-356e-4243-a6cb-bc599e9b716f",
|
||||
provider="mem0",
|
||||
api_key=SecretStr(settings.secrets.mem0_api_key),
|
||||
api_key_env_var="MEM0_API_KEY",
|
||||
title="Use Credits for Mem0",
|
||||
expires_at=None,
|
||||
)
|
||||
@@ -172,7 +154,6 @@ apollo_credentials = APIKeyCredentials(
|
||||
id="544c62b5-1d0f-4156-8fb4-9525f11656eb",
|
||||
provider="apollo",
|
||||
api_key=SecretStr(settings.secrets.apollo_api_key),
|
||||
api_key_env_var="APOLLO_API_KEY",
|
||||
title="Use Credits for Apollo",
|
||||
expires_at=None,
|
||||
)
|
||||
@@ -181,7 +162,6 @@ smartlead_credentials = APIKeyCredentials(
|
||||
id="3bcdbda3-84a3-46af-8fdb-bfd2472298b8",
|
||||
provider="smartlead",
|
||||
api_key=SecretStr(settings.secrets.smartlead_api_key),
|
||||
api_key_env_var="SMARTLEAD_API_KEY",
|
||||
title="Use Credits for SmartLead",
|
||||
expires_at=None,
|
||||
)
|
||||
@@ -190,7 +170,6 @@ google_maps_credentials = APIKeyCredentials(
|
||||
id="9aa1bde0-4947-4a70-a20c-84daa3850d52",
|
||||
provider="google_maps",
|
||||
api_key=SecretStr(settings.secrets.google_maps_api_key),
|
||||
api_key_env_var="GOOGLE_MAPS_API_KEY",
|
||||
title="Use Credits for Google Maps",
|
||||
expires_at=None,
|
||||
)
|
||||
@@ -199,7 +178,6 @@ zerobounce_credentials = APIKeyCredentials(
|
||||
id="63a6e279-2dc2-448e-bf57-85776f7176dc",
|
||||
provider="zerobounce",
|
||||
api_key=SecretStr(settings.secrets.zerobounce_api_key),
|
||||
api_key_env_var="ZEROBOUNCE_API_KEY",
|
||||
title="Use Credits for ZeroBounce",
|
||||
expires_at=None,
|
||||
)
|
||||
@@ -208,7 +186,6 @@ enrichlayer_credentials = APIKeyCredentials(
|
||||
id="d9fce73a-6c1d-4e8b-ba2e-12a456789def",
|
||||
provider="enrichlayer",
|
||||
api_key=SecretStr(settings.secrets.enrichlayer_api_key),
|
||||
api_key_env_var="ENRICHLAYER_API_KEY",
|
||||
title="Use Credits for Enrichlayer",
|
||||
expires_at=None,
|
||||
)
|
||||
@@ -218,7 +195,6 @@ llama_api_credentials = APIKeyCredentials(
|
||||
id="d44045af-1c33-4833-9e19-752313214de2",
|
||||
provider="llama_api",
|
||||
api_key=SecretStr(settings.secrets.llama_api_key),
|
||||
api_key_env_var="LLAMA_API_KEY",
|
||||
title="Use Credits for Llama API",
|
||||
expires_at=None,
|
||||
)
|
||||
@@ -227,7 +203,6 @@ v0_credentials = APIKeyCredentials(
|
||||
id="c4e6d1a0-3b5f-4789-a8e2-9b123456789f",
|
||||
provider="v0",
|
||||
api_key=SecretStr(settings.secrets.v0_api_key),
|
||||
api_key_env_var="V0_API_KEY",
|
||||
title="Use Credits for v0 by Vercel",
|
||||
expires_at=None,
|
||||
)
|
||||
|
||||
@@ -17,8 +17,9 @@ from backend.data.model import (
|
||||
)
|
||||
from backend.integrations.oauth.base import BaseOAuthHandler
|
||||
from backend.integrations.webhooks._base import BaseWebhooksManager
|
||||
from backend.sdk.provider import OAuthConfig, Provider, ProviderRegister
|
||||
from backend.sdk.provider import OAuthConfig, Provider
|
||||
from backend.sdk.registry import AutoRegistry
|
||||
from backend.util.settings import Settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -39,7 +40,6 @@ class ProviderBuilder:
|
||||
self._client_id_env_var: Optional[str] = None
|
||||
self._client_secret_env_var: Optional[str] = None
|
||||
self._extra_config: dict = {}
|
||||
self._register: ProviderRegister = ProviderRegister(name=name)
|
||||
|
||||
def with_oauth(
|
||||
self,
|
||||
@@ -48,11 +48,6 @@ class ProviderBuilder:
|
||||
client_id_env_var: Optional[str] = None,
|
||||
client_secret_env_var: Optional[str] = None,
|
||||
) -> "ProviderBuilder":
|
||||
|
||||
self._register.with_oauth = True
|
||||
self._register.client_id_env_var = client_id_env_var
|
||||
self._register.client_secret_env_var = client_secret_env_var
|
||||
|
||||
"""Add OAuth support."""
|
||||
if not client_id_env_var or not client_secret_env_var:
|
||||
client_id_env_var = f"{self.name}_client_id".upper()
|
||||
@@ -78,8 +73,6 @@ class ProviderBuilder:
|
||||
|
||||
def with_api_key(self, env_var_name: str, title: str) -> "ProviderBuilder":
|
||||
"""Add API key support with environment variable name."""
|
||||
self._register.with_api_key = True
|
||||
self._register.api_key_env_var = env_var_name
|
||||
self._supported_auth_types.add("api_key")
|
||||
|
||||
# Register the API key mapping
|
||||
@@ -98,14 +91,30 @@ class ProviderBuilder:
|
||||
)
|
||||
return self
|
||||
|
||||
def with_api_key_from_settings(
|
||||
self, settings_attr: str, title: str
|
||||
) -> "ProviderBuilder":
|
||||
"""Use existing API key from settings."""
|
||||
self._supported_auth_types.add("api_key")
|
||||
|
||||
# Try to get the API key from settings
|
||||
settings = Settings()
|
||||
api_key = getattr(settings.secrets, settings_attr, None)
|
||||
if api_key:
|
||||
self._default_credentials.append(
|
||||
APIKeyCredentials(
|
||||
id=f"{self.name}-default",
|
||||
provider=self.name,
|
||||
api_key=api_key,
|
||||
title=title,
|
||||
)
|
||||
)
|
||||
return self
|
||||
|
||||
def with_user_password(
|
||||
self, username_env_var: str, password_env_var: str, title: str
|
||||
) -> "ProviderBuilder":
|
||||
"""Add username/password support with environment variable names."""
|
||||
self._register.with_user_password = True
|
||||
self._register.username_env_var = username_env_var
|
||||
self._register.password_env_var = password_env_var
|
||||
|
||||
self._supported_auth_types.add("user_password")
|
||||
|
||||
# Check if credentials exist in environment
|
||||
@@ -165,7 +174,6 @@ class ProviderBuilder:
|
||||
supported_auth_types=self._supported_auth_types,
|
||||
api_client_factory=self._api_client_factory,
|
||||
error_handler=self._error_handler,
|
||||
register=self._register,
|
||||
**self._extra_config,
|
||||
)
|
||||
|
||||
|
||||
@@ -1,123 +0,0 @@
|
||||
from typing import Dict
|
||||
|
||||
from prisma import Prisma
|
||||
from prisma.models import ProviderRegistry as PrismaProviderRegistry
|
||||
|
||||
from backend.sdk.provider import ProviderRegister
|
||||
|
||||
|
||||
def is_providers_different(
|
||||
current_provider: PrismaProviderRegistry, new_provider: ProviderRegister
|
||||
) -> bool:
|
||||
"""
|
||||
Compare a current provider (as stored in the database) against a new provider registration
|
||||
and determine if they are different. This is done by converting the database model to a
|
||||
ProviderRegister and checking for equality (all fields compared).
|
||||
|
||||
Args:
|
||||
current_provider (PrismaProviderRegistry): The provider as stored in the database.
|
||||
new_provider (ProviderRegister): The provider specification to compare.
|
||||
|
||||
Returns:
|
||||
bool: True if the providers differ, False if they are effectively the same.
|
||||
"""
|
||||
current_provider_register = ProviderRegister(
|
||||
name=current_provider.name,
|
||||
with_oauth=current_provider.with_oauth,
|
||||
client_id_env_var=current_provider.client_id_env_var,
|
||||
client_secret_env_var=current_provider.client_secret_env_var,
|
||||
with_api_key=current_provider.with_api_key,
|
||||
api_key_env_var=current_provider.api_key_env_var,
|
||||
with_user_password=current_provider.with_user_password,
|
||||
username_env_var=current_provider.username_env_var,
|
||||
password_env_var=current_provider.password_env_var,
|
||||
)
|
||||
if current_provider_register == new_provider:
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def find_delta_providers(
|
||||
current_providers: Dict[str, PrismaProviderRegistry],
|
||||
providers: Dict[str, ProviderRegister],
|
||||
) -> Dict[str, ProviderRegister]:
|
||||
"""
|
||||
Identify providers that are either new or updated compared to the current providers list.
|
||||
|
||||
Args:
|
||||
current_providers (Dict[str, PrismaProviderRegistry]): Dictionary of current provider models keyed by provider name.
|
||||
providers (Dict[str, ProviderRegister]): Dictionary of new provider registrations keyed by provider name.
|
||||
|
||||
Returns:
|
||||
Dict[str, ProviderRegister]: Providers that need to be added/updated in the registry.
|
||||
- Includes providers not in current_providers.
|
||||
- Includes providers where the data differs from what's in current_providers.
|
||||
"""
|
||||
provider_update = {}
|
||||
for name, provider in providers.items():
|
||||
if name not in current_providers:
|
||||
provider_update[name] = provider
|
||||
else:
|
||||
if is_providers_different(current_providers[name], provider):
|
||||
provider_update[name] = provider
|
||||
|
||||
return provider_update
|
||||
|
||||
|
||||
async def get_providers() -> Dict[str, PrismaProviderRegistry]:
|
||||
"""
|
||||
Retrieve all provider registries from the database.
|
||||
|
||||
Returns:
|
||||
Dict[str, PrismaProviderRegistry]: Dictionary of all current providers, keyed by provider name.
|
||||
"""
|
||||
async with Prisma() as prisma:
|
||||
providers = await prisma.providerregistry.find_many()
|
||||
return {
|
||||
provider.name: PrismaProviderRegistry(**provider.model_dump())
|
||||
for provider in providers
|
||||
}
|
||||
|
||||
|
||||
async def upsert_providers_change_bulk(providers: Dict[str, ProviderRegister]):
|
||||
"""
|
||||
Bulk upsert providers into the database after checking for changes.
|
||||
|
||||
Args:
|
||||
providers (Dict[str, ProviderRegister]): Dictionary of new provider registrations keyed by provider name.
|
||||
"""
|
||||
current_providers = await get_providers()
|
||||
provider_update = find_delta_providers(current_providers, providers)
|
||||
"""Async version of bulk upsert providers with all fields using transaction for atomicity"""
|
||||
async with Prisma() as prisma:
|
||||
async with prisma.tx() as tx:
|
||||
results = []
|
||||
for name, provider in provider_update.items():
|
||||
result = await tx.providerregistry.upsert(
|
||||
where={"name": name},
|
||||
data={
|
||||
"create": {
|
||||
"name": name,
|
||||
"with_oauth": provider.with_oauth,
|
||||
"client_id_env_var": provider.client_id_env_var,
|
||||
"client_secret_env_var": provider.client_secret_env_var,
|
||||
"with_api_key": provider.with_api_key,
|
||||
"api_key_env_var": provider.api_key_env_var,
|
||||
"with_user_password": provider.with_user_password,
|
||||
"username_env_var": provider.username_env_var,
|
||||
"password_env_var": provider.password_env_var,
|
||||
},
|
||||
"update": {
|
||||
"with_oauth": provider.with_oauth,
|
||||
"client_id_env_var": provider.client_id_env_var,
|
||||
"client_secret_env_var": provider.client_secret_env_var,
|
||||
"with_api_key": provider.with_api_key,
|
||||
"api_key_env_var": provider.api_key_env_var,
|
||||
"with_user_password": provider.with_user_password,
|
||||
"username_env_var": provider.username_env_var,
|
||||
"password_env_var": provider.password_env_var,
|
||||
},
|
||||
},
|
||||
)
|
||||
results.append(result)
|
||||
return results
|
||||
@@ -1,127 +0,0 @@
|
||||
from datetime import datetime
|
||||
|
||||
import pytest
|
||||
from prisma.models import ProviderRegistry as PrismaProviderRegistry
|
||||
|
||||
from backend.sdk.db import find_delta_providers, is_providers_different
|
||||
from backend.sdk.provider import ProviderRegister
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
def test_is_providers_different_same():
|
||||
current_provider = PrismaProviderRegistry(
|
||||
name="test_provider",
|
||||
with_oauth=True,
|
||||
client_id_env_var="TEST_CLIENT_ID",
|
||||
client_secret_env_var="TEST_CLIENT_SECRET",
|
||||
with_api_key=True,
|
||||
api_key_env_var="TEST_API_KEY",
|
||||
with_user_password=True,
|
||||
username_env_var="TEST_USERNAME",
|
||||
password_env_var="TEST_PASSWORD",
|
||||
updatedAt=datetime.now(),
|
||||
)
|
||||
new_provider = ProviderRegister(
|
||||
name="test_provider",
|
||||
with_oauth=True,
|
||||
client_id_env_var="TEST_CLIENT_ID",
|
||||
client_secret_env_var="TEST_CLIENT_SECRET",
|
||||
with_api_key=True,
|
||||
api_key_env_var="TEST_API_KEY",
|
||||
with_user_password=True,
|
||||
username_env_var="TEST_USERNAME",
|
||||
password_env_var="TEST_PASSWORD",
|
||||
)
|
||||
assert not is_providers_different(current_provider, new_provider)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
def test_is_providers_different_different():
|
||||
current_provider = PrismaProviderRegistry(
|
||||
name="test_provider",
|
||||
with_oauth=True,
|
||||
client_id_env_var="TEST_CLIENT_ID",
|
||||
client_secret_env_var="TEST_CLIENT_SECRET",
|
||||
with_api_key=True,
|
||||
api_key_env_var="TEST_API_KEY",
|
||||
with_user_password=True,
|
||||
username_env_var="TEST_USERNAME",
|
||||
password_env_var="TEST_PASSWORD",
|
||||
updatedAt=datetime.now(),
|
||||
)
|
||||
new_provider = ProviderRegister(
|
||||
name="test_provider",
|
||||
with_oauth=False,
|
||||
with_api_key=True,
|
||||
api_key_env_var="TEST_API_KEY",
|
||||
with_user_password=True,
|
||||
username_env_var="TEST_USERNAME",
|
||||
password_env_var="TEST_PASSWORD",
|
||||
)
|
||||
assert is_providers_different(current_provider, new_provider)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
def test_find_delta_providers():
|
||||
current_providers = {
|
||||
"test_provider": PrismaProviderRegistry(
|
||||
name="test_provider",
|
||||
with_oauth=True,
|
||||
client_id_env_var="TEST_CLIENT_ID",
|
||||
client_secret_env_var="TEST_CLIENT_SECRET",
|
||||
with_api_key=True,
|
||||
api_key_env_var="TEST_API_KEY",
|
||||
with_user_password=True,
|
||||
username_env_var="TEST_USERNAME",
|
||||
password_env_var="TEST_PASSWORD",
|
||||
updatedAt=datetime.now(),
|
||||
),
|
||||
"test_provider_2": PrismaProviderRegistry(
|
||||
name="test_provider_2",
|
||||
with_oauth=True,
|
||||
client_id_env_var="TEST_CLIENT_ID_2",
|
||||
client_secret_env_var="TEST_CLIENT_SECRET_2",
|
||||
with_api_key=True,
|
||||
api_key_env_var="TEST_API_KEY_2",
|
||||
with_user_password=True,
|
||||
username_env_var="TEST_USERNAME_2",
|
||||
password_env_var="TEST_PASSWORD_2",
|
||||
updatedAt=datetime.now(),
|
||||
),
|
||||
}
|
||||
new_providers = {
|
||||
"test_provider": ProviderRegister(
|
||||
name="test_provider",
|
||||
with_oauth=True,
|
||||
client_id_env_var="TEST_CLIENT_ID",
|
||||
client_secret_env_var="TEST_CLIENT_SECRET",
|
||||
with_api_key=True,
|
||||
api_key_env_var="TEST_API_KEY",
|
||||
with_user_password=True,
|
||||
username_env_var="TEST_USERNAME",
|
||||
password_env_var="TEST_PASSWORD",
|
||||
),
|
||||
"test_provider_2": ProviderRegister(
|
||||
name="test_provider_2",
|
||||
with_oauth=False,
|
||||
with_api_key=True,
|
||||
api_key_env_var="TEST_API_KEY_2",
|
||||
with_user_password=True,
|
||||
username_env_var="TEST_USERNAME_2",
|
||||
password_env_var="TEST_PASSWORD_2",
|
||||
),
|
||||
"test_provider_3": ProviderRegister(
|
||||
name="test_provider_3",
|
||||
with_oauth=True,
|
||||
client_id_env_var="TEST_CLIENT_ID_3",
|
||||
client_secret_env_var="TEST_CLIENT_SECRET_3",
|
||||
with_api_key=False,
|
||||
with_user_password=True,
|
||||
username_env_var="TEST_USERNAME_3",
|
||||
password_env_var="TEST_PASSWORD_3",
|
||||
),
|
||||
}
|
||||
assert find_delta_providers(current_providers, new_providers) == {
|
||||
"test_provider_2": new_providers["test_provider_2"],
|
||||
"test_provider_3": new_providers["test_provider_3"],
|
||||
}
|
||||
@@ -30,23 +30,6 @@ class OAuthConfig(BaseModel):
|
||||
client_secret_env_var: str
|
||||
|
||||
|
||||
class ProviderRegister(BaseModel):
|
||||
"""Provider log configuration for SDK providers."""
|
||||
|
||||
name: str
|
||||
|
||||
with_oauth: bool = False
|
||||
client_id_env_var: Optional[str] = None
|
||||
client_secret_env_var: Optional[str] = None
|
||||
|
||||
with_api_key: bool = False
|
||||
api_key_env_var: Optional[str] = None
|
||||
|
||||
with_user_password: bool = False
|
||||
username_env_var: Optional[str] = None
|
||||
password_env_var: Optional[str] = None
|
||||
|
||||
|
||||
class Provider:
|
||||
"""A configured provider that blocks can use.
|
||||
|
||||
@@ -65,7 +48,6 @@ class Provider:
|
||||
def __init__(
|
||||
self,
|
||||
name: str,
|
||||
register: ProviderRegister,
|
||||
oauth_config: Optional[OAuthConfig] = None,
|
||||
webhook_manager: Optional[Type[BaseWebhooksManager]] = None,
|
||||
default_credentials: Optional[List[Credentials]] = None,
|
||||
@@ -83,7 +65,7 @@ class Provider:
|
||||
self.supported_auth_types = supported_auth_types or set()
|
||||
self._api_client_factory = api_client_factory
|
||||
self._error_handler = error_handler
|
||||
self.register = register
|
||||
|
||||
# Store any additional configuration
|
||||
self._extra_config = kwargs
|
||||
self.test_credentials_uuid = uuid.uuid4()
|
||||
|
||||
@@ -13,11 +13,9 @@ from backend.data.model import Credentials
|
||||
from backend.integrations.oauth.base import BaseOAuthHandler
|
||||
from backend.integrations.providers import ProviderName
|
||||
from backend.integrations.webhooks._base import BaseWebhooksManager
|
||||
from backend.sdk.db import upsert_providers_change_bulk
|
||||
from backend.sdk.provider import ProviderRegister
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from backend.sdk.provider import Provider, ProviderRegister
|
||||
from backend.sdk.provider import Provider
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -59,7 +57,6 @@ class AutoRegistry:
|
||||
_webhook_managers: Dict[str, Type[BaseWebhooksManager]] = {}
|
||||
_block_configurations: Dict[Type[Block], BlockConfiguration] = {}
|
||||
_api_key_mappings: Dict[str, str] = {} # provider -> env_var_name
|
||||
_provider_registry: Dict[str, ProviderRegister] = {}
|
||||
|
||||
@classmethod
|
||||
def register_provider(cls, provider: "Provider") -> None:
|
||||
@@ -67,7 +64,6 @@ class AutoRegistry:
|
||||
with cls._lock:
|
||||
cls._providers[provider.name] = provider
|
||||
|
||||
cls._provider_registry[provider.name] = provider.register
|
||||
# Register OAuth handler if provided
|
||||
if provider.oauth_config:
|
||||
# Dynamically set PROVIDER_NAME if not already set
|
||||
@@ -167,7 +163,7 @@ class AutoRegistry:
|
||||
cls._api_key_mappings.clear()
|
||||
|
||||
@classmethod
|
||||
async def patch_integrations(cls) -> None:
|
||||
def patch_integrations(cls) -> None:
|
||||
"""Patch existing integration points to use AutoRegistry."""
|
||||
# OAuth handlers are handled by SDKAwareHandlersDict in oauth/__init__.py
|
||||
# No patching needed for OAuth handlers
|
||||
@@ -217,73 +213,6 @@ class AutoRegistry:
|
||||
|
||||
creds_store: Any = backend.integrations.credentials_store
|
||||
|
||||
if "backend.integrations.providers" in sys.modules:
|
||||
providers: Any = sys.modules["backend.integrations.providers"]
|
||||
else:
|
||||
import backend.integrations.providers
|
||||
|
||||
providers: Any = backend.integrations.providers
|
||||
|
||||
legacy_oauth_providers = {
|
||||
providers.ProviderName.DISCORD.value: ProviderRegister(
|
||||
name=providers.ProviderName.DISCORD.value,
|
||||
with_oauth=True,
|
||||
client_id_env_var="DISCORD_CLIENT_ID",
|
||||
client_secret_env_var="DISCORD_CLIENT_SECRET",
|
||||
),
|
||||
providers.ProviderName.GITHUB.value: ProviderRegister(
|
||||
name=providers.ProviderName.GITHUB.value,
|
||||
with_oauth=True,
|
||||
client_id_env_var="GITHUB_CLIENT_ID",
|
||||
client_secret_env_var="GITHUB_CLIENT_SECRET",
|
||||
),
|
||||
providers.ProviderName.GOOGLE.value: ProviderRegister(
|
||||
name=providers.ProviderName.GOOGLE.value,
|
||||
with_oauth=True,
|
||||
client_id_env_var="GOOGLE_CLIENT_ID",
|
||||
client_secret_env_var="GOOGLE_CLIENT_SECRET",
|
||||
),
|
||||
providers.ProviderName.NOTION.value: ProviderRegister(
|
||||
name=providers.ProviderName.NOTION.value,
|
||||
with_oauth=True,
|
||||
client_id_env_var="NOTION_CLIENT_ID",
|
||||
client_secret_env_var="NOTION_CLIENT_SECRET",
|
||||
),
|
||||
providers.ProviderName.TWITTER.value: ProviderRegister(
|
||||
name=providers.ProviderName.TWITTER.value,
|
||||
with_oauth=True,
|
||||
client_id_env_var="TWITTER_CLIENT_ID",
|
||||
client_secret_env_var="TWITTER_CLIENT_SECRET",
|
||||
),
|
||||
providers.ProviderName.TODOIST.value: ProviderRegister(
|
||||
name=providers.ProviderName.TODOIST.value,
|
||||
with_oauth=True,
|
||||
client_id_env_var="TODOIST_CLIENT_ID",
|
||||
client_secret_env_var="TODOIST_CLIENT_SECRET",
|
||||
),
|
||||
}
|
||||
|
||||
if hasattr(creds_store, "DEFAULT_CREDENTIALS"):
|
||||
DEFAULT_CREDENTIALS = creds_store.DEFAULT_CREDENTIALS
|
||||
for item in DEFAULT_CREDENTIALS:
|
||||
new_cred = ProviderRegister(
|
||||
name=item.provider,
|
||||
with_api_key=True,
|
||||
api_key_env_var=item.api_key_env_var,
|
||||
)
|
||||
if item.provider in legacy_oauth_providers:
|
||||
new_cred.with_oauth = True
|
||||
new_cred.client_id_env_var = legacy_oauth_providers[
|
||||
item.provider
|
||||
].client_id_env_var
|
||||
new_cred.client_secret_env_var = legacy_oauth_providers[
|
||||
item.provider
|
||||
].client_secret_env_var
|
||||
|
||||
cls._provider_registry[item.provider] = new_cred
|
||||
|
||||
await upsert_providers_change_bulk(providers=cls._provider_registry)
|
||||
|
||||
if hasattr(creds_store, "IntegrationCredentialsStore"):
|
||||
store_class = creds_store.IntegrationCredentialsStore
|
||||
if hasattr(store_class, "get_all_creds"):
|
||||
@@ -308,6 +237,5 @@ class AutoRegistry:
|
||||
logger.info(
|
||||
"Successfully patched IntegrationCredentialsStore.get_all_creds"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logging.warning(f"Failed to patch credentials store: {e}")
|
||||
|
||||
@@ -16,7 +16,6 @@ from fastapi.middleware.gzip import GZipMiddleware
|
||||
from fastapi.routing import APIRoute
|
||||
from prisma.errors import PrismaError
|
||||
|
||||
import backend.blocks
|
||||
import backend.data.block
|
||||
import backend.data.db
|
||||
import backend.data.graph
|
||||
@@ -96,13 +95,10 @@ async def lifespan_context(app: fastapi.FastAPI):
|
||||
# Ensure SDK auto-registration is patched before initializing blocks
|
||||
from backend.sdk.registry import AutoRegistry
|
||||
|
||||
await AutoRegistry.patch_integrations()
|
||||
AutoRegistry.patch_integrations()
|
||||
|
||||
await backend.data.block.initialize_blocks()
|
||||
|
||||
blocks = backend.blocks.load_all_blocks()
|
||||
|
||||
await backend.data.block.upsert_blocks_change_bulk(blocks)
|
||||
await backend.data.user.migrate_and_encrypt_user_integrations()
|
||||
await backend.data.graph.fix_llm_provider_credentials()
|
||||
await backend.data.graph.migrate_llm_models(LlmModel.GPT4O)
|
||||
|
||||
@@ -1,31 +0,0 @@
|
||||
-- CreateTable
|
||||
CREATE TABLE "ProviderRegistry" (
|
||||
"name" TEXT NOT NULL,
|
||||
"with_oauth" BOOLEAN NOT NULL DEFAULT false,
|
||||
"client_id_env_var" TEXT,
|
||||
"client_secret_env_var" TEXT,
|
||||
"with_api_key" BOOLEAN NOT NULL DEFAULT false,
|
||||
"api_key_env_var" TEXT,
|
||||
"with_user_password" BOOLEAN NOT NULL DEFAULT false,
|
||||
"username_env_var" TEXT,
|
||||
"password_env_var" TEXT,
|
||||
"updatedAt" TIMESTAMP(3) NOT NULL,
|
||||
|
||||
CONSTRAINT "ProviderRegistry_pkey" PRIMARY KEY ("name")
|
||||
);
|
||||
|
||||
-- CreateTable
|
||||
CREATE TABLE "BlocksRegistry" (
|
||||
"id" TEXT NOT NULL,
|
||||
"name" TEXT NOT NULL,
|
||||
"definition" JSONB NOT NULL,
|
||||
"updatedAt" TIMESTAMP(3) NOT NULL,
|
||||
|
||||
CONSTRAINT "BlocksRegistry_pkey" PRIMARY KEY ("id")
|
||||
);
|
||||
|
||||
-- CreateIndex
|
||||
CREATE INDEX "ProviderRegistry_updatedAt_idx" ON "ProviderRegistry"("updatedAt");
|
||||
|
||||
-- CreateIndex
|
||||
CREATE INDEX "BlocksRegistry_updatedAt_idx" ON "BlocksRegistry"("updatedAt");
|
||||
@@ -61,34 +61,6 @@ model User {
|
||||
NotificationBatches UserNotificationBatch[]
|
||||
}
|
||||
|
||||
// This model describes the providers that are available to the user.
|
||||
model ProviderRegistry {
|
||||
name String @id
|
||||
|
||||
with_oauth Boolean @default(false)
|
||||
client_id_env_var String?
|
||||
client_secret_env_var String?
|
||||
|
||||
with_api_key Boolean @default(false)
|
||||
api_key_env_var String?
|
||||
|
||||
with_user_password Boolean @default(false)
|
||||
username_env_var String?
|
||||
password_env_var String?
|
||||
updatedAt DateTime @updatedAt
|
||||
|
||||
@@index([updatedAt])
|
||||
}
|
||||
|
||||
model BlocksRegistry {
|
||||
id String @id @default(uuid())
|
||||
name String
|
||||
definition Json
|
||||
updatedAt DateTime @updatedAt
|
||||
|
||||
@@index([updatedAt])
|
||||
}
|
||||
|
||||
enum OnboardingStep {
|
||||
// Introductory onboarding (Library)
|
||||
WELCOME
|
||||
|
||||
@@ -52,8 +52,7 @@ class TestWebhookPatching:
|
||||
"""Clear registry."""
|
||||
AutoRegistry.clear()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_webhook_manager_patching(self):
|
||||
def test_webhook_manager_patching(self):
|
||||
"""Test that webhook managers are correctly patched."""
|
||||
|
||||
# Mock the original load_webhook_managers function
|
||||
@@ -76,7 +75,7 @@ class TestWebhookPatching:
|
||||
with patch.dict(
|
||||
"sys.modules", {"backend.integrations.webhooks": mock_webhooks_module}
|
||||
):
|
||||
await AutoRegistry.patch_integrations()
|
||||
AutoRegistry.patch_integrations()
|
||||
|
||||
# Call the patched function
|
||||
result = mock_webhooks_module.load_webhook_managers()
|
||||
@@ -88,8 +87,7 @@ class TestWebhookPatching:
|
||||
assert "webhook_provider" in result
|
||||
assert result["webhook_provider"] == MockWebhookManager
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_webhook_patching_no_original_function(self):
|
||||
def test_webhook_patching_no_original_function(self):
|
||||
"""Test webhook patching when load_webhook_managers doesn't exist."""
|
||||
# Mock webhooks module without load_webhook_managers
|
||||
mock_webhooks_module = MagicMock(spec=[])
|
||||
@@ -105,7 +103,7 @@ class TestWebhookPatching:
|
||||
"sys.modules", {"backend.integrations.webhooks": mock_webhooks_module}
|
||||
):
|
||||
# Should not raise an error
|
||||
await AutoRegistry.patch_integrations()
|
||||
AutoRegistry.patch_integrations()
|
||||
|
||||
# Function should not be added if it didn't exist
|
||||
assert not hasattr(mock_webhooks_module, "load_webhook_managers")
|
||||
@@ -118,8 +116,7 @@ class TestPatchingIntegration:
|
||||
"""Clear registry."""
|
||||
AutoRegistry.clear()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_complete_provider_registration_and_patching(self):
|
||||
def test_complete_provider_registration_and_patching(self):
|
||||
"""Test the complete flow from provider registration to patching."""
|
||||
# Mock webhooks module
|
||||
mock_webhooks = MagicMock()
|
||||
@@ -141,7 +138,7 @@ class TestPatchingIntegration:
|
||||
"backend.integrations.webhooks": mock_webhooks,
|
||||
},
|
||||
):
|
||||
await AutoRegistry.patch_integrations()
|
||||
AutoRegistry.patch_integrations()
|
||||
|
||||
# Verify webhook patching
|
||||
webhook_result = mock_webhooks.load_webhook_managers()
|
||||
|
||||
@@ -25,7 +25,6 @@ from backend.sdk import (
|
||||
Provider,
|
||||
ProviderBuilder,
|
||||
)
|
||||
from backend.sdk.provider import ProviderRegister
|
||||
|
||||
|
||||
class TestAutoRegistry:
|
||||
@@ -40,7 +39,6 @@ class TestAutoRegistry:
|
||||
# Create a test provider
|
||||
provider = Provider(
|
||||
name="test_provider",
|
||||
register=ProviderRegister(name="test_provider"),
|
||||
oauth_handler=None,
|
||||
webhook_manager=None,
|
||||
default_credentials=[],
|
||||
@@ -80,7 +78,6 @@ class TestAutoRegistry:
|
||||
default_credentials=[],
|
||||
base_costs=[],
|
||||
supported_auth_types={"oauth2"},
|
||||
register=ProviderRegister(name="oauth_provider"),
|
||||
)
|
||||
|
||||
AutoRegistry.register_provider(provider)
|
||||
@@ -98,7 +95,6 @@ class TestAutoRegistry:
|
||||
|
||||
provider = Provider(
|
||||
name="webhook_provider",
|
||||
register=ProviderRegister(name="webhook_provider"),
|
||||
oauth_handler=None,
|
||||
webhook_manager=TestWebhookManager,
|
||||
default_credentials=[],
|
||||
@@ -132,7 +128,6 @@ class TestAutoRegistry:
|
||||
|
||||
provider = Provider(
|
||||
name="test_provider",
|
||||
register=ProviderRegister(name="test_provider"),
|
||||
oauth_handler=None,
|
||||
webhook_manager=None,
|
||||
default_credentials=[cred1, cred2],
|
||||
@@ -199,7 +194,6 @@ class TestAutoRegistry:
|
||||
):
|
||||
provider1 = Provider(
|
||||
name="provider1",
|
||||
register=ProviderRegister(name="provider1"),
|
||||
oauth_config=OAuthConfig(
|
||||
oauth_handler=TestOAuth1,
|
||||
client_id_env_var="TEST_CLIENT_ID",
|
||||
@@ -213,7 +207,6 @@ class TestAutoRegistry:
|
||||
|
||||
provider2 = Provider(
|
||||
name="provider2",
|
||||
register=ProviderRegister(name="provider2"),
|
||||
oauth_config=OAuthConfig(
|
||||
oauth_handler=TestOAuth2,
|
||||
client_id_env_var="TEST_CLIENT_ID",
|
||||
@@ -260,7 +253,6 @@ class TestAutoRegistry:
|
||||
# Add some registrations
|
||||
provider = Provider(
|
||||
name="test_provider",
|
||||
register=ProviderRegister(name="test_provider"),
|
||||
oauth_handler=None,
|
||||
webhook_manager=None,
|
||||
default_credentials=[],
|
||||
@@ -290,8 +282,7 @@ class TestAutoRegistryPatching:
|
||||
AutoRegistry.clear()
|
||||
|
||||
@patch("backend.integrations.webhooks.load_webhook_managers")
|
||||
@pytest.mark.asyncio
|
||||
async def test_webhook_manager_patching(self, mock_load_managers):
|
||||
def test_webhook_manager_patching(self, mock_load_managers):
|
||||
"""Test that webhook managers are patched into the system."""
|
||||
# Set up the mock to return an empty dict
|
||||
mock_load_managers.return_value = {}
|
||||
@@ -303,7 +294,6 @@ class TestAutoRegistryPatching:
|
||||
# Register a provider with webhooks
|
||||
provider = Provider(
|
||||
name="webhook_provider",
|
||||
register=ProviderRegister(name="webhook_provider"),
|
||||
oauth_handler=None,
|
||||
webhook_manager=TestWebhookManager,
|
||||
default_credentials=[],
|
||||
@@ -320,8 +310,8 @@ class TestAutoRegistryPatching:
|
||||
with patch.dict(
|
||||
"sys.modules", {"backend.integrations.webhooks": mock_webhooks}
|
||||
):
|
||||
# Apply patches - now async
|
||||
await AutoRegistry.patch_integrations()
|
||||
# Apply patches
|
||||
AutoRegistry.patch_integrations()
|
||||
|
||||
# Call the patched function
|
||||
result = mock_webhooks.load_webhook_managers()
|
||||
|
||||
@@ -7,7 +7,7 @@ import {
|
||||
usePostV1CreateCredentials,
|
||||
} from "@/app/api/__generated__/endpoints/integrations/integrations";
|
||||
import { useToast } from "@/components/molecules/Toast/use-toast";
|
||||
import type { PostV1CreateCredentialsBody } from "@/app/api/__generated__/models/postV1CreateCredentialsBody";
|
||||
import { APIKeyCredentials } from "@/app/api/__generated__/models/aPIKeyCredentials";
|
||||
import { useQueryClient } from "@tanstack/react-query";
|
||||
import { useState } from "react";
|
||||
|
||||
@@ -89,7 +89,7 @@ export function useAPIKeyCredentialsModal({
|
||||
api_key: values.apiKey,
|
||||
title: values.title,
|
||||
expires_at: expiresAt,
|
||||
} as PostV1CreateCredentialsBody,
|
||||
} as APIKeyCredentials,
|
||||
});
|
||||
}
|
||||
|
||||
|
||||
@@ -207,7 +207,7 @@
|
||||
"schema": {
|
||||
"oneOf": [
|
||||
{ "$ref": "#/components/schemas/OAuth2Credentials" },
|
||||
{ "$ref": "#/components/schemas/APIKeyCredentials-Input" },
|
||||
{ "$ref": "#/components/schemas/APIKeyCredentials" },
|
||||
{ "$ref": "#/components/schemas/UserPasswordCredentials" },
|
||||
{ "$ref": "#/components/schemas/HostScopedCredentials-Input" }
|
||||
],
|
||||
@@ -215,7 +215,7 @@
|
||||
"propertyName": "type",
|
||||
"mapping": {
|
||||
"oauth2": "#/components/schemas/OAuth2Credentials",
|
||||
"api_key": "#/components/schemas/APIKeyCredentials-Input",
|
||||
"api_key": "#/components/schemas/APIKeyCredentials",
|
||||
"user_password": "#/components/schemas/UserPasswordCredentials",
|
||||
"host_scoped": "#/components/schemas/HostScopedCredentials-Input"
|
||||
}
|
||||
@@ -233,7 +233,7 @@
|
||||
"schema": {
|
||||
"oneOf": [
|
||||
{ "$ref": "#/components/schemas/OAuth2Credentials" },
|
||||
{ "$ref": "#/components/schemas/APIKeyCredentials-Output" },
|
||||
{ "$ref": "#/components/schemas/APIKeyCredentials" },
|
||||
{ "$ref": "#/components/schemas/UserPasswordCredentials" },
|
||||
{
|
||||
"$ref": "#/components/schemas/HostScopedCredentials-Output"
|
||||
@@ -243,7 +243,7 @@
|
||||
"propertyName": "type",
|
||||
"mapping": {
|
||||
"oauth2": "#/components/schemas/OAuth2Credentials",
|
||||
"api_key": "#/components/schemas/APIKeyCredentials-Output",
|
||||
"api_key": "#/components/schemas/APIKeyCredentials",
|
||||
"user_password": "#/components/schemas/UserPasswordCredentials",
|
||||
"host_scoped": "#/components/schemas/HostScopedCredentials-Output"
|
||||
}
|
||||
@@ -302,7 +302,7 @@
|
||||
"schema": {
|
||||
"oneOf": [
|
||||
{ "$ref": "#/components/schemas/OAuth2Credentials" },
|
||||
{ "$ref": "#/components/schemas/APIKeyCredentials-Output" },
|
||||
{ "$ref": "#/components/schemas/APIKeyCredentials" },
|
||||
{ "$ref": "#/components/schemas/UserPasswordCredentials" },
|
||||
{
|
||||
"$ref": "#/components/schemas/HostScopedCredentials-Output"
|
||||
@@ -312,7 +312,7 @@
|
||||
"propertyName": "type",
|
||||
"mapping": {
|
||||
"oauth2": "#/components/schemas/OAuth2Credentials",
|
||||
"api_key": "#/components/schemas/APIKeyCredentials-Output",
|
||||
"api_key": "#/components/schemas/APIKeyCredentials",
|
||||
"user_password": "#/components/schemas/UserPasswordCredentials",
|
||||
"host_scoped": "#/components/schemas/HostScopedCredentials-Output"
|
||||
}
|
||||
@@ -4782,41 +4782,7 @@
|
||||
},
|
||||
"components": {
|
||||
"schemas": {
|
||||
"APIKeyCredentials-Input": {
|
||||
"properties": {
|
||||
"id": { "type": "string", "title": "Id" },
|
||||
"provider": { "type": "string", "title": "Provider" },
|
||||
"title": {
|
||||
"anyOf": [{ "type": "string" }, { "type": "null" }],
|
||||
"title": "Title"
|
||||
},
|
||||
"type": {
|
||||
"type": "string",
|
||||
"const": "api_key",
|
||||
"title": "Type",
|
||||
"default": "api_key"
|
||||
},
|
||||
"api_key": {
|
||||
"type": "string",
|
||||
"format": "password",
|
||||
"title": "Api Key",
|
||||
"writeOnly": true
|
||||
},
|
||||
"expires_at": {
|
||||
"anyOf": [{ "type": "integer" }, { "type": "null" }],
|
||||
"title": "Expires At",
|
||||
"description": "Unix timestamp (seconds) indicating when the API key expires (if at all)"
|
||||
},
|
||||
"api_key_env_var": {
|
||||
"anyOf": [{ "type": "string" }, { "type": "null" }],
|
||||
"title": "Api Key Env Var"
|
||||
}
|
||||
},
|
||||
"type": "object",
|
||||
"required": ["provider", "api_key"],
|
||||
"title": "APIKeyCredentials"
|
||||
},
|
||||
"APIKeyCredentials-Output": {
|
||||
"APIKeyCredentials": {
|
||||
"properties": {
|
||||
"id": { "type": "string", "title": "Id" },
|
||||
"provider": { "type": "string", "title": "Provider" },
|
||||
|
||||
Reference in New Issue
Block a user