diff --git a/analytics/analytics/analytic_unit_manager.py b/analytics/analytics/analytic_unit_manager.py index ec65053..88b087f 100644 --- a/analytics/analytics/analytic_unit_manager.py +++ b/analytics/analytics/analytic_unit_manager.py @@ -48,32 +48,36 @@ class AnalyticUnitManager: self.analytic_workers[analytic_unit_id] = worker return worker + + async def __handle_analytic_task(self, task) -> dict: + """ + returns payload or None + """ + analytic_unit_id: AnalyticUnitId = task['analyticUnitId'] + + if task['type'] == 'CANCEL': + if analytic_unit_id in self.analytic_workers: + self.analytic_workers[analytic_unit_id].cancel() + return + + payload = task['payload'] + worker = self.__ensure_worker(analytic_unit_id, payload['pattern']) + data = prepare_data(payload['data']) + if task['type'] == 'LEARN': + return await worker.do_train(payload['segments'], data, payload['cache']) + elif task['type'] == 'PREDICT': + return await worker.do_predict(data, payload['cache']) + + raise ValueError('Unknown task type "%s"' % task['type']) + + async def handle_analytic_task(self, task): try: - analytic_unit_id: AnalyticUnitId = task['analyticUnitId'] - - if task['type'] == 'CANCEL': - if analytic_unit_id in self.analytic_workers: - self.analytic_workers[analytic_unit_id].cancel() - return { - 'status': 'SUCCESS' - } - - payload = task['payload'] - worker = self.__ensure_worker(analytic_unit_id, payload['pattern']) - data = prepare_data(payload['data']) - result_payload = {} - if task['type'] == 'LEARN': - result_payload = await worker.do_train(payload['segments'], data, payload['cache']) - elif task['type'] == 'PREDICT': - result_payload = await worker.do_predict(data, payload['cache']) - else: - raise ValueError('Unknown task type "%s"' % task['type']) + result_payload = await self.__handle_analytic_task(task) return { 'status': 'SUCCESS', 'payload': result_payload } - except Exception as e: error_text = traceback.format_exc() logger.error("handle_analytic_task exception: '%s'" % error_text) @@ -82,5 +86,3 @@ class AnalyticUnitManager: 'status': 'FAILED', 'error': str(e) } - - diff --git a/analytics/analytics/analytic_unit_worker.py b/analytics/analytics/analytic_unit_worker.py index 4f221fe..9cb8c94 100644 --- a/analytics/analytics/analytic_unit_worker.py +++ b/analytics/analytics/analytic_unit_worker.py @@ -30,7 +30,6 @@ class AnalyticUnitWorker: except CancelledError as e: return cache - async def do_predict(self, data: pd.DataFrame, cache: Optional[AnalyticUnitCache]) -> dict: return self._detector.predict(data, cache) diff --git a/analytics/analytics/services/server_service.py b/analytics/analytics/services/server_service.py index 2868959..0d8477a 100644 --- a/analytics/analytics/services/server_service.py +++ b/analytics/analytics/services/server_service.py @@ -17,7 +17,7 @@ class ServerMessage: self.payload = payload self.request_id = request_id - def toJSON(self): + def toJSON(self) -> dict: result = { 'method': self.method } diff --git a/server/src/controllers/analytics_controller.ts b/server/src/controllers/analytics_controller.ts index f363d37..6c0ddc3 100644 --- a/server/src/controllers/analytics_controller.ts +++ b/server/src/controllers/analytics_controller.ts @@ -13,6 +13,7 @@ import * as _ from 'lodash'; type TaskResult = any; +type PredictionResult = any; export type TaskResolver = (taskResult: TaskResult) => void; const taskResolvers = new Map(); @@ -37,16 +38,25 @@ function onTaskResult(taskResult: TaskResult) { } } +function onPredict(predictionResult: PredictionResult) { + processPredictionResult(predictionResult.analyticUnitId, predictionResult); +} + async function onMessage(message: AnalyticsMessage) { let responsePayload = null; - let resolvedMethod = false; + let methodResolved = false; if(message.method === AnalyticsMessageMethod.TASK_RESULT) { onTaskResult(message.payload); - resolvedMethod = true; + methodResolved = true; } - if(!resolvedMethod) { + if(message.method === AnalyticsMessageMethod.PREDICT) { + onPredict(message.payload); + methodResolved = true; + } + + if(!methodResolved) { throw new TypeError('Unknown method ' + message.method); } @@ -182,7 +192,7 @@ export async function runPredict(id: AnalyticUnit.AnalyticUnitId) { return []; } - let payload = processPredictionResult(id, result); + let payload = processPredictionResult(id, result.payload); // TODO: implement segments merging without removing labeled // if(segments.length > 0 && payload.segments.length > 0) { @@ -223,30 +233,29 @@ export async function deleteNonpredictedSegments(id, payload) { Segment.removeSegments(segmentsToRemove.map(s => s.id)); } -function processPredictionResult(analyticUnitId: AnalyticUnit.AnalyticUnitId, taskResult: any): { +function processPredictionResult(analyticUnitId: AnalyticUnit.AnalyticUnitId, predictionResult: PredictionResult): { lastPredictionTime: number, segments: Segment.Segment[], cache: any } { - let payload = taskResult.payload; - if (payload === undefined) { - throw new Error(`Missing payload in result: ${taskResult}`); - } - if (payload.segments === undefined || !Array.isArray(payload.segments)) { - throw new Error(`Missing segments in result or it is corrupted: ${JSON.stringify(payload)}`); + + if (predictionResult.segments === undefined || !Array.isArray(predictionResult.segments)) { + throw new Error(`Missing segments in result or it is corrupted: ${JSON.stringify(predictionResult)}`); } - if (payload.lastPredictionTime === undefined || isNaN(+payload.lastPredictionTime)) { + if (predictionResult.lastPredictionTime === undefined || isNaN(+predictionResult.lastPredictionTime)) { throw new Error( - `Missing lastPredictionTime is result or it is corrupted: ${JSON.stringify(payload)}` + `Missing lastPredictionTime is result or it is corrupted: ${JSON.stringify(predictionResult)}` ); } - let segments = payload.segments.map(segment => new Segment.Segment(analyticUnitId, segment.from, segment.to, false, false)); + let segments = predictionResult.segments.map( + segment => new Segment.Segment(analyticUnitId, segment.from, segment.to, false, false) + ); return { - lastPredictionTime: payload.lastPredictionTime, + lastPredictionTime: predictionResult.lastPredictionTime, segments: segments, - cache: payload.cache + cache: predictionResult.cache }; } diff --git a/server/src/models/analytics_message_model.ts b/server/src/models/analytics_message_model.ts index a82b7db..88039e4 100644 --- a/server/src/models/analytics_message_model.ts +++ b/server/src/models/analytics_message_model.ts @@ -1,6 +1,7 @@ export enum AnalyticsMessageMethod { TASK = 'TASK', - TASK_RESULT = 'TASK_RESULT' + TASK_RESULT = 'TASK_RESULT', + PREDICT = 'PREDICT' } export class AnalyticsMessage {