Merge pull request #9 from PragmaticMachineLearning/fix/abort-streaming-request

Fix/abort streaming request
This commit is contained in:
Oluwatobi Adefami
2025-02-19 03:08:36 +01:00
committed by GitHub
3 changed files with 143 additions and 69 deletions

View File

@@ -189,6 +189,7 @@ const SpreadsheetApp = () => {
);
} catch (error) {
if (error.name === "AbortError") {
console.log("Request Aborted");
// Handle abort case
setChatHistory((prev) =>
prev.map((msg) =>

View File

@@ -1,4 +1,4 @@
import { useState } from "react";
import { useEffect, useState } from "react";
import { CellUpdate, ChatMessage } from "@/types/api";
import { Check, X, Send, Trash2, Loader2, Square } from "lucide-react";
@@ -22,8 +22,15 @@ const ChatBox = ({
const [message, setMessage] = useState("");
const [isLoading, setIsLoading] = useState(false);
useEffect(() => {
if (chatHistory.length > 0) {
const lastMessage = chatHistory[chatHistory.length - 1];
setIsLoading(!!lastMessage.streaming);
}
}, [chatHistory]);
const handleSend = async () => {
if (message.trim()) {
if (message.trim() || isLoading) {
if (isLoading) {
onStop();
setIsLoading(false);
@@ -36,8 +43,6 @@ const ChatBox = ({
await onSend(currentMessage); // Changed to await for streaming
} catch (error) {
console.error("Error details:", error);
} finally {
setIsLoading(false);
}
}
};

View File

@@ -43,6 +43,15 @@ async function handleLLMRequest(
chatHistory: { role: string; content: string }[],
res: any,
): Promise<void> {
let aborted = false;
let sandbox: Sandbox | null = null;
// Set up disconnect handler
res.on("close", () => {
aborted = true;
console.log("Client disconnected");
});
try {
const data = formatSpreadsheetData(spreadsheetData);
const spreadsheetContext = spreadsheetData?.length
@@ -55,7 +64,8 @@ async function handleLLMRequest(
...chatHistory.slice(-10),
{ role: "user", content: userMessage },
];
// First, try a streaming call without tools
// First streaming call
const stream = await openai.chat.completions.create({
messages: messages,
model: model,
@@ -64,6 +74,13 @@ async function handleLLMRequest(
let accumulatedContent = "";
for await (const chunk of stream) {
// Check if client disconnected
if (aborted) {
console.log("Aborting stream processing");
await stream.controller.abort();
return;
}
const content = chunk.choices[0]?.delta?.content || "";
if (content) {
accumulatedContent += content;
@@ -76,7 +93,10 @@ async function handleLLMRequest(
}
}
// After streaming text, check if we need tool calls
// Check again before making the tool call
if (aborted) return;
// Tool completion call
const toolCompletion = await openai.chat.completions.create({
messages: [
...messages,
@@ -87,10 +107,12 @@ async function handleLLMRequest(
stream: false,
});
// Check if aborted before processing tool calls
if (aborted) return;
const assistantMessage = toolCompletion.choices[0]?.message;
const toolCalls = assistantMessage?.tool_calls;
console.log("Assistant message>>>", assistantMessage);
console.log("Tool Calls>>>", toolCalls);
if (toolCalls?.length) {
const toolCall = toolCalls[0];
let toolData: any = {
@@ -98,87 +120,117 @@ async function handleLLMRequest(
};
if (toolCall.function.name === "set_spreadsheet_cells") {
if (aborted) return;
const updates = JSON.parse(toolCall.function.arguments).cellUpdates;
toolData.updates = updates;
// Add formatted tool data to the response
toolData.response +=
"\n\nSpreadsheet Updates:\n" +
updates
.map((update) => `${update.target}: ${update.formula}`)
.map((update: any) => `${update.target}: ${update.formula}`)
.join("\n");
} else if (toolCall.function.name === "create_chart") {
if (aborted) return;
const args = JSON.parse(toolCall.function.arguments);
toolData.chartData = {
type: args.type,
options: { title: args.title, data: args.data },
};
// Add formatted chart data to the response
toolData.response +=
"\n\nChart Data:\n" +
`Type: ${args.type}\n` +
`Title: ${args.title}\n` +
`Data:\n${args.data.map((row) => row.join(", ")).join("\n")}`;
`Data:\n${args.data.map((row: any[]) => row.join(", ")).join("\n")}`;
} else if (toolCall.function.name === "execute_python_code") {
const sandbox = await Sandbox.create();
const dirname = "/home/user";
try {
if (aborted) return;
sandbox = await Sandbox.create();
const dirname = "/home/user";
const { analysis_goal, suggested_code, start_cell } = JSON.parse(
toolCall.function.arguments,
);
console.log("START CELL", start_cell);
const csvData = convertToCSV(spreadsheetData);
await sandbox.files.write(`${dirname}/data.csv`, csvData);
const { analysis_goal, suggested_code, start_cell } = JSON.parse(
toolCall.function.arguments,
);
const pythonScript = `
import pandas as pd
import numpy as np
# Read the data
df = pd.read_csv('/home/user/data.csv')
# Execute analysis
${suggested_code}
`;
if (aborted) {
await sandbox.kill();
return;
}
const execution = await sandbox.runCode(pythonScript);
const fileContent = await sandbox.files.read("/home/user/outputs.csv");
console.log("FILE CONTENT>", fileContent);
// Parse the CSV output to generate cell updates
const outputRows = fileContent
.trim()
.split("\n")
.map((row) => row.split(","));
const csvData = convertToCSV(spreadsheetData);
await sandbox.files.write(`${dirname}/data.csv`, csvData);
const colLetter = start_cell.match(/[A-Z]+/)[0];
const rowNumber = parseInt(start_cell.match(/\d+/)[0]);
const generatedUpdates: CellUpdate[] = outputRows.flatMap(
(row, rowIndex) =>
row.map((value, colIndex) => ({
target: `${String.fromCharCode(colLetter.charCodeAt(0) + colIndex)}${rowNumber + rowIndex}`,
formula: value.toString(),
})),
);
if (aborted) {
await sandbox.kill();
return;
}
// Generate cell updates based on the output data
console.log("GENERATED UPDATES", generatedUpdates);
toolData = {
response: `Analysis: ${analysis_goal}\n\nResults:\n${fileContent}`,
updates: generatedUpdates,
analysis: {
goal: analysis_goal,
output: fileContent,
error: execution.logs.stderr,
},
};
const pythonScript = `
import pandas as pd
import numpy as np
# Read the data
df = pd.read_csv('/home/user/data.csv')
# Execute analysis
${suggested_code}
`;
const execution = await sandbox.runCode(pythonScript);
if (aborted) {
await sandbox.kill();
return;
}
const fileContent = await sandbox.files.read(
"/home/user/outputs.csv",
);
// Parse the CSV output to generate cell updates
const outputRows = fileContent
.trim()
.split("\n")
.map((row) => row.split(","));
const colLetter = start_cell.match(/[A-Z]+/)[0];
const rowNumber = parseInt(start_cell.match(/\d+/)[0]);
const generatedUpdates: CellUpdate[] = outputRows.flatMap(
(row, rowIndex) =>
row.map((value, colIndex) => ({
target: `${String.fromCharCode(
colLetter.charCodeAt(0) + colIndex,
)}${rowNumber + rowIndex}`,
formula: value.toString(),
})),
);
toolData = {
response: `Analysis: ${analysis_goal}\n\nResults:\n${fileContent}`,
updates: generatedUpdates,
analysis: {
goal: analysis_goal,
output: fileContent,
error: execution.logs.stderr,
},
};
} finally {
if (sandbox) {
await sandbox.kill();
sandbox = null;
}
}
}
res.write(
`data: ${JSON.stringify({
...toolData,
streaming: false,
})}\n\n`,
);
} else {
// Only send response if not aborted
if (!aborted) {
res.write(
`data: ${JSON.stringify({
...toolData,
streaming: false,
})}\n\n`,
);
}
} else if (!aborted) {
res.write(
`data: ${JSON.stringify({
response: accumulatedContent,
@@ -187,12 +239,17 @@ async function handleLLMRequest(
);
}
} catch (error: any) {
console.error("LLM API error:", error);
res.write(
`data: ${JSON.stringify({ error: error.message || "Unknown error" })}\n\n`,
);
if (!aborted) {
console.error("LLM API error:", error);
res.write(
`data: ${JSON.stringify({ error: error.message || "Unknown error" })}\n\n`,
);
}
} finally {
res.end();
// Ensure sandbox is destroyed if it exists
if (sandbox) {
await sandbox.kill();
}
}
}
@@ -203,9 +260,20 @@ export default async function handler(req: any, res: any): Promise<void> {
res.setHeader("Connection", "keep-alive");
res.flushHeaders();
// Create a promise that resolves when the client disconnects
const disconnectPromise = new Promise((resolve) => {
res.on("close", () => {
resolve(undefined);
});
});
try {
const { message, spreadsheetData, chatHistory } = req.body;
await handleLLMRequest(message, spreadsheetData, chatHistory, res);
// Race between the LLM request and the client disconnecting
await Promise.race([
handleLLMRequest(message, spreadsheetData, chatHistory, res),
disconnectPromise,
]);
} catch (error: any) {
console.error("Error processing LLM request:", error);
res.write(