|
|
@ -168,6 +168,19 @@ class TestDataset(unittest.TestCase): |
|
|
|
except ValueError: |
|
|
|
except ValueError: |
|
|
|
self.fail('Model {} raised unexpectedly'.format(model_name)) |
|
|
|
self.fail('Model {} raised unexpectedly'.format(model_name)) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def test_general_for_two_labeling(self): |
|
|
|
|
|
|
|
data_val = [1.0, 2.0, 5.0, 2.0, 1.0, 1.0, 3.0, 6.0, 4.0, 2.0, 1.0, 0, 0] |
|
|
|
|
|
|
|
dataframe = create_dataframe(data_val) |
|
|
|
|
|
|
|
segments = [{'_id': 'Esl7uetLhx4lCqHa', 'analyticUnitId': 'opnICRJwOmwBELK8', 'from': 1523889000001, 'to': 1523889000003, 'labeled': True, 'deleted': False}] |
|
|
|
|
|
|
|
model = models.GeneralModel() |
|
|
|
|
|
|
|
model.fit(dataframe, segments, dict()) |
|
|
|
|
|
|
|
result = len(data_val) + 1 |
|
|
|
|
|
|
|
for _ in range(2): |
|
|
|
|
|
|
|
model.do_detect(dataframe) |
|
|
|
|
|
|
|
max_pattern_index = max(model.do_detect(dataframe)) |
|
|
|
|
|
|
|
self.assertLessEqual(max_pattern_index, result) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == '__main__': |
|
|
|
if __name__ == '__main__': |
|
|
|
unittest.main() |
|
|
|
unittest.main() |
|
|
|
|
|
|
|
|
|
|
|