Skip to content

Commit d214b4a

Browse files
committed
Refactor VECTORIZED support using time_axis in BaseRawIO
Moved VECTORIZED orientation logic to BaseRawIO as suggested by @samuelgarcia: - Added time_axis parameter to buffer_description (0=MULTIPLEXED, 1=VECTORIZED) - Extended BaseRawIO._get_analogsignal_chunk() to handle time_axis=1 for raw buffers - Removed custom _get_analogsignal_chunk() override from BrainVisionRawIO - Fixed _get_signal_size() to correctly handle raw buffers with time_axis=1 Benefits: - Cleaner, more general solution applicable to other readers - Consistent with existing HDF5 time_axis pattern - Reduced code duplication - All tests pass with identical MNE-Python validation
1 parent 9fc58a0 commit d214b4a

File tree

2 files changed

+66
-70
lines changed

2 files changed

+66
-70
lines changed

neo/rawio/baserawio.py

Lines changed: 61 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1577,9 +1577,14 @@ def __init__(self, *arg, **kwargs):
15771577
def _get_signal_size(self, block_index, seg_index, stream_index):
15781578
buffer_id = self.header["signal_streams"][stream_index]["buffer_id"]
15791579
buffer_desc = self.get_analogsignal_buffer_description(block_index, seg_index, buffer_id)
1580-
# some hdf5 revert teh buffer
1581-
time_axis = buffer_desc.get("time_axis", 0)
1582-
return buffer_desc["shape"][time_axis]
1580+
# For "raw" type, shape is always (time, channels) regardless of file layout
1581+
# For "hdf5" type, shape can be (time, channels) or (channels, time) based on time_axis
1582+
if buffer_desc["type"] == "raw":
1583+
return buffer_desc["shape"][0]
1584+
else:
1585+
# some hdf5 revert the buffer
1586+
time_axis = buffer_desc.get("time_axis", 0)
1587+
return buffer_desc["shape"][time_axis]
15831588

