Files
crewAI/lib/crewai/tests/llms/test_tool_call_streaming.py
Greyson LaLonde f8deb0fd18 feat: add streaming tool call events; fix provider id tracking; add tests and cassettes
Adds support for streaming tool call events with test coverage, fixes tool-stream ID tracking (including OpenAI-style tracking for Azure), improves Gemini tool calling + streaming tests, adds Anthropic tests, generates Azure cassettes, and fixes Azure cassette URIs.
2026-01-05 14:33:36 -05:00

324 lines
12 KiB
Python

"""Tests for tool call streaming events across LLM providers.
These tests verify that when streaming is enabled and the LLM makes a tool call,
the stream chunk events include proper tool call information with
call_type=LLMCallType.TOOL_CALL.
"""
from typing import Any
from unittest.mock import MagicMock, patch
import pytest
from crewai.events.types.llm_events import LLMCallType, LLMStreamChunkEvent, ToolCall
from crewai.llm import LLM
@pytest.fixture
def get_temperature_tool_schema() -> dict[str, Any]:
"""Create a temperature tool schema for native function calling."""
return {
"type": "function",
"function": {
"name": "get_current_temperature",
"description": "Get the current temperature in a city.",
"parameters": {
"type": "object",
"properties": {
"city": {
"type": "string",
"description": "The name of the city to get the temperature for.",
}
},
"required": ["city"],
},
},
}
@pytest.fixture
def mock_emit() -> MagicMock:
"""Mock the event bus emit function."""
from crewai.events.event_bus import CrewAIEventsBus
with patch.object(CrewAIEventsBus, "emit") as mock:
yield mock
def get_tool_call_events(mock_emit: MagicMock) -> list[LLMStreamChunkEvent]:
"""Extract tool call streaming events from mock emit calls."""
tool_call_events = []
for call in mock_emit.call_args_list:
event = call[1].get("event") if len(call) > 1 else None
if isinstance(event, LLMStreamChunkEvent) and event.call_type == LLMCallType.TOOL_CALL:
tool_call_events.append(event)
return tool_call_events
def get_all_stream_events(mock_emit: MagicMock) -> list[LLMStreamChunkEvent]:
"""Extract all streaming events from mock emit calls."""
stream_events = []
for call in mock_emit.call_args_list:
event = call[1].get("event") if len(call) > 1 else None
if isinstance(event, LLMStreamChunkEvent):
stream_events.append(event)
return stream_events
class TestOpenAIToolCallStreaming:
"""Tests for OpenAI provider tool call streaming events."""
@pytest.mark.vcr()
def test_openai_streaming_emits_tool_call_events(
self, get_temperature_tool_schema: dict[str, Any], mock_emit: MagicMock
) -> None:
"""Test that OpenAI streaming emits tool call events with correct call_type."""
llm = LLM(model="openai/gpt-4o-mini", stream=True)
llm.call(
messages=[
{"role": "user", "content": "What is the temperature in San Francisco?"},
],
tools=[get_temperature_tool_schema],
available_functions={
"get_current_temperature": lambda city: f"The temperature in {city} is 72°F"
},
)
tool_call_events = get_tool_call_events(mock_emit)
assert len(tool_call_events) > 0, "Should receive tool call streaming events"
first_tool_call_event = tool_call_events[0]
assert first_tool_call_event.call_type == LLMCallType.TOOL_CALL
assert first_tool_call_event.tool_call is not None
assert isinstance(first_tool_call_event.tool_call, ToolCall)
assert first_tool_call_event.tool_call.function is not None
assert first_tool_call_event.tool_call.function.name == "get_current_temperature"
assert first_tool_call_event.tool_call.type == "function"
assert first_tool_call_event.tool_call.index >= 0
class TestToolCallStreamingEventStructure:
"""Tests for the structure and content of tool call streaming events."""
@pytest.mark.vcr()
def test_tool_call_event_accumulates_arguments(
self, get_temperature_tool_schema: dict[str, Any], mock_emit: MagicMock
) -> None:
"""Test that tool call events accumulate arguments progressively."""
llm = LLM(model="openai/gpt-4o-mini", stream=True)
llm.call(
messages=[
{"role": "user", "content": "What is the temperature in San Francisco?"},
],
tools=[get_temperature_tool_schema],
available_functions={
"get_current_temperature": lambda city: f"The temperature in {city} is 72°F"
},
)
tool_call_events = get_tool_call_events(mock_emit)
assert len(tool_call_events) >= 2, "Should receive multiple tool call streaming events"
for evt in tool_call_events:
assert evt.tool_call is not None
assert evt.tool_call.function is not None
@pytest.mark.vcr()
def test_tool_call_events_have_consistent_tool_id(
self, get_temperature_tool_schema: dict[str, Any], mock_emit: MagicMock
) -> None:
"""Test that all events for the same tool call have the same tool ID."""
llm = LLM(model="openai/gpt-4o-mini", stream=True)
llm.call(
messages=[
{"role": "user", "content": "What is the temperature in San Francisco?"},
],
tools=[get_temperature_tool_schema],
available_functions={
"get_current_temperature": lambda city: f"The temperature in {city} is 72°F"
},
)
tool_call_events = get_tool_call_events(mock_emit)
assert len(tool_call_events) >= 1, "Should receive tool call streaming events"
if len(tool_call_events) > 1:
events_by_index: dict[int, list[LLMStreamChunkEvent]] = {}
for evt in tool_call_events:
if evt.tool_call is not None:
idx = evt.tool_call.index
if idx not in events_by_index:
events_by_index[idx] = []
events_by_index[idx].append(evt)
for idx, evts in events_by_index.items():
ids = [
e.tool_call.id
for e in evts
if e.tool_call is not None and e.tool_call.id
]
if ids:
assert len(set(ids)) == 1, f"Tool call ID should be consistent for index {idx}"
class TestMixedStreamingEvents:
"""Tests for scenarios with both text and tool call streaming events."""
@pytest.mark.vcr()
def test_streaming_distinguishes_text_and_tool_calls(
self, get_temperature_tool_schema: dict[str, Any], mock_emit: MagicMock
) -> None:
"""Test that streaming correctly distinguishes between text chunks and tool calls."""
llm = LLM(model="openai/gpt-4o-mini", stream=True)
llm.call(
messages=[
{"role": "user", "content": "What is the temperature in San Francisco?"},
],
tools=[get_temperature_tool_schema],
available_functions={
"get_current_temperature": lambda city: f"The temperature in {city} is 72°F"
},
)
all_events = get_all_stream_events(mock_emit)
tool_call_events = get_tool_call_events(mock_emit)
assert len(all_events) >= 1, "Should receive streaming events"
for event in tool_call_events:
assert event.call_type == LLMCallType.TOOL_CALL
assert event.tool_call is not None
class TestGeminiToolCallStreaming:
"""Tests for Gemini provider tool call streaming events."""
@pytest.mark.vcr()
def test_gemini_streaming_emits_tool_call_events(
self, get_temperature_tool_schema: dict[str, Any], mock_emit: MagicMock
) -> None:
"""Test that Gemini streaming emits tool call events with correct call_type."""
llm = LLM(model="gemini/gemini-2.0-flash", stream=True)
llm.call(
messages=[
{"role": "user", "content": "What is the temperature in San Francisco?"},
],
tools=[get_temperature_tool_schema],
available_functions={
"get_current_temperature": lambda city: f"The temperature in {city} is 72°F"
},
)
tool_call_events = get_tool_call_events(mock_emit)
assert len(tool_call_events) > 0, "Should receive tool call streaming events"
first_tool_call_event = tool_call_events[0]
assert first_tool_call_event.call_type == LLMCallType.TOOL_CALL
assert first_tool_call_event.tool_call is not None
assert isinstance(first_tool_call_event.tool_call, ToolCall)
assert first_tool_call_event.tool_call.function is not None
assert first_tool_call_event.tool_call.function.name == "get_current_temperature"
assert first_tool_call_event.tool_call.type == "function"
@pytest.mark.vcr()
def test_gemini_streaming_multiple_tool_calls_unique_ids(
self, get_temperature_tool_schema: dict[str, Any], mock_emit: MagicMock
) -> None:
"""Test that Gemini streaming assigns unique IDs to multiple tool calls."""
llm = LLM(model="gemini/gemini-2.0-flash", stream=True)
llm.call(
messages=[
{"role": "user", "content": "What is the temperature in Paris and London?"},
],
tools=[get_temperature_tool_schema],
available_functions={
"get_current_temperature": lambda city: f"The temperature in {city} is 72°F"
},
)
tool_call_events = get_tool_call_events(mock_emit)
assert len(tool_call_events) >= 2, "Should receive at least 2 tool call events"
tool_ids = [
evt.tool_call.id
for evt in tool_call_events
if evt.tool_call is not None and evt.tool_call.id
]
assert len(set(tool_ids)) >= 2, "Each tool call should have a unique ID"
class TestAzureToolCallStreaming:
"""Tests for Azure provider tool call streaming events."""
@pytest.mark.vcr()
def test_azure_streaming_emits_tool_call_events(
self, get_temperature_tool_schema: dict[str, Any], mock_emit: MagicMock
) -> None:
"""Test that Azure streaming emits tool call events with correct call_type."""
llm = LLM(model="azure/gpt-4o-mini", stream=True)
llm.call(
messages=[
{"role": "user", "content": "What is the temperature in San Francisco?"},
],
tools=[get_temperature_tool_schema],
available_functions={
"get_current_temperature": lambda city: f"The temperature in {city} is 72°F"
},
)
tool_call_events = get_tool_call_events(mock_emit)
assert len(tool_call_events) > 0, "Should receive tool call streaming events"
first_tool_call_event = tool_call_events[0]
assert first_tool_call_event.call_type == LLMCallType.TOOL_CALL
assert first_tool_call_event.tool_call is not None
assert isinstance(first_tool_call_event.tool_call, ToolCall)
assert first_tool_call_event.tool_call.function is not None
assert first_tool_call_event.tool_call.function.name == "get_current_temperature"
assert first_tool_call_event.tool_call.type == "function"
class TestAnthropicToolCallStreaming:
"""Tests for Anthropic provider tool call streaming events."""
@pytest.mark.vcr()
def test_anthropic_streaming_emits_tool_call_events(
self, get_temperature_tool_schema: dict[str, Any], mock_emit: MagicMock
) -> None:
"""Test that Anthropic streaming emits tool call events with correct call_type."""
llm = LLM(model="anthropic/claude-3-5-haiku-latest", stream=True)
llm.call(
messages=[
{"role": "user", "content": "What is the temperature in San Francisco?"},
],
tools=[get_temperature_tool_schema],
available_functions={
"get_current_temperature": lambda city: f"The temperature in {city} is 72°F"
},
)
tool_call_events = get_tool_call_events(mock_emit)
assert len(tool_call_events) > 0, "Should receive tool call streaming events"
first_tool_call_event = tool_call_events[0]
assert first_tool_call_event.call_type == LLMCallType.TOOL_CALL
assert first_tool_call_event.tool_call is not None
assert isinstance(first_tool_call_event.tool_call, ToolCall)
assert first_tool_call_event.tool_call.function is not None
assert first_tool_call_event.tool_call.function.name == "get_current_temperature"
assert first_tool_call_event.tool_call.type == "function"