Skip to content

Commit 91986a5

Browse files
committed
add flag and time
1 parent f386aa3 commit 91986a5

File tree

1 file changed

+23
-5
lines changed

1 file changed

+23
-5
lines changed

imap_processing/ialirt/utils/grouping.py

Lines changed: 23 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -8,14 +8,18 @@
88
logger = logging.getLogger(__name__)
99

1010

11-
def filter_valid_groups(grouped_data: xr.Dataset) -> xr.Dataset:
11+
def filter_valid_groups(
12+
grouped_data: xr.Dataset, flag: str | None = None
13+
) -> xr.Dataset:
1214
"""
1315
Filter out groups where `src_seq_ctr` diff are not 1.
1416
1517
Parameters
1618
----------
1719
grouped_data : xr.Dataset
1820
Dataset with a "group" coordinate.
21+
flag : str | None
22+
Name of flag data variable.
1923
2024
Returns
2125
-------
@@ -42,6 +46,12 @@ def filter_valid_groups(grouped_data: xr.Dataset) -> xr.Dataset:
4246
drop=True,
4347
)
4448

49+
if flag:
50+
filtered_data = filtered_data.where(
51+
filtered_data[flag] != 0,
52+
drop=True,
53+
)
54+
4555
return filtered_data
4656

4757

@@ -50,6 +60,7 @@ def find_groups(
5060
sequence_range: tuple,
5161
sequence_name: str,
5262
time_name: str,
63+
flag: str | None,
5364
) -> xr.Dataset:
5465
"""
5566
Group data based on time and sequence number values.
@@ -64,6 +75,8 @@ def find_groups(
6475
Name of the sequence variable.
6576
time_name : str
6677
Name of the time variable.
78+
flag : str | None
79+
Name of flag data variable.
6780
6881
Returns
6982
-------
@@ -82,9 +95,14 @@ def find_groups(
8295

8396
# Use sequence_range == 0 to define the beginning of the group.
8497
# Find time at this index and use it as the beginning time for the group.
85-
start_times = sorted_data[time_name][
86-
(sorted_data[sequence_name] == sequence_range[0])
87-
]
98+
if flag:
99+
start_times = sorted_data[time_name][
100+
(sorted_data[sequence_name] == sequence_range[0]) & (sorted_data[flag] != 0)
101+
]
102+
else:
103+
start_times = sorted_data[time_name][
104+
(sorted_data[sequence_name] == sequence_range[0])
105+
]
88106
# Use max sequence_range to define the end of the group.
89107
end_times = sorted_data[time_name][
90108
([sorted_data[sequence_name] == sequence_range[-1]][-1])
@@ -115,6 +133,6 @@ def find_groups(
115133
grouped_data = grouped_data.assign_coords(group=("epoch", group_labels))
116134

117135
# Filter out groups with non-sequential src_seq_ctr values.
118-
filtered_data = filter_valid_groups(grouped_data)
136+
filtered_data = filter_valid_groups(grouped_data, flag)
119137

120138
return filtered_data

0 commit comments

Comments
 (0)