From 047efbd23357fdf1260d480c3bfbd98c81755d99 Mon Sep 17 00:00:00 2001 From: Alexey Velikiy Date: Sun, 7 Nov 2021 05:08:00 +0300 Subject: [PATCH] basic svm model --- server/Cargo.lock | 215 ++++++++++++++++++ server/Cargo.toml | 9 + .../analytic_service/pattern_detector.rs | 151 +++++++++++- 3 files changed, 367 insertions(+), 8 deletions(-) diff --git a/server/Cargo.lock b/server/Cargo.lock index f354c76..9567a63 100644 --- a/server/Cargo.lock +++ b/server/Cargo.lock @@ -37,6 +37,15 @@ version = "1.0.44" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "61604a8f862e1d5c3229fdd78f8b02c68dcf73a4c4b05fd636d12240aaa242c1" +[[package]] +name = "approx" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3f2a05fd1bd10b2527e20a2cd32d8873d115b8b39fe219ee25f42a8aca6ba278" +dependencies = [ + "num-traits 0.2.14", +] + [[package]] name = "arrayvec" version = "0.5.2" @@ -77,6 +86,15 @@ version = "0.13.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "904dfeac50f3cdaba28fc6f57fdcddb75f49ed61346676a78c4ffe55877802fd" +[[package]] +name = "bincode" +version = "1.3.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b1f45e9417d87227c7a56d22e471c6206462cba514c7590c09aff4cf6d1ddcad" +dependencies = [ + "serde 1.0.130", +] + [[package]] name = "bitflags" version = "1.3.2" @@ -238,6 +256,12 @@ dependencies = [ "generic-array", ] +[[package]] +name = "either" +version = "1.6.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e78d4f1cc4ae33bbfc157ed5d5a5ef3bc29227303d595861deb238fcec4e9457" + [[package]] name = "fallible-iterator" version = "0.2.0" @@ -458,15 +482,20 @@ name = "hastic" version = "0.0.1" dependencies = [ "anyhow", + "bincode", "chrono", "config", "fastrand", "futures", "libsqlite3-sys", + "linfa", + "linfa-svm", + "ndarray", "openssl", "parking_lot", "rusqlite", "serde 1.0.130", + "serde_json", "subbeat", "tokio", "warp", @@ -616,12 +645,30 @@ dependencies = [ "cfg-if", ] +[[package]] +name = "itertools" +version = "0.10.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "69ddb889f9d0d08a67338271fa9b62996bc788c7796a5c18cf057420aaed5eaf" +dependencies = [ + "either", +] + [[package]] name = "itoa" version = "0.4.8" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b71991ff56294aa922b450139ee08b3bfc70982c6b2c7562771375cf73542dd4" +[[package]] +name = "kdtree" +version = "0.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "80ee359328fc9087e9e3fc0a4567c4dd27ec69a127d6a70e8d9dd22845b8b1a2" +dependencies = [ + "num-traits 0.2.14", +] + [[package]] name = "lazy_static" version = "1.4.0" @@ -647,6 +694,12 @@ version = "0.2.105" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "869d572136620d55835903746bcb5cdc54cb2851fd0aeec53220b4bb65ef3013" +[[package]] +name = "libm" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c7d73b3f436185384286bd8098d17ec07c9a7d2388a6599f824d8502b529702a" + [[package]] name = "libsqlite3-sys" version = "0.23.1" @@ -658,6 +711,64 @@ dependencies = [ "vcpkg", ] +[[package]] +name = "linfa" +version = "0.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7496612871c85de124d12006109360ed7ffc4740ce3fe3fe35a749e2d7f2f88f" +dependencies = [ + "approx", + "ndarray", + "num-traits 0.2.14", + "rand 0.8.4", + "thiserror", +] + +[[package]] +name = "linfa-kernel" +version = "0.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6da781111566c8b9867c8e9c0ac69fd72808dda1e111164aa23ece9f15d6a95a" +dependencies = [ + "linfa", + "linfa-nn", + "ndarray", + "num-traits 0.2.14", + "serde 1.0.130", + "sprs", +] + +[[package]] +name = "linfa-nn" +version = "0.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "83748e3f71c8fa7fc490a54d2e33d40b8246a335618159fb313989ff6eb209ef" +dependencies = [ + "kdtree", + "linfa", + "ndarray", + "ndarray-stats", + "noisy_float", + "num-traits 0.2.14", + "order-stat", + "thiserror", +] + +[[package]] +name = "linfa-svm" +version = "0.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ecfffa7a24c3019dc4c4fa9c58bf8f187ff2fdbbd08b6fee454591c6ec285f3f" +dependencies = [ + "linfa", + "linfa-kernel", + "ndarray", + "ndarray-rand", + "num-traits 0.2.14", + "serde 1.0.130", + "thiserror", +] + [[package]] name = "linked-hash-map" version = "0.5.4" @@ -688,6 +799,15 @@ version = "0.1.9" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a3e378b66a060d48947b590737b30a1be76706c8dd7b8ba0f2fe3989c68a853f" +[[package]] +name = "matrixmultiply" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5a8a15b776d9dfaecd44b03c5828c2199cddff5247215858aac14624f8d6b741" +dependencies = [ + "rawpointer", +] + [[package]] name = "memchr" version = "2.4.1" @@ -768,6 +888,56 @@ dependencies = [ "tempfile", ] +[[package]] +name = "ndarray" +version = "0.15.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "08e854964160a323e65baa19a0b1a027f76d590faba01f05c0cbc3187221a8c9" +dependencies = [ + "approx", + "matrixmultiply", + "num-complex", + "num-integer", + "num-traits 0.2.14", + "rawpointer", + "serde 1.0.130", +] + +[[package]] +name = "ndarray-rand" +version = "0.14.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "65608f937acc725f5b164dcf40f4f0bc5d67dc268ab8a649d3002606718c4588" +dependencies = [ + "ndarray", + "rand 0.8.4", + "rand_distr", +] + +[[package]] +name = "ndarray-stats" +version = "0.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "22877ad014bafa2f7dcfa5d556b0c7a52b0546cc98061a1ebef6d1834957b069" +dependencies = [ + "indexmap", + "itertools", + "ndarray", + "noisy_float", + "num-integer", + "num-traits 0.2.14", + "rand 0.8.4", +] + +[[package]] +name = "noisy_float" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "978fe6e6ebc0bf53de533cd456ca2d9de13de13856eda1518a285d7705a213af" +dependencies = [ + "num-traits 0.2.14", +] + [[package]] name = "nom" version = "5.1.2" @@ -788,6 +958,15 @@ dependencies = [ "winapi", ] +[[package]] +name = "num-complex" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "26873667bbbb7c5182d4a37c1add32cdf09f841af72da53318fdb81543c15085" +dependencies = [ + "num-traits 0.2.14", +] + [[package]] name = "num-integer" version = "0.1.44" @@ -814,6 +993,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9a64b1ec5cda2586e284722486d802acf1f7dbdc623e2bfc57e65ca1cd099290" dependencies = [ "autocfg", + "libm", ] [[package]] @@ -881,6 +1061,12 @@ dependencies = [ "vcpkg", ] +[[package]] +name = "order-stat" +version = "0.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "efa535d5117d3661134dbf1719b6f0ffe06f2375843b13935db186cd094105eb" + [[package]] name = "parking_lot" version = "0.11.2" @@ -1055,6 +1241,16 @@ dependencies = [ "getrandom 0.2.3", ] +[[package]] +name = "rand_distr" +version = "0.4.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "964d548f8e7d12e102ef183a0de7e98180c9f8729f555897a857b96e48122d2f" +dependencies = [ + "num-traits 0.2.14", + "rand 0.8.4", +] + [[package]] name = "rand_hc" version = "0.2.0" @@ -1073,6 +1269,12 @@ dependencies = [ "rand_core 0.6.3", ] +[[package]] +name = "rawpointer" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "60a357793950651c4ed0f3f52338f53b2f809f32d83a07f72909fa13e4c6c1e3" + [[package]] name = "redox_syscall" version = "0.2.10" @@ -1308,6 +1510,19 @@ dependencies = [ "winapi", ] +[[package]] +name = "sprs" +version = "0.11.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6ea71e48b3eab4c4b153e8e35dcaeac132720809ef68359097b8cb54a18edd70" +dependencies = [ + "ndarray", + "num-complex", + "num-traits 0.2.14", + "serde 1.0.130", + "smallvec", +] + [[package]] name = "static_assertions" version = "1.1.0" diff --git a/server/Cargo.toml b/server/Cargo.toml index afdbef5..3ebfe96 100644 --- a/server/Cargo.toml +++ b/server/Cargo.toml @@ -11,6 +11,7 @@ tokio = { version = "1", features = ["full"] } warp = "0.3" parking_lot = "0.11.2" serde = { version = "1.0", features = ["derive"] } +serde_json = "1.0" fastrand = "1.5.0" # subbeat = { path = "../../../subbeat/subbeat/" } subbeat = "0.0.14" @@ -21,3 +22,11 @@ rusqlite = "0.26.1" libsqlite3-sys = { version = "*", features = ["bundled"] } futures = "0.3.17" chrono = "0.4.19" + +linfa = "0.5.0" +linfa-svm = { version="0.5.0", features=["serde"] } +ndarray = "0.15.3" +bincode = "1.3.3" + +# TODO: remove this from prod +# plotlib = "0.5.1" diff --git a/server/src/services/analytic_service/pattern_detector.rs b/server/src/services/analytic_service/pattern_detector.rs index c9e522d..e35b35e 100644 --- a/server/src/services/analytic_service/pattern_detector.rs +++ b/server/src/services/analytic_service/pattern_detector.rs @@ -1,11 +1,56 @@ -#[derive(Debug, Clone)] + +use std::{fmt, sync::Arc}; + + +use parking_lot::Mutex; +use serde_json; +use serde::{Serialize, Deserialize}; + +use linfa::prelude::*; + +use linfa; +use linfa_svm::{error::Result, Svm}; + +use ndarray::{Array, ArrayView, Axis}; + + + +#[derive(Clone)] pub struct LearningResults { - // model: Vec, + model: Arc>>, patterns: Vec>, anti_patterns: Vec>, } -const CORR_THRESHOLD: f64 = 0.95; +// impl Clone for LearningResults { +// fn clone(&self) -> Self { +// // TODO: it's a hack +// // https://github.com/rust-ml/linfa/issues/174 +// let model_str = serde_json::to_string(&self.model).unwrap(); +// let model = serde_json::from_str(&model_str).unwrap(); +// return LearningResults { +// model, +// patterns: self.patterns.clone(), +// anti_patterns: self.anti_patterns.clone() +// }; +// } +// } + +impl fmt::Debug for LearningResults { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("Point") + .field("{:?}", &self.patterns) + .field("{:?}", &self.anti_patterns) + .finish() + } +} + + +const FEATURES_SIZE: usize = 6; + +type Features = [f64; FEATURES_SIZE]; + +const SCORE_THRESHOLD: f64 = 0.95; #[derive(Clone)] pub struct PatternDetector { @@ -34,21 +79,65 @@ impl PatternDetector { let mut patterns = Vec::>::new(); let mut anti_patterns = Vec::>::new(); - // TODO: implement actual learning + + let mut records = Array::zeros((0, FEATURES_SIZE)); + let mut targets_raw = Vec::::new(); for r in reads { let xs: Vec = r.iter().map(|e| e.1).map(nan_to_zero).collect(); + let fs = PatternDetector::get_features(&xs); + + records.push_row(ArrayView::from(&fs)).unwrap(); + + targets_raw.push(true); patterns.push(xs); } for r in anti_reads { let xs: Vec = r.iter().map(|e| e.1).map(nan_to_zero).collect(); + let fs = PatternDetector::get_features(&xs); + records.push_row(ArrayView::from(&fs)).unwrap(); + targets_raw.push(false); anti_patterns.push(xs); } + let targets = Array::from_vec(targets_raw); + + // println!("{:?}", records); + // println!("{:?}", targets); + + let train = linfa::Dataset::new(records, targets); + + + + // The 'view' describes what set of data is drawn + // let v = ContinuousView::new() + // .add(s1) + // // .add(s2) + // .x_range(-500., 100.) + // .y_range(-200., 600.) + // .x_label("Some varying variable") + // .y_label("The response of something"); + + // Page::single(&v).save("scatter.svg").unwrap(); + + // let model = stat.iter().map(|(c, v)| v / *c as f64).collect(); + let model = Svm::<_, bool>::params() + .pos_neg_weights(50000., 5000.) + .gaussian_kernel(80.0) + .fit(&train).unwrap(); + + + // let prediction = model.predict(Array::from_vec(vec![ + // 715.3122807017543, 761.1228070175438, 745.0, 56.135764727158595, 0.0, 0.0 + // ])); + + // println!("pridiction: {}", prediction ); + LearningResults { + model: Arc::new(Mutex::new(model)), patterns, anti_patterns, } @@ -57,12 +146,9 @@ impl PatternDetector { // TODO: get iterator instead of vector pub fn detect(&self, ts: &Vec<(u64, f64)>) -> Vec<(u64, u64)> { let mut results = Vec::new(); - // let mut i = 0; - let pt = &self.learning_results.patterns; let apt = &self.learning_results.anti_patterns; - for i in 0..ts.len() { @@ -97,7 +183,21 @@ impl PatternDetector { } } - if pattern_match_score > anti_pattern_match_score && pattern_match_score >= CORR_THRESHOLD { + { + let mut backet = Vec::::new(); + for j in 0..pattern_match_len { + backet.push(nan_to_zero(ts[i + j].1)); + } + let fs = PatternDetector::get_features(&backet); + let detected = self.learning_results.model.lock().predict(Array::from_vec(fs.to_vec())); + if detected { + pattern_match_score += 0.1; + } else { + anti_pattern_match_score += 0.1; + } + } + + if pattern_match_score > anti_pattern_match_score && pattern_match_score >= SCORE_THRESHOLD { results.push((ts[i].0, ts[i + pattern_match_len - 1].0)); } } @@ -148,4 +248,39 @@ impl PatternDetector { return result; } + + fn get_features(xs: &Vec) -> Features { + let mut min = f64::MAX; + let mut max = f64::MIN; + let mut sum = 0f64; + + for x in xs { + min = min.min(*x); + max = max.max(*x); + sum += x; + } + + let mean = sum / xs.len() as f64; + + sum = 0f64; + + for x in xs { + sum += (x - mean) * (x - mean); + } + + let sd = sum.sqrt(); + + // TODO: add autocorrelation + // TODO: add FFT + + return [ + min, + max, + mean, + sd, + 0f64,0f64, + //0f64,0f64,0f64, 0f64 + ]; + } + }