Merge branch 'master' into zamilmajdy/code-validation

This commit is contained in:
Reinier van der Leer
2024-06-03 21:43:59 +02:00
413 changed files with 6467 additions and 12095 deletions

7
forge/.env.example Normal file
View File

@@ -0,0 +1,7 @@
# Your OpenAI API Key. If GPT-4 is available it will use that, otherwise will use 3.5-turbo
OPENAI_API_KEY=abc
# Control log level
LOG_LEVEL=INFO
DATABASE_STRING="sqlite:///agent.db"
PORT=8000

11
forge/.flake8 Normal file
View File

@@ -0,0 +1,11 @@
[flake8]
max-line-length = 88
# Ignore rules that conflict with Black code style
extend-ignore = E203, W503
exclude =
.git,
__pycache__/,
*.pyc,
.pytest_cache/,
venv*/,
.venv/,

175
forge/.gitignore vendored Normal file
View File

@@ -0,0 +1,175 @@
## Original ignores
autogpt/keys.py
autogpt/*.json
*.mpeg
.env
azure.yaml
.vscode
.idea/*
auto-gpt.json
log.txt
log-ingestion.txt
logs
*.log
*.mp3
mem.sqlite3
venvAutoGPT
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class
# C extensions
*.so
# Distribution / packaging
.Python
build/
develop-eggs/
dist/
plugins/
plugins_config.yaml
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
pip-wheel-metadata/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST
# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec
# Installer logs
pip-log.txt
pip-delete-this-directory.txt
# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
*.py,cover
.hypothesis/
.pytest_cache/
# Translations
*.mo
*.pot
# Django stuff:
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal
# Flask stuff:
instance/
.webassets-cache
# Scrapy stuff:
.scrapy
# Sphinx documentation
docs/_build/
site/
# PyBuilder
target/
# Jupyter Notebook
.ipynb_checkpoints
# IPython
profile_default/
ipython_config.py
# pyenv
.python-version
# pipenv
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
# However, in case of collaboration, if having platform-specific dependencies or dependencies
# having no cross-platform support, pipenv may install dependencies that don't work, or not
# install all needed dependencies.
#Pipfile.lock
# PEP 582; used by e.g. github.com/David-OConnor/pyflow
__pypackages__/
# Celery stuff
celerybeat-schedule
celerybeat.pid
# SageMath parsed files
*.sage.py
# Environments
.direnv/
.env
.venv
env/
venv*/
ENV/
env.bak/
# Spyder project settings
.spyderproject
.spyproject
# Rope project settings
.ropeproject
# mkdocs documentation
/site
# mypy
.mypy_cache/
.dmypy.json
dmypy.json
# Pyre type checker
.pyre/
llama-*
vicuna-*
# mac
.DS_Store
openai/
# news
CURRENT_BULLETIN.md
agbenchmark_config/workspace
agbenchmark_config/reports
*.sqlite*
*.db
.agbench
.agbenchmark
.benchmarks
.mypy_cache
.pytest_cache
.vscode
ig_*
agbenchmark_config/updates.json
agbenchmark_config/challenges_already_beaten.json
agbenchmark_config/temp_folder/*
test_workspace/

40
forge/Dockerfile Normal file
View File

@@ -0,0 +1,40 @@
# Use an official Python runtime as a parent image
FROM python:3.11-slim-buster as base
# Set work directory in the container
WORKDIR /app
# Install system dependencies
RUN apt-get update \
&& apt-get install -y build-essential curl ffmpeg \
&& apt-get clean \
&& rm -rf /var/lib/apt/lists/*
# Install Poetry - respects $POETRY_VERSION & $POETRY_HOME
ENV POETRY_VERSION=1.1.8 \
POETRY_HOME="/opt/poetry" \
POETRY_NO_INTERACTION=1 \
POETRY_VIRTUALENVS_CREATE=false \
PATH="$POETRY_HOME/bin:$PATH"
RUN pip3 install poetry
COPY pyproject.toml poetry.lock* /app/
# Project initialization:
RUN poetry install --no-interaction --no-ansi
ENV PYTHONPATH="/app:$PYTHONPATH"
FROM base as dependencies
# Copy project
COPY . /app
# Make port 80 available to the world outside this container
EXPOSE 8000
# Run the application when the container launches
CMD ["poetry", "run", "python", "autogpt/__main__.py"]

24
forge/README.md Normal file
View File

@@ -0,0 +1,24 @@
# 🚀 **AutoGPT-Forge**: Build Your Own AutoGPT Agent! 🧠
### 🌌 Dive into the Universe of AutoGPT Creation! 🌌
Ever dreamt of becoming the genius behind an AI agent? Dive into the *Forge*, where **you** become the creator!
---
### 🛠️ **Why AutoGPT-Forge?**
- 💤 **No More Boilerplate!** Don't let the mundane tasks stop you. Fork and build without the headache of starting from scratch!
- 🧠 **Brain-centric Development!** All the tools you need so you can spend 100% of your time on what matters - crafting the brain of your AI!
- 🛠️ **Tooling ecosystem!** We work with the best in class tools to bring you the best experience possible!
---
### 🚀 **Get Started!**
The getting started [tutorial series](https://aiedge.medium.com/autogpt-forge-e3de53cc58ec) will guide you through the process of setting up your project all the way through to building a generalist agent.
1. [AutoGPT Forge: A Comprehensive Guide to Your First Steps](https://aiedge.medium.com/autogpt-forge-a-comprehensive-guide-to-your-first-steps-a1dfdf46e3b4)
2. [AutoGPT Forge: The Blueprint of an AI Agent](https://aiedge.medium.com/autogpt-forge-the-blueprint-of-an-ai-agent-75cd72ffde6)
3. [AutoGPT Forge: Interacting with your Agent](https://aiedge.medium.com/autogpt-forge-interacting-with-your-agent-1214561b06b)
4. [AutoGPT Forge: Crafting Intelligent Agent Logic](https://medium.com/@aiedge/autogpt-forge-crafting-intelligent-agent-logic-bc5197b14cb4)

View File

@@ -0,0 +1,4 @@
{
"workspace": {"input": "agbenchmark_config/workspace", "output": "agbenchmark_config/workspace"},
"host": "http://localhost:8000"
}

0
forge/forge/__init__.py Normal file
View File

54
forge/forge/__main__.py Normal file
View File

@@ -0,0 +1,54 @@
import logging
import os
import uvicorn
from dotenv import load_dotenv
from forge.logging.config import configure_logging
logger = logging.getLogger(__name__)
logo = """\n\n
d8888 888 .d8888b. 8888888b. 88888888888
d88P888 888 888 888 888 888 888
d88P 888 888 888 888888 .d88b. 888 888 d88P 888
d88P 888 888 888 888 d88""88b 888 88888 8888888P" 888
d88P 888 888 888 888 888 888 888 888 888 888
d8888888888 Y88b 888 Y88b. Y88..88P Y88b d88P 888 888
d88P 888 "Y88888 "Y888 "Y88P" "Y8888P88 888 888
8888888888
888
888 .d88b. 888d888 .d88b. .d88b.
888888 d88""88b 888P" d88P"88b d8P Y8b
888 888 888 888 888 888 88888888
888 Y88..88P 888 Y88b 888 Y8b.
888 "Y88P" 888 "Y88888 "Y8888
888
Y8b d88P
"Y88P" v0.1.0
\n"""
if __name__ == "__main__":
print(logo)
port = os.getenv("PORT", 8000)
configure_logging()
logger.info(f"Agent server starting on http://localhost:{port}")
load_dotenv()
uvicorn.run(
"forge.app:app",
host="localhost",
port=int(port),
log_level="error",
# Reload on changes to code or .env
reload=True,
reload_dirs=os.path.dirname(os.path.dirname(__file__)),
reload_excludes="*.py", # Cancel default *.py include pattern
reload_includes=[
f"{os.path.basename(os.path.dirname(__file__))}/**/*.py",
".*",
".env",
],
)

View File

@@ -0,0 +1,7 @@
from .base import BaseAgent, BaseAgentConfiguration, BaseAgentSettings
__all__ = [
"BaseAgent",
"BaseAgentConfiguration",
"BaseAgentSettings",
]

193
forge/forge/agent/agent.py Normal file
View File

