diff --git a/crates/polars-arrow/src/io/ipc/read/common.rs b/crates/polars-arrow/src/io/ipc/read/common.rs index d802d8d55803..2b12e46df6b1 100644 --- a/crates/polars-arrow/src/io/ipc/read/common.rs +++ b/crates/polars-arrow/src/io/ipc/read/common.rs @@ -86,7 +86,6 @@ pub fn read_record_batch( version: arrow_format::ipc::MetadataVersion, reader: &mut R, block_offset: u64, - file_size: u64, scratch: &mut Vec, ) -> PolarsResult>> { assert_eq!(fields.len(), ipc_schema.fields.len()); @@ -101,26 +100,6 @@ pub fn read_record_batch( .unwrap_or_else(VecDeque::new); let mut buffers: VecDeque = buffers.iter().collect(); - // check that the sum of the sizes of all buffers is <= than the size of the file - let buffers_size = buffers - .iter() - .map(|buffer| { - let buffer_size: u64 = buffer - .length() - .try_into() - .map_err(|_| polars_err!(oos = OutOfSpecKind::NegativeFooterLength))?; - Ok(buffer_size) - }) - .sum::>()?; - if buffers_size > file_size { - return Err(polars_err!( - oos = OutOfSpecKind::InvalidBuffersLength { - buffers_size, - file_size, - } - )); - } - let field_nodes = batch .nodes() .map_err(|err| polars_err!(oos = OutOfSpecKind::InvalidFlatbufferNodes(err)))? @@ -275,7 +254,6 @@ pub fn read_dictionary( dictionaries: &mut Dictionaries, reader: &mut R, block_offset: u64, - file_size: u64, scratch: &mut Vec, ) -> PolarsResult<()> { if batch @@ -322,7 +300,6 @@ pub fn read_dictionary( arrow_format::ipc::MetadataVersion::V5, reader, block_offset, - file_size, scratch, )?; diff --git a/crates/polars-arrow/src/io/ipc/read/file.rs b/crates/polars-arrow/src/io/ipc/read/file.rs index 9048ce9d6d32..693cfc34cec1 100644 --- a/crates/polars-arrow/src/io/ipc/read/file.rs +++ b/crates/polars-arrow/src/io/ipc/read/file.rs @@ -106,7 +106,6 @@ fn read_dictionary_block( dictionaries, reader, offset + length, - metadata.size, dictionary_scratch, ) } @@ -368,7 +367,6 @@ pub fn read_batch( .map_err(|err| polars_err!(oos = OutOfSpecKind::InvalidFlatbufferVersion(err)))?, reader, offset + length, - metadata.size, data_scratch, ) } diff --git a/crates/polars-arrow/src/io/ipc/read/flight.rs b/crates/polars-arrow/src/io/ipc/read/flight.rs index cace725f3da5..4614e0a5bf2f 100644 --- a/crates/polars-arrow/src/io/ipc/read/flight.rs +++ b/crates/polars-arrow/src/io/ipc/read/flight.rs @@ -325,7 +325,6 @@ impl FlightConsumer { // Return Batch MessageHeaderRef::RecordBatch(batch) => { if batch.compression()?.is_some() { - let data_size = msg.arrow_data.len() as u64; let mut reader = std::io::Cursor::new(msg.arrow_data.as_slice()); read_record_batch( batch, @@ -337,7 +336,6 @@ impl FlightConsumer { self.md.version, &mut reader, 0, - data_size, &mut self.scratch, ) .map(Some) diff --git a/crates/polars-arrow/src/io/ipc/read/read_basic.rs b/crates/polars-arrow/src/io/ipc/read/read_basic.rs index ad543e3f6422..c84d49826038 100644 --- a/crates/polars-arrow/src/io/ipc/read/read_basic.rs +++ b/crates/polars-arrow/src/io/ipc/read/read_basic.rs @@ -57,6 +57,9 @@ fn read_uncompressed_bytes( .take(buffer_length as u64) .read_to_end(&mut buffer) .unwrap(); + + polars_ensure!(buffer.len() == buffer_length, ComputeError: "Malformed IPC file: expected compressed buffer of len {buffer_length}, got {}", buffer.len()); + Ok(buffer) } else { unreachable!() @@ -278,6 +281,8 @@ fn read_uncompressed_bitmap( .take(bytes as u64) .read_to_end(&mut buffer)?; + polars_ensure!(buffer.len() == bytes, ComputeError: "Malformed IPC file: expected compressed buffer of len {bytes}, got {}", buffer.len()); + Ok(buffer) } diff --git a/crates/polars-arrow/src/io/ipc/read/stream.rs b/crates/polars-arrow/src/io/ipc/read/stream.rs index 64b8325d368e..0eb0d855473c 100644 --- a/crates/polars-arrow/src/io/ipc/read/stream.rs +++ b/crates/polars-arrow/src/io/ipc/read/stream.rs @@ -1,4 +1,4 @@ -use std::io::Read; +use std::io::{Read, Seek}; use arrow_format::ipc::planus::ReadAsRoot; use polars_error::{PolarsError, PolarsResult, polars_bail, polars_err}; @@ -86,12 +86,11 @@ impl StreamState { /// Reads the next item, yielding `None` if the stream is done, /// and a [`StreamState`] otherwise. -fn read_next( +fn read_next( reader: &mut R, metadata: &StreamMetadata, dictionaries: &mut Dictionaries, message_buffer: &mut Vec, - data_buffer: &mut Vec, projection: &Option, scratch: &mut Vec, ) -> PolarsResult> { @@ -153,16 +152,7 @@ fn read_next( match header { arrow_format::ipc::MessageHeaderRef::RecordBatch(batch) => { - data_buffer.clear(); - data_buffer.try_reserve(block_length)?; - reader - .by_ref() - .take(block_length as u64) - .read_to_end(data_buffer)?; - - let file_size = data_buffer.len() as u64; - - let mut reader = std::io::Cursor::new(data_buffer); + let cur_pos = reader.stream_position()?; let chunk = read_record_batch( batch, @@ -172,12 +162,18 @@ fn read_next( None, dictionaries, metadata.version, - &mut reader, + &mut (&mut *reader).take(block_length as u64), 0, - file_size, scratch, ); + let new_pos = reader.stream_position()?; + let read_size = new_pos - cur_pos; + + reader.seek(std::io::SeekFrom::Current( + block_length as i64 - read_size as i64, + ))?; + if let Some(ProjectionInfo { map, .. }) = projection { // re-order according to projection chunk @@ -188,34 +184,31 @@ fn read_next( } }, arrow_format::ipc::MessageHeaderRef::DictionaryBatch(batch) => { - data_buffer.clear(); - data_buffer.try_reserve(block_length)?; - reader - .by_ref() - .take(block_length as u64) - .read_to_end(data_buffer)?; - - let file_size = data_buffer.len() as u64; - let mut dict_reader = std::io::Cursor::new(&data_buffer); + let cur_pos = reader.stream_position()?; read_dictionary( batch, &metadata.schema, &metadata.ipc_schema, dictionaries, - &mut dict_reader, + &mut (&mut *reader).take(block_length as u64), 0, - file_size, scratch, )?; + let new_pos = reader.stream_position()?; + let read_size = new_pos - cur_pos; + + reader.seek(std::io::SeekFrom::Current( + block_length as i64 - read_size as i64, + ))?; + // read the next message until we encounter a RecordBatch message read_next( reader, metadata, dictionaries, message_buffer, - data_buffer, projection, scratch, ) @@ -235,13 +228,12 @@ pub struct StreamReader { metadata: StreamMetadata, dictionaries: Dictionaries, finished: bool, - data_buffer: Vec, message_buffer: Vec, projection: Option, scratch: Vec, } -impl StreamReader { +impl StreamReader { /// Try to create a new stream reader /// /// The first message in the stream is the schema, the reader will fail if it does not @@ -256,7 +248,6 @@ impl StreamReader { metadata, dictionaries: Default::default(), finished: false, - data_buffer: Default::default(), message_buffer: Default::default(), projection, scratch: Default::default(), @@ -290,7 +281,6 @@ impl StreamReader { &self.metadata, &mut self.dictionaries, &mut self.message_buffer, - &mut self.data_buffer, &self.projection, &mut self.scratch, )?; @@ -301,7 +291,7 @@ impl StreamReader { } } -impl Iterator for StreamReader { +impl Iterator for StreamReader { type Item = PolarsResult; fn next(&mut self) -> Option { diff --git a/crates/polars-core/src/serde/df.rs b/crates/polars-core/src/serde/df.rs index 41763968d0ba..0b1367368143 100644 --- a/crates/polars-core/src/serde/df.rs +++ b/crates/polars-core/src/serde/df.rs @@ -1,3 +1,4 @@ +use std::io::{Read, Seek}; use std::sync::Arc; use arrow::datatypes::Metadata; @@ -72,7 +73,7 @@ impl DataFrame { Ok(buf) } - pub fn deserialize_from_reader(reader: &mut dyn std::io::Read) -> PolarsResult { + pub fn deserialize_from_reader(reader: &mut T) -> PolarsResult { let mut md = read_stream_metadata(reader)?; let custom_metadata = md.custom_schema_metadata.take(); @@ -167,7 +168,8 @@ impl<'de> Deserialize<'de> for DataFrame { { deserialize_map_bytes(deserializer, |b| { let v = &mut b.as_ref(); - Self::deserialize_from_reader(v) + let mut reader = std::io::Cursor::new(v); + Self::deserialize_from_reader(&mut reader) })? .map_err(D::Error::custom) } diff --git a/crates/polars-core/src/serde/series.rs b/crates/polars-core/src/serde/series.rs index 0a9aee8a8d91..18679cb800ef 100644 --- a/crates/polars-core/src/serde/series.rs +++ b/crates/polars-core/src/serde/series.rs @@ -1,3 +1,5 @@ +use std::io::{Read, Seek}; + use polars_utils::pl_serialize::deserialize_map_bytes; use serde::de::Error; use serde::{Deserialize, Deserializer, Serialize, Serializer}; @@ -19,7 +21,7 @@ impl Series { Ok(buf) } - pub fn deserialize_from_reader(reader: &mut dyn std::io::Read) -> PolarsResult { + pub fn deserialize_from_reader(reader: &mut T) -> PolarsResult { let df = DataFrame::deserialize_from_reader(reader)?; if df.width() != 1 { @@ -59,7 +61,8 @@ impl<'de> Deserialize<'de> for Series { { deserialize_map_bytes(deserializer, |b| { let v = &mut b.as_ref(); - Self::deserialize_from_reader(v) + let mut reader = std::io::Cursor::new(v); + Self::deserialize_from_reader(&mut reader) })? .map_err(D::Error::custom) } diff --git a/crates/polars-io/src/ipc/ipc_stream.rs b/crates/polars-io/src/ipc/ipc_stream.rs index 1c2e39143172..6b826939ac92 100644 --- a/crates/polars-io/src/ipc/ipc_stream.rs +++ b/crates/polars-io/src/ipc/ipc_stream.rs @@ -33,7 +33,7 @@ //! let df_read = IpcStreamReader::new(buf).finish().unwrap(); //! assert!(df.equals(&df_read)); //! ``` -use std::io::{Read, Write}; +use std::io::{Read, Seek, Write}; use std::path::PathBuf; use arrow::datatypes::Metadata; @@ -130,7 +130,7 @@ impl IpcStreamReader { impl ArrowReader for read::StreamReader where - R: Read, + R: Read + Seek, { fn next_record_batch(&mut self) -> PolarsResult> { self.next().map_or(Ok(None), |v| match v { @@ -145,7 +145,7 @@ where impl SerReader for IpcStreamReader where - R: Read, + R: Read + Seek, { fn new(reader: R) -> Self { IpcStreamReader { diff --git a/crates/polars-python/src/series/general.rs b/crates/polars-python/src/series/general.rs index 39358b084df8..d88605ded354 100644 --- a/crates/polars-python/src/series/general.rs +++ b/crates/polars-python/src/series/general.rs @@ -376,7 +376,8 @@ impl PySeries { use pyo3::pybacked::PyBackedBytes; match state.extract::(py) { Ok(bytes) => py.enter_polars(|| { - *self.series.write() = Series::deserialize_from_reader(&mut &*bytes)?; + let mut reader = std::io::Cursor::new(&*bytes); + *self.series.write() = Series::deserialize_from_reader(&mut reader)?; PolarsResult::Ok(()) }), Err(e) => Err(e),