mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-04-08 03:00:28 -04:00
Add block registry and updated
This commit is contained in:
@@ -1,3 +1,4 @@
|
||||
import collections
|
||||
import inspect
|
||||
import logging
|
||||
import os
|
||||
@@ -9,6 +10,7 @@ from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
ClassVar,
|
||||
Dict,
|
||||
Generic,
|
||||
Optional,
|
||||
Sequence,
|
||||
@@ -20,7 +22,8 @@ from typing import (
|
||||
|
||||
import jsonref
|
||||
import jsonschema
|
||||
from prisma.models import AgentBlock
|
||||
from prisma import Json
|
||||
from prisma.models import AgentBlock, BlocksRegistry
|
||||
from prisma.types import AgentBlockCreateInput
|
||||
from pydantic import BaseModel
|
||||
|
||||
@@ -479,19 +482,50 @@ class Block(ABC, Generic[BlockSchemaInputType, BlockSchemaOutputType]):
|
||||
return self.__class__.__name__
|
||||
|
||||
def to_dict(self):
|
||||
return {
|
||||
"id": self.id,
|
||||
"name": self.name,
|
||||
"inputSchema": self.input_schema.jsonschema(),
|
||||
"outputSchema": self.output_schema.jsonschema(),
|
||||
"description": self.description,
|
||||
"categories": [category.dict() for category in self.categories],
|
||||
"contributors": [
|
||||
contributor.model_dump() for contributor in self.contributors
|
||||
],
|
||||
"staticOutput": self.static_output,
|
||||
"uiType": self.block_type.value,
|
||||
}
|
||||
# Sort categories by their name to ensure consistent ordering
|
||||
sorted_categories = [
|
||||
category.dict()
|
||||
for category in sorted(self.categories, key=lambda c: c.name)
|
||||
]
|
||||
|
||||
# Sort dictionary keys recursively for consistent ordering
|
||||
def sort_dict(obj):
|
||||
if isinstance(obj, dict):
|
||||
return collections.OrderedDict(
|
||||
sorted((k, sort_dict(v)) for k, v in obj.items())
|
||||
)
|
||||
elif isinstance(obj, list):
|
||||
# Check if all items in the list are primitive types that can be sorted
|
||||
if obj and all(
|
||||
isinstance(item, (str, int, float, bool, type(None)))
|
||||
for item in obj
|
||||
):
|
||||
# Sort primitive lists for consistent ordering
|
||||
return sorted(obj, key=lambda x: (x is None, str(x)))
|
||||
else:
|
||||
# For lists of complex objects, process each item but maintain order
|
||||
return [sort_dict(item) for item in obj]
|
||||
return obj
|
||||
|
||||
return collections.OrderedDict(
|
||||
[
|
||||
("id", self.id),
|
||||
("name", self.name),
|
||||
("inputSchema", sort_dict(self.input_schema.jsonschema())),
|
||||
("outputSchema", sort_dict(self.output_schema.jsonschema())),
|
||||
("description", self.description),
|
||||
("categories", sorted_categories),
|
||||
(
|
||||
"contributors",
|
||||
sorted(
|
||||
[contributor.model_dump() for contributor in self.contributors],
|
||||
key=lambda c: (c.get("name", ""), c.get("username", "")),
|
||||
),
|
||||
),
|
||||
("staticOutput", self.static_output),
|
||||
("uiType", self.block_type.value),
|
||||
]
|
||||
)
|
||||
|
||||
def get_info(self) -> BlockInfo:
|
||||
from backend.data.credit import get_block_cost
|
||||
@@ -738,3 +772,127 @@ def get_io_block_ids() -> Sequence[str]:
|
||||
for id, B in get_blocks().items()
|
||||
if B().block_type in (BlockType.INPUT, BlockType.OUTPUT)
|
||||
]
|
||||
|
||||
|
||||
async def get_block_registry() -> Dict[str, BlocksRegistry]:
|
||||
"""
|
||||
Retrieves the BlocksRegistry from the database and returns a dictionary mapping
|
||||
block names to BlocksRegistry objects.
|
||||
|
||||
Returns:
|
||||
Dict[str, BlocksRegistry]: A dictionary where each key is a block name and
|
||||
each value is a BlocksRegistry instance.
|
||||
"""
|
||||
blocks = await BlocksRegistry.prisma().find_many()
|
||||
return {block.id: block for block in blocks}
|
||||
|
||||
|
||||
def recursive_json_compare(
|
||||
db_block_definition: Any, local_block_definition: Any
|
||||
) -> bool:
|
||||
"""
|
||||
Recursively compares two JSON objects for equality.
|
||||
|
||||
Args:
|
||||
db_block_definition (Any): The JSON object from the database.
|
||||
local_block_definition (Any): The local JSON object to compare against.
|
||||
|
||||
Returns:
|
||||
bool: True if the objects are equal, False otherwise.
|
||||
"""
|
||||
if isinstance(db_block_definition, dict) and isinstance(
|
||||
local_block_definition, dict
|
||||
):
|
||||
if set(db_block_definition.keys()) != set(local_block_definition.keys()):
|
||||
logger.error(
|
||||
f"Keys are not the same: {set(db_block_definition.keys())} != {set(local_block_definition.keys())}"
|
||||
)
|
||||
return False
|
||||
return all(
|
||||
recursive_json_compare(db_block_definition[k], local_block_definition[k])
|
||||
for k in db_block_definition
|
||||
)
|
||||
values_are_same = db_block_definition == local_block_definition
|
||||
if not values_are_same:
|
||||
logger.error(
|
||||
f"Values are not the same: {db_block_definition} != {local_block_definition}"
|
||||
)
|
||||
return values_are_same
|
||||
|
||||
|
||||
def check_block_same(db_block: BlocksRegistry, local_block: Block) -> bool:
|
||||
"""
|
||||
Compares a database block with a local block.
|
||||
|
||||
Args:
|
||||
db_block (BlocksRegistry): The block object from the database registry.
|
||||
local_block (Block[BlockSchema, BlockSchema]): The local block definition.
|
||||
|
||||
Returns:
|
||||
bool: True if the blocks are equal, False otherwise.
|
||||
"""
|
||||
local_block_instance = local_block() # type: ignore
|
||||
local_block_definition = local_block_instance.to_dict()
|
||||
db_block_definition = db_block.definition
|
||||
logger.info(
|
||||
f"Checking if block {local_block_instance.name} is the same as the database block {db_block.name}"
|
||||
)
|
||||
is_same = recursive_json_compare(db_block_definition, local_block_definition)
|
||||
return is_same
|
||||
|
||||
|
||||
def find_delta_blocks(
|
||||
db_blocks: Dict[str, BlocksRegistry], local_blocks: Dict[str, Block]
|
||||
) -> Dict[str, Block]:
|
||||
"""
|
||||
Finds the set of blocks that are new or changed compared to the database.
|
||||
|
||||
Args:
|
||||
db_blocks (Dict[str, BlocksRegistry]): Existing blocks from the database, keyed by name.
|
||||
local_blocks (Dict[str, Block]): Local block definitions, keyed by name.
|
||||
|
||||
Returns:
|
||||
Dict[str, Block]: Blocks that are missing from or different than the database, keyed by name.
|
||||
"""
|
||||
block_update: Dict[str, Block] = {}
|
||||
for block_id, block in local_blocks.items():
|
||||
if block_id not in db_blocks:
|
||||
block_update[block_id] = block
|
||||
else:
|
||||
if not check_block_same(db_blocks[block_id], block):
|
||||
block_update[block_id] = block
|
||||
return block_update
|
||||
|
||||
|
||||
async def upsert_blocks_change_bulk(blocks: Dict[str, Block]):
|
||||
"""
|
||||
Bulk upserts blocks into the database if changed.
|
||||
|
||||
- Compares the provided local blocks to those in the database via their definition.
|
||||
- Inserts new or updated blocks.
|
||||
|
||||
Args:
|
||||
blocks (Dict[str, Block]): Local block definitions to upsert.
|
||||
|
||||
Returns:
|
||||
Dict[str, Block]: Blocks that were new or changed and upserted.
|
||||
"""
|
||||
db_blocks = await get_block_registry()
|
||||
block_update = find_delta_blocks(db_blocks, blocks)
|
||||
logger.error(f"Upserting {len(block_update)} blocks of {len(blocks)} total blocks")
|
||||
for block_id, block in block_update.items():
|
||||
await BlocksRegistry.prisma().upsert(
|
||||
where={"id": block_id},
|
||||
data={
|
||||
"create": {
|
||||
"id": block_id,
|
||||
"name": block().__class__.__name__, # type: ignore
|
||||
"definition": Json(block.to_dict(block())), # type: ignore
|
||||
},
|
||||
"update": {
|
||||
"name": block().__class__.__name__, # type: ignore
|
||||
"definition": Json(block.to_dict(block())), # type: ignore
|
||||
},
|
||||
},
|
||||
)
|
||||
return block_update
|
||||
|
||||
191
autogpt_platform/backend/backend/data/block_test.py
Normal file
191
autogpt_platform/backend/backend/data/block_test.py
Normal file
@@ -0,0 +1,191 @@
|
||||
import json
|
||||
from datetime import datetime
|
||||
|
||||
import pytest
|
||||
from prisma.models import BlocksRegistry
|
||||
|
||||
from backend.blocks.basic import (
|
||||
FileStoreBlock,
|
||||
PrintToConsoleBlock,
|
||||
ReverseListOrderBlock,
|
||||
StoreValueBlock,
|
||||
)
|
||||
from backend.data.block import (
|
||||
check_block_same,
|
||||
find_delta_blocks,
|
||||
recursive_json_compare,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_recursive_json_compare():
|
||||
db_block_definition = {
|
||||
"a": 1,
|
||||
"b": 2,
|
||||
"c": 3,
|
||||
}
|
||||
local_block_definition = {
|
||||
"a": 1,
|
||||
"b": 2,
|
||||
"c": 3,
|
||||
}
|
||||
assert recursive_json_compare(db_block_definition, local_block_definition)
|
||||
assert not recursive_json_compare(
|
||||
db_block_definition, {**local_block_definition, "d": 4}
|
||||
)
|
||||
assert not recursive_json_compare(
|
||||
db_block_definition, {**local_block_definition, "a": 2}
|
||||
)
|
||||
assert not recursive_json_compare(
|
||||
db_block_definition, {**local_block_definition, "b": 3}
|
||||
)
|
||||
assert not recursive_json_compare(
|
||||
db_block_definition, {**local_block_definition, "c": 4}
|
||||
)
|
||||
assert not recursive_json_compare(
|
||||
db_block_definition, {**local_block_definition, "a": 1, "b": 2, "c": 3, "d": 4}
|
||||
)
|
||||
assert recursive_json_compare({}, {})
|
||||
assert recursive_json_compare({"a": 1}, {"a": 1})
|
||||
assert not recursive_json_compare({"a": 1}, {"b": 1})
|
||||
assert not recursive_json_compare({"a": 1}, {"a": 2})
|
||||
assert not recursive_json_compare({"a": 1}, {"a": [1, 2]})
|
||||
assert not recursive_json_compare({"a": 1}, {"a": {"b": 1}})
|
||||
assert not recursive_json_compare({"a": 1}, {"a": {"b": 2}})
|
||||
assert not recursive_json_compare({"a": 1}, {"a": {"b": [1, 2]}})
|
||||
assert not recursive_json_compare({"a": 1}, {"a": {"b": {"c": 1}}})
|
||||
assert not recursive_json_compare({"a": 1}, {"a": {"b": {"c": 2}}})
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_check_block_same():
|
||||
local_block = PrintToConsoleBlock()
|
||||
db_block = BlocksRegistry(
|
||||
id="f3b1c1b2-4c4f-4f0d-8d2f-4c4f0d8d2f4c",
|
||||
name=local_block.__class__.__name__,
|
||||
definition=json.dumps(local_block.to_dict()), # type: ignore To much type magic going on here
|
||||
updatedAt=datetime.now(),
|
||||
)
|
||||
assert check_block_same(db_block, local_block)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_check_block_not_same():
|
||||
local_block = PrintToConsoleBlock()
|
||||
local_block_data = local_block.to_dict()
|
||||
local_block_data["description"] = "Hello, World!"
|
||||
|
||||
db_block = BlocksRegistry(
|
||||
id="f3b1c1b2-4c4f-4f0d-8d2f-4c4f0d8d2f4c",
|
||||
name=local_block.__class__.__name__,
|
||||
definition=json.dumps(local_block_data), # type: ignore To much type magic going on here
|
||||
updatedAt=datetime.now(),
|
||||
)
|
||||
assert not check_block_same(db_block, local_block)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_find_delta_blocks():
|
||||
now = datetime.now()
|
||||
store_value_block = StoreValueBlock()
|
||||
local_blocks = {
|
||||
PrintToConsoleBlock().id: PrintToConsoleBlock(),
|
||||
ReverseListOrderBlock().id: ReverseListOrderBlock(),
|
||||
FileStoreBlock().id: FileStoreBlock(),
|
||||
store_value_block.id: store_value_block,
|
||||
}
|
||||
db_blocks = {
|
||||
PrintToConsoleBlock().id: BlocksRegistry(
|
||||
id=PrintToConsoleBlock().id,
|
||||
name=PrintToConsoleBlock().__class__.__name__,
|
||||
definition=json.dumps(PrintToConsoleBlock().to_dict()), # type: ignore To much type magic going on here
|
||||
updatedAt=now,
|
||||
),
|
||||
ReverseListOrderBlock().id: BlocksRegistry(
|
||||
id=ReverseListOrderBlock().id,
|
||||
name=ReverseListOrderBlock().__class__.__name__,
|
||||
definition=json.dumps(ReverseListOrderBlock().to_dict()), # type: ignore To much type magic going on here
|
||||
updatedAt=now,
|
||||
),
|
||||
FileStoreBlock().id: BlocksRegistry(
|
||||
id=FileStoreBlock().id,
|
||||
name=FileStoreBlock().__class__.__name__,
|
||||
definition=json.dumps(FileStoreBlock().to_dict()), # type: ignore To much type magic going on here
|
||||
updatedAt=now,
|
||||
),
|
||||
}
|
||||
delta_blocks = find_delta_blocks(db_blocks, local_blocks)
|
||||
assert len(delta_blocks) == 1
|
||||
assert store_value_block.id in delta_blocks.keys()
|
||||
assert delta_blocks[store_value_block.id] == store_value_block
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_find_delta_blocks_block_updated():
|
||||
now = datetime.now()
|
||||
store_value_block = StoreValueBlock()
|
||||
print_to_console_block_definition = PrintToConsoleBlock().to_dict()
|
||||
print_to_console_block_definition["description"] = "Hello, World!"
|
||||
local_blocks = {
|
||||
PrintToConsoleBlock().id: PrintToConsoleBlock(),
|
||||
ReverseListOrderBlock().id: ReverseListOrderBlock(),
|
||||
FileStoreBlock().id: FileStoreBlock(),
|
||||
store_value_block.id: store_value_block,
|
||||
}
|
||||
db_blocks = {
|
||||
PrintToConsoleBlock().id: BlocksRegistry(
|
||||
id=PrintToConsoleBlock().id,
|
||||
name=PrintToConsoleBlock().__class__.__name__,
|
||||
definition=json.dumps(print_to_console_block_definition), # type: ignore To much type magic going on here
|
||||
updatedAt=now,
|
||||
),
|
||||
ReverseListOrderBlock().id: BlocksRegistry(
|
||||
id=ReverseListOrderBlock().id,
|
||||
name=ReverseListOrderBlock().__class__.__name__,
|
||||
definition=json.dumps(ReverseListOrderBlock().to_dict()), # type: ignore To much type magic going on here
|
||||
updatedAt=now,
|
||||
),
|
||||
FileStoreBlock().id: BlocksRegistry(
|
||||
id=FileStoreBlock().id,
|
||||
name=FileStoreBlock().__class__.__name__,
|
||||
definition=json.dumps(FileStoreBlock().to_dict()), # type: ignore To much type magic going on here
|
||||
updatedAt=now,
|
||||
),
|
||||
}
|
||||
delta_blocks = find_delta_blocks(db_blocks, local_blocks)
|
||||
assert len(delta_blocks) == 2
|
||||
assert store_value_block.id in delta_blocks.keys()
|
||||
assert delta_blocks[store_value_block.id] == store_value_block
|
||||
assert PrintToConsoleBlock().id in delta_blocks.keys()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_find_delta_block_no_diff():
|
||||
now = datetime.now()
|
||||
local_blocks = {
|
||||
PrintToConsoleBlock().id: PrintToConsoleBlock(),
|
||||
ReverseListOrderBlock().id: ReverseListOrderBlock(),
|
||||
FileStoreBlock().id: FileStoreBlock(),
|
||||
}
|
||||
db_blocks = {
|
||||
PrintToConsoleBlock().id: BlocksRegistry(
|
||||
id=PrintToConsoleBlock().id,
|
||||
name=PrintToConsoleBlock().__class__.__name__,
|
||||
definition=json.dumps(PrintToConsoleBlock().to_dict()), # type: ignore To much type magic going on here
|
||||
updatedAt=now,
|
||||
),
|
||||
ReverseListOrderBlock().id: BlocksRegistry(
|
||||
id=ReverseListOrderBlock().id,
|
||||
name=ReverseListOrderBlock().__class__.__name__,
|
||||
definition=json.dumps(ReverseListOrderBlock().to_dict()), # type: ignore To much type magic going on here
|
||||
updatedAt=now,
|
||||
),
|
||||
FileStoreBlock().id: BlocksRegistry(
|
||||
id=FileStoreBlock().id,
|
||||
name=FileStoreBlock().__class__.__name__,
|
||||
definition=json.dumps(FileStoreBlock().to_dict()), # type: ignore To much type magic going on here
|
||||
updatedAt=now,
|
||||
),
|
||||
}
|
||||
delta_blocks = find_delta_blocks(db_blocks, local_blocks)
|
||||
assert len(delta_blocks) == 0
|
||||
@@ -523,13 +523,13 @@ class CredentialsMetaInput(BaseModel, Generic[CP, CT]):
|
||||
if hasattr(model_class, "allowed_providers") and hasattr(
|
||||
model_class, "allowed_cred_types"
|
||||
):
|
||||
allowed_providers = model_class.allowed_providers()
|
||||
allowed_providers = sorted(model_class.allowed_providers())
|
||||
# If no specific providers (None), allow any string
|
||||
if allowed_providers is None:
|
||||
schema["credentials_provider"] = ["string"] # Allow any string provider
|
||||
else:
|
||||
schema["credentials_provider"] = allowed_providers
|
||||
schema["credentials_types"] = model_class.allowed_cred_types()
|
||||
schema["credentials_types"] = sorted(model_class.allowed_cred_types())
|
||||
# Do not return anything, just mutate schema in place
|
||||
|
||||
model_config = ConfigDict(
|
||||
|
||||
@@ -1,16 +1,98 @@
|
||||
from typing import Dict
|
||||
|
||||
from prisma import Prisma
|
||||
from prisma.models import ProviderRegistry as PrismaProviderRegistry
|
||||
|
||||
from backend.sdk.provider import ProviderRegister
|
||||
|
||||
|
||||
async def upsert_providers_bulk(providers: Dict[str, ProviderRegister]):
|
||||
def is_providers_different(
|
||||
current_provider: PrismaProviderRegistry, new_provider: ProviderRegister
|
||||
) -> bool:
|
||||
"""
|
||||
Compare a current provider (as stored in the database) against a new provider registration
|
||||
and determine if they are different. This is done by converting the database model to a
|
||||
ProviderRegister and checking for equality (all fields compared).
|
||||
|
||||
Args:
|
||||
current_provider (PrismaProviderRegistry): The provider as stored in the database.
|
||||
new_provider (ProviderRegister): The provider specification to compare.
|
||||
|
||||
Returns:
|
||||
bool: True if the providers differ, False if they are effectively the same.
|
||||
"""
|
||||
current_provider_register = ProviderRegister(
|
||||
name=current_provider.name,
|
||||
with_oauth=current_provider.with_oauth,
|
||||
client_id_env_var=current_provider.client_id_env_var,
|
||||
client_secret_env_var=current_provider.client_secret_env_var,
|
||||
with_api_key=current_provider.with_api_key,
|
||||
api_key_env_var=current_provider.api_key_env_var,
|
||||
with_user_password=current_provider.with_user_password,
|
||||
username_env_var=current_provider.username_env_var,
|
||||
password_env_var=current_provider.password_env_var,
|
||||
)
|
||||
if current_provider_register == new_provider:
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def find_delta_providers(
|
||||
current_providers: Dict[str, PrismaProviderRegistry],
|
||||
providers: Dict[str, ProviderRegister],
|
||||
) -> Dict[str, ProviderRegister]:
|
||||
"""
|
||||
Identify providers that are either new or updated compared to the current providers list.
|
||||
|
||||
Args:
|
||||
current_providers (Dict[str, PrismaProviderRegistry]): Dictionary of current provider models keyed by provider name.
|
||||
providers (Dict[str, ProviderRegister]): Dictionary of new provider registrations keyed by provider name.
|
||||
|
||||
Returns:
|
||||
Dict[str, ProviderRegister]: Providers that need to be added/updated in the registry.
|
||||
- Includes providers not in current_providers.
|
||||
- Includes providers where the data differs from what's in current_providers.
|
||||
"""
|
||||
provider_update = {}
|
||||
for name, provider in providers.items():
|
||||
if name not in current_providers:
|
||||
provider_update[name] = provider
|
||||
else:
|
||||
if is_providers_different(current_providers[name], provider):
|
||||
provider_update[name] = provider
|
||||
|
||||
return provider_update
|
||||
|
||||
|
||||
async def get_providers() -> Dict[str, PrismaProviderRegistry]:
|
||||
"""
|
||||
Retrieve all provider registries from the database.
|
||||
|
||||
Returns:
|
||||
Dict[str, PrismaProviderRegistry]: Dictionary of all current providers, keyed by provider name.
|
||||
"""
|
||||
async with Prisma() as prisma:
|
||||
providers = await prisma.providerregistry.find_many()
|
||||
return {
|
||||
provider.name: PrismaProviderRegistry(**provider.model_dump())
|
||||
for provider in providers
|
||||
}
|
||||
|
||||
|
||||
async def upsert_providers_change_bulk(providers: Dict[str, ProviderRegister]):
|
||||
"""
|
||||
Bulk upsert providers into the database after checking for changes.
|
||||
|
||||
Args:
|
||||
providers (Dict[str, ProviderRegister]): Dictionary of new provider registrations keyed by provider name.
|
||||
"""
|
||||
current_providers = await get_providers()
|
||||
provider_update = find_delta_providers(current_providers, providers)
|
||||
"""Async version of bulk upsert providers with all fields using transaction for atomicity"""
|
||||
async with Prisma() as prisma:
|
||||
async with prisma.tx() as tx:
|
||||
results = []
|
||||
for name, provider in providers.items():
|
||||
for name, provider in provider_update.items():
|
||||
result = await tx.providerregistry.upsert(
|
||||
where={"name": name},
|
||||
data={
|
||||
|
||||
127
autogpt_platform/backend/backend/sdk/db_test.py
Normal file
127
autogpt_platform/backend/backend/sdk/db_test.py
Normal file
@@ -0,0 +1,127 @@
|
||||
from datetime import datetime
|
||||
|
||||
import pytest
|
||||
from prisma.models import ProviderRegistry as PrismaProviderRegistry
|
||||
|
||||
from backend.sdk.db import find_delta_providers, is_providers_different
|
||||
from backend.sdk.provider import ProviderRegister
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
def test_is_providers_different_same():
|
||||
current_provider = PrismaProviderRegistry(
|
||||
name="test_provider",
|
||||
with_oauth=True,
|
||||
client_id_env_var="TEST_CLIENT_ID",
|
||||
client_secret_env_var="TEST_CLIENT_SECRET",
|
||||
with_api_key=True,
|
||||
api_key_env_var="TEST_API_KEY",
|
||||
with_user_password=True,
|
||||
username_env_var="TEST_USERNAME",
|
||||
password_env_var="TEST_PASSWORD",
|
||||
updatedAt=datetime.now(),
|
||||
)
|
||||
new_provider = ProviderRegister(
|
||||
name="test_provider",
|
||||
with_oauth=True,
|
||||
client_id_env_var="TEST_CLIENT_ID",
|
||||
client_secret_env_var="TEST_CLIENT_SECRET",
|
||||
with_api_key=True,
|
||||
api_key_env_var="TEST_API_KEY",
|
||||
with_user_password=True,
|
||||
username_env_var="TEST_USERNAME",
|
||||
password_env_var="TEST_PASSWORD",
|
||||
)
|
||||
assert not is_providers_different(current_provider, new_provider)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
def test_is_providers_different_different():
|
||||
current_provider = PrismaProviderRegistry(
|
||||
name="test_provider",
|
||||
with_oauth=True,
|
||||
client_id_env_var="TEST_CLIENT_ID",
|
||||
client_secret_env_var="TEST_CLIENT_SECRET",
|
||||
with_api_key=True,
|
||||
api_key_env_var="TEST_API_KEY",
|
||||
with_user_password=True,
|
||||
username_env_var="TEST_USERNAME",
|
||||
password_env_var="TEST_PASSWORD",
|
||||
updatedAt=datetime.now(),
|
||||
)
|
||||
new_provider = ProviderRegister(
|
||||
name="test_provider",
|
||||
with_oauth=False,
|
||||
with_api_key=True,
|
||||
api_key_env_var="TEST_API_KEY",
|
||||
with_user_password=True,
|
||||
username_env_var="TEST_USERNAME",
|
||||
password_env_var="TEST_PASSWORD",
|
||||
)
|
||||
assert is_providers_different(current_provider, new_provider)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
def test_find_delta_providers():
|
||||
current_providers = {
|
||||
"test_provider": PrismaProviderRegistry(
|
||||
name="test_provider",
|
||||
with_oauth=True,
|
||||
client_id_env_var="TEST_CLIENT_ID",
|
||||
client_secret_env_var="TEST_CLIENT_SECRET",
|
||||
with_api_key=True,
|
||||
api_key_env_var="TEST_API_KEY",
|
||||
with_user_password=True,
|
||||
username_env_var="TEST_USERNAME",
|
||||
password_env_var="TEST_PASSWORD",
|
||||
updatedAt=datetime.now(),
|
||||
),
|
||||
"test_provider_2": PrismaProviderRegistry(
|
||||
name="test_provider_2",
|
||||
with_oauth=True,
|
||||
client_id_env_var="TEST_CLIENT_ID_2",
|
||||
client_secret_env_var="TEST_CLIENT_SECRET_2",
|
||||
with_api_key=True,
|
||||
api_key_env_var="TEST_API_KEY_2",
|
||||
with_user_password=True,
|
||||
username_env_var="TEST_USERNAME_2",
|
||||
password_env_var="TEST_PASSWORD_2",
|
||||
updatedAt=datetime.now(),
|
||||
),
|
||||
}
|
||||
new_providers = {
|
||||
"test_provider": ProviderRegister(
|
||||
name="test_provider",
|
||||
with_oauth=True,
|
||||
client_id_env_var="TEST_CLIENT_ID",
|
||||
client_secret_env_var="TEST_CLIENT_SECRET",
|
||||
with_api_key=True,
|
||||
api_key_env_var="TEST_API_KEY",
|
||||
with_user_password=True,
|
||||
username_env_var="TEST_USERNAME",
|
||||
password_env_var="TEST_PASSWORD",
|
||||
),
|
||||
"test_provider_2": ProviderRegister(
|
||||
name="test_provider_2",
|
||||
with_oauth=False,
|
||||
with_api_key=True,
|
||||
api_key_env_var="TEST_API_KEY_2",
|
||||
with_user_password=True,
|
||||
username_env_var="TEST_USERNAME_2",
|
||||
password_env_var="TEST_PASSWORD_2",
|
||||
),
|
||||
"test_provider_3": ProviderRegister(
|
||||
name="test_provider_3",
|
||||
with_oauth=True,
|
||||
client_id_env_var="TEST_CLIENT_ID_3",
|
||||
client_secret_env_var="TEST_CLIENT_SECRET_3",
|
||||
with_api_key=False,
|
||||
with_user_password=True,
|
||||
username_env_var="TEST_USERNAME_3",
|
||||
password_env_var="TEST_PASSWORD_3",
|
||||
),
|
||||
}
|
||||
assert find_delta_providers(current_providers, new_providers) == {
|
||||
"test_provider_2": new_providers["test_provider_2"],
|
||||
"test_provider_3": new_providers["test_provider_3"],
|
||||
}
|
||||
@@ -13,7 +13,7 @@ from backend.data.model import Credentials
|
||||
from backend.integrations.oauth.base import BaseOAuthHandler
|
||||
from backend.integrations.providers import ProviderName
|
||||
from backend.integrations.webhooks._base import BaseWebhooksManager
|
||||
from backend.sdk.db import upsert_providers_bulk
|
||||
from backend.sdk.db import upsert_providers_change_bulk
|
||||
from backend.sdk.provider import ProviderRegister
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -282,7 +282,7 @@ class AutoRegistry:
|
||||
|
||||
cls._provider_registry[item.provider] = new_cred
|
||||
|
||||
await upsert_providers_bulk(providers=cls._provider_registry)
|
||||
await upsert_providers_change_bulk(providers=cls._provider_registry)
|
||||
|
||||
if hasattr(creds_store, "IntegrationCredentialsStore"):
|
||||
store_class = creds_store.IntegrationCredentialsStore
|
||||
|
||||
@@ -16,6 +16,7 @@ from fastapi.middleware.gzip import GZipMiddleware
|
||||
from fastapi.routing import APIRoute
|
||||
from prisma.errors import PrismaError
|
||||
|
||||
import backend.blocks
|
||||
import backend.data.block
|
||||
import backend.data.db
|
||||
import backend.data.graph
|
||||
@@ -99,6 +100,9 @@ async def lifespan_context(app: fastapi.FastAPI):
|
||||
|
||||
await backend.data.block.initialize_blocks()
|
||||
|
||||
blocks = backend.blocks.load_all_blocks()
|
||||
|
||||
await backend.data.block.upsert_blocks_change_bulk(blocks)
|
||||
await backend.data.user.migrate_and_encrypt_user_integrations()
|
||||
await backend.data.graph.fix_llm_provider_credentials()
|
||||
await backend.data.graph.migrate_llm_models(LlmModel.GPT4O)
|
||||
|
||||
@@ -1,14 +0,0 @@
|
||||
-- CreateTable
|
||||
CREATE TABLE "ProviderRegistry" (
|
||||
"name" TEXT NOT NULL,
|
||||
"with_oauth" BOOLEAN NOT NULL DEFAULT false,
|
||||
"client_id_env_var" TEXT,
|
||||
"client_secret_env_var" TEXT,
|
||||
"with_api_key" BOOLEAN NOT NULL DEFAULT false,
|
||||
"api_key_env_var" TEXT,
|
||||
"with_user_password" BOOLEAN NOT NULL DEFAULT false,
|
||||
"username_env_var" TEXT,
|
||||
"password_env_var" TEXT,
|
||||
|
||||
CONSTRAINT "ProviderRegistry_pkey" PRIMARY KEY ("name")
|
||||
);
|
||||
@@ -0,0 +1,31 @@
|
||||
-- CreateTable
|
||||
CREATE TABLE "ProviderRegistry" (
|
||||
"name" TEXT NOT NULL,
|
||||
"with_oauth" BOOLEAN NOT NULL DEFAULT false,
|
||||
"client_id_env_var" TEXT,
|
||||
"client_secret_env_var" TEXT,
|
||||
"with_api_key" BOOLEAN NOT NULL DEFAULT false,
|
||||
"api_key_env_var" TEXT,
|
||||
"with_user_password" BOOLEAN NOT NULL DEFAULT false,
|
||||
"username_env_var" TEXT,
|
||||
"password_env_var" TEXT,
|
||||
"updatedAt" TIMESTAMP(3) NOT NULL,
|
||||
|
||||
CONSTRAINT "ProviderRegistry_pkey" PRIMARY KEY ("name")
|
||||
);
|
||||
|
||||
-- CreateTable
|
||||
CREATE TABLE "BlocksRegistry" (
|
||||
"id" TEXT NOT NULL,
|
||||
"name" TEXT NOT NULL,
|
||||
"definition" JSONB NOT NULL,
|
||||
"updatedAt" TIMESTAMP(3) NOT NULL,
|
||||
|
||||
CONSTRAINT "BlocksRegistry_pkey" PRIMARY KEY ("id")
|
||||
);
|
||||
|
||||
-- CreateIndex
|
||||
CREATE INDEX "ProviderRegistry_updatedAt_idx" ON "ProviderRegistry"("updatedAt");
|
||||
|
||||
-- CreateIndex
|
||||
CREATE INDEX "BlocksRegistry_updatedAt_idx" ON "BlocksRegistry"("updatedAt");
|
||||
@@ -74,6 +74,18 @@ model ProviderRegistry {
|
||||
with_user_password Boolean @default(false)
|
||||
username_env_var String?
|
||||
password_env_var String?
|
||||
updatedAt DateTime @updatedAt
|
||||
|
||||
@@index([updatedAt])
|
||||
}
|
||||
|
||||
model BlocksRegistry {
|
||||
id String @id @default(uuid())
|
||||
name String
|
||||
definition Json
|
||||
updatedAt DateTime @updatedAt
|
||||
|
||||
@@index([updatedAt])
|
||||
}
|
||||
|
||||
enum OnboardingStep {
|
||||
|
||||
@@ -7,7 +7,7 @@ import {
|
||||
usePostV1CreateCredentials,
|
||||
} from "@/app/api/__generated__/endpoints/integrations/integrations";
|
||||
import { useToast } from "@/components/molecules/Toast/use-toast";
|
||||
import { APIKeyCredentials } from "@/app/api/__generated__/models/aPIKeyCredentials";
|
||||
import { APIKeyCredentialsInput } from "@/app/api/__generated__/models/aPIKeyCredentialsInput";
|
||||
import { useQueryClient } from "@tanstack/react-query";
|
||||
import { useState } from "react";
|
||||
|
||||
@@ -89,7 +89,7 @@ export function useAPIKeyCredentialsModal({
|
||||
api_key: values.apiKey,
|
||||
title: values.title,
|
||||
expires_at: expiresAt,
|
||||
} as APIKeyCredentials,
|
||||
} as APIKeyCredentialsInput,
|
||||
});
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user