@@ -0,0 +1,193 @@
import logging
import os
import pathlib
from io import BytesIO
from uuid import uuid4
import uvicorn
from fastapi import APIRouter, FastAPI, UploadFile
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import RedirectResponse, StreamingResponse
from fastapi.staticfiles import StaticFiles
from forge.agent_protocol.api_router import base_router
from forge.agent_protocol.database.db import AgentDB
from forge.agent_protocol.middlewares import AgentMiddleware
from forge.agent_protocol.models.task import (
Artifact,
Step,
StepRequestBody,
Task,
TaskArtifactsListResponse,
TaskListResponse,
TaskRequestBody,
TaskStepsListResponse,
)
from forge.file_storage.base import FileStorage
logger = logging.getLogger(__name__)
class Agent:
def __init__(self, database: AgentDB, workspace: FileStorage):
self.db = database
self.workspace = workspace
def get_agent_app(self, router: APIRouter = base_router):
"""
Start the agent server.
"""
app = FastAPI(
title="AutoGPT Forge",
description="Modified version of The Agent Protocol.",
version="v0.4",
)
# Add CORS middleware
origins = [
"http://localhost:5000",
"http://127.0.0.1:5000",
"http://localhost:8000",
"http://127.0.0.1:8000",
"http://localhost:8080",
"http://127.0.0.1:8080",
# Add any other origins you want to whitelist
]
app.add_middleware(
CORSMiddleware,
allow_origins=origins,
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
app.include_router(router, prefix="/ap/v1")
script_dir = os.path.dirname(os.path.realpath(__file__))
frontend_path = pathlib.Path(
os.path.join(script_dir, "../../../frontend/build/web")
).resolve()
if os.path.exists(frontend_path):
app.mount("/app", StaticFiles(directory=frontend_path), name="app")
@app.get("/", include_in_schema=False)
async def root():
return RedirectResponse(url="/app/index.html", status_code=307)
else:
logger.warning(
f"Frontend not found. {frontend_path} does not exist. "
"The frontend will not be served."
)
app.add_middleware(AgentMiddleware, agent=self)
return app
def start(self, port):
uvicorn.run(
"forge.app:app", host="localhost", port=port, log_level="error", reload=True
)
async def create_task(self, task_request: TaskRequestBody) -> Task:
"""
Create a task for the agent.
"""
task = await self.db.create_task(
input=task_request.input,
additional_input=task_request.additional_input,
)
return task
async def list_tasks(self, page: int = 1, pageSize: int = 10) -> TaskListResponse:
"""
List all tasks that the agent has created.
"""
tasks, pagination = await self.db.list_tasks(page, pageSize)
response = TaskListResponse(tasks=tasks, pagination=pagination)
return response
async def get_task(self, task_id: str) -> Task:
"""
Get a task by ID.
"""
task = await self.db.get_task(task_id)
return task
async def list_steps(
self, task_id: str, page: int = 1, pageSize: int = 10
) -> TaskStepsListResponse:
"""
List the IDs of all steps that the task has created.
"""
steps, pagination = await self.db.list_steps(task_id, page, pageSize)
response = TaskStepsListResponse(steps=steps, pagination=pagination)
return response
async def execute_step(self, task_id: str, step_request: StepRequestBody) -> Step:
"""
Create a step for the task.
"""
raise NotImplementedError
async def get_step(self, task_id: str, step_id: str) -> Step:
"""
Get a step by ID.
"""
step = await self.db.get_step(task_id, step_id)
return step
async def list_artifacts(
self, task_id: str, page: int = 1, pageSize: int = 10
) -> TaskArtifactsListResponse:
"""
List the artifacts that the task has created.
"""
artifacts, pagination = await self.db.list_artifacts(task_id, page, pageSize)
return TaskArtifactsListResponse(artifacts=artifacts, pagination=pagination)
async def create_artifact(
self, task_id: str, file: UploadFile, relative_path: str = ""
) -> Artifact:
"""
Create an artifact for the task.
"""
file_name = file.filename or str(uuid4())
data = b""
while contents := file.file.read(1024 * 1024):
data += contents
# Check if relative path ends with filename
if relative_path.endswith(file_name):
file_path = relative_path
else:
file_path = os.path.join(relative_path, file_name)
await self.workspace.write_file(file_path, data)
artifact = await self.db.create_artifact(
task_id=task_id,
file_name=file_name,
relative_path=relative_path,
agent_created=False,
)
return artifact
async def get_artifact(self, task_id: str, artifact_id: str) -> StreamingResponse:
"""
Get an artifact by ID.
"""
artifact = await self.db.get_artifact(artifact_id)
if artifact.file_name not in artifact.relative_path:
file_path = os.path.join(artifact.relative_path, artifact.file_name)
else:
file_path = artifact.relative_path
retrieved_artifact = self.workspace.read_file(file_path, binary=True)
return StreamingResponse(
BytesIO(retrieved_artifact),
media_type="application/octet-stream",
headers={
"Content-Disposition": f"attachment; filename={artifact.file_name}"
},
)

View File

@@ -0,0 +1,137 @@
from pathlib import Path
import pytest
from fastapi import UploadFile
from forge.agent_protocol.database.db import AgentDB
from forge.agent_protocol.models.task import (
StepRequestBody,
Task,
TaskListResponse,
TaskRequestBody,
)
from forge.file_storage.base import FileStorageConfiguration
from forge.file_storage.local import LocalFileStorage
from .agent import Agent
@pytest.fixture
def agent(test_workspace: Path):
db = AgentDB("sqlite:///test.db")
config = FileStorageConfiguration(root=test_workspace)
workspace = LocalFileStorage(config)
return Agent(db, workspace)
@pytest.fixture
def file_upload():
this_file = Path(__file__)
file_handle = this_file.open("rb")
yield UploadFile(file_handle, filename=this_file.name)
file_handle.close()
@pytest.mark.asyncio
async def test_create_task(agent: Agent):
task_request = TaskRequestBody(
input="test_input", additional_input={"input": "additional_test_input"}
)
task: Task = await agent.create_task(task_request)
assert task.input == "test_input"
@pytest.mark.asyncio
async def test_list_tasks(agent: Agent):
task_request = TaskRequestBody(
input="test_input", additional_input={"input": "additional_test_input"}
)
await agent.create_task(task_request)
tasks = await agent.list_tasks()
assert isinstance(tasks, TaskListResponse)
@pytest.mark.asyncio
async def test_get_task(agent: Agent):
task_request = TaskRequestBody(
input="test_input", additional_input={"input": "additional_test_input"}
)
task = await agent.create_task(task_request)
retrieved_task = await agent.get_task(task.task_id)
assert retrieved_task.task_id == task.task_id
@pytest.mark.xfail(reason="execute_step is not implemented")
@pytest.mark.asyncio
async def test_execute_step(agent: Agent):
task_request = TaskRequestBody(
input="test_input", additional_input={"input": "additional_test_input"}
)
task = await agent.create_task(task_request)
step_request = StepRequestBody(
input="step_input", additional_input={"input": "additional_test_input"}
)
step = await agent.execute_step(task.task_id, step_request)
assert step.input == "step_input"
assert step.additional_input == {"input": "additional_test_input"}
@pytest.mark.xfail(reason="execute_step is not implemented")
@pytest.mark.asyncio
async def test_get_step(agent: Agent):
task_request = TaskRequestBody(
input="test_input", additional_input={"input": "additional_test_input"}
)
task = await agent.create_task(task_request)
step_request = StepRequestBody(
input="step_input", additional_input={"input": "additional_test_input"}
)
step = await agent.execute_step(task.task_id, step_request)
retrieved_step = await agent.get_step(task.task_id, step.step_id)
assert retrieved_step.step_id == step.step_id
@pytest.mark.asyncio
async def test_list_artifacts(agent: Agent):
tasks = await agent.list_tasks()
assert tasks.tasks, "No tasks in test.db"
artifacts = await agent.list_artifacts(tasks.tasks[0].task_id)
assert isinstance(artifacts.artifacts, list)
@pytest.mark.asyncio
async def test_create_artifact(agent: Agent, file_upload: UploadFile):
task_request = TaskRequestBody(
input="test_input", additional_input={"input": "additional_test_input"}
)
task = await agent.create_task(task_request)
artifact = await agent.create_artifact(
task_id=task.task_id,
file=file_upload,
relative_path=f"a_dir/{file_upload.filename}",
)
assert artifact.file_name == file_upload.filename
assert artifact.relative_path == f"a_dir/{file_upload.filename}"
@pytest.mark.asyncio
async def test_create_and_get_artifact(agent: Agent, file_upload: UploadFile):
task_request = TaskRequestBody(
input="test_input", additional_input={"input": "additional_test_input"}
)
task = await agent.create_task(task_request)
artifact = await agent.create_artifact(
task_id=task.task_id,
file=file_upload,
relative_path=f"b_dir/{file_upload.filename}",
)
await file_upload.seek(0)
file_upload_content = await file_upload.read()
retrieved_artifact = await agent.get_artifact(task.task_id, artifact.artifact_id)
retrieved_artifact_content = bytearray()
async for b in retrieved_artifact.body_iterator:
retrieved_artifact_content.extend(b) # type: ignore
assert retrieved_artifact_content == file_upload_content

333
forge/forge/agent/base.py Normal file
View File

@@ -0,0 +1,333 @@
from __future__ import annotations
import copy
import inspect
import logging
from abc import ABCMeta, abstractmethod
from typing import (
Any,
Awaitable,
Callable,
Generic,
Iterator,
Optional,
ParamSpec,
TypeVar,
cast,
overload,
)
from colorama import Fore
from pydantic import BaseModel, Field, validator
from forge.agent import protocols
from forge.agent.components import (
AgentComponent,
ComponentEndpointError,
EndpointPipelineError,
)
from forge.config.ai_directives import AIDirectives
from forge.config.ai_profile import AIProfile
from forge.llm.providers import CHAT_MODELS, ModelName, OpenAIModelName
from forge.llm.providers.schema import ChatModelInfo
from forge.models.action import ActionResult, AnyProposal
from forge.models.config import SystemConfiguration, SystemSettings, UserConfigurable
logger = logging.getLogger(__name__)
T = TypeVar("T")
P = ParamSpec("P")
DEFAULT_TRIGGERING_PROMPT = (
"Determine exactly one command to use next based on the given goals "
"and the progress you have made so far, "
"and respond using the JSON schema specified previously:"
)
class BaseAgentConfiguration(SystemConfiguration):
allow_fs_access: bool = UserConfigurable(default=False)
fast_llm: ModelName = UserConfigurable(default=OpenAIModelName.GPT3_16k)
smart_llm: ModelName = UserConfigurable(default=OpenAIModelName.GPT4)
use_functions_api: bool = UserConfigurable(default=False)
default_cycle_instruction: str = DEFAULT_TRIGGERING_PROMPT
"""The default instruction passed to the AI for a thinking cycle."""
big_brain: bool = UserConfigurable(default=True)
"""
Whether this agent uses the configured smart LLM (default) to think,
as opposed to the configured fast LLM. Enabling this disables hybrid mode.
"""
cycle_budget: Optional[int] = 1
"""
The number of cycles that the agent is allowed to run unsupervised.
`None` for unlimited continuous execution,
`1` to require user approval for every step,
`0` to stop the agent.
"""
cycles_remaining = cycle_budget
"""The number of cycles remaining within the `cycle_budget`."""
cycle_count = 0
"""The number of cycles that the agent has run since its initialization."""
send_token_limit: Optional[int] = None
"""
The token limit for prompt construction. Should leave room for the completion;
defaults to 75% of `llm.max_tokens`.
"""
summary_max_tlength: Optional[int] = None
# TODO: move to ActionHistoryConfiguration
@validator("use_functions_api")
def validate_openai_functions(cls, v: bool, values: dict[str, Any]):
if v:
smart_llm = values["smart_llm"]
fast_llm = values["fast_llm"]
assert all(
[
not any(s in name for s in {"-0301", "-0314"})
for name in {smart_llm, fast_llm}
]
), (
f"Model {smart_llm} does not support OpenAI Functions. "
"Please disable OPENAI_FUNCTIONS or choose a suitable model."
)
return v
class BaseAgentSettings(SystemSettings):
agent_id: str = ""
ai_profile: AIProfile = Field(default_factory=lambda: AIProfile(ai_name="AutoGPT"))
"""The AI profile or "personality" of the agent."""
directives: AIDirectives = Field(default_factory=AIDirectives)
"""Directives (general instructional guidelines) for the agent."""
task: str = "Terminate immediately" # FIXME: placeholder for forge.sdk.schema.Task
"""The user-given task that the agent is working on."""
config: BaseAgentConfiguration = Field(default_factory=BaseAgentConfiguration)
"""The configuration for this BaseAgent subsystem instance."""
class AgentMeta(ABCMeta):
def __call__(cls, *args, **kwargs):
# Create instance of the class (Agent or BaseAgent)
instance = super().__call__(*args, **kwargs)
# Automatically collect modules after the instance is created
instance._collect_components()
return instance
class BaseAgent(Generic[AnyProposal], metaclass=AgentMeta):
def __init__(
self,
settings: BaseAgentSettings,
):
self.state = settings
self.components: list[AgentComponent] = []
self.config = settings.config
# Execution data for debugging
self._trace: list[str] = []
logger.debug(f"Created {__class__} '{self.state.ai_profile.ai_name}'")
@property
def trace(self) -> list[str]:
return self._trace
@property
def llm(self) -> ChatModelInfo:
"""The LLM that the agent uses to think."""
llm_name = (
self.config.smart_llm if self.config.big_brain else self.config.fast_llm
)
return CHAT_MODELS[llm_name]
@property
def send_token_limit(self) -> int:
return self.config.send_token_limit or self.llm.max_tokens * 3 // 4
@abstractmethod
async def propose_action(self) -> AnyProposal:
...
@abstractmethod
async def execute(
self,
proposal: AnyProposal,
user_feedback: str = "",
) -> ActionResult:
...
@abstractmethod
async def do_not_execute(
self,
denied_proposal: AnyProposal,
user_feedback: str,
) -> ActionResult:
...
def reset_trace(self):
self._trace = []
@overload
async def run_pipeline(
self, protocol_method: Callable[P, Iterator[T]], *args, retry_limit: int = 3
) -> list[T]:
...
@overload
async def run_pipeline(
self,
protocol_method: Callable[P, None | Awaitable[None]],
*args,
retry_limit: int = 3,
) -> list[None]:
...
async def run_pipeline(
self,
protocol_method: Callable[P, Iterator[T] | None | Awaitable[None]],
*args,
retry_limit: int = 3,
) -> list[T] | list[None]:
method_name = protocol_method.__name__
protocol_name = protocol_method.__qualname__.split(".")[0]
protocol_class = getattr(protocols, protocol_name)
if not issubclass(protocol_class, AgentComponent):
raise TypeError(f"{repr(protocol_method)} is not a protocol method")
# Clone parameters to revert on failure
original_args = self._selective_copy(args)
pipeline_attempts = 0
method_result: list[T] = []
self._trace.append(f"⬇️ {Fore.BLUE}{method_name}{Fore.RESET}")
while pipeline_attempts < retry_limit:
try:
for component in self.components:
# Skip other protocols
if not isinstance(component, protocol_class):
continue
# Skip disabled components
if not component.enabled:
self._trace.append(
f" {Fore.LIGHTBLACK_EX}"
f"{component.__class__.__name__}{Fore.RESET}"
)
continue
method = cast(
Callable[..., Iterator[T] | None | Awaitable[None]] | None,
getattr(component, method_name, None),
)
if not callable(method):
continue
component_attempts = 0
while component_attempts < retry_limit:
try:
component_args = self._selective_copy(args)
result = method(*component_args)
if inspect.isawaitable(result):
result = await result
if result is not None:
method_result.extend(result)
args = component_args
self._trace.append(f"{component.__class__.__name__}")
except ComponentEndpointError:
self._trace.append(
f"{Fore.YELLOW}{component.__class__.__name__}: "
f"ComponentEndpointError{Fore.RESET}"
)
# Retry the same component on ComponentEndpointError
component_attempts += 1
continue
# Successful component execution
break
# Successful pipeline execution
break
except EndpointPipelineError as e:
self._trace.append(
f"{Fore.LIGHTRED_EX}{e.triggerer.__class__.__name__}: "
f"EndpointPipelineError{Fore.RESET}"
)
# Restart from the beginning on EndpointPipelineError
# Revert to original parameters
args = self._selective_copy(original_args)
pipeline_attempts += 1
continue # Start the loop over
except Exception as e:
raise e
return method_result
def _collect_components(self):
components = [
getattr(self, attr)
for attr in dir(self)
if isinstance(getattr(self, attr), AgentComponent)
]
if self.components:
# Check if any component is missing (added to Agent but not to components)
for component in components:
if component not in self.components:
logger.warning(
f"Component {component.__class__.__name__} "
"is attached to an agent but not added to components list"
)
# Skip collecting and sorting and sort if ordering is explicit
return
self.components = self._topological_sort(components)
def _topological_sort(
self, components: list[AgentComponent]
) -> list[AgentComponent]:
visited = set()
stack = []
def visit(node: AgentComponent):
if node in visited:
return
visited.add(node)
for neighbor_class in node._run_after:
neighbor = next(
(m for m in components if isinstance(m, neighbor_class)), None
)
if neighbor and neighbor not in visited:
visit(neighbor)
stack.append(node)
for component in components:
visit(component)
return stack
def _selective_copy(self, args: tuple[Any, ...]) -> tuple[Any, ...]:
copied_args = []
for item in args:
if isinstance(item, list):
# Shallow copy for lists
copied_item = item[:]
elif isinstance(item, dict):
# Shallow copy for dicts
copied_item = item.copy()
elif isinstance(item, BaseModel):
# Deep copy for Pydantic models (deep=True to also copy nested models)
copied_item = item.copy(deep=True)
else:
# Deep copy for other objects
copied_item = copy.deepcopy(item)
copied_args.append(copied_item)
return tuple(copied_args)

View File

@@ -0,0 +1,51 @@
from __future__ import annotations
from abc import ABC
from typing import Callable, TypeVar
T = TypeVar("T", bound="AgentComponent")
class AgentComponent(ABC):
"""Base class for all agent components."""
_run_after: list[type[AgentComponent]] = []
_enabled: Callable[[], bool] | bool = True
_disabled_reason: str = ""
@property
def enabled(self) -> bool:
if callable(self._enabled):
return self._enabled()
return self._enabled
@property
def disabled_reason(self) -> str:
"""Return the reason this component is disabled."""
return self._disabled_reason
def run_after(self: T, *components: type[AgentComponent] | AgentComponent) -> T:
"""Set the components that this component should run after."""
for component in components:
t = component if isinstance(component, type) else type(component)
if t not in self._run_after and t is not self.__class__:
self._run_after.append(t)
return self
class ComponentEndpointError(Exception):
"""Error of a single protocol method on a component."""
def __init__(self, message: str, component: AgentComponent):
self.message = message
self.triggerer = component
super().__init__(message)
class EndpointPipelineError(ComponentEndpointError):
"""Error of an entire pipeline of one endpoint."""
class ComponentSystemError(EndpointPipelineError):
"""Error of a group of pipelines;
multiple different endpoints."""

View File

@@ -0,0 +1,51 @@
from abc import abstractmethod
from typing import TYPE_CHECKING, Awaitable, Generic, Iterator
from forge.models.action import ActionResult, AnyProposal
from .components import AgentComponent
if TYPE_CHECKING:
from forge.command.command import Command
from forge.llm.providers import ChatMessage
class DirectiveProvider(AgentComponent):
def get_constraints(self) -> Iterator[str]:
return iter([])
def get_resources(self) -> Iterator[str]:
return iter([])
def get_best_practices(self) -> Iterator[str]:
return iter([])
class CommandProvider(AgentComponent):
@abstractmethod
def get_commands(self) -> Iterator["Command"]:
...
class MessageProvider(AgentComponent):
@abstractmethod
def get_messages(self) -> Iterator["ChatMessage"]:
...
class AfterParse(AgentComponent, Generic[AnyProposal]):
@abstractmethod
def after_parse(self, result: AnyProposal) -> None | Awaitable[None]:
...
class ExecutionFailure(AgentComponent):
@abstractmethod
def execution_failure(self, error: Exception) -> None | Awaitable[None]:
...
class AfterExecute(AgentComponent):
@abstractmethod
def after_execute(self, result: "ActionResult") -> None | Awaitable[None]:
...

View File

View File

@@ -0,0 +1,476 @@
"""
Routes for the Agent Service.
This module defines the API routes for the Agent service.
Developers and contributors should be especially careful when making modifications
to these routes to ensure consistency and correctness in the system's behavior.
"""
import logging
from typing import TYPE_CHECKING, Optional
from fastapi import APIRouter, HTTPException, Query, Request, Response, UploadFile
from fastapi.responses import StreamingResponse
from .models import (
Artifact,
Step,
StepRequestBody,
Task,
TaskArtifactsListResponse,
TaskListResponse,
TaskRequestBody,
TaskStepsListResponse,
)
if TYPE_CHECKING:
from forge.agent.agent import Agent
base_router = APIRouter()
logger = logging.getLogger(__name__)
@base_router.get("/", tags=["root"])
async def root():
"""
Root endpoint that returns a welcome message.
"""
return Response(content="Welcome to the AutoGPT Forge")
@base_router.get("/heartbeat", tags=["server"])
async def check_server_status():
"""
Check if the server is running.
"""
return Response(content="Server is running.", status_code=200)
@base_router.post("/agent/tasks", tags=["agent"], response_model=Task)
async def create_agent_task(request: Request, task_request: TaskRequestBody) -> Task:
"""
Creates a new task using the provided TaskRequestBody and returns a Task.
Args:
request (Request): FastAPI request object.
task (TaskRequestBody): The task request containing input data.
Returns:
Task: A new task with task_id, input, and additional_input set.
Example:
Request (TaskRequestBody defined in schema.py):
{
"input": "Write the words you receive to the file 'output.txt'.",
"additional_input": "python/code"
}
Response (Task defined in schema.py):
{
"task_id": "50da533e-3904-4401-8a07-c49adf88b5eb",
"input": "Write the word 'Washington' to a .txt file",
"additional_input": "python/code",
"artifacts": [],
}
"""
agent: "Agent" = request["agent"]
try:
task = await agent.create_task(task_request)
return task
except Exception:
logger.exception(f"Error whilst trying to create a task: {task_request}")
raise
@base_router.get("/agent/tasks", tags=["agent"], response_model=TaskListResponse)
async def list_agent_tasks(
request: Request,
page: int = Query(1, ge=1),
page_size: int = Query(10, ge=1),
) -> TaskListResponse:
"""
Retrieves a paginated list of all tasks.
Args:
request (Request): FastAPI request object.
page (int, optional): Page number for pagination. Default: 1
page_size (int, optional): Number of tasks per page for pagination. Default: 10
Returns:
TaskListResponse: A list of tasks, and pagination details.
Example:
Request:
GET /agent/tasks?page=1&pageSize=10
Response (TaskListResponse defined in schema.py):
{
"items": [
{
"input": "Write the word 'Washington' to a .txt file",
"additional_input": null,
"task_id": "50da533e-3904-4401-8a07-c49adf88b5eb",
"artifacts": [],
"steps": []
},
...
],
"pagination": {
"total": 100,
"pages": 10,
"current": 1,
"pageSize": 10
}
}
"""
agent: "Agent" = request["agent"]
try:
tasks = await agent.list_tasks(page, page_size)
return tasks
except Exception:
logger.exception("Error whilst trying to list tasks")
raise
@base_router.get("/agent/tasks/{task_id}", tags=["agent"], response_model=Task)
async def get_agent_task(request: Request, task_id: str) -> Task:
"""
Gets the details of a task by ID.
Args:
request (Request): FastAPI request object.
task_id (str): The ID of the task.
Returns:
Task: The task with the given ID.
Example:
Request:
GET /agent/tasks/50da533e-3904-4401-8a07-c49adf88b5eb
Response (Task defined in schema.py):
{
"input": "Write the word 'Washington' to a .txt file",
"additional_input": null,
"task_id": "50da533e-3904-4401-8a07-c49adf88b5eb",
"artifacts": [
{
"artifact_id": "7a49f31c-f9c6-4346-a22c-e32bc5af4d8e",
"file_name": "output.txt",
"agent_created": true,
"relative_path": "file://50da533e-3904-4401-8a07-c49adf88b5eb/output.txt"
}
],
"steps": [
{
"task_id": "50da533e-3904-4401-8a07-c49adf88b5eb",
"step_id": "6bb1801a-fd80-45e8-899a-4dd723cc602e",
"input": "Write the word 'Washington' to a .txt file",
"additional_input": "challenge:write_to_file",
"name": "Write to file",
"status": "completed",
"output": "I am going to use the write_to_file command and write Washington to a file called output.txt <write_to_file('output.txt', 'Washington')>",
"additional_output": "Do you want me to continue?",
"artifacts": [
{
"artifact_id": "7a49f31c-f9c6-4346-a22c-e32bc5af4d8e",
"file_name": "output.txt",
"agent_created": true,
"relative_path": "file://50da533e-3904-4401-8a07-c49adf88b5eb/output.txt"
}
],
"is_last": true
}
]
}
""" # noqa: E501
agent: "Agent" = request["agent"]
try:
task = await agent.get_task(task_id)
return task
except Exception:
logger.exception(f"Error whilst trying to get task: {task_id}")
raise
@base_router.get(
"/agent/tasks/{task_id}/steps",
tags=["agent"],
response_model=TaskStepsListResponse,
)
async def list_agent_task_steps(
request: Request,
task_id: str,
page: int = Query(1, ge=1),
page_size: int = Query(10, ge=1, alias="pageSize"),
) -> TaskStepsListResponse:
"""
Retrieves a paginated list of steps associated with a specific task.
Args:
request (Request): FastAPI request object.
task_id (str): The ID of the task.
page (int, optional): The page number for pagination. Defaults to 1.
page_size (int, optional): Number of steps per page for pagination. Default: 10.
Returns:
TaskStepsListResponse: A list of steps, and pagination details.
Example:
Request:
GET /agent/tasks/50da533e-3904-4401-8a07-c49adf88b5eb/steps?page=1&pageSize=10
Response (TaskStepsListResponse defined in schema.py):
{
"items": [
{
"task_id": "50da533e-3904-4401-8a07-c49adf88b5eb",
"step_id": "step1_id",
...
},
...
],
"pagination": {
"total": 100,
"pages": 10,
"current": 1,
"pageSize": 10
}
}
""" # noqa: E501
agent: "Agent" = request["agent"]
try:
steps = await agent.list_steps(task_id, page, page_size)
return steps
except Exception:
logger.exception("Error whilst trying to list steps")
raise
@base_router.post("/agent/tasks/{task_id}/steps", tags=["agent"], response_model=Step)
async def execute_agent_task_step(
request: Request, task_id: str, step_request: Optional[StepRequestBody] = None
) -> Step:
"""
Executes the next step for a specified task based on the current task status and
returns the executed step with additional feedback fields.
This route is significant because this is where the agent actually performs work.
The function handles executing the next step for a task based on its current state,
and it requires careful implementation to ensure all scenarios (like the presence
or absence of steps or a step marked as `last_step`) are handled correctly.
Depending on the current state of the task, the following scenarios are possible:
1. No steps exist for the task.
2. There is at least one step already for the task, and the task does not have a
completed step marked as `last_step`.
3. There is a completed step marked as `last_step` already on the task.
In each of these scenarios, a step object will be returned with two additional
fields: `output` and `additional_output`.
- `output`: Provides the primary response or feedback to the user.
- `additional_output`: Supplementary information or data. Its specific content is
not strictly defined and can vary based on the step or agent's implementation.
Args:
request (Request): FastAPI request object.
task_id (str): The ID of the task.
step (StepRequestBody): The details for executing the step.
Returns:
Step: Details of the executed step with additional feedback.
Example:
Request:
POST /agent/tasks/50da533e-3904-4401-8a07-c49adf88b5eb/steps
{
"input": "Step input details...",
...
}
Response:
{
"task_id": "50da533e-3904-4401-8a07-c49adf88b5eb",
"step_id": "step1_id",
"output": "Primary feedback...",
"additional_output": "Supplementary details...",
...
}
"""
agent: "Agent" = request["agent"]
try:
# An empty step request represents a yes to continue command
if not step_request:
step_request = StepRequestBody(input="y")
step = await agent.execute_step(task_id, step_request)
return step
except Exception:
logger.exception(f"Error whilst trying to execute a task step: {task_id}")
raise
@base_router.get(
"/agent/tasks/{task_id}/steps/{step_id}", tags=["agent"], response_model=Step
)
async def get_agent_task_step(request: Request, task_id: str, step_id: str) -> Step:
"""
Retrieves the details of a specific step for a given task.
Args:
request (Request): FastAPI request object.
task_id (str): The ID of the task.
step_id (str): The ID of the step.
Returns:
Step: Details of the specific step.
Example:
Request:
GET /agent/tasks/50da533e-3904-4401-8a07-c49adf88b5eb/steps/step1_id
Response:
{
"task_id": "50da533e-3904-4401-8a07-c49adf88b5eb",
"step_id": "step1_id",
...
}
"""
agent: "Agent" = request["agent"]
try:
step = await agent.get_step(task_id, step_id)
return step
except Exception:
logger.exception(f"Error whilst trying to get step: {step_id}")
raise
@base_router.get(
"/agent/tasks/{task_id}/artifacts",
tags=["agent"],
response_model=TaskArtifactsListResponse,
)
async def list_agent_task_artifacts(
request: Request,
task_id: str,
page: int = Query(1, ge=1),
page_size: int = Query(10, ge=1, alias="pageSize"),
) -> TaskArtifactsListResponse:
"""
Retrieves a paginated list of artifacts associated with a specific task.
Args:
request (Request): FastAPI request object.
task_id (str): The ID of the task.
page (int, optional): The page number for pagination. Defaults to 1.
page_size (int, optional): Number of items per page for pagination. Default: 10.
Returns:
TaskArtifactsListResponse: A list of artifacts, and pagination details.
Example:
Request:
GET /agent/tasks/50da533e-3904-4401-8a07-c49adf88b5eb/artifacts?page=1&pageSize=10
Response (TaskArtifactsListResponse defined in schema.py):
{
"items": [
{"artifact_id": "artifact1_id", ...},
{"artifact_id": "artifact2_id", ...},
...
],
"pagination": {
"total": 100,
"pages": 10,
"current": 1,
"pageSize": 10
}
}
""" # noqa: E501
agent: "Agent" = request["agent"]
try:
artifacts = await agent.list_artifacts(task_id, page, page_size)
return artifacts
except Exception:
logger.exception("Error whilst trying to list artifacts")
raise
@base_router.post(
"/agent/tasks/{task_id}/artifacts", tags=["agent"], response_model=Artifact
)
async def upload_agent_task_artifacts(
request: Request, task_id: str, file: UploadFile, relative_path: str = ""
) -> Artifact:
"""
This endpoint is used to upload an artifact (file) associated with a specific task.
Args:
request (Request): The FastAPI request object.
task_id (str): The ID of the task for which the artifact is being uploaded.
file (UploadFile): The file being uploaded as an artifact.
relative_path (str): The relative path for the file. This is a query parameter.
Returns:
Artifact: Metadata object for the uploaded artifact, including its ID and path.
Example:
Request:
POST /agent/tasks/50da533e-3904-4401-8a07-c49adf88b5eb/artifacts?relative_path=my_folder/my_other_folder
File: <uploaded_file>
Response:
{
"artifact_id": "b225e278-8b4c-4f99-a696-8facf19f0e56",
"created_at": "2023-01-01T00:00:00Z",
"modified_at": "2023-01-01T00:00:00Z",
"agent_created": false,
"relative_path": "/my_folder/my_other_folder/",
"file_name": "main.py"
}
""" # noqa: E501
agent: "Agent" = request["agent"]
if file is None:
raise HTTPException(status_code=400, detail="File must be specified")
try:
artifact = await agent.create_artifact(task_id, file, relative_path)
return artifact
except Exception:
logger.exception(f"Error whilst trying to upload artifact: {task_id}")
raise
@base_router.get(
"/agent/tasks/{task_id}/artifacts/{artifact_id}",
tags=["agent"],
response_model=str,
)
async def download_agent_task_artifact(
request: Request, task_id: str, artifact_id: str
) -> StreamingResponse:
"""
Downloads an artifact associated with a specific task.
Args:
request (Request): FastAPI request object.
task_id (str): The ID of the task.
artifact_id (str): The ID of the artifact.
Returns:
FileResponse: The downloaded artifact file.
Example:
Request:
GET /agent/tasks/50da533e-3904-4401-8a07-c49adf88b5eb/artifacts/artifact1_id
Response:
<file_content_of_artifact>
"""
agent: "Agent" = request["agent"]
try:
return await agent.get_artifact(task_id, artifact_id)
except Exception:
logger.exception(f"Error whilst trying to download artifact: {task_id}")
raise

View File

@@ -0,0 +1,3 @@
from .db import AgentDB
__all__ = ["AgentDB"]

View File

@@ -0,0 +1,502 @@
"""
This is an example implementation of the Agent Protocol DB for development Purposes
It uses SQLite as the database and file store backend.
IT IS NOT ADVISED TO USE THIS IN PRODUCTION!
"""
import logging
import math
import uuid
from datetime import datetime
from typing import Any, Dict, List, Literal, Optional, Tuple
from sqlalchemy import JSON, Boolean, DateTime, ForeignKey, create_engine
from sqlalchemy.exc import SQLAlchemyError
from sqlalchemy.orm import (
DeclarativeBase,
Mapped,
joinedload,
mapped_column,
relationship,
sessionmaker,
)
from forge.utils.exceptions import NotFoundError
from ..models.artifact import Artifact
from ..models.pagination import Pagination
from ..models.task import Step, StepRequestBody, StepStatus, Task
logger = logging.getLogger(__name__)
class Base(DeclarativeBase):
type_annotation_map = {
dict[str, Any]: JSON,
}
class TaskModel(Base):
__tablename__ = "tasks"
task_id: Mapped[str] = mapped_column(primary_key=True, index=True)
input: Mapped[str]
additional_input: Mapped[dict[str, Any]] = mapped_column(default=dict)
created_at: Mapped[datetime] = mapped_column(DateTime, default=datetime.utcnow)
modified_at: Mapped[datetime] = mapped_column(
DateTime, default=datetime.utcnow, onupdate=datetime.utcnow
)
artifacts = relationship("ArtifactModel", back_populates="task")
class StepModel(Base):
__tablename__ = "steps"
step_id: Mapped[str] = mapped_column(primary_key=True, index=True)
task_id: Mapped[str] = mapped_column(ForeignKey("tasks.task_id"))
name: Mapped[str]
input: Mapped[str]
status: Mapped[str]
output: Mapped[Optional[str]]
is_last: Mapped[bool] = mapped_column(Boolean, default=False)
created_at: Mapped[datetime] = mapped_column(DateTime, default=datetime.utcnow)
modified_at: Mapped[datetime] = mapped_column(
DateTime, default=datetime.utcnow, onupdate=datetime.utcnow
)
additional_input: Mapped[dict[str, Any]] = mapped_column(default=dict)
additional_output: Mapped[Optional[dict[str, Any]]]
artifacts = relationship("ArtifactModel", back_populates="step")
class ArtifactModel(Base):
__tablename__ = "artifacts"
artifact_id: Mapped[str] = mapped_column(primary_key=True, index=True)
task_id: Mapped[str] = mapped_column(ForeignKey("tasks.task_id"))
step_id: Mapped[Optional[str]] = mapped_column(ForeignKey("steps.step_id"))
agent_created: Mapped[bool] = mapped_column(default=False)
file_name: Mapped[str]
relative_path: Mapped[str]
created_at: Mapped[datetime] = mapped_column(default=datetime.utcnow)
modified_at: Mapped[datetime] = mapped_column(
default=datetime.utcnow, onupdate=datetime.utcnow
)
step = relationship("StepModel", back_populates="artifacts")
task = relationship("TaskModel", back_populates="artifacts")
def convert_to_task(task_obj: TaskModel, debug_enabled: bool = False) -> Task:
if debug_enabled:
logger.debug(f"Converting TaskModel to Task for task_id: {task_obj.task_id}")
task_artifacts = [convert_to_artifact(artifact) for artifact in task_obj.artifacts]
return Task(
task_id=task_obj.task_id,
created_at=task_obj.created_at,
modified_at=task_obj.modified_at,
input=task_obj.input,
additional_input=task_obj.additional_input,
artifacts=task_artifacts,
)
def convert_to_step(step_model: StepModel, debug_enabled: bool = False) -> Step:
if debug_enabled:
logger.debug(f"Converting StepModel to Step for step_id: {step_model.step_id}")
step_artifacts = [
convert_to_artifact(artifact) for artifact in step_model.artifacts
]
status = (
StepStatus.completed if step_model.status == "completed" else StepStatus.created
)
return Step(
task_id=step_model.task_id,
step_id=step_model.step_id,
created_at=step_model.created_at,
modified_at=step_model.modified_at,
name=step_model.name,
input=step_model.input,
status=status,
output=step_model.output,
artifacts=step_artifacts,
is_last=step_model.is_last == 1,
additional_input=step_model.additional_input,
additional_output=step_model.additional_output,
)
def convert_to_artifact(artifact_model: ArtifactModel) -> Artifact:
return Artifact(
artifact_id=artifact_model.artifact_id,
created_at=artifact_model.created_at,
modified_at=artifact_model.modified_at,
agent_created=artifact_model.agent_created,
relative_path=artifact_model.relative_path,
file_name=artifact_model.file_name,
)
# sqlite:///{database_name}
class AgentDB:
def __init__(self, database_string, debug_enabled: bool = False) -> None:
super().__init__()
self.debug_enabled = debug_enabled
if self.debug_enabled:
logger.debug(
f"Initializing AgentDB with database_string: {database_string}"
)
self.engine = create_engine(database_string)
Base.metadata.create_all(self.engine)
self.Session = sessionmaker(bind=self.engine)
def close(self) -> None:
self.Session.close_all()
self.engine.dispose()
async def create_task(
self, input: Optional[str], additional_input: Optional[dict] = {}
) -> Task:
if self.debug_enabled:
logger.debug("Creating new task")
try:
with self.Session() as session:
new_task = TaskModel(
task_id=str(uuid.uuid4()),
input=input,
additional_input=additional_input if additional_input else {},
)
session.add(new_task)
session.commit()
session.refresh(new_task)
if self.debug_enabled:
logger.debug(f"Created new task with task_id: {new_task.task_id}")
return convert_to_task(new_task, self.debug_enabled)
except SQLAlchemyError as e:
logger.error(f"SQLAlchemy error while creating task: {e}")
raise
except Exception as e:
logger.error(f"Unexpected error while creating task: {e}")
raise
async def create_step(
self,
task_id: str,
input: StepRequestBody,
is_last: bool = False,
additional_input: Optional[Dict[str, Any]] = {},
) -> Step:
if self.debug_enabled:
logger.debug(f"Creating new step for task_id: {task_id}")
try:
with self.Session() as session:
new_step = StepModel(
task_id=task_id,
step_id=str(uuid.uuid4()),
name=input.input,
input=input.input,
status="created",
is_last=is_last,
additional_input=additional_input,
)
session.add(new_step)
session.commit()
session.refresh(new_step)
if self.debug_enabled:
logger.debug(f"Created new step with step_id: {new_step.step_id}")
return convert_to_step(new_step, self.debug_enabled)
except SQLAlchemyError as e:
logger.error(f"SQLAlchemy error while creating step: {e}")
raise
except Exception as e:
logger.error(f"Unexpected error while creating step: {e}")
raise
async def create_artifact(
self,
task_id: str,
file_name: str,
relative_path: str,
agent_created: bool = False,
step_id: str | None = None,
) -> Artifact:
if self.debug_enabled:
logger.debug(f"Creating new artifact for task_id: {task_id}")
try:
with self.Session() as session:
if (
existing_artifact := session.query(ArtifactModel)
.filter_by(
task_id=task_id,
file_name=file_name,
relative_path=relative_path,
)
.first()
):
session.close()
if self.debug_enabled:
logger.debug(
f"Artifact {file_name} already exists at {relative_path}/"
)
return convert_to_artifact(existing_artifact)
new_artifact = ArtifactModel(
artifact_id=str(uuid.uuid4()),
task_id=task_id,
step_id=step_id,
agent_created=agent_created,
file_name=file_name,
relative_path=relative_path,
)
session.add(new_artifact)
session.commit()
session.refresh(new_artifact)
if self.debug_enabled:
logger.debug(
f"Created new artifact with ID: {new_artifact.artifact_id}"
)
return convert_to_artifact(new_artifact)
except SQLAlchemyError as e:
logger.error(f"SQLAlchemy error while creating step: {e}")
raise
except Exception as e:
logger.error(f"Unexpected error while creating step: {e}")
raise
async def get_task(self, task_id: str) -> Task:
"""Get a task by its id"""
if self.debug_enabled:
logger.debug(f"Getting task with task_id: {task_id}")
try:
with self.Session() as session:
if task_obj := (
session.query(TaskModel)
.options(joinedload(TaskModel.artifacts))
.filter_by(task_id=task_id)
.first()
):
return convert_to_task(task_obj, self.debug_enabled)
else:
logger.error(f"Task not found with task_id: {task_id}")
raise NotFoundError("Task not found")
except SQLAlchemyError as e:
logger.error(f"SQLAlchemy error while getting task: {e}")
raise
except Exception as e:
logger.error(f"Unexpected error while getting task: {e}")
raise
async def get_step(self, task_id: str, step_id: str) -> Step:
if self.debug_enabled:
logger.debug(f"Getting step with task_id: {task_id} and step_id: {step_id}")
try:
with self.Session() as session:
if step := (
session.query(StepModel)
.options(joinedload(StepModel.artifacts))
.filter(StepModel.step_id == step_id)
.first()
):
return convert_to_step(step, self.debug_enabled)
else:
logger.error(
f"Step not found with task_id: {task_id} and step_id: {step_id}"
)
raise NotFoundError("Step not found")
except SQLAlchemyError as e:
logger.error(f"SQLAlchemy error while getting step: {e}")
raise
except Exception as e:
logger.error(f"Unexpected error while getting step: {e}")
raise
async def get_artifact(self, artifact_id: str) -> Artifact:
if self.debug_enabled:
logger.debug(f"Getting artifact with and artifact_id: {artifact_id}")
try:
with self.Session() as session:
if (
artifact_model := session.query(ArtifactModel)
.filter_by(artifact_id=artifact_id)
.first()
):
return convert_to_artifact(artifact_model)
else:
logger.error(
f"Artifact not found with and artifact_id: {artifact_id}"
)
raise NotFoundError("Artifact not found")
except SQLAlchemyError as e:
logger.error(f"SQLAlchemy error while getting artifact: {e}")
raise
except Exception as e:
logger.error(f"Unexpected error while getting artifact: {e}")
raise
async def update_step(
self,
task_id: str,
step_id: str,
status: Optional[str] = None,
output: Optional[str] = None,
additional_input: Optional[Dict[str, Any]] = None,
additional_output: Optional[Dict[str, Any]] = None,
) -> Step:
if self.debug_enabled:
logger.debug(
f"Updating step with task_id: {task_id} and step_id: {step_id}"
)
try:
with self.Session() as session:
if (
step := session.query(StepModel)
.filter_by(task_id=task_id, step_id=step_id)
.first()
):
if status is not None:
step.status = status
if additional_input is not None:
step.additional_input = additional_input
if output is not None:
step.output = output
if additional_output is not None:
step.additional_output = additional_output
session.commit()
return await self.get_step(task_id, step_id)
else:
logger.error(
"Can't update non-existent Step with "
f"task_id: {task_id} and step_id: {step_id}"
)
raise NotFoundError("Step not found")
except SQLAlchemyError as e:
logger.error(f"SQLAlchemy error while getting step: {e}")
raise
except Exception as e:
logger.error(f"Unexpected error while getting step: {e}")
raise
async def update_artifact(
self,
artifact_id: str,
*,
file_name: str = "",
relative_path: str = "",
agent_created: Optional[Literal[True]] = None,
) -> Artifact:
logger.debug(f"Updating artifact with artifact_id: {artifact_id}")
with self.Session() as session:
if (
artifact := session.query(ArtifactModel)
.filter_by(artifact_id=artifact_id)
.first()
):
if file_name:
artifact.file_name = file_name
if relative_path:
artifact.relative_path = relative_path
if agent_created:
artifact.agent_created = agent_created
session.commit()
return await self.get_artifact(artifact_id)
else:
logger.error(f"Artifact not found with artifact_id: {artifact_id}")
raise NotFoundError("Artifact not found")
async def list_tasks(
self, page: int = 1, per_page: int = 10
) -> Tuple[List[Task], Pagination]:
if self.debug_enabled:
logger.debug("Listing tasks")
try:
with self.Session() as session:
tasks = (
session.query(TaskModel)
.offset((page - 1) * per_page)
.limit(per_page)
.all()
)
total = session.query(TaskModel).count()
pages = math.ceil(total / per_page)
pagination = Pagination(
total_items=total,
total_pages=pages,
current_page=page,
page_size=per_page,
)
return [
convert_to_task(task, self.debug_enabled) for task in tasks
], pagination
except SQLAlchemyError as e:
logger.error(f"SQLAlchemy error while listing tasks: {e}")
raise
except Exception as e:
logger.error(f"Unexpected error while listing tasks: {e}")
raise
async def list_steps(
self, task_id: str, page: int = 1, per_page: int = 10
) -> Tuple[List[Step], Pagination]:
if self.debug_enabled:
logger.debug(f"Listing steps for task_id: {task_id}")
try:
with self.Session() as session:
steps = (
session.query(StepModel)
.filter_by(task_id=task_id)
.offset((page - 1) * per_page)
.limit(per_page)
.all()
)
total = session.query(StepModel).filter_by(task_id=task_id).count()
pages = math.ceil(total / per_page)
pagination = Pagination(
total_items=total,
total_pages=pages,
current_page=page,
page_size=per_page,
)
return [
convert_to_step(step, self.debug_enabled) for step in steps
], pagination
except SQLAlchemyError as e:
logger.error(f"SQLAlchemy error while listing steps: {e}")
raise
except Exception as e:
logger.error(f"Unexpected error while listing steps: {e}")
raise
async def list_artifacts(
self, task_id: str, page: int = 1, per_page: int = 10
) -> Tuple[List[Artifact], Pagination]:
if self.debug_enabled:
logger.debug(f"Listing artifacts for task_id: {task_id}")
try:
with self.Session() as session:
artifacts = (
session.query(ArtifactModel)
.filter_by(task_id=task_id)
.offset((page - 1) * per_page)
.limit(per_page)
.all()
)
total = session.query(ArtifactModel).filter_by(task_id=task_id).count()
pages = math.ceil(total / per_page)
pagination = Pagination(
total_items=total,
total_pages=pages,
current_page=page,
page_size=per_page,
)
return [
convert_to_artifact(artifact) for artifact in artifacts
], pagination
except SQLAlchemyError as e:
logger.error(f"SQLAlchemy error while listing artifacts: {e}")
raise
except Exception as e:
logger.error(f"Unexpected error while listing artifacts: {e}")
raise

View File

@@ -0,0 +1,313 @@
import os
import sqlite3
from datetime import datetime
import pytest
from forge.agent_protocol.database.db import (
AgentDB,
ArtifactModel,
StepModel,
TaskModel,
convert_to_artifact,
convert_to_step,
convert_to_task,
)
from forge.agent_protocol.models import (
Artifact,
Step,
StepRequestBody,
StepStatus,
Task,
)
from forge.utils.exceptions import NotFoundError as DataNotFoundError
TEST_DB_FILENAME = "test_db.sqlite3"
TEST_DB_URL = f"sqlite:///{TEST_DB_FILENAME}"
@pytest.fixture
def agent_db():
db = AgentDB(TEST_DB_URL)
yield db
db.close()
os.remove(TEST_DB_FILENAME)
@pytest.fixture
def raw_db_connection(agent_db: AgentDB):
connection = sqlite3.connect(TEST_DB_FILENAME)
yield connection
connection.close()
def test_table_creation(raw_db_connection: sqlite3.Connection):
cursor = raw_db_connection.cursor()
# Test for tasks table existence
cursor.execute("SELECT name FROM sqlite_master WHERE type='table' AND name='tasks'")
assert cursor.fetchone() is not None
# Test for steps table existence
cursor.execute("SELECT name FROM sqlite_master WHERE type='table' AND name='steps'")
assert cursor.fetchone() is not None
# Test for artifacts table existence
cursor.execute(
"SELECT name FROM sqlite_master WHERE type='table' AND name='artifacts'"
)
assert cursor.fetchone() is not None
@pytest.mark.asyncio
async def test_task_schema():
now = datetime.now()
task = Task(
task_id="50da533e-3904-4401-8a07-c49adf88b5eb",
input="Write the words you receive to the file 'output.txt'.",
created_at=now,
modified_at=now,
artifacts=[
Artifact(
artifact_id="b225e278-8b4c-4f99-a696-8facf19f0e56",
agent_created=True,
file_name="main.py",
relative_path="python/code/",
created_at=now,
modified_at=now,
)
],
)
assert task.task_id == "50da533e-3904-4401-8a07-c49adf88b5eb"
assert task.input == "Write the words you receive to the file 'output.txt'."
assert len(task.artifacts) == 1
assert task.artifacts[0].artifact_id == "b225e278-8b4c-4f99-a696-8facf19f0e56"
@pytest.mark.asyncio
async def test_step_schema():
now = datetime.now()
step = Step(
task_id="50da533e-3904-4401-8a07-c49adf88b5eb",
step_id="6bb1801a-fd80-45e8-899a-4dd723cc602e",
created_at=now,
modified_at=now,
name="Write to file",
input="Write the words you receive to the file 'output.txt'.",
status=StepStatus.created,
output=(
"I am going to use the write_to_file command and write Washington "
"to a file called output.txt <write_to_file('output.txt', 'Washington')>"
),
artifacts=[
Artifact(
artifact_id="b225e278-8b4c-4f99-a696-8facf19f0e56",
file_name="main.py",
relative_path="python/code/",
created_at=now,
modified_at=now,
agent_created=True,
)
],
is_last=False,
)
assert step.task_id == "50da533e-3904-4401-8a07-c49adf88b5eb"
assert step.step_id == "6bb1801a-fd80-45e8-899a-4dd723cc602e"
assert step.name == "Write to file"
assert step.status == StepStatus.created
assert step.output == (
"I am going to use the write_to_file command and write Washington "
"to a file called output.txt <write_to_file('output.txt', 'Washington')>"
)
assert len(step.artifacts) == 1
assert step.artifacts[0].artifact_id == "b225e278-8b4c-4f99-a696-8facf19f0e56"
assert step.is_last is False
@pytest.mark.asyncio
async def test_convert_to_task():
now = datetime.now()
task_model = TaskModel(
task_id="50da533e-3904-4401-8a07-c49adf88b5eb",
created_at=now,
modified_at=now,
input="Write the words you receive to the file 'output.txt'.",
additional_input={},
artifacts=[
ArtifactModel(
artifact_id="b225e278-8b4c-4f99-a696-8facf19f0e56",
created_at=now,
modified_at=now,
relative_path="file:///path/to/main.py",
agent_created=True,
file_name="main.py",
)
],
)
task = convert_to_task(task_model)
assert task.task_id == "50da533e-3904-4401-8a07-c49adf88b5eb"
assert task.input == "Write the words you receive to the file 'output.txt'."
assert len(task.artifacts) == 1
assert task.artifacts[0].artifact_id == "b225e278-8b4c-4f99-a696-8facf19f0e56"
@pytest.mark.asyncio
async def test_convert_to_step():
now = datetime.now()
step_model = StepModel(
task_id="50da533e-3904-4401-8a07-c49adf88b5eb",
step_id="6bb1801a-fd80-45e8-899a-4dd723cc602e",
created_at=now,
modified_at=now,
name="Write to file",
status="created",
input="Write the words you receive to the file 'output.txt'.",
additional_input={},
artifacts=[
ArtifactModel(
artifact_id="b225e278-8b4c-4f99-a696-8facf19f0e56",
created_at=now,
modified_at=now,
relative_path="file:///path/to/main.py",
agent_created=True,
file_name="main.py",
)
],
is_last=False,
)
step = convert_to_step(step_model)
assert step.task_id == "50da533e-3904-4401-8a07-c49adf88b5eb"
assert step.step_id == "6bb1801a-fd80-45e8-899a-4dd723cc602e"
assert step.name == "Write to file"
assert step.status == StepStatus.created
assert len(step.artifacts) == 1
assert step.artifacts[0].artifact_id == "b225e278-8b4c-4f99-a696-8facf19f0e56"
assert step.is_last is False
@pytest.mark.asyncio
async def test_convert_to_artifact():
now = datetime.now()
artifact_model = ArtifactModel(
artifact_id="b225e278-8b4c-4f99-a696-8facf19f0e56",
created_at=now,
modified_at=now,
relative_path="file:///path/to/main.py",
agent_created=True,
file_name="main.py",
)
artifact = convert_to_artifact(artifact_model)
assert artifact.artifact_id == "b225e278-8b4c-4f99-a696-8facf19f0e56"
assert artifact.relative_path == "file:///path/to/main.py"
assert artifact.agent_created is True
@pytest.mark.asyncio
async def test_create_task(agent_db: AgentDB):
task = await agent_db.create_task("task_input")
assert task.input == "task_input"
@pytest.mark.asyncio
async def test_create_and_get_task(agent_db: AgentDB):
task = await agent_db.create_task("test_input")
fetched_task = await agent_db.get_task(task.task_id)
assert fetched_task.input == "test_input"
@pytest.mark.asyncio
async def test_get_task_not_found(agent_db: AgentDB):
with pytest.raises(DataNotFoundError):
await agent_db.get_task("9999")
@pytest.mark.asyncio
async def test_create_and_get_step(agent_db: AgentDB):
task = await agent_db.create_task("task_input")
step_input = {"type": "python/code"}
request = StepRequestBody(input="test_input debug", additional_input=step_input)
step = await agent_db.create_step(task.task_id, request)
step = await agent_db.get_step(task.task_id, step.step_id)
assert step.input == "test_input debug"
@pytest.mark.asyncio
async def test_updating_step(agent_db: AgentDB):
created_task = await agent_db.create_task("task_input")
step_input = {"type": "python/code"}
request = StepRequestBody(input="test_input debug", additional_input=step_input)
created_step = await agent_db.create_step(created_task.task_id, request)
await agent_db.update_step(created_task.task_id, created_step.step_id, "completed")
step = await agent_db.get_step(created_task.task_id, created_step.step_id)
assert step.status.value == "completed"
@pytest.mark.asyncio
async def test_get_step_not_found(agent_db: AgentDB):
with pytest.raises(DataNotFoundError):
await agent_db.get_step("9999", "9999")
@pytest.mark.asyncio
async def test_get_artifact(agent_db: AgentDB):
# Given: A task and its corresponding artifact
task = await agent_db.create_task("test_input debug")
step_input = {"type": "python/code"}
requst = StepRequestBody(input="test_input debug", additional_input=step_input)
step = await agent_db.create_step(task.task_id, requst)
# Create an artifact
artifact = await agent_db.create_artifact(
task_id=task.task_id,
file_name="test_get_artifact_sample_file.txt",
relative_path="file:///path/to/test_get_artifact_sample_file.txt",
agent_created=True,
step_id=step.step_id,
)
# When: The artifact is fetched by its ID
fetched_artifact = await agent_db.get_artifact(artifact.artifact_id)
# Then: The fetched artifact matches the original
assert fetched_artifact.artifact_id == artifact.artifact_id
assert (
fetched_artifact.relative_path
== "file:///path/to/test_get_artifact_sample_file.txt"
)
@pytest.mark.asyncio
async def test_list_tasks(agent_db: AgentDB):
# Given: Multiple tasks in the database
task1 = await agent_db.create_task("test_input_1")
task2 = await agent_db.create_task("test_input_2")
# When: All tasks are fetched
fetched_tasks, pagination = await agent_db.list_tasks()
# Then: The fetched tasks list includes the created tasks
task_ids = [task.task_id for task in fetched_tasks]
assert task1.task_id in task_ids
assert task2.task_id in task_ids
@pytest.mark.asyncio
async def test_list_steps(agent_db: AgentDB):
step_input = {"type": "python/code"}
request = StepRequestBody(input="test_input debug", additional_input=step_input)
# Given: A task and multiple steps for that task
task = await agent_db.create_task("test_input")
step1 = await agent_db.create_step(task.task_id, request)
request = StepRequestBody(input="step two")
step2 = await agent_db.create_step(task.task_id, request)
# When: All steps for the task are fetched
fetched_steps, pagination = await agent_db.list_steps(task.task_id)
# Then: The fetched steps list includes the created steps
step_ids = [step.step_id for step in fetched_steps]
assert step1.step_id in step_ids
assert step2.step_id in step_ids

View File

@@ -0,0 +1,34 @@
from starlette.types import ASGIApp
class AgentMiddleware:
"""
Middleware that injects the agent instance into the request scope.
"""
def __init__(self, app: ASGIApp, agent):
"""
Args:
app: The FastAPI app - automatically injected by FastAPI.
agent: The agent instance to inject into the request scope.
Examples:
>>> from fastapi import FastAPI, Request
>>> from agent_protocol.agent import Agent
>>> from agent_protocol.middlewares import AgentMiddleware
>>> app = FastAPI()
>>> @app.get("/")
>>> async def root(request: Request):
>>> agent = request["agent"]
>>> task = agent.db.create_task("Do something.")
>>> return {"task_id": a.task_id}
>>> agent = Agent()
>>> app.add_middleware(AgentMiddleware, agent=agent)
"""
self.app = app
self.agent = agent
async def __call__(self, scope, receive, send):
scope["agent"] = self.agent
await self.app(scope, receive, send)

View File

@@ -0,0 +1,25 @@
from .artifact import Artifact
from .pagination import Pagination
from .task import (
Step,
StepRequestBody,
StepStatus,
Task,
TaskArtifactsListResponse,
TaskListResponse,
TaskRequestBody,
TaskStepsListResponse,
)
__all__ = [
"Artifact",
"Pagination",
"Step",
"StepRequestBody",
"StepStatus",
"Task",
"TaskArtifactsListResponse",
"TaskListResponse",
"TaskRequestBody",
"TaskStepsListResponse",
]

View File

@@ -0,0 +1,38 @@
from datetime import datetime
from pydantic import BaseModel, Field
class Artifact(BaseModel):
created_at: datetime = Field(
...,
description="The creation datetime of the task.",
example="2023-01-01T00:00:00Z",
json_encoders={datetime: lambda v: v.isoformat()},
)
modified_at: datetime = Field(
...,
description="The modification datetime of the task.",
example="2023-01-01T00:00:00Z",
json_encoders={datetime: lambda v: v.isoformat()},
)
artifact_id: str = Field(
...,
description="ID of the artifact.",
example="b225e278-8b4c-4f99-a696-8facf19f0e56",
)
agent_created: bool = Field(
...,
description="Whether the artifact has been created by the agent.",
example=False,
)
relative_path: str = Field(
...,
description="Relative path of the artifact in the agents workspace.",
example="/my_folder/my_other_folder/",
)
file_name: str = Field(
...,
description="Filename of the artifact.",
example="main.py",
)

View File

@@ -0,0 +1,8 @@
from pydantic import BaseModel, Field
class Pagination(BaseModel):
total_items: int = Field(..., description="Total number of items.", example=42)
total_pages: int = Field(..., description="Total number of pages.", example=97)
current_page: int = Field(..., description="Current_page page number.", example=1)
page_size: int = Field(..., description="Number of items per page.", example=25)

View File

@@ -0,0 +1,126 @@
from __future__ import annotations
from datetime import datetime
from enum import Enum
from typing import Any, List, Optional
from pydantic import BaseModel, Field
from .artifact import Artifact
from .pagination import Pagination
class TaskRequestBody(BaseModel):
input: str = Field(
...,
min_length=1,
description="Input prompt for the task.",
example="Write the words you receive to the file 'output.txt'.",
)
additional_input: dict[str, Any] = Field(default_factory=dict)
class Task(TaskRequestBody):
created_at: datetime = Field(
...,
description="The creation datetime of the task.",
example="2023-01-01T00:00:00Z",
json_encoders={datetime: lambda v: v.isoformat()},
)
modified_at: datetime = Field(
...,
description="The modification datetime of the task.",
example="2023-01-01T00:00:00Z",
json_encoders={datetime: lambda v: v.isoformat()},
)
task_id: str = Field(
...,
description="The ID of the task.",
example="50da533e-3904-4401-8a07-c49adf88b5eb",
)
artifacts: list[Artifact] = Field(
default_factory=list,
description="A list of artifacts that the task has produced.",
example=[
"7a49f31c-f9c6-4346-a22c-e32bc5af4d8e",
"ab7b4091-2560-4692-a4fe-d831ea3ca7d6",
],
)
class StepRequestBody(BaseModel):
name: Optional[str] = Field(
default=None, description="The name of the task step.", example="Write to file"
)
input: str = Field(
..., description="Input prompt for the step.", example="Washington"
)
additional_input: dict[str, Any] = Field(default_factory=dict)
class StepStatus(Enum):
created = "created"
running = "running"
completed = "completed"
class Step(StepRequestBody):
created_at: datetime = Field(
...,
description="The creation datetime of the task.",
example="2023-01-01T00:00:00Z",
json_encoders={datetime: lambda v: v.isoformat()},
)
modified_at: datetime = Field(
...,
description="The modification datetime of the task.",
example="2023-01-01T00:00:00Z",
json_encoders={datetime: lambda v: v.isoformat()},
)
task_id: str = Field(
...,
description="The ID of the task this step belongs to.",
example="50da533e-3904-4401-8a07-c49adf88b5eb",
)
step_id: str = Field(
...,
description="The ID of the task step.",
example="6bb1801a-fd80-45e8-899a-4dd723cc602e",
)
name: Optional[str] = Field(
default=None, description="The name of the task step.", example="Write to file"
)
status: StepStatus = Field(
..., description="The status of the task step.", example="created"
)
output: Optional[str] = Field(
default=None,
description="Output of the task step.",
example=(
"I am going to use the write_to_file command and write Washington "
"to a file called output.txt <write_to_file('output.txt', 'Washington')"
),
)
additional_output: Optional[dict[str, Any]] = None
artifacts: list[Artifact] = Field(
default_factory=list,
description="A list of artifacts that the step has produced.",
)
is_last: bool = Field(
..., description="Whether this is the last step in the task.", example=True
)
class TaskListResponse(BaseModel):
tasks: Optional[List[Task]] = None
pagination: Optional[Pagination] = None
class TaskStepsListResponse(BaseModel):
steps: Optional[List[Step]] = None
pagination: Optional[Pagination] = None
class TaskArtifactsListResponse(BaseModel):
artifacts: Optional[List[Artifact]] = None
pagination: Optional[Pagination] = None

View File

@@ -0,0 +1,5 @@
from .command import Command
from .decorator import command
from .parameter import CommandParameter
__all__ = ["Command", "CommandParameter", "command"]

View File

@@ -0,0 +1,95 @@
from __future__ import annotations
import inspect
from typing import Callable, Concatenate, Generic, ParamSpec, TypeVar, cast
from forge.agent.protocols import CommandProvider
from .parameter import CommandParameter
P = ParamSpec("P")
CO = TypeVar("CO") # command output
_CP = TypeVar("_CP", bound=CommandProvider)
class Command(Generic[P, CO]):
"""A class representing a command.
Attributes:
name (str): The name of the command.
description (str): A brief description of what the command does.
parameters (list): The parameters of the function that the command executes.
"""
def __init__(
self,
names: list[str],
description: str,
method: Callable[Concatenate[_CP, P], CO],
parameters: list[CommandParameter],
):
# Check if all parameters are provided
if not self._parameters_match(method, parameters):
raise ValueError(
f"Command {names[0]} has different parameters than provided schema"
)
self.names = names
self.description = description
# Method technically has a `self` parameter, but we can ignore that
# since Python passes it internally.
self.method = cast(Callable[P, CO], method)
self.parameters = parameters
@property
def is_async(self) -> bool:
return inspect.iscoroutinefunction(self.method)
@property
def return_type(self) -> type:
type = inspect.signature(self.method).return_annotation
if type == inspect.Signature.empty:
return None
return type.__name__
def _parameters_match(
self, func: Callable, parameters: list[CommandParameter]
) -> bool:
# Get the function's signature
signature = inspect.signature(func)
# Extract parameter names, ignoring 'self' for methods
func_param_names = [
param.name
for param in signature.parameters.values()
if param.name != "self"
]
names = [param.name for param in parameters]
# Check if sorted lists of names/keys are equal
return sorted(func_param_names) == sorted(names)
def __call__(self, *args: P.args, **kwargs: P.kwargs) -> CO:
return self.method(*args, **kwargs)
def __str__(self) -> str:
params = [
f"{param.name}: "
+ ("%s" if param.spec.required else "Optional[%s]")
% (param.spec.type.value if param.spec.type else "Any")
for param in self.parameters
]
return (
f"{self.names[0]}: {self.description.rstrip('.')}. "
f"Params: ({', '.join(params)})"
)
def __get__(self, instance, owner):
if instance is None:
# Accessed on the class, not an instance
return self
# Bind the method to the instance
return Command(
self.names,
self.description,
self.method.__get__(instance, owner),
self.parameters,
)

View File

@@ -0,0 +1,60 @@
import re
from typing import Callable, Concatenate, Optional, TypeVar
from forge.agent.protocols import CommandProvider
from forge.models.json_schema import JSONSchema
from .command import CO, Command, CommandParameter, P
_CP = TypeVar("_CP", bound=CommandProvider)
def command(
names: list[str] = [],
description: Optional[str] = None,
parameters: dict[str, JSONSchema] = {},
) -> Callable[[Callable[Concatenate[_CP, P], CO]], Command[P, CO]]:
"""
The command decorator is used to make a Command from a function.
Args:
names (list[str]): The names of the command.
If not provided, the function name will be used.
description (str): A brief description of what the command does.
If not provided, the docstring until double line break will be used
(or entire docstring if no double line break is found)
parameters (dict[str, JSONSchema]): The parameters of the function
that the command executes.
"""
def decorator(func: Callable[Concatenate[_CP, P], CO]) -> Command[P, CO]:
doc = func.__doc__ or ""
# If names is not provided, use the function name
command_names = names or [func.__name__]
# If description is not provided, use the first part of the docstring
if not (command_description := description):
if not func.__doc__:
raise ValueError("Description is required if function has no docstring")
# Return the part of the docstring before double line break or everything
command_description = re.sub(r"\s+", " ", doc.split("\n\n")[0].strip())
# Parameters
typed_parameters = [
CommandParameter(
name=param_name,
spec=spec,
)
for param_name, spec in parameters.items()
]
# Wrap func with Command
command = Command(
names=command_names,
description=command_description,
method=func,
parameters=typed_parameters,
)
return command
return decorator

View File

@@ -0,0 +1,16 @@
from pydantic import BaseModel
from forge.models.json_schema import JSONSchema
class CommandParameter(BaseModel):
name: str
spec: JSONSchema
def __repr__(self):
return "CommandParameter('%s', '%s', '%s', %s)" % (
self.name,
self.spec.type,
self.spec.description,
self.spec.required,
)

View File

@@ -0,0 +1,137 @@
# 🧩 Components
Components are the building blocks of [🤖 Agents](./agents.md). They are classes inheriting `AgentComponent` or implementing one or more [⚙️ Protocols](./protocols.md) that give agent additional abilities or processing.
Components can be used to implement various functionalities like providing messages to the prompt, executing code, or interacting with external services.
They can be enabled or disabled, ordered, and can rely on each other.
Components assigned in the agent's `__init__` via `self` are automatically detected upon the agent's instantiation.
For example inside `__init__`: `self.my_component = MyComponent()`.
You can use any valid Python variable name, what matters for the component to be detected is its type (`AgentComponent` or any protocol inheriting from it).
Visit [Built-in Components](./built-in-components.md) to see what components are available out of the box.
```py
from forge.agent import BaseAgent
from forge.agent.components import AgentComponent
class HelloComponent(AgentComponent):
pass
class SomeComponent(AgentComponent):
def __init__(self, hello_component: HelloComponent):
self.hello_component = hello_component
class MyAgent(BaseAgent):
def __init__(self):
# These components will be automatically discovered and used
self.hello_component = HelloComponent()
# We pass HelloComponent to SomeComponent
self.some_component = SomeComponent(self.hello_component)
```
## Ordering components
The execution order of components is important because some may depend on the results of the previous ones.
**By default, components are ordered alphabetically.**
### Ordering individual components
You can order a single component by passing other components (or their types) to the `run_after` method. This way you can ensure that the component will be executed after the specified one.
The `run_after` method returns the component itself, so you can call it when assigning the component to a variable:
```py
class MyAgent(Agent):
def __init__(self):
self.hello_component = HelloComponent()
self.calculator_component = CalculatorComponent().run_after(self.hello_component)
# This is equivalent to passing a type:
# self.calculator_component = CalculatorComponent().run_after(HelloComponent)
```
!!! warning
Be sure not to make circular dependencies when ordering components!
### Ordering all components
You can also order all components by setting `self.components` list in the agent's `__init__` method.
This way ensures that there's no circular dependencies and any `run_after` calls are ignored.
!!! warning
Be sure to include all components - by setting `self.components` list, you're overriding the default behavior of discovering components automatically. Since it's usually not intended agent will inform you in the terminal if some components were skipped.
```py
class MyAgent(Agent):
def __init__(self):
self.hello_component = HelloComponent()
self.calculator_component = CalculatorComponent()
# Explicitly set components list
self.components = [self.hello_component, self.calculator_component]
```
## Disabling components
You can control which components are enabled by setting their `_enabled` attribute.
Either provide a `bool` value or a `Callable[[], bool]`, will be checked each time
the component is about to be executed. This way you can dynamically enable or disable
components based on some conditions.
You can also provide a reason for disabling the component by setting `_disabled_reason`.
The reason will be visible in the debug information.
```py
class DisabledComponent(MessageProvider):
def __init__(self):
# Disable this component
self._enabled = False
self._disabled_reason = "This component is disabled because of reasons."
# Or disable based on some condition, either statically...:
self._enabled = self.some_property is not None
# ... or dynamically:
self._enabled = lambda: self.some_property is not None
# This method will never be called
def get_messages(self) -> Iterator[ChatMessage]:
yield ChatMessage.user("This message won't be seen!")
def some_condition(self) -> bool:
return False
```
If you don't want the component at all, you can just remove it from the agent's `__init__` method. If you want to remove components you inherit from the parent class you can set the relevant attribute to `None`:
!!! Warning
Be careful when removing components that are required by other components. This may lead to errors and unexpected behavior.
```py
class MyAgent(Agent):
def __init__(self):
super().__init__(...)
# Disable WatchdogComponent that is in the parent class
self.watchdog = None
```
## Exceptions
Custom errors are provided which can be used to control the execution flow in case something went wrong. All those errors can be raised in protocol methods and will be caught by the agent.
By default agent will retry three times and then re-raise an exception if it's still not resolved. All passed arguments are automatically handled and the values are reverted when needed.
All errors accept an optional `str` message. There are following errors ordered by increasing broadness:
1. `ComponentEndpointError`: A single endpoint method failed to execute. Agent will retry the execution of this endpoint on the component.
2. `EndpointPipelineError`: A pipeline failed to execute. Agent will retry the execution of the endpoint for all components.
3. `ComponentSystemError`: Multiple pipelines failed.
**Example**
```py
from forge.agent.components import ComponentEndpointError
from forge.agent.protocols import MessageProvider
# Example of raising an error
class MyComponent(MessageProvider):
def get_messages(self) -> Iterator[ChatMessage]:
# This will cause the component to always fail
# and retry 3 times before re-raising the exception
raise ComponentEndpointError("Endpoint error!")
```

View File

@@ -0,0 +1,4 @@
from .action_history import ActionHistoryComponent
from .model import Episode, EpisodicActionHistory
__all__ = ["ActionHistoryComponent", "Episode", "EpisodicActionHistory"]

View File

@@ -0,0 +1,81 @@
from __future__ import annotations
from typing import TYPE_CHECKING, Callable, Iterator, Optional
from forge.agent.protocols import AfterExecute, AfterParse, MessageProvider
from forge.llm.prompting.utils import indent
from forge.llm.providers import ChatMessage, MultiProvider
if TYPE_CHECKING:
from forge.config.config import Config
from .model import ActionResult, AnyProposal, Episode, EpisodicActionHistory
class ActionHistoryComponent(MessageProvider, AfterParse[AnyProposal], AfterExecute):
"""Keeps track of the event history and provides a summary of the steps."""
def __init__(
self,
event_history: EpisodicActionHistory[AnyProposal],
max_tokens: int,
count_tokens: Callable[[str], int],
legacy_config: Config,
llm_provider: MultiProvider,
) -> None:
self.event_history = event_history
self.max_tokens = max_tokens
self.count_tokens = count_tokens
self.legacy_config = legacy_config
self.llm_provider = llm_provider
def get_messages(self) -> Iterator[ChatMessage]:
if progress := self._compile_progress(
self.event_history.episodes,
self.max_tokens,
self.count_tokens,
):
yield ChatMessage.system(
f"## Progress on your Task so far\nThis is the list of the steps that you have executed previously, use this as your consideration on considering the next action!\n{progress}"
)
def after_parse(self, result: AnyProposal) -> None:
self.event_history.register_action(result)
async def after_execute(self, result: ActionResult) -> None:
self.event_history.register_result(result)
await self.event_history.handle_compression(
self.llm_provider, self.legacy_config
)
def _compile_progress(
self,
episode_history: list[Episode[AnyProposal]],
max_tokens: Optional[int] = None,
count_tokens: Optional[Callable[[str], int]] = None,
) -> str:
if max_tokens and not count_tokens:
raise ValueError("count_tokens is required if max_tokens is set")
steps: list[str] = []
tokens: int = 0
n_episodes = len(episode_history)
for i, episode in enumerate(reversed(episode_history)):
# Use full format for the latest 4 steps, summary or format for older steps
if i < 4 or episode.summary is None:
step_content = indent(episode.format(), 2).strip()
else:
step_content = episode.summary
step = f"* Step {n_episodes - i}: {step_content}"
if max_tokens and count_tokens:
step_tokens = count_tokens(step)
if tokens + step_tokens > max_tokens:
break
tokens += step_tokens
steps.insert(0, step)
return "\n\n".join(steps)

View File

@@ -0,0 +1,155 @@
from __future__ import annotations
import asyncio
from typing import TYPE_CHECKING, Generic
from pydantic import Field
from pydantic.generics import GenericModel
from forge.content_processing.text import summarize_text
from forge.llm.prompting.utils import format_numbered_list, indent
from forge.models.action import ActionResult, AnyProposal
from forge.models.utils import ModelWithSummary
if TYPE_CHECKING:
from forge.config.config import Config
from forge.llm.providers import MultiProvider
class Episode(GenericModel, Generic[AnyProposal]):
action: AnyProposal
result: ActionResult | None
summary: str | None = None
def format(self):
step = f"Executed `{self.action.use_tool}`\n"
reasoning = (
_r.summary()
if isinstance(_r := self.action.thoughts, ModelWithSummary)
else _r
)
step += f'- **Reasoning:** "{reasoning}"\n'
step += (
"- **Status:** "
f"`{self.result.status if self.result else 'did_not_finish'}`\n"
)
if self.result:
if self.result.status == "success":
result = str(self.result)
result = "\n" + indent(result) if "\n" in result else result
step += f"- **Output:** {result}"
elif self.result.status == "error":
step += f"- **Reason:** {self.result.reason}\n"
if self.result.error:
step += f"- **Error:** {self.result.error}\n"
elif self.result.status == "interrupted_by_human":
step += f"- **Feedback:** {self.result.feedback}\n"
return step
def __str__(self) -> str:
executed_action = f"Executed `{self.action.use_tool}`"
action_result = f": {self.result}" if self.result else "."
return executed_action + action_result
class EpisodicActionHistory(GenericModel, Generic[AnyProposal]):
"""Utility container for an action history"""
episodes: list[Episode[AnyProposal]] = Field(default_factory=list)
cursor: int = 0
_lock = asyncio.Lock()
@property
def current_episode(self) -> Episode[AnyProposal] | None:
if self.cursor == len(self):
return None
return self[self.cursor]
def __getitem__(self, key: int) -> Episode[AnyProposal]:
return self.episodes[key]
def __len__(self) -> int:
return len(self.episodes)
def __bool__(self) -> bool:
return len(self.episodes) > 0
def register_action(self, action: AnyProposal) -> None:
if not self.current_episode:
self.episodes.append(Episode(action=action, result=None))
assert self.current_episode
elif self.current_episode.action:
raise ValueError("Action for current cycle already set")
def register_result(self, result: ActionResult) -> None:
if not self.current_episode:
raise RuntimeError("Cannot register result for cycle without action")
elif self.current_episode.result:
raise ValueError("Result for current cycle already set")
self.current_episode.result = result
self.cursor = len(self.episodes)
def rewind(self, number_of_episodes: int = 0) -> None:
"""Resets the history to an earlier state.
Params:
number_of_cycles (int): The number of cycles to rewind. Default is 0.
When set to 0, it will only reset the current cycle.
"""
# Remove partial record of current cycle
if self.current_episode:
if self.current_episode.action and not self.current_episode.result:
self.episodes.pop(self.cursor)
# Rewind the specified number of cycles
if number_of_episodes > 0:
self.episodes = self.episodes[:-number_of_episodes]
self.cursor = len(self.episodes)
async def handle_compression(
self, llm_provider: MultiProvider, app_config: Config
) -> None:
"""Compresses each episode in the action history using an LLM.
This method iterates over all episodes in the action history without a summary,
and generates a summary for them using an LLM.
"""
compress_instruction = (
"The text represents an action, the reason for its execution, "
"and its result. "
"Condense the action taken and its result into one line. "
"Preserve any specific factual information gathered by the action."
)
async with self._lock:
# Gather all episodes without a summary
episodes_to_summarize = [ep for ep in self.episodes if ep.summary is None]
# Parallelize summarization calls
summarize_coroutines = [
summarize_text(
episode.format(),
instruction=compress_instruction,
llm_provider=llm_provider,
config=app_config,
)
for episode in episodes_to_summarize
]
summaries = await asyncio.gather(*summarize_coroutines)
# Assign summaries to episodes
for episode, (summary, _) in zip(episodes_to_summarize, summaries):
episode.summary = summary
def fmt_list(self) -> str:
return format_numbered_list(self.episodes)
def fmt_paragraph(self) -> str:
steps: list[str] = []
for i, episode in enumerate(self.episodes, 1):
step = f"### Step {i}: {episode.format()}\n"
steps.append(step)
return "\n\n".join(steps)

View File

@@ -0,0 +1,13 @@
from .code_executor import (
ALLOWLIST_CONTROL,
DENYLIST_CONTROL,
CodeExecutionError,
CodeExecutorComponent,
)
__all__ = [
"ALLOWLIST_CONTROL",
"DENYLIST_CONTROL",
"CodeExecutionError",
"CodeExecutorComponent",
]

View File

@@ -0,0 +1,410 @@
import logging
import os
import random
import shlex
import string
import subprocess
from pathlib import Path
from typing import Iterator
import docker
from docker.errors import DockerException, ImageNotFound, NotFound
from docker.models.containers import Container as DockerContainer
from forge.agent import BaseAgentSettings
from forge.agent.protocols import CommandProvider
from forge.command import Command, command
from forge.config.config import Config
from forge.file_storage import FileStorage
from forge.models.json_schema import JSONSchema
from forge.utils.exceptions import (
CommandExecutionError,
InvalidArgumentError,
OperationNotAllowedError,
)
logger = logging.getLogger(__name__)
ALLOWLIST_CONTROL = "allowlist"
DENYLIST_CONTROL = "denylist"
def we_are_running_in_a_docker_container() -> bool:
"""Check if we are running in a Docker container
Returns:
bool: True if we are running in a Docker container, False otherwise
"""
return os.path.exists("/.dockerenv")
def is_docker_available() -> bool:
"""Check if Docker is available and supports Linux containers
Returns:
bool: True if Docker is available and supports Linux containers, False otherwise
"""
try:
client = docker.from_env()
docker_info = client.info()
return docker_info["OSType"] == "linux"
except Exception:
return False
class CodeExecutionError(CommandExecutionError):
"""The operation (an attempt to run arbitrary code) returned an error"""
class CodeExecutorComponent(CommandProvider):
"""Provides commands to execute Python code and shell commands."""
def __init__(
self, workspace: FileStorage, state: BaseAgentSettings, config: Config
):
self.workspace = workspace
self.state = state
self.legacy_config = config
if not we_are_running_in_a_docker_container() and not is_docker_available():
logger.info(
"Docker is not available or does not support Linux containers. "
"The code execution commands will not be available."
)
if not self.legacy_config.execute_local_commands:
logger.info(
"Local shell commands are disabled. To enable them,"
" set EXECUTE_LOCAL_COMMANDS to 'True' in your config file."
)
def get_commands(self) -> Iterator[Command]:
if we_are_running_in_a_docker_container() or is_docker_available():
yield self.execute_python_code
yield self.execute_python_file
if self.legacy_config.execute_local_commands:
yield self.execute_shell
yield self.execute_shell_popen
@command(
["execute_python_code"],
"Executes the given Python code inside a single-use Docker container"
" with access to your workspace folder",
{
"code": JSONSchema(
type=JSONSchema.Type.STRING,
description="The Python code to run",
required=True,
),
},
)
async def execute_python_code(self, code: str) -> str:
"""
Create and execute a Python file in a Docker container
and return the STDOUT of the executed code.
If the code generates any data that needs to be captured,
use a print statement.
Args:
code (str): The Python code to run.
agent (Agent): The Agent executing the command.
Returns:
str: The STDOUT captured from the code when it ran.
"""
temp_path = ""
while True:
temp_path = f"temp{self._generate_random_string()}.py"
if not self.workspace.exists(temp_path):
break
await self.workspace.write_file(temp_path, code)
try:
return self.execute_python_file(temp_path)
except Exception as e:
raise CommandExecutionError(*e.args)
finally:
self.workspace.delete_file(temp_path)
@command(
["execute_python_file"],
"Execute an existing Python file inside a single-use Docker container"
" with access to your workspace folder",
{
"filename": JSONSchema(
type=JSONSchema.Type.STRING,
description="The name of the file to execute",
required=True,
),
"args": JSONSchema(
type=JSONSchema.Type.ARRAY,
description="The (command line) arguments to pass to the script",
required=False,
items=JSONSchema(type=JSONSchema.Type.STRING),
),
},
)
def execute_python_file(self, filename: str | Path, args: list[str] = []) -> str:
"""Execute a Python file in a Docker container and return the output
Args:
filename (Path): The name of the file to execute
args (list, optional): The arguments with which to run the python script
Returns:
str: The output of the file
"""
logger.info(f"Executing python file '{filename}'")
if not str(filename).endswith(".py"):
raise InvalidArgumentError("Invalid file type. Only .py files are allowed.")
file_path = self.workspace.get_path(filename)
if not self.workspace.exists(file_path):
# Mimic the response that you get from the command line to make it
# intuitively understandable for the LLM
raise FileNotFoundError(
f"python: can't open file '{filename}': "
f"[Errno 2] No such file or directory"
)
if we_are_running_in_a_docker_container():
logger.debug(
"App is running in a Docker container; "
f"executing {file_path} directly..."
)
with self.workspace.mount() as local_path:
result = subprocess.run(
["python", "-B", str(file_path.relative_to(self.workspace.root))]
+ args,
capture_output=True,
encoding="utf8",
cwd=str(local_path),
)
if result.returncode == 0:
return result.stdout
else:
raise CodeExecutionError(result.stderr)
logger.debug("App is not running in a Docker container")
return self._run_python_code_in_docker(file_path, args)
def validate_command(self, command_line: str, config: Config) -> tuple[bool, bool]:
"""Check whether a command is allowed and whether it may be executed in a shell.
If shell command control is enabled, we disallow executing in a shell, because
otherwise the model could circumvent the command filter using shell features.
Args:
command_line (str): The command line to validate
config (Config): The app config including shell command control settings
Returns:
bool: True if the command is allowed, False otherwise
bool: True if the command may be executed in a shell, False otherwise
"""
if not command_line:
return False, False
command_name = shlex.split(command_line)[0]
if config.shell_command_control == ALLOWLIST_CONTROL:
return command_name in config.shell_allowlist, False
elif config.shell_command_control == DENYLIST_CONTROL:
return command_name not in config.shell_denylist, False
else:
return True, True
@command(
["execute_shell"],
"Execute a Shell Command, non-interactive commands only",
{
"command_line": JSONSchema(
type=JSONSchema.Type.STRING,
description="The command line to execute",
required=True,
)
},
)
def execute_shell(self, command_line: str) -> str:
"""Execute a shell command and return the output
Args:
command_line (str): The command line to execute
Returns:
str: The output of the command
"""
allow_execute, allow_shell = self.validate_command(
command_line, self.legacy_config
)
if not allow_execute:
logger.info(f"Command '{command_line}' not allowed")
raise OperationNotAllowedError("This shell command is not allowed.")
current_dir = Path.cwd()
# Change dir into workspace if necessary
if not current_dir.is_relative_to(self.workspace.root):
os.chdir(self.workspace.root)
logger.info(
f"Executing command '{command_line}' in working directory '{os.getcwd()}'"
)
result = subprocess.run(
command_line if allow_shell else shlex.split(command_line),
capture_output=True,
shell=allow_shell,
)
output = f"STDOUT:\n{result.stdout.decode()}\nSTDERR:\n{result.stderr.decode()}"
# Change back to whatever the prior working dir was
os.chdir(current_dir)
return output
@command(
["execute_shell_popen"],
"Execute a Shell Command, non-interactive commands only",
{
"command_line": JSONSchema(
type=JSONSchema.Type.STRING,
description="The command line to execute",
required=True,
)
},
)
def execute_shell_popen(self, command_line: str) -> str:
"""Execute a shell command with Popen and returns an english description
of the event and the process id
Args:
command_line (str): The command line to execute
Returns:
str: Description of the fact that the process started and its id
"""
allow_execute, allow_shell = self.validate_command(
command_line, self.legacy_config
)
if not allow_execute:
logger.info(f"Command '{command_line}' not allowed")
raise OperationNotAllowedError("This shell command is not allowed.")
current_dir = Path.cwd()
# Change dir into workspace if necessary
if not current_dir.is_relative_to(self.workspace.root):
os.chdir(self.workspace.root)
logger.info(
f"Executing command '{command_line}' in working directory '{os.getcwd()}'"
)
do_not_show_output = subprocess.DEVNULL
process = subprocess.Popen(
command_line if allow_shell else shlex.split(command_line),
shell=allow_shell,
stdout=do_not_show_output,
stderr=do_not_show_output,
)
# Change back to whatever the prior working dir was
os.chdir(current_dir)
return f"Subprocess started with PID:'{str(process.pid)}'"
def _run_python_code_in_docker(self, filename: str | Path, args: list[str]) -> str:
"""Run a Python script in a Docker container"""
file_path = self.workspace.get_path(filename)
try:
assert self.state.agent_id, "Need Agent ID to attach Docker container"
client = docker.from_env()
image_name = "python:3-alpine"
container_is_fresh = False
container_name = f"{self.state.agent_id}_sandbox"
with self.workspace.mount() as local_path:
try:
container: DockerContainer = client.containers.get(
container_name
) # type: ignore
except NotFound:
try:
client.images.get(image_name)
logger.debug(f"Image '{image_name}' found locally")
except ImageNotFound:
logger.info(
f"Image '{image_name}' not found locally,"
" pulling from Docker Hub..."
)
# Use the low-level API to stream the pull response
low_level_client = docker.APIClient()
for line in low_level_client.pull(
image_name, stream=True, decode=True
):
# Print the status and progress, if available
status = line.get("status")
progress = line.get("progress")
if status and progress:
logger.info(f"{status}: {progress}")
elif status:
logger.info(status)
logger.debug(f"Creating new {image_name} container...")
container: DockerContainer = client.containers.run(
image_name,
["sleep", "60"], # Max 60 seconds to prevent permanent hangs
volumes={
str(local_path.resolve()): {
"bind": "/workspace",
"mode": "rw",
}
},
working_dir="/workspace",
stderr=True,
stdout=True,
detach=True,
name=container_name,
) # type: ignore
container_is_fresh = True
if not container.status == "running":
container.start()
elif not container_is_fresh:
container.restart()
logger.debug(f"Running {file_path} in container {container.name}...")
exec_result = container.exec_run(
[
"python",
"-B",
file_path.relative_to(self.workspace.root).as_posix(),
]
+ args,
stderr=True,
stdout=True,
)
if exec_result.exit_code != 0:
raise CodeExecutionError(exec_result.output.decode("utf-8"))
return exec_result.output.decode("utf-8")
except DockerException as e:
logger.warning(
"Could not run the script in a container. "
"If you haven't already, please install Docker: "
"https://docs.docker.com/get-docker/"
)
raise CommandExecutionError(f"Could not run the script in a container: {e}")
def _generate_random_string(self, length: int = 8):
# Create a string of all letters and digits
characters = string.ascii_letters + string.digits
# Use random.choices to generate a random string
random_string = "".join(random.choices(characters, k=length))
return random_string

View File

@@ -0,0 +1,3 @@
from .code_flow_executor import (
CodeFlowExecutionComponent
)

View File

@@ -0,0 +1,80 @@
"""Commands to generate images based on text input"""
import inspect
import logging
from typing import Iterator
from forge.agent.protocols import CommandProvider
from forge.command import Command, command
from forge.models.json_schema import JSONSchema
MAX_RESULT_LENGTH = 1000
logger = logging.getLogger(__name__)
class CodeFlowExecutionComponent(CommandProvider):
"""A component that provides commands to execute code flow."""
def __init__(self):
self._enabled = True
self.available_functions = {}
def set_available_functions(self, functions: list[Command]):
self.available_functions = {
name: function for function in functions for name in function.names
}
def get_commands(self) -> Iterator[Command]:
yield self.execute_code_flow
@command(
parameters={
"python_code": JSONSchema(
type=JSONSchema.Type.STRING,
description="The Python code to execute",
required=True,
),
"plan_text": JSONSchema(
type=JSONSchema.Type.STRING,
description="The plan to written in a natural language",
required=False,
),
},
)
async def execute_code_flow(self, python_code: str, plan_text: str) -> str:
"""Execute the code flow.
Args:
python_code (str): The Python code to execute
callables (dict[str, Callable]): The dictionary of [name, callable] pairs to use in the code
Returns:
str: The result of the code execution
"""
code_header = "import inspect\n" + "\n".join(
[
f"""
async def {name}(*args, **kwargs):
result = {name}_func(*args, **kwargs)
if inspect.isawaitable(result):
result = await result
return result
"""
for name in self.available_functions.keys()
]
)
result = {
name + "_func": func for name, func in self.available_functions.items()
}
code = f"{code_header}\n{python_code}\n\nexec_output = main()"
logger.debug(f"Code-Flow Execution code:\n{python_code}")
exec(code, result)
result = await result["exec_output"]
logger.debug(f"Code-Flow Execution result:\n{result}")
if inspect.isawaitable(result):
result = await result
# limit the result to limit the characters
if len(result) > MAX_RESULT_LENGTH:
result = result[:MAX_RESULT_LENGTH] + "...[Truncated, Content is too long]"
return f"Execution Plan:\n{plan_text}\n\nExecution Output:\n{result}"

View File

@@ -0,0 +1,15 @@
from .context import ContextComponent
from .context_item import (
ContextItem,
FileContextItem,
FolderContextItem,
StaticContextItem,
)
__all__ = [
"ContextComponent",
"ContextItem",
"FileContextItem",
"FolderContextItem",
"StaticContextItem",
]

View File

@@ -0,0 +1,163 @@
import contextlib
from pathlib import Path
from typing import Iterator
from pydantic import BaseModel, Field
from typing_extensions import Annotated
from forge.agent.protocols import CommandProvider, MessageProvider
from forge.command import Command, command
from forge.file_storage.base import FileStorage
from forge.llm.providers import ChatMessage
from forge.models.json_schema import JSONSchema
from forge.utils.exceptions import InvalidArgumentError
from .context_item import ContextItem, FileContextItem, FolderContextItem
class AgentContext(BaseModel):
items: list[Annotated[ContextItem, Field(discriminator="type")]] = Field(
default_factory=list
)
def __bool__(self) -> bool:
return len(self.items) > 0
def __contains__(self, item: ContextItem) -> bool:
return any([i.source == item.source for i in self.items])
def add(self, item: ContextItem) -> None:
self.items.append(item)
def close(self, index: int) -> None:
self.items.pop(index - 1)
def clear(self) -> None:
self.items.clear()
def format_numbered(self, workspace: FileStorage) -> str:
return "\n\n".join(
[f"{i}. {c.fmt(workspace)}" for i, c in enumerate(self.items, 1)]
)
class ContextComponent(MessageProvider, CommandProvider):
"""Adds ability to keep files and folders open in the context (prompt)."""
def __init__(self, workspace: FileStorage, context: AgentContext):
self.context = context
self.workspace = workspace
def get_messages(self) -> Iterator[ChatMessage]:
if self.context:
yield ChatMessage.system(
"## Context\n"
f"{self.context.format_numbered(self.workspace)}\n\n"
"When a context item is no longer needed and you are not done yet, "
"you can hide the item by specifying its number in the list above "
"to `hide_context_item`.",
)
def get_commands(self) -> Iterator[Command]:
yield self.open_file
yield self.open_folder
if self.context:
yield self.close_context_item
@command(
parameters={
"file_path": JSONSchema(
type=JSONSchema.Type.STRING,
description="The path of the file to open",
required=True,
)
}
)
async def open_file(self, file_path: str | Path) -> str:
"""Opens a file for editing or continued viewing;
creates it if it does not exist yet.
Note: If you only need to read or write a file once,
use `write_to_file` instead.
Args:
file_path (str | Path): The path of the file to open
Returns:
str: A status message indicating what happened
"""
if not isinstance(file_path, Path):
file_path = Path(file_path)
created = False
if not self.workspace.exists(file_path):
await self.workspace.write_file(file_path, "")
created = True
# Try to make the file path relative
with contextlib.suppress(ValueError):
file_path = file_path.relative_to(self.workspace.root)
file = FileContextItem(path=file_path)
self.context.add(file)
return (
f"File {file_path}{' created,' if created else ''} has been opened"
" and added to the context ✅"
)
@command(
parameters={
"path": JSONSchema(
type=JSONSchema.Type.STRING,
description="The path of the folder to open",
required=True,
)
}
)
def open_folder(self, path: str | Path) -> str:
"""Open a folder to keep track of its content
Args:
path (str | Path): The path of the folder to open
Returns:
str: A status message indicating what happened
"""
if not isinstance(path, Path):
path = Path(path)
if not self.workspace.exists(path):
raise FileNotFoundError(
f"open_folder {path} failed: no such file or directory"
)
# Try to make the path relative
with contextlib.suppress(ValueError):
path = path.relative_to(self.workspace.root)
folder = FolderContextItem(path=path)
self.context.add(folder)
return f"Folder {path} has been opened and added to the context ✅"
@command(
parameters={
"number": JSONSchema(
type=JSONSchema.Type.INTEGER,
description="The 1-based index of the context item to hide",
required=True,
)
}
)
def close_context_item(self, number: int) -> str:
"""Hide an open file, folder or other context item, to save tokens.
Args:
number (int): The 1-based index of the context item to hide
Returns:
str: A status message indicating what happened
"""
if number > len(self.context.items) or number == 0:
raise InvalidArgumentError(f"Index {number} out of range")
self.context.close(number)
return f"Context item {number} hidden ✅"

View File

@@ -0,0 +1,85 @@
import logging
from abc import ABC, abstractmethod
from pathlib import Path
from typing import Literal, Optional
from pydantic import BaseModel, Field
from forge.file_storage.base import FileStorage
from forge.utils.file_operations import decode_textual_file
logger = logging.getLogger(__name__)
class BaseContextItem(ABC):
@property
@abstractmethod
def description(self) -> str:
"""Description of the context item"""
...
@property
@abstractmethod
def source(self) -> Optional[str]:
"""A string indicating the source location of the context item"""
...
@abstractmethod
def get_content(self, workspace: FileStorage) -> str:
"""The content represented by the context item"""
...
def fmt(self, workspace: FileStorage) -> str:
return (
f"{self.description} (source: {self.source})\n"
"```\n"
f"{self.get_content(workspace)}\n"
"```"
)
class FileContextItem(BaseModel, BaseContextItem):
path: Path
type: Literal["file"] = "file"
@property
def description(self) -> str:
return f"The current content of the file '{self.path}'"
@property
def source(self) -> str:
return str(self.path)
def get_content(self, workspace: FileStorage) -> str:
with workspace.open_file(self.path, "r", True) as file:
return decode_textual_file(file, self.path.suffix, logger)
class FolderContextItem(BaseModel, BaseContextItem):
path: Path
type: Literal["folder"] = "folder"
@property
def description(self) -> str:
return f"The contents of the folder '{self.path}' in the workspace"
@property
def source(self) -> str:
return str(self.path)
def get_content(self, workspace: FileStorage) -> str:
files = [str(p) for p in workspace.list_files(self.path)]
folders = [f"{str(p)}/" for p in workspace.list_folders(self.path)]
items = folders + files
items.sort()
return "\n".join(items)
class StaticContextItem(BaseModel, BaseContextItem):
item_description: str = Field(alias="description")
item_source: Optional[str] = Field(alias="source")
item_content: str = Field(alias="content")
type: Literal["static"] = "static"
ContextItem = FileContextItem | FolderContextItem | StaticContextItem

View File

@@ -0,0 +1,3 @@
from .file_manager import FileManagerComponent
__all__ = ["FileManagerComponent"]

View File

@@ -0,0 +1,160 @@
import logging
import os
from pathlib import Path
from typing import Iterator, Optional
from forge.agent import BaseAgentSettings
from forge.agent.protocols import CommandProvider, DirectiveProvider
from forge.command import Command, command
from forge.file_storage.base import FileStorage
from forge.models.json_schema import JSONSchema
from forge.utils.file_operations import decode_textual_file
logger = logging.getLogger(__name__)
class FileManagerComponent(DirectiveProvider, CommandProvider):
"""
Adds general file manager (e.g. Agent state),
workspace manager (e.g. Agent output files) support and
commands to perform operations on files and folders.
"""
files: FileStorage
"""Agent-related files, e.g. state, logs.
Use `workspace` to access the agent's workspace files."""
workspace: FileStorage
"""Workspace that the agent has access to, e.g. for reading/writing files.
Use `files` to access agent-related files, e.g. state, logs."""
STATE_FILE = "state.json"
"""The name of the file where the agent's state is stored."""
def __init__(self, state: BaseAgentSettings, file_storage: FileStorage):
self.state = state
if not state.agent_id:
raise ValueError("Agent must have an ID.")
self.files = file_storage.clone_with_subroot(f"agents/{state.agent_id}/")
self.workspace = file_storage.clone_with_subroot(
f"agents/{state.agent_id}/workspace"
)
self._file_storage = file_storage
async def save_state(self, save_as: Optional[str] = None) -> None:
"""Save the agent's state to the state file."""
state: BaseAgentSettings = getattr(self, "state")
if save_as:
temp_id = state.agent_id
state.agent_id = save_as
self._file_storage.make_dir(f"agents/{save_as}")
# Save state
await self._file_storage.write_file(
f"agents/{save_as}/{self.STATE_FILE}", state.json()
)
# Copy workspace
self._file_storage.copy(
f"agents/{temp_id}/workspace",
f"agents/{save_as}/workspace",
)
state.agent_id = temp_id
else:
await self.files.write_file(self.files.root / self.STATE_FILE, state.json())
def change_agent_id(self, new_id: str):
"""Change the agent's ID and update the file storage accordingly."""
state: BaseAgentSettings = getattr(self, "state")
# Rename the agent's files and workspace
self._file_storage.rename(f"agents/{state.agent_id}", f"agents/{new_id}")
# Update the file storage objects
self.files = self._file_storage.clone_with_subroot(f"agents/{new_id}/")
self.workspace = self._file_storage.clone_with_subroot(
f"agents/{new_id}/workspace"
)
state.agent_id = new_id
def get_resources(self) -> Iterator[str]:
yield "The ability to read and write files."
def get_commands(self) -> Iterator[Command]:
yield self.read_file
yield self.write_to_file
yield self.list_folder
@command(
parameters={
"filename": JSONSchema(
type=JSONSchema.Type.STRING,
description="The path of the file to read",
required=True,
)
},
)
def read_file(self, filename: str | Path) -> str:
"""Read a file and return the contents
Args:
filename (str): The name of the file to read
Returns:
str: The contents of the file
"""
file = self.workspace.open_file(filename, binary=True)
content = decode_textual_file(file, os.path.splitext(filename)[1], logger)
return content
@command(
["write_file", "create_file"],
"Write a file, creating it if necessary. "
"If the file exists, it is overwritten.",
{
"filename": JSONSchema(
type=JSONSchema.Type.STRING,
description="The name of the file to write to",
required=True,
),
"contents": JSONSchema(
type=JSONSchema.Type.STRING,
description="The contents to write to the file",
required=True,
),
},
)
async def write_to_file(self, filename: str | Path, contents: str) -> str:
"""Write contents to a file
Args:
filename (str): The name of the file to write to
contents (str): The contents to write to the file
Returns:
str: A message indicating success or failure
"""
if directory := os.path.dirname(filename):
self.workspace.make_dir(directory)
await self.workspace.write_file(filename, contents)
return f"File {filename} has been written successfully."
@command(
parameters={
"folder": JSONSchema(
type=JSONSchema.Type.STRING,
description="The folder to list files in. "
"Pass an empty string to list files in the workspace.",
required=True,
)
},
)
def list_folder(self, folder: str | Path) -> list[str]:
"""Lists files in a folder recursively
Args:
folder (str): The folder to search in
Returns:
list[str]: A list of files found in the folder
"""
return [str(p) for p in self.workspace.list_files(folder)]

View File

@@ -0,0 +1,3 @@
from .git_operations import GitOperationsComponent
__all__ = ["GitOperationsComponent"]

View File

@@ -0,0 +1,60 @@
from pathlib import Path
from typing import Iterator
from git.repo import Repo
from forge.agent.protocols import CommandProvider
from forge.command import Command, command
from forge.config.config import Config
from forge.models.json_schema import JSONSchema
from forge.utils.exceptions import CommandExecutionError
from forge.utils.url_validator import validate_url
class GitOperationsComponent(CommandProvider):
"""Provides commands to perform Git operations."""
def __init__(self, config: Config):
self._enabled = bool(config.github_username and config.github_api_key)
self._disabled_reason = "Configure github_username and github_api_key."
self.legacy_config = config
def get_commands(self) -> Iterator[Command]:
yield self.clone_repository
@command(
parameters={
"url": JSONSchema(
type=JSONSchema.Type.STRING,
description="The URL of the repository to clone",
required=True,
),
"clone_path": JSONSchema(
type=JSONSchema.Type.STRING,
description="The path to clone the repository to",
required=True,
),
},
)
@validate_url
def clone_repository(self, url: str, clone_path: Path) -> str:
"""Clone a GitHub repository locally.
Args:
url (str): The URL of the repository to clone.
clone_path (Path): The path to clone the repository to.
Returns:
str: The result of the clone operation.
"""
split_url = url.split("//")
auth_repo_url = (
f"//{self.legacy_config.github_username}:"
f"{self.legacy_config.github_api_key}@".join(split_url)
)
try:
Repo.clone_from(url=auth_repo_url, to_path=clone_path)
except Exception as e:
raise CommandExecutionError(f"Could not clone repo: {e}")
return f"""Cloned {url} to {clone_path}"""

View File

@@ -0,0 +1,3 @@
from .image_gen import ImageGeneratorComponent
__all__ = ["ImageGeneratorComponent"]

View File

@@ -0,0 +1,239 @@
"""Commands to generate images based on text input"""
import io
import json
import logging
import time
import uuid
from base64 import b64decode
from pathlib import Path
from typing import Iterator
import requests
from openai import OpenAI
from PIL import Image
from forge.agent.protocols import CommandProvider
from forge.command import Command, command
from forge.config.config import Config
from forge.file_storage import FileStorage
from forge.models.json_schema import JSONSchema
logger = logging.getLogger(__name__)
class ImageGeneratorComponent(CommandProvider):
"""A component that provides commands to generate images from text prompts."""
def __init__(self, workspace: FileStorage, config: Config):
self._enabled = bool(config.image_provider)
self._disabled_reason = "No image provider set."
self.workspace = workspace
self.legacy_config = config
def get_commands(self) -> Iterator[Command]:
if (
self.legacy_config.openai_credentials
or self.legacy_config.huggingface_api_token
or self.legacy_config.sd_webui_auth
):
yield self.generate_image
@command(
parameters={
"prompt": JSONSchema(
type=JSONSchema.Type.STRING,
description="The prompt used to generate the image",
required=True,
),
"size": JSONSchema(
type=JSONSchema.Type.INTEGER,
description="The size of the image",
required=False,
),
},
)
def generate_image(self, prompt: str, size: int) -> str:
"""Generate an image from a prompt.
Args:
prompt (str): The prompt to use
size (int, optional): The size of the image. Defaults to 256.
Not supported by HuggingFace.
Returns:
str: The filename of the image
"""
filename = self.workspace.root / f"{str(uuid.uuid4())}.jpg"
cfg = self.legacy_config
if cfg.openai_credentials and (
cfg.image_provider == "dalle"
or not (cfg.huggingface_api_token or cfg.sd_webui_url)
):
return self.generate_image_with_dalle(prompt, filename, size)
elif cfg.huggingface_api_token and (
cfg.image_provider == "huggingface"
or not (cfg.openai_credentials or cfg.sd_webui_url)
):
return self.generate_image_with_hf(prompt, filename)
elif cfg.sd_webui_url and (
cfg.image_provider == "sdwebui" or cfg.sd_webui_auth
):
return self.generate_image_with_sd_webui(prompt, filename, size)
return "Error: No image generation provider available"
def generate_image_with_hf(self, prompt: str, output_file: Path) -> str:
"""Generate an image with HuggingFace's API.
Args:
prompt (str): The prompt to use
filename (Path): The filename to save the image to
Returns:
str: The filename of the image
"""
API_URL = f"https://api-inference.huggingface.co/models/{self.legacy_config.huggingface_image_model}" # noqa: E501
if self.legacy_config.huggingface_api_token is None:
raise ValueError(
"You need to set your Hugging Face API token in the config file."
)
headers = {
"Authorization": f"Bearer {self.legacy_config.huggingface_api_token}",
"X-Use-Cache": "false",
}
retry_count = 0
while retry_count < 10:
response = requests.post(
API_URL,
headers=headers,
json={
"inputs": prompt,
},
)
if response.ok:
try:
image = Image.open(io.BytesIO(response.content))
logger.info(f"Image Generated for prompt:{prompt}")
image.save(output_file)
return f"Saved to disk: {output_file}"
except Exception as e:
logger.error(e)
break
else:
try:
error = json.loads(response.text)
if "estimated_time" in error:
delay = error["estimated_time"]
logger.debug(response.text)
logger.info("Retrying in", delay)
time.sleep(delay)
else:
break
except Exception as e:
logger.error(e)
break
retry_count += 1
return "Error creating image."
def generate_image_with_dalle(
self, prompt: str, output_file: Path, size: int
) -> str:
"""Generate an image with DALL-E.
Args:
prompt (str): The prompt to use
filename (Path): The filename to save the image to
size (int): The size of the image
Returns:
str: The filename of the image
"""
assert self.legacy_config.openai_credentials # otherwise this tool is disabled
# Check for supported image sizes
if size not in [256, 512, 1024]:
closest = min([256, 512, 1024], key=lambda x: abs(x - size))
logger.info(
"DALL-E only supports image sizes of 256x256, 512x512, or 1024x1024. "
f"Setting to {closest}, was {size}."
)
size = closest
# TODO: integrate in `forge.llm.providers`(?)
response = OpenAI(
api_key=self.legacy_config.openai_credentials.api_key.get_secret_value()
).images.generate(
prompt=prompt,
n=1,
# TODO: improve typing of size config item(s)
size=f"{size}x{size}", # type: ignore
response_format="b64_json",
)
assert response.data[0].b64_json is not None # response_format = "b64_json"
logger.info(f"Image Generated for prompt: {prompt}")
image_data = b64decode(response.data[0].b64_json)
with open(output_file, mode="wb") as png:
png.write(image_data)
return f"Saved to disk: {output_file}"
def generate_image_with_sd_webui(
self,
prompt: str,
output_file: Path,
size: int = 512,
negative_prompt: str = "",
extra: dict = {},
) -> str:
"""Generate an image with Stable Diffusion webui.
Args:
prompt (str): The prompt to use
filename (str): The filename to save the image to
size (int, optional): The size of the image. Defaults to 256.
negative_prompt (str, optional): The negative prompt to use. Defaults to "".
extra (dict, optional): Extra parameters to pass to the API. Defaults to {}.
Returns:
str: The filename of the image
"""
# Create a session and set the basic auth if needed
s = requests.Session()
if self.legacy_config.sd_webui_auth:
username, password = self.legacy_config.sd_webui_auth.split(":")
s.auth = (username, password or "")
# Generate the images
response = requests.post(
f"{self.legacy_config.sd_webui_url}/sdapi/v1/txt2img",
json={
"prompt": prompt,
"negative_prompt": negative_prompt,
"sampler_index": "DDIM",
"steps": 20,
"config_scale": 7.0,
"width": size,
"height": size,
"n_iter": 1,
**extra,
},
)
logger.info(f"Image Generated for prompt: '{prompt}'")
# Save the image to disk
response = response.json()
b64 = b64decode(response["images"][0].split(",", 1)[0])
image = Image.open(io.BytesIO(b64))
image.save(output_file)
return f"Saved to disk: {output_file}"

View File

@@ -0,0 +1,3 @@
from .system import SystemComponent
__all__ = ["SystemComponent"]

View File

@@ -0,0 +1,79 @@
import logging
import time
from typing import Iterator
from forge.agent.protocols import CommandProvider, DirectiveProvider, MessageProvider
from forge.command import Command, command
from forge.llm.providers import ChatMessage
from forge.models.json_schema import JSONSchema
from forge.utils.const import FINISH_COMMAND
from forge.utils.exceptions import AgentFinished
logger = logging.getLogger(__name__)
class SystemComponent(DirectiveProvider, MessageProvider, CommandProvider):
"""Component for system messages and commands."""
def get_constraints(self) -> Iterator[str]:
yield "Exclusively use the commands listed below."
yield (
"You can only act proactively, and are unable to start background jobs or "
"set up webhooks for yourself. "
"Take this into account when planning your actions."
)
yield (
"You are unable to interact with physical objects. "
"If this is absolutely necessary to fulfill a task or objective or "
"to complete a step, you must ask the user to do it for you. "
"If the user refuses this, and there is no other way to achieve your "
"goals, you must terminate to avoid wasting time and energy."
)
def get_resources(self) -> Iterator[str]:
yield (
"You are a Large Language Model, trained on millions of pages of text, "
"including a lot of factual knowledge. Make use of this factual knowledge "
"to avoid unnecessary gathering of information."
)
def get_best_practices(self) -> Iterator[str]:
yield (
"Continuously review and analyze your actions to ensure "
"you are performing to the best of your abilities."
)
yield "Constructively self-criticize your big-picture behavior constantly."
yield "Reflect on past decisions and strategies to refine your approach."
yield (
"Every command has a cost, so be smart and efficient. "
"Aim to complete tasks in the least number of steps."
)
yield (
"Only make use of your information gathering abilities to find "
"information that you don't yet have knowledge of."
)
def get_messages(self) -> Iterator[ChatMessage]:
# Clock
yield ChatMessage.system(
f"## Clock\nThe current time and date is {time.strftime('%c')}"
)
def get_commands(self) -> Iterator[Command]:
yield self.finish
@command(
names=[FINISH_COMMAND],
parameters={
"reason": JSONSchema(
type=JSONSchema.Type.STRING,
description="A summary to the user of how the goals were accomplished",
required=True,
),
},
)
def finish(self, reason: str):
"""Use this to shut down once you have completed your task,
or when there are insurmountable problems that make it impossible
for you to finish your task."""
raise AgentFinished(reason)

View File

@@ -0,0 +1,3 @@
from .user_interaction import UserInteractionComponent
__all__ = ["UserInteractionComponent"]

View File

@@ -0,0 +1,36 @@
from typing import Iterator
import click
from forge.agent.protocols import CommandProvider
from forge.command import Command, command
from forge.config.config import Config
from forge.models.json_schema import JSONSchema
from forge.utils.const import ASK_COMMAND
class UserInteractionComponent(CommandProvider):
"""Provides commands to interact with the user."""
def __init__(self, config: Config):
self._enabled = not config.noninteractive_mode
def get_commands(self) -> Iterator[Command]:
yield self.ask_user
@command(
names=[ASK_COMMAND],
parameters={
"question": JSONSchema(
type=JSONSchema.Type.STRING,
description="The question or prompt to the user",
required=True,
)
},
)
def ask_user(self, question: str) -> str:
"""If you need more details or information regarding the given task,
you can ask the user for input."""
print(f"\nQ: {question}")
resp = click.prompt("A")
return f"The user's answer: '{resp}'"

View File

@@ -0,0 +1,3 @@
from .watchdog import WatchdogComponent
__all__ = ["WatchdogComponent"]

View File

@@ -0,0 +1,63 @@
from __future__ import annotations
import logging
from typing import TYPE_CHECKING
from forge.agent.components import ComponentSystemError
from forge.agent.protocols import AfterParse
from forge.components.action_history import EpisodicActionHistory
from forge.models.action import AnyProposal
if TYPE_CHECKING:
from forge.agent.base import BaseAgentConfiguration
logger = logging.getLogger(__name__)
class WatchdogComponent(AfterParse[AnyProposal]):
"""
Adds a watchdog feature to an agent class. Whenever the agent starts
looping, the watchdog will switch from the FAST_LLM to the SMART_LLM and re-think.
"""
def __init__(
self,
config: "BaseAgentConfiguration",
event_history: EpisodicActionHistory[AnyProposal],
):
self.config = config
self.event_history = event_history
self.revert_big_brain = False
def after_parse(self, result: AnyProposal) -> None:
if self.revert_big_brain:
self.config.big_brain = False
self.revert_big_brain = False
if not self.config.big_brain and self.config.fast_llm != self.config.smart_llm:
previous_command, previous_command_args = None, None
if len(self.event_history) > 1:
# Detect repetitive commands
previous_cycle = self.event_history.episodes[
self.event_history.cursor - 1
]
previous_command = previous_cycle.action.use_tool.name
previous_command_args = previous_cycle.action.use_tool.arguments
rethink_reason = ""
if not result.use_tool:
rethink_reason = "AI did not specify a command"
elif (
result.use_tool.name == previous_command
and result.use_tool.arguments == previous_command_args
):
rethink_reason = f"Repititive command detected ({result.use_tool.name})"
if rethink_reason:
logger.info(f"{rethink_reason}, re-thinking with SMART_LLM...")
self.event_history.rewind()
self.big_brain = True
self.revert_big_brain = True
# Trigger retry of all pipelines prior to this component
raise ComponentSystemError(rethink_reason, self)

View File

@@ -0,0 +1,4 @@
from .search import WebSearchComponent
from .selenium import BrowsingError, WebSeleniumComponent
__all__ = ["WebSearchComponent", "BrowsingError", "WebSeleniumComponent"]

View File

@@ -0,0 +1,194 @@
import json
import logging
import time
from typing import Iterator
from duckduckgo_search import DDGS
from forge.agent.protocols import CommandProvider, DirectiveProvider
from forge.command import Command, command
from forge.config.config import Config
from forge.models.json_schema import JSONSchema
from forge.utils.exceptions import ConfigurationError
DUCKDUCKGO_MAX_ATTEMPTS = 3
logger = logging.getLogger(__name__)
class WebSearchComponent(DirectiveProvider, CommandProvider):
"""Provides commands to search the web."""
def __init__(self, config: Config):
self.legacy_config = config
if (
not self.legacy_config.google_api_key
or not self.legacy_config.google_custom_search_engine_id
):
logger.info(
"Configure google_api_key and custom_search_engine_id "
"to use Google API search."
)
def get_resources(self) -> Iterator[str]:
yield "Internet access for searches and information gathering."
def get_commands(self) -> Iterator[Command]:
yield self.web_search
if (
self.legacy_config.google_api_key
and self.legacy_config.google_custom_search_engine_id
):
yield self.google
@command(
["web_search", "search"],
"Searches the web",
{
"query": JSONSchema(
type=JSONSchema.Type.STRING,
description="The search query",
required=True,
),
"num_results": JSONSchema(
type=JSONSchema.Type.INTEGER,
description="The number of results to return",
minimum=1,
maximum=10,
required=False,
),
},
)
def web_search(self, query: str, num_results: int = 8) -> str:
"""Return the results of a Google search
Args:
query (str): The search query.
num_results (int): The number of results to return.
Returns:
str: The results of the search.
"""
search_results = []
attempts = 0
while attempts < DUCKDUCKGO_MAX_ATTEMPTS:
if not query:
return json.dumps(search_results)
search_results = DDGS().text(query, max_results=num_results)
if search_results:
break
time.sleep(1)
attempts += 1
search_results = [
{
"title": r["title"],
"url": r["href"],
**({"exerpt": r["body"]} if r.get("body") else {}),
}
for r in search_results
]
results = ("## Search results\n") + "\n\n".join(
f"### \"{r['title']}\"\n"
f"**URL:** {r['url']} \n"
"**Excerpt:** " + (f'"{exerpt}"' if (exerpt := r.get("exerpt")) else "N/A")
for r in search_results
)
return self.safe_google_results(results)
@command(
["google"],
"Google Search",
{
"query": JSONSchema(
type=JSONSchema.Type.STRING,
description="The search query",
required=True,
),
"num_results": JSONSchema(
type=JSONSchema.Type.INTEGER,
description="The number of results to return",
minimum=1,
maximum=10,
required=False,
),
},
)
def google(self, query: str, num_results: int = 8) -> str | list[str]:
"""Return the results of a Google search using the official Google API
Args:
query (str): The search query.
num_results (int): The number of results to return.
Returns:
str: The results of the search.
"""
from googleapiclient.discovery import build
from googleapiclient.errors import HttpError
try:
# Get the Google API key and Custom Search Engine ID from the config file
api_key = self.legacy_config.google_api_key
custom_search_engine_id = self.legacy_config.google_custom_search_engine_id
# Initialize the Custom Search API service
service = build("customsearch", "v1", developerKey=api_key)
# Send the search query and retrieve the results
result = (
service.cse()
.list(q=query, cx=custom_search_engine_id, num=num_results)
.execute()
)
# Extract the search result items from the response
search_results = result.get("items", [])
# Create a list of only the URLs from the search results
search_results_links = [item["link"] for item in search_results]
except HttpError as e:
# Handle errors in the API call
error_details = json.loads(e.content.decode())
# Check if the error is related to an invalid or missing API key
if error_details.get("error", {}).get(
"code"
) == 403 and "invalid API key" in error_details.get("error", {}).get(
"message", ""
):
raise ConfigurationError(
"The provided Google API key is invalid or missing."
)
raise
# google_result can be a list or a string depending on the search results
# Return the list of search result URLs
return self.safe_google_results(search_results_links)
def safe_google_results(self, results: str | list) -> str:
"""
Return the results of a Google search in a safe format.
Args:
results (str | list): The search results.
Returns:
str: The results of the search.
"""
if isinstance(results, list):
safe_message = json.dumps(
[result.encode("utf-8", "ignore").decode("utf-8") for result in results]
)
else:
safe_message = results.encode("utf-8", "ignore").decode("utf-8")
return safe_message

View File

@@ -0,0 +1,374 @@
import asyncio
import logging
import re
from pathlib import Path
from sys import platform
from typing import Iterator, Type
from urllib.request import urlretrieve
from bs4 import BeautifulSoup
from selenium.common.exceptions import WebDriverException
from selenium.webdriver.chrome.options import Options as ChromeOptions
from selenium.webdriver.chrome.service import Service as ChromeDriverService
from selenium.webdriver.chrome.webdriver import WebDriver as ChromeDriver
from selenium.webdriver.common.by import By
from selenium.webdriver.edge.options import Options as EdgeOptions
from selenium.webdriver.edge.service import Service as EdgeDriverService
from selenium.webdriver.edge.webdriver import WebDriver as EdgeDriver
from selenium.webdriver.firefox.options import Options as FirefoxOptions
from selenium.webdriver.firefox.service import Service as GeckoDriverService
from selenium.webdriver.firefox.webdriver import WebDriver as FirefoxDriver
from selenium.webdriver.remote.webdriver import WebDriver
from selenium.webdriver.safari.options import Options as SafariOptions
from selenium.webdriver.safari.webdriver import WebDriver as SafariDriver
from selenium.webdriver.support import expected_conditions as EC
from selenium.webdriver.support.wait import WebDriverWait
from webdriver_manager.chrome import ChromeDriverManager
from webdriver_manager.firefox import GeckoDriverManager
from webdriver_manager.microsoft import EdgeChromiumDriverManager as EdgeDriverManager
from forge.agent.protocols import CommandProvider, DirectiveProvider
from forge.command import Command, command
from forge.config.config import Config
from forge.content_processing.html import extract_hyperlinks, format_hyperlinks
from forge.content_processing.text import extract_information, summarize_text
from forge.llm.providers import ChatModelInfo, MultiProvider
from forge.models.json_schema import JSONSchema
from forge.utils.exceptions import CommandExecutionError, TooMuchOutputError
from forge.utils.url_validator import validate_url
logger = logging.getLogger(__name__)
FILE_DIR = Path(__file__).parent.parent
MAX_RAW_CONTENT_LENGTH = 500
LINKS_TO_RETURN = 20
BrowserOptions = ChromeOptions | EdgeOptions | FirefoxOptions | SafariOptions
class BrowsingError(CommandExecutionError):
"""An error occurred while trying to browse the page"""
class WebSeleniumComponent(DirectiveProvider, CommandProvider):
"""Provides commands to browse the web using Selenium."""
def __init__(
self,
config: Config,
llm_provider: MultiProvider,
model_info: ChatModelInfo,
):
self.legacy_config = config
self.llm_provider = llm_provider
self.model_info = model_info
def get_resources(self) -> Iterator[str]:
yield "Ability to read websites."
def get_commands(self) -> Iterator[Command]:
yield self.read_webpage
@command(
["read_webpage"],
(
"Read a webpage, and extract specific information from it."
" You must specify either topics_of_interest,"
" a question, or get_raw_content."
),
{
"url": JSONSchema(
type=JSONSchema.Type.STRING,
description="The URL to visit",
required=True,
),
"topics_of_interest": JSONSchema(
type=JSONSchema.Type.ARRAY,
items=JSONSchema(type=JSONSchema.Type.STRING),
description=(
"A list of topics about which you want to extract information "
"from the page."
),
required=False,
),
"question": JSONSchema(
type=JSONSchema.Type.STRING,
description=(
"A question you want to answer using the content of the webpage."
),
required=False,
),
"get_raw_content": JSONSchema(
type=JSONSchema.Type.BOOLEAN,
description=(
"If true, the unprocessed content of the webpage will be returned. "
"This consumes a lot of tokens, so use it with caution."
),
required=False,
),
},
)
@validate_url
async def read_webpage(
self,
url: str,
*,
topics_of_interest: list[str] = [],
get_raw_content: bool = False,
question: str = "",
) -> str:
"""Browse a website and return the answer and links to the user
Args:
url (str): The url of the website to browse
question (str): The question to answer using the content of the webpage
Returns:
str: The answer and links to the user and the webdriver
"""
driver = None
try:
driver = await self.open_page_in_browser(url, self.legacy_config)
text = self.scrape_text_with_selenium(driver)
links = self.scrape_links_with_selenium(driver, url)
return_literal_content = True
summarized = False
if not text:
return f"Website did not contain any text.\n\nLinks: {links}"
elif get_raw_content:
if (
output_tokens := self.llm_provider.count_tokens(
text, self.model_info.name
)
) > MAX_RAW_CONTENT_LENGTH:
oversize_factor = round(output_tokens / MAX_RAW_CONTENT_LENGTH, 1)
raise TooMuchOutputError(
f"Page content is {oversize_factor}x the allowed length "
"for `get_raw_content=true`"
)
return text + (f"\n\nLinks: {links}" if links else "")
else:
text = await self.summarize_webpage(
text, question or None, topics_of_interest
)
return_literal_content = bool(question)
summarized = True
# Limit links to LINKS_TO_RETURN
if len(links) > LINKS_TO_RETURN:
links = links[:LINKS_TO_RETURN]
text_fmt = f"'''{text}'''" if "\n" in text else f"'{text}'"
links_fmt = "\n".join(f"- {link}" for link in links)
return (
f"Page content{' (summary)' if summarized else ''}:"
if return_literal_content
else "Answer gathered from webpage:"
) + f" {text_fmt}\n\nLinks:\n{links_fmt}"
except WebDriverException as e:
# These errors are often quite long and include lots of context.
# Just grab the first line.
msg = e.msg.split("\n")[0] if e.msg else str(e)
if "net::" in msg:
raise BrowsingError(
"A networking error occurred while trying to load the page: %s"
% re.sub(r"^unknown error: ", "", msg)
)
raise CommandExecutionError(msg)
finally:
if driver:
driver.close()
def scrape_text_with_selenium(self, driver: WebDriver) -> str:
"""Scrape text from a browser window using selenium
Args:
driver (WebDriver): A driver object representing
the browser window to scrape
Returns:
str: the text scraped from the website
"""
# Get the HTML content directly from the browser's DOM
page_source = driver.execute_script("return document.body.outerHTML;")
soup = BeautifulSoup(page_source, "html.parser")
for script in soup(["script", "style"]):
script.extract()
text = soup.get_text()
lines = (line.strip() for line in text.splitlines())
chunks = (phrase.strip() for line in lines for phrase in line.split(" "))
text = "\n".join(chunk for chunk in chunks if chunk)
return text
def scrape_links_with_selenium(self, driver: WebDriver, base_url: str) -> list[str]:
"""Scrape links from a website using selenium
Args:
driver (WebDriver): A driver object representing
the browser window to scrape
base_url (str): The base URL to use for resolving relative links
Returns:
List[str]: The links scraped from the website
"""
page_source = driver.page_source
soup = BeautifulSoup(page_source, "html.parser")
for script in soup(["script", "style"]):
script.extract()
hyperlinks = extract_hyperlinks(soup, base_url)
return format_hyperlinks(hyperlinks)
async def open_page_in_browser(self, url: str, config: Config) -> WebDriver:
"""Open a browser window and load a web page using Selenium
Params:
url (str): The URL of the page to load
config (Config): The applicable application configuration
Returns:
driver (WebDriver): A driver object representing
the browser window to scrape
"""
logging.getLogger("selenium").setLevel(logging.CRITICAL)
options_available: dict[str, Type[BrowserOptions]] = {
"chrome": ChromeOptions,
"edge": EdgeOptions,
"firefox": FirefoxOptions,
"safari": SafariOptions,
}
options: BrowserOptions = options_available[config.selenium_web_browser]()
options.add_argument(f"user-agent={config.user_agent}")
if isinstance(options, FirefoxOptions):
if config.selenium_headless:
options.headless = True # type: ignore
options.add_argument("--disable-gpu")
driver = FirefoxDriver(
service=GeckoDriverService(GeckoDriverManager().install()),
options=options,
)
elif isinstance(options, EdgeOptions):
driver = EdgeDriver(
service=EdgeDriverService(EdgeDriverManager().install()),
options=options,
)
elif isinstance(options, SafariOptions):
# Requires a bit more setup on the users end.
# See https://developer.apple.com/documentation/webkit/testing_with_webdriver_in_safari # noqa: E501
driver = SafariDriver(options=options)
elif isinstance(options, ChromeOptions):
if platform == "linux" or platform == "linux2":
options.add_argument("--disable-dev-shm-usage")
options.add_argument("--remote-debugging-port=9222")
options.add_argument("--no-sandbox")
if config.selenium_headless:
options.add_argument("--headless=new")
options.add_argument("--disable-gpu")
self._sideload_chrome_extensions(
options, config.app_data_dir / "assets" / "crx"
)
if (chromium_driver_path := Path("/usr/bin/chromedriver")).exists():
chrome_service = ChromeDriverService(str(chromium_driver_path))
else:
try:
chrome_driver = ChromeDriverManager().install()
except AttributeError as e:
if "'NoneType' object has no attribute 'split'" in str(e):
# https://github.com/SergeyPirogov/webdriver_manager/issues/649
logger.critical(
"Connecting to browser failed:"
" is Chrome or Chromium installed?"
)
raise
chrome_service = ChromeDriverService(chrome_driver)
driver = ChromeDriver(service=chrome_service, options=options)
driver.get(url)
# Wait for page to be ready, sleep 2 seconds, wait again until page ready.
# This allows the cookiewall squasher time to get rid of cookie walls.
WebDriverWait(driver, 10).until(
EC.presence_of_element_located((By.TAG_NAME, "body"))
)
await asyncio.sleep(2)
WebDriverWait(driver, 10).until(
EC.presence_of_element_located((By.TAG_NAME, "body"))
)
return driver
def _sideload_chrome_extensions(
self, options: ChromeOptions, dl_folder: Path
) -> None:
crx_download_url_template = "https://clients2.google.com/service/update2/crx?response=redirect&prodversion=49.0&acceptformat=crx3&x=id%3D{crx_id}%26installsource%3Dondemand%26uc" # noqa
cookiewall_squasher_crx_id = "edibdbjcniadpccecjdfdjjppcpchdlm"
adblocker_crx_id = "cjpalhdlnbpafiamejdnhcphjbkeiagm"
# Make sure the target folder exists
dl_folder.mkdir(parents=True, exist_ok=True)
for crx_id in (cookiewall_squasher_crx_id, adblocker_crx_id):
crx_path = dl_folder / f"{crx_id}.crx"
if not crx_path.exists():
logger.debug(f"Downloading CRX {crx_id}...")
crx_download_url = crx_download_url_template.format(crx_id=crx_id)
urlretrieve(crx_download_url, crx_path)
logger.debug(f"Downloaded {crx_path.name}")
options.add_extension(str(crx_path))
async def summarize_webpage(
self,
text: str,
question: str | None,
topics_of_interest: list[str],
) -> str:
"""Summarize text using the OpenAI API
Args:
url (str): The url of the text
text (str): The text to summarize
question (str): The question to ask the model
driver (WebDriver): The webdriver to use to scroll the page
Returns:
str: The summary of the text
"""
if not text:
raise ValueError("No text to summarize")
text_length = len(text)
logger.debug(f"Web page content length: {text_length} characters")
result = None
information = None
if topics_of_interest:
information = await extract_information(
text,
topics_of_interest=topics_of_interest,
llm_provider=self.llm_provider,
config=self.legacy_config,
)
return "\n".join(f"* {i}" for i in information)
else:
result, _ = await summarize_text(
text,
question=question,
llm_provider=self.llm_provider,
config=self.legacy_config,
)
return result

View File

@@ -0,0 +1,14 @@
"""
This module contains configuration models and helpers for AutoGPT Forge.
"""
from .ai_directives import AIDirectives
from .ai_profile import AIProfile
from .config import Config, ConfigBuilder, assert_config_has_openai_api_key
__all__ = [
"assert_config_has_openai_api_key",
"AIProfile",
"AIDirectives",
"Config",
"ConfigBuilder",
]

View File

@@ -0,0 +1,28 @@
from __future__ import annotations
import logging
from pydantic import BaseModel, Field
logger = logging.getLogger(__name__)
class AIDirectives(BaseModel):
"""An object that contains the basic directives for the AI prompt.
Attributes:
constraints (list): A list of constraints that the AI should adhere to.
resources (list): A list of resources that the AI can utilize.
best_practices (list): A list of best practices that the AI should follow.
"""
resources: list[str] = Field(default_factory=list)
constraints: list[str] = Field(default_factory=list)
best_practices: list[str] = Field(default_factory=list)
def __add__(self, other: AIDirectives) -> AIDirectives:
return AIDirectives(
resources=self.resources + other.resources,
constraints=self.constraints + other.constraints,
best_practices=self.best_practices + other.best_practices,
).copy(deep=True)

View File

@@ -0,0 +1,27 @@
from pydantic import BaseModel, Field
DEFAULT_AI_NAME = "AutoGPT"
DEFAULT_AI_ROLE = (
"a seasoned digital assistant: "
"capable, intelligent, considerate and assertive. "
"You have extensive research and development skills, and you don't shy "
"away from writing some code to solve a problem. "
"You are pragmatic and make the most out of the tools available to you."
)
class AIProfile(BaseModel):
"""
Object to hold the AI's personality.
Attributes:
ai_name (str): The name of the AI.
ai_role (str): The description of the AI's role.
ai_goals (list): The list of objectives the AI is supposed to complete.
api_budget (float): The maximum dollar value for API calls (0.0 means infinite)
"""
ai_name: str = DEFAULT_AI_NAME
ai_role: str = DEFAULT_AI_ROLE
"""`ai_role` should fit in the following format: `You are {ai_name}, {ai_role}`"""
ai_goals: list[str] = Field(default_factory=list[str])

View File

@@ -0,0 +1,266 @@
"""Configuration class to store the state of bools for different scripts access."""
from __future__ import annotations
import logging
import os
import re
from pathlib import Path
from typing import Any, Optional, Union
import click
from colorama import Fore
from pydantic import SecretStr, validator
import forge
from forge.file_storage import FileStorageBackendName
from forge.llm.providers import CHAT_MODELS, ModelName
from forge.llm.providers.openai import OpenAICredentials, OpenAIModelName
from forge.logging.config import LoggingConfig
from forge.models.config import Configurable, SystemSettings, UserConfigurable
from forge.speech.say import TTSConfig
logger = logging.getLogger(__name__)
PROJECT_ROOT = Path(forge.__file__).parent.parent
AZURE_CONFIG_FILE = Path("azure.yaml")
GPT_4_MODEL = OpenAIModelName.GPT4
GPT_3_MODEL = OpenAIModelName.GPT3
class Config(SystemSettings, arbitrary_types_allowed=True):
name: str = "Auto-GPT configuration"
description: str = "Default configuration for the Auto-GPT application."
########################
# Application Settings #
########################
project_root: Path = PROJECT_ROOT
app_data_dir: Path = project_root / "data"
skip_news: bool = False
skip_reprompt: bool = False
authorise_key: str = UserConfigurable(default="y", from_env="AUTHORISE_COMMAND_KEY")
exit_key: str = UserConfigurable(default="n", from_env="EXIT_KEY")
noninteractive_mode: bool = False
# TTS configuration
logging: LoggingConfig = LoggingConfig()
tts_config: TTSConfig = TTSConfig()
# File storage
file_storage_backend: FileStorageBackendName = UserConfigurable(
default=FileStorageBackendName.LOCAL, from_env="FILE_STORAGE_BACKEND"
)
##########################
# Agent Control Settings #
##########################
# Model configuration
fast_llm: ModelName = UserConfigurable(
default=OpenAIModelName.GPT3,
from_env="FAST_LLM",
)
smart_llm: ModelName = UserConfigurable(
default=OpenAIModelName.GPT4_TURBO,
from_env="SMART_LLM",
)
temperature: float = UserConfigurable(default=0, from_env="TEMPERATURE")
openai_functions: bool = UserConfigurable(
default=False, from_env=lambda: os.getenv("OPENAI_FUNCTIONS", "False") == "True"
)
embedding_model: str = UserConfigurable(
default="text-embedding-3-small", from_env="EMBEDDING_MODEL"
)
browse_spacy_language_model: str = UserConfigurable(
default="en_core_web_sm", from_env="BROWSE_SPACY_LANGUAGE_MODEL"
)
# Run loop configuration
continuous_mode: bool = False
continuous_limit: int = 0
############
# Commands #
############
# General
disabled_commands: list[str] = UserConfigurable(
default_factory=list,
from_env=lambda: _safe_split(os.getenv("DISABLED_COMMANDS")),
)
# File ops
restrict_to_workspace: bool = UserConfigurable(
default=True,
from_env=lambda: os.getenv("RESTRICT_TO_WORKSPACE", "True") == "True",
)
allow_downloads: bool = False
# Shell commands
shell_command_control: str = UserConfigurable(
default="denylist", from_env="SHELL_COMMAND_CONTROL"
)
execute_local_commands: bool = UserConfigurable(
default=False,
from_env=lambda: os.getenv("EXECUTE_LOCAL_COMMANDS", "False") == "True",
)
shell_denylist: list[str] = UserConfigurable(
default_factory=lambda: ["sudo", "su"],
from_env=lambda: _safe_split(
os.getenv("SHELL_DENYLIST", os.getenv("DENY_COMMANDS"))
),
)
shell_allowlist: list[str] = UserConfigurable(
default_factory=list,
from_env=lambda: _safe_split(
os.getenv("SHELL_ALLOWLIST", os.getenv("ALLOW_COMMANDS"))
),
)
# Text to image
image_provider: Optional[str] = UserConfigurable(from_env="IMAGE_PROVIDER")
huggingface_image_model: str = UserConfigurable(
default="CompVis/stable-diffusion-v1-4", from_env="HUGGINGFACE_IMAGE_MODEL"
)
sd_webui_url: Optional[str] = UserConfigurable(
default="http://localhost:7860", from_env="SD_WEBUI_URL"
)
image_size: int = UserConfigurable(default=256, from_env="IMAGE_SIZE")
# Audio to text
audio_to_text_provider: str = UserConfigurable(
default="huggingface", from_env="AUDIO_TO_TEXT_PROVIDER"
)
huggingface_audio_to_text_model: Optional[str] = UserConfigurable(
from_env="HUGGINGFACE_AUDIO_TO_TEXT_MODEL"
)
# Web browsing
selenium_web_browser: str = UserConfigurable("chrome", from_env="USE_WEB_BROWSER")
selenium_headless: bool = UserConfigurable(
default=True, from_env=lambda: os.getenv("HEADLESS_BROWSER", "True") == "True"
)
user_agent: str = UserConfigurable(
default="Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_4) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/83.0.4103.97 Safari/537.36", # noqa: E501
from_env="USER_AGENT",
)
###############
# Credentials #
###############
# OpenAI
openai_credentials: Optional[OpenAICredentials] = None
azure_config_file: Optional[Path] = UserConfigurable(
default=AZURE_CONFIG_FILE, from_env="AZURE_CONFIG_FILE"
)
# Github
github_api_key: Optional[str] = UserConfigurable(from_env="GITHUB_API_KEY")
github_username: Optional[str] = UserConfigurable(from_env="GITHUB_USERNAME")
# Google
google_api_key: Optional[str] = UserConfigurable(from_env="GOOGLE_API_KEY")
google_custom_search_engine_id: Optional[str] = UserConfigurable(
from_env="GOOGLE_CUSTOM_SEARCH_ENGINE_ID",
)
# Huggingface
huggingface_api_token: Optional[str] = UserConfigurable(
from_env="HUGGINGFACE_API_TOKEN"
)
# Stable Diffusion
sd_webui_auth: Optional[str] = UserConfigurable(from_env="SD_WEBUI_AUTH")
@validator("openai_functions")
def validate_openai_functions(cls, v: bool, values: dict[str, Any]):
if v:
smart_llm = values["smart_llm"]
assert CHAT_MODELS[smart_llm].has_function_call_api, (
f"Model {smart_llm} does not support tool calling. "
"Please disable OPENAI_FUNCTIONS or choose a suitable model."
)
return v
class ConfigBuilder(Configurable[Config]):
default_settings = Config()
@classmethod
def build_config_from_env(cls, project_root: Path = PROJECT_ROOT) -> Config:
"""Initialize the Config class"""
config = cls.build_agent_configuration()
config.project_root = project_root
# Make relative paths absolute
for k in {
"azure_config_file", # TODO: move from project root
}:
setattr(config, k, project_root / getattr(config, k))
if (
config.openai_credentials
and config.openai_credentials.api_type == SecretStr("azure")
and (config_file := config.azure_config_file)
):
config.openai_credentials.load_azure_config(config_file)
return config
def assert_config_has_openai_api_key(config: Config) -> None:
"""Check if the OpenAI API key is set in config.py or as an environment variable."""
key_pattern = r"^sk-(proj-)?\w{48}"
openai_api_key = (
config.openai_credentials.api_key.get_secret_value()
if config.openai_credentials
else ""
)
# If there's no credentials or empty API key, prompt the user to set it
if not openai_api_key:
logger.error(
"Please set your OpenAI API key in .env or as an environment variable."
)
logger.info(
"You can get your key from https://platform.openai.com/account/api-keys"
)
openai_api_key = click.prompt(
"Please enter your OpenAI API key if you have it",
default="",
show_default=False,
)
openai_api_key = openai_api_key.strip()
if re.search(key_pattern, openai_api_key):
os.environ["OPENAI_API_KEY"] = openai_api_key
if config.openai_credentials:
config.openai_credentials.api_key = SecretStr(openai_api_key)
else:
config.openai_credentials = OpenAICredentials(
api_key=SecretStr(openai_api_key)
)
print("OpenAI API key successfully set!")
print(
f"{Fore.YELLOW}NOTE: The API key you've set is only temporary. "
f"For longer sessions, please set it in the .env file{Fore.RESET}"
)
else:
print(f"{Fore.RED}Invalid OpenAI API key{Fore.RESET}")
exit(1)
# If key is set, but it looks invalid
elif not re.search(key_pattern, openai_api_key):
logger.error(
"Invalid OpenAI API key! "
"Please set your OpenAI API key in .env or as an environment variable."
)
logger.info(
"You can get your key from https://platform.openai.com/account/api-keys"
)
exit(1)
def _safe_split(s: Union[str, None], sep: str = ",") -> list[str]:
"""Split a string by a separator. Return an empty list if the string is None."""
if s is None:
return []
return s.split(sep)

8
forge/forge/conftest.py Normal file
View File

@@ -0,0 +1,8 @@
from pathlib import Path
import pytest
@pytest.fixture()
def test_workspace(tmp_path: Path) -> Path:
return tmp_path

View File

@@ -0,0 +1,33 @@
"""HTML processing functions"""
from __future__ import annotations
from bs4 import BeautifulSoup
from requests.compat import urljoin
def extract_hyperlinks(soup: BeautifulSoup, base_url: str) -> list[tuple[str, str]]:
"""Extract hyperlinks from a BeautifulSoup object
Args:
soup (BeautifulSoup): The BeautifulSoup object
base_url (str): The base URL
Returns:
List[Tuple[str, str]]: The extracted hyperlinks
"""
return [
(link.text, urljoin(base_url, link["href"]))
for link in soup.find_all("a", href=True)
]
def format_hyperlinks(hyperlinks: list[tuple[str, str]]) -> list[str]:
"""Format hyperlinks to be displayed to the user
Args:
hyperlinks (List[Tuple[str, str]]): The hyperlinks to format
Returns:
List[str]: The formatted hyperlinks
"""
return [f"{link_text.strip()} ({link_url})" for link_text, link_url in hyperlinks]

View File

@@ -0,0 +1,317 @@
"""Text processing functions"""
from __future__ import annotations
import logging
import math
from typing import TYPE_CHECKING, Iterator, Optional, TypeVar
import spacy
if TYPE_CHECKING:
from forge.config.config import Config
from forge.json.parsing import extract_list_from_json
from forge.llm.prompting import ChatPrompt
from forge.llm.providers import ChatMessage, ModelTokenizer, MultiProvider
logger = logging.getLogger(__name__)
T = TypeVar("T")
def batch(
sequence: list[T], max_batch_length: int, overlap: int = 0
) -> Iterator[list[T]]:
"""
Batch data from iterable into slices of length N. The last batch may be shorter.
Example: `batched('ABCDEFGHIJ', 3)` --> `ABC DEF GHI J`
"""
if max_batch_length < 1:
raise ValueError("n must be at least one")
for i in range(0, len(sequence), max_batch_length - overlap):
yield sequence[i : i + max_batch_length]
def chunk_content(
content: str,
max_chunk_length: int,
tokenizer: ModelTokenizer,
with_overlap: bool = True,
) -> Iterator[tuple[str, int]]:
"""Split content into chunks of approximately equal token length."""
MAX_OVERLAP = 200 # limit overlap to save tokens
tokenized_text = tokenizer.encode(content)
total_length = len(tokenized_text)
n_chunks = math.ceil(total_length / max_chunk_length)
chunk_length = math.ceil(total_length / n_chunks)
overlap = min(max_chunk_length - chunk_length, MAX_OVERLAP) if with_overlap else 0
for token_batch in batch(tokenized_text, chunk_length + overlap, overlap):
yield tokenizer.decode(token_batch), len(token_batch)
async def summarize_text(
text: str,
llm_provider: MultiProvider,
config: Config,
question: Optional[str] = None,
instruction: Optional[str] = None,
) -> tuple[str, list[tuple[str, str]]]:
if question:
if instruction:
raise ValueError(
"Parameters 'question' and 'instructions' cannot both be set"
)
instruction = (
f'From the text, answer the question: "{question}". '
"If the answer is not in the text, indicate this clearly "
"and concisely state why the text is not suitable to answer the question."
)
elif not instruction:
instruction = (
"Summarize or describe the text clearly and concisely, "
"whichever seems more appropriate."
)
return await _process_text( # type: ignore
text=text,
instruction=instruction,
llm_provider=llm_provider,
config=config,
)
async def extract_information(
source_text: str,
topics_of_interest: list[str],
llm_provider: MultiProvider,
config: Config,
) -> list[str]:
fmt_topics_list = "\n".join(f"* {topic}." for topic in topics_of_interest)
instruction = (
"Extract relevant pieces of information about the following topics:\n"
f"{fmt_topics_list}\n"
"Reword pieces of information if needed to make them self-explanatory. "
"Be concise.\n\n"
"Respond with an `Array<string>` in JSON format AND NOTHING ELSE. "
'If the text contains no relevant information, return "[]".'
)
return await _process_text( # type: ignore
text=source_text,
instruction=instruction,
output_type=list[str],
llm_provider=llm_provider,
config=config,
)
async def _process_text(
text: str,
instruction: str,
llm_provider: MultiProvider,
config: Config,
output_type: type[str | list[str]] = str,
) -> tuple[str, list[tuple[str, str]]] | list[str]:
"""Process text using the OpenAI API for summarization or information extraction
Params:
text (str): The text to process.
instruction (str): Additional instruction for processing.
llm_provider: LLM provider to use.
config (Config): The global application config.
output_type: `str` for summaries or `list[str]` for piece-wise info extraction.
Returns:
For summarization: tuple[str, None | list[(summary, chunk)]]
For piece-wise information extraction: list[str]
"""
if not text.strip():
raise ValueError("No content")
model = config.fast_llm
text_tlength = llm_provider.count_tokens(text, model)
logger.debug(f"Text length: {text_tlength} tokens")
max_result_tokens = 500
max_chunk_length = llm_provider.get_token_limit(model) - max_result_tokens - 50
logger.debug(f"Max chunk length: {max_chunk_length} tokens")
if text_tlength < max_chunk_length:
prompt = ChatPrompt(
messages=[
ChatMessage.system(
"The user is going to give you a text enclosed in triple quotes. "
f"{instruction}"
),
ChatMessage.user(f'"""{text}"""'),
]
)
logger.debug(f"PROCESSING:\n{prompt}")
response = await llm_provider.create_chat_completion(
model_prompt=prompt.messages,
model_name=model,
temperature=0.5,
max_output_tokens=max_result_tokens,
completion_parser=lambda s: (
extract_list_from_json(s.content) if output_type is not str else None
),
)
if isinstance(response.parsed_result, list):
logger.debug(f"Raw LLM response: {repr(response.response.content)}")
fmt_result_bullet_list = "\n".join(f"* {r}" for r in response.parsed_result)
logger.debug(
f"\n{'-'*11} EXTRACTION RESULT {'-'*12}\n"
f"{fmt_result_bullet_list}\n"
f"{'-'*42}\n"
)
return response.parsed_result
else:
summary = response.response.content
logger.debug(f"\n{'-'*16} SUMMARY {'-'*17}\n{summary}\n{'-'*42}\n")
return summary.strip(), [(summary, text)]
else:
chunks = list(
split_text(
text,
config=config,
max_chunk_length=max_chunk_length,
tokenizer=llm_provider.get_tokenizer(model),
)
)
processed_results = []
for i, (chunk, _) in enumerate(chunks):
logger.info(f"Processing chunk {i + 1} / {len(chunks)}")
chunk_result = await _process_text(
text=chunk,
instruction=instruction,
output_type=output_type,
llm_provider=llm_provider,
config=config,
)
processed_results.extend(
chunk_result if output_type == list[str] else [chunk_result]
)
if output_type == list[str]:
return processed_results
else:
summary, _ = await _process_text(
"\n\n".join([result[0] for result in processed_results]),
instruction=(
"The text consists of multiple partial summaries. "
"Combine these partial summaries into one."
),
llm_provider=llm_provider,
config=config,
)
return summary.strip(), [
(processed_results[i], chunks[i][0]) for i in range(0, len(chunks))
]
def split_text(
text: str,
config: Config,
max_chunk_length: int,
tokenizer: ModelTokenizer,
with_overlap: bool = True,
) -> Iterator[tuple[str, int]]:
"""
Split text into chunks of sentences, with each chunk not exceeding the max length.
Args:
text (str): The text to split.
config (Config): Config object containing the Spacy model setting.
max_chunk_length (int, optional): The maximum length of a chunk.
tokenizer (ModelTokenizer): Tokenizer to use for determining chunk length.
with_overlap (bool, optional): Whether to allow overlap between chunks.
Yields:
str: The next chunk of text
Raises:
ValueError: when a sentence is longer than the maximum length
"""
text_length = len(tokenizer.encode(text))
if text_length < max_chunk_length:
yield text, text_length
return
n_chunks = math.ceil(text_length / max_chunk_length)
target_chunk_length = math.ceil(text_length / n_chunks)
nlp: spacy.language.Language = spacy.load(config.browse_spacy_language_model)
nlp.add_pipe("sentencizer")
doc = nlp(text)
sentences = [sentence.text.strip() for sentence in doc.sents]
current_chunk: list[str] = []
current_chunk_length = 0
last_sentence = None
last_sentence_length = 0
i = 0
while i < len(sentences):
sentence = sentences[i]
sentence_length = len(tokenizer.encode(sentence))
expected_chunk_length = current_chunk_length + 1 + sentence_length
if (
expected_chunk_length < max_chunk_length
# try to create chunks of approximately equal size
and expected_chunk_length - (sentence_length / 2) < target_chunk_length
):
current_chunk.append(sentence)
current_chunk_length = expected_chunk_length
elif sentence_length < max_chunk_length:
if last_sentence:
yield " ".join(current_chunk), current_chunk_length
current_chunk = []
current_chunk_length = 0
if with_overlap:
overlap_max_length = max_chunk_length - sentence_length - 1
if last_sentence_length < overlap_max_length:
current_chunk += [last_sentence]
current_chunk_length += last_sentence_length + 1
elif overlap_max_length > 5:
# add as much from the end of the last sentence as fits
current_chunk += [
list(
chunk_content(
content=last_sentence,
max_chunk_length=overlap_max_length,
tokenizer=tokenizer,
)
).pop()[0],
]
current_chunk_length += overlap_max_length + 1
current_chunk += [sentence]
current_chunk_length += sentence_length
else: # sentence longer than maximum length -> chop up and try again
sentences[i : i + 1] = [
chunk
for chunk, _ in chunk_content(sentence, target_chunk_length, tokenizer)
]
continue
i += 1
last_sentence = sentence
last_sentence_length = sentence_length
if current_chunk:
yield " ".join(current_chunk), current_chunk_length

View File

@@ -0,0 +1,37 @@
import enum
from pathlib import Path
from .base import FileStorage
class FileStorageBackendName(str, enum.Enum):
LOCAL = "local"
GCS = "gcs"
S3 = "s3"
def get_storage(
backend: FileStorageBackendName,
root_path: Path = Path("."),
restrict_to_root: bool = True,
) -> FileStorage:
match backend:
case FileStorageBackendName.LOCAL:
from .local import FileStorageConfiguration, LocalFileStorage
config = FileStorageConfiguration.from_env()
config.root = root_path
config.restrict_to_root = restrict_to_root
return LocalFileStorage(config)
case FileStorageBackendName.S3:
from .s3 import S3FileStorage, S3FileStorageConfiguration
config = S3FileStorageConfiguration.from_env()
config.root = root_path
return S3FileStorage(config)
case FileStorageBackendName.GCS:
from .gcs import GCSFileStorage, GCSFileStorageConfiguration
config = GCSFileStorageConfiguration.from_env()
config.root = root_path
return GCSFileStorage(config)

View File

@@ -0,0 +1,283 @@
"""
The FileStorage class provides an interface for interacting with a file storage.
"""
from __future__ import annotations
import asyncio
import logging
import os
import shutil
import tempfile
from abc import ABC, abstractmethod
from contextlib import contextmanager
from pathlib import Path
from typing import Any, BinaryIO, Callable, Generator, Literal, TextIO, overload
from watchdog.events import FileSystemEvent, FileSystemEventHandler
from watchdog.observers import Observer
from forge.models.config import SystemConfiguration
logger = logging.getLogger(__name__)
class FileStorageConfiguration(SystemConfiguration):
restrict_to_root: bool = True
root: Path = Path("/")
class FileStorage(ABC):
"""A class that represents a file storage."""
on_write_file: Callable[[Path], Any] | None = None
"""
Event hook, executed after writing a file.
Params:
Path: The path of the file that was written, relative to the storage root.
"""
@property
@abstractmethod
def root(self) -> Path:
"""The root path of the file storage."""
@property
@abstractmethod
def restrict_to_root(self) -> bool:
"""Whether to restrict file access to within the storage's root path."""
@property
@abstractmethod
def is_local(self) -> bool:
"""Whether the storage is local (i.e. on the same machine, not cloud-based)."""
@abstractmethod
def initialize(self) -> None:
"""
Calling `initialize()` should bring the storage to a ready-to-use state.
For example, it can create the resource in which files will be stored, if it
doesn't exist yet. E.g. a folder on disk, or an S3 Bucket.
"""
@overload
@abstractmethod
def open_file(
self,
path: str | Path,
mode: Literal["r", "w"] = "r",
binary: Literal[False] = False,
) -> TextIO:
"""Returns a readable text file-like object representing the file."""
@overload
@abstractmethod
def open_file(
self, path: str | Path, mode: Literal["r", "w"], binary: Literal[True]
) -> BinaryIO:
"""Returns a binary file-like object representing the file."""
@overload
@abstractmethod
def open_file(self, path: str | Path, *, binary: Literal[True]) -> BinaryIO:
"""Returns a readable binary file-like object representing the file."""
@overload
@abstractmethod
def open_file(
self, path: str | Path, mode: Literal["r", "w"] = "r", binary: bool = False
) -> TextIO | BinaryIO:
"""Returns a file-like object representing the file."""
@overload
@abstractmethod
def read_file(self, path: str | Path, binary: Literal[False] = False) -> str:
"""Read a file in the storage as text."""
...
@overload
@abstractmethod
def read_file(self, path: str | Path, binary: Literal[True]) -> bytes:
"""Read a file in the storage as binary."""
...
@overload
@abstractmethod
def read_file(self, path: str | Path, binary: bool = False) -> str | bytes:
"""Read a file in the storage."""
...
@abstractmethod
async def write_file(self, path: str | Path, content: str | bytes) -> None:
"""Write to a file in the storage."""
@abstractmethod
def list_files(self, path: str | Path = ".") -> list[Path]:
"""List all files (recursively) in a directory in the storage."""
@abstractmethod
def list_folders(
self, path: str | Path = ".", recursive: bool = False
) -> list[Path]:
"""List all folders in a directory in the storage."""
@abstractmethod
def delete_file(self, path: str | Path) -> None:
"""Delete a file in the storage."""
@abstractmethod
def delete_dir(self, path: str | Path) -> None:
"""Delete an empty folder in the storage."""
@abstractmethod
def exists(self, path: str | Path) -> bool:
"""Check if a file or folder exists in the storage."""
@abstractmethod
def rename(self, old_path: str | Path, new_path: str | Path) -> None:
"""Rename a file or folder in the storage."""
@abstractmethod
def copy(self, source: str | Path, destination: str | Path) -> None:
"""Copy a file or folder with all contents in the storage."""
@abstractmethod
def make_dir(self, path: str | Path) -> None:
"""Create a directory in the storage if doesn't exist."""
@abstractmethod
def clone_with_subroot(self, subroot: str | Path) -> FileStorage:
"""Create a new FileStorage with a subroot of the current storage."""
def get_path(self, relative_path: str | Path) -> Path:
"""Get the full path for an item in the storage.
Parameters:
relative_path: The relative path to resolve in the storage.
Returns:
Path: The resolved path relative to the storage.
"""
return self._sanitize_path(relative_path)
@contextmanager
def mount(self, path: str | Path = ".") -> Generator[Path, Any, None]:
"""Mount the file storage and provide a local path."""
local_path = tempfile.mkdtemp(dir=path)
observer = Observer()
try:
# Copy all files to the local directory
files = self.list_files()
for file in files:
file_path = local_path / file
file_path.parent.mkdir(parents=True, exist_ok=True)
content = self.read_file(file, binary=True)
file_path.write_bytes(content)
# Sync changes
event_handler = FileSyncHandler(self, local_path)
observer.schedule(event_handler, local_path, recursive=True)
observer.start()
yield Path(local_path)
finally:
observer.stop()
observer.join()
shutil.rmtree(local_path)
def _sanitize_path(
self,
path: str | Path,
) -> Path:
"""Resolve the relative path within the given root if possible.
Parameters:
relative_path: The relative path to resolve.
Returns:
Path: The resolved path.
Raises:
ValueError: If the path is absolute and a root is provided.
ValueError: If the path is outside the root and the root is restricted.
"""
# Posix systems disallow null bytes in paths. Windows is agnostic about it.
# Do an explicit check here for all sorts of null byte representations.
if "\0" in str(path):
raise ValueError("Embedded null byte")
logger.debug(f"Resolving path '{path}' in storage '{self.root}'")
relative_path = Path(path)
# Allow absolute paths if they are contained in the storage.
if (
relative_path.is_absolute()
and self.restrict_to_root
and not relative_path.is_relative_to(self.root)
):
raise ValueError(
f"Attempted to access absolute path '{relative_path}' "
f"in storage '{self.root}'"
)
full_path = self.root / relative_path
if self.is_local:
full_path = full_path.resolve()
else:
full_path = Path(os.path.normpath(full_path))
logger.debug(f"Joined paths as '{full_path}'")
if self.restrict_to_root and not full_path.is_relative_to(self.root):
raise ValueError(
f"Attempted to access path '{full_path}' "
f"outside of storage '{self.root}'."
)
return full_path
class FileSyncHandler(FileSystemEventHandler):
def __init__(self, storage: FileStorage, path: str | Path = "."):
self.storage = storage
self.path = Path(path)
def on_modified(self, event: FileSystemEvent):
if event.is_directory:
return
file_path = Path(event.src_path).relative_to(self.path)
content = file_path.read_bytes()
# Must execute write_file synchronously because the hook is synchronous
# TODO: Schedule write operation using asyncio.create_task (non-blocking)
asyncio.get_event_loop().run_until_complete(
self.storage.write_file(file_path, content)
)
def on_created(self, event: FileSystemEvent):
if event.is_directory:
self.storage.make_dir(event.src_path)
return
file_path = Path(event.src_path).relative_to(self.path)
content = file_path.read_bytes()
# Must execute write_file synchronously because the hook is synchronous
# TODO: Schedule write operation using asyncio.create_task (non-blocking)
asyncio.get_event_loop().run_until_complete(
self.storage.write_file(file_path, content)
)
def on_deleted(self, event: FileSystemEvent):
if event.is_directory:
self.storage.delete_dir(event.src_path)
return
file_path = event.src_path
self.storage.delete_file(file_path)
def on_moved(self, event: FileSystemEvent):
self.storage.rename(event.src_path, event.dest_path)

View File

@@ -0,0 +1,267 @@
"""
The GCSWorkspace class provides an interface for interacting with a file workspace, and
stores the files in a Google Cloud Storage bucket.
"""
from __future__ import annotations
import inspect
import logging
from io import TextIOWrapper
from pathlib import Path
from typing import Literal, overload
from google.cloud import storage
from google.cloud.exceptions import NotFound
from google.cloud.storage.fileio import BlobReader, BlobWriter
from forge.models.config import UserConfigurable
from .base import FileStorage, FileStorageConfiguration
logger = logging.getLogger(__name__)
class GCSFileStorageConfiguration(FileStorageConfiguration):
bucket: str = UserConfigurable("autogpt", from_env="STORAGE_BUCKET")
class GCSFileStorage(FileStorage):
"""A class that represents a Google Cloud Storage."""
_bucket: storage.Bucket
def __init__(self, config: GCSFileStorageConfiguration):
self._bucket_name = config.bucket
self._root = config.root
# Add / at the beginning of the root path
if not self._root.is_absolute():
self._root = Path("/").joinpath(self._root)
self._gcs = storage.Client()
super().__init__()
@property
def root(self) -> Path:
"""The root directory of the file storage."""
return self._root
@property
def restrict_to_root(self) -> bool:
"""Whether to restrict generated paths to the root."""
return True
@property
def is_local(self) -> bool:
"""Whether the storage is local (i.e. on the same machine, not cloud-based)."""
return False
def initialize(self) -> None:
logger.debug(f"Initializing {repr(self)}...")
try:
self._bucket = self._gcs.get_bucket(self._bucket_name)
except NotFound:
logger.info(f"Bucket '{self._bucket_name}' does not exist; creating it...")
self._bucket = self._gcs.create_bucket(self._bucket_name)
def get_path(self, relative_path: str | Path) -> Path:
# We set GCS root with "/" at the beginning
# but relative_to("/") will remove it
# because we don't actually want it in the storage filenames
return super().get_path(relative_path).relative_to("/")
def _get_blob(self, path: str | Path) -> storage.Blob:
path = self.get_path(path)
return self._bucket.blob(str(path))
@overload
def open_file(
self,
path: str | Path,
mode: Literal["r", "w"] = "r",
binary: Literal[False] = False,
) -> TextIOWrapper:
...
@overload
def open_file(
self, path: str | Path, mode: Literal["r"], binary: Literal[True]
) -> BlobReader:
...
@overload
def open_file(
self, path: str | Path, mode: Literal["w"], binary: Literal[True]
) -> BlobWriter:
...
@overload
def open_file(
self, path: str | Path, mode: Literal["r", "w"], binary: Literal[True]
) -> BlobWriter | BlobReader:
...
@overload
def open_file(self, path: str | Path, *, binary: Literal[True]) -> BlobReader:
...
@overload
def open_file(
self, path: str | Path, mode: Literal["r", "w"] = "r", binary: bool = False
) -> BlobReader | BlobWriter | TextIOWrapper:
...
# https://github.com/microsoft/pyright/issues/8007
def open_file( # pyright: ignore[reportIncompatibleMethodOverride]
self, path: str | Path, mode: Literal["r", "w"] = "r", binary: bool = False
) -> BlobReader | BlobWriter | TextIOWrapper:
"""Open a file in the storage."""
blob = self._get_blob(path)
blob.reload() # pin revision number to prevent version mixing while reading
return blob.open(f"{mode}b" if binary else mode)
@overload
def read_file(self, path: str | Path, binary: Literal[False] = False) -> str:
"""Read a file in the storage as text."""
...
@overload
def read_file(self, path: str | Path, binary: Literal[True]) -> bytes:
"""Read a file in the storage as binary."""
...
@overload
def read_file(self, path: str | Path, binary: bool = False) -> str | bytes:
"""Read a file in the storage."""
...
def read_file(self, path: str | Path, binary: bool = False) -> str | bytes:
"""Read a file in the storage."""
return self.open_file(path, "r", binary).read()
async def write_file(self, path: str | Path, content: str | bytes) -> None:
"""Write to a file in the storage."""
blob = self._get_blob(path)
blob.upload_from_string(
data=content,
content_type=(
"text/plain"
if type(content) is str
# TODO: get MIME type from file extension or binary content
else "application/octet-stream"
),
)
if self.on_write_file:
path = Path(path)
if path.is_absolute():
path = path.relative_to(self.root)
res = self.on_write_file(path)
if inspect.isawaitable(res):
await res
def list_files(self, path: str | Path = ".") -> list[Path]:
"""List all files (recursively) in a directory in the storage."""
path = self.get_path(path)
return [
Path(blob.name).relative_to(path)
for blob in self._bucket.list_blobs(
prefix=f"{path}/" if path != Path(".") else None
)
]
def list_folders(
self, path: str | Path = ".", recursive: bool = False
) -> list[Path]:
"""List 'directories' directly in a given path or recursively in the storage."""
path = self.get_path(path)
folder_names = set()
# List objects with the specified prefix and delimiter
for blob in self._bucket.list_blobs(prefix=path):
# Remove path prefix and the object name (last part)
folder = Path(blob.name).relative_to(path).parent
if not folder or folder == Path("."):
continue
# For non-recursive, only add the first level of folders
if not recursive:
folder_names.add(folder.parts[0])
else:
# For recursive, need to add all nested folders
for i in range(len(folder.parts)):
folder_names.add("/".join(folder.parts[: i + 1]))
return [Path(f) for f in folder_names]
def delete_file(self, path: str | Path) -> None:
"""Delete a file in the storage."""
path = self.get_path(path)
blob = self._bucket.blob(str(path))
blob.delete()
def delete_dir(self, path: str | Path) -> None:
"""Delete an empty folder in the storage."""
# Since GCS does not have directories, we don't need to do anything
pass
def exists(self, path: str | Path) -> bool:
"""Check if a file or folder exists in GCS storage."""
path = self.get_path(path)
# Check for exact blob match (file)
blob = self._bucket.blob(str(path))
if blob.exists():
return True
# Check for any blobs with prefix (folder)
prefix = f"{str(path).rstrip('/')}/"
blobs = self._bucket.list_blobs(prefix=prefix, max_results=1)
return next(blobs, None) is not None
def make_dir(self, path: str | Path) -> None:
"""Create a directory in the storage if doesn't exist."""
# GCS does not have directories, so we don't need to do anything
pass
def rename(self, old_path: str | Path, new_path: str | Path) -> None:
"""Rename a file or folder in the storage."""
old_path = self.get_path(old_path)
new_path = self.get_path(new_path)
blob = self._bucket.blob(str(old_path))
# If the blob with exact name exists, rename it
if blob.exists():
self._bucket.rename_blob(blob, new_name=str(new_path))
return
# Otherwise, rename all blobs with the prefix (folder)
for blob in self._bucket.list_blobs(prefix=f"{old_path}/"):
new_name = str(blob.name).replace(str(old_path), str(new_path), 1)
self._bucket.rename_blob(blob, new_name=new_name)
def copy(self, source: str | Path, destination: str | Path) -> None:
"""Copy a file or folder with all contents in the storage."""
source = self.get_path(source)
destination = self.get_path(destination)
# If the source is a file, copy it
if self._bucket.blob(str(source)).exists():
self._bucket.copy_blob(
self._bucket.blob(str(source)), self._bucket, str(destination)
)
return
# Otherwise, copy all blobs with the prefix (folder)
for blob in self._bucket.list_blobs(prefix=f"{source}/"):
new_name = str(blob.name).replace(str(source), str(destination), 1)
self._bucket.copy_blob(blob, self._bucket, new_name)
def clone_with_subroot(self, subroot: str | Path) -> GCSFileStorage:
"""Create a new GCSFileStorage with a subroot of the current storage."""
file_storage = GCSFileStorage(
GCSFileStorageConfiguration(
root=Path("/").joinpath(self.get_path(subroot)),
bucket=self._bucket_name,
)
)
file_storage._gcs = self._gcs
file_storage._bucket = self._bucket
return file_storage
def __repr__(self) -> str:
return f"{__class__.__name__}(bucket='{self._bucket_name}', root={self._root})"

View File

@@ -0,0 +1,188 @@
"""
The LocalFileStorage class implements a FileStorage that works with local files.
"""
from __future__ import annotations
import inspect
import logging
from contextlib import contextmanager
from pathlib import Path
from typing import Any, BinaryIO, Generator, Literal, TextIO, overload
from .base import FileStorage, FileStorageConfiguration
logger = logging.getLogger(__name__)
class LocalFileStorage(FileStorage):
"""A class that represents a file storage."""
def __init__(self, config: FileStorageConfiguration):
self._root = config.root.resolve()
self._restrict_to_root = config.restrict_to_root
self.make_dir(self.root)
super().__init__()
@property
def root(self) -> Path:
"""The root directory of the file storage."""
return self._root
@property
def restrict_to_root(self) -> bool:
"""Whether to restrict generated paths to the root."""
return self._restrict_to_root
@property
def is_local(self) -> bool:
"""Whether the storage is local (i.e. on the same machine, not cloud-based)."""
return True
def initialize(self) -> None:
self.root.mkdir(exist_ok=True, parents=True)
@overload
def open_file(
self,
path: str | Path,
mode: Literal["w", "r"] = "r",
binary: Literal[False] = False,
) -> TextIO:
...
@overload
def open_file(
self, path: str | Path, mode: Literal["w", "r"], binary: Literal[True]
) -> BinaryIO:
...
@overload
def open_file(self, path: str | Path, *, binary: Literal[True]) -> BinaryIO:
...
@overload
def open_file(
self, path: str | Path, mode: Literal["w", "r"] = "r", binary: bool = False
) -> TextIO | BinaryIO:
...
def open_file(
self, path: str | Path, mode: Literal["w", "r"] = "r", binary: bool = False
) -> TextIO | BinaryIO:
"""Open a file in the storage."""
return self._open_file(path, f"{mode}b" if binary else mode)
def _open_file(self, path: str | Path, mode: str) -> TextIO | BinaryIO:
full_path = self.get_path(path)
if any(m in mode for m in ("w", "a", "x")):
full_path.parent.mkdir(parents=True, exist_ok=True)
return open(full_path, mode) # type: ignore
@overload
def read_file(self, path: str | Path, binary: Literal[False] = False) -> str:
"""Read a file in the storage as text."""
...
@overload
def read_file(self, path: str | Path, binary: Literal[True]) -> bytes:
"""Read a file in the storage as binary."""
...
@overload
def read_file(self, path: str | Path, binary: bool = False) -> str | bytes:
"""Read a file in the storage."""
...
def read_file(self, path: str | Path, binary: bool = False) -> str | bytes:
"""Read a file in the storage."""
with self._open_file(path, "rb" if binary else "r") as file:
return file.read()
async def write_file(self, path: str | Path, content: str | bytes) -> None:
"""Write to a file in the storage."""
with self._open_file(path, "wb" if type(content) is bytes else "w") as file:
file.write(content) # type: ignore
if self.on_write_file:
path = Path(path)
if path.is_absolute():
path = path.relative_to(self.root)
res = self.on_write_file(path)
if inspect.isawaitable(res):
await res
def list_files(self, path: str | Path = ".") -> list[Path]:
"""List all files (recursively) in a directory in the storage."""
path = self.get_path(path)
return [file.relative_to(path) for file in path.rglob("*") if file.is_file()]
def list_folders(
self, path: str | Path = ".", recursive: bool = False
) -> list[Path]:
"""List directories directly in a given path or recursively."""
path = self.get_path(path)
if recursive:
return [
folder.relative_to(path)
for folder in path.rglob("*")
if folder.is_dir()
]
else:
return [
folder.relative_to(path) for folder in path.iterdir() if folder.is_dir()
]
def delete_file(self, path: str | Path) -> None:
"""Delete a file in the storage."""
full_path = self.get_path(path)
full_path.unlink()
def delete_dir(self, path: str | Path) -> None:
"""Delete an empty folder in the storage."""
full_path = self.get_path(path)
full_path.rmdir()
def exists(self, path: str | Path) -> bool:
"""Check if a file or folder exists in the storage."""
return self.get_path(path).exists()
def make_dir(self, path: str | Path) -> None:
"""Create a directory in the storage if doesn't exist."""
full_path = self.get_path(path)
full_path.mkdir(exist_ok=True, parents=True)
def rename(self, old_path: str | Path, new_path: str | Path) -> None:
"""Rename a file or folder in the storage."""
old_path = self.get_path(old_path)
new_path = self.get_path(new_path)
old_path.rename(new_path)
def copy(self, source: str | Path, destination: str | Path) -> None:
"""Copy a file or folder with all contents in the storage."""
source = self.get_path(source)
destination = self.get_path(destination)
if source.is_file():
destination.write_bytes(source.read_bytes())
else:
destination.mkdir(exist_ok=True, parents=True)
for file in source.rglob("*"):
if file.is_file():
target = destination / file.relative_to(source)
target.parent.mkdir(exist_ok=True, parents=True)
target.write_bytes(file.read_bytes())
def clone_with_subroot(self, subroot: str | Path) -> FileStorage:
"""Create a new LocalFileStorage with a subroot of the current storage."""
return LocalFileStorage(
FileStorageConfiguration(
root=self.get_path(subroot),
restrict_to_root=self.restrict_to_root,
)
)
@contextmanager
def mount(self, path: str | Path = ".") -> Generator[Path, Any, None]:
"""Mount the file storage and provide a local path."""
# No need to do anything for local storage
yield Path(self.get_path(".")).absolute()

View File

@@ -0,0 +1,353 @@
"""
The S3Workspace class provides an interface for interacting with a file workspace, and
stores the files in an S3 bucket.
"""
from __future__ import annotations
import contextlib
import inspect
import logging
from io import TextIOWrapper
from pathlib import Path
from typing import TYPE_CHECKING, BinaryIO, Literal, Optional, overload
import boto3
import botocore.exceptions
from pydantic import SecretStr
from forge.models.config import UserConfigurable
from .base import FileStorage, FileStorageConfiguration
if TYPE_CHECKING:
import mypy_boto3_s3
from botocore.response import StreamingBody
logger = logging.getLogger(__name__)
class S3FileStorageConfiguration(FileStorageConfiguration):
bucket: str = UserConfigurable("autogpt", from_env="STORAGE_BUCKET")
s3_endpoint_url: Optional[SecretStr] = UserConfigurable(from_env="S3_ENDPOINT_URL")
class S3FileStorage(FileStorage):
"""A class that represents an S3 storage."""
_bucket: mypy_boto3_s3.service_resource.Bucket
def __init__(self, config: S3FileStorageConfiguration):
self._bucket_name = config.bucket
self._root = config.root
# Add / at the beginning of the root path
if not self._root.is_absolute():
self._root = Path("/").joinpath(self._root)
# https://boto3.amazonaws.com/v1/documentation/api/latest/guide/configuration.html
self._s3 = boto3.resource(
"s3",
endpoint_url=(
config.s3_endpoint_url.get_secret_value()
if config.s3_endpoint_url
else None
),
)
super().__init__()
@property
def root(self) -> Path:
"""The root directory of the file storage."""
return self._root
@property
def restrict_to_root(self):
"""Whether to restrict generated paths to the root."""
return True
@property
def is_local(self) -> bool:
"""Whether the storage is local (i.e. on the same machine, not cloud-based)."""
return False
def initialize(self) -> None:
logger.debug(f"Initializing {repr(self)}...")
try:
self._s3.meta.client.head_bucket(Bucket=self._bucket_name)
self._bucket = self._s3.Bucket(self._bucket_name)
except botocore.exceptions.ClientError as e:
if "(404)" not in str(e):
raise
logger.info(f"Bucket '{self._bucket_name}' does not exist; creating it...")
self._bucket = self._s3.create_bucket(Bucket=self._bucket_name)
def get_path(self, relative_path: str | Path) -> Path:
# We set S3 root with "/" at the beginning
# but relative_to("/") will remove it
# because we don't actually want it in the storage filenames
return super().get_path(relative_path).relative_to("/")
def _get_obj(self, path: str | Path) -> mypy_boto3_s3.service_resource.Object:
"""Get an S3 object."""
obj = self._bucket.Object(str(path))
with contextlib.suppress(botocore.exceptions.ClientError):
obj.load()
return obj
@overload
def open_file(
self,
path: str | Path,
mode: Literal["r", "w"] = "r",
binary: Literal[False] = False,
) -> TextIOWrapper:
...
@overload
def open_file(
self, path: str | Path, mode: Literal["r", "w"], binary: Literal[True]
) -> S3BinaryIOWrapper:
...
@overload
def open_file(
self, path: str | Path, *, binary: Literal[True]
) -> S3BinaryIOWrapper:
...
@overload
def open_file(
self, path: str | Path, mode: Literal["r", "w"] = "r", binary: bool = False
) -> S3BinaryIOWrapper | TextIOWrapper:
...
def open_file(
self, path: str | Path, mode: Literal["r", "w"] = "r", binary: bool = False
) -> TextIOWrapper | S3BinaryIOWrapper:
"""Open a file in the storage."""
path = self.get_path(path)
body = S3BinaryIOWrapper(self._get_obj(path).get()["Body"], str(path))
return body if binary else TextIOWrapper(body)
@overload
def read_file(self, path: str | Path, binary: Literal[False] = False) -> str:
"""Read a file in the storage as text."""
...
@overload
def read_file(self, path: str | Path, binary: Literal[True]) -> bytes:
"""Read a file in the storage as binary."""
...
@overload
def read_file(self, path: str | Path, binary: bool = False) -> str | bytes:
"""Read a file in the storage."""
...
def read_file(self, path: str | Path, binary: bool = False) -> str | bytes:
"""Read a file in the storage."""
return self.open_file(path, binary=binary).read()
async def write_file(self, path: str | Path, content: str | bytes) -> None:
"""Write to a file in the storage."""
obj = self._get_obj(self.get_path(path))
obj.put(Body=content)
if self.on_write_file:
path = Path(path)
if path.is_absolute():
path = path.relative_to(self.root)
res = self.on_write_file(path)
if inspect.isawaitable(res):
await res
def list_files(self, path: str | Path = ".") -> list[Path]:
"""List all files (recursively) in a directory in the storage."""
path = self.get_path(path)
if path == Path("."): # root level of bucket
return [Path(obj.key) for obj in self._bucket.objects.all()]
else:
return [
Path(obj.key).relative_to(path)
for obj in self._bucket.objects.filter(Prefix=f"{path}/")
]
def list_folders(
self, path: str | Path = ".", recursive: bool = False
) -> list[Path]:
"""List 'directories' directly in a given path or recursively in the storage."""
path = self.get_path(path)
folder_names = set()
# List objects with the specified prefix and delimiter
for obj_summary in self._bucket.objects.filter(Prefix=str(path)):
# Remove path prefix and the object name (last part)
folder = Path(obj_summary.key).relative_to(path).parent
if not folder or folder == Path("."):
continue
# For non-recursive, only add the first level of folders
if not recursive:
folder_names.add(folder.parts[0])
else:
# For recursive, need to add all nested folders
for i in range(len(folder.parts)):
folder_names.add("/".join(folder.parts[: i + 1]))
return [Path(f) for f in folder_names]
def delete_file(self, path: str | Path) -> None:
"""Delete a file in the storage."""
path = self.get_path(path)
obj = self._s3.Object(self._bucket_name, str(path))
obj.delete()
def delete_dir(self, path: str | Path) -> None:
"""Delete an empty folder in the storage."""
# S3 does not have directories, so we don't need to do anything
pass
def exists(self, path: str | Path) -> bool:
"""Check if a file or folder exists in S3 storage."""
path = self.get_path(path)
try:
# Check for exact object match (file)
self._s3.meta.client.head_object(Bucket=self._bucket_name, Key=str(path))
return True
except botocore.exceptions.ClientError as e:
if e.response.get("ResponseMetadata", {}).get("HTTPStatusCode") == 404:
# If the object does not exist,
# check for objects with the prefix (folder)
prefix = f"{str(path).rstrip('/')}/"
objs = list(self._bucket.objects.filter(Prefix=prefix, MaxKeys=1))
return len(objs) > 0 # True if any objects exist with the prefix
else:
raise # Re-raise for any other client errors
def make_dir(self, path: str | Path) -> None:
"""Create a directory in the storage if doesn't exist."""
# S3 does not have directories, so we don't need to do anything
pass
def rename(self, old_path: str | Path, new_path: str | Path) -> None:
"""Rename a file or folder in the storage."""
old_path = str(self.get_path(old_path))
new_path = str(self.get_path(new_path))
try:
# If file exists, rename it
self._s3.meta.client.head_object(Bucket=self._bucket_name, Key=old_path)
self._s3.meta.client.copy_object(
CopySource={"Bucket": self._bucket_name, "Key": old_path},
Bucket=self._bucket_name,
Key=new_path,
)
self._s3.meta.client.delete_object(Bucket=self._bucket_name, Key=old_path)
except botocore.exceptions.ClientError as e:
if e.response.get("ResponseMetadata", {}).get("HTTPStatusCode") == 404:
# If the object does not exist,
# it may be a folder
prefix = f"{old_path.rstrip('/')}/"
objs = list(self._bucket.objects.filter(Prefix=prefix))
for obj in objs:
new_key = new_path + obj.key[len(old_path) :]
self._s3.meta.client.copy_object(
CopySource={"Bucket": self._bucket_name, "Key": obj.key},
Bucket=self._bucket_name,
Key=new_key,
)
self._s3.meta.client.delete_object(
Bucket=self._bucket_name, Key=obj.key
)
else:
raise # Re-raise for any other client errors
def copy(self, source: str | Path, destination: str | Path) -> None:
"""Copy a file or folder with all contents in the storage."""
source = str(self.get_path(source))
destination = str(self.get_path(destination))
try:
# If source is a file, copy it
self._s3.meta.client.head_object(Bucket=self._bucket_name, Key=source)
self._s3.meta.client.copy_object(
CopySource={"Bucket": self._bucket_name, "Key": source},
Bucket=self._bucket_name,
Key=destination,
)
except botocore.exceptions.ClientError as e:
if e.response.get("ResponseMetadata", {}).get("HTTPStatusCode") == 404:
# If the object does not exist,
# it may be a folder
prefix = f"{source.rstrip('/')}/"
objs = list(self._bucket.objects.filter(Prefix=prefix))
for obj in objs:
new_key = destination + obj.key[len(source) :]
self._s3.meta.client.copy_object(
CopySource={"Bucket": self._bucket_name, "Key": obj.key},
Bucket=self._bucket_name,
Key=new_key,
)
else:
raise
def clone_with_subroot(self, subroot: str | Path) -> S3FileStorage:
"""Create a new S3FileStorage with a subroot of the current storage."""
file_storage = S3FileStorage(
S3FileStorageConfiguration(
bucket=self._bucket_name,
root=Path("/").joinpath(self.get_path(subroot)),
s3_endpoint_url=SecretStr(self._s3.meta.client.meta.endpoint_url),
)
)
file_storage._s3 = self._s3
file_storage._bucket = self._bucket
return file_storage
def __repr__(self) -> str:
return f"{__class__.__name__}(bucket='{self._bucket_name}', root={self._root})"
class S3BinaryIOWrapper(BinaryIO):
def __init__(self, body: StreamingBody, name: str):
self.body = body
self._name = name
@property
def name(self) -> str:
return self._name
def read(self, size: int = -1) -> bytes:
return self.body.read(size if size > 0 else None)
def readinto(self, b: bytearray) -> int:
data = self.read(len(b))
b[: len(data)] = data
return len(data)
def close(self) -> None:
self.body.close()
def fileno(self) -> int:
return self.body.fileno()
def flush(self) -> None:
self.body.flush()
def isatty(self) -> bool:
return self.body.isatty()
def readable(self) -> bool:
return self.body.readable()
def seekable(self) -> bool:
return self.body.seekable()
def writable(self) -> bool:
return False
def __enter__(self):
return self
def __exit__(self, exc_type, exc_val, exc_tb):
self.body.close()

View File

View File

@@ -0,0 +1,93 @@
import logging
import re
from typing import Any
import demjson3
logger = logging.getLogger(__name__)
def json_loads(json_str: str) -> Any:
"""Parse a JSON string, tolerating minor syntax issues:
- Missing, extra and trailing commas
- Extraneous newlines and whitespace outside of string literals
- Inconsistent spacing after colons and commas
- Missing closing brackets or braces
- Numbers: binary, hex, octal, trailing and prefixed decimal points
- Different encodings
- Surrounding markdown code block
- Comments
Args:
json_str: The JSON string to parse.
Returns:
The parsed JSON object, same as built-in json.loads.
"""
# Remove possible code block
pattern = r"```(?:json|JSON)*([\s\S]*?)```"
match = re.search(pattern, json_str)
if match:
json_str = match.group(1).strip()
json_result = demjson3.decode(json_str, return_errors=True)
assert json_result is not None # by virtue of return_errors=True
if json_result.errors:
logger.debug(
"JSON parse errors:\n" + "\n".join(str(e) for e in json_result.errors)
)
if json_result.object in (demjson3.syntax_error, demjson3.undefined):
raise ValueError(
f"Failed to parse JSON string: {json_str}", *json_result.errors
)
return json_result.object
def extract_dict_from_json(json_str: str) -> dict[str, Any]:
# Sometimes the response includes the JSON in a code block with ```
pattern = r"```(?:json|JSON)*([\s\S]*?)```"
match = re.search(pattern, json_str)
if match:
json_str = match.group(1).strip()
else:
# The string may contain JSON.
json_pattern = r"{[\s\S]*}"
match = re.search(json_pattern, json_str)
if match:
json_str = match.group()
result = json_loads(json_str)
if not isinstance(result, dict):
raise ValueError(
f"Response '''{json_str}''' evaluated to non-dict value {repr(result)}"
)
return result
def extract_list_from_json(json_str: str) -> list[Any]:
# Sometimes the response includes the JSON in a code block with ```
pattern = r"```(?:json|JSON)*([\s\S]*?)```"
match = re.search(pattern, json_str)
if match:
json_str = match.group(1).strip()
else:
# The string may contain JSON.
json_pattern = r"\[[\s\S]*\]"
match = re.search(json_pattern, json_str)
if match:
json_str = match.group()
result = json_loads(json_str)
if not isinstance(result, list):
raise ValueError(
f"Response '''{json_str}''' evaluated to non-list value {repr(result)}"
)
return result

View File

View File

@@ -0,0 +1,8 @@
from .base import PromptStrategy
from .schema import ChatPrompt, LanguageModelClassification
__all__ = [
"LanguageModelClassification",
"ChatPrompt",
"PromptStrategy",
]

View File

@@ -0,0 +1,22 @@
import abc
from typing import TYPE_CHECKING, Any
if TYPE_CHECKING:
from forge.llm.providers import AssistantChatMessage
from .schema import ChatPrompt, LanguageModelClassification
class PromptStrategy(abc.ABC):
@property
@abc.abstractmethod
def model_classification(self) -> LanguageModelClassification:
...
@abc.abstractmethod
def build_prompt(self, *_, **kwargs) -> ChatPrompt:
...
@abc.abstractmethod
def parse_response_content(self, response: "AssistantChatMessage") -> Any:
...

View File

@@ -0,0 +1,9 @@
{% extends "techniques/expert.j2" %}
{% block expert %}Human Resources{% endblock %}
{% block prompt %}
Generate a profile for an expert who can help with the task '{{ task }}'. Please provide the following details:
Name: Enter the expert's name
Expertise: Specify the area in which the expert specializes
Goals: List 4 goals that the expert aims to achieve in order to help with the task
Assessment: Describe how the expert will assess whether they have successfully completed the task
{% endblock %}

View File

@@ -0,0 +1,17 @@
Reply only in json with the following format:
{
\"thoughts\": {
\"text\": \"thoughts\",
\"reasoning\": \"reasoning behind thoughts\",
\"plan\": \"- short bulleted\\n- list that conveys\\n- long-term plan\",
\"criticism\": \"constructive self-criticism\",
\"speak\": \"thoughts summary to say to user\",
},
\"ability\": {
\"name\": \"ability name\",
\"args\": {
\"arg1\": \"value1", etc...
}
}
}

