diff --git a/src/post_processing/dataclass/data_aplose.py b/src/post_processing/dataclass/data_aplose.py index 98e6d9c..e1d3cae 100644 --- a/src/post_processing/dataclass/data_aplose.py +++ b/src/post_processing/dataclass/data_aplose.py @@ -393,7 +393,6 @@ def plot( color = kwargs.get("color") season = kwargs.get("season") effort = kwargs.get("effort") - if not bin_size: msg = "'bin_size' missing for histogram plot." raise ValueError(msg) @@ -417,18 +416,20 @@ def plot( season = kwargs.get("season", False) bin_size = kwargs.get("bin_size") - return heatmap(df=df_filtered, - ax=ax, - bin_size=bin_size, - time_range=time, - show_rise_set=show_rise_set, - season=season, - coordinates=self.coordinates, - ) + return heatmap( + df=df_filtered, + ax=ax, + bin_size=bin_size, + time_range=time, + show_rise_set=show_rise_set, + season=season, + coordinates=self.coordinates, + ) if mode == "scatter": show_rise_set = kwargs.get("show_rise_set", True) season = kwargs.get("season", False) + effort = kwargs.get("effort") return scatter(df=df_filtered, ax=ax, @@ -436,6 +437,7 @@ def plot( show_rise_set=show_rise_set, season=season, coordinates=self.coordinates, + effort=effort, ) if mode == "agreement": diff --git a/src/post_processing/dataclass/detection_filter.py b/src/post_processing/dataclass/detection_filter.py index d636c4c..b28c023 100644 --- a/src/post_processing/dataclass/detection_filter.py +++ b/src/post_processing/dataclass/detection_filter.py @@ -7,7 +7,7 @@ from __future__ import annotations -from dataclasses import dataclass +from dataclasses import dataclass, fields from pathlib import Path from typing import TYPE_CHECKING, Literal @@ -44,6 +44,12 @@ class DetectionFilter: box: bool = False filename_format: str = None + def __getitem__(self, key: str): + """Return the value of the given key.""" + if key in {f.name for f in fields(self)}: + return getattr(self, key) + raise KeyError(key) + @classmethod def from_yaml( cls, diff --git a/src/post_processing/dataclass/recording_period.py b/src/post_processing/dataclass/recording_period.py index 4c09722..d0d48b7 100644 --- a/src/post_processing/dataclass/recording_period.py +++ b/src/post_processing/dataclass/recording_period.py @@ -8,19 +8,16 @@ from dataclasses import dataclass from typing import TYPE_CHECKING -from osekit.config import TIMESTAMP_FORMATS_EXPORTED_FILES -from osekit.utils.timestamp_utils import strptime_from_text from pandas import ( + IntervalIndex, Series, Timedelta, - cut, + date_range, read_csv, + to_datetime, ) -from post_processing.utils.core_utils import ( - get_time_range_and_bin_size, - localize_timestamps, -) +from post_processing.utils.core_utils import round_begin_end_timestamps from post_processing.utils.filtering_utils import ( find_delimiter, ) @@ -33,7 +30,7 @@ @dataclass(frozen=True) class RecordingPeriod: - """A class to handle recording periods.""" + """Represents recording effort over time, aggregated into bins.""" counts: Series timebin_origin: Timedelta @@ -42,33 +39,125 @@ class RecordingPeriod: def from_path( cls, config: DetectionFilter, - date_format: str = TIMESTAMP_FORMATS_EXPORTED_FILES, *, bin_size: Timedelta | BaseOffset, ) -> RecordingPeriod: - """Return a list of Timestamps corresponding to recording periods.""" + """Vectorised creation of recording coverage from CSV with start/end datetimes. + + This method reads a CSV with columns: + - "start_recording" + - "end_recording" + - "start_deployment" + - "end_deployment" + + It computes the **effective recording interval** as the intersection between + recording and deployment periods, builds a fine-grained timeline at + `timebin_origin` resolution, and aggregates effort into `bin_size` bins. + + Parameters + ---------- + config + Configuration object containing at least: + - `timestamp_file`: path to CSV + - `timebin_origin`: Timedelta resolution of detections + bin_size : Timedelta or BaseOffset + Size of the aggregation bin (e.g. Timedelta("1H") or "1D"). + + Returns + ------- + RecordingPeriod + Object containing `counts` (Series indexed by IntervalIndex) and + `timebin_origin`. + + """ + # Read CSV and parse datetime columns timestamp_file = config.timestamp_file delim = find_delimiter(timestamp_file) - timestamp_df = read_csv(timestamp_file, delimiter=delim) - - if "timestamp" in timestamp_df.columns: - msg = "Parsing 'timestamp' column not implemented yet." - raise NotImplementedError(msg) - - if "filename" in timestamp_df.columns: - timestamps = [ - strptime_from_text(ts, date_format) - for ts in timestamp_df["filename"] - ] - timestamps = localize_timestamps(timestamps, config.timezone) - time_vector, bin_size = get_time_range_and_bin_size(timestamps, bin_size) - - binned = cut(timestamps, time_vector) - max_annot = bin_size / config.timebin_origin - - return cls(counts=binned.value_counts().sort_index().clip(upper=max_annot), - timebin_origin=config.timebin_origin, - ) - - msg = "Could not parse timestamps." - raise ValueError(msg) + df = read_csv( + config.timestamp_file, + parse_dates=[ + "start_recording", + "end_recording", + "start_deployment", + "end_deployment", + ], + delimiter=delim, + ) + + if df.empty: + msg = "CSV is empty." + raise ValueError(msg) + + # Ensure all required columns are present + required_columns = { + "start_recording", + "end_recording", + "start_deployment", + "end_deployment", + } + + missing = required_columns - set(df.columns) + + if missing: + msg = f"CSV is missing required columns: {', '.join(sorted(missing))}" + raise ValueError(msg) + + # Normalise timezones: convert to UTC, then remove tz info (naive) + for col in [ + "start_recording", + "end_recording", + "start_deployment", + "end_deployment", + ]: + df[col] = to_datetime(df[col], utc=True).dt.tz_convert(None) + + # Compute effective recording intervals (intersection) + df["effective_start_recording"] = df[ + ["start_recording", "start_deployment"] + ].max(axis=1) + + df["effective_end_recording"] = df[ + ["end_recording", "end_deployment"] + ].min(axis=1) + + # Remove rows with no actual recording interval + df = df.loc[ + df["effective_start_recording"] < df["effective_end_recording"] + ].copy() + + if df.empty: + msg = "No valid recording intervals after deployment intersection." + raise ValueError(msg) + + # Build fine-grained timeline at `timebin_origin` resolution + origin = config.timebin_origin + time_index = date_range( + start=df["effective_start_recording"].min(), + end=df["effective_end_recording"].max(), + freq=origin, + ) + + # Initialise effort vector (0 = no recording, 1 = recording) + # Compare each timestamp to all intervals in a vectorised manner + effort = Series(0, index=time_index) + + # Vectorised interval coverage + t_vals = time_index.to_numpy()[:, None] + start_vals = df["effective_start_recording"].to_numpy() + end_vals = df["effective_end_recording"].to_numpy() + + # Boolean matrix: True if the timestamp is within any recording interval + covered = (t_vals >= start_vals) & (t_vals < end_vals) + effort[:] = covered.any(axis=1).astype(int) + + # Aggregate effort into user-defined bin_size + counts = effort.resample(bin_size, closed="left", label="left").sum() + + counts.index = IntervalIndex.from_arrays( + counts.index, + counts.index + + round_begin_end_timestamps(list(counts.index), bin_size)[-1], + closed="left", + ) + + return cls(counts=counts, timebin_origin=origin) diff --git a/src/post_processing/utils/core_utils.py b/src/post_processing/utils/core_utils.py index 5a831e1..c149155 100644 --- a/src/post_processing/utils/core_utils.py +++ b/src/post_processing/utils/core_utils.py @@ -3,7 +3,7 @@ from __future__ import annotations import json -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Any import astral import easygui @@ -11,7 +11,7 @@ from astral.sun import sunrise, sunset from matplotlib import pyplot as plt from osekit.config import TIMESTAMP_FORMAT_AUDIO_FILE -from osekit.utils.timestamp_utils import strptime_from_text, strftime_osmose_format +from osekit.utils.timestamp_utils import strftime_osmose_format, strptime_from_text from pandas import ( DataFrame, DatetimeIndex, @@ -44,10 +44,10 @@ def get_season(ts: Timestamp, *, northern: bool = True) -> tuple[str, int]: """Determine the meteorological season from a Timestamp. - In the Northern hemisphere + In the Northern Hemisphere Winter: Dec-Feb, Spring: Mar-May, Summer: Jun-Aug, Autumn: Sep-Nov - In the Southern hemisphere + In the Southern Hemisphere Winter: Jun-Aug, Spring: Sep-Nov, Summer: Dec-Feb, Autumn: Mar-May Parameters @@ -133,8 +133,18 @@ def get_sun_times( dt_sunset = Timestamp(sunset(gps.observer, date=date)).tz_convert(tz) # Convert sunrise and sunset to decimal hours - h_sunrise.append(dt_sunrise.hour + dt_sunrise.minute / 60 + dt_sunrise.second / 3600 + dt_sunrise.microsecond / 3_600_000_000) - h_sunset.append(dt_sunset.hour + dt_sunset.minute / 60 + dt_sunset.second / 3600 + dt_sunset.microsecond / 3_600_000_000) + h_sunrise.append( + dt_sunrise.hour + + dt_sunrise.minute / 60 + + dt_sunrise.second / 3600 + + dt_sunrise.microsecond / 3_600_000_000, + ) + h_sunset.append( + dt_sunset.hour + + dt_sunset.minute / 60 + + dt_sunset.second / 3600 + + dt_sunset.microsecond / 3_600_000_000, + ) return h_sunrise, h_sunset @@ -215,8 +225,12 @@ def add_weak_detection( if not max_time: max_time = Timedelta(get_max_time(df), "s") - df["start_datetime"] = [strftime_osmose_format(start) for start in df["start_datetime"]] - df["end_datetime"] = [strftime_osmose_format(stop) for stop in df["end_datetime"]] + df["start_datetime"] = [ + strftime_osmose_format(start) for start in df["start_datetime"] + ] + df["end_datetime"] = [ + strftime_osmose_format(stop) for stop in df["end_datetime"] + ] for ant in annotators: for lbl in labels: @@ -255,17 +269,16 @@ def add_weak_detection( new_line.append(np.nan) df.loc[df.index.max() + 1] = new_line - return df.sort_values(by=["start_datetime", "annotator"]).reset_index(drop=True) def json2df(json_path: Path) -> DataFrame: - """Convert a metadatax json file into a DataFrame. + """Convert a metadatax JSON file into a DataFrame. Parameters ---------- json_path: Path - Json file path + JSON file path """ with json_path.open(encoding="utf-8") as f: @@ -301,15 +314,15 @@ def add_season_period( raise ValueError(msg) bins = date_range( - start=Timestamp(ax.get_xlim()[0], unit="D").floor("1D"), - end=Timestamp(ax.get_xlim()[1], unit="D").ceil("1D"), + start=Timestamp(ax.get_xlim()[0], unit="D"), + end=Timestamp(ax.get_xlim()[1], unit="D"), ) season_colors = { - "winter": "#2ce5e3", - "spring": "#4fcf50", - "summer": "#ffcf50", - "autumn": "#fb9a67", + "winter": "#84eceb", + "spring": "#91de92", + "summer": "#fce097", + "autumn": "#f9c1a5", } bin_centers = [ @@ -330,8 +343,9 @@ def add_season_period( width=(bins[i + 1] - bins[i]), color=season_colors[season], align="center", - zorder=0, - alpha=0.6, + zorder=2, + alpha=1, + linewidth=0, ) ax.set_ylim(ax.dataLim.ymin, ax.dataLim.ymax) @@ -480,7 +494,7 @@ def get_labels_and_annotators(df: DataFrame) -> tuple[list, list]: def localize_timestamps(timestamps: list[Timestamp], tz: tzinfo) -> list[Timestamp]: - """Localize timestamps if necessary.""" + """Localise timestamps if necessary.""" localized = [] for ts in timestamps: if ts.tzinfo is None or ts.tzinfo.utcoffset(ts) is None: @@ -509,20 +523,20 @@ def get_time_range_and_bin_size( if isinstance(bin_size, Timedelta): return timestamp_range, bin_size - elif isinstance(bin_size, BaseOffset): + if isinstance(bin_size, BaseOffset): return timestamp_range, timestamp_range[1] - timestamp_range[0] - else: - msg = "bin_size must be a Timedelta or BaseOffset." - raise TypeError(msg) + msg = "bin_size must be a Timedelta or BaseOffset." + raise TypeError(msg) def round_begin_end_timestamps( timestamp_list: list[Timestamp], bin_size: Timedelta | BaseOffset, -) -> tuple[Timestamp, Timestamp, Timedelta]: +) -> tuple[Any, Any, Any]: """Return time vector given a bin size.""" - if (not isinstance(timestamp_list, list) or - not all(isinstance(ts, Timestamp) for ts in timestamp_list)): + if not isinstance(timestamp_list, list) or not all( + isinstance(ts, Timestamp) for ts in timestamp_list + ): msg = "timestamp_list must be a list[Timestamp]" raise TypeError(msg) @@ -546,14 +560,16 @@ def round_begin_end_timestamps( timestamp_range = date_range(start=start, end=end, freq=bin_size) bin_size = timestamp_range[1] - timestamp_range[0] - return start.floor(bin_size), end.ceil(bin_size), bin_size + if bin_size.resolution_string in {"s", "min", "h"}: + return start.floor(bin_size), end.ceil(bin_size), bin_size + return start, end, bin_size msg = "Could not get start/end timestamps." raise ValueError(msg) def timedelta_to_str(td: Timedelta) -> str: - """From a Timedelta to corresponding string.""" + """From a Timedelta to the corresponding string.""" seconds = int(td.total_seconds()) if seconds % 86400 == 0: diff --git a/src/post_processing/utils/filtering_utils.py b/src/post_processing/utils/filtering_utils.py index c391ff6..faf2fd7 100644 --- a/src/post_processing/utils/filtering_utils.py +++ b/src/post_processing/utils/filtering_utils.py @@ -8,6 +8,7 @@ from typing import TYPE_CHECKING import pytz +from osekit.utils.timestamp_utils import strptime_from_text from pandas import ( DataFrame, Timedelta, @@ -509,8 +510,8 @@ def reshape_timebin( timebin_new: Timedelta The size of the new time bin. timestamp_audio: list[Timestamp] - A list of Timestamp objects corresponding to the shape - in which the data should be reshaped. + A list of Timestamp objects corresponding to the start of each wav + that corresponds to a detection Returns ------- @@ -570,16 +571,17 @@ def get_filename_timestamps(df: DataFrame, date_parser: str) -> list[Timestamp]: """ tz = get_timezone(df) - try: - return [ - to_datetime( - ts, - format=date_parser, - ).tz_localize(tz) for ts in df["filename"] - ] - except ValueError: - msg = """Could not parse timestamps from `df["filename"]`.""" - raise ValueError(msg) from None + timestamps = [ + strptime_from_text( + ts, + datetime_template=date_parser, + ) for ts in df["filename"] + ] + + if all(t.tz is None for t in timestamps): + timestamps = [t.tz_localize(tz) for t in timestamps] + + return timestamps def ensure_in_list(value: str, candidates: list[str], label: str) -> None: diff --git a/src/post_processing/utils/plot_utils.py b/src/post_processing/utils/plot_utils.py index f21e343..8d12fa3 100644 --- a/src/post_processing/utils/plot_utils.py +++ b/src/post_processing/utils/plot_utils.py @@ -11,9 +11,17 @@ import numpy as np from matplotlib import dates as mdates from matplotlib.dates import num2date -from matplotlib.ticker import PercentFormatter +from matplotlib.patches import Patch from numpy import ceil, histogram, polyfit -from pandas import DataFrame, DatetimeIndex, Index, Timedelta, Timestamp, date_range +from pandas import ( + DataFrame, + DatetimeIndex, + Series, + Timedelta, + Timestamp, + concat, + date_range, +) from pandas.tseries import frequencies from scipy.stats import pearsonr from seaborn import scatterplot @@ -28,11 +36,10 @@ timedelta_to_str, ) from post_processing.utils.filtering_utils import ( + filter_by_annotator, get_max_time, get_timezone, - filter_by_annotator, ) -from post_processing.utils.metrics_utils import normalize_counts_by_effort if TYPE_CHECKING: from datetime import tzinfo @@ -68,15 +75,15 @@ def histo( - legend: bool Whether to show the legend. - color: str | list[str] - Color or list of colors for the histogram bars. - If not provided, default colors will be used. + Colour or list of colours for the histogram bars. + If not provided, default colours will be used. - season: bool Whether to show the season. - coordinates: tuple[float, float] The coordinates of the plotted detections. - effort: RecordingPeriod Object corresponding to the observation effort. - If provided, data will be normalized by observation effort. + If provided, data will be normalised by observation effort. """ labels, annotators = zip(*[col.rsplit("-", 1) for col in df.columns], strict=False) @@ -107,9 +114,6 @@ def histo( else: legend_labels = None - if effort: - normalize_counts_by_effort(df, effort, time_bin) - n_groups = len(labels) if legend_labels else 1 bar_width = bin_size / n_groups bin_starts = mdates.date2num(df.index) @@ -130,31 +134,29 @@ def histo( ax.bar(bin_starts + offset, df.iloc[:, i], **bar_kwargs) if len(df.columns) > 1 and legend: - ax.legend(labels=legend_labels, bbox_to_anchor=(1.01, 1), loc="upper left") + ax.legend( + labels=legend_labels, + bbox_to_anchor=(1.01, 1), + loc="upper left", + ) - y_label = ( - f"Detections{(' normalized by effort' if effort else '')}" - f"\n(detections: {timedelta_to_str(time_bin)}" - f" - bin size: {bin_size_str})" - ) - ax.set_ylabel(y_label) - set_y_axis_to_percentage(ax) if effort else set_dynamic_ylim(ax, df) + ax.set_ylabel(f"Detections ({timedelta_to_str(time_bin)})") + ax.set_xlabel(f"Bin size ({bin_size_str})") set_plot_title(ax, annotators, labels) ax.set_xlim(begin, end) - if season: - if lat is None or lon is None: - get_coordinates() - add_season_period(ax, northern=lat >= 0) - if effort: shade_no_effort( ax=ax, - bin_starts=df.index, observed=effort, - bar_width=bin_size, + legend=legend, ) + if season: + if lat is None or lon is None: + get_coordinates() + add_season_period(ax, northern=lat >= 0) + def _prepare_timeline_plot( df: DataFrame, @@ -196,7 +198,6 @@ def _prepare_timeline_plot( ax.set_ylim(0, 24) ax.set_yticks(range(0, 25, 2)) ax.set_ylabel("Hour") - ax.set_xlabel("Date") ax.grid(color="k", linestyle="-", linewidth=0.2) set_plot_title(ax=ax, annotators=annotators, labels=labels) @@ -213,7 +214,7 @@ def scatter( df: DataFrame, ax: Axes, time_range: DatetimeIndex, - **kwargs: bool | tuple[float, float], + **kwargs: bool | tuple[float, float] | RecordingPeriod, ) -> None: """Scatter-plot of detections for a given annotator and label. @@ -237,6 +238,7 @@ def scatter( show_rise_set = kwargs.get("show_rise_set", False) season = kwargs.get("season", False) coordinates = kwargs.get("coordinates", False) + effort = kwargs.get("effort", False) _prepare_timeline_plot( df=df, @@ -276,6 +278,12 @@ def scatter( framealpha=0.6, ) + if effort: + shade_no_effort( + ax=ax, + observed=effort, + ) + def heatmap(df: DataFrame, ax: Axes, @@ -359,7 +367,7 @@ def heatmap(df: DataFrame, ) if coordinates and season: - lat, lon = coordinates + lat, _ = coordinates add_season_period(ax, northern=lat >= 0) bin_size_str = get_bin_size_str(bin_size) @@ -466,7 +474,7 @@ def agreement( bin_size: Timedelta | BaseOffset, ax: plt.Axes, ) -> None: - """Compute and visualize agreement between two annotators. + """Compute and visualise agreement between two annotators. This function compares annotation timestamps from two annotators over a time range. It also fits and plots a linear regression line and displays the coefficient @@ -487,41 +495,33 @@ def agreement( """ labels, annotators = get_labels_and_annotators(df) - datetimes1 = list( - df[(df["annotator"] == annotators[0]) & (df["annotation"] == labels[0])][ - "start_datetime" - ], - ) - datetimes2 = list( - df[(df["annotator"] == annotators[1]) & (df["annotation"] == labels[1])][ - "start_datetime" - ], - ) + datetimes = [ + list( + df[ + (df["annotator"] == annotators[i]) & (df["annotation"] == labels[i]) + ]["start_datetime"], + ) + for i in range(2) + ] # scatter plot n_annot_max = bin_size.total_seconds() / df["end_time"].iloc[0] - start = df["start_datetime"].min() - stop = df["start_datetime"].max() - freq = ( bin_size if isinstance(bin_size, Timedelta) else str(bin_size.n) + bin_size.name ) bins = date_range( - start=start.floor(bin_size), - end=stop.ceil(bin_size), + start=df["start_datetime"].min().floor(bin_size), + end=df["start_datetime"].max().ceil(bin_size), freq=freq, ) - hist1, _ = histogram(datetimes1, bins=bins) - hist2, _ = histogram(datetimes2, bins=bins) - df_hist = ( DataFrame( { - annotators[0]: hist1, - annotators[1]: hist2, + annotators[0]: histogram(datetimes[0], bins=bins)[0], + annotators[1]: histogram(datetimes[1], bins=bins)[0], }, ) / n_annot_max @@ -557,8 +557,8 @@ def timeline( Matplotlib axes object where the scatterplot and regression line will be drawn. **kwargs: Additional keyword arguments depending on the mode. - color: str | list[str] - Color or list of colors for the histogram bars. - If not provided, default colors will be used. + Colour or list of colours for the histogram bars. + If not provided, default colours will be used. """ color = kwargs.get("color") @@ -606,12 +606,15 @@ def get_bin_size_str(bin_size: Timedelta | BaseOffset) -> str: return str(bin_size.n) + bin_size.freqstr -def set_y_axis_to_percentage( - ax: plt.Axes, -) -> None: +def set_y_axis_to_percentage(ax: plt.Axes, max_val: float) -> None: """Set y-axis to percentage.""" - ax.yaxis.set_major_formatter(PercentFormatter(xmax=1.0)) - ax.set_yticks(np.arange(0, 1.02, 0.2)) + ax.yaxis.set_major_formatter( + plt.FuncFormatter(lambda y, _: f"{(y / max_val) * 100:.0f}%"), + ) + + current_label = ax.get_ylabel() + if current_label and "%" not in current_label: + ax.set_ylabel(f"{current_label} (%)") def set_dynamic_ylim(ax: plt.Axes, @@ -639,9 +642,8 @@ def set_plot_title(ax: plt.Axes, annotators: list[str], labels: list[str]) -> No def shade_no_effort( ax: plt.Axes, - bin_starts: Index, observed: RecordingPeriod, - bar_width: Timedelta, + legend: bool, ) -> None: """Shade areas of the plot where no observation effort was made. @@ -649,31 +651,96 @@ def shade_no_effort( ---------- ax : plt.Axes The axes on which to draw the shaded regions. - bin_starts : Index - A datetime index representing the start times of each bin. observed : RecordingPeriod A Series with observation counts or flags, indexed by datetime. Should be aligned or re-indexable to `bin_starts`. - bar_width : Timedelta - Width of each time bin. Used to compute the span of the shaded areas. - + legend : bool + Wether to add the legend entry for the shaded regions. """ + # Convert effort IntervalIndex → DatetimeIndex (bin starts) + effort_by_start = Series( + observed.counts.values, + index=[i.left for i in observed.counts.index], + ) + + bar_width = effort_by_start.index[1] - effort_by_start.index[0] width_days = bar_width.total_seconds() / 86400 - no_effort_bins = bin_starts[observed.counts.reindex(bin_starts) == 0] - for ts in no_effort_bins: - start = mdates.date2num(ts) - ax.axvspan(start, start + width_days, color="grey", alpha=0.08, zorder=1.5) - x_min, x_max = ax.get_xlim() - data_min = mdates.date2num(bin_starts[0]) - data_max = mdates.date2num(bin_starts[-1]) + width_days - - if x_min < data_min: - ax.axvspan(x_min, data_min, color="grey", alpha=0.08, zorder=1.5) - if x_max > data_max: - ax.axvspan(data_max, x_max, color="grey", alpha=0.08, zorder=1.5) - ax.set_xlim(x_min, x_max) + max_effort = bar_width / observed.timebin_origin + effort_fraction = effort_by_start / max_effort + + first_elem = Series([0], index=[effort_fraction.index[0] - bar_width]) + last_elem = Series([0], index=[effort_fraction.index[-1] + bar_width]) + effort_fraction = concat([first_elem, effort_fraction, last_elem]) + + no_effort = effort_fraction[effort_fraction == 0] + partial_effort = effort_fraction[(effort_fraction > 0) & (effort_fraction < 1)] + + # Get legend handle + handles1, labels1 = ax.get_legend_handles_labels() + + _draw_effort_spans( + ax=ax, + effort_index=partial_effort.index, + width_days=width_days, + facecolor="0.65", + alpha=0.1, + label="partial data", + ) + + _draw_effort_spans( + ax=ax, + effort_index=no_effort.index, + width_days=width_days, + facecolor="0.45", + alpha=0.15, + label="no data", + ) + + # Add effort legend to current plot legend + handles_effort = [] + if len(partial_effort) > 0: + handles_effort.append( + Patch(facecolor="0.65", alpha=0.1, label="partial data"), + ) + if len(no_effort) > 0: + handles_effort.append( + Patch(facecolor="0.45", alpha=0.15, label="no data"), + ) + if handles_effort and legend: + labels_effort = [h.get_label() for h in handles_effort] + handles = handles1 + handles_effort + labels = labels1 + labels_effort + ax.legend( + handles, + labels, + bbox_to_anchor=(1.01, 1), + loc="upper left", + ) + + +def _draw_effort_spans( + ax: plt.Axes, + effort_index: DatetimeIndex, + width_days: float, + *, + facecolor: str, + alpha: float, + label: str, +) -> None: + """Draw vertical lines for effort plot.""" + for ts in effort_index: + start = mdates.date2num(ts) + ax.axvspan( + start, + start + width_days, + facecolor=facecolor, + alpha=alpha, + linewidth=0, + zorder=1, + label=label, + ) def add_sunrise_sunset(ax: Axes, lat: float, lon: float, tz: tzinfo) -> None: diff --git a/tests/conftest.py b/tests/conftest.py index e03bf43..a6299e3 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -8,6 +8,7 @@ import yaml from osekit.utils.timestamp_utils import strftime_osmose_format from pandas import DataFrame, read_csv +from pandas.tseries import frequencies SAMPLE = """dataset,filename,start_time,end_time,start_frequency,end_frequency,annotation,annotator,start_datetime,end_datetime,type,score sample_dataset,2025_01_25_06_20_00,0.0,10.0,0.0,72000.0,lbl2,ann2,2025-01-25T06:20:00.000+00:00,2025-01-25T06:20:10.000+00:00,WEAK,0.11 @@ -122,8 +123,6 @@ """ - - STATUS = """dataset,filename,ann1,ann2,ann3,ann4,ann5,ann6 sample_dataset,2025_01_25_06_20_00,FINISHED,FINISHED,FINISHED,FINISHED,FINISHED,FINISHED sample_dataset,2025_01_25_06_20_10,FINISHED,FINISHED,FINISHED,FINISHED,FINISHED,FINISHED @@ -134,6 +133,14 @@ sample_dataset,2025_01_26_06_20_20,FINISHED,FINISHED,FINISHED,FINISHED,FINISHED,FINISHED """ +# --------------------------------------------------------------------------- +# Fake recording planning CSV used for tests +# --------------------------------------------------------------------------- +RECORDING_PLANNING_CSV = """start_recording,end_recording,start_deployment,end_deployment +2024-01-01 00:00:00+0000,2024-04-09 02:00:00+0000,2024-01-02 00:00:00+0000,2024-04-30 02:00:00+0000 +2024-04-30 01:00:00+0000,2024-07-14 06:00:00+0000,2024-04-30 02:00:00+0000,2024-07-06 14:00:00+0000 +""" + @pytest.fixture def sample_df() -> DataFrame: @@ -228,3 +235,21 @@ def create_file(path: Path, size: int = 2048): create_file(nested / "file4.wav") (tmp_path / "ignore.txt").write_text("not audio") return tmp_path + + +@pytest.fixture +def recording_planning_csv(tmp_path) -> Path: + """Create a temporary CSV file simulating a recording planning.""" + path = tmp_path / "recording_planning.csv" + path.write_text(RECORDING_PLANNING_CSV) + return path + + +@pytest.fixture +def recording_planning_config(recording_planning_csv): + """Minimal config object compatible with RecordingPeriod.from_path.""" + class RecordingPlanningConfig: + timestamp_file: Path = recording_planning_csv + timebin_origin = frequencies.to_offset("1min") + + return RecordingPlanningConfig() diff --git a/tests/test_DataAplose.py b/tests/test_DataAplose.py index 5ad1b04..9b9516c 100644 --- a/tests/test_DataAplose.py +++ b/tests/test_DataAplose.py @@ -19,6 +19,7 @@ def test_data_aplose_init(sample_df: DataFrame) -> None: assert data.begin == sample_df["start_datetime"].min() assert data.end == sample_df["end_datetime"].max() + def test_filter_df_single_pair(sample_df: DataFrame) -> None: data = DataAplose(sample_df) filtered_data = data.filter_df(annotator="ann1", label="lbl1") @@ -30,17 +31,19 @@ def test_filter_df_single_pair(sample_df: DataFrame) -> None: ].reset_index(drop=True) assert filtered_data.equals(expected) + def test_change_tz(sample_df: DataFrame) -> None: data = DataAplose(sample_df) - new_tz = 'Etc/GMT-7' + new_tz = "Etc/GMT-7" data.change_tz(new_tz) - start_dt = data.df['start_datetime'] - end_dt = data.df['end_datetime'] + start_dt = data.df["start_datetime"] + end_dt = data.df["end_datetime"] assert all(ts.tz.zone == new_tz for ts in start_dt), f"The detection start timestamps have to be in {new_tz} timezone" assert all(ts.tz.zone == new_tz for ts in end_dt), f"The detection end timestamps have to be in {new_tz} timezone" assert data.begin.tz.zone == new_tz, f"The begin value of the DataAplose has to be in {new_tz} timezone" assert data.end.tz.zone == new_tz, f"The end value of the DataAplose has to be in {new_tz} timezone" + def test_filter_df_multiple_pairs(sample_df: DataFrame) -> None: data = DataAplose(sample_df) filtered_data = data.filter_df(annotator=["ann1", "ann2"], label=["lbl1", "lbl2"]) diff --git a/tests/test_core_utils.py b/tests/test_core_utils.py index a1a3d73..e72e482 100644 --- a/tests/test_core_utils.py +++ b/tests/test_core_utils.py @@ -8,6 +8,8 @@ from post_processing.dataclass.data_aplose import DataAplose from post_processing.utils.core_utils import ( + add_recording_period, + add_season_period, add_weak_detection, get_coordinates, get_count, @@ -15,13 +17,11 @@ get_season, get_sun_times, get_time_range_and_bin_size, + json2df, localize_timestamps, round_begin_end_timestamps, - timedelta_to_str, - add_season_period, - add_recording_period, set_bar_height, - json2df, + timedelta_to_str, ) @@ -409,10 +409,11 @@ def test_add_season_no_data() -> None: # %% add_recording_period + def test_add_recording_period_valid() -> None: fig, ax = plt.subplots() start = Timestamp("2025-01-01T00:00:00+00:00") - stop = Timestamp("2025-01-02T00:00:00+00:00") + stop = Timestamp("2025-01-02T00:00:00+00:00") ts = date_range(start=start, end=stop, freq="H", tz="UTC") values = list(range(len(ts))) @@ -423,7 +424,7 @@ def test_add_recording_period_valid() -> None: [ Timestamp("2025-01-01T00:00:00+00:00"), Timestamp("2025-01-02T00:00:00+00:00"), - ] + ], ], columns=["deployment_date", "recovery_date"], ) @@ -438,6 +439,7 @@ def test_add_recording_period_no_data() -> None: # %% set_bar_height + def test_set_bar_height_valid() -> None: fig, ax = plt.subplots() start = Timestamp("2025-01-01T00:00:00+00:00") @@ -457,6 +459,7 @@ def test_set_bar_height_no_data() -> None: # %% json2df + def test_json2df_valid(tmp_path): fake_json = { "deployment_date": "2025-01-01T00:00:00+00:00", @@ -474,9 +477,9 @@ def test_json2df_valid(tmp_path): [ Timestamp("2025-01-01T00:00:00+00:00"), Timestamp("2025-01-02T00:00:00+00:00"), - ] + ], ], columns=["deployment_date", "recovery_date"], ) - assert df.equals(expected) \ No newline at end of file + assert df.equals(expected) diff --git a/tests/test_filtering_utils.py b/tests/test_filtering_utils.py index 95fd987..3ec3760 100644 --- a/tests/test_filtering_utils.py +++ b/tests/test_filtering_utils.py @@ -77,7 +77,7 @@ def test_find_delimiter_unsupported_delimiter(tmp_path: Path) -> None: with pytest.raises( ValueError, - match=r"unsupported delimiter '&'" + match=r"unsupported delimiter '&'", ): find_delimiter(file) @@ -199,6 +199,7 @@ def test_filter_by_freq_valid(sample_df: DataFrame, f_min, f_max): if f_max is not None: assert (result["end_frequency"] <= f_max).all() + @pytest.mark.parametrize( "f_min, f_max, expected_msg", [ @@ -216,8 +217,6 @@ def test_filter_by_freq_valid(sample_df: DataFrame, f_min, f_max): ), ], ) - - def test_filter_by_freq_out_of_range(sample_df: DataFrame, f_min, f_max, expected_msg): with pytest.raises(ValueError, match=expected_msg): filter_by_freq(sample_df, f_min=f_min, f_max=f_max) @@ -331,7 +330,7 @@ def test_get_timezone_several(sample_df: DataFrame) -> None: } sample_df = concat( [sample_df, DataFrame([new_row])], - ignore_index=False + ignore_index=False, ) tz = get_timezone(sample_df) assert len(tz) == 2 @@ -340,6 +339,7 @@ def test_get_timezone_several(sample_df: DataFrame) -> None: # %% read DataFrame + def test_read_dataframe_comma_delimiter(tmp_path: Path) -> None: csv_file = tmp_path / "test.csv" csv_file.write_text( @@ -417,7 +417,7 @@ def test_no_timebin_several_tz(sample_df: DataFrame) -> None: } sample_df = concat( [sample_df, DataFrame([new_row])], - ignore_index=False + ignore_index=False, ) timestamp_wav = to_datetime(sample_df["filename"], format="%Y_%m_%d_%H_%M_%S").dt.tz_localize(pytz.UTC) @@ -429,7 +429,7 @@ def test_no_timebin_original_timebin(sample_df: DataFrame) -> None: tz = get_timezone(sample_df) timestamp_wav = to_datetime( sample_df["filename"], - format="%Y_%m_%d_%H_%M_%S" + format="%Y_%m_%d_%H_%M_%S", ).dt.tz_localize(tz) df_out = reshape_timebin( sample_df, @@ -520,7 +520,7 @@ def test_simple_reshape_hourly(sample_df: DataFrame) -> None: tz = get_timezone(sample_df) timestamp_wav = to_datetime( sample_df["filename"], - format="%Y_%m_%d_%H_%M_%S" + format="%Y_%m_%d_%H_%M_%S", ).dt.tz_localize(tz) df_out = reshape_timebin( sample_df, @@ -538,7 +538,7 @@ def test_reshape_daily_multiple_bins(sample_df: DataFrame) -> None: tz = get_timezone(sample_df) timestamp_wav = to_datetime( sample_df["filename"], - format="%Y_%m_%d_%H_%M_%S" + format="%Y_%m_%d_%H_%M_%S", ).dt.tz_localize(tz) df_out = reshape_timebin(sample_df, timestamp_audio=timestamp_wav, timebin_new=Timedelta(days=1)) assert not df_out.empty @@ -555,7 +555,7 @@ def test_with_manual_timestamps_vector(sample_df: DataFrame) -> None: df_out = reshape_timebin( sample_df, timestamp_audio=timestamp_wav, - timebin_new=Timedelta(hours=1) + timebin_new=Timedelta(hours=1), ) assert not df_out.empty @@ -589,6 +589,7 @@ def test_ensure_no_invalid_with_elements() -> None: assert "bar" in str(exc_info.value) assert "columns" in str(exc_info.value) + def test_ensure_no_invalid_single_element() -> None: invalid_items = ["baz"] with pytest.raises(ValueError) as exc_info: @@ -598,6 +599,7 @@ def test_ensure_no_invalid_single_element() -> None: # %% intersection / union + def test_intersection(sample_df) -> None: df_result = intersection_or_union(sample_df[sample_df["annotator"].isin(["ann1", "ann2"])], user_sel="intersection") @@ -628,7 +630,7 @@ def test_not_enough_annotators_raises() -> None: "annotation": ["cat"], "start_datetime": to_datetime(["2025-01-01 10:00"]), "end_datetime": to_datetime(["2025-01-01 10:01"]), - "annotator": ["A"] + "annotator": ["A"], }) with pytest.raises(ValueError, match="Not enough annotators detected"): - intersection_or_union(df_single_annotator, user_sel="intersection") \ No newline at end of file + intersection_or_union(df_single_annotator, user_sel="intersection") diff --git a/tests/test_glider_utils.py b/tests/test_glider_utils.py index 12d83df..d0247c5 100644 --- a/tests/test_glider_utils.py +++ b/tests/test_glider_utils.py @@ -56,7 +56,7 @@ def test_get_position_from_timestamp(nav_df: DataFrame) -> None: def test_plot_detections_with_nav_data( df_detections: DataFrame, - nav_df: DataFrame + nav_df: DataFrame, ) -> None: plot_detections_with_nav_data( df=df_detections, diff --git a/tests/test_metric_utils.py b/tests/test_metric_utils.py index 34ce769..35717e7 100644 --- a/tests/test_metric_utils.py +++ b/tests/test_metric_utils.py @@ -3,6 +3,7 @@ from post_processing.utils.metrics_utils import detection_perf + def test_detection_perf(sample_df: DataFrame) -> None: try: detection_perf(df=sample_df[sample_df["annotator"].isin(["ann1", "ann4"])], ref=("ann1", "lbl1")) @@ -12,4 +13,4 @@ def test_detection_perf(sample_df: DataFrame) -> None: def test_detection_perf_one_annotator(sample_df: DataFrame) -> None: with pytest.raises(ValueError, match="Two annotators needed"): - detection_perf(df=sample_df[sample_df["annotator"] == "ann1"], ref=("ann1", "lbl1")) \ No newline at end of file + detection_perf(df=sample_df[sample_df["annotator"] == "ann1"], ref=("ann1", "lbl1")) diff --git a/tests/test_plot_utils.py b/tests/test_plot_utils.py index d7392cf..4306c38 100644 --- a/tests/test_plot_utils.py +++ b/tests/test_plot_utils.py @@ -1,13 +1,11 @@ import matplotlib.pyplot as plt import pytest -from matplotlib.ticker import PercentFormatter -from numpy import arange, testing from post_processing.utils.plot_utils import ( - overview, _wrap_xtick_labels, - set_y_axis_to_percentage, get_legend, + overview, + set_y_axis_to_percentage, ) @@ -57,16 +55,15 @@ def test_wrap_xtick_labels_no_spaces(): assert wrapped_labels[0] == expected -def test_y_axis_formatter_and_ticks(): +def test_set_y_axis_to_percentage(): fig, ax = plt.subplots() - - set_y_axis_to_percentage(ax) - - assert isinstance(ax.yaxis.get_major_formatter(), PercentFormatter) - assert ax.yaxis.get_major_formatter().xmax == 1.0 - - expected_ticks = arange(0, 1.02, 0.2) - testing.assert_allclose(ax.get_yticks(), expected_ticks) + ax.set_ylabel("Accuracy") + set_y_axis_to_percentage(ax, max_val=200) + formatter = ax.yaxis.get_major_formatter() + assert formatter(100, None) == "50%" + assert formatter(200, None) == "100%" + assert ax.get_ylabel() == "Accuracy (%)" + plt.close(fig) def test_single_annotator_multiple_labels(): @@ -103,4 +100,4 @@ def test_lists_and_strings_combined(): labels = ["Label1", "Label2"] result = get_legend(annotators, labels) expected = ["Alice\nLabel1", "Bob\nLabel2"] - assert result == expected \ No newline at end of file + assert result == expected diff --git a/tests/test_recording_period.py b/tests/test_recording_period.py new file mode 100644 index 0000000..064c6b5 --- /dev/null +++ b/tests/test_recording_period.py @@ -0,0 +1,80 @@ +from pandas import Timedelta, read_csv, to_datetime + +from post_processing.dataclass.detection_filter import DetectionFilter +from post_processing.dataclass.recording_period import RecordingPeriod + + +def test_recording_period_with_gaps(recording_planning_config: DetectionFilter) -> None: + """RecordingPeriod correctly represents long gaps with no recording effort. + + The planning contains two recording blocks separated by ~3 weeks with no + recording at all. Weekly aggregation must reflect: + - weeks with full effort, + - weeks with partial effort, + - weeks with zero effort. + """ + histo_x_bin_size = Timedelta("7D") + recording_period = RecordingPeriod.from_path( + config=recording_planning_config, + bin_size=histo_x_bin_size, + ) + + counts = recording_period.counts + origin = recording_planning_config.timebin_origin + nb_timebin_origin_per_histo_x_bin_size = int(histo_x_bin_size / origin) + + # Computes effective recording intervals from recording planning csv + df_planning = read_csv( + recording_planning_config.timestamp_file, + parse_dates=[ + "start_recording", + "end_recording", + "start_deployment", + "end_deployment", + ], + ) + for col in [ + "start_recording", + "end_recording", + "start_deployment", + "end_deployment", + ]: + df_planning[col] = ( + to_datetime(df_planning[col], utc=True) + .dt.tz_convert(None) + ) + + df_planning["start"] = df_planning[ + ["start_recording", "start_deployment"] + ].max(axis=1) + df_planning["end"] = df_planning[ + ["end_recording", "end_deployment"] + ].min(axis=1) + + planning = df_planning.loc[df_planning["start"] < df_planning["end"]] + # ------------------------------------------------------------------ + # Structural checks + # ------------------------------------------------------------------ + assert not counts.empty + assert counts.index.is_interval() + assert counts.min() >= 0 + assert counts.max() <= nb_timebin_origin_per_histo_x_bin_size + + # ------------------------------------------------------------------ + # Find overlap (number of timebin_origin) within each effective recording period + # ------------------------------------------------------------------ + for interval in counts.index: + bin_start = interval.left + bin_end = interval.right + + # Compute overlap with all recording intervals + overlap_start = planning["start"].clip(lower=bin_start, upper=bin_end) + overlap_end = planning["end"].clip(lower=bin_start, upper=bin_end) + + overlap = (overlap_end - overlap_start).clip(lower=Timedelta(0)) + expected_minutes = int(overlap.sum() / recording_planning_config.timebin_origin) + + assert counts.loc[interval] == expected_minutes, ( + f"Mismatch for bin {interval}: " + f"expected {expected_minutes}, got {counts.loc[interval]}" + )