mirror of
https://github.com/microsoft/autogen.git
synced 2026-04-20 03:02:16 -04:00
Update pyright extends for core (#562)
* Update pyright extends for core * Fixes
This commit is contained in:
@@ -75,9 +75,10 @@ exclude = ["build", "dist", "src/autogen_core/application/protos"]
|
||||
include = ["src/**", "samples/*.py", "docs/**/*.ipynb", "tests/**"]
|
||||
|
||||
[tool.pyright]
|
||||
extend = "../../pyproject.toml"
|
||||
extends = "../../pyproject.toml"
|
||||
include = ["src", "tests", "samples"]
|
||||
exclude = ["src/autogen_core/application/protos"]
|
||||
reportDeprecated = false
|
||||
|
||||
[tool.pytest.ini_options]
|
||||
minversion = "6.0"
|
||||
|
||||
@@ -9,10 +9,9 @@ from autogen_core.components import (
|
||||
message_handler,
|
||||
)
|
||||
from autogen_core.components.model_context import ChatCompletionContext
|
||||
from autogen_core.components.models import AssistantMessage, UserMessage
|
||||
from autogen_core.components.models import UserMessage
|
||||
|
||||
from ..types import (
|
||||
Message,
|
||||
MultiModalMessage,
|
||||
PublishNow,
|
||||
Reset,
|
||||
|
||||
@@ -7,7 +7,6 @@ from autogen_core.components.model_context import ChatCompletionContext
|
||||
from autogen_core.components.models import ChatCompletionClient, UserMessage
|
||||
|
||||
from ..types import (
|
||||
Message,
|
||||
MultiModalMessage,
|
||||
PublishNow,
|
||||
Reset,
|
||||
|
||||
@@ -8,7 +8,6 @@ from autogen_core.base import (
|
||||
AgentId,
|
||||
AgentInstantiationContext,
|
||||
MessageContext,
|
||||
try_get_known_serializers_for_type,
|
||||
)
|
||||
from autogen_core.components import DefaultSubscription, DefaultTopicId, RoutedAgent, message_handler
|
||||
|
||||
|
||||
@@ -113,8 +113,7 @@ class HostConnection:
|
||||
if self._connection_task is None:
|
||||
raise RuntimeError("Connection is not open.")
|
||||
await self._channel.close()
|
||||
if self._connection_task is not None:
|
||||
await self._connection_task
|
||||
await self._connection_task
|
||||
|
||||
@staticmethod
|
||||
async def _connect( # type: ignore
|
||||
@@ -227,6 +226,8 @@ class WorkerAgentRuntime(AgentRuntime):
|
||||
task.add_done_callback(self._background_tasks.discard)
|
||||
case None:
|
||||
logger.warning("No message")
|
||||
case other:
|
||||
logger.error(f"Unknown message type: {other}")
|
||||
except Exception as e:
|
||||
logger.error("Error in read loop", exc_info=e)
|
||||
|
||||
|
||||
@@ -125,6 +125,8 @@ class WorkerAgentRuntimeHostServicer(agent_worker_pb2_grpc.AgentRpcServicer):
|
||||
logger.warning(f"Received unexpected message type: {oneofcase}")
|
||||
case None:
|
||||
logger.warning("Received empty message")
|
||||
case other:
|
||||
logger.error(f"Received unexpected message: {other}")
|
||||
|
||||
async def _process_request(self, request: agent_worker_pb2.RpcRequest, client_id: int) -> None:
|
||||
# Deliver the message to a client given the target agent type.
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from typing import Any, List, Mapping, Protocol, runtime_checkable
|
||||
from typing import Any, Mapping, Protocol, runtime_checkable
|
||||
|
||||
from ._agent_id import AgentId
|
||||
from ._agent_metadata import AgentMetadata
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Sequence
|
||||
from typing import Any, Awaitable, Callable, List, Mapping, Protocol, Type, TypeVar, overload, runtime_checkable
|
||||
from typing import Any, Awaitable, Callable, Mapping, Protocol, Type, TypeVar, overload, runtime_checkable
|
||||
|
||||
from typing_extensions import deprecated
|
||||
|
||||
|
||||
@@ -4,8 +4,7 @@ import inspect
|
||||
import warnings
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Sequence
|
||||
from re import S
|
||||
from typing import Any, Awaitable, Callable, ClassVar, List, Mapping, Tuple, Type, TypeVar, overload
|
||||
from typing import Any, Awaitable, Callable, ClassVar, List, Mapping, Tuple, Type, TypeVar
|
||||
|
||||
from typing_extensions import Self
|
||||
|
||||
@@ -18,7 +17,7 @@ from ._agent_type import AgentType
|
||||
from ._cancellation_token import CancellationToken
|
||||
from ._message_context import MessageContext
|
||||
from ._serialization import MessageSerializer, try_get_known_serializers_for_type
|
||||
from ._subscription import UnboundSubscription
|
||||
from ._subscription import Subscription, UnboundSubscription
|
||||
from ._subscription_context import SubscriptionInstantiationContext
|
||||
from ._topic import TopicId
|
||||
|
||||
@@ -30,7 +29,7 @@ BaseAgentType = TypeVar("BaseAgentType", bound="BaseAgent")
|
||||
# Decorator for adding an unbound subscription to an agent
|
||||
def subscription_factory(subscription: UnboundSubscription) -> Callable[[Type[BaseAgentType]], Type[BaseAgentType]]:
|
||||
def decorator(cls: Type[BaseAgentType]) -> Type[BaseAgentType]:
|
||||
cls._unbound_subscriptions_list.append(subscription)
|
||||
cls.internal_unbound_subscriptions_list.append(subscription)
|
||||
return cls
|
||||
|
||||
return decorator
|
||||
@@ -48,29 +47,29 @@ def handles(
|
||||
if len(serializer_list) == 0:
|
||||
raise ValueError(f"No serializers found for type {type}. Please provide an explicit serializer.")
|
||||
|
||||
cls._extra_handles_types.append((type, serializer_list))
|
||||
cls.internal_extra_handles_types.append((type, serializer_list))
|
||||
return cls
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
class BaseAgent(ABC, Agent):
|
||||
_unbound_subscriptions_list: ClassVar[List[UnboundSubscription]] = []
|
||||
_extra_handles_types: ClassVar[List[Tuple[Type[Any], List[MessageSerializer[Any]]]]] = []
|
||||
internal_unbound_subscriptions_list: ClassVar[List[UnboundSubscription]] = []
|
||||
internal_extra_handles_types: ClassVar[List[Tuple[Type[Any], List[MessageSerializer[Any]]]]] = []
|
||||
|
||||
def __init_subclass__(cls, **kwargs: Any) -> None:
|
||||
super().__init_subclass__(**kwargs)
|
||||
# Automatically set class_variable in each subclass so that they are not shared between subclasses
|
||||
cls._extra_handles_types = []
|
||||
cls._unbound_subscriptions_list = []
|
||||
cls.internal_extra_handles_types = []
|
||||
cls.internal_unbound_subscriptions_list = []
|
||||
|
||||
@classmethod
|
||||
def _handles_types(cls) -> List[Tuple[Type[Any], List[MessageSerializer[Any]]]]:
|
||||
return cls._extra_handles_types
|
||||
return cls.internal_extra_handles_types
|
||||
|
||||
@classmethod
|
||||
def _unbound_subscriptions(cls) -> List[UnboundSubscription]:
|
||||
return cls._unbound_subscriptions_list
|
||||
return cls.internal_unbound_subscriptions_list
|
||||
|
||||
@property
|
||||
def metadata(self) -> AgentMetadata:
|
||||
@@ -155,7 +154,7 @@ class BaseAgent(ABC, Agent):
|
||||
agent_type = await runtime.register_factory(type=agent_type, agent_factory=factory, expected_class=cls)
|
||||
if not skip_class_subscriptions:
|
||||
with SubscriptionInstantiationContext.populate_context(agent_type):
|
||||
subscriptions = []
|
||||
subscriptions: List[Subscription] = []
|
||||
for unbound_subscription in cls._unbound_subscriptions():
|
||||
subscriptions_list_result = unbound_subscription()
|
||||
if inspect.isawaitable(subscriptions_list_result):
|
||||
|
||||
@@ -3,7 +3,6 @@ from dataclasses import asdict, dataclass, fields
|
||||
from typing import Any, ClassVar, Dict, List, Protocol, Sequence, TypeVar, cast, get_args, get_origin, runtime_checkable
|
||||
|
||||
from pydantic import BaseModel
|
||||
from typing_extensions import deprecated
|
||||
|
||||
from autogen_core.base._type_helpers import is_union
|
||||
|
||||
|
||||
@@ -1,9 +1,8 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Awaitable, Callable, Protocol, runtime_checkable
|
||||
from typing import Awaitable, Callable, Protocol, runtime_checkable
|
||||
|
||||
from ._agent_id import AgentId
|
||||
from ._agent_type import AgentType
|
||||
from ._topic import TopicId
|
||||
|
||||
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
import inspect
|
||||
from typing import Any, Awaitable, Callable, List, Mapping, Sequence, TypeVar, cast, get_type_hints
|
||||
from typing import Any, Awaitable, Callable, List, Mapping, Sequence, TypeVar, get_type_hints
|
||||
|
||||
from ..base import (
|
||||
Agent,
|
||||
@@ -112,9 +112,10 @@ class ClosureAgent(Agent):
|
||||
with SubscriptionInstantiationContext.populate_context(agent_type):
|
||||
subscriptions_list_result = subscriptions()
|
||||
if inspect.isawaitable(subscriptions_list_result):
|
||||
subscriptions_list.extend(cast(List[Subscription], await subscriptions_list_result))
|
||||
subscriptions_list.extend(await subscriptions_list_result)
|
||||
else:
|
||||
subscriptions_list.extend(cast(List[Subscription], subscriptions_list_result))
|
||||
# just ignore mypy here
|
||||
subscriptions_list.extend(subscriptions_list_result) # type: ignore
|
||||
|
||||
agent_type = await runtime.register_factory(
|
||||
type=agent_type,
|
||||
|
||||
@@ -4,12 +4,12 @@ import base64
|
||||
import re
|
||||
from io import BytesIO
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
from typing import Any, cast
|
||||
|
||||
import aiohttp
|
||||
from openai.types.chat import ChatCompletionContentPartImageParam
|
||||
from PIL import Image as PILImage
|
||||
from pydantic import BaseModel, GetCoreSchemaHandler, ValidationInfo
|
||||
from pydantic import GetCoreSchemaHandler, ValidationInfo
|
||||
from pydantic_core import core_schema
|
||||
from typing_extensions import Literal
|
||||
|
||||
@@ -68,7 +68,7 @@ class Image:
|
||||
# Custom validation
|
||||
def validate(value: Any, validation_info: ValidationInfo) -> Image:
|
||||
if isinstance(value, dict):
|
||||
base_64 = value.get("data")
|
||||
base_64 = cast(str | None, value.get("data")) # type: ignore
|
||||
if base_64 is None:
|
||||
raise ValueError("Expected 'data' key in the dictionary")
|
||||
return cls.from_base64(base_64)
|
||||
|
||||
@@ -19,8 +19,6 @@ from typing import (
|
||||
runtime_checkable,
|
||||
)
|
||||
|
||||
from typing_extensions import Self
|
||||
|
||||
from ..base import BaseAgent, MessageContext, MessageSerializer, try_get_known_serializers_for_type
|
||||
from ..base._type_helpers import AnyType, get_types
|
||||
from ..base.exceptions import CantHandleException
|
||||
@@ -477,7 +475,7 @@ class RoutedAgent(BaseAgent):
|
||||
|
||||
@classmethod
|
||||
def _discover_handlers(cls) -> Sequence[MessageHandler[Any, Any, Any]]:
|
||||
handlers = []
|
||||
handlers: List[MessageHandler[Any, Any, Any]] = []
|
||||
for attr in dir(cls):
|
||||
if callable(getattr(cls, attr, None)):
|
||||
# Since we are getting it from the class, self is not bound
|
||||
@@ -491,7 +489,7 @@ class RoutedAgent(BaseAgent):
|
||||
# TODO handle deduplication
|
||||
handlers = cls._discover_handlers()
|
||||
types: List[Tuple[Type[Any], List[MessageSerializer[Any]]]] = []
|
||||
types.extend(cls._extra_handles_types)
|
||||
types.extend(cls.internal_extra_handles_types)
|
||||
for handler in handlers:
|
||||
for t in handler.target_types:
|
||||
# TODO: support different serializers
|
||||
|
||||
@@ -13,7 +13,7 @@ from collections.abc import Sequence
|
||||
from hashlib import md5
|
||||
from pathlib import Path
|
||||
from types import TracebackType
|
||||
from typing import Any, Callable, ClassVar, Dict, List, Optional, ParamSpec, Type, Union
|
||||
from typing import Any, Callable, ClassVar, List, Optional, ParamSpec, Type, Union
|
||||
|
||||
import docker
|
||||
import docker.models
|
||||
@@ -212,8 +212,8 @@ $functions"""
|
||||
if len(code_blocks) == 0:
|
||||
raise ValueError("No code blocks to execute.")
|
||||
|
||||
outputs = []
|
||||
files = []
|
||||
outputs: List[str] = []
|
||||
files: List[Path] = []
|
||||
last_exit_code = 0
|
||||
for code_block in code_blocks:
|
||||
lang = code_block.language.lower()
|
||||
@@ -237,7 +237,7 @@ $functions"""
|
||||
|
||||
command = ["timeout", str(self._timeout), lang_to_cmd(lang), filename]
|
||||
|
||||
result = await asyncio.to_thread(self._container.exec_run, command)
|
||||
result = await asyncio.to_thread(self._container.exec_run, command) # type: ignore
|
||||
exit_code = result.exit_code
|
||||
output = result.output.decode("utf-8")
|
||||
if exit_code == 124:
|
||||
@@ -277,7 +277,7 @@ $functions"""
|
||||
raise ValueError("Container is not running. Must first be started with either start or a context manager.")
|
||||
|
||||
"""(Experimental) Restart the code executor."""
|
||||
await asyncio.to_thread(self._container.restart)
|
||||
await asyncio.to_thread(self._container.restart) # type: ignore
|
||||
if self._container.status != "running":
|
||||
self._running = False
|
||||
logs_str = self._container.logs().decode("utf-8")
|
||||
|
||||
@@ -33,7 +33,7 @@ def docker_tests_enabled() -> bool:
|
||||
return False
|
||||
|
||||
|
||||
@pytest_asyncio.fixture(scope="function")
|
||||
@pytest_asyncio.fixture(scope="function") # type: ignore
|
||||
async def executor_and_temp_dir(
|
||||
request: pytest.FixtureRequest,
|
||||
) -> AsyncGenerator[tuple[LocalCommandLineCodeExecutor | DockerCommandLineCodeExecutor, str], None]:
|
||||
@@ -55,7 +55,7 @@ ExecutorFixture: TypeAlias = tuple[LocalCommandLineCodeExecutor | DockerCommandL
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("executor_and_temp_dir", ["local", "docker"], indirect=True)
|
||||
async def test_execute_code(executor_and_temp_dir: ExecutorFixture) -> None:
|
||||
executor, temp_dir = executor_and_temp_dir
|
||||
executor, _temp_dir = executor_and_temp_dir
|
||||
cancellation_token = CancellationToken()
|
||||
|
||||
# Test single code block.
|
||||
@@ -138,7 +138,7 @@ async def test_local_commandline_code_executor_restart() -> None:
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("executor_and_temp_dir", ["local", "docker"], indirect=True)
|
||||
async def test_invalid_relative_path(executor_and_temp_dir: ExecutorFixture) -> None:
|
||||
executor, temp_dir = executor_and_temp_dir
|
||||
executor, _temp_dir = executor_and_temp_dir
|
||||
cancellation_token = CancellationToken()
|
||||
code = """# filename: /tmp/test.py
|
||||
|
||||
|
||||
@@ -3,11 +3,10 @@ from dataclasses import dataclass
|
||||
|
||||
import pytest
|
||||
from autogen_core.application import SingleThreadedAgentRuntime
|
||||
from autogen_core.base import AgentId, AgentRuntime, MessageContext, TopicId
|
||||
from autogen_core.base import AgentId, AgentRuntime, MessageContext
|
||||
from autogen_core.components import ClosureAgent
|
||||
from autogen_core.components._default_subscription import DefaultSubscription
|
||||
from autogen_core.components._default_topic import DefaultTopicId
|
||||
from autogen_core.components._type_subscription import TypeSubscription
|
||||
|
||||
|
||||
@dataclass
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import asyncio
|
||||
from typing import Any, AsyncGenerator, List, Tuple
|
||||
from unittest.mock import MagicMock, patch
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
from autogen_core.base import CancellationToken
|
||||
|
||||
@@ -86,7 +86,7 @@ class RoutedAgentMessageCustomMatch(RoutedAgent):
|
||||
async def handler_one(self, message: TestMessage, ctx: MessageContext) -> None:
|
||||
self.handler_one_called = True
|
||||
|
||||
@message_handler(match=cast(Callable[[TestMessage, MessageContext], bool], lambda msg, ctx: msg.value == "two"))
|
||||
@message_handler(match=cast(Callable[[TestMessage, MessageContext], bool], lambda msg, ctx: msg.value == "two")) # type: ignore
|
||||
async def handler_two(self, message: TestMessage, ctx: MessageContext) -> None:
|
||||
self.handler_two_called = True
|
||||
|
||||
|
||||
@@ -12,7 +12,6 @@ from autogen_core.base import (
|
||||
try_get_known_serializers_for_type,
|
||||
)
|
||||
from autogen_core.components import (
|
||||
DefaultSubscription,
|
||||
DefaultTopicId,
|
||||
TypeSubscription,
|
||||
default_subscription,
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import asyncio
|
||||
import logging
|
||||
from typing import List
|
||||
|
||||
import pytest
|
||||
from autogen_core.application import WorkerAgentRuntime, WorkerAgentRuntimeHost
|
||||
@@ -10,7 +11,6 @@ from autogen_core.base import (
|
||||
try_get_known_serializers_for_type,
|
||||
)
|
||||
from autogen_core.components import (
|
||||
DefaultSubscription,
|
||||
DefaultTopicId,
|
||||
TypeSubscription,
|
||||
default_subscription,
|
||||
@@ -163,7 +163,7 @@ async def test_register_receives_publish_cascade_multiple_workers() -> None:
|
||||
total_num_calls_expected += num_initial_messages * ((num_agents - 1) ** i)
|
||||
|
||||
# Run multiple workers one for each agent.
|
||||
workers = []
|
||||
workers: List[WorkerAgentRuntime] = []
|
||||
# Register agents
|
||||
for i in range(num_agents):
|
||||
runtime = WorkerAgentRuntime(host_address=host_address)
|
||||
|
||||
Reference in New Issue
Block a user