Browse Source

Timeout for learning #481 (#485)

pull/1/head
rozetko 5 years ago committed by GitHub
parent
commit
9e66e8b035
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 16
      analytics/analytics/analytic_unit_worker.py
  2. 1
      analytics/analytics/config.py
  3. 1
      analytics/analytics/services/server_service.py
  4. 3
      config.example.json

16
analytics/analytics/analytic_unit_worker.py

@ -4,7 +4,7 @@ import logging
import pandas as pd import pandas as pd
from typing import Optional, Union from typing import Optional, Union
from models import ModelCache from models import ModelCache
from concurrent.futures import Executor, CancelledError from concurrent.futures import Executor, CancelledError, TimeoutError
import asyncio import asyncio
logger = logging.getLogger('AnalyticUnitWorker') logger = logging.getLogger('AnalyticUnitWorker')
@ -16,26 +16,28 @@ class AnalyticUnitWorker:
self.analytic_unit_id = analytic_unit_id self.analytic_unit_id = analytic_unit_id
self._detector = detector self._detector = detector
self._executor: Executor = executor self._executor: Executor = executor
self._training_feature: asyncio.Future = None self._training_future: asyncio.Future = None
async def do_train( async def do_train(
self, payload: Union[list, dict], data: pd.DataFrame, cache: Optional[ModelCache] self, payload: Union[list, dict], data: pd.DataFrame, cache: Optional[ModelCache]
) -> ModelCache: ) -> ModelCache:
self._training_feature = asyncio.get_event_loop().run_in_executor( self._training_future = self._executor.submit(
self._executor, self._detector.train, data, payload, cache self._detector.train, data, payload, cache
) )
try: try:
new_cache: ModelCache = await self._training_feature new_cache: ModelCache = self._training_future.result(timeout = config.LEARNING_TIMEOUT)
return new_cache return new_cache
except CancelledError as e: except CancelledError as e:
return cache return cache
except TimeoutError:
raise Exception('Timeout ({}s) exceeded while learning'.format(config.LEARNING_TIMEOUT))
async def do_detect(self, data: pd.DataFrame, cache: Optional[ModelCache]) -> dict: async def do_detect(self, data: pd.DataFrame, cache: Optional[ModelCache]) -> dict:
return self._detector.detect(data, cache) return self._detector.detect(data, cache)
def cancel(self): def cancel(self):
if self._training_feature is not None: if self._training_future is not None:
self._training_feature.cancel() self._training_future.cancel()
async def recieve_data(self, data: pd.DataFrame, cache: Optional[ModelCache]): async def recieve_data(self, data: pd.DataFrame, cache: Optional[ModelCache]):
return self._detector.recieve_data(data, cache) return self._detector.recieve_data(data, cache)

1
analytics/analytics/config.py

@ -28,3 +28,4 @@ def get_config_field(field, default_val = None):
ZMQ_DEV_PORT = get_config_field('ZMQ_DEV_PORT', '8002') ZMQ_DEV_PORT = get_config_field('ZMQ_DEV_PORT', '8002')
ZMQ_CONNECTION_STRING = get_config_field('ZMQ_CONNECTION_STRING', 'tcp://0.0.0.0:%s' % ZMQ_DEV_PORT) ZMQ_CONNECTION_STRING = get_config_field('ZMQ_CONNECTION_STRING', 'tcp://0.0.0.0:%s' % ZMQ_DEV_PORT)
LEARNING_TIMEOUT = get_config_field('LEARNING_TIMEOUT', 120)

1
analytics/analytics/services/server_service.py

@ -88,7 +88,6 @@ class ServerService:
self.responses[message_object['requestId']] = message.payload self.responses[message_object['requestId']] = message.payload
return return
logger.debug(message.toJSON())
asyncio.ensure_future(self.on_message_handler(message)) asyncio.ensure_future(self.on_message_handler(message))
except Exception as e: except Exception as e:
error_text = traceback.format_exc() error_text = traceback.format_exc()

3
config.example.json

@ -4,5 +4,6 @@
"HASTIC_WEBHOOK_URL": "http://localhost:8080", "HASTIC_WEBHOOK_URL": "http://localhost:8080",
"HASTIC_WEBHOOK_TYPE": "application/x-www-form-urlencoded", "HASTIC_WEBHOOK_TYPE": "application/x-www-form-urlencoded",
"HASTIC_WEBHOOK_SECRET": "mysecret", "HASTIC_WEBHOOK_SECRET": "mysecret",
"GRAFANA_URL": "http://localhost:3000" "GRAFANA_URL": "http://localhost:3000",
"TRAIN_TIMEOUT": 120
} }

Loading…
Cancel
Save