fix(agent): Fix type propagation of Command and @command when used on methods (#7124)

This commit is contained in:
Reinier van der Leer
2024-05-09 11:39:09 +02:00
committed by GitHub
parent 34fdbaa26b
commit 7e02cfdc9f
2 changed files with 13 additions and 13 deletions

View File

@@ -1,21 +1,18 @@
import re
from typing import Callable, Optional, ParamSpec, TypeVar
from typing import Callable, Concatenate, Optional, TypeVar
from autogpt.agents.protocols import CommandProvider
from autogpt.core.utils.json_schema import JSONSchema
from autogpt.models.command import Command, CommandOutput, CommandParameter
from autogpt.models.command import CO, Command, CommandParameter, P
# Unique identifier for AutoGPT commands
AUTO_GPT_COMMAND_IDENTIFIER = "auto_gpt_command"
P = ParamSpec("P")
CO = TypeVar("CO", bound=CommandOutput)
_CP = TypeVar("_CP", bound=CommandProvider)
def command(
names: list[str] = [],
description: Optional[str] = None,
parameters: dict[str, JSONSchema] = {},
) -> Callable[[Callable[P, CommandOutput]], Command]:
) -> Callable[[Callable[Concatenate[_CP, P], CO]], Command[P, CO]]:
"""
The command decorator is used to make a Command from a function.
@@ -29,7 +26,7 @@ def command(
that the command executes.
"""
def decorator(func: Callable[P, CO]) -> Command:
def decorator(func: Callable[Concatenate[_CP, P], CO]) -> Command[P, CO]:
doc = func.__doc__ or ""
# If names is not provided, use the function name
command_names = names or [func.__name__]

View File

@@ -1,7 +1,7 @@
from __future__ import annotations
import inspect
from typing import Any, Callable
from typing import Any, Callable, Generic, ParamSpec, TypeVar
from .command_parameter import CommandParameter
from .context_item import ContextItem
@@ -9,8 +9,11 @@ from .context_item import ContextItem
CommandReturnValue = Any
CommandOutput = CommandReturnValue | tuple[CommandReturnValue, ContextItem]
P = ParamSpec("P")
CO = TypeVar("CO", bound=CommandOutput)
class Command:
class Command(Generic[P, CO]):
"""A class representing a command.
Attributes:
@@ -23,7 +26,7 @@ class Command:
self,
names: list[str],
description: str,
method: Callable[..., CommandOutput],
method: Callable[P, CO],
parameters: list[CommandParameter],
):
# Check if all parameters are provided
@@ -55,7 +58,7 @@ class Command:
# Check if sorted lists of names/keys are equal
return sorted(func_param_names) == sorted(names)
def __call__(self, *args, **kwargs) -> Any:
def __call__(self, *args: P.args, **kwargs: P.kwargs) -> CO:
return self.method(*args, **kwargs)
def __str__(self) -> str: