diff --git a/analytics/analytics/utils/common.py b/analytics/analytics/utils/common.py index a0ae298..39bd8f5 100644 --- a/analytics/analytics/utils/common.py +++ b/analytics/analytics/utils/common.py @@ -133,12 +133,10 @@ def best_pattern(pattern_list: list, data: pd.Series, dir: str) -> list: return new_pattern_list def find_nan_indexes(segment: pd.Series) -> list: - nan_list = np.isnan(segment) - nan_indexes = [] - for i, val in enumerate(nan_list): - if val: - nan_indexes.append(i) - return nan_indexes + nan_list = pd.isnull(segment) + nan_list = np.array(nan_list) + nan_indexes = np.where(nan_list == True)[0] + return list(nan_indexes) def check_nan_values(segment: Union[pd.Series, list]) -> Union[pd.Series, list]: nan_list = utils.find_nan_indexes(segment) diff --git a/analytics/tests/test_utils.py b/analytics/tests/test_utils.py index 17d9a3f..ed4cceb 100644 --- a/analytics/tests/test_utils.py +++ b/analytics/tests/test_utils.py @@ -219,6 +219,23 @@ class TestUtils(unittest.TestCase): utils.get_av_model(patterns_list) except ValueError: self.fail('Method get_convolve raised unexpectedly') + + def test_find_nan_indexes(self): + data = [1, 1, 1, 0, 0, np.NaN, None, []] + data = pd.Series(data) + result = [5, 6] + self.assertEqual(utils.find_nan_indexes(data), result) + + def test_find_nan_indexes_normal_values(self): + data = [1, 1, 1, 0, 0, 0, 1, 1] + data = pd.Series(data) + result = [] + self.assertEqual(utils.find_nan_indexes(data), result) + + def test_find_nan_indexes_empty_values(self): + data = [] + result = [] + self.assertEqual(utils.find_nan_indexes(data), result) if __name__ == '__main__': unittest.main()