split into example

This commit is contained in:
Jack Gerrits
2024-05-15 09:59:23 -04:00
parent 17eb9f8ecd
commit 813a9e1ddb
3 changed files with 150 additions and 144 deletions

137
examples/example.py Normal file
View File

@@ -0,0 +1,137 @@
import asyncio
import random
from dataclasses import dataclass
from typing import Awaitable, Callable, List, Optional, Sequence, cast
from agnext.prototype import Agent, Event, EventQueue, EventRouter, TypeRoutedAgent, event_handler
@dataclass
class InputEvent(Event):
message: str
sender: str
@dataclass
class NewEvent(Event):
message: str
sender: str
recipient: str
@dataclass
class ResponseEvent(Event):
message: Optional[str]
sender: str
GroupChatEvents = InputEvent | NewEvent | ResponseEvent
class GroupChatManager(TypeRoutedAgent[GroupChatEvents]):
def __init__(
self, name: str, emit_event: Callable[[GroupChatEvents], Awaitable[None]], agents: Sequence[Agent]
) -> None:
super().__init__(name, emit_event)
self._agents = agents
self._current_speaker = 0
self._events: List[GroupChatEvents] = []
self._responses: List[ResponseEvent] = []
@event_handler(InputEvent)
async def on_input_event(self, event: InputEvent) -> None:
# New group chat
self._events.clear()
recipient_agent = self._agents[self._current_speaker]
self._current_speaker = (self._current_speaker + 1) % len(self._agents)
new_event = NewEvent(message=event.message, sender=self.name, recipient=recipient_agent.name)
self._events.append(event)
await self.emit_event(new_event)
@event_handler(ResponseEvent)
async def on_group_chat_event(self, event: ResponseEvent) -> None:
self._responses.append(event)
# TODO: Handle termination and replying to original sender
# Received response from all - proceeed
if len(self._responses) == len(self._agents):
recipient_agent = self._agents[self._current_speaker]
self._current_speaker = (self._current_speaker + 1) % len(self._agents)
responses_with_content = [x for x in self._responses if x.message is not None]
if len(responses_with_content) != 1:
raise ValueError("Can't handle anything other than 1 response right now.")
new_event = NewEvent(
message=cast(str, responses_with_content[0].message), sender=self.name, recipient=recipient_agent.name
)
self._events.append(new_event)
self._responses.clear()
await self.emit_event(new_event)
async def on_unhandled_event(self, event: GroupChatEvents) -> None:
raise ValueError("Unknown")
class Critic(TypeRoutedAgent[GroupChatEvents]):
def __init__(self, name: str, emit_event: Callable[[GroupChatEvents], Awaitable[None]]) -> None:
super().__init__(name, emit_event)
@event_handler(NewEvent)
async def on_new_event(self, event: NewEvent) -> None:
if event.recipient == self.name:
response = random.choice([" is a good idea", " is a bad idea"])
await self.emit_event(ResponseEvent(event.message + response, sender=self.name))
else:
await self.emit_event(ResponseEvent(None, sender=self.name))
async def on_unhandled_event(self, event: GroupChatEvents) -> None:
raise ValueError("Unknown")
class Suggester(TypeRoutedAgent[GroupChatEvents]):
def __init__(self, name: str, emit_event: Callable[[GroupChatEvents], Awaitable[None]]) -> None:
super().__init__(name, emit_event)
@event_handler(NewEvent)
async def on_new_event(self, event: NewEvent) -> None:
if event.recipient == self.name:
response = random.choice(
["Attach wheels to a laptop", "merge a banana and an apple", "Cheese but made with oats"]
)
await self.emit_event(ResponseEvent(response, sender=self.name))
else:
await self.emit_event(ResponseEvent(None, sender=self.name))
async def on_unhandled_event(self, event: GroupChatEvents) -> None:
raise ValueError("Unknown")
async def main():
event_queue = EventQueue[GroupChatEvents]()
critic = Critic("Critic", event_queue.into_callable())
suggester = Suggester("Suggester", event_queue.into_callable())
group_chat_manager = GroupChatManager("Manager", event_queue.into_callable(), [critic, suggester])
processor = EventRouter(event_queue, [critic, suggester, group_chat_manager])
await event_queue.emit(InputEvent(message="Go", sender="external"))
await processor.process_next()
await processor.process_next()
await processor.process_next()
await processor.process_next()
await processor.process_next()
await processor.process_next()
await processor.process_next()
await processor.process_next()
await processor.process_next()
await processor.process_next()
await processor.process_next()
if __name__ == "__main__":
asyncio.run(main())

