fix(ai): support dynamic config

This commit is contained in:
0xjojo1
2025-07-15 17:41:05 +08:00
parent 19bec12ef8
commit 528faddec7
15 changed files with 336 additions and 144 deletions

View File

@@ -1,25 +1,23 @@
import { EmbeddingModelV1, LanguageModelV1 } from '@ai-sdk/provider'
import { AICustomConfig, ProviderType } from '../config/type'
import { AbstractProvider } from './providers/abstractProvider'
import { OpenAIModelProvider } from './providers/openAI'
import { GlobalConfig, ProviderType } from './type'
import { PgLiteVector } from './vector'
export class AIModelFactory {
private static instance: AIModelFactory
private llmProviderCache = new Map<string, AbstractProvider>()
private embeddingProviderCache = new Map<string, AbstractProvider>()
private vectorInstance?: PgLiteVector
private globalConfig: GlobalConfig
private config: AICustomConfig
private constructor(config: GlobalConfig) {
this.globalConfig = config
private constructor(config: AICustomConfig) {
this.config = config
}
/**
* Initialize the singleton instance with configuration
* This should be called once at app startup
*/
static initialize(config: GlobalConfig): AIModelFactory {
static initialize(config: AICustomConfig): AIModelFactory {
AIModelFactory.instance = new AIModelFactory(config)
return AIModelFactory.instance
}
@@ -47,15 +45,15 @@ export class AIModelFactory {
/**
* Get current configuration
*/
getConfig(): GlobalConfig {
return this.globalConfig
getConfig(): AICustomConfig {
return this.config
}
/**
* Update configuration and clear caches
*/
updateConfig(config: GlobalConfig): void {
this.globalConfig = config
updateConfig(config: AICustomConfig): void {
this.config = config
this.clearCache()
}
@@ -64,7 +62,7 @@ export class AIModelFactory {
*/
private createProvider(
type: ProviderType,
config: GlobalConfig,
config: AICustomConfig,
): AbstractProvider {
switch (type) {
case ProviderType.OPENAI:
@@ -78,17 +76,17 @@ export class AIModelFactory {
* Get or create a LLM provider instance with caching
*/
async getLLMProvider(): Promise<AbstractProvider> {
const providerType = this.globalConfig.llmProvider || 'openai'
const providerType = this.config.llmProvider || 'openai'
const cacheKey = `llm-${providerType}-${JSON.stringify({
apiKey: this.globalConfig.aiApiKey,
apiEndpoint: this.globalConfig.aiApiEndpoint,
model: this.globalConfig.languageModel,
apiKey: this.config.llmApiKey,
apiEndpoint: this.config.llmApiEndpoint,
model: this.config.languageModel,
})}`
if (!this.llmProviderCache.has(cacheKey)) {
const provider = this.createProvider(
providerType as ProviderType,
this.globalConfig,
this.config,
)
this.llmProviderCache.set(cacheKey, provider)
}
@@ -100,22 +98,19 @@ export class AIModelFactory {
* Get or create an embedding provider instance with caching
*/
async getEmbeddingProvider(): Promise<AbstractProvider> {
const providerType = this.globalConfig.embeddingProvider || 'openai'
const providerType = this.config.embeddingProvider || 'openai'
const cacheKey = `embedding-${providerType}-${JSON.stringify({
apiKey: this.globalConfig.embeddingApiKey,
apiEndpoint: this.globalConfig.embeddingApiEndpoint,
model: this.globalConfig.embeddingModel,
apiKey: this.config.embeddingApiKey,
apiEndpoint: this.config.embeddingApiEndpoint,
model: this.config.embeddingModel,
})}`
if (!this.embeddingProviderCache.has(cacheKey)) {
// Create a config specifically for embedding provider
const embeddingConfig: GlobalConfig = {
...this.globalConfig,
aiApiKey:
this.globalConfig.embeddingApiKey || this.globalConfig.aiApiKey,
aiApiEndpoint:
this.globalConfig.embeddingApiEndpoint ||
this.globalConfig.aiApiEndpoint,
const embeddingConfig: AICustomConfig = {
...this.config,
embeddingApiKey: this.config.embeddingApiKey,
embeddingApiEndpoint: this.config.embeddingApiEndpoint,
}
const provider = this.createProvider(
@@ -144,27 +139,12 @@ export class AIModelFactory {
return provider.Embeddings()
}
/**
* Facade method: Get vector database instance
*/
async getVectorDatabase(): Promise<PgLiteVector> {
if (!this.vectorInstance) {
this.vectorInstance = new PgLiteVector({
dataDir: this.globalConfig.dbPath,
schemaName: this.globalConfig.vectorSchemaName,
})
}
return this.vectorInstance
}
/**
* Clear provider cache
*/
clearCache(): void {
this.llmProviderCache.clear()
this.embeddingProviderCache.clear()
this.vectorInstance = undefined
}
/**
@@ -176,7 +156,7 @@ export class AIModelFactory {
}
// Convenience functions for easy access
export function initializeAI(config: GlobalConfig): AIModelFactory {
export function initializeAI(config: AICustomConfig): AIModelFactory {
return AIModelFactory.initialize(config)
}
@@ -185,7 +165,7 @@ export function getAI(): AIModelFactory {
}
// Legacy function - now just delegates to getInstance
export function createAIModelFactory(config: GlobalConfig): AIModelFactory {
export function createAIModelFactory(config: AICustomConfig): AIModelFactory {
if (AIModelFactory.isInitialized()) {
const factory = AIModelFactory.getInstance()
factory.updateConfig(config)
@@ -202,7 +182,3 @@ export async function getLanguageModel(): Promise<LanguageModelV1> {
export async function getEmbeddingModel(): Promise<EmbeddingModelV1<string>> {
return AIModelFactory.getInstance().getEmbeddingModel()
}
export async function getVectorDatabase(): Promise<PgLiteVector> {
return AIModelFactory.getInstance().getVectorDatabase()
}

View File

@@ -1,37 +1,52 @@
import { QueryResult } from '@mastra/core'
import { embed, embedMany } from 'ai'
import { INode } from '@penx/model-type'
import { VectorService } from '../db/vectorService'
import { AIModelFactory } from './aiModelFactory'
import { buildMDocument, ProcessingOptions } from './utils/nodeToDocument'
import { PgLiteVector } from './vector'
/**
* AI Service - Business logic layer for AI operations
* Responsibilities:
* - Coordinate AI models and vector operations
* - Handle embedding generation and search
* - Process business data (INode) into AI-ready format
*/
export class AIService {
static async embeddingDeleteAll(VectorStore: PgLiteVector) {
await VectorStore.truncateIndex({ indexName: 'penx_embedding' })
/**
* Clear all embeddings from the index
*/
static async embeddingDeleteAll(indexName: string = 'penx_embedding') {
const vectorService = VectorService.getInstance()
await vectorService.clearEmbeddingIndex(indexName)
}
/**
* Upsert embeddings for ICreationNode data with optimized processing
* Process and upsert embeddings for node data
* Business logic: Convert INode[] -> embeddings -> vector store
*/
static async embeddingUpsert(data: INode[], options?: ProcessingOptions) {
// 1. Get AI services
const embeddingModel =
await AIModelFactory.getInstance().getEmbeddingModel()
const VectorStore = await AIModelFactory.getInstance().getVectorDatabase()
const vectorService = VectorService.getInstance()
const vectorStore = await vectorService.getVectorStore()
// 2. Process business data into documents
const documents = await buildMDocument(data, options)
const chunks = await documents.chunk()
// 3. Generate embeddings
const { embeddings } = await embedMany({
values: chunks.map((chunk) => chunk.text),
model: embeddingModel,
})
await VectorStore.createIndex({
indexName: 'penx_embedding',
dimension: 1536,
})
// 4. Ensure index exists
await vectorService.ensureEmbeddingIndex('penx_embedding', 1536)
await VectorStore.upsert({
// 5. Store embeddings
await vectorStore.upsert({
indexName: 'penx_embedding',
vectors: embeddings,
metadata: chunks.map((chunk) => ({
@@ -43,20 +58,30 @@ export class AIService {
})
}
static async embeddingSearch(query: string): Promise<QueryResult[]> {
/**
* Search embeddings with text query
*/
static async embeddingSearch(
query: string,
topK: number = 10,
): Promise<QueryResult[]> {
// 1. Get AI services
const embeddingModel =
await AIModelFactory.getInstance().getEmbeddingModel()
const VectorStore = await AIModelFactory.getInstance().getVectorDatabase()
const vectorService = VectorService.getInstance()
const vectorStore = await vectorService.getVectorStore()
// 2. Generate query embedding
const { embedding } = await embed({
value: query,
model: embeddingModel,
})
const results = await VectorStore.query({
// 3. Search vectors
const results = await vectorStore.query({
indexName: 'penx_embedding',
queryVector: embedding,
topK: 10,
topK,
})
return results
@@ -74,29 +99,44 @@ export class AIService {
siteId?: string
featured?: boolean
},
topK: number = 10,
): Promise<QueryResult[]> {
// 1. Get AI services
const embeddingModel =
await AIModelFactory.getInstance().getEmbeddingModel()
const VectorStore = await AIModelFactory.getInstance().getVectorDatabase()
const vectorService = VectorService.getInstance()
const vectorStore = await vectorService.getVectorStore()
// 2. Generate query embedding
const { embedding } = await embed({
value: query,
model: embeddingModel,
})
const results = await VectorStore.query({
// 3. Build filter with creation-specific constraints
const searchFilter = filters
? {
...filters,
nodeType: 'CREATION', // Ensure we only get creation nodes
}
: { nodeType: 'CREATION' }
// 4. Search vectors with filters
const results = await vectorStore.query({
indexName: 'penx_embedding',
queryVector: embedding,
topK: 10,
// Add metadata filters if supported by your vector store
filter: filters
? {
...filters,
nodeType: 'CREATION', // Ensure we only get creation nodes
}
: { nodeType: 'CREATION' },
topK,
filter: searchFilter,
})
return results
}
/**
* Get embedding index statistics
*/
static async getEmbeddingStats(indexName: string = 'penx_embedding') {
const vectorService = VectorService.getInstance()
return await vectorService.getIndexStats(indexName)
}
}

View File

@@ -1,41 +1,15 @@
import { Hono } from 'hono'
import { ICreationNode, NodeType } from '@penx/model-type'
import { CreationStatus, GateType } from '@penx/types'
import { AIModelFactory } from './aiModelFactory'
import { ICreationNode } from '@penx/model-type'
import { AICustomConfig } from '../config/type'
import { initializeAI } from './aiModelFactory'
import { AIService } from './aiService'
import { GlobalConfig } from './type'
const aiHonoServer = new Hono()
// temporary config
const tempConfig: GlobalConfig = {
dbPath: './data',
llmProvider: 'openai',
languageModel: 'gpt-4o-mini',
embeddingProvider: 'openai',
embeddingModel: 'text-embedding-3-small',
embeddingDimensions: 1536,
aiApiKey: 'sk-test',
aiApiEndpoint: 'https://openai.com/v1',
embeddingApiKey: 'sk-test',
embeddingApiEndpoint: 'https://openai.com/v1',
}
AIModelFactory.initialize(tempConfig)
// Health check endpoint
aiHonoServer.get('/health', (c) => {
return c.json({
status: 'ok',
message: 'AI service is running',
config: {
llmProvider: tempConfig.llmProvider,
embeddingProvider: tempConfig.embeddingProvider,
models: {
llm: tempConfig.languageModel,
embedding: tempConfig.embeddingModel,
},
},
})
aiHonoServer.post('/config', async (c) => {
const config = await c.req.json()
initializeAI(config as AICustomConfig)
return c.json({ status: 'ok' })
})
// Embedding upsert for creation content (main endpoint)

View File

@@ -1,13 +1,13 @@
import { createOpenAI } from '@ai-sdk/openai'
import { EmbeddingModelV1, LanguageModelV1, ProviderV1 } from '@ai-sdk/provider'
import { GlobalConfig } from '../type'
import { AICustomConfig } from '../../config/type'
export abstract class AbstractProvider {
globalConfig: GlobalConfig
globalConfig: AICustomConfig
provider!: ProviderV1
protected ready: Promise<void>
constructor(globalConfig: GlobalConfig) {
constructor(globalConfig: AICustomConfig) {
this.globalConfig = globalConfig
// Create and store the initialization promise

View File

@@ -1,17 +1,17 @@
import { createOpenAI } from '@ai-sdk/openai'
import { LanguageModelV1 } from '@ai-sdk/provider'
import { GlobalConfig } from '../type'
import { AICustomConfig } from '../../config/type'
import { AbstractProvider } from './abstractProvider'
export class OpenAIModelProvider extends AbstractProvider {
constructor(globalConfig: GlobalConfig) {
constructor(globalConfig: AICustomConfig) {
super(globalConfig)
}
createProvider() {
return createOpenAI({
apiKey: this.globalConfig.aiApiKey,
baseURL: this.globalConfig.aiApiEndpoint || undefined,
apiKey: this.globalConfig.llmApiKey,
baseURL: this.globalConfig.llmApiEndpoint || undefined,
})
}

View File

@@ -19,13 +19,12 @@ export enum ProviderType {
* Defines the structure and validation for AI-related settings
*/
export const ConfigSchema = z.object({
// Vector Database
dbPath: z.string(), // Path to the database file
enable: z.boolean().optional(),
// Language Model Configuration
llmProvider: z.string().optional(), // Provider for language model (openai, anthropic, etc.)
aiApiKey: z.string().optional(), // API key for language model
aiApiEndpoint: z.string().optional(), // API endpoint for language model
llmApiKey: z.string().optional(), // API key for language model
llmApiEndpoint: z.string().optional(), // API endpoint for language model
languageModel: z.string().optional(), // Language model identifier
// Embedding Model Configuration (can be different provider)
@@ -46,4 +45,4 @@ export const ConfigSchema = z.object({
* Global AI Configuration Type
* Inferred from ConfigSchema for type safety
*/
export type GlobalConfig = z.infer<typeof ConfigSchema>
export type AICustomConfig = z.infer<typeof ConfigSchema>

View File

@@ -0,0 +1,134 @@
import { join } from 'path'
import { PGlite } from '@electric-sql/pglite'
import { vector } from '@electric-sql/pglite/vector'
export class PgLiteServer {
private static instance: PgLiteServer | null = null
private db: PGlite | null = null
private constructor() {
// Private constructor for singleton pattern
}
/**
* Get the default database path based on operating system
* @returns Database directory path
*/
private getDefaultDbPath(): string {
const platform = process.platform
const os = require('os')
const homeDir = os.homedir()
let dbPath: string
if (platform === 'win32') {
dbPath = join(homeDir, 'PenX', 'Database') // C:\Users\username\PenX\Database
} else if (platform === 'darwin') {
dbPath = join(homeDir, 'PenX', 'Database') // ~/PenX/Database
} else {
dbPath = join(homeDir, 'PenX', 'Database') // ~/PenX/Database
}
return dbPath
}
/**
* Initialize PGlite database with vector extension
* @param customPath Optional custom database path
* @returns PGlite instance
*/
private async initializeDb(customPath?: string): Promise<PGlite> {
if (this.db) {
return this.db
}
const dbPath = customPath || this.getDefaultDbPath()
console.log(`Initializing PGlite database at: ${dbPath}`)
this.db = new PGlite(dbPath, {
extensions: {
vector, // Enable vector extension by default
},
})
// Ensure vector extension is installed
try {
await this.db.query('CREATE EXTENSION IF NOT EXISTS vector;')
console.log('Vector extension installed successfully')
} catch (error) {
console.warn('Failed to install vector extension:', error)
}
return this.db
}
/**
* Get the singleton PgLiteServer instance
* @returns PgLiteServer instance
*/
public static getInstance(): PgLiteServer {
if (!PgLiteServer.instance) {
PgLiteServer.instance = new PgLiteServer()
}
return PgLiteServer.instance
}
/**
* Get the PGlite database instance
* @param customPath Optional custom database path
* @returns PGlite instance
*/
public async getDb(customPath?: string): Promise<PGlite> {
return await this.initializeDb(customPath)
}
/**
* Close the database connection
*/
public async close(): Promise<void> {
if (this.db) {
await this.db.close()
this.db = null
console.log('PGlite database connection closed')
}
}
/**
* Get database information
* @returns Database status and path
*/
public getInfo(): { isConnected: boolean; path: string | null } {
return {
isConnected: this.db !== null,
path: this.db ? this.getDefaultDbPath() : null,
}
}
}
/**
* Convenience function to get PGlite database instance
* @param customPath Optional custom database path
* @returns PGlite instance
*/
export async function getPgLiteDb(customPath?: string): Promise<PGlite> {
const server = PgLiteServer.getInstance()
return await server.getDb(customPath)
}
/**
* Convenience function to close the database connection
*/
export async function closePgLiteDb(): Promise<void> {
const server = PgLiteServer.getInstance()
await server.close()
}
/**
* Get PGlite database information
*/
export function getPgLiteInfo(): { isConnected: boolean; path: string | null } {
const server = PgLiteServer.getInstance()
return server.getInfo()
}

View File

@@ -1,5 +1,4 @@
import { PGlite } from '@electric-sql/pglite'
import { vector } from '@electric-sql/pglite/vector'
import { ErrorCategory, ErrorDomain, MastraError } from '@mastra/core/error'
import { parseSqlIdentifier } from '@mastra/core/utils'
import { MastraVector } from '@mastra/core/vector'
@@ -62,28 +61,12 @@ export class PgLiteVector extends MastraVector<PGVectorFilter> {
private createdIndexes = new Map<string, number>()
private mutexesByName = new Map<string, Mutex>()
private schema?: string
private setupSchemaPromise: Promise<void> | null = null
private installVectorExtensionPromise: Promise<void> | null = null
private vectorExtensionInstalled: boolean | undefined = undefined
private schemaSetupComplete: boolean | undefined = undefined
constructor(
config: {
/**
* PGlite storage location:
* - undefined or ":memory:" for in-memory database
* - "idb://name" for IndexedDB storage (browser)
* - "./path/to/data" for file system storage (Node/Bun/Deno)
*/
dataDir?: string
/**
* Schema name for tables
*/
schemaName?: string
} = {},
) {
if (!config.dataDir) {
throw new Error('PgLiteVector: dataDir is required')
constructor(db: PGlite) {
if (!db) {
throw new Error('PgLiteVector: PGlite instance is required')
}
super()
@@ -92,12 +75,8 @@ export class PgLiteVector extends MastraVector<PGVectorFilter> {
// Custom schemas can cause transaction issues
this.schema = undefined // Force use of public schema
// Initialize PGlite
this.db = new PGlite(config.dataDir, {
extensions: {
vector,
},
})
// Use the provided PGlite instance
this.db = db
void (async () => {
// warm the created indexes cache so we don't need to check if indexes exist every time

View File

@@ -0,0 +1,88 @@
import { PGlite } from '@electric-sql/pglite'
import { getPgLiteDb } from './index'
import { PgLiteVector } from './vector'
export class VectorService {
private static instance: VectorService | null = null
private vectorStore: PgLiteVector | null = null
private constructor() {
// Private constructor for singleton pattern
}
/**
* Get the singleton VectorService instance
*/
public static getInstance(): VectorService {
if (!VectorService.instance) {
VectorService.instance = new VectorService()
}
return VectorService.instance
}
/**
* Initialize vector store with PGlite database
*/
private async initializeVectorStore(): Promise<PgLiteVector> {
if (this.vectorStore) {
return this.vectorStore
}
// Get PGlite database instance
const db: PGlite = await getPgLiteDb()
// Create vector store instance
this.vectorStore = new PgLiteVector(db)
return this.vectorStore
}
/**
* Get vector store instance
*/
public async getVectorStore(): Promise<PgLiteVector> {
return await this.initializeVectorStore()
}
/**
* Create embedding index if not exists
*/
public async ensureEmbeddingIndex(
indexName: string = 'penx_embedding',
dimension: number = 1536,
): Promise<void> {
const vectorStore = await this.getVectorStore()
await vectorStore.createIndex({
indexName,
dimension,
})
}
/**
* Clear all data from embedding index
*/
public async clearEmbeddingIndex(
indexName: string = 'penx_embedding',
): Promise<void> {
const vectorStore = await this.getVectorStore()
await vectorStore.truncateIndex({ indexName })
}
/**
* Get vector store statistics
*/
public async getIndexStats(indexName: string = 'penx_embedding') {
const vectorStore = await this.getVectorStore()
return await vectorStore.describeIndex({ indexName })
}
}
// Convenience functions
export async function getVectorService(): Promise<VectorService> {
return VectorService.getInstance()
}
export async function getVectorStore(): Promise<PgLiteVector> {
const service = VectorService.getInstance()
return await service.getVectorStore()
}

View File

@@ -0,0 +1,2 @@
export * from './ai'
export * from './db'