Files
gpt-pilot/core/cli/helpers.py
Senko Rasic 5b474ccc1f merge gpt-pilot 0.2 codebase
This is a complete rewrite of the GPT Pilot core, from the ground
up, making the agentic architecture front and center, and also
fixing some long-standing problems with the database architecture
that weren't feasible to solve without breaking compatibility.

As the database structure and config file syntax have changed,
we have automatic imports for projects and current configs,
see the README.md file for details.

This also relicenses the project to FSL-1.1-MIT license.
2024-05-22 21:42:25 +02:00

320 lines
11 KiB
Python

import json
import os
import os.path
import sys
from argparse import ArgumentParser, ArgumentTypeError, Namespace
from typing import Optional
from urllib.parse import urlparse
from uuid import UUID
from core.config import Config, LLMProvider, LocalIPCConfig, ProviderConfig, UIAdapter, get_config, loader
from core.config.env_importer import import_from_dotenv
from core.config.version import get_version
from core.db.session import SessionManager
from core.db.setup import run_migrations
from core.log import setup
from core.state.state_manager import StateManager
from core.ui.base import UIBase
from core.ui.console import PlainConsoleUI
from core.ui.ipc_client import IPCClientUI
def parse_llm_endpoint(value: str) -> Optional[tuple[LLMProvider, str]]:
"""
Parse --llm-endpoint command-line option.
Option syntax is: --llm-endpoint <provider>:<url>
:param value: Argument value.
:return: Tuple with LLM provider and URL, or None if if the option wasn't provided.
"""
if not value:
return None
parts = value.split(":", 1)
if len(parts) != 2:
raise ArgumentTypeError("Invalid LLM endpoint format; expected 'provider:url'")
try:
provider = LLMProvider(parts[0])
except ValueError as err:
raise ArgumentTypeError(f"Unsupported LLM provider: {err}")
url = urlparse(parts[1])
if url.scheme not in ("http", "https"):
raise ArgumentTypeError(f"Invalid LLM endpoint URL: {parts[1]}")
return provider, url.geturl()
def parse_llm_key(value: str) -> Optional[tuple[LLMProvider, str]]:
"""
Parse --llm-key command-line option.
Option syntax is: --llm-key <provider>:<key>
:param value: Argument value.
:return: Tuple with LLM provider and key, or None if if the option wasn't provided.
"""
if not value:
return None
parts = value.split(":", 1)
if len(parts) != 2:
raise ArgumentTypeError("Invalid LLM endpoint format; expected 'provider:key'")
try:
provider = LLMProvider(parts[0])
except ValueError as err:
raise ArgumentTypeError(f"Unsupported LLM provider: {err}")
return provider, parts[1]
def parse_arguments() -> Namespace:
"""
Parse command-line arguments.
Available arguments:
--help: Show the help message
--config: Path to the configuration file
--show-config: Output the default configuration to stdout
--default-config: Output the configuration to stdout
--level: Log level (debug,info,warning,error,critical)
--database: Database URL
--local-ipc-port: Local IPC port to connect to
--local-ipc-host: Local IPC host to connect to
--version: Show the version and exit
--list: List all projects
--list-json: List all projects in JSON format
--project: Load a specific project
--branch: Load a specific branch
--step: Load a specific step in a project/branch
--llm-endpoint: Use specific API endpoint for the given provider
--llm-key: Use specific LLM key for the given provider
--import-v0: Import data from a v0 (gpt-pilot) database with the given path
--email: User's email address, if provided
--extension-version: Version of the VSCode extension, if used
:return: Parsed arguments object.
"""
version = get_version()
parser = ArgumentParser()
parser.add_argument("--config", help="Path to the configuration file", default="config.json")
parser.add_argument("--show-config", help="Output the default configuration to stdout", action="store_true")
parser.add_argument("--level", help="Log level (debug,info,warning,error,critical)", required=False)
parser.add_argument("--database", help="Database URL", required=False)
parser.add_argument("--local-ipc-port", help="Local IPC port to connect to", type=int, required=False)
parser.add_argument("--local-ipc-host", help="Local IPC host to connect to", default="localhost", required=False)
parser.add_argument("--version", action="version", version=version)
parser.add_argument("--list", help="List all projects", action="store_true")
parser.add_argument("--list-json", help="List all projects in JSON format", action="store_true")
parser.add_argument("--project", help="Load a specific project", type=UUID, required=False)
parser.add_argument("--branch", help="Load a specific branch", type=UUID, required=False)
parser.add_argument("--step", help="Load a specific step in a project/branch", type=int, required=False)
parser.add_argument("--delete", help="Delete a specific project", type=UUID, required=False)
parser.add_argument(
"--llm-endpoint",
help="Use specific API endpoint for the given provider",
type=parse_llm_endpoint,
action="append",
required=False,
)
parser.add_argument(
"--llm-key",
help="Use specific LLM key for the given provider",
type=parse_llm_key,
action="append",
required=False,
)
parser.add_argument(
"--import-v0",
help="Import data from a v0 (gpt-pilot) database with the given path",
required=False,
)
parser.add_argument("--email", help="User's email address", required=False)
parser.add_argument("--extension-version", help="Version of the VSCode extension", required=False)
return parser.parse_args()
def load_config(args: Namespace) -> Optional[Config]:
"""
Load Pythagora JSON configuration file and apply command-line arguments.
:param args: Command-line arguments (at least `config` must be present).
:return: Configuration object, or None if config couldn't be loaded.
"""
if not os.path.isfile(args.config):
imported = import_from_dotenv(args.config)
if not imported:
print(f"Configuration file not found: {args.config}; using default", file=sys.stderr)
return get_config()
try:
config = loader.load(args.config)
except ValueError as err:
print(f"Error parsing config file {args.config}: {err}", file=sys.stderr)
return None
if args.level:
config.log.level = args.level.upper()
if args.database:
config.db.url = args.database
if args.local_ipc_port:
config.ui = LocalIPCConfig(port=args.local_ipc_port, host=args.local_ipc_host)
if args.llm_endpoint:
for provider, endpoint in args.llm_endpoint:
if provider not in config.llm:
config.llm[provider] = ProviderConfig()
config.llm[provider].base_url = endpoint
if args.llm_key:
for provider, key in args.llm_key:
if provider not in config.llm:
config.llm[provider] = ProviderConfig()
config.llm[provider].api_key = key
try:
Config.model_validate(config)
except ValueError as err:
print(f"Configuration error: {err}", file=sys.stderr)
return None
return config
async def list_projects_json(db: SessionManager):
"""
List all projects in the database in JSON format.
"""
sm = StateManager(db)
projects = await sm.list_projects()
data = []
for project in projects:
p = {
"name": project.name,
"id": project.id.hex,
"branches": [],
}
for branch in project.branches:
b = {
"name": branch.name,
"id": branch.id.hex,
"steps": [],
}
for state in branch.states:
s = {
"name": f"Step #{state.step_index}",
"step": state.step_index,
}
b["steps"].append(s)
if b["steps"]:
b["steps"][-1]["name"] = "Latest step"
p["branches"].append(b)
data.append(p)
print(json.dumps(data, indent=2))
async def list_projects(db: SessionManager):
"""
List all projects in the database.
"""
sm = StateManager(db)
projects = await sm.list_projects()
print(f"Available projects ({len(projects)}):")
for project in projects:
print(f"* {project.name} ({project.id})")
for branch in project.branches:
last_step = max(state.step_index for state in branch.states)
print(f" - {branch.name} ({branch.id}) - last step: {last_step}")
async def load_project(
sm: StateManager,
project_id: Optional[UUID] = None,
branch_id: Optional[UUID] = None,
step_index: Optional[int] = None,
) -> bool:
"""
Load a project from the database.
:param sm: State manager.
:param project_id: Project ID (optional, loads the last step in the main branch).
:param branch_id: Branch ID (optional, loads the last step in the branch).
:param step_index: Step index (optional, loads the state at the given step).
:return: True if the project was loaded successfully, False otherwise.
"""
step_txt = f" step {step_index}" if step_index else ""
if branch_id:
project_state = await sm.load_project(branch_id=branch_id, step_index=step_index)
if project_state:
return True
else:
print(f"Branch {branch_id}{step_txt} not found; use --list to list all projects", file=sys.stderr)
return False
elif project_id:
project_state = await sm.load_project(project_id=project_id, step_index=step_index)
if project_state:
return True
else:
print(f"Project {project_id}{step_txt} not found; use --list to list all projects", file=sys.stderr)
return False
return False
async def delete_project(sm: StateManager, project_id: UUID) -> bool:
"""
Delete a project from a database.
:param sm: State manager.
:param project_id: Project ID.
:return: True if project was deleted, False otherwise.
"""
return await sm.delete_project(project_id)
def show_config():
"""
Print the current configuration to stdout.
"""
cfg = get_config()
print(cfg.model_dump_json(indent=2))
def init() -> tuple[UIBase, SessionManager, Namespace]:
"""
Initialize the application.
Loads configuration, sets up logging and UI, initializes the database
and runs database migrations.
:return: Tuple with UI, db session manager, file manager, and command-line arguments.
"""
args = parse_arguments()
config = load_config(args)
if not config:
return (None, None, args)
setup(config.log, force=True)
if config.ui.type == UIAdapter.IPC_CLIENT:
ui = IPCClientUI(config.ui)
else:
ui = PlainConsoleUI()
run_migrations(config.db)
db = SessionManager(config.db)
return (ui, db, args)
__all__ = ["parse_arguments", "load_config", "list_projects_json", "list_projects", "load_project", "init"]