View File

@@ -25,15 +25,14 @@ line-length = 120
fix = true
exclude = ["build", "dist", "my_project/__init__.py", "my_project/main.py"]
target-version = "py310"
include = ["src/**", "examples/**", "tests/**"]
[tool.ruff.lint]
select = ["E", "F", "W", "B", "Q", "I"]
ignore = ["F401", "E501"]
[tool.mypy]
files = [
"src"
]
files = ["src", "examples", "tests"]
strict = true
python_version = "3.10"
@@ -52,9 +51,7 @@ disallow_untyped_decorators = true
disallow_any_unimported = true
[tool.pyright]
include = [
"src"
]
include = ["src", "examples", "tests"]
typeCheckingMode = "strict"
reportUnnecessaryIsInstance = false
reportMissingTypeStubs = false

View File

@@ -1,8 +1,5 @@
import asyncio
import random
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import Any, Awaitable, Callable, Dict, List, Optional, Protocol, Sequence, Type, cast
from typing import Any, Awaitable, Callable, Dict, List, Protocol, Sequence, Type
# Type based routing for event
@@ -19,28 +16,24 @@ from typing import Any, Awaitable, Callable, Dict, List, Optional, Protocol, Seq
# on event with receipt
class Event(Protocol):
sender: str
# reply_to: Optional[str]
# T must encompass all subscribed types for a given agent
class Agent(Protocol):
@property
def name(self) -> str:
...
def name(self) -> str: ...
class EventBasedAgent[T: Event](Agent):
@property
def subscriptions(self) -> Sequence[Type[T]]:
...
def subscriptions(self) -> Sequence[Type[T]]: ...
async def on_event(self, event: T) -> None:
...
async def on_event(self, event: T) -> None: ...
# async def _send_event(self, event: T) -> None:
# ...
@@ -52,8 +45,9 @@ class EventBasedAgent[T: Event](Agent):
# NOTE: this works on concrete types and not inheritance
def event_handler[T: Event](target_type: Type[T]):
def decorator(func: Callable[..., Awaitable[None]]) -> Callable[..., Awaitable[None]]:
func._target_type = target_type # type: ignore
func._target_type = target_type # type: ignore
return func
return decorator
@@ -88,104 +82,9 @@ class TypeRoutedAgent[T: Event](EventBasedAgent[T], ABC):
else:
await self.on_unhandled_event(event)
@abstractmethod
async def on_unhandled_event(self, event: T) -> None:
...
async def on_unhandled_event(self, event: T) -> None: ...
@dataclass
class InputEvent(Event):
message: str
sender: str
@dataclass
class NewEvent(Event):
message: str
sender: str
recipient: str
@dataclass
class ResponseEvent(Event):
message: Optional[str]
sender: str
GroupChatEvents = InputEvent | NewEvent | ResponseEvent
class GroupChatManager(TypeRoutedAgent[GroupChatEvents]):
def __init__(self, name: str, emit_event: Callable[[GroupChatEvents], Awaitable[None]], agents: Sequence[Agent]) -> None:
super().__init__(name, emit_event)
self._agents = agents
self._current_speaker = 0
self._events: List[GroupChatEvents] = []
self._responses: List[ResponseEvent] = []
@event_handler(InputEvent)
async def on_input_event(self, event: InputEvent) -> None:
# New group chat
self._events.clear()
recipient_agent = self._agents[self._current_speaker]
self._current_speaker = (self._current_speaker + 1) % len(self._agents)
new_event = NewEvent(message=event.message, sender=self.name, recipient=recipient_agent.name)
self._events.append(event)
await self.emit_event(new_event)
@event_handler(ResponseEvent)
async def on_group_chat_event(self, event: ResponseEvent) -> None:
self._responses.append(event)
# TODO: Handle termination and replying to original sender
# Received response from all - proceeed
if len(self._responses) == len(self._agents):
recipient_agent = self._agents[self._current_speaker]
self._current_speaker = (self._current_speaker + 1) % len(self._agents)
responses_with_content = [x for x in self._responses if x.message is not None]
if len(responses_with_content) != 1:
raise ValueError("Can't handle anything other than 1 response right now.")
new_event = NewEvent(message=cast(str, responses_with_content[0].message), sender=self.name, recipient=recipient_agent.name)
self._events.append(new_event)
self._responses.clear()
await self.emit_event(new_event)
async def on_unhandled_event(self, event: GroupChatEvents) -> None:
raise ValueError("Unknown")
class Critic(TypeRoutedAgent[GroupChatEvents]):
def __init__(self, name: str, emit_event: Callable[[GroupChatEvents], Awaitable[None]]) -> None:
super().__init__(name, emit_event)
@event_handler(NewEvent)
async def on_new_event(self, event: NewEvent) -> None:
if event.recipient == self.name:
response = random.choice([" is a good idea", " is a bad idea"])
await self.emit_event(ResponseEvent(event.message + response, sender=self.name))
else:
await self.emit_event(ResponseEvent(None, sender=self.name))
async def on_unhandled_event(self, event: GroupChatEvents) -> None:
raise ValueError("Unknown")
class Suggester(TypeRoutedAgent[GroupChatEvents]):
def __init__(self, name: str, emit_event: Callable[[GroupChatEvents], Awaitable[None]]) -> None:
super().__init__(name, emit_event)
@event_handler(NewEvent)
async def on_new_event(self, event: NewEvent) -> None:
if event.recipient == self.name:
response = random.choice(["Attach wheels to a laptop", "merge a banana and an apple", "Cheese but made with oats"])
await self.emit_event(ResponseEvent(response, sender=self.name))
else:
await self.emit_event(ResponseEvent(None, sender=self.name))
async def on_unhandled_event(self, event: GroupChatEvents) -> None:
raise ValueError("Unknown")
class EventQueue[U]:
def __init__(self) -> None:
@@ -205,7 +104,7 @@ class EventQueue[U]:
return self.emit
class EventRouter[T: Event]():
class EventRouter[T: Event]:
def __init__(self, event_queue: EventQueue[T], agents: Sequence[EventBasedAgent[T]]) -> None:
self._event_queue = event_queue
# Use default dict i just cant remember the syntax and im without internet
@@ -218,7 +117,6 @@ class EventRouter[T: Event]():
self._per_type_subscribers[subscription].append(agent)
async def process_next(self) -> None:
if self._event_queue.empty():
return
@@ -230,29 +128,3 @@ class EventRouter[T: Event]():
await subscriber.on_event(event)
else:
print(f"Event {event} has no recipient agent")
async def main():
event_queue = EventQueue[GroupChatEvents]()
critic = Critic("Critic", event_queue.into_callable())
suggester = Suggester("Suggester", event_queue.into_callable())
group_chat_manager = GroupChatManager("Manager", event_queue.into_callable(), [critic, suggester])
processor = EventRouter(event_queue, [critic, suggester, group_chat_manager])
await event_queue.emit(InputEvent(message="Go", sender="external"))
await processor.process_next()
await processor.process_next()
await processor.process_next()
await processor.process_next()
await processor.process_next()
await processor.process_next()
await processor.process_next()
await processor.process_next()
await processor.process_next()
await processor.process_next()
await processor.process_next()
if __name__ == "__main__":
asyncio.run(main())