mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-02-10 06:45:28 -05:00
Merge branch 'master' into zamilmajdy/code-validation
This commit is contained in:
7
forge/.env.example
Normal file
7
forge/.env.example
Normal file
@@ -0,0 +1,7 @@
|
||||
# Your OpenAI API Key. If GPT-4 is available it will use that, otherwise will use 3.5-turbo
|
||||
OPENAI_API_KEY=abc
|
||||
|
||||
# Control log level
|
||||
LOG_LEVEL=INFO
|
||||
DATABASE_STRING="sqlite:///agent.db"
|
||||
PORT=8000
|
||||
11
forge/.flake8
Normal file
11
forge/.flake8
Normal file
@@ -0,0 +1,11 @@
|
||||
[flake8]
|
||||
max-line-length = 88
|
||||
# Ignore rules that conflict with Black code style
|
||||
extend-ignore = E203, W503
|
||||
exclude =
|
||||
.git,
|
||||
__pycache__/,
|
||||
*.pyc,
|
||||
.pytest_cache/,
|
||||
venv*/,
|
||||
.venv/,
|
||||
175
forge/.gitignore
vendored
Normal file
175
forge/.gitignore
vendored
Normal file
@@ -0,0 +1,175 @@
|
||||
## Original ignores
|
||||
autogpt/keys.py
|
||||
autogpt/*.json
|
||||
*.mpeg
|
||||
.env
|
||||
azure.yaml
|
||||
.vscode
|
||||
.idea/*
|
||||
auto-gpt.json
|
||||
log.txt
|
||||
log-ingestion.txt
|
||||
logs
|
||||
*.log
|
||||
*.mp3
|
||||
mem.sqlite3
|
||||
venvAutoGPT
|
||||
|
||||
# Byte-compiled / optimized / DLL files
|
||||
__pycache__/
|
||||
*.py[cod]
|
||||
*$py.class
|
||||
|
||||
# C extensions
|
||||
*.so
|
||||
|
||||
# Distribution / packaging
|
||||
.Python
|
||||
build/
|
||||
develop-eggs/
|
||||
dist/
|
||||
plugins/
|
||||
plugins_config.yaml
|
||||
downloads/
|
||||
eggs/
|
||||
.eggs/
|
||||
lib/
|
||||
lib64/
|
||||
parts/
|
||||
sdist/
|
||||
var/
|
||||
wheels/
|
||||
pip-wheel-metadata/
|
||||
share/python-wheels/
|
||||
*.egg-info/
|
||||
.installed.cfg
|
||||
*.egg
|
||||
MANIFEST
|
||||
|
||||
# PyInstaller
|
||||
# Usually these files are written by a python script from a template
|
||||
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
||||
*.manifest
|
||||
*.spec
|
||||
|
||||
# Installer logs
|
||||
pip-log.txt
|
||||
pip-delete-this-directory.txt
|
||||
|
||||
# Unit test / coverage reports
|
||||
htmlcov/
|
||||
.tox/
|
||||
.nox/
|
||||
.coverage
|
||||
.coverage.*
|
||||
.cache
|
||||
nosetests.xml
|
||||
coverage.xml
|
||||
*.cover
|
||||
*.py,cover
|
||||
.hypothesis/
|
||||
.pytest_cache/
|
||||
|
||||
# Translations
|
||||
*.mo
|
||||
*.pot
|
||||
|
||||
# Django stuff:
|
||||
*.log
|
||||
local_settings.py
|
||||
db.sqlite3
|
||||
db.sqlite3-journal
|
||||
|
||||
# Flask stuff:
|
||||
instance/
|
||||
.webassets-cache
|
||||
|
||||
# Scrapy stuff:
|
||||
.scrapy
|
||||
|
||||
# Sphinx documentation
|
||||
docs/_build/
|
||||
site/
|
||||
|
||||
# PyBuilder
|
||||
target/
|
||||
|
||||
# Jupyter Notebook
|
||||
.ipynb_checkpoints
|
||||
|
||||
# IPython
|
||||
profile_default/
|
||||
ipython_config.py
|
||||
|
||||
# pyenv
|
||||
.python-version
|
||||
|
||||
# pipenv
|
||||
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
||||
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
||||
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
||||
# install all needed dependencies.
|
||||
#Pipfile.lock
|
||||
|
||||
# PEP 582; used by e.g. github.com/David-OConnor/pyflow
|
||||
__pypackages__/
|
||||
|
||||
# Celery stuff
|
||||
celerybeat-schedule
|
||||
celerybeat.pid
|
||||
|
||||
# SageMath parsed files
|
||||
*.sage.py
|
||||
|
||||
# Environments
|
||||
.direnv/
|
||||
.env
|
||||
.venv
|
||||
env/
|
||||
venv*/
|
||||
ENV/
|
||||
env.bak/
|
||||
|
||||
# Spyder project settings
|
||||
.spyderproject
|
||||
.spyproject
|
||||
|
||||
# Rope project settings
|
||||
.ropeproject
|
||||
|
||||
# mkdocs documentation
|
||||
/site
|
||||
|
||||
# mypy
|
||||
.mypy_cache/
|
||||
.dmypy.json
|
||||
dmypy.json
|
||||
|
||||
# Pyre type checker
|
||||
.pyre/
|
||||
llama-*
|
||||
vicuna-*
|
||||
|
||||
# mac
|
||||
.DS_Store
|
||||
|
||||
openai/
|
||||
|
||||
# news
|
||||
CURRENT_BULLETIN.md
|
||||
|
||||
agbenchmark_config/workspace
|
||||
agbenchmark_config/reports
|
||||
*.sqlite*
|
||||
*.db
|
||||
.agbench
|
||||
.agbenchmark
|
||||
.benchmarks
|
||||
.mypy_cache
|
||||
.pytest_cache
|
||||
.vscode
|
||||
ig_*
|
||||
agbenchmark_config/updates.json
|
||||
agbenchmark_config/challenges_already_beaten.json
|
||||
agbenchmark_config/temp_folder/*
|
||||
test_workspace/
|
||||
40
forge/Dockerfile
Normal file
40
forge/Dockerfile
Normal file
@@ -0,0 +1,40 @@
|
||||
# Use an official Python runtime as a parent image
|
||||
FROM python:3.11-slim-buster as base
|
||||
|
||||
# Set work directory in the container
|
||||
WORKDIR /app
|
||||
|
||||
# Install system dependencies
|
||||
RUN apt-get update \
|
||||
&& apt-get install -y build-essential curl ffmpeg \
|
||||
&& apt-get clean \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
|
||||
# Install Poetry - respects $POETRY_VERSION & $POETRY_HOME
|
||||
ENV POETRY_VERSION=1.1.8 \
|
||||
POETRY_HOME="/opt/poetry" \
|
||||
POETRY_NO_INTERACTION=1 \
|
||||
POETRY_VIRTUALENVS_CREATE=false \
|
||||
PATH="$POETRY_HOME/bin:$PATH"
|
||||
|
||||
RUN pip3 install poetry
|
||||
|
||||
COPY pyproject.toml poetry.lock* /app/
|
||||
|
||||
# Project initialization:
|
||||
RUN poetry install --no-interaction --no-ansi
|
||||
|
||||
ENV PYTHONPATH="/app:$PYTHONPATH"
|
||||
|
||||
FROM base as dependencies
|
||||
|
||||
# Copy project
|
||||
COPY . /app
|
||||
|
||||
|
||||
# Make port 80 available to the world outside this container
|
||||
EXPOSE 8000
|
||||
|
||||
# Run the application when the container launches
|
||||
CMD ["poetry", "run", "python", "autogpt/__main__.py"]
|
||||
24
forge/README.md
Normal file
24
forge/README.md
Normal file
@@ -0,0 +1,24 @@
|
||||
# 🚀 **AutoGPT-Forge**: Build Your Own AutoGPT Agent! 🧠
|
||||
|
||||
### 🌌 Dive into the Universe of AutoGPT Creation! 🌌
|
||||
|
||||
Ever dreamt of becoming the genius behind an AI agent? Dive into the *Forge*, where **you** become the creator!
|
||||
|
||||
---
|
||||
|
||||
### 🛠️ **Why AutoGPT-Forge?**
|
||||
- 💤 **No More Boilerplate!** Don't let the mundane tasks stop you. Fork and build without the headache of starting from scratch!
|
||||
- 🧠 **Brain-centric Development!** All the tools you need so you can spend 100% of your time on what matters - crafting the brain of your AI!
|
||||
- 🛠️ **Tooling ecosystem!** We work with the best in class tools to bring you the best experience possible!
|
||||
---
|
||||
|
||||
### 🚀 **Get Started!**
|
||||
|
||||
The getting started [tutorial series](https://aiedge.medium.com/autogpt-forge-e3de53cc58ec) will guide you through the process of setting up your project all the way through to building a generalist agent.
|
||||
|
||||
1. [AutoGPT Forge: A Comprehensive Guide to Your First Steps](https://aiedge.medium.com/autogpt-forge-a-comprehensive-guide-to-your-first-steps-a1dfdf46e3b4)
|
||||
2. [AutoGPT Forge: The Blueprint of an AI Agent](https://aiedge.medium.com/autogpt-forge-the-blueprint-of-an-ai-agent-75cd72ffde6)
|
||||
3. [AutoGPT Forge: Interacting with your Agent](https://aiedge.medium.com/autogpt-forge-interacting-with-your-agent-1214561b06b)
|
||||
4. [AutoGPT Forge: Crafting Intelligent Agent Logic](https://medium.com/@aiedge/autogpt-forge-crafting-intelligent-agent-logic-bc5197b14cb4)
|
||||
|
||||
|
||||
4
forge/agbenchmark_config/config.json
Normal file
4
forge/agbenchmark_config/config.json
Normal file
@@ -0,0 +1,4 @@
|
||||
{
|
||||
"workspace": {"input": "agbenchmark_config/workspace", "output": "agbenchmark_config/workspace"},
|
||||
"host": "http://localhost:8000"
|
||||
}
|
||||
0
forge/forge/__init__.py
Normal file
0
forge/forge/__init__.py
Normal file
54
forge/forge/__main__.py
Normal file
54
forge/forge/__main__.py
Normal file
@@ -0,0 +1,54 @@
|
||||
import logging
|
||||
import os
|
||||
|
||||
import uvicorn
|
||||
from dotenv import load_dotenv
|
||||
|
||||
from forge.logging.config import configure_logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
logo = """\n\n
|
||||
d8888 888 .d8888b. 8888888b. 88888888888
|
||||
d88P888 888 888 888 888 888 888
|
||||
d88P 888 888 888 888888 .d88b. 888 888 d88P 888
|
||||
d88P 888 888 888 888 d88""88b 888 88888 8888888P" 888
|
||||
d88P 888 888 888 888 888 888 888 888 888 888
|
||||
d8888888888 Y88b 888 Y88b. Y88..88P Y88b d88P 888 888
|
||||
d88P 888 "Y88888 "Y888 "Y88P" "Y8888P88 888 888
|
||||
|
||||
|
||||
8888888888
|
||||
888
|
||||
888 .d88b. 888d888 .d88b. .d88b.
|
||||
888888 d88""88b 888P" d88P"88b d8P Y8b
|
||||
888 888 888 888 888 888 88888888
|
||||
888 Y88..88P 888 Y88b 888 Y8b.
|
||||
888 "Y88P" 888 "Y88888 "Y8888
|
||||
888
|
||||
Y8b d88P
|
||||
"Y88P" v0.1.0
|
||||
\n"""
|
||||
|
||||
if __name__ == "__main__":
|
||||
print(logo)
|
||||
port = os.getenv("PORT", 8000)
|
||||
configure_logging()
|
||||
logger.info(f"Agent server starting on http://localhost:{port}")
|
||||
load_dotenv()
|
||||
|
||||
uvicorn.run(
|
||||
"forge.app:app",
|
||||
host="localhost",
|
||||
port=int(port),
|
||||
log_level="error",
|
||||
# Reload on changes to code or .env
|
||||
reload=True,
|
||||
reload_dirs=os.path.dirname(os.path.dirname(__file__)),
|
||||
reload_excludes="*.py", # Cancel default *.py include pattern
|
||||
reload_includes=[
|
||||
f"{os.path.basename(os.path.dirname(__file__))}/**/*.py",
|
||||
".*",
|
||||
".env",
|
||||
],
|
||||
)
|
||||
7
forge/forge/agent/__init__.py
Normal file
7
forge/forge/agent/__init__.py
Normal file
@@ -0,0 +1,7 @@
|
||||
from .base import BaseAgent, BaseAgentConfiguration, BaseAgentSettings
|
||||
|
||||
__all__ = [
|
||||
"BaseAgent",
|
||||
"BaseAgentConfiguration",
|
||||
"BaseAgentSettings",
|
||||
]
|
||||
193
forge/forge/agent/agent.py
Normal file
193
forge/forge/agent/agent.py
Normal file
@@ -0,0 +1,193 @@
|
||||
import logging
|
||||
import os
|
||||
import pathlib
|
||||
from io import BytesIO
|
||||
from uuid import uuid4
|
||||
|
||||
import uvicorn
|
||||
from fastapi import APIRouter, FastAPI, UploadFile
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.responses import RedirectResponse, StreamingResponse
|
||||
from fastapi.staticfiles import StaticFiles
|
||||
|
||||
from forge.agent_protocol.api_router import base_router
|
||||
from forge.agent_protocol.database.db import AgentDB
|
||||
from forge.agent_protocol.middlewares import AgentMiddleware
|
||||
from forge.agent_protocol.models.task import (
|
||||
Artifact,
|
||||
Step,
|
||||
StepRequestBody,
|
||||
Task,
|
||||
TaskArtifactsListResponse,
|
||||
TaskListResponse,
|
||||
TaskRequestBody,
|
||||
TaskStepsListResponse,
|
||||
)
|
||||
from forge.file_storage.base import FileStorage
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class Agent:
|
||||
def __init__(self, database: AgentDB, workspace: FileStorage):
|
||||
self.db = database
|
||||
self.workspace = workspace
|
||||
|
||||
def get_agent_app(self, router: APIRouter = base_router):
|
||||
"""
|
||||
Start the agent server.
|
||||
"""
|
||||
|
||||
app = FastAPI(
|
||||
title="AutoGPT Forge",
|
||||
description="Modified version of The Agent Protocol.",
|
||||
version="v0.4",
|
||||
)
|
||||
|
||||
# Add CORS middleware
|
||||
origins = [
|
||||
"http://localhost:5000",
|
||||
"http://127.0.0.1:5000",
|
||||
"http://localhost:8000",
|
||||
"http://127.0.0.1:8000",
|
||||
"http://localhost:8080",
|
||||
"http://127.0.0.1:8080",
|
||||
# Add any other origins you want to whitelist
|
||||
]
|
||||
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=origins,
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
app.include_router(router, prefix="/ap/v1")
|
||||
script_dir = os.path.dirname(os.path.realpath(__file__))
|
||||
frontend_path = pathlib.Path(
|
||||
os.path.join(script_dir, "../../../frontend/build/web")
|
||||
).resolve()
|
||||
|
||||
if os.path.exists(frontend_path):
|
||||
app.mount("/app", StaticFiles(directory=frontend_path), name="app")
|
||||
|
||||
@app.get("/", include_in_schema=False)
|
||||
async def root():
|
||||
return RedirectResponse(url="/app/index.html", status_code=307)
|
||||
|
||||
else:
|
||||
logger.warning(
|
||||
f"Frontend not found. {frontend_path} does not exist. "
|
||||
"The frontend will not be served."
|
||||
)
|
||||
app.add_middleware(AgentMiddleware, agent=self)
|
||||
|
||||
return app
|
||||
|
||||
def start(self, port):
|
||||
uvicorn.run(
|
||||
"forge.app:app", host="localhost", port=port, log_level="error", reload=True
|
||||
)
|
||||
|
||||
async def create_task(self, task_request: TaskRequestBody) -> Task:
|
||||
"""
|
||||
Create a task for the agent.
|
||||
"""
|
||||
task = await self.db.create_task(
|
||||
input=task_request.input,
|
||||
additional_input=task_request.additional_input,
|
||||
)
|
||||
return task
|
||||
|
||||
async def list_tasks(self, page: int = 1, pageSize: int = 10) -> TaskListResponse:
|
||||
"""
|
||||
List all tasks that the agent has created.
|
||||
"""
|
||||
tasks, pagination = await self.db.list_tasks(page, pageSize)
|
||||
response = TaskListResponse(tasks=tasks, pagination=pagination)
|
||||
return response
|
||||
|
||||
async def get_task(self, task_id: str) -> Task:
|
||||
"""
|
||||
Get a task by ID.
|
||||
"""
|
||||
task = await self.db.get_task(task_id)
|
||||
return task
|
||||
|
||||
async def list_steps(
|
||||
self, task_id: str, page: int = 1, pageSize: int = 10
|
||||
) -> TaskStepsListResponse:
|
||||
"""
|
||||
List the IDs of all steps that the task has created.
|
||||
"""
|
||||
steps, pagination = await self.db.list_steps(task_id, page, pageSize)
|
||||
response = TaskStepsListResponse(steps=steps, pagination=pagination)
|
||||
return response
|
||||
|
||||
async def execute_step(self, task_id: str, step_request: StepRequestBody) -> Step:
|
||||
"""
|
||||
Create a step for the task.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
async def get_step(self, task_id: str, step_id: str) -> Step:
|
||||
"""
|
||||
Get a step by ID.
|
||||
"""
|
||||
step = await self.db.get_step(task_id, step_id)
|
||||
return step
|
||||
|
||||
async def list_artifacts(
|
||||
self, task_id: str, page: int = 1, pageSize: int = 10
|
||||
) -> TaskArtifactsListResponse:
|
||||
"""
|
||||
List the artifacts that the task has created.
|
||||
"""
|
||||
artifacts, pagination = await self.db.list_artifacts(task_id, page, pageSize)
|
||||
return TaskArtifactsListResponse(artifacts=artifacts, pagination=pagination)
|
||||
|
||||
async def create_artifact(
|
||||
self, task_id: str, file: UploadFile, relative_path: str = ""
|
||||
) -> Artifact:
|
||||
"""
|
||||
Create an artifact for the task.
|
||||
"""
|
||||
file_name = file.filename or str(uuid4())
|
||||
data = b""
|
||||
while contents := file.file.read(1024 * 1024):
|
||||
data += contents
|
||||
# Check if relative path ends with filename
|
||||
if relative_path.endswith(file_name):
|
||||
file_path = relative_path
|
||||
else:
|
||||
file_path = os.path.join(relative_path, file_name)
|
||||
|
||||
await self.workspace.write_file(file_path, data)
|
||||
|
||||
artifact = await self.db.create_artifact(
|
||||
task_id=task_id,
|
||||
file_name=file_name,
|
||||
relative_path=relative_path,
|
||||
agent_created=False,
|
||||
)
|
||||
return artifact
|
||||
|
||||
async def get_artifact(self, task_id: str, artifact_id: str) -> StreamingResponse:
|
||||
"""
|
||||
Get an artifact by ID.
|
||||
"""
|
||||
artifact = await self.db.get_artifact(artifact_id)
|
||||
if artifact.file_name not in artifact.relative_path:
|
||||
file_path = os.path.join(artifact.relative_path, artifact.file_name)
|
||||
else:
|
||||
file_path = artifact.relative_path
|
||||
retrieved_artifact = self.workspace.read_file(file_path, binary=True)
|
||||
|
||||
return StreamingResponse(
|
||||
BytesIO(retrieved_artifact),
|
||||
media_type="application/octet-stream",
|
||||
headers={
|
||||
"Content-Disposition": f"attachment; filename={artifact.file_name}"
|
||||
},
|
||||
)
|
||||
137
forge/forge/agent/agent_test.py
Normal file
137
forge/forge/agent/agent_test.py
Normal file
@@ -0,0 +1,137 @@
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
from fastapi import UploadFile
|
||||
|
||||
from forge.agent_protocol.database.db import AgentDB
|
||||
from forge.agent_protocol.models.task import (
|
||||
StepRequestBody,
|
||||
Task,
|
||||
TaskListResponse,
|
||||
TaskRequestBody,
|
||||
)
|
||||
from forge.file_storage.base import FileStorageConfiguration
|
||||
from forge.file_storage.local import LocalFileStorage
|
||||
|
||||
from .agent import Agent
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def agent(test_workspace: Path):
|
||||
db = AgentDB("sqlite:///test.db")
|
||||
config = FileStorageConfiguration(root=test_workspace)
|
||||
workspace = LocalFileStorage(config)
|
||||
return Agent(db, workspace)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def file_upload():
|
||||
this_file = Path(__file__)
|
||||
file_handle = this_file.open("rb")
|
||||
yield UploadFile(file_handle, filename=this_file.name)
|
||||
file_handle.close()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_task(agent: Agent):
|
||||
task_request = TaskRequestBody(
|
||||
input="test_input", additional_input={"input": "additional_test_input"}
|
||||
)
|
||||
task: Task = await agent.create_task(task_request)
|
||||
assert task.input == "test_input"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_tasks(agent: Agent):
|
||||
task_request = TaskRequestBody(
|
||||
input="test_input", additional_input={"input": "additional_test_input"}
|
||||
)
|
||||
await agent.create_task(task_request)
|
||||
tasks = await agent.list_tasks()
|
||||
assert isinstance(tasks, TaskListResponse)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_task(agent: Agent):
|
||||
task_request = TaskRequestBody(
|
||||
input="test_input", additional_input={"input": "additional_test_input"}
|
||||
)
|
||||
task = await agent.create_task(task_request)
|
||||
retrieved_task = await agent.get_task(task.task_id)
|
||||
assert retrieved_task.task_id == task.task_id
|
||||
|
||||
|
||||
@pytest.mark.xfail(reason="execute_step is not implemented")
|
||||
@pytest.mark.asyncio
|
||||
async def test_execute_step(agent: Agent):
|
||||
task_request = TaskRequestBody(
|
||||
input="test_input", additional_input={"input": "additional_test_input"}
|
||||
)
|
||||
task = await agent.create_task(task_request)
|
||||
step_request = StepRequestBody(
|
||||
input="step_input", additional_input={"input": "additional_test_input"}
|
||||
)
|
||||
step = await agent.execute_step(task.task_id, step_request)
|
||||
assert step.input == "step_input"
|
||||
assert step.additional_input == {"input": "additional_test_input"}
|
||||
|
||||
|
||||
@pytest.mark.xfail(reason="execute_step is not implemented")
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_step(agent: Agent):
|
||||
task_request = TaskRequestBody(
|
||||
input="test_input", additional_input={"input": "additional_test_input"}
|
||||
)
|
||||
task = await agent.create_task(task_request)
|
||||
step_request = StepRequestBody(
|
||||
input="step_input", additional_input={"input": "additional_test_input"}
|
||||
)
|
||||
step = await agent.execute_step(task.task_id, step_request)
|
||||
retrieved_step = await agent.get_step(task.task_id, step.step_id)
|
||||
assert retrieved_step.step_id == step.step_id
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_artifacts(agent: Agent):
|
||||
tasks = await agent.list_tasks()
|
||||
assert tasks.tasks, "No tasks in test.db"
|
||||
|
||||
artifacts = await agent.list_artifacts(tasks.tasks[0].task_id)
|
||||
assert isinstance(artifacts.artifacts, list)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_artifact(agent: Agent, file_upload: UploadFile):
|
||||
task_request = TaskRequestBody(
|
||||
input="test_input", additional_input={"input": "additional_test_input"}
|
||||
)
|
||||
task = await agent.create_task(task_request)
|
||||
artifact = await agent.create_artifact(
|
||||
task_id=task.task_id,
|
||||
file=file_upload,
|
||||
relative_path=f"a_dir/{file_upload.filename}",
|
||||
)
|
||||
assert artifact.file_name == file_upload.filename
|
||||
assert artifact.relative_path == f"a_dir/{file_upload.filename}"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_and_get_artifact(agent: Agent, file_upload: UploadFile):
|
||||
task_request = TaskRequestBody(
|
||||
input="test_input", additional_input={"input": "additional_test_input"}
|
||||
)
|
||||
task = await agent.create_task(task_request)
|
||||
|
||||
artifact = await agent.create_artifact(
|
||||
task_id=task.task_id,
|
||||
file=file_upload,
|
||||
relative_path=f"b_dir/{file_upload.filename}",
|
||||
)
|
||||
await file_upload.seek(0)
|
||||
file_upload_content = await file_upload.read()
|
||||
|
||||
retrieved_artifact = await agent.get_artifact(task.task_id, artifact.artifact_id)
|
||||
retrieved_artifact_content = bytearray()
|
||||
async for b in retrieved_artifact.body_iterator:
|
||||
retrieved_artifact_content.extend(b) # type: ignore
|
||||
assert retrieved_artifact_content == file_upload_content
|
||||
333
forge/forge/agent/base.py
Normal file
333
forge/forge/agent/base.py
Normal file
@@ -0,0 +1,333 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import copy
|
||||
import inspect
|
||||
import logging
|
||||
from abc import ABCMeta, abstractmethod
|
||||
from typing import (
|
||||
Any,
|
||||
Awaitable,
|
||||
Callable,
|
||||
Generic,
|
||||
Iterator,
|
||||
Optional,
|
||||
ParamSpec,
|
||||
TypeVar,
|
||||
cast,
|
||||
overload,
|
||||
)
|
||||
|
||||
from colorama import Fore
|
||||
from pydantic import BaseModel, Field, validator
|
||||
|
||||
from forge.agent import protocols
|
||||
from forge.agent.components import (
|
||||
AgentComponent,
|
||||
ComponentEndpointError,
|
||||
EndpointPipelineError,
|
||||
)
|
||||
from forge.config.ai_directives import AIDirectives
|
||||
from forge.config.ai_profile import AIProfile
|
||||
from forge.llm.providers import CHAT_MODELS, ModelName, OpenAIModelName
|
||||
from forge.llm.providers.schema import ChatModelInfo
|
||||
from forge.models.action import ActionResult, AnyProposal
|
||||
from forge.models.config import SystemConfiguration, SystemSettings, UserConfigurable
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
T = TypeVar("T")
|
||||
P = ParamSpec("P")
|
||||
|
||||
DEFAULT_TRIGGERING_PROMPT = (
|
||||
"Determine exactly one command to use next based on the given goals "
|
||||
"and the progress you have made so far, "
|
||||
"and respond using the JSON schema specified previously:"
|
||||
)
|
||||
|
||||
|
||||
class BaseAgentConfiguration(SystemConfiguration):
|
||||
allow_fs_access: bool = UserConfigurable(default=False)
|
||||
|
||||
fast_llm: ModelName = UserConfigurable(default=OpenAIModelName.GPT3_16k)
|
||||
smart_llm: ModelName = UserConfigurable(default=OpenAIModelName.GPT4)
|
||||
use_functions_api: bool = UserConfigurable(default=False)
|
||||
|
||||
default_cycle_instruction: str = DEFAULT_TRIGGERING_PROMPT
|
||||
"""The default instruction passed to the AI for a thinking cycle."""
|
||||
|
||||
big_brain: bool = UserConfigurable(default=True)
|
||||
"""
|
||||
Whether this agent uses the configured smart LLM (default) to think,
|
||||
as opposed to the configured fast LLM. Enabling this disables hybrid mode.
|
||||
"""
|
||||
|
||||
cycle_budget: Optional[int] = 1
|
||||
"""
|
||||
The number of cycles that the agent is allowed to run unsupervised.
|
||||
|
||||
`None` for unlimited continuous execution,
|
||||
`1` to require user approval for every step,
|
||||
`0` to stop the agent.
|
||||
"""
|
||||
|
||||
cycles_remaining = cycle_budget
|
||||
"""The number of cycles remaining within the `cycle_budget`."""
|
||||
|
||||
cycle_count = 0
|
||||
"""The number of cycles that the agent has run since its initialization."""
|
||||
|
||||
send_token_limit: Optional[int] = None
|
||||
"""
|
||||
The token limit for prompt construction. Should leave room for the completion;
|
||||
defaults to 75% of `llm.max_tokens`.
|
||||
"""
|
||||
|
||||
summary_max_tlength: Optional[int] = None
|
||||
# TODO: move to ActionHistoryConfiguration
|
||||
|
||||
@validator("use_functions_api")
|
||||
def validate_openai_functions(cls, v: bool, values: dict[str, Any]):
|
||||
if v:
|
||||
smart_llm = values["smart_llm"]
|
||||
fast_llm = values["fast_llm"]
|
||||
assert all(
|
||||
[
|
||||
not any(s in name for s in {"-0301", "-0314"})
|
||||
for name in {smart_llm, fast_llm}
|
||||
]
|
||||
), (
|
||||
f"Model {smart_llm} does not support OpenAI Functions. "
|
||||
"Please disable OPENAI_FUNCTIONS or choose a suitable model."
|
||||
)
|
||||
return v
|
||||
|
||||
|
||||
class BaseAgentSettings(SystemSettings):
|
||||
agent_id: str = ""
|
||||
|
||||
ai_profile: AIProfile = Field(default_factory=lambda: AIProfile(ai_name="AutoGPT"))
|
||||
"""The AI profile or "personality" of the agent."""
|
||||
|
||||
directives: AIDirectives = Field(default_factory=AIDirectives)
|
||||
"""Directives (general instructional guidelines) for the agent."""
|
||||
|
||||
task: str = "Terminate immediately" # FIXME: placeholder for forge.sdk.schema.Task
|
||||
"""The user-given task that the agent is working on."""
|
||||
|
||||
config: BaseAgentConfiguration = Field(default_factory=BaseAgentConfiguration)
|
||||
"""The configuration for this BaseAgent subsystem instance."""
|
||||
|
||||
|
||||
class AgentMeta(ABCMeta):
|
||||
def __call__(cls, *args, **kwargs):
|
||||
# Create instance of the class (Agent or BaseAgent)
|
||||
instance = super().__call__(*args, **kwargs)
|
||||
# Automatically collect modules after the instance is created
|
||||
instance._collect_components()
|
||||
return instance
|
||||
|
||||
|
||||
class BaseAgent(Generic[AnyProposal], metaclass=AgentMeta):
|
||||
def __init__(
|
||||
self,
|
||||
settings: BaseAgentSettings,
|
||||
):
|
||||
self.state = settings
|
||||
self.components: list[AgentComponent] = []
|
||||
self.config = settings.config
|
||||
# Execution data for debugging
|
||||
self._trace: list[str] = []
|
||||
|
||||
logger.debug(f"Created {__class__} '{self.state.ai_profile.ai_name}'")
|
||||
|
||||
@property
|
||||
def trace(self) -> list[str]:
|
||||
return self._trace
|
||||
|
||||
@property
|
||||
def llm(self) -> ChatModelInfo:
|
||||
"""The LLM that the agent uses to think."""
|
||||
llm_name = (
|
||||
self.config.smart_llm if self.config.big_brain else self.config.fast_llm
|
||||
)
|
||||
return CHAT_MODELS[llm_name]
|
||||
|
||||
@property
|
||||
def send_token_limit(self) -> int:
|
||||
return self.config.send_token_limit or self.llm.max_tokens * 3 // 4
|
||||
|
||||
@abstractmethod
|
||||
async def propose_action(self) -> AnyProposal:
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def execute(
|
||||
self,
|
||||
proposal: AnyProposal,
|
||||
user_feedback: str = "",
|
||||
) -> ActionResult:
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def do_not_execute(
|
||||
self,
|
||||
denied_proposal: AnyProposal,
|
||||
user_feedback: str,
|
||||
) -> ActionResult:
|
||||
...
|
||||
|
||||
def reset_trace(self):
|
||||
self._trace = []
|
||||
|
||||
@overload
|
||||
async def run_pipeline(
|
||||
self, protocol_method: Callable[P, Iterator[T]], *args, retry_limit: int = 3
|
||||
) -> list[T]:
|
||||
...
|
||||
|
||||
@overload
|
||||
async def run_pipeline(
|
||||
self,
|
||||
protocol_method: Callable[P, None | Awaitable[None]],
|
||||
*args,
|
||||
retry_limit: int = 3,
|
||||
) -> list[None]:
|
||||
...
|
||||
|
||||
async def run_pipeline(
|
||||
self,
|
||||
protocol_method: Callable[P, Iterator[T] | None | Awaitable[None]],
|
||||
*args,
|
||||
retry_limit: int = 3,
|
||||
) -> list[T] | list[None]:
|
||||
method_name = protocol_method.__name__
|
||||
protocol_name = protocol_method.__qualname__.split(".")[0]
|
||||
protocol_class = getattr(protocols, protocol_name)
|
||||
if not issubclass(protocol_class, AgentComponent):
|
||||
raise TypeError(f"{repr(protocol_method)} is not a protocol method")
|
||||
|
||||
# Clone parameters to revert on failure
|
||||
original_args = self._selective_copy(args)
|
||||
pipeline_attempts = 0
|
||||
method_result: list[T] = []
|
||||
self._trace.append(f"⬇️ {Fore.BLUE}{method_name}{Fore.RESET}")
|
||||
|
||||
while pipeline_attempts < retry_limit:
|
||||
try:
|
||||
for component in self.components:
|
||||
# Skip other protocols
|
||||
if not isinstance(component, protocol_class):
|
||||
continue
|
||||
|
||||
# Skip disabled components
|
||||
if not component.enabled:
|
||||
self._trace.append(
|
||||
f" {Fore.LIGHTBLACK_EX}"
|
||||
f"{component.__class__.__name__}{Fore.RESET}"
|
||||
)
|
||||
continue
|
||||
|
||||
method = cast(
|
||||
Callable[..., Iterator[T] | None | Awaitable[None]] | None,
|
||||
getattr(component, method_name, None),
|
||||
)
|
||||
if not callable(method):
|
||||
continue
|
||||
|
||||
component_attempts = 0
|
||||
while component_attempts < retry_limit:
|
||||
try:
|
||||
component_args = self._selective_copy(args)
|
||||
result = method(*component_args)
|
||||
if inspect.isawaitable(result):
|
||||
result = await result
|
||||
if result is not None:
|
||||
method_result.extend(result)
|
||||
args = component_args
|
||||
self._trace.append(f"✅ {component.__class__.__name__}")
|
||||
|
||||
except ComponentEndpointError:
|
||||
self._trace.append(
|
||||
f"❌ {Fore.YELLOW}{component.__class__.__name__}: "
|
||||
f"ComponentEndpointError{Fore.RESET}"
|
||||
)
|
||||
# Retry the same component on ComponentEndpointError
|
||||
component_attempts += 1
|
||||
continue
|
||||
# Successful component execution
|
||||
break
|
||||
# Successful pipeline execution
|
||||
break
|
||||
except EndpointPipelineError as e:
|
||||
self._trace.append(
|
||||
f"❌ {Fore.LIGHTRED_EX}{e.triggerer.__class__.__name__}: "
|
||||
f"EndpointPipelineError{Fore.RESET}"
|
||||
)
|
||||
# Restart from the beginning on EndpointPipelineError
|
||||
# Revert to original parameters
|
||||
args = self._selective_copy(original_args)
|
||||
pipeline_attempts += 1
|
||||
continue # Start the loop over
|
||||
except Exception as e:
|
||||
raise e
|
||||
return method_result
|
||||
|
||||
def _collect_components(self):
|
||||
components = [
|
||||
getattr(self, attr)
|
||||
for attr in dir(self)
|
||||
if isinstance(getattr(self, attr), AgentComponent)
|
||||
]
|
||||
|
||||
if self.components:
|
||||
# Check if any component is missing (added to Agent but not to components)
|
||||
for component in components:
|
||||
if component not in self.components:
|
||||
logger.warning(
|
||||
f"Component {component.__class__.__name__} "
|
||||
"is attached to an agent but not added to components list"
|
||||
)
|
||||
# Skip collecting and sorting and sort if ordering is explicit
|
||||
return
|
||||
self.components = self._topological_sort(components)
|
||||
|
||||
def _topological_sort(
|
||||
self, components: list[AgentComponent]
|
||||
) -> list[AgentComponent]:
|
||||
visited = set()
|
||||
stack = []
|
||||
|
||||
def visit(node: AgentComponent):
|
||||
if node in visited:
|
||||
return
|
||||
visited.add(node)
|
||||
for neighbor_class in node._run_after:
|
||||
neighbor = next(
|
||||
(m for m in components if isinstance(m, neighbor_class)), None
|
||||
)
|
||||
if neighbor and neighbor not in visited:
|
||||
visit(neighbor)
|
||||
stack.append(node)
|
||||
|
||||
for component in components:
|
||||
visit(component)
|
||||
|
||||
return stack
|
||||
|
||||
def _selective_copy(self, args: tuple[Any, ...]) -> tuple[Any, ...]:
|
||||
copied_args = []
|
||||
for item in args:
|
||||
if isinstance(item, list):
|
||||
# Shallow copy for lists
|
||||
copied_item = item[:]
|
||||
elif isinstance(item, dict):
|
||||
# Shallow copy for dicts
|
||||
copied_item = item.copy()
|
||||
elif isinstance(item, BaseModel):
|
||||
# Deep copy for Pydantic models (deep=True to also copy nested models)
|
||||
copied_item = item.copy(deep=True)
|
||||
else:
|
||||
# Deep copy for other objects
|
||||
copied_item = copy.deepcopy(item)
|
||||
copied_args.append(copied_item)
|
||||
return tuple(copied_args)
|
||||
51
forge/forge/agent/components.py
Normal file
51
forge/forge/agent/components.py
Normal file
@@ -0,0 +1,51 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import ABC
|
||||
from typing import Callable, TypeVar
|
||||
|
||||
T = TypeVar("T", bound="AgentComponent")
|
||||
|
||||
|
||||
class AgentComponent(ABC):
|
||||
"""Base class for all agent components."""
|
||||
|
||||
_run_after: list[type[AgentComponent]] = []
|
||||
_enabled: Callable[[], bool] | bool = True
|
||||
_disabled_reason: str = ""
|
||||
|
||||
@property
|
||||
def enabled(self) -> bool:
|
||||
if callable(self._enabled):
|
||||
return self._enabled()
|
||||
return self._enabled
|
||||
|
||||
@property
|
||||
def disabled_reason(self) -> str:
|
||||
"""Return the reason this component is disabled."""
|
||||
return self._disabled_reason
|
||||
|
||||
def run_after(self: T, *components: type[AgentComponent] | AgentComponent) -> T:
|
||||
"""Set the components that this component should run after."""
|
||||
for component in components:
|
||||
t = component if isinstance(component, type) else type(component)
|
||||
if t not in self._run_after and t is not self.__class__:
|
||||
self._run_after.append(t)
|
||||
return self
|
||||
|
||||
|
||||
class ComponentEndpointError(Exception):
|
||||
"""Error of a single protocol method on a component."""
|
||||
|
||||
def __init__(self, message: str, component: AgentComponent):
|
||||
self.message = message
|
||||
self.triggerer = component
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
class EndpointPipelineError(ComponentEndpointError):
|
||||
"""Error of an entire pipeline of one endpoint."""
|
||||
|
||||
|
||||
class ComponentSystemError(EndpointPipelineError):
|
||||
"""Error of a group of pipelines;
|
||||
multiple different endpoints."""
|
||||
51
forge/forge/agent/protocols.py
Normal file
51
forge/forge/agent/protocols.py
Normal file
@@ -0,0 +1,51 @@
|
||||
from abc import abstractmethod
|
||||
from typing import TYPE_CHECKING, Awaitable, Generic, Iterator
|
||||
|
||||
from forge.models.action import ActionResult, AnyProposal
|
||||
|
||||
from .components import AgentComponent
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from forge.command.command import Command
|
||||
from forge.llm.providers import ChatMessage
|
||||
|
||||
|
||||
class DirectiveProvider(AgentComponent):
|
||||
def get_constraints(self) -> Iterator[str]:
|
||||
return iter([])
|
||||
|
||||
def get_resources(self) -> Iterator[str]:
|
||||
return iter([])
|
||||
|
||||
def get_best_practices(self) -> Iterator[str]:
|
||||
return iter([])
|
||||
|
||||
|
||||
class CommandProvider(AgentComponent):
|
||||
@abstractmethod
|
||||
def get_commands(self) -> Iterator["Command"]:
|
||||
...
|
||||
|
||||
|
||||
class MessageProvider(AgentComponent):
|
||||
@abstractmethod
|
||||
def get_messages(self) -> Iterator["ChatMessage"]:
|
||||
...
|
||||
|
||||
|
||||
class AfterParse(AgentComponent, Generic[AnyProposal]):
|
||||
@abstractmethod
|
||||
def after_parse(self, result: AnyProposal) -> None | Awaitable[None]:
|
||||
...
|
||||
|
||||
|
||||
class ExecutionFailure(AgentComponent):
|
||||
@abstractmethod
|
||||
def execution_failure(self, error: Exception) -> None | Awaitable[None]:
|
||||
...
|
||||
|
||||
|
||||
class AfterExecute(AgentComponent):
|
||||
@abstractmethod
|
||||
def after_execute(self, result: "ActionResult") -> None | Awaitable[None]:
|
||||
...
|
||||
0
forge/forge/agent_protocol/__init__.py
Normal file
0
forge/forge/agent_protocol/__init__.py
Normal file
476
forge/forge/agent_protocol/api_router.py
Normal file
476
forge/forge/agent_protocol/api_router.py
Normal file
@@ -0,0 +1,476 @@
|
||||
"""
|
||||
Routes for the Agent Service.
|
||||
|
||||
This module defines the API routes for the Agent service.
|
||||
|
||||
Developers and contributors should be especially careful when making modifications
|
||||
to these routes to ensure consistency and correctness in the system's behavior.
|
||||
"""
|
||||
import logging
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
|
||||
from fastapi import APIRouter, HTTPException, Query, Request, Response, UploadFile
|
||||
from fastapi.responses import StreamingResponse
|
||||
|
||||
from .models import (
|
||||
Artifact,
|
||||
Step,
|
||||
StepRequestBody,
|
||||
Task,
|
||||
TaskArtifactsListResponse,
|
||||
TaskListResponse,
|
||||
TaskRequestBody,
|
||||
TaskStepsListResponse,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from forge.agent.agent import Agent
|
||||
|
||||
base_router = APIRouter()
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@base_router.get("/", tags=["root"])
|
||||
async def root():
|
||||
"""
|
||||
Root endpoint that returns a welcome message.
|
||||
"""
|
||||
return Response(content="Welcome to the AutoGPT Forge")
|
||||
|
||||
|
||||
@base_router.get("/heartbeat", tags=["server"])
|
||||
async def check_server_status():
|
||||
"""
|
||||
Check if the server is running.
|
||||
"""
|
||||
return Response(content="Server is running.", status_code=200)
|
||||
|
||||
|
||||
@base_router.post("/agent/tasks", tags=["agent"], response_model=Task)
|
||||
async def create_agent_task(request: Request, task_request: TaskRequestBody) -> Task:
|
||||
"""
|
||||
Creates a new task using the provided TaskRequestBody and returns a Task.
|
||||
|
||||
Args:
|
||||
request (Request): FastAPI request object.
|
||||
task (TaskRequestBody): The task request containing input data.
|
||||
|
||||
Returns:
|
||||
Task: A new task with task_id, input, and additional_input set.
|
||||
|
||||
Example:
|
||||
Request (TaskRequestBody defined in schema.py):
|
||||
{
|
||||
"input": "Write the words you receive to the file 'output.txt'.",
|
||||
"additional_input": "python/code"
|
||||
}
|
||||
|
||||
Response (Task defined in schema.py):
|
||||
{
|
||||
"task_id": "50da533e-3904-4401-8a07-c49adf88b5eb",
|
||||
"input": "Write the word 'Washington' to a .txt file",
|
||||
"additional_input": "python/code",
|
||||
"artifacts": [],
|
||||
}
|
||||
"""
|
||||
agent: "Agent" = request["agent"]
|
||||
|
||||
try:
|
||||
task = await agent.create_task(task_request)
|
||||
return task
|
||||
except Exception:
|
||||
logger.exception(f"Error whilst trying to create a task: {task_request}")
|
||||
raise
|
||||
|
||||
|
||||
@base_router.get("/agent/tasks", tags=["agent"], response_model=TaskListResponse)
|
||||
async def list_agent_tasks(
|
||||
request: Request,
|
||||
page: int = Query(1, ge=1),
|
||||
page_size: int = Query(10, ge=1),
|
||||
) -> TaskListResponse:
|
||||
"""
|
||||
Retrieves a paginated list of all tasks.
|
||||
|
||||
Args:
|
||||
request (Request): FastAPI request object.
|
||||
page (int, optional): Page number for pagination. Default: 1
|
||||
page_size (int, optional): Number of tasks per page for pagination. Default: 10
|
||||
|
||||
Returns:
|
||||
TaskListResponse: A list of tasks, and pagination details.
|
||||
|
||||
Example:
|
||||
Request:
|
||||
GET /agent/tasks?page=1&pageSize=10
|
||||
|
||||
Response (TaskListResponse defined in schema.py):
|
||||
{
|
||||
"items": [
|
||||
{
|
||||
"input": "Write the word 'Washington' to a .txt file",
|
||||
"additional_input": null,
|
||||
"task_id": "50da533e-3904-4401-8a07-c49adf88b5eb",
|
||||
"artifacts": [],
|
||||
"steps": []
|
||||
},
|
||||
...
|
||||
],
|
||||
"pagination": {
|
||||
"total": 100,
|
||||
"pages": 10,
|
||||
"current": 1,
|
||||
"pageSize": 10
|
||||
}
|
||||
}
|
||||
"""
|
||||
agent: "Agent" = request["agent"]
|
||||
try:
|
||||
tasks = await agent.list_tasks(page, page_size)
|
||||
return tasks
|
||||
except Exception:
|
||||
logger.exception("Error whilst trying to list tasks")
|
||||
raise
|
||||
|
||||
|
||||
@base_router.get("/agent/tasks/{task_id}", tags=["agent"], response_model=Task)
|
||||
async def get_agent_task(request: Request, task_id: str) -> Task:
|
||||
"""
|
||||
Gets the details of a task by ID.
|
||||
|
||||
Args:
|
||||
request (Request): FastAPI request object.
|
||||
task_id (str): The ID of the task.
|
||||
|
||||
Returns:
|
||||
Task: The task with the given ID.
|
||||
|
||||
Example:
|
||||
Request:
|
||||
GET /agent/tasks/50da533e-3904-4401-8a07-c49adf88b5eb
|
||||
|
||||
Response (Task defined in schema.py):
|
||||
{
|
||||
"input": "Write the word 'Washington' to a .txt file",
|
||||
"additional_input": null,
|
||||
"task_id": "50da533e-3904-4401-8a07-c49adf88b5eb",
|
||||
"artifacts": [
|
||||
{
|
||||
"artifact_id": "7a49f31c-f9c6-4346-a22c-e32bc5af4d8e",
|
||||
"file_name": "output.txt",
|
||||
"agent_created": true,
|
||||
"relative_path": "file://50da533e-3904-4401-8a07-c49adf88b5eb/output.txt"
|
||||
}
|
||||
],
|
||||
"steps": [
|
||||
{
|
||||
"task_id": "50da533e-3904-4401-8a07-c49adf88b5eb",
|
||||
"step_id": "6bb1801a-fd80-45e8-899a-4dd723cc602e",
|
||||
"input": "Write the word 'Washington' to a .txt file",
|
||||
"additional_input": "challenge:write_to_file",
|
||||
"name": "Write to file",
|
||||
"status": "completed",
|
||||
"output": "I am going to use the write_to_file command and write Washington to a file called output.txt <write_to_file('output.txt', 'Washington')>",
|
||||
"additional_output": "Do you want me to continue?",
|
||||
"artifacts": [
|
||||
{
|
||||
"artifact_id": "7a49f31c-f9c6-4346-a22c-e32bc5af4d8e",
|
||||
"file_name": "output.txt",
|
||||
"agent_created": true,
|
||||
"relative_path": "file://50da533e-3904-4401-8a07-c49adf88b5eb/output.txt"
|
||||
}
|
||||
],
|
||||
"is_last": true
|
||||
}
|
||||
]
|
||||
}
|
||||
""" # noqa: E501
|
||||
agent: "Agent" = request["agent"]
|
||||
try:
|
||||
task = await agent.get_task(task_id)
|
||||
return task
|
||||
except Exception:
|
||||
logger.exception(f"Error whilst trying to get task: {task_id}")
|
||||
raise
|
||||
|
||||
|
||||
@base_router.get(
|
||||
"/agent/tasks/{task_id}/steps",
|
||||
tags=["agent"],
|
||||
response_model=TaskStepsListResponse,
|
||||
)
|
||||
async def list_agent_task_steps(
|
||||
request: Request,
|
||||
task_id: str,
|
||||
page: int = Query(1, ge=1),
|
||||
page_size: int = Query(10, ge=1, alias="pageSize"),
|
||||
) -> TaskStepsListResponse:
|
||||
"""
|
||||
Retrieves a paginated list of steps associated with a specific task.
|
||||
|
||||
Args:
|
||||
request (Request): FastAPI request object.
|
||||
task_id (str): The ID of the task.
|
||||
page (int, optional): The page number for pagination. Defaults to 1.
|
||||
page_size (int, optional): Number of steps per page for pagination. Default: 10.
|
||||
|
||||
Returns:
|
||||
TaskStepsListResponse: A list of steps, and pagination details.
|
||||
|
||||
Example:
|
||||
Request:
|
||||
GET /agent/tasks/50da533e-3904-4401-8a07-c49adf88b5eb/steps?page=1&pageSize=10
|
||||
|
||||
Response (TaskStepsListResponse defined in schema.py):
|
||||
{
|
||||
"items": [
|
||||
{
|
||||
"task_id": "50da533e-3904-4401-8a07-c49adf88b5eb",
|
||||
"step_id": "step1_id",
|
||||
...
|
||||
},
|
||||
...
|
||||
],
|
||||
"pagination": {
|
||||
"total": 100,
|
||||
"pages": 10,
|
||||
"current": 1,
|
||||
"pageSize": 10
|
||||
}
|
||||
}
|
||||
""" # noqa: E501
|
||||
agent: "Agent" = request["agent"]
|
||||
try:
|
||||
steps = await agent.list_steps(task_id, page, page_size)
|
||||
return steps
|
||||
except Exception:
|
||||
logger.exception("Error whilst trying to list steps")
|
||||
raise
|
||||
|
||||
|
||||
@base_router.post("/agent/tasks/{task_id}/steps", tags=["agent"], response_model=Step)
|
||||
async def execute_agent_task_step(
|
||||
request: Request, task_id: str, step_request: Optional[StepRequestBody] = None
|
||||
) -> Step:
|
||||
"""
|
||||
Executes the next step for a specified task based on the current task status and
|
||||
returns the executed step with additional feedback fields.
|
||||
|
||||
This route is significant because this is where the agent actually performs work.
|
||||
The function handles executing the next step for a task based on its current state,
|
||||
and it requires careful implementation to ensure all scenarios (like the presence
|
||||
or absence of steps or a step marked as `last_step`) are handled correctly.
|
||||
|
||||
Depending on the current state of the task, the following scenarios are possible:
|
||||
1. No steps exist for the task.
|
||||
2. There is at least one step already for the task, and the task does not have a
|
||||
completed step marked as `last_step`.
|
||||
3. There is a completed step marked as `last_step` already on the task.
|
||||
|
||||
In each of these scenarios, a step object will be returned with two additional
|
||||
fields: `output` and `additional_output`.
|
||||
- `output`: Provides the primary response or feedback to the user.
|
||||
- `additional_output`: Supplementary information or data. Its specific content is
|
||||
not strictly defined and can vary based on the step or agent's implementation.
|
||||
|
||||
Args:
|
||||
request (Request): FastAPI request object.
|
||||
task_id (str): The ID of the task.
|
||||
step (StepRequestBody): The details for executing the step.
|
||||
|
||||
Returns:
|
||||
Step: Details of the executed step with additional feedback.
|
||||
|
||||
Example:
|
||||
Request:
|
||||
POST /agent/tasks/50da533e-3904-4401-8a07-c49adf88b5eb/steps
|
||||
{
|
||||
"input": "Step input details...",
|
||||
...
|
||||
}
|
||||
|
||||
Response:
|
||||
{
|
||||
"task_id": "50da533e-3904-4401-8a07-c49adf88b5eb",
|
||||
"step_id": "step1_id",
|
||||
"output": "Primary feedback...",
|
||||
"additional_output": "Supplementary details...",
|
||||
...
|
||||
}
|
||||
"""
|
||||
agent: "Agent" = request["agent"]
|
||||
try:
|
||||
# An empty step request represents a yes to continue command
|
||||
if not step_request:
|
||||
step_request = StepRequestBody(input="y")
|
||||
|
||||
step = await agent.execute_step(task_id, step_request)
|
||||
return step
|
||||
except Exception:
|
||||
logger.exception(f"Error whilst trying to execute a task step: {task_id}")
|
||||
raise
|
||||
|
||||
|
||||
@base_router.get(
|
||||
"/agent/tasks/{task_id}/steps/{step_id}", tags=["agent"], response_model=Step
|
||||
)
|
||||
async def get_agent_task_step(request: Request, task_id: str, step_id: str) -> Step:
|
||||
"""
|
||||
Retrieves the details of a specific step for a given task.
|
||||
|
||||
Args:
|
||||
request (Request): FastAPI request object.
|
||||
task_id (str): The ID of the task.
|
||||
step_id (str): The ID of the step.
|
||||
|
||||
Returns:
|
||||
Step: Details of the specific step.
|
||||
|
||||
Example:
|
||||
Request:
|
||||
GET /agent/tasks/50da533e-3904-4401-8a07-c49adf88b5eb/steps/step1_id
|
||||
|
||||
Response:
|
||||
{
|
||||
"task_id": "50da533e-3904-4401-8a07-c49adf88b5eb",
|
||||
"step_id": "step1_id",
|
||||
...
|
||||
}
|
||||
"""
|
||||
agent: "Agent" = request["agent"]
|
||||
try:
|
||||
step = await agent.get_step(task_id, step_id)
|
||||
return step
|
||||
except Exception:
|
||||
logger.exception(f"Error whilst trying to get step: {step_id}")
|
||||
raise
|
||||
|
||||
|
||||
@base_router.get(
|
||||
"/agent/tasks/{task_id}/artifacts",
|
||||
tags=["agent"],
|
||||
response_model=TaskArtifactsListResponse,
|
||||
)
|
||||
async def list_agent_task_artifacts(
|
||||
request: Request,
|
||||
task_id: str,
|
||||
page: int = Query(1, ge=1),
|
||||
page_size: int = Query(10, ge=1, alias="pageSize"),
|
||||
) -> TaskArtifactsListResponse:
|
||||
"""
|
||||
Retrieves a paginated list of artifacts associated with a specific task.
|
||||
|
||||
Args:
|
||||
request (Request): FastAPI request object.
|
||||
task_id (str): The ID of the task.
|
||||
page (int, optional): The page number for pagination. Defaults to 1.
|
||||
page_size (int, optional): Number of items per page for pagination. Default: 10.
|
||||
|
||||
Returns:
|
||||
TaskArtifactsListResponse: A list of artifacts, and pagination details.
|
||||
|
||||
Example:
|
||||
Request:
|
||||
GET /agent/tasks/50da533e-3904-4401-8a07-c49adf88b5eb/artifacts?page=1&pageSize=10
|
||||
|
||||
Response (TaskArtifactsListResponse defined in schema.py):
|
||||
{
|
||||
"items": [
|
||||
{"artifact_id": "artifact1_id", ...},
|
||||
{"artifact_id": "artifact2_id", ...},
|
||||
...
|
||||
],
|
||||
"pagination": {
|
||||
"total": 100,
|
||||
"pages": 10,
|
||||
"current": 1,
|
||||
"pageSize": 10
|
||||
}
|
||||
}
|
||||
""" # noqa: E501
|
||||
agent: "Agent" = request["agent"]
|
||||
try:
|
||||
artifacts = await agent.list_artifacts(task_id, page, page_size)
|
||||
return artifacts
|
||||
except Exception:
|
||||
logger.exception("Error whilst trying to list artifacts")
|
||||
raise
|
||||
|
||||
|
||||
@base_router.post(
|
||||
"/agent/tasks/{task_id}/artifacts", tags=["agent"], response_model=Artifact
|
||||
)
|
||||
async def upload_agent_task_artifacts(
|
||||
request: Request, task_id: str, file: UploadFile, relative_path: str = ""
|
||||
) -> Artifact:
|
||||
"""
|
||||
This endpoint is used to upload an artifact (file) associated with a specific task.
|
||||
|
||||
Args:
|
||||
request (Request): The FastAPI request object.
|
||||
task_id (str): The ID of the task for which the artifact is being uploaded.
|
||||
file (UploadFile): The file being uploaded as an artifact.
|
||||
relative_path (str): The relative path for the file. This is a query parameter.
|
||||
|
||||
Returns:
|
||||
Artifact: Metadata object for the uploaded artifact, including its ID and path.
|
||||
|
||||
Example:
|
||||
Request:
|
||||
POST /agent/tasks/50da533e-3904-4401-8a07-c49adf88b5eb/artifacts?relative_path=my_folder/my_other_folder
|
||||
File: <uploaded_file>
|
||||
|
||||
Response:
|
||||
{
|
||||
"artifact_id": "b225e278-8b4c-4f99-a696-8facf19f0e56",
|
||||
"created_at": "2023-01-01T00:00:00Z",
|
||||
"modified_at": "2023-01-01T00:00:00Z",
|
||||
"agent_created": false,
|
||||
"relative_path": "/my_folder/my_other_folder/",
|
||||
"file_name": "main.py"
|
||||
}
|
||||
""" # noqa: E501
|
||||
agent: "Agent" = request["agent"]
|
||||
|
||||
if file is None:
|
||||
raise HTTPException(status_code=400, detail="File must be specified")
|
||||
try:
|
||||
artifact = await agent.create_artifact(task_id, file, relative_path)
|
||||
return artifact
|
||||
except Exception:
|
||||
logger.exception(f"Error whilst trying to upload artifact: {task_id}")
|
||||
raise
|
||||
|
||||
|
||||
@base_router.get(
|
||||
"/agent/tasks/{task_id}/artifacts/{artifact_id}",
|
||||
tags=["agent"],
|
||||
response_model=str,
|
||||
)
|
||||
async def download_agent_task_artifact(
|
||||
request: Request, task_id: str, artifact_id: str
|
||||
) -> StreamingResponse:
|
||||
"""
|
||||
Downloads an artifact associated with a specific task.
|
||||
|
||||
Args:
|
||||
request (Request): FastAPI request object.
|
||||
task_id (str): The ID of the task.
|
||||
artifact_id (str): The ID of the artifact.
|
||||
|
||||
Returns:
|
||||
FileResponse: The downloaded artifact file.
|
||||
|
||||
Example:
|
||||
Request:
|
||||
GET /agent/tasks/50da533e-3904-4401-8a07-c49adf88b5eb/artifacts/artifact1_id
|
||||
|
||||
Response:
|
||||
<file_content_of_artifact>
|
||||
"""
|
||||
agent: "Agent" = request["agent"]
|
||||
try:
|
||||
return await agent.get_artifact(task_id, artifact_id)
|
||||
except Exception:
|
||||
logger.exception(f"Error whilst trying to download artifact: {task_id}")
|
||||
raise
|
||||
3
forge/forge/agent_protocol/database/__init__.py
Normal file
3
forge/forge/agent_protocol/database/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
from .db import AgentDB
|
||||
|
||||
__all__ = ["AgentDB"]
|
||||
502
forge/forge/agent_protocol/database/db.py
Normal file
502
forge/forge/agent_protocol/database/db.py
Normal file
@@ -0,0 +1,502 @@
|
||||
"""
|
||||
This is an example implementation of the Agent Protocol DB for development Purposes
|
||||
It uses SQLite as the database and file store backend.
|
||||
IT IS NOT ADVISED TO USE THIS IN PRODUCTION!
|
||||
"""
|
||||
|
||||
import logging
|
||||
import math
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict, List, Literal, Optional, Tuple
|
||||
|
||||
from sqlalchemy import JSON, Boolean, DateTime, ForeignKey, create_engine
|
||||
from sqlalchemy.exc import SQLAlchemyError
|
||||
from sqlalchemy.orm import (
|
||||
DeclarativeBase,
|
||||
Mapped,
|
||||
joinedload,
|
||||
mapped_column,
|
||||
relationship,
|
||||
sessionmaker,
|
||||
)
|
||||
|
||||
from forge.utils.exceptions import NotFoundError
|
||||
|
||||
from ..models.artifact import Artifact
|
||||
from ..models.pagination import Pagination
|
||||
from ..models.task import Step, StepRequestBody, StepStatus, Task
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class Base(DeclarativeBase):
|
||||
type_annotation_map = {
|
||||
dict[str, Any]: JSON,
|
||||
}
|
||||
|
||||
|
||||
class TaskModel(Base):
|
||||
__tablename__ = "tasks"
|
||||
|
||||
task_id: Mapped[str] = mapped_column(primary_key=True, index=True)
|
||||
input: Mapped[str]
|
||||
additional_input: Mapped[dict[str, Any]] = mapped_column(default=dict)
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime, default=datetime.utcnow)
|
||||
modified_at: Mapped[datetime] = mapped_column(
|
||||
DateTime, default=datetime.utcnow, onupdate=datetime.utcnow
|
||||
)
|
||||
|
||||
artifacts = relationship("ArtifactModel", back_populates="task")
|
||||
|
||||
|
||||
class StepModel(Base):
|
||||
__tablename__ = "steps"
|
||||
|
||||
step_id: Mapped[str] = mapped_column(primary_key=True, index=True)
|
||||
task_id: Mapped[str] = mapped_column(ForeignKey("tasks.task_id"))
|
||||
name: Mapped[str]
|
||||
input: Mapped[str]
|
||||
status: Mapped[str]
|
||||
output: Mapped[Optional[str]]
|
||||
is_last: Mapped[bool] = mapped_column(Boolean, default=False)
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime, default=datetime.utcnow)
|
||||
modified_at: Mapped[datetime] = mapped_column(
|
||||
DateTime, default=datetime.utcnow, onupdate=datetime.utcnow
|
||||
)
|
||||
|
||||
additional_input: Mapped[dict[str, Any]] = mapped_column(default=dict)
|
||||
additional_output: Mapped[Optional[dict[str, Any]]]
|
||||
artifacts = relationship("ArtifactModel", back_populates="step")
|
||||
|
||||
|
||||
class ArtifactModel(Base):
|
||||
__tablename__ = "artifacts"
|
||||
|
||||
artifact_id: Mapped[str] = mapped_column(primary_key=True, index=True)
|
||||
task_id: Mapped[str] = mapped_column(ForeignKey("tasks.task_id"))
|
||||
step_id: Mapped[Optional[str]] = mapped_column(ForeignKey("steps.step_id"))
|
||||
agent_created: Mapped[bool] = mapped_column(default=False)
|
||||
file_name: Mapped[str]
|
||||
relative_path: Mapped[str]
|
||||
created_at: Mapped[datetime] = mapped_column(default=datetime.utcnow)
|
||||
modified_at: Mapped[datetime] = mapped_column(
|
||||
default=datetime.utcnow, onupdate=datetime.utcnow
|
||||
)
|
||||
|
||||
step = relationship("StepModel", back_populates="artifacts")
|
||||
task = relationship("TaskModel", back_populates="artifacts")
|
||||
|
||||
|
||||
def convert_to_task(task_obj: TaskModel, debug_enabled: bool = False) -> Task:
|
||||
if debug_enabled:
|
||||
logger.debug(f"Converting TaskModel to Task for task_id: {task_obj.task_id}")
|
||||
task_artifacts = [convert_to_artifact(artifact) for artifact in task_obj.artifacts]
|
||||
return Task(
|
||||
task_id=task_obj.task_id,
|
||||
created_at=task_obj.created_at,
|
||||
modified_at=task_obj.modified_at,
|
||||
input=task_obj.input,
|
||||
additional_input=task_obj.additional_input,
|
||||
artifacts=task_artifacts,
|
||||
)
|
||||
|
||||
|
||||
def convert_to_step(step_model: StepModel, debug_enabled: bool = False) -> Step:
|
||||
if debug_enabled:
|
||||
logger.debug(f"Converting StepModel to Step for step_id: {step_model.step_id}")
|
||||
step_artifacts = [
|
||||
convert_to_artifact(artifact) for artifact in step_model.artifacts
|
||||
]
|
||||
status = (
|
||||
StepStatus.completed if step_model.status == "completed" else StepStatus.created
|
||||
)
|
||||
return Step(
|
||||
task_id=step_model.task_id,
|
||||
step_id=step_model.step_id,
|
||||
created_at=step_model.created_at,
|
||||
modified_at=step_model.modified_at,
|
||||
name=step_model.name,
|
||||
input=step_model.input,
|
||||
status=status,
|
||||
output=step_model.output,
|
||||
artifacts=step_artifacts,
|
||||
is_last=step_model.is_last == 1,
|
||||
additional_input=step_model.additional_input,
|
||||
additional_output=step_model.additional_output,
|
||||
)
|
||||
|
||||
|
||||
def convert_to_artifact(artifact_model: ArtifactModel) -> Artifact:
|
||||
return Artifact(
|
||||
artifact_id=artifact_model.artifact_id,
|
||||
created_at=artifact_model.created_at,
|
||||
modified_at=artifact_model.modified_at,
|
||||
agent_created=artifact_model.agent_created,
|
||||
relative_path=artifact_model.relative_path,
|
||||
file_name=artifact_model.file_name,
|
||||
)
|
||||
|
||||
|
||||
# sqlite:///{database_name}
|
||||
class AgentDB:
|
||||
def __init__(self, database_string, debug_enabled: bool = False) -> None:
|
||||
super().__init__()
|
||||
self.debug_enabled = debug_enabled
|
||||
if self.debug_enabled:
|
||||
logger.debug(
|
||||
f"Initializing AgentDB with database_string: {database_string}"
|
||||
)
|
||||
self.engine = create_engine(database_string)
|
||||
Base.metadata.create_all(self.engine)
|
||||
self.Session = sessionmaker(bind=self.engine)
|
||||
|
||||
def close(self) -> None:
|
||||
self.Session.close_all()
|
||||
self.engine.dispose()
|
||||
|
||||
async def create_task(
|
||||
self, input: Optional[str], additional_input: Optional[dict] = {}
|
||||
) -> Task:
|
||||
if self.debug_enabled:
|
||||
logger.debug("Creating new task")
|
||||
|
||||
try:
|
||||
with self.Session() as session:
|
||||
new_task = TaskModel(
|
||||
task_id=str(uuid.uuid4()),
|
||||
input=input,
|
||||
additional_input=additional_input if additional_input else {},
|
||||
)
|
||||
session.add(new_task)
|
||||
session.commit()
|
||||
session.refresh(new_task)
|
||||
if self.debug_enabled:
|
||||
logger.debug(f"Created new task with task_id: {new_task.task_id}")
|
||||
return convert_to_task(new_task, self.debug_enabled)
|
||||
except SQLAlchemyError as e:
|
||||
logger.error(f"SQLAlchemy error while creating task: {e}")
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error while creating task: {e}")
|
||||
raise
|
||||
|
||||
async def create_step(
|
||||
self,
|
||||
task_id: str,
|
||||
input: StepRequestBody,
|
||||
is_last: bool = False,
|
||||
additional_input: Optional[Dict[str, Any]] = {},
|
||||
) -> Step:
|
||||
if self.debug_enabled:
|
||||
logger.debug(f"Creating new step for task_id: {task_id}")
|
||||
try:
|
||||
with self.Session() as session:
|
||||
new_step = StepModel(
|
||||
task_id=task_id,
|
||||
step_id=str(uuid.uuid4()),
|
||||
name=input.input,
|
||||
input=input.input,
|
||||
status="created",
|
||||
is_last=is_last,
|
||||
additional_input=additional_input,
|
||||
)
|
||||
session.add(new_step)
|
||||
session.commit()
|
||||
session.refresh(new_step)
|
||||
if self.debug_enabled:
|
||||
logger.debug(f"Created new step with step_id: {new_step.step_id}")
|
||||
return convert_to_step(new_step, self.debug_enabled)
|
||||
except SQLAlchemyError as e:
|
||||
logger.error(f"SQLAlchemy error while creating step: {e}")
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error while creating step: {e}")
|
||||
raise
|
||||
|
||||
async def create_artifact(
|
||||
self,
|
||||
task_id: str,
|
||||
file_name: str,
|
||||
relative_path: str,
|
||||
agent_created: bool = False,
|
||||
step_id: str | None = None,
|
||||
) -> Artifact:
|
||||
if self.debug_enabled:
|
||||
logger.debug(f"Creating new artifact for task_id: {task_id}")
|
||||
try:
|
||||
with self.Session() as session:
|
||||
if (
|
||||
existing_artifact := session.query(ArtifactModel)
|
||||
.filter_by(
|
||||
task_id=task_id,
|
||||
file_name=file_name,
|
||||
relative_path=relative_path,
|
||||
)
|
||||
.first()
|
||||
):
|
||||
session.close()
|
||||
if self.debug_enabled:
|
||||
logger.debug(
|
||||
f"Artifact {file_name} already exists at {relative_path}/"
|
||||
)
|
||||
return convert_to_artifact(existing_artifact)
|
||||
|
||||
new_artifact = ArtifactModel(
|
||||
artifact_id=str(uuid.uuid4()),
|
||||
task_id=task_id,
|
||||
step_id=step_id,
|
||||
agent_created=agent_created,
|
||||
file_name=file_name,
|
||||
relative_path=relative_path,
|
||||
)
|
||||
session.add(new_artifact)
|
||||
session.commit()
|
||||
session.refresh(new_artifact)
|
||||
if self.debug_enabled:
|
||||
logger.debug(
|
||||
f"Created new artifact with ID: {new_artifact.artifact_id}"
|
||||
)
|
||||
return convert_to_artifact(new_artifact)
|
||||
except SQLAlchemyError as e:
|
||||
logger.error(f"SQLAlchemy error while creating step: {e}")
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error while creating step: {e}")
|
||||
raise
|
||||
|
||||
async def get_task(self, task_id: str) -> Task:
|
||||
"""Get a task by its id"""
|
||||
if self.debug_enabled:
|
||||
logger.debug(f"Getting task with task_id: {task_id}")
|
||||
try:
|
||||
with self.Session() as session:
|
||||
if task_obj := (
|
||||
session.query(TaskModel)
|
||||
.options(joinedload(TaskModel.artifacts))
|
||||
.filter_by(task_id=task_id)
|
||||
.first()
|
||||
):
|
||||
return convert_to_task(task_obj, self.debug_enabled)
|
||||
else:
|
||||
logger.error(f"Task not found with task_id: {task_id}")
|
||||
raise NotFoundError("Task not found")
|
||||
except SQLAlchemyError as e:
|
||||
logger.error(f"SQLAlchemy error while getting task: {e}")
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error while getting task: {e}")
|
||||
raise
|
||||
|
||||
async def get_step(self, task_id: str, step_id: str) -> Step:
|
||||
if self.debug_enabled:
|
||||
logger.debug(f"Getting step with task_id: {task_id} and step_id: {step_id}")
|
||||
try:
|
||||
with self.Session() as session:
|
||||
if step := (
|
||||
session.query(StepModel)
|
||||
.options(joinedload(StepModel.artifacts))
|
||||
.filter(StepModel.step_id == step_id)
|
||||
.first()
|
||||
):
|
||||
return convert_to_step(step, self.debug_enabled)
|
||||
|
||||
else:
|
||||
logger.error(
|
||||
f"Step not found with task_id: {task_id} and step_id: {step_id}"
|
||||
)
|
||||
raise NotFoundError("Step not found")
|
||||
except SQLAlchemyError as e:
|
||||
logger.error(f"SQLAlchemy error while getting step: {e}")
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error while getting step: {e}")
|
||||
raise
|
||||
|
||||
async def get_artifact(self, artifact_id: str) -> Artifact:
|
||||
if self.debug_enabled:
|
||||
logger.debug(f"Getting artifact with and artifact_id: {artifact_id}")
|
||||
try:
|
||||
with self.Session() as session:
|
||||
if (
|
||||
artifact_model := session.query(ArtifactModel)
|
||||
.filter_by(artifact_id=artifact_id)
|
||||
.first()
|
||||
):
|
||||
return convert_to_artifact(artifact_model)
|
||||
else:
|
||||
logger.error(
|
||||
f"Artifact not found with and artifact_id: {artifact_id}"
|
||||
)
|
||||
raise NotFoundError("Artifact not found")
|
||||
except SQLAlchemyError as e:
|
||||
logger.error(f"SQLAlchemy error while getting artifact: {e}")
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error while getting artifact: {e}")
|
||||
raise
|
||||
|
||||
async def update_step(
|
||||
self,
|
||||
task_id: str,
|
||||
step_id: str,
|
||||
status: Optional[str] = None,
|
||||
output: Optional[str] = None,
|
||||
additional_input: Optional[Dict[str, Any]] = None,
|
||||
additional_output: Optional[Dict[str, Any]] = None,
|
||||
) -> Step:
|
||||
if self.debug_enabled:
|
||||
logger.debug(
|
||||
f"Updating step with task_id: {task_id} and step_id: {step_id}"
|
||||
)
|
||||
try:
|
||||
with self.Session() as session:
|
||||
if (
|
||||
step := session.query(StepModel)
|
||||
.filter_by(task_id=task_id, step_id=step_id)
|
||||
.first()
|
||||
):
|
||||
if status is not None:
|
||||
step.status = status
|
||||
if additional_input is not None:
|
||||
step.additional_input = additional_input
|
||||
if output is not None:
|
||||
step.output = output
|
||||
if additional_output is not None:
|
||||
step.additional_output = additional_output
|
||||
session.commit()
|
||||
return await self.get_step(task_id, step_id)
|
||||
else:
|
||||
logger.error(
|
||||
"Can't update non-existent Step with "
|
||||
f"task_id: {task_id} and step_id: {step_id}"
|
||||
)
|
||||
raise NotFoundError("Step not found")
|
||||
except SQLAlchemyError as e:
|
||||
logger.error(f"SQLAlchemy error while getting step: {e}")
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error while getting step: {e}")
|
||||
raise
|
||||
|
||||
async def update_artifact(
|
||||
self,
|
||||
artifact_id: str,
|
||||
*,
|
||||
file_name: str = "",
|
||||
relative_path: str = "",
|
||||
agent_created: Optional[Literal[True]] = None,
|
||||
) -> Artifact:
|
||||
logger.debug(f"Updating artifact with artifact_id: {artifact_id}")
|
||||
with self.Session() as session:
|
||||
if (
|
||||
artifact := session.query(ArtifactModel)
|
||||
.filter_by(artifact_id=artifact_id)
|
||||
.first()
|
||||
):
|
||||
if file_name:
|
||||
artifact.file_name = file_name
|
||||
if relative_path:
|
||||
artifact.relative_path = relative_path
|
||||
if agent_created:
|
||||
artifact.agent_created = agent_created
|
||||
session.commit()
|
||||
return await self.get_artifact(artifact_id)
|
||||
else:
|
||||
logger.error(f"Artifact not found with artifact_id: {artifact_id}")
|
||||
raise NotFoundError("Artifact not found")
|
||||
|
||||
async def list_tasks(
|
||||
self, page: int = 1, per_page: int = 10
|
||||
) -> Tuple[List[Task], Pagination]:
|
||||
if self.debug_enabled:
|
||||
logger.debug("Listing tasks")
|
||||
try:
|
||||
with self.Session() as session:
|
||||
tasks = (
|
||||
session.query(TaskModel)
|
||||
.offset((page - 1) * per_page)
|
||||
.limit(per_page)
|
||||
.all()
|
||||
)
|
||||
total = session.query(TaskModel).count()
|
||||
pages = math.ceil(total / per_page)
|
||||
pagination = Pagination(
|
||||
total_items=total,
|
||||
total_pages=pages,
|
||||
current_page=page,
|
||||
page_size=per_page,
|
||||
)
|
||||
return [
|
||||
convert_to_task(task, self.debug_enabled) for task in tasks
|
||||
], pagination
|
||||
except SQLAlchemyError as e:
|
||||
logger.error(f"SQLAlchemy error while listing tasks: {e}")
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error while listing tasks: {e}")
|
||||
raise
|
||||
|
||||
async def list_steps(
|
||||
self, task_id: str, page: int = 1, per_page: int = 10
|
||||
) -> Tuple[List[Step], Pagination]:
|
||||
if self.debug_enabled:
|
||||
logger.debug(f"Listing steps for task_id: {task_id}")
|
||||
try:
|
||||
with self.Session() as session:
|
||||
steps = (
|
||||
session.query(StepModel)
|
||||
.filter_by(task_id=task_id)
|
||||
.offset((page - 1) * per_page)
|
||||
.limit(per_page)
|
||||
.all()
|
||||
)
|
||||
total = session.query(StepModel).filter_by(task_id=task_id).count()
|
||||
pages = math.ceil(total / per_page)
|
||||
pagination = Pagination(
|
||||
total_items=total,
|
||||
total_pages=pages,
|
||||
current_page=page,
|
||||
page_size=per_page,
|
||||
)
|
||||
return [
|
||||
convert_to_step(step, self.debug_enabled) for step in steps
|
||||
], pagination
|
||||
except SQLAlchemyError as e:
|
||||
logger.error(f"SQLAlchemy error while listing steps: {e}")
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error while listing steps: {e}")
|
||||
raise
|
||||
|
||||
async def list_artifacts(
|
||||
self, task_id: str, page: int = 1, per_page: int = 10
|
||||
) -> Tuple[List[Artifact], Pagination]:
|
||||
if self.debug_enabled:
|
||||
logger.debug(f"Listing artifacts for task_id: {task_id}")
|
||||
try:
|
||||
with self.Session() as session:
|
||||
artifacts = (
|
||||
session.query(ArtifactModel)
|
||||
.filter_by(task_id=task_id)
|
||||
.offset((page - 1) * per_page)
|
||||
.limit(per_page)
|
||||
.all()
|
||||
)
|
||||
total = session.query(ArtifactModel).filter_by(task_id=task_id).count()
|
||||
pages = math.ceil(total / per_page)
|
||||
pagination = Pagination(
|
||||
total_items=total,
|
||||
total_pages=pages,
|
||||
current_page=page,
|
||||
page_size=per_page,
|
||||
)
|
||||
return [
|
||||
convert_to_artifact(artifact) for artifact in artifacts
|
||||
], pagination
|
||||
except SQLAlchemyError as e:
|
||||
logger.error(f"SQLAlchemy error while listing artifacts: {e}")
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error while listing artifacts: {e}")
|
||||
raise
|
||||
313
forge/forge/agent_protocol/database/db_test.py
Normal file
313
forge/forge/agent_protocol/database/db_test.py
Normal file
@@ -0,0 +1,313 @@
|
||||
import os
|
||||
import sqlite3
|
||||
from datetime import datetime
|
||||
|
||||
import pytest
|
||||
|
||||
from forge.agent_protocol.database.db import (
|
||||
AgentDB,
|
||||
ArtifactModel,
|
||||
StepModel,
|
||||
TaskModel,
|
||||
convert_to_artifact,
|
||||
convert_to_step,
|
||||
convert_to_task,
|
||||
)
|
||||
from forge.agent_protocol.models import (
|
||||
Artifact,
|
||||
Step,
|
||||
StepRequestBody,
|
||||
StepStatus,
|
||||
Task,
|
||||
)
|
||||
from forge.utils.exceptions import NotFoundError as DataNotFoundError
|
||||
|
||||
TEST_DB_FILENAME = "test_db.sqlite3"
|
||||
TEST_DB_URL = f"sqlite:///{TEST_DB_FILENAME}"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def agent_db():
|
||||
db = AgentDB(TEST_DB_URL)
|
||||
yield db
|
||||
db.close()
|
||||
os.remove(TEST_DB_FILENAME)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def raw_db_connection(agent_db: AgentDB):
|
||||
connection = sqlite3.connect(TEST_DB_FILENAME)
|
||||
yield connection
|
||||
connection.close()
|
||||
|
||||
|
||||
def test_table_creation(raw_db_connection: sqlite3.Connection):
|
||||
cursor = raw_db_connection.cursor()
|
||||
|
||||
# Test for tasks table existence
|
||||
cursor.execute("SELECT name FROM sqlite_master WHERE type='table' AND name='tasks'")
|
||||
assert cursor.fetchone() is not None
|
||||
|
||||
# Test for steps table existence
|
||||
cursor.execute("SELECT name FROM sqlite_master WHERE type='table' AND name='steps'")
|
||||
assert cursor.fetchone() is not None
|
||||
|
||||
# Test for artifacts table existence
|
||||
cursor.execute(
|
||||
"SELECT name FROM sqlite_master WHERE type='table' AND name='artifacts'"
|
||||
)
|
||||
assert cursor.fetchone() is not None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_task_schema():
|
||||
now = datetime.now()
|
||||
task = Task(
|
||||
task_id="50da533e-3904-4401-8a07-c49adf88b5eb",
|
||||
input="Write the words you receive to the file 'output.txt'.",
|
||||
created_at=now,
|
||||
modified_at=now,
|
||||
artifacts=[
|
||||
Artifact(
|
||||
artifact_id="b225e278-8b4c-4f99-a696-8facf19f0e56",
|
||||
agent_created=True,
|
||||
file_name="main.py",
|
||||
relative_path="python/code/",
|
||||
created_at=now,
|
||||
modified_at=now,
|
||||
)
|
||||
],
|
||||
)
|
||||
assert task.task_id == "50da533e-3904-4401-8a07-c49adf88b5eb"
|
||||
assert task.input == "Write the words you receive to the file 'output.txt'."
|
||||
assert len(task.artifacts) == 1
|
||||
assert task.artifacts[0].artifact_id == "b225e278-8b4c-4f99-a696-8facf19f0e56"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_step_schema():
|
||||
now = datetime.now()
|
||||
step = Step(
|
||||
task_id="50da533e-3904-4401-8a07-c49adf88b5eb",
|
||||
step_id="6bb1801a-fd80-45e8-899a-4dd723cc602e",
|
||||
created_at=now,
|
||||
modified_at=now,
|
||||
name="Write to file",
|
||||
input="Write the words you receive to the file 'output.txt'.",
|
||||
status=StepStatus.created,
|
||||
output=(
|
||||
"I am going to use the write_to_file command and write Washington "
|
||||
"to a file called output.txt <write_to_file('output.txt', 'Washington')>"
|
||||
),
|
||||
artifacts=[
|
||||
Artifact(
|
||||
artifact_id="b225e278-8b4c-4f99-a696-8facf19f0e56",
|
||||
file_name="main.py",
|
||||
relative_path="python/code/",
|
||||
created_at=now,
|
||||
modified_at=now,
|
||||
agent_created=True,
|
||||
)
|
||||
],
|
||||
is_last=False,
|
||||
)
|
||||
assert step.task_id == "50da533e-3904-4401-8a07-c49adf88b5eb"
|
||||
assert step.step_id == "6bb1801a-fd80-45e8-899a-4dd723cc602e"
|
||||
assert step.name == "Write to file"
|
||||
assert step.status == StepStatus.created
|
||||
assert step.output == (
|
||||
"I am going to use the write_to_file command and write Washington "
|
||||
"to a file called output.txt <write_to_file('output.txt', 'Washington')>"
|
||||
)
|
||||
assert len(step.artifacts) == 1
|
||||
assert step.artifacts[0].artifact_id == "b225e278-8b4c-4f99-a696-8facf19f0e56"
|
||||
assert step.is_last is False
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_convert_to_task():
|
||||
now = datetime.now()
|
||||
task_model = TaskModel(
|
||||
task_id="50da533e-3904-4401-8a07-c49adf88b5eb",
|
||||
created_at=now,
|
||||
modified_at=now,
|
||||
input="Write the words you receive to the file 'output.txt'.",
|
||||
additional_input={},
|
||||
artifacts=[
|
||||
ArtifactModel(
|
||||
artifact_id="b225e278-8b4c-4f99-a696-8facf19f0e56",
|
||||
created_at=now,
|
||||
modified_at=now,
|
||||
relative_path="file:///path/to/main.py",
|
||||
agent_created=True,
|
||||
file_name="main.py",
|
||||
)
|
||||
],
|
||||
)
|
||||
task = convert_to_task(task_model)
|
||||
assert task.task_id == "50da533e-3904-4401-8a07-c49adf88b5eb"
|
||||
assert task.input == "Write the words you receive to the file 'output.txt'."
|
||||
assert len(task.artifacts) == 1
|
||||
assert task.artifacts[0].artifact_id == "b225e278-8b4c-4f99-a696-8facf19f0e56"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_convert_to_step():
|
||||
now = datetime.now()
|
||||
step_model = StepModel(
|
||||
task_id="50da533e-3904-4401-8a07-c49adf88b5eb",
|
||||
step_id="6bb1801a-fd80-45e8-899a-4dd723cc602e",
|
||||
created_at=now,
|
||||
modified_at=now,
|
||||
name="Write to file",
|
||||
status="created",
|
||||
input="Write the words you receive to the file 'output.txt'.",
|
||||
additional_input={},
|
||||
artifacts=[
|
||||
ArtifactModel(
|
||||
artifact_id="b225e278-8b4c-4f99-a696-8facf19f0e56",
|
||||
created_at=now,
|
||||
modified_at=now,
|
||||
relative_path="file:///path/to/main.py",
|
||||
agent_created=True,
|
||||
file_name="main.py",
|
||||
)
|
||||
],
|
||||
is_last=False,
|
||||
)
|
||||
step = convert_to_step(step_model)
|
||||
assert step.task_id == "50da533e-3904-4401-8a07-c49adf88b5eb"
|
||||
assert step.step_id == "6bb1801a-fd80-45e8-899a-4dd723cc602e"
|
||||
assert step.name == "Write to file"
|
||||
assert step.status == StepStatus.created
|
||||
assert len(step.artifacts) == 1
|
||||
assert step.artifacts[0].artifact_id == "b225e278-8b4c-4f99-a696-8facf19f0e56"
|
||||
assert step.is_last is False
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_convert_to_artifact():
|
||||
now = datetime.now()
|
||||
artifact_model = ArtifactModel(
|
||||
artifact_id="b225e278-8b4c-4f99-a696-8facf19f0e56",
|
||||
created_at=now,
|
||||
modified_at=now,
|
||||
relative_path="file:///path/to/main.py",
|
||||
agent_created=True,
|
||||
file_name="main.py",
|
||||
)
|
||||
artifact = convert_to_artifact(artifact_model)
|
||||
assert artifact.artifact_id == "b225e278-8b4c-4f99-a696-8facf19f0e56"
|
||||
assert artifact.relative_path == "file:///path/to/main.py"
|
||||
assert artifact.agent_created is True
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_task(agent_db: AgentDB):
|
||||
task = await agent_db.create_task("task_input")
|
||||
assert task.input == "task_input"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_and_get_task(agent_db: AgentDB):
|
||||
task = await agent_db.create_task("test_input")
|
||||
fetched_task = await agent_db.get_task(task.task_id)
|
||||
assert fetched_task.input == "test_input"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_task_not_found(agent_db: AgentDB):
|
||||
with pytest.raises(DataNotFoundError):
|
||||
await agent_db.get_task("9999")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_and_get_step(agent_db: AgentDB):
|
||||
task = await agent_db.create_task("task_input")
|
||||
step_input = {"type": "python/code"}
|
||||
request = StepRequestBody(input="test_input debug", additional_input=step_input)
|
||||
step = await agent_db.create_step(task.task_id, request)
|
||||
step = await agent_db.get_step(task.task_id, step.step_id)
|
||||
assert step.input == "test_input debug"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_updating_step(agent_db: AgentDB):
|
||||
created_task = await agent_db.create_task("task_input")
|
||||
step_input = {"type": "python/code"}
|
||||
request = StepRequestBody(input="test_input debug", additional_input=step_input)
|
||||
created_step = await agent_db.create_step(created_task.task_id, request)
|
||||
await agent_db.update_step(created_task.task_id, created_step.step_id, "completed")
|
||||
|
||||
step = await agent_db.get_step(created_task.task_id, created_step.step_id)
|
||||
assert step.status.value == "completed"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_step_not_found(agent_db: AgentDB):
|
||||
with pytest.raises(DataNotFoundError):
|
||||
await agent_db.get_step("9999", "9999")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_artifact(agent_db: AgentDB):
|
||||
# Given: A task and its corresponding artifact
|
||||
task = await agent_db.create_task("test_input debug")
|
||||
step_input = {"type": "python/code"}
|
||||
requst = StepRequestBody(input="test_input debug", additional_input=step_input)
|
||||
|
||||
step = await agent_db.create_step(task.task_id, requst)
|
||||
|
||||
# Create an artifact
|
||||
artifact = await agent_db.create_artifact(
|
||||
task_id=task.task_id,
|
||||
file_name="test_get_artifact_sample_file.txt",
|
||||
relative_path="file:///path/to/test_get_artifact_sample_file.txt",
|
||||
agent_created=True,
|
||||
step_id=step.step_id,
|
||||
)
|
||||
|
||||
# When: The artifact is fetched by its ID
|
||||
fetched_artifact = await agent_db.get_artifact(artifact.artifact_id)
|
||||
|
||||
# Then: The fetched artifact matches the original
|
||||
assert fetched_artifact.artifact_id == artifact.artifact_id
|
||||
assert (
|
||||
fetched_artifact.relative_path
|
||||
== "file:///path/to/test_get_artifact_sample_file.txt"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_tasks(agent_db: AgentDB):
|
||||
# Given: Multiple tasks in the database
|
||||
task1 = await agent_db.create_task("test_input_1")
|
||||
task2 = await agent_db.create_task("test_input_2")
|
||||
|
||||
# When: All tasks are fetched
|
||||
fetched_tasks, pagination = await agent_db.list_tasks()
|
||||
|
||||
# Then: The fetched tasks list includes the created tasks
|
||||
task_ids = [task.task_id for task in fetched_tasks]
|
||||
assert task1.task_id in task_ids
|
||||
assert task2.task_id in task_ids
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_steps(agent_db: AgentDB):
|
||||
step_input = {"type": "python/code"}
|
||||
request = StepRequestBody(input="test_input debug", additional_input=step_input)
|
||||
|
||||
# Given: A task and multiple steps for that task
|
||||
task = await agent_db.create_task("test_input")
|
||||
step1 = await agent_db.create_step(task.task_id, request)
|
||||
request = StepRequestBody(input="step two")
|
||||
step2 = await agent_db.create_step(task.task_id, request)
|
||||
|
||||
# When: All steps for the task are fetched
|
||||
fetched_steps, pagination = await agent_db.list_steps(task.task_id)
|
||||
|
||||
# Then: The fetched steps list includes the created steps
|
||||
step_ids = [step.step_id for step in fetched_steps]
|
||||
assert step1.step_id in step_ids
|
||||
assert step2.step_id in step_ids
|
||||
34
forge/forge/agent_protocol/middlewares.py
Normal file
34
forge/forge/agent_protocol/middlewares.py
Normal file
@@ -0,0 +1,34 @@
|
||||
from starlette.types import ASGIApp
|
||||
|
||||
|
||||
class AgentMiddleware:
|
||||
"""
|
||||
Middleware that injects the agent instance into the request scope.
|
||||
"""
|
||||
|
||||
def __init__(self, app: ASGIApp, agent):
|
||||
"""
|
||||
|
||||
Args:
|
||||
app: The FastAPI app - automatically injected by FastAPI.
|
||||
agent: The agent instance to inject into the request scope.
|
||||
|
||||
Examples:
|
||||
>>> from fastapi import FastAPI, Request
|
||||
>>> from agent_protocol.agent import Agent
|
||||
>>> from agent_protocol.middlewares import AgentMiddleware
|
||||
>>> app = FastAPI()
|
||||
>>> @app.get("/")
|
||||
>>> async def root(request: Request):
|
||||
>>> agent = request["agent"]
|
||||
>>> task = agent.db.create_task("Do something.")
|
||||
>>> return {"task_id": a.task_id}
|
||||
>>> agent = Agent()
|
||||
>>> app.add_middleware(AgentMiddleware, agent=agent)
|
||||
"""
|
||||
self.app = app
|
||||
self.agent = agent
|
||||
|
||||
async def __call__(self, scope, receive, send):
|
||||
scope["agent"] = self.agent
|
||||
await self.app(scope, receive, send)
|
||||
25
forge/forge/agent_protocol/models/__init__.py
Normal file
25
forge/forge/agent_protocol/models/__init__.py
Normal file
@@ -0,0 +1,25 @@
|
||||
from .artifact import Artifact
|
||||
from .pagination import Pagination
|
||||
from .task import (
|
||||
Step,
|
||||
StepRequestBody,
|
||||
StepStatus,
|
||||
Task,
|
||||
TaskArtifactsListResponse,
|
||||
TaskListResponse,
|
||||
TaskRequestBody,
|
||||
TaskStepsListResponse,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"Artifact",
|
||||
"Pagination",
|
||||
"Step",
|
||||
"StepRequestBody",
|
||||
"StepStatus",
|
||||
"Task",
|
||||
"TaskArtifactsListResponse",
|
||||
"TaskListResponse",
|
||||
"TaskRequestBody",
|
||||
"TaskStepsListResponse",
|
||||
]
|
||||
38
forge/forge/agent_protocol/models/artifact.py
Normal file
38
forge/forge/agent_protocol/models/artifact.py
Normal file
@@ -0,0 +1,38 @@
|
||||
from datetime import datetime
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class Artifact(BaseModel):
|
||||
created_at: datetime = Field(
|
||||
...,
|
||||
description="The creation datetime of the task.",
|
||||
example="2023-01-01T00:00:00Z",
|
||||
json_encoders={datetime: lambda v: v.isoformat()},
|
||||
)
|
||||
modified_at: datetime = Field(
|
||||
...,
|
||||
description="The modification datetime of the task.",
|
||||
example="2023-01-01T00:00:00Z",
|
||||
json_encoders={datetime: lambda v: v.isoformat()},
|
||||
)
|
||||
artifact_id: str = Field(
|
||||
...,
|
||||
description="ID of the artifact.",
|
||||
example="b225e278-8b4c-4f99-a696-8facf19f0e56",
|
||||
)
|
||||
agent_created: bool = Field(
|
||||
...,
|
||||
description="Whether the artifact has been created by the agent.",
|
||||
example=False,
|
||||
)
|
||||
relative_path: str = Field(
|
||||
...,
|
||||
description="Relative path of the artifact in the agents workspace.",
|
||||
example="/my_folder/my_other_folder/",
|
||||
)
|
||||
file_name: str = Field(
|
||||
...,
|
||||
description="Filename of the artifact.",
|
||||
example="main.py",
|
||||
)
|
||||
8
forge/forge/agent_protocol/models/pagination.py
Normal file
8
forge/forge/agent_protocol/models/pagination.py
Normal file
@@ -0,0 +1,8 @@
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class Pagination(BaseModel):
|
||||
total_items: int = Field(..., description="Total number of items.", example=42)
|
||||
total_pages: int = Field(..., description="Total number of pages.", example=97)
|
||||
current_page: int = Field(..., description="Current_page page number.", example=1)
|
||||
page_size: int = Field(..., description="Number of items per page.", example=25)
|
||||
126
forge/forge/agent_protocol/models/task.py
Normal file
126
forge/forge/agent_protocol/models/task.py
Normal file
@@ -0,0 +1,126 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from typing import Any, List, Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from .artifact import Artifact
|
||||
from .pagination import Pagination
|
||||
|
||||
|
||||
class TaskRequestBody(BaseModel):
|
||||
input: str = Field(
|
||||
...,
|
||||
min_length=1,
|
||||
description="Input prompt for the task.",
|
||||
example="Write the words you receive to the file 'output.txt'.",
|
||||
)
|
||||
additional_input: dict[str, Any] = Field(default_factory=dict)
|
||||
|
||||
|
||||
class Task(TaskRequestBody):
|
||||
created_at: datetime = Field(
|
||||
...,
|
||||
description="The creation datetime of the task.",
|
||||
example="2023-01-01T00:00:00Z",
|
||||
json_encoders={datetime: lambda v: v.isoformat()},
|
||||
)
|
||||
modified_at: datetime = Field(
|
||||
...,
|
||||
description="The modification datetime of the task.",
|
||||
example="2023-01-01T00:00:00Z",
|
||||
json_encoders={datetime: lambda v: v.isoformat()},
|
||||
)
|
||||
task_id: str = Field(
|
||||
...,
|
||||
description="The ID of the task.",
|
||||
example="50da533e-3904-4401-8a07-c49adf88b5eb",
|
||||
)
|
||||
artifacts: list[Artifact] = Field(
|
||||
default_factory=list,
|
||||
description="A list of artifacts that the task has produced.",
|
||||
example=[
|
||||
"7a49f31c-f9c6-4346-a22c-e32bc5af4d8e",
|
||||
"ab7b4091-2560-4692-a4fe-d831ea3ca7d6",
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
class StepRequestBody(BaseModel):
|
||||
name: Optional[str] = Field(
|
||||
default=None, description="The name of the task step.", example="Write to file"
|
||||
)
|
||||
input: str = Field(
|
||||
..., description="Input prompt for the step.", example="Washington"
|
||||
)
|
||||
additional_input: dict[str, Any] = Field(default_factory=dict)
|
||||
|
||||
|
||||
class StepStatus(Enum):
|
||||
created = "created"
|
||||
running = "running"
|
||||
completed = "completed"
|
||||
|
||||
|
||||
class Step(StepRequestBody):
|
||||
created_at: datetime = Field(
|
||||
...,
|
||||
description="The creation datetime of the task.",
|
||||
example="2023-01-01T00:00:00Z",
|
||||
json_encoders={datetime: lambda v: v.isoformat()},
|
||||
)
|
||||
modified_at: datetime = Field(
|
||||
...,
|
||||
description="The modification datetime of the task.",
|
||||
example="2023-01-01T00:00:00Z",
|
||||
json_encoders={datetime: lambda v: v.isoformat()},
|
||||
)
|
||||
task_id: str = Field(
|
||||
...,
|
||||
description="The ID of the task this step belongs to.",
|
||||
example="50da533e-3904-4401-8a07-c49adf88b5eb",
|
||||
)
|
||||
step_id: str = Field(
|
||||
...,
|
||||
description="The ID of the task step.",
|
||||
example="6bb1801a-fd80-45e8-899a-4dd723cc602e",
|
||||
)
|
||||
name: Optional[str] = Field(
|
||||
default=None, description="The name of the task step.", example="Write to file"
|
||||
)
|
||||
status: StepStatus = Field(
|
||||
..., description="The status of the task step.", example="created"
|
||||
)
|
||||
output: Optional[str] = Field(
|
||||
default=None,
|
||||
description="Output of the task step.",
|
||||
example=(
|
||||
"I am going to use the write_to_file command and write Washington "
|
||||
"to a file called output.txt <write_to_file('output.txt', 'Washington')"
|
||||
),
|
||||
)
|
||||
additional_output: Optional[dict[str, Any]] = None
|
||||
artifacts: list[Artifact] = Field(
|
||||
default_factory=list,
|
||||
description="A list of artifacts that the step has produced.",
|
||||
)
|
||||
is_last: bool = Field(
|
||||
..., description="Whether this is the last step in the task.", example=True
|
||||
)
|
||||
|
||||
|
||||
class TaskListResponse(BaseModel):
|
||||
tasks: Optional[List[Task]] = None
|
||||
pagination: Optional[Pagination] = None
|
||||
|
||||
|
||||
class TaskStepsListResponse(BaseModel):
|
||||
steps: Optional[List[Step]] = None
|
||||
pagination: Optional[Pagination] = None
|
||||
|
||||
|
||||
class TaskArtifactsListResponse(BaseModel):
|
||||
artifacts: Optional[List[Artifact]] = None
|
||||
pagination: Optional[Pagination] = None
|
||||
5
forge/forge/command/__init__.py
Normal file
5
forge/forge/command/__init__.py
Normal file
@@ -0,0 +1,5 @@
|
||||
from .command import Command
|
||||
from .decorator import command
|
||||
from .parameter import CommandParameter
|
||||
|
||||
__all__ = ["Command", "CommandParameter", "command"]
|
||||
95
forge/forge/command/command.py
Normal file
95
forge/forge/command/command.py
Normal file
@@ -0,0 +1,95 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import inspect
|
||||
from typing import Callable, Concatenate, Generic, ParamSpec, TypeVar, cast
|
||||
|
||||
from forge.agent.protocols import CommandProvider
|
||||
|
||||
from .parameter import CommandParameter
|
||||
|
||||
P = ParamSpec("P")
|
||||
CO = TypeVar("CO") # command output
|
||||
|
||||
_CP = TypeVar("_CP", bound=CommandProvider)
|
||||
|
||||
|
||||
class Command(Generic[P, CO]):
|
||||
"""A class representing a command.
|
||||
|
||||
Attributes:
|
||||
name (str): The name of the command.
|
||||
description (str): A brief description of what the command does.
|
||||
parameters (list): The parameters of the function that the command executes.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
names: list[str],
|
||||
description: str,
|
||||
method: Callable[Concatenate[_CP, P], CO],
|
||||
parameters: list[CommandParameter],
|
||||
):
|
||||
# Check if all parameters are provided
|
||||
if not self._parameters_match(method, parameters):
|
||||
raise ValueError(
|
||||
f"Command {names[0]} has different parameters than provided schema"
|
||||
)
|
||||
self.names = names
|
||||
self.description = description
|
||||
# Method technically has a `self` parameter, but we can ignore that
|
||||
# since Python passes it internally.
|
||||
self.method = cast(Callable[P, CO], method)
|
||||
self.parameters = parameters
|
||||
|
||||
@property
|
||||
def is_async(self) -> bool:
|
||||
return inspect.iscoroutinefunction(self.method)
|
||||
|
||||
@property
|
||||
def return_type(self) -> type:
|
||||
type = inspect.signature(self.method).return_annotation
|
||||
if type == inspect.Signature.empty:
|
||||
return None
|
||||
return type.__name__
|
||||
|
||||
def _parameters_match(
|
||||
self, func: Callable, parameters: list[CommandParameter]
|
||||
) -> bool:
|
||||
# Get the function's signature
|
||||
signature = inspect.signature(func)
|
||||
# Extract parameter names, ignoring 'self' for methods
|
||||
func_param_names = [
|
||||
param.name
|
||||
for param in signature.parameters.values()
|
||||
if param.name != "self"
|
||||
]
|
||||
names = [param.name for param in parameters]
|
||||
# Check if sorted lists of names/keys are equal
|
||||
return sorted(func_param_names) == sorted(names)
|
||||
|
||||
def __call__(self, *args: P.args, **kwargs: P.kwargs) -> CO:
|
||||
return self.method(*args, **kwargs)
|
||||
|
||||
def __str__(self) -> str:
|
||||
params = [
|
||||
f"{param.name}: "
|
||||
+ ("%s" if param.spec.required else "Optional[%s]")
|
||||
% (param.spec.type.value if param.spec.type else "Any")
|
||||
for param in self.parameters
|
||||
]
|
||||
return (
|
||||
f"{self.names[0]}: {self.description.rstrip('.')}. "
|
||||
f"Params: ({', '.join(params)})"
|
||||
)
|
||||
|
||||
def __get__(self, instance, owner):
|
||||
if instance is None:
|
||||
# Accessed on the class, not an instance
|
||||
return self
|
||||
# Bind the method to the instance
|
||||
return Command(
|
||||
self.names,
|
||||
self.description,
|
||||
self.method.__get__(instance, owner),
|
||||
self.parameters,
|
||||
)
|
||||
60
forge/forge/command/decorator.py
Normal file
60
forge/forge/command/decorator.py
Normal file
@@ -0,0 +1,60 @@
|
||||
import re
|
||||
from typing import Callable, Concatenate, Optional, TypeVar
|
||||
|
||||
from forge.agent.protocols import CommandProvider
|
||||
from forge.models.json_schema import JSONSchema
|
||||
|
||||
from .command import CO, Command, CommandParameter, P
|
||||
|
||||
_CP = TypeVar("_CP", bound=CommandProvider)
|
||||
|
||||
|
||||
def command(
|
||||
names: list[str] = [],
|
||||
description: Optional[str] = None,
|
||||
parameters: dict[str, JSONSchema] = {},
|
||||
) -> Callable[[Callable[Concatenate[_CP, P], CO]], Command[P, CO]]:
|
||||
"""
|
||||
The command decorator is used to make a Command from a function.
|
||||
|
||||
Args:
|
||||
names (list[str]): The names of the command.
|
||||
If not provided, the function name will be used.
|
||||
description (str): A brief description of what the command does.
|
||||
If not provided, the docstring until double line break will be used
|
||||
(or entire docstring if no double line break is found)
|
||||
parameters (dict[str, JSONSchema]): The parameters of the function
|
||||
that the command executes.
|
||||
"""
|
||||
|
||||
def decorator(func: Callable[Concatenate[_CP, P], CO]) -> Command[P, CO]:
|
||||
doc = func.__doc__ or ""
|
||||
# If names is not provided, use the function name
|
||||
command_names = names or [func.__name__]
|
||||
# If description is not provided, use the first part of the docstring
|
||||
if not (command_description := description):
|
||||
if not func.__doc__:
|
||||
raise ValueError("Description is required if function has no docstring")
|
||||
# Return the part of the docstring before double line break or everything
|
||||
command_description = re.sub(r"\s+", " ", doc.split("\n\n")[0].strip())
|
||||
|
||||
# Parameters
|
||||
typed_parameters = [
|
||||
CommandParameter(
|
||||
name=param_name,
|
||||
spec=spec,
|
||||
)
|
||||
for param_name, spec in parameters.items()
|
||||
]
|
||||
|
||||
# Wrap func with Command
|
||||
command = Command(
|
||||
names=command_names,
|
||||
description=command_description,
|
||||
method=func,
|
||||
parameters=typed_parameters,
|
||||
)
|
||||
|
||||
return command
|
||||
|
||||
return decorator
|
||||
16
forge/forge/command/parameter.py
Normal file
16
forge/forge/command/parameter.py
Normal file
@@ -0,0 +1,16 @@
|
||||
from pydantic import BaseModel
|
||||
|
||||
from forge.models.json_schema import JSONSchema
|
||||
|
||||
|
||||
class CommandParameter(BaseModel):
|
||||
name: str
|
||||
spec: JSONSchema
|
||||
|
||||
def __repr__(self):
|
||||
return "CommandParameter('%s', '%s', '%s', %s)" % (
|
||||
self.name,
|
||||
self.spec.type,
|
||||
self.spec.description,
|
||||
self.spec.required,
|
||||
)
|
||||
137
forge/forge/components/README.md
Normal file
137
forge/forge/components/README.md
Normal file
@@ -0,0 +1,137 @@
|
||||
# 🧩 Components
|
||||
|
||||
Components are the building blocks of [🤖 Agents](./agents.md). They are classes inheriting `AgentComponent` or implementing one or more [⚙️ Protocols](./protocols.md) that give agent additional abilities or processing.
|
||||
|
||||
Components can be used to implement various functionalities like providing messages to the prompt, executing code, or interacting with external services.
|
||||
They can be enabled or disabled, ordered, and can rely on each other.
|
||||
|
||||
Components assigned in the agent's `__init__` via `self` are automatically detected upon the agent's instantiation.
|
||||
For example inside `__init__`: `self.my_component = MyComponent()`.
|
||||
You can use any valid Python variable name, what matters for the component to be detected is its type (`AgentComponent` or any protocol inheriting from it).
|
||||
|
||||
Visit [Built-in Components](./built-in-components.md) to see what components are available out of the box.
|
||||
|
||||
```py
|
||||
from forge.agent import BaseAgent
|
||||
from forge.agent.components import AgentComponent
|
||||
|
||||
class HelloComponent(AgentComponent):
|
||||
pass
|
||||
|
||||
class SomeComponent(AgentComponent):
|
||||
def __init__(self, hello_component: HelloComponent):
|
||||
self.hello_component = hello_component
|
||||
|
||||
class MyAgent(BaseAgent):
|
||||
def __init__(self):
|
||||
# These components will be automatically discovered and used
|
||||
self.hello_component = HelloComponent()
|
||||
# We pass HelloComponent to SomeComponent
|
||||
self.some_component = SomeComponent(self.hello_component)
|
||||
```
|
||||
|
||||
## Ordering components
|
||||
|
||||
The execution order of components is important because some may depend on the results of the previous ones.
|
||||
**By default, components are ordered alphabetically.**
|
||||
|
||||
### Ordering individual components
|
||||
|
||||
You can order a single component by passing other components (or their types) to the `run_after` method. This way you can ensure that the component will be executed after the specified one.
|
||||
The `run_after` method returns the component itself, so you can call it when assigning the component to a variable:
|
||||
|
||||
```py
|
||||
class MyAgent(Agent):
|
||||
def __init__(self):
|
||||
self.hello_component = HelloComponent()
|
||||
self.calculator_component = CalculatorComponent().run_after(self.hello_component)
|
||||
# This is equivalent to passing a type:
|
||||
# self.calculator_component = CalculatorComponent().run_after(HelloComponent)
|
||||
```
|
||||
|
||||
!!! warning
|
||||
Be sure not to make circular dependencies when ordering components!
|
||||
|
||||
### Ordering all components
|
||||
|
||||
You can also order all components by setting `self.components` list in the agent's `__init__` method.
|
||||
This way ensures that there's no circular dependencies and any `run_after` calls are ignored.
|
||||
|
||||
!!! warning
|
||||
Be sure to include all components - by setting `self.components` list, you're overriding the default behavior of discovering components automatically. Since it's usually not intended agent will inform you in the terminal if some components were skipped.
|
||||
|
||||
```py
|
||||
class MyAgent(Agent):
|
||||
def __init__(self):
|
||||
self.hello_component = HelloComponent()
|
||||
self.calculator_component = CalculatorComponent()
|
||||
# Explicitly set components list
|
||||
self.components = [self.hello_component, self.calculator_component]
|
||||
```
|
||||
|
||||
## Disabling components
|
||||
|
||||
You can control which components are enabled by setting their `_enabled` attribute.
|
||||
Either provide a `bool` value or a `Callable[[], bool]`, will be checked each time
|
||||
the component is about to be executed. This way you can dynamically enable or disable
|
||||
components based on some conditions.
|
||||
You can also provide a reason for disabling the component by setting `_disabled_reason`.
|
||||
The reason will be visible in the debug information.
|
||||
|
||||
```py
|
||||
class DisabledComponent(MessageProvider):
|
||||
def __init__(self):
|
||||
# Disable this component
|
||||
self._enabled = False
|
||||
self._disabled_reason = "This component is disabled because of reasons."
|
||||
|
||||
# Or disable based on some condition, either statically...:
|
||||
self._enabled = self.some_property is not None
|
||||
# ... or dynamically:
|
||||
self._enabled = lambda: self.some_property is not None
|
||||
|
||||
# This method will never be called
|
||||
def get_messages(self) -> Iterator[ChatMessage]:
|
||||
yield ChatMessage.user("This message won't be seen!")
|
||||
|
||||
def some_condition(self) -> bool:
|
||||
return False
|
||||
```
|
||||
|
||||
If you don't want the component at all, you can just remove it from the agent's `__init__` method. If you want to remove components you inherit from the parent class you can set the relevant attribute to `None`:
|
||||
|
||||
!!! Warning
|
||||
Be careful when removing components that are required by other components. This may lead to errors and unexpected behavior.
|
||||
|
||||
```py
|
||||
class MyAgent(Agent):
|
||||
def __init__(self):
|
||||
super().__init__(...)
|
||||
# Disable WatchdogComponent that is in the parent class
|
||||
self.watchdog = None
|
||||
|
||||
```
|
||||
|
||||
## Exceptions
|
||||
|
||||
Custom errors are provided which can be used to control the execution flow in case something went wrong. All those errors can be raised in protocol methods and will be caught by the agent.
|
||||
By default agent will retry three times and then re-raise an exception if it's still not resolved. All passed arguments are automatically handled and the values are reverted when needed.
|
||||
All errors accept an optional `str` message. There are following errors ordered by increasing broadness:
|
||||
|
||||
1. `ComponentEndpointError`: A single endpoint method failed to execute. Agent will retry the execution of this endpoint on the component.
|
||||
2. `EndpointPipelineError`: A pipeline failed to execute. Agent will retry the execution of the endpoint for all components.
|
||||
3. `ComponentSystemError`: Multiple pipelines failed.
|
||||
|
||||
**Example**
|
||||
|
||||
```py
|
||||
from forge.agent.components import ComponentEndpointError
|
||||
from forge.agent.protocols import MessageProvider
|
||||
|
||||
# Example of raising an error
|
||||
class MyComponent(MessageProvider):
|
||||
def get_messages(self) -> Iterator[ChatMessage]:
|
||||
# This will cause the component to always fail
|
||||
# and retry 3 times before re-raising the exception
|
||||
raise ComponentEndpointError("Endpoint error!")
|
||||
```
|
||||
4
forge/forge/components/action_history/__init__.py
Normal file
4
forge/forge/components/action_history/__init__.py
Normal file
@@ -0,0 +1,4 @@
|
||||
from .action_history import ActionHistoryComponent
|
||||
from .model import Episode, EpisodicActionHistory
|
||||
|
||||
__all__ = ["ActionHistoryComponent", "Episode", "EpisodicActionHistory"]
|
||||
81
forge/forge/components/action_history/action_history.py
Normal file
81
forge/forge/components/action_history/action_history.py
Normal file
@@ -0,0 +1,81 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, Callable, Iterator, Optional
|
||||
|
||||
from forge.agent.protocols import AfterExecute, AfterParse, MessageProvider
|
||||
from forge.llm.prompting.utils import indent
|
||||
from forge.llm.providers import ChatMessage, MultiProvider
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from forge.config.config import Config
|
||||
|
||||
from .model import ActionResult, AnyProposal, Episode, EpisodicActionHistory
|
||||
|
||||
|
||||
class ActionHistoryComponent(MessageProvider, AfterParse[AnyProposal], AfterExecute):
|
||||
"""Keeps track of the event history and provides a summary of the steps."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
event_history: EpisodicActionHistory[AnyProposal],
|
||||
max_tokens: int,
|
||||
count_tokens: Callable[[str], int],
|
||||
legacy_config: Config,
|
||||
llm_provider: MultiProvider,
|
||||
) -> None:
|
||||
self.event_history = event_history
|
||||
self.max_tokens = max_tokens
|
||||
self.count_tokens = count_tokens
|
||||
self.legacy_config = legacy_config
|
||||
self.llm_provider = llm_provider
|
||||
|
||||
def get_messages(self) -> Iterator[ChatMessage]:
|
||||
if progress := self._compile_progress(
|
||||
self.event_history.episodes,
|
||||
self.max_tokens,
|
||||
self.count_tokens,
|
||||
):
|
||||
yield ChatMessage.system(
|
||||
f"## Progress on your Task so far\nThis is the list of the steps that you have executed previously, use this as your consideration on considering the next action!\n{progress}"
|
||||
)
|
||||
|
||||
def after_parse(self, result: AnyProposal) -> None:
|
||||
self.event_history.register_action(result)
|
||||
|
||||
async def after_execute(self, result: ActionResult) -> None:
|
||||
self.event_history.register_result(result)
|
||||
await self.event_history.handle_compression(
|
||||
self.llm_provider, self.legacy_config
|
||||
)
|
||||
|
||||
def _compile_progress(
|
||||
self,
|
||||
episode_history: list[Episode[AnyProposal]],
|
||||
max_tokens: Optional[int] = None,
|
||||
count_tokens: Optional[Callable[[str], int]] = None,
|
||||
) -> str:
|
||||
if max_tokens and not count_tokens:
|
||||
raise ValueError("count_tokens is required if max_tokens is set")
|
||||
|
||||
steps: list[str] = []
|
||||
tokens: int = 0
|
||||
n_episodes = len(episode_history)
|
||||
|
||||
for i, episode in enumerate(reversed(episode_history)):
|
||||
# Use full format for the latest 4 steps, summary or format for older steps
|
||||
if i < 4 or episode.summary is None:
|
||||
step_content = indent(episode.format(), 2).strip()
|
||||
else:
|
||||
step_content = episode.summary
|
||||
|
||||
step = f"* Step {n_episodes - i}: {step_content}"
|
||||
|
||||
if max_tokens and count_tokens:
|
||||
step_tokens = count_tokens(step)
|
||||
if tokens + step_tokens > max_tokens:
|
||||
break
|
||||
tokens += step_tokens
|
||||
|
||||
steps.insert(0, step)
|
||||
|
||||
return "\n\n".join(steps)
|
||||
155
forge/forge/components/action_history/model.py
Normal file
155
forge/forge/components/action_history/model.py
Normal file
@@ -0,0 +1,155 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from typing import TYPE_CHECKING, Generic
|
||||
|
||||
from pydantic import Field
|
||||
from pydantic.generics import GenericModel
|
||||
|
||||
from forge.content_processing.text import summarize_text
|
||||
from forge.llm.prompting.utils import format_numbered_list, indent
|
||||
from forge.models.action import ActionResult, AnyProposal
|
||||
from forge.models.utils import ModelWithSummary
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from forge.config.config import Config
|
||||
from forge.llm.providers import MultiProvider
|
||||
|
||||
|
||||
class Episode(GenericModel, Generic[AnyProposal]):
|
||||
action: AnyProposal
|
||||
result: ActionResult | None
|
||||
summary: str | None = None
|
||||
|
||||
def format(self):
|
||||
step = f"Executed `{self.action.use_tool}`\n"
|
||||
reasoning = (
|
||||
_r.summary()
|
||||
if isinstance(_r := self.action.thoughts, ModelWithSummary)
|
||||
else _r
|
||||
)
|
||||
step += f'- **Reasoning:** "{reasoning}"\n'
|
||||
step += (
|
||||
"- **Status:** "
|
||||
f"`{self.result.status if self.result else 'did_not_finish'}`\n"
|
||||
)
|
||||
if self.result:
|
||||
if self.result.status == "success":
|
||||
result = str(self.result)
|
||||
result = "\n" + indent(result) if "\n" in result else result
|
||||
step += f"- **Output:** {result}"
|
||||
elif self.result.status == "error":
|
||||
step += f"- **Reason:** {self.result.reason}\n"
|
||||
if self.result.error:
|
||||
step += f"- **Error:** {self.result.error}\n"
|
||||
elif self.result.status == "interrupted_by_human":
|
||||
step += f"- **Feedback:** {self.result.feedback}\n"
|
||||
return step
|
||||
|
||||
def __str__(self) -> str:
|
||||
executed_action = f"Executed `{self.action.use_tool}`"
|
||||
action_result = f": {self.result}" if self.result else "."
|
||||
return executed_action + action_result
|
||||
|
||||
|
||||
class EpisodicActionHistory(GenericModel, Generic[AnyProposal]):
|
||||
"""Utility container for an action history"""
|
||||
|
||||
episodes: list[Episode[AnyProposal]] = Field(default_factory=list)
|
||||
cursor: int = 0
|
||||
_lock = asyncio.Lock()
|
||||
|
||||
@property
|
||||
def current_episode(self) -> Episode[AnyProposal] | None:
|
||||
if self.cursor == len(self):
|
||||
return None
|
||||
return self[self.cursor]
|
||||
|
||||
def __getitem__(self, key: int) -> Episode[AnyProposal]:
|
||||
return self.episodes[key]
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self.episodes)
|
||||
|
||||
def __bool__(self) -> bool:
|
||||
return len(self.episodes) > 0
|
||||
|
||||
def register_action(self, action: AnyProposal) -> None:
|
||||
if not self.current_episode:
|
||||
self.episodes.append(Episode(action=action, result=None))
|
||||
assert self.current_episode
|
||||
elif self.current_episode.action:
|
||||
raise ValueError("Action for current cycle already set")
|
||||
|
||||
def register_result(self, result: ActionResult) -> None:
|
||||
if not self.current_episode:
|
||||
raise RuntimeError("Cannot register result for cycle without action")
|
||||
elif self.current_episode.result:
|
||||
raise ValueError("Result for current cycle already set")
|
||||
|
||||
self.current_episode.result = result
|
||||
self.cursor = len(self.episodes)
|
||||
|
||||
def rewind(self, number_of_episodes: int = 0) -> None:
|
||||
"""Resets the history to an earlier state.
|
||||
|
||||
Params:
|
||||
number_of_cycles (int): The number of cycles to rewind. Default is 0.
|
||||
When set to 0, it will only reset the current cycle.
|
||||
"""
|
||||
# Remove partial record of current cycle
|
||||
if self.current_episode:
|
||||
if self.current_episode.action and not self.current_episode.result:
|
||||
self.episodes.pop(self.cursor)
|
||||
|
||||
# Rewind the specified number of cycles
|
||||
if number_of_episodes > 0:
|
||||
self.episodes = self.episodes[:-number_of_episodes]
|
||||
self.cursor = len(self.episodes)
|
||||
|
||||
async def handle_compression(
|
||||
self, llm_provider: MultiProvider, app_config: Config
|
||||
) -> None:
|
||||
"""Compresses each episode in the action history using an LLM.
|
||||
|
||||
This method iterates over all episodes in the action history without a summary,
|
||||
and generates a summary for them using an LLM.
|
||||
"""
|
||||
compress_instruction = (
|
||||
"The text represents an action, the reason for its execution, "
|
||||
"and its result. "
|
||||
"Condense the action taken and its result into one line. "
|
||||
"Preserve any specific factual information gathered by the action."
|
||||
)
|
||||
async with self._lock:
|
||||
# Gather all episodes without a summary
|
||||
episodes_to_summarize = [ep for ep in self.episodes if ep.summary is None]
|
||||
|
||||
# Parallelize summarization calls
|
||||
summarize_coroutines = [
|
||||
summarize_text(
|
||||
episode.format(),
|
||||
instruction=compress_instruction,
|
||||
llm_provider=llm_provider,
|
||||
config=app_config,
|
||||
)
|
||||
for episode in episodes_to_summarize
|
||||
]
|
||||
summaries = await asyncio.gather(*summarize_coroutines)
|
||||
|
||||
# Assign summaries to episodes
|
||||
for episode, (summary, _) in zip(episodes_to_summarize, summaries):
|
||||
episode.summary = summary
|
||||
|
||||
def fmt_list(self) -> str:
|
||||
return format_numbered_list(self.episodes)
|
||||
|
||||
def fmt_paragraph(self) -> str:
|
||||
steps: list[str] = []
|
||||
|
||||
for i, episode in enumerate(self.episodes, 1):
|
||||
step = f"### Step {i}: {episode.format()}\n"
|
||||
|
||||
steps.append(step)
|
||||
|
||||
return "\n\n".join(steps)
|
||||
13
forge/forge/components/code_executor/__init__.py
Normal file
13
forge/forge/components/code_executor/__init__.py
Normal file
@@ -0,0 +1,13 @@
|
||||
from .code_executor import (
|
||||
ALLOWLIST_CONTROL,
|
||||
DENYLIST_CONTROL,
|
||||
CodeExecutionError,
|
||||
CodeExecutorComponent,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"ALLOWLIST_CONTROL",
|
||||
"DENYLIST_CONTROL",
|
||||
"CodeExecutionError",
|
||||
"CodeExecutorComponent",
|
||||
]
|
||||
410
forge/forge/components/code_executor/code_executor.py
Normal file
410
forge/forge/components/code_executor/code_executor.py
Normal file
@@ -0,0 +1,410 @@
|
||||
import logging
|
||||
import os
|
||||
import random
|
||||
import shlex
|
||||
import string
|
||||
import subprocess
|
||||
from pathlib import Path
|
||||
from typing import Iterator
|
||||
|
||||
import docker
|
||||
from docker.errors import DockerException, ImageNotFound, NotFound
|
||||
from docker.models.containers import Container as DockerContainer
|
||||
|
||||
from forge.agent import BaseAgentSettings
|
||||
from forge.agent.protocols import CommandProvider
|
||||
from forge.command import Command, command
|
||||
from forge.config.config import Config
|
||||
from forge.file_storage import FileStorage
|
||||
from forge.models.json_schema import JSONSchema
|
||||
from forge.utils.exceptions import (
|
||||
CommandExecutionError,
|
||||
InvalidArgumentError,
|
||||
OperationNotAllowedError,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
ALLOWLIST_CONTROL = "allowlist"
|
||||
DENYLIST_CONTROL = "denylist"
|
||||
|
||||
|
||||
def we_are_running_in_a_docker_container() -> bool:
|
||||
"""Check if we are running in a Docker container
|
||||
|
||||
Returns:
|
||||
bool: True if we are running in a Docker container, False otherwise
|
||||
"""
|
||||
return os.path.exists("/.dockerenv")
|
||||
|
||||
|
||||
def is_docker_available() -> bool:
|
||||
"""Check if Docker is available and supports Linux containers
|
||||
|
||||
Returns:
|
||||
bool: True if Docker is available and supports Linux containers, False otherwise
|
||||
"""
|
||||
try:
|
||||
client = docker.from_env()
|
||||
docker_info = client.info()
|
||||
return docker_info["OSType"] == "linux"
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
|
||||
class CodeExecutionError(CommandExecutionError):
|
||||
"""The operation (an attempt to run arbitrary code) returned an error"""
|
||||
|
||||
|
||||
class CodeExecutorComponent(CommandProvider):
|
||||
"""Provides commands to execute Python code and shell commands."""
|
||||
|
||||
def __init__(
|
||||
self, workspace: FileStorage, state: BaseAgentSettings, config: Config
|
||||
):
|
||||
self.workspace = workspace
|
||||
self.state = state
|
||||
self.legacy_config = config
|
||||
|
||||
if not we_are_running_in_a_docker_container() and not is_docker_available():
|
||||
logger.info(
|
||||
"Docker is not available or does not support Linux containers. "
|
||||
"The code execution commands will not be available."
|
||||
)
|
||||
|
||||
if not self.legacy_config.execute_local_commands:
|
||||
logger.info(
|
||||
"Local shell commands are disabled. To enable them,"
|
||||
" set EXECUTE_LOCAL_COMMANDS to 'True' in your config file."
|
||||
)
|
||||
|
||||
def get_commands(self) -> Iterator[Command]:
|
||||
if we_are_running_in_a_docker_container() or is_docker_available():
|
||||
yield self.execute_python_code
|
||||
yield self.execute_python_file
|
||||
|
||||
if self.legacy_config.execute_local_commands:
|
||||
yield self.execute_shell
|
||||
yield self.execute_shell_popen
|
||||
|
||||
@command(
|
||||
["execute_python_code"],
|
||||
"Executes the given Python code inside a single-use Docker container"
|
||||
" with access to your workspace folder",
|
||||
{
|
||||
"code": JSONSchema(
|
||||
type=JSONSchema.Type.STRING,
|
||||
description="The Python code to run",
|
||||
required=True,
|
||||
),
|
||||
},
|
||||
)
|
||||
async def execute_python_code(self, code: str) -> str:
|
||||
"""
|
||||
Create and execute a Python file in a Docker container
|
||||
and return the STDOUT of the executed code.
|
||||
|
||||
If the code generates any data that needs to be captured,
|
||||
use a print statement.
|
||||
|
||||
Args:
|
||||
code (str): The Python code to run.
|
||||
agent (Agent): The Agent executing the command.
|
||||
|
||||
Returns:
|
||||
str: The STDOUT captured from the code when it ran.
|
||||
"""
|
||||
|
||||
temp_path = ""
|
||||
while True:
|
||||
temp_path = f"temp{self._generate_random_string()}.py"
|
||||
if not self.workspace.exists(temp_path):
|
||||
break
|
||||
await self.workspace.write_file(temp_path, code)
|
||||
|
||||
try:
|
||||
return self.execute_python_file(temp_path)
|
||||
except Exception as e:
|
||||
raise CommandExecutionError(*e.args)
|
||||
finally:
|
||||
self.workspace.delete_file(temp_path)
|
||||
|
||||
@command(
|
||||
["execute_python_file"],
|
||||
"Execute an existing Python file inside a single-use Docker container"
|
||||
" with access to your workspace folder",
|
||||
{
|
||||
"filename": JSONSchema(
|
||||
type=JSONSchema.Type.STRING,
|
||||
description="The name of the file to execute",
|
||||
required=True,
|
||||
),
|
||||
"args": JSONSchema(
|
||||
type=JSONSchema.Type.ARRAY,
|
||||
description="The (command line) arguments to pass to the script",
|
||||
required=False,
|
||||
items=JSONSchema(type=JSONSchema.Type.STRING),
|
||||
),
|
||||
},
|
||||
)
|
||||
def execute_python_file(self, filename: str | Path, args: list[str] = []) -> str:
|
||||
"""Execute a Python file in a Docker container and return the output
|
||||
|
||||
Args:
|
||||
filename (Path): The name of the file to execute
|
||||
args (list, optional): The arguments with which to run the python script
|
||||
|
||||
Returns:
|
||||
str: The output of the file
|
||||
"""
|
||||
logger.info(f"Executing python file '{filename}'")
|
||||
|
||||
if not str(filename).endswith(".py"):
|
||||
raise InvalidArgumentError("Invalid file type. Only .py files are allowed.")
|
||||
|
||||
file_path = self.workspace.get_path(filename)
|
||||
if not self.workspace.exists(file_path):
|
||||
# Mimic the response that you get from the command line to make it
|
||||
# intuitively understandable for the LLM
|
||||
raise FileNotFoundError(
|
||||
f"python: can't open file '{filename}': "
|
||||
f"[Errno 2] No such file or directory"
|
||||
)
|
||||
|
||||
if we_are_running_in_a_docker_container():
|
||||
logger.debug(
|
||||
"App is running in a Docker container; "
|
||||
f"executing {file_path} directly..."
|
||||
)
|
||||
with self.workspace.mount() as local_path:
|
||||
result = subprocess.run(
|
||||
["python", "-B", str(file_path.relative_to(self.workspace.root))]
|
||||
+ args,
|
||||
capture_output=True,
|
||||
encoding="utf8",
|
||||
cwd=str(local_path),
|
||||
)
|
||||
if result.returncode == 0:
|
||||
return result.stdout
|
||||
else:
|
||||
raise CodeExecutionError(result.stderr)
|
||||
|
||||
logger.debug("App is not running in a Docker container")
|
||||
return self._run_python_code_in_docker(file_path, args)
|
||||
|
||||
def validate_command(self, command_line: str, config: Config) -> tuple[bool, bool]:
|
||||
"""Check whether a command is allowed and whether it may be executed in a shell.
|
||||
|
||||
If shell command control is enabled, we disallow executing in a shell, because
|
||||
otherwise the model could circumvent the command filter using shell features.
|
||||
|
||||
Args:
|
||||
command_line (str): The command line to validate
|
||||
config (Config): The app config including shell command control settings
|
||||
|
||||
Returns:
|
||||
bool: True if the command is allowed, False otherwise
|
||||
bool: True if the command may be executed in a shell, False otherwise
|
||||
"""
|
||||
if not command_line:
|
||||
return False, False
|
||||
|
||||
command_name = shlex.split(command_line)[0]
|
||||
|
||||
if config.shell_command_control == ALLOWLIST_CONTROL:
|
||||
return command_name in config.shell_allowlist, False
|
||||
elif config.shell_command_control == DENYLIST_CONTROL:
|
||||
return command_name not in config.shell_denylist, False
|
||||
else:
|
||||
return True, True
|
||||
|
||||
@command(
|
||||
["execute_shell"],
|
||||
"Execute a Shell Command, non-interactive commands only",
|
||||
{
|
||||
"command_line": JSONSchema(
|
||||
type=JSONSchema.Type.STRING,
|
||||
description="The command line to execute",
|
||||
required=True,
|
||||
)
|
||||
},
|
||||
)
|
||||
def execute_shell(self, command_line: str) -> str:
|
||||
"""Execute a shell command and return the output
|
||||
|
||||
Args:
|
||||
command_line (str): The command line to execute
|
||||
|
||||
Returns:
|
||||
str: The output of the command
|
||||
"""
|
||||
allow_execute, allow_shell = self.validate_command(
|
||||
command_line, self.legacy_config
|
||||
)
|
||||
if not allow_execute:
|
||||
logger.info(f"Command '{command_line}' not allowed")
|
||||
raise OperationNotAllowedError("This shell command is not allowed.")
|
||||
|
||||
current_dir = Path.cwd()
|
||||
# Change dir into workspace if necessary
|
||||
if not current_dir.is_relative_to(self.workspace.root):
|
||||
os.chdir(self.workspace.root)
|
||||
|
||||
logger.info(
|
||||
f"Executing command '{command_line}' in working directory '{os.getcwd()}'"
|
||||
)
|
||||
|
||||
result = subprocess.run(
|
||||
command_line if allow_shell else shlex.split(command_line),
|
||||
capture_output=True,
|
||||
shell=allow_shell,
|
||||
)
|
||||
output = f"STDOUT:\n{result.stdout.decode()}\nSTDERR:\n{result.stderr.decode()}"
|
||||
|
||||
# Change back to whatever the prior working dir was
|
||||
os.chdir(current_dir)
|
||||
|
||||
return output
|
||||
|
||||
@command(
|
||||
["execute_shell_popen"],
|
||||
"Execute a Shell Command, non-interactive commands only",
|
||||
{
|
||||
"command_line": JSONSchema(
|
||||
type=JSONSchema.Type.STRING,
|
||||
description="The command line to execute",
|
||||
required=True,
|
||||
)
|
||||
},
|
||||
)
|
||||
def execute_shell_popen(self, command_line: str) -> str:
|
||||
"""Execute a shell command with Popen and returns an english description
|
||||
of the event and the process id
|
||||
|
||||
Args:
|
||||
command_line (str): The command line to execute
|
||||
|
||||
Returns:
|
||||
str: Description of the fact that the process started and its id
|
||||
"""
|
||||
allow_execute, allow_shell = self.validate_command(
|
||||
command_line, self.legacy_config
|
||||
)
|
||||
if not allow_execute:
|
||||
logger.info(f"Command '{command_line}' not allowed")
|
||||
raise OperationNotAllowedError("This shell command is not allowed.")
|
||||
|
||||
current_dir = Path.cwd()
|
||||
# Change dir into workspace if necessary
|
||||
if not current_dir.is_relative_to(self.workspace.root):
|
||||
os.chdir(self.workspace.root)
|
||||
|
||||
logger.info(
|
||||
f"Executing command '{command_line}' in working directory '{os.getcwd()}'"
|
||||
)
|
||||
|
||||
do_not_show_output = subprocess.DEVNULL
|
||||
process = subprocess.Popen(
|
||||
command_line if allow_shell else shlex.split(command_line),
|
||||
shell=allow_shell,
|
||||
stdout=do_not_show_output,
|
||||
stderr=do_not_show_output,
|
||||
)
|
||||
|
||||
# Change back to whatever the prior working dir was
|
||||
os.chdir(current_dir)
|
||||
|
||||
return f"Subprocess started with PID:'{str(process.pid)}'"
|
||||
|
||||
def _run_python_code_in_docker(self, filename: str | Path, args: list[str]) -> str:
|
||||
"""Run a Python script in a Docker container"""
|
||||
file_path = self.workspace.get_path(filename)
|
||||
try:
|
||||
assert self.state.agent_id, "Need Agent ID to attach Docker container"
|
||||
|
||||
client = docker.from_env()
|
||||
image_name = "python:3-alpine"
|
||||
container_is_fresh = False
|
||||
container_name = f"{self.state.agent_id}_sandbox"
|
||||
with self.workspace.mount() as local_path:
|
||||
try:
|
||||
container: DockerContainer = client.containers.get(
|
||||
container_name
|
||||
) # type: ignore
|
||||
except NotFound:
|
||||
try:
|
||||
client.images.get(image_name)
|
||||
logger.debug(f"Image '{image_name}' found locally")
|
||||
except ImageNotFound:
|
||||
logger.info(
|
||||
f"Image '{image_name}' not found locally,"
|
||||
" pulling from Docker Hub..."
|
||||
)
|
||||
# Use the low-level API to stream the pull response
|
||||
low_level_client = docker.APIClient()
|
||||
for line in low_level_client.pull(
|
||||
image_name, stream=True, decode=True
|
||||
):
|
||||
# Print the status and progress, if available
|
||||
status = line.get("status")
|
||||
progress = line.get("progress")
|
||||
if status and progress:
|
||||
logger.info(f"{status}: {progress}")
|
||||
elif status:
|
||||
logger.info(status)
|
||||
|
||||
logger.debug(f"Creating new {image_name} container...")
|
||||
container: DockerContainer = client.containers.run(
|
||||
image_name,
|
||||
["sleep", "60"], # Max 60 seconds to prevent permanent hangs
|
||||
volumes={
|
||||
str(local_path.resolve()): {
|
||||
"bind": "/workspace",
|
||||
"mode": "rw",
|
||||
}
|
||||
},
|
||||
working_dir="/workspace",
|
||||
stderr=True,
|
||||
stdout=True,
|
||||
detach=True,
|
||||
name=container_name,
|
||||
) # type: ignore
|
||||
container_is_fresh = True
|
||||
|
||||
if not container.status == "running":
|
||||
container.start()
|
||||
elif not container_is_fresh:
|
||||
container.restart()
|
||||
|
||||
logger.debug(f"Running {file_path} in container {container.name}...")
|
||||
|
||||
exec_result = container.exec_run(
|
||||
[
|
||||
"python",
|
||||
"-B",
|
||||
file_path.relative_to(self.workspace.root).as_posix(),
|
||||
]
|
||||
+ args,
|
||||
stderr=True,
|
||||
stdout=True,
|
||||
)
|
||||
|
||||
if exec_result.exit_code != 0:
|
||||
raise CodeExecutionError(exec_result.output.decode("utf-8"))
|
||||
|
||||
return exec_result.output.decode("utf-8")
|
||||
|
||||
except DockerException as e:
|
||||
logger.warning(
|
||||
"Could not run the script in a container. "
|
||||
"If you haven't already, please install Docker: "
|
||||
"https://docs.docker.com/get-docker/"
|
||||
)
|
||||
raise CommandExecutionError(f"Could not run the script in a container: {e}")
|
||||
|
||||
def _generate_random_string(self, length: int = 8):
|
||||
# Create a string of all letters and digits
|
||||
characters = string.ascii_letters + string.digits
|
||||
# Use random.choices to generate a random string
|
||||
random_string = "".join(random.choices(characters, k=length))
|
||||
return random_string
|
||||
3
forge/forge/components/code_flow_executor/__init__.py
Normal file
3
forge/forge/components/code_flow_executor/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
from .code_flow_executor import (
|
||||
CodeFlowExecutionComponent
|
||||
)
|
||||
@@ -0,0 +1,80 @@
|
||||
"""Commands to generate images based on text input"""
|
||||
|
||||
import inspect
|
||||
import logging
|
||||
from typing import Iterator
|
||||
|
||||
from forge.agent.protocols import CommandProvider
|
||||
from forge.command import Command, command
|
||||
from forge.models.json_schema import JSONSchema
|
||||
|
||||
MAX_RESULT_LENGTH = 1000
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class CodeFlowExecutionComponent(CommandProvider):
|
||||
"""A component that provides commands to execute code flow."""
|
||||
|
||||
def __init__(self):
|
||||
self._enabled = True
|
||||
self.available_functions = {}
|
||||
|
||||
def set_available_functions(self, functions: list[Command]):
|
||||
self.available_functions = {
|
||||
name: function for function in functions for name in function.names
|
||||
}
|
||||
|
||||
def get_commands(self) -> Iterator[Command]:
|
||||
yield self.execute_code_flow
|
||||
|
||||
@command(
|
||||
parameters={
|
||||
"python_code": JSONSchema(
|
||||
type=JSONSchema.Type.STRING,
|
||||
description="The Python code to execute",
|
||||
required=True,
|
||||
),
|
||||
"plan_text": JSONSchema(
|
||||
type=JSONSchema.Type.STRING,
|
||||
description="The plan to written in a natural language",
|
||||
required=False,
|
||||
),
|
||||
},
|
||||
)
|
||||
async def execute_code_flow(self, python_code: str, plan_text: str) -> str:
|
||||
"""Execute the code flow.
|
||||
|
||||
Args:
|
||||
python_code (str): The Python code to execute
|
||||
callables (dict[str, Callable]): The dictionary of [name, callable] pairs to use in the code
|
||||
|
||||
Returns:
|
||||
str: The result of the code execution
|
||||
"""
|
||||
code_header = "import inspect\n" + "\n".join(
|
||||
[
|
||||
f"""
|
||||
async def {name}(*args, **kwargs):
|
||||
result = {name}_func(*args, **kwargs)
|
||||
if inspect.isawaitable(result):
|
||||
result = await result
|
||||
return result
|
||||
"""
|
||||
for name in self.available_functions.keys()
|
||||
]
|
||||
)
|
||||
result = {
|
||||
name + "_func": func for name, func in self.available_functions.items()
|
||||
}
|
||||
code = f"{code_header}\n{python_code}\n\nexec_output = main()"
|
||||
logger.debug(f"Code-Flow Execution code:\n{python_code}")
|
||||
exec(code, result)
|
||||
result = await result["exec_output"]
|
||||
logger.debug(f"Code-Flow Execution result:\n{result}")
|
||||
if inspect.isawaitable(result):
|
||||
result = await result
|
||||
|
||||
# limit the result to limit the characters
|
||||
if len(result) > MAX_RESULT_LENGTH:
|
||||
result = result[:MAX_RESULT_LENGTH] + "...[Truncated, Content is too long]"
|
||||
return f"Execution Plan:\n{plan_text}\n\nExecution Output:\n{result}"
|
||||
15
forge/forge/components/context/__init__.py
Normal file
15
forge/forge/components/context/__init__.py
Normal file
@@ -0,0 +1,15 @@
|
||||
from .context import ContextComponent
|
||||
from .context_item import (
|
||||
ContextItem,
|
||||
FileContextItem,
|
||||
FolderContextItem,
|
||||
StaticContextItem,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"ContextComponent",
|
||||
"ContextItem",
|
||||
"FileContextItem",
|
||||
"FolderContextItem",
|
||||
"StaticContextItem",
|
||||
]
|
||||
163
forge/forge/components/context/context.py
Normal file
163
forge/forge/components/context/context.py
Normal file
@@ -0,0 +1,163 @@
|
||||
import contextlib
|
||||
from pathlib import Path
|
||||
from typing import Iterator
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from typing_extensions import Annotated
|
||||
|
||||
from forge.agent.protocols import CommandProvider, MessageProvider
|
||||
from forge.command import Command, command
|
||||
from forge.file_storage.base import FileStorage
|
||||
from forge.llm.providers import ChatMessage
|
||||
from forge.models.json_schema import JSONSchema
|
||||
from forge.utils.exceptions import InvalidArgumentError
|
||||
|
||||
from .context_item import ContextItem, FileContextItem, FolderContextItem
|
||||
|
||||
|
||||
class AgentContext(BaseModel):
|
||||
items: list[Annotated[ContextItem, Field(discriminator="type")]] = Field(
|
||||
default_factory=list
|
||||
)
|
||||
|
||||
def __bool__(self) -> bool:
|
||||
return len(self.items) > 0
|
||||
|
||||
def __contains__(self, item: ContextItem) -> bool:
|
||||
return any([i.source == item.source for i in self.items])
|
||||
|
||||
def add(self, item: ContextItem) -> None:
|
||||
self.items.append(item)
|
||||
|
||||
def close(self, index: int) -> None:
|
||||
self.items.pop(index - 1)
|
||||
|
||||
def clear(self) -> None:
|
||||
self.items.clear()
|
||||
|
||||
def format_numbered(self, workspace: FileStorage) -> str:
|
||||
return "\n\n".join(
|
||||
[f"{i}. {c.fmt(workspace)}" for i, c in enumerate(self.items, 1)]
|
||||
)
|
||||
|
||||
|
||||
class ContextComponent(MessageProvider, CommandProvider):
|
||||
"""Adds ability to keep files and folders open in the context (prompt)."""
|
||||
|
||||
def __init__(self, workspace: FileStorage, context: AgentContext):
|
||||
self.context = context
|
||||
self.workspace = workspace
|
||||
|
||||
def get_messages(self) -> Iterator[ChatMessage]:
|
||||
if self.context:
|
||||
yield ChatMessage.system(
|
||||
"## Context\n"
|
||||
f"{self.context.format_numbered(self.workspace)}\n\n"
|
||||
"When a context item is no longer needed and you are not done yet, "
|
||||
"you can hide the item by specifying its number in the list above "
|
||||
"to `hide_context_item`.",
|
||||
)
|
||||
|
||||
def get_commands(self) -> Iterator[Command]:
|
||||
yield self.open_file
|
||||
yield self.open_folder
|
||||
if self.context:
|
||||
yield self.close_context_item
|
||||
|
||||
@command(
|
||||
parameters={
|
||||
"file_path": JSONSchema(
|
||||
type=JSONSchema.Type.STRING,
|
||||
description="The path of the file to open",
|
||||
required=True,
|
||||
)
|
||||
}
|
||||
)
|
||||
async def open_file(self, file_path: str | Path) -> str:
|
||||
"""Opens a file for editing or continued viewing;
|
||||
creates it if it does not exist yet.
|
||||
Note: If you only need to read or write a file once,
|
||||
use `write_to_file` instead.
|
||||
|
||||
Args:
|
||||
file_path (str | Path): The path of the file to open
|
||||
|
||||
Returns:
|
||||
str: A status message indicating what happened
|
||||
"""
|
||||
if not isinstance(file_path, Path):
|
||||
file_path = Path(file_path)
|
||||
|
||||
created = False
|
||||
if not self.workspace.exists(file_path):
|
||||
await self.workspace.write_file(file_path, "")
|
||||
created = True
|
||||
|
||||
# Try to make the file path relative
|
||||
with contextlib.suppress(ValueError):
|
||||
file_path = file_path.relative_to(self.workspace.root)
|
||||
|
||||
file = FileContextItem(path=file_path)
|
||||
self.context.add(file)
|
||||
return (
|
||||
f"File {file_path}{' created,' if created else ''} has been opened"
|
||||
" and added to the context ✅"
|
||||
)
|
||||
|
||||
@command(
|
||||
parameters={
|
||||
"path": JSONSchema(
|
||||
type=JSONSchema.Type.STRING,
|
||||
description="The path of the folder to open",
|
||||
required=True,
|
||||
)
|
||||
}
|
||||
)
|
||||
def open_folder(self, path: str | Path) -> str:
|
||||
"""Open a folder to keep track of its content
|
||||
|
||||
Args:
|
||||
path (str | Path): The path of the folder to open
|
||||
|
||||
Returns:
|
||||
str: A status message indicating what happened
|
||||
"""
|
||||
if not isinstance(path, Path):
|
||||
path = Path(path)
|
||||
|
||||
if not self.workspace.exists(path):
|
||||
raise FileNotFoundError(
|
||||
f"open_folder {path} failed: no such file or directory"
|
||||
)
|
||||
|
||||
# Try to make the path relative
|
||||
with contextlib.suppress(ValueError):
|
||||
path = path.relative_to(self.workspace.root)
|
||||
|
||||
folder = FolderContextItem(path=path)
|
||||
self.context.add(folder)
|
||||
return f"Folder {path} has been opened and added to the context ✅"
|
||||
|
||||
@command(
|
||||
parameters={
|
||||
"number": JSONSchema(
|
||||
type=JSONSchema.Type.INTEGER,
|
||||
description="The 1-based index of the context item to hide",
|
||||
required=True,
|
||||
)
|
||||
}
|
||||
)
|
||||
def close_context_item(self, number: int) -> str:
|
||||
"""Hide an open file, folder or other context item, to save tokens.
|
||||
|
||||
Args:
|
||||
number (int): The 1-based index of the context item to hide
|
||||
|
||||
Returns:
|
||||
str: A status message indicating what happened
|
||||
"""
|
||||
if number > len(self.context.items) or number == 0:
|
||||
raise InvalidArgumentError(f"Index {number} out of range")
|
||||
|
||||
self.context.close(number)
|
||||
return f"Context item {number} hidden ✅"
|
||||
85
forge/forge/components/context/context_item.py
Normal file
85
forge/forge/components/context/context_item.py
Normal file
@@ -0,0 +1,85 @@
|
||||
import logging
|
||||
from abc import ABC, abstractmethod
|
||||
from pathlib import Path
|
||||
from typing import Literal, Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from forge.file_storage.base import FileStorage
|
||||
from forge.utils.file_operations import decode_textual_file
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class BaseContextItem(ABC):
|
||||
@property
|
||||
@abstractmethod
|
||||
def description(self) -> str:
|
||||
"""Description of the context item"""
|
||||
...
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def source(self) -> Optional[str]:
|
||||
"""A string indicating the source location of the context item"""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def get_content(self, workspace: FileStorage) -> str:
|
||||
"""The content represented by the context item"""
|
||||
...
|
||||
|
||||
def fmt(self, workspace: FileStorage) -> str:
|
||||
return (
|
||||
f"{self.description} (source: {self.source})\n"
|
||||
"```\n"
|
||||
f"{self.get_content(workspace)}\n"
|
||||
"```"
|
||||
)
|
||||
|
||||
|
||||
class FileContextItem(BaseModel, BaseContextItem):
|
||||
path: Path
|
||||
type: Literal["file"] = "file"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return f"The current content of the file '{self.path}'"
|
||||
|
||||
@property
|
||||
def source(self) -> str:
|
||||
return str(self.path)
|
||||
|
||||
def get_content(self, workspace: FileStorage) -> str:
|
||||
with workspace.open_file(self.path, "r", True) as file:
|
||||
return decode_textual_file(file, self.path.suffix, logger)
|
||||
|
||||
|
||||
class FolderContextItem(BaseModel, BaseContextItem):
|
||||
path: Path
|
||||
type: Literal["folder"] = "folder"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return f"The contents of the folder '{self.path}' in the workspace"
|
||||
|
||||
@property
|
||||
def source(self) -> str:
|
||||
return str(self.path)
|
||||
|
||||
def get_content(self, workspace: FileStorage) -> str:
|
||||
files = [str(p) for p in workspace.list_files(self.path)]
|
||||
folders = [f"{str(p)}/" for p in workspace.list_folders(self.path)]
|
||||
items = folders + files
|
||||
items.sort()
|
||||
return "\n".join(items)
|
||||
|
||||
|
||||
class StaticContextItem(BaseModel, BaseContextItem):
|
||||
item_description: str = Field(alias="description")
|
||||
item_source: Optional[str] = Field(alias="source")
|
||||
item_content: str = Field(alias="content")
|
||||
type: Literal["static"] = "static"
|
||||
|
||||
|
||||
ContextItem = FileContextItem | FolderContextItem | StaticContextItem
|
||||
3
forge/forge/components/file_manager/__init__.py
Normal file
3
forge/forge/components/file_manager/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
from .file_manager import FileManagerComponent
|
||||
|
||||
__all__ = ["FileManagerComponent"]
|
||||
160
forge/forge/components/file_manager/file_manager.py
Normal file
160
forge/forge/components/file_manager/file_manager.py
Normal file
@@ -0,0 +1,160 @@
|
||||
import logging
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Iterator, Optional
|
||||
|
||||
from forge.agent import BaseAgentSettings
|
||||
from forge.agent.protocols import CommandProvider, DirectiveProvider
|
||||
from forge.command import Command, command
|
||||
from forge.file_storage.base import FileStorage
|
||||
from forge.models.json_schema import JSONSchema
|
||||
from forge.utils.file_operations import decode_textual_file
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class FileManagerComponent(DirectiveProvider, CommandProvider):
|
||||
"""
|
||||
Adds general file manager (e.g. Agent state),
|
||||
workspace manager (e.g. Agent output files) support and
|
||||
commands to perform operations on files and folders.
|
||||
"""
|
||||
|
||||
files: FileStorage
|
||||
"""Agent-related files, e.g. state, logs.
|
||||
Use `workspace` to access the agent's workspace files."""
|
||||
|
||||
workspace: FileStorage
|
||||
"""Workspace that the agent has access to, e.g. for reading/writing files.
|
||||
Use `files` to access agent-related files, e.g. state, logs."""
|
||||
|
||||
STATE_FILE = "state.json"
|
||||
"""The name of the file where the agent's state is stored."""
|
||||
|
||||
def __init__(self, state: BaseAgentSettings, file_storage: FileStorage):
|
||||
self.state = state
|
||||
|
||||
if not state.agent_id:
|
||||
raise ValueError("Agent must have an ID.")
|
||||
|
||||
self.files = file_storage.clone_with_subroot(f"agents/{state.agent_id}/")
|
||||
self.workspace = file_storage.clone_with_subroot(
|
||||
f"agents/{state.agent_id}/workspace"
|
||||
)
|
||||
self._file_storage = file_storage
|
||||
|
||||
async def save_state(self, save_as: Optional[str] = None) -> None:
|
||||
"""Save the agent's state to the state file."""
|
||||
state: BaseAgentSettings = getattr(self, "state")
|
||||
if save_as:
|
||||
temp_id = state.agent_id
|
||||
state.agent_id = save_as
|
||||
self._file_storage.make_dir(f"agents/{save_as}")
|
||||
# Save state
|
||||
await self._file_storage.write_file(
|
||||
f"agents/{save_as}/{self.STATE_FILE}", state.json()
|
||||
)
|
||||
# Copy workspace
|
||||
self._file_storage.copy(
|
||||
f"agents/{temp_id}/workspace",
|
||||
f"agents/{save_as}/workspace",
|
||||
)
|
||||
state.agent_id = temp_id
|
||||
else:
|
||||
await self.files.write_file(self.files.root / self.STATE_FILE, state.json())
|
||||
|
||||
def change_agent_id(self, new_id: str):
|
||||
"""Change the agent's ID and update the file storage accordingly."""
|
||||
state: BaseAgentSettings = getattr(self, "state")
|
||||
# Rename the agent's files and workspace
|
||||
self._file_storage.rename(f"agents/{state.agent_id}", f"agents/{new_id}")
|
||||
# Update the file storage objects
|
||||
self.files = self._file_storage.clone_with_subroot(f"agents/{new_id}/")
|
||||
self.workspace = self._file_storage.clone_with_subroot(
|
||||
f"agents/{new_id}/workspace"
|
||||
)
|
||||
state.agent_id = new_id
|
||||
|
||||
def get_resources(self) -> Iterator[str]:
|
||||
yield "The ability to read and write files."
|
||||
|
||||
def get_commands(self) -> Iterator[Command]:
|
||||
yield self.read_file
|
||||
yield self.write_to_file
|
||||
yield self.list_folder
|
||||
|
||||
@command(
|
||||
parameters={
|
||||
"filename": JSONSchema(
|
||||
type=JSONSchema.Type.STRING,
|
||||
description="The path of the file to read",
|
||||
required=True,
|
||||
)
|
||||
},
|
||||
)
|
||||
def read_file(self, filename: str | Path) -> str:
|
||||
"""Read a file and return the contents
|
||||
|
||||
Args:
|
||||
filename (str): The name of the file to read
|
||||
|
||||
Returns:
|
||||
str: The contents of the file
|
||||
"""
|
||||
file = self.workspace.open_file(filename, binary=True)
|
||||
content = decode_textual_file(file, os.path.splitext(filename)[1], logger)
|
||||
|
||||
return content
|
||||
|
||||
@command(
|
||||
["write_file", "create_file"],
|
||||
"Write a file, creating it if necessary. "
|
||||
"If the file exists, it is overwritten.",
|
||||
{
|
||||
"filename": JSONSchema(
|
||||
type=JSONSchema.Type.STRING,
|
||||
description="The name of the file to write to",
|
||||
required=True,
|
||||
),
|
||||
"contents": JSONSchema(
|
||||
type=JSONSchema.Type.STRING,
|
||||
description="The contents to write to the file",
|
||||
required=True,
|
||||
),
|
||||
},
|
||||
)
|
||||
async def write_to_file(self, filename: str | Path, contents: str) -> str:
|
||||
"""Write contents to a file
|
||||
|
||||
Args:
|
||||
filename (str): The name of the file to write to
|
||||
contents (str): The contents to write to the file
|
||||
|
||||
Returns:
|
||||
str: A message indicating success or failure
|
||||
"""
|
||||
if directory := os.path.dirname(filename):
|
||||
self.workspace.make_dir(directory)
|
||||
await self.workspace.write_file(filename, contents)
|
||||
return f"File {filename} has been written successfully."
|
||||
|
||||
@command(
|
||||
parameters={
|
||||
"folder": JSONSchema(
|
||||
type=JSONSchema.Type.STRING,
|
||||
description="The folder to list files in. "
|
||||
"Pass an empty string to list files in the workspace.",
|
||||
required=True,
|
||||
)
|
||||
},
|
||||
)
|
||||
def list_folder(self, folder: str | Path) -> list[str]:
|
||||
"""Lists files in a folder recursively
|
||||
|
||||
Args:
|
||||
folder (str): The folder to search in
|
||||
|
||||
Returns:
|
||||
list[str]: A list of files found in the folder
|
||||
"""
|
||||
return [str(p) for p in self.workspace.list_files(folder)]
|
||||
3
forge/forge/components/git_operations/__init__.py
Normal file
3
forge/forge/components/git_operations/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
from .git_operations import GitOperationsComponent
|
||||
|
||||
__all__ = ["GitOperationsComponent"]
|
||||
60
forge/forge/components/git_operations/git_operations.py
Normal file
60
forge/forge/components/git_operations/git_operations.py
Normal file
@@ -0,0 +1,60 @@
|
||||
from pathlib import Path
|
||||
from typing import Iterator
|
||||
|
||||
from git.repo import Repo
|
||||
|
||||
from forge.agent.protocols import CommandProvider
|
||||
from forge.command import Command, command
|
||||
from forge.config.config import Config
|
||||
from forge.models.json_schema import JSONSchema
|
||||
from forge.utils.exceptions import CommandExecutionError
|
||||
from forge.utils.url_validator import validate_url
|
||||
|
||||
|
||||
class GitOperationsComponent(CommandProvider):
|
||||
"""Provides commands to perform Git operations."""
|
||||
|
||||
def __init__(self, config: Config):
|
||||
self._enabled = bool(config.github_username and config.github_api_key)
|
||||
self._disabled_reason = "Configure github_username and github_api_key."
|
||||
self.legacy_config = config
|
||||
|
||||
def get_commands(self) -> Iterator[Command]:
|
||||
yield self.clone_repository
|
||||
|
||||
@command(
|
||||
parameters={
|
||||
"url": JSONSchema(
|
||||
type=JSONSchema.Type.STRING,
|
||||
description="The URL of the repository to clone",
|
||||
required=True,
|
||||
),
|
||||
"clone_path": JSONSchema(
|
||||
type=JSONSchema.Type.STRING,
|
||||
description="The path to clone the repository to",
|
||||
required=True,
|
||||
),
|
||||
},
|
||||
)
|
||||
@validate_url
|
||||
def clone_repository(self, url: str, clone_path: Path) -> str:
|
||||
"""Clone a GitHub repository locally.
|
||||
|
||||
Args:
|
||||
url (str): The URL of the repository to clone.
|
||||
clone_path (Path): The path to clone the repository to.
|
||||
|
||||
Returns:
|
||||
str: The result of the clone operation.
|
||||
"""
|
||||
split_url = url.split("//")
|
||||
auth_repo_url = (
|
||||
f"//{self.legacy_config.github_username}:"
|
||||
f"{self.legacy_config.github_api_key}@".join(split_url)
|
||||
)
|
||||
try:
|
||||
Repo.clone_from(url=auth_repo_url, to_path=clone_path)
|
||||
except Exception as e:
|
||||
raise CommandExecutionError(f"Could not clone repo: {e}")
|
||||
|
||||
return f"""Cloned {url} to {clone_path}"""
|
||||
3
forge/forge/components/image_gen/__init__.py
Normal file
3
forge/forge/components/image_gen/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
from .image_gen import ImageGeneratorComponent
|
||||
|
||||
__all__ = ["ImageGeneratorComponent"]
|
||||
239
forge/forge/components/image_gen/image_gen.py
Normal file
239
forge/forge/components/image_gen/image_gen.py
Normal file
@@ -0,0 +1,239 @@
|
||||
"""Commands to generate images based on text input"""
|
||||
|
||||
import io
|
||||
import json
|
||||
import logging
|
||||
import time
|
||||
import uuid
|
||||
from base64 import b64decode
|
||||
from pathlib import Path
|
||||
from typing import Iterator
|
||||
|
||||
import requests
|
||||
from openai import OpenAI
|
||||
from PIL import Image
|
||||
|
||||
from forge.agent.protocols import CommandProvider
|
||||
from forge.command import Command, command
|
||||
from forge.config.config import Config
|
||||
from forge.file_storage import FileStorage
|
||||
from forge.models.json_schema import JSONSchema
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ImageGeneratorComponent(CommandProvider):
|
||||
"""A component that provides commands to generate images from text prompts."""
|
||||
|
||||
def __init__(self, workspace: FileStorage, config: Config):
|
||||
self._enabled = bool(config.image_provider)
|
||||
self._disabled_reason = "No image provider set."
|
||||
self.workspace = workspace
|
||||
self.legacy_config = config
|
||||
|
||||
def get_commands(self) -> Iterator[Command]:
|
||||
if (
|
||||
self.legacy_config.openai_credentials
|
||||
or self.legacy_config.huggingface_api_token
|
||||
or self.legacy_config.sd_webui_auth
|
||||
):
|
||||
yield self.generate_image
|
||||
|
||||
@command(
|
||||
parameters={
|
||||
"prompt": JSONSchema(
|
||||
type=JSONSchema.Type.STRING,
|
||||
description="The prompt used to generate the image",
|
||||
required=True,
|
||||
),
|
||||
"size": JSONSchema(
|
||||
type=JSONSchema.Type.INTEGER,
|
||||
description="The size of the image",
|
||||
required=False,
|
||||
),
|
||||
},
|
||||
)
|
||||
def generate_image(self, prompt: str, size: int) -> str:
|
||||
"""Generate an image from a prompt.
|
||||
|
||||
Args:
|
||||
prompt (str): The prompt to use
|
||||
size (int, optional): The size of the image. Defaults to 256.
|
||||
Not supported by HuggingFace.
|
||||
|
||||
Returns:
|
||||
str: The filename of the image
|
||||
"""
|
||||
filename = self.workspace.root / f"{str(uuid.uuid4())}.jpg"
|
||||
cfg = self.legacy_config
|
||||
|
||||
if cfg.openai_credentials and (
|
||||
cfg.image_provider == "dalle"
|
||||
or not (cfg.huggingface_api_token or cfg.sd_webui_url)
|
||||
):
|
||||
return self.generate_image_with_dalle(prompt, filename, size)
|
||||
|
||||
elif cfg.huggingface_api_token and (
|
||||
cfg.image_provider == "huggingface"
|
||||
or not (cfg.openai_credentials or cfg.sd_webui_url)
|
||||
):
|
||||
return self.generate_image_with_hf(prompt, filename)
|
||||
|
||||
elif cfg.sd_webui_url and (
|
||||
cfg.image_provider == "sdwebui" or cfg.sd_webui_auth
|
||||
):
|
||||
return self.generate_image_with_sd_webui(prompt, filename, size)
|
||||
|
||||
return "Error: No image generation provider available"
|
||||
|
||||
def generate_image_with_hf(self, prompt: str, output_file: Path) -> str:
|
||||
"""Generate an image with HuggingFace's API.
|
||||
|
||||
Args:
|
||||
prompt (str): The prompt to use
|
||||
filename (Path): The filename to save the image to
|
||||
|
||||
Returns:
|
||||
str: The filename of the image
|
||||
"""
|
||||
API_URL = f"https://api-inference.huggingface.co/models/{self.legacy_config.huggingface_image_model}" # noqa: E501
|
||||
if self.legacy_config.huggingface_api_token is None:
|
||||
raise ValueError(
|
||||
"You need to set your Hugging Face API token in the config file."
|
||||
)
|
||||
headers = {
|
||||
"Authorization": f"Bearer {self.legacy_config.huggingface_api_token}",
|
||||
"X-Use-Cache": "false",
|
||||
}
|
||||
|
||||
retry_count = 0
|
||||
while retry_count < 10:
|
||||
response = requests.post(
|
||||
API_URL,
|
||||
headers=headers,
|
||||
json={
|
||||
"inputs": prompt,
|
||||
},
|
||||
)
|
||||
|
||||
if response.ok:
|
||||
try:
|
||||
image = Image.open(io.BytesIO(response.content))
|
||||
logger.info(f"Image Generated for prompt:{prompt}")
|
||||
image.save(output_file)
|
||||
return f"Saved to disk: {output_file}"
|
||||
except Exception as e:
|
||||
logger.error(e)
|
||||
break
|
||||
else:
|
||||
try:
|
||||
error = json.loads(response.text)
|
||||
if "estimated_time" in error:
|
||||
delay = error["estimated_time"]
|
||||
logger.debug(response.text)
|
||||
logger.info("Retrying in", delay)
|
||||
time.sleep(delay)
|
||||
else:
|
||||
break
|
||||
except Exception as e:
|
||||
logger.error(e)
|
||||
break
|
||||
|
||||
retry_count += 1
|
||||
|
||||
return "Error creating image."
|
||||
|
||||
def generate_image_with_dalle(
|
||||
self, prompt: str, output_file: Path, size: int
|
||||
) -> str:
|
||||
"""Generate an image with DALL-E.
|
||||
|
||||
Args:
|
||||
prompt (str): The prompt to use
|
||||
filename (Path): The filename to save the image to
|
||||
size (int): The size of the image
|
||||
|
||||
Returns:
|
||||
str: The filename of the image
|
||||
"""
|
||||
assert self.legacy_config.openai_credentials # otherwise this tool is disabled
|
||||
|
||||
# Check for supported image sizes
|
||||
if size not in [256, 512, 1024]:
|
||||
closest = min([256, 512, 1024], key=lambda x: abs(x - size))
|
||||
logger.info(
|
||||
"DALL-E only supports image sizes of 256x256, 512x512, or 1024x1024. "
|
||||
f"Setting to {closest}, was {size}."
|
||||
)
|
||||
size = closest
|
||||
|
||||
# TODO: integrate in `forge.llm.providers`(?)
|
||||
response = OpenAI(
|
||||
api_key=self.legacy_config.openai_credentials.api_key.get_secret_value()
|
||||
).images.generate(
|
||||
prompt=prompt,
|
||||
n=1,
|
||||
# TODO: improve typing of size config item(s)
|
||||
size=f"{size}x{size}", # type: ignore
|
||||
response_format="b64_json",
|
||||
)
|
||||
assert response.data[0].b64_json is not None # response_format = "b64_json"
|
||||
|
||||
logger.info(f"Image Generated for prompt: {prompt}")
|
||||
|
||||
image_data = b64decode(response.data[0].b64_json)
|
||||
|
||||
with open(output_file, mode="wb") as png:
|
||||
png.write(image_data)
|
||||
|
||||
return f"Saved to disk: {output_file}"
|
||||
|
||||
def generate_image_with_sd_webui(
|
||||
self,
|
||||
prompt: str,
|
||||
output_file: Path,
|
||||
size: int = 512,
|
||||
negative_prompt: str = "",
|
||||
extra: dict = {},
|
||||
) -> str:
|
||||
"""Generate an image with Stable Diffusion webui.
|
||||
Args:
|
||||
prompt (str): The prompt to use
|
||||
filename (str): The filename to save the image to
|
||||
size (int, optional): The size of the image. Defaults to 256.
|
||||
negative_prompt (str, optional): The negative prompt to use. Defaults to "".
|
||||
extra (dict, optional): Extra parameters to pass to the API. Defaults to {}.
|
||||
Returns:
|
||||
str: The filename of the image
|
||||
"""
|
||||
# Create a session and set the basic auth if needed
|
||||
s = requests.Session()
|
||||
if self.legacy_config.sd_webui_auth:
|
||||
username, password = self.legacy_config.sd_webui_auth.split(":")
|
||||
s.auth = (username, password or "")
|
||||
|
||||
# Generate the images
|
||||
response = requests.post(
|
||||
f"{self.legacy_config.sd_webui_url}/sdapi/v1/txt2img",
|
||||
json={
|
||||
"prompt": prompt,
|
||||
"negative_prompt": negative_prompt,
|
||||
"sampler_index": "DDIM",
|
||||
"steps": 20,
|
||||
"config_scale": 7.0,
|
||||
"width": size,
|
||||
"height": size,
|
||||
"n_iter": 1,
|
||||
**extra,
|
||||
},
|
||||
)
|
||||
|
||||
logger.info(f"Image Generated for prompt: '{prompt}'")
|
||||
|
||||
# Save the image to disk
|
||||
response = response.json()
|
||||
b64 = b64decode(response["images"][0].split(",", 1)[0])
|
||||
image = Image.open(io.BytesIO(b64))
|
||||
image.save(output_file)
|
||||
|
||||
return f"Saved to disk: {output_file}"
|
||||
3
forge/forge/components/system/__init__.py
Normal file
3
forge/forge/components/system/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
from .system import SystemComponent
|
||||
|
||||
__all__ = ["SystemComponent"]
|
||||
79
forge/forge/components/system/system.py
Normal file
79
forge/forge/components/system/system.py
Normal file
@@ -0,0 +1,79 @@
|
||||
import logging
|
||||
import time
|
||||
from typing import Iterator
|
||||
|
||||
from forge.agent.protocols import CommandProvider, DirectiveProvider, MessageProvider
|
||||
from forge.command import Command, command
|
||||
from forge.llm.providers import ChatMessage
|
||||
from forge.models.json_schema import JSONSchema
|
||||
from forge.utils.const import FINISH_COMMAND
|
||||
from forge.utils.exceptions import AgentFinished
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class SystemComponent(DirectiveProvider, MessageProvider, CommandProvider):
|
||||
"""Component for system messages and commands."""
|
||||
|
||||
def get_constraints(self) -> Iterator[str]:
|
||||
yield "Exclusively use the commands listed below."
|
||||
yield (
|
||||
"You can only act proactively, and are unable to start background jobs or "
|
||||
"set up webhooks for yourself. "
|
||||
"Take this into account when planning your actions."
|
||||
)
|
||||
yield (
|
||||
"You are unable to interact with physical objects. "
|
||||
"If this is absolutely necessary to fulfill a task or objective or "
|
||||
"to complete a step, you must ask the user to do it for you. "
|
||||
"If the user refuses this, and there is no other way to achieve your "
|
||||
"goals, you must terminate to avoid wasting time and energy."
|
||||
)
|
||||
|
||||
def get_resources(self) -> Iterator[str]:
|
||||
yield (
|
||||
"You are a Large Language Model, trained on millions of pages of text, "
|
||||
"including a lot of factual knowledge. Make use of this factual knowledge "
|
||||
"to avoid unnecessary gathering of information."
|
||||
)
|
||||
|
||||
def get_best_practices(self) -> Iterator[str]:
|
||||
yield (
|
||||
"Continuously review and analyze your actions to ensure "
|
||||
"you are performing to the best of your abilities."
|
||||
)
|
||||
yield "Constructively self-criticize your big-picture behavior constantly."
|
||||
yield "Reflect on past decisions and strategies to refine your approach."
|
||||
yield (
|
||||
"Every command has a cost, so be smart and efficient. "
|
||||
"Aim to complete tasks in the least number of steps."
|
||||
)
|
||||
yield (
|
||||
"Only make use of your information gathering abilities to find "
|
||||
"information that you don't yet have knowledge of."
|
||||
)
|
||||
|
||||
def get_messages(self) -> Iterator[ChatMessage]:
|
||||
# Clock
|
||||
yield ChatMessage.system(
|
||||
f"## Clock\nThe current time and date is {time.strftime('%c')}"
|
||||
)
|
||||
|
||||
def get_commands(self) -> Iterator[Command]:
|
||||
yield self.finish
|
||||
|
||||
@command(
|
||||
names=[FINISH_COMMAND],
|
||||
parameters={
|
||||
"reason": JSONSchema(
|
||||
type=JSONSchema.Type.STRING,
|
||||
description="A summary to the user of how the goals were accomplished",
|
||||
required=True,
|
||||
),
|
||||
},
|
||||
)
|
||||
def finish(self, reason: str):
|
||||
"""Use this to shut down once you have completed your task,
|
||||
or when there are insurmountable problems that make it impossible
|
||||
for you to finish your task."""
|
||||
raise AgentFinished(reason)
|
||||
3
forge/forge/components/user_interaction/__init__.py
Normal file
3
forge/forge/components/user_interaction/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
from .user_interaction import UserInteractionComponent
|
||||
|
||||
__all__ = ["UserInteractionComponent"]
|
||||
36
forge/forge/components/user_interaction/user_interaction.py
Normal file
36
forge/forge/components/user_interaction/user_interaction.py
Normal file
@@ -0,0 +1,36 @@
|
||||
from typing import Iterator
|
||||
|
||||
import click
|
||||
|
||||
from forge.agent.protocols import CommandProvider
|
||||
from forge.command import Command, command
|
||||
from forge.config.config import Config
|
||||
from forge.models.json_schema import JSONSchema
|
||||
from forge.utils.const import ASK_COMMAND
|
||||
|
||||
|
||||
class UserInteractionComponent(CommandProvider):
|
||||
"""Provides commands to interact with the user."""
|
||||
|
||||
def __init__(self, config: Config):
|
||||
self._enabled = not config.noninteractive_mode
|
||||
|
||||
def get_commands(self) -> Iterator[Command]:
|
||||
yield self.ask_user
|
||||
|
||||
@command(
|
||||
names=[ASK_COMMAND],
|
||||
parameters={
|
||||
"question": JSONSchema(
|
||||
type=JSONSchema.Type.STRING,
|
||||
description="The question or prompt to the user",
|
||||
required=True,
|
||||
)
|
||||
},
|
||||
)
|
||||
def ask_user(self, question: str) -> str:
|
||||
"""If you need more details or information regarding the given task,
|
||||
you can ask the user for input."""
|
||||
print(f"\nQ: {question}")
|
||||
resp = click.prompt("A")
|
||||
return f"The user's answer: '{resp}'"
|
||||
3
forge/forge/components/watchdog/__init__.py
Normal file
3
forge/forge/components/watchdog/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
from .watchdog import WatchdogComponent
|
||||
|
||||
__all__ = ["WatchdogComponent"]
|
||||
63
forge/forge/components/watchdog/watchdog.py
Normal file
63
forge/forge/components/watchdog/watchdog.py
Normal file
@@ -0,0 +1,63 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from forge.agent.components import ComponentSystemError
|
||||
from forge.agent.protocols import AfterParse
|
||||
from forge.components.action_history import EpisodicActionHistory
|
||||
from forge.models.action import AnyProposal
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from forge.agent.base import BaseAgentConfiguration
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class WatchdogComponent(AfterParse[AnyProposal]):
|
||||
"""
|
||||
Adds a watchdog feature to an agent class. Whenever the agent starts
|
||||
looping, the watchdog will switch from the FAST_LLM to the SMART_LLM and re-think.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: "BaseAgentConfiguration",
|
||||
event_history: EpisodicActionHistory[AnyProposal],
|
||||
):
|
||||
self.config = config
|
||||
self.event_history = event_history
|
||||
self.revert_big_brain = False
|
||||
|
||||
def after_parse(self, result: AnyProposal) -> None:
|
||||
if self.revert_big_brain:
|
||||
self.config.big_brain = False
|
||||
self.revert_big_brain = False
|
||||
|
||||
if not self.config.big_brain and self.config.fast_llm != self.config.smart_llm:
|
||||
previous_command, previous_command_args = None, None
|
||||
if len(self.event_history) > 1:
|
||||
# Detect repetitive commands
|
||||
previous_cycle = self.event_history.episodes[
|
||||
self.event_history.cursor - 1
|
||||
]
|
||||
previous_command = previous_cycle.action.use_tool.name
|
||||
previous_command_args = previous_cycle.action.use_tool.arguments
|
||||
|
||||
rethink_reason = ""
|
||||
|
||||
if not result.use_tool:
|
||||
rethink_reason = "AI did not specify a command"
|
||||
elif (
|
||||
result.use_tool.name == previous_command
|
||||
and result.use_tool.arguments == previous_command_args
|
||||
):
|
||||
rethink_reason = f"Repititive command detected ({result.use_tool.name})"
|
||||
|
||||
if rethink_reason:
|
||||
logger.info(f"{rethink_reason}, re-thinking with SMART_LLM...")
|
||||
self.event_history.rewind()
|
||||
self.big_brain = True
|
||||
self.revert_big_brain = True
|
||||
# Trigger retry of all pipelines prior to this component
|
||||
raise ComponentSystemError(rethink_reason, self)
|
||||
4
forge/forge/components/web/__init__.py
Normal file
4
forge/forge/components/web/__init__.py
Normal file
@@ -0,0 +1,4 @@
|
||||
from .search import WebSearchComponent
|
||||
from .selenium import BrowsingError, WebSeleniumComponent
|
||||
|
||||
__all__ = ["WebSearchComponent", "BrowsingError", "WebSeleniumComponent"]
|
||||
194
forge/forge/components/web/search.py
Normal file
194
forge/forge/components/web/search.py
Normal file
@@ -0,0 +1,194 @@
|
||||
import json
|
||||
import logging
|
||||
import time
|
||||
from typing import Iterator
|
||||
|
||||
from duckduckgo_search import DDGS
|
||||
|
||||
from forge.agent.protocols import CommandProvider, DirectiveProvider
|
||||
from forge.command import Command, command
|
||||
from forge.config.config import Config
|
||||
from forge.models.json_schema import JSONSchema
|
||||
from forge.utils.exceptions import ConfigurationError
|
||||
|
||||
DUCKDUCKGO_MAX_ATTEMPTS = 3
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class WebSearchComponent(DirectiveProvider, CommandProvider):
|
||||
"""Provides commands to search the web."""
|
||||
|
||||
def __init__(self, config: Config):
|
||||
self.legacy_config = config
|
||||
|
||||
if (
|
||||
not self.legacy_config.google_api_key
|
||||
or not self.legacy_config.google_custom_search_engine_id
|
||||
):
|
||||
logger.info(
|
||||
"Configure google_api_key and custom_search_engine_id "
|
||||
"to use Google API search."
|
||||
)
|
||||
|
||||
def get_resources(self) -> Iterator[str]:
|
||||
yield "Internet access for searches and information gathering."
|
||||
|
||||
def get_commands(self) -> Iterator[Command]:
|
||||
yield self.web_search
|
||||
|
||||
if (
|
||||
self.legacy_config.google_api_key
|
||||
and self.legacy_config.google_custom_search_engine_id
|
||||
):
|
||||
yield self.google
|
||||
|
||||
@command(
|
||||
["web_search", "search"],
|
||||
"Searches the web",
|
||||
{
|
||||
"query": JSONSchema(
|
||||
type=JSONSchema.Type.STRING,
|
||||
description="The search query",
|
||||
required=True,
|
||||
),
|
||||
"num_results": JSONSchema(
|
||||
type=JSONSchema.Type.INTEGER,
|
||||
description="The number of results to return",
|
||||
minimum=1,
|
||||
maximum=10,
|
||||
required=False,
|
||||
),
|
||||
},
|
||||
)
|
||||
def web_search(self, query: str, num_results: int = 8) -> str:
|
||||
"""Return the results of a Google search
|
||||
|
||||
Args:
|
||||
query (str): The search query.
|
||||
num_results (int): The number of results to return.
|
||||
|
||||
Returns:
|
||||
str: The results of the search.
|
||||
"""
|
||||
search_results = []
|
||||
attempts = 0
|
||||
|
||||
while attempts < DUCKDUCKGO_MAX_ATTEMPTS:
|
||||
if not query:
|
||||
return json.dumps(search_results)
|
||||
|
||||
search_results = DDGS().text(query, max_results=num_results)
|
||||
|
||||
if search_results:
|
||||
break
|
||||
|
||||
time.sleep(1)
|
||||
attempts += 1
|
||||
|
||||
search_results = [
|
||||
{
|
||||
"title": r["title"],
|
||||
"url": r["href"],
|
||||
**({"exerpt": r["body"]} if r.get("body") else {}),
|
||||
}
|
||||
for r in search_results
|
||||
]
|
||||
|
||||
results = ("## Search results\n") + "\n\n".join(
|
||||
f"### \"{r['title']}\"\n"
|
||||
f"**URL:** {r['url']} \n"
|
||||
"**Excerpt:** " + (f'"{exerpt}"' if (exerpt := r.get("exerpt")) else "N/A")
|
||||
for r in search_results
|
||||
)
|
||||
return self.safe_google_results(results)
|
||||
|
||||
@command(
|
||||
["google"],
|
||||
"Google Search",
|
||||
{
|
||||
"query": JSONSchema(
|
||||
type=JSONSchema.Type.STRING,
|
||||
description="The search query",
|
||||
required=True,
|
||||
),
|
||||
"num_results": JSONSchema(
|
||||
type=JSONSchema.Type.INTEGER,
|
||||
description="The number of results to return",
|
||||
minimum=1,
|
||||
maximum=10,
|
||||
required=False,
|
||||
),
|
||||
},
|
||||
)
|
||||
def google(self, query: str, num_results: int = 8) -> str | list[str]:
|
||||
"""Return the results of a Google search using the official Google API
|
||||
|
||||
Args:
|
||||
query (str): The search query.
|
||||
num_results (int): The number of results to return.
|
||||
|
||||
Returns:
|
||||
str: The results of the search.
|
||||
"""
|
||||
|
||||
from googleapiclient.discovery import build
|
||||
from googleapiclient.errors import HttpError
|
||||
|
||||
try:
|
||||
# Get the Google API key and Custom Search Engine ID from the config file
|
||||
api_key = self.legacy_config.google_api_key
|
||||
custom_search_engine_id = self.legacy_config.google_custom_search_engine_id
|
||||
|
||||
# Initialize the Custom Search API service
|
||||
service = build("customsearch", "v1", developerKey=api_key)
|
||||
|
||||
# Send the search query and retrieve the results
|
||||
result = (
|
||||
service.cse()
|
||||
.list(q=query, cx=custom_search_engine_id, num=num_results)
|
||||
.execute()
|
||||
)
|
||||
|
||||
# Extract the search result items from the response
|
||||
search_results = result.get("items", [])
|
||||
|
||||
# Create a list of only the URLs from the search results
|
||||
search_results_links = [item["link"] for item in search_results]
|
||||
|
||||
except HttpError as e:
|
||||
# Handle errors in the API call
|
||||
error_details = json.loads(e.content.decode())
|
||||
|
||||
# Check if the error is related to an invalid or missing API key
|
||||
if error_details.get("error", {}).get(
|
||||
"code"
|
||||
) == 403 and "invalid API key" in error_details.get("error", {}).get(
|
||||
"message", ""
|
||||
):
|
||||
raise ConfigurationError(
|
||||
"The provided Google API key is invalid or missing."
|
||||
)
|
||||
raise
|
||||
# google_result can be a list or a string depending on the search results
|
||||
|
||||
# Return the list of search result URLs
|
||||
return self.safe_google_results(search_results_links)
|
||||
|
||||
def safe_google_results(self, results: str | list) -> str:
|
||||
"""
|
||||
Return the results of a Google search in a safe format.
|
||||
|
||||
Args:
|
||||
results (str | list): The search results.
|
||||
|
||||
Returns:
|
||||
str: The results of the search.
|
||||
"""
|
||||
if isinstance(results, list):
|
||||
safe_message = json.dumps(
|
||||
[result.encode("utf-8", "ignore").decode("utf-8") for result in results]
|
||||
)
|
||||
else:
|
||||
safe_message = results.encode("utf-8", "ignore").decode("utf-8")
|
||||
return safe_message
|
||||
374
forge/forge/components/web/selenium.py
Normal file
374
forge/forge/components/web/selenium.py
Normal file
@@ -0,0 +1,374 @@
|
||||
import asyncio
|
||||
import logging
|
||||
import re
|
||||
from pathlib import Path
|
||||
from sys import platform
|
||||
from typing import Iterator, Type
|
||||
from urllib.request import urlretrieve
|
||||
|
||||
from bs4 import BeautifulSoup
|
||||
from selenium.common.exceptions import WebDriverException
|
||||
from selenium.webdriver.chrome.options import Options as ChromeOptions
|
||||
from selenium.webdriver.chrome.service import Service as ChromeDriverService
|
||||
from selenium.webdriver.chrome.webdriver import WebDriver as ChromeDriver
|
||||
from selenium.webdriver.common.by import By
|
||||
from selenium.webdriver.edge.options import Options as EdgeOptions
|
||||
from selenium.webdriver.edge.service import Service as EdgeDriverService
|
||||
from selenium.webdriver.edge.webdriver import WebDriver as EdgeDriver
|
||||
from selenium.webdriver.firefox.options import Options as FirefoxOptions
|
||||
from selenium.webdriver.firefox.service import Service as GeckoDriverService
|
||||
from selenium.webdriver.firefox.webdriver import WebDriver as FirefoxDriver
|
||||
from selenium.webdriver.remote.webdriver import WebDriver
|
||||
from selenium.webdriver.safari.options import Options as SafariOptions
|
||||
from selenium.webdriver.safari.webdriver import WebDriver as SafariDriver
|
||||
from selenium.webdriver.support import expected_conditions as EC
|
||||
from selenium.webdriver.support.wait import WebDriverWait
|
||||
from webdriver_manager.chrome import ChromeDriverManager
|
||||
from webdriver_manager.firefox import GeckoDriverManager
|
||||
from webdriver_manager.microsoft import EdgeChromiumDriverManager as EdgeDriverManager
|
||||
|
||||
from forge.agent.protocols import CommandProvider, DirectiveProvider
|
||||
from forge.command import Command, command
|
||||
from forge.config.config import Config
|
||||
from forge.content_processing.html import extract_hyperlinks, format_hyperlinks
|
||||
from forge.content_processing.text import extract_information, summarize_text
|
||||
from forge.llm.providers import ChatModelInfo, MultiProvider
|
||||
from forge.models.json_schema import JSONSchema
|
||||
from forge.utils.exceptions import CommandExecutionError, TooMuchOutputError
|
||||
from forge.utils.url_validator import validate_url
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
FILE_DIR = Path(__file__).parent.parent
|
||||
MAX_RAW_CONTENT_LENGTH = 500
|
||||
LINKS_TO_RETURN = 20
|
||||
|
||||
|
||||
BrowserOptions = ChromeOptions | EdgeOptions | FirefoxOptions | SafariOptions
|
||||
|
||||
|
||||
class BrowsingError(CommandExecutionError):
|
||||
"""An error occurred while trying to browse the page"""
|
||||
|
||||
|
||||
class WebSeleniumComponent(DirectiveProvider, CommandProvider):
|
||||
"""Provides commands to browse the web using Selenium."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: Config,
|
||||
llm_provider: MultiProvider,
|
||||
model_info: ChatModelInfo,
|
||||
):
|
||||
self.legacy_config = config
|
||||
self.llm_provider = llm_provider
|
||||
self.model_info = model_info
|
||||
|
||||
def get_resources(self) -> Iterator[str]:
|
||||
yield "Ability to read websites."
|
||||
|
||||
def get_commands(self) -> Iterator[Command]:
|
||||
yield self.read_webpage
|
||||
|
||||
@command(
|
||||
["read_webpage"],
|
||||
(
|
||||
"Read a webpage, and extract specific information from it."
|
||||
" You must specify either topics_of_interest,"
|
||||
" a question, or get_raw_content."
|
||||
),
|
||||
{
|
||||
"url": JSONSchema(
|
||||
type=JSONSchema.Type.STRING,
|
||||
description="The URL to visit",
|
||||
required=True,
|
||||
),
|
||||
"topics_of_interest": JSONSchema(
|
||||
type=JSONSchema.Type.ARRAY,
|
||||
items=JSONSchema(type=JSONSchema.Type.STRING),
|
||||
description=(
|
||||
"A list of topics about which you want to extract information "
|
||||
"from the page."
|
||||
),
|
||||
required=False,
|
||||
),
|
||||
"question": JSONSchema(
|
||||
type=JSONSchema.Type.STRING,
|
||||
description=(
|
||||
"A question you want to answer using the content of the webpage."
|
||||
),
|
||||
required=False,
|
||||
),
|
||||
"get_raw_content": JSONSchema(
|
||||
type=JSONSchema.Type.BOOLEAN,
|
||||
description=(
|
||||
"If true, the unprocessed content of the webpage will be returned. "
|
||||
"This consumes a lot of tokens, so use it with caution."
|
||||
),
|
||||
required=False,
|
||||
),
|
||||
},
|
||||
)
|
||||
@validate_url
|
||||
async def read_webpage(
|
||||
self,
|
||||
url: str,
|
||||
*,
|
||||
topics_of_interest: list[str] = [],
|
||||
get_raw_content: bool = False,
|
||||
question: str = "",
|
||||
) -> str:
|
||||
"""Browse a website and return the answer and links to the user
|
||||
|
||||
Args:
|
||||
url (str): The url of the website to browse
|
||||
question (str): The question to answer using the content of the webpage
|
||||
|
||||
Returns:
|
||||
str: The answer and links to the user and the webdriver
|
||||
"""
|
||||
driver = None
|
||||
try:
|
||||
driver = await self.open_page_in_browser(url, self.legacy_config)
|
||||
|
||||
text = self.scrape_text_with_selenium(driver)
|
||||
links = self.scrape_links_with_selenium(driver, url)
|
||||
|
||||
return_literal_content = True
|
||||
summarized = False
|
||||
if not text:
|
||||
return f"Website did not contain any text.\n\nLinks: {links}"
|
||||
elif get_raw_content:
|
||||
if (
|
||||
output_tokens := self.llm_provider.count_tokens(
|
||||
text, self.model_info.name
|
||||
)
|
||||
) > MAX_RAW_CONTENT_LENGTH:
|
||||
oversize_factor = round(output_tokens / MAX_RAW_CONTENT_LENGTH, 1)
|
||||
raise TooMuchOutputError(
|
||||
f"Page content is {oversize_factor}x the allowed length "
|
||||
"for `get_raw_content=true`"
|
||||
)
|
||||
return text + (f"\n\nLinks: {links}" if links else "")
|
||||
else:
|
||||
text = await self.summarize_webpage(
|
||||
text, question or None, topics_of_interest
|
||||
)
|
||||
return_literal_content = bool(question)
|
||||
summarized = True
|
||||
|
||||
# Limit links to LINKS_TO_RETURN
|
||||
if len(links) > LINKS_TO_RETURN:
|
||||
links = links[:LINKS_TO_RETURN]
|
||||
|
||||
text_fmt = f"'''{text}'''" if "\n" in text else f"'{text}'"
|
||||
links_fmt = "\n".join(f"- {link}" for link in links)
|
||||
return (
|
||||
f"Page content{' (summary)' if summarized else ''}:"
|
||||
if return_literal_content
|
||||
else "Answer gathered from webpage:"
|
||||
) + f" {text_fmt}\n\nLinks:\n{links_fmt}"
|
||||
|
||||
except WebDriverException as e:
|
||||
# These errors are often quite long and include lots of context.
|
||||
# Just grab the first line.
|
||||
msg = e.msg.split("\n")[0] if e.msg else str(e)
|
||||
if "net::" in msg:
|
||||
raise BrowsingError(
|
||||
"A networking error occurred while trying to load the page: %s"
|
||||
% re.sub(r"^unknown error: ", "", msg)
|
||||
)
|
||||
raise CommandExecutionError(msg)
|
||||
finally:
|
||||
if driver:
|
||||
driver.close()
|
||||
|
||||
def scrape_text_with_selenium(self, driver: WebDriver) -> str:
|
||||
"""Scrape text from a browser window using selenium
|
||||
|
||||
Args:
|
||||
driver (WebDriver): A driver object representing
|
||||
the browser window to scrape
|
||||
|
||||
Returns:
|
||||
str: the text scraped from the website
|
||||
"""
|
||||
|
||||
# Get the HTML content directly from the browser's DOM
|
||||
page_source = driver.execute_script("return document.body.outerHTML;")
|
||||
soup = BeautifulSoup(page_source, "html.parser")
|
||||
|
||||
for script in soup(["script", "style"]):
|
||||
script.extract()
|
||||
|
||||
text = soup.get_text()
|
||||
lines = (line.strip() for line in text.splitlines())
|
||||
chunks = (phrase.strip() for line in lines for phrase in line.split(" "))
|
||||
text = "\n".join(chunk for chunk in chunks if chunk)
|
||||
return text
|
||||
|
||||
def scrape_links_with_selenium(self, driver: WebDriver, base_url: str) -> list[str]:
|
||||
"""Scrape links from a website using selenium
|
||||
|
||||
Args:
|
||||
driver (WebDriver): A driver object representing
|
||||
the browser window to scrape
|
||||
base_url (str): The base URL to use for resolving relative links
|
||||
|
||||
Returns:
|
||||
List[str]: The links scraped from the website
|
||||
"""
|
||||
page_source = driver.page_source
|
||||
soup = BeautifulSoup(page_source, "html.parser")
|
||||
|
||||
for script in soup(["script", "style"]):
|
||||
script.extract()
|
||||
|
||||
hyperlinks = extract_hyperlinks(soup, base_url)
|
||||
|
||||
return format_hyperlinks(hyperlinks)
|
||||
|
||||
async def open_page_in_browser(self, url: str, config: Config) -> WebDriver:
|
||||
"""Open a browser window and load a web page using Selenium
|
||||
|
||||
Params:
|
||||
url (str): The URL of the page to load
|
||||
config (Config): The applicable application configuration
|
||||
|
||||
Returns:
|
||||
driver (WebDriver): A driver object representing
|
||||
the browser window to scrape
|
||||
"""
|
||||
logging.getLogger("selenium").setLevel(logging.CRITICAL)
|
||||
|
||||
options_available: dict[str, Type[BrowserOptions]] = {
|
||||
"chrome": ChromeOptions,
|
||||
"edge": EdgeOptions,
|
||||
"firefox": FirefoxOptions,
|
||||
"safari": SafariOptions,
|
||||
}
|
||||
|
||||
options: BrowserOptions = options_available[config.selenium_web_browser]()
|
||||
options.add_argument(f"user-agent={config.user_agent}")
|
||||
|
||||
if isinstance(options, FirefoxOptions):
|
||||
if config.selenium_headless:
|
||||
options.headless = True # type: ignore
|
||||
options.add_argument("--disable-gpu")
|
||||
driver = FirefoxDriver(
|
||||
service=GeckoDriverService(GeckoDriverManager().install()),
|
||||
options=options,
|
||||
)
|
||||
elif isinstance(options, EdgeOptions):
|
||||
driver = EdgeDriver(
|
||||
service=EdgeDriverService(EdgeDriverManager().install()),
|
||||
options=options,
|
||||
)
|
||||
elif isinstance(options, SafariOptions):
|
||||
# Requires a bit more setup on the users end.
|
||||
# See https://developer.apple.com/documentation/webkit/testing_with_webdriver_in_safari # noqa: E501
|
||||
driver = SafariDriver(options=options)
|
||||
elif isinstance(options, ChromeOptions):
|
||||
if platform == "linux" or platform == "linux2":
|
||||
options.add_argument("--disable-dev-shm-usage")
|
||||
options.add_argument("--remote-debugging-port=9222")
|
||||
|
||||
options.add_argument("--no-sandbox")
|
||||
if config.selenium_headless:
|
||||
options.add_argument("--headless=new")
|
||||
options.add_argument("--disable-gpu")
|
||||
|
||||
self._sideload_chrome_extensions(
|
||||
options, config.app_data_dir / "assets" / "crx"
|
||||
)
|
||||
|
||||
if (chromium_driver_path := Path("/usr/bin/chromedriver")).exists():
|
||||
chrome_service = ChromeDriverService(str(chromium_driver_path))
|
||||
else:
|
||||
try:
|
||||
chrome_driver = ChromeDriverManager().install()
|
||||
except AttributeError as e:
|
||||
if "'NoneType' object has no attribute 'split'" in str(e):
|
||||
# https://github.com/SergeyPirogov/webdriver_manager/issues/649
|
||||
logger.critical(
|
||||
"Connecting to browser failed:"
|
||||
" is Chrome or Chromium installed?"
|
||||
)
|
||||
raise
|
||||
chrome_service = ChromeDriverService(chrome_driver)
|
||||
driver = ChromeDriver(service=chrome_service, options=options)
|
||||
|
||||
driver.get(url)
|
||||
|
||||
# Wait for page to be ready, sleep 2 seconds, wait again until page ready.
|
||||
# This allows the cookiewall squasher time to get rid of cookie walls.
|
||||
WebDriverWait(driver, 10).until(
|
||||
EC.presence_of_element_located((By.TAG_NAME, "body"))
|
||||
)
|
||||
await asyncio.sleep(2)
|
||||
WebDriverWait(driver, 10).until(
|
||||
EC.presence_of_element_located((By.TAG_NAME, "body"))
|
||||
)
|
||||
|
||||
return driver
|
||||
|
||||
def _sideload_chrome_extensions(
|
||||
self, options: ChromeOptions, dl_folder: Path
|
||||
) -> None:
|
||||
crx_download_url_template = "https://clients2.google.com/service/update2/crx?response=redirect&prodversion=49.0&acceptformat=crx3&x=id%3D{crx_id}%26installsource%3Dondemand%26uc" # noqa
|
||||
cookiewall_squasher_crx_id = "edibdbjcniadpccecjdfdjjppcpchdlm"
|
||||
adblocker_crx_id = "cjpalhdlnbpafiamejdnhcphjbkeiagm"
|
||||
|
||||
# Make sure the target folder exists
|
||||
dl_folder.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
for crx_id in (cookiewall_squasher_crx_id, adblocker_crx_id):
|
||||
crx_path = dl_folder / f"{crx_id}.crx"
|
||||
if not crx_path.exists():
|
||||
logger.debug(f"Downloading CRX {crx_id}...")
|
||||
crx_download_url = crx_download_url_template.format(crx_id=crx_id)
|
||||
urlretrieve(crx_download_url, crx_path)
|
||||
logger.debug(f"Downloaded {crx_path.name}")
|
||||
options.add_extension(str(crx_path))
|
||||
|
||||
async def summarize_webpage(
|
||||
self,
|
||||
text: str,
|
||||
question: str | None,
|
||||
topics_of_interest: list[str],
|
||||
) -> str:
|
||||
"""Summarize text using the OpenAI API
|
||||
|
||||
Args:
|
||||
url (str): The url of the text
|
||||
text (str): The text to summarize
|
||||
question (str): The question to ask the model
|
||||
driver (WebDriver): The webdriver to use to scroll the page
|
||||
|
||||
Returns:
|
||||
str: The summary of the text
|
||||
"""
|
||||
if not text:
|
||||
raise ValueError("No text to summarize")
|
||||
|
||||
text_length = len(text)
|
||||
logger.debug(f"Web page content length: {text_length} characters")
|
||||
|
||||
result = None
|
||||
information = None
|
||||
if topics_of_interest:
|
||||
information = await extract_information(
|
||||
text,
|
||||
topics_of_interest=topics_of_interest,
|
||||
llm_provider=self.llm_provider,
|
||||
config=self.legacy_config,
|
||||
)
|
||||
return "\n".join(f"* {i}" for i in information)
|
||||
else:
|
||||
result, _ = await summarize_text(
|
||||
text,
|
||||
question=question,
|
||||
llm_provider=self.llm_provider,
|
||||
config=self.legacy_config,
|
||||
)
|
||||
return result
|
||||
14
forge/forge/config/__init__.py
Normal file
14
forge/forge/config/__init__.py
Normal file
@@ -0,0 +1,14 @@
|
||||
"""
|
||||
This module contains configuration models and helpers for AutoGPT Forge.
|
||||
"""
|
||||
from .ai_directives import AIDirectives
|
||||
from .ai_profile import AIProfile
|
||||
from .config import Config, ConfigBuilder, assert_config_has_openai_api_key
|
||||
|
||||
__all__ = [
|
||||
"assert_config_has_openai_api_key",
|
||||
"AIProfile",
|
||||
"AIDirectives",
|
||||
"Config",
|
||||
"ConfigBuilder",
|
||||
]
|
||||
28
forge/forge/config/ai_directives.py
Normal file
28
forge/forge/config/ai_directives.py
Normal file
@@ -0,0 +1,28 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AIDirectives(BaseModel):
|
||||
"""An object that contains the basic directives for the AI prompt.
|
||||
|
||||
Attributes:
|
||||
constraints (list): A list of constraints that the AI should adhere to.
|
||||
resources (list): A list of resources that the AI can utilize.
|
||||
best_practices (list): A list of best practices that the AI should follow.
|
||||
"""
|
||||
|
||||
resources: list[str] = Field(default_factory=list)
|
||||
constraints: list[str] = Field(default_factory=list)
|
||||
best_practices: list[str] = Field(default_factory=list)
|
||||
|
||||
def __add__(self, other: AIDirectives) -> AIDirectives:
|
||||
return AIDirectives(
|
||||
resources=self.resources + other.resources,
|
||||
constraints=self.constraints + other.constraints,
|
||||
best_practices=self.best_practices + other.best_practices,
|
||||
).copy(deep=True)
|
||||
27
forge/forge/config/ai_profile.py
Normal file
27
forge/forge/config/ai_profile.py
Normal file
@@ -0,0 +1,27 @@
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
DEFAULT_AI_NAME = "AutoGPT"
|
||||
DEFAULT_AI_ROLE = (
|
||||
"a seasoned digital assistant: "
|
||||
"capable, intelligent, considerate and assertive. "
|
||||
"You have extensive research and development skills, and you don't shy "
|
||||
"away from writing some code to solve a problem. "
|
||||
"You are pragmatic and make the most out of the tools available to you."
|
||||
)
|
||||
|
||||
|
||||
class AIProfile(BaseModel):
|
||||
"""
|
||||
Object to hold the AI's personality.
|
||||
|
||||
Attributes:
|
||||
ai_name (str): The name of the AI.
|
||||
ai_role (str): The description of the AI's role.
|
||||
ai_goals (list): The list of objectives the AI is supposed to complete.
|
||||
api_budget (float): The maximum dollar value for API calls (0.0 means infinite)
|
||||
"""
|
||||
|
||||
ai_name: str = DEFAULT_AI_NAME
|
||||
ai_role: str = DEFAULT_AI_ROLE
|
||||
"""`ai_role` should fit in the following format: `You are {ai_name}, {ai_role}`"""
|
||||
ai_goals: list[str] = Field(default_factory=list[str])
|
||||
266
forge/forge/config/config.py
Normal file
266
forge/forge/config/config.py
Normal file
@@ -0,0 +1,266 @@
|
||||
"""Configuration class to store the state of bools for different scripts access."""
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
from pathlib import Path
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
import click
|
||||
from colorama import Fore
|
||||
from pydantic import SecretStr, validator
|
||||
|
||||
import forge
|
||||
from forge.file_storage import FileStorageBackendName
|
||||
from forge.llm.providers import CHAT_MODELS, ModelName
|
||||
from forge.llm.providers.openai import OpenAICredentials, OpenAIModelName
|
||||
from forge.logging.config import LoggingConfig
|
||||
from forge.models.config import Configurable, SystemSettings, UserConfigurable
|
||||
from forge.speech.say import TTSConfig
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
PROJECT_ROOT = Path(forge.__file__).parent.parent
|
||||
AZURE_CONFIG_FILE = Path("azure.yaml")
|
||||
|
||||
GPT_4_MODEL = OpenAIModelName.GPT4
|
||||
GPT_3_MODEL = OpenAIModelName.GPT3
|
||||
|
||||
|
||||
class Config(SystemSettings, arbitrary_types_allowed=True):
|
||||
name: str = "Auto-GPT configuration"
|
||||
description: str = "Default configuration for the Auto-GPT application."
|
||||
|
||||
########################
|
||||
# Application Settings #
|
||||
########################
|
||||
project_root: Path = PROJECT_ROOT
|
||||
app_data_dir: Path = project_root / "data"
|
||||
skip_news: bool = False
|
||||
skip_reprompt: bool = False
|
||||
authorise_key: str = UserConfigurable(default="y", from_env="AUTHORISE_COMMAND_KEY")
|
||||
exit_key: str = UserConfigurable(default="n", from_env="EXIT_KEY")
|
||||
noninteractive_mode: bool = False
|
||||
|
||||
# TTS configuration
|
||||
logging: LoggingConfig = LoggingConfig()
|
||||
tts_config: TTSConfig = TTSConfig()
|
||||
|
||||
# File storage
|
||||
file_storage_backend: FileStorageBackendName = UserConfigurable(
|
||||
default=FileStorageBackendName.LOCAL, from_env="FILE_STORAGE_BACKEND"
|
||||
)
|
||||
|
||||
##########################
|
||||
# Agent Control Settings #
|
||||
##########################
|
||||
# Model configuration
|
||||
fast_llm: ModelName = UserConfigurable(
|
||||
default=OpenAIModelName.GPT3,
|
||||
from_env="FAST_LLM",
|
||||
)
|
||||
smart_llm: ModelName = UserConfigurable(
|
||||
default=OpenAIModelName.GPT4_TURBO,
|
||||
from_env="SMART_LLM",
|
||||
)
|
||||
temperature: float = UserConfigurable(default=0, from_env="TEMPERATURE")
|
||||
openai_functions: bool = UserConfigurable(
|
||||
default=False, from_env=lambda: os.getenv("OPENAI_FUNCTIONS", "False") == "True"
|
||||
)
|
||||
embedding_model: str = UserConfigurable(
|
||||
default="text-embedding-3-small", from_env="EMBEDDING_MODEL"
|
||||
)
|
||||
browse_spacy_language_model: str = UserConfigurable(
|
||||
default="en_core_web_sm", from_env="BROWSE_SPACY_LANGUAGE_MODEL"
|
||||
)
|
||||
|
||||
# Run loop configuration
|
||||
continuous_mode: bool = False
|
||||
continuous_limit: int = 0
|
||||
|
||||
############
|
||||
# Commands #
|
||||
############
|
||||
# General
|
||||
disabled_commands: list[str] = UserConfigurable(
|
||||
default_factory=list,
|
||||
from_env=lambda: _safe_split(os.getenv("DISABLED_COMMANDS")),
|
||||
)
|
||||
|
||||
# File ops
|
||||
restrict_to_workspace: bool = UserConfigurable(
|
||||
default=True,
|
||||
from_env=lambda: os.getenv("RESTRICT_TO_WORKSPACE", "True") == "True",
|
||||
)
|
||||
allow_downloads: bool = False
|
||||
|
||||
# Shell commands
|
||||
shell_command_control: str = UserConfigurable(
|
||||
default="denylist", from_env="SHELL_COMMAND_CONTROL"
|
||||
)
|
||||
execute_local_commands: bool = UserConfigurable(
|
||||
default=False,
|
||||
from_env=lambda: os.getenv("EXECUTE_LOCAL_COMMANDS", "False") == "True",
|
||||
)
|
||||
shell_denylist: list[str] = UserConfigurable(
|
||||
default_factory=lambda: ["sudo", "su"],
|
||||
from_env=lambda: _safe_split(
|
||||
os.getenv("SHELL_DENYLIST", os.getenv("DENY_COMMANDS"))
|
||||
),
|
||||
)
|
||||
shell_allowlist: list[str] = UserConfigurable(
|
||||
default_factory=list,
|
||||
from_env=lambda: _safe_split(
|
||||
os.getenv("SHELL_ALLOWLIST", os.getenv("ALLOW_COMMANDS"))
|
||||
),
|
||||
)
|
||||
|
||||
# Text to image
|
||||
image_provider: Optional[str] = UserConfigurable(from_env="IMAGE_PROVIDER")
|
||||
huggingface_image_model: str = UserConfigurable(
|
||||
default="CompVis/stable-diffusion-v1-4", from_env="HUGGINGFACE_IMAGE_MODEL"
|
||||
)
|
||||
sd_webui_url: Optional[str] = UserConfigurable(
|
||||
default="http://localhost:7860", from_env="SD_WEBUI_URL"
|
||||
)
|
||||
image_size: int = UserConfigurable(default=256, from_env="IMAGE_SIZE")
|
||||
|
||||
# Audio to text
|
||||
audio_to_text_provider: str = UserConfigurable(
|
||||
default="huggingface", from_env="AUDIO_TO_TEXT_PROVIDER"
|
||||
)
|
||||
huggingface_audio_to_text_model: Optional[str] = UserConfigurable(
|
||||
from_env="HUGGINGFACE_AUDIO_TO_TEXT_MODEL"
|
||||
)
|
||||
|
||||
# Web browsing
|
||||
selenium_web_browser: str = UserConfigurable("chrome", from_env="USE_WEB_BROWSER")
|
||||
selenium_headless: bool = UserConfigurable(
|
||||
default=True, from_env=lambda: os.getenv("HEADLESS_BROWSER", "True") == "True"
|
||||
)
|
||||
user_agent: str = UserConfigurable(
|
||||
default="Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_4) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/83.0.4103.97 Safari/537.36", # noqa: E501
|
||||
from_env="USER_AGENT",
|
||||
)
|
||||
|
||||
###############
|
||||
# Credentials #
|
||||
###############
|
||||
# OpenAI
|
||||
openai_credentials: Optional[OpenAICredentials] = None
|
||||
azure_config_file: Optional[Path] = UserConfigurable(
|
||||
default=AZURE_CONFIG_FILE, from_env="AZURE_CONFIG_FILE"
|
||||
)
|
||||
|
||||
# Github
|
||||
github_api_key: Optional[str] = UserConfigurable(from_env="GITHUB_API_KEY")
|
||||
github_username: Optional[str] = UserConfigurable(from_env="GITHUB_USERNAME")
|
||||
|
||||
# Google
|
||||
google_api_key: Optional[str] = UserConfigurable(from_env="GOOGLE_API_KEY")
|
||||
google_custom_search_engine_id: Optional[str] = UserConfigurable(
|
||||
from_env="GOOGLE_CUSTOM_SEARCH_ENGINE_ID",
|
||||
)
|
||||
|
||||
# Huggingface
|
||||
huggingface_api_token: Optional[str] = UserConfigurable(
|
||||
from_env="HUGGINGFACE_API_TOKEN"
|
||||
)
|
||||
|
||||
# Stable Diffusion
|
||||
sd_webui_auth: Optional[str] = UserConfigurable(from_env="SD_WEBUI_AUTH")
|
||||
|
||||
@validator("openai_functions")
|
||||
def validate_openai_functions(cls, v: bool, values: dict[str, Any]):
|
||||
if v:
|
||||
smart_llm = values["smart_llm"]
|
||||
assert CHAT_MODELS[smart_llm].has_function_call_api, (
|
||||
f"Model {smart_llm} does not support tool calling. "
|
||||
"Please disable OPENAI_FUNCTIONS or choose a suitable model."
|
||||
)
|
||||
return v
|
||||
|
||||
|
||||
class ConfigBuilder(Configurable[Config]):
|
||||
default_settings = Config()
|
||||
|
||||
@classmethod
|
||||
def build_config_from_env(cls, project_root: Path = PROJECT_ROOT) -> Config:
|
||||
"""Initialize the Config class"""
|
||||
|
||||
config = cls.build_agent_configuration()
|
||||
config.project_root = project_root
|
||||
|
||||
# Make relative paths absolute
|
||||
for k in {
|
||||
"azure_config_file", # TODO: move from project root
|
||||
}:
|
||||
setattr(config, k, project_root / getattr(config, k))
|
||||
|
||||
if (
|
||||
config.openai_credentials
|
||||
and config.openai_credentials.api_type == SecretStr("azure")
|
||||
and (config_file := config.azure_config_file)
|
||||
):
|
||||
config.openai_credentials.load_azure_config(config_file)
|
||||
|
||||
return config
|
||||
|
||||
|
||||
def assert_config_has_openai_api_key(config: Config) -> None:
|
||||
"""Check if the OpenAI API key is set in config.py or as an environment variable."""
|
||||
key_pattern = r"^sk-(proj-)?\w{48}"
|
||||
openai_api_key = (
|
||||
config.openai_credentials.api_key.get_secret_value()
|
||||
if config.openai_credentials
|
||||
else ""
|
||||
)
|
||||
|
||||
# If there's no credentials or empty API key, prompt the user to set it
|
||||
if not openai_api_key:
|
||||
logger.error(
|
||||
"Please set your OpenAI API key in .env or as an environment variable."
|
||||
)
|
||||
logger.info(
|
||||
"You can get your key from https://platform.openai.com/account/api-keys"
|
||||
)
|
||||
openai_api_key = click.prompt(
|
||||
"Please enter your OpenAI API key if you have it",
|
||||
default="",
|
||||
show_default=False,
|
||||
)
|
||||
openai_api_key = openai_api_key.strip()
|
||||
if re.search(key_pattern, openai_api_key):
|
||||
os.environ["OPENAI_API_KEY"] = openai_api_key
|
||||
if config.openai_credentials:
|
||||
config.openai_credentials.api_key = SecretStr(openai_api_key)
|
||||
else:
|
||||
config.openai_credentials = OpenAICredentials(
|
||||
api_key=SecretStr(openai_api_key)
|
||||
)
|
||||
print("OpenAI API key successfully set!")
|
||||
print(
|
||||
f"{Fore.YELLOW}NOTE: The API key you've set is only temporary. "
|
||||
f"For longer sessions, please set it in the .env file{Fore.RESET}"
|
||||
)
|
||||
else:
|
||||
print(f"{Fore.RED}Invalid OpenAI API key{Fore.RESET}")
|
||||
exit(1)
|
||||
# If key is set, but it looks invalid
|
||||
elif not re.search(key_pattern, openai_api_key):
|
||||
logger.error(
|
||||
"Invalid OpenAI API key! "
|
||||
"Please set your OpenAI API key in .env or as an environment variable."
|
||||
)
|
||||
logger.info(
|
||||
"You can get your key from https://platform.openai.com/account/api-keys"
|
||||
)
|
||||
exit(1)
|
||||
|
||||
|
||||
def _safe_split(s: Union[str, None], sep: str = ",") -> list[str]:
|
||||
"""Split a string by a separator. Return an empty list if the string is None."""
|
||||
if s is None:
|
||||
return []
|
||||
return s.split(sep)
|
||||
8
forge/forge/conftest.py
Normal file
8
forge/forge/conftest.py
Normal file
@@ -0,0 +1,8 @@
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def test_workspace(tmp_path: Path) -> Path:
|
||||
return tmp_path
|
||||
0
forge/forge/content_processing/__init__.py
Normal file
0
forge/forge/content_processing/__init__.py
Normal file
33
forge/forge/content_processing/html.py
Normal file
33
forge/forge/content_processing/html.py
Normal file
@@ -0,0 +1,33 @@
|
||||
"""HTML processing functions"""
|
||||
from __future__ import annotations
|
||||
|
||||
from bs4 import BeautifulSoup
|
||||
from requests.compat import urljoin
|
||||
|
||||
|
||||
def extract_hyperlinks(soup: BeautifulSoup, base_url: str) -> list[tuple[str, str]]:
|
||||
"""Extract hyperlinks from a BeautifulSoup object
|
||||
|
||||
Args:
|
||||
soup (BeautifulSoup): The BeautifulSoup object
|
||||
base_url (str): The base URL
|
||||
|
||||
Returns:
|
||||
List[Tuple[str, str]]: The extracted hyperlinks
|
||||
"""
|
||||
return [
|
||||
(link.text, urljoin(base_url, link["href"]))
|
||||
for link in soup.find_all("a", href=True)
|
||||
]
|
||||
|
||||
|
||||
def format_hyperlinks(hyperlinks: list[tuple[str, str]]) -> list[str]:
|
||||
"""Format hyperlinks to be displayed to the user
|
||||
|
||||
Args:
|
||||
hyperlinks (List[Tuple[str, str]]): The hyperlinks to format
|
||||
|
||||
Returns:
|
||||
List[str]: The formatted hyperlinks
|
||||
"""
|
||||
return [f"{link_text.strip()} ({link_url})" for link_text, link_url in hyperlinks]
|
||||
317
forge/forge/content_processing/text.py
Normal file
317
forge/forge/content_processing/text.py
Normal file
@@ -0,0 +1,317 @@
|
||||
"""Text processing functions"""
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import math
|
||||
from typing import TYPE_CHECKING, Iterator, Optional, TypeVar
|
||||
|
||||
import spacy
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from forge.config.config import Config
|
||||
|
||||
from forge.json.parsing import extract_list_from_json
|
||||
from forge.llm.prompting import ChatPrompt
|
||||
from forge.llm.providers import ChatMessage, ModelTokenizer, MultiProvider
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
def batch(
|
||||
sequence: list[T], max_batch_length: int, overlap: int = 0
|
||||
) -> Iterator[list[T]]:
|
||||
"""
|
||||
Batch data from iterable into slices of length N. The last batch may be shorter.
|
||||
|
||||
Example: `batched('ABCDEFGHIJ', 3)` --> `ABC DEF GHI J`
|
||||
"""
|
||||
if max_batch_length < 1:
|
||||
raise ValueError("n must be at least one")
|
||||
for i in range(0, len(sequence), max_batch_length - overlap):
|
||||
yield sequence[i : i + max_batch_length]
|
||||
|
||||
|
||||
def chunk_content(
|
||||
content: str,
|
||||
max_chunk_length: int,
|
||||
tokenizer: ModelTokenizer,
|
||||
with_overlap: bool = True,
|
||||
) -> Iterator[tuple[str, int]]:
|
||||
"""Split content into chunks of approximately equal token length."""
|
||||
|
||||
MAX_OVERLAP = 200 # limit overlap to save tokens
|
||||
|
||||
tokenized_text = tokenizer.encode(content)
|
||||
total_length = len(tokenized_text)
|
||||
n_chunks = math.ceil(total_length / max_chunk_length)
|
||||
|
||||
chunk_length = math.ceil(total_length / n_chunks)
|
||||
overlap = min(max_chunk_length - chunk_length, MAX_OVERLAP) if with_overlap else 0
|
||||
|
||||
for token_batch in batch(tokenized_text, chunk_length + overlap, overlap):
|
||||
yield tokenizer.decode(token_batch), len(token_batch)
|
||||
|
||||
|
||||
async def summarize_text(
|
||||
text: str,
|
||||
llm_provider: MultiProvider,
|
||||
config: Config,
|
||||
question: Optional[str] = None,
|
||||
instruction: Optional[str] = None,
|
||||
) -> tuple[str, list[tuple[str, str]]]:
|
||||
if question:
|
||||
if instruction:
|
||||
raise ValueError(
|
||||
"Parameters 'question' and 'instructions' cannot both be set"
|
||||
)
|
||||
|
||||
instruction = (
|
||||
f'From the text, answer the question: "{question}". '
|
||||
"If the answer is not in the text, indicate this clearly "
|
||||
"and concisely state why the text is not suitable to answer the question."
|
||||
)
|
||||
elif not instruction:
|
||||
instruction = (
|
||||
"Summarize or describe the text clearly and concisely, "
|
||||
"whichever seems more appropriate."
|
||||
)
|
||||
|
||||
return await _process_text( # type: ignore
|
||||
text=text,
|
||||
instruction=instruction,
|
||||
llm_provider=llm_provider,
|
||||
config=config,
|
||||
)
|
||||
|
||||
|
||||
async def extract_information(
|
||||
source_text: str,
|
||||
topics_of_interest: list[str],
|
||||
llm_provider: MultiProvider,
|
||||
config: Config,
|
||||
) -> list[str]:
|
||||
fmt_topics_list = "\n".join(f"* {topic}." for topic in topics_of_interest)
|
||||
instruction = (
|
||||
"Extract relevant pieces of information about the following topics:\n"
|
||||
f"{fmt_topics_list}\n"
|
||||
"Reword pieces of information if needed to make them self-explanatory. "
|
||||
"Be concise.\n\n"
|
||||
"Respond with an `Array<string>` in JSON format AND NOTHING ELSE. "
|
||||
'If the text contains no relevant information, return "[]".'
|
||||
)
|
||||
return await _process_text( # type: ignore
|
||||
text=source_text,
|
||||
instruction=instruction,
|
||||
output_type=list[str],
|
||||
llm_provider=llm_provider,
|
||||
config=config,
|
||||
)
|
||||
|
||||
|
||||
async def _process_text(
|
||||
text: str,
|
||||
instruction: str,
|
||||
llm_provider: MultiProvider,
|
||||
config: Config,
|
||||
output_type: type[str | list[str]] = str,
|
||||
) -> tuple[str, list[tuple[str, str]]] | list[str]:
|
||||
"""Process text using the OpenAI API for summarization or information extraction
|
||||
|
||||
Params:
|
||||
text (str): The text to process.
|
||||
instruction (str): Additional instruction for processing.
|
||||
llm_provider: LLM provider to use.
|
||||
config (Config): The global application config.
|
||||
output_type: `str` for summaries or `list[str]` for piece-wise info extraction.
|
||||
|
||||
Returns:
|
||||
For summarization: tuple[str, None | list[(summary, chunk)]]
|
||||
For piece-wise information extraction: list[str]
|
||||
"""
|
||||
if not text.strip():
|
||||
raise ValueError("No content")
|
||||
|
||||
model = config.fast_llm
|
||||
|
||||
text_tlength = llm_provider.count_tokens(text, model)
|
||||
logger.debug(f"Text length: {text_tlength} tokens")
|
||||
|
||||
max_result_tokens = 500
|
||||
max_chunk_length = llm_provider.get_token_limit(model) - max_result_tokens - 50
|
||||
logger.debug(f"Max chunk length: {max_chunk_length} tokens")
|
||||
|
||||
if text_tlength < max_chunk_length:
|
||||
prompt = ChatPrompt(
|
||||
messages=[
|
||||
ChatMessage.system(
|
||||
"The user is going to give you a text enclosed in triple quotes. "
|
||||
f"{instruction}"
|
||||
),
|
||||
ChatMessage.user(f'"""{text}"""'),
|
||||
]
|
||||
)
|
||||
|
||||
logger.debug(f"PROCESSING:\n{prompt}")
|
||||
|
||||
response = await llm_provider.create_chat_completion(
|
||||
model_prompt=prompt.messages,
|
||||
model_name=model,
|
||||
temperature=0.5,
|
||||
max_output_tokens=max_result_tokens,
|
||||
completion_parser=lambda s: (
|
||||
extract_list_from_json(s.content) if output_type is not str else None
|
||||
),
|
||||
)
|
||||
|
||||
if isinstance(response.parsed_result, list):
|
||||
logger.debug(f"Raw LLM response: {repr(response.response.content)}")
|
||||
fmt_result_bullet_list = "\n".join(f"* {r}" for r in response.parsed_result)
|
||||
logger.debug(
|
||||
f"\n{'-'*11} EXTRACTION RESULT {'-'*12}\n"
|
||||
f"{fmt_result_bullet_list}\n"
|
||||
f"{'-'*42}\n"
|
||||
)
|
||||
return response.parsed_result
|
||||
else:
|
||||
summary = response.response.content
|
||||
logger.debug(f"\n{'-'*16} SUMMARY {'-'*17}\n{summary}\n{'-'*42}\n")
|
||||
return summary.strip(), [(summary, text)]
|
||||
else:
|
||||
chunks = list(
|
||||
split_text(
|
||||
text,
|
||||
config=config,
|
||||
max_chunk_length=max_chunk_length,
|
||||
tokenizer=llm_provider.get_tokenizer(model),
|
||||
)
|
||||
)
|
||||
|
||||
processed_results = []
|
||||
for i, (chunk, _) in enumerate(chunks):
|
||||
logger.info(f"Processing chunk {i + 1} / {len(chunks)}")
|
||||
chunk_result = await _process_text(
|
||||
text=chunk,
|
||||
instruction=instruction,
|
||||
output_type=output_type,
|
||||
llm_provider=llm_provider,
|
||||
config=config,
|
||||
)
|
||||
processed_results.extend(
|
||||
chunk_result if output_type == list[str] else [chunk_result]
|
||||
)
|
||||
|
||||
if output_type == list[str]:
|
||||
return processed_results
|
||||
else:
|
||||
summary, _ = await _process_text(
|
||||
"\n\n".join([result[0] for result in processed_results]),
|
||||
instruction=(
|
||||
"The text consists of multiple partial summaries. "
|
||||
"Combine these partial summaries into one."
|
||||
),
|
||||
llm_provider=llm_provider,
|
||||
config=config,
|
||||
)
|
||||
return summary.strip(), [
|
||||
(processed_results[i], chunks[i][0]) for i in range(0, len(chunks))
|
||||
]
|
||||
|
||||
|
||||
def split_text(
|
||||
text: str,
|
||||
config: Config,
|
||||
max_chunk_length: int,
|
||||
tokenizer: ModelTokenizer,
|
||||
with_overlap: bool = True,
|
||||
) -> Iterator[tuple[str, int]]:
|
||||
"""
|
||||
Split text into chunks of sentences, with each chunk not exceeding the max length.
|
||||
|
||||
Args:
|
||||
text (str): The text to split.
|
||||
config (Config): Config object containing the Spacy model setting.
|
||||
max_chunk_length (int, optional): The maximum length of a chunk.
|
||||
tokenizer (ModelTokenizer): Tokenizer to use for determining chunk length.
|
||||
with_overlap (bool, optional): Whether to allow overlap between chunks.
|
||||
|
||||
Yields:
|
||||
str: The next chunk of text
|
||||
|
||||
Raises:
|
||||
ValueError: when a sentence is longer than the maximum length
|
||||
"""
|
||||
text_length = len(tokenizer.encode(text))
|
||||
|
||||
if text_length < max_chunk_length:
|
||||
yield text, text_length
|
||||
return
|
||||
|
||||
n_chunks = math.ceil(text_length / max_chunk_length)
|
||||
target_chunk_length = math.ceil(text_length / n_chunks)
|
||||
|
||||
nlp: spacy.language.Language = spacy.load(config.browse_spacy_language_model)
|
||||
nlp.add_pipe("sentencizer")
|
||||
doc = nlp(text)
|
||||
sentences = [sentence.text.strip() for sentence in doc.sents]
|
||||
|
||||
current_chunk: list[str] = []
|
||||
current_chunk_length = 0
|
||||
last_sentence = None
|
||||
last_sentence_length = 0
|
||||
|
||||
i = 0
|
||||
while i < len(sentences):
|
||||
sentence = sentences[i]
|
||||
sentence_length = len(tokenizer.encode(sentence))
|
||||
expected_chunk_length = current_chunk_length + 1 + sentence_length
|
||||
|
||||
if (
|
||||
expected_chunk_length < max_chunk_length
|
||||
# try to create chunks of approximately equal size
|
||||
and expected_chunk_length - (sentence_length / 2) < target_chunk_length
|
||||
):
|
||||
current_chunk.append(sentence)
|
||||
current_chunk_length = expected_chunk_length
|
||||
|
||||
elif sentence_length < max_chunk_length:
|
||||
if last_sentence:
|
||||
yield " ".join(current_chunk), current_chunk_length
|
||||
current_chunk = []
|
||||
current_chunk_length = 0
|
||||
|
||||
if with_overlap:
|
||||
overlap_max_length = max_chunk_length - sentence_length - 1
|
||||
if last_sentence_length < overlap_max_length:
|
||||
current_chunk += [last_sentence]
|
||||
current_chunk_length += last_sentence_length + 1
|
||||
elif overlap_max_length > 5:
|
||||
# add as much from the end of the last sentence as fits
|
||||
current_chunk += [
|
||||
list(
|
||||
chunk_content(
|
||||
content=last_sentence,
|
||||
max_chunk_length=overlap_max_length,
|
||||
tokenizer=tokenizer,
|
||||
)
|
||||
).pop()[0],
|
||||
]
|
||||
current_chunk_length += overlap_max_length + 1
|
||||
|
||||
current_chunk += [sentence]
|
||||
current_chunk_length += sentence_length
|
||||
|
||||
else: # sentence longer than maximum length -> chop up and try again
|
||||
sentences[i : i + 1] = [
|
||||
chunk
|
||||
for chunk, _ in chunk_content(sentence, target_chunk_length, tokenizer)
|
||||
]
|
||||
continue
|
||||
|
||||
i += 1
|
||||
last_sentence = sentence
|
||||
last_sentence_length = sentence_length
|
||||
|
||||
if current_chunk:
|
||||
yield " ".join(current_chunk), current_chunk_length
|
||||
37
forge/forge/file_storage/__init__.py
Normal file
37
forge/forge/file_storage/__init__.py
Normal file
@@ -0,0 +1,37 @@
|
||||
import enum
|
||||
from pathlib import Path
|
||||
|
||||
from .base import FileStorage
|
||||
|
||||
|
||||
class FileStorageBackendName(str, enum.Enum):
|
||||
LOCAL = "local"
|
||||
GCS = "gcs"
|
||||
S3 = "s3"
|
||||
|
||||
|
||||
def get_storage(
|
||||
backend: FileStorageBackendName,
|
||||
root_path: Path = Path("."),
|
||||
restrict_to_root: bool = True,
|
||||
) -> FileStorage:
|
||||
match backend:
|
||||
case FileStorageBackendName.LOCAL:
|
||||
from .local import FileStorageConfiguration, LocalFileStorage
|
||||
|
||||
config = FileStorageConfiguration.from_env()
|
||||
config.root = root_path
|
||||
config.restrict_to_root = restrict_to_root
|
||||
return LocalFileStorage(config)
|
||||
case FileStorageBackendName.S3:
|
||||
from .s3 import S3FileStorage, S3FileStorageConfiguration
|
||||
|
||||
config = S3FileStorageConfiguration.from_env()
|
||||
config.root = root_path
|
||||
return S3FileStorage(config)
|
||||
case FileStorageBackendName.GCS:
|
||||
from .gcs import GCSFileStorage, GCSFileStorageConfiguration
|
||||
|
||||
config = GCSFileStorageConfiguration.from_env()
|
||||
config.root = root_path
|
||||
return GCSFileStorage(config)
|
||||
283
forge/forge/file_storage/base.py
Normal file
283
forge/forge/file_storage/base.py
Normal file
@@ -0,0 +1,283 @@
|
||||
"""
|
||||
The FileStorage class provides an interface for interacting with a file storage.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
import shutil
|
||||
import tempfile
|
||||
from abc import ABC, abstractmethod
|
||||
from contextlib import contextmanager
|
||||
from pathlib import Path
|
||||
from typing import Any, BinaryIO, Callable, Generator, Literal, TextIO, overload
|
||||
|
||||
from watchdog.events import FileSystemEvent, FileSystemEventHandler
|
||||
from watchdog.observers import Observer
|
||||
|
||||
from forge.models.config import SystemConfiguration
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class FileStorageConfiguration(SystemConfiguration):
|
||||
restrict_to_root: bool = True
|
||||
root: Path = Path("/")
|
||||
|
||||
|
||||
class FileStorage(ABC):
|
||||
"""A class that represents a file storage."""
|
||||
|
||||
on_write_file: Callable[[Path], Any] | None = None
|
||||
"""
|
||||
Event hook, executed after writing a file.
|
||||
|
||||
Params:
|
||||
Path: The path of the file that was written, relative to the storage root.
|
||||
"""
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def root(self) -> Path:
|
||||
"""The root path of the file storage."""
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def restrict_to_root(self) -> bool:
|
||||
"""Whether to restrict file access to within the storage's root path."""
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def is_local(self) -> bool:
|
||||
"""Whether the storage is local (i.e. on the same machine, not cloud-based)."""
|
||||
|
||||
@abstractmethod
|
||||
def initialize(self) -> None:
|
||||
"""
|
||||
Calling `initialize()` should bring the storage to a ready-to-use state.
|
||||
For example, it can create the resource in which files will be stored, if it
|
||||
doesn't exist yet. E.g. a folder on disk, or an S3 Bucket.
|
||||
"""
|
||||
|
||||
@overload
|
||||
@abstractmethod
|
||||
def open_file(
|
||||
self,
|
||||
path: str | Path,
|
||||
mode: Literal["r", "w"] = "r",
|
||||
binary: Literal[False] = False,
|
||||
) -> TextIO:
|
||||
"""Returns a readable text file-like object representing the file."""
|
||||
|
||||
@overload
|
||||
@abstractmethod
|
||||
def open_file(
|
||||
self, path: str | Path, mode: Literal["r", "w"], binary: Literal[True]
|
||||
) -> BinaryIO:
|
||||
"""Returns a binary file-like object representing the file."""
|
||||
|
||||
@overload
|
||||
@abstractmethod
|
||||
def open_file(self, path: str | Path, *, binary: Literal[True]) -> BinaryIO:
|
||||
"""Returns a readable binary file-like object representing the file."""
|
||||
|
||||
@overload
|
||||
@abstractmethod
|
||||
def open_file(
|
||||
self, path: str | Path, mode: Literal["r", "w"] = "r", binary: bool = False
|
||||
) -> TextIO | BinaryIO:
|
||||
"""Returns a file-like object representing the file."""
|
||||
|
||||
@overload
|
||||
@abstractmethod
|
||||
def read_file(self, path: str | Path, binary: Literal[False] = False) -> str:
|
||||
"""Read a file in the storage as text."""
|
||||
...
|
||||
|
||||
@overload
|
||||
@abstractmethod
|
||||
def read_file(self, path: str | Path, binary: Literal[True]) -> bytes:
|
||||
"""Read a file in the storage as binary."""
|
||||
...
|
||||
|
||||
@overload
|
||||
@abstractmethod
|
||||
def read_file(self, path: str | Path, binary: bool = False) -> str | bytes:
|
||||
"""Read a file in the storage."""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def write_file(self, path: str | Path, content: str | bytes) -> None:
|
||||
"""Write to a file in the storage."""
|
||||
|
||||
@abstractmethod
|
||||
def list_files(self, path: str | Path = ".") -> list[Path]:
|
||||
"""List all files (recursively) in a directory in the storage."""
|
||||
|
||||
@abstractmethod
|
||||
def list_folders(
|
||||
self, path: str | Path = ".", recursive: bool = False
|
||||
) -> list[Path]:
|
||||
"""List all folders in a directory in the storage."""
|
||||
|
||||
@abstractmethod
|
||||
def delete_file(self, path: str | Path) -> None:
|
||||
"""Delete a file in the storage."""
|
||||
|
||||
@abstractmethod
|
||||
def delete_dir(self, path: str | Path) -> None:
|
||||
"""Delete an empty folder in the storage."""
|
||||
|
||||
@abstractmethod
|
||||
def exists(self, path: str | Path) -> bool:
|
||||
"""Check if a file or folder exists in the storage."""
|
||||
|
||||
@abstractmethod
|
||||
def rename(self, old_path: str | Path, new_path: str | Path) -> None:
|
||||
"""Rename a file or folder in the storage."""
|
||||
|
||||
@abstractmethod
|
||||
def copy(self, source: str | Path, destination: str | Path) -> None:
|
||||
"""Copy a file or folder with all contents in the storage."""
|
||||
|
||||
@abstractmethod
|
||||
def make_dir(self, path: str | Path) -> None:
|
||||
"""Create a directory in the storage if doesn't exist."""
|
||||
|
||||
@abstractmethod
|
||||
def clone_with_subroot(self, subroot: str | Path) -> FileStorage:
|
||||
"""Create a new FileStorage with a subroot of the current storage."""
|
||||
|
||||
def get_path(self, relative_path: str | Path) -> Path:
|
||||
"""Get the full path for an item in the storage.
|
||||
|
||||
Parameters:
|
||||
relative_path: The relative path to resolve in the storage.
|
||||
|
||||
Returns:
|
||||
Path: The resolved path relative to the storage.
|
||||
"""
|
||||
return self._sanitize_path(relative_path)
|
||||
|
||||
@contextmanager
|
||||
def mount(self, path: str | Path = ".") -> Generator[Path, Any, None]:
|
||||
"""Mount the file storage and provide a local path."""
|
||||
local_path = tempfile.mkdtemp(dir=path)
|
||||
|
||||
observer = Observer()
|
||||
try:
|
||||
# Copy all files to the local directory
|
||||
files = self.list_files()
|
||||
for file in files:
|
||||
file_path = local_path / file
|
||||
file_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
content = self.read_file(file, binary=True)
|
||||
file_path.write_bytes(content)
|
||||
|
||||
# Sync changes
|
||||
event_handler = FileSyncHandler(self, local_path)
|
||||
observer.schedule(event_handler, local_path, recursive=True)
|
||||
observer.start()
|
||||
|
||||
yield Path(local_path)
|
||||
finally:
|
||||
observer.stop()
|
||||
observer.join()
|
||||
shutil.rmtree(local_path)
|
||||
|
||||
def _sanitize_path(
|
||||
self,
|
||||
path: str | Path,
|
||||
) -> Path:
|
||||
"""Resolve the relative path within the given root if possible.
|
||||
|
||||
Parameters:
|
||||
relative_path: The relative path to resolve.
|
||||
|
||||
Returns:
|
||||
Path: The resolved path.
|
||||
|
||||
Raises:
|
||||
ValueError: If the path is absolute and a root is provided.
|
||||
ValueError: If the path is outside the root and the root is restricted.
|
||||
"""
|
||||
|
||||
# Posix systems disallow null bytes in paths. Windows is agnostic about it.
|
||||
# Do an explicit check here for all sorts of null byte representations.
|
||||
if "\0" in str(path):
|
||||
raise ValueError("Embedded null byte")
|
||||
|
||||
logger.debug(f"Resolving path '{path}' in storage '{self.root}'")
|
||||
|
||||
relative_path = Path(path)
|
||||
|
||||
# Allow absolute paths if they are contained in the storage.
|
||||
if (
|
||||
relative_path.is_absolute()
|
||||
and self.restrict_to_root
|
||||
and not relative_path.is_relative_to(self.root)
|
||||
):
|
||||
raise ValueError(
|
||||
f"Attempted to access absolute path '{relative_path}' "
|
||||
f"in storage '{self.root}'"
|
||||
)
|
||||
|
||||
full_path = self.root / relative_path
|
||||
if self.is_local:
|
||||
full_path = full_path.resolve()
|
||||
else:
|
||||
full_path = Path(os.path.normpath(full_path))
|
||||
|
||||
logger.debug(f"Joined paths as '{full_path}'")
|
||||
|
||||
if self.restrict_to_root and not full_path.is_relative_to(self.root):
|
||||
raise ValueError(
|
||||
f"Attempted to access path '{full_path}' "
|
||||
f"outside of storage '{self.root}'."
|
||||
)
|
||||
|
||||
return full_path
|
||||
|
||||
|
||||
class FileSyncHandler(FileSystemEventHandler):
|
||||
def __init__(self, storage: FileStorage, path: str | Path = "."):
|
||||
self.storage = storage
|
||||
self.path = Path(path)
|
||||
|
||||
def on_modified(self, event: FileSystemEvent):
|
||||
if event.is_directory:
|
||||
return
|
||||
|
||||
file_path = Path(event.src_path).relative_to(self.path)
|
||||
content = file_path.read_bytes()
|
||||
# Must execute write_file synchronously because the hook is synchronous
|
||||
# TODO: Schedule write operation using asyncio.create_task (non-blocking)
|
||||
asyncio.get_event_loop().run_until_complete(
|
||||
self.storage.write_file(file_path, content)
|
||||
)
|
||||
|
||||
def on_created(self, event: FileSystemEvent):
|
||||
if event.is_directory:
|
||||
self.storage.make_dir(event.src_path)
|
||||
return
|
||||
|
||||
file_path = Path(event.src_path).relative_to(self.path)
|
||||
content = file_path.read_bytes()
|
||||
# Must execute write_file synchronously because the hook is synchronous
|
||||
# TODO: Schedule write operation using asyncio.create_task (non-blocking)
|
||||
asyncio.get_event_loop().run_until_complete(
|
||||
self.storage.write_file(file_path, content)
|
||||
)
|
||||
|
||||
def on_deleted(self, event: FileSystemEvent):
|
||||
if event.is_directory:
|
||||
self.storage.delete_dir(event.src_path)
|
||||
return
|
||||
|
||||
file_path = event.src_path
|
||||
self.storage.delete_file(file_path)
|
||||
|
||||
def on_moved(self, event: FileSystemEvent):
|
||||
self.storage.rename(event.src_path, event.dest_path)
|
||||
267
forge/forge/file_storage/gcs.py
Normal file
267
forge/forge/file_storage/gcs.py
Normal file
@@ -0,0 +1,267 @@
|
||||
"""
|
||||
The GCSWorkspace class provides an interface for interacting with a file workspace, and
|
||||
stores the files in a Google Cloud Storage bucket.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import inspect
|
||||
import logging
|
||||
from io import TextIOWrapper
|
||||
from pathlib import Path
|
||||
from typing import Literal, overload
|
||||
|
||||
from google.cloud import storage
|
||||
from google.cloud.exceptions import NotFound
|
||||
from google.cloud.storage.fileio import BlobReader, BlobWriter
|
||||
|
||||
from forge.models.config import UserConfigurable
|
||||
|
||||
from .base import FileStorage, FileStorageConfiguration
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class GCSFileStorageConfiguration(FileStorageConfiguration):
|
||||
bucket: str = UserConfigurable("autogpt", from_env="STORAGE_BUCKET")
|
||||
|
||||
|
||||
class GCSFileStorage(FileStorage):
|
||||
"""A class that represents a Google Cloud Storage."""
|
||||
|
||||
_bucket: storage.Bucket
|
||||
|
||||
def __init__(self, config: GCSFileStorageConfiguration):
|
||||
self._bucket_name = config.bucket
|
||||
self._root = config.root
|
||||
# Add / at the beginning of the root path
|
||||
if not self._root.is_absolute():
|
||||
self._root = Path("/").joinpath(self._root)
|
||||
|
||||
self._gcs = storage.Client()
|
||||
super().__init__()
|
||||
|
||||
@property
|
||||
def root(self) -> Path:
|
||||
"""The root directory of the file storage."""
|
||||
return self._root
|
||||
|
||||
@property
|
||||
def restrict_to_root(self) -> bool:
|
||||
"""Whether to restrict generated paths to the root."""
|
||||
return True
|
||||
|
||||
@property
|
||||
def is_local(self) -> bool:
|
||||
"""Whether the storage is local (i.e. on the same machine, not cloud-based)."""
|
||||
return False
|
||||
|
||||
def initialize(self) -> None:
|
||||
logger.debug(f"Initializing {repr(self)}...")
|
||||
try:
|
||||
self._bucket = self._gcs.get_bucket(self._bucket_name)
|
||||
except NotFound:
|
||||
logger.info(f"Bucket '{self._bucket_name}' does not exist; creating it...")
|
||||
self._bucket = self._gcs.create_bucket(self._bucket_name)
|
||||
|
||||
def get_path(self, relative_path: str | Path) -> Path:
|
||||
# We set GCS root with "/" at the beginning
|
||||
# but relative_to("/") will remove it
|
||||
# because we don't actually want it in the storage filenames
|
||||
return super().get_path(relative_path).relative_to("/")
|
||||
|
||||
def _get_blob(self, path: str | Path) -> storage.Blob:
|
||||
path = self.get_path(path)
|
||||
return self._bucket.blob(str(path))
|
||||
|
||||
@overload
|
||||
def open_file(
|
||||
self,
|
||||
path: str | Path,
|
||||
mode: Literal["r", "w"] = "r",
|
||||
binary: Literal[False] = False,
|
||||
) -> TextIOWrapper:
|
||||
...
|
||||
|
||||
@overload
|
||||
def open_file(
|
||||
self, path: str | Path, mode: Literal["r"], binary: Literal[True]
|
||||
) -> BlobReader:
|
||||
...
|
||||
|
||||
@overload
|
||||
def open_file(
|
||||
self, path: str | Path, mode: Literal["w"], binary: Literal[True]
|
||||
) -> BlobWriter:
|
||||
...
|
||||
|
||||
@overload
|
||||
def open_file(
|
||||
self, path: str | Path, mode: Literal["r", "w"], binary: Literal[True]
|
||||
) -> BlobWriter | BlobReader:
|
||||
...
|
||||
|
||||
@overload
|
||||
def open_file(self, path: str | Path, *, binary: Literal[True]) -> BlobReader:
|
||||
...
|
||||
|
||||
@overload
|
||||
def open_file(
|
||||
self, path: str | Path, mode: Literal["r", "w"] = "r", binary: bool = False
|
||||
) -> BlobReader | BlobWriter | TextIOWrapper:
|
||||
...
|
||||
|
||||
# https://github.com/microsoft/pyright/issues/8007
|
||||
def open_file( # pyright: ignore[reportIncompatibleMethodOverride]
|
||||
self, path: str | Path, mode: Literal["r", "w"] = "r", binary: bool = False
|
||||
) -> BlobReader | BlobWriter | TextIOWrapper:
|
||||
"""Open a file in the storage."""
|
||||
blob = self._get_blob(path)
|
||||
blob.reload() # pin revision number to prevent version mixing while reading
|
||||
return blob.open(f"{mode}b" if binary else mode)
|
||||
|
||||
@overload
|
||||
def read_file(self, path: str | Path, binary: Literal[False] = False) -> str:
|
||||
"""Read a file in the storage as text."""
|
||||
...
|
||||
|
||||
@overload
|
||||
def read_file(self, path: str | Path, binary: Literal[True]) -> bytes:
|
||||
"""Read a file in the storage as binary."""
|
||||
...
|
||||
|
||||
@overload
|
||||
def read_file(self, path: str | Path, binary: bool = False) -> str | bytes:
|
||||
"""Read a file in the storage."""
|
||||
...
|
||||
|
||||
def read_file(self, path: str | Path, binary: bool = False) -> str | bytes:
|
||||
"""Read a file in the storage."""
|
||||
return self.open_file(path, "r", binary).read()
|
||||
|
||||
async def write_file(self, path: str | Path, content: str | bytes) -> None:
|
||||
"""Write to a file in the storage."""
|
||||
blob = self._get_blob(path)
|
||||
|
||||
blob.upload_from_string(
|
||||
data=content,
|
||||
content_type=(
|
||||
"text/plain"
|
||||
if type(content) is str
|
||||
# TODO: get MIME type from file extension or binary content
|
||||
else "application/octet-stream"
|
||||
),
|
||||
)
|
||||
|
||||
if self.on_write_file:
|
||||
path = Path(path)
|
||||
if path.is_absolute():
|
||||
path = path.relative_to(self.root)
|
||||
res = self.on_write_file(path)
|
||||
if inspect.isawaitable(res):
|
||||
await res
|
||||
|
||||
def list_files(self, path: str | Path = ".") -> list[Path]:
|
||||
"""List all files (recursively) in a directory in the storage."""
|
||||
path = self.get_path(path)
|
||||
return [
|
||||
Path(blob.name).relative_to(path)
|
||||
for blob in self._bucket.list_blobs(
|
||||
prefix=f"{path}/" if path != Path(".") else None
|
||||
)
|
||||
]
|
||||
|
||||
def list_folders(
|
||||
self, path: str | Path = ".", recursive: bool = False
|
||||
) -> list[Path]:
|
||||
"""List 'directories' directly in a given path or recursively in the storage."""
|
||||
path = self.get_path(path)
|
||||
folder_names = set()
|
||||
|
||||
# List objects with the specified prefix and delimiter
|
||||
for blob in self._bucket.list_blobs(prefix=path):
|
||||
# Remove path prefix and the object name (last part)
|
||||
folder = Path(blob.name).relative_to(path).parent
|
||||
if not folder or folder == Path("."):
|
||||
continue
|
||||
# For non-recursive, only add the first level of folders
|
||||
if not recursive:
|
||||
folder_names.add(folder.parts[0])
|
||||
else:
|
||||
# For recursive, need to add all nested folders
|
||||
for i in range(len(folder.parts)):
|
||||
folder_names.add("/".join(folder.parts[: i + 1]))
|
||||
|
||||
return [Path(f) for f in folder_names]
|
||||
|
||||
def delete_file(self, path: str | Path) -> None:
|
||||
"""Delete a file in the storage."""
|
||||
path = self.get_path(path)
|
||||
blob = self._bucket.blob(str(path))
|
||||
blob.delete()
|
||||
|
||||
def delete_dir(self, path: str | Path) -> None:
|
||||
"""Delete an empty folder in the storage."""
|
||||
# Since GCS does not have directories, we don't need to do anything
|
||||
pass
|
||||
|
||||
def exists(self, path: str | Path) -> bool:
|
||||
"""Check if a file or folder exists in GCS storage."""
|
||||
path = self.get_path(path)
|
||||
# Check for exact blob match (file)
|
||||
blob = self._bucket.blob(str(path))
|
||||
if blob.exists():
|
||||
return True
|
||||
# Check for any blobs with prefix (folder)
|
||||
prefix = f"{str(path).rstrip('/')}/"
|
||||
blobs = self._bucket.list_blobs(prefix=prefix, max_results=1)
|
||||
return next(blobs, None) is not None
|
||||
|
||||
def make_dir(self, path: str | Path) -> None:
|
||||
"""Create a directory in the storage if doesn't exist."""
|
||||
# GCS does not have directories, so we don't need to do anything
|
||||
pass
|
||||
|
||||
def rename(self, old_path: str | Path, new_path: str | Path) -> None:
|
||||
"""Rename a file or folder in the storage."""
|
||||
old_path = self.get_path(old_path)
|
||||
new_path = self.get_path(new_path)
|
||||
blob = self._bucket.blob(str(old_path))
|
||||
# If the blob with exact name exists, rename it
|
||||
if blob.exists():
|
||||
self._bucket.rename_blob(blob, new_name=str(new_path))
|
||||
return
|
||||
# Otherwise, rename all blobs with the prefix (folder)
|
||||
for blob in self._bucket.list_blobs(prefix=f"{old_path}/"):
|
||||
new_name = str(blob.name).replace(str(old_path), str(new_path), 1)
|
||||
self._bucket.rename_blob(blob, new_name=new_name)
|
||||
|
||||
def copy(self, source: str | Path, destination: str | Path) -> None:
|
||||
"""Copy a file or folder with all contents in the storage."""
|
||||
source = self.get_path(source)
|
||||
destination = self.get_path(destination)
|
||||
# If the source is a file, copy it
|
||||
if self._bucket.blob(str(source)).exists():
|
||||
self._bucket.copy_blob(
|
||||
self._bucket.blob(str(source)), self._bucket, str(destination)
|
||||
)
|
||||
return
|
||||
# Otherwise, copy all blobs with the prefix (folder)
|
||||
for blob in self._bucket.list_blobs(prefix=f"{source}/"):
|
||||
new_name = str(blob.name).replace(str(source), str(destination), 1)
|
||||
self._bucket.copy_blob(blob, self._bucket, new_name)
|
||||
|
||||
def clone_with_subroot(self, subroot: str | Path) -> GCSFileStorage:
|
||||
"""Create a new GCSFileStorage with a subroot of the current storage."""
|
||||
file_storage = GCSFileStorage(
|
||||
GCSFileStorageConfiguration(
|
||||
root=Path("/").joinpath(self.get_path(subroot)),
|
||||
bucket=self._bucket_name,
|
||||
)
|
||||
)
|
||||
file_storage._gcs = self._gcs
|
||||
file_storage._bucket = self._bucket
|
||||
return file_storage
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"{__class__.__name__}(bucket='{self._bucket_name}', root={self._root})"
|
||||
188
forge/forge/file_storage/local.py
Normal file
188
forge/forge/file_storage/local.py
Normal file
@@ -0,0 +1,188 @@
|
||||
"""
|
||||
The LocalFileStorage class implements a FileStorage that works with local files.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import inspect
|
||||
import logging
|
||||
from contextlib import contextmanager
|
||||
from pathlib import Path
|
||||
from typing import Any, BinaryIO, Generator, Literal, TextIO, overload
|
||||
|
||||
from .base import FileStorage, FileStorageConfiguration
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class LocalFileStorage(FileStorage):
|
||||
"""A class that represents a file storage."""
|
||||
|
||||
def __init__(self, config: FileStorageConfiguration):
|
||||
self._root = config.root.resolve()
|
||||
self._restrict_to_root = config.restrict_to_root
|
||||
self.make_dir(self.root)
|
||||
super().__init__()
|
||||
|
||||
@property
|
||||
def root(self) -> Path:
|
||||
"""The root directory of the file storage."""
|
||||
return self._root
|
||||
|
||||
@property
|
||||
def restrict_to_root(self) -> bool:
|
||||
"""Whether to restrict generated paths to the root."""
|
||||
return self._restrict_to_root
|
||||
|
||||
@property
|
||||
def is_local(self) -> bool:
|
||||
"""Whether the storage is local (i.e. on the same machine, not cloud-based)."""
|
||||
return True
|
||||
|
||||
def initialize(self) -> None:
|
||||
self.root.mkdir(exist_ok=True, parents=True)
|
||||
|
||||
@overload
|
||||
def open_file(
|
||||
self,
|
||||
path: str | Path,
|
||||
mode: Literal["w", "r"] = "r",
|
||||
binary: Literal[False] = False,
|
||||
) -> TextIO:
|
||||
...
|
||||
|
||||
@overload
|
||||
def open_file(
|
||||
self, path: str | Path, mode: Literal["w", "r"], binary: Literal[True]
|
||||
) -> BinaryIO:
|
||||
...
|
||||
|
||||
@overload
|
||||
def open_file(self, path: str | Path, *, binary: Literal[True]) -> BinaryIO:
|
||||
...
|
||||
|
||||
@overload
|
||||
def open_file(
|
||||
self, path: str | Path, mode: Literal["w", "r"] = "r", binary: bool = False
|
||||
) -> TextIO | BinaryIO:
|
||||
...
|
||||
|
||||
def open_file(
|
||||
self, path: str | Path, mode: Literal["w", "r"] = "r", binary: bool = False
|
||||
) -> TextIO | BinaryIO:
|
||||
"""Open a file in the storage."""
|
||||
return self._open_file(path, f"{mode}b" if binary else mode)
|
||||
|
||||
def _open_file(self, path: str | Path, mode: str) -> TextIO | BinaryIO:
|
||||
full_path = self.get_path(path)
|
||||
if any(m in mode for m in ("w", "a", "x")):
|
||||
full_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
return open(full_path, mode) # type: ignore
|
||||
|
||||
@overload
|
||||
def read_file(self, path: str | Path, binary: Literal[False] = False) -> str:
|
||||
"""Read a file in the storage as text."""
|
||||
...
|
||||
|
||||
@overload
|
||||
def read_file(self, path: str | Path, binary: Literal[True]) -> bytes:
|
||||
"""Read a file in the storage as binary."""
|
||||
...
|
||||
|
||||
@overload
|
||||
def read_file(self, path: str | Path, binary: bool = False) -> str | bytes:
|
||||
"""Read a file in the storage."""
|
||||
...
|
||||
|
||||
def read_file(self, path: str | Path, binary: bool = False) -> str | bytes:
|
||||
"""Read a file in the storage."""
|
||||
with self._open_file(path, "rb" if binary else "r") as file:
|
||||
return file.read()
|
||||
|
||||
async def write_file(self, path: str | Path, content: str | bytes) -> None:
|
||||
"""Write to a file in the storage."""
|
||||
with self._open_file(path, "wb" if type(content) is bytes else "w") as file:
|
||||
file.write(content) # type: ignore
|
||||
|
||||
if self.on_write_file:
|
||||
path = Path(path)
|
||||
if path.is_absolute():
|
||||
path = path.relative_to(self.root)
|
||||
res = self.on_write_file(path)
|
||||
if inspect.isawaitable(res):
|
||||
await res
|
||||
|
||||
def list_files(self, path: str | Path = ".") -> list[Path]:
|
||||
"""List all files (recursively) in a directory in the storage."""
|
||||
path = self.get_path(path)
|
||||
return [file.relative_to(path) for file in path.rglob("*") if file.is_file()]
|
||||
|
||||
def list_folders(
|
||||
self, path: str | Path = ".", recursive: bool = False
|
||||
) -> list[Path]:
|
||||
"""List directories directly in a given path or recursively."""
|
||||
path = self.get_path(path)
|
||||
if recursive:
|
||||
return [
|
||||
folder.relative_to(path)
|
||||
for folder in path.rglob("*")
|
||||
if folder.is_dir()
|
||||
]
|
||||
else:
|
||||
return [
|
||||
folder.relative_to(path) for folder in path.iterdir() if folder.is_dir()
|
||||
]
|
||||
|
||||
def delete_file(self, path: str | Path) -> None:
|
||||
"""Delete a file in the storage."""
|
||||
full_path = self.get_path(path)
|
||||
full_path.unlink()
|
||||
|
||||
def delete_dir(self, path: str | Path) -> None:
|
||||
"""Delete an empty folder in the storage."""
|
||||
full_path = self.get_path(path)
|
||||
full_path.rmdir()
|
||||
|
||||
def exists(self, path: str | Path) -> bool:
|
||||
"""Check if a file or folder exists in the storage."""
|
||||
return self.get_path(path).exists()
|
||||
|
||||
def make_dir(self, path: str | Path) -> None:
|
||||
"""Create a directory in the storage if doesn't exist."""
|
||||
full_path = self.get_path(path)
|
||||
full_path.mkdir(exist_ok=True, parents=True)
|
||||
|
||||
def rename(self, old_path: str | Path, new_path: str | Path) -> None:
|
||||
"""Rename a file or folder in the storage."""
|
||||
old_path = self.get_path(old_path)
|
||||
new_path = self.get_path(new_path)
|
||||
old_path.rename(new_path)
|
||||
|
||||
def copy(self, source: str | Path, destination: str | Path) -> None:
|
||||
"""Copy a file or folder with all contents in the storage."""
|
||||
source = self.get_path(source)
|
||||
destination = self.get_path(destination)
|
||||
if source.is_file():
|
||||
destination.write_bytes(source.read_bytes())
|
||||
else:
|
||||
destination.mkdir(exist_ok=True, parents=True)
|
||||
for file in source.rglob("*"):
|
||||
if file.is_file():
|
||||
target = destination / file.relative_to(source)
|
||||
target.parent.mkdir(exist_ok=True, parents=True)
|
||||
target.write_bytes(file.read_bytes())
|
||||
|
||||
def clone_with_subroot(self, subroot: str | Path) -> FileStorage:
|
||||
"""Create a new LocalFileStorage with a subroot of the current storage."""
|
||||
return LocalFileStorage(
|
||||
FileStorageConfiguration(
|
||||
root=self.get_path(subroot),
|
||||
restrict_to_root=self.restrict_to_root,
|
||||
)
|
||||
)
|
||||
|
||||
@contextmanager
|
||||
def mount(self, path: str | Path = ".") -> Generator[Path, Any, None]:
|
||||
"""Mount the file storage and provide a local path."""
|
||||
# No need to do anything for local storage
|
||||
yield Path(self.get_path(".")).absolute()
|
||||
353
forge/forge/file_storage/s3.py
Normal file
353
forge/forge/file_storage/s3.py
Normal file
@@ -0,0 +1,353 @@
|
||||
"""
|
||||
The S3Workspace class provides an interface for interacting with a file workspace, and
|
||||
stores the files in an S3 bucket.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import contextlib
|
||||
import inspect
|
||||
import logging
|
||||
from io import TextIOWrapper
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, BinaryIO, Literal, Optional, overload
|
||||
|
||||
import boto3
|
||||
import botocore.exceptions
|
||||
from pydantic import SecretStr
|
||||
|
||||
from forge.models.config import UserConfigurable
|
||||
|
||||
from .base import FileStorage, FileStorageConfiguration
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import mypy_boto3_s3
|
||||
from botocore.response import StreamingBody
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class S3FileStorageConfiguration(FileStorageConfiguration):
|
||||
bucket: str = UserConfigurable("autogpt", from_env="STORAGE_BUCKET")
|
||||
s3_endpoint_url: Optional[SecretStr] = UserConfigurable(from_env="S3_ENDPOINT_URL")
|
||||
|
||||
|
||||
class S3FileStorage(FileStorage):
|
||||
"""A class that represents an S3 storage."""
|
||||
|
||||
_bucket: mypy_boto3_s3.service_resource.Bucket
|
||||
|
||||
def __init__(self, config: S3FileStorageConfiguration):
|
||||
self._bucket_name = config.bucket
|
||||
self._root = config.root
|
||||
# Add / at the beginning of the root path
|
||||
if not self._root.is_absolute():
|
||||
self._root = Path("/").joinpath(self._root)
|
||||
|
||||
# https://boto3.amazonaws.com/v1/documentation/api/latest/guide/configuration.html
|
||||
self._s3 = boto3.resource(
|
||||
"s3",
|
||||
endpoint_url=(
|
||||
config.s3_endpoint_url.get_secret_value()
|
||||
if config.s3_endpoint_url
|
||||
else None
|
||||
),
|
||||
)
|
||||
|
||||
super().__init__()
|
||||
|
||||
@property
|
||||
def root(self) -> Path:
|
||||
"""The root directory of the file storage."""
|
||||
return self._root
|
||||
|
||||
@property
|
||||
def restrict_to_root(self):
|
||||
"""Whether to restrict generated paths to the root."""
|
||||
return True
|
||||
|
||||
@property
|
||||
def is_local(self) -> bool:
|
||||
"""Whether the storage is local (i.e. on the same machine, not cloud-based)."""
|
||||
return False
|
||||
|
||||
def initialize(self) -> None:
|
||||
logger.debug(f"Initializing {repr(self)}...")
|
||||
try:
|
||||
self._s3.meta.client.head_bucket(Bucket=self._bucket_name)
|
||||
self._bucket = self._s3.Bucket(self._bucket_name)
|
||||
except botocore.exceptions.ClientError as e:
|
||||
if "(404)" not in str(e):
|
||||
raise
|
||||
logger.info(f"Bucket '{self._bucket_name}' does not exist; creating it...")
|
||||
self._bucket = self._s3.create_bucket(Bucket=self._bucket_name)
|
||||
|
||||
def get_path(self, relative_path: str | Path) -> Path:
|
||||
# We set S3 root with "/" at the beginning
|
||||
# but relative_to("/") will remove it
|
||||
# because we don't actually want it in the storage filenames
|
||||
return super().get_path(relative_path).relative_to("/")
|
||||
|
||||
def _get_obj(self, path: str | Path) -> mypy_boto3_s3.service_resource.Object:
|
||||
"""Get an S3 object."""
|
||||
obj = self._bucket.Object(str(path))
|
||||
with contextlib.suppress(botocore.exceptions.ClientError):
|
||||
obj.load()
|
||||
return obj
|
||||
|
||||
@overload
|
||||
def open_file(
|
||||
self,
|
||||
path: str | Path,
|
||||
mode: Literal["r", "w"] = "r",
|
||||
binary: Literal[False] = False,
|
||||
) -> TextIOWrapper:
|
||||
...
|
||||
|
||||
@overload
|
||||
def open_file(
|
||||
self, path: str | Path, mode: Literal["r", "w"], binary: Literal[True]
|
||||
) -> S3BinaryIOWrapper:
|
||||
...
|
||||
|
||||
@overload
|
||||
def open_file(
|
||||
self, path: str | Path, *, binary: Literal[True]
|
||||
) -> S3BinaryIOWrapper:
|
||||
...
|
||||
|
||||
@overload
|
||||
def open_file(
|
||||
self, path: str | Path, mode: Literal["r", "w"] = "r", binary: bool = False
|
||||
) -> S3BinaryIOWrapper | TextIOWrapper:
|
||||
...
|
||||
|
||||
def open_file(
|
||||
self, path: str | Path, mode: Literal["r", "w"] = "r", binary: bool = False
|
||||
) -> TextIOWrapper | S3BinaryIOWrapper:
|
||||
"""Open a file in the storage."""
|
||||
path = self.get_path(path)
|
||||
body = S3BinaryIOWrapper(self._get_obj(path).get()["Body"], str(path))
|
||||
return body if binary else TextIOWrapper(body)
|
||||
|
||||
@overload
|
||||
def read_file(self, path: str | Path, binary: Literal[False] = False) -> str:
|
||||
"""Read a file in the storage as text."""
|
||||
...
|
||||
|
||||
@overload
|
||||
def read_file(self, path: str | Path, binary: Literal[True]) -> bytes:
|
||||
"""Read a file in the storage as binary."""
|
||||
...
|
||||
|
||||
@overload
|
||||
def read_file(self, path: str | Path, binary: bool = False) -> str | bytes:
|
||||
"""Read a file in the storage."""
|
||||
...
|
||||
|
||||
def read_file(self, path: str | Path, binary: bool = False) -> str | bytes:
|
||||
"""Read a file in the storage."""
|
||||
return self.open_file(path, binary=binary).read()
|
||||
|
||||
async def write_file(self, path: str | Path, content: str | bytes) -> None:
|
||||
"""Write to a file in the storage."""
|
||||
obj = self._get_obj(self.get_path(path))
|
||||
obj.put(Body=content)
|
||||
|
||||
if self.on_write_file:
|
||||
path = Path(path)
|
||||
if path.is_absolute():
|
||||
path = path.relative_to(self.root)
|
||||
res = self.on_write_file(path)
|
||||
if inspect.isawaitable(res):
|
||||
await res
|
||||
|
||||
def list_files(self, path: str | Path = ".") -> list[Path]:
|
||||
"""List all files (recursively) in a directory in the storage."""
|
||||
path = self.get_path(path)
|
||||
if path == Path("."): # root level of bucket
|
||||
return [Path(obj.key) for obj in self._bucket.objects.all()]
|
||||
else:
|
||||
return [
|
||||
Path(obj.key).relative_to(path)
|
||||
for obj in self._bucket.objects.filter(Prefix=f"{path}/")
|
||||
]
|
||||
|
||||
def list_folders(
|
||||
self, path: str | Path = ".", recursive: bool = False
|
||||
) -> list[Path]:
|
||||
"""List 'directories' directly in a given path or recursively in the storage."""
|
||||
path = self.get_path(path)
|
||||
folder_names = set()
|
||||
|
||||
# List objects with the specified prefix and delimiter
|
||||
for obj_summary in self._bucket.objects.filter(Prefix=str(path)):
|
||||
# Remove path prefix and the object name (last part)
|
||||
folder = Path(obj_summary.key).relative_to(path).parent
|
||||
if not folder or folder == Path("."):
|
||||
continue
|
||||
# For non-recursive, only add the first level of folders
|
||||
if not recursive:
|
||||
folder_names.add(folder.parts[0])
|
||||
else:
|
||||
# For recursive, need to add all nested folders
|
||||
for i in range(len(folder.parts)):
|
||||
folder_names.add("/".join(folder.parts[: i + 1]))
|
||||
|
||||
return [Path(f) for f in folder_names]
|
||||
|
||||
def delete_file(self, path: str | Path) -> None:
|
||||
"""Delete a file in the storage."""
|
||||
path = self.get_path(path)
|
||||
obj = self._s3.Object(self._bucket_name, str(path))
|
||||
obj.delete()
|
||||
|
||||
def delete_dir(self, path: str | Path) -> None:
|
||||
"""Delete an empty folder in the storage."""
|
||||
# S3 does not have directories, so we don't need to do anything
|
||||
pass
|
||||
|
||||
def exists(self, path: str | Path) -> bool:
|
||||
"""Check if a file or folder exists in S3 storage."""
|
||||
path = self.get_path(path)
|
||||
try:
|
||||
# Check for exact object match (file)
|
||||
self._s3.meta.client.head_object(Bucket=self._bucket_name, Key=str(path))
|
||||
return True
|
||||
except botocore.exceptions.ClientError as e:
|
||||
if e.response.get("ResponseMetadata", {}).get("HTTPStatusCode") == 404:
|
||||
# If the object does not exist,
|
||||
# check for objects with the prefix (folder)
|
||||
prefix = f"{str(path).rstrip('/')}/"
|
||||
objs = list(self._bucket.objects.filter(Prefix=prefix, MaxKeys=1))
|
||||
return len(objs) > 0 # True if any objects exist with the prefix
|
||||
else:
|
||||
raise # Re-raise for any other client errors
|
||||
|
||||
def make_dir(self, path: str | Path) -> None:
|
||||
"""Create a directory in the storage if doesn't exist."""
|
||||
# S3 does not have directories, so we don't need to do anything
|
||||
pass
|
||||
|
||||
def rename(self, old_path: str | Path, new_path: str | Path) -> None:
|
||||
"""Rename a file or folder in the storage."""
|
||||
old_path = str(self.get_path(old_path))
|
||||
new_path = str(self.get_path(new_path))
|
||||
|
||||
try:
|
||||
# If file exists, rename it
|
||||
self._s3.meta.client.head_object(Bucket=self._bucket_name, Key=old_path)
|
||||
self._s3.meta.client.copy_object(
|
||||
CopySource={"Bucket": self._bucket_name, "Key": old_path},
|
||||
Bucket=self._bucket_name,
|
||||
Key=new_path,
|
||||
)
|
||||
self._s3.meta.client.delete_object(Bucket=self._bucket_name, Key=old_path)
|
||||
except botocore.exceptions.ClientError as e:
|
||||
if e.response.get("ResponseMetadata", {}).get("HTTPStatusCode") == 404:
|
||||
# If the object does not exist,
|
||||
# it may be a folder
|
||||
prefix = f"{old_path.rstrip('/')}/"
|
||||
objs = list(self._bucket.objects.filter(Prefix=prefix))
|
||||
for obj in objs:
|
||||
new_key = new_path + obj.key[len(old_path) :]
|
||||
self._s3.meta.client.copy_object(
|
||||
CopySource={"Bucket": self._bucket_name, "Key": obj.key},
|
||||
Bucket=self._bucket_name,
|
||||
Key=new_key,
|
||||
)
|
||||
self._s3.meta.client.delete_object(
|
||||
Bucket=self._bucket_name, Key=obj.key
|
||||
)
|
||||
else:
|
||||
raise # Re-raise for any other client errors
|
||||
|
||||
def copy(self, source: str | Path, destination: str | Path) -> None:
|
||||
"""Copy a file or folder with all contents in the storage."""
|
||||
source = str(self.get_path(source))
|
||||
destination = str(self.get_path(destination))
|
||||
|
||||
try:
|
||||
# If source is a file, copy it
|
||||
self._s3.meta.client.head_object(Bucket=self._bucket_name, Key=source)
|
||||
self._s3.meta.client.copy_object(
|
||||
CopySource={"Bucket": self._bucket_name, "Key": source},
|
||||
Bucket=self._bucket_name,
|
||||
Key=destination,
|
||||
)
|
||||
except botocore.exceptions.ClientError as e:
|
||||
if e.response.get("ResponseMetadata", {}).get("HTTPStatusCode") == 404:
|
||||
# If the object does not exist,
|
||||
# it may be a folder
|
||||
prefix = f"{source.rstrip('/')}/"
|
||||
objs = list(self._bucket.objects.filter(Prefix=prefix))
|
||||
for obj in objs:
|
||||
new_key = destination + obj.key[len(source) :]
|
||||
self._s3.meta.client.copy_object(
|
||||
CopySource={"Bucket": self._bucket_name, "Key": obj.key},
|
||||
Bucket=self._bucket_name,
|
||||
Key=new_key,
|
||||
)
|
||||
else:
|
||||
raise
|
||||
|
||||
def clone_with_subroot(self, subroot: str | Path) -> S3FileStorage:
|
||||
"""Create a new S3FileStorage with a subroot of the current storage."""
|
||||
file_storage = S3FileStorage(
|
||||
S3FileStorageConfiguration(
|
||||
bucket=self._bucket_name,
|
||||
root=Path("/").joinpath(self.get_path(subroot)),
|
||||
s3_endpoint_url=SecretStr(self._s3.meta.client.meta.endpoint_url),
|
||||
)
|
||||
)
|
||||
file_storage._s3 = self._s3
|
||||
file_storage._bucket = self._bucket
|
||||
return file_storage
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"{__class__.__name__}(bucket='{self._bucket_name}', root={self._root})"
|
||||
|
||||
|
||||
class S3BinaryIOWrapper(BinaryIO):
|
||||
def __init__(self, body: StreamingBody, name: str):
|
||||
self.body = body
|
||||
self._name = name
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return self._name
|
||||
|
||||
def read(self, size: int = -1) -> bytes:
|
||||
return self.body.read(size if size > 0 else None)
|
||||
|
||||
def readinto(self, b: bytearray) -> int:
|
||||
data = self.read(len(b))
|
||||
b[: len(data)] = data
|
||||
return len(data)
|
||||
|
||||
def close(self) -> None:
|
||||
self.body.close()
|
||||
|
||||
def fileno(self) -> int:
|
||||
return self.body.fileno()
|
||||
|
||||
def flush(self) -> None:
|
||||
self.body.flush()
|
||||
|
||||
def isatty(self) -> bool:
|
||||
return self.body.isatty()
|
||||
|
||||
def readable(self) -> bool:
|
||||
return self.body.readable()
|
||||
|
||||
def seekable(self) -> bool:
|
||||
return self.body.seekable()
|
||||
|
||||
def writable(self) -> bool:
|
||||
return False
|
||||
|
||||
def __enter__(self):
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
self.body.close()
|
||||
0
forge/forge/json/__init__.py
Normal file
0
forge/forge/json/__init__.py
Normal file
93
forge/forge/json/parsing.py
Normal file
93
forge/forge/json/parsing.py
Normal file
@@ -0,0 +1,93 @@
|
||||
import logging
|
||||
import re
|
||||
from typing import Any
|
||||
|
||||
import demjson3
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def json_loads(json_str: str) -> Any:
|
||||
"""Parse a JSON string, tolerating minor syntax issues:
|
||||
- Missing, extra and trailing commas
|
||||
- Extraneous newlines and whitespace outside of string literals
|
||||
- Inconsistent spacing after colons and commas
|
||||
- Missing closing brackets or braces
|
||||
- Numbers: binary, hex, octal, trailing and prefixed decimal points
|
||||
- Different encodings
|
||||
- Surrounding markdown code block
|
||||
- Comments
|
||||
|
||||
Args:
|
||||
json_str: The JSON string to parse.
|
||||
|
||||
Returns:
|
||||
The parsed JSON object, same as built-in json.loads.
|
||||
"""
|
||||
# Remove possible code block
|
||||
pattern = r"```(?:json|JSON)*([\s\S]*?)```"
|
||||
match = re.search(pattern, json_str)
|
||||
|
||||
if match:
|
||||
json_str = match.group(1).strip()
|
||||
|
||||
json_result = demjson3.decode(json_str, return_errors=True)
|
||||
assert json_result is not None # by virtue of return_errors=True
|
||||
|
||||
if json_result.errors:
|
||||
logger.debug(
|
||||
"JSON parse errors:\n" + "\n".join(str(e) for e in json_result.errors)
|
||||
)
|
||||
|
||||
if json_result.object in (demjson3.syntax_error, demjson3.undefined):
|
||||
raise ValueError(
|
||||
f"Failed to parse JSON string: {json_str}", *json_result.errors
|
||||
)
|
||||
|
||||
return json_result.object
|
||||
|
||||
|
||||
def extract_dict_from_json(json_str: str) -> dict[str, Any]:
|
||||
# Sometimes the response includes the JSON in a code block with ```
|
||||
pattern = r"```(?:json|JSON)*([\s\S]*?)```"
|
||||
match = re.search(pattern, json_str)
|
||||
|
||||
if match:
|
||||
json_str = match.group(1).strip()
|
||||
else:
|
||||
# The string may contain JSON.
|
||||
json_pattern = r"{[\s\S]*}"
|
||||
match = re.search(json_pattern, json_str)
|
||||
|
||||
if match:
|
||||
json_str = match.group()
|
||||
|
||||
result = json_loads(json_str)
|
||||
if not isinstance(result, dict):
|
||||
raise ValueError(
|
||||
f"Response '''{json_str}''' evaluated to non-dict value {repr(result)}"
|
||||
)
|
||||
return result
|
||||
|
||||
|
||||
def extract_list_from_json(json_str: str) -> list[Any]:
|
||||
# Sometimes the response includes the JSON in a code block with ```
|
||||
pattern = r"```(?:json|JSON)*([\s\S]*?)```"
|
||||
match = re.search(pattern, json_str)
|
||||
|
||||
if match:
|
||||
json_str = match.group(1).strip()
|
||||
else:
|
||||
# The string may contain JSON.
|
||||
json_pattern = r"\[[\s\S]*\]"
|
||||
match = re.search(json_pattern, json_str)
|
||||
|
||||
if match:
|
||||
json_str = match.group()
|
||||
|
||||
result = json_loads(json_str)
|
||||
if not isinstance(result, list):
|
||||
raise ValueError(
|
||||
f"Response '''{json_str}''' evaluated to non-list value {repr(result)}"
|
||||
)
|
||||
return result
|
||||
0
forge/forge/llm/__init__.py
Normal file
0
forge/forge/llm/__init__.py
Normal file
8
forge/forge/llm/prompting/__init__.py
Normal file
8
forge/forge/llm/prompting/__init__.py
Normal file
@@ -0,0 +1,8 @@
|
||||
from .base import PromptStrategy
|
||||
from .schema import ChatPrompt, LanguageModelClassification
|
||||
|
||||
__all__ = [
|
||||
"LanguageModelClassification",
|
||||
"ChatPrompt",
|
||||
"PromptStrategy",
|
||||
]
|
||||
22
forge/forge/llm/prompting/base.py
Normal file
22
forge/forge/llm/prompting/base.py
Normal file
@@ -0,0 +1,22 @@
|
||||
import abc
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from forge.llm.providers import AssistantChatMessage
|
||||
|
||||
from .schema import ChatPrompt, LanguageModelClassification
|
||||
|
||||
|
||||
class PromptStrategy(abc.ABC):
|
||||
@property
|
||||
@abc.abstractmethod
|
||||
def model_classification(self) -> LanguageModelClassification:
|
||||
...
|
||||
|
||||
@abc.abstractmethod
|
||||
def build_prompt(self, *_, **kwargs) -> ChatPrompt:
|
||||
...
|
||||
|
||||
@abc.abstractmethod
|
||||
def parse_response_content(self, response: "AssistantChatMessage") -> Any:
|
||||
...
|
||||
@@ -0,0 +1,9 @@
|
||||
{% extends "techniques/expert.j2" %}
|
||||
{% block expert %}Human Resources{% endblock %}
|
||||
{% block prompt %}
|
||||
Generate a profile for an expert who can help with the task '{{ task }}'. Please provide the following details:
|
||||
Name: Enter the expert's name
|
||||
Expertise: Specify the area in which the expert specializes
|
||||
Goals: List 4 goals that the expert aims to achieve in order to help with the task
|
||||
Assessment: Describe how the expert will assess whether they have successfully completed the task
|
||||
{% endblock %}
|
||||
17
forge/forge/llm/prompting/gpt-3.5-turbo/system-format.j2
Normal file
17
forge/forge/llm/prompting/gpt-3.5-turbo/system-format.j2
Normal file
@@ -0,0 +1,17 @@
|
||||
Reply only in json with the following format:
|
||||
|
||||
{
|
||||
\"thoughts\": {
|
||||
\"text\": \"thoughts\",
|
||||
\"reasoning\": \"reasoning behind thoughts\",
|
||||
\"plan\": \"- short bulleted\\n- list that conveys\\n- long-term plan\",
|
||||
\"criticism\": \"constructive self-criticism\",
|
||||
\"speak\": \"thoughts summary to say to user\",
|
||||
},
|
||||
\"ability\": {
|
||||
\"name\": \"ability name\",
|
||||
\"args\": {
|
||||
\"arg1\": \"value1", etc...
|
||||
}
|
||||
}
|
||||
}
|
||||
50
forge/forge/llm/prompting/gpt-3.5-turbo/task-step.j2
Normal file
50
forge/forge/llm/prompting/gpt-3.5-turbo/task-step.j2
Normal file
@@ -0,0 +1,50 @@
|
||||
{% extends "techniques/expert.j2" %}
|
||||
{% block expert %}Planner{% endblock %}
|
||||
{% block prompt %}
|
||||
Your task is:
|
||||
|
||||
{{ task }}
|
||||
|
||||
Answer in the provided format.
|
||||
|
||||
Your decisions must always be made independently without seeking user assistance. Play to your strengths as an LLM and
|
||||
pursue simple strategies with no legal complications.
|
||||
|
||||
{% if constraints %}
|
||||
## Constraints
|
||||
You operate within the following constraints:
|
||||
{% for constraint in constraints %}
|
||||
- {{ constraint }}
|
||||
{% endfor %}
|
||||
{% endif %}
|
||||
|
||||
{% if resources %}
|
||||
## Resources
|
||||
You can leverage access to the following resources:
|
||||
{% for resource in resources %}
|
||||
- {{ resource }}
|
||||
{% endfor %}
|
||||
{% endif %}
|
||||
|
||||
{% if abilities %}
|
||||
## Abilities
|
||||
You have access to the following abilities you can call:
|
||||
{% for ability in abilities %}
|
||||
- {{ ability }}
|
||||
{% endfor %}
|
||||
{% endif %}
|
||||
|
||||
{% if best_practices %}
|
||||
## Best practices
|
||||
{% for best_practice in best_practices %}
|
||||
- {{ best_practice }}
|
||||
{% endfor %}
|
||||
{% endif %}
|
||||
|
||||
{% if previous_actions %}
|
||||
## History of Abilities Used
|
||||
{% for action in previous_actions %}
|
||||
- {{ action }}
|
||||
{% endfor %}
|
||||
{% endif %}
|
||||
{% endblock %}
|
||||
35
forge/forge/llm/prompting/schema.py
Normal file
35
forge/forge/llm/prompting/schema.py
Normal file
@@ -0,0 +1,35 @@
|
||||
import enum
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from forge.llm.providers.schema import (
|
||||
ChatMessage,
|
||||
ChatMessageDict,
|
||||
CompletionModelFunction,
|
||||
)
|
||||
|
||||
|
||||
class LanguageModelClassification(str, enum.Enum):
|
||||
"""The LanguageModelClassification is a functional description of the model.
|
||||
|
||||
This is used to determine what kind of model to use for a given prompt.
|
||||
Sometimes we prefer a faster or cheaper model to accomplish a task when
|
||||
possible.
|
||||
"""
|
||||
|
||||
FAST_MODEL = "fast_model"
|
||||
SMART_MODEL = "smart_model"
|
||||
|
||||
|
||||
class ChatPrompt(BaseModel):
|
||||
messages: list[ChatMessage]
|
||||
functions: list[CompletionModelFunction] = Field(default_factory=list)
|
||||
prefill_response: str = ""
|
||||
|
||||
def raw(self) -> list[ChatMessageDict]:
|
||||
return [m.dict() for m in self.messages] # type: ignore
|
||||
|
||||
def __str__(self):
|
||||
return "\n\n".join(
|
||||
f"{m.role.value.upper()}: {m.content}" for m in self.messages
|
||||
)
|
||||
2
forge/forge/llm/prompting/techniques/chain-of-thought.j2
Normal file
2
forge/forge/llm/prompting/techniques/chain-of-thought.j2
Normal file
@@ -0,0 +1,2 @@
|
||||
{% block prompt %} {% endblock %}
|
||||
Let's work this out in a step by step way to be sure we have the right answer.
|
||||
1
forge/forge/llm/prompting/techniques/expert.j2
Normal file
1
forge/forge/llm/prompting/techniques/expert.j2
Normal file
@@ -0,0 +1 @@
|
||||
Answer as an expert in {% block expert %} {% endblock %}. {% block prompt %}{% endblock %}
|
||||
5
forge/forge/llm/prompting/techniques/few-shot.j2
Normal file
5
forge/forge/llm/prompting/techniques/few-shot.j2
Normal file
@@ -0,0 +1,5 @@
|
||||
{% block prompt %} {% endblock %}
|
||||
Examples:
|
||||
{% for example in examples %}
|
||||
- {{ example }}
|
||||
{% endfor %}
|
||||
43
forge/forge/llm/prompting/utils.py
Normal file
43
forge/forge/llm/prompting/utils.py
Normal file
@@ -0,0 +1,43 @@
|
||||
from math import ceil, floor
|
||||
from typing import Any
|
||||
|
||||
from forge.llm.prompting.schema import ChatPrompt
|
||||
|
||||
SEPARATOR_LENGTH = 42
|
||||
|
||||
|
||||
def dump_prompt(prompt: ChatPrompt) -> str:
|
||||
def separator(text: str):
|
||||
half_sep_len = (SEPARATOR_LENGTH - 2 - len(text)) / 2
|
||||
return f"{floor(half_sep_len)*'-'} {text.upper()} {ceil(half_sep_len)*'-'}"
|
||||
|
||||
formatted_messages = "\n".join(
|
||||
[f"{separator(m.role)}\n{m.content}" for m in prompt.messages]
|
||||
)
|
||||
return f"""
|
||||
============== {prompt.__class__.__name__} ==============
|
||||
Length: {len(prompt.messages)} messages
|
||||
{formatted_messages}
|
||||
==========================================
|
||||
"""
|
||||
|
||||
|
||||
def format_numbered_list(items: list[Any], start_at: int = 1) -> str:
|
||||
return "\n".join(f"{i}. {str(item)}" for i, item in enumerate(items, start_at))
|
||||
|
||||
|
||||
def indent(content: str, indentation: int | str = 4) -> str:
|
||||
if type(indentation) is int:
|
||||
indentation = " " * indentation
|
||||
return indentation + content.replace("\n", f"\n{indentation}") # type: ignore
|
||||
|
||||
|
||||
def to_numbered_list(
|
||||
items: list[str], no_items_response: str = "", **template_args
|
||||
) -> str:
|
||||
if items:
|
||||
return "\n".join(
|
||||
f"{i+1}. {item.format(**template_args)}" for i, item in enumerate(items)
|
||||
)
|
||||
else:
|
||||
return no_items_response
|
||||
73
forge/forge/llm/providers/__init__.py
Normal file
73
forge/forge/llm/providers/__init__.py
Normal file
@@ -0,0 +1,73 @@
|
||||
from .multi import (
|
||||
CHAT_MODELS,
|
||||
ChatModelProvider,
|
||||
EmbeddingModelProvider,
|
||||
ModelName,
|
||||
MultiProvider,
|
||||
)
|
||||
from .openai import (
|
||||
OPEN_AI_CHAT_MODELS,
|
||||
OPEN_AI_EMBEDDING_MODELS,
|
||||
OPEN_AI_MODELS,
|
||||
OpenAIModelName,
|
||||
OpenAIProvider,
|
||||
OpenAISettings,
|
||||
)
|
||||
from .schema import (
|
||||
AssistantChatMessage,
|
||||
AssistantChatMessageDict,
|
||||
AssistantFunctionCall,
|
||||
AssistantFunctionCallDict,
|
||||
ChatMessage,
|
||||
ChatModelInfo,
|
||||
ChatModelResponse,
|
||||
CompletionModelFunction,
|
||||
Embedding,
|
||||
EmbeddingModelInfo,
|
||||
EmbeddingModelResponse,
|
||||
ModelInfo,
|
||||
ModelProviderBudget,
|
||||
ModelProviderCredentials,
|
||||
ModelProviderName,
|
||||
ModelProviderService,
|
||||
ModelProviderSettings,
|
||||
ModelProviderUsage,
|
||||
ModelResponse,
|
||||
ModelTokenizer,
|
||||
)
|
||||
from .utils import function_specs_from_commands
|
||||
|
||||
__all__ = [
|
||||
"AssistantChatMessage",
|
||||
"AssistantChatMessageDict",
|
||||
"AssistantFunctionCall",
|
||||
"AssistantFunctionCallDict",
|
||||
"ChatMessage",
|
||||
"ChatModelInfo",
|
||||
"ChatModelResponse",
|
||||
"CompletionModelFunction",
|
||||
"CHAT_MODELS",
|
||||
"Embedding",
|
||||
"EmbeddingModelInfo",
|
||||
"EmbeddingModelProvider",
|
||||
"EmbeddingModelResponse",
|
||||
"ModelInfo",
|
||||
"ModelName",
|
||||
"ChatModelProvider",
|
||||
"ModelProviderBudget",
|
||||
"ModelProviderCredentials",
|
||||
"ModelProviderName",
|
||||
"ModelProviderService",
|
||||
"ModelProviderSettings",
|
||||
"ModelProviderUsage",
|
||||
"ModelResponse",
|
||||
"ModelTokenizer",
|
||||
"MultiProvider",
|
||||
"OPEN_AI_MODELS",
|
||||
"OPEN_AI_CHAT_MODELS",
|
||||
"OPEN_AI_EMBEDDING_MODELS",
|
||||
"OpenAIModelName",
|
||||
"OpenAIProvider",
|
||||
"OpenAISettings",
|
||||
"function_specs_from_commands",
|
||||
]
|
||||
517
forge/forge/llm/providers/_openai_base.py
Normal file
517
forge/forge/llm/providers/_openai_base.py
Normal file
@@ -0,0 +1,517 @@
|
||||
import inspect
|
||||
import logging
|
||||
from typing import (
|
||||
Any,
|
||||
Awaitable,
|
||||
Callable,
|
||||
ClassVar,
|
||||
Mapping,
|
||||
Optional,
|
||||
ParamSpec,
|
||||
Sequence,
|
||||
TypeVar,
|
||||
cast,
|
||||
)
|
||||
|
||||
import sentry_sdk
|
||||
import tenacity
|
||||
from openai._exceptions import APIConnectionError, APIStatusError
|
||||
from openai.types import CreateEmbeddingResponse, EmbeddingCreateParams
|
||||
from openai.types.chat import (
|
||||
ChatCompletion,
|
||||
ChatCompletionAssistantMessageParam,
|
||||
ChatCompletionMessage,
|
||||
ChatCompletionMessageParam,
|
||||
CompletionCreateParams,
|
||||
)
|
||||
from openai.types.shared_params import FunctionDefinition
|
||||
|
||||
from forge.json.parsing import json_loads
|
||||
|
||||
from .schema import (
|
||||
AssistantChatMessage,
|
||||
AssistantFunctionCall,
|
||||
AssistantToolCall,
|
||||
BaseChatModelProvider,
|
||||
BaseEmbeddingModelProvider,
|
||||
BaseModelProvider,
|
||||
ChatMessage,
|
||||
ChatModelInfo,
|
||||
ChatModelResponse,
|
||||
CompletionModelFunction,
|
||||
Embedding,
|
||||
EmbeddingModelInfo,
|
||||
EmbeddingModelResponse,
|
||||
ModelProviderService,
|
||||
_ModelName,
|
||||
_ModelProviderSettings,
|
||||
)
|
||||
from .utils import validate_tool_calls
|
||||
|
||||
_T = TypeVar("_T")
|
||||
_P = ParamSpec("_P")
|
||||
|
||||
|
||||
class _BaseOpenAIProvider(BaseModelProvider[_ModelName, _ModelProviderSettings]):
|
||||
"""Base class for LLM providers with OpenAI-like APIs"""
|
||||
|
||||
MODELS: ClassVar[
|
||||
Mapping[_ModelName, ChatModelInfo[_ModelName] | EmbeddingModelInfo[_ModelName]] # type: ignore # noqa
|
||||
]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
settings: Optional[_ModelProviderSettings] = None,
|
||||
logger: Optional[logging.Logger] = None,
|
||||
):
|
||||
if not getattr(self, "MODELS", None):
|
||||
raise ValueError(f"{self.__class__.__name__}.MODELS is not set")
|
||||
|
||||
if not settings:
|
||||
settings = self.default_settings.copy(deep=True)
|
||||
if not settings.credentials:
|
||||
settings.credentials = self.default_settings.__fields__[
|
||||
"credentials"
|
||||
].type_.from_env()
|
||||
|
||||
super(_BaseOpenAIProvider, self).__init__(settings=settings, logger=logger)
|
||||
|
||||
if not getattr(self, "_client", None):
|
||||
from openai import AsyncOpenAI
|
||||
|
||||
self._client = AsyncOpenAI(
|
||||
**self._credentials.get_api_access_kwargs() # type: ignore
|
||||
)
|
||||
|
||||
async def get_available_models(
|
||||
self,
|
||||
) -> Sequence[ChatModelInfo[_ModelName] | EmbeddingModelInfo[_ModelName]]:
|
||||
_models = (await self._client.models.list()).data
|
||||
return [
|
||||
self.MODELS[cast(_ModelName, m.id)] for m in _models if m.id in self.MODELS
|
||||
]
|
||||
|
||||
def get_token_limit(self, model_name: _ModelName) -> int:
|
||||
"""Get the maximum number of input tokens for a given model"""
|
||||
return self.MODELS[model_name].max_tokens
|
||||
|
||||
def count_tokens(self, text: str, model_name: _ModelName) -> int:
|
||||
return len(self.get_tokenizer(model_name).encode(text))
|
||||
|
||||
def _retry_api_request(self, func: Callable[_P, _T]) -> Callable[_P, _T]:
|
||||
return tenacity.retry(
|
||||
retry=(
|
||||
tenacity.retry_if_exception_type(APIConnectionError)
|
||||
| tenacity.retry_if_exception(
|
||||
lambda e: isinstance(e, APIStatusError) and e.status_code >= 500
|
||||
)
|
||||
),
|
||||
wait=tenacity.wait_exponential(),
|
||||
stop=tenacity.stop_after_attempt(self._configuration.retries_per_request),
|
||||
after=tenacity.after_log(self._logger, logging.DEBUG),
|
||||
)(func)
|
||||
|
||||
def __repr__(self):
|
||||
return f"{self.__class__.__name__}()"
|
||||
|
||||
|
||||
class BaseOpenAIChatProvider(
|
||||
_BaseOpenAIProvider[_ModelName, _ModelProviderSettings],
|
||||
BaseChatModelProvider[_ModelName, _ModelProviderSettings],
|
||||
):
|
||||
CHAT_MODELS: ClassVar[dict[_ModelName, ChatModelInfo[_ModelName]]] # type: ignore
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
settings: Optional[_ModelProviderSettings] = None,
|
||||
logger: Optional[logging.Logger] = None,
|
||||
):
|
||||
if not getattr(self, "CHAT_MODELS", None):
|
||||
raise ValueError(f"{self.__class__.__name__}.CHAT_MODELS is not set")
|
||||
|
||||
super(BaseOpenAIChatProvider, self).__init__(settings=settings, logger=logger)
|
||||
|
||||
async def get_available_chat_models(self) -> Sequence[ChatModelInfo[_ModelName]]:
|
||||
all_available_models = await self.get_available_models()
|
||||
return [
|
||||
model
|
||||
for model in all_available_models
|
||||
if model.service == ModelProviderService.CHAT
|
||||
]
|
||||
|
||||
def count_message_tokens(
|
||||
self,
|
||||
messages: ChatMessage | list[ChatMessage],
|
||||
model_name: _ModelName,
|
||||
) -> int:
|
||||
if isinstance(messages, ChatMessage):
|
||||
messages = [messages]
|
||||
return self.count_tokens(
|
||||
"\n\n".join(f"{m.role.upper()}: {m.content}" for m in messages), model_name
|
||||
)
|
||||
|
||||
async def create_chat_completion(
|
||||
self,
|
||||
model_prompt: list[ChatMessage],
|
||||
model_name: _ModelName,
|
||||
completion_parser: Callable[[AssistantChatMessage], _T] = lambda _: None,
|
||||
functions: Optional[list[CompletionModelFunction]] = None,
|
||||
max_output_tokens: Optional[int] = None,
|
||||
prefill_response: str = "",
|
||||
**kwargs,
|
||||
) -> ChatModelResponse[_T]:
|
||||
"""Create a chat completion using the API."""
|
||||
|
||||
(
|
||||
openai_messages,
|
||||
completion_kwargs,
|
||||
parse_kwargs,
|
||||
) = self._get_chat_completion_args(
|
||||
prompt_messages=model_prompt,
|
||||
model=model_name,
|
||||
functions=functions,
|
||||
max_output_tokens=max_output_tokens,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
total_cost = 0.0
|
||||
attempts = 0
|
||||
while True:
|
||||
completion_kwargs["messages"] = openai_messages
|
||||
_response, _cost, t_input, t_output = await self._create_chat_completion(
|
||||
model=model_name,
|
||||
completion_kwargs=completion_kwargs,
|
||||
)
|
||||
total_cost += _cost
|
||||
|
||||
# If parsing the response fails, append the error to the prompt, and let the
|
||||
# LLM fix its mistake(s).
|
||||
attempts += 1
|
||||
parse_errors: list[Exception] = []
|
||||
|
||||
_assistant_msg = _response.choices[0].message
|
||||
|
||||
tool_calls, _errors = self._parse_assistant_tool_calls(
|
||||
_assistant_msg, **parse_kwargs
|
||||
)
|
||||
parse_errors += _errors
|
||||
|
||||
# Validate tool calls
|
||||
if not parse_errors and tool_calls and functions:
|
||||
parse_errors += validate_tool_calls(tool_calls, functions)
|
||||
|
||||
assistant_msg = AssistantChatMessage(
|
||||
content=_assistant_msg.content or "",
|
||||
tool_calls=tool_calls or None,
|
||||
)
|
||||
|
||||
parsed_result: _T = None # type: ignore
|
||||
if not parse_errors:
|
||||
try:
|
||||
parsed_result = completion_parser(assistant_msg)
|
||||
if inspect.isawaitable(parsed_result):
|
||||
parsed_result = await parsed_result
|
||||
except Exception as e:
|
||||
parse_errors.append(e)
|
||||
|
||||
if not parse_errors:
|
||||
if attempts > 1:
|
||||
self._logger.debug(
|
||||
f"Total cost for {attempts} attempts: ${round(total_cost, 5)}"
|
||||
)
|
||||
|
||||
return ChatModelResponse(
|
||||
response=AssistantChatMessage(
|
||||
content=_assistant_msg.content or "",
|
||||
tool_calls=tool_calls or None,
|
||||
),
|
||||
parsed_result=parsed_result,
|
||||
model_info=self.CHAT_MODELS[model_name],
|
||||
prompt_tokens_used=t_input,
|
||||
completion_tokens_used=t_output,
|
||||
)
|
||||
|
||||
else:
|
||||
self._logger.debug(
|
||||
f"Parsing failed on response: '''{_assistant_msg}'''"
|
||||
)
|
||||
parse_errors_fmt = "\n\n".join(
|
||||
f"{e.__class__.__name__}: {e}" for e in parse_errors
|
||||
)
|
||||
self._logger.warning(
|
||||
f"Parsing attempt #{attempts} failed: {parse_errors_fmt}"
|
||||
)
|
||||
for e in parse_errors:
|
||||
sentry_sdk.capture_exception(
|
||||
error=e,
|
||||
extras={"assistant_msg": _assistant_msg, "i_attempt": attempts},
|
||||
)
|
||||
|
||||
if attempts < self._configuration.fix_failed_parse_tries:
|
||||
openai_messages.append(
|
||||
cast(
|
||||
ChatCompletionAssistantMessageParam,
|
||||
_assistant_msg.dict(exclude_none=True),
|
||||
)
|
||||
)
|
||||
openai_messages.append(
|
||||
{
|
||||
"role": "system",
|
||||
"content": (
|
||||
f"ERROR PARSING YOUR RESPONSE:\n\n{parse_errors_fmt}"
|
||||
),
|
||||
}
|
||||
)
|
||||
continue
|
||||
else:
|
||||
raise parse_errors[0]
|
||||
|
||||
def _get_chat_completion_args(
|
||||
self,
|
||||
prompt_messages: list[ChatMessage],
|
||||
model: _ModelName,
|
||||
functions: Optional[list[CompletionModelFunction]] = None,
|
||||
max_output_tokens: Optional[int] = None,
|
||||
**kwargs,
|
||||
) -> tuple[
|
||||
list[ChatCompletionMessageParam], CompletionCreateParams, dict[str, Any]
|
||||
]:
|
||||
"""Prepare keyword arguments for a chat completion API call
|
||||
|
||||
Args:
|
||||
prompt_messages: List of ChatMessages
|
||||
model: The model to use
|
||||
functions (optional): List of functions available to the LLM
|
||||
max_output_tokens (optional): Maximum number of tokens to generate
|
||||
|
||||
Returns:
|
||||
list[ChatCompletionMessageParam]: Prompt messages for the API call
|
||||
CompletionCreateParams: Mapping of other kwargs for the API call
|
||||
Mapping[str, Any]: Any keyword arguments to pass on to the completion parser
|
||||
"""
|
||||
kwargs = cast(CompletionCreateParams, kwargs)
|
||||
|
||||
if max_output_tokens:
|
||||
kwargs["max_tokens"] = max_output_tokens
|
||||
|
||||
if functions:
|
||||
kwargs["tools"] = [ # pyright: ignore - it fails to infer the dict type
|
||||
{"type": "function", "function": format_function_def_for_openai(f)}
|
||||
for f in functions
|
||||
]
|
||||
if len(functions) == 1:
|
||||
# force the model to call the only specified function
|
||||
kwargs["tool_choice"] = { # pyright: ignore - type inference failure
|
||||
"type": "function",
|
||||
"function": {"name": functions[0].name},
|
||||
}
|
||||
|
||||
if extra_headers := self._configuration.extra_request_headers:
|
||||
# 'extra_headers' is not on CompletionCreateParams, but is on chat.create()
|
||||
kwargs["extra_headers"] = kwargs.get("extra_headers", {}) # type: ignore
|
||||
kwargs["extra_headers"].update(extra_headers.copy()) # type: ignore
|
||||
|
||||
prepped_messages: list[ChatCompletionMessageParam] = [
|
||||
message.dict( # type: ignore
|
||||
include={"role", "content", "tool_calls", "tool_call_id", "name"},
|
||||
exclude_none=True,
|
||||
)
|
||||
for message in prompt_messages
|
||||
]
|
||||
|
||||
if "messages" in kwargs:
|
||||
prepped_messages += kwargs["messages"]
|
||||
del kwargs["messages"] # type: ignore - messages are added back later
|
||||
|
||||
return prepped_messages, kwargs, {}
|
||||
|
||||
async def _create_chat_completion(
|
||||
self,
|
||||
model: _ModelName,
|
||||
completion_kwargs: CompletionCreateParams,
|
||||
) -> tuple[ChatCompletion, float, int, int]:
|
||||
"""
|
||||
Create a chat completion using an OpenAI-like API with retry handling
|
||||
|
||||
Params:
|
||||
model: The model to use for the completion
|
||||
completion_kwargs: All other arguments for the completion call
|
||||
|
||||
Returns:
|
||||
ChatCompletion: The chat completion response object
|
||||
float: The cost ($) of this completion
|
||||
int: Number of prompt tokens used
|
||||
int: Number of completion tokens used
|
||||
"""
|
||||
completion_kwargs["model"] = completion_kwargs.get("model") or model
|
||||
|
||||
@self._retry_api_request
|
||||
async def _create_chat_completion_with_retry() -> ChatCompletion:
|
||||
return await self._client.chat.completions.create(
|
||||
**completion_kwargs, # type: ignore
|
||||
)
|
||||
|
||||
completion = await _create_chat_completion_with_retry()
|
||||
|
||||
if completion.usage:
|
||||
prompt_tokens_used = completion.usage.prompt_tokens
|
||||
completion_tokens_used = completion.usage.completion_tokens
|
||||
else:
|
||||
prompt_tokens_used = completion_tokens_used = 0
|
||||
|
||||
if self._budget:
|
||||
cost = self._budget.update_usage_and_cost(
|
||||
model_info=self.CHAT_MODELS[model],
|
||||
input_tokens_used=prompt_tokens_used,
|
||||
output_tokens_used=completion_tokens_used,
|
||||
)
|
||||
else:
|
||||
cost = 0
|
||||
|
||||
self._logger.debug(
|
||||
f"{model} completion usage: {prompt_tokens_used} input, "
|
||||
f"{completion_tokens_used} output - ${round(cost, 5)}"
|
||||
)
|
||||
return completion, cost, prompt_tokens_used, completion_tokens_used
|
||||
|
||||
def _parse_assistant_tool_calls(
|
||||
self, assistant_message: ChatCompletionMessage, **kwargs
|
||||
) -> tuple[list[AssistantToolCall], list[Exception]]:
|
||||
tool_calls: list[AssistantToolCall] = []
|
||||
parse_errors: list[Exception] = []
|
||||
|
||||
if assistant_message.tool_calls:
|
||||
for _tc in assistant_message.tool_calls:
|
||||
try:
|
||||
parsed_arguments = json_loads(_tc.function.arguments)
|
||||
except Exception as e:
|
||||
err_message = (
|
||||
f"Decoding arguments for {_tc.function.name} failed: "
|
||||
+ str(e.args[0])
|
||||
)
|
||||
parse_errors.append(
|
||||
type(e)(err_message, *e.args[1:]).with_traceback(
|
||||
e.__traceback__
|
||||
)
|
||||
)
|
||||
continue
|
||||
|
||||
tool_calls.append(
|
||||
AssistantToolCall(
|
||||
id=_tc.id,
|
||||
type=_tc.type,
|
||||
function=AssistantFunctionCall(
|
||||
name=_tc.function.name,
|
||||
arguments=parsed_arguments,
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
# If parsing of all tool calls succeeds in the end, we ignore any issues
|
||||
if len(tool_calls) == len(assistant_message.tool_calls):
|
||||
parse_errors = []
|
||||
|
||||
return tool_calls, parse_errors
|
||||
|
||||
|
||||
class BaseOpenAIEmbeddingProvider(
|
||||
_BaseOpenAIProvider[_ModelName, _ModelProviderSettings],
|
||||
BaseEmbeddingModelProvider[_ModelName, _ModelProviderSettings],
|
||||
):
|
||||
EMBEDDING_MODELS: ClassVar[
|
||||
dict[_ModelName, EmbeddingModelInfo[_ModelName]] # type: ignore
|
||||
]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
settings: Optional[_ModelProviderSettings] = None,
|
||||
logger: Optional[logging.Logger] = None,
|
||||
):
|
||||
if not getattr(self, "EMBEDDING_MODELS", None):
|
||||
raise ValueError(f"{self.__class__.__name__}.EMBEDDING_MODELS is not set")
|
||||
|
||||
super(BaseOpenAIEmbeddingProvider, self).__init__(
|
||||
settings=settings, logger=logger
|
||||
)
|
||||
|
||||
async def get_available_embedding_models(
|
||||
self,
|
||||
) -> Sequence[EmbeddingModelInfo[_ModelName]]:
|
||||
all_available_models = await self.get_available_models()
|
||||
return [
|
||||
model
|
||||
for model in all_available_models
|
||||
if model.service == ModelProviderService.EMBEDDING
|
||||
]
|
||||
|
||||
async def create_embedding(
|
||||
self,
|
||||
text: str,
|
||||
model_name: _ModelName,
|
||||
embedding_parser: Callable[[Embedding], Embedding],
|
||||
**kwargs,
|
||||
) -> EmbeddingModelResponse:
|
||||
"""Create an embedding using an OpenAI-like API"""
|
||||
embedding_kwargs = self._get_embedding_kwargs(
|
||||
input=text, model=model_name, **kwargs
|
||||
)
|
||||
response = await self._create_embedding(embedding_kwargs)
|
||||
|
||||
return EmbeddingModelResponse(
|
||||
embedding=embedding_parser(response.data[0].embedding),
|
||||
model_info=self.EMBEDDING_MODELS[model_name],
|
||||
prompt_tokens_used=response.usage.prompt_tokens,
|
||||
)
|
||||
|
||||
def _get_embedding_kwargs(
|
||||
self, input: str | list[str], model: _ModelName, **kwargs
|
||||
) -> EmbeddingCreateParams:
|
||||
"""Get kwargs for an embedding API call
|
||||
|
||||
Params:
|
||||
input: Text body or list of text bodies to create embedding(s) from
|
||||
model: Embedding model to use
|
||||
|
||||
Returns:
|
||||
The kwargs for the embedding API call
|
||||
"""
|
||||
kwargs = cast(EmbeddingCreateParams, kwargs)
|
||||
|
||||
kwargs["input"] = input
|
||||
kwargs["model"] = model
|
||||
|
||||
if extra_headers := self._configuration.extra_request_headers:
|
||||
# 'extra_headers' is not on CompletionCreateParams, but is on embedding.create() # noqa
|
||||
kwargs["extra_headers"] = kwargs.get("extra_headers", {}) # type: ignore
|
||||
kwargs["extra_headers"].update(extra_headers.copy()) # type: ignore
|
||||
|
||||
return kwargs
|
||||
|
||||
def _create_embedding(
|
||||
self, embedding_kwargs: EmbeddingCreateParams
|
||||
) -> Awaitable[CreateEmbeddingResponse]:
|
||||
"""Create an embedding using an OpenAI-like API with retry handling."""
|
||||
|
||||
@self._retry_api_request
|
||||
async def _create_embedding_with_retry() -> CreateEmbeddingResponse:
|
||||
return await self._client.embeddings.create(**embedding_kwargs)
|
||||
|
||||
return _create_embedding_with_retry()
|
||||
|
||||
|
||||
def format_function_def_for_openai(self: CompletionModelFunction) -> FunctionDefinition:
|
||||
"""Returns an OpenAI-consumable function definition"""
|
||||
|
||||
return {
|
||||
"name": self.name,
|
||||
"description": self.description,
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
name: param.to_dict() for name, param in self.parameters.items()
|
||||
},
|
||||
"required": [
|
||||
name for name, param in self.parameters.items() if param.required
|
||||
],
|
||||
},
|
||||
}
|
||||
488
forge/forge/llm/providers/anthropic.py
Normal file
488
forge/forge/llm/providers/anthropic.py
Normal file
@@ -0,0 +1,488 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import enum
|
||||
import logging
|
||||
from typing import TYPE_CHECKING, Any, Callable, Optional, ParamSpec, Sequence, TypeVar
|
||||
|
||||
import sentry_sdk
|
||||
import tenacity
|
||||
import tiktoken
|
||||
from anthropic import APIConnectionError, APIStatusError
|
||||
from pydantic import SecretStr
|
||||
|
||||
from forge.models.config import UserConfigurable
|
||||
|
||||
from .schema import (
|
||||
AssistantChatMessage,
|
||||
AssistantFunctionCall,
|
||||
AssistantToolCall,
|
||||
BaseChatModelProvider,
|
||||
ChatMessage,
|
||||
ChatModelInfo,
|
||||
ChatModelResponse,
|
||||
CompletionModelFunction,
|
||||
ModelProviderBudget,
|
||||
ModelProviderConfiguration,
|
||||
ModelProviderCredentials,
|
||||
ModelProviderName,
|
||||
ModelProviderSettings,
|
||||
ModelTokenizer,
|
||||
ToolResultMessage,
|
||||
)
|
||||
from .utils import validate_tool_calls
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from anthropic.types.beta.tools import MessageCreateParams
|
||||
from anthropic.types.beta.tools import ToolsBetaMessage as Message
|
||||
from anthropic.types.beta.tools import ToolsBetaMessageParam as MessageParam
|
||||
|
||||
_T = TypeVar("_T")
|
||||
_P = ParamSpec("_P")
|
||||
|
||||
|
||||
class AnthropicModelName(str, enum.Enum):
|
||||
CLAUDE3_OPUS_v1 = "claude-3-opus-20240229"
|
||||
CLAUDE3_SONNET_v1 = "claude-3-sonnet-20240229"
|
||||
CLAUDE3_HAIKU_v1 = "claude-3-haiku-20240307"
|
||||
|
||||
|
||||
ANTHROPIC_CHAT_MODELS = {
|
||||
info.name: info
|
||||
for info in [
|
||||
ChatModelInfo(
|
||||
name=AnthropicModelName.CLAUDE3_OPUS_v1,
|
||||
provider_name=ModelProviderName.ANTHROPIC,
|
||||
prompt_token_cost=15 / 1e6,
|
||||
completion_token_cost=75 / 1e6,
|
||||
max_tokens=200000,
|
||||
has_function_call_api=True,
|
||||
),
|
||||
ChatModelInfo(
|
||||
name=AnthropicModelName.CLAUDE3_SONNET_v1,
|
||||
provider_name=ModelProviderName.ANTHROPIC,
|
||||
prompt_token_cost=3 / 1e6,
|
||||
completion_token_cost=15 / 1e6,
|
||||
max_tokens=200000,
|
||||
has_function_call_api=True,
|
||||
),
|
||||
ChatModelInfo(
|
||||
name=AnthropicModelName.CLAUDE3_HAIKU_v1,
|
||||
provider_name=ModelProviderName.ANTHROPIC,
|
||||
prompt_token_cost=0.25 / 1e6,
|
||||
completion_token_cost=1.25 / 1e6,
|
||||
max_tokens=200000,
|
||||
has_function_call_api=True,
|
||||
),
|
||||
]
|
||||
}
|
||||
|
||||
|
||||
class AnthropicCredentials(ModelProviderCredentials):
|
||||
"""Credentials for Anthropic."""
|
||||
|
||||
api_key: SecretStr = UserConfigurable(from_env="ANTHROPIC_API_KEY") # type: ignore
|
||||
api_base: Optional[SecretStr] = UserConfigurable(
|
||||
default=None, from_env="ANTHROPIC_API_BASE_URL"
|
||||
)
|
||||
|
||||
def get_api_access_kwargs(self) -> dict[str, str]:
|
||||
return {
|
||||
k: v.get_secret_value()
|
||||
for k, v in {
|
||||
"api_key": self.api_key,
|
||||
"base_url": self.api_base,
|
||||
}.items()
|
||||
if v is not None
|
||||
}
|
||||
|
||||
|
||||
class AnthropicSettings(ModelProviderSettings):
|
||||
credentials: Optional[AnthropicCredentials] # type: ignore
|
||||
budget: ModelProviderBudget # type: ignore
|
||||
|
||||
|
||||
class AnthropicProvider(BaseChatModelProvider[AnthropicModelName, AnthropicSettings]):
|
||||
default_settings = AnthropicSettings(
|
||||
name="anthropic_provider",
|
||||
description="Provides access to Anthropic's API.",
|
||||
configuration=ModelProviderConfiguration(),
|
||||
credentials=None,
|
||||
budget=ModelProviderBudget(),
|
||||
)
|
||||
|
||||
_settings: AnthropicSettings
|
||||
_credentials: AnthropicCredentials
|
||||
_budget: ModelProviderBudget
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
settings: Optional[AnthropicSettings] = None,
|
||||
logger: Optional[logging.Logger] = None,
|
||||
):
|
||||
if not settings:
|
||||
settings = self.default_settings.copy(deep=True)
|
||||
if not settings.credentials:
|
||||
settings.credentials = AnthropicCredentials.from_env()
|
||||
|
||||
super(AnthropicProvider, self).__init__(settings=settings, logger=logger)
|
||||
|
||||
from anthropic import AsyncAnthropic
|
||||
|
||||
self._client = AsyncAnthropic(
|
||||
**self._credentials.get_api_access_kwargs() # type: ignore
|
||||
)
|
||||
|
||||
async def get_available_models(self) -> Sequence[ChatModelInfo[AnthropicModelName]]:
|
||||
return await self.get_available_chat_models()
|
||||
|
||||
async def get_available_chat_models(
|
||||
self,
|
||||
) -> Sequence[ChatModelInfo[AnthropicModelName]]:
|
||||
return list(ANTHROPIC_CHAT_MODELS.values())
|
||||
|
||||
def get_token_limit(self, model_name: AnthropicModelName) -> int:
|
||||
"""Get the token limit for a given model."""
|
||||
return ANTHROPIC_CHAT_MODELS[model_name].max_tokens
|
||||
|
||||
def get_tokenizer(self, model_name: AnthropicModelName) -> ModelTokenizer[Any]:
|
||||
# HACK: No official tokenizer is available for Claude 3
|
||||
return tiktoken.encoding_for_model(model_name)
|
||||
|
||||
def count_tokens(self, text: str, model_name: AnthropicModelName) -> int:
|
||||
return 0 # HACK: No official tokenizer is available for Claude 3
|
||||
|
||||
def count_message_tokens(
|
||||
self,
|
||||
messages: ChatMessage | list[ChatMessage],
|
||||
model_name: AnthropicModelName,
|
||||
) -> int:
|
||||
return 0 # HACK: No official tokenizer is available for Claude 3
|
||||
|
||||
async def create_chat_completion(
|
||||
self,
|
||||
model_prompt: list[ChatMessage],
|
||||
model_name: AnthropicModelName,
|
||||
completion_parser: Callable[[AssistantChatMessage], _T] = lambda _: None,
|
||||
functions: Optional[list[CompletionModelFunction]] = None,
|
||||
max_output_tokens: Optional[int] = None,
|
||||
prefill_response: str = "",
|
||||
**kwargs,
|
||||
) -> ChatModelResponse[_T]:
|
||||
"""Create a completion using the Anthropic API."""
|
||||
anthropic_messages, completion_kwargs = self._get_chat_completion_args(
|
||||
prompt_messages=model_prompt,
|
||||
model=model_name,
|
||||
functions=functions,
|
||||
max_output_tokens=max_output_tokens,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
total_cost = 0.0
|
||||
attempts = 0
|
||||
while True:
|
||||
completion_kwargs["messages"] = anthropic_messages.copy()
|
||||
if prefill_response:
|
||||
completion_kwargs["messages"].append(
|
||||
{"role": "assistant", "content": prefill_response}
|
||||
)
|
||||
|
||||
(
|
||||
_assistant_msg,
|
||||
cost,
|
||||
t_input,
|
||||
t_output,
|
||||
) = await self._create_chat_completion(model_name, completion_kwargs)
|
||||
total_cost += cost
|
||||
self._logger.debug(
|
||||
f"Completion usage: {t_input} input, {t_output} output "
|
||||
f"- ${round(cost, 5)}"
|
||||
)
|
||||
|
||||
# Merge prefill into generated response
|
||||
if prefill_response:
|
||||
first_text_block = next(
|
||||
b for b in _assistant_msg.content if b.type == "text"
|
||||
)
|
||||
first_text_block.text = prefill_response + first_text_block.text
|
||||
|
||||
assistant_msg = AssistantChatMessage(
|
||||
content="\n\n".join(
|
||||
b.text for b in _assistant_msg.content if b.type == "text"
|
||||
),
|
||||
tool_calls=self._parse_assistant_tool_calls(_assistant_msg),
|
||||
)
|
||||
|
||||
# If parsing the response fails, append the error to the prompt, and let the
|
||||
# LLM fix its mistake(s).
|
||||
attempts += 1
|
||||
tool_call_errors = []
|
||||
try:
|
||||
# Validate tool calls
|
||||
if assistant_msg.tool_calls and functions:
|
||||
tool_call_errors = validate_tool_calls(
|
||||
assistant_msg.tool_calls, functions
|
||||
)
|
||||
if tool_call_errors:
|
||||
raise ValueError(
|
||||
"Invalid tool use(s):\n"
|
||||
+ "\n".join(str(e) for e in tool_call_errors)
|
||||
)
|
||||
|
||||
parsed_result = completion_parser(assistant_msg)
|
||||
break
|
||||
except Exception as e:
|
||||
self._logger.debug(
|
||||
f"Parsing failed on response: '''{_assistant_msg}'''"
|
||||
)
|
||||
self._logger.warning(f"Parsing attempt #{attempts} failed: {e}")
|
||||
sentry_sdk.capture_exception(
|
||||
error=e,
|
||||
extras={"assistant_msg": _assistant_msg, "i_attempt": attempts},
|
||||
)
|
||||
if attempts < self._configuration.fix_failed_parse_tries:
|
||||
anthropic_messages.append(
|
||||
_assistant_msg.dict(include={"role", "content"}) # type: ignore
|
||||
)
|
||||
anthropic_messages.append(
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
*(
|
||||
# tool_result is required if last assistant message
|
||||
# had tool_use block(s)
|
||||
{
|
||||
"type": "tool_result",
|
||||
"tool_use_id": tc.id,
|
||||
"is_error": True,
|
||||
"content": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": "Not executed because parsing "
|
||||
"of your last message failed"
|
||||
if not tool_call_errors
|
||||
else str(e)
|
||||
if (
|
||||
e := next(
|
||||
(
|
||||
tce
|
||||
for tce in tool_call_errors
|
||||
if tce.name
|
||||
== tc.function.name
|
||||
),
|
||||
None,
|
||||
)
|
||||
)
|
||||
else "Not executed because validation "
|
||||
"of tool input failed",
|
||||
}
|
||||
],
|
||||
}
|
||||
for tc in assistant_msg.tool_calls or []
|
||||
),
|
||||
{
|
||||
"type": "text",
|
||||
"text": (
|
||||
"ERROR PARSING YOUR RESPONSE:\n\n"
|
||||
f"{e.__class__.__name__}: {e}"
|
||||
),
|
||||
},
|
||||
],
|
||||
}
|
||||
)
|
||||
else:
|
||||
raise
|
||||
|
||||
if attempts > 1:
|
||||
self._logger.debug(
|
||||
f"Total cost for {attempts} attempts: ${round(total_cost, 5)}"
|
||||
)
|
||||
|
||||
return ChatModelResponse(
|
||||
response=assistant_msg,
|
||||
parsed_result=parsed_result,
|
||||
model_info=ANTHROPIC_CHAT_MODELS[model_name],
|
||||
prompt_tokens_used=t_input,
|
||||
completion_tokens_used=t_output,
|
||||
)
|
||||
|
||||
def _get_chat_completion_args(
|
||||
self,
|
||||
prompt_messages: list[ChatMessage],
|
||||
functions: Optional[list[CompletionModelFunction]] = None,
|
||||
max_output_tokens: Optional[int] = None,
|
||||
**kwargs,
|
||||
) -> tuple[list[MessageParam], MessageCreateParams]:
|
||||
"""Prepare arguments for message completion API call.
|
||||
|
||||
Args:
|
||||
prompt_messages: List of ChatMessages.
|
||||
functions: Optional list of functions available to the LLM.
|
||||
kwargs: Additional keyword arguments.
|
||||
|
||||
Returns:
|
||||
list[MessageParam]: Prompt messages for the Anthropic call
|
||||
dict[str, Any]: Any other kwargs for the Anthropic call
|
||||
"""
|
||||
if functions:
|
||||
kwargs["tools"] = [
|
||||
{
|
||||
"name": f.name,
|
||||
"description": f.description,
|
||||
"input_schema": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
name: param.to_dict()
|
||||
for name, param in f.parameters.items()
|
||||
},
|
||||
"required": [
|
||||
name
|
||||
for name, param in f.parameters.items()
|
||||
if param.required
|
||||
],
|
||||
},
|
||||
}
|
||||
for f in functions
|
||||
]
|
||||
|
||||
kwargs["max_tokens"] = max_output_tokens or 4096
|
||||
|
||||
if extra_headers := self._configuration.extra_request_headers:
|
||||
kwargs["extra_headers"] = kwargs.get("extra_headers", {})
|
||||
kwargs["extra_headers"].update(extra_headers.copy())
|
||||
|
||||
system_messages = [
|
||||
m for m in prompt_messages if m.role == ChatMessage.Role.SYSTEM
|
||||
]
|
||||
if (_n := len(system_messages)) > 1:
|
||||
self._logger.warning(
|
||||
f"Prompt has {_n} system messages; Anthropic supports only 1. "
|
||||
"They will be merged, and removed from the rest of the prompt."
|
||||
)
|
||||
kwargs["system"] = "\n\n".join(sm.content for sm in system_messages)
|
||||
|
||||
messages: list[MessageParam] = []
|
||||
for message in prompt_messages:
|
||||
if message.role == ChatMessage.Role.SYSTEM:
|
||||
continue
|
||||
elif message.role == ChatMessage.Role.USER:
|
||||
# Merge subsequent user messages
|
||||
if messages and (prev_msg := messages[-1])["role"] == "user":
|
||||
if isinstance(prev_msg["content"], str):
|
||||
prev_msg["content"] += f"\n\n{message.content}"
|
||||
else:
|
||||
assert isinstance(prev_msg["content"], list)
|
||||
prev_msg["content"].append(
|
||||
{"type": "text", "text": message.content}
|
||||
)
|
||||
else:
|
||||
messages.append({"role": "user", "content": message.content})
|
||||
# TODO: add support for image blocks
|
||||
elif message.role == ChatMessage.Role.ASSISTANT:
|
||||
if isinstance(message, AssistantChatMessage) and message.tool_calls:
|
||||
messages.append(
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": [
|
||||
*(
|
||||
[{"type": "text", "text": message.content}]
|
||||
if message.content
|
||||
else []
|
||||
),
|
||||
*(
|
||||
{
|
||||
"type": "tool_use",
|
||||
"id": tc.id,
|
||||
"name": tc.function.name,
|
||||
"input": tc.function.arguments,
|
||||
}
|
||||
for tc in message.tool_calls
|
||||
),
|
||||
],
|
||||
}
|
||||
)
|
||||
elif message.content:
|
||||
messages.append(
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": message.content,
|
||||
}
|
||||
)
|
||||
elif isinstance(message, ToolResultMessage):
|
||||
messages.append(
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "tool_result",
|
||||
"tool_use_id": message.tool_call_id,
|
||||
"content": [{"type": "text", "text": message.content}],
|
||||
"is_error": message.is_error,
|
||||
}
|
||||
],
|
||||
}
|
||||
)
|
||||
|
||||
return messages, kwargs # type: ignore
|
||||
|
||||
async def _create_chat_completion(
|
||||
self, model: AnthropicModelName, completion_kwargs: MessageCreateParams
|
||||
) -> tuple[Message, float, int, int]:
|
||||
"""
|
||||
Create a chat completion using the Anthropic API with retry handling.
|
||||
|
||||
Params:
|
||||
completion_kwargs: Keyword arguments for an Anthropic Messages API call
|
||||
|
||||
Returns:
|
||||
Message: The message completion object
|
||||
float: The cost ($) of this completion
|
||||
int: Number of input tokens used
|
||||
int: Number of output tokens used
|
||||
"""
|
||||
|
||||
@self._retry_api_request
|
||||
async def _create_chat_completion_with_retry() -> Message:
|
||||
return await self._client.beta.tools.messages.create(
|
||||
model=model, **completion_kwargs # type: ignore
|
||||
)
|
||||
|
||||
response = await _create_chat_completion_with_retry()
|
||||
|
||||
cost = self._budget.update_usage_and_cost(
|
||||
model_info=ANTHROPIC_CHAT_MODELS[model],
|
||||
input_tokens_used=response.usage.input_tokens,
|
||||
output_tokens_used=response.usage.output_tokens,
|
||||
)
|
||||
return response, cost, response.usage.input_tokens, response.usage.output_tokens
|
||||
|
||||
def _parse_assistant_tool_calls(
|
||||
self, assistant_message: Message
|
||||
) -> list[AssistantToolCall]:
|
||||
return [
|
||||
AssistantToolCall(
|
||||
id=c.id,
|
||||
type="function",
|
||||
function=AssistantFunctionCall(
|
||||
name=c.name,
|
||||
arguments=c.input, # type: ignore
|
||||
),
|
||||
)
|
||||
for c in assistant_message.content
|
||||
if c.type == "tool_use"
|
||||
]
|
||||
|
||||
def _retry_api_request(self, func: Callable[_P, _T]) -> Callable[_P, _T]:
|
||||
return tenacity.retry(
|
||||
retry=(
|
||||
tenacity.retry_if_exception_type(APIConnectionError)
|
||||
| tenacity.retry_if_exception(
|
||||
lambda e: isinstance(e, APIStatusError) and e.status_code >= 500
|
||||
)
|
||||
),
|
||||
wait=tenacity.wait_exponential(),
|
||||
stop=tenacity.stop_after_attempt(self._configuration.retries_per_request),
|
||||
after=tenacity.after_log(self._logger, logging.DEBUG),
|
||||
)(func)
|
||||
|
||||
def __repr__(self):
|
||||
return "AnthropicProvider()"
|
||||
126
forge/forge/llm/providers/groq.py
Normal file
126
forge/forge/llm/providers/groq.py
Normal file
@@ -0,0 +1,126 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import enum
|
||||
import logging
|
||||
from typing import Any, Optional
|
||||
|
||||
import tiktoken
|
||||
from pydantic import SecretStr
|
||||
|
||||
from forge.models.config import UserConfigurable
|
||||
|
||||
from ._openai_base import BaseOpenAIChatProvider
|
||||
from .schema import (
|
||||
ChatModelInfo,
|
||||
ModelProviderBudget,
|
||||
ModelProviderConfiguration,
|
||||
ModelProviderCredentials,
|
||||
ModelProviderName,
|
||||
ModelProviderSettings,
|
||||
ModelTokenizer,
|
||||
)
|
||||
|
||||
|
||||
class GroqModelName(str, enum.Enum):
|
||||
LLAMA3_8B = "llama3-8b-8192"
|
||||
LLAMA3_70B = "llama3-70b-8192"
|
||||
MIXTRAL_8X7B = "mixtral-8x7b-32768"
|
||||
GEMMA_7B = "gemma-7b-it"
|
||||
|
||||
|
||||
GROQ_CHAT_MODELS = {
|
||||
info.name: info
|
||||
for info in [
|
||||
ChatModelInfo(
|
||||
name=GroqModelName.LLAMA3_8B,
|
||||
provider_name=ModelProviderName.GROQ,
|
||||
prompt_token_cost=0.05 / 1e6,
|
||||
completion_token_cost=0.10 / 1e6,
|
||||
max_tokens=8192,
|
||||
has_function_call_api=True,
|
||||
),
|
||||
ChatModelInfo(
|
||||
name=GroqModelName.LLAMA3_70B,
|
||||
provider_name=ModelProviderName.GROQ,
|
||||
prompt_token_cost=0.59 / 1e6,
|
||||
completion_token_cost=0.79 / 1e6,
|
||||
max_tokens=8192,
|
||||
has_function_call_api=True,
|
||||
),
|
||||
ChatModelInfo(
|
||||
name=GroqModelName.MIXTRAL_8X7B,
|
||||
provider_name=ModelProviderName.GROQ,
|
||||
prompt_token_cost=0.27 / 1e6,
|
||||
completion_token_cost=0.27 / 1e6,
|
||||
max_tokens=32768,
|
||||
has_function_call_api=True,
|
||||
),
|
||||
ChatModelInfo(
|
||||
name=GroqModelName.GEMMA_7B,
|
||||
provider_name=ModelProviderName.GROQ,
|
||||
prompt_token_cost=0.10 / 1e6,
|
||||
completion_token_cost=0.10 / 1e6,
|
||||
max_tokens=8192,
|
||||
has_function_call_api=True,
|
||||
),
|
||||
]
|
||||
}
|
||||
|
||||
|
||||
class GroqCredentials(ModelProviderCredentials):
|
||||
"""Credentials for Groq."""
|
||||
|
||||
api_key: SecretStr = UserConfigurable(from_env="GROQ_API_KEY") # type: ignore
|
||||
api_base: Optional[SecretStr] = UserConfigurable(
|
||||
default=None, from_env="GROQ_API_BASE_URL"
|
||||
)
|
||||
|
||||
def get_api_access_kwargs(self) -> dict[str, str]:
|
||||
return {
|
||||
k: v.get_secret_value()
|
||||
for k, v in {
|
||||
"api_key": self.api_key,
|
||||
"base_url": self.api_base,
|
||||
}.items()
|
||||
if v is not None
|
||||
}
|
||||
|
||||
|
||||
class GroqSettings(ModelProviderSettings):
|
||||
credentials: Optional[GroqCredentials] # type: ignore
|
||||
budget: ModelProviderBudget # type: ignore
|
||||
|
||||
|
||||
class GroqProvider(BaseOpenAIChatProvider[GroqModelName, GroqSettings]):
|
||||
CHAT_MODELS = GROQ_CHAT_MODELS
|
||||
MODELS = CHAT_MODELS
|
||||
|
||||
default_settings = GroqSettings(
|
||||
name="groq_provider",
|
||||
description="Provides access to Groq's API.",
|
||||
configuration=ModelProviderConfiguration(),
|
||||
credentials=None,
|
||||
budget=ModelProviderBudget(),
|
||||
)
|
||||
|
||||
_settings: GroqSettings
|
||||
_configuration: ModelProviderConfiguration
|
||||
_credentials: GroqCredentials
|
||||
_budget: ModelProviderBudget
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
settings: Optional[GroqSettings] = None,
|
||||
logger: Optional[logging.Logger] = None,
|
||||
):
|
||||
super(GroqProvider, self).__init__(settings=settings, logger=logger)
|
||||
|
||||
from groq import AsyncGroq
|
||||
|
||||
self._client = AsyncGroq(
|
||||
**self._credentials.get_api_access_kwargs() # type: ignore
|
||||
)
|
||||
|
||||
def get_tokenizer(self, model_name: GroqModelName) -> ModelTokenizer[Any]:
|
||||
# HACK: No official tokenizer is available for Groq
|
||||
return tiktoken.encoding_for_model("gpt-3.5-turbo")
|
||||
165
forge/forge/llm/providers/multi.py
Normal file
165
forge/forge/llm/providers/multi.py
Normal file
@@ -0,0 +1,165 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import Any, Callable, Iterator, Optional, Sequence, TypeVar
|
||||
|
||||
from pydantic import ValidationError
|
||||
|
||||
from .anthropic import ANTHROPIC_CHAT_MODELS, AnthropicModelName, AnthropicProvider
|
||||
from .groq import GROQ_CHAT_MODELS, GroqModelName, GroqProvider
|
||||
from .openai import OPEN_AI_CHAT_MODELS, OpenAIModelName, OpenAIProvider
|
||||
from .schema import (
|
||||
AssistantChatMessage,
|
||||
BaseChatModelProvider,
|
||||
ChatMessage,
|
||||
ChatModelInfo,
|
||||
ChatModelResponse,
|
||||
CompletionModelFunction,
|
||||
ModelProviderBudget,
|
||||
ModelProviderConfiguration,
|
||||
ModelProviderName,
|
||||
ModelProviderSettings,
|
||||
ModelTokenizer,
|
||||
)
|
||||
|
||||
_T = TypeVar("_T")
|
||||
|
||||
ModelName = AnthropicModelName | GroqModelName | OpenAIModelName
|
||||
EmbeddingModelProvider = OpenAIProvider
|
||||
|
||||
CHAT_MODELS = {**ANTHROPIC_CHAT_MODELS, **GROQ_CHAT_MODELS, **OPEN_AI_CHAT_MODELS}
|
||||
|
||||
|
||||
class MultiProvider(BaseChatModelProvider[ModelName, ModelProviderSettings]):
|
||||
default_settings = ModelProviderSettings(
|
||||
name="multi_provider",
|
||||
description=(
|
||||
"Provides access to all of the available models, regardless of provider."
|
||||
),
|
||||
configuration=ModelProviderConfiguration(
|
||||
retries_per_request=7,
|
||||
),
|
||||
budget=ModelProviderBudget(),
|
||||
)
|
||||
|
||||
_budget: ModelProviderBudget
|
||||
|
||||
_provider_instances: dict[ModelProviderName, ChatModelProvider]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
settings: Optional[ModelProviderSettings] = None,
|
||||
logger: Optional[logging.Logger] = None,
|
||||
):
|
||||
super(MultiProvider, self).__init__(settings=settings, logger=logger)
|
||||
self._budget = self._settings.budget or ModelProviderBudget()
|
||||
|
||||
self._provider_instances = {}
|
||||
|
||||
async def get_available_models(self) -> Sequence[ChatModelInfo[ModelName]]:
|
||||
# TODO: support embeddings
|
||||
return await self.get_available_chat_models()
|
||||
|
||||
async def get_available_chat_models(self) -> Sequence[ChatModelInfo[ModelName]]:
|
||||
models = []
|
||||
for provider in self.get_available_providers():
|
||||
models.extend(await provider.get_available_chat_models())
|
||||
return models
|
||||
|
||||
def get_token_limit(self, model_name: ModelName) -> int:
|
||||
"""Get the token limit for a given model."""
|
||||
return self.get_model_provider(model_name).get_token_limit(
|
||||
model_name # type: ignore
|
||||
)
|
||||
|
||||
def get_tokenizer(self, model_name: ModelName) -> ModelTokenizer[Any]:
|
||||
return self.get_model_provider(model_name).get_tokenizer(
|
||||
model_name # type: ignore
|
||||
)
|
||||
|
||||
def count_tokens(self, text: str, model_name: ModelName) -> int:
|
||||
return self.get_model_provider(model_name).count_tokens(
|
||||
text=text, model_name=model_name # type: ignore
|
||||
)
|
||||
|
||||
def count_message_tokens(
|
||||
self, messages: ChatMessage | list[ChatMessage], model_name: ModelName
|
||||
) -> int:
|
||||
return self.get_model_provider(model_name).count_message_tokens(
|
||||
messages=messages, model_name=model_name # type: ignore
|
||||
)
|
||||
|
||||
async def create_chat_completion(
|
||||
self,
|
||||
model_prompt: list[ChatMessage],
|
||||
model_name: ModelName,
|
||||
completion_parser: Callable[[AssistantChatMessage], _T] = lambda _: None,
|
||||
functions: Optional[list[CompletionModelFunction]] = None,
|
||||
max_output_tokens: Optional[int] = None,
|
||||
prefill_response: str = "",
|
||||
**kwargs,
|
||||
) -> ChatModelResponse[_T]:
|
||||
"""Create a completion using the Anthropic API."""
|
||||
return await self.get_model_provider(model_name).create_chat_completion(
|
||||
model_prompt=model_prompt,
|
||||
model_name=model_name, # type: ignore
|
||||
completion_parser=completion_parser,
|
||||
functions=functions,
|
||||
max_output_tokens=max_output_tokens,
|
||||
prefill_response=prefill_response,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def get_model_provider(self, model: ModelName) -> ChatModelProvider:
|
||||
model_info = CHAT_MODELS[model]
|
||||
return self._get_provider(model_info.provider_name)
|
||||
|
||||
def get_available_providers(self) -> Iterator[ChatModelProvider]:
|
||||
for provider_name in ModelProviderName:
|
||||
try:
|
||||
yield self._get_provider(provider_name)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
def _get_provider(self, provider_name: ModelProviderName) -> ChatModelProvider:
|
||||
_provider = self._provider_instances.get(provider_name)
|
||||
if not _provider:
|
||||
Provider = self._get_provider_class(provider_name)
|
||||
settings = Provider.default_settings.copy(deep=True)
|
||||
settings.budget = self._budget
|
||||
settings.configuration.extra_request_headers.update(
|
||||
self._settings.configuration.extra_request_headers
|
||||
)
|
||||
if settings.credentials is None:
|
||||
try:
|
||||
Credentials = settings.__fields__["credentials"].type_
|
||||
settings.credentials = Credentials.from_env()
|
||||
except ValidationError as e:
|
||||
raise ValueError(
|
||||
f"{provider_name} is unavailable: can't load credentials"
|
||||
) from e
|
||||
|
||||
self._provider_instances[provider_name] = _provider = Provider(
|
||||
settings=settings, logger=self._logger # type: ignore
|
||||
)
|
||||
_provider._budget = self._budget # Object binding not preserved by Pydantic
|
||||
return _provider
|
||||
|
||||
@classmethod
|
||||
def _get_provider_class(
|
||||
cls, provider_name: ModelProviderName
|
||||
) -> type[AnthropicProvider | GroqProvider | OpenAIProvider]:
|
||||
try:
|
||||
return {
|
||||
ModelProviderName.ANTHROPIC: AnthropicProvider,
|
||||
ModelProviderName.GROQ: GroqProvider,
|
||||
ModelProviderName.OPENAI: OpenAIProvider,
|
||||
}[provider_name]
|
||||
except KeyError:
|
||||
raise ValueError(f"{provider_name} is not a known provider") from None
|
||||
|
||||
def __repr__(self):
|
||||
return f"{self.__class__.__name__}()"
|
||||
|
||||
|
||||
ChatModelProvider = AnthropicProvider | GroqProvider | OpenAIProvider | MultiProvider
|
||||
629
forge/forge/llm/providers/openai.py
Normal file
629
forge/forge/llm/providers/openai.py
Normal file
@@ -0,0 +1,629 @@
|
||||
import enum
|
||||
import inspect
|
||||
import logging
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Any, Callable, Iterator, Mapping, Optional, ParamSpec, TypeVar, cast
|
||||
|
||||
import tenacity
|
||||
import tiktoken
|
||||
import yaml
|
||||
from openai._exceptions import APIStatusError, RateLimitError
|
||||
from openai.types import EmbeddingCreateParams
|
||||
from openai.types.chat import (
|
||||
ChatCompletionMessage,
|
||||
ChatCompletionMessageParam,
|
||||
CompletionCreateParams,
|
||||
)
|
||||
from pydantic import SecretStr
|
||||
|
||||
from forge.json.parsing import json_loads
|
||||
from forge.models.config import UserConfigurable
|
||||
from forge.models.json_schema import JSONSchema
|
||||
|
||||
from ._openai_base import BaseOpenAIChatProvider, BaseOpenAIEmbeddingProvider
|
||||
from .schema import (
|
||||
AssistantToolCall,
|
||||
AssistantToolCallDict,
|
||||
ChatMessage,
|
||||
ChatModelInfo,
|
||||
CompletionModelFunction,
|
||||
Embedding,
|
||||
EmbeddingModelInfo,
|
||||
ModelProviderBudget,
|
||||
ModelProviderConfiguration,
|
||||
ModelProviderCredentials,
|
||||
ModelProviderName,
|
||||
ModelProviderSettings,
|
||||
ModelTokenizer,
|
||||
)
|
||||
|
||||
_T = TypeVar("_T")
|
||||
_P = ParamSpec("_P")
|
||||
|
||||
OpenAIEmbeddingParser = Callable[[Embedding], Embedding]
|
||||
|
||||
|
||||
class OpenAIModelName(str, enum.Enum):
|
||||
EMBEDDING_v2 = "text-embedding-ada-002"
|
||||
EMBEDDING_v3_S = "text-embedding-3-small"
|
||||
EMBEDDING_v3_L = "text-embedding-3-large"
|
||||
|
||||
GPT3_v1 = "gpt-3.5-turbo-0301"
|
||||
GPT3_v2 = "gpt-3.5-turbo-0613"
|
||||
GPT3_v2_16k = "gpt-3.5-turbo-16k-0613"
|
||||
GPT3_v3 = "gpt-3.5-turbo-1106"
|
||||
GPT3_v4 = "gpt-3.5-turbo-0125"
|
||||
GPT3_ROLLING = "gpt-3.5-turbo"
|
||||
GPT3_ROLLING_16k = "gpt-3.5-turbo-16k"
|
||||
GPT3 = GPT3_ROLLING
|
||||
GPT3_16k = GPT3_ROLLING_16k
|
||||
|
||||
GPT4_v1 = "gpt-4-0314"
|
||||
GPT4_v1_32k = "gpt-4-32k-0314"
|
||||
GPT4_v2 = "gpt-4-0613"
|
||||
GPT4_v2_32k = "gpt-4-32k-0613"
|
||||
GPT4_v3 = "gpt-4-1106-preview"
|
||||
GPT4_v3_VISION = "gpt-4-1106-vision-preview"
|
||||
GPT4_v4 = "gpt-4-0125-preview"
|
||||
GPT4_v5 = "gpt-4-turbo-2024-04-09"
|
||||
GPT4_ROLLING = "gpt-4"
|
||||
GPT4_ROLLING_32k = "gpt-4-32k"
|
||||
GPT4_TURBO = "gpt-4-turbo"
|
||||
GPT4_TURBO_PREVIEW = "gpt-4-turbo-preview"
|
||||
GPT4_VISION = "gpt-4-vision-preview"
|
||||
GPT4_O_v1 = "gpt-4o-2024-05-13"
|
||||
GPT4_O_ROLLING = "gpt-4o"
|
||||
GPT4 = GPT4_ROLLING
|
||||
GPT4_32k = GPT4_ROLLING_32k
|
||||
GPT4_O = GPT4_O_ROLLING
|
||||
|
||||
|
||||
OPEN_AI_EMBEDDING_MODELS = {
|
||||
info.name: info
|
||||
for info in [
|
||||
EmbeddingModelInfo(
|
||||
name=OpenAIModelName.EMBEDDING_v2,
|
||||
provider_name=ModelProviderName.OPENAI,
|
||||
prompt_token_cost=0.0001 / 1000,
|
||||
max_tokens=8191,
|
||||
embedding_dimensions=1536,
|
||||
),
|
||||
EmbeddingModelInfo(
|
||||
name=OpenAIModelName.EMBEDDING_v3_S,
|
||||
provider_name=ModelProviderName.OPENAI,
|
||||
prompt_token_cost=0.00002 / 1000,
|
||||
max_tokens=8191,
|
||||
embedding_dimensions=1536,
|
||||
),
|
||||
EmbeddingModelInfo(
|
||||
name=OpenAIModelName.EMBEDDING_v3_L,
|
||||
provider_name=ModelProviderName.OPENAI,
|
||||
prompt_token_cost=0.00013 / 1000,
|
||||
max_tokens=8191,
|
||||
embedding_dimensions=3072,
|
||||
),
|
||||
]
|
||||
}
|
||||
|
||||
|
||||
OPEN_AI_CHAT_MODELS = {
|
||||
info.name: info
|
||||
for info in [
|
||||
ChatModelInfo(
|
||||
name=OpenAIModelName.GPT3_v1,
|
||||
provider_name=ModelProviderName.OPENAI,
|
||||
prompt_token_cost=0.0015 / 1000,
|
||||
completion_token_cost=0.002 / 1000,
|
||||
max_tokens=4096,
|
||||
has_function_call_api=True,
|
||||
),
|
||||
ChatModelInfo(
|
||||
name=OpenAIModelName.GPT3_v2_16k,
|
||||
provider_name=ModelProviderName.OPENAI,
|
||||
prompt_token_cost=0.003 / 1000,
|
||||
completion_token_cost=0.004 / 1000,
|
||||
max_tokens=16384,
|
||||
has_function_call_api=True,
|
||||
),
|
||||
ChatModelInfo(
|
||||
name=OpenAIModelName.GPT3_v3,
|
||||
provider_name=ModelProviderName.OPENAI,
|
||||
prompt_token_cost=0.001 / 1000,
|
||||
completion_token_cost=0.002 / 1000,
|
||||
max_tokens=16384,
|
||||
has_function_call_api=True,
|
||||
),
|
||||
ChatModelInfo(
|
||||
name=OpenAIModelName.GPT3_v4,
|
||||
provider_name=ModelProviderName.OPENAI,
|
||||
prompt_token_cost=0.0005 / 1000,
|
||||
completion_token_cost=0.0015 / 1000,
|
||||
max_tokens=16384,
|
||||
has_function_call_api=True,
|
||||
),
|
||||
ChatModelInfo(
|
||||
name=OpenAIModelName.GPT4_v1,
|
||||
provider_name=ModelProviderName.OPENAI,
|
||||
prompt_token_cost=0.03 / 1000,
|
||||
completion_token_cost=0.06 / 1000,
|
||||
max_tokens=8191,
|
||||
has_function_call_api=True,
|
||||
),
|
||||
ChatModelInfo(
|
||||
name=OpenAIModelName.GPT4_v1_32k,
|
||||
provider_name=ModelProviderName.OPENAI,
|
||||
prompt_token_cost=0.06 / 1000,
|
||||
completion_token_cost=0.12 / 1000,
|
||||
max_tokens=32768,
|
||||
has_function_call_api=True,
|
||||
),
|
||||
ChatModelInfo(
|
||||
name=OpenAIModelName.GPT4_TURBO,
|
||||
provider_name=ModelProviderName.OPENAI,
|
||||
prompt_token_cost=0.01 / 1000,
|
||||
completion_token_cost=0.03 / 1000,
|
||||
max_tokens=128000,
|
||||
has_function_call_api=True,
|
||||
),
|
||||
ChatModelInfo(
|
||||
name=OpenAIModelName.GPT4_O,
|
||||
provider_name=ModelProviderName.OPENAI,
|
||||
prompt_token_cost=5 / 1_000_000,
|
||||
completion_token_cost=15 / 1_000_000,
|
||||
max_tokens=128_000,
|
||||
has_function_call_api=True,
|
||||
),
|
||||
]
|
||||
}
|
||||
# Copy entries for models with equivalent specs
|
||||
chat_model_mapping = {
|
||||
OpenAIModelName.GPT3_v1: [OpenAIModelName.GPT3_v2],
|
||||
OpenAIModelName.GPT3_v2_16k: [OpenAIModelName.GPT3_16k],
|
||||
OpenAIModelName.GPT3_v4: [OpenAIModelName.GPT3_ROLLING],
|
||||
OpenAIModelName.GPT4_v1: [OpenAIModelName.GPT4_v2, OpenAIModelName.GPT4_ROLLING],
|
||||
OpenAIModelName.GPT4_v1_32k: [
|
||||
OpenAIModelName.GPT4_v2_32k,
|
||||
OpenAIModelName.GPT4_32k,
|
||||
],
|
||||
OpenAIModelName.GPT4_TURBO: [
|
||||
OpenAIModelName.GPT4_v3,
|
||||
OpenAIModelName.GPT4_v3_VISION,
|
||||
OpenAIModelName.GPT4_VISION,
|
||||
OpenAIModelName.GPT4_v4,
|
||||
OpenAIModelName.GPT4_TURBO_PREVIEW,
|
||||
OpenAIModelName.GPT4_v5,
|
||||
],
|
||||
OpenAIModelName.GPT4_O: [OpenAIModelName.GPT4_O_v1],
|
||||
}
|
||||
for base, copies in chat_model_mapping.items():
|
||||
for copy in copies:
|
||||
copy_info = OPEN_AI_CHAT_MODELS[base].copy(update={"name": copy})
|
||||
OPEN_AI_CHAT_MODELS[copy] = copy_info
|
||||
if copy.endswith(("-0301", "-0314")):
|
||||
copy_info.has_function_call_api = False
|
||||
|
||||
|
||||
OPEN_AI_MODELS: Mapping[
|
||||
OpenAIModelName,
|
||||
ChatModelInfo[OpenAIModelName] | EmbeddingModelInfo[OpenAIModelName],
|
||||
] = {
|
||||
**OPEN_AI_CHAT_MODELS,
|
||||
**OPEN_AI_EMBEDDING_MODELS,
|
||||
}
|
||||
|
||||
|
||||
class OpenAICredentials(ModelProviderCredentials):
|
||||
"""Credentials for OpenAI."""
|
||||
|
||||
api_key: SecretStr = UserConfigurable(from_env="OPENAI_API_KEY") # type: ignore
|
||||
api_base: Optional[SecretStr] = UserConfigurable(
|
||||
default=None, from_env="OPENAI_API_BASE_URL"
|
||||
)
|
||||
organization: Optional[SecretStr] = UserConfigurable(from_env="OPENAI_ORGANIZATION")
|
||||
|
||||
api_type: Optional[SecretStr] = UserConfigurable(
|
||||
default=None,
|
||||
from_env=lambda: cast(
|
||||
SecretStr | None,
|
||||
"azure"
|
||||
if os.getenv("USE_AZURE") == "True"
|
||||
else os.getenv("OPENAI_API_TYPE"),
|
||||
),
|
||||
)
|
||||
api_version: Optional[SecretStr] = UserConfigurable(
|
||||
default=None, from_env="OPENAI_API_VERSION"
|
||||
)
|
||||
azure_endpoint: Optional[SecretStr] = None
|
||||
azure_model_to_deploy_id_map: Optional[dict[str, str]] = None
|
||||
|
||||
def get_api_access_kwargs(self) -> dict[str, str]:
|
||||
kwargs = {
|
||||
k: v.get_secret_value()
|
||||
for k, v in {
|
||||
"api_key": self.api_key,
|
||||
"base_url": self.api_base,
|
||||
"organization": self.organization,
|
||||
"api_version": self.api_version,
|
||||
}.items()
|
||||
if v is not None
|
||||
}
|
||||
if self.api_type == SecretStr("azure"):
|
||||
assert self.azure_endpoint, "Azure endpoint not configured"
|
||||
kwargs["azure_endpoint"] = self.azure_endpoint.get_secret_value()
|
||||
return kwargs
|
||||
|
||||
def get_model_access_kwargs(self, model: str) -> dict[str, str]:
|
||||
kwargs = {"model": model}
|
||||
if self.api_type == SecretStr("azure") and model:
|
||||
azure_kwargs = self._get_azure_access_kwargs(model)
|
||||
kwargs.update(azure_kwargs)
|
||||
return kwargs
|
||||
|
||||
def load_azure_config(self, config_file: Path) -> None:
|
||||
with open(config_file) as file:
|
||||
config_params = yaml.load(file, Loader=yaml.SafeLoader) or {}
|
||||
|
||||
try:
|
||||
assert config_params.get(
|
||||
"azure_model_map", {}
|
||||
), "Azure model->deployment_id map is empty"
|
||||
except AssertionError as e:
|
||||
raise ValueError(*e.args)
|
||||
|
||||
self.api_type = config_params.get("azure_api_type", "azure")
|
||||
self.api_version = config_params.get("azure_api_version", None)
|
||||
self.azure_endpoint = config_params.get("azure_endpoint")
|
||||
self.azure_model_to_deploy_id_map = config_params.get("azure_model_map")
|
||||
|
||||
def _get_azure_access_kwargs(self, model: str) -> dict[str, str]:
|
||||
"""Get the kwargs for the Azure API."""
|
||||
|
||||
if not self.azure_model_to_deploy_id_map:
|
||||
raise ValueError("Azure model deployment map not configured")
|
||||
|
||||
if model not in self.azure_model_to_deploy_id_map:
|
||||
raise ValueError(f"No Azure deployment ID configured for model '{model}'")
|
||||
deployment_id = self.azure_model_to_deploy_id_map[model]
|
||||
|
||||
return {"model": deployment_id}
|
||||
|
||||
|
||||
class OpenAISettings(ModelProviderSettings):
|
||||
credentials: Optional[OpenAICredentials] # type: ignore
|
||||
budget: ModelProviderBudget # type: ignore
|
||||
|
||||
|
||||
class OpenAIProvider(
|
||||
BaseOpenAIChatProvider[OpenAIModelName, OpenAISettings],
|
||||
BaseOpenAIEmbeddingProvider[OpenAIModelName, OpenAISettings],
|
||||
):
|
||||
MODELS = OPEN_AI_MODELS
|
||||
CHAT_MODELS = OPEN_AI_CHAT_MODELS
|
||||
EMBEDDING_MODELS = OPEN_AI_EMBEDDING_MODELS
|
||||
|
||||
default_settings = OpenAISettings(
|
||||
name="openai_provider",
|
||||
description="Provides access to OpenAI's API.",
|
||||
configuration=ModelProviderConfiguration(),
|
||||
credentials=None,
|
||||
budget=ModelProviderBudget(),
|
||||
)
|
||||
|
||||
_settings: OpenAISettings
|
||||
_configuration: ModelProviderConfiguration
|
||||
_credentials: OpenAICredentials
|
||||
_budget: ModelProviderBudget
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
settings: Optional[OpenAISettings] = None,
|
||||
logger: Optional[logging.Logger] = None,
|
||||
):
|
||||
super(OpenAIProvider, self).__init__(settings=settings, logger=logger)
|
||||
|
||||
if self._credentials.api_type == SecretStr("azure"):
|
||||
from openai import AsyncAzureOpenAI
|
||||
|
||||
# API key and org (if configured) are passed, the rest of the required
|
||||
# credentials is loaded from the environment by the AzureOpenAI client.
|
||||
self._client = AsyncAzureOpenAI(
|
||||
**self._credentials.get_api_access_kwargs() # type: ignore
|
||||
)
|
||||
else:
|
||||
from openai import AsyncOpenAI
|
||||
|
||||
self._client = AsyncOpenAI(
|
||||
**self._credentials.get_api_access_kwargs() # type: ignore
|
||||
)
|
||||
|
||||
def get_tokenizer(self, model_name: OpenAIModelName) -> ModelTokenizer[int]:
|
||||
return tiktoken.encoding_for_model(model_name)
|
||||
|
||||
def count_message_tokens(
|
||||
self,
|
||||
messages: ChatMessage | list[ChatMessage],
|
||||
model_name: OpenAIModelName,
|
||||
) -> int:
|
||||
if isinstance(messages, ChatMessage):
|
||||
messages = [messages]
|
||||
|
||||
if model_name.startswith("gpt-3.5-turbo"):
|
||||
tokens_per_message = (
|
||||
4 # every message follows <|start|>{role/name}\n{content}<|end|>\n
|
||||
)
|
||||
tokens_per_name = -1 # if there's a name, the role is omitted
|
||||
# TODO: check if this is still valid for gpt-4o
|
||||
elif model_name.startswith("gpt-4"):
|
||||
tokens_per_message = 3
|
||||
tokens_per_name = 1
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
f"count_message_tokens() is not implemented for model {model_name}.\n"
|
||||
"See https://github.com/openai/openai-python/blob/120d225b91a8453e15240a49fb1c6794d8119326/chatml.md " # noqa
|
||||
"for information on how messages are converted to tokens."
|
||||
)
|
||||
tokenizer = self.get_tokenizer(model_name)
|
||||
|
||||
num_tokens = 0
|
||||
for message in messages:
|
||||
num_tokens += tokens_per_message
|
||||
for key, value in message.dict().items():
|
||||
num_tokens += len(tokenizer.encode(value))
|
||||
if key == "name":
|
||||
num_tokens += tokens_per_name
|
||||
num_tokens += 3 # every reply is primed with <|start|>assistant<|message|>
|
||||
return num_tokens
|
||||
|
||||
def _get_chat_completion_args(
|
||||
self,
|
||||
prompt_messages: list[ChatMessage],
|
||||
model: OpenAIModelName,
|
||||
functions: Optional[list[CompletionModelFunction]] = None,
|
||||
max_output_tokens: Optional[int] = None,
|
||||
**kwargs,
|
||||
) -> tuple[
|
||||
list[ChatCompletionMessageParam], CompletionCreateParams, dict[str, Any]
|
||||
]:
|
||||
"""Prepare keyword arguments for an OpenAI chat completion call
|
||||
|
||||
Args:
|
||||
prompt_messages: List of ChatMessages
|
||||
model: The model to use
|
||||
functions (optional): List of functions available to the LLM
|
||||
max_output_tokens (optional): Maximum number of tokens to generate
|
||||
|
||||
Returns:
|
||||
list[ChatCompletionMessageParam]: Prompt messages for the OpenAI call
|
||||
CompletionCreateParams: Mapping of other kwargs for the OpenAI call
|
||||
Mapping[str, Any]: Any keyword arguments to pass on to the completion parser
|
||||
"""
|
||||
tools_compat_mode = False
|
||||
if functions:
|
||||
if not OPEN_AI_CHAT_MODELS[model].has_function_call_api:
|
||||
# Provide compatibility with older models
|
||||
_functions_compat_fix_kwargs(functions, prompt_messages)
|
||||
tools_compat_mode = True
|
||||
functions = None
|
||||
|
||||
openai_messages, kwargs, parse_kwargs = super()._get_chat_completion_args(
|
||||
prompt_messages=prompt_messages,
|
||||
model=model,
|
||||
functions=functions,
|
||||
max_output_tokens=max_output_tokens,
|
||||
**kwargs,
|
||||
)
|
||||
kwargs.update(self._credentials.get_model_access_kwargs(model)) # type: ignore
|
||||
|
||||
if tools_compat_mode:
|
||||
parse_kwargs["compat_mode"] = True
|
||||
|
||||
return openai_messages, kwargs, parse_kwargs
|
||||
|
||||
def _parse_assistant_tool_calls(
|
||||
self,
|
||||
assistant_message: ChatCompletionMessage,
|
||||
compat_mode: bool = False,
|
||||
**kwargs,
|
||||
) -> tuple[list[AssistantToolCall], list[Exception]]:
|
||||
tool_calls: list[AssistantToolCall] = []
|
||||
parse_errors: list[Exception] = []
|
||||
|
||||
if not compat_mode:
|
||||
return super()._parse_assistant_tool_calls(
|
||||
assistant_message=assistant_message, compat_mode=compat_mode, **kwargs
|
||||
)
|
||||
elif assistant_message.content:
|
||||
try:
|
||||
tool_calls = list(
|
||||
_tool_calls_compat_extract_calls(assistant_message.content)
|
||||
)
|
||||
except Exception as e:
|
||||
parse_errors.append(e)
|
||||
|
||||
return tool_calls, parse_errors
|
||||
|
||||
def _get_embedding_kwargs(
|
||||
self, input: str | list[str], model: OpenAIModelName, **kwargs
|
||||
) -> EmbeddingCreateParams:
|
||||
kwargs = super()._get_embedding_kwargs(input=input, model=model, **kwargs)
|
||||
kwargs.update(self._credentials.get_model_access_kwargs(model)) # type: ignore
|
||||
return kwargs
|
||||
|
||||
_get_embedding_kwargs.__doc__ = (
|
||||
BaseOpenAIEmbeddingProvider._get_embedding_kwargs.__doc__
|
||||
)
|
||||
|
||||
def _retry_api_request(self, func: Callable[_P, _T]) -> Callable[_P, _T]:
|
||||
_log_retry_debug_message = tenacity.after_log(self._logger, logging.DEBUG)
|
||||
|
||||
def _log_on_fail(retry_state: tenacity.RetryCallState) -> None:
|
||||
_log_retry_debug_message(retry_state)
|
||||
|
||||
if (
|
||||
retry_state.attempt_number == 0
|
||||
and retry_state.outcome
|
||||
and isinstance(retry_state.outcome.exception(), RateLimitError)
|
||||
):
|
||||
self._logger.warning(
|
||||
"Please double check that you have setup a PAID OpenAI API Account."
|
||||
" You can read more here: "
|
||||
"https://docs.agpt.co/setup/#getting-an-openai-api-key"
|
||||
)
|
||||
|
||||
return tenacity.retry(
|
||||
retry=(
|
||||
tenacity.retry_if_exception_type(RateLimitError)
|
||||
| tenacity.retry_if_exception(
|
||||
lambda e: isinstance(e, APIStatusError) and e.status_code == 502
|
||||
)
|
||||
),
|
||||
wait=tenacity.wait_exponential(),
|
||||
stop=tenacity.stop_after_attempt(self._configuration.retries_per_request),
|
||||
after=_log_on_fail,
|
||||
)(func)
|
||||
|
||||
def __repr__(self):
|
||||
return "OpenAIProvider()"
|
||||
|
||||
|
||||
def format_function_specs_as_typescript_ns(
|
||||
functions: list[CompletionModelFunction],
|
||||
) -> str:
|
||||
"""Returns a function signature block in the format used by OpenAI internally:
|
||||
https://community.openai.com/t/how-to-calculate-the-tokens-when-using-function-call/266573/18
|
||||
|
||||
For use with `count_tokens` to determine token usage of provided functions.
|
||||
|
||||
Example:
|
||||
```ts
|
||||
namespace functions {
|
||||
|
||||
// Get the current weather in a given location
|
||||
type get_current_weather = (_: {
|
||||
// The city and state, e.g. San Francisco, CA
|
||||
location: string,
|
||||
unit?: "celsius" | "fahrenheit",
|
||||
}) => any;
|
||||
|
||||
} // namespace functions
|
||||
```
|
||||
"""
|
||||
|
||||
return (
|
||||
"namespace functions {\n\n"
|
||||
+ "\n\n".join(format_openai_function_for_prompt(f) for f in functions)
|
||||
+ "\n\n} // namespace functions"
|
||||
)
|
||||
|
||||
|
||||
def format_openai_function_for_prompt(func: CompletionModelFunction) -> str:
|
||||
"""Returns the function formatted similarly to the way OpenAI does it internally:
|
||||
https://community.openai.com/t/how-to-calculate-the-tokens-when-using-function-call/266573/18
|
||||
|
||||
Example:
|
||||
```ts
|
||||
// Get the current weather in a given location
|
||||
type get_current_weather = (_: {
|
||||
// The city and state, e.g. San Francisco, CA
|
||||
location: string,
|
||||
unit?: "celsius" | "fahrenheit",
|
||||
}) => any;
|
||||
```
|
||||
"""
|
||||
|
||||
def param_signature(name: str, spec: JSONSchema) -> str:
|
||||
return (
|
||||
f"// {spec.description}\n" if spec.description else ""
|
||||
) + f"{name}{'' if spec.required else '?'}: {spec.typescript_type},"
|
||||
|
||||
return "\n".join(
|
||||
[
|
||||
f"// {func.description}",
|
||||
f"type {func.name} = (_ :{{",
|
||||
*[param_signature(name, p) for name, p in func.parameters.items()],
|
||||
"}) => any;",
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
def count_openai_functions_tokens(
|
||||
functions: list[CompletionModelFunction], count_tokens: Callable[[str], int]
|
||||
) -> int:
|
||||
"""Returns the number of tokens taken up by a set of function definitions
|
||||
|
||||
Reference: https://community.openai.com/t/how-to-calculate-the-tokens-when-using-function-call/266573/18 # noqa: E501
|
||||
"""
|
||||
return count_tokens(
|
||||
"# Tools\n\n"
|
||||
"## functions\n\n"
|
||||
f"{format_function_specs_as_typescript_ns(functions)}"
|
||||
)
|
||||
|
||||
|
||||
def _functions_compat_fix_kwargs(
|
||||
functions: list[CompletionModelFunction],
|
||||
prompt_messages: list[ChatMessage],
|
||||
):
|
||||
function_definitions = format_function_specs_as_typescript_ns(functions)
|
||||
function_call_schema = JSONSchema(
|
||||
type=JSONSchema.Type.OBJECT,
|
||||
properties={
|
||||
"name": JSONSchema(
|
||||
description="The name of the function to call",
|
||||
enum=[f.name for f in functions],
|
||||
required=True,
|
||||
),
|
||||
"arguments": JSONSchema(
|
||||
description="The arguments for the function call",
|
||||
type=JSONSchema.Type.OBJECT,
|
||||
required=True,
|
||||
),
|
||||
},
|
||||
)
|
||||
tool_calls_schema = JSONSchema(
|
||||
type=JSONSchema.Type.ARRAY,
|
||||
items=JSONSchema(
|
||||
type=JSONSchema.Type.OBJECT,
|
||||
properties={
|
||||
"type": JSONSchema(
|
||||
type=JSONSchema.Type.STRING,
|
||||
enum=["function"],
|
||||
),
|
||||
"function": function_call_schema,
|
||||
},
|
||||
),
|
||||
)
|
||||
prompt_messages.append(
|
||||
ChatMessage.system(
|
||||
"# tool usage instructions\n\n"
|
||||
"Specify a '```tool_calls' block in your response,"
|
||||
" with a valid JSON object that adheres to the following schema:\n\n"
|
||||
f"{tool_calls_schema.to_dict()}\n\n"
|
||||
"Specify any tools that you need to use through this JSON object.\n\n"
|
||||
"Put the tool_calls block at the end of your response"
|
||||
" and include its fences if it is not the only content.\n\n"
|
||||
"## functions\n\n"
|
||||
"For the function call itself, use one of the following"
|
||||
f" functions:\n\n{function_definitions}"
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def _tool_calls_compat_extract_calls(response: str) -> Iterator[AssistantToolCall]:
|
||||
import re
|
||||
import uuid
|
||||
|
||||
logging.debug(f"Trying to extract tool calls from response:\n{response}")
|
||||
|
||||
if response[0] == "[":
|
||||
tool_calls: list[AssistantToolCallDict] = json_loads(response)
|
||||
else:
|
||||
block = re.search(r"```(?:tool_calls)?\n(.*)\n```\s*$", response, re.DOTALL)
|
||||
if not block:
|
||||
raise ValueError("Could not find tool_calls block in response")
|
||||
tool_calls: list[AssistantToolCallDict] = json_loads(block.group(1))
|
||||
|
||||
for t in tool_calls:
|
||||
t["id"] = str(uuid.uuid4())
|
||||
yield AssistantToolCall.parse_obj(t)
|
||||
460
forge/forge/llm/providers/schema.py
Normal file
460
forge/forge/llm/providers/schema.py
Normal file
@@ -0,0 +1,460 @@
|
||||
import abc
|
||||
import enum
|
||||
import logging
|
||||
import math
|
||||
from collections import defaultdict
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
Callable,
|
||||
ClassVar,
|
||||
Generic,
|
||||
Literal,
|
||||
Optional,
|
||||
Protocol,
|
||||
Sequence,
|
||||
TypedDict,
|
||||
TypeVar,
|
||||
)
|
||||
|
||||
from pydantic import BaseModel, Field, SecretStr
|
||||
|
||||
from forge.logging.utils import fmt_kwargs
|
||||
from forge.models.config import (
|
||||
Configurable,
|
||||
SystemConfiguration,
|
||||
SystemSettings,
|
||||
UserConfigurable,
|
||||
)
|
||||
from forge.models.json_schema import JSONSchema
|
||||
from forge.models.providers import (
|
||||
Embedding,
|
||||
ProviderBudget,
|
||||
ProviderCredentials,
|
||||
ResourceType,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from jsonschema import ValidationError
|
||||
|
||||
|
||||
_T = TypeVar("_T")
|
||||
|
||||
_ModelName = TypeVar("_ModelName", bound=str)
|
||||
|
||||
|
||||
class ModelProviderService(str, enum.Enum):
|
||||
"""A ModelService describes what kind of service the model provides."""
|
||||
|
||||
EMBEDDING = "embedding"
|
||||
CHAT = "chat_completion"
|
||||
TEXT = "text_completion"
|
||||
|
||||
|
||||
class ModelProviderName(str, enum.Enum):
|
||||
OPENAI = "openai"
|
||||
ANTHROPIC = "anthropic"
|
||||
GROQ = "groq"
|
||||
|
||||
|
||||
class ChatMessage(BaseModel):
|
||||
class Role(str, enum.Enum):
|
||||
USER = "user"
|
||||
SYSTEM = "system"
|
||||
ASSISTANT = "assistant"
|
||||
|
||||
TOOL = "tool"
|
||||
"""May be used for the result of tool calls"""
|
||||
FUNCTION = "function"
|
||||
"""May be used for the return value of function calls"""
|
||||
|
||||
role: Role
|
||||
content: str
|
||||
|
||||
@staticmethod
|
||||
def user(content: str) -> "ChatMessage":
|
||||
return ChatMessage(role=ChatMessage.Role.USER, content=content)
|
||||
|
||||
@staticmethod
|
||||
def system(content: str) -> "ChatMessage":
|
||||
return ChatMessage(role=ChatMessage.Role.SYSTEM, content=content)
|
||||
|
||||
|
||||
class ChatMessageDict(TypedDict):
|
||||
role: str
|
||||
content: str
|
||||
|
||||
|
||||
class AssistantFunctionCall(BaseModel):
|
||||
name: str
|
||||
arguments: dict[str, Any]
|
||||
|
||||
def __str__(self) -> str:
|
||||
return f"{self.name}({fmt_kwargs(self.arguments)})"
|
||||
|
||||
|
||||
class AssistantFunctionCallDict(TypedDict):
|
||||
name: str
|
||||
arguments: dict[str, Any]
|
||||
|
||||
|
||||
class AssistantToolCall(BaseModel):
|
||||
id: str
|
||||
type: Literal["function"]
|
||||
function: AssistantFunctionCall
|
||||
|
||||
|
||||
class AssistantToolCallDict(TypedDict):
|
||||
id: str
|
||||
type: Literal["function"]
|
||||
function: AssistantFunctionCallDict
|
||||
|
||||
|
||||
class AssistantChatMessage(ChatMessage):
|
||||
role: Literal[ChatMessage.Role.ASSISTANT] = ChatMessage.Role.ASSISTANT # type: ignore # noqa
|
||||
content: str = ""
|
||||
tool_calls: Optional[list[AssistantToolCall]] = None
|
||||
|
||||
|
||||
class ToolResultMessage(ChatMessage):
|
||||
role: Literal[ChatMessage.Role.TOOL] = ChatMessage.Role.TOOL # type: ignore
|
||||
is_error: bool = False
|
||||
tool_call_id: str
|
||||
|
||||
|
||||
class AssistantChatMessageDict(TypedDict, total=False):
|
||||
role: str
|
||||
content: str
|
||||
tool_calls: list[AssistantToolCallDict]
|
||||
|
||||
|
||||
class CompletionModelFunction(BaseModel):
|
||||
"""General representation object for LLM-callable functions."""
|
||||
|
||||
name: str
|
||||
description: str
|
||||
parameters: dict[str, "JSONSchema"]
|
||||
return_type: str | None = None
|
||||
is_async: bool = False
|
||||
|
||||
def fmt_line(self) -> str:
|
||||
params = ", ".join(
|
||||
f"{name}{'?' if not p.required else ''}: " f"{p.typescript_type}"
|
||||
for name, p in self.parameters.items()
|
||||
)
|
||||
return f"{self.name}: {self.description}. Params: ({params})"
|
||||
|
||||
def fmt_header(self, impl="pass", force_async=False) -> str:
|
||||
"""
|
||||
Formats and returns the function header as a string with types and descriptions.
|
||||
|
||||
Returns:
|
||||
str: The formatted function header.
|
||||
"""
|
||||
def indent(content: str, spaces: int = 4):
|
||||
return " " * spaces + content.replace("\n", "\n" + " " * spaces)
|
||||
|
||||
params = ", ".join(
|
||||
f"{name}: {p.python_type}{f'= {str(p.default)}' if p.default else ' = None' if not p.required else ''}"
|
||||
for name, p in self.parameters.items()
|
||||
)
|
||||
func = "async def" if self.is_async or force_async else "def"
|
||||
return_str = f" -> {self.return_type}" if self.return_type else ""
|
||||
return f"{func} {self.name}({params}){return_str}:\n" + indent(
|
||||
(
|
||||
'"""\n'
|
||||
f"{self.description}\n\n"
|
||||
"Params:\n"
|
||||
+ indent(
|
||||
"\n".join(
|
||||
f"{name}: {param.description}"
|
||||
for name, param in self.parameters.items()
|
||||
if param.description
|
||||
)
|
||||
)
|
||||
+ "\n"
|
||||
'"""\n'
|
||||
f"{impl}"
|
||||
),
|
||||
)
|
||||
|
||||
def validate_call(
|
||||
self, function_call: AssistantFunctionCall
|
||||
) -> tuple[bool, list["ValidationError"]]:
|
||||
"""
|
||||
Validates the given function call against the function's parameter specs
|
||||
|
||||
Returns:
|
||||
bool: Whether the given set of arguments is valid for this command
|
||||
list[ValidationError]: Issues with the set of arguments (if any)
|
||||
|
||||
Raises:
|
||||
ValueError: If the function_call doesn't call this function
|
||||
"""
|
||||
if function_call.name != self.name:
|
||||
raise ValueError(
|
||||
f"Can't validate {function_call.name} call using {self.name} spec"
|
||||
)
|
||||
|
||||
params_schema = JSONSchema(
|
||||
type=JSONSchema.Type.OBJECT,
|
||||
properties={name: spec for name, spec in self.parameters.items()},
|
||||
)
|
||||
return params_schema.validate_object(function_call.arguments)
|
||||
|
||||
|
||||
class ModelInfo(BaseModel, Generic[_ModelName]):
|
||||
"""Struct for model information.
|
||||
|
||||
Would be lovely to eventually get this directly from APIs, but needs to be
|
||||
scraped from websites for now.
|
||||
"""
|
||||
|
||||
name: _ModelName
|
||||
service: ClassVar[ModelProviderService]
|
||||
provider_name: ModelProviderName
|
||||
prompt_token_cost: float = 0.0
|
||||
completion_token_cost: float = 0.0
|
||||
|
||||
|
||||
class ModelResponse(BaseModel):
|
||||
"""Standard response struct for a response from a model."""
|
||||
|
||||
prompt_tokens_used: int
|
||||
completion_tokens_used: int
|
||||
model_info: ModelInfo
|
||||
|
||||
|
||||
class ModelProviderConfiguration(SystemConfiguration):
|
||||
retries_per_request: int = UserConfigurable(7)
|
||||
fix_failed_parse_tries: int = UserConfigurable(3)
|
||||
extra_request_headers: dict[str, str] = Field(default_factory=dict)
|
||||
|
||||
|
||||
class ModelProviderCredentials(ProviderCredentials):
|
||||
"""Credentials for a model provider."""
|
||||
|
||||
api_key: SecretStr | None = UserConfigurable(default=None)
|
||||
api_type: SecretStr | None = UserConfigurable(default=None)
|
||||
api_base: SecretStr | None = UserConfigurable(default=None)
|
||||
api_version: SecretStr | None = UserConfigurable(default=None)
|
||||
deployment_id: SecretStr | None = UserConfigurable(default=None)
|
||||
|
||||
class Config(ProviderCredentials.Config):
|
||||
extra = "ignore"
|
||||
|
||||
|
||||
class ModelProviderUsage(BaseModel):
|
||||
"""Usage for a particular model from a model provider."""
|
||||
|
||||
class ModelUsage(BaseModel):
|
||||
completion_tokens: int = 0
|
||||
prompt_tokens: int = 0
|
||||
|
||||
usage_per_model: dict[str, ModelUsage] = defaultdict(ModelUsage)
|
||||
|
||||
@property
|
||||
def completion_tokens(self) -> int:
|
||||
return sum(model.completion_tokens for model in self.usage_per_model.values())
|
||||
|
||||
@property
|
||||
def prompt_tokens(self) -> int:
|
||||
return sum(model.prompt_tokens for model in self.usage_per_model.values())
|
||||
|
||||
def update_usage(
|
||||
self,
|
||||
model: str,
|
||||
input_tokens_used: int,
|
||||
output_tokens_used: int = 0,
|
||||
) -> None:
|
||||
self.usage_per_model[model].prompt_tokens += input_tokens_used
|
||||
self.usage_per_model[model].completion_tokens += output_tokens_used
|
||||
|
||||
|
||||
class ModelProviderBudget(ProviderBudget[ModelProviderUsage]):
|
||||
usage: ModelProviderUsage = Field(default_factory=ModelProviderUsage)
|
||||
|
||||
def update_usage_and_cost(
|
||||
self,
|
||||
model_info: ModelInfo,
|
||||
input_tokens_used: int,
|
||||
output_tokens_used: int = 0,
|
||||
) -> float:
|
||||
"""Update the usage and cost of the provider.
|
||||
|
||||
Returns:
|
||||
float: The (calculated) cost of the given model response.
|
||||
"""
|
||||
self.usage.update_usage(model_info.name, input_tokens_used, output_tokens_used)
|
||||
incurred_cost = (
|
||||
output_tokens_used * model_info.completion_token_cost
|
||||
+ input_tokens_used * model_info.prompt_token_cost
|
||||
)
|
||||
self.total_cost += incurred_cost
|
||||
self.remaining_budget -= incurred_cost
|
||||
return incurred_cost
|
||||
|
||||
|
||||
class ModelProviderSettings(SystemSettings):
|
||||
resource_type: ClassVar[ResourceType] = ResourceType.MODEL
|
||||
configuration: ModelProviderConfiguration
|
||||
credentials: Optional[ModelProviderCredentials] = None
|
||||
budget: Optional[ModelProviderBudget] = None
|
||||
|
||||
|
||||
_ModelProviderSettings = TypeVar("_ModelProviderSettings", bound=ModelProviderSettings)
|
||||
|
||||
|
||||
# TODO: either use MultiProvider throughout codebase as type for `llm_provider`, or
|
||||
# replace `_ModelName` by `str` to eliminate type checking difficulties
|
||||
class BaseModelProvider(
|
||||
abc.ABC,
|
||||
Generic[_ModelName, _ModelProviderSettings],
|
||||
Configurable[_ModelProviderSettings],
|
||||
):
|
||||
"""A ModelProvider abstracts the details of a particular provider of models."""
|
||||
|
||||
default_settings: ClassVar[_ModelProviderSettings] # type: ignore
|
||||
|
||||
_settings: _ModelProviderSettings
|
||||
_logger: logging.Logger
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
settings: Optional[_ModelProviderSettings] = None,
|
||||
logger: Optional[logging.Logger] = None,
|
||||
):
|
||||
if not settings:
|
||||
settings = self.default_settings.copy(deep=True)
|
||||
|
||||
self._settings = settings
|
||||
self._configuration = settings.configuration
|
||||
self._credentials = settings.credentials
|
||||
self._budget = settings.budget
|
||||
|
||||
self._logger = logger or logging.getLogger(self.__module__)
|
||||
|
||||
@abc.abstractmethod
|
||||
async def get_available_models(
|
||||
self,
|
||||
) -> Sequence["ChatModelInfo[_ModelName] | EmbeddingModelInfo[_ModelName]"]:
|
||||
...
|
||||
|
||||
@abc.abstractmethod
|
||||
def count_tokens(self, text: str, model_name: _ModelName) -> int:
|
||||
...
|
||||
|
||||
@abc.abstractmethod
|
||||
def get_tokenizer(self, model_name: _ModelName) -> "ModelTokenizer[Any]":
|
||||
...
|
||||
|
||||
@abc.abstractmethod
|
||||
def get_token_limit(self, model_name: _ModelName) -> int:
|
||||
...
|
||||
|
||||
def get_incurred_cost(self) -> float:
|
||||
if self._budget:
|
||||
return self._budget.total_cost
|
||||
return 0
|
||||
|
||||
def get_remaining_budget(self) -> float:
|
||||
if self._budget:
|
||||
return self._budget.remaining_budget
|
||||
return math.inf
|
||||
|
||||
|
||||
class ModelTokenizer(Protocol, Generic[_T]):
|
||||
"""A ModelTokenizer provides tokenization specific to a model."""
|
||||
|
||||
@abc.abstractmethod
|
||||
def encode(self, text: str) -> list[_T]:
|
||||
...
|
||||
|
||||
@abc.abstractmethod
|
||||
def decode(self, tokens: list[_T]) -> str:
|
||||
...
|
||||
|
||||
|
||||
####################
|
||||
# Embedding Models #
|
||||
####################
|
||||
|
||||
|
||||
class EmbeddingModelInfo(ModelInfo[_ModelName]):
|
||||
"""Struct for embedding model information."""
|
||||
|
||||
service: Literal[ModelProviderService.EMBEDDING] = ModelProviderService.EMBEDDING # type: ignore # noqa
|
||||
max_tokens: int
|
||||
embedding_dimensions: int
|
||||
|
||||
|
||||
class EmbeddingModelResponse(ModelResponse):
|
||||
"""Standard response struct for a response from an embedding model."""
|
||||
|
||||
embedding: Embedding = Field(default_factory=list)
|
||||
completion_tokens_used: int = Field(default=0, const=True)
|
||||
|
||||
|
||||
class BaseEmbeddingModelProvider(BaseModelProvider[_ModelName, _ModelProviderSettings]):
|
||||
@abc.abstractmethod
|
||||
async def get_available_embedding_models(
|
||||
self,
|
||||
) -> Sequence[EmbeddingModelInfo[_ModelName]]:
|
||||
...
|
||||
|
||||
@abc.abstractmethod
|
||||
async def create_embedding(
|
||||
self,
|
||||
text: str,
|
||||
model_name: _ModelName,
|
||||
embedding_parser: Callable[[Embedding], Embedding],
|
||||
**kwargs,
|
||||
) -> EmbeddingModelResponse:
|
||||
...
|
||||
|
||||
|
||||
###############
|
||||
# Chat Models #
|
||||
###############
|
||||
|
||||
|
||||
class ChatModelInfo(ModelInfo[_ModelName]):
|
||||
"""Struct for language model information."""
|
||||
|
||||
service: Literal[ModelProviderService.CHAT] = ModelProviderService.CHAT # type: ignore # noqa
|
||||
max_tokens: int
|
||||
has_function_call_api: bool = False
|
||||
|
||||
|
||||
class ChatModelResponse(ModelResponse, Generic[_T]):
|
||||
"""Standard response struct for a response from a language model."""
|
||||
|
||||
response: AssistantChatMessage
|
||||
parsed_result: _T
|
||||
|
||||
|
||||
class BaseChatModelProvider(BaseModelProvider[_ModelName, _ModelProviderSettings]):
|
||||
@abc.abstractmethod
|
||||
async def get_available_chat_models(self) -> Sequence[ChatModelInfo[_ModelName]]:
|
||||
...
|
||||
|
||||
@abc.abstractmethod
|
||||
def count_message_tokens(
|
||||
self,
|
||||
messages: ChatMessage | list[ChatMessage],
|
||||
model_name: _ModelName,
|
||||
) -> int:
|
||||
...
|
||||
|
||||
@abc.abstractmethod
|
||||
async def create_chat_completion(
|
||||
self,
|
||||
model_prompt: list[ChatMessage],
|
||||
model_name: _ModelName,
|
||||
completion_parser: Callable[[AssistantChatMessage], _T] = lambda _: None,
|
||||
functions: Optional[list[CompletionModelFunction]] = None,
|
||||
max_output_tokens: Optional[int] = None,
|
||||
prefill_response: str = "",
|
||||
**kwargs,
|
||||
) -> ChatModelResponse[_T]:
|
||||
...
|
||||
88
forge/forge/llm/providers/utils.py
Normal file
88
forge/forge/llm/providers/utils.py
Normal file
@@ -0,0 +1,88 @@
|
||||
from typing import TYPE_CHECKING, Any, Iterable
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from forge.command.command import Command
|
||||
|
||||
from .schema import AssistantToolCall, CompletionModelFunction
|
||||
|
||||
|
||||
class InvalidFunctionCallError(Exception):
|
||||
def __init__(self, name: str, arguments: dict[str, Any], message: str):
|
||||
self.message = message
|
||||
self.name = name
|
||||
self.arguments = arguments
|
||||
super().__init__(message)
|
||||
|
||||
def __str__(self) -> str:
|
||||
return f"Invalid function call for {self.name}: {self.message}"
|
||||
|
||||
|
||||
def validate_tool_calls(
|
||||
tool_calls: list[AssistantToolCall], functions: list[CompletionModelFunction]
|
||||
) -> list[InvalidFunctionCallError]:
|
||||
"""
|
||||
Validates a list of tool calls against a list of functions.
|
||||
|
||||
1. Tries to find a function matching each tool call
|
||||
2. If a matching function is found, validates the tool call's arguments,
|
||||
reporting any resulting errors
|
||||
2. If no matching function is found, an error "Unknown function X" is reported
|
||||
3. A list of all errors encountered during validation is returned
|
||||
|
||||
Params:
|
||||
tool_calls: A list of tool calls to validate.
|
||||
functions: A list of functions to validate against.
|
||||
|
||||
Returns:
|
||||
list[InvalidFunctionCallError]: All errors encountered during validation.
|
||||
"""
|
||||
errors: list[InvalidFunctionCallError] = []
|
||||
for tool_call in tool_calls:
|
||||
function_call = tool_call.function
|
||||
|
||||
if function := next(
|
||||
(f for f in functions if f.name == function_call.name),
|
||||
None,
|
||||
):
|
||||
is_valid, validation_errors = function.validate_call(function_call)
|
||||
if not is_valid:
|
||||
fmt_errors = [
|
||||
f"{'.'.join(str(p) for p in f.path)}: {f.message}"
|
||||
if f.path
|
||||
else f.message
|
||||
for f in validation_errors
|
||||
]
|
||||
errors.append(
|
||||
InvalidFunctionCallError(
|
||||
name=function_call.name,
|
||||
arguments=function_call.arguments,
|
||||
message=(
|
||||
"The set of arguments supplied is invalid:\n"
|
||||
+ "\n".join(fmt_errors)
|
||||
),
|
||||
)
|
||||
)
|
||||
else:
|
||||
errors.append(
|
||||
InvalidFunctionCallError(
|
||||
name=function_call.name,
|
||||
arguments=function_call.arguments,
|
||||
message=f"Unknown function {function_call.name}",
|
||||
)
|
||||
)
|
||||
|
||||
return errors
|
||||
|
||||
|
||||
def function_specs_from_commands(
|
||||
commands: Iterable["Command"],
|
||||
) -> list[CompletionModelFunction]:
|
||||
"""Get LLM-consumable function specs for the agent's available commands."""
|
||||
return [
|
||||
CompletionModelFunction(
|
||||
name=command.names[0],
|
||||
description=command.description,
|
||||
parameters={param.name: param.spec for param in command.parameters},
|
||||
)
|
||||
for command in commands
|
||||
]
|
||||
9
forge/forge/logging/__init__.py
Normal file
9
forge/forge/logging/__init__.py
Normal file
@@ -0,0 +1,9 @@
|
||||
from .config import configure_logging
|
||||
from .filters import BelowLevelFilter
|
||||
from .formatters import FancyConsoleFormatter
|
||||
|
||||
__all__ = [
|
||||
"configure_logging",
|
||||
"BelowLevelFilter",
|
||||
"FancyConsoleFormatter",
|
||||
]
|
||||
200
forge/forge/logging/config.py
Normal file
200
forge/forge/logging/config.py
Normal file
@@ -0,0 +1,200 @@
|
||||
"""Logging module for Auto-GPT."""
|
||||
from __future__ import annotations
|
||||
|
||||
import enum
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
|
||||
from openai._base_client import log as openai_logger
|
||||
|
||||
from forge.models.config import SystemConfiguration, UserConfigurable
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from forge.speech import TTSConfig
|
||||
|
||||
from .filters import BelowLevelFilter
|
||||
from .formatters import ForgeFormatter, StructuredLoggingFormatter
|
||||
from .handlers import TTSHandler
|
||||
|
||||
LOG_DIR = Path(__file__).parent.parent.parent / "logs"
|
||||
LOG_FILE = "activity.log"
|
||||
DEBUG_LOG_FILE = "debug.log"
|
||||
ERROR_LOG_FILE = "error.log"
|
||||
|
||||
SIMPLE_LOG_FORMAT = "%(asctime)s %(levelname)s %(title)s%(message)s"
|
||||
DEBUG_LOG_FORMAT = (
|
||||
"%(asctime)s %(levelname)s %(filename)s:%(lineno)d" " %(title)s%(message)s"
|
||||
)
|
||||
|
||||
SPEECH_OUTPUT_LOGGER = "VOICE"
|
||||
USER_FRIENDLY_OUTPUT_LOGGER = "USER_FRIENDLY_OUTPUT"
|
||||
|
||||
|
||||
class LogFormatName(str, enum.Enum):
|
||||
SIMPLE = "simple"
|
||||
DEBUG = "debug"
|
||||
STRUCTURED = "structured_google_cloud"
|
||||
|
||||
|
||||
TEXT_LOG_FORMAT_MAP = {
|
||||
LogFormatName.DEBUG: DEBUG_LOG_FORMAT,
|
||||
LogFormatName.SIMPLE: SIMPLE_LOG_FORMAT,
|
||||
}
|
||||
|
||||
|
||||
class LoggingConfig(SystemConfiguration):
|
||||
level: int = UserConfigurable(
|
||||
default=logging.INFO,
|
||||
from_env=lambda: logging.getLevelName(os.getenv("LOG_LEVEL", "INFO")),
|
||||
)
|
||||
|
||||
# Console output
|
||||
log_format: LogFormatName = UserConfigurable(
|
||||
default=LogFormatName.SIMPLE, from_env="LOG_FORMAT"
|
||||
)
|
||||
plain_console_output: bool = UserConfigurable(
|
||||
default=False,
|
||||
from_env=lambda: os.getenv("PLAIN_OUTPUT", "False") == "True",
|
||||
)
|
||||
|
||||
# File output
|
||||
log_dir: Path = LOG_DIR
|
||||
log_file_format: Optional[LogFormatName] = UserConfigurable(
|
||||
default=LogFormatName.SIMPLE,
|
||||
from_env=lambda: os.getenv( # type: ignore
|
||||
"LOG_FILE_FORMAT", os.getenv("LOG_FORMAT", "simple")
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def configure_logging(
|
||||
debug: bool = False,
|
||||
level: Optional[int | str] = None,
|
||||
log_dir: Optional[Path] = None,
|
||||
log_format: Optional[LogFormatName | str] = None,
|
||||
log_file_format: Optional[LogFormatName | str] = None,
|
||||
plain_console_output: Optional[bool] = None,
|
||||
config: Optional[LoggingConfig] = None,
|
||||
tts_config: Optional[TTSConfig] = None,
|
||||
) -> None:
|
||||
"""Configure the native logging module, based on the environment config and any
|
||||
specified overrides.
|
||||
|
||||
Arguments override values specified in the environment.
|
||||
Overrides are also applied to `config`, if passed.
|
||||
|
||||
Should be usable as `configure_logging(**config.logging.dict())`, where
|
||||
`config.logging` is a `LoggingConfig` object.
|
||||
"""
|
||||
if debug and level:
|
||||
raise ValueError("Only one of either 'debug' and 'level' arguments may be set")
|
||||
|
||||
# Parse arguments
|
||||
if isinstance(level, str):
|
||||
if type(_level := logging.getLevelName(level.upper())) is int:
|
||||
level = _level
|
||||
else:
|
||||
raise ValueError(f"Unknown log level '{level}'")
|
||||
if isinstance(log_format, str):
|
||||
if log_format in LogFormatName._value2member_map_:
|
||||
log_format = LogFormatName(log_format)
|
||||
elif not isinstance(log_format, LogFormatName):
|
||||
raise ValueError(f"Unknown log format '{log_format}'")
|
||||
if isinstance(log_file_format, str):
|
||||
if log_file_format in LogFormatName._value2member_map_:
|
||||
log_file_format = LogFormatName(log_file_format)
|
||||
elif not isinstance(log_file_format, LogFormatName):
|
||||
raise ValueError(f"Unknown log format '{log_format}'")
|
||||
|
||||
config = config or LoggingConfig.from_env()
|
||||
|
||||
# Aggregate env config + arguments
|
||||
config.level = logging.DEBUG if debug else level or config.level
|
||||
config.log_dir = log_dir or config.log_dir
|
||||
config.log_format = log_format or (
|
||||
LogFormatName.DEBUG if debug else config.log_format
|
||||
)
|
||||
config.log_file_format = log_file_format or log_format or config.log_file_format
|
||||
config.plain_console_output = (
|
||||
plain_console_output
|
||||
if plain_console_output is not None
|
||||
else config.plain_console_output
|
||||
)
|
||||
|
||||
# Structured logging is used for cloud environments,
|
||||
# where logging to a file makes no sense.
|
||||
if config.log_format == LogFormatName.STRUCTURED:
|
||||
config.plain_console_output = True
|
||||
config.log_file_format = None
|
||||
|
||||
# create log directory if it doesn't exist
|
||||
if not config.log_dir.exists():
|
||||
config.log_dir.mkdir()
|
||||
|
||||
log_handlers: list[logging.Handler] = []
|
||||
|
||||
if config.log_format in (LogFormatName.DEBUG, LogFormatName.SIMPLE):
|
||||
console_format_template = TEXT_LOG_FORMAT_MAP[config.log_format]
|
||||
console_formatter = ForgeFormatter(console_format_template)
|
||||
else:
|
||||
console_formatter = StructuredLoggingFormatter()
|
||||
console_format_template = SIMPLE_LOG_FORMAT
|
||||
|
||||
# Console output handlers
|
||||
stdout = logging.StreamHandler(stream=sys.stdout)
|
||||
stdout.setLevel(config.level)
|
||||
stdout.addFilter(BelowLevelFilter(logging.WARNING))
|
||||
stdout.setFormatter(console_formatter)
|
||||
stderr = logging.StreamHandler()
|
||||
stderr.setLevel(logging.WARNING)
|
||||
stderr.setFormatter(console_formatter)
|
||||
log_handlers += [stdout, stderr]
|
||||
|
||||
# File output handlers
|
||||
if config.log_file_format is not None:
|
||||
if config.level < logging.ERROR:
|
||||
file_output_format_template = TEXT_LOG_FORMAT_MAP[config.log_file_format]
|
||||
file_output_formatter = ForgeFormatter(
|
||||
file_output_format_template, no_color=True
|
||||
)
|
||||
|
||||
# INFO log file handler
|
||||
activity_log_handler = logging.FileHandler(
|
||||
config.log_dir / LOG_FILE, "a", "utf-8"
|
||||
)
|
||||
activity_log_handler.setLevel(config.level)
|
||||
activity_log_handler.setFormatter(file_output_formatter)
|
||||
log_handlers += [activity_log_handler]
|
||||
|
||||
# ERROR log file handler
|
||||
error_log_handler = logging.FileHandler(
|
||||
config.log_dir / ERROR_LOG_FILE, "a", "utf-8"
|
||||
)
|
||||
error_log_handler.setLevel(logging.ERROR)
|
||||
error_log_handler.setFormatter(ForgeFormatter(DEBUG_LOG_FORMAT, no_color=True))
|
||||
log_handlers += [error_log_handler]
|
||||
|
||||
# Configure the root logger
|
||||
logging.basicConfig(
|
||||
format=console_format_template,
|
||||
level=config.level,
|
||||
handlers=log_handlers,
|
||||
)
|
||||
|
||||
# Speech output
|
||||
speech_output_logger = logging.getLogger(SPEECH_OUTPUT_LOGGER)
|
||||
speech_output_logger.setLevel(logging.INFO)
|
||||
if tts_config:
|
||||
speech_output_logger.addHandler(TTSHandler(tts_config))
|
||||
speech_output_logger.propagate = False
|
||||
|
||||
# JSON logger with better formatting
|
||||
json_logger = logging.getLogger("JSON_LOGGER")
|
||||
json_logger.setLevel(logging.DEBUG)
|
||||
json_logger.propagate = False
|
||||
|
||||
# Disable debug logging from OpenAI library
|
||||
openai_logger.setLevel(logging.WARNING)
|
||||
12
forge/forge/logging/filters.py
Normal file
12
forge/forge/logging/filters.py
Normal file
@@ -0,0 +1,12 @@
|
||||
import logging
|
||||
|
||||
|
||||
class BelowLevelFilter(logging.Filter):
|
||||
"""Filter for logging levels below a certain threshold."""
|
||||
|
||||
def __init__(self, below_level: int):
|
||||
super().__init__()
|
||||
self.below_level = below_level
|
||||
|
||||
def filter(self, record: logging.LogRecord):
|
||||
return record.levelno < self.below_level
|
||||
95
forge/forge/logging/formatters.py
Normal file
95
forge/forge/logging/formatters.py
Normal file
@@ -0,0 +1,95 @@
|
||||
import logging
|
||||
|
||||
from colorama import Fore, Style
|
||||
from google.cloud.logging_v2.handlers import CloudLoggingFilter, StructuredLogHandler
|
||||
|
||||
from .utils import remove_color_codes
|
||||
|
||||
|
||||
class FancyConsoleFormatter(logging.Formatter):
|
||||
"""
|
||||
A custom logging formatter designed for console output.
|
||||
|
||||
This formatter enhances the standard logging output with color coding. The color
|
||||
coding is based on the level of the log message, making it easier to distinguish
|
||||
between different types of messages in the console output.
|
||||
|
||||
The color for each level is defined in the LEVEL_COLOR_MAP class attribute.
|
||||
"""
|
||||
|
||||
# level -> (level & text color, title color)
|
||||
LEVEL_COLOR_MAP = {
|
||||
logging.DEBUG: Fore.LIGHTBLACK_EX,
|
||||
logging.INFO: Fore.BLUE,
|
||||
logging.WARNING: Fore.YELLOW,
|
||||
logging.ERROR: Fore.RED,
|
||||
logging.CRITICAL: Fore.RED + Style.BRIGHT,
|
||||
}
|
||||
|
||||
def format(self, record: logging.LogRecord) -> str:
|
||||
# Make sure `msg` is a string
|
||||
if not hasattr(record, "msg"):
|
||||
record.msg = ""
|
||||
elif not type(record.msg) is str:
|
||||
record.msg = str(record.msg)
|
||||
|
||||
# Determine default color based on error level
|
||||
level_color = ""
|
||||
if record.levelno in self.LEVEL_COLOR_MAP:
|
||||
level_color = self.LEVEL_COLOR_MAP[record.levelno]
|
||||
record.levelname = f"{level_color}{record.levelname}{Style.RESET_ALL}"
|
||||
|
||||
# Determine color for message
|
||||
color = getattr(record, "color", level_color)
|
||||
color_is_specified = hasattr(record, "color")
|
||||
|
||||
# Don't color INFO messages unless the color is explicitly specified.
|
||||
if color and (record.levelno != logging.INFO or color_is_specified):
|
||||
record.msg = f"{color}{record.msg}{Style.RESET_ALL}"
|
||||
|
||||
return super().format(record)
|
||||
|
||||
|
||||
class ForgeFormatter(FancyConsoleFormatter):
|
||||
def __init__(self, *args, no_color: bool = False, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.no_color = no_color
|
||||
|
||||
def format(self, record: logging.LogRecord) -> str:
|
||||
# Make sure `msg` is a string
|
||||
if not hasattr(record, "msg"):
|
||||
record.msg = ""
|
||||
elif not type(record.msg) is str:
|
||||
record.msg = str(record.msg)
|
||||
|
||||
# Strip color from the message to prevent color spoofing
|
||||
if record.msg and not getattr(record, "preserve_color", False):
|
||||
record.msg = remove_color_codes(record.msg)
|
||||
|
||||
# Determine color for title
|
||||
title = getattr(record, "title", "")
|
||||
title_color = getattr(record, "title_color", "") or self.LEVEL_COLOR_MAP.get(
|
||||
record.levelno, ""
|
||||
)
|
||||
if title and title_color:
|
||||
title = f"{title_color + Style.BRIGHT}{title}{Style.RESET_ALL}"
|
||||
# Make sure record.title is set, and padded with a space if not empty
|
||||
record.title = f"{title} " if title else ""
|
||||
|
||||
if self.no_color:
|
||||
return remove_color_codes(super().format(record))
|
||||
else:
|
||||
return super().format(record)
|
||||
|
||||
|
||||
class StructuredLoggingFormatter(StructuredLogHandler, logging.Formatter):
|
||||
def __init__(self):
|
||||
# Set up CloudLoggingFilter to add diagnostic info to the log records
|
||||
self.cloud_logging_filter = CloudLoggingFilter()
|
||||
|
||||
# Init StructuredLogHandler
|
||||
super().__init__()
|
||||
|
||||
def format(self, record: logging.LogRecord) -> str:
|
||||
self.cloud_logging_filter.filter(record)
|
||||
return super().format(record)
|
||||
45
forge/forge/logging/handlers.py
Normal file
45
forge/forge/logging/handlers.py
Normal file
@@ -0,0 +1,45 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from forge.logging.utils import remove_color_codes
|
||||
from forge.speech import TextToSpeechProvider
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from forge.speech import TTSConfig
|
||||
|
||||
|
||||
class TTSHandler(logging.Handler):
|
||||
"""Output messages to the configured TTS engine (if any)"""
|
||||
|
||||
def __init__(self, config: TTSConfig):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.tts_provider = TextToSpeechProvider(config)
|
||||
|
||||
def format(self, record: logging.LogRecord) -> str:
|
||||
if getattr(record, "title", ""):
|
||||
msg = f"{getattr(record, 'title')} {record.msg}"
|
||||
else:
|
||||
msg = f"{record.msg}"
|
||||
|
||||
return remove_color_codes(msg)
|
||||
|
||||
def emit(self, record: logging.LogRecord) -> None:
|
||||
if not self.config.speak_mode:
|
||||
return
|
||||
|
||||
message = self.format(record)
|
||||
self.tts_provider.say(message)
|
||||
|
||||
|
||||
class JsonFileHandler(logging.FileHandler):
|
||||
def format(self, record: logging.LogRecord) -> str:
|
||||
record.json_data = json.loads(record.getMessage())
|
||||
return json.dumps(getattr(record, "json_data"), ensure_ascii=False, indent=4)
|
||||
|
||||
def emit(self, record: logging.LogRecord) -> None:
|
||||
with open(self.baseFilename, "w", encoding="utf-8") as f:
|
||||
f.write(self.format(record))
|
||||
33
forge/forge/logging/utils.py
Normal file
33
forge/forge/logging/utils.py
Normal file
@@ -0,0 +1,33 @@
|
||||
import logging
|
||||
import re
|
||||
from typing import Any
|
||||
|
||||
from colorama import Fore
|
||||
|
||||
|
||||
def remove_color_codes(s: str) -> str:
|
||||
return re.sub(r"\x1B(?:[@-Z\\-_]|\[[0-?]*[ -/]*[@-~])", "", s)
|
||||
|
||||
|
||||
def fmt_kwargs(kwargs: dict) -> str:
|
||||
return ", ".join(f"{n}={repr(v)}" for n, v in kwargs.items())
|
||||
|
||||
|
||||
def print_attribute(
|
||||
title: str, value: Any, title_color: str = Fore.GREEN, value_color: str = ""
|
||||
) -> None:
|
||||
logger = logging.getLogger()
|
||||
logger.info(
|
||||
str(value),
|
||||
extra={
|
||||
"title": f"{title.rstrip(':')}:",
|
||||
"title_color": title_color,
|
||||
"color": value_color,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def speak(message: str, level: int = logging.INFO) -> None:
|
||||
from .config import SPEECH_OUTPUT_LOGGER
|
||||
|
||||
logging.getLogger(SPEECH_OUTPUT_LOGGER).log(level, message)
|
||||
79
forge/forge/models/action.py
Normal file
79
forge/forge/models/action.py
Normal file
@@ -0,0 +1,79 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Literal, Optional, TypeVar
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from forge.llm.providers.schema import AssistantFunctionCall
|
||||
|
||||
from .utils import ModelWithSummary
|
||||
|
||||
|
||||
class ActionProposal(BaseModel):
|
||||
thoughts: str | ModelWithSummary
|
||||
use_tool: AssistantFunctionCall
|
||||
|
||||
|
||||
AnyProposal = TypeVar("AnyProposal", bound=ActionProposal)
|
||||
|
||||
|
||||
class ActionSuccessResult(BaseModel):
|
||||
outputs: Any
|
||||
status: Literal["success"] = "success"
|
||||
|
||||
def __str__(self) -> str:
|
||||
outputs = str(self.outputs).replace("```", r"\```")
|
||||
multiline = "\n" in outputs
|
||||
return f"```\n{self.outputs}\n```" if multiline else str(self.outputs)
|
||||
|
||||
|
||||
class ErrorInfo(BaseModel):
|
||||
args: tuple
|
||||
message: str
|
||||
exception_type: str
|
||||
repr: str
|
||||
|
||||
@staticmethod
|
||||
def from_exception(exception: Exception) -> ErrorInfo:
|
||||
return ErrorInfo(
|
||||
args=exception.args,
|
||||
message=getattr(exception, "message", exception.args[0]),
|
||||
exception_type=exception.__class__.__name__,
|
||||
repr=repr(exception),
|
||||
)
|
||||
|
||||
def __str__(self):
|
||||
return repr(self)
|
||||
|
||||
def __repr__(self):
|
||||
return self.repr
|
||||
|
||||
|
||||
class ActionErrorResult(BaseModel):
|
||||
reason: str
|
||||
error: Optional[ErrorInfo] = None
|
||||
status: Literal["error"] = "error"
|
||||
|
||||
@staticmethod
|
||||
def from_exception(exception: Exception) -> ActionErrorResult:
|
||||
return ActionErrorResult(
|
||||
reason=getattr(exception, "message", exception.args[0]),
|
||||
error=ErrorInfo.from_exception(exception),
|
||||
)
|
||||
|
||||
def __str__(self) -> str:
|
||||
return f"Action failed: '{self.reason}'"
|
||||
|
||||
|
||||
class ActionInterruptedByHuman(BaseModel):
|
||||
feedback: str
|
||||
status: Literal["interrupted_by_human"] = "interrupted_by_human"
|
||||
|
||||
def __str__(self) -> str:
|
||||
return (
|
||||
'The user interrupted the action with the following feedback: "%s"'
|
||||
% self.feedback
|
||||
)
|
||||
|
||||
|
||||
ActionResult = ActionSuccessResult | ActionErrorResult | ActionInterruptedByHuman
|
||||
350
forge/forge/models/config.py
Normal file
350
forge/forge/models/config.py
Normal file
@@ -0,0 +1,350 @@
|
||||
import os
|
||||
import typing
|
||||
from typing import Any, Callable, Generic, Optional, Type, TypeVar, get_args
|
||||
|
||||
from pydantic import BaseModel, Field, ValidationError
|
||||
from pydantic.fields import ModelField, Undefined, UndefinedType
|
||||
from pydantic.main import ModelMetaclass
|
||||
|
||||
T = TypeVar("T")
|
||||
M = TypeVar("M", bound=BaseModel)
|
||||
|
||||
|
||||
def UserConfigurable(
|
||||
default: T | UndefinedType = Undefined,
|
||||
*args,
|
||||
default_factory: Optional[Callable[[], T]] = None,
|
||||
from_env: Optional[str | Callable[[], T | None]] = None,
|
||||
description: str = "",
|
||||
**kwargs,
|
||||
) -> T:
|
||||
# TODO: use this to auto-generate docs for the application configuration
|
||||
return Field(
|
||||
default,
|
||||
*args,
|
||||
default_factory=default_factory,
|
||||
from_env=from_env,
|
||||
description=description,
|
||||
**kwargs,
|
||||
user_configurable=True,
|
||||
)
|
||||
|
||||
|
||||
class SystemConfiguration(BaseModel):
|
||||
def get_user_config(self) -> dict[str, Any]:
|
||||
return _recurse_user_config_values(self)
|
||||
|
||||
@classmethod
|
||||
def from_env(cls):
|
||||
"""
|
||||
Initializes the config object from environment variables.
|
||||
|
||||
Environment variables are mapped to UserConfigurable fields using the from_env
|
||||
attribute that can be passed to UserConfigurable.
|
||||
"""
|
||||
|
||||
def infer_field_value(field: ModelField):
|
||||
field_info = field.field_info
|
||||
default_value = (
|
||||
field.default
|
||||
if field.default not in (None, Undefined)
|
||||
else (field.default_factory() if field.default_factory else Undefined)
|
||||
)
|
||||
if from_env := field_info.extra.get("from_env"):
|
||||
val_from_env = (
|
||||
os.getenv(from_env) if type(from_env) is str else from_env()
|
||||
)
|
||||
if val_from_env is not None:
|
||||
return val_from_env
|
||||
return default_value
|
||||
|
||||
return _recursive_init_model(cls, infer_field_value)
|
||||
|
||||
class Config:
|
||||
extra = "forbid"
|
||||
use_enum_values = True
|
||||
validate_assignment = True
|
||||
|
||||
|
||||
SC = TypeVar("SC", bound=SystemConfiguration)
|
||||
|
||||
|
||||
class SystemSettings(BaseModel):
|
||||
"""A base class for all system settings."""
|
||||
|
||||
name: str
|
||||
description: str
|
||||
|
||||
class Config:
|
||||
extra = "forbid"
|
||||
use_enum_values = True
|
||||
validate_assignment = True
|
||||
|
||||
|
||||
S = TypeVar("S", bound=SystemSettings)
|
||||
|
||||
|
||||
class Configurable(Generic[S]):
|
||||
"""A base class for all configurable objects."""
|
||||
|
||||
prefix: str = ""
|
||||
default_settings: typing.ClassVar[S] # type: ignore
|
||||
|
||||
@classmethod
|
||||
def get_user_config(cls) -> dict[str, Any]:
|
||||
return _recurse_user_config_values(cls.default_settings)
|
||||
|
||||
@classmethod
|
||||
def build_agent_configuration(cls, overrides: dict = {}) -> S:
|
||||
"""Process the configuration for this object."""
|
||||
|
||||
base_config = _update_user_config_from_env(cls.default_settings)
|
||||
final_configuration = deep_update(base_config, overrides)
|
||||
|
||||
return cls.default_settings.__class__.parse_obj(final_configuration)
|
||||
|
||||
|
||||
def _update_user_config_from_env(instance: BaseModel) -> dict[str, Any]:
|
||||
"""
|
||||
Update config fields of a Pydantic model instance from environment variables.
|
||||
|
||||
Precedence:
|
||||
1. Non-default value already on the instance
|
||||
2. Value returned by `from_env()`
|
||||
3. Default value for the field
|
||||
|
||||
Params:
|
||||
instance: The Pydantic model instance.
|
||||
|
||||
Returns:
|
||||
The user config fields of the instance.
|
||||
"""
|
||||
|
||||
def infer_field_value(field: ModelField, value):
|
||||
field_info = field.field_info
|
||||
default_value = (
|
||||
field.default
|
||||
if field.default not in (None, Undefined)
|
||||
else (field.default_factory() if field.default_factory else None)
|
||||
)
|
||||
if value == default_value and (from_env := field_info.extra.get("from_env")):
|
||||
val_from_env = os.getenv(from_env) if type(from_env) is str else from_env()
|
||||
if val_from_env is not None:
|
||||
return val_from_env
|
||||
return value
|
||||
|
||||
def init_sub_config(model: Type[SC]) -> SC | None:
|
||||
try:
|
||||
return model.from_env()
|
||||
except ValidationError as e:
|
||||
# Gracefully handle missing fields
|
||||
if all(e["type"] == "value_error.missing" for e in e.errors()):
|
||||
return None
|
||||
raise
|
||||
|
||||
return _recurse_user_config_fields(instance, infer_field_value, init_sub_config)
|
||||
|
||||
|
||||
def _recursive_init_model(
|
||||
model: Type[M],
|
||||
infer_field_value: Callable[[ModelField], Any],
|
||||
) -> M:
|
||||
"""
|
||||
Recursively initialize the user configuration fields of a Pydantic model.
|
||||
|
||||
Parameters:
|
||||
model: The Pydantic model type.
|
||||
infer_field_value: A callback function to infer the value of each field.
|
||||
Parameters:
|
||||
ModelField: The Pydantic ModelField object describing the field.
|
||||
|
||||
Returns:
|
||||
BaseModel: An instance of the model with the initialized configuration.
|
||||
"""
|
||||
user_config_fields = {}
|
||||
for name, field in model.__fields__.items():
|
||||
if "user_configurable" in field.field_info.extra:
|
||||
user_config_fields[name] = infer_field_value(field)
|
||||
elif type(field.outer_type_) is ModelMetaclass and issubclass(
|
||||
field.outer_type_, SystemConfiguration
|
||||
):
|
||||
try:
|
||||
user_config_fields[name] = _recursive_init_model(
|
||||
model=field.outer_type_,
|
||||
infer_field_value=infer_field_value,
|
||||
)
|
||||
except ValidationError as e:
|
||||
# Gracefully handle missing fields
|
||||
if all(e["type"] == "value_error.missing" for e in e.errors()):
|
||||
user_config_fields[name] = None
|
||||
raise
|
||||
|
||||
user_config_fields = remove_none_items(user_config_fields)
|
||||
|
||||
return model.parse_obj(user_config_fields)
|
||||
|
||||
|
||||
def _recurse_user_config_fields(
|
||||
model: BaseModel,
|
||||
infer_field_value: Callable[[ModelField, Any], Any],
|
||||
init_sub_config: Optional[
|
||||
Callable[[Type[SystemConfiguration]], SystemConfiguration | None]
|
||||
] = None,
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Recursively process the user configuration fields of a Pydantic model instance.
|
||||
|
||||
Params:
|
||||
model: The Pydantic model to iterate over.
|
||||
infer_field_value: A callback function to process each field.
|
||||
Params:
|
||||
ModelField: The Pydantic ModelField object describing the field.
|
||||
Any: The current value of the field.
|
||||
init_sub_config: An optional callback function to initialize a sub-config.
|
||||
Params:
|
||||
Type[SystemConfiguration]: The type of the sub-config to initialize.
|
||||
|
||||
Returns:
|
||||
dict[str, Any]: The processed user configuration fields of the instance.
|
||||
"""
|
||||
user_config_fields = {}
|
||||
|
||||
for name, field in model.__fields__.items():
|
||||
value = getattr(model, name)
|
||||
|
||||
# Handle individual field
|
||||
if "user_configurable" in field.field_info.extra:
|
||||
user_config_fields[name] = infer_field_value(field, value)
|
||||
|
||||
# Recurse into nested config object
|
||||
elif isinstance(value, SystemConfiguration):
|
||||
user_config_fields[name] = _recurse_user_config_fields(
|
||||
model=value,
|
||||
infer_field_value=infer_field_value,
|
||||
init_sub_config=init_sub_config,
|
||||
)
|
||||
|
||||
# Recurse into optional nested config object
|
||||
elif value is None and init_sub_config:
|
||||
field_type = get_args(field.annotation)[0] # Optional[T] -> T
|
||||
if type(field_type) is ModelMetaclass and issubclass(
|
||||
field_type, SystemConfiguration
|
||||
):
|
||||
sub_config = init_sub_config(field_type)
|
||||
if sub_config:
|
||||
user_config_fields[name] = _recurse_user_config_fields(
|
||||
model=sub_config,
|
||||
infer_field_value=infer_field_value,
|
||||
init_sub_config=init_sub_config,
|
||||
)
|
||||
|
||||
elif isinstance(value, list) and all(
|
||||
isinstance(i, SystemConfiguration) for i in value
|
||||
):
|
||||
user_config_fields[name] = [
|
||||
_recurse_user_config_fields(i, infer_field_value, init_sub_config)
|
||||
for i in value
|
||||
]
|
||||
elif isinstance(value, dict) and all(
|
||||
isinstance(i, SystemConfiguration) for i in value.values()
|
||||
):
|
||||
user_config_fields[name] = {
|
||||
k: _recurse_user_config_fields(v, infer_field_value, init_sub_config)
|
||||
for k, v in value.items()
|
||||
}
|
||||
|
||||
return user_config_fields
|
||||
|
||||
|
||||
def _recurse_user_config_values(
|
||||
instance: BaseModel,
|
||||
get_field_value: Callable[[ModelField, T], T] = lambda _, v: v,
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
This function recursively traverses the user configuration values in a Pydantic
|
||||
model instance.
|
||||
|
||||
Params:
|
||||
instance: A Pydantic model instance.
|
||||
get_field_value: A callback function to process each field. Parameters:
|
||||
ModelField: The Pydantic ModelField object that describes the field.
|
||||
Any: The current value of the field.
|
||||
|
||||
Returns:
|
||||
A dictionary containing the processed user configuration fields of the instance.
|
||||
"""
|
||||
user_config_values = {}
|
||||
|
||||
for name, value in instance.__dict__.items():
|
||||
field = instance.__fields__[name]
|
||||
if "user_configurable" in field.field_info.extra:
|
||||
user_config_values[name] = get_field_value(field, value)
|
||||
elif isinstance(value, SystemConfiguration):
|
||||
user_config_values[name] = _recurse_user_config_values(
|
||||
instance=value, get_field_value=get_field_value
|
||||
)
|
||||
elif isinstance(value, list) and all(
|
||||
isinstance(i, SystemConfiguration) for i in value
|
||||
):
|
||||
user_config_values[name] = [
|
||||
_recurse_user_config_values(i, get_field_value) for i in value
|
||||
]
|
||||
elif isinstance(value, dict) and all(
|
||||
isinstance(i, SystemConfiguration) for i in value.values()
|
||||
):
|
||||
user_config_values[name] = {
|
||||
k: _recurse_user_config_values(v, get_field_value)
|
||||
for k, v in value.items()
|
||||
}
|
||||
|
||||
return user_config_values
|
||||
|
||||
|
||||
def _get_non_default_user_config_values(instance: BaseModel) -> dict[str, Any]:
|
||||
"""
|
||||
Get the non-default user config fields of a Pydantic model instance.
|
||||
|
||||
Params:
|
||||
instance: The Pydantic model instance.
|
||||
|
||||
Returns:
|
||||
dict[str, Any]: The non-default user config values on the instance.
|
||||
"""
|
||||
|
||||
def get_field_value(field: ModelField, value):
|
||||
default = field.default_factory() if field.default_factory else field.default
|
||||
if value != default:
|
||||
return value
|
||||
|
||||
return remove_none_items(_recurse_user_config_values(instance, get_field_value))
|
||||
|
||||
|
||||
def deep_update(original_dict: dict, update_dict: dict) -> dict:
|
||||
"""
|
||||
Recursively update a dictionary.
|
||||
|
||||
Params:
|
||||
original_dict (dict): The dictionary to be updated.
|
||||
update_dict (dict): The dictionary to update with.
|
||||
|
||||
Returns:
|
||||
dict: The updated dictionary.
|
||||
"""
|
||||
for key, value in update_dict.items():
|
||||
if (
|
||||
key in original_dict
|
||||
and isinstance(original_dict[key], dict)
|
||||
and isinstance(value, dict)
|
||||
):
|
||||
original_dict[key] = deep_update(original_dict[key], value)
|
||||
else:
|
||||
original_dict[key] = value
|
||||
return original_dict
|
||||
|
||||
|
||||
def remove_none_items(d):
|
||||
if isinstance(d, dict):
|
||||
return {
|
||||
k: remove_none_items(v) for k, v in d.items() if v not in (None, Undefined)
|
||||
}
|
||||
return d
|
||||
217
forge/forge/models/json_schema.py
Normal file
217
forge/forge/models/json_schema.py
Normal file
@@ -0,0 +1,217 @@
|
||||
import ast
|
||||
import enum
|
||||
from textwrap import indent
|
||||
from typing import Any, Optional, overload
|
||||
|
||||
from jsonschema import Draft7Validator, ValidationError
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class JSONSchema(BaseModel):
|
||||
class Type(str, enum.Enum):
|
||||
STRING = "string"
|
||||
ARRAY = "array"
|
||||
OBJECT = "object"
|
||||
NUMBER = "number"
|
||||
INTEGER = "integer"
|
||||
BOOLEAN = "boolean"
|
||||
TYPE = "type"
|
||||
|
||||
# TODO: add docstrings
|
||||
description: Optional[str] = None
|
||||
type: Optional[Type] = None
|
||||
enum: Optional[list] = None
|
||||
required: bool = False
|
||||
default: Any = None
|
||||
items: Optional["JSONSchema"] = None
|
||||
properties: Optional[dict[str, "JSONSchema"]] = None
|
||||
minimum: Optional[int | float] = None
|
||||
maximum: Optional[int | float] = None
|
||||
minItems: Optional[int] = None
|
||||
maxItems: Optional[int] = None
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
schema: dict = {
|
||||
"type": self.type.value if self.type else None,
|
||||
"description": self.description,
|
||||
"default": repr(self.default),
|
||||
}
|
||||
if self.type == "array":
|
||||
if self.items:
|
||||
schema["items"] = self.items.to_dict()
|
||||
schema["minItems"] = self.minItems
|
||||
schema["maxItems"] = self.maxItems
|
||||
elif self.type == "object":
|
||||
if self.properties:
|
||||
schema["properties"] = {
|
||||
name: prop.to_dict() for name, prop in self.properties.items()
|
||||
}
|
||||
schema["required"] = [
|
||||
name for name, prop in self.properties.items() if prop.required
|
||||
]
|
||||
elif self.enum:
|
||||
schema["enum"] = self.enum
|
||||
else:
|
||||
schema["minumum"] = self.minimum
|
||||
schema["maximum"] = self.maximum
|
||||
|
||||
schema = {k: v for k, v in schema.items() if v is not None}
|
||||
|
||||
return schema
|
||||
|
||||
@staticmethod
|
||||
def from_dict(schema: dict) -> "JSONSchema":
|
||||
definitions = schema.get("definitions", {})
|
||||
schema = _resolve_type_refs_in_schema(schema, definitions)
|
||||
|
||||
return JSONSchema(
|
||||
description=schema.get("description"),
|
||||
type=schema["type"],
|
||||
default=ast.literal_eval(d) if (d := schema.get("default")) else None,
|
||||
enum=schema.get("enum"),
|
||||
items=JSONSchema.from_dict(schema["items"]) if "items" in schema else None,
|
||||
properties=JSONSchema.parse_properties(schema)
|
||||
if schema["type"] == "object"
|
||||
else None,
|
||||
minimum=schema.get("minimum"),
|
||||
maximum=schema.get("maximum"),
|
||||
minItems=schema.get("minItems"),
|
||||
maxItems=schema.get("maxItems"),
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def parse_properties(schema_node: dict) -> dict[str, "JSONSchema"]:
|
||||
properties = (
|
||||
{k: JSONSchema.from_dict(v) for k, v in schema_node["properties"].items()}
|
||||
if "properties" in schema_node
|
||||
else {}
|
||||
)
|
||||
if "required" in schema_node:
|
||||
for k, v in properties.items():
|
||||
v.required = k in schema_node["required"]
|
||||
return properties
|
||||
|
||||
def validate_object(self, object: object) -> tuple[bool, list[ValidationError]]:
|
||||
"""
|
||||
Validates an object or a value against the JSONSchema.
|
||||
|
||||
Params:
|
||||
object: The value/object to validate.
|
||||
schema (JSONSchema): The JSONSchema to validate against.
|
||||
|
||||
Returns:
|
||||
bool: Indicates whether the given value or object is valid for the schema.
|
||||
list[ValidationError]: The issues with the value or object (if any).
|
||||
"""
|
||||
validator = Draft7Validator(self.to_dict())
|
||||
|
||||
if errors := sorted(validator.iter_errors(object), key=lambda e: e.path):
|
||||
return False, errors
|
||||
|
||||
return True, []
|
||||
|
||||
def to_typescript_object_interface(self, interface_name: str = "") -> str:
|
||||
if self.type != JSONSchema.Type.OBJECT:
|
||||
raise NotImplementedError("Only `object` schemas are supported")
|
||||
|
||||
if self.properties:
|
||||
attributes: list[str] = []
|
||||
for name, property in self.properties.items():
|
||||
if property.description:
|
||||
attributes.append(f"// {property.description}")
|
||||
attributes.append(f"{name}: {property.typescript_type};")
|
||||
attributes_string = "\n".join(attributes)
|
||||
else:
|
||||
attributes_string = "[key: string]: any"
|
||||
|
||||
return (
|
||||
f"interface {interface_name} " if interface_name else ""
|
||||
) + f"{{\n{indent(attributes_string, ' ')}\n}}"
|
||||
|
||||
@property
|
||||
def python_type(self) -> str:
|
||||
if self.type == JSONSchema.Type.BOOLEAN:
|
||||
return "bool"
|
||||
elif self.type in {JSONSchema.Type.INTEGER}:
|
||||
return "int"
|
||||
elif self.type == JSONSchema.Type.NUMBER:
|
||||
return "float"
|
||||
elif self.type == JSONSchema.Type.STRING:
|
||||
return "str"
|
||||
elif self.type == JSONSchema.Type.ARRAY:
|
||||
return f"list[{self.items.python_type}]" if self.items else "list"
|
||||
elif self.type == JSONSchema.Type.OBJECT:
|
||||
if not self.properties:
|
||||
return "dict"
|
||||
raise NotImplementedError(
|
||||
"JSONSchema.python_type doesn't support TypedDicts yet"
|
||||
)
|
||||
elif self.enum:
|
||||
return "Union[" + ", ".join(repr(v) for v in self.enum) + "]"
|
||||
elif self.type == JSONSchema.Type.TYPE:
|
||||
return "type"
|
||||
elif self.type is None:
|
||||
return "Any"
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
f"JSONSchema.python_type does not support Type.{self.type.name} yet"
|
||||
)
|
||||
|
||||
@property
|
||||
def typescript_type(self) -> str:
|
||||
if not self.type:
|
||||
return "any"
|
||||
if self.type == JSONSchema.Type.BOOLEAN:
|
||||
return "boolean"
|
||||
if self.type in {JSONSchema.Type.INTEGER, JSONSchema.Type.NUMBER}:
|
||||
return "number"
|
||||
if self.type == JSONSchema.Type.STRING:
|
||||
return "string"
|
||||
if self.type == JSONSchema.Type.ARRAY:
|
||||
return f"Array<{self.items.typescript_type}>" if self.items else "Array"
|
||||
if self.type == JSONSchema.Type.OBJECT:
|
||||
if not self.properties:
|
||||
return "Record<string, any>"
|
||||
return self.to_typescript_object_interface()
|
||||
if self.enum:
|
||||
return " | ".join(repr(v) for v in self.enum)
|
||||
elif self.type == JSONSchema.Type.TYPE:
|
||||
return "type"
|
||||
elif self.type is None:
|
||||
return "any"
|
||||
|
||||
raise NotImplementedError(
|
||||
f"JSONSchema.typescript_type does not support Type.{self.type.name} yet"
|
||||
)
|
||||
|
||||
|
||||
@overload
|
||||
def _resolve_type_refs_in_schema(schema: dict, definitions: dict) -> dict:
|
||||
...
|
||||
|
||||
|
||||
@overload
|
||||
def _resolve_type_refs_in_schema(schema: list, definitions: dict) -> list:
|
||||
...
|
||||
|
||||
|
||||
def _resolve_type_refs_in_schema(schema: dict | list, definitions: dict) -> dict | list:
|
||||
"""
|
||||
Recursively resolve type $refs in the JSON schema with their definitions.
|
||||
"""
|
||||
if isinstance(schema, dict):
|
||||
if "$ref" in schema:
|
||||
ref_path = schema["$ref"].split("/")[2:] # Split and remove '#/definitions'
|
||||
ref_value = definitions
|
||||
for key in ref_path:
|
||||
ref_value = ref_value[key]
|
||||
return _resolve_type_refs_in_schema(ref_value, definitions)
|
||||
else:
|
||||
return {
|
||||
k: _resolve_type_refs_in_schema(v, definitions)
|
||||
for k, v in schema.items()
|
||||
}
|
||||
elif isinstance(schema, list):
|
||||
return [_resolve_type_refs_in_schema(item, definitions) for item in schema]
|
||||
else:
|
||||
return schema
|
||||
61
forge/forge/models/providers.py
Normal file
61
forge/forge/models/providers.py
Normal file
@@ -0,0 +1,61 @@
|
||||
import abc
|
||||
import enum
|
||||
import math
|
||||
from typing import Callable, Generic, TypeVar
|
||||
|
||||
from pydantic import BaseModel, SecretBytes, SecretField, SecretStr
|
||||
|
||||
from forge.models.config import SystemConfiguration, UserConfigurable
|
||||
|
||||
_T = TypeVar("_T")
|
||||
|
||||
|
||||
class ResourceType(str, enum.Enum):
|
||||
"""An enumeration of resource types."""
|
||||
|
||||
MODEL = "model"
|
||||
|
||||
|
||||
class ProviderBudget(SystemConfiguration, Generic[_T]):
|
||||
total_budget: float = UserConfigurable(math.inf)
|
||||
total_cost: float = 0
|
||||
remaining_budget: float = math.inf
|
||||
usage: _T
|
||||
|
||||
@abc.abstractmethod
|
||||
def update_usage_and_cost(self, *args, **kwargs) -> float:
|
||||
"""Update the usage and cost of the provider.
|
||||
|
||||
Returns:
|
||||
float: The (calculated) cost of the given model response.
|
||||
"""
|
||||
...
|
||||
|
||||
|
||||
class ProviderCredentials(SystemConfiguration):
|
||||
"""Struct for credentials."""
|
||||
|
||||
def unmasked(self) -> dict:
|
||||
return unmask(self)
|
||||
|
||||
class Config(SystemConfiguration.Config):
|
||||
json_encoders: dict[type[SecretField], Callable[[SecretField], str | None]] = {
|
||||
SecretStr: lambda v: v.get_secret_value() if v else None,
|
||||
SecretBytes: lambda v: v.get_secret_value() if v else None,
|
||||
SecretField: lambda v: v.get_secret_value() if v else None,
|
||||
}
|
||||
|
||||
|
||||
def unmask(model: BaseModel):
|
||||
unmasked_fields = {}
|
||||
for field_name, _ in model.__fields__.items():
|
||||
value = getattr(model, field_name)
|
||||
if isinstance(value, SecretStr):
|
||||
unmasked_fields[field_name] = value.get_secret_value()
|
||||
else:
|
||||
unmasked_fields[field_name] = value
|
||||
return unmasked_fields
|
||||
|
||||
|
||||
# Used both by model providers and memory providers
|
||||
Embedding = list[float]
|
||||
10
forge/forge/models/utils.py
Normal file
10
forge/forge/models/utils.py
Normal file
@@ -0,0 +1,10 @@
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class ModelWithSummary(BaseModel, ABC):
|
||||
@abstractmethod
|
||||
def summary(self) -> str:
|
||||
"""Should produce a human readable summary of the model content."""
|
||||
pass
|
||||
4
forge/forge/speech/__init__.py
Normal file
4
forge/forge/speech/__init__.py
Normal file
@@ -0,0 +1,4 @@
|
||||
"""This module contains the (speech recognition and) speech synthesis functions."""
|
||||
from .say import TextToSpeechProvider, TTSConfig
|
||||
|
||||
__all__ = ["TextToSpeechProvider", "TTSConfig"]
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user