Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 25 additions & 16 deletions cpp/src/arrow/ipc/reader.cc
Original file line number Diff line number Diff line change
Expand Up @@ -900,15 +900,6 @@ Status ReadDictionary(const Buffer& metadata, const IpcReadContext& context,
return Status::OK();
}

Status ReadDictionary(const Message& message, const IpcReadContext& context,
DictionaryKind* kind) {
// Only invoke this method if we already know we have a dictionary message
DCHECK_EQ(message.type(), MessageType::DICTIONARY_BATCH);
CHECK_HAS_BODY(message);
ARROW_ASSIGN_OR_RAISE(auto reader, Buffer::GetReader(message.body()));
return ReadDictionary(*message.metadata(), context, kind, reader.get());
}

} // namespace

Result<std::shared_ptr<RecordBatch>> ReadRecordBatch(
Expand Down Expand Up @@ -948,6 +939,15 @@ Result<std::shared_ptr<RecordBatch>> ReadRecordBatch(
reader.get());
}

Status ReadDictionary(const Message& message, DictionaryMemo* dictionary_memo,
const IpcReadOptions& options) {
CHECK_MESSAGE_TYPE(MessageType::DICTIONARY_BATCH, message.type());
CHECK_HAS_BODY(message);
IpcReadContext context(dictionary_memo, options, /*swap=*/false);
ARROW_ASSIGN_OR_RAISE(auto reader, Buffer::GetReader(message.body()));
return ReadDictionary(*message.metadata(), context, /*kind=*/nullptr, reader.get());
}

