mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-01-09 15:17:59 -05:00
feat(agent): Persist AgentContext in Agent.state (#7125)
Persist the agent's `AgentContext` so that it works in rehydrated agent instances. This makes context usable in the `AgentProtocolServer`, where the agent instance is loaded and destroyed for every step. - Make `AgentContext` a Pydantic model - Add `context` parameter to `ContextComponent.__init__` so we can pass in an existing instance - Add `context: AgentContext` to `AgentSettings` so it is persisted - Add `type` attribute to `ContextItem` implementations as a discriminator - Rename `ContextItem` base class to `BaseContextItem` and make new `ContextItem` type alias (union of the implementation types)
This commit is contained in:
committed by
GitHub
parent
7e02cfdc9f
commit
b0cbf711dc
@@ -50,7 +50,7 @@ from autogpt.utils.exceptions import (
|
||||
|
||||
from .base import BaseAgent, BaseAgentConfiguration, BaseAgentSettings
|
||||
from .features.agent_file_manager import FileManagerComponent
|
||||
from .features.context import ContextComponent
|
||||
from .features.context import AgentContext, ContextComponent
|
||||
from .features.watchdog import WatchdogComponent
|
||||
from .prompt_strategies.one_shot import (
|
||||
OneShotAgentActionProposal,
|
||||
@@ -82,6 +82,8 @@ class AgentSettings(BaseAgentSettings):
|
||||
)
|
||||
"""(STATE) The action history of the agent."""
|
||||
|
||||
context: AgentContext = Field(default_factory=AgentContext)
|
||||
|
||||
|
||||
class Agent(BaseAgent, Configurable[AgentSettings]):
|
||||
default_settings: AgentSettings = AgentSettings(
|
||||
@@ -132,7 +134,7 @@ class Agent(BaseAgent, Configurable[AgentSettings]):
|
||||
)
|
||||
self.web_search = WebSearchComponent(legacy_config)
|
||||
self.web_selenium = WebSeleniumComponent(legacy_config, llm_provider, self.llm)
|
||||
self.context = ContextComponent(self.file_manager.workspace)
|
||||
self.context = ContextComponent(self.file_manager.workspace, settings.context)
|
||||
self.watchdog = WatchdogComponent(settings.config, settings.history)
|
||||
|
||||
self.created_at = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
|
||||
@@ -1,6 +1,9 @@
|
||||
import contextlib
|
||||
from pathlib import Path
|
||||
from typing import Iterator, Optional
|
||||
from typing import Iterator
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from typing_extensions import Annotated
|
||||
|
||||
from autogpt.agents.protocols import CommandProvider, MessageProvider
|
||||
from autogpt.command_decorator import command
|
||||
@@ -12,11 +15,10 @@ from autogpt.models.context_item import ContextItem, FileContextItem, FolderCont
|
||||
from autogpt.utils.exceptions import InvalidArgumentError
|
||||
|
||||
|
||||
class AgentContext:
|
||||
items: list[ContextItem]
|
||||
|
||||
def __init__(self, items: Optional[list[ContextItem]] = None):
|
||||
self.items = items or []
|
||||
class AgentContext(BaseModel):
|
||||
items: list[Annotated[ContextItem, Field(discriminator="type")]] = Field(
|
||||
default_factory=list
|
||||
)
|
||||
|
||||
def __bool__(self) -> bool:
|
||||
return len(self.items) > 0
|
||||
@@ -42,8 +44,8 @@ class AgentContext:
|
||||
class ContextComponent(MessageProvider, CommandProvider):
|
||||
"""Adds ability to keep files and folders open in the context (prompt)."""
|
||||
|
||||
def __init__(self, workspace: FileStorage):
|
||||
self.context = AgentContext()
|
||||
def __init__(self, workspace: FileStorage, context: AgentContext):
|
||||
self.context = context
|
||||
self.workspace = workspace
|
||||
|
||||
def get_messages(self) -> Iterator[ChatMessage]:
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import logging
|
||||
from abc import ABC, abstractmethod
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
from typing import Literal, Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
@@ -11,7 +11,7 @@ from autogpt.utils.file_operations_utils import decode_textual_file
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ContextItem(ABC):
|
||||
class BaseContextItem(ABC):
|
||||
@property
|
||||
@abstractmethod
|
||||
def description(self) -> str:
|
||||
@@ -38,8 +38,9 @@ class ContextItem(ABC):
|
||||
)
|
||||
|
||||
|
||||
class FileContextItem(BaseModel, ContextItem):
|
||||
class FileContextItem(BaseModel, BaseContextItem):
|
||||
path: Path
|
||||
type: Literal["file"] = "file"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
@@ -54,8 +55,9 @@ class FileContextItem(BaseModel, ContextItem):
|
||||
return decode_textual_file(file, self.path.suffix, logger)
|
||||
|
||||
|
||||
class FolderContextItem(BaseModel, ContextItem):
|
||||
class FolderContextItem(BaseModel, BaseContextItem):
|
||||
path: Path
|
||||
type: Literal["folder"] = "folder"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
@@ -73,7 +75,11 @@ class FolderContextItem(BaseModel, ContextItem):
|
||||
return "\n".join(items)
|
||||
|
||||
|
||||
class StaticContextItem(BaseModel, ContextItem):
|
||||
class StaticContextItem(BaseModel, BaseContextItem):
|
||||
item_description: str = Field(alias="description")
|
||||
item_source: Optional[str] = Field(alias="source")
|
||||
item_content: str = Field(alias="content")
|
||||
type: Literal["static"] = "static"
|
||||
|
||||
|
||||
ContextItem = FileContextItem | FolderContextItem | StaticContextItem
|
||||
|
||||
Reference in New Issue
Block a user