From 775c663de5ef9399d920db0be89856c7345ba4e4 Mon Sep 17 00:00:00 2001 From: Alexandr Velikiy <39257464+VargBurz@users.noreply.github.com> Date: Thu, 7 Feb 2019 21:15:22 +0300 Subject: [PATCH] KeyError 45058 #389 (#397) * fix general * test * minor fix * fix test * minor fix --- analytics/analytics/models/general_model.py | 1 + analytics/tests/test_dataset.py | 13 +++++++++++++ 2 files changed, 14 insertions(+) diff --git a/analytics/analytics/models/general_model.py b/analytics/analytics/models/general_model.py index 8465c33..176a912 100644 --- a/analytics/analytics/models/general_model.py +++ b/analytics/analytics/models/general_model.py @@ -66,6 +66,7 @@ class GeneralModel(Model): if pat_data.count(0) == len(pat_data): raise ValueError('Labeled patterns must not be empty') + self.all_conv = [] for i in range(self.state['WINDOW_SIZE'] * 2, len(data)): watch_data = data[i - self.state['WINDOW_SIZE'] * 2: i] watch_data = utils.subtract_min_without_nan(watch_data) diff --git a/analytics/tests/test_dataset.py b/analytics/tests/test_dataset.py index 487bac7..e1c7824 100644 --- a/analytics/tests/test_dataset.py +++ b/analytics/tests/test_dataset.py @@ -168,6 +168,19 @@ class TestDataset(unittest.TestCase): except ValueError: self.fail('Model {} raised unexpectedly'.format(model_name)) + def test_general_for_two_labeling(self): + data_val = [1.0, 2.0, 5.0, 2.0, 1.0, 1.0, 3.0, 6.0, 4.0, 2.0, 1.0, 0, 0] + dataframe = create_dataframe(data_val) + segments = [{'_id': 'Esl7uetLhx4lCqHa', 'analyticUnitId': 'opnICRJwOmwBELK8', 'from': 1523889000001, 'to': 1523889000003, 'labeled': True, 'deleted': False}] + model = models.GeneralModel() + model.fit(dataframe, segments, dict()) + result = len(data_val) + 1 + for _ in range(2): + model.do_detect(dataframe) + max_pattern_index = max(model.do_detect(dataframe)) + self.assertLessEqual(max_pattern_index, result) + + if __name__ == '__main__': unittest.main()