--- a +++ b/qiita_pet/handlers/websocket_handlers.py @@ -0,0 +1,144 @@ +# ----------------------------------------------------------------------------- +# Copyright (c) 2014--, The Qiita Development Team. +# +# Distributed under the terms of the BSD 3-clause License. +# +# The full license is in the file LICENSE, distributed with this software. +# ----------------------------------------------------------------------------- + +# adapted from +# https://github.com/leporo/tornado-redis/blob/master/demos/websockets +from json import loads, dumps +from itertools import chain + +import toredis +from tornado.web import authenticated +from tornado.websocket import WebSocketHandler +from tornado.gen import engine, Task + +from qiita_core.qiita_settings import r_client +from qiita_pet.handlers.base_handlers import BaseHandler +from qiita_db.artifact import Artifact +from qiita_core.util import execute_as_transaction + + +class MessageHandler(WebSocketHandler): + def __init__(self, *args, **kwargs): + super(MessageHandler, self).__init__(*args, **kwargs) + # The redis server + self.r_client = r_client + + # The toredis server that allows event-based message handling + self.toredis = toredis.Client() + self.toredis.connect() + + self.channel = None + self.channel_messages = None + + def get_current_user(self): + user = self.get_secure_cookie("user") + if user is None: + raise ValueError("No user associated with the websocket!") + else: + return user.strip('" ') + + # Open allows for any number arguments, unlike what pylint thinks. + # pylint: disable=W0221 + @authenticated + def open(self): + self.write_message('hello') + + @authenticated + def on_message(self, msg): + # When the websocket receives a message from the javascript client, + # parse into JSON + msginfo = loads(msg) + + # Determine which Redis communication channel the server needs to + # listen on + self.channel = msginfo.get('user', None) + + if self.channel is not None: + self.channel_messages = '%s:messages' % self.channel + self.listen() + + def listen(self): + # Attach a callback on the channel to listen too. This callback is + # executed when anything is placed onto the channel. + self.toredis.subscribe(self.channel, callback=self.callback) + + # Potential race-condition where a separate process may have placed + # messages into the queue before we've been able to attach listen. + oldmessages = self.r_client.lrange(self.channel_messages, 0, -1) + if oldmessages is not None: + for message in oldmessages: + self.write_message(message) + + def callback(self, msg): + message_type, channel, payload = msg + + # if a compute process wrote to the Redis channel that we are + # listening too, and if it is actually a message, send the payload to + # the javascript client via the websocket + if channel == self.channel and message_type == 'message': + self.write_message(payload) + + @engine + def on_close(self): + yield Task(self.toredis.unsubscribe, self.channel) + self.r_client.delete('%s:messages' % self.channel) + self.redis.disconnect() + + +class SelectedSocketHandler(WebSocketHandler, BaseHandler): + """Websocket for removing samples on default analysis display page""" + @authenticated + @execute_as_transaction + def on_message(self, msg): + # When the websocket receives a message from the javascript client, + # parse into JSON + msginfo = loads(msg) + default = self.current_user.default_analysis + + if 'remove_sample' in msginfo: + data = msginfo['remove_sample'] + artifact = Artifact(data['proc_data']) + default.remove_samples([artifact], data['samples']) + elif 'remove_pd' in msginfo: + data = msginfo['remove_pd'] + default.remove_samples([Artifact(data['proc_data'])]) + elif 'clear' in msginfo: + data = msginfo['clear'] + artifacts = [Artifact(_id) for _id in data['pids']] + default.remove_samples(artifacts) + self.write_message(msg) + + # Open allows for any number arguments, unlike what pylint thinks. + # pylint: disable=W0221 + @authenticated + @execute_as_transaction + def open(self): + self.write_message('hello') + + +class SelectSamplesHandler(WebSocketHandler, BaseHandler): + """Websocket for selecting and deselecting samples on list studies page""" + @authenticated + @execute_as_transaction + def on_message(self, msg): + """Selects samples on a message from the user + + Parameters + ---------- + msg : JSON str + Message containing sample and prc_data information, in the form + {proc_data_id': [s1, s2, ...], ...]} + """ + msginfo = loads(msg) + default = self.current_user.default_analysis + default.add_samples(msginfo['sel']) + # Count total number of unique samples selected and return + self.write_message(dumps({ + 'sel': len(set( + chain.from_iterable(s for s in msginfo['sel'].values()))) + }))