Compare commits

...

5 Commits

4 changed files with 125 additions and 4 deletions
+4
View File
@@ -20,6 +20,9 @@ class AsyncLLM(LLM):
def __init__(self, *args: Any, **kwargs: Any) -> None:
super().__init__(*args, **kwargs)
# Apply extra headers from env if defined
_extra_headers = self._get_extra_headers()
self._async_completion = partial(
self._call_acompletion,
model=self.config.model,
@@ -35,6 +38,7 @@ class AsyncLLM(LLM):
top_p=self.config.top_p,
drop_params=self.config.drop_params,
seed=self.config.seed,
**({'extra_headers': _extra_headers} if _extra_headers is not None else {}),
)
async_completion_unwrapped = self._async_completion
+27 -3
View File
@@ -1,4 +1,5 @@
import copy
import json as _json
import os
import time
import warnings
@@ -58,6 +59,24 @@ class LLM(RetryMixin, DebugMixin):
config: an LLMConfig object specifying the configuration of the LLM.
"""
def _get_extra_headers(self) -> dict[str, Any] | None:
"""Read and validate extra headers from LLM_EXTRA_HEADERS env.
Returns a dict if valid JSON object, otherwise None.
"""
_extra_headers_env = os.getenv('LLM_EXTRA_HEADERS')
if not _extra_headers_env:
return None
try:
_extra_headers = _json.loads(_extra_headers_env)
if not isinstance(_extra_headers, dict):
logger.warning('LLM_EXTRA_HEADERS must be a JSON object; ignoring')
return None
return _extra_headers
except Exception as _e:
logger.warning(f'Failed parsing LLM_EXTRA_HEADERS: {_e}')
return None
def __init__(
self,
config: LLMConfig,
@@ -201,12 +220,17 @@ class LLM(RetryMixin, DebugMixin):
if self.config.completion_kwargs is not None:
kwargs.update(self.config.completion_kwargs)
# Apply extra headers from env if defined and not already provided via completion_kwargs
_extra_headers = self._get_extra_headers()
if _extra_headers is not None and 'extra_headers' not in kwargs:
kwargs['extra_headers'] = _extra_headers
self._completion = partial(
litellm_completion,
model=self.config.model,
api_key=self.config.api_key.get_secret_value()
if self.config.api_key
else None,
api_key=(
self.config.api_key.get_secret_value() if self.config.api_key else None
),
base_url=self.config.base_url,
api_version=self.config.api_version,
custom_llm_provider=self.config.custom_llm_provider,
+5 -1
View File
@@ -14,6 +14,9 @@ class StreamingLLM(AsyncLLM):
def __init__(self, *args: Any, **kwargs: Any) -> None:
super().__init__(*args, **kwargs)
# Apply extra headers from env if defined
_extra_headers = self._get_extra_headers()
self._async_streaming_completion = partial(
self._call_acompletion,
model=self.config.model,
@@ -28,7 +31,8 @@ class StreamingLLM(AsyncLLM):
temperature=self.config.temperature,
top_p=self.config.top_p,
drop_params=self.config.drop_params,
stream=True, # Ensure streaming is enabled
stream=True,
**({'extra_headers': _extra_headers} if _extra_headers is not None else {}),
)
async_streaming_completion_unwrapped = self._async_streaming_completion
+89
View File
@@ -0,0 +1,89 @@
import asyncio
import json
from unittest.mock import AsyncMock, patch
import pytest
from openhands.core.config import LLMConfig
from openhands.llm.async_llm import AsyncLLM
from openhands.llm.llm import LLM
from openhands.llm.streaming_llm import StreamingLLM
@pytest.fixture
def extra_headers_env(monkeypatch):
headers = {
'editor-version': 'vscode/1.85.1',
'Copilot-Integration-Id': 'vscode-chat',
}
monkeypatch.setenv('LLM_EXTRA_HEADERS', json.dumps(headers))
return headers
def test_llm_passes_extra_headers_to_litellm_completion(extra_headers_env):
cfg = LLMConfig(model='gpt-4o', api_key='test_key')
with patch('openhands.llm.llm.litellm_completion') as mock_completion:
def _side_effect(*args, **kwargs):
assert 'extra_headers' in kwargs
assert kwargs['extra_headers'] == extra_headers_env
# minimal response structure expected by wrapper
return {'id': 'resp-1', 'choices': [{'message': {'content': 'ok'}}]}
mock_completion.side_effect = _side_effect
llm = LLM(cfg, service_id='svc')
resp = llm.completion(messages=[{'role': 'user', 'content': 'hi'}])
assert resp['choices'][0]['message']['content'] == 'ok'
@pytest.mark.asyncio
async def test_async_llm_passes_extra_headers_to_litellm_acompletion(extra_headers_env):
cfg = LLMConfig(model='gpt-4o', api_key='test_key')
async def _async_side_effect(*args, **kwargs):
assert 'extra_headers' in kwargs
assert kwargs['extra_headers'] == extra_headers_env
return {'id': 'resp-2', 'choices': [{'message': {'content': 'ok'}}]}
with patch(
'openhands.llm.async_llm.litellm_acompletion',
new=AsyncMock(side_effect=_async_side_effect),
):
llm = AsyncLLM(cfg, service_id='svc')
resp = await llm.async_completion(
messages=[{'role': 'user', 'content': 'hi'}], stream=False
)
assert resp['choices'][0]['message']['content'] == 'ok'
@pytest.mark.asyncio
async def test_streaming_llm_passes_extra_headers_to_litellm_acompletion(
extra_headers_env,
):
cfg = LLMConfig(model='gpt-4o', api_key='test_key')
async def _gen():
for chunk in [
{'choices': [{'delta': {'content': 'hello'}}]},
{'choices': [{'delta': {'content': ' world'}}]},
]:
yield chunk
await asyncio.sleep(0)
async def _async_side_effect(*args, **kwargs):
assert 'extra_headers' in kwargs
assert kwargs['extra_headers'] == extra_headers_env
return _gen()
with patch(
'openhands.llm.async_llm.litellm_acompletion',
new=AsyncMock(side_effect=_async_side_effect),
):
llm = StreamingLLM(cfg, service_id='svc')
collected = []
async for chunk in llm.async_streaming_completion(
messages=[{'role': 'user', 'content': 'hi'}], stream=True
):
collected.append(chunk['choices'][0]['delta']['content'])
assert ''.join(collected) == 'hello world'