View File

@@ -0,0 +1,50 @@
{% extends "techniques/expert.j2" %}
{% block expert %}Planner{% endblock %}
{% block prompt %}
Your task is:
{{ task }}
Answer in the provided format.
Your decisions must always be made independently without seeking user assistance. Play to your strengths as an LLM and
pursue simple strategies with no legal complications.
{% if constraints %}
## Constraints
You operate within the following constraints:
{% for constraint in constraints %}
- {{ constraint }}
{% endfor %}
{% endif %}
{% if resources %}
## Resources
You can leverage access to the following resources:
{% for resource in resources %}
- {{ resource }}
{% endfor %}
{% endif %}
{% if abilities %}
## Abilities
You have access to the following abilities you can call:
{% for ability in abilities %}
- {{ ability }}
{% endfor %}
{% endif %}
{% if best_practices %}
## Best practices
{% for best_practice in best_practices %}
- {{ best_practice }}
{% endfor %}
{% endif %}
{% if previous_actions %}
## History of Abilities Used
{% for action in previous_actions %}
- {{ action }}
{% endfor %}
{% endif %}
{% endblock %}

View File

@@ -0,0 +1,35 @@
import enum
from pydantic import BaseModel, Field
from forge.llm.providers.schema import (
ChatMessage,
ChatMessageDict,
CompletionModelFunction,
)
class LanguageModelClassification(str, enum.Enum):
"""The LanguageModelClassification is a functional description of the model.
This is used to determine what kind of model to use for a given prompt.
Sometimes we prefer a faster or cheaper model to accomplish a task when
possible.
"""
FAST_MODEL = "fast_model"
SMART_MODEL = "smart_model"
class ChatPrompt(BaseModel):
messages: list[ChatMessage]
functions: list[CompletionModelFunction] = Field(default_factory=list)
prefill_response: str = ""
def raw(self) -> list[ChatMessageDict]:
return [m.dict() for m in self.messages] # type: ignore
def __str__(self):
return "\n\n".join(
f"{m.role.value.upper()}: {m.content}" for m in self.messages
)

