misc: addressed token refresh issue

This commit is contained in:
Sheen Capadngan
2025-12-17 04:02:20 +08:00
parent 139c346fac
commit 4e6069439b
5 changed files with 121 additions and 78 deletions

View File

@@ -144,7 +144,9 @@ export const registerAiMcpServerRouter = async (server: FastifyZodProvider) => {
accessToken: z.string().optional(),
refreshToken: z.string().optional(),
expiresAt: z.number().optional(),
tokenType: z.string().optional()
tokenType: z.string().optional(),
clientId: z.string().optional(),
clientSecret: z.string().optional()
})
}
},

View File

@@ -65,19 +65,29 @@ const TOKEN_REFRESH_BUFFER_MS = 5 * 60 * 1000;
*/
const refreshOAuthToken = async (
serverUrl: string,
refreshToken: string
refreshToken: string,
clientId: string,
clientSecret?: string
): Promise<{ accessToken: string; refreshToken?: string; expiresAt?: number }> => {
const issuer = new URL(serverUrl).origin;
const { data: serverMetadata } = await request.get<TOAuthAuthorizationServerMetadata>(
`${issuer}/.well-known/oauth-authorization-server`
);
const tokenParams: Record<string, string> = {
grant_type: "refresh_token",
refresh_token: refreshToken,
client_id: clientId
};
// Add client_secret for confidential clients
if (clientSecret) {
tokenParams.client_secret = clientSecret;
}
const { data: tokenResponse } = await request.post<TOAuthTokenResponse>(
serverMetadata.token_endpoint,
new URLSearchParams({
grant_type: "refresh_token",
refresh_token: refreshToken
}).toString(),
new URLSearchParams(tokenParams).toString(),
{
headers: {
"Content-Type": "application/x-www-form-urlencoded"
@@ -486,7 +496,9 @@ export const aiMcpServerServiceFactory = ({
accessToken: sessionData.accessToken,
refreshToken: sessionData.refreshToken,
expiresAt: sessionData.expiresAt,
tokenType: sessionData.tokenType
tokenType: sessionData.tokenType,
clientId: sessionData.clientId,
clientSecret: sessionData.clientSecret
};
};
@@ -791,29 +803,47 @@ export const aiMcpServerServiceFactory = ({
Date.now() >= credentials.expiresAt - TOKEN_REFRESH_BUFFER_MS;
if (isExpired && "refreshToken" in credentials && credentials.refreshToken) {
try {
const refreshedTokens = await refreshOAuthToken(server.url, credentials.refreshToken);
credentials = {
...credentials,
accessToken: refreshedTokens.accessToken,
refreshToken: refreshedTokens.refreshToken || credentials.refreshToken,
expiresAt: refreshedTokens.expiresAt
// Decrypt OAuth config to get client credentials for refresh
let oauthConfig: { clientId: string; clientSecret?: string } | undefined;
if (server.encryptedOauthConfig) {
oauthConfig = JSON.parse(decryptor({ cipherTextBlob: server.encryptedOauthConfig }).toString()) as {
clientId: string;
clientSecret?: string;
};
}
// Persist the refreshed credentials
const { cipherTextBlob: newEncryptedCredentials } = encryptor({
plainText: Buffer.from(JSON.stringify(credentials))
});
if (!oauthConfig?.clientId) {
logger.error({ serverId }, "Cannot refresh OAuth token: missing client_id in OAuth config");
} else {
try {
const refreshedTokens = await refreshOAuthToken(
server.url,
credentials.refreshToken,
oauthConfig.clientId,
oauthConfig.clientSecret
);
await aiMcpServerDAL.updateById(serverId, {
encryptedCredentials: newEncryptedCredentials
});
credentials = {
...credentials,
accessToken: refreshedTokens.accessToken,
refreshToken: refreshedTokens.refreshToken || credentials.refreshToken,
expiresAt: refreshedTokens.expiresAt
};
logger.info({ serverId }, "Refreshed OAuth token for MCP server");
} catch (refreshError) {
logger.error(refreshError, `Failed to refresh OAuth token for MCP server ${serverId}`);
// Return expired token - caller can decide how to handle the error
// Persist the refreshed credentials
const { cipherTextBlob: newEncryptedCredentials } = encryptor({
plainText: Buffer.from(JSON.stringify(credentials))
});
await aiMcpServerDAL.updateById(serverId, {
encryptedCredentials: newEncryptedCredentials
});
logger.info({ serverId }, "Refreshed OAuth token for MCP server");
} catch (refreshError) {
logger.error(refreshError, `Failed to refresh OAuth token for MCP server ${serverId}`);
// Return expired token - caller can decide how to handle the error
}
}
}
@@ -829,6 +859,24 @@ export const aiMcpServerServiceFactory = ({
return { credentials, accessToken: undefined };
};
// Get stored OAuth config (client ID/secret) for a server
const getServerOAuthConfig = async (
serverId: string
): Promise<{ clientId: string; clientSecret?: string } | null> => {
const server = await aiMcpServerDAL.findById(serverId);
if (!server || !server.encryptedOauthConfig) {
return null;
}
const { decryptor } = await kmsService.createCipherPairWithDataKey({
type: KmsDataKey.SecretManager,
projectId: server.projectId
});
const decrypted = decryptor({ cipherTextBlob: server.encryptedOauthConfig });
return JSON.parse(decrypted.toString()) as { clientId: string; clientSecret?: string };
};
// Get user's personal credentials for a server (with token refresh)
const getUserServerCredentials = async ({
serverId,
@@ -863,31 +911,43 @@ export const aiMcpServerServiceFactory = ({
const isExpired = credentials.expiresAt && Date.now() >= credentials.expiresAt - TOKEN_REFRESH_BUFFER_MS;
if (isExpired && credentials.refreshToken) {
try {
const refreshedTokens = await refreshOAuthToken(serverUrl, credentials.refreshToken);
// Get OAuth config (client_id is needed for refresh)
const oauthConfig = await getServerOAuthConfig(serverId);
credentials = {
...credentials,
accessToken: refreshedTokens.accessToken,
refreshToken: refreshedTokens.refreshToken || credentials.refreshToken,
expiresAt: refreshedTokens.expiresAt
};
if (!oauthConfig?.clientId) {
logger.error({ serverId, userId }, "Cannot refresh OAuth token: missing client_id in OAuth config");
} else {
try {
const refreshedTokens = await refreshOAuthToken(
serverUrl,
credentials.refreshToken,
oauthConfig.clientId,
oauthConfig.clientSecret
);
// Persist the refreshed credentials
const { cipherTextBlob: newEncryptedCredentials } = encryptor({
plainText: Buffer.from(JSON.stringify(credentials))
});
credentials = {
...credentials,
accessToken: refreshedTokens.accessToken,
refreshToken: refreshedTokens.refreshToken || credentials.refreshToken,
expiresAt: refreshedTokens.expiresAt
};
await aiMcpServerUserCredentialDAL.upsertCredential({
userId,
aiMcpServerId: serverId,
encryptedCredentials: newEncryptedCredentials
});
// Persist the refreshed credentials
const { cipherTextBlob: newEncryptedCredentials } = encryptor({
plainText: Buffer.from(JSON.stringify(credentials))
});
logger.info({ serverId, userId }, "Refreshed OAuth token for user's MCP server credentials");
} catch (refreshError) {
logger.error(refreshError, `Failed to refresh OAuth token for user ${userId} on MCP server ${serverId}`);
// Return expired token - caller can decide how to handle the error
await aiMcpServerUserCredentialDAL.upsertCredential({
userId,
aiMcpServerId: serverId,
encryptedCredentials: newEncryptedCredentials
});
logger.info({ serverId, userId }, "Refreshed OAuth token for user's MCP server credentials");
} catch (refreshError) {
logger.error(refreshError, `Failed to refresh OAuth token for user ${userId} on MCP server ${serverId}`);
// Return expired token - caller can decide how to handle the error
}
}
}
@@ -914,24 +974,6 @@ export const aiMcpServerServiceFactory = ({
return getServerCredentials({ serverId, projectId });
};
// Get stored OAuth config (client ID/secret) for a server
const getServerOAuthConfig = async (
serverId: string
): Promise<{ clientId: string; clientSecret?: string } | null> => {
const server = await aiMcpServerDAL.findById(serverId);
if (!server || !server.encryptedOauthConfig) {
return null;
}
const { decryptor } = await kmsService.createCipherPairWithDataKey({
type: KmsDataKey.SecretManager,
projectId: server.projectId
});
const decrypted = decryptor({ cipherTextBlob: server.encryptedOauthConfig });
return JSON.parse(decrypted.toString()) as { clientId: string; clientSecret?: string };
};
return {
initiateOAuth,
handleOAuthCallback,
@@ -943,8 +985,6 @@ export const aiMcpServerServiceFactory = ({
deleteMcpServer,
listMcpServerTools,
syncMcpServerTools,
getServerCredentials,
getUserServerCredentials,
getCredentialsForServer,
getServerOAuthConfig
};

View File

@@ -97,6 +97,8 @@ export type TOAuthStatusResponse = {
refreshToken?: string;
expiresAt?: number;
tokenType?: string;
clientId?: string;
clientSecret?: string;
};
// List MCP Servers DTO

View File

@@ -72,6 +72,9 @@ export const AuthenticationStep = ({ onOAuthSuccess }: Props) => {
{ shouldDirty: true, shouldTouch: true, shouldValidate: true }
);
setValue("oauthClientId", oauthStatus.clientId || "");
setValue("oauthClientSecret", oauthStatus.clientSecret || "");
setIsOAuthPending(false);
setOauthSessionId(null);

View File

@@ -1,8 +1,4 @@
import { faRefresh } from "@fortawesome/free-solid-svg-icons";
import { FontAwesomeIcon } from "@fortawesome/react-fontawesome";
import {
Button,
EmptyState,
Table,
TableContainer,
@@ -13,7 +9,7 @@ import {
Tooltip,
Tr
} from "@app/components/v2";
import { useListAiMcpServerTools, useSyncAiMcpServerTools } from "@app/hooks/api";
import { useListAiMcpServerTools } from "@app/hooks/api";
type Props = {
serverId: string;
@@ -21,13 +17,13 @@ type Props = {
export const MCPServerAvailableToolsSection = ({ serverId }: Props) => {
const { data: toolsData, isPending } = useListAiMcpServerTools({ serverId });
const syncTools = useSyncAiMcpServerTools();
// const syncTools = useSyncAiMcpServerTools();
const tools = toolsData?.tools || [];
const handleSyncTools = async () => {
await syncTools.mutateAsync({ serverId });
};
// const handleSyncTools = async () => {
// await syncTools.mutateAsync({ serverId });
// };
return (
<div className="flex w-full flex-col rounded-lg border border-mineshaft-600 bg-mineshaft-900">
@@ -38,7 +34,7 @@ export const MCPServerAvailableToolsSection = ({ serverId }: Props) => {
Tools provided by this MCP server that can be enabled in endpoints
</p>
</div>
<Button
{/* <Button
variant="outline_bg"
size="sm"
leftIcon={<FontAwesomeIcon icon={faRefresh} />}
@@ -46,7 +42,7 @@ export const MCPServerAvailableToolsSection = ({ serverId }: Props) => {
isLoading={syncTools.isPending}
>
Sync Tools
</Button>
</Button> */}
</div>
<div className="p-4">
<TableContainer>