Integrating Websockets in Directus 🕸️🧦 (#14737)

* added emitter context

* partial items tests

* updated items handler tests

* fixed test after merge

* forgot the event context

* fixed auth message parsing for graphql subscriptions

* fixed type strictness

* fixed graphql subscription bug

* bumped websocket dependencies

* touched up some dangling code

* updated itemsservice usage

* disabled overkill logs

* double checked environment type processing

* fixed missed capitalization

* fixed subscription payloads

* Added explicit string type casting

* removed obsolete "trimUpper" utility

* using the parseJSON utility consistently

* pinned dependencies

* parse environment variables

* fixed pnpm-lock

* GraphQL Subscriptions for all events

* fixed typo

* added event data to the graphql definition

* fix payload for delete events

* Added optional chaining for type to prevent fatal crashes on invalid messages

* fix failing on getting type from undefined

* Update api/src/websocket/exceptions.ts

Co-authored-by: Azri Kahar <42867097+azrikahar@users.noreply.github.com>

* Add proper ZodError handling

* added the zod-validation-error parser

* allow disabling the rate limiter

* Update api/src/websocket/controllers/base.ts

Co-authored-by: Azri Kahar <42867097+azrikahar@users.noreply.github.com>

* updated starting logs

* fixed email/password expiration logic

* added tests for getMessageType

* simplified message parsing and dropped capitalization

* updated authenticate test

* switched to lower cased message.type to prevent spreading "toUpperCase" around

* cleaned up debug logs

* cast enabled config to boolean

* Update api/src/websocket/controllers/rest.ts

Co-authored-by: Azri Kahar <42867097+azrikahar@users.noreply.github.com>

* Update api/src/websocket/handlers/subscribe.ts

Co-authored-by: Azri Kahar <42867097+azrikahar@users.noreply.github.com>

* Update api/src/websocket/handlers/subscribe.ts

Co-authored-by: Azri Kahar <42867097+azrikahar@users.noreply.github.com>

* Update api/src/websocket/handlers/items.ts

Co-authored-by: Azri Kahar <42867097+azrikahar@users.noreply.github.com>

* Update api/src/websocket/controllers/base.ts

Co-authored-by: Azri Kahar <42867097+azrikahar@users.noreply.github.com>

* Update api/src/websocket/handlers/heartbeat.ts

Co-authored-by: Azri Kahar <42867097+azrikahar@users.noreply.github.com>

* Suggested fixes by Azri

* removed redundant try-catch

* fixed authentication timeout
added returning the refresh token when authenticating

* updated pnpm lock after merge

* Fixed authentication modes for GraphQL according to best practices

* implement useFakeTimers in heartbeat unit test

* implement useFakeTimers in items unit test

* Update api/src/services/server.ts

Co-authored-by: Azri Kahar <42867097+azrikahar@users.noreply.github.com>

* removed obsolete authentication.verbose toggle

* added email flag to message validation

* switched to ternary for consistency

* moved getSchema out of for loop

* added singleton logic to items handler

* close the socket after failed auth for non-public connections

* disabled system collections for rest subscriptions

* re-ran pnpm i

* allow for multiple subscripitions in the memory messenger

* - fixed system collection subscriptions
- abstracted hook message bus
- fixed graphql horizontal scaling

* remove logic from root context for tests

* fix reading created item

* fix linter

* typo and extra safe guard suggested by azri

* prevent setting long timeouts in favor of a shared interval

* prevent unsubscribing all existing subscriptions when omitting "uid"

* - extracted getService utility
- block system collections mutation in the items handler
- implemented the correct services for system collections

* allow numeric uid's to be used

* fixed the types for numeric uid's to be used

* added missing await's

* fixed type imports after merge

* removed unused imports

* Update api/src/websocket/controllers/hooks.ts

Co-authored-by: Azri Kahar <42867097+azrikahar@users.noreply.github.com>

* Update api/src/websocket/controllers/hooks.ts

Co-authored-by: Azri Kahar <42867097+azrikahar@users.noreply.github.com>

* Update api/src/messenger.ts

Co-authored-by: Azri Kahar <42867097+azrikahar@users.noreply.github.com>

* improved error for graphql subscriptions

* fixed TS Modernization conflicts

* fixed TS Modernization conflicts

* fixed conflicts after merge

* removed unused name property

* abstracxted environment configuration

* respond to ping messages when heartbeat disabled

* something something merge

* moved toBoolean to it's own util file

* replaced old socket naming

* removed old exception

* fixed typo

* Update api/src/env.ts

Co-authored-by: ian <licitdev@gmail.com>

* Update api/src/websocket/handlers/heartbeat.test.ts

Co-authored-by: ian <licitdev@gmail.com>

* Update api/src/websocket/handlers/heartbeat.ts

Co-authored-by: ian <licitdev@gmail.com>

* Update api/src/services/server.ts

Co-authored-by: ian <licitdev@gmail.com>

* fixed for linter

* add server_info_websocket in graphql

* Add base REST websocket tests

* do merge things

* fixing things

* fixed failing unit test

* Update dependencies

* Move tests

* Update lockfile

* Use new paths when spawning

* return websockets to opt-in

* Enable websockets for tests

* Test with ephemeral access token

* no camelcasing gql subscriptions

* use underscore for gql event

* Remove unused import

* Add base GraphQL subscription tests

* Fix accidental comment

* Add some relational tests

* Organize imports

Using VS Code's default organize import

* Run ESlint formatting

* One more opinionated formatting change

* Formatting

* Fix message sequence not in order

* Remove relational batch update tests

* Test horizontal scaling

* using toboolean util for server_info

* removed unneeded type cast

* found the gql request type

* extra usage of the toBoolean util

* merge the authentication middleware and get-account-for-token util

* updated utility test

* fixed middleware unit test

* Add return

* Remove user filtering and close conns

* Fix reused accountability

* fixed failing util test

* added subscription unit tests

* added missing mock

* trigger workflow

* Revert "trigger workflow"

This reverts commit 4f544b0c1b.

* Trigger testing for all vendors

* add unsubscription confirmation

* Wait for unsubscription confirmation

* Fix incorrect sending of unsubscription confirmation

* updated ubsubscribe logic

* Update count for unsubscription message

* Fix sequence for UUID pktype in MSSQL

* Increase auth timeout

* Add start index when getting messages

* Fix subscription retrieval and cast uid to string

* Remove nested ternary

* Revert "Increase auth timeout"

This reverts commit 10707409c4.

* Terminate connection instead of close

* fixed merge

* re-added missing packages after merge resolve

* fixed type imports

* Create lazy-cows-happen.md

Added changeset

* Minor bump for "directus" package as well

* fixed "strict" auth mode for graphql subscriptions

* removed nested ternary

* Add websocket tests to sequential flow

* Disable pressure limiter for blackbox tests

* fix merge

* WebSockets Documentation (#18254)

* Small repsonsive styling fix on Card

* REST getting started guide

* Authentication guide

* REST subscription guides

* JS Chat guide

* Sidebar websocket guides section

* Added config options

* Respoinding to brainslug's review

* Fixed incorrect header on guides/rt/subs

* Fixed spellchecker

* Correct full code example on guides/rt/chat/js

* Fixed JS chat tut

* Order of steps in js chat guide updated for easier following-along

* Realtime chat Vue Guide

* feat: create react.js file

* feat: add set up for directus project

* docs: create react boilder plate

* docs: initialize connection

* docs: set up submission methods

* docs: establish websocket connection

* docs: subscribe to messages

* docs: create new messages

* docs: display new messages

* docs: display historical messages

* docs: next steps

* docs: full code sample

* docs: clean up

* docs: add name to contributors

* docs: add react card

* docs: updates to react chat

* Added live poll result guide

* docs: intro

* docs: before you begin

* docs: install packages

* docs: authenticate connection

* docs: query and mutation

* docs: utilize hooks

* docs: subscribe to changes

* docs: create helper functions

* docs: display messages

* docs: summary

* docs: full sample code

* chore: add card for webscockets with graphql

* docs: intro

* docs: subscribe to changes

* docs: handling changes

* docs: crud operations

* docs: unsubscribing from changes

* docs: updates

* chore: add card

* chore: updates to graphql docs

* chore: updates to getting started

* chore: updates to subscription

* chore: updates to real chat guide

* Added WebSockets Operations Guide

* Consistent titles

* Contributors component for docs

* Triggering Netlify

* Add operations to sidebar

* Fix operations link

* Small formatting changes

* Clarity around property values

* Removed unused values in Contributors component

* Prompt for default choice

* Tabs & lowercase doctypes

* Semicolons

* Event overwerites -> event listeners

* Spacing

* Flipped order of websockets guide to match GQL

---------

Co-authored-by: Esther Agbaje <folasadeagbaje@gmail.com>
Co-authored-by: Rijk van Zanten <rijkvanzanten@me.com>

* fixed typo

* removed unused import

* added tests for "to-boolean" and "exceptions"

* added websocket service tests

* quote environment variable to satisfy dictionary

* GraphQL Subscriptions update (#18804)

* updated graphql subscription structure

* updated graphql examples

* Create hungry-points-rescue.md

* using `key` instead of `ID` on the toplevel

* removed changeset

* fixed the graphql type after the rename to `key`

* retrun data null for delete events to prevent non-nullable gql error

* updated missed ID reference in the docs

* updated missed ID reference in the docs

* renamed "payload" to "data" in the REST Subscription response

* fixed missed reference to payload

* added optional event filter for REST subscriptions

* updated docs for event filter

* Update docs/guides/real-time/subscriptions/websockets.md

Co-authored-by: ian <licitdev@gmail.com>

---------

Co-authored-by: ian <licitdev@gmail.com>

* added messenger unit test

* always send subscription confirmation

* Add event to subscription options

* Update tests

* Add tests for event filtering

* Revert testing for all vendors

* Remove obsolete console comment

* Update comment

* Correct event in JS WS guide

* Fix collection name to match name used in subscription

* Fix collection name in other guides

* Fix diffs in doc & enhance chart example

* Complete sentence in GraphQL guide

* Small update to config description

---------

Co-authored-by: Rijk van Zanten <rijkvanzanten@me.com>
Co-authored-by: Azri Kahar <42867097+azrikahar@users.noreply.github.com>
Co-authored-by: ian <licitdev@gmail.com>
Co-authored-by: Nitwel <mail@nitwel.de>
Co-authored-by: Kevin Lewis <kvn@lws.io>
Co-authored-by: Pascal Jufer <pascal-jufer@bluewin.ch>
Co-authored-by: Esther Agbaje <folasadeagbaje@gmail.com>
This commit is contained in:
Brainslug
2023-06-08 20:54:50 +02:00
committed by GitHub
parent 0606b93fcc
commit cfddabd9ee
82 changed files with 15844 additions and 142 deletions

View File

@@ -112,6 +112,7 @@
"fs-extra": "11.1.1",
"graphql": "16.6.0",
"graphql-compose": "9.0.10",
"graphql-ws": "5.12.0",
"helmet": "7.0.0",
"icc": "3.0.0",
"inquirer": "9.2.4",
@@ -158,7 +159,10 @@
"uuid": "9.0.0",
"uuid-validate": "0.0.3",
"vm2": "3.9.19",
"wellknown": "0.5.0"
"wellknown": "0.5.0",
"ws": "8.12.1",
"zod": "3.21.4",
"zod-validation-error": "1.0.1"
},
"devDependencies": {
"@directus/tsconfig": "workspace:*",
@@ -199,6 +203,7 @@
"@types/uuid": "9.0.1",
"@types/uuid-validate": "0.0.1",
"@types/wellknown": "0.5.4",
"@types/ws": "8.5.4",
"@vitest/coverage-c8": "0.31.1",
"copyfiles": "2.4.1",
"form-data": "4.0.0",

View File

@@ -1,6 +1,7 @@
import type { ActionHandler, EventContext, FilterHandler, InitHandler } from '@directus/types';
import ee2 from 'eventemitter2';
import logger from './logger.js';
import getDatabase from './database/index.js';
export class Emitter {
private filterEmitter;
@@ -22,11 +23,19 @@ export class Emitter {
this.initEmitter = new ee2.EventEmitter2(emitterOptions);
}
private getDefaultContext(): EventContext {
return {
database: getDatabase(),
accountability: null,
schema: null,
};
}
public async emitFilter<T>(
event: string | string[],
payload: T,
meta: Record<string, any>,
context: EventContext
context: EventContext | null = null
): Promise<T> {
const events = Array.isArray(event) ? event : [event];
@@ -39,7 +48,7 @@ export class Emitter {
for (const { event, listeners } of eventListeners) {
for (const listener of listeners) {
const result = await listener(updatedPayload, { event, ...meta }, context);
const result = await listener(updatedPayload, { event, ...meta }, context ?? this.getDefaultContext());
if (result !== undefined) {
updatedPayload = result;
@@ -50,11 +59,11 @@ export class Emitter {
return updatedPayload;
}
public emitAction(event: string | string[], meta: Record<string, any>, context: EventContext): void {
public emitAction(event: string | string[], meta: Record<string, any>, context: EventContext | null = null): void {
const events = Array.isArray(event) ? event : [event];
for (const event of events) {
this.actionEmitter.emitAsync(event, { event, ...meta }, context).catch((err) => {
this.actionEmitter.emitAsync(event, { event, ...meta }, context ?? this.getDefaultContext()).catch((err) => {
logger.warn(`An error was thrown while executing action "${event}"`);
logger.warn(err);
});

View File

@@ -7,10 +7,10 @@ import { parseJSON, toArray } from '@directus/utils';
import dotenv from 'dotenv';
import fs from 'fs';
import { clone, toNumber, toString } from 'lodash-es';
import { createRequire } from 'node:module';
import path from 'path';
import { requireYAML } from './utils/require-yaml.js';
import { createRequire } from 'node:module';
import { toBoolean } from './utils/to-boolean.js';
const require = createRequire(import.meta.url);
@@ -206,6 +206,8 @@ const allowedEnvironmentVars = [
// flows
'FLOWS_EXEC_ALLOWED_MODULES',
'FLOWS_ENV_ALLOW_LIST',
// websockets
'WEBSOCKETS_.+',
].map((name) => new RegExp(`^${name}$`));
const acceptedEnvTypes = ['string', 'number', 'regex', 'array', 'json'];
@@ -306,6 +308,18 @@ const defaults: Record<string, any> = {
GRAPHQL_INTROSPECTION: true,
WEBSOCKETS_ENABLED: false,
WEBSOCKETS_REST_ENABLED: true,
WEBSOCKETS_REST_AUTH: 'handshake',
WEBSOCKETS_REST_AUTH_TIMEOUT: 10,
WEBSOCKETS_REST_PATH: '/websocket',
WEBSOCKETS_GRAPHQL_ENABLED: true,
WEBSOCKETS_GRAPHQL_AUTH: 'handshake',
WEBSOCKETS_GRAPHQL_AUTH_TIMEOUT: 10,
WEBSOCKETS_GRAPHQL_PATH: '/graphql',
WEBSOCKETS_HEARTBEAT_ENABLED: true,
WEBSOCKETS_HEARTBEAT_PERIOD: 30,
FLOWS_EXEC_ALLOWED_MODULES: false,
FLOWS_ENV_ALLOW_LIST: false,
@@ -571,7 +585,3 @@ function tryJSON(value: any) {
return value;
}
}
function toBoolean(value: any): boolean {
return value === 'true' || value === true || value === '1' || value === 1;
}

82
api/src/messenger.test.ts Normal file
View File

@@ -0,0 +1,82 @@
import { describe, expect, test, vi, beforeEach } from 'vitest';
import { getEnv } from './env.js';
import { MessengerMemory, MessengerRedis } from './messenger.js';
vi.mock('./env');
vi.mock('ioredis');
async function dynamicMessenger(mockedEnv: Record<string, any>) {
vi.mocked(getEnv).mockReturnValue(mockedEnv);
return await import('./messenger.js');
}
beforeEach(() => {
vi.resetModules();
});
describe('MessengerMemory', () => {
test('getMessenger', async () => {
const { MessengerMemory, getMessenger } = await dynamicMessenger({
MESSENGER_STORE: 'memory',
});
const messenger = getMessenger();
expect(messenger).toBeInstanceOf(MessengerMemory);
});
test('subscribing', () => {
const messages: Record<string, string>[] = [];
const testMessage = { test: 'test' };
const messenger = new MessengerMemory();
messenger.subscribe('test', (data: Record<string, string>) => {
messages.push(data);
});
messenger.publish('test', testMessage);
expect(messenger.handlers['test']?.size ?? 0).toBe(1);
expect(messages.length).toBe(1);
expect(messages).toStrictEqual([testMessage]);
messenger.unsubscribe('test');
messenger.publish('test', testMessage);
expect(messenger.handlers['test']?.size ?? 0).toBe(0);
expect(messages.length).toBe(1);
expect(messages).toStrictEqual([testMessage]);
});
});
describe('MessengerRedis', () => {
test('getMessenger', async () => {
const { MessengerRedis, getMessenger } = await dynamicMessenger({
MESSENGER_STORE: 'redis',
});
const messenger = getMessenger();
expect(messenger).toBeInstanceOf(MessengerRedis);
});
test('subscribing', () => {
const testMessage = { test: 'test' };
const messenger = new MessengerRedis();
messenger.subscribe('test', (_data: Record<string, string>) => {
// do nothing
});
expect(messenger.sub.subscribe).toBeCalled();
expect(messenger.sub.on).toBeCalled();
messenger.publish('test', testMessage);
expect(messenger.pub.publish).toBeCalled();
messenger.unsubscribe('test');
expect(messenger.sub.unsubscribe).toBeCalled();
});
});

View File

@@ -1,6 +1,6 @@
import { parseJSON } from '@directus/utils';
import { Redis } from 'ioredis';
import env from './env.js';
import { getEnv } from './env.js';
import { getConfigFromEnv } from './utils/get-config-from-env.js';
export type MessengerSubscriptionCallback = (payload: Record<string, any>) => void;
@@ -8,26 +8,31 @@ export type MessengerSubscriptionCallback = (payload: Record<string, any>) => vo
export interface Messenger {
publish: (channel: string, payload: Record<string, any>) => void;
subscribe: (channel: string, callback: MessengerSubscriptionCallback) => void;
unsubscribe: (channel: string) => void;
unsubscribe: (channel: string, callback?: MessengerSubscriptionCallback) => void;
}
export class MessengerMemory implements Messenger {
handlers: Record<string, MessengerSubscriptionCallback>;
handlers: Record<string, Set<MessengerSubscriptionCallback>>;
constructor() {
this.handlers = {};
}
publish(channel: string, payload: Record<string, any>) {
this.handlers[channel]?.(payload);
this.handlers[channel]?.forEach((callback) => callback(payload));
}
subscribe(channel: string, callback: MessengerSubscriptionCallback) {
this.handlers[channel] = callback;
if (!this.handlers[channel]) this.handlers[channel] = new Set();
this.handlers[channel]?.add(callback);
}
unsubscribe(channel: string) {
delete this.handlers[channel];
unsubscribe(channel: string, callback?: MessengerSubscriptionCallback) {
if (!callback) {
delete this.handlers[channel];
} else {
this.handlers[channel]?.delete(callback);
}
}
}
@@ -38,7 +43,7 @@ export class MessengerRedis implements Messenger {
constructor() {
const config = getConfigFromEnv('MESSENGER_REDIS');
const env = getEnv();
this.pub = new Redis(env['MESSENGER_REDIS'] ?? config);
this.sub = new Redis(env['MESSENGER_REDIS'] ?? config);
this.namespace = env['MESSENGER_NAMESPACE'] ?? 'directus';
@@ -69,6 +74,7 @@ let messenger: Messenger;
export function getMessenger() {
if (messenger) return messenger;
const env = getEnv();
if (env['MESSENGER_STORE'] === 'redis') {
messenger = new MessengerRedis();

View File

@@ -2,18 +2,19 @@ import type { Request, Response } from 'express';
import jwt from 'jsonwebtoken';
import type { Knex } from 'knex';
import { afterEach, expect, test, vi } from 'vitest';
import '../../src/types/express.d.ts';
import '../types/express.d.ts';
import getDatabase from '../database/index.js';
import emitter from '../emitter.js';
import env from '../env.js';
import { InvalidCredentialsException } from '../exceptions/invalid-credentials.js';
import { handler } from './authenticate.js';
vi.mock('../../src/database');
vi.mock('../database/index');
vi.mock('../../src/env', () => {
vi.mock('../env', () => {
const MOCK_ENV = {
SECRET: 'test',
EXTENSIONS_PATH: './extensions',
};
return {

View File

@@ -3,12 +3,9 @@ import type { NextFunction, Request, Response } from 'express';
import { isEqual } from 'lodash-es';
import getDatabase from '../database/index.js';
import emitter from '../emitter.js';
import env from '../env.js';
import { InvalidCredentialsException } from '../exceptions/index.js';
import asyncHandler from '../utils/async-handler.js';
import { getIPFromReq } from '../utils/get-ip-from-req.js';
import isDirectusJWT from '../utils/is-directus-jwt.js';
import { verifyAccessJWT } from '../utils/jwt.js';
import { getAccountabilityForToken } from '../utils/get-accountability-for-token.js';
/**
* Verify the passed JWT and assign the user ID and role to `req`
@@ -48,41 +45,7 @@ export const handler = async (req: Request, _res: Response, next: NextFunction)
return next();
}
req.accountability = defaultAccountability;
if (req.token) {
if (isDirectusJWT(req.token)) {
const payload = verifyAccessJWT(req.token, env['SECRET']);
req.accountability.role = payload.role;
req.accountability.admin = payload.admin_access === true || payload.admin_access == 1;
req.accountability.app = payload.app_access === true || payload.app_access == 1;
if (payload.share) req.accountability.share = payload.share;
if (payload.share_scope) req.accountability.share_scope = payload.share_scope;
if (payload.id) req.accountability.user = payload.id;
} else {
// Try finding the user with the provided token
const user = await database
.select('directus_users.id', 'directus_users.role', 'directus_roles.admin_access', 'directus_roles.app_access')
.from('directus_users')
.leftJoin('directus_roles', 'directus_users.role', 'directus_roles.id')
.where({
'directus_users.token': req.token,
status: 'active',
})
.first();
if (!user) {
throw new InvalidCredentialsException();
}
req.accountability.user = user.id;
req.accountability.role = user.role;
req.accountability.admin = user.admin_access === true || user.admin_access == 1;
req.accountability.app = user.app_access === true || user.app_access == 1;
}
}
req.accountability = await getAccountabilityForToken(req.token, defaultAccountability);
return next();
};

View File

@@ -12,6 +12,14 @@ import emitter from './emitter.js';
import env from './env.js';
import logger from './logger.js';
import { getConfigFromEnv } from './utils/get-config-from-env.js';
import {
createSubscriptionController,
createWebSocketController,
getSubscriptionController,
getWebSocketController,
} from './websocket/controllers/index.js';
import { startWebSocketHandlers } from './websocket/handlers/index.js';
import { toBoolean } from './utils/to-boolean.js';
export let SERVER_ONLINE = true;
@@ -82,6 +90,12 @@ export async function createServer(): Promise<http.Server> {
res.once('close', complete.bind(null, false));
});
if (toBoolean(env['WEBSOCKETS_ENABLED']) === true) {
createSubscriptionController(server);
createWebSocketController(server);
startWebSocketHandlers();
}
const terminusOptions: TerminusOptions = {
timeout:
env['SERVER_SHUTDOWN_TIMEOUT'] >= 0 && env['SERVER_SHUTDOWN_TIMEOUT'] < Infinity
@@ -106,6 +120,8 @@ export async function createServer(): Promise<http.Server> {
}
async function onSignal() {
getSubscriptionController()?.terminate();
getWebSocketController()?.terminate();
const database = getDatabase();
await database.destroy();

View File

@@ -62,24 +62,13 @@ import { AuthenticationService } from '../authentication.js';
import { CollectionsService } from '../collections.js';
import { FieldsService } from '../fields.js';
import { FilesService } from '../files.js';
import { FlowsService } from '../flows.js';
import { FoldersService } from '../folders.js';
import { ItemsService } from '../items.js';
import { NotificationsService } from '../notifications.js';
import { OperationsService } from '../operations.js';
import { PermissionsService } from '../permissions.js';
import { PresetsService } from '../presets.js';
import { RelationsService } from '../relations.js';
import { RevisionsService } from '../revisions.js';
import { RolesService } from '../roles.js';
import { ServerService } from '../server.js';
import { SettingsService } from '../settings.js';
import { SharesService } from '../shares.js';
import { SpecificationService } from '../specifications.js';
import { TFAService } from '../tfa.js';
import { UsersService } from '../users.js';
import { UtilsService } from '../utils.js';
import { WebhooksService } from '../webhooks.js';
import { GraphQLBigInt } from './types/bigint.js';
import { GraphQLDate } from './types/date.js';
import { GraphQLGeoJSON } from './types/geojson.js';
@@ -88,6 +77,9 @@ import { GraphQLStringOrFloat } from './types/string-or-float.js';
import { GraphQLVoid } from './types/void.js';
import { addPathToValidationError } from './utils/add-path-to-validation-error.js';
import processError from './utils/process-error.js';
import { createSubscriptionGenerator } from './subscription.js';
import { getService } from '../../utils/get-service.js';
import { toBoolean } from '../../utils/to-boolean.js';
const validationRules = Array.from(specifiedRules);
@@ -198,6 +190,15 @@ export class GraphQLService {
: reduceSchema(this.schema, this.accountability?.permissions || null, ['delete']),
};
const subscriptionEventType = schemaComposer.createEnumTC({
name: 'EventEnum',
values: {
create: { value: 'create' },
update: { value: 'update' },
delete: { value: 'delete' },
},
});
const { ReadCollectionTypes } = getReadableTypes();
const { CreateCollectionTypes, UpdateCollectionTypes, DeleteCollectionTypes } = getWritableTypes();
@@ -1066,6 +1067,29 @@ export class GraphQLService {
},
});
}
const eventName = `${collection.collection}_mutated`;
if (collection.collection in ReadCollectionTypes) {
const subscriptionType = schemaComposer.createObjectTC({
name: eventName,
fields: {
key: new GraphQLNonNull(GraphQLID),
event: subscriptionEventType,
data: ReadCollectionTypes[collection.collection]!,
},
});
schemaComposer.Subscription.addFields({
[eventName]: {
type: subscriptionType,
args: {
event: subscriptionEventType,
},
subscribe: createSubscriptionGenerator(self, eventName),
},
});
}
}
for (const relation of schema.read.relations) {
@@ -1411,7 +1435,12 @@ export class GraphQLService {
return await this.upsertSingleton(collection, args['data'], query);
}
const service = this.getService(collection);
const service = getService(collection, {
knex: this.knex,
accountability: this.accountability,
schema: this.schema,
});
const hasQuery = (query.fields || []).length > 0;
try {
@@ -1466,7 +1495,11 @@ export class GraphQLService {
* Execute the read action on the correct service. Checks for singleton as well.
*/
async read(collection: string, query: Query): Promise<Partial<Item>> {
const service = this.getService(collection);
const service = getService(collection, {
knex: this.knex,
accountability: this.accountability,
schema: this.schema,
});
const result = this.schema.collections[collection]!.singleton
? await service.readSingleton(query, { stripNonRequested: false })
@@ -1483,7 +1516,11 @@ export class GraphQLService {
body: Record<string, any> | Record<string, any>[],
query: Query
): Promise<Partial<Item> | boolean> {
const service = this.getService(collection);
const service = getService(collection, {
knex: this.knex,
accountability: this.accountability,
schema: this.schema,
});
try {
await service.upsertSingleton(body);
@@ -1737,51 +1774,6 @@ export class GraphQLService {
return new GraphQLError(error.message, undefined, undefined, undefined, undefined, error);
}
/**
* Select the correct service for the given collection. This allows the individual services to run
* their custom checks (f.e. it allows UsersService to prevent updating TFA secret from outside)
*/
getService(collection: string): ItemsService {
const opts = {
knex: this.knex,
accountability: this.accountability,
schema: this.schema,
};
switch (collection) {
case 'directus_activity':
return new ActivityService(opts);
case 'directus_files':
return new FilesService(opts);
case 'directus_folders':
return new FoldersService(opts);
case 'directus_permissions':
return new PermissionsService(opts);
case 'directus_presets':
return new PresetsService(opts);
case 'directus_notifications':
return new NotificationsService(opts);
case 'directus_revisions':
return new RevisionsService(opts);
case 'directus_roles':
return new RolesService(opts);
case 'directus_settings':
return new SettingsService(opts);
case 'directus_users':
return new UsersService(opts);
case 'directus_webhooks':
return new WebhooksService(opts);
case 'directus_shares':
return new SharesService(opts);
case 'directus_flows':
return new FlowsService(opts);
case 'directus_operations':
return new OperationsService(opts);
default:
return new ItemsService(collection, opts);
}
}
/**
* Replace all fragments in a selectionset for the actual selection set as defined in the fragment
* Effectively merges the selections with the fragments used in those selections
@@ -1907,6 +1899,58 @@ export class GraphQLService {
},
}),
},
websocket: toBoolean(env['WEBSOCKETS_ENABLED'])
? {
type: new GraphQLObjectType({
name: 'server_info_websocket',
fields: {
rest: {
type: toBoolean(env['WEBSOCKETS_REST_ENABLED'])
? new GraphQLObjectType({
name: 'server_info_websocket_rest',
fields: {
authentication: {
type: new GraphQLEnumType({
name: 'server_info_websocket_rest_authentication',
values: {
public: { value: 'public' },
handshake: { value: 'handshake' },
strict: { value: 'strict' },
},
}),
},
path: { type: GraphQLString },
},
})
: GraphQLBoolean,
},
graphql: {
type: toBoolean(env['WEBSOCKETS_GRAPHQL_ENABLED'])
? new GraphQLObjectType({
name: 'server_info_websocket_graphql',
fields: {
authentication: {
type: new GraphQLEnumType({
name: 'server_info_websocket_graphql_authentication',
values: {
public: { value: 'public' },
handshake: { value: 'handshake' },
strict: { value: 'strict' },
},
}),
},
path: { type: GraphQLString },
},
})
: GraphQLBoolean,
},
heartbeat: {
type: toBoolean(env['WEBSOCKETS_HEARTBEAT_ENABLED']) ? GraphQLInt : GraphQLBoolean,
},
},
}),
}
: GraphQLBoolean,
queryLimit: {
type: new GraphQLObjectType({
name: 'server_info_query_limit',

View File

@@ -0,0 +1,103 @@
import { EventEmitter, on } from 'events';
import { getMessenger } from '../../messenger.js';
import type { GraphQLService } from './index.js';
import { getSchema } from '../../utils/get-schema.js';
import { ItemsService } from '../items.js';
import type { Query } from '@directus/types';
import type { GraphQLResolveInfo, SelectionNode } from 'graphql';
const messages = createPubSub(new EventEmitter());
export function bindPubSub() {
const messenger = getMessenger();
messenger.subscribe('websocket.event', (message: Record<string, any>) => {
messages.publish(`${message['collection']}_mutated`, message);
});
}
export function createSubscriptionGenerator(self: GraphQLService, event: string) {
return async function* (_x: unknown, _y: unknown, _z: unknown, request: GraphQLResolveInfo) {
const fields = parseFields(self, request);
const args = parseArguments(request);
for await (const payload of messages.subscribe(event)) {
const eventData = payload as Record<string, any>;
if ('event' in args && eventData['action'] !== args['event']) {
continue; // skip filtered events
}
const schema = await getSchema();
if (eventData['action'] === 'create') {
const { collection, key } = eventData;
const service = new ItemsService(collection, { schema });
const data = await service.readOne(key, { fields } as Query);
yield { [event]: { key, data, event: 'create' } };
}
if (eventData['action'] === 'update') {
const { collection, keys } = eventData;
const service = new ItemsService(collection, { schema });
for (const key of keys) {
const data = await service.readOne(key, { fields } as Query);
yield { [event]: { key, data, event: 'update' } };
}
}
if (eventData['action'] === 'delete') {
const { keys } = eventData;
for (const key of keys) {
yield { [event]: { key, data: null, event: 'delete' } };
}
}
}
};
}
function createPubSub<P extends { [key: string]: unknown }>(emitter: EventEmitter) {
return {
publish: <T extends Extract<keyof P, string>>(event: T, payload: P[T]) =>
void emitter.emit(event as string, payload),
subscribe: async function* <T extends Extract<keyof P, string>>(event: T): AsyncIterableIterator<P[T]> {
const asyncIterator = on(emitter, event);
for await (const [value] of asyncIterator) {
yield value;
}
},
};
}
function parseFields(service: GraphQLService, request: GraphQLResolveInfo) {
const selections = request.fieldNodes[0]?.selectionSet?.selections ?? [];
const dataSelections = selections.reduce((result: readonly SelectionNode[], selection: SelectionNode) => {
if (
selection.kind === 'Field' &&
selection.name.value === 'data' &&
selection.selectionSet?.kind === 'SelectionSet'
) {
return selection.selectionSet.selections;
}
return result;
}, []);
const { fields } = service.getQuery({}, dataSelections, request.variableValues);
return fields ?? [];
}
function parseArguments(request: GraphQLResolveInfo) {
const args = request.fieldNodes[0]?.arguments ?? [];
return args.reduce((result, current) => {
if ('value' in current.value && typeof current.value.value === 'string') {
result[current.name.value] = current.value.value;
}
return result;
}, {} as Record<string, string>);
}

View File

@@ -16,6 +16,7 @@ import { getStorage } from '../storage/index.js';
import type { AbstractServiceOptions } from '../types/index.js';
import { version } from '../utils/package.js';
import { SettingsService } from './settings.js';
import { toBoolean } from '../utils/to-boolean.js';
export class ServerService {
knex: Knex;
@@ -78,6 +79,32 @@ export class ServerService {
};
}
if (this.accountability?.user) {
if (toBoolean(env['WEBSOCKETS_ENABLED'])) {
info['websocket'] = {};
info['websocket'].rest = toBoolean(env['WEBSOCKETS_REST_ENABLED'])
? {
authentication: env['WEBSOCKETS_REST_AUTH'],
path: env['WEBSOCKETS_REST_PATH'],
}
: false;
info['websocket'].graphql = toBoolean(env['WEBSOCKETS_GRAPHQL_ENABLED'])
? {
authentication: env['WEBSOCKETS_GRAPHQL_AUTH'],
path: env['WEBSOCKETS_GRAPHQL_PATH'],
}
: false;
info['websocket'].heartbeat = toBoolean(env['WEBSOCKETS_HEARTBEAT_ENABLED'])
? env['WEBSOCKETS_HEARTBEAT_PERIOD']
: false;
} else {
info['websocket'] = false;
}
}
return info;
}

View File

@@ -0,0 +1,69 @@
import { describe, expect, test, vi } from 'vitest';
import type { WebSocketClient } from '../websocket/types.js';
import { WebSocketController, getWebSocketController } from '../websocket/controllers/index.js';
import type { Accountability } from '@directus/types';
import { WebSocketService } from './websocket.js';
vi.mock('../emitter');
vi.mock('../websocket/controllers/index');
function mockClient(accountability: Accountability | null = null) {
return {
on: vi.fn(),
off: vi.fn(),
send: vi.fn(),
close: vi.fn(),
accountability,
} as unknown as WebSocketClient;
}
describe('WebSocketService', () => {
test('get clients', () => {
vi.mocked(getWebSocketController).mockReturnValue({
clients: new Set([mockClient(), mockClient(), mockClient()]),
} as unknown as WebSocketController);
const wsService = new WebSocketService();
expect(wsService.clients().size).toBe(3);
});
test('broadcast', () => {
const clients = new Set([mockClient(), mockClient(), mockClient()]);
const message = 'test 123';
vi.mocked(getWebSocketController).mockReturnValue({ clients } as unknown as WebSocketController);
const wsService = new WebSocketService();
wsService.broadcast(message);
for (const client of clients) {
expect(client.send).toBeCalledWith(message);
}
});
test('broadcast with role filter', () => {
const clients = [mockClient({ user: 'test', role: 'test' }), mockClient({ user: 'test2', role: 'test2' })];
const message = 'test 123';
vi.mocked(getWebSocketController).mockReturnValue({ clients: new Set(clients) } as unknown as WebSocketController);
const wsService = new WebSocketService();
wsService.broadcast(message, { role: 'test' });
expect(clients[0]!.send).toBeCalledWith(message);
expect(clients[1]!.send).not.toBeCalled();
});
test('broadcast with user filter', () => {
const clients = [mockClient({ user: 'test', role: 'test' }), mockClient({ user: 'test2', role: 'test2' })];
const message = 'test 123';
vi.mocked(getWebSocketController).mockReturnValue({ clients: new Set(clients) } as unknown as WebSocketController);
const wsService = new WebSocketService();
wsService.broadcast(message, { user: 'test2' });
expect(clients[0]!.send).not.toBeCalled();
expect(clients[1]!.send).toBeCalledWith(message);
});
});

View File

@@ -0,0 +1,34 @@
import type { ActionHandler } from '@directus/types';
import { getWebSocketController } from '../websocket/controllers/index.js';
import type { WebSocketController } from '../websocket/controllers/rest.js';
import type { WebSocketClient } from '../websocket/types.js';
import type { WebSocketMessage } from '../websocket/messages.js';
import emitter from '../emitter.js';
export class WebSocketService {
private controller: WebSocketController;
constructor() {
this.controller = getWebSocketController();
}
on(event: 'connect' | 'message' | 'error' | 'close', callback: ActionHandler) {
emitter.onAction('websocket.' + event, callback);
}
off(event: 'connect' | 'message' | 'error' | 'close', callback: ActionHandler) {
emitter.offAction('websocket.' + event, callback);
}
broadcast(message: string | WebSocketMessage, filter?: { user?: string; role?: string }) {
this.controller.clients.forEach((client: WebSocketClient) => {
if (filter && filter.user && filter.user !== client.accountability?.user) return;
if (filter && filter.role && filter.role !== client.accountability?.role) return;
client.send(typeof message === 'string' ? message : JSON.stringify(message));
});
}
clients(): Set<WebSocketClient> {
return this.controller.clients;
}
}

View File

@@ -0,0 +1,87 @@
import { expect, describe, test, vi } from 'vitest';
import { getAccountabilityForRole } from './get-accountability-for-role.js';
vi.mock('./get-permissions', () => ({
getPermissions: vi.fn().mockReturnValue([]),
}));
function mockDatabase() {
const self: Record<string, any> = {
select: vi.fn(() => self),
from: vi.fn(() => self),
where: vi.fn(() => self),
first: vi.fn(),
};
return self;
}
describe('getAccountabilityForRole', async () => {
test('no role', async () => {
const result = await getAccountabilityForRole(null, {
accountability: null,
schema: {} as any,
database: vi.fn() as any,
});
expect(result).toStrictEqual({
admin: false,
app: false,
permissions: [],
role: null,
user: null,
});
});
test('system role', async () => {
const result = await getAccountabilityForRole('system', {
accountability: null,
schema: {} as any,
database: vi.fn() as any,
});
expect(result).toStrictEqual({
admin: true,
app: true,
permissions: [],
role: null,
user: null,
});
});
test('get role from database', async () => {
const db = mockDatabase();
db['first'].mockReturnValue({
admin_access: 'not true',
app_access: '1',
});
const result = await getAccountabilityForRole('123-456', {
accountability: null,
schema: {} as any,
database: db as any,
});
expect(result).toStrictEqual({
admin: false,
app: true,
permissions: [],
role: '123-456',
user: null,
});
});
test('database invalid role', async () => {
const db = mockDatabase();
db['first'].mockReturnValue(false);
expect(() =>
getAccountabilityForRole('456-789', {
accountability: null,
schema: {} as any,
database: db as any,
})
).rejects.toThrow('Configured role "456-789" isn\'t a valid role ID or doesn\'t exist.');
});
});

View File

@@ -0,0 +1,101 @@
import { expect, describe, test, vi } from 'vitest';
import jwt from 'jsonwebtoken';
import env from '../env.js';
import { getAccountabilityForToken } from './get-accountability-for-token.js';
import getDatabase from '../database/index.js';
vi.mock('../env', () => {
const MOCK_ENV = {
SECRET: 'super-secure-secret',
EXTENSIONS_PATH: './extensions',
};
return {
default: MOCK_ENV,
getEnv: () => MOCK_ENV,
};
});
vi.mock('../database/index', () => {
const self: Record<string, any> = {
select: vi.fn(() => self),
from: vi.fn(() => self),
leftJoin: vi.fn(() => self),
where: vi.fn(() => self),
first: vi.fn(),
};
return { default: vi.fn(() => self) };
});
describe('getAccountabilityForToken', async () => {
test('minimal token payload', async () => {
const token = jwt.sign({ role: '123-456-789', app_access: false, admin_access: false }, env['SECRET'], {
issuer: 'directus',
});
const result = await getAccountabilityForToken(token);
expect(result).toStrictEqual({ admin: false, app: false, role: '123-456-789', user: null });
});
test('full token payload', async () => {
const token = jwt.sign(
{
share: 'share-id',
share_scope: 'share-scope',
id: 'user-id',
role: 'role-id',
admin_access: 1,
app_access: 1,
},
env['SECRET'],
{ issuer: 'directus' }
);
const result = await getAccountabilityForToken(token);
expect(result.admin).toBe(true);
expect(result.app).toBe(true);
expect(result.role).toBe('role-id');
expect(result.share).toBe('share-id');
expect(result.share_scope).toBe('share-scope');
expect(result.user).toBe('user-id');
});
test('throws token expired error', async () => {
const token = jwt.sign({ role: '123-456-789' }, env['SECRET'], { issuer: 'directus', expiresIn: -1 });
expect(() => getAccountabilityForToken(token)).rejects.toThrow('Token expired.');
});
test('throws token invalid error', async () => {
const token = jwt.sign({ role: '123-456-789' }, 'bad-secret', { issuer: 'directus' });
expect(() => getAccountabilityForToken(token)).rejects.toThrow('Token invalid.');
});
test('find user in database', async () => {
const db = getDatabase();
vi.spyOn(db, 'first').mockReturnValue({
id: 'user-id',
role: 'role-id',
admin_access: false,
app_access: true,
} as any);
const token = jwt.sign({ role: '123-456-789' }, 'bad-secret');
const result = await getAccountabilityForToken(token);
expect(result).toStrictEqual({
user: 'user-id',
role: 'role-id',
admin: false,
app: true,
});
});
test('no user found', async () => {
const db = getDatabase();
vi.spyOn(db, 'first').mockReturnValue(false as any);
const token = jwt.sign({ role: '123-456-789' }, 'bad-secret');
expect(() => getAccountabilityForToken(token)).rejects.toThrow('Invalid user credentials.');
});
});

View File

@@ -0,0 +1,58 @@
import getDatabase from '../database/index.js';
import type { Accountability } from '@directus/types';
import isDirectusJWT from './is-directus-jwt.js';
import { InvalidCredentialsException } from '../index.js';
import env from '../env.js';
import { verifyAccessJWT } from './jwt.js';
export async function getAccountabilityForToken(
token?: string | null,
accountability?: Accountability
): Promise<Accountability> {
if (!accountability) {
accountability = {
user: null,
role: null,
admin: false,
app: false,
};
}
if (token) {
if (isDirectusJWT(token)) {
const payload = verifyAccessJWT(token, env['SECRET'] as string);
accountability.role = payload.role;
accountability.admin = payload.admin_access === true || payload.admin_access == 1;
accountability.app = payload.app_access === true || payload.app_access == 1;
if (payload.share) accountability.share = payload.share;
if (payload.share_scope) accountability.share_scope = payload.share_scope;
if (payload.id) accountability.user = payload.id;
} else {
// Try finding the user with the provided token
const database = getDatabase();
const user = await database
.select('directus_users.id', 'directus_users.role', 'directus_roles.admin_access', 'directus_roles.app_access')
.from('directus_users')
.leftJoin('directus_roles', 'directus_users.role', 'directus_roles.id')
.where({
'directus_users.token': token,
status: 'active',
})
.first();
if (!user) {
throw new InvalidCredentialsException();
}
accountability.user = user.id;
accountability.role = user.role;
accountability.admin = user.admin_access === true || user.admin_access == 1;
accountability.app = user.app_access === true || user.app_access == 1;
}
}
return accountability;
}

View File

@@ -0,0 +1,69 @@
import {
ActivityService,
DashboardsService,
FilesService,
FlowsService,
FoldersService,
ItemsService,
NotificationsService,
OperationsService,
PanelsService,
PermissionsService,
PresetsService,
RevisionsService,
RolesService,
SettingsService,
SharesService,
UsersService,
WebhooksService,
} from '../index.js';
import type { AbstractServiceOptions } from '../types/services.js';
/**
* Select the correct service for the given collection. This allows the individual services to run
* their custom checks (f.e. it allows UsersService to prevent updating TFA secret from outside)
*/
export function getService(collection: string, opts: AbstractServiceOptions): ItemsService {
switch (collection) {
case 'directus_activity':
return new ActivityService(opts);
// case 'directus_collections':
// return new CollectionsService(opts);
case 'directus_dashboards':
return new DashboardsService(opts);
// case 'directus_fields':
// return new FieldsService(opts);
case 'directus_files':
return new FilesService(opts);
case 'directus_flows':
return new FlowsService(opts);
case 'directus_folders':
return new FoldersService(opts);
case 'directus_notifications':
return new NotificationsService(opts);
case 'directus_operations':
return new OperationsService(opts);
case 'directus_panels':
return new PanelsService(opts);
case 'directus_permissions':
return new PermissionsService(opts);
case 'directus_presets':
return new PresetsService(opts);
// case 'directus_relations':
// return new RelationsService(opts);
case 'directus_revisions':
return new RevisionsService(opts);
case 'directus_roles':
return new RolesService(opts);
case 'directus_settings':
return new SettingsService(opts);
case 'directus_shares':
return new SharesService(opts);
case 'directus_users':
return new UsersService(opts);
case 'directus_webhooks':
return new WebhooksService(opts);
default:
return new ItemsService(collection, opts);
}
}

View File

@@ -0,0 +1,16 @@
import { expect, test } from 'vitest';
import { toBoolean } from './to-boolean.js';
test.each([
['true', true],
[true, true],
['1', true],
[1, true],
['false', false],
['anything', false],
[123, false],
[{}, false],
[['{}'], false],
])('toBoolean(%s) -> %s', (value, expected) => {
expect(toBoolean(value)).toBe(expected);
});

View File

@@ -0,0 +1,6 @@
/**
* Convert environment variable to Boolean
*/
export function toBoolean(value: any): boolean {
return value === 'true' || value === true || value === '1' || value === 1;
}

View File

@@ -0,0 +1,137 @@
import type { Accountability } from '@directus/types';
import { describe, expect, test, vi } from 'vitest';
import type { Mock } from 'vitest';
import { InvalidCredentialsException } from '../index.js';
import { getAccountabilityForRole } from '../utils/get-accountability-for-role.js';
import { getAccountabilityForToken } from '../utils/get-accountability-for-token.js';
import { authenticateConnection, authenticationSuccess, refreshAccountability } from './authenticate.js';
import type { WebSocketAuthMessage } from './messages.js';
import { getExpiresAtForToken } from './utils/get-expires-at-for-token.js';
vi.mock('../utils/get-accountability-for-token', () => ({
getAccountabilityForToken: vi.fn().mockReturnValue({
role: null, // minimum viable accountability
} as Accountability),
}));
vi.mock('../utils/get-accountability-for-role', () => ({
getAccountabilityForRole: vi.fn(),
}));
vi.mock('./utils/get-expires-at-for-token', () => ({
getExpiresAtForToken: vi.fn(),
}));
vi.mock('../utils/get-schema');
vi.mock('../services/authentication', () => ({
AuthenticationService: vi.fn(() => ({
login: vi.fn().mockReturnValue({ accessToken: '123', refreshToken: 'refresh', expires: 123456 }),
refresh: vi.fn().mockReturnValue({ accessToken: '456', refreshToken: 'refresh' }),
})),
}));
vi.mock('../database');
describe('authenticateConnection', () => {
test('Success with email/password', async () => {
const TIMESTAMP = 123456789;
(getExpiresAtForToken as Mock).mockReturnValue(TIMESTAMP);
const result = await authenticateConnection({
type: 'auth',
email: 'email',
password: 'password',
} as WebSocketAuthMessage);
expect(result).toStrictEqual({
accountability: { role: null },
expires_at: TIMESTAMP,
refresh_token: 'refresh',
});
});
test('Success with refresh_token', async () => {
const TIMESTAMP = 987654;
(getExpiresAtForToken as Mock).mockReturnValue(TIMESTAMP);
const result = await authenticateConnection({
type: 'auth',
refresh_token: 'refresh_token',
} as WebSocketAuthMessage);
expect(result).toStrictEqual({
accountability: { role: null },
expires_at: TIMESTAMP,
refresh_token: 'refresh',
});
});
test('Success with access_token', async () => {
const TIMESTAMP = 456987;
(getExpiresAtForToken as Mock).mockReturnValue(TIMESTAMP);
const result = await authenticateConnection({
type: 'auth',
access_token: 'access_token',
} as WebSocketAuthMessage);
expect(result).toStrictEqual({
accountability: { role: null },
expires_at: TIMESTAMP,
refresh_token: undefined,
});
});
test('Failure token expired', async () => {
(getAccountabilityForToken as Mock).mockImplementation(() => {
throw new InvalidCredentialsException('Token expired.');
});
expect(() =>
authenticateConnection({
type: 'auth',
access_token: 'expired',
} as WebSocketAuthMessage)
).rejects.toThrow('Token expired.');
});
test('Failure authentication failed', async () => {
expect(() =>
authenticateConnection({
type: 'auth',
access_token: '',
} as WebSocketAuthMessage)
).rejects.toThrow('Authentication failed.');
});
});
describe('refreshAccountability', () => {
test('should just work', async () => {
(getAccountabilityForRole as Mock).mockReturnValue({
role: '123-456-789',
} as Accountability);
const result = await refreshAccountability({
role: null,
user: 'abc-def-ghi',
});
expect(result).toStrictEqual({
role: '123-456-789',
user: 'abc-def-ghi',
});
});
});
describe('authenticationSuccess', () => {
test('without uid', async () => {
const result = authenticationSuccess();
expect(result).toBe('{"type":"auth","status":"ok"}');
});
test('with uid', async () => {
const result = authenticationSuccess('123456');
expect(result).toBe('{"type":"auth","status":"ok","uid":"123456"}');
});
});

View File

@@ -0,0 +1,79 @@
import type { Accountability } from '@directus/types';
import { DEFAULT_AUTH_PROVIDER } from '../constants.js';
import getDatabase from '../database/index.js';
import { InvalidCredentialsException } from '../exceptions/index.js';
import { AuthenticationService } from '../services/index.js';
import { getAccountabilityForRole } from '../utils/get-accountability-for-role.js';
import { getAccountabilityForToken } from '../utils/get-accountability-for-token.js';
import { getSchema } from '../utils/get-schema.js';
import { WebSocketException } from './exceptions.js';
import type { BasicAuthMessage, WebSocketResponse } from './messages.js';
import type { AuthenticationState } from './types.js';
import { getExpiresAtForToken } from './utils/get-expires-at-for-token.js';
export async function authenticateConnection(
message: BasicAuthMessage & Record<string, any>
): Promise<AuthenticationState> {
let access_token: string | undefined, refresh_token: string | undefined;
try {
if ('email' in message && 'password' in message) {
const authenticationService = new AuthenticationService({ schema: await getSchema() });
const { accessToken, refreshToken } = await authenticationService.login(DEFAULT_AUTH_PROVIDER, message);
access_token = accessToken;
refresh_token = refreshToken;
}
if ('refresh_token' in message) {
const authenticationService = new AuthenticationService({ schema: await getSchema() });
const { accessToken, refreshToken } = await authenticationService.refresh(message.refresh_token);
access_token = accessToken;
refresh_token = refreshToken;
}
if ('access_token' in message) {
access_token = message.access_token;
}
if (!access_token) throw new Error();
const accountability = await getAccountabilityForToken(access_token);
const expires_at = getExpiresAtForToken(access_token);
return { accountability, expires_at, refresh_token } as AuthenticationState;
} catch (error) {
if (error instanceof InvalidCredentialsException && error.message === 'Token expired.') {
throw new WebSocketException('auth', 'TOKEN_EXPIRED', 'Token expired.', message['uid']);
}
throw new WebSocketException('auth', 'AUTH_FAILED', 'Authentication failed.', message['uid']);
}
}
export async function refreshAccountability(
accountability: Accountability | null | undefined
): Promise<Accountability> {
const result: Accountability = await getAccountabilityForRole(accountability?.role || null, {
accountability: accountability || null,
schema: await getSchema(),
database: getDatabase(),
});
result.user = accountability?.user || null;
return result;
}
export function authenticationSuccess(uid?: string | number, refresh_token?: string): string {
const message: WebSocketResponse = {
type: 'auth',
status: 'ok',
};
if (uid !== undefined) {
message.uid = uid;
}
if (refresh_token !== undefined) {
message['refresh_token'] = refresh_token;
}
return JSON.stringify(message);
}

View File

@@ -0,0 +1,344 @@
import type { Accountability } from '@directus/types';
import { parseJSON } from '@directus/utils';
import type { IncomingMessage, Server as httpServer } from 'http';
import type { ParsedUrlQuery } from 'querystring';
import type { RateLimiterAbstract } from 'rate-limiter-flexible';
import type internal from 'stream';
import { parse } from 'url';
import { v4 as uuid } from 'uuid';
import WebSocket, { WebSocketServer } from 'ws';
import { fromZodError } from 'zod-validation-error';
import emitter from '../../emitter.js';
import env from '../../env.js';
import { InvalidConfigException, TokenExpiredException } from '../../exceptions/index.js';
import logger from '../../logger.js';
import { createRateLimiter } from '../../rate-limiter.js';
import { getAccountabilityForToken } from '../../utils/get-accountability-for-token.js';
import { toBoolean } from '../../utils/to-boolean.js';
import { authenticateConnection, authenticationSuccess } from '../authenticate.js';
import { WebSocketException, handleWebSocketException } from '../exceptions.js';
import { AuthMode, WebSocketAuthMessage, WebSocketMessage } from '../messages.js';
import type { AuthenticationState, UpgradeContext, WebSocketClient } from '../types.js';
import { getExpiresAtForToken } from '../utils/get-expires-at-for-token.js';
import { getMessageType } from '../utils/message.js';
import { waitForAnyMessage, waitForMessageType } from '../utils/wait-for-message.js';
import { registerWebSocketEvents } from './hooks.js';
const TOKEN_CHECK_INTERVAL = 15 * 60 * 1000; // 15 minutes
export default abstract class SocketController {
server: WebSocket.Server;
clients: Set<WebSocketClient>;
authentication: {
mode: AuthMode;
timeout: number;
};
endpoint: string;
maxConnections: number;
private rateLimiter: RateLimiterAbstract | null;
private authInterval: NodeJS.Timer | null;
constructor(httpServer: httpServer, configPrefix: string) {
this.server = new WebSocketServer({ noServer: true });
this.clients = new Set();
this.authInterval = null;
const { endpoint, authentication, maxConnections } = this.getEnvironmentConfig(configPrefix);
this.endpoint = endpoint;
this.authentication = authentication;
this.maxConnections = maxConnections;
this.rateLimiter = this.getRateLimiter();
httpServer.on('upgrade', this.handleUpgrade.bind(this));
this.checkClientTokens();
registerWebSocketEvents();
}
protected getEnvironmentConfig(configPrefix: string): {
endpoint: string;
authentication: {
mode: AuthMode;
timeout: number;
};
maxConnections: number;
} {
const endpoint = String(env[`${configPrefix}_PATH`]);
const authMode = AuthMode.safeParse(String(env[`${configPrefix}_AUTH`]).toLowerCase());
const authTimeout = Number(env[`${configPrefix}_AUTH_TIMEOUT`]) * 1000;
const maxConnections =
`${configPrefix}_CONN_LIMIT` in env ? Number(env[`${configPrefix}_CONN_LIMIT`]) : Number.POSITIVE_INFINITY;
if (!authMode.success) {
throw new InvalidConfigException(fromZodError(authMode.error, { prefix: `${configPrefix}_AUTH` }).message);
}
return {
endpoint,
maxConnections,
authentication: {
mode: authMode.data,
timeout: authTimeout,
},
};
}
protected getRateLimiter() {
if (toBoolean(env['RATE_LIMITER_ENABLED']) === true) {
return createRateLimiter('RATE_LIMITER', {
keyPrefix: 'websocket',
});
}
return null;
}
protected async handleUpgrade(request: IncomingMessage, socket: internal.Duplex, head: Buffer) {
const { pathname, query } = parse(request.url!, true);
if (pathname !== this.endpoint) return;
if (this.clients.size >= this.maxConnections) {
logger.debug('WebSocket upgrade denied - max connections reached');
socket.write('HTTP/1.1 403 Forbidden\r\n\r\n');
socket.destroy();
return;
}
const context: UpgradeContext = { request, socket, head };
if (this.authentication.mode === 'strict') {
await this.handleStrictUpgrade(context, query);
return;
}
if (this.authentication.mode === 'handshake') {
await this.handleHandshakeUpgrade(context);
return;
}
this.server.handleUpgrade(request, socket, head, async (ws) => {
const state = { accountability: null, expires_at: null } as AuthenticationState;
this.server.emit('connection', ws, state);
});
}
protected async handleStrictUpgrade({ request, socket, head }: UpgradeContext, query: ParsedUrlQuery) {
let accountability: Accountability | null, expires_at: number | null;
try {
const token = query['access_token'] as string;
accountability = await getAccountabilityForToken(token);
expires_at = getExpiresAtForToken(token);
} catch {
accountability = null;
expires_at = null;
}
if (!accountability || !accountability.user) {
logger.debug('WebSocket upgrade denied - ' + JSON.stringify(accountability || 'invalid'));
socket.write('HTTP/1.1 401 Unauthorized\r\n\r\n');
socket.destroy();
return;
}
this.server.handleUpgrade(request, socket, head, async (ws) => {
const state = { accountability, expires_at } as AuthenticationState;
this.server.emit('connection', ws, state);
});
}
protected async handleHandshakeUpgrade({ request, socket, head }: UpgradeContext) {
this.server.handleUpgrade(request, socket, head, async (ws) => {
try {
const payload = await waitForAnyMessage(ws, this.authentication.timeout);
if (getMessageType(payload) !== 'auth') throw new Error();
const state = await authenticateConnection(WebSocketAuthMessage.parse(payload));
ws.send(authenticationSuccess(payload['uid'], state.refresh_token));
this.server.emit('connection', ws, state);
} catch {
logger.debug('WebSocket authentication handshake failed');
const error = new WebSocketException('auth', 'AUTH_FAILED', 'Authentication handshake failed.');
handleWebSocketException(ws, error, 'auth');
ws.close();
}
});
}
createClient(ws: WebSocket, { accountability, expires_at }: AuthenticationState) {
const client = ws as WebSocketClient;
client.accountability = accountability;
client.expires_at = expires_at;
client.uid = uuid();
client.auth_timer = null;
ws.on('message', async (data: WebSocket.RawData) => {
if (this.rateLimiter !== null) {
try {
await this.rateLimiter.consume(client.uid);
} catch (limit) {
const timeout = (limit as any)?.msBeforeNext ?? this.rateLimiter.msDuration;
const error = new WebSocketException(
'server',
'REQUESTS_EXCEEDED',
`Too many messages, retry after ${timeout}ms.`
);
handleWebSocketException(client, error, 'server');
logger.debug(`WebSocket#${client.uid} is rate limited`);
return;
}
}
let message: WebSocketMessage;
try {
message = this.parseMessage(data.toString());
} catch (err: any) {
handleWebSocketException(client, err, 'server');
return;
}
if (getMessageType(message) === 'auth') {
try {
await this.handleAuthRequest(client, WebSocketAuthMessage.parse(message));
} catch {
// ignore errors
}
return;
}
// this log cannot be higher in the function or it will leak credentials
logger.trace(`WebSocket#${client.uid} - ${JSON.stringify(message)}`);
ws.emit('parsed-message', message);
});
ws.on('error', () => {
logger.debug(`WebSocket#${client.uid} connection errored`);
if (client.auth_timer) {
clearTimeout(client.auth_timer);
client.auth_timer = null;
}
this.clients.delete(client);
});
ws.on('close', () => {
logger.debug(`WebSocket#${client.uid} connection closed`);
if (client.auth_timer) {
clearTimeout(client.auth_timer);
client.auth_timer = null;
}
this.clients.delete(client);
});
logger.debug(`WebSocket#${client.uid} connected`);
if (accountability) {
logger.trace(`WebSocket#${client.uid} authenticated as ${JSON.stringify(accountability)}`);
}
this.setTokenExpireTimer(client);
this.clients.add(client);
return client;
}
protected parseMessage(data: string): WebSocketMessage {
let message: WebSocketMessage;
try {
message = WebSocketMessage.parse(parseJSON(data));
} catch (err: any) {
throw new WebSocketException('server', 'INVALID_PAYLOAD', 'Unable to parse the incoming message.');
}
return message;
}
protected async handleAuthRequest(client: WebSocketClient, message: WebSocketAuthMessage) {
try {
const { accountability, expires_at, refresh_token } = await authenticateConnection(message);
client.accountability = accountability;
client.expires_at = expires_at;
this.setTokenExpireTimer(client);
emitter.emitAction('websocket.auth.success', { client });
client.send(authenticationSuccess(message.uid, refresh_token));
logger.trace(`WebSocket#${client.uid} authenticated as ${JSON.stringify(client.accountability)}`);
} catch (error) {
logger.trace(`WebSocket#${client.uid} failed authentication`);
emitter.emitAction('websocket.auth.failure', { client });
client.accountability = null;
client.expires_at = null;
const _error =
error instanceof WebSocketException
? error
: new WebSocketException('auth', 'AUTH_FAILED', 'Authentication failed.', message.uid);
handleWebSocketException(client, _error, 'auth');
if (this.authentication.mode !== 'public') {
client.close();
}
}
}
setTokenExpireTimer(client: WebSocketClient) {
if (client.auth_timer !== null) {
// clear up old timeouts if needed
clearTimeout(client.auth_timer);
client.auth_timer = null;
}
if (!client.expires_at) return;
const expiresIn = client.expires_at * 1000 - Date.now();
if (expiresIn > TOKEN_CHECK_INTERVAL) return;
client.auth_timer = setTimeout(() => {
client.accountability = null;
client.expires_at = null;
handleWebSocketException(client, new TokenExpiredException(), 'auth');
waitForMessageType(client, 'auth', this.authentication.timeout).catch((msg: WebSocketMessage) => {
const error = new WebSocketException('auth', 'AUTH_TIMEOUT', 'Authentication timed out.', msg?.uid);
handleWebSocketException(client, error, 'auth');
if (this.authentication.mode !== 'public') {
client.close();
}
});
}, expiresIn);
}
checkClientTokens() {
this.authInterval = setInterval(() => {
if (this.clients.size === 0) return;
// check the clients and set shorter timeouts if needed
for (const client of this.clients) {
if (client.expires_at === null || client.auth_timer !== null) continue;
this.setTokenExpireTimer(client);
}
}, TOKEN_CHECK_INTERVAL);
}
terminate() {
if (this.authInterval) clearInterval(this.authInterval);
this.clients.forEach((client) => {
if (client.auth_timer) clearTimeout(client.auth_timer);
});
this.server.clients.forEach((ws) => {
ws.terminate();
});
}
}

View File

@@ -0,0 +1,121 @@
import { CloseCode, MessageType, makeServer } from 'graphql-ws';
import type { Server } from 'graphql-ws';
import type { Server as httpServer } from 'http';
import type { WebSocket } from 'ws';
import env from '../../env.js';
import logger from '../../logger.js';
import { bindPubSub } from '../../services/graphql/subscription.js';
import { GraphQLService } from '../../services/index.js';
import { getSchema } from '../../utils/get-schema.js';
import { authenticateConnection, refreshAccountability } from '../authenticate.js';
import { handleWebSocketException } from '../exceptions.js';
import { ConnectionParams, WebSocketMessage } from '../messages.js';
import type { AuthenticationState, GraphQLSocket, UpgradeContext, WebSocketClient } from '../types.js';
import { getMessageType } from '../utils/message.js';
import SocketController from './base.js';
export class GraphQLSubscriptionController extends SocketController {
gql: Server<GraphQLSocket>;
constructor(httpServer: httpServer) {
super(httpServer, 'WEBSOCKETS_GRAPHQL');
this.server.on('connection', (ws: WebSocket, auth: AuthenticationState) => {
this.bindEvents(this.createClient(ws, auth));
});
this.gql = makeServer<ConnectionParams, GraphQLSocket>({
schema: async (ctx) => {
const accountability = ctx.extra.client.accountability;
// for now only the items will be watched, system events tbd
const service = new GraphQLService({
schema: await getSchema(),
scope: 'items',
accountability,
});
return service.getSchema();
},
});
bindPubSub();
logger.info(`GraphQL Subscriptions started at ws://${env['HOST']}:${env['PORT']}${this.endpoint}`);
}
private bindEvents(client: WebSocketClient) {
const closedHandler = this.gql.opened(
{
protocol: client.protocol,
send: (data) =>
new Promise((resolve, reject) => {
client.send(data, (err) => (err ? reject(err) : resolve()));
}),
close: (code, reason) => client.close(code, reason), // for standard closures
onMessage: (cb) => {
client.on('parsed-message', async (message: WebSocketMessage) => {
try {
if (getMessageType(message) === 'connection_init' && this.authentication.mode !== 'strict') {
const params = ConnectionParams.parse(message['payload']);
if (this.authentication.mode === 'handshake') {
if (typeof params.access_token === 'string') {
const { accountability, expires_at } = await authenticateConnection({
access_token: params.access_token,
});
client.accountability = accountability;
client.expires_at = expires_at;
} else {
client.close(CloseCode.Forbidden, 'Forbidden');
return;
}
}
} else if (this.authentication.mode === 'handshake' && !client.accountability?.user) {
// the first message should authenticate successfully in this mode
client.close(CloseCode.Forbidden, 'Forbidden');
return;
} else {
client.accountability = await refreshAccountability(client.accountability);
}
await cb(JSON.stringify(message));
} catch (error) {
handleWebSocketException(client, error, MessageType.Error);
}
});
},
},
{ client }
);
// notify server that the socket closed
client.once('close', (code, reason) => closedHandler(code, reason.toString()));
// check strict authentication status
if (this.authentication.mode === 'strict' && !client.accountability?.user) {
client.close(CloseCode.Forbidden, 'Forbidden');
}
}
override setTokenExpireTimer(client: WebSocketClient) {
if (client.auth_timer !== null) {
clearTimeout(client.auth_timer);
client.auth_timer = null;
}
if (this.authentication.mode !== 'handshake') return;
client.auth_timer = setTimeout(() => {
if (!client.accountability?.user) {
client.close(CloseCode.Forbidden, 'Forbidden');
}
}, this.authentication.timeout);
}
protected override async handleHandshakeUpgrade({ request, socket, head }: UpgradeContext) {
this.server.handleUpgrade(request, socket, head, async (ws) => {
this.server.emit('connection', ws, { accountability: null, expires_at: null });
// actual enforcement is handled by the setTokenExpireTimer function
});
}
}

View File

@@ -0,0 +1,140 @@
import emitter from '../../emitter.js';
import { getMessenger } from '../../messenger.js';
import type { WebSocketEvent } from '../messages.js';
let actionsRegistered = false;
export function registerWebSocketEvents() {
if (actionsRegistered) return;
actionsRegistered = true;
registerActionHooks([
'items',
'activity',
'collections',
'folders',
'permissions',
'presets',
'revisions',
'roles',
'settings',
'users',
'webhooks',
]);
registerFieldsHooks();
registerFilesHooks();
registerRelationsHooks();
}
function registerActionHooks(modules: string[]) {
// register event hooks that can be handled in an uniform manner
for (const module of modules) {
registerAction(module + '.create', ({ key, collection, payload = {} }) => ({
collection,
action: 'create',
key,
payload,
}));
registerAction(module + '.update', ({ keys, collection, payload = {} }) => ({
collection,
action: 'update',
keys,
payload,
}));
registerAction(module + '.delete', ({ keys, collection, payload = [] }) => ({
collection,
action: 'delete',
keys,
payload,
}));
}
}
function registerFieldsHooks() {
// exception for field hooks that don't report `directus_fields` as being the collection
registerAction('fields.create', ({ key, payload = {} }) => ({
collection: 'directus_fields',
action: 'create',
key,
payload,
}));
registerAction('fields.update', ({ keys, payload = {} }) => ({
collection: 'directus_fields',
action: 'update',
keys,
payload,
}));
registerAction('fields.delete', ({ keys, payload = [] }) => ({
collection: 'directus_fields',
action: 'delete',
keys,
payload,
}));
}
function registerFilesHooks() {
// extra event for file uploads that doubles as create event
registerAction('files.upload', ({ key, collection, payload = {} }) => ({
collection,
action: 'create',
key,
payload,
}));
registerAction('files.update', ({ keys, collection, payload = {} }) => ({
collection,
action: 'update',
keys,
payload,
}));
registerAction('files.delete', ({ keys, collection, payload = [] }) => ({
collection,
action: 'delete',
keys,
payload,
}));
}
function registerRelationsHooks() {
// exception for relation hooks that don't report `directus_relations` as being the collection
registerAction('relations.create', ({ key, payload = {} }) => ({
collection: 'directus_relations',
action: 'create',
key,
payload: { ...payload, key },
}));
registerAction('relations.update', ({ keys, payload = {} }) => ({
collection: 'directus_relations',
action: 'update',
keys,
payload,
}));
registerAction('relations.delete', ({ collection, payload = [] }) => ({
collection: 'directus_relations',
action: 'delete',
keys: payload,
payload: { collection, fields: payload },
}));
}
/**
* Wrapper for emitter.onAction to hook into system events
* @param event The action event to watch
* @param transform Transformer function
*/
function registerAction(event: string, transform: (args: Record<string, any>) => WebSocketEvent) {
const messenger = getMessenger();
emitter.onAction(event, async (data: Record<string, any>) => {
// push the event through the Redis pub/sub
messenger.publish('websocket.event', transform(data) as Record<string, any>);
});
}

View File

@@ -0,0 +1,44 @@
import type { Server as httpServer } from 'http';
import env from '../../env.js';
import { ServiceUnavailableException } from '../../index.js';
import { toBoolean } from '../../utils/to-boolean.js';
import { GraphQLSubscriptionController } from './graphql.js';
import { WebSocketController } from './rest.js';
let websocketController: WebSocketController | undefined;
let subscriptionController: GraphQLSubscriptionController | undefined;
export function createWebSocketController(server: httpServer) {
if (toBoolean(env['WEBSOCKETS_REST_ENABLED'])) {
websocketController = new WebSocketController(server);
}
}
export function getWebSocketController() {
if (!toBoolean(env['WEBSOCKETS_ENABLED']) || !toBoolean(env['WEBSOCKETS_REST_ENABLED'])) {
throw new ServiceUnavailableException('WebSocket server is disabled', {
service: 'get-websocket-controller',
});
}
if (!websocketController) {
throw new ServiceUnavailableException('WebSocket server is not initialized', {
service: 'get-websocket-controller',
});
}
return websocketController;
}
export function createSubscriptionController(server: httpServer) {
if (toBoolean(env['WEBSOCKETS_GRAPHQL_ENABLED'])) {
subscriptionController = new GraphQLSubscriptionController(server);
}
}
export function getSubscriptionController() {
return subscriptionController;
}
export * from './graphql.js';
export * from './rest.js';

View File

@@ -0,0 +1,58 @@
import { parseJSON } from '@directus/utils';
import type { Server as httpServer } from 'http';
import type WebSocket from 'ws';
import emitter from '../../emitter.js';
import env from '../../env.js';
import logger from '../../logger.js';
import { refreshAccountability } from '../authenticate.js';
import { WebSocketException, handleWebSocketException } from '../exceptions.js';
import { WebSocketMessage } from '../messages.js';
import type { AuthenticationState, WebSocketClient } from '../types.js';
import SocketController from './base.js';
export class WebSocketController extends SocketController {
constructor(httpServer: httpServer) {
super(httpServer, 'WEBSOCKETS_REST');
this.server.on('connection', (ws: WebSocket, auth: AuthenticationState) => {
this.bindEvents(this.createClient(ws, auth));
});
logger.info(`WebSocket Server started at ws://${env['HOST']}:${env['PORT']}${this.endpoint}`);
}
private bindEvents(client: WebSocketClient) {
client.on('parsed-message', async (message: WebSocketMessage) => {
try {
message = WebSocketMessage.parse(await emitter.emitFilter('websocket.message', message, { client }));
client.accountability = await refreshAccountability(client.accountability);
emitter.emitAction('websocket.message', { message, client });
} catch (error) {
handleWebSocketException(client, error, 'server');
return;
}
});
client.on('error', (event: WebSocket.Event) => {
emitter.emitAction('websocket.error', { client, event });
});
client.on('close', (event: WebSocket.CloseEvent) => {
emitter.emitAction('websocket.close', { client, event });
});
emitter.emitAction('websocket.connect', { client });
}
protected override parseMessage(data: string): WebSocketMessage {
let message: WebSocketMessage;
try {
message = parseJSON(data);
} catch (err: any) {
throw new WebSocketException('server', 'INVALID_PAYLOAD', 'Unable to parse the incoming message.');
}
return message;
}
}

View File

@@ -0,0 +1,106 @@
import { describe, expect, test, vi } from 'vitest';
import type { WebSocketClient } from './types.js';
import { BaseException } from '@directus/exceptions';
import { InvalidPayloadException } from '../index.js';
import { WebSocketException, handleWebSocketException } from './exceptions.js';
import { ZodError } from 'zod';
import logger from '../logger.js';
vi.mock('../logger');
function mockClient() {
return {
on: vi.fn(),
off: vi.fn(),
send: vi.fn(),
close: vi.fn(),
accountability: null,
} as unknown as WebSocketClient;
}
describe('WebSocketException', () => {
test('with uid', () => {
const error = new WebSocketException('type', 'code', 'message', 123);
const response = error.toJSON();
expect(response).toStrictEqual({
type: 'type',
status: 'error',
error: {
code: 'code',
message: 'message',
},
uid: 123,
});
expect(error.toMessage()).toBe(JSON.stringify(response));
});
test('without uid', () => {
const error = new WebSocketException('type', 'code', 'message');
const response = error.toJSON();
expect(response).toStrictEqual({
type: 'type',
status: 'error',
error: {
code: 'code',
message: 'message',
},
});
expect(error.toMessage()).toBe(JSON.stringify(response));
});
});
describe('handleWebSocketException', () => {
const type = 'testing';
test('handle BaseException', () => {
const client = mockClient();
const error = new BaseException('test', 200, '123');
const expected = WebSocketException.fromException(error, type).toMessage();
handleWebSocketException(client, error, type);
expect(client.send).toBeCalledWith(expected);
expect(logger.error).not.toBeCalled();
});
test('handle InvalidPayloadException', () => {
const client = mockClient();
const error = new InvalidPayloadException('test');
const expected = WebSocketException.fromException(error, type).toMessage();
handleWebSocketException(client, error, type);
expect(client.send).toBeCalledWith(expected);
expect(logger.error).not.toBeCalled();
});
test('handle WebSocketException', () => {
const client = mockClient();
const error = new WebSocketException('type', 'code', 'message', 123);
const expected = error.toMessage();
handleWebSocketException(client, error, type);
expect(client.send).toBeCalledWith(expected);
expect(logger.error).not.toBeCalled();
});
test('handle ZodError', () => {
const client = mockClient();
const error = new ZodError([
{ message: 'test', code: 'invalid_type', path: ['path'], expected: 'array', received: 'string' },
]);
const expected = WebSocketException.fromZodError(error, type).toMessage();
handleWebSocketException(client, error, type);
expect(client.send).toBeCalledWith(expected);
expect(logger.error).not.toBeCalled();
});
test('unhandled exception', () => {
const client = mockClient();
const error = new Error('regular error');
handleWebSocketException(client, error, type);
expect(client.send).not.toBeCalled();
expect(logger.error).toBeCalled();
});
});

View File

@@ -0,0 +1,69 @@
import { BaseException } from '@directus/exceptions';
import type { WebSocket } from 'ws';
import { ZodError } from 'zod';
import { fromZodError } from 'zod-validation-error';
import logger from '../logger.js';
import type { WebSocketResponse } from './messages.js';
import type { WebSocketClient } from './types.js';
export class WebSocketException extends Error {
type: string;
code: string;
uid: string | number | undefined;
constructor(type: string, code: string, message: string, uid?: string | number) {
super(message);
this.type = type;
this.code = code;
this.uid = uid;
}
toJSON(): WebSocketResponse {
const message: WebSocketResponse = {
type: this.type,
status: 'error',
error: {
code: this.code,
message: this.message,
},
};
if (this.uid !== undefined) {
message.uid = this.uid;
}
return message;
}
toMessage(): string {
return JSON.stringify(this.toJSON());
}
static fromException(error: BaseException, type = 'unknown') {
return new WebSocketException(type, error.code, error.message);
}
static fromZodError(error: ZodError, type = 'unknown') {
const zError = fromZodError(error);
return new WebSocketException(type, 'INVALID_PAYLOAD', zError.message);
}
}
export function handleWebSocketException(client: WebSocketClient | WebSocket, error: unknown, type?: string): void {
if (error instanceof BaseException) {
client.send(WebSocketException.fromException(error, type).toMessage());
return;
}
if (error instanceof WebSocketException) {
client.send(error.toMessage());
return;
}
if (error instanceof ZodError) {
client.send(WebSocketException.fromZodError(error, type).toMessage());
return;
}
// unhandled exceptions
logger.error(`WebSocket unhandled exception ${JSON.stringify({ type, error })}`);
}

View File

@@ -0,0 +1,98 @@
import type { EventContext } from '@directus/types';
import { afterEach, beforeEach, describe, expect, test, vi } from 'vitest';
import type { Mock } from 'vitest';
import emitter from '../../emitter.js';
import { WebSocketController, getWebSocketController } from '../controllers/index.js';
import type { WebSocketClient } from '../types.js';
import { HeartbeatHandler } from './heartbeat.js';
// mocking
vi.mock('../controllers', () => ({
getWebSocketController: vi.fn(() => ({
clients: new Set(),
})),
}));
vi.mock('../../env', async () => {
const actual = (await vi.importActual('../../env')) as { default: Record<string, any> };
const MOCK_ENV = {
...actual.default,
WEBSOCKETS_HEARTBEAT_PERIOD: 1,
};
return {
default: MOCK_ENV,
getEnv: () => MOCK_ENV,
};
});
function mockClient() {
return {
on: vi.fn(),
off: vi.fn(),
send: vi.fn(),
close: vi.fn(),
} as unknown as WebSocketClient;
}
describe('WebSocket heartbeat handler', () => {
let controller: WebSocketController;
beforeEach(() => {
vi.useFakeTimers();
controller = getWebSocketController();
});
afterEach(() => {
vi.useRealTimers();
vi.clearAllMocks();
});
test('client should ping', async () => {
// initialize handler
new HeartbeatHandler(controller);
// connect fake client
const fakeClient = mockClient();
(fakeClient.send as Mock).mockImplementation(() => {
//respond with a message
emitter.emitAction('websocket.message', { client: fakeClient, message: { type: 'pong' } }, {} as EventContext);
});
controller.clients.add(fakeClient);
emitter.emitAction('websocket.connect', {}, {} as EventContext);
// wait for ping
vi.advanceTimersByTime(1000); // 1sec heartbeat interval
expect(fakeClient.send).toBeCalled();
// wait for another timeout
vi.advanceTimersByTime(1000); // 1sec heartbeat interval
expect(fakeClient.send).toBeCalled();
// the connection should not have been closed
expect(fakeClient.close).not.toBeCalled();
});
test('connection should be closed', async () => {
// initialize handler
new HeartbeatHandler(controller);
// connect fake client
const fakeClient = mockClient();
controller.clients.add(fakeClient);
emitter.emitAction('websocket.connect', {}, {} as EventContext);
vi.advanceTimersByTime(2 * 1000); // 2x 1sec heartbeat interval
expect(fakeClient.send).toBeCalled();
// the connection should have been closed
expect(fakeClient.close).toBeCalled();
});
test('the server should pong if the client pings', async () => {
// initialize handler
new HeartbeatHandler(controller);
// connect fake client
const fakeClient = mockClient();
controller.clients.add(fakeClient);
emitter.emitAction('websocket.connect', {}, {} as EventContext);
emitter.emitAction('websocket.message', { client: fakeClient, message: { type: 'ping' } }, {} as EventContext);
expect(fakeClient.send).toBeCalled();
});
});

View File

@@ -0,0 +1,87 @@
import type { ActionHandler } from '@directus/types';
import emitter from '../../emitter.js';
import env from '../../env.js';
import { toBoolean } from '../../utils/to-boolean.js';
import { WebSocketController, getWebSocketController } from '../controllers/index.js';
import { WebSocketMessage } from '../messages.js';
import type { WebSocketClient } from '../types.js';
import { fmtMessage, getMessageType } from '../utils/message.js';
const HEARTBEAT_FREQUENCY = Number(env['WEBSOCKETS_HEARTBEAT_PERIOD']) * 1000;
export class HeartbeatHandler {
private pulse: NodeJS.Timer | undefined;
private controller: WebSocketController;
constructor(controller?: WebSocketController) {
this.controller = controller ?? getWebSocketController();
emitter.onAction('websocket.message', ({ client, message }) => {
try {
this.onMessage(client, WebSocketMessage.parse(message));
} catch {
/* ignore errors */
}
});
if (toBoolean(env['WEBSOCKETS_HEARTBEAT_ENABLED']) === true) {
emitter.onAction('websocket.connect', () => this.checkClients());
emitter.onAction('websocket.error', () => this.checkClients());
emitter.onAction('websocket.close', () => this.checkClients());
}
}
private checkClients() {
const hasClients = this.controller.clients.size > 0;
if (hasClients && !this.pulse) {
this.pulse = setInterval(() => {
this.pingClients();
}, HEARTBEAT_FREQUENCY);
}
if (!hasClients && this.pulse) {
clearInterval(this.pulse);
this.pulse = undefined;
}
}
onMessage(client: WebSocketClient, message: WebSocketMessage) {
if (getMessageType(message) !== 'ping') return;
// send pong message back as acknowledgement
const data = 'uid' in message ? { uid: message.uid } : {};
client.send(fmtMessage('pong', data));
}
pingClients() {
const pendingClients = new Set<WebSocketClient>(this.controller.clients);
const activeClients = new Set<WebSocketClient>();
const timeout = setTimeout(() => {
// close connections that haven't responded
for (const client of pendingClients) {
client.close();
}
}, HEARTBEAT_FREQUENCY);
const messageWatcher: ActionHandler = ({ client }) => {
// any message means this connection is still open
if (!activeClients.has(client)) {
pendingClients.delete(client);
activeClients.add(client);
}
if (pendingClients.size === 0) {
clearTimeout(timeout);
emitter.offAction('websocket.message', messageWatcher);
}
};
emitter.onAction('websocket.message', messageWatcher);
// ping all the clients
for (const client of pendingClients) {
client.send(fmtMessage('ping'));
}
}
}

View File

@@ -0,0 +1,13 @@
import { HeartbeatHandler } from './heartbeat.js';
import { ItemsHandler } from './items.js';
import { SubscribeHandler } from './subscribe.js';
export function startWebSocketHandlers() {
new HeartbeatHandler();
new ItemsHandler();
new SubscribeHandler();
}
export * from './heartbeat.js';
export * from './items.js';
export * from './subscribe.js';

View File

@@ -0,0 +1,295 @@
import type { EventContext } from '@directus/types';
import { afterEach, beforeEach, describe, expect, test, vi } from 'vitest';
import type { Mock } from 'vitest';
import emitter from '../../emitter.js';
import { ItemsService, MetaService } from '../../services/index.js';
import { getSchema } from '../../utils/get-schema.js';
import type { WebSocketClient } from '../types.js';
import { ItemsHandler } from './items.js';
// mocking
vi.mock('../controllers', () => ({
getWebSocketController: vi.fn(() => ({
clients: new Set(),
})),
}));
vi.mock('../../utils/get-schema', () => ({
getSchema: vi.fn(),
}));
vi.mock('../../services', () => ({
ItemsService: vi.fn(),
MetaService: vi.fn(),
}));
function mockClient() {
return {
on: vi.fn(),
off: vi.fn(),
send: vi.fn(),
close: vi.fn(),
accountability: null,
} as unknown as WebSocketClient;
}
describe('WebSocket heartbeat handler', () => {
let handler: ItemsHandler;
beforeEach(() => {
vi.useFakeTimers();
// initialize handler
handler = new ItemsHandler();
});
afterEach(() => {
vi.useRealTimers();
emitter.offAll();
vi.clearAllMocks();
});
test('ignore other message types', async () => {
const spy = vi.spyOn(handler, 'onMessage');
// receive message
emitter.emitAction(
'websocket.message',
{
client: mockClient(),
message: { type: 'pong' },
},
{} as EventContext
);
// expect nothing
expect(spy).not.toBeCalled();
});
test('invalid collection should error', async () => {
(getSchema as Mock).mockImplementation(() => ({ collections: {} }));
// receive message
const fakeClient = mockClient();
emitter.emitAction(
'websocket.message',
{
client: fakeClient,
message: { type: 'items', collection: 'test', action: 'create', data: {} },
},
{} as EventContext
);
await vi.runAllTimersAsync(); // flush promises to make sure the event is handled
// expect error
expect(fakeClient.send).toBeCalledWith(
'{"type":"items","status":"error","error":{"code":"INVALID_COLLECTION","message":"The provided collection does not exists or is not accessible."}}'
);
});
test('create one item', async () => {
// do mocking
(getSchema as Mock).mockImplementation(() => ({ collections: { test: [] } }));
const createOne = vi.fn(),
readOne = vi.fn();
(ItemsService as Mock).mockImplementation(() => ({ createOne, readOne }));
// receive message
const fakeClient = mockClient();
emitter.emitAction(
'websocket.message',
{
client: fakeClient,
message: { type: 'items', collection: 'test', action: 'create', data: {} },
},
{} as EventContext
);
await vi.runAllTimersAsync(); // flush promises to make sure the event is handled
// expect service functions
expect(createOne).toBeCalled();
expect(readOne).toBeCalled();
expect(fakeClient.send).toBeCalled();
});
test('create multiple items', async () => {
// do mocking
(getSchema as Mock).mockImplementation(() => ({ collections: { test: [] } }));
const createMany = vi.fn(),
readMany = vi.fn();
(ItemsService as Mock).mockImplementation(() => ({ createMany, readMany }));
// receive message
const fakeClient = mockClient();
emitter.emitAction(
'websocket.message',
{
client: fakeClient,
message: { type: 'items', collection: 'test', action: 'create', data: [{}, {}] },
},
{} as EventContext
);
await vi.runAllTimersAsync(); // flush promises to make sure the event is handled
// expect service functions
expect(createMany).toBeCalled();
expect(readMany).toBeCalled();
expect(fakeClient.send).toBeCalled();
});
test('read by query', async () => {
// do mocking
(getSchema as Mock).mockImplementation(() => ({ collections: { test: [] } }));
const readByQuery = vi.fn();
(ItemsService as Mock).mockImplementation(() => ({ readByQuery }));
const getMetaForQuery = vi.fn();
(MetaService as Mock).mockImplementation(() => ({ getMetaForQuery }));
// receive message
const fakeClient = mockClient();
emitter.emitAction(
'websocket.message',
{
client: fakeClient,
message: { type: 'items', collection: 'test', action: 'read', query: {} },
},
{} as EventContext
);
await vi.runAllTimersAsync(); // flush promises to make sure the event is handled
// expect service functions
expect(readByQuery).toBeCalled();
expect(getMetaForQuery).toBeCalled();
expect(fakeClient.send).toBeCalled();
});
test('update one item', async () => {
// do mocking
(getSchema as Mock).mockImplementation(() => ({ collections: { test: [] } }));
const updateOne = vi.fn(),
readOne = vi.fn();
(ItemsService as Mock).mockImplementation(() => ({ updateOne, readOne }));
// receive message
const fakeClient = mockClient();
emitter.emitAction(
'websocket.message',
{
client: fakeClient,
message: { type: 'items', collection: 'test', action: 'update', data: {}, id: '123' },
},
{} as EventContext
);
await vi.runAllTimersAsync(); // flush promises to make sure the event is handled
// expect service functions
expect(updateOne).toBeCalled();
expect(readOne).toBeCalled();
expect(fakeClient.send).toBeCalled();
});
test('update multiple items', async () => {
// do mocking
(getSchema as Mock).mockImplementation(() => ({ collections: { test: [] } }));
const updateMany = vi.fn(),
readMany = vi.fn();
(ItemsService as Mock).mockImplementation(() => ({ updateMany, readMany }));
const getMetaForQuery = vi.fn();
(MetaService as Mock).mockImplementation(() => ({ getMetaForQuery }));
// receive message
const fakeClient = mockClient();
emitter.emitAction(
'websocket.message',
{
client: fakeClient,
message: { type: 'items', collection: 'test', action: 'update', data: {}, ids: ['123', '456'] },
},
{} as EventContext
);
await vi.runAllTimersAsync(); // flush promises to make sure the event is handled
// expect service functions
expect(updateMany).toBeCalled();
expect(getMetaForQuery).toBeCalled();
expect(readMany).toBeCalled();
expect(fakeClient.send).toBeCalled();
});
test('delete one item', async () => {
// do mocking
(getSchema as Mock).mockImplementation(() => ({ collections: { test: [] } }));
const deleteOne = vi.fn();
(ItemsService as Mock).mockImplementation(() => ({ deleteOne }));
// receive message
const fakeClient = mockClient();
emitter.emitAction(
'websocket.message',
{
client: fakeClient,
message: { type: 'items', collection: 'test', action: 'delete', id: '123' },
},
{} as EventContext
);
await vi.runAllTimersAsync(); // flush promises to make sure the event is handled
// expect service functions
expect(deleteOne).toBeCalled();
expect(fakeClient.send).toBeCalled();
});
test('delete multiple items by id', async () => {
// do mocking
(getSchema as Mock).mockImplementation(() => ({ collections: { test: [] } }));
const deleteMany = vi.fn();
(ItemsService as Mock).mockImplementation(() => ({ deleteMany }));
// receive message
const fakeClient = mockClient();
emitter.emitAction(
'websocket.message',
{
client: fakeClient,
message: { type: 'items', collection: 'test', action: 'delete', ids: ['123', 456] },
},
{} as EventContext
);
await vi.runAllTimersAsync(); // flush promises to make sure the event is handled
// expect service functions
expect(deleteMany).toBeCalled();
expect(fakeClient.send).toBeCalled();
});
test('delete multiple items by query', async () => {
// do mocking
(getSchema as Mock).mockImplementation(() => ({ collections: { test: [] } }));
const deleteByQuery = vi.fn();
(ItemsService as Mock).mockImplementation(() => ({ deleteByQuery }));
// receive message
const fakeClient = mockClient();
emitter.emitAction(
'websocket.message',
{
client: fakeClient,
message: { type: 'items', collection: 'test', action: 'delete', query: {} },
},
{} as EventContext
);
await vi.runAllTimersAsync(); // flush promises to make sure the event is handled
// expect service functions
expect(deleteByQuery).toBeCalled();
expect(fakeClient.send).toBeCalled();
});
});

View File

@@ -0,0 +1,117 @@
import emitter from '../../emitter.js';
import { ItemsService, MetaService } from '../../services/index.js';
import { getSchema } from '../../utils/get-schema.js';
import { sanitizeQuery } from '../../utils/sanitize-query.js';
import { WebSocketException, handleWebSocketException } from '../exceptions.js';
import { WebSocketItemsMessage } from '../messages.js';
import type { WebSocketClient } from '../types.js';
import { fmtMessage, getMessageType } from '../utils/message.js';
export class ItemsHandler {
constructor() {
emitter.onAction('websocket.message', ({ client, message }) => {
if (getMessageType(message) !== 'items') return;
try {
const parsedMessage = WebSocketItemsMessage.parse(message);
this.onMessage(client, parsedMessage).catch((err) => {
// this catch is required because the async onMessage function is not awaited
handleWebSocketException(client, err, 'items');
});
} catch (err) {
handleWebSocketException(client, err, 'items');
}
});
}
async onMessage(client: WebSocketClient, message: WebSocketItemsMessage) {
const uid = message.uid;
const accountability = client.accountability;
const schema = await getSchema();
if (!schema.collections[message.collection] || message.collection.startsWith('directus_')) {
throw new WebSocketException(
'items',
'INVALID_COLLECTION',
'The provided collection does not exists or is not accessible.',
uid
);
}
const isSingleton = !!schema.collections[message.collection]?.singleton;
const service = new ItemsService(message.collection, { schema, accountability });
const metaService = new MetaService({ schema, accountability });
let result, meta;
if (message.action === 'create') {
const query = sanitizeQuery(message?.query ?? {}, accountability);
if (Array.isArray(message.data)) {
const keys = await service.createMany(message.data);
result = await service.readMany(keys, query);
} else {
const key = await service.createOne(message.data);
result = await service.readOne(key, query);
}
}
if (message.action === 'read') {
const query = sanitizeQuery(message.query ?? {}, accountability);
if (message.id) {
result = await service.readOne(message.id, query);
} else if (message.ids) {
result = await service.readMany(message.ids, query);
} else if (isSingleton) {
result = await service.readSingleton(query);
} else {
result = await service.readByQuery(query);
}
meta = await metaService.getMetaForQuery(message.collection, query);
}
if (message.action === 'update') {
const query = sanitizeQuery(message.query ?? {}, accountability);
if (message.id) {
const key = await service.updateOne(message.id, message.data);
result = await service.readOne(key);
} else if (message.ids) {
const keys = await service.updateMany(message.ids, message.data);
meta = await metaService.getMetaForQuery(message.collection, query);
result = await service.readMany(keys, query);
} else if (isSingleton) {
await service.upsertSingleton(message.data);
result = await service.readSingleton(query);
} else {
const keys = await service.updateByQuery(query, message.data);
meta = await metaService.getMetaForQuery(message.collection, query);
result = await service.readMany(keys, query);
}
}
if (message.action === 'delete') {
if (message.id) {
await service.deleteOne(message.id);
result = message.id;
} else if (message.ids) {
await service.deleteMany(message.ids);
result = message.ids;
} else if (message.query) {
const query = sanitizeQuery(message.query, accountability);
result = await service.deleteByQuery(query);
} else {
throw new WebSocketException(
'items',
'INVALID_PAYLOAD',
"Either 'ids', 'id' or 'query' is required for a DELETE request.",
uid
);
}
}
client.send(fmtMessage('items', { data: result, ...(meta ? { meta } : {}) }, uid));
}
}

View File

@@ -0,0 +1,282 @@
import { expect, describe, test, vi, beforeEach, afterEach } from 'vitest';
import emitter from '../../emitter.js';
import { SubscribeHandler } from './subscribe.js';
import type { WebSocketClient } from '../types.js';
import { getSchema } from '../../utils/get-schema.js';
import type { CollectionsOverview, Relation } from '@directus/types';
// mocking
vi.mock('../controllers', () => ({
getWebSocketController: vi.fn(() => ({
clients: new Set(),
})),
}));
vi.mock('../../utils/get-schema', () => ({
getSchema: vi.fn(),
}));
vi.mock('../../services', () => ({
ItemsService: vi.fn(() => ({
readByQuery: vi.fn(),
})),
MetaService: vi.fn(),
}));
vi.mock('../../database/index');
function mockClient() {
return {
on: vi.fn(),
off: vi.fn(),
send: vi.fn(),
close: vi.fn(),
accountability: null,
} as unknown as WebSocketClient;
}
function delay(ms: number) {
return new Promise<void>((resolve) => {
setTimeout(() => resolve(), ms);
});
}
describe('WebSocket heartbeat handler', () => {
let handler: SubscribeHandler;
beforeEach(() => {
// initialize handler
handler = new SubscribeHandler();
});
afterEach(() => {
emitter.offAll();
vi.clearAllMocks();
});
test('ignore other message types', async () => {
const spy = vi.spyOn(handler, 'onMessage');
// receive message
emitter.emitAction('websocket.message', {
client: mockClient(),
message: { type: 'ping' },
});
// expect nothing
expect(spy).not.toBeCalled();
});
test('should fail subscribe to non-existing collection', async () => {
vi.mocked(getSchema).mockImplementation(async () => ({
collections: {} as CollectionsOverview,
relations: [] as Relation[],
}));
const subscribe = vi.spyOn(handler, 'subscribe');
const onMessage = vi.spyOn(handler, 'onMessage');
// receive message
emitter.emitAction('websocket.message', {
client: mockClient(),
message: {
type: 'subscribe',
collection: 'does_not_exist',
},
});
await delay(10);
// expect
expect(onMessage).toBeCalled();
expect(subscribe).not.toBeCalled();
});
test('should subscribe/unsubscribe to collection', async () => {
const client = mockClient();
vi.mocked(getSchema).mockImplementation(async () => ({
collections: {
test_collection: {
collection: 'test_collection',
primary: 'id',
singleton: false,
sortField: null,
note: null,
accountability: null,
fields: {},
},
} as CollectionsOverview,
relations: [] as Relation[],
}));
const subscribe = vi.spyOn(handler, 'subscribe');
const onMessage = vi.spyOn(handler, 'onMessage');
// receive message
emitter.emitAction('websocket.message', {
client,
message: {
type: 'subscribe',
collection: 'test_collection',
uid: '123',
},
});
await delay(10);
// expect
expect(onMessage).toBeCalled();
expect(subscribe).toBeCalled();
expect(handler.subscriptions['test_collection']?.size).toBe(1);
});
test('unsubscribe a specific subscription', async () => {
const client = mockClient();
vi.mocked(getSchema).mockImplementation(async () => ({
collections: {
test_collection: {
collection: 'test_collection',
primary: 'id',
singleton: false,
sortField: null,
note: null,
accountability: null,
fields: {},
},
other_collection: {
collection: 'other_collection',
primary: 'id',
singleton: false,
sortField: null,
note: null,
accountability: null,
fields: {},
},
} as CollectionsOverview,
relations: [] as Relation[],
}));
const unsubscribe = vi.spyOn(handler, 'unsubscribe');
const subscribe = vi.spyOn(handler, 'subscribe');
const onMessage = vi.spyOn(handler, 'onMessage');
// subscribe
emitter.emitAction('websocket.message', {
client,
message: {
type: 'subscribe',
collection: 'test_collection',
uid: '123',
},
});
emitter.emitAction('websocket.message', {
client,
message: {
type: 'subscribe',
collection: 'other_collection',
uid: '456',
},
});
await delay(10);
// expect
expect(onMessage).toBeCalledTimes(2);
expect(subscribe).toBeCalledTimes(2);
expect(handler.subscriptions['test_collection']?.size).toBe(1);
expect(handler.subscriptions['other_collection']?.size).toBe(1);
// unsubscribe
emitter.emitAction('websocket.message', {
client,
message: {
type: 'unsubscribe',
uid: '123',
},
});
await delay(10);
// expect
expect(unsubscribe).toBeCalled();
expect(handler.subscriptions['test_collection']?.size).toBe(0);
expect(handler.subscriptions['other_collection']?.size).toBe(1);
});
test('unsubscribe all subscriptions', async () => {
const client = mockClient();
vi.mocked(getSchema).mockImplementation(async () => ({
collections: {
test_collection: {
collection: 'test_collection',
primary: 'id',
singleton: false,
sortField: null,
note: null,
accountability: null,
fields: {},
},
other_collection: {
collection: 'other_collection',
primary: 'id',
singleton: false,
sortField: null,
note: null,
accountability: null,
fields: {},
},
} as CollectionsOverview,
relations: [] as Relation[],
}));
const unsubscribe = vi.spyOn(handler, 'unsubscribe');
const subscribe = vi.spyOn(handler, 'subscribe');
const onMessage = vi.spyOn(handler, 'onMessage');
// subscribe
emitter.emitAction('websocket.message', {
client,
message: {
type: 'subscribe',
collection: 'test_collection',
uid: '123',
},
});
emitter.emitAction('websocket.message', {
client,
message: {
type: 'subscribe',
collection: 'other_collection',
uid: '456',
},
});
await delay(10);
// expect
expect(onMessage).toBeCalledTimes(2);
expect(subscribe).toBeCalledTimes(2);
expect(handler.subscriptions['test_collection']?.size).toBe(1);
expect(handler.subscriptions['other_collection']?.size).toBe(1);
// unsubscribe
emitter.emitAction('websocket.message', {
client,
message: {
type: 'unsubscribe',
},
});
await delay(10);
// expect
expect(unsubscribe).toBeCalled();
expect(handler.subscriptions['test_collection']?.size).toBe(0);
expect(handler.subscriptions['other_collection']?.size).toBe(0);
});
});

View File

@@ -0,0 +1,342 @@
import type { Accountability, SchemaOverview } from '@directus/types';
import emitter from '../../emitter.js';
import { InvalidPayloadException } from '../../index.js';
import { getMessenger } from '../../messenger.js';
import type { Messenger } from '../../messenger.js';
import { CollectionsService, FieldsService, MetaService } from '../../services/index.js';
import { getSchema } from '../../utils/get-schema.js';
import { getService } from '../../utils/get-service.js';
import { sanitizeQuery } from '../../utils/sanitize-query.js';
import { refreshAccountability } from '../authenticate.js';
import { WebSocketException, handleWebSocketException } from '../exceptions.js';
import type { WebSocketEvent } from '../messages.js';
import { WebSocketSubscribeMessage } from '../messages.js';
import type { Subscription, SubscriptionEvent, WebSocketClient } from '../types.js';
import { fmtMessage, getMessageType } from '../utils/message.js';
/**
* Handler responsible for subscriptions
*/
export class SubscribeHandler {
// storage of subscriptions per collection
subscriptions: Record<string, Set<Subscription>>;
// internal message bus
protected messenger: Messenger;
/**
* Initialize the handler
*/
constructor() {
this.subscriptions = {};
this.messenger = getMessenger();
this.bindWebSocket();
// listen to the Redis pub/sub and dispatch
this.messenger.subscribe('websocket.event', (message: Record<string, any>) => {
try {
this.dispatch(message as WebSocketEvent);
} catch {
// don't error on an invalid event from the messenger
}
});
}
/**
* Hook into websocket client lifecycle events
*/
bindWebSocket() {
// listen to incoming messages on the connected websockets
emitter.onAction('websocket.message', ({ client, message }) => {
if (!['subscribe', 'unsubscribe'].includes(getMessageType(message))) return;
try {
this.onMessage(client, WebSocketSubscribeMessage.parse(message));
} catch (error) {
handleWebSocketException(client, error, 'subscribe');
}
});
// unsubscribe when a connection drops
emitter.onAction('websocket.error', ({ client }) => this.unsubscribe(client));
emitter.onAction('websocket.close', ({ client }) => this.unsubscribe(client));
}
/**
* Register a subscription
* @param subscription
*/
subscribe(subscription: Subscription) {
const { collection } = subscription;
if ('item' in subscription && ['directus_fields', 'directus_relations'].includes(collection)) {
throw new InvalidPayloadException(`Cannot subscribe to a specific item in the ${collection} collection.`);
}
if (!this.subscriptions[collection]) {
this.subscriptions[collection] = new Set();
}
this.subscriptions[collection]?.add(subscription);
}
/**
* Remove a subscription
* @param subscription
*/
unsubscribe(client: WebSocketClient, uid?: string | number) {
if (uid !== undefined) {
const subscription = this.getSubscription(client, String(uid));
if (subscription) {
this.subscriptions[subscription.collection]?.delete(subscription);
}
} else {
for (const key of Object.keys(this.subscriptions)) {
const subscriptions = Array.from(this.subscriptions[key] || []);
for (let i = subscriptions.length - 1; i >= 0; i--) {
const subscription = subscriptions[i];
if (!subscription) continue;
if (subscription.client === client && (!uid || subscription.uid === uid)) {
this.subscriptions[key]?.delete(subscription);
}
}
}
}
}
/**
* Dispatch event to subscriptions
*/
async dispatch(event: WebSocketEvent) {
const subscriptions = this.subscriptions[event.collection];
if (!subscriptions || subscriptions.size === 0) return;
const schema = await getSchema();
for (const subscription of subscriptions) {
const { client } = subscription;
if (subscription.event !== undefined && event.action !== subscription.event) {
continue; // skip filtered events
}
try {
client.accountability = await refreshAccountability(client.accountability);
const result =
'item' in subscription
? await this.getSinglePayload(subscription, client.accountability, schema, event)
: await this.getMultiPayload(subscription, client.accountability, schema, event);
if (Array.isArray(result?.['data']) && result?.['data']?.length === 0) return;
client.send(fmtMessage('subscription', result, subscription.uid));
} catch (err) {
handleWebSocketException(client, err, 'subscribe');
}
}
}
/**
* Handle incoming (un)subscribe requests
*/
async onMessage(client: WebSocketClient, message: WebSocketSubscribeMessage) {
if (getMessageType(message) === 'subscribe') {
try {
const collection = String(message.collection!);
const accountability = client.accountability;
const schema = await getSchema();
if (!accountability?.admin && !schema.collections[collection]) {
throw new WebSocketException(
'subscribe',
'INVALID_COLLECTION',
'The provided collection does not exists or is not accessible.',
message.uid
);
}
const subscription: Subscription = {
client,
collection,
};
if ('event' in message) {
subscription.event = message.event as SubscriptionEvent;
}
if ('query' in message) {
subscription.query = sanitizeQuery(message.query!, accountability);
}
if ('item' in message) subscription.item = String(message.item);
if ('uid' in message) {
subscription.uid = String(message.uid);
// remove the subscription if it already exists
this.unsubscribe(client, subscription.uid);
}
let data: Record<string, any>;
if (subscription.event === undefined) {
data =
'item' in subscription
? await this.getSinglePayload(subscription, accountability, schema)
: await this.getMultiPayload(subscription, accountability, schema);
} else {
data = { event: 'init' };
}
// if no errors were thrown register the subscription
this.subscribe(subscription);
// send an initial response
client.send(fmtMessage('subscription', data, subscription.uid));
} catch (err) {
handleWebSocketException(client, err, 'subscribe');
}
}
if (getMessageType(message) === 'unsubscribe') {
try {
this.unsubscribe(client, message.uid);
client.send(fmtMessage('subscription', { event: 'unsubscribe' }, message.uid));
} catch (err) {
handleWebSocketException(client, err, 'unsubscribe');
}
}
}
private async getSinglePayload(
subscription: Subscription,
accountability: Accountability | null,
schema: SchemaOverview,
event?: WebSocketEvent
): Promise<Record<string, any>> {
const metaService = new MetaService({ schema, accountability });
const query = subscription.query ?? {};
const id = subscription.item!;
const result: Record<string, any> = {
event: event?.action ?? 'init',
};
if (subscription.collection === 'directus_collections') {
const service = new CollectionsService({ schema, accountability });
result['data'] = await service.readOne(String(id));
} else {
const service = getService(subscription.collection, { schema, accountability });
result['data'] = await service.readOne(id, query);
}
if ('meta' in query) {
result['meta'] = await metaService.getMetaForQuery(subscription.collection, query);
}
return result;
}
private async getMultiPayload(
subscription: Subscription,
accountability: Accountability | null,
schema: SchemaOverview,
event?: WebSocketEvent
): Promise<Record<string, any>> {
const metaService = new MetaService({ schema, accountability });
const result: Record<string, any> = {
event: event?.action ?? 'init',
};
switch (subscription.collection) {
case 'directus_collections':
result['data'] = await this.getCollectionPayload(accountability, schema, event);
break;
case 'directus_fields':
result['data'] = await this.getFieldsPayload(accountability, schema, event);
break;
case 'directus_relations':
result['data'] = event?.payload;
break;
default:
result['data'] = await this.getItemsPayload(subscription, accountability, schema, event);
break;
}
const query = subscription.query ?? {};
if ('meta' in query) {
result['meta'] = await metaService.getMetaForQuery(subscription.collection, query);
}
return result;
}
private async getCollectionPayload(
accountability: Accountability | null,
schema: SchemaOverview,
event?: WebSocketEvent
) {
const service = new CollectionsService({ schema, accountability });
if (!event?.action) {
return await service.readByQuery();
} else if (event.action === 'create') {
return await service.readMany([String(event.key)]);
} else if (event.action === 'delete') {
return event.keys;
} else {
return await service.readMany(event.keys.map((key: any) => String(key)));
}
}
private async getFieldsPayload(
accountability: Accountability | null,
schema: SchemaOverview,
event?: WebSocketEvent
) {
const service = new FieldsService({ schema, accountability });
if (!event?.action) {
return await service.readAll();
} else if (event.action === 'delete') {
return event.keys;
} else {
return await service.readOne(event.payload?.['collection'], event.payload?.['field']);
}
}
private async getItemsPayload(
subscription: Subscription,
accountability: Accountability | null,
schema: SchemaOverview,
event?: WebSocketEvent
) {
const query = subscription.query ?? {};
const service = getService(subscription.collection, { schema, accountability });
if (!event?.action) {
return await service.readByQuery(query);
} else if (event.action === 'create') {
return await service.readMany([event.key], query);
} else if (event.action === 'delete') {
return event.keys;
} else {
return await service.readMany(event.keys, query);
}
}
private getSubscription(client: WebSocketClient, uid: string | number) {
for (const userSubscriptions of Object.values(this.subscriptions)) {
for (const subscription of userSubscriptions) {
if (subscription.client === client && subscription.uid === uid) {
return subscription;
}
}
}
return undefined;
}
}

View File

@@ -0,0 +1,118 @@
import type { Item, Query } from '@directus/types';
import { z } from 'zod';
const zodStringOrNumber = z.union([z.string(), z.number()]);
export const WebSocketMessage = z
.object({
type: z.string(),
uid: zodStringOrNumber.optional(),
})
.passthrough();
export type WebSocketMessage = z.infer<typeof WebSocketMessage>;
export const WebSocketResponse = z.discriminatedUnion('status', [
WebSocketMessage.extend({
status: z.literal('ok'),
}),
WebSocketMessage.extend({
status: z.literal('error'),
error: z
.object({
code: z.string(),
message: z.string(),
})
.passthrough(),
}),
]);
export type WebSocketResponse = z.infer<typeof WebSocketResponse>;
export const ConnectionParams = z.object({ access_token: z.string().optional() });
export type ConnectionParams = z.infer<typeof ConnectionParams>;
export const BasicAuthMessage = z.union([
z.object({ email: z.string().email(), password: z.string() }),
z.object({ access_token: z.string() }),
z.object({ refresh_token: z.string() }),
]);
export type BasicAuthMessage = z.infer<typeof BasicAuthMessage>;
export const WebSocketAuthMessage = WebSocketMessage.extend({
type: z.literal('auth'),
}).and(BasicAuthMessage);
export type WebSocketAuthMessage = z.infer<typeof WebSocketAuthMessage>;
export const WebSocketSubscribeMessage = z.discriminatedUnion('type', [
WebSocketMessage.extend({
type: z.literal('subscribe'),
collection: z.string(),
event: z.union([z.literal('create'), z.literal('update'), z.literal('delete')]).optional(),
item: zodStringOrNumber.optional(),
query: z.custom<Query>().optional(),
}),
WebSocketMessage.extend({
type: z.literal('unsubscribe'),
}),
]);
export type WebSocketSubscribeMessage = z.infer<typeof WebSocketSubscribeMessage>;
const ZodItem = z.custom<Partial<Item>>();
const PartialItemsMessage = z.object({
uid: zodStringOrNumber.optional(),
type: z.literal('items'),
collection: z.string(),
});
export const WebSocketItemsMessage = z.union([
PartialItemsMessage.extend({
action: z.literal('create'),
data: z.union([z.array(ZodItem), ZodItem]),
query: z.custom<Query>().optional(),
}),
PartialItemsMessage.extend({
action: z.literal('read'),
ids: z.array(zodStringOrNumber).optional(),
id: zodStringOrNumber.optional(),
query: z.custom<Query>().optional(),
}),
PartialItemsMessage.extend({
action: z.literal('update'),
data: ZodItem,
ids: z.array(zodStringOrNumber).optional(),
id: zodStringOrNumber.optional(),
query: z.custom<Query>().optional(),
}),
PartialItemsMessage.extend({
action: z.literal('delete'),
ids: z.array(zodStringOrNumber).optional(),
id: zodStringOrNumber.optional(),
query: z.custom<Query>().optional(),
}),
]);
export type WebSocketItemsMessage = z.infer<typeof WebSocketItemsMessage>;
export const WebSocketEvent = z.discriminatedUnion('action', [
z.object({
action: z.literal('create'),
collection: z.string(),
payload: z.record(z.any()).optional(),
key: zodStringOrNumber,
}),
z.object({
action: z.literal('update'),
collection: z.string(),
payload: z.record(z.any()).optional(),
keys: z.array(zodStringOrNumber),
}),
z.object({
action: z.literal('delete'),
collection: z.string(),
payload: z.record(z.any()).optional(),
keys: z.array(zodStringOrNumber),
}),
]);
export type WebSocketEvent = z.infer<typeof WebSocketEvent>;
export const AuthMode = z.union([z.literal('public'), z.literal('handshake'), z.literal('strict')]);
export type AuthMode = z.infer<typeof AuthMode>;

View File

@@ -0,0 +1,35 @@
import type { Accountability, Query } from '@directus/types';
import type { IncomingMessage } from 'http';
import type internal from 'stream';
import type { WebSocket } from 'ws';
export type AuthenticationState = {
accountability: Accountability | null;
expires_at: number | null;
refresh_token?: string;
};
export type WebSocketClient = WebSocket &
AuthenticationState & { uid: string | number; auth_timer: NodeJS.Timer | null };
export type UpgradeRequest = IncomingMessage & AuthenticationState;
export type SubscriptionEvent = 'create' | 'update' | 'delete';
export type Subscription = {
uid?: string | number;
query?: Query;
item?: string | number;
event?: SubscriptionEvent;
collection: string;
client: WebSocketClient;
};
export type UpgradeContext = {
request: IncomingMessage;
socket: internal.Duplex;
head: Buffer;
};
export type GraphQLSocket = {
client: WebSocketClient;
};

View File

@@ -0,0 +1,23 @@
import jwt from 'jsonwebtoken';
import { describe, expect, test } from 'vitest';
import { getExpiresAtForToken } from './get-expires-at-for-token.js';
describe('getExpiresAtForToken', () => {
test('Returns null for non-jwt tokens', () => {
const result = getExpiresAtForToken('not-a-jwt');
expect(result).toBe(null);
});
test('Returns null for jwt with no exp field', () => {
const token = jwt.sign({ payload: 'content' }, 'secret', { issuer: 'tim' });
const result = getExpiresAtForToken(token);
expect(result).toBe(null);
});
test('Returns expiresAt field for jwt with exp as number', () => {
const now = Math.floor(Date.now() / 1000);
const token = jwt.sign({ payload: 'content' }, 'secret', { expiresIn: 42 });
const result = getExpiresAtForToken(token);
expect(result).toBeGreaterThan(now);
});
});

View File

@@ -0,0 +1,11 @@
import jwt from 'jsonwebtoken';
export function getExpiresAtForToken(token: string): number | null {
const decoded = jwt.decode(token);
if (decoded && typeof decoded === 'object' && decoded.exp) {
return decoded.exp;
}
return null;
}

View File

@@ -0,0 +1,75 @@
import { describe, expect, test, vi } from 'vitest';
import type { WebSocketClient } from '../types.js';
import { fmtMessage, getMessageType, safeSend } from './message.js';
describe('fmtMessage util', () => {
test('Returns formatted message', () => {
const result = fmtMessage('test', { test: 'abc' });
expect(result).toStrictEqual('{"type":"test","test":"abc"}');
});
test('Returns formatted message with uid', () => {
const result = fmtMessage('test', { test: 'abc' }, '123');
expect(result).toStrictEqual('{"type":"test","test":"abc","uid":"123"}');
});
});
describe('safeSend util', () => {
test('Ignore for closed connections', async () => {
const fakeClient = {
readyState: 3, // closed
OPEN: 1,
bufferedAmount: 0,
send: vi.fn(),
} as unknown as WebSocketClient;
const result = await safeSend(fakeClient, 'not used');
expect(result).toBe(false);
expect(fakeClient.send).not.toBeCalled();
});
test('Wait for buffer', async () => {
const fakeClient = {
readyState: 1, // open
OPEN: 1,
bufferedAmount: 4,
send: vi.fn(),
};
setTimeout(() => {
fakeClient.bufferedAmount = 0;
}, 10);
const result = await safeSend(fakeClient as unknown as WebSocketClient, 'a message', 20);
expect(result).toBe(true);
expect(fakeClient.send).toBeCalledWith('a message');
});
test('send message', async () => {
const fakeClient = {
readyState: 1, // open
OPEN: 1,
bufferedAmount: 0,
send: vi.fn(),
} as unknown as WebSocketClient;
const result = await safeSend(fakeClient as unknown as WebSocketClient, 'a message');
expect(result).toBe(true);
expect(fakeClient.send).toBeCalledWith('a message');
});
});
describe('getMessageType util', () => {
test('Fails graceously', () => {
expect(getMessageType(null)).toBe('');
expect(getMessageType(undefined)).toBe('');
expect(getMessageType(false)).toBe('');
expect(getMessageType(123456)).toBe('');
expect(getMessageType([])).toBe('');
});
test('Get the type property', () => {
expect(getMessageType({ type: 'test' })).toBe('test');
expect(getMessageType({ type: 123 })).toBe('123');
});
});

View File

@@ -0,0 +1,34 @@
import type { WebSocketClient } from '../types.js';
// a simple util for building a message object
export const fmtMessage = (type: string, data: Record<string, any> = {}, uid?: string | number) => {
const message: Record<string, any> = { type, ...data };
if (uid !== undefined) {
message['uid'] = uid;
}
return JSON.stringify(message);
};
// we may need this later for slow connections
export const safeSend = async (client: WebSocketClient, data: string, delay = 100) => {
if (client.readyState !== client.OPEN) return false;
if (client.bufferedAmount > 0) {
// wait for the buffer to clear
return new Promise((resolve) => {
setTimeout(() => {
safeSend(client, data, delay).then((success) => resolve(success));
}, delay);
});
}
client.send(data);
return true;
};
// an often used message type extractor function
export const getMessageType = (message: any): string => {
return typeof message !== 'object' || Array.isArray(message) || message === null ? '' : String(message.type);
};

View File

@@ -0,0 +1,94 @@
import { describe, expect, test, vi } from 'vitest';
import type { RawData, WebSocket } from 'ws';
import { waitForAnyMessage, waitForMessageType } from './wait-for-message.js';
function bufferMessage(msg: any): RawData {
return Buffer.from(JSON.stringify(msg));
}
function mockClient(handler: (callback: (event: RawData) => void) => void) {
return {
on: vi.fn().mockImplementation((type: string, callback: (event: RawData) => void) => {
if (type === 'message') handler(callback);
}),
off: vi.fn(),
} as unknown as WebSocket;
}
describe('Wait for messages', () => {
test('should succeed, 5ms delay, 10ms timeout', async () => {
const TEST_TIMEOUT = 10;
const TEST_MSG = { type: 'test', id: 1 };
const fakeClient = mockClient((callback) => {
setTimeout(() => {
callback(bufferMessage(TEST_MSG));
}, 5);
});
const msg = await waitForAnyMessage(fakeClient, TEST_TIMEOUT);
expect(msg).toStrictEqual(TEST_MSG);
});
test('should fail, 10ms delay, 5ms timeout', async () => {
const TEST_TIMEOUT = 5;
const TEST_MSG = { type: 'test', id: 1 };
const fakeClient = mockClient((callback) => {
setTimeout(() => {
callback(bufferMessage(TEST_MSG));
}, 10);
});
expect(() => waitForAnyMessage(fakeClient, TEST_TIMEOUT)).rejects.toBe(undefined);
});
test('should fail parsing', async () => {
const TEST_TIMEOUT = 5;
const fakeClient = mockClient((callback) => {
setTimeout(() => {
callback(Buffer.from('{invalid:json}'));
}, 10);
});
expect(() => waitForAnyMessage(fakeClient, TEST_TIMEOUT)).rejects.toBe(undefined);
});
});
describe('Wait for specific types messages', () => {
const MSG_A = { type: 'test', id: 1 };
const MSG_B = { type: 'other', id: 2 };
test('should find the correct message', async () => {
const fakeClient = mockClient((callback) => {
setTimeout(() => callback(bufferMessage(MSG_B)), 5);
setTimeout(() => callback(bufferMessage(MSG_A)), 10);
});
const msg = await waitForMessageType(fakeClient, 'test', 15);
expect(msg).toStrictEqual(MSG_A);
});
test('should fail, no matching type', async () => {
const fakeClient = mockClient((callback) => {
setTimeout(() => {
callback(bufferMessage(MSG_B));
}, 5);
});
expect(() => waitForMessageType(fakeClient, 'test', 10)).rejects.toBe(undefined);
});
test('should fail parsing', async () => {
const fakeClient = mockClient((callback) => {
setTimeout(() => {
callback(bufferMessage({ id: 2 }));
}, 5);
});
expect(() => waitForMessageType(fakeClient, 'test', 10)).rejects.toBe(undefined);
});
});

View File

@@ -0,0 +1,52 @@
import { parseJSON } from '@directus/utils';
import type { RawData, WebSocket } from 'ws';
import { WebSocketMessage } from '../messages.js';
import { getMessageType } from './message.js';
export const waitForAnyMessage = (client: WebSocket, timeout: number): Promise<Record<string, any>> => {
return new Promise((resolve, reject) => {
client.on('message', awaitMessage);
const timer = setTimeout(() => {
client.off('message', awaitMessage);
reject();
}, timeout);
function awaitMessage(event: RawData) {
try {
clearTimeout(timer);
client.off('message', awaitMessage);
resolve(parseJSON(event.toString()));
} catch (err) {
reject(err);
}
}
});
};
export const waitForMessageType = (client: WebSocket, type: string, timeout: number): Promise<WebSocketMessage> => {
return new Promise((resolve, reject) => {
client.on('message', awaitMessage);
const timer = setTimeout(() => {
client.off('message', awaitMessage);
reject();
}, timeout);
function awaitMessage(event: RawData) {
let msg: WebSocketMessage;
try {
msg = WebSocketMessage.parse(parseJSON(event.toString()));
} catch {
return;
}
if (getMessageType(msg) === type) {
clearTimeout(timer);
client.off('message', awaitMessage);
resolve(msg);
}
}
});
};