diff --git a/nifi/user_python_extensions/record_decompress_cerner_blob.py b/nifi/user_python_extensions/record_decompress_cerner_blob.py index b650beea..718a8a14 100644 --- a/nifi/user_python_extensions/record_decompress_cerner_blob.py +++ b/nifi/user_python_extensions/record_decompress_cerner_blob.py @@ -33,6 +33,8 @@ class Java: class ProcessorDetails: version = '0.0.1' + description = "Decompresses Cerner LZW compressed blobs from a JSON input stream" + tags = ["cerner", "oracle", "blob"] def __init__(self, jvm: JVMView): super().__init__(jvm) @@ -110,6 +112,7 @@ def transform(self, context: ProcessContext, flowFile: JavaObject) -> FlowFileTr """ output_contents: list = [] + attributes: dict = {k: str(v) for k, v in flowFile.getAttributes().items()} try: self.process_context = context @@ -118,7 +121,7 @@ def transform(self, context: ProcessContext, flowFile: JavaObject) -> FlowFileTr # read avro record input_raw_bytes: bytes | bytearray = flowFile.getContentsAsBytes() - records = [] + records: list | dict = [] try: records = json.loads(input_raw_bytes.decode()) @@ -131,35 +134,70 @@ def transform(self, context: ProcessContext, flowFile: JavaObject) -> FlowFileTr try: records = json.loads(input_raw_bytes.decode("windows-1252")) except json.JSONDecodeError as e: - self.logger.error(f"Error decoding JSON: {str(e)} \n with windows-1252") - raise + return self.build_failure_result( + flowFile, + ValueError(f"Error decoding JSON: {str(e)} \n with windows-1252"), + attributes=attributes, + contents=input_raw_bytes, + ) if not isinstance(records, list): records = [records] if not records: - raise ValueError("No records found in JSON input") + return self.build_failure_result( + flowFile, + ValueError("No records found in JSON input"), + attributes=attributes, + contents=input_raw_bytes, + ) + + # sanity check: blobs are from the same document_id + doc_ids: set = {str(r.get(self.document_id_field_name, "")) for r in records} + if len(doc_ids) > 1: + return self.build_failure_result( + flowFile, + ValueError(f"Multiple document IDs in one FlowFile: {list(doc_ids)}"), + attributes=attributes, + contents=input_raw_bytes, + ) - concatenated_blob_sequence_order = {} - output_merged_record = {} + concatenated_blob_sequence_order: dict = {} + output_merged_record: dict = {} - have_any_sequence = any(self.blob_sequence_order_field_name in record for record in records) - have_any_no_sequence = any(self.blob_sequence_order_field_name not in record for record in records) + have_any_sequence: bool = any(self.blob_sequence_order_field_name in record for record in records) + have_any_no_sequence: bool = any(self.blob_sequence_order_field_name not in record for record in records) if have_any_sequence and have_any_no_sequence: - raise ValueError( - f"Mixed records: some have '{self.blob_sequence_order_field_name}', some don't. " - "Cannot safely reconstruct blob stream." + return self.build_failure_result( + flowFile, + ValueError( + f"Mixed records: some have '{self.blob_sequence_order_field_name}', some don't. " + "Cannot safely reconstruct blob stream." + ), + attributes=attributes, + contents=input_raw_bytes, ) for record in records: if self.binary_field_name not in record or record[self.binary_field_name] in (None, ""): - raise ValueError(f"Missing '{self.binary_field_name}' in a record") + return self.build_failure_result( + flowFile, + ValueError(f"Missing '{self.binary_field_name}' in a record"), + attributes=attributes, + contents=input_raw_bytes, + ) if have_any_sequence: seq = int(record[self.blob_sequence_order_field_name]) if seq in concatenated_blob_sequence_order: - raise ValueError(f"Duplicate {self.blob_sequence_order_field_name}: {seq}") + return self.build_failure_result( + flowFile, + ValueError(f"Duplicate {self.blob_sequence_order_field_name}: {seq}"), + attributes=attributes, + contents=input_raw_bytes, + ) + concatenated_blob_sequence_order[seq] = record[self.binary_field_name] else: # no sequence anywhere: preserve record order (0..n-1) @@ -174,32 +212,84 @@ def transform(self, context: ProcessContext, flowFile: JavaObject) -> FlowFileTr full_compressed_blob = bytearray() - for k in sorted(concatenated_blob_sequence_order.keys()): + # double check to make sure there is no gap in the blob sequence, i.e missing blob. + order_of_blobs_keys = sorted(concatenated_blob_sequence_order.keys()) + for i in range(1, len(order_of_blobs_keys)): + if order_of_blobs_keys[i] != order_of_blobs_keys[i-1] + 1: + return self.build_failure_result( + flowFile, + ValueError( + f"Sequence gap: missing {order_of_blobs_keys[i-1] + 1} " + f"(have {order_of_blobs_keys[i-1]} then {order_of_blobs_keys[i]})" + ), + attributes=attributes, + contents=input_raw_bytes, + ) + + for k in order_of_blobs_keys: v = concatenated_blob_sequence_order[k] + temporary_blob: bytes = b"" + if self.binary_field_source_encoding == "base64": if not isinstance(v, str): - raise ValueError(f"Expected base64 string in {self.binary_field_name} for part {k}, got {type(v)}") + return self.build_failure_result( + flowFile, + ValueError( + f"Expected base64 string in {self.binary_field_name} for part {k}, got {type(v)}" + ), + attributes=attributes, + contents=input_raw_bytes, + ) try: temporary_blob = base64.b64decode(v, validate=True) except Exception as e: - raise ValueError(f"Error decoding base64 blob part {k}: {e}") + return self.build_failure_result( + flowFile, + ValueError(f"Error decoding base64 blob part {k}: {e}"), + attributes=attributes, + contents=input_raw_bytes, + ) else: # raw bytes path if isinstance(v, (bytes, bytearray)): temporary_blob = v else: - raise ValueError(f"Expected bytes in {self.binary_field_name} for part {k}, got {type(v)}") - + return self.build_failure_result( + flowFile, + ValueError( + f"Expected bytes in {self.binary_field_name} for part {k}, got {type(v)}" + ), + attributes=attributes, + contents=input_raw_bytes, + ) + full_compressed_blob.extend(temporary_blob) + # build / add new attributes to dict before doing anything else to have some trace. + attributes["document_id_field_name"] = str(self.document_id_field_name) + attributes["document_id"] = str(output_merged_record.get(self.document_id_field_name, "")) + attributes["binary_field"] = str(self.binary_field_name) + attributes["output_text_field_name"] = str(self.output_text_field_name) + attributes["mime.type"] = "application/json" + attributes["blob_parts"] = str(len(order_of_blobs_keys)) + attributes["blob_seq_min"] = str(order_of_blobs_keys[0]) if order_of_blobs_keys else "" + attributes["blob_seq_max"] = str(order_of_blobs_keys[-1]) if order_of_blobs_keys else "" + attributes["compressed_len"] = str(len(full_compressed_blob)) + attributes["compressed_head_hex"] = bytes(full_compressed_blob[:16]).hex() + try: decompress_blob = DecompressLzwCernerBlob() decompress_blob.decompress(full_compressed_blob) - output_merged_record[self.binary_field_name] = decompress_blob.output_stream + output_merged_record[self.binary_field_name] = bytes(decompress_blob.output_stream) except Exception as exception: - self.logger.error(f"Error decompressing cerner blob: {str(exception)} \n") - raise exception + return self.build_failure_result( + flowFile, + exception=exception, + attributes=attributes, + include_flowfile_attributes=False, + contents=input_raw_bytes + ) if self.output_mode == "base64": output_merged_record[self.binary_field_name] = \ @@ -207,15 +297,15 @@ def transform(self, context: ProcessContext, flowFile: JavaObject) -> FlowFileTr output_contents.append(output_merged_record) - attributes: dict = {k: str(v) for k, v in flowFile.getAttributes().items()} - attributes["document_id_field_name"] = str(self.document_id_field_name) - attributes["binary_field"] = str(self.binary_field_name) - attributes["output_text_field_name"] = str(self.output_text_field_name) - attributes["mime.type"] = "application/json" - - return FlowFileTransformResult(relationship="success", + return FlowFileTransformResult(relationship=self.REL_SUCCESS, attributes=attributes, contents=json.dumps(output_contents).encode("utf-8")) except Exception as exception: self.logger.error("Exception during flowfile processing: " + traceback.format_exc()) - raise exception + return self.build_failure_result( + flowFile, + exception, + attributes=attributes, + contents=locals().get("input_raw_bytes", flowFile.getContentsAsBytes()), + include_flowfile_attributes=False + ) diff --git a/nifi/user_scripts/utils/codecs/cerner_blob.py b/nifi/user_scripts/utils/codecs/cerner_blob.py index cfd299bf..9596016c 100644 --- a/nifi/user_scripts/utils/codecs/cerner_blob.py +++ b/nifi/user_scripts/utils/codecs/cerner_blob.py @@ -1,4 +1,3 @@ - class LzwItem: def __init__(self, _prefix: int = 0, _suffix: int = 0) -> None: self.prefix = _prefix @@ -9,7 +8,7 @@ class DecompressLzwCernerBlob: def __init__(self) -> None: self.MAX_CODES: int = 8192 self.tmp_decompression_buffer: list[int] = [0] * self.MAX_CODES - self.lzw_lookup_table: list[LzwItem] = [LzwItem()] * self.MAX_CODES + self.lzw_lookup_table: list[LzwItem] = [LzwItem() for _ in range(self.MAX_CODES)] self.tmp_buffer_index: int = 0 self.current_byte_buffer_index: int = 0 @@ -21,19 +20,24 @@ def save_to_lookup_table(self, compressed_code: int): self.tmp_buffer_index = -1 while compressed_code >= 258: self.tmp_buffer_index += 1 - self.tmp_decompression_buffer[self.tmp_buffer_index] = \ - self.lzw_lookup_table[compressed_code].suffix + self.tmp_decompression_buffer[self.tmp_buffer_index] = self.lzw_lookup_table[compressed_code].suffix compressed_code = self.lzw_lookup_table[compressed_code].prefix self.tmp_buffer_index += 1 self.tmp_decompression_buffer[self.tmp_buffer_index] = compressed_code - for i in reversed(list(range(self.tmp_buffer_index + 1))): - self.output_stream.append(self.tmp_decompression_buffer[i]) + for i in reversed(range(self.tmp_buffer_index + 1)): + v = self.tmp_decompression_buffer[i] + if not (0 <= v <= 255): + raise ValueError(f"Invalid output byte {v} (expected 0..255)") + self.output_stream.append(v) - def decompress(self, input_stream: bytearray = bytearray()): + def decompress(self, input_stream: bytearray): + if not input_stream: + raise ValueError("Empty input_stream") byte_buffer_index: int = 0 + self.output_stream = bytearray() # used for bit shifts shift: int = 1 @@ -49,46 +53,56 @@ def decompress(self, input_stream: bytearray = bytearray()): while True: if current_shift >= 9: - current_shift -= 8 - if first_code != 0: byte_buffer_index += 1 + if byte_buffer_index >= len(input_stream): + raise ValueError("Truncated input_stream") + middle_code = input_stream[byte_buffer_index] - first_code = (first_code << current_shift + - 8) | (middle_code << current_shift) + first_code = (first_code << (current_shift + 8)) | (middle_code << current_shift) byte_buffer_index += 1 + + if byte_buffer_index >= len(input_stream): + raise ValueError("Truncated input_stream") + middle_code = input_stream[byte_buffer_index] tmp_code = middle_code >> (8 - current_shift) lookup_index = first_code | tmp_code - skip_flag = True else: byte_buffer_index += 1 + if byte_buffer_index >= len(input_stream): + raise ValueError("Truncated input_stream") first_code = input_stream[byte_buffer_index] byte_buffer_index += 1 + if byte_buffer_index >= len(input_stream): + raise ValueError("Truncated input_stream") middle_code = input_stream[byte_buffer_index] else: byte_buffer_index += 1 + if byte_buffer_index >= len(input_stream): + raise ValueError("Truncated input_stream") middle_code = input_stream[byte_buffer_index] if not skip_flag: - lookup_index = (first_code << current_shift) | ( - middle_code >> 8 - current_shift) + lookup_index = (first_code << current_shift) | (middle_code >> (8 - current_shift)) if lookup_index == 256: shift = 1 - current_shift += 1 - first_code = input_stream[byte_buffer_index] - + current_shift = 1 + previous_code = 0 + skip_flag = False + self.tmp_decompression_buffer = [0] * self.MAX_CODES self.tmp_buffer_index = 0 - - self.lzw_lookup_table = [LzwItem()] * self.MAX_CODES + self.lzw_lookup_table = [LzwItem() for _ in range(self.MAX_CODES)] self.code_count = 257 + + first_code = input_stream[byte_buffer_index] continue elif lookup_index == 257: # EOF marker @@ -99,18 +113,18 @@ def decompress(self, input_stream: bytearray = bytearray()): # skipit part if previous_code == 0: self.tmp_decompression_buffer[0] = lookup_index + if lookup_index < self.code_count: self.save_to_lookup_table(lookup_index) if self.code_count < self.MAX_CODES: - self.lzw_lookup_table[self.code_count] = LzwItem( - previous_code, - self.tmp_decompression_buffer[self.tmp_buffer_index]) + self.lzw_lookup_table[self.code_count] = \ + LzwItem(previous_code, self.tmp_decompression_buffer[self.tmp_buffer_index]) self.code_count += 1 else: - self.lzw_lookup_table[self.code_count] = LzwItem( - previous_code, - self.tmp_decompression_buffer[self.tmp_buffer_index]) - self.code_count += 1 + if self.code_count < self.MAX_CODES: + self.lzw_lookup_table[self.code_count] = \ + LzwItem(previous_code, self.tmp_decompression_buffer[self.tmp_buffer_index]) + self.code_count += 1 self.save_to_lookup_table(lookup_index) # end of skipit @@ -120,4 +134,5 @@ def decompress(self, input_stream: bytearray = bytearray()): if self.code_count in [511, 1023, 2047, 4095]: shift += 1 current_shift += 1 + previous_code = lookup_index diff --git a/nifi/user_scripts/utils/nifi/base_nifi_processor.py b/nifi/user_scripts/utils/nifi/base_nifi_processor.py index 8e28ef3b..f29478cf 100644 --- a/nifi/user_scripts/utils/nifi/base_nifi_processor.py +++ b/nifi/user_scripts/utils/nifi/base_nifi_processor.py @@ -163,8 +163,11 @@ def build_failure_result( self, flowFile: JavaObject, exception: Exception, - *, + attributes: dict | None = None, include_flowfile_attributes: bool = False, + contents: bytes | bytearray | None = None, + *args, + **kwargs, ) -> FlowFileTransformResult: """ Build a failure FlowFileTransformResult with exception metadata. @@ -172,7 +175,9 @@ def build_failure_result( Args: flowFile: The FlowFile being processed. exception: The exception raised during processing. - include_flowfile_attributes: If true, include all FlowFile attributes. + attributes: Optional pre-built attributes dict to use/extend. + include_flowfile_attributes: If true, merge in all FlowFile attributes. + contents: Optional override for contents; defaults to the incoming FlowFile contents. Returns: A FlowFileTransformResult targeting the failure relationship. @@ -180,19 +185,24 @@ def build_failure_result( exception_name = type(exception).__name__ exception_message = str(exception) - exception_value = ( - f"{exception_name}: {exception_message}" if exception_message else exception_name - ) + exception_value = f"{exception_name}: {exception_message}" if exception_message else exception_name + + merged_attributes: dict[str, str] = {} + if attributes: + merged_attributes.update({k: str(v) for k, v in attributes.items()}) - attributes = {} if include_flowfile_attributes: - attributes = {k: str(v) for k, v in flowFile.getAttributes().items()} - attributes["exception"] = exception_value + merged_attributes.update({k: str(v) for k, v in flowFile.getAttributes().items()}) + + merged_attributes["exception"] = exception_value + + if contents is None: + contents = flowFile.getContentsAsBytes() return FlowFileTransformResult( - relationship="failure", - attributes=attributes, - contents=flowFile.getContentsAsBytes(), + relationship=self.REL_FAILURE, + attributes=merged_attributes, + contents=contents ) def onScheduled(self, context: ProcessContext) -> None: diff --git a/typings/nifiapi/flowfiletransform.pyi b/typings/nifiapi/flowfiletransform.pyi index 65306797..c1950c51 100644 --- a/typings/nifiapi/flowfiletransform.pyi +++ b/typings/nifiapi/flowfiletransform.pyi @@ -1,5 +1,6 @@ from typing import Any, Protocol +from nifiapi.relationship import Relationship from py4j.java_gateway import JavaObject from .properties import ProcessContext @@ -13,7 +14,7 @@ class FlowFileTransform(Protocol): class FlowFileTransformResult: def __init__( self, - relationship: str, + relationship: str | Relationship, attributes: dict[str, str], contents: bytes | None = None, ) -> None: ...