v1 of AutoGen Studio on AgentChat (#4097)

* add skeleton worflow manager

* add test notebook

* update test nb

* add sample team spec

* refactor requirements to agentchat and ext

* add base provider to return agentchat agents from json spec

* initial api refactor, update dbmanager

* api refactor

* refactor tests

* ags api tutorial update

* ui refactor

* general refactor

* minor refactor updates

* backend api refaactor

* ui refactor and update

* implement v1 for streaming connection with ui updates

* backend refactor

* ui refactor

* minor ui tweak

* minor refactor and tweaks

* general refactor

* update tests

* sync uv.lock with main

* uv lock update
This commit is contained in:
Victor Dibia
2024-11-09 14:32:24 -08:00
committed by GitHub
parent f40b0c2730
commit 0e985d4b40
117 changed files with 20736 additions and 13600 deletions

View File

@@ -1,3 +1,3 @@
# from .dbmanager import *
from .dbmanager import *
from .utils import *
from .db_manager import DatabaseManager
from .component_factory import ComponentFactory
from .config_manager import ConfigurationManager

View File

@@ -1,116 +0,0 @@
# A generic, single database configuration.
[alembic]
# path to migration scripts
script_location = migrations
# template used to generate migration file names; The default value is %%(rev)s_%%(slug)s
# Uncomment the line below if you want the files to be prepended with date and time
# see https://alembic.sqlalchemy.org/en/latest/tutorial.html#editing-the-ini-file
# for all available tokens
# file_template = %%(year)d_%%(month).2d_%%(day).2d_%%(hour).2d%%(minute).2d-%%(rev)s_%%(slug)s
# sys.path path, will be prepended to sys.path if present.
# defaults to the current working directory.
prepend_sys_path = .
# timezone to use when rendering the date within the migration file
# as well as the filename.
# If specified, requires the python>=3.9 or backports.zoneinfo library.
# Any required deps can installed by adding `alembic[tz]` to the pip requirements
# string value is passed to ZoneInfo()
# leave blank for localtime
# timezone =
# max length of characters to apply to the
# "slug" field
# truncate_slug_length = 40
# set to 'true' to run the environment during
# the 'revision' command, regardless of autogenerate
# revision_environment = false
# set to 'true' to allow .pyc and .pyo files without
# a source .py file to be detected as revisions in the
# versions/ directory
# sourceless = false
# version location specification; This defaults
# to migrations/versions. When using multiple version
# directories, initial revisions must be specified with --version-path.
# The path separator used here should be the separator specified by "version_path_separator" below.
# version_locations = %(here)s/bar:%(here)s/bat:migrations/versions
# version path separator; As mentioned above, this is the character used to split
# version_locations. The default within new alembic.ini files is "os", which uses os.pathsep.
# If this key is omitted entirely, it falls back to the legacy behavior of splitting on spaces and/or commas.
# Valid values for version_path_separator are:
#
# version_path_separator = :
# version_path_separator = ;
# version_path_separator = space
version_path_separator = os # Use os.pathsep. Default configuration used for new projects.
# set to 'true' to search source files recursively
# in each "version_locations" directory
# new in Alembic version 1.10
# recursive_version_locations = false
# the output encoding used when revision files
# are written from script.py.mako
# output_encoding = utf-8
sqlalchemy.url = driver://user:pass@localhost/dbname
[post_write_hooks]
# post_write_hooks defines scripts or Python functions that are run
# on newly generated revision scripts. See the documentation for further
# detail and examples
# format using "black" - use the console_scripts runner, against the "black" entrypoint
# hooks = black
# black.type = console_scripts
# black.entrypoint = black
# black.options = -l 79 REVISION_SCRIPT_FILENAME
# lint with attempts to fix using "ruff" - use the exec runner, execute a binary
# hooks = ruff
# ruff.type = exec
# ruff.executable = %(here)s/.venv/bin/ruff
# ruff.options = --fix REVISION_SCRIPT_FILENAME
# Logging configuration
[loggers]
keys = root,sqlalchemy,alembic
[handlers]
keys = console
[formatters]
keys = generic
[logger_root]
level = WARN
handlers = console
qualname =
[logger_sqlalchemy]
level = WARN
handlers =
qualname = sqlalchemy.engine
[logger_alembic]
level = INFO
handlers =
qualname = alembic
[handler_console]
class = StreamHandler
args = (sys.stderr,)
level = NOTSET
formatter = generic
[formatter_generic]
format = %(levelname)-5.5s [%(name)s] %(message)s
datefmt = %H:%M:%S

View File

@@ -0,0 +1,355 @@
import os
from pathlib import Path
from typing import List, Literal, Union, Optional, Dict, Any, Type
from datetime import datetime
import json
from autogen_agentchat.task import MaxMessageTermination, TextMentionTermination, StopMessageTermination
import yaml
import logging
from packaging import version
from ..datamodel import (
TeamConfig, AgentConfig, ModelConfig, ToolConfig,
TeamTypes, AgentTypes, ModelTypes, ToolTypes,
ComponentType, ComponentConfig, ComponentConfigInput, TerminationConfig, TerminationTypes, Response
)
from autogen_agentchat.agents import AssistantAgent
from autogen_agentchat.teams import RoundRobinGroupChat, SelectorGroupChat
from autogen_ext.models import OpenAIChatCompletionClient
from autogen_core.components.tools import FunctionTool
logger = logging.getLogger(__name__)
# Type definitions for supported components
TeamComponent = Union[RoundRobinGroupChat, SelectorGroupChat]
AgentComponent = Union[AssistantAgent] # Will grow with more agent types
# Will grow with more model types
ModelComponent = Union[OpenAIChatCompletionClient]
ToolComponent = Union[FunctionTool] # Will grow with more tool types
TerminationComponent = Union[MaxMessageTermination,
StopMessageTermination, TextMentionTermination]
# Config type definitions
Component = Union[TeamComponent, AgentComponent, ModelComponent, ToolComponent]
ReturnType = Literal['object', 'dict', 'config']
Component = Union[RoundRobinGroupChat, SelectorGroupChat,
AssistantAgent, OpenAIChatCompletionClient, FunctionTool]
class ComponentFactory:
"""Creates and manages agent components with versioned configuration loading"""
SUPPORTED_VERSIONS = {
ComponentType.TEAM: ["1.0.0"],
ComponentType.AGENT: ["1.0.0"],
ComponentType.MODEL: ["1.0.0"],
ComponentType.TOOL: ["1.0.0"],
ComponentType.TERMINATION: ["1.0.0"]
}
def __init__(self):
self._model_cache: Dict[str, OpenAIChatCompletionClient] = {}
self._tool_cache: Dict[str, FunctionTool] = {}
self._last_cache_clear = datetime.now()
async def load(self, component: ComponentConfigInput, return_type: ReturnType = 'object') -> Union[Component, dict, ComponentConfig]:
"""
Universal loader for any component type
Args:
component: Component configuration (file path, dict, or ComponentConfig)
return_type: Type of return value ('object', 'dict', or 'config')
Returns:
Component instance, config dict, or ComponentConfig based on return_type
Raises:
ValueError: If component type is unknown or version unsupported
"""
try:
# Load and validate config
if isinstance(component, (str, Path)):
component_dict = await self._load_from_file(component)
config = self._dict_to_config(component_dict)
elif isinstance(component, dict):
config = self._dict_to_config(component)
else:
config = component
# Validate version
if not self._is_version_supported(config.component_type, config.version):
raise ValueError(
f"Unsupported version {config.version} for "
f"component type {config.component_type}. "
f"Supported versions: {self.SUPPORTED_VERSIONS[config.component_type]}"
)
# Return early if dict or config requested
if return_type == 'dict':
return config.model_dump()
elif return_type == 'config':
return config
# Otherwise create and return component instance
handlers = {
ComponentType.TEAM: self.load_team,
ComponentType.AGENT: self.load_agent,
ComponentType.MODEL: self.load_model,
ComponentType.TOOL: self.load_tool,
ComponentType.TERMINATION: self.load_termination
}
handler = handlers.get(config.component_type)
if not handler:
raise ValueError(
f"Unknown component type: {config.component_type}")
return await handler(config)
except Exception as e:
logger.error(f"Failed to load component: {str(e)}")
raise
async def load_directory(self, directory: Union[str, Path], check_exists: bool = False, return_type: ReturnType = 'object') -> List[Union[Component, dict, ComponentConfig]]:
"""
Import all component configurations from a directory.
"""
components = []
try:
directory = Path(directory)
# Using Path.iterdir() instead of os.listdir
for path in list(directory.glob("*")):
if path.suffix.lower().endswith(('.json', '.yaml', '.yml')):
try:
component = await self.load(path, return_type)
components.append(component)
except Exception as e:
logger.info(
f"Failed to load component: {str(e)}, {path}")
return components
except Exception as e:
logger.info(f"Failed to load directory: {str(e)}")
return components
def _dict_to_config(self, config_dict: dict) -> ComponentConfig:
"""Convert dictionary to appropriate config type based on component_type"""
if "component_type" not in config_dict:
raise ValueError("component_type is required in configuration")
config_types = {
ComponentType.TEAM: TeamConfig,
ComponentType.AGENT: AgentConfig,
ComponentType.MODEL: ModelConfig,
ComponentType.TOOL: ToolConfig,
ComponentType.TERMINATION: TerminationConfig # Add mapping for termination
}
component_type = ComponentType(config_dict["component_type"])
config_class = config_types.get(component_type)
if not config_class:
raise ValueError(f"Unknown component type: {component_type}")
return config_class(**config_dict)
async def load_termination(self, config: TerminationConfig) -> TerminationComponent:
"""Create termination condition instance from configuration."""
try:
if config.termination_type == TerminationTypes.MAX_MESSAGES:
return MaxMessageTermination(max_messages=config.max_messages)
elif config.termination_type == TerminationTypes.STOP_MESSAGE:
return StopMessageTermination()
elif config.termination_type == TerminationTypes.TEXT_MENTION:
if not config.text:
raise ValueError(
"text parameter required for TextMentionTermination")
return TextMentionTermination(text=config.text)
else:
raise ValueError(
f"Unsupported termination type: {config.termination_type}")
except Exception as e:
logger.error(f"Failed to create termination condition: {str(e)}")
raise ValueError(
f"Termination condition creation failed: {str(e)}")
async def load_team(self, config: TeamConfig) -> TeamComponent:
"""Create team instance from configuration."""
try:
# Load participants (agents)
participants = []
for participant in config.participants:
agent = await self.load(participant)
participants.append(agent)
# Load model client if specified
model_client = None
if config.model_client:
model_client = await self.load(config.model_client)
# Load termination condition if specified
termination = None
if config.termination_condition:
# Now we can use the universal load() method since termination is a proper component
termination = await self.load(config.termination_condition)
# Create team based on type
if config.team_type == TeamTypes.ROUND_ROBIN:
return RoundRobinGroupChat(
participants=participants,
termination_condition=termination
)
elif config.team_type == TeamTypes.SELECTOR:
if not model_client:
raise ValueError(
"SelectorGroupChat requires a model_client")
return SelectorGroupChat(
participants=participants,
model_client=model_client,
termination_condition=termination
)
else:
raise ValueError(f"Unsupported team type: {config.team_type}")
except Exception as e:
logger.error(f"Failed to create team {config.name}: {str(e)}")
raise ValueError(f"Team creation failed: {str(e)}")
async def load_agent(self, config: AgentConfig) -> AgentComponent:
"""Create agent instance from configuration."""
try:
# Load model client if specified
model_client = None
if config.model_client:
model_client = await self.load(config.model_client)
system_message = config.system_message if config.system_message else "You are a helpful assistant"
# Load tools if specified
tools = []
if config.tools:
for tool_config in config.tools:
tool = await self.load(tool_config)
tools.append(tool)
if config.agent_type == AgentTypes.ASSISTANT:
return AssistantAgent(
name=config.name,
model_client=model_client,
tools=tools,
system_message=system_message
)
else:
raise ValueError(
f"Unsupported agent type: {config.agent_type}")
except Exception as e:
logger.error(f"Failed to create agent {config.name}: {str(e)}")
raise ValueError(f"Agent creation failed: {str(e)}")
async def load_model(self, config: ModelConfig) -> ModelComponent:
"""Create model instance from configuration."""
try:
# Check cache first
cache_key = str(config.model_dump())
if cache_key in self._model_cache:
logger.debug(f"Using cached model for {config.model}")
return self._model_cache[cache_key]
if config.model_type == ModelTypes.OPENAI:
model = OpenAIChatCompletionClient(
model=config.model,
api_key=config.api_key,
base_url=config.base_url
)
self._model_cache[cache_key] = model
return model
else:
raise ValueError(
f"Unsupported model type: {config.model_type}")
except Exception as e:
logger.error(f"Failed to create model {config.model}: {str(e)}")
raise ValueError(f"Model creation failed: {str(e)}")
async def load_tool(self, config: ToolConfig) -> ToolComponent:
"""Create tool instance from configuration."""
try:
# Validate required fields
if not all([config.name, config.description, config.content, config.tool_type]):
raise ValueError("Tool configuration missing required fields")
# Check cache first
cache_key = str(config.model_dump())
if cache_key in self._tool_cache:
logger.debug(f"Using cached tool '{config.name}'")
return self._tool_cache[cache_key]
if config.tool_type == ToolTypes.PYTHON_FUNCTION:
tool = FunctionTool(
name=config.name,
description=config.description,
func=self._func_from_string(config.content)
)
self._tool_cache[cache_key] = tool
return tool
else:
raise ValueError(f"Unsupported tool type: {config.tool_type}")
except Exception as e:
logger.error(f"Failed to create tool '{config.name}': {str(e)}")
raise
# Helper methods remain largely the same
async def _load_from_file(self, path: Union[str, Path]) -> dict:
"""Load configuration from JSON or YAML file."""
path = Path(path)
if not path.exists():
raise FileNotFoundError(f"Config file not found: {path}")
try:
with open(path) as f:
if path.suffix == '.json':
return json.load(f)
elif path.suffix in ('.yml', '.yaml'):
return yaml.safe_load(f)
else:
raise ValueError(f"Unsupported file format: {path.suffix}")
except Exception as e:
raise ValueError(f"Failed to load file {path}: {str(e)}")
def _func_from_string(self, content: str) -> callable:
"""Convert function string to callable."""
try:
namespace = {}
exec(content, namespace)
for item in namespace.values():
if callable(item) and not isinstance(item, type):
return item
raise ValueError("No function found in provided code")
except Exception as e:
raise ValueError(f"Failed to create function: {str(e)}")
def _is_version_supported(self, component_type: ComponentType, ver: str) -> bool:
"""Check if version is supported for component type."""
try:
v = version.parse(ver)
return ver in self.SUPPORTED_VERSIONS[component_type]
except version.InvalidVersion:
return False
async def cleanup(self) -> None:
"""Cleanup resources and clear caches."""
for model in self._model_cache.values():
if hasattr(model, 'cleanup'):
await model.cleanup()
for tool in self._tool_cache.values():
if hasattr(tool, 'cleanup'):
await tool.cleanup()
self._model_cache.clear()
self._tool_cache.clear()
self._last_cache_clear = datetime.now()
logger.info("Cleared all component caches")

View File

@@ -0,0 +1,322 @@
import logging
from typing import Optional, Union, Dict, Any, List
from pathlib import Path
from loguru import logger
from ..datamodel import (
Model, Team, Agent, Tool,
Response, ComponentTypes, LinkTypes,
ComponentConfigInput
)
from .component_factory import ComponentFactory
from .db_manager import DatabaseManager
class ConfigurationManager:
"""Manages persistence and relationships of components using ComponentFactory for validation"""
DEFAULT_UNIQUENESS_FIELDS = {
ComponentTypes.MODEL: ['model_type', 'model'],
ComponentTypes.TOOL: ['name'],
ComponentTypes.AGENT: ['agent_type', 'name'],
ComponentTypes.TEAM: ['team_type', 'name']
}
def __init__(self, db_manager: DatabaseManager, uniqueness_fields: Dict[ComponentTypes, List[str]] = None):
self.db_manager = db_manager
self.component_factory = ComponentFactory()
self.uniqueness_fields = uniqueness_fields or self.DEFAULT_UNIQUENESS_FIELDS
async def import_component(self, component_config: ComponentConfigInput, user_id: str, check_exists: bool = False) -> Response:
"""
Import a component configuration, validate it, and store the resulting component.
Args:
component_config: Configuration for the component (file path, dict, or ComponentConfig)
user_id: User ID to associate with imported component
check_exists: Whether to check for existing components before storing (default: False)
Returns:
Response containing import results or error
"""
try:
# Get validated config as dict
config = await self.component_factory.load(component_config, return_type='dict')
# Get component type
component_type = self._determine_component_type(config)
if not component_type:
raise ValueError(
f"Unable to determine component type from config")
# Check existence if requested
if check_exists:
existing = self._check_exists(component_type, config, user_id)
if existing:
return Response(
message=self._format_exists_message(
component_type, config),
status=True,
data={"id": existing.id}
)
# Route to appropriate storage method
if component_type == ComponentTypes.TEAM:
return await self._store_team(config, user_id, check_exists)
elif component_type == ComponentTypes.AGENT:
return await self._store_agent(config, user_id, check_exists)
elif component_type == ComponentTypes.MODEL:
return await self._store_model(config, user_id)
elif component_type == ComponentTypes.TOOL:
return await self._store_tool(config, user_id)
else:
raise ValueError(
f"Unsupported component type: {component_type}")
except Exception as e:
logger.error(f"Failed to import component: {str(e)}")
return Response(message=str(e), status=False)
async def import_directory(self, directory: Union[str, Path], user_id: str, check_exists: bool = False) -> Response:
"""
Import all component configurations from a directory.
Args:
directory: Path to directory containing configuration files
user_id: User ID to associate with imported components
check_exists: Whether to check for existing components before storing (default: False)
Returns:
Response containing import results for all files
"""
try:
configs = await self.component_factory.load_directory(directory, return_type='dict')
results = []
for config in configs:
result = await self.import_component(config, user_id, check_exists)
results.append({
"component": self._get_component_type(config),
"status": result.status,
"message": result.message,
"id": result.data.get("id") if result.status else None
})
return Response(
message="Directory import complete",
status=True,
data=results
)
except Exception as e:
logger.error(f"Failed to import directory: {str(e)}")
return Response(message=str(e), status=False)
async def _store_team(self, config: dict, user_id: str, check_exists: bool = False) -> Response:
"""Store team component and manage its relationships with agents"""
try:
# Store the team
team_db = Team(
user_id=user_id,
config=config
)
team_result = self.db_manager.upsert(team_db)
if not team_result.status:
return team_result
team_id = team_result.data["id"]
# Handle participants (agents)
for participant in config.get("participants", []):
if check_exists:
# Check for existing agent
agent_type = self._determine_component_type(participant)
existing_agent = self._check_exists(
agent_type, participant, user_id)
if existing_agent:
# Link existing agent
self.db_manager.link(
LinkTypes.TEAM_AGENT,
team_id,
existing_agent.id
)
logger.info(
f"Linked existing agent to team: {existing_agent}")
continue
# Store and link new agent
agent_result = await self._store_agent(participant, user_id, check_exists)
if agent_result.status:
self.db_manager.link(
LinkTypes.TEAM_AGENT,
team_id,
agent_result.data["id"]
)
return team_result
except Exception as e:
logger.error(f"Failed to store team: {str(e)}")
return Response(message=str(e), status=False)
async def _store_agent(self, config: dict, user_id: str, check_exists: bool = False) -> Response:
"""Store agent component and manage its relationships with tools and model"""
try:
# Store the agent
agent_db = Agent(
user_id=user_id,
config=config
)
agent_result = self.db_manager.upsert(agent_db)
if not agent_result.status:
return agent_result
agent_id = agent_result.data["id"]
# Handle model client
if "model_client" in config:
if check_exists:
# Check for existing model
model_type = self._determine_component_type(
config["model_client"])
existing_model = self._check_exists(
model_type, config["model_client"], user_id)
if existing_model:
# Link existing model
self.db_manager.link(
LinkTypes.AGENT_MODEL,
agent_id,
existing_model.id
)
logger.info(
f"Linked existing model to agent: {existing_model.config.model_type}")
else:
# Store and link new model
model_result = await self._store_model(config["model_client"], user_id)
if model_result.status:
self.db_manager.link(
LinkTypes.AGENT_MODEL,
agent_id,
model_result.data["id"]
)
else:
# Store and link new model without checking
model_result = await self._store_model(config["model_client"], user_id)
if model_result.status:
self.db_manager.link(
LinkTypes.AGENT_MODEL,
agent_id,
model_result.data["id"]
)
# Handle tools
for tool_config in config.get("tools", []):
if check_exists:
# Check for existing tool
tool_type = self._determine_component_type(tool_config)
existing_tool = self._check_exists(
tool_type, tool_config, user_id)
if existing_tool:
# Link existing tool
self.db_manager.link(
LinkTypes.AGENT_TOOL,
agent_id,
existing_tool.id
)
logger.info(
f"Linked existing tool to agent: {existing_tool.config.name}")
continue
# Store and link new tool
tool_result = await self._store_tool(tool_config, user_id)
if tool_result.status:
self.db_manager.link(
LinkTypes.AGENT_TOOL,
agent_id,
tool_result.data["id"]
)
return agent_result
except Exception as e:
logger.error(f"Failed to store agent: {str(e)}")
return Response(message=str(e), status=False)
async def _store_model(self, config: dict, user_id: str) -> Response:
"""Store model component (leaf node - no relationships)"""
try:
model_db = Model(
user_id=user_id,
config=config
)
return self.db_manager.upsert(model_db)
except Exception as e:
logger.error(f"Failed to store model: {str(e)}")
return Response(message=str(e), status=False)
async def _store_tool(self, config: dict, user_id: str) -> Response:
"""Store tool component (leaf node - no relationships)"""
try:
tool_db = Tool(
user_id=user_id,
config=config
)
return self.db_manager.upsert(tool_db)
except Exception as e:
logger.error(f"Failed to store tool: {str(e)}")
return Response(message=str(e), status=False)
def _check_exists(self, component_type: ComponentTypes, config: dict, user_id: str) -> Optional[Union[Model, Tool, Agent, Team]]:
"""Check if component exists based on configured uniqueness fields."""
fields = self.uniqueness_fields.get(component_type, [])
if not fields:
return None
component_class = {
ComponentTypes.MODEL: Model,
ComponentTypes.TOOL: Tool,
ComponentTypes.AGENT: Agent,
ComponentTypes.TEAM: Team
}.get(component_type)
components = self.db_manager.get(
component_class, {"user_id": user_id}).data
for component in components:
matches = all(
component.config.get(field) == config.get(field)
for field in fields
)
if matches:
return component
return None
def _format_exists_message(self, component_type: ComponentTypes, config: dict) -> str:
"""Format existence message with identifying fields."""
fields = self.uniqueness_fields.get(component_type, [])
field_values = [f"{field}='{config.get(field)}'" for field in fields]
return f"{component_type.value} with {' and '.join(field_values)} already exists"
def _determine_component_type(self, config: dict) -> Optional[ComponentTypes]:
"""Determine component type from configuration dictionary"""
if "team_type" in config:
return ComponentTypes.TEAM
elif "agent_type" in config:
return ComponentTypes.AGENT
elif "model_type" in config:
return ComponentTypes.MODEL
elif "tool_type" in config:
return ComponentTypes.TOOL
return None
def _get_component_type(self, config: dict) -> str:
"""Helper to get component type string from config"""
component_type = self._determine_component_type(config)
return component_type.value if component_type else "unknown"
async def cleanup(self):
"""Cleanup resources"""
await self.component_factory.cleanup()

View File

@@ -0,0 +1,424 @@
import threading
from datetime import datetime
from typing import Optional
from loguru import logger
from sqlalchemy import exc, text, func
from sqlmodel import Session, SQLModel, and_, create_engine, select
from .schema_manager import SchemaManager
from ..datamodel import (
Response,
LinkTypes
)
# from .dbutils import init_db_samples
class DatabaseManager:
"""A class to manage database operations"""
_init_lock = threading.Lock()
def __init__(self, engine_uri: str, auto_upgrade: bool = True):
connection_args = {
"check_same_thread": True} if "sqlite" in engine_uri else {}
self.engine = create_engine(engine_uri, connect_args=connection_args)
self.schema_manager = SchemaManager(
engine=self.engine,
auto_upgrade=auto_upgrade,
)
# Check and upgrade on startup
upgraded, status = self.schema_manager.check_and_upgrade()
if upgraded:
logger.info("Database schema was upgraded automatically")
else:
logger.info(f"Schema status: {status}")
def reset_db(self, recreate_tables: bool = True):
"""
Reset the database by dropping all tables and optionally recreating them.
Args:
recreate_tables (bool): If True, recreates the tables after dropping them.
Set to False if you want to call create_db_and_tables() separately.
"""
if not self._init_lock.acquire(blocking=False):
logger.warning("Database reset already in progress")
return Response(
message="Database reset already in progress",
status=False,
data=None
)
try:
# Dispose existing connections
self.engine.dispose()
with Session(self.engine) as session:
try:
# Disable foreign key checks for SQLite
if 'sqlite' in str(self.engine.url):
session.exec(text('PRAGMA foreign_keys=OFF'))
# Drop all tables
SQLModel.metadata.drop_all(self.engine)
logger.info("All tables dropped successfully")
# Re-enable foreign key checks for SQLite
if 'sqlite' in str(self.engine.url):
session.exec(text('PRAGMA foreign_keys=ON'))
session.commit()
except Exception as e:
session.rollback()
raise e
finally:
session.close()
self._init_lock.release()
if recreate_tables:
logger.info("Recreating tables...")
self.create_db_and_tables()
return Response(
message="Database reset successfully" if recreate_tables else "Database tables dropped successfully",
status=True,
data=None
)
except Exception as e:
error_msg = f"Error while resetting database: {str(e)}"
logger.error(error_msg)
return Response(
message=error_msg,
status=False,
data=None
)
finally:
if self._init_lock.locked():
self._init_lock.release()
logger.info("Database reset lock released")
def create_db_and_tables(self):
"""Create a new database and tables"""
with self._init_lock:
try:
SQLModel.metadata.create_all(self.engine)
logger.info("Database tables created successfully")
try:
# init_db_samples(self)
pass
except Exception as e:
logger.info(
"Error while initializing database samples: " + str(e))
except Exception as e:
logger.info("Error while creating database tables:" + str(e))
def upsert(self, model: SQLModel, return_json: bool = True):
"""Create or update an entity
Args:
model (SQLModel): The model instance to create or update
return_json (bool, optional): If True, returns the model as a dictionary.
If False, returns the SQLModel instance. Defaults to True.
Returns:
Response: Contains status, message and data (either dict or SQLModel based on return_json)
"""
status = True
model_class = type(model)
existing_model = None
with Session(self.engine) as session:
try:
existing_model = session.exec(
select(model_class).where(model_class.id == model.id)).first()
if existing_model:
model.updated_at = datetime.now()
for key, value in model.model_dump().items():
setattr(existing_model, key, value)
model = existing_model # Use the updated existing model
session.add(model)
else:
session.add(model)
session.commit()
session.refresh(model)
except Exception as e:
session.rollback()
logger.error("Error while updating/creating " +
str(model_class.__name__) + ": " + str(e))
status = False
return Response(
message=(
f"{model_class.__name__} Updated Successfully"
if existing_model
else f"{model_class.__name__} Created Successfully"
),
status=status,
data=model.model_dump() if return_json else model,
)
def _model_to_dict(self, model_obj):
return {col.name: getattr(model_obj, col.name) for col in model_obj.__table__.columns}
def get(
self,
model_class: SQLModel,
filters: dict = None,
return_json: bool = False,
order: str = "desc",
):
"""List entities"""
with Session(self.engine) as session:
result = []
status = True
status_message = ""
try:
statement = select(model_class)
if filters:
conditions = [getattr(model_class, col) ==
value for col, value in filters.items()]
statement = statement.where(and_(*conditions))
if hasattr(model_class, "created_at") and order:
order_by_clause = getattr(
model_class.created_at, order)() # Dynamically apply asc/desc
statement = statement.order_by(order_by_clause)
items = session.exec(statement).all()
result = [self._model_to_dict(
item) if return_json else item for item in items]
status_message = f"{model_class.__name__} Retrieved Successfully"
except Exception as e:
session.rollback()
status = False
status_message = f"Error while fetching {model_class.__name__}"
logger.error("Error while getting items: " +
str(model_class.__name__) + " " + str(e))
return Response(message=status_message, status=status, data=result)
def delete(self, model_class: SQLModel, filters: dict = None):
"""Delete an entity"""
status_message = ""
status = True
with Session(self.engine) as session:
try:
statement = select(model_class)
if filters:
conditions = [
getattr(model_class, col) == value for col, value in filters.items()]
statement = statement.where(and_(*conditions))
rows = session.exec(statement).all()
if rows:
for row in rows:
session.delete(row)
session.commit()
status_message = f"{model_class.__name__} Deleted Successfully"
else:
status_message = "Row not found"
logger.info(f"Row with filters {filters} not found")
except exc.IntegrityError as e:
session.rollback()
status = False
status_message = f"Integrity error: The {model_class.__name__} is linked to another entity and cannot be deleted. {e}"
# Log the specific integrity error
logger.error(status_message)
except Exception as e:
session.rollback()
status = False
status_message = f"Error while deleting: {e}"
logger.error(status_message)
return Response(message=status_message, status=status, data=None)
def link(
self,
link_type: LinkTypes,
primary_id: int,
secondary_id: int,
sequence: Optional[int] = None,
):
"""Link two entities with automatic sequence handling."""
with Session(self.engine) as session:
try:
# Get classes from LinkTypes
primary_class = link_type.primary_class
secondary_class = link_type.secondary_class
link_table = link_type.link_table
# Get entities
primary_entity = session.get(primary_class, primary_id)
secondary_entity = session.get(secondary_class, secondary_id)
if not primary_entity or not secondary_entity:
return Response(message="One or both entities do not exist", status=False)
# Get field names
primary_id_field = f"{primary_class.__name__.lower()}_id"
secondary_id_field = f"{secondary_class.__name__.lower()}_id"
# Check for existing link
existing_link = session.exec(
select(link_table).where(
and_(
getattr(link_table, primary_id_field) == primary_id,
getattr(
link_table, secondary_id_field) == secondary_id
)
)
).first()
if existing_link:
return Response(message="Link already exists", status=False)
# Get the next sequence number if not provided
if sequence is None:
max_seq_result = session.exec(
select(func.max(link_table.sequence)).where(
getattr(link_table, primary_id_field) == primary_id
)
).first()
sequence = 0 if max_seq_result is None else max_seq_result + 1
# Create new link
new_link = link_table(**{
primary_id_field: primary_id,
secondary_id_field: secondary_id,
'sequence': sequence
})
session.add(new_link)
session.commit()
return Response(
message=f"Entities linked successfully with sequence {sequence}",
status=True
)
except Exception as e:
session.rollback()
return Response(message=f"Error linking entities: {str(e)}", status=False)
def unlink(
self,
link_type: LinkTypes,
primary_id: int,
secondary_id: int,
sequence: Optional[int] = None
):
"""Unlink two entities and reorder sequences if needed."""
with Session(self.engine) as session:
try:
# Get classes from LinkTypes
primary_class = link_type.primary_class
secondary_class = link_type.secondary_class
link_table = link_type.link_table
# Get field names
primary_id_field = f"{primary_class.__name__.lower()}_id"
secondary_id_field = f"{secondary_class.__name__.lower()}_id"
# Find existing link
statement = select(link_table).where(
and_(
getattr(link_table, primary_id_field) == primary_id,
getattr(link_table, secondary_id_field) == secondary_id
)
)
if sequence is not None:
statement = statement.where(
link_table.sequence == sequence)
existing_link = session.exec(statement).first()
if not existing_link:
return Response(message="Link does not exist", status=False)
deleted_sequence = existing_link.sequence
session.delete(existing_link)
# Reorder sequences for remaining links
remaining_links = session.exec(
select(link_table)
.where(getattr(link_table, primary_id_field) == primary_id)
.where(link_table.sequence > deleted_sequence)
.order_by(link_table.sequence)
).all()
# Decrease sequence numbers to fill the gap
for link in remaining_links:
link.sequence -= 1
session.commit()
return Response(
message="Entities unlinked successfully and sequences reordered",
status=True
)
except Exception as e:
session.rollback()
return Response(message=f"Error unlinking entities: {str(e)}", status=False)
def get_linked_entities(
self,
link_type: LinkTypes,
primary_id: int,
return_json: bool = False,
):
"""Get linked entities based on link type and primary ID, ordered by sequence."""
with Session(self.engine) as session:
try:
# Get classes from LinkTypes
primary_class = link_type.primary_class
secondary_class = link_type.secondary_class
link_table = link_type.link_table
# Get field names
primary_id_field = f"{primary_class.__name__.lower()}_id"
secondary_id_field = f"{secondary_class.__name__.lower()}_id"
# Query both link and entity, ordered by sequence
items = session.exec(
select(secondary_class)
.join(link_table, getattr(link_table, secondary_id_field) == secondary_class.id)
.where(getattr(link_table, primary_id_field) == primary_id)
.order_by(link_table.sequence)
).all()
result = [
item.model_dump() if return_json else item for item in items]
return Response(
message="Linked entities retrieved successfully",
status=True,
data=result
)
except Exception as e:
logger.error(f"Error getting linked entities: {str(e)}")
return Response(
message=f"Error getting linked entities: {str(e)}",
status=False,
data=[]
)
# Add new close method
async def close(self):
"""Close database connections and cleanup resources"""
logger.info("Closing database connections...")
try:
# Dispose of the SQLAlchemy engine
self.engine.dispose()
logger.info("Database connections closed successfully")
except Exception as e:
logger.error(f"Error closing database connections: {str(e)}")
raise

View File

@@ -1,491 +0,0 @@
import threading
from datetime import datetime
from typing import Optional
from loguru import logger
from sqlalchemy import exc
from sqlmodel import Session, SQLModel, and_, create_engine, select
from ..datamodel import (
Agent,
AgentLink,
AgentModelLink,
AgentSkillLink,
Model,
Response,
Skill,
Workflow,
WorkflowAgentLink,
WorkflowAgentType,
)
from .utils import init_db_samples
valid_link_types = ["agent_model", "agent_skill", "agent_agent", "workflow_agent"]
class WorkflowAgentMap(SQLModel):
agent: Agent
link: WorkflowAgentLink
class DBManager:
"""A class to manage database operations"""
_init_lock = threading.Lock() # Class-level lock
def __init__(self, engine_uri: str):
connection_args = {"check_same_thread": True} if "sqlite" in engine_uri else {}
self.engine = create_engine(engine_uri, connect_args=connection_args)
# run_migration(engine_uri=engine_uri)
def create_db_and_tables(self):
"""Create a new database and tables"""
with self._init_lock: # Use the lock
try:
SQLModel.metadata.create_all(self.engine)
try:
init_db_samples(self)
except Exception as e:
logger.info("Error while initializing database samples: " + str(e))
except Exception as e:
logger.info("Error while creating database tables:" + str(e))
def upsert(self, model: SQLModel):
"""Create a new entity"""
# check if the model exists, update else add
status = True
model_class = type(model)
existing_model = None
with Session(self.engine) as session:
try:
existing_model = session.exec(select(model_class).where(model_class.id == model.id)).first()
if existing_model:
model.updated_at = datetime.now()
for key, value in model.model_dump().items():
setattr(existing_model, key, value)
model = existing_model
session.add(model)
else:
session.add(model)
session.commit()
session.refresh(model)
except Exception as e:
session.rollback()
logger.error("Error while updating " + str(model_class.__name__) + ": " + str(e))
status = False
response = Response(
message=(
f"{model_class.__name__} Updated Successfully "
if existing_model
else f"{model_class.__name__} Created Successfully"
),
status=status,
data=model.model_dump(),
)
return response
def _model_to_dict(self, model_obj):
return {col.name: getattr(model_obj, col.name) for col in model_obj.__table__.columns}
def get_items(
self,
model_class: SQLModel,
session: Session,
filters: dict = None,
return_json: bool = False,
order: str = "desc",
):
"""List all entities"""
result = []
status = True
status_message = ""
try:
if filters:
conditions = [getattr(model_class, col) == value for col, value in filters.items()]
statement = select(model_class).where(and_(*conditions))
if hasattr(model_class, "created_at") and order:
if order == "desc":
statement = statement.order_by(model_class.created_at.desc())
else:
statement = statement.order_by(model_class.created_at.asc())
else:
statement = select(model_class)
if return_json:
result = [self._model_to_dict(row) for row in session.exec(statement).all()]
else:
result = session.exec(statement).all()
status_message = f"{model_class.__name__} Retrieved Successfully"
except Exception as e:
session.rollback()
status = False
status_message = f"Error while fetching {model_class.__name__}"
logger.error("Error while getting items: " + str(model_class.__name__) + " " + str(e))
response: Response = Response(
message=status_message,
status=status,
data=result,
)
return response
def get(
self,
model_class: SQLModel,
filters: dict = None,
return_json: bool = False,
order: str = "desc",
):
"""List all entities"""
with Session(self.engine) as session:
response = self.get_items(model_class, session, filters, return_json, order)
return response
def delete(self, model_class: SQLModel, filters: dict = None):
"""Delete an entity"""
row = None
status_message = ""
status = True
with Session(self.engine) as session:
try:
if filters:
conditions = [getattr(model_class, col) == value for col, value in filters.items()]
row = session.exec(select(model_class).where(and_(*conditions))).all()
else:
row = session.exec(select(model_class)).all()
if row:
for row in row:
session.delete(row)
session.commit()
status_message = f"{model_class.__name__} Deleted Successfully"
else:
print(f"Row with filters {filters} not found")
logger.info("Row with filters + filters + not found")
status_message = "Row not found"
except exc.IntegrityError as e:
session.rollback()
logger.error("Integrity ... Error while deleting: " + str(e))
status_message = f"The {model_class.__name__} is linked to another entity and cannot be deleted."
status = False
except Exception as e:
session.rollback()
logger.error("Error while deleting: " + str(e))
status_message = f"Error while deleting: {e}"
status = False
response = Response(
message=status_message,
status=status,
data=None,
)
return response
def get_linked_entities(
self,
link_type: str,
primary_id: int,
return_json: bool = False,
agent_type: Optional[str] = None,
sequence_id: Optional[int] = None,
):
"""
Get all entities linked to the primary entity.
Args:
link_type (str): The type of link to retrieve, e.g., "agent_model".
primary_id (int): The identifier for the primary model.
return_json (bool): Whether to return the result as a JSON object.
Returns:
List[SQLModel]: A list of linked entities.
"""
linked_entities = []
if link_type not in valid_link_types:
return []
status = True
status_message = ""
with Session(self.engine) as session:
try:
if link_type == "agent_model":
# get the agent
agent = self.get_items(Agent, filters={"id": primary_id}, session=session).data[0]
linked_entities = agent.models
elif link_type == "agent_skill":
agent = self.get_items(Agent, filters={"id": primary_id}, session=session).data[0]
linked_entities = agent.skills
elif link_type == "agent_agent":
agent = self.get_items(Agent, filters={"id": primary_id}, session=session).data[0]
linked_entities = agent.agents
elif link_type == "workflow_agent":
linked_entities = session.exec(
select(WorkflowAgentLink, Agent)
.join(Agent, WorkflowAgentLink.agent_id == Agent.id)
.where(
WorkflowAgentLink.workflow_id == primary_id,
)
).all()
linked_entities = [WorkflowAgentMap(agent=agent, link=link) for link, agent in linked_entities]
linked_entities = sorted(linked_entities, key=lambda x: x.link.sequence_id) # type: ignore
except Exception as e:
logger.error("Error while getting linked entities: " + str(e))
status_message = f"Error while getting linked entities: {e}"
status = False
if return_json:
linked_entities = [row.model_dump() for row in linked_entities]
response = Response(
message=status_message,
status=status,
data=linked_entities,
)
return response
def link(
self,
link_type: str,
primary_id: int,
secondary_id: int,
agent_type: Optional[str] = None,
sequence_id: Optional[int] = None,
) -> Response:
"""
Link two entities together.
Args:
link_type (str): The type of link to create, e.g., "agent_model".
primary_id (int): The identifier for the primary model.
secondary_id (int): The identifier for the secondary model.
agent_type (Optional[str]): The type of agent, e.g., "sender" or receiver.
Returns:
Response: The response of the linking operation, including success status and message.
"""
# TBD verify that is creator of the primary entity being linked
status = True
status_message = ""
primary_model = None
secondary_model = None
if link_type not in valid_link_types:
status = False
status_message = f"Invalid link type: {link_type}. Valid link types are: {valid_link_types}"
else:
with Session(self.engine) as session:
try:
if link_type == "agent_model":
primary_model = session.exec(select(Agent).where(Agent.id == primary_id)).first()
secondary_model = session.exec(select(Model).where(Model.id == secondary_id)).first()
if primary_model is None or secondary_model is None:
status = False
status_message = "One or both entity records do not exist."
else:
# check if the link already exists
existing_link = session.exec(
select(AgentModelLink).where(
AgentModelLink.agent_id == primary_id,
AgentModelLink.model_id == secondary_id,
)
).first()
if existing_link: # link already exists
return Response(
message=(
f"{secondary_model.__class__.__name__} already linked "
f"to {primary_model.__class__.__name__}"
),
status=False,
)
else:
primary_model.models.append(secondary_model)
elif link_type == "agent_agent":
primary_model = session.exec(select(Agent).where(Agent.id == primary_id)).first()
secondary_model = session.exec(select(Agent).where(Agent.id == secondary_id)).first()
if primary_model is None or secondary_model is None:
status = False
status_message = "One or both entity records do not exist."
else:
# check if the link already exists
existing_link = session.exec(
select(AgentLink).where(
AgentLink.parent_id == primary_id,
AgentLink.agent_id == secondary_id,
)
).first()
if existing_link:
return Response(
message=(
f"{secondary_model.__class__.__name__} already linked "
f"to {primary_model.__class__.__name__}"
),
status=False,
)
else:
primary_model.agents.append(secondary_model)
elif link_type == "agent_skill":
primary_model = session.exec(select(Agent).where(Agent.id == primary_id)).first()
secondary_model = session.exec(select(Skill).where(Skill.id == secondary_id)).first()
if primary_model is None or secondary_model is None:
status = False
status_message = "One or both entity records do not exist."
else:
# check if the link already exists
existing_link = session.exec(
select(AgentSkillLink).where(
AgentSkillLink.agent_id == primary_id,
AgentSkillLink.skill_id == secondary_id,
)
).first()
if existing_link:
return Response(
message=(
f"{secondary_model.__class__.__name__} already linked "
f"to {primary_model.__class__.__name__}"
),
status=False,
)
else:
primary_model.skills.append(secondary_model)
elif link_type == "workflow_agent":
primary_model = session.exec(select(Workflow).where(Workflow.id == primary_id)).first()
secondary_model = session.exec(select(Agent).where(Agent.id == secondary_id)).first()
if primary_model is None or secondary_model is None:
status = False
status_message = "One or both entity records do not exist."
else:
# check if the link already exists
existing_link = session.exec(
select(WorkflowAgentLink).where(
WorkflowAgentLink.workflow_id == primary_id,
WorkflowAgentLink.agent_id == secondary_id,
WorkflowAgentLink.agent_type == agent_type,
WorkflowAgentLink.sequence_id == sequence_id,
)
).first()
if existing_link:
return Response(
message=(
f"{secondary_model.__class__.__name__} already linked "
f"to {primary_model.__class__.__name__}"
),
status=False,
)
else:
# primary_model.agents.append(secondary_model)
workflow_agent_link = WorkflowAgentLink(
workflow_id=primary_id,
agent_id=secondary_id,
agent_type=agent_type,
sequence_id=sequence_id,
)
session.add(workflow_agent_link)
# add and commit the link
session.add(primary_model)
session.commit()
status_message = (
f"{secondary_model.__class__.__name__} successfully linked "
f"to {primary_model.__class__.__name__}"
)
except Exception as e:
session.rollback()
logger.error("Error while linking: " + str(e))
status = False
status_message = f"Error while linking due to an exception: {e}"
response = Response(
message=status_message,
status=status,
)
return response
def unlink(
self,
link_type: str,
primary_id: int,
secondary_id: int,
agent_type: Optional[str] = None,
sequence_id: Optional[int] = 0,
) -> Response:
"""
Unlink two entities.
Args:
link_type (str): The type of link to remove, e.g., "agent_model".
primary_id (int): The identifier for the primary model.
secondary_id (int): The identifier for the secondary model.
agent_type (Optional[str]): The type of agent, e.g., "sender" or receiver.
Returns:
Response: The response of the unlinking operation, including success status and message.
"""
status = True
status_message = ""
print("primary", primary_id, "secondary", secondary_id, "sequence", sequence_id, "agent_type", agent_type)
if link_type not in valid_link_types:
status = False
status_message = f"Invalid link type: {link_type}. Valid link types are: {valid_link_types}"
return Response(message=status_message, status=status)
with Session(self.engine) as session:
try:
if link_type == "agent_model":
existing_link = session.exec(
select(AgentModelLink).where(
AgentModelLink.agent_id == primary_id,
AgentModelLink.model_id == secondary_id,
)
).first()
elif link_type == "agent_skill":
existing_link = session.exec(
select(AgentSkillLink).where(
AgentSkillLink.agent_id == primary_id,
AgentSkillLink.skill_id == secondary_id,
)
).first()
elif link_type == "agent_agent":
existing_link = session.exec(
select(AgentLink).where(
AgentLink.parent_id == primary_id,
AgentLink.agent_id == secondary_id,
)
).first()
elif link_type == "workflow_agent":
existing_link = session.exec(
select(WorkflowAgentLink).where(
WorkflowAgentLink.workflow_id == primary_id,
WorkflowAgentLink.agent_id == secondary_id,
WorkflowAgentLink.agent_type == agent_type,
WorkflowAgentLink.sequence_id == sequence_id,
)
).first()
if existing_link:
session.delete(existing_link)
session.commit()
status_message = "Link removed successfully."
else:
status = False
status_message = "Link does not exist."
except Exception as e:
session.rollback()
logger.error("Error while unlinking: " + str(e))
status = False
status_message = f"Error while unlinking due to an exception: {e}"
return Response(message=status_message, status=status)

View File

@@ -1 +0,0 @@
Generic single-database configuration.

View File

@@ -1,80 +0,0 @@
import os
from logging.config import fileConfig
from alembic import context
from sqlalchemy import engine_from_config, pool
from sqlmodel import SQLModel
from autogenstudio.datamodel import *
from autogenstudio.utils import get_db_uri
# this is the Alembic Config object, which provides
# access to the values within the .ini file in use.
config = context.config
config.set_main_option("sqlalchemy.url", get_db_uri())
# Interpret the config file for Python logging.
# This line sets up loggers basically.
if config.config_file_name is not None:
fileConfig(config.config_file_name)
# add your model's MetaData object here
# for 'autogenerate' support
# from myapp import mymodel
# target_metadata = mymodel.Base.metadata
target_metadata = SQLModel.metadata
# other values from the config, defined by the needs of env.py,
# can be acquired:
# my_important_option = config.get_main_option("my_important_option")
# ... etc.
def run_migrations_offline() -> None:
"""Run migrations in 'offline' mode.
This configures the context with just a URL
and not an Engine, though an Engine is acceptable
here as well. By skipping the Engine creation
we don't even need a DBAPI to be available.
Calls to context.execute() here emit the given string to the
script output.
"""
url = config.get_main_option("sqlalchemy.url")
context.configure(
url=url,
target_metadata=target_metadata,
literal_binds=True,
dialect_opts={"paramstyle": "named"},
)
with context.begin_transaction():
context.run_migrations()
def run_migrations_online() -> None:
"""Run migrations in 'online' mode.
In this scenario we need to create an Engine
and associate a connection with the context.
"""
connectable = engine_from_config(
config.get_section(config.config_ini_section, {}),
prefix="sqlalchemy.",
poolclass=pool.NullPool,
)
with connectable.connect() as connection:
context.configure(connection=connection, target_metadata=target_metadata)
with context.begin_transaction():
context.run_migrations()
if context.is_offline_mode():
run_migrations_offline()
else:
run_migrations_online()

View File

@@ -1,27 +0,0 @@
"""${message}
Revision ID: ${up_revision}
Revises: ${down_revision | comma,n}
Create Date: ${create_date}
"""
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
import sqlmodel
${imports if imports else ""}
# revision identifiers, used by Alembic.
revision: str = ${repr(up_revision)}
down_revision: Union[str, None] = ${repr(down_revision)}
branch_labels: Union[str, Sequence[str], None] = ${repr(branch_labels)}
depends_on: Union[str, Sequence[str], None] = ${repr(depends_on)}
def upgrade() -> None:
${upgrades if upgrades else "pass"}
def downgrade() -> None:
${downgrades if downgrades else "pass"}

View File

@@ -0,0 +1,505 @@
import os
from pathlib import Path
import shutil
from typing import Optional, Tuple, List
from loguru import logger
from alembic import command
from alembic.config import Config
from alembic.runtime.migration import MigrationContext
from alembic.script import ScriptDirectory
from alembic.autogenerate import compare_metadata
from sqlalchemy import Engine
from sqlmodel import SQLModel
class SchemaManager:
"""
Manages database schema validation and migrations using Alembic.
Provides automatic schema validation, migrations, and safe upgrades.
Args:
engine: SQLAlchemy engine instance
auto_upgrade: Whether to automatically upgrade schema when differences found
init_mode: Controls initialization behavior:
- "none": No automatic initialization (raises error if not set up)
- "auto": Initialize if not present (default)
- "force": Always reinitialize, removing existing configuration
"""
def __init__(
self,
engine: Engine,
auto_upgrade: bool = True,
init_mode: str = "auto"
):
if init_mode not in ["none", "auto", "force"]:
raise ValueError("init_mode must be one of: none, auto, force")
self.engine = engine
self.auto_upgrade = auto_upgrade
# Set up paths relative to this file
self.base_dir = Path(__file__).parent
self.alembic_dir = self.base_dir / 'alembic'
self.alembic_ini_path = self.base_dir / 'alembic.ini'
# Handle initialization based on mode
if init_mode == "none":
self._validate_alembic_setup()
else:
self._ensure_alembic_setup(force=init_mode == "force")
def _cleanup_existing_alembic(self) -> None:
"""
Safely removes existing Alembic configuration while preserving versions directory.
"""
logger.info(
"Cleaning up existing Alembic configuration while preserving versions...")
# Create a backup of versions directory if it exists
if self.alembic_dir.exists() and (self.alembic_dir / 'versions').exists():
logger.info("Preserving existing versions directory")
# Remove alembic directory contents EXCEPT versions
if self.alembic_dir.exists():
for item in self.alembic_dir.iterdir():
if item.name != 'versions':
try:
if item.is_dir():
shutil.rmtree(item)
logger.info(f"Removed directory: {item}")
else:
item.unlink()
logger.info(f"Removed file: {item}")
except Exception as e:
logger.error(f"Failed to remove {item}: {e}")
# Remove alembic.ini if it exists
if self.alembic_ini_path.exists():
try:
self.alembic_ini_path.unlink()
logger.info(
f"Removed existing alembic.ini: {self.alembic_ini_path}")
except Exception as e:
logger.error(f"Failed to remove alembic.ini: {e}")
def _ensure_alembic_setup(self, *, force: bool = False) -> None:
"""
Ensures Alembic is properly set up, initializing if necessary.
Args:
force: If True, removes existing configuration and reinitializes
"""
try:
self._validate_alembic_setup()
if force:
logger.info(
"Force initialization requested. Cleaning up existing configuration...")
self._cleanup_existing_alembic()
self._initialize_alembic()
except FileNotFoundError:
logger.info("Alembic configuration not found. Initializing...")
if self.alembic_dir.exists():
logger.warning(
"Found existing alembic directory but missing configuration")
self._cleanup_existing_alembic()
self._initialize_alembic()
logger.info("Alembic initialization complete")
def _initialize_alembic(self) -> str:
"""Initializes Alembic configuration in the local directory."""
logger.info("Initializing Alembic configuration...")
# Check if versions exists
has_versions = (self.alembic_dir / 'versions').exists()
logger.info(f"Existing versions directory found: {has_versions}")
# Create base directories
self.alembic_dir.mkdir(exist_ok=True)
if not has_versions:
(self.alembic_dir / 'versions').mkdir(exist_ok=True)
# Write alembic.ini
ini_content = self._generate_alembic_ini_content()
with open(self.alembic_ini_path, 'w') as f:
f.write(ini_content)
logger.info("Created alembic.ini")
if not has_versions:
# Only run init if no versions directory
config = self.get_alembic_config()
command.init(config, str(self.alembic_dir))
logger.info("Initialized new Alembic directory structure")
else:
# Create minimal env.py if it doesn't exist
env_path = self.alembic_dir / 'env.py'
if not env_path.exists():
self._create_minimal_env_py(env_path)
logger.info("Created minimal env.py")
else:
# Update existing env.py
self._update_env_py(env_path)
logger.info("Updated existing env.py")
logger.info(f"Alembic setup completed at {self.base_dir}")
return str(self.alembic_ini_path)
def _create_minimal_env_py(self, env_path: Path) -> None:
"""Creates a minimal env.py file for Alembic."""
content = '''
from logging.config import fileConfig
from sqlalchemy import engine_from_config
from sqlalchemy import pool
from alembic import context
from sqlmodel import SQLModel
config = context.config
if config.config_file_name is not None:
fileConfig(config.config_file_name)
target_metadata = SQLModel.metadata
def run_migrations_offline() -> None:
url = config.get_main_option("sqlalchemy.url")
context.configure(
url=url,
target_metadata=target_metadata,
literal_binds=True,
dialect_opts={"paramstyle": "named"},
compare_type=True
)
with context.begin_transaction():
context.run_migrations()
def run_migrations_online() -> None:
connectable = engine_from_config(
config.get_section(config.config_ini_section),
prefix="sqlalchemy.",
poolclass=pool.NullPool,
)
with connectable.connect() as connection:
context.configure(
connection=connection,
target_metadata=target_metadata,
compare_type=True
)
with context.begin_transaction():
context.run_migrations()
if context.is_offline_mode():
run_migrations_offline()
else:
run_migrations_online()'''
with open(env_path, 'w') as f:
f.write(content)
def _generate_alembic_ini_content(self) -> str:
"""
Generates content for alembic.ini file.
"""
return f"""
[alembic]
script_location = {self.alembic_dir}
sqlalchemy.url = {self.engine.url}
[loggers]
keys = root,sqlalchemy,alembic
[handlers]
keys = console
[formatters]
keys = generic
[logger_root]
level = WARN
handlers = console
qualname =
[logger_sqlalchemy]
level = WARN
handlers =
qualname = sqlalchemy.engine
[logger_alembic]
level = INFO
handlers =
qualname = alembic
[handler_console]
class = StreamHandler
args = (sys.stderr,)
level = NOTSET
formatter = generic
[formatter_generic]
format = %(levelname)-5.5s [%(name)s] %(message)s
datefmt = %H:%M:%S
""".strip()
def _update_env_py(self, env_path: Path) -> None:
"""
Updates the env.py file to use SQLModel metadata.
"""
try:
with open(env_path, 'r') as f:
content = f.read()
# Add SQLModel import
if "from sqlmodel import SQLModel" not in content:
content = "from sqlmodel import SQLModel\n" + content
# Replace target_metadata
content = content.replace(
"target_metadata = None",
"target_metadata = SQLModel.metadata"
)
# Add compare_type=True to context.configure
if "context.configure(" in content and "compare_type=True" not in content:
content = content.replace(
"context.configure(",
"context.configure(compare_type=True,"
)
with open(env_path, 'w') as f:
f.write(content)
logger.info("Updated env.py with SQLModel metadata")
except Exception as e:
logger.error(f"Failed to update env.py: {e}")
raise
# Fixed: use keyword-only argument
def _ensure_alembic_setup(self, *, force: bool = False) -> None:
"""
Ensures Alembic is properly set up, initializing if necessary.
Args:
force: If True, removes existing configuration and reinitializes
"""
try:
self._validate_alembic_setup()
if force:
logger.info(
"Force initialization requested. Cleaning up existing configuration...")
self._cleanup_existing_alembic()
self._initialize_alembic()
except FileNotFoundError:
logger.info("Alembic configuration not found. Initializing...")
if self.alembic_dir.exists():
logger.warning(
"Found existing alembic directory but missing configuration")
self._cleanup_existing_alembic()
self._initialize_alembic()
logger.info("Alembic initialization complete")
def _validate_alembic_setup(self) -> None:
"""Validates that Alembic is properly configured."""
if not self.alembic_ini_path.exists():
raise FileNotFoundError("Alembic configuration not found")
def get_alembic_config(self) -> Config:
"""
Gets Alembic configuration.
Returns:
Config: Alembic Config object
Raises:
FileNotFoundError: If alembic.ini cannot be found
"""
if not self.alembic_ini_path.exists():
raise FileNotFoundError("Could not find alembic.ini")
return Config(str(self.alembic_ini_path))
def get_current_revision(self) -> Optional[str]:
"""
Gets the current database revision.
Returns:
str: Current revision string or None if no revision
"""
with self.engine.connect() as conn:
context = MigrationContext.configure(conn)
return context.get_current_revision()
def get_head_revision(self) -> str:
"""
Gets the latest available revision.
Returns:
str: Head revision string
"""
config = self.get_alembic_config()
script = ScriptDirectory.from_config(config)
return script.get_current_head()
def get_schema_differences(self) -> List[tuple]:
"""
Detects differences between current database and models.
Returns:
List[tuple]: List of differences found
"""
with self.engine.connect() as conn:
context = MigrationContext.configure(conn)
diff = compare_metadata(context, SQLModel.metadata)
return list(diff)
def check_schema_status(self) -> Tuple[bool, str]:
"""
Checks if database schema matches current models and migrations.
Returns:
Tuple[bool, str]: (needs_upgrade, status_message)
"""
try:
current_rev = self.get_current_revision()
head_rev = self.get_head_revision()
if current_rev != head_rev:
return True, f"Database needs upgrade: {current_rev} -> {head_rev}"
differences = self.get_schema_differences()
if differences:
changes_desc = "\n".join(str(diff) for diff in differences)
return True, f"Unmigrated changes detected:\n{changes_desc}"
return False, "Database schema is up to date"
except Exception as e:
logger.error(f"Error checking schema status: {str(e)}")
return True, f"Error checking schema: {str(e)}"
def upgrade_schema(self, revision: str = "head") -> bool:
"""
Upgrades database schema to specified revision.
Args:
revision: Target revision (default: "head")
Returns:
bool: True if upgrade successful
"""
try:
config = self.get_alembic_config()
command.upgrade(config, revision)
logger.info(f"Schema upgraded successfully to {revision}")
return True
except Exception as e:
logger.error(f"Schema upgrade failed: {str(e)}")
return False
def check_and_upgrade(self) -> Tuple[bool, str]:
"""
Checks schema status and upgrades if necessary (and auto_upgrade is True).
Returns:
Tuple[bool, str]: (action_taken, status_message)
"""
needs_upgrade, status = self.check_schema_status()
if needs_upgrade:
if self.auto_upgrade:
if self.upgrade_schema():
return True, "Schema was automatically upgraded"
else:
return False, "Automatic schema upgrade failed"
else:
return False, f"Schema needs upgrade but auto_upgrade is disabled. Status: {status}"
return False, status
def generate_revision(self, message: str = "auto") -> Optional[str]:
"""
Generates new migration revision for current schema changes.
Args:
message: Revision message
Returns:
str: Revision ID if successful, None otherwise
"""
try:
config = self.get_alembic_config()
command.revision(
config,
message=message,
autogenerate=True
)
return self.get_head_revision()
except Exception as e:
logger.error(f"Failed to generate revision: {str(e)}")
return None
def get_pending_migrations(self) -> List[str]:
"""
Gets list of pending migrations that need to be applied.
Returns:
List[str]: List of pending migration revision IDs
"""
config = self.get_alembic_config()
script = ScriptDirectory.from_config(config)
current = self.get_current_revision()
head = self.get_head_revision()
if current == head:
return []
pending = []
for rev in script.iterate_revisions(current, head):
pending.append(rev.revision)
return pending
def print_status(self) -> None:
"""Prints current migration status information to logger."""
current = self.get_current_revision()
head = self.get_head_revision()
differences = self.get_schema_differences()
pending = self.get_pending_migrations()
logger.info("=== Database Schema Status ===")
logger.info(f"Current revision: {current}")
logger.info(f"Head revision: {head}")
logger.info(f"Pending migrations: {len(pending)}")
for rev in pending:
logger.info(f" - {rev}")
logger.info(f"Unmigrated changes: {len(differences)}")
for diff in differences:
logger.info(f" - {diff}")
def ensure_schema_up_to_date(self) -> bool:
"""
Ensures the database schema is up to date, generating and applying migrations if needed.
Returns:
bool: True if schema is up to date or was successfully updated
"""
try:
# Check for unmigrated changes
differences = self.get_schema_differences()
if differences:
# Generate new migration
revision = self.generate_revision("auto-generated")
if not revision:
return False
logger.info(f"Generated new migration: {revision}")
# Apply any pending migrations
upgraded, status = self.check_and_upgrade()
if not upgraded and "needs upgrade" in status.lower():
return False
return True
except Exception as e:
logger.error(f"Failed to ensure schema is up to date: {e}")
return False

File diff suppressed because one or more lines are too long