Compare commits

...

1 Commits

Author SHA1 Message Date
openhands 4fb3b0dcd5 Fix issue #4184: '[LLM] Support LLM routing through notdiamond' 2024-10-03 11:44:20 +00:00
4 changed files with 164 additions and 1 deletions
+2
View File
@@ -47,3 +47,5 @@ class ConfigType(str, Enum):
WORKSPACE_MOUNT_PATH = 'WORKSPACE_MOUNT_PATH' WORKSPACE_MOUNT_PATH = 'WORKSPACE_MOUNT_PATH'
WORKSPACE_MOUNT_PATH_IN_SANDBOX = 'WORKSPACE_MOUNT_PATH_IN_SANDBOX' WORKSPACE_MOUNT_PATH_IN_SANDBOX = 'WORKSPACE_MOUNT_PATH_IN_SANDBOX'
WORKSPACE_MOUNT_REWRITE = 'WORKSPACE_MOUNT_REWRITE' WORKSPACE_MOUNT_REWRITE = 'WORKSPACE_MOUNT_REWRITE'
LLM_PROVIDERS = 'LLM_PROVIDERS'
LLM_ROUTER_ENABLED = 'LLM_ROUTER_ENABLED'
+43 -1
View File
@@ -2,7 +2,7 @@ import copy
import time import time
import warnings import warnings
from functools import partial from functools import partial
from typing import Any from typing import Any, List, Tuple
from openhands.core.config import LLMConfig from openhands.core.config import LLMConfig
@@ -26,6 +26,7 @@ from openhands.core.message import Message
from openhands.core.metrics import Metrics from openhands.core.metrics import Metrics
from openhands.llm.debug_mixin import DebugMixin from openhands.llm.debug_mixin import DebugMixin
from openhands.llm.retry_mixin import RetryMixin from openhands.llm.retry_mixin import RetryMixin
from openhands.llm.llm_router import LLMRouter
__all__ = ['LLM'] __all__ = ['LLM']
@@ -77,6 +78,11 @@ class LLM(RetryMixin, DebugMixin):
# list of LLM completions (for logging purposes). Each completion is a dict with the following keys: # list of LLM completions (for logging purposes). Each completion is a dict with the following keys:
# - 'messages': list of messages # - 'messages': list of messages
# - 'response': response from the LLM # - 'response': response from the LLM
if self.config.llm_router_enabled:
self.router = LLMRouter(config, metrics)
else:
self.router = None
self.llm_completions: list[dict[str, Any]] = [] self.llm_completions: list[dict[str, Any]] = []
# litellm actually uses base Exception here for unknown model # litellm actually uses base Exception here for unknown model
@@ -123,6 +129,7 @@ class LLM(RetryMixin, DebugMixin):
litellm_completion, litellm_completion,
model=self.config.model, model=self.config.model,
api_key=self.config.api_key, api_key=self.config.api_key,
base_url=self.config.base_url, base_url=self.config.base_url,
api_version=self.config.api_version, api_version=self.config.api_version,
custom_llm_provider=self.config.custom_llm_provider, custom_llm_provider=self.config.custom_llm_provider,
@@ -173,6 +180,7 @@ class LLM(RetryMixin, DebugMixin):
if not messages: if not messages:
raise ValueError( raise ValueError(
'The messages list is empty. At least one message is required.' 'The messages list is empty. At least one message is required.'
) )
# log the entire LLM prompt # log the entire LLM prompt
@@ -211,6 +219,40 @@ class LLM(RetryMixin, DebugMixin):
self._completion = wrapper self._completion = wrapper
def complete(
self,
messages: List[Message],
**kwargs: Any,
) -> Tuple[str, float]:
"""Complete the given messages using the best selected model or the default model."""
start_time = time.time()
if self.router:
response, _ = self.router.complete(messages, **kwargs)
else:
response = self._completion(
messages=[{"role": msg.role, "content": msg.content} for msg in messages],
**kwargs
)
latency = time.time() - start_time
return response.choices[0].message.content, latency
def stream(
self,
messages: List[Message],
**kwargs: Any,
):
"""Stream the response using the best selected model or the default model."""
if self.router:
yield from self.router.stream(messages, **kwargs)
else:
yield from self._completion(
messages=[{"role": msg.role, "content": msg.content} for msg in messages],
stream=True,
**kwargs
)
@property @property
def completion(self): def completion(self):
"""Decorator for the litellm completion function. """Decorator for the litellm completion function.
+65
View File
@@ -0,0 +1,65 @@
import os
from typing import List, Tuple, Any
from openhands.core.config import LLMConfig
from openhands.llm.llm import LLM
from openhands.core.message import Message
from openhands.core.metrics import Metrics
class LLMRouter(LLM):
"""LLMRouter class that selects the best LLM for a given query."""
def __init__(
self,
config: LLMConfig,
metrics: Metrics | None = None,
):
super().__init__(config, metrics)
self.llm_providers: List[str] = config.llm_providers
self.notdiamond_api_key = os.environ.get("NOTDIAMOND_API_KEY")
if not self.notdiamond_api_key:
raise ValueError("NOTDIAMOND_API_KEY environment variable is not set")
from notdiamond import NotDiamond
self.client = NotDiamond()
def _select_model(self, messages: List[Message]) -> Tuple[str, Any]:
"""Select the best model for the given messages."""
formatted_messages = [{"role": msg.role, "content": msg.content} for msg in messages]
session_id, provider = self.client.chat.completions.model_select(
messages=formatted_messages,
model=self.llm_providers
)
return provider.model, session_id
def complete(
self,
messages: List[Message],
**kwargs: Any,
) -> Tuple[str, float]:
"""Complete the given messages using the best selected model."""
selected_model, session_id = self._select_model(messages)
# Create a new LLM instance with the selected model
selected_config = LLMConfig(model=selected_model)
selected_llm = LLM(config=selected_config, metrics=self.metrics)
# Use the selected LLM to complete the messages
response, latency = selected_llm.complete(messages, **kwargs)
return response, latency
def stream(
self,
messages: List[Message],
**kwargs: Any,
):
"""Stream the response using the best selected model."""
selected_model, session_id = self._select_model(messages)
# Create a new LLM instance with the selected model
selected_config = LLMConfig(model=selected_model)
selected_llm = LLM(config=selected_config, metrics=self.metrics)
# Use the selected LLM to stream the response
yield from selected_llm.stream(messages, **kwargs)
+54
View File
@@ -0,0 +1,54 @@
import pytest
from unittest.mock import Mock, patch
from openhands.core.config import LLMConfig
from openhands.core.message import Message
from openhands.llm.llm import LLM
from openhands.llm.llm_router import LLMRouter
@pytest.fixture
def mock_notdiamond():
with patch('openhands.llm.llm_router.NotDiamond') as mock:
yield mock
def test_llm_router_enabled(mock_notdiamond):
config = LLMConfig(
model="test-model",
llm_router_enabled=True,
llm_providers=["model1", "model2"]
)
llm = LLM(config)
assert isinstance(llm.router, LLMRouter)
messages = [Message(role="user", content="Hello")]
mock_response = Mock()
mock_response.choices[0].message.content = "Hello, how can I help you?"
llm.router.complete = Mock(return_value=(mock_response, 0.5))
response, latency = llm.complete(messages)
assert response == "Hello, how can I help you?"
assert isinstance(latency, float)
llm.router.complete.assert_called_once_with(messages)
def test_llm_router_disabled():
config = LLMConfig(
model="test-model",
llm_router_enabled=False
)
llm = LLM(config)
assert llm.router is None
messages = [Message(role="user", content="Hello")]
with patch.object(llm, '_completion') as mock_completion:
mock_response = Mock()
mock_response.choices[0].message.content = "Hello, how can I help you?"
mock_completion.return_value = mock_response
response, latency = llm.complete(messages)
assert response == "Hello, how can I help you?"
assert isinstance(latency, float)
mock_completion.assert_called_once()