diff --git a/client/src/views/Home.vue b/client/src/views/Home.vue index 0f48f60..ce3434d 100644 --- a/client/src/views/Home.vue +++ b/client/src/views/Home.vue @@ -5,7 +5,8 @@
- Hold
ctrl
to label patterns
+ Hold
S
to label patterns
+ Hold
A
to label anti patterns
Holde key
D
to delete patterns
diff --git a/server/src/api/analytics.rs b/server/src/api/analytics.rs index a18213c..8b156e0 100644 --- a/server/src/api/analytics.rs +++ b/server/src/api/analytics.rs @@ -7,7 +7,9 @@ pub mod filters { pub fn filters( client: Client, ) -> impl Filter + Clone { - list(client.clone()).or(status(client.clone())) + list(client.clone()) + .or(status(client.clone())) + .or(list_train(client.clone())) // .or(create(db.clone())) // // .or(update(db.clone())) // .or(delete(db.clone())) @@ -34,6 +36,16 @@ pub mod filters { .and_then(handlers::status) } + /// GET /analytics/model + pub fn list_train( + client: Client, + ) -> impl Filter + Clone { + warp::path!("analytics" / "model") + .and(warp::get()) + .and(with_client(client)) + .and_then(handlers::list_train) + } + fn with_client( client: Client, ) -> impl Filter + Clone { @@ -69,6 +81,16 @@ mod handlers { } } } + + pub async fn list_train(client: Client) -> Result { + match client.get_train().await { + Ok(lt) => Ok(API::json(<)), + Err(e) => { + println!("{:?}", e); + Err(warp::reject::custom(BadQuery)) + } + } + } } mod models { diff --git a/server/src/services/analytic_service/analytic_client.rs b/server/src/services/analytic_service/analytic_client.rs index 60b0d74..32654de 100644 --- a/server/src/services/analytic_service/analytic_client.rs +++ b/server/src/services/analytic_service/analytic_client.rs @@ -3,8 +3,10 @@ use tokio::sync::oneshot; use crate::services::segments_service::Segment; +use super::pattern_detector::LearningResults; use super::types::DetectionTask; use super::types::LearningStatus; +use super::types::LearningTrain; use super::types::{AnalyticServiceMessage, RequestType}; /// Client to be used multithreaded @@ -33,6 +35,14 @@ impl AnalyticClient { Ok(r) } + pub async fn get_train(&self) -> anyhow::Result { + let (tx, rx) = oneshot::channel(); + let req = AnalyticServiceMessage::Request(RequestType::GetLearningTrain(tx)); + self.tx.send(req).await?; + let r = rx.await?; + Ok(r) + } + pub async fn get_pattern_detection(&self, from: u64, to: u64) -> anyhow::Result> { let (tx, rx) = oneshot::channel(); let req = AnalyticServiceMessage::Request(RequestType::RunDetection(DetectionTask { diff --git a/server/src/services/analytic_service/analytic_service.rs b/server/src/services/analytic_service/analytic_service.rs index 3a66900..a929f95 100644 --- a/server/src/services/analytic_service/analytic_service.rs +++ b/server/src/services/analytic_service/analytic_service.rs @@ -1,4 +1,4 @@ -use super::types::{self, DetectionRunnerConfig}; +use super::types::{self, DetectionRunnerConfig, LearningTrain}; use super::{ analytic_client::AnalyticClient, pattern_detector::{self, LearningResults, PatternDetector}, @@ -13,7 +13,6 @@ use crate::utils::{self, get_random_str}; use anyhow; -use subbeat::metric::MetricResult; use tokio::sync::{mpsc, oneshot}; use futures::future; @@ -167,6 +166,20 @@ impl AnalyticService { RequestType::GetStatus(tx) => { tx.send(self.learning_status.clone()).unwrap(); } + RequestType::GetLearningTrain(tx) => { + if self.learning_results.is_none() { + tx.send(LearningTrain::default()).unwrap(); + } else { + tx.send( + self.learning_results + .as_ref() + .unwrap() + .learning_train + .clone(), + ) + .unwrap(); + } + } }; } @@ -237,7 +250,10 @@ impl AnalyticService { // be careful if decide to store detections in db let segments = ss.get_segments_inside(0, u64::MAX / 2).unwrap(); - let has_segments_label = segments.iter().find(|s| s.segment_type == SegmentType::Label).is_some(); + let has_segments_label = segments + .iter() + .find(|s| s.segment_type == SegmentType::Label) + .is_some(); if !has_segments_label { match tx @@ -257,7 +273,7 @@ impl AnalyticService { let mut learn_tss = Vec::new(); let mut learn_anti_tss = Vec::new(); - + for r in rs { if r.is_err() { println!("Error extracting metrics from datasource"); @@ -274,7 +290,7 @@ impl AnalyticService { } let sd = r.unwrap(); - if sd.data.is_empty() { + if sd.data.is_empty() { continue; } if sd.label { @@ -294,7 +310,6 @@ impl AnalyticService { Ok(_) => {} Err(_e) => println!("Fail to send learning results"), } - } async fn get_pattern_detection( diff --git a/server/src/services/analytic_service/pattern_detector.rs b/server/src/services/analytic_service/pattern_detector.rs index e35b35e..02e8425 100644 --- a/server/src/services/analytic_service/pattern_detector.rs +++ b/server/src/services/analytic_service/pattern_detector.rs @@ -1,10 +1,8 @@ - use std::{fmt, sync::Arc}; - use parking_lot::Mutex; +use serde::{Deserialize, Serialize}; use serde_json; -use serde::{Serialize, Deserialize}; use linfa::prelude::*; @@ -13,11 +11,14 @@ use linfa_svm::{error::Result, Svm}; use ndarray::{Array, ArrayView, Axis}; - +use super::types::LearningTrain; #[derive(Clone)] pub struct LearningResults { model: Arc>>, + + pub learning_train: LearningTrain, + patterns: Vec>, anti_patterns: Vec>, } @@ -39,18 +40,17 @@ pub struct LearningResults { impl fmt::Debug for LearningResults { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { f.debug_struct("Point") - .field("{:?}", &self.patterns) - .field("{:?}", &self.anti_patterns) - .finish() + .field("{:?}", &self.patterns) + .field("{:?}", &self.anti_patterns) + .finish() } } +pub const FEATURES_SIZE: usize = 2; -const FEATURES_SIZE: usize = 6; +pub type Features = [f64; FEATURES_SIZE]; -type Features = [f64; FEATURES_SIZE]; - -const SCORE_THRESHOLD: f64 = 0.95; +pub const SCORE_THRESHOLD: f64 = 0.95; #[derive(Clone)] pub struct PatternDetector { @@ -79,16 +79,14 @@ impl PatternDetector { let mut patterns = Vec::>::new(); let mut anti_patterns = Vec::>::new(); - - let mut records = Array::zeros((0, FEATURES_SIZE)); + let mut records_raw = Vec::::new(); let mut targets_raw = Vec::::new(); for r in reads { let xs: Vec = r.iter().map(|e| e.1).map(nan_to_zero).collect(); let fs = PatternDetector::get_features(&xs); - - records.push_row(ArrayView::from(&fs)).unwrap(); - + + records_raw.push(fs); targets_raw.push(true); patterns.push(xs); } @@ -96,20 +94,22 @@ impl PatternDetector { for r in anti_reads { let xs: Vec = r.iter().map(|e| e.1).map(nan_to_zero).collect(); let fs = PatternDetector::get_features(&xs); - records.push_row(ArrayView::from(&fs)).unwrap(); + records_raw.push(fs); targets_raw.push(false); anti_patterns.push(xs); } - let targets = Array::from_vec(targets_raw); + let records = Array::from_shape_fn((records_raw.len(), FEATURES_SIZE), |(i, j)| { + records_raw[i][j] + }); + + let targets = Array::from_vec(targets_raw.clone()); // println!("{:?}", records); // println!("{:?}", targets); let train = linfa::Dataset::new(records, targets); - - // The 'view' describes what set of data is drawn // let v = ContinuousView::new() // .add(s1) @@ -118,18 +118,17 @@ impl PatternDetector { // .y_range(-200., 600.) // .x_label("Some varying variable") // .y_label("The response of something"); - - // Page::single(&v).save("scatter.svg").unwrap(); + // Page::single(&v).save("scatter.svg").unwrap(); // let model = stat.iter().map(|(c, v)| v / *c as f64).collect(); let model = Svm::<_, bool>::params() .pos_neg_weights(50000., 5000.) .gaussian_kernel(80.0) - .fit(&train).unwrap(); - - + .fit(&train) + .unwrap(); + // let prediction = model.predict(Array::from_vec(vec![ // 715.3122807017543, 761.1228070175438, 745.0, 56.135764727158595, 0.0, 0.0 // ])); @@ -138,6 +137,12 @@ impl PatternDetector { LearningResults { model: Arc::new(Mutex::new(model)), + + learning_train: LearningTrain { + features: records_raw, + target: targets_raw, + }, + patterns, anti_patterns, } @@ -149,9 +154,8 @@ impl PatternDetector { let pt = &self.learning_results.patterns; let apt = &self.learning_results.anti_patterns; - - for i in 0..ts.len() { + for i in 0..ts.len() { let mut pattern_match_score = 0f64; let mut pattern_match_len = 0usize; let mut anti_pattern_match_score = 0f64; @@ -183,13 +187,17 @@ impl PatternDetector { } } - { + { let mut backet = Vec::::new(); for j in 0..pattern_match_len { backet.push(nan_to_zero(ts[i + j].1)); } let fs = PatternDetector::get_features(&backet); - let detected = self.learning_results.model.lock().predict(Array::from_vec(fs.to_vec())); + let detected = self + .learning_results + .model + .lock() + .predict(Array::from_vec(fs.to_vec())); if detected { pattern_match_score += 0.1; } else { @@ -197,7 +205,9 @@ impl PatternDetector { } } - if pattern_match_score > anti_pattern_match_score && pattern_match_score >= SCORE_THRESHOLD { + if pattern_match_score > anti_pattern_match_score + && pattern_match_score >= SCORE_THRESHOLD + { results.push((ts[i].0, ts[i + pattern_match_len - 1].0)); } } @@ -253,7 +263,7 @@ impl PatternDetector { let mut min = f64::MAX; let mut max = f64::MIN; let mut sum = 0f64; - + for x in xs { min = min.min(*x); max = max.max(*x); @@ -272,15 +282,15 @@ impl PatternDetector { // TODO: add autocorrelation // TODO: add FFT + // TODO: add DWT return [ min, - max, - mean, - sd, - 0f64,0f64, - //0f64,0f64,0f64, 0f64 + max, + // mean, + // sd, + // 0f64,0f64, + // 0f64,0f64,0f64, 0f64 ]; } - } diff --git a/server/src/services/analytic_service/types.rs b/server/src/services/analytic_service/types.rs index fdcbb7e..0f0c8c5 100644 --- a/server/src/services/analytic_service/types.rs +++ b/server/src/services/analytic_service/types.rs @@ -1,10 +1,10 @@ use crate::services::segments_service::Segment; -use super::pattern_detector::LearningResults; +use super::pattern_detector::{self, LearningResults, PatternDetector}; use anyhow::Result; use serde::Serialize; -use tokio::sync::{mpsc, oneshot}; +use tokio::sync::oneshot; #[derive(Debug, Clone, PartialEq, Serialize)] pub enum LearningStatus { @@ -15,6 +15,21 @@ pub enum LearningStatus { Ready, } +#[derive(Clone, Serialize, Debug)] +pub struct LearningTrain { + pub features: Vec, + pub target: Vec, +} + +impl Default for LearningTrain { + fn default() -> Self { + return LearningTrain { + features: Vec::new(), + target: Vec::new(), + }; + } +} + #[derive(Debug)] pub enum ResponseType { LearningStarted, @@ -43,6 +58,7 @@ pub enum RequestType { RunLearning, RunDetection(DetectionTask), GetStatus(oneshot::Sender), + GetLearningTrain(oneshot::Sender), } #[derive(Debug)]