mirror of
https://github.com/PragmaticMachineLearning/probly.git
synced 2026-01-10 05:47:56 -05:00
Merge pull request #9 from PragmaticMachineLearning/fix/abort-streaming-request
Fix/abort streaming request
This commit is contained in:
@@ -189,6 +189,7 @@ const SpreadsheetApp = () => {
|
||||
);
|
||||
} catch (error) {
|
||||
if (error.name === "AbortError") {
|
||||
console.log("Request Aborted");
|
||||
// Handle abort case
|
||||
setChatHistory((prev) =>
|
||||
prev.map((msg) =>
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
import { useState } from "react";
|
||||
import { useEffect, useState } from "react";
|
||||
import { CellUpdate, ChatMessage } from "@/types/api";
|
||||
import { Check, X, Send, Trash2, Loader2, Square } from "lucide-react";
|
||||
|
||||
@@ -22,8 +22,15 @@ const ChatBox = ({
|
||||
const [message, setMessage] = useState("");
|
||||
const [isLoading, setIsLoading] = useState(false);
|
||||
|
||||
useEffect(() => {
|
||||
if (chatHistory.length > 0) {
|
||||
const lastMessage = chatHistory[chatHistory.length - 1];
|
||||
setIsLoading(!!lastMessage.streaming);
|
||||
}
|
||||
}, [chatHistory]);
|
||||
|
||||
const handleSend = async () => {
|
||||
if (message.trim()) {
|
||||
if (message.trim() || isLoading) {
|
||||
if (isLoading) {
|
||||
onStop();
|
||||
setIsLoading(false);
|
||||
@@ -36,8 +43,6 @@ const ChatBox = ({
|
||||
await onSend(currentMessage); // Changed to await for streaming
|
||||
} catch (error) {
|
||||
console.error("Error details:", error);
|
||||
} finally {
|
||||
setIsLoading(false);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
@@ -43,6 +43,15 @@ async function handleLLMRequest(
|
||||
chatHistory: { role: string; content: string }[],
|
||||
res: any,
|
||||
): Promise<void> {
|
||||
let aborted = false;
|
||||
let sandbox: Sandbox | null = null;
|
||||
|
||||
// Set up disconnect handler
|
||||
res.on("close", () => {
|
||||
aborted = true;
|
||||
console.log("Client disconnected");
|
||||
});
|
||||
|
||||
try {
|
||||
const data = formatSpreadsheetData(spreadsheetData);
|
||||
const spreadsheetContext = spreadsheetData?.length
|
||||
@@ -55,7 +64,8 @@ async function handleLLMRequest(
|
||||
...chatHistory.slice(-10),
|
||||
{ role: "user", content: userMessage },
|
||||
];
|
||||
// First, try a streaming call without tools
|
||||
|
||||
// First streaming call
|
||||
const stream = await openai.chat.completions.create({
|
||||
messages: messages,
|
||||
model: model,
|
||||
@@ -64,6 +74,13 @@ async function handleLLMRequest(
|
||||
|
||||
let accumulatedContent = "";
|
||||
for await (const chunk of stream) {
|
||||
// Check if client disconnected
|
||||
if (aborted) {
|
||||
console.log("Aborting stream processing");
|
||||
await stream.controller.abort();
|
||||
return;
|
||||
}
|
||||
|
||||
const content = chunk.choices[0]?.delta?.content || "";
|
||||
if (content) {
|
||||
accumulatedContent += content;
|
||||
@@ -76,7 +93,10 @@ async function handleLLMRequest(
|
||||
}
|
||||
}
|
||||
|
||||
// After streaming text, check if we need tool calls
|
||||
// Check again before making the tool call
|
||||
if (aborted) return;
|
||||
|
||||
// Tool completion call
|
||||
const toolCompletion = await openai.chat.completions.create({
|
||||
messages: [
|
||||
...messages,
|
||||
@@ -87,10 +107,12 @@ async function handleLLMRequest(
|
||||
stream: false,
|
||||
});
|
||||
|
||||
// Check if aborted before processing tool calls
|
||||
if (aborted) return;
|
||||
|
||||
const assistantMessage = toolCompletion.choices[0]?.message;
|
||||
const toolCalls = assistantMessage?.tool_calls;
|
||||
console.log("Assistant message>>>", assistantMessage);
|
||||
console.log("Tool Calls>>>", toolCalls);
|
||||
|
||||
if (toolCalls?.length) {
|
||||
const toolCall = toolCalls[0];
|
||||
let toolData: any = {
|
||||
@@ -98,87 +120,117 @@ async function handleLLMRequest(
|
||||
};
|
||||
|
||||
if (toolCall.function.name === "set_spreadsheet_cells") {
|
||||
if (aborted) return;
|
||||
const updates = JSON.parse(toolCall.function.arguments).cellUpdates;
|
||||
toolData.updates = updates;
|
||||
|
||||
// Add formatted tool data to the response
|
||||
toolData.response +=
|
||||
"\n\nSpreadsheet Updates:\n" +
|
||||
updates
|
||||
.map((update) => `${update.target}: ${update.formula}`)
|
||||
.map((update: any) => `${update.target}: ${update.formula}`)
|
||||
.join("\n");
|
||||
} else if (toolCall.function.name === "create_chart") {
|
||||
if (aborted) return;
|
||||
const args = JSON.parse(toolCall.function.arguments);
|
||||
toolData.chartData = {
|
||||
type: args.type,
|
||||
options: { title: args.title, data: args.data },
|
||||
};
|
||||
|
||||
// Add formatted chart data to the response
|
||||
toolData.response +=
|
||||
"\n\nChart Data:\n" +
|
||||
`Type: ${args.type}\n` +
|
||||
`Title: ${args.title}\n` +
|
||||
`Data:\n${args.data.map((row) => row.join(", ")).join("\n")}`;
|
||||
`Data:\n${args.data.map((row: any[]) => row.join(", ")).join("\n")}`;
|
||||
} else if (toolCall.function.name === "execute_python_code") {
|
||||
const sandbox = await Sandbox.create();
|
||||
const dirname = "/home/user";
|
||||
try {
|
||||
if (aborted) return;
|
||||
sandbox = await Sandbox.create();
|
||||
const dirname = "/home/user";
|
||||
|
||||
const { analysis_goal, suggested_code, start_cell } = JSON.parse(
|
||||
toolCall.function.arguments,
|
||||
);
|
||||
console.log("START CELL", start_cell);
|
||||
const csvData = convertToCSV(spreadsheetData);
|
||||
await sandbox.files.write(`${dirname}/data.csv`, csvData);
|
||||
const { analysis_goal, suggested_code, start_cell } = JSON.parse(
|
||||
toolCall.function.arguments,
|
||||
);
|
||||
|
||||
const pythonScript = `
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
# Read the data
|
||||
df = pd.read_csv('/home/user/data.csv')
|
||||
# Execute analysis
|
||||
${suggested_code}
|
||||
`;
|
||||
if (aborted) {
|
||||
await sandbox.kill();
|
||||
return;
|
||||
}
|
||||
|
||||
const execution = await sandbox.runCode(pythonScript);
|
||||
const fileContent = await sandbox.files.read("/home/user/outputs.csv");
|
||||
console.log("FILE CONTENT>", fileContent);
|
||||
// Parse the CSV output to generate cell updates
|
||||
const outputRows = fileContent
|
||||
.trim()
|
||||
.split("\n")
|
||||
.map((row) => row.split(","));
|
||||
const csvData = convertToCSV(spreadsheetData);
|
||||
await sandbox.files.write(`${dirname}/data.csv`, csvData);
|
||||
|
||||
const colLetter = start_cell.match(/[A-Z]+/)[0];
|
||||
const rowNumber = parseInt(start_cell.match(/\d+/)[0]);
|
||||
const generatedUpdates: CellUpdate[] = outputRows.flatMap(
|
||||
(row, rowIndex) =>
|
||||
row.map((value, colIndex) => ({
|
||||
target: `${String.fromCharCode(colLetter.charCodeAt(0) + colIndex)}${rowNumber + rowIndex}`,
|
||||
formula: value.toString(),
|
||||
})),
|
||||
);
|
||||
if (aborted) {
|
||||
await sandbox.kill();
|
||||
return;
|
||||
}
|
||||
|
||||
// Generate cell updates based on the output data
|
||||
console.log("GENERATED UPDATES", generatedUpdates);
|
||||
toolData = {
|
||||
response: `Analysis: ${analysis_goal}\n\nResults:\n${fileContent}`,
|
||||
updates: generatedUpdates,
|
||||
analysis: {
|
||||
goal: analysis_goal,
|
||||
output: fileContent,
|
||||
error: execution.logs.stderr,
|
||||
},
|
||||
};
|
||||
const pythonScript = `
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
# Read the data
|
||||
df = pd.read_csv('/home/user/data.csv')
|
||||
# Execute analysis
|
||||
${suggested_code}
|
||||
`;
|
||||
|
||||
const execution = await sandbox.runCode(pythonScript);
|
||||
|
||||
if (aborted) {
|
||||
await sandbox.kill();
|
||||
return;
|
||||
}
|
||||
|
||||
const fileContent = await sandbox.files.read(
|
||||
"/home/user/outputs.csv",
|
||||
);
|
||||
|
||||
// Parse the CSV output to generate cell updates
|
||||
const outputRows = fileContent
|
||||
.trim()
|
||||
.split("\n")
|
||||
.map((row) => row.split(","));
|
||||
|
||||
const colLetter = start_cell.match(/[A-Z]+/)[0];
|
||||
const rowNumber = parseInt(start_cell.match(/\d+/)[0]);
|
||||
|
||||
const generatedUpdates: CellUpdate[] = outputRows.flatMap(
|
||||
(row, rowIndex) =>
|
||||
row.map((value, colIndex) => ({
|
||||
target: `${String.fromCharCode(
|
||||
colLetter.charCodeAt(0) + colIndex,
|
||||
)}${rowNumber + rowIndex}`,
|
||||
formula: value.toString(),
|
||||
})),
|
||||
);
|
||||
|
||||
toolData = {
|
||||
response: `Analysis: ${analysis_goal}\n\nResults:\n${fileContent}`,
|
||||
updates: generatedUpdates,
|
||||
analysis: {
|
||||
goal: analysis_goal,
|
||||
output: fileContent,
|
||||
error: execution.logs.stderr,
|
||||
},
|
||||
};
|
||||
} finally {
|
||||
if (sandbox) {
|
||||
await sandbox.kill();
|
||||
sandbox = null;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
res.write(
|
||||
`data: ${JSON.stringify({
|
||||
...toolData,
|
||||
streaming: false,
|
||||
})}\n\n`,
|
||||
);
|
||||
} else {
|
||||
// Only send response if not aborted
|
||||
if (!aborted) {
|
||||
res.write(
|
||||
`data: ${JSON.stringify({
|
||||
...toolData,
|
||||
streaming: false,
|
||||
})}\n\n`,
|
||||
);
|
||||
}
|
||||
} else if (!aborted) {
|
||||
res.write(
|
||||
`data: ${JSON.stringify({
|
||||
response: accumulatedContent,
|
||||
@@ -187,12 +239,17 @@ async function handleLLMRequest(
|
||||
);
|
||||
}
|
||||
} catch (error: any) {
|
||||
console.error("LLM API error:", error);
|
||||
res.write(
|
||||
`data: ${JSON.stringify({ error: error.message || "Unknown error" })}\n\n`,
|
||||
);
|
||||
if (!aborted) {
|
||||
console.error("LLM API error:", error);
|
||||
res.write(
|
||||
`data: ${JSON.stringify({ error: error.message || "Unknown error" })}\n\n`,
|
||||
);
|
||||
}
|
||||
} finally {
|
||||
res.end();
|
||||
// Ensure sandbox is destroyed if it exists
|
||||
if (sandbox) {
|
||||
await sandbox.kill();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -203,9 +260,20 @@ export default async function handler(req: any, res: any): Promise<void> {
|
||||
res.setHeader("Connection", "keep-alive");
|
||||
res.flushHeaders();
|
||||
|
||||
// Create a promise that resolves when the client disconnects
|
||||
const disconnectPromise = new Promise((resolve) => {
|
||||
res.on("close", () => {
|
||||
resolve(undefined);
|
||||
});
|
||||
});
|
||||
|
||||
try {
|
||||
const { message, spreadsheetData, chatHistory } = req.body;
|
||||
await handleLLMRequest(message, spreadsheetData, chatHistory, res);
|
||||
// Race between the LLM request and the client disconnecting
|
||||
await Promise.race([
|
||||
handleLLMRequest(message, spreadsheetData, chatHistory, res),
|
||||
disconnectPromise,
|
||||
]);
|
||||
} catch (error: any) {
|
||||
console.error("Error processing LLM request:", error);
|
||||
res.write(
|
||||
|
||||
Reference in New Issue
Block a user