Browse Source

basic svm model

pull/25/head
Alexey Velikiy 3 years ago
parent
commit
047efbd233
  1. 215
      server/Cargo.lock
  2. 9
      server/Cargo.toml
  3. 151
      server/src/services/analytic_service/pattern_detector.rs

215
server/Cargo.lock generated

@ -37,6 +37,15 @@ version = "1.0.44"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "61604a8f862e1d5c3229fdd78f8b02c68dcf73a4c4b05fd636d12240aaa242c1"
[[package]]
name = "approx"
version = "0.4.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3f2a05fd1bd10b2527e20a2cd32d8873d115b8b39fe219ee25f42a8aca6ba278"
dependencies = [
"num-traits 0.2.14",
]
[[package]]
name = "arrayvec"
version = "0.5.2"
@ -77,6 +86,15 @@ version = "0.13.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "904dfeac50f3cdaba28fc6f57fdcddb75f49ed61346676a78c4ffe55877802fd"
[[package]]
name = "bincode"
version = "1.3.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b1f45e9417d87227c7a56d22e471c6206462cba514c7590c09aff4cf6d1ddcad"
dependencies = [
"serde 1.0.130",
]
[[package]]
name = "bitflags"
version = "1.3.2"
@ -238,6 +256,12 @@ dependencies = [
"generic-array",
]
[[package]]
name = "either"
version = "1.6.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e78d4f1cc4ae33bbfc157ed5d5a5ef3bc29227303d595861deb238fcec4e9457"
[[package]]
name = "fallible-iterator"
version = "0.2.0"
@ -458,15 +482,20 @@ name = "hastic"
version = "0.0.1"
dependencies = [
"anyhow",
"bincode",
"chrono",
"config",
"fastrand",
"futures",
"libsqlite3-sys",
"linfa",
"linfa-svm",
"ndarray",
"openssl",
"parking_lot",
"rusqlite",
"serde 1.0.130",
"serde_json",
"subbeat",
"tokio",
"warp",
@ -616,12 +645,30 @@ dependencies = [
"cfg-if",
]
[[package]]
name = "itertools"
version = "0.10.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "69ddb889f9d0d08a67338271fa9b62996bc788c7796a5c18cf057420aaed5eaf"
dependencies = [
"either",
]
[[package]]
name = "itoa"
version = "0.4.8"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b71991ff56294aa922b450139ee08b3bfc70982c6b2c7562771375cf73542dd4"
[[package]]
name = "kdtree"
version = "0.6.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "80ee359328fc9087e9e3fc0a4567c4dd27ec69a127d6a70e8d9dd22845b8b1a2"
dependencies = [
"num-traits 0.2.14",
]
[[package]]
name = "lazy_static"
version = "1.4.0"
@ -647,6 +694,12 @@ version = "0.2.105"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "869d572136620d55835903746bcb5cdc54cb2851fd0aeec53220b4bb65ef3013"
[[package]]
name = "libm"
version = "0.2.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c7d73b3f436185384286bd8098d17ec07c9a7d2388a6599f824d8502b529702a"
[[package]]
name = "libsqlite3-sys"
version = "0.23.1"
@ -658,6 +711,64 @@ dependencies = [
"vcpkg",
]
[[package]]
name = "linfa"
version = "0.5.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7496612871c85de124d12006109360ed7ffc4740ce3fe3fe35a749e2d7f2f88f"
dependencies = [
"approx",
"ndarray",
"num-traits 0.2.14",
"rand 0.8.4",
"thiserror",
]
[[package]]
name = "linfa-kernel"
version = "0.5.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6da781111566c8b9867c8e9c0ac69fd72808dda1e111164aa23ece9f15d6a95a"
dependencies = [
"linfa",
"linfa-nn",
"ndarray",
"num-traits 0.2.14",
"serde 1.0.130",
"sprs",
]
[[package]]
name = "linfa-nn"
version = "0.5.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "83748e3f71c8fa7fc490a54d2e33d40b8246a335618159fb313989ff6eb209ef"
dependencies = [
"kdtree",
"linfa",
"ndarray",
"ndarray-stats",
"noisy_float",
"num-traits 0.2.14",
"order-stat",
"thiserror",
]
[[package]]
name = "linfa-svm"
version = "0.5.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ecfffa7a24c3019dc4c4fa9c58bf8f187ff2fdbbd08b6fee454591c6ec285f3f"
dependencies = [
"linfa",
"linfa-kernel",
"ndarray",
"ndarray-rand",
"num-traits 0.2.14",
"serde 1.0.130",
"thiserror",
]
[[package]]
name = "linked-hash-map"
version = "0.5.4"
@ -688,6 +799,15 @@ version = "0.1.9"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a3e378b66a060d48947b590737b30a1be76706c8dd7b8ba0f2fe3989c68a853f"
[[package]]
name = "matrixmultiply"
version = "0.3.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5a8a15b776d9dfaecd44b03c5828c2199cddff5247215858aac14624f8d6b741"
dependencies = [
"rawpointer",
]
[[package]]
name = "memchr"
version = "2.4.1"
@ -768,6 +888,56 @@ dependencies = [
"tempfile",
]
[[package]]
name = "ndarray"
version = "0.15.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "08e854964160a323e65baa19a0b1a027f76d590faba01f05c0cbc3187221a8c9"
dependencies = [
"approx",
"matrixmultiply",
"num-complex",
"num-integer",
"num-traits 0.2.14",
"rawpointer",
"serde 1.0.130",
]
[[package]]
name = "ndarray-rand"
version = "0.14.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "65608f937acc725f5b164dcf40f4f0bc5d67dc268ab8a649d3002606718c4588"
dependencies = [
"ndarray",
"rand 0.8.4",
"rand_distr",
]
[[package]]
name = "ndarray-stats"
version = "0.5.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "22877ad014bafa2f7dcfa5d556b0c7a52b0546cc98061a1ebef6d1834957b069"
dependencies = [
"indexmap",
"itertools",
"ndarray",
"noisy_float",
"num-integer",
"num-traits 0.2.14",
"rand 0.8.4",
]
[[package]]
name = "noisy_float"
version = "0.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "978fe6e6ebc0bf53de533cd456ca2d9de13de13856eda1518a285d7705a213af"
dependencies = [
"num-traits 0.2.14",
]
[[package]]
name = "nom"
version = "5.1.2"
@ -788,6 +958,15 @@ dependencies = [
"winapi",
]
[[package]]
name = "num-complex"
version = "0.4.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "26873667bbbb7c5182d4a37c1add32cdf09f841af72da53318fdb81543c15085"
dependencies = [
"num-traits 0.2.14",
]
[[package]]
name = "num-integer"
version = "0.1.44"
@ -814,6 +993,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9a64b1ec5cda2586e284722486d802acf1f7dbdc623e2bfc57e65ca1cd099290"
dependencies = [
"autocfg",
"libm",
]
[[package]]
@ -881,6 +1061,12 @@ dependencies = [
"vcpkg",
]
[[package]]
name = "order-stat"
version = "0.1.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "efa535d5117d3661134dbf1719b6f0ffe06f2375843b13935db186cd094105eb"
[[package]]
name = "parking_lot"
version = "0.11.2"
@ -1055,6 +1241,16 @@ dependencies = [
"getrandom 0.2.3",
]
[[package]]
name = "rand_distr"
version = "0.4.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "964d548f8e7d12e102ef183a0de7e98180c9f8729f555897a857b96e48122d2f"
dependencies = [
"num-traits 0.2.14",
"rand 0.8.4",
]
[[package]]
name = "rand_hc"
version = "0.2.0"
@ -1073,6 +1269,12 @@ dependencies = [
"rand_core 0.6.3",
]
[[package]]
name = "rawpointer"
version = "0.2.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "60a357793950651c4ed0f3f52338f53b2f809f32d83a07f72909fa13e4c6c1e3"
[[package]]
name = "redox_syscall"
version = "0.2.10"
@ -1308,6 +1510,19 @@ dependencies = [
"winapi",
]
[[package]]
name = "sprs"
version = "0.11.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6ea71e48b3eab4c4b153e8e35dcaeac132720809ef68359097b8cb54a18edd70"
dependencies = [
"ndarray",
"num-complex",
"num-traits 0.2.14",
"serde 1.0.130",
"smallvec",
]
[[package]]
name = "static_assertions"
version = "1.1.0"

9
server/Cargo.toml

@ -11,6 +11,7 @@ tokio = { version = "1", features = ["full"] }
warp = "0.3"
parking_lot = "0.11.2"
serde = { version = "1.0", features = ["derive"] }
serde_json = "1.0"
fastrand = "1.5.0"
# subbeat = { path = "../../../subbeat/subbeat/" }
subbeat = "0.0.14"
@ -21,3 +22,11 @@ rusqlite = "0.26.1"
libsqlite3-sys = { version = "*", features = ["bundled"] }
futures = "0.3.17"
chrono = "0.4.19"
linfa = "0.5.0"
linfa-svm = { version="0.5.0", features=["serde"] }
ndarray = "0.15.3"
bincode = "1.3.3"
# TODO: remove this from prod
# plotlib = "0.5.1"

151
server/src/services/analytic_service/pattern_detector.rs

@ -1,11 +1,56 @@
#[derive(Debug, Clone)]
use std::{fmt, sync::Arc};
use parking_lot::Mutex;
use serde_json;
use serde::{Serialize, Deserialize};
use linfa::prelude::*;
use linfa;
use linfa_svm::{error::Result, Svm};
use ndarray::{Array, ArrayView, Axis};
#[derive(Clone)]
pub struct LearningResults {
// model: Vec<f64>,
model: Arc<Mutex<Svm<f64, bool>>>,
patterns: Vec<Vec<f64>>,
anti_patterns: Vec<Vec<f64>>,
}
const CORR_THRESHOLD: f64 = 0.95;
// impl Clone for LearningResults {
// fn clone(&self) -> Self {
// // TODO: it's a hack
// // https://github.com/rust-ml/linfa/issues/174
// let model_str = serde_json::to_string(&self.model).unwrap();
// let model = serde_json::from_str(&model_str).unwrap();
// return LearningResults {
// model,
// patterns: self.patterns.clone(),
// anti_patterns: self.anti_patterns.clone()
// };
// }
// }
impl fmt::Debug for LearningResults {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("Point")
.field("{:?}", &self.patterns)
.field("{:?}", &self.anti_patterns)
.finish()
}
}
const FEATURES_SIZE: usize = 6;
type Features = [f64; FEATURES_SIZE];
const SCORE_THRESHOLD: f64 = 0.95;
#[derive(Clone)]
pub struct PatternDetector {
@ -34,21 +79,65 @@ impl PatternDetector {
let mut patterns = Vec::<Vec<f64>>::new();
let mut anti_patterns = Vec::<Vec<f64>>::new();
// TODO: implement actual learning
let mut records = Array::zeros((0, FEATURES_SIZE));
let mut targets_raw = Vec::<bool>::new();
for r in reads {
let xs: Vec<f64> = r.iter().map(|e| e.1).map(nan_to_zero).collect();
let fs = PatternDetector::get_features(&xs);
records.push_row(ArrayView::from(&fs)).unwrap();
targets_raw.push(true);
patterns.push(xs);
}
for r in anti_reads {
let xs: Vec<f64> = r.iter().map(|e| e.1).map(nan_to_zero).collect();
let fs = PatternDetector::get_features(&xs);
records.push_row(ArrayView::from(&fs)).unwrap();
targets_raw.push(false);
anti_patterns.push(xs);
}
let targets = Array::from_vec(targets_raw);
// println!("{:?}", records);
// println!("{:?}", targets);
let train = linfa::Dataset::new(records, targets);
// The 'view' describes what set of data is drawn
// let v = ContinuousView::new()
// .add(s1)
// // .add(s2)
// .x_range(-500., 100.)
// .y_range(-200., 600.)
// .x_label("Some varying variable")
// .y_label("The response of something");
// Page::single(&v).save("scatter.svg").unwrap();
// let model = stat.iter().map(|(c, v)| v / *c as f64).collect();
let model = Svm::<_, bool>::params()
.pos_neg_weights(50000., 5000.)
.gaussian_kernel(80.0)
.fit(&train).unwrap();
// let prediction = model.predict(Array::from_vec(vec![
// 715.3122807017543, 761.1228070175438, 745.0, 56.135764727158595, 0.0, 0.0
// ]));
// println!("pridiction: {}", prediction );
LearningResults {
model: Arc::new(Mutex::new(model)),
patterns,
anti_patterns,
}
@ -57,12 +146,9 @@ impl PatternDetector {
// TODO: get iterator instead of vector
pub fn detect(&self, ts: &Vec<(u64, f64)>) -> Vec<(u64, u64)> {
let mut results = Vec::new();
// let mut i = 0;
let pt = &self.learning_results.patterns;
let apt = &self.learning_results.anti_patterns;
for i in 0..ts.len() {
@ -97,7 +183,21 @@ impl PatternDetector {
}
}
if pattern_match_score > anti_pattern_match_score && pattern_match_score >= CORR_THRESHOLD {
{
let mut backet = Vec::<f64>::new();
for j in 0..pattern_match_len {
backet.push(nan_to_zero(ts[i + j].1));
}
let fs = PatternDetector::get_features(&backet);
let detected = self.learning_results.model.lock().predict(Array::from_vec(fs.to_vec()));
if detected {
pattern_match_score += 0.1;
} else {
anti_pattern_match_score += 0.1;
}
}
if pattern_match_score > anti_pattern_match_score && pattern_match_score >= SCORE_THRESHOLD {
results.push((ts[i].0, ts[i + pattern_match_len - 1].0));
}
}
@ -148,4 +248,39 @@ impl PatternDetector {
return result;
}
fn get_features(xs: &Vec<f64>) -> Features {
let mut min = f64::MAX;
let mut max = f64::MIN;
let mut sum = 0f64;
for x in xs {
min = min.min(*x);
max = max.max(*x);
sum += x;
}
let mean = sum / xs.len() as f64;
sum = 0f64;
for x in xs {
sum += (x - mean) * (x - mean);
}
let sd = sum.sqrt();
// TODO: add autocorrelation
// TODO: add FFT
return [
min,
max,
mean,
sd,
0f64,0f64,
//0f64,0f64,0f64, 0f64
];
}
}

Loading…
Cancel
Save