reformat response piles to include confirmation and conclusion

This commit is contained in:
Alex O'Connell
2025-12-07 17:28:38 -05:00
parent 55f254149a
commit dac9973cb5
12 changed files with 3770 additions and 3327 deletions

View File

@@ -128,15 +128,15 @@ def generate_static_example(action: dict, persona: str, language: str, max_devic
service_action = service_name.split(".")[1]
tool_name = SERVICE_TO_TOOL_MAP.get(service_action, TOOL_TURN_ON)
response = get_random_response(
response_starting, response_confirmed = get_random_response(
piles.pile_of_responses,
service=service_name,
persona=persona,
question_template="",
short=False
).lower()
)
answer_list = [response]
answer_list = [response_confirmed]
tool_args = {}
question = question.replace("<device_name>", target_device)
@@ -229,6 +229,7 @@ def generate_static_example(action: dict, persona: str, language: str, max_devic
"available_tools": available_tools,
"question": question.lower(),
"answers": answer_list,
"answer_starting": response_starting,
"tool_calls": [ tool_call ]
}
@@ -291,7 +292,7 @@ def generate_templated_example(template: dict, persona: str, language: str, max_
# pick an appropriate response and generate the question
if len(template_device_types) == 1:
answer_template = get_random_response(
answer_starting, answer_confirmed = get_random_response(
piles.pile_of_responses,
service=service_names[0],
persona=persona,
@@ -300,20 +301,23 @@ def generate_templated_example(template: dict, persona: str, language: str, max_
)
question = question_template.replace("<device_name>", chosen_devices[0]["description"])
answer_list = [ answer_template.replace("<device_name>", chosen_devices[0]["description"]) ]
answer_starting = answer_starting.replace("<device_name>", chosen_devices[0]["description"])
answer_list = [ answer_confirmed.replace("<device_name>", chosen_devices[0]["description"]) ]
else:
question = question_template
answers = []
answer_starting = ""
for i in range(len(template_device_types)):
question = question.replace(f"<device_name{(i + 1)}>", chosen_devices[i]["description"])
answer_response = get_random_response(
answer_starting_part, answer_confirmed = get_random_response(
piles.pile_of_responses,
service=service_names[i],
persona=persona,
question_template=question_template,
short=True
)
answers.append(answer_response.replace(f"<device_name>", chosen_devices[i]["description"]))
answer_starting += answer_starting_part.replace(f"<device_name>", chosen_devices[i]["description"]) + " "
answers.append(answer_confirmed.replace(f"<device_name>", chosen_devices[i]["description"]))
answer_list = []
for word in piles.and_words:
@@ -414,6 +418,7 @@ def generate_templated_example(template: dict, persona: str, language: str, max_
"states": device_list,
"available_tools": available_tools,
"question": question.lower(),
"answer_starting": answer_starting.lower(),
"answers": [ sentence.lower() for sentence in answer_list ],
"tool_calls": tool_calls
}
@@ -520,6 +525,7 @@ def format_example_sharegpt(example, persona, language, use_system_role, use_ser
sys_prompt = generate_system_prompt(example, persona, language, piles.pile_of_system_prompts)
question = example["question"]
answers = " ".join(example["answers"])
answer_starting = example.get("answer_starting", "")
tool_calls = []
tool_results = []
@@ -564,8 +570,7 @@ def format_example_sharegpt(example, persona, language, use_system_role, use_ser
conversation.extend([
{
"role": "assistant",
# FIXME: use the "confirmation" response here instead of a canned text
"content": [{ "type": "text", "text": "I will perform the requested user action." }],
"content": [{ "type": "text", "text": answer_starting }],
"tool_calls": tool_calls
},
{
@@ -581,7 +586,7 @@ def format_example_sharegpt(example, persona, language, use_system_role, use_ser
conversation.extend([
{
"role": "assistant",
"content": [{ "type": "text", "text": answers }],
"content": [{ "type": "text", "text": answer_starting + answers }],
}
])

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

398
data/reformat_responses.py Normal file
View File

@@ -0,0 +1,398 @@
#!/usr/bin/env python3
"""
Script to reformat CSV responses by splitting the 'response' column into
'response_starting' (action in progress) and 'response_confirmed' (action completed)
using llama.cpp's native chat completion endpoint with concurrent aiohttp calls.
pip3 install aiohttp pydantic tqdm
"""
import argparse
import asyncio
import csv
import json
import os
import random
from pathlib import Path
from typing import Dict, List, Optional, Tuple
import aiohttp
from pydantic import BaseModel, Field
from tqdm import tqdm
class ResponseFormat(BaseModel):
response_starting: str = Field(description="Response indicating the action is in progress")
response_confirmed: str = Field(description="Response indicating the action has been completed")
TASK_DESCRIPTION = """
You are reformatting assistant responses for a smart home system.
Given an original response that describes an action, generate TWO variations:
1. response_starting: A response indicating the action is IN PROGRESS (starting, initiating, working on it)
2. response_confirmed: A response indicating the action has been SUCCESSFULLY COMPLETED (done, completed, finished)
Both responses should:
- Maintain the same tone and persona as the original
- Keep any device names or parameters (like <device_name>, <temp_f>, etc.) exactly as they appear
- Be natural and conversational
- Be concise (similar length to the original)
- Avoid overly formal language
- Preserve the language of the original data even if it is not English
Example:
Original: "Opening the blinds for you."
{
"response_starting": "Opening the blinds now."
"response_confirmed": "The blinds are now open."
}
Original: "Setting temperature to <temp_f> degrees."
{
"response_starting": "Setting temperature to <temp_f> degrees."
"response_confirmed": "Temperature has been set to <temp_f> degrees."
}
Respond ONLY with a JSON object in this exact format:
{
"response_starting": "your starting response here",
"response_confirmed": "your confirmed response here"
}
"""
JSON_SCHEMA = ResponseFormat.model_json_schema()
def load_system_prompts(system_prompts_path: Path) -> Dict[str, str]:
"""Load system prompts from CSV file."""
prompts: Dict[str, str] = {}
with open(system_prompts_path, 'r', encoding='utf-8') as f:
reader = csv.DictReader(f)
for row in reader:
prompts[row['persona']] = row['prompt']
return prompts
def _extract_message_content(payload: Dict) -> str:
"""Extract assistant message content from diverse llama.cpp response shapes."""
if not isinstance(payload, dict):
return ''
if 'choices' in payload and payload['choices']:
message = payload['choices'][0].get('message', {})
content = message.get('content')
else:
content = payload.get('content')
if isinstance(content, list):
parts: List[str] = []
for item in content:
if isinstance(item, dict):
parts.append(item.get('text', ''))
elif isinstance(item, str):
parts.append(item)
return ''.join(parts)
if isinstance(content, str):
return content
return ''
async def generate_reformatted_responses(
session: aiohttp.ClientSession,
url: str,
model: str,
system_prompt: str,
original_response: str,
service: str,
semaphore: asyncio.Semaphore,
max_attempts: int,
) -> Dict[str, str]:
"""Use llama.cpp chat endpoint to generate structured responses with retries."""
user_message = f"""Service: {service}
Original response: {original_response}
Generate the two response variations as specified."""
conversation = [
{
"role": "system",
"content": "You are a helpful synthetic data generation assistant. Your task is to respond only using JSON format as specified."
},
{
"role": "user",
"content": f"{TASK_DESCRIPTION}\n\nCurrent Persona System Prompt: {system_prompt}\n{user_message}"
},
]
schema_payload = {
"type": "json_schema",
"json_schema": {
"name": "response_format",
"schema": JSON_SCHEMA,
},
}
attempts_remaining = max_attempts
last_error: Optional[Exception] = None
content = None
while attempts_remaining > 0:
attempts_remaining -= 1
payload = {
"model": model,
"messages": conversation,
"seed": random.randint(1, 1_000_000),
"response_format": schema_payload,
"stream": False,
}
try:
async with semaphore:
async with session.post(url, json=payload) as response:
response.raise_for_status()
data = await response.json()
content = _extract_message_content(data)
result = json.loads(content)
if 'response_starting' not in result or 'response_confirmed' not in result:
raise ValueError(f"Invalid response format: {result}")
return {
'response_starting': result['response_starting'],
'response_confirmed': result['response_confirmed'],
}
except Exception as exc: # noqa: BLE001
last_error = exc
print(f"\nError processing response '{original_response}': {exc} - new response '{content}'")
conversation.append({
"role": "user",
"content": "The previous response was invalid. Respond ONLY with JSON matching the schema: " + str(JSON_SCHEMA),
})
print(
f"Failed to reformat response after multiple attempts. Using original response of '{original_response}'."
)
if last_error:
print(f"Last error: {last_error}")
return {
'response_starting': original_response,
'response_confirmed': original_response,
}
async def process_csv(
input_path: Path,
output_path: Path,
system_prompts_path: Path,
base_endpoint: str,
route: str,
api_key: str | None,
model: str,
max_concurrency: int,
max_attempts: int,
request_timeout: float,
):
"""Process the CSV file and generate reformatted responses concurrently."""
print("Loading system prompts...")
system_prompts = load_system_prompts(system_prompts_path)
print(f"Loaded {len(system_prompts)} personas: {', '.join(system_prompts.keys())}")
print(f"\nReading input file: {input_path}")
with open(input_path, 'r', encoding='utf-8') as input_file:
reader = csv.DictReader(input_file)
rows = list(reader)
total_rows = len(rows)
print(f"Processing {total_rows} rows with concurrency={max_concurrency}...")
headers = {}
if api_key:
headers["Authorization"] = f"Bearer {api_key}"
timeout = aiohttp.ClientTimeout(total=request_timeout)
endpoint = base_endpoint.rstrip('/') + '/' + route.lstrip('/')
semaphore = asyncio.Semaphore(max_concurrency)
output_file = open(output_path, 'w', encoding='utf-8', newline='')
fieldnames = ['service', 'response_starting', 'response_confirmed', 'persona', 'short']
writer = csv.DictWriter(output_file, fieldnames=fieldnames)
writer.writeheader()
async with aiohttp.ClientSession(headers=headers, timeout=timeout) as session:
tasks = []
for idx, row in enumerate(rows):
tasks.append(
asyncio.create_task(
process_row(
idx,
row,
session,
endpoint,
model,
system_prompts,
semaphore,
max_attempts,
)
)
)
progress = tqdm(total=total_rows, desc="Generating responses")
try:
for coro in asyncio.as_completed(tasks):
idx, output_row = await coro
writer.writerow(output_row)
output_file.flush()
progress.update(1)
finally:
progress.close()
output_file.close()
print(f"✓ Successfully processed {total_rows} rows")
print(f"✓ Output saved to: {output_path}")
async def process_row(
idx: int,
row: Dict[str, str],
session: aiohttp.ClientSession,
endpoint: str,
model: str,
system_prompts: Dict[str, str],
semaphore: asyncio.Semaphore,
max_attempts: int,
) -> Tuple[int, Dict[str, str]]:
"""Process a single CSV row asynchronously."""
service = row['service']
original_response = row['response']
persona = row['persona']
short = row['short']
system_prompt = system_prompts.get(persona, system_prompts.get('assistant', ''))
reformatted = await generate_reformatted_responses(
session=session,
url=endpoint,
model=model,
system_prompt=system_prompt,
original_response=original_response,
service=service,
semaphore=semaphore,
max_attempts=max_attempts,
)
output_row = {
'service': service,
'response_starting': reformatted['response_starting'],
'response_confirmed': reformatted['response_confirmed'],
'persona': persona,
'short': short,
}
return idx, output_row
def main() -> int:
parser = argparse.ArgumentParser(
description="Reformat CSV responses using llama.cpp chat completions with concurrent aiohttp calls",
)
parser.add_argument('input_csv', type=Path, help='Input CSV file path (e.g., pile_of_responses.csv)')
parser.add_argument('output_csv', type=Path, help='Output CSV file path')
parser.add_argument(
'--system-prompts',
type=Path,
help='Path to system prompts CSV file (default: same directory as input, pile_of_system_prompts.csv)',
)
parser.add_argument(
'--endpoint',
type=str,
default='https://ai.cloud.alexoconnell.net/v1',
help='Base URL for the llama.cpp server (default: https://ai.cloud.alexoconnell.net)',
)
parser.add_argument(
'--route',
type=str,
default='/chat/completions',
help='Route for the llama.cpp chat completion endpoint (default: /chat/completions)',
)
parser.add_argument(
'--api-key',
type=str,
default=None,
help='Optional API key for the llama.cpp server (default: use OPENAI_API_KEY env var if set)',
)
parser.add_argument(
'--model',
type=str,
default='gpt-oss-120b',
help='Model to use (default: gpt-oss-120b)',
)
parser.add_argument(
'--max-concurrency',
type=int,
default=4,
help='Maximum number of concurrent generations (default: 4)',
)
parser.add_argument(
'--max-attempts',
type=int,
default=3,
help='Maximum number of retries per row (default: 3)',
)
parser.add_argument(
'--timeout',
type=float,
default=60.0,
help='Total request timeout in seconds (default: 60.0)',
)
args = parser.parse_args()
if not args.input_csv.exists():
print(f"Error: Input file not found: {args.input_csv}")
return 1
if args.system_prompts:
system_prompts_path = args.system_prompts
else:
system_prompts_path = args.input_csv.parent / 'pile_of_system_prompts.csv'
if not system_prompts_path.exists():
print(f"Error: System prompts file not found: {system_prompts_path}")
return 1
api_key = args.api_key or os.environ.get('OPENAI_API_KEY')
args.output_csv.parent.mkdir(parents=True, exist_ok=True)
try:
asyncio.run(
process_csv(
input_path=args.input_csv,
output_path=args.output_csv,
system_prompts_path=system_prompts_path,
base_endpoint=args.endpoint,
route=args.route,
api_key=api_key,
model=args.model,
max_concurrency=args.max_concurrency,
max_attempts=args.max_attempts,
request_timeout=args.timeout,
)
)
return 0
except Exception as exc: # noqa: BLE001
print(f"Error: {exc}")
import traceback
traceback.print_exc()
return 1
if __name__ == '__main__':
raise SystemExit(main())

View File

@@ -3,4 +3,7 @@ webcolors>=1.13
pandas>=2.2.3
deep-translator>=1.11.4
langcodes>=3.5.0
babel>=2.15.0
babel>=2.15.0
aiohttp>=3.9.0
pydantic>=2.5.0
tqdm>=4.66.0

View File

@@ -66,8 +66,7 @@ def generate_random_parameter(param_name, piles_of_data):
return param_generator()
# FIXME: return 2 responses, 1 to confirm the action and one to confirm completion of the action
def get_random_response(pile_of_responses, *, service: str, persona: str, question_template: str, short: bool) -> str:
def get_random_response(pile_of_responses, *, service: str, persona: str, question_template: str, short: bool) -> tuple[str, str]:
required_vars = list(set([var for var in var_pattern.findall(question_template) if "device_name" not in var]))
@@ -80,7 +79,7 @@ def get_random_response(pile_of_responses, *, service: str, persona: str, questi
if len(possible_results) == 0:
raise NoResponseAvailableException(f"No responses matched the provided filters: {persona}, {service}, {required_vars}, {short}")
return possible_results.sample()["response"].values[0]
return possible_results.sample()["response_starting"].values[0], possible_results.sample()["response_confirmed"].values[0]
class DatasetPiles:
def __init__(self, supported_devices, language="english"):
@@ -130,7 +129,7 @@ class DatasetPiles:
self.pile_of_specific_actions = list(reader)
self.pile_of_responses = pandas.read_csv(f"piles/{language}/pile_of_responses.csv")
self.pile_of_responses["contains_vars"] = self.pile_of_responses["response"].apply(get_included_vars)
self.pile_of_responses["contains_vars"] = self.pile_of_responses["response_starting"].apply(get_included_vars)
with open(f"piles/{language}/pile_of_status_requests.csv", encoding="utf8") as f:
reader = csv.DictReader(f)

View File

@@ -3,11 +3,8 @@
{%- if messages[0].role == 'system' %}
{{- messages[0].content + '\n\n' }}
{%- endif %}
{{- "# Tools\n\nYou may call one or more functions to assist with the user query.\n\nYou are provided with function signatures within <tools></tools> XML tags:\n<tools>" }}
{%- for tool in tools %}
{{- "\n" }}
{{- tool | tojson }}
{%- endfor %}
{{- "Tools:\n" }}
{%- for tool in tools %} {{ tool['name'] }}({% for param in tool['parameters']['properties'].keys() if param != 'target_device' %}{{ param }}{% if not loop.last %},{% endif %}{% endfor -%}),{% if not loop.last -%}
{{- "\n</tools>\n\nFor each function call, return a json object with function name and arguments within <tool_call></tool_call> XML tags:\n<tool_call>\n{\"name\": <function-name>, \"arguments\": <args-json-object>}\n</tool_call><|im_end|>\n" }}
{%- else %}
{%- if messages[0].role == 'system' %}
@@ -36,7 +33,7 @@
{{- tool_call.name }}
{{- '", "arguments": ' }}
{%- if tool_call.arguments is string %}
{{- tool_call.arguments }}
"{{- tool_call.arguments }}
{%- else %}
{{- tool_call.arguments | tojson }}
{%- endif %}

58
train/zephyr_legacy.j2 Normal file
View File

@@ -0,0 +1,58 @@
{%- if tools %}
{{- '<|system|>\n' }}
{%- if messages[0].role == 'system' %}
{{- messages[0].content + '\n\n' }}
{%- endif %}
{{- "Tools:\n" }}
{%- for tool in tools %} {{ tool['name'] }}({% for param in tool['parameters']['properties'].keys() if param != 'target_device' %}{{ param }}{% if not loop.last %},{% endif %}{% endfor -%}),{% if not loop.last -%}
{{- "\n</tools>\n\nFor each function call, return a json object with function name and arguments within <tool_call></tool_call> XML tags:\n<tool_call>\n{\"name\": <function-name>, \"arguments\": <args-json-object>}\n</tool_call><|endoftext|>\n" }}
{%- else %}
{%- if messages[0].role == 'system' %}
{{- '<|system|>\n' + messages[0].content + '\nNo tools were provided. If the user requests you interact with a device, tell them you are unable to do so.' + '<|endoftext|>\n' }}
{%- endif %}
{%- endif %}
{%- for message in messages %}
{%- if message.content is string %}
{%- set content = message.content %}
{%- else %}
{%- set content = '' %}
{%- endif %}
{%- if (message.role == "user") or (message.role == "system" and not loop.first) %}
{{- '<|' + message.role + '|>' + '\n' + content + '<|endoftext|>' + '\n' }}
{%- elif message.role == "assistant" %}
{{- '<|' + message.role + '|>' + '\n' + content }}
{%- if message.tool_calls %}
{%- for tool_call in message.tool_calls %}
{%- if (loop.first and content) or (not loop.first) %}
{{- '\n' }}
{%- endif %}
{%- if tool_call.function %}
{%- set tool_call = tool_call.function %}
{%- endif %}
{{- '<tool_call>\n{"name": "' }}
{{- tool_call.name }}
{{- '", "arguments": ' }}
{%- if tool_call.arguments is string %}
"{{- tool_call.arguments }}
{%- else %}
{{- tool_call.arguments | tojson }}
{%- endif %}
{{- '}\n</tool_call>' }}
{%- endfor %}
{%- endif %}
{{- '<|endoftext|>\n' }}
{%- elif message.role == "tool" %}
{%- if loop.first or (messages[loop.index0 - 1].role != "tool") %}
{{- '<|user|>' }}
{%- endif %}
{{- '\n<tool_response>\n' }}
{{- content }}
{{- '\n</tool_response>' }}
{%- if loop.last or (messages[loop.index0 + 1].role != "tool") %}
{{- '<|endoftext|>\n' }}
{%- endif %}
{%- endif %}
{%- endfor %}
{%- if add_generation_prompt %}
{{- '<|assistant|>\n' }}
{%- endif %}

View File

@@ -1,17 +0,0 @@
{% for message in messages %}
{%- if message['role'] == 'user' or message['role'] == 'tool' -%}
<|user|> {{ message['content'] }}{{ eos_token }}
{%- elif message['role'] == 'system' -%}
<|system|> {{ message['content'] }}
Services:
{%- for tool in tools %} {{ tool['function']['name'] }}({% for param in tool['function']['parameters']['properties'].keys() if param != 'target_device' %}{{ param }}{% if not loop.last %},{% endif %}{% endfor -%}),{% if not loop.last -%}
{%- if tools | length == 0 %}No tools were provided. If the user requests you interact with a device, tell them you are unable to do so.{% endif %}
{%- endif -%}{%- endfor -%}
{{ eos_token }}
{%- elif message['role'] == 'assistant' -%}
<|assistant|> {{ message['content'] }}{{ eos_token }}
{%- endif -%}
{%- if loop.last and add_generation_prompt %}
<|assistant|>
{%- endif %}
{% endfor -%}