Streaming: Rework websocket server initialisation & authentication code (#28631)

th-new
Emelia Smith 2024-01-15 11:36:30 +01:00 committed by GitHub
parent e72676e83a
commit 58830be943
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 95 additions and 33 deletions

View File

@ -182,14 +182,74 @@ const CHANNEL_NAMES = [
]; ];
const startServer = async () => { const startServer = async () => {
const pgPool = new pg.Pool(pgConfigFromEnv(process.env));
const server = http.createServer();
const wss = new WebSocket.Server({ noServer: true });
// Set the X-Request-Id header on WebSockets:
wss.on("headers", function onHeaders(headers, req) {
headers.push(`X-Request-Id: ${req.id}`);
});
const app = express(); const app = express();
app.set('trust proxy', process.env.TRUSTED_PROXY_IP ? process.env.TRUSTED_PROXY_IP.split(/(?:\s*,\s*|\s+)/) : 'loopback,uniquelocal'); app.set('trust proxy', process.env.TRUSTED_PROXY_IP ? process.env.TRUSTED_PROXY_IP.split(/(?:\s*,\s*|\s+)/) : 'loopback,uniquelocal');
const pgPool = new pg.Pool(pgConfigFromEnv(process.env));
const server = http.createServer(app);
app.use(cors()); app.use(cors());
// Handle eventsource & other http requests:
server.on('request', app);
// Handle upgrade requests:
server.on('upgrade', async function handleUpgrade(request, socket, head) {
/** @param {Error} err */
const onSocketError = (err) => {
log.error(`Error with websocket upgrade: ${err}`);
};
socket.on('error', onSocketError);
// Authenticate:
try {
await accountFromRequest(request);
} catch (err) {
log.error(`Error authenticating request: ${err}`);
// Unfortunately for using the on('upgrade') setup, we need to manually
// write a HTTP Response to the Socket to close the connection upgrade
// attempt, so the following code is to handle all of that.
const statusCode = err.status ?? 401;
/** @type {Record<string, string | number>} */
const headers = {
'Connection': 'close',
'Content-Type': 'text/plain',
'Content-Length': 0,
'X-Request-Id': request.id,
// TODO: Send the error message via header so it can be debugged in
// developer tools
};
// Ensure the socket is closed once we've finished writing to it:
socket.once('finish', () => {
socket.destroy();
});
// Write the HTTP response manually:
socket.end(`HTTP/1.1 ${statusCode} ${http.STATUS_CODES[statusCode]}\r\n${Object.keys(headers).map((key) => `${key}: ${headers[key]}`).join('\r\n')}\r\n\r\n`);
return;
}
wss.handleUpgrade(request, socket, head, function done(ws) {
// Remove the error handler:
socket.removeListener('error', onSocketError);
// Start the connection:
wss.emit('connection', ws, request);
});
});
/** /**
* @type {Object.<string, Array.<function(Object<string, any>): void>>} * @type {Object.<string, Array.<function(Object<string, any>): void>>}
*/ */
@ -360,10 +420,19 @@ const startServer = async () => {
const isInScope = (req, necessaryScopes) => const isInScope = (req, necessaryScopes) =>
req.scopes.some(scope => necessaryScopes.includes(scope)); req.scopes.some(scope => necessaryScopes.includes(scope));
/**
* @typedef ResolvedAccount
* @property {string} accessTokenId
* @property {string[]} scopes
* @property {string} accountId
* @property {string[]} chosenLanguages
* @property {string} deviceId
*/
/** /**
* @param {string} token * @param {string} token
* @param {any} req * @param {any} req
* @returns {Promise.<void>} * @returns {Promise<ResolvedAccount>}
*/ */
const accountFromToken = (token, req) => new Promise((resolve, reject) => { const accountFromToken = (token, req) => new Promise((resolve, reject) => {
pgPool.connect((err, client, done) => { pgPool.connect((err, client, done) => {
@ -394,14 +463,20 @@ const startServer = async () => {
req.chosenLanguages = result.rows[0].chosen_languages; req.chosenLanguages = result.rows[0].chosen_languages;
req.deviceId = result.rows[0].device_id; req.deviceId = result.rows[0].device_id;
resolve(); resolve({
accessTokenId: result.rows[0].id,
scopes: result.rows[0].scopes.split(' '),
accountId: result.rows[0].account_id,
chosenLanguages: result.rows[0].chosen_languages,
deviceId: result.rows[0].device_id
});
}); });
}); });
}); });
/** /**
* @param {any} req * @param {any} req
* @returns {Promise.<void>} * @returns {Promise<ResolvedAccount>}
*/ */
const accountFromRequest = (req) => new Promise((resolve, reject) => { const accountFromRequest = (req) => new Promise((resolve, reject) => {
const authorization = req.headers.authorization; const authorization = req.headers.authorization;
@ -494,25 +569,6 @@ const startServer = async () => {
reject(err); reject(err);
}); });
/**
* @param {any} info
* @param {function(boolean, number, string): void} callback
*/
const wsVerifyClient = (info, callback) => {
// When verifying the websockets connection, we no longer pre-emptively
// check OAuth scopes and drop the connection if they're missing. We only
// drop the connection if access without token is not allowed by environment
// variables. OAuth scope checks are moved to the point of subscription
// to a specific stream.
accountFromRequest(info.req).then(() => {
callback(true, undefined, undefined);
}).catch(err => {
log.error(info.req.requestId, err.toString());
callback(false, 401, 'Unauthorized');
});
};
/** /**
* @typedef SystemMessageHandlers * @typedef SystemMessageHandlers
* @property {function(): void} onKill * @property {function(): void} onKill
@ -944,8 +1000,8 @@ const startServer = async () => {
}; };
/** /**
* @param {any} req * @param {http.IncomingMessage} req
* @param {any} ws * @param {WebSocket} ws
* @param {string[]} streamName * @param {string[]} streamName
* @returns {function(string, string): void} * @returns {function(string, string): void}
*/ */
@ -955,7 +1011,9 @@ const startServer = async () => {
return; return;
} }
ws.send(JSON.stringify({ stream: streamName, event, payload }), (err) => { const message = JSON.stringify({ stream: streamName, event, payload });
ws.send(message, (/** @type {Error} */ err) => {
if (err) { if (err) {
log.error(req.requestId, `Failed to send to websocket: ${err}`); log.error(req.requestId, `Failed to send to websocket: ${err}`);
} }
@ -992,8 +1050,6 @@ const startServer = async () => {
}); });
}); });
const wss = new WebSocket.Server({ server, verifyClient: wsVerifyClient });
/** /**
* @typedef StreamParams * @typedef StreamParams
* @property {string} [tag] * @property {string} [tag]
@ -1173,8 +1229,8 @@ const startServer = async () => {
/** /**
* @typedef WebSocketSession * @typedef WebSocketSession
* @property {any} socket * @property {WebSocket} websocket
* @property {any} request * @property {http.IncomingMessage} request
* @property {Object.<string, { channelName: string, listener: SubscriptionListener, stopHeartbeat: function(): void }>} subscriptions * @property {Object.<string, { channelName: string, listener: SubscriptionListener, stopHeartbeat: function(): void }>} subscriptions
*/ */
@ -1297,7 +1353,11 @@ const startServer = async () => {
} }
}; };
wss.on('connection', (ws, req) => { /**
* @param {WebSocket & { isAlive: boolean }} ws
* @param {http.IncomingMessage} req
*/
function onConnection(ws, req) {
// Note: url.parse could throw, which would terminate the connection, so we // Note: url.parse could throw, which would terminate the connection, so we
// increment the connected clients metric straight away when we establish // increment the connected clients metric straight away when we establish
// the connection, without waiting: // the connection, without waiting:
@ -1375,7 +1435,9 @@ const startServer = async () => {
if (location && location.query.stream) { if (location && location.query.stream) {
subscribeWebsocketToChannel(session, firstParam(location.query.stream), location.query); subscribeWebsocketToChannel(session, firstParam(location.query.stream), location.query);
} }
}); }
wss.on('connection', onConnection);
setInterval(() => { setInterval(() => {
wss.clients.forEach(ws => { wss.clients.forEach(ws => {