diff --git a/api/src/utils/async-handler.test.ts b/api/src/utils/async-handler.test.ts new file mode 100644 index 0000000000..403534f4f6 --- /dev/null +++ b/api/src/utils/async-handler.test.ts @@ -0,0 +1,19 @@ +import type { RequestHandler, NextFunction, Request, Response } from 'express'; +import '../../src/types/express.d.ts'; +import asyncHandler from './async-handler'; + +let mockRequest: Partial; +let mockResponse: Partial; +const nextFunction: NextFunction = jest.fn(); + +test('Wraps async middleware in Promise resolve that will catch rejects and pass them to the nextFn', async () => { + const err = new Error('testing'); + + const middleware: RequestHandler = async (req, res, next) => { + throw err; + }; + + await asyncHandler(middleware)(mockRequest as Request, mockResponse as Response, nextFunction as NextFunction); + + expect(nextFunction).toHaveBeenCalledWith(err); +}); diff --git a/api/src/utils/async-handler.ts b/api/src/utils/async-handler.ts index e50cb7b6d8..3ae29adf1e 100644 --- a/api/src/utils/async-handler.ts +++ b/api/src/utils/async-handler.ts @@ -1,22 +1,6 @@ -import { ErrorRequestHandler, RequestHandler } from 'express'; +import type { RequestHandler, Request, Response, NextFunction } from 'express'; -/** - * Handles promises in routes. - */ -function asyncHandler(handler: RequestHandler): RequestHandler; -function asyncHandler(handler: ErrorRequestHandler): ErrorRequestHandler; -function asyncHandler(handler: RequestHandler | ErrorRequestHandler): RequestHandler | ErrorRequestHandler { - if (handler.length === 2 || handler.length === 3) { - const scoped: RequestHandler = (req, res, next) => - Promise.resolve((handler as RequestHandler)(req, res, next)).catch(next); - return scoped; - } else if (handler.length === 4) { - const scoped: ErrorRequestHandler = (err, req, res, next) => - Promise.resolve((handler as ErrorRequestHandler)(err, req, res, next)).catch(next); - return scoped; - } else { - throw new Error(`Failed to asyncHandle() function "${handler.name}"`); - } -} +const asyncHandler = (fn: RequestHandler) => (req: Request, res: Response, next: NextFunction) => + Promise.resolve(fn(req, res, next)).catch(next); export default asyncHandler;