diff --git a/.env.example b/.env.example index 9864a4148209..552e59e2e0a3 100644 --- a/.env.example +++ b/.env.example @@ -511,6 +511,10 @@ OPENID_AUTO_REDIRECT=false OPENID_USE_PKCE=false #Set to true to reuse openid tokens for authentication management instead of using the mongodb session and the custom refresh token. OPENID_REUSE_TOKENS= +# Set to true to expose a JWT-signed cookie containing the OpenID `sub` claim with sameSite=lax. +# This enables cross-origin OAuth callback flows (e.g., AWS Bedrock AgentCore 3LO). +# Can be used independently of OPENID_REUSE_TOKENS. +OPENID_EXPOSE_SUB_COOKIE= #By default, signing key verification results are cached in order to prevent excessive HTTP requests to the JWKS endpoint. #If a signing key matching the kid is found, this will be cached and the next time this kid is requested the signing key will be served from the cache. #Default is true. diff --git a/api/cache/getLogStores.js b/api/cache/getLogStores.js index 40aac08ee633..59406899575b 100644 --- a/api/cache/getLogStores.js +++ b/api/cache/getLogStores.js @@ -51,6 +51,10 @@ const namespaces = { CacheKeys.OPENID_EXCHANGED_TOKENS, Time.TEN_MINUTES, ), + [CacheKeys.ADMIN_OAUTH_EXCHANGE]: standardCache( + CacheKeys.ADMIN_OAUTH_EXCHANGE, + Time.THIRTY_SECONDS, + ), }; /** diff --git a/api/models/Agent.js b/api/models/Agent.js index 11789ca63b05..663285183a4b 100644 --- a/api/models/Agent.js +++ b/api/models/Agent.js @@ -589,10 +589,16 @@ const deleteAgent = async (searchParameter) => { const agent = await Agent.findOneAndDelete(searchParameter); if (agent) { await removeAgentFromAllProjects(agent.id); - await removeAllPermissions({ - resourceType: ResourceType.AGENT, - resourceId: agent._id, - }); + await Promise.all([ + removeAllPermissions({ + resourceType: ResourceType.AGENT, + resourceId: agent._id, + }), + removeAllPermissions({ + resourceType: ResourceType.REMOTE_AGENT, + resourceId: agent._id, + }), + ]); try { await Agent.updateMany({ 'edges.to': agent.id }, { $pull: { edges: { to: agent.id } } }); } catch (error) { @@ -631,7 +637,7 @@ const deleteUserAgents = async (userId) => { } await AclEntry.deleteMany({ - resourceType: ResourceType.AGENT, + resourceType: { $in: [ResourceType.AGENT, ResourceType.REMOTE_AGENT] }, resourceId: { $in: agentObjectIds }, }); diff --git a/api/models/File.js b/api/models/File.js index 5e90c86fe4b5..1a01ef12f97e 100644 --- a/api/models/File.js +++ b/api/models/File.js @@ -26,7 +26,8 @@ const getFiles = async (filter, _sortOptions, selectFields = { text: 0 }) => { }; /** - * Retrieves tool files (files that are embedded or have a fileIdentifier) from an array of file IDs + * Retrieves tool files (files that are embedded or have a fileIdentifier) from an array of file IDs. + * Note: execute_code files are handled separately by getCodeGeneratedFiles. * @param {string[]} fileIds - Array of file_id strings to search for * @param {Set} toolResourceSet - Optional filter for tool resources * @returns {Promise>} Files that match the criteria @@ -37,21 +38,25 @@ const getToolFilesByIds = async (fileIds, toolResourceSet) => { } try { - const filter = { - file_id: { $in: fileIds }, - $or: [], - }; + const orConditions = []; if (toolResourceSet.has(EToolResources.context)) { - filter.$or.push({ text: { $exists: true, $ne: null }, context: FileContext.agents }); + orConditions.push({ text: { $exists: true, $ne: null }, context: FileContext.agents }); } if (toolResourceSet.has(EToolResources.file_search)) { - filter.$or.push({ embedded: true }); + orConditions.push({ embedded: true }); } - if (toolResourceSet.has(EToolResources.execute_code)) { - filter.$or.push({ 'metadata.fileIdentifier': { $exists: true } }); + + if (orConditions.length === 0) { + return []; } + const filter = { + file_id: { $in: fileIds }, + context: { $ne: FileContext.execute_code }, // Exclude code-generated files + $or: orConditions, + }; + const selectFields = { text: 0 }; const sortOptions = { updatedAt: -1 }; @@ -62,6 +67,70 @@ const getToolFilesByIds = async (fileIds, toolResourceSet) => { } }; +/** + * Retrieves files generated by code execution for a given conversation. + * These files are stored locally with fileIdentifier metadata for code env re-upload. + * @param {string} conversationId - The conversation ID to search for + * @param {string[]} [messageIds] - Optional array of messageIds to filter by (for linear thread filtering) + * @returns {Promise>} Files generated by code execution in the conversation + */ +const getCodeGeneratedFiles = async (conversationId, messageIds) => { + if (!conversationId) { + return []; + } + + /** messageIds are required for proper thread filtering of code-generated files */ + if (!messageIds || messageIds.length === 0) { + return []; + } + + try { + const filter = { + conversationId, + context: FileContext.execute_code, + messageId: { $exists: true, $in: messageIds }, + 'metadata.fileIdentifier': { $exists: true }, + }; + + const selectFields = { text: 0 }; + const sortOptions = { createdAt: 1 }; + + return await getFiles(filter, sortOptions, selectFields); + } catch (error) { + logger.error('[getCodeGeneratedFiles] Error retrieving code generated files:', error); + return []; + } +}; + +/** + * Retrieves user-uploaded execute_code files (not code-generated) by their file IDs. + * These are files with fileIdentifier metadata but context is NOT execute_code (e.g., agents or message_attachment). + * File IDs should be collected from message.files arrays in the current thread. + * @param {string[]} fileIds - Array of file IDs to fetch (from message.files in the thread) + * @returns {Promise>} User-uploaded execute_code files + */ +const getUserCodeFiles = async (fileIds) => { + if (!fileIds || fileIds.length === 0) { + return []; + } + + try { + const filter = { + file_id: { $in: fileIds }, + context: { $ne: FileContext.execute_code }, + 'metadata.fileIdentifier': { $exists: true }, + }; + + const selectFields = { text: 0 }; + const sortOptions = { createdAt: 1 }; + + return await getFiles(filter, sortOptions, selectFields); + } catch (error) { + logger.error('[getUserCodeFiles] Error retrieving user code files:', error); + return []; + } +}; + /** * Creates a new file with a TTL of 1 hour. * @param {MongoFile} data - The file data to be created, must contain file_id. @@ -169,6 +238,8 @@ module.exports = { findFileById, getFiles, getToolFilesByIds, + getCodeGeneratedFiles, + getUserCodeFiles, createFile, updateFile, updateFileUsage, diff --git a/api/package.json b/api/package.json index ab0e5130bc69..8916a8487056 100644 --- a/api/package.json +++ b/api/package.json @@ -36,7 +36,7 @@ "dependencies": { "@anthropic-ai/sdk": "^0.71.0", "@anthropic-ai/vertex-sdk": "^0.14.0", - "@aws-sdk/client-bedrock-runtime": "^3.941.0", + "@aws-sdk/client-bedrock-runtime": "^3.970.0", "@aws-sdk/client-s3": "^3.758.0", "@aws-sdk/s3-request-presigner": "^3.758.0", "@azure/identity": "^4.7.0", @@ -45,7 +45,7 @@ "@google/genai": "^1.19.0", "@keyv/redis": "^4.3.3", "@langchain/core": "^0.3.80", - "@librechat/agents": "^3.0.776", + "@librechat/agents": "^3.1.0", "@librechat/api": "*", "@librechat/data-schemas": "*", "@microsoft/microsoft-graph-client": "^3.0.7", diff --git a/api/server/controllers/PermissionsController.js b/api/server/controllers/PermissionsController.js index e22e9532c91b..51993d083c3c 100644 --- a/api/server/controllers/PermissionsController.js +++ b/api/server/controllers/PermissionsController.js @@ -5,6 +5,7 @@ const mongoose = require('mongoose'); const { logger } = require('@librechat/data-schemas'); const { ResourceType, PrincipalType, PermissionBits } = require('librechat-data-provider'); +const { enrichRemoteAgentPrincipals, backfillRemoteAgentPermissions } = require('@librechat/api'); const { bulkUpdateResourcePermissions, ensureGroupPrincipalExists, @@ -14,7 +15,6 @@ const { findAccessibleResources, getResourcePermissionsMap, } = require('~/server/services/PermissionService'); -const { AclEntry } = require('~/db/models'); const { searchPrincipals: searchLocalPrincipals, sortPrincipalsByRelevance, @@ -24,6 +24,7 @@ const { entraIdPrincipalFeatureEnabled, searchEntraIdPrincipals, } = require('~/server/services/GraphApiService'); +const { AclEntry, AccessRole } = require('~/db/models'); /** * Generic controller for resource permission endpoints @@ -234,7 +235,7 @@ const getResourcePermissions = async (req, res) => { }, ]); - const principals = []; + let principals = []; let publicPermission = null; // Process aggregation results @@ -280,6 +281,13 @@ const getResourcePermissions = async (req, res) => { } } + if (resourceType === ResourceType.REMOTE_AGENT) { + const enricherDeps = { AclEntry, AccessRole, logger }; + const enrichResult = await enrichRemoteAgentPrincipals(enricherDeps, resourceId, principals); + principals = enrichResult.principals; + backfillRemoteAgentPermissions(enricherDeps, resourceId, enrichResult.entriesToBackfill); + } + // Return response in format expected by frontend const response = { resourceType, diff --git a/api/server/controllers/UserController.js b/api/server/controllers/UserController.js index b0cfd7ede27c..0f17b4d3a975 100644 --- a/api/server/controllers/UserController.js +++ b/api/server/controllers/UserController.js @@ -22,6 +22,7 @@ const { } = require('~/models'); const { ConversationTag, + AgentApiKey, Transaction, MemoryEntry, Assistant, @@ -256,6 +257,7 @@ const deleteUserController = async (req, res) => { await deleteFiles(null, user.id); // delete database files in case of orphaned files from previous steps await deleteToolCalls(user.id); // delete user tool calls await deleteUserAgents(user.id); // delete user agents + await AgentApiKey.deleteMany({ user: user._id }); // delete user agent API keys await Assistant.deleteMany({ user: user.id }); // delete user assistants await ConversationTag.deleteMany({ user: user.id }); // delete user conversation tags await MemoryEntry.deleteMany({ userId: user.id }); // delete user memory entries diff --git a/api/server/controllers/agents/__tests__/callbacks.spec.js b/api/server/controllers/agents/__tests__/callbacks.spec.js index 7922c31efabb..103f9f3236ae 100644 --- a/api/server/controllers/agents/__tests__/callbacks.spec.js +++ b/api/server/controllers/agents/__tests__/callbacks.spec.js @@ -16,9 +16,7 @@ jest.mock('@librechat/data-schemas', () => ({ })); jest.mock('@librechat/agents', () => ({ - EnvVar: { CODE_API_KEY: 'CODE_API_KEY' }, - Providers: { GOOGLE: 'google' }, - GraphEvents: {}, + ...jest.requireActual('@librechat/agents'), getMessageId: jest.fn(), ToolEndHandler: jest.fn(), handleToolCalls: jest.fn(), diff --git a/api/server/controllers/agents/callbacks.js b/api/server/controllers/agents/callbacks.js index 0d2a7bc31760..c27f89fdf80d 100644 --- a/api/server/controllers/agents/callbacks.js +++ b/api/server/controllers/agents/callbacks.js @@ -1,6 +1,7 @@ const { nanoid } = require('nanoid'); -const { sendEvent, GenerationJobManager } = require('@librechat/api'); +const { Constants } = require('@librechat/agents'); const { logger } = require('@librechat/data-schemas'); +const { sendEvent, GenerationJobManager, writeAttachmentEvent } = require('@librechat/api'); const { Tools, StepTypes, FileContext, ErrorTypes } = require('librechat-data-provider'); const { EnvVar, @@ -441,10 +442,10 @@ function createToolEndCallback({ req, res, artifactPromises, streamId = null }) return; } - { - if (output.name !== Tools.execute_code) { - return; - } + const isCodeTool = + output.name === Tools.execute_code || output.name === Constants.PROGRAMMATIC_TOOL_CALLING; + if (!isCodeTool) { + return; } if (!output.artifact.files) { @@ -488,7 +489,226 @@ function createToolEndCallback({ req, res, artifactPromises, streamId = null }) }; } +/** + * Helper to write attachment events in Open Responses format (librechat:attachment) + * @param {ServerResponse} res - The server response object + * @param {Object} tracker - The response tracker with sequence number + * @param {Object} attachment - The attachment data + * @param {Object} metadata - Additional metadata (messageId, conversationId) + */ +function writeResponsesAttachment(res, tracker, attachment, metadata) { + const sequenceNumber = tracker.nextSequence(); + writeAttachmentEvent(res, sequenceNumber, attachment, { + messageId: metadata.run_id, + conversationId: metadata.thread_id, + }); +} + +/** + * Creates a tool end callback specifically for the Responses API. + * Emits attachments as `librechat:attachment` events per the Open Responses extension spec. + * + * @param {Object} params + * @param {ServerRequest} params.req + * @param {ServerResponse} params.res + * @param {Object} params.tracker - Response tracker with sequence number + * @param {Promise[]} params.artifactPromises + * @returns {ToolEndCallback} The tool end callback. + */ +function createResponsesToolEndCallback({ req, res, tracker, artifactPromises }) { + /** + * @type {ToolEndCallback} + */ + return async (data, metadata) => { + const output = data?.output; + if (!output) { + return; + } + + if (!output.artifact) { + return; + } + + if (output.artifact[Tools.file_search]) { + artifactPromises.push( + (async () => { + const user = req.user; + const attachment = await processFileCitations({ + user, + metadata, + appConfig: req.config, + toolArtifact: output.artifact, + toolCallId: output.tool_call_id, + }); + if (!attachment) { + return null; + } + // For Responses API, emit attachment during streaming + if (res.headersSent && !res.writableEnded) { + writeResponsesAttachment(res, tracker, attachment, metadata); + } + return attachment; + })().catch((error) => { + logger.error('Error processing file citations:', error); + return null; + }), + ); + } + + if (output.artifact[Tools.ui_resources]) { + artifactPromises.push( + (async () => { + const attachment = { + type: Tools.ui_resources, + toolCallId: output.tool_call_id, + [Tools.ui_resources]: output.artifact[Tools.ui_resources].data, + }; + // For Responses API, always emit attachment during streaming + if (res.headersSent && !res.writableEnded) { + writeResponsesAttachment(res, tracker, attachment, metadata); + } + return attachment; + })().catch((error) => { + logger.error('Error processing artifact content:', error); + return null; + }), + ); + } + + if (output.artifact[Tools.web_search]) { + artifactPromises.push( + (async () => { + const attachment = { + type: Tools.web_search, + toolCallId: output.tool_call_id, + [Tools.web_search]: { ...output.artifact[Tools.web_search] }, + }; + // For Responses API, always emit attachment during streaming + if (res.headersSent && !res.writableEnded) { + writeResponsesAttachment(res, tracker, attachment, metadata); + } + return attachment; + })().catch((error) => { + logger.error('Error processing artifact content:', error); + return null; + }), + ); + } + + if (output.artifact.content) { + /** @type {FormattedContent[]} */ + const content = output.artifact.content; + for (let i = 0; i < content.length; i++) { + const part = content[i]; + if (!part) { + continue; + } + if (part.type !== 'image_url') { + continue; + } + const { url } = part.image_url; + artifactPromises.push( + (async () => { + const filename = `${output.name}_img_${nanoid()}`; + const file_id = output.artifact.file_ids?.[i]; + const file = await saveBase64Image(url, { + req, + file_id, + filename, + endpoint: metadata.provider, + context: FileContext.image_generation, + }); + const fileMetadata = Object.assign(file, { + toolCallId: output.tool_call_id, + }); + + if (!fileMetadata) { + return null; + } + + // For Responses API, emit attachment during streaming + if (res.headersSent && !res.writableEnded) { + const attachment = { + file_id: fileMetadata.file_id, + filename: fileMetadata.filename, + type: fileMetadata.type, + url: fileMetadata.filepath, + width: fileMetadata.width, + height: fileMetadata.height, + tool_call_id: output.tool_call_id, + }; + writeResponsesAttachment(res, tracker, attachment, metadata); + } + + return fileMetadata; + })().catch((error) => { + logger.error('Error processing artifact content:', error); + return null; + }), + ); + } + return; + } + + const isCodeTool = + output.name === Tools.execute_code || output.name === Constants.PROGRAMMATIC_TOOL_CALLING; + if (!isCodeTool) { + return; + } + + if (!output.artifact.files) { + return; + } + + for (const file of output.artifact.files) { + const { id, name } = file; + artifactPromises.push( + (async () => { + const result = await loadAuthValues({ + userId: req.user.id, + authFields: [EnvVar.CODE_API_KEY], + }); + const fileMetadata = await processCodeOutput({ + req, + id, + name, + apiKey: result[EnvVar.CODE_API_KEY], + messageId: metadata.run_id, + toolCallId: output.tool_call_id, + conversationId: metadata.thread_id, + session_id: output.artifact.session_id, + }); + + if (!fileMetadata) { + return null; + } + + // For Responses API, emit attachment during streaming + if (res.headersSent && !res.writableEnded) { + const attachment = { + file_id: fileMetadata.file_id, + filename: fileMetadata.filename, + type: fileMetadata.type, + url: fileMetadata.filepath, + width: fileMetadata.width, + height: fileMetadata.height, + tool_call_id: output.tool_call_id, + }; + writeResponsesAttachment(res, tracker, attachment, metadata); + } + + return fileMetadata; + })().catch((error) => { + logger.error('Error processing code output:', error); + return null; + }), + ); + } + }; +} + module.exports = { getDefaultHandlers, createToolEndCallback, + createResponsesToolEndCallback, }; diff --git a/api/server/controllers/agents/client.js b/api/server/controllers/agents/client.js index 35cf7de784fe..0977f3468326 100644 --- a/api/server/controllers/agents/client.js +++ b/api/server/controllers/agents/client.js @@ -649,6 +649,7 @@ class AgentClient extends BaseClient { updateFilesUsage: db.updateFilesUsage, getUserKeyValues: db.getUserKeyValues, getToolFilesByIds: db.getToolFilesByIds, + getCodeGeneratedFiles: db.getCodeGeneratedFiles, }, ); @@ -1028,6 +1029,7 @@ class AgentClient extends BaseClient { run = await createRun({ agents, + messages, indexTokenCountMap, runId: this.responseMessageId, signal: abortController.signal, diff --git a/api/server/controllers/agents/openai.js b/api/server/controllers/agents/openai.js new file mode 100644 index 000000000000..331179c7f468 --- /dev/null +++ b/api/server/controllers/agents/openai.js @@ -0,0 +1,660 @@ +const { nanoid } = require('nanoid'); +const { logger } = require('@librechat/data-schemas'); +const { EModelEndpoint, ResourceType, PermissionBits } = require('librechat-data-provider'); +const { + Callback, + ToolEndHandler, + formatAgentMessages, + ChatModelStreamHandler, +} = require('@librechat/agents'); +const { + writeSSE, + createRun, + createChunk, + sendFinalChunk, + createSafeUser, + validateRequest, + initializeAgent, + createErrorResponse, + buildNonStreamingResponse, + createOpenAIStreamTracker, + createOpenAIContentAggregator, + isChatCompletionValidationFailure, +} = require('@librechat/api'); +const { createToolEndCallback } = require('~/server/controllers/agents/callbacks'); +const { findAccessibleResources } = require('~/server/services/PermissionService'); +const { loadAgentTools } = require('~/server/services/ToolService'); +const { getConvoFiles } = require('~/models/Conversation'); +const { getAgent, getAgents } = require('~/models/Agent'); +const db = require('~/models'); + +/** + * Creates a tool loader function for the agent. + * @param {AbortSignal} signal - The abort signal + */ +function createToolLoader(signal) { + return async function loadTools({ + req, + res, + tools, + model, + agentId, + provider, + tool_options, + tool_resources, + }) { + const agent = { id: agentId, tools, provider, model, tool_options }; + try { + return await loadAgentTools({ + req, + res, + agent, + signal, + tool_resources, + streamId: null, // No resumable stream for OpenAI compat + }); + } catch (error) { + logger.error('Error loading tools for agent ' + agentId, error); + } + }; +} + +/** + * Convert content part to internal format + * @param {Object} part - Content part + * @returns {Object} Converted part + */ +function convertContentPart(part) { + if (part.type === 'text') { + return { type: 'text', text: part.text }; + } + if (part.type === 'image_url') { + return { type: 'image_url', image_url: part.image_url }; + } + return part; +} + +/** + * Convert OpenAI messages to internal format + * @param {Array} messages - OpenAI format messages + * @returns {Array} Internal format messages + */ +function convertMessages(messages) { + return messages.map((msg) => { + let content; + if (typeof msg.content === 'string') { + content = msg.content; + } else if (msg.content) { + content = msg.content.map(convertContentPart); + } else { + content = ''; + } + + return { + role: msg.role, + content, + ...(msg.name && { name: msg.name }), + ...(msg.tool_calls && { tool_calls: msg.tool_calls }), + ...(msg.tool_call_id && { tool_call_id: msg.tool_call_id }), + }; + }); +} + +/** + * Send an error response in OpenAI format + */ +function sendErrorResponse(res, statusCode, message, type = 'invalid_request_error', code = null) { + res.status(statusCode).json(createErrorResponse(message, type, code)); +} + +/** + * OpenAI-compatible chat completions controller for agents. + * + * POST /v1/chat/completions + * + * Request format: + * { + * "model": "agent_id_here", + * "messages": [{"role": "user", "content": "Hello!"}], + * "stream": true, + * "conversation_id": "optional", + * "parent_message_id": "optional" + * } + */ +const OpenAIChatCompletionController = async (req, res) => { + const appConfig = req.config; + + // Validate request + const validation = validateRequest(req.body); + if (isChatCompletionValidationFailure(validation)) { + return sendErrorResponse(res, 400, validation.error); + } + + const request = validation.request; + const agentId = request.model; + + // Look up the agent + const agent = await getAgent({ id: agentId }); + if (!agent) { + return sendErrorResponse( + res, + 404, + `Agent not found: ${agentId}`, + 'invalid_request_error', + 'model_not_found', + ); + } + + // Generate IDs + const requestId = `chatcmpl-${nanoid()}`; + const conversationId = request.conversation_id ?? nanoid(); + const parentMessageId = request.parent_message_id ?? null; + const created = Math.floor(Date.now() / 1000); + + const context = { + created, + requestId, + model: agentId, + }; + + // Set up abort controller + const abortController = new AbortController(); + + // Handle client disconnect + req.on('close', () => { + if (!abortController.signal.aborted) { + abortController.abort(); + logger.debug('[OpenAI API] Client disconnected, aborting'); + } + }); + + try { + // Build allowed providers set + const allowedProviders = new Set( + appConfig?.endpoints?.[EModelEndpoint.agents]?.allowedProviders, + ); + + // Create tool loader + const loadTools = createToolLoader(abortController.signal); + + // Initialize the agent first to check for disableStreaming + const endpointOption = { + endpoint: agent.provider, + model_parameters: agent.model_parameters ?? {}, + }; + + const primaryConfig = await initializeAgent( + { + req, + res, + loadTools, + requestFiles: [], + conversationId, + parentMessageId, + agent, + endpointOption, + allowedProviders, + isInitialAgent: true, + }, + { + getConvoFiles, + getFiles: db.getFiles, + getUserKey: db.getUserKey, + getMessages: db.getMessages, + updateFilesUsage: db.updateFilesUsage, + getUserKeyValues: db.getUserKeyValues, + getUserCodeFiles: db.getUserCodeFiles, + getToolFilesByIds: db.getToolFilesByIds, + getCodeGeneratedFiles: db.getCodeGeneratedFiles, + }, + ); + + // Determine if streaming is enabled (check both request and agent config) + const streamingDisabled = !!primaryConfig.model_parameters?.disableStreaming; + const isStreaming = request.stream === true && !streamingDisabled; + + // Create tracker for streaming or aggregator for non-streaming + const tracker = isStreaming ? createOpenAIStreamTracker() : null; + const aggregator = isStreaming ? null : createOpenAIContentAggregator(); + + // Set up response for streaming + if (isStreaming) { + res.setHeader('Content-Type', 'text/event-stream'); + res.setHeader('Cache-Control', 'no-cache'); + res.setHeader('Connection', 'keep-alive'); + res.setHeader('X-Accel-Buffering', 'no'); + res.flushHeaders(); + + // Send initial chunk with role + const initialChunk = createChunk(context, { role: 'assistant' }); + writeSSE(res, initialChunk); + } + + // Create handler config for OpenAI streaming (only used when streaming) + const handlerConfig = isStreaming + ? { + res, + context, + tracker, + } + : null; + + // We need custom handlers that stream in OpenAI format + const collectedUsage = []; + /** @type {Promise[]} */ + const artifactPromises = []; + + // Create tool end callback for processing artifacts (images, file citations, code output) + const toolEndCallback = createToolEndCallback({ req, res, artifactPromises, streamId: null }); + + // Convert messages to internal format + const openaiMessages = convertMessages(request.messages); + + // Format for agent + const toolSet = new Set((primaryConfig.tools ?? []).map((tool) => tool && tool.name)); + const { messages: formattedMessages, indexTokenCountMap } = formatAgentMessages( + openaiMessages, + {}, + toolSet, + ); + + /** + * Create a simple handler that processes data + */ + const createHandler = (processor) => ({ + handle: (_event, data) => { + if (processor) { + processor(data); + } + }, + }); + + /** + * Stream text content in OpenAI format + */ + const streamText = (text) => { + if (!text) { + return; + } + if (isStreaming) { + tracker.addText(); + writeSSE(res, createChunk(context, { content: text })); + } else { + aggregator.addText(text); + } + }; + + /** + * Stream reasoning content in OpenAI format (OpenRouter convention) + */ + const streamReasoning = (text) => { + if (!text) { + return; + } + if (isStreaming) { + tracker.addReasoning(); + writeSSE(res, createChunk(context, { reasoning: text })); + } else { + aggregator.addReasoning(text); + } + }; + + // Built-in handler for processing raw model stream chunks + const chatModelStreamHandler = new ChatModelStreamHandler(); + + // Event handlers for OpenAI-compatible streaming + const handlers = { + // Process raw model chunks and dispatch message/reasoning deltas + on_chat_model_stream: { + handle: async (event, data, metadata, graph) => { + await chatModelStreamHandler.handle(event, data, metadata, graph); + }, + }, + + // Text content streaming + on_message_delta: createHandler((data) => { + const content = data?.delta?.content; + if (Array.isArray(content)) { + for (const part of content) { + if (part.type === 'text' && part.text) { + streamText(part.text); + } + } + } + }), + + // Reasoning/thinking content streaming + on_reasoning_delta: createHandler((data) => { + const content = data?.delta?.content; + if (Array.isArray(content)) { + for (const part of content) { + const text = part.think || part.text; + if (text) { + streamReasoning(text); + } + } + } + }), + + // Tool call initiation - streams id and name (from on_run_step) + on_run_step: createHandler((data) => { + const stepDetails = data?.stepDetails; + if (stepDetails?.type === 'tool_calls' && stepDetails.tool_calls) { + for (const tc of stepDetails.tool_calls) { + const toolIndex = data.index ?? 0; + const toolId = tc.id ?? ''; + const toolName = tc.name ?? ''; + const toolCall = { + id: toolId, + type: 'function', + function: { name: toolName, arguments: '' }, + }; + + // Track tool call in tracker or aggregator + if (isStreaming) { + if (!tracker.toolCalls.has(toolIndex)) { + tracker.toolCalls.set(toolIndex, toolCall); + } + // Stream initial tool call chunk (like OpenAI does) + writeSSE( + res, + createChunk(context, { + tool_calls: [{ index: toolIndex, ...toolCall }], + }), + ); + } else { + if (!aggregator.toolCalls.has(toolIndex)) { + aggregator.toolCalls.set(toolIndex, toolCall); + } + } + } + } + }), + + // Tool call argument streaming (from on_run_step_delta) + on_run_step_delta: createHandler((data) => { + const delta = data?.delta; + if (delta?.type === 'tool_calls' && delta.tool_calls) { + for (const tc of delta.tool_calls) { + const args = tc.args ?? ''; + if (!args) { + continue; + } + + const toolIndex = tc.index ?? 0; + + // Update tool call arguments + const targetMap = isStreaming ? tracker.toolCalls : aggregator.toolCalls; + const tracked = targetMap.get(toolIndex); + if (tracked) { + tracked.function.arguments += args; + } + + // Stream argument delta (only for streaming) + if (isStreaming) { + writeSSE( + res, + createChunk(context, { + tool_calls: [ + { + index: toolIndex, + function: { arguments: args }, + }, + ], + }), + ); + } + } + } + }), + + // Usage tracking + on_chat_model_end: createHandler((data) => { + const usage = data?.output?.usage_metadata; + if (usage) { + collectedUsage.push(usage); + const target = isStreaming ? tracker : aggregator; + target.usage.promptTokens += usage.input_tokens ?? 0; + target.usage.completionTokens += usage.output_tokens ?? 0; + } + }), + on_run_step_completed: createHandler(), + // Use proper ToolEndHandler for processing artifacts (images, file citations, code output) + on_tool_end: new ToolEndHandler(toolEndCallback, logger), + on_chain_stream: createHandler(), + on_chain_end: createHandler(), + on_agent_update: createHandler(), + on_custom_event: createHandler(), + }; + + // Create and run the agent + const userId = req.user?.id ?? 'api-user'; + + // Extract userMCPAuthMap from primaryConfig (needed for MCP tool connections) + const userMCPAuthMap = primaryConfig.userMCPAuthMap; + + const run = await createRun({ + agents: [primaryConfig], + messages: formattedMessages, + indexTokenCountMap, + runId: requestId, + signal: abortController.signal, + customHandlers: handlers, + requestBody: { + messageId: requestId, + conversationId, + }, + user: { id: userId }, + }); + + if (!run) { + throw new Error('Failed to create agent run'); + } + + // Process the stream + const config = { + runName: 'AgentRun', + configurable: { + thread_id: conversationId, + user_id: userId, + user: createSafeUser(req.user), + ...(userMCPAuthMap != null && { userMCPAuthMap }), + }, + signal: abortController.signal, + streamMode: 'values', + version: 'v2', + }; + + await run.processStream({ messages: formattedMessages }, config, { + callbacks: { + [Callback.TOOL_ERROR]: (graph, error, toolId) => { + logger.error(`[OpenAI API] Tool Error "${toolId}"`, error); + }, + }, + }); + + // Finalize response + if (isStreaming) { + sendFinalChunk(handlerConfig); + res.end(); + + // Wait for artifact processing after response ends (non-blocking) + if (artifactPromises.length > 0) { + Promise.all(artifactPromises).catch((artifactError) => { + logger.warn('[OpenAI API] Error processing artifacts:', artifactError); + }); + } + } else { + // For non-streaming, wait for artifacts before sending response + if (artifactPromises.length > 0) { + try { + await Promise.all(artifactPromises); + } catch (artifactError) { + logger.warn('[OpenAI API] Error processing artifacts:', artifactError); + } + } + + // Build usage from aggregated data + const usage = { + prompt_tokens: aggregator.usage.promptTokens, + completion_tokens: aggregator.usage.completionTokens, + total_tokens: aggregator.usage.promptTokens + aggregator.usage.completionTokens, + }; + + if (aggregator.usage.reasoningTokens > 0) { + usage.completion_tokens_details = { + reasoning_tokens: aggregator.usage.reasoningTokens, + }; + } + + const response = buildNonStreamingResponse( + context, + aggregator.getText(), + aggregator.getReasoning(), + aggregator.toolCalls, + usage, + ); + res.json(response); + } + } catch (error) { + const errorMessage = error instanceof Error ? error.message : 'An error occurred'; + logger.error('[OpenAI API] Error:', error); + + // Check if we already started streaming (headers sent) + if (res.headersSent) { + // Headers already sent, send error in stream + const errorChunk = createChunk(context, { content: `\n\nError: ${errorMessage}` }, 'stop'); + writeSSE(res, errorChunk); + writeSSE(res, '[DONE]'); + res.end(); + } else { + sendErrorResponse(res, 500, errorMessage, 'server_error'); + } + } +}; + +/** + * List available agents as models (filtered by remote access permissions) + * + * GET /v1/models + */ +const ListModelsController = async (req, res) => { + try { + const userId = req.user?.id; + const userRole = req.user?.role; + + if (!userId) { + return sendErrorResponse(res, 401, 'Authentication required', 'auth_error'); + } + + // Find agents the user has remote access to (VIEW permission on REMOTE_AGENT) + const accessibleAgentIds = await findAccessibleResources({ + userId, + role: userRole, + resourceType: ResourceType.REMOTE_AGENT, + requiredPermissions: PermissionBits.VIEW, + }); + + // Get the accessible agents + let agents = []; + if (accessibleAgentIds.length > 0) { + agents = await getAgents({ _id: { $in: accessibleAgentIds } }); + } + + const models = agents.map((agent) => ({ + id: agent.id, + object: 'model', + created: Math.floor(new Date(agent.createdAt || Date.now()).getTime() / 1000), + owned_by: 'librechat', + permission: [], + root: agent.id, + parent: null, + // LibreChat extensions + name: agent.name, + description: agent.description, + provider: agent.provider, + })); + + res.json({ + object: 'list', + data: models, + }); + } catch (error) { + const errorMessage = error instanceof Error ? error.message : 'Failed to list models'; + logger.error('[OpenAI API] Error listing models:', error); + sendErrorResponse(res, 500, errorMessage, 'server_error'); + } +}; + +/** + * Get a specific model/agent (with remote access permission check) + * + * GET /v1/models/:model + */ +const GetModelController = async (req, res) => { + try { + const { model } = req.params; + const userId = req.user?.id; + const userRole = req.user?.role; + + if (!userId) { + return sendErrorResponse(res, 401, 'Authentication required', 'auth_error'); + } + + const agent = await getAgent({ id: model }); + + if (!agent) { + return sendErrorResponse( + res, + 404, + `Model not found: ${model}`, + 'invalid_request_error', + 'model_not_found', + ); + } + + // Check if user has remote access to this agent + const accessibleAgentIds = await findAccessibleResources({ + userId, + role: userRole, + resourceType: ResourceType.REMOTE_AGENT, + requiredPermissions: PermissionBits.VIEW, + }); + + const hasAccess = accessibleAgentIds.some((id) => id.toString() === agent._id.toString()); + + if (!hasAccess) { + return sendErrorResponse( + res, + 403, + `No remote access to model: ${model}`, + 'permission_error', + 'access_denied', + ); + } + + res.json({ + id: agent.id, + object: 'model', + created: Math.floor(new Date(agent.createdAt || Date.now()).getTime() / 1000), + owned_by: 'librechat', + permission: [], + root: agent.id, + parent: null, + // LibreChat extensions + name: agent.name, + description: agent.description, + provider: agent.provider, + }); + } catch (error) { + const errorMessage = error instanceof Error ? error.message : 'Failed to get model'; + logger.error('[OpenAI API] Error getting model:', error); + sendErrorResponse(res, 500, errorMessage, 'server_error'); + } +}; + +module.exports = { + OpenAIChatCompletionController, + ListModelsController, + GetModelController, +}; diff --git a/api/server/controllers/agents/responses.js b/api/server/controllers/agents/responses.js new file mode 100644 index 000000000000..bf52edcf7d97 --- /dev/null +++ b/api/server/controllers/agents/responses.js @@ -0,0 +1,800 @@ +const { nanoid } = require('nanoid'); +const { v4: uuidv4 } = require('uuid'); +const { logger } = require('@librechat/data-schemas'); +const { EModelEndpoint, ResourceType, PermissionBits } = require('librechat-data-provider'); +const { + Callback, + ToolEndHandler, + formatAgentMessages, + ChatModelStreamHandler, +} = require('@librechat/agents'); +const { + createRun, + createSafeUser, + initializeAgent, + // Responses API + writeDone, + buildResponse, + generateResponseId, + isValidationFailure, + emitResponseCreated, + createResponseContext, + createResponseTracker, + setupStreamingResponse, + emitResponseInProgress, + convertInputToMessages, + validateResponseRequest, + buildAggregatedResponse, + createResponseAggregator, + sendResponsesErrorResponse, + createResponsesEventHandlers, + createAggregatorEventHandlers, +} = require('@librechat/api'); +const { + createResponsesToolEndCallback, + createToolEndCallback, +} = require('~/server/controllers/agents/callbacks'); +const { findAccessibleResources } = require('~/server/services/PermissionService'); +const { getConvoFiles, saveConvo, getConvo } = require('~/models/Conversation'); +const { loadAgentTools } = require('~/server/services/ToolService'); +const { getAgent, getAgents } = require('~/models/Agent'); +const db = require('~/models'); + +/** @type {import('@librechat/api').AppConfig | null} */ +let appConfig = null; + +/** + * Set the app config for the controller + * @param {import('@librechat/api').AppConfig} config + */ +function setAppConfig(config) { + appConfig = config; +} + +/** + * Creates a tool loader function for the agent. + * @param {AbortSignal} signal - The abort signal + */ +function createToolLoader(signal) { + return async function loadTools({ + req, + res, + tools, + model, + agentId, + provider, + tool_options, + tool_resources, + }) { + const agent = { id: agentId, tools, provider, model, tool_options }; + try { + return await loadAgentTools({ + req, + res, + agent, + signal, + tool_resources, + streamId: null, + }); + } catch (error) { + logger.error('Error loading tools for agent ' + agentId, error); + } + }; +} + +/** + * Convert Open Responses input items to internal messages + * @param {import('@librechat/api').InputItem[]} input + * @returns {Array} Internal messages + */ +function convertToInternalMessages(input) { + return convertInputToMessages(input); +} + +/** + * Load messages from a previous response/conversation + * @param {string} conversationId - The conversation/response ID + * @param {string} userId - The user ID + * @returns {Promise} Messages from the conversation + */ +async function loadPreviousMessages(conversationId, userId) { + try { + const messages = await db.getMessages({ conversationId, user: userId }); + if (!messages || messages.length === 0) { + return []; + } + + // Convert stored messages to internal format + return messages.map((msg) => { + const internalMsg = { + role: msg.isCreatedByUser ? 'user' : 'assistant', + content: '', + messageId: msg.messageId, + }; + + // Handle content - could be string or array + if (typeof msg.text === 'string') { + internalMsg.content = msg.text; + } else if (Array.isArray(msg.content)) { + // Handle content parts + internalMsg.content = msg.content; + } else if (msg.text) { + internalMsg.content = String(msg.text); + } + + return internalMsg; + }); + } catch (error) { + logger.error('[Responses API] Error loading previous messages:', error); + return []; + } +} + +/** + * Save input messages to database + * @param {import('express').Request} req + * @param {string} conversationId + * @param {Array} inputMessages - Internal format messages + * @param {string} agentId + * @returns {Promise} + */ +async function saveInputMessages(req, conversationId, inputMessages, agentId) { + for (const msg of inputMessages) { + if (msg.role === 'user') { + await db.saveMessage( + req, + { + messageId: msg.messageId || nanoid(), + conversationId, + parentMessageId: null, + isCreatedByUser: true, + text: typeof msg.content === 'string' ? msg.content : JSON.stringify(msg.content), + sender: 'User', + endpoint: EModelEndpoint.agents, + model: agentId, + }, + { context: 'Responses API - save user input' }, + ); + } + } +} + +/** + * Save response output to database + * @param {import('express').Request} req + * @param {string} conversationId + * @param {string} responseId + * @param {import('@librechat/api').Response} response + * @param {string} agentId + * @returns {Promise} + */ +async function saveResponseOutput(req, conversationId, responseId, response, agentId) { + // Extract text content from output items + let responseText = ''; + for (const item of response.output) { + if (item.type === 'message' && item.content) { + for (const part of item.content) { + if (part.type === 'output_text' && part.text) { + responseText += part.text; + } + } + } + } + + // Save the assistant message + await db.saveMessage( + req, + { + messageId: responseId, + conversationId, + parentMessageId: null, + isCreatedByUser: false, + text: responseText, + sender: 'Agent', + endpoint: EModelEndpoint.agents, + model: agentId, + finish_reason: response.status === 'completed' ? 'stop' : response.status, + tokenCount: response.usage?.output_tokens, + }, + { context: 'Responses API - save assistant response' }, + ); +} + +/** + * Save or update conversation + * @param {import('express').Request} req + * @param {string} conversationId + * @param {string} agentId + * @param {object} agent + * @returns {Promise} + */ +async function saveConversation(req, conversationId, agentId, agent) { + await saveConvo( + req, + { + conversationId, + endpoint: EModelEndpoint.agents, + agentId, + title: agent?.name || 'Open Responses Conversation', + model: agent?.model, + }, + { context: 'Responses API - save conversation' }, + ); +} + +/** + * Convert stored messages to Open Responses output format + * @param {Array} messages - Stored messages + * @returns {Array} Output items + */ +function convertMessagesToOutputItems(messages) { + const output = []; + + for (const msg of messages) { + if (!msg.isCreatedByUser) { + output.push({ + type: 'message', + id: msg.messageId, + role: 'assistant', + status: 'completed', + content: [ + { + type: 'output_text', + text: msg.text || '', + annotations: [], + }, + ], + }); + } + } + + return output; +} + +/** + * Create Response - POST /v1/responses + * + * Creates a model response following the Open Responses API specification. + * Supports both streaming and non-streaming responses. + * + * @param {import('express').Request} req + * @param {import('express').Response} res + */ +const createResponse = async (req, res) => { + // Validate request + const validation = validateResponseRequest(req.body); + if (isValidationFailure(validation)) { + return sendResponsesErrorResponse(res, 400, validation.error); + } + + const request = validation.request; + const agentId = request.model; + const isStreaming = request.stream === true; + + // Look up the agent + const agent = await getAgent({ id: agentId }); + if (!agent) { + return sendResponsesErrorResponse( + res, + 404, + `Agent not found: ${agentId}`, + 'not_found', + 'model_not_found', + ); + } + + // Generate IDs + const responseId = generateResponseId(); + const conversationId = request.previous_response_id ?? uuidv4(); + const parentMessageId = null; + + // Create response context + const context = createResponseContext(request, responseId); + + // Set up abort controller + const abortController = new AbortController(); + + // Handle client disconnect + req.on('close', () => { + if (!abortController.signal.aborted) { + abortController.abort(); + logger.debug('[Responses API] Client disconnected, aborting'); + } + }); + + try { + // Build allowed providers set + const allowedProviders = new Set( + appConfig?.endpoints?.[EModelEndpoint.agents]?.allowedProviders, + ); + + // Create tool loader + const loadTools = createToolLoader(abortController.signal); + + // Initialize the agent first to check for disableStreaming + const endpointOption = { + endpoint: agent.provider, + model_parameters: agent.model_parameters ?? {}, + }; + + const primaryConfig = await initializeAgent( + { + req, + res, + loadTools, + requestFiles: [], + conversationId, + parentMessageId, + agent, + endpointOption, + allowedProviders, + isInitialAgent: true, + }, + { + getConvoFiles, + getFiles: db.getFiles, + getUserKey: db.getUserKey, + getMessages: db.getMessages, + updateFilesUsage: db.updateFilesUsage, + getUserKeyValues: db.getUserKeyValues, + getUserCodeFiles: db.getUserCodeFiles, + getToolFilesByIds: db.getToolFilesByIds, + getCodeGeneratedFiles: db.getCodeGeneratedFiles, + }, + ); + + // Determine if streaming is enabled (check both request and agent config) + const streamingDisabled = !!primaryConfig.model_parameters?.disableStreaming; + const actuallyStreaming = isStreaming && !streamingDisabled; + + // Load previous messages if previous_response_id is provided + let previousMessages = []; + if (request.previous_response_id) { + const userId = req.user?.id ?? 'api-user'; + previousMessages = await loadPreviousMessages(request.previous_response_id, userId); + } + + // Convert input to internal messages + const inputMessages = convertToInternalMessages( + typeof request.input === 'string' ? request.input : request.input, + ); + + // Merge previous messages with new input + const allMessages = [...previousMessages, ...inputMessages]; + + // Format for agent + const toolSet = new Set((primaryConfig.tools ?? []).map((tool) => tool && tool.name)); + const { messages: formattedMessages, indexTokenCountMap } = formatAgentMessages( + allMessages, + {}, + toolSet, + ); + + // Create tracker for streaming or aggregator for non-streaming + const tracker = actuallyStreaming ? createResponseTracker() : null; + const aggregator = actuallyStreaming ? null : createResponseAggregator(); + + // Set up response for streaming + if (actuallyStreaming) { + setupStreamingResponse(res); + + // Create handler config + const handlerConfig = { + res, + context, + tracker, + }; + + // Emit response.created then response.in_progress per Open Responses spec + emitResponseCreated(handlerConfig); + emitResponseInProgress(handlerConfig); + + // Create event handlers + const { handlers: responsesHandlers, finalizeStream } = + createResponsesEventHandlers(handlerConfig); + + // Built-in handler for processing raw model stream chunks + const chatModelStreamHandler = new ChatModelStreamHandler(); + + // Artifact promises for processing tool outputs + /** @type {Promise[]} */ + const artifactPromises = []; + // Use Responses API-specific callback that emits librechat:attachment events + const toolEndCallback = createResponsesToolEndCallback({ + req, + res, + tracker, + artifactPromises, + }); + + // Combine handlers + const handlers = { + on_chat_model_stream: { + handle: async (event, data, metadata, graph) => { + await chatModelStreamHandler.handle(event, data, metadata, graph); + }, + }, + on_message_delta: responsesHandlers.on_message_delta, + on_reasoning_delta: responsesHandlers.on_reasoning_delta, + on_run_step: responsesHandlers.on_run_step, + on_run_step_delta: responsesHandlers.on_run_step_delta, + on_chat_model_end: responsesHandlers.on_chat_model_end, + on_tool_end: new ToolEndHandler(toolEndCallback, logger), + on_run_step_completed: { handle: () => {} }, + on_chain_stream: { handle: () => {} }, + on_chain_end: { handle: () => {} }, + on_agent_update: { handle: () => {} }, + on_custom_event: { handle: () => {} }, + }; + + // Create and run the agent + const userId = req.user?.id ?? 'api-user'; + const userMCPAuthMap = primaryConfig.userMCPAuthMap; + + const run = await createRun({ + agents: [primaryConfig], + messages: formattedMessages, + indexTokenCountMap, + runId: responseId, + signal: abortController.signal, + customHandlers: handlers, + requestBody: { + messageId: responseId, + conversationId, + }, + user: { id: userId }, + }); + + if (!run) { + throw new Error('Failed to create agent run'); + } + + // Process the stream + const config = { + runName: 'AgentRun', + configurable: { + thread_id: conversationId, + user_id: userId, + user: createSafeUser(req.user), + ...(userMCPAuthMap != null && { userMCPAuthMap }), + }, + signal: abortController.signal, + streamMode: 'values', + version: 'v2', + }; + + await run.processStream({ messages: formattedMessages }, config, { + callbacks: { + [Callback.TOOL_ERROR]: (graph, error, toolId) => { + logger.error(`[Responses API] Tool Error "${toolId}"`, error); + }, + }, + }); + + // Finalize the stream + finalizeStream(); + res.end(); + + // Save to database if store: true + if (request.store === true) { + try { + // Save conversation + await saveConversation(req, conversationId, agentId, agent); + + // Save input messages + await saveInputMessages(req, conversationId, inputMessages, agentId); + + // Build response for saving (use tracker with buildResponse for streaming) + const finalResponse = buildResponse(context, tracker, 'completed'); + await saveResponseOutput(req, conversationId, responseId, finalResponse, agentId); + + logger.debug( + `[Responses API] Stored response ${responseId} in conversation ${conversationId}`, + ); + } catch (saveError) { + logger.error('[Responses API] Error saving response:', saveError); + // Don't fail the request if saving fails + } + } + + // Wait for artifact processing after response ends (non-blocking) + if (artifactPromises.length > 0) { + Promise.all(artifactPromises).catch((artifactError) => { + logger.warn('[Responses API] Error processing artifacts:', artifactError); + }); + } + } else { + // Non-streaming response + const aggregatorHandlers = createAggregatorEventHandlers(aggregator); + + // Built-in handler for processing raw model stream chunks + const chatModelStreamHandler = new ChatModelStreamHandler(); + + // Artifact promises for processing tool outputs + /** @type {Promise[]} */ + const artifactPromises = []; + const toolEndCallback = createToolEndCallback({ req, res, artifactPromises, streamId: null }); + + // Combine handlers + const handlers = { + on_chat_model_stream: { + handle: async (event, data, metadata, graph) => { + await chatModelStreamHandler.handle(event, data, metadata, graph); + }, + }, + on_message_delta: aggregatorHandlers.on_message_delta, + on_reasoning_delta: aggregatorHandlers.on_reasoning_delta, + on_run_step: aggregatorHandlers.on_run_step, + on_run_step_delta: aggregatorHandlers.on_run_step_delta, + on_chat_model_end: aggregatorHandlers.on_chat_model_end, + on_tool_end: new ToolEndHandler(toolEndCallback, logger), + on_run_step_completed: { handle: () => {} }, + on_chain_stream: { handle: () => {} }, + on_chain_end: { handle: () => {} }, + on_agent_update: { handle: () => {} }, + on_custom_event: { handle: () => {} }, + }; + + // Create and run the agent + const userId = req.user?.id ?? 'api-user'; + const userMCPAuthMap = primaryConfig.userMCPAuthMap; + + const run = await createRun({ + agents: [primaryConfig], + messages: formattedMessages, + indexTokenCountMap, + runId: responseId, + signal: abortController.signal, + customHandlers: handlers, + requestBody: { + messageId: responseId, + conversationId, + }, + user: { id: userId }, + }); + + if (!run) { + throw new Error('Failed to create agent run'); + } + + // Process the stream + const config = { + runName: 'AgentRun', + configurable: { + thread_id: conversationId, + user_id: userId, + user: createSafeUser(req.user), + ...(userMCPAuthMap != null && { userMCPAuthMap }), + }, + signal: abortController.signal, + streamMode: 'values', + version: 'v2', + }; + + await run.processStream({ messages: formattedMessages }, config, { + callbacks: { + [Callback.TOOL_ERROR]: (graph, error, toolId) => { + logger.error(`[Responses API] Tool Error "${toolId}"`, error); + }, + }, + }); + + // Wait for artifacts before sending response + if (artifactPromises.length > 0) { + try { + await Promise.all(artifactPromises); + } catch (artifactError) { + logger.warn('[Responses API] Error processing artifacts:', artifactError); + } + } + + // Build and send the response + const response = buildAggregatedResponse(context, aggregator); + + // Save to database if store: true + if (request.store === true) { + try { + // Save conversation + await saveConversation(req, conversationId, agentId, agent); + + // Save input messages + await saveInputMessages(req, conversationId, inputMessages, agentId); + + // Save response output + await saveResponseOutput(req, conversationId, responseId, response, agentId); + + logger.debug( + `[Responses API] Stored response ${responseId} in conversation ${conversationId}`, + ); + } catch (saveError) { + logger.error('[Responses API] Error saving response:', saveError); + // Don't fail the request if saving fails + } + } + + res.json(response); + } + } catch (error) { + const errorMessage = error instanceof Error ? error.message : 'An error occurred'; + logger.error('[Responses API] Error:', error); + + // Check if we already started streaming (headers sent) + if (res.headersSent) { + // Headers already sent, write error event and close + writeDone(res); + res.end(); + } else { + sendResponsesErrorResponse(res, 500, errorMessage, 'server_error'); + } + } +}; + +/** + * List available agents as models - GET /v1/models (also works with /v1/responses/models) + * + * Returns a list of available agents the user has remote access to. + * + * @param {import('express').Request} req + * @param {import('express').Response} res + */ +const listModels = async (req, res) => { + try { + const userId = req.user?.id; + const userRole = req.user?.role; + + if (!userId) { + return sendResponsesErrorResponse(res, 401, 'Authentication required', 'auth_error'); + } + + // Find agents the user has remote access to (VIEW permission on REMOTE_AGENT) + const accessibleAgentIds = await findAccessibleResources({ + userId, + role: userRole, + resourceType: ResourceType.REMOTE_AGENT, + requiredPermissions: PermissionBits.VIEW, + }); + + // Get the accessible agents + let agents = []; + if (accessibleAgentIds.length > 0) { + agents = await getAgents({ _id: { $in: accessibleAgentIds } }); + } + + // Convert to models format + const models = agents.map((agent) => ({ + id: agent.id, + object: 'model', + created: Math.floor(new Date(agent.createdAt).getTime() / 1000), + owned_by: agent.author ?? 'librechat', + // Additional metadata + name: agent.name, + description: agent.description, + provider: agent.provider, + })); + + res.json({ + object: 'list', + data: models, + }); + } catch (error) { + logger.error('[Responses API] Error listing models:', error); + sendResponsesErrorResponse( + res, + 500, + error instanceof Error ? error.message : 'Failed to list models', + 'server_error', + ); + } +}; + +/** + * Get Response - GET /v1/responses/:id + * + * Retrieves a stored response by its ID. + * The response ID maps to a conversationId in LibreChat's storage. + * + * @param {import('express').Request} req + * @param {import('express').Response} res + */ +const getResponse = async (req, res) => { + try { + const responseId = req.params.id; + const userId = req.user?.id; + + if (!responseId) { + return sendResponsesErrorResponse(res, 400, 'Response ID is required'); + } + + // The responseId could be either the response ID or the conversation ID + // Try to find a conversation with this ID + const conversation = await getConvo(userId, responseId); + + if (!conversation) { + return sendResponsesErrorResponse( + res, + 404, + `Response not found: ${responseId}`, + 'not_found', + 'response_not_found', + ); + } + + // Load messages for this conversation + const messages = await db.getMessages({ conversationId: responseId, user: userId }); + + if (!messages || messages.length === 0) { + return sendResponsesErrorResponse( + res, + 404, + `No messages found for response: ${responseId}`, + 'not_found', + 'response_not_found', + ); + } + + // Convert messages to Open Responses output format + const output = convertMessagesToOutputItems(messages); + + // Find the last assistant message for usage info + const lastAssistantMessage = messages.filter((m) => !m.isCreatedByUser).pop(); + + // Build the response object + const response = { + id: responseId, + object: 'response', + created_at: Math.floor(new Date(conversation.createdAt || Date.now()).getTime() / 1000), + completed_at: Math.floor(new Date(conversation.updatedAt || Date.now()).getTime() / 1000), + status: 'completed', + incomplete_details: null, + model: conversation.agentId || conversation.model || 'unknown', + previous_response_id: null, + instructions: null, + output, + error: null, + tools: [], + tool_choice: 'auto', + truncation: 'disabled', + parallel_tool_calls: true, + text: { format: { type: 'text' } }, + temperature: 1, + top_p: 1, + presence_penalty: 0, + frequency_penalty: 0, + top_logprobs: null, + reasoning: null, + user: userId, + usage: lastAssistantMessage?.tokenCount + ? { + input_tokens: 0, + output_tokens: lastAssistantMessage.tokenCount, + total_tokens: lastAssistantMessage.tokenCount, + } + : null, + max_output_tokens: null, + max_tool_calls: null, + store: true, + background: false, + service_tier: 'default', + metadata: {}, + safety_identifier: null, + prompt_cache_key: null, + }; + + res.json(response); + } catch (error) { + logger.error('[Responses API] Error getting response:', error); + sendResponsesErrorResponse( + res, + 500, + error instanceof Error ? error.message : 'Failed to get response', + 'server_error', + ); + } +}; + +module.exports = { + createResponse, + getResponse, + listModels, + setAppConfig, +}; diff --git a/api/server/controllers/agents/v1.js b/api/server/controllers/agents/v1.js index 9f0a4a227974..34078b225000 100644 --- a/api/server/controllers/agents/v1.js +++ b/api/server/controllers/agents/v1.js @@ -11,7 +11,9 @@ const { convertOcrToContextInPlace, } = require('@librechat/api'); const { + Time, Tools, + CacheKeys, Constants, FileSources, ResourceType, @@ -21,8 +23,6 @@ const { PermissionBits, actionDelimiter, removeNullishValues, - CacheKeys, - Time, } = require('librechat-data-provider'); const { getListAgentsByAccess, @@ -94,16 +94,25 @@ const createAgentHandler = async (req, res) => { const agent = await createAgent(agentData); - // Automatically grant owner permissions to the creator try { - await grantPermission({ - principalType: PrincipalType.USER, - principalId: userId, - resourceType: ResourceType.AGENT, - resourceId: agent._id, - accessRoleId: AccessRoleIds.AGENT_OWNER, - grantedBy: userId, - }); + await Promise.all([ + grantPermission({ + principalType: PrincipalType.USER, + principalId: userId, + resourceType: ResourceType.AGENT, + resourceId: agent._id, + accessRoleId: AccessRoleIds.AGENT_OWNER, + grantedBy: userId, + }), + grantPermission({ + principalType: PrincipalType.USER, + principalId: userId, + resourceType: ResourceType.REMOTE_AGENT, + resourceId: agent._id, + accessRoleId: AccessRoleIds.REMOTE_AGENT_OWNER, + grantedBy: userId, + }), + ]); logger.debug( `[createAgent] Granted owner permissions to user ${userId} for agent ${agent.id}`, ); @@ -396,16 +405,25 @@ const duplicateAgentHandler = async (req, res) => { newAgentData.actions = agentActions; const newAgent = await createAgent(newAgentData); - // Automatically grant owner permissions to the duplicator try { - await grantPermission({ - principalType: PrincipalType.USER, - principalId: userId, - resourceType: ResourceType.AGENT, - resourceId: newAgent._id, - accessRoleId: AccessRoleIds.AGENT_OWNER, - grantedBy: userId, - }); + await Promise.all([ + grantPermission({ + principalType: PrincipalType.USER, + principalId: userId, + resourceType: ResourceType.AGENT, + resourceId: newAgent._id, + accessRoleId: AccessRoleIds.AGENT_OWNER, + grantedBy: userId, + }), + grantPermission({ + principalType: PrincipalType.USER, + principalId: userId, + resourceType: ResourceType.REMOTE_AGENT, + resourceId: newAgent._id, + accessRoleId: AccessRoleIds.REMOTE_AGENT_OWNER, + grantedBy: userId, + }), + ]); logger.debug( `[duplicateAgent] Granted owner permissions to user ${userId} for duplicated agent ${newAgent.id}`, ); diff --git a/api/server/controllers/auth/LogoutController.js b/api/server/controllers/auth/LogoutController.js index ec6631628535..16da12aadde3 100644 --- a/api/server/controllers/auth/LogoutController.js +++ b/api/server/controllers/auth/LogoutController.js @@ -24,6 +24,9 @@ const logoutController = async (req, res) => { res.clearCookie('openid_access_token'); res.clearCookie('openid_user_id'); res.clearCookie('token_provider'); + if (isEnabled(process.env.OPENID_EXPOSE_SUB_COOKIE)) { + res.clearCookie('openid_sub'); + } const response = { message }; if ( isOpenIdUser && diff --git a/api/server/controllers/auth/oauth.js b/api/server/controllers/auth/oauth.js new file mode 100644 index 000000000000..80c2ced0026a --- /dev/null +++ b/api/server/controllers/auth/oauth.js @@ -0,0 +1,79 @@ +const { CacheKeys } = require('librechat-data-provider'); +const { logger, DEFAULT_SESSION_EXPIRY } = require('@librechat/data-schemas'); +const { + isEnabled, + getAdminPanelUrl, + isAdminPanelRedirect, + generateAdminExchangeCode, +} = require('@librechat/api'); +const { syncUserEntraGroupMemberships } = require('~/server/services/PermissionService'); +const { setAuthTokens, setOpenIDAuthTokens } = require('~/server/services/AuthService'); +const getLogStores = require('~/cache/getLogStores'); +const { checkBan } = require('~/server/middleware'); +const { generateToken } = require('~/models'); + +const domains = { + client: process.env.DOMAIN_CLIENT, + server: process.env.DOMAIN_SERVER, +}; + +function createOAuthHandler(redirectUri = domains.client) { + /** + * A handler to process OAuth authentication results. + * @type {Function} + * @param {ServerRequest} req - Express request object. + * @param {ServerResponse} res - Express response object. + * @param {NextFunction} next - Express next middleware function. + */ + return async (req, res, next) => { + try { + if (res.headersSent) { + return; + } + + await checkBan(req, res); + if (req.banned) { + return; + } + + /** Check if this is an admin panel redirect (cross-origin) */ + if (isAdminPanelRedirect(redirectUri, getAdminPanelUrl(), domains.client)) { + /** For admin panel, generate exchange code instead of setting cookies */ + const cache = getLogStores(CacheKeys.ADMIN_OAUTH_EXCHANGE); + const sessionExpiry = Number(process.env.SESSION_EXPIRY) || DEFAULT_SESSION_EXPIRY; + const token = await generateToken(req.user, sessionExpiry); + + /** Get refresh token from tokenset for OpenID users */ + const refreshToken = + req.user.tokenset?.refresh_token || req.user.federatedTokens?.refresh_token; + + const exchangeCode = await generateAdminExchangeCode(cache, req.user, token, refreshToken); + + const callbackUrl = new URL(redirectUri); + callbackUrl.searchParams.set('code', exchangeCode); + logger.info(`[OAuth] Admin panel redirect with exchange code for user: ${req.user.email}`); + return res.redirect(callbackUrl.toString()); + } + + /** Standard OAuth flow - set cookies and redirect */ + if ( + req.user && + req.user.provider == 'openid' && + isEnabled(process.env.OPENID_REUSE_TOKENS) === true + ) { + await syncUserEntraGroupMemberships(req.user, req.user.tokenset.access_token); + setOpenIDAuthTokens(req.user.tokenset, req, res, req.user._id.toString()); + } else { + await setAuthTokens(req.user._id, res); + } + res.redirect(redirectUri); + } catch (err) { + logger.error('Error in setting authentication tokens:', err); + next(err); + } + }; +} + +module.exports = { + createOAuthHandler, +}; diff --git a/api/server/experimental.js b/api/server/experimental.js index 91ef9ef28671..4a457abf6107 100644 --- a/api/server/experimental.js +++ b/api/server/experimental.js @@ -299,6 +299,7 @@ if (cluster.isMaster) { app.use('/api/auth', routes.auth); app.use('/api/actions', routes.actions); app.use('/api/keys', routes.keys); + app.use('/api/api-keys', routes.apiKeys); app.use('/api/user', routes.user); app.use('/api/search', routes.search); app.use('/api/messages', routes.messages); diff --git a/api/server/index.js b/api/server/index.js index a7ddd47f375d..fcd0229c9f3d 100644 --- a/api/server/index.js +++ b/api/server/index.js @@ -134,8 +134,10 @@ const startServer = async () => { app.use('/oauth', routes.oauth); /* API Endpoints */ app.use('/api/auth', routes.auth); + app.use('/api/admin', routes.adminAuth); app.use('/api/actions', routes.actions); app.use('/api/keys', routes.keys); + app.use('/api/api-keys', routes.apiKeys); app.use('/api/user', routes.user); app.use('/api/search', routes.search); app.use('/api/messages', routes.messages); diff --git a/api/server/middleware/checkSharePublicAccess.js b/api/server/middleware/checkSharePublicAccess.js index c094d54acb3a..0e95b9f6f859 100644 --- a/api/server/middleware/checkSharePublicAccess.js +++ b/api/server/middleware/checkSharePublicAccess.js @@ -9,6 +9,7 @@ const resourceToPermissionType = { [ResourceType.AGENT]: PermissionTypes.AGENTS, [ResourceType.PROMPTGROUP]: PermissionTypes.PROMPTS, [ResourceType.MCPSERVER]: PermissionTypes.MCP_SERVERS, + [ResourceType.REMOTE_AGENT]: PermissionTypes.REMOTE_AGENTS, }; /** diff --git a/api/server/routes/accessPermissions.js b/api/server/routes/accessPermissions.js index 79e7f3ddcaca..45afec133b77 100644 --- a/api/server/routes/accessPermissions.js +++ b/api/server/routes/accessPermissions.js @@ -53,6 +53,12 @@ const checkResourcePermissionAccess = (requiredPermission) => (req, res, next) = requiredPermission, resourceIdParam: 'resourceId', }); + } else if (resourceType === ResourceType.REMOTE_AGENT) { + middleware = canAccessResource({ + resourceType: ResourceType.REMOTE_AGENT, + requiredPermission, + resourceIdParam: 'resourceId', + }); } else if (resourceType === ResourceType.PROMPTGROUP) { middleware = canAccessResource({ resourceType: ResourceType.PROMPTGROUP, diff --git a/api/server/routes/admin/auth.js b/api/server/routes/admin/auth.js new file mode 100644 index 000000000000..291b5eaaf85f --- /dev/null +++ b/api/server/routes/admin/auth.js @@ -0,0 +1,127 @@ +const express = require('express'); +const passport = require('passport'); +const { randomState } = require('openid-client'); +const { logger } = require('@librechat/data-schemas'); +const { CacheKeys } = require('librechat-data-provider'); +const { + requireAdmin, + getAdminPanelUrl, + exchangeAdminCode, + createSetBalanceConfig, +} = require('@librechat/api'); +const { loginController } = require('~/server/controllers/auth/LoginController'); +const { createOAuthHandler } = require('~/server/controllers/auth/oauth'); +const { getAppConfig } = require('~/server/services/Config'); +const getLogStores = require('~/cache/getLogStores'); +const { getOpenIdConfig } = require('~/strategies'); +const middleware = require('~/server/middleware'); +const { Balance } = require('~/db/models'); + +const setBalanceConfig = createSetBalanceConfig({ + getAppConfig, + Balance, +}); + +const router = express.Router(); + +router.post( + '/login/local', + middleware.logHeaders, + middleware.loginLimiter, + middleware.checkBan, + middleware.requireLocalAuth, + requireAdmin, + setBalanceConfig, + loginController, +); + +router.get('/verify', middleware.requireJwtAuth, requireAdmin, (req, res) => { + const { password: _p, totpSecret: _t, __v, ...user } = req.user; + user.id = user._id.toString(); + res.status(200).json({ user }); +}); + +router.get('/oauth/openid/check', (req, res) => { + const openidConfig = getOpenIdConfig(); + if (!openidConfig) { + return res.status(404).json({ + error: 'OpenID configuration not found', + error_code: 'OPENID_NOT_CONFIGURED', + }); + } + res.status(200).json({ message: 'OpenID check successful' }); +}); + +router.get('/oauth/openid', (req, res, next) => { + return passport.authenticate('openidAdmin', { + session: false, + state: randomState(), + })(req, res, next); +}); + +router.get( + '/oauth/openid/callback', + passport.authenticate('openidAdmin', { + failureRedirect: `${getAdminPanelUrl()}/auth/openid/callback?error=auth_failed&error_description=Authentication+failed`, + failureMessage: true, + session: false, + }), + requireAdmin, + setBalanceConfig, + middleware.checkDomainAllowed, + createOAuthHandler(`${getAdminPanelUrl()}/auth/openid/callback`), +); + +/** Regex pattern for valid exchange codes: 64 hex characters */ +const EXCHANGE_CODE_PATTERN = /^[a-f0-9]{64}$/i; + +/** + * Exchange OAuth authorization code for tokens. + * This endpoint is called server-to-server by the admin panel. + * The code is one-time-use and expires in 30 seconds. + * + * POST /api/admin/oauth/exchange + * Body: { code: string } + * Response: { token: string, refreshToken: string, user: object } + */ +router.post('/oauth/exchange', middleware.loginLimiter, async (req, res) => { + try { + const { code } = req.body; + + if (!code) { + logger.warn('[admin/oauth/exchange] Missing authorization code'); + return res.status(400).json({ + error: 'Missing authorization code', + error_code: 'MISSING_CODE', + }); + } + + if (typeof code !== 'string' || !EXCHANGE_CODE_PATTERN.test(code)) { + logger.warn('[admin/oauth/exchange] Invalid authorization code format'); + return res.status(400).json({ + error: 'Invalid authorization code format', + error_code: 'INVALID_CODE_FORMAT', + }); + } + + const cache = getLogStores(CacheKeys.ADMIN_OAUTH_EXCHANGE); + const result = await exchangeAdminCode(cache, code); + + if (!result) { + return res.status(401).json({ + error: 'Invalid or expired authorization code', + error_code: 'INVALID_OR_EXPIRED_CODE', + }); + } + + res.json(result); + } catch (error) { + logger.error('[admin/oauth/exchange] Error:', error); + res.status(500).json({ + error: 'Internal server error', + error_code: 'INTERNAL_ERROR', + }); + } +}); + +module.exports = router; diff --git a/api/server/routes/agents/__tests__/abort.spec.js b/api/server/routes/agents/__tests__/abort.spec.js index e879d51452b5..442665d97376 100644 --- a/api/server/routes/agents/__tests__/abort.spec.js +++ b/api/server/routes/agents/__tests__/abort.spec.js @@ -26,10 +26,12 @@ const mockGenerationJobManager = { const mockSaveMessage = jest.fn(); jest.mock('@librechat/data-schemas', () => ({ + ...jest.requireActual('@librechat/data-schemas'), logger: mockLogger, })); jest.mock('@librechat/api', () => ({ + ...jest.requireActual('@librechat/api'), isEnabled: jest.fn().mockReturnValue(false), GenerationJobManager: mockGenerationJobManager, })); diff --git a/api/server/routes/agents/__tests__/responses.spec.js b/api/server/routes/agents/__tests__/responses.spec.js new file mode 100644 index 000000000000..4d83219b84b3 --- /dev/null +++ b/api/server/routes/agents/__tests__/responses.spec.js @@ -0,0 +1,1125 @@ +/** + * Open Responses API Integration Tests + * + * Tests the /v1/responses endpoint against the Open Responses specification + * compliance tests. Uses real Anthropic API for LLM calls. + * + * @see https://openresponses.org/specification + * @see https://github.com/openresponses/openresponses/blob/main/src/lib/compliance-tests.ts + */ + +// Load environment variables from root .env file for API keys +require('dotenv').config({ path: require('path').resolve(__dirname, '../../../../../.env') }); + +const originalEnv = { + CREDS_KEY: process.env.CREDS_KEY, + CREDS_IV: process.env.CREDS_IV, +}; + +process.env.CREDS_KEY = '0123456789abcdef0123456789abcdef'; +process.env.CREDS_IV = '0123456789abcdef'; + +/** Skip tests if ANTHROPIC_API_KEY is not available */ +const SKIP_INTEGRATION_TESTS = !process.env.ANTHROPIC_API_KEY; +if (SKIP_INTEGRATION_TESTS) { + console.warn('ANTHROPIC_API_KEY not found - skipping integration tests'); +} + +jest.mock('meilisearch', () => ({ + MeiliSearch: jest.fn().mockImplementation(() => ({ + getIndex: jest.fn().mockRejectedValue(new Error('mocked')), + index: jest.fn().mockReturnValue({ + getRawInfo: jest.fn().mockResolvedValue({ primaryKey: 'id' }), + updateSettings: jest.fn().mockResolvedValue({}), + addDocuments: jest.fn().mockResolvedValue({}), + updateDocuments: jest.fn().mockResolvedValue({}), + deleteDocument: jest.fn().mockResolvedValue({}), + }), + })), +})); + +jest.mock('~/server/services/Config', () => ({ + loadCustomConfig: jest.fn(() => Promise.resolve({})), + getAppConfig: jest.fn().mockResolvedValue({ + paths: { + uploads: '/tmp', + dist: '/tmp/dist', + fonts: '/tmp/fonts', + assets: '/tmp/assets', + }, + fileStrategy: 'local', + imageOutputType: 'PNG', + endpoints: { + agents: { + allowedProviders: ['anthropic', 'openAI'], + }, + }, + }), + setCachedTools: jest.fn(), + getCachedTools: jest.fn(), + getMCPServerTools: jest.fn().mockReturnValue([]), +})); + +jest.mock('~/app/clients/tools', () => ({ + createOpenAIImageTools: jest.fn(() => []), + createYouTubeTools: jest.fn(() => []), + manifestToolMap: {}, + toolkits: [], +})); + +jest.mock('~/config', () => ({ + createMCPServersRegistry: jest.fn(), + createMCPManager: jest.fn().mockResolvedValue({ + getAppToolFunctions: jest.fn().mockResolvedValue({}), + }), +})); + +const express = require('express'); +const request = require('supertest'); +const mongoose = require('mongoose'); +const { v4: uuidv4 } = require('uuid'); +const { MongoMemoryServer } = require('mongodb-memory-server'); +const { hashToken, getRandomValues, createModels } = require('@librechat/data-schemas'); +const { + SystemRoles, + ResourceType, + AccessRoleIds, + PrincipalType, + PrincipalModel, + PermissionBits, + EModelEndpoint, +} = require('librechat-data-provider'); + +/** @type {import('mongoose').Model} */ +let Agent; +/** @type {import('mongoose').Model} */ +let AgentApiKey; +/** @type {import('mongoose').Model} */ +let User; +/** @type {import('mongoose').Model} */ +let AclEntry; +/** @type {import('mongoose').Model} */ +let AccessRole; + +/** + * Parse SSE stream into events + * @param {string} text - Raw SSE text + * @returns {Array<{event: string, data: unknown}>} + */ +function parseSSEEvents(text) { + const events = []; + const lines = text.split('\n'); + + let currentEvent = ''; + let currentData = ''; + + for (const line of lines) { + if (line.startsWith('event:')) { + currentEvent = line.slice(6).trim(); + } else if (line.startsWith('data:')) { + currentData = line.slice(5).trim(); + } else if (line === '' && currentData) { + if (currentData === '[DONE]') { + events.push({ event: 'done', data: '[DONE]' }); + } else { + try { + const parsed = JSON.parse(currentData); + events.push({ + event: currentEvent || parsed.type || 'unknown', + data: parsed, + }); + } catch { + // Skip unparseable data + } + } + currentEvent = ''; + currentData = ''; + } + } + + return events; +} + +/** + * Valid streaming event types per Open Responses specification + * @see https://github.com/openresponses/openresponses/blob/main/src/lib/sse-parser.ts + */ +const VALID_STREAMING_EVENT_TYPES = new Set([ + // Standard Open Responses events + 'response.created', + 'response.queued', + 'response.in_progress', + 'response.completed', + 'response.failed', + 'response.incomplete', + 'response.output_item.added', + 'response.output_item.done', + 'response.content_part.added', + 'response.content_part.done', + 'response.output_text.delta', + 'response.output_text.done', + 'response.refusal.delta', + 'response.refusal.done', + 'response.function_call_arguments.delta', + 'response.function_call_arguments.done', + 'response.reasoning_summary_part.added', + 'response.reasoning_summary_part.done', + 'response.reasoning.delta', + 'response.reasoning.done', + 'response.reasoning_summary_text.delta', + 'response.reasoning_summary_text.done', + 'response.output_text.annotation.added', + 'error', + // LibreChat extension events (prefixed per Open Responses spec) + // @see https://openresponses.org/specification#extending-streaming-events + 'librechat:attachment', +]); + +/** + * Validate a streaming event against Open Responses spec + * @param {Object} event - Parsed event with data + * @returns {string[]} Array of validation errors + */ +function validateStreamingEvent(event) { + const errors = []; + const data = event.data; + + if (!data || typeof data !== 'object') { + return errors; // Skip non-object data (e.g., [DONE]) + } + + const eventType = data.type; + + // Check event type is valid + if (!VALID_STREAMING_EVENT_TYPES.has(eventType)) { + errors.push(`Invalid event type: ${eventType}`); + return errors; + } + + // Validate required fields based on event type + switch (eventType) { + case 'response.output_text.delta': + if (typeof data.sequence_number !== 'number') { + errors.push('response.output_text.delta: missing sequence_number'); + } + if (typeof data.item_id !== 'string') { + errors.push('response.output_text.delta: missing item_id'); + } + if (typeof data.output_index !== 'number') { + errors.push('response.output_text.delta: missing output_index'); + } + if (typeof data.content_index !== 'number') { + errors.push('response.output_text.delta: missing content_index'); + } + if (typeof data.delta !== 'string') { + errors.push('response.output_text.delta: missing delta'); + } + if (!Array.isArray(data.logprobs)) { + errors.push('response.output_text.delta: missing logprobs array'); + } + break; + + case 'response.output_text.done': + if (typeof data.sequence_number !== 'number') { + errors.push('response.output_text.done: missing sequence_number'); + } + if (typeof data.item_id !== 'string') { + errors.push('response.output_text.done: missing item_id'); + } + if (typeof data.output_index !== 'number') { + errors.push('response.output_text.done: missing output_index'); + } + if (typeof data.content_index !== 'number') { + errors.push('response.output_text.done: missing content_index'); + } + if (typeof data.text !== 'string') { + errors.push('response.output_text.done: missing text'); + } + if (!Array.isArray(data.logprobs)) { + errors.push('response.output_text.done: missing logprobs array'); + } + break; + + case 'response.reasoning.delta': + if (typeof data.sequence_number !== 'number') { + errors.push('response.reasoning.delta: missing sequence_number'); + } + if (typeof data.item_id !== 'string') { + errors.push('response.reasoning.delta: missing item_id'); + } + if (typeof data.output_index !== 'number') { + errors.push('response.reasoning.delta: missing output_index'); + } + if (typeof data.content_index !== 'number') { + errors.push('response.reasoning.delta: missing content_index'); + } + if (typeof data.delta !== 'string') { + errors.push('response.reasoning.delta: missing delta'); + } + break; + + case 'response.reasoning.done': + if (typeof data.sequence_number !== 'number') { + errors.push('response.reasoning.done: missing sequence_number'); + } + if (typeof data.item_id !== 'string') { + errors.push('response.reasoning.done: missing item_id'); + } + if (typeof data.output_index !== 'number') { + errors.push('response.reasoning.done: missing output_index'); + } + if (typeof data.content_index !== 'number') { + errors.push('response.reasoning.done: missing content_index'); + } + if (typeof data.text !== 'string') { + errors.push('response.reasoning.done: missing text'); + } + break; + + case 'response.in_progress': + case 'response.completed': + case 'response.failed': + if (!data.response || typeof data.response !== 'object') { + errors.push(`${eventType}: missing response object`); + } + break; + + case 'response.output_item.added': + case 'response.output_item.done': + if (typeof data.output_index !== 'number') { + errors.push(`${eventType}: missing output_index`); + } + if (!data.item || typeof data.item !== 'object') { + errors.push(`${eventType}: missing item object`); + } + break; + } + + return errors; +} + +/** + * Validate all streaming events and return errors + * @param {Array} events - Array of parsed events + * @returns {string[]} Array of all validation errors + */ +function validateAllStreamingEvents(events) { + const allErrors = []; + for (const event of events) { + const errors = validateStreamingEvent(event); + allErrors.push(...errors); + } + return allErrors; +} + +/** + * Create a test agent with Anthropic provider + * @param {Object} overrides + * @returns {Promise} + */ +async function createTestAgent(overrides = {}) { + const timestamp = new Date(); + const agentData = { + id: `agent_${uuidv4().replace(/-/g, '').substring(0, 21)}`, + name: 'Test Anthropic Agent', + description: 'An agent for testing Open Responses API', + instructions: 'You are a helpful assistant. Be concise.', + provider: EModelEndpoint.anthropic, + model: 'claude-sonnet-4-5-20250929', + author: new mongoose.Types.ObjectId(), + tools: [], + model_parameters: {}, + ...overrides, + }; + + const versionData = { ...agentData }; + delete versionData.author; + + const initialAgentData = { + ...agentData, + versions: [ + { + ...versionData, + createdAt: timestamp, + updatedAt: timestamp, + }, + ], + category: 'general', + }; + + return (await Agent.create(initialAgentData)).toObject(); +} + +/** + * Create an agent with extended thinking enabled + * @param {Object} overrides + * @returns {Promise} + */ +async function createThinkingAgent(overrides = {}) { + return createTestAgent({ + name: 'Test Thinking Agent', + description: 'An agent with extended thinking enabled', + model_parameters: { + thinking: { + type: 'enabled', + budget_tokens: 5000, + }, + }, + ...overrides, + }); +} + +const describeWithApiKey = SKIP_INTEGRATION_TESTS ? describe.skip : describe; + +describeWithApiKey('Open Responses API Integration Tests', () => { + // Increase timeout for real API calls + jest.setTimeout(120000); + + let mongoServer; + let app; + let testAgent; + let thinkingAgent; + let testUser; + let testApiKey; // The raw API key for Authorization header + + afterAll(() => { + process.env.CREDS_KEY = originalEnv.CREDS_KEY; + process.env.CREDS_IV = originalEnv.CREDS_IV; + }); + + beforeAll(async () => { + // Start MongoDB Memory Server + mongoServer = await MongoMemoryServer.create(); + const mongoUri = mongoServer.getUri(); + + // Connect to MongoDB + await mongoose.connect(mongoUri); + + // Register all models + const models = createModels(mongoose); + + // Get models + Agent = models.Agent; + AgentApiKey = models.AgentApiKey; + User = models.User; + AclEntry = models.AclEntry; + AccessRole = models.AccessRole; + + // Create minimal Express app with just the responses routes + app = express(); + app.use(express.json()); + + // Mount the responses routes + const responsesRoutes = require('~/server/routes/agents/responses'); + app.use('/api/agents/v1/responses', responsesRoutes); + + // Create test user + testUser = await User.create({ + name: 'Test API User', + username: 'testapiuser', + email: 'testapiuser@test.com', + emailVerified: true, + provider: 'local', + role: SystemRoles.ADMIN, + }); + + // Create REMOTE_AGENT access roles (if they don't exist) + const existingRoles = await AccessRole.find({ + accessRoleId: { + $in: [ + AccessRoleIds.REMOTE_AGENT_VIEWER, + AccessRoleIds.REMOTE_AGENT_EDITOR, + AccessRoleIds.REMOTE_AGENT_OWNER, + ], + }, + }); + + if (existingRoles.length === 0) { + await AccessRole.create([ + { + accessRoleId: AccessRoleIds.REMOTE_AGENT_VIEWER, + name: 'API Viewer', + description: 'Can query the agent via API', + resourceType: ResourceType.REMOTE_AGENT, + permBits: PermissionBits.VIEW, + }, + { + accessRoleId: AccessRoleIds.REMOTE_AGENT_EDITOR, + name: 'API Editor', + description: 'Can view and modify the agent via API', + resourceType: ResourceType.REMOTE_AGENT, + permBits: PermissionBits.VIEW | PermissionBits.EDIT, + }, + { + accessRoleId: AccessRoleIds.REMOTE_AGENT_OWNER, + name: 'API Owner', + description: 'Full API access + can grant remote access to others', + resourceType: ResourceType.REMOTE_AGENT, + permBits: + PermissionBits.VIEW | + PermissionBits.EDIT | + PermissionBits.DELETE | + PermissionBits.SHARE, + }, + ]); + } + + // Generate and create an API key for the test user + const rawKey = `sk-${await getRandomValues(32)}`; + const keyHash = await hashToken(rawKey); + const keyPrefix = rawKey.substring(0, 8); + + await AgentApiKey.create({ + userId: testUser._id, + name: 'Test API Key', + keyHash, + keyPrefix, + }); + + testApiKey = rawKey; + + // Create test agents with the test user as author + testAgent = await createTestAgent({ author: testUser._id }); + thinkingAgent = await createThinkingAgent({ author: testUser._id }); + + // Grant REMOTE_AGENT permissions for the test agents + await AclEntry.create([ + { + principalType: PrincipalType.USER, + principalModel: PrincipalModel.USER, + principalId: testUser._id, + resourceType: ResourceType.REMOTE_AGENT, + resourceId: testAgent._id, + accessRoleId: AccessRoleIds.REMOTE_AGENT_OWNER, + permBits: + PermissionBits.VIEW | PermissionBits.EDIT | PermissionBits.DELETE | PermissionBits.SHARE, + }, + { + principalType: PrincipalType.USER, + principalModel: PrincipalModel.USER, + principalId: testUser._id, + resourceType: ResourceType.REMOTE_AGENT, + resourceId: thinkingAgent._id, + accessRoleId: AccessRoleIds.REMOTE_AGENT_OWNER, + permBits: + PermissionBits.VIEW | PermissionBits.EDIT | PermissionBits.DELETE | PermissionBits.SHARE, + }, + ]); + }, 60000); + + afterAll(async () => { + await mongoose.disconnect(); + await mongoServer.stop(); + }); + + beforeEach(async () => { + // Clean up any test data between tests if needed + }); + + /* =========================================================================== + * COMPLIANCE TESTS + * Based on: https://github.com/openresponses/openresponses/blob/main/src/lib/compliance-tests.ts + * =========================================================================== */ + + /** Helper to add auth header to requests */ + const authRequest = () => ({ + post: (url) => request(app).post(url).set('Authorization', `Bearer ${testApiKey}`), + get: (url) => request(app).get(url).set('Authorization', `Bearer ${testApiKey}`), + }); + + describe('Compliance Tests', () => { + describe('basic-response', () => { + it('should return a valid ResponseResource for a simple text request', async () => { + const response = await authRequest() + .post('/api/agents/v1/responses') + .send({ + model: testAgent.id, + input: [ + { + type: 'message', + role: 'user', + content: 'Say hello in exactly 3 words.', + }, + ], + }); + + expect(response.status).toBe(200); + expect(response.body).toBeDefined(); + + // Validate ResponseResource schema + const body = response.body; + expect(body.id).toMatch(/^resp_/); + expect(body.object).toBe('response'); + expect(typeof body.created_at).toBe('number'); + expect(body.status).toBe('completed'); + expect(body.model).toBe(testAgent.id); + + // Validate output + expect(Array.isArray(body.output)).toBe(true); + expect(body.output.length).toBeGreaterThan(0); + + // Should have at least one message item + const messageItem = body.output.find((item) => item.type === 'message'); + expect(messageItem).toBeDefined(); + expect(messageItem.role).toBe('assistant'); + expect(messageItem.status).toBe('completed'); + expect(Array.isArray(messageItem.content)).toBe(true); + }); + }); + + describe('streaming-response', () => { + it('should return valid SSE streaming events', async () => { + const response = await authRequest() + .post('/api/agents/v1/responses') + .send({ + model: testAgent.id, + input: [ + { + type: 'message', + role: 'user', + content: 'Count from 1 to 5.', + }, + ], + stream: true, + }) + .buffer(true) + .parse((res, callback) => { + let data = ''; + res.on('data', (chunk) => { + data += chunk.toString(); + }); + res.on('end', () => { + callback(null, data); + }); + }); + + expect(response.status).toBe(200); + expect(response.headers['content-type']).toMatch(/text\/event-stream/); + + const events = parseSSEEvents(response.body); + expect(events.length).toBeGreaterThan(0); + + // Validate all streaming events against Open Responses spec + // This catches issues like: + // - Invalid event types (e.g., response.reasoning_text.delta instead of response.reasoning.delta) + // - Missing required fields (e.g., logprobs on output_text events) + const validationErrors = validateAllStreamingEvents(events); + if (validationErrors.length > 0) { + console.error('Streaming event validation errors:', validationErrors); + } + expect(validationErrors).toEqual([]); + + // Validate streaming event types + const eventTypes = events.map((e) => e.event); + + // Should have response.created first (per Open Responses spec) + expect(eventTypes).toContain('response.created'); + + // Should have response.in_progress + expect(eventTypes).toContain('response.in_progress'); + + // response.created should come before response.in_progress + const createdIdx = eventTypes.indexOf('response.created'); + const inProgressIdx = eventTypes.indexOf('response.in_progress'); + expect(createdIdx).toBeLessThan(inProgressIdx); + + // Should have response.completed or response.failed + expect(eventTypes.some((t) => t === 'response.completed' || t === 'response.failed')).toBe( + true, + ); + + // Should have [DONE] + expect(eventTypes).toContain('done'); + + // Validate response.completed has full response + const completedEvent = events.find((e) => e.event === 'response.completed'); + if (completedEvent) { + expect(completedEvent.data.response).toBeDefined(); + expect(completedEvent.data.response.status).toBe('completed'); + expect(completedEvent.data.response.output.length).toBeGreaterThan(0); + } + }); + + it('should emit valid event types per Open Responses spec', async () => { + const response = await authRequest() + .post('/api/agents/v1/responses') + .send({ + model: testAgent.id, + input: [ + { + type: 'message', + role: 'user', + content: 'Say hi.', + }, + ], + stream: true, + }) + .buffer(true) + .parse((res, callback) => { + let data = ''; + res.on('data', (chunk) => { + data += chunk.toString(); + }); + res.on('end', () => { + callback(null, data); + }); + }); + + expect(response.status).toBe(200); + + const events = parseSSEEvents(response.body); + + // Check all event types are valid + for (const event of events) { + if (event.data && typeof event.data === 'object' && event.data.type) { + expect(VALID_STREAMING_EVENT_TYPES.has(event.data.type)).toBe(true); + } + } + }); + + it('should include logprobs array in output_text events', async () => { + const response = await authRequest() + .post('/api/agents/v1/responses') + .send({ + model: testAgent.id, + input: [ + { + type: 'message', + role: 'user', + content: 'Say one word.', + }, + ], + stream: true, + }) + .buffer(true) + .parse((res, callback) => { + let data = ''; + res.on('data', (chunk) => { + data += chunk.toString(); + }); + res.on('end', () => { + callback(null, data); + }); + }); + + expect(response.status).toBe(200); + + const events = parseSSEEvents(response.body); + + // Find output_text delta/done events and verify logprobs + const textDeltaEvents = events.filter( + (e) => e.data && e.data.type === 'response.output_text.delta', + ); + const textDoneEvents = events.filter( + (e) => e.data && e.data.type === 'response.output_text.done', + ); + + // Should have at least one output_text event + expect(textDeltaEvents.length + textDoneEvents.length).toBeGreaterThan(0); + + // All output_text.delta events must have logprobs array + for (const event of textDeltaEvents) { + expect(Array.isArray(event.data.logprobs)).toBe(true); + } + + // All output_text.done events must have logprobs array + for (const event of textDoneEvents) { + expect(Array.isArray(event.data.logprobs)).toBe(true); + } + }); + }); + + describe('system-prompt', () => { + it('should handle developer role messages in input (as system)', async () => { + // Note: For Anthropic, system messages must be first and there can only be one. + // Since the agent already has instructions, we use 'developer' role which + // gets merged into the system prompt, or we test with a simple user message + // that instructs the behavior. + const response = await authRequest() + .post('/api/agents/v1/responses') + .send({ + model: testAgent.id, + input: [ + { + type: 'message', + role: 'user', + content: 'Pretend you are a pirate and say hello in pirate speak.', + }, + ], + }); + + expect(response.status).toBe(200); + expect(response.body.status).toBe('completed'); + expect(response.body.output.length).toBeGreaterThan(0); + + // The response should reflect the pirate persona + const messageItem = response.body.output.find((item) => item.type === 'message'); + expect(messageItem).toBeDefined(); + expect(messageItem.content.length).toBeGreaterThan(0); + }); + }); + + describe('multi-turn', () => { + it('should handle multi-turn conversation history', async () => { + const response = await authRequest() + .post('/api/agents/v1/responses') + .send({ + model: testAgent.id, + input: [ + { + type: 'message', + role: 'user', + content: 'My name is Alice.', + }, + { + type: 'message', + role: 'assistant', + content: 'Hello Alice! Nice to meet you. How can I help you today?', + }, + { + type: 'message', + role: 'user', + content: 'What is my name?', + }, + ], + }); + + expect(response.status).toBe(200); + expect(response.body.status).toBe('completed'); + + // The response should reference "Alice" + const messageItem = response.body.output.find((item) => item.type === 'message'); + expect(messageItem).toBeDefined(); + + const textContent = messageItem.content.find((c) => c.type === 'output_text'); + expect(textContent).toBeDefined(); + expect(textContent.text.toLowerCase()).toContain('alice'); + }); + }); + + // Note: tool-calling test requires tool setup which may need additional configuration + // Note: image-input test requires vision-capable model + + describe('string-input', () => { + it('should accept simple string input', async () => { + const response = await authRequest().post('/api/agents/v1/responses').send({ + model: testAgent.id, + input: 'Hello!', + }); + + expect(response.status).toBe(200); + expect(response.body.status).toBe('completed'); + expect(response.body.output.length).toBeGreaterThan(0); + }); + }); + }); + + /* =========================================================================== + * EXTENDED THINKING TESTS + * Tests reasoning output from Claude models with extended thinking enabled + * =========================================================================== */ + + describe('Extended Thinking', () => { + it('should return reasoning output when thinking is enabled', async () => { + const response = await authRequest() + .post('/api/agents/v1/responses') + .send({ + model: thinkingAgent.id, + input: [ + { + type: 'message', + role: 'user', + content: 'What is 15 * 7? Think step by step.', + }, + ], + }); + + expect(response.status).toBe(200); + expect(response.body.status).toBe('completed'); + + // Check for reasoning item in output + const reasoningItem = response.body.output.find((item) => item.type === 'reasoning'); + // If reasoning is present, validate its structure per Open Responses spec + // Note: reasoning items do NOT have a 'status' field per the spec + // @see https://github.com/openresponses/openresponses/blob/main/src/generated/kubb/zod/reasoningBodySchema.ts + if (reasoningItem) { + expect(reasoningItem).toHaveProperty('id'); + expect(reasoningItem).toHaveProperty('type', 'reasoning'); + // Note: 'status' is NOT a field on reasoning items per the spec + expect(reasoningItem).toHaveProperty('summary'); + expect(Array.isArray(reasoningItem.summary)).toBe(true); + + // Validate content items + if (reasoningItem.content && reasoningItem.content.length > 0) { + const reasoningContent = reasoningItem.content[0]; + expect(reasoningContent).toHaveProperty('type', 'reasoning_text'); + expect(reasoningContent).toHaveProperty('text'); + } + } + + const messageItem = response.body.output.find((item) => item.type === 'message'); + expect(messageItem).toBeDefined(); + }); + + it('should stream reasoning events when thinking is enabled', async () => { + const response = await authRequest() + .post('/api/agents/v1/responses') + .send({ + model: thinkingAgent.id, + input: [ + { + type: 'message', + role: 'user', + content: 'What is 12 + 8? Think step by step.', + }, + ], + stream: true, + }) + .buffer(true) + .parse((res, callback) => { + let data = ''; + res.on('data', (chunk) => { + data += chunk.toString(); + }); + res.on('end', () => { + callback(null, data); + }); + }); + + expect(response.status).toBe(200); + + const events = parseSSEEvents(response.body); + + // Validate all events against Open Responses spec + const validationErrors = validateAllStreamingEvents(events); + if (validationErrors.length > 0) { + console.error('Reasoning streaming event validation errors:', validationErrors); + } + expect(validationErrors).toEqual([]); + + // Check for reasoning-related events using correct event types per Open Responses spec + // Note: The spec uses response.reasoning.delta NOT response.reasoning_text.delta + const reasoningDeltaEvents = events.filter( + (e) => e.data && e.data.type === 'response.reasoning.delta', + ); + const reasoningDoneEvents = events.filter( + (e) => e.data && e.data.type === 'response.reasoning.done', + ); + + // If reasoning events are present, validate their structure + if (reasoningDeltaEvents.length > 0) { + const deltaEvent = reasoningDeltaEvents[0]; + expect(deltaEvent.data).toHaveProperty('item_id'); + expect(deltaEvent.data).toHaveProperty('delta'); + expect(deltaEvent.data).toHaveProperty('output_index'); + expect(deltaEvent.data).toHaveProperty('content_index'); + expect(deltaEvent.data).toHaveProperty('sequence_number'); + } + + if (reasoningDoneEvents.length > 0) { + const doneEvent = reasoningDoneEvents[0]; + expect(doneEvent.data).toHaveProperty('item_id'); + expect(doneEvent.data).toHaveProperty('text'); + expect(doneEvent.data).toHaveProperty('output_index'); + expect(doneEvent.data).toHaveProperty('content_index'); + expect(doneEvent.data).toHaveProperty('sequence_number'); + } + + // Verify stream completed properly + const eventTypes = events.map((e) => e.event); + expect(eventTypes).toContain('response.completed'); + }); + }); + + /* =========================================================================== + * SCHEMA VALIDATION TESTS + * Verify response schema compliance + * =========================================================================== */ + + describe('Schema Validation', () => { + it('should include all required fields in response', async () => { + const response = await authRequest().post('/api/agents/v1/responses').send({ + model: testAgent.id, + input: 'Test', + }); + + expect(response.status).toBe(200); + const body = response.body; + + // Required fields per Open Responses spec + expect(body).toHaveProperty('id'); + expect(body).toHaveProperty('object', 'response'); + expect(body).toHaveProperty('created_at'); + expect(body).toHaveProperty('completed_at'); + expect(body).toHaveProperty('status'); + expect(body).toHaveProperty('model'); + expect(body).toHaveProperty('output'); + expect(body).toHaveProperty('tools'); + expect(body).toHaveProperty('tool_choice'); + expect(body).toHaveProperty('truncation'); + expect(body).toHaveProperty('parallel_tool_calls'); + expect(body).toHaveProperty('text'); + expect(body).toHaveProperty('temperature'); + expect(body).toHaveProperty('top_p'); + expect(body).toHaveProperty('presence_penalty'); + expect(body).toHaveProperty('frequency_penalty'); + expect(body).toHaveProperty('top_logprobs'); + expect(body).toHaveProperty('store'); + expect(body).toHaveProperty('background'); + expect(body).toHaveProperty('service_tier'); + expect(body).toHaveProperty('metadata'); + + // top_logprobs must be a number (not null) + expect(typeof body.top_logprobs).toBe('number'); + + // Usage must have required detail fields + expect(body).toHaveProperty('usage'); + expect(body.usage).toHaveProperty('input_tokens'); + expect(body.usage).toHaveProperty('output_tokens'); + expect(body.usage).toHaveProperty('total_tokens'); + expect(body.usage).toHaveProperty('input_tokens_details'); + expect(body.usage).toHaveProperty('output_tokens_details'); + expect(body.usage.input_tokens_details).toHaveProperty('cached_tokens'); + expect(body.usage.output_tokens_details).toHaveProperty('reasoning_tokens'); + }); + + it('should have valid message item structure', async () => { + const response = await authRequest().post('/api/agents/v1/responses').send({ + model: testAgent.id, + input: 'Hello', + }); + + expect(response.status).toBe(200); + + const messageItem = response.body.output.find((item) => item.type === 'message'); + expect(messageItem).toBeDefined(); + + // Message item required fields + expect(messageItem).toHaveProperty('type', 'message'); + expect(messageItem).toHaveProperty('id'); + expect(messageItem).toHaveProperty('status'); + expect(messageItem).toHaveProperty('role', 'assistant'); + expect(messageItem).toHaveProperty('content'); + expect(Array.isArray(messageItem.content)).toBe(true); + + // Content part structure - verify all required fields + if (messageItem.content.length > 0) { + const textContent = messageItem.content.find((c) => c.type === 'output_text'); + if (textContent) { + expect(textContent).toHaveProperty('type', 'output_text'); + expect(textContent).toHaveProperty('text'); + expect(textContent).toHaveProperty('annotations'); + expect(textContent).toHaveProperty('logprobs'); + expect(Array.isArray(textContent.annotations)).toBe(true); + expect(Array.isArray(textContent.logprobs)).toBe(true); + } + } + + // Verify reasoning item has required summary field + const reasoningItem = response.body.output.find((item) => item.type === 'reasoning'); + if (reasoningItem) { + expect(reasoningItem).toHaveProperty('type', 'reasoning'); + expect(reasoningItem).toHaveProperty('id'); + expect(reasoningItem).toHaveProperty('summary'); + expect(Array.isArray(reasoningItem.summary)).toBe(true); + } + }); + }); + + /* =========================================================================== + * RESPONSE STORAGE TESTS + * Tests for store: true and GET /v1/responses/:id + * =========================================================================== */ + + describe('Response Storage', () => { + it('should store response when store: true and retrieve it', async () => { + // Create a stored response + const createResponse = await authRequest().post('/api/agents/v1/responses').send({ + model: testAgent.id, + input: 'Remember this: The answer is 42.', + store: true, + }); + + expect(createResponse.status).toBe(200); + expect(createResponse.body.status).toBe('completed'); + + const responseId = createResponse.body.id; + expect(responseId).toMatch(/^resp_/); + + // Small delay to ensure database write completes + await new Promise((resolve) => setTimeout(resolve, 500)); + + // Retrieve the stored response + const getResponseResult = await authRequest().get(`/api/agents/v1/responses/${responseId}`); + + // Note: The response might be stored under conversationId, not responseId + // If we get 404, that's expected behavior for now since we store by conversationId + if (getResponseResult.status === 200) { + expect(getResponseResult.body.object).toBe('response'); + expect(getResponseResult.body.status).toBe('completed'); + expect(getResponseResult.body.output.length).toBeGreaterThan(0); + } + }); + + it('should return 404 for non-existent response', async () => { + const response = await authRequest().get('/api/agents/v1/responses/resp_nonexistent123'); + + expect(response.status).toBe(404); + expect(response.body.error).toBeDefined(); + }); + }); + + /* =========================================================================== + * ERROR HANDLING TESTS + * =========================================================================== */ + + describe('Error Handling', () => { + it('should return error for missing model', async () => { + const response = await authRequest().post('/api/agents/v1/responses').send({ + input: 'Hello', + }); + + expect(response.status).toBe(400); + expect(response.body.error).toBeDefined(); + }); + + it('should return error for missing input', async () => { + const response = await authRequest().post('/api/agents/v1/responses').send({ + model: testAgent.id, + }); + + expect(response.status).toBe(400); + expect(response.body.error).toBeDefined(); + }); + + it('should return error for non-existent agent', async () => { + const response = await authRequest().post('/api/agents/v1/responses').send({ + model: 'agent_nonexistent123456789', + input: 'Hello', + }); + + expect(response.status).toBe(404); + expect(response.body.error).toBeDefined(); + }); + }); + + /* =========================================================================== + * MODELS ENDPOINT TESTS + * =========================================================================== */ + + describe('GET /v1/responses/models', () => { + it('should list available agents as models', async () => { + const response = await authRequest().get('/api/agents/v1/responses/models'); + + expect(response.status).toBe(200); + expect(response.body.object).toBe('list'); + expect(Array.isArray(response.body.data)).toBe(true); + + // Should include our test agent + const foundAgent = response.body.data.find((m) => m.id === testAgent.id); + expect(foundAgent).toBeDefined(); + expect(foundAgent.object).toBe('model'); + expect(foundAgent.name).toBe(testAgent.name); + }); + }); +}); diff --git a/api/server/routes/agents/index.js b/api/server/routes/agents/index.js index bf790aeee8c6..f8d39cb4d82b 100644 --- a/api/server/routes/agents/index.js +++ b/api/server/routes/agents/index.js @@ -10,6 +10,8 @@ const { messageUserLimiter, } = require('~/server/middleware'); const { saveMessage } = require('~/models'); +const openai = require('./openai'); +const responses = require('./responses'); const { v1 } = require('./v1'); const chat = require('./chat'); @@ -17,6 +19,20 @@ const { LIMIT_MESSAGE_IP, LIMIT_MESSAGE_USER } = process.env ?? {}; const router = express.Router(); +/** + * Open Responses API routes (API key authentication handled in route file) + * Mounted at /agents/v1/responses (full path: /api/agents/v1/responses) + * NOTE: Must be mounted BEFORE /v1 to avoid being caught by the less specific route + * @see https://openresponses.org/specification + */ +router.use('/v1/responses', responses); + +/** + * OpenAI-compatible API routes (API key authentication handled in route file) + * Mounted at /agents/v1 (full path: /api/agents/v1/chat/completions) + */ +router.use('/v1', openai); + router.use(requireJwtAuth); router.use(checkBan); router.use(uaParser); diff --git a/api/server/routes/agents/openai.js b/api/server/routes/agents/openai.js new file mode 100644 index 000000000000..9a0d9a356452 --- /dev/null +++ b/api/server/routes/agents/openai.js @@ -0,0 +1,110 @@ +/** + * OpenAI-compatible API routes for LibreChat agents. + * + * Provides a /v1/chat/completions compatible interface for + * interacting with LibreChat agents remotely via API. + * + * Usage: + * POST /v1/chat/completions - Chat with an agent + * GET /v1/models - List available agents + * GET /v1/models/:model - Get agent details + * + * Request format: + * { + * "model": "agent_id_here", + * "messages": [{"role": "user", "content": "Hello!"}], + * "stream": true + * } + */ +const express = require('express'); +const { PermissionTypes, Permissions } = require('librechat-data-provider'); +const { + generateCheckAccess, + createRequireApiKeyAuth, + createCheckRemoteAgentAccess, +} = require('@librechat/api'); +const { + OpenAIChatCompletionController, + ListModelsController, + GetModelController, +} = require('~/server/controllers/agents/openai'); +const { getEffectivePermissions } = require('~/server/services/PermissionService'); +const { validateAgentApiKey, findUser } = require('~/models'); +const { configMiddleware } = require('~/server/middleware'); +const { getRoleByName } = require('~/models/Role'); +const { getAgent } = require('~/models/Agent'); + +const router = express.Router(); + +const requireApiKeyAuth = createRequireApiKeyAuth({ + validateAgentApiKey, + findUser, +}); + +const checkRemoteAgentsFeature = generateCheckAccess({ + permissionType: PermissionTypes.REMOTE_AGENTS, + permissions: [Permissions.USE], + getRoleByName, +}); + +const checkAgentPermission = createCheckRemoteAgentAccess({ + getAgent, + getEffectivePermissions, +}); + +router.use(requireApiKeyAuth); +router.use(configMiddleware); +router.use(checkRemoteAgentsFeature); + +/** + * @route POST /v1/chat/completions + * @desc OpenAI-compatible chat completions with agents + * @access Private (API key auth required) + * + * Request body: + * { + * "model": "agent_id", // Required: The agent ID to use + * "messages": [...], // Required: Array of chat messages + * "stream": true, // Optional: Whether to stream (default: false) + * "conversation_id": "...", // Optional: Conversation ID for context + * "parent_message_id": "..." // Optional: Parent message for threading + * } + * + * Response (streaming): + * - SSE stream with OpenAI chat.completion.chunk format + * - Includes delta.reasoning for thinking/reasoning content + * + * Response (non-streaming): + * - Standard OpenAI chat.completion format + */ +router.post('/chat/completions', checkAgentPermission, OpenAIChatCompletionController); + +/** + * @route GET /v1/models + * @desc List available agents as models + * @access Private (API key auth required) + * + * Response: + * { + * "object": "list", + * "data": [ + * { + * "id": "agent_id", + * "object": "model", + * "name": "Agent Name", + * "provider": "openai", + * ... + * } + * ] + * } + */ +router.get('/models', ListModelsController); + +/** + * @route GET /v1/models/:model + * @desc Get details for a specific agent/model + * @access Private (API key auth required) + */ +router.get('/models/:model', GetModelController); + +module.exports = router; diff --git a/api/server/routes/agents/responses.js b/api/server/routes/agents/responses.js new file mode 100644 index 000000000000..431942e921c9 --- /dev/null +++ b/api/server/routes/agents/responses.js @@ -0,0 +1,144 @@ +/** + * Open Responses API routes for LibreChat agents. + * + * Implements the Open Responses specification for a forward-looking, + * agentic API that uses items as the fundamental unit and semantic + * streaming events. + * + * Usage: + * POST /v1/responses - Create a response + * GET /v1/models - List available agents + * + * Request format: + * { + * "model": "agent_id_here", + * "input": "Hello!" or [{ type: "message", role: "user", content: "Hello!" }], + * "stream": true, + * "previous_response_id": "optional_conversation_id" + * } + * + * @see https://openresponses.org/specification + */ +const express = require('express'); +const { PermissionTypes, Permissions } = require('librechat-data-provider'); +const { + generateCheckAccess, + createRequireApiKeyAuth, + createCheckRemoteAgentAccess, +} = require('@librechat/api'); +const { + createResponse, + getResponse, + listModels, +} = require('~/server/controllers/agents/responses'); +const { getEffectivePermissions } = require('~/server/services/PermissionService'); +const { validateAgentApiKey, findUser } = require('~/models'); +const { configMiddleware } = require('~/server/middleware'); +const { getRoleByName } = require('~/models/Role'); +const { getAgent } = require('~/models/Agent'); + +const router = express.Router(); + +const requireApiKeyAuth = createRequireApiKeyAuth({ + validateAgentApiKey, + findUser, +}); + +const checkRemoteAgentsFeature = generateCheckAccess({ + permissionType: PermissionTypes.REMOTE_AGENTS, + permissions: [Permissions.USE], + getRoleByName, +}); + +const checkAgentPermission = createCheckRemoteAgentAccess({ + getAgent, + getEffectivePermissions, +}); + +router.use(requireApiKeyAuth); +router.use(configMiddleware); +router.use(checkRemoteAgentsFeature); + +/** + * @route POST /v1/responses + * @desc Create a model response following Open Responses specification + * @access Private (API key auth required) + * + * Request body: + * { + * "model": "agent_id", // Required: The agent ID to use + * "input": "..." | [...], // Required: String or array of input items + * "stream": true, // Optional: Whether to stream (default: false) + * "previous_response_id": "...", // Optional: Previous response for continuation + * "instructions": "...", // Optional: Additional instructions + * "tools": [...], // Optional: Additional tools + * "tool_choice": "auto", // Optional: Tool choice mode + * "max_output_tokens": 4096, // Optional: Max tokens + * "temperature": 0.7 // Optional: Temperature + * } + * + * Response (streaming): + * - SSE stream with semantic events: + * - response.in_progress + * - response.output_item.added + * - response.content_part.added + * - response.output_text.delta + * - response.output_text.done + * - response.function_call_arguments.delta + * - response.output_item.done + * - response.completed + * - [DONE] + * + * Response (non-streaming): + * { + * "id": "resp_xxx", + * "object": "response", + * "created_at": 1234567890, + * "status": "completed", + * "model": "agent_id", + * "output": [...], // Array of output items + * "usage": { ... } + * } + */ +router.post('/', checkAgentPermission, createResponse); + +/** + * @route GET /v1/responses/models + * @desc List available agents as models + * @access Private (API key auth required) + * + * Response: + * { + * "object": "list", + * "data": [ + * { + * "id": "agent_id", + * "object": "model", + * "name": "Agent Name", + * "provider": "openai", + * ... + * } + * ] + * } + */ +router.get('/models', listModels); + +/** + * @route GET /v1/responses/:id + * @desc Retrieve a stored response by ID + * @access Private (API key auth required) + * + * Response: + * { + * "id": "resp_xxx", + * "object": "response", + * "created_at": 1234567890, + * "status": "completed", + * "model": "agent_id", + * "output": [...], + * "usage": { ... } + * } + */ +router.get('/:id', getResponse); + +module.exports = router; diff --git a/api/server/routes/apiKeys.js b/api/server/routes/apiKeys.js new file mode 100644 index 000000000000..29dcc326f5e3 --- /dev/null +++ b/api/server/routes/apiKeys.js @@ -0,0 +1,36 @@ +const express = require('express'); +const { generateCheckAccess, createApiKeyHandlers } = require('@librechat/api'); +const { PermissionTypes, Permissions } = require('librechat-data-provider'); +const { + getAgentApiKeyById, + createAgentApiKey, + deleteAgentApiKey, + listAgentApiKeys, +} = require('~/models'); +const { requireJwtAuth } = require('~/server/middleware'); +const { getRoleByName } = require('~/models/Role'); + +const router = express.Router(); + +const handlers = createApiKeyHandlers({ + createAgentApiKey, + listAgentApiKeys, + deleteAgentApiKey, + getAgentApiKeyById, +}); + +const checkRemoteAgentsUse = generateCheckAccess({ + permissionType: PermissionTypes.REMOTE_AGENTS, + permissions: [Permissions.USE], + getRoleByName, +}); + +router.post('/', requireJwtAuth, checkRemoteAgentsUse, handlers.createApiKey); + +router.get('/', requireJwtAuth, checkRemoteAgentsUse, handlers.listApiKeys); + +router.get('/:id', requireJwtAuth, checkRemoteAgentsUse, handlers.getApiKey); + +router.delete('/:id', requireJwtAuth, checkRemoteAgentsUse, handlers.deleteApiKey); + +module.exports = router; diff --git a/api/server/routes/index.js b/api/server/routes/index.js index f3571099cb5c..6a48919db337 100644 --- a/api/server/routes/index.js +++ b/api/server/routes/index.js @@ -1,6 +1,7 @@ const accessPermissions = require('./accessPermissions'); const assistants = require('./assistants'); const categories = require('./categories'); +const adminAuth = require('./admin/auth'); const endpoints = require('./endpoints'); const staticRoute = require('./static'); const messages = require('./messages'); @@ -9,6 +10,7 @@ const presets = require('./presets'); const prompts = require('./prompts'); const balance = require('./balance'); const actions = require('./actions'); +const apiKeys = require('./apiKeys'); const banner = require('./banner'); const search = require('./search'); const models = require('./models'); @@ -28,7 +30,9 @@ const mcp = require('./mcp'); module.exports = { mcp, auth, + adminAuth, keys, + apiKeys, user, tags, roles, diff --git a/api/server/routes/oauth.js b/api/server/routes/oauth.js index 64d29210ac17..4a2e2f70c616 100644 --- a/api/server/routes/oauth.js +++ b/api/server/routes/oauth.js @@ -4,10 +4,9 @@ const passport = require('passport'); const { randomState } = require('openid-client'); const { logger } = require('@librechat/data-schemas'); const { ErrorTypes } = require('librechat-data-provider'); -const { isEnabled, createSetBalanceConfig } = require('@librechat/api'); -const { checkDomainAllowed, loginLimiter, logHeaders, checkBan } = require('~/server/middleware'); -const { syncUserEntraGroupMemberships } = require('~/server/services/PermissionService'); -const { setAuthTokens, setOpenIDAuthTokens } = require('~/server/services/AuthService'); +const { createSetBalanceConfig } = require('@librechat/api'); +const { checkDomainAllowed, loginLimiter, logHeaders } = require('~/server/middleware'); +const { createOAuthHandler } = require('~/server/controllers/auth/oauth'); const { getAppConfig } = require('~/server/services/Config'); const { Balance } = require('~/db/models'); @@ -26,32 +25,7 @@ const domains = { router.use(logHeaders); router.use(loginLimiter); -const oauthHandler = async (req, res, next) => { - try { - if (res.headersSent) { - return; - } - - await checkBan(req, res); - if (req.banned) { - return; - } - if ( - req.user && - req.user.provider == 'openid' && - isEnabled(process.env.OPENID_REUSE_TOKENS) === true - ) { - await syncUserEntraGroupMemberships(req.user, req.user.tokenset.access_token); - setOpenIDAuthTokens(req.user.tokenset, req, res, req.user._id.toString()); - } else { - await setAuthTokens(req.user._id, res); - } - res.redirect(domains.client); - } catch (err) { - logger.error('Error in setting authentication tokens:', err); - next(err); - } -}; +const oauthHandler = createOAuthHandler(); router.get('/error', (req, res) => { /** A single error message is pushed by passport when authentication fails. */ diff --git a/api/server/routes/roles.js b/api/server/routes/roles.js index abb53141bda9..12e18c7624ec 100644 --- a/api/server/routes/roles.js +++ b/api/server/routes/roles.js @@ -6,9 +6,10 @@ const { agentPermissionsSchema, promptPermissionsSchema, memoryPermissionsSchema, + mcpServersPermissionsSchema, marketplacePermissionsSchema, peoplePickerPermissionsSchema, - mcpServersPermissionsSchema, + remoteAgentsPermissionsSchema, } = require('librechat-data-provider'); const { checkAdmin, requireJwtAuth } = require('~/server/middleware'); const { updateRoleByName, getRoleByName } = require('~/models/Role'); @@ -51,6 +52,11 @@ const permissionConfigs = { permissionType: PermissionTypes.MARKETPLACE, errorMessage: 'Invalid marketplace permissions.', }, + 'remote-agents': { + schema: remoteAgentsPermissionsSchema, + permissionType: PermissionTypes.REMOTE_AGENTS, + errorMessage: 'Invalid remote agents permissions.', + }, }; /** @@ -160,4 +166,10 @@ router.put('/:roleName/mcp-servers', checkAdmin, createPermissionUpdateHandler(' */ router.put('/:roleName/marketplace', checkAdmin, createPermissionUpdateHandler('marketplace')); +/** + * PUT /api/roles/:roleName/remote-agents + * Update remote agents (API) permissions for a specific role + */ +router.put('/:roleName/remote-agents', checkAdmin, createPermissionUpdateHandler('remote-agents')); + module.exports = router; diff --git a/api/server/services/AuthService.js b/api/server/services/AuthService.js index a400bce8b7d3..ab1b9c56530c 100644 --- a/api/server/services/AuthService.js +++ b/api/server/services/AuthService.js @@ -6,8 +6,14 @@ const { DEFAULT_SESSION_EXPIRY, DEFAULT_REFRESH_TOKEN_EXPIRY, } = require('@librechat/data-schemas'); +const { + math, + isEnabled, + checkEmailConfig, + isEmailDomainAllowed, + extractSubFromAccessToken, +} = require('@librechat/api'); const { ErrorTypes, SystemRoles, errorsToString } = require('librechat-data-provider'); -const { isEnabled, checkEmailConfig, isEmailDomainAllowed, math } = require('@librechat/api'); const { findUser, findToken, @@ -490,6 +496,27 @@ const setOpenIDAuthTokens = (tokenset, req, res, userId, existingRefreshToken) = sameSite: 'strict', }); } + + if (isEnabled(process.env.OPENID_EXPOSE_SUB_COOKIE)) { + if (!process.env.JWT_REFRESH_SECRET) { + logger.error( + '[setOpenIDAuthTokens] JWT_REFRESH_SECRET not configured for openid_sub cookie', + ); + return tokenset.access_token; + } + const { sub } = extractSubFromAccessToken(tokenset.access_token); + if (sub) { + const signedSub = jwt.sign({ sub }, process.env.JWT_REFRESH_SECRET, { + expiresIn: expiryInMilliseconds / 1000, + }); + res.cookie('openid_sub', signedSub, { + expires: expirationDate, + httpOnly: true, + secure: isProduction, + sameSite: 'lax', + }); + } + } return tokenset.access_token; } catch (error) { logger.error('[setOpenIDAuthTokens] Error in setting authentication tokens:', error); diff --git a/api/server/services/Config/loadConfigModels.js b/api/server/services/Config/loadConfigModels.js index 6354d10331b6..2bc83ecc3aaa 100644 --- a/api/server/services/Config/loadConfigModels.js +++ b/api/server/services/Config/loadConfigModels.js @@ -28,6 +28,11 @@ async function loadConfigModels(req) { modelsConfig[EModelEndpoint.azureAssistants] = azureConfig.assistantModels; } + const bedrockConfig = appConfig.endpoints?.[EModelEndpoint.bedrock]; + if (bedrockConfig?.models && Array.isArray(bedrockConfig.models)) { + modelsConfig[EModelEndpoint.bedrock] = bedrockConfig.models; + } + if (!Array.isArray(appConfig.endpoints?.[EModelEndpoint.custom])) { return modelsConfig; } diff --git a/api/server/services/Endpoints/agents/addedConvo.js b/api/server/services/Endpoints/agents/addedConvo.js index 240622ed9f1a..7e9385267aeb 100644 --- a/api/server/services/Endpoints/agents/addedConvo.js +++ b/api/server/services/Endpoints/agents/addedConvo.js @@ -31,6 +31,7 @@ setGetAgent(getAgent); * @param {Function} params.loadTools - Function to load agent tools * @param {Array} params.requestFiles - Request files * @param {string} params.conversationId - The conversation ID + * @param {string} [params.parentMessageId] - The parent message ID for thread filtering * @param {Set} params.allowedProviders - Set of allowed providers * @param {Map} params.agentConfigs - Map of agent configs to add to * @param {string} params.primaryAgentId - The primary agent ID @@ -46,6 +47,7 @@ const processAddedConvo = async ({ loadTools, requestFiles, conversationId, + parentMessageId, allowedProviders, agentConfigs, primaryAgentId, @@ -91,6 +93,7 @@ const processAddedConvo = async ({ loadTools, requestFiles, conversationId, + parentMessageId, agent: addedAgent, endpointOption, allowedProviders, @@ -99,9 +102,12 @@ const processAddedConvo = async ({ getConvoFiles, getFiles: db.getFiles, getUserKey: db.getUserKey, + getMessages: db.getMessages, updateFilesUsage: db.updateFilesUsage, + getUserCodeFiles: db.getUserCodeFiles, getUserKeyValues: db.getUserKeyValues, getToolFilesByIds: db.getToolFilesByIds, + getCodeGeneratedFiles: db.getCodeGeneratedFiles, }, ); diff --git a/api/server/services/Endpoints/agents/initialize.js b/api/server/services/Endpoints/agents/initialize.js index a6914801198d..65d38c022615 100644 --- a/api/server/services/Endpoints/agents/initialize.js +++ b/api/server/services/Endpoints/agents/initialize.js @@ -44,13 +44,23 @@ function createToolLoader(signal, streamId = null) { * @param {string} params.model * @param {AgentToolResources} params.tool_resources * @returns {Promise<{ - * tools: StructuredTool[], - * toolContextMap: Record, - * userMCPAuthMap?: Record> + * tools: StructuredTool[], + * toolContextMap: Record, + * userMCPAuthMap?: Record>, + * toolRegistry?: import('@librechat/agents').LCToolRegistry * } | undefined>} */ - return async function loadTools({ req, res, agentId, tools, provider, model, tool_resources }) { - const agent = { id: agentId, tools, provider, model }; + return async function loadTools({ + req, + res, + tools, + model, + agentId, + provider, + tool_options, + tool_resources, + }) { + const agent = { id: agentId, tools, provider, model, tool_options }; try { return await loadAgentTools({ req, @@ -120,6 +130,8 @@ const initializeClient = async ({ req, res, signal, endpointOption }) => { const requestFiles = req.body.files ?? []; /** @type {string} */ const conversationId = req.body.conversationId; + /** @type {string | undefined} */ + const parentMessageId = req.body.parentMessageId; const primaryConfig = await initializeAgent( { @@ -128,6 +140,7 @@ const initializeClient = async ({ req, res, signal, endpointOption }) => { loadTools, requestFiles, conversationId, + parentMessageId, agent: primaryAgent, endpointOption, allowedProviders, @@ -137,9 +150,12 @@ const initializeClient = async ({ req, res, signal, endpointOption }) => { getConvoFiles, getFiles: db.getFiles, getUserKey: db.getUserKey, + getMessages: db.getMessages, updateFilesUsage: db.updateFilesUsage, getUserKeyValues: db.getUserKeyValues, + getUserCodeFiles: db.getUserCodeFiles, getToolFilesByIds: db.getToolFilesByIds, + getCodeGeneratedFiles: db.getCodeGeneratedFiles, }, ); @@ -179,6 +195,7 @@ const initializeClient = async ({ req, res, signal, endpointOption }) => { loadTools, requestFiles, conversationId, + parentMessageId, endpointOption, allowedProviders, }, @@ -186,9 +203,12 @@ const initializeClient = async ({ req, res, signal, endpointOption }) => { getConvoFiles, getFiles: db.getFiles, getUserKey: db.getUserKey, + getMessages: db.getMessages, updateFilesUsage: db.updateFilesUsage, getUserKeyValues: db.getUserKeyValues, + getUserCodeFiles: db.getUserCodeFiles, getToolFilesByIds: db.getToolFilesByIds, + getCodeGeneratedFiles: db.getCodeGeneratedFiles, }, ); if (userMCPAuthMap != null) { @@ -243,17 +263,18 @@ const initializeClient = async ({ req, res, signal, endpointOption }) => { const { userMCPAuthMap: updatedMCPAuthMap } = await processAddedConvo({ req, res, - endpointOption, - modelsConfig, - logViolation, loadTools, + logViolation, + modelsConfig, requestFiles, - conversationId, - allowedProviders, agentConfigs, - primaryAgentId: primaryConfig.id, primaryAgent, + endpointOption, userMCPAuthMap, + conversationId, + parentMessageId, + allowedProviders, + primaryAgentId: primaryConfig.id, }); if (updatedMCPAuthMap) { diff --git a/api/server/services/Files/Code/process.js b/api/server/services/Files/Code/process.js index 15df6de0d665..b7e7f5655263 100644 --- a/api/server/services/Files/Code/process.js +++ b/api/server/services/Files/Code/process.js @@ -6,27 +6,112 @@ const { getCodeBaseURL } = require('@librechat/agents'); const { logAxiosError, getBasePath } = require('@librechat/api'); const { Tools, + megabyte, + fileConfig, FileContext, FileSources, imageExtRegex, + inferMimeType, EToolResources, + EModelEndpoint, + mergeFileConfig, + getEndpointFileConfig, } = require('librechat-data-provider'); const { filterFilesByAgentAccess } = require('~/server/services/Files/permissions'); const { getStrategyFunctions } = require('~/server/services/Files/strategies'); const { convertImage } = require('~/server/services/Files/images/convert'); const { createFile, getFiles, updateFile } = require('~/models'); +const { determineFileType } = require('~/server/utils'); /** - * Process OpenAI image files, convert to target format, save and return file metadata. + * Creates a fallback download URL response when file cannot be processed locally. + * Used when: file exceeds size limit, storage strategy unavailable, or download error occurs. + * @param {Object} params - The parameters. + * @param {string} params.name - The filename. + * @param {string} params.session_id - The code execution session ID. + * @param {string} params.id - The file ID from the code environment. + * @param {string} params.conversationId - The current conversation ID. + * @param {string} params.toolCallId - The tool call ID that generated the file. + * @param {string} params.messageId - The current message ID. + * @param {number} params.expiresAt - Expiration timestamp (24 hours from creation). + * @returns {Object} Fallback response with download URL. + */ +const createDownloadFallback = ({ + id, + name, + messageId, + expiresAt, + session_id, + toolCallId, + conversationId, +}) => { + const basePath = getBasePath(); + return { + filename: name, + filepath: `${basePath}/api/files/code/download/${session_id}/${id}`, + expiresAt, + conversationId, + toolCallId, + messageId, + }; +}; + +/** + * Find an existing code-generated file by filename in the conversation. + * Used to update existing files instead of creating duplicates. + * + * ## Deduplication Strategy + * + * Files are deduplicated by `(conversationId, filename)` - NOT including `messageId`. + * This is an intentional design decision to handle iterative code development patterns: + * + * **Rationale:** + * - When users iteratively refine code (e.g., "regenerate that chart with red bars"), + * the same logical file (e.g., "chart.png") is produced multiple times + * - Without deduplication, each iteration would create a new file, leading to storage bloat + * - The latest version is what matters for re-upload to the code environment + * + * **Implications:** + * - Different messages producing files with the same name will update the same file record + * - The `messageId` field tracks which message last updated the file + * - The `usage` counter tracks how many times the file has been generated + * + * **Future Considerations:** + * - If file versioning is needed, consider adding a `versions` array or separate version collection + * - The current approach prioritizes storage efficiency over history preservation + * + * @param {string} filename - The filename to search for. + * @param {string} conversationId - The conversation ID. + * @returns {Promise} The existing file or null. + */ +const findExistingCodeFile = async (filename, conversationId) => { + if (!filename || !conversationId) { + return null; + } + const files = await getFiles( + { + filename, + conversationId, + context: FileContext.execute_code, + }, + { createdAt: -1 }, + { text: 0 }, + ); + return files?.[0] ?? null; +}; + +/** + * Process code execution output files - downloads and saves both images and non-image files. + * All files are saved to local storage with fileIdentifier metadata for code env re-upload. * @param {ServerRequest} params.req - The Express request object. - * @param {string} params.id - The file ID. + * @param {string} params.id - The file ID from the code environment. * @param {string} params.name - The filename. * @param {string} params.apiKey - The code execution API key. * @param {string} params.toolCallId - The tool call ID that generated the file. * @param {string} params.session_id - The code execution session ID. * @param {string} params.conversationId - The current conversation ID. * @param {string} params.messageId - The current message ID. - * @returns {Promise} The file metadata or undefined if an error occurs. + * @returns {Promise} The file metadata or undefined if an error occurs. */ const processCodeOutput = async ({ req, @@ -41,19 +126,15 @@ const processCodeOutput = async ({ const appConfig = req.config; const currentDate = new Date(); const baseURL = getCodeBaseURL(); - const basePath = getBasePath(); - const fileExt = path.extname(name); - if (!fileExt || !imageExtRegex.test(name)) { - return { - filename: name, - filepath: `${basePath}/api/files/code/download/${session_id}/${id}`, - /** Note: expires 24 hours after creation */ - expiresAt: currentDate.getTime() + 86400000, - conversationId, - toolCallId, - messageId, - }; - } + const fileExt = path.extname(name).toLowerCase(); + const isImage = fileExt && imageExtRegex.test(name); + + const mergedFileConfig = mergeFileConfig(appConfig.fileConfig); + const endpointFileConfig = getEndpointFileConfig({ + fileConfig: mergedFileConfig, + endpoint: EModelEndpoint.agents, + }); + const fileSizeLimit = endpointFileConfig.fileSizeLimit ?? mergedFileConfig.serverFileSizeLimit; try { const formattedDate = currentDate.toISOString(); @@ -70,29 +151,135 @@ const processCodeOutput = async ({ const buffer = Buffer.from(response.data, 'binary'); - const file_id = v4(); - const _file = await convertImage(req, buffer, 'high', `${file_id}${fileExt}`); + // Enforce file size limit + if (buffer.length > fileSizeLimit) { + logger.warn( + `[processCodeOutput] File "${name}" (${(buffer.length / megabyte).toFixed(2)} MB) exceeds size limit of ${(fileSizeLimit / megabyte).toFixed(2)} MB, falling back to download URL`, + ); + return createDownloadFallback({ + id, + name, + messageId, + toolCallId, + session_id, + conversationId, + expiresAt: currentDate.getTime() + 86400000, + }); + } + + const fileIdentifier = `${session_id}/${id}`; + + /** + * Check for existing file with same filename in this conversation. + * If found, we'll update it instead of creating a duplicate. + */ + const existingFile = await findExistingCodeFile(name, conversationId); + const file_id = existingFile?.file_id ?? v4(); + const isUpdate = !!existingFile; + + if (isUpdate) { + logger.debug( + `[processCodeOutput] Updating existing file "${name}" (${file_id}) instead of creating duplicate`, + ); + } + + if (isImage) { + const _file = await convertImage(req, buffer, 'high', `${file_id}${fileExt}`); + const file = { + ..._file, + file_id, + messageId, + usage: isUpdate ? (existingFile.usage ?? 0) + 1 : 1, + filename: name, + conversationId, + user: req.user.id, + type: `image/${appConfig.imageOutputType}`, + createdAt: isUpdate ? existingFile.createdAt : formattedDate, + updatedAt: formattedDate, + source: appConfig.fileStrategy, + context: FileContext.execute_code, + metadata: { fileIdentifier }, + }; + createFile(file, true); + return Object.assign(file, { messageId, toolCallId }); + } + + // For non-image files, save to configured storage strategy + const { saveBuffer } = getStrategyFunctions(appConfig.fileStrategy); + if (!saveBuffer) { + logger.warn( + `[processCodeOutput] saveBuffer not available for strategy ${appConfig.fileStrategy}, falling back to download URL`, + ); + return createDownloadFallback({ + id, + name, + messageId, + toolCallId, + session_id, + conversationId, + expiresAt: currentDate.getTime() + 86400000, + }); + } + + // Determine MIME type from buffer or extension + const detectedType = await determineFileType(buffer, true); + const mimeType = detectedType?.mime || inferMimeType(name, '') || 'application/octet-stream'; + + /** Check MIME type support - for code-generated files, we're lenient but log unsupported types */ + const isSupportedMimeType = fileConfig.checkType( + mimeType, + endpointFileConfig.supportedMimeTypes, + ); + if (!isSupportedMimeType) { + logger.warn( + `[processCodeOutput] File "${name}" has unsupported MIME type "${mimeType}", proceeding with storage but may not be usable as tool resource`, + ); + } + + const fileName = `${file_id}__${name}`; + const filepath = await saveBuffer({ + userId: req.user.id, + buffer, + fileName, + basePath: 'uploads', + }); + const file = { - ..._file, file_id, - usage: 1, + filepath, + messageId, + object: 'file', filename: name, + type: mimeType, conversationId, user: req.user.id, - type: `image/${appConfig.imageOutputType}`, - createdAt: formattedDate, + bytes: buffer.length, updatedAt: formattedDate, + metadata: { fileIdentifier }, source: appConfig.fileStrategy, context: FileContext.execute_code, + usage: isUpdate ? (existingFile.usage ?? 0) + 1 : 1, + createdAt: isUpdate ? existingFile.createdAt : formattedDate, }; + createFile(file, true); - /** Note: `messageId` & `toolCallId` are not part of file DB schema; message object records associated file ID */ return Object.assign(file, { messageId, toolCallId }); } catch (error) { logAxiosError({ - message: 'Error downloading code environment file', + message: 'Error downloading/processing code environment file', error, }); + + // Fallback for download errors - return download URL so user can still manually download + return createDownloadFallback({ + id, + name, + messageId, + toolCallId, + session_id, + conversationId, + expiresAt: currentDate.getTime() + 86400000, + }); } }; @@ -204,9 +391,16 @@ const primeFiles = async (options, apiKey) => { if (!toolContext) { toolContext = `- Note: The following files are available in the "${Tools.execute_code}" tool environment:`; } - toolContext += `\n\t- /mnt/data/${file.filename}${ - agentResourceIds.has(file.file_id) ? '' : ' (just attached by user)' - }`; + + let fileSuffix = ''; + if (!agentResourceIds.has(file.file_id)) { + fileSuffix = + file.context === FileContext.execute_code + ? ' (from previous code execution)' + : ' (attached by user)'; + } + + toolContext += `\n\t- /mnt/data/${file.filename}${fileSuffix}`; files.push({ id, session_id, diff --git a/api/server/services/Files/Code/process.spec.js b/api/server/services/Files/Code/process.spec.js new file mode 100644 index 000000000000..7e15888876fe --- /dev/null +++ b/api/server/services/Files/Code/process.spec.js @@ -0,0 +1,418 @@ +// Configurable file size limit for tests - use a getter so it can be changed per test +const fileSizeLimitConfig = { value: 20 * 1024 * 1024 }; // Default 20MB + +// Mock librechat-data-provider with configurable file size limit +jest.mock('librechat-data-provider', () => { + const actual = jest.requireActual('librechat-data-provider'); + return { + ...actual, + mergeFileConfig: jest.fn((config) => { + const merged = actual.mergeFileConfig(config); + // Override the serverFileSizeLimit with our test value + return { + ...merged, + get serverFileSizeLimit() { + return fileSizeLimitConfig.value; + }, + }; + }), + getEndpointFileConfig: jest.fn((options) => { + const config = actual.getEndpointFileConfig(options); + // Override fileSizeLimit with our test value + return { + ...config, + get fileSizeLimit() { + return fileSizeLimitConfig.value; + }, + }; + }), + }; +}); + +const { FileContext } = require('librechat-data-provider'); + +// Mock uuid +jest.mock('uuid', () => ({ + v4: jest.fn(() => 'mock-uuid-1234'), +})); + +// Mock axios +jest.mock('axios'); +const axios = require('axios'); + +// Mock logger +jest.mock('@librechat/data-schemas', () => ({ + logger: { + warn: jest.fn(), + debug: jest.fn(), + error: jest.fn(), + }, +})); + +// Mock getCodeBaseURL +jest.mock('@librechat/agents', () => ({ + getCodeBaseURL: jest.fn(() => 'https://code-api.example.com'), +})); + +// Mock logAxiosError and getBasePath +jest.mock('@librechat/api', () => ({ + logAxiosError: jest.fn(), + getBasePath: jest.fn(() => ''), +})); + +// Mock models +jest.mock('~/models', () => ({ + createFile: jest.fn(), + getFiles: jest.fn(), + updateFile: jest.fn(), +})); + +// Mock permissions (must be before process.js import) +jest.mock('~/server/services/Files/permissions', () => ({ + filterFilesByAgentAccess: jest.fn((options) => Promise.resolve(options.files)), +})); + +// Mock strategy functions +jest.mock('~/server/services/Files/strategies', () => ({ + getStrategyFunctions: jest.fn(), +})); + +// Mock convertImage +jest.mock('~/server/services/Files/images/convert', () => ({ + convertImage: jest.fn(), +})); + +// Mock determineFileType +jest.mock('~/server/utils', () => ({ + determineFileType: jest.fn(), +})); + +const { createFile, getFiles } = require('~/models'); +const { getStrategyFunctions } = require('~/server/services/Files/strategies'); +const { convertImage } = require('~/server/services/Files/images/convert'); +const { determineFileType } = require('~/server/utils'); +const { logger } = require('@librechat/data-schemas'); + +// Import after mocks +const { processCodeOutput } = require('./process'); + +describe('Code Process', () => { + const mockReq = { + user: { id: 'user-123' }, + config: { + fileConfig: {}, + fileStrategy: 'local', + imageOutputType: 'webp', + }, + }; + + const baseParams = { + req: mockReq, + id: 'file-id-123', + name: 'test-file.txt', + apiKey: 'test-api-key', + toolCallId: 'tool-call-123', + conversationId: 'conv-123', + messageId: 'msg-123', + session_id: 'session-123', + }; + + beforeEach(() => { + jest.clearAllMocks(); + // Default mock implementations + getFiles.mockResolvedValue(null); + createFile.mockResolvedValue({}); + getStrategyFunctions.mockReturnValue({ + saveBuffer: jest.fn().mockResolvedValue('/uploads/mock-file-path.txt'), + }); + determineFileType.mockResolvedValue({ mime: 'text/plain' }); + }); + + describe('findExistingCodeFile (via processCodeOutput)', () => { + it('should find existing file by filename and conversationId', async () => { + const existingFile = { + file_id: 'existing-file-id', + filename: 'test-file.txt', + usage: 2, + createdAt: '2024-01-01T00:00:00.000Z', + }; + getFiles.mockResolvedValue([existingFile]); + + const smallBuffer = Buffer.alloc(100); + axios.mockResolvedValue({ data: smallBuffer }); + + const result = await processCodeOutput(baseParams); + + // Verify getFiles was called with correct deduplication query + expect(getFiles).toHaveBeenCalledWith( + { + filename: 'test-file.txt', + conversationId: 'conv-123', + context: FileContext.execute_code, + }, + { createdAt: -1 }, + { text: 0 }, + ); + + // Verify the existing file_id was reused + expect(result.file_id).toBe('existing-file-id'); + // Verify usage was incremented + expect(result.usage).toBe(3); + // Verify original createdAt was preserved + expect(result.createdAt).toBe('2024-01-01T00:00:00.000Z'); + }); + + it('should create new file when no existing file found', async () => { + getFiles.mockResolvedValue(null); + + const smallBuffer = Buffer.alloc(100); + axios.mockResolvedValue({ data: smallBuffer }); + + const result = await processCodeOutput(baseParams); + + // Should use the mocked uuid + expect(result.file_id).toBe('mock-uuid-1234'); + // Should have usage of 1 for new file + expect(result.usage).toBe(1); + }); + + it('should return null for invalid inputs (empty filename)', async () => { + const smallBuffer = Buffer.alloc(100); + axios.mockResolvedValue({ data: smallBuffer }); + + // The function handles this internally - with empty name + // findExistingCodeFile returns null early for empty filename (guard clause) + const result = await processCodeOutput({ ...baseParams, name: '' }); + + // getFiles should NOT be called due to early return in findExistingCodeFile + expect(getFiles).not.toHaveBeenCalled(); + // A new file_id should be generated since no existing file was found + expect(result.file_id).toBe('mock-uuid-1234'); + }); + }); + + describe('processCodeOutput', () => { + describe('image file processing', () => { + it('should process image files using convertImage', async () => { + const imageParams = { ...baseParams, name: 'chart.png' }; + const imageBuffer = Buffer.alloc(500); + axios.mockResolvedValue({ data: imageBuffer }); + + const convertedFile = { + filepath: '/uploads/converted-image.webp', + bytes: 400, + }; + convertImage.mockResolvedValue(convertedFile); + getFiles.mockResolvedValue(null); + + const result = await processCodeOutput(imageParams); + + expect(convertImage).toHaveBeenCalledWith( + mockReq, + imageBuffer, + 'high', + 'mock-uuid-1234.png', + ); + expect(result.type).toBe('image/webp'); + expect(result.context).toBe(FileContext.execute_code); + expect(result.filename).toBe('chart.png'); + }); + + it('should update existing image file and increment usage', async () => { + const imageParams = { ...baseParams, name: 'chart.png' }; + const existingFile = { + file_id: 'existing-img-id', + usage: 1, + createdAt: '2024-01-01T00:00:00.000Z', + }; + getFiles.mockResolvedValue([existingFile]); + + const imageBuffer = Buffer.alloc(500); + axios.mockResolvedValue({ data: imageBuffer }); + convertImage.mockResolvedValue({ filepath: '/uploads/img.webp' }); + + const result = await processCodeOutput(imageParams); + + expect(result.file_id).toBe('existing-img-id'); + expect(result.usage).toBe(2); + expect(logger.debug).toHaveBeenCalledWith( + expect.stringContaining('Updating existing file'), + ); + }); + }); + + describe('non-image file processing', () => { + it('should process non-image files using saveBuffer', async () => { + const smallBuffer = Buffer.alloc(100); + axios.mockResolvedValue({ data: smallBuffer }); + + const mockSaveBuffer = jest.fn().mockResolvedValue('/uploads/saved-file.txt'); + getStrategyFunctions.mockReturnValue({ saveBuffer: mockSaveBuffer }); + determineFileType.mockResolvedValue({ mime: 'text/plain' }); + + const result = await processCodeOutput(baseParams); + + expect(mockSaveBuffer).toHaveBeenCalledWith({ + userId: 'user-123', + buffer: smallBuffer, + fileName: 'mock-uuid-1234__test-file.txt', + basePath: 'uploads', + }); + expect(result.type).toBe('text/plain'); + expect(result.filepath).toBe('/uploads/saved-file.txt'); + expect(result.bytes).toBe(100); + }); + + it('should detect MIME type from buffer', async () => { + const smallBuffer = Buffer.alloc(100); + axios.mockResolvedValue({ data: smallBuffer }); + determineFileType.mockResolvedValue({ mime: 'application/pdf' }); + + const result = await processCodeOutput({ ...baseParams, name: 'document.pdf' }); + + expect(determineFileType).toHaveBeenCalledWith(smallBuffer, true); + expect(result.type).toBe('application/pdf'); + }); + + it('should fallback to application/octet-stream for unknown types', async () => { + const smallBuffer = Buffer.alloc(100); + axios.mockResolvedValue({ data: smallBuffer }); + determineFileType.mockResolvedValue(null); + + const result = await processCodeOutput({ ...baseParams, name: 'unknown.xyz' }); + + expect(result.type).toBe('application/octet-stream'); + }); + }); + + describe('file size limit enforcement', () => { + it('should fallback to download URL when file exceeds size limit', async () => { + // Set a small file size limit for this test + fileSizeLimitConfig.value = 1000; // 1KB limit + + const largeBuffer = Buffer.alloc(5000); // 5KB - exceeds 1KB limit + axios.mockResolvedValue({ data: largeBuffer }); + + const result = await processCodeOutput(baseParams); + + expect(logger.warn).toHaveBeenCalledWith(expect.stringContaining('exceeds size limit')); + expect(result.filepath).toContain('/api/files/code/download/session-123/file-id-123'); + expect(result.expiresAt).toBeDefined(); + // Should not call createFile for oversized files (fallback path) + expect(createFile).not.toHaveBeenCalled(); + + // Reset to default for other tests + fileSizeLimitConfig.value = 20 * 1024 * 1024; + }); + }); + + describe('fallback behavior', () => { + it('should fallback to download URL when saveBuffer is not available', async () => { + const smallBuffer = Buffer.alloc(100); + axios.mockResolvedValue({ data: smallBuffer }); + getStrategyFunctions.mockReturnValue({ saveBuffer: null }); + + const result = await processCodeOutput(baseParams); + + expect(logger.warn).toHaveBeenCalledWith( + expect.stringContaining('saveBuffer not available'), + ); + expect(result.filepath).toContain('/api/files/code/download/'); + expect(result.filename).toBe('test-file.txt'); + }); + + it('should fallback to download URL on axios error', async () => { + axios.mockRejectedValue(new Error('Network error')); + + const result = await processCodeOutput(baseParams); + + expect(result.filepath).toContain('/api/files/code/download/session-123/file-id-123'); + expect(result.conversationId).toBe('conv-123'); + expect(result.messageId).toBe('msg-123'); + expect(result.toolCallId).toBe('tool-call-123'); + }); + }); + + describe('usage counter increment', () => { + it('should set usage to 1 for new files', async () => { + getFiles.mockResolvedValue(null); + const smallBuffer = Buffer.alloc(100); + axios.mockResolvedValue({ data: smallBuffer }); + + const result = await processCodeOutput(baseParams); + + expect(result.usage).toBe(1); + }); + + it('should increment usage for existing files', async () => { + const existingFile = { file_id: 'existing-id', usage: 5, createdAt: '2024-01-01' }; + getFiles.mockResolvedValue([existingFile]); + const smallBuffer = Buffer.alloc(100); + axios.mockResolvedValue({ data: smallBuffer }); + + const result = await processCodeOutput(baseParams); + + expect(result.usage).toBe(6); + }); + + it('should handle existing file with undefined usage', async () => { + const existingFile = { file_id: 'existing-id', createdAt: '2024-01-01' }; + getFiles.mockResolvedValue([existingFile]); + const smallBuffer = Buffer.alloc(100); + axios.mockResolvedValue({ data: smallBuffer }); + + const result = await processCodeOutput(baseParams); + + // (undefined ?? 0) + 1 = 1 + expect(result.usage).toBe(1); + }); + }); + + describe('metadata and file properties', () => { + it('should include fileIdentifier in metadata', async () => { + const smallBuffer = Buffer.alloc(100); + axios.mockResolvedValue({ data: smallBuffer }); + + const result = await processCodeOutput(baseParams); + + expect(result.metadata).toEqual({ + fileIdentifier: 'session-123/file-id-123', + }); + }); + + it('should set correct context for code-generated files', async () => { + const smallBuffer = Buffer.alloc(100); + axios.mockResolvedValue({ data: smallBuffer }); + + const result = await processCodeOutput(baseParams); + + expect(result.context).toBe(FileContext.execute_code); + }); + + it('should include toolCallId and messageId in result', async () => { + const smallBuffer = Buffer.alloc(100); + axios.mockResolvedValue({ data: smallBuffer }); + + const result = await processCodeOutput(baseParams); + + expect(result.toolCallId).toBe('tool-call-123'); + expect(result.messageId).toBe('msg-123'); + }); + + it('should call createFile with upsert enabled', async () => { + const smallBuffer = Buffer.alloc(100); + axios.mockResolvedValue({ data: smallBuffer }); + + await processCodeOutput(baseParams); + + expect(createFile).toHaveBeenCalledWith( + expect.objectContaining({ + file_id: 'mock-uuid-1234', + context: FileContext.execute_code, + }), + true, // upsert flag + ); + }); + }); + }); +}); diff --git a/api/server/services/Files/Local/crud.js b/api/server/services/Files/Local/crud.js index db553f57ddc0..b43ab7532619 100644 --- a/api/server/services/Files/Local/crud.js +++ b/api/server/services/Files/Local/crud.js @@ -67,7 +67,12 @@ async function saveLocalBuffer({ userId, buffer, fileName, basePath = 'images' } try { const { publicPath, uploads } = paths; - const directoryPath = path.join(basePath === 'images' ? publicPath : uploads, basePath, userId); + /** + * For 'images': save to publicPath/images/userId (images are served statically) + * For 'uploads': save to uploads/userId (files downloaded via API) + * */ + const directoryPath = + basePath === 'images' ? path.join(publicPath, basePath, userId) : path.join(uploads, userId); if (!fs.existsSync(directoryPath)) { fs.mkdirSync(directoryPath, { recursive: true }); diff --git a/api/server/services/MCP.js b/api/server/services/MCP.js index 81d7107de40c..df1e637b1bb3 100644 --- a/api/server/services/MCP.js +++ b/api/server/services/MCP.js @@ -29,6 +29,7 @@ const { getMCPManager, } = require('~/config'); const { findToken, createToken, updateToken } = require('~/models'); +const { getGraphApiToken } = require('./GraphTokenService'); const { reinitMCPServer } = require('./Tools/mcp'); const { getAppConfig } = require('./Config'); const { getLogStores } = require('~/cache'); @@ -501,6 +502,7 @@ function createToolInstance({ }, oauthStart, oauthEnd, + graphTokenResolver: getGraphApiToken, }); if (isAssistantsEndpoint(provider) && Array.isArray(result)) { @@ -548,6 +550,7 @@ function createToolInstance({ }); toolInstance.mcp = true; toolInstance.mcpRawServerName = serverName; + toolInstance.mcpJsonSchema = parameters; return toolInstance; } diff --git a/api/server/services/MCP.spec.js b/api/server/services/MCP.spec.js index cb2f0081a3d9..5d7eb093be68 100644 --- a/api/server/services/MCP.spec.js +++ b/api/server/services/MCP.spec.js @@ -120,6 +120,10 @@ jest.mock('./Tools/mcp', () => ({ reinitMCPServer: jest.fn(), })); +jest.mock('./GraphTokenService', () => ({ + getGraphApiToken: jest.fn(), +})); + describe('tests for the new helper functions used by the MCP connection status endpoints', () => { let mockGetMCPManager; let mockGetFlowStateManager; diff --git a/api/server/services/PermissionService.js b/api/server/services/PermissionService.js index c35faf7c8d91..a843f48f6fe0 100644 --- a/api/server/services/PermissionService.js +++ b/api/server/services/PermissionService.js @@ -141,7 +141,6 @@ const checkPermission = async ({ userId, role, resourceType, resourceId, require validateResourceType(resourceType); - // Get all principals for the user (user + groups + public) const principals = await getUserPrincipals({ userId, role }); if (principals.length === 0) { @@ -151,7 +150,6 @@ const checkPermission = async ({ userId, role, resourceType, resourceId, require return await hasPermission(principals, resourceType, resourceId, requiredPermission); } catch (error) { logger.error(`[PermissionService.checkPermission] Error: ${error.message}`); - // Re-throw validation errors if (error.message.includes('requiredPermission must be')) { throw error; } @@ -172,12 +170,12 @@ const getEffectivePermissions = async ({ userId, role, resourceType, resourceId try { validateResourceType(resourceType); - // Get all principals for the user (user + groups + public) const principals = await getUserPrincipals({ userId, role }); if (principals.length === 0) { return 0; } + return await getEffectivePermissionsACL(principals, resourceType, resourceId); } catch (error) { logger.error(`[PermissionService.getEffectivePermissions] Error: ${error.message}`); diff --git a/api/server/services/ToolService.js b/api/server/services/ToolService.js index 62d25b23eb70..f72b4169dc57 100644 --- a/api/server/services/ToolService.js +++ b/api/server/services/ToolService.js @@ -6,6 +6,7 @@ const { hasCustomUserVars, getUserMCPAuthMap, isActionDomainAllowed, + buildToolClassification, } = require('@librechat/api'); const { Tools, @@ -36,6 +37,7 @@ const { recordUsage } = require('~/server/services/Threads'); const { loadTools } = require('~/app/clients/tools/util'); const { redactMessage } = require('~/config/parsers'); const { findPluginAuthsByKeys } = require('~/models'); +const { loadAuthValues } = require('~/server/services/Tools/credentials'); /** * Processes the required actions by calling the appropriate tools and returning the outputs. * @param {OpenAIClient} client - OpenAI or StreamRunManager Client. @@ -367,7 +369,13 @@ async function processRequiredActions(client, requiredActions) { * @param {AbortSignal} params.signal * @param {Pick> }>} The agent tools. + * @returns {Promise<{ + * tools?: StructuredTool[]; + * toolContextMap?: Record; + * userMCPAuthMap?: Record>; + * toolRegistry?: Map; + * hasDeferredTools?: boolean; + * }>} The agent tools and registry. */ async function loadAgentTools({ req, @@ -401,8 +409,14 @@ async function loadAgentTools({ const checkCapability = (capability) => { const enabled = enabledCapabilities.has(capability); if (!enabled) { + const isToolCapability = [ + AgentCapabilities.file_search, + AgentCapabilities.execute_code, + AgentCapabilities.web_search, + ].includes(capability); + const suffix = isToolCapability ? ' despite configured tool.' : '.'; logger.warn( - `Capability "${capability}" disabled${capability === AgentCapabilities.tools ? '.' : ' despite configured tool.'} User: ${req.user.id} | Agent: ${agent.id}`, + `Capability "${capability}" disabled${suffix} User: ${req.user.id} | Agent: ${agent.id}`, ); } return enabled; @@ -510,11 +524,25 @@ async function loadAgentTools({ return map; }, {}); + /** Build tool registry from MCP tools and create PTC/tool search tools if configured */ + const deferredToolsEnabled = checkCapability(AgentCapabilities.deferred_tools); + const { toolRegistry, additionalTools, hasDeferredTools } = await buildToolClassification({ + loadedTools, + userId: req.user.id, + agentId: agent.id, + agentToolOptions: agent.tool_options, + deferredToolsEnabled, + loadAuthValues, + }); + agentTools.push(...additionalTools); + if (!checkCapability(AgentCapabilities.actions)) { return { tools: agentTools, userMCPAuthMap, toolContextMap, + toolRegistry, + hasDeferredTools, }; } @@ -527,6 +555,8 @@ async function loadAgentTools({ tools: agentTools, userMCPAuthMap, toolContextMap, + toolRegistry, + hasDeferredTools, }; } @@ -654,6 +684,8 @@ async function loadAgentTools({ tools: agentTools, toolContextMap, userMCPAuthMap, + toolRegistry, + hasDeferredTools, }; } diff --git a/api/server/services/__tests__/AuthService.openid-lax-cookie.spec.js b/api/server/services/__tests__/AuthService.openid-lax-cookie.spec.js new file mode 100644 index 000000000000..4a1dd96283cd --- /dev/null +++ b/api/server/services/__tests__/AuthService.openid-lax-cookie.spec.js @@ -0,0 +1,389 @@ +jest.mock('@librechat/data-schemas', () => ({ + logger: { + error: jest.fn(), + debug: jest.fn(), + info: jest.fn(), + warn: jest.fn(), + }, + hashToken: jest.fn((token) => `hashed-${token}`), + createMethods: jest.fn(() => ({})), + DEFAULT_REFRESH_TOKEN_EXPIRY: 1000 * 60 * 60 * 24 * 7, // 7 days in milliseconds + DEFAULT_SESSION_EXPIRY: 1000 * 60 * 15, // 15 minutes in milliseconds +})); + +jest.mock('librechat-data-provider', () => ({ + ...jest.requireActual('librechat-data-provider'), +})); + +jest.mock('@librechat/api', () => ({ + ...jest.requireActual('@librechat/api'), + checkEmailConfig: jest.fn(() => false), + isEmailDomainAllowed: jest.fn(() => true), + extractSubFromAccessToken: jest.fn((token) => { + if (!token) { + return { sub: null, error: 'No access token provided' }; + } + if (token === 'test-access-token') { + return { sub: 'cognito-sub-12345' }; + } + if (token === 'token-without-sub') { + return { sub: null, error: 'No sub claim in access token' }; + } + if (token === 'invalid-token') { + return { sub: null, error: 'Failed to decode access token' }; + } + return { sub: 'default-sub-claim' }; + }), +})); + +jest.mock('jsonwebtoken', () => ({ + sign: jest.fn((payload, _secret, _options) => { + if (payload.id) { + return `mocked-jwt-token-${payload.id}`; + } + if (payload.sub) { + return `mocked-jwt-token-${payload.sub}`; + } + return 'mocked-jwt-token'; + }), + verify: jest.fn(), + decode: jest.fn((token) => { + if (token === 'test-access-token') { + return { sub: 'cognito-sub-12345', aud: 'test-client-id' }; + } + return null; + }), +})); + +jest.mock('bcryptjs', () => ({ + genSaltSync: jest.fn(() => 'mock-salt'), + hashSync: jest.fn((password) => `hashed-${password}`), + compareSync: jest.fn(() => true), +})); + +jest.mock('~/models', () => ({ + findUser: jest.fn(), + findToken: jest.fn(), + createUser: jest.fn(), + updateUser: jest.fn(), + countUsers: jest.fn(), + getUserById: jest.fn(), + findSession: jest.fn(), + createToken: jest.fn(), + deleteTokens: jest.fn(), + deleteSession: jest.fn(), + createSession: jest.fn(), + generateToken: jest.fn(), + deleteUserById: jest.fn(), + generateRefreshToken: jest.fn(), +})); + +jest.mock('~/strategies/validators', () => ({ + registerSchema: { + safeParse: jest.fn(() => ({ error: null })), + }, +})); + +jest.mock('~/server/services/Config', () => ({ + getAppConfig: jest.fn(() => Promise.resolve({})), +})); + +jest.mock('~/server/utils', () => ({ + sendEmail: jest.fn(() => Promise.resolve()), +})); + +const { setOpenIDAuthTokens } = require('../AuthService'); + +describe('setOpenIDAuthTokens - OPENID_EXPOSE_SUB_COOKIE feature', () => { + let mockReq; + let mockRes; + let mockTokenset; + const testUserId = 'test-user-id-12345'; + + beforeEach(() => { + // mockReq without session to test cookie fallback path + mockReq = {}; + + mockRes = { + cookie: jest.fn(), + }; + + mockTokenset = { + access_token: 'test-access-token', + refresh_token: 'test-refresh-token', + id_token: 'test-id-token', + }; + + // Reset NODE_ENV to production for secure cookies + const originalEnv = process.env.NODE_ENV; + process.env.NODE_ENV = 'production'; + // Store original value to restore later + process.env._ORIGINAL_NODE_ENV = originalEnv; + process.env.JWT_REFRESH_SECRET = 'test-secret'; + jest.clearAllMocks(); + }); + + afterEach(() => { + // Restore original NODE_ENV value + if (process.env._ORIGINAL_NODE_ENV) { + process.env.NODE_ENV = process.env._ORIGINAL_NODE_ENV; + delete process.env._ORIGINAL_NODE_ENV; + } else { + delete process.env.NODE_ENV; + } + delete process.env.JWT_REFRESH_SECRET; + delete process.env.OPENID_REUSE_TOKENS; + delete process.env.REFRESH_TOKEN_EXPIRY; + delete process.env.OPENID_EXPOSE_SUB_COOKIE; + }); + + describe('cookie count verification', () => { + it('should set 3 cookies when OPENID_REUSE_TOKENS is false and OPENID_EXPOSE_SUB_COOKIE is false', () => { + process.env.OPENID_REUSE_TOKENS = 'false'; + process.env.OPENID_EXPOSE_SUB_COOKIE = 'false'; + + setOpenIDAuthTokens(mockTokenset, mockReq, mockRes, testUserId); + + expect(mockRes.cookie).toHaveBeenCalledTimes(3); + const cookieNames = mockRes.cookie.mock.calls.map((call) => call[0]); + expect(cookieNames).toEqual(['refreshToken', 'openid_access_token', 'token_provider']); + }); + + it('should set 4 cookies when OPENID_REUSE_TOKENS is true and OPENID_EXPOSE_SUB_COOKIE is false', () => { + process.env.OPENID_REUSE_TOKENS = 'true'; + process.env.OPENID_EXPOSE_SUB_COOKIE = 'false'; + + setOpenIDAuthTokens(mockTokenset, mockReq, mockRes, testUserId); + + expect(mockRes.cookie).toHaveBeenCalledTimes(4); + const cookieNames = mockRes.cookie.mock.calls.map((call) => call[0]); + expect(cookieNames).toEqual([ + 'refreshToken', + 'openid_access_token', + 'token_provider', + 'openid_user_id', + ]); + }); + + it('should set 5 cookies when both OPENID_REUSE_TOKENS and OPENID_EXPOSE_SUB_COOKIE are true', () => { + process.env.OPENID_REUSE_TOKENS = 'true'; + process.env.OPENID_EXPOSE_SUB_COOKIE = 'true'; + + setOpenIDAuthTokens(mockTokenset, mockReq, mockRes, testUserId); + + expect(mockRes.cookie).toHaveBeenCalledTimes(5); + const cookieNames = mockRes.cookie.mock.calls.map((call) => call[0]); + expect(cookieNames).toEqual([ + 'refreshToken', + 'openid_access_token', + 'token_provider', + 'openid_user_id', + 'openid_sub', + ]); + }); + + it('should set 4 cookies when OPENID_EXPOSE_SUB_COOKIE is true but OPENID_REUSE_TOKENS is false', () => { + process.env.OPENID_REUSE_TOKENS = 'false'; + process.env.OPENID_EXPOSE_SUB_COOKIE = 'true'; + + setOpenIDAuthTokens(mockTokenset, mockReq, mockRes, testUserId); + + expect(mockRes.cookie).toHaveBeenCalledTimes(4); + const cookieNames = mockRes.cookie.mock.calls.map((call) => call[0]); + expect(cookieNames).toEqual([ + 'refreshToken', + 'openid_access_token', + 'token_provider', + 'openid_sub', + ]); + }); + }); + + describe('other cookies remain strict', () => { + it('should keep all other cookies with strict sameSite when OPENID_EXPOSE_SUB_COOKIE is true', () => { + process.env.OPENID_REUSE_TOKENS = 'true'; + process.env.OPENID_EXPOSE_SUB_COOKIE = 'true'; + + setOpenIDAuthTokens(mockTokenset, mockReq, mockRes, testUserId); + + const refreshTokenCall = mockRes.cookie.mock.calls.find((call) => call[0] === 'refreshToken'); + const accessTokenCall = mockRes.cookie.mock.calls.find( + (call) => call[0] === 'openid_access_token', + ); + const providerCall = mockRes.cookie.mock.calls.find((call) => call[0] === 'token_provider'); + const userIdCall = mockRes.cookie.mock.calls.find((call) => call[0] === 'openid_user_id'); + + // All other cookies should remain strict + expect(refreshTokenCall[2].sameSite).toBe('strict'); + expect(accessTokenCall[2].sameSite).toBe('strict'); + expect(providerCall[2].sameSite).toBe('strict'); + expect(userIdCall[2].sameSite).toBe('strict'); + }); + }); + + describe('openid_sub cookie for Cognito sub claim', () => { + it('should create JWT-signed openid_sub cookie when OPENID_EXPOSE_SUB_COOKIE is true', () => { + process.env.OPENID_EXPOSE_SUB_COOKIE = 'true'; + + setOpenIDAuthTokens(mockTokenset, mockReq, mockRes, testUserId); + + const openidSubCall = mockRes.cookie.mock.calls.find((call) => call[0] === 'openid_sub'); + expect(openidSubCall).toBeDefined(); + expect(openidSubCall[1]).toBe('mocked-jwt-token-cognito-sub-12345'); // JWT-signed + expect(openidSubCall[2]).toMatchObject({ + httpOnly: true, + sameSite: 'lax', + }); + }); + + it('should not create openid_sub cookie when OPENID_EXPOSE_SUB_COOKIE is false', () => { + process.env.OPENID_EXPOSE_SUB_COOKIE = 'false'; + + setOpenIDAuthTokens(mockTokenset, mockReq, mockRes, testUserId); + + const openidSubCall = mockRes.cookie.mock.calls.find((call) => call[0] === 'openid_sub'); + expect(openidSubCall).toBeUndefined(); + }); + + it('should sign the Cognito sub claim with JWT_REFRESH_SECRET', () => { + process.env.OPENID_EXPOSE_SUB_COOKIE = 'true'; + const jwt = require('jsonwebtoken'); + + setOpenIDAuthTokens(mockTokenset, mockReq, mockRes, testUserId); + + expect(jwt.sign).toHaveBeenCalledWith( + { sub: 'cognito-sub-12345' }, + 'test-secret', + expect.objectContaining({ + expiresIn: expect.any(Number), + }), + ); + }); + + it('should set correct expiration on openid_sub cookie', () => { + process.env.OPENID_EXPOSE_SUB_COOKIE = 'true'; + const expiryInMs = 1000 * 60 * 60 * 24 * 7; // 7 days + const beforeTime = Date.now() + expiryInMs; + + setOpenIDAuthTokens(mockTokenset, mockReq, mockRes, testUserId); + + const afterTime = Date.now() + expiryInMs; + const openidSubCall = mockRes.cookie.mock.calls.find((call) => call[0] === 'openid_sub'); + + expect(openidSubCall[2].expires).toBeInstanceOf(Date); + expect(openidSubCall[2].expires.getTime()).toBeGreaterThanOrEqual(beforeTime); + expect(openidSubCall[2].expires.getTime()).toBeLessThanOrEqual(afterTime); + }); + }); + + describe('AWS Bedrock AgentCore 3LO use case', () => { + it('should support 3LO callback flow with JWT-signed lax openid_sub cookie', () => { + // Simulate AWS Bedrock AgentCore setup + process.env.OPENID_EXPOSE_SUB_COOKIE = 'true'; + + setOpenIDAuthTokens(mockTokenset, mockReq, mockRes, testUserId); + + const openidSubCall = mockRes.cookie.mock.calls.find((call) => call[0] === 'openid_sub'); + + // Verify cookie allows cross-site GET requests (OAuth callbacks) + expect(openidSubCall[2].sameSite).toBe('lax'); + + // Verify cookie maintains security + expect(openidSubCall[2].httpOnly).toBe(true); + expect(openidSubCall[2]).toHaveProperty('secure'); + + // Verify it's JWT-signed for verification in callback service + expect(openidSubCall[1]).toMatch(/^mocked-jwt-token-/); + }); + }); + + describe('backwards compatibility', () => { + it('should not create openid_sub cookie when feature flag is not set', () => { + process.env.OPENID_REUSE_TOKENS = 'true'; + // OPENID_EXPOSE_SUB_COOKIE not set + + setOpenIDAuthTokens(mockTokenset, mockReq, mockRes, testUserId); + + const openidSubCall = mockRes.cookie.mock.calls.find((call) => call[0] === 'openid_sub'); + expect(openidSubCall).toBeUndefined(); + }); + + it('should not break existing functionality when feature flag is enabled but OPENID_REUSE_TOKENS is false', () => { + process.env.OPENID_REUSE_TOKENS = 'false'; + process.env.OPENID_EXPOSE_SUB_COOKIE = 'true'; + + setOpenIDAuthTokens(mockTokenset, mockReq, mockRes, testUserId); + + // Should set 4 cookies: 3 standard + openid_sub + expect(mockRes.cookie).toHaveBeenCalledTimes(4); + const cookieNames = mockRes.cookie.mock.calls.map((call) => call[0]); + expect(cookieNames).toEqual([ + 'refreshToken', + 'openid_access_token', + 'token_provider', + 'openid_sub', + ]); + }); + }); + + describe('error handling and edge cases', () => { + it('should handle missing JWT_REFRESH_SECRET gracefully', () => { + const { logger } = require('@librechat/data-schemas'); + delete process.env.JWT_REFRESH_SECRET; + process.env.OPENID_EXPOSE_SUB_COOKIE = 'true'; + + setOpenIDAuthTokens(mockTokenset, mockReq, mockRes, testUserId); + + const openidSubCall = mockRes.cookie.mock.calls.find((call) => call[0] === 'openid_sub'); + expect(openidSubCall).toBeUndefined(); + expect(logger.error).toHaveBeenCalledWith( + expect.stringContaining('JWT_REFRESH_SECRET not configured'), + ); + }); + + it('should not create openid_sub cookie when access token has no sub claim', () => { + const { extractSubFromAccessToken } = require('@librechat/api'); + extractSubFromAccessToken.mockReturnValueOnce({ + sub: null, + error: 'No sub claim in access token', + }); + + process.env.OPENID_EXPOSE_SUB_COOKIE = 'true'; + mockTokenset.access_token = 'token-without-sub'; + setOpenIDAuthTokens(mockTokenset, mockReq, mockRes, testUserId); + + const openidSubCall = mockRes.cookie.mock.calls.find((call) => call[0] === 'openid_sub'); + expect(openidSubCall).toBeUndefined(); + }); + + it('should not create openid_sub cookie when extractSubFromAccessToken returns null', () => { + const { extractSubFromAccessToken } = require('@librechat/api'); + extractSubFromAccessToken.mockReturnValueOnce({ + sub: null, + error: 'Invalid JWT format', + }); + + process.env.OPENID_EXPOSE_SUB_COOKIE = 'true'; + mockTokenset.access_token = 'invalid-token'; + setOpenIDAuthTokens(mockTokenset, mockReq, mockRes, testUserId); + + const openidSubCall = mockRes.cookie.mock.calls.find((call) => call[0] === 'openid_sub'); + expect(openidSubCall).toBeUndefined(); + }); + + it('should handle decode errors gracefully', () => { + const { extractSubFromAccessToken } = require('@librechat/api'); + extractSubFromAccessToken.mockReturnValueOnce({ + sub: null, + error: 'Failed to decode access token', + }); + + process.env.OPENID_EXPOSE_SUB_COOKIE = 'true'; + mockTokenset.access_token = 'invalid-token'; + setOpenIDAuthTokens(mockTokenset, mockReq, mockRes, testUserId); + + const openidSubCall = mockRes.cookie.mock.calls.find((call) => call[0] === 'openid_sub'); + expect(openidSubCall).toBeUndefined(); + }); + }); +}); diff --git a/api/server/services/__tests__/AuthService.userid-cookie.spec.js b/api/server/services/__tests__/AuthService.userid-cookie.spec.js new file mode 100644 index 000000000000..7b8fb7f16f42 --- /dev/null +++ b/api/server/services/__tests__/AuthService.userid-cookie.spec.js @@ -0,0 +1,510 @@ +jest.mock('@librechat/data-schemas', () => ({ + logger: { + error: jest.fn(), + debug: jest.fn(), + info: jest.fn(), + warn: jest.fn(), + }, + hashToken: jest.fn((token) => `hashed-${token}`), + createMethods: jest.fn(() => ({})), + DEFAULT_REFRESH_TOKEN_EXPIRY: 1000 * 60 * 60 * 24 * 7, // 7 days in milliseconds + DEFAULT_SESSION_EXPIRY: 1000 * 60 * 15, // 15 minutes in milliseconds +})); + +jest.mock('librechat-data-provider', () => ({ + ...jest.requireActual('librechat-data-provider'), +})); + +jest.mock('@librechat/api', () => ({ + ...jest.requireActual('@librechat/api'), + checkEmailConfig: jest.fn(() => false), + isEmailDomainAllowed: jest.fn(() => true), + extractSubFromAccessToken: jest.fn((token) => { + if (!token) { + return { sub: null, error: 'No access token provided' }; + } + if (token === 'test-access-token') { + return { sub: 'openid-provider-sub-67890' }; + } + if (token === 'token-without-sub') { + return { sub: null, error: 'No sub claim in access token' }; + } + if (token === 'invalid-token') { + return { sub: null, error: 'Failed to decode access token' }; + } + return { sub: 'openid-provider-sub-67890' }; + }), +})); + +jest.mock('jsonwebtoken', () => ({ + sign: jest.fn((payload, secret, options) => { + if (payload.id) { + return `mocked-jwt-token-${payload.id}`; + } + if (payload.sub) { + return `mocked-jwt-token-${payload.sub}`; + } + return 'mocked-jwt-token'; + }), + verify: jest.fn(), +})); + +jest.mock('jsonwebtoken/decode', () => + jest.fn((token) => ({ + sub: 'openid-provider-sub-67890', + email: 'test@example.com', + name: 'Test User', + })), +); + +jest.mock('bcryptjs', () => ({ + genSaltSync: jest.fn(() => 'mock-salt'), + hashSync: jest.fn((password) => `hashed-${password}`), + compareSync: jest.fn(() => true), +})); + +jest.mock('~/models', () => ({ + findUser: jest.fn(), + findToken: jest.fn(), + createUser: jest.fn(), + updateUser: jest.fn(), + countUsers: jest.fn(), + getUserById: jest.fn(), + findSession: jest.fn(), + createToken: jest.fn(), + deleteTokens: jest.fn(), + deleteSession: jest.fn(), + createSession: jest.fn(), + generateToken: jest.fn(), + deleteUserById: jest.fn(), + generateRefreshToken: jest.fn(), +})); + +jest.mock('~/strategies/validators', () => ({ + registerSchema: { + safeParse: jest.fn(() => ({ error: null })), + }, +})); + +jest.mock('~/server/services/Config', () => ({ + getAppConfig: jest.fn(() => Promise.resolve({})), +})); + +jest.mock('~/server/utils', () => ({ + sendEmail: jest.fn(() => Promise.resolve()), +})); + +const { setOpenIDAuthTokens } = require('../AuthService'); +const { logger } = require('@librechat/data-schemas'); + +describe('setOpenIDAuthTokens - openid_sub cookie functionality', () => { + let mockReq; + let mockRes; + let mockTokenset; + const testUserId = 'test-user-id-12345'; + const testOpenIdSub = 'openid-provider-sub-67890'; + + beforeEach(() => { + // Mock request object without session to trigger cookie fallback + mockReq = { + session: null, + }; + + mockRes = { + cookie: jest.fn(), + }; + + // Mock jwt.decode to return decoded token with sub claim + const jwtDecode = require('jsonwebtoken/decode'); + jwtDecode.mockReturnValue({ + sub: testOpenIdSub, + email: 'test@example.com', + name: 'Test User', + }); + + mockTokenset = { + access_token: 'test-access-token', + refresh_token: 'test-refresh-token', + id_token: 'test-id-token', + }; + + // Reset NODE_ENV to production for secure cookies + const originalEnv = process.env.NODE_ENV; + process.env.NODE_ENV = 'production'; + // Store original value to restore later + process.env._ORIGINAL_NODE_ENV = originalEnv; + process.env.JWT_REFRESH_SECRET = 'test-secret'; + // Enable feature flag by default + process.env.OPENID_EXPOSE_SUB_COOKIE = 'true'; + jest.clearAllMocks(); + }); + + afterEach(() => { + delete process.env.NODE_ENV; + delete process.env.JWT_REFRESH_SECRET; + delete process.env.OPENID_REUSE_TOKENS; + delete process.env.REFRESH_TOKEN_EXPIRY; + delete process.env.OPENID_EXPOSE_SUB_COOKIE; + }); + + describe('openid_sub cookie setting', () => { + it('should set openid_sub cookie with lax sameSite when access token contains sub', () => { + setOpenIDAuthTokens(mockTokenset, mockReq, mockRes, testUserId); + + const openidSubCookieCall = mockRes.cookie.mock.calls.find( + (call) => call[0] === 'openid_sub', + ); + expect(openidSubCookieCall).toBeDefined(); + expect(openidSubCookieCall[1]).toBe(`mocked-jwt-token-${testOpenIdSub}`); + expect(openidSubCookieCall[2]).toMatchObject({ + httpOnly: true, + sameSite: 'lax', + }); + // Verify secure flag exists (value depends on NODE_ENV) + expect(openidSubCookieCall[2]).toHaveProperty('secure'); + expect(typeof openidSubCookieCall[2].secure).toBe('boolean'); + }); + + it('should not set openid_sub cookie when access token has no sub claim', () => { + const { extractSubFromAccessToken } = require('@librechat/api'); + extractSubFromAccessToken.mockReturnValueOnce({ sub: null, error: 'No sub claim' }); + + setOpenIDAuthTokens(mockTokenset, mockReq, mockRes, testUserId); + + const openidSubCookieCall = mockRes.cookie.mock.calls.find( + (call) => call[0] === 'openid_sub', + ); + expect(openidSubCookieCall).toBeUndefined(); + }); + + it('should not set openid_sub cookie when jwt decode returns null', () => { + const { extractSubFromAccessToken } = require('@librechat/api'); + extractSubFromAccessToken.mockReturnValueOnce({ sub: null, error: 'Decode error' }); + + setOpenIDAuthTokens(mockTokenset, mockReq, mockRes, testUserId); + + const openidSubCookieCall = mockRes.cookie.mock.calls.find( + (call) => call[0] === 'openid_sub', + ); + expect(openidSubCookieCall).toBeUndefined(); + }); + + it('should set openid_sub cookie with secure=false in non-production environment', () => { + process.env.NODE_ENV = 'development'; + + setOpenIDAuthTokens(mockTokenset, mockReq, mockRes, testUserId); + + const openidSubCookieCall = mockRes.cookie.mock.calls.find( + (call) => call[0] === 'openid_sub', + ); + expect(openidSubCookieCall).toBeDefined(); + expect(openidSubCookieCall[2]).toMatchObject({ + httpOnly: true, + secure: false, + sameSite: 'lax', + }); + }); + + it('should set openid_sub cookie with correct expiration date', () => { + const expiryInMs = 1000 * 60 * 60 * 24 * 7; // 7 days + const beforeTime = Date.now() + expiryInMs; + + setOpenIDAuthTokens(mockTokenset, mockReq, mockRes, testUserId); + + const afterTime = Date.now() + expiryInMs; + const openidSubCookieCall = mockRes.cookie.mock.calls.find( + (call) => call[0] === 'openid_sub', + ); + + expect(openidSubCookieCall[2].expires).toBeInstanceOf(Date); + expect(openidSubCookieCall[2].expires.getTime()).toBeGreaterThanOrEqual(beforeTime); + expect(openidSubCookieCall[2].expires.getTime()).toBeLessThanOrEqual(afterTime); + }); + + it('should use custom REFRESH_TOKEN_EXPIRY for openid_sub cookie expiration', () => { + process.env.REFRESH_TOKEN_EXPIRY = '1000 * 60 * 60 * 24 * 14'; // 14 days + const expiryInMs = 1000 * 60 * 60 * 24 * 14; + const beforeTime = Date.now() + expiryInMs; + + setOpenIDAuthTokens(mockTokenset, mockReq, mockRes, testUserId); + + const afterTime = Date.now() + expiryInMs; + const openidSubCookieCall = mockRes.cookie.mock.calls.find( + (call) => call[0] === 'openid_sub', + ); + + expect(openidSubCookieCall[2].expires.getTime()).toBeGreaterThanOrEqual(beforeTime); + expect(openidSubCookieCall[2].expires.getTime()).toBeLessThanOrEqual(afterTime); + }); + }); + + describe('openid_sub cookie with other cookies', () => { + it('should set all cookies including openid_sub', () => { + setOpenIDAuthTokens(mockTokenset, mockReq, mockRes, testUserId); + + expect(mockRes.cookie).toHaveBeenCalledTimes(4); + expect(mockRes.cookie).toHaveBeenCalledWith( + 'refreshToken', + expect.any(String), + expect.any(Object), + ); + expect(mockRes.cookie).toHaveBeenCalledWith( + 'openid_access_token', + expect.any(String), + expect.any(Object), + ); + expect(mockRes.cookie).toHaveBeenCalledWith('token_provider', 'openid', expect.any(Object)); + expect(mockRes.cookie).toHaveBeenCalledWith( + 'openid_sub', + `mocked-jwt-token-${testOpenIdSub}`, + expect.any(Object), + ); + }); + + it('should set openid_user_id and openid_sub when OPENID_REUSE_TOKENS is enabled', () => { + process.env.OPENID_REUSE_TOKENS = 'true'; + + setOpenIDAuthTokens(mockTokenset, mockReq, mockRes, testUserId); + + expect(mockRes.cookie).toHaveBeenCalledTimes(5); + expect(mockRes.cookie).toHaveBeenCalledWith( + 'openid_user_id', + expect.any(String), + expect.any(Object), + ); + expect(mockRes.cookie).toHaveBeenCalledWith( + 'openid_sub', + `mocked-jwt-token-${testOpenIdSub}`, + expect.any(Object), + ); + + const openidUserIdCall = mockRes.cookie.mock.calls.find( + (call) => call[0] === 'openid_user_id', + ); + const openidSubCall = mockRes.cookie.mock.calls.find((call) => call[0] === 'openid_sub'); + + // openid_user_id uses strict sameSite + expect(openidUserIdCall[2].sameSite).toBe('strict'); + // openid_sub uses lax sameSite + expect(openidSubCall[2].sameSite).toBe('lax'); + }); + + it('should set 3 cookies when access token has no sub claim', () => { + const { extractSubFromAccessToken } = require('@librechat/api'); + extractSubFromAccessToken.mockReturnValueOnce({ sub: null, error: 'No sub claim' }); + + setOpenIDAuthTokens(mockTokenset, mockReq, mockRes, testUserId); + + expect(mockRes.cookie).toHaveBeenCalledTimes(3); + const cookieNames = mockRes.cookie.mock.calls.map((call) => call[0]); + expect(cookieNames).toEqual(['refreshToken', 'openid_access_token', 'token_provider']); + }); + }); + + describe('openid_sub cookie security', () => { + it('should set httpOnly flag on openid_sub cookie', () => { + setOpenIDAuthTokens(mockTokenset, mockReq, mockRes, testUserId); + + const openidSubCookieCall = mockRes.cookie.mock.calls.find( + (call) => call[0] === 'openid_sub', + ); + expect(openidSubCookieCall[2].httpOnly).toBe(true); + }); + + it('should set sameSite to lax (not strict) on openid_sub cookie', () => { + setOpenIDAuthTokens(mockTokenset, mockReq, mockRes, testUserId); + + const openidSubCookieCall = mockRes.cookie.mock.calls.find( + (call) => call[0] === 'openid_sub', + ); + expect(openidSubCookieCall[2].sameSite).toBe('lax'); + expect(openidSubCookieCall[2].sameSite).not.toBe('strict'); + }); + + it('should verify other cookies still use strict sameSite', () => { + setOpenIDAuthTokens(mockTokenset, mockReq, mockRes, testUserId); + + const refreshTokenCall = mockRes.cookie.mock.calls.find((call) => call[0] === 'refreshToken'); + const accessTokenCall = mockRes.cookie.mock.calls.find( + (call) => call[0] === 'openid_access_token', + ); + const providerCall = mockRes.cookie.mock.calls.find((call) => call[0] === 'token_provider'); + + expect(refreshTokenCall[2].sameSite).toBe('strict'); + expect(accessTokenCall[2].sameSite).toBe('strict'); + expect(providerCall[2].sameSite).toBe('strict'); + }); + }); + + describe('error handling', () => { + it('should not throw error when userId is provided but tokenset is invalid', () => { + expect(() => { + setOpenIDAuthTokens(null, mockReq, mockRes, testUserId); + }).not.toThrow(); + + expect(logger.error).toHaveBeenCalledWith( + '[setOpenIDAuthTokens] No tokenset found in request', + ); + }); + + it('should not set openid_sub cookie when tokenset is missing access_token', () => { + const invalidTokenset = { + refresh_token: 'test-refresh-token', + }; + + setOpenIDAuthTokens(invalidTokenset, mockReq, mockRes, testUserId); + + const openidSubCookieCall = mockRes.cookie.mock.calls.find( + (call) => call[0] === 'openid_sub', + ); + expect(openidSubCookieCall).toBeUndefined(); + }); + + it('should not set openid_sub cookie when refresh_token is missing', () => { + const invalidTokenset = { + access_token: 'test-access-token', + }; + + setOpenIDAuthTokens(invalidTokenset, mockReq, mockRes, testUserId); + + const openidSubCookieCall = mockRes.cookie.mock.calls.find( + (call) => call[0] === 'openid_sub', + ); + expect(openidSubCookieCall).toBeUndefined(); + }); + + it('should handle error during cookie setting gracefully', () => { + mockRes.cookie = jest.fn(() => { + throw new Error('Cookie setting failed'); + }); + + expect(() => { + setOpenIDAuthTokens(mockTokenset, mockReq, mockRes, testUserId); + }).toThrow('Cookie setting failed'); + + expect(logger.error).toHaveBeenCalledWith( + '[setOpenIDAuthTokens] Error in setting authentication tokens:', + expect.any(Error), + ); + }); + + it('should handle jwt decode throwing an error', () => { + const { extractSubFromAccessToken } = require('@librechat/api'); + extractSubFromAccessToken.mockReturnValueOnce({ sub: null, error: 'Invalid JWT' }); + + // Should not throw, just skip setting openid_sub cookie + expect(() => { + setOpenIDAuthTokens(mockTokenset, mockReq, mockRes, testUserId); + }).not.toThrow(); + + const openidSubCookieCall = mockRes.cookie.mock.calls.find( + (call) => call[0] === 'openid_sub', + ); + expect(openidSubCookieCall).toBeUndefined(); + }); + }); + + describe('integration with existing refresh token', () => { + it('should set openid_sub cookie when using existingRefreshToken', () => { + const tokensetWithoutRefresh = { + access_token: 'test-access-token', + }; + const existingRefreshToken = 'existing-refresh-token'; + + setOpenIDAuthTokens( + tokensetWithoutRefresh, + mockReq, + mockRes, + testUserId, + existingRefreshToken, + ); + + const openidSubCookieCall = mockRes.cookie.mock.calls.find( + (call) => call[0] === 'openid_sub', + ); + expect(openidSubCookieCall).toBeDefined(); + expect(openidSubCookieCall[1]).toBe(`mocked-jwt-token-${testOpenIdSub}`); + }); + + it('should prefer tokenset refresh_token over existingRefreshToken', () => { + const existingRefreshToken = 'existing-refresh-token'; + + setOpenIDAuthTokens(mockTokenset, mockReq, mockRes, testUserId, existingRefreshToken); + + const refreshTokenCall = mockRes.cookie.mock.calls.find((call) => call[0] === 'refreshToken'); + expect(refreshTokenCall[1]).toBe(mockTokenset.refresh_token); + expect(refreshTokenCall[1]).not.toBe(existingRefreshToken); + + const openidSubCookieCall = mockRes.cookie.mock.calls.find( + (call) => call[0] === 'openid_sub', + ); + expect(openidSubCookieCall).toBeDefined(); + }); + }); + + describe('feature flag behavior', () => { + it('should not set openid_sub cookie when OPENID_EXPOSE_SUB_COOKIE is false', () => { + process.env.OPENID_EXPOSE_SUB_COOKIE = 'false'; + + setOpenIDAuthTokens(mockTokenset, mockReq, mockRes, testUserId); + + const openidSubCookieCall = mockRes.cookie.mock.calls.find( + (call) => call[0] === 'openid_sub', + ); + expect(openidSubCookieCall).toBeUndefined(); + + // Other cookies should still be set + expect(mockRes.cookie).toHaveBeenCalledTimes(3); + const cookieNames = mockRes.cookie.mock.calls.map((call) => call[0]); + expect(cookieNames).toEqual(['refreshToken', 'openid_access_token', 'token_provider']); + }); + + it('should not set openid_sub cookie when OPENID_EXPOSE_SUB_COOKIE is not set', () => { + delete process.env.OPENID_EXPOSE_SUB_COOKIE; + + setOpenIDAuthTokens(mockTokenset, mockReq, mockRes, testUserId); + + const openidSubCookieCall = mockRes.cookie.mock.calls.find( + (call) => call[0] === 'openid_sub', + ); + expect(openidSubCookieCall).toBeUndefined(); + + // Other cookies should still be set + expect(mockRes.cookie).toHaveBeenCalledTimes(3); + }); + + it('should set openid_sub cookie when OPENID_EXPOSE_SUB_COOKIE is true', () => { + process.env.OPENID_EXPOSE_SUB_COOKIE = 'true'; + + setOpenIDAuthTokens(mockTokenset, mockReq, mockRes, testUserId); + + const openidSubCookieCall = mockRes.cookie.mock.calls.find( + (call) => call[0] === 'openid_sub', + ); + expect(openidSubCookieCall).toBeDefined(); + expect(openidSubCookieCall[1]).toBe(`mocked-jwt-token-${testOpenIdSub}`); + }); + + it('should set openid_sub cookie with OPENID_REUSE_TOKENS and OPENID_EXPOSE_SUB_COOKIE both enabled', () => { + process.env.OPENID_REUSE_TOKENS = 'true'; + process.env.OPENID_EXPOSE_SUB_COOKIE = 'true'; + + setOpenIDAuthTokens(mockTokenset, mockReq, mockRes, testUserId); + + expect(mockRes.cookie).toHaveBeenCalledTimes(5); + + const openidUserIdCall = mockRes.cookie.mock.calls.find( + (call) => call[0] === 'openid_user_id', + ); + const openidSubCall = mockRes.cookie.mock.calls.find((call) => call[0] === 'openid_sub'); + + expect(openidUserIdCall).toBeDefined(); + expect(openidSubCall).toBeDefined(); + + // openid_user_id uses strict sameSite + expect(openidUserIdCall[2].sameSite).toBe('strict'); + // openid_sub uses lax sameSite + expect(openidSubCall[2].sameSite).toBe('lax'); + }); + }); +}); diff --git a/api/server/services/__tests__/ToolService.spec.js b/api/server/services/__tests__/ToolService.spec.js new file mode 100644 index 000000000000..2f00bbc3d6ac --- /dev/null +++ b/api/server/services/__tests__/ToolService.spec.js @@ -0,0 +1,149 @@ +const { AgentCapabilities, defaultAgentCapabilities } = require('librechat-data-provider'); + +/** + * Tests for ToolService capability checking logic. + * The actual loadAgentTools function has many dependencies, so we test + * the capability checking logic in isolation. + */ +describe('ToolService - Capability Checking', () => { + describe('checkCapability logic', () => { + /** + * Simulates the checkCapability function from loadAgentTools + */ + const createCheckCapability = (enabledCapabilities, logger = { warn: jest.fn() }) => { + return (capability) => { + const enabled = enabledCapabilities.has(capability); + if (!enabled) { + const isToolCapability = [ + AgentCapabilities.file_search, + AgentCapabilities.execute_code, + AgentCapabilities.web_search, + ].includes(capability); + const suffix = isToolCapability ? ' despite configured tool.' : '.'; + logger.warn(`Capability "${capability}" disabled${suffix}`); + } + return enabled; + }; + }; + + it('should return true when capability is enabled', () => { + const enabledCapabilities = new Set([AgentCapabilities.deferred_tools]); + const checkCapability = createCheckCapability(enabledCapabilities); + + expect(checkCapability(AgentCapabilities.deferred_tools)).toBe(true); + }); + + it('should return false when capability is not enabled', () => { + const enabledCapabilities = new Set([]); + const checkCapability = createCheckCapability(enabledCapabilities); + + expect(checkCapability(AgentCapabilities.deferred_tools)).toBe(false); + }); + + it('should log warning with "despite configured tool" for tool capabilities', () => { + const logger = { warn: jest.fn() }; + const enabledCapabilities = new Set([]); + const checkCapability = createCheckCapability(enabledCapabilities, logger); + + checkCapability(AgentCapabilities.file_search); + expect(logger.warn).toHaveBeenCalledWith(expect.stringContaining('despite configured tool')); + + logger.warn.mockClear(); + checkCapability(AgentCapabilities.execute_code); + expect(logger.warn).toHaveBeenCalledWith(expect.stringContaining('despite configured tool')); + + logger.warn.mockClear(); + checkCapability(AgentCapabilities.web_search); + expect(logger.warn).toHaveBeenCalledWith(expect.stringContaining('despite configured tool')); + }); + + it('should log warning without "despite configured tool" for non-tool capabilities', () => { + const logger = { warn: jest.fn() }; + const enabledCapabilities = new Set([]); + const checkCapability = createCheckCapability(enabledCapabilities, logger); + + checkCapability(AgentCapabilities.deferred_tools); + expect(logger.warn).toHaveBeenCalledWith( + expect.stringContaining('Capability "deferred_tools" disabled.'), + ); + expect(logger.warn).not.toHaveBeenCalledWith( + expect.stringContaining('despite configured tool'), + ); + + logger.warn.mockClear(); + checkCapability(AgentCapabilities.tools); + expect(logger.warn).toHaveBeenCalledWith( + expect.stringContaining('Capability "tools" disabled.'), + ); + expect(logger.warn).not.toHaveBeenCalledWith( + expect.stringContaining('despite configured tool'), + ); + + logger.warn.mockClear(); + checkCapability(AgentCapabilities.actions); + expect(logger.warn).toHaveBeenCalledWith( + expect.stringContaining('Capability "actions" disabled.'), + ); + }); + + it('should not log warning when capability is enabled', () => { + const logger = { warn: jest.fn() }; + const enabledCapabilities = new Set([ + AgentCapabilities.deferred_tools, + AgentCapabilities.file_search, + ]); + const checkCapability = createCheckCapability(enabledCapabilities, logger); + + checkCapability(AgentCapabilities.deferred_tools); + checkCapability(AgentCapabilities.file_search); + + expect(logger.warn).not.toHaveBeenCalled(); + }); + }); + + describe('defaultAgentCapabilities', () => { + it('should include deferred_tools capability by default', () => { + expect(defaultAgentCapabilities).toContain(AgentCapabilities.deferred_tools); + }); + + it('should include all expected default capabilities', () => { + expect(defaultAgentCapabilities).toContain(AgentCapabilities.execute_code); + expect(defaultAgentCapabilities).toContain(AgentCapabilities.file_search); + expect(defaultAgentCapabilities).toContain(AgentCapabilities.web_search); + expect(defaultAgentCapabilities).toContain(AgentCapabilities.artifacts); + expect(defaultAgentCapabilities).toContain(AgentCapabilities.actions); + expect(defaultAgentCapabilities).toContain(AgentCapabilities.context); + expect(defaultAgentCapabilities).toContain(AgentCapabilities.tools); + expect(defaultAgentCapabilities).toContain(AgentCapabilities.chain); + expect(defaultAgentCapabilities).toContain(AgentCapabilities.ocr); + }); + }); + + describe('deferredToolsEnabled integration', () => { + it('should correctly determine deferredToolsEnabled from capabilities set', () => { + const createCheckCapability = (enabledCapabilities) => { + return (capability) => enabledCapabilities.has(capability); + }; + + // When deferred_tools is in capabilities + const withDeferred = new Set([AgentCapabilities.deferred_tools, AgentCapabilities.tools]); + const checkWithDeferred = createCheckCapability(withDeferred); + expect(checkWithDeferred(AgentCapabilities.deferred_tools)).toBe(true); + + // When deferred_tools is NOT in capabilities + const withoutDeferred = new Set([AgentCapabilities.tools, AgentCapabilities.actions]); + const checkWithoutDeferred = createCheckCapability(withoutDeferred); + expect(checkWithoutDeferred(AgentCapabilities.deferred_tools)).toBe(false); + }); + + it('should use defaultAgentCapabilities when no capabilities configured', () => { + // Simulates the fallback behavior in loadAgentTools + const endpointsConfig = {}; // No capabilities configured + const enabledCapabilities = new Set( + endpointsConfig?.capabilities ?? defaultAgentCapabilities, + ); + + expect(enabledCapabilities.has(AgentCapabilities.deferred_tools)).toBe(true); + }); + }); +}); diff --git a/api/strategies/index.js b/api/strategies/index.js index 725e04224a03..b4f7bd3cac34 100644 --- a/api/strategies/index.js +++ b/api/strategies/index.js @@ -1,14 +1,14 @@ -const appleLogin = require('./appleStrategy'); +const { setupOpenId, getOpenIdConfig } = require('./openidStrategy'); +const openIdJwtLogin = require('./openIdJwtStrategy'); +const facebookLogin = require('./facebookStrategy'); +const discordLogin = require('./discordStrategy'); const passportLogin = require('./localStrategy'); const googleLogin = require('./googleStrategy'); const githubLogin = require('./githubStrategy'); -const discordLogin = require('./discordStrategy'); -const facebookLogin = require('./facebookStrategy'); -const { setupOpenId, getOpenIdConfig } = require('./openidStrategy'); -const jwtLogin = require('./jwtStrategy'); -const ldapLogin = require('./ldapStrategy'); const { setupSaml } = require('./samlStrategy'); -const openIdJwtLogin = require('./openIdJwtStrategy'); +const appleLogin = require('./appleStrategy'); +const ldapLogin = require('./ldapStrategy'); +const jwtLogin = require('./jwtStrategy'); module.exports = { appleLogin, diff --git a/api/strategies/openidStrategy.js b/api/strategies/openidStrategy.js index a4369e601b42..84458ce99256 100644 --- a/api/strategies/openidStrategy.js +++ b/api/strategies/openidStrategy.js @@ -6,8 +6,8 @@ const client = require('openid-client'); const jwtDecode = require('jsonwebtoken/decode'); const { HttpsProxyAgent } = require('https-proxy-agent'); const { hashToken, logger } = require('@librechat/data-schemas'); -const { CacheKeys, ErrorTypes } = require('librechat-data-provider'); const { Strategy: OpenIDStrategy } = require('openid-client/passport'); +const { CacheKeys, ErrorTypes, SystemRoles } = require('librechat-data-provider'); const { isEnabled, logHeaders, @@ -287,6 +287,274 @@ function convertToUsername(input, defaultValue = '') { return defaultValue; } +/** + * Process OpenID authentication tokenset and userinfo + * This is the core logic extracted from the passport strategy callback + * Can be reused by both the passport strategy and proxy authentication + * + * @param {Object} tokenset - The OpenID tokenset containing access_token, id_token, etc. + * @param {boolean} existingUsersOnly - If true, only existing users will be processed + * @returns {Promise} The authenticated user object with tokenset + */ +async function processOpenIDAuth(tokenset, existingUsersOnly = false) { + const claims = tokenset.claims ? tokenset.claims() : tokenset; + const userinfo = { + ...claims, + }; + + if (tokenset.access_token) { + const providerUserinfo = await getUserInfo(openidConfig, tokenset.access_token, claims.sub); + Object.assign(userinfo, providerUserinfo); + } + + const appConfig = await getAppConfig(); + /** Azure AD sometimes doesn't return email, use preferred_username as fallback */ + const email = userinfo.email || userinfo.preferred_username || userinfo.upn; + if (!isEmailDomainAllowed(email, appConfig?.registration?.allowedDomains)) { + logger.error( + `[OpenID Strategy] Authentication blocked - email domain not allowed [Email: ${userinfo.email}]`, + ); + throw new Error('Email domain not allowed'); + } + + const result = await findOpenIDUser({ + findUser, + email: email, + openidId: claims.sub || userinfo.sub, + idOnTheSource: claims.oid || userinfo.oid, + strategyName: 'openidStrategy', + }); + let user = result.user; + const error = result.error; + + if (error) { + throw new Error(ErrorTypes.AUTH_FAILED); + } + + const fullName = getFullName(userinfo); + + const requiredRole = process.env.OPENID_REQUIRED_ROLE; + if (requiredRole) { + const requiredRoles = requiredRole + .split(',') + .map((role) => role.trim()) + .filter(Boolean); + const requiredRoleParameterPath = process.env.OPENID_REQUIRED_ROLE_PARAMETER_PATH; + const requiredRoleTokenKind = process.env.OPENID_REQUIRED_ROLE_TOKEN_KIND; + + let decodedToken = ''; + if (requiredRoleTokenKind === 'access' && tokenset.access_token) { + decodedToken = jwtDecode(tokenset.access_token); + } else if (requiredRoleTokenKind === 'id' && tokenset.id_token) { + decodedToken = jwtDecode(tokenset.id_token); + } + + let roles = get(decodedToken, requiredRoleParameterPath); + if (!roles || (!Array.isArray(roles) && typeof roles !== 'string')) { + logger.error( + `[openidStrategy] Key '${requiredRoleParameterPath}' not found in ${requiredRoleTokenKind} token!`, + ); + const rolesList = + requiredRoles.length === 1 + ? `"${requiredRoles[0]}"` + : `one of: ${requiredRoles.map((r) => `"${r}"`).join(', ')}`; + throw new Error(`You must have ${rolesList} role to log in.`); + } + + if (!requiredRoles.some((role) => roles.includes(role))) { + const rolesList = + requiredRoles.length === 1 + ? `"${requiredRoles[0]}"` + : `one of: ${requiredRoles.map((r) => `"${r}"`).join(', ')}`; + throw new Error(`You must have ${rolesList} role to log in.`); + } + } + + let username = ''; + if (process.env.OPENID_USERNAME_CLAIM) { + username = userinfo[process.env.OPENID_USERNAME_CLAIM]; + } else { + username = convertToUsername( + userinfo.preferred_username || userinfo.username || userinfo.email, + ); + } + + if (existingUsersOnly && !user) { + throw new Error('User does not exist'); + } + + if (!user) { + user = { + provider: 'openid', + openidId: userinfo.sub, + username, + email: email || '', + emailVerified: userinfo.email_verified || false, + name: fullName, + idOnTheSource: userinfo.oid, + }; + + const balanceConfig = getBalanceConfig(appConfig); + user = await createUser(user, balanceConfig, true, true); + } else { + user.provider = 'openid'; + user.openidId = userinfo.sub; + user.username = username; + user.name = fullName; + user.idOnTheSource = userinfo.oid; + if (email && email !== user.email) { + user.email = email; + user.emailVerified = userinfo.email_verified || false; + } + } + + const adminRole = process.env.OPENID_ADMIN_ROLE; + const adminRoleParameterPath = process.env.OPENID_ADMIN_ROLE_PARAMETER_PATH; + const adminRoleTokenKind = process.env.OPENID_ADMIN_ROLE_TOKEN_KIND; + + if (adminRole && adminRoleParameterPath && adminRoleTokenKind) { + let adminRoleObject; + switch (adminRoleTokenKind) { + case 'access': + adminRoleObject = jwtDecode(tokenset.access_token); + break; + case 'id': + adminRoleObject = jwtDecode(tokenset.id_token); + break; + case 'userinfo': + adminRoleObject = userinfo; + break; + default: + logger.error( + `[openidStrategy] Invalid admin role token kind: ${adminRoleTokenKind}. Must be one of 'access', 'id', or 'userinfo'.`, + ); + throw new Error('Invalid admin role token kind'); + } + + const adminRoles = get(adminRoleObject, adminRoleParameterPath); + + if ( + adminRoles && + (adminRoles === true || + adminRoles === adminRole || + (Array.isArray(adminRoles) && adminRoles.includes(adminRole))) + ) { + user.role = SystemRoles.ADMIN; + logger.info(`[openidStrategy] User ${username} is an admin based on role: ${adminRole}`); + } else if (user.role === SystemRoles.ADMIN) { + user.role = SystemRoles.USER; + logger.info( + `[openidStrategy] User ${username} demoted from admin - role no longer present in token`, + ); + } + } + + if (!!userinfo && userinfo.picture && !user.avatar?.includes('manual=true')) { + /** @type {string | undefined} */ + const imageUrl = userinfo.picture; + + let fileName; + if (crypto) { + fileName = (await hashToken(userinfo.sub)) + '.png'; + } else { + fileName = userinfo.sub + '.png'; + } + + const imageBuffer = await downloadImage( + imageUrl, + openidConfig, + tokenset.access_token, + userinfo.sub, + ); + if (imageBuffer) { + const { saveBuffer } = getStrategyFunctions( + appConfig?.fileStrategy ?? process.env.CDN_PROVIDER, + ); + const imagePath = await saveBuffer({ + fileName, + userId: user._id.toString(), + buffer: imageBuffer, + }); + user.avatar = imagePath ?? ''; + } + } + + user = await updateUser(user._id, user); + + logger.info( + `[openidStrategy] login success openidId: ${user.openidId} | email: ${user.email} | username: ${user.username} `, + { + user: { + openidId: user.openidId, + username: user.username, + email: user.email, + name: user.name, + }, + }, + ); + + return { + ...user, + tokenset, + federatedTokens: { + access_token: tokenset.access_token, + refresh_token: tokenset.refresh_token, + expires_at: tokenset.expires_at, + }, + }; +} + +/** + * @param {boolean | undefined} [existingUsersOnly] + */ +function createOpenIDCallback(existingUsersOnly) { + return async (tokenset, done) => { + try { + const user = await processOpenIDAuth(tokenset, existingUsersOnly); + done(null, user); + } catch (err) { + if (err.message === 'Email domain not allowed') { + return done(null, false, { message: err.message }); + } + if (err.message === ErrorTypes.AUTH_FAILED) { + return done(null, false, { message: err.message }); + } + if (err.message && err.message.includes('role to log in')) { + return done(null, false, { message: err.message }); + } + logger.error('[openidStrategy] login failed', err); + done(err); + } + }; +} + +/** + * Sets up the OpenID strategy specifically for admin authentication. + * @param {Configuration} openidConfig + */ +const setupOpenIdAdmin = (openidConfig) => { + try { + if (!openidConfig) { + throw new Error('OpenID configuration not initialized'); + } + + const openidAdminLogin = new CustomOpenIDStrategy( + { + config: openidConfig, + scope: process.env.OPENID_SCOPE, + usePKCE: isEnabled(process.env.OPENID_USE_PKCE), + clockTolerance: process.env.OPENID_CLOCK_TOLERANCE || 300, + callbackURL: process.env.DOMAIN_SERVER + '/api/admin/oauth/openid/callback', + }, + createOpenIDCallback(true), + ); + + passport.use('openidAdmin', openidAdminLogin); + } catch (err) { + logger.error('[openidStrategy] setupOpenIdAdmin', err); + } +}; + /** * Sets up the OpenID strategy for authentication. * This function configures the OpenID client, handles proxy settings, @@ -324,10 +592,6 @@ async function setupOpenId() { }, ); - const requiredRole = process.env.OPENID_REQUIRED_ROLE; - const requiredRoleParameterPath = process.env.OPENID_REQUIRED_ROLE_PARAMETER_PATH; - const requiredRoleTokenKind = process.env.OPENID_REQUIRED_ROLE_TOKEN_KIND; - const usePKCE = isEnabled(process.env.OPENID_USE_PKCE); logger.info(`[openidStrategy] OpenID authentication configuration`, { generateNonce: shouldGenerateNonce, reason: shouldGenerateNonce @@ -335,241 +599,25 @@ async function setupOpenId() { : 'OPENID_GENERATE_NONCE=false - Standard flow without explicit nonce or metadata', }); - // Set of env variables that specify how to set if a user is an admin - // If not set, all users will be treated as regular users - const adminRole = process.env.OPENID_ADMIN_ROLE; - const adminRoleParameterPath = process.env.OPENID_ADMIN_ROLE_PARAMETER_PATH; - const adminRoleTokenKind = process.env.OPENID_ADMIN_ROLE_TOKEN_KIND; - const openidLogin = new CustomOpenIDStrategy( { config: openidConfig, scope: process.env.OPENID_SCOPE, callbackURL: process.env.DOMAIN_SERVER + process.env.OPENID_CALLBACK_URL, clockTolerance: process.env.OPENID_CLOCK_TOLERANCE || 300, - usePKCE, - }, - /** - * @param {import('openid-client').TokenEndpointResponseHelpers} tokenset - * @param {import('passport-jwt').VerifyCallback} done - */ - async (tokenset, done) => { - try { - const claims = tokenset.claims(); - const userinfo = { - ...claims, - ...(await getUserInfo(openidConfig, tokenset.access_token, claims.sub)), - }; - - const appConfig = await getAppConfig(); - /** Azure AD sometimes doesn't return email, use preferred_username as fallback */ - const email = userinfo.email || userinfo.preferred_username || userinfo.upn; - if (!isEmailDomainAllowed(email, appConfig?.registration?.allowedDomains)) { - logger.error( - `[OpenID Strategy] Authentication blocked - email domain not allowed [Email: ${email}]`, - ); - return done(null, false, { message: 'Email domain not allowed' }); - } - - const result = await findOpenIDUser({ - findUser, - email: email, - openidId: claims.sub, - idOnTheSource: claims.oid, - strategyName: 'openidStrategy', - }); - let user = result.user; - const error = result.error; - - if (error) { - return done(null, false, { - message: ErrorTypes.AUTH_FAILED, - }); - } - - const fullName = getFullName(userinfo); - - if (requiredRole) { - const requiredRoles = requiredRole - .split(',') - .map((role) => role.trim()) - .filter(Boolean); - let decodedToken = ''; - if (requiredRoleTokenKind === 'access') { - decodedToken = jwtDecode(tokenset.access_token); - } else if (requiredRoleTokenKind === 'id') { - decodedToken = jwtDecode(tokenset.id_token); - } - - let roles = get(decodedToken, requiredRoleParameterPath); - if (!roles || (!Array.isArray(roles) && typeof roles !== 'string')) { - logger.error( - `[openidStrategy] Key '${requiredRoleParameterPath}' not found or invalid type in ${requiredRoleTokenKind} token!`, - ); - const rolesList = - requiredRoles.length === 1 - ? `"${requiredRoles[0]}"` - : `one of: ${requiredRoles.map((r) => `"${r}"`).join(', ')}`; - return done(null, false, { - message: `You must have ${rolesList} role to log in.`, - }); - } - - if (!requiredRoles.some((role) => roles.includes(role))) { - const rolesList = - requiredRoles.length === 1 - ? `"${requiredRoles[0]}"` - : `one of: ${requiredRoles.map((r) => `"${r}"`).join(', ')}`; - return done(null, false, { - message: `You must have ${rolesList} role to log in.`, - }); - } - } - - let username = ''; - if (process.env.OPENID_USERNAME_CLAIM) { - username = userinfo[process.env.OPENID_USERNAME_CLAIM]; - } else { - username = convertToUsername( - userinfo.preferred_username || userinfo.username || userinfo.email, - ); - } - - if (!user) { - user = { - provider: 'openid', - openidId: userinfo.sub, - username, - email: email || '', - emailVerified: userinfo.email_verified || false, - name: fullName, - idOnTheSource: userinfo.oid, - }; - - const balanceConfig = getBalanceConfig(appConfig); - user = await createUser(user, balanceConfig, true, true); - } else { - user.provider = 'openid'; - user.openidId = userinfo.sub; - user.username = username; - user.name = fullName; - user.idOnTheSource = userinfo.oid; - if (email && email !== user.email) { - user.email = email; - user.emailVerified = userinfo.email_verified || false; - } - } - - if (adminRole && adminRoleParameterPath && adminRoleTokenKind) { - let adminRoleObject; - switch (adminRoleTokenKind) { - case 'access': - adminRoleObject = jwtDecode(tokenset.access_token); - break; - case 'id': - adminRoleObject = jwtDecode(tokenset.id_token); - break; - case 'userinfo': - adminRoleObject = userinfo; - break; - default: - logger.error( - `[openidStrategy] Invalid admin role token kind: ${adminRoleTokenKind}. Must be one of 'access', 'id', or 'userinfo'.`, - ); - return done(new Error('Invalid admin role token kind')); - } - - const adminRoles = get(adminRoleObject, adminRoleParameterPath); - - // Accept 3 types of values for the object extracted from adminRoleParameterPath: - // 1. A boolean value indicating if the user is an admin - // 2. A string with a single role name - // 3. An array of role names - - if ( - adminRoles && - (adminRoles === true || - adminRoles === adminRole || - (Array.isArray(adminRoles) && adminRoles.includes(adminRole))) - ) { - user.role = 'ADMIN'; - logger.info( - `[openidStrategy] User ${username} is an admin based on role: ${adminRole}`, - ); - } else if (user.role === 'ADMIN') { - user.role = 'USER'; - logger.info( - `[openidStrategy] User ${username} demoted from admin - role no longer present in token`, - ); - } - } - - if (!!userinfo && userinfo.picture && !user.avatar?.includes('manual=true')) { - /** @type {string | undefined} */ - const imageUrl = userinfo.picture; - - let fileName; - if (crypto) { - fileName = (await hashToken(userinfo.sub)) + '.png'; - } else { - fileName = userinfo.sub + '.png'; - } - - const imageBuffer = await downloadImage( - imageUrl, - openidConfig, - tokenset.access_token, - userinfo.sub, - ); - if (imageBuffer) { - const { saveBuffer } = getStrategyFunctions( - appConfig?.fileStrategy ?? process.env.CDN_PROVIDER, - ); - const imagePath = await saveBuffer({ - fileName, - userId: user._id.toString(), - buffer: imageBuffer, - }); - user.avatar = imagePath ?? ''; - } - } - - user = await updateUser(user._id, user); - - logger.info( - `[openidStrategy] login success openidId: ${user.openidId} | email: ${user.email} | username: ${user.username} `, - { - user: { - openidId: user.openidId, - username: user.username, - email: user.email, - name: user.name, - }, - }, - ); - - done(null, { - ...user, - tokenset, - federatedTokens: { - access_token: tokenset.access_token, - refresh_token: tokenset.refresh_token, - expires_at: tokenset.expires_at, - }, - }); - } catch (err) { - logger.error('[openidStrategy] login failed', err); - done(err); - } + usePKCE: isEnabled(process.env.OPENID_USE_PKCE), }, + createOpenIDCallback(), ); passport.use('openid', openidLogin); + setupOpenIdAdmin(openidConfig); return openidConfig; } catch (err) { logger.error('[openidStrategy]', err); return null; } } + /** * @function getOpenIdConfig * @description Returns the OpenID client instance. diff --git a/api/strategies/openidStrategy.spec.js b/api/strategies/openidStrategy.spec.js index 9ac22ff42f23..ada27cca1727 100644 --- a/api/strategies/openidStrategy.spec.js +++ b/api/strategies/openidStrategy.spec.js @@ -64,21 +64,36 @@ jest.mock('openid-client', () => { }); jest.mock('openid-client/passport', () => { - let verifyCallback; + /** Store callbacks by strategy name - 'openid' and 'openidAdmin' */ + const verifyCallbacks = {}; + let lastVerifyCallback; + const mockStrategy = jest.fn((options, verify) => { - verifyCallback = verify; + lastVerifyCallback = verify; return { name: 'openid', options, verify }; }); return { Strategy: mockStrategy, - __getVerifyCallback: () => verifyCallback, + /** Get the last registered callback (for backward compatibility) */ + __getVerifyCallback: () => lastVerifyCallback, + /** Store callback by name when passport.use is called */ + __setVerifyCallback: (name, callback) => { + verifyCallbacks[name] = callback; + }, + /** Get callback by strategy name */ + __getVerifyCallbackByName: (name) => verifyCallbacks[name], }; }); -// Mock passport +// Mock passport - capture strategy name and callback jest.mock('passport', () => ({ - use: jest.fn(), + use: jest.fn((name, strategy) => { + const passportMock = require('openid-client/passport'); + if (strategy && strategy.verify) { + passportMock.__setVerifyCallback(name, strategy.verify); + } + }), })); describe('setupOpenId', () => { @@ -159,9 +174,10 @@ describe('setupOpenId', () => { }; fetch.mockResolvedValue(fakeResponse); - // Call the setup function and capture the verify callback + // Call the setup function and capture the verify callback for the regular 'openid' strategy + // (not 'openidAdmin' which requires existing users) await setupOpenId(); - verifyCallback = require('openid-client/passport').__getVerifyCallback(); + verifyCallback = require('openid-client/passport').__getVerifyCallbackByName('openid'); }); it('should create a new user with correct username when preferred_username claim exists', async () => { @@ -389,7 +405,7 @@ describe('setupOpenId', () => { // Arrange process.env.OPENID_REQUIRED_ROLE = 'someRole,anotherRole,admin'; await setupOpenId(); // Re-initialize the strategy - verifyCallback = require('openid-client/passport').__getVerifyCallback(); + verifyCallback = require('openid-client/passport').__getVerifyCallbackByName('openid'); jwtDecode.mockReturnValue({ roles: ['anotherRole', 'aThirdRole'], }); @@ -406,7 +422,7 @@ describe('setupOpenId', () => { // Arrange process.env.OPENID_REQUIRED_ROLE = 'someRole,anotherRole,admin'; await setupOpenId(); // Re-initialize the strategy - verifyCallback = require('openid-client/passport').__getVerifyCallback(); + verifyCallback = require('openid-client/passport').__getVerifyCallbackByName('openid'); jwtDecode.mockReturnValue({ roles: ['aThirdRole', 'aFourthRole'], }); @@ -425,7 +441,7 @@ describe('setupOpenId', () => { // Arrange process.env.OPENID_REQUIRED_ROLE = ' someRole , anotherRole , admin '; await setupOpenId(); // Re-initialize the strategy - verifyCallback = require('openid-client/passport').__getVerifyCallback(); + verifyCallback = require('openid-client/passport').__getVerifyCallbackByName('openid'); jwtDecode.mockReturnValue({ roles: ['someRole'], }); @@ -560,7 +576,7 @@ describe('setupOpenId', () => { delete process.env.OPENID_ADMIN_ROLE_TOKEN_KIND; await setupOpenId(); - verifyCallback = require('openid-client/passport').__getVerifyCallback(); + verifyCallback = require('openid-client/passport').__getVerifyCallbackByName('openid'); // Simulate an existing admin user const existingAdminUser = { @@ -611,7 +627,7 @@ describe('setupOpenId', () => { }); await setupOpenId(); - verifyCallback = require('openid-client/passport').__getVerifyCallback(); + verifyCallback = require('openid-client/passport').__getVerifyCallbackByName('openid'); const { user } = await validate(tokenset); @@ -634,7 +650,7 @@ describe('setupOpenId', () => { }); await setupOpenId(); - verifyCallback = require('openid-client/passport').__getVerifyCallback(); + verifyCallback = require('openid-client/passport').__getVerifyCallbackByName('openid'); const { user } = await validate(tokenset); @@ -655,14 +671,12 @@ describe('setupOpenId', () => { }); await setupOpenId(); - verifyCallback = require('openid-client/passport').__getVerifyCallback(); + verifyCallback = require('openid-client/passport').__getVerifyCallbackByName('openid'); const { user, details } = await validate(tokenset); expect(logger.error).toHaveBeenCalledWith( - expect.stringContaining( - "Key 'resource_access.nonexistent.roles' not found or invalid type in id token!", - ), + expect.stringContaining("Key 'resource_access.nonexistent.roles' not found in id token!"), ); expect(user).toBe(false); expect(details.message).toContain('role to log in'); @@ -680,12 +694,12 @@ describe('setupOpenId', () => { }); await setupOpenId(); - verifyCallback = require('openid-client/passport').__getVerifyCallback(); + verifyCallback = require('openid-client/passport').__getVerifyCallbackByName('openid'); const { user } = await validate(tokenset); expect(logger.error).toHaveBeenCalledWith( - expect.stringContaining("Key 'org.team.roles' not found or invalid type in id token!"), + expect.stringContaining("Key 'org.team.roles' not found in id token!"), ); expect(user).toBe(false); }); @@ -709,7 +723,7 @@ describe('setupOpenId', () => { }); await setupOpenId(); - verifyCallback = require('openid-client/passport').__getVerifyCallback(); + verifyCallback = require('openid-client/passport').__getVerifyCallbackByName('openid'); const { user } = await validate(tokenset); @@ -739,7 +753,7 @@ describe('setupOpenId', () => { }); await setupOpenId(); - verifyCallback = require('openid-client/passport').__getVerifyCallback(); + verifyCallback = require('openid-client/passport').__getVerifyCallbackByName('openid'); const { user } = await validate({ ...tokenset, @@ -759,7 +773,7 @@ describe('setupOpenId', () => { }); await setupOpenId(); - verifyCallback = require('openid-client/passport').__getVerifyCallback(); + verifyCallback = require('openid-client/passport').__getVerifyCallbackByName('openid'); const { user } = await validate(tokenset); @@ -776,7 +790,7 @@ describe('setupOpenId', () => { }); await setupOpenId(); - verifyCallback = require('openid-client/passport').__getVerifyCallback(); + verifyCallback = require('openid-client/passport').__getVerifyCallbackByName('openid'); const { user } = await validate(tokenset); @@ -793,7 +807,7 @@ describe('setupOpenId', () => { }); await setupOpenId(); - verifyCallback = require('openid-client/passport').__getVerifyCallback(); + verifyCallback = require('openid-client/passport').__getVerifyCallbackByName('openid'); const { user } = await validate(tokenset); @@ -810,7 +824,7 @@ describe('setupOpenId', () => { }); await setupOpenId(); - verifyCallback = require('openid-client/passport').__getVerifyCallback(); + verifyCallback = require('openid-client/passport').__getVerifyCallbackByName('openid'); const { user } = await validate(tokenset); @@ -827,7 +841,7 @@ describe('setupOpenId', () => { }); await setupOpenId(); - verifyCallback = require('openid-client/passport').__getVerifyCallback(); + verifyCallback = require('openid-client/passport').__getVerifyCallbackByName('openid'); const { user } = await validate(tokenset); @@ -847,7 +861,7 @@ describe('setupOpenId', () => { }); await setupOpenId(); - verifyCallback = require('openid-client/passport').__getVerifyCallback(); + verifyCallback = require('openid-client/passport').__getVerifyCallbackByName('openid'); const { user } = await validate(tokenset); @@ -864,12 +878,12 @@ describe('setupOpenId', () => { }); await setupOpenId(); - verifyCallback = require('openid-client/passport').__getVerifyCallback(); + verifyCallback = require('openid-client/passport').__getVerifyCallbackByName('openid'); const { user } = await validate(tokenset); expect(logger.error).toHaveBeenCalledWith( - expect.stringContaining("Key 'access.roles' not found or invalid type in id token!"), + expect.stringContaining("Key 'access.roles' not found in id token!"), ); expect(user).toBe(false); }); @@ -884,12 +898,12 @@ describe('setupOpenId', () => { }); await setupOpenId(); - verifyCallback = require('openid-client/passport').__getVerifyCallback(); + verifyCallback = require('openid-client/passport').__getVerifyCallbackByName('openid'); const { user } = await validate(tokenset); expect(logger.error).toHaveBeenCalledWith( - expect.stringContaining("Key 'data.roles' not found or invalid type in id token!"), + expect.stringContaining("Key 'data.roles' not found in id token!"), ); expect(user).toBe(false); }); @@ -906,7 +920,7 @@ describe('setupOpenId', () => { }); await setupOpenId(); - verifyCallback = require('openid-client/passport').__getVerifyCallback(); + verifyCallback = require('openid-client/passport').__getVerifyCallbackByName('openid'); await expect(validate(tokenset)).rejects.toThrow('Invalid admin role token kind'); @@ -927,12 +941,12 @@ describe('setupOpenId', () => { }); await setupOpenId(); - verifyCallback = require('openid-client/passport').__getVerifyCallback(); + verifyCallback = require('openid-client/passport').__getVerifyCallbackByName('openid'); const { user, details } = await validate(tokenset); expect(logger.error).toHaveBeenCalledWith( - expect.stringContaining("Key 'roles' not found or invalid type in id token!"), + expect.stringContaining("Key 'roles' not found in id token!"), ); expect(user).toBe(false); expect(details.message).toContain('role to log in'); @@ -948,12 +962,12 @@ describe('setupOpenId', () => { }); await setupOpenId(); - verifyCallback = require('openid-client/passport').__getVerifyCallback(); + verifyCallback = require('openid-client/passport').__getVerifyCallbackByName('openid'); const { user } = await validate(tokenset); expect(logger.error).toHaveBeenCalledWith( - expect.stringContaining("Key 'roleCount' not found or invalid type in id token!"), + expect.stringContaining("Key 'roleCount' not found in id token!"), ); expect(user).toBe(false); }); diff --git a/client/src/common/agents-types.ts b/client/src/common/agents-types.ts index 9ac6b440a397..c3832b7ff8da 100644 --- a/client/src/common/agents-types.ts +++ b/client/src/common/agents-types.ts @@ -1,6 +1,7 @@ import { AgentCapabilities, ArtifactModes } from 'librechat-data-provider'; import type { AgentModelParameters, + AgentToolOptions, SupportContact, AgentProvider, GraphEdge, @@ -33,6 +34,8 @@ export type AgentForm = { model: string | null; model_parameters: AgentModelParameters; tools?: string[]; + /** Per-tool configuration options (deferred loading, allowed callers, etc.) */ + tool_options?: AgentToolOptions; provider?: AgentProvider | OptionWithIcon; /** @deprecated Use edges instead */ agent_ids?: string[]; diff --git a/client/src/components/Chat/Messages/Content/Part.tsx b/client/src/components/Chat/Messages/Content/Part.tsx index bfa2b28fac65..4a74e3606f77 100644 --- a/client/src/components/Chat/Messages/Content/Part.tsx +++ b/client/src/components/Chat/Messages/Content/Part.tsx @@ -91,7 +91,11 @@ const Part = memo( const isToolCall = 'args' in toolCall && (!toolCall.type || toolCall.type === ToolCallTypes.TOOL_CALL); - if (isToolCall && toolCall.name === Tools.execute_code) { + if ( + isToolCall && + (toolCall.name === Tools.execute_code || + toolCall.name === Constants.PROGRAMMATIC_TOOL_CALLING) + ) { return ( }) => { const [isVisible, setIsVisible] = useState(false); + const file = attachment as TFile & TAttachmentMetadata; const { handleDownload } = useAttachmentLink({ href: attachment.filepath ?? '', filename: attachment.filename ?? '', + file_id: file.file_id, + user: file.user, + source: file.source, }); const extension = attachment.filename?.split('.').pop(); diff --git a/client/src/components/Chat/Messages/Content/Parts/ExecuteCode.tsx b/client/src/components/Chat/Messages/Content/Parts/ExecuteCode.tsx index 729011bdbd5e..2f14ac0f13d9 100644 --- a/client/src/components/Chat/Messages/Content/Parts/ExecuteCode.tsx +++ b/client/src/components/Chat/Messages/Content/Parts/ExecuteCode.tsx @@ -67,7 +67,7 @@ export default function ExecuteCode({ const [contentHeight, setContentHeight] = useState(0); const prevShowCodeRef = useRef(showCode); - const { lang, code } = useParseArgs(args) ?? ({} as ParsedArgs); + const { lang = 'py', code } = useParseArgs(args) ?? ({} as ParsedArgs); const progress = useProgress(initialProgress); useEffect(() => { diff --git a/client/src/components/Chat/Messages/Content/Parts/LogContent.tsx b/client/src/components/Chat/Messages/Content/Parts/LogContent.tsx index d2a303f49f33..da2a8f175eb9 100644 --- a/client/src/components/Chat/Messages/Content/Parts/LogContent.tsx +++ b/client/src/components/Chat/Messages/Content/Parts/LogContent.tsx @@ -65,6 +65,7 @@ const LogContent: React.FC = ({ output = '', renderImages, atta return `${filename} ${localize('com_download_expired')}`; } + const fileData = file as TFile & TAttachmentMetadata; const filepath = file.filepath || ''; // const expirationText = expiresAt @@ -72,7 +73,13 @@ const LogContent: React.FC = ({ output = '', renderImages, atta // : ` ${localize('com_click_to_download')}`; return ( - + {'- '} {filename} {localize('com_click_to_download')} diff --git a/client/src/components/Chat/Messages/Content/Parts/LogLink.tsx b/client/src/components/Chat/Messages/Content/Parts/LogLink.tsx index d328f202ee93..070becf51776 100644 --- a/client/src/components/Chat/Messages/Content/Parts/LogLink.tsx +++ b/client/src/components/Chat/Messages/Content/Parts/LogLink.tsx @@ -1,21 +1,56 @@ import React from 'react'; +import { FileSources } from 'librechat-data-provider'; import { useToastContext } from '@librechat/client'; -import { useCodeOutputDownload } from '~/data-provider'; +import { useCodeOutputDownload, useFileDownload } from '~/data-provider'; interface LogLinkProps { href: string; filename: string; + file_id?: string; + user?: string; + source?: string; children: React.ReactNode; } -export const useAttachmentLink = ({ href, filename }: Pick) => { +interface AttachmentLinkOptions { + href: string; + filename: string; + file_id?: string; + user?: string; + source?: string; +} + +/** + * Determines if a file is stored locally (not an external API URL). + * Files with these sources are stored on the LibreChat server and should + * use the /api/files/download endpoint instead of direct URL access. + */ +const isLocallyStoredSource = (source?: string): boolean => { + if (!source) { + return false; + } + return [FileSources.local, FileSources.firebase, FileSources.s3, FileSources.azure_blob].includes( + source as FileSources, + ); +}; + +export const useAttachmentLink = ({ + href, + filename, + file_id, + user, + source, +}: AttachmentLinkOptions) => { const { showToast } = useToastContext(); - const { refetch: downloadFile } = useCodeOutputDownload(href); + + const useLocalDownload = isLocallyStoredSource(source) && !!file_id && !!user; + const { refetch: downloadFromApi } = useFileDownload(user, file_id); + const { refetch: downloadFromUrl } = useCodeOutputDownload(href); const handleDownload = async (event: React.MouseEvent) => { event.preventDefault(); try { - const stream = await downloadFile(); + const stream = useLocalDownload ? await downloadFromApi() : await downloadFromUrl(); if (stream.data == null || stream.data === '') { console.error('Error downloading file: No data found'); showToast({ @@ -39,8 +74,8 @@ export const useAttachmentLink = ({ href, filename }: Pick = ({ href, filename, children }) => { - const { handleDownload } = useAttachmentLink({ href, filename }); +const LogLink: React.FC = ({ href, filename, file_id, user, source, children }) => { + const { handleDownload } = useAttachmentLink({ href, filename, file_id, user, source }); return ( void }) { + const localize = useLocalize(); + const { showToast } = useToastContext(); + const [open, setOpen] = useState(false); + const [name, setName] = useState(''); + const [newKey, setNewKey] = useState(null); + const [showKey, setShowKey] = useState(false); + const [isCopying, setIsCopying] = useState(false); + const createMutation = useCreateAgentApiKeyMutation(); + const copyKey = useCopyToClipboard({ text: newKey || '' }); + + const handleCreate = async () => { + if (!name.trim()) { + showToast({ message: localize('com_ui_api_key_name_required'), status: 'error' }); + return; + } + + try { + const result = await createMutation.mutateAsync({ name: name.trim() }); + setNewKey(result.key); + showToast({ message: localize('com_ui_api_key_created'), status: 'success' }); + onKeyCreated?.(); + } catch { + showToast({ message: localize('com_ui_api_key_create_error'), status: 'error' }); + } + }; + + const handleClose = () => { + setName(''); + setNewKey(null); + setShowKey(false); + setOpen(false); + }; + + const handleCopy = () => { + if (isCopying) { + return; + } + copyKey(setIsCopying); + showToast({ message: localize('com_ui_api_key_copied'), status: 'success' }); + }; + + return ( + + + + + + {localize('com_ui_create_api_key')} +
+ {!newKey ? ( + <> +
+ + setName(e.target.value)} + placeholder={localize('com_ui_api_key_name_placeholder')} + /> +
+
+ + + + +
+ + ) : ( + <> +
+

+ {localize('com_ui_api_key_warning')} +

+
+
+ +
+ + + +
+
+
+ +
+ + )} +
+
+
+ ); +} + +function KeyItem({ + id, + name, + keyPrefix, + createdAt, + lastUsedAt, +}: { + id: string; + name: string; + keyPrefix: string; + createdAt: string; + lastUsedAt?: string; +}) { + const localize = useLocalize(); + const { showToast } = useToastContext(); + const [confirmDelete, setConfirmDelete] = useState(false); + const deleteMutation = useDeleteAgentApiKeyMutation(); + + const handleDelete = async () => { + try { + await deleteMutation.mutateAsync(id); + showToast({ message: localize('com_ui_api_key_deleted'), status: 'success' }); + } catch { + showToast({ message: localize('com_ui_api_key_delete_error'), status: 'error' }); + } + setConfirmDelete(false); + }; + + const formatDate = (dateStr: string) => { + return new Date(dateStr).toLocaleDateString(undefined, { + year: 'numeric', + month: 'short', + day: 'numeric', + }); + }; + + return ( +
+
+ +
+
{name}
+
+ {keyPrefix}... + • + + {localize('com_ui_created')} {formatDate(createdAt)} + + {lastUsedAt && ( + <> + • + + {localize('com_ui_last_used')} {formatDate(lastUsedAt)} + + + )} +
+
+
+
+ {confirmDelete ? ( +
+ + +
+ ) : ( + + )} +
+
+ ); +} + +function ApiKeysContent({ isOpen }: { isOpen: boolean }) { + const localize = useLocalize(); + const { data, isLoading, error } = useGetAgentApiKeysQuery({ enabled: isOpen }); + + if (error) { + return
{localize('com_ui_api_keys_load_error')}
; + } + + return ( +
+
+ + +
+ +
+ {isLoading && ( +
+ +
+ )} + {!isLoading && + data?.keys && + data.keys.length > 0 && + data.keys.map((key) => ( + + ))} + {!isLoading && (!data?.keys || data.keys.length === 0) && ( +
+ +

{localize('com_ui_no_api_keys')}

+
+ )} +
+
+ ); +} + +const remoteAgentsPermissions: PermissionConfig[] = [ + { permission: Permissions.USE, labelKey: 'com_ui_remote_agents_allow_use' }, + { permission: Permissions.CREATE, labelKey: 'com_ui_remote_agents_allow_create' }, + { permission: Permissions.SHARE, labelKey: 'com_ui_remote_agents_allow_share' }, + { permission: Permissions.SHARE_PUBLIC, labelKey: 'com_ui_remote_agents_allow_share_public' }, +]; + +function RemoteAgentsAdminSettings() { + const localize = useLocalize(); + const { showToast } = useToastContext(); + + const mutation = useUpdateRemoteAgentsPermissionsMutation({ + onSuccess: () => { + showToast({ status: 'success', message: localize('com_ui_saved') }); + }, + onError: () => { + showToast({ status: 'error', message: localize('com_ui_error_save_admin_settings') }); + }, + }); + + const trigger = ( + + ); + + return ( + + ); +} + +export function AgentApiKeys() { + const localize = useLocalize(); + const [isOpen, setIsOpen] = useState(false); + + return ( +
+ + + + + + + + + + {localize('com_ui_agent_api_keys')} +

+ {localize('com_ui_agent_api_keys_description')} +

+
+ +
+
+
+ ); +} diff --git a/client/src/components/Nav/SettingsTabs/Data/Data.tsx b/client/src/components/Nav/SettingsTabs/Data/Data.tsx index 0bba5a152e17..eb8cea98c294 100644 --- a/client/src/components/Nav/SettingsTabs/Data/Data.tsx +++ b/client/src/components/Nav/SettingsTabs/Data/Data.tsx @@ -1,15 +1,22 @@ import React, { useState, useRef } from 'react'; import { useOnClickOutside } from '@librechat/client'; +import { Permissions, PermissionTypes } from 'librechat-data-provider'; import ImportConversations from './ImportConversations'; -import { RevokeKeys } from './RevokeKeys'; +import { AgentApiKeys } from './AgentApiKeys'; import { DeleteCache } from './DeleteCache'; +import { RevokeKeys } from './RevokeKeys'; import { ClearChats } from './ClearChats'; import SharedLinks from './SharedLinks'; +import { useHasAccess } from '~/hooks'; function Data() { const dataTabRef = useRef(null); const [confirmClearConvos, setConfirmClearConvos] = useState(false); useOnClickOutside(dataTabRef, () => confirmClearConvos && setConfirmClearConvos(false), []); + const hasAccessToApiKeys = useHasAccess({ + permissionType: PermissionTypes.REMOTE_AGENTS, + permission: Permissions.USE, + }); return (
@@ -19,6 +26,11 @@ function Data() {
+ {hasAccessToApiKeys && ( +
+ +
+ )}
diff --git a/client/src/components/SidePanel/Agents/AgentFooter.tsx b/client/src/components/SidePanel/Agents/AgentFooter.tsx index 80a449bb2d77..b2fa99659626 100644 --- a/client/src/components/SidePanel/Agents/AgentFooter.tsx +++ b/client/src/components/SidePanel/Agents/AgentFooter.tsx @@ -1,3 +1,4 @@ +import { Globe } from 'lucide-react'; import { Spinner } from '@librechat/client'; import { useWatch, useFormContext } from 'react-hook-form'; import { @@ -44,13 +45,20 @@ export default function AgentFooter({ permissionType: PermissionTypes.AGENTS, permission: Permissions.SHARE, }); + const hasAccessToShareRemoteAgents = useHasAccess({ + permissionType: PermissionTypes.REMOTE_AGENTS, + permission: Permissions.SHARE, + }); const { hasPermission, isLoading: permissionsLoading } = useResourcePermissions( ResourceType.AGENT, agent?._id || '', ); + const { hasPermission: hasRemoteAgentPermission, isLoading: remotePermissionsLoading } = + useResourcePermissions(ResourceType.REMOTE_AGENT, agent?._id || ''); const canShareThisAgent = hasPermission(PermissionBits.SHARE); const canDeleteThisAgent = hasPermission(PermissionBits.DELETE); + const canShareRemoteAgent = hasRemoteAgentPermission(PermissionBits.SHARE); const isSaving = createMutation.isLoading || updateMutation.isLoading || isAvatarUploading; const renderSaveButton = () => { if (isSaving) { @@ -91,6 +99,25 @@ export default function AgentFooter({ resourceType={ResourceType.AGENT} /> )} + {(agent?.author === user?.id || user?.role === SystemRoles.ADMIN || canShareRemoteAgent) && + hasAccessToShareRemoteAgents && + !remotePermissionsLoading && + agent?._id && ( + + + + )} {agent && agent.author === user?.id && } {/* Submit Button */}
+ {/* Defer All toggle - icon only with tooltip */} + {deferredToolsEnabled && ( + { + e.stopPropagation(); + toggleDeferAll(); + }} + onKeyDown={(e) => { + if (e.key === 'Enter' || e.key === ' ') { + e.preventDefault(); + e.stopPropagation(); + toggleDeferAll(); + } + }} + > + + + )} +
{/* Caret button for accordion */} @@ -230,52 +352,97 @@ export default function MCPTool({ serverInfo }: { serverInfo?: MCPServerInfo })
- {serverInfo.tools?.map((subTool) => ( -