View File

@@ -0,0 +1,2 @@
{% block prompt %} {% endblock %}
Let's work this out in a step by step way to be sure we have the right answer.

View File

@@ -0,0 +1 @@
Answer as an expert in {% block expert %} {% endblock %}. {% block prompt %}{% endblock %}

View File

@@ -0,0 +1,5 @@
{% block prompt %} {% endblock %}
Examples:
{% for example in examples %}
- {{ example }}
{% endfor %}

View File

@@ -0,0 +1,43 @@
from math import ceil, floor
from typing import Any
from forge.llm.prompting.schema import ChatPrompt
SEPARATOR_LENGTH = 42
def dump_prompt(prompt: ChatPrompt) -> str:
def separator(text: str):
half_sep_len = (SEPARATOR_LENGTH - 2 - len(text)) / 2
return f"{floor(half_sep_len)*'-'} {text.upper()} {ceil(half_sep_len)*'-'}"
formatted_messages = "\n".join(
[f"{separator(m.role)}\n{m.content}" for m in prompt.messages]
)
return f"""
============== {prompt.__class__.__name__} ==============
Length: {len(prompt.messages)} messages
{formatted_messages}
==========================================
"""
def format_numbered_list(items: list[Any], start_at: int = 1) -> str:
return "\n".join(f"{i}. {str(item)}" for i, item in enumerate(items, start_at))
def indent(content: str, indentation: int | str = 4) -> str:
if type(indentation) is int:
indentation = " " * indentation
return indentation + content.replace("\n", f"\n{indentation}") # type: ignore
def to_numbered_list(
items: list[str], no_items_response: str = "", **template_args
) -> str:
if items:
return "\n".join(
f"{i+1}. {item.format(**template_args)}" for i, item in enumerate(items)
)
else:
return no_items_response

