mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-04-08 03:00:28 -04:00
feat(forge/llm): allow async completion parsers
This commit is contained in:
@@ -154,7 +154,10 @@ class BaseOpenAIChatProvider(
|
||||
self,
|
||||
model_prompt: list[ChatMessage],
|
||||
model_name: _ModelName,
|
||||
completion_parser: Callable[[AssistantChatMessage], _T] = lambda _: None,
|
||||
completion_parser: (
|
||||
Callable[[AssistantChatMessage], Awaitable[_T]]
|
||||
| Callable[[AssistantChatMessage], _T]
|
||||
) = lambda _: None,
|
||||
functions: Optional[list[CompletionModelFunction]] = None,
|
||||
max_output_tokens: Optional[int] = None,
|
||||
prefill_response: str = "",
|
||||
@@ -208,9 +211,15 @@ class BaseOpenAIChatProvider(
|
||||
parsed_result: _T = None # type: ignore
|
||||
if not parse_errors:
|
||||
try:
|
||||
parsed_result = completion_parser(assistant_msg)
|
||||
if inspect.isawaitable(parsed_result):
|
||||
parsed_result = await parsed_result
|
||||
parsed_result = (
|
||||
await _result
|
||||
if inspect.isawaitable(
|
||||
_result := completion_parser(assistant_msg)
|
||||
)
|
||||
# cast(..) needed because inspect.isawaitable(..) loses type:
|
||||
# https://github.com/microsoft/pyright/issues/3690
|
||||
else cast(_T, _result)
|
||||
)
|
||||
except Exception as e:
|
||||
parse_errors.append(e)
|
||||
|
||||
|
||||
@@ -1,8 +1,19 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import enum
|
||||
import inspect
|
||||
import logging
|
||||
from typing import TYPE_CHECKING, Any, Callable, Optional, ParamSpec, Sequence, TypeVar
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
Awaitable,
|
||||
Callable,
|
||||
Optional,
|
||||
ParamSpec,
|
||||
Sequence,
|
||||
TypeVar,
|
||||
cast,
|
||||
)
|
||||
|
||||
import sentry_sdk
|
||||
import tenacity
|
||||
@@ -162,7 +173,10 @@ class AnthropicProvider(BaseChatModelProvider[AnthropicModelName, AnthropicSetti
|
||||
self,
|
||||
model_prompt: list[ChatMessage],
|
||||
model_name: AnthropicModelName,
|
||||
completion_parser: Callable[[AssistantChatMessage], _T] = lambda _: None,
|
||||
completion_parser: (
|
||||
Callable[[AssistantChatMessage], Awaitable[_T]]
|
||||
| Callable[[AssistantChatMessage], _T]
|
||||
) = lambda _: None,
|
||||
functions: Optional[list[CompletionModelFunction]] = None,
|
||||
max_output_tokens: Optional[int] = None,
|
||||
prefill_response: str = "",
|
||||
@@ -228,7 +242,14 @@ class AnthropicProvider(BaseChatModelProvider[AnthropicModelName, AnthropicSetti
|
||||
+ "\n".join(str(e) for e in tool_call_errors)
|
||||
)
|
||||
|
||||
parsed_result = completion_parser(assistant_msg)
|
||||
# cast(..) needed because inspect.isawaitable(..) loses type info:
|
||||
# https://github.com/microsoft/pyright/issues/3690
|
||||
parsed_result = cast(
|
||||
_T,
|
||||
await _result
|
||||
if inspect.isawaitable(_result := completion_parser(assistant_msg))
|
||||
else _result,
|
||||
)
|
||||
break
|
||||
except Exception as e:
|
||||
self._logger.debug(
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import Any, Callable, Iterator, Optional, Sequence, TypeVar
|
||||
from typing import Any, Awaitable, Callable, Iterator, Optional, Sequence, TypeVar
|
||||
|
||||
from pydantic import ValidationError
|
||||
|
||||
@@ -93,7 +93,10 @@ class MultiProvider(BaseChatModelProvider[ModelName, ModelProviderSettings]):
|
||||
self,
|
||||
model_prompt: list[ChatMessage],
|
||||
model_name: ModelName,
|
||||
completion_parser: Callable[[AssistantChatMessage], _T] = lambda _: None,
|
||||
completion_parser: (
|
||||
Callable[[AssistantChatMessage], Awaitable[_T]]
|
||||
| Callable[[AssistantChatMessage], _T]
|
||||
) = lambda _: None,
|
||||
functions: Optional[list[CompletionModelFunction]] = None,
|
||||
max_output_tokens: Optional[int] = None,
|
||||
prefill_response: str = "",
|
||||
|
||||
@@ -6,6 +6,7 @@ from collections import defaultdict
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
Awaitable,
|
||||
Callable,
|
||||
ClassVar,
|
||||
Generic,
|
||||
@@ -455,7 +456,10 @@ class BaseChatModelProvider(BaseModelProvider[_ModelName, _ModelProviderSettings
|
||||
self,
|
||||
model_prompt: list[ChatMessage],
|
||||
model_name: _ModelName,
|
||||
completion_parser: Callable[[AssistantChatMessage], _T] = lambda _: None,
|
||||
completion_parser: (
|
||||
Callable[[AssistantChatMessage], Awaitable[_T]]
|
||||
| Callable[[AssistantChatMessage], _T]
|
||||
) = lambda _: None,
|
||||
functions: Optional[list[CompletionModelFunction]] = None,
|
||||
max_output_tokens: Optional[int] = None,
|
||||
prefill_response: str = "",
|
||||
|
||||
Reference in New Issue
Block a user