Files
home-llm/tests/llama_conversation/test_ai_task.py
2025-12-14 12:35:41 -05:00

145 lines
4.5 KiB
Python

"""Tests for AI Task extraction behavior."""
from typing import Any, cast
import pytest
import voluptuous as vol
from homeassistant.exceptions import HomeAssistantError
from homeassistant.helpers import llm
from custom_components.llama_conversation.ai_task import (
LocalLLMTaskEntity,
ResultExtractionMethod,
)
from custom_components.llama_conversation.const import (
CONF_AI_TASK_EXTRACTION_METHOD,
)
from custom_components.llama_conversation.entity import TextGenerationResult
class DummyGenTask:
def __init__(self, *, name="task", instructions="do", attachments=None, structure=None):
self.name = name
self.instructions = instructions
self.attachments = attachments or []
self.structure = structure
class DummyChatLog:
def __init__(self, content=None):
self.content = content or []
self.conversation_id = "conv-id"
self.llm_api = None
class DummyClient:
def __init__(self, result: TextGenerationResult):
self._result = result
def _generate_system_prompt(self, prompt_template, llm_api, entity_options):
return prompt_template
def _supports_vision(self, _options): # pragma: no cover - not needed for tests
return False
async def _generate(self, _messages, _llm_api, _entity_id, _options):
return self._result
class DummyTaskEntity(LocalLLMTaskEntity):
def __init__(self, hass, client, options):
# Bypass parent init to avoid ConfigEntry/Subentry plumbing.
self.hass = hass
self.client = client
self._runtime_options = options
self.entity_id = "ai_task.test"
self.entry_id = "entry"
self.subentry_id = "subentry"
self._attr_supported_features = 0
@property
def runtime_options(self):
return self._runtime_options
@pytest.mark.asyncio
async def test_no_extraction_returns_raw_text(hass):
entity = DummyTaskEntity(
hass,
DummyClient(TextGenerationResult(response="raw text")),
{CONF_AI_TASK_EXTRACTION_METHOD: ResultExtractionMethod.NONE},
)
chat_log = DummyChatLog()
task = DummyGenTask()
result = await entity._async_generate_data(cast(Any, task), cast(Any, chat_log))
assert result.data == "raw text"
assert chat_log.llm_api is None
@pytest.mark.asyncio
async def test_structured_output_success(hass):
entity = DummyTaskEntity(
hass,
DummyClient(TextGenerationResult(response='{"foo": 1}')),
{CONF_AI_TASK_EXTRACTION_METHOD: ResultExtractionMethod.STRUCTURED_OUTPUT},
)
chat_log = DummyChatLog()
task = DummyGenTask(structure=vol.Schema({"foo": int}))
result = await entity._async_generate_data(cast(Any, task), cast(Any, chat_log))
assert result.data == {"foo": 1}
@pytest.mark.asyncio
async def test_structured_output_invalid_json_raises(hass):
entity = DummyTaskEntity(
hass,
DummyClient(TextGenerationResult(response="not-json")),
{CONF_AI_TASK_EXTRACTION_METHOD: ResultExtractionMethod.STRUCTURED_OUTPUT},
)
chat_log = DummyChatLog()
task = DummyGenTask(structure=vol.Schema({"foo": int}))
with pytest.raises(HomeAssistantError):
await entity._async_generate_data(cast(Any, task), cast(Any, chat_log))
@pytest.mark.asyncio
async def test_tool_extraction_success(hass):
tool_call = llm.ToolInput("submit_response", {"value": 42})
entity = DummyTaskEntity(
hass,
DummyClient(TextGenerationResult(response="", tool_calls=[tool_call])),
{CONF_AI_TASK_EXTRACTION_METHOD: ResultExtractionMethod.TOOL},
)
chat_log = DummyChatLog()
task = DummyGenTask(structure=vol.Schema({"value": int}))
result = await entity._async_generate_data(cast(Any, task), cast(Any, chat_log))
assert result.data == {"value": 42}
assert chat_log.llm_api is not None
@pytest.mark.asyncio
async def test_tool_extraction_missing_tool_args_raises(hass):
class DummyToolCall:
def __init__(self, tool_args=None):
self.tool_args = tool_args
tool_call = DummyToolCall(tool_args=None)
entity = DummyTaskEntity(
hass,
DummyClient(TextGenerationResult(response="", tool_calls=cast(Any, [tool_call]))),
{CONF_AI_TASK_EXTRACTION_METHOD: ResultExtractionMethod.TOOL},
)
chat_log = DummyChatLog()
task = DummyGenTask(structure=vol.Schema({"value": int}))
with pytest.raises(HomeAssistantError):
await entity._async_generate_data(cast(Any, task), cast(Any, chat_log))