Browse Source

Timeout for learning #481 (#485)

pull/1/head
rozetko 5 years ago committed by GitHub
parent
commit
9e66e8b035
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 16
      analytics/analytics/analytic_unit_worker.py
  2. 1
      analytics/analytics/config.py
  3. 1
      analytics/analytics/services/server_service.py
  4. 3
      config.example.json

16
analytics/analytics/analytic_unit_worker.py

@ -4,7 +4,7 @@ import logging
import pandas as pd
from typing import Optional, Union
from models import ModelCache
from concurrent.futures import Executor, CancelledError
from concurrent.futures import Executor, CancelledError, TimeoutError
import asyncio
logger = logging.getLogger('AnalyticUnitWorker')
@ -16,26 +16,28 @@ class AnalyticUnitWorker:
self.analytic_unit_id = analytic_unit_id
self._detector = detector
self._executor: Executor = executor
self._training_feature: asyncio.Future = None
self._training_future: asyncio.Future = None
async def do_train(
self, payload: Union[list, dict], data: pd.DataFrame, cache: Optional[ModelCache]
) -> ModelCache:
self._training_feature = asyncio.get_event_loop().run_in_executor(
self._executor, self._detector.train, data, payload, cache
self._training_future = self._executor.submit(
self._detector.train, data, payload, cache
)
try:
new_cache: ModelCache = await self._training_feature
new_cache: ModelCache = self._training_future.result(timeout = config.LEARNING_TIMEOUT)
return new_cache
except CancelledError as e:
return cache
except TimeoutError:
raise Exception('Timeout ({}s) exceeded while learning'.format(config.LEARNING_TIMEOUT))
async def do_detect(self, data: pd.DataFrame, cache: Optional[ModelCache]) -> dict:
return self._detector.detect(data, cache)
def cancel(self):
if self._training_feature is not None:
self._training_feature.cancel()
if self._training_future is not None:
self._training_future.cancel()
async def recieve_data(self, data: pd.DataFrame, cache: Optional[ModelCache]):
return self._detector.recieve_data(data, cache)

1
analytics/analytics/config.py

@ -28,3 +28,4 @@ def get_config_field(field, default_val = None):
ZMQ_DEV_PORT = get_config_field('ZMQ_DEV_PORT', '8002')
ZMQ_CONNECTION_STRING = get_config_field('ZMQ_CONNECTION_STRING', 'tcp://0.0.0.0:%s' % ZMQ_DEV_PORT)
LEARNING_TIMEOUT = get_config_field('LEARNING_TIMEOUT', 120)

1
analytics/analytics/services/server_service.py

@ -88,7 +88,6 @@ class ServerService:
self.responses[message_object['requestId']] = message.payload
return
logger.debug(message.toJSON())
asyncio.ensure_future(self.on_message_handler(message))
except Exception as e:
error_text = traceback.format_exc()

3
config.example.json

@ -4,5 +4,6 @@
"HASTIC_WEBHOOK_URL": "http://localhost:8080",
"HASTIC_WEBHOOK_TYPE": "application/x-www-form-urlencoded",
"HASTIC_WEBHOOK_SECRET": "mysecret",
"GRAFANA_URL": "http://localhost:3000"
"GRAFANA_URL": "http://localhost:3000",
"TRAIN_TIMEOUT": 120
}

Loading…
Cancel
Save