mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-01-09 15:17:59 -05:00
feat(blocks): Improve JSON generation+parsing in AI Structured Response block (#10960)
The AI Structured Response Generator block currently doesn't support responses that aren't pure JSON. This prohibits multi-step prompting because reasoning content is not allowed in the response, which in turn limits performance. ### Changes 🏗️ - Adjust prompt to enclose JSON in pre-defined tags so we can extract it from a response that isn't pure JSON - Adjust mechanism to extract and parse JSON - Add `force_json_output` input (advanced, default `False`) - Update incorrect `max_output_tokens` values for Claude 4 and 3.7 to prevent responses from being cut off due to `max_tokens` ### Checklist 📋 #### For code changes: - [x] I have clearly listed my changes in the PR description - [x] I have made a test plan - [x] I have tested my changes according to the test plan: - [x] LLMs correctly follows response generation instructions - [x] LLMs follow system response format instructions even if user prompt contains conflicting instructions - [x] JSON is extracted from response successfully - [x] `force_json_output` works (at least for models that support it) Tested with Claude 4 Sonnet, various GPT models, and Llama 3.3 70B. --------- Co-authored-by: Zamil Majdy <zamil.majdy@agpt.co> Co-authored-by: Claude <noreply@anthropic.com> Co-authored-by: Nicholas Tindle <nicholas.tindle@agpt.co>
This commit is contained in:
committed by
GitHub
parent
7bd571d9ce
commit
0b267f573e
@@ -1,5 +1,9 @@
|
||||
# This file contains a lot of prompt block strings that would trigger "line too long"
|
||||
# flake8: noqa: E501
|
||||
import ast
|
||||
import logging
|
||||
import re
|
||||
import secrets
|
||||
from abc import ABC
|
||||
from enum import Enum, EnumMeta
|
||||
from json import JSONDecodeError
|
||||
@@ -204,13 +208,13 @@ MODEL_METADATA = {
|
||||
"anthropic", 200000, 32000
|
||||
), # claude-opus-4-1-20250805
|
||||
LlmModel.CLAUDE_4_OPUS: ModelMetadata(
|
||||
"anthropic", 200000, 8192
|
||||
"anthropic", 200000, 32000
|
||||
), # claude-4-opus-20250514
|
||||
LlmModel.CLAUDE_4_SONNET: ModelMetadata(
|
||||
"anthropic", 200000, 8192
|
||||
"anthropic", 200000, 64000
|
||||
), # claude-4-sonnet-20250514
|
||||
LlmModel.CLAUDE_3_7_SONNET: ModelMetadata(
|
||||
"anthropic", 200000, 8192
|
||||
"anthropic", 200000, 64000
|
||||
), # claude-3-7-sonnet-20250219
|
||||
LlmModel.CLAUDE_3_5_SONNET: ModelMetadata(
|
||||
"anthropic", 200000, 8192
|
||||
@@ -382,7 +386,9 @@ def extract_openai_tool_calls(response) -> list[ToolContentBlock] | None:
|
||||
return None
|
||||
|
||||
|
||||
def get_parallel_tool_calls_param(llm_model: LlmModel, parallel_tool_calls):
|
||||
def get_parallel_tool_calls_param(
|
||||
llm_model: LlmModel, parallel_tool_calls: bool | None
|
||||
):
|
||||
"""Get the appropriate parallel_tool_calls parameter for OpenAI-compatible APIs."""
|
||||
if llm_model.startswith("o") or parallel_tool_calls is None:
|
||||
return openai.NOT_GIVEN
|
||||
@@ -393,8 +399,8 @@ async def llm_call(
|
||||
credentials: APIKeyCredentials,
|
||||
llm_model: LlmModel,
|
||||
prompt: list[dict],
|
||||
json_format: bool,
|
||||
max_tokens: int | None,
|
||||
force_json_output: bool = False,
|
||||
tools: list[dict] | None = None,
|
||||
ollama_host: str = "localhost:11434",
|
||||
parallel_tool_calls=None,
|
||||
@@ -407,7 +413,7 @@ async def llm_call(
|
||||
credentials: The API key credentials to use.
|
||||
llm_model: The LLM model to use.
|
||||
prompt: The prompt to send to the LLM.
|
||||
json_format: Whether the response should be in JSON format.
|
||||
force_json_output: Whether the response should be in JSON format.
|
||||
max_tokens: The maximum number of tokens to generate in the chat completion.
|
||||
tools: The tools to use in the chat completion.
|
||||
ollama_host: The host for ollama to use.
|
||||
@@ -446,7 +452,7 @@ async def llm_call(
|
||||
llm_model, parallel_tool_calls
|
||||
)
|
||||
|
||||
if json_format:
|
||||
if force_json_output:
|
||||
response_format = {"type": "json_object"}
|
||||
|
||||
response = await oai_client.chat.completions.create(
|
||||
@@ -559,7 +565,7 @@ async def llm_call(
|
||||
raise ValueError("Groq does not support tools.")
|
||||
|
||||
client = AsyncGroq(api_key=credentials.api_key.get_secret_value())
|
||||
response_format = {"type": "json_object"} if json_format else None
|
||||
response_format = {"type": "json_object"} if force_json_output else None
|
||||
response = await client.chat.completions.create(
|
||||
model=llm_model.value,
|
||||
messages=prompt, # type: ignore
|
||||
@@ -717,7 +723,7 @@ async def llm_call(
|
||||
)
|
||||
|
||||
response_format = None
|
||||
if json_format:
|
||||
if force_json_output:
|
||||
response_format = {"type": "json_object"}
|
||||
|
||||
parallel_tool_calls_param = get_parallel_tool_calls_param(
|
||||
@@ -780,6 +786,17 @@ class AIStructuredResponseGeneratorBlock(AIBlockBase):
|
||||
description="The language model to use for answering the prompt.",
|
||||
advanced=False,
|
||||
)
|
||||
force_json_output: bool = SchemaField(
|
||||
title="Restrict LLM to pure JSON output",
|
||||
default=False,
|
||||
description=(
|
||||
"Whether to force the LLM to produce a JSON-only response. "
|
||||
"This can increase the block's reliability, "
|
||||
"but may also reduce the quality of the response "
|
||||
"because it prohibits the LLM from reasoning "
|
||||
"before providing its JSON response."
|
||||
),
|
||||
)
|
||||
credentials: AICredentials = AICredentialsField()
|
||||
sys_prompt: str = SchemaField(
|
||||
title="System Prompt",
|
||||
@@ -848,17 +865,18 @@ class AIStructuredResponseGeneratorBlock(AIBlockBase):
|
||||
"llm_call": lambda *args, **kwargs: LLMResponse(
|
||||
raw_response="",
|
||||
prompt=[""],
|
||||
response=json.dumps(
|
||||
{
|
||||
"key1": "key1Value",
|
||||
"key2": "key2Value",
|
||||
}
|
||||
response=(
|
||||
'<json_output id="test123456">{\n'
|
||||
' "key1": "key1Value",\n'
|
||||
' "key2": "key2Value"\n'
|
||||
"}</json_output>"
|
||||
),
|
||||
tool_calls=None,
|
||||
prompt_tokens=0,
|
||||
completion_tokens=0,
|
||||
reasoning=None,
|
||||
)
|
||||
),
|
||||
"get_collision_proof_output_tag_id": lambda *args: "test123456",
|
||||
},
|
||||
)
|
||||
|
||||
@@ -867,9 +885,9 @@ class AIStructuredResponseGeneratorBlock(AIBlockBase):
|
||||
credentials: APIKeyCredentials,
|
||||
llm_model: LlmModel,
|
||||
prompt: list[dict],
|
||||
json_format: bool,
|
||||
compress_prompt_to_fit: bool,
|
||||
max_tokens: int | None,
|
||||
force_json_output: bool = False,
|
||||
compress_prompt_to_fit: bool = True,
|
||||
tools: list[dict] | None = None,
|
||||
ollama_host: str = "localhost:11434",
|
||||
) -> LLMResponse:
|
||||
@@ -882,8 +900,8 @@ class AIStructuredResponseGeneratorBlock(AIBlockBase):
|
||||
credentials=credentials,
|
||||
llm_model=llm_model,
|
||||
prompt=prompt,
|
||||
json_format=json_format,
|
||||
max_tokens=max_tokens,
|
||||
force_json_output=force_json_output,
|
||||
tools=tools,
|
||||
ollama_host=ollama_host,
|
||||
compress_prompt_to_fit=compress_prompt_to_fit,
|
||||
@@ -895,11 +913,6 @@ class AIStructuredResponseGeneratorBlock(AIBlockBase):
|
||||
logger.debug(f"Calling LLM with input data: {input_data}")
|
||||
prompt = [json.to_dict(p) for p in input_data.conversation_history]
|
||||
|
||||
def trim_prompt(s: str) -> str:
|
||||
"""Removes indentation up to and including `|` from a multi-line prompt."""
|
||||
lines = s.strip().split("\n")
|
||||
return "\n".join([line.strip().lstrip("|") for line in lines])
|
||||
|
||||
values = input_data.prompt_values
|
||||
if values:
|
||||
input_data.prompt = fmt.format_string(input_data.prompt, values)
|
||||
@@ -908,28 +921,15 @@ class AIStructuredResponseGeneratorBlock(AIBlockBase):
|
||||
if input_data.sys_prompt:
|
||||
prompt.append({"role": "system", "content": input_data.sys_prompt})
|
||||
|
||||
# Use a one-time unique tag to prevent collisions with user/LLM content
|
||||
output_tag_id = self.get_collision_proof_output_tag_id()
|
||||
output_tag_start = f'<json_output id="{output_tag_id}">'
|
||||
if input_data.expected_format:
|
||||
expected_format = [
|
||||
f"{json.dumps(k)}: {json.dumps(v)}"
|
||||
for k, v in input_data.expected_format.items()
|
||||
]
|
||||
if input_data.list_result:
|
||||
format_prompt = (
|
||||
f'"results": [\n {{\n {", ".join(expected_format)}\n }}\n]'
|
||||
)
|
||||
else:
|
||||
format_prompt = ",\n| ".join(expected_format)
|
||||
|
||||
sys_prompt = trim_prompt(
|
||||
f"""
|
||||
|Reply with pure JSON strictly following this JSON format:
|
||||
|{{
|
||||
| {format_prompt}
|
||||
|}}
|
||||
|
|
||||
|Ensure the response is valid JSON. DO NOT include any additional text (e.g. markdown code block fences) outside of the JSON.
|
||||
|If you cannot provide all the keys, provide an empty string for the values you cannot answer.
|
||||
"""
|
||||
sys_prompt = self.response_format_instructions(
|
||||
input_data.expected_format,
|
||||
list_mode=input_data.list_result,
|
||||
pure_json_mode=input_data.force_json_output,
|
||||
output_tag_start=output_tag_start,
|
||||
)
|
||||
prompt.append({"role": "system", "content": sys_prompt})
|
||||
|
||||
@@ -947,18 +947,21 @@ class AIStructuredResponseGeneratorBlock(AIBlockBase):
|
||||
except JSONDecodeError as e:
|
||||
return f"JSON decode error: {e}"
|
||||
|
||||
logger.debug(f"LLM request: {prompt}")
|
||||
error_feedback_message = ""
|
||||
llm_model = input_data.model
|
||||
|
||||
for retry_count in range(input_data.retry):
|
||||
logger.debug(f"LLM request: {prompt}")
|
||||
try:
|
||||
llm_response = await self.llm_call(
|
||||
credentials=credentials,
|
||||
llm_model=llm_model,
|
||||
prompt=prompt,
|
||||
compress_prompt_to_fit=input_data.compress_prompt_to_fit,
|
||||
json_format=bool(input_data.expected_format),
|
||||
force_json_output=(
|
||||
input_data.force_json_output
|
||||
and bool(input_data.expected_format)
|
||||
),
|
||||
ollama_host=input_data.ollama_host,
|
||||
max_tokens=input_data.max_tokens,
|
||||
)
|
||||
@@ -973,30 +976,52 @@ class AIStructuredResponseGeneratorBlock(AIBlockBase):
|
||||
|
||||
if input_data.expected_format:
|
||||
try:
|
||||
response_obj = json.loads(response_text)
|
||||
except JSONDecodeError as json_error:
|
||||
response_obj = self.get_json_from_response(
|
||||
response_text,
|
||||
pure_json_mode=input_data.force_json_output,
|
||||
output_tag_start=output_tag_start,
|
||||
)
|
||||
except (ValueError, JSONDecodeError) as parse_error:
|
||||
censored_response = re.sub(r"[A-Za-z0-9]", "*", response_text)
|
||||
response_snippet = (
|
||||
f"{censored_response[:50]}...{censored_response[-30:]}"
|
||||
)
|
||||
logger.warning(
|
||||
f"Error getting JSON from LLM response: {parse_error}\n\n"
|
||||
f"Response start+end: `{response_snippet}`"
|
||||
)
|
||||
prompt.append({"role": "assistant", "content": response_text})
|
||||
|
||||
indented_json_error = str(json_error).replace("\n", "\n|")
|
||||
error_feedback_message = trim_prompt(
|
||||
f"""
|
||||
|Your previous response could not be parsed as valid JSON:
|
||||
|
|
||||
|{indented_json_error}
|
||||
|
|
||||
|Please provide a valid JSON response that matches the expected format.
|
||||
"""
|
||||
error_feedback_message = self.invalid_response_feedback(
|
||||
parse_error,
|
||||
was_parseable=False,
|
||||
list_mode=input_data.list_result,
|
||||
pure_json_mode=input_data.force_json_output,
|
||||
output_tag_start=output_tag_start,
|
||||
)
|
||||
prompt.append(
|
||||
{"role": "user", "content": error_feedback_message}
|
||||
)
|
||||
continue
|
||||
|
||||
# Handle object response for `force_json_output`+`list_result`
|
||||
if input_data.list_result and isinstance(response_obj, dict):
|
||||
if "results" in response_obj:
|
||||
response_obj = response_obj.get("results", [])
|
||||
elif len(response_obj) == 1:
|
||||
response_obj = list(response_obj.values())
|
||||
if "results" in response_obj and isinstance(
|
||||
response_obj["results"], list
|
||||
):
|
||||
response_obj = response_obj["results"]
|
||||
else:
|
||||
error_feedback_message = (
|
||||
"Expected an array of objects in the 'results' key, "
|
||||
f"but got: {response_obj}"
|
||||
)
|
||||
prompt.append(
|
||||
{"role": "assistant", "content": response_text}
|
||||
)
|
||||
prompt.append(
|
||||
{"role": "user", "content": error_feedback_message}
|
||||
)
|
||||
continue
|
||||
|
||||
validation_errors = "\n".join(
|
||||
[
|
||||
@@ -1022,12 +1047,12 @@ class AIStructuredResponseGeneratorBlock(AIBlockBase):
|
||||
return
|
||||
|
||||
prompt.append({"role": "assistant", "content": response_text})
|
||||
error_feedback_message = trim_prompt(
|
||||
f"""
|
||||
|Your response did not match the expected format:
|
||||
|
|
||||
|{validation_errors}
|
||||
"""
|
||||
error_feedback_message = self.invalid_response_feedback(
|
||||
validation_errors,
|
||||
was_parseable=True,
|
||||
list_mode=input_data.list_result,
|
||||
pure_json_mode=input_data.force_json_output,
|
||||
output_tag_start=output_tag_start,
|
||||
)
|
||||
prompt.append({"role": "user", "content": error_feedback_message})
|
||||
else:
|
||||
@@ -1059,6 +1084,127 @@ class AIStructuredResponseGeneratorBlock(AIBlockBase):
|
||||
|
||||
raise RuntimeError(error_feedback_message)
|
||||
|
||||
def response_format_instructions(
|
||||
self,
|
||||
expected_object_format: dict[str, str],
|
||||
*,
|
||||
list_mode: bool,
|
||||
pure_json_mode: bool,
|
||||
output_tag_start: str,
|
||||
) -> str:
|
||||
expected_output_format = json.dumps(expected_object_format, indent=2)
|
||||
output_type = "object" if not list_mode else "array"
|
||||
outer_output_type = "object" if pure_json_mode else output_type
|
||||
|
||||
if output_type == "array":
|
||||
indented_obj_format = expected_output_format.replace("\n", "\n ")
|
||||
expected_output_format = f"[\n {indented_obj_format},\n ...\n]"
|
||||
if pure_json_mode:
|
||||
indented_list_format = expected_output_format.replace("\n", "\n ")
|
||||
expected_output_format = (
|
||||
"{\n"
|
||||
' "reasoning": "... (optional)",\n' # for better performance
|
||||
f' "results": {indented_list_format}\n'
|
||||
"}"
|
||||
)
|
||||
|
||||
# Preserve indentation in prompt
|
||||
expected_output_format = expected_output_format.replace("\n", "\n|")
|
||||
|
||||
# Prepare prompt
|
||||
if not pure_json_mode:
|
||||
expected_output_format = (
|
||||
f"{output_tag_start}\n{expected_output_format}\n</json_output>"
|
||||
)
|
||||
|
||||
instructions = f"""
|
||||
|In your response you MUST include a valid JSON {outer_output_type} strictly following this format:
|
||||
|{expected_output_format}
|
||||
|
|
||||
|If you cannot provide all the keys, you MUST provide an empty string for the values you cannot answer.
|
||||
""".strip()
|
||||
|
||||
if not pure_json_mode:
|
||||
instructions += f"""
|
||||
|
|
||||
|You MUST enclose your final JSON answer in {output_tag_start}...</json_output> tags, even if the user specifies a different tag.
|
||||
|There MUST be exactly ONE {output_tag_start}...</json_output> block in your response, which MUST ONLY contain the JSON {outer_output_type} and nothing else. Other text outside this block is allowed.
|
||||
""".strip()
|
||||
|
||||
return trim_prompt(instructions)
|
||||
|
||||
def invalid_response_feedback(
|
||||
self,
|
||||
error,
|
||||
*,
|
||||
was_parseable: bool,
|
||||
list_mode: bool,
|
||||
pure_json_mode: bool,
|
||||
output_tag_start: str,
|
||||
) -> str:
|
||||
outer_output_type = "object" if not list_mode or pure_json_mode else "array"
|
||||
|
||||
if was_parseable:
|
||||
complaint = f"Your previous response did not match the expected {outer_output_type} format."
|
||||
else:
|
||||
complaint = f"Your previous response did not contain a parseable JSON {outer_output_type}."
|
||||
|
||||
indented_parse_error = str(error).replace("\n", "\n|")
|
||||
|
||||
instruction = (
|
||||
f"Please provide a {output_tag_start}...</json_output> block containing a"
|
||||
if not pure_json_mode
|
||||
else "Please provide a"
|
||||
) + f" valid JSON {outer_output_type} that matches the expected format."
|
||||
|
||||
return trim_prompt(
|
||||
f"""
|
||||
|{complaint}
|
||||
|
|
||||
|{indented_parse_error}
|
||||
|
|
||||
|{instruction}
|
||||
"""
|
||||
)
|
||||
|
||||
def get_json_from_response(
|
||||
self, response_text: str, *, pure_json_mode: bool, output_tag_start: str
|
||||
) -> dict[str, Any] | list[dict[str, Any]]:
|
||||
if pure_json_mode:
|
||||
# Handle pure JSON responses
|
||||
try:
|
||||
return json.loads(response_text)
|
||||
except JSONDecodeError as first_parse_error:
|
||||
# If that didn't work, try finding the { and } to deal with possible ```json fences etc.
|
||||
json_start = response_text.find("{")
|
||||
json_end = response_text.rfind("}")
|
||||
try:
|
||||
return json.loads(response_text[json_start : json_end + 1])
|
||||
except JSONDecodeError:
|
||||
# Raise the original error, as it's more likely to be relevant
|
||||
raise first_parse_error from None
|
||||
|
||||
if output_tag_start not in response_text:
|
||||
raise ValueError(
|
||||
"Response does not contain the expected "
|
||||
f"{output_tag_start}...</json_output> block."
|
||||
)
|
||||
json_output = (
|
||||
response_text.split(output_tag_start, 1)[1]
|
||||
.rsplit("</json_output>", 1)[0]
|
||||
.strip()
|
||||
)
|
||||
return json.loads(json_output)
|
||||
|
||||
def get_collision_proof_output_tag_id(self) -> str:
|
||||
return secrets.token_hex(8)
|
||||
|
||||
|
||||
def trim_prompt(s: str) -> str:
|
||||
"""Removes indentation up to and including `|` from a multi-line prompt."""
|
||||
lines = s.strip().split("\n")
|
||||
return "\n".join([line.strip().lstrip("|") for line in lines])
|
||||
|
||||
|
||||
class AITextGeneratorBlock(AIBlockBase):
|
||||
class Input(BlockSchema):
|
||||
|
||||
@@ -523,7 +523,6 @@ class SmartDecisionMakerBlock(Block):
|
||||
credentials=credentials,
|
||||
llm_model=input_data.model,
|
||||
prompt=prompt,
|
||||
json_format=False,
|
||||
max_tokens=input_data.max_tokens,
|
||||
tools=tool_functions,
|
||||
ollama_host=input_data.ollama_host,
|
||||
|
||||
@@ -30,7 +30,6 @@ class TestLLMStatsTracking:
|
||||
credentials=llm.TEST_CREDENTIALS,
|
||||
llm_model=llm.LlmModel.GPT4O,
|
||||
prompt=[{"role": "user", "content": "Hello"}],
|
||||
json_format=False,
|
||||
max_tokens=100,
|
||||
)
|
||||
|
||||
@@ -42,6 +41,8 @@ class TestLLMStatsTracking:
|
||||
@pytest.mark.asyncio
|
||||
async def test_ai_structured_response_block_tracks_stats(self):
|
||||
"""Test that AIStructuredResponseGeneratorBlock correctly tracks stats."""
|
||||
from unittest.mock import patch
|
||||
|
||||
import backend.blocks.llm as llm
|
||||
|
||||
block = llm.AIStructuredResponseGeneratorBlock()
|
||||
@@ -51,7 +52,7 @@ class TestLLMStatsTracking:
|
||||
return llm.LLMResponse(
|
||||
raw_response="",
|
||||
prompt=[],
|
||||
response='{"key1": "value1", "key2": "value2"}',
|
||||
response='<json_output id="test123456">{"key1": "value1", "key2": "value2"}</json_output>',
|
||||
tool_calls=None,
|
||||
prompt_tokens=15,
|
||||
completion_tokens=25,
|
||||
@@ -69,10 +70,12 @@ class TestLLMStatsTracking:
|
||||
)
|
||||
|
||||
outputs = {}
|
||||
async for output_name, output_data in block.run(
|
||||
input_data, credentials=llm.TEST_CREDENTIALS
|
||||
):
|
||||
outputs[output_name] = output_data
|
||||
# Mock secrets.token_hex to return consistent ID
|
||||
with patch("secrets.token_hex", return_value="test123456"):
|
||||
async for output_name, output_data in block.run(
|
||||
input_data, credentials=llm.TEST_CREDENTIALS
|
||||
):
|
||||
outputs[output_name] = output_data
|
||||
|
||||
# Check stats
|
||||
assert block.execution_stats.input_token_count == 15
|
||||
@@ -143,7 +146,7 @@ class TestLLMStatsTracking:
|
||||
return llm.LLMResponse(
|
||||
raw_response="",
|
||||
prompt=[],
|
||||
response='{"wrong": "format"}',
|
||||
response='<json_output id="test123456">{"wrong": "format"}</json_output>',
|
||||
tool_calls=None,
|
||||
prompt_tokens=10,
|
||||
completion_tokens=15,
|
||||
@@ -154,7 +157,7 @@ class TestLLMStatsTracking:
|
||||
return llm.LLMResponse(
|
||||
raw_response="",
|
||||
prompt=[],
|
||||
response='{"key1": "value1", "key2": "value2"}',
|
||||
response='<json_output id="test123456">{"key1": "value1", "key2": "value2"}</json_output>',
|
||||
tool_calls=None,
|
||||
prompt_tokens=20,
|
||||
completion_tokens=25,
|
||||
@@ -173,10 +176,12 @@ class TestLLMStatsTracking:
|
||||
)
|
||||
|
||||
outputs = {}
|
||||
async for output_name, output_data in block.run(
|
||||
input_data, credentials=llm.TEST_CREDENTIALS
|
||||
):
|
||||
outputs[output_name] = output_data
|
||||
# Mock secrets.token_hex to return consistent ID
|
||||
with patch("secrets.token_hex", return_value="test123456"):
|
||||
async for output_name, output_data in block.run(
|
||||
input_data, credentials=llm.TEST_CREDENTIALS
|
||||
):
|
||||
outputs[output_name] = output_data
|
||||
|
||||
# Check stats - should accumulate both calls
|
||||
# For 2 attempts: attempt 1 (failed) + attempt 2 (success) = 2 total
|
||||
@@ -269,7 +274,8 @@ class TestLLMStatsTracking:
|
||||
mock_response.choices = [
|
||||
MagicMock(
|
||||
message=MagicMock(
|
||||
content='{"summary": "Test chunk summary"}', tool_calls=None
|
||||
content='<json_output id="test123456">{"summary": "Test chunk summary"}</json_output>',
|
||||
tool_calls=None,
|
||||
)
|
||||
)
|
||||
]
|
||||
@@ -277,7 +283,7 @@ class TestLLMStatsTracking:
|
||||
mock_response.choices = [
|
||||
MagicMock(
|
||||
message=MagicMock(
|
||||
content='{"final_summary": "Test final summary"}',
|
||||
content='<json_output id="test123456">{"final_summary": "Test final summary"}</json_output>',
|
||||
tool_calls=None,
|
||||
)
|
||||
)
|
||||
@@ -298,11 +304,13 @@ class TestLLMStatsTracking:
|
||||
max_tokens=1000, # Large enough to avoid chunking
|
||||
)
|
||||
|
||||
outputs = {}
|
||||
async for output_name, output_data in block.run(
|
||||
input_data, credentials=llm.TEST_CREDENTIALS
|
||||
):
|
||||
outputs[output_name] = output_data
|
||||
# Mock secrets.token_hex to return consistent ID
|
||||
with patch("secrets.token_hex", return_value="test123456"):
|
||||
outputs = {}
|
||||
async for output_name, output_data in block.run(
|
||||
input_data, credentials=llm.TEST_CREDENTIALS
|
||||
):
|
||||
outputs[output_name] = output_data
|
||||
|
||||
print(f"Actual calls made: {call_count}")
|
||||
print(f"Block stats: {block.execution_stats}")
|
||||
@@ -457,7 +465,7 @@ class TestLLMStatsTracking:
|
||||
return llm.LLMResponse(
|
||||
raw_response="",
|
||||
prompt=[],
|
||||
response='{"result": "test"}',
|
||||
response='<json_output id="test123456">{"result": "test"}</json_output>',
|
||||
tool_calls=None,
|
||||
prompt_tokens=10,
|
||||
completion_tokens=20,
|
||||
@@ -476,10 +484,12 @@ class TestLLMStatsTracking:
|
||||
|
||||
# Run the block
|
||||
outputs = {}
|
||||
async for output_name, output_data in block.run(
|
||||
input_data, credentials=llm.TEST_CREDENTIALS
|
||||
):
|
||||
outputs[output_name] = output_data
|
||||
# Mock secrets.token_hex to return consistent ID
|
||||
with patch("secrets.token_hex", return_value="test123456"):
|
||||
async for output_name, output_data in block.run(
|
||||
input_data, credentials=llm.TEST_CREDENTIALS
|
||||
):
|
||||
outputs[output_name] = output_data
|
||||
|
||||
# Block finished - now grab and assert stats
|
||||
assert block.execution_stats is not None
|
||||
|
||||
@@ -423,7 +423,6 @@ async def _call_llm_direct(
|
||||
credentials=credentials,
|
||||
llm_model=LlmModel.GPT4O_MINI,
|
||||
prompt=prompt,
|
||||
json_format=False,
|
||||
max_tokens=150,
|
||||
compress_prompt_to_fit=True,
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user