View File

@@ -0,0 +1,73 @@
from .multi import (
CHAT_MODELS,
ChatModelProvider,
EmbeddingModelProvider,
ModelName,
MultiProvider,
)
from .openai import (
OPEN_AI_CHAT_MODELS,
OPEN_AI_EMBEDDING_MODELS,
OPEN_AI_MODELS,
OpenAIModelName,
OpenAIProvider,
OpenAISettings,
)
from .schema import (
AssistantChatMessage,
AssistantChatMessageDict,
AssistantFunctionCall,
AssistantFunctionCallDict,
ChatMessage,
ChatModelInfo,
ChatModelResponse,
CompletionModelFunction,
Embedding,
EmbeddingModelInfo,
EmbeddingModelResponse,
ModelInfo,
ModelProviderBudget,
ModelProviderCredentials,
ModelProviderName,
ModelProviderService,
ModelProviderSettings,
ModelProviderUsage,
ModelResponse,
ModelTokenizer,
)
from .utils import function_specs_from_commands
__all__ = [
"AssistantChatMessage",
"AssistantChatMessageDict",
"AssistantFunctionCall",
"AssistantFunctionCallDict",
"ChatMessage",
"ChatModelInfo",
"ChatModelResponse",
"CompletionModelFunction",
"CHAT_MODELS",
"Embedding",
"EmbeddingModelInfo",
"EmbeddingModelProvider",
"EmbeddingModelResponse",
"ModelInfo",
"ModelName",
"ChatModelProvider",
"ModelProviderBudget",
"ModelProviderCredentials",
"ModelProviderName",
"ModelProviderService",
"ModelProviderSettings",
"ModelProviderUsage",
"ModelResponse",
"ModelTokenizer",
"MultiProvider",
"OPEN_AI_MODELS",
"OPEN_AI_CHAT_MODELS",
"OPEN_AI_EMBEDDING_MODELS",
"OpenAIModelName",
"OpenAIProvider",
"OpenAISettings",
"function_specs_from_commands",
]

