Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 0 additions & 23 deletions crates/polars-arrow/src/io/ipc/read/common.rs
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,6 @@ pub fn read_record_batch<R: Read + Seek>(
version: arrow_format::ipc::MetadataVersion,
reader: &mut R,
block_offset: u64,
file_size: u64,
scratch: &mut Vec<u8>,
) -> PolarsResult<RecordBatchT<Box<dyn Array>>> {
assert_eq!(fields.len(), ipc_schema.fields.len());
Expand All @@ -101,26 +100,6 @@ pub fn read_record_batch<R: Read + Seek>(
.unwrap_or_else(VecDeque::new);
let mut buffers: VecDeque<arrow_format::ipc::BufferRef> = buffers.iter().collect();

// check that the sum of the sizes of all buffers is <= than the size of the file
let buffers_size = buffers
.iter()
.map(|buffer| {
let buffer_size: u64 = buffer
.length()
.try_into()
.map_err(|_| polars_err!(oos = OutOfSpecKind::NegativeFooterLength))?;
Ok(buffer_size)
})
.sum::<PolarsResult<u64>>()?;
if buffers_size > file_size {
return Err(polars_err!(
oos = OutOfSpecKind::InvalidBuffersLength {
buffers_size,
file_size,
}
));
}

let field_nodes = batch
.nodes()
.map_err(|err| polars_err!(oos = OutOfSpecKind::InvalidFlatbufferNodes(err)))?
Expand Down Expand Up @@ -275,7 +254,6 @@ pub fn read_dictionary<R: Read + Seek>(
dictionaries: &mut Dictionaries,
reader: &mut R,
block_offset: u64,
file_size: u64,
scratch: &mut Vec<u8>,
) -> PolarsResult<()> {
if batch
Expand Down Expand Up @@ -322,7 +300,6 @@ pub fn read_dictionary<R: Read + Seek>(
arrow_format::ipc::MetadataVersion::V5,
reader,
block_offset,
file_size,
scratch,
)?;

Expand Down
2 changes: 0 additions & 2 deletions crates/polars-arrow/src/io/ipc/read/file.rs
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,6 @@ fn read_dictionary_block<R: Read + Seek>(
dictionaries,
reader,
offset + length,
metadata.size,
dictionary_scratch,
)
}
Expand Down Expand Up @@ -368,7 +367,6 @@ pub fn read_batch<R: Read + Seek>(
.map_err(|err| polars_err!(oos = OutOfSpecKind::InvalidFlatbufferVersion(err)))?,
reader,
offset + length,
metadata.size,
data_scratch,
)
}
2 changes: 0 additions & 2 deletions crates/polars-arrow/src/io/ipc/read/flight.rs
Original file line number Diff line number Diff line change
Expand Up @@ -325,7 +325,6 @@ impl FlightConsumer {
// Return Batch
MessageHeaderRef::RecordBatch(batch) => {
if batch.compression()?.is_some() {
let data_size = msg.arrow_data.len() as u64;
let mut reader = std::io::Cursor::new(msg.arrow_data.as_slice());
read_record_batch(
batch,
Expand All @@ -337,7 +336,6 @@ impl FlightConsumer {
self.md.version,
&mut reader,
0,
data_size,
&mut self.scratch,
)
.map(Some)
Expand Down
5 changes: 5 additions & 0 deletions crates/polars-arrow/src/io/ipc/read/read_basic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,9 @@ fn read_uncompressed_bytes<R: Read + Seek>(
.take(buffer_length as u64)
.read_to_end(&mut buffer)
.unwrap();

polars_ensure!(buffer.len() == buffer_length, ComputeError: "Malformed IPC file: expected compressed buffer of len {buffer_length}, got {}", buffer.len());

Ok(buffer)
} else {
unreachable!()
Expand Down Expand Up @@ -278,6 +281,8 @@ fn read_uncompressed_bitmap<R: Read + Seek>(
.take(bytes as u64)
.read_to_end(&mut buffer)?;

polars_ensure!(buffer.len() == bytes, ComputeError: "Malformed IPC file: expected compressed buffer of len {bytes}, got {}", buffer.len());

Ok(buffer)
}

Expand Down
54 changes: 22 additions & 32 deletions crates/polars-arrow/src/io/ipc/read/stream.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use std::io::Read;
use std::io::{Read, Seek};

use arrow_format::ipc::planus::ReadAsRoot;
use polars_error::{PolarsError, PolarsResult, polars_bail, polars_err};
Expand Down Expand Up @@ -86,12 +86,11 @@ impl StreamState {

/// Reads the next item, yielding `None` if the stream is done,
/// and a [`StreamState`] otherwise.
fn read_next<R: Read>(
fn read_next<R: Read + Seek>(
reader: &mut R,
metadata: &StreamMetadata,
dictionaries: &mut Dictionaries,
message_buffer: &mut Vec<u8>,
data_buffer: &mut Vec<u8>,
projection: &Option<ProjectionInfo>,
scratch: &mut Vec<u8>,
) -> PolarsResult<Option<StreamState>> {
Expand Down Expand Up @@ -153,16 +152,7 @@ fn read_next<R: Read>(

match header {
arrow_format::ipc::MessageHeaderRef::RecordBatch(batch) => {
data_buffer.clear();
data_buffer.try_reserve(block_length)?;
reader
.by_ref()
.take(block_length as u64)
.read_to_end(data_buffer)?;

let file_size = data_buffer.len() as u64;

let mut reader = std::io::Cursor::new(data_buffer);
let cur_pos = reader.stream_position()?;

let chunk = read_record_batch(
batch,
Expand All @@ -172,12 +162,18 @@ fn read_next<R: Read>(
None,
dictionaries,
metadata.version,
&mut reader,
&mut (&mut *reader).take(block_length as u64),
0,
file_size,
scratch,
);

let new_pos = reader.stream_position()?;
let read_size = new_pos - cur_pos;

reader.seek(std::io::SeekFrom::Current(
block_length as i64 - read_size as i64,
))?;

if let Some(ProjectionInfo { map, .. }) = projection {
// re-order according to projection
chunk
Expand All @@ -188,34 +184,31 @@ fn read_next<R: Read>(
}
},
arrow_format::ipc::MessageHeaderRef::DictionaryBatch(batch) => {
data_buffer.clear();
data_buffer.try_reserve(block_length)?;
reader
.by_ref()
.take(block_length as u64)
.read_to_end(data_buffer)?;

let file_size = data_buffer.len() as u64;
let mut dict_reader = std::io::Cursor::new(&data_buffer);
let cur_pos = reader.stream_position()?;

read_dictionary(
batch,
&metadata.schema,
&metadata.ipc_schema,
dictionaries,
&mut dict_reader,
&mut (&mut *reader).take(block_length as u64),
0,
file_size,
scratch,
)?;

let new_pos = reader.stream_position()?;
let read_size = new_pos - cur_pos;

reader.seek(std::io::SeekFrom::Current(
block_length as i64 - read_size as i64,
))?;

// read the next message until we encounter a RecordBatch message
read_next(
reader,
metadata,
dictionaries,
message_buffer,
data_buffer,
projection,
scratch,
)
Expand All @@ -235,13 +228,12 @@ pub struct StreamReader<R: Read> {
metadata: StreamMetadata,
dictionaries: Dictionaries,
finished: bool,
data_buffer: Vec<u8>,
message_buffer: Vec<u8>,
projection: Option<ProjectionInfo>,
scratch: Vec<u8>,
}

impl<R: Read> StreamReader<R> {
impl<R: Read + Seek> StreamReader<R> {
/// Try to create a new stream reader
///
/// The first message in the stream is the schema, the reader will fail if it does not
Expand All @@ -256,7 +248,6 @@ impl<R: Read> StreamReader<R> {
metadata,
dictionaries: Default::default(),
finished: false,
data_buffer: Default::default(),
message_buffer: Default::default(),
projection,
scratch: Default::default(),
Expand Down Expand Up @@ -290,7 +281,6 @@ impl<R: Read> StreamReader<R> {
&self.metadata,
&mut self.dictionaries,
&mut self.message_buffer,
&mut self.data_buffer,
&self.projection,
&mut self.scratch,
)?;
Expand All @@ -301,7 +291,7 @@ impl<R: Read> StreamReader<R> {
}
}

impl<R: Read> Iterator for StreamReader<R> {
impl<R: Read + Seek> Iterator for StreamReader<R> {
type Item = PolarsResult<StreamState>;

fn next(&mut self) -> Option<Self::Item> {
Expand Down
6 changes: 4 additions & 2 deletions crates/polars-core/src/serde/df.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
use std::io::{Read, Seek};
use std::sync::Arc;

use arrow::datatypes::Metadata;
Expand Down Expand Up @@ -72,7 +73,7 @@ impl DataFrame {
Ok(buf)
}

pub fn deserialize_from_reader(reader: &mut dyn std::io::Read) -> PolarsResult<Self> {
pub fn deserialize_from_reader<T: Read + Seek>(reader: &mut T) -> PolarsResult<Self> {
let mut md = read_stream_metadata(reader)?;

let custom_metadata = md.custom_schema_metadata.take();
Expand Down Expand Up @@ -167,7 +168,8 @@ impl<'de> Deserialize<'de> for DataFrame {
{
deserialize_map_bytes(deserializer, |b| {
let v = &mut b.as_ref();
Self::deserialize_from_reader(v)
let mut reader = std::io::Cursor::new(v);
Self::deserialize_from_reader(&mut reader)
})?
.map_err(D::Error::custom)
}
Expand Down
7 changes: 5 additions & 2 deletions crates/polars-core/src/serde/series.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
use std::io::{Read, Seek};

use polars_utils::pl_serialize::deserialize_map_bytes;
use serde::de::Error;
use serde::{Deserialize, Deserializer, Serialize, Serializer};
Expand All @@ -19,7 +21,7 @@ impl Series {
Ok(buf)
}

pub fn deserialize_from_reader(reader: &mut dyn std::io::Read) -> PolarsResult<Self> {
pub fn deserialize_from_reader<T: Read + Seek>(reader: &mut T) -> PolarsResult<Self> {
let df = DataFrame::deserialize_from_reader(reader)?;

if df.width() != 1 {
Expand Down Expand Up @@ -59,7 +61,8 @@ impl<'de> Deserialize<'de> for Series {
{
deserialize_map_bytes(deserializer, |b| {
let v = &mut b.as_ref();
Self::deserialize_from_reader(v)
let mut reader = std::io::Cursor::new(v);
Self::deserialize_from_reader(&mut reader)
})?
.map_err(D::Error::custom)
}
Expand Down
6 changes: 3 additions & 3 deletions crates/polars-io/src/ipc/ipc_stream.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
//! let df_read = IpcStreamReader::new(buf).finish().unwrap();
//! assert!(df.equals(&df_read));
//! ```
use std::io::{Read, Write};
use std::io::{Read, Seek, Write};
use std::path::PathBuf;

use arrow::datatypes::Metadata;
Expand Down Expand Up @@ -130,7 +130,7 @@ impl<R: Read> IpcStreamReader<R> {

impl<R> ArrowReader for read::StreamReader<R>
where
R: Read,
R: Read + Seek,
{
fn next_record_batch(&mut self) -> PolarsResult<Option<RecordBatch>> {
self.next().map_or(Ok(None), |v| match v {
Expand All @@ -145,7 +145,7 @@ where

impl<R> SerReader<R> for IpcStreamReader<R>
where
R: Read,
R: Read + Seek,
{
fn new(reader: R) -> Self {
IpcStreamReader {
Expand Down
3 changes: 2 additions & 1 deletion crates/polars-python/src/series/general.rs
Original file line number Diff line number Diff line change
Expand Up @@ -376,7 +376,8 @@ impl PySeries {
use pyo3::pybacked::PyBackedBytes;
match state.extract::<PyBackedBytes>(py) {
Ok(bytes) => py.enter_polars(|| {
*self.series.write() = Series::deserialize_from_reader(&mut &*bytes)?;
let mut reader = std::io::Cursor::new(&*bytes);
*self.series.write() = Series::deserialize_from_reader(&mut reader)?;
PolarsResult::Ok(())
}),
Err(e) => Err(e),
Expand Down
Loading