mirror of
https://github.com/directus/directus.git
synced 2026-04-25 03:00:53 -04:00
Integrating Websockets in Directus 🕸️🧦 (#14737)
* added emitter context * partial items tests * updated items handler tests * fixed test after merge * forgot the event context * fixed auth message parsing for graphql subscriptions * fixed type strictness * fixed graphql subscription bug * bumped websocket dependencies * touched up some dangling code * updated itemsservice usage * disabled overkill logs * double checked environment type processing * fixed missed capitalization * fixed subscription payloads * Added explicit string type casting * removed obsolete "trimUpper" utility * using the parseJSON utility consistently * pinned dependencies * parse environment variables * fixed pnpm-lock * GraphQL Subscriptions for all events * fixed typo * added event data to the graphql definition * fix payload for delete events * Added optional chaining for type to prevent fatal crashes on invalid messages * fix failing on getting type from undefined * Update api/src/websocket/exceptions.ts Co-authored-by: Azri Kahar <42867097+azrikahar@users.noreply.github.com> * Add proper ZodError handling * added the zod-validation-error parser * allow disabling the rate limiter * Update api/src/websocket/controllers/base.ts Co-authored-by: Azri Kahar <42867097+azrikahar@users.noreply.github.com> * updated starting logs * fixed email/password expiration logic * added tests for getMessageType * simplified message parsing and dropped capitalization * updated authenticate test * switched to lower cased message.type to prevent spreading "toUpperCase" around * cleaned up debug logs * cast enabled config to boolean * Update api/src/websocket/controllers/rest.ts Co-authored-by: Azri Kahar <42867097+azrikahar@users.noreply.github.com> * Update api/src/websocket/handlers/subscribe.ts Co-authored-by: Azri Kahar <42867097+azrikahar@users.noreply.github.com> * Update api/src/websocket/handlers/subscribe.ts Co-authored-by: Azri Kahar <42867097+azrikahar@users.noreply.github.com> * Update api/src/websocket/handlers/items.ts Co-authored-by: Azri Kahar <42867097+azrikahar@users.noreply.github.com> * Update api/src/websocket/controllers/base.ts Co-authored-by: Azri Kahar <42867097+azrikahar@users.noreply.github.com> * Update api/src/websocket/handlers/heartbeat.ts Co-authored-by: Azri Kahar <42867097+azrikahar@users.noreply.github.com> * Suggested fixes by Azri * removed redundant try-catch * fixed authentication timeout added returning the refresh token when authenticating * updated pnpm lock after merge * Fixed authentication modes for GraphQL according to best practices * implement useFakeTimers in heartbeat unit test * implement useFakeTimers in items unit test * Update api/src/services/server.ts Co-authored-by: Azri Kahar <42867097+azrikahar@users.noreply.github.com> * removed obsolete authentication.verbose toggle * added email flag to message validation * switched to ternary for consistency * moved getSchema out of for loop * added singleton logic to items handler * close the socket after failed auth for non-public connections * disabled system collections for rest subscriptions * re-ran pnpm i * allow for multiple subscripitions in the memory messenger * - fixed system collection subscriptions - abstracted hook message bus - fixed graphql horizontal scaling * remove logic from root context for tests * fix reading created item * fix linter * typo and extra safe guard suggested by azri * prevent setting long timeouts in favor of a shared interval * prevent unsubscribing all existing subscriptions when omitting "uid" * - extracted getService utility - block system collections mutation in the items handler - implemented the correct services for system collections * allow numeric uid's to be used * fixed the types for numeric uid's to be used * added missing await's * fixed type imports after merge * removed unused imports * Update api/src/websocket/controllers/hooks.ts Co-authored-by: Azri Kahar <42867097+azrikahar@users.noreply.github.com> * Update api/src/websocket/controllers/hooks.ts Co-authored-by: Azri Kahar <42867097+azrikahar@users.noreply.github.com> * Update api/src/messenger.ts Co-authored-by: Azri Kahar <42867097+azrikahar@users.noreply.github.com> * improved error for graphql subscriptions * fixed TS Modernization conflicts * fixed TS Modernization conflicts * fixed conflicts after merge * removed unused name property * abstracxted environment configuration * respond to ping messages when heartbeat disabled * something something merge * moved toBoolean to it's own util file * replaced old socket naming * removed old exception * fixed typo * Update api/src/env.ts Co-authored-by: ian <licitdev@gmail.com> * Update api/src/websocket/handlers/heartbeat.test.ts Co-authored-by: ian <licitdev@gmail.com> * Update api/src/websocket/handlers/heartbeat.ts Co-authored-by: ian <licitdev@gmail.com> * Update api/src/services/server.ts Co-authored-by: ian <licitdev@gmail.com> * fixed for linter * add server_info_websocket in graphql * Add base REST websocket tests * do merge things * fixing things * fixed failing unit test * Update dependencies * Move tests * Update lockfile * Use new paths when spawning * return websockets to opt-in * Enable websockets for tests * Test with ephemeral access token * no camelcasing gql subscriptions * use underscore for gql event * Remove unused import * Add base GraphQL subscription tests * Fix accidental comment * Add some relational tests * Organize imports Using VS Code's default organize import * Run ESlint formatting * One more opinionated formatting change * Formatting * Fix message sequence not in order * Remove relational batch update tests * Test horizontal scaling * using toboolean util for server_info * removed unneeded type cast * found the gql request type * extra usage of the toBoolean util * merge the authentication middleware and get-account-for-token util * updated utility test * fixed middleware unit test * Add return * Remove user filtering and close conns * Fix reused accountability * fixed failing util test * added subscription unit tests * added missing mock * trigger workflow * Revert "trigger workflow" This reverts commit4f544b0c1b. * Trigger testing for all vendors * add unsubscription confirmation * Wait for unsubscription confirmation * Fix incorrect sending of unsubscription confirmation * updated ubsubscribe logic * Update count for unsubscription message * Fix sequence for UUID pktype in MSSQL * Increase auth timeout * Add start index when getting messages * Fix subscription retrieval and cast uid to string * Remove nested ternary * Revert "Increase auth timeout" This reverts commit10707409c4. * Terminate connection instead of close * fixed merge * re-added missing packages after merge resolve * fixed type imports * Create lazy-cows-happen.md Added changeset * Minor bump for "directus" package as well * fixed "strict" auth mode for graphql subscriptions * removed nested ternary * Add websocket tests to sequential flow * Disable pressure limiter for blackbox tests * fix merge * WebSockets Documentation (#18254) * Small repsonsive styling fix on Card * REST getting started guide * Authentication guide * REST subscription guides * JS Chat guide * Sidebar websocket guides section * Added config options * Respoinding to brainslug's review * Fixed incorrect header on guides/rt/subs * Fixed spellchecker * Correct full code example on guides/rt/chat/js * Fixed JS chat tut * Order of steps in js chat guide updated for easier following-along * Realtime chat Vue Guide * feat: create react.js file * feat: add set up for directus project * docs: create react boilder plate * docs: initialize connection * docs: set up submission methods * docs: establish websocket connection * docs: subscribe to messages * docs: create new messages * docs: display new messages * docs: display historical messages * docs: next steps * docs: full code sample * docs: clean up * docs: add name to contributors * docs: add react card * docs: updates to react chat * Added live poll result guide * docs: intro * docs: before you begin * docs: install packages * docs: authenticate connection * docs: query and mutation * docs: utilize hooks * docs: subscribe to changes * docs: create helper functions * docs: display messages * docs: summary * docs: full sample code * chore: add card for webscockets with graphql * docs: intro * docs: subscribe to changes * docs: handling changes * docs: crud operations * docs: unsubscribing from changes * docs: updates * chore: add card * chore: updates to graphql docs * chore: updates to getting started * chore: updates to subscription * chore: updates to real chat guide * Added WebSockets Operations Guide * Consistent titles * Contributors component for docs * Triggering Netlify * Add operations to sidebar * Fix operations link * Small formatting changes * Clarity around property values * Removed unused values in Contributors component * Prompt for default choice * Tabs & lowercase doctypes * Semicolons * Event overwerites -> event listeners * Spacing * Flipped order of websockets guide to match GQL --------- Co-authored-by: Esther Agbaje <folasadeagbaje@gmail.com> Co-authored-by: Rijk van Zanten <rijkvanzanten@me.com> * fixed typo * removed unused import * added tests for "to-boolean" and "exceptions" * added websocket service tests * quote environment variable to satisfy dictionary * GraphQL Subscriptions update (#18804) * updated graphql subscription structure * updated graphql examples * Create hungry-points-rescue.md * using `key` instead of `ID` on the toplevel * removed changeset * fixed the graphql type after the rename to `key` * retrun data null for delete events to prevent non-nullable gql error * updated missed ID reference in the docs * updated missed ID reference in the docs * renamed "payload" to "data" in the REST Subscription response * fixed missed reference to payload * added optional event filter for REST subscriptions * updated docs for event filter * Update docs/guides/real-time/subscriptions/websockets.md Co-authored-by: ian <licitdev@gmail.com> --------- Co-authored-by: ian <licitdev@gmail.com> * added messenger unit test * always send subscription confirmation * Add event to subscription options * Update tests * Add tests for event filtering * Revert testing for all vendors * Remove obsolete console comment * Update comment * Correct event in JS WS guide * Fix collection name to match name used in subscription * Fix collection name in other guides * Fix diffs in doc & enhance chart example * Complete sentence in GraphQL guide * Small update to config description --------- Co-authored-by: Rijk van Zanten <rijkvanzanten@me.com> Co-authored-by: Azri Kahar <42867097+azrikahar@users.noreply.github.com> Co-authored-by: ian <licitdev@gmail.com> Co-authored-by: Nitwel <mail@nitwel.de> Co-authored-by: Kevin Lewis <kvn@lws.io> Co-authored-by: Pascal Jufer <pascal-jufer@bluewin.ch> Co-authored-by: Esther Agbaje <folasadeagbaje@gmail.com>
This commit is contained in:
@@ -112,6 +112,7 @@
|
||||
"fs-extra": "11.1.1",
|
||||
"graphql": "16.6.0",
|
||||
"graphql-compose": "9.0.10",
|
||||
"graphql-ws": "5.12.0",
|
||||
"helmet": "7.0.0",
|
||||
"icc": "3.0.0",
|
||||
"inquirer": "9.2.4",
|
||||
@@ -158,7 +159,10 @@
|
||||
"uuid": "9.0.0",
|
||||
"uuid-validate": "0.0.3",
|
||||
"vm2": "3.9.19",
|
||||
"wellknown": "0.5.0"
|
||||
"wellknown": "0.5.0",
|
||||
"ws": "8.12.1",
|
||||
"zod": "3.21.4",
|
||||
"zod-validation-error": "1.0.1"
|
||||
},
|
||||
"devDependencies": {
|
||||
"@directus/tsconfig": "workspace:*",
|
||||
@@ -199,6 +203,7 @@
|
||||
"@types/uuid": "9.0.1",
|
||||
"@types/uuid-validate": "0.0.1",
|
||||
"@types/wellknown": "0.5.4",
|
||||
"@types/ws": "8.5.4",
|
||||
"@vitest/coverage-c8": "0.31.1",
|
||||
"copyfiles": "2.4.1",
|
||||
"form-data": "4.0.0",
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import type { ActionHandler, EventContext, FilterHandler, InitHandler } from '@directus/types';
|
||||
import ee2 from 'eventemitter2';
|
||||
import logger from './logger.js';
|
||||
import getDatabase from './database/index.js';
|
||||
|
||||
export class Emitter {
|
||||
private filterEmitter;
|
||||
@@ -22,11 +23,19 @@ export class Emitter {
|
||||
this.initEmitter = new ee2.EventEmitter2(emitterOptions);
|
||||
}
|
||||
|
||||
private getDefaultContext(): EventContext {
|
||||
return {
|
||||
database: getDatabase(),
|
||||
accountability: null,
|
||||
schema: null,
|
||||
};
|
||||
}
|
||||
|
||||
public async emitFilter<T>(
|
||||
event: string | string[],
|
||||
payload: T,
|
||||
meta: Record<string, any>,
|
||||
context: EventContext
|
||||
context: EventContext | null = null
|
||||
): Promise<T> {
|
||||
const events = Array.isArray(event) ? event : [event];
|
||||
|
||||
@@ -39,7 +48,7 @@ export class Emitter {
|
||||
|
||||
for (const { event, listeners } of eventListeners) {
|
||||
for (const listener of listeners) {
|
||||
const result = await listener(updatedPayload, { event, ...meta }, context);
|
||||
const result = await listener(updatedPayload, { event, ...meta }, context ?? this.getDefaultContext());
|
||||
|
||||
if (result !== undefined) {
|
||||
updatedPayload = result;
|
||||
@@ -50,11 +59,11 @@ export class Emitter {
|
||||
return updatedPayload;
|
||||
}
|
||||
|
||||
public emitAction(event: string | string[], meta: Record<string, any>, context: EventContext): void {
|
||||
public emitAction(event: string | string[], meta: Record<string, any>, context: EventContext | null = null): void {
|
||||
const events = Array.isArray(event) ? event : [event];
|
||||
|
||||
for (const event of events) {
|
||||
this.actionEmitter.emitAsync(event, { event, ...meta }, context).catch((err) => {
|
||||
this.actionEmitter.emitAsync(event, { event, ...meta }, context ?? this.getDefaultContext()).catch((err) => {
|
||||
logger.warn(`An error was thrown while executing action "${event}"`);
|
||||
logger.warn(err);
|
||||
});
|
||||
|
||||
@@ -7,10 +7,10 @@ import { parseJSON, toArray } from '@directus/utils';
|
||||
import dotenv from 'dotenv';
|
||||
import fs from 'fs';
|
||||
import { clone, toNumber, toString } from 'lodash-es';
|
||||
import { createRequire } from 'node:module';
|
||||
import path from 'path';
|
||||
import { requireYAML } from './utils/require-yaml.js';
|
||||
|
||||
import { createRequire } from 'node:module';
|
||||
import { toBoolean } from './utils/to-boolean.js';
|
||||
|
||||
const require = createRequire(import.meta.url);
|
||||
|
||||
@@ -206,6 +206,8 @@ const allowedEnvironmentVars = [
|
||||
// flows
|
||||
'FLOWS_EXEC_ALLOWED_MODULES',
|
||||
'FLOWS_ENV_ALLOW_LIST',
|
||||
// websockets
|
||||
'WEBSOCKETS_.+',
|
||||
].map((name) => new RegExp(`^${name}$`));
|
||||
|
||||
const acceptedEnvTypes = ['string', 'number', 'regex', 'array', 'json'];
|
||||
@@ -306,6 +308,18 @@ const defaults: Record<string, any> = {
|
||||
|
||||
GRAPHQL_INTROSPECTION: true,
|
||||
|
||||
WEBSOCKETS_ENABLED: false,
|
||||
WEBSOCKETS_REST_ENABLED: true,
|
||||
WEBSOCKETS_REST_AUTH: 'handshake',
|
||||
WEBSOCKETS_REST_AUTH_TIMEOUT: 10,
|
||||
WEBSOCKETS_REST_PATH: '/websocket',
|
||||
WEBSOCKETS_GRAPHQL_ENABLED: true,
|
||||
WEBSOCKETS_GRAPHQL_AUTH: 'handshake',
|
||||
WEBSOCKETS_GRAPHQL_AUTH_TIMEOUT: 10,
|
||||
WEBSOCKETS_GRAPHQL_PATH: '/graphql',
|
||||
WEBSOCKETS_HEARTBEAT_ENABLED: true,
|
||||
WEBSOCKETS_HEARTBEAT_PERIOD: 30,
|
||||
|
||||
FLOWS_EXEC_ALLOWED_MODULES: false,
|
||||
FLOWS_ENV_ALLOW_LIST: false,
|
||||
|
||||
@@ -571,7 +585,3 @@ function tryJSON(value: any) {
|
||||
return value;
|
||||
}
|
||||
}
|
||||
|
||||
function toBoolean(value: any): boolean {
|
||||
return value === 'true' || value === true || value === '1' || value === 1;
|
||||
}
|
||||
|
||||
82
api/src/messenger.test.ts
Normal file
82
api/src/messenger.test.ts
Normal file
@@ -0,0 +1,82 @@
|
||||
import { describe, expect, test, vi, beforeEach } from 'vitest';
|
||||
import { getEnv } from './env.js';
|
||||
import { MessengerMemory, MessengerRedis } from './messenger.js';
|
||||
|
||||
vi.mock('./env');
|
||||
vi.mock('ioredis');
|
||||
|
||||
async function dynamicMessenger(mockedEnv: Record<string, any>) {
|
||||
vi.mocked(getEnv).mockReturnValue(mockedEnv);
|
||||
|
||||
return await import('./messenger.js');
|
||||
}
|
||||
|
||||
beforeEach(() => {
|
||||
vi.resetModules();
|
||||
});
|
||||
|
||||
describe('MessengerMemory', () => {
|
||||
test('getMessenger', async () => {
|
||||
const { MessengerMemory, getMessenger } = await dynamicMessenger({
|
||||
MESSENGER_STORE: 'memory',
|
||||
});
|
||||
|
||||
const messenger = getMessenger();
|
||||
|
||||
expect(messenger).toBeInstanceOf(MessengerMemory);
|
||||
});
|
||||
|
||||
test('subscribing', () => {
|
||||
const messages: Record<string, string>[] = [];
|
||||
const testMessage = { test: 'test' };
|
||||
const messenger = new MessengerMemory();
|
||||
|
||||
messenger.subscribe('test', (data: Record<string, string>) => {
|
||||
messages.push(data);
|
||||
});
|
||||
|
||||
messenger.publish('test', testMessage);
|
||||
|
||||
expect(messenger.handlers['test']?.size ?? 0).toBe(1);
|
||||
expect(messages.length).toBe(1);
|
||||
expect(messages).toStrictEqual([testMessage]);
|
||||
|
||||
messenger.unsubscribe('test');
|
||||
|
||||
messenger.publish('test', testMessage);
|
||||
|
||||
expect(messenger.handlers['test']?.size ?? 0).toBe(0);
|
||||
expect(messages.length).toBe(1);
|
||||
expect(messages).toStrictEqual([testMessage]);
|
||||
});
|
||||
});
|
||||
|
||||
describe('MessengerRedis', () => {
|
||||
test('getMessenger', async () => {
|
||||
const { MessengerRedis, getMessenger } = await dynamicMessenger({
|
||||
MESSENGER_STORE: 'redis',
|
||||
});
|
||||
|
||||
const messenger = getMessenger();
|
||||
|
||||
expect(messenger).toBeInstanceOf(MessengerRedis);
|
||||
});
|
||||
|
||||
test('subscribing', () => {
|
||||
const testMessage = { test: 'test' };
|
||||
const messenger = new MessengerRedis();
|
||||
|
||||
messenger.subscribe('test', (_data: Record<string, string>) => {
|
||||
// do nothing
|
||||
});
|
||||
|
||||
expect(messenger.sub.subscribe).toBeCalled();
|
||||
expect(messenger.sub.on).toBeCalled();
|
||||
|
||||
messenger.publish('test', testMessage);
|
||||
expect(messenger.pub.publish).toBeCalled();
|
||||
|
||||
messenger.unsubscribe('test');
|
||||
expect(messenger.sub.unsubscribe).toBeCalled();
|
||||
});
|
||||
});
|
||||
@@ -1,6 +1,6 @@
|
||||
import { parseJSON } from '@directus/utils';
|
||||
import { Redis } from 'ioredis';
|
||||
import env from './env.js';
|
||||
import { getEnv } from './env.js';
|
||||
import { getConfigFromEnv } from './utils/get-config-from-env.js';
|
||||
|
||||
export type MessengerSubscriptionCallback = (payload: Record<string, any>) => void;
|
||||
@@ -8,26 +8,31 @@ export type MessengerSubscriptionCallback = (payload: Record<string, any>) => vo
|
||||
export interface Messenger {
|
||||
publish: (channel: string, payload: Record<string, any>) => void;
|
||||
subscribe: (channel: string, callback: MessengerSubscriptionCallback) => void;
|
||||
unsubscribe: (channel: string) => void;
|
||||
unsubscribe: (channel: string, callback?: MessengerSubscriptionCallback) => void;
|
||||
}
|
||||
|
||||
export class MessengerMemory implements Messenger {
|
||||
handlers: Record<string, MessengerSubscriptionCallback>;
|
||||
handlers: Record<string, Set<MessengerSubscriptionCallback>>;
|
||||
|
||||
constructor() {
|
||||
this.handlers = {};
|
||||
}
|
||||
|
||||
publish(channel: string, payload: Record<string, any>) {
|
||||
this.handlers[channel]?.(payload);
|
||||
this.handlers[channel]?.forEach((callback) => callback(payload));
|
||||
}
|
||||
|
||||
subscribe(channel: string, callback: MessengerSubscriptionCallback) {
|
||||
this.handlers[channel] = callback;
|
||||
if (!this.handlers[channel]) this.handlers[channel] = new Set();
|
||||
this.handlers[channel]?.add(callback);
|
||||
}
|
||||
|
||||
unsubscribe(channel: string) {
|
||||
delete this.handlers[channel];
|
||||
unsubscribe(channel: string, callback?: MessengerSubscriptionCallback) {
|
||||
if (!callback) {
|
||||
delete this.handlers[channel];
|
||||
} else {
|
||||
this.handlers[channel]?.delete(callback);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -38,7 +43,7 @@ export class MessengerRedis implements Messenger {
|
||||
|
||||
constructor() {
|
||||
const config = getConfigFromEnv('MESSENGER_REDIS');
|
||||
|
||||
const env = getEnv();
|
||||
this.pub = new Redis(env['MESSENGER_REDIS'] ?? config);
|
||||
this.sub = new Redis(env['MESSENGER_REDIS'] ?? config);
|
||||
this.namespace = env['MESSENGER_NAMESPACE'] ?? 'directus';
|
||||
@@ -69,6 +74,7 @@ let messenger: Messenger;
|
||||
|
||||
export function getMessenger() {
|
||||
if (messenger) return messenger;
|
||||
const env = getEnv();
|
||||
|
||||
if (env['MESSENGER_STORE'] === 'redis') {
|
||||
messenger = new MessengerRedis();
|
||||
|
||||
@@ -2,18 +2,19 @@ import type { Request, Response } from 'express';
|
||||
import jwt from 'jsonwebtoken';
|
||||
import type { Knex } from 'knex';
|
||||
import { afterEach, expect, test, vi } from 'vitest';
|
||||
import '../../src/types/express.d.ts';
|
||||
import '../types/express.d.ts';
|
||||
import getDatabase from '../database/index.js';
|
||||
import emitter from '../emitter.js';
|
||||
import env from '../env.js';
|
||||
import { InvalidCredentialsException } from '../exceptions/invalid-credentials.js';
|
||||
import { handler } from './authenticate.js';
|
||||
|
||||
vi.mock('../../src/database');
|
||||
vi.mock('../database/index');
|
||||
|
||||
vi.mock('../../src/env', () => {
|
||||
vi.mock('../env', () => {
|
||||
const MOCK_ENV = {
|
||||
SECRET: 'test',
|
||||
EXTENSIONS_PATH: './extensions',
|
||||
};
|
||||
|
||||
return {
|
||||
|
||||
@@ -3,12 +3,9 @@ import type { NextFunction, Request, Response } from 'express';
|
||||
import { isEqual } from 'lodash-es';
|
||||
import getDatabase from '../database/index.js';
|
||||
import emitter from '../emitter.js';
|
||||
import env from '../env.js';
|
||||
import { InvalidCredentialsException } from '../exceptions/index.js';
|
||||
import asyncHandler from '../utils/async-handler.js';
|
||||
import { getIPFromReq } from '../utils/get-ip-from-req.js';
|
||||
import isDirectusJWT from '../utils/is-directus-jwt.js';
|
||||
import { verifyAccessJWT } from '../utils/jwt.js';
|
||||
import { getAccountabilityForToken } from '../utils/get-accountability-for-token.js';
|
||||
|
||||
/**
|
||||
* Verify the passed JWT and assign the user ID and role to `req`
|
||||
@@ -48,41 +45,7 @@ export const handler = async (req: Request, _res: Response, next: NextFunction)
|
||||
return next();
|
||||
}
|
||||
|
||||
req.accountability = defaultAccountability;
|
||||
|
||||
if (req.token) {
|
||||
if (isDirectusJWT(req.token)) {
|
||||
const payload = verifyAccessJWT(req.token, env['SECRET']);
|
||||
|
||||
req.accountability.role = payload.role;
|
||||
req.accountability.admin = payload.admin_access === true || payload.admin_access == 1;
|
||||
req.accountability.app = payload.app_access === true || payload.app_access == 1;
|
||||
|
||||
if (payload.share) req.accountability.share = payload.share;
|
||||
if (payload.share_scope) req.accountability.share_scope = payload.share_scope;
|
||||
if (payload.id) req.accountability.user = payload.id;
|
||||
} else {
|
||||
// Try finding the user with the provided token
|
||||
const user = await database
|
||||
.select('directus_users.id', 'directus_users.role', 'directus_roles.admin_access', 'directus_roles.app_access')
|
||||
.from('directus_users')
|
||||
.leftJoin('directus_roles', 'directus_users.role', 'directus_roles.id')
|
||||
.where({
|
||||
'directus_users.token': req.token,
|
||||
status: 'active',
|
||||
})
|
||||
.first();
|
||||
|
||||
if (!user) {
|
||||
throw new InvalidCredentialsException();
|
||||
}
|
||||
|
||||
req.accountability.user = user.id;
|
||||
req.accountability.role = user.role;
|
||||
req.accountability.admin = user.admin_access === true || user.admin_access == 1;
|
||||
req.accountability.app = user.app_access === true || user.app_access == 1;
|
||||
}
|
||||
}
|
||||
req.accountability = await getAccountabilityForToken(req.token, defaultAccountability);
|
||||
|
||||
return next();
|
||||
};
|
||||
|
||||
@@ -12,6 +12,14 @@ import emitter from './emitter.js';
|
||||
import env from './env.js';
|
||||
import logger from './logger.js';
|
||||
import { getConfigFromEnv } from './utils/get-config-from-env.js';
|
||||
import {
|
||||
createSubscriptionController,
|
||||
createWebSocketController,
|
||||
getSubscriptionController,
|
||||
getWebSocketController,
|
||||
} from './websocket/controllers/index.js';
|
||||
import { startWebSocketHandlers } from './websocket/handlers/index.js';
|
||||
import { toBoolean } from './utils/to-boolean.js';
|
||||
|
||||
export let SERVER_ONLINE = true;
|
||||
|
||||
@@ -82,6 +90,12 @@ export async function createServer(): Promise<http.Server> {
|
||||
res.once('close', complete.bind(null, false));
|
||||
});
|
||||
|
||||
if (toBoolean(env['WEBSOCKETS_ENABLED']) === true) {
|
||||
createSubscriptionController(server);
|
||||
createWebSocketController(server);
|
||||
startWebSocketHandlers();
|
||||
}
|
||||
|
||||
const terminusOptions: TerminusOptions = {
|
||||
timeout:
|
||||
env['SERVER_SHUTDOWN_TIMEOUT'] >= 0 && env['SERVER_SHUTDOWN_TIMEOUT'] < Infinity
|
||||
@@ -106,6 +120,8 @@ export async function createServer(): Promise<http.Server> {
|
||||
}
|
||||
|
||||
async function onSignal() {
|
||||
getSubscriptionController()?.terminate();
|
||||
getWebSocketController()?.terminate();
|
||||
const database = getDatabase();
|
||||
await database.destroy();
|
||||
|
||||
|
||||
@@ -62,24 +62,13 @@ import { AuthenticationService } from '../authentication.js';
|
||||
import { CollectionsService } from '../collections.js';
|
||||
import { FieldsService } from '../fields.js';
|
||||
import { FilesService } from '../files.js';
|
||||
import { FlowsService } from '../flows.js';
|
||||
import { FoldersService } from '../folders.js';
|
||||
import { ItemsService } from '../items.js';
|
||||
import { NotificationsService } from '../notifications.js';
|
||||
import { OperationsService } from '../operations.js';
|
||||
import { PermissionsService } from '../permissions.js';
|
||||
import { PresetsService } from '../presets.js';
|
||||
import { RelationsService } from '../relations.js';
|
||||
import { RevisionsService } from '../revisions.js';
|
||||
import { RolesService } from '../roles.js';
|
||||
import { ServerService } from '../server.js';
|
||||
import { SettingsService } from '../settings.js';
|
||||
import { SharesService } from '../shares.js';
|
||||
import { SpecificationService } from '../specifications.js';
|
||||
import { TFAService } from '../tfa.js';
|
||||
import { UsersService } from '../users.js';
|
||||
import { UtilsService } from '../utils.js';
|
||||
import { WebhooksService } from '../webhooks.js';
|
||||
import { GraphQLBigInt } from './types/bigint.js';
|
||||
import { GraphQLDate } from './types/date.js';
|
||||
import { GraphQLGeoJSON } from './types/geojson.js';
|
||||
@@ -88,6 +77,9 @@ import { GraphQLStringOrFloat } from './types/string-or-float.js';
|
||||
import { GraphQLVoid } from './types/void.js';
|
||||
import { addPathToValidationError } from './utils/add-path-to-validation-error.js';
|
||||
import processError from './utils/process-error.js';
|
||||
import { createSubscriptionGenerator } from './subscription.js';
|
||||
import { getService } from '../../utils/get-service.js';
|
||||
import { toBoolean } from '../../utils/to-boolean.js';
|
||||
|
||||
const validationRules = Array.from(specifiedRules);
|
||||
|
||||
@@ -198,6 +190,15 @@ export class GraphQLService {
|
||||
: reduceSchema(this.schema, this.accountability?.permissions || null, ['delete']),
|
||||
};
|
||||
|
||||
const subscriptionEventType = schemaComposer.createEnumTC({
|
||||
name: 'EventEnum',
|
||||
values: {
|
||||
create: { value: 'create' },
|
||||
update: { value: 'update' },
|
||||
delete: { value: 'delete' },
|
||||
},
|
||||
});
|
||||
|
||||
const { ReadCollectionTypes } = getReadableTypes();
|
||||
|
||||
const { CreateCollectionTypes, UpdateCollectionTypes, DeleteCollectionTypes } = getWritableTypes();
|
||||
@@ -1066,6 +1067,29 @@ export class GraphQLService {
|
||||
},
|
||||
});
|
||||
}
|
||||
|
||||
const eventName = `${collection.collection}_mutated`;
|
||||
|
||||
if (collection.collection in ReadCollectionTypes) {
|
||||
const subscriptionType = schemaComposer.createObjectTC({
|
||||
name: eventName,
|
||||
fields: {
|
||||
key: new GraphQLNonNull(GraphQLID),
|
||||
event: subscriptionEventType,
|
||||
data: ReadCollectionTypes[collection.collection]!,
|
||||
},
|
||||
});
|
||||
|
||||
schemaComposer.Subscription.addFields({
|
||||
[eventName]: {
|
||||
type: subscriptionType,
|
||||
args: {
|
||||
event: subscriptionEventType,
|
||||
},
|
||||
subscribe: createSubscriptionGenerator(self, eventName),
|
||||
},
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
for (const relation of schema.read.relations) {
|
||||
@@ -1411,7 +1435,12 @@ export class GraphQLService {
|
||||
return await this.upsertSingleton(collection, args['data'], query);
|
||||
}
|
||||
|
||||
const service = this.getService(collection);
|
||||
const service = getService(collection, {
|
||||
knex: this.knex,
|
||||
accountability: this.accountability,
|
||||
schema: this.schema,
|
||||
});
|
||||
|
||||
const hasQuery = (query.fields || []).length > 0;
|
||||
|
||||
try {
|
||||
@@ -1466,7 +1495,11 @@ export class GraphQLService {
|
||||
* Execute the read action on the correct service. Checks for singleton as well.
|
||||
*/
|
||||
async read(collection: string, query: Query): Promise<Partial<Item>> {
|
||||
const service = this.getService(collection);
|
||||
const service = getService(collection, {
|
||||
knex: this.knex,
|
||||
accountability: this.accountability,
|
||||
schema: this.schema,
|
||||
});
|
||||
|
||||
const result = this.schema.collections[collection]!.singleton
|
||||
? await service.readSingleton(query, { stripNonRequested: false })
|
||||
@@ -1483,7 +1516,11 @@ export class GraphQLService {
|
||||
body: Record<string, any> | Record<string, any>[],
|
||||
query: Query
|
||||
): Promise<Partial<Item> | boolean> {
|
||||
const service = this.getService(collection);
|
||||
const service = getService(collection, {
|
||||
knex: this.knex,
|
||||
accountability: this.accountability,
|
||||
schema: this.schema,
|
||||
});
|
||||
|
||||
try {
|
||||
await service.upsertSingleton(body);
|
||||
@@ -1737,51 +1774,6 @@ export class GraphQLService {
|
||||
return new GraphQLError(error.message, undefined, undefined, undefined, undefined, error);
|
||||
}
|
||||
|
||||
/**
|
||||
* Select the correct service for the given collection. This allows the individual services to run
|
||||
* their custom checks (f.e. it allows UsersService to prevent updating TFA secret from outside)
|
||||
*/
|
||||
getService(collection: string): ItemsService {
|
||||
const opts = {
|
||||
knex: this.knex,
|
||||
accountability: this.accountability,
|
||||
schema: this.schema,
|
||||
};
|
||||
|
||||
switch (collection) {
|
||||
case 'directus_activity':
|
||||
return new ActivityService(opts);
|
||||
case 'directus_files':
|
||||
return new FilesService(opts);
|
||||
case 'directus_folders':
|
||||
return new FoldersService(opts);
|
||||
case 'directus_permissions':
|
||||
return new PermissionsService(opts);
|
||||
case 'directus_presets':
|
||||
return new PresetsService(opts);
|
||||
case 'directus_notifications':
|
||||
return new NotificationsService(opts);
|
||||
case 'directus_revisions':
|
||||
return new RevisionsService(opts);
|
||||
case 'directus_roles':
|
||||
return new RolesService(opts);
|
||||
case 'directus_settings':
|
||||
return new SettingsService(opts);
|
||||
case 'directus_users':
|
||||
return new UsersService(opts);
|
||||
case 'directus_webhooks':
|
||||
return new WebhooksService(opts);
|
||||
case 'directus_shares':
|
||||
return new SharesService(opts);
|
||||
case 'directus_flows':
|
||||
return new FlowsService(opts);
|
||||
case 'directus_operations':
|
||||
return new OperationsService(opts);
|
||||
default:
|
||||
return new ItemsService(collection, opts);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Replace all fragments in a selectionset for the actual selection set as defined in the fragment
|
||||
* Effectively merges the selections with the fragments used in those selections
|
||||
@@ -1907,6 +1899,58 @@ export class GraphQLService {
|
||||
},
|
||||
}),
|
||||
},
|
||||
websocket: toBoolean(env['WEBSOCKETS_ENABLED'])
|
||||
? {
|
||||
type: new GraphQLObjectType({
|
||||
name: 'server_info_websocket',
|
||||
fields: {
|
||||
rest: {
|
||||
type: toBoolean(env['WEBSOCKETS_REST_ENABLED'])
|
||||
? new GraphQLObjectType({
|
||||
name: 'server_info_websocket_rest',
|
||||
fields: {
|
||||
authentication: {
|
||||
type: new GraphQLEnumType({
|
||||
name: 'server_info_websocket_rest_authentication',
|
||||
values: {
|
||||
public: { value: 'public' },
|
||||
handshake: { value: 'handshake' },
|
||||
strict: { value: 'strict' },
|
||||
},
|
||||
}),
|
||||
},
|
||||
path: { type: GraphQLString },
|
||||
},
|
||||
})
|
||||
: GraphQLBoolean,
|
||||
},
|
||||
graphql: {
|
||||
type: toBoolean(env['WEBSOCKETS_GRAPHQL_ENABLED'])
|
||||
? new GraphQLObjectType({
|
||||
name: 'server_info_websocket_graphql',
|
||||
fields: {
|
||||
authentication: {
|
||||
type: new GraphQLEnumType({
|
||||
name: 'server_info_websocket_graphql_authentication',
|
||||
values: {
|
||||
public: { value: 'public' },
|
||||
handshake: { value: 'handshake' },
|
||||
strict: { value: 'strict' },
|
||||
},
|
||||
}),
|
||||
},
|
||||
path: { type: GraphQLString },
|
||||
},
|
||||
})
|
||||
: GraphQLBoolean,
|
||||
},
|
||||
heartbeat: {
|
||||
type: toBoolean(env['WEBSOCKETS_HEARTBEAT_ENABLED']) ? GraphQLInt : GraphQLBoolean,
|
||||
},
|
||||
},
|
||||
}),
|
||||
}
|
||||
: GraphQLBoolean,
|
||||
queryLimit: {
|
||||
type: new GraphQLObjectType({
|
||||
name: 'server_info_query_limit',
|
||||
|
||||
103
api/src/services/graphql/subscription.ts
Normal file
103
api/src/services/graphql/subscription.ts
Normal file
@@ -0,0 +1,103 @@
|
||||
import { EventEmitter, on } from 'events';
|
||||
import { getMessenger } from '../../messenger.js';
|
||||
import type { GraphQLService } from './index.js';
|
||||
import { getSchema } from '../../utils/get-schema.js';
|
||||
import { ItemsService } from '../items.js';
|
||||
import type { Query } from '@directus/types';
|
||||
import type { GraphQLResolveInfo, SelectionNode } from 'graphql';
|
||||
|
||||
const messages = createPubSub(new EventEmitter());
|
||||
|
||||
export function bindPubSub() {
|
||||
const messenger = getMessenger();
|
||||
|
||||
messenger.subscribe('websocket.event', (message: Record<string, any>) => {
|
||||
messages.publish(`${message['collection']}_mutated`, message);
|
||||
});
|
||||
}
|
||||
|
||||
export function createSubscriptionGenerator(self: GraphQLService, event: string) {
|
||||
return async function* (_x: unknown, _y: unknown, _z: unknown, request: GraphQLResolveInfo) {
|
||||
const fields = parseFields(self, request);
|
||||
const args = parseArguments(request);
|
||||
|
||||
for await (const payload of messages.subscribe(event)) {
|
||||
const eventData = payload as Record<string, any>;
|
||||
|
||||
if ('event' in args && eventData['action'] !== args['event']) {
|
||||
continue; // skip filtered events
|
||||
}
|
||||
|
||||
const schema = await getSchema();
|
||||
|
||||
if (eventData['action'] === 'create') {
|
||||
const { collection, key } = eventData;
|
||||
const service = new ItemsService(collection, { schema });
|
||||
const data = await service.readOne(key, { fields } as Query);
|
||||
yield { [event]: { key, data, event: 'create' } };
|
||||
}
|
||||
|
||||
if (eventData['action'] === 'update') {
|
||||
const { collection, keys } = eventData;
|
||||
const service = new ItemsService(collection, { schema });
|
||||
|
||||
for (const key of keys) {
|
||||
const data = await service.readOne(key, { fields } as Query);
|
||||
yield { [event]: { key, data, event: 'update' } };
|
||||
}
|
||||
}
|
||||
|
||||
if (eventData['action'] === 'delete') {
|
||||
const { keys } = eventData;
|
||||
|
||||
for (const key of keys) {
|
||||
yield { [event]: { key, data: null, event: 'delete' } };
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
function createPubSub<P extends { [key: string]: unknown }>(emitter: EventEmitter) {
|
||||
return {
|
||||
publish: <T extends Extract<keyof P, string>>(event: T, payload: P[T]) =>
|
||||
void emitter.emit(event as string, payload),
|
||||
subscribe: async function* <T extends Extract<keyof P, string>>(event: T): AsyncIterableIterator<P[T]> {
|
||||
const asyncIterator = on(emitter, event);
|
||||
|
||||
for await (const [value] of asyncIterator) {
|
||||
yield value;
|
||||
}
|
||||
},
|
||||
};
|
||||
}
|
||||
|
||||
function parseFields(service: GraphQLService, request: GraphQLResolveInfo) {
|
||||
const selections = request.fieldNodes[0]?.selectionSet?.selections ?? [];
|
||||
|
||||
const dataSelections = selections.reduce((result: readonly SelectionNode[], selection: SelectionNode) => {
|
||||
if (
|
||||
selection.kind === 'Field' &&
|
||||
selection.name.value === 'data' &&
|
||||
selection.selectionSet?.kind === 'SelectionSet'
|
||||
) {
|
||||
return selection.selectionSet.selections;
|
||||
}
|
||||
|
||||
return result;
|
||||
}, []);
|
||||
|
||||
const { fields } = service.getQuery({}, dataSelections, request.variableValues);
|
||||
return fields ?? [];
|
||||
}
|
||||
|
||||
function parseArguments(request: GraphQLResolveInfo) {
|
||||
const args = request.fieldNodes[0]?.arguments ?? [];
|
||||
return args.reduce((result, current) => {
|
||||
if ('value' in current.value && typeof current.value.value === 'string') {
|
||||
result[current.name.value] = current.value.value;
|
||||
}
|
||||
|
||||
return result;
|
||||
}, {} as Record<string, string>);
|
||||
}
|
||||
@@ -16,6 +16,7 @@ import { getStorage } from '../storage/index.js';
|
||||
import type { AbstractServiceOptions } from '../types/index.js';
|
||||
import { version } from '../utils/package.js';
|
||||
import { SettingsService } from './settings.js';
|
||||
import { toBoolean } from '../utils/to-boolean.js';
|
||||
|
||||
export class ServerService {
|
||||
knex: Knex;
|
||||
@@ -78,6 +79,32 @@ export class ServerService {
|
||||
};
|
||||
}
|
||||
|
||||
if (this.accountability?.user) {
|
||||
if (toBoolean(env['WEBSOCKETS_ENABLED'])) {
|
||||
info['websocket'] = {};
|
||||
|
||||
info['websocket'].rest = toBoolean(env['WEBSOCKETS_REST_ENABLED'])
|
||||
? {
|
||||
authentication: env['WEBSOCKETS_REST_AUTH'],
|
||||
path: env['WEBSOCKETS_REST_PATH'],
|
||||
}
|
||||
: false;
|
||||
|
||||
info['websocket'].graphql = toBoolean(env['WEBSOCKETS_GRAPHQL_ENABLED'])
|
||||
? {
|
||||
authentication: env['WEBSOCKETS_GRAPHQL_AUTH'],
|
||||
path: env['WEBSOCKETS_GRAPHQL_PATH'],
|
||||
}
|
||||
: false;
|
||||
|
||||
info['websocket'].heartbeat = toBoolean(env['WEBSOCKETS_HEARTBEAT_ENABLED'])
|
||||
? env['WEBSOCKETS_HEARTBEAT_PERIOD']
|
||||
: false;
|
||||
} else {
|
||||
info['websocket'] = false;
|
||||
}
|
||||
}
|
||||
|
||||
return info;
|
||||
}
|
||||
|
||||
|
||||
69
api/src/services/websocket.test.ts
Normal file
69
api/src/services/websocket.test.ts
Normal file
@@ -0,0 +1,69 @@
|
||||
import { describe, expect, test, vi } from 'vitest';
|
||||
import type { WebSocketClient } from '../websocket/types.js';
|
||||
import { WebSocketController, getWebSocketController } from '../websocket/controllers/index.js';
|
||||
import type { Accountability } from '@directus/types';
|
||||
import { WebSocketService } from './websocket.js';
|
||||
|
||||
vi.mock('../emitter');
|
||||
vi.mock('../websocket/controllers/index');
|
||||
|
||||
function mockClient(accountability: Accountability | null = null) {
|
||||
return {
|
||||
on: vi.fn(),
|
||||
off: vi.fn(),
|
||||
send: vi.fn(),
|
||||
close: vi.fn(),
|
||||
accountability,
|
||||
} as unknown as WebSocketClient;
|
||||
}
|
||||
|
||||
describe('WebSocketService', () => {
|
||||
test('get clients', () => {
|
||||
vi.mocked(getWebSocketController).mockReturnValue({
|
||||
clients: new Set([mockClient(), mockClient(), mockClient()]),
|
||||
} as unknown as WebSocketController);
|
||||
|
||||
const wsService = new WebSocketService();
|
||||
expect(wsService.clients().size).toBe(3);
|
||||
});
|
||||
|
||||
test('broadcast', () => {
|
||||
const clients = new Set([mockClient(), mockClient(), mockClient()]);
|
||||
const message = 'test 123';
|
||||
|
||||
vi.mocked(getWebSocketController).mockReturnValue({ clients } as unknown as WebSocketController);
|
||||
|
||||
const wsService = new WebSocketService();
|
||||
wsService.broadcast(message);
|
||||
|
||||
for (const client of clients) {
|
||||
expect(client.send).toBeCalledWith(message);
|
||||
}
|
||||
});
|
||||
|
||||
test('broadcast with role filter', () => {
|
||||
const clients = [mockClient({ user: 'test', role: 'test' }), mockClient({ user: 'test2', role: 'test2' })];
|
||||
const message = 'test 123';
|
||||
|
||||
vi.mocked(getWebSocketController).mockReturnValue({ clients: new Set(clients) } as unknown as WebSocketController);
|
||||
|
||||
const wsService = new WebSocketService();
|
||||
wsService.broadcast(message, { role: 'test' });
|
||||
|
||||
expect(clients[0]!.send).toBeCalledWith(message);
|
||||
expect(clients[1]!.send).not.toBeCalled();
|
||||
});
|
||||
|
||||
test('broadcast with user filter', () => {
|
||||
const clients = [mockClient({ user: 'test', role: 'test' }), mockClient({ user: 'test2', role: 'test2' })];
|
||||
const message = 'test 123';
|
||||
|
||||
vi.mocked(getWebSocketController).mockReturnValue({ clients: new Set(clients) } as unknown as WebSocketController);
|
||||
|
||||
const wsService = new WebSocketService();
|
||||
wsService.broadcast(message, { user: 'test2' });
|
||||
|
||||
expect(clients[0]!.send).not.toBeCalled();
|
||||
expect(clients[1]!.send).toBeCalledWith(message);
|
||||
});
|
||||
});
|
||||
34
api/src/services/websocket.ts
Normal file
34
api/src/services/websocket.ts
Normal file
@@ -0,0 +1,34 @@
|
||||
import type { ActionHandler } from '@directus/types';
|
||||
import { getWebSocketController } from '../websocket/controllers/index.js';
|
||||
import type { WebSocketController } from '../websocket/controllers/rest.js';
|
||||
import type { WebSocketClient } from '../websocket/types.js';
|
||||
import type { WebSocketMessage } from '../websocket/messages.js';
|
||||
import emitter from '../emitter.js';
|
||||
|
||||
export class WebSocketService {
|
||||
private controller: WebSocketController;
|
||||
|
||||
constructor() {
|
||||
this.controller = getWebSocketController();
|
||||
}
|
||||
|
||||
on(event: 'connect' | 'message' | 'error' | 'close', callback: ActionHandler) {
|
||||
emitter.onAction('websocket.' + event, callback);
|
||||
}
|
||||
|
||||
off(event: 'connect' | 'message' | 'error' | 'close', callback: ActionHandler) {
|
||||
emitter.offAction('websocket.' + event, callback);
|
||||
}
|
||||
|
||||
broadcast(message: string | WebSocketMessage, filter?: { user?: string; role?: string }) {
|
||||
this.controller.clients.forEach((client: WebSocketClient) => {
|
||||
if (filter && filter.user && filter.user !== client.accountability?.user) return;
|
||||
if (filter && filter.role && filter.role !== client.accountability?.role) return;
|
||||
client.send(typeof message === 'string' ? message : JSON.stringify(message));
|
||||
});
|
||||
}
|
||||
|
||||
clients(): Set<WebSocketClient> {
|
||||
return this.controller.clients;
|
||||
}
|
||||
}
|
||||
87
api/src/utils/get-accountability-for-role.test.ts
Normal file
87
api/src/utils/get-accountability-for-role.test.ts
Normal file
@@ -0,0 +1,87 @@
|
||||
import { expect, describe, test, vi } from 'vitest';
|
||||
import { getAccountabilityForRole } from './get-accountability-for-role.js';
|
||||
|
||||
vi.mock('./get-permissions', () => ({
|
||||
getPermissions: vi.fn().mockReturnValue([]),
|
||||
}));
|
||||
|
||||
function mockDatabase() {
|
||||
const self: Record<string, any> = {
|
||||
select: vi.fn(() => self),
|
||||
from: vi.fn(() => self),
|
||||
where: vi.fn(() => self),
|
||||
first: vi.fn(),
|
||||
};
|
||||
|
||||
return self;
|
||||
}
|
||||
|
||||
describe('getAccountabilityForRole', async () => {
|
||||
test('no role', async () => {
|
||||
const result = await getAccountabilityForRole(null, {
|
||||
accountability: null,
|
||||
schema: {} as any,
|
||||
database: vi.fn() as any,
|
||||
});
|
||||
|
||||
expect(result).toStrictEqual({
|
||||
admin: false,
|
||||
app: false,
|
||||
permissions: [],
|
||||
role: null,
|
||||
user: null,
|
||||
});
|
||||
});
|
||||
|
||||
test('system role', async () => {
|
||||
const result = await getAccountabilityForRole('system', {
|
||||
accountability: null,
|
||||
schema: {} as any,
|
||||
database: vi.fn() as any,
|
||||
});
|
||||
|
||||
expect(result).toStrictEqual({
|
||||
admin: true,
|
||||
app: true,
|
||||
permissions: [],
|
||||
role: null,
|
||||
user: null,
|
||||
});
|
||||
});
|
||||
|
||||
test('get role from database', async () => {
|
||||
const db = mockDatabase();
|
||||
|
||||
db['first'].mockReturnValue({
|
||||
admin_access: 'not true',
|
||||
app_access: '1',
|
||||
});
|
||||
|
||||
const result = await getAccountabilityForRole('123-456', {
|
||||
accountability: null,
|
||||
schema: {} as any,
|
||||
database: db as any,
|
||||
});
|
||||
|
||||
expect(result).toStrictEqual({
|
||||
admin: false,
|
||||
app: true,
|
||||
permissions: [],
|
||||
role: '123-456',
|
||||
user: null,
|
||||
});
|
||||
});
|
||||
|
||||
test('database invalid role', async () => {
|
||||
const db = mockDatabase();
|
||||
db['first'].mockReturnValue(false);
|
||||
|
||||
expect(() =>
|
||||
getAccountabilityForRole('456-789', {
|
||||
accountability: null,
|
||||
schema: {} as any,
|
||||
database: db as any,
|
||||
})
|
||||
).rejects.toThrow('Configured role "456-789" isn\'t a valid role ID or doesn\'t exist.');
|
||||
});
|
||||
});
|
||||
101
api/src/utils/get-accountability-for-token.test.ts
Normal file
101
api/src/utils/get-accountability-for-token.test.ts
Normal file
@@ -0,0 +1,101 @@
|
||||
import { expect, describe, test, vi } from 'vitest';
|
||||
import jwt from 'jsonwebtoken';
|
||||
import env from '../env.js';
|
||||
import { getAccountabilityForToken } from './get-accountability-for-token.js';
|
||||
import getDatabase from '../database/index.js';
|
||||
|
||||
vi.mock('../env', () => {
|
||||
const MOCK_ENV = {
|
||||
SECRET: 'super-secure-secret',
|
||||
EXTENSIONS_PATH: './extensions',
|
||||
};
|
||||
|
||||
return {
|
||||
default: MOCK_ENV,
|
||||
getEnv: () => MOCK_ENV,
|
||||
};
|
||||
});
|
||||
|
||||
vi.mock('../database/index', () => {
|
||||
const self: Record<string, any> = {
|
||||
select: vi.fn(() => self),
|
||||
from: vi.fn(() => self),
|
||||
leftJoin: vi.fn(() => self),
|
||||
where: vi.fn(() => self),
|
||||
first: vi.fn(),
|
||||
};
|
||||
|
||||
return { default: vi.fn(() => self) };
|
||||
});
|
||||
|
||||
describe('getAccountabilityForToken', async () => {
|
||||
test('minimal token payload', async () => {
|
||||
const token = jwt.sign({ role: '123-456-789', app_access: false, admin_access: false }, env['SECRET'], {
|
||||
issuer: 'directus',
|
||||
});
|
||||
|
||||
const result = await getAccountabilityForToken(token);
|
||||
expect(result).toStrictEqual({ admin: false, app: false, role: '123-456-789', user: null });
|
||||
});
|
||||
|
||||
test('full token payload', async () => {
|
||||
const token = jwt.sign(
|
||||
{
|
||||
share: 'share-id',
|
||||
share_scope: 'share-scope',
|
||||
id: 'user-id',
|
||||
role: 'role-id',
|
||||
admin_access: 1,
|
||||
app_access: 1,
|
||||
},
|
||||
env['SECRET'],
|
||||
{ issuer: 'directus' }
|
||||
);
|
||||
|
||||
const result = await getAccountabilityForToken(token);
|
||||
expect(result.admin).toBe(true);
|
||||
expect(result.app).toBe(true);
|
||||
expect(result.role).toBe('role-id');
|
||||
expect(result.share).toBe('share-id');
|
||||
expect(result.share_scope).toBe('share-scope');
|
||||
expect(result.user).toBe('user-id');
|
||||
});
|
||||
|
||||
test('throws token expired error', async () => {
|
||||
const token = jwt.sign({ role: '123-456-789' }, env['SECRET'], { issuer: 'directus', expiresIn: -1 });
|
||||
expect(() => getAccountabilityForToken(token)).rejects.toThrow('Token expired.');
|
||||
});
|
||||
|
||||
test('throws token invalid error', async () => {
|
||||
const token = jwt.sign({ role: '123-456-789' }, 'bad-secret', { issuer: 'directus' });
|
||||
expect(() => getAccountabilityForToken(token)).rejects.toThrow('Token invalid.');
|
||||
});
|
||||
|
||||
test('find user in database', async () => {
|
||||
const db = getDatabase();
|
||||
|
||||
vi.spyOn(db, 'first').mockReturnValue({
|
||||
id: 'user-id',
|
||||
role: 'role-id',
|
||||
admin_access: false,
|
||||
app_access: true,
|
||||
} as any);
|
||||
|
||||
const token = jwt.sign({ role: '123-456-789' }, 'bad-secret');
|
||||
const result = await getAccountabilityForToken(token);
|
||||
|
||||
expect(result).toStrictEqual({
|
||||
user: 'user-id',
|
||||
role: 'role-id',
|
||||
admin: false,
|
||||
app: true,
|
||||
});
|
||||
});
|
||||
|
||||
test('no user found', async () => {
|
||||
const db = getDatabase();
|
||||
vi.spyOn(db, 'first').mockReturnValue(false as any);
|
||||
const token = jwt.sign({ role: '123-456-789' }, 'bad-secret');
|
||||
expect(() => getAccountabilityForToken(token)).rejects.toThrow('Invalid user credentials.');
|
||||
});
|
||||
});
|
||||
58
api/src/utils/get-accountability-for-token.ts
Normal file
58
api/src/utils/get-accountability-for-token.ts
Normal file
@@ -0,0 +1,58 @@
|
||||
import getDatabase from '../database/index.js';
|
||||
import type { Accountability } from '@directus/types';
|
||||
import isDirectusJWT from './is-directus-jwt.js';
|
||||
import { InvalidCredentialsException } from '../index.js';
|
||||
import env from '../env.js';
|
||||
import { verifyAccessJWT } from './jwt.js';
|
||||
|
||||
export async function getAccountabilityForToken(
|
||||
token?: string | null,
|
||||
accountability?: Accountability
|
||||
): Promise<Accountability> {
|
||||
if (!accountability) {
|
||||
accountability = {
|
||||
user: null,
|
||||
role: null,
|
||||
admin: false,
|
||||
app: false,
|
||||
};
|
||||
}
|
||||
|
||||
if (token) {
|
||||
if (isDirectusJWT(token)) {
|
||||
const payload = verifyAccessJWT(token, env['SECRET'] as string);
|
||||
|
||||
accountability.role = payload.role;
|
||||
accountability.admin = payload.admin_access === true || payload.admin_access == 1;
|
||||
accountability.app = payload.app_access === true || payload.app_access == 1;
|
||||
|
||||
if (payload.share) accountability.share = payload.share;
|
||||
if (payload.share_scope) accountability.share_scope = payload.share_scope;
|
||||
if (payload.id) accountability.user = payload.id;
|
||||
} else {
|
||||
// Try finding the user with the provided token
|
||||
const database = getDatabase();
|
||||
|
||||
const user = await database
|
||||
.select('directus_users.id', 'directus_users.role', 'directus_roles.admin_access', 'directus_roles.app_access')
|
||||
.from('directus_users')
|
||||
.leftJoin('directus_roles', 'directus_users.role', 'directus_roles.id')
|
||||
.where({
|
||||
'directus_users.token': token,
|
||||
status: 'active',
|
||||
})
|
||||
.first();
|
||||
|
||||
if (!user) {
|
||||
throw new InvalidCredentialsException();
|
||||
}
|
||||
|
||||
accountability.user = user.id;
|
||||
accountability.role = user.role;
|
||||
accountability.admin = user.admin_access === true || user.admin_access == 1;
|
||||
accountability.app = user.app_access === true || user.app_access == 1;
|
||||
}
|
||||
}
|
||||
|
||||
return accountability;
|
||||
}
|
||||
69
api/src/utils/get-service.ts
Normal file
69
api/src/utils/get-service.ts
Normal file
@@ -0,0 +1,69 @@
|
||||
import {
|
||||
ActivityService,
|
||||
DashboardsService,
|
||||
FilesService,
|
||||
FlowsService,
|
||||
FoldersService,
|
||||
ItemsService,
|
||||
NotificationsService,
|
||||
OperationsService,
|
||||
PanelsService,
|
||||
PermissionsService,
|
||||
PresetsService,
|
||||
RevisionsService,
|
||||
RolesService,
|
||||
SettingsService,
|
||||
SharesService,
|
||||
UsersService,
|
||||
WebhooksService,
|
||||
} from '../index.js';
|
||||
import type { AbstractServiceOptions } from '../types/services.js';
|
||||
|
||||
/**
|
||||
* Select the correct service for the given collection. This allows the individual services to run
|
||||
* their custom checks (f.e. it allows UsersService to prevent updating TFA secret from outside)
|
||||
*/
|
||||
export function getService(collection: string, opts: AbstractServiceOptions): ItemsService {
|
||||
switch (collection) {
|
||||
case 'directus_activity':
|
||||
return new ActivityService(opts);
|
||||
// case 'directus_collections':
|
||||
// return new CollectionsService(opts);
|
||||
case 'directus_dashboards':
|
||||
return new DashboardsService(opts);
|
||||
// case 'directus_fields':
|
||||
// return new FieldsService(opts);
|
||||
case 'directus_files':
|
||||
return new FilesService(opts);
|
||||
case 'directus_flows':
|
||||
return new FlowsService(opts);
|
||||
case 'directus_folders':
|
||||
return new FoldersService(opts);
|
||||
case 'directus_notifications':
|
||||
return new NotificationsService(opts);
|
||||
case 'directus_operations':
|
||||
return new OperationsService(opts);
|
||||
case 'directus_panels':
|
||||
return new PanelsService(opts);
|
||||
case 'directus_permissions':
|
||||
return new PermissionsService(opts);
|
||||
case 'directus_presets':
|
||||
return new PresetsService(opts);
|
||||
// case 'directus_relations':
|
||||
// return new RelationsService(opts);
|
||||
case 'directus_revisions':
|
||||
return new RevisionsService(opts);
|
||||
case 'directus_roles':
|
||||
return new RolesService(opts);
|
||||
case 'directus_settings':
|
||||
return new SettingsService(opts);
|
||||
case 'directus_shares':
|
||||
return new SharesService(opts);
|
||||
case 'directus_users':
|
||||
return new UsersService(opts);
|
||||
case 'directus_webhooks':
|
||||
return new WebhooksService(opts);
|
||||
default:
|
||||
return new ItemsService(collection, opts);
|
||||
}
|
||||
}
|
||||
16
api/src/utils/to-boolean.test.ts
Normal file
16
api/src/utils/to-boolean.test.ts
Normal file
@@ -0,0 +1,16 @@
|
||||
import { expect, test } from 'vitest';
|
||||
import { toBoolean } from './to-boolean.js';
|
||||
|
||||
test.each([
|
||||
['true', true],
|
||||
[true, true],
|
||||
['1', true],
|
||||
[1, true],
|
||||
['false', false],
|
||||
['anything', false],
|
||||
[123, false],
|
||||
[{}, false],
|
||||
[['{}'], false],
|
||||
])('toBoolean(%s) -> %s', (value, expected) => {
|
||||
expect(toBoolean(value)).toBe(expected);
|
||||
});
|
||||
6
api/src/utils/to-boolean.ts
Normal file
6
api/src/utils/to-boolean.ts
Normal file
@@ -0,0 +1,6 @@
|
||||
/**
|
||||
* Convert environment variable to Boolean
|
||||
*/
|
||||
export function toBoolean(value: any): boolean {
|
||||
return value === 'true' || value === true || value === '1' || value === 1;
|
||||
}
|
||||
137
api/src/websocket/authenticate.test.ts
Normal file
137
api/src/websocket/authenticate.test.ts
Normal file
@@ -0,0 +1,137 @@
|
||||
import type { Accountability } from '@directus/types';
|
||||
import { describe, expect, test, vi } from 'vitest';
|
||||
import type { Mock } from 'vitest';
|
||||
import { InvalidCredentialsException } from '../index.js';
|
||||
import { getAccountabilityForRole } from '../utils/get-accountability-for-role.js';
|
||||
import { getAccountabilityForToken } from '../utils/get-accountability-for-token.js';
|
||||
import { authenticateConnection, authenticationSuccess, refreshAccountability } from './authenticate.js';
|
||||
import type { WebSocketAuthMessage } from './messages.js';
|
||||
import { getExpiresAtForToken } from './utils/get-expires-at-for-token.js';
|
||||
|
||||
vi.mock('../utils/get-accountability-for-token', () => ({
|
||||
getAccountabilityForToken: vi.fn().mockReturnValue({
|
||||
role: null, // minimum viable accountability
|
||||
} as Accountability),
|
||||
}));
|
||||
|
||||
vi.mock('../utils/get-accountability-for-role', () => ({
|
||||
getAccountabilityForRole: vi.fn(),
|
||||
}));
|
||||
|
||||
vi.mock('./utils/get-expires-at-for-token', () => ({
|
||||
getExpiresAtForToken: vi.fn(),
|
||||
}));
|
||||
|
||||
vi.mock('../utils/get-schema');
|
||||
|
||||
vi.mock('../services/authentication', () => ({
|
||||
AuthenticationService: vi.fn(() => ({
|
||||
login: vi.fn().mockReturnValue({ accessToken: '123', refreshToken: 'refresh', expires: 123456 }),
|
||||
refresh: vi.fn().mockReturnValue({ accessToken: '456', refreshToken: 'refresh' }),
|
||||
})),
|
||||
}));
|
||||
|
||||
vi.mock('../database');
|
||||
|
||||
describe('authenticateConnection', () => {
|
||||
test('Success with email/password', async () => {
|
||||
const TIMESTAMP = 123456789;
|
||||
(getExpiresAtForToken as Mock).mockReturnValue(TIMESTAMP);
|
||||
|
||||
const result = await authenticateConnection({
|
||||
type: 'auth',
|
||||
email: 'email',
|
||||
password: 'password',
|
||||
} as WebSocketAuthMessage);
|
||||
|
||||
expect(result).toStrictEqual({
|
||||
accountability: { role: null },
|
||||
expires_at: TIMESTAMP,
|
||||
refresh_token: 'refresh',
|
||||
});
|
||||
});
|
||||
|
||||
test('Success with refresh_token', async () => {
|
||||
const TIMESTAMP = 987654;
|
||||
(getExpiresAtForToken as Mock).mockReturnValue(TIMESTAMP);
|
||||
|
||||
const result = await authenticateConnection({
|
||||
type: 'auth',
|
||||
refresh_token: 'refresh_token',
|
||||
} as WebSocketAuthMessage);
|
||||
|
||||
expect(result).toStrictEqual({
|
||||
accountability: { role: null },
|
||||
expires_at: TIMESTAMP,
|
||||
refresh_token: 'refresh',
|
||||
});
|
||||
});
|
||||
|
||||
test('Success with access_token', async () => {
|
||||
const TIMESTAMP = 456987;
|
||||
(getExpiresAtForToken as Mock).mockReturnValue(TIMESTAMP);
|
||||
|
||||
const result = await authenticateConnection({
|
||||
type: 'auth',
|
||||
access_token: 'access_token',
|
||||
} as WebSocketAuthMessage);
|
||||
|
||||
expect(result).toStrictEqual({
|
||||
accountability: { role: null },
|
||||
expires_at: TIMESTAMP,
|
||||
refresh_token: undefined,
|
||||
});
|
||||
});
|
||||
|
||||
test('Failure token expired', async () => {
|
||||
(getAccountabilityForToken as Mock).mockImplementation(() => {
|
||||
throw new InvalidCredentialsException('Token expired.');
|
||||
});
|
||||
|
||||
expect(() =>
|
||||
authenticateConnection({
|
||||
type: 'auth',
|
||||
access_token: 'expired',
|
||||
} as WebSocketAuthMessage)
|
||||
).rejects.toThrow('Token expired.');
|
||||
});
|
||||
|
||||
test('Failure authentication failed', async () => {
|
||||
expect(() =>
|
||||
authenticateConnection({
|
||||
type: 'auth',
|
||||
access_token: '',
|
||||
} as WebSocketAuthMessage)
|
||||
).rejects.toThrow('Authentication failed.');
|
||||
});
|
||||
});
|
||||
|
||||
describe('refreshAccountability', () => {
|
||||
test('should just work', async () => {
|
||||
(getAccountabilityForRole as Mock).mockReturnValue({
|
||||
role: '123-456-789',
|
||||
} as Accountability);
|
||||
|
||||
const result = await refreshAccountability({
|
||||
role: null,
|
||||
user: 'abc-def-ghi',
|
||||
});
|
||||
|
||||
expect(result).toStrictEqual({
|
||||
role: '123-456-789',
|
||||
user: 'abc-def-ghi',
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
describe('authenticationSuccess', () => {
|
||||
test('without uid', async () => {
|
||||
const result = authenticationSuccess();
|
||||
expect(result).toBe('{"type":"auth","status":"ok"}');
|
||||
});
|
||||
|
||||
test('with uid', async () => {
|
||||
const result = authenticationSuccess('123456');
|
||||
expect(result).toBe('{"type":"auth","status":"ok","uid":"123456"}');
|
||||
});
|
||||
});
|
||||
79
api/src/websocket/authenticate.ts
Normal file
79
api/src/websocket/authenticate.ts
Normal file
@@ -0,0 +1,79 @@
|
||||
import type { Accountability } from '@directus/types';
|
||||
import { DEFAULT_AUTH_PROVIDER } from '../constants.js';
|
||||
import getDatabase from '../database/index.js';
|
||||
import { InvalidCredentialsException } from '../exceptions/index.js';
|
||||
import { AuthenticationService } from '../services/index.js';
|
||||
import { getAccountabilityForRole } from '../utils/get-accountability-for-role.js';
|
||||
import { getAccountabilityForToken } from '../utils/get-accountability-for-token.js';
|
||||
import { getSchema } from '../utils/get-schema.js';
|
||||
import { WebSocketException } from './exceptions.js';
|
||||
import type { BasicAuthMessage, WebSocketResponse } from './messages.js';
|
||||
import type { AuthenticationState } from './types.js';
|
||||
import { getExpiresAtForToken } from './utils/get-expires-at-for-token.js';
|
||||
|
||||
export async function authenticateConnection(
|
||||
message: BasicAuthMessage & Record<string, any>
|
||||
): Promise<AuthenticationState> {
|
||||
let access_token: string | undefined, refresh_token: string | undefined;
|
||||
|
||||
try {
|
||||
if ('email' in message && 'password' in message) {
|
||||
const authenticationService = new AuthenticationService({ schema: await getSchema() });
|
||||
const { accessToken, refreshToken } = await authenticationService.login(DEFAULT_AUTH_PROVIDER, message);
|
||||
access_token = accessToken;
|
||||
refresh_token = refreshToken;
|
||||
}
|
||||
|
||||
if ('refresh_token' in message) {
|
||||
const authenticationService = new AuthenticationService({ schema: await getSchema() });
|
||||
const { accessToken, refreshToken } = await authenticationService.refresh(message.refresh_token);
|
||||
access_token = accessToken;
|
||||
refresh_token = refreshToken;
|
||||
}
|
||||
|
||||
if ('access_token' in message) {
|
||||
access_token = message.access_token;
|
||||
}
|
||||
|
||||
if (!access_token) throw new Error();
|
||||
const accountability = await getAccountabilityForToken(access_token);
|
||||
const expires_at = getExpiresAtForToken(access_token);
|
||||
return { accountability, expires_at, refresh_token } as AuthenticationState;
|
||||
} catch (error) {
|
||||
if (error instanceof InvalidCredentialsException && error.message === 'Token expired.') {
|
||||
throw new WebSocketException('auth', 'TOKEN_EXPIRED', 'Token expired.', message['uid']);
|
||||
}
|
||||
|
||||
throw new WebSocketException('auth', 'AUTH_FAILED', 'Authentication failed.', message['uid']);
|
||||
}
|
||||
}
|
||||
|
||||
export async function refreshAccountability(
|
||||
accountability: Accountability | null | undefined
|
||||
): Promise<Accountability> {
|
||||
const result: Accountability = await getAccountabilityForRole(accountability?.role || null, {
|
||||
accountability: accountability || null,
|
||||
schema: await getSchema(),
|
||||
database: getDatabase(),
|
||||
});
|
||||
|
||||
result.user = accountability?.user || null;
|
||||
return result;
|
||||
}
|
||||
|
||||
export function authenticationSuccess(uid?: string | number, refresh_token?: string): string {
|
||||
const message: WebSocketResponse = {
|
||||
type: 'auth',
|
||||
status: 'ok',
|
||||
};
|
||||
|
||||
if (uid !== undefined) {
|
||||
message.uid = uid;
|
||||
}
|
||||
|
||||
if (refresh_token !== undefined) {
|
||||
message['refresh_token'] = refresh_token;
|
||||
}
|
||||
|
||||
return JSON.stringify(message);
|
||||
}
|
||||
344
api/src/websocket/controllers/base.ts
Normal file
344
api/src/websocket/controllers/base.ts
Normal file
@@ -0,0 +1,344 @@
|
||||
import type { Accountability } from '@directus/types';
|
||||
import { parseJSON } from '@directus/utils';
|
||||
import type { IncomingMessage, Server as httpServer } from 'http';
|
||||
import type { ParsedUrlQuery } from 'querystring';
|
||||
import type { RateLimiterAbstract } from 'rate-limiter-flexible';
|
||||
import type internal from 'stream';
|
||||
import { parse } from 'url';
|
||||
import { v4 as uuid } from 'uuid';
|
||||
import WebSocket, { WebSocketServer } from 'ws';
|
||||
import { fromZodError } from 'zod-validation-error';
|
||||
import emitter from '../../emitter.js';
|
||||
import env from '../../env.js';
|
||||
import { InvalidConfigException, TokenExpiredException } from '../../exceptions/index.js';
|
||||
import logger from '../../logger.js';
|
||||
import { createRateLimiter } from '../../rate-limiter.js';
|
||||
import { getAccountabilityForToken } from '../../utils/get-accountability-for-token.js';
|
||||
import { toBoolean } from '../../utils/to-boolean.js';
|
||||
import { authenticateConnection, authenticationSuccess } from '../authenticate.js';
|
||||
import { WebSocketException, handleWebSocketException } from '../exceptions.js';
|
||||
import { AuthMode, WebSocketAuthMessage, WebSocketMessage } from '../messages.js';
|
||||
import type { AuthenticationState, UpgradeContext, WebSocketClient } from '../types.js';
|
||||
import { getExpiresAtForToken } from '../utils/get-expires-at-for-token.js';
|
||||
import { getMessageType } from '../utils/message.js';
|
||||
import { waitForAnyMessage, waitForMessageType } from '../utils/wait-for-message.js';
|
||||
import { registerWebSocketEvents } from './hooks.js';
|
||||
|
||||
const TOKEN_CHECK_INTERVAL = 15 * 60 * 1000; // 15 minutes
|
||||
|
||||
export default abstract class SocketController {
|
||||
server: WebSocket.Server;
|
||||
clients: Set<WebSocketClient>;
|
||||
authentication: {
|
||||
mode: AuthMode;
|
||||
timeout: number;
|
||||
};
|
||||
|
||||
endpoint: string;
|
||||
maxConnections: number;
|
||||
private rateLimiter: RateLimiterAbstract | null;
|
||||
private authInterval: NodeJS.Timer | null;
|
||||
|
||||
constructor(httpServer: httpServer, configPrefix: string) {
|
||||
this.server = new WebSocketServer({ noServer: true });
|
||||
this.clients = new Set();
|
||||
this.authInterval = null;
|
||||
|
||||
const { endpoint, authentication, maxConnections } = this.getEnvironmentConfig(configPrefix);
|
||||
this.endpoint = endpoint;
|
||||
this.authentication = authentication;
|
||||
this.maxConnections = maxConnections;
|
||||
this.rateLimiter = this.getRateLimiter();
|
||||
|
||||
httpServer.on('upgrade', this.handleUpgrade.bind(this));
|
||||
this.checkClientTokens();
|
||||
registerWebSocketEvents();
|
||||
}
|
||||
|
||||
protected getEnvironmentConfig(configPrefix: string): {
|
||||
endpoint: string;
|
||||
authentication: {
|
||||
mode: AuthMode;
|
||||
timeout: number;
|
||||
};
|
||||
maxConnections: number;
|
||||
} {
|
||||
const endpoint = String(env[`${configPrefix}_PATH`]);
|
||||
const authMode = AuthMode.safeParse(String(env[`${configPrefix}_AUTH`]).toLowerCase());
|
||||
const authTimeout = Number(env[`${configPrefix}_AUTH_TIMEOUT`]) * 1000;
|
||||
|
||||
const maxConnections =
|
||||
`${configPrefix}_CONN_LIMIT` in env ? Number(env[`${configPrefix}_CONN_LIMIT`]) : Number.POSITIVE_INFINITY;
|
||||
|
||||
if (!authMode.success) {
|
||||
throw new InvalidConfigException(fromZodError(authMode.error, { prefix: `${configPrefix}_AUTH` }).message);
|
||||
}
|
||||
|
||||
return {
|
||||
endpoint,
|
||||
maxConnections,
|
||||
authentication: {
|
||||
mode: authMode.data,
|
||||
timeout: authTimeout,
|
||||
},
|
||||
};
|
||||
}
|
||||
|
||||
protected getRateLimiter() {
|
||||
if (toBoolean(env['RATE_LIMITER_ENABLED']) === true) {
|
||||
return createRateLimiter('RATE_LIMITER', {
|
||||
keyPrefix: 'websocket',
|
||||
});
|
||||
}
|
||||
|
||||
return null;
|
||||
}
|
||||
|
||||
protected async handleUpgrade(request: IncomingMessage, socket: internal.Duplex, head: Buffer) {
|
||||
const { pathname, query } = parse(request.url!, true);
|
||||
if (pathname !== this.endpoint) return;
|
||||
|
||||
if (this.clients.size >= this.maxConnections) {
|
||||
logger.debug('WebSocket upgrade denied - max connections reached');
|
||||
socket.write('HTTP/1.1 403 Forbidden\r\n\r\n');
|
||||
socket.destroy();
|
||||
return;
|
||||
}
|
||||
|
||||
const context: UpgradeContext = { request, socket, head };
|
||||
|
||||
if (this.authentication.mode === 'strict') {
|
||||
await this.handleStrictUpgrade(context, query);
|
||||
return;
|
||||
}
|
||||
|
||||
if (this.authentication.mode === 'handshake') {
|
||||
await this.handleHandshakeUpgrade(context);
|
||||
return;
|
||||
}
|
||||
|
||||
this.server.handleUpgrade(request, socket, head, async (ws) => {
|
||||
const state = { accountability: null, expires_at: null } as AuthenticationState;
|
||||
this.server.emit('connection', ws, state);
|
||||
});
|
||||
}
|
||||
|
||||
protected async handleStrictUpgrade({ request, socket, head }: UpgradeContext, query: ParsedUrlQuery) {
|
||||
let accountability: Accountability | null, expires_at: number | null;
|
||||
|
||||
try {
|
||||
const token = query['access_token'] as string;
|
||||
accountability = await getAccountabilityForToken(token);
|
||||
expires_at = getExpiresAtForToken(token);
|
||||
} catch {
|
||||
accountability = null;
|
||||
expires_at = null;
|
||||
}
|
||||
|
||||
if (!accountability || !accountability.user) {
|
||||
logger.debug('WebSocket upgrade denied - ' + JSON.stringify(accountability || 'invalid'));
|
||||
socket.write('HTTP/1.1 401 Unauthorized\r\n\r\n');
|
||||
socket.destroy();
|
||||
return;
|
||||
}
|
||||
|
||||
this.server.handleUpgrade(request, socket, head, async (ws) => {
|
||||
const state = { accountability, expires_at } as AuthenticationState;
|
||||
this.server.emit('connection', ws, state);
|
||||
});
|
||||
}
|
||||
|
||||
protected async handleHandshakeUpgrade({ request, socket, head }: UpgradeContext) {
|
||||
this.server.handleUpgrade(request, socket, head, async (ws) => {
|
||||
try {
|
||||
const payload = await waitForAnyMessage(ws, this.authentication.timeout);
|
||||
if (getMessageType(payload) !== 'auth') throw new Error();
|
||||
|
||||
const state = await authenticateConnection(WebSocketAuthMessage.parse(payload));
|
||||
ws.send(authenticationSuccess(payload['uid'], state.refresh_token));
|
||||
this.server.emit('connection', ws, state);
|
||||
} catch {
|
||||
logger.debug('WebSocket authentication handshake failed');
|
||||
const error = new WebSocketException('auth', 'AUTH_FAILED', 'Authentication handshake failed.');
|
||||
handleWebSocketException(ws, error, 'auth');
|
||||
ws.close();
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
createClient(ws: WebSocket, { accountability, expires_at }: AuthenticationState) {
|
||||
const client = ws as WebSocketClient;
|
||||
client.accountability = accountability;
|
||||
client.expires_at = expires_at;
|
||||
client.uid = uuid();
|
||||
client.auth_timer = null;
|
||||
|
||||
ws.on('message', async (data: WebSocket.RawData) => {
|
||||
if (this.rateLimiter !== null) {
|
||||
try {
|
||||
await this.rateLimiter.consume(client.uid);
|
||||
} catch (limit) {
|
||||
const timeout = (limit as any)?.msBeforeNext ?? this.rateLimiter.msDuration;
|
||||
|
||||
const error = new WebSocketException(
|
||||
'server',
|
||||
'REQUESTS_EXCEEDED',
|
||||
`Too many messages, retry after ${timeout}ms.`
|
||||
);
|
||||
|
||||
handleWebSocketException(client, error, 'server');
|
||||
logger.debug(`WebSocket#${client.uid} is rate limited`);
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
let message: WebSocketMessage;
|
||||
|
||||
try {
|
||||
message = this.parseMessage(data.toString());
|
||||
} catch (err: any) {
|
||||
handleWebSocketException(client, err, 'server');
|
||||
return;
|
||||
}
|
||||
|
||||
if (getMessageType(message) === 'auth') {
|
||||
try {
|
||||
await this.handleAuthRequest(client, WebSocketAuthMessage.parse(message));
|
||||
} catch {
|
||||
// ignore errors
|
||||
}
|
||||
|
||||
return;
|
||||
}
|
||||
|
||||
// this log cannot be higher in the function or it will leak credentials
|
||||
logger.trace(`WebSocket#${client.uid} - ${JSON.stringify(message)}`);
|
||||
ws.emit('parsed-message', message);
|
||||
});
|
||||
|
||||
ws.on('error', () => {
|
||||
logger.debug(`WebSocket#${client.uid} connection errored`);
|
||||
|
||||
if (client.auth_timer) {
|
||||
clearTimeout(client.auth_timer);
|
||||
client.auth_timer = null;
|
||||
}
|
||||
|
||||
this.clients.delete(client);
|
||||
});
|
||||
|
||||
ws.on('close', () => {
|
||||
logger.debug(`WebSocket#${client.uid} connection closed`);
|
||||
|
||||
if (client.auth_timer) {
|
||||
clearTimeout(client.auth_timer);
|
||||
client.auth_timer = null;
|
||||
}
|
||||
|
||||
this.clients.delete(client);
|
||||
});
|
||||
|
||||
logger.debug(`WebSocket#${client.uid} connected`);
|
||||
|
||||
if (accountability) {
|
||||
logger.trace(`WebSocket#${client.uid} authenticated as ${JSON.stringify(accountability)}`);
|
||||
}
|
||||
|
||||
this.setTokenExpireTimer(client);
|
||||
this.clients.add(client);
|
||||
return client;
|
||||
}
|
||||
|
||||
protected parseMessage(data: string): WebSocketMessage {
|
||||
let message: WebSocketMessage;
|
||||
|
||||
try {
|
||||
message = WebSocketMessage.parse(parseJSON(data));
|
||||
} catch (err: any) {
|
||||
throw new WebSocketException('server', 'INVALID_PAYLOAD', 'Unable to parse the incoming message.');
|
||||
}
|
||||
|
||||
return message;
|
||||
}
|
||||
|
||||
protected async handleAuthRequest(client: WebSocketClient, message: WebSocketAuthMessage) {
|
||||
try {
|
||||
const { accountability, expires_at, refresh_token } = await authenticateConnection(message);
|
||||
client.accountability = accountability;
|
||||
client.expires_at = expires_at;
|
||||
this.setTokenExpireTimer(client);
|
||||
emitter.emitAction('websocket.auth.success', { client });
|
||||
client.send(authenticationSuccess(message.uid, refresh_token));
|
||||
logger.trace(`WebSocket#${client.uid} authenticated as ${JSON.stringify(client.accountability)}`);
|
||||
} catch (error) {
|
||||
logger.trace(`WebSocket#${client.uid} failed authentication`);
|
||||
emitter.emitAction('websocket.auth.failure', { client });
|
||||
|
||||
client.accountability = null;
|
||||
client.expires_at = null;
|
||||
|
||||
const _error =
|
||||
error instanceof WebSocketException
|
||||
? error
|
||||
: new WebSocketException('auth', 'AUTH_FAILED', 'Authentication failed.', message.uid);
|
||||
|
||||
handleWebSocketException(client, _error, 'auth');
|
||||
|
||||
if (this.authentication.mode !== 'public') {
|
||||
client.close();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
setTokenExpireTimer(client: WebSocketClient) {
|
||||
if (client.auth_timer !== null) {
|
||||
// clear up old timeouts if needed
|
||||
clearTimeout(client.auth_timer);
|
||||
client.auth_timer = null;
|
||||
}
|
||||
|
||||
if (!client.expires_at) return;
|
||||
|
||||
const expiresIn = client.expires_at * 1000 - Date.now();
|
||||
if (expiresIn > TOKEN_CHECK_INTERVAL) return;
|
||||
|
||||
client.auth_timer = setTimeout(() => {
|
||||
client.accountability = null;
|
||||
client.expires_at = null;
|
||||
handleWebSocketException(client, new TokenExpiredException(), 'auth');
|
||||
|
||||
waitForMessageType(client, 'auth', this.authentication.timeout).catch((msg: WebSocketMessage) => {
|
||||
const error = new WebSocketException('auth', 'AUTH_TIMEOUT', 'Authentication timed out.', msg?.uid);
|
||||
handleWebSocketException(client, error, 'auth');
|
||||
|
||||
if (this.authentication.mode !== 'public') {
|
||||
client.close();
|
||||
}
|
||||
});
|
||||
}, expiresIn);
|
||||
}
|
||||
|
||||
checkClientTokens() {
|
||||
this.authInterval = setInterval(() => {
|
||||
if (this.clients.size === 0) return;
|
||||
|
||||
// check the clients and set shorter timeouts if needed
|
||||
for (const client of this.clients) {
|
||||
if (client.expires_at === null || client.auth_timer !== null) continue;
|
||||
this.setTokenExpireTimer(client);
|
||||
}
|
||||
}, TOKEN_CHECK_INTERVAL);
|
||||
}
|
||||
|
||||
terminate() {
|
||||
if (this.authInterval) clearInterval(this.authInterval);
|
||||
|
||||
this.clients.forEach((client) => {
|
||||
if (client.auth_timer) clearTimeout(client.auth_timer);
|
||||
});
|
||||
|
||||
this.server.clients.forEach((ws) => {
|
||||
ws.terminate();
|
||||
});
|
||||
}
|
||||
}
|
||||
121
api/src/websocket/controllers/graphql.ts
Normal file
121
api/src/websocket/controllers/graphql.ts
Normal file
@@ -0,0 +1,121 @@
|
||||
import { CloseCode, MessageType, makeServer } from 'graphql-ws';
|
||||
import type { Server } from 'graphql-ws';
|
||||
import type { Server as httpServer } from 'http';
|
||||
import type { WebSocket } from 'ws';
|
||||
import env from '../../env.js';
|
||||
import logger from '../../logger.js';
|
||||
import { bindPubSub } from '../../services/graphql/subscription.js';
|
||||
import { GraphQLService } from '../../services/index.js';
|
||||
import { getSchema } from '../../utils/get-schema.js';
|
||||
import { authenticateConnection, refreshAccountability } from '../authenticate.js';
|
||||
import { handleWebSocketException } from '../exceptions.js';
|
||||
import { ConnectionParams, WebSocketMessage } from '../messages.js';
|
||||
import type { AuthenticationState, GraphQLSocket, UpgradeContext, WebSocketClient } from '../types.js';
|
||||
import { getMessageType } from '../utils/message.js';
|
||||
import SocketController from './base.js';
|
||||
|
||||
export class GraphQLSubscriptionController extends SocketController {
|
||||
gql: Server<GraphQLSocket>;
|
||||
constructor(httpServer: httpServer) {
|
||||
super(httpServer, 'WEBSOCKETS_GRAPHQL');
|
||||
|
||||
this.server.on('connection', (ws: WebSocket, auth: AuthenticationState) => {
|
||||
this.bindEvents(this.createClient(ws, auth));
|
||||
});
|
||||
|
||||
this.gql = makeServer<ConnectionParams, GraphQLSocket>({
|
||||
schema: async (ctx) => {
|
||||
const accountability = ctx.extra.client.accountability;
|
||||
|
||||
// for now only the items will be watched, system events tbd
|
||||
const service = new GraphQLService({
|
||||
schema: await getSchema(),
|
||||
scope: 'items',
|
||||
accountability,
|
||||
});
|
||||
|
||||
return service.getSchema();
|
||||
},
|
||||
});
|
||||
|
||||
bindPubSub();
|
||||
logger.info(`GraphQL Subscriptions started at ws://${env['HOST']}:${env['PORT']}${this.endpoint}`);
|
||||
}
|
||||
|
||||
private bindEvents(client: WebSocketClient) {
|
||||
const closedHandler = this.gql.opened(
|
||||
{
|
||||
protocol: client.protocol,
|
||||
send: (data) =>
|
||||
new Promise((resolve, reject) => {
|
||||
client.send(data, (err) => (err ? reject(err) : resolve()));
|
||||
}),
|
||||
close: (code, reason) => client.close(code, reason), // for standard closures
|
||||
onMessage: (cb) => {
|
||||
client.on('parsed-message', async (message: WebSocketMessage) => {
|
||||
try {
|
||||
if (getMessageType(message) === 'connection_init' && this.authentication.mode !== 'strict') {
|
||||
const params = ConnectionParams.parse(message['payload']);
|
||||
|
||||
if (this.authentication.mode === 'handshake') {
|
||||
if (typeof params.access_token === 'string') {
|
||||
const { accountability, expires_at } = await authenticateConnection({
|
||||
access_token: params.access_token,
|
||||
});
|
||||
|
||||
client.accountability = accountability;
|
||||
client.expires_at = expires_at;
|
||||
} else {
|
||||
client.close(CloseCode.Forbidden, 'Forbidden');
|
||||
return;
|
||||
}
|
||||
}
|
||||
} else if (this.authentication.mode === 'handshake' && !client.accountability?.user) {
|
||||
// the first message should authenticate successfully in this mode
|
||||
client.close(CloseCode.Forbidden, 'Forbidden');
|
||||
return;
|
||||
} else {
|
||||
client.accountability = await refreshAccountability(client.accountability);
|
||||
}
|
||||
|
||||
await cb(JSON.stringify(message));
|
||||
} catch (error) {
|
||||
handleWebSocketException(client, error, MessageType.Error);
|
||||
}
|
||||
});
|
||||
},
|
||||
},
|
||||
{ client }
|
||||
);
|
||||
|
||||
// notify server that the socket closed
|
||||
client.once('close', (code, reason) => closedHandler(code, reason.toString()));
|
||||
|
||||
// check strict authentication status
|
||||
if (this.authentication.mode === 'strict' && !client.accountability?.user) {
|
||||
client.close(CloseCode.Forbidden, 'Forbidden');
|
||||
}
|
||||
}
|
||||
|
||||
override setTokenExpireTimer(client: WebSocketClient) {
|
||||
if (client.auth_timer !== null) {
|
||||
clearTimeout(client.auth_timer);
|
||||
client.auth_timer = null;
|
||||
}
|
||||
|
||||
if (this.authentication.mode !== 'handshake') return;
|
||||
|
||||
client.auth_timer = setTimeout(() => {
|
||||
if (!client.accountability?.user) {
|
||||
client.close(CloseCode.Forbidden, 'Forbidden');
|
||||
}
|
||||
}, this.authentication.timeout);
|
||||
}
|
||||
|
||||
protected override async handleHandshakeUpgrade({ request, socket, head }: UpgradeContext) {
|
||||
this.server.handleUpgrade(request, socket, head, async (ws) => {
|
||||
this.server.emit('connection', ws, { accountability: null, expires_at: null });
|
||||
// actual enforcement is handled by the setTokenExpireTimer function
|
||||
});
|
||||
}
|
||||
}
|
||||
140
api/src/websocket/controllers/hooks.ts
Normal file
140
api/src/websocket/controllers/hooks.ts
Normal file
@@ -0,0 +1,140 @@
|
||||
import emitter from '../../emitter.js';
|
||||
import { getMessenger } from '../../messenger.js';
|
||||
import type { WebSocketEvent } from '../messages.js';
|
||||
|
||||
let actionsRegistered = false;
|
||||
|
||||
export function registerWebSocketEvents() {
|
||||
if (actionsRegistered) return;
|
||||
actionsRegistered = true;
|
||||
|
||||
registerActionHooks([
|
||||
'items',
|
||||
'activity',
|
||||
'collections',
|
||||
'folders',
|
||||
'permissions',
|
||||
'presets',
|
||||
'revisions',
|
||||
'roles',
|
||||
'settings',
|
||||
'users',
|
||||
'webhooks',
|
||||
]);
|
||||
|
||||
registerFieldsHooks();
|
||||
registerFilesHooks();
|
||||
registerRelationsHooks();
|
||||
}
|
||||
|
||||
function registerActionHooks(modules: string[]) {
|
||||
// register event hooks that can be handled in an uniform manner
|
||||
for (const module of modules) {
|
||||
registerAction(module + '.create', ({ key, collection, payload = {} }) => ({
|
||||
collection,
|
||||
action: 'create',
|
||||
key,
|
||||
payload,
|
||||
}));
|
||||
|
||||
registerAction(module + '.update', ({ keys, collection, payload = {} }) => ({
|
||||
collection,
|
||||
action: 'update',
|
||||
keys,
|
||||
payload,
|
||||
}));
|
||||
|
||||
registerAction(module + '.delete', ({ keys, collection, payload = [] }) => ({
|
||||
collection,
|
||||
action: 'delete',
|
||||
keys,
|
||||
payload,
|
||||
}));
|
||||
}
|
||||
}
|
||||
|
||||
function registerFieldsHooks() {
|
||||
// exception for field hooks that don't report `directus_fields` as being the collection
|
||||
registerAction('fields.create', ({ key, payload = {} }) => ({
|
||||
collection: 'directus_fields',
|
||||
action: 'create',
|
||||
key,
|
||||
payload,
|
||||
}));
|
||||
|
||||
registerAction('fields.update', ({ keys, payload = {} }) => ({
|
||||
collection: 'directus_fields',
|
||||
action: 'update',
|
||||
keys,
|
||||
payload,
|
||||
}));
|
||||
|
||||
registerAction('fields.delete', ({ keys, payload = [] }) => ({
|
||||
collection: 'directus_fields',
|
||||
action: 'delete',
|
||||
keys,
|
||||
payload,
|
||||
}));
|
||||
}
|
||||
|
||||
function registerFilesHooks() {
|
||||
// extra event for file uploads that doubles as create event
|
||||
registerAction('files.upload', ({ key, collection, payload = {} }) => ({
|
||||
collection,
|
||||
action: 'create',
|
||||
key,
|
||||
payload,
|
||||
}));
|
||||
|
||||
registerAction('files.update', ({ keys, collection, payload = {} }) => ({
|
||||
collection,
|
||||
action: 'update',
|
||||
keys,
|
||||
payload,
|
||||
}));
|
||||
|
||||
registerAction('files.delete', ({ keys, collection, payload = [] }) => ({
|
||||
collection,
|
||||
action: 'delete',
|
||||
keys,
|
||||
payload,
|
||||
}));
|
||||
}
|
||||
|
||||
function registerRelationsHooks() {
|
||||
// exception for relation hooks that don't report `directus_relations` as being the collection
|
||||
registerAction('relations.create', ({ key, payload = {} }) => ({
|
||||
collection: 'directus_relations',
|
||||
action: 'create',
|
||||
key,
|
||||
payload: { ...payload, key },
|
||||
}));
|
||||
|
||||
registerAction('relations.update', ({ keys, payload = {} }) => ({
|
||||
collection: 'directus_relations',
|
||||
action: 'update',
|
||||
keys,
|
||||
payload,
|
||||
}));
|
||||
|
||||
registerAction('relations.delete', ({ collection, payload = [] }) => ({
|
||||
collection: 'directus_relations',
|
||||
action: 'delete',
|
||||
keys: payload,
|
||||
payload: { collection, fields: payload },
|
||||
}));
|
||||
}
|
||||
|
||||
/**
|
||||
* Wrapper for emitter.onAction to hook into system events
|
||||
* @param event The action event to watch
|
||||
* @param transform Transformer function
|
||||
*/
|
||||
function registerAction(event: string, transform: (args: Record<string, any>) => WebSocketEvent) {
|
||||
const messenger = getMessenger();
|
||||
|
||||
emitter.onAction(event, async (data: Record<string, any>) => {
|
||||
// push the event through the Redis pub/sub
|
||||
messenger.publish('websocket.event', transform(data) as Record<string, any>);
|
||||
});
|
||||
}
|
||||
44
api/src/websocket/controllers/index.ts
Normal file
44
api/src/websocket/controllers/index.ts
Normal file
@@ -0,0 +1,44 @@
|
||||
import type { Server as httpServer } from 'http';
|
||||
import env from '../../env.js';
|
||||
import { ServiceUnavailableException } from '../../index.js';
|
||||
import { toBoolean } from '../../utils/to-boolean.js';
|
||||
import { GraphQLSubscriptionController } from './graphql.js';
|
||||
import { WebSocketController } from './rest.js';
|
||||
|
||||
let websocketController: WebSocketController | undefined;
|
||||
let subscriptionController: GraphQLSubscriptionController | undefined;
|
||||
|
||||
export function createWebSocketController(server: httpServer) {
|
||||
if (toBoolean(env['WEBSOCKETS_REST_ENABLED'])) {
|
||||
websocketController = new WebSocketController(server);
|
||||
}
|
||||
}
|
||||
|
||||
export function getWebSocketController() {
|
||||
if (!toBoolean(env['WEBSOCKETS_ENABLED']) || !toBoolean(env['WEBSOCKETS_REST_ENABLED'])) {
|
||||
throw new ServiceUnavailableException('WebSocket server is disabled', {
|
||||
service: 'get-websocket-controller',
|
||||
});
|
||||
}
|
||||
|
||||
if (!websocketController) {
|
||||
throw new ServiceUnavailableException('WebSocket server is not initialized', {
|
||||
service: 'get-websocket-controller',
|
||||
});
|
||||
}
|
||||
|
||||
return websocketController;
|
||||
}
|
||||
|
||||
export function createSubscriptionController(server: httpServer) {
|
||||
if (toBoolean(env['WEBSOCKETS_GRAPHQL_ENABLED'])) {
|
||||
subscriptionController = new GraphQLSubscriptionController(server);
|
||||
}
|
||||
}
|
||||
|
||||
export function getSubscriptionController() {
|
||||
return subscriptionController;
|
||||
}
|
||||
|
||||
export * from './graphql.js';
|
||||
export * from './rest.js';
|
||||
58
api/src/websocket/controllers/rest.ts
Normal file
58
api/src/websocket/controllers/rest.ts
Normal file
@@ -0,0 +1,58 @@
|
||||
import { parseJSON } from '@directus/utils';
|
||||
import type { Server as httpServer } from 'http';
|
||||
import type WebSocket from 'ws';
|
||||
import emitter from '../../emitter.js';
|
||||
import env from '../../env.js';
|
||||
import logger from '../../logger.js';
|
||||
import { refreshAccountability } from '../authenticate.js';
|
||||
import { WebSocketException, handleWebSocketException } from '../exceptions.js';
|
||||
import { WebSocketMessage } from '../messages.js';
|
||||
import type { AuthenticationState, WebSocketClient } from '../types.js';
|
||||
import SocketController from './base.js';
|
||||
|
||||
export class WebSocketController extends SocketController {
|
||||
constructor(httpServer: httpServer) {
|
||||
super(httpServer, 'WEBSOCKETS_REST');
|
||||
|
||||
this.server.on('connection', (ws: WebSocket, auth: AuthenticationState) => {
|
||||
this.bindEvents(this.createClient(ws, auth));
|
||||
});
|
||||
|
||||
logger.info(`WebSocket Server started at ws://${env['HOST']}:${env['PORT']}${this.endpoint}`);
|
||||
}
|
||||
|
||||
private bindEvents(client: WebSocketClient) {
|
||||
client.on('parsed-message', async (message: WebSocketMessage) => {
|
||||
try {
|
||||
message = WebSocketMessage.parse(await emitter.emitFilter('websocket.message', message, { client }));
|
||||
client.accountability = await refreshAccountability(client.accountability);
|
||||
emitter.emitAction('websocket.message', { message, client });
|
||||
} catch (error) {
|
||||
handleWebSocketException(client, error, 'server');
|
||||
return;
|
||||
}
|
||||
});
|
||||
|
||||
client.on('error', (event: WebSocket.Event) => {
|
||||
emitter.emitAction('websocket.error', { client, event });
|
||||
});
|
||||
|
||||
client.on('close', (event: WebSocket.CloseEvent) => {
|
||||
emitter.emitAction('websocket.close', { client, event });
|
||||
});
|
||||
|
||||
emitter.emitAction('websocket.connect', { client });
|
||||
}
|
||||
|
||||
protected override parseMessage(data: string): WebSocketMessage {
|
||||
let message: WebSocketMessage;
|
||||
|
||||
try {
|
||||
message = parseJSON(data);
|
||||
} catch (err: any) {
|
||||
throw new WebSocketException('server', 'INVALID_PAYLOAD', 'Unable to parse the incoming message.');
|
||||
}
|
||||
|
||||
return message;
|
||||
}
|
||||
}
|
||||
106
api/src/websocket/exceptions.test.ts
Normal file
106
api/src/websocket/exceptions.test.ts
Normal file
@@ -0,0 +1,106 @@
|
||||
import { describe, expect, test, vi } from 'vitest';
|
||||
import type { WebSocketClient } from './types.js';
|
||||
import { BaseException } from '@directus/exceptions';
|
||||
import { InvalidPayloadException } from '../index.js';
|
||||
import { WebSocketException, handleWebSocketException } from './exceptions.js';
|
||||
import { ZodError } from 'zod';
|
||||
import logger from '../logger.js';
|
||||
|
||||
vi.mock('../logger');
|
||||
|
||||
function mockClient() {
|
||||
return {
|
||||
on: vi.fn(),
|
||||
off: vi.fn(),
|
||||
send: vi.fn(),
|
||||
close: vi.fn(),
|
||||
accountability: null,
|
||||
} as unknown as WebSocketClient;
|
||||
}
|
||||
|
||||
describe('WebSocketException', () => {
|
||||
test('with uid', () => {
|
||||
const error = new WebSocketException('type', 'code', 'message', 123);
|
||||
const response = error.toJSON();
|
||||
|
||||
expect(response).toStrictEqual({
|
||||
type: 'type',
|
||||
status: 'error',
|
||||
error: {
|
||||
code: 'code',
|
||||
message: 'message',
|
||||
},
|
||||
uid: 123,
|
||||
});
|
||||
|
||||
expect(error.toMessage()).toBe(JSON.stringify(response));
|
||||
});
|
||||
|
||||
test('without uid', () => {
|
||||
const error = new WebSocketException('type', 'code', 'message');
|
||||
const response = error.toJSON();
|
||||
|
||||
expect(response).toStrictEqual({
|
||||
type: 'type',
|
||||
status: 'error',
|
||||
error: {
|
||||
code: 'code',
|
||||
message: 'message',
|
||||
},
|
||||
});
|
||||
|
||||
expect(error.toMessage()).toBe(JSON.stringify(response));
|
||||
});
|
||||
});
|
||||
|
||||
describe('handleWebSocketException', () => {
|
||||
const type = 'testing';
|
||||
|
||||
test('handle BaseException', () => {
|
||||
const client = mockClient();
|
||||
const error = new BaseException('test', 200, '123');
|
||||
const expected = WebSocketException.fromException(error, type).toMessage();
|
||||
handleWebSocketException(client, error, type);
|
||||
expect(client.send).toBeCalledWith(expected);
|
||||
expect(logger.error).not.toBeCalled();
|
||||
});
|
||||
|
||||
test('handle InvalidPayloadException', () => {
|
||||
const client = mockClient();
|
||||
const error = new InvalidPayloadException('test');
|
||||
const expected = WebSocketException.fromException(error, type).toMessage();
|
||||
handleWebSocketException(client, error, type);
|
||||
expect(client.send).toBeCalledWith(expected);
|
||||
expect(logger.error).not.toBeCalled();
|
||||
});
|
||||
|
||||
test('handle WebSocketException', () => {
|
||||
const client = mockClient();
|
||||
const error = new WebSocketException('type', 'code', 'message', 123);
|
||||
const expected = error.toMessage();
|
||||
handleWebSocketException(client, error, type);
|
||||
expect(client.send).toBeCalledWith(expected);
|
||||
expect(logger.error).not.toBeCalled();
|
||||
});
|
||||
|
||||
test('handle ZodError', () => {
|
||||
const client = mockClient();
|
||||
|
||||
const error = new ZodError([
|
||||
{ message: 'test', code: 'invalid_type', path: ['path'], expected: 'array', received: 'string' },
|
||||
]);
|
||||
|
||||
const expected = WebSocketException.fromZodError(error, type).toMessage();
|
||||
handleWebSocketException(client, error, type);
|
||||
expect(client.send).toBeCalledWith(expected);
|
||||
expect(logger.error).not.toBeCalled();
|
||||
});
|
||||
|
||||
test('unhandled exception', () => {
|
||||
const client = mockClient();
|
||||
const error = new Error('regular error');
|
||||
handleWebSocketException(client, error, type);
|
||||
expect(client.send).not.toBeCalled();
|
||||
expect(logger.error).toBeCalled();
|
||||
});
|
||||
});
|
||||
69
api/src/websocket/exceptions.ts
Normal file
69
api/src/websocket/exceptions.ts
Normal file
@@ -0,0 +1,69 @@
|
||||
import { BaseException } from '@directus/exceptions';
|
||||
import type { WebSocket } from 'ws';
|
||||
import { ZodError } from 'zod';
|
||||
import { fromZodError } from 'zod-validation-error';
|
||||
import logger from '../logger.js';
|
||||
import type { WebSocketResponse } from './messages.js';
|
||||
import type { WebSocketClient } from './types.js';
|
||||
|
||||
export class WebSocketException extends Error {
|
||||
type: string;
|
||||
code: string;
|
||||
uid: string | number | undefined;
|
||||
constructor(type: string, code: string, message: string, uid?: string | number) {
|
||||
super(message);
|
||||
this.type = type;
|
||||
this.code = code;
|
||||
this.uid = uid;
|
||||
}
|
||||
|
||||
toJSON(): WebSocketResponse {
|
||||
const message: WebSocketResponse = {
|
||||
type: this.type,
|
||||
status: 'error',
|
||||
error: {
|
||||
code: this.code,
|
||||
message: this.message,
|
||||
},
|
||||
};
|
||||
|
||||
if (this.uid !== undefined) {
|
||||
message.uid = this.uid;
|
||||
}
|
||||
|
||||
return message;
|
||||
}
|
||||
|
||||
toMessage(): string {
|
||||
return JSON.stringify(this.toJSON());
|
||||
}
|
||||
|
||||
static fromException(error: BaseException, type = 'unknown') {
|
||||
return new WebSocketException(type, error.code, error.message);
|
||||
}
|
||||
|
||||
static fromZodError(error: ZodError, type = 'unknown') {
|
||||
const zError = fromZodError(error);
|
||||
return new WebSocketException(type, 'INVALID_PAYLOAD', zError.message);
|
||||
}
|
||||
}
|
||||
|
||||
export function handleWebSocketException(client: WebSocketClient | WebSocket, error: unknown, type?: string): void {
|
||||
if (error instanceof BaseException) {
|
||||
client.send(WebSocketException.fromException(error, type).toMessage());
|
||||
return;
|
||||
}
|
||||
|
||||
if (error instanceof WebSocketException) {
|
||||
client.send(error.toMessage());
|
||||
return;
|
||||
}
|
||||
|
||||
if (error instanceof ZodError) {
|
||||
client.send(WebSocketException.fromZodError(error, type).toMessage());
|
||||
return;
|
||||
}
|
||||
|
||||
// unhandled exceptions
|
||||
logger.error(`WebSocket unhandled exception ${JSON.stringify({ type, error })}`);
|
||||
}
|
||||
98
api/src/websocket/handlers/heartbeat.test.ts
Normal file
98
api/src/websocket/handlers/heartbeat.test.ts
Normal file
@@ -0,0 +1,98 @@
|
||||
import type { EventContext } from '@directus/types';
|
||||
import { afterEach, beforeEach, describe, expect, test, vi } from 'vitest';
|
||||
import type { Mock } from 'vitest';
|
||||
import emitter from '../../emitter.js';
|
||||
import { WebSocketController, getWebSocketController } from '../controllers/index.js';
|
||||
import type { WebSocketClient } from '../types.js';
|
||||
import { HeartbeatHandler } from './heartbeat.js';
|
||||
|
||||
// mocking
|
||||
vi.mock('../controllers', () => ({
|
||||
getWebSocketController: vi.fn(() => ({
|
||||
clients: new Set(),
|
||||
})),
|
||||
}));
|
||||
|
||||
vi.mock('../../env', async () => {
|
||||
const actual = (await vi.importActual('../../env')) as { default: Record<string, any> };
|
||||
|
||||
const MOCK_ENV = {
|
||||
...actual.default,
|
||||
WEBSOCKETS_HEARTBEAT_PERIOD: 1,
|
||||
};
|
||||
|
||||
return {
|
||||
default: MOCK_ENV,
|
||||
getEnv: () => MOCK_ENV,
|
||||
};
|
||||
});
|
||||
|
||||
function mockClient() {
|
||||
return {
|
||||
on: vi.fn(),
|
||||
off: vi.fn(),
|
||||
send: vi.fn(),
|
||||
close: vi.fn(),
|
||||
} as unknown as WebSocketClient;
|
||||
}
|
||||
|
||||
describe('WebSocket heartbeat handler', () => {
|
||||
let controller: WebSocketController;
|
||||
|
||||
beforeEach(() => {
|
||||
vi.useFakeTimers();
|
||||
controller = getWebSocketController();
|
||||
});
|
||||
|
||||
afterEach(() => {
|
||||
vi.useRealTimers();
|
||||
vi.clearAllMocks();
|
||||
});
|
||||
|
||||
test('client should ping', async () => {
|
||||
// initialize handler
|
||||
new HeartbeatHandler(controller);
|
||||
// connect fake client
|
||||
const fakeClient = mockClient();
|
||||
|
||||
(fakeClient.send as Mock).mockImplementation(() => {
|
||||
//respond with a message
|
||||
emitter.emitAction('websocket.message', { client: fakeClient, message: { type: 'pong' } }, {} as EventContext);
|
||||
});
|
||||
|
||||
controller.clients.add(fakeClient);
|
||||
emitter.emitAction('websocket.connect', {}, {} as EventContext);
|
||||
// wait for ping
|
||||
vi.advanceTimersByTime(1000); // 1sec heartbeat interval
|
||||
expect(fakeClient.send).toBeCalled();
|
||||
// wait for another timeout
|
||||
vi.advanceTimersByTime(1000); // 1sec heartbeat interval
|
||||
expect(fakeClient.send).toBeCalled();
|
||||
// the connection should not have been closed
|
||||
expect(fakeClient.close).not.toBeCalled();
|
||||
});
|
||||
|
||||
test('connection should be closed', async () => {
|
||||
// initialize handler
|
||||
new HeartbeatHandler(controller);
|
||||
// connect fake client
|
||||
const fakeClient = mockClient();
|
||||
controller.clients.add(fakeClient);
|
||||
emitter.emitAction('websocket.connect', {}, {} as EventContext);
|
||||
vi.advanceTimersByTime(2 * 1000); // 2x 1sec heartbeat interval
|
||||
expect(fakeClient.send).toBeCalled();
|
||||
// the connection should have been closed
|
||||
expect(fakeClient.close).toBeCalled();
|
||||
});
|
||||
|
||||
test('the server should pong if the client pings', async () => {
|
||||
// initialize handler
|
||||
new HeartbeatHandler(controller);
|
||||
// connect fake client
|
||||
const fakeClient = mockClient();
|
||||
controller.clients.add(fakeClient);
|
||||
emitter.emitAction('websocket.connect', {}, {} as EventContext);
|
||||
emitter.emitAction('websocket.message', { client: fakeClient, message: { type: 'ping' } }, {} as EventContext);
|
||||
expect(fakeClient.send).toBeCalled();
|
||||
});
|
||||
});
|
||||
87
api/src/websocket/handlers/heartbeat.ts
Normal file
87
api/src/websocket/handlers/heartbeat.ts
Normal file
@@ -0,0 +1,87 @@
|
||||
import type { ActionHandler } from '@directus/types';
|
||||
import emitter from '../../emitter.js';
|
||||
import env from '../../env.js';
|
||||
import { toBoolean } from '../../utils/to-boolean.js';
|
||||
import { WebSocketController, getWebSocketController } from '../controllers/index.js';
|
||||
import { WebSocketMessage } from '../messages.js';
|
||||
import type { WebSocketClient } from '../types.js';
|
||||
import { fmtMessage, getMessageType } from '../utils/message.js';
|
||||
|
||||
const HEARTBEAT_FREQUENCY = Number(env['WEBSOCKETS_HEARTBEAT_PERIOD']) * 1000;
|
||||
|
||||
export class HeartbeatHandler {
|
||||
private pulse: NodeJS.Timer | undefined;
|
||||
private controller: WebSocketController;
|
||||
|
||||
constructor(controller?: WebSocketController) {
|
||||
this.controller = controller ?? getWebSocketController();
|
||||
|
||||
emitter.onAction('websocket.message', ({ client, message }) => {
|
||||
try {
|
||||
this.onMessage(client, WebSocketMessage.parse(message));
|
||||
} catch {
|
||||
/* ignore errors */
|
||||
}
|
||||
});
|
||||
|
||||
if (toBoolean(env['WEBSOCKETS_HEARTBEAT_ENABLED']) === true) {
|
||||
emitter.onAction('websocket.connect', () => this.checkClients());
|
||||
emitter.onAction('websocket.error', () => this.checkClients());
|
||||
emitter.onAction('websocket.close', () => this.checkClients());
|
||||
}
|
||||
}
|
||||
|
||||
private checkClients() {
|
||||
const hasClients = this.controller.clients.size > 0;
|
||||
|
||||
if (hasClients && !this.pulse) {
|
||||
this.pulse = setInterval(() => {
|
||||
this.pingClients();
|
||||
}, HEARTBEAT_FREQUENCY);
|
||||
}
|
||||
|
||||
if (!hasClients && this.pulse) {
|
||||
clearInterval(this.pulse);
|
||||
this.pulse = undefined;
|
||||
}
|
||||
}
|
||||
|
||||
onMessage(client: WebSocketClient, message: WebSocketMessage) {
|
||||
if (getMessageType(message) !== 'ping') return;
|
||||
// send pong message back as acknowledgement
|
||||
const data = 'uid' in message ? { uid: message.uid } : {};
|
||||
client.send(fmtMessage('pong', data));
|
||||
}
|
||||
|
||||
pingClients() {
|
||||
const pendingClients = new Set<WebSocketClient>(this.controller.clients);
|
||||
const activeClients = new Set<WebSocketClient>();
|
||||
|
||||
const timeout = setTimeout(() => {
|
||||
// close connections that haven't responded
|
||||
for (const client of pendingClients) {
|
||||
client.close();
|
||||
}
|
||||
}, HEARTBEAT_FREQUENCY);
|
||||
|
||||
const messageWatcher: ActionHandler = ({ client }) => {
|
||||
// any message means this connection is still open
|
||||
if (!activeClients.has(client)) {
|
||||
pendingClients.delete(client);
|
||||
activeClients.add(client);
|
||||
}
|
||||
|
||||
if (pendingClients.size === 0) {
|
||||
clearTimeout(timeout);
|
||||
emitter.offAction('websocket.message', messageWatcher);
|
||||
}
|
||||
};
|
||||
|
||||
emitter.onAction('websocket.message', messageWatcher);
|
||||
|
||||
// ping all the clients
|
||||
for (const client of pendingClients) {
|
||||
client.send(fmtMessage('ping'));
|
||||
}
|
||||
}
|
||||
}
|
||||
13
api/src/websocket/handlers/index.ts
Normal file
13
api/src/websocket/handlers/index.ts
Normal file
@@ -0,0 +1,13 @@
|
||||
import { HeartbeatHandler } from './heartbeat.js';
|
||||
import { ItemsHandler } from './items.js';
|
||||
import { SubscribeHandler } from './subscribe.js';
|
||||
|
||||
export function startWebSocketHandlers() {
|
||||
new HeartbeatHandler();
|
||||
new ItemsHandler();
|
||||
new SubscribeHandler();
|
||||
}
|
||||
|
||||
export * from './heartbeat.js';
|
||||
export * from './items.js';
|
||||
export * from './subscribe.js';
|
||||
295
api/src/websocket/handlers/items.test.ts
Normal file
295
api/src/websocket/handlers/items.test.ts
Normal file
@@ -0,0 +1,295 @@
|
||||
import type { EventContext } from '@directus/types';
|
||||
import { afterEach, beforeEach, describe, expect, test, vi } from 'vitest';
|
||||
import type { Mock } from 'vitest';
|
||||
import emitter from '../../emitter.js';
|
||||
import { ItemsService, MetaService } from '../../services/index.js';
|
||||
import { getSchema } from '../../utils/get-schema.js';
|
||||
import type { WebSocketClient } from '../types.js';
|
||||
import { ItemsHandler } from './items.js';
|
||||
|
||||
// mocking
|
||||
vi.mock('../controllers', () => ({
|
||||
getWebSocketController: vi.fn(() => ({
|
||||
clients: new Set(),
|
||||
})),
|
||||
}));
|
||||
|
||||
vi.mock('../../utils/get-schema', () => ({
|
||||
getSchema: vi.fn(),
|
||||
}));
|
||||
|
||||
vi.mock('../../services', () => ({
|
||||
ItemsService: vi.fn(),
|
||||
MetaService: vi.fn(),
|
||||
}));
|
||||
|
||||
function mockClient() {
|
||||
return {
|
||||
on: vi.fn(),
|
||||
off: vi.fn(),
|
||||
send: vi.fn(),
|
||||
close: vi.fn(),
|
||||
accountability: null,
|
||||
} as unknown as WebSocketClient;
|
||||
}
|
||||
|
||||
describe('WebSocket heartbeat handler', () => {
|
||||
let handler: ItemsHandler;
|
||||
|
||||
beforeEach(() => {
|
||||
vi.useFakeTimers();
|
||||
// initialize handler
|
||||
handler = new ItemsHandler();
|
||||
});
|
||||
|
||||
afterEach(() => {
|
||||
vi.useRealTimers();
|
||||
emitter.offAll();
|
||||
vi.clearAllMocks();
|
||||
});
|
||||
|
||||
test('ignore other message types', async () => {
|
||||
const spy = vi.spyOn(handler, 'onMessage');
|
||||
|
||||
// receive message
|
||||
emitter.emitAction(
|
||||
'websocket.message',
|
||||
{
|
||||
client: mockClient(),
|
||||
message: { type: 'pong' },
|
||||
},
|
||||
{} as EventContext
|
||||
);
|
||||
|
||||
// expect nothing
|
||||
expect(spy).not.toBeCalled();
|
||||
});
|
||||
|
||||
test('invalid collection should error', async () => {
|
||||
(getSchema as Mock).mockImplementation(() => ({ collections: {} }));
|
||||
// receive message
|
||||
const fakeClient = mockClient();
|
||||
|
||||
emitter.emitAction(
|
||||
'websocket.message',
|
||||
{
|
||||
client: fakeClient,
|
||||
message: { type: 'items', collection: 'test', action: 'create', data: {} },
|
||||
},
|
||||
{} as EventContext
|
||||
);
|
||||
|
||||
await vi.runAllTimersAsync(); // flush promises to make sure the event is handled
|
||||
|
||||
// expect error
|
||||
expect(fakeClient.send).toBeCalledWith(
|
||||
'{"type":"items","status":"error","error":{"code":"INVALID_COLLECTION","message":"The provided collection does not exists or is not accessible."}}'
|
||||
);
|
||||
});
|
||||
|
||||
test('create one item', async () => {
|
||||
// do mocking
|
||||
(getSchema as Mock).mockImplementation(() => ({ collections: { test: [] } }));
|
||||
|
||||
const createOne = vi.fn(),
|
||||
readOne = vi.fn();
|
||||
|
||||
(ItemsService as Mock).mockImplementation(() => ({ createOne, readOne }));
|
||||
// receive message
|
||||
const fakeClient = mockClient();
|
||||
|
||||
emitter.emitAction(
|
||||
'websocket.message',
|
||||
{
|
||||
client: fakeClient,
|
||||
message: { type: 'items', collection: 'test', action: 'create', data: {} },
|
||||
},
|
||||
{} as EventContext
|
||||
);
|
||||
|
||||
await vi.runAllTimersAsync(); // flush promises to make sure the event is handled
|
||||
// expect service functions
|
||||
expect(createOne).toBeCalled();
|
||||
expect(readOne).toBeCalled();
|
||||
expect(fakeClient.send).toBeCalled();
|
||||
});
|
||||
|
||||
test('create multiple items', async () => {
|
||||
// do mocking
|
||||
(getSchema as Mock).mockImplementation(() => ({ collections: { test: [] } }));
|
||||
|
||||
const createMany = vi.fn(),
|
||||
readMany = vi.fn();
|
||||
|
||||
(ItemsService as Mock).mockImplementation(() => ({ createMany, readMany }));
|
||||
// receive message
|
||||
const fakeClient = mockClient();
|
||||
|
||||
emitter.emitAction(
|
||||
'websocket.message',
|
||||
{
|
||||
client: fakeClient,
|
||||
message: { type: 'items', collection: 'test', action: 'create', data: [{}, {}] },
|
||||
},
|
||||
{} as EventContext
|
||||
);
|
||||
|
||||
await vi.runAllTimersAsync(); // flush promises to make sure the event is handled
|
||||
// expect service functions
|
||||
expect(createMany).toBeCalled();
|
||||
expect(readMany).toBeCalled();
|
||||
expect(fakeClient.send).toBeCalled();
|
||||
});
|
||||
|
||||
test('read by query', async () => {
|
||||
// do mocking
|
||||
(getSchema as Mock).mockImplementation(() => ({ collections: { test: [] } }));
|
||||
const readByQuery = vi.fn();
|
||||
(ItemsService as Mock).mockImplementation(() => ({ readByQuery }));
|
||||
const getMetaForQuery = vi.fn();
|
||||
(MetaService as Mock).mockImplementation(() => ({ getMetaForQuery }));
|
||||
// receive message
|
||||
const fakeClient = mockClient();
|
||||
|
||||
emitter.emitAction(
|
||||
'websocket.message',
|
||||
{
|
||||
client: fakeClient,
|
||||
message: { type: 'items', collection: 'test', action: 'read', query: {} },
|
||||
},
|
||||
{} as EventContext
|
||||
);
|
||||
|
||||
await vi.runAllTimersAsync(); // flush promises to make sure the event is handled
|
||||
// expect service functions
|
||||
expect(readByQuery).toBeCalled();
|
||||
expect(getMetaForQuery).toBeCalled();
|
||||
expect(fakeClient.send).toBeCalled();
|
||||
});
|
||||
|
||||
test('update one item', async () => {
|
||||
// do mocking
|
||||
(getSchema as Mock).mockImplementation(() => ({ collections: { test: [] } }));
|
||||
|
||||
const updateOne = vi.fn(),
|
||||
readOne = vi.fn();
|
||||
|
||||
(ItemsService as Mock).mockImplementation(() => ({ updateOne, readOne }));
|
||||
// receive message
|
||||
const fakeClient = mockClient();
|
||||
|
||||
emitter.emitAction(
|
||||
'websocket.message',
|
||||
{
|
||||
client: fakeClient,
|
||||
message: { type: 'items', collection: 'test', action: 'update', data: {}, id: '123' },
|
||||
},
|
||||
{} as EventContext
|
||||
);
|
||||
|
||||
await vi.runAllTimersAsync(); // flush promises to make sure the event is handled
|
||||
// expect service functions
|
||||
expect(updateOne).toBeCalled();
|
||||
expect(readOne).toBeCalled();
|
||||
expect(fakeClient.send).toBeCalled();
|
||||
});
|
||||
|
||||
test('update multiple items', async () => {
|
||||
// do mocking
|
||||
(getSchema as Mock).mockImplementation(() => ({ collections: { test: [] } }));
|
||||
|
||||
const updateMany = vi.fn(),
|
||||
readMany = vi.fn();
|
||||
|
||||
(ItemsService as Mock).mockImplementation(() => ({ updateMany, readMany }));
|
||||
const getMetaForQuery = vi.fn();
|
||||
(MetaService as Mock).mockImplementation(() => ({ getMetaForQuery }));
|
||||
// receive message
|
||||
const fakeClient = mockClient();
|
||||
|
||||
emitter.emitAction(
|
||||
'websocket.message',
|
||||
{
|
||||
client: fakeClient,
|
||||
message: { type: 'items', collection: 'test', action: 'update', data: {}, ids: ['123', '456'] },
|
||||
},
|
||||
{} as EventContext
|
||||
);
|
||||
|
||||
await vi.runAllTimersAsync(); // flush promises to make sure the event is handled
|
||||
// expect service functions
|
||||
expect(updateMany).toBeCalled();
|
||||
expect(getMetaForQuery).toBeCalled();
|
||||
expect(readMany).toBeCalled();
|
||||
expect(fakeClient.send).toBeCalled();
|
||||
});
|
||||
|
||||
test('delete one item', async () => {
|
||||
// do mocking
|
||||
(getSchema as Mock).mockImplementation(() => ({ collections: { test: [] } }));
|
||||
const deleteOne = vi.fn();
|
||||
(ItemsService as Mock).mockImplementation(() => ({ deleteOne }));
|
||||
// receive message
|
||||
const fakeClient = mockClient();
|
||||
|
||||
emitter.emitAction(
|
||||
'websocket.message',
|
||||
{
|
||||
client: fakeClient,
|
||||
message: { type: 'items', collection: 'test', action: 'delete', id: '123' },
|
||||
},
|
||||
{} as EventContext
|
||||
);
|
||||
|
||||
await vi.runAllTimersAsync(); // flush promises to make sure the event is handled
|
||||
// expect service functions
|
||||
expect(deleteOne).toBeCalled();
|
||||
expect(fakeClient.send).toBeCalled();
|
||||
});
|
||||
|
||||
test('delete multiple items by id', async () => {
|
||||
// do mocking
|
||||
(getSchema as Mock).mockImplementation(() => ({ collections: { test: [] } }));
|
||||
const deleteMany = vi.fn();
|
||||
(ItemsService as Mock).mockImplementation(() => ({ deleteMany }));
|
||||
// receive message
|
||||
const fakeClient = mockClient();
|
||||
|
||||
emitter.emitAction(
|
||||
'websocket.message',
|
||||
{
|
||||
client: fakeClient,
|
||||
message: { type: 'items', collection: 'test', action: 'delete', ids: ['123', 456] },
|
||||
},
|
||||
{} as EventContext
|
||||
);
|
||||
|
||||
await vi.runAllTimersAsync(); // flush promises to make sure the event is handled
|
||||
// expect service functions
|
||||
expect(deleteMany).toBeCalled();
|
||||
expect(fakeClient.send).toBeCalled();
|
||||
});
|
||||
|
||||
test('delete multiple items by query', async () => {
|
||||
// do mocking
|
||||
(getSchema as Mock).mockImplementation(() => ({ collections: { test: [] } }));
|
||||
const deleteByQuery = vi.fn();
|
||||
(ItemsService as Mock).mockImplementation(() => ({ deleteByQuery }));
|
||||
// receive message
|
||||
const fakeClient = mockClient();
|
||||
|
||||
emitter.emitAction(
|
||||
'websocket.message',
|
||||
{
|
||||
client: fakeClient,
|
||||
message: { type: 'items', collection: 'test', action: 'delete', query: {} },
|
||||
},
|
||||
{} as EventContext
|
||||
);
|
||||
|
||||
await vi.runAllTimersAsync(); // flush promises to make sure the event is handled
|
||||
// expect service functions
|
||||
expect(deleteByQuery).toBeCalled();
|
||||
expect(fakeClient.send).toBeCalled();
|
||||
});
|
||||
});
|
||||
117
api/src/websocket/handlers/items.ts
Normal file
117
api/src/websocket/handlers/items.ts
Normal file
@@ -0,0 +1,117 @@
|
||||
import emitter from '../../emitter.js';
|
||||
import { ItemsService, MetaService } from '../../services/index.js';
|
||||
import { getSchema } from '../../utils/get-schema.js';
|
||||
import { sanitizeQuery } from '../../utils/sanitize-query.js';
|
||||
import { WebSocketException, handleWebSocketException } from '../exceptions.js';
|
||||
import { WebSocketItemsMessage } from '../messages.js';
|
||||
import type { WebSocketClient } from '../types.js';
|
||||
import { fmtMessage, getMessageType } from '../utils/message.js';
|
||||
|
||||
export class ItemsHandler {
|
||||
constructor() {
|
||||
emitter.onAction('websocket.message', ({ client, message }) => {
|
||||
if (getMessageType(message) !== 'items') return;
|
||||
|
||||
try {
|
||||
const parsedMessage = WebSocketItemsMessage.parse(message);
|
||||
|
||||
this.onMessage(client, parsedMessage).catch((err) => {
|
||||
// this catch is required because the async onMessage function is not awaited
|
||||
handleWebSocketException(client, err, 'items');
|
||||
});
|
||||
} catch (err) {
|
||||
handleWebSocketException(client, err, 'items');
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
async onMessage(client: WebSocketClient, message: WebSocketItemsMessage) {
|
||||
const uid = message.uid;
|
||||
const accountability = client.accountability;
|
||||
const schema = await getSchema();
|
||||
|
||||
if (!schema.collections[message.collection] || message.collection.startsWith('directus_')) {
|
||||
throw new WebSocketException(
|
||||
'items',
|
||||
'INVALID_COLLECTION',
|
||||
'The provided collection does not exists or is not accessible.',
|
||||
uid
|
||||
);
|
||||
}
|
||||
|
||||
const isSingleton = !!schema.collections[message.collection]?.singleton;
|
||||
const service = new ItemsService(message.collection, { schema, accountability });
|
||||
const metaService = new MetaService({ schema, accountability });
|
||||
let result, meta;
|
||||
|
||||
if (message.action === 'create') {
|
||||
const query = sanitizeQuery(message?.query ?? {}, accountability);
|
||||
|
||||
if (Array.isArray(message.data)) {
|
||||
const keys = await service.createMany(message.data);
|
||||
result = await service.readMany(keys, query);
|
||||
} else {
|
||||
const key = await service.createOne(message.data);
|
||||
result = await service.readOne(key, query);
|
||||
}
|
||||
}
|
||||
|
||||
if (message.action === 'read') {
|
||||
const query = sanitizeQuery(message.query ?? {}, accountability);
|
||||
|
||||
if (message.id) {
|
||||
result = await service.readOne(message.id, query);
|
||||
} else if (message.ids) {
|
||||
result = await service.readMany(message.ids, query);
|
||||
} else if (isSingleton) {
|
||||
result = await service.readSingleton(query);
|
||||
} else {
|
||||
result = await service.readByQuery(query);
|
||||
}
|
||||
|
||||
meta = await metaService.getMetaForQuery(message.collection, query);
|
||||
}
|
||||
|
||||
if (message.action === 'update') {
|
||||
const query = sanitizeQuery(message.query ?? {}, accountability);
|
||||
|
||||
if (message.id) {
|
||||
const key = await service.updateOne(message.id, message.data);
|
||||
result = await service.readOne(key);
|
||||
} else if (message.ids) {
|
||||
const keys = await service.updateMany(message.ids, message.data);
|
||||
meta = await metaService.getMetaForQuery(message.collection, query);
|
||||
result = await service.readMany(keys, query);
|
||||
} else if (isSingleton) {
|
||||
await service.upsertSingleton(message.data);
|
||||
result = await service.readSingleton(query);
|
||||
} else {
|
||||
const keys = await service.updateByQuery(query, message.data);
|
||||
meta = await metaService.getMetaForQuery(message.collection, query);
|
||||
result = await service.readMany(keys, query);
|
||||
}
|
||||
}
|
||||
|
||||
if (message.action === 'delete') {
|
||||
if (message.id) {
|
||||
await service.deleteOne(message.id);
|
||||
result = message.id;
|
||||
} else if (message.ids) {
|
||||
await service.deleteMany(message.ids);
|
||||
result = message.ids;
|
||||
} else if (message.query) {
|
||||
const query = sanitizeQuery(message.query, accountability);
|
||||
result = await service.deleteByQuery(query);
|
||||
} else {
|
||||
throw new WebSocketException(
|
||||
'items',
|
||||
'INVALID_PAYLOAD',
|
||||
"Either 'ids', 'id' or 'query' is required for a DELETE request.",
|
||||
uid
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
client.send(fmtMessage('items', { data: result, ...(meta ? { meta } : {}) }, uid));
|
||||
}
|
||||
}
|
||||
282
api/src/websocket/handlers/subscribe.test.ts
Normal file
282
api/src/websocket/handlers/subscribe.test.ts
Normal file
@@ -0,0 +1,282 @@
|
||||
import { expect, describe, test, vi, beforeEach, afterEach } from 'vitest';
|
||||
import emitter from '../../emitter.js';
|
||||
import { SubscribeHandler } from './subscribe.js';
|
||||
import type { WebSocketClient } from '../types.js';
|
||||
import { getSchema } from '../../utils/get-schema.js';
|
||||
import type { CollectionsOverview, Relation } from '@directus/types';
|
||||
|
||||
// mocking
|
||||
vi.mock('../controllers', () => ({
|
||||
getWebSocketController: vi.fn(() => ({
|
||||
clients: new Set(),
|
||||
})),
|
||||
}));
|
||||
|
||||
vi.mock('../../utils/get-schema', () => ({
|
||||
getSchema: vi.fn(),
|
||||
}));
|
||||
|
||||
vi.mock('../../services', () => ({
|
||||
ItemsService: vi.fn(() => ({
|
||||
readByQuery: vi.fn(),
|
||||
})),
|
||||
MetaService: vi.fn(),
|
||||
}));
|
||||
|
||||
vi.mock('../../database/index');
|
||||
|
||||
function mockClient() {
|
||||
return {
|
||||
on: vi.fn(),
|
||||
off: vi.fn(),
|
||||
send: vi.fn(),
|
||||
close: vi.fn(),
|
||||
accountability: null,
|
||||
} as unknown as WebSocketClient;
|
||||
}
|
||||
|
||||
function delay(ms: number) {
|
||||
return new Promise<void>((resolve) => {
|
||||
setTimeout(() => resolve(), ms);
|
||||
});
|
||||
}
|
||||
|
||||
describe('WebSocket heartbeat handler', () => {
|
||||
let handler: SubscribeHandler;
|
||||
|
||||
beforeEach(() => {
|
||||
// initialize handler
|
||||
handler = new SubscribeHandler();
|
||||
});
|
||||
|
||||
afterEach(() => {
|
||||
emitter.offAll();
|
||||
vi.clearAllMocks();
|
||||
});
|
||||
|
||||
test('ignore other message types', async () => {
|
||||
const spy = vi.spyOn(handler, 'onMessage');
|
||||
|
||||
// receive message
|
||||
emitter.emitAction('websocket.message', {
|
||||
client: mockClient(),
|
||||
message: { type: 'ping' },
|
||||
});
|
||||
|
||||
// expect nothing
|
||||
expect(spy).not.toBeCalled();
|
||||
});
|
||||
|
||||
test('should fail subscribe to non-existing collection', async () => {
|
||||
vi.mocked(getSchema).mockImplementation(async () => ({
|
||||
collections: {} as CollectionsOverview,
|
||||
relations: [] as Relation[],
|
||||
}));
|
||||
|
||||
const subscribe = vi.spyOn(handler, 'subscribe');
|
||||
const onMessage = vi.spyOn(handler, 'onMessage');
|
||||
|
||||
// receive message
|
||||
emitter.emitAction('websocket.message', {
|
||||
client: mockClient(),
|
||||
message: {
|
||||
type: 'subscribe',
|
||||
collection: 'does_not_exist',
|
||||
},
|
||||
});
|
||||
|
||||
await delay(10);
|
||||
|
||||
// expect
|
||||
expect(onMessage).toBeCalled();
|
||||
expect(subscribe).not.toBeCalled();
|
||||
});
|
||||
|
||||
test('should subscribe/unsubscribe to collection', async () => {
|
||||
const client = mockClient();
|
||||
|
||||
vi.mocked(getSchema).mockImplementation(async () => ({
|
||||
collections: {
|
||||
test_collection: {
|
||||
collection: 'test_collection',
|
||||
primary: 'id',
|
||||
singleton: false,
|
||||
sortField: null,
|
||||
note: null,
|
||||
accountability: null,
|
||||
fields: {},
|
||||
},
|
||||
} as CollectionsOverview,
|
||||
relations: [] as Relation[],
|
||||
}));
|
||||
|
||||
const subscribe = vi.spyOn(handler, 'subscribe');
|
||||
const onMessage = vi.spyOn(handler, 'onMessage');
|
||||
|
||||
// receive message
|
||||
emitter.emitAction('websocket.message', {
|
||||
client,
|
||||
message: {
|
||||
type: 'subscribe',
|
||||
collection: 'test_collection',
|
||||
uid: '123',
|
||||
},
|
||||
});
|
||||
|
||||
await delay(10);
|
||||
|
||||
// expect
|
||||
expect(onMessage).toBeCalled();
|
||||
expect(subscribe).toBeCalled();
|
||||
expect(handler.subscriptions['test_collection']?.size).toBe(1);
|
||||
});
|
||||
|
||||
test('unsubscribe a specific subscription', async () => {
|
||||
const client = mockClient();
|
||||
|
||||
vi.mocked(getSchema).mockImplementation(async () => ({
|
||||
collections: {
|
||||
test_collection: {
|
||||
collection: 'test_collection',
|
||||
primary: 'id',
|
||||
singleton: false,
|
||||
sortField: null,
|
||||
note: null,
|
||||
accountability: null,
|
||||
fields: {},
|
||||
},
|
||||
other_collection: {
|
||||
collection: 'other_collection',
|
||||
primary: 'id',
|
||||
singleton: false,
|
||||
sortField: null,
|
||||
note: null,
|
||||
accountability: null,
|
||||
fields: {},
|
||||
},
|
||||
} as CollectionsOverview,
|
||||
relations: [] as Relation[],
|
||||
}));
|
||||
|
||||
const unsubscribe = vi.spyOn(handler, 'unsubscribe');
|
||||
const subscribe = vi.spyOn(handler, 'subscribe');
|
||||
const onMessage = vi.spyOn(handler, 'onMessage');
|
||||
|
||||
// subscribe
|
||||
emitter.emitAction('websocket.message', {
|
||||
client,
|
||||
message: {
|
||||
type: 'subscribe',
|
||||
collection: 'test_collection',
|
||||
uid: '123',
|
||||
},
|
||||
});
|
||||
|
||||
emitter.emitAction('websocket.message', {
|
||||
client,
|
||||
message: {
|
||||
type: 'subscribe',
|
||||
collection: 'other_collection',
|
||||
uid: '456',
|
||||
},
|
||||
});
|
||||
|
||||
await delay(10);
|
||||
|
||||
// expect
|
||||
expect(onMessage).toBeCalledTimes(2);
|
||||
expect(subscribe).toBeCalledTimes(2);
|
||||
expect(handler.subscriptions['test_collection']?.size).toBe(1);
|
||||
expect(handler.subscriptions['other_collection']?.size).toBe(1);
|
||||
|
||||
// unsubscribe
|
||||
emitter.emitAction('websocket.message', {
|
||||
client,
|
||||
message: {
|
||||
type: 'unsubscribe',
|
||||
uid: '123',
|
||||
},
|
||||
});
|
||||
|
||||
await delay(10);
|
||||
|
||||
// expect
|
||||
expect(unsubscribe).toBeCalled();
|
||||
expect(handler.subscriptions['test_collection']?.size).toBe(0);
|
||||
expect(handler.subscriptions['other_collection']?.size).toBe(1);
|
||||
});
|
||||
|
||||
test('unsubscribe all subscriptions', async () => {
|
||||
const client = mockClient();
|
||||
|
||||
vi.mocked(getSchema).mockImplementation(async () => ({
|
||||
collections: {
|
||||
test_collection: {
|
||||
collection: 'test_collection',
|
||||
primary: 'id',
|
||||
singleton: false,
|
||||
sortField: null,
|
||||
note: null,
|
||||
accountability: null,
|
||||
fields: {},
|
||||
},
|
||||
other_collection: {
|
||||
collection: 'other_collection',
|
||||
primary: 'id',
|
||||
singleton: false,
|
||||
sortField: null,
|
||||
note: null,
|
||||
accountability: null,
|
||||
fields: {},
|
||||
},
|
||||
} as CollectionsOverview,
|
||||
relations: [] as Relation[],
|
||||
}));
|
||||
|
||||
const unsubscribe = vi.spyOn(handler, 'unsubscribe');
|
||||
const subscribe = vi.spyOn(handler, 'subscribe');
|
||||
const onMessage = vi.spyOn(handler, 'onMessage');
|
||||
|
||||
// subscribe
|
||||
emitter.emitAction('websocket.message', {
|
||||
client,
|
||||
message: {
|
||||
type: 'subscribe',
|
||||
collection: 'test_collection',
|
||||
uid: '123',
|
||||
},
|
||||
});
|
||||
|
||||
emitter.emitAction('websocket.message', {
|
||||
client,
|
||||
message: {
|
||||
type: 'subscribe',
|
||||
collection: 'other_collection',
|
||||
uid: '456',
|
||||
},
|
||||
});
|
||||
|
||||
await delay(10);
|
||||
|
||||
// expect
|
||||
expect(onMessage).toBeCalledTimes(2);
|
||||
expect(subscribe).toBeCalledTimes(2);
|
||||
expect(handler.subscriptions['test_collection']?.size).toBe(1);
|
||||
expect(handler.subscriptions['other_collection']?.size).toBe(1);
|
||||
|
||||
// unsubscribe
|
||||
emitter.emitAction('websocket.message', {
|
||||
client,
|
||||
message: {
|
||||
type: 'unsubscribe',
|
||||
},
|
||||
});
|
||||
|
||||
await delay(10);
|
||||
|
||||
// expect
|
||||
expect(unsubscribe).toBeCalled();
|
||||
expect(handler.subscriptions['test_collection']?.size).toBe(0);
|
||||
expect(handler.subscriptions['other_collection']?.size).toBe(0);
|
||||
});
|
||||
});
|
||||
342
api/src/websocket/handlers/subscribe.ts
Normal file
342
api/src/websocket/handlers/subscribe.ts
Normal file
@@ -0,0 +1,342 @@
|
||||
import type { Accountability, SchemaOverview } from '@directus/types';
|
||||
import emitter from '../../emitter.js';
|
||||
import { InvalidPayloadException } from '../../index.js';
|
||||
import { getMessenger } from '../../messenger.js';
|
||||
import type { Messenger } from '../../messenger.js';
|
||||
import { CollectionsService, FieldsService, MetaService } from '../../services/index.js';
|
||||
import { getSchema } from '../../utils/get-schema.js';
|
||||
import { getService } from '../../utils/get-service.js';
|
||||
import { sanitizeQuery } from '../../utils/sanitize-query.js';
|
||||
import { refreshAccountability } from '../authenticate.js';
|
||||
import { WebSocketException, handleWebSocketException } from '../exceptions.js';
|
||||
import type { WebSocketEvent } from '../messages.js';
|
||||
import { WebSocketSubscribeMessage } from '../messages.js';
|
||||
import type { Subscription, SubscriptionEvent, WebSocketClient } from '../types.js';
|
||||
import { fmtMessage, getMessageType } from '../utils/message.js';
|
||||
|
||||
/**
|
||||
* Handler responsible for subscriptions
|
||||
*/
|
||||
export class SubscribeHandler {
|
||||
// storage of subscriptions per collection
|
||||
subscriptions: Record<string, Set<Subscription>>;
|
||||
// internal message bus
|
||||
protected messenger: Messenger;
|
||||
/**
|
||||
* Initialize the handler
|
||||
*/
|
||||
constructor() {
|
||||
this.subscriptions = {};
|
||||
this.messenger = getMessenger();
|
||||
this.bindWebSocket();
|
||||
|
||||
// listen to the Redis pub/sub and dispatch
|
||||
this.messenger.subscribe('websocket.event', (message: Record<string, any>) => {
|
||||
try {
|
||||
this.dispatch(message as WebSocketEvent);
|
||||
} catch {
|
||||
// don't error on an invalid event from the messenger
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
/**
|
||||
* Hook into websocket client lifecycle events
|
||||
*/
|
||||
bindWebSocket() {
|
||||
// listen to incoming messages on the connected websockets
|
||||
emitter.onAction('websocket.message', ({ client, message }) => {
|
||||
if (!['subscribe', 'unsubscribe'].includes(getMessageType(message))) return;
|
||||
|
||||
try {
|
||||
this.onMessage(client, WebSocketSubscribeMessage.parse(message));
|
||||
} catch (error) {
|
||||
handleWebSocketException(client, error, 'subscribe');
|
||||
}
|
||||
});
|
||||
|
||||
// unsubscribe when a connection drops
|
||||
emitter.onAction('websocket.error', ({ client }) => this.unsubscribe(client));
|
||||
emitter.onAction('websocket.close', ({ client }) => this.unsubscribe(client));
|
||||
}
|
||||
|
||||
/**
|
||||
* Register a subscription
|
||||
* @param subscription
|
||||
*/
|
||||
subscribe(subscription: Subscription) {
|
||||
const { collection } = subscription;
|
||||
|
||||
if ('item' in subscription && ['directus_fields', 'directus_relations'].includes(collection)) {
|
||||
throw new InvalidPayloadException(`Cannot subscribe to a specific item in the ${collection} collection.`);
|
||||
}
|
||||
|
||||
if (!this.subscriptions[collection]) {
|
||||
this.subscriptions[collection] = new Set();
|
||||
}
|
||||
|
||||
this.subscriptions[collection]?.add(subscription);
|
||||
}
|
||||
|
||||
/**
|
||||
* Remove a subscription
|
||||
* @param subscription
|
||||
*/
|
||||
unsubscribe(client: WebSocketClient, uid?: string | number) {
|
||||
if (uid !== undefined) {
|
||||
const subscription = this.getSubscription(client, String(uid));
|
||||
|
||||
if (subscription) {
|
||||
this.subscriptions[subscription.collection]?.delete(subscription);
|
||||
}
|
||||
} else {
|
||||
for (const key of Object.keys(this.subscriptions)) {
|
||||
const subscriptions = Array.from(this.subscriptions[key] || []);
|
||||
|
||||
for (let i = subscriptions.length - 1; i >= 0; i--) {
|
||||
const subscription = subscriptions[i];
|
||||
if (!subscription) continue;
|
||||
|
||||
if (subscription.client === client && (!uid || subscription.uid === uid)) {
|
||||
this.subscriptions[key]?.delete(subscription);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Dispatch event to subscriptions
|
||||
*/
|
||||
async dispatch(event: WebSocketEvent) {
|
||||
const subscriptions = this.subscriptions[event.collection];
|
||||
if (!subscriptions || subscriptions.size === 0) return;
|
||||
const schema = await getSchema();
|
||||
|
||||
for (const subscription of subscriptions) {
|
||||
const { client } = subscription;
|
||||
|
||||
if (subscription.event !== undefined && event.action !== subscription.event) {
|
||||
continue; // skip filtered events
|
||||
}
|
||||
|
||||
try {
|
||||
client.accountability = await refreshAccountability(client.accountability);
|
||||
|
||||
const result =
|
||||
'item' in subscription
|
||||
? await this.getSinglePayload(subscription, client.accountability, schema, event)
|
||||
: await this.getMultiPayload(subscription, client.accountability, schema, event);
|
||||
|
||||
if (Array.isArray(result?.['data']) && result?.['data']?.length === 0) return;
|
||||
|
||||
client.send(fmtMessage('subscription', result, subscription.uid));
|
||||
} catch (err) {
|
||||
handleWebSocketException(client, err, 'subscribe');
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Handle incoming (un)subscribe requests
|
||||
*/
|
||||
async onMessage(client: WebSocketClient, message: WebSocketSubscribeMessage) {
|
||||
if (getMessageType(message) === 'subscribe') {
|
||||
try {
|
||||
const collection = String(message.collection!);
|
||||
const accountability = client.accountability;
|
||||
const schema = await getSchema();
|
||||
|
||||
if (!accountability?.admin && !schema.collections[collection]) {
|
||||
throw new WebSocketException(
|
||||
'subscribe',
|
||||
'INVALID_COLLECTION',
|
||||
'The provided collection does not exists or is not accessible.',
|
||||
message.uid
|
||||
);
|
||||
}
|
||||
|
||||
const subscription: Subscription = {
|
||||
client,
|
||||
collection,
|
||||
};
|
||||
|
||||
if ('event' in message) {
|
||||
subscription.event = message.event as SubscriptionEvent;
|
||||
}
|
||||
|
||||
if ('query' in message) {
|
||||
subscription.query = sanitizeQuery(message.query!, accountability);
|
||||
}
|
||||
|
||||
if ('item' in message) subscription.item = String(message.item);
|
||||
|
||||
if ('uid' in message) {
|
||||
subscription.uid = String(message.uid);
|
||||
// remove the subscription if it already exists
|
||||
this.unsubscribe(client, subscription.uid);
|
||||
}
|
||||
|
||||
let data: Record<string, any>;
|
||||
|
||||
if (subscription.event === undefined) {
|
||||
data =
|
||||
'item' in subscription
|
||||
? await this.getSinglePayload(subscription, accountability, schema)
|
||||
: await this.getMultiPayload(subscription, accountability, schema);
|
||||
} else {
|
||||
data = { event: 'init' };
|
||||
}
|
||||
|
||||
// if no errors were thrown register the subscription
|
||||
this.subscribe(subscription);
|
||||
|
||||
// send an initial response
|
||||
client.send(fmtMessage('subscription', data, subscription.uid));
|
||||
} catch (err) {
|
||||
handleWebSocketException(client, err, 'subscribe');
|
||||
}
|
||||
}
|
||||
|
||||
if (getMessageType(message) === 'unsubscribe') {
|
||||
try {
|
||||
this.unsubscribe(client, message.uid);
|
||||
|
||||
client.send(fmtMessage('subscription', { event: 'unsubscribe' }, message.uid));
|
||||
} catch (err) {
|
||||
handleWebSocketException(client, err, 'unsubscribe');
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private async getSinglePayload(
|
||||
subscription: Subscription,
|
||||
accountability: Accountability | null,
|
||||
schema: SchemaOverview,
|
||||
event?: WebSocketEvent
|
||||
): Promise<Record<string, any>> {
|
||||
const metaService = new MetaService({ schema, accountability });
|
||||
const query = subscription.query ?? {};
|
||||
const id = subscription.item!;
|
||||
|
||||
const result: Record<string, any> = {
|
||||
event: event?.action ?? 'init',
|
||||
};
|
||||
|
||||
if (subscription.collection === 'directus_collections') {
|
||||
const service = new CollectionsService({ schema, accountability });
|
||||
result['data'] = await service.readOne(String(id));
|
||||
} else {
|
||||
const service = getService(subscription.collection, { schema, accountability });
|
||||
result['data'] = await service.readOne(id, query);
|
||||
}
|
||||
|
||||
if ('meta' in query) {
|
||||
result['meta'] = await metaService.getMetaForQuery(subscription.collection, query);
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
private async getMultiPayload(
|
||||
subscription: Subscription,
|
||||
accountability: Accountability | null,
|
||||
schema: SchemaOverview,
|
||||
event?: WebSocketEvent
|
||||
): Promise<Record<string, any>> {
|
||||
const metaService = new MetaService({ schema, accountability });
|
||||
|
||||
const result: Record<string, any> = {
|
||||
event: event?.action ?? 'init',
|
||||
};
|
||||
|
||||
switch (subscription.collection) {
|
||||
case 'directus_collections':
|
||||
result['data'] = await this.getCollectionPayload(accountability, schema, event);
|
||||
break;
|
||||
case 'directus_fields':
|
||||
result['data'] = await this.getFieldsPayload(accountability, schema, event);
|
||||
break;
|
||||
case 'directus_relations':
|
||||
result['data'] = event?.payload;
|
||||
break;
|
||||
default:
|
||||
result['data'] = await this.getItemsPayload(subscription, accountability, schema, event);
|
||||
break;
|
||||
}
|
||||
|
||||
const query = subscription.query ?? {};
|
||||
|
||||
if ('meta' in query) {
|
||||
result['meta'] = await metaService.getMetaForQuery(subscription.collection, query);
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
private async getCollectionPayload(
|
||||
accountability: Accountability | null,
|
||||
schema: SchemaOverview,
|
||||
event?: WebSocketEvent
|
||||
) {
|
||||
const service = new CollectionsService({ schema, accountability });
|
||||
|
||||
if (!event?.action) {
|
||||
return await service.readByQuery();
|
||||
} else if (event.action === 'create') {
|
||||
return await service.readMany([String(event.key)]);
|
||||
} else if (event.action === 'delete') {
|
||||
return event.keys;
|
||||
} else {
|
||||
return await service.readMany(event.keys.map((key: any) => String(key)));
|
||||
}
|
||||
}
|
||||
|
||||
private async getFieldsPayload(
|
||||
accountability: Accountability | null,
|
||||
schema: SchemaOverview,
|
||||
event?: WebSocketEvent
|
||||
) {
|
||||
const service = new FieldsService({ schema, accountability });
|
||||
|
||||
if (!event?.action) {
|
||||
return await service.readAll();
|
||||
} else if (event.action === 'delete') {
|
||||
return event.keys;
|
||||
} else {
|
||||
return await service.readOne(event.payload?.['collection'], event.payload?.['field']);
|
||||
}
|
||||
}
|
||||
|
||||
private async getItemsPayload(
|
||||
subscription: Subscription,
|
||||
accountability: Accountability | null,
|
||||
schema: SchemaOverview,
|
||||
event?: WebSocketEvent
|
||||
) {
|
||||
const query = subscription.query ?? {};
|
||||
const service = getService(subscription.collection, { schema, accountability });
|
||||
|
||||
if (!event?.action) {
|
||||
return await service.readByQuery(query);
|
||||
} else if (event.action === 'create') {
|
||||
return await service.readMany([event.key], query);
|
||||
} else if (event.action === 'delete') {
|
||||
return event.keys;
|
||||
} else {
|
||||
return await service.readMany(event.keys, query);
|
||||
}
|
||||
}
|
||||
|
||||
private getSubscription(client: WebSocketClient, uid: string | number) {
|
||||
for (const userSubscriptions of Object.values(this.subscriptions)) {
|
||||
for (const subscription of userSubscriptions) {
|
||||
if (subscription.client === client && subscription.uid === uid) {
|
||||
return subscription;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return undefined;
|
||||
}
|
||||
}
|
||||
118
api/src/websocket/messages.ts
Normal file
118
api/src/websocket/messages.ts
Normal file
@@ -0,0 +1,118 @@
|
||||
import type { Item, Query } from '@directus/types';
|
||||
import { z } from 'zod';
|
||||
|
||||
const zodStringOrNumber = z.union([z.string(), z.number()]);
|
||||
|
||||
export const WebSocketMessage = z
|
||||
.object({
|
||||
type: z.string(),
|
||||
uid: zodStringOrNumber.optional(),
|
||||
})
|
||||
.passthrough();
|
||||
export type WebSocketMessage = z.infer<typeof WebSocketMessage>;
|
||||
|
||||
export const WebSocketResponse = z.discriminatedUnion('status', [
|
||||
WebSocketMessage.extend({
|
||||
status: z.literal('ok'),
|
||||
}),
|
||||
WebSocketMessage.extend({
|
||||
status: z.literal('error'),
|
||||
error: z
|
||||
.object({
|
||||
code: z.string(),
|
||||
message: z.string(),
|
||||
})
|
||||
.passthrough(),
|
||||
}),
|
||||
]);
|
||||
export type WebSocketResponse = z.infer<typeof WebSocketResponse>;
|
||||
|
||||
export const ConnectionParams = z.object({ access_token: z.string().optional() });
|
||||
export type ConnectionParams = z.infer<typeof ConnectionParams>;
|
||||
|
||||
export const BasicAuthMessage = z.union([
|
||||
z.object({ email: z.string().email(), password: z.string() }),
|
||||
z.object({ access_token: z.string() }),
|
||||
z.object({ refresh_token: z.string() }),
|
||||
]);
|
||||
export type BasicAuthMessage = z.infer<typeof BasicAuthMessage>;
|
||||
|
||||
export const WebSocketAuthMessage = WebSocketMessage.extend({
|
||||
type: z.literal('auth'),
|
||||
}).and(BasicAuthMessage);
|
||||
export type WebSocketAuthMessage = z.infer<typeof WebSocketAuthMessage>;
|
||||
|
||||
export const WebSocketSubscribeMessage = z.discriminatedUnion('type', [
|
||||
WebSocketMessage.extend({
|
||||
type: z.literal('subscribe'),
|
||||
collection: z.string(),
|
||||
event: z.union([z.literal('create'), z.literal('update'), z.literal('delete')]).optional(),
|
||||
item: zodStringOrNumber.optional(),
|
||||
query: z.custom<Query>().optional(),
|
||||
}),
|
||||
WebSocketMessage.extend({
|
||||
type: z.literal('unsubscribe'),
|
||||
}),
|
||||
]);
|
||||
export type WebSocketSubscribeMessage = z.infer<typeof WebSocketSubscribeMessage>;
|
||||
|
||||
const ZodItem = z.custom<Partial<Item>>();
|
||||
|
||||
const PartialItemsMessage = z.object({
|
||||
uid: zodStringOrNumber.optional(),
|
||||
type: z.literal('items'),
|
||||
collection: z.string(),
|
||||
});
|
||||
|
||||
export const WebSocketItemsMessage = z.union([
|
||||
PartialItemsMessage.extend({
|
||||
action: z.literal('create'),
|
||||
data: z.union([z.array(ZodItem), ZodItem]),
|
||||
query: z.custom<Query>().optional(),
|
||||
}),
|
||||
PartialItemsMessage.extend({
|
||||
action: z.literal('read'),
|
||||
ids: z.array(zodStringOrNumber).optional(),
|
||||
id: zodStringOrNumber.optional(),
|
||||
query: z.custom<Query>().optional(),
|
||||
}),
|
||||
PartialItemsMessage.extend({
|
||||
action: z.literal('update'),
|
||||
data: ZodItem,
|
||||
ids: z.array(zodStringOrNumber).optional(),
|
||||
id: zodStringOrNumber.optional(),
|
||||
query: z.custom<Query>().optional(),
|
||||
}),
|
||||
PartialItemsMessage.extend({
|
||||
action: z.literal('delete'),
|
||||
ids: z.array(zodStringOrNumber).optional(),
|
||||
id: zodStringOrNumber.optional(),
|
||||
query: z.custom<Query>().optional(),
|
||||
}),
|
||||
]);
|
||||
export type WebSocketItemsMessage = z.infer<typeof WebSocketItemsMessage>;
|
||||
|
||||
export const WebSocketEvent = z.discriminatedUnion('action', [
|
||||
z.object({
|
||||
action: z.literal('create'),
|
||||
collection: z.string(),
|
||||
payload: z.record(z.any()).optional(),
|
||||
key: zodStringOrNumber,
|
||||
}),
|
||||
z.object({
|
||||
action: z.literal('update'),
|
||||
collection: z.string(),
|
||||
payload: z.record(z.any()).optional(),
|
||||
keys: z.array(zodStringOrNumber),
|
||||
}),
|
||||
z.object({
|
||||
action: z.literal('delete'),
|
||||
collection: z.string(),
|
||||
payload: z.record(z.any()).optional(),
|
||||
keys: z.array(zodStringOrNumber),
|
||||
}),
|
||||
]);
|
||||
export type WebSocketEvent = z.infer<typeof WebSocketEvent>;
|
||||
|
||||
export const AuthMode = z.union([z.literal('public'), z.literal('handshake'), z.literal('strict')]);
|
||||
export type AuthMode = z.infer<typeof AuthMode>;
|
||||
35
api/src/websocket/types.ts
Normal file
35
api/src/websocket/types.ts
Normal file
@@ -0,0 +1,35 @@
|
||||
import type { Accountability, Query } from '@directus/types';
|
||||
import type { IncomingMessage } from 'http';
|
||||
import type internal from 'stream';
|
||||
import type { WebSocket } from 'ws';
|
||||
|
||||
export type AuthenticationState = {
|
||||
accountability: Accountability | null;
|
||||
expires_at: number | null;
|
||||
refresh_token?: string;
|
||||
};
|
||||
|
||||
export type WebSocketClient = WebSocket &
|
||||
AuthenticationState & { uid: string | number; auth_timer: NodeJS.Timer | null };
|
||||
export type UpgradeRequest = IncomingMessage & AuthenticationState;
|
||||
|
||||
export type SubscriptionEvent = 'create' | 'update' | 'delete';
|
||||
|
||||
export type Subscription = {
|
||||
uid?: string | number;
|
||||
query?: Query;
|
||||
item?: string | number;
|
||||
event?: SubscriptionEvent;
|
||||
collection: string;
|
||||
client: WebSocketClient;
|
||||
};
|
||||
|
||||
export type UpgradeContext = {
|
||||
request: IncomingMessage;
|
||||
socket: internal.Duplex;
|
||||
head: Buffer;
|
||||
};
|
||||
|
||||
export type GraphQLSocket = {
|
||||
client: WebSocketClient;
|
||||
};
|
||||
23
api/src/websocket/utils/get-expires-at-for-token.test.ts
Normal file
23
api/src/websocket/utils/get-expires-at-for-token.test.ts
Normal file
@@ -0,0 +1,23 @@
|
||||
import jwt from 'jsonwebtoken';
|
||||
import { describe, expect, test } from 'vitest';
|
||||
import { getExpiresAtForToken } from './get-expires-at-for-token.js';
|
||||
|
||||
describe('getExpiresAtForToken', () => {
|
||||
test('Returns null for non-jwt tokens', () => {
|
||||
const result = getExpiresAtForToken('not-a-jwt');
|
||||
expect(result).toBe(null);
|
||||
});
|
||||
|
||||
test('Returns null for jwt with no exp field', () => {
|
||||
const token = jwt.sign({ payload: 'content' }, 'secret', { issuer: 'tim' });
|
||||
const result = getExpiresAtForToken(token);
|
||||
expect(result).toBe(null);
|
||||
});
|
||||
|
||||
test('Returns expiresAt field for jwt with exp as number', () => {
|
||||
const now = Math.floor(Date.now() / 1000);
|
||||
const token = jwt.sign({ payload: 'content' }, 'secret', { expiresIn: 42 });
|
||||
const result = getExpiresAtForToken(token);
|
||||
expect(result).toBeGreaterThan(now);
|
||||
});
|
||||
});
|
||||
11
api/src/websocket/utils/get-expires-at-for-token.ts
Normal file
11
api/src/websocket/utils/get-expires-at-for-token.ts
Normal file
@@ -0,0 +1,11 @@
|
||||
import jwt from 'jsonwebtoken';
|
||||
|
||||
export function getExpiresAtForToken(token: string): number | null {
|
||||
const decoded = jwt.decode(token);
|
||||
|
||||
if (decoded && typeof decoded === 'object' && decoded.exp) {
|
||||
return decoded.exp;
|
||||
}
|
||||
|
||||
return null;
|
||||
}
|
||||
75
api/src/websocket/utils/message.test.ts
Normal file
75
api/src/websocket/utils/message.test.ts
Normal file
@@ -0,0 +1,75 @@
|
||||
import { describe, expect, test, vi } from 'vitest';
|
||||
import type { WebSocketClient } from '../types.js';
|
||||
import { fmtMessage, getMessageType, safeSend } from './message.js';
|
||||
|
||||
describe('fmtMessage util', () => {
|
||||
test('Returns formatted message', () => {
|
||||
const result = fmtMessage('test', { test: 'abc' });
|
||||
expect(result).toStrictEqual('{"type":"test","test":"abc"}');
|
||||
});
|
||||
|
||||
test('Returns formatted message with uid', () => {
|
||||
const result = fmtMessage('test', { test: 'abc' }, '123');
|
||||
expect(result).toStrictEqual('{"type":"test","test":"abc","uid":"123"}');
|
||||
});
|
||||
});
|
||||
|
||||
describe('safeSend util', () => {
|
||||
test('Ignore for closed connections', async () => {
|
||||
const fakeClient = {
|
||||
readyState: 3, // closed
|
||||
OPEN: 1,
|
||||
bufferedAmount: 0,
|
||||
send: vi.fn(),
|
||||
} as unknown as WebSocketClient;
|
||||
|
||||
const result = await safeSend(fakeClient, 'not used');
|
||||
expect(result).toBe(false);
|
||||
expect(fakeClient.send).not.toBeCalled();
|
||||
});
|
||||
|
||||
test('Wait for buffer', async () => {
|
||||
const fakeClient = {
|
||||
readyState: 1, // open
|
||||
OPEN: 1,
|
||||
bufferedAmount: 4,
|
||||
send: vi.fn(),
|
||||
};
|
||||
|
||||
setTimeout(() => {
|
||||
fakeClient.bufferedAmount = 0;
|
||||
}, 10);
|
||||
|
||||
const result = await safeSend(fakeClient as unknown as WebSocketClient, 'a message', 20);
|
||||
expect(result).toBe(true);
|
||||
expect(fakeClient.send).toBeCalledWith('a message');
|
||||
});
|
||||
|
||||
test('send message', async () => {
|
||||
const fakeClient = {
|
||||
readyState: 1, // open
|
||||
OPEN: 1,
|
||||
bufferedAmount: 0,
|
||||
send: vi.fn(),
|
||||
} as unknown as WebSocketClient;
|
||||
|
||||
const result = await safeSend(fakeClient as unknown as WebSocketClient, 'a message');
|
||||
expect(result).toBe(true);
|
||||
expect(fakeClient.send).toBeCalledWith('a message');
|
||||
});
|
||||
});
|
||||
|
||||
describe('getMessageType util', () => {
|
||||
test('Fails graceously', () => {
|
||||
expect(getMessageType(null)).toBe('');
|
||||
expect(getMessageType(undefined)).toBe('');
|
||||
expect(getMessageType(false)).toBe('');
|
||||
expect(getMessageType(123456)).toBe('');
|
||||
expect(getMessageType([])).toBe('');
|
||||
});
|
||||
|
||||
test('Get the type property', () => {
|
||||
expect(getMessageType({ type: 'test' })).toBe('test');
|
||||
expect(getMessageType({ type: 123 })).toBe('123');
|
||||
});
|
||||
});
|
||||
34
api/src/websocket/utils/message.ts
Normal file
34
api/src/websocket/utils/message.ts
Normal file
@@ -0,0 +1,34 @@
|
||||
import type { WebSocketClient } from '../types.js';
|
||||
|
||||
// a simple util for building a message object
|
||||
export const fmtMessage = (type: string, data: Record<string, any> = {}, uid?: string | number) => {
|
||||
const message: Record<string, any> = { type, ...data };
|
||||
|
||||
if (uid !== undefined) {
|
||||
message['uid'] = uid;
|
||||
}
|
||||
|
||||
return JSON.stringify(message);
|
||||
};
|
||||
|
||||
// we may need this later for slow connections
|
||||
export const safeSend = async (client: WebSocketClient, data: string, delay = 100) => {
|
||||
if (client.readyState !== client.OPEN) return false;
|
||||
|
||||
if (client.bufferedAmount > 0) {
|
||||
// wait for the buffer to clear
|
||||
return new Promise((resolve) => {
|
||||
setTimeout(() => {
|
||||
safeSend(client, data, delay).then((success) => resolve(success));
|
||||
}, delay);
|
||||
});
|
||||
}
|
||||
|
||||
client.send(data);
|
||||
return true;
|
||||
};
|
||||
|
||||
// an often used message type extractor function
|
||||
export const getMessageType = (message: any): string => {
|
||||
return typeof message !== 'object' || Array.isArray(message) || message === null ? '' : String(message.type);
|
||||
};
|
||||
94
api/src/websocket/utils/wait-for-message.test.ts
Normal file
94
api/src/websocket/utils/wait-for-message.test.ts
Normal file
@@ -0,0 +1,94 @@
|
||||
import { describe, expect, test, vi } from 'vitest';
|
||||
import type { RawData, WebSocket } from 'ws';
|
||||
import { waitForAnyMessage, waitForMessageType } from './wait-for-message.js';
|
||||
|
||||
function bufferMessage(msg: any): RawData {
|
||||
return Buffer.from(JSON.stringify(msg));
|
||||
}
|
||||
|
||||
function mockClient(handler: (callback: (event: RawData) => void) => void) {
|
||||
return {
|
||||
on: vi.fn().mockImplementation((type: string, callback: (event: RawData) => void) => {
|
||||
if (type === 'message') handler(callback);
|
||||
}),
|
||||
off: vi.fn(),
|
||||
} as unknown as WebSocket;
|
||||
}
|
||||
|
||||
describe('Wait for messages', () => {
|
||||
test('should succeed, 5ms delay, 10ms timeout', async () => {
|
||||
const TEST_TIMEOUT = 10;
|
||||
const TEST_MSG = { type: 'test', id: 1 };
|
||||
|
||||
const fakeClient = mockClient((callback) => {
|
||||
setTimeout(() => {
|
||||
callback(bufferMessage(TEST_MSG));
|
||||
}, 5);
|
||||
});
|
||||
|
||||
const msg = await waitForAnyMessage(fakeClient, TEST_TIMEOUT);
|
||||
|
||||
expect(msg).toStrictEqual(TEST_MSG);
|
||||
});
|
||||
|
||||
test('should fail, 10ms delay, 5ms timeout', async () => {
|
||||
const TEST_TIMEOUT = 5;
|
||||
const TEST_MSG = { type: 'test', id: 1 };
|
||||
|
||||
const fakeClient = mockClient((callback) => {
|
||||
setTimeout(() => {
|
||||
callback(bufferMessage(TEST_MSG));
|
||||
}, 10);
|
||||
});
|
||||
|
||||
expect(() => waitForAnyMessage(fakeClient, TEST_TIMEOUT)).rejects.toBe(undefined);
|
||||
});
|
||||
|
||||
test('should fail parsing', async () => {
|
||||
const TEST_TIMEOUT = 5;
|
||||
|
||||
const fakeClient = mockClient((callback) => {
|
||||
setTimeout(() => {
|
||||
callback(Buffer.from('{invalid:json}'));
|
||||
}, 10);
|
||||
});
|
||||
|
||||
expect(() => waitForAnyMessage(fakeClient, TEST_TIMEOUT)).rejects.toBe(undefined);
|
||||
});
|
||||
});
|
||||
|
||||
describe('Wait for specific types messages', () => {
|
||||
const MSG_A = { type: 'test', id: 1 };
|
||||
const MSG_B = { type: 'other', id: 2 };
|
||||
|
||||
test('should find the correct message', async () => {
|
||||
const fakeClient = mockClient((callback) => {
|
||||
setTimeout(() => callback(bufferMessage(MSG_B)), 5);
|
||||
setTimeout(() => callback(bufferMessage(MSG_A)), 10);
|
||||
});
|
||||
|
||||
const msg = await waitForMessageType(fakeClient, 'test', 15);
|
||||
|
||||
expect(msg).toStrictEqual(MSG_A);
|
||||
});
|
||||
|
||||
test('should fail, no matching type', async () => {
|
||||
const fakeClient = mockClient((callback) => {
|
||||
setTimeout(() => {
|
||||
callback(bufferMessage(MSG_B));
|
||||
}, 5);
|
||||
});
|
||||
|
||||
expect(() => waitForMessageType(fakeClient, 'test', 10)).rejects.toBe(undefined);
|
||||
});
|
||||
|
||||
test('should fail parsing', async () => {
|
||||
const fakeClient = mockClient((callback) => {
|
||||
setTimeout(() => {
|
||||
callback(bufferMessage({ id: 2 }));
|
||||
}, 5);
|
||||
});
|
||||
|
||||
expect(() => waitForMessageType(fakeClient, 'test', 10)).rejects.toBe(undefined);
|
||||
});
|
||||
});
|
||||
52
api/src/websocket/utils/wait-for-message.ts
Normal file
52
api/src/websocket/utils/wait-for-message.ts
Normal file
@@ -0,0 +1,52 @@
|
||||
import { parseJSON } from '@directus/utils';
|
||||
import type { RawData, WebSocket } from 'ws';
|
||||
import { WebSocketMessage } from '../messages.js';
|
||||
import { getMessageType } from './message.js';
|
||||
|
||||
export const waitForAnyMessage = (client: WebSocket, timeout: number): Promise<Record<string, any>> => {
|
||||
return new Promise((resolve, reject) => {
|
||||
client.on('message', awaitMessage);
|
||||
|
||||
const timer = setTimeout(() => {
|
||||
client.off('message', awaitMessage);
|
||||
reject();
|
||||
}, timeout);
|
||||
|
||||
function awaitMessage(event: RawData) {
|
||||
try {
|
||||
clearTimeout(timer);
|
||||
client.off('message', awaitMessage);
|
||||
resolve(parseJSON(event.toString()));
|
||||
} catch (err) {
|
||||
reject(err);
|
||||
}
|
||||
}
|
||||
});
|
||||
};
|
||||
|
||||
export const waitForMessageType = (client: WebSocket, type: string, timeout: number): Promise<WebSocketMessage> => {
|
||||
return new Promise((resolve, reject) => {
|
||||
client.on('message', awaitMessage);
|
||||
|
||||
const timer = setTimeout(() => {
|
||||
client.off('message', awaitMessage);
|
||||
reject();
|
||||
}, timeout);
|
||||
|
||||
function awaitMessage(event: RawData) {
|
||||
let msg: WebSocketMessage;
|
||||
|
||||
try {
|
||||
msg = WebSocketMessage.parse(parseJSON(event.toString()));
|
||||
} catch {
|
||||
return;
|
||||
}
|
||||
|
||||
if (getMessageType(msg) === type) {
|
||||
clearTimeout(timer);
|
||||
client.off('message', awaitMessage);
|
||||
resolve(msg);
|
||||
}
|
||||
}
|
||||
});
|
||||
};
|
||||
Reference in New Issue
Block a user