View File

@@ -0,0 +1,517 @@
import inspect
import logging
from typing import (
Any,
Awaitable,
Callable,
ClassVar,
Mapping,
Optional,
ParamSpec,
Sequence,
TypeVar,
cast,
)
import sentry_sdk
import tenacity
from openai._exceptions import APIConnectionError, APIStatusError
from openai.types import CreateEmbeddingResponse, EmbeddingCreateParams
from openai.types.chat import (
ChatCompletion,
ChatCompletionAssistantMessageParam,
ChatCompletionMessage,
ChatCompletionMessageParam,
CompletionCreateParams,
)
from openai.types.shared_params import FunctionDefinition
from forge.json.parsing import json_loads
from .schema import (
AssistantChatMessage,
AssistantFunctionCall,
AssistantToolCall,
BaseChatModelProvider,
BaseEmbeddingModelProvider,
BaseModelProvider,
ChatMessage,
ChatModelInfo,
ChatModelResponse,
CompletionModelFunction,
Embedding,
EmbeddingModelInfo,
EmbeddingModelResponse,
ModelProviderService,
_ModelName,
_ModelProviderSettings,
)
from .utils import validate_tool_calls
_T = TypeVar("_T")
_P = ParamSpec("_P")
class _BaseOpenAIProvider(BaseModelProvider[_ModelName, _ModelProviderSettings]):
"""Base class for LLM providers with OpenAI-like APIs"""
MODELS: ClassVar[
Mapping[_ModelName, ChatModelInfo[_ModelName] | EmbeddingModelInfo[_ModelName]] # type: ignore # noqa
]
def __init__(
self,
settings: Optional[_ModelProviderSettings] = None,
logger: Optional[logging.Logger] = None,
):
if not getattr(self, "MODELS", None):
raise ValueError(f"{self.__class__.__name__}.MODELS is not set")
if not settings:
settings = self.default_settings.copy(deep=True)
if not settings.credentials:
settings.credentials = self.default_settings.__fields__[
"credentials"
].type_.from_env()
super(_BaseOpenAIProvider, self).__init__(settings=settings, logger=logger)
if not getattr(self, "_client", None):
from openai import AsyncOpenAI
self._client = AsyncOpenAI(
**self._credentials.get_api_access_kwargs() # type: ignore
)
async def get_available_models(
self,
) -> Sequence[ChatModelInfo[_ModelName] | EmbeddingModelInfo[_ModelName]]:
_models = (await self._client.models.list()).data
return [
self.MODELS[cast(_ModelName, m.id)] for m in _models if m.id in self.MODELS
]
def get_token_limit(self, model_name: _ModelName) -> int:
"""Get the maximum number of input tokens for a given model"""
return self.MODELS[model_name].max_tokens
def count_tokens(self, text: str, model_name: _ModelName) -> int:
return len(self.get_tokenizer(model_name).encode(text))
def _retry_api_request(self, func: Callable[_P, _T]) -> Callable[_P, _T]:
return tenacity.retry(
retry=(
tenacity.retry_if_exception_type(APIConnectionError)
| tenacity.retry_if_exception(
lambda e: isinstance(e, APIStatusError) and e.status_code >= 500
)
),
wait=tenacity.wait_exponential(),
stop=tenacity.stop_after_attempt(self._configuration.retries_per_request),
after=tenacity.after_log(self._logger, logging.DEBUG),
)(func)
def __repr__(self):
return f"{self.__class__.__name__}()"
class BaseOpenAIChatProvider(
_BaseOpenAIProvider[_ModelName, _ModelProviderSettings],
BaseChatModelProvider[_ModelName, _ModelProviderSettings],
):
CHAT_MODELS: ClassVar[dict[_ModelName, ChatModelInfo[_ModelName]]] # type: ignore
def __init__(
self,
settings: Optional[_ModelProviderSettings] = None,
logger: Optional[logging.Logger] = None,
):
if not getattr(self, "CHAT_MODELS", None):
raise ValueError(f"{self.__class__.__name__}.CHAT_MODELS is not set")
super(BaseOpenAIChatProvider, self).__init__(settings=settings, logger=logger)
async def get_available_chat_models(self) -> Sequence[ChatModelInfo[_ModelName]]:
all_available_models = await self.get_available_models()
return [
model
for model in all_available_models
if model.service == ModelProviderService.CHAT
]
def count_message_tokens(
self,
messages: ChatMessage | list[ChatMessage],
model_name: _ModelName,
) -> int:
if isinstance(messages, ChatMessage):
messages = [messages]
return self.count_tokens(
"\n\n".join(f"{m.role.upper()}: {m.content}" for m in messages), model_name
)
async def create_chat_completion(
self,
model_prompt: list[ChatMessage],
model_name: _ModelName,
completion_parser: Callable[[AssistantChatMessage], _T] = lambda _: None,
functions: Optional[list[CompletionModelFunction]] = None,
max_output_tokens: Optional[int] = None,
prefill_response: str = "",
**kwargs,
) -> ChatModelResponse[_T]:
"""Create a chat completion using the API."""
(
openai_messages,
completion_kwargs,
parse_kwargs,
) = self._get_chat_completion_args(
prompt_messages=model_prompt,
model=model_name,
functions=functions,
max_output_tokens=max_output_tokens,
**kwargs,
)
total_cost = 0.0
attempts = 0
while True:
completion_kwargs["messages"] = openai_messages
_response, _cost, t_input, t_output = await self._create_chat_completion(
model=model_name,
completion_kwargs=completion_kwargs,
)
total_cost += _cost
# If parsing the response fails, append the error to the prompt, and let the
# LLM fix its mistake(s).
attempts += 1
parse_errors: list[Exception] = []
_assistant_msg = _response.choices[0].message
tool_calls, _errors = self._parse_assistant_tool_calls(
_assistant_msg, **parse_kwargs
)
parse_errors += _errors
# Validate tool calls
if not parse_errors and tool_calls and functions:
parse_errors += validate_tool_calls(tool_calls, functions)
assistant_msg = AssistantChatMessage(
content=_assistant_msg.content or "",
tool_calls=tool_calls or None,
)
parsed_result: _T = None # type: ignore
if not parse_errors:
try:
parsed_result = completion_parser(assistant_msg)
if inspect.isawaitable(parsed_result):
parsed_result = await parsed_result
except Exception as e:
parse_errors.append(e)
if not parse_errors:
if attempts > 1:
self._logger.debug(
f"Total cost for {attempts} attempts: ${round(total_cost, 5)}"
)
return ChatModelResponse(
response=AssistantChatMessage(
content=_assistant_msg.content or "",
tool_calls=tool_calls or None,
),
parsed_result=parsed_result,
model_info=self.CHAT_MODELS[model_name],
prompt_tokens_used=t_input,
completion_tokens_used=t_output,
)
else:
self._logger.debug(
f"Parsing failed on response: '''{_assistant_msg}'''"
)
parse_errors_fmt = "\n\n".join(
f"{e.__class__.__name__}: {e}" for e in parse_errors
)
self._logger.warning(
f"Parsing attempt #{attempts} failed: {parse_errors_fmt}"
)
for e in parse_errors:
sentry_sdk.capture_exception(
error=e,
extras={"assistant_msg": _assistant_msg, "i_attempt": attempts},
)
if attempts < self._configuration.fix_failed_parse_tries:
openai_messages.append(
cast(
ChatCompletionAssistantMessageParam,
_assistant_msg.dict(exclude_none=True),
)
)
openai_messages.append(
{
"role": "system",
"content": (
f"ERROR PARSING YOUR RESPONSE:\n\n{parse_errors_fmt}"
),
}
)
continue
else:
raise parse_errors[0]
def _get_chat_completion_args(
self,
prompt_messages: list[ChatMessage],
model: _ModelName,
functions: Optional[list[CompletionModelFunction]] = None,
max_output_tokens: Optional[int] = None,
**kwargs,
) -> tuple[
list[ChatCompletionMessageParam], CompletionCreateParams, dict[str, Any]
]:
"""Prepare keyword arguments for a chat completion API call
Args:
prompt_messages: List of ChatMessages
model: The model to use
functions (optional): List of functions available to the LLM
max_output_tokens (optional): Maximum number of tokens to generate
Returns:
list[ChatCompletionMessageParam]: Prompt messages for the API call
CompletionCreateParams: Mapping of other kwargs for the API call
Mapping[str, Any]: Any keyword arguments to pass on to the completion parser
"""
kwargs = cast(CompletionCreateParams, kwargs)
if max_output_tokens:
kwargs["max_tokens"] = max_output_tokens
if functions:
kwargs["tools"] = [ # pyright: ignore - it fails to infer the dict type
{"type": "function", "function": format_function_def_for_openai(f)}
for f in functions
]
if len(functions) == 1:
# force the model to call the only specified function
kwargs["tool_choice"] = { # pyright: ignore - type inference failure
"type": "function",
"function": {"name": functions[0].name},
}
if extra_headers := self._configuration.extra_request_headers:
# 'extra_headers' is not on CompletionCreateParams, but is on chat.create()
kwargs["extra_headers"] = kwargs.get("extra_headers", {}) # type: ignore
kwargs["extra_headers"].update(extra_headers.copy()) # type: ignore
prepped_messages: list[ChatCompletionMessageParam] = [
message.dict( # type: ignore
include={"role", "content", "tool_calls", "tool_call_id", "name"},
exclude_none=True,
)
for message in prompt_messages
]
if "messages" in kwargs:
prepped_messages += kwargs["messages"]
del kwargs["messages"] # type: ignore - messages are added back later
return prepped_messages, kwargs, {}
async def _create_chat_completion(
self,
model: _ModelName,
completion_kwargs: CompletionCreateParams,
) -> tuple[ChatCompletion, float, int, int]:
"""
Create a chat completion using an OpenAI-like API with retry handling
Params:
model: The model to use for the completion
completion_kwargs: All other arguments for the completion call
Returns:
ChatCompletion: The chat completion response object
float: The cost ($) of this completion
int: Number of prompt tokens used
int: Number of completion tokens used
"""
completion_kwargs["model"] = completion_kwargs.get("model") or model
@self._retry_api_request
async def _create_chat_completion_with_retry() -> ChatCompletion:
return await self._client.chat.completions.create(
**completion_kwargs, # type: ignore
)
completion = await _create_chat_completion_with_retry()
if completion.usage:
prompt_tokens_used = completion.usage.prompt_tokens
completion_tokens_used = completion.usage.completion_tokens
else:
prompt_tokens_used = completion_tokens_used = 0
if self._budget:
cost = self._budget.update_usage_and_cost(
model_info=self.CHAT_MODELS[model],
input_tokens_used=prompt_tokens_used,
output_tokens_used=completion_tokens_used,
)
else:
cost = 0
self._logger.debug(
f"{model} completion usage: {prompt_tokens_used} input, "
f"{completion_tokens_used} output - ${round(cost, 5)}"
)
return completion, cost, prompt_tokens_used, completion_tokens_used
def _parse_assistant_tool_calls(
self, assistant_message: ChatCompletionMessage, **kwargs
) -> tuple[list[AssistantToolCall], list[Exception]]:
tool_calls: list[AssistantToolCall] = []
parse_errors: list[Exception] = []
if assistant_message.tool_calls:
for _tc in assistant_message.tool_calls:
try:
parsed_arguments = json_loads(_tc.function.arguments)
except Exception as e:
err_message = (
f"Decoding arguments for {_tc.function.name} failed: "
+ str(e.args[0])
)
parse_errors.append(
type(e)(err_message, *e.args[1:]).with_traceback(
e.__traceback__
)
)
continue
tool_calls.append(
AssistantToolCall(
id=_tc.id,
type=_tc.type,
function=AssistantFunctionCall(
name=_tc.function.name,
arguments=parsed_arguments,
),
)
)
# If parsing of all tool calls succeeds in the end, we ignore any issues
if len(tool_calls) == len(assistant_message.tool_calls):
parse_errors = []
return tool_calls, parse_errors
class BaseOpenAIEmbeddingProvider(
_BaseOpenAIProvider[_ModelName, _ModelProviderSettings],
BaseEmbeddingModelProvider[_ModelName, _ModelProviderSettings],
):
EMBEDDING_MODELS: ClassVar[
dict[_ModelName, EmbeddingModelInfo[_ModelName]] # type: ignore
]
def __init__(
self,
settings: Optional[_ModelProviderSettings] = None,
logger: Optional[logging.Logger] = None,
):
if not getattr(self, "EMBEDDING_MODELS", None):
raise ValueError(f"{self.__class__.__name__}.EMBEDDING_MODELS is not set")
super(BaseOpenAIEmbeddingProvider, self).__init__(
settings=settings, logger=logger
)
async def get_available_embedding_models(
self,
) -> Sequence[EmbeddingModelInfo[_ModelName]]:
all_available_models = await self.get_available_models()
return [
model
for model in all_available_models
if model.service == ModelProviderService.EMBEDDING
]
async def create_embedding(
self,
text: str,
model_name: _ModelName,
embedding_parser: Callable[[Embedding], Embedding],
**kwargs,
) -> EmbeddingModelResponse:
"""Create an embedding using an OpenAI-like API"""
embedding_kwargs = self._get_embedding_kwargs(
input=text, model=model_name, **kwargs
)
response = await self._create_embedding(embedding_kwargs)
return EmbeddingModelResponse(
embedding=embedding_parser(response.data[0].embedding),
model_info=self.EMBEDDING_MODELS[model_name],
prompt_tokens_used=response.usage.prompt_tokens,
)
def _get_embedding_kwargs(
self, input: str | list[str], model: _ModelName, **kwargs
) -> EmbeddingCreateParams:
"""Get kwargs for an embedding API call
Params:
input: Text body or list of text bodies to create embedding(s) from
model: Embedding model to use
Returns:
The kwargs for the embedding API call
"""
kwargs = cast(EmbeddingCreateParams, kwargs)
kwargs["input"] = input
kwargs["model"] = model
if extra_headers := self._configuration.extra_request_headers:
# 'extra_headers' is not on CompletionCreateParams, but is on embedding.create() # noqa
kwargs["extra_headers"] = kwargs.get("extra_headers", {}) # type: ignore
kwargs["extra_headers"].update(extra_headers.copy()) # type: ignore
return kwargs
def _create_embedding(
self, embedding_kwargs: EmbeddingCreateParams
) -> Awaitable[CreateEmbeddingResponse]:
"""Create an embedding using an OpenAI-like API with retry handling."""
@self._retry_api_request
async def _create_embedding_with_retry() -> CreateEmbeddingResponse:
return await self._client.embeddings.create(**embedding_kwargs)
return _create_embedding_with_retry()
def format_function_def_for_openai(self: CompletionModelFunction) -> FunctionDefinition:
"""Returns an OpenAI-consumable function definition"""
return {
"name": self.name,
"description": self.description,
"parameters": {
"type": "object",
"properties": {
name: param.to_dict() for name, param in self.parameters.items()
},
"required": [
name for name, param in self.parameters.items() if param.required
],
},
}

View File

@@ -0,0 +1,488 @@
from __future__ import annotations
import enum
import logging
from typing import TYPE_CHECKING, Any, Callable, Optional, ParamSpec, Sequence, TypeVar
import sentry_sdk
import tenacity
import tiktoken
from anthropic import APIConnectionError, APIStatusError
from pydantic import SecretStr
from forge.models.config import UserConfigurable
from .schema import (
AssistantChatMessage,
AssistantFunctionCall,
AssistantToolCall,
BaseChatModelProvider,
ChatMessage,
ChatModelInfo,
ChatModelResponse,
CompletionModelFunction,
ModelProviderBudget,
ModelProviderConfiguration,
ModelProviderCredentials,
ModelProviderName,
ModelProviderSettings,
ModelTokenizer,
ToolResultMessage,
)
from .utils import validate_tool_calls
if TYPE_CHECKING:
from anthropic.types.beta.tools import MessageCreateParams
from anthropic.types.beta.tools import ToolsBetaMessage as Message
from anthropic.types.beta.tools import ToolsBetaMessageParam as MessageParam
_T = TypeVar("_T")
_P = ParamSpec("_P")
class AnthropicModelName(str, enum.Enum):
CLAUDE3_OPUS_v1 = "claude-3-opus-20240229"
CLAUDE3_SONNET_v1 = "claude-3-sonnet-20240229"
CLAUDE3_HAIKU_v1 = "claude-3-haiku-20240307"
ANTHROPIC_CHAT_MODELS = {
info.name: info
for info in [
ChatModelInfo(
name=AnthropicModelName.CLAUDE3_OPUS_v1,
provider_name=ModelProviderName.ANTHROPIC,
prompt_token_cost=15 / 1e6,
completion_token_cost=75 / 1e6,
max_tokens=200000,
has_function_call_api=True,
),
ChatModelInfo(
name=AnthropicModelName.CLAUDE3_SONNET_v1,
provider_name=ModelProviderName.ANTHROPIC,
prompt_token_cost=3 / 1e6,
completion_token_cost=15 / 1e6,
max_tokens=200000,
has_function_call_api=True,
),
ChatModelInfo(
name=AnthropicModelName.CLAUDE3_HAIKU_v1,
provider_name=ModelProviderName.ANTHROPIC,
prompt_token_cost=0.25 / 1e6,
completion_token_cost=1.25 / 1e6,
max_tokens=200000,
has_function_call_api=True,
),
]
}
class AnthropicCredentials(ModelProviderCredentials):
"""Credentials for Anthropic."""
api_key: SecretStr = UserConfigurable(from_env="ANTHROPIC_API_KEY") # type: ignore
api_base: Optional[SecretStr] = UserConfigurable(
default=None, from_env="ANTHROPIC_API_BASE_URL"
)
def get_api_access_kwargs(self) -> dict[str, str]:
return {
k: v.get_secret_value()
for k, v in {
"api_key": self.api_key,
"base_url": self.api_base,
}.items()
if v is not None
}
class AnthropicSettings(ModelProviderSettings):
credentials: Optional[AnthropicCredentials] # type: ignore
budget: ModelProviderBudget # type: ignore
class AnthropicProvider(BaseChatModelProvider[AnthropicModelName, AnthropicSettings]):
default_settings = AnthropicSettings(
name="anthropic_provider",
description="Provides access to Anthropic's API.",
configuration=ModelProviderConfiguration(),
credentials=None,
budget=ModelProviderBudget(),
)
_settings: AnthropicSettings
_credentials: AnthropicCredentials
_budget: ModelProviderBudget
def __init__(
self,
settings: Optional[AnthropicSettings] = None,
logger: Optional[logging.Logger] = None,
):
if not settings:
settings = self.default_settings.copy(deep=True)
if not settings.credentials:
settings.credentials = AnthropicCredentials.from_env()
super(AnthropicProvider, self).__init__(settings=settings, logger=logger)
from anthropic import AsyncAnthropic
self._client = AsyncAnthropic(
**self._credentials.get_api_access_kwargs() # type: ignore
)
async def get_available_models(self) -> Sequence[ChatModelInfo[AnthropicModelName]]:
return await self.get_available_chat_models()
async def get_available_chat_models(
self,
) -> Sequence[ChatModelInfo[AnthropicModelName]]:
return list(ANTHROPIC_CHAT_MODELS.values())
def get_token_limit(self, model_name: AnthropicModelName) -> int:
"""Get the token limit for a given model."""
return ANTHROPIC_CHAT_MODELS[model_name].max_tokens
def get_tokenizer(self, model_name: AnthropicModelName) -> ModelTokenizer[Any]:
# HACK: No official tokenizer is available for Claude 3
return tiktoken.encoding_for_model(model_name)
def count_tokens(self, text: str, model_name: AnthropicModelName) -> int:
return 0 # HACK: No official tokenizer is available for Claude 3
def count_message_tokens(
self,
messages: ChatMessage | list[ChatMessage],
model_name: AnthropicModelName,
) -> int:
return 0 # HACK: No official tokenizer is available for Claude 3
async def create_chat_completion(
self,
model_prompt: list[ChatMessage],
model_name: AnthropicModelName,
completion_parser: Callable[[AssistantChatMessage], _T] = lambda _: None,
functions: Optional[list[CompletionModelFunction]] = None,
max_output_tokens: Optional[int] = None,
prefill_response: str = "",
**kwargs,
) -> ChatModelResponse[_T]:
"""Create a completion using the Anthropic API."""
anthropic_messages, completion_kwargs = self._get_chat_completion_args(
prompt_messages=model_prompt,
model=model_name,
functions=functions,
max_output_tokens=max_output_tokens,
**kwargs,
)
total_cost = 0.0
attempts = 0
while True:
completion_kwargs["messages"] = anthropic_messages.copy()
if prefill_response:
completion_kwargs["messages"].append(
{"role": "assistant", "content": prefill_response}
)
(
_assistant_msg,
cost,
t_input,
t_output,
) = await self._create_chat_completion(model_name, completion_kwargs)
total_cost += cost
self._logger.debug(
f"Completion usage: {t_input} input, {t_output} output "
f"- ${round(cost, 5)}"
)
# Merge prefill into generated response
if prefill_response:
first_text_block = next(
b for b in _assistant_msg.content if b.type == "text"
)
first_text_block.text = prefill_response + first_text_block.text
assistant_msg = AssistantChatMessage(
content="\n\n".join(
b.text for b in _assistant_msg.content if b.type == "text"
),
tool_calls=self._parse_assistant_tool_calls(_assistant_msg),
)
# If parsing the response fails, append the error to the prompt, and let the
# LLM fix its mistake(s).
attempts += 1
tool_call_errors = []
try:
# Validate tool calls
if assistant_msg.tool_calls and functions:
tool_call_errors = validate_tool_calls(
assistant_msg.tool_calls, functions
)
if tool_call_errors:
raise ValueError(
"Invalid tool use(s):\n"
+ "\n".join(str(e) for e in tool_call_errors)
)
parsed_result = completion_parser(assistant_msg)
break
except Exception as e:
self._logger.debug(
f"Parsing failed on response: '''{_assistant_msg}'''"
)
self._logger.warning(f"Parsing attempt #{attempts} failed: {e}")
sentry_sdk.capture_exception(
error=e,
extras={"assistant_msg": _assistant_msg, "i_attempt": attempts},
)
if attempts < self._configuration.fix_failed_parse_tries:
anthropic_messages.append(
_assistant_msg.dict(include={"role", "content"}) # type: ignore
)
anthropic_messages.append(
{
"role": "user",
"content": [
*(
# tool_result is required if last assistant message
# had tool_use block(s)
{
"type": "tool_result",
"tool_use_id": tc.id,
"is_error": True,
"content": [
{
"type": "text",
"text": "Not executed because parsing "
"of your last message failed"
if not tool_call_errors
else str(e)
if (
e := next(
(
tce
for tce in tool_call_errors
if tce.name
== tc.function.name
),
None,
)
)
else "Not executed because validation "
"of tool input failed",
}
],
}
for tc in assistant_msg.tool_calls or []
),
{
"type": "text",
"text": (
"ERROR PARSING YOUR RESPONSE:\n\n"
f"{e.__class__.__name__}: {e}"
),
},
],
}
)
else:
raise
if attempts > 1:
self._logger.debug(
f"Total cost for {attempts} attempts: ${round(total_cost, 5)}"
)
return ChatModelResponse(
response=assistant_msg,
parsed_result=parsed_result,
model_info=ANTHROPIC_CHAT_MODELS[model_name],
prompt_tokens_used=t_input,
completion_tokens_used=t_output,
)
def _get_chat_completion_args(
self,
prompt_messages: list[ChatMessage],
functions: Optional[list[CompletionModelFunction]] = None,
max_output_tokens: Optional[int] = None,
**kwargs,
) -> tuple[list[MessageParam], MessageCreateParams]:
"""Prepare arguments for message completion API call.
Args:
prompt_messages: List of ChatMessages.
functions: Optional list of functions available to the LLM.
kwargs: Additional keyword arguments.
Returns:
list[MessageParam]: Prompt messages for the Anthropic call
dict[str, Any]: Any other kwargs for the Anthropic call
"""
if functions:
kwargs["tools"] = [
{
"name": f.name,
"description": f.description,
"input_schema": {
"type": "object",
"properties": {
name: param.to_dict()
for name, param in f.parameters.items()
},
"required": [
name
for name, param in f.parameters.items()
if param.required
],
},
}
for f in functions
]
kwargs["max_tokens"] = max_output_tokens or 4096
if extra_headers := self._configuration.extra_request_headers:
kwargs["extra_headers"] = kwargs.get("extra_headers", {})
kwargs["extra_headers"].update(extra_headers.copy())
system_messages = [
m for m in prompt_messages if m.role == ChatMessage.Role.SYSTEM
]
if (_n := len(system_messages)) > 1:
self._logger.warning(
f"Prompt has {_n} system messages; Anthropic supports only 1. "
"They will be merged, and removed from the rest of the prompt."
)
kwargs["system"] = "\n\n".join(sm.content for sm in system_messages)
messages: list[MessageParam] = []
for message in prompt_messages:
if message.role == ChatMessage.Role.SYSTEM:
continue
elif message.role == ChatMessage.Role.USER:
# Merge subsequent user messages
if messages and (prev_msg := messages[-1])["role"] == "user":
if isinstance(prev_msg["content"], str):
prev_msg["content"] += f"\n\n{message.content}"
else:
assert isinstance(prev_msg["content"], list)
prev_msg["content"].append(
{"type": "text", "text": message.content}
)
else:
messages.append({"role": "user", "content": message.content})
# TODO: add support for image blocks
elif message.role == ChatMessage.Role.ASSISTANT:
if isinstance(message, AssistantChatMessage) and message.tool_calls:
messages.append(
{
"role": "assistant",
"content": [
*(
[{"type": "text", "text": message.content}]
if message.content
else []
),
*(
{
"type": "tool_use",
"id": tc.id,
"name": tc.function.name,
"input": tc.function.arguments,
}
for tc in message.tool_calls
),
],
}
)
elif message.content:
messages.append(
{
"role": "assistant",
"content": message.content,
}
)
elif isinstance(message, ToolResultMessage):
messages.append(
{
"role": "user",
"content": [
{
"type": "tool_result",
"tool_use_id": message.tool_call_id,
"content": [{"type": "text", "text": message.content}],
"is_error": message.is_error,
}
],
}
)
return messages, kwargs # type: ignore
async def _create_chat_completion(
self, model: AnthropicModelName, completion_kwargs: MessageCreateParams
) -> tuple[Message, float, int, int]:
"""
Create a chat completion using the Anthropic API with retry handling.
Params:
completion_kwargs: Keyword arguments for an Anthropic Messages API call
Returns:
Message: The message completion object
float: The cost ($) of this completion
int: Number of input tokens used
int: Number of output tokens used
"""
@self._retry_api_request
async def _create_chat_completion_with_retry() -> Message:
return await self._client.beta.tools.messages.create(
model=model, **completion_kwargs # type: ignore
)
response = await _create_chat_completion_with_retry()
cost = self._budget.update_usage_and_cost(
model_info=ANTHROPIC_CHAT_MODELS[model],
input_tokens_used=response.usage.input_tokens,
output_tokens_used=response.usage.output_tokens,
)
return response, cost, response.usage.input_tokens, response.usage.output_tokens
def _parse_assistant_tool_calls(
self, assistant_message: Message
) -> list[AssistantToolCall]:
return [
AssistantToolCall(
id=c.id,
type="function",
function=AssistantFunctionCall(
name=c.name,
arguments=c.input, # type: ignore
),
)
for c in assistant_message.content
if c.type == "tool_use"
]
def _retry_api_request(self, func: Callable[_P, _T]) -> Callable[_P, _T]:
return tenacity.retry(
retry=(
tenacity.retry_if_exception_type(APIConnectionError)
| tenacity.retry_if_exception(
lambda e: isinstance(e, APIStatusError) and e.status_code >= 500
)
),
wait=tenacity.wait_exponential(),
stop=tenacity.stop_after_attempt(self._configuration.retries_per_request),
after=tenacity.after_log(self._logger, logging.DEBUG),
)(func)
def __repr__(self):
return "AnthropicProvider()"

