mirror of
https://github.com/microsoft/autogen.git
synced 2026-02-12 18:34:56 -05:00
v1 of AutoGen Studio on AgentChat (#4097)
* add skeleton worflow manager * add test notebook * update test nb * add sample team spec * refactor requirements to agentchat and ext * add base provider to return agentchat agents from json spec * initial api refactor, update dbmanager * api refactor * refactor tests * ags api tutorial update * ui refactor * general refactor * minor refactor updates * backend api refaactor * ui refactor and update * implement v1 for streaming connection with ui updates * backend refactor * ui refactor * minor ui tweak * minor refactor and tweaks * general refactor * update tests * sync uv.lock with main * uv lock update
This commit is contained in:
@@ -1,3 +1,3 @@
|
||||
# from .dbmanager import *
|
||||
from .dbmanager import *
|
||||
from .utils import *
|
||||
from .db_manager import DatabaseManager
|
||||
from .component_factory import ComponentFactory
|
||||
from .config_manager import ConfigurationManager
|
||||
|
||||
@@ -1,116 +0,0 @@
|
||||
# A generic, single database configuration.
|
||||
|
||||
[alembic]
|
||||
# path to migration scripts
|
||||
script_location = migrations
|
||||
|
||||
# template used to generate migration file names; The default value is %%(rev)s_%%(slug)s
|
||||
# Uncomment the line below if you want the files to be prepended with date and time
|
||||
# see https://alembic.sqlalchemy.org/en/latest/tutorial.html#editing-the-ini-file
|
||||
# for all available tokens
|
||||
# file_template = %%(year)d_%%(month).2d_%%(day).2d_%%(hour).2d%%(minute).2d-%%(rev)s_%%(slug)s
|
||||
|
||||
# sys.path path, will be prepended to sys.path if present.
|
||||
# defaults to the current working directory.
|
||||
prepend_sys_path = .
|
||||
|
||||
# timezone to use when rendering the date within the migration file
|
||||
# as well as the filename.
|
||||
# If specified, requires the python>=3.9 or backports.zoneinfo library.
|
||||
# Any required deps can installed by adding `alembic[tz]` to the pip requirements
|
||||
# string value is passed to ZoneInfo()
|
||||
# leave blank for localtime
|
||||
# timezone =
|
||||
|
||||
# max length of characters to apply to the
|
||||
# "slug" field
|
||||
# truncate_slug_length = 40
|
||||
|
||||
# set to 'true' to run the environment during
|
||||
# the 'revision' command, regardless of autogenerate
|
||||
# revision_environment = false
|
||||
|
||||
# set to 'true' to allow .pyc and .pyo files without
|
||||
# a source .py file to be detected as revisions in the
|
||||
# versions/ directory
|
||||
# sourceless = false
|
||||
|
||||
# version location specification; This defaults
|
||||
# to migrations/versions. When using multiple version
|
||||
# directories, initial revisions must be specified with --version-path.
|
||||
# The path separator used here should be the separator specified by "version_path_separator" below.
|
||||
# version_locations = %(here)s/bar:%(here)s/bat:migrations/versions
|
||||
|
||||
# version path separator; As mentioned above, this is the character used to split
|
||||
# version_locations. The default within new alembic.ini files is "os", which uses os.pathsep.
|
||||
# If this key is omitted entirely, it falls back to the legacy behavior of splitting on spaces and/or commas.
|
||||
# Valid values for version_path_separator are:
|
||||
#
|
||||
# version_path_separator = :
|
||||
# version_path_separator = ;
|
||||
# version_path_separator = space
|
||||
version_path_separator = os # Use os.pathsep. Default configuration used for new projects.
|
||||
|
||||
# set to 'true' to search source files recursively
|
||||
# in each "version_locations" directory
|
||||
# new in Alembic version 1.10
|
||||
# recursive_version_locations = false
|
||||
|
||||
# the output encoding used when revision files
|
||||
# are written from script.py.mako
|
||||
# output_encoding = utf-8
|
||||
|
||||
sqlalchemy.url = driver://user:pass@localhost/dbname
|
||||
|
||||
|
||||
[post_write_hooks]
|
||||
# post_write_hooks defines scripts or Python functions that are run
|
||||
# on newly generated revision scripts. See the documentation for further
|
||||
# detail and examples
|
||||
|
||||
# format using "black" - use the console_scripts runner, against the "black" entrypoint
|
||||
# hooks = black
|
||||
# black.type = console_scripts
|
||||
# black.entrypoint = black
|
||||
# black.options = -l 79 REVISION_SCRIPT_FILENAME
|
||||
|
||||
# lint with attempts to fix using "ruff" - use the exec runner, execute a binary
|
||||
# hooks = ruff
|
||||
# ruff.type = exec
|
||||
# ruff.executable = %(here)s/.venv/bin/ruff
|
||||
# ruff.options = --fix REVISION_SCRIPT_FILENAME
|
||||
|
||||
# Logging configuration
|
||||
[loggers]
|
||||
keys = root,sqlalchemy,alembic
|
||||
|
||||
[handlers]
|
||||
keys = console
|
||||
|
||||
[formatters]
|
||||
keys = generic
|
||||
|
||||
[logger_root]
|
||||
level = WARN
|
||||
handlers = console
|
||||
qualname =
|
||||
|
||||
[logger_sqlalchemy]
|
||||
level = WARN
|
||||
handlers =
|
||||
qualname = sqlalchemy.engine
|
||||
|
||||
[logger_alembic]
|
||||
level = INFO
|
||||
handlers =
|
||||
qualname = alembic
|
||||
|
||||
[handler_console]
|
||||
class = StreamHandler
|
||||
args = (sys.stderr,)
|
||||
level = NOTSET
|
||||
formatter = generic
|
||||
|
||||
[formatter_generic]
|
||||
format = %(levelname)-5.5s [%(name)s] %(message)s
|
||||
datefmt = %H:%M:%S
|
||||
@@ -0,0 +1,355 @@
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import List, Literal, Union, Optional, Dict, Any, Type
|
||||
from datetime import datetime
|
||||
import json
|
||||
from autogen_agentchat.task import MaxMessageTermination, TextMentionTermination, StopMessageTermination
|
||||
import yaml
|
||||
import logging
|
||||
from packaging import version
|
||||
|
||||
from ..datamodel import (
|
||||
TeamConfig, AgentConfig, ModelConfig, ToolConfig,
|
||||
TeamTypes, AgentTypes, ModelTypes, ToolTypes,
|
||||
ComponentType, ComponentConfig, ComponentConfigInput, TerminationConfig, TerminationTypes, Response
|
||||
)
|
||||
from autogen_agentchat.agents import AssistantAgent
|
||||
from autogen_agentchat.teams import RoundRobinGroupChat, SelectorGroupChat
|
||||
from autogen_ext.models import OpenAIChatCompletionClient
|
||||
from autogen_core.components.tools import FunctionTool
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Type definitions for supported components
|
||||
TeamComponent = Union[RoundRobinGroupChat, SelectorGroupChat]
|
||||
AgentComponent = Union[AssistantAgent] # Will grow with more agent types
|
||||
# Will grow with more model types
|
||||
ModelComponent = Union[OpenAIChatCompletionClient]
|
||||
ToolComponent = Union[FunctionTool] # Will grow with more tool types
|
||||
TerminationComponent = Union[MaxMessageTermination,
|
||||
StopMessageTermination, TextMentionTermination]
|
||||
|
||||
# Config type definitions
|
||||
|
||||
Component = Union[TeamComponent, AgentComponent, ModelComponent, ToolComponent]
|
||||
|
||||
|
||||
ReturnType = Literal['object', 'dict', 'config']
|
||||
Component = Union[RoundRobinGroupChat, SelectorGroupChat,
|
||||
AssistantAgent, OpenAIChatCompletionClient, FunctionTool]
|
||||
|
||||
|
||||
class ComponentFactory:
|
||||
"""Creates and manages agent components with versioned configuration loading"""
|
||||
|
||||
SUPPORTED_VERSIONS = {
|
||||
ComponentType.TEAM: ["1.0.0"],
|
||||
ComponentType.AGENT: ["1.0.0"],
|
||||
ComponentType.MODEL: ["1.0.0"],
|
||||
ComponentType.TOOL: ["1.0.0"],
|
||||
ComponentType.TERMINATION: ["1.0.0"]
|
||||
}
|
||||
|
||||
def __init__(self):
|
||||
self._model_cache: Dict[str, OpenAIChatCompletionClient] = {}
|
||||
self._tool_cache: Dict[str, FunctionTool] = {}
|
||||
self._last_cache_clear = datetime.now()
|
||||
|
||||
async def load(self, component: ComponentConfigInput, return_type: ReturnType = 'object') -> Union[Component, dict, ComponentConfig]:
|
||||
"""
|
||||
Universal loader for any component type
|
||||
|
||||
Args:
|
||||
component: Component configuration (file path, dict, or ComponentConfig)
|
||||
return_type: Type of return value ('object', 'dict', or 'config')
|
||||
|
||||
Returns:
|
||||
Component instance, config dict, or ComponentConfig based on return_type
|
||||
|
||||
Raises:
|
||||
ValueError: If component type is unknown or version unsupported
|
||||
"""
|
||||
try:
|
||||
# Load and validate config
|
||||
if isinstance(component, (str, Path)):
|
||||
component_dict = await self._load_from_file(component)
|
||||
config = self._dict_to_config(component_dict)
|
||||
elif isinstance(component, dict):
|
||||
config = self._dict_to_config(component)
|
||||
else:
|
||||
config = component
|
||||
|
||||
# Validate version
|
||||
if not self._is_version_supported(config.component_type, config.version):
|
||||
raise ValueError(
|
||||
f"Unsupported version {config.version} for "
|
||||
f"component type {config.component_type}. "
|
||||
f"Supported versions: {self.SUPPORTED_VERSIONS[config.component_type]}"
|
||||
)
|
||||
|
||||
# Return early if dict or config requested
|
||||
if return_type == 'dict':
|
||||
return config.model_dump()
|
||||
elif return_type == 'config':
|
||||
return config
|
||||
|
||||
# Otherwise create and return component instance
|
||||
handlers = {
|
||||
ComponentType.TEAM: self.load_team,
|
||||
ComponentType.AGENT: self.load_agent,
|
||||
ComponentType.MODEL: self.load_model,
|
||||
ComponentType.TOOL: self.load_tool,
|
||||
ComponentType.TERMINATION: self.load_termination
|
||||
}
|
||||
|
||||
handler = handlers.get(config.component_type)
|
||||
if not handler:
|
||||
raise ValueError(
|
||||
f"Unknown component type: {config.component_type}")
|
||||
|
||||
return await handler(config)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to load component: {str(e)}")
|
||||
raise
|
||||
|
||||
async def load_directory(self, directory: Union[str, Path], check_exists: bool = False, return_type: ReturnType = 'object') -> List[Union[Component, dict, ComponentConfig]]:
|
||||
"""
|
||||
Import all component configurations from a directory.
|
||||
"""
|
||||
components = []
|
||||
try:
|
||||
directory = Path(directory)
|
||||
# Using Path.iterdir() instead of os.listdir
|
||||
for path in list(directory.glob("*")):
|
||||
if path.suffix.lower().endswith(('.json', '.yaml', '.yml')):
|
||||
try:
|
||||
component = await self.load(path, return_type)
|
||||
components.append(component)
|
||||
except Exception as e:
|
||||
logger.info(
|
||||
f"Failed to load component: {str(e)}, {path}")
|
||||
|
||||
return components
|
||||
except Exception as e:
|
||||
logger.info(f"Failed to load directory: {str(e)}")
|
||||
return components
|
||||
|
||||
def _dict_to_config(self, config_dict: dict) -> ComponentConfig:
|
||||
"""Convert dictionary to appropriate config type based on component_type"""
|
||||
if "component_type" not in config_dict:
|
||||
raise ValueError("component_type is required in configuration")
|
||||
|
||||
config_types = {
|
||||
ComponentType.TEAM: TeamConfig,
|
||||
ComponentType.AGENT: AgentConfig,
|
||||
ComponentType.MODEL: ModelConfig,
|
||||
ComponentType.TOOL: ToolConfig,
|
||||
ComponentType.TERMINATION: TerminationConfig # Add mapping for termination
|
||||
}
|
||||
|
||||
component_type = ComponentType(config_dict["component_type"])
|
||||
config_class = config_types.get(component_type)
|
||||
|
||||
if not config_class:
|
||||
raise ValueError(f"Unknown component type: {component_type}")
|
||||
|
||||
return config_class(**config_dict)
|
||||
|
||||
async def load_termination(self, config: TerminationConfig) -> TerminationComponent:
|
||||
"""Create termination condition instance from configuration."""
|
||||
try:
|
||||
if config.termination_type == TerminationTypes.MAX_MESSAGES:
|
||||
return MaxMessageTermination(max_messages=config.max_messages)
|
||||
elif config.termination_type == TerminationTypes.STOP_MESSAGE:
|
||||
return StopMessageTermination()
|
||||
elif config.termination_type == TerminationTypes.TEXT_MENTION:
|
||||
if not config.text:
|
||||
raise ValueError(
|
||||
"text parameter required for TextMentionTermination")
|
||||
return TextMentionTermination(text=config.text)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unsupported termination type: {config.termination_type}")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to create termination condition: {str(e)}")
|
||||
raise ValueError(
|
||||
f"Termination condition creation failed: {str(e)}")
|
||||
|
||||
async def load_team(self, config: TeamConfig) -> TeamComponent:
|
||||
"""Create team instance from configuration."""
|
||||
try:
|
||||
# Load participants (agents)
|
||||
participants = []
|
||||
for participant in config.participants:
|
||||
agent = await self.load(participant)
|
||||
participants.append(agent)
|
||||
|
||||
# Load model client if specified
|
||||
model_client = None
|
||||
if config.model_client:
|
||||
model_client = await self.load(config.model_client)
|
||||
|
||||
# Load termination condition if specified
|
||||
termination = None
|
||||
if config.termination_condition:
|
||||
# Now we can use the universal load() method since termination is a proper component
|
||||
termination = await self.load(config.termination_condition)
|
||||
|
||||
# Create team based on type
|
||||
if config.team_type == TeamTypes.ROUND_ROBIN:
|
||||
return RoundRobinGroupChat(
|
||||
participants=participants,
|
||||
termination_condition=termination
|
||||
)
|
||||
elif config.team_type == TeamTypes.SELECTOR:
|
||||
if not model_client:
|
||||
raise ValueError(
|
||||
"SelectorGroupChat requires a model_client")
|
||||
return SelectorGroupChat(
|
||||
participants=participants,
|
||||
model_client=model_client,
|
||||
termination_condition=termination
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unsupported team type: {config.team_type}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to create team {config.name}: {str(e)}")
|
||||
raise ValueError(f"Team creation failed: {str(e)}")
|
||||
|
||||
async def load_agent(self, config: AgentConfig) -> AgentComponent:
|
||||
"""Create agent instance from configuration."""
|
||||
try:
|
||||
# Load model client if specified
|
||||
model_client = None
|
||||
if config.model_client:
|
||||
model_client = await self.load(config.model_client)
|
||||
system_message = config.system_message if config.system_message else "You are a helpful assistant"
|
||||
# Load tools if specified
|
||||
tools = []
|
||||
if config.tools:
|
||||
for tool_config in config.tools:
|
||||
tool = await self.load(tool_config)
|
||||
tools.append(tool)
|
||||
|
||||
if config.agent_type == AgentTypes.ASSISTANT:
|
||||
return AssistantAgent(
|
||||
name=config.name,
|
||||
model_client=model_client,
|
||||
tools=tools,
|
||||
system_message=system_message
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unsupported agent type: {config.agent_type}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to create agent {config.name}: {str(e)}")
|
||||
raise ValueError(f"Agent creation failed: {str(e)}")
|
||||
|
||||
async def load_model(self, config: ModelConfig) -> ModelComponent:
|
||||
"""Create model instance from configuration."""
|
||||
try:
|
||||
# Check cache first
|
||||
cache_key = str(config.model_dump())
|
||||
if cache_key in self._model_cache:
|
||||
logger.debug(f"Using cached model for {config.model}")
|
||||
return self._model_cache[cache_key]
|
||||
|
||||
if config.model_type == ModelTypes.OPENAI:
|
||||
model = OpenAIChatCompletionClient(
|
||||
model=config.model,
|
||||
api_key=config.api_key,
|
||||
base_url=config.base_url
|
||||
)
|
||||
self._model_cache[cache_key] = model
|
||||
return model
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unsupported model type: {config.model_type}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to create model {config.model}: {str(e)}")
|
||||
raise ValueError(f"Model creation failed: {str(e)}")
|
||||
|
||||
async def load_tool(self, config: ToolConfig) -> ToolComponent:
|
||||
"""Create tool instance from configuration."""
|
||||
try:
|
||||
# Validate required fields
|
||||
if not all([config.name, config.description, config.content, config.tool_type]):
|
||||
raise ValueError("Tool configuration missing required fields")
|
||||
|
||||
# Check cache first
|
||||
cache_key = str(config.model_dump())
|
||||
if cache_key in self._tool_cache:
|
||||
logger.debug(f"Using cached tool '{config.name}'")
|
||||
return self._tool_cache[cache_key]
|
||||
|
||||
if config.tool_type == ToolTypes.PYTHON_FUNCTION:
|
||||
tool = FunctionTool(
|
||||
name=config.name,
|
||||
description=config.description,
|
||||
func=self._func_from_string(config.content)
|
||||
)
|
||||
self._tool_cache[cache_key] = tool
|
||||
return tool
|
||||
else:
|
||||
raise ValueError(f"Unsupported tool type: {config.tool_type}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to create tool '{config.name}': {str(e)}")
|
||||
raise
|
||||
|
||||
# Helper methods remain largely the same
|
||||
async def _load_from_file(self, path: Union[str, Path]) -> dict:
|
||||
"""Load configuration from JSON or YAML file."""
|
||||
path = Path(path)
|
||||
if not path.exists():
|
||||
raise FileNotFoundError(f"Config file not found: {path}")
|
||||
|
||||
try:
|
||||
with open(path) as f:
|
||||
if path.suffix == '.json':
|
||||
return json.load(f)
|
||||
elif path.suffix in ('.yml', '.yaml'):
|
||||
return yaml.safe_load(f)
|
||||
else:
|
||||
raise ValueError(f"Unsupported file format: {path.suffix}")
|
||||
except Exception as e:
|
||||
raise ValueError(f"Failed to load file {path}: {str(e)}")
|
||||
|
||||
def _func_from_string(self, content: str) -> callable:
|
||||
"""Convert function string to callable."""
|
||||
try:
|
||||
namespace = {}
|
||||
exec(content, namespace)
|
||||
for item in namespace.values():
|
||||
if callable(item) and not isinstance(item, type):
|
||||
return item
|
||||
raise ValueError("No function found in provided code")
|
||||
except Exception as e:
|
||||
raise ValueError(f"Failed to create function: {str(e)}")
|
||||
|
||||
def _is_version_supported(self, component_type: ComponentType, ver: str) -> bool:
|
||||
"""Check if version is supported for component type."""
|
||||
try:
|
||||
v = version.parse(ver)
|
||||
return ver in self.SUPPORTED_VERSIONS[component_type]
|
||||
except version.InvalidVersion:
|
||||
return False
|
||||
|
||||
async def cleanup(self) -> None:
|
||||
"""Cleanup resources and clear caches."""
|
||||
for model in self._model_cache.values():
|
||||
if hasattr(model, 'cleanup'):
|
||||
await model.cleanup()
|
||||
|
||||
for tool in self._tool_cache.values():
|
||||
if hasattr(tool, 'cleanup'):
|
||||
await tool.cleanup()
|
||||
|
||||
self._model_cache.clear()
|
||||
self._tool_cache.clear()
|
||||
self._last_cache_clear = datetime.now()
|
||||
logger.info("Cleared all component caches")
|
||||
@@ -0,0 +1,322 @@
|
||||
import logging
|
||||
from typing import Optional, Union, Dict, Any, List
|
||||
from pathlib import Path
|
||||
from loguru import logger
|
||||
from ..datamodel import (
|
||||
Model, Team, Agent, Tool,
|
||||
Response, ComponentTypes, LinkTypes,
|
||||
ComponentConfigInput
|
||||
)
|
||||
|
||||
from .component_factory import ComponentFactory
|
||||
from .db_manager import DatabaseManager
|
||||
|
||||
|
||||
class ConfigurationManager:
|
||||
"""Manages persistence and relationships of components using ComponentFactory for validation"""
|
||||
|
||||
DEFAULT_UNIQUENESS_FIELDS = {
|
||||
ComponentTypes.MODEL: ['model_type', 'model'],
|
||||
ComponentTypes.TOOL: ['name'],
|
||||
ComponentTypes.AGENT: ['agent_type', 'name'],
|
||||
ComponentTypes.TEAM: ['team_type', 'name']
|
||||
}
|
||||
|
||||
def __init__(self, db_manager: DatabaseManager, uniqueness_fields: Dict[ComponentTypes, List[str]] = None):
|
||||
self.db_manager = db_manager
|
||||
self.component_factory = ComponentFactory()
|
||||
self.uniqueness_fields = uniqueness_fields or self.DEFAULT_UNIQUENESS_FIELDS
|
||||
|
||||
async def import_component(self, component_config: ComponentConfigInput, user_id: str, check_exists: bool = False) -> Response:
|
||||
"""
|
||||
Import a component configuration, validate it, and store the resulting component.
|
||||
|
||||
Args:
|
||||
component_config: Configuration for the component (file path, dict, or ComponentConfig)
|
||||
user_id: User ID to associate with imported component
|
||||
check_exists: Whether to check for existing components before storing (default: False)
|
||||
|
||||
Returns:
|
||||
Response containing import results or error
|
||||
"""
|
||||
try:
|
||||
# Get validated config as dict
|
||||
config = await self.component_factory.load(component_config, return_type='dict')
|
||||
|
||||
# Get component type
|
||||
component_type = self._determine_component_type(config)
|
||||
if not component_type:
|
||||
raise ValueError(
|
||||
f"Unable to determine component type from config")
|
||||
|
||||
# Check existence if requested
|
||||
if check_exists:
|
||||
existing = self._check_exists(component_type, config, user_id)
|
||||
if existing:
|
||||
return Response(
|
||||
message=self._format_exists_message(
|
||||
component_type, config),
|
||||
status=True,
|
||||
data={"id": existing.id}
|
||||
)
|
||||
|
||||
# Route to appropriate storage method
|
||||
if component_type == ComponentTypes.TEAM:
|
||||
return await self._store_team(config, user_id, check_exists)
|
||||
elif component_type == ComponentTypes.AGENT:
|
||||
return await self._store_agent(config, user_id, check_exists)
|
||||
elif component_type == ComponentTypes.MODEL:
|
||||
return await self._store_model(config, user_id)
|
||||
elif component_type == ComponentTypes.TOOL:
|
||||
return await self._store_tool(config, user_id)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unsupported component type: {component_type}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to import component: {str(e)}")
|
||||
return Response(message=str(e), status=False)
|
||||
|
||||
async def import_directory(self, directory: Union[str, Path], user_id: str, check_exists: bool = False) -> Response:
|
||||
"""
|
||||
Import all component configurations from a directory.
|
||||
|
||||
Args:
|
||||
directory: Path to directory containing configuration files
|
||||
user_id: User ID to associate with imported components
|
||||
check_exists: Whether to check for existing components before storing (default: False)
|
||||
|
||||
Returns:
|
||||
Response containing import results for all files
|
||||
"""
|
||||
try:
|
||||
configs = await self.component_factory.load_directory(directory, return_type='dict')
|
||||
|
||||
results = []
|
||||
for config in configs:
|
||||
result = await self.import_component(config, user_id, check_exists)
|
||||
results.append({
|
||||
"component": self._get_component_type(config),
|
||||
"status": result.status,
|
||||
"message": result.message,
|
||||
"id": result.data.get("id") if result.status else None
|
||||
})
|
||||
|
||||
return Response(
|
||||
message="Directory import complete",
|
||||
status=True,
|
||||
data=results
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to import directory: {str(e)}")
|
||||
return Response(message=str(e), status=False)
|
||||
|
||||
async def _store_team(self, config: dict, user_id: str, check_exists: bool = False) -> Response:
|
||||
"""Store team component and manage its relationships with agents"""
|
||||
try:
|
||||
# Store the team
|
||||
team_db = Team(
|
||||
user_id=user_id,
|
||||
config=config
|
||||
)
|
||||
team_result = self.db_manager.upsert(team_db)
|
||||
if not team_result.status:
|
||||
return team_result
|
||||
|
||||
team_id = team_result.data["id"]
|
||||
|
||||
# Handle participants (agents)
|
||||
for participant in config.get("participants", []):
|
||||
if check_exists:
|
||||
# Check for existing agent
|
||||
agent_type = self._determine_component_type(participant)
|
||||
existing_agent = self._check_exists(
|
||||
agent_type, participant, user_id)
|
||||
if existing_agent:
|
||||
# Link existing agent
|
||||
self.db_manager.link(
|
||||
LinkTypes.TEAM_AGENT,
|
||||
team_id,
|
||||
existing_agent.id
|
||||
)
|
||||
logger.info(
|
||||
f"Linked existing agent to team: {existing_agent}")
|
||||
continue
|
||||
|
||||
# Store and link new agent
|
||||
agent_result = await self._store_agent(participant, user_id, check_exists)
|
||||
if agent_result.status:
|
||||
self.db_manager.link(
|
||||
LinkTypes.TEAM_AGENT,
|
||||
team_id,
|
||||
agent_result.data["id"]
|
||||
)
|
||||
|
||||
return team_result
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to store team: {str(e)}")
|
||||
return Response(message=str(e), status=False)
|
||||
|
||||
async def _store_agent(self, config: dict, user_id: str, check_exists: bool = False) -> Response:
|
||||
"""Store agent component and manage its relationships with tools and model"""
|
||||
try:
|
||||
# Store the agent
|
||||
agent_db = Agent(
|
||||
user_id=user_id,
|
||||
config=config
|
||||
)
|
||||
agent_result = self.db_manager.upsert(agent_db)
|
||||
if not agent_result.status:
|
||||
return agent_result
|
||||
|
||||
agent_id = agent_result.data["id"]
|
||||
|
||||
# Handle model client
|
||||
if "model_client" in config:
|
||||
if check_exists:
|
||||
# Check for existing model
|
||||
model_type = self._determine_component_type(
|
||||
config["model_client"])
|
||||
existing_model = self._check_exists(
|
||||
model_type, config["model_client"], user_id)
|
||||
if existing_model:
|
||||
# Link existing model
|
||||
self.db_manager.link(
|
||||
LinkTypes.AGENT_MODEL,
|
||||
agent_id,
|
||||
existing_model.id
|
||||
)
|
||||
logger.info(
|
||||
f"Linked existing model to agent: {existing_model.config.model_type}")
|
||||
else:
|
||||
# Store and link new model
|
||||
model_result = await self._store_model(config["model_client"], user_id)
|
||||
if model_result.status:
|
||||
self.db_manager.link(
|
||||
LinkTypes.AGENT_MODEL,
|
||||
agent_id,
|
||||
model_result.data["id"]
|
||||
)
|
||||
else:
|
||||
# Store and link new model without checking
|
||||
model_result = await self._store_model(config["model_client"], user_id)
|
||||
if model_result.status:
|
||||
self.db_manager.link(
|
||||
LinkTypes.AGENT_MODEL,
|
||||
agent_id,
|
||||
model_result.data["id"]
|
||||
)
|
||||
|
||||
# Handle tools
|
||||
for tool_config in config.get("tools", []):
|
||||
if check_exists:
|
||||
# Check for existing tool
|
||||
tool_type = self._determine_component_type(tool_config)
|
||||
existing_tool = self._check_exists(
|
||||
tool_type, tool_config, user_id)
|
||||
if existing_tool:
|
||||
# Link existing tool
|
||||
self.db_manager.link(
|
||||
LinkTypes.AGENT_TOOL,
|
||||
agent_id,
|
||||
existing_tool.id
|
||||
)
|
||||
logger.info(
|
||||
f"Linked existing tool to agent: {existing_tool.config.name}")
|
||||
continue
|
||||
|
||||
# Store and link new tool
|
||||
tool_result = await self._store_tool(tool_config, user_id)
|
||||
if tool_result.status:
|
||||
self.db_manager.link(
|
||||
LinkTypes.AGENT_TOOL,
|
||||
agent_id,
|
||||
tool_result.data["id"]
|
||||
)
|
||||
|
||||
return agent_result
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to store agent: {str(e)}")
|
||||
return Response(message=str(e), status=False)
|
||||
|
||||
async def _store_model(self, config: dict, user_id: str) -> Response:
|
||||
"""Store model component (leaf node - no relationships)"""
|
||||
try:
|
||||
model_db = Model(
|
||||
user_id=user_id,
|
||||
config=config
|
||||
)
|
||||
return self.db_manager.upsert(model_db)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to store model: {str(e)}")
|
||||
return Response(message=str(e), status=False)
|
||||
|
||||
async def _store_tool(self, config: dict, user_id: str) -> Response:
|
||||
"""Store tool component (leaf node - no relationships)"""
|
||||
try:
|
||||
tool_db = Tool(
|
||||
user_id=user_id,
|
||||
config=config
|
||||
)
|
||||
return self.db_manager.upsert(tool_db)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to store tool: {str(e)}")
|
||||
return Response(message=str(e), status=False)
|
||||
|
||||
def _check_exists(self, component_type: ComponentTypes, config: dict, user_id: str) -> Optional[Union[Model, Tool, Agent, Team]]:
|
||||
"""Check if component exists based on configured uniqueness fields."""
|
||||
fields = self.uniqueness_fields.get(component_type, [])
|
||||
if not fields:
|
||||
return None
|
||||
|
||||
component_class = {
|
||||
ComponentTypes.MODEL: Model,
|
||||
ComponentTypes.TOOL: Tool,
|
||||
ComponentTypes.AGENT: Agent,
|
||||
ComponentTypes.TEAM: Team
|
||||
}.get(component_type)
|
||||
|
||||
components = self.db_manager.get(
|
||||
component_class, {"user_id": user_id}).data
|
||||
|
||||
for component in components:
|
||||
matches = all(
|
||||
component.config.get(field) == config.get(field)
|
||||
for field in fields
|
||||
)
|
||||
if matches:
|
||||
return component
|
||||
|
||||
return None
|
||||
|
||||
def _format_exists_message(self, component_type: ComponentTypes, config: dict) -> str:
|
||||
"""Format existence message with identifying fields."""
|
||||
fields = self.uniqueness_fields.get(component_type, [])
|
||||
field_values = [f"{field}='{config.get(field)}'" for field in fields]
|
||||
return f"{component_type.value} with {' and '.join(field_values)} already exists"
|
||||
|
||||
def _determine_component_type(self, config: dict) -> Optional[ComponentTypes]:
|
||||
"""Determine component type from configuration dictionary"""
|
||||
if "team_type" in config:
|
||||
return ComponentTypes.TEAM
|
||||
elif "agent_type" in config:
|
||||
return ComponentTypes.AGENT
|
||||
elif "model_type" in config:
|
||||
return ComponentTypes.MODEL
|
||||
elif "tool_type" in config:
|
||||
return ComponentTypes.TOOL
|
||||
return None
|
||||
|
||||
def _get_component_type(self, config: dict) -> str:
|
||||
"""Helper to get component type string from config"""
|
||||
component_type = self._determine_component_type(config)
|
||||
return component_type.value if component_type else "unknown"
|
||||
|
||||
async def cleanup(self):
|
||||
"""Cleanup resources"""
|
||||
await self.component_factory.cleanup()
|
||||
@@ -0,0 +1,424 @@
|
||||
import threading
|
||||
from datetime import datetime
|
||||
from typing import Optional
|
||||
|
||||
from loguru import logger
|
||||
from sqlalchemy import exc, text, func
|
||||
from sqlmodel import Session, SQLModel, and_, create_engine, select
|
||||
from .schema_manager import SchemaManager
|
||||
|
||||
from ..datamodel import (
|
||||
Response,
|
||||
LinkTypes
|
||||
)
|
||||
# from .dbutils import init_db_samples
|
||||
|
||||
|
||||
class DatabaseManager:
|
||||
"""A class to manage database operations"""
|
||||
|
||||
_init_lock = threading.Lock()
|
||||
|
||||
def __init__(self, engine_uri: str, auto_upgrade: bool = True):
|
||||
connection_args = {
|
||||
"check_same_thread": True} if "sqlite" in engine_uri else {}
|
||||
self.engine = create_engine(engine_uri, connect_args=connection_args)
|
||||
self.schema_manager = SchemaManager(
|
||||
engine=self.engine,
|
||||
auto_upgrade=auto_upgrade,
|
||||
)
|
||||
|
||||
# Check and upgrade on startup
|
||||
upgraded, status = self.schema_manager.check_and_upgrade()
|
||||
if upgraded:
|
||||
logger.info("Database schema was upgraded automatically")
|
||||
else:
|
||||
logger.info(f"Schema status: {status}")
|
||||
|
||||
def reset_db(self, recreate_tables: bool = True):
|
||||
"""
|
||||
Reset the database by dropping all tables and optionally recreating them.
|
||||
|
||||
Args:
|
||||
recreate_tables (bool): If True, recreates the tables after dropping them.
|
||||
Set to False if you want to call create_db_and_tables() separately.
|
||||
"""
|
||||
if not self._init_lock.acquire(blocking=False):
|
||||
logger.warning("Database reset already in progress")
|
||||
return Response(
|
||||
message="Database reset already in progress",
|
||||
status=False,
|
||||
data=None
|
||||
)
|
||||
|
||||
try:
|
||||
# Dispose existing connections
|
||||
self.engine.dispose()
|
||||
with Session(self.engine) as session:
|
||||
try:
|
||||
# Disable foreign key checks for SQLite
|
||||
if 'sqlite' in str(self.engine.url):
|
||||
session.exec(text('PRAGMA foreign_keys=OFF'))
|
||||
|
||||
# Drop all tables
|
||||
SQLModel.metadata.drop_all(self.engine)
|
||||
logger.info("All tables dropped successfully")
|
||||
|
||||
# Re-enable foreign key checks for SQLite
|
||||
if 'sqlite' in str(self.engine.url):
|
||||
session.exec(text('PRAGMA foreign_keys=ON'))
|
||||
|
||||
session.commit()
|
||||
|
||||
except Exception as e:
|
||||
session.rollback()
|
||||
raise e
|
||||
finally:
|
||||
session.close()
|
||||
self._init_lock.release()
|
||||
|
||||
if recreate_tables:
|
||||
logger.info("Recreating tables...")
|
||||
self.create_db_and_tables()
|
||||
|
||||
return Response(
|
||||
message="Database reset successfully" if recreate_tables else "Database tables dropped successfully",
|
||||
status=True,
|
||||
data=None
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
error_msg = f"Error while resetting database: {str(e)}"
|
||||
logger.error(error_msg)
|
||||
return Response(
|
||||
message=error_msg,
|
||||
status=False,
|
||||
data=None
|
||||
)
|
||||
finally:
|
||||
if self._init_lock.locked():
|
||||
self._init_lock.release()
|
||||
logger.info("Database reset lock released")
|
||||
|
||||
def create_db_and_tables(self):
|
||||
"""Create a new database and tables"""
|
||||
with self._init_lock:
|
||||
try:
|
||||
SQLModel.metadata.create_all(self.engine)
|
||||
logger.info("Database tables created successfully")
|
||||
try:
|
||||
# init_db_samples(self)
|
||||
pass
|
||||
except Exception as e:
|
||||
logger.info(
|
||||
"Error while initializing database samples: " + str(e))
|
||||
except Exception as e:
|
||||
logger.info("Error while creating database tables:" + str(e))
|
||||
|
||||
def upsert(self, model: SQLModel, return_json: bool = True):
|
||||
"""Create or update an entity
|
||||
|
||||
Args:
|
||||
model (SQLModel): The model instance to create or update
|
||||
return_json (bool, optional): If True, returns the model as a dictionary.
|
||||
If False, returns the SQLModel instance. Defaults to True.
|
||||
|
||||
Returns:
|
||||
Response: Contains status, message and data (either dict or SQLModel based on return_json)
|
||||
"""
|
||||
status = True
|
||||
model_class = type(model)
|
||||
existing_model = None
|
||||
|
||||
with Session(self.engine) as session:
|
||||
try:
|
||||
existing_model = session.exec(
|
||||
select(model_class).where(model_class.id == model.id)).first()
|
||||
if existing_model:
|
||||
model.updated_at = datetime.now()
|
||||
for key, value in model.model_dump().items():
|
||||
setattr(existing_model, key, value)
|
||||
model = existing_model # Use the updated existing model
|
||||
session.add(model)
|
||||
else:
|
||||
session.add(model)
|
||||
session.commit()
|
||||
session.refresh(model)
|
||||
except Exception as e:
|
||||
session.rollback()
|
||||
logger.error("Error while updating/creating " +
|
||||
str(model_class.__name__) + ": " + str(e))
|
||||
status = False
|
||||
|
||||
return Response(
|
||||
message=(
|
||||
f"{model_class.__name__} Updated Successfully"
|
||||
if existing_model
|
||||
else f"{model_class.__name__} Created Successfully"
|
||||
),
|
||||
status=status,
|
||||
data=model.model_dump() if return_json else model,
|
||||
)
|
||||
|
||||
def _model_to_dict(self, model_obj):
|
||||
return {col.name: getattr(model_obj, col.name) for col in model_obj.__table__.columns}
|
||||
|
||||
def get(
|
||||
self,
|
||||
model_class: SQLModel,
|
||||
filters: dict = None,
|
||||
return_json: bool = False,
|
||||
order: str = "desc",
|
||||
):
|
||||
"""List entities"""
|
||||
with Session(self.engine) as session:
|
||||
result = []
|
||||
status = True
|
||||
status_message = ""
|
||||
|
||||
try:
|
||||
statement = select(model_class)
|
||||
if filters:
|
||||
conditions = [getattr(model_class, col) ==
|
||||
value for col, value in filters.items()]
|
||||
statement = statement.where(and_(*conditions))
|
||||
|
||||
if hasattr(model_class, "created_at") and order:
|
||||
order_by_clause = getattr(
|
||||
model_class.created_at, order)() # Dynamically apply asc/desc
|
||||
statement = statement.order_by(order_by_clause)
|
||||
|
||||
items = session.exec(statement).all()
|
||||
result = [self._model_to_dict(
|
||||
item) if return_json else item for item in items]
|
||||
status_message = f"{model_class.__name__} Retrieved Successfully"
|
||||
except Exception as e:
|
||||
session.rollback()
|
||||
status = False
|
||||
status_message = f"Error while fetching {model_class.__name__}"
|
||||
logger.error("Error while getting items: " +
|
||||
str(model_class.__name__) + " " + str(e))
|
||||
|
||||
return Response(message=status_message, status=status, data=result)
|
||||
|
||||
def delete(self, model_class: SQLModel, filters: dict = None):
|
||||
"""Delete an entity"""
|
||||
status_message = ""
|
||||
status = True
|
||||
|
||||
with Session(self.engine) as session:
|
||||
try:
|
||||
statement = select(model_class)
|
||||
if filters:
|
||||
conditions = [
|
||||
getattr(model_class, col) == value for col, value in filters.items()]
|
||||
statement = statement.where(and_(*conditions))
|
||||
|
||||
rows = session.exec(statement).all()
|
||||
|
||||
if rows:
|
||||
for row in rows:
|
||||
session.delete(row)
|
||||
session.commit()
|
||||
status_message = f"{model_class.__name__} Deleted Successfully"
|
||||
else:
|
||||
status_message = "Row not found"
|
||||
logger.info(f"Row with filters {filters} not found")
|
||||
|
||||
except exc.IntegrityError as e:
|
||||
session.rollback()
|
||||
status = False
|
||||
status_message = f"Integrity error: The {model_class.__name__} is linked to another entity and cannot be deleted. {e}"
|
||||
# Log the specific integrity error
|
||||
logger.error(status_message)
|
||||
except Exception as e:
|
||||
session.rollback()
|
||||
status = False
|
||||
status_message = f"Error while deleting: {e}"
|
||||
logger.error(status_message)
|
||||
|
||||
return Response(message=status_message, status=status, data=None)
|
||||
|
||||
def link(
|
||||
self,
|
||||
link_type: LinkTypes,
|
||||
primary_id: int,
|
||||
secondary_id: int,
|
||||
sequence: Optional[int] = None,
|
||||
):
|
||||
"""Link two entities with automatic sequence handling."""
|
||||
with Session(self.engine) as session:
|
||||
try:
|
||||
# Get classes from LinkTypes
|
||||
primary_class = link_type.primary_class
|
||||
secondary_class = link_type.secondary_class
|
||||
link_table = link_type.link_table
|
||||
|
||||
# Get entities
|
||||
primary_entity = session.get(primary_class, primary_id)
|
||||
secondary_entity = session.get(secondary_class, secondary_id)
|
||||
|
||||
if not primary_entity or not secondary_entity:
|
||||
return Response(message="One or both entities do not exist", status=False)
|
||||
|
||||
# Get field names
|
||||
primary_id_field = f"{primary_class.__name__.lower()}_id"
|
||||
secondary_id_field = f"{secondary_class.__name__.lower()}_id"
|
||||
|
||||
# Check for existing link
|
||||
existing_link = session.exec(
|
||||
select(link_table).where(
|
||||
and_(
|
||||
getattr(link_table, primary_id_field) == primary_id,
|
||||
getattr(
|
||||
link_table, secondary_id_field) == secondary_id
|
||||
)
|
||||
)
|
||||
).first()
|
||||
|
||||
if existing_link:
|
||||
return Response(message="Link already exists", status=False)
|
||||
|
||||
# Get the next sequence number if not provided
|
||||
if sequence is None:
|
||||
max_seq_result = session.exec(
|
||||
select(func.max(link_table.sequence)).where(
|
||||
getattr(link_table, primary_id_field) == primary_id
|
||||
)
|
||||
).first()
|
||||
sequence = 0 if max_seq_result is None else max_seq_result + 1
|
||||
|
||||
# Create new link
|
||||
new_link = link_table(**{
|
||||
primary_id_field: primary_id,
|
||||
secondary_id_field: secondary_id,
|
||||
'sequence': sequence
|
||||
})
|
||||
session.add(new_link)
|
||||
session.commit()
|
||||
|
||||
return Response(
|
||||
message=f"Entities linked successfully with sequence {sequence}",
|
||||
status=True
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
session.rollback()
|
||||
return Response(message=f"Error linking entities: {str(e)}", status=False)
|
||||
|
||||
def unlink(
|
||||
self,
|
||||
link_type: LinkTypes,
|
||||
primary_id: int,
|
||||
secondary_id: int,
|
||||
sequence: Optional[int] = None
|
||||
):
|
||||
"""Unlink two entities and reorder sequences if needed."""
|
||||
with Session(self.engine) as session:
|
||||
try:
|
||||
# Get classes from LinkTypes
|
||||
primary_class = link_type.primary_class
|
||||
secondary_class = link_type.secondary_class
|
||||
link_table = link_type.link_table
|
||||
|
||||
# Get field names
|
||||
primary_id_field = f"{primary_class.__name__.lower()}_id"
|
||||
secondary_id_field = f"{secondary_class.__name__.lower()}_id"
|
||||
|
||||
# Find existing link
|
||||
statement = select(link_table).where(
|
||||
and_(
|
||||
getattr(link_table, primary_id_field) == primary_id,
|
||||
getattr(link_table, secondary_id_field) == secondary_id
|
||||
)
|
||||
)
|
||||
|
||||
if sequence is not None:
|
||||
statement = statement.where(
|
||||
link_table.sequence == sequence)
|
||||
|
||||
existing_link = session.exec(statement).first()
|
||||
|
||||
if not existing_link:
|
||||
return Response(message="Link does not exist", status=False)
|
||||
|
||||
deleted_sequence = existing_link.sequence
|
||||
session.delete(existing_link)
|
||||
|
||||
# Reorder sequences for remaining links
|
||||
remaining_links = session.exec(
|
||||
select(link_table)
|
||||
.where(getattr(link_table, primary_id_field) == primary_id)
|
||||
.where(link_table.sequence > deleted_sequence)
|
||||
.order_by(link_table.sequence)
|
||||
).all()
|
||||
|
||||
# Decrease sequence numbers to fill the gap
|
||||
for link in remaining_links:
|
||||
link.sequence -= 1
|
||||
|
||||
session.commit()
|
||||
|
||||
return Response(
|
||||
message="Entities unlinked successfully and sequences reordered",
|
||||
status=True
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
session.rollback()
|
||||
return Response(message=f"Error unlinking entities: {str(e)}", status=False)
|
||||
|
||||
def get_linked_entities(
|
||||
self,
|
||||
link_type: LinkTypes,
|
||||
primary_id: int,
|
||||
return_json: bool = False,
|
||||
):
|
||||
"""Get linked entities based on link type and primary ID, ordered by sequence."""
|
||||
with Session(self.engine) as session:
|
||||
try:
|
||||
# Get classes from LinkTypes
|
||||
primary_class = link_type.primary_class
|
||||
secondary_class = link_type.secondary_class
|
||||
link_table = link_type.link_table
|
||||
|
||||
# Get field names
|
||||
primary_id_field = f"{primary_class.__name__.lower()}_id"
|
||||
secondary_id_field = f"{secondary_class.__name__.lower()}_id"
|
||||
|
||||
# Query both link and entity, ordered by sequence
|
||||
items = session.exec(
|
||||
select(secondary_class)
|
||||
.join(link_table, getattr(link_table, secondary_id_field) == secondary_class.id)
|
||||
.where(getattr(link_table, primary_id_field) == primary_id)
|
||||
.order_by(link_table.sequence)
|
||||
).all()
|
||||
|
||||
result = [
|
||||
item.model_dump() if return_json else item for item in items]
|
||||
|
||||
return Response(
|
||||
message="Linked entities retrieved successfully",
|
||||
status=True,
|
||||
data=result
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting linked entities: {str(e)}")
|
||||
return Response(
|
||||
message=f"Error getting linked entities: {str(e)}",
|
||||
status=False,
|
||||
data=[]
|
||||
)
|
||||
# Add new close method
|
||||
|
||||
async def close(self):
|
||||
"""Close database connections and cleanup resources"""
|
||||
logger.info("Closing database connections...")
|
||||
try:
|
||||
# Dispose of the SQLAlchemy engine
|
||||
self.engine.dispose()
|
||||
logger.info("Database connections closed successfully")
|
||||
except Exception as e:
|
||||
logger.error(f"Error closing database connections: {str(e)}")
|
||||
raise
|
||||
@@ -1,491 +0,0 @@
|
||||
import threading
|
||||
from datetime import datetime
|
||||
from typing import Optional
|
||||
|
||||
from loguru import logger
|
||||
from sqlalchemy import exc
|
||||
from sqlmodel import Session, SQLModel, and_, create_engine, select
|
||||
|
||||
from ..datamodel import (
|
||||
Agent,
|
||||
AgentLink,
|
||||
AgentModelLink,
|
||||
AgentSkillLink,
|
||||
Model,
|
||||
Response,
|
||||
Skill,
|
||||
Workflow,
|
||||
WorkflowAgentLink,
|
||||
WorkflowAgentType,
|
||||
)
|
||||
from .utils import init_db_samples
|
||||
|
||||
valid_link_types = ["agent_model", "agent_skill", "agent_agent", "workflow_agent"]
|
||||
|
||||
|
||||
class WorkflowAgentMap(SQLModel):
|
||||
agent: Agent
|
||||
link: WorkflowAgentLink
|
||||
|
||||
|
||||
class DBManager:
|
||||
"""A class to manage database operations"""
|
||||
|
||||
_init_lock = threading.Lock() # Class-level lock
|
||||
|
||||
def __init__(self, engine_uri: str):
|
||||
connection_args = {"check_same_thread": True} if "sqlite" in engine_uri else {}
|
||||
self.engine = create_engine(engine_uri, connect_args=connection_args)
|
||||
# run_migration(engine_uri=engine_uri)
|
||||
|
||||
def create_db_and_tables(self):
|
||||
"""Create a new database and tables"""
|
||||
with self._init_lock: # Use the lock
|
||||
try:
|
||||
SQLModel.metadata.create_all(self.engine)
|
||||
try:
|
||||
init_db_samples(self)
|
||||
except Exception as e:
|
||||
logger.info("Error while initializing database samples: " + str(e))
|
||||
except Exception as e:
|
||||
logger.info("Error while creating database tables:" + str(e))
|
||||
|
||||
def upsert(self, model: SQLModel):
|
||||
"""Create a new entity"""
|
||||
# check if the model exists, update else add
|
||||
status = True
|
||||
model_class = type(model)
|
||||
existing_model = None
|
||||
|
||||
with Session(self.engine) as session:
|
||||
try:
|
||||
existing_model = session.exec(select(model_class).where(model_class.id == model.id)).first()
|
||||
if existing_model:
|
||||
model.updated_at = datetime.now()
|
||||
for key, value in model.model_dump().items():
|
||||
setattr(existing_model, key, value)
|
||||
model = existing_model
|
||||
session.add(model)
|
||||
else:
|
||||
session.add(model)
|
||||
session.commit()
|
||||
session.refresh(model)
|
||||
except Exception as e:
|
||||
session.rollback()
|
||||
logger.error("Error while updating " + str(model_class.__name__) + ": " + str(e))
|
||||
status = False
|
||||
|
||||
response = Response(
|
||||
message=(
|
||||
f"{model_class.__name__} Updated Successfully "
|
||||
if existing_model
|
||||
else f"{model_class.__name__} Created Successfully"
|
||||
),
|
||||
status=status,
|
||||
data=model.model_dump(),
|
||||
)
|
||||
|
||||
return response
|
||||
|
||||
def _model_to_dict(self, model_obj):
|
||||
return {col.name: getattr(model_obj, col.name) for col in model_obj.__table__.columns}
|
||||
|
||||
def get_items(
|
||||
self,
|
||||
model_class: SQLModel,
|
||||
session: Session,
|
||||
filters: dict = None,
|
||||
return_json: bool = False,
|
||||
order: str = "desc",
|
||||
):
|
||||
"""List all entities"""
|
||||
result = []
|
||||
status = True
|
||||
status_message = ""
|
||||
|
||||
try:
|
||||
if filters:
|
||||
conditions = [getattr(model_class, col) == value for col, value in filters.items()]
|
||||
statement = select(model_class).where(and_(*conditions))
|
||||
|
||||
if hasattr(model_class, "created_at") and order:
|
||||
if order == "desc":
|
||||
statement = statement.order_by(model_class.created_at.desc())
|
||||
else:
|
||||
statement = statement.order_by(model_class.created_at.asc())
|
||||
else:
|
||||
statement = select(model_class)
|
||||
|
||||
if return_json:
|
||||
result = [self._model_to_dict(row) for row in session.exec(statement).all()]
|
||||
else:
|
||||
result = session.exec(statement).all()
|
||||
status_message = f"{model_class.__name__} Retrieved Successfully"
|
||||
except Exception as e:
|
||||
session.rollback()
|
||||
status = False
|
||||
status_message = f"Error while fetching {model_class.__name__}"
|
||||
logger.error("Error while getting items: " + str(model_class.__name__) + " " + str(e))
|
||||
|
||||
response: Response = Response(
|
||||
message=status_message,
|
||||
status=status,
|
||||
data=result,
|
||||
)
|
||||
return response
|
||||
|
||||
def get(
|
||||
self,
|
||||
model_class: SQLModel,
|
||||
filters: dict = None,
|
||||
return_json: bool = False,
|
||||
order: str = "desc",
|
||||
):
|
||||
"""List all entities"""
|
||||
|
||||
with Session(self.engine) as session:
|
||||
response = self.get_items(model_class, session, filters, return_json, order)
|
||||
return response
|
||||
|
||||
def delete(self, model_class: SQLModel, filters: dict = None):
|
||||
"""Delete an entity"""
|
||||
row = None
|
||||
status_message = ""
|
||||
status = True
|
||||
|
||||
with Session(self.engine) as session:
|
||||
try:
|
||||
if filters:
|
||||
conditions = [getattr(model_class, col) == value for col, value in filters.items()]
|
||||
row = session.exec(select(model_class).where(and_(*conditions))).all()
|
||||
else:
|
||||
row = session.exec(select(model_class)).all()
|
||||
if row:
|
||||
for row in row:
|
||||
session.delete(row)
|
||||
session.commit()
|
||||
status_message = f"{model_class.__name__} Deleted Successfully"
|
||||
else:
|
||||
print(f"Row with filters {filters} not found")
|
||||
logger.info("Row with filters + filters + not found")
|
||||
status_message = "Row not found"
|
||||
except exc.IntegrityError as e:
|
||||
session.rollback()
|
||||
logger.error("Integrity ... Error while deleting: " + str(e))
|
||||
status_message = f"The {model_class.__name__} is linked to another entity and cannot be deleted."
|
||||
status = False
|
||||
except Exception as e:
|
||||
session.rollback()
|
||||
logger.error("Error while deleting: " + str(e))
|
||||
status_message = f"Error while deleting: {e}"
|
||||
status = False
|
||||
response = Response(
|
||||
message=status_message,
|
||||
status=status,
|
||||
data=None,
|
||||
)
|
||||
return response
|
||||
|
||||
def get_linked_entities(
|
||||
self,
|
||||
link_type: str,
|
||||
primary_id: int,
|
||||
return_json: bool = False,
|
||||
agent_type: Optional[str] = None,
|
||||
sequence_id: Optional[int] = None,
|
||||
):
|
||||
"""
|
||||
Get all entities linked to the primary entity.
|
||||
|
||||
Args:
|
||||
link_type (str): The type of link to retrieve, e.g., "agent_model".
|
||||
primary_id (int): The identifier for the primary model.
|
||||
return_json (bool): Whether to return the result as a JSON object.
|
||||
|
||||
Returns:
|
||||
List[SQLModel]: A list of linked entities.
|
||||
"""
|
||||
|
||||
linked_entities = []
|
||||
|
||||
if link_type not in valid_link_types:
|
||||
return []
|
||||
|
||||
status = True
|
||||
status_message = ""
|
||||
|
||||
with Session(self.engine) as session:
|
||||
try:
|
||||
if link_type == "agent_model":
|
||||
# get the agent
|
||||
agent = self.get_items(Agent, filters={"id": primary_id}, session=session).data[0]
|
||||
linked_entities = agent.models
|
||||
elif link_type == "agent_skill":
|
||||
agent = self.get_items(Agent, filters={"id": primary_id}, session=session).data[0]
|
||||
linked_entities = agent.skills
|
||||
elif link_type == "agent_agent":
|
||||
agent = self.get_items(Agent, filters={"id": primary_id}, session=session).data[0]
|
||||
linked_entities = agent.agents
|
||||
elif link_type == "workflow_agent":
|
||||
linked_entities = session.exec(
|
||||
select(WorkflowAgentLink, Agent)
|
||||
.join(Agent, WorkflowAgentLink.agent_id == Agent.id)
|
||||
.where(
|
||||
WorkflowAgentLink.workflow_id == primary_id,
|
||||
)
|
||||
).all()
|
||||
|
||||
linked_entities = [WorkflowAgentMap(agent=agent, link=link) for link, agent in linked_entities]
|
||||
linked_entities = sorted(linked_entities, key=lambda x: x.link.sequence_id) # type: ignore
|
||||
except Exception as e:
|
||||
logger.error("Error while getting linked entities: " + str(e))
|
||||
status_message = f"Error while getting linked entities: {e}"
|
||||
status = False
|
||||
if return_json:
|
||||
linked_entities = [row.model_dump() for row in linked_entities]
|
||||
|
||||
response = Response(
|
||||
message=status_message,
|
||||
status=status,
|
||||
data=linked_entities,
|
||||
)
|
||||
|
||||
return response
|
||||
|
||||
def link(
|
||||
self,
|
||||
link_type: str,
|
||||
primary_id: int,
|
||||
secondary_id: int,
|
||||
agent_type: Optional[str] = None,
|
||||
sequence_id: Optional[int] = None,
|
||||
) -> Response:
|
||||
"""
|
||||
Link two entities together.
|
||||
|
||||
Args:
|
||||
link_type (str): The type of link to create, e.g., "agent_model".
|
||||
primary_id (int): The identifier for the primary model.
|
||||
secondary_id (int): The identifier for the secondary model.
|
||||
agent_type (Optional[str]): The type of agent, e.g., "sender" or receiver.
|
||||
|
||||
Returns:
|
||||
Response: The response of the linking operation, including success status and message.
|
||||
"""
|
||||
|
||||
# TBD verify that is creator of the primary entity being linked
|
||||
status = True
|
||||
status_message = ""
|
||||
primary_model = None
|
||||
secondary_model = None
|
||||
|
||||
if link_type not in valid_link_types:
|
||||
status = False
|
||||
status_message = f"Invalid link type: {link_type}. Valid link types are: {valid_link_types}"
|
||||
else:
|
||||
with Session(self.engine) as session:
|
||||
try:
|
||||
if link_type == "agent_model":
|
||||
primary_model = session.exec(select(Agent).where(Agent.id == primary_id)).first()
|
||||
secondary_model = session.exec(select(Model).where(Model.id == secondary_id)).first()
|
||||
if primary_model is None or secondary_model is None:
|
||||
status = False
|
||||
status_message = "One or both entity records do not exist."
|
||||
else:
|
||||
# check if the link already exists
|
||||
existing_link = session.exec(
|
||||
select(AgentModelLink).where(
|
||||
AgentModelLink.agent_id == primary_id,
|
||||
AgentModelLink.model_id == secondary_id,
|
||||
)
|
||||
).first()
|
||||
if existing_link: # link already exists
|
||||
return Response(
|
||||
message=(
|
||||
f"{secondary_model.__class__.__name__} already linked "
|
||||
f"to {primary_model.__class__.__name__}"
|
||||
),
|
||||
status=False,
|
||||
)
|
||||
else:
|
||||
primary_model.models.append(secondary_model)
|
||||
elif link_type == "agent_agent":
|
||||
primary_model = session.exec(select(Agent).where(Agent.id == primary_id)).first()
|
||||
secondary_model = session.exec(select(Agent).where(Agent.id == secondary_id)).first()
|
||||
if primary_model is None or secondary_model is None:
|
||||
status = False
|
||||
status_message = "One or both entity records do not exist."
|
||||
else:
|
||||
# check if the link already exists
|
||||
existing_link = session.exec(
|
||||
select(AgentLink).where(
|
||||
AgentLink.parent_id == primary_id,
|
||||
AgentLink.agent_id == secondary_id,
|
||||
)
|
||||
).first()
|
||||
if existing_link:
|
||||
return Response(
|
||||
message=(
|
||||
f"{secondary_model.__class__.__name__} already linked "
|
||||
f"to {primary_model.__class__.__name__}"
|
||||
),
|
||||
status=False,
|
||||
)
|
||||
else:
|
||||
primary_model.agents.append(secondary_model)
|
||||
|
||||
elif link_type == "agent_skill":
|
||||
primary_model = session.exec(select(Agent).where(Agent.id == primary_id)).first()
|
||||
secondary_model = session.exec(select(Skill).where(Skill.id == secondary_id)).first()
|
||||
if primary_model is None or secondary_model is None:
|
||||
status = False
|
||||
status_message = "One or both entity records do not exist."
|
||||
else:
|
||||
# check if the link already exists
|
||||
existing_link = session.exec(
|
||||
select(AgentSkillLink).where(
|
||||
AgentSkillLink.agent_id == primary_id,
|
||||
AgentSkillLink.skill_id == secondary_id,
|
||||
)
|
||||
).first()
|
||||
if existing_link:
|
||||
return Response(
|
||||
message=(
|
||||
f"{secondary_model.__class__.__name__} already linked "
|
||||
f"to {primary_model.__class__.__name__}"
|
||||
),
|
||||
status=False,
|
||||
)
|
||||
else:
|
||||
primary_model.skills.append(secondary_model)
|
||||
elif link_type == "workflow_agent":
|
||||
primary_model = session.exec(select(Workflow).where(Workflow.id == primary_id)).first()
|
||||
secondary_model = session.exec(select(Agent).where(Agent.id == secondary_id)).first()
|
||||
if primary_model is None or secondary_model is None:
|
||||
status = False
|
||||
status_message = "One or both entity records do not exist."
|
||||
else:
|
||||
# check if the link already exists
|
||||
existing_link = session.exec(
|
||||
select(WorkflowAgentLink).where(
|
||||
WorkflowAgentLink.workflow_id == primary_id,
|
||||
WorkflowAgentLink.agent_id == secondary_id,
|
||||
WorkflowAgentLink.agent_type == agent_type,
|
||||
WorkflowAgentLink.sequence_id == sequence_id,
|
||||
)
|
||||
).first()
|
||||
if existing_link:
|
||||
return Response(
|
||||
message=(
|
||||
f"{secondary_model.__class__.__name__} already linked "
|
||||
f"to {primary_model.__class__.__name__}"
|
||||
),
|
||||
status=False,
|
||||
)
|
||||
else:
|
||||
# primary_model.agents.append(secondary_model)
|
||||
workflow_agent_link = WorkflowAgentLink(
|
||||
workflow_id=primary_id,
|
||||
agent_id=secondary_id,
|
||||
agent_type=agent_type,
|
||||
sequence_id=sequence_id,
|
||||
)
|
||||
session.add(workflow_agent_link)
|
||||
# add and commit the link
|
||||
session.add(primary_model)
|
||||
session.commit()
|
||||
status_message = (
|
||||
f"{secondary_model.__class__.__name__} successfully linked "
|
||||
f"to {primary_model.__class__.__name__}"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
session.rollback()
|
||||
logger.error("Error while linking: " + str(e))
|
||||
status = False
|
||||
status_message = f"Error while linking due to an exception: {e}"
|
||||
|
||||
response = Response(
|
||||
message=status_message,
|
||||
status=status,
|
||||
)
|
||||
|
||||
return response
|
||||
|
||||
def unlink(
|
||||
self,
|
||||
link_type: str,
|
||||
primary_id: int,
|
||||
secondary_id: int,
|
||||
agent_type: Optional[str] = None,
|
||||
sequence_id: Optional[int] = 0,
|
||||
) -> Response:
|
||||
"""
|
||||
Unlink two entities.
|
||||
|
||||
Args:
|
||||
link_type (str): The type of link to remove, e.g., "agent_model".
|
||||
primary_id (int): The identifier for the primary model.
|
||||
secondary_id (int): The identifier for the secondary model.
|
||||
agent_type (Optional[str]): The type of agent, e.g., "sender" or receiver.
|
||||
|
||||
Returns:
|
||||
Response: The response of the unlinking operation, including success status and message.
|
||||
"""
|
||||
status = True
|
||||
status_message = ""
|
||||
print("primary", primary_id, "secondary", secondary_id, "sequence", sequence_id, "agent_type", agent_type)
|
||||
|
||||
if link_type not in valid_link_types:
|
||||
status = False
|
||||
status_message = f"Invalid link type: {link_type}. Valid link types are: {valid_link_types}"
|
||||
return Response(message=status_message, status=status)
|
||||
|
||||
with Session(self.engine) as session:
|
||||
try:
|
||||
if link_type == "agent_model":
|
||||
existing_link = session.exec(
|
||||
select(AgentModelLink).where(
|
||||
AgentModelLink.agent_id == primary_id,
|
||||
AgentModelLink.model_id == secondary_id,
|
||||
)
|
||||
).first()
|
||||
elif link_type == "agent_skill":
|
||||
existing_link = session.exec(
|
||||
select(AgentSkillLink).where(
|
||||
AgentSkillLink.agent_id == primary_id,
|
||||
AgentSkillLink.skill_id == secondary_id,
|
||||
)
|
||||
).first()
|
||||
elif link_type == "agent_agent":
|
||||
existing_link = session.exec(
|
||||
select(AgentLink).where(
|
||||
AgentLink.parent_id == primary_id,
|
||||
AgentLink.agent_id == secondary_id,
|
||||
)
|
||||
).first()
|
||||
elif link_type == "workflow_agent":
|
||||
existing_link = session.exec(
|
||||
select(WorkflowAgentLink).where(
|
||||
WorkflowAgentLink.workflow_id == primary_id,
|
||||
WorkflowAgentLink.agent_id == secondary_id,
|
||||
WorkflowAgentLink.agent_type == agent_type,
|
||||
WorkflowAgentLink.sequence_id == sequence_id,
|
||||
)
|
||||
).first()
|
||||
|
||||
if existing_link:
|
||||
session.delete(existing_link)
|
||||
session.commit()
|
||||
status_message = "Link removed successfully."
|
||||
else:
|
||||
status = False
|
||||
status_message = "Link does not exist."
|
||||
|
||||
except Exception as e:
|
||||
session.rollback()
|
||||
logger.error("Error while unlinking: " + str(e))
|
||||
status = False
|
||||
status_message = f"Error while unlinking due to an exception: {e}"
|
||||
|
||||
return Response(message=status_message, status=status)
|
||||
@@ -1 +0,0 @@
|
||||
Generic single-database configuration.
|
||||
@@ -1,80 +0,0 @@
|
||||
import os
|
||||
from logging.config import fileConfig
|
||||
|
||||
from alembic import context
|
||||
from sqlalchemy import engine_from_config, pool
|
||||
from sqlmodel import SQLModel
|
||||
|
||||
from autogenstudio.datamodel import *
|
||||
from autogenstudio.utils import get_db_uri
|
||||
|
||||
# this is the Alembic Config object, which provides
|
||||
# access to the values within the .ini file in use.
|
||||
config = context.config
|
||||
config.set_main_option("sqlalchemy.url", get_db_uri())
|
||||
|
||||
# Interpret the config file for Python logging.
|
||||
# This line sets up loggers basically.
|
||||
if config.config_file_name is not None:
|
||||
fileConfig(config.config_file_name)
|
||||
|
||||
# add your model's MetaData object here
|
||||
# for 'autogenerate' support
|
||||
# from myapp import mymodel
|
||||
# target_metadata = mymodel.Base.metadata
|
||||
target_metadata = SQLModel.metadata
|
||||
|
||||
# other values from the config, defined by the needs of env.py,
|
||||
# can be acquired:
|
||||
# my_important_option = config.get_main_option("my_important_option")
|
||||
# ... etc.
|
||||
|
||||
|
||||
def run_migrations_offline() -> None:
|
||||
"""Run migrations in 'offline' mode.
|
||||
|
||||
This configures the context with just a URL
|
||||
and not an Engine, though an Engine is acceptable
|
||||
here as well. By skipping the Engine creation
|
||||
we don't even need a DBAPI to be available.
|
||||
|
||||
Calls to context.execute() here emit the given string to the
|
||||
script output.
|
||||
|
||||
"""
|
||||
url = config.get_main_option("sqlalchemy.url")
|
||||
context.configure(
|
||||
url=url,
|
||||
target_metadata=target_metadata,
|
||||
literal_binds=True,
|
||||
dialect_opts={"paramstyle": "named"},
|
||||
)
|
||||
|
||||
with context.begin_transaction():
|
||||
context.run_migrations()
|
||||
|
||||
|
||||
def run_migrations_online() -> None:
|
||||
"""Run migrations in 'online' mode.
|
||||
|
||||
In this scenario we need to create an Engine
|
||||
and associate a connection with the context.
|
||||
|
||||
"""
|
||||
connectable = engine_from_config(
|
||||
config.get_section(config.config_ini_section, {}),
|
||||
prefix="sqlalchemy.",
|
||||
poolclass=pool.NullPool,
|
||||
)
|
||||
|
||||
with connectable.connect() as connection:
|
||||
context.configure(connection=connection, target_metadata=target_metadata)
|
||||
|
||||
with context.begin_transaction():
|
||||
context.run_migrations()
|
||||
|
||||
|
||||
if context.is_offline_mode():
|
||||
run_migrations_offline()
|
||||
else:
|
||||
run_migrations_online()
|
||||
@@ -1,27 +0,0 @@
|
||||
"""${message}
|
||||
|
||||
Revision ID: ${up_revision}
|
||||
Revises: ${down_revision | comma,n}
|
||||
Create Date: ${create_date}
|
||||
|
||||
"""
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
import sqlmodel
|
||||
${imports if imports else ""}
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = ${repr(up_revision)}
|
||||
down_revision: Union[str, None] = ${repr(down_revision)}
|
||||
branch_labels: Union[str, Sequence[str], None] = ${repr(branch_labels)}
|
||||
depends_on: Union[str, Sequence[str], None] = ${repr(depends_on)}
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
${upgrades if upgrades else "pass"}
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
${downgrades if downgrades else "pass"}
|
||||
@@ -0,0 +1,505 @@
|
||||
import os
|
||||
from pathlib import Path
|
||||
import shutil
|
||||
from typing import Optional, Tuple, List
|
||||
from loguru import logger
|
||||
from alembic import command
|
||||
from alembic.config import Config
|
||||
from alembic.runtime.migration import MigrationContext
|
||||
from alembic.script import ScriptDirectory
|
||||
from alembic.autogenerate import compare_metadata
|
||||
from sqlalchemy import Engine
|
||||
from sqlmodel import SQLModel
|
||||
|
||||
|
||||
class SchemaManager:
|
||||
"""
|
||||
Manages database schema validation and migrations using Alembic.
|
||||
Provides automatic schema validation, migrations, and safe upgrades.
|
||||
|
||||
Args:
|
||||
engine: SQLAlchemy engine instance
|
||||
auto_upgrade: Whether to automatically upgrade schema when differences found
|
||||
init_mode: Controls initialization behavior:
|
||||
- "none": No automatic initialization (raises error if not set up)
|
||||
- "auto": Initialize if not present (default)
|
||||
- "force": Always reinitialize, removing existing configuration
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
engine: Engine,
|
||||
auto_upgrade: bool = True,
|
||||
init_mode: str = "auto"
|
||||
):
|
||||
if init_mode not in ["none", "auto", "force"]:
|
||||
raise ValueError("init_mode must be one of: none, auto, force")
|
||||
|
||||
self.engine = engine
|
||||
self.auto_upgrade = auto_upgrade
|
||||
|
||||
# Set up paths relative to this file
|
||||
self.base_dir = Path(__file__).parent
|
||||
self.alembic_dir = self.base_dir / 'alembic'
|
||||
self.alembic_ini_path = self.base_dir / 'alembic.ini'
|
||||
|
||||
# Handle initialization based on mode
|
||||
if init_mode == "none":
|
||||
self._validate_alembic_setup()
|
||||
else:
|
||||
self._ensure_alembic_setup(force=init_mode == "force")
|
||||
|
||||
def _cleanup_existing_alembic(self) -> None:
|
||||
"""
|
||||
Safely removes existing Alembic configuration while preserving versions directory.
|
||||
"""
|
||||
logger.info(
|
||||
"Cleaning up existing Alembic configuration while preserving versions...")
|
||||
|
||||
# Create a backup of versions directory if it exists
|
||||
if self.alembic_dir.exists() and (self.alembic_dir / 'versions').exists():
|
||||
logger.info("Preserving existing versions directory")
|
||||
|
||||
# Remove alembic directory contents EXCEPT versions
|
||||
if self.alembic_dir.exists():
|
||||
for item in self.alembic_dir.iterdir():
|
||||
if item.name != 'versions':
|
||||
try:
|
||||
if item.is_dir():
|
||||
shutil.rmtree(item)
|
||||
logger.info(f"Removed directory: {item}")
|
||||
else:
|
||||
item.unlink()
|
||||
logger.info(f"Removed file: {item}")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to remove {item}: {e}")
|
||||
|
||||
# Remove alembic.ini if it exists
|
||||
if self.alembic_ini_path.exists():
|
||||
try:
|
||||
self.alembic_ini_path.unlink()
|
||||
logger.info(
|
||||
f"Removed existing alembic.ini: {self.alembic_ini_path}")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to remove alembic.ini: {e}")
|
||||
|
||||
def _ensure_alembic_setup(self, *, force: bool = False) -> None:
|
||||
"""
|
||||
Ensures Alembic is properly set up, initializing if necessary.
|
||||
|
||||
Args:
|
||||
force: If True, removes existing configuration and reinitializes
|
||||
"""
|
||||
try:
|
||||
self._validate_alembic_setup()
|
||||
if force:
|
||||
logger.info(
|
||||
"Force initialization requested. Cleaning up existing configuration...")
|
||||
self._cleanup_existing_alembic()
|
||||
self._initialize_alembic()
|
||||
except FileNotFoundError:
|
||||
logger.info("Alembic configuration not found. Initializing...")
|
||||
if self.alembic_dir.exists():
|
||||
logger.warning(
|
||||
"Found existing alembic directory but missing configuration")
|
||||
self._cleanup_existing_alembic()
|
||||
self._initialize_alembic()
|
||||
logger.info("Alembic initialization complete")
|
||||
|
||||
def _initialize_alembic(self) -> str:
|
||||
"""Initializes Alembic configuration in the local directory."""
|
||||
logger.info("Initializing Alembic configuration...")
|
||||
|
||||
# Check if versions exists
|
||||
has_versions = (self.alembic_dir / 'versions').exists()
|
||||
logger.info(f"Existing versions directory found: {has_versions}")
|
||||
|
||||
# Create base directories
|
||||
self.alembic_dir.mkdir(exist_ok=True)
|
||||
if not has_versions:
|
||||
(self.alembic_dir / 'versions').mkdir(exist_ok=True)
|
||||
|
||||
# Write alembic.ini
|
||||
ini_content = self._generate_alembic_ini_content()
|
||||
with open(self.alembic_ini_path, 'w') as f:
|
||||
f.write(ini_content)
|
||||
logger.info("Created alembic.ini")
|
||||
|
||||
if not has_versions:
|
||||
# Only run init if no versions directory
|
||||
config = self.get_alembic_config()
|
||||
command.init(config, str(self.alembic_dir))
|
||||
logger.info("Initialized new Alembic directory structure")
|
||||
else:
|
||||
# Create minimal env.py if it doesn't exist
|
||||
env_path = self.alembic_dir / 'env.py'
|
||||
if not env_path.exists():
|
||||
self._create_minimal_env_py(env_path)
|
||||
logger.info("Created minimal env.py")
|
||||
else:
|
||||
# Update existing env.py
|
||||
self._update_env_py(env_path)
|
||||
logger.info("Updated existing env.py")
|
||||
|
||||
logger.info(f"Alembic setup completed at {self.base_dir}")
|
||||
return str(self.alembic_ini_path)
|
||||
|
||||
def _create_minimal_env_py(self, env_path: Path) -> None:
|
||||
"""Creates a minimal env.py file for Alembic."""
|
||||
content = '''
|
||||
from logging.config import fileConfig
|
||||
from sqlalchemy import engine_from_config
|
||||
from sqlalchemy import pool
|
||||
from alembic import context
|
||||
from sqlmodel import SQLModel
|
||||
|
||||
config = context.config
|
||||
if config.config_file_name is not None:
|
||||
fileConfig(config.config_file_name)
|
||||
|
||||
target_metadata = SQLModel.metadata
|
||||
|
||||
def run_migrations_offline() -> None:
|
||||
url = config.get_main_option("sqlalchemy.url")
|
||||
context.configure(
|
||||
url=url,
|
||||
target_metadata=target_metadata,
|
||||
literal_binds=True,
|
||||
dialect_opts={"paramstyle": "named"},
|
||||
compare_type=True
|
||||
)
|
||||
with context.begin_transaction():
|
||||
context.run_migrations()
|
||||
|
||||
def run_migrations_online() -> None:
|
||||
connectable = engine_from_config(
|
||||
config.get_section(config.config_ini_section),
|
||||
prefix="sqlalchemy.",
|
||||
poolclass=pool.NullPool,
|
||||
)
|
||||
with connectable.connect() as connection:
|
||||
context.configure(
|
||||
connection=connection,
|
||||
target_metadata=target_metadata,
|
||||
compare_type=True
|
||||
)
|
||||
with context.begin_transaction():
|
||||
context.run_migrations()
|
||||
|
||||
if context.is_offline_mode():
|
||||
run_migrations_offline()
|
||||
else:
|
||||
run_migrations_online()'''
|
||||
|
||||
with open(env_path, 'w') as f:
|
||||
f.write(content)
|
||||
|
||||
def _generate_alembic_ini_content(self) -> str:
|
||||
"""
|
||||
Generates content for alembic.ini file.
|
||||
"""
|
||||
return f"""
|
||||
[alembic]
|
||||
script_location = {self.alembic_dir}
|
||||
sqlalchemy.url = {self.engine.url}
|
||||
|
||||
[loggers]
|
||||
keys = root,sqlalchemy,alembic
|
||||
|
||||
[handlers]
|
||||
keys = console
|
||||
|
||||
[formatters]
|
||||
keys = generic
|
||||
|
||||
[logger_root]
|
||||
level = WARN
|
||||
handlers = console
|
||||
qualname =
|
||||
|
||||
[logger_sqlalchemy]
|
||||
level = WARN
|
||||
handlers =
|
||||
qualname = sqlalchemy.engine
|
||||
|
||||
[logger_alembic]
|
||||
level = INFO
|
||||
handlers =
|
||||
qualname = alembic
|
||||
|
||||
[handler_console]
|
||||
class = StreamHandler
|
||||
args = (sys.stderr,)
|
||||
level = NOTSET
|
||||
formatter = generic
|
||||
|
||||
[formatter_generic]
|
||||
format = %(levelname)-5.5s [%(name)s] %(message)s
|
||||
datefmt = %H:%M:%S
|
||||
""".strip()
|
||||
|
||||
def _update_env_py(self, env_path: Path) -> None:
|
||||
"""
|
||||
Updates the env.py file to use SQLModel metadata.
|
||||
"""
|
||||
try:
|
||||
with open(env_path, 'r') as f:
|
||||
content = f.read()
|
||||
|
||||
# Add SQLModel import
|
||||
if "from sqlmodel import SQLModel" not in content:
|
||||
content = "from sqlmodel import SQLModel\n" + content
|
||||
|
||||
# Replace target_metadata
|
||||
content = content.replace(
|
||||
"target_metadata = None",
|
||||
"target_metadata = SQLModel.metadata"
|
||||
)
|
||||
|
||||
# Add compare_type=True to context.configure
|
||||
if "context.configure(" in content and "compare_type=True" not in content:
|
||||
content = content.replace(
|
||||
"context.configure(",
|
||||
"context.configure(compare_type=True,"
|
||||
)
|
||||
|
||||
with open(env_path, 'w') as f:
|
||||
f.write(content)
|
||||
|
||||
logger.info("Updated env.py with SQLModel metadata")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to update env.py: {e}")
|
||||
raise
|
||||
|
||||
# Fixed: use keyword-only argument
|
||||
def _ensure_alembic_setup(self, *, force: bool = False) -> None:
|
||||
"""
|
||||
Ensures Alembic is properly set up, initializing if necessary.
|
||||
|
||||
Args:
|
||||
force: If True, removes existing configuration and reinitializes
|
||||
"""
|
||||
try:
|
||||
self._validate_alembic_setup()
|
||||
if force:
|
||||
logger.info(
|
||||
"Force initialization requested. Cleaning up existing configuration...")
|
||||
self._cleanup_existing_alembic()
|
||||
self._initialize_alembic()
|
||||
except FileNotFoundError:
|
||||
logger.info("Alembic configuration not found. Initializing...")
|
||||
if self.alembic_dir.exists():
|
||||
logger.warning(
|
||||
"Found existing alembic directory but missing configuration")
|
||||
self._cleanup_existing_alembic()
|
||||
self._initialize_alembic()
|
||||
logger.info("Alembic initialization complete")
|
||||
|
||||
def _validate_alembic_setup(self) -> None:
|
||||
"""Validates that Alembic is properly configured."""
|
||||
if not self.alembic_ini_path.exists():
|
||||
raise FileNotFoundError("Alembic configuration not found")
|
||||
|
||||
def get_alembic_config(self) -> Config:
|
||||
"""
|
||||
Gets Alembic configuration.
|
||||
|
||||
Returns:
|
||||
Config: Alembic Config object
|
||||
|
||||
Raises:
|
||||
FileNotFoundError: If alembic.ini cannot be found
|
||||
"""
|
||||
if not self.alembic_ini_path.exists():
|
||||
raise FileNotFoundError("Could not find alembic.ini")
|
||||
|
||||
return Config(str(self.alembic_ini_path))
|
||||
|
||||
def get_current_revision(self) -> Optional[str]:
|
||||
"""
|
||||
Gets the current database revision.
|
||||
|
||||
Returns:
|
||||
str: Current revision string or None if no revision
|
||||
"""
|
||||
with self.engine.connect() as conn:
|
||||
context = MigrationContext.configure(conn)
|
||||
return context.get_current_revision()
|
||||
|
||||
def get_head_revision(self) -> str:
|
||||
"""
|
||||
Gets the latest available revision.
|
||||
|
||||
Returns:
|
||||
str: Head revision string
|
||||
"""
|
||||
config = self.get_alembic_config()
|
||||
script = ScriptDirectory.from_config(config)
|
||||
return script.get_current_head()
|
||||
|
||||
def get_schema_differences(self) -> List[tuple]:
|
||||
"""
|
||||
Detects differences between current database and models.
|
||||
|
||||
Returns:
|
||||
List[tuple]: List of differences found
|
||||
"""
|
||||
with self.engine.connect() as conn:
|
||||
context = MigrationContext.configure(conn)
|
||||
diff = compare_metadata(context, SQLModel.metadata)
|
||||
return list(diff)
|
||||
|
||||
def check_schema_status(self) -> Tuple[bool, str]:
|
||||
"""
|
||||
Checks if database schema matches current models and migrations.
|
||||
|
||||
Returns:
|
||||
Tuple[bool, str]: (needs_upgrade, status_message)
|
||||
"""
|
||||
try:
|
||||
current_rev = self.get_current_revision()
|
||||
head_rev = self.get_head_revision()
|
||||
|
||||
if current_rev != head_rev:
|
||||
return True, f"Database needs upgrade: {current_rev} -> {head_rev}"
|
||||
|
||||
differences = self.get_schema_differences()
|
||||
if differences:
|
||||
changes_desc = "\n".join(str(diff) for diff in differences)
|
||||
return True, f"Unmigrated changes detected:\n{changes_desc}"
|
||||
|
||||
return False, "Database schema is up to date"
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error checking schema status: {str(e)}")
|
||||
return True, f"Error checking schema: {str(e)}"
|
||||
|
||||
def upgrade_schema(self, revision: str = "head") -> bool:
|
||||
"""
|
||||
Upgrades database schema to specified revision.
|
||||
|
||||
Args:
|
||||
revision: Target revision (default: "head")
|
||||
|
||||
Returns:
|
||||
bool: True if upgrade successful
|
||||
"""
|
||||
try:
|
||||
config = self.get_alembic_config()
|
||||
command.upgrade(config, revision)
|
||||
logger.info(f"Schema upgraded successfully to {revision}")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Schema upgrade failed: {str(e)}")
|
||||
return False
|
||||
|
||||
def check_and_upgrade(self) -> Tuple[bool, str]:
|
||||
"""
|
||||
Checks schema status and upgrades if necessary (and auto_upgrade is True).
|
||||
|
||||
Returns:
|
||||
Tuple[bool, str]: (action_taken, status_message)
|
||||
"""
|
||||
needs_upgrade, status = self.check_schema_status()
|
||||
|
||||
if needs_upgrade:
|
||||
if self.auto_upgrade:
|
||||
if self.upgrade_schema():
|
||||
return True, "Schema was automatically upgraded"
|
||||
else:
|
||||
return False, "Automatic schema upgrade failed"
|
||||
else:
|
||||
return False, f"Schema needs upgrade but auto_upgrade is disabled. Status: {status}"
|
||||
|
||||
return False, status
|
||||
|
||||
def generate_revision(self, message: str = "auto") -> Optional[str]:
|
||||
"""
|
||||
Generates new migration revision for current schema changes.
|
||||
|
||||
Args:
|
||||
message: Revision message
|
||||
|
||||
Returns:
|
||||
str: Revision ID if successful, None otherwise
|
||||
"""
|
||||
try:
|
||||
config = self.get_alembic_config()
|
||||
command.revision(
|
||||
config,
|
||||
message=message,
|
||||
autogenerate=True
|
||||
)
|
||||
return self.get_head_revision()
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to generate revision: {str(e)}")
|
||||
return None
|
||||
|
||||
def get_pending_migrations(self) -> List[str]:
|
||||
"""
|
||||
Gets list of pending migrations that need to be applied.
|
||||
|
||||
Returns:
|
||||
List[str]: List of pending migration revision IDs
|
||||
"""
|
||||
config = self.get_alembic_config()
|
||||
script = ScriptDirectory.from_config(config)
|
||||
|
||||
current = self.get_current_revision()
|
||||
head = self.get_head_revision()
|
||||
|
||||
if current == head:
|
||||
return []
|
||||
|
||||
pending = []
|
||||
for rev in script.iterate_revisions(current, head):
|
||||
pending.append(rev.revision)
|
||||
|
||||
return pending
|
||||
|
||||
def print_status(self) -> None:
|
||||
"""Prints current migration status information to logger."""
|
||||
current = self.get_current_revision()
|
||||
head = self.get_head_revision()
|
||||
differences = self.get_schema_differences()
|
||||
pending = self.get_pending_migrations()
|
||||
|
||||
logger.info("=== Database Schema Status ===")
|
||||
logger.info(f"Current revision: {current}")
|
||||
logger.info(f"Head revision: {head}")
|
||||
logger.info(f"Pending migrations: {len(pending)}")
|
||||
for rev in pending:
|
||||
logger.info(f" - {rev}")
|
||||
logger.info(f"Unmigrated changes: {len(differences)}")
|
||||
for diff in differences:
|
||||
logger.info(f" - {diff}")
|
||||
|
||||
def ensure_schema_up_to_date(self) -> bool:
|
||||
"""
|
||||
Ensures the database schema is up to date, generating and applying migrations if needed.
|
||||
|
||||
Returns:
|
||||
bool: True if schema is up to date or was successfully updated
|
||||
"""
|
||||
try:
|
||||
# Check for unmigrated changes
|
||||
differences = self.get_schema_differences()
|
||||
if differences:
|
||||
# Generate new migration
|
||||
revision = self.generate_revision("auto-generated")
|
||||
if not revision:
|
||||
return False
|
||||
logger.info(f"Generated new migration: {revision}")
|
||||
|
||||
# Apply any pending migrations
|
||||
upgraded, status = self.check_and_upgrade()
|
||||
if not upgraded and "needs upgrade" in status.lower():
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to ensure schema is up to date: {e}")
|
||||
return False
|
||||
File diff suppressed because one or more lines are too long
Reference in New Issue
Block a user