diff --git a/analytics/analytics/services/data_service.py b/analytics/analytics/services/data_service.py index 0723fa4..9978243 100644 --- a/analytics/analytics/services/data_service.py +++ b/analytics/analytics/services/data_service.py @@ -71,13 +71,13 @@ class DataService: 'content': content } message = ServerMessage('FILE_SAVE', message_payload) - await self.server_service.send_request(message) + await self.server_service.send_request_to_server(message) async def load_file_content(self, file_descriptor: FileDescriptor) -> str: self.__check_lock(file_descriptor) message_payload = { 'filename': file_descriptor.filename } message = ServerMessage('FILE_LOAD', message_payload) - return await self.server_service.send_request(message) + return await self.server_service.send_request_to_server(message) def __check_lock(self, file_descriptor: FileDescriptor): filename = file_descriptor.filename diff --git a/analytics/analytics/services/server_service.py b/analytics/analytics/services/server_service.py index afc5db9..33ecbc2 100644 --- a/analytics/analytics/services/server_service.py +++ b/analytics/analytics/services/server_service.py @@ -7,11 +7,18 @@ import logging import json import asyncio import traceback + +import utils.concurrent + from typing import Optional logger = logging.getLogger('SERVER_SERVICE') +PARSE_MESSAGE_OR_SAVE_LOOP_INTERRUPTED = False +SERVER_SOCKET_RECV_LOOP_INTERRUPTED = False + + class ServerMessage: def __init__(self, method: str, payload: object = None, request_id: int = None): self.method = method @@ -39,61 +46,81 @@ class ServerMessage: request_id = json['requestId'] return ServerMessage(method, payload, request_id) -class ServerService: +class ServerService(utils.concurrent.AsyncZmqActor): def __init__(self): - logger.info("Binding to %s ..." % config.ZMQ_CONNECTION_STRING) - self.context = zmq.asyncio.Context() - self.socket = self.context.socket(zmq.PAIR) - self.socket.bind(config.ZMQ_CONNECTION_STRING) - self.request_next_id = 1 - self.responses = dict() - self._aiter_inited = False + super(ServerService, self).__init__() + self.__aiter_inited = False + self.__request_next_id = 1 + self.__responses = dict() + self.start() + + async def send_message_to_server(self, message: ServerMessage): + # Following message will be sent to actor's self._on_message() + # We do it cuz we created self.__server_socket in self._run() method, + # which runs in the actor's thread, not the thread we created ServerService + + # in theory, we can try to use zmq.proxy: + # zmq.proxy(self.__actor_socket, self.__server_socket) + # and do here something like: + # self.__actor_socket.send_string(json.dumps(message.toJSON())) + await self._put_message_to_thread(json.dumps(message.toJSON())) + + async def send_request_to_server(self, message: ServerMessage) -> object: + if message.request_id is not None: + raise ValueError('Message can`t have request_id before it is scheduled') + request_id = message.request_id = self.__request_next_id + self.request_next_id = self.__request_next_id + 1 + asyncio.ensure_future(self.send_message_to_server(message)) + # you should await self.__responses[request_id] which should be a task, + # which you resolve somewhere else + while request_id not in self.__responses: + await asyncio.sleep(1) + response = self.__responses[request_id] + del self.__responses[request_id] + return response def __aiter__(self): - if self._aiter_inited: + if self.__aiter_inited: raise RuntimeError('Can`t iterate twice') - _aiter_inited = True + __aiter_inited = True return self async def __anext__(self) -> ServerMessage: - while True: - received_bytes = await self.socket.recv(copy=True) - text = received_bytes.decode('utf-8') - - if text == 'PING': - asyncio.ensure_future(self.__handle_ping()) + while not PARSE_MESSAGE_OR_SAVE_LOOP_INTERRUPTED: + thread_message = await self._recv_message_from_thread() + server_message = self.__parse_message_or_save(thread_message) + if server_message is None: + continue else: - message = self.__parse_message_or_save(text) - if message is None: - continue - else: - return message + return server_message - async def send_message(self, message: ServerMessage): - await self.socket.send_string(json.dumps(message.toJSON())) + async def _run_thread(self): + logger.info("Binding to %s ..." % config.ZMQ_CONNECTION_STRING) + self.__server_socket = self._zmq_context.socket(zmq.PAIR) + self.__server_socket.bind(config.ZMQ_CONNECTION_STRING) + await self.__server_socket_recv_loop() - async def send_request(self, message: ServerMessage) -> object: - if message.request_id is not None: - raise ValueError('Message can`t have request_id before it is scheduled') - request_id = message.request_id = self.request_next_id - self.request_next_id = self.request_next_id + 1 - asyncio.ensure_future(self.send_message(message)) - while request_id not in self.responses: - await asyncio.sleep(1) - response = self.responses[request_id] - del self.responses[request_id] - return response + async def _on_message_to_thread(self, message: str): + await self.__server_socket.send_string(message) + + async def __server_socket_recv_loop(self): + while not SERVER_SOCKET_RECV_LOOP_INTERRUPTED: + received_string = await self.__server_socket.recv_string() + if received_string == 'PING': + asyncio.ensure_future(self.__handle_ping()) + else: + asyncio.ensure_future(self._send_message_from_thread(received_string)) async def __handle_ping(self): - await self.socket.send(b'PONG') + await self.__server_socket.send_string('PONG') def __parse_message_or_save(self, text: str) -> Optional[ServerMessage]: try: message_object = json.loads(text) message = ServerMessage.fromJSON(message_object) if message.request_id is not None: - self.responses[message_object['requestId']] = message.payload + self.__responses[message_object['requestId']] = message.payload return None return message except Exception: diff --git a/analytics/bin/server b/analytics/bin/server index 996c9ec..a96b0c9 100755 --- a/analytics/bin/server +++ b/analytics/bin/server @@ -52,14 +52,14 @@ async def handle_task(task: object): if not task_type == 'PUSH': message = services.server_service.ServerMessage('TASK_RESULT', task_result_payload) - await server_service.send_message(message) + await server_service.send_message_to_server(message) res = await analytic_unit_manager.handle_analytic_task(task) res['_id'] = task['_id'] if not task_type == 'PUSH': message = services.server_service.ServerMessage('TASK_RESULT', res) - await server_service.send_message(message) + await server_service.send_message_to_server(message) except Exception as e: error_text = traceback.format_exc() @@ -71,7 +71,7 @@ async def handle_data(task: object): if res['status'] == 'SUCCESS' and res['payload'] is not None: res['_id'] = task['_id'] message = services.server_service.ServerMessage('DETECT', res) - await server_service.send_message(message) + await server_service.send_message_to_server(message) async def handle_message(message: services.ServerMessage): if message.method == 'TASK':