Make JSON output optional in LLM calls

This commit is contained in:
Reinier van der Leer
2025-09-23 16:40:38 +02:00
parent 30402979bb
commit e7b15c39aa
4 changed files with 25 additions and 13 deletions

View File

@@ -382,7 +382,9 @@ def extract_openai_tool_calls(response) -> list[ToolContentBlock] | None:
return None
def get_parallel_tool_calls_param(llm_model: LlmModel, parallel_tool_calls):
def get_parallel_tool_calls_param(
llm_model: LlmModel, parallel_tool_calls: bool | None
):
"""Get the appropriate parallel_tool_calls parameter for OpenAI-compatible APIs."""
if llm_model.startswith("o") or parallel_tool_calls is None:
return openai.NOT_GIVEN
@@ -393,8 +395,8 @@ async def llm_call(
credentials: APIKeyCredentials,
llm_model: LlmModel,
prompt: list[dict],
json_format: bool,
max_tokens: int | None,
force_json_output: bool = False,
tools: list[dict] | None = None,
ollama_host: str = "localhost:11434",
parallel_tool_calls=None,
@@ -407,7 +409,7 @@ async def llm_call(
credentials: The API key credentials to use.
llm_model: The LLM model to use.
prompt: The prompt to send to the LLM.
json_format: Whether the response should be in JSON format.
force_json_output: Whether the response should be in JSON format.
max_tokens: The maximum number of tokens to generate in the chat completion.
tools: The tools to use in the chat completion.
ollama_host: The host for ollama to use.
@@ -446,7 +448,7 @@ async def llm_call(
llm_model, parallel_tool_calls
)
if json_format:
if force_json_output:
response_format = {"type": "json_object"}
response = await oai_client.chat.completions.create(
@@ -559,7 +561,7 @@ async def llm_call(
raise ValueError("Groq does not support tools.")
client = AsyncGroq(api_key=credentials.api_key.get_secret_value())
response_format = {"type": "json_object"} if json_format else None
response_format = {"type": "json_object"} if force_json_output else None
response = await client.chat.completions.create(
model=llm_model.value,
messages=prompt, # type: ignore
@@ -717,7 +719,7 @@ async def llm_call(
)
response_format = None
if json_format:
if force_json_output:
response_format = {"type": "json_object"}
parallel_tool_calls_param = get_parallel_tool_calls_param(
@@ -769,6 +771,16 @@ class AIStructuredResponseGeneratorBlock(AIBlockBase):
description="Expected format of the response. If provided, the response will be validated against this format. "
"The keys should be the expected fields in the response, and the values should be the description of the field.",
)
force_json_output: bool = SchemaField(
title="Restrict LLM to pure JSON output",
default=False,
description=(
"Whether to force the LLM to produce a JSON-only response. "
"This can increase a model's reliability of outputting valid JSON. "
"However, it may also reduce the quality of the response, because it "
"prohibits the LLM from reasoning before providing its JSON response."
),
)
list_result: bool = SchemaField(
title="List Result",
default=False,
@@ -867,9 +879,9 @@ class AIStructuredResponseGeneratorBlock(AIBlockBase):
credentials: APIKeyCredentials,
llm_model: LlmModel,
prompt: list[dict],
json_format: bool,
compress_prompt_to_fit: bool,
max_tokens: int | None,
force_json_output: bool = False,
compress_prompt_to_fit: bool = True,
tools: list[dict] | None = None,
ollama_host: str = "localhost:11434",
) -> LLMResponse:
@@ -882,8 +894,8 @@ class AIStructuredResponseGeneratorBlock(AIBlockBase):
credentials=credentials,
llm_model=llm_model,
prompt=prompt,
json_format=json_format,
max_tokens=max_tokens,
force_json_output=force_json_output,
tools=tools,
ollama_host=ollama_host,
compress_prompt_to_fit=compress_prompt_to_fit,
@@ -954,7 +966,10 @@ class AIStructuredResponseGeneratorBlock(AIBlockBase):
llm_model=llm_model,
prompt=prompt,
compress_prompt_to_fit=input_data.compress_prompt_to_fit,
json_format=bool(input_data.expected_format),
force_json_output=(
input_data.force_json_output
and bool(input_data.expected_format)
),
ollama_host=input_data.ollama_host,
max_tokens=input_data.max_tokens,
)

View File

@@ -523,7 +523,6 @@ class SmartDecisionMakerBlock(Block):
credentials=credentials,
llm_model=input_data.model,
prompt=prompt,
json_format=False,
max_tokens=input_data.max_tokens,
tools=tool_functions,
ollama_host=input_data.ollama_host,

View File

@@ -30,7 +30,6 @@ class TestLLMStatsTracking:
credentials=llm.TEST_CREDENTIALS,
llm_model=llm.LlmModel.GPT4O,
prompt=[{"role": "user", "content": "Hello"}],
json_format=False,
max_tokens=100,
)

View File

@@ -423,7 +423,6 @@ async def _call_llm_direct(
credentials=credentials,
llm_model=LlmModel.GPT4O_MINI,
prompt=prompt,
json_format=False,
max_tokens=150,
compress_prompt_to_fit=True,
)