mirror of
https://github.com/All-Hands-AI/OpenHands.git
synced 2026-04-29 03:00:45 -04:00
Compare commits
1 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 4fb3b0dcd5 |
@@ -47,3 +47,5 @@ class ConfigType(str, Enum):
|
|||||||
WORKSPACE_MOUNT_PATH = 'WORKSPACE_MOUNT_PATH'
|
WORKSPACE_MOUNT_PATH = 'WORKSPACE_MOUNT_PATH'
|
||||||
WORKSPACE_MOUNT_PATH_IN_SANDBOX = 'WORKSPACE_MOUNT_PATH_IN_SANDBOX'
|
WORKSPACE_MOUNT_PATH_IN_SANDBOX = 'WORKSPACE_MOUNT_PATH_IN_SANDBOX'
|
||||||
WORKSPACE_MOUNT_REWRITE = 'WORKSPACE_MOUNT_REWRITE'
|
WORKSPACE_MOUNT_REWRITE = 'WORKSPACE_MOUNT_REWRITE'
|
||||||
|
LLM_PROVIDERS = 'LLM_PROVIDERS'
|
||||||
|
LLM_ROUTER_ENABLED = 'LLM_ROUTER_ENABLED'
|
||||||
|
|||||||
+43
-1
@@ -2,7 +2,7 @@ import copy
|
|||||||
import time
|
import time
|
||||||
import warnings
|
import warnings
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from typing import Any
|
from typing import Any, List, Tuple
|
||||||
|
|
||||||
from openhands.core.config import LLMConfig
|
from openhands.core.config import LLMConfig
|
||||||
|
|
||||||
@@ -26,6 +26,7 @@ from openhands.core.message import Message
|
|||||||
from openhands.core.metrics import Metrics
|
from openhands.core.metrics import Metrics
|
||||||
from openhands.llm.debug_mixin import DebugMixin
|
from openhands.llm.debug_mixin import DebugMixin
|
||||||
from openhands.llm.retry_mixin import RetryMixin
|
from openhands.llm.retry_mixin import RetryMixin
|
||||||
|
from openhands.llm.llm_router import LLMRouter
|
||||||
|
|
||||||
__all__ = ['LLM']
|
__all__ = ['LLM']
|
||||||
|
|
||||||
@@ -77,6 +78,11 @@ class LLM(RetryMixin, DebugMixin):
|
|||||||
# list of LLM completions (for logging purposes). Each completion is a dict with the following keys:
|
# list of LLM completions (for logging purposes). Each completion is a dict with the following keys:
|
||||||
# - 'messages': list of messages
|
# - 'messages': list of messages
|
||||||
# - 'response': response from the LLM
|
# - 'response': response from the LLM
|
||||||
|
|
||||||
|
if self.config.llm_router_enabled:
|
||||||
|
self.router = LLMRouter(config, metrics)
|
||||||
|
else:
|
||||||
|
self.router = None
|
||||||
self.llm_completions: list[dict[str, Any]] = []
|
self.llm_completions: list[dict[str, Any]] = []
|
||||||
|
|
||||||
# litellm actually uses base Exception here for unknown model
|
# litellm actually uses base Exception here for unknown model
|
||||||
@@ -123,6 +129,7 @@ class LLM(RetryMixin, DebugMixin):
|
|||||||
litellm_completion,
|
litellm_completion,
|
||||||
model=self.config.model,
|
model=self.config.model,
|
||||||
api_key=self.config.api_key,
|
api_key=self.config.api_key,
|
||||||
|
|
||||||
base_url=self.config.base_url,
|
base_url=self.config.base_url,
|
||||||
api_version=self.config.api_version,
|
api_version=self.config.api_version,
|
||||||
custom_llm_provider=self.config.custom_llm_provider,
|
custom_llm_provider=self.config.custom_llm_provider,
|
||||||
@@ -173,6 +180,7 @@ class LLM(RetryMixin, DebugMixin):
|
|||||||
if not messages:
|
if not messages:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
'The messages list is empty. At least one message is required.'
|
'The messages list is empty. At least one message is required.'
|
||||||
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# log the entire LLM prompt
|
# log the entire LLM prompt
|
||||||
@@ -211,6 +219,40 @@ class LLM(RetryMixin, DebugMixin):
|
|||||||
|
|
||||||
self._completion = wrapper
|
self._completion = wrapper
|
||||||
|
|
||||||
|
|
||||||
|
def complete(
|
||||||
|
self,
|
||||||
|
messages: List[Message],
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> Tuple[str, float]:
|
||||||
|
"""Complete the given messages using the best selected model or the default model."""
|
||||||
|
start_time = time.time()
|
||||||
|
|
||||||
|
if self.router:
|
||||||
|
response, _ = self.router.complete(messages, **kwargs)
|
||||||
|
else:
|
||||||
|
response = self._completion(
|
||||||
|
messages=[{"role": msg.role, "content": msg.content} for msg in messages],
|
||||||
|
**kwargs
|
||||||
|
)
|
||||||
|
|
||||||
|
latency = time.time() - start_time
|
||||||
|
return response.choices[0].message.content, latency
|
||||||
|
|
||||||
|
def stream(
|
||||||
|
self,
|
||||||
|
messages: List[Message],
|
||||||
|
**kwargs: Any,
|
||||||
|
):
|
||||||
|
"""Stream the response using the best selected model or the default model."""
|
||||||
|
if self.router:
|
||||||
|
yield from self.router.stream(messages, **kwargs)
|
||||||
|
else:
|
||||||
|
yield from self._completion(
|
||||||
|
messages=[{"role": msg.role, "content": msg.content} for msg in messages],
|
||||||
|
stream=True,
|
||||||
|
**kwargs
|
||||||
|
)
|
||||||
@property
|
@property
|
||||||
def completion(self):
|
def completion(self):
|
||||||
"""Decorator for the litellm completion function.
|
"""Decorator for the litellm completion function.
|
||||||
|
|||||||
@@ -0,0 +1,65 @@
|
|||||||
|
|
||||||
|
import os
|
||||||
|
from typing import List, Tuple, Any
|
||||||
|
from openhands.core.config import LLMConfig
|
||||||
|
from openhands.llm.llm import LLM
|
||||||
|
from openhands.core.message import Message
|
||||||
|
from openhands.core.metrics import Metrics
|
||||||
|
|
||||||
|
class LLMRouter(LLM):
|
||||||
|
"""LLMRouter class that selects the best LLM for a given query."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
config: LLMConfig,
|
||||||
|
metrics: Metrics | None = None,
|
||||||
|
):
|
||||||
|
super().__init__(config, metrics)
|
||||||
|
self.llm_providers: List[str] = config.llm_providers
|
||||||
|
self.notdiamond_api_key = os.environ.get("NOTDIAMOND_API_KEY")
|
||||||
|
if not self.notdiamond_api_key:
|
||||||
|
raise ValueError("NOTDIAMOND_API_KEY environment variable is not set")
|
||||||
|
|
||||||
|
from notdiamond import NotDiamond
|
||||||
|
self.client = NotDiamond()
|
||||||
|
|
||||||
|
def _select_model(self, messages: List[Message]) -> Tuple[str, Any]:
|
||||||
|
"""Select the best model for the given messages."""
|
||||||
|
formatted_messages = [{"role": msg.role, "content": msg.content} for msg in messages]
|
||||||
|
session_id, provider = self.client.chat.completions.model_select(
|
||||||
|
messages=formatted_messages,
|
||||||
|
model=self.llm_providers
|
||||||
|
)
|
||||||
|
return provider.model, session_id
|
||||||
|
|
||||||
|
def complete(
|
||||||
|
self,
|
||||||
|
messages: List[Message],
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> Tuple[str, float]:
|
||||||
|
"""Complete the given messages using the best selected model."""
|
||||||
|
selected_model, session_id = self._select_model(messages)
|
||||||
|
|
||||||
|
# Create a new LLM instance with the selected model
|
||||||
|
selected_config = LLMConfig(model=selected_model)
|
||||||
|
selected_llm = LLM(config=selected_config, metrics=self.metrics)
|
||||||
|
|
||||||
|
# Use the selected LLM to complete the messages
|
||||||
|
response, latency = selected_llm.complete(messages, **kwargs)
|
||||||
|
|
||||||
|
return response, latency
|
||||||
|
|
||||||
|
def stream(
|
||||||
|
self,
|
||||||
|
messages: List[Message],
|
||||||
|
**kwargs: Any,
|
||||||
|
):
|
||||||
|
"""Stream the response using the best selected model."""
|
||||||
|
selected_model, session_id = self._select_model(messages)
|
||||||
|
|
||||||
|
# Create a new LLM instance with the selected model
|
||||||
|
selected_config = LLMConfig(model=selected_model)
|
||||||
|
selected_llm = LLM(config=selected_config, metrics=self.metrics)
|
||||||
|
|
||||||
|
# Use the selected LLM to stream the response
|
||||||
|
yield from selected_llm.stream(messages, **kwargs)
|
||||||
@@ -0,0 +1,54 @@
|
|||||||
|
|
||||||
|
import pytest
|
||||||
|
from unittest.mock import Mock, patch
|
||||||
|
from openhands.core.config import LLMConfig
|
||||||
|
from openhands.core.message import Message
|
||||||
|
from openhands.llm.llm import LLM
|
||||||
|
from openhands.llm.llm_router import LLMRouter
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_notdiamond():
|
||||||
|
with patch('openhands.llm.llm_router.NotDiamond') as mock:
|
||||||
|
yield mock
|
||||||
|
|
||||||
|
def test_llm_router_enabled(mock_notdiamond):
|
||||||
|
config = LLMConfig(
|
||||||
|
model="test-model",
|
||||||
|
llm_router_enabled=True,
|
||||||
|
llm_providers=["model1", "model2"]
|
||||||
|
)
|
||||||
|
llm = LLM(config)
|
||||||
|
|
||||||
|
assert isinstance(llm.router, LLMRouter)
|
||||||
|
|
||||||
|
messages = [Message(role="user", content="Hello")]
|
||||||
|
mock_response = Mock()
|
||||||
|
mock_response.choices[0].message.content = "Hello, how can I help you?"
|
||||||
|
llm.router.complete = Mock(return_value=(mock_response, 0.5))
|
||||||
|
|
||||||
|
response, latency = llm.complete(messages)
|
||||||
|
|
||||||
|
assert response == "Hello, how can I help you?"
|
||||||
|
assert isinstance(latency, float)
|
||||||
|
llm.router.complete.assert_called_once_with(messages)
|
||||||
|
|
||||||
|
def test_llm_router_disabled():
|
||||||
|
config = LLMConfig(
|
||||||
|
model="test-model",
|
||||||
|
llm_router_enabled=False
|
||||||
|
)
|
||||||
|
llm = LLM(config)
|
||||||
|
|
||||||
|
assert llm.router is None
|
||||||
|
|
||||||
|
messages = [Message(role="user", content="Hello")]
|
||||||
|
with patch.object(llm, '_completion') as mock_completion:
|
||||||
|
mock_response = Mock()
|
||||||
|
mock_response.choices[0].message.content = "Hello, how can I help you?"
|
||||||
|
mock_completion.return_value = mock_response
|
||||||
|
|
||||||
|
response, latency = llm.complete(messages)
|
||||||
|
|
||||||
|
assert response == "Hello, how can I help you?"
|
||||||
|
assert isinstance(latency, float)
|
||||||
|
mock_completion.assert_called_once()
|
||||||
Reference in New Issue
Block a user