Set up unified pre-commit + CI w/ linting + type checking & FIX EVERYTHING (#7171)

- **FIX ALL LINT/TYPE ERRORS IN AUTOGPT, FORGE, AND BENCHMARK**

### Linting
- Clean up linter configs for `autogpt`, `forge`, and `benchmark`
- Add type checking with Pyright
- Create unified pre-commit config
- Create unified linting and type checking CI workflow

### Testing
- Synchronize CI test setups for `autogpt`, `forge`, and `benchmark`
   - Add missing pytest-cov to benchmark dependencies
- Mark GCS tests as slow to speed up pre-commit test runs
- Repair `forge` test suite
  - Add `AgentDB.close()` method for test DB teardown in db_test.py
  - Use actual temporary dir instead of forge/test_workspace/
- Move left-behind dependencies for moved `forge`-code to from autogpt to forge

### Notable type changes
- Replace uses of `ChatModelProvider` by `MultiProvider`
- Removed unnecessary exports from various __init__.py
- Simplify `FileStorage.open_file` signature by removing `IOBase` from return type union
  - Implement `S3BinaryIOWrapper(BinaryIO)` type interposer for `S3FileStorage`

- Expand overloads of `GCSFileStorage.open_file` for improved typing of read and write modes

  Had to silence type checking for the extra overloads, because (I think) Pyright is reporting a false-positive:
  https://github.com/microsoft/pyright/issues/8007

- Change `count_tokens`, `get_tokenizer`, `count_message_tokens` methods on `ModelProvider`s from class methods to instance methods

- Move `CompletionModelFunction.schema` method -> helper function `format_function_def_for_openai` in `forge.llm.providers.openai`

- Rename `ModelProvider` -> `BaseModelProvider`
- Rename `ChatModelProvider` -> `BaseChatModelProvider`
- Add type `ChatModelProvider` which is a union of all subclasses of `BaseChatModelProvider`

### Removed rather than fixed
- Remove deprecated and broken autogpt/agbenchmark_config/benchmarks.py
- Various base classes and properties on base classes in `forge.llm.providers.schema` and `forge.models.providers`

### Fixes for other issues that came to light
- Clean up `forge.agent_protocol.api_router`, `forge.agent_protocol.database`, and `forge.agent.agent`

- Add fallback behavior to `ImageGeneratorComponent`
   - Remove test for deprecated failure behavior

- Fix `agbenchmark.challenges.builtin` challenge exclusion mechanism on Windows

- Fix `_tool_calls_compat_extract_calls` in `forge.llm.providers.openai`

- Add support for `any` (= no type specified) in `JSONSchema.typescript_type`
This commit is contained in:
Reinier van der Leer
2024-05-28 05:04:21 +02:00
committed by GitHub
parent 2c13a2706c
commit f107ff8cf0
147 changed files with 2897 additions and 2425 deletions

View File

@@ -1,15 +1,11 @@
[flake8]
max-line-length = 88
select = "E303, W293, W292, E305, E231, E302"
# Ignore rules that conflict with Black code style
extend-ignore = E203, W503
exclude =
.tox,
__pycache__,
.git,
__pycache__/,
*.pyc,
.env
venv*/*,
.venv/*,
reports/*,
dist/*,
agent/*,
code,
agbenchmark/challenges/*
.pytest_cache/,
venv*/,
.venv/,

5
forge/.gitignore vendored
View File

@@ -160,7 +160,8 @@ CURRENT_BULLETIN.md
agbenchmark_config/workspace
agbenchmark_config/reports
*.sqlite
*.sqlite*
*.db
.agbench
.agbenchmark
.benchmarks
@@ -168,7 +169,7 @@ agbenchmark_config/reports
.pytest_cache
.vscode
ig_*
agent.db
agbenchmark_config/updates.json
agbenchmark_config/challenges_already_beaten.json
agbenchmark_config/temp_folder/*
test_workspace/

View File

@@ -1,43 +0,0 @@
repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.4.0
hooks:
- id: check-added-large-files
args: ['--maxkb=500']
- id: check-byte-order-marker
- id: check-case-conflict
- id: check-merge-conflict
- id: check-symlinks
- id: debug-statements
- repo: https://github.com/pycqa/isort
rev: 5.12.0
hooks:
- id: isort
language_version: python3.11
- repo: https://github.com/psf/black
rev: 23.3.0
hooks:
- id: black
language_version: python3.11
# - repo: https://github.com/pre-commit/mirrors-mypy
# rev: 'v1.3.0'
# hooks:
# - id: mypy
- repo: local
hooks:
- id: autoflake
name: autoflake
entry: autoflake --in-place --remove-all-unused-imports --recursive --ignore-init-module-imports --ignore-pass-after-docstring forge/autogpt
language: python
types: [ python ]
# Mono repo has bronken this TODO: fix
# - id: pytest-check
# name: pytest-check
# entry: pytest
# language: system
# pass_filenames: false
# always_run: true

View File

@@ -9,27 +9,24 @@ from forge.logging.config import configure_logging
logger = logging.getLogger(__name__)
logo = """\n\n
d8888 888 .d8888b. 8888888b. 88888888888
d88888 888 d88P Y88b 888 Y88b 888
d88P888 888 888 888 888 888 888
d88P 888 888 888 888888 .d88b. 888 888 d88P 888
d88P 888 888 888 888 d88""88b 888 88888 8888888P" 888
d88P 888 888 888 888 888 888 888 888 888 888
d8888888888 Y88b 888 Y88b. Y88..88P Y88b d88P 888 888
d88P 888 "Y88888 "Y888 "Y88P" "Y8888P88 888 888
8888888888
888
888
8888888 .d88b. 888d888 .d88b. .d88b.
888 d88""88b 888P" d88P"88b d8P Y8b
888 888 888 888 888 888 88888888
888 Y88..88P 888 Y88b 888 Y8b.
888 "Y88P" 888 "Y88888 "Y8888
888
Y8b d88P
d8888 888 .d8888b. 8888888b. 88888888888
d88P888 888 888 888 888 888 888
d88P 888 888 888 888888 .d88b. 888 888 d88P 888
d88P 888 888 888 888 d88""88b 888 88888 8888888P" 888
d88P 888 888 888 888 888 888 888 888 888 888
d8888888888 Y88b 888 Y88b. Y88..88P Y88b d88P 888 888
d88P 888 "Y88888 "Y888 "Y88P" "Y8888P88 888 888
8888888888
888
888 .d88b. 888d888 .d88b. .d88b.
888888 d88""88b 888P" d88P"88b d8P Y8b
888 888 888 888 888 888 88888888
888 Y88..88P 888 Y88b 888 Y8b.
888 "Y88P" 888 "Y88888 "Y8888
888
Y8b d88P
"Y88P" v0.1.0
\n"""

View File

@@ -1,15 +1,7 @@
from .base import AgentMeta, BaseAgent, BaseAgentConfiguration, BaseAgentSettings
from .components import (
AgentComponent,
ComponentEndpointError,
ComponentSystemError,
EndpointPipelineError,
)
from .protocols import (
AfterExecute,
AfterParse,
CommandProvider,
DirectiveProvider,
ExecutionFailure,
MessageProvider,
)
from .base import BaseAgent, BaseAgentConfiguration, BaseAgentSettings
__all__ = [
"BaseAgent",
"BaseAgentConfiguration",
"BaseAgentSettings",
]

View File

@@ -24,7 +24,6 @@ from forge.agent_protocol.models.task import (
TaskStepsListResponse,
)
from forge.file_storage.base import FileStorage
from forge.utils.exceptions import NotFoundError
logger = logging.getLogger(__name__)
@@ -79,7 +78,8 @@ class Agent:
else:
logger.warning(
f"Frontend not found. {frontend_path} does not exist. The frontend will not be served"
f"Frontend not found. {frontend_path} does not exist. "
"The frontend will not be served."
)
app.add_middleware(AgentMiddleware, agent=self)
@@ -94,34 +94,25 @@ class Agent:
"""
Create a task for the agent.
"""
try:
task = await self.db.create_task(
input=task_request.input,
additional_input=task_request.additional_input,
)
return task
except Exception as e:
raise
task = await self.db.create_task(
input=task_request.input,
additional_input=task_request.additional_input,
)
return task
async def list_tasks(self, page: int = 1, pageSize: int = 10) -> TaskListResponse:
"""
List all tasks that the agent has created.
"""
try:
tasks, pagination = await self.db.list_tasks(page, pageSize)
response = TaskListResponse(tasks=tasks, pagination=pagination)
return response
except Exception as e:
raise
tasks, pagination = await self.db.list_tasks(page, pageSize)
response = TaskListResponse(tasks=tasks, pagination=pagination)
return response
async def get_task(self, task_id: str) -> Task:
"""
Get a task by ID.
"""
try:
task = await self.db.get_task(task_id)
except Exception as e:
raise
task = await self.db.get_task(task_id)
return task
async def list_steps(
@@ -130,12 +121,9 @@ class Agent:
"""
List the IDs of all steps that the task has created.
"""
try:
steps, pagination = await self.db.list_steps(task_id, page, pageSize)
response = TaskStepsListResponse(steps=steps, pagination=pagination)
return response
except Exception as e:
raise
steps, pagination = await self.db.list_steps(task_id, page, pageSize)
response = TaskStepsListResponse(steps=steps, pagination=pagination)
return response
async def execute_step(self, task_id: str, step_request: StepRequestBody) -> Step:
"""
@@ -147,11 +135,8 @@ class Agent:
"""
Get a step by ID.
"""
try:
step = await self.db.get_step(task_id, step_id)
return step
except Exception as e:
raise
step = await self.db.get_step(task_id, step_id)
return step
async def list_artifacts(
self, task_id: str, page: int = 1, pageSize: int = 10
@@ -159,62 +144,45 @@ class Agent:
"""
List the artifacts that the task has created.
"""
try:
artifacts, pagination = await self.db.list_artifacts(
task_id, page, pageSize
)
return TaskArtifactsListResponse(artifacts=artifacts, pagination=pagination)
except Exception as e:
raise
artifacts, pagination = await self.db.list_artifacts(task_id, page, pageSize)
return TaskArtifactsListResponse(artifacts=artifacts, pagination=pagination)
async def create_artifact(
self, task_id: str, file: UploadFile, relative_path: str
self, task_id: str, file: UploadFile, relative_path: str = ""
) -> Artifact:
"""
Create an artifact for the task.
"""
data = None
file_name = file.filename or str(uuid4())
try:
data = b""
while contents := file.file.read(1024 * 1024):
data += contents
# Check if relative path ends with filename
if relative_path.endswith(file_name):
file_path = relative_path
else:
file_path = os.path.join(relative_path, file_name)
data = b""
while contents := file.file.read(1024 * 1024):
data += contents
# Check if relative path ends with filename
if relative_path.endswith(file_name):
file_path = relative_path
else:
file_path = os.path.join(relative_path, file_name)
await self.workspace.write_file(file_path, data)
await self.workspace.write_file(file_path, data)
artifact = await self.db.create_artifact(
task_id=task_id,
file_name=file_name,
relative_path=relative_path,
agent_created=False,
)
except Exception as e:
raise
artifact = await self.db.create_artifact(
task_id=task_id,
file_name=file_name,
relative_path=relative_path,
agent_created=False,
)
return artifact
async def get_artifact(self, task_id: str, artifact_id: str) -> Artifact:
async def get_artifact(self, task_id: str, artifact_id: str) -> StreamingResponse:
"""
Get an artifact by ID.
"""
try:
artifact = await self.db.get_artifact(artifact_id)
if artifact.file_name not in artifact.relative_path:
file_path = os.path.join(artifact.relative_path, artifact.file_name)
else:
file_path = artifact.relative_path
retrieved_artifact = self.workspace.read_file(file_path)
except NotFoundError as e:
raise
except FileNotFoundError as e:
raise
except Exception as e:
raise
artifact = await self.db.get_artifact(artifact_id)
if artifact.file_name not in artifact.relative_path:
file_path = os.path.join(artifact.relative_path, artifact.file_name)
else:
file_path = artifact.relative_path
retrieved_artifact = self.workspace.read_file(file_path, binary=True)
return StreamingResponse(
BytesIO(retrieved_artifact),

View File

@@ -1,6 +1,7 @@
from pathlib import Path
import pytest
from fastapi import UploadFile
from forge.agent_protocol.database.db import AgentDB
from forge.agent_protocol.models.task import (
@@ -16,16 +17,23 @@ from .agent import Agent
@pytest.fixture
def agent():
def agent(test_workspace: Path):
db = AgentDB("sqlite:///test.db")
config = FileStorageConfiguration(root=Path("./test_workspace"))
config = FileStorageConfiguration(root=test_workspace)
workspace = LocalFileStorage(config)
return Agent(db, workspace)
@pytest.mark.skip
@pytest.fixture
def file_upload():
this_file = Path(__file__)
file_handle = this_file.open("rb")
yield UploadFile(file_handle, filename=this_file.name)
file_handle.close()
@pytest.mark.asyncio
async def test_create_task(agent):
async def test_create_task(agent: Agent):
task_request = TaskRequestBody(
input="test_input", additional_input={"input": "additional_test_input"}
)
@@ -33,20 +41,18 @@ async def test_create_task(agent):
assert task.input == "test_input"
@pytest.mark.skip
@pytest.mark.asyncio
async def test_list_tasks(agent):
async def test_list_tasks(agent: Agent):
task_request = TaskRequestBody(
input="test_input", additional_input={"input": "additional_test_input"}
)
task = await agent.create_task(task_request)
await agent.create_task(task_request)
tasks = await agent.list_tasks()
assert isinstance(tasks, TaskListResponse)
@pytest.mark.skip
@pytest.mark.asyncio
async def test_get_task(agent):
async def test_get_task(agent: Agent):
task_request = TaskRequestBody(
input="test_input", additional_input={"input": "additional_test_input"}
)
@@ -55,9 +61,9 @@ async def test_get_task(agent):
assert retrieved_task.task_id == task.task_id
@pytest.mark.skip
@pytest.mark.xfail(reason="execute_step is not implemented")
@pytest.mark.asyncio
async def test_create_and_execute_step(agent):
async def test_execute_step(agent: Agent):
task_request = TaskRequestBody(
input="test_input", additional_input={"input": "additional_test_input"}
)
@@ -65,14 +71,14 @@ async def test_create_and_execute_step(agent):
step_request = StepRequestBody(
input="step_input", additional_input={"input": "additional_test_input"}
)
step = await agent.create_and_execute_step(task.task_id, step_request)
step = await agent.execute_step(task.task_id, step_request)
assert step.input == "step_input"
assert step.additional_input == {"input": "additional_test_input"}
@pytest.mark.skip
@pytest.mark.xfail(reason="execute_step is not implemented")
@pytest.mark.asyncio
async def test_get_step(agent):
async def test_get_step(agent: Agent):
task_request = TaskRequestBody(
input="test_input", additional_input={"input": "additional_test_input"}
)
@@ -80,38 +86,52 @@ async def test_get_step(agent):
step_request = StepRequestBody(
input="step_input", additional_input={"input": "additional_test_input"}
)
step = await agent.create_and_execute_step(task.task_id, step_request)
step = await agent.execute_step(task.task_id, step_request)
retrieved_step = await agent.get_step(task.task_id, step.step_id)
assert retrieved_step.step_id == step.step_id
@pytest.mark.skip
@pytest.mark.asyncio
async def test_list_artifacts(agent):
artifacts = await agent.list_artifacts()
assert isinstance(artifacts, list)
async def test_list_artifacts(agent: Agent):
tasks = await agent.list_tasks()
assert tasks.tasks, "No tasks in test.db"
artifacts = await agent.list_artifacts(tasks.tasks[0].task_id)
assert isinstance(artifacts.artifacts, list)
@pytest.mark.skip
@pytest.mark.asyncio
async def test_create_artifact(agent):
async def test_create_artifact(agent: Agent, file_upload: UploadFile):
task_request = TaskRequestBody(
input="test_input", additional_input={"input": "additional_test_input"}
)
task = await agent.create_task(task_request)
artifact_request = ArtifactRequestBody(file=None, uri="test_uri")
artifact = await agent.create_artifact(task.task_id, artifact_request)
assert artifact.uri == "test_uri"
artifact = await agent.create_artifact(
task_id=task.task_id,
file=file_upload,
relative_path=f"a_dir/{file_upload.filename}",
)
assert artifact.file_name == file_upload.filename
assert artifact.relative_path == f"a_dir/{file_upload.filename}"
@pytest.mark.skip
@pytest.mark.asyncio
async def test_get_artifact(agent):
async def test_create_and_get_artifact(agent: Agent, file_upload: UploadFile):
task_request = TaskRequestBody(
input="test_input", additional_input={"input": "additional_test_input"}
)
task = await agent.create_task(task_request)
artifact_request = ArtifactRequestBody(file=None, uri="test_uri")
artifact = await agent.create_artifact(task.task_id, artifact_request)
artifact = await agent.create_artifact(
task_id=task.task_id,
file=file_upload,
relative_path=f"b_dir/{file_upload.filename}",
)
await file_upload.seek(0)
file_upload_content = await file_upload.read()
retrieved_artifact = await agent.get_artifact(task.task_id, artifact.artifact_id)
assert retrieved_artifact.artifact_id == artifact.artifact_id
retrieved_artifact_content = bytearray()
async for b in retrieved_artifact.body_iterator:
retrieved_artifact_content.extend(b) # type: ignore
assert retrieved_artifact_content == file_upload_content

View File

@@ -5,22 +5,21 @@ import inspect
import logging
from abc import ABCMeta, abstractmethod
from typing import (
TYPE_CHECKING,
Any,
Awaitable,
Callable,
Generic,
Iterator,
Optional,
ParamSpec,
TypeVar,
cast,
overload,
)
from colorama import Fore
from pydantic import BaseModel, Field, validator
if TYPE_CHECKING:
from forge.models.action import ActionProposal, ActionResult
from forge.agent import protocols
from forge.agent.components import (
AgentComponent,
@@ -29,15 +28,10 @@ from forge.agent.components import (
)
from forge.config.ai_directives import AIDirectives
from forge.config.ai_profile import AIProfile
from forge.config.config import ConfigBuilder
from forge.llm.providers import CHAT_MODELS, ModelName, OpenAIModelName
from forge.llm.providers.schema import ChatModelInfo
from forge.models.config import (
Configurable,
SystemConfiguration,
SystemSettings,
UserConfigurable,
)
from forge.models.action import ActionResult, AnyProposal
from forge.models.config import SystemConfiguration, SystemSettings, UserConfigurable
logger = logging.getLogger(__name__)
@@ -133,17 +127,7 @@ class AgentMeta(ABCMeta):
return instance
class BaseAgent(Configurable[BaseAgentSettings], metaclass=AgentMeta):
C = TypeVar("C", bound=AgentComponent)
default_settings = BaseAgentSettings(
name="BaseAgent",
description=__doc__ if __doc__ else "",
)
class BaseAgent(Generic[AnyProposal], metaclass=AgentMeta):
def __init__(
self,
settings: BaseAgentSettings,
@@ -173,13 +157,13 @@ class BaseAgent(Configurable[BaseAgentSettings], metaclass=AgentMeta):
return self.config.send_token_limit or self.llm.max_tokens * 3 // 4
@abstractmethod
async def propose_action(self) -> ActionProposal:
async def propose_action(self) -> AnyProposal:
...
@abstractmethod
async def execute(
self,
proposal: ActionProposal,
proposal: AnyProposal,
user_feedback: str = "",
) -> ActionResult:
...
@@ -187,7 +171,7 @@ class BaseAgent(Configurable[BaseAgentSettings], metaclass=AgentMeta):
@abstractmethod
async def do_not_execute(
self,
denied_proposal: ActionProposal,
denied_proposal: AnyProposal,
user_feedback: str,
) -> ActionResult:
...
@@ -203,13 +187,16 @@ class BaseAgent(Configurable[BaseAgentSettings], metaclass=AgentMeta):
@overload
async def run_pipeline(
self, protocol_method: Callable[P, None], *args, retry_limit: int = 3
self,
protocol_method: Callable[P, None | Awaitable[None]],
*args,
retry_limit: int = 3,
) -> list[None]:
...
async def run_pipeline(
self,
protocol_method: Callable[P, Iterator[T] | None],
protocol_method: Callable[P, Iterator[T] | None | Awaitable[None]],
*args,
retry_limit: int = 3,
) -> list[T] | list[None]:
@@ -240,7 +227,10 @@ class BaseAgent(Configurable[BaseAgentSettings], metaclass=AgentMeta):
)
continue
method = getattr(component, method_name, None)
method = cast(
Callable[..., Iterator[T] | None | Awaitable[None]] | None,
getattr(component, method_name, None),
)
if not callable(method):
continue
@@ -248,10 +238,9 @@ class BaseAgent(Configurable[BaseAgentSettings], metaclass=AgentMeta):
while component_attempts < retry_limit:
try:
component_args = self._selective_copy(args)
if inspect.iscoroutinefunction(method):
result = await method(*component_args)
else:
result = method(*component_args)
result = method(*component_args)
if inspect.isawaitable(result):
result = await result
if result is not None:
method_result.extend(result)
args = component_args
@@ -269,9 +258,9 @@ class BaseAgent(Configurable[BaseAgentSettings], metaclass=AgentMeta):
break
# Successful pipeline execution
break
except EndpointPipelineError:
except EndpointPipelineError as e:
self._trace.append(
f"{Fore.LIGHTRED_EX}{component.__class__.__name__}: "
f"{Fore.LIGHTRED_EX}{e.triggerer.__class__.__name__}: "
f"EndpointPipelineError{Fore.RESET}"
)
# Restart from the beginning on EndpointPipelineError

View File

@@ -36,8 +36,9 @@ class AgentComponent(ABC):
class ComponentEndpointError(Exception):
"""Error of a single protocol method on a component."""
def __init__(self, message: str = ""):
def __init__(self, message: str, component: AgentComponent):
self.message = message
self.triggerer = component
super().__init__(message)

View File

@@ -1,14 +1,13 @@
from abc import abstractmethod
from typing import TYPE_CHECKING, Iterator
from typing import TYPE_CHECKING, Awaitable, Generic, Iterator
from forge.models.action import ActionResult, AnyProposal
from .components import AgentComponent
if TYPE_CHECKING:
from forge.command.command import Command
from forge.llm.providers import ChatMessage
from forge.models.action import ActionResult
from .base import ActionProposal
class DirectiveProvider(AgentComponent):
@@ -34,19 +33,19 @@ class MessageProvider(AgentComponent):
...
class AfterParse(AgentComponent):
class AfterParse(AgentComponent, Generic[AnyProposal]):
@abstractmethod
def after_parse(self, result: "ActionProposal") -> None:
def after_parse(self, result: AnyProposal) -> None | Awaitable[None]:
...
class ExecutionFailure(AgentComponent):
@abstractmethod
def execution_failure(self, error: Exception) -> None:
def execution_failure(self, error: Exception) -> None | Awaitable[None]:
...
class AfterExecute(AgentComponent):
@abstractmethod
def after_execute(self, result: "ActionResult") -> None:
def after_execute(self, result: "ActionResult") -> None | Awaitable[None]:
...

View File

@@ -1,39 +1,16 @@
"""
Routes for the Agent Service.
This module defines the API routes for the Agent service. While there are multiple endpoints provided by the service,
the ones that require special attention due to their complexity are:
This module defines the API routes for the Agent service.
1. `execute_agent_task_step`:
This route is significant because this is where the agent actually performs the work. The function handles
executing the next step for a task based on its current state, and it requires careful implementation to ensure
all scenarios (like the presence or absence of steps or a step marked as `last_step`) are handled correctly.
2. `upload_agent_task_artifacts`:
This route allows for the upload of artifacts, supporting various URI types (e.g., s3, gcs, ftp, http).
The support for different URI types makes it a bit more complex, and it's important to ensure that all
supported URI types are correctly managed. NOTE: The AutoGPT team will eventually handle the most common
uri types for you.
3. `create_agent_task`:
While this is a simpler route, it plays a crucial role in the workflow, as it's responsible for the creation
of a new task.
Developers and contributors should be especially careful when making modifications to these routes to ensure
consistency and correctness in the system's behavior.
Developers and contributors should be especially careful when making modifications
to these routes to ensure consistency and correctness in the system's behavior.
"""
import json
import logging
from typing import Optional
from typing import TYPE_CHECKING, Optional
from fastapi import APIRouter, Query, Request, Response, UploadFile
from fastapi.responses import FileResponse
from forge.utils.exceptions import (
NotFoundError,
get_detailed_traceback,
get_exception_message,
)
from fastapi import APIRouter, HTTPException, Query, Request, Response, UploadFile
from fastapi.responses import StreamingResponse
from .models import (
Artifact,
@@ -46,6 +23,9 @@ from .models import (
TaskStepsListResponse,
)
if TYPE_CHECKING:
from forge.agent.agent import Agent
base_router = APIRouter()
logger = logging.getLogger(__name__)
@@ -73,10 +53,10 @@ async def create_agent_task(request: Request, task_request: TaskRequestBody) ->
Args:
request (Request): FastAPI request object.
task (TaskRequestBody): The task request containing input and additional input data.
task (TaskRequestBody): The task request containing input data.
Returns:
Task: A new task with task_id, input, additional_input, and empty lists for artifacts and steps.
Task: A new task with task_id, input, and additional_input set.
Example:
Request (TaskRequestBody defined in schema.py):
@@ -93,46 +73,32 @@ async def create_agent_task(request: Request, task_request: TaskRequestBody) ->
"artifacts": [],
}
"""
agent = request["agent"]
agent: "Agent" = request["agent"]
try:
task_request = await agent.create_task(task_request)
return Response(
content=task_request.json(),
status_code=200,
media_type="application/json",
)
task = await agent.create_task(task_request)
return task
except Exception:
logger.exception(f"Error whilst trying to create a task: {task_request}")
return Response(
content=json.dumps(
{
"error": "Internal server error",
"exception": get_exception_message(),
"traceback": get_detailed_traceback(),
}
),
status_code=500,
media_type="application/json",
)
raise
@base_router.get("/agent/tasks", tags=["agent"], response_model=TaskListResponse)
async def list_agent_tasks(
request: Request,
page: Optional[int] = Query(1, ge=1),
page_size: Optional[int] = Query(10, ge=1),
page: int = Query(1, ge=1),
page_size: int = Query(10, ge=1),
) -> TaskListResponse:
"""
Retrieves a paginated list of all tasks.
Args:
request (Request): FastAPI request object.
page (int, optional): The page number for pagination. Defaults to 1.
page_size (int, optional): The number of tasks per page for pagination. Defaults to 10.
page (int, optional): Page number for pagination. Default: 1
page_size (int, optional): Number of tasks per page for pagination. Default: 10
Returns:
TaskListResponse: A response object containing a list of tasks and pagination details.
TaskListResponse: A list of tasks, and pagination details.
Example:
Request:
@@ -158,34 +124,13 @@ async def list_agent_tasks(
}
}
"""
agent = request["agent"]
agent: "Agent" = request["agent"]
try:
tasks = await agent.list_tasks(page, page_size)
return Response(
content=tasks.json(),
status_code=200,
media_type="application/json",
)
except NotFoundError:
logger.exception("Error whilst trying to list tasks")
return Response(
content=json.dumps({"error": "Tasks not found"}),
status_code=404,
media_type="application/json",
)
return tasks
except Exception:
logger.exception("Error whilst trying to list tasks")
return Response(
content=json.dumps(
{
"error": "Internal server error",
"exception": get_exception_message(),
"traceback": get_detailed_traceback(),
}
),
status_code=500,
media_type="application/json",
)
raise
@base_router.get("/agent/tasks/{task_id}", tags=["agent"], response_model=Task)
@@ -239,36 +184,14 @@ async def get_agent_task(request: Request, task_id: str) -> Task:
}
]
}
"""
agent = request["agent"]
""" # noqa: E501
agent: "Agent" = request["agent"]
try:
task = await agent.get_task(task_id)
return Response(
content=task.json(),
status_code=200,
media_type="application/json",
)
except NotFoundError:
logger.exception(f"Error whilst trying to get task: {task_id}")
return Response(
content=json.dumps({"error": "Task not found"}),
status_code=404,
media_type="application/json",
)
return task
except Exception:
logger.exception(f"Error whilst trying to get task: {task_id}")
return Response(
content=json.dumps(
{
"error": "Internal server error",
"exception": get_exception_message(),
"traceback": get_detailed_traceback(),
}
),
status_code=500,
media_type="application/json",
)
raise
@base_router.get(
@@ -279,8 +202,8 @@ async def get_agent_task(request: Request, task_id: str) -> Task:
async def list_agent_task_steps(
request: Request,
task_id: str,
page: Optional[int] = Query(1, ge=1),
page_size: Optional[int] = Query(10, ge=1, alias="pageSize"),
page: int = Query(1, ge=1),
page_size: int = Query(10, ge=1, alias="pageSize"),
) -> TaskStepsListResponse:
"""
Retrieves a paginated list of steps associated with a specific task.
@@ -289,10 +212,10 @@ async def list_agent_task_steps(
request (Request): FastAPI request object.
task_id (str): The ID of the task.
page (int, optional): The page number for pagination. Defaults to 1.
page_size (int, optional): The number of steps per page for pagination. Defaults to 10.
page_size (int, optional): Number of steps per page for pagination. Default: 10.
Returns:
TaskStepsListResponse: A response object containing a list of steps and pagination details.
TaskStepsListResponse: A list of steps, and pagination details.
Example:
Request:
@@ -315,54 +238,40 @@ async def list_agent_task_steps(
"pageSize": 10
}
}
"""
agent = request["agent"]
""" # noqa: E501
agent: "Agent" = request["agent"]
try:
steps = await agent.list_steps(task_id, page, page_size)
return Response(
content=steps.json(),
status_code=200,
media_type="application/json",
)
except NotFoundError:
logger.exception("Error whilst trying to list steps")
return Response(
content=json.dumps({"error": "Steps not found"}),
status_code=404,
media_type="application/json",
)
return steps
except Exception:
logger.exception("Error whilst trying to list steps")
return Response(
content=json.dumps(
{
"error": "Internal server error",
"exception": get_exception_message(),
"traceback": get_detailed_traceback(),
}
),
status_code=500,
media_type="application/json",
)
raise
@base_router.post("/agent/tasks/{task_id}/steps", tags=["agent"], response_model=Step)
async def execute_agent_task_step(
request: Request, task_id: str, step: Optional[StepRequestBody] = None
request: Request, task_id: str, step_request: Optional[StepRequestBody] = None
) -> Step:
"""
Executes the next step for a specified task based on the current task status and returns the
executed step with additional feedback fields.
Executes the next step for a specified task based on the current task status and
returns the executed step with additional feedback fields.
Depending on the current state of the task, the following scenarios are supported:
This route is significant because this is where the agent actually performs work.
The function handles executing the next step for a task based on its current state,
and it requires careful implementation to ensure all scenarios (like the presence
or absence of steps or a step marked as `last_step`) are handled correctly.
Depending on the current state of the task, the following scenarios are possible:
1. No steps exist for the task.
2. There is at least one step already for the task, and the task does not have a completed step marked as `last_step`.
2. There is at least one step already for the task, and the task does not have a
completed step marked as `last_step`.
3. There is a completed step marked as `last_step` already on the task.
In each of these scenarios, a step object will be returned with two additional fields: `output` and `additional_output`.
In each of these scenarios, a step object will be returned with two additional
fields: `output` and `additional_output`.
- `output`: Provides the primary response or feedback to the user.
- `additional_output`: Supplementary information or data. Its specific content is not strictly defined and can vary based on the step or agent's implementation.
- `additional_output`: Supplementary information or data. Its specific content is
not strictly defined and can vary based on the step or agent's implementation.
Args:
request (Request): FastAPI request object.
@@ -389,39 +298,17 @@ async def execute_agent_task_step(
...
}
"""
agent = request["agent"]
agent: "Agent" = request["agent"]
try:
# An empty step request represents a yes to continue command
if not step:
step = StepRequestBody(input="y")
if not step_request:
step_request = StepRequestBody(input="y")
step = await agent.execute_step(task_id, step)
return Response(
content=step.json(),
status_code=200,
media_type="application/json",
)
except NotFoundError:
logger.exception(f"Error whilst trying to execute a task step: {task_id}")
return Response(
content=json.dumps({"error": f"Task not found {task_id}"}),
status_code=404,
media_type="application/json",
)
step = await agent.execute_step(task_id, step_request)
return step
except Exception:
logger.exception(f"Error whilst trying to execute a task step: {task_id}")
return Response(
content=json.dumps(
{
"error": "Internal server error",
"exception": get_exception_message(),
"traceback": get_detailed_traceback(),
}
),
status_code=500,
media_type="application/json",
)
raise
@base_router.get(
@@ -450,31 +337,13 @@ async def get_agent_task_step(request: Request, task_id: str, step_id: str) -> S
...
}
"""
agent = request["agent"]
agent: "Agent" = request["agent"]
try:
step = await agent.get_step(task_id, step_id)
return Response(content=step.json(), status_code=200)
except NotFoundError:
logger.exception(f"Error whilst trying to get step: {step_id}")
return Response(
content=json.dumps({"error": "Step not found"}),
status_code=404,
media_type="application/json",
)
return step
except Exception:
logger.exception(f"Error whilst trying to get step: {step_id}")
return Response(
content=json.dumps(
{
"error": "Internal server error",
"exception": get_exception_message(),
"traceback": get_detailed_traceback(),
}
),
status_code=500,
media_type="application/json",
)
raise
@base_router.get(
@@ -485,8 +354,8 @@ async def get_agent_task_step(request: Request, task_id: str, step_id: str) -> S
async def list_agent_task_artifacts(
request: Request,
task_id: str,
page: Optional[int] = Query(1, ge=1),
page_size: Optional[int] = Query(10, ge=1, alias="pageSize"),
page: int = Query(1, ge=1),
page_size: int = Query(10, ge=1, alias="pageSize"),
) -> TaskArtifactsListResponse:
"""
Retrieves a paginated list of artifacts associated with a specific task.
@@ -495,10 +364,10 @@ async def list_agent_task_artifacts(
request (Request): FastAPI request object.
task_id (str): The ID of the task.
page (int, optional): The page number for pagination. Defaults to 1.
page_size (int, optional): The number of items per page for pagination. Defaults to 10.
page_size (int, optional): Number of items per page for pagination. Default: 10.
Returns:
TaskArtifactsListResponse: A response object containing a list of artifacts and pagination details.
TaskArtifactsListResponse: A list of artifacts, and pagination details.
Example:
Request:
@@ -518,52 +387,33 @@ async def list_agent_task_artifacts(
"pageSize": 10
}
}
"""
agent = request["agent"]
""" # noqa: E501
agent: "Agent" = request["agent"]
try:
artifacts: TaskArtifactsListResponse = await agent.list_artifacts(
task_id, page, page_size
)
artifacts = await agent.list_artifacts(task_id, page, page_size)
return artifacts
except NotFoundError:
logger.exception("Error whilst trying to list artifacts")
return Response(
content=json.dumps({"error": "Artifacts not found for task_id"}),
status_code=404,
media_type="application/json",
)
except Exception:
logger.exception("Error whilst trying to list artifacts")
return Response(
content=json.dumps(
{
"error": "Internal server error",
"exception": get_exception_message(),
"traceback": get_detailed_traceback(),
}
),
status_code=500,
media_type="application/json",
)
raise
@base_router.post(
"/agent/tasks/{task_id}/artifacts", tags=["agent"], response_model=Artifact
)
async def upload_agent_task_artifacts(
request: Request, task_id: str, file: UploadFile, relative_path: Optional[str] = ""
request: Request, task_id: str, file: UploadFile, relative_path: str = ""
) -> Artifact:
"""
This endpoint is used to upload an artifact associated with a specific task. The artifact is provided as a file.
This endpoint is used to upload an artifact (file) associated with a specific task.
Args:
request (Request): The FastAPI request object.
task_id (str): The unique identifier of the task for which the artifact is being uploaded.
task_id (str): The ID of the task for which the artifact is being uploaded.
file (UploadFile): The file being uploaded as an artifact.
relative_path (str): The relative path for the file. This is a query parameter.
Returns:
Artifact: An object containing metadata of the uploaded artifact, including its unique identifier.
Artifact: Metadata object for the uploaded artifact, including its ID and path.
Example:
Request:
@@ -579,35 +429,17 @@ async def upload_agent_task_artifacts(
"relative_path": "/my_folder/my_other_folder/",
"file_name": "main.py"
}
"""
agent = request["agent"]
""" # noqa: E501
agent: "Agent" = request["agent"]
if file is None:
return Response(
content=json.dumps({"error": "File must be specified"}),
status_code=404,
media_type="application/json",
)
raise HTTPException(status_code=400, detail="File must be specified")
try:
artifact = await agent.create_artifact(task_id, file, relative_path)
return Response(
content=artifact.json(),
status_code=200,
media_type="application/json",
)
return artifact
except Exception:
logger.exception(f"Error whilst trying to upload artifact: {task_id}")
return Response(
content=json.dumps(
{
"error": "Internal server error",
"exception": get_exception_message(),
"traceback": get_detailed_traceback(),
}
),
status_code=500,
media_type="application/json",
)
raise
@base_router.get(
@@ -617,7 +449,7 @@ async def upload_agent_task_artifacts(
)
async def download_agent_task_artifact(
request: Request, task_id: str, artifact_id: str
) -> FileResponse:
) -> StreamingResponse:
"""
Downloads an artifact associated with a specific task.
@@ -636,32 +468,9 @@ async def download_agent_task_artifact(
Response:
<file_content_of_artifact>
"""
agent = request["agent"]
agent: "Agent" = request["agent"]
try:
return await agent.get_artifact(task_id, artifact_id)
except NotFoundError:
logger.exception(f"Error whilst trying to download artifact: {task_id}")
return Response(
content=json.dumps(
{
"error": f"Artifact not found "
"- task_id: {task_id}, artifact_id: {artifact_id}"
}
),
status_code=404,
media_type="application/json",
)
except Exception:
logger.exception(f"Error whilst trying to download artifact: {task_id}")
return Response(
content=json.dumps(
{
"error": f"Internal server error "
"- task_id: {task_id}, artifact_id: {artifact_id}",
"exception": get_exception_message(),
"traceback": get_detailed_traceback(),
}
),
status_code=500,
media_type="application/json",
)
raise

View File

@@ -1 +1,3 @@
from .db import AgentDB
__all__ = ["AgentDB"]

View File

@@ -4,23 +4,22 @@ It uses SQLite as the database and file store backend.
IT IS NOT ADVISED TO USE THIS IN PRODUCTION!
"""
import datetime
import logging
import math
import uuid
from datetime import datetime
from typing import Any, Dict, List, Literal, Optional, Tuple
from sqlalchemy import (
JSON,
Boolean,
Column,
DateTime,
ForeignKey,
String,
create_engine,
)
from sqlalchemy import JSON, Boolean, DateTime, ForeignKey, create_engine
from sqlalchemy.exc import SQLAlchemyError
from sqlalchemy.orm import DeclarativeBase, joinedload, relationship, sessionmaker
from sqlalchemy.orm import (
DeclarativeBase,
Mapped,
joinedload,
mapped_column,
relationship,
sessionmaker,
)
from forge.utils.exceptions import NotFoundError
@@ -32,18 +31,20 @@ logger = logging.getLogger(__name__)
class Base(DeclarativeBase):
pass
type_annotation_map = {
dict[str, Any]: JSON,
}
class TaskModel(Base):
__tablename__ = "tasks"
task_id = Column(String, primary_key=True, index=True)
input = Column(String)
additional_input = Column(JSON)
created_at = Column(DateTime, default=datetime.datetime.utcnow)
modified_at = Column(
DateTime, default=datetime.datetime.utcnow, onupdate=datetime.datetime.utcnow
task_id: Mapped[str] = mapped_column(primary_key=True, index=True)
input: Mapped[str]
additional_input: Mapped[dict[str, Any]] = mapped_column(default=dict)
created_at: Mapped[datetime] = mapped_column(DateTime, default=datetime.utcnow)
modified_at: Mapped[datetime] = mapped_column(
DateTime, default=datetime.utcnow, onupdate=datetime.utcnow
)
artifacts = relationship("ArtifactModel", back_populates="task")
@@ -52,35 +53,35 @@ class TaskModel(Base):
class StepModel(Base):
__tablename__ = "steps"
step_id = Column(String, primary_key=True, index=True)
task_id = Column(String, ForeignKey("tasks.task_id"))
name = Column(String)
input = Column(String)
status = Column(String)
output = Column(String)
is_last = Column(Boolean, default=False)
created_at = Column(DateTime, default=datetime.datetime.utcnow)
modified_at = Column(
DateTime, default=datetime.datetime.utcnow, onupdate=datetime.datetime.utcnow
step_id: Mapped[str] = mapped_column(primary_key=True, index=True)
task_id: Mapped[str] = mapped_column(ForeignKey("tasks.task_id"))
name: Mapped[str]
input: Mapped[str]
status: Mapped[str]
output: Mapped[Optional[str]]
is_last: Mapped[bool] = mapped_column(Boolean, default=False)
created_at: Mapped[datetime] = mapped_column(DateTime, default=datetime.utcnow)
modified_at: Mapped[datetime] = mapped_column(
DateTime, default=datetime.utcnow, onupdate=datetime.utcnow
)
additional_input = Column(JSON)
additional_output = Column(JSON)
additional_input: Mapped[dict[str, Any]] = mapped_column(default=dict)
additional_output: Mapped[Optional[dict[str, Any]]]
artifacts = relationship("ArtifactModel", back_populates="step")
class ArtifactModel(Base):
__tablename__ = "artifacts"
artifact_id = Column(String, primary_key=True, index=True)
task_id = Column(String, ForeignKey("tasks.task_id"))
step_id = Column(String, ForeignKey("steps.step_id"))
agent_created = Column(Boolean, default=False)
file_name = Column(String)
relative_path = Column(String)
created_at = Column(DateTime, default=datetime.datetime.utcnow)
modified_at = Column(
DateTime, default=datetime.datetime.utcnow, onupdate=datetime.datetime.utcnow
artifact_id: Mapped[str] = mapped_column(primary_key=True, index=True)
task_id: Mapped[str] = mapped_column(ForeignKey("tasks.task_id"))
step_id: Mapped[Optional[str]] = mapped_column(ForeignKey("steps.step_id"))
agent_created: Mapped[bool] = mapped_column(default=False)
file_name: Mapped[str]
relative_path: Mapped[str]
created_at: Mapped[datetime] = mapped_column(default=datetime.utcnow)
modified_at: Mapped[datetime] = mapped_column(
default=datetime.utcnow, onupdate=datetime.utcnow
)
step = relationship("StepModel", back_populates="artifacts")
@@ -150,6 +151,10 @@ class AgentDB:
Base.metadata.create_all(self.engine)
self.Session = sessionmaker(bind=self.engine)
def close(self) -> None:
self.Session.close_all()
self.engine.dispose()
async def create_task(
self, input: Optional[str], additional_input: Optional[dict] = {}
) -> Task:
@@ -172,8 +177,6 @@ class AgentDB:
except SQLAlchemyError as e:
logger.error(f"SQLAlchemy error while creating task: {e}")
raise
except NotFoundError as e:
raise
except Exception as e:
logger.error(f"Unexpected error while creating task: {e}")
raise
@@ -207,8 +210,6 @@ class AgentDB:
except SQLAlchemyError as e:
logger.error(f"SQLAlchemy error while creating step: {e}")
raise
except NotFoundError as e:
raise
except Exception as e:
logger.error(f"Unexpected error while creating step: {e}")
raise
@@ -237,7 +238,7 @@ class AgentDB:
session.close()
if self.debug_enabled:
logger.debug(
f"Artifact already exists with relative_path: {relative_path}"
f"Artifact {file_name} already exists at {relative_path}/"
)
return convert_to_artifact(existing_artifact)
@@ -254,14 +255,12 @@ class AgentDB:
session.refresh(new_artifact)
if self.debug_enabled:
logger.debug(
f"Created new artifact with artifact_id: {new_artifact.artifact_id}"
f"Created new artifact with ID: {new_artifact.artifact_id}"
)
return convert_to_artifact(new_artifact)
except SQLAlchemyError as e:
logger.error(f"SQLAlchemy error while creating step: {e}")
raise
except NotFoundError as e:
raise
except Exception as e:
logger.error(f"Unexpected error while creating step: {e}")
raise
@@ -285,8 +284,6 @@ class AgentDB:
except SQLAlchemyError as e:
logger.error(f"SQLAlchemy error while getting task: {e}")
raise
except NotFoundError as e:
raise
except Exception as e:
logger.error(f"Unexpected error while getting task: {e}")
raise
@@ -312,8 +309,6 @@ class AgentDB:
except SQLAlchemyError as e:
logger.error(f"SQLAlchemy error while getting step: {e}")
raise
except NotFoundError as e:
raise
except Exception as e:
logger.error(f"Unexpected error while getting step: {e}")
raise
@@ -337,8 +332,6 @@ class AgentDB:
except SQLAlchemyError as e:
logger.error(f"SQLAlchemy error while getting artifact: {e}")
raise
except NotFoundError as e:
raise
except Exception as e:
logger.error(f"Unexpected error while getting artifact: {e}")
raise
@@ -375,14 +368,13 @@ class AgentDB:
return await self.get_step(task_id, step_id)
else:
logger.error(
f"Step not found for update with task_id: {task_id} and step_id: {step_id}"
"Can't update non-existent Step with "
f"task_id: {task_id} and step_id: {step_id}"
)
raise NotFoundError("Step not found")
except SQLAlchemyError as e:
logger.error(f"SQLAlchemy error while getting step: {e}")
raise
except NotFoundError as e:
raise
except Exception as e:
logger.error(f"Unexpected error while getting step: {e}")
raise
@@ -441,8 +433,6 @@ class AgentDB:
except SQLAlchemyError as e:
logger.error(f"SQLAlchemy error while listing tasks: {e}")
raise
except NotFoundError as e:
raise
except Exception as e:
logger.error(f"Unexpected error while listing tasks: {e}")
raise
@@ -475,8 +465,6 @@ class AgentDB:
except SQLAlchemyError as e:
logger.error(f"SQLAlchemy error while listing steps: {e}")
raise
except NotFoundError as e:
raise
except Exception as e:
logger.error(f"Unexpected error while listing steps: {e}")
raise
@@ -509,8 +497,6 @@ class AgentDB:
except SQLAlchemyError as e:
logger.error(f"SQLAlchemy error while listing artifacts: {e}")
raise
except NotFoundError as e:
raise
except Exception as e:
logger.error(f"Unexpected error while listing artifacts: {e}")
raise

View File

@@ -22,14 +22,27 @@ from forge.agent_protocol.models import (
)
from forge.utils.exceptions import NotFoundError as DataNotFoundError
TEST_DB_FILENAME = "test_db.sqlite3"
TEST_DB_URL = f"sqlite:///{TEST_DB_FILENAME}"
@pytest.mark.asyncio
def test_table_creation():
db_name = "sqlite:///test_db.sqlite3"
agent_db = AgentDB(db_name)
conn = sqlite3.connect("test_db.sqlite3")
cursor = conn.cursor()
@pytest.fixture
def agent_db():
db = AgentDB(TEST_DB_URL)
yield db
db.close()
os.remove(TEST_DB_FILENAME)
@pytest.fixture
def raw_db_connection(agent_db: AgentDB):
connection = sqlite3.connect(TEST_DB_FILENAME)
yield connection
connection.close()
def test_table_creation(raw_db_connection: sqlite3.Connection):
cursor = raw_db_connection.cursor()
# Test for tasks table existence
cursor.execute("SELECT name FROM sqlite_master WHERE type='table' AND name='tasks'")
@@ -45,8 +58,6 @@ def test_table_creation():
)
assert cursor.fetchone() is not None
os.remove(db_name.split("///")[1])
@pytest.mark.asyncio
async def test_task_schema():
@@ -84,7 +95,10 @@ async def test_step_schema():
name="Write to file",
input="Write the words you receive to the file 'output.txt'.",
status=StepStatus.created,
output="I am going to use the write_to_file command and write Washington to a file called output.txt <write_to_file('output.txt', 'Washington')>",
output=(
"I am going to use the write_to_file command and write Washington "
"to a file called output.txt <write_to_file('output.txt', 'Washington')>"
),
artifacts=[
Artifact(
artifact_id="b225e278-8b4c-4f99-a696-8facf19f0e56",
@@ -101,13 +115,13 @@ async def test_step_schema():
assert step.step_id == "6bb1801a-fd80-45e8-899a-4dd723cc602e"
assert step.name == "Write to file"
assert step.status == StepStatus.created
assert (
step.output
== "I am going to use the write_to_file command and write Washington to a file called output.txt <write_to_file('output.txt', 'Washington')>"
assert step.output == (
"I am going to use the write_to_file command and write Washington "
"to a file called output.txt <write_to_file('output.txt', 'Washington')>"
)
assert len(step.artifacts) == 1
assert step.artifacts[0].artifact_id == "b225e278-8b4c-4f99-a696-8facf19f0e56"
assert step.is_last == False
assert step.is_last is False
@pytest.mark.asyncio
@@ -118,6 +132,7 @@ async def test_convert_to_task():
created_at=now,
modified_at=now,
input="Write the words you receive to the file 'output.txt'.",
additional_input={},
artifacts=[
ArtifactModel(
artifact_id="b225e278-8b4c-4f99-a696-8facf19f0e56",
@@ -147,6 +162,7 @@ async def test_convert_to_step():
name="Write to file",
status="created",
input="Write the words you receive to the file 'output.txt'.",
additional_input={},
artifacts=[
ArtifactModel(
artifact_id="b225e278-8b4c-4f99-a696-8facf19f0e56",
@@ -166,7 +182,7 @@ async def test_convert_to_step():
assert step.status == StepStatus.created
assert len(step.artifacts) == 1
assert step.artifacts[0].artifact_id == "b225e278-8b4c-4f99-a696-8facf19f0e56"
assert step.is_last == False
assert step.is_last is False
@pytest.mark.asyncio
@@ -183,91 +199,67 @@ async def test_convert_to_artifact():
artifact = convert_to_artifact(artifact_model)
assert artifact.artifact_id == "b225e278-8b4c-4f99-a696-8facf19f0e56"
assert artifact.relative_path == "file:///path/to/main.py"
assert artifact.agent_created == True
assert artifact.agent_created is True
@pytest.mark.asyncio
async def test_create_task():
# Having issues with pytest fixture so added setup and teardown in each test as a rapid workaround
# TODO: Fix this!
db_name = "sqlite:///test_db.sqlite3"
agent_db = AgentDB(db_name)
async def test_create_task(agent_db: AgentDB):
task = await agent_db.create_task("task_input")
assert task.input == "task_input"
os.remove(db_name.split("///")[1])
@pytest.mark.asyncio
async def test_create_and_get_task():
db_name = "sqlite:///test_db.sqlite3"
agent_db = AgentDB(db_name)
async def test_create_and_get_task(agent_db: AgentDB):
task = await agent_db.create_task("test_input")
fetched_task = await agent_db.get_task(task.task_id)
assert fetched_task.input == "test_input"
os.remove(db_name.split("///")[1])
@pytest.mark.asyncio
async def test_get_task_not_found():
db_name = "sqlite:///test_db.sqlite3"
agent_db = AgentDB(db_name)
async def test_get_task_not_found(agent_db: AgentDB):
with pytest.raises(DataNotFoundError):
await agent_db.get_task(9999)
os.remove(db_name.split("///")[1])
await agent_db.get_task("9999")
@pytest.mark.asyncio
async def test_create_and_get_step():
db_name = "sqlite:///test_db.sqlite3"
agent_db = AgentDB(db_name)
async def test_create_and_get_step(agent_db: AgentDB):
task = await agent_db.create_task("task_input")
step_input = StepInput(type="python/code")
step_input = {"type": "python/code"}
request = StepRequestBody(input="test_input debug", additional_input=step_input)
step = await agent_db.create_step(task.task_id, request)
step = await agent_db.get_step(task.task_id, step.step_id)
assert step.input == "test_input debug"
os.remove(db_name.split("///")[1])
@pytest.mark.asyncio
async def test_updating_step():
db_name = "sqlite:///test_db.sqlite3"
agent_db = AgentDB(db_name)
async def test_updating_step(agent_db: AgentDB):
created_task = await agent_db.create_task("task_input")
step_input = StepInput(type="python/code")
step_input = {"type": "python/code"}
request = StepRequestBody(input="test_input debug", additional_input=step_input)
created_step = await agent_db.create_step(created_task.task_id, request)
await agent_db.update_step(created_task.task_id, created_step.step_id, "completed")
step = await agent_db.get_step(created_task.task_id, created_step.step_id)
assert step.status.value == "completed"
os.remove(db_name.split("///")[1])
@pytest.mark.asyncio
async def test_get_step_not_found():
db_name = "sqlite:///test_db.sqlite3"
agent_db = AgentDB(db_name)
async def test_get_step_not_found(agent_db: AgentDB):
with pytest.raises(DataNotFoundError):
await agent_db.get_step(9999, 9999)
os.remove(db_name.split("///")[1])
await agent_db.get_step("9999", "9999")
@pytest.mark.asyncio
async def test_get_artifact():
db_name = "sqlite:///test_db.sqlite3"
db = AgentDB(db_name)
async def test_get_artifact(agent_db: AgentDB):
# Given: A task and its corresponding artifact
task = await db.create_task("test_input debug")
step_input = StepInput(type="python/code")
task = await agent_db.create_task("test_input debug")
step_input = {"type": "python/code"}
requst = StepRequestBody(input="test_input debug", additional_input=step_input)
step = await db.create_step(task.task_id, requst)
step = await agent_db.create_step(task.task_id, requst)
# Create an artifact
artifact = await db.create_artifact(
artifact = await agent_db.create_artifact(
task_id=task.task_id,
file_name="test_get_artifact_sample_file.txt",
relative_path="file:///path/to/test_get_artifact_sample_file.txt",
@@ -276,7 +268,7 @@ async def test_get_artifact():
)
# When: The artifact is fetched by its ID
fetched_artifact = await db.get_artifact(artifact.artifact_id)
fetched_artifact = await agent_db.get_artifact(artifact.artifact_id)
# Then: The fetched artifact matches the original
assert fetched_artifact.artifact_id == artifact.artifact_id
@@ -285,47 +277,37 @@ async def test_get_artifact():
== "file:///path/to/test_get_artifact_sample_file.txt"
)
os.remove(db_name.split("///")[1])
@pytest.mark.asyncio
async def test_list_tasks():
db_name = "sqlite:///test_db.sqlite3"
db = AgentDB(db_name)
async def test_list_tasks(agent_db: AgentDB):
# Given: Multiple tasks in the database
task1 = await db.create_task("test_input_1")
task2 = await db.create_task("test_input_2")
task1 = await agent_db.create_task("test_input_1")
task2 = await agent_db.create_task("test_input_2")
# When: All tasks are fetched
fetched_tasks, pagination = await db.list_tasks()
fetched_tasks, pagination = await agent_db.list_tasks()
# Then: The fetched tasks list includes the created tasks
task_ids = [task.task_id for task in fetched_tasks]
assert task1.task_id in task_ids
assert task2.task_id in task_ids
os.remove(db_name.split("///")[1])
@pytest.mark.asyncio
async def test_list_steps():
db_name = "sqlite:///test_db.sqlite3"
db = AgentDB(db_name)
step_input = StepInput(type="python/code")
requst = StepRequestBody(input="test_input debug", additional_input=step_input)
async def test_list_steps(agent_db: AgentDB):
step_input = {"type": "python/code"}
request = StepRequestBody(input="test_input debug", additional_input=step_input)
# Given: A task and multiple steps for that task
task = await db.create_task("test_input")
step1 = await db.create_step(task.task_id, requst)
requst = StepRequestBody(input="step two", additional_input=step_input)
step2 = await db.create_step(task.task_id, requst)
task = await agent_db.create_task("test_input")
step1 = await agent_db.create_step(task.task_id, request)
request = StepRequestBody(input="step two")
step2 = await agent_db.create_step(task.task_id, request)
# When: All steps for the task are fetched
fetched_steps, pagination = await db.list_steps(task.task_id)
fetched_steps, pagination = await agent_db.list_steps(task.task_id)
# Then: The fetched steps list includes the created steps
step_ids = [step.step_id for step in fetched_steps]
assert step1.step_id in step_ids
assert step2.step_id in step_ids
os.remove(db_name.split("///")[1])

View File

@@ -1,4 +1,4 @@
from .artifact import Artifact, ArtifactUpload
from .artifact import Artifact
from .pagination import Pagination
from .task import (
Step,
@@ -10,3 +10,16 @@ from .task import (
TaskRequestBody,
TaskStepsListResponse,
)
__all__ = [
"Artifact",
"Pagination",
"Step",
"StepRequestBody",
"StepStatus",
"Task",
"TaskArtifactsListResponse",
"TaskListResponse",
"TaskRequestBody",
"TaskStepsListResponse",
]

View File

@@ -3,15 +3,6 @@ from datetime import datetime
from pydantic import BaseModel, Field
class ArtifactUpload(BaseModel):
file: str = Field(..., description="File to upload.", format="binary")
relative_path: str = Field(
...,
description="Relative path of the artifact in the agent's workspace.",
example="python/code",
)
class Artifact(BaseModel):
created_at: datetime = Field(
...,

View File

@@ -2,7 +2,7 @@ from __future__ import annotations
from datetime import datetime
from enum import Enum
from typing import List, Optional
from typing import Any, List, Optional
from pydantic import BaseModel, Field
@@ -17,7 +17,7 @@ class TaskRequestBody(BaseModel):
description="Input prompt for the task.",
example="Write the words you receive to the file 'output.txt'.",
)
additional_input: Optional[dict] = None
additional_input: dict[str, Any] = Field(default_factory=dict)
class Task(TaskRequestBody):
@@ -38,8 +38,8 @@ class Task(TaskRequestBody):
description="The ID of the task.",
example="50da533e-3904-4401-8a07-c49adf88b5eb",
)
artifacts: Optional[List[Artifact]] = Field(
[],
artifacts: list[Artifact] = Field(
default_factory=list,
description="A list of artifacts that the task has produced.",
example=[
"7a49f31c-f9c6-4346-a22c-e32bc5af4d8e",
@@ -50,14 +50,12 @@ class Task(TaskRequestBody):
class StepRequestBody(BaseModel):
name: Optional[str] = Field(
None, description="The name of the task step.", example="Write to file"
default=None, description="The name of the task step.", example="Write to file"
)
input: Optional[str] = Field(
None,
description="Input prompt for the step.",
example="Washington",
input: str = Field(
..., description="Input prompt for the step.", example="Washington"
)
additional_input: Optional[dict] = None
additional_input: dict[str, Any] = Field(default_factory=dict)
class StepStatus(Enum):
@@ -90,19 +88,23 @@ class Step(StepRequestBody):
example="6bb1801a-fd80-45e8-899a-4dd723cc602e",
)
name: Optional[str] = Field(
None, description="The name of the task step.", example="Write to file"
default=None, description="The name of the task step.", example="Write to file"
)
status: StepStatus = Field(
..., description="The status of the task step.", example="created"
)
output: Optional[str] = Field(
None,
default=None,
description="Output of the task step.",
example="I am going to use the write_to_file command and write Washington to a file called output.txt <write_to_file('output.txt', 'Washington')",
example=(
"I am going to use the write_to_file command and write Washington "
"to a file called output.txt <write_to_file('output.txt', 'Washington')"
),
)
additional_output: Optional[dict] = None
artifacts: Optional[List[Artifact]] = Field(
[], description="A list of artifacts that the step has produced."
additional_output: Optional[dict[str, Any]] = None
artifacts: list[Artifact] = Field(
default_factory=list,
description="A list of artifacts that the step has produced.",
)
is_last: bool = Field(
..., description="Whether this is the last step in the task.", example=True

View File

@@ -1,3 +1,5 @@
from .command import Command, CommandOutput, CommandParameter
from .command import Command
from .decorator import command
from .parameter import CommandParameter
__all__ = ["Command", "CommandParameter", "command"]

View File

@@ -1,14 +1,16 @@
from __future__ import annotations
import inspect
from typing import Any, Callable, Generic, ParamSpec, TypeVar
from typing import Callable, Concatenate, Generic, ParamSpec, TypeVar, cast
from forge.agent.protocols import CommandProvider
from .parameter import CommandParameter
CommandOutput = Any
P = ParamSpec("P")
CO = TypeVar("CO", bound=CommandOutput)
CO = TypeVar("CO") # command output
_CP = TypeVar("_CP", bound=CommandProvider)
class Command(Generic[P, CO]):
@@ -24,7 +26,7 @@ class Command(Generic[P, CO]):
self,
names: list[str],
description: str,
method: Callable[P, CO],
method: Callable[Concatenate[_CP, P], CO],
parameters: list[CommandParameter],
):
# Check if all parameters are provided
@@ -34,7 +36,9 @@ class Command(Generic[P, CO]):
)
self.names = names
self.description = description
self.method = method
# Method technically has a `self` parameter, but we can ignore that
# since Python passes it internally.
self.method = cast(Callable[P, CO], method)
self.parameters = parameters
@property
@@ -62,7 +66,8 @@ class Command(Generic[P, CO]):
def __str__(self) -> str:
params = [
f"{param.name}: "
+ ("%s" if param.spec.required else "Optional[%s]") % param.spec.type.value
+ ("%s" if param.spec.required else "Optional[%s]")
% (param.spec.type.value if param.spec.type else "Any")
for param in self.parameters
]
return (

View File

@@ -1,2 +1,4 @@
from .action_history import ActionHistoryComponent
from .model import Episode, EpisodicActionHistory
__all__ = ["ActionHistoryComponent", "Episode", "EpisodicActionHistory"]

View File

@@ -1,27 +1,27 @@
from __future__ import annotations
from typing import TYPE_CHECKING, Callable, Generic, Iterator, Optional
from typing import TYPE_CHECKING, Callable, Iterator, Optional
from forge.agent.protocols import AfterExecute, AfterParse, MessageProvider
from forge.llm.prompting.utils import indent
from forge.llm.providers import ChatMessage, ChatModelProvider
from forge.llm.providers import ChatMessage, MultiProvider
if TYPE_CHECKING:
from forge.config.config import Config
from .model import AP, ActionResult, Episode, EpisodicActionHistory
from .model import ActionResult, AnyProposal, Episode, EpisodicActionHistory
class ActionHistoryComponent(MessageProvider, AfterParse, AfterExecute, Generic[AP]):
class ActionHistoryComponent(MessageProvider, AfterParse[AnyProposal], AfterExecute):
"""Keeps track of the event history and provides a summary of the steps."""
def __init__(
self,
event_history: EpisodicActionHistory[AP],
event_history: EpisodicActionHistory[AnyProposal],
max_tokens: int,
count_tokens: Callable[[str], int],
legacy_config: Config,
llm_provider: ChatModelProvider,
llm_provider: MultiProvider,
) -> None:
self.event_history = event_history
self.max_tokens = max_tokens
@@ -37,7 +37,7 @@ class ActionHistoryComponent(MessageProvider, AfterParse, AfterExecute, Generic[
):
yield ChatMessage.system(f"## Progress on your Task so far\n\n{progress}")
def after_parse(self, result: AP) -> None:
def after_parse(self, result: AnyProposal) -> None:
self.event_history.register_action(result)
async def after_execute(self, result: ActionResult) -> None:
@@ -48,7 +48,7 @@ class ActionHistoryComponent(MessageProvider, AfterParse, AfterExecute, Generic[
def _compile_progress(
self,
episode_history: list[Episode],
episode_history: list[Episode[AnyProposal]],
max_tokens: Optional[int] = None,
count_tokens: Optional[Callable[[str], int]] = None,
) -> str:

View File

@@ -1,25 +1,23 @@
from __future__ import annotations
import asyncio
from typing import TYPE_CHECKING, Generic, Iterator, TypeVar
from typing import TYPE_CHECKING, Generic
from pydantic import Field
from pydantic.generics import GenericModel
from forge.content_processing.text import summarize_text
from forge.llm.prompting.utils import format_numbered_list, indent
from forge.models.action import ActionProposal, ActionResult
from forge.models.action import ActionResult, AnyProposal
from forge.models.utils import ModelWithSummary
if TYPE_CHECKING:
from forge.config.config import Config
from forge.llm.providers import ChatModelProvider
AP = TypeVar("AP", bound=ActionProposal)
from forge.llm.providers import MultiProvider
class Episode(GenericModel, Generic[AP]):
action: AP
class Episode(GenericModel, Generic[AnyProposal]):
action: AnyProposal
result: ActionResult | None
summary: str | None = None
@@ -54,32 +52,29 @@ class Episode(GenericModel, Generic[AP]):
return executed_action + action_result
class EpisodicActionHistory(GenericModel, Generic[AP]):
class EpisodicActionHistory(GenericModel, Generic[AnyProposal]):
"""Utility container for an action history"""
episodes: list[Episode[AP]] = Field(default_factory=list)
episodes: list[Episode[AnyProposal]] = Field(default_factory=list)
cursor: int = 0
_lock = asyncio.Lock()
@property
def current_episode(self) -> Episode[AP] | None:
def current_episode(self) -> Episode[AnyProposal] | None:
if self.cursor == len(self):
return None
return self[self.cursor]
def __getitem__(self, key: int) -> Episode[AP]:
def __getitem__(self, key: int) -> Episode[AnyProposal]:
return self.episodes[key]
def __iter__(self) -> Iterator[Episode[AP]]:
return iter(self.episodes)
def __len__(self) -> int:
return len(self.episodes)
def __bool__(self) -> bool:
return len(self.episodes) > 0
def register_action(self, action: AP) -> None:
def register_action(self, action: AnyProposal) -> None:
if not self.current_episode:
self.episodes.append(Episode(action=action, result=None))
assert self.current_episode
@@ -113,7 +108,7 @@ class EpisodicActionHistory(GenericModel, Generic[AP]):
self.cursor = len(self.episodes)
async def handle_compression(
self, llm_provider: ChatModelProvider, app_config: Config
self, llm_provider: MultiProvider, app_config: Config
) -> None:
"""Compresses each episode in the action history using an LLM.

View File

@@ -3,6 +3,11 @@ from .code_executor import (
DENYLIST_CONTROL,
CodeExecutionError,
CodeExecutorComponent,
is_docker_available,
we_are_running_in_a_docker_container,
)
__all__ = [
"ALLOWLIST_CONTROL",
"DENYLIST_CONTROL",
"CodeExecutionError",
"CodeExecutorComponent",
]

View File

@@ -11,7 +11,8 @@ import docker
from docker.errors import DockerException, ImageNotFound, NotFound
from docker.models.containers import Container as DockerContainer
from forge.agent import BaseAgentSettings, CommandProvider
from forge.agent import BaseAgentSettings
from forge.agent.protocols import CommandProvider
from forge.command import Command, command
from forge.config.config import Config
from forge.file_storage import FileStorage

View File

@@ -5,3 +5,11 @@ from .context_item import (
FolderContextItem,
StaticContextItem,
)
__all__ = [
"ContextComponent",
"ContextItem",
"FileContextItem",
"FolderContextItem",
"StaticContextItem",
]

View File

@@ -1 +1,3 @@
from .file_manager import FileManagerComponent
__all__ = ["FileManagerComponent"]

View File

@@ -1 +1,3 @@
from .git_operations import GitOperationsComponent
__all__ = ["GitOperationsComponent"]

View File

@@ -1 +1,3 @@
from .image_gen import ImageGeneratorComponent
__all__ = ["ImageGeneratorComponent"]

View File

@@ -32,7 +32,12 @@ class ImageGeneratorComponent(CommandProvider):
self.legacy_config = config
def get_commands(self) -> Iterator[Command]:
yield self.generate_image
if (
self.legacy_config.openai_credentials
or self.legacy_config.huggingface_api_token
or self.legacy_config.sd_webui_auth
):
yield self.generate_image
@command(
parameters={
@@ -60,17 +65,26 @@ class ImageGeneratorComponent(CommandProvider):
str: The filename of the image
"""
filename = self.workspace.root / f"{str(uuid.uuid4())}.jpg"
cfg = self.legacy_config
# DALL-E
if self.legacy_config.image_provider == "dalle":
if cfg.openai_credentials and (
cfg.image_provider == "dalle"
or not (cfg.huggingface_api_token or cfg.sd_webui_url)
):
return self.generate_image_with_dalle(prompt, filename, size)
# HuggingFace
elif self.legacy_config.image_provider == "huggingface":
elif cfg.huggingface_api_token and (
cfg.image_provider == "huggingface"
or not (cfg.openai_credentials or cfg.sd_webui_url)
):
return self.generate_image_with_hf(prompt, filename)
# SD WebUI
elif self.legacy_config.image_provider == "sdwebui":
elif cfg.sd_webui_url and (
cfg.image_provider == "sdwebui" or cfg.sd_webui_auth
):
return self.generate_image_with_sd_webui(prompt, filename, size)
return "No Image Provider Set"
return "Error: No image generation provider available"
def generate_image_with_hf(self, prompt: str, output_file: Path) -> str:
"""Generate an image with HuggingFace's API.
@@ -142,6 +156,7 @@ class ImageGeneratorComponent(CommandProvider):
Returns:
str: The filename of the image
"""
assert self.legacy_config.openai_credentials # otherwise this tool is disabled
# Check for supported image sizes
if size not in [256, 512, 1024]:
@@ -152,16 +167,19 @@ class ImageGeneratorComponent(CommandProvider):
)
size = closest
# TODO: integrate in `forge.llm.providers`(?)
response = OpenAI(
api_key=self.legacy_config.openai_credentials.api_key.get_secret_value()
).images.generate(
prompt=prompt,
n=1,
size=f"{size}x{size}",
# TODO: improve typing of size config item(s)
size=f"{size}x{size}", # type: ignore
response_format="b64_json",
)
assert response.data[0].b64_json is not None # response_format = "b64_json"
logger.info(f"Image Generated for prompt:{prompt}")
logger.info(f"Image Generated for prompt: {prompt}")
image_data = b64decode(response.data[0].b64_json)

View File

@@ -1 +1,3 @@
from .system import SystemComponent
__all__ = ["SystemComponent"]

View File

@@ -37,8 +37,8 @@ class SystemComponent(DirectiveProvider, MessageProvider, CommandProvider):
"You are unable to interact with physical objects. "
"If this is absolutely necessary to fulfill a task or objective or "
"to complete a step, you must ask the user to do it for you. "
"If the user refuses this, and there is no other way to achieve your goals, "
"you must terminate to avoid wasting time and energy."
"If the user refuses this, and there is no other way to achieve your "
"goals, you must terminate to avoid wasting time and energy."
)
def get_resources(self) -> Iterator[str]:

View File

@@ -1 +1,3 @@
from .user_interaction import UserInteractionComponent
__all__ = ["UserInteractionComponent"]

View File

@@ -1 +1,3 @@
from .watchdog import WatchdogComponent
__all__ = ["WatchdogComponent"]

View File

@@ -1,18 +1,20 @@
from __future__ import annotations
import logging
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from forge.agent.base import BaseAgentConfiguration
from forge.agent.components import ComponentSystemError
from forge.agent.protocols import AfterParse
from forge.components.action_history import EpisodicActionHistory
from forge.models.action import ActionProposal
from forge.models.action import AnyProposal
if TYPE_CHECKING:
from forge.agent.base import BaseAgentConfiguration
logger = logging.getLogger(__name__)
class WatchdogComponent(AfterParse):
class WatchdogComponent(AfterParse[AnyProposal]):
"""
Adds a watchdog feature to an agent class. Whenever the agent starts
looping, the watchdog will switch from the FAST_LLM to the SMART_LLM and re-think.
@@ -21,13 +23,13 @@ class WatchdogComponent(AfterParse):
def __init__(
self,
config: "BaseAgentConfiguration",
event_history: EpisodicActionHistory[ActionProposal],
event_history: EpisodicActionHistory[AnyProposal],
):
self.config = config
self.event_history = event_history
self.revert_big_brain = False
def after_parse(self, result: ActionProposal) -> None:
def after_parse(self, result: AnyProposal) -> None:
if self.revert_big_brain:
self.config.big_brain = False
self.revert_big_brain = False
@@ -58,4 +60,4 @@ class WatchdogComponent(AfterParse):
self.big_brain = True
self.revert_big_brain = True
# Trigger retry of all pipelines prior to this component
raise ComponentSystemError()
raise ComponentSystemError(rethink_reason, self)

View File

@@ -1,2 +1,4 @@
from .search import WebSearchComponent
from .selenium import BrowsingError, TooMuchOutputError, WebSeleniumComponent
from .selenium import BrowsingError, WebSeleniumComponent
__all__ = ["WebSearchComponent", "BrowsingError", "WebSeleniumComponent"]

View File

@@ -12,7 +12,6 @@ from selenium.webdriver.chrome.options import Options as ChromeOptions
from selenium.webdriver.chrome.service import Service as ChromeDriverService
from selenium.webdriver.chrome.webdriver import WebDriver as ChromeDriver
from selenium.webdriver.common.by import By
from selenium.webdriver.common.options import ArgOptions as BrowserOptions
from selenium.webdriver.edge.options import Options as EdgeOptions
from selenium.webdriver.edge.service import Service as EdgeDriverService
from selenium.webdriver.edge.webdriver import WebDriver as EdgeDriver
@@ -33,7 +32,7 @@ from forge.command import Command, command
from forge.config.config import Config
from forge.content_processing.html import extract_hyperlinks, format_hyperlinks
from forge.content_processing.text import extract_information, summarize_text
from forge.llm.providers.schema import ChatModelInfo, ChatModelProvider
from forge.llm.providers import ChatModelInfo, MultiProvider
from forge.models.json_schema import JSONSchema
from forge.utils.exceptions import CommandExecutionError, TooMuchOutputError
from forge.utils.url_validator import validate_url
@@ -45,6 +44,9 @@ MAX_RAW_CONTENT_LENGTH = 500
LINKS_TO_RETURN = 20
BrowserOptions = ChromeOptions | EdgeOptions | FirefoxOptions | SafariOptions
class BrowsingError(CommandExecutionError):
"""An error occurred while trying to browse the page"""
@@ -55,7 +57,7 @@ class WebSeleniumComponent(DirectiveProvider, CommandProvider):
def __init__(
self,
config: Config,
llm_provider: ChatModelProvider,
llm_provider: MultiProvider,
model_info: ChatModelInfo,
):
self.legacy_config = config
@@ -251,7 +253,7 @@ class WebSeleniumComponent(DirectiveProvider, CommandProvider):
if isinstance(options, FirefoxOptions):
if config.selenium_headless:
options.headless = True
options.headless = True # type: ignore
options.add_argument("--disable-gpu")
driver = FirefoxDriver(
service=GeckoDriverService(GeckoDriverManager().install()),

View File

@@ -200,7 +200,7 @@ class ConfigBuilder(Configurable[Config]):
if (
config.openai_credentials
and config.openai_credentials.api_type == "azure"
and config.openai_credentials.api_type == SecretStr("azure")
and (config_file := config.azure_config_file)
):
config.openai_credentials.load_azure_config(config_file)

8
forge/forge/conftest.py Normal file
View File

@@ -0,0 +1,8 @@
from pathlib import Path
import pytest
@pytest.fixture()
def test_workspace(tmp_path: Path) -> Path:
return tmp_path

View File

@@ -12,7 +12,7 @@ if TYPE_CHECKING:
from forge.json.parsing import extract_list_from_json
from forge.llm.prompting import ChatPrompt
from forge.llm.providers import ChatMessage, ChatModelProvider, ModelTokenizer
from forge.llm.providers import ChatMessage, ModelTokenizer, MultiProvider
logger = logging.getLogger(__name__)
@@ -56,7 +56,7 @@ def chunk_content(
async def summarize_text(
text: str,
llm_provider: ChatModelProvider,
llm_provider: MultiProvider,
config: Config,
question: Optional[str] = None,
instruction: Optional[str] = None,
@@ -89,7 +89,7 @@ async def summarize_text(
async def extract_information(
source_text: str,
topics_of_interest: list[str],
llm_provider: ChatModelProvider,
llm_provider: MultiProvider,
config: Config,
) -> list[str]:
fmt_topics_list = "\n".join(f"* {topic}." for topic in topics_of_interest)
@@ -113,7 +113,7 @@ async def extract_information(
async def _process_text(
text: str,
instruction: str,
llm_provider: ChatModelProvider,
llm_provider: MultiProvider,
config: Config,
output_type: type[str | list[str]] = str,
) -> tuple[str, list[tuple[str, str]]] | list[str]:
@@ -165,7 +165,7 @@ async def _process_text(
),
)
if output_type == list[str]:
if isinstance(response.parsed_result, list):
logger.debug(f"Raw LLM response: {repr(response.response.content)}")
fmt_result_bullet_list = "\n".join(f"* {r}" for r in response.parsed_result)
logger.debug(

View File

@@ -1,10 +1,7 @@
import enum
from pathlib import Path
from .base import FileStorage, FileStorageConfiguration
from .gcs import GCSFileStorage, GCSFileStorageConfiguration
from .local import LocalFileStorage
from .s3 import S3FileStorage, S3FileStorageConfiguration
from .base import FileStorage
class FileStorageBackendName(str, enum.Enum):
@@ -15,7 +12,7 @@ class FileStorageBackendName(str, enum.Enum):
def get_storage(
backend: FileStorageBackendName,
root_path: Path = ".",
root_path: Path = Path("."),
restrict_to_root: bool = True,
) -> FileStorage:
match backend:

View File

@@ -4,17 +4,17 @@ The FileStorage class provides an interface for interacting with a file storage.
from __future__ import annotations
import asyncio
import logging
import os
import shutil
import tempfile
from abc import ABC, abstractmethod
from contextlib import contextmanager
from io import IOBase, TextIOBase
from pathlib import Path
from typing import IO, Any, BinaryIO, Callable, Generator, Literal, TextIO, overload
from typing import Any, BinaryIO, Callable, Generator, Literal, TextIO, overload
from watchdog.events import FileSystemEventHandler
from watchdog.events import FileSystemEvent, FileSystemEventHandler
from watchdog.observers import Observer
from forge.models.config import SystemConfiguration
@@ -66,26 +66,29 @@ class FileStorage(ABC):
def open_file(
self,
path: str | Path,
mode: Literal["w", "r"] = "r",
mode: Literal["r", "w"] = "r",
binary: Literal[False] = False,
) -> TextIO | TextIOBase:
) -> TextIO:
"""Returns a readable text file-like object representing the file."""
@overload
@abstractmethod
def open_file(
self,
path: str | Path,
mode: Literal["w", "r"] = "r",
binary: Literal[True] = True,
) -> BinaryIO | IOBase:
self, path: str | Path, mode: Literal["r", "w"], binary: Literal[True]
) -> BinaryIO:
"""Returns a binary file-like object representing the file."""
@overload
@abstractmethod
def open_file(self, path: str | Path, *, binary: Literal[True]) -> BinaryIO:
"""Returns a readable binary file-like object representing the file."""
@overload
@abstractmethod
def open_file(
self, path: str | Path, mode: Literal["w", "r"] = "r", binary: bool = False
) -> IO | IOBase:
"""Returns a readable file-like object representing the file."""
self, path: str | Path, mode: Literal["r", "w"] = "r", binary: bool = False
) -> TextIO | BinaryIO:
"""Returns a file-like object representing the file."""
@overload
@abstractmethod
@@ -95,13 +98,15 @@ class FileStorage(ABC):
@overload
@abstractmethod
def read_file(self, path: str | Path, binary: Literal[True] = True) -> bytes:
def read_file(self, path: str | Path, binary: Literal[True]) -> bytes:
"""Read a file in the storage as binary."""
...
@overload
@abstractmethod
def read_file(self, path: str | Path, binary: bool = False) -> str | bytes:
"""Read a file in the storage."""
...
@abstractmethod
async def write_file(self, path: str | Path, content: str | bytes) -> None:
@@ -241,24 +246,32 @@ class FileSyncHandler(FileSystemEventHandler):
self.storage = storage
self.path = Path(path)
async def on_modified(self, event):
def on_modified(self, event: FileSystemEvent):
if event.is_directory:
return
file_path = Path(event.src_path).relative_to(self.path)
content = file_path.read_bytes()
await self.storage.write_file(file_path, content)
# Must execute write_file synchronously because the hook is synchronous
# TODO: Schedule write operation using asyncio.create_task (non-blocking)
asyncio.get_event_loop().run_until_complete(
self.storage.write_file(file_path, content)
)
async def on_created(self, event):
def on_created(self, event: FileSystemEvent):
if event.is_directory:
self.storage.make_dir(event.src_path)
return
file_path = Path(event.src_path).relative_to(self.path)
content = file_path.read_bytes()
await self.storage.write_file(file_path, content)
# Must execute write_file synchronously because the hook is synchronous
# TODO: Schedule write operation using asyncio.create_task (non-blocking)
asyncio.get_event_loop().run_until_complete(
self.storage.write_file(file_path, content)
)
def on_deleted(self, event):
def on_deleted(self, event: FileSystemEvent):
if event.is_directory:
self.storage.delete_dir(event.src_path)
return
@@ -266,5 +279,5 @@ class FileSyncHandler(FileSystemEventHandler):
file_path = event.src_path
self.storage.delete_file(file_path)
def on_moved(self, event):
def on_moved(self, event: FileSystemEvent):
self.storage.rename(event.src_path, event.dest_path)

View File

@@ -7,12 +7,13 @@ from __future__ import annotations
import inspect
import logging
from io import IOBase
from io import TextIOWrapper
from pathlib import Path
from typing import Literal
from typing import Literal, overload
from google.cloud import storage
from google.cloud.exceptions import NotFound
from google.cloud.storage.fileio import BlobReader, BlobWriter
from forge.models.config import UserConfigurable
@@ -73,14 +74,67 @@ class GCSFileStorage(FileStorage):
path = self.get_path(path)
return self._bucket.blob(str(path))
@overload
def open_file(
self, path: str | Path, mode: Literal["w", "r"] = "r", binary: bool = False
) -> IOBase:
self,
path: str | Path,
mode: Literal["r", "w"] = "r",
binary: Literal[False] = False,
) -> TextIOWrapper:
...
@overload
def open_file(
self, path: str | Path, mode: Literal["r"], binary: Literal[True]
) -> BlobReader:
...
@overload
def open_file(
self, path: str | Path, mode: Literal["w"], binary: Literal[True]
) -> BlobWriter:
...
@overload
def open_file(
self, path: str | Path, mode: Literal["r", "w"], binary: Literal[True]
) -> BlobWriter | BlobReader:
...
@overload
def open_file(self, path: str | Path, *, binary: Literal[True]) -> BlobReader:
...
@overload
def open_file(
self, path: str | Path, mode: Literal["r", "w"] = "r", binary: bool = False
) -> BlobReader | BlobWriter | TextIOWrapper:
...
# https://github.com/microsoft/pyright/issues/8007
def open_file( # pyright: ignore[reportIncompatibleMethodOverride]
self, path: str | Path, mode: Literal["r", "w"] = "r", binary: bool = False
) -> BlobReader | BlobWriter | TextIOWrapper:
"""Open a file in the storage."""
blob = self._get_blob(path)
blob.reload() # pin revision number to prevent version mixing while reading
return blob.open(f"{mode}b" if binary else mode)
@overload
def read_file(self, path: str | Path, binary: Literal[False] = False) -> str:
"""Read a file in the storage as text."""
...
@overload
def read_file(self, path: str | Path, binary: Literal[True]) -> bytes:
"""Read a file in the storage as binary."""
...
@overload
def read_file(self, path: str | Path, binary: bool = False) -> str | bytes:
"""Read a file in the storage."""
...
def read_file(self, path: str | Path, binary: bool = False) -> str | bytes:
"""Read a file in the storage."""
return self.open_file(path, "r", binary).read()

View File

@@ -8,7 +8,7 @@ import inspect
import logging
from contextlib import contextmanager
from pathlib import Path
from typing import IO, Any, Generator, Literal
from typing import Any, BinaryIO, Generator, Literal, TextIO, overload
from .base import FileStorage, FileStorageConfiguration
@@ -42,16 +42,58 @@ class LocalFileStorage(FileStorage):
def initialize(self) -> None:
self.root.mkdir(exist_ok=True, parents=True)
@overload
def open_file(
self,
path: str | Path,
mode: Literal["w", "r"] = "r",
binary: Literal[False] = False,
) -> TextIO:
...
@overload
def open_file(
self, path: str | Path, mode: Literal["w", "r"], binary: Literal[True]
) -> BinaryIO:
...
@overload
def open_file(self, path: str | Path, *, binary: Literal[True]) -> BinaryIO:
...
@overload
def open_file(
self, path: str | Path, mode: Literal["w", "r"] = "r", binary: bool = False
) -> IO:
) -> TextIO | BinaryIO:
...
def open_file(
self, path: str | Path, mode: Literal["w", "r"] = "r", binary: bool = False
) -> TextIO | BinaryIO:
"""Open a file in the storage."""
return self._open_file(path, f"{mode}b" if binary else mode)
def _open_file(self, path: str | Path, mode: str) -> IO:
def _open_file(self, path: str | Path, mode: str) -> TextIO | BinaryIO:
full_path = self.get_path(path)
if any(m in mode for m in ("w", "a", "x")):
full_path.parent.mkdir(parents=True, exist_ok=True)
return open(full_path, mode) # type: ignore
@overload
def read_file(self, path: str | Path, binary: Literal[False] = False) -> str:
"""Read a file in the storage as text."""
...
@overload
def read_file(self, path: str | Path, binary: Literal[True]) -> bytes:
"""Read a file in the storage as binary."""
...
@overload
def read_file(self, path: str | Path, binary: bool = False) -> str | bytes:
"""Read a file in the storage."""
...
def read_file(self, path: str | Path, binary: bool = False) -> str | bytes:
"""Read a file in the storage."""
with self._open_file(path, "rb" if binary else "r") as file:
@@ -60,7 +102,7 @@ class LocalFileStorage(FileStorage):
async def write_file(self, path: str | Path, content: str | bytes) -> None:
"""Write to a file in the storage."""
with self._open_file(path, "wb" if type(content) is bytes else "w") as file:
file.write(content)
file.write(content) # type: ignore
if self.on_write_file:
path = Path(path)

View File

@@ -8,9 +8,9 @@ from __future__ import annotations
import contextlib
import inspect
import logging
from io import IOBase, TextIOWrapper
from io import TextIOWrapper
from pathlib import Path
from typing import TYPE_CHECKING, Literal, Optional
from typing import TYPE_CHECKING, BinaryIO, Literal, Optional, overload
import boto3
import botocore.exceptions
@@ -22,6 +22,7 @@ from .base import FileStorage, FileStorageConfiguration
if TYPE_CHECKING:
import mypy_boto3_s3
from botocore.response import StreamingBody
logger = logging.getLogger(__name__)
@@ -89,18 +90,60 @@ class S3FileStorage(FileStorage):
def _get_obj(self, path: str | Path) -> mypy_boto3_s3.service_resource.Object:
"""Get an S3 object."""
path = self.get_path(path)
obj = self._bucket.Object(str(path))
with contextlib.suppress(botocore.exceptions.ClientError):
obj.load()
return obj
@overload
def open_file(
self, path: str | Path, mode: Literal["w", "r"] = "r", binary: bool = False
) -> IOBase:
self,
path: str | Path,
mode: Literal["r", "w"] = "r",
binary: Literal[False] = False,
) -> TextIOWrapper:
...
@overload
def open_file(
self, path: str | Path, mode: Literal["r", "w"], binary: Literal[True]
) -> S3BinaryIOWrapper:
...
@overload
def open_file(
self, path: str | Path, *, binary: Literal[True]
) -> S3BinaryIOWrapper:
...
@overload
def open_file(
self, path: str | Path, mode: Literal["r", "w"] = "r", binary: bool = False
) -> S3BinaryIOWrapper | TextIOWrapper:
...
def open_file(
self, path: str | Path, mode: Literal["r", "w"] = "r", binary: bool = False
) -> TextIOWrapper | S3BinaryIOWrapper:
"""Open a file in the storage."""
obj = self._get_obj(path)
return obj.get()["Body"] if binary else TextIOWrapper(obj.get()["Body"])
path = self.get_path(path)
body = S3BinaryIOWrapper(self._get_obj(path).get()["Body"], str(path))
return body if binary else TextIOWrapper(body)
@overload
def read_file(self, path: str | Path, binary: Literal[False] = False) -> str:
"""Read a file in the storage as text."""
...
@overload
def read_file(self, path: str | Path, binary: Literal[True]) -> bytes:
"""Read a file in the storage as binary."""
...
@overload
def read_file(self, path: str | Path, binary: bool = False) -> str | bytes:
"""Read a file in the storage."""
...
def read_file(self, path: str | Path, binary: bool = False) -> str | bytes:
"""Read a file in the storage."""
@@ -108,7 +151,7 @@ class S3FileStorage(FileStorage):
async def write_file(self, path: str | Path, content: str | bytes) -> None:
"""Write to a file in the storage."""
obj = self._get_obj(path)
obj = self._get_obj(self.get_path(path))
obj.put(Body=content)
if self.on_write_file:
@@ -172,7 +215,7 @@ class S3FileStorage(FileStorage):
self._s3.meta.client.head_object(Bucket=self._bucket_name, Key=str(path))
return True
except botocore.exceptions.ClientError as e:
if int(e.response["ResponseMetadata"]["HTTPStatusCode"]) == 404:
if e.response.get("ResponseMetadata", {}).get("HTTPStatusCode") == 404:
# If the object does not exist,
# check for objects with the prefix (folder)
prefix = f"{str(path).rstrip('/')}/"
@@ -201,7 +244,7 @@ class S3FileStorage(FileStorage):
)
self._s3.meta.client.delete_object(Bucket=self._bucket_name, Key=old_path)
except botocore.exceptions.ClientError as e:
if int(e.response["ResponseMetadata"]["HTTPStatusCode"]) == 404:
if e.response.get("ResponseMetadata", {}).get("HTTPStatusCode") == 404:
# If the object does not exist,
# it may be a folder
prefix = f"{old_path.rstrip('/')}/"
@@ -233,7 +276,7 @@ class S3FileStorage(FileStorage):
Key=destination,
)
except botocore.exceptions.ClientError as e:
if int(e.response["ResponseMetadata"]["HTTPStatusCode"]) == 404:
if e.response.get("ResponseMetadata", {}).get("HTTPStatusCode") == 404:
# If the object does not exist,
# it may be a folder
prefix = f"{source.rstrip('/')}/"
@@ -254,7 +297,7 @@ class S3FileStorage(FileStorage):
S3FileStorageConfiguration(
bucket=self._bucket_name,
root=Path("/").joinpath(self.get_path(subroot)),
s3_endpoint_url=self._s3.meta.client.meta.endpoint_url,
s3_endpoint_url=SecretStr(self._s3.meta.client.meta.endpoint_url),
)
)
file_storage._s3 = self._s3
@@ -263,3 +306,48 @@ class S3FileStorage(FileStorage):
def __repr__(self) -> str:
return f"{__class__.__name__}(bucket='{self._bucket_name}', root={self._root})"
class S3BinaryIOWrapper(BinaryIO):
def __init__(self, body: StreamingBody, name: str):
self.body = body
self._name = name
@property
def name(self) -> str:
return self._name
def read(self, size: int = -1) -> bytes:
return self.body.read(size if size > 0 else None)
def readinto(self, b: bytearray) -> int:
data = self.read(len(b))
b[: len(data)] = data
return len(data)
def close(self) -> None:
self.body.close()
def fileno(self) -> int:
return self.body.fileno()
def flush(self) -> None:
self.body.flush()
def isatty(self) -> bool:
return self.body.isatty()
def readable(self) -> bool:
return self.body.readable()
def seekable(self) -> bool:
return self.body.seekable()
def writable(self) -> bool:
return False
def __enter__(self):
return self
def __exit__(self, exc_type, exc_val, exc_tb):
self.body.close()

View File

@@ -1 +0,0 @@
from .parsing import extract_dict_from_json, extract_list_from_json, json_loads

View File

@@ -1,7 +1,5 @@
import abc
from typing import TYPE_CHECKING
from forge.models.config import SystemConfiguration
from typing import TYPE_CHECKING, Any
if TYPE_CHECKING:
from forge.llm.providers import AssistantChatMessage
@@ -10,8 +8,6 @@ from .schema import ChatPrompt, LanguageModelClassification
class PromptStrategy(abc.ABC):
default_configuration: SystemConfiguration
@property
@abc.abstractmethod
def model_classification(self) -> LanguageModelClassification:
@@ -22,5 +18,5 @@ class PromptStrategy(abc.ABC):
...
@abc.abstractmethod
def parse_response_content(self, response_content: "AssistantChatMessage"):
def parse_response_content(self, response: "AssistantChatMessage") -> Any:
...

View File

@@ -27,7 +27,7 @@ class ChatPrompt(BaseModel):
prefill_response: str = ""
def raw(self) -> list[ChatMessageDict]:
return [m.dict() for m in self.messages]
return [m.dict() for m in self.messages] # type: ignore
def __str__(self):
return "\n\n".join(

View File

@@ -1,4 +1,10 @@
from .multi import CHAT_MODELS, ModelName, MultiProvider
from .multi import (
CHAT_MODELS,
ChatModelProvider,
EmbeddingModelProvider,
ModelName,
MultiProvider,
)
from .openai import (
OPEN_AI_CHAT_MODELS,
OPEN_AI_EMBEDDING_MODELS,
@@ -14,15 +20,12 @@ from .schema import (
AssistantFunctionCallDict,
ChatMessage,
ChatModelInfo,
ChatModelProvider,
ChatModelResponse,
CompletionModelFunction,
Embedding,
EmbeddingModelInfo,
EmbeddingModelProvider,
EmbeddingModelResponse,
ModelInfo,
ModelProvider,
ModelProviderBudget,
ModelProviderCredentials,
ModelProviderName,
@@ -41,7 +44,6 @@ __all__ = [
"AssistantFunctionCallDict",
"ChatMessage",
"ChatModelInfo",
"ChatModelProvider",
"ChatModelResponse",
"CompletionModelFunction",
"CHAT_MODELS",
@@ -51,7 +53,7 @@ __all__ = [
"EmbeddingModelResponse",
"ModelInfo",
"ModelName",
"ModelProvider",
"ChatModelProvider",
"ModelProviderBudget",
"ModelProviderCredentials",
"ModelProviderName",

View File

@@ -2,7 +2,7 @@ from __future__ import annotations
import enum
import logging
from typing import TYPE_CHECKING, Callable, Optional, ParamSpec, TypeVar
from typing import TYPE_CHECKING, Any, Callable, Optional, ParamSpec, TypeVar
import sentry_sdk
import tenacity
@@ -14,9 +14,9 @@ from forge.llm.providers.schema import (
AssistantChatMessage,
AssistantFunctionCall,
AssistantToolCall,
BaseChatModelProvider,
ChatMessage,
ChatModelInfo,
ChatModelProvider,
ChatModelResponse,
CompletionModelFunction,
ModelProviderBudget,
@@ -27,7 +27,7 @@ from forge.llm.providers.schema import (
ModelTokenizer,
ToolResultMessage,
)
from forge.models.config import Configurable, UserConfigurable
from forge.models.config import UserConfigurable
from .utils import validate_tool_calls
@@ -84,14 +84,14 @@ class AnthropicConfiguration(ModelProviderConfiguration):
class AnthropicCredentials(ModelProviderCredentials):
"""Credentials for Anthropic."""
api_key: SecretStr = UserConfigurable(from_env="ANTHROPIC_API_KEY")
api_key: SecretStr = UserConfigurable(from_env="ANTHROPIC_API_KEY") # type: ignore
api_base: Optional[SecretStr] = UserConfigurable(
default=None, from_env="ANTHROPIC_API_BASE_URL"
)
def get_api_access_kwargs(self) -> dict[str, str]:
return {
k: (v.get_secret_value() if type(v) is SecretStr else v)
k: v.get_secret_value()
for k, v in {
"api_key": self.api_key,
"base_url": self.api_base,
@@ -101,12 +101,12 @@ class AnthropicCredentials(ModelProviderCredentials):
class AnthropicSettings(ModelProviderSettings):
configuration: AnthropicConfiguration
credentials: Optional[AnthropicCredentials]
budget: ModelProviderBudget
configuration: AnthropicConfiguration # type: ignore
credentials: Optional[AnthropicCredentials] # type: ignore
budget: ModelProviderBudget # type: ignore
class AnthropicProvider(Configurable[AnthropicSettings], ChatModelProvider):
class AnthropicProvider(BaseChatModelProvider[AnthropicModelName, AnthropicSettings]):
default_settings = AnthropicSettings(
name="anthropic_provider",
description="Provides access to Anthropic's API.",
@@ -136,27 +136,26 @@ class AnthropicProvider(Configurable[AnthropicSettings], ChatModelProvider):
from anthropic import AsyncAnthropic
self._client = AsyncAnthropic(**self._credentials.get_api_access_kwargs())
self._client = AsyncAnthropic(
**self._credentials.get_api_access_kwargs() # type: ignore
)
async def get_available_models(self) -> list[ChatModelInfo]:
async def get_available_models(self) -> list[ChatModelInfo[AnthropicModelName]]:
return list(ANTHROPIC_CHAT_MODELS.values())
def get_token_limit(self, model_name: str) -> int:
def get_token_limit(self, model_name: AnthropicModelName) -> int:
"""Get the token limit for a given model."""
return ANTHROPIC_CHAT_MODELS[model_name].max_tokens
@classmethod
def get_tokenizer(cls, model_name: AnthropicModelName) -> ModelTokenizer:
def get_tokenizer(self, model_name: AnthropicModelName) -> ModelTokenizer[Any]:
# HACK: No official tokenizer is available for Claude 3
return tiktoken.encoding_for_model(model_name)
@classmethod
def count_tokens(cls, text: str, model_name: AnthropicModelName) -> int:
def count_tokens(self, text: str, model_name: AnthropicModelName) -> int:
return 0 # HACK: No official tokenizer is available for Claude 3
@classmethod
def count_message_tokens(
cls,
self,
messages: ChatMessage | list[ChatMessage],
model_name: AnthropicModelName,
) -> int:
@@ -195,7 +194,7 @@ class AnthropicProvider(Configurable[AnthropicSettings], ChatModelProvider):
cost,
t_input,
t_output,
) = await self._create_chat_completion(completion_kwargs)
) = await self._create_chat_completion(model_name, completion_kwargs)
total_cost += cost
self._logger.debug(
f"Completion usage: {t_input} input, {t_output} output "
@@ -245,7 +244,7 @@ class AnthropicProvider(Configurable[AnthropicSettings], ChatModelProvider):
)
if attempts < self._configuration.fix_failed_parse_tries:
anthropic_messages.append(
_assistant_msg.dict(include={"role", "content"})
_assistant_msg.dict(include={"role", "content"}) # type: ignore
)
anthropic_messages.append(
{
@@ -312,7 +311,6 @@ class AnthropicProvider(Configurable[AnthropicSettings], ChatModelProvider):
def _get_chat_completion_args(
self,
prompt_messages: list[ChatMessage],
model: AnthropicModelName,
functions: Optional[list[CompletionModelFunction]] = None,
max_output_tokens: Optional[int] = None,
**kwargs,
@@ -321,7 +319,6 @@ class AnthropicProvider(Configurable[AnthropicSettings], ChatModelProvider):
Args:
prompt_messages: List of ChatMessages.
model: The model to use.
functions: Optional list of functions available to the LLM.
kwargs: Additional keyword arguments.
@@ -329,8 +326,6 @@ class AnthropicProvider(Configurable[AnthropicSettings], ChatModelProvider):
list[MessageParam]: Prompt messages for the Anthropic call
dict[str, Any]: Any other kwargs for the Anthropic call
"""
kwargs["model"] = model
if functions:
kwargs["tools"] = [
{
@@ -433,7 +428,7 @@ class AnthropicProvider(Configurable[AnthropicSettings], ChatModelProvider):
return messages, kwargs # type: ignore
async def _create_chat_completion(
self, completion_kwargs: MessageCreateParams
self, model: AnthropicModelName, completion_kwargs: MessageCreateParams
) -> tuple[Message, float, int, int]:
"""
Create a chat completion using the Anthropic API with retry handling.
@@ -449,17 +444,15 @@ class AnthropicProvider(Configurable[AnthropicSettings], ChatModelProvider):
"""
@self._retry_api_request
async def _create_chat_completion_with_retry(
completion_kwargs: MessageCreateParams,
) -> Message:
async def _create_chat_completion_with_retry() -> Message:
return await self._client.beta.tools.messages.create(
**completion_kwargs # type: ignore
model=model, **completion_kwargs # type: ignore
)
response = await _create_chat_completion_with_retry(completion_kwargs)
response = await _create_chat_completion_with_retry()
cost = self._budget.update_usage_and_cost(
model_info=ANTHROPIC_CHAT_MODELS[completion_kwargs["model"]],
model_info=ANTHROPIC_CHAT_MODELS[model],
input_tokens_used=response.usage.input_tokens,
output_tokens_used=response.usage.output_tokens,
)
@@ -472,7 +465,10 @@ class AnthropicProvider(Configurable[AnthropicSettings], ChatModelProvider):
AssistantToolCall(
id=c.id,
type="function",
function=AssistantFunctionCall(name=c.name, arguments=c.input),
function=AssistantFunctionCall(
name=c.name,
arguments=c.input, # type: ignore
),
)
for c in assistant_message.content
if c.type == "tool_use"

View File

@@ -2,7 +2,7 @@ from __future__ import annotations
import enum
import logging
from typing import TYPE_CHECKING, Callable, Optional, ParamSpec, TypeVar
from typing import TYPE_CHECKING, Any, Callable, Optional, ParamSpec, TypeVar
import sentry_sdk
import tenacity
@@ -15,9 +15,9 @@ from forge.llm.providers.schema import (
AssistantChatMessage,
AssistantFunctionCall,
AssistantToolCall,
BaseChatModelProvider,
ChatMessage,
ChatModelInfo,
ChatModelProvider,
ChatModelResponse,
CompletionModelFunction,
ModelProviderBudget,
@@ -27,8 +27,9 @@ from forge.llm.providers.schema import (
ModelProviderSettings,
ModelTokenizer,
)
from forge.models.config import Configurable, UserConfigurable
from forge.models.config import UserConfigurable
from .openai import format_function_def_for_openai
from .utils import validate_tool_calls
if TYPE_CHECKING:
@@ -93,14 +94,14 @@ class GroqConfiguration(ModelProviderConfiguration):
class GroqCredentials(ModelProviderCredentials):
"""Credentials for Groq."""
api_key: SecretStr = UserConfigurable(from_env="GROQ_API_KEY")
api_key: SecretStr = UserConfigurable(from_env="GROQ_API_KEY") # type: ignore
api_base: Optional[SecretStr] = UserConfigurable(
default=None, from_env="GROQ_API_BASE_URL"
)
def get_api_access_kwargs(self) -> dict[str, str]:
return {
k: (v.get_secret_value() if type(v) is SecretStr else v)
k: v.get_secret_value()
for k, v in {
"api_key": self.api_key,
"base_url": self.api_base,
@@ -110,12 +111,12 @@ class GroqCredentials(ModelProviderCredentials):
class GroqSettings(ModelProviderSettings):
configuration: GroqConfiguration
credentials: Optional[GroqCredentials]
budget: ModelProviderBudget
configuration: GroqConfiguration # type: ignore
credentials: Optional[GroqCredentials] # type: ignore
budget: ModelProviderBudget # type: ignore
class GroqProvider(Configurable[GroqSettings], ChatModelProvider):
class GroqProvider(BaseChatModelProvider[GroqModelName, GroqSettings]):
default_settings = GroqSettings(
name="groq_provider",
description="Provides access to Groq's API.",
@@ -145,28 +146,27 @@ class GroqProvider(Configurable[GroqSettings], ChatModelProvider):
from groq import AsyncGroq
self._client = AsyncGroq(**self._credentials.get_api_access_kwargs())
self._client = AsyncGroq(
**self._credentials.get_api_access_kwargs() # type: ignore
)
async def get_available_models(self) -> list[ChatModelInfo]:
async def get_available_models(self) -> list[ChatModelInfo[GroqModelName]]:
_models = (await self._client.models.list()).data
return [GROQ_CHAT_MODELS[m.id] for m in _models if m.id in GROQ_CHAT_MODELS]
def get_token_limit(self, model_name: str) -> int:
def get_token_limit(self, model_name: GroqModelName) -> int:
"""Get the token limit for a given model."""
return GROQ_CHAT_MODELS[model_name].max_tokens
@classmethod
def get_tokenizer(cls, model_name: GroqModelName) -> ModelTokenizer:
def get_tokenizer(self, model_name: GroqModelName) -> ModelTokenizer[Any]:
# HACK: No official tokenizer is available for Groq
return tiktoken.encoding_for_model("gpt-3.5-turbo")
@classmethod
def count_tokens(cls, text: str, model_name: GroqModelName) -> int:
return len(cls.get_tokenizer(model_name).encode(text))
def count_tokens(self, text: str, model_name: GroqModelName) -> int:
return len(self.get_tokenizer(model_name).encode(text))
@classmethod
def count_message_tokens(
cls,
self,
messages: ChatMessage | list[ChatMessage],
model_name: GroqModelName,
) -> int:
@@ -174,7 +174,7 @@ class GroqProvider(Configurable[GroqSettings], ChatModelProvider):
messages = [messages]
# HACK: No official tokenizer (for text or messages) is available for Groq.
# Token overhead of messages is unknown and may be inaccurate.
return cls.count_tokens(
return self.count_tokens(
"\n\n".join(f"{m.role.upper()}: {m.content}" for m in messages), model_name
)
@@ -191,7 +191,6 @@ class GroqProvider(Configurable[GroqSettings], ChatModelProvider):
"""Create a completion using the Groq API."""
groq_messages, completion_kwargs = self._get_chat_completion_args(
prompt_messages=model_prompt,
model=model_name,
functions=functions,
max_output_tokens=max_output_tokens,
**kwargs,
@@ -202,7 +201,8 @@ class GroqProvider(Configurable[GroqSettings], ChatModelProvider):
while True:
completion_kwargs["messages"] = groq_messages.copy()
_response, _cost, t_input, t_output = await self._create_chat_completion(
completion_kwargs
model=model_name,
completion_kwargs=completion_kwargs,
)
total_cost += _cost
@@ -221,7 +221,7 @@ class GroqProvider(Configurable[GroqSettings], ChatModelProvider):
parse_errors += validate_tool_calls(tool_calls, functions)
assistant_msg = AssistantChatMessage(
content=_assistant_msg.content,
content=_assistant_msg.content or "",
tool_calls=tool_calls or None,
)
@@ -240,7 +240,7 @@ class GroqProvider(Configurable[GroqSettings], ChatModelProvider):
return ChatModelResponse(
response=AssistantChatMessage(
content=_assistant_msg.content,
content=_assistant_msg.content or "",
tool_calls=tool_calls or None,
),
parsed_result=parsed_result,
@@ -266,7 +266,9 @@ class GroqProvider(Configurable[GroqSettings], ChatModelProvider):
)
if attempts < self._configuration.fix_failed_parse_tries:
groq_messages.append(_assistant_msg.dict(exclude_none=True))
groq_messages.append(
_assistant_msg.dict(exclude_none=True) # type: ignore
)
groq_messages.append(
{
"role": "system",
@@ -282,7 +284,6 @@ class GroqProvider(Configurable[GroqSettings], ChatModelProvider):
def _get_chat_completion_args(
self,
prompt_messages: list[ChatMessage],
model: GroqModelName,
functions: Optional[list[CompletionModelFunction]] = None,
max_output_tokens: Optional[int] = None,
**kwargs, # type: ignore
@@ -291,7 +292,6 @@ class GroqProvider(Configurable[GroqSettings], ChatModelProvider):
Args:
model_prompt: List of ChatMessages.
model_name: The model to use.
functions: Optional list of functions available to the LLM.
kwargs: Additional keyword arguments.
@@ -300,13 +300,13 @@ class GroqProvider(Configurable[GroqSettings], ChatModelProvider):
dict[str, Any]: Any other kwargs for the OpenAI call
"""
kwargs: CompletionCreateParams = kwargs # type: ignore
kwargs["model"] = model
if max_output_tokens:
kwargs["max_tokens"] = max_output_tokens
if functions:
kwargs["tools"] = [
{"type": "function", "function": f.schema} for f in functions
{"type": "function", "function": format_function_def_for_openai(f)}
for f in functions
]
if len(functions) == 1:
# force the model to call the only specified function
@@ -321,7 +321,7 @@ class GroqProvider(Configurable[GroqSettings], ChatModelProvider):
kwargs["extra_headers"].update(extra_headers.copy()) # type: ignore
groq_messages: list[ChatCompletionMessageParam] = [
message.dict(
message.dict( # type: ignore
include={"role", "content", "tool_calls", "tool_call_id", "name"},
exclude_none=True,
)
@@ -335,7 +335,7 @@ class GroqProvider(Configurable[GroqSettings], ChatModelProvider):
return groq_messages, kwargs
async def _create_chat_completion(
self, completion_kwargs: CompletionCreateParams
self, model: GroqModelName, completion_kwargs: CompletionCreateParams
) -> tuple[ChatCompletion, float, int, int]:
"""
Create a chat completion using the Groq API with retry handling.
@@ -351,24 +351,31 @@ class GroqProvider(Configurable[GroqSettings], ChatModelProvider):
"""
@self._retry_api_request
async def _create_chat_completion_with_retry(
completion_kwargs: CompletionCreateParams,
) -> ChatCompletion:
return await self._client.chat.completions.create(**completion_kwargs)
async def _create_chat_completion_with_retry() -> ChatCompletion:
return await self._client.chat.completions.create(
model=model, **completion_kwargs # type: ignore
)
response = await _create_chat_completion_with_retry(completion_kwargs)
response = await _create_chat_completion_with_retry()
cost = self._budget.update_usage_and_cost(
model_info=GROQ_CHAT_MODELS[completion_kwargs["model"]],
input_tokens_used=response.usage.prompt_tokens,
output_tokens_used=response.usage.completion_tokens,
)
return (
response,
cost,
response.usage.prompt_tokens,
response.usage.completion_tokens,
)
if not response.usage:
self._logger.warning(
"Groq chat completion response does not contain a usage field",
response,
)
return response, 0, 0, 0
else:
cost = self._budget.update_usage_and_cost(
model_info=GROQ_CHAT_MODELS[model],
input_tokens_used=response.usage.prompt_tokens,
output_tokens_used=response.usage.completion_tokens,
)
return (
response,
cost,
response.usage.prompt_tokens,
response.usage.completion_tokens,
)
def _parse_assistant_tool_calls(
self, assistant_message: ChatCompletionMessage, compat_mode: bool = False

View File

@@ -1,20 +1,18 @@
from __future__ import annotations
import logging
from typing import Callable, Iterator, Optional, TypeVar
from typing import Any, Callable, Iterator, Optional, TypeVar
from pydantic import ValidationError
from forge.models.config import Configurable
from .anthropic import ANTHROPIC_CHAT_MODELS, AnthropicModelName, AnthropicProvider
from .groq import GROQ_CHAT_MODELS, GroqModelName, GroqProvider
from .openai import OPEN_AI_CHAT_MODELS, OpenAIModelName, OpenAIProvider
from .schema import (
AssistantChatMessage,
BaseChatModelProvider,
ChatMessage,
ChatModelInfo,
ChatModelProvider,
ChatModelResponse,
CompletionModelFunction,
ModelProviderBudget,
@@ -27,11 +25,12 @@ from .schema import (
_T = TypeVar("_T")
ModelName = AnthropicModelName | GroqModelName | OpenAIModelName
EmbeddingModelProvider = OpenAIProvider
CHAT_MODELS = {**ANTHROPIC_CHAT_MODELS, **GROQ_CHAT_MODELS, **OPEN_AI_CHAT_MODELS}
class MultiProvider(Configurable[ModelProviderSettings], ChatModelProvider):
class MultiProvider(BaseChatModelProvider[ModelName, ModelProviderSettings]):
default_settings = ModelProviderSettings(
name="multi_provider",
description=(
@@ -57,7 +56,7 @@ class MultiProvider(Configurable[ModelProviderSettings], ChatModelProvider):
self._provider_instances = {}
async def get_available_models(self) -> list[ChatModelInfo]:
async def get_available_models(self) -> list[ChatModelInfo[ModelName]]:
models = []
for provider in self.get_available_providers():
models.extend(await provider.get_available_models())
@@ -65,24 +64,25 @@ class MultiProvider(Configurable[ModelProviderSettings], ChatModelProvider):
def get_token_limit(self, model_name: ModelName) -> int:
"""Get the token limit for a given model."""
return self.get_model_provider(model_name).get_token_limit(model_name)
@classmethod
def get_tokenizer(cls, model_name: ModelName) -> ModelTokenizer:
return cls._get_model_provider_class(model_name).get_tokenizer(model_name)
@classmethod
def count_tokens(cls, text: str, model_name: ModelName) -> int:
return cls._get_model_provider_class(model_name).count_tokens(
text=text, model_name=model_name
return self.get_model_provider(model_name).get_token_limit(
model_name # type: ignore
)
def get_tokenizer(self, model_name: ModelName) -> ModelTokenizer[Any]:
return self.get_model_provider(model_name).get_tokenizer(
model_name # type: ignore
)
def count_tokens(self, text: str, model_name: ModelName) -> int:
return self.get_model_provider(model_name).count_tokens(
text=text, model_name=model_name # type: ignore
)
@classmethod
def count_message_tokens(
cls, messages: ChatMessage | list[ChatMessage], model_name: ModelName
self, messages: ChatMessage | list[ChatMessage], model_name: ModelName
) -> int:
return cls._get_model_provider_class(model_name).count_message_tokens(
messages=messages, model_name=model_name
return self.get_model_provider(model_name).count_message_tokens(
messages=messages, model_name=model_name # type: ignore
)
async def create_chat_completion(
@@ -98,7 +98,7 @@ class MultiProvider(Configurable[ModelProviderSettings], ChatModelProvider):
"""Create a completion using the Anthropic API."""
return await self.get_model_provider(model_name).create_chat_completion(
model_prompt=model_prompt,
model_name=model_name,
model_name=model_name, # type: ignore
completion_parser=completion_parser,
functions=functions,
max_output_tokens=max_output_tokens,
@@ -136,17 +136,11 @@ class MultiProvider(Configurable[ModelProviderSettings], ChatModelProvider):
) from e
self._provider_instances[provider_name] = _provider = Provider(
settings=settings, logger=self._logger
settings=settings, logger=self._logger # type: ignore
)
_provider._budget = self._budget # Object binding not preserved by Pydantic
return _provider
@classmethod
def _get_model_provider_class(
cls, model_name: ModelName
) -> type[AnthropicProvider | GroqProvider | OpenAIProvider]:
return cls._get_provider_class(CHAT_MODELS[model_name].provider_name)
@classmethod
def _get_provider_class(
cls, provider_name: ModelProviderName
@@ -162,3 +156,6 @@ class MultiProvider(Configurable[ModelProviderSettings], ChatModelProvider):
def __repr__(self):
return f"{self.__class__.__name__}()"
ChatModelProvider = AnthropicProvider | GroqProvider | OpenAIProvider | MultiProvider

View File

@@ -2,7 +2,16 @@ import enum
import logging
import os
from pathlib import Path
from typing import Any, Callable, Coroutine, Iterator, Optional, ParamSpec, TypeVar
from typing import (
Any,
Callable,
Coroutine,
Iterator,
Optional,
ParamSpec,
TypeVar,
cast,
)
import sentry_sdk
import tenacity
@@ -12,9 +21,11 @@ from openai._exceptions import APIStatusError, RateLimitError
from openai.types import CreateEmbeddingResponse
from openai.types.chat import (
ChatCompletion,
ChatCompletionAssistantMessageParam,
ChatCompletionMessage,
ChatCompletionMessageParam,
)
from openai.types.shared_params import FunctionDefinition
from pydantic import SecretStr
from forge.json.parsing import json_loads
@@ -23,14 +34,14 @@ from forge.llm.providers.schema import (
AssistantFunctionCall,
AssistantToolCall,
AssistantToolCallDict,
BaseChatModelProvider,
BaseEmbeddingModelProvider,
ChatMessage,
ChatModelInfo,
ChatModelProvider,
ChatModelResponse,
CompletionModelFunction,
Embedding,
EmbeddingModelInfo,
EmbeddingModelProvider,
EmbeddingModelResponse,
ModelProviderBudget,
ModelProviderConfiguration,
@@ -39,7 +50,7 @@ from forge.llm.providers.schema import (
ModelProviderSettings,
ModelTokenizer,
)
from forge.models.config import Configurable, UserConfigurable
from forge.models.config import UserConfigurable
from forge.models.json_schema import JSONSchema
from .utils import validate_tool_calls
@@ -223,43 +234,46 @@ class OpenAIConfiguration(ModelProviderConfiguration):
class OpenAICredentials(ModelProviderCredentials):
"""Credentials for OpenAI."""
api_key: SecretStr = UserConfigurable(from_env="OPENAI_API_KEY")
api_key: SecretStr = UserConfigurable(from_env="OPENAI_API_KEY") # type: ignore
api_base: Optional[SecretStr] = UserConfigurable(
default=None, from_env="OPENAI_API_BASE_URL"
)
organization: Optional[SecretStr] = UserConfigurable(from_env="OPENAI_ORGANIZATION")
api_type: str = UserConfigurable(
default="",
from_env=lambda: (
api_type: Optional[SecretStr] = UserConfigurable(
default=None,
from_env=lambda: cast(
SecretStr | None,
"azure"
if os.getenv("USE_AZURE") == "True"
else os.getenv("OPENAI_API_TYPE")
else os.getenv("OPENAI_API_TYPE"),
),
)
api_version: str = UserConfigurable("", from_env="OPENAI_API_VERSION")
api_version: Optional[SecretStr] = UserConfigurable(
default=None, from_env="OPENAI_API_VERSION"
)
azure_endpoint: Optional[SecretStr] = None
azure_model_to_deploy_id_map: Optional[dict[str, str]] = None
def get_api_access_kwargs(self) -> dict[str, str]:
kwargs = {
k: (v.get_secret_value() if type(v) is SecretStr else v)
k: v.get_secret_value()
for k, v in {
"api_key": self.api_key,
"base_url": self.api_base,
"organization": self.organization,
"api_version": self.api_version,
}.items()
if v is not None
}
if self.api_type == "azure":
kwargs["api_version"] = self.api_version
if self.api_type == SecretStr("azure"):
assert self.azure_endpoint, "Azure endpoint not configured"
kwargs["azure_endpoint"] = self.azure_endpoint.get_secret_value()
return kwargs
def get_model_access_kwargs(self, model: str) -> dict[str, str]:
kwargs = {"model": model}
if self.api_type == "azure" and model:
if self.api_type == SecretStr("azure") and model:
azure_kwargs = self._get_azure_access_kwargs(model)
kwargs.update(azure_kwargs)
return kwargs
@@ -276,7 +290,7 @@ class OpenAICredentials(ModelProviderCredentials):
raise ValueError(*e.args)
self.api_type = config_params.get("azure_api_type", "azure")
self.api_version = config_params.get("azure_api_version", "")
self.api_version = config_params.get("azure_api_version", None)
self.azure_endpoint = config_params.get("azure_endpoint")
self.azure_model_to_deploy_id_map = config_params.get("azure_model_map")
@@ -294,13 +308,14 @@ class OpenAICredentials(ModelProviderCredentials):
class OpenAISettings(ModelProviderSettings):
configuration: OpenAIConfiguration
credentials: Optional[OpenAICredentials]
budget: ModelProviderBudget
configuration: OpenAIConfiguration # type: ignore
credentials: Optional[OpenAICredentials] # type: ignore
budget: ModelProviderBudget # type: ignore
class OpenAIProvider(
Configurable[OpenAISettings], ChatModelProvider, EmbeddingModelProvider
BaseChatModelProvider[OpenAIModelName, OpenAISettings],
BaseEmbeddingModelProvider[OpenAIModelName, OpenAISettings],
):
default_settings = OpenAISettings(
name="openai_provider",
@@ -329,37 +344,38 @@ class OpenAIProvider(
super(OpenAIProvider, self).__init__(settings=settings, logger=logger)
if self._credentials.api_type == "azure":
if self._credentials.api_type == SecretStr("azure"):
from openai import AsyncAzureOpenAI
# API key and org (if configured) are passed, the rest of the required
# credentials is loaded from the environment by the AzureOpenAI client.
self._client = AsyncAzureOpenAI(**self._credentials.get_api_access_kwargs())
self._client = AsyncAzureOpenAI(
**self._credentials.get_api_access_kwargs() # type: ignore
)
else:
from openai import AsyncOpenAI
self._client = AsyncOpenAI(**self._credentials.get_api_access_kwargs())
self._client = AsyncOpenAI(
**self._credentials.get_api_access_kwargs() # type: ignore
)
async def get_available_models(self) -> list[ChatModelInfo]:
async def get_available_models(self) -> list[ChatModelInfo[OpenAIModelName]]:
_models = (await self._client.models.list()).data
return [OPEN_AI_MODELS[m.id] for m in _models if m.id in OPEN_AI_MODELS]
def get_token_limit(self, model_name: str) -> int:
def get_token_limit(self, model_name: OpenAIModelName) -> int:
"""Get the token limit for a given model."""
return OPEN_AI_MODELS[model_name].max_tokens
@classmethod
def get_tokenizer(cls, model_name: OpenAIModelName) -> ModelTokenizer:
def get_tokenizer(self, model_name: OpenAIModelName) -> ModelTokenizer[int]:
return tiktoken.encoding_for_model(model_name)
@classmethod
def count_tokens(cls, text: str, model_name: OpenAIModelName) -> int:
encoding = cls.get_tokenizer(model_name)
def count_tokens(self, text: str, model_name: OpenAIModelName) -> int:
encoding = self.get_tokenizer(model_name)
return len(encoding.encode(text))
@classmethod
def count_message_tokens(
cls,
self,
messages: ChatMessage | list[ChatMessage],
model_name: OpenAIModelName,
) -> int:
@@ -447,7 +463,7 @@ class OpenAIProvider(
parse_errors += validate_tool_calls(tool_calls, functions)
assistant_msg = AssistantChatMessage(
content=_assistant_msg.content,
content=_assistant_msg.content or "",
tool_calls=tool_calls or None,
)
@@ -466,7 +482,7 @@ class OpenAIProvider(
return ChatModelResponse(
response=AssistantChatMessage(
content=_assistant_msg.content,
content=_assistant_msg.content or "",
tool_calls=tool_calls or None,
),
parsed_result=parsed_result,
@@ -492,7 +508,12 @@ class OpenAIProvider(
)
if attempts < self._configuration.fix_failed_parse_tries:
openai_messages.append(_assistant_msg.dict(exclude_none=True))
openai_messages.append(
cast(
ChatCompletionAssistantMessageParam,
_assistant_msg.dict(exclude_none=True),
)
)
openai_messages.append(
{
"role": "system",
@@ -522,7 +543,10 @@ class OpenAIProvider(
prompt_tokens_used=response.usage.prompt_tokens,
completion_tokens_used=0,
)
self._budget.update_usage_and_cost(response)
self._budget.update_usage_and_cost(
model_info=response.model_info,
input_tokens_used=response.prompt_tokens_used,
)
return response
def _get_chat_completion_args(
@@ -549,7 +573,8 @@ class OpenAIProvider(
if functions:
if OPEN_AI_CHAT_MODELS[model_name].has_function_call_api:
kwargs["tools"] = [
{"type": "function", "function": f.schema} for f in functions
{"type": "function", "function": format_function_def_for_openai(f)}
for f in functions
]
if len(functions) == 1:
# force the model to call the only specified function
@@ -569,10 +594,13 @@ class OpenAIProvider(
model_prompt += kwargs["messages"]
del kwargs["messages"]
openai_messages: list[ChatCompletionMessageParam] = [
message.dict(
include={"role", "content", "tool_calls", "name"},
exclude_none=True,
openai_messages = [
cast(
ChatCompletionMessageParam,
message.dict(
include={"role", "content", "tool_calls", "name"},
exclude_none=True,
),
)
for message in model_prompt
]
@@ -655,7 +683,7 @@ class OpenAIProvider(
def _parse_assistant_tool_calls(
self, assistant_message: ChatCompletionMessage, compat_mode: bool = False
):
) -> tuple[list[AssistantToolCall], list[Exception]]:
tool_calls: list[AssistantToolCall] = []
parse_errors: list[Exception] = []
@@ -749,6 +777,24 @@ class OpenAIProvider(
return "OpenAIProvider()"
def format_function_def_for_openai(self: CompletionModelFunction) -> FunctionDefinition:
"""Returns an OpenAI-consumable function definition"""
return {
"name": self.name,
"description": self.description,
"parameters": {
"type": "object",
"properties": {
name: param.to_dict() for name, param in self.parameters.items()
},
"required": [
name for name, param in self.parameters.items() if param.required
],
},
}
def format_function_specs_as_typescript_ns(
functions: list[CompletionModelFunction],
) -> str:
@@ -888,6 +934,4 @@ def _tool_calls_compat_extract_calls(response: str) -> Iterator[AssistantToolCal
for t in tool_calls:
t["id"] = str(uuid.uuid4())
t["function"]["arguments"] = str(t["function"]["arguments"]) # HACK
yield AssistantToolCall.parse_obj(t)

View File

@@ -19,14 +19,17 @@ from typing import (
from pydantic import BaseModel, Field, SecretStr, validator
from forge.logging.utils import fmt_kwargs
from forge.models.config import SystemConfiguration, UserConfigurable
from forge.models.config import (
Configurable,
SystemConfiguration,
SystemSettings,
UserConfigurable,
)
from forge.models.json_schema import JSONSchema
from forge.models.providers import (
Embedding,
ProviderBudget,
ProviderCredentials,
ProviderSettings,
ProviderUsage,
ResourceType,
)
@@ -34,6 +37,11 @@ if TYPE_CHECKING:
from jsonschema import ValidationError
_T = TypeVar("_T")
_ModelName = TypeVar("_ModelName", bound=str)
class ModelProviderService(str, enum.Enum):
"""A ModelService describes what kind of service the model provides."""
@@ -102,13 +110,13 @@ class AssistantToolCallDict(TypedDict):
class AssistantChatMessage(ChatMessage):
role: Literal[ChatMessage.Role.ASSISTANT] = ChatMessage.Role.ASSISTANT
content: Optional[str]
role: Literal[ChatMessage.Role.ASSISTANT] = ChatMessage.Role.ASSISTANT # type: ignore # noqa
content: str = ""
tool_calls: Optional[list[AssistantToolCall]] = None
class ToolResultMessage(ChatMessage):
role: Literal[ChatMessage.Role.TOOL] = ChatMessage.Role.TOOL
role: Literal[ChatMessage.Role.TOOL] = ChatMessage.Role.TOOL # type: ignore
is_error: bool = False
tool_call_id: str
@@ -126,32 +134,6 @@ class CompletionModelFunction(BaseModel):
description: str
parameters: dict[str, "JSONSchema"]
@property
def schema(self) -> dict[str, str | dict | list]:
"""Returns an OpenAI-consumable function specification"""
return {
"name": self.name,
"description": self.description,
"parameters": {
"type": "object",
"properties": {
name: param.to_dict() for name, param in self.parameters.items()
},
"required": [
name for name, param in self.parameters.items() if param.required
],
},
}
@staticmethod
def parse(schema: dict) -> "CompletionModelFunction":
return CompletionModelFunction(
name=schema["name"],
description=schema["description"],
parameters=JSONSchema.parse_properties(schema["parameters"]),
)
def fmt_line(self) -> str:
params = ", ".join(
f"{name}{'?' if not p.required else ''}: " f"{p.typescript_type}"
@@ -184,15 +166,15 @@ class CompletionModelFunction(BaseModel):
return params_schema.validate_object(function_call.arguments)
class ModelInfo(BaseModel):
class ModelInfo(BaseModel, Generic[_ModelName]):
"""Struct for model information.
Would be lovely to eventually get this directly from APIs, but needs to be
scraped from websites for now.
"""
name: str
service: ModelProviderService
name: _ModelName
service: ClassVar[ModelProviderService]
provider_name: ModelProviderName
prompt_token_cost: float = 0.0
completion_token_cost: float = 0.0
@@ -220,27 +202,39 @@ class ModelProviderCredentials(ProviderCredentials):
api_version: SecretStr | None = UserConfigurable(default=None)
deployment_id: SecretStr | None = UserConfigurable(default=None)
class Config:
class Config(ProviderCredentials.Config):
extra = "ignore"
class ModelProviderUsage(ProviderUsage):
class ModelProviderUsage(BaseModel):
"""Usage for a particular model from a model provider."""
completion_tokens: int = 0
prompt_tokens: int = 0
class ModelUsage(BaseModel):
completion_tokens: int = 0
prompt_tokens: int = 0
usage_per_model: dict[str, ModelUsage] = defaultdict(ModelUsage)
@property
def completion_tokens(self) -> int:
return sum(model.completion_tokens for model in self.usage_per_model.values())
@property
def prompt_tokens(self) -> int:
return sum(model.prompt_tokens for model in self.usage_per_model.values())
def update_usage(
self,
model: str,
input_tokens_used: int,
output_tokens_used: int = 0,
) -> None:
self.prompt_tokens += input_tokens_used
self.completion_tokens += output_tokens_used
self.usage_per_model[model].prompt_tokens += input_tokens_used
self.usage_per_model[model].completion_tokens += output_tokens_used
class ModelProviderBudget(ProviderBudget):
usage: defaultdict[str, ModelProviderUsage] = defaultdict(ModelProviderUsage)
class ModelProviderBudget(ProviderBudget[ModelProviderUsage]):
usage: ModelProviderUsage = Field(default_factory=ModelProviderUsage)
def update_usage_and_cost(
self,
@@ -253,7 +247,7 @@ class ModelProviderBudget(ProviderBudget):
Returns:
float: The (calculated) cost of the given model response.
"""
self.usage[model_info.name].update_usage(input_tokens_used, output_tokens_used)
self.usage.update_usage(model_info.name, input_tokens_used, output_tokens_used)
incurred_cost = (
output_tokens_used * model_info.completion_token_cost
+ input_tokens_used * model_info.prompt_token_cost
@@ -263,28 +257,33 @@ class ModelProviderBudget(ProviderBudget):
return incurred_cost
class ModelProviderSettings(ProviderSettings):
resource_type: ResourceType = ResourceType.MODEL
class ModelProviderSettings(SystemSettings):
resource_type: ClassVar[ResourceType] = ResourceType.MODEL
configuration: ModelProviderConfiguration
credentials: Optional[ModelProviderCredentials] = None
budget: Optional[ModelProviderBudget] = None
class ModelProvider(abc.ABC):
_ModelProviderSettings = TypeVar("_ModelProviderSettings", bound=ModelProviderSettings)
# TODO: either use MultiProvider throughout codebase as type for `llm_provider`, or
# replace `_ModelName` by `str` to eliminate type checking difficulties
class BaseModelProvider(
abc.ABC,
Generic[_ModelName, _ModelProviderSettings],
Configurable[_ModelProviderSettings],
):
"""A ModelProvider abstracts the details of a particular provider of models."""
default_settings: ClassVar[ModelProviderSettings]
_settings: ModelProviderSettings
_configuration: ModelProviderConfiguration
_credentials: Optional[ModelProviderCredentials] = None
_budget: Optional[ModelProviderBudget] = None
default_settings: ClassVar[_ModelProviderSettings] # type: ignore
_settings: _ModelProviderSettings
_logger: logging.Logger
def __init__(
self,
settings: Optional[ModelProviderSettings] = None,
settings: Optional[_ModelProviderSettings] = None,
logger: Optional[logging.Logger] = None,
):
if not settings:
@@ -298,15 +297,15 @@ class ModelProvider(abc.ABC):
self._logger = logger or logging.getLogger(self.__module__)
@abc.abstractmethod
def count_tokens(self, text: str, model_name: str) -> int:
def count_tokens(self, text: str, model_name: _ModelName) -> int:
...
@abc.abstractmethod
def get_tokenizer(self, model_name: str) -> "ModelTokenizer":
def get_tokenizer(self, model_name: _ModelName) -> "ModelTokenizer[Any]":
...
@abc.abstractmethod
def get_token_limit(self, model_name: str) -> int:
def get_token_limit(self, model_name: _ModelName) -> int:
...
def get_incurred_cost(self) -> float:
@@ -320,15 +319,15 @@ class ModelProvider(abc.ABC):
return math.inf
class ModelTokenizer(Protocol):
class ModelTokenizer(Protocol, Generic[_T]):
"""A ModelTokenizer provides tokenization specific to a model."""
@abc.abstractmethod
def encode(self, text: str) -> list:
def encode(self, text: str) -> list[_T]:
...
@abc.abstractmethod
def decode(self, tokens: list) -> str:
def decode(self, tokens: list[_T]) -> str:
...
@@ -337,10 +336,10 @@ class ModelTokenizer(Protocol):
####################
class EmbeddingModelInfo(ModelInfo):
class EmbeddingModelInfo(ModelInfo[_ModelName]):
"""Struct for embedding model information."""
service: Literal[ModelProviderService.EMBEDDING] = ModelProviderService.EMBEDDING
service = ModelProviderService.EMBEDDING
max_tokens: int
embedding_dimensions: int
@@ -350,20 +349,19 @@ class EmbeddingModelResponse(ModelResponse):
embedding: Embedding = Field(default_factory=list)
@classmethod
@validator("completion_tokens_used")
def _verify_no_completion_tokens_used(cls, v):
def _verify_no_completion_tokens_used(cls, v: int):
if v > 0:
raise ValueError("Embeddings should not have completion tokens used.")
return v
class EmbeddingModelProvider(ModelProvider):
class BaseEmbeddingModelProvider(BaseModelProvider[_ModelName, _ModelProviderSettings]):
@abc.abstractmethod
async def create_embedding(
self,
text: str,
model_name: str,
model_name: _ModelName,
embedding_parser: Callable[[Embedding], Embedding],
**kwargs,
) -> EmbeddingModelResponse:
@@ -375,34 +373,31 @@ class EmbeddingModelProvider(ModelProvider):
###############
class ChatModelInfo(ModelInfo):
class ChatModelInfo(ModelInfo[_ModelName]):
"""Struct for language model information."""
service: Literal[ModelProviderService.CHAT] = ModelProviderService.CHAT
service = ModelProviderService.CHAT
max_tokens: int
has_function_call_api: bool = False
_T = TypeVar("_T")
class ChatModelResponse(ModelResponse, Generic[_T]):
"""Standard response struct for a response from a language model."""
response: AssistantChatMessage
parsed_result: _T = None
parsed_result: _T
class ChatModelProvider(ModelProvider):
class BaseChatModelProvider(BaseModelProvider[_ModelName, _ModelProviderSettings]):
@abc.abstractmethod
async def get_available_models(self) -> list[ChatModelInfo]:
async def get_available_models(self) -> list[ChatModelInfo[_ModelName]]:
...
@abc.abstractmethod
def count_message_tokens(
self,
messages: ChatMessage | list[ChatMessage],
model_name: str,
model_name: _ModelName,
) -> int:
...
@@ -410,7 +405,7 @@ class ChatModelProvider(ModelProvider):
async def create_chat_completion(
self,
model_prompt: list[ChatMessage],
model_name: str,
model_name: _ModelName,
completion_parser: Callable[[AssistantChatMessage], _T] = lambda _: None,
functions: Optional[list[CompletionModelFunction]] = None,
max_output_tokens: Optional[int] = None,

View File

@@ -8,7 +8,6 @@ import sys
from pathlib import Path
from typing import TYPE_CHECKING, Optional
from colorama import Fore, Style
from openai._base_client import log as openai_logger
from forge.models.config import SystemConfiguration, UserConfigurable
@@ -65,7 +64,7 @@ class LoggingConfig(SystemConfiguration):
log_dir: Path = LOG_DIR
log_file_format: Optional[LogFormatName] = UserConfigurable(
default=LogFormatName.SIMPLE,
from_env=lambda: os.getenv(
from_env=lambda: os.getenv( # type: ignore
"LOG_FILE_FORMAT", os.getenv("LOG_FORMAT", "simple")
),
)

View File

@@ -1,6 +1,6 @@
from __future__ import annotations
from typing import Any, Literal, Optional
from typing import Any, Literal, Optional, TypeVar
from pydantic import BaseModel
@@ -11,7 +11,10 @@ from .utils import ModelWithSummary
class ActionProposal(BaseModel):
thoughts: str | ModelWithSummary
use_tool: AssistantFunctionCall = None
use_tool: AssistantFunctionCall
AnyProposal = TypeVar("AnyProposal", bound=ActionProposal)
class ActionSuccessResult(BaseModel):

View File

@@ -1,4 +1,3 @@
import abc
import os
import typing
from typing import Any, Callable, Generic, Optional, Type, TypeVar, get_args
@@ -85,11 +84,11 @@ class SystemSettings(BaseModel):
S = TypeVar("S", bound=SystemSettings)
class Configurable(abc.ABC, Generic[S]):
class Configurable(Generic[S]):
"""A base class for all configurable objects."""
prefix: str = ""
default_settings: typing.ClassVar[S]
default_settings: typing.ClassVar[S] # type: ignore
@classmethod
def get_user_config(cls) -> dict[str, Any]:

View File

@@ -1,6 +1,6 @@
import enum
from textwrap import indent
from typing import Optional
from typing import Optional, overload
from jsonschema import Draft7Validator, ValidationError
from pydantic import BaseModel
@@ -57,30 +57,8 @@ class JSONSchema(BaseModel):
@staticmethod
def from_dict(schema: dict) -> "JSONSchema":
def resolve_references(schema: dict, definitions: dict) -> dict:
"""
Recursively resolve type $refs in the JSON schema with their definitions.
"""
if isinstance(schema, dict):
if "$ref" in schema:
ref_path = schema["$ref"].split("/")[
2:
] # Split and remove '#/definitions'
ref_value = definitions
for key in ref_path:
ref_value = ref_value[key]
return resolve_references(ref_value, definitions)
else:
return {
k: resolve_references(v, definitions) for k, v in schema.items()
}
elif isinstance(schema, list):
return [resolve_references(item, definitions) for item in schema]
else:
return schema
definitions = schema.get("definitions", {})
schema = resolve_references(schema, definitions)
schema = _resolve_type_refs_in_schema(schema, definitions)
return JSONSchema(
description=schema.get("description"),
@@ -147,21 +125,55 @@ class JSONSchema(BaseModel):
@property
def typescript_type(self) -> str:
if not self.type:
return "any"
if self.type == JSONSchema.Type.BOOLEAN:
return "boolean"
elif self.type in {JSONSchema.Type.INTEGER, JSONSchema.Type.NUMBER}:
if self.type in {JSONSchema.Type.INTEGER, JSONSchema.Type.NUMBER}:
return "number"
elif self.type == JSONSchema.Type.STRING:
if self.type == JSONSchema.Type.STRING:
return "string"
elif self.type == JSONSchema.Type.ARRAY:
if self.type == JSONSchema.Type.ARRAY:
return f"Array<{self.items.typescript_type}>" if self.items else "Array"
elif self.type == JSONSchema.Type.OBJECT:
if self.type == JSONSchema.Type.OBJECT:
if not self.properties:
return "Record<string, any>"
return self.to_typescript_object_interface()
elif self.enum:
if self.enum:
return " | ".join(repr(v) for v in self.enum)
raise NotImplementedError(
f"JSONSchema.typescript_type does not support Type.{self.type.name} yet"
)
@overload
def _resolve_type_refs_in_schema(schema: dict, definitions: dict) -> dict:
...
@overload
def _resolve_type_refs_in_schema(schema: list, definitions: dict) -> list:
...
def _resolve_type_refs_in_schema(schema: dict | list, definitions: dict) -> dict | list:
"""
Recursively resolve type $refs in the JSON schema with their definitions.
"""
if isinstance(schema, dict):
if "$ref" in schema:
ref_path = schema["$ref"].split("/")[2:] # Split and remove '#/definitions'
ref_value = definitions
for key in ref_path:
ref_value = ref_value[key]
return _resolve_type_refs_in_schema(ref_value, definitions)
else:
raise NotImplementedError(
f"JSONSchema.typescript_type does not support Type.{self.type.name} yet"
)
return {
k: _resolve_type_refs_in_schema(v, definitions)
for k, v in schema.items()
}
elif isinstance(schema, list):
return [_resolve_type_refs_in_schema(item, definitions) for item in schema]
else:
return schema

View File

@@ -1,31 +1,26 @@
import abc
import enum
import math
from typing import Callable, Generic, TypeVar
from pydantic import BaseModel, SecretBytes, SecretField, SecretStr
from forge.models.config import SystemConfiguration, SystemSettings, UserConfigurable
from forge.models.config import SystemConfiguration, UserConfigurable
_T = TypeVar("_T")
class ResourceType(str, enum.Enum):
"""An enumeration of resource types."""
MODEL = "model"
MEMORY = "memory"
class ProviderUsage(SystemConfiguration, abc.ABC):
@abc.abstractmethod
def update_usage(self, *args, **kwargs) -> None:
"""Update the usage of the resource."""
...
class ProviderBudget(SystemConfiguration):
class ProviderBudget(SystemConfiguration, Generic[_T]):
total_budget: float = UserConfigurable(math.inf)
total_cost: float = 0
remaining_budget: float = math.inf
usage: ProviderUsage
usage: _T
@abc.abstractmethod
def update_usage_and_cost(self, *args, **kwargs) -> float:
@@ -43,8 +38,8 @@ class ProviderCredentials(SystemConfiguration):
def unmasked(self) -> dict:
return unmask(self)
class Config:
json_encoders = {
class Config(SystemConfiguration.Config):
json_encoders: dict[type[SecretField], Callable[[SecretField], str | None]] = {
SecretStr: lambda v: v.get_secret_value() if v else None,
SecretBytes: lambda v: v.get_secret_value() if v else None,
SecretField: lambda v: v.get_secret_value() if v else None,
@@ -62,11 +57,5 @@ def unmask(model: BaseModel):
return unmasked_fields
class ProviderSettings(SystemSettings):
resource_type: ResourceType
credentials: ProviderCredentials | None = None
budget: ProviderBudget | None = None
# Used both by model providers and memory providers
Embedding = list[float]

View File

@@ -1,2 +1,4 @@
"""This module contains the speech recognition and speech synthesis functions."""
"""This module contains the (speech recognition and) speech synthesis functions."""
from .say import TextToSpeechProvider, TTSConfig
__all__ = ["TextToSpeechProvider", "TTSConfig"]

View File

@@ -45,7 +45,7 @@ class VoiceBase:
"""
@abc.abstractmethod
def _speech(self, text: str, voice_index: int = 0) -> bool:
def _speech(self, text: str, voice_id: int = 0) -> bool:
"""
Play the given text.

View File

@@ -66,7 +66,7 @@ class ElevenLabsSpeech(VoiceBase):
if voice and voice not in PLACEHOLDERS:
self._voices[voice_index] = voice
def _speech(self, text: str, voice_index: int = 0) -> bool:
def _speech(self, text: str, voice_id: int = 0) -> bool:
"""Speak text using elevenlabs.io's API
Args:
@@ -77,7 +77,7 @@ class ElevenLabsSpeech(VoiceBase):
bool: True if the request was successful, False otherwise
"""
tts_url = (
f"https://api.elevenlabs.io/v1/text-to-speech/{self._voices[voice_index]}"
f"https://api.elevenlabs.io/v1/text-to-speech/{self._voices[voice_id]}"
)
response = requests.post(tts_url, headers=self._headers, json={"text": text})

View File

@@ -15,7 +15,7 @@ class GTTSVoice(VoiceBase):
def _setup(self) -> None:
pass
def _speech(self, text: str, _: int = 0) -> bool:
def _speech(self, text: str, voice_id: int = 0) -> bool:
"""Play the given text."""
tts = gtts.gTTS(text)
tts.save("speech.mp3")

View File

@@ -12,11 +12,11 @@ class MacOSTTS(VoiceBase):
def _setup(self) -> None:
pass
def _speech(self, text: str, voice_index: int = 0) -> bool:
def _speech(self, text: str, voice_id: int = 0) -> bool:
"""Play the given text."""
if voice_index == 0:
if voice_id == 0:
subprocess.run(["say", text], shell=False)
elif voice_index == 1:
elif voice_id == 1:
subprocess.run(["say", "-v", "Ava (Premium)", text], shell=False)
else:
subprocess.run(["say", "-v", "Samantha", text], shell=False)

View File

@@ -24,8 +24,7 @@ class StreamElementsSpeech(VoiceBase):
"""Setup the voices, API key, etc."""
self.config = config
def _speech(self, text: str, voice: str, _: int = 0) -> bool:
voice = self.config.voice
def _speech(self, text: str, voice_id: int = 0) -> bool:
"""Speak text using the streamelements API
Args:
@@ -35,6 +34,7 @@ class StreamElementsSpeech(VoiceBase):
Returns:
bool: True if the request was successful, False otherwise
"""
voice = self.config.voice
tts_url = (
f"https://api.streamelements.com/kappa/v2/speech?voice={voice}&text={text}"
)

View File

@@ -7,7 +7,7 @@ from typing import Optional
def get_exception_message():
"""Get current exception type and message."""
exc_type, exc_value, _ = sys.exc_info()
exception_message = f"{exc_type.__name__}: {exc_value}"
exception_message = f"{exc_type.__name__}: {exc_value}" if exc_type else exc_value
return exception_message

View File

@@ -25,7 +25,7 @@ class TXTParser(ParserStrategy):
charset_match = charset_normalizer.from_bytes(file.read()).best()
logger.debug(
f"Reading {getattr(file, 'name', 'file')} "
f"with encoding '{charset_match.encoding}'"
f"with encoding '{charset_match.encoding if charset_match else None}'"
)
return str(charset_match)

View File

@@ -39,7 +39,7 @@ def validate_url(func: Callable[P, T]) -> Callable[P, T]:
return func(*bound_args.args, **bound_args.kwargs)
return wrapper
return wrapper # type: ignore
def is_valid_url(url: str) -> bool:

View File

@@ -1,13 +0,0 @@
[mypy]
namespace_packages = True
follow_imports = skip
check_untyped_defs = True
disallow_untyped_defs = True
exclude = ^(agbenchmark/challenges/|agent/|venv|venv-dev)
ignore_missing_imports = True
[mypy-agbenchmark.utils.data_types.*]
ignore_errors = True
[mypy-numpy.*]
ignore_errors = True

759
forge/poetry.lock generated

File diff suppressed because one or more lines are too long

View File

@@ -26,8 +26,10 @@ duckduckgo-search = "^5.0.0"
fastapi = "^0.109.1"
gitpython = "^3.1.32"
google-api-python-client = "*"
google-cloud-logging = "^3.8.0"
google-cloud-storage = "^2.13.0"
groq = "^0.8.0"
gTTS = "^2.3.1"
jinja2 = "^3.1.2"
jsonschema = "*"
litellm = "^1.17.9"
@@ -57,16 +59,17 @@ webdriver-manager = "^4.0.1"
benchmark = ["agbenchmark"]
[tool.poetry.group.dev.dependencies]
isort = "^5.12.0"
black = "^23.3.0"
black = "^23.12.1"
flake8 = "^7.0.0"
isort = "^5.13.1"
pyright = "^1.1.364"
pre-commit = "^3.3.3"
mypy = "^1.4.1"
flake8 = "^6.0.0"
boto3-stubs = { extras = ["s3"], version = "^1.33.6" }
types-requests = "^2.31.0.2"
pytest = "^7.4.0"
pytest-asyncio = "^0.21.1"
pytest-cov = "^5.0.0"
mock = "^5.1.0"
autoflake = "^2.2.0"
pydevd-pycharm = "^233.6745.319"
@@ -74,20 +77,21 @@ pydevd-pycharm = "^233.6745.319"
requires = ["poetry-core"]
build-backend = "poetry.core.masonry.api"
[tool.black]
line-length = 88
target-version = ['py310']
include = '\.pyi?$'
packages = ["forge"]
extend-exclude = '(/dist|/.venv|/venv|/build|/agent|agbenchmark/challenges)/'
[tool.isort]
profile = "black"
multi_line_output = 3
include_trailing_comma = true
force_grid_wrap = 0
use_parentheses = true
ensure_newline_before_comments = true
line_length = 88
sections = ["FUTURE", "STDLIB", "THIRDPARTY", "FIRSTPARTY", "LOCALFOLDER"]
skip_glob = [".tox", "__pycache__", "*.pyc", "venv*/*", "reports", "venv", "env", "node_modules", ".env", ".venv", "dist", "agent/*", "agbenchmark/challenges/*"]
[tool.pyright]
pythonVersion = "3.10"
[tool.pytest.ini_options]
pythonpath = ["forge"]
testpaths = ["forge", "tests"]