Update pyright extends for core (#562)

* Update pyright extends for core

* Fixes
This commit is contained in:
Jack Gerrits
2024-09-20 15:51:38 -04:00
committed by GitHub
parent d60536c9fe
commit ab6ba80a98
21 changed files with 44 additions and 49 deletions

View File

@@ -75,9 +75,10 @@ exclude = ["build", "dist", "src/autogen_core/application/protos"]
include = ["src/**", "samples/*.py", "docs/**/*.ipynb", "tests/**"]
[tool.pyright]
extend = "../../pyproject.toml"
extends = "../../pyproject.toml"
include = ["src", "tests", "samples"]
exclude = ["src/autogen_core/application/protos"]
reportDeprecated = false
[tool.pytest.ini_options]
minversion = "6.0"

View File

@@ -9,10 +9,9 @@ from autogen_core.components import (
message_handler,
)
from autogen_core.components.model_context import ChatCompletionContext
from autogen_core.components.models import AssistantMessage, UserMessage
from autogen_core.components.models import UserMessage
from ..types import (
Message,
MultiModalMessage,
PublishNow,
Reset,

View File

@@ -7,7 +7,6 @@ from autogen_core.components.model_context import ChatCompletionContext
from autogen_core.components.models import ChatCompletionClient, UserMessage
from ..types import (
Message,
MultiModalMessage,
PublishNow,
Reset,

View File

@@ -8,7 +8,6 @@ from autogen_core.base import (
AgentId,
AgentInstantiationContext,
MessageContext,
try_get_known_serializers_for_type,
)
from autogen_core.components import DefaultSubscription, DefaultTopicId, RoutedAgent, message_handler

View File

@@ -113,8 +113,7 @@ class HostConnection:
if self._connection_task is None:
raise RuntimeError("Connection is not open.")
await self._channel.close()
if self._connection_task is not None:
await self._connection_task
await self._connection_task
@staticmethod
async def _connect( # type: ignore
@@ -227,6 +226,8 @@ class WorkerAgentRuntime(AgentRuntime):
task.add_done_callback(self._background_tasks.discard)
case None:
logger.warning("No message")
case other:
logger.error(f"Unknown message type: {other}")
except Exception as e:
logger.error("Error in read loop", exc_info=e)

View File

@@ -125,6 +125,8 @@ class WorkerAgentRuntimeHostServicer(agent_worker_pb2_grpc.AgentRpcServicer):
logger.warning(f"Received unexpected message type: {oneofcase}")
case None:
logger.warning("Received empty message")
case other:
logger.error(f"Received unexpected message: {other}")
async def _process_request(self, request: agent_worker_pb2.RpcRequest, client_id: int) -> None:
# Deliver the message to a client given the target agent type.

View File

@@ -1,4 +1,4 @@
from typing import Any, List, Mapping, Protocol, runtime_checkable
from typing import Any, Mapping, Protocol, runtime_checkable
from ._agent_id import AgentId
from ._agent_metadata import AgentMetadata

View File

@@ -1,7 +1,7 @@
from __future__ import annotations
from collections.abc import Sequence
from typing import Any, Awaitable, Callable, List, Mapping, Protocol, Type, TypeVar, overload, runtime_checkable
from typing import Any, Awaitable, Callable, Mapping, Protocol, Type, TypeVar, overload, runtime_checkable
from typing_extensions import deprecated

View File

@@ -4,8 +4,7 @@ import inspect
import warnings
from abc import ABC, abstractmethod
from collections.abc import Sequence
from re import S
from typing import Any, Awaitable, Callable, ClassVar, List, Mapping, Tuple, Type, TypeVar, overload
from typing import Any, Awaitable, Callable, ClassVar, List, Mapping, Tuple, Type, TypeVar
from typing_extensions import Self
@@ -18,7 +17,7 @@ from ._agent_type import AgentType
from ._cancellation_token import CancellationToken
from ._message_context import MessageContext
from ._serialization import MessageSerializer, try_get_known_serializers_for_type
from ._subscription import UnboundSubscription
from ._subscription import Subscription, UnboundSubscription
from ._subscription_context import SubscriptionInstantiationContext
from ._topic import TopicId
@@ -30,7 +29,7 @@ BaseAgentType = TypeVar("BaseAgentType", bound="BaseAgent")
# Decorator for adding an unbound subscription to an agent
def subscription_factory(subscription: UnboundSubscription) -> Callable[[Type[BaseAgentType]], Type[BaseAgentType]]:
def decorator(cls: Type[BaseAgentType]) -> Type[BaseAgentType]:
cls._unbound_subscriptions_list.append(subscription)
cls.internal_unbound_subscriptions_list.append(subscription)
return cls
return decorator
@@ -48,29 +47,29 @@ def handles(
if len(serializer_list) == 0:
raise ValueError(f"No serializers found for type {type}. Please provide an explicit serializer.")
cls._extra_handles_types.append((type, serializer_list))
cls.internal_extra_handles_types.append((type, serializer_list))
return cls
return decorator
class BaseAgent(ABC, Agent):
_unbound_subscriptions_list: ClassVar[List[UnboundSubscription]] = []
_extra_handles_types: ClassVar[List[Tuple[Type[Any], List[MessageSerializer[Any]]]]] = []
internal_unbound_subscriptions_list: ClassVar[List[UnboundSubscription]] = []
internal_extra_handles_types: ClassVar[List[Tuple[Type[Any], List[MessageSerializer[Any]]]]] = []
def __init_subclass__(cls, **kwargs: Any) -> None:
super().__init_subclass__(**kwargs)
# Automatically set class_variable in each subclass so that they are not shared between subclasses
cls._extra_handles_types = []
cls._unbound_subscriptions_list = []
cls.internal_extra_handles_types = []
cls.internal_unbound_subscriptions_list = []
@classmethod
def _handles_types(cls) -> List[Tuple[Type[Any], List[MessageSerializer[Any]]]]:
return cls._extra_handles_types
return cls.internal_extra_handles_types
@classmethod
def _unbound_subscriptions(cls) -> List[UnboundSubscription]:
return cls._unbound_subscriptions_list
return cls.internal_unbound_subscriptions_list
@property
def metadata(self) -> AgentMetadata:
@@ -155,7 +154,7 @@ class BaseAgent(ABC, Agent):
agent_type = await runtime.register_factory(type=agent_type, agent_factory=factory, expected_class=cls)
if not skip_class_subscriptions:
with SubscriptionInstantiationContext.populate_context(agent_type):
subscriptions = []
subscriptions: List[Subscription] = []
for unbound_subscription in cls._unbound_subscriptions():
subscriptions_list_result = unbound_subscription()
if inspect.isawaitable(subscriptions_list_result):

View File

@@ -3,7 +3,6 @@ from dataclasses import asdict, dataclass, fields
from typing import Any, ClassVar, Dict, List, Protocol, Sequence, TypeVar, cast, get_args, get_origin, runtime_checkable
from pydantic import BaseModel
from typing_extensions import deprecated
from autogen_core.base._type_helpers import is_union

View File

@@ -1,9 +1,8 @@
from __future__ import annotations
from typing import Any, Awaitable, Callable, Protocol, runtime_checkable
from typing import Awaitable, Callable, Protocol, runtime_checkable
from ._agent_id import AgentId
from ._agent_type import AgentType
from ._topic import TopicId

View File

@@ -1,5 +1,5 @@
import inspect
from typing import Any, Awaitable, Callable, List, Mapping, Sequence, TypeVar, cast, get_type_hints
from typing import Any, Awaitable, Callable, List, Mapping, Sequence, TypeVar, get_type_hints
from ..base import (
Agent,
@@ -112,9 +112,10 @@ class ClosureAgent(Agent):
with SubscriptionInstantiationContext.populate_context(agent_type):
subscriptions_list_result = subscriptions()
if inspect.isawaitable(subscriptions_list_result):
subscriptions_list.extend(cast(List[Subscription], await subscriptions_list_result))
subscriptions_list.extend(await subscriptions_list_result)
else:
subscriptions_list.extend(cast(List[Subscription], subscriptions_list_result))
# just ignore mypy here
subscriptions_list.extend(subscriptions_list_result) # type: ignore
agent_type = await runtime.register_factory(
type=agent_type,

View File

@@ -4,12 +4,12 @@ import base64
import re
from io import BytesIO
from pathlib import Path
from typing import Any
from typing import Any, cast
import aiohttp
from openai.types.chat import ChatCompletionContentPartImageParam
from PIL import Image as PILImage
from pydantic import BaseModel, GetCoreSchemaHandler, ValidationInfo
from pydantic import GetCoreSchemaHandler, ValidationInfo
from pydantic_core import core_schema
from typing_extensions import Literal
@@ -68,7 +68,7 @@ class Image:
# Custom validation
def validate(value: Any, validation_info: ValidationInfo) -> Image:
if isinstance(value, dict):
base_64 = value.get("data")
base_64 = cast(str | None, value.get("data")) # type: ignore
if base_64 is None:
raise ValueError("Expected 'data' key in the dictionary")
return cls.from_base64(base_64)

View File

@@ -19,8 +19,6 @@ from typing import (
runtime_checkable,
)
from typing_extensions import Self
from ..base import BaseAgent, MessageContext, MessageSerializer, try_get_known_serializers_for_type
from ..base._type_helpers import AnyType, get_types
from ..base.exceptions import CantHandleException
@@ -477,7 +475,7 @@ class RoutedAgent(BaseAgent):
@classmethod
def _discover_handlers(cls) -> Sequence[MessageHandler[Any, Any, Any]]:
handlers = []
handlers: List[MessageHandler[Any, Any, Any]] = []
for attr in dir(cls):
if callable(getattr(cls, attr, None)):
# Since we are getting it from the class, self is not bound
@@ -491,7 +489,7 @@ class RoutedAgent(BaseAgent):
# TODO handle deduplication
handlers = cls._discover_handlers()
types: List[Tuple[Type[Any], List[MessageSerializer[Any]]]] = []
types.extend(cls._extra_handles_types)
types.extend(cls.internal_extra_handles_types)
for handler in handlers:
for t in handler.target_types:
# TODO: support different serializers

View File

@@ -13,7 +13,7 @@ from collections.abc import Sequence
from hashlib import md5
from pathlib import Path
from types import TracebackType
from typing import Any, Callable, ClassVar, Dict, List, Optional, ParamSpec, Type, Union
from typing import Any, Callable, ClassVar, List, Optional, ParamSpec, Type, Union
import docker
import docker.models
@@ -212,8 +212,8 @@ $functions"""
if len(code_blocks) == 0:
raise ValueError("No code blocks to execute.")
outputs = []
files = []
outputs: List[str] = []
files: List[Path] = []
last_exit_code = 0
for code_block in code_blocks:
lang = code_block.language.lower()
@@ -237,7 +237,7 @@ $functions"""
command = ["timeout", str(self._timeout), lang_to_cmd(lang), filename]
result = await asyncio.to_thread(self._container.exec_run, command)
result = await asyncio.to_thread(self._container.exec_run, command) # type: ignore
exit_code = result.exit_code
output = result.output.decode("utf-8")
if exit_code == 124:
@@ -277,7 +277,7 @@ $functions"""
raise ValueError("Container is not running. Must first be started with either start or a context manager.")
"""(Experimental) Restart the code executor."""
await asyncio.to_thread(self._container.restart)
await asyncio.to_thread(self._container.restart) # type: ignore
if self._container.status != "running":
self._running = False
logs_str = self._container.logs().decode("utf-8")

View File

@@ -33,7 +33,7 @@ def docker_tests_enabled() -> bool:
return False
@pytest_asyncio.fixture(scope="function")
@pytest_asyncio.fixture(scope="function") # type: ignore
async def executor_and_temp_dir(
request: pytest.FixtureRequest,
) -> AsyncGenerator[tuple[LocalCommandLineCodeExecutor | DockerCommandLineCodeExecutor, str], None]:
@@ -55,7 +55,7 @@ ExecutorFixture: TypeAlias = tuple[LocalCommandLineCodeExecutor | DockerCommandL
@pytest.mark.asyncio
@pytest.mark.parametrize("executor_and_temp_dir", ["local", "docker"], indirect=True)
async def test_execute_code(executor_and_temp_dir: ExecutorFixture) -> None:
executor, temp_dir = executor_and_temp_dir
executor, _temp_dir = executor_and_temp_dir
cancellation_token = CancellationToken()
# Test single code block.
@@ -138,7 +138,7 @@ async def test_local_commandline_code_executor_restart() -> None:
@pytest.mark.asyncio
@pytest.mark.parametrize("executor_and_temp_dir", ["local", "docker"], indirect=True)
async def test_invalid_relative_path(executor_and_temp_dir: ExecutorFixture) -> None:
executor, temp_dir = executor_and_temp_dir
executor, _temp_dir = executor_and_temp_dir
cancellation_token = CancellationToken()
code = """# filename: /tmp/test.py

View File

@@ -3,11 +3,10 @@ from dataclasses import dataclass
import pytest
from autogen_core.application import SingleThreadedAgentRuntime
from autogen_core.base import AgentId, AgentRuntime, MessageContext, TopicId
from autogen_core.base import AgentId, AgentRuntime, MessageContext
from autogen_core.components import ClosureAgent
from autogen_core.components._default_subscription import DefaultSubscription
from autogen_core.components._default_topic import DefaultTopicId
from autogen_core.components._type_subscription import TypeSubscription
@dataclass

View File

@@ -1,6 +1,6 @@
import asyncio
from typing import Any, AsyncGenerator, List, Tuple
from unittest.mock import MagicMock, patch
from unittest.mock import MagicMock
import pytest
from autogen_core.base import CancellationToken

View File

@@ -86,7 +86,7 @@ class RoutedAgentMessageCustomMatch(RoutedAgent):
async def handler_one(self, message: TestMessage, ctx: MessageContext) -> None:
self.handler_one_called = True
@message_handler(match=cast(Callable[[TestMessage, MessageContext], bool], lambda msg, ctx: msg.value == "two"))
@message_handler(match=cast(Callable[[TestMessage, MessageContext], bool], lambda msg, ctx: msg.value == "two")) # type: ignore
async def handler_two(self, message: TestMessage, ctx: MessageContext) -> None:
self.handler_two_called = True

View File

@@ -12,7 +12,6 @@ from autogen_core.base import (
try_get_known_serializers_for_type,
)
from autogen_core.components import (
DefaultSubscription,
DefaultTopicId,
TypeSubscription,
default_subscription,

View File

@@ -1,5 +1,6 @@
import asyncio
import logging
from typing import List
import pytest
from autogen_core.application import WorkerAgentRuntime, WorkerAgentRuntimeHost
@@ -10,7 +11,6 @@ from autogen_core.base import (
try_get_known_serializers_for_type,
)
from autogen_core.components import (
DefaultSubscription,
DefaultTopicId,
TypeSubscription,
default_subscription,
@@ -163,7 +163,7 @@ async def test_register_receives_publish_cascade_multiple_workers() -> None:
total_num_calls_expected += num_initial_messages * ((num_agents - 1) ** i)
# Run multiple workers one for each agent.
workers = []
workers: List[WorkerAgentRuntime] = []
# Register agents
for i in range(num_agents):
runtime = WorkerAgentRuntime(host_address=host_address)