From aad28f7841feb3271d5afcaa23728a0c43766b5c Mon Sep 17 00:00:00 2001 From: Gustavo Madeira Santana Date: Tue, 10 Feb 2026 06:41:36 -0500 Subject: [PATCH] fix: streamline custom endpoint onboarding (#11106) (thanks @MackDing) --- src/commands/auth-choice-prompt.ts | 4 + src/commands/onboard-custom.test.ts | 95 ++++----- src/commands/onboard-custom.ts | 298 ++++++++++------------------ 3 files changed, 147 insertions(+), 250 deletions(-) diff --git a/src/commands/auth-choice-prompt.ts b/src/commands/auth-choice-prompt.ts index 3fbacdfdb4..8eef15e079 100644 --- a/src/commands/auth-choice-prompt.ts +++ b/src/commands/auth-choice-prompt.ts @@ -42,6 +42,10 @@ export async function promptAuthChoiceGrouped(params: { continue; } + if (group.options.length === 1) { + return group.options[0].value; + } + const methodSelection = await params.prompter.select({ message: `${group.label} auth method`, options: [...group.options, { value: BACK_VALUE, label: "Back" }], diff --git a/src/commands/onboard-custom.test.ts b/src/commands/onboard-custom.test.ts index cceaa705a3..16c07c287c 100644 --- a/src/commands/onboard-custom.test.ts +++ b/src/commands/onboard-custom.test.ts @@ -13,38 +13,30 @@ describe("promptCustomApiConfig", () => { vi.useRealTimers(); }); - it("handles openai discovery and saves alias", async () => { + it("handles openai flow and saves alias", async () => { const prompter = { text: vi .fn() .mockResolvedValueOnce("http://localhost:11434/v1") // Base URL .mockResolvedValueOnce("") // API Key + .mockResolvedValueOnce("llama3") // Model ID .mockResolvedValueOnce("custom") // Endpoint ID .mockResolvedValueOnce("local"), // Alias progress: vi.fn(() => ({ update: vi.fn(), stop: vi.fn(), })), - select: vi - .fn() - .mockResolvedValueOnce("openai") // Compatibility - .mockResolvedValueOnce("llama3"), // Model selection + select: vi.fn().mockResolvedValueOnce("openai"), // Compatibility confirm: vi.fn(), note: vi.fn(), }; vi.stubGlobal( "fetch", - vi - .fn() - .mockResolvedValueOnce({ - ok: true, - json: async () => ({ data: [{ id: "llama3" }, { id: "mistral" }] }), - }) - .mockResolvedValueOnce({ - ok: true, - json: async () => ({}), - }), + vi.fn().mockResolvedValueOnce({ + ok: true, + json: async () => ({}), + }), ); const result = await promptCustomApiConfig({ @@ -53,27 +45,31 @@ describe("promptCustomApiConfig", () => { config: {}, }); - expect(prompter.text).toHaveBeenCalledTimes(4); - expect(prompter.select).toHaveBeenCalledTimes(2); + expect(prompter.text).toHaveBeenCalledTimes(5); + expect(prompter.select).toHaveBeenCalledTimes(1); expect(result.config.models?.providers?.custom?.api).toBe("openai-completions"); expect(result.config.agents?.defaults?.models?.["custom/llama3"]?.alias).toBe("local"); }); - it("falls back to manual entry when discovery fails", async () => { + it("retries when verification fails", async () => { const prompter = { text: vi .fn() .mockResolvedValueOnce("http://localhost:11434/v1") // Base URL .mockResolvedValueOnce("") // API Key + .mockResolvedValueOnce("bad-model") // Model ID + .mockResolvedValueOnce("good-model") // Model ID retry .mockResolvedValueOnce("custom") // Endpoint ID - .mockResolvedValueOnce("manual-model-id") // Manual model .mockResolvedValueOnce(""), // Alias progress: vi.fn(() => ({ update: vi.fn(), stop: vi.fn(), })), - select: vi.fn().mockResolvedValueOnce("openai"), // Compatibility only - confirm: vi.fn().mockResolvedValue(true), + select: vi + .fn() + .mockResolvedValueOnce("openai") // Compatibility + .mockResolvedValueOnce("model"), // Retry choice + confirm: vi.fn(), note: vi.fn(), }; @@ -81,8 +77,8 @@ describe("promptCustomApiConfig", () => { "fetch", vi .fn() - .mockRejectedValueOnce(new Error("Network error")) - .mockRejectedValueOnce(new Error("Network error")), + .mockResolvedValueOnce({ ok: false, status: 400, json: async () => ({}) }) + .mockResolvedValueOnce({ ok: true, json: async () => ({}) }), ); await promptCustomApiConfig({ @@ -91,8 +87,8 @@ describe("promptCustomApiConfig", () => { config: {}, }); - expect(prompter.text).toHaveBeenCalledTimes(5); - expect(prompter.confirm).toHaveBeenCalled(); + expect(prompter.text).toHaveBeenCalledTimes(6); + expect(prompter.select).toHaveBeenCalledTimes(2); }); it("detects openai compatibility when unknown", async () => { @@ -115,16 +111,10 @@ describe("promptCustomApiConfig", () => { vi.stubGlobal( "fetch", - vi - .fn() - .mockResolvedValueOnce({ - ok: true, - json: async () => ({ data: [] }), - }) - .mockResolvedValueOnce({ - ok: true, - json: async () => ({}), - }), + vi.fn().mockResolvedValueOnce({ + ok: true, + json: async () => ({}), + }), ); const result = await promptCustomApiConfig({ @@ -147,14 +137,13 @@ describe("promptCustomApiConfig", () => { .mockResolvedValueOnce("bad-model") // Model ID #1 .mockResolvedValueOnce("https://ok.example.com/v1") // Base URL #2 .mockResolvedValueOnce("ok-key") // API Key #2 - .mockResolvedValueOnce("ok-model") // Model ID #2 .mockResolvedValueOnce("custom") // Endpoint ID .mockResolvedValueOnce(""), // Alias progress: vi.fn(() => ({ update: vi.fn(), stop: vi.fn(), })), - select: vi.fn().mockResolvedValueOnce("unknown"), + select: vi.fn().mockResolvedValueOnce("unknown").mockResolvedValueOnce("baseUrl"), confirm: vi.fn(), note: vi.fn(), }; @@ -163,16 +152,8 @@ describe("promptCustomApiConfig", () => { "fetch", vi .fn() - .mockResolvedValueOnce({ - ok: true, - json: async () => ({ data: [] }), - }) .mockResolvedValueOnce({ ok: false, status: 404, json: async () => ({}) }) .mockResolvedValueOnce({ ok: false, status: 404, json: async () => ({}) }) - .mockResolvedValueOnce({ - ok: true, - json: async () => ({ data: [] }), - }) .mockResolvedValueOnce({ ok: true, json: async () => ({}) }), ); @@ -194,8 +175,8 @@ describe("promptCustomApiConfig", () => { .fn() .mockResolvedValueOnce("http://localhost:11434/v1") // Base URL .mockResolvedValueOnce("") // API Key + .mockResolvedValueOnce("llama3") // Model ID .mockResolvedValueOnce("custom") // Endpoint ID - .mockResolvedValueOnce("llama3") // Manual model .mockResolvedValueOnce(""), // Alias progress: vi.fn(() => ({ update: vi.fn(), @@ -208,13 +189,10 @@ describe("promptCustomApiConfig", () => { vi.stubGlobal( "fetch", - vi - .fn() - .mockRejectedValueOnce(new Error("Discovery failed")) - .mockResolvedValueOnce({ - ok: true, - json: async () => ({}), - }), + vi.fn().mockResolvedValueOnce({ + ok: true, + json: async () => ({}), + }), ); const result = await promptCustomApiConfig({ @@ -248,21 +226,22 @@ describe("promptCustomApiConfig", () => { expect(result.config.models?.providers?.["custom-2"]).toBeDefined(); }); - it("aborts discovery after timeout", async () => { + it("aborts verification after timeout", async () => { vi.useFakeTimers(); const prompter = { text: vi .fn() .mockResolvedValueOnce("http://localhost:11434/v1") // Base URL .mockResolvedValueOnce("") // API Key + .mockResolvedValueOnce("slow-model") // Model ID + .mockResolvedValueOnce("fast-model") // Model ID retry .mockResolvedValueOnce("custom") // Endpoint ID - .mockResolvedValueOnce("manual-model-id") // Manual model .mockResolvedValueOnce(""), // Alias progress: vi.fn(() => ({ update: vi.fn(), stop: vi.fn(), })), - select: vi.fn().mockResolvedValueOnce("openai"), + select: vi.fn().mockResolvedValueOnce("openai").mockResolvedValueOnce("model"), confirm: vi.fn(), note: vi.fn(), }; @@ -283,9 +262,9 @@ describe("promptCustomApiConfig", () => { config: {}, }); - await vi.advanceTimersByTimeAsync(5000); + await vi.advanceTimersByTimeAsync(10000); await promise; - expect(prompter.text).toHaveBeenCalledTimes(5); + expect(prompter.text).toHaveBeenCalledTimes(6); }); }); diff --git a/src/commands/onboard-custom.ts b/src/commands/onboard-custom.ts index 20bd19190a..58becb842f 100644 --- a/src/commands/onboard-custom.ts +++ b/src/commands/onboard-custom.ts @@ -8,10 +8,8 @@ import { applyPrimaryModel } from "./model-picker.js"; import { normalizeAlias } from "./models/shared.js"; const DEFAULT_OPENAI_BASE_URL = "http://127.0.0.1:11434/v1"; -const DEFAULT_ANTHROPIC_BASE_URL = "https://api.anthropic.com/v1"; const DEFAULT_CONTEXT_WINDOW = 4096; const DEFAULT_MAX_TOKENS = 4096; -const DISCOVERY_TIMEOUT_MS = 5000; const VERIFY_TIMEOUT_MS = 10000; type CustomApiCompatibility = "openai" | "anthropic"; @@ -31,7 +29,7 @@ const COMPATIBILITY_OPTIONS: Array<{ { value: "openai", label: "OpenAI-compatible", - hint: "Uses /models + /chat/completions", + hint: "Uses /chat/completions", api: "openai-completions", }, { @@ -47,19 +45,6 @@ const COMPATIBILITY_OPTIONS: Array<{ }, ]; -function resolveBaseUrlDefaults(compatibility: CustomApiCompatibilityChoice) { - if (compatibility === "anthropic") { - return { - initialValue: DEFAULT_ANTHROPIC_BASE_URL, - placeholder: "https://api.anthropic.com/v1", - }; - } - return { - initialValue: DEFAULT_OPENAI_BASE_URL, - placeholder: "http://127.0.0.1:11434/v1", - }; -} - function normalizeEndpointId(raw: string): string { const trimmed = raw.trim().toLowerCase(); if (!trimmed) { @@ -162,19 +147,6 @@ async function fetchWithTimeout( } } -function parseOpenAiModels(data: { data?: { id: string }[]; models?: { id: string }[] }) { - const rawModels = data.data || data.models || []; - return rawModels.map((m: unknown) => { - if (typeof m === "string") { - return m; - } - if (typeof m === "object" && m !== null && "id" in m) { - return (m as { id: string }).id; - } - return String(m); - }); -} - function formatVerificationError(error: unknown): string { if (!error) { return "unknown error"; @@ -192,39 +164,6 @@ function formatVerificationError(error: unknown): string { } } -async function tryDiscoverOpenAiModels(params: { - baseUrl: string; - apiKey: string; - prompter: WizardPrompter; -}): Promise { - const { baseUrl, apiKey, prompter } = params; - const spinner = prompter.progress("Connecting..."); - spinner.update(`Scanning models at ${baseUrl}...`); - try { - const discoveryUrl = new URL("models", baseUrl.endsWith("/") ? baseUrl : `${baseUrl}/`).href; - const res = await fetchWithTimeout( - discoveryUrl, - { headers: buildOpenAiHeaders(apiKey) }, - DISCOVERY_TIMEOUT_MS, - ); - if (res.ok) { - const data = (await res.json()) as { data?: { id: string }[]; models?: { id: string }[] }; - const models = parseOpenAiModels(data); - if (models.length > 0) { - spinner.stop(`Found ${models.length} models.`); - return models; - } - spinner.stop("Connected, but no models list returned."); - return []; - } - spinner.stop(`Connection succeeded, but discovery failed (${res.status}).`); - return null; - } catch { - spinner.stop("Could not auto-detect models."); - return null; - } -} - type VerificationResult = { ok: boolean; status?: number; @@ -297,14 +236,12 @@ async function requestAnthropicVerification(params: { async function promptBaseUrlAndKey(params: { prompter: WizardPrompter; - compatibility: CustomApiCompatibilityChoice; initialBaseUrl?: string; }): Promise<{ baseUrl: string; apiKey: string }> { - const defaults = resolveBaseUrlDefaults(params.compatibility); const baseUrlInput = await params.prompter.text({ message: "API Base URL", - initialValue: params.initialBaseUrl ?? defaults.initialValue, - placeholder: defaults.placeholder, + initialValue: params.initialBaseUrl ?? DEFAULT_OPENAI_BASE_URL, + placeholder: "https://api.example.com/v1", validate: (val) => { try { new URL(val); @@ -315,7 +252,7 @@ async function promptBaseUrlAndKey(params: { }, }); const apiKeyInput = await params.prompter.text({ - message: "API Key (optional)", + message: "API Key (leave blank if not required)", placeholder: "sk-...", initialValue: "", }); @@ -329,6 +266,10 @@ export async function promptCustomApiConfig(params: { }): Promise { const { prompter, runtime, config } = params; + const baseInput = await promptBaseUrlAndKey({ prompter }); + let baseUrl = baseInput.baseUrl; + let apiKey = baseInput.apiKey; + const compatibilityChoice = await prompter.select({ message: "Endpoint compatibility", options: COMPATIBILITY_OPTIONS.map((option) => ({ @@ -337,44 +278,24 @@ export async function promptCustomApiConfig(params: { hint: option.hint, })), }); + + let modelId = ( + await prompter.text({ + message: "Model ID", + placeholder: "e.g. llama3, claude-3-7-sonnet", + validate: (val) => (val.trim() ? undefined : "Model ID is required"), + }) + ).trim(); + let compatibility: CustomApiCompatibility | null = compatibilityChoice === "unknown" ? null : compatibilityChoice; let providerApi = COMPATIBILITY_OPTIONS.find((entry) => entry.value === compatibility)?.api ?? "openai-completions"; - let baseUrl = ""; - let apiKey = ""; - let modelId: string | undefined; - let discoveredModels: string[] | null = null; - let verifiedFromProbe = false; - - if (compatibilityChoice === "unknown") { - let lastBaseUrl: string | undefined; - while (!compatibility) { - const baseInput = await promptBaseUrlAndKey({ - prompter, - compatibility: compatibilityChoice, - initialBaseUrl: lastBaseUrl, - }); - baseUrl = baseInput.baseUrl; - apiKey = baseInput.apiKey; - - const models = await tryDiscoverOpenAiModels({ baseUrl, apiKey, prompter }); - if (models && models.length > 0) { - compatibility = "openai"; - providerApi = "openai-completions"; - discoveredModels = models; - break; - } - - modelId = ( - await prompter.text({ - message: "Model ID", - placeholder: "e.g. llama3, claude-3-7-sonnet", - validate: (val) => (val.trim() ? undefined : "Model ID is required"), - }) - ).trim(); + while (true) { + let verifiedFromProbe = false; + if (!compatibility) { const probeSpinner = prompter.progress("Detecting endpoint type..."); const openaiProbe = await requestOpenAiVerification({ baseUrl, apiKey, modelId }); if (openaiProbe.ok) { @@ -382,41 +303,95 @@ export async function promptCustomApiConfig(params: { compatibility = "openai"; providerApi = "openai-completions"; verifiedFromProbe = true; - break; + } else { + const anthropicProbe = await requestAnthropicVerification({ baseUrl, apiKey, modelId }); + if (anthropicProbe.ok) { + probeSpinner.stop("Detected Anthropic-compatible endpoint."); + compatibility = "anthropic"; + providerApi = "anthropic-messages"; + verifiedFromProbe = true; + } else { + probeSpinner.stop("Could not detect endpoint type."); + await prompter.note( + "This endpoint did not respond to OpenAI or Anthropic style requests.", + "Endpoint detection", + ); + const retryChoice = await prompter.select({ + message: "What would you like to change?", + options: [ + { value: "baseUrl", label: "Change base URL" }, + { value: "model", label: "Change model" }, + { value: "both", label: "Change base URL and model" }, + ], + }); + if (retryChoice === "baseUrl" || retryChoice === "both") { + const retryInput = await promptBaseUrlAndKey({ + prompter, + initialBaseUrl: baseUrl, + }); + baseUrl = retryInput.baseUrl; + apiKey = retryInput.apiKey; + } + if (retryChoice === "model" || retryChoice === "both") { + modelId = ( + await prompter.text({ + message: "Model ID", + placeholder: "e.g. llama3, claude-3-7-sonnet", + validate: (val) => (val.trim() ? undefined : "Model ID is required"), + }) + ).trim(); + } + continue; + } } - - const anthropicProbe = await requestAnthropicVerification({ baseUrl, apiKey, modelId }); - if (anthropicProbe.ok) { - probeSpinner.stop("Detected Anthropic-compatible endpoint."); - compatibility = "anthropic"; - providerApi = "anthropic-messages"; - verifiedFromProbe = true; - break; - } - - probeSpinner.stop("Could not detect endpoint type."); - await prompter.note( - "This endpoint did not respond to OpenAI or Anthropic style requests. Enter a new base URL and try again.", - "Endpoint detection", - ); - lastBaseUrl = baseUrl; - modelId = undefined; } - } else { - const baseInput = await promptBaseUrlAndKey({ - prompter, - compatibility: compatibilityChoice, - }); - baseUrl = baseInput.baseUrl; - apiKey = baseInput.apiKey; - compatibility = compatibilityChoice; - providerApi = - COMPATIBILITY_OPTIONS.find((entry) => entry.value === compatibility)?.api ?? - "openai-completions"; - } - if (!compatibility) { - return { config }; + if (verifiedFromProbe) { + break; + } + + const verifySpinner = prompter.progress("Verifying..."); + const result = + compatibility === "anthropic" + ? await requestAnthropicVerification({ baseUrl, apiKey, modelId }) + : await requestOpenAiVerification({ baseUrl, apiKey, modelId }); + if (result.ok) { + verifySpinner.stop("Verification successful."); + break; + } + if (result.status !== undefined) { + verifySpinner.stop(`Verification failed: status ${result.status}`); + } else { + verifySpinner.stop(`Verification failed: ${formatVerificationError(result.error)}`); + } + const retryChoice = await prompter.select({ + message: "What would you like to change?", + options: [ + { value: "baseUrl", label: "Change base URL" }, + { value: "model", label: "Change model" }, + { value: "both", label: "Change base URL and model" }, + ], + }); + if (retryChoice === "baseUrl" || retryChoice === "both") { + const retryInput = await promptBaseUrlAndKey({ + prompter, + initialBaseUrl: baseUrl, + }); + baseUrl = retryInput.baseUrl; + apiKey = retryInput.apiKey; + } + if (retryChoice === "model" || retryChoice === "both") { + modelId = ( + await prompter.text({ + message: "Model ID", + placeholder: "e.g. llama3, claude-3-7-sonnet", + validate: (val) => (val.trim() ? undefined : "Model ID is required"), + }) + ).trim(); + } + if (compatibilityChoice === "unknown") { + compatibility = null; + } } const providers = config.models?.providers ?? {}; @@ -446,40 +421,6 @@ export async function promptCustomApiConfig(params: { } const providerId = providerIdResult.providerId; - if (compatibility === "openai" && !discoveredModels) { - discoveredModels = await tryDiscoverOpenAiModels({ baseUrl, apiKey, prompter }); - } - - if (!modelId) { - if (compatibility === "openai" && discoveredModels && discoveredModels.length > 0) { - const selection = await prompter.select({ - message: "Select a model", - options: [ - ...discoveredModels.map((id) => ({ value: id, label: id })), - { value: "__manual", label: "(Enter manually...)" }, - ], - }); - if (selection !== "__manual") { - modelId = selection; - } - } else if (compatibility === "anthropic") { - await prompter.note( - "Anthropic-compatible endpoints do not expose a standard models endpoint. Please enter a model ID manually.", - "Model discovery", - ); - } - } - - if (!modelId) { - modelId = ( - await prompter.text({ - message: "Model ID", - placeholder: "e.g. llama3, claude-3-7-sonnet", - validate: (val) => (val.trim() ? undefined : "Model ID is required"), - }) - ).trim(); - } - const modelRef = modelKey(providerId, modelId); const aliasInput = await prompter.text({ message: "Model alias (optional)", @@ -489,33 +430,6 @@ export async function promptCustomApiConfig(params: { }); const alias = aliasInput.trim(); - let verified = verifiedFromProbe; - if (!verified) { - const verifySpinner = prompter.progress("Verifying..."); - const result = - compatibility === "anthropic" - ? await requestAnthropicVerification({ baseUrl, apiKey, modelId }) - : await requestOpenAiVerification({ baseUrl, apiKey, modelId }); - if (result.ok) { - verified = true; - verifySpinner.stop("Verification successful."); - } else if (result.status !== undefined) { - verifySpinner.stop(`Verification failed: status ${result.status}`); - } else { - verifySpinner.stop(`Verification failed: ${formatVerificationError(result.error)}`); - } - } - - if (!verified) { - const confirm = await prompter.confirm({ - message: "Could not verify model connection. Save anyway?", - initialValue: true, - }); - if (!confirm) { - return { config }; - } - } - const existingProvider = providers[providerId]; const existingModels = Array.isArray(existingProvider?.models) ? existingProvider.models : []; const hasModel = existingModels.some((model) => model.id === modelId);