[RFC] Return 401 for an authentication error on WebSockets (#3411)

* Return 401 for an authentication error on WebSocket

* Use upgradeReq instead of a custom object
pull/3278/head^2
unarist 2017-05-30 01:20:53 +09:00 committed by Eugen Rochko
parent 5e2c5e95b6
commit 9a81be0d37
1 changed files with 48 additions and 39 deletions

View File

@ -95,7 +95,6 @@ const startWorker = (workerId) => {
const app = express(); const app = express();
const pgPool = new pg.Pool(Object.assign(pgConfigs[env], dbUrlToConfig(process.env.DATABASE_URL))); const pgPool = new pg.Pool(Object.assign(pgConfigs[env], dbUrlToConfig(process.env.DATABASE_URL)));
const server = http.createServer(app); const server = http.createServer(app);
const wss = new WebSocket.Server({ server });
const redisNamespace = process.env.REDIS_NAMESPACE || null; const redisNamespace = process.env.REDIS_NAMESPACE || null;
const redisParams = { const redisParams = {
@ -186,14 +185,10 @@ const startWorker = (workerId) => {
}); });
}; };
const authenticationMiddleware = (req, res, next) => { const accountFromRequest = (req, next) => {
if (req.method === 'OPTIONS') { const authorization = req.headers.authorization;
next(); const location = url.parse(req.url, true);
return; const accessToken = location.query.access_token;
}
const authorization = req.get('Authorization');
const accessToken = req.query.access_token;
if (!authorization && !accessToken) { if (!authorization && !accessToken) {
const err = new Error('Missing access token'); const err = new Error('Missing access token');
@ -208,6 +203,26 @@ const startWorker = (workerId) => {
accountFromToken(token, req, next); accountFromToken(token, req, next);
}; };
const wsVerifyClient = (info, cb) => {
accountFromRequest(info.req, err => {
if (!err) {
cb(true, undefined, undefined);
} else {
log.error(info.req.requestId, err.toString());
cb(false, 401, 'Unauthorized');
}
});
};
const authenticationMiddleware = (req, res, next) => {
if (req.method === 'OPTIONS') {
next();
return;
}
accountFromRequest(req, next);
};
const errorMiddleware = (err, req, res, next) => { const errorMiddleware = (err, req, res, next) => {
log.error(req.requestId, err.toString()); log.error(req.requestId, err.toString());
res.writeHead(err.statusCode || 500, { 'Content-Type': 'application/json' }); res.writeHead(err.statusCode || 500, { 'Content-Type': 'application/json' });
@ -352,10 +367,12 @@ const startWorker = (workerId) => {
streamFrom(`timeline:hashtag:${req.query.tag}:local`, req, streamToHttp(req, res), streamHttpEnd(req), true); streamFrom(`timeline:hashtag:${req.query.tag}:local`, req, streamToHttp(req, res), streamHttpEnd(req), true);
}); });
const wss = new WebSocket.Server({ server, verifyClient: wsVerifyClient });
wss.on('connection', ws => { wss.on('connection', ws => {
const location = url.parse(ws.upgradeReq.url, true); const req = ws.upgradeReq;
const token = location.query.access_token; const location = url.parse(req.url, true);
const req = { requestId: uuid.v4() }; req.requestId = uuid.v4();
ws.isAlive = true; ws.isAlive = true;
@ -363,33 +380,25 @@ const startWorker = (workerId) => {
ws.isAlive = true; ws.isAlive = true;
}); });
accountFromToken(token, req, err => { switch(location.query.stream) {
if (err) { case 'user':
log.error(req.requestId, err); streamFrom(`timeline:${req.accountId}`, req, streamToWs(req, ws), streamWsEnd(req, ws));
ws.close(); break;
return; case 'public':
} streamFrom('timeline:public', req, streamToWs(req, ws), streamWsEnd(req, ws), true);
break;
switch(location.query.stream) { case 'public:local':
case 'user': streamFrom('timeline:public:local', req, streamToWs(req, ws), streamWsEnd(req, ws), true);
streamFrom(`timeline:${req.accountId}`, req, streamToWs(req, ws), streamWsEnd(req, ws)); break;
break; case 'hashtag':
case 'public': streamFrom(`timeline:hashtag:${location.query.tag}`, req, streamToWs(req, ws), streamWsEnd(req, ws), true);
streamFrom('timeline:public', req, streamToWs(req, ws), streamWsEnd(req, ws), true); break;
break; case 'hashtag:local':
case 'public:local': streamFrom(`timeline:hashtag:${location.query.tag}:local`, req, streamToWs(req, ws), streamWsEnd(req, ws), true);
streamFrom('timeline:public:local', req, streamToWs(req, ws), streamWsEnd(req, ws), true); break;
break; default:
case 'hashtag': ws.close();
streamFrom(`timeline:hashtag:${location.query.tag}`, req, streamToWs(req, ws), streamWsEnd(req, ws), true); }
break;
case 'hashtag:local':
streamFrom(`timeline:hashtag:${location.query.tag}:local`, req, streamToWs(req, ws), streamWsEnd(req, ws), true);
break;
default:
ws.close();
}
});
}); });
const wsInterval = setInterval(() => { const wsInterval = setInterval(() => {