fix: streamline custom endpoint onboarding (#11106) (thanks @MackDing)

This commit is contained in:
Gustavo Madeira Santana
2026-02-10 06:41:36 -05:00
parent ed20354abf
commit aad28f7841
3 changed files with 147 additions and 250 deletions

View File

@@ -42,6 +42,10 @@ export async function promptAuthChoiceGrouped(params: {
continue;
}
if (group.options.length === 1) {
return group.options[0].value;
}
const methodSelection = await params.prompter.select({
message: `${group.label} auth method`,
options: [...group.options, { value: BACK_VALUE, label: "Back" }],

View File

@@ -13,38 +13,30 @@ describe("promptCustomApiConfig", () => {
vi.useRealTimers();
});
it("handles openai discovery and saves alias", async () => {
it("handles openai flow and saves alias", async () => {
const prompter = {
text: vi
.fn()
.mockResolvedValueOnce("http://localhost:11434/v1") // Base URL
.mockResolvedValueOnce("") // API Key
.mockResolvedValueOnce("llama3") // Model ID
.mockResolvedValueOnce("custom") // Endpoint ID
.mockResolvedValueOnce("local"), // Alias
progress: vi.fn(() => ({
update: vi.fn(),
stop: vi.fn(),
})),
select: vi
.fn()
.mockResolvedValueOnce("openai") // Compatibility
.mockResolvedValueOnce("llama3"), // Model selection
select: vi.fn().mockResolvedValueOnce("openai"), // Compatibility
confirm: vi.fn(),
note: vi.fn(),
};
vi.stubGlobal(
"fetch",
vi
.fn()
.mockResolvedValueOnce({
ok: true,
json: async () => ({ data: [{ id: "llama3" }, { id: "mistral" }] }),
})
.mockResolvedValueOnce({
ok: true,
json: async () => ({}),
}),
vi.fn().mockResolvedValueOnce({
ok: true,
json: async () => ({}),
}),
);
const result = await promptCustomApiConfig({
@@ -53,27 +45,31 @@ describe("promptCustomApiConfig", () => {
config: {},
});
expect(prompter.text).toHaveBeenCalledTimes(4);
expect(prompter.select).toHaveBeenCalledTimes(2);
expect(prompter.text).toHaveBeenCalledTimes(5);
expect(prompter.select).toHaveBeenCalledTimes(1);
expect(result.config.models?.providers?.custom?.api).toBe("openai-completions");
expect(result.config.agents?.defaults?.models?.["custom/llama3"]?.alias).toBe("local");
});
it("falls back to manual entry when discovery fails", async () => {
it("retries when verification fails", async () => {
const prompter = {
text: vi
.fn()
.mockResolvedValueOnce("http://localhost:11434/v1") // Base URL
.mockResolvedValueOnce("") // API Key
.mockResolvedValueOnce("bad-model") // Model ID
.mockResolvedValueOnce("good-model") // Model ID retry
.mockResolvedValueOnce("custom") // Endpoint ID
.mockResolvedValueOnce("manual-model-id") // Manual model
.mockResolvedValueOnce(""), // Alias
progress: vi.fn(() => ({
update: vi.fn(),
stop: vi.fn(),
})),
select: vi.fn().mockResolvedValueOnce("openai"), // Compatibility only
confirm: vi.fn().mockResolvedValue(true),
select: vi
.fn()
.mockResolvedValueOnce("openai") // Compatibility
.mockResolvedValueOnce("model"), // Retry choice
confirm: vi.fn(),
note: vi.fn(),
};
@@ -81,8 +77,8 @@ describe("promptCustomApiConfig", () => {
"fetch",
vi
.fn()
.mockRejectedValueOnce(new Error("Network error"))
.mockRejectedValueOnce(new Error("Network error")),
.mockResolvedValueOnce({ ok: false, status: 400, json: async () => ({}) })
.mockResolvedValueOnce({ ok: true, json: async () => ({}) }),
);
await promptCustomApiConfig({
@@ -91,8 +87,8 @@ describe("promptCustomApiConfig", () => {
config: {},
});
expect(prompter.text).toHaveBeenCalledTimes(5);
expect(prompter.confirm).toHaveBeenCalled();
expect(prompter.text).toHaveBeenCalledTimes(6);
expect(prompter.select).toHaveBeenCalledTimes(2);
});
it("detects openai compatibility when unknown", async () => {
@@ -115,16 +111,10 @@ describe("promptCustomApiConfig", () => {
vi.stubGlobal(
"fetch",
vi
.fn()
.mockResolvedValueOnce({
ok: true,
json: async () => ({ data: [] }),
})
.mockResolvedValueOnce({
ok: true,
json: async () => ({}),
}),
vi.fn().mockResolvedValueOnce({
ok: true,
json: async () => ({}),
}),
);
const result = await promptCustomApiConfig({
@@ -147,14 +137,13 @@ describe("promptCustomApiConfig", () => {
.mockResolvedValueOnce("bad-model") // Model ID #1
.mockResolvedValueOnce("https://ok.example.com/v1") // Base URL #2
.mockResolvedValueOnce("ok-key") // API Key #2
.mockResolvedValueOnce("ok-model") // Model ID #2
.mockResolvedValueOnce("custom") // Endpoint ID
.mockResolvedValueOnce(""), // Alias
progress: vi.fn(() => ({
update: vi.fn(),
stop: vi.fn(),
})),
select: vi.fn().mockResolvedValueOnce("unknown"),
select: vi.fn().mockResolvedValueOnce("unknown").mockResolvedValueOnce("baseUrl"),
confirm: vi.fn(),
note: vi.fn(),
};
@@ -163,16 +152,8 @@ describe("promptCustomApiConfig", () => {
"fetch",
vi
.fn()
.mockResolvedValueOnce({
ok: true,
json: async () => ({ data: [] }),
})
.mockResolvedValueOnce({ ok: false, status: 404, json: async () => ({}) })
.mockResolvedValueOnce({ ok: false, status: 404, json: async () => ({}) })
.mockResolvedValueOnce({
ok: true,
json: async () => ({ data: [] }),
})
.mockResolvedValueOnce({ ok: true, json: async () => ({}) }),
);
@@ -194,8 +175,8 @@ describe("promptCustomApiConfig", () => {
.fn()
.mockResolvedValueOnce("http://localhost:11434/v1") // Base URL
.mockResolvedValueOnce("") // API Key
.mockResolvedValueOnce("llama3") // Model ID
.mockResolvedValueOnce("custom") // Endpoint ID
.mockResolvedValueOnce("llama3") // Manual model
.mockResolvedValueOnce(""), // Alias
progress: vi.fn(() => ({
update: vi.fn(),
@@ -208,13 +189,10 @@ describe("promptCustomApiConfig", () => {
vi.stubGlobal(
"fetch",
vi
.fn()
.mockRejectedValueOnce(new Error("Discovery failed"))
.mockResolvedValueOnce({
ok: true,
json: async () => ({}),
}),
vi.fn().mockResolvedValueOnce({
ok: true,
json: async () => ({}),
}),
);
const result = await promptCustomApiConfig({
@@ -248,21 +226,22 @@ describe("promptCustomApiConfig", () => {
expect(result.config.models?.providers?.["custom-2"]).toBeDefined();
});
it("aborts discovery after timeout", async () => {
it("aborts verification after timeout", async () => {
vi.useFakeTimers();
const prompter = {
text: vi
.fn()
.mockResolvedValueOnce("http://localhost:11434/v1") // Base URL
.mockResolvedValueOnce("") // API Key
.mockResolvedValueOnce("slow-model") // Model ID
.mockResolvedValueOnce("fast-model") // Model ID retry
.mockResolvedValueOnce("custom") // Endpoint ID
.mockResolvedValueOnce("manual-model-id") // Manual model
.mockResolvedValueOnce(""), // Alias
progress: vi.fn(() => ({
update: vi.fn(),
stop: vi.fn(),
})),
select: vi.fn().mockResolvedValueOnce("openai"),
select: vi.fn().mockResolvedValueOnce("openai").mockResolvedValueOnce("model"),
confirm: vi.fn(),
note: vi.fn(),
};
@@ -283,9 +262,9 @@ describe("promptCustomApiConfig", () => {
config: {},
});
await vi.advanceTimersByTimeAsync(5000);
await vi.advanceTimersByTimeAsync(10000);
await promise;
expect(prompter.text).toHaveBeenCalledTimes(5);
expect(prompter.text).toHaveBeenCalledTimes(6);
});
});

View File

@@ -8,10 +8,8 @@ import { applyPrimaryModel } from "./model-picker.js";
import { normalizeAlias } from "./models/shared.js";
const DEFAULT_OPENAI_BASE_URL = "http://127.0.0.1:11434/v1";
const DEFAULT_ANTHROPIC_BASE_URL = "https://api.anthropic.com/v1";
const DEFAULT_CONTEXT_WINDOW = 4096;
const DEFAULT_MAX_TOKENS = 4096;
const DISCOVERY_TIMEOUT_MS = 5000;
const VERIFY_TIMEOUT_MS = 10000;
type CustomApiCompatibility = "openai" | "anthropic";
@@ -31,7 +29,7 @@ const COMPATIBILITY_OPTIONS: Array<{
{
value: "openai",
label: "OpenAI-compatible",
hint: "Uses /models + /chat/completions",
hint: "Uses /chat/completions",
api: "openai-completions",
},
{
@@ -47,19 +45,6 @@ const COMPATIBILITY_OPTIONS: Array<{
},
];
function resolveBaseUrlDefaults(compatibility: CustomApiCompatibilityChoice) {
if (compatibility === "anthropic") {
return {
initialValue: DEFAULT_ANTHROPIC_BASE_URL,
placeholder: "https://api.anthropic.com/v1",
};
}
return {
initialValue: DEFAULT_OPENAI_BASE_URL,
placeholder: "http://127.0.0.1:11434/v1",
};
}
function normalizeEndpointId(raw: string): string {
const trimmed = raw.trim().toLowerCase();
if (!trimmed) {
@@ -162,19 +147,6 @@ async function fetchWithTimeout(
}
}
function parseOpenAiModels(data: { data?: { id: string }[]; models?: { id: string }[] }) {
const rawModels = data.data || data.models || [];
return rawModels.map((m: unknown) => {
if (typeof m === "string") {
return m;
}
if (typeof m === "object" && m !== null && "id" in m) {
return (m as { id: string }).id;
}
return String(m);
});
}
function formatVerificationError(error: unknown): string {
if (!error) {
return "unknown error";
@@ -192,39 +164,6 @@ function formatVerificationError(error: unknown): string {
}
}
async function tryDiscoverOpenAiModels(params: {
baseUrl: string;
apiKey: string;
prompter: WizardPrompter;
}): Promise<string[] | null> {
const { baseUrl, apiKey, prompter } = params;
const spinner = prompter.progress("Connecting...");
spinner.update(`Scanning models at ${baseUrl}...`);
try {
const discoveryUrl = new URL("models", baseUrl.endsWith("/") ? baseUrl : `${baseUrl}/`).href;
const res = await fetchWithTimeout(
discoveryUrl,
{ headers: buildOpenAiHeaders(apiKey) },
DISCOVERY_TIMEOUT_MS,
);
if (res.ok) {
const data = (await res.json()) as { data?: { id: string }[]; models?: { id: string }[] };
const models = parseOpenAiModels(data);
if (models.length > 0) {
spinner.stop(`Found ${models.length} models.`);
return models;
}
spinner.stop("Connected, but no models list returned.");
return [];
}
spinner.stop(`Connection succeeded, but discovery failed (${res.status}).`);
return null;
} catch {
spinner.stop("Could not auto-detect models.");
return null;
}
}
type VerificationResult = {
ok: boolean;
status?: number;
@@ -297,14 +236,12 @@ async function requestAnthropicVerification(params: {
async function promptBaseUrlAndKey(params: {
prompter: WizardPrompter;
compatibility: CustomApiCompatibilityChoice;
initialBaseUrl?: string;
}): Promise<{ baseUrl: string; apiKey: string }> {
const defaults = resolveBaseUrlDefaults(params.compatibility);
const baseUrlInput = await params.prompter.text({
message: "API Base URL",
initialValue: params.initialBaseUrl ?? defaults.initialValue,
placeholder: defaults.placeholder,
initialValue: params.initialBaseUrl ?? DEFAULT_OPENAI_BASE_URL,
placeholder: "https://api.example.com/v1",
validate: (val) => {
try {
new URL(val);
@@ -315,7 +252,7 @@ async function promptBaseUrlAndKey(params: {
},
});
const apiKeyInput = await params.prompter.text({
message: "API Key (optional)",
message: "API Key (leave blank if not required)",
placeholder: "sk-...",
initialValue: "",
});
@@ -329,6 +266,10 @@ export async function promptCustomApiConfig(params: {
}): Promise<CustomApiResult> {
const { prompter, runtime, config } = params;
const baseInput = await promptBaseUrlAndKey({ prompter });
let baseUrl = baseInput.baseUrl;
let apiKey = baseInput.apiKey;
const compatibilityChoice = await prompter.select({
message: "Endpoint compatibility",
options: COMPATIBILITY_OPTIONS.map((option) => ({
@@ -337,44 +278,24 @@ export async function promptCustomApiConfig(params: {
hint: option.hint,
})),
});
let modelId = (
await prompter.text({
message: "Model ID",
placeholder: "e.g. llama3, claude-3-7-sonnet",
validate: (val) => (val.trim() ? undefined : "Model ID is required"),
})
).trim();
let compatibility: CustomApiCompatibility | null =
compatibilityChoice === "unknown" ? null : compatibilityChoice;
let providerApi =
COMPATIBILITY_OPTIONS.find((entry) => entry.value === compatibility)?.api ??
"openai-completions";
let baseUrl = "";
let apiKey = "";
let modelId: string | undefined;
let discoveredModels: string[] | null = null;
let verifiedFromProbe = false;
if (compatibilityChoice === "unknown") {
let lastBaseUrl: string | undefined;
while (!compatibility) {
const baseInput = await promptBaseUrlAndKey({
prompter,
compatibility: compatibilityChoice,
initialBaseUrl: lastBaseUrl,
});
baseUrl = baseInput.baseUrl;
apiKey = baseInput.apiKey;
const models = await tryDiscoverOpenAiModels({ baseUrl, apiKey, prompter });
if (models && models.length > 0) {
compatibility = "openai";
providerApi = "openai-completions";
discoveredModels = models;
break;
}
modelId = (
await prompter.text({
message: "Model ID",
placeholder: "e.g. llama3, claude-3-7-sonnet",
validate: (val) => (val.trim() ? undefined : "Model ID is required"),
})
).trim();
while (true) {
let verifiedFromProbe = false;
if (!compatibility) {
const probeSpinner = prompter.progress("Detecting endpoint type...");
const openaiProbe = await requestOpenAiVerification({ baseUrl, apiKey, modelId });
if (openaiProbe.ok) {
@@ -382,41 +303,95 @@ export async function promptCustomApiConfig(params: {
compatibility = "openai";
providerApi = "openai-completions";
verifiedFromProbe = true;
break;
} else {
const anthropicProbe = await requestAnthropicVerification({ baseUrl, apiKey, modelId });
if (anthropicProbe.ok) {
probeSpinner.stop("Detected Anthropic-compatible endpoint.");
compatibility = "anthropic";
providerApi = "anthropic-messages";
verifiedFromProbe = true;
} else {
probeSpinner.stop("Could not detect endpoint type.");
await prompter.note(
"This endpoint did not respond to OpenAI or Anthropic style requests.",
"Endpoint detection",
);
const retryChoice = await prompter.select({
message: "What would you like to change?",
options: [
{ value: "baseUrl", label: "Change base URL" },
{ value: "model", label: "Change model" },
{ value: "both", label: "Change base URL and model" },
],
});
if (retryChoice === "baseUrl" || retryChoice === "both") {
const retryInput = await promptBaseUrlAndKey({
prompter,
initialBaseUrl: baseUrl,
});
baseUrl = retryInput.baseUrl;
apiKey = retryInput.apiKey;
}
if (retryChoice === "model" || retryChoice === "both") {
modelId = (
await prompter.text({
message: "Model ID",
placeholder: "e.g. llama3, claude-3-7-sonnet",
validate: (val) => (val.trim() ? undefined : "Model ID is required"),
})
).trim();
}
continue;
}
}
const anthropicProbe = await requestAnthropicVerification({ baseUrl, apiKey, modelId });
if (anthropicProbe.ok) {
probeSpinner.stop("Detected Anthropic-compatible endpoint.");
compatibility = "anthropic";
providerApi = "anthropic-messages";
verifiedFromProbe = true;
break;
}
probeSpinner.stop("Could not detect endpoint type.");
await prompter.note(
"This endpoint did not respond to OpenAI or Anthropic style requests. Enter a new base URL and try again.",
"Endpoint detection",
);
lastBaseUrl = baseUrl;
modelId = undefined;
}
} else {
const baseInput = await promptBaseUrlAndKey({
prompter,
compatibility: compatibilityChoice,
});
baseUrl = baseInput.baseUrl;
apiKey = baseInput.apiKey;
compatibility = compatibilityChoice;
providerApi =
COMPATIBILITY_OPTIONS.find((entry) => entry.value === compatibility)?.api ??
"openai-completions";
}
if (!compatibility) {
return { config };
if (verifiedFromProbe) {
break;
}
const verifySpinner = prompter.progress("Verifying...");
const result =
compatibility === "anthropic"
? await requestAnthropicVerification({ baseUrl, apiKey, modelId })
: await requestOpenAiVerification({ baseUrl, apiKey, modelId });
if (result.ok) {
verifySpinner.stop("Verification successful.");
break;
}
if (result.status !== undefined) {
verifySpinner.stop(`Verification failed: status ${result.status}`);
} else {
verifySpinner.stop(`Verification failed: ${formatVerificationError(result.error)}`);
}
const retryChoice = await prompter.select({
message: "What would you like to change?",
options: [
{ value: "baseUrl", label: "Change base URL" },
{ value: "model", label: "Change model" },
{ value: "both", label: "Change base URL and model" },
],
});
if (retryChoice === "baseUrl" || retryChoice === "both") {
const retryInput = await promptBaseUrlAndKey({
prompter,
initialBaseUrl: baseUrl,
});
baseUrl = retryInput.baseUrl;
apiKey = retryInput.apiKey;
}
if (retryChoice === "model" || retryChoice === "both") {
modelId = (
await prompter.text({
message: "Model ID",
placeholder: "e.g. llama3, claude-3-7-sonnet",
validate: (val) => (val.trim() ? undefined : "Model ID is required"),
})
).trim();
}
if (compatibilityChoice === "unknown") {
compatibility = null;
}
}
const providers = config.models?.providers ?? {};
@@ -446,40 +421,6 @@ export async function promptCustomApiConfig(params: {
}
const providerId = providerIdResult.providerId;
if (compatibility === "openai" && !discoveredModels) {
discoveredModels = await tryDiscoverOpenAiModels({ baseUrl, apiKey, prompter });
}
if (!modelId) {
if (compatibility === "openai" && discoveredModels && discoveredModels.length > 0) {
const selection = await prompter.select({
message: "Select a model",
options: [
...discoveredModels.map((id) => ({ value: id, label: id })),
{ value: "__manual", label: "(Enter manually...)" },
],
});
if (selection !== "__manual") {
modelId = selection;
}
} else if (compatibility === "anthropic") {
await prompter.note(
"Anthropic-compatible endpoints do not expose a standard models endpoint. Please enter a model ID manually.",
"Model discovery",
);
}
}
if (!modelId) {
modelId = (
await prompter.text({
message: "Model ID",
placeholder: "e.g. llama3, claude-3-7-sonnet",
validate: (val) => (val.trim() ? undefined : "Model ID is required"),
})
).trim();
}
const modelRef = modelKey(providerId, modelId);
const aliasInput = await prompter.text({
message: "Model alias (optional)",
@@ -489,33 +430,6 @@ export async function promptCustomApiConfig(params: {
});
const alias = aliasInput.trim();
let verified = verifiedFromProbe;
if (!verified) {
const verifySpinner = prompter.progress("Verifying...");
const result =
compatibility === "anthropic"
? await requestAnthropicVerification({ baseUrl, apiKey, modelId })
: await requestOpenAiVerification({ baseUrl, apiKey, modelId });
if (result.ok) {
verified = true;
verifySpinner.stop("Verification successful.");
} else if (result.status !== undefined) {
verifySpinner.stop(`Verification failed: status ${result.status}`);
} else {
verifySpinner.stop(`Verification failed: ${formatVerificationError(result.error)}`);
}
}
if (!verified) {
const confirm = await prompter.confirm({
message: "Could not verify model connection. Save anyway?",
initialValue: true,
});
if (!confirm) {
return { config };
}
}
const existingProvider = providers[providerId];
const existingModels = Array.isArray(existingProvider?.models) ? existingProvider.models : [];
const hasModel = existingModels.some((model) => model.id === modelId);