diff --git a/batbot/spectrogram/__init__.py b/batbot/spectrogram/__init__.py index 99fb956..8b94716 100644 --- a/batbot/spectrogram/__init__.py +++ b/batbot/spectrogram/__init__.py @@ -1380,6 +1380,20 @@ def calculate_harmonic_and_echo_flags( return harmonic_flag, harmonic_peak, echo_flag, echo_peak +def create_masked_image(stft_db, costs, kernel=11): + weights = costs.copy() + weights[weights < weights.mean()] = 0 + weights[weights > 0] = weights.max() + weights = weights.astype(np.float32) + weights = cv2.GaussianBlur( + weights, (kernel, kernel), sigmaX=4, sigmaY=4, borderType=cv2.BORDER_DEFAULT + ) + weights /= weights.max() + masked = stft_db * weights + masked = normalize_stft(masked, None, np.uint8) + return masked + + # @lp def compute_wrapper( wav_filepath, @@ -1823,6 +1837,8 @@ def compute_wrapper( output_paths = [] compressed_paths = [] + mask_paths = [] + masked_paths = [] if not fast_mode: datas = [ (output_paths, 'jpg', stft_db), @@ -1834,6 +1850,14 @@ def compute_wrapper( (compressed_paths, 'compressed.jpg', segments['stft_db']), ] + # Create masked image + if 'costs' in segments and 'stft_db' in segments: + masked = create_masked_image(segments['stft_db'], segments['costs']) + datas += [ + (mask_paths, 'mask.jpg', segments['costs']), + (masked_paths, 'masked.jpg', masked), + ] + for accumulator, tag, data in datas: if data.dtype != np.uint8: data_ = data.astype(np.float32) @@ -1863,6 +1887,8 @@ def compute_wrapper( 'spectrogram': { 'uncompressed.path': output_paths, 'compressed.path': compressed_paths, + 'mask.path': mask_paths, + 'masked.path': masked_paths, }, 'global_threshold.amp': int(round(255.0 * (global_threshold / max_value))), 'sr.hz': int(sr), @@ -1886,6 +1912,9 @@ def compute_wrapper( 'width.px': segments['stft_db'].shape[1], 'height.px': segments['stft_db'].shape[0], } + if 'costs' in segments and 'stft_db' in segments: + metadata['size']['mask'] = metadata['size']['compressed'] + metadata['size']['masked'] = metadata['size']['compressed'] metadata_path = f'{out_file_stem}.metadata.json' with open(metadata_path, 'w') as metafile: diff --git a/setup.cfg b/setup.cfg index d73eb80..3cf199b 100644 --- a/setup.cfg +++ b/setup.cfg @@ -38,7 +38,7 @@ python_requires = >=3.7 [options.entry_points] console_scripts = - batbot = batbot.batbot:cli + batbot = batbot.batbot_cli:cli [tool:pytest] minversion = 5.4