88logger = logging .getLogger (__name__ )
99
1010
11- def filter_valid_groups (grouped_data : xr .Dataset ) -> xr .Dataset :
11+ def filter_valid_groups (
12+ grouped_data : xr .Dataset , flag : str | None = None
13+ ) -> xr .Dataset :
1214 """
1315 Filter out groups where `src_seq_ctr` diff are not 1.
1416
1517 Parameters
1618 ----------
1719 grouped_data : xr.Dataset
1820 Dataset with a "group" coordinate.
21+ flag : str | None
22+ Name of flag data variable.
1923
2024 Returns
2125 -------
@@ -42,6 +46,12 @@ def filter_valid_groups(grouped_data: xr.Dataset) -> xr.Dataset:
4246 drop = True ,
4347 )
4448
49+ if flag :
50+ filtered_data = filtered_data .where (
51+ filtered_data [flag ] != 0 ,
52+ drop = True ,
53+ )
54+
4555 return filtered_data
4656
4757
@@ -50,6 +60,7 @@ def find_groups(
5060 sequence_range : tuple ,
5161 sequence_name : str ,
5262 time_name : str ,
63+ flag : str | None ,
5364) -> xr .Dataset :
5465 """
5566 Group data based on time and sequence number values.
@@ -64,6 +75,8 @@ def find_groups(
6475 Name of the sequence variable.
6576 time_name : str
6677 Name of the time variable.
78+ flag : str | None
79+ Name of flag data variable.
6780
6881 Returns
6982 -------
@@ -82,9 +95,14 @@ def find_groups(
8295
8396 # Use sequence_range == 0 to define the beginning of the group.
8497 # Find time at this index and use it as the beginning time for the group.
85- start_times = sorted_data [time_name ][
86- (sorted_data [sequence_name ] == sequence_range [0 ])
87- ]
98+ if flag :
99+ start_times = sorted_data [time_name ][
100+ (sorted_data [sequence_name ] == sequence_range [0 ]) & (sorted_data [flag ] != 0 )
101+ ]
102+ else :
103+ start_times = sorted_data [time_name ][
104+ (sorted_data [sequence_name ] == sequence_range [0 ])
105+ ]
88106 # Use max sequence_range to define the end of the group.
89107 end_times = sorted_data [time_name ][
90108 ([sorted_data [sequence_name ] == sequence_range [- 1 ]][- 1 ])
@@ -115,6 +133,6 @@ def find_groups(
115133 grouped_data = grouped_data .assign_coords (group = ("epoch" , group_labels ))
116134
117135 # Filter out groups with non-sequential src_seq_ctr values.
118- filtered_data = filter_valid_groups (grouped_data )
136+ filtered_data = filter_valid_groups (grouped_data , flag )
119137
120138 return filtered_data
0 commit comments