View File

@@ -0,0 +1,126 @@
from __future__ import annotations
import enum
import logging
from typing import Any, Optional
import tiktoken
from pydantic import SecretStr
from forge.models.config import UserConfigurable
from ._openai_base import BaseOpenAIChatProvider
from .schema import (
ChatModelInfo,
ModelProviderBudget,
ModelProviderConfiguration,
ModelProviderCredentials,
ModelProviderName,
ModelProviderSettings,
ModelTokenizer,
)
class GroqModelName(str, enum.Enum):
LLAMA3_8B = "llama3-8b-8192"
LLAMA3_70B = "llama3-70b-8192"
MIXTRAL_8X7B = "mixtral-8x7b-32768"
GEMMA_7B = "gemma-7b-it"
GROQ_CHAT_MODELS = {
info.name: info
for info in [
ChatModelInfo(
name=GroqModelName.LLAMA3_8B,
provider_name=ModelProviderName.GROQ,
prompt_token_cost=0.05 / 1e6,
completion_token_cost=0.10 / 1e6,
max_tokens=8192,
has_function_call_api=True,
),
ChatModelInfo(
name=GroqModelName.LLAMA3_70B,
provider_name=ModelProviderName.GROQ,
prompt_token_cost=0.59 / 1e6,
completion_token_cost=0.79 / 1e6,
max_tokens=8192,
has_function_call_api=True,
),
ChatModelInfo(
name=GroqModelName.MIXTRAL_8X7B,
provider_name=ModelProviderName.GROQ,
prompt_token_cost=0.27 / 1e6,
completion_token_cost=0.27 / 1e6,
max_tokens=32768,
has_function_call_api=True,
),
ChatModelInfo(
name=GroqModelName.GEMMA_7B,
provider_name=ModelProviderName.GROQ,
prompt_token_cost=0.10 / 1e6,
completion_token_cost=0.10 / 1e6,
max_tokens=8192,
has_function_call_api=True,
),
]
}
class GroqCredentials(ModelProviderCredentials):
"""Credentials for Groq."""
api_key: SecretStr = UserConfigurable(from_env="GROQ_API_KEY") # type: ignore
api_base: Optional[SecretStr] = UserConfigurable(
default=None, from_env="GROQ_API_BASE_URL"
)
def get_api_access_kwargs(self) -> dict[str, str]:
return {
k: v.get_secret_value()
for k, v in {
"api_key": self.api_key,
"base_url": self.api_base,
}.items()
if v is not None
}
class GroqSettings(ModelProviderSettings):
credentials: Optional[GroqCredentials] # type: ignore
budget: ModelProviderBudget # type: ignore
class GroqProvider(BaseOpenAIChatProvider[GroqModelName, GroqSettings]):
CHAT_MODELS = GROQ_CHAT_MODELS
MODELS = CHAT_MODELS
default_settings = GroqSettings(
name="groq_provider",
description="Provides access to Groq's API.",
configuration=ModelProviderConfiguration(),
credentials=None,
budget=ModelProviderBudget(),
)
_settings: GroqSettings
_configuration: ModelProviderConfiguration
_credentials: GroqCredentials
_budget: ModelProviderBudget
def __init__(
self,
settings: Optional[GroqSettings] = None,
logger: Optional[logging.Logger] = None,
):
super(GroqProvider, self).__init__(settings=settings, logger=logger)
from groq import AsyncGroq
self._client = AsyncGroq(
**self._credentials.get_api_access_kwargs() # type: ignore
)
def get_tokenizer(self, model_name: GroqModelName) -> ModelTokenizer[Any]:
# HACK: No official tokenizer is available for Groq
return tiktoken.encoding_for_model("gpt-3.5-turbo")

View File

@@ -0,0 +1,165 @@
from __future__ import annotations
import logging
from typing import Any, Callable, Iterator, Optional, Sequence, TypeVar
from pydantic import ValidationError
from .anthropic import ANTHROPIC_CHAT_MODELS, AnthropicModelName, AnthropicProvider
from .groq import GROQ_CHAT_MODELS, GroqModelName, GroqProvider
from .openai import OPEN_AI_CHAT_MODELS, OpenAIModelName, OpenAIProvider
from .schema import (
AssistantChatMessage,
BaseChatModelProvider,
ChatMessage,
ChatModelInfo,
ChatModelResponse,
CompletionModelFunction,
ModelProviderBudget,
ModelProviderConfiguration,
ModelProviderName,
ModelProviderSettings,
ModelTokenizer,
)
_T = TypeVar("_T")
ModelName = AnthropicModelName | GroqModelName | OpenAIModelName
EmbeddingModelProvider = OpenAIProvider
CHAT_MODELS = {**ANTHROPIC_CHAT_MODELS, **GROQ_CHAT_MODELS, **OPEN_AI_CHAT_MODELS}
class MultiProvider(BaseChatModelProvider[ModelName, ModelProviderSettings]):
default_settings = ModelProviderSettings(
name="multi_provider",
description=(
"Provides access to all of the available models, regardless of provider."
),
configuration=ModelProviderConfiguration(
retries_per_request=7,
),
budget=ModelProviderBudget(),
)
_budget: ModelProviderBudget
_provider_instances: dict[ModelProviderName, ChatModelProvider]
def __init__(
self,
settings: Optional[ModelProviderSettings] = None,
logger: Optional[logging.Logger] = None,
):
super(MultiProvider, self).__init__(settings=settings, logger=logger)
self._budget = self._settings.budget or ModelProviderBudget()
self._provider_instances = {}
async def get_available_models(self) -> Sequence[ChatModelInfo[ModelName]]:
# TODO: support embeddings
return await self.get_available_chat_models()
async def get_available_chat_models(self) -> Sequence[ChatModelInfo[ModelName]]:
models = []
for provider in self.get_available_providers():
models.extend(await provider.get_available_chat_models())
return models
def get_token_limit(self, model_name: ModelName) -> int:
"""Get the token limit for a given model."""
return self.get_model_provider(model_name).get_token_limit(
model_name # type: ignore
)
def get_tokenizer(self, model_name: ModelName) -> ModelTokenizer[Any]:
return self.get_model_provider(model_name).get_tokenizer(
model_name # type: ignore
)
def count_tokens(self, text: str, model_name: ModelName) -> int:
return self.get_model_provider(model_name).count_tokens(
text=text, model_name=model_name # type: ignore
)
def count_message_tokens(
self, messages: ChatMessage | list[ChatMessage], model_name: ModelName
) -> int:
return self.get_model_provider(model_name).count_message_tokens(
messages=messages, model_name=model_name # type: ignore
)
async def create_chat_completion(
self,
model_prompt: list[ChatMessage],
model_name: ModelName,
completion_parser: Callable[[AssistantChatMessage], _T] = lambda _: None,
functions: Optional[list[CompletionModelFunction]] = None,
max_output_tokens: Optional[int] = None,
prefill_response: str = "",
**kwargs,
) -> ChatModelResponse[_T]:
"""Create a completion using the Anthropic API."""
return await self.get_model_provider(model_name).create_chat_completion(
model_prompt=model_prompt,
model_name=model_name, # type: ignore
completion_parser=completion_parser,
functions=functions,
max_output_tokens=max_output_tokens,
prefill_response=prefill_response,
**kwargs,
)
def get_model_provider(self, model: ModelName) -> ChatModelProvider:
model_info = CHAT_MODELS[model]
return self._get_provider(model_info.provider_name)
def get_available_providers(self) -> Iterator[ChatModelProvider]:
for provider_name in ModelProviderName:
try:
yield self._get_provider(provider_name)
except Exception:
pass
def _get_provider(self, provider_name: ModelProviderName) -> ChatModelProvider:
_provider = self._provider_instances.get(provider_name)
if not _provider:
Provider = self._get_provider_class(provider_name)
settings = Provider.default_settings.copy(deep=True)
settings.budget = self._budget
settings.configuration.extra_request_headers.update(
self._settings.configuration.extra_request_headers
)
if settings.credentials is None:
try:
Credentials = settings.__fields__["credentials"].type_
settings.credentials = Credentials.from_env()
except ValidationError as e:
raise ValueError(
f"{provider_name} is unavailable: can't load credentials"
) from e
self._provider_instances[provider_name] = _provider = Provider(
settings=settings, logger=self._logger # type: ignore
)
_provider._budget = self._budget # Object binding not preserved by Pydantic
return _provider
@classmethod
def _get_provider_class(
cls, provider_name: ModelProviderName
) -> type[AnthropicProvider | GroqProvider | OpenAIProvider]:
try:
return {
ModelProviderName.ANTHROPIC: AnthropicProvider,
ModelProviderName.GROQ: GroqProvider,
ModelProviderName.OPENAI: OpenAIProvider,
}[provider_name]
except KeyError:
raise ValueError(f"{provider_name} is not a known provider") from None
def __repr__(self):
return f"{self.__class__.__name__}()"
ChatModelProvider = AnthropicProvider | GroqProvider | OpenAIProvider | MultiProvider

View File

@@ -0,0 +1,629 @@
import enum
import inspect
import logging
import os
from pathlib import Path
from typing import Any, Callable, Iterator, Mapping, Optional, ParamSpec, TypeVar, cast
import tenacity
import tiktoken
import yaml
from openai._exceptions import APIStatusError, RateLimitError
from openai.types import EmbeddingCreateParams
from openai.types.chat import (
ChatCompletionMessage,
ChatCompletionMessageParam,
CompletionCreateParams,
)
from pydantic import SecretStr
from forge.json.parsing import json_loads
from forge.models.config import UserConfigurable
from forge.models.json_schema import JSONSchema
from ._openai_base import BaseOpenAIChatProvider, BaseOpenAIEmbeddingProvider
from .schema import (
AssistantToolCall,
AssistantToolCallDict,
ChatMessage,
ChatModelInfo,
CompletionModelFunction,
Embedding,
EmbeddingModelInfo,
ModelProviderBudget,
ModelProviderConfiguration,
ModelProviderCredentials,
ModelProviderName,
ModelProviderSettings,
ModelTokenizer,
)
_T = TypeVar("_T")
_P = ParamSpec("_P")
OpenAIEmbeddingParser = Callable[[Embedding], Embedding]
class OpenAIModelName(str, enum.Enum):
EMBEDDING_v2 = "text-embedding-ada-002"
EMBEDDING_v3_S = "text-embedding-3-small"
EMBEDDING_v3_L = "text-embedding-3-large"
GPT3_v1 = "gpt-3.5-turbo-0301"
GPT3_v2 = "gpt-3.5-turbo-0613"
GPT3_v2_16k = "gpt-3.5-turbo-16k-0613"
GPT3_v3 = "gpt-3.5-turbo-1106"
GPT3_v4 = "gpt-3.5-turbo-0125"
GPT3_ROLLING = "gpt-3.5-turbo"
GPT3_ROLLING_16k = "gpt-3.5-turbo-16k"
GPT3 = GPT3_ROLLING
GPT3_16k = GPT3_ROLLING_16k
GPT4_v1 = "gpt-4-0314"
GPT4_v1_32k = "gpt-4-32k-0314"
GPT4_v2 = "gpt-4-0613"
GPT4_v2_32k = "gpt-4-32k-0613"
GPT4_v3 = "gpt-4-1106-preview"
GPT4_v3_VISION = "gpt-4-1106-vision-preview"
GPT4_v4 = "gpt-4-0125-preview"
GPT4_v5 = "gpt-4-turbo-2024-04-09"
GPT4_ROLLING = "gpt-4"
GPT4_ROLLING_32k = "gpt-4-32k"
GPT4_TURBO = "gpt-4-turbo"
GPT4_TURBO_PREVIEW = "gpt-4-turbo-preview"
GPT4_VISION = "gpt-4-vision-preview"
GPT4_O_v1 = "gpt-4o-2024-05-13"
GPT4_O_ROLLING = "gpt-4o"
GPT4 = GPT4_ROLLING
GPT4_32k = GPT4_ROLLING_32k
GPT4_O = GPT4_O_ROLLING
OPEN_AI_EMBEDDING_MODELS = {
info.name: info
for info in [
EmbeddingModelInfo(
name=OpenAIModelName.EMBEDDING_v2,
provider_name=ModelProviderName.OPENAI,
prompt_token_cost=0.0001 / 1000,
max_tokens=8191,
embedding_dimensions=1536,
),
EmbeddingModelInfo(
name=OpenAIModelName.EMBEDDING_v3_S,
provider_name=ModelProviderName.OPENAI,
prompt_token_cost=0.00002 / 1000,
max_tokens=8191,
embedding_dimensions=1536,
),
EmbeddingModelInfo(
name=OpenAIModelName.EMBEDDING_v3_L,
provider_name=ModelProviderName.OPENAI,
prompt_token_cost=0.00013 / 1000,
max_tokens=8191,
embedding_dimensions=3072,
),
]
}
OPEN_AI_CHAT_MODELS = {
info.name: info
for info in [
ChatModelInfo(
name=OpenAIModelName.GPT3_v1,
provider_name=ModelProviderName.OPENAI,
prompt_token_cost=0.0015 / 1000,
completion_token_cost=0.002 / 1000,
max_tokens=4096,
has_function_call_api=True,
),
ChatModelInfo(
name=OpenAIModelName.GPT3_v2_16k,
provider_name=ModelProviderName.OPENAI,
prompt_token_cost=0.003 / 1000,
completion_token_cost=0.004 / 1000,
max_tokens=16384,
has_function_call_api=True,
),
ChatModelInfo(
name=OpenAIModelName.GPT3_v3,
provider_name=ModelProviderName.OPENAI,
prompt_token_cost=0.001 / 1000,
completion_token_cost=0.002 / 1000,
max_tokens=16384,
has_function_call_api=True,
),
ChatModelInfo(
name=OpenAIModelName.GPT3_v4,
provider_name=ModelProviderName.OPENAI,
prompt_token_cost=0.0005 / 1000,
completion_token_cost=0.0015 / 1000,
max_tokens=16384,
has_function_call_api=True,
),
ChatModelInfo(
name=OpenAIModelName.GPT4_v1,
provider_name=ModelProviderName.OPENAI,
prompt_token_cost=0.03 / 1000,
completion_token_cost=0.06 / 1000,
max_tokens=8191,
has_function_call_api=True,
),
ChatModelInfo(
name=OpenAIModelName.GPT4_v1_32k,
provider_name=ModelProviderName.OPENAI,
prompt_token_cost=0.06 / 1000,
completion_token_cost=0.12 / 1000,
max_tokens=32768,
has_function_call_api=True,
),
ChatModelInfo(
name=OpenAIModelName.GPT4_TURBO,
provider_name=ModelProviderName.OPENAI,
prompt_token_cost=0.01 / 1000,
completion_token_cost=0.03 / 1000,
max_tokens=128000,
has_function_call_api=True,
),
ChatModelInfo(
name=OpenAIModelName.GPT4_O,
provider_name=ModelProviderName.OPENAI,
prompt_token_cost=5 / 1_000_000,
completion_token_cost=15 / 1_000_000,
max_tokens=128_000,
has_function_call_api=True,
),
]
}
# Copy entries for models with equivalent specs
chat_model_mapping = {
OpenAIModelName.GPT3_v1: [OpenAIModelName.GPT3_v2],
OpenAIModelName.GPT3_v2_16k: [OpenAIModelName.GPT3_16k],
OpenAIModelName.GPT3_v4: [OpenAIModelName.GPT3_ROLLING],
OpenAIModelName.GPT4_v1: [OpenAIModelName.GPT4_v2, OpenAIModelName.GPT4_ROLLING],
OpenAIModelName.GPT4_v1_32k: [
OpenAIModelName.GPT4_v2_32k,
OpenAIModelName.GPT4_32k,
],
OpenAIModelName.GPT4_TURBO: [
OpenAIModelName.GPT4_v3,
OpenAIModelName.GPT4_v3_VISION,
OpenAIModelName.GPT4_VISION,
OpenAIModelName.GPT4_v4,
OpenAIModelName.GPT4_TURBO_PREVIEW,
OpenAIModelName.GPT4_v5,
],
OpenAIModelName.GPT4_O: [OpenAIModelName.GPT4_O_v1],
}
for base, copies in chat_model_mapping.items():
for copy in copies:
copy_info = OPEN_AI_CHAT_MODELS[base].copy(update={"name": copy})
OPEN_AI_CHAT_MODELS[copy] = copy_info
if copy.endswith(("-0301", "-0314")):
copy_info.has_function_call_api = False
OPEN_AI_MODELS: Mapping[
OpenAIModelName,
ChatModelInfo[OpenAIModelName] | EmbeddingModelInfo[OpenAIModelName],
] = {
**OPEN_AI_CHAT_MODELS,
**OPEN_AI_EMBEDDING_MODELS,
}
class OpenAICredentials(ModelProviderCredentials):
"""Credentials for OpenAI."""
api_key: SecretStr = UserConfigurable(from_env="OPENAI_API_KEY") # type: ignore
api_base: Optional[SecretStr] = UserConfigurable(
default=None, from_env="OPENAI_API_BASE_URL"
)
organization: Optional[SecretStr] = UserConfigurable(from_env="OPENAI_ORGANIZATION")
api_type: Optional[SecretStr] = UserConfigurable(
default=None,
from_env=lambda: cast(
SecretStr | None,
"azure"
if os.getenv("USE_AZURE") == "True"
else os.getenv("OPENAI_API_TYPE"),
),
)
api_version: Optional[SecretStr] = UserConfigurable(
default=None, from_env="OPENAI_API_VERSION"
)
azure_endpoint: Optional[SecretStr] = None
azure_model_to_deploy_id_map: Optional[dict[str, str]] = None
def get_api_access_kwargs(self) -> dict[str, str]:
kwargs = {
k: v.get_secret_value()
for k, v in {
"api_key": self.api_key,
"base_url": self.api_base,
"organization": self.organization,
"api_version": self.api_version,
}.items()
if v is not None
}
if self.api_type == SecretStr("azure"):
assert self.azure_endpoint, "Azure endpoint not configured"
kwargs["azure_endpoint"] = self.azure_endpoint.get_secret_value()
return kwargs
def get_model_access_kwargs(self, model: str) -> dict[str, str]:
kwargs = {"model": model}
if self.api_type == SecretStr("azure") and model:
azure_kwargs = self._get_azure_access_kwargs(model)
kwargs.update(azure_kwargs)
return kwargs
def load_azure_config(self, config_file: Path) -> None:
with open(config_file) as file:
config_params = yaml.load(file, Loader=yaml.SafeLoader) or {}
try:
assert config_params.get(
"azure_model_map", {}
), "Azure model->deployment_id map is empty"
except AssertionError as e:
raise ValueError(*e.args)
self.api_type = config_params.get("azure_api_type", "azure")
self.api_version = config_params.get("azure_api_version", None)
self.azure_endpoint = config_params.get("azure_endpoint")
self.azure_model_to_deploy_id_map = config_params.get("azure_model_map")
def _get_azure_access_kwargs(self, model: str) -> dict[str, str]:
"""Get the kwargs for the Azure API."""
if not self.azure_model_to_deploy_id_map:
raise ValueError("Azure model deployment map not configured")
if model not in self.azure_model_to_deploy_id_map:
raise ValueError(f"No Azure deployment ID configured for model '{model}'")
deployment_id = self.azure_model_to_deploy_id_map[model]
return {"model": deployment_id}
class OpenAISettings(ModelProviderSettings):
credentials: Optional[OpenAICredentials] # type: ignore
budget: ModelProviderBudget # type: ignore
class OpenAIProvider(
BaseOpenAIChatProvider[OpenAIModelName, OpenAISettings],
BaseOpenAIEmbeddingProvider[OpenAIModelName, OpenAISettings],
):
MODELS = OPEN_AI_MODELS
CHAT_MODELS = OPEN_AI_CHAT_MODELS
EMBEDDING_MODELS = OPEN_AI_EMBEDDING_MODELS
default_settings = OpenAISettings(
name="openai_provider",
description="Provides access to OpenAI's API.",
configuration=ModelProviderConfiguration(),
credentials=None,
budget=ModelProviderBudget(),
)
_settings: OpenAISettings
_configuration: ModelProviderConfiguration
_credentials: OpenAICredentials
_budget: ModelProviderBudget
def __init__(
self,
settings: Optional[OpenAISettings] = None,
logger: Optional[logging.Logger] = None,
):
super(OpenAIProvider, self).__init__(settings=settings, logger=logger)
if self._credentials.api_type == SecretStr("azure"):
from openai import AsyncAzureOpenAI
# API key and org (if configured) are passed, the rest of the required
# credentials is loaded from the environment by the AzureOpenAI client.
self._client = AsyncAzureOpenAI(
**self._credentials.get_api_access_kwargs() # type: ignore
)
else:
from openai import AsyncOpenAI
self._client = AsyncOpenAI(
**self._credentials.get_api_access_kwargs() # type: ignore
)
def get_tokenizer(self, model_name: OpenAIModelName) -> ModelTokenizer[int]:
return tiktoken.encoding_for_model(model_name)
def count_message_tokens(
self,
messages: ChatMessage | list[ChatMessage],
model_name: OpenAIModelName,
) -> int:
if isinstance(messages, ChatMessage):
messages = [messages]
if model_name.startswith("gpt-3.5-turbo"):
tokens_per_message = (
4 # every message follows <|start|>{role/name}\n{content}<|end|>\n
)
tokens_per_name = -1 # if there's a name, the role is omitted
# TODO: check if this is still valid for gpt-4o
elif model_name.startswith("gpt-4"):
tokens_per_message = 3
tokens_per_name = 1
else:
raise NotImplementedError(
f"count_message_tokens() is not implemented for model {model_name}.\n"
"See https://github.com/openai/openai-python/blob/120d225b91a8453e15240a49fb1c6794d8119326/chatml.md " # noqa
"for information on how messages are converted to tokens."
)
tokenizer = self.get_tokenizer(model_name)
num_tokens = 0
for message in messages:
num_tokens += tokens_per_message
for key, value in message.dict().items():
num_tokens += len(tokenizer.encode(value))
if key == "name":
num_tokens += tokens_per_name
num_tokens += 3 # every reply is primed with <|start|>assistant<|message|>
return num_tokens
def _get_chat_completion_args(
self,
prompt_messages: list[ChatMessage],
model: OpenAIModelName,
functions: Optional[list[CompletionModelFunction]] = None,
max_output_tokens: Optional[int] = None,
**kwargs,
) -> tuple[
list[ChatCompletionMessageParam], CompletionCreateParams, dict[str, Any]
]:
"""Prepare keyword arguments for an OpenAI chat completion call
Args:
prompt_messages: List of ChatMessages
model: The model to use
functions (optional): List of functions available to the LLM
max_output_tokens (optional): Maximum number of tokens to generate
Returns:
list[ChatCompletionMessageParam]: Prompt messages for the OpenAI call
CompletionCreateParams: Mapping of other kwargs for the OpenAI call
Mapping[str, Any]: Any keyword arguments to pass on to the completion parser
"""
tools_compat_mode = False
if functions:
if not OPEN_AI_CHAT_MODELS[model].has_function_call_api:
# Provide compatibility with older models
_functions_compat_fix_kwargs(functions, prompt_messages)
tools_compat_mode = True
functions = None
openai_messages, kwargs, parse_kwargs = super()._get_chat_completion_args(
prompt_messages=prompt_messages,
model=model,
functions=functions,
max_output_tokens=max_output_tokens,
**kwargs,
)
kwargs.update(self._credentials.get_model_access_kwargs(model)) # type: ignore
if tools_compat_mode:
parse_kwargs["compat_mode"] = True
return openai_messages, kwargs, parse_kwargs
def _parse_assistant_tool_calls(
self,
assistant_message: ChatCompletionMessage,
compat_mode: bool = False,
**kwargs,
) -> tuple[list[AssistantToolCall], list[Exception]]:
tool_calls: list[AssistantToolCall] = []
parse_errors: list[Exception] = []
if not compat_mode:
return super()._parse_assistant_tool_calls(
assistant_message=assistant_message, compat_mode=compat_mode, **kwargs
)
elif assistant_message.content:
try:
tool_calls = list(
_tool_calls_compat_extract_calls(assistant_message.content)
)
except Exception as e:
parse_errors.append(e)
return tool_calls, parse_errors
def _get_embedding_kwargs(
self, input: str | list[str], model: OpenAIModelName, **kwargs
) -> EmbeddingCreateParams:
kwargs = super()._get_embedding_kwargs(input=input, model=model, **kwargs)
kwargs.update(self._credentials.get_model_access_kwargs(model)) # type: ignore
return kwargs
_get_embedding_kwargs.__doc__ = (
BaseOpenAIEmbeddingProvider._get_embedding_kwargs.__doc__
)
def _retry_api_request(self, func: Callable[_P, _T]) -> Callable[_P, _T]:
_log_retry_debug_message = tenacity.after_log(self._logger, logging.DEBUG)
def _log_on_fail(retry_state: tenacity.RetryCallState) -> None:
_log_retry_debug_message(retry_state)
if (
retry_state.attempt_number == 0
and retry_state.outcome
and isinstance(retry_state.outcome.exception(), RateLimitError)
):
self._logger.warning(
"Please double check that you have setup a PAID OpenAI API Account."
" You can read more here: "
"https://docs.agpt.co/setup/#getting-an-openai-api-key"
)
return tenacity.retry(
retry=(
tenacity.retry_if_exception_type(RateLimitError)
| tenacity.retry_if_exception(
lambda e: isinstance(e, APIStatusError) and e.status_code == 502
)
),
wait=tenacity.wait_exponential(),
stop=tenacity.stop_after_attempt(self._configuration.retries_per_request),
after=_log_on_fail,
)(func)
def __repr__(self):
return "OpenAIProvider()"
def format_function_specs_as_typescript_ns(
functions: list[CompletionModelFunction],
) -> str:
"""Returns a function signature block in the format used by OpenAI internally:
https://community.openai.com/t/how-to-calculate-the-tokens-when-using-function-call/266573/18
For use with `count_tokens` to determine token usage of provided functions.
Example:
```ts
namespace functions {
// Get the current weather in a given location
type get_current_weather = (_: {
// The city and state, e.g. San Francisco, CA
location: string,
unit?: "celsius" | "fahrenheit",
}) => any;
} // namespace functions
```
"""
return (
"namespace functions {\n\n"
+ "\n\n".join(format_openai_function_for_prompt(f) for f in functions)
+ "\n\n} // namespace functions"
)
def format_openai_function_for_prompt(func: CompletionModelFunction) -> str:
"""Returns the function formatted similarly to the way OpenAI does it internally:
https://community.openai.com/t/how-to-calculate-the-tokens-when-using-function-call/266573/18
Example:
```ts
// Get the current weather in a given location
type get_current_weather = (_: {
// The city and state, e.g. San Francisco, CA
location: string,
unit?: "celsius" | "fahrenheit",
}) => any;
```
"""
def param_signature(name: str, spec: JSONSchema) -> str:
return (
f"// {spec.description}\n" if spec.description else ""
) + f"{name}{'' if spec.required else '?'}: {spec.typescript_type},"
return "\n".join(
[
f"// {func.description}",
f"type {func.name} = (_ :{{",
*[param_signature(name, p) for name, p in func.parameters.items()],
"}) => any;",
]
)
def count_openai_functions_tokens(
functions: list[CompletionModelFunction], count_tokens: Callable[[str], int]
) -> int:
"""Returns the number of tokens taken up by a set of function definitions
Reference: https://community.openai.com/t/how-to-calculate-the-tokens-when-using-function-call/266573/18 # noqa: E501
"""
return count_tokens(
"# Tools\n\n"
"## functions\n\n"
f"{format_function_specs_as_typescript_ns(functions)}"
)
def _functions_compat_fix_kwargs(
functions: list[CompletionModelFunction],
prompt_messages: list[ChatMessage],
):
function_definitions = format_function_specs_as_typescript_ns(functions)
function_call_schema = JSONSchema(
type=JSONSchema.Type.OBJECT,
properties={
"name": JSONSchema(
description="The name of the function to call",
enum=[f.name for f in functions],
required=True,
),
"arguments": JSONSchema(
description="The arguments for the function call",
type=JSONSchema.Type.OBJECT,
required=True,
),
},
)
tool_calls_schema = JSONSchema(
type=JSONSchema.Type.ARRAY,
items=JSONSchema(
type=JSONSchema.Type.OBJECT,
properties={
"type": JSONSchema(
type=JSONSchema.Type.STRING,
enum=["function"],
),
"function": function_call_schema,
},
),
)
prompt_messages.append(
ChatMessage.system(
"# tool usage instructions\n\n"
"Specify a '```tool_calls' block in your response,"
" with a valid JSON object that adheres to the following schema:\n\n"
f"{tool_calls_schema.to_dict()}\n\n"
"Specify any tools that you need to use through this JSON object.\n\n"
"Put the tool_calls block at the end of your response"
" and include its fences if it is not the only content.\n\n"
"## functions\n\n"
"For the function call itself, use one of the following"
f" functions:\n\n{function_definitions}"
),
)
def _tool_calls_compat_extract_calls(response: str) -> Iterator[AssistantToolCall]:
import re
import uuid
logging.debug(f"Trying to extract tool calls from response:\n{response}")
if response[0] == "[":
tool_calls: list[AssistantToolCallDict] = json_loads(response)
else:
block = re.search(r"```(?:tool_calls)?\n(.*)\n```\s*$", response, re.DOTALL)
if not block:
raise ValueError("Could not find tool_calls block in response")
tool_calls: list[AssistantToolCallDict] = json_loads(block.group(1))
for t in tool_calls:
t["id"] = str(uuid.uuid4())
yield AssistantToolCall.parse_obj(t)

View File

@@ -0,0 +1,460 @@
import abc
import enum
import logging
import math
from collections import defaultdict
from typing import (
TYPE_CHECKING,
Any,
Callable,
ClassVar,
Generic,
Literal,
Optional,
Protocol,
Sequence,
TypedDict,
TypeVar,
)
from pydantic import BaseModel, Field, SecretStr
from forge.logging.utils import fmt_kwargs
from forge.models.config import (
Configurable,
SystemConfiguration,
SystemSettings,
UserConfigurable,
)
from forge.models.json_schema import JSONSchema
from forge.models.providers import (
Embedding,
ProviderBudget,
ProviderCredentials,
ResourceType,
)
if TYPE_CHECKING:
from jsonschema import ValidationError
_T = TypeVar("_T")
_ModelName = TypeVar("_ModelName", bound=str)
class ModelProviderService(str, enum.Enum):
"""A ModelService describes what kind of service the model provides."""
EMBEDDING = "embedding"
CHAT = "chat_completion"
TEXT = "text_completion"
class ModelProviderName(str, enum.Enum):
OPENAI = "openai"
ANTHROPIC = "anthropic"
GROQ = "groq"
class ChatMessage(BaseModel):
class Role(str, enum.Enum):
USER = "user"
SYSTEM = "system"
ASSISTANT = "assistant"
TOOL = "tool"
"""May be used for the result of tool calls"""
FUNCTION = "function"
"""May be used for the return value of function calls"""
role: Role
content: str
@staticmethod
def user(content: str) -> "ChatMessage":
return ChatMessage(role=ChatMessage.Role.USER, content=content)
@staticmethod
def system(content: str) -> "ChatMessage":
return ChatMessage(role=ChatMessage.Role.SYSTEM, content=content)
class ChatMessageDict(TypedDict):
role: str
content: str
class AssistantFunctionCall(BaseModel):
name: str
arguments: dict[str, Any]
def __str__(self) -> str:
return f"{self.name}({fmt_kwargs(self.arguments)})"
class AssistantFunctionCallDict(TypedDict):
name: str
arguments: dict[str, Any]
class AssistantToolCall(BaseModel):
id: str
type: Literal["function"]
function: AssistantFunctionCall
class AssistantToolCallDict(TypedDict):
id: str
type: Literal["function"]
function: AssistantFunctionCallDict
class AssistantChatMessage(ChatMessage):
role: Literal[ChatMessage.Role.ASSISTANT] = ChatMessage.Role.ASSISTANT # type: ignore # noqa
content: str = ""
tool_calls: Optional[list[AssistantToolCall]] = None
class ToolResultMessage(ChatMessage):
role: Literal[ChatMessage.Role.TOOL] = ChatMessage.Role.TOOL # type: ignore
is_error: bool = False
tool_call_id: str
class AssistantChatMessageDict(TypedDict, total=False):
role: str
content: str
tool_calls: list[AssistantToolCallDict]
class CompletionModelFunction(BaseModel):
"""General representation object for LLM-callable functions."""
name: str
description: str
parameters: dict[str, "JSONSchema"]
return_type: str | None = None
is_async: bool = False
def fmt_line(self) -> str:
params = ", ".join(
f"{name}{'?' if not p.required else ''}: " f"{p.typescript_type}"
for name, p in self.parameters.items()
)
return f"{self.name}: {self.description}. Params: ({params})"
def fmt_header(self, impl="pass", force_async=False) -> str:
"""
Formats and returns the function header as a string with types and descriptions.
Returns:
str: The formatted function header.
"""
def indent(content: str, spaces: int = 4):
return " " * spaces + content.replace("\n", "\n" + " " * spaces)
params = ", ".join(
f"{name}: {p.python_type}{f'= {str(p.default)}' if p.default else ' = None' if not p.required else ''}"
for name, p in self.parameters.items()
)
func = "async def" if self.is_async or force_async else "def"
return_str = f" -> {self.return_type}" if self.return_type else ""
return f"{func} {self.name}({params}){return_str}:\n" + indent(
(
'"""\n'
f"{self.description}\n\n"
"Params:\n"
+ indent(
"\n".join(
f"{name}: {param.description}"
for name, param in self.parameters.items()
if param.description
)
)
+ "\n"
'"""\n'
f"{impl}"
),
)
def validate_call(
self, function_call: AssistantFunctionCall
) -> tuple[bool, list["ValidationError"]]:
"""
Validates the given function call against the function's parameter specs
Returns:
bool: Whether the given set of arguments is valid for this command
list[ValidationError]: Issues with the set of arguments (if any)
Raises:
ValueError: If the function_call doesn't call this function
"""
if function_call.name != self.name:
raise ValueError(
f"Can't validate {function_call.name} call using {self.name} spec"
)
params_schema = JSONSchema(
type=JSONSchema.Type.OBJECT,
properties={name: spec for name, spec in self.parameters.items()},
)
return params_schema.validate_object(function_call.arguments)
class ModelInfo(BaseModel, Generic[_ModelName]):
"""Struct for model information.
Would be lovely to eventually get this directly from APIs, but needs to be
scraped from websites for now.
"""
name: _ModelName
service: ClassVar[ModelProviderService]
provider_name: ModelProviderName
prompt_token_cost: float = 0.0
completion_token_cost: float = 0.0
class ModelResponse(BaseModel):
"""Standard response struct for a response from a model."""
prompt_tokens_used: int
completion_tokens_used: int
model_info: ModelInfo
class ModelProviderConfiguration(SystemConfiguration):
retries_per_request: int = UserConfigurable(7)
fix_failed_parse_tries: int = UserConfigurable(3)
extra_request_headers: dict[str, str] = Field(default_factory=dict)
class ModelProviderCredentials(ProviderCredentials):
"""Credentials for a model provider."""
api_key: SecretStr | None = UserConfigurable(default=None)
api_type: SecretStr | None = UserConfigurable(default=None)
api_base: SecretStr | None = UserConfigurable(default=None)
api_version: SecretStr | None = UserConfigurable(default=None)
deployment_id: SecretStr | None = UserConfigurable(default=None)
class Config(ProviderCredentials.Config):
extra = "ignore"
class ModelProviderUsage(BaseModel):
"""Usage for a particular model from a model provider."""
class ModelUsage(BaseModel):
completion_tokens: int = 0
prompt_tokens: int = 0
usage_per_model: dict[str, ModelUsage] = defaultdict(ModelUsage)
@property
def completion_tokens(self) -> int:
return sum(model.completion_tokens for model in self.usage_per_model.values())
@property
def prompt_tokens(self) -> int:
return sum(model.prompt_tokens for model in self.usage_per_model.values())
def update_usage(
self,
model: str,
input_tokens_used: int,
output_tokens_used: int = 0,
) -> None:
self.usage_per_model[model].prompt_tokens += input_tokens_used
self.usage_per_model[model].completion_tokens += output_tokens_used
class ModelProviderBudget(ProviderBudget[ModelProviderUsage]):
usage: ModelProviderUsage = Field(default_factory=ModelProviderUsage)
def update_usage_and_cost(
self,
model_info: ModelInfo,
input_tokens_used: int,
output_tokens_used: int = 0,
) -> float:
"""Update the usage and cost of the provider.
Returns:
float: The (calculated) cost of the given model response.
"""
self.usage.update_usage(model_info.name, input_tokens_used, output_tokens_used)
incurred_cost = (
output_tokens_used * model_info.completion_token_cost
+ input_tokens_used * model_info.prompt_token_cost
)
self.total_cost += incurred_cost
self.remaining_budget -= incurred_cost
return incurred_cost
class ModelProviderSettings(SystemSettings):
resource_type: ClassVar[ResourceType] = ResourceType.MODEL
configuration: ModelProviderConfiguration
credentials: Optional[ModelProviderCredentials] = None
budget: Optional[ModelProviderBudget] = None
_ModelProviderSettings = TypeVar("_ModelProviderSettings", bound=ModelProviderSettings)
# TODO: either use MultiProvider throughout codebase as type for `llm_provider`, or
# replace `_ModelName` by `str` to eliminate type checking difficulties
class BaseModelProvider(
abc.ABC,
Generic[_ModelName, _ModelProviderSettings],
Configurable[_ModelProviderSettings],
):
"""A ModelProvider abstracts the details of a particular provider of models."""
default_settings: ClassVar[_ModelProviderSettings] # type: ignore
_settings: _ModelProviderSettings
_logger: logging.Logger
def __init__(
self,
settings: Optional[_ModelProviderSettings] = None,
logger: Optional[logging.Logger] = None,
):
if not settings:
settings = self.default_settings.copy(deep=True)
self._settings = settings
self._configuration = settings.configuration
self._credentials = settings.credentials
self._budget = settings.budget
self._logger = logger or logging.getLogger(self.__module__)
@abc.abstractmethod
async def get_available_models(
self,
) -> Sequence["ChatModelInfo[_ModelName] | EmbeddingModelInfo[_ModelName]"]:
...
@abc.abstractmethod
def count_tokens(self, text: str, model_name: _ModelName) -> int:
...
@abc.abstractmethod
def get_tokenizer(self, model_name: _ModelName) -> "ModelTokenizer[Any]":
...
@abc.abstractmethod
def get_token_limit(self, model_name: _ModelName) -> int:
...
def get_incurred_cost(self) -> float:
if self._budget:
return self._budget.total_cost
return 0
def get_remaining_budget(self) -> float:
if self._budget:
return self._budget.remaining_budget
return math.inf
class ModelTokenizer(Protocol, Generic[_T]):
"""A ModelTokenizer provides tokenization specific to a model."""
@abc.abstractmethod
def encode(self, text: str) -> list[_T]:
...
@abc.abstractmethod
def decode(self, tokens: list[_T]) -> str:
...
####################
# Embedding Models #
####################
class EmbeddingModelInfo(ModelInfo[_ModelName]):
"""Struct for embedding model information."""
service: Literal[ModelProviderService.EMBEDDING] = ModelProviderService.EMBEDDING # type: ignore # noqa
max_tokens: int
embedding_dimensions: int
class EmbeddingModelResponse(ModelResponse):
"""Standard response struct for a response from an embedding model."""
embedding: Embedding = Field(default_factory=list)
completion_tokens_used: int = Field(default=0, const=True)
class BaseEmbeddingModelProvider(BaseModelProvider[_ModelName, _ModelProviderSettings]):
@abc.abstractmethod
async def get_available_embedding_models(
self,
) -> Sequence[EmbeddingModelInfo[_ModelName]]:
...
@abc.abstractmethod
async def create_embedding(
self,
text: str,
model_name: _ModelName,
embedding_parser: Callable[[Embedding], Embedding],
**kwargs,
) -> EmbeddingModelResponse:
...
###############
# Chat Models #
###############
class ChatModelInfo(ModelInfo[_ModelName]):
"""Struct for language model information."""
service: Literal[ModelProviderService.CHAT] = ModelProviderService.CHAT # type: ignore # noqa
max_tokens: int
has_function_call_api: bool = False
class ChatModelResponse(ModelResponse, Generic[_T]):
"""Standard response struct for a response from a language model."""
response: AssistantChatMessage
parsed_result: _T
class BaseChatModelProvider(BaseModelProvider[_ModelName, _ModelProviderSettings]):
@abc.abstractmethod
async def get_available_chat_models(self) -> Sequence[ChatModelInfo[_ModelName]]:
...
@abc.abstractmethod
def count_message_tokens(
self,
messages: ChatMessage | list[ChatMessage],
model_name: _ModelName,
) -> int:
...
@abc.abstractmethod
async def create_chat_completion(
self,
model_prompt: list[ChatMessage],
model_name: _ModelName,
completion_parser: Callable[[AssistantChatMessage], _T] = lambda _: None,
functions: Optional[list[CompletionModelFunction]] = None,
max_output_tokens: Optional[int] = None,
prefill_response: str = "",
**kwargs,
) -> ChatModelResponse[_T]:
...

