feat(forge/llm): allow async completion parsers

This commit is contained in:
Reinier van der Leer
2024-06-08 21:29:35 +02:00
parent 8144d26cef
commit 111e8585b5
4 changed files with 47 additions and 10 deletions

View File

@@ -154,7 +154,10 @@ class BaseOpenAIChatProvider(
self,
model_prompt: list[ChatMessage],
model_name: _ModelName,
completion_parser: Callable[[AssistantChatMessage], _T] = lambda _: None,
completion_parser: (
Callable[[AssistantChatMessage], Awaitable[_T]]
| Callable[[AssistantChatMessage], _T]
) = lambda _: None,
functions: Optional[list[CompletionModelFunction]] = None,
max_output_tokens: Optional[int] = None,
prefill_response: str = "",
@@ -208,9 +211,15 @@ class BaseOpenAIChatProvider(
parsed_result: _T = None # type: ignore
if not parse_errors:
try:
parsed_result = completion_parser(assistant_msg)
if inspect.isawaitable(parsed_result):
parsed_result = await parsed_result
parsed_result = (
await _result
if inspect.isawaitable(
_result := completion_parser(assistant_msg)
)
# cast(..) needed because inspect.isawaitable(..) loses type:
# https://github.com/microsoft/pyright/issues/3690
else cast(_T, _result)
)
except Exception as e:
parse_errors.append(e)

View File

@@ -1,8 +1,19 @@
from __future__ import annotations
import enum
import inspect
import logging
from typing import TYPE_CHECKING, Any, Callable, Optional, ParamSpec, Sequence, TypeVar
from typing import (
TYPE_CHECKING,
Any,
Awaitable,
Callable,
Optional,
ParamSpec,
Sequence,
TypeVar,
cast,
)
import sentry_sdk
import tenacity
@@ -162,7 +173,10 @@ class AnthropicProvider(BaseChatModelProvider[AnthropicModelName, AnthropicSetti
self,
model_prompt: list[ChatMessage],
model_name: AnthropicModelName,
completion_parser: Callable[[AssistantChatMessage], _T] = lambda _: None,
completion_parser: (
Callable[[AssistantChatMessage], Awaitable[_T]]
| Callable[[AssistantChatMessage], _T]
) = lambda _: None,
functions: Optional[list[CompletionModelFunction]] = None,
max_output_tokens: Optional[int] = None,
prefill_response: str = "",
@@ -228,7 +242,14 @@ class AnthropicProvider(BaseChatModelProvider[AnthropicModelName, AnthropicSetti
+ "\n".join(str(e) for e in tool_call_errors)
)
parsed_result = completion_parser(assistant_msg)
# cast(..) needed because inspect.isawaitable(..) loses type info:
# https://github.com/microsoft/pyright/issues/3690
parsed_result = cast(
_T,
await _result
if inspect.isawaitable(_result := completion_parser(assistant_msg))
else _result,
)
break
except Exception as e:
self._logger.debug(

View File

@@ -1,7 +1,7 @@
from __future__ import annotations
import logging
from typing import Any, Callable, Iterator, Optional, Sequence, TypeVar
from typing import Any, Awaitable, Callable, Iterator, Optional, Sequence, TypeVar
from pydantic import ValidationError
@@ -93,7 +93,10 @@ class MultiProvider(BaseChatModelProvider[ModelName, ModelProviderSettings]):
self,
model_prompt: list[ChatMessage],
model_name: ModelName,
completion_parser: Callable[[AssistantChatMessage], _T] = lambda _: None,
completion_parser: (
Callable[[AssistantChatMessage], Awaitable[_T]]
| Callable[[AssistantChatMessage], _T]
) = lambda _: None,
functions: Optional[list[CompletionModelFunction]] = None,
max_output_tokens: Optional[int] = None,
prefill_response: str = "",

View File

@@ -6,6 +6,7 @@ from collections import defaultdict
from typing import (
TYPE_CHECKING,
Any,
Awaitable,
Callable,
ClassVar,
Generic,
@@ -455,7 +456,10 @@ class BaseChatModelProvider(BaseModelProvider[_ModelName, _ModelProviderSettings
self,
model_prompt: list[ChatMessage],
model_name: _ModelName,
completion_parser: Callable[[AssistantChatMessage], _T] = lambda _: None,
completion_parser: (
Callable[[AssistantChatMessage], Awaitable[_T]]
| Callable[[AssistantChatMessage], _T]
) = lambda _: None,
functions: Optional[list[CompletionModelFunction]] = None,
max_output_tokens: Optional[int] = None,
prefill_response: str = "",