mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-04-08 03:00:28 -04:00
Set up unified pre-commit + CI w/ linting + type checking & FIX EVERYTHING (#7171)
- **FIX ALL LINT/TYPE ERRORS IN AUTOGPT, FORGE, AND BENCHMARK** ### Linting - Clean up linter configs for `autogpt`, `forge`, and `benchmark` - Add type checking with Pyright - Create unified pre-commit config - Create unified linting and type checking CI workflow ### Testing - Synchronize CI test setups for `autogpt`, `forge`, and `benchmark` - Add missing pytest-cov to benchmark dependencies - Mark GCS tests as slow to speed up pre-commit test runs - Repair `forge` test suite - Add `AgentDB.close()` method for test DB teardown in db_test.py - Use actual temporary dir instead of forge/test_workspace/ - Move left-behind dependencies for moved `forge`-code to from autogpt to forge ### Notable type changes - Replace uses of `ChatModelProvider` by `MultiProvider` - Removed unnecessary exports from various __init__.py - Simplify `FileStorage.open_file` signature by removing `IOBase` from return type union - Implement `S3BinaryIOWrapper(BinaryIO)` type interposer for `S3FileStorage` - Expand overloads of `GCSFileStorage.open_file` for improved typing of read and write modes Had to silence type checking for the extra overloads, because (I think) Pyright is reporting a false-positive: https://github.com/microsoft/pyright/issues/8007 - Change `count_tokens`, `get_tokenizer`, `count_message_tokens` methods on `ModelProvider`s from class methods to instance methods - Move `CompletionModelFunction.schema` method -> helper function `format_function_def_for_openai` in `forge.llm.providers.openai` - Rename `ModelProvider` -> `BaseModelProvider` - Rename `ChatModelProvider` -> `BaseChatModelProvider` - Add type `ChatModelProvider` which is a union of all subclasses of `BaseChatModelProvider` ### Removed rather than fixed - Remove deprecated and broken autogpt/agbenchmark_config/benchmarks.py - Various base classes and properties on base classes in `forge.llm.providers.schema` and `forge.models.providers` ### Fixes for other issues that came to light - Clean up `forge.agent_protocol.api_router`, `forge.agent_protocol.database`, and `forge.agent.agent` - Add fallback behavior to `ImageGeneratorComponent` - Remove test for deprecated failure behavior - Fix `agbenchmark.challenges.builtin` challenge exclusion mechanism on Windows - Fix `_tool_calls_compat_extract_calls` in `forge.llm.providers.openai` - Add support for `any` (= no type specified) in `JSONSchema.typescript_type`
This commit is contained in:
committed by
GitHub
parent
2c13a2706c
commit
f107ff8cf0
@@ -1,15 +1,11 @@
|
||||
[flake8]
|
||||
max-line-length = 88
|
||||
select = "E303, W293, W292, E305, E231, E302"
|
||||
# Ignore rules that conflict with Black code style
|
||||
extend-ignore = E203, W503
|
||||
exclude =
|
||||
.tox,
|
||||
__pycache__,
|
||||
.git,
|
||||
__pycache__/,
|
||||
*.pyc,
|
||||
.env
|
||||
venv*/*,
|
||||
.venv/*,
|
||||
reports/*,
|
||||
dist/*,
|
||||
agent/*,
|
||||
code,
|
||||
agbenchmark/challenges/*
|
||||
.pytest_cache/,
|
||||
venv*/,
|
||||
.venv/,
|
||||
|
||||
5
forge/.gitignore
vendored
5
forge/.gitignore
vendored
@@ -160,7 +160,8 @@ CURRENT_BULLETIN.md
|
||||
|
||||
agbenchmark_config/workspace
|
||||
agbenchmark_config/reports
|
||||
*.sqlite
|
||||
*.sqlite*
|
||||
*.db
|
||||
.agbench
|
||||
.agbenchmark
|
||||
.benchmarks
|
||||
@@ -168,7 +169,7 @@ agbenchmark_config/reports
|
||||
.pytest_cache
|
||||
.vscode
|
||||
ig_*
|
||||
agent.db
|
||||
agbenchmark_config/updates.json
|
||||
agbenchmark_config/challenges_already_beaten.json
|
||||
agbenchmark_config/temp_folder/*
|
||||
test_workspace/
|
||||
|
||||
@@ -1,43 +0,0 @@
|
||||
repos:
|
||||
- repo: https://github.com/pre-commit/pre-commit-hooks
|
||||
rev: v4.4.0
|
||||
hooks:
|
||||
- id: check-added-large-files
|
||||
args: ['--maxkb=500']
|
||||
- id: check-byte-order-marker
|
||||
- id: check-case-conflict
|
||||
- id: check-merge-conflict
|
||||
- id: check-symlinks
|
||||
- id: debug-statements
|
||||
|
||||
- repo: https://github.com/pycqa/isort
|
||||
rev: 5.12.0
|
||||
hooks:
|
||||
- id: isort
|
||||
language_version: python3.11
|
||||
|
||||
- repo: https://github.com/psf/black
|
||||
rev: 23.3.0
|
||||
hooks:
|
||||
- id: black
|
||||
language_version: python3.11
|
||||
|
||||
# - repo: https://github.com/pre-commit/mirrors-mypy
|
||||
# rev: 'v1.3.0'
|
||||
# hooks:
|
||||
# - id: mypy
|
||||
|
||||
- repo: local
|
||||
hooks:
|
||||
- id: autoflake
|
||||
name: autoflake
|
||||
entry: autoflake --in-place --remove-all-unused-imports --recursive --ignore-init-module-imports --ignore-pass-after-docstring forge/autogpt
|
||||
language: python
|
||||
types: [ python ]
|
||||
# Mono repo has bronken this TODO: fix
|
||||
# - id: pytest-check
|
||||
# name: pytest-check
|
||||
# entry: pytest
|
||||
# language: system
|
||||
# pass_filenames: false
|
||||
# always_run: true
|
||||
@@ -9,27 +9,24 @@ from forge.logging.config import configure_logging
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
logo = """\n\n
|
||||
d8888 888 .d8888b. 8888888b. 88888888888
|
||||
d88888 888 d88P Y88b 888 Y88b 888
|
||||
d88P888 888 888 888 888 888 888
|
||||
d88P 888 888 888 888888 .d88b. 888 888 d88P 888
|
||||
d88P 888 888 888 888 d88""88b 888 88888 8888888P" 888
|
||||
d88P 888 888 888 888 888 888 888 888 888 888
|
||||
d8888888888 Y88b 888 Y88b. Y88..88P Y88b d88P 888 888
|
||||
d88P 888 "Y88888 "Y888 "Y88P" "Y8888P88 888 888
|
||||
|
||||
|
||||
|
||||
8888888888
|
||||
888
|
||||
888
|
||||
8888888 .d88b. 888d888 .d88b. .d88b.
|
||||
888 d88""88b 888P" d88P"88b d8P Y8b
|
||||
888 888 888 888 888 888 88888888
|
||||
888 Y88..88P 888 Y88b 888 Y8b.
|
||||
888 "Y88P" 888 "Y88888 "Y8888
|
||||
888
|
||||
Y8b d88P
|
||||
d8888 888 .d8888b. 8888888b. 88888888888
|
||||
d88P888 888 888 888 888 888 888
|
||||
d88P 888 888 888 888888 .d88b. 888 888 d88P 888
|
||||
d88P 888 888 888 888 d88""88b 888 88888 8888888P" 888
|
||||
d88P 888 888 888 888 888 888 888 888 888 888
|
||||
d8888888888 Y88b 888 Y88b. Y88..88P Y88b d88P 888 888
|
||||
d88P 888 "Y88888 "Y888 "Y88P" "Y8888P88 888 888
|
||||
|
||||
|
||||
8888888888
|
||||
888
|
||||
888 .d88b. 888d888 .d88b. .d88b.
|
||||
888888 d88""88b 888P" d88P"88b d8P Y8b
|
||||
888 888 888 888 888 888 88888888
|
||||
888 Y88..88P 888 Y88b 888 Y8b.
|
||||
888 "Y88P" 888 "Y88888 "Y8888
|
||||
888
|
||||
Y8b d88P
|
||||
"Y88P" v0.1.0
|
||||
\n"""
|
||||
|
||||
|
||||
@@ -1,15 +1,7 @@
|
||||
from .base import AgentMeta, BaseAgent, BaseAgentConfiguration, BaseAgentSettings
|
||||
from .components import (
|
||||
AgentComponent,
|
||||
ComponentEndpointError,
|
||||
ComponentSystemError,
|
||||
EndpointPipelineError,
|
||||
)
|
||||
from .protocols import (
|
||||
AfterExecute,
|
||||
AfterParse,
|
||||
CommandProvider,
|
||||
DirectiveProvider,
|
||||
ExecutionFailure,
|
||||
MessageProvider,
|
||||
)
|
||||
from .base import BaseAgent, BaseAgentConfiguration, BaseAgentSettings
|
||||
|
||||
__all__ = [
|
||||
"BaseAgent",
|
||||
"BaseAgentConfiguration",
|
||||
"BaseAgentSettings",
|
||||
]
|
||||
|
||||
@@ -24,7 +24,6 @@ from forge.agent_protocol.models.task import (
|
||||
TaskStepsListResponse,
|
||||
)
|
||||
from forge.file_storage.base import FileStorage
|
||||
from forge.utils.exceptions import NotFoundError
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -79,7 +78,8 @@ class Agent:
|
||||
|
||||
else:
|
||||
logger.warning(
|
||||
f"Frontend not found. {frontend_path} does not exist. The frontend will not be served"
|
||||
f"Frontend not found. {frontend_path} does not exist. "
|
||||
"The frontend will not be served."
|
||||
)
|
||||
app.add_middleware(AgentMiddleware, agent=self)
|
||||
|
||||
@@ -94,34 +94,25 @@ class Agent:
|
||||
"""
|
||||
Create a task for the agent.
|
||||
"""
|
||||
try:
|
||||
task = await self.db.create_task(
|
||||
input=task_request.input,
|
||||
additional_input=task_request.additional_input,
|
||||
)
|
||||
return task
|
||||
except Exception as e:
|
||||
raise
|
||||
task = await self.db.create_task(
|
||||
input=task_request.input,
|
||||
additional_input=task_request.additional_input,
|
||||
)
|
||||
return task
|
||||
|
||||
async def list_tasks(self, page: int = 1, pageSize: int = 10) -> TaskListResponse:
|
||||
"""
|
||||
List all tasks that the agent has created.
|
||||
"""
|
||||
try:
|
||||
tasks, pagination = await self.db.list_tasks(page, pageSize)
|
||||
response = TaskListResponse(tasks=tasks, pagination=pagination)
|
||||
return response
|
||||
except Exception as e:
|
||||
raise
|
||||
tasks, pagination = await self.db.list_tasks(page, pageSize)
|
||||
response = TaskListResponse(tasks=tasks, pagination=pagination)
|
||||
return response
|
||||
|
||||
async def get_task(self, task_id: str) -> Task:
|
||||
"""
|
||||
Get a task by ID.
|
||||
"""
|
||||
try:
|
||||
task = await self.db.get_task(task_id)
|
||||
except Exception as e:
|
||||
raise
|
||||
task = await self.db.get_task(task_id)
|
||||
return task
|
||||
|
||||
async def list_steps(
|
||||
@@ -130,12 +121,9 @@ class Agent:
|
||||
"""
|
||||
List the IDs of all steps that the task has created.
|
||||
"""
|
||||
try:
|
||||
steps, pagination = await self.db.list_steps(task_id, page, pageSize)
|
||||
response = TaskStepsListResponse(steps=steps, pagination=pagination)
|
||||
return response
|
||||
except Exception as e:
|
||||
raise
|
||||
steps, pagination = await self.db.list_steps(task_id, page, pageSize)
|
||||
response = TaskStepsListResponse(steps=steps, pagination=pagination)
|
||||
return response
|
||||
|
||||
async def execute_step(self, task_id: str, step_request: StepRequestBody) -> Step:
|
||||
"""
|
||||
@@ -147,11 +135,8 @@ class Agent:
|
||||
"""
|
||||
Get a step by ID.
|
||||
"""
|
||||
try:
|
||||
step = await self.db.get_step(task_id, step_id)
|
||||
return step
|
||||
except Exception as e:
|
||||
raise
|
||||
step = await self.db.get_step(task_id, step_id)
|
||||
return step
|
||||
|
||||
async def list_artifacts(
|
||||
self, task_id: str, page: int = 1, pageSize: int = 10
|
||||
@@ -159,62 +144,45 @@ class Agent:
|
||||
"""
|
||||
List the artifacts that the task has created.
|
||||
"""
|
||||
try:
|
||||
artifacts, pagination = await self.db.list_artifacts(
|
||||
task_id, page, pageSize
|
||||
)
|
||||
return TaskArtifactsListResponse(artifacts=artifacts, pagination=pagination)
|
||||
|
||||
except Exception as e:
|
||||
raise
|
||||
artifacts, pagination = await self.db.list_artifacts(task_id, page, pageSize)
|
||||
return TaskArtifactsListResponse(artifacts=artifacts, pagination=pagination)
|
||||
|
||||
async def create_artifact(
|
||||
self, task_id: str, file: UploadFile, relative_path: str
|
||||
self, task_id: str, file: UploadFile, relative_path: str = ""
|
||||
) -> Artifact:
|
||||
"""
|
||||
Create an artifact for the task.
|
||||
"""
|
||||
data = None
|
||||
file_name = file.filename or str(uuid4())
|
||||
try:
|
||||
data = b""
|
||||
while contents := file.file.read(1024 * 1024):
|
||||
data += contents
|
||||
# Check if relative path ends with filename
|
||||
if relative_path.endswith(file_name):
|
||||
file_path = relative_path
|
||||
else:
|
||||
file_path = os.path.join(relative_path, file_name)
|
||||
data = b""
|
||||
while contents := file.file.read(1024 * 1024):
|
||||
data += contents
|
||||
# Check if relative path ends with filename
|
||||
if relative_path.endswith(file_name):
|
||||
file_path = relative_path
|
||||
else:
|
||||
file_path = os.path.join(relative_path, file_name)
|
||||
|
||||
await self.workspace.write_file(file_path, data)
|
||||
await self.workspace.write_file(file_path, data)
|
||||
|
||||
artifact = await self.db.create_artifact(
|
||||
task_id=task_id,
|
||||
file_name=file_name,
|
||||
relative_path=relative_path,
|
||||
agent_created=False,
|
||||
)
|
||||
except Exception as e:
|
||||
raise
|
||||
artifact = await self.db.create_artifact(
|
||||
task_id=task_id,
|
||||
file_name=file_name,
|
||||
relative_path=relative_path,
|
||||
agent_created=False,
|
||||
)
|
||||
return artifact
|
||||
|
||||
async def get_artifact(self, task_id: str, artifact_id: str) -> Artifact:
|
||||
async def get_artifact(self, task_id: str, artifact_id: str) -> StreamingResponse:
|
||||
"""
|
||||
Get an artifact by ID.
|
||||
"""
|
||||
try:
|
||||
artifact = await self.db.get_artifact(artifact_id)
|
||||
if artifact.file_name not in artifact.relative_path:
|
||||
file_path = os.path.join(artifact.relative_path, artifact.file_name)
|
||||
else:
|
||||
file_path = artifact.relative_path
|
||||
retrieved_artifact = self.workspace.read_file(file_path)
|
||||
except NotFoundError as e:
|
||||
raise
|
||||
except FileNotFoundError as e:
|
||||
raise
|
||||
except Exception as e:
|
||||
raise
|
||||
artifact = await self.db.get_artifact(artifact_id)
|
||||
if artifact.file_name not in artifact.relative_path:
|
||||
file_path = os.path.join(artifact.relative_path, artifact.file_name)
|
||||
else:
|
||||
file_path = artifact.relative_path
|
||||
retrieved_artifact = self.workspace.read_file(file_path, binary=True)
|
||||
|
||||
return StreamingResponse(
|
||||
BytesIO(retrieved_artifact),
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
from fastapi import UploadFile
|
||||
|
||||
from forge.agent_protocol.database.db import AgentDB
|
||||
from forge.agent_protocol.models.task import (
|
||||
@@ -16,16 +17,23 @@ from .agent import Agent
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def agent():
|
||||
def agent(test_workspace: Path):
|
||||
db = AgentDB("sqlite:///test.db")
|
||||
config = FileStorageConfiguration(root=Path("./test_workspace"))
|
||||
config = FileStorageConfiguration(root=test_workspace)
|
||||
workspace = LocalFileStorage(config)
|
||||
return Agent(db, workspace)
|
||||
|
||||
|
||||
@pytest.mark.skip
|
||||
@pytest.fixture
|
||||
def file_upload():
|
||||
this_file = Path(__file__)
|
||||
file_handle = this_file.open("rb")
|
||||
yield UploadFile(file_handle, filename=this_file.name)
|
||||
file_handle.close()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_task(agent):
|
||||
async def test_create_task(agent: Agent):
|
||||
task_request = TaskRequestBody(
|
||||
input="test_input", additional_input={"input": "additional_test_input"}
|
||||
)
|
||||
@@ -33,20 +41,18 @@ async def test_create_task(agent):
|
||||
assert task.input == "test_input"
|
||||
|
||||
|
||||
@pytest.mark.skip
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_tasks(agent):
|
||||
async def test_list_tasks(agent: Agent):
|
||||
task_request = TaskRequestBody(
|
||||
input="test_input", additional_input={"input": "additional_test_input"}
|
||||
)
|
||||
task = await agent.create_task(task_request)
|
||||
await agent.create_task(task_request)
|
||||
tasks = await agent.list_tasks()
|
||||
assert isinstance(tasks, TaskListResponse)
|
||||
|
||||
|
||||
@pytest.mark.skip
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_task(agent):
|
||||
async def test_get_task(agent: Agent):
|
||||
task_request = TaskRequestBody(
|
||||
input="test_input", additional_input={"input": "additional_test_input"}
|
||||
)
|
||||
@@ -55,9 +61,9 @@ async def test_get_task(agent):
|
||||
assert retrieved_task.task_id == task.task_id
|
||||
|
||||
|
||||
@pytest.mark.skip
|
||||
@pytest.mark.xfail(reason="execute_step is not implemented")
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_and_execute_step(agent):
|
||||
async def test_execute_step(agent: Agent):
|
||||
task_request = TaskRequestBody(
|
||||
input="test_input", additional_input={"input": "additional_test_input"}
|
||||
)
|
||||
@@ -65,14 +71,14 @@ async def test_create_and_execute_step(agent):
|
||||
step_request = StepRequestBody(
|
||||
input="step_input", additional_input={"input": "additional_test_input"}
|
||||
)
|
||||
step = await agent.create_and_execute_step(task.task_id, step_request)
|
||||
step = await agent.execute_step(task.task_id, step_request)
|
||||
assert step.input == "step_input"
|
||||
assert step.additional_input == {"input": "additional_test_input"}
|
||||
|
||||
|
||||
@pytest.mark.skip
|
||||
@pytest.mark.xfail(reason="execute_step is not implemented")
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_step(agent):
|
||||
async def test_get_step(agent: Agent):
|
||||
task_request = TaskRequestBody(
|
||||
input="test_input", additional_input={"input": "additional_test_input"}
|
||||
)
|
||||
@@ -80,38 +86,52 @@ async def test_get_step(agent):
|
||||
step_request = StepRequestBody(
|
||||
input="step_input", additional_input={"input": "additional_test_input"}
|
||||
)
|
||||
step = await agent.create_and_execute_step(task.task_id, step_request)
|
||||
step = await agent.execute_step(task.task_id, step_request)
|
||||
retrieved_step = await agent.get_step(task.task_id, step.step_id)
|
||||
assert retrieved_step.step_id == step.step_id
|
||||
|
||||
|
||||
@pytest.mark.skip
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_artifacts(agent):
|
||||
artifacts = await agent.list_artifacts()
|
||||
assert isinstance(artifacts, list)
|
||||
async def test_list_artifacts(agent: Agent):
|
||||
tasks = await agent.list_tasks()
|
||||
assert tasks.tasks, "No tasks in test.db"
|
||||
|
||||
artifacts = await agent.list_artifacts(tasks.tasks[0].task_id)
|
||||
assert isinstance(artifacts.artifacts, list)
|
||||
|
||||
|
||||
@pytest.mark.skip
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_artifact(agent):
|
||||
async def test_create_artifact(agent: Agent, file_upload: UploadFile):
|
||||
task_request = TaskRequestBody(
|
||||
input="test_input", additional_input={"input": "additional_test_input"}
|
||||
)
|
||||
task = await agent.create_task(task_request)
|
||||
artifact_request = ArtifactRequestBody(file=None, uri="test_uri")
|
||||
artifact = await agent.create_artifact(task.task_id, artifact_request)
|
||||
assert artifact.uri == "test_uri"
|
||||
artifact = await agent.create_artifact(
|
||||
task_id=task.task_id,
|
||||
file=file_upload,
|
||||
relative_path=f"a_dir/{file_upload.filename}",
|
||||
)
|
||||
assert artifact.file_name == file_upload.filename
|
||||
assert artifact.relative_path == f"a_dir/{file_upload.filename}"
|
||||
|
||||
|
||||
@pytest.mark.skip
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_artifact(agent):
|
||||
async def test_create_and_get_artifact(agent: Agent, file_upload: UploadFile):
|
||||
task_request = TaskRequestBody(
|
||||
input="test_input", additional_input={"input": "additional_test_input"}
|
||||
)
|
||||
task = await agent.create_task(task_request)
|
||||
artifact_request = ArtifactRequestBody(file=None, uri="test_uri")
|
||||
artifact = await agent.create_artifact(task.task_id, artifact_request)
|
||||
|
||||
artifact = await agent.create_artifact(
|
||||
task_id=task.task_id,
|
||||
file=file_upload,
|
||||
relative_path=f"b_dir/{file_upload.filename}",
|
||||
)
|
||||
await file_upload.seek(0)
|
||||
file_upload_content = await file_upload.read()
|
||||
|
||||
retrieved_artifact = await agent.get_artifact(task.task_id, artifact.artifact_id)
|
||||
assert retrieved_artifact.artifact_id == artifact.artifact_id
|
||||
retrieved_artifact_content = bytearray()
|
||||
async for b in retrieved_artifact.body_iterator:
|
||||
retrieved_artifact_content.extend(b) # type: ignore
|
||||
assert retrieved_artifact_content == file_upload_content
|
||||
|
||||
@@ -5,22 +5,21 @@ import inspect
|
||||
import logging
|
||||
from abc import ABCMeta, abstractmethod
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
Awaitable,
|
||||
Callable,
|
||||
Generic,
|
||||
Iterator,
|
||||
Optional,
|
||||
ParamSpec,
|
||||
TypeVar,
|
||||
cast,
|
||||
overload,
|
||||
)
|
||||
|
||||
from colorama import Fore
|
||||
from pydantic import BaseModel, Field, validator
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from forge.models.action import ActionProposal, ActionResult
|
||||
|
||||
from forge.agent import protocols
|
||||
from forge.agent.components import (
|
||||
AgentComponent,
|
||||
@@ -29,15 +28,10 @@ from forge.agent.components import (
|
||||
)
|
||||
from forge.config.ai_directives import AIDirectives
|
||||
from forge.config.ai_profile import AIProfile
|
||||
from forge.config.config import ConfigBuilder
|
||||
from forge.llm.providers import CHAT_MODELS, ModelName, OpenAIModelName
|
||||
from forge.llm.providers.schema import ChatModelInfo
|
||||
from forge.models.config import (
|
||||
Configurable,
|
||||
SystemConfiguration,
|
||||
SystemSettings,
|
||||
UserConfigurable,
|
||||
)
|
||||
from forge.models.action import ActionResult, AnyProposal
|
||||
from forge.models.config import SystemConfiguration, SystemSettings, UserConfigurable
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -133,17 +127,7 @@ class AgentMeta(ABCMeta):
|
||||
return instance
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
class BaseAgent(Configurable[BaseAgentSettings], metaclass=AgentMeta):
|
||||
C = TypeVar("C", bound=AgentComponent)
|
||||
|
||||
default_settings = BaseAgentSettings(
|
||||
name="BaseAgent",
|
||||
description=__doc__ if __doc__ else "",
|
||||
)
|
||||
|
||||
class BaseAgent(Generic[AnyProposal], metaclass=AgentMeta):
|
||||
def __init__(
|
||||
self,
|
||||
settings: BaseAgentSettings,
|
||||
@@ -173,13 +157,13 @@ class BaseAgent(Configurable[BaseAgentSettings], metaclass=AgentMeta):
|
||||
return self.config.send_token_limit or self.llm.max_tokens * 3 // 4
|
||||
|
||||
@abstractmethod
|
||||
async def propose_action(self) -> ActionProposal:
|
||||
async def propose_action(self) -> AnyProposal:
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def execute(
|
||||
self,
|
||||
proposal: ActionProposal,
|
||||
proposal: AnyProposal,
|
||||
user_feedback: str = "",
|
||||
) -> ActionResult:
|
||||
...
|
||||
@@ -187,7 +171,7 @@ class BaseAgent(Configurable[BaseAgentSettings], metaclass=AgentMeta):
|
||||
@abstractmethod
|
||||
async def do_not_execute(
|
||||
self,
|
||||
denied_proposal: ActionProposal,
|
||||
denied_proposal: AnyProposal,
|
||||
user_feedback: str,
|
||||
) -> ActionResult:
|
||||
...
|
||||
@@ -203,13 +187,16 @@ class BaseAgent(Configurable[BaseAgentSettings], metaclass=AgentMeta):
|
||||
|
||||
@overload
|
||||
async def run_pipeline(
|
||||
self, protocol_method: Callable[P, None], *args, retry_limit: int = 3
|
||||
self,
|
||||
protocol_method: Callable[P, None | Awaitable[None]],
|
||||
*args,
|
||||
retry_limit: int = 3,
|
||||
) -> list[None]:
|
||||
...
|
||||
|
||||
async def run_pipeline(
|
||||
self,
|
||||
protocol_method: Callable[P, Iterator[T] | None],
|
||||
protocol_method: Callable[P, Iterator[T] | None | Awaitable[None]],
|
||||
*args,
|
||||
retry_limit: int = 3,
|
||||
) -> list[T] | list[None]:
|
||||
@@ -240,7 +227,10 @@ class BaseAgent(Configurable[BaseAgentSettings], metaclass=AgentMeta):
|
||||
)
|
||||
continue
|
||||
|
||||
method = getattr(component, method_name, None)
|
||||
method = cast(
|
||||
Callable[..., Iterator[T] | None | Awaitable[None]] | None,
|
||||
getattr(component, method_name, None),
|
||||
)
|
||||
if not callable(method):
|
||||
continue
|
||||
|
||||
@@ -248,10 +238,9 @@ class BaseAgent(Configurable[BaseAgentSettings], metaclass=AgentMeta):
|
||||
while component_attempts < retry_limit:
|
||||
try:
|
||||
component_args = self._selective_copy(args)
|
||||
if inspect.iscoroutinefunction(method):
|
||||
result = await method(*component_args)
|
||||
else:
|
||||
result = method(*component_args)
|
||||
result = method(*component_args)
|
||||
if inspect.isawaitable(result):
|
||||
result = await result
|
||||
if result is not None:
|
||||
method_result.extend(result)
|
||||
args = component_args
|
||||
@@ -269,9 +258,9 @@ class BaseAgent(Configurable[BaseAgentSettings], metaclass=AgentMeta):
|
||||
break
|
||||
# Successful pipeline execution
|
||||
break
|
||||
except EndpointPipelineError:
|
||||
except EndpointPipelineError as e:
|
||||
self._trace.append(
|
||||
f"❌ {Fore.LIGHTRED_EX}{component.__class__.__name__}: "
|
||||
f"❌ {Fore.LIGHTRED_EX}{e.triggerer.__class__.__name__}: "
|
||||
f"EndpointPipelineError{Fore.RESET}"
|
||||
)
|
||||
# Restart from the beginning on EndpointPipelineError
|
||||
|
||||
@@ -36,8 +36,9 @@ class AgentComponent(ABC):
|
||||
class ComponentEndpointError(Exception):
|
||||
"""Error of a single protocol method on a component."""
|
||||
|
||||
def __init__(self, message: str = ""):
|
||||
def __init__(self, message: str, component: AgentComponent):
|
||||
self.message = message
|
||||
self.triggerer = component
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
|
||||
@@ -1,14 +1,13 @@
|
||||
from abc import abstractmethod
|
||||
from typing import TYPE_CHECKING, Iterator
|
||||
from typing import TYPE_CHECKING, Awaitable, Generic, Iterator
|
||||
|
||||
from forge.models.action import ActionResult, AnyProposal
|
||||
|
||||
from .components import AgentComponent
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from forge.command.command import Command
|
||||
from forge.llm.providers import ChatMessage
|
||||
from forge.models.action import ActionResult
|
||||
|
||||
from .base import ActionProposal
|
||||
|
||||
|
||||
class DirectiveProvider(AgentComponent):
|
||||
@@ -34,19 +33,19 @@ class MessageProvider(AgentComponent):
|
||||
...
|
||||
|
||||
|
||||
class AfterParse(AgentComponent):
|
||||
class AfterParse(AgentComponent, Generic[AnyProposal]):
|
||||
@abstractmethod
|
||||
def after_parse(self, result: "ActionProposal") -> None:
|
||||
def after_parse(self, result: AnyProposal) -> None | Awaitable[None]:
|
||||
...
|
||||
|
||||
|
||||
class ExecutionFailure(AgentComponent):
|
||||
@abstractmethod
|
||||
def execution_failure(self, error: Exception) -> None:
|
||||
def execution_failure(self, error: Exception) -> None | Awaitable[None]:
|
||||
...
|
||||
|
||||
|
||||
class AfterExecute(AgentComponent):
|
||||
@abstractmethod
|
||||
def after_execute(self, result: "ActionResult") -> None:
|
||||
def after_execute(self, result: "ActionResult") -> None | Awaitable[None]:
|
||||
...
|
||||
|
||||
@@ -1,39 +1,16 @@
|
||||
"""
|
||||
Routes for the Agent Service.
|
||||
|
||||
This module defines the API routes for the Agent service. While there are multiple endpoints provided by the service,
|
||||
the ones that require special attention due to their complexity are:
|
||||
This module defines the API routes for the Agent service.
|
||||
|
||||
1. `execute_agent_task_step`:
|
||||
This route is significant because this is where the agent actually performs the work. The function handles
|
||||
executing the next step for a task based on its current state, and it requires careful implementation to ensure
|
||||
all scenarios (like the presence or absence of steps or a step marked as `last_step`) are handled correctly.
|
||||
|
||||
2. `upload_agent_task_artifacts`:
|
||||
This route allows for the upload of artifacts, supporting various URI types (e.g., s3, gcs, ftp, http).
|
||||
The support for different URI types makes it a bit more complex, and it's important to ensure that all
|
||||
supported URI types are correctly managed. NOTE: The AutoGPT team will eventually handle the most common
|
||||
uri types for you.
|
||||
|
||||
3. `create_agent_task`:
|
||||
While this is a simpler route, it plays a crucial role in the workflow, as it's responsible for the creation
|
||||
of a new task.
|
||||
|
||||
Developers and contributors should be especially careful when making modifications to these routes to ensure
|
||||
consistency and correctness in the system's behavior.
|
||||
Developers and contributors should be especially careful when making modifications
|
||||
to these routes to ensure consistency and correctness in the system's behavior.
|
||||
"""
|
||||
import json
|
||||
import logging
|
||||
from typing import Optional
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
|
||||
from fastapi import APIRouter, Query, Request, Response, UploadFile
|
||||
from fastapi.responses import FileResponse
|
||||
|
||||
from forge.utils.exceptions import (
|
||||
NotFoundError,
|
||||
get_detailed_traceback,
|
||||
get_exception_message,
|
||||
)
|
||||
from fastapi import APIRouter, HTTPException, Query, Request, Response, UploadFile
|
||||
from fastapi.responses import StreamingResponse
|
||||
|
||||
from .models import (
|
||||
Artifact,
|
||||
@@ -46,6 +23,9 @@ from .models import (
|
||||
TaskStepsListResponse,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from forge.agent.agent import Agent
|
||||
|
||||
base_router = APIRouter()
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -73,10 +53,10 @@ async def create_agent_task(request: Request, task_request: TaskRequestBody) ->
|
||||
|
||||
Args:
|
||||
request (Request): FastAPI request object.
|
||||
task (TaskRequestBody): The task request containing input and additional input data.
|
||||
task (TaskRequestBody): The task request containing input data.
|
||||
|
||||
Returns:
|
||||
Task: A new task with task_id, input, additional_input, and empty lists for artifacts and steps.
|
||||
Task: A new task with task_id, input, and additional_input set.
|
||||
|
||||
Example:
|
||||
Request (TaskRequestBody defined in schema.py):
|
||||
@@ -93,46 +73,32 @@ async def create_agent_task(request: Request, task_request: TaskRequestBody) ->
|
||||
"artifacts": [],
|
||||
}
|
||||
"""
|
||||
agent = request["agent"]
|
||||
agent: "Agent" = request["agent"]
|
||||
|
||||
try:
|
||||
task_request = await agent.create_task(task_request)
|
||||
return Response(
|
||||
content=task_request.json(),
|
||||
status_code=200,
|
||||
media_type="application/json",
|
||||
)
|
||||
task = await agent.create_task(task_request)
|
||||
return task
|
||||
except Exception:
|
||||
logger.exception(f"Error whilst trying to create a task: {task_request}")
|
||||
return Response(
|
||||
content=json.dumps(
|
||||
{
|
||||
"error": "Internal server error",
|
||||
"exception": get_exception_message(),
|
||||
"traceback": get_detailed_traceback(),
|
||||
}
|
||||
),
|
||||
status_code=500,
|
||||
media_type="application/json",
|
||||
)
|
||||
raise
|
||||
|
||||
|
||||
@base_router.get("/agent/tasks", tags=["agent"], response_model=TaskListResponse)
|
||||
async def list_agent_tasks(
|
||||
request: Request,
|
||||
page: Optional[int] = Query(1, ge=1),
|
||||
page_size: Optional[int] = Query(10, ge=1),
|
||||
page: int = Query(1, ge=1),
|
||||
page_size: int = Query(10, ge=1),
|
||||
) -> TaskListResponse:
|
||||
"""
|
||||
Retrieves a paginated list of all tasks.
|
||||
|
||||
Args:
|
||||
request (Request): FastAPI request object.
|
||||
page (int, optional): The page number for pagination. Defaults to 1.
|
||||
page_size (int, optional): The number of tasks per page for pagination. Defaults to 10.
|
||||
page (int, optional): Page number for pagination. Default: 1
|
||||
page_size (int, optional): Number of tasks per page for pagination. Default: 10
|
||||
|
||||
Returns:
|
||||
TaskListResponse: A response object containing a list of tasks and pagination details.
|
||||
TaskListResponse: A list of tasks, and pagination details.
|
||||
|
||||
Example:
|
||||
Request:
|
||||
@@ -158,34 +124,13 @@ async def list_agent_tasks(
|
||||
}
|
||||
}
|
||||
"""
|
||||
agent = request["agent"]
|
||||
agent: "Agent" = request["agent"]
|
||||
try:
|
||||
tasks = await agent.list_tasks(page, page_size)
|
||||
return Response(
|
||||
content=tasks.json(),
|
||||
status_code=200,
|
||||
media_type="application/json",
|
||||
)
|
||||
except NotFoundError:
|
||||
logger.exception("Error whilst trying to list tasks")
|
||||
return Response(
|
||||
content=json.dumps({"error": "Tasks not found"}),
|
||||
status_code=404,
|
||||
media_type="application/json",
|
||||
)
|
||||
return tasks
|
||||
except Exception:
|
||||
logger.exception("Error whilst trying to list tasks")
|
||||
return Response(
|
||||
content=json.dumps(
|
||||
{
|
||||
"error": "Internal server error",
|
||||
"exception": get_exception_message(),
|
||||
"traceback": get_detailed_traceback(),
|
||||
}
|
||||
),
|
||||
status_code=500,
|
||||
media_type="application/json",
|
||||
)
|
||||
raise
|
||||
|
||||
|
||||
@base_router.get("/agent/tasks/{task_id}", tags=["agent"], response_model=Task)
|
||||
@@ -239,36 +184,14 @@ async def get_agent_task(request: Request, task_id: str) -> Task:
|
||||
}
|
||||
]
|
||||
}
|
||||
"""
|
||||
agent = request["agent"]
|
||||
""" # noqa: E501
|
||||
agent: "Agent" = request["agent"]
|
||||
try:
|
||||
task = await agent.get_task(task_id)
|
||||
|
||||
return Response(
|
||||
content=task.json(),
|
||||
status_code=200,
|
||||
media_type="application/json",
|
||||
)
|
||||
except NotFoundError:
|
||||
logger.exception(f"Error whilst trying to get task: {task_id}")
|
||||
return Response(
|
||||
content=json.dumps({"error": "Task not found"}),
|
||||
status_code=404,
|
||||
media_type="application/json",
|
||||
)
|
||||
return task
|
||||
except Exception:
|
||||
logger.exception(f"Error whilst trying to get task: {task_id}")
|
||||
return Response(
|
||||
content=json.dumps(
|
||||
{
|
||||
"error": "Internal server error",
|
||||
"exception": get_exception_message(),
|
||||
"traceback": get_detailed_traceback(),
|
||||
}
|
||||
),
|
||||
status_code=500,
|
||||
media_type="application/json",
|
||||
)
|
||||
raise
|
||||
|
||||
|
||||
@base_router.get(
|
||||
@@ -279,8 +202,8 @@ async def get_agent_task(request: Request, task_id: str) -> Task:
|
||||
async def list_agent_task_steps(
|
||||
request: Request,
|
||||
task_id: str,
|
||||
page: Optional[int] = Query(1, ge=1),
|
||||
page_size: Optional[int] = Query(10, ge=1, alias="pageSize"),
|
||||
page: int = Query(1, ge=1),
|
||||
page_size: int = Query(10, ge=1, alias="pageSize"),
|
||||
) -> TaskStepsListResponse:
|
||||
"""
|
||||
Retrieves a paginated list of steps associated with a specific task.
|
||||
@@ -289,10 +212,10 @@ async def list_agent_task_steps(
|
||||
request (Request): FastAPI request object.
|
||||
task_id (str): The ID of the task.
|
||||
page (int, optional): The page number for pagination. Defaults to 1.
|
||||
page_size (int, optional): The number of steps per page for pagination. Defaults to 10.
|
||||
page_size (int, optional): Number of steps per page for pagination. Default: 10.
|
||||
|
||||
Returns:
|
||||
TaskStepsListResponse: A response object containing a list of steps and pagination details.
|
||||
TaskStepsListResponse: A list of steps, and pagination details.
|
||||
|
||||
Example:
|
||||
Request:
|
||||
@@ -315,54 +238,40 @@ async def list_agent_task_steps(
|
||||
"pageSize": 10
|
||||
}
|
||||
}
|
||||
"""
|
||||
agent = request["agent"]
|
||||
""" # noqa: E501
|
||||
agent: "Agent" = request["agent"]
|
||||
try:
|
||||
steps = await agent.list_steps(task_id, page, page_size)
|
||||
return Response(
|
||||
content=steps.json(),
|
||||
status_code=200,
|
||||
media_type="application/json",
|
||||
)
|
||||
except NotFoundError:
|
||||
logger.exception("Error whilst trying to list steps")
|
||||
return Response(
|
||||
content=json.dumps({"error": "Steps not found"}),
|
||||
status_code=404,
|
||||
media_type="application/json",
|
||||
)
|
||||
return steps
|
||||
except Exception:
|
||||
logger.exception("Error whilst trying to list steps")
|
||||
return Response(
|
||||
content=json.dumps(
|
||||
{
|
||||
"error": "Internal server error",
|
||||
"exception": get_exception_message(),
|
||||
"traceback": get_detailed_traceback(),
|
||||
}
|
||||
),
|
||||
status_code=500,
|
||||
media_type="application/json",
|
||||
)
|
||||
raise
|
||||
|
||||
|
||||
@base_router.post("/agent/tasks/{task_id}/steps", tags=["agent"], response_model=Step)
|
||||
async def execute_agent_task_step(
|
||||
request: Request, task_id: str, step: Optional[StepRequestBody] = None
|
||||
request: Request, task_id: str, step_request: Optional[StepRequestBody] = None
|
||||
) -> Step:
|
||||
"""
|
||||
Executes the next step for a specified task based on the current task status and returns the
|
||||
executed step with additional feedback fields.
|
||||
Executes the next step for a specified task based on the current task status and
|
||||
returns the executed step with additional feedback fields.
|
||||
|
||||
Depending on the current state of the task, the following scenarios are supported:
|
||||
This route is significant because this is where the agent actually performs work.
|
||||
The function handles executing the next step for a task based on its current state,
|
||||
and it requires careful implementation to ensure all scenarios (like the presence
|
||||
or absence of steps or a step marked as `last_step`) are handled correctly.
|
||||
|
||||
Depending on the current state of the task, the following scenarios are possible:
|
||||
1. No steps exist for the task.
|
||||
2. There is at least one step already for the task, and the task does not have a completed step marked as `last_step`.
|
||||
2. There is at least one step already for the task, and the task does not have a
|
||||
completed step marked as `last_step`.
|
||||
3. There is a completed step marked as `last_step` already on the task.
|
||||
|
||||
In each of these scenarios, a step object will be returned with two additional fields: `output` and `additional_output`.
|
||||
In each of these scenarios, a step object will be returned with two additional
|
||||
fields: `output` and `additional_output`.
|
||||
- `output`: Provides the primary response or feedback to the user.
|
||||
- `additional_output`: Supplementary information or data. Its specific content is not strictly defined and can vary based on the step or agent's implementation.
|
||||
- `additional_output`: Supplementary information or data. Its specific content is
|
||||
not strictly defined and can vary based on the step or agent's implementation.
|
||||
|
||||
Args:
|
||||
request (Request): FastAPI request object.
|
||||
@@ -389,39 +298,17 @@ async def execute_agent_task_step(
|
||||
...
|
||||
}
|
||||
"""
|
||||
agent = request["agent"]
|
||||
agent: "Agent" = request["agent"]
|
||||
try:
|
||||
# An empty step request represents a yes to continue command
|
||||
if not step:
|
||||
step = StepRequestBody(input="y")
|
||||
if not step_request:
|
||||
step_request = StepRequestBody(input="y")
|
||||
|
||||
step = await agent.execute_step(task_id, step)
|
||||
|
||||
return Response(
|
||||
content=step.json(),
|
||||
status_code=200,
|
||||
media_type="application/json",
|
||||
)
|
||||
except NotFoundError:
|
||||
logger.exception(f"Error whilst trying to execute a task step: {task_id}")
|
||||
return Response(
|
||||
content=json.dumps({"error": f"Task not found {task_id}"}),
|
||||
status_code=404,
|
||||
media_type="application/json",
|
||||
)
|
||||
step = await agent.execute_step(task_id, step_request)
|
||||
return step
|
||||
except Exception:
|
||||
logger.exception(f"Error whilst trying to execute a task step: {task_id}")
|
||||
return Response(
|
||||
content=json.dumps(
|
||||
{
|
||||
"error": "Internal server error",
|
||||
"exception": get_exception_message(),
|
||||
"traceback": get_detailed_traceback(),
|
||||
}
|
||||
),
|
||||
status_code=500,
|
||||
media_type="application/json",
|
||||
)
|
||||
raise
|
||||
|
||||
|
||||
@base_router.get(
|
||||
@@ -450,31 +337,13 @@ async def get_agent_task_step(request: Request, task_id: str, step_id: str) -> S
|
||||
...
|
||||
}
|
||||
"""
|
||||
agent = request["agent"]
|
||||
agent: "Agent" = request["agent"]
|
||||
try:
|
||||
step = await agent.get_step(task_id, step_id)
|
||||
|
||||
return Response(content=step.json(), status_code=200)
|
||||
except NotFoundError:
|
||||
logger.exception(f"Error whilst trying to get step: {step_id}")
|
||||
return Response(
|
||||
content=json.dumps({"error": "Step not found"}),
|
||||
status_code=404,
|
||||
media_type="application/json",
|
||||
)
|
||||
return step
|
||||
except Exception:
|
||||
logger.exception(f"Error whilst trying to get step: {step_id}")
|
||||
return Response(
|
||||
content=json.dumps(
|
||||
{
|
||||
"error": "Internal server error",
|
||||
"exception": get_exception_message(),
|
||||
"traceback": get_detailed_traceback(),
|
||||
}
|
||||
),
|
||||
status_code=500,
|
||||
media_type="application/json",
|
||||
)
|
||||
raise
|
||||
|
||||
|
||||
@base_router.get(
|
||||
@@ -485,8 +354,8 @@ async def get_agent_task_step(request: Request, task_id: str, step_id: str) -> S
|
||||
async def list_agent_task_artifacts(
|
||||
request: Request,
|
||||
task_id: str,
|
||||
page: Optional[int] = Query(1, ge=1),
|
||||
page_size: Optional[int] = Query(10, ge=1, alias="pageSize"),
|
||||
page: int = Query(1, ge=1),
|
||||
page_size: int = Query(10, ge=1, alias="pageSize"),
|
||||
) -> TaskArtifactsListResponse:
|
||||
"""
|
||||
Retrieves a paginated list of artifacts associated with a specific task.
|
||||
@@ -495,10 +364,10 @@ async def list_agent_task_artifacts(
|
||||
request (Request): FastAPI request object.
|
||||
task_id (str): The ID of the task.
|
||||
page (int, optional): The page number for pagination. Defaults to 1.
|
||||
page_size (int, optional): The number of items per page for pagination. Defaults to 10.
|
||||
page_size (int, optional): Number of items per page for pagination. Default: 10.
|
||||
|
||||
Returns:
|
||||
TaskArtifactsListResponse: A response object containing a list of artifacts and pagination details.
|
||||
TaskArtifactsListResponse: A list of artifacts, and pagination details.
|
||||
|
||||
Example:
|
||||
Request:
|
||||
@@ -518,52 +387,33 @@ async def list_agent_task_artifacts(
|
||||
"pageSize": 10
|
||||
}
|
||||
}
|
||||
"""
|
||||
agent = request["agent"]
|
||||
""" # noqa: E501
|
||||
agent: "Agent" = request["agent"]
|
||||
try:
|
||||
artifacts: TaskArtifactsListResponse = await agent.list_artifacts(
|
||||
task_id, page, page_size
|
||||
)
|
||||
artifacts = await agent.list_artifacts(task_id, page, page_size)
|
||||
return artifacts
|
||||
except NotFoundError:
|
||||
logger.exception("Error whilst trying to list artifacts")
|
||||
return Response(
|
||||
content=json.dumps({"error": "Artifacts not found for task_id"}),
|
||||
status_code=404,
|
||||
media_type="application/json",
|
||||
)
|
||||
except Exception:
|
||||
logger.exception("Error whilst trying to list artifacts")
|
||||
return Response(
|
||||
content=json.dumps(
|
||||
{
|
||||
"error": "Internal server error",
|
||||
"exception": get_exception_message(),
|
||||
"traceback": get_detailed_traceback(),
|
||||
}
|
||||
),
|
||||
status_code=500,
|
||||
media_type="application/json",
|
||||
)
|
||||
raise
|
||||
|
||||
|
||||
@base_router.post(
|
||||
"/agent/tasks/{task_id}/artifacts", tags=["agent"], response_model=Artifact
|
||||
)
|
||||
async def upload_agent_task_artifacts(
|
||||
request: Request, task_id: str, file: UploadFile, relative_path: Optional[str] = ""
|
||||
request: Request, task_id: str, file: UploadFile, relative_path: str = ""
|
||||
) -> Artifact:
|
||||
"""
|
||||
This endpoint is used to upload an artifact associated with a specific task. The artifact is provided as a file.
|
||||
This endpoint is used to upload an artifact (file) associated with a specific task.
|
||||
|
||||
Args:
|
||||
request (Request): The FastAPI request object.
|
||||
task_id (str): The unique identifier of the task for which the artifact is being uploaded.
|
||||
task_id (str): The ID of the task for which the artifact is being uploaded.
|
||||
file (UploadFile): The file being uploaded as an artifact.
|
||||
relative_path (str): The relative path for the file. This is a query parameter.
|
||||
|
||||
Returns:
|
||||
Artifact: An object containing metadata of the uploaded artifact, including its unique identifier.
|
||||
Artifact: Metadata object for the uploaded artifact, including its ID and path.
|
||||
|
||||
Example:
|
||||
Request:
|
||||
@@ -579,35 +429,17 @@ async def upload_agent_task_artifacts(
|
||||
"relative_path": "/my_folder/my_other_folder/",
|
||||
"file_name": "main.py"
|
||||
}
|
||||
"""
|
||||
agent = request["agent"]
|
||||
""" # noqa: E501
|
||||
agent: "Agent" = request["agent"]
|
||||
|
||||
if file is None:
|
||||
return Response(
|
||||
content=json.dumps({"error": "File must be specified"}),
|
||||
status_code=404,
|
||||
media_type="application/json",
|
||||
)
|
||||
raise HTTPException(status_code=400, detail="File must be specified")
|
||||
try:
|
||||
artifact = await agent.create_artifact(task_id, file, relative_path)
|
||||
return Response(
|
||||
content=artifact.json(),
|
||||
status_code=200,
|
||||
media_type="application/json",
|
||||
)
|
||||
return artifact
|
||||
except Exception:
|
||||
logger.exception(f"Error whilst trying to upload artifact: {task_id}")
|
||||
return Response(
|
||||
content=json.dumps(
|
||||
{
|
||||
"error": "Internal server error",
|
||||
"exception": get_exception_message(),
|
||||
"traceback": get_detailed_traceback(),
|
||||
}
|
||||
),
|
||||
status_code=500,
|
||||
media_type="application/json",
|
||||
)
|
||||
raise
|
||||
|
||||
|
||||
@base_router.get(
|
||||
@@ -617,7 +449,7 @@ async def upload_agent_task_artifacts(
|
||||
)
|
||||
async def download_agent_task_artifact(
|
||||
request: Request, task_id: str, artifact_id: str
|
||||
) -> FileResponse:
|
||||
) -> StreamingResponse:
|
||||
"""
|
||||
Downloads an artifact associated with a specific task.
|
||||
|
||||
@@ -636,32 +468,9 @@ async def download_agent_task_artifact(
|
||||
Response:
|
||||
<file_content_of_artifact>
|
||||
"""
|
||||
agent = request["agent"]
|
||||
agent: "Agent" = request["agent"]
|
||||
try:
|
||||
return await agent.get_artifact(task_id, artifact_id)
|
||||
except NotFoundError:
|
||||
logger.exception(f"Error whilst trying to download artifact: {task_id}")
|
||||
return Response(
|
||||
content=json.dumps(
|
||||
{
|
||||
"error": f"Artifact not found "
|
||||
"- task_id: {task_id}, artifact_id: {artifact_id}"
|
||||
}
|
||||
),
|
||||
status_code=404,
|
||||
media_type="application/json",
|
||||
)
|
||||
except Exception:
|
||||
logger.exception(f"Error whilst trying to download artifact: {task_id}")
|
||||
return Response(
|
||||
content=json.dumps(
|
||||
{
|
||||
"error": f"Internal server error "
|
||||
"- task_id: {task_id}, artifact_id: {artifact_id}",
|
||||
"exception": get_exception_message(),
|
||||
"traceback": get_detailed_traceback(),
|
||||
}
|
||||
),
|
||||
status_code=500,
|
||||
media_type="application/json",
|
||||
)
|
||||
raise
|
||||
|
||||
@@ -1 +1,3 @@
|
||||
from .db import AgentDB
|
||||
|
||||
__all__ = ["AgentDB"]
|
||||
|
||||
@@ -4,23 +4,22 @@ It uses SQLite as the database and file store backend.
|
||||
IT IS NOT ADVISED TO USE THIS IN PRODUCTION!
|
||||
"""
|
||||
|
||||
import datetime
|
||||
import logging
|
||||
import math
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict, List, Literal, Optional, Tuple
|
||||
|
||||
from sqlalchemy import (
|
||||
JSON,
|
||||
Boolean,
|
||||
Column,
|
||||
DateTime,
|
||||
ForeignKey,
|
||||
String,
|
||||
create_engine,
|
||||
)
|
||||
from sqlalchemy import JSON, Boolean, DateTime, ForeignKey, create_engine
|
||||
from sqlalchemy.exc import SQLAlchemyError
|
||||
from sqlalchemy.orm import DeclarativeBase, joinedload, relationship, sessionmaker
|
||||
from sqlalchemy.orm import (
|
||||
DeclarativeBase,
|
||||
Mapped,
|
||||
joinedload,
|
||||
mapped_column,
|
||||
relationship,
|
||||
sessionmaker,
|
||||
)
|
||||
|
||||
from forge.utils.exceptions import NotFoundError
|
||||
|
||||
@@ -32,18 +31,20 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class Base(DeclarativeBase):
|
||||
pass
|
||||
type_annotation_map = {
|
||||
dict[str, Any]: JSON,
|
||||
}
|
||||
|
||||
|
||||
class TaskModel(Base):
|
||||
__tablename__ = "tasks"
|
||||
|
||||
task_id = Column(String, primary_key=True, index=True)
|
||||
input = Column(String)
|
||||
additional_input = Column(JSON)
|
||||
created_at = Column(DateTime, default=datetime.datetime.utcnow)
|
||||
modified_at = Column(
|
||||
DateTime, default=datetime.datetime.utcnow, onupdate=datetime.datetime.utcnow
|
||||
task_id: Mapped[str] = mapped_column(primary_key=True, index=True)
|
||||
input: Mapped[str]
|
||||
additional_input: Mapped[dict[str, Any]] = mapped_column(default=dict)
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime, default=datetime.utcnow)
|
||||
modified_at: Mapped[datetime] = mapped_column(
|
||||
DateTime, default=datetime.utcnow, onupdate=datetime.utcnow
|
||||
)
|
||||
|
||||
artifacts = relationship("ArtifactModel", back_populates="task")
|
||||
@@ -52,35 +53,35 @@ class TaskModel(Base):
|
||||
class StepModel(Base):
|
||||
__tablename__ = "steps"
|
||||
|
||||
step_id = Column(String, primary_key=True, index=True)
|
||||
task_id = Column(String, ForeignKey("tasks.task_id"))
|
||||
name = Column(String)
|
||||
input = Column(String)
|
||||
status = Column(String)
|
||||
output = Column(String)
|
||||
is_last = Column(Boolean, default=False)
|
||||
created_at = Column(DateTime, default=datetime.datetime.utcnow)
|
||||
modified_at = Column(
|
||||
DateTime, default=datetime.datetime.utcnow, onupdate=datetime.datetime.utcnow
|
||||
step_id: Mapped[str] = mapped_column(primary_key=True, index=True)
|
||||
task_id: Mapped[str] = mapped_column(ForeignKey("tasks.task_id"))
|
||||
name: Mapped[str]
|
||||
input: Mapped[str]
|
||||
status: Mapped[str]
|
||||
output: Mapped[Optional[str]]
|
||||
is_last: Mapped[bool] = mapped_column(Boolean, default=False)
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime, default=datetime.utcnow)
|
||||
modified_at: Mapped[datetime] = mapped_column(
|
||||
DateTime, default=datetime.utcnow, onupdate=datetime.utcnow
|
||||
)
|
||||
|
||||
additional_input = Column(JSON)
|
||||
additional_output = Column(JSON)
|
||||
additional_input: Mapped[dict[str, Any]] = mapped_column(default=dict)
|
||||
additional_output: Mapped[Optional[dict[str, Any]]]
|
||||
artifacts = relationship("ArtifactModel", back_populates="step")
|
||||
|
||||
|
||||
class ArtifactModel(Base):
|
||||
__tablename__ = "artifacts"
|
||||
|
||||
artifact_id = Column(String, primary_key=True, index=True)
|
||||
task_id = Column(String, ForeignKey("tasks.task_id"))
|
||||
step_id = Column(String, ForeignKey("steps.step_id"))
|
||||
agent_created = Column(Boolean, default=False)
|
||||
file_name = Column(String)
|
||||
relative_path = Column(String)
|
||||
created_at = Column(DateTime, default=datetime.datetime.utcnow)
|
||||
modified_at = Column(
|
||||
DateTime, default=datetime.datetime.utcnow, onupdate=datetime.datetime.utcnow
|
||||
artifact_id: Mapped[str] = mapped_column(primary_key=True, index=True)
|
||||
task_id: Mapped[str] = mapped_column(ForeignKey("tasks.task_id"))
|
||||
step_id: Mapped[Optional[str]] = mapped_column(ForeignKey("steps.step_id"))
|
||||
agent_created: Mapped[bool] = mapped_column(default=False)
|
||||
file_name: Mapped[str]
|
||||
relative_path: Mapped[str]
|
||||
created_at: Mapped[datetime] = mapped_column(default=datetime.utcnow)
|
||||
modified_at: Mapped[datetime] = mapped_column(
|
||||
default=datetime.utcnow, onupdate=datetime.utcnow
|
||||
)
|
||||
|
||||
step = relationship("StepModel", back_populates="artifacts")
|
||||
@@ -150,6 +151,10 @@ class AgentDB:
|
||||
Base.metadata.create_all(self.engine)
|
||||
self.Session = sessionmaker(bind=self.engine)
|
||||
|
||||
def close(self) -> None:
|
||||
self.Session.close_all()
|
||||
self.engine.dispose()
|
||||
|
||||
async def create_task(
|
||||
self, input: Optional[str], additional_input: Optional[dict] = {}
|
||||
) -> Task:
|
||||
@@ -172,8 +177,6 @@ class AgentDB:
|
||||
except SQLAlchemyError as e:
|
||||
logger.error(f"SQLAlchemy error while creating task: {e}")
|
||||
raise
|
||||
except NotFoundError as e:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error while creating task: {e}")
|
||||
raise
|
||||
@@ -207,8 +210,6 @@ class AgentDB:
|
||||
except SQLAlchemyError as e:
|
||||
logger.error(f"SQLAlchemy error while creating step: {e}")
|
||||
raise
|
||||
except NotFoundError as e:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error while creating step: {e}")
|
||||
raise
|
||||
@@ -237,7 +238,7 @@ class AgentDB:
|
||||
session.close()
|
||||
if self.debug_enabled:
|
||||
logger.debug(
|
||||
f"Artifact already exists with relative_path: {relative_path}"
|
||||
f"Artifact {file_name} already exists at {relative_path}/"
|
||||
)
|
||||
return convert_to_artifact(existing_artifact)
|
||||
|
||||
@@ -254,14 +255,12 @@ class AgentDB:
|
||||
session.refresh(new_artifact)
|
||||
if self.debug_enabled:
|
||||
logger.debug(
|
||||
f"Created new artifact with artifact_id: {new_artifact.artifact_id}"
|
||||
f"Created new artifact with ID: {new_artifact.artifact_id}"
|
||||
)
|
||||
return convert_to_artifact(new_artifact)
|
||||
except SQLAlchemyError as e:
|
||||
logger.error(f"SQLAlchemy error while creating step: {e}")
|
||||
raise
|
||||
except NotFoundError as e:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error while creating step: {e}")
|
||||
raise
|
||||
@@ -285,8 +284,6 @@ class AgentDB:
|
||||
except SQLAlchemyError as e:
|
||||
logger.error(f"SQLAlchemy error while getting task: {e}")
|
||||
raise
|
||||
except NotFoundError as e:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error while getting task: {e}")
|
||||
raise
|
||||
@@ -312,8 +309,6 @@ class AgentDB:
|
||||
except SQLAlchemyError as e:
|
||||
logger.error(f"SQLAlchemy error while getting step: {e}")
|
||||
raise
|
||||
except NotFoundError as e:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error while getting step: {e}")
|
||||
raise
|
||||
@@ -337,8 +332,6 @@ class AgentDB:
|
||||
except SQLAlchemyError as e:
|
||||
logger.error(f"SQLAlchemy error while getting artifact: {e}")
|
||||
raise
|
||||
except NotFoundError as e:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error while getting artifact: {e}")
|
||||
raise
|
||||
@@ -375,14 +368,13 @@ class AgentDB:
|
||||
return await self.get_step(task_id, step_id)
|
||||
else:
|
||||
logger.error(
|
||||
f"Step not found for update with task_id: {task_id} and step_id: {step_id}"
|
||||
"Can't update non-existent Step with "
|
||||
f"task_id: {task_id} and step_id: {step_id}"
|
||||
)
|
||||
raise NotFoundError("Step not found")
|
||||
except SQLAlchemyError as e:
|
||||
logger.error(f"SQLAlchemy error while getting step: {e}")
|
||||
raise
|
||||
except NotFoundError as e:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error while getting step: {e}")
|
||||
raise
|
||||
@@ -441,8 +433,6 @@ class AgentDB:
|
||||
except SQLAlchemyError as e:
|
||||
logger.error(f"SQLAlchemy error while listing tasks: {e}")
|
||||
raise
|
||||
except NotFoundError as e:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error while listing tasks: {e}")
|
||||
raise
|
||||
@@ -475,8 +465,6 @@ class AgentDB:
|
||||
except SQLAlchemyError as e:
|
||||
logger.error(f"SQLAlchemy error while listing steps: {e}")
|
||||
raise
|
||||
except NotFoundError as e:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error while listing steps: {e}")
|
||||
raise
|
||||
@@ -509,8 +497,6 @@ class AgentDB:
|
||||
except SQLAlchemyError as e:
|
||||
logger.error(f"SQLAlchemy error while listing artifacts: {e}")
|
||||
raise
|
||||
except NotFoundError as e:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error while listing artifacts: {e}")
|
||||
raise
|
||||
|
||||
@@ -22,14 +22,27 @@ from forge.agent_protocol.models import (
|
||||
)
|
||||
from forge.utils.exceptions import NotFoundError as DataNotFoundError
|
||||
|
||||
TEST_DB_FILENAME = "test_db.sqlite3"
|
||||
TEST_DB_URL = f"sqlite:///{TEST_DB_FILENAME}"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
def test_table_creation():
|
||||
db_name = "sqlite:///test_db.sqlite3"
|
||||
agent_db = AgentDB(db_name)
|
||||
|
||||
conn = sqlite3.connect("test_db.sqlite3")
|
||||
cursor = conn.cursor()
|
||||
@pytest.fixture
|
||||
def agent_db():
|
||||
db = AgentDB(TEST_DB_URL)
|
||||
yield db
|
||||
db.close()
|
||||
os.remove(TEST_DB_FILENAME)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def raw_db_connection(agent_db: AgentDB):
|
||||
connection = sqlite3.connect(TEST_DB_FILENAME)
|
||||
yield connection
|
||||
connection.close()
|
||||
|
||||
|
||||
def test_table_creation(raw_db_connection: sqlite3.Connection):
|
||||
cursor = raw_db_connection.cursor()
|
||||
|
||||
# Test for tasks table existence
|
||||
cursor.execute("SELECT name FROM sqlite_master WHERE type='table' AND name='tasks'")
|
||||
@@ -45,8 +58,6 @@ def test_table_creation():
|
||||
)
|
||||
assert cursor.fetchone() is not None
|
||||
|
||||
os.remove(db_name.split("///")[1])
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_task_schema():
|
||||
@@ -84,7 +95,10 @@ async def test_step_schema():
|
||||
name="Write to file",
|
||||
input="Write the words you receive to the file 'output.txt'.",
|
||||
status=StepStatus.created,
|
||||
output="I am going to use the write_to_file command and write Washington to a file called output.txt <write_to_file('output.txt', 'Washington')>",
|
||||
output=(
|
||||
"I am going to use the write_to_file command and write Washington "
|
||||
"to a file called output.txt <write_to_file('output.txt', 'Washington')>"
|
||||
),
|
||||
artifacts=[
|
||||
Artifact(
|
||||
artifact_id="b225e278-8b4c-4f99-a696-8facf19f0e56",
|
||||
@@ -101,13 +115,13 @@ async def test_step_schema():
|
||||
assert step.step_id == "6bb1801a-fd80-45e8-899a-4dd723cc602e"
|
||||
assert step.name == "Write to file"
|
||||
assert step.status == StepStatus.created
|
||||
assert (
|
||||
step.output
|
||||
== "I am going to use the write_to_file command and write Washington to a file called output.txt <write_to_file('output.txt', 'Washington')>"
|
||||
assert step.output == (
|
||||
"I am going to use the write_to_file command and write Washington "
|
||||
"to a file called output.txt <write_to_file('output.txt', 'Washington')>"
|
||||
)
|
||||
assert len(step.artifacts) == 1
|
||||
assert step.artifacts[0].artifact_id == "b225e278-8b4c-4f99-a696-8facf19f0e56"
|
||||
assert step.is_last == False
|
||||
assert step.is_last is False
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@@ -118,6 +132,7 @@ async def test_convert_to_task():
|
||||
created_at=now,
|
||||
modified_at=now,
|
||||
input="Write the words you receive to the file 'output.txt'.",
|
||||
additional_input={},
|
||||
artifacts=[
|
||||
ArtifactModel(
|
||||
artifact_id="b225e278-8b4c-4f99-a696-8facf19f0e56",
|
||||
@@ -147,6 +162,7 @@ async def test_convert_to_step():
|
||||
name="Write to file",
|
||||
status="created",
|
||||
input="Write the words you receive to the file 'output.txt'.",
|
||||
additional_input={},
|
||||
artifacts=[
|
||||
ArtifactModel(
|
||||
artifact_id="b225e278-8b4c-4f99-a696-8facf19f0e56",
|
||||
@@ -166,7 +182,7 @@ async def test_convert_to_step():
|
||||
assert step.status == StepStatus.created
|
||||
assert len(step.artifacts) == 1
|
||||
assert step.artifacts[0].artifact_id == "b225e278-8b4c-4f99-a696-8facf19f0e56"
|
||||
assert step.is_last == False
|
||||
assert step.is_last is False
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@@ -183,91 +199,67 @@ async def test_convert_to_artifact():
|
||||
artifact = convert_to_artifact(artifact_model)
|
||||
assert artifact.artifact_id == "b225e278-8b4c-4f99-a696-8facf19f0e56"
|
||||
assert artifact.relative_path == "file:///path/to/main.py"
|
||||
assert artifact.agent_created == True
|
||||
assert artifact.agent_created is True
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_task():
|
||||
# Having issues with pytest fixture so added setup and teardown in each test as a rapid workaround
|
||||
# TODO: Fix this!
|
||||
db_name = "sqlite:///test_db.sqlite3"
|
||||
agent_db = AgentDB(db_name)
|
||||
|
||||
async def test_create_task(agent_db: AgentDB):
|
||||
task = await agent_db.create_task("task_input")
|
||||
assert task.input == "task_input"
|
||||
os.remove(db_name.split("///")[1])
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_and_get_task():
|
||||
db_name = "sqlite:///test_db.sqlite3"
|
||||
agent_db = AgentDB(db_name)
|
||||
async def test_create_and_get_task(agent_db: AgentDB):
|
||||
task = await agent_db.create_task("test_input")
|
||||
fetched_task = await agent_db.get_task(task.task_id)
|
||||
assert fetched_task.input == "test_input"
|
||||
os.remove(db_name.split("///")[1])
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_task_not_found():
|
||||
db_name = "sqlite:///test_db.sqlite3"
|
||||
agent_db = AgentDB(db_name)
|
||||
async def test_get_task_not_found(agent_db: AgentDB):
|
||||
with pytest.raises(DataNotFoundError):
|
||||
await agent_db.get_task(9999)
|
||||
os.remove(db_name.split("///")[1])
|
||||
await agent_db.get_task("9999")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_and_get_step():
|
||||
db_name = "sqlite:///test_db.sqlite3"
|
||||
agent_db = AgentDB(db_name)
|
||||
async def test_create_and_get_step(agent_db: AgentDB):
|
||||
task = await agent_db.create_task("task_input")
|
||||
step_input = StepInput(type="python/code")
|
||||
step_input = {"type": "python/code"}
|
||||
request = StepRequestBody(input="test_input debug", additional_input=step_input)
|
||||
step = await agent_db.create_step(task.task_id, request)
|
||||
step = await agent_db.get_step(task.task_id, step.step_id)
|
||||
assert step.input == "test_input debug"
|
||||
os.remove(db_name.split("///")[1])
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_updating_step():
|
||||
db_name = "sqlite:///test_db.sqlite3"
|
||||
agent_db = AgentDB(db_name)
|
||||
async def test_updating_step(agent_db: AgentDB):
|
||||
created_task = await agent_db.create_task("task_input")
|
||||
step_input = StepInput(type="python/code")
|
||||
step_input = {"type": "python/code"}
|
||||
request = StepRequestBody(input="test_input debug", additional_input=step_input)
|
||||
created_step = await agent_db.create_step(created_task.task_id, request)
|
||||
await agent_db.update_step(created_task.task_id, created_step.step_id, "completed")
|
||||
|
||||
step = await agent_db.get_step(created_task.task_id, created_step.step_id)
|
||||
assert step.status.value == "completed"
|
||||
os.remove(db_name.split("///")[1])
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_step_not_found():
|
||||
db_name = "sqlite:///test_db.sqlite3"
|
||||
agent_db = AgentDB(db_name)
|
||||
async def test_get_step_not_found(agent_db: AgentDB):
|
||||
with pytest.raises(DataNotFoundError):
|
||||
await agent_db.get_step(9999, 9999)
|
||||
os.remove(db_name.split("///")[1])
|
||||
await agent_db.get_step("9999", "9999")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_artifact():
|
||||
db_name = "sqlite:///test_db.sqlite3"
|
||||
db = AgentDB(db_name)
|
||||
|
||||
async def test_get_artifact(agent_db: AgentDB):
|
||||
# Given: A task and its corresponding artifact
|
||||
task = await db.create_task("test_input debug")
|
||||
step_input = StepInput(type="python/code")
|
||||
task = await agent_db.create_task("test_input debug")
|
||||
step_input = {"type": "python/code"}
|
||||
requst = StepRequestBody(input="test_input debug", additional_input=step_input)
|
||||
|
||||
step = await db.create_step(task.task_id, requst)
|
||||
step = await agent_db.create_step(task.task_id, requst)
|
||||
|
||||
# Create an artifact
|
||||
artifact = await db.create_artifact(
|
||||
artifact = await agent_db.create_artifact(
|
||||
task_id=task.task_id,
|
||||
file_name="test_get_artifact_sample_file.txt",
|
||||
relative_path="file:///path/to/test_get_artifact_sample_file.txt",
|
||||
@@ -276,7 +268,7 @@ async def test_get_artifact():
|
||||
)
|
||||
|
||||
# When: The artifact is fetched by its ID
|
||||
fetched_artifact = await db.get_artifact(artifact.artifact_id)
|
||||
fetched_artifact = await agent_db.get_artifact(artifact.artifact_id)
|
||||
|
||||
# Then: The fetched artifact matches the original
|
||||
assert fetched_artifact.artifact_id == artifact.artifact_id
|
||||
@@ -285,47 +277,37 @@ async def test_get_artifact():
|
||||
== "file:///path/to/test_get_artifact_sample_file.txt"
|
||||
)
|
||||
|
||||
os.remove(db_name.split("///")[1])
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_tasks():
|
||||
db_name = "sqlite:///test_db.sqlite3"
|
||||
db = AgentDB(db_name)
|
||||
|
||||
async def test_list_tasks(agent_db: AgentDB):
|
||||
# Given: Multiple tasks in the database
|
||||
task1 = await db.create_task("test_input_1")
|
||||
task2 = await db.create_task("test_input_2")
|
||||
task1 = await agent_db.create_task("test_input_1")
|
||||
task2 = await agent_db.create_task("test_input_2")
|
||||
|
||||
# When: All tasks are fetched
|
||||
fetched_tasks, pagination = await db.list_tasks()
|
||||
fetched_tasks, pagination = await agent_db.list_tasks()
|
||||
|
||||
# Then: The fetched tasks list includes the created tasks
|
||||
task_ids = [task.task_id for task in fetched_tasks]
|
||||
assert task1.task_id in task_ids
|
||||
assert task2.task_id in task_ids
|
||||
os.remove(db_name.split("///")[1])
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_steps():
|
||||
db_name = "sqlite:///test_db.sqlite3"
|
||||
db = AgentDB(db_name)
|
||||
|
||||
step_input = StepInput(type="python/code")
|
||||
requst = StepRequestBody(input="test_input debug", additional_input=step_input)
|
||||
async def test_list_steps(agent_db: AgentDB):
|
||||
step_input = {"type": "python/code"}
|
||||
request = StepRequestBody(input="test_input debug", additional_input=step_input)
|
||||
|
||||
# Given: A task and multiple steps for that task
|
||||
task = await db.create_task("test_input")
|
||||
step1 = await db.create_step(task.task_id, requst)
|
||||
requst = StepRequestBody(input="step two", additional_input=step_input)
|
||||
step2 = await db.create_step(task.task_id, requst)
|
||||
task = await agent_db.create_task("test_input")
|
||||
step1 = await agent_db.create_step(task.task_id, request)
|
||||
request = StepRequestBody(input="step two")
|
||||
step2 = await agent_db.create_step(task.task_id, request)
|
||||
|
||||
# When: All steps for the task are fetched
|
||||
fetched_steps, pagination = await db.list_steps(task.task_id)
|
||||
fetched_steps, pagination = await agent_db.list_steps(task.task_id)
|
||||
|
||||
# Then: The fetched steps list includes the created steps
|
||||
step_ids = [step.step_id for step in fetched_steps]
|
||||
assert step1.step_id in step_ids
|
||||
assert step2.step_id in step_ids
|
||||
os.remove(db_name.split("///")[1])
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from .artifact import Artifact, ArtifactUpload
|
||||
from .artifact import Artifact
|
||||
from .pagination import Pagination
|
||||
from .task import (
|
||||
Step,
|
||||
@@ -10,3 +10,16 @@ from .task import (
|
||||
TaskRequestBody,
|
||||
TaskStepsListResponse,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"Artifact",
|
||||
"Pagination",
|
||||
"Step",
|
||||
"StepRequestBody",
|
||||
"StepStatus",
|
||||
"Task",
|
||||
"TaskArtifactsListResponse",
|
||||
"TaskListResponse",
|
||||
"TaskRequestBody",
|
||||
"TaskStepsListResponse",
|
||||
]
|
||||
|
||||
@@ -3,15 +3,6 @@ from datetime import datetime
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class ArtifactUpload(BaseModel):
|
||||
file: str = Field(..., description="File to upload.", format="binary")
|
||||
relative_path: str = Field(
|
||||
...,
|
||||
description="Relative path of the artifact in the agent's workspace.",
|
||||
example="python/code",
|
||||
)
|
||||
|
||||
|
||||
class Artifact(BaseModel):
|
||||
created_at: datetime = Field(
|
||||
...,
|
||||
|
||||
@@ -2,7 +2,7 @@ from __future__ import annotations
|
||||
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from typing import List, Optional
|
||||
from typing import Any, List, Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
@@ -17,7 +17,7 @@ class TaskRequestBody(BaseModel):
|
||||
description="Input prompt for the task.",
|
||||
example="Write the words you receive to the file 'output.txt'.",
|
||||
)
|
||||
additional_input: Optional[dict] = None
|
||||
additional_input: dict[str, Any] = Field(default_factory=dict)
|
||||
|
||||
|
||||
class Task(TaskRequestBody):
|
||||
@@ -38,8 +38,8 @@ class Task(TaskRequestBody):
|
||||
description="The ID of the task.",
|
||||
example="50da533e-3904-4401-8a07-c49adf88b5eb",
|
||||
)
|
||||
artifacts: Optional[List[Artifact]] = Field(
|
||||
[],
|
||||
artifacts: list[Artifact] = Field(
|
||||
default_factory=list,
|
||||
description="A list of artifacts that the task has produced.",
|
||||
example=[
|
||||
"7a49f31c-f9c6-4346-a22c-e32bc5af4d8e",
|
||||
@@ -50,14 +50,12 @@ class Task(TaskRequestBody):
|
||||
|
||||
class StepRequestBody(BaseModel):
|
||||
name: Optional[str] = Field(
|
||||
None, description="The name of the task step.", example="Write to file"
|
||||
default=None, description="The name of the task step.", example="Write to file"
|
||||
)
|
||||
input: Optional[str] = Field(
|
||||
None,
|
||||
description="Input prompt for the step.",
|
||||
example="Washington",
|
||||
input: str = Field(
|
||||
..., description="Input prompt for the step.", example="Washington"
|
||||
)
|
||||
additional_input: Optional[dict] = None
|
||||
additional_input: dict[str, Any] = Field(default_factory=dict)
|
||||
|
||||
|
||||
class StepStatus(Enum):
|
||||
@@ -90,19 +88,23 @@ class Step(StepRequestBody):
|
||||
example="6bb1801a-fd80-45e8-899a-4dd723cc602e",
|
||||
)
|
||||
name: Optional[str] = Field(
|
||||
None, description="The name of the task step.", example="Write to file"
|
||||
default=None, description="The name of the task step.", example="Write to file"
|
||||
)
|
||||
status: StepStatus = Field(
|
||||
..., description="The status of the task step.", example="created"
|
||||
)
|
||||
output: Optional[str] = Field(
|
||||
None,
|
||||
default=None,
|
||||
description="Output of the task step.",
|
||||
example="I am going to use the write_to_file command and write Washington to a file called output.txt <write_to_file('output.txt', 'Washington')",
|
||||
example=(
|
||||
"I am going to use the write_to_file command and write Washington "
|
||||
"to a file called output.txt <write_to_file('output.txt', 'Washington')"
|
||||
),
|
||||
)
|
||||
additional_output: Optional[dict] = None
|
||||
artifacts: Optional[List[Artifact]] = Field(
|
||||
[], description="A list of artifacts that the step has produced."
|
||||
additional_output: Optional[dict[str, Any]] = None
|
||||
artifacts: list[Artifact] = Field(
|
||||
default_factory=list,
|
||||
description="A list of artifacts that the step has produced.",
|
||||
)
|
||||
is_last: bool = Field(
|
||||
..., description="Whether this is the last step in the task.", example=True
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
from .command import Command, CommandOutput, CommandParameter
|
||||
from .command import Command
|
||||
from .decorator import command
|
||||
from .parameter import CommandParameter
|
||||
|
||||
__all__ = ["Command", "CommandParameter", "command"]
|
||||
|
||||
@@ -1,14 +1,16 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import inspect
|
||||
from typing import Any, Callable, Generic, ParamSpec, TypeVar
|
||||
from typing import Callable, Concatenate, Generic, ParamSpec, TypeVar, cast
|
||||
|
||||
from forge.agent.protocols import CommandProvider
|
||||
|
||||
from .parameter import CommandParameter
|
||||
|
||||
CommandOutput = Any
|
||||
|
||||
P = ParamSpec("P")
|
||||
CO = TypeVar("CO", bound=CommandOutput)
|
||||
CO = TypeVar("CO") # command output
|
||||
|
||||
_CP = TypeVar("_CP", bound=CommandProvider)
|
||||
|
||||
|
||||
class Command(Generic[P, CO]):
|
||||
@@ -24,7 +26,7 @@ class Command(Generic[P, CO]):
|
||||
self,
|
||||
names: list[str],
|
||||
description: str,
|
||||
method: Callable[P, CO],
|
||||
method: Callable[Concatenate[_CP, P], CO],
|
||||
parameters: list[CommandParameter],
|
||||
):
|
||||
# Check if all parameters are provided
|
||||
@@ -34,7 +36,9 @@ class Command(Generic[P, CO]):
|
||||
)
|
||||
self.names = names
|
||||
self.description = description
|
||||
self.method = method
|
||||
# Method technically has a `self` parameter, but we can ignore that
|
||||
# since Python passes it internally.
|
||||
self.method = cast(Callable[P, CO], method)
|
||||
self.parameters = parameters
|
||||
|
||||
@property
|
||||
@@ -62,7 +66,8 @@ class Command(Generic[P, CO]):
|
||||
def __str__(self) -> str:
|
||||
params = [
|
||||
f"{param.name}: "
|
||||
+ ("%s" if param.spec.required else "Optional[%s]") % param.spec.type.value
|
||||
+ ("%s" if param.spec.required else "Optional[%s]")
|
||||
% (param.spec.type.value if param.spec.type else "Any")
|
||||
for param in self.parameters
|
||||
]
|
||||
return (
|
||||
|
||||
@@ -1,2 +1,4 @@
|
||||
from .action_history import ActionHistoryComponent
|
||||
from .model import Episode, EpisodicActionHistory
|
||||
|
||||
__all__ = ["ActionHistoryComponent", "Episode", "EpisodicActionHistory"]
|
||||
|
||||
@@ -1,27 +1,27 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, Callable, Generic, Iterator, Optional
|
||||
from typing import TYPE_CHECKING, Callable, Iterator, Optional
|
||||
|
||||
from forge.agent.protocols import AfterExecute, AfterParse, MessageProvider
|
||||
from forge.llm.prompting.utils import indent
|
||||
from forge.llm.providers import ChatMessage, ChatModelProvider
|
||||
from forge.llm.providers import ChatMessage, MultiProvider
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from forge.config.config import Config
|
||||
|
||||
from .model import AP, ActionResult, Episode, EpisodicActionHistory
|
||||
from .model import ActionResult, AnyProposal, Episode, EpisodicActionHistory
|
||||
|
||||
|
||||
class ActionHistoryComponent(MessageProvider, AfterParse, AfterExecute, Generic[AP]):
|
||||
class ActionHistoryComponent(MessageProvider, AfterParse[AnyProposal], AfterExecute):
|
||||
"""Keeps track of the event history and provides a summary of the steps."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
event_history: EpisodicActionHistory[AP],
|
||||
event_history: EpisodicActionHistory[AnyProposal],
|
||||
max_tokens: int,
|
||||
count_tokens: Callable[[str], int],
|
||||
legacy_config: Config,
|
||||
llm_provider: ChatModelProvider,
|
||||
llm_provider: MultiProvider,
|
||||
) -> None:
|
||||
self.event_history = event_history
|
||||
self.max_tokens = max_tokens
|
||||
@@ -37,7 +37,7 @@ class ActionHistoryComponent(MessageProvider, AfterParse, AfterExecute, Generic[
|
||||
):
|
||||
yield ChatMessage.system(f"## Progress on your Task so far\n\n{progress}")
|
||||
|
||||
def after_parse(self, result: AP) -> None:
|
||||
def after_parse(self, result: AnyProposal) -> None:
|
||||
self.event_history.register_action(result)
|
||||
|
||||
async def after_execute(self, result: ActionResult) -> None:
|
||||
@@ -48,7 +48,7 @@ class ActionHistoryComponent(MessageProvider, AfterParse, AfterExecute, Generic[
|
||||
|
||||
def _compile_progress(
|
||||
self,
|
||||
episode_history: list[Episode],
|
||||
episode_history: list[Episode[AnyProposal]],
|
||||
max_tokens: Optional[int] = None,
|
||||
count_tokens: Optional[Callable[[str], int]] = None,
|
||||
) -> str:
|
||||
|
||||
@@ -1,25 +1,23 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from typing import TYPE_CHECKING, Generic, Iterator, TypeVar
|
||||
from typing import TYPE_CHECKING, Generic
|
||||
|
||||
from pydantic import Field
|
||||
from pydantic.generics import GenericModel
|
||||
|
||||
from forge.content_processing.text import summarize_text
|
||||
from forge.llm.prompting.utils import format_numbered_list, indent
|
||||
from forge.models.action import ActionProposal, ActionResult
|
||||
from forge.models.action import ActionResult, AnyProposal
|
||||
from forge.models.utils import ModelWithSummary
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from forge.config.config import Config
|
||||
from forge.llm.providers import ChatModelProvider
|
||||
|
||||
AP = TypeVar("AP", bound=ActionProposal)
|
||||
from forge.llm.providers import MultiProvider
|
||||
|
||||
|
||||
class Episode(GenericModel, Generic[AP]):
|
||||
action: AP
|
||||
class Episode(GenericModel, Generic[AnyProposal]):
|
||||
action: AnyProposal
|
||||
result: ActionResult | None
|
||||
summary: str | None = None
|
||||
|
||||
@@ -54,32 +52,29 @@ class Episode(GenericModel, Generic[AP]):
|
||||
return executed_action + action_result
|
||||
|
||||
|
||||
class EpisodicActionHistory(GenericModel, Generic[AP]):
|
||||
class EpisodicActionHistory(GenericModel, Generic[AnyProposal]):
|
||||
"""Utility container for an action history"""
|
||||
|
||||
episodes: list[Episode[AP]] = Field(default_factory=list)
|
||||
episodes: list[Episode[AnyProposal]] = Field(default_factory=list)
|
||||
cursor: int = 0
|
||||
_lock = asyncio.Lock()
|
||||
|
||||
@property
|
||||
def current_episode(self) -> Episode[AP] | None:
|
||||
def current_episode(self) -> Episode[AnyProposal] | None:
|
||||
if self.cursor == len(self):
|
||||
return None
|
||||
return self[self.cursor]
|
||||
|
||||
def __getitem__(self, key: int) -> Episode[AP]:
|
||||
def __getitem__(self, key: int) -> Episode[AnyProposal]:
|
||||
return self.episodes[key]
|
||||
|
||||
def __iter__(self) -> Iterator[Episode[AP]]:
|
||||
return iter(self.episodes)
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self.episodes)
|
||||
|
||||
def __bool__(self) -> bool:
|
||||
return len(self.episodes) > 0
|
||||
|
||||
def register_action(self, action: AP) -> None:
|
||||
def register_action(self, action: AnyProposal) -> None:
|
||||
if not self.current_episode:
|
||||
self.episodes.append(Episode(action=action, result=None))
|
||||
assert self.current_episode
|
||||
@@ -113,7 +108,7 @@ class EpisodicActionHistory(GenericModel, Generic[AP]):
|
||||
self.cursor = len(self.episodes)
|
||||
|
||||
async def handle_compression(
|
||||
self, llm_provider: ChatModelProvider, app_config: Config
|
||||
self, llm_provider: MultiProvider, app_config: Config
|
||||
) -> None:
|
||||
"""Compresses each episode in the action history using an LLM.
|
||||
|
||||
|
||||
@@ -3,6 +3,11 @@ from .code_executor import (
|
||||
DENYLIST_CONTROL,
|
||||
CodeExecutionError,
|
||||
CodeExecutorComponent,
|
||||
is_docker_available,
|
||||
we_are_running_in_a_docker_container,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"ALLOWLIST_CONTROL",
|
||||
"DENYLIST_CONTROL",
|
||||
"CodeExecutionError",
|
||||
"CodeExecutorComponent",
|
||||
]
|
||||
|
||||
@@ -11,7 +11,8 @@ import docker
|
||||
from docker.errors import DockerException, ImageNotFound, NotFound
|
||||
from docker.models.containers import Container as DockerContainer
|
||||
|
||||
from forge.agent import BaseAgentSettings, CommandProvider
|
||||
from forge.agent import BaseAgentSettings
|
||||
from forge.agent.protocols import CommandProvider
|
||||
from forge.command import Command, command
|
||||
from forge.config.config import Config
|
||||
from forge.file_storage import FileStorage
|
||||
|
||||
@@ -5,3 +5,11 @@ from .context_item import (
|
||||
FolderContextItem,
|
||||
StaticContextItem,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"ContextComponent",
|
||||
"ContextItem",
|
||||
"FileContextItem",
|
||||
"FolderContextItem",
|
||||
"StaticContextItem",
|
||||
]
|
||||
|
||||
@@ -1 +1,3 @@
|
||||
from .file_manager import FileManagerComponent
|
||||
|
||||
__all__ = ["FileManagerComponent"]
|
||||
|
||||
@@ -1 +1,3 @@
|
||||
from .git_operations import GitOperationsComponent
|
||||
|
||||
__all__ = ["GitOperationsComponent"]
|
||||
|
||||
@@ -1 +1,3 @@
|
||||
from .image_gen import ImageGeneratorComponent
|
||||
|
||||
__all__ = ["ImageGeneratorComponent"]
|
||||
|
||||
@@ -32,7 +32,12 @@ class ImageGeneratorComponent(CommandProvider):
|
||||
self.legacy_config = config
|
||||
|
||||
def get_commands(self) -> Iterator[Command]:
|
||||
yield self.generate_image
|
||||
if (
|
||||
self.legacy_config.openai_credentials
|
||||
or self.legacy_config.huggingface_api_token
|
||||
or self.legacy_config.sd_webui_auth
|
||||
):
|
||||
yield self.generate_image
|
||||
|
||||
@command(
|
||||
parameters={
|
||||
@@ -60,17 +65,26 @@ class ImageGeneratorComponent(CommandProvider):
|
||||
str: The filename of the image
|
||||
"""
|
||||
filename = self.workspace.root / f"{str(uuid.uuid4())}.jpg"
|
||||
cfg = self.legacy_config
|
||||
|
||||
# DALL-E
|
||||
if self.legacy_config.image_provider == "dalle":
|
||||
if cfg.openai_credentials and (
|
||||
cfg.image_provider == "dalle"
|
||||
or not (cfg.huggingface_api_token or cfg.sd_webui_url)
|
||||
):
|
||||
return self.generate_image_with_dalle(prompt, filename, size)
|
||||
# HuggingFace
|
||||
elif self.legacy_config.image_provider == "huggingface":
|
||||
|
||||
elif cfg.huggingface_api_token and (
|
||||
cfg.image_provider == "huggingface"
|
||||
or not (cfg.openai_credentials or cfg.sd_webui_url)
|
||||
):
|
||||
return self.generate_image_with_hf(prompt, filename)
|
||||
# SD WebUI
|
||||
elif self.legacy_config.image_provider == "sdwebui":
|
||||
|
||||
elif cfg.sd_webui_url and (
|
||||
cfg.image_provider == "sdwebui" or cfg.sd_webui_auth
|
||||
):
|
||||
return self.generate_image_with_sd_webui(prompt, filename, size)
|
||||
return "No Image Provider Set"
|
||||
|
||||
return "Error: No image generation provider available"
|
||||
|
||||
def generate_image_with_hf(self, prompt: str, output_file: Path) -> str:
|
||||
"""Generate an image with HuggingFace's API.
|
||||
@@ -142,6 +156,7 @@ class ImageGeneratorComponent(CommandProvider):
|
||||
Returns:
|
||||
str: The filename of the image
|
||||
"""
|
||||
assert self.legacy_config.openai_credentials # otherwise this tool is disabled
|
||||
|
||||
# Check for supported image sizes
|
||||
if size not in [256, 512, 1024]:
|
||||
@@ -152,16 +167,19 @@ class ImageGeneratorComponent(CommandProvider):
|
||||
)
|
||||
size = closest
|
||||
|
||||
# TODO: integrate in `forge.llm.providers`(?)
|
||||
response = OpenAI(
|
||||
api_key=self.legacy_config.openai_credentials.api_key.get_secret_value()
|
||||
).images.generate(
|
||||
prompt=prompt,
|
||||
n=1,
|
||||
size=f"{size}x{size}",
|
||||
# TODO: improve typing of size config item(s)
|
||||
size=f"{size}x{size}", # type: ignore
|
||||
response_format="b64_json",
|
||||
)
|
||||
assert response.data[0].b64_json is not None # response_format = "b64_json"
|
||||
|
||||
logger.info(f"Image Generated for prompt:{prompt}")
|
||||
logger.info(f"Image Generated for prompt: {prompt}")
|
||||
|
||||
image_data = b64decode(response.data[0].b64_json)
|
||||
|
||||
|
||||
@@ -1 +1,3 @@
|
||||
from .system import SystemComponent
|
||||
|
||||
__all__ = ["SystemComponent"]
|
||||
|
||||
@@ -37,8 +37,8 @@ class SystemComponent(DirectiveProvider, MessageProvider, CommandProvider):
|
||||
"You are unable to interact with physical objects. "
|
||||
"If this is absolutely necessary to fulfill a task or objective or "
|
||||
"to complete a step, you must ask the user to do it for you. "
|
||||
"If the user refuses this, and there is no other way to achieve your goals, "
|
||||
"you must terminate to avoid wasting time and energy."
|
||||
"If the user refuses this, and there is no other way to achieve your "
|
||||
"goals, you must terminate to avoid wasting time and energy."
|
||||
)
|
||||
|
||||
def get_resources(self) -> Iterator[str]:
|
||||
|
||||
@@ -1 +1,3 @@
|
||||
from .user_interaction import UserInteractionComponent
|
||||
|
||||
__all__ = ["UserInteractionComponent"]
|
||||
|
||||
@@ -1 +1,3 @@
|
||||
from .watchdog import WatchdogComponent
|
||||
|
||||
__all__ = ["WatchdogComponent"]
|
||||
|
||||
@@ -1,18 +1,20 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from forge.agent.base import BaseAgentConfiguration
|
||||
|
||||
from forge.agent.components import ComponentSystemError
|
||||
from forge.agent.protocols import AfterParse
|
||||
from forge.components.action_history import EpisodicActionHistory
|
||||
from forge.models.action import ActionProposal
|
||||
from forge.models.action import AnyProposal
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from forge.agent.base import BaseAgentConfiguration
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class WatchdogComponent(AfterParse):
|
||||
class WatchdogComponent(AfterParse[AnyProposal]):
|
||||
"""
|
||||
Adds a watchdog feature to an agent class. Whenever the agent starts
|
||||
looping, the watchdog will switch from the FAST_LLM to the SMART_LLM and re-think.
|
||||
@@ -21,13 +23,13 @@ class WatchdogComponent(AfterParse):
|
||||
def __init__(
|
||||
self,
|
||||
config: "BaseAgentConfiguration",
|
||||
event_history: EpisodicActionHistory[ActionProposal],
|
||||
event_history: EpisodicActionHistory[AnyProposal],
|
||||
):
|
||||
self.config = config
|
||||
self.event_history = event_history
|
||||
self.revert_big_brain = False
|
||||
|
||||
def after_parse(self, result: ActionProposal) -> None:
|
||||
def after_parse(self, result: AnyProposal) -> None:
|
||||
if self.revert_big_brain:
|
||||
self.config.big_brain = False
|
||||
self.revert_big_brain = False
|
||||
@@ -58,4 +60,4 @@ class WatchdogComponent(AfterParse):
|
||||
self.big_brain = True
|
||||
self.revert_big_brain = True
|
||||
# Trigger retry of all pipelines prior to this component
|
||||
raise ComponentSystemError()
|
||||
raise ComponentSystemError(rethink_reason, self)
|
||||
|
||||
@@ -1,2 +1,4 @@
|
||||
from .search import WebSearchComponent
|
||||
from .selenium import BrowsingError, TooMuchOutputError, WebSeleniumComponent
|
||||
from .selenium import BrowsingError, WebSeleniumComponent
|
||||
|
||||
__all__ = ["WebSearchComponent", "BrowsingError", "WebSeleniumComponent"]
|
||||
|
||||
@@ -12,7 +12,6 @@ from selenium.webdriver.chrome.options import Options as ChromeOptions
|
||||
from selenium.webdriver.chrome.service import Service as ChromeDriverService
|
||||
from selenium.webdriver.chrome.webdriver import WebDriver as ChromeDriver
|
||||
from selenium.webdriver.common.by import By
|
||||
from selenium.webdriver.common.options import ArgOptions as BrowserOptions
|
||||
from selenium.webdriver.edge.options import Options as EdgeOptions
|
||||
from selenium.webdriver.edge.service import Service as EdgeDriverService
|
||||
from selenium.webdriver.edge.webdriver import WebDriver as EdgeDriver
|
||||
@@ -33,7 +32,7 @@ from forge.command import Command, command
|
||||
from forge.config.config import Config
|
||||
from forge.content_processing.html import extract_hyperlinks, format_hyperlinks
|
||||
from forge.content_processing.text import extract_information, summarize_text
|
||||
from forge.llm.providers.schema import ChatModelInfo, ChatModelProvider
|
||||
from forge.llm.providers import ChatModelInfo, MultiProvider
|
||||
from forge.models.json_schema import JSONSchema
|
||||
from forge.utils.exceptions import CommandExecutionError, TooMuchOutputError
|
||||
from forge.utils.url_validator import validate_url
|
||||
@@ -45,6 +44,9 @@ MAX_RAW_CONTENT_LENGTH = 500
|
||||
LINKS_TO_RETURN = 20
|
||||
|
||||
|
||||
BrowserOptions = ChromeOptions | EdgeOptions | FirefoxOptions | SafariOptions
|
||||
|
||||
|
||||
class BrowsingError(CommandExecutionError):
|
||||
"""An error occurred while trying to browse the page"""
|
||||
|
||||
@@ -55,7 +57,7 @@ class WebSeleniumComponent(DirectiveProvider, CommandProvider):
|
||||
def __init__(
|
||||
self,
|
||||
config: Config,
|
||||
llm_provider: ChatModelProvider,
|
||||
llm_provider: MultiProvider,
|
||||
model_info: ChatModelInfo,
|
||||
):
|
||||
self.legacy_config = config
|
||||
@@ -251,7 +253,7 @@ class WebSeleniumComponent(DirectiveProvider, CommandProvider):
|
||||
|
||||
if isinstance(options, FirefoxOptions):
|
||||
if config.selenium_headless:
|
||||
options.headless = True
|
||||
options.headless = True # type: ignore
|
||||
options.add_argument("--disable-gpu")
|
||||
driver = FirefoxDriver(
|
||||
service=GeckoDriverService(GeckoDriverManager().install()),
|
||||
|
||||
@@ -200,7 +200,7 @@ class ConfigBuilder(Configurable[Config]):
|
||||
|
||||
if (
|
||||
config.openai_credentials
|
||||
and config.openai_credentials.api_type == "azure"
|
||||
and config.openai_credentials.api_type == SecretStr("azure")
|
||||
and (config_file := config.azure_config_file)
|
||||
):
|
||||
config.openai_credentials.load_azure_config(config_file)
|
||||
|
||||
8
forge/forge/conftest.py
Normal file
8
forge/forge/conftest.py
Normal file
@@ -0,0 +1,8 @@
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def test_workspace(tmp_path: Path) -> Path:
|
||||
return tmp_path
|
||||
@@ -12,7 +12,7 @@ if TYPE_CHECKING:
|
||||
|
||||
from forge.json.parsing import extract_list_from_json
|
||||
from forge.llm.prompting import ChatPrompt
|
||||
from forge.llm.providers import ChatMessage, ChatModelProvider, ModelTokenizer
|
||||
from forge.llm.providers import ChatMessage, ModelTokenizer, MultiProvider
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -56,7 +56,7 @@ def chunk_content(
|
||||
|
||||
async def summarize_text(
|
||||
text: str,
|
||||
llm_provider: ChatModelProvider,
|
||||
llm_provider: MultiProvider,
|
||||
config: Config,
|
||||
question: Optional[str] = None,
|
||||
instruction: Optional[str] = None,
|
||||
@@ -89,7 +89,7 @@ async def summarize_text(
|
||||
async def extract_information(
|
||||
source_text: str,
|
||||
topics_of_interest: list[str],
|
||||
llm_provider: ChatModelProvider,
|
||||
llm_provider: MultiProvider,
|
||||
config: Config,
|
||||
) -> list[str]:
|
||||
fmt_topics_list = "\n".join(f"* {topic}." for topic in topics_of_interest)
|
||||
@@ -113,7 +113,7 @@ async def extract_information(
|
||||
async def _process_text(
|
||||
text: str,
|
||||
instruction: str,
|
||||
llm_provider: ChatModelProvider,
|
||||
llm_provider: MultiProvider,
|
||||
config: Config,
|
||||
output_type: type[str | list[str]] = str,
|
||||
) -> tuple[str, list[tuple[str, str]]] | list[str]:
|
||||
@@ -165,7 +165,7 @@ async def _process_text(
|
||||
),
|
||||
)
|
||||
|
||||
if output_type == list[str]:
|
||||
if isinstance(response.parsed_result, list):
|
||||
logger.debug(f"Raw LLM response: {repr(response.response.content)}")
|
||||
fmt_result_bullet_list = "\n".join(f"* {r}" for r in response.parsed_result)
|
||||
logger.debug(
|
||||
|
||||
@@ -1,10 +1,7 @@
|
||||
import enum
|
||||
from pathlib import Path
|
||||
|
||||
from .base import FileStorage, FileStorageConfiguration
|
||||
from .gcs import GCSFileStorage, GCSFileStorageConfiguration
|
||||
from .local import LocalFileStorage
|
||||
from .s3 import S3FileStorage, S3FileStorageConfiguration
|
||||
from .base import FileStorage
|
||||
|
||||
|
||||
class FileStorageBackendName(str, enum.Enum):
|
||||
@@ -15,7 +12,7 @@ class FileStorageBackendName(str, enum.Enum):
|
||||
|
||||
def get_storage(
|
||||
backend: FileStorageBackendName,
|
||||
root_path: Path = ".",
|
||||
root_path: Path = Path("."),
|
||||
restrict_to_root: bool = True,
|
||||
) -> FileStorage:
|
||||
match backend:
|
||||
|
||||
@@ -4,17 +4,17 @@ The FileStorage class provides an interface for interacting with a file storage.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
import shutil
|
||||
import tempfile
|
||||
from abc import ABC, abstractmethod
|
||||
from contextlib import contextmanager
|
||||
from io import IOBase, TextIOBase
|
||||
from pathlib import Path
|
||||
from typing import IO, Any, BinaryIO, Callable, Generator, Literal, TextIO, overload
|
||||
from typing import Any, BinaryIO, Callable, Generator, Literal, TextIO, overload
|
||||
|
||||
from watchdog.events import FileSystemEventHandler
|
||||
from watchdog.events import FileSystemEvent, FileSystemEventHandler
|
||||
from watchdog.observers import Observer
|
||||
|
||||
from forge.models.config import SystemConfiguration
|
||||
@@ -66,26 +66,29 @@ class FileStorage(ABC):
|
||||
def open_file(
|
||||
self,
|
||||
path: str | Path,
|
||||
mode: Literal["w", "r"] = "r",
|
||||
mode: Literal["r", "w"] = "r",
|
||||
binary: Literal[False] = False,
|
||||
) -> TextIO | TextIOBase:
|
||||
) -> TextIO:
|
||||
"""Returns a readable text file-like object representing the file."""
|
||||
|
||||
@overload
|
||||
@abstractmethod
|
||||
def open_file(
|
||||
self,
|
||||
path: str | Path,
|
||||
mode: Literal["w", "r"] = "r",
|
||||
binary: Literal[True] = True,
|
||||
) -> BinaryIO | IOBase:
|
||||
self, path: str | Path, mode: Literal["r", "w"], binary: Literal[True]
|
||||
) -> BinaryIO:
|
||||
"""Returns a binary file-like object representing the file."""
|
||||
|
||||
@overload
|
||||
@abstractmethod
|
||||
def open_file(self, path: str | Path, *, binary: Literal[True]) -> BinaryIO:
|
||||
"""Returns a readable binary file-like object representing the file."""
|
||||
|
||||
@overload
|
||||
@abstractmethod
|
||||
def open_file(
|
||||
self, path: str | Path, mode: Literal["w", "r"] = "r", binary: bool = False
|
||||
) -> IO | IOBase:
|
||||
"""Returns a readable file-like object representing the file."""
|
||||
self, path: str | Path, mode: Literal["r", "w"] = "r", binary: bool = False
|
||||
) -> TextIO | BinaryIO:
|
||||
"""Returns a file-like object representing the file."""
|
||||
|
||||
@overload
|
||||
@abstractmethod
|
||||
@@ -95,13 +98,15 @@ class FileStorage(ABC):
|
||||
|
||||
@overload
|
||||
@abstractmethod
|
||||
def read_file(self, path: str | Path, binary: Literal[True] = True) -> bytes:
|
||||
def read_file(self, path: str | Path, binary: Literal[True]) -> bytes:
|
||||
"""Read a file in the storage as binary."""
|
||||
...
|
||||
|
||||
@overload
|
||||
@abstractmethod
|
||||
def read_file(self, path: str | Path, binary: bool = False) -> str | bytes:
|
||||
"""Read a file in the storage."""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def write_file(self, path: str | Path, content: str | bytes) -> None:
|
||||
@@ -241,24 +246,32 @@ class FileSyncHandler(FileSystemEventHandler):
|
||||
self.storage = storage
|
||||
self.path = Path(path)
|
||||
|
||||
async def on_modified(self, event):
|
||||
def on_modified(self, event: FileSystemEvent):
|
||||
if event.is_directory:
|
||||
return
|
||||
|
||||
file_path = Path(event.src_path).relative_to(self.path)
|
||||
content = file_path.read_bytes()
|
||||
await self.storage.write_file(file_path, content)
|
||||
# Must execute write_file synchronously because the hook is synchronous
|
||||
# TODO: Schedule write operation using asyncio.create_task (non-blocking)
|
||||
asyncio.get_event_loop().run_until_complete(
|
||||
self.storage.write_file(file_path, content)
|
||||
)
|
||||
|
||||
async def on_created(self, event):
|
||||
def on_created(self, event: FileSystemEvent):
|
||||
if event.is_directory:
|
||||
self.storage.make_dir(event.src_path)
|
||||
return
|
||||
|
||||
file_path = Path(event.src_path).relative_to(self.path)
|
||||
content = file_path.read_bytes()
|
||||
await self.storage.write_file(file_path, content)
|
||||
# Must execute write_file synchronously because the hook is synchronous
|
||||
# TODO: Schedule write operation using asyncio.create_task (non-blocking)
|
||||
asyncio.get_event_loop().run_until_complete(
|
||||
self.storage.write_file(file_path, content)
|
||||
)
|
||||
|
||||
def on_deleted(self, event):
|
||||
def on_deleted(self, event: FileSystemEvent):
|
||||
if event.is_directory:
|
||||
self.storage.delete_dir(event.src_path)
|
||||
return
|
||||
@@ -266,5 +279,5 @@ class FileSyncHandler(FileSystemEventHandler):
|
||||
file_path = event.src_path
|
||||
self.storage.delete_file(file_path)
|
||||
|
||||
def on_moved(self, event):
|
||||
def on_moved(self, event: FileSystemEvent):
|
||||
self.storage.rename(event.src_path, event.dest_path)
|
||||
|
||||
@@ -7,12 +7,13 @@ from __future__ import annotations
|
||||
|
||||
import inspect
|
||||
import logging
|
||||
from io import IOBase
|
||||
from io import TextIOWrapper
|
||||
from pathlib import Path
|
||||
from typing import Literal
|
||||
from typing import Literal, overload
|
||||
|
||||
from google.cloud import storage
|
||||
from google.cloud.exceptions import NotFound
|
||||
from google.cloud.storage.fileio import BlobReader, BlobWriter
|
||||
|
||||
from forge.models.config import UserConfigurable
|
||||
|
||||
@@ -73,14 +74,67 @@ class GCSFileStorage(FileStorage):
|
||||
path = self.get_path(path)
|
||||
return self._bucket.blob(str(path))
|
||||
|
||||
@overload
|
||||
def open_file(
|
||||
self, path: str | Path, mode: Literal["w", "r"] = "r", binary: bool = False
|
||||
) -> IOBase:
|
||||
self,
|
||||
path: str | Path,
|
||||
mode: Literal["r", "w"] = "r",
|
||||
binary: Literal[False] = False,
|
||||
) -> TextIOWrapper:
|
||||
...
|
||||
|
||||
@overload
|
||||
def open_file(
|
||||
self, path: str | Path, mode: Literal["r"], binary: Literal[True]
|
||||
) -> BlobReader:
|
||||
...
|
||||
|
||||
@overload
|
||||
def open_file(
|
||||
self, path: str | Path, mode: Literal["w"], binary: Literal[True]
|
||||
) -> BlobWriter:
|
||||
...
|
||||
|
||||
@overload
|
||||
def open_file(
|
||||
self, path: str | Path, mode: Literal["r", "w"], binary: Literal[True]
|
||||
) -> BlobWriter | BlobReader:
|
||||
...
|
||||
|
||||
@overload
|
||||
def open_file(self, path: str | Path, *, binary: Literal[True]) -> BlobReader:
|
||||
...
|
||||
|
||||
@overload
|
||||
def open_file(
|
||||
self, path: str | Path, mode: Literal["r", "w"] = "r", binary: bool = False
|
||||
) -> BlobReader | BlobWriter | TextIOWrapper:
|
||||
...
|
||||
|
||||
# https://github.com/microsoft/pyright/issues/8007
|
||||
def open_file( # pyright: ignore[reportIncompatibleMethodOverride]
|
||||
self, path: str | Path, mode: Literal["r", "w"] = "r", binary: bool = False
|
||||
) -> BlobReader | BlobWriter | TextIOWrapper:
|
||||
"""Open a file in the storage."""
|
||||
blob = self._get_blob(path)
|
||||
blob.reload() # pin revision number to prevent version mixing while reading
|
||||
return blob.open(f"{mode}b" if binary else mode)
|
||||
|
||||
@overload
|
||||
def read_file(self, path: str | Path, binary: Literal[False] = False) -> str:
|
||||
"""Read a file in the storage as text."""
|
||||
...
|
||||
|
||||
@overload
|
||||
def read_file(self, path: str | Path, binary: Literal[True]) -> bytes:
|
||||
"""Read a file in the storage as binary."""
|
||||
...
|
||||
|
||||
@overload
|
||||
def read_file(self, path: str | Path, binary: bool = False) -> str | bytes:
|
||||
"""Read a file in the storage."""
|
||||
...
|
||||
|
||||
def read_file(self, path: str | Path, binary: bool = False) -> str | bytes:
|
||||
"""Read a file in the storage."""
|
||||
return self.open_file(path, "r", binary).read()
|
||||
|
||||
@@ -8,7 +8,7 @@ import inspect
|
||||
import logging
|
||||
from contextlib import contextmanager
|
||||
from pathlib import Path
|
||||
from typing import IO, Any, Generator, Literal
|
||||
from typing import Any, BinaryIO, Generator, Literal, TextIO, overload
|
||||
|
||||
from .base import FileStorage, FileStorageConfiguration
|
||||
|
||||
@@ -42,16 +42,58 @@ class LocalFileStorage(FileStorage):
|
||||
def initialize(self) -> None:
|
||||
self.root.mkdir(exist_ok=True, parents=True)
|
||||
|
||||
@overload
|
||||
def open_file(
|
||||
self,
|
||||
path: str | Path,
|
||||
mode: Literal["w", "r"] = "r",
|
||||
binary: Literal[False] = False,
|
||||
) -> TextIO:
|
||||
...
|
||||
|
||||
@overload
|
||||
def open_file(
|
||||
self, path: str | Path, mode: Literal["w", "r"], binary: Literal[True]
|
||||
) -> BinaryIO:
|
||||
...
|
||||
|
||||
@overload
|
||||
def open_file(self, path: str | Path, *, binary: Literal[True]) -> BinaryIO:
|
||||
...
|
||||
|
||||
@overload
|
||||
def open_file(
|
||||
self, path: str | Path, mode: Literal["w", "r"] = "r", binary: bool = False
|
||||
) -> IO:
|
||||
) -> TextIO | BinaryIO:
|
||||
...
|
||||
|
||||
def open_file(
|
||||
self, path: str | Path, mode: Literal["w", "r"] = "r", binary: bool = False
|
||||
) -> TextIO | BinaryIO:
|
||||
"""Open a file in the storage."""
|
||||
return self._open_file(path, f"{mode}b" if binary else mode)
|
||||
|
||||
def _open_file(self, path: str | Path, mode: str) -> IO:
|
||||
def _open_file(self, path: str | Path, mode: str) -> TextIO | BinaryIO:
|
||||
full_path = self.get_path(path)
|
||||
if any(m in mode for m in ("w", "a", "x")):
|
||||
full_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
return open(full_path, mode) # type: ignore
|
||||
|
||||
@overload
|
||||
def read_file(self, path: str | Path, binary: Literal[False] = False) -> str:
|
||||
"""Read a file in the storage as text."""
|
||||
...
|
||||
|
||||
@overload
|
||||
def read_file(self, path: str | Path, binary: Literal[True]) -> bytes:
|
||||
"""Read a file in the storage as binary."""
|
||||
...
|
||||
|
||||
@overload
|
||||
def read_file(self, path: str | Path, binary: bool = False) -> str | bytes:
|
||||
"""Read a file in the storage."""
|
||||
...
|
||||
|
||||
def read_file(self, path: str | Path, binary: bool = False) -> str | bytes:
|
||||
"""Read a file in the storage."""
|
||||
with self._open_file(path, "rb" if binary else "r") as file:
|
||||
@@ -60,7 +102,7 @@ class LocalFileStorage(FileStorage):
|
||||
async def write_file(self, path: str | Path, content: str | bytes) -> None:
|
||||
"""Write to a file in the storage."""
|
||||
with self._open_file(path, "wb" if type(content) is bytes else "w") as file:
|
||||
file.write(content)
|
||||
file.write(content) # type: ignore
|
||||
|
||||
if self.on_write_file:
|
||||
path = Path(path)
|
||||
|
||||
@@ -8,9 +8,9 @@ from __future__ import annotations
|
||||
import contextlib
|
||||
import inspect
|
||||
import logging
|
||||
from io import IOBase, TextIOWrapper
|
||||
from io import TextIOWrapper
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Literal, Optional
|
||||
from typing import TYPE_CHECKING, BinaryIO, Literal, Optional, overload
|
||||
|
||||
import boto3
|
||||
import botocore.exceptions
|
||||
@@ -22,6 +22,7 @@ from .base import FileStorage, FileStorageConfiguration
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import mypy_boto3_s3
|
||||
from botocore.response import StreamingBody
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -89,18 +90,60 @@ class S3FileStorage(FileStorage):
|
||||
|
||||
def _get_obj(self, path: str | Path) -> mypy_boto3_s3.service_resource.Object:
|
||||
"""Get an S3 object."""
|
||||
path = self.get_path(path)
|
||||
obj = self._bucket.Object(str(path))
|
||||
with contextlib.suppress(botocore.exceptions.ClientError):
|
||||
obj.load()
|
||||
return obj
|
||||
|
||||
@overload
|
||||
def open_file(
|
||||
self, path: str | Path, mode: Literal["w", "r"] = "r", binary: bool = False
|
||||
) -> IOBase:
|
||||
self,
|
||||
path: str | Path,
|
||||
mode: Literal["r", "w"] = "r",
|
||||
binary: Literal[False] = False,
|
||||
) -> TextIOWrapper:
|
||||
...
|
||||
|
||||
@overload
|
||||
def open_file(
|
||||
self, path: str | Path, mode: Literal["r", "w"], binary: Literal[True]
|
||||
) -> S3BinaryIOWrapper:
|
||||
...
|
||||
|
||||
@overload
|
||||
def open_file(
|
||||
self, path: str | Path, *, binary: Literal[True]
|
||||
) -> S3BinaryIOWrapper:
|
||||
...
|
||||
|
||||
@overload
|
||||
def open_file(
|
||||
self, path: str | Path, mode: Literal["r", "w"] = "r", binary: bool = False
|
||||
) -> S3BinaryIOWrapper | TextIOWrapper:
|
||||
...
|
||||
|
||||
def open_file(
|
||||
self, path: str | Path, mode: Literal["r", "w"] = "r", binary: bool = False
|
||||
) -> TextIOWrapper | S3BinaryIOWrapper:
|
||||
"""Open a file in the storage."""
|
||||
obj = self._get_obj(path)
|
||||
return obj.get()["Body"] if binary else TextIOWrapper(obj.get()["Body"])
|
||||
path = self.get_path(path)
|
||||
body = S3BinaryIOWrapper(self._get_obj(path).get()["Body"], str(path))
|
||||
return body if binary else TextIOWrapper(body)
|
||||
|
||||
@overload
|
||||
def read_file(self, path: str | Path, binary: Literal[False] = False) -> str:
|
||||
"""Read a file in the storage as text."""
|
||||
...
|
||||
|
||||
@overload
|
||||
def read_file(self, path: str | Path, binary: Literal[True]) -> bytes:
|
||||
"""Read a file in the storage as binary."""
|
||||
...
|
||||
|
||||
@overload
|
||||
def read_file(self, path: str | Path, binary: bool = False) -> str | bytes:
|
||||
"""Read a file in the storage."""
|
||||
...
|
||||
|
||||
def read_file(self, path: str | Path, binary: bool = False) -> str | bytes:
|
||||
"""Read a file in the storage."""
|
||||
@@ -108,7 +151,7 @@ class S3FileStorage(FileStorage):
|
||||
|
||||
async def write_file(self, path: str | Path, content: str | bytes) -> None:
|
||||
"""Write to a file in the storage."""
|
||||
obj = self._get_obj(path)
|
||||
obj = self._get_obj(self.get_path(path))
|
||||
obj.put(Body=content)
|
||||
|
||||
if self.on_write_file:
|
||||
@@ -172,7 +215,7 @@ class S3FileStorage(FileStorage):
|
||||
self._s3.meta.client.head_object(Bucket=self._bucket_name, Key=str(path))
|
||||
return True
|
||||
except botocore.exceptions.ClientError as e:
|
||||
if int(e.response["ResponseMetadata"]["HTTPStatusCode"]) == 404:
|
||||
if e.response.get("ResponseMetadata", {}).get("HTTPStatusCode") == 404:
|
||||
# If the object does not exist,
|
||||
# check for objects with the prefix (folder)
|
||||
prefix = f"{str(path).rstrip('/')}/"
|
||||
@@ -201,7 +244,7 @@ class S3FileStorage(FileStorage):
|
||||
)
|
||||
self._s3.meta.client.delete_object(Bucket=self._bucket_name, Key=old_path)
|
||||
except botocore.exceptions.ClientError as e:
|
||||
if int(e.response["ResponseMetadata"]["HTTPStatusCode"]) == 404:
|
||||
if e.response.get("ResponseMetadata", {}).get("HTTPStatusCode") == 404:
|
||||
# If the object does not exist,
|
||||
# it may be a folder
|
||||
prefix = f"{old_path.rstrip('/')}/"
|
||||
@@ -233,7 +276,7 @@ class S3FileStorage(FileStorage):
|
||||
Key=destination,
|
||||
)
|
||||
except botocore.exceptions.ClientError as e:
|
||||
if int(e.response["ResponseMetadata"]["HTTPStatusCode"]) == 404:
|
||||
if e.response.get("ResponseMetadata", {}).get("HTTPStatusCode") == 404:
|
||||
# If the object does not exist,
|
||||
# it may be a folder
|
||||
prefix = f"{source.rstrip('/')}/"
|
||||
@@ -254,7 +297,7 @@ class S3FileStorage(FileStorage):
|
||||
S3FileStorageConfiguration(
|
||||
bucket=self._bucket_name,
|
||||
root=Path("/").joinpath(self.get_path(subroot)),
|
||||
s3_endpoint_url=self._s3.meta.client.meta.endpoint_url,
|
||||
s3_endpoint_url=SecretStr(self._s3.meta.client.meta.endpoint_url),
|
||||
)
|
||||
)
|
||||
file_storage._s3 = self._s3
|
||||
@@ -263,3 +306,48 @@ class S3FileStorage(FileStorage):
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"{__class__.__name__}(bucket='{self._bucket_name}', root={self._root})"
|
||||
|
||||
|
||||
class S3BinaryIOWrapper(BinaryIO):
|
||||
def __init__(self, body: StreamingBody, name: str):
|
||||
self.body = body
|
||||
self._name = name
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return self._name
|
||||
|
||||
def read(self, size: int = -1) -> bytes:
|
||||
return self.body.read(size if size > 0 else None)
|
||||
|
||||
def readinto(self, b: bytearray) -> int:
|
||||
data = self.read(len(b))
|
||||
b[: len(data)] = data
|
||||
return len(data)
|
||||
|
||||
def close(self) -> None:
|
||||
self.body.close()
|
||||
|
||||
def fileno(self) -> int:
|
||||
return self.body.fileno()
|
||||
|
||||
def flush(self) -> None:
|
||||
self.body.flush()
|
||||
|
||||
def isatty(self) -> bool:
|
||||
return self.body.isatty()
|
||||
|
||||
def readable(self) -> bool:
|
||||
return self.body.readable()
|
||||
|
||||
def seekable(self) -> bool:
|
||||
return self.body.seekable()
|
||||
|
||||
def writable(self) -> bool:
|
||||
return False
|
||||
|
||||
def __enter__(self):
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
self.body.close()
|
||||
|
||||
@@ -1 +0,0 @@
|
||||
from .parsing import extract_dict_from_json, extract_list_from_json, json_loads
|
||||
|
||||
@@ -1,7 +1,5 @@
|
||||
import abc
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from forge.models.config import SystemConfiguration
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from forge.llm.providers import AssistantChatMessage
|
||||
@@ -10,8 +8,6 @@ from .schema import ChatPrompt, LanguageModelClassification
|
||||
|
||||
|
||||
class PromptStrategy(abc.ABC):
|
||||
default_configuration: SystemConfiguration
|
||||
|
||||
@property
|
||||
@abc.abstractmethod
|
||||
def model_classification(self) -> LanguageModelClassification:
|
||||
@@ -22,5 +18,5 @@ class PromptStrategy(abc.ABC):
|
||||
...
|
||||
|
||||
@abc.abstractmethod
|
||||
def parse_response_content(self, response_content: "AssistantChatMessage"):
|
||||
def parse_response_content(self, response: "AssistantChatMessage") -> Any:
|
||||
...
|
||||
|
||||
@@ -27,7 +27,7 @@ class ChatPrompt(BaseModel):
|
||||
prefill_response: str = ""
|
||||
|
||||
def raw(self) -> list[ChatMessageDict]:
|
||||
return [m.dict() for m in self.messages]
|
||||
return [m.dict() for m in self.messages] # type: ignore
|
||||
|
||||
def __str__(self):
|
||||
return "\n\n".join(
|
||||
|
||||
@@ -1,4 +1,10 @@
|
||||
from .multi import CHAT_MODELS, ModelName, MultiProvider
|
||||
from .multi import (
|
||||
CHAT_MODELS,
|
||||
ChatModelProvider,
|
||||
EmbeddingModelProvider,
|
||||
ModelName,
|
||||
MultiProvider,
|
||||
)
|
||||
from .openai import (
|
||||
OPEN_AI_CHAT_MODELS,
|
||||
OPEN_AI_EMBEDDING_MODELS,
|
||||
@@ -14,15 +20,12 @@ from .schema import (
|
||||
AssistantFunctionCallDict,
|
||||
ChatMessage,
|
||||
ChatModelInfo,
|
||||
ChatModelProvider,
|
||||
ChatModelResponse,
|
||||
CompletionModelFunction,
|
||||
Embedding,
|
||||
EmbeddingModelInfo,
|
||||
EmbeddingModelProvider,
|
||||
EmbeddingModelResponse,
|
||||
ModelInfo,
|
||||
ModelProvider,
|
||||
ModelProviderBudget,
|
||||
ModelProviderCredentials,
|
||||
ModelProviderName,
|
||||
@@ -41,7 +44,6 @@ __all__ = [
|
||||
"AssistantFunctionCallDict",
|
||||
"ChatMessage",
|
||||
"ChatModelInfo",
|
||||
"ChatModelProvider",
|
||||
"ChatModelResponse",
|
||||
"CompletionModelFunction",
|
||||
"CHAT_MODELS",
|
||||
@@ -51,7 +53,7 @@ __all__ = [
|
||||
"EmbeddingModelResponse",
|
||||
"ModelInfo",
|
||||
"ModelName",
|
||||
"ModelProvider",
|
||||
"ChatModelProvider",
|
||||
"ModelProviderBudget",
|
||||
"ModelProviderCredentials",
|
||||
"ModelProviderName",
|
||||
|
||||
@@ -2,7 +2,7 @@ from __future__ import annotations
|
||||
|
||||
import enum
|
||||
import logging
|
||||
from typing import TYPE_CHECKING, Callable, Optional, ParamSpec, TypeVar
|
||||
from typing import TYPE_CHECKING, Any, Callable, Optional, ParamSpec, TypeVar
|
||||
|
||||
import sentry_sdk
|
||||
import tenacity
|
||||
@@ -14,9 +14,9 @@ from forge.llm.providers.schema import (
|
||||
AssistantChatMessage,
|
||||
AssistantFunctionCall,
|
||||
AssistantToolCall,
|
||||
BaseChatModelProvider,
|
||||
ChatMessage,
|
||||
ChatModelInfo,
|
||||
ChatModelProvider,
|
||||
ChatModelResponse,
|
||||
CompletionModelFunction,
|
||||
ModelProviderBudget,
|
||||
@@ -27,7 +27,7 @@ from forge.llm.providers.schema import (
|
||||
ModelTokenizer,
|
||||
ToolResultMessage,
|
||||
)
|
||||
from forge.models.config import Configurable, UserConfigurable
|
||||
from forge.models.config import UserConfigurable
|
||||
|
||||
from .utils import validate_tool_calls
|
||||
|
||||
@@ -84,14 +84,14 @@ class AnthropicConfiguration(ModelProviderConfiguration):
|
||||
class AnthropicCredentials(ModelProviderCredentials):
|
||||
"""Credentials for Anthropic."""
|
||||
|
||||
api_key: SecretStr = UserConfigurable(from_env="ANTHROPIC_API_KEY")
|
||||
api_key: SecretStr = UserConfigurable(from_env="ANTHROPIC_API_KEY") # type: ignore
|
||||
api_base: Optional[SecretStr] = UserConfigurable(
|
||||
default=None, from_env="ANTHROPIC_API_BASE_URL"
|
||||
)
|
||||
|
||||
def get_api_access_kwargs(self) -> dict[str, str]:
|
||||
return {
|
||||
k: (v.get_secret_value() if type(v) is SecretStr else v)
|
||||
k: v.get_secret_value()
|
||||
for k, v in {
|
||||
"api_key": self.api_key,
|
||||
"base_url": self.api_base,
|
||||
@@ -101,12 +101,12 @@ class AnthropicCredentials(ModelProviderCredentials):
|
||||
|
||||
|
||||
class AnthropicSettings(ModelProviderSettings):
|
||||
configuration: AnthropicConfiguration
|
||||
credentials: Optional[AnthropicCredentials]
|
||||
budget: ModelProviderBudget
|
||||
configuration: AnthropicConfiguration # type: ignore
|
||||
credentials: Optional[AnthropicCredentials] # type: ignore
|
||||
budget: ModelProviderBudget # type: ignore
|
||||
|
||||
|
||||
class AnthropicProvider(Configurable[AnthropicSettings], ChatModelProvider):
|
||||
class AnthropicProvider(BaseChatModelProvider[AnthropicModelName, AnthropicSettings]):
|
||||
default_settings = AnthropicSettings(
|
||||
name="anthropic_provider",
|
||||
description="Provides access to Anthropic's API.",
|
||||
@@ -136,27 +136,26 @@ class AnthropicProvider(Configurable[AnthropicSettings], ChatModelProvider):
|
||||
|
||||
from anthropic import AsyncAnthropic
|
||||
|
||||
self._client = AsyncAnthropic(**self._credentials.get_api_access_kwargs())
|
||||
self._client = AsyncAnthropic(
|
||||
**self._credentials.get_api_access_kwargs() # type: ignore
|
||||
)
|
||||
|
||||
async def get_available_models(self) -> list[ChatModelInfo]:
|
||||
async def get_available_models(self) -> list[ChatModelInfo[AnthropicModelName]]:
|
||||
return list(ANTHROPIC_CHAT_MODELS.values())
|
||||
|
||||
def get_token_limit(self, model_name: str) -> int:
|
||||
def get_token_limit(self, model_name: AnthropicModelName) -> int:
|
||||
"""Get the token limit for a given model."""
|
||||
return ANTHROPIC_CHAT_MODELS[model_name].max_tokens
|
||||
|
||||
@classmethod
|
||||
def get_tokenizer(cls, model_name: AnthropicModelName) -> ModelTokenizer:
|
||||
def get_tokenizer(self, model_name: AnthropicModelName) -> ModelTokenizer[Any]:
|
||||
# HACK: No official tokenizer is available for Claude 3
|
||||
return tiktoken.encoding_for_model(model_name)
|
||||
|
||||
@classmethod
|
||||
def count_tokens(cls, text: str, model_name: AnthropicModelName) -> int:
|
||||
def count_tokens(self, text: str, model_name: AnthropicModelName) -> int:
|
||||
return 0 # HACK: No official tokenizer is available for Claude 3
|
||||
|
||||
@classmethod
|
||||
def count_message_tokens(
|
||||
cls,
|
||||
self,
|
||||
messages: ChatMessage | list[ChatMessage],
|
||||
model_name: AnthropicModelName,
|
||||
) -> int:
|
||||
@@ -195,7 +194,7 @@ class AnthropicProvider(Configurable[AnthropicSettings], ChatModelProvider):
|
||||
cost,
|
||||
t_input,
|
||||
t_output,
|
||||
) = await self._create_chat_completion(completion_kwargs)
|
||||
) = await self._create_chat_completion(model_name, completion_kwargs)
|
||||
total_cost += cost
|
||||
self._logger.debug(
|
||||
f"Completion usage: {t_input} input, {t_output} output "
|
||||
@@ -245,7 +244,7 @@ class AnthropicProvider(Configurable[AnthropicSettings], ChatModelProvider):
|
||||
)
|
||||
if attempts < self._configuration.fix_failed_parse_tries:
|
||||
anthropic_messages.append(
|
||||
_assistant_msg.dict(include={"role", "content"})
|
||||
_assistant_msg.dict(include={"role", "content"}) # type: ignore
|
||||
)
|
||||
anthropic_messages.append(
|
||||
{
|
||||
@@ -312,7 +311,6 @@ class AnthropicProvider(Configurable[AnthropicSettings], ChatModelProvider):
|
||||
def _get_chat_completion_args(
|
||||
self,
|
||||
prompt_messages: list[ChatMessage],
|
||||
model: AnthropicModelName,
|
||||
functions: Optional[list[CompletionModelFunction]] = None,
|
||||
max_output_tokens: Optional[int] = None,
|
||||
**kwargs,
|
||||
@@ -321,7 +319,6 @@ class AnthropicProvider(Configurable[AnthropicSettings], ChatModelProvider):
|
||||
|
||||
Args:
|
||||
prompt_messages: List of ChatMessages.
|
||||
model: The model to use.
|
||||
functions: Optional list of functions available to the LLM.
|
||||
kwargs: Additional keyword arguments.
|
||||
|
||||
@@ -329,8 +326,6 @@ class AnthropicProvider(Configurable[AnthropicSettings], ChatModelProvider):
|
||||
list[MessageParam]: Prompt messages for the Anthropic call
|
||||
dict[str, Any]: Any other kwargs for the Anthropic call
|
||||
"""
|
||||
kwargs["model"] = model
|
||||
|
||||
if functions:
|
||||
kwargs["tools"] = [
|
||||
{
|
||||
@@ -433,7 +428,7 @@ class AnthropicProvider(Configurable[AnthropicSettings], ChatModelProvider):
|
||||
return messages, kwargs # type: ignore
|
||||
|
||||
async def _create_chat_completion(
|
||||
self, completion_kwargs: MessageCreateParams
|
||||
self, model: AnthropicModelName, completion_kwargs: MessageCreateParams
|
||||
) -> tuple[Message, float, int, int]:
|
||||
"""
|
||||
Create a chat completion using the Anthropic API with retry handling.
|
||||
@@ -449,17 +444,15 @@ class AnthropicProvider(Configurable[AnthropicSettings], ChatModelProvider):
|
||||
"""
|
||||
|
||||
@self._retry_api_request
|
||||
async def _create_chat_completion_with_retry(
|
||||
completion_kwargs: MessageCreateParams,
|
||||
) -> Message:
|
||||
async def _create_chat_completion_with_retry() -> Message:
|
||||
return await self._client.beta.tools.messages.create(
|
||||
**completion_kwargs # type: ignore
|
||||
model=model, **completion_kwargs # type: ignore
|
||||
)
|
||||
|
||||
response = await _create_chat_completion_with_retry(completion_kwargs)
|
||||
response = await _create_chat_completion_with_retry()
|
||||
|
||||
cost = self._budget.update_usage_and_cost(
|
||||
model_info=ANTHROPIC_CHAT_MODELS[completion_kwargs["model"]],
|
||||
model_info=ANTHROPIC_CHAT_MODELS[model],
|
||||
input_tokens_used=response.usage.input_tokens,
|
||||
output_tokens_used=response.usage.output_tokens,
|
||||
)
|
||||
@@ -472,7 +465,10 @@ class AnthropicProvider(Configurable[AnthropicSettings], ChatModelProvider):
|
||||
AssistantToolCall(
|
||||
id=c.id,
|
||||
type="function",
|
||||
function=AssistantFunctionCall(name=c.name, arguments=c.input),
|
||||
function=AssistantFunctionCall(
|
||||
name=c.name,
|
||||
arguments=c.input, # type: ignore
|
||||
),
|
||||
)
|
||||
for c in assistant_message.content
|
||||
if c.type == "tool_use"
|
||||
|
||||
@@ -2,7 +2,7 @@ from __future__ import annotations
|
||||
|
||||
import enum
|
||||
import logging
|
||||
from typing import TYPE_CHECKING, Callable, Optional, ParamSpec, TypeVar
|
||||
from typing import TYPE_CHECKING, Any, Callable, Optional, ParamSpec, TypeVar
|
||||
|
||||
import sentry_sdk
|
||||
import tenacity
|
||||
@@ -15,9 +15,9 @@ from forge.llm.providers.schema import (
|
||||
AssistantChatMessage,
|
||||
AssistantFunctionCall,
|
||||
AssistantToolCall,
|
||||
BaseChatModelProvider,
|
||||
ChatMessage,
|
||||
ChatModelInfo,
|
||||
ChatModelProvider,
|
||||
ChatModelResponse,
|
||||
CompletionModelFunction,
|
||||
ModelProviderBudget,
|
||||
@@ -27,8 +27,9 @@ from forge.llm.providers.schema import (
|
||||
ModelProviderSettings,
|
||||
ModelTokenizer,
|
||||
)
|
||||
from forge.models.config import Configurable, UserConfigurable
|
||||
from forge.models.config import UserConfigurable
|
||||
|
||||
from .openai import format_function_def_for_openai
|
||||
from .utils import validate_tool_calls
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -93,14 +94,14 @@ class GroqConfiguration(ModelProviderConfiguration):
|
||||
class GroqCredentials(ModelProviderCredentials):
|
||||
"""Credentials for Groq."""
|
||||
|
||||
api_key: SecretStr = UserConfigurable(from_env="GROQ_API_KEY")
|
||||
api_key: SecretStr = UserConfigurable(from_env="GROQ_API_KEY") # type: ignore
|
||||
api_base: Optional[SecretStr] = UserConfigurable(
|
||||
default=None, from_env="GROQ_API_BASE_URL"
|
||||
)
|
||||
|
||||
def get_api_access_kwargs(self) -> dict[str, str]:
|
||||
return {
|
||||
k: (v.get_secret_value() if type(v) is SecretStr else v)
|
||||
k: v.get_secret_value()
|
||||
for k, v in {
|
||||
"api_key": self.api_key,
|
||||
"base_url": self.api_base,
|
||||
@@ -110,12 +111,12 @@ class GroqCredentials(ModelProviderCredentials):
|
||||
|
||||
|
||||
class GroqSettings(ModelProviderSettings):
|
||||
configuration: GroqConfiguration
|
||||
credentials: Optional[GroqCredentials]
|
||||
budget: ModelProviderBudget
|
||||
configuration: GroqConfiguration # type: ignore
|
||||
credentials: Optional[GroqCredentials] # type: ignore
|
||||
budget: ModelProviderBudget # type: ignore
|
||||
|
||||
|
||||
class GroqProvider(Configurable[GroqSettings], ChatModelProvider):
|
||||
class GroqProvider(BaseChatModelProvider[GroqModelName, GroqSettings]):
|
||||
default_settings = GroqSettings(
|
||||
name="groq_provider",
|
||||
description="Provides access to Groq's API.",
|
||||
@@ -145,28 +146,27 @@ class GroqProvider(Configurable[GroqSettings], ChatModelProvider):
|
||||
|
||||
from groq import AsyncGroq
|
||||
|
||||
self._client = AsyncGroq(**self._credentials.get_api_access_kwargs())
|
||||
self._client = AsyncGroq(
|
||||
**self._credentials.get_api_access_kwargs() # type: ignore
|
||||
)
|
||||
|
||||
async def get_available_models(self) -> list[ChatModelInfo]:
|
||||
async def get_available_models(self) -> list[ChatModelInfo[GroqModelName]]:
|
||||
_models = (await self._client.models.list()).data
|
||||
return [GROQ_CHAT_MODELS[m.id] for m in _models if m.id in GROQ_CHAT_MODELS]
|
||||
|
||||
def get_token_limit(self, model_name: str) -> int:
|
||||
def get_token_limit(self, model_name: GroqModelName) -> int:
|
||||
"""Get the token limit for a given model."""
|
||||
return GROQ_CHAT_MODELS[model_name].max_tokens
|
||||
|
||||
@classmethod
|
||||
def get_tokenizer(cls, model_name: GroqModelName) -> ModelTokenizer:
|
||||
def get_tokenizer(self, model_name: GroqModelName) -> ModelTokenizer[Any]:
|
||||
# HACK: No official tokenizer is available for Groq
|
||||
return tiktoken.encoding_for_model("gpt-3.5-turbo")
|
||||
|
||||
@classmethod
|
||||
def count_tokens(cls, text: str, model_name: GroqModelName) -> int:
|
||||
return len(cls.get_tokenizer(model_name).encode(text))
|
||||
def count_tokens(self, text: str, model_name: GroqModelName) -> int:
|
||||
return len(self.get_tokenizer(model_name).encode(text))
|
||||
|
||||
@classmethod
|
||||
def count_message_tokens(
|
||||
cls,
|
||||
self,
|
||||
messages: ChatMessage | list[ChatMessage],
|
||||
model_name: GroqModelName,
|
||||
) -> int:
|
||||
@@ -174,7 +174,7 @@ class GroqProvider(Configurable[GroqSettings], ChatModelProvider):
|
||||
messages = [messages]
|
||||
# HACK: No official tokenizer (for text or messages) is available for Groq.
|
||||
# Token overhead of messages is unknown and may be inaccurate.
|
||||
return cls.count_tokens(
|
||||
return self.count_tokens(
|
||||
"\n\n".join(f"{m.role.upper()}: {m.content}" for m in messages), model_name
|
||||
)
|
||||
|
||||
@@ -191,7 +191,6 @@ class GroqProvider(Configurable[GroqSettings], ChatModelProvider):
|
||||
"""Create a completion using the Groq API."""
|
||||
groq_messages, completion_kwargs = self._get_chat_completion_args(
|
||||
prompt_messages=model_prompt,
|
||||
model=model_name,
|
||||
functions=functions,
|
||||
max_output_tokens=max_output_tokens,
|
||||
**kwargs,
|
||||
@@ -202,7 +201,8 @@ class GroqProvider(Configurable[GroqSettings], ChatModelProvider):
|
||||
while True:
|
||||
completion_kwargs["messages"] = groq_messages.copy()
|
||||
_response, _cost, t_input, t_output = await self._create_chat_completion(
|
||||
completion_kwargs
|
||||
model=model_name,
|
||||
completion_kwargs=completion_kwargs,
|
||||
)
|
||||
total_cost += _cost
|
||||
|
||||
@@ -221,7 +221,7 @@ class GroqProvider(Configurable[GroqSettings], ChatModelProvider):
|
||||
parse_errors += validate_tool_calls(tool_calls, functions)
|
||||
|
||||
assistant_msg = AssistantChatMessage(
|
||||
content=_assistant_msg.content,
|
||||
content=_assistant_msg.content or "",
|
||||
tool_calls=tool_calls or None,
|
||||
)
|
||||
|
||||
@@ -240,7 +240,7 @@ class GroqProvider(Configurable[GroqSettings], ChatModelProvider):
|
||||
|
||||
return ChatModelResponse(
|
||||
response=AssistantChatMessage(
|
||||
content=_assistant_msg.content,
|
||||
content=_assistant_msg.content or "",
|
||||
tool_calls=tool_calls or None,
|
||||
),
|
||||
parsed_result=parsed_result,
|
||||
@@ -266,7 +266,9 @@ class GroqProvider(Configurable[GroqSettings], ChatModelProvider):
|
||||
)
|
||||
|
||||
if attempts < self._configuration.fix_failed_parse_tries:
|
||||
groq_messages.append(_assistant_msg.dict(exclude_none=True))
|
||||
groq_messages.append(
|
||||
_assistant_msg.dict(exclude_none=True) # type: ignore
|
||||
)
|
||||
groq_messages.append(
|
||||
{
|
||||
"role": "system",
|
||||
@@ -282,7 +284,6 @@ class GroqProvider(Configurable[GroqSettings], ChatModelProvider):
|
||||
def _get_chat_completion_args(
|
||||
self,
|
||||
prompt_messages: list[ChatMessage],
|
||||
model: GroqModelName,
|
||||
functions: Optional[list[CompletionModelFunction]] = None,
|
||||
max_output_tokens: Optional[int] = None,
|
||||
**kwargs, # type: ignore
|
||||
@@ -291,7 +292,6 @@ class GroqProvider(Configurable[GroqSettings], ChatModelProvider):
|
||||
|
||||
Args:
|
||||
model_prompt: List of ChatMessages.
|
||||
model_name: The model to use.
|
||||
functions: Optional list of functions available to the LLM.
|
||||
kwargs: Additional keyword arguments.
|
||||
|
||||
@@ -300,13 +300,13 @@ class GroqProvider(Configurable[GroqSettings], ChatModelProvider):
|
||||
dict[str, Any]: Any other kwargs for the OpenAI call
|
||||
"""
|
||||
kwargs: CompletionCreateParams = kwargs # type: ignore
|
||||
kwargs["model"] = model
|
||||
if max_output_tokens:
|
||||
kwargs["max_tokens"] = max_output_tokens
|
||||
|
||||
if functions:
|
||||
kwargs["tools"] = [
|
||||
{"type": "function", "function": f.schema} for f in functions
|
||||
{"type": "function", "function": format_function_def_for_openai(f)}
|
||||
for f in functions
|
||||
]
|
||||
if len(functions) == 1:
|
||||
# force the model to call the only specified function
|
||||
@@ -321,7 +321,7 @@ class GroqProvider(Configurable[GroqSettings], ChatModelProvider):
|
||||
kwargs["extra_headers"].update(extra_headers.copy()) # type: ignore
|
||||
|
||||
groq_messages: list[ChatCompletionMessageParam] = [
|
||||
message.dict(
|
||||
message.dict( # type: ignore
|
||||
include={"role", "content", "tool_calls", "tool_call_id", "name"},
|
||||
exclude_none=True,
|
||||
)
|
||||
@@ -335,7 +335,7 @@ class GroqProvider(Configurable[GroqSettings], ChatModelProvider):
|
||||
return groq_messages, kwargs
|
||||
|
||||
async def _create_chat_completion(
|
||||
self, completion_kwargs: CompletionCreateParams
|
||||
self, model: GroqModelName, completion_kwargs: CompletionCreateParams
|
||||
) -> tuple[ChatCompletion, float, int, int]:
|
||||
"""
|
||||
Create a chat completion using the Groq API with retry handling.
|
||||
@@ -351,24 +351,31 @@ class GroqProvider(Configurable[GroqSettings], ChatModelProvider):
|
||||
"""
|
||||
|
||||
@self._retry_api_request
|
||||
async def _create_chat_completion_with_retry(
|
||||
completion_kwargs: CompletionCreateParams,
|
||||
) -> ChatCompletion:
|
||||
return await self._client.chat.completions.create(**completion_kwargs)
|
||||
async def _create_chat_completion_with_retry() -> ChatCompletion:
|
||||
return await self._client.chat.completions.create(
|
||||
model=model, **completion_kwargs # type: ignore
|
||||
)
|
||||
|
||||
response = await _create_chat_completion_with_retry(completion_kwargs)
|
||||
response = await _create_chat_completion_with_retry()
|
||||
|
||||
cost = self._budget.update_usage_and_cost(
|
||||
model_info=GROQ_CHAT_MODELS[completion_kwargs["model"]],
|
||||
input_tokens_used=response.usage.prompt_tokens,
|
||||
output_tokens_used=response.usage.completion_tokens,
|
||||
)
|
||||
return (
|
||||
response,
|
||||
cost,
|
||||
response.usage.prompt_tokens,
|
||||
response.usage.completion_tokens,
|
||||
)
|
||||
if not response.usage:
|
||||
self._logger.warning(
|
||||
"Groq chat completion response does not contain a usage field",
|
||||
response,
|
||||
)
|
||||
return response, 0, 0, 0
|
||||
else:
|
||||
cost = self._budget.update_usage_and_cost(
|
||||
model_info=GROQ_CHAT_MODELS[model],
|
||||
input_tokens_used=response.usage.prompt_tokens,
|
||||
output_tokens_used=response.usage.completion_tokens,
|
||||
)
|
||||
return (
|
||||
response,
|
||||
cost,
|
||||
response.usage.prompt_tokens,
|
||||
response.usage.completion_tokens,
|
||||
)
|
||||
|
||||
def _parse_assistant_tool_calls(
|
||||
self, assistant_message: ChatCompletionMessage, compat_mode: bool = False
|
||||
|
||||
@@ -1,20 +1,18 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import Callable, Iterator, Optional, TypeVar
|
||||
from typing import Any, Callable, Iterator, Optional, TypeVar
|
||||
|
||||
from pydantic import ValidationError
|
||||
|
||||
from forge.models.config import Configurable
|
||||
|
||||
from .anthropic import ANTHROPIC_CHAT_MODELS, AnthropicModelName, AnthropicProvider
|
||||
from .groq import GROQ_CHAT_MODELS, GroqModelName, GroqProvider
|
||||
from .openai import OPEN_AI_CHAT_MODELS, OpenAIModelName, OpenAIProvider
|
||||
from .schema import (
|
||||
AssistantChatMessage,
|
||||
BaseChatModelProvider,
|
||||
ChatMessage,
|
||||
ChatModelInfo,
|
||||
ChatModelProvider,
|
||||
ChatModelResponse,
|
||||
CompletionModelFunction,
|
||||
ModelProviderBudget,
|
||||
@@ -27,11 +25,12 @@ from .schema import (
|
||||
_T = TypeVar("_T")
|
||||
|
||||
ModelName = AnthropicModelName | GroqModelName | OpenAIModelName
|
||||
EmbeddingModelProvider = OpenAIProvider
|
||||
|
||||
CHAT_MODELS = {**ANTHROPIC_CHAT_MODELS, **GROQ_CHAT_MODELS, **OPEN_AI_CHAT_MODELS}
|
||||
|
||||
|
||||
class MultiProvider(Configurable[ModelProviderSettings], ChatModelProvider):
|
||||
class MultiProvider(BaseChatModelProvider[ModelName, ModelProviderSettings]):
|
||||
default_settings = ModelProviderSettings(
|
||||
name="multi_provider",
|
||||
description=(
|
||||
@@ -57,7 +56,7 @@ class MultiProvider(Configurable[ModelProviderSettings], ChatModelProvider):
|
||||
|
||||
self._provider_instances = {}
|
||||
|
||||
async def get_available_models(self) -> list[ChatModelInfo]:
|
||||
async def get_available_models(self) -> list[ChatModelInfo[ModelName]]:
|
||||
models = []
|
||||
for provider in self.get_available_providers():
|
||||
models.extend(await provider.get_available_models())
|
||||
@@ -65,24 +64,25 @@ class MultiProvider(Configurable[ModelProviderSettings], ChatModelProvider):
|
||||
|
||||
def get_token_limit(self, model_name: ModelName) -> int:
|
||||
"""Get the token limit for a given model."""
|
||||
return self.get_model_provider(model_name).get_token_limit(model_name)
|
||||
|
||||
@classmethod
|
||||
def get_tokenizer(cls, model_name: ModelName) -> ModelTokenizer:
|
||||
return cls._get_model_provider_class(model_name).get_tokenizer(model_name)
|
||||
|
||||
@classmethod
|
||||
def count_tokens(cls, text: str, model_name: ModelName) -> int:
|
||||
return cls._get_model_provider_class(model_name).count_tokens(
|
||||
text=text, model_name=model_name
|
||||
return self.get_model_provider(model_name).get_token_limit(
|
||||
model_name # type: ignore
|
||||
)
|
||||
|
||||
def get_tokenizer(self, model_name: ModelName) -> ModelTokenizer[Any]:
|
||||
return self.get_model_provider(model_name).get_tokenizer(
|
||||
model_name # type: ignore
|
||||
)
|
||||
|
||||
def count_tokens(self, text: str, model_name: ModelName) -> int:
|
||||
return self.get_model_provider(model_name).count_tokens(
|
||||
text=text, model_name=model_name # type: ignore
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def count_message_tokens(
|
||||
cls, messages: ChatMessage | list[ChatMessage], model_name: ModelName
|
||||
self, messages: ChatMessage | list[ChatMessage], model_name: ModelName
|
||||
) -> int:
|
||||
return cls._get_model_provider_class(model_name).count_message_tokens(
|
||||
messages=messages, model_name=model_name
|
||||
return self.get_model_provider(model_name).count_message_tokens(
|
||||
messages=messages, model_name=model_name # type: ignore
|
||||
)
|
||||
|
||||
async def create_chat_completion(
|
||||
@@ -98,7 +98,7 @@ class MultiProvider(Configurable[ModelProviderSettings], ChatModelProvider):
|
||||
"""Create a completion using the Anthropic API."""
|
||||
return await self.get_model_provider(model_name).create_chat_completion(
|
||||
model_prompt=model_prompt,
|
||||
model_name=model_name,
|
||||
model_name=model_name, # type: ignore
|
||||
completion_parser=completion_parser,
|
||||
functions=functions,
|
||||
max_output_tokens=max_output_tokens,
|
||||
@@ -136,17 +136,11 @@ class MultiProvider(Configurable[ModelProviderSettings], ChatModelProvider):
|
||||
) from e
|
||||
|
||||
self._provider_instances[provider_name] = _provider = Provider(
|
||||
settings=settings, logger=self._logger
|
||||
settings=settings, logger=self._logger # type: ignore
|
||||
)
|
||||
_provider._budget = self._budget # Object binding not preserved by Pydantic
|
||||
return _provider
|
||||
|
||||
@classmethod
|
||||
def _get_model_provider_class(
|
||||
cls, model_name: ModelName
|
||||
) -> type[AnthropicProvider | GroqProvider | OpenAIProvider]:
|
||||
return cls._get_provider_class(CHAT_MODELS[model_name].provider_name)
|
||||
|
||||
@classmethod
|
||||
def _get_provider_class(
|
||||
cls, provider_name: ModelProviderName
|
||||
@@ -162,3 +156,6 @@ class MultiProvider(Configurable[ModelProviderSettings], ChatModelProvider):
|
||||
|
||||
def __repr__(self):
|
||||
return f"{self.__class__.__name__}()"
|
||||
|
||||
|
||||
ChatModelProvider = AnthropicProvider | GroqProvider | OpenAIProvider | MultiProvider
|
||||
|
||||
@@ -2,7 +2,16 @@ import enum
|
||||
import logging
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Any, Callable, Coroutine, Iterator, Optional, ParamSpec, TypeVar
|
||||
from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
Coroutine,
|
||||
Iterator,
|
||||
Optional,
|
||||
ParamSpec,
|
||||
TypeVar,
|
||||
cast,
|
||||
)
|
||||
|
||||
import sentry_sdk
|
||||
import tenacity
|
||||
@@ -12,9 +21,11 @@ from openai._exceptions import APIStatusError, RateLimitError
|
||||
from openai.types import CreateEmbeddingResponse
|
||||
from openai.types.chat import (
|
||||
ChatCompletion,
|
||||
ChatCompletionAssistantMessageParam,
|
||||
ChatCompletionMessage,
|
||||
ChatCompletionMessageParam,
|
||||
)
|
||||
from openai.types.shared_params import FunctionDefinition
|
||||
from pydantic import SecretStr
|
||||
|
||||
from forge.json.parsing import json_loads
|
||||
@@ -23,14 +34,14 @@ from forge.llm.providers.schema import (
|
||||
AssistantFunctionCall,
|
||||
AssistantToolCall,
|
||||
AssistantToolCallDict,
|
||||
BaseChatModelProvider,
|
||||
BaseEmbeddingModelProvider,
|
||||
ChatMessage,
|
||||
ChatModelInfo,
|
||||
ChatModelProvider,
|
||||
ChatModelResponse,
|
||||
CompletionModelFunction,
|
||||
Embedding,
|
||||
EmbeddingModelInfo,
|
||||
EmbeddingModelProvider,
|
||||
EmbeddingModelResponse,
|
||||
ModelProviderBudget,
|
||||
ModelProviderConfiguration,
|
||||
@@ -39,7 +50,7 @@ from forge.llm.providers.schema import (
|
||||
ModelProviderSettings,
|
||||
ModelTokenizer,
|
||||
)
|
||||
from forge.models.config import Configurable, UserConfigurable
|
||||
from forge.models.config import UserConfigurable
|
||||
from forge.models.json_schema import JSONSchema
|
||||
|
||||
from .utils import validate_tool_calls
|
||||
@@ -223,43 +234,46 @@ class OpenAIConfiguration(ModelProviderConfiguration):
|
||||
class OpenAICredentials(ModelProviderCredentials):
|
||||
"""Credentials for OpenAI."""
|
||||
|
||||
api_key: SecretStr = UserConfigurable(from_env="OPENAI_API_KEY")
|
||||
api_key: SecretStr = UserConfigurable(from_env="OPENAI_API_KEY") # type: ignore
|
||||
api_base: Optional[SecretStr] = UserConfigurable(
|
||||
default=None, from_env="OPENAI_API_BASE_URL"
|
||||
)
|
||||
organization: Optional[SecretStr] = UserConfigurable(from_env="OPENAI_ORGANIZATION")
|
||||
|
||||
api_type: str = UserConfigurable(
|
||||
default="",
|
||||
from_env=lambda: (
|
||||
api_type: Optional[SecretStr] = UserConfigurable(
|
||||
default=None,
|
||||
from_env=lambda: cast(
|
||||
SecretStr | None,
|
||||
"azure"
|
||||
if os.getenv("USE_AZURE") == "True"
|
||||
else os.getenv("OPENAI_API_TYPE")
|
||||
else os.getenv("OPENAI_API_TYPE"),
|
||||
),
|
||||
)
|
||||
api_version: str = UserConfigurable("", from_env="OPENAI_API_VERSION")
|
||||
api_version: Optional[SecretStr] = UserConfigurable(
|
||||
default=None, from_env="OPENAI_API_VERSION"
|
||||
)
|
||||
azure_endpoint: Optional[SecretStr] = None
|
||||
azure_model_to_deploy_id_map: Optional[dict[str, str]] = None
|
||||
|
||||
def get_api_access_kwargs(self) -> dict[str, str]:
|
||||
kwargs = {
|
||||
k: (v.get_secret_value() if type(v) is SecretStr else v)
|
||||
k: v.get_secret_value()
|
||||
for k, v in {
|
||||
"api_key": self.api_key,
|
||||
"base_url": self.api_base,
|
||||
"organization": self.organization,
|
||||
"api_version": self.api_version,
|
||||
}.items()
|
||||
if v is not None
|
||||
}
|
||||
if self.api_type == "azure":
|
||||
kwargs["api_version"] = self.api_version
|
||||
if self.api_type == SecretStr("azure"):
|
||||
assert self.azure_endpoint, "Azure endpoint not configured"
|
||||
kwargs["azure_endpoint"] = self.azure_endpoint.get_secret_value()
|
||||
return kwargs
|
||||
|
||||
def get_model_access_kwargs(self, model: str) -> dict[str, str]:
|
||||
kwargs = {"model": model}
|
||||
if self.api_type == "azure" and model:
|
||||
if self.api_type == SecretStr("azure") and model:
|
||||
azure_kwargs = self._get_azure_access_kwargs(model)
|
||||
kwargs.update(azure_kwargs)
|
||||
return kwargs
|
||||
@@ -276,7 +290,7 @@ class OpenAICredentials(ModelProviderCredentials):
|
||||
raise ValueError(*e.args)
|
||||
|
||||
self.api_type = config_params.get("azure_api_type", "azure")
|
||||
self.api_version = config_params.get("azure_api_version", "")
|
||||
self.api_version = config_params.get("azure_api_version", None)
|
||||
self.azure_endpoint = config_params.get("azure_endpoint")
|
||||
self.azure_model_to_deploy_id_map = config_params.get("azure_model_map")
|
||||
|
||||
@@ -294,13 +308,14 @@ class OpenAICredentials(ModelProviderCredentials):
|
||||
|
||||
|
||||
class OpenAISettings(ModelProviderSettings):
|
||||
configuration: OpenAIConfiguration
|
||||
credentials: Optional[OpenAICredentials]
|
||||
budget: ModelProviderBudget
|
||||
configuration: OpenAIConfiguration # type: ignore
|
||||
credentials: Optional[OpenAICredentials] # type: ignore
|
||||
budget: ModelProviderBudget # type: ignore
|
||||
|
||||
|
||||
class OpenAIProvider(
|
||||
Configurable[OpenAISettings], ChatModelProvider, EmbeddingModelProvider
|
||||
BaseChatModelProvider[OpenAIModelName, OpenAISettings],
|
||||
BaseEmbeddingModelProvider[OpenAIModelName, OpenAISettings],
|
||||
):
|
||||
default_settings = OpenAISettings(
|
||||
name="openai_provider",
|
||||
@@ -329,37 +344,38 @@ class OpenAIProvider(
|
||||
|
||||
super(OpenAIProvider, self).__init__(settings=settings, logger=logger)
|
||||
|
||||
if self._credentials.api_type == "azure":
|
||||
if self._credentials.api_type == SecretStr("azure"):
|
||||
from openai import AsyncAzureOpenAI
|
||||
|
||||
# API key and org (if configured) are passed, the rest of the required
|
||||
# credentials is loaded from the environment by the AzureOpenAI client.
|
||||
self._client = AsyncAzureOpenAI(**self._credentials.get_api_access_kwargs())
|
||||
self._client = AsyncAzureOpenAI(
|
||||
**self._credentials.get_api_access_kwargs() # type: ignore
|
||||
)
|
||||
else:
|
||||
from openai import AsyncOpenAI
|
||||
|
||||
self._client = AsyncOpenAI(**self._credentials.get_api_access_kwargs())
|
||||
self._client = AsyncOpenAI(
|
||||
**self._credentials.get_api_access_kwargs() # type: ignore
|
||||
)
|
||||
|
||||
async def get_available_models(self) -> list[ChatModelInfo]:
|
||||
async def get_available_models(self) -> list[ChatModelInfo[OpenAIModelName]]:
|
||||
_models = (await self._client.models.list()).data
|
||||
return [OPEN_AI_MODELS[m.id] for m in _models if m.id in OPEN_AI_MODELS]
|
||||
|
||||
def get_token_limit(self, model_name: str) -> int:
|
||||
def get_token_limit(self, model_name: OpenAIModelName) -> int:
|
||||
"""Get the token limit for a given model."""
|
||||
return OPEN_AI_MODELS[model_name].max_tokens
|
||||
|
||||
@classmethod
|
||||
def get_tokenizer(cls, model_name: OpenAIModelName) -> ModelTokenizer:
|
||||
def get_tokenizer(self, model_name: OpenAIModelName) -> ModelTokenizer[int]:
|
||||
return tiktoken.encoding_for_model(model_name)
|
||||
|
||||
@classmethod
|
||||
def count_tokens(cls, text: str, model_name: OpenAIModelName) -> int:
|
||||
encoding = cls.get_tokenizer(model_name)
|
||||
def count_tokens(self, text: str, model_name: OpenAIModelName) -> int:
|
||||
encoding = self.get_tokenizer(model_name)
|
||||
return len(encoding.encode(text))
|
||||
|
||||
@classmethod
|
||||
def count_message_tokens(
|
||||
cls,
|
||||
self,
|
||||
messages: ChatMessage | list[ChatMessage],
|
||||
model_name: OpenAIModelName,
|
||||
) -> int:
|
||||
@@ -447,7 +463,7 @@ class OpenAIProvider(
|
||||
parse_errors += validate_tool_calls(tool_calls, functions)
|
||||
|
||||
assistant_msg = AssistantChatMessage(
|
||||
content=_assistant_msg.content,
|
||||
content=_assistant_msg.content or "",
|
||||
tool_calls=tool_calls or None,
|
||||
)
|
||||
|
||||
@@ -466,7 +482,7 @@ class OpenAIProvider(
|
||||
|
||||
return ChatModelResponse(
|
||||
response=AssistantChatMessage(
|
||||
content=_assistant_msg.content,
|
||||
content=_assistant_msg.content or "",
|
||||
tool_calls=tool_calls or None,
|
||||
),
|
||||
parsed_result=parsed_result,
|
||||
@@ -492,7 +508,12 @@ class OpenAIProvider(
|
||||
)
|
||||
|
||||
if attempts < self._configuration.fix_failed_parse_tries:
|
||||
openai_messages.append(_assistant_msg.dict(exclude_none=True))
|
||||
openai_messages.append(
|
||||
cast(
|
||||
ChatCompletionAssistantMessageParam,
|
||||
_assistant_msg.dict(exclude_none=True),
|
||||
)
|
||||
)
|
||||
openai_messages.append(
|
||||
{
|
||||
"role": "system",
|
||||
@@ -522,7 +543,10 @@ class OpenAIProvider(
|
||||
prompt_tokens_used=response.usage.prompt_tokens,
|
||||
completion_tokens_used=0,
|
||||
)
|
||||
self._budget.update_usage_and_cost(response)
|
||||
self._budget.update_usage_and_cost(
|
||||
model_info=response.model_info,
|
||||
input_tokens_used=response.prompt_tokens_used,
|
||||
)
|
||||
return response
|
||||
|
||||
def _get_chat_completion_args(
|
||||
@@ -549,7 +573,8 @@ class OpenAIProvider(
|
||||
if functions:
|
||||
if OPEN_AI_CHAT_MODELS[model_name].has_function_call_api:
|
||||
kwargs["tools"] = [
|
||||
{"type": "function", "function": f.schema} for f in functions
|
||||
{"type": "function", "function": format_function_def_for_openai(f)}
|
||||
for f in functions
|
||||
]
|
||||
if len(functions) == 1:
|
||||
# force the model to call the only specified function
|
||||
@@ -569,10 +594,13 @@ class OpenAIProvider(
|
||||
model_prompt += kwargs["messages"]
|
||||
del kwargs["messages"]
|
||||
|
||||
openai_messages: list[ChatCompletionMessageParam] = [
|
||||
message.dict(
|
||||
include={"role", "content", "tool_calls", "name"},
|
||||
exclude_none=True,
|
||||
openai_messages = [
|
||||
cast(
|
||||
ChatCompletionMessageParam,
|
||||
message.dict(
|
||||
include={"role", "content", "tool_calls", "name"},
|
||||
exclude_none=True,
|
||||
),
|
||||
)
|
||||
for message in model_prompt
|
||||
]
|
||||
@@ -655,7 +683,7 @@ class OpenAIProvider(
|
||||
|
||||
def _parse_assistant_tool_calls(
|
||||
self, assistant_message: ChatCompletionMessage, compat_mode: bool = False
|
||||
):
|
||||
) -> tuple[list[AssistantToolCall], list[Exception]]:
|
||||
tool_calls: list[AssistantToolCall] = []
|
||||
parse_errors: list[Exception] = []
|
||||
|
||||
@@ -749,6 +777,24 @@ class OpenAIProvider(
|
||||
return "OpenAIProvider()"
|
||||
|
||||
|
||||
def format_function_def_for_openai(self: CompletionModelFunction) -> FunctionDefinition:
|
||||
"""Returns an OpenAI-consumable function definition"""
|
||||
|
||||
return {
|
||||
"name": self.name,
|
||||
"description": self.description,
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
name: param.to_dict() for name, param in self.parameters.items()
|
||||
},
|
||||
"required": [
|
||||
name for name, param in self.parameters.items() if param.required
|
||||
],
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def format_function_specs_as_typescript_ns(
|
||||
functions: list[CompletionModelFunction],
|
||||
) -> str:
|
||||
@@ -888,6 +934,4 @@ def _tool_calls_compat_extract_calls(response: str) -> Iterator[AssistantToolCal
|
||||
|
||||
for t in tool_calls:
|
||||
t["id"] = str(uuid.uuid4())
|
||||
t["function"]["arguments"] = str(t["function"]["arguments"]) # HACK
|
||||
|
||||
yield AssistantToolCall.parse_obj(t)
|
||||
|
||||
@@ -19,14 +19,17 @@ from typing import (
|
||||
from pydantic import BaseModel, Field, SecretStr, validator
|
||||
|
||||
from forge.logging.utils import fmt_kwargs
|
||||
from forge.models.config import SystemConfiguration, UserConfigurable
|
||||
from forge.models.config import (
|
||||
Configurable,
|
||||
SystemConfiguration,
|
||||
SystemSettings,
|
||||
UserConfigurable,
|
||||
)
|
||||
from forge.models.json_schema import JSONSchema
|
||||
from forge.models.providers import (
|
||||
Embedding,
|
||||
ProviderBudget,
|
||||
ProviderCredentials,
|
||||
ProviderSettings,
|
||||
ProviderUsage,
|
||||
ResourceType,
|
||||
)
|
||||
|
||||
@@ -34,6 +37,11 @@ if TYPE_CHECKING:
|
||||
from jsonschema import ValidationError
|
||||
|
||||
|
||||
_T = TypeVar("_T")
|
||||
|
||||
_ModelName = TypeVar("_ModelName", bound=str)
|
||||
|
||||
|
||||
class ModelProviderService(str, enum.Enum):
|
||||
"""A ModelService describes what kind of service the model provides."""
|
||||
|
||||
@@ -102,13 +110,13 @@ class AssistantToolCallDict(TypedDict):
|
||||
|
||||
|
||||
class AssistantChatMessage(ChatMessage):
|
||||
role: Literal[ChatMessage.Role.ASSISTANT] = ChatMessage.Role.ASSISTANT
|
||||
content: Optional[str]
|
||||
role: Literal[ChatMessage.Role.ASSISTANT] = ChatMessage.Role.ASSISTANT # type: ignore # noqa
|
||||
content: str = ""
|
||||
tool_calls: Optional[list[AssistantToolCall]] = None
|
||||
|
||||
|
||||
class ToolResultMessage(ChatMessage):
|
||||
role: Literal[ChatMessage.Role.TOOL] = ChatMessage.Role.TOOL
|
||||
role: Literal[ChatMessage.Role.TOOL] = ChatMessage.Role.TOOL # type: ignore
|
||||
is_error: bool = False
|
||||
tool_call_id: str
|
||||
|
||||
@@ -126,32 +134,6 @@ class CompletionModelFunction(BaseModel):
|
||||
description: str
|
||||
parameters: dict[str, "JSONSchema"]
|
||||
|
||||
@property
|
||||
def schema(self) -> dict[str, str | dict | list]:
|
||||
"""Returns an OpenAI-consumable function specification"""
|
||||
|
||||
return {
|
||||
"name": self.name,
|
||||
"description": self.description,
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
name: param.to_dict() for name, param in self.parameters.items()
|
||||
},
|
||||
"required": [
|
||||
name for name, param in self.parameters.items() if param.required
|
||||
],
|
||||
},
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def parse(schema: dict) -> "CompletionModelFunction":
|
||||
return CompletionModelFunction(
|
||||
name=schema["name"],
|
||||
description=schema["description"],
|
||||
parameters=JSONSchema.parse_properties(schema["parameters"]),
|
||||
)
|
||||
|
||||
def fmt_line(self) -> str:
|
||||
params = ", ".join(
|
||||
f"{name}{'?' if not p.required else ''}: " f"{p.typescript_type}"
|
||||
@@ -184,15 +166,15 @@ class CompletionModelFunction(BaseModel):
|
||||
return params_schema.validate_object(function_call.arguments)
|
||||
|
||||
|
||||
class ModelInfo(BaseModel):
|
||||
class ModelInfo(BaseModel, Generic[_ModelName]):
|
||||
"""Struct for model information.
|
||||
|
||||
Would be lovely to eventually get this directly from APIs, but needs to be
|
||||
scraped from websites for now.
|
||||
"""
|
||||
|
||||
name: str
|
||||
service: ModelProviderService
|
||||
name: _ModelName
|
||||
service: ClassVar[ModelProviderService]
|
||||
provider_name: ModelProviderName
|
||||
prompt_token_cost: float = 0.0
|
||||
completion_token_cost: float = 0.0
|
||||
@@ -220,27 +202,39 @@ class ModelProviderCredentials(ProviderCredentials):
|
||||
api_version: SecretStr | None = UserConfigurable(default=None)
|
||||
deployment_id: SecretStr | None = UserConfigurable(default=None)
|
||||
|
||||
class Config:
|
||||
class Config(ProviderCredentials.Config):
|
||||
extra = "ignore"
|
||||
|
||||
|
||||
class ModelProviderUsage(ProviderUsage):
|
||||
class ModelProviderUsage(BaseModel):
|
||||
"""Usage for a particular model from a model provider."""
|
||||
|
||||
completion_tokens: int = 0
|
||||
prompt_tokens: int = 0
|
||||
class ModelUsage(BaseModel):
|
||||
completion_tokens: int = 0
|
||||
prompt_tokens: int = 0
|
||||
|
||||
usage_per_model: dict[str, ModelUsage] = defaultdict(ModelUsage)
|
||||
|
||||
@property
|
||||
def completion_tokens(self) -> int:
|
||||
return sum(model.completion_tokens for model in self.usage_per_model.values())
|
||||
|
||||
@property
|
||||
def prompt_tokens(self) -> int:
|
||||
return sum(model.prompt_tokens for model in self.usage_per_model.values())
|
||||
|
||||
def update_usage(
|
||||
self,
|
||||
model: str,
|
||||
input_tokens_used: int,
|
||||
output_tokens_used: int = 0,
|
||||
) -> None:
|
||||
self.prompt_tokens += input_tokens_used
|
||||
self.completion_tokens += output_tokens_used
|
||||
self.usage_per_model[model].prompt_tokens += input_tokens_used
|
||||
self.usage_per_model[model].completion_tokens += output_tokens_used
|
||||
|
||||
|
||||
class ModelProviderBudget(ProviderBudget):
|
||||
usage: defaultdict[str, ModelProviderUsage] = defaultdict(ModelProviderUsage)
|
||||
class ModelProviderBudget(ProviderBudget[ModelProviderUsage]):
|
||||
usage: ModelProviderUsage = Field(default_factory=ModelProviderUsage)
|
||||
|
||||
def update_usage_and_cost(
|
||||
self,
|
||||
@@ -253,7 +247,7 @@ class ModelProviderBudget(ProviderBudget):
|
||||
Returns:
|
||||
float: The (calculated) cost of the given model response.
|
||||
"""
|
||||
self.usage[model_info.name].update_usage(input_tokens_used, output_tokens_used)
|
||||
self.usage.update_usage(model_info.name, input_tokens_used, output_tokens_used)
|
||||
incurred_cost = (
|
||||
output_tokens_used * model_info.completion_token_cost
|
||||
+ input_tokens_used * model_info.prompt_token_cost
|
||||
@@ -263,28 +257,33 @@ class ModelProviderBudget(ProviderBudget):
|
||||
return incurred_cost
|
||||
|
||||
|
||||
class ModelProviderSettings(ProviderSettings):
|
||||
resource_type: ResourceType = ResourceType.MODEL
|
||||
class ModelProviderSettings(SystemSettings):
|
||||
resource_type: ClassVar[ResourceType] = ResourceType.MODEL
|
||||
configuration: ModelProviderConfiguration
|
||||
credentials: Optional[ModelProviderCredentials] = None
|
||||
budget: Optional[ModelProviderBudget] = None
|
||||
|
||||
|
||||
class ModelProvider(abc.ABC):
|
||||
_ModelProviderSettings = TypeVar("_ModelProviderSettings", bound=ModelProviderSettings)
|
||||
|
||||
|
||||
# TODO: either use MultiProvider throughout codebase as type for `llm_provider`, or
|
||||
# replace `_ModelName` by `str` to eliminate type checking difficulties
|
||||
class BaseModelProvider(
|
||||
abc.ABC,
|
||||
Generic[_ModelName, _ModelProviderSettings],
|
||||
Configurable[_ModelProviderSettings],
|
||||
):
|
||||
"""A ModelProvider abstracts the details of a particular provider of models."""
|
||||
|
||||
default_settings: ClassVar[ModelProviderSettings]
|
||||
|
||||
_settings: ModelProviderSettings
|
||||
_configuration: ModelProviderConfiguration
|
||||
_credentials: Optional[ModelProviderCredentials] = None
|
||||
_budget: Optional[ModelProviderBudget] = None
|
||||
default_settings: ClassVar[_ModelProviderSettings] # type: ignore
|
||||
|
||||
_settings: _ModelProviderSettings
|
||||
_logger: logging.Logger
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
settings: Optional[ModelProviderSettings] = None,
|
||||
settings: Optional[_ModelProviderSettings] = None,
|
||||
logger: Optional[logging.Logger] = None,
|
||||
):
|
||||
if not settings:
|
||||
@@ -298,15 +297,15 @@ class ModelProvider(abc.ABC):
|
||||
self._logger = logger or logging.getLogger(self.__module__)
|
||||
|
||||
@abc.abstractmethod
|
||||
def count_tokens(self, text: str, model_name: str) -> int:
|
||||
def count_tokens(self, text: str, model_name: _ModelName) -> int:
|
||||
...
|
||||
|
||||
@abc.abstractmethod
|
||||
def get_tokenizer(self, model_name: str) -> "ModelTokenizer":
|
||||
def get_tokenizer(self, model_name: _ModelName) -> "ModelTokenizer[Any]":
|
||||
...
|
||||
|
||||
@abc.abstractmethod
|
||||
def get_token_limit(self, model_name: str) -> int:
|
||||
def get_token_limit(self, model_name: _ModelName) -> int:
|
||||
...
|
||||
|
||||
def get_incurred_cost(self) -> float:
|
||||
@@ -320,15 +319,15 @@ class ModelProvider(abc.ABC):
|
||||
return math.inf
|
||||
|
||||
|
||||
class ModelTokenizer(Protocol):
|
||||
class ModelTokenizer(Protocol, Generic[_T]):
|
||||
"""A ModelTokenizer provides tokenization specific to a model."""
|
||||
|
||||
@abc.abstractmethod
|
||||
def encode(self, text: str) -> list:
|
||||
def encode(self, text: str) -> list[_T]:
|
||||
...
|
||||
|
||||
@abc.abstractmethod
|
||||
def decode(self, tokens: list) -> str:
|
||||
def decode(self, tokens: list[_T]) -> str:
|
||||
...
|
||||
|
||||
|
||||
@@ -337,10 +336,10 @@ class ModelTokenizer(Protocol):
|
||||
####################
|
||||
|
||||
|
||||
class EmbeddingModelInfo(ModelInfo):
|
||||
class EmbeddingModelInfo(ModelInfo[_ModelName]):
|
||||
"""Struct for embedding model information."""
|
||||
|
||||
service: Literal[ModelProviderService.EMBEDDING] = ModelProviderService.EMBEDDING
|
||||
service = ModelProviderService.EMBEDDING
|
||||
max_tokens: int
|
||||
embedding_dimensions: int
|
||||
|
||||
@@ -350,20 +349,19 @@ class EmbeddingModelResponse(ModelResponse):
|
||||
|
||||
embedding: Embedding = Field(default_factory=list)
|
||||
|
||||
@classmethod
|
||||
@validator("completion_tokens_used")
|
||||
def _verify_no_completion_tokens_used(cls, v):
|
||||
def _verify_no_completion_tokens_used(cls, v: int):
|
||||
if v > 0:
|
||||
raise ValueError("Embeddings should not have completion tokens used.")
|
||||
return v
|
||||
|
||||
|
||||
class EmbeddingModelProvider(ModelProvider):
|
||||
class BaseEmbeddingModelProvider(BaseModelProvider[_ModelName, _ModelProviderSettings]):
|
||||
@abc.abstractmethod
|
||||
async def create_embedding(
|
||||
self,
|
||||
text: str,
|
||||
model_name: str,
|
||||
model_name: _ModelName,
|
||||
embedding_parser: Callable[[Embedding], Embedding],
|
||||
**kwargs,
|
||||
) -> EmbeddingModelResponse:
|
||||
@@ -375,34 +373,31 @@ class EmbeddingModelProvider(ModelProvider):
|
||||
###############
|
||||
|
||||
|
||||
class ChatModelInfo(ModelInfo):
|
||||
class ChatModelInfo(ModelInfo[_ModelName]):
|
||||
"""Struct for language model information."""
|
||||
|
||||
service: Literal[ModelProviderService.CHAT] = ModelProviderService.CHAT
|
||||
service = ModelProviderService.CHAT
|
||||
max_tokens: int
|
||||
has_function_call_api: bool = False
|
||||
|
||||
|
||||
_T = TypeVar("_T")
|
||||
|
||||
|
||||
class ChatModelResponse(ModelResponse, Generic[_T]):
|
||||
"""Standard response struct for a response from a language model."""
|
||||
|
||||
response: AssistantChatMessage
|
||||
parsed_result: _T = None
|
||||
parsed_result: _T
|
||||
|
||||
|
||||
class ChatModelProvider(ModelProvider):
|
||||
class BaseChatModelProvider(BaseModelProvider[_ModelName, _ModelProviderSettings]):
|
||||
@abc.abstractmethod
|
||||
async def get_available_models(self) -> list[ChatModelInfo]:
|
||||
async def get_available_models(self) -> list[ChatModelInfo[_ModelName]]:
|
||||
...
|
||||
|
||||
@abc.abstractmethod
|
||||
def count_message_tokens(
|
||||
self,
|
||||
messages: ChatMessage | list[ChatMessage],
|
||||
model_name: str,
|
||||
model_name: _ModelName,
|
||||
) -> int:
|
||||
...
|
||||
|
||||
@@ -410,7 +405,7 @@ class ChatModelProvider(ModelProvider):
|
||||
async def create_chat_completion(
|
||||
self,
|
||||
model_prompt: list[ChatMessage],
|
||||
model_name: str,
|
||||
model_name: _ModelName,
|
||||
completion_parser: Callable[[AssistantChatMessage], _T] = lambda _: None,
|
||||
functions: Optional[list[CompletionModelFunction]] = None,
|
||||
max_output_tokens: Optional[int] = None,
|
||||
|
||||
@@ -8,7 +8,6 @@ import sys
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
|
||||
from colorama import Fore, Style
|
||||
from openai._base_client import log as openai_logger
|
||||
|
||||
from forge.models.config import SystemConfiguration, UserConfigurable
|
||||
@@ -65,7 +64,7 @@ class LoggingConfig(SystemConfiguration):
|
||||
log_dir: Path = LOG_DIR
|
||||
log_file_format: Optional[LogFormatName] = UserConfigurable(
|
||||
default=LogFormatName.SIMPLE,
|
||||
from_env=lambda: os.getenv(
|
||||
from_env=lambda: os.getenv( # type: ignore
|
||||
"LOG_FILE_FORMAT", os.getenv("LOG_FORMAT", "simple")
|
||||
),
|
||||
)
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Literal, Optional
|
||||
from typing import Any, Literal, Optional, TypeVar
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
@@ -11,7 +11,10 @@ from .utils import ModelWithSummary
|
||||
|
||||
class ActionProposal(BaseModel):
|
||||
thoughts: str | ModelWithSummary
|
||||
use_tool: AssistantFunctionCall = None
|
||||
use_tool: AssistantFunctionCall
|
||||
|
||||
|
||||
AnyProposal = TypeVar("AnyProposal", bound=ActionProposal)
|
||||
|
||||
|
||||
class ActionSuccessResult(BaseModel):
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
import abc
|
||||
import os
|
||||
import typing
|
||||
from typing import Any, Callable, Generic, Optional, Type, TypeVar, get_args
|
||||
@@ -85,11 +84,11 @@ class SystemSettings(BaseModel):
|
||||
S = TypeVar("S", bound=SystemSettings)
|
||||
|
||||
|
||||
class Configurable(abc.ABC, Generic[S]):
|
||||
class Configurable(Generic[S]):
|
||||
"""A base class for all configurable objects."""
|
||||
|
||||
prefix: str = ""
|
||||
default_settings: typing.ClassVar[S]
|
||||
default_settings: typing.ClassVar[S] # type: ignore
|
||||
|
||||
@classmethod
|
||||
def get_user_config(cls) -> dict[str, Any]:
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import enum
|
||||
from textwrap import indent
|
||||
from typing import Optional
|
||||
from typing import Optional, overload
|
||||
|
||||
from jsonschema import Draft7Validator, ValidationError
|
||||
from pydantic import BaseModel
|
||||
@@ -57,30 +57,8 @@ class JSONSchema(BaseModel):
|
||||
|
||||
@staticmethod
|
||||
def from_dict(schema: dict) -> "JSONSchema":
|
||||
def resolve_references(schema: dict, definitions: dict) -> dict:
|
||||
"""
|
||||
Recursively resolve type $refs in the JSON schema with their definitions.
|
||||
"""
|
||||
if isinstance(schema, dict):
|
||||
if "$ref" in schema:
|
||||
ref_path = schema["$ref"].split("/")[
|
||||
2:
|
||||
] # Split and remove '#/definitions'
|
||||
ref_value = definitions
|
||||
for key in ref_path:
|
||||
ref_value = ref_value[key]
|
||||
return resolve_references(ref_value, definitions)
|
||||
else:
|
||||
return {
|
||||
k: resolve_references(v, definitions) for k, v in schema.items()
|
||||
}
|
||||
elif isinstance(schema, list):
|
||||
return [resolve_references(item, definitions) for item in schema]
|
||||
else:
|
||||
return schema
|
||||
|
||||
definitions = schema.get("definitions", {})
|
||||
schema = resolve_references(schema, definitions)
|
||||
schema = _resolve_type_refs_in_schema(schema, definitions)
|
||||
|
||||
return JSONSchema(
|
||||
description=schema.get("description"),
|
||||
@@ -147,21 +125,55 @@ class JSONSchema(BaseModel):
|
||||
|
||||
@property
|
||||
def typescript_type(self) -> str:
|
||||
if not self.type:
|
||||
return "any"
|
||||
if self.type == JSONSchema.Type.BOOLEAN:
|
||||
return "boolean"
|
||||
elif self.type in {JSONSchema.Type.INTEGER, JSONSchema.Type.NUMBER}:
|
||||
if self.type in {JSONSchema.Type.INTEGER, JSONSchema.Type.NUMBER}:
|
||||
return "number"
|
||||
elif self.type == JSONSchema.Type.STRING:
|
||||
if self.type == JSONSchema.Type.STRING:
|
||||
return "string"
|
||||
elif self.type == JSONSchema.Type.ARRAY:
|
||||
if self.type == JSONSchema.Type.ARRAY:
|
||||
return f"Array<{self.items.typescript_type}>" if self.items else "Array"
|
||||
elif self.type == JSONSchema.Type.OBJECT:
|
||||
if self.type == JSONSchema.Type.OBJECT:
|
||||
if not self.properties:
|
||||
return "Record<string, any>"
|
||||
return self.to_typescript_object_interface()
|
||||
elif self.enum:
|
||||
if self.enum:
|
||||
return " | ".join(repr(v) for v in self.enum)
|
||||
|
||||
raise NotImplementedError(
|
||||
f"JSONSchema.typescript_type does not support Type.{self.type.name} yet"
|
||||
)
|
||||
|
||||
|
||||
@overload
|
||||
def _resolve_type_refs_in_schema(schema: dict, definitions: dict) -> dict:
|
||||
...
|
||||
|
||||
|
||||
@overload
|
||||
def _resolve_type_refs_in_schema(schema: list, definitions: dict) -> list:
|
||||
...
|
||||
|
||||
|
||||
def _resolve_type_refs_in_schema(schema: dict | list, definitions: dict) -> dict | list:
|
||||
"""
|
||||
Recursively resolve type $refs in the JSON schema with their definitions.
|
||||
"""
|
||||
if isinstance(schema, dict):
|
||||
if "$ref" in schema:
|
||||
ref_path = schema["$ref"].split("/")[2:] # Split and remove '#/definitions'
|
||||
ref_value = definitions
|
||||
for key in ref_path:
|
||||
ref_value = ref_value[key]
|
||||
return _resolve_type_refs_in_schema(ref_value, definitions)
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
f"JSONSchema.typescript_type does not support Type.{self.type.name} yet"
|
||||
)
|
||||
return {
|
||||
k: _resolve_type_refs_in_schema(v, definitions)
|
||||
for k, v in schema.items()
|
||||
}
|
||||
elif isinstance(schema, list):
|
||||
return [_resolve_type_refs_in_schema(item, definitions) for item in schema]
|
||||
else:
|
||||
return schema
|
||||
|
||||
@@ -1,31 +1,26 @@
|
||||
import abc
|
||||
import enum
|
||||
import math
|
||||
from typing import Callable, Generic, TypeVar
|
||||
|
||||
from pydantic import BaseModel, SecretBytes, SecretField, SecretStr
|
||||
|
||||
from forge.models.config import SystemConfiguration, SystemSettings, UserConfigurable
|
||||
from forge.models.config import SystemConfiguration, UserConfigurable
|
||||
|
||||
_T = TypeVar("_T")
|
||||
|
||||
|
||||
class ResourceType(str, enum.Enum):
|
||||
"""An enumeration of resource types."""
|
||||
|
||||
MODEL = "model"
|
||||
MEMORY = "memory"
|
||||
|
||||
|
||||
class ProviderUsage(SystemConfiguration, abc.ABC):
|
||||
@abc.abstractmethod
|
||||
def update_usage(self, *args, **kwargs) -> None:
|
||||
"""Update the usage of the resource."""
|
||||
...
|
||||
|
||||
|
||||
class ProviderBudget(SystemConfiguration):
|
||||
class ProviderBudget(SystemConfiguration, Generic[_T]):
|
||||
total_budget: float = UserConfigurable(math.inf)
|
||||
total_cost: float = 0
|
||||
remaining_budget: float = math.inf
|
||||
usage: ProviderUsage
|
||||
usage: _T
|
||||
|
||||
@abc.abstractmethod
|
||||
def update_usage_and_cost(self, *args, **kwargs) -> float:
|
||||
@@ -43,8 +38,8 @@ class ProviderCredentials(SystemConfiguration):
|
||||
def unmasked(self) -> dict:
|
||||
return unmask(self)
|
||||
|
||||
class Config:
|
||||
json_encoders = {
|
||||
class Config(SystemConfiguration.Config):
|
||||
json_encoders: dict[type[SecretField], Callable[[SecretField], str | None]] = {
|
||||
SecretStr: lambda v: v.get_secret_value() if v else None,
|
||||
SecretBytes: lambda v: v.get_secret_value() if v else None,
|
||||
SecretField: lambda v: v.get_secret_value() if v else None,
|
||||
@@ -62,11 +57,5 @@ def unmask(model: BaseModel):
|
||||
return unmasked_fields
|
||||
|
||||
|
||||
class ProviderSettings(SystemSettings):
|
||||
resource_type: ResourceType
|
||||
credentials: ProviderCredentials | None = None
|
||||
budget: ProviderBudget | None = None
|
||||
|
||||
|
||||
# Used both by model providers and memory providers
|
||||
Embedding = list[float]
|
||||
|
||||
@@ -1,2 +1,4 @@
|
||||
"""This module contains the speech recognition and speech synthesis functions."""
|
||||
"""This module contains the (speech recognition and) speech synthesis functions."""
|
||||
from .say import TextToSpeechProvider, TTSConfig
|
||||
|
||||
__all__ = ["TextToSpeechProvider", "TTSConfig"]
|
||||
|
||||
@@ -45,7 +45,7 @@ class VoiceBase:
|
||||
"""
|
||||
|
||||
@abc.abstractmethod
|
||||
def _speech(self, text: str, voice_index: int = 0) -> bool:
|
||||
def _speech(self, text: str, voice_id: int = 0) -> bool:
|
||||
"""
|
||||
Play the given text.
|
||||
|
||||
|
||||
@@ -66,7 +66,7 @@ class ElevenLabsSpeech(VoiceBase):
|
||||
if voice and voice not in PLACEHOLDERS:
|
||||
self._voices[voice_index] = voice
|
||||
|
||||
def _speech(self, text: str, voice_index: int = 0) -> bool:
|
||||
def _speech(self, text: str, voice_id: int = 0) -> bool:
|
||||
"""Speak text using elevenlabs.io's API
|
||||
|
||||
Args:
|
||||
@@ -77,7 +77,7 @@ class ElevenLabsSpeech(VoiceBase):
|
||||
bool: True if the request was successful, False otherwise
|
||||
"""
|
||||
tts_url = (
|
||||
f"https://api.elevenlabs.io/v1/text-to-speech/{self._voices[voice_index]}"
|
||||
f"https://api.elevenlabs.io/v1/text-to-speech/{self._voices[voice_id]}"
|
||||
)
|
||||
response = requests.post(tts_url, headers=self._headers, json={"text": text})
|
||||
|
||||
|
||||
@@ -15,7 +15,7 @@ class GTTSVoice(VoiceBase):
|
||||
def _setup(self) -> None:
|
||||
pass
|
||||
|
||||
def _speech(self, text: str, _: int = 0) -> bool:
|
||||
def _speech(self, text: str, voice_id: int = 0) -> bool:
|
||||
"""Play the given text."""
|
||||
tts = gtts.gTTS(text)
|
||||
tts.save("speech.mp3")
|
||||
|
||||
@@ -12,11 +12,11 @@ class MacOSTTS(VoiceBase):
|
||||
def _setup(self) -> None:
|
||||
pass
|
||||
|
||||
def _speech(self, text: str, voice_index: int = 0) -> bool:
|
||||
def _speech(self, text: str, voice_id: int = 0) -> bool:
|
||||
"""Play the given text."""
|
||||
if voice_index == 0:
|
||||
if voice_id == 0:
|
||||
subprocess.run(["say", text], shell=False)
|
||||
elif voice_index == 1:
|
||||
elif voice_id == 1:
|
||||
subprocess.run(["say", "-v", "Ava (Premium)", text], shell=False)
|
||||
else:
|
||||
subprocess.run(["say", "-v", "Samantha", text], shell=False)
|
||||
|
||||
@@ -24,8 +24,7 @@ class StreamElementsSpeech(VoiceBase):
|
||||
"""Setup the voices, API key, etc."""
|
||||
self.config = config
|
||||
|
||||
def _speech(self, text: str, voice: str, _: int = 0) -> bool:
|
||||
voice = self.config.voice
|
||||
def _speech(self, text: str, voice_id: int = 0) -> bool:
|
||||
"""Speak text using the streamelements API
|
||||
|
||||
Args:
|
||||
@@ -35,6 +34,7 @@ class StreamElementsSpeech(VoiceBase):
|
||||
Returns:
|
||||
bool: True if the request was successful, False otherwise
|
||||
"""
|
||||
voice = self.config.voice
|
||||
tts_url = (
|
||||
f"https://api.streamelements.com/kappa/v2/speech?voice={voice}&text={text}"
|
||||
)
|
||||
|
||||
@@ -7,7 +7,7 @@ from typing import Optional
|
||||
def get_exception_message():
|
||||
"""Get current exception type and message."""
|
||||
exc_type, exc_value, _ = sys.exc_info()
|
||||
exception_message = f"{exc_type.__name__}: {exc_value}"
|
||||
exception_message = f"{exc_type.__name__}: {exc_value}" if exc_type else exc_value
|
||||
return exception_message
|
||||
|
||||
|
||||
|
||||
@@ -25,7 +25,7 @@ class TXTParser(ParserStrategy):
|
||||
charset_match = charset_normalizer.from_bytes(file.read()).best()
|
||||
logger.debug(
|
||||
f"Reading {getattr(file, 'name', 'file')} "
|
||||
f"with encoding '{charset_match.encoding}'"
|
||||
f"with encoding '{charset_match.encoding if charset_match else None}'"
|
||||
)
|
||||
return str(charset_match)
|
||||
|
||||
|
||||
@@ -39,7 +39,7 @@ def validate_url(func: Callable[P, T]) -> Callable[P, T]:
|
||||
|
||||
return func(*bound_args.args, **bound_args.kwargs)
|
||||
|
||||
return wrapper
|
||||
return wrapper # type: ignore
|
||||
|
||||
|
||||
def is_valid_url(url: str) -> bool:
|
||||
|
||||
@@ -1,13 +0,0 @@
|
||||
[mypy]
|
||||
namespace_packages = True
|
||||
follow_imports = skip
|
||||
check_untyped_defs = True
|
||||
disallow_untyped_defs = True
|
||||
exclude = ^(agbenchmark/challenges/|agent/|venv|venv-dev)
|
||||
ignore_missing_imports = True
|
||||
|
||||
[mypy-agbenchmark.utils.data_types.*]
|
||||
ignore_errors = True
|
||||
|
||||
[mypy-numpy.*]
|
||||
ignore_errors = True
|
||||
759
forge/poetry.lock
generated
759
forge/poetry.lock
generated
File diff suppressed because one or more lines are too long
@@ -26,8 +26,10 @@ duckduckgo-search = "^5.0.0"
|
||||
fastapi = "^0.109.1"
|
||||
gitpython = "^3.1.32"
|
||||
google-api-python-client = "*"
|
||||
google-cloud-logging = "^3.8.0"
|
||||
google-cloud-storage = "^2.13.0"
|
||||
groq = "^0.8.0"
|
||||
gTTS = "^2.3.1"
|
||||
jinja2 = "^3.1.2"
|
||||
jsonschema = "*"
|
||||
litellm = "^1.17.9"
|
||||
@@ -57,16 +59,17 @@ webdriver-manager = "^4.0.1"
|
||||
benchmark = ["agbenchmark"]
|
||||
|
||||
[tool.poetry.group.dev.dependencies]
|
||||
isort = "^5.12.0"
|
||||
black = "^23.3.0"
|
||||
black = "^23.12.1"
|
||||
flake8 = "^7.0.0"
|
||||
isort = "^5.13.1"
|
||||
pyright = "^1.1.364"
|
||||
pre-commit = "^3.3.3"
|
||||
mypy = "^1.4.1"
|
||||
flake8 = "^6.0.0"
|
||||
boto3-stubs = { extras = ["s3"], version = "^1.33.6" }
|
||||
types-requests = "^2.31.0.2"
|
||||
pytest = "^7.4.0"
|
||||
pytest-asyncio = "^0.21.1"
|
||||
pytest-cov = "^5.0.0"
|
||||
mock = "^5.1.0"
|
||||
autoflake = "^2.2.0"
|
||||
pydevd-pycharm = "^233.6745.319"
|
||||
|
||||
|
||||
@@ -74,20 +77,21 @@ pydevd-pycharm = "^233.6745.319"
|
||||
requires = ["poetry-core"]
|
||||
build-backend = "poetry.core.masonry.api"
|
||||
|
||||
|
||||
[tool.black]
|
||||
line-length = 88
|
||||
target-version = ['py310']
|
||||
include = '\.pyi?$'
|
||||
packages = ["forge"]
|
||||
extend-exclude = '(/dist|/.venv|/venv|/build|/agent|agbenchmark/challenges)/'
|
||||
|
||||
|
||||
[tool.isort]
|
||||
profile = "black"
|
||||
multi_line_output = 3
|
||||
include_trailing_comma = true
|
||||
force_grid_wrap = 0
|
||||
use_parentheses = true
|
||||
ensure_newline_before_comments = true
|
||||
line_length = 88
|
||||
sections = ["FUTURE", "STDLIB", "THIRDPARTY", "FIRSTPARTY", "LOCALFOLDER"]
|
||||
skip_glob = [".tox", "__pycache__", "*.pyc", "venv*/*", "reports", "venv", "env", "node_modules", ".env", ".venv", "dist", "agent/*", "agbenchmark/challenges/*"]
|
||||
|
||||
|
||||
[tool.pyright]
|
||||
pythonVersion = "3.10"
|
||||
|
||||
|
||||
[tool.pytest.ini_options]
|
||||
pythonpath = ["forge"]
|
||||
testpaths = ["forge", "tests"]
|
||||
|
||||
Reference in New Issue
Block a user