fix sampling and elicitation calls for streamable http

This commit is contained in:
evalstate
2025-09-28 17:40:11 +01:00
parent 850b1542e2
commit a753cd6f3a

View File

@@ -1,10 +1,13 @@
import { Server } from "@modelcontextprotocol/sdk/server/index.js";
import type { RequestHandlerExtra } from "@modelcontextprotocol/sdk/shared/protocol.js";
import {
CallToolRequestSchema,
ClientCapabilities,
CompleteRequestSchema,
CreateMessageRequest,
CreateMessageResultSchema,
ElicitRequest,
ElicitResultSchema,
GetPromptRequestSchema,
ListPromptsRequestSchema,
ListResourcesRequestSchema,
@@ -14,6 +17,8 @@ import {
ReadResourceRequestSchema,
Resource,
RootsListChangedNotificationSchema,
ServerNotification,
ServerRequest,
SubscribeRequestSchema,
Tool,
ToolSchema,
@@ -36,6 +41,8 @@ type ToolInput = z.infer<typeof ToolInputSchema>;
const ToolOutputSchema = ToolSchema.shape.outputSchema;
type ToolOutput = z.infer<typeof ToolOutputSchema>;
type SendRequest = RequestHandlerExtra<ServerRequest, ServerNotification>["sendRequest"];
/* Input schemas for tools implemented in this server */
const EchoSchema = z.object({
message: z.string().describe("Message to echo"),
@@ -220,7 +227,8 @@ export const createServer = () => {
const requestSampling = async (
context: string,
uri: string,
maxTokens: number = 100
maxTokens: number = 100,
sendRequest: SendRequest
) => {
const request: CreateMessageRequest = {
method: "sampling/createMessage",
@@ -241,22 +249,24 @@ export const createServer = () => {
},
};
return await server.request(request, CreateMessageResultSchema);
return await sendRequest(request, CreateMessageResultSchema);
};
const requestElicitation = async (
message: string,
requestedSchema: any
requestedSchema: any,
sendRequest: SendRequest
) => {
const request = {
const request: ElicitRequest = {
method: 'elicitation/create',
params: {
message,
requestedSchema
}
requestedSchema,
},
};
return await server.request(request, z.any());
return await sendRequest(request, ElicitResultSchema);
};
const ALL_RESOURCES: Resource[] = Array.from({ length: 100 }, (_, i) => {
@@ -334,12 +344,12 @@ export const createServer = () => {
throw new Error(`Unknown resource: ${uri}`);
});
server.setRequestHandler(SubscribeRequestSchema, async (request) => {
server.setRequestHandler(SubscribeRequestSchema, async (request, extra) => {
const { uri } = request.params;
subscriptions.add(uri);
// Request sampling from client when someone subscribes
await requestSampling("A new subscription was started", uri);
await requestSampling("A new subscription was started", uri, undefined, extra.sendRequest);
return {};
});
@@ -615,7 +625,8 @@ export const createServer = () => {
const result = await requestSampling(
prompt,
ToolName.SAMPLE_LLM,
maxTokens
maxTokens,
extra.sendRequest
);
return {
content: [
@@ -734,14 +745,20 @@ export const createServer = () => {
type: 'object',
properties: {
color: { type: 'string', description: 'Favorite color' },
number: { type: 'integer', description: 'Favorite number', minimum: 1, maximum: 100 },
number: {
type: 'integer',
description: 'Favorite number',
minimum: 1,
maximum: 100,
},
pets: {
type: 'string',
enum: ['cats', 'dogs', 'birds', 'fish', 'reptiles'],
description: 'Favorite pets'
description: 'Favorite pets',
},
}
}
},
},
extra.sendRequest
);
// Handle different response actions