diff --git a/cpp/src/arrow/ipc/reader.cc b/cpp/src/arrow/ipc/reader.cc index 046eacb6ced2..7bd2abcdaa87 100644 --- a/cpp/src/arrow/ipc/reader.cc +++ b/cpp/src/arrow/ipc/reader.cc @@ -900,15 +900,6 @@ Status ReadDictionary(const Buffer& metadata, const IpcReadContext& context, return Status::OK(); } -Status ReadDictionary(const Message& message, const IpcReadContext& context, - DictionaryKind* kind) { - // Only invoke this method if we already know we have a dictionary message - DCHECK_EQ(message.type(), MessageType::DICTIONARY_BATCH); - CHECK_HAS_BODY(message); - ARROW_ASSIGN_OR_RAISE(auto reader, Buffer::GetReader(message.body())); - return ReadDictionary(*message.metadata(), context, kind, reader.get()); -} - } // namespace Result> ReadRecordBatch( @@ -948,6 +939,15 @@ Result> ReadRecordBatch( reader.get()); } +Status ReadDictionary(const Message& message, DictionaryMemo* dictionary_memo, + const IpcReadOptions& options) { + CHECK_MESSAGE_TYPE(MessageType::DICTIONARY_BATCH, message.type()); + CHECK_HAS_BODY(message); + IpcReadContext context(dictionary_memo, options, /*swap=*/false); + ARROW_ASSIGN_OR_RAISE(auto reader, Buffer::GetReader(message.body())); + return ReadDictionary(*message.metadata(), context, /*kind=*/nullptr, reader.get()); +} + // Streaming format decoder class StreamDecoderInternal : public MessageDecoderListener { public: @@ -966,11 +966,11 @@ class StreamDecoderInternal : public MessageDecoderListener { field_inclusion_mask_(), num_required_initial_dictionaries_(0), num_read_initial_dictionaries_(0), - dictionary_memo_(), schema_(nullptr), filtered_schema_(nullptr), stats_(), - swap_endian_(false) {} + swap_endian_(false), + dictionary_memo_() {} Status OnMessageDecoded(std::unique_ptr message) override { ++stats_.num_messages; @@ -1036,7 +1036,7 @@ class StreamDecoderInternal : public MessageDecoderListener { num_required_initial_dictionaries_, ") of dictionaries at the start of the stream"); } - RETURN_NOT_OK(ReadDictionary(*message)); + RETURN_NOT_OK(ReadDictionaryMessage(*message)); num_read_initial_dictionaries_++; if (num_read_initial_dictionaries_ == num_required_initial_dictionaries_) { state_ = State::RECORD_BATCHES; @@ -1047,7 +1047,7 @@ class StreamDecoderInternal : public MessageDecoderListener { Status OnRecordBatchMessageDecoded(std::unique_ptr message) { if (message->type() == MessageType::DICTIONARY_BATCH) { - return ReadDictionary(*message); + return ReadDictionaryMessage(*message); } else { CHECK_HAS_BODY(*message); ARROW_ASSIGN_OR_RAISE(auto reader, Buffer::GetReader(message->body())); @@ -1062,10 +1062,13 @@ class StreamDecoderInternal : public MessageDecoderListener { } // Read dictionary from dictionary batch - Status ReadDictionary(const Message& message) { + Status ReadDictionaryMessage(const Message& message) { DictionaryKind kind; IpcReadContext context(&dictionary_memo_, options_, swap_endian_); - RETURN_NOT_OK(::arrow::ipc::ReadDictionary(message, context, &kind)); + DCHECK_EQ(message.type(), MessageType::DICTIONARY_BATCH); + CHECK_HAS_BODY(message); + ARROW_ASSIGN_OR_RAISE(auto reader, Buffer::GetReader(message.body())); + RETURN_NOT_OK(ReadDictionary(*message.metadata(), context, &kind, reader.get())); ++stats_.num_dictionary_batches; switch (kind) { case DictionaryKind::New: @@ -1086,11 +1089,13 @@ class StreamDecoderInternal : public MessageDecoderListener { std::vector field_inclusion_mask_; int num_required_initial_dictionaries_; int num_read_initial_dictionaries_; - DictionaryMemo dictionary_memo_; std::shared_ptr schema_; std::shared_ptr filtered_schema_; ReadStats stats_; bool swap_endian_; + + protected: + DictionaryMemo dictionary_memo_; }; // ---------------------------------------------------------------------- @@ -1158,6 +1163,8 @@ class RecordBatchStreamReaderImpl : public RecordBatchStreamReader, ReadStats stats() const override { return StreamDecoderInternal::stats(); } + DictionaryMemo* dictionary_memo() override { return &dictionary_memo_; } + private: std::unique_ptr message_reader_; }; @@ -1490,6 +1497,8 @@ class RecordBatchFileReaderImpl : public RecordBatchFileReader { ReadStats stats() const override { return stats_.poll(); } + DictionaryMemo* dictionary_memo() override { return &dictionary_memo_; } + Result>> GetRecordBatchGenerator( const bool coalesce, const io::IOContext& io_context, const io::CacheOptions cache_options, diff --git a/cpp/src/arrow/ipc/reader.h b/cpp/src/arrow/ipc/reader.h index 888f59a62777..947ab2ae4d85 100644 --- a/cpp/src/arrow/ipc/reader.h +++ b/cpp/src/arrow/ipc/reader.h @@ -98,6 +98,9 @@ class ARROW_EXPORT RecordBatchStreamReader : public RecordBatchReader { /// \brief Return current read statistics virtual ReadStats stats() const = 0; + + /// \brief Return the DictionaryMemo used by this reader + virtual DictionaryMemo* dictionary_memo() = 0; }; /// \brief Reads the record batch file format @@ -200,6 +203,9 @@ class ARROW_EXPORT RecordBatchFileReader /// \brief Return current read statistics virtual ReadStats stats() const = 0; + /// \brief Return the DictionaryMemo used by this reader + virtual DictionaryMemo* dictionary_memo() = 0; + /// \brief Computes the total number of rows in the file. virtual Result CountRows() = 0; @@ -580,6 +586,19 @@ Result> ReadRecordBatch( const DictionaryMemo* dictionary_memo, const IpcReadOptions& options, io::RandomAccessFile* file); +/// \brief Read a dictionary message and add its contents to a DictionaryMemo +/// +/// If the memo already contains a dictionary with the same id, it is replaced. +/// Does not perform endian swapping; intended for use with native-endian data. +/// For cross-endian support, use RecordBatchStreamReader or RecordBatchFileReader. +/// +/// \param[in] message a Message of type DICTIONARY_BATCH +/// \param[in,out] dictionary_memo DictionaryMemo to populate with the dictionary data +/// \param[in] options IPC options for reading +ARROW_EXPORT +Status ReadDictionary(const Message& message, DictionaryMemo* dictionary_memo, + const IpcReadOptions& options = IpcReadOptions::Defaults()); + /// \brief Read arrow::Tensor as encapsulated IPC message in file /// /// \param[in] file an InputStream pointed at the start of the message diff --git a/cpp/src/arrow/ipc/writer.cc b/cpp/src/arrow/ipc/writer.cc index cba484af1584..f27f9d292c66 100644 --- a/cpp/src/arrow/ipc/writer.cc +++ b/cpp/src/arrow/ipc/writer.cc @@ -821,6 +821,41 @@ Status GetDictionaryPayload(int64_t id, bool is_delta, return assembler.Assemble(dictionary); } +Result>> CollectAndSerializeDictionaries( + const RecordBatch& batch, DictionaryMemo* dictionary_memo, + const IpcWriteOptions& options) { + DictionaryFieldMapper mapper(*batch.schema()); + ARROW_ASSIGN_OR_RAISE(auto dictionaries, CollectDictionaries(batch, mapper)); + + std::vector> result; + for (const auto& pair : dictionaries) { + int64_t id = pair.first; + const auto& dictionary = pair.second; + + if (dictionary_memo->HasDictionary(id)) { + ARROW_ASSIGN_OR_RAISE(auto existing, + dictionary_memo->GetDictionary(id, options.memory_pool)); + if (existing.get() == dictionary->data().get()) { + continue; + } + } + + IpcPayload payload; + RETURN_NOT_OK(GetDictionaryPayload(id, dictionary, options, &payload)); + + ARROW_ASSIGN_OR_RAISE(auto stream, + io::BufferOutputStream::Create(1024, options.memory_pool)); + int32_t metadata_length = 0; + RETURN_NOT_OK(WriteIpcPayload(payload, options, stream.get(), &metadata_length)); + ARROW_ASSIGN_OR_RAISE(auto buffer, stream->Finish()); + result.push_back(std::move(buffer)); + + ARROW_ASSIGN_OR_RAISE( + std::ignore, dictionary_memo->AddOrReplaceDictionary(id, dictionary->data())); + } + return result; +} + Status GetRecordBatchPayload(const RecordBatch& batch, const IpcWriteOptions& options, IpcPayload* out) { return GetRecordBatchPayload(batch, NULLPTR, options, out); diff --git a/cpp/src/arrow/ipc/writer.h b/cpp/src/arrow/ipc/writer.h index aefb59f3136e..af69a11968c8 100644 --- a/cpp/src/arrow/ipc/writer.h +++ b/cpp/src/arrow/ipc/writer.h @@ -377,6 +377,24 @@ Status GetDictionaryPayload(int64_t id, bool is_delta, const std::shared_ptr& dictionary, const IpcWriteOptions& options, IpcPayload* payload); +/// \brief Collect and serialize dictionary messages for a RecordBatch +/// +/// For each dictionary-encoded field in the batch, checks the memo to determine +/// whether serialization is needed. If the memo has no dictionary for a given id, +/// the dictionary is serialized and added. If the memo already has a dictionary +/// for that id, the ArrayData pointers are compared: if they are the same object +/// the dictionary is skipped (deduplicated), otherwise a replacement dictionary +/// message is serialized and the memo is updated. +/// +/// \param[in] batch the RecordBatch to collect dictionaries from +/// \param[in,out] dictionary_memo tracks which dictionaries have been serialized +/// \param[in] options IPC write options +/// \return vector of serialized dictionary IPC message buffers +ARROW_EXPORT +Result>> CollectAndSerializeDictionaries( + const RecordBatch& batch, DictionaryMemo* dictionary_memo, + const IpcWriteOptions& options = IpcWriteOptions::Defaults()); + /// \brief Compute IpcPayload for the given record batch /// \param[in] batch the RecordBatch that is being serialized /// \param[in] options options for serialization diff --git a/docs/source/python/api/ipc.rst b/docs/source/python/api/ipc.rst index 027fee583ec1..d27efbb858df 100644 --- a/docs/source/python/api/ipc.rst +++ b/docs/source/python/api/ipc.rst @@ -34,6 +34,8 @@ Inter-Process Communication ipc.open_stream ipc.read_message ipc.read_record_batch + ipc.read_dictionary_message + ipc.serialize_dictionaries ipc.get_record_batch_size ipc.read_tensor ipc.write_tensor diff --git a/python/pyarrow/includes/libarrow.pxd b/python/pyarrow/includes/libarrow.pxd index e96a7d84696d..4ffad0f272c4 100644 --- a/python/pyarrow/includes/libarrow.pxd +++ b/python/pyarrow/includes/libarrow.pxd @@ -1916,7 +1916,7 @@ cdef extern from "arrow/ipc/api.h" namespace "arrow::ipc" nogil: int64_t num_replaced_dictionaries cdef cppclass CDictionaryMemo" arrow::ipc::DictionaryMemo": - pass + c_bool HasDictionary(int64_t id) const cdef cppclass CIpcPayload" arrow::ipc::IpcPayload": MessageType type @@ -1970,6 +1970,7 @@ cdef extern from "arrow/ipc/api.h" namespace "arrow::ipc" nogil: const CIpcReadOptions& options) CIpcReadStats stats() + CDictionaryMemo* dictionary_memo() cdef cppclass CRecordBatchFileReader \ " arrow::ipc::RecordBatchFileReader": @@ -1992,6 +1993,7 @@ cdef extern from "arrow/ipc/api.h" namespace "arrow::ipc" nogil: CResult[CRecordBatchWithMetadata] ReadRecordBatchWithCustomMetadata(int i) CIpcReadStats stats() + CDictionaryMemo* dictionary_memo() shared_ptr[const CKeyValueMetadata] metadata() @@ -2020,12 +2022,21 @@ cdef extern from "arrow/ipc/api.h" namespace "arrow::ipc" nogil: CDictionaryMemo* dictionary_memo, const CIpcReadOptions& options) + CStatus ReadDictionary(const CMessage& message, + CDictionaryMemo* dictionary_memo, + const CIpcReadOptions& options) + CResult[shared_ptr[CBuffer]] SerializeSchema( const CSchema& schema, CMemoryPool* pool) CResult[shared_ptr[CBuffer]] SerializeRecordBatch( const CRecordBatch& schema, const CIpcWriteOptions& options) + CResult[vector[shared_ptr[CBuffer]]] CollectAndSerializeDictionaries( + const CRecordBatch& batch, + CDictionaryMemo* dictionary_memo, + const CIpcWriteOptions& options) + CResult[shared_ptr[CSchema]] ReadSchema(const CMessage& message, CDictionaryMemo* dictionary_memo) diff --git a/python/pyarrow/ipc.pxi b/python/pyarrow/ipc.pxi index 6477579af21d..7c5f02fd1021 100644 --- a/python/pyarrow/ipc.pxi +++ b/python/pyarrow/ipc.pxi @@ -1102,6 +1102,15 @@ cdef class _RecordBatchStreamReader(RecordBatchReader): raise ValueError("Operation on closed reader") return _wrap_read_stats(self.stream_reader.stats()) + @property + def dictionary_memo(self): + """ + The DictionaryMemo associated with this reader. + """ + if not self.reader: + raise ValueError("Operation on closed reader") + return DictionaryMemo.wrap(self.stream_reader.dictionary_memo(), self) + cdef class _RecordBatchFileWriter(_RecordBatchStreamWriter): @@ -1293,6 +1302,15 @@ cdef class _RecordBatchFileReader(_Weakrefable): wrapped = pyarrow_wrap_metadata(self.reader.get().metadata()) return wrapped.to_dict() if wrapped is not None else None + @property + def dictionary_memo(self): + """ + The DictionaryMemo associated with this reader. + """ + if not self.reader: + raise ValueError("Operation on closed reader") + return DictionaryMemo.wrap(self.reader.get().dictionary_memo(), self) + def get_tensor_size(Tensor tensor): """ @@ -1502,3 +1520,72 @@ def read_record_batch(obj, Schema schema, CIpcReadOptions.Defaults())) return pyarrow_wrap_batch(result) + + +def read_dictionary_message(obj, DictionaryMemo dictionary_memo): + """ + Read a dictionary message into a DictionaryMemo. + + If the memo already contains a dictionary with the same id, it is + replaced. The memo must already have dictionary types registered, + typically from a prior read_schema call with the same memo. + + Parameters + ---------- + obj : Message or Buffer-like + A message of type DICTIONARY_BATCH. + dictionary_memo : DictionaryMemo + Memo to populate with the dictionary data. + """ + cdef Message message + + if isinstance(obj, Message): + message = obj + else: + message = read_message(obj) + + with nogil: + check_status(ReadDictionary(deref(message.message.get()), + dictionary_memo.memo, + CIpcReadOptions.Defaults())) + + +def serialize_dictionaries(RecordBatch batch, + DictionaryMemo dictionary_memo, + memory_pool=None): + """ + Serialize IPC dictionary messages needed for a RecordBatch. + + For each dictionary-encoded column, checks the memo to determine + whether serialization is needed. Dictionaries are deduplicated by + pointer identity: if the memo already contains the same dictionary + object, it is skipped. If a different dictionary object exists for + the same field, a replacement message is emitted and the memo is + updated. + + Parameters + ---------- + batch : RecordBatch + The record batch whose dictionaries should be serialized. + dictionary_memo : DictionaryMemo + Tracks which dictionaries have already been serialized. + Updated in place with newly serialized dictionaries. + memory_pool : MemoryPool, default None + Uses default memory pool if not specified. + + Returns + ------- + list of Buffer + Serialized dictionary IPC messages, in dependency order. + """ + batch._assert_cpu() + cdef: + vector[shared_ptr[CBuffer]] c_buffers + CIpcWriteOptions options = CIpcWriteOptions.Defaults() + options.memory_pool = maybe_unbox_memory_pool(memory_pool) + + with nogil: + c_buffers = GetResultValue( + CollectAndSerializeDictionaries( + deref(batch.batch), dictionary_memo.memo, options)) + return [pyarrow_wrap_buffer(buf) for buf in c_buffers] diff --git a/python/pyarrow/ipc.py b/python/pyarrow/ipc.py index 4e236678788a..48adb75cc6df 100644 --- a/python/pyarrow/ipc.py +++ b/python/pyarrow/ipc.py @@ -26,6 +26,8 @@ RecordBatchReader, _ReadPandasMixin, MetadataVersion, Alignment, read_message, read_record_batch, read_schema, + read_dictionary_message, + serialize_dictionaries, read_tensor, write_tensor, get_record_batch_size, get_tensor_size) import pyarrow.lib as lib diff --git a/python/pyarrow/lib.pxd b/python/pyarrow/lib.pxd index 683faa7855c5..ee1c6319964d 100644 --- a/python/pyarrow/lib.pxd +++ b/python/pyarrow/lib.pxd @@ -128,6 +128,10 @@ cdef class DictionaryMemo(_Weakrefable): # it on the heap so as to avoid C++ ABI issues with Python wheels. shared_ptr[CDictionaryMemo] sp_memo CDictionaryMemo* memo + object _parent + + @staticmethod + cdef DictionaryMemo wrap(CDictionaryMemo* memo, object parent) cdef class DictionaryType(DataType): diff --git a/python/pyarrow/tests/test_ipc.py b/python/pyarrow/tests/test_ipc.py index 6813ed777234..5e641e910d53 100644 --- a/python/pyarrow/tests/test_ipc.py +++ b/python/pyarrow/tests/test_ipc.py @@ -1429,3 +1429,256 @@ def read_options_args(request): def test_read_options_repr(read_options_args): # https://github.com/apache/arrow/issues/47358 check_ipc_options_repr(pa.ipc.IpcReadOptions, read_options_args) + + +def test_dictionary_memo_on_readers(): + # Create data with dictionary-encoded columns + arr = pa.array(["foo", "bar", "foo", "baz"]).dictionary_encode() + batch = pa.record_batch([arr], names=["dict_col"]) + + # Write to stream format + sink = pa.BufferOutputStream() + with pa.ipc.new_stream(sink, batch.schema) as writer: + writer.write_batch(batch) + stream_buf = sink.getvalue() + + # Test RecordBatchStreamReader.dictionary_memo + reader = pa.ipc.open_stream(stream_buf) + memo = reader.dictionary_memo + assert isinstance(memo, pa.DictionaryMemo) + + # Write to file format + sink = pa.BufferOutputStream() + with pa.ipc.new_file(sink, batch.schema) as writer: + writer.write_batch(batch) + file_buf = sink.getvalue() + + # Test RecordBatchFileReader.dictionary_memo: use it to read a batch + file_reader = pa.ipc.open_file(file_buf) + file_memo = file_reader.dictionary_memo + assert isinstance(file_memo, pa.DictionaryMemo) + # The file reader populates its memo on open; use it to read back + file_batch = file_reader.get_batch(0) + assert file_batch.equals(batch) + # Feed a dictionary message from a separately-written stream into + # the file reader's memo (cross-population). + sink_extra = pa.BufferOutputStream() + arr_extra = pa.array(["p", "q", "r"]).dictionary_encode() + batch_extra = pa.record_batch([arr_extra], names=["dict_col"]) + with pa.ipc.new_stream(sink_extra, batch_extra.schema) as writer: + writer.write_batch(batch_extra) + extra_buf = sink_extra.getvalue() + msg_reader_extra = pa.ipc.MessageReader.open_stream(extra_buf) + next(msg_reader_extra) # skip schema + dict_msg_extra = next(msg_reader_extra) + pa.ipc.read_dictionary_message(dict_msg_extra, file_memo) + + # Demonstrate cross-population with read_dictionary_message: + # Open a stream reader, extract its memo, then feed a dictionary message + # from a separately-written stream into that memo. + sink2 = pa.BufferOutputStream() + arr2 = pa.array(["alpha", "beta", "gamma"]).dictionary_encode() + batch2 = pa.record_batch([arr2], names=["dict_col"]) + with pa.ipc.new_stream(sink2, batch2.schema) as writer: + writer.write_batch(batch2) + stream_buf2 = sink2.getvalue() + + # Read dictionary messages from the second stream + msg_reader = pa.ipc.MessageReader.open_stream(stream_buf2) + schema_msg = next(msg_reader) + assert schema_msg.type == "schema" + dict_msg = next(msg_reader) + assert dict_msg.type == "dictionary" + + # Use read_dictionary_message with the first reader's memo + reader1 = pa.ipc.open_stream(stream_buf) + memo1 = reader1.dictionary_memo + pa.ipc.read_dictionary_message(dict_msg, memo1) + + +def test_serialize_dictionaries_basic(): + arr = pa.array(["a", "b", "a"]).dictionary_encode() + batch = pa.record_batch([arr], names=["col"]) + memo = pa.DictionaryMemo() + + # First call should return dictionary buffers + buffers = pa.ipc.serialize_dictionaries(batch, memo) + assert len(buffers) == 1 + assert isinstance(buffers[0], pa.Buffer) + assert len(buffers[0]) > 0 + + # Second call with the same memo should return nothing (already tracked) + buffers2 = pa.ipc.serialize_dictionaries(batch, memo) + assert len(buffers2) == 0 + + +def test_serialize_dictionaries_roundtrip(): + arr = pa.array(["a", "b", "a"]).dictionary_encode() + batch = pa.record_batch([arr], names=["col"]) + + # Serialize + write_memo = pa.DictionaryMemo() + dict_bufs = pa.ipc.serialize_dictionaries(batch, write_memo) + batch_buf = batch.serialize() + + # Read back + read_memo = pa.DictionaryMemo() + schema = pa.ipc.read_schema(batch.schema.serialize(), read_memo) + for buf in dict_bufs: + pa.ipc.read_dictionary_message(buf, read_memo) + result = pa.ipc.read_record_batch(batch_buf, schema, read_memo) + assert result.equals(batch) + + +def test_serialize_dictionaries_no_dict_columns(): + arr = pa.array([1, 2, 3]) + batch = pa.record_batch([arr], names=["col"]) + memo = pa.DictionaryMemo() + + buffers = pa.ipc.serialize_dictionaries(batch, memo) + assert len(buffers) == 0 + + +def test_serialize_dictionaries_multiple_dict_columns(): + arr1 = pa.array(["a", "b", "a"]).dictionary_encode() + arr2 = pa.array(["x", "y", "z"]).dictionary_encode() + batch = pa.record_batch([arr1, arr2], names=["col1", "col2"]) + memo = pa.DictionaryMemo() + + buffers = pa.ipc.serialize_dictionaries(batch, memo) + assert len(buffers) == 2 + + # Full roundtrip with multiple dict columns + batch_buf = batch.serialize() + read_memo = pa.DictionaryMemo() + schema = pa.ipc.read_schema(batch.schema.serialize(), read_memo) + for buf in buffers: + pa.ipc.read_dictionary_message(buf, read_memo) + result = pa.ipc.read_record_batch(batch_buf, schema, read_memo) + assert result.equals(batch) + + +def test_serialize_dictionaries_multi_batch_memo_dedup(): + # Multiple batches sharing the same dictionary object (same pointer). + # serialize_dictionaries deduplicates by pointer — the second batch + # skips serialization because its dictionary is the same object. + dictionary = pa.array(["a", "b", "c"]) + arr1 = pa.DictionaryArray.from_arrays( + pa.array([0, 1, 0], type=pa.int8()), dictionary) + arr2 = pa.DictionaryArray.from_arrays( + pa.array([2, 1, 0], type=pa.int8()), dictionary) + batch1 = pa.record_batch([arr1], names=["col"]) + batch2 = pa.record_batch([arr2], names=["col"]) + + memo = pa.DictionaryMemo() + + # First batch emits the dictionary + bufs1 = pa.ipc.serialize_dictionaries(batch1, memo) + assert len(bufs1) == 1 + + # Second batch shares the same dictionary object — skipped + bufs2 = pa.ipc.serialize_dictionaries(batch2, memo) + assert len(bufs2) == 0 + + # Roundtrip both batches using only batch1's dictionary messages + read_memo = pa.DictionaryMemo() + schema = pa.ipc.read_schema(batch1.schema.serialize(), read_memo) + for buf in bufs1: + pa.ipc.read_dictionary_message(buf, read_memo) + + result1 = pa.ipc.read_record_batch(batch1.serialize(), schema, read_memo) + assert result1.equals(batch1) + + result2 = pa.ipc.read_record_batch(batch2.serialize(), schema, read_memo) + assert result2.equals(batch2) + + +def test_serialize_dictionaries_nested_in_struct(): + # Dictionary-encoded field nested inside a struct type + dict_arr = pa.array(["x", "y", "x"]).dictionary_encode() + int_arr = pa.array([1, 2, 3]) + struct_arr = pa.StructArray.from_arrays( + [int_arr, dict_arr], names=["i", "d"]) + batch = pa.record_batch([struct_arr], names=["s"]) + + memo = pa.DictionaryMemo() + bufs = pa.ipc.serialize_dictionaries(batch, memo) + assert len(bufs) == 1 + + # Second call: already tracked + assert len(pa.ipc.serialize_dictionaries(batch, memo)) == 0 + + # Roundtrip + batch_buf = batch.serialize() + read_memo = pa.DictionaryMemo() + schema = pa.ipc.read_schema(batch.schema.serialize(), read_memo) + for buf in bufs: + pa.ipc.read_dictionary_message(buf, read_memo) + result = pa.ipc.read_record_batch(batch_buf, schema, read_memo) + assert result.equals(batch) + + +def test_serialize_dictionaries_changed_values_across_batches(): + # When batch2 has a different dictionary object than batch1, + # serialize_dictionaries detects the difference (by pointer) and + # emits a replacement dictionary message. + arr1 = pa.array(["a", "b", "a"]).dictionary_encode() + batch1 = pa.record_batch([arr1], names=["col"]) + + arr2 = pa.array(["x", "y", "z"]).dictionary_encode() + batch2 = pa.record_batch([arr2], names=["col"]) + + memo = pa.DictionaryMemo() + bufs1 = pa.ipc.serialize_dictionaries(batch1, memo) + assert len(bufs1) == 1 + + # batch2 has a different dictionary object — new message emitted + bufs2 = pa.ipc.serialize_dictionaries(batch2, memo) + assert len(bufs2) == 1 + + # Sequential roundtrip with a single read memo: batch2's dictionary + # replaces batch1's via read_dictionary_message (AddOrReplace). + read_memo = pa.DictionaryMemo() + schema = pa.ipc.read_schema(batch1.schema.serialize(), read_memo) + + for buf in bufs1: + pa.ipc.read_dictionary_message(buf, read_memo) + result1 = pa.ipc.read_record_batch(batch1.serialize(), schema, read_memo) + assert result1.equals(batch1) + + # Feed batch2's replacement dictionary into the same memo + for buf in bufs2: + pa.ipc.read_dictionary_message(buf, read_memo) + result2 = pa.ipc.read_record_batch(batch2.serialize(), schema, read_memo) + assert result2.equals(batch2) + + +def test_serialize_dictionaries_same_values_different_objects(): + # Dedup is by pointer identity, not by value equality. + # Two independently-created arrays with identical values are different + # objects, so both get serialized. + arr1 = pa.array(["a", "b", "a"]).dictionary_encode() + arr2 = pa.array(["a", "b", "a"]).dictionary_encode() + batch1 = pa.record_batch([arr1], names=["col"]) + batch2 = pa.record_batch([arr2], names=["col"]) + + memo = pa.DictionaryMemo() + bufs1 = pa.ipc.serialize_dictionaries(batch1, memo) + assert len(bufs1) == 1 + + # Same values, but different Python/C++ objects — not deduplicated + bufs2 = pa.ipc.serialize_dictionaries(batch2, memo) + assert len(bufs2) == 1 + + # Both round-trip correctly + read_memo = pa.DictionaryMemo() + schema = pa.ipc.read_schema(batch1.schema.serialize(), read_memo) + for buf in bufs1: + pa.ipc.read_dictionary_message(buf, read_memo) + result1 = pa.ipc.read_record_batch(batch1.serialize(), schema, read_memo) + assert result1.equals(batch1) + + for buf in bufs2: + pa.ipc.read_dictionary_message(buf, read_memo) + result2 = pa.ipc.read_record_batch(batch2.serialize(), schema, read_memo) + assert result2.equals(batch2) diff --git a/python/pyarrow/types.pxi b/python/pyarrow/types.pxi index e84f1b073f6c..718e5ec81ba1 100644 --- a/python/pyarrow/types.pxi +++ b/python/pyarrow/types.pxi @@ -481,6 +481,14 @@ cdef class DictionaryMemo(_Weakrefable): self.sp_memo.reset(new CDictionaryMemo()) self.memo = self.sp_memo.get() + @staticmethod + cdef DictionaryMemo wrap(CDictionaryMemo* memo, object parent): + cdef DictionaryMemo self = DictionaryMemo.__new__(DictionaryMemo) + self.memo = memo + self.sp_memo.reset() # don't own it + self._parent = parent # prevent GC of owner + return self + cdef class DictionaryType(DataType): """