From 9a81be0d3715eb846d940794f8b34cbbe4ba67a5 Mon Sep 17 00:00:00 2001 From: unarist Date: Tue, 30 May 2017 01:20:53 +0900 Subject: [PATCH] [RFC] Return 401 for an authentication error on WebSockets (#3411) * Return 401 for an authentication error on WebSocket * Use upgradeReq instead of a custom object --- streaming/index.js | 87 +++++++++++++++++++++++++--------------------- 1 file changed, 48 insertions(+), 39 deletions(-) diff --git a/streaming/index.js b/streaming/index.js index fe39cf21df6..0411ae8efe2 100644 --- a/streaming/index.js +++ b/streaming/index.js @@ -95,7 +95,6 @@ const startWorker = (workerId) => { const app = express(); const pgPool = new pg.Pool(Object.assign(pgConfigs[env], dbUrlToConfig(process.env.DATABASE_URL))); const server = http.createServer(app); - const wss = new WebSocket.Server({ server }); const redisNamespace = process.env.REDIS_NAMESPACE || null; const redisParams = { @@ -186,14 +185,10 @@ const startWorker = (workerId) => { }); }; - const authenticationMiddleware = (req, res, next) => { - if (req.method === 'OPTIONS') { - next(); - return; - } - - const authorization = req.get('Authorization'); - const accessToken = req.query.access_token; + const accountFromRequest = (req, next) => { + const authorization = req.headers.authorization; + const location = url.parse(req.url, true); + const accessToken = location.query.access_token; if (!authorization && !accessToken) { const err = new Error('Missing access token'); @@ -208,6 +203,26 @@ const startWorker = (workerId) => { accountFromToken(token, req, next); }; + const wsVerifyClient = (info, cb) => { + accountFromRequest(info.req, err => { + if (!err) { + cb(true, undefined, undefined); + } else { + log.error(info.req.requestId, err.toString()); + cb(false, 401, 'Unauthorized'); + } + }); + }; + + const authenticationMiddleware = (req, res, next) => { + if (req.method === 'OPTIONS') { + next(); + return; + } + + accountFromRequest(req, next); + }; + const errorMiddleware = (err, req, res, next) => { log.error(req.requestId, err.toString()); res.writeHead(err.statusCode || 500, { 'Content-Type': 'application/json' }); @@ -352,10 +367,12 @@ const startWorker = (workerId) => { streamFrom(`timeline:hashtag:${req.query.tag}:local`, req, streamToHttp(req, res), streamHttpEnd(req), true); }); + const wss = new WebSocket.Server({ server, verifyClient: wsVerifyClient }); + wss.on('connection', ws => { - const location = url.parse(ws.upgradeReq.url, true); - const token = location.query.access_token; - const req = { requestId: uuid.v4() }; + const req = ws.upgradeReq; + const location = url.parse(req.url, true); + req.requestId = uuid.v4(); ws.isAlive = true; @@ -363,33 +380,25 @@ const startWorker = (workerId) => { ws.isAlive = true; }); - accountFromToken(token, req, err => { - if (err) { - log.error(req.requestId, err); - ws.close(); - return; - } - - switch(location.query.stream) { - case 'user': - streamFrom(`timeline:${req.accountId}`, req, streamToWs(req, ws), streamWsEnd(req, ws)); - break; - case 'public': - streamFrom('timeline:public', req, streamToWs(req, ws), streamWsEnd(req, ws), true); - break; - case 'public:local': - streamFrom('timeline:public:local', req, streamToWs(req, ws), streamWsEnd(req, ws), true); - break; - case 'hashtag': - streamFrom(`timeline:hashtag:${location.query.tag}`, req, streamToWs(req, ws), streamWsEnd(req, ws), true); - break; - case 'hashtag:local': - streamFrom(`timeline:hashtag:${location.query.tag}:local`, req, streamToWs(req, ws), streamWsEnd(req, ws), true); - break; - default: - ws.close(); - } - }); + switch(location.query.stream) { + case 'user': + streamFrom(`timeline:${req.accountId}`, req, streamToWs(req, ws), streamWsEnd(req, ws)); + break; + case 'public': + streamFrom('timeline:public', req, streamToWs(req, ws), streamWsEnd(req, ws), true); + break; + case 'public:local': + streamFrom('timeline:public:local', req, streamToWs(req, ws), streamWsEnd(req, ws), true); + break; + case 'hashtag': + streamFrom(`timeline:hashtag:${location.query.tag}`, req, streamToWs(req, ws), streamWsEnd(req, ws), true); + break; + case 'hashtag:local': + streamFrom(`timeline:hashtag:${location.query.tag}:local`, req, streamToWs(req, ws), streamWsEnd(req, ws), true); + break; + default: + ws.close(); + } }); const wsInterval = setInterval(() => {