feat(providers): add Fireworks AI provider integration (#3873)

* feat(providers): add Fireworks AI provider integration

* fix(providers): remove unused logger and dead modelInfo from fireworks

* lint

* feat(providers): add Fireworks BYOK support and official icon

* fix(providers): add workspace membership check and remove shared fetch cache for fireworks models
This commit is contained in:
Waleed
2026-03-31 19:22:04 -07:00
committed by GitHub
parent b95a0491a0
commit e39c534ee3
21 changed files with 892 additions and 14 deletions

View File

@@ -28,6 +28,7 @@ API_ENCRYPTION_KEY=your_api_encryption_key # Use `openssl rand -hex 32` to gener
# OLLAMA_URL=http://localhost:11434 # URL for local Ollama server - uncomment if using local models
# VLLM_BASE_URL=http://localhost:8000 # Base URL for your self-hosted vLLM (OpenAI-compatible)
# VLLM_API_KEY= # Optional bearer token if your vLLM instance requires auth
# FIREWORKS_API_KEY= # Optional Fireworks AI API key for model listing
# Admin API (Optional - for self-hosted GitOps)
# ADMIN_API_KEY= # Use `openssl rand -hex 32` to generate. Enables admin API for workflow export/import.

View File

@@ -0,0 +1,93 @@
import { createLogger } from '@sim/logger'
import { type NextRequest, NextResponse } from 'next/server'
import { getBYOKKey } from '@/lib/api-key/byok'
import { getSession } from '@/lib/auth'
import { env } from '@/lib/core/config/env'
import { getUserEntityPermissions } from '@/lib/workspaces/permissions/utils'
import { filterBlacklistedModels, isProviderBlacklisted } from '@/providers/utils'
const logger = createLogger('FireworksModelsAPI')
interface FireworksModel {
id: string
object?: string
created?: number
owned_by?: string
}
interface FireworksModelsResponse {
data: FireworksModel[]
object?: string
}
export async function GET(request: NextRequest) {
if (isProviderBlacklisted('fireworks')) {
logger.info('Fireworks provider is blacklisted, returning empty models')
return NextResponse.json({ models: [] })
}
let apiKey: string | undefined
const workspaceId = request.nextUrl.searchParams.get('workspaceId')
if (workspaceId) {
const session = await getSession()
if (session?.user?.id) {
const permission = await getUserEntityPermissions(session.user.id, 'workspace', workspaceId)
if (permission) {
const byokResult = await getBYOKKey(workspaceId, 'fireworks')
if (byokResult) {
apiKey = byokResult.apiKey
}
}
}
}
if (!apiKey) {
apiKey = env.FIREWORKS_API_KEY
}
if (!apiKey) {
logger.info('No Fireworks API key available, returning empty models')
return NextResponse.json({ models: [] })
}
try {
const response = await fetch('https://api.fireworks.ai/inference/v1/models', {
headers: {
Authorization: `Bearer ${apiKey}`,
'Content-Type': 'application/json',
},
cache: 'no-store',
})
if (!response.ok) {
logger.warn('Failed to fetch Fireworks models', {
status: response.status,
statusText: response.statusText,
})
return NextResponse.json({ models: [] })
}
const data = (await response.json()) as FireworksModelsResponse
const allModels: string[] = []
for (const model of data.data ?? []) {
allModels.push(`fireworks/${model.id}`)
}
const uniqueModels = Array.from(new Set(allModels))
const models = filterBlacklistedModels(uniqueModels)
logger.info('Successfully fetched Fireworks models', {
count: models.length,
filtered: uniqueModels.length - models.length,
})
return NextResponse.json({ models })
} catch (error) {
logger.error('Error fetching Fireworks models', {
error: error instanceof Error ? error.message : 'Unknown error',
})
return NextResponse.json({ models: [] })
}
}

View File

@@ -18,6 +18,7 @@ const VALID_PROVIDERS = [
'anthropic',
'google',
'mistral',
'fireworks',
'firecrawl',
'exa',
'serper',

View File

@@ -2,8 +2,10 @@
import { useEffect } from 'react'
import { createLogger } from '@sim/logger'
import { useParams } from 'next/navigation'
import { useProviderModels } from '@/hooks/queries/providers'
import {
updateFireworksProviderModels,
updateOllamaProviderModels,
updateOpenRouterProviderModels,
updateVLLMProviderModels,
@@ -12,11 +14,11 @@ import { type ProviderName, useProvidersStore } from '@/stores/providers'
const logger = createLogger('ProviderModelsLoader')
function useSyncProvider(provider: ProviderName) {
function useSyncProvider(provider: ProviderName, workspaceId?: string) {
const setProviderModels = useProvidersStore((state) => state.setProviderModels)
const setProviderLoading = useProvidersStore((state) => state.setProviderLoading)
const setOpenRouterModelInfo = useProvidersStore((state) => state.setOpenRouterModelInfo)
const { data, isLoading, isFetching, error } = useProviderModels(provider)
const { data, isLoading, isFetching, error } = useProviderModels(provider, workspaceId)
useEffect(() => {
setProviderLoading(provider, isLoading || isFetching)
@@ -35,6 +37,8 @@ function useSyncProvider(provider: ProviderName) {
if (data.modelInfo) {
setOpenRouterModelInfo(data.modelInfo)
}
} else if (provider === 'fireworks') {
void updateFireworksProviderModels(data.models)
}
} catch (syncError) {
logger.warn(`Failed to sync provider definitions for ${provider}`, syncError as Error)
@@ -51,9 +55,13 @@ function useSyncProvider(provider: ProviderName) {
}
export function ProviderModelsLoader() {
const params = useParams()
const workspaceId = params?.workspaceId as string | undefined
useSyncProvider('base')
useSyncProvider('ollama')
useSyncProvider('vllm')
useSyncProvider('openrouter')
useSyncProvider('fireworks', workspaceId)
return null
}

View File

@@ -18,6 +18,7 @@ import {
BrandfetchIcon,
ExaAIIcon,
FirecrawlIcon,
FireworksIcon,
GeminiIcon,
GoogleIcon,
JinaAIIcon,
@@ -75,6 +76,13 @@ const PROVIDERS: {
description: 'LLM calls and Knowledge Base OCR',
placeholder: 'Enter your API key',
},
{
id: 'fireworks',
name: 'Fireworks',
icon: FireworksIcon,
description: 'LLM calls',
placeholder: 'Enter your Fireworks API key',
},
{
id: 'firecrawl',
name: 'Firecrawl',

View File

@@ -21,8 +21,15 @@ export function getModelOptions() {
const ollamaModels = providersState.providers.ollama.models
const vllmModels = providersState.providers.vllm.models
const openrouterModels = providersState.providers.openrouter.models
const fireworksModels = providersState.providers.fireworks.models
const allModels = Array.from(
new Set([...baseModels, ...ollamaModels, ...vllmModels, ...openrouterModels])
new Set([
...baseModels,
...ollamaModels,
...vllmModels,
...openrouterModels,
...fireworksModels,
])
)
return allModels.map((model) => {

View File

@@ -3483,6 +3483,25 @@ export function MySQLIcon(props: SVGProps<SVGSVGElement>) {
)
}
export function FireworksIcon(props: SVGProps<SVGSVGElement>) {
return (
<svg
{...props}
viewBox='0 0 512 512'
xmlns='http://www.w3.org/2000/svg'
fillRule='evenodd'
clipRule='evenodd'
strokeLinejoin='round'
strokeMiterlimit={2}
>
<path
d='M314.333 110.167L255.98 251.729l-58.416-141.562h-37.459l64 154.75c5.23 12.854 17.771 21.312 31.646 21.312s26.417-8.437 31.646-21.27l64.396-154.792h-37.459zm24.917 215.666L446 216.583l-14.562-34.77-116.584 119.562c-9.708 9.958-12.541 24.833-7.146 37.646 5.292 12.73 17.792 21.083 31.584 21.083l.042.063L506 359.75l-14.562-34.77-152.146.853h-.042zM66 216.5l14.563-34.77 116.583 119.562a34.592 34.592 0 017.146 37.646C199 351.667 186.5 360.02 172.708 360.02l-166.666-.375-.042.042 14.563-34.771 152.145.875L66 216.5z'
fill='currentColor'
/>
</svg>
)
}
export function OpenRouterIcon(props: SVGProps<SVGSVGElement>) {
return (
<svg

View File

@@ -9,6 +9,7 @@ const providerEndpoints: Record<ProviderName, string> = {
ollama: '/api/providers/ollama/models',
vllm: '/api/providers/vllm/models',
openrouter: '/api/providers/openrouter/models',
fireworks: '/api/providers/fireworks/models',
}
interface ProviderModelsResponse {
@@ -18,14 +19,21 @@ interface ProviderModelsResponse {
export const providerKeys = {
all: ['provider-models'] as const,
models: (provider: string) => [...providerKeys.all, provider] as const,
models: (provider: string, workspaceId?: string) =>
[...providerKeys.all, provider, workspaceId ?? ''] as const,
}
async function fetchProviderModels(
provider: ProviderName,
signal?: AbortSignal
signal?: AbortSignal,
workspaceId?: string
): Promise<ProviderModelsResponse> {
const response = await fetch(providerEndpoints[provider], { signal })
let url = providerEndpoints[provider]
if (provider === 'fireworks' && workspaceId) {
url = `${url}?workspaceId=${encodeURIComponent(workspaceId)}`
}
const response = await fetch(url, { signal })
if (!response.ok) {
logger.warn(`Failed to fetch ${provider} models`, {
@@ -45,10 +53,10 @@ async function fetchProviderModels(
}
}
export function useProviderModels(provider: ProviderName) {
export function useProviderModels(provider: ProviderName, workspaceId?: string) {
return useQuery({
queryKey: providerKeys.models(provider),
queryFn: ({ signal }) => fetchProviderModels(provider, signal),
queryKey: providerKeys.models(provider, workspaceId),
queryFn: ({ signal }) => fetchProviderModels(provider, signal, workspaceId),
staleTime: 5 * 60 * 1000,
})
}

View File

@@ -73,6 +73,26 @@ export async function getApiKeyWithBYOK(
return { apiKey: userProvidedKey || env.VLLM_API_KEY || 'empty', isBYOK: false }
}
const isFireworksModel =
provider === 'fireworks' ||
useProvidersStore.getState().providers.fireworks.models.includes(model)
if (isFireworksModel) {
if (workspaceId) {
const byokResult = await getBYOKKey(workspaceId, 'fireworks')
if (byokResult) {
logger.info('Using BYOK key for Fireworks', { model, workspaceId })
return byokResult
}
}
if (userProvidedKey) {
return { apiKey: userProvidedKey, isBYOK: false }
}
if (env.FIREWORKS_API_KEY) {
return { apiKey: env.FIREWORKS_API_KEY, isBYOK: false }
}
throw new Error(`API key is required for Fireworks ${model}`)
}
const isBedrockModel = provider === 'bedrock' || model.startsWith('bedrock/')
if (isBedrockModel) {
return { apiKey: 'bedrock-uses-own-credentials', isBYOK: false }

View File

@@ -696,14 +696,19 @@ function resolveAuthType(
/**
* Gets all available models from PROVIDER_DEFINITIONS as static options.
* This provides fallback data when store state is not available server-side.
* Excludes dynamic providers (ollama, vllm, openrouter) which require runtime fetching.
* Excludes dynamic providers (ollama, vllm, openrouter, fireworks) which require runtime fetching.
*/
function getStaticModelOptions(): { id: string; label?: string }[] {
const models: { id: string; label?: string }[] = []
for (const provider of Object.values(PROVIDER_DEFINITIONS)) {
// Skip providers with dynamic/fetched models
if (provider.id === 'ollama' || provider.id === 'vllm' || provider.id === 'openrouter') {
if (
provider.id === 'ollama' ||
provider.id === 'vllm' ||
provider.id === 'openrouter' ||
provider.id === 'fireworks'
) {
continue
}
if (provider?.models) {
@@ -737,6 +742,7 @@ function callOptionsWithFallback(
ollama: { models: [] },
vllm: { models: [] },
openrouter: { models: [] },
fireworks: { models: [] },
},
}

View File

@@ -324,7 +324,7 @@ function getStaticModelOptionsForVFS(): Array<{
hosted: boolean
}> {
const hostedProviders = new Set(['openai', 'anthropic', 'google'])
const dynamicProviders = new Set(['ollama', 'vllm', 'openrouter'])
const dynamicProviders = new Set(['ollama', 'vllm', 'openrouter', 'fireworks'])
const models: Array<{ id: string; provider: string; hosted: boolean }> = []

View File

@@ -105,6 +105,7 @@ export const env = createEnv({
OLLAMA_URL: z.string().url().optional(), // Ollama local LLM server URL
VLLM_BASE_URL: z.string().url().optional(), // vLLM self-hosted base URL (OpenAI-compatible)
VLLM_API_KEY: z.string().optional(), // Optional bearer token for vLLM
FIREWORKS_API_KEY: z.string().optional(), // Optional Fireworks AI API key for model listing
ELEVENLABS_API_KEY: z.string().min(1).optional(), // ElevenLabs API key for text-to-speech in deployed chat
SERPER_API_KEY: z.string().min(1).optional(), // Serper API key for online search
EXA_API_KEY: z.string().min(1).optional(), // Exa AI API key for enhanced online search

View File

@@ -0,0 +1,623 @@
import { createLogger } from '@sim/logger'
import OpenAI from 'openai'
import type { ChatCompletionCreateParamsStreaming } from 'openai/resources/chat/completions'
import type { StreamingExecution } from '@/executor/types'
import { MAX_TOOL_ITERATIONS } from '@/providers'
import {
checkForForcedToolUsage,
createReadableStreamFromOpenAIStream,
supportsNativeStructuredOutputs,
} from '@/providers/fireworks/utils'
import { getProviderDefaultModel, getProviderModels } from '@/providers/models'
import type {
FunctionCallResponse,
Message,
ProviderConfig,
ProviderRequest,
ProviderResponse,
TimeSegment,
} from '@/providers/types'
import { ProviderError } from '@/providers/types'
import {
calculateCost,
generateSchemaInstructions,
prepareToolExecution,
prepareToolsWithUsageControl,
sumToolCosts,
} from '@/providers/utils'
import { executeTool } from '@/tools'
const logger = createLogger('FireworksProvider')
/**
* Applies structured output configuration to a payload based on model capabilities.
* Uses json_schema with strict mode for supported models, falls back to json_object with prompt instructions.
*/
async function applyResponseFormat(
targetPayload: any,
messages: any[],
responseFormat: any,
model: string
): Promise<any[]> {
const useNative = await supportsNativeStructuredOutputs(model)
if (useNative) {
logger.info('Using native structured outputs for Fireworks model', { model })
targetPayload.response_format = {
type: 'json_schema',
json_schema: {
name: responseFormat.name || 'response_schema',
schema: responseFormat.schema || responseFormat,
strict: responseFormat.strict !== false,
},
}
return messages
}
logger.info('Using json_object mode with prompt instructions for Fireworks model', { model })
const schema = responseFormat.schema || responseFormat
const schemaInstructions = generateSchemaInstructions(schema, responseFormat.name)
targetPayload.response_format = { type: 'json_object' }
return [...messages, { role: 'user', content: schemaInstructions }]
}
export const fireworksProvider: ProviderConfig = {
id: 'fireworks',
name: 'Fireworks',
description: 'Fast inference for open-source models via Fireworks AI',
version: '1.0.0',
models: getProviderModels('fireworks'),
defaultModel: getProviderDefaultModel('fireworks'),
executeRequest: async (
request: ProviderRequest
): Promise<ProviderResponse | StreamingExecution> => {
if (!request.apiKey) {
throw new Error('API key is required for Fireworks')
}
const client = new OpenAI({
apiKey: request.apiKey,
baseURL: 'https://api.fireworks.ai/inference/v1',
})
const requestedModel = request.model.replace(/^fireworks\//, '')
logger.info('Preparing Fireworks request', {
model: requestedModel,
hasSystemPrompt: !!request.systemPrompt,
hasMessages: !!request.messages?.length,
hasTools: !!request.tools?.length,
toolCount: request.tools?.length || 0,
hasResponseFormat: !!request.responseFormat,
stream: !!request.stream,
})
const allMessages: Message[] = []
if (request.systemPrompt) {
allMessages.push({ role: 'system', content: request.systemPrompt })
}
if (request.context) {
allMessages.push({ role: 'user', content: request.context })
}
if (request.messages) {
allMessages.push(...request.messages)
}
const tools = request.tools?.length
? request.tools.map((tool) => ({
type: 'function',
function: {
name: tool.id,
description: tool.description,
parameters: tool.parameters,
},
}))
: undefined
const payload: any = {
model: requestedModel,
messages: allMessages,
}
if (request.temperature !== undefined) payload.temperature = request.temperature
if (request.maxTokens != null) payload.max_tokens = request.maxTokens
let preparedTools: ReturnType<typeof prepareToolsWithUsageControl> | null = null
let hasActiveTools = false
if (tools?.length) {
preparedTools = prepareToolsWithUsageControl(tools, request.tools, logger, 'fireworks')
const { tools: filteredTools, toolChoice } = preparedTools
if (filteredTools?.length && toolChoice) {
payload.tools = filteredTools
payload.tool_choice = toolChoice
hasActiveTools = true
}
}
const providerStartTime = Date.now()
const providerStartTimeISO = new Date(providerStartTime).toISOString()
try {
if (request.responseFormat && !hasActiveTools) {
payload.messages = await applyResponseFormat(
payload,
payload.messages,
request.responseFormat,
requestedModel
)
}
if (request.stream && (!tools || tools.length === 0 || !hasActiveTools)) {
const streamingParams: ChatCompletionCreateParamsStreaming = {
...payload,
stream: true,
stream_options: { include_usage: true },
}
const streamResponse = await client.chat.completions.create(
streamingParams,
request.abortSignal ? { signal: request.abortSignal } : undefined
)
const streamingResult = {
stream: createReadableStreamFromOpenAIStream(streamResponse, (content, usage) => {
streamingResult.execution.output.content = content
streamingResult.execution.output.tokens = {
input: usage.prompt_tokens,
output: usage.completion_tokens,
total: usage.total_tokens,
}
const costResult = calculateCost(
requestedModel,
usage.prompt_tokens,
usage.completion_tokens
)
streamingResult.execution.output.cost = {
input: costResult.input,
output: costResult.output,
total: costResult.total,
}
const end = Date.now()
const endISO = new Date(end).toISOString()
if (streamingResult.execution.output.providerTiming) {
streamingResult.execution.output.providerTiming.endTime = endISO
streamingResult.execution.output.providerTiming.duration = end - providerStartTime
if (streamingResult.execution.output.providerTiming.timeSegments?.[0]) {
streamingResult.execution.output.providerTiming.timeSegments[0].endTime = end
streamingResult.execution.output.providerTiming.timeSegments[0].duration =
end - providerStartTime
}
}
}),
execution: {
success: true,
output: {
content: '',
model: requestedModel,
tokens: { input: 0, output: 0, total: 0 },
toolCalls: undefined,
providerTiming: {
startTime: providerStartTimeISO,
endTime: new Date().toISOString(),
duration: Date.now() - providerStartTime,
timeSegments: [
{
type: 'model',
name: 'Streaming response',
startTime: providerStartTime,
endTime: Date.now(),
duration: Date.now() - providerStartTime,
},
],
},
cost: { input: 0, output: 0, total: 0 },
},
logs: [],
metadata: {
startTime: providerStartTimeISO,
endTime: new Date().toISOString(),
duration: Date.now() - providerStartTime,
},
},
} as StreamingExecution
return streamingResult as StreamingExecution
}
const initialCallTime = Date.now()
const originalToolChoice = payload.tool_choice
const forcedTools = preparedTools?.forcedTools || []
let usedForcedTools: string[] = []
let currentResponse = await client.chat.completions.create(
payload,
request.abortSignal ? { signal: request.abortSignal } : undefined
)
const firstResponseTime = Date.now() - initialCallTime
let content = currentResponse.choices[0]?.message?.content || ''
const tokens = {
input: currentResponse.usage?.prompt_tokens || 0,
output: currentResponse.usage?.completion_tokens || 0,
total: currentResponse.usage?.total_tokens || 0,
}
const toolCalls: FunctionCallResponse[] = []
const toolResults: Record<string, unknown>[] = []
const currentMessages = [...allMessages]
let iterationCount = 0
let modelTime = firstResponseTime
let toolsTime = 0
let hasUsedForcedTool = false
const timeSegments: TimeSegment[] = [
{
type: 'model',
name: 'Initial response',
startTime: initialCallTime,
endTime: initialCallTime + firstResponseTime,
duration: firstResponseTime,
},
]
const forcedToolResult = checkForForcedToolUsage(
currentResponse,
originalToolChoice,
forcedTools,
usedForcedTools
)
hasUsedForcedTool = forcedToolResult.hasUsedForcedTool
usedForcedTools = forcedToolResult.usedForcedTools
while (iterationCount < MAX_TOOL_ITERATIONS) {
if (currentResponse.choices[0]?.message?.content) {
content = currentResponse.choices[0].message.content
}
const toolCallsInResponse = currentResponse.choices[0]?.message?.tool_calls
if (!toolCallsInResponse || toolCallsInResponse.length === 0) {
break
}
const toolsStartTime = Date.now()
const toolExecutionPromises = toolCallsInResponse.map(async (toolCall) => {
const toolCallStartTime = Date.now()
const toolName = toolCall.function.name
try {
const toolArgs = JSON.parse(toolCall.function.arguments)
const tool = request.tools?.find((t) => t.id === toolName)
if (!tool) return null
const { toolParams, executionParams } = prepareToolExecution(tool, toolArgs, request)
const result = await executeTool(toolName, executionParams)
const toolCallEndTime = Date.now()
return {
toolCall,
toolName,
toolParams,
result,
startTime: toolCallStartTime,
endTime: toolCallEndTime,
duration: toolCallEndTime - toolCallStartTime,
}
} catch (error) {
const toolCallEndTime = Date.now()
logger.error('Error processing tool call (Fireworks):', {
error: error instanceof Error ? error.message : String(error),
toolName,
})
return {
toolCall,
toolName,
toolParams: {},
result: {
success: false,
output: undefined,
error: error instanceof Error ? error.message : 'Tool execution failed',
},
startTime: toolCallStartTime,
endTime: toolCallEndTime,
duration: toolCallEndTime - toolCallStartTime,
}
}
})
const executionResults = await Promise.allSettled(toolExecutionPromises)
currentMessages.push({
role: 'assistant',
content: null,
tool_calls: toolCallsInResponse.map((tc) => ({
id: tc.id,
type: 'function',
function: {
name: tc.function.name,
arguments: tc.function.arguments,
},
})),
})
for (const settledResult of executionResults) {
if (settledResult.status === 'rejected' || !settledResult.value) continue
const { toolCall, toolName, toolParams, result, startTime, endTime, duration } =
settledResult.value
timeSegments.push({
type: 'tool',
name: toolName,
startTime: startTime,
endTime: endTime,
duration: duration,
})
let resultContent: any
if (result.success) {
toolResults.push(result.output!)
resultContent = result.output
} else {
resultContent = {
error: true,
message: result.error || 'Tool execution failed',
tool: toolName,
}
}
toolCalls.push({
name: toolName,
arguments: toolParams,
startTime: new Date(startTime).toISOString(),
endTime: new Date(endTime).toISOString(),
duration: duration,
result: resultContent,
success: result.success,
})
currentMessages.push({
role: 'tool',
tool_call_id: toolCall.id,
content: JSON.stringify(resultContent),
})
}
const thisToolsTime = Date.now() - toolsStartTime
toolsTime += thisToolsTime
const nextPayload = {
...payload,
messages: currentMessages,
}
if (typeof originalToolChoice === 'object' && hasUsedForcedTool && forcedTools.length > 0) {
const remainingTools = forcedTools.filter((tool) => !usedForcedTools.includes(tool))
if (remainingTools.length > 0) {
nextPayload.tool_choice = { type: 'function', function: { name: remainingTools[0] } }
} else {
nextPayload.tool_choice = 'auto'
}
}
const nextModelStartTime = Date.now()
currentResponse = await client.chat.completions.create(
nextPayload,
request.abortSignal ? { signal: request.abortSignal } : undefined
)
const nextForcedToolResult = checkForForcedToolUsage(
currentResponse,
nextPayload.tool_choice,
forcedTools,
usedForcedTools
)
hasUsedForcedTool = nextForcedToolResult.hasUsedForcedTool
usedForcedTools = nextForcedToolResult.usedForcedTools
const nextModelEndTime = Date.now()
const thisModelTime = nextModelEndTime - nextModelStartTime
timeSegments.push({
type: 'model',
name: `Model response (iteration ${iterationCount + 1})`,
startTime: nextModelStartTime,
endTime: nextModelEndTime,
duration: thisModelTime,
})
modelTime += thisModelTime
if (currentResponse.choices[0]?.message?.content) {
content = currentResponse.choices[0].message.content
}
if (currentResponse.usage) {
tokens.input += currentResponse.usage.prompt_tokens || 0
tokens.output += currentResponse.usage.completion_tokens || 0
tokens.total += currentResponse.usage.total_tokens || 0
}
iterationCount++
}
if (request.stream) {
const accumulatedCost = calculateCost(requestedModel, tokens.input, tokens.output)
const streamingParams: ChatCompletionCreateParamsStreaming = {
...payload,
messages: [...currentMessages],
tool_choice: 'auto',
stream: true,
stream_options: { include_usage: true },
}
if (request.responseFormat) {
;(streamingParams as any).messages = await applyResponseFormat(
streamingParams as any,
streamingParams.messages,
request.responseFormat,
requestedModel
)
}
const streamResponse = await client.chat.completions.create(
streamingParams,
request.abortSignal ? { signal: request.abortSignal } : undefined
)
const streamingResult = {
stream: createReadableStreamFromOpenAIStream(streamResponse, (content, usage) => {
streamingResult.execution.output.content = content
streamingResult.execution.output.tokens = {
input: tokens.input + usage.prompt_tokens,
output: tokens.output + usage.completion_tokens,
total: tokens.total + usage.total_tokens,
}
const streamCost = calculateCost(
requestedModel,
usage.prompt_tokens,
usage.completion_tokens
)
const tc = sumToolCosts(toolResults)
streamingResult.execution.output.cost = {
input: accumulatedCost.input + streamCost.input,
output: accumulatedCost.output + streamCost.output,
toolCost: tc || undefined,
total: accumulatedCost.total + streamCost.total + tc,
}
}),
execution: {
success: true,
output: {
content: '',
model: requestedModel,
tokens: { input: tokens.input, output: tokens.output, total: tokens.total },
toolCalls:
toolCalls.length > 0
? {
list: toolCalls,
count: toolCalls.length,
}
: undefined,
providerTiming: {
startTime: providerStartTimeISO,
endTime: new Date().toISOString(),
duration: Date.now() - providerStartTime,
modelTime: modelTime,
toolsTime: toolsTime,
firstResponseTime: firstResponseTime,
iterations: iterationCount + 1,
timeSegments: timeSegments,
},
cost: {
input: accumulatedCost.input,
output: accumulatedCost.output,
total: accumulatedCost.total,
},
},
logs: [],
metadata: {
startTime: providerStartTimeISO,
endTime: new Date().toISOString(),
duration: Date.now() - providerStartTime,
},
},
} as StreamingExecution
return streamingResult as StreamingExecution
}
if (request.responseFormat && hasActiveTools) {
const finalPayload: any = {
model: payload.model,
messages: [...currentMessages],
}
if (payload.temperature !== undefined) {
finalPayload.temperature = payload.temperature
}
if (payload.max_tokens !== undefined) {
finalPayload.max_tokens = payload.max_tokens
}
finalPayload.messages = await applyResponseFormat(
finalPayload,
finalPayload.messages,
request.responseFormat,
requestedModel
)
const finalStartTime = Date.now()
const finalResponse = await client.chat.completions.create(
finalPayload,
request.abortSignal ? { signal: request.abortSignal } : undefined
)
const finalEndTime = Date.now()
const finalDuration = finalEndTime - finalStartTime
timeSegments.push({
type: 'model',
name: 'Final structured response',
startTime: finalStartTime,
endTime: finalEndTime,
duration: finalDuration,
})
modelTime += finalDuration
if (finalResponse.choices[0]?.message?.content) {
content = finalResponse.choices[0].message.content
}
if (finalResponse.usage) {
tokens.input += finalResponse.usage.prompt_tokens || 0
tokens.output += finalResponse.usage.completion_tokens || 0
tokens.total += finalResponse.usage.total_tokens || 0
}
}
const providerEndTime = Date.now()
const providerEndTimeISO = new Date(providerEndTime).toISOString()
const totalDuration = providerEndTime - providerStartTime
return {
content,
model: requestedModel,
tokens,
toolCalls: toolCalls.length > 0 ? toolCalls : undefined,
toolResults: toolResults.length > 0 ? toolResults : undefined,
timing: {
startTime: providerStartTimeISO,
endTime: providerEndTimeISO,
duration: totalDuration,
modelTime: modelTime,
toolsTime: toolsTime,
firstResponseTime: firstResponseTime,
iterations: iterationCount + 1,
timeSegments: timeSegments,
},
}
} catch (error) {
const providerEndTime = Date.now()
const providerEndTimeISO = new Date(providerEndTime).toISOString()
const totalDuration = providerEndTime - providerStartTime
const errorDetails: Record<string, any> = {
error: error instanceof Error ? error.message : String(error),
duration: totalDuration,
}
if (error && typeof error === 'object') {
const err = error as any
if (err.status) errorDetails.status = err.status
if (err.code) errorDetails.code = err.code
if (err.type) errorDetails.type = err.type
if (err.error?.message) errorDetails.providerMessage = err.error.message
if (err.error?.metadata) errorDetails.metadata = err.error.metadata
}
logger.error('Error in Fireworks request:', errorDetails)
throw new ProviderError(error instanceof Error ? error.message : String(error), {
startTime: providerStartTimeISO,
endTime: providerEndTimeISO,
duration: totalDuration,
})
}
},
}

View File

@@ -0,0 +1,41 @@
import type { ChatCompletionChunk } from 'openai/resources/chat/completions'
import type { CompletionUsage } from 'openai/resources/completions'
import { checkForForcedToolUsageOpenAI, createOpenAICompatibleStream } from '@/providers/utils'
/**
* Checks if a model supports native structured outputs (json_schema).
* Fireworks AI supports structured outputs across their inference API.
*/
export async function supportsNativeStructuredOutputs(_modelId: string): Promise<boolean> {
return true
}
/**
* Creates a ReadableStream from a Fireworks streaming response.
* Uses the shared OpenAI-compatible streaming utility.
*/
export function createReadableStreamFromOpenAIStream(
openaiStream: AsyncIterable<ChatCompletionChunk>,
onComplete?: (content: string, usage: CompletionUsage) => void
): ReadableStream<Uint8Array> {
return createOpenAICompatibleStream(openaiStream, 'Fireworks', onComplete)
}
/**
* Checks if a forced tool was used in a Fireworks response.
* Uses the shared OpenAI-compatible forced tool usage helper.
*/
export function checkForForcedToolUsage(
response: any,
toolChoice: string | { type: string; function?: { name: string }; name?: string; any?: any },
forcedTools: string[],
usedForcedTools: string[]
): { hasUsedForcedTool: boolean; usedForcedTools: string[] } {
return checkForForcedToolUsageOpenAI(
response,
toolChoice,
'Fireworks',
forcedTools,
usedForcedTools
)
}

View File

@@ -14,6 +14,7 @@ import {
BedrockIcon,
CerebrasIcon,
DeepseekIcon,
FireworksIcon,
GeminiIcon,
GroqIcon,
MistralIcon,
@@ -71,6 +72,20 @@ export interface ProviderDefinition {
}
export const PROVIDER_DEFINITIONS: Record<string, ProviderDefinition> = {
fireworks: {
id: 'fireworks',
name: 'Fireworks',
description: 'Fast inference for open-source models via Fireworks AI',
defaultModel: '',
modelPatterns: [/^fireworks\//],
icon: FireworksIcon,
capabilities: {
temperature: { min: 0, max: 2 },
toolUsageControl: true,
},
contextInformationAvailable: false,
models: [],
},
openrouter: {
id: 'openrouter',
name: 'OpenRouter',
@@ -2539,6 +2554,18 @@ export function updateVLLMModels(models: string[]): void {
}))
}
export function updateFireworksModels(models: string[]): void {
PROVIDER_DEFINITIONS.fireworks.models = models.map((modelId) => ({
id: modelId,
pricing: {
input: 0,
output: 0,
updatedAt: new Date().toISOString().split('T')[0],
},
capabilities: {},
}))
}
export function updateOpenRouterModels(models: string[]): void {
PROVIDER_DEFINITIONS.openrouter.models = models.map((modelId) => ({
id: modelId,

View File

@@ -5,6 +5,7 @@ import { azureOpenAIProvider } from '@/providers/azure-openai'
import { bedrockProvider } from '@/providers/bedrock'
import { cerebrasProvider } from '@/providers/cerebras'
import { deepseekProvider } from '@/providers/deepseek'
import { fireworksProvider } from '@/providers/fireworks'
import { googleProvider } from '@/providers/google'
import { groqProvider } from '@/providers/groq'
import { mistralProvider } from '@/providers/mistral'
@@ -32,6 +33,7 @@ const providerRegistry: Record<ProviderId, ProviderConfig> = {
mistral: mistralProvider,
'azure-openai': azureOpenAIProvider,
openrouter: openRouterProvider,
fireworks: fireworksProvider,
ollama: ollamaProvider,
bedrock: bedrockProvider,
}

View File

@@ -14,6 +14,7 @@ export type ProviderId =
| 'mistral'
| 'ollama'
| 'openrouter'
| 'fireworks'
| 'vllm'
| 'bedrock'

View File

@@ -147,6 +147,7 @@ export const providers: Record<ProviderId, ProviderMetadata> = {
mistral: buildProviderMetadata('mistral'),
bedrock: buildProviderMetadata('bedrock'),
openrouter: buildProviderMetadata('openrouter'),
fireworks: buildProviderMetadata('fireworks'),
}
export function updateOllamaProviderModels(models: string[]): void {
@@ -166,11 +167,20 @@ export async function updateOpenRouterProviderModels(models: string[]): Promise<
providers.openrouter.models = getProviderModelsFromDefinitions('openrouter')
}
export async function updateFireworksProviderModels(models: string[]): Promise<void> {
const { updateFireworksModels } = await import('@/providers/models')
updateFireworksModels(models)
providers.fireworks.models = getProviderModelsFromDefinitions('fireworks')
}
export function getBaseModelProviders(): Record<string, ProviderId> {
const allProviders = Object.entries(providers)
.filter(
([providerId]) =>
providerId !== 'ollama' && providerId !== 'vllm' && providerId !== 'openrouter'
providerId !== 'ollama' &&
providerId !== 'vllm' &&
providerId !== 'openrouter' &&
providerId !== 'fireworks'
)
.reduce(
(map, [providerId, config]) => {

View File

@@ -10,6 +10,7 @@ export const useProvidersStore = create<ProvidersStore>((set, get) => ({
ollama: { models: [], isLoading: false },
vllm: { models: [], isLoading: false },
openrouter: { models: [], isLoading: false },
fireworks: { models: [], isLoading: false },
},
openRouterModelInfo: {},

View File

@@ -1,4 +1,4 @@
export type ProviderName = 'ollama' | 'vllm' | 'openrouter' | 'base'
export type ProviderName = 'ollama' | 'vllm' | 'openrouter' | 'fireworks' | 'base'
export interface OpenRouterModelInfo {
id: string

View File

@@ -6,6 +6,7 @@ export type BYOKProviderId =
| 'anthropic'
| 'google'
| 'mistral'
| 'fireworks'
| 'firecrawl'
| 'exa'
| 'serper'