15841589
def _get_analogsignal_chunk(
15851590
self,
@@ -1603,29 +1608,61 @@ def _get_analogsignal_chunk(
16031608

16041609
if buffer_desc["type"] == "raw":
16051610

1606-
# open files on demand and keep reference to opened file
1607-
if not hasattr(self, "_memmap_analogsignal_buffers"):
1608-
self._memmap_analogsignal_buffers = {}
1609-
if block_index not in self._memmap_analogsignal_buffers:
1610-
self._memmap_analogsignal_buffers[block_index] = {}
1611-
if seg_index not in self._memmap_analogsignal_buffers[block_index]:
1612-
self._memmap_analogsignal_buffers[block_index][seg_index] = {}
1613-
if buffer_id not in self._memmap_analogsignal_buffers[block_index][seg_index]:
1614-
fid = open(buffer_desc["file_path"], mode="rb")
1615-
self._memmap_analogsignal_buffers[block_index][seg_index][buffer_id] = fid
1616-
else:
1617-
fid = self._memmap_analogsignal_buffers[block_index][seg_index][buffer_id]
1611+
time_axis = buffer_desc.get("time_axis", 0)
16181612

1619-
num_channels = buffer_desc["shape"][1]
1613+
if time_axis == 0:
1614+
# MULTIPLEXED: time_axis=0 means (time, channels) layout
1615+
# open files on demand and keep reference to opened file
1616+
if not hasattr(self, "_memmap_analogsignal_buffers"):
1617+
self._memmap_analogsignal_buffers = {}
1618+
if block_index not in self._memmap_analogsignal_buffers:
1619+
self._memmap_analogsignal_buffers[block_index] = {}
1620+
if seg_index not in self._memmap_analogsignal_buffers[block_index]:
1621+
self._memmap_analogsignal_buffers[block_index][seg_index] = {}
1622+
if buffer_id not in self._memmap_analogsignal_buffers[block_index][seg_index]:
1623+
fid = open(buffer_desc["file_path"], mode="rb")
1624+
self._memmap_analogsignal_buffers[block_index][seg_index][buffer_id] = fid
1625+
else:
1626+
fid = self._memmap_analogsignal_buffers[block_index][seg_index][buffer_id]
1627+
1628+
num_channels = buffer_desc["shape"][1]
1629+
1630+
raw_sigs = get_memmap_chunk_from_opened_file(
1631+
fid,
1632+
num_channels,
1633+
i_start,
1634+
i_stop,
1635+
np.dtype(buffer_desc["dtype"]),
1636+
file_offset=buffer_desc["file_offset"],
1637+
)
16201638

1621-
raw_sigs = get_memmap_chunk_from_opened_file(
1622-
fid,
1623-
num_channels,
1624-
i_start,
1625-
i_stop,
1626-
np.dtype(buffer_desc["dtype"]),
1627-
file_offset=buffer_desc["file_offset"],
1628-
)
1639+
elif time_axis == 1:
1640+
# VECTORIZED: time_axis=1 means (channels, time) layout
1641+
# Data is stored as [all_samples_ch1, all_samples_ch2, ...]
1642+
dtype = np.dtype(buffer_desc["dtype"])
1643+
num_channels = buffer_desc["shape"][1]
1644+
num_samples = i_stop - i_start
1645+
total_samples_per_channel = buffer_desc["shape"][0]
1646+
1647+
# Determine which channels to read
1648+
if channel_indexes is None:
1649+
chan_indices = np.arange(num_channels)
1650+
else:
1651+
chan_indices = np.asarray(channel_indexes)
1652+
1653+
raw_sigs = np.empty((num_samples, len(chan_indices)), dtype=dtype)
1654+
1655+
for i, chan_idx in enumerate(chan_indices):
1656+
offset = buffer_desc["file_offset"] + chan_idx * total_samples_per_channel * dtype.itemsize
1657+
channel_data = np.memmap(buffer_desc["file_path"], dtype=dtype, mode='r',
1658+
offset=offset, shape=(total_samples_per_channel,))
1659+
raw_sigs[:, i] = channel_data[i_start:i_stop]
1660+
1661+
# Channel slicing already done above, so skip later channel_indexes slicing
1662+
channel_indexes = None
1663+
1664+
else:
1665+
raise ValueError(f"time_axis must be 0 or 1, got {time_axis}")
16291666

16301667
elif buffer_desc["type"] == "hdf5":
16311668

neo/rawio/brainvisionrawio.py

Lines changed: 5 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -91,13 +91,10 @@ def _parse_header(self):
9191
self._buffer_descriptions = {0: {0: {}}}
9292
self._stream_buffer_slice = {}
9393

94-
# Calculate the shape based on orientation
95-
if self._data_orientation == "MULTIPLEXED":
96-
shape = get_memmap_shape(binary_filename, sig_dtype, num_channels=nb_channel, offset=0)
97-
else: # VECTORIZED
98-
# For VECTORIZED, data is stored as [all_samples_ch1, all_samples_ch2, ...]
99-
# We still report shape as (num_samples, num_channels) for compatibility
100-
shape = get_memmap_shape(binary_filename, sig_dtype, num_channels=nb_channel, offset=0)
94+
shape = get_memmap_shape(binary_filename, sig_dtype, num_channels=nb_channel, offset=0)
95+
96+
# time_axis indicates data layout: 0 for MULTIPLEXED (time, channels), 1 for VECTORIZED (channels, time)
97+
time_axis = 0 if self._data_orientation == "MULTIPLEXED" else 1
10198

10299
self._buffer_descriptions[0][0][buffer_id] = {
103100
"type": "raw",
@@ -106,12 +103,10 @@ def _parse_header(self):
106103
"order": "C",
107104
"file_offset": 0,
108105
"shape": shape,
106+
"time_axis": time_axis,
109107
}
110108
self._stream_buffer_slice[stream_id] = None
111109

112-
# Store number of channels for VECTORIZED reading
113-
self._nb_channel = nb_channel
114-
115110
signal_buffers = np.array([("Signals", "0")], dtype=_signal_buffer_dtype)
116111
signal_streams = np.array([("Signals", "0", "0")], dtype=_signal_stream_dtype)
117112

@@ -253,42 +248,6 @@ def _rescale_event_timestamp(self, event_timestamps, dtype, event_channel_index)
253248
def _get_analogsignal_buffer_description(self, block_index, seg_index, buffer_id):
254249
return self._buffer_descriptions[block_index][seg_index][buffer_id]
255250

256-
def _get_analogsignal_chunk(
257-
self, block_index, seg_index, i_start, i_stop, stream_index, channel_indexes
258-
):
259-
"""
260-
Override to handle VECTORIZED orientation.
261-
VECTORIZED: all samples for ch1, then all samples for ch2, etc.
262-
"""
263-
if self._data_orientation == "MULTIPLEXED":
264-
return super()._get_analogsignal_chunk(
265-
block_index, seg_index, i_start, i_stop, stream_index, channel_indexes
266-
)
267-
268-
# VECTORIZED: use memmap to read each channel's data block
269-
buffer_id = self.header["signal_streams"][stream_index]["buffer_id"]
270-
buffer_desc = self.get_analogsignal_buffer_description(block_index, seg_index, buffer_id)
271-
272-
i_start = i_start or 0
273-
i_stop = i_stop or buffer_desc["shape"][0]
274-
275-
if channel_indexes is None:
276-
channel_indexes = np.arange(self._nb_channel)
277-
278-
dtype = np.dtype(buffer_desc["dtype"])
279-
num_samples = i_stop - i_start
280-
total_samples_per_channel = buffer_desc["shape"][0]
281-
282-
raw_sigs = np.empty((num_samples, len(channel_indexes)), dtype=dtype)
283-
284-
for i, chan_idx in enumerate(channel_indexes):
285-
offset = buffer_desc["file_offset"] + chan_idx * total_samples_per_channel * dtype.itemsize
286-
channel_data = np.memmap(buffer_desc["file_path"], dtype=dtype, mode='r',
287-
offset=offset, shape=(total_samples_per_channel,))
288-
raw_sigs[:, i] = channel_data[i_start:i_stop]
289-
290-
return raw_sigs
291-
292251
def _ensure_filename(self, filename, kind, entry_name):
293252
if not os.path.exists(filename):
294253
# file not found, subsequent import stage would fail

0 commit comments

Comments
 (0)