mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-08 22:18:10 -05:00
Adds support for streaming tool call events with test coverage, fixes tool-stream ID tracking (including OpenAI-style tracking for Azure), improves Gemini tool calling + streaming tests, adds Anthropic tests, generates Azure cassettes, and fixes Azure cassette URIs.
324 lines
12 KiB
Python
324 lines
12 KiB
Python
"""Tests for tool call streaming events across LLM providers.
|
|
|
|
These tests verify that when streaming is enabled and the LLM makes a tool call,
|
|
the stream chunk events include proper tool call information with
|
|
call_type=LLMCallType.TOOL_CALL.
|
|
"""
|
|
|
|
from typing import Any
|
|
from unittest.mock import MagicMock, patch
|
|
|
|
import pytest
|
|
|
|
from crewai.events.types.llm_events import LLMCallType, LLMStreamChunkEvent, ToolCall
|
|
from crewai.llm import LLM
|
|
|
|
|
|
@pytest.fixture
|
|
def get_temperature_tool_schema() -> dict[str, Any]:
|
|
"""Create a temperature tool schema for native function calling."""
|
|
return {
|
|
"type": "function",
|
|
"function": {
|
|
"name": "get_current_temperature",
|
|
"description": "Get the current temperature in a city.",
|
|
"parameters": {
|
|
"type": "object",
|
|
"properties": {
|
|
"city": {
|
|
"type": "string",
|
|
"description": "The name of the city to get the temperature for.",
|
|
}
|
|
},
|
|
"required": ["city"],
|
|
},
|
|
},
|
|
}
|
|
|
|
|
|
@pytest.fixture
|
|
def mock_emit() -> MagicMock:
|
|
"""Mock the event bus emit function."""
|
|
from crewai.events.event_bus import CrewAIEventsBus
|
|
|
|
with patch.object(CrewAIEventsBus, "emit") as mock:
|
|
yield mock
|
|
|
|
|
|
def get_tool_call_events(mock_emit: MagicMock) -> list[LLMStreamChunkEvent]:
|
|
"""Extract tool call streaming events from mock emit calls."""
|
|
tool_call_events = []
|
|
for call in mock_emit.call_args_list:
|
|
event = call[1].get("event") if len(call) > 1 else None
|
|
if isinstance(event, LLMStreamChunkEvent) and event.call_type == LLMCallType.TOOL_CALL:
|
|
tool_call_events.append(event)
|
|
return tool_call_events
|
|
|
|
|
|
def get_all_stream_events(mock_emit: MagicMock) -> list[LLMStreamChunkEvent]:
|
|
"""Extract all streaming events from mock emit calls."""
|
|
stream_events = []
|
|
for call in mock_emit.call_args_list:
|
|
event = call[1].get("event") if len(call) > 1 else None
|
|
if isinstance(event, LLMStreamChunkEvent):
|
|
stream_events.append(event)
|
|
return stream_events
|
|
|
|
|
|
class TestOpenAIToolCallStreaming:
|
|
"""Tests for OpenAI provider tool call streaming events."""
|
|
|
|
@pytest.mark.vcr()
|
|
def test_openai_streaming_emits_tool_call_events(
|
|
self, get_temperature_tool_schema: dict[str, Any], mock_emit: MagicMock
|
|
) -> None:
|
|
"""Test that OpenAI streaming emits tool call events with correct call_type."""
|
|
llm = LLM(model="openai/gpt-4o-mini", stream=True)
|
|
|
|
llm.call(
|
|
messages=[
|
|
{"role": "user", "content": "What is the temperature in San Francisco?"},
|
|
],
|
|
tools=[get_temperature_tool_schema],
|
|
available_functions={
|
|
"get_current_temperature": lambda city: f"The temperature in {city} is 72°F"
|
|
},
|
|
)
|
|
|
|
tool_call_events = get_tool_call_events(mock_emit)
|
|
|
|
assert len(tool_call_events) > 0, "Should receive tool call streaming events"
|
|
|
|
first_tool_call_event = tool_call_events[0]
|
|
assert first_tool_call_event.call_type == LLMCallType.TOOL_CALL
|
|
assert first_tool_call_event.tool_call is not None
|
|
assert isinstance(first_tool_call_event.tool_call, ToolCall)
|
|
assert first_tool_call_event.tool_call.function is not None
|
|
assert first_tool_call_event.tool_call.function.name == "get_current_temperature"
|
|
assert first_tool_call_event.tool_call.type == "function"
|
|
assert first_tool_call_event.tool_call.index >= 0
|
|
|
|
|
|
class TestToolCallStreamingEventStructure:
|
|
"""Tests for the structure and content of tool call streaming events."""
|
|
|
|
@pytest.mark.vcr()
|
|
def test_tool_call_event_accumulates_arguments(
|
|
self, get_temperature_tool_schema: dict[str, Any], mock_emit: MagicMock
|
|
) -> None:
|
|
"""Test that tool call events accumulate arguments progressively."""
|
|
llm = LLM(model="openai/gpt-4o-mini", stream=True)
|
|
|
|
llm.call(
|
|
messages=[
|
|
{"role": "user", "content": "What is the temperature in San Francisco?"},
|
|
],
|
|
tools=[get_temperature_tool_schema],
|
|
available_functions={
|
|
"get_current_temperature": lambda city: f"The temperature in {city} is 72°F"
|
|
},
|
|
)
|
|
|
|
tool_call_events = get_tool_call_events(mock_emit)
|
|
|
|
assert len(tool_call_events) >= 2, "Should receive multiple tool call streaming events"
|
|
|
|
for evt in tool_call_events:
|
|
assert evt.tool_call is not None
|
|
assert evt.tool_call.function is not None
|
|
|
|
@pytest.mark.vcr()
|
|
def test_tool_call_events_have_consistent_tool_id(
|
|
self, get_temperature_tool_schema: dict[str, Any], mock_emit: MagicMock
|
|
) -> None:
|
|
"""Test that all events for the same tool call have the same tool ID."""
|
|
llm = LLM(model="openai/gpt-4o-mini", stream=True)
|
|
|
|
llm.call(
|
|
messages=[
|
|
{"role": "user", "content": "What is the temperature in San Francisco?"},
|
|
],
|
|
tools=[get_temperature_tool_schema],
|
|
available_functions={
|
|
"get_current_temperature": lambda city: f"The temperature in {city} is 72°F"
|
|
},
|
|
)
|
|
|
|
tool_call_events = get_tool_call_events(mock_emit)
|
|
|
|
assert len(tool_call_events) >= 1, "Should receive tool call streaming events"
|
|
|
|
if len(tool_call_events) > 1:
|
|
events_by_index: dict[int, list[LLMStreamChunkEvent]] = {}
|
|
for evt in tool_call_events:
|
|
if evt.tool_call is not None:
|
|
idx = evt.tool_call.index
|
|
if idx not in events_by_index:
|
|
events_by_index[idx] = []
|
|
events_by_index[idx].append(evt)
|
|
|
|
for idx, evts in events_by_index.items():
|
|
ids = [
|
|
e.tool_call.id
|
|
for e in evts
|
|
if e.tool_call is not None and e.tool_call.id
|
|
]
|
|
if ids:
|
|
assert len(set(ids)) == 1, f"Tool call ID should be consistent for index {idx}"
|
|
|
|
|
|
class TestMixedStreamingEvents:
|
|
"""Tests for scenarios with both text and tool call streaming events."""
|
|
|
|
@pytest.mark.vcr()
|
|
def test_streaming_distinguishes_text_and_tool_calls(
|
|
self, get_temperature_tool_schema: dict[str, Any], mock_emit: MagicMock
|
|
) -> None:
|
|
"""Test that streaming correctly distinguishes between text chunks and tool calls."""
|
|
llm = LLM(model="openai/gpt-4o-mini", stream=True)
|
|
|
|
llm.call(
|
|
messages=[
|
|
{"role": "user", "content": "What is the temperature in San Francisco?"},
|
|
],
|
|
tools=[get_temperature_tool_schema],
|
|
available_functions={
|
|
"get_current_temperature": lambda city: f"The temperature in {city} is 72°F"
|
|
},
|
|
)
|
|
|
|
all_events = get_all_stream_events(mock_emit)
|
|
tool_call_events = get_tool_call_events(mock_emit)
|
|
|
|
assert len(all_events) >= 1, "Should receive streaming events"
|
|
|
|
for event in tool_call_events:
|
|
assert event.call_type == LLMCallType.TOOL_CALL
|
|
assert event.tool_call is not None
|
|
|
|
|
|
class TestGeminiToolCallStreaming:
|
|
"""Tests for Gemini provider tool call streaming events."""
|
|
|
|
@pytest.mark.vcr()
|
|
def test_gemini_streaming_emits_tool_call_events(
|
|
self, get_temperature_tool_schema: dict[str, Any], mock_emit: MagicMock
|
|
) -> None:
|
|
"""Test that Gemini streaming emits tool call events with correct call_type."""
|
|
llm = LLM(model="gemini/gemini-2.0-flash", stream=True)
|
|
|
|
llm.call(
|
|
messages=[
|
|
{"role": "user", "content": "What is the temperature in San Francisco?"},
|
|
],
|
|
tools=[get_temperature_tool_schema],
|
|
available_functions={
|
|
"get_current_temperature": lambda city: f"The temperature in {city} is 72°F"
|
|
},
|
|
)
|
|
|
|
tool_call_events = get_tool_call_events(mock_emit)
|
|
|
|
assert len(tool_call_events) > 0, "Should receive tool call streaming events"
|
|
|
|
first_tool_call_event = tool_call_events[0]
|
|
assert first_tool_call_event.call_type == LLMCallType.TOOL_CALL
|
|
assert first_tool_call_event.tool_call is not None
|
|
assert isinstance(first_tool_call_event.tool_call, ToolCall)
|
|
assert first_tool_call_event.tool_call.function is not None
|
|
assert first_tool_call_event.tool_call.function.name == "get_current_temperature"
|
|
assert first_tool_call_event.tool_call.type == "function"
|
|
|
|
@pytest.mark.vcr()
|
|
def test_gemini_streaming_multiple_tool_calls_unique_ids(
|
|
self, get_temperature_tool_schema: dict[str, Any], mock_emit: MagicMock
|
|
) -> None:
|
|
"""Test that Gemini streaming assigns unique IDs to multiple tool calls."""
|
|
llm = LLM(model="gemini/gemini-2.0-flash", stream=True)
|
|
|
|
llm.call(
|
|
messages=[
|
|
{"role": "user", "content": "What is the temperature in Paris and London?"},
|
|
],
|
|
tools=[get_temperature_tool_schema],
|
|
available_functions={
|
|
"get_current_temperature": lambda city: f"The temperature in {city} is 72°F"
|
|
},
|
|
)
|
|
|
|
tool_call_events = get_tool_call_events(mock_emit)
|
|
|
|
assert len(tool_call_events) >= 2, "Should receive at least 2 tool call events"
|
|
|
|
tool_ids = [
|
|
evt.tool_call.id
|
|
for evt in tool_call_events
|
|
if evt.tool_call is not None and evt.tool_call.id
|
|
]
|
|
assert len(set(tool_ids)) >= 2, "Each tool call should have a unique ID"
|
|
|
|
|
|
class TestAzureToolCallStreaming:
|
|
"""Tests for Azure provider tool call streaming events."""
|
|
|
|
@pytest.mark.vcr()
|
|
def test_azure_streaming_emits_tool_call_events(
|
|
self, get_temperature_tool_schema: dict[str, Any], mock_emit: MagicMock
|
|
) -> None:
|
|
"""Test that Azure streaming emits tool call events with correct call_type."""
|
|
llm = LLM(model="azure/gpt-4o-mini", stream=True)
|
|
|
|
llm.call(
|
|
messages=[
|
|
{"role": "user", "content": "What is the temperature in San Francisco?"},
|
|
],
|
|
tools=[get_temperature_tool_schema],
|
|
available_functions={
|
|
"get_current_temperature": lambda city: f"The temperature in {city} is 72°F"
|
|
},
|
|
)
|
|
|
|
tool_call_events = get_tool_call_events(mock_emit)
|
|
|
|
assert len(tool_call_events) > 0, "Should receive tool call streaming events"
|
|
|
|
first_tool_call_event = tool_call_events[0]
|
|
assert first_tool_call_event.call_type == LLMCallType.TOOL_CALL
|
|
assert first_tool_call_event.tool_call is not None
|
|
assert isinstance(first_tool_call_event.tool_call, ToolCall)
|
|
assert first_tool_call_event.tool_call.function is not None
|
|
assert first_tool_call_event.tool_call.function.name == "get_current_temperature"
|
|
assert first_tool_call_event.tool_call.type == "function"
|
|
|
|
|
|
class TestAnthropicToolCallStreaming:
|
|
"""Tests for Anthropic provider tool call streaming events."""
|
|
|
|
@pytest.mark.vcr()
|
|
def test_anthropic_streaming_emits_tool_call_events(
|
|
self, get_temperature_tool_schema: dict[str, Any], mock_emit: MagicMock
|
|
) -> None:
|
|
"""Test that Anthropic streaming emits tool call events with correct call_type."""
|
|
llm = LLM(model="anthropic/claude-3-5-haiku-latest", stream=True)
|
|
|
|
llm.call(
|
|
messages=[
|
|
{"role": "user", "content": "What is the temperature in San Francisco?"},
|
|
],
|
|
tools=[get_temperature_tool_schema],
|
|
available_functions={
|
|
"get_current_temperature": lambda city: f"The temperature in {city} is 72°F"
|
|
},
|
|
)
|
|
|
|
tool_call_events = get_tool_call_events(mock_emit)
|
|
|
|
assert len(tool_call_events) > 0, "Should receive tool call streaming events"
|
|
|
|
first_tool_call_event = tool_call_events[0]
|
|
assert first_tool_call_event.call_type == LLMCallType.TOOL_CALL
|
|
assert first_tool_call_event.tool_call is not None
|
|
assert isinstance(first_tool_call_event.tool_call, ToolCall)
|
|
assert first_tool_call_event.tool_call.function is not None
|
|
assert first_tool_call_event.tool_call.function.name == "get_current_temperature"
|
|
assert first_tool_call_event.tool_call.type == "function" |