From 374a6ca9e92222903d35e4174e5524e08221ddab Mon Sep 17 00:00:00 2001 From: Patrick Loeber <98830383+ploeber@users.noreply.github.com> Date: Thu, 19 Dec 2024 11:49:43 +0100 Subject: [PATCH 1/6] feat: Add HTTP response info - Add status code to all Error classes - Store the latest HTTP response in a new instance variable `Client.last_response` --- assemblyai/api.py | 56 +++++++++++++++++++++++----------- assemblyai/client.py | 16 ++++++++++ assemblyai/transcriber.py | 25 ++++++++------- assemblyai/types.py | 22 +++++++++++++ tests/unit/test_transcriber.py | 9 ++++-- 5 files changed, 97 insertions(+), 31 deletions(-) diff --git a/assemblyai/api.py b/assemblyai/api.py index 6c16f1d..b2f666a 100644 --- a/assemblyai/api.py +++ b/assemblyai/api.py @@ -43,7 +43,8 @@ def create_transcript( ) if response.status_code != httpx.codes.OK: raise types.TranscriptError( - f"failed to transcribe url {request.audio_url}: {_get_error_message(response)}" + f"failed to transcribe url {request.audio_url}: {_get_error_message(response)}", + response.status_code, ) return types.TranscriptResponse.parse_obj(response.json()) @@ -60,6 +61,7 @@ def get_transcript( if response.status_code != httpx.codes.OK: raise types.TranscriptError( f"failed to retrieve transcript {transcript_id}: {_get_error_message(response)}", + response.status_code, ) return types.TranscriptResponse.parse_obj(response.json()) @@ -76,6 +78,7 @@ def delete_transcript( if response.status_code != httpx.codes.OK: raise types.TranscriptError( f"failed to delete transcript {transcript_id}: {_get_error_message(response)}", + response.status_code, ) return types.TranscriptResponse.parse_obj(response.json()) @@ -102,7 +105,8 @@ def upload_file( if response.status_code != httpx.codes.OK: raise types.TranscriptError( - f"Failed to upload audio file: {_get_error_message(response)}" + f"Failed to upload audio file: {_get_error_message(response)}", + response.status_code, ) return response.json()["upload_url"] @@ -127,7 +131,8 @@ def export_subtitles_srt( if response.status_code != httpx.codes.OK: raise types.TranscriptError( - f"failed to export SRT for transcript {transcript_id}: {_get_error_message(response)}" + f"failed to export SRT for transcript {transcript_id}: {_get_error_message(response)}", + response.status_code, ) return response.text @@ -152,7 +157,8 @@ def export_subtitles_vtt( if response.status_code != httpx.codes.OK: raise types.TranscriptError( - f"failed to export VTT for transcript {transcript_id}: {_get_error_message(response)}" + f"failed to export VTT for transcript {transcript_id}: {_get_error_message(response)}", + response.status_code, ) return response.text @@ -174,7 +180,8 @@ def word_search( if response.status_code != httpx.codes.OK: raise types.TranscriptError( - f"failed to search words in transcript {transcript_id}: {_get_error_message(response)}" + f"failed to search words in transcript {transcript_id}: {_get_error_message(response)}", + response.status_code, ) return types.WordSearchMatchResponse.parse_obj(response.json()) @@ -199,17 +206,20 @@ def get_redacted_audio( if response.status_code == httpx.codes.ACCEPTED: raise types.RedactedAudioIncompleteError( - f"redacted audio for transcript {transcript_id} is not ready yet" + f"redacted audio for transcript {transcript_id} is not ready yet", + response.status_code, ) if response.status_code == httpx.codes.BAD_REQUEST: raise types.RedactedAudioExpiredError( - f"redacted audio for transcript {transcript_id} is no longer available" + f"redacted audio for transcript {transcript_id} is no longer available", + response.status_code, ) if response.status_code != httpx.codes.OK: raise types.TranscriptError( - f"failed to retrieve redacted audio for transcript {transcript_id}: {_get_error_message(response)}" + f"failed to retrieve redacted audio for transcript {transcript_id}: {_get_error_message(response)}", + response.status_code, ) return types.RedactedAudioResponse.parse_obj(response.json()) @@ -225,7 +235,8 @@ def get_sentences( if response.status_code != httpx.codes.OK: raise types.TranscriptError( - f"failed to retrieve sentences for transcript {transcript_id}: {_get_error_message(response)}" + f"failed to retrieve sentences for transcript {transcript_id}: {_get_error_message(response)}", + response.status_code, ) return types.SentencesResponse.parse_obj(response.json()) @@ -241,7 +252,8 @@ def get_paragraphs( if response.status_code != httpx.codes.OK: raise types.TranscriptError( - f"failed to retrieve paragraphs for transcript {transcript_id}: {_get_error_message(response)}" + f"failed to retrieve paragraphs for transcript {transcript_id}: {_get_error_message(response)}", + response.status_code, ) return types.ParagraphsResponse.parse_obj(response.json()) @@ -264,7 +276,8 @@ def list_transcripts( if response.status_code != httpx.codes.OK: raise types.AssemblyAIError( - f"failed to retrieve transcripts: {_get_error_message(response)}" + f"failed to retrieve transcripts: {_get_error_message(response)}", + response.status_code, ) return types.ListTranscriptResponse.parse_obj(response.json()) @@ -285,7 +298,8 @@ def lemur_question( if response.status_code != httpx.codes.OK: raise types.LemurError( - f"failed to call Lemur questions: {_get_error_message(response)}" + f"failed to call Lemur questions: {_get_error_message(response)}", + response.status_code, ) return types.LemurQuestionResponse.parse_obj(response.json()) @@ -306,7 +320,8 @@ def lemur_summarize( if response.status_code != httpx.codes.OK: raise types.LemurError( - f"failed to call Lemur summary: {_get_error_message(response)}" + f"failed to call Lemur summary: {_get_error_message(response)}", + response.status_code, ) return types.LemurSummaryResponse.parse_obj(response.json()) @@ -327,7 +342,8 @@ def lemur_action_items( if response.status_code != httpx.codes.OK: raise types.LemurError( - f"failed to call Lemur action items: {_get_error_message(response)}" + f"failed to call Lemur action items: {_get_error_message(response)}", + response.status_code, ) return types.LemurActionItemsResponse.parse_obj(response.json()) @@ -348,7 +364,8 @@ def lemur_task( if response.status_code != httpx.codes.OK: raise types.LemurError( - f"failed to call Lemur task: {_get_error_message(response)}" + f"failed to call Lemur task: {_get_error_message(response)}", + response.status_code, ) return types.LemurTaskResponse.parse_obj(response.json()) @@ -366,7 +383,8 @@ def lemur_purge_request_data( if response.status_code != httpx.codes.OK: raise types.LemurError( - f"Failed to purge LeMUR request data for provided request ID: {request.request_id}. Error: {_get_error_message(response)}" + f"Failed to purge LeMUR request data for provided request ID: {request.request_id}. Error: {_get_error_message(response)}", + response.status_code, ) return types.LemurPurgeResponse.parse_obj(response.json()) @@ -387,7 +405,8 @@ def lemur_get_response_data( if response.status_code != httpx.codes.OK: raise types.LemurError( - f"Failed to get LeMUR response data for provided request ID: {request_id}. Error: {_get_error_message(response)}" + f"Failed to get LeMUR response data for provided request ID: {request_id}. Error: {_get_error_message(response)}", + response.status_code, ) json_data = response.json() @@ -411,7 +430,8 @@ def create_temporary_token( if response.status_code != httpx.codes.OK: raise types.AssemblyAIError( - f"Failed to create temporary token: {_get_error_message(response)}" + f"Failed to create temporary token: {_get_error_message(response)}", + response.status_code, ) data = types.RealtimeCreateTemporaryTokenResponse.parse_obj(response.json()) diff --git a/assemblyai/client.py b/assemblyai/client.py index a7d4697..99653d1 100644 --- a/assemblyai/client.py +++ b/assemblyai/client.py @@ -43,12 +43,28 @@ def __init__( if self._settings.api_key: headers["authorization"] = self.settings.api_key + self._last_response: Optional[httpx.Response] = None + + def _store_response(response): + self._last_response = response + self._http_client = httpx.Client( base_url=self.settings.base_url, headers=headers, timeout=self.settings.http_timeout, + event_hooks={"response": [_store_response]}, ) + @property + def last_response(self) -> Optional[httpx.Response]: + """ + Get the last HTTP response, corresponding to the last request sent from this client. + + Returns: + The last HTTP response. + """ + return self._last_response + @property def settings(self) -> types.Settings: """ diff --git a/assemblyai/transcriber.py b/assemblyai/transcriber.py index c07e84c..753c253 100644 --- a/assemblyai/transcriber.py +++ b/assemblyai/transcriber.py @@ -175,7 +175,8 @@ def save_redacted_audio(self, filepath: str): with httpx.stream(method="GET", url=self.get_redacted_audio_url()) as response: if response.status_code not in (httpx.codes.OK, httpx.codes.NOT_MODIFIED): raise types.RedactedAudioUnavailableError( - f"Fetching redacted audio failed with status code {response.status_code}" + f"Fetching redacted audio failed with status code {response.status_code}", + response.status_code, ) with open(filepath, "wb") as f: for chunk in response.iter_bytes(): @@ -556,7 +557,9 @@ def add_transcript(self, transcript: Union[Transcript, str]) -> None: return self - def wait_for_completion(self, return_failures) -> Union[None, List[str]]: + def wait_for_completion( + self, return_failures + ) -> Union[None, List[types.AssemblyAIError]]: transcripts: List[Transcript] = [] failures: List[str] = [] @@ -572,7 +575,7 @@ def wait_for_completion(self, return_failures) -> Union[None, List[str]]: try: transcripts.append(future.result()) except types.TranscriptError as e: - failures.append(str(e)) + failures.append(e) self.transcripts = transcripts @@ -672,7 +675,7 @@ def add_transcript( def wait_for_completion( self, return_failures: Optional[bool] = False, - ) -> Union[Self, Tuple[Self, List[str]]]: + ) -> Union[Self, Tuple[Self, List[types.AssemblyAIError]]]: """ Polls each transcript within the `TranscriptGroup`. @@ -695,7 +698,7 @@ def wait_for_completion_async( return_failures: Optional[bool] = False, ) -> Union[ concurrent.futures.Future[Self], - concurrent.futures.Future[Tuple[Self, List[str]]], + concurrent.futures.Future[Tuple[Self, List[types.AssemblyAIError]]], ]: return self._executor.submit( self.wait_for_completion, return_failures=return_failures @@ -799,7 +802,7 @@ def transcribe_group( config: Optional[types.TranscriptionConfig], poll: bool, return_failures: Optional[bool] = False, - ) -> Union[TranscriptGroup, Tuple[TranscriptGroup, List[str]]]: + ) -> Union[TranscriptGroup, Tuple[TranscriptGroup, List[types.AssemblyAIError]]]: if config is None: config = self.config @@ -827,7 +830,7 @@ def transcribe_group( try: transcript_group.add_transcript(future.result()) except types.TranscriptError as e: - failures.append(f"Error processing {future_transcripts[future]}: {e}") + failures.append(e) if poll and return_failures: transcript_group, completion_failures = ( @@ -969,7 +972,7 @@ def submit_group( data: List[Union[str, BinaryIO]], config: Optional[types.TranscriptionConfig] = None, return_failures: Optional[bool] = False, - ) -> Union[TranscriptGroup, Tuple[TranscriptGroup, List[str]]]: + ) -> Union[TranscriptGroup, Tuple[TranscriptGroup, List[types.AssemblyAIError]]]: """ Submits multiple transcription jobs without waiting for their completion. @@ -1032,7 +1035,7 @@ def transcribe_group( data: List[Union[str, BinaryIO]], config: Optional[types.TranscriptionConfig] = None, return_failures: Optional[bool] = False, - ) -> Union[TranscriptGroup, Tuple[TranscriptGroup, List[str]]]: + ) -> Union[TranscriptGroup, Tuple[TranscriptGroup, List[types.AssemblyAIError]]]: """ Transcribes a list of files (as local paths, URLs, or binary objects). @@ -1057,7 +1060,7 @@ def transcribe_group_async( return_failures: Optional[bool] = False, ) -> Union[ concurrent.futures.Future[TranscriptGroup], - concurrent.futures.Future[Tuple[TranscriptGroup, List[str]]], + concurrent.futures.Future[Tuple[TranscriptGroup, List[types.AssemblyAIError]]], ]: """ Transcribes a list of files (as local paths, URLs, or binary objects) asynchronously. @@ -1358,7 +1361,7 @@ def _handle_error(self, error: websockets.exceptions.ConnectionClosed) -> None: error_message = error.reason if error.code != 1000: - self._on_error(types.RealtimeError(error_message)) + self._on_error(types.RealtimeError(error_message, error.code)) self.close() diff --git a/assemblyai/types.py b/assemblyai/types.py index 412d248..2e8283f 100644 --- a/assemblyai/types.py +++ b/assemblyai/types.py @@ -27,12 +27,19 @@ class AssemblyAIError(Exception): Base exception for all AssemblyAI errors """ + def __init__(self, message: str, status_code: Optional[int] = None): + super().__init__(message) + self.status_code = status_code + class TranscriptError(AssemblyAIError): """ Error class when a transcription fails """ + def __init__(self, message: str, status_code: Optional[int] = None): + super().__init__(message, status_code) + class RedactedAudioIncompleteError(AssemblyAIError): """ @@ -40,6 +47,9 @@ class RedactedAudioIncompleteError(AssemblyAIError): before the file has finished processing """ + def __init__(self, message: str, status_code: Optional[int] = None): + super().__init__(message, status_code) + class RedactedAudioExpiredError(AssemblyAIError): """ @@ -47,6 +57,9 @@ class RedactedAudioExpiredError(AssemblyAIError): but the file has expired and is no longer available """ + def __init__(self, message: str, status_code: Optional[int] = None): + super().__init__(message, status_code) + class RedactedAudioUnavailableError(AssemblyAIError): """ @@ -54,12 +67,18 @@ class RedactedAudioUnavailableError(AssemblyAIError): but it is not available at the given URL """ + def __init__(self, message: str, status_code: Optional[int] = None): + super().__init__(message, status_code) + class LemurError(AssemblyAIError): """ Error class when a Lemur request fails """ + def __init__(self, message: str, status_code: Optional[int] = None): + super().__init__(message, status_code) + class Sourcable: """ @@ -2291,6 +2310,9 @@ class RealtimeError(AssemblyAIError): Real-time error message """ + def __init__(self, message: str, status_code: Optional[int] = None): + super().__init__(message, status_code) + RealtimeErrorMapping = { 4000: "Sample rate must be a positive integer", diff --git a/tests/unit/test_transcriber.py b/tests/unit/test_transcriber.py index b40a7dc..5f7a165 100644 --- a/tests/unit/test_transcriber.py +++ b/tests/unit/test_transcriber.py @@ -70,6 +70,7 @@ def test_upload_file_fails(httpx_mock: HTTPXMock): # check wheter the TranscriptError contains the specified error message assert returned_error_message in str(excinfo.value) + assert httpx.codes.INTERNAL_SERVER_ERROR == excinfo.value.status_code def test_submit_url_succeeds(httpx_mock: HTTPXMock): @@ -120,6 +121,7 @@ def test_submit_url_fails(httpx_mock: HTTPXMock): transcriber.submit("https://example.org/audio.wav") assert "something went wrong" in str(excinfo) + assert httpx.codes.INTERNAL_SERVER_ERROR == excinfo.value.status_code # check whether we mocked everything assert len(httpx_mock.get_requests()) == 1 @@ -148,6 +150,7 @@ def test_submit_file_fails_due_api_error(httpx_mock: HTTPXMock): # check wheter the Exception contains the specified error message assert "something went wrong" in str(excinfo.value) + assert httpx.codes.INTERNAL_SERVER_ERROR == excinfo.value.status_code # check whether we mocked everything assert len(httpx_mock.get_requests()) == 1 @@ -430,7 +433,8 @@ def test_transcribe_group_urls_fails_during_upload(httpx_mock: HTTPXMock): assert len(failures) == 1 # Check whether the error message corresponds to the raised TranscriptError message - assert f"Error processing {expect_failed_audio_url}" in failures[0] + assert f"failed to transcribe url" in str(failures[0]) + assert failures[0].status_code == httpx.codes.INTERNAL_SERVER_ERROR def test_transcribe_group_urls_fails_during_polling(httpx_mock: HTTPXMock): @@ -501,7 +505,8 @@ def test_transcribe_group_urls_fails_during_polling(httpx_mock: HTTPXMock): assert len(failures) == 1 # Check whether the error message is correct - assert "failed to retrieve transcript" in failures[0] + assert "failed to retrieve transcript" in str(failures[0]) + assert failures[0].status_code == httpx.codes.INTERNAL_SERVER_ERROR def test_transcribe_async_url_succeeds(httpx_mock: HTTPXMock): From 0fe1bcb5a90f82b7fd0414907825a8d286858d61 Mon Sep 17 00:00:00 2001 From: Patrick Loeber <98830383+ploeber@users.noreply.github.com> Date: Thu, 19 Dec 2024 13:22:29 +0100 Subject: [PATCH 2/6] fix mypy errors --- assemblyai/__version__.py | 2 +- assemblyai/client.py | 2 +- assemblyai/transcriber.py | 94 ++++++++++++++++++++++++++++++++------- 3 files changed, 81 insertions(+), 17 deletions(-) diff --git a/assemblyai/__version__.py b/assemblyai/__version__.py index 98bb08f..d9f2629 100644 --- a/assemblyai/__version__.py +++ b/assemblyai/__version__.py @@ -1 +1 @@ -__version__ = "0.35.1" +__version__ = "0.36.0" diff --git a/assemblyai/client.py b/assemblyai/client.py index 99653d1..bce5889 100644 --- a/assemblyai/client.py +++ b/assemblyai/client.py @@ -41,7 +41,7 @@ def __init__( headers = {"user-agent": user_agent} if self._settings.api_key: - headers["authorization"] = self.settings.api_key + headers["authorization"] = self._settings.api_key self._last_response: Optional[httpx.Response] = None diff --git a/assemblyai/transcriber.py b/assemblyai/transcriber.py index 753c253..7b96e78 100644 --- a/assemblyai/transcriber.py +++ b/assemblyai/transcriber.py @@ -17,6 +17,7 @@ Iterator, List, Optional, + Set, Tuple, Union, ) @@ -47,6 +48,8 @@ def __init__( @property def config(self) -> types.TranscriptionConfig: "Returns the configuration from the internal Transcript object" + if self.transcript is None: + raise ValueError("Canot access the configuration. The internal Transcript object is None.") return types.TranscriptionConfig( **self.transcript.dict( @@ -74,6 +77,8 @@ def wait_for_completion(self) -> Self: """ polls the given transcript until we have a status other than `processing` or `queued` """ + if not self.transcript_id: + raise ValueError("Cannot wait for completion. The internal transcript ID is None.") while True: # No try-except - if there is an HTTP error then surface it to user @@ -97,6 +102,9 @@ def export_subtitles_srt( *, chars_per_caption: Optional[int], ) -> str: + if not self.transcript or not self.transcript.id: + raise ValueError("Cannot export subtitles. The internal Transcript object is None.") + return api.export_subtitles_srt( client=self._client.http_client, transcript_id=self.transcript.id, @@ -108,6 +116,9 @@ def export_subtitles_vtt( *, chars_per_caption: Optional[int], ) -> str: + if not self.transcript or not self.transcript.id: + raise ValueError("Cannot export subtitles. The internal Transcript object is None.") + return api.export_subtitles_vtt( client=self._client.http_client, transcript_id=self.transcript.id, @@ -119,6 +130,9 @@ def word_search( *, words: List[str], ) -> List[types.WordSearchMatch]: + if not self.transcript or not self.transcript.id: + raise ValueError("Cannot perform word search. The internal Transcript object is None.") + response = api.word_search( client=self._client.http_client, transcript_id=self.transcript.id, @@ -128,6 +142,9 @@ def word_search( return response.matches def get_sentences(self) -> List[types.Sentence]: + if not self.transcript or not self.transcript.id: + raise ValueError("Cannot get sentences. The internal Transcript object is None.") + response = api.get_sentences( client=self._client.http_client, transcript_id=self.transcript.id, @@ -136,6 +153,9 @@ def get_sentences(self) -> List[types.Sentence]: return response.sentences def get_paragraphs(self) -> List[types.Paragraph]: + if not self.transcript or not self.transcript.id: + raise ValueError("Cannot get paragraphs. The internal Transcript object is None.") + response = api.get_paragraphs( client=self._client.http_client, transcript_id=self.transcript.id, @@ -156,6 +176,9 @@ def get_redacted_audio_url(self) -> str: "Redacted audio is only available when `redact_pii` and `redact_pii_audio` are set to `True`." ) + if not self.transcript_id: + raise ValueError("Cannot get redacted audio url. The internal transcript ID is None.") + while True: try: return api.get_redacted_audio( @@ -184,8 +207,8 @@ def save_redacted_audio(self, filepath: str): @classmethod def delete_by_id(cls, transcript_id: str) -> types.Transcript: - client = _client.Client.get_default().http_client - response = api.delete_transcript(client=client, transcript_id=transcript_id) + client = _client.Client.get_default() + response = api.delete_transcript(client=client.http_client, transcript_id=transcript_id) return Transcript.from_response(client=client, response=response) @@ -306,83 +329,112 @@ def config(self) -> types.TranscriptionConfig: @property def json_response(self) -> Optional[dict]: "The full JSON response associated with the transcript." + if not self._impl.transcript: + raise ValueError("The internal Transcript object is None.") return self._impl.transcript.dict() @property def audio_url(self) -> str: "The corresponding audio url" + if not self._impl.transcript: + raise ValueError("The internal Transcript object is None.") return self._impl.transcript.audio_url @property def speech_model(self) -> Optional[str]: "The speech model used for the transcription" + if not self._impl.transcript: + raise ValueError("The internal Transcript object is None.") + return self._impl.transcript.speech_model @property def text(self) -> Optional[str]: "The text transcription of your media file" + if not self._impl.transcript: + raise ValueError("The internal Transcript object is None.") return self._impl.transcript.text @property def summary(self) -> Optional[str]: "The summarization of the transcript" + if not self._impl.transcript: + raise ValueError("The internal Transcript object is None.") return self._impl.transcript.summary @property def chapters(self) -> Optional[List[types.Chapter]]: "The list of auto-chapters results" + if not self._impl.transcript: + raise ValueError("The internal Transcript object is None.") return self._impl.transcript.chapters @property def content_safety(self) -> Optional[types.ContentSafetyResponse]: "The results from the content safety analysis" + if not self._impl.transcript: + raise ValueError("The internal Transcript object is None.") return self._impl.transcript.content_safety_labels @property def sentiment_analysis(self) -> Optional[List[types.Sentiment]]: "The list of sentiment analysis results" + if not self._impl.transcript: + raise ValueError("The internal Transcript object is None.") return self._impl.transcript.sentiment_analysis_results @property def entities(self) -> Optional[List[types.Entity]]: "The list of entity detection results" + if not self._impl.transcript: + raise ValueError("The internal Transcript object is None.") return self._impl.transcript.entities @property def iab_categories(self) -> Optional[types.IABResponse]: "The results from the IAB category detection" + if not self._impl.transcript: + raise ValueError("The internal Transcript object is None.") return self._impl.transcript.iab_categories_result @property def auto_highlights(self) -> Optional[types.AutohighlightResponse]: "The results from the auto-highlights model" + if not self._impl.transcript: + raise ValueError("The internal Transcript object is None.") return self._impl.transcript.auto_highlights_result @property def status(self) -> types.TranscriptStatus: "The current status of the transcript" + if not self._impl.transcript: + raise ValueError("The internal Transcript object is None.") return self._impl.transcript.status @property def error(self) -> Optional[str]: "The error message in case the transcription fails" + if not self._impl.transcript: + raise ValueError("The internal Transcript object is None.") return self._impl.transcript.error @property def words(self) -> Optional[List[types.Word]]: "The list of words in the transcript" + if not self._impl.transcript: + raise ValueError("The internal Transcript object is None.") return self._impl.transcript.words @@ -392,30 +444,40 @@ def utterances(self) -> Optional[List[types.Utterance]]: When `dual_channel` or `speaker_labels` is enabled, a list of utterances in the transcript. """ + if not self._impl.transcript: + raise ValueError("The internal Transcript object is None.") return self._impl.transcript.utterances @property def confidence(self) -> Optional[float]: "The confidence our model has in the transcribed text, between 0 and 1" + if not self._impl.transcript: + raise ValueError("The internal Transcript object is None.") return self._impl.transcript.confidence @property def audio_duration(self) -> Optional[int]: "The duration of the audio in seconds" + if not self._impl.transcript: + raise ValueError("The internal Transcript object is None.") return self._impl.transcript.audio_duration @property def webhook_status_code(self) -> Optional[int]: "The status code we received from your server when delivering your webhook" + if not self._impl.transcript: + raise ValueError("The internal Transcript object is None.") return self._impl.transcript.webhook_status_code @property def webhook_auth(self) -> Optional[bool]: "Whether the webhook was sent with an HTTP authentication header" + if not self._impl.transcript: + raise ValueError("The internal Transcript object is None.") return self._impl.transcript.webhook_auth @@ -540,7 +602,9 @@ def __init__( @property def transcript_ids(self) -> List[str]: - return [t.id for t in self.transcripts] + if any(t.id is None for t in self.transcripts): + raise ValueError("All transcripts must have a transcript ID.") + return [t.id for t in self.transcripts if t.id] # include the if check for mypy type checker def add_transcript(self, transcript: Union[Transcript, str]) -> None: if isinstance(transcript, Transcript): @@ -555,19 +619,17 @@ def add_transcript(self, transcript: Union[Transcript, str]) -> None: else: raise TypeError("Unsupported type for `transcript`") - return self - def wait_for_completion( self, return_failures ) -> Union[None, List[types.AssemblyAIError]]: transcripts: List[Transcript] = [] - failures: List[str] = [] + failures: List[types.AssemblyAIError] = [] - future_transcripts: Dict[concurrent.futures.Future[Transcript], str] = {} + future_transcripts: Set[concurrent.futures.Future[Transcript]] = set() for transcript in self.transcripts: future = transcript.wait_for_completion_async() - future_transcripts[future] = transcript + future_transcripts.add(future) finished_futures, _ = concurrent.futures.wait(future_transcripts) @@ -619,13 +681,13 @@ def __iter__(self) -> Iterator[Transcript]: return iter(self.transcripts) @classmethod - def get_by_ids(cls, transcript_ids: List[str]) -> Self: + def get_by_ids(cls, transcript_ids: List[str]) -> Union[Self, Tuple[Self, List[types.AssemblyAIError]]]: return cls(transcript_ids=transcript_ids).wait_for_completion() @classmethod def get_by_ids_async( cls, transcript_ids: List[str] - ) -> concurrent.futures.Future[Self]: + ) -> concurrent.futures.Future[Union[Self, Tuple[Self, List[types.AssemblyAIError]]]]: return cls(transcript_ids=transcript_ids).wait_for_completion_async() @property @@ -646,6 +708,8 @@ def status(self) -> types.TranscriptStatus: return types.TranscriptStatus.processing elif all(s == types.TranscriptStatus.completed for s in all_status): return types.TranscriptStatus.completed + else: + raise ValueError(f"Unexpected status type: {all_status}") @property def lemur(self) -> lemur.Lemur: @@ -687,6 +751,8 @@ def wait_for_completion( """ if return_failures: failures = self._impl.wait_for_completion(return_failures=return_failures) + if not failures: + raise ValueError("return_failures was set but failures object is None") return self, failures self._impl.wait_for_completion(return_failures=return_failures) @@ -696,9 +762,7 @@ def wait_for_completion( def wait_for_completion_async( self, return_failures: Optional[bool] = False, - ) -> Union[ - concurrent.futures.Future[Self], - concurrent.futures.Future[Tuple[Self, List[types.AssemblyAIError]]], + ) -> concurrent.futures.Future[Union[Self, Tuple[Self, List[types.AssemblyAIError]]], ]: return self._executor.submit( self.wait_for_completion, return_failures=return_failures @@ -806,7 +870,7 @@ def transcribe_group( if config is None: config = self.config - future_transcripts: Dict[concurrent.futures.Future[Transcript], str] = {} + future_transcripts: Set[concurrent.futures.Future[Transcript]] = set() with concurrent.futures.ThreadPoolExecutor(max_workers=8) as executor: for d in data: @@ -817,7 +881,7 @@ def transcribe_group( poll=False, ) - future_transcripts[transcript_future] = d + future_transcripts.add(transcript_future) finished_futures, _ = concurrent.futures.wait(future_transcripts) From 1e1abc7e0bfc8fc7e79fc6bfcd1a00ed34d5e87a Mon Sep 17 00:00:00 2001 From: Patrick Loeber <98830383+ploeber@users.noreply.github.com> Date: Thu, 19 Dec 2024 13:23:37 +0100 Subject: [PATCH 3/6] fix ruff errors --- assemblyai/transcriber.py | 53 ++++++++++++++++++++++++++++----------- 1 file changed, 39 insertions(+), 14 deletions(-) diff --git a/assemblyai/transcriber.py b/assemblyai/transcriber.py index 7b96e78..c3bd8cf 100644 --- a/assemblyai/transcriber.py +++ b/assemblyai/transcriber.py @@ -49,7 +49,9 @@ def __init__( def config(self) -> types.TranscriptionConfig: "Returns the configuration from the internal Transcript object" if self.transcript is None: - raise ValueError("Canot access the configuration. The internal Transcript object is None.") + raise ValueError( + "Canot access the configuration. The internal Transcript object is None." + ) return types.TranscriptionConfig( **self.transcript.dict( @@ -78,7 +80,9 @@ def wait_for_completion(self) -> Self: polls the given transcript until we have a status other than `processing` or `queued` """ if not self.transcript_id: - raise ValueError("Cannot wait for completion. The internal transcript ID is None.") + raise ValueError( + "Cannot wait for completion. The internal transcript ID is None." + ) while True: # No try-except - if there is an HTTP error then surface it to user @@ -103,7 +107,9 @@ def export_subtitles_srt( chars_per_caption: Optional[int], ) -> str: if not self.transcript or not self.transcript.id: - raise ValueError("Cannot export subtitles. The internal Transcript object is None.") + raise ValueError( + "Cannot export subtitles. The internal Transcript object is None." + ) return api.export_subtitles_srt( client=self._client.http_client, @@ -117,7 +123,9 @@ def export_subtitles_vtt( chars_per_caption: Optional[int], ) -> str: if not self.transcript or not self.transcript.id: - raise ValueError("Cannot export subtitles. The internal Transcript object is None.") + raise ValueError( + "Cannot export subtitles. The internal Transcript object is None." + ) return api.export_subtitles_vtt( client=self._client.http_client, @@ -131,7 +139,9 @@ def word_search( words: List[str], ) -> List[types.WordSearchMatch]: if not self.transcript or not self.transcript.id: - raise ValueError("Cannot perform word search. The internal Transcript object is None.") + raise ValueError( + "Cannot perform word search. The internal Transcript object is None." + ) response = api.word_search( client=self._client.http_client, @@ -143,7 +153,9 @@ def word_search( def get_sentences(self) -> List[types.Sentence]: if not self.transcript or not self.transcript.id: - raise ValueError("Cannot get sentences. The internal Transcript object is None.") + raise ValueError( + "Cannot get sentences. The internal Transcript object is None." + ) response = api.get_sentences( client=self._client.http_client, @@ -154,7 +166,9 @@ def get_sentences(self) -> List[types.Sentence]: def get_paragraphs(self) -> List[types.Paragraph]: if not self.transcript or not self.transcript.id: - raise ValueError("Cannot get paragraphs. The internal Transcript object is None.") + raise ValueError( + "Cannot get paragraphs. The internal Transcript object is None." + ) response = api.get_paragraphs( client=self._client.http_client, @@ -177,7 +191,9 @@ def get_redacted_audio_url(self) -> str: ) if not self.transcript_id: - raise ValueError("Cannot get redacted audio url. The internal transcript ID is None.") + raise ValueError( + "Cannot get redacted audio url. The internal transcript ID is None." + ) while True: try: @@ -208,7 +224,9 @@ def save_redacted_audio(self, filepath: str): @classmethod def delete_by_id(cls, transcript_id: str) -> types.Transcript: client = _client.Client.get_default() - response = api.delete_transcript(client=client.http_client, transcript_id=transcript_id) + response = api.delete_transcript( + client=client.http_client, transcript_id=transcript_id + ) return Transcript.from_response(client=client, response=response) @@ -370,7 +388,7 @@ def summary(self) -> Optional[str]: def chapters(self) -> Optional[List[types.Chapter]]: "The list of auto-chapters results" if not self._impl.transcript: - raise ValueError("The internal Transcript object is None.") + raise ValueError("The internal Transcript object is None.") return self._impl.transcript.chapters @@ -604,7 +622,9 @@ def __init__( def transcript_ids(self) -> List[str]: if any(t.id is None for t in self.transcripts): raise ValueError("All transcripts must have a transcript ID.") - return [t.id for t in self.transcripts if t.id] # include the if check for mypy type checker + return [ + t.id for t in self.transcripts if t.id + ] # include the if check for mypy type checker def add_transcript(self, transcript: Union[Transcript, str]) -> None: if isinstance(transcript, Transcript): @@ -681,13 +701,17 @@ def __iter__(self) -> Iterator[Transcript]: return iter(self.transcripts) @classmethod - def get_by_ids(cls, transcript_ids: List[str]) -> Union[Self, Tuple[Self, List[types.AssemblyAIError]]]: + def get_by_ids( + cls, transcript_ids: List[str] + ) -> Union[Self, Tuple[Self, List[types.AssemblyAIError]]]: return cls(transcript_ids=transcript_ids).wait_for_completion() @classmethod def get_by_ids_async( cls, transcript_ids: List[str] - ) -> concurrent.futures.Future[Union[Self, Tuple[Self, List[types.AssemblyAIError]]]]: + ) -> concurrent.futures.Future[ + Union[Self, Tuple[Self, List[types.AssemblyAIError]]] + ]: return cls(transcript_ids=transcript_ids).wait_for_completion_async() @property @@ -762,7 +786,8 @@ def wait_for_completion( def wait_for_completion_async( self, return_failures: Optional[bool] = False, - ) -> concurrent.futures.Future[Union[Self, Tuple[Self, List[types.AssemblyAIError]]], + ) -> concurrent.futures.Future[ + Union[Self, Tuple[Self, List[types.AssemblyAIError]]], ]: return self._executor.submit( self.wait_for_completion, return_failures=return_failures From c594b9a92d9e3144a0269c2c566f2309737d6e17 Mon Sep 17 00:00:00 2001 From: Patrick Loeber <98830383+ploeber@users.noreply.github.com> Date: Thu, 19 Dec 2024 14:09:45 +0100 Subject: [PATCH 4/6] fix more mypy errors --- assemblyai/client.py | 2 +- assemblyai/transcriber.py | 53 ++++++++++++++++++++++++++------------- 2 files changed, 36 insertions(+), 19 deletions(-) diff --git a/assemblyai/client.py b/assemblyai/client.py index bce5889..5454c5c 100644 --- a/assemblyai/client.py +++ b/assemblyai/client.py @@ -88,7 +88,7 @@ def http_client(self) -> httpx.Client: return self._http_client @classmethod - def get_default(cls, api_key_required: bool = True) -> Self: + def get_default(cls, api_key_required: bool = True): """ Return the default client. diff --git a/assemblyai/transcriber.py b/assemblyai/transcriber.py index c3bd8cf..97b9060 100644 --- a/assemblyai/transcriber.py +++ b/assemblyai/transcriber.py @@ -663,6 +663,7 @@ def wait_for_completion( if return_failures: return failures + return None class TranscriptGroup: @@ -913,7 +914,7 @@ def transcribe_group( transcript_group = TranscriptGroup( client=self._client, ) - failures = [] + failures: List[types.AssemblyAIError] = [] for future in finished_futures: try: @@ -922,14 +923,20 @@ def transcribe_group( failures.append(e) if poll and return_failures: - transcript_group, completion_failures = ( - transcript_group.wait_for_completion(return_failures=return_failures) - ) + res = transcript_group.wait_for_completion(return_failures=return_failures) + if not isinstance(res, tuple): + raise ValueError( + "return_failures is set but did not receive failures object" + ) + transcript_group, completion_failures = res failures.extend(completion_failures) elif poll: - transcript_group = transcript_group.wait_for_completion( - return_failures=return_failures - ) + res = transcript_group.wait_for_completion(return_failures=return_failures) + if not isinstance(res, TranscriptGroup): + raise ValueError( + "return_failures is not set but did receive failures object" + ) + transcript_group = res if return_failures: return transcript_group, failures @@ -987,7 +994,11 @@ def __init__( ) if not max_workers: - max_workers = max(1, os.cpu_count() - 1) + cpu_count = os.cpu_count() + if not cpu_count: + max_workers = 1 + else: + max_workers = max(1, cpu_count - 1) self._executor = concurrent.futures.ThreadPoolExecutor( max_workers=max_workers, @@ -1147,9 +1158,8 @@ def transcribe_group_async( data: List[Union[str, BinaryIO]], config: Optional[types.TranscriptionConfig] = None, return_failures: Optional[bool] = False, - ) -> Union[ - concurrent.futures.Future[TranscriptGroup], - concurrent.futures.Future[Tuple[TranscriptGroup, List[types.AssemblyAIError]]], + ) -> concurrent.futures.Future[ + Union[TranscriptGroup, Tuple[TranscriptGroup, List[types.AssemblyAIError]]] ]: """ Transcribes a list of files (as local paths, URLs, or binary objects) asynchronously. @@ -1339,7 +1349,8 @@ def close(self, terminate: bool = False) -> None: try: self._read_thread.join() self._write_thread.join() - self._websocket.close() + if self._websocket: + self._websocket.close() except Exception: pass @@ -1354,15 +1365,18 @@ def _read(self) -> None: """ while not self._stop_event.is_set(): + if not self._websocket: + raise ValueError("Websocket is None") + try: - message = self._websocket.recv(timeout=1) + recv_message = self._websocket.recv(timeout=1) except TimeoutError: continue except websockets.exceptions.ConnectionClosed as exc: return self._handle_error(exc) try: - message = json.loads(message) + message = json.loads(recv_message) except json.JSONDecodeError as exc: self._on_error( types.RealtimeError( @@ -1387,7 +1401,9 @@ def _write(self) -> None: continue try: - if isinstance(data, dict): + if not self._websocket: + raise ValueError("websocket is None") + elif isinstance(data, dict): self._websocket.send(json.dumps(data)) elif isinstance(data, bytes): self._websocket.send(data) @@ -1425,9 +1441,10 @@ def _handle_message( message["message_type"] == types.RealtimeMessageTypes.session_information ): - self._on_extra_session_information( - types.RealtimeSessionInformation(**message) - ) + if self._on_extra_session_information is not None: + self._on_extra_session_information( + types.RealtimeSessionInformation(**message) + ) elif "error" in message: self._on_error(types.RealtimeError(message["error"])) From e2a82d1bed63ff5ab693520dc7e2d93ba43e0619 Mon Sep 17 00:00:00 2001 From: Patrick Loeber <98830383+ploeber@users.noreply.github.com> Date: Thu, 19 Dec 2024 14:12:05 +0100 Subject: [PATCH 5/6] fix ruff error --- assemblyai/client.py | 1 - 1 file changed, 1 deletion(-) diff --git a/assemblyai/client.py b/assemblyai/client.py index 5454c5c..e8c7c09 100644 --- a/assemblyai/client.py +++ b/assemblyai/client.py @@ -3,7 +3,6 @@ from typing import ClassVar, Optional import httpx -from typing_extensions import Self from . import types from .__version__ import __version__ From e1016b7b85640cb8ed77b624d7be5b0a603e0224 Mon Sep 17 00:00:00 2001 From: Patrick Loeber <98830383+ploeber@users.noreply.github.com> Date: Thu, 19 Dec 2024 14:57:51 +0100 Subject: [PATCH 6/6] fix truthy/falsey checks --- assemblyai/transcriber.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/assemblyai/transcriber.py b/assemblyai/transcriber.py index 97b9060..9103cf7 100644 --- a/assemblyai/transcriber.py +++ b/assemblyai/transcriber.py @@ -661,7 +661,7 @@ def wait_for_completion( self.transcripts = transcripts - if return_failures: + if return_failures is True: return failures return None @@ -774,9 +774,9 @@ def wait_for_completion( Args: return_failures: Whether to return a list of errors for transcripts that failed due to HTTP errors. """ - if return_failures: + if return_failures is True: failures = self._impl.wait_for_completion(return_failures=return_failures) - if not failures: + if failures is None: raise ValueError("return_failures was set but failures object is None") return self, failures @@ -922,11 +922,11 @@ def transcribe_group( except types.TranscriptError as e: failures.append(e) - if poll and return_failures: + if poll is True and return_failures is True: res = transcript_group.wait_for_completion(return_failures=return_failures) if not isinstance(res, tuple): raise ValueError( - "return_failures is set but did not receive failures object" + "return_failures was set but did not receive failures object" ) transcript_group, completion_failures = res failures.extend(completion_failures) @@ -934,11 +934,11 @@ def transcribe_group( res = transcript_group.wait_for_completion(return_failures=return_failures) if not isinstance(res, TranscriptGroup): raise ValueError( - "return_failures is not set but did receive failures object" + "return_failures was not set but did receive failures object" ) transcript_group = res - if return_failures: + if return_failures is True: return transcript_group, failures else: return transcript_group