Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
148 changes: 119 additions & 29 deletions nifi/user_python_extensions/record_decompress_cerner_blob.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@ class Java:

class ProcessorDetails:
version = '0.0.1'
description = "Decompresses Cerner LZW compressed blobs from a JSON input stream"
tags = ["cerner", "oracle", "blob"]

def __init__(self, jvm: JVMView):
super().__init__(jvm)
Expand Down Expand Up @@ -110,6 +112,7 @@ def transform(self, context: ProcessContext, flowFile: JavaObject) -> FlowFileTr
"""

output_contents: list = []
attributes: dict = {k: str(v) for k, v in flowFile.getAttributes().items()}

try:
self.process_context = context
Expand All @@ -118,7 +121,7 @@ def transform(self, context: ProcessContext, flowFile: JavaObject) -> FlowFileTr
# read avro record
input_raw_bytes: bytes | bytearray = flowFile.getContentsAsBytes()

records = []
records: list | dict = []

try:
records = json.loads(input_raw_bytes.decode())
Expand All @@ -131,35 +134,70 @@ def transform(self, context: ProcessContext, flowFile: JavaObject) -> FlowFileTr
try:
records = json.loads(input_raw_bytes.decode("windows-1252"))
except json.JSONDecodeError as e:
self.logger.error(f"Error decoding JSON: {str(e)} \n with windows-1252")
raise
return self.build_failure_result(
flowFile,
ValueError(f"Error decoding JSON: {str(e)} \n with windows-1252"),
attributes=attributes,
contents=input_raw_bytes,
)

if not isinstance(records, list):
records = [records]

if not records:
raise ValueError("No records found in JSON input")
return self.build_failure_result(
flowFile,
ValueError("No records found in JSON input"),
attributes=attributes,
contents=input_raw_bytes,
)

# sanity check: blobs are from the same document_id
doc_ids: set = {str(r.get(self.document_id_field_name, "")) for r in records}
if len(doc_ids) > 1:
return self.build_failure_result(
flowFile,
ValueError(f"Multiple document IDs in one FlowFile: {list(doc_ids)}"),
attributes=attributes,
contents=input_raw_bytes,
)

concatenated_blob_sequence_order = {}
output_merged_record = {}
concatenated_blob_sequence_order: dict = {}
output_merged_record: dict = {}

have_any_sequence = any(self.blob_sequence_order_field_name in record for record in records)
have_any_no_sequence = any(self.blob_sequence_order_field_name not in record for record in records)
have_any_sequence: bool = any(self.blob_sequence_order_field_name in record for record in records)
have_any_no_sequence: bool = any(self.blob_sequence_order_field_name not in record for record in records)

if have_any_sequence and have_any_no_sequence:
raise ValueError(
f"Mixed records: some have '{self.blob_sequence_order_field_name}', some don't. "
"Cannot safely reconstruct blob stream."
return self.build_failure_result(
flowFile,
ValueError(
f"Mixed records: some have '{self.blob_sequence_order_field_name}', some don't. "
"Cannot safely reconstruct blob stream."
),
attributes=attributes,
contents=input_raw_bytes,
)

for record in records:
if self.binary_field_name not in record or record[self.binary_field_name] in (None, ""):
raise ValueError(f"Missing '{self.binary_field_name}' in a record")
return self.build_failure_result(
flowFile,
ValueError(f"Missing '{self.binary_field_name}' in a record"),
attributes=attributes,
contents=input_raw_bytes,
)

if have_any_sequence:
seq = int(record[self.blob_sequence_order_field_name])
if seq in concatenated_blob_sequence_order:
raise ValueError(f"Duplicate {self.blob_sequence_order_field_name}: {seq}")
return self.build_failure_result(
flowFile,
ValueError(f"Duplicate {self.blob_sequence_order_field_name}: {seq}"),
attributes=attributes,
contents=input_raw_bytes,
)

concatenated_blob_sequence_order[seq] = record[self.binary_field_name]
else:
# no sequence anywhere: preserve record order (0..n-1)
Expand All @@ -174,48 +212,100 @@ def transform(self, context: ProcessContext, flowFile: JavaObject) -> FlowFileTr

full_compressed_blob = bytearray()

for k in sorted(concatenated_blob_sequence_order.keys()):
# double check to make sure there is no gap in the blob sequence, i.e missing blob.
order_of_blobs_keys = sorted(concatenated_blob_sequence_order.keys())
for i in range(1, len(order_of_blobs_keys)):
if order_of_blobs_keys[i] != order_of_blobs_keys[i-1] + 1:
return self.build_failure_result(
flowFile,
ValueError(
f"Sequence gap: missing {order_of_blobs_keys[i-1] + 1} "
f"(have {order_of_blobs_keys[i-1]} then {order_of_blobs_keys[i]})"
),
attributes=attributes,
contents=input_raw_bytes,
)

for k in order_of_blobs_keys:
v = concatenated_blob_sequence_order[k]

temporary_blob: bytes = b""

if self.binary_field_source_encoding == "base64":
if not isinstance(v, str):
raise ValueError(f"Expected base64 string in {self.binary_field_name} for part {k}, got {type(v)}")
return self.build_failure_result(
flowFile,
ValueError(
f"Expected base64 string in {self.binary_field_name} for part {k}, got {type(v)}"
),
attributes=attributes,
contents=input_raw_bytes,
)
try:
temporary_blob = base64.b64decode(v, validate=True)
except Exception as e:
raise ValueError(f"Error decoding base64 blob part {k}: {e}")
return self.build_failure_result(
flowFile,
ValueError(f"Error decoding base64 blob part {k}: {e}"),
attributes=attributes,
contents=input_raw_bytes,
)
else:
# raw bytes path
if isinstance(v, (bytes, bytearray)):
temporary_blob = v
else:
raise ValueError(f"Expected bytes in {self.binary_field_name} for part {k}, got {type(v)}")

return self.build_failure_result(
flowFile,
ValueError(
f"Expected bytes in {self.binary_field_name} for part {k}, got {type(v)}"
),
attributes=attributes,
contents=input_raw_bytes,
)

full_compressed_blob.extend(temporary_blob)

# build / add new attributes to dict before doing anything else to have some trace.
attributes["document_id_field_name"] = str(self.document_id_field_name)
attributes["document_id"] = str(output_merged_record.get(self.document_id_field_name, ""))
attributes["binary_field"] = str(self.binary_field_name)
attributes["output_text_field_name"] = str(self.output_text_field_name)
attributes["mime.type"] = "application/json"
attributes["blob_parts"] = str(len(order_of_blobs_keys))
attributes["blob_seq_min"] = str(order_of_blobs_keys[0]) if order_of_blobs_keys else ""
attributes["blob_seq_max"] = str(order_of_blobs_keys[-1]) if order_of_blobs_keys else ""
attributes["compressed_len"] = str(len(full_compressed_blob))
attributes["compressed_head_hex"] = bytes(full_compressed_blob[:16]).hex()

try:
decompress_blob = DecompressLzwCernerBlob()
decompress_blob.decompress(full_compressed_blob)
output_merged_record[self.binary_field_name] = decompress_blob.output_stream
output_merged_record[self.binary_field_name] = bytes(decompress_blob.output_stream)
except Exception as exception:
self.logger.error(f"Error decompressing cerner blob: {str(exception)} \n")
raise exception
return self.build_failure_result(
flowFile,
exception=exception,
attributes=attributes,
include_flowfile_attributes=False,
contents=input_raw_bytes
)

if self.output_mode == "base64":
output_merged_record[self.binary_field_name] = \
base64.b64encode(output_merged_record[self.binary_field_name]).decode(self.output_charset)

output_contents.append(output_merged_record)

attributes: dict = {k: str(v) for k, v in flowFile.getAttributes().items()}
attributes["document_id_field_name"] = str(self.document_id_field_name)
attributes["binary_field"] = str(self.binary_field_name)
attributes["output_text_field_name"] = str(self.output_text_field_name)
attributes["mime.type"] = "application/json"

return FlowFileTransformResult(relationship="success",
return FlowFileTransformResult(relationship=self.REL_SUCCESS,
attributes=attributes,
contents=json.dumps(output_contents).encode("utf-8"))
except Exception as exception:
self.logger.error("Exception during flowfile processing: " + traceback.format_exc())
raise exception
return self.build_failure_result(
flowFile,
exception,
attributes=attributes,
contents=locals().get("input_raw_bytes", flowFile.getContentsAsBytes()),
include_flowfile_attributes=False
)
67 changes: 41 additions & 26 deletions nifi/user_scripts/utils/codecs/cerner_blob.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@

class LzwItem:
def __init__(self, _prefix: int = 0, _suffix: int = 0) -> None:
self.prefix = _prefix
Expand All @@ -9,7 +8,7 @@ class DecompressLzwCernerBlob:
def __init__(self) -> None:
self.MAX_CODES: int = 8192
self.tmp_decompression_buffer: list[int] = [0] * self.MAX_CODES
self.lzw_lookup_table: list[LzwItem] = [LzwItem()] * self.MAX_CODES
self.lzw_lookup_table: list[LzwItem] = [LzwItem() for _ in range(self.MAX_CODES)]
self.tmp_buffer_index: int = 0
self.current_byte_buffer_index: int = 0

Expand All @@ -21,19 +20,24 @@ def save_to_lookup_table(self, compressed_code: int):
self.tmp_buffer_index = -1
while compressed_code >= 258:
self.tmp_buffer_index += 1
self.tmp_decompression_buffer[self.tmp_buffer_index] = \
self.lzw_lookup_table[compressed_code].suffix
self.tmp_decompression_buffer[self.tmp_buffer_index] = self.lzw_lookup_table[compressed_code].suffix
compressed_code = self.lzw_lookup_table[compressed_code].prefix

self.tmp_buffer_index += 1
self.tmp_decompression_buffer[self.tmp_buffer_index] = compressed_code

for i in reversed(list(range(self.tmp_buffer_index + 1))):
self.output_stream.append(self.tmp_decompression_buffer[i])
for i in reversed(range(self.tmp_buffer_index + 1)):
v = self.tmp_decompression_buffer[i]
if not (0 <= v <= 255):
raise ValueError(f"Invalid output byte {v} (expected 0..255)")
self.output_stream.append(v)

def decompress(self, input_stream: bytearray = bytearray()):
def decompress(self, input_stream: bytearray):
if not input_stream:
raise ValueError("Empty input_stream")

byte_buffer_index: int = 0
self.output_stream = bytearray()

# used for bit shifts
shift: int = 1
Expand All @@ -49,46 +53,56 @@ def decompress(self, input_stream: bytearray = bytearray()):

while True:
if current_shift >= 9:

current_shift -= 8

if first_code != 0:
byte_buffer_index += 1
if byte_buffer_index >= len(input_stream):
raise ValueError("Truncated input_stream")

middle_code = input_stream[byte_buffer_index]

first_code = (first_code << current_shift +
8) | (middle_code << current_shift)
first_code = (first_code << (current_shift + 8)) | (middle_code << current_shift)

byte_buffer_index += 1

if byte_buffer_index >= len(input_stream):
raise ValueError("Truncated input_stream")

middle_code = input_stream[byte_buffer_index]

tmp_code = middle_code >> (8 - current_shift)
lookup_index = first_code | tmp_code

skip_flag = True
else:
byte_buffer_index += 1
if byte_buffer_index >= len(input_stream):
raise ValueError("Truncated input_stream")
first_code = input_stream[byte_buffer_index]
byte_buffer_index += 1
if byte_buffer_index >= len(input_stream):
raise ValueError("Truncated input_stream")
middle_code = input_stream[byte_buffer_index]
else:
byte_buffer_index += 1
if byte_buffer_index >= len(input_stream):
raise ValueError("Truncated input_stream")
middle_code = input_stream[byte_buffer_index]

if not skip_flag:
lookup_index = (first_code << current_shift) | (
middle_code >> 8 - current_shift)
lookup_index = (first_code << current_shift) | (middle_code >> (8 - current_shift))

if lookup_index == 256:
shift = 1
current_shift += 1
first_code = input_stream[byte_buffer_index]

current_shift = 1
previous_code = 0
skip_flag = False

self.tmp_decompression_buffer = [0] * self.MAX_CODES
self.tmp_buffer_index = 0

self.lzw_lookup_table = [LzwItem()] * self.MAX_CODES
self.lzw_lookup_table = [LzwItem() for _ in range(self.MAX_CODES)]
self.code_count = 257

first_code = input_stream[byte_buffer_index]
continue

elif lookup_index == 257: # EOF marker
Expand All @@ -99,18 +113,18 @@ def decompress(self, input_stream: bytearray = bytearray()):
# skipit part
if previous_code == 0:
self.tmp_decompression_buffer[0] = lookup_index

if lookup_index < self.code_count:
self.save_to_lookup_table(lookup_index)
if self.code_count < self.MAX_CODES:
self.lzw_lookup_table[self.code_count] = LzwItem(
previous_code,
self.tmp_decompression_buffer[self.tmp_buffer_index])
self.lzw_lookup_table[self.code_count] = \
LzwItem(previous_code, self.tmp_decompression_buffer[self.tmp_buffer_index])
self.code_count += 1
else:
self.lzw_lookup_table[self.code_count] = LzwItem(
previous_code,
self.tmp_decompression_buffer[self.tmp_buffer_index])
self.code_count += 1
if self.code_count < self.MAX_CODES:
self.lzw_lookup_table[self.code_count] = \
LzwItem(previous_code, self.tmp_decompression_buffer[self.tmp_buffer_index])
self.code_count += 1
self.save_to_lookup_table(lookup_index)
# end of skipit

Expand All @@ -120,4 +134,5 @@ def decompress(self, input_stream: bytearray = bytearray()):
if self.code_count in [511, 1023, 2047, 4095]:
shift += 1
current_shift += 1

previous_code = lookup_index
Loading