// Streaming format decoder
class StreamDecoderInternal : public MessageDecoderListener {
public:
Expand All @@ -966,11 +966,11 @@ class StreamDecoderInternal : public MessageDecoderListener {
field_inclusion_mask_(),
num_required_initial_dictionaries_(0),
num_read_initial_dictionaries_(0),
dictionary_memo_(),
schema_(nullptr),
filtered_schema_(nullptr),
stats_(),
swap_endian_(false) {}
swap_endian_(false),
dictionary_memo_() {}

Status OnMessageDecoded(std::unique_ptr<Message> message) override {
++stats_.num_messages;
Expand Down Expand Up @@ -1036,7 +1036,7 @@ class StreamDecoderInternal : public MessageDecoderListener {
num_required_initial_dictionaries_,
") of dictionaries at the start of the stream");
}
RETURN_NOT_OK(ReadDictionary(*message));
RETURN_NOT_OK(ReadDictionaryMessage(*message));
num_read_initial_dictionaries_++;
if (num_read_initial_dictionaries_ == num_required_initial_dictionaries_) {
state_ = State::RECORD_BATCHES;
Expand All @@ -1047,7 +1047,7 @@ class StreamDecoderInternal : public MessageDecoderListener {

Status OnRecordBatchMessageDecoded(std::unique_ptr<Message> message) {
if (message->type() == MessageType::DICTIONARY_BATCH) {
return ReadDictionary(*message);
return ReadDictionaryMessage(*message);
} else {
CHECK_HAS_BODY(*message);
ARROW_ASSIGN_OR_RAISE(auto reader, Buffer::GetReader(message->body()));
Expand All @@ -1062,10 +1062,13 @@ class StreamDecoderInternal : public MessageDecoderListener {
}

// Read dictionary from dictionary batch
Status ReadDictionary(const Message& message) {
Status ReadDictionaryMessage(const Message& message) {
DictionaryKind kind;
IpcReadContext context(&dictionary_memo_, options_, swap_endian_);
RETURN_NOT_OK(::arrow::ipc::ReadDictionary(message, context, &kind));
DCHECK_EQ(message.type(), MessageType::DICTIONARY_BATCH);
CHECK_HAS_BODY(message);
ARROW_ASSIGN_OR_RAISE(auto reader, Buffer::GetReader(message.body()));
RETURN_NOT_OK(ReadDictionary(*message.metadata(), context, &kind, reader.get()));
++stats_.num_dictionary_batches;
switch (kind) {
case DictionaryKind::New:
Expand All @@ -1086,11 +1089,13 @@ class StreamDecoderInternal : public MessageDecoderListener {
std::vector<bool> field_inclusion_mask_;
int num_required_initial_dictionaries_;
int num_read_initial_dictionaries_;
DictionaryMemo dictionary_memo_;
std::shared_ptr<Schema> schema_;
std::shared_ptr<Schema> filtered_schema_;
ReadStats stats_;
bool swap_endian_;

protected:
DictionaryMemo dictionary_memo_;
};

// ----------------------------------------------------------------------
Expand Down Expand Up @@ -1158,6 +1163,8 @@ class RecordBatchStreamReaderImpl : public RecordBatchStreamReader,

ReadStats stats() const override { return StreamDecoderInternal::stats(); }

DictionaryMemo* dictionary_memo() override { return &dictionary_memo_; }

private:
std::unique_ptr<MessageReader> message_reader_;
};
Expand Down Expand Up @@ -1490,6 +1497,8 @@ class RecordBatchFileReaderImpl : public RecordBatchFileReader {

ReadStats stats() const override { return stats_.poll(); }

DictionaryMemo* dictionary_memo() override { return &dictionary_memo_; }

Result<AsyncGenerator<std::shared_ptr<RecordBatch>>> GetRecordBatchGenerator(
const bool coalesce, const io::IOContext& io_context,
const io::CacheOptions cache_options,
Expand Down
19 changes: 19 additions & 0 deletions cpp/src/arrow/ipc/reader.h
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,9 @@ class ARROW_EXPORT RecordBatchStreamReader : public RecordBatchReader {

/// \brief Return current read statistics
virtual ReadStats stats() const = 0;

/// \brief Return the DictionaryMemo used by this reader
virtual DictionaryMemo* dictionary_memo() = 0;
};

/// \brief Reads the record batch file format
Expand Down Expand Up @@ -200,6 +203,9 @@ class ARROW_EXPORT RecordBatchFileReader
/// \brief Return current read statistics
virtual ReadStats stats() const = 0;

/// \brief Return the DictionaryMemo used by this reader
virtual DictionaryMemo* dictionary_memo() = 0;

/// \brief Computes the total number of rows in the file.
virtual Result<int64_t> CountRows() = 0;

Expand Down Expand Up @@ -580,6 +586,19 @@ Result<std::shared_ptr<RecordBatch>> ReadRecordBatch(
const DictionaryMemo* dictionary_memo, const IpcReadOptions& options,
io::RandomAccessFile* file);

/// \brief Read a dictionary message and add its contents to a DictionaryMemo
///
/// If the memo already contains a dictionary with the same id, it is replaced.
/// Does not perform endian swapping; intended for use with native-endian data.
/// For cross-endian support, use RecordBatchStreamReader or RecordBatchFileReader.
///
/// \param[in] message a Message of type DICTIONARY_BATCH
/// \param[in,out] dictionary_memo DictionaryMemo to populate with the dictionary data
/// \param[in] options IPC options for reading
ARROW_EXPORT
Status ReadDictionary(const Message& message, DictionaryMemo* dictionary_memo,
const IpcReadOptions& options = IpcReadOptions::Defaults());

/// \brief Read arrow::Tensor as encapsulated IPC message in file
///
/// \param[in] file an InputStream pointed at the start of the message
Expand Down
35 changes: 35 additions & 0 deletions cpp/src/arrow/ipc/writer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -821,6 +821,41 @@ Status GetDictionaryPayload(int64_t id, bool is_delta,
return assembler.Assemble(dictionary);
}

Result<std::vector<std::shared_ptr<Buffer>>> CollectAndSerializeDictionaries(
const RecordBatch& batch, DictionaryMemo* dictionary_memo,
const IpcWriteOptions& options) {
DictionaryFieldMapper mapper(*batch.schema());
ARROW_ASSIGN_OR_RAISE(auto dictionaries, CollectDictionaries(batch, mapper));

std::vector<std::shared_ptr<Buffer>> result;
for (const auto& pair : dictionaries) {
int64_t id = pair.first;
const auto& dictionary = pair.second;

if (dictionary_memo->HasDictionary(id)) {
ARROW_ASSIGN_OR_RAISE(auto existing,
dictionary_memo->GetDictionary(id, options.memory_pool));
if (existing.get() == dictionary->data().get()) {
continue;
}
}

IpcPayload payload;
RETURN_NOT_OK(GetDictionaryPayload(id, dictionary, options, &payload));

ARROW_ASSIGN_OR_RAISE(auto stream,
io::BufferOutputStream::Create(1024, options.memory_pool));
int32_t metadata_length = 0;
RETURN_NOT_OK(WriteIpcPayload(payload, options, stream.get(), &metadata_length));
ARROW_ASSIGN_OR_RAISE(auto buffer, stream->Finish());
result.push_back(std::move(buffer));

ARROW_ASSIGN_OR_RAISE(
std::ignore, dictionary_memo->AddOrReplaceDictionary(id, dictionary->data()));
}
return result;
}

Status GetRecordBatchPayload(const RecordBatch& batch, const IpcWriteOptions& options,
IpcPayload* out) {
return GetRecordBatchPayload(batch, NULLPTR, options, out);
Expand Down
18 changes: 18 additions & 0 deletions cpp/src/arrow/ipc/writer.h
Original file line number Diff line number Diff line change
Expand Up @@ -377,6 +377,24 @@ Status GetDictionaryPayload(int64_t id, bool is_delta,
const std::shared_ptr<Array>& dictionary,
const IpcWriteOptions& options, IpcPayload* payload);

/// \brief Collect and serialize dictionary messages for a RecordBatch
///
/// For each dictionary-encoded field in the batch, checks the memo to determine
/// whether serialization is needed. If the memo has no dictionary for a given id,
/// the dictionary is serialized and added. If the memo already has a dictionary
/// for that id, the ArrayData pointers are compared: if they are the same object
/// the dictionary is skipped (deduplicated), otherwise a replacement dictionary
/// message is serialized and the memo is updated.
///
/// \param[in] batch the RecordBatch to collect dictionaries from
/// \param[in,out] dictionary_memo tracks which dictionaries have been serialized
/// \param[in] options IPC write options
/// \return vector of serialized dictionary IPC message buffers
ARROW_EXPORT
Result<std::vector<std::shared_ptr<Buffer>>> CollectAndSerializeDictionaries(
const RecordBatch& batch, DictionaryMemo* dictionary_memo,
const IpcWriteOptions& options = IpcWriteOptions::Defaults());

/// \brief Compute IpcPayload for the given record batch
/// \param[in] batch the RecordBatch that is being serialized
/// \param[in] options options for serialization
Expand Down
2 changes: 2 additions & 0 deletions docs/source/python/api/ipc.rst
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@ Inter-Process Communication
ipc.open_stream
ipc.read_message
ipc.read_record_batch
ipc.read_dictionary_message
ipc.serialize_dictionaries
ipc.get_record_batch_size
ipc.read_tensor
ipc.write_tensor
Expand Down
13 changes: 12 additions & 1 deletion python/pyarrow/includes/libarrow.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -1916,7 +1916,7 @@ cdef extern from "arrow/ipc/api.h" namespace "arrow::ipc" nogil:
int64_t num_replaced_dictionaries

cdef cppclass CDictionaryMemo" arrow::ipc::DictionaryMemo":
pass
c_bool HasDictionary(int64_t id) const

cdef cppclass CIpcPayload" arrow::ipc::IpcPayload":
MessageType type
Expand Down Expand Up @@ -1970,6 +1970,7 @@ cdef extern from "arrow/ipc/api.h" namespace "arrow::ipc" nogil:
const CIpcReadOptions& options)

CIpcReadStats stats()
CDictionaryMemo* dictionary_memo()

cdef cppclass CRecordBatchFileReader \
" arrow::ipc::RecordBatchFileReader":
Expand All @@ -1992,6 +1993,7 @@ cdef extern from "arrow/ipc/api.h" namespace "arrow::ipc" nogil:
CResult[CRecordBatchWithMetadata] ReadRecordBatchWithCustomMetadata(int i)

CIpcReadStats stats()
CDictionaryMemo* dictionary_memo()

shared_ptr[const CKeyValueMetadata] metadata()

Expand Down Expand Up @@ -2020,12 +2022,21 @@ cdef extern from "arrow/ipc/api.h" namespace "arrow::ipc" nogil:
CDictionaryMemo* dictionary_memo,
const CIpcReadOptions& options)

CStatus ReadDictionary(const CMessage& message,
CDictionaryMemo* dictionary_memo,
const CIpcReadOptions& options)

CResult[shared_ptr[CBuffer]] SerializeSchema(
const CSchema& schema, CMemoryPool* pool)

CResult[shared_ptr[CBuffer]] SerializeRecordBatch(
const CRecordBatch& schema, const CIpcWriteOptions& options)

CResult[vector[shared_ptr[CBuffer]]] CollectAndSerializeDictionaries(
const CRecordBatch& batch,
CDictionaryMemo* dictionary_memo,
const CIpcWriteOptions& options)

CResult[shared_ptr[CSchema]] ReadSchema(const CMessage& message,
CDictionaryMemo* dictionary_memo)

Expand Down
87 changes: 87 additions & 0 deletions python/pyarrow/ipc.pxi
Original file line number Diff line number Diff line change
Expand Up @@ -1102,6 +1102,15 @@ cdef class _RecordBatchStreamReader(RecordBatchReader):
raise ValueError("Operation on closed reader")
return _wrap_read_stats(self.stream_reader.stats())

@property
def dictionary_memo(self):
"""
The DictionaryMemo associated with this reader.
"""
if not self.reader:
raise ValueError("Operation on closed reader")
return DictionaryMemo.wrap(self.stream_reader.dictionary_memo(), self)


cdef class _RecordBatchFileWriter(_RecordBatchStreamWriter):

Expand Down Expand Up @@ -1293,6 +1302,15 @@ cdef class _RecordBatchFileReader(_Weakrefable):
wrapped = pyarrow_wrap_metadata(self.reader.get().metadata())
return wrapped.to_dict() if wrapped is not None else None

@property
def dictionary_memo(self):
"""
The DictionaryMemo associated with this reader.
"""
if not self.reader:
raise ValueError("Operation on closed reader")
return DictionaryMemo.wrap(self.reader.get().dictionary_memo(), self)


def get_tensor_size(Tensor tensor):
"""
Expand Down Expand Up @@ -1502,3 +1520,72 @@ def read_record_batch(obj, Schema schema,
CIpcReadOptions.Defaults()))

return pyarrow_wrap_batch(result)


def read_dictionary_message(obj, DictionaryMemo dictionary_memo):
"""
Read a dictionary message into a DictionaryMemo.

If the memo already contains a dictionary with the same id, it is
replaced. The memo must already have dictionary types registered,
typically from a prior read_schema call with the same memo.

Parameters
----------
obj : Message or Buffer-like
A message of type DICTIONARY_BATCH.
dictionary_memo : DictionaryMemo
Memo to populate with the dictionary data.
"""
cdef Message message

if isinstance(obj, Message):
message = obj
else:
message = read_message(obj)

with nogil:
check_status(ReadDictionary(deref(message.message.get()),
dictionary_memo.memo,
CIpcReadOptions.Defaults()))


def serialize_dictionaries(RecordBatch batch,
DictionaryMemo dictionary_memo,
memory_pool=None):
"""
Serialize IPC dictionary messages needed for a RecordBatch.

For each dictionary-encoded column, checks the memo to determine
whether serialization is needed. Dictionaries are deduplicated by
pointer identity: if the memo already contains the same dictionary
object, it is skipped. If a different dictionary object exists for
the same field, a replacement message is emitted and the memo is
updated.

Parameters
----------
batch : RecordBatch
The record batch whose dictionaries should be serialized.
dictionary_memo : DictionaryMemo
Tracks which dictionaries have already been serialized.
Updated in place with newly serialized dictionaries.
memory_pool : MemoryPool, default None
Uses default memory pool if not specified.

Returns
-------
list of Buffer
Serialized dictionary IPC messages, in dependency order.
"""
batch._assert_cpu()
cdef:
vector[shared_ptr[CBuffer]] c_buffers
CIpcWriteOptions options = CIpcWriteOptions.Defaults()
options.memory_pool = maybe_unbox_memory_pool(memory_pool)

with nogil:
c_buffers = GetResultValue(
CollectAndSerializeDictionaries(
deref(batch.batch), dictionary_memo.memo, options))
return [pyarrow_wrap_buffer(buf) for buf in c_buffers]
2 changes: 2 additions & 0 deletions python/pyarrow/ipc.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@
RecordBatchReader, _ReadPandasMixin,
MetadataVersion, Alignment,
read_message, read_record_batch, read_schema,
read_dictionary_message,
serialize_dictionaries,
read_tensor, write_tensor,
get_record_batch_size, get_tensor_size)
import pyarrow.lib as lib
Expand Down
Loading