feat(agent): Persist AgentContext in Agent.state (#7125)

Persist the agent's `AgentContext` so that it works in rehydrated agent instances. This makes context usable in the `AgentProtocolServer`, where the agent instance is loaded and destroyed for every step.

- Make `AgentContext` a Pydantic model
- Add `context` parameter to `ContextComponent.__init__` so we can pass in an existing instance
- Add `context: AgentContext` to `AgentSettings` so it is persisted
- Add `type` attribute to `ContextItem` implementations as a discriminator
- Rename `ContextItem` base class to `BaseContextItem` and make new `ContextItem` type alias (union of the implementation types)
This commit is contained in:
Reinier van der Leer
2024-05-10 09:30:12 +02:00
committed by GitHub
parent 7e02cfdc9f
commit b0cbf711dc
3 changed files with 25 additions and 15 deletions

View File

@@ -50,7 +50,7 @@ from autogpt.utils.exceptions import (
from .base import BaseAgent, BaseAgentConfiguration, BaseAgentSettings
from .features.agent_file_manager import FileManagerComponent
from .features.context import ContextComponent
from .features.context import AgentContext, ContextComponent
from .features.watchdog import WatchdogComponent
from .prompt_strategies.one_shot import (
OneShotAgentActionProposal,
@@ -82,6 +82,8 @@ class AgentSettings(BaseAgentSettings):
)
"""(STATE) The action history of the agent."""
context: AgentContext = Field(default_factory=AgentContext)
class Agent(BaseAgent, Configurable[AgentSettings]):
default_settings: AgentSettings = AgentSettings(
@@ -132,7 +134,7 @@ class Agent(BaseAgent, Configurable[AgentSettings]):
)
self.web_search = WebSearchComponent(legacy_config)
self.web_selenium = WebSeleniumComponent(legacy_config, llm_provider, self.llm)
self.context = ContextComponent(self.file_manager.workspace)
self.context = ContextComponent(self.file_manager.workspace, settings.context)
self.watchdog = WatchdogComponent(settings.config, settings.history)
self.created_at = datetime.now().strftime("%Y%m%d_%H%M%S")

View File

@@ -1,6 +1,9 @@
import contextlib
from pathlib import Path
from typing import Iterator, Optional
from typing import Iterator
from pydantic import BaseModel, Field
from typing_extensions import Annotated
from autogpt.agents.protocols import CommandProvider, MessageProvider
from autogpt.command_decorator import command
@@ -12,11 +15,10 @@ from autogpt.models.context_item import ContextItem, FileContextItem, FolderCont
from autogpt.utils.exceptions import InvalidArgumentError
class AgentContext:
items: list[ContextItem]
def __init__(self, items: Optional[list[ContextItem]] = None):
self.items = items or []
class AgentContext(BaseModel):
items: list[Annotated[ContextItem, Field(discriminator="type")]] = Field(
default_factory=list
)
def __bool__(self) -> bool:
return len(self.items) > 0
@@ -42,8 +44,8 @@ class AgentContext:
class ContextComponent(MessageProvider, CommandProvider):
"""Adds ability to keep files and folders open in the context (prompt)."""
def __init__(self, workspace: FileStorage):
self.context = AgentContext()
def __init__(self, workspace: FileStorage, context: AgentContext):
self.context = context
self.workspace = workspace
def get_messages(self) -> Iterator[ChatMessage]:

View File

@@ -1,7 +1,7 @@
import logging
from abc import ABC, abstractmethod
from pathlib import Path
from typing import Optional
from typing import Literal, Optional
from pydantic import BaseModel, Field
@@ -11,7 +11,7 @@ from autogpt.utils.file_operations_utils import decode_textual_file
logger = logging.getLogger(__name__)
class ContextItem(ABC):
class BaseContextItem(ABC):
@property
@abstractmethod
def description(self) -> str:
@@ -38,8 +38,9 @@ class ContextItem(ABC):
)
class FileContextItem(BaseModel, ContextItem):
class FileContextItem(BaseModel, BaseContextItem):
path: Path
type: Literal["file"] = "file"
@property
def description(self) -> str:
@@ -54,8 +55,9 @@ class FileContextItem(BaseModel, ContextItem):
return decode_textual_file(file, self.path.suffix, logger)
class FolderContextItem(BaseModel, ContextItem):
class FolderContextItem(BaseModel, BaseContextItem):
path: Path
type: Literal["folder"] = "folder"
@property
def description(self) -> str:
@@ -73,7 +75,11 @@ class FolderContextItem(BaseModel, ContextItem):
return "\n".join(items)
class StaticContextItem(BaseModel, ContextItem):
class StaticContextItem(BaseModel, BaseContextItem):
item_description: str = Field(alias="description")
item_source: Optional[str] = Field(alias="source")
item_content: str = Field(alias="content")
type: Literal["static"] = "static"
ContextItem = FileContextItem | FolderContextItem | StaticContextItem