View File

@@ -0,0 +1,88 @@
from typing import TYPE_CHECKING, Any, Iterable
if TYPE_CHECKING:
from forge.command.command import Command
from .schema import AssistantToolCall, CompletionModelFunction
class InvalidFunctionCallError(Exception):
def __init__(self, name: str, arguments: dict[str, Any], message: str):
self.message = message
self.name = name
self.arguments = arguments
super().__init__(message)
def __str__(self) -> str:
return f"Invalid function call for {self.name}: {self.message}"
def validate_tool_calls(
tool_calls: list[AssistantToolCall], functions: list[CompletionModelFunction]
) -> list[InvalidFunctionCallError]:
"""
Validates a list of tool calls against a list of functions.
1. Tries to find a function matching each tool call
2. If a matching function is found, validates the tool call's arguments,
reporting any resulting errors
2. If no matching function is found, an error "Unknown function X" is reported
3. A list of all errors encountered during validation is returned
Params:
tool_calls: A list of tool calls to validate.
functions: A list of functions to validate against.
Returns:
list[InvalidFunctionCallError]: All errors encountered during validation.
"""
errors: list[InvalidFunctionCallError] = []
for tool_call in tool_calls:
function_call = tool_call.function
if function := next(
(f for f in functions if f.name == function_call.name),
None,
):
is_valid, validation_errors = function.validate_call(function_call)
if not is_valid:
fmt_errors = [
f"{'.'.join(str(p) for p in f.path)}: {f.message}"
if f.path
else f.message
for f in validation_errors
]
errors.append(
InvalidFunctionCallError(
name=function_call.name,
arguments=function_call.arguments,
message=(
"The set of arguments supplied is invalid:\n"
+ "\n".join(fmt_errors)
),
)
)
else:
errors.append(
InvalidFunctionCallError(
name=function_call.name,
arguments=function_call.arguments,
message=f"Unknown function {function_call.name}",
)
)
return errors
def function_specs_from_commands(
commands: Iterable["Command"],
) -> list[CompletionModelFunction]:
"""Get LLM-consumable function specs for the agent's available commands."""
return [
CompletionModelFunction(
name=command.names[0],
description=command.description,
parameters={param.name: param.spec for param in command.parameters},
)
for command in commands
]

View File

@@ -0,0 +1,9 @@
from .config import configure_logging
from .filters import BelowLevelFilter
from .formatters import FancyConsoleFormatter
__all__ = [
"configure_logging",
"BelowLevelFilter",
"FancyConsoleFormatter",
]

View File

@@ -0,0 +1,200 @@
"""Logging module for Auto-GPT."""
from __future__ import annotations
import enum
import logging
import os
import sys
from pathlib import Path
from typing import TYPE_CHECKING, Optional
from openai._base_client import log as openai_logger
from forge.models.config import SystemConfiguration, UserConfigurable
if TYPE_CHECKING:
from forge.speech import TTSConfig
from .filters import BelowLevelFilter
from .formatters import ForgeFormatter, StructuredLoggingFormatter
from .handlers import TTSHandler
LOG_DIR = Path(__file__).parent.parent.parent / "logs"
LOG_FILE = "activity.log"
DEBUG_LOG_FILE = "debug.log"
ERROR_LOG_FILE = "error.log"
SIMPLE_LOG_FORMAT = "%(asctime)s %(levelname)s %(title)s%(message)s"
DEBUG_LOG_FORMAT = (
"%(asctime)s %(levelname)s %(filename)s:%(lineno)d" " %(title)s%(message)s"
)
SPEECH_OUTPUT_LOGGER = "VOICE"
USER_FRIENDLY_OUTPUT_LOGGER = "USER_FRIENDLY_OUTPUT"
class LogFormatName(str, enum.Enum):
SIMPLE = "simple"
DEBUG = "debug"
STRUCTURED = "structured_google_cloud"
TEXT_LOG_FORMAT_MAP = {
LogFormatName.DEBUG: DEBUG_LOG_FORMAT,
LogFormatName.SIMPLE: SIMPLE_LOG_FORMAT,
}
class LoggingConfig(SystemConfiguration):
level: int = UserConfigurable(
default=logging.INFO,
from_env=lambda: logging.getLevelName(os.getenv("LOG_LEVEL", "INFO")),
)
# Console output
log_format: LogFormatName = UserConfigurable(
default=LogFormatName.SIMPLE, from_env="LOG_FORMAT"
)
plain_console_output: bool = UserConfigurable(
default=False,
from_env=lambda: os.getenv("PLAIN_OUTPUT", "False") == "True",
)
# File output
log_dir: Path = LOG_DIR
log_file_format: Optional[LogFormatName] = UserConfigurable(
default=LogFormatName.SIMPLE,
from_env=lambda: os.getenv( # type: ignore
"LOG_FILE_FORMAT", os.getenv("LOG_FORMAT", "simple")
),
)
def configure_logging(
debug: bool = False,
level: Optional[int | str] = None,
log_dir: Optional[Path] = None,
log_format: Optional[LogFormatName | str] = None,
log_file_format: Optional[LogFormatName | str] = None,
plain_console_output: Optional[bool] = None,
config: Optional[LoggingConfig] = None,
tts_config: Optional[TTSConfig] = None,
) -> None:
"""Configure the native logging module, based on the environment config and any
specified overrides.
Arguments override values specified in the environment.
Overrides are also applied to `config`, if passed.
Should be usable as `configure_logging(**config.logging.dict())`, where
`config.logging` is a `LoggingConfig` object.
"""
if debug and level:
raise ValueError("Only one of either 'debug' and 'level' arguments may be set")
# Parse arguments
if isinstance(level, str):
if type(_level := logging.getLevelName(level.upper())) is int:
level = _level
else:
raise ValueError(f"Unknown log level '{level}'")
if isinstance(log_format, str):
if log_format in LogFormatName._value2member_map_:
log_format = LogFormatName(log_format)
elif not isinstance(log_format, LogFormatName):
raise ValueError(f"Unknown log format '{log_format}'")
if isinstance(log_file_format, str):
if log_file_format in LogFormatName._value2member_map_:
log_file_format = LogFormatName(log_file_format)
elif not isinstance(log_file_format, LogFormatName):
raise ValueError(f"Unknown log format '{log_format}'")
config = config or LoggingConfig.from_env()
# Aggregate env config + arguments
config.level = logging.DEBUG if debug else level or config.level
config.log_dir = log_dir or config.log_dir
config.log_format = log_format or (
LogFormatName.DEBUG if debug else config.log_format
)
config.log_file_format = log_file_format or log_format or config.log_file_format
config.plain_console_output = (
plain_console_output
if plain_console_output is not None
else config.plain_console_output
)
# Structured logging is used for cloud environments,
# where logging to a file makes no sense.
if config.log_format == LogFormatName.STRUCTURED:
config.plain_console_output = True
config.log_file_format = None
# create log directory if it doesn't exist
if not config.log_dir.exists():
config.log_dir.mkdir()
log_handlers: list[logging.Handler] = []
if config.log_format in (LogFormatName.DEBUG, LogFormatName.SIMPLE):
console_format_template = TEXT_LOG_FORMAT_MAP[config.log_format]
console_formatter = ForgeFormatter(console_format_template)
else:
console_formatter = StructuredLoggingFormatter()
console_format_template = SIMPLE_LOG_FORMAT
# Console output handlers
stdout = logging.StreamHandler(stream=sys.stdout)
stdout.setLevel(config.level)
stdout.addFilter(BelowLevelFilter(logging.WARNING))
stdout.setFormatter(console_formatter)
stderr = logging.StreamHandler()
stderr.setLevel(logging.WARNING)
stderr.setFormatter(console_formatter)
log_handlers += [stdout, stderr]
# File output handlers
if config.log_file_format is not None:
if config.level < logging.ERROR:
file_output_format_template = TEXT_LOG_FORMAT_MAP[config.log_file_format]
file_output_formatter = ForgeFormatter(
file_output_format_template, no_color=True
)
# INFO log file handler
activity_log_handler = logging.FileHandler(
config.log_dir / LOG_FILE, "a", "utf-8"
)
activity_log_handler.setLevel(config.level)
activity_log_handler.setFormatter(file_output_formatter)
log_handlers += [activity_log_handler]
# ERROR log file handler
error_log_handler = logging.FileHandler(
config.log_dir / ERROR_LOG_FILE, "a", "utf-8"
)
error_log_handler.setLevel(logging.ERROR)
error_log_handler.setFormatter(ForgeFormatter(DEBUG_LOG_FORMAT, no_color=True))
log_handlers += [error_log_handler]
# Configure the root logger
logging.basicConfig(
format=console_format_template,
level=config.level,
handlers=log_handlers,
)
# Speech output
speech_output_logger = logging.getLogger(SPEECH_OUTPUT_LOGGER)
speech_output_logger.setLevel(logging.INFO)
if tts_config:
speech_output_logger.addHandler(TTSHandler(tts_config))
speech_output_logger.propagate = False
# JSON logger with better formatting
json_logger = logging.getLogger("JSON_LOGGER")
json_logger.setLevel(logging.DEBUG)
json_logger.propagate = False
# Disable debug logging from OpenAI library
openai_logger.setLevel(logging.WARNING)

View File

@@ -0,0 +1,12 @@
import logging
class BelowLevelFilter(logging.Filter):
"""Filter for logging levels below a certain threshold."""
def __init__(self, below_level: int):
super().__init__()
self.below_level = below_level
def filter(self, record: logging.LogRecord):
return record.levelno < self.below_level

View File

@@ -0,0 +1,95 @@
import logging
from colorama import Fore, Style
from google.cloud.logging_v2.handlers import CloudLoggingFilter, StructuredLogHandler
from .utils import remove_color_codes
class FancyConsoleFormatter(logging.Formatter):
"""
A custom logging formatter designed for console output.
This formatter enhances the standard logging output with color coding. The color
coding is based on the level of the log message, making it easier to distinguish
between different types of messages in the console output.
The color for each level is defined in the LEVEL_COLOR_MAP class attribute.
"""
# level -> (level & text color, title color)
LEVEL_COLOR_MAP = {
logging.DEBUG: Fore.LIGHTBLACK_EX,
logging.INFO: Fore.BLUE,
logging.WARNING: Fore.YELLOW,
logging.ERROR: Fore.RED,
logging.CRITICAL: Fore.RED + Style.BRIGHT,
}
def format(self, record: logging.LogRecord) -> str:
# Make sure `msg` is a string
if not hasattr(record, "msg"):
record.msg = ""
elif not type(record.msg) is str:
record.msg = str(record.msg)
# Determine default color based on error level
level_color = ""
if record.levelno in self.LEVEL_COLOR_MAP:
level_color = self.LEVEL_COLOR_MAP[record.levelno]
record.levelname = f"{level_color}{record.levelname}{Style.RESET_ALL}"
# Determine color for message
color = getattr(record, "color", level_color)
color_is_specified = hasattr(record, "color")
# Don't color INFO messages unless the color is explicitly specified.
if color and (record.levelno != logging.INFO or color_is_specified):
record.msg = f"{color}{record.msg}{Style.RESET_ALL}"
return super().format(record)
class ForgeFormatter(FancyConsoleFormatter):
def __init__(self, *args, no_color: bool = False, **kwargs):
super().__init__(*args, **kwargs)
self.no_color = no_color
def format(self, record: logging.LogRecord) -> str:
# Make sure `msg` is a string
if not hasattr(record, "msg"):
record.msg = ""
elif not type(record.msg) is str:
record.msg = str(record.msg)
# Strip color from the message to prevent color spoofing
if record.msg and not getattr(record, "preserve_color", False):
record.msg = remove_color_codes(record.msg)
# Determine color for title
title = getattr(record, "title", "")
title_color = getattr(record, "title_color", "") or self.LEVEL_COLOR_MAP.get(
record.levelno, ""
)
if title and title_color:
title = f"{title_color + Style.BRIGHT}{title}{Style.RESET_ALL}"
# Make sure record.title is set, and padded with a space if not empty
record.title = f"{title} " if title else ""
if self.no_color:
return remove_color_codes(super().format(record))
else:
return super().format(record)
class StructuredLoggingFormatter(StructuredLogHandler, logging.Formatter):
def __init__(self):
# Set up CloudLoggingFilter to add diagnostic info to the log records
self.cloud_logging_filter = CloudLoggingFilter()
# Init StructuredLogHandler
super().__init__()
def format(self, record: logging.LogRecord) -> str:
self.cloud_logging_filter.filter(record)
return super().format(record)

View File

@@ -0,0 +1,45 @@
from __future__ import annotations
import json
import logging
from typing import TYPE_CHECKING
from forge.logging.utils import remove_color_codes
from forge.speech import TextToSpeechProvider
if TYPE_CHECKING:
from forge.speech import TTSConfig
class TTSHandler(logging.Handler):
"""Output messages to the configured TTS engine (if any)"""
def __init__(self, config: TTSConfig):
super().__init__()
self.config = config
self.tts_provider = TextToSpeechProvider(config)
def format(self, record: logging.LogRecord) -> str:
if getattr(record, "title", ""):
msg = f"{getattr(record, 'title')} {record.msg}"
else:
msg = f"{record.msg}"
return remove_color_codes(msg)
def emit(self, record: logging.LogRecord) -> None:
if not self.config.speak_mode:
return
message = self.format(record)
self.tts_provider.say(message)
class JsonFileHandler(logging.FileHandler):
def format(self, record: logging.LogRecord) -> str:
record.json_data = json.loads(record.getMessage())
return json.dumps(getattr(record, "json_data"), ensure_ascii=False, indent=4)
def emit(self, record: logging.LogRecord) -> None:
with open(self.baseFilename, "w", encoding="utf-8") as f:
f.write(self.format(record))

View File

@@ -0,0 +1,33 @@
import logging
import re
from typing import Any
from colorama import Fore
def remove_color_codes(s: str) -> str:
return re.sub(r"\x1B(?:[@-Z\\-_]|\[[0-?]*[ -/]*[@-~])", "", s)
def fmt_kwargs(kwargs: dict) -> str:
return ", ".join(f"{n}={repr(v)}" for n, v in kwargs.items())
def print_attribute(
title: str, value: Any, title_color: str = Fore.GREEN, value_color: str = ""
) -> None:
logger = logging.getLogger()
logger.info(
str(value),
extra={
"title": f"{title.rstrip(':')}:",
"title_color": title_color,
"color": value_color,
},
)
def speak(message: str, level: int = logging.INFO) -> None:
from .config import SPEECH_OUTPUT_LOGGER
logging.getLogger(SPEECH_OUTPUT_LOGGER).log(level, message)

View File

@@ -0,0 +1,79 @@
from __future__ import annotations
from typing import Any, Literal, Optional, TypeVar
from pydantic import BaseModel
from forge.llm.providers.schema import AssistantFunctionCall
from .utils import ModelWithSummary
class ActionProposal(BaseModel):
thoughts: str | ModelWithSummary
use_tool: AssistantFunctionCall
AnyProposal = TypeVar("AnyProposal", bound=ActionProposal)
class ActionSuccessResult(BaseModel):
outputs: Any
status: Literal["success"] = "success"
def __str__(self) -> str:
outputs = str(self.outputs).replace("```", r"\```")
multiline = "\n" in outputs
return f"```\n{self.outputs}\n```" if multiline else str(self.outputs)
class ErrorInfo(BaseModel):
args: tuple
message: str
exception_type: str
repr: str
@staticmethod
def from_exception(exception: Exception) -> ErrorInfo:
return ErrorInfo(
args=exception.args,
message=getattr(exception, "message", exception.args[0]),
exception_type=exception.__class__.__name__,
repr=repr(exception),
)
def __str__(self):
return repr(self)
def __repr__(self):
return self.repr
class ActionErrorResult(BaseModel):
reason: str
error: Optional[ErrorInfo] = None
status: Literal["error"] = "error"
@staticmethod
def from_exception(exception: Exception) -> ActionErrorResult:
return ActionErrorResult(
reason=getattr(exception, "message", exception.args[0]),
error=ErrorInfo.from_exception(exception),
)
def __str__(self) -> str:
return f"Action failed: '{self.reason}'"
class ActionInterruptedByHuman(BaseModel):
feedback: str
status: Literal["interrupted_by_human"] = "interrupted_by_human"
def __str__(self) -> str:
return (
'The user interrupted the action with the following feedback: "%s"'
% self.feedback
)
ActionResult = ActionSuccessResult | ActionErrorResult | ActionInterruptedByHuman

View File

@@ -0,0 +1,350 @@
import os
import typing
from typing import Any, Callable, Generic, Optional, Type, TypeVar, get_args
from pydantic import BaseModel, Field, ValidationError
from pydantic.fields import ModelField, Undefined, UndefinedType
from pydantic.main import ModelMetaclass
T = TypeVar("T")
M = TypeVar("M", bound=BaseModel)
def UserConfigurable(
default: T | UndefinedType = Undefined,
*args,
default_factory: Optional[Callable[[], T]] = None,
from_env: Optional[str | Callable[[], T | None]] = None,
description: str = "",
**kwargs,
) -> T:
# TODO: use this to auto-generate docs for the application configuration
return Field(
default,
*args,
default_factory=default_factory,
from_env=from_env,
description=description,
**kwargs,
user_configurable=True,
)
class SystemConfiguration(BaseModel):
def get_user_config(self) -> dict[str, Any]:
return _recurse_user_config_values(self)
@classmethod
def from_env(cls):
"""
Initializes the config object from environment variables.
Environment variables are mapped to UserConfigurable fields using the from_env
attribute that can be passed to UserConfigurable.
"""
def infer_field_value(field: ModelField):
field_info = field.field_info
default_value = (
field.default
if field.default not in (None, Undefined)
else (field.default_factory() if field.default_factory else Undefined)
)
if from_env := field_info.extra.get("from_env"):
val_from_env = (
os.getenv(from_env) if type(from_env) is str else from_env()
)
if val_from_env is not None:
return val_from_env
return default_value
return _recursive_init_model(cls, infer_field_value)
class Config:
extra = "forbid"
use_enum_values = True
validate_assignment = True
SC = TypeVar("SC", bound=SystemConfiguration)
class SystemSettings(BaseModel):
"""A base class for all system settings."""
name: str
description: str
class Config:
extra = "forbid"
use_enum_values = True
validate_assignment = True
S = TypeVar("S", bound=SystemSettings)
class Configurable(Generic[S]):
"""A base class for all configurable objects."""
prefix: str = ""
default_settings: typing.ClassVar[S] # type: ignore
@classmethod
def get_user_config(cls) -> dict[str, Any]:
return _recurse_user_config_values(cls.default_settings)
@classmethod
def build_agent_configuration(cls, overrides: dict = {}) -> S:
"""Process the configuration for this object."""
base_config = _update_user_config_from_env(cls.default_settings)
final_configuration = deep_update(base_config, overrides)
return cls.default_settings.__class__.parse_obj(final_configuration)
def _update_user_config_from_env(instance: BaseModel) -> dict[str, Any]:
"""
Update config fields of a Pydantic model instance from environment variables.
Precedence:
1. Non-default value already on the instance
2. Value returned by `from_env()`
3. Default value for the field
Params:
instance: The Pydantic model instance.
Returns:
The user config fields of the instance.
"""
def infer_field_value(field: ModelField, value):
field_info = field.field_info
default_value = (
field.default
if field.default not in (None, Undefined)
else (field.default_factory() if field.default_factory else None)
)
if value == default_value and (from_env := field_info.extra.get("from_env")):
val_from_env = os.getenv(from_env) if type(from_env) is str else from_env()
if val_from_env is not None:
return val_from_env
return value
def init_sub_config(model: Type[SC]) -> SC | None:
try:
return model.from_env()
except ValidationError as e:
# Gracefully handle missing fields
if all(e["type"] == "value_error.missing" for e in e.errors()):
return None
raise
return _recurse_user_config_fields(instance, infer_field_value, init_sub_config)
def _recursive_init_model(
model: Type[M],
infer_field_value: Callable[[ModelField], Any],
) -> M:
"""
Recursively initialize the user configuration fields of a Pydantic model.
Parameters:
model: The Pydantic model type.
infer_field_value: A callback function to infer the value of each field.
Parameters:
ModelField: The Pydantic ModelField object describing the field.
Returns:
BaseModel: An instance of the model with the initialized configuration.
"""
user_config_fields = {}
for name, field in model.__fields__.items():
if "user_configurable" in field.field_info.extra:
user_config_fields[name] = infer_field_value(field)
elif type(field.outer_type_) is ModelMetaclass and issubclass(
field.outer_type_, SystemConfiguration
):
try:
user_config_fields[name] = _recursive_init_model(
model=field.outer_type_,
infer_field_value=infer_field_value,
)
except ValidationError as e:
# Gracefully handle missing fields
if all(e["type"] == "value_error.missing" for e in e.errors()):
user_config_fields[name] = None
raise
user_config_fields = remove_none_items(user_config_fields)
return model.parse_obj(user_config_fields)
def _recurse_user_config_fields(
model: BaseModel,
infer_field_value: Callable[[ModelField, Any], Any],
init_sub_config: Optional[
Callable[[Type[SystemConfiguration]], SystemConfiguration | None]
] = None,
) -> dict[str, Any]:
"""
Recursively process the user configuration fields of a Pydantic model instance.
Params:
model: The Pydantic model to iterate over.
infer_field_value: A callback function to process each field.
Params:
ModelField: The Pydantic ModelField object describing the field.
Any: The current value of the field.
init_sub_config: An optional callback function to initialize a sub-config.
Params:
Type[SystemConfiguration]: The type of the sub-config to initialize.
Returns:
dict[str, Any]: The processed user configuration fields of the instance.
"""
user_config_fields = {}
for name, field in model.__fields__.items():
value = getattr(model, name)
# Handle individual field
if "user_configurable" in field.field_info.extra:
user_config_fields[name] = infer_field_value(field, value)
# Recurse into nested config object
elif isinstance(value, SystemConfiguration):
user_config_fields[name] = _recurse_user_config_fields(
model=value,
infer_field_value=infer_field_value,
init_sub_config=init_sub_config,
)
# Recurse into optional nested config object
elif value is None and init_sub_config:
field_type = get_args(field.annotation)[0] # Optional[T] -> T
if type(field_type) is ModelMetaclass and issubclass(
field_type, SystemConfiguration
):
sub_config = init_sub_config(field_type)
if sub_config:
user_config_fields[name] = _recurse_user_config_fields(
model=sub_config,
infer_field_value=infer_field_value,
init_sub_config=init_sub_config,
)
elif isinstance(value, list) and all(
isinstance(i, SystemConfiguration) for i in value
):
user_config_fields[name] = [
_recurse_user_config_fields(i, infer_field_value, init_sub_config)
for i in value
]
elif isinstance(value, dict) and all(
isinstance(i, SystemConfiguration) for i in value.values()
):
user_config_fields[name] = {
k: _recurse_user_config_fields(v, infer_field_value, init_sub_config)
for k, v in value.items()
}
return user_config_fields
def _recurse_user_config_values(
instance: BaseModel,
get_field_value: Callable[[ModelField, T], T] = lambda _, v: v,
) -> dict[str, Any]:
"""
This function recursively traverses the user configuration values in a Pydantic
model instance.
Params:
instance: A Pydantic model instance.
get_field_value: A callback function to process each field. Parameters:
ModelField: The Pydantic ModelField object that describes the field.
Any: The current value of the field.
Returns:
A dictionary containing the processed user configuration fields of the instance.
"""
user_config_values = {}
for name, value in instance.__dict__.items():
field = instance.__fields__[name]
if "user_configurable" in field.field_info.extra:
user_config_values[name] = get_field_value(field, value)
elif isinstance(value, SystemConfiguration):
user_config_values[name] = _recurse_user_config_values(
instance=value, get_field_value=get_field_value
)
elif isinstance(value, list) and all(
isinstance(i, SystemConfiguration) for i in value
):
user_config_values[name] = [
_recurse_user_config_values(i, get_field_value) for i in value
]
elif isinstance(value, dict) and all(
isinstance(i, SystemConfiguration) for i in value.values()
):
user_config_values[name] = {
k: _recurse_user_config_values(v, get_field_value)
for k, v in value.items()
}
return user_config_values
def _get_non_default_user_config_values(instance: BaseModel) -> dict[str, Any]:
"""
Get the non-default user config fields of a Pydantic model instance.
Params:
instance: The Pydantic model instance.
Returns:
dict[str, Any]: The non-default user config values on the instance.
"""
def get_field_value(field: ModelField, value):
default = field.default_factory() if field.default_factory else field.default
if value != default:
return value
return remove_none_items(_recurse_user_config_values(instance, get_field_value))
def deep_update(original_dict: dict, update_dict: dict) -> dict:
"""
Recursively update a dictionary.
Params:
original_dict (dict): The dictionary to be updated.
update_dict (dict): The dictionary to update with.
Returns:
dict: The updated dictionary.
"""
for key, value in update_dict.items():
if (
key in original_dict
and isinstance(original_dict[key], dict)
and isinstance(value, dict)
):
original_dict[key] = deep_update(original_dict[key], value)
else:
original_dict[key] = value
return original_dict
def remove_none_items(d):
if isinstance(d, dict):
return {
k: remove_none_items(v) for k, v in d.items() if v not in (None, Undefined)
}
return d

View File

@@ -0,0 +1,217 @@
import ast
import enum
from textwrap import indent
from typing import Any, Optional, overload
from jsonschema import Draft7Validator, ValidationError
from pydantic import BaseModel
class JSONSchema(BaseModel):
class Type(str, enum.Enum):
STRING = "string"
ARRAY = "array"
OBJECT = "object"
NUMBER = "number"
INTEGER = "integer"
BOOLEAN = "boolean"
TYPE = "type"
# TODO: add docstrings
description: Optional[str] = None
type: Optional[Type] = None
enum: Optional[list] = None
required: bool = False
default: Any = None
items: Optional["JSONSchema"] = None
properties: Optional[dict[str, "JSONSchema"]] = None
minimum: Optional[int | float] = None
maximum: Optional[int | float] = None
minItems: Optional[int] = None
maxItems: Optional[int] = None
def to_dict(self) -> dict:
schema: dict = {
"type": self.type.value if self.type else None,
"description": self.description,
"default": repr(self.default),
}
if self.type == "array":
if self.items:
schema["items"] = self.items.to_dict()
schema["minItems"] = self.minItems
schema["maxItems"] = self.maxItems
elif self.type == "object":
if self.properties:
schema["properties"] = {
name: prop.to_dict() for name, prop in self.properties.items()
}
schema["required"] = [
name for name, prop in self.properties.items() if prop.required
]
elif self.enum:
schema["enum"] = self.enum
else:
schema["minumum"] = self.minimum
schema["maximum"] = self.maximum
schema = {k: v for k, v in schema.items() if v is not None}
return schema
@staticmethod
def from_dict(schema: dict) -> "JSONSchema":
definitions = schema.get("definitions", {})
schema = _resolve_type_refs_in_schema(schema, definitions)
return JSONSchema(
description=schema.get("description"),
type=schema["type"],
default=ast.literal_eval(d) if (d := schema.get("default")) else None,
enum=schema.get("enum"),
items=JSONSchema.from_dict(schema["items"]) if "items" in schema else None,
properties=JSONSchema.parse_properties(schema)
if schema["type"] == "object"
else None,
minimum=schema.get("minimum"),
maximum=schema.get("maximum"),
minItems=schema.get("minItems"),
maxItems=schema.get("maxItems"),
)
@staticmethod
def parse_properties(schema_node: dict) -> dict[str, "JSONSchema"]:
properties = (
{k: JSONSchema.from_dict(v) for k, v in schema_node["properties"].items()}
if "properties" in schema_node
else {}
)
if "required" in schema_node:
for k, v in properties.items():
v.required = k in schema_node["required"]
return properties
def validate_object(self, object: object) -> tuple[bool, list[ValidationError]]:
"""
Validates an object or a value against the JSONSchema.
Params:
object: The value/object to validate.
schema (JSONSchema): The JSONSchema to validate against.
Returns:
bool: Indicates whether the given value or object is valid for the schema.
list[ValidationError]: The issues with the value or object (if any).
"""
validator = Draft7Validator(self.to_dict())
if errors := sorted(validator.iter_errors(object), key=lambda e: e.path):
return False, errors
return True, []
def to_typescript_object_interface(self, interface_name: str = "") -> str:
if self.type != JSONSchema.Type.OBJECT:
raise NotImplementedError("Only `object` schemas are supported")
if self.properties:
attributes: list[str] = []
for name, property in self.properties.items():
if property.description:
attributes.append(f"// {property.description}")
attributes.append(f"{name}: {property.typescript_type};")
attributes_string = "\n".join(attributes)
else:
attributes_string = "[key: string]: any"
return (
f"interface {interface_name} " if interface_name else ""
) + f"{{\n{indent(attributes_string, ' ')}\n}}"
@property
def python_type(self) -> str:
if self.type == JSONSchema.Type.BOOLEAN:
return "bool"
elif self.type in {JSONSchema.Type.INTEGER}:
return "int"
elif self.type == JSONSchema.Type.NUMBER:
return "float"
elif self.type == JSONSchema.Type.STRING:
return "str"
elif self.type == JSONSchema.Type.ARRAY:
return f"list[{self.items.python_type}]" if self.items else "list"
elif self.type == JSONSchema.Type.OBJECT:
if not self.properties:
return "dict"
raise NotImplementedError(
"JSONSchema.python_type doesn't support TypedDicts yet"
)
elif self.enum:
return "Union[" + ", ".join(repr(v) for v in self.enum) + "]"
elif self.type == JSONSchema.Type.TYPE:
return "type"
elif self.type is None:
return "Any"
else:
raise NotImplementedError(
f"JSONSchema.python_type does not support Type.{self.type.name} yet"
)
@property
def typescript_type(self) -> str:
if not self.type:
return "any"
if self.type == JSONSchema.Type.BOOLEAN:
return "boolean"
if self.type in {JSONSchema.Type.INTEGER, JSONSchema.Type.NUMBER}:
return "number"
if self.type == JSONSchema.Type.STRING:
return "string"
if self.type == JSONSchema.Type.ARRAY:
return f"Array<{self.items.typescript_type}>" if self.items else "Array"
if self.type == JSONSchema.Type.OBJECT:
if not self.properties:
return "Record<string, any>"
return self.to_typescript_object_interface()
if self.enum:
return " | ".join(repr(v) for v in self.enum)
elif self.type == JSONSchema.Type.TYPE:
return "type"
elif self.type is None:
return "any"
raise NotImplementedError(
f"JSONSchema.typescript_type does not support Type.{self.type.name} yet"
)
@overload
def _resolve_type_refs_in_schema(schema: dict, definitions: dict) -> dict:
...
@overload
def _resolve_type_refs_in_schema(schema: list, definitions: dict) -> list:
...
def _resolve_type_refs_in_schema(schema: dict | list, definitions: dict) -> dict | list:
"""
Recursively resolve type $refs in the JSON schema with their definitions.
"""
if isinstance(schema, dict):
if "$ref" in schema:
ref_path = schema["$ref"].split("/")[2:] # Split and remove '#/definitions'
ref_value = definitions
for key in ref_path:
ref_value = ref_value[key]
return _resolve_type_refs_in_schema(ref_value, definitions)
else:
return {
k: _resolve_type_refs_in_schema(v, definitions)
for k, v in schema.items()
}
elif isinstance(schema, list):
return [_resolve_type_refs_in_schema(item, definitions) for item in schema]
else:
return schema

View File

@@ -0,0 +1,61 @@
import abc
import enum
import math
from typing import Callable, Generic, TypeVar
from pydantic import BaseModel, SecretBytes, SecretField, SecretStr
from forge.models.config import SystemConfiguration, UserConfigurable
_T = TypeVar("_T")
class ResourceType(str, enum.Enum):
"""An enumeration of resource types."""
MODEL = "model"
class ProviderBudget(SystemConfiguration, Generic[_T]):
total_budget: float = UserConfigurable(math.inf)
total_cost: float = 0
remaining_budget: float = math.inf
usage: _T
@abc.abstractmethod
def update_usage_and_cost(self, *args, **kwargs) -> float:
"""Update the usage and cost of the provider.
Returns:
float: The (calculated) cost of the given model response.
"""
...
class ProviderCredentials(SystemConfiguration):
"""Struct for credentials."""
def unmasked(self) -> dict:
return unmask(self)
class Config(SystemConfiguration.Config):
json_encoders: dict[type[SecretField], Callable[[SecretField], str | None]] = {
SecretStr: lambda v: v.get_secret_value() if v else None,
SecretBytes: lambda v: v.get_secret_value() if v else None,
SecretField: lambda v: v.get_secret_value() if v else None,
}
def unmask(model: BaseModel):
unmasked_fields = {}
for field_name, _ in model.__fields__.items():
value = getattr(model, field_name)
if isinstance(value, SecretStr):
unmasked_fields[field_name] = value.get_secret_value()
else:
unmasked_fields[field_name] = value
return unmasked_fields
# Used both by model providers and memory providers
Embedding = list[float]

View File

@@ -0,0 +1,10 @@
from abc import ABC, abstractmethod
from pydantic import BaseModel
class ModelWithSummary(BaseModel, ABC):
@abstractmethod
def summary(self) -> str:
"""Should produce a human readable summary of the model content."""
pass

View File

@@ -0,0 +1,4 @@
"""This module contains the (speech recognition and) speech synthesis functions."""
from .say import TextToSpeechProvider, TTSConfig
__all__ = ["TextToSpeechProvider", "TTSConfig"]

Some files were not shown because too many files have changed in this diff Show More