fix(storage): support Azure connection string for presigned URLs (#2997)

* fix(docs): update requirements to be more accurate for deploying the app

* updated kb to support 1536 dimension vectors for models other than text embedding 3 small

* fix(storage): support Azure connection string for presigned URLs

* fix(kb): update test for embedding dimensions parameter

* fix(storage): align credential source ordering for consistency
This commit is contained in:
Waleed
2026-01-25 13:06:12 -08:00
committed by GitHub
parent 1bf5ed4586
commit be2a9ef0f8
15 changed files with 183 additions and 87 deletions

View File

@@ -8,6 +8,17 @@ const logger = createLogger('EmbeddingUtils')
const MAX_TOKENS_PER_REQUEST = 8000
const MAX_CONCURRENT_BATCHES = env.KB_CONFIG_CONCURRENCY_LIMIT || 50
const EMBEDDING_DIMENSIONS = 1536
/**
* Check if the model supports custom dimensions.
* text-embedding-3-* models support the dimensions parameter.
* Checks for 'embedding-3' to handle Azure deployments with custom naming conventions.
*/
function supportsCustomDimensions(modelName: string): boolean {
const name = modelName.toLowerCase()
return name.includes('embedding-3') && !name.includes('ada')
}
export class EmbeddingAPIError extends Error {
public status: number
@@ -93,15 +104,19 @@ async function getEmbeddingConfig(
async function callEmbeddingAPI(inputs: string[], config: EmbeddingConfig): Promise<number[][]> {
return retryWithExponentialBackoff(
async () => {
const useDimensions = supportsCustomDimensions(config.modelName)
const requestBody = config.useAzure
? {
input: inputs,
encoding_format: 'float',
...(useDimensions && { dimensions: EMBEDDING_DIMENSIONS }),
}
: {
input: inputs,
model: config.modelName,
encoding_format: 'float',
...(useDimensions && { dimensions: EMBEDDING_DIMENSIONS }),
}
const response = await fetch(config.apiUrl, {

View File

@@ -18,6 +18,52 @@ const logger = createLogger('BlobClient')
let _blobServiceClient: BlobServiceClientInstance | null = null
interface ParsedCredentials {
accountName: string
accountKey: string
}
/**
* Extract account name and key from an Azure connection string.
* Connection strings have the format: DefaultEndpointsProtocol=https;AccountName=...;AccountKey=...;EndpointSuffix=...
*/
function parseConnectionString(connectionString: string): ParsedCredentials {
const accountNameMatch = connectionString.match(/AccountName=([^;]+)/)
if (!accountNameMatch) {
throw new Error('Cannot extract account name from connection string')
}
const accountKeyMatch = connectionString.match(/AccountKey=([^;]+)/)
if (!accountKeyMatch) {
throw new Error('Cannot extract account key from connection string')
}
return {
accountName: accountNameMatch[1],
accountKey: accountKeyMatch[1],
}
}
/**
* Get account credentials from BLOB_CONFIG, extracting from connection string if necessary.
*/
function getAccountCredentials(): ParsedCredentials {
if (BLOB_CONFIG.connectionString) {
return parseConnectionString(BLOB_CONFIG.connectionString)
}
if (BLOB_CONFIG.accountName && BLOB_CONFIG.accountKey) {
return {
accountName: BLOB_CONFIG.accountName,
accountKey: BLOB_CONFIG.accountKey,
}
}
throw new Error(
'Azure Blob Storage credentials are missing set AZURE_CONNECTION_STRING or both AZURE_ACCOUNT_NAME and AZURE_ACCOUNT_KEY'
)
}
export async function getBlobServiceClient(): Promise<BlobServiceClientInstance> {
if (_blobServiceClient) return _blobServiceClient
@@ -127,6 +173,8 @@ export async function getPresignedUrl(key: string, expiresIn = 3600) {
const containerClient = blobServiceClient.getContainerClient(BLOB_CONFIG.containerName)
const blockBlobClient = containerClient.getBlockBlobClient(key)
const { accountName, accountKey } = getAccountCredentials()
const sasOptions = {
containerName: BLOB_CONFIG.containerName,
blobName: key,
@@ -137,13 +185,7 @@ export async function getPresignedUrl(key: string, expiresIn = 3600) {
const sasToken = generateBlobSASQueryParameters(
sasOptions,
new StorageSharedKeyCredential(
BLOB_CONFIG.accountName,
BLOB_CONFIG.accountKey ??
(() => {
throw new Error('AZURE_ACCOUNT_KEY is required when using account name authentication')
})()
)
new StorageSharedKeyCredential(accountName, accountKey)
).toString()
return `${blockBlobClient.url}?${sasToken}`
@@ -168,9 +210,14 @@ export async function getPresignedUrlWithConfig(
StorageSharedKeyCredential,
} = await import('@azure/storage-blob')
let tempBlobServiceClient: BlobServiceClientInstance
let accountName: string
let accountKey: string
if (customConfig.connectionString) {
tempBlobServiceClient = BlobServiceClient.fromConnectionString(customConfig.connectionString)
const credentials = parseConnectionString(customConfig.connectionString)
accountName = credentials.accountName
accountKey = credentials.accountKey
} else if (customConfig.accountName && customConfig.accountKey) {
const sharedKeyCredential = new StorageSharedKeyCredential(
customConfig.accountName,
@@ -180,6 +227,8 @@ export async function getPresignedUrlWithConfig(
`https://${customConfig.accountName}.blob.core.windows.net`,
sharedKeyCredential
)
accountName = customConfig.accountName
accountKey = customConfig.accountKey
} else {
throw new Error(
'Custom blob config must include either connectionString or accountName + accountKey'
@@ -199,13 +248,7 @@ export async function getPresignedUrlWithConfig(
const sasToken = generateBlobSASQueryParameters(
sasOptions,
new StorageSharedKeyCredential(
customConfig.accountName,
customConfig.accountKey ??
(() => {
throw new Error('Account key is required when using account name authentication')
})()
)
new StorageSharedKeyCredential(accountName, accountKey)
).toString()
return `${blockBlobClient.url}?${sasToken}`
@@ -403,13 +446,9 @@ export async function getMultipartPartUrls(
if (customConfig) {
if (customConfig.connectionString) {
blobServiceClient = BlobServiceClient.fromConnectionString(customConfig.connectionString)
const match = customConfig.connectionString.match(/AccountName=([^;]+)/)
if (!match) throw new Error('Cannot extract account name from connection string')
accountName = match[1]
const keyMatch = customConfig.connectionString.match(/AccountKey=([^;]+)/)
if (!keyMatch) throw new Error('Cannot extract account key from connection string')
accountKey = keyMatch[1]
const credentials = parseConnectionString(customConfig.connectionString)
accountName = credentials.accountName
accountKey = credentials.accountKey
} else if (customConfig.accountName && customConfig.accountKey) {
const credential = new StorageSharedKeyCredential(
customConfig.accountName,
@@ -428,12 +467,9 @@ export async function getMultipartPartUrls(
} else {
blobServiceClient = await getBlobServiceClient()
containerName = BLOB_CONFIG.containerName
accountName = BLOB_CONFIG.accountName
accountKey =
BLOB_CONFIG.accountKey ||
(() => {
throw new Error('AZURE_ACCOUNT_KEY is required')
})()
const credentials = getAccountCredentials()
accountName = credentials.accountName
accountKey = credentials.accountKey
}
const containerClient = blobServiceClient.getContainerClient(containerName)
@@ -501,12 +537,10 @@ export async function completeMultipartUpload(
const containerClient = blobServiceClient.getContainerClient(containerName)
const blockBlobClient = containerClient.getBlockBlobClient(key)
// Sort parts by part number and extract block IDs
const sortedBlockIds = parts
.sort((a, b) => a.partNumber - b.partNumber)
.map((part) => part.blockId)
// Commit the block list to create the final blob
await blockBlobClient.commitBlockList(sortedBlockIds, {
metadata: {
multipartUpload: 'completed',
@@ -557,10 +591,8 @@ export async function abortMultipartUpload(key: string, customConfig?: BlobConfi
const blockBlobClient = containerClient.getBlockBlobClient(key)
try {
// Delete the blob if it exists (this also cleans up any uncommitted blocks)
await blockBlobClient.deleteIfExists()
} catch (error) {
// Ignore errors since we're just cleaning up
logger.warn('Error cleaning up multipart upload:', error)
}
}