diff --git a/bats_ai/core/admin/pulse_metadata.py b/bats_ai/core/admin/pulse_metadata.py index d1d07e25..b6ba8922 100644 --- a/bats_ai/core/admin/pulse_metadata.py +++ b/bats_ai/core/admin/pulse_metadata.py @@ -5,5 +5,5 @@ @admin.register(PulseMetadata) class PulseMetadataAdmin(admin.ModelAdmin): - list_display = ('recording', 'index', 'bounding_box') + list_display = ('recording', 'index', 'bounding_box', 'curve', 'char_freq', 'knee', 'heel') list_select_related = True diff --git a/bats_ai/core/migrations/0029_pulsemetadata_char_freq_pulsemetadata_curve_and_more.py b/bats_ai/core/migrations/0029_pulsemetadata_char_freq_pulsemetadata_curve_and_more.py new file mode 100644 index 00000000..33f0ee62 --- /dev/null +++ b/bats_ai/core/migrations/0029_pulsemetadata_char_freq_pulsemetadata_curve_and_more.py @@ -0,0 +1,35 @@ +# Generated by Django 4.2.23 on 2026-02-03 19:43 + +import django.contrib.gis.db.models.fields +from django.db import migrations + + +class Migration(migrations.Migration): + dependencies = [ + ('core', '0028_alter_spectrogramimage_type_pulsemetadata'), + ] + + operations = [ + migrations.AddField( + model_name='pulsemetadata', + name='char_freq', + field=django.contrib.gis.db.models.fields.PointField(blank=True, null=True, srid=4326), + ), + migrations.AddField( + model_name='pulsemetadata', + name='curve', + field=django.contrib.gis.db.models.fields.LineStringField( + blank=True, null=True, srid=4326 + ), + ), + migrations.AddField( + model_name='pulsemetadata', + name='heel', + field=django.contrib.gis.db.models.fields.PointField(blank=True, null=True, srid=4326), + ), + migrations.AddField( + model_name='pulsemetadata', + name='knee', + field=django.contrib.gis.db.models.fields.PointField(blank=True, null=True, srid=4326), + ), + ] diff --git a/bats_ai/core/models/pulse_metadata.py b/bats_ai/core/models/pulse_metadata.py index b1dcc8bb..b3622fc6 100644 --- a/bats_ai/core/models/pulse_metadata.py +++ b/bats_ai/core/models/pulse_metadata.py @@ -8,4 +8,7 @@ class PulseMetadata(models.Model): index = models.IntegerField(null=False, blank=False) bounding_box = models.PolygonField(null=False, blank=False) contours = models.JSONField(null=True, blank=True) - # TODO: Add in metadata from batbot + curve = models.LineStringField(null=True, blank=True) + char_freq = models.PointField(null=True, blank=True) + knee = models.PointField(null=True, blank=True) + heel = models.PointField(null=True, blank=True) diff --git a/bats_ai/core/tasks/nabat/tasks.py b/bats_ai/core/tasks/nabat/tasks.py index 45c7bf90..68b24e4f 100644 --- a/bats_ai/core/tasks/nabat/tasks.py +++ b/bats_ai/core/tasks/nabat/tasks.py @@ -2,9 +2,10 @@ from pathlib import Path import tempfile +from django.contrib.gis.geos import LineString, Point, Polygon import requests -from bats_ai.core.models import ProcessingTask +from bats_ai.core.models import ProcessingTask, PulseMetadata from bats_ai.core.models.nabat import NABatRecording from bats_ai.utils.spectrogram_utils import ( generate_nabat_compressed_spectrogram, @@ -54,7 +55,51 @@ def generate_spectrograms( compressed = results['compressed'] - generate_nabat_compressed_spectrogram(nabat_recording, spectrogram, compressed) + compressed_obj = generate_nabat_compressed_spectrogram( + nabat_recording, spectrogram, compressed + ) + segment_index_map = {} + for segment in compressed['contours']['segments']: + pulse_metadata_obj, _ = PulseMetadata.objects.get_or_create( + recording=compressed_obj.recording, + index=segment['segment_index'], + defaults={ + 'contours': segment['contours'], + 'bounding_box': Polygon( + ( + (segment['start_ms'], segment['freq_max']), + (segment['stop_ms'], segment['freq_max']), + (segment['stop_ms'], segment['freq_min']), + (segment['start_ms'], segment['freq_min']), + (segment['start_ms'], segment['freq_max']), + ) + ), + }, + ) + segment_index_map[segment['segment_index']] = pulse_metadata_obj + for segment in compressed['segments']: + if segment['segment_index'] not in segment_index_map: + PulseMetadata.objects.update_or_create( + recording=compressed_obj.recording, + index=segment['segment_index'], + defaults={ + 'curve': LineString([Point(x[1], x[0]) for x in segment['curve_hz_ms']]), + 'char_freq': Point(segment['char_freq_ms'], segment['char_freq_hz']), + 'knee': Point(segment['knee_ms'], segment['knee_hz']), + 'heel': Point(segment['heel_ms'], segment['heel_hz']), + }, + ) + else: + pulse_metadata_obj = segment_index_map[segment['segment_index']] + pulse_metadata_obj.curve = LineString( + [Point(x[1], x[0]) for x in segment['curve_hz_ms']] + ) + pulse_metadata_obj.char_freq = Point( + segment['char_freq_ms'], segment['char_freq_hz'] + ) + pulse_metadata_obj.knee = Point(segment['knee_ms'], segment['knee_hz']) + pulse_metadata_obj.heel = Point(segment['heel_ms'], segment['heel_hz']) + pulse_metadata_obj.save() processing_task.status = ProcessingTask.Status.COMPLETE processing_task.save() diff --git a/bats_ai/core/tasks/tasks.py b/bats_ai/core/tasks/tasks.py index 271ab1c7..4f91f1c4 100644 --- a/bats_ai/core/tasks/tasks.py +++ b/bats_ai/core/tasks/tasks.py @@ -4,7 +4,7 @@ import tempfile from django.contrib.contenttypes.models import ContentType -from django.contrib.gis.geos import Polygon +from django.contrib.gis.geos import LineString, Point, Polygon from django.core.files import File from bats_ai.celery import app @@ -101,8 +101,9 @@ def recording_compute_spectrogram(recording_id: int): ) # Create SpectrogramContour objects for each segment - for segment in results['segments']['segments']: - PulseMetadata.objects.update_or_create( + segment_index_map = {} + for segment in compressed['contours']['segments']: + pulse_metadata_obj, _ = PulseMetadata.objects.update_or_create( recording=compressed_obj.recording, index=segment['segment_index'], defaults={ @@ -118,5 +119,29 @@ def recording_compute_spectrogram(recording_id: int): ), }, ) + segment_index_map[segment['segment_index']] = pulse_metadata_obj + for segment in compressed['segments']: + if segment['segment_index'] not in segment_index_map: + PulseMetadata.objects.update_or_create( + recording=compressed_obj.recording, + index=segment['segment_index'], + defaults={ + 'curve': LineString([Point(x[1], x[0]) for x in segment['curve_hz_ms']]), + 'char_freq': Point(segment['char_freq_ms'], segment['char_freq_hz']), + 'knee': Point(segment['knee_ms'], segment['knee_hz']), + 'heel': Point(segment['heel_ms'], segment['heel_hz']), + }, + ) + else: + pulse_metadata_obj = segment_index_map[segment['segment_index']] + pulse_metadata_obj.curve = LineString( + [Point(x[1], x[0]) for x in segment['curve_hz_ms']] + ) + pulse_metadata_obj.char_freq = Point( + segment['char_freq_ms'], segment['char_freq_hz'] + ) + pulse_metadata_obj.knee = Point(segment['knee_ms'], segment['knee_hz']) + pulse_metadata_obj.heel = Point(segment['heel_ms'], segment['heel_hz']) + pulse_metadata_obj.save() return {'spectrogram_id': spectrogram.id, 'compressed_id': compressed_obj.id} diff --git a/bats_ai/core/utils/batbot_metadata.py b/bats_ai/core/utils/batbot_metadata.py index db0cb473..1caa9bb2 100644 --- a/bats_ai/core/utils/batbot_metadata.py +++ b/bats_ai/core/utils/batbot_metadata.py @@ -1,5 +1,6 @@ from contextlib import contextmanager import json +import logging import os from pathlib import Path from typing import Any, TypedDict @@ -15,6 +16,8 @@ from .contour_utils import process_spectrogram_assets_for_contours +logger = logging.getLogger(__name__) + class SpectrogramMetadata(BaseModel): """Metadata about the spectrogram.""" @@ -261,6 +264,17 @@ class SpectrogramContourSegment(TypedDict): stop_ms: float +class BatBotMetadataCurve(TypedDict): + segment_index: int + curve_hz_ms: list[float] + char_freq_ms: float + char_freq_hz: float + knee_ms: float + knee_hz: float + heel_ms: float + heel_hz: float + + class SpectrogramContours(TypedDict): segments: list[SpectrogramContourSegment] total_segments: int @@ -272,7 +286,7 @@ class SpectrogramAssets(TypedDict): freq_max: int normal: SpectrogramAssetResult compressed: SpectrogramCompressedAssetResult - segments: SpectrogramContours | None + contours: SpectrogramContours | None @contextmanager @@ -285,6 +299,25 @@ def working_directory(path): os.chdir(previous) +def convert_to_segment_data( + metadata: BatbotMetadata, +) -> list[BatBotMetadataCurve]: + segment_data: list[BatBotMetadataCurve] = [] + for index, segment in enumerate(metadata.segments): + segment_data_item: BatBotMetadataCurve = { + 'segment_index': index, + 'curve_hz_ms': segment.curve_hz_ms, + 'char_freq_ms': segment.fc_ms, + 'char_freq_hz': segment.fc_hz, + 'knee_ms': segment.hi_fc_knee_ms, + 'knee_hz': segment.hi_fc_knee_hz, + 'heel_ms': segment.lo_fc_heel_ms, + 'heel_hz': segment.lo_fc_heel_hz, + } + segment_data.append(segment_data_item) + return segment_data + + def generate_spectrogram_assets(recording_path: str, output_folder: str): batbot.pipeline(recording_path, output_folder=output_folder) # There should be a .metadata.json file in the output_base directory by replacing extentions @@ -300,6 +333,7 @@ def generate_spectrogram_assets(recording_path: str, output_folder: str): metadata.frequencies.max_hz compressed_metadata = convert_to_compressed_spectrogram_data(metadata) + segment_curve_data = convert_to_segment_data(metadata) result: SpectrogramAssets = { 'duration': metadata.duration_ms, 'freq_min': metadata.frequencies.min_hz, @@ -317,10 +351,11 @@ def generate_spectrogram_assets(recording_path: str, output_folder: str): 'widths': compressed_metadata.widths, 'starts': compressed_metadata.starts, 'stops': compressed_metadata.stops, + 'segments': segment_curve_data, }, } - segments_data = process_spectrogram_assets_for_contours(result) - result['segments'] = segments_data + contour_segments_data = process_spectrogram_assets_for_contours(result) + result['compressed']['contours'] = contour_segments_data return result diff --git a/bats_ai/core/views/recording.py b/bats_ai/core/views/recording.py index 300d6fb2..65351020 100644 --- a/bats_ai/core/views/recording.py +++ b/bats_ai/core/views/recording.py @@ -129,7 +129,7 @@ class UpdateAnnotationsSchema(Schema): id: int | None -class PulseMetadataSchema(Schema): +class PulseContourSchema(Schema): id: int | None index: int bounding_box: Any @@ -145,6 +145,36 @@ def from_orm(cls, obj: PulseMetadata): ) +class PulseMetadataSchema(Schema): + id: int | None + index: int + curve: list[list[float]] | None = None # list of [time, frequency] + char_freq: list[float] | None = None # point [time, frequency] + knee: list[float] | None = None # point [time, frequency] + heel: list[float] | None = None # point [time, frequency] + + @classmethod + def from_orm(cls, obj: PulseMetadata): + def point_to_list(pt): + if pt is None: + return None + return [pt.x, pt.y] + + def linestring_to_list(ls): + if ls is None: + return None + return [[c[0], c[1]] for c in ls.coords] + + return cls( + id=obj.id, + index=obj.index, + curve=linestring_to_list(obj.curve), + char_freq=point_to_list(obj.char_freq), + knee=point_to_list(obj.knee), + heel=point_to_list(obj.heel), + ) + + @router.post('/') def create_recording( request: HttpRequest, @@ -559,6 +589,25 @@ def get_annotations(request: HttpRequest, id: int): return {'error': 'Recording not found'} +@router.get('/{id}/pulse_contours') +def get_pulse_contours(request: HttpRequest, id: int): + try: + recording = Recording.objects.get(pk=id) + if recording.owner == request.user or recording.public: + computed_pulse_annotation_qs = PulseMetadata.objects.filter( + recording=recording + ).order_by('index') + return [ + PulseContourSchema.from_orm(pulse) for pulse in computed_pulse_annotation_qs.all() + ] + else: + return { + 'error': 'Permission denied. You do not own this recording, and it is not public.' + } + except Recording.DoesNotExist: + return {'error': 'Recording not found'} + + @router.get('/{id}/pulse_data') def get_pulse_data(request: HttpRequest, id: int): try: diff --git a/client/src/api/api.ts b/client/src/api/api.ts index e8612c94..b6505957 100644 --- a/client/src/api/api.ts +++ b/client/src/api/api.ts @@ -574,14 +574,28 @@ export interface Contour { index: number; } -export interface ComputedPulseAnnotation { +export interface ComputedPulseContour { id: number; index: number; contours: Contour[]; } -async function getComputedPulseAnnotations(recordingId: number) { - const result = await axiosInstance.get(`/recording/${recordingId}/pulse_data`); +async function getComputedPulseContour(recordingId: number) { + const result = await axiosInstance.get(`/recording/${recordingId}/pulse_contours`); + return result.data; +} + +export interface PulseMetadata { + id: number; + index: number; + curve: number[][] | null; // list of [time, frequency] + char_freq: number[] | null; // point [time, frequency] + knee: number[] | null; // point [time, frequency] + heel: number[] | null; // point [time, frequency] +} + +async function getPulseMetadata(recordingId: number) { + const result = await axiosInstance.get(`/recording/${recordingId}/pulse_data`); return result.data; } @@ -622,7 +636,8 @@ export { getFileAnnotationDetails, getExportStatus, getRecordingTags, - getComputedPulseAnnotations, + getComputedPulseContour, + getPulseMetadata, getCurrentUser, getVettingDetailsForUser, createOrUpdateVettingDetailsForUser, diff --git a/client/src/components/PulseMetadataButton.vue b/client/src/components/PulseMetadataButton.vue new file mode 100644 index 00000000..4dcf61f3 --- /dev/null +++ b/client/src/components/PulseMetadataButton.vue @@ -0,0 +1,228 @@ + + + + + diff --git a/client/src/components/SpectrogramImageContentMenu.vue b/client/src/components/SpectrogramImageContentMenu.vue index e60bc5d9..8e07b557 100644 --- a/client/src/components/SpectrogramImageContentMenu.vue +++ b/client/src/components/SpectrogramImageContentMenu.vue @@ -1,5 +1,5 @@