mirror of
https://github.com/acon96/home-llm.git
synced 2026-01-08 21:28:05 -05:00
145 lines
4.5 KiB
Python
145 lines
4.5 KiB
Python
"""Tests for AI Task extraction behavior."""
|
|
|
|
from typing import Any, cast
|
|
|
|
import pytest
|
|
import voluptuous as vol
|
|
|
|
from homeassistant.exceptions import HomeAssistantError
|
|
from homeassistant.helpers import llm
|
|
|
|
from custom_components.llama_conversation.ai_task import (
|
|
LocalLLMTaskEntity,
|
|
ResultExtractionMethod,
|
|
)
|
|
from custom_components.llama_conversation.const import (
|
|
CONF_AI_TASK_EXTRACTION_METHOD,
|
|
)
|
|
from custom_components.llama_conversation.entity import TextGenerationResult
|
|
|
|
|
|
class DummyGenTask:
|
|
def __init__(self, *, name="task", instructions="do", attachments=None, structure=None):
|
|
self.name = name
|
|
self.instructions = instructions
|
|
self.attachments = attachments or []
|
|
self.structure = structure
|
|
|
|
|
|
class DummyChatLog:
|
|
def __init__(self, content=None):
|
|
self.content = content or []
|
|
self.conversation_id = "conv-id"
|
|
self.llm_api = None
|
|
|
|
class DummyClient:
|
|
def __init__(self, result: TextGenerationResult):
|
|
self._result = result
|
|
|
|
def _generate_system_prompt(self, prompt_template, llm_api, entity_options):
|
|
return prompt_template
|
|
|
|
def _supports_vision(self, _options): # pragma: no cover - not needed for tests
|
|
return False
|
|
|
|
async def _generate(self, _messages, _llm_api, _entity_id, _options):
|
|
return self._result
|
|
|
|
|
|
class DummyTaskEntity(LocalLLMTaskEntity):
|
|
def __init__(self, hass, client, options):
|
|
# Bypass parent init to avoid ConfigEntry/Subentry plumbing.
|
|
self.hass = hass
|
|
self.client = client
|
|
self._runtime_options = options
|
|
self.entity_id = "ai_task.test"
|
|
self.entry_id = "entry"
|
|
self.subentry_id = "subentry"
|
|
self._attr_supported_features = 0
|
|
|
|
@property
|
|
def runtime_options(self):
|
|
return self._runtime_options
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_no_extraction_returns_raw_text(hass):
|
|
entity = DummyTaskEntity(
|
|
hass,
|
|
DummyClient(TextGenerationResult(response="raw text")),
|
|
{CONF_AI_TASK_EXTRACTION_METHOD: ResultExtractionMethod.NONE},
|
|
)
|
|
chat_log = DummyChatLog()
|
|
task = DummyGenTask()
|
|
|
|
result = await entity._async_generate_data(cast(Any, task), cast(Any, chat_log))
|
|
|
|
assert result.data == "raw text"
|
|
assert chat_log.llm_api is None
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_structured_output_success(hass):
|
|
entity = DummyTaskEntity(
|
|
hass,
|
|
DummyClient(TextGenerationResult(response='{"foo": 1}')),
|
|
{CONF_AI_TASK_EXTRACTION_METHOD: ResultExtractionMethod.STRUCTURED_OUTPUT},
|
|
)
|
|
chat_log = DummyChatLog()
|
|
task = DummyGenTask(structure=vol.Schema({"foo": int}))
|
|
|
|
result = await entity._async_generate_data(cast(Any, task), cast(Any, chat_log))
|
|
|
|
assert result.data == {"foo": 1}
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_structured_output_invalid_json_raises(hass):
|
|
entity = DummyTaskEntity(
|
|
hass,
|
|
DummyClient(TextGenerationResult(response="not-json")),
|
|
{CONF_AI_TASK_EXTRACTION_METHOD: ResultExtractionMethod.STRUCTURED_OUTPUT},
|
|
)
|
|
chat_log = DummyChatLog()
|
|
task = DummyGenTask(structure=vol.Schema({"foo": int}))
|
|
|
|
with pytest.raises(HomeAssistantError):
|
|
await entity._async_generate_data(cast(Any, task), cast(Any, chat_log))
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_tool_extraction_success(hass):
|
|
tool_call = llm.ToolInput("submit_response", {"value": 42})
|
|
entity = DummyTaskEntity(
|
|
hass,
|
|
DummyClient(TextGenerationResult(response="", tool_calls=[tool_call])),
|
|
{CONF_AI_TASK_EXTRACTION_METHOD: ResultExtractionMethod.TOOL},
|
|
)
|
|
chat_log = DummyChatLog()
|
|
task = DummyGenTask(structure=vol.Schema({"value": int}))
|
|
|
|
result = await entity._async_generate_data(cast(Any, task), cast(Any, chat_log))
|
|
|
|
assert result.data == {"value": 42}
|
|
assert chat_log.llm_api is not None
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_tool_extraction_missing_tool_args_raises(hass):
|
|
class DummyToolCall:
|
|
def __init__(self, tool_args=None):
|
|
self.tool_args = tool_args
|
|
|
|
tool_call = DummyToolCall(tool_args=None)
|
|
entity = DummyTaskEntity(
|
|
hass,
|
|
DummyClient(TextGenerationResult(response="", tool_calls=cast(Any, [tool_call]))),
|
|
{CONF_AI_TASK_EXTRACTION_METHOD: ResultExtractionMethod.TOOL},
|
|
)
|
|
chat_log = DummyChatLog()
|
|
task = DummyGenTask(structure=vol.Schema({"value": int}))
|
|
|
|
with pytest.raises(HomeAssistantError):
|
|
await entity._async_generate_data(cast(Any, task), cast(Any, chat_log))
|
|
|