diff --git a/src/everything/everything.ts b/src/everything/everything.ts index 01277579..ddf8658d 100644 --- a/src/everything/everything.ts +++ b/src/everything/everything.ts @@ -6,13 +6,16 @@ import { CompleteRequestSchema, CreateMessageRequest, CreateMessageResultSchema, + ElicitRequest, ElicitResultSchema, + ErrorCode, GetPromptRequestSchema, ListPromptsRequestSchema, ListResourcesRequestSchema, ListResourceTemplatesRequestSchema, ListToolsRequestSchema, LoggingLevel, + McpError, ReadResourceRequestSchema, Resource, RootsListChangedNotificationSchema, @@ -21,8 +24,11 @@ import { SubscribeRequestSchema, Tool, UnsubscribeRequestSchema, - type Root + type CallToolResult, + type Root, } from "@modelcontextprotocol/sdk/types.js"; +import { ToolRegistry, BreakToolLoopError } from "./toolRegistry.js"; +import { runToolLoop } from "./toolLoop.js"; import { z } from "zod"; import { zodToJsonSchema } from "zod-to-json-schema"; import { readFileSync } from "fs"; @@ -70,6 +76,12 @@ const SampleLLMSchema = z.object({ .describe("Maximum number of tokens to generate"), }); +const AdventureGameSchema = z.object({ + gameSynopsisOrSubject: z + .string() + .describe("Description of the game subject or possible synopsis."), +}); + const GetTinyImageSchema = z.object({}); const AnnotatedMessageSchema = z.object({ @@ -137,6 +149,7 @@ enum ToolName { LONG_RUNNING_OPERATION = "longRunningOperation", PRINT_ENV = "printEnv", SAMPLE_LLM = "sampleLLM", + ADVENTURE_GAME = "adventureGame", GET_TINY_IMAGE = "getTinyImage", ANNOTATED_MESSAGE = "annotatedMessage", GET_RESOURCE_REFERENCE = "getResourceReference", @@ -226,36 +239,6 @@ export const createServer = () => { } }; - // Helper method to request sampling from client - const requestSampling = async ( - context: string, - uri: string, - maxTokens: number = 100, - sendRequest: SendRequest - ) => { - const request: CreateMessageRequest = { - method: "sampling/createMessage", - params: { - messages: [ - { - role: "user", - content: { - type: "text", - text: `Resource ${uri} context: ${context}`, - }, - }, - ], - systemPrompt: "You are a helpful test server.", - maxTokens, - temperature: 0.7, - includeContext: "thisServer", - }, - }; - - return await sendRequest(request, CreateMessageResultSchema); - - }; - const ALL_RESOURCES: Resource[] = Array.from({ length: 100 }, (_, i) => { const uri = `test://static/resource/${i + 1}`; if (i % 2 === 0) { @@ -536,6 +519,11 @@ export const createServer = () => { description: "Elicitation test tool that demonstrates how to request user input with various field types (string, boolean, email, uri, date, integer, number, enum)", inputSchema: zodToJsonSchema(ElicitationSchema) as ToolInput, }); + if (clientCapabilities!.sampling && clientCapabilities!.elicitation) tools.push ({ + name: ToolName.ADVENTURE_GAME, + description: "Play a 'choose your own adventure' game. The user will be asked for decisions along the way via elicitation. Requires both sampling and elicitation capabilities.", + inputSchema: zodToJsonSchema(AdventureGameSchema) as ToolInput, + }); return { tools }; }); @@ -611,12 +599,25 @@ export const createServer = () => { const validatedArgs = SampleLLMSchema.parse(args); const { prompt, maxTokens } = validatedArgs; - const result = await requestSampling( - prompt, - ToolName.SAMPLE_LLM, - maxTokens, - extra.sendRequest - ); + const result = await extra.sendRequest({ + method: "sampling/createMessage", + params: { + maxTokens, + messages: [ + { + role: "user", + content: { + type: "text", + text: prompt, + }, + }, + ], + systemPrompt: "You are a helpful test server.", + temperature: 0.7, + includeContext: "thisServer", + }, + }, CreateMessageResultSchema); + const content = Array.isArray(result.content) ? result.content : [result.content]; const textResult = content.every((c) => c.type === "text") ? content.map(c => c.text).join("\n") @@ -628,6 +629,176 @@ export const createServer = () => { }; } + if (name === ToolName.ADVENTURE_GAME) { + const { gameSynopsisOrSubject } = AdventureGameSchema.parse(args); + + // Helper to create error result + const makeErrorCallToolResult = (error: unknown): CallToolResult => ({ + content: [ + { + type: "text", + text: error instanceof Error ? `${error.message}\n${error.stack}` : `${error}`, + }, + ], + isError: true, + }); + + // Create registry with game tools + const gameRegistry = new ToolRegistry({ + userLost: { + description: "Called when the user loses", + inputSchema: z.object({ + storyUpdate: z.string(), + }), + callback: async (args, gameExtra) => { + const { storyUpdate } = args as { storyUpdate: string }; + await gameExtra.sendRequest({ + method: 'elicitation/create', + params: { + mode: 'form', + message: 'You Lost!\n' + storyUpdate, + requestedSchema: { + type: 'object', + properties: {}, + }, + }, + }, ElicitResultSchema); + throw new BreakToolLoopError('lost'); + }, + }, + userWon: { + description: "Called when the user wins the game", + inputSchema: z.object({ + storyUpdate: z.string(), + }), + callback: async (args, gameExtra) => { + const { storyUpdate } = args as { storyUpdate: string }; + await gameExtra.sendRequest({ + method: 'elicitation/create', + params: { + mode: 'form', + message: 'You Won!\n' + storyUpdate, + requestedSchema: { + type: 'object', + properties: {}, + }, + }, + }, ElicitResultSchema); + throw new BreakToolLoopError('won'); + }, + }, + nextStep: { + description: "Next step in the game.", + inputSchema: z.object({ + storyUpdate: z.string().describe("Description of the next step of the game. Acknowledges the last decision (if any) and describes what happened because of / since it was made, then continues the story up to the point where another decision is needed from the user (if/when appropriate)."), + nextDecisions: z.array(z.string()).describe("The list of possible decisions the user/player can make at this point of the story. Empty list if we've reached the end of the story"), + decisionTimeoutSeconds: z.number().optional().describe("Optional: timeout in seconds for decision to be made. Used when a timely decision is needed."), + }), + outputSchema: z.object({ + userDecision: z.string().optional() + .describe("The decision the user took, or undefined if the user let the decision time out."), + }), + callback: async (args, gameExtra) => { + const { storyUpdate, nextDecisions, decisionTimeoutSeconds } = args as { + storyUpdate: string; + nextDecisions: string[]; + decisionTimeoutSeconds?: number; + }; + try { + const result = await gameExtra.sendRequest({ + method: 'elicitation/create', + params: { + mode: 'form', + message: storyUpdate, + requestedSchema: { + type: 'object', + properties: { + nextDecision: { + title: 'Next step', + type: 'string', + enum: nextDecisions, + }, + }, + }, + }, + }, ElicitResultSchema, { + timeout: decisionTimeoutSeconds == null ? undefined : decisionTimeoutSeconds * 1000, + }); + + if (result.action === 'accept') { + const structuredContent = { + userDecision: result.content?.nextDecision as string, + }; + return { + content: [{ type: 'text' as const, text: JSON.stringify(structuredContent) }], + structuredContent, + }; + } else { + return { + content: [{ type: 'text' as const, text: result.action === 'decline' ? 'Game Over' : 'Game Cancelled' }], + }; + } + } catch (error) { + if (error instanceof McpError && error.code === ErrorCode.RequestTimeout) { + const structuredContent = { + userDecision: undefined, // Means "timed out" + }; + return { + content: [{ type: 'text' as const, text: JSON.stringify(structuredContent) }], + structuredContent, + }; + } + return makeErrorCallToolResult(error); + } + }, + }, + }); + + try { + const { answer, transcript, usage } = await runToolLoop({ + initialMessages: [{ + role: "user", + content: { + type: "text", + text: gameSynopsisOrSubject, + }, + }], + systemPrompt: + "You are a 'choose your own adventure' game master. " + + "Given an initial user request (subject and/or synopsis of the game, maybe description of their role in the game), " + + "you will relentlessly walk the user forward in an imaginary story, " + + "giving them regular choices as to what their character can do next can happen next. " + + "If the user didn't choose a role for themselves, you can ask them to pick one of a few interesting options (first decision). " + + "Then you will continually develop the story and call the nextStep tool to give story updates and ask for pivotal decisions. " + + "Updates should fit in a page (sometimes as short as a paragraph e.g. if doing a battle with very fast paced action). " + + "Some decisions should have a timeout to create some thrills for the user, in tight action scenes. " + + "When / if the user loses (e.g. dies, or whatever the user expressed as a loss condition), the last call to nextStep should have zero options.", + defaultToolChoice: { mode: 'required' }, + server, + registry: gameRegistry, + }, extra); + + return { + content: [ + { + type: "text", + text: answer, + }, + { + type: "text", + text: `\n\n--- Usage: ${usage.api_calls} API calls, ${usage.input_tokens} input / ${usage.output_tokens} output tokens ---`, + }, + { + type: "text", + text: `\n\n--- Debug Transcript (${transcript.length} messages) ---\n${JSON.stringify(transcript, null, 2)}`, + }, + ], + }; + } catch (error) { + return makeErrorCallToolResult(error); + } + } + if (name === ToolName.GET_TINY_IMAGE) { GetTinyImageSchema.parse(args); return { diff --git a/src/everything/toolLoop.ts b/src/everything/toolLoop.ts new file mode 100644 index 00000000..dcb75fc1 --- /dev/null +++ b/src/everything/toolLoop.ts @@ -0,0 +1,146 @@ +import type { Server } from "@modelcontextprotocol/sdk/server/index.js"; +import type { RequestHandlerExtra } from "@modelcontextprotocol/sdk/shared/protocol.js"; +import type { + SamplingMessage, + ToolUseContent, + CreateMessageRequest, + CreateMessageResult, + ServerRequest, + ServerNotification, + ToolResultContent, +} from "@modelcontextprotocol/sdk/types.js"; +import { ToolRegistry, BreakToolLoopError } from "./toolRegistry.js"; + +export { BreakToolLoopError }; + +/** + * Interface for tracking aggregated token usage across API calls. + */ +interface AggregatedUsage { + input_tokens: number; + output_tokens: number; + cache_creation_input_tokens: number; + cache_read_input_tokens: number; + api_calls: number; +} + +/** + * Runs a tool loop using sampling. + * Continues until the LLM provides a final answer. + */ +export async function runToolLoop( + options: { + initialMessages: SamplingMessage[]; + server: Server; + registry: ToolRegistry; + maxIterations?: number; + systemPrompt?: string; + defaultToolChoice?: CreateMessageRequest["params"]["toolChoice"]; + }, + extra: RequestHandlerExtra +): Promise<{ answer: string; transcript: SamplingMessage[]; usage: AggregatedUsage }> { + const messages: SamplingMessage[] = [...options.initialMessages]; + + // Initialize usage tracking + const usage: AggregatedUsage = { + input_tokens: 0, + output_tokens: 0, + cache_creation_input_tokens: 0, + cache_read_input_tokens: 0, + api_calls: 0, + }; + + let iteration = 0; + const maxIterations = options.maxIterations ?? Number.POSITIVE_INFINITY; + const defaultToolChoice = options.defaultToolChoice ?? { mode: "auto" }; + + let request: CreateMessageRequest["params"] | undefined; + let response: CreateMessageResult | undefined; + + while (iteration < maxIterations) { + iteration++; + + // Request message from LLM with available tools + response = await options.server.createMessage(request = { + messages, + systemPrompt: options.systemPrompt, + maxTokens: 4000, + tools: iteration < maxIterations ? options.registry.tools : undefined, + // Don't allow tool calls at the last iteration: finish with an answer no matter what! + toolChoice: iteration < maxIterations ? defaultToolChoice : { mode: "none" }, + }); + + // Aggregate usage statistics from the response + if (response._meta?.usage) { + const responseUsage = response._meta.usage as Record; + usage.input_tokens += responseUsage.input_tokens || 0; + usage.output_tokens += responseUsage.output_tokens || 0; + usage.cache_creation_input_tokens += responseUsage.cache_creation_input_tokens || 0; + usage.cache_read_input_tokens += responseUsage.cache_read_input_tokens || 0; + usage.api_calls += 1; + } + + // Add assistant's response to message history + messages.push({ + role: "assistant", + content: response.content, + }); + + if (response.stopReason === "toolUse") { + const contentArray = Array.isArray(response.content) ? response.content : [response.content]; + const toolCalls = contentArray.filter( + (content): content is ToolUseContent => content.type === "tool_use" + ); + + await options.server.sendLoggingMessage({ + level: "info", + data: `Loop iteration ${iteration}: ${toolCalls.length} tool invocation(s) requested`, + }); + + let toolResults: ToolResultContent[]; + try { + toolResults = await options.registry.callTools(toolCalls, extra); + } catch (error) { + if (error instanceof BreakToolLoopError) { + return { answer: `${error.message}`, transcript: messages, usage }; + } + console.error(error); + throw new Error(`Tool call failed: ${error}`); + } + + messages.push({ + role: "user", + content: iteration < maxIterations ? toolResults : [ + ...toolResults, + { + type: "text", + text: "Using the information retrieved from the tools, please now provide a concise final answer to the original question (last iteration of the tool loop).", + }, + ], + }); + } else if (response.stopReason === "endTurn") { + const contentArray = Array.isArray(response.content) ? response.content : [response.content]; + const unexpectedBlocks = contentArray.filter(content => content.type !== "text"); + if (unexpectedBlocks.length > 0) { + throw new Error(`Expected text content in final answer, but got: ${unexpectedBlocks.map(b => b.type).join(", ")}`); + } + + await options.server.sendLoggingMessage({ + level: "info", + data: `Tool loop completed after ${iteration} iteration(s)`, + }); + + return { + answer: contentArray.map(block => block.type === "text" ? block.text : "").join("\n\n"), + transcript: messages, + usage, + }; + } else if (response?.stopReason === "maxTokens") { + throw new Error("LLM response hit max tokens limit"); + } else { + throw new Error(`Unsupported stop reason: ${response.stopReason}`); + } + } + + throw new Error(`Tool loop exceeded maximum iterations (${maxIterations}); request: ${JSON.stringify(request)}\nresponse: ${JSON.stringify(response)}`); +} diff --git a/src/everything/toolRegistry.ts b/src/everything/toolRegistry.ts new file mode 100644 index 00000000..1e7cdd2d --- /dev/null +++ b/src/everything/toolRegistry.ts @@ -0,0 +1,75 @@ +import type { RequestHandlerExtra } from "@modelcontextprotocol/sdk/shared/protocol.js"; +import type { + Tool, + ToolAnnotations, + ToolUseContent, + ToolResultContent, + CallToolResult, + ServerRequest, + ServerNotification, +} from "@modelcontextprotocol/sdk/types.js"; +import { zodToJsonSchema } from "zod-to-json-schema"; + +export class BreakToolLoopError extends Error { + constructor(message: string) { + super(message); + } +} + +type ToolCallback = ( + args: Record, + extra: RequestHandlerExtra +) => CallToolResult | Promise; + +interface ToolDefinition { + title?: string; + description?: string; + inputSchema?: unknown; + outputSchema?: unknown; + annotations?: ToolAnnotations; + _meta?: Record; + callback: ToolCallback; +} + +export class ToolRegistry { + readonly tools: Tool[]; + + constructor(private toolDefinitions: { [name: string]: ToolDefinition }) { + this.tools = Object.entries(this.toolDefinitions).map(([name, tool]) => ({ + name, + title: tool.title, + description: tool.description, + // eslint-disable-next-line @typescript-eslint/no-explicit-any + inputSchema: tool.inputSchema ? zodToJsonSchema(tool.inputSchema as any) : undefined, + // eslint-disable-next-line @typescript-eslint/no-explicit-any + outputSchema: tool.outputSchema ? zodToJsonSchema(tool.outputSchema as any) : undefined, + annotations: tool.annotations, + _meta: tool._meta, + })); + } + + async callTools( + toolCalls: ToolUseContent[], + extra: RequestHandlerExtra + ): Promise { + return Promise.all(toolCalls.map(async ({ name, id, input }) => { + const tool = this.toolDefinitions[name]; + if (!tool) { + throw new Error(`Tool ${name} not found`); + } + try { + return { + type: "tool_result", + toolUseId: id, + // Copies fields: content, structuredContent?, isError? + ...await tool.callback(input as Record, extra), + }; + } catch (error) { + if (error instanceof BreakToolLoopError) { + throw error; + } + throw new Error(`Tool ${name} failed: ${error instanceof Error ? `${error.message}\n${error.stack}` : error}`); + } + })); + } +}