mirror of
https://github.com/simstudioai/sim.git
synced 2026-04-28 03:00:29 -04:00
fix(storage): support Azure connection string for presigned URLs (#2997)
* fix(docs): update requirements to be more accurate for deploying the app * updated kb to support 1536 dimension vectors for models other than text embedding 3 small * fix(storage): support Azure connection string for presigned URLs * fix(kb): update test for embedding dimensions parameter * fix(storage): align credential source ordering for consistency
This commit is contained in:
@@ -8,6 +8,17 @@ const logger = createLogger('EmbeddingUtils')
|
||||
|
||||
const MAX_TOKENS_PER_REQUEST = 8000
|
||||
const MAX_CONCURRENT_BATCHES = env.KB_CONFIG_CONCURRENCY_LIMIT || 50
|
||||
const EMBEDDING_DIMENSIONS = 1536
|
||||
|
||||
/**
|
||||
* Check if the model supports custom dimensions.
|
||||
* text-embedding-3-* models support the dimensions parameter.
|
||||
* Checks for 'embedding-3' to handle Azure deployments with custom naming conventions.
|
||||
*/
|
||||
function supportsCustomDimensions(modelName: string): boolean {
|
||||
const name = modelName.toLowerCase()
|
||||
return name.includes('embedding-3') && !name.includes('ada')
|
||||
}
|
||||
|
||||
export class EmbeddingAPIError extends Error {
|
||||
public status: number
|
||||
@@ -93,15 +104,19 @@ async function getEmbeddingConfig(
|
||||
async function callEmbeddingAPI(inputs: string[], config: EmbeddingConfig): Promise<number[][]> {
|
||||
return retryWithExponentialBackoff(
|
||||
async () => {
|
||||
const useDimensions = supportsCustomDimensions(config.modelName)
|
||||
|
||||
const requestBody = config.useAzure
|
||||
? {
|
||||
input: inputs,
|
||||
encoding_format: 'float',
|
||||
...(useDimensions && { dimensions: EMBEDDING_DIMENSIONS }),
|
||||
}
|
||||
: {
|
||||
input: inputs,
|
||||
model: config.modelName,
|
||||
encoding_format: 'float',
|
||||
...(useDimensions && { dimensions: EMBEDDING_DIMENSIONS }),
|
||||
}
|
||||
|
||||
const response = await fetch(config.apiUrl, {
|
||||
|
||||
@@ -18,6 +18,52 @@ const logger = createLogger('BlobClient')
|
||||
|
||||
let _blobServiceClient: BlobServiceClientInstance | null = null
|
||||
|
||||
interface ParsedCredentials {
|
||||
accountName: string
|
||||
accountKey: string
|
||||
}
|
||||
|
||||
/**
|
||||
* Extract account name and key from an Azure connection string.
|
||||
* Connection strings have the format: DefaultEndpointsProtocol=https;AccountName=...;AccountKey=...;EndpointSuffix=...
|
||||
*/
|
||||
function parseConnectionString(connectionString: string): ParsedCredentials {
|
||||
const accountNameMatch = connectionString.match(/AccountName=([^;]+)/)
|
||||
if (!accountNameMatch) {
|
||||
throw new Error('Cannot extract account name from connection string')
|
||||
}
|
||||
|
||||
const accountKeyMatch = connectionString.match(/AccountKey=([^;]+)/)
|
||||
if (!accountKeyMatch) {
|
||||
throw new Error('Cannot extract account key from connection string')
|
||||
}
|
||||
|
||||
return {
|
||||
accountName: accountNameMatch[1],
|
||||
accountKey: accountKeyMatch[1],
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Get account credentials from BLOB_CONFIG, extracting from connection string if necessary.
|
||||
*/
|
||||
function getAccountCredentials(): ParsedCredentials {
|
||||
if (BLOB_CONFIG.connectionString) {
|
||||
return parseConnectionString(BLOB_CONFIG.connectionString)
|
||||
}
|
||||
|
||||
if (BLOB_CONFIG.accountName && BLOB_CONFIG.accountKey) {
|
||||
return {
|
||||
accountName: BLOB_CONFIG.accountName,
|
||||
accountKey: BLOB_CONFIG.accountKey,
|
||||
}
|
||||
}
|
||||
|
||||
throw new Error(
|
||||
'Azure Blob Storage credentials are missing – set AZURE_CONNECTION_STRING or both AZURE_ACCOUNT_NAME and AZURE_ACCOUNT_KEY'
|
||||
)
|
||||
}
|
||||
|
||||
export async function getBlobServiceClient(): Promise<BlobServiceClientInstance> {
|
||||
if (_blobServiceClient) return _blobServiceClient
|
||||
|
||||
@@ -127,6 +173,8 @@ export async function getPresignedUrl(key: string, expiresIn = 3600) {
|
||||
const containerClient = blobServiceClient.getContainerClient(BLOB_CONFIG.containerName)
|
||||
const blockBlobClient = containerClient.getBlockBlobClient(key)
|
||||
|
||||
const { accountName, accountKey } = getAccountCredentials()
|
||||
|
||||
const sasOptions = {
|
||||
containerName: BLOB_CONFIG.containerName,
|
||||
blobName: key,
|
||||
@@ -137,13 +185,7 @@ export async function getPresignedUrl(key: string, expiresIn = 3600) {
|
||||
|
||||
const sasToken = generateBlobSASQueryParameters(
|
||||
sasOptions,
|
||||
new StorageSharedKeyCredential(
|
||||
BLOB_CONFIG.accountName,
|
||||
BLOB_CONFIG.accountKey ??
|
||||
(() => {
|
||||
throw new Error('AZURE_ACCOUNT_KEY is required when using account name authentication')
|
||||
})()
|
||||
)
|
||||
new StorageSharedKeyCredential(accountName, accountKey)
|
||||
).toString()
|
||||
|
||||
return `${blockBlobClient.url}?${sasToken}`
|
||||
@@ -168,9 +210,14 @@ export async function getPresignedUrlWithConfig(
|
||||
StorageSharedKeyCredential,
|
||||
} = await import('@azure/storage-blob')
|
||||
let tempBlobServiceClient: BlobServiceClientInstance
|
||||
let accountName: string
|
||||
let accountKey: string
|
||||
|
||||
if (customConfig.connectionString) {
|
||||
tempBlobServiceClient = BlobServiceClient.fromConnectionString(customConfig.connectionString)
|
||||
const credentials = parseConnectionString(customConfig.connectionString)
|
||||
accountName = credentials.accountName
|
||||
accountKey = credentials.accountKey
|
||||
} else if (customConfig.accountName && customConfig.accountKey) {
|
||||
const sharedKeyCredential = new StorageSharedKeyCredential(
|
||||
customConfig.accountName,
|
||||
@@ -180,6 +227,8 @@ export async function getPresignedUrlWithConfig(
|
||||
`https://${customConfig.accountName}.blob.core.windows.net`,
|
||||
sharedKeyCredential
|
||||
)
|
||||
accountName = customConfig.accountName
|
||||
accountKey = customConfig.accountKey
|
||||
} else {
|
||||
throw new Error(
|
||||
'Custom blob config must include either connectionString or accountName + accountKey'
|
||||
@@ -199,13 +248,7 @@ export async function getPresignedUrlWithConfig(
|
||||
|
||||
const sasToken = generateBlobSASQueryParameters(
|
||||
sasOptions,
|
||||
new StorageSharedKeyCredential(
|
||||
customConfig.accountName,
|
||||
customConfig.accountKey ??
|
||||
(() => {
|
||||
throw new Error('Account key is required when using account name authentication')
|
||||
})()
|
||||
)
|
||||
new StorageSharedKeyCredential(accountName, accountKey)
|
||||
).toString()
|
||||
|
||||
return `${blockBlobClient.url}?${sasToken}`
|
||||
@@ -403,13 +446,9 @@ export async function getMultipartPartUrls(
|
||||
if (customConfig) {
|
||||
if (customConfig.connectionString) {
|
||||
blobServiceClient = BlobServiceClient.fromConnectionString(customConfig.connectionString)
|
||||
const match = customConfig.connectionString.match(/AccountName=([^;]+)/)
|
||||
if (!match) throw new Error('Cannot extract account name from connection string')
|
||||
accountName = match[1]
|
||||
|
||||
const keyMatch = customConfig.connectionString.match(/AccountKey=([^;]+)/)
|
||||
if (!keyMatch) throw new Error('Cannot extract account key from connection string')
|
||||
accountKey = keyMatch[1]
|
||||
const credentials = parseConnectionString(customConfig.connectionString)
|
||||
accountName = credentials.accountName
|
||||
accountKey = credentials.accountKey
|
||||
} else if (customConfig.accountName && customConfig.accountKey) {
|
||||
const credential = new StorageSharedKeyCredential(
|
||||
customConfig.accountName,
|
||||
@@ -428,12 +467,9 @@ export async function getMultipartPartUrls(
|
||||
} else {
|
||||
blobServiceClient = await getBlobServiceClient()
|
||||
containerName = BLOB_CONFIG.containerName
|
||||
accountName = BLOB_CONFIG.accountName
|
||||
accountKey =
|
||||
BLOB_CONFIG.accountKey ||
|
||||
(() => {
|
||||
throw new Error('AZURE_ACCOUNT_KEY is required')
|
||||
})()
|
||||
const credentials = getAccountCredentials()
|
||||
accountName = credentials.accountName
|
||||
accountKey = credentials.accountKey
|
||||
}
|
||||
|
||||
const containerClient = blobServiceClient.getContainerClient(containerName)
|
||||
@@ -501,12 +537,10 @@ export async function completeMultipartUpload(
|
||||
const containerClient = blobServiceClient.getContainerClient(containerName)
|
||||
const blockBlobClient = containerClient.getBlockBlobClient(key)
|
||||
|
||||
// Sort parts by part number and extract block IDs
|
||||
const sortedBlockIds = parts
|
||||
.sort((a, b) => a.partNumber - b.partNumber)
|
||||
.map((part) => part.blockId)
|
||||
|
||||
// Commit the block list to create the final blob
|
||||
await blockBlobClient.commitBlockList(sortedBlockIds, {
|
||||
metadata: {
|
||||
multipartUpload: 'completed',
|
||||
@@ -557,10 +591,8 @@ export async function abortMultipartUpload(key: string, customConfig?: BlobConfi
|
||||
const blockBlobClient = containerClient.getBlockBlobClient(key)
|
||||
|
||||
try {
|
||||
// Delete the blob if it exists (this also cleans up any uncommitted blocks)
|
||||
await blockBlobClient.deleteIfExists()
|
||||
} catch (error) {
|
||||
// Ignore errors since we're just cleaning up
|
||||
logger.warn('Error cleaning up multipart upload:', error)
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user