diff --git a/roboflow/adapters/rfapi.py b/roboflow/adapters/rfapi.py index a522429e..57a86e6b 100644 --- a/roboflow/adapters/rfapi.py +++ b/roboflow/adapters/rfapi.py @@ -152,6 +152,52 @@ def get_version_export( return payload +def start_search_export( + api_key: str, + workspace_url: str, + query: str, + format: str, + dataset: Optional[str] = None, + annotation_group: Optional[str] = None, + name: Optional[str] = None, +) -> str: + """Start a search export job. + + Returns the export_id string used to poll for completion. + + Raises RoboflowError on non-202 responses. + """ + url = f"{API_URL}/{workspace_url}/search/export?api_key={api_key}" + body: Dict[str, str] = {"query": query, "format": format} + if dataset is not None: + body["dataset"] = dataset + if annotation_group is not None: + body["annotationGroup"] = annotation_group + if name is not None: + body["name"] = name + + response = requests.post(url, json=body) + if response.status_code != 202: + raise RoboflowError(response.text) + + payload = response.json() + return payload["link"] + + +def get_search_export(api_key: str, workspace_url: str, export_id: str) -> dict: + """Poll the status of a search export job. + + Returns dict with ``ready`` (bool) and ``link`` (str, present when ready). + + Raises RoboflowError on non-200 responses. + """ + url = f"{API_URL}/{workspace_url}/search/export/{export_id}?api_key={api_key}" + response = requests.get(url) + if response.status_code != 200: + raise RoboflowError(response.text) + return response.json() + + def upload_image( api_key, project_url, diff --git a/roboflow/core/version.py b/roboflow/core/version.py index bd6e035c..5e236c61 100644 --- a/roboflow/core/version.py +++ b/roboflow/core/version.py @@ -5,7 +5,6 @@ import os import sys import time -import zipfile from typing import TYPE_CHECKING, Optional, Union import requests @@ -32,7 +31,7 @@ from roboflow.models.object_detection import ObjectDetectionModel from roboflow.models.semantic_segmentation import SemanticSegmentationModel from roboflow.util.annotations import amend_data_yaml -from roboflow.util.general import write_line +from roboflow.util.general import extract_zip, write_line from roboflow.util.model_processor import process from roboflow.util.versions import get_model_format, get_wrong_dependencies_versions, normalize_yolo_model_type @@ -239,7 +238,7 @@ def download(self, model_format=None, location=None, overwrite: bool = False): link = export_info["export"]["link"] self.__download_zip(link, location, model_format) - self.__extract_zip(location, model_format) + extract_zip(location, desc=f"Extracting Dataset Version Zip to {location} in {model_format}:") self.__reformat_yaml(location, model_format) # TODO: is roboflow-python a place to be munging yaml files? return Dataset(self.name, self.version, model_format, os.path.abspath(location)) @@ -577,30 +576,6 @@ def bar_progress(current, total, width=80): sys.stdout.write("\n") sys.stdout.flush() - def __extract_zip(self, location, format): - """ - Extracts the contents of a downloaded ZIP file and then deletes the zipped file. - - Args: - location (str): filepath of the data directory that contains the ZIP file - format (str): the format identifier string - - Raises: - RuntimeError: If there is an error unzipping the file - """ # noqa: E501 // docs - desc = None if TQDM_DISABLE else f"Extracting Dataset Version Zip to {location} in {format}:" - with zipfile.ZipFile(location + "/roboflow.zip", "r") as zip_ref: - for member in tqdm( - zip_ref.infolist(), - desc=desc, - ): - try: - zip_ref.extract(member, location) - except zipfile.error: - raise RuntimeError("Error unzipping download") - - os.remove(location + "/roboflow.zip") - def __get_download_location(self): """ Get the local path to save a downloaded dataset to @@ -707,4 +682,4 @@ def __str__(self): def unwrap_version_id(version_id: str) -> str: - return version_id if "/" not in str(version_id) else version_id.split("/")[-1] + return version_id if "/" not in str(version_id) else version_id.rsplit("/", maxsplit=1)[-1] diff --git a/roboflow/core/workspace.py b/roboflow/core/workspace.py index de5fcca4..5953dc17 100644 --- a/roboflow/core/workspace.py +++ b/roboflow/core/workspace.py @@ -5,10 +5,13 @@ import json import os import sys +import time from typing import Any, Dict, List, Optional import requests from PIL import Image +from requests.exceptions import HTTPError +from tqdm import tqdm from roboflow.adapters import rfapi from roboflow.adapters.rfapi import AnnotationSaveError, ImageUploadError, RoboflowError @@ -16,6 +19,7 @@ from roboflow.core.project import Project from roboflow.util import folderparser from roboflow.util.active_learning_utils import check_box_size, clip_encode, count_comparisons +from roboflow.util.general import extract_zip as _extract_zip from roboflow.util.image_utils import load_labelmap from roboflow.util.model_processor import process from roboflow.util.two_stage_utils import ocr_infer @@ -662,6 +666,108 @@ def _upload_zip( except Exception as e: print(f"An error occured when uploading the model: {e}") + def search_export( + self, + query: str, + format: str = "coco", + location: Optional[str] = None, + dataset: Optional[str] = None, + annotation_group: Optional[str] = None, + name: Optional[str] = None, + extract_zip: bool = True, + ) -> str: + """Export search results as a downloaded dataset. + + Args: + query: Search query string (e.g. ``"tag:annotate"`` or ``"*"``). + format: Annotation format for the export (default ``"coco"``). + location: Local directory to save the exported dataset. + Defaults to ``./search-export-{format}``. + dataset: Limit export to a specific dataset (project) slug. + annotation_group: Limit export to a specific annotation group. + name: Optional name for the export. + extract_zip: If True (default), extract the zip and remove it. + If False, keep the zip file as-is. + + Returns: + Absolute path to the extracted directory or the zip file. + + Raises: + ValueError: If both *dataset* and *annotation_group* are provided. + RoboflowError: On API errors or export timeout. + """ + if dataset is not None and annotation_group is not None: + raise ValueError("dataset and annotation_group are mutually exclusive; provide only one") + + if location is None: + location = f"./search-export-{format}" + location = os.path.abspath(location) + + # 1. Start the export + export_id = rfapi.start_search_export( + api_key=self.__api_key, + workspace_url=self.url, + query=query, + format=format, + dataset=dataset, + annotation_group=annotation_group, + name=name, + ) + print(f"Export started (id={export_id}). Polling for completion...") + + # 2. Poll until ready + timeout = 600 + poll_interval = 5 + elapsed = 0 + while elapsed < timeout: + status = rfapi.get_search_export( + api_key=self.__api_key, + workspace_url=self.url, + export_id=export_id, + ) + if status.get("ready"): + break + time.sleep(poll_interval) + elapsed += poll_interval + else: + raise RoboflowError(f"Search export timed out after {timeout}s") + + download_url = status["link"] + + # 3. Download zip + if not os.path.exists(location): + os.makedirs(location) + + zip_path = os.path.join(location, "roboflow.zip") + response = requests.get(download_url, stream=True) + try: + response.raise_for_status() + except HTTPError as e: + raise RoboflowError(f"Failed to download search export: {e}") + + total_length = response.headers.get("content-length") + try: + total_kib = int(total_length) // 1024 + 1 if total_length is not None else None + except (TypeError, ValueError): + total_kib = None + with open(zip_path, "wb") as f: + for chunk in tqdm( + response.iter_content(chunk_size=1024), + desc=f"Downloading search export to {location}", + total=total_kib, + ): + if chunk: + f.write(chunk) + f.flush() + + if extract_zip: + _extract_zip(location, desc=f"Extracting search export to {location}") + print(f"Search export extracted to {location}") + return location + else: + print(f"Search export saved to {zip_path}") + return zip_path + def __str__(self): projects = self.projects() json_value = {"name": self.name, "url": self.url, "projects": projects} diff --git a/roboflow/roboflowpy.py b/roboflow/roboflowpy.py index 70cf6db9..f68bda47 100755 --- a/roboflow/roboflowpy.py +++ b/roboflow/roboflowpy.py @@ -202,6 +202,21 @@ def infer(args): print(group) +def search_export(args): + rf = roboflow.Roboflow() + workspace = rf.workspace(args.workspace) + result = workspace.search_export( + query=args.query, + format=args.format, + location=args.location, + dataset=args.dataset, + annotation_group=args.annotation_group, + name=args.name, + extract_zip=not args.no_extract, + ) + print(result) + + def _argparser(): parser = argparse.ArgumentParser(description="Welcome to the roboflow CLI: computer vision at your fingertips 🪄") subparsers = parser.add_subparsers(title="subcommands") @@ -218,6 +233,7 @@ def _argparser(): _add_run_video_inference_api_parser(subparsers) deployment.add_deployment_parser(subparsers) _add_whoami_parser(subparsers) + _add_search_export_parser(subparsers) parser.add_argument("-v", "--version", help="show version info", action="store_true") parser.set_defaults(func=show_version) @@ -594,6 +610,19 @@ def _add_get_workspace_project_version_parser(subparsers): workspace_project_version_parser.set_defaults(func=get_workspace_project_version) +def _add_search_export_parser(subparsers): + p = subparsers.add_parser("search-export", help="Export search results as a dataset") + p.add_argument("query", help="Search query (e.g. 'tag:annotate' or '*')") + p.add_argument("-f", dest="format", default="coco", help="Annotation format (default: coco)") + p.add_argument("-w", dest="workspace", help="Workspace url or id (uses default workspace if not specified)") + p.add_argument("-l", dest="location", help="Local directory to save the export") + p.add_argument("-d", dest="dataset", help="Limit export to a specific dataset (project slug)") + p.add_argument("-g", dest="annotation_group", help="Limit export to a specific annotation group") + p.add_argument("-n", dest="name", help="Optional name for the export") + p.add_argument("--no-extract", dest="no_extract", action="store_true", help="Skip extraction, keep the zip file") + p.set_defaults(func=search_export) + + def _add_login_parser(subparsers): login_parser = subparsers.add_parser("login", help="Log in to Roboflow") login_parser.add_argument( diff --git a/roboflow/util/general.py b/roboflow/util/general.py index 9c92e552..9368d7a2 100644 --- a/roboflow/util/general.py +++ b/roboflow/util/general.py @@ -1,7 +1,13 @@ +import os import sys import time +import zipfile from random import random +from tqdm import tqdm + +from roboflow.config import TQDM_DISABLE + def write_line(line): sys.stdout.write("\r" + line) @@ -40,3 +46,22 @@ def __call__(self, func, *args, **kwargs): self.retries += 1 else: raise + + +def extract_zip(location: str, desc: str = "Extracting"): + """Extract ``roboflow.zip`` inside *location* and remove the archive. + + Args: + location: Directory containing ``roboflow.zip``. + desc: Description shown in the tqdm progress bar. + """ + zip_path = os.path.join(location, "roboflow.zip") + tqdm_desc = None if TQDM_DISABLE else desc + with zipfile.ZipFile(zip_path, "r") as zip_ref: + for member in tqdm(zip_ref.infolist(), desc=tqdm_desc): + try: + zip_ref.extract(member, location) + except zipfile.error: + raise RuntimeError("Error unzipping download") + + os.remove(zip_path) diff --git a/tests/test_search_export.py b/tests/test_search_export.py new file mode 100644 index 00000000..c8f3de1b --- /dev/null +++ b/tests/test_search_export.py @@ -0,0 +1,241 @@ +import io +import os +import shutil +import unittest +import zipfile +from unittest.mock import MagicMock, patch + +import requests +import responses + +from roboflow.adapters.rfapi import RoboflowError, get_search_export, start_search_export +from roboflow.config import API_URL + + +class TestStartSearchExport(unittest.TestCase): + API_KEY = "test_key" + WORKSPACE = "my-workspace" + + @responses.activate + def test_success(self): + url = f"{API_URL}/{self.WORKSPACE}/search/export?api_key={self.API_KEY}" + responses.add(responses.POST, url, json={"success": True, "link": "export_123"}, status=202) + + export_id = start_search_export(self.API_KEY, self.WORKSPACE, query="*", format="coco") + self.assertEqual(export_id, "export_123") + + body = responses.calls[0].request.body + self.assertIn(b'"query"', body) + self.assertIn(b'"format"', body) + + @responses.activate + def test_with_dataset(self): + url = f"{API_URL}/{self.WORKSPACE}/search/export?api_key={self.API_KEY}" + responses.add(responses.POST, url, json={"success": True, "link": "export_456"}, status=202) + + export_id = start_search_export( + self.API_KEY, self.WORKSPACE, query="tag:train", format="yolov8", dataset="my-dataset" + ) + self.assertEqual(export_id, "export_456") + + body = responses.calls[0].request.body + self.assertIn(b'"dataset"', body) + + @responses.activate + def test_error_response(self): + url = f"{API_URL}/{self.WORKSPACE}/search/export?api_key={self.API_KEY}" + responses.add(responses.POST, url, body="Bad Request", status=400) + + with self.assertRaises(RoboflowError): + start_search_export(self.API_KEY, self.WORKSPACE, query="*", format="coco") + + +class TestGetSearchExport(unittest.TestCase): + API_KEY = "test_key" + WORKSPACE = "my-workspace" + + @responses.activate + def test_not_ready(self): + url = f"{API_URL}/{self.WORKSPACE}/search/export/exp1?api_key={self.API_KEY}" + responses.add(responses.GET, url, json={"ready": False}, status=200) + + result = get_search_export(self.API_KEY, self.WORKSPACE, "exp1") + self.assertFalse(result["ready"]) + + @responses.activate + def test_ready(self): + url = f"{API_URL}/{self.WORKSPACE}/search/export/exp1?api_key={self.API_KEY}" + responses.add(responses.GET, url, json={"ready": True, "link": "https://download.url/file.zip"}, status=200) + + result = get_search_export(self.API_KEY, self.WORKSPACE, "exp1") + self.assertTrue(result["ready"]) + self.assertEqual(result["link"], "https://download.url/file.zip") + + @responses.activate + def test_error_response(self): + url = f"{API_URL}/{self.WORKSPACE}/search/export/exp1?api_key={self.API_KEY}" + responses.add(responses.GET, url, body="Not Found", status=404) + + with self.assertRaises(RoboflowError): + get_search_export(self.API_KEY, self.WORKSPACE, "exp1") + + +class TestWorkspaceSearchExportValidation(unittest.TestCase): + def _make_workspace(self): + from roboflow.core.workspace import Workspace + + info = { + "workspace": { + "name": "Test", + "url": "test-ws", + "projects": [], + "members": [], + } + } + return Workspace(info, api_key="test_key", default_workspace="test-ws", model_format="yolov8") + + def test_mutual_exclusion(self): + ws = self._make_workspace() + with self.assertRaises(ValueError) as ctx: + ws.search_export(query="*", dataset="ds", annotation_group="ag") + self.assertIn("mutually exclusive", str(ctx.exception)) + + +class TestWorkspaceSearchExportFlow(unittest.TestCase): + @staticmethod + def _build_zip_bytes(files): + buffer = io.BytesIO() + with zipfile.ZipFile(buffer, "w", zipfile.ZIP_DEFLATED) as zip_file: + for filename, content in files.items(): + zip_file.writestr(filename, content) + return buffer.getvalue() + + def _make_workspace(self): + from roboflow.core.workspace import Workspace + + info = { + "workspace": { + "name": "Test", + "url": "test-ws", + "projects": [], + "members": [], + } + } + return Workspace(info, api_key="test_key", default_workspace="test-ws", model_format="yolov8") + + @patch("roboflow.core.workspace.rfapi") + @patch("roboflow.core.workspace.requests") + def test_full_flow(self, mock_requests, mock_rfapi): + ws = self._make_workspace() + + mock_rfapi.start_search_export.return_value = "exp_abc" + mock_rfapi.get_search_export.return_value = {"ready": True, "link": "https://example.com/export.zip"} + + fake_zip = self._build_zip_bytes({"images/sample.jpg": "fake-image-data"}) + mock_response = MagicMock() + mock_response.headers = {"content-length": str(len(fake_zip))} + mock_response.raise_for_status.return_value = None + mock_response.iter_content.return_value = [fake_zip[:1024], fake_zip[1024:]] + mock_requests.get.return_value = mock_response + + location = "./test_search_export_output" + try: + result = ws.search_export(query="*", format="coco", location=location) + + expected_location = os.path.abspath(location) + self.assertEqual(result, expected_location) + self.assertTrue(os.path.exists(os.path.join(expected_location, "images", "sample.jpg"))) + self.assertFalse(os.path.exists(os.path.join(expected_location, "roboflow.zip"))) + + mock_rfapi.start_search_export.assert_called_once_with( + api_key="test_key", + workspace_url="test-ws", + query="*", + format="coco", + dataset=None, + annotation_group=None, + name=None, + ) + mock_rfapi.get_search_export.assert_called_once_with( + api_key="test_key", + workspace_url="test-ws", + export_id="exp_abc", + ) + mock_response.raise_for_status.assert_called_once() + mock_response.iter_content.assert_called_once_with(chunk_size=1024) + finally: + if os.path.exists(location): + shutil.rmtree(location) + + @patch("roboflow.core.workspace.rfapi") + @patch("roboflow.core.workspace.requests") + def test_full_flow_without_content_length_still_streams(self, mock_requests, mock_rfapi): + ws = self._make_workspace() + + mock_rfapi.start_search_export.return_value = "exp_abc" + mock_rfapi.get_search_export.return_value = {"ready": True, "link": "https://example.com/export.zip"} + + fake_zip = self._build_zip_bytes({"annotations/instances.json": "{}"}) + mock_response = MagicMock() + mock_response.headers = {} + mock_response.raise_for_status.return_value = None + mock_response.iter_content.return_value = [fake_zip] + mock_requests.get.return_value = mock_response + + location = "./test_search_export_no_content_length" + try: + result = ws.search_export(query="*", format="coco", location=location) + expected_location = os.path.abspath(location) + self.assertEqual(result, expected_location) + self.assertTrue(os.path.exists(os.path.join(expected_location, "annotations", "instances.json"))) + mock_response.iter_content.assert_called_once_with(chunk_size=1024) + finally: + if os.path.exists(location): + shutil.rmtree(location) + + @patch("roboflow.core.workspace.rfapi") + @patch("roboflow.core.workspace.requests") + def test_download_http_error_raises_roboflow_error(self, mock_requests, mock_rfapi): + ws = self._make_workspace() + + mock_rfapi.start_search_export.return_value = "exp_abc" + mock_rfapi.get_search_export.return_value = {"ready": True, "link": "https://example.com/export.zip"} + + mock_response = MagicMock() + mock_response.raise_for_status.side_effect = requests.HTTPError("403 Client Error") + mock_requests.get.return_value = mock_response + + with self.assertRaises(RoboflowError) as context: + ws.search_export(query="*", format="coco", location="./test_search_export_http_error") + + self.assertIn("Failed to download search export", str(context.exception)) + + @patch("roboflow.core.workspace.rfapi") + @patch("roboflow.core.workspace.requests") + def test_no_extract(self, mock_requests, mock_rfapi): + ws = self._make_workspace() + + mock_rfapi.start_search_export.return_value = "exp_abc" + mock_rfapi.get_search_export.return_value = {"ready": True, "link": "https://example.com/export.zip"} + + fake_zip = self._build_zip_bytes({"images/sample.jpg": "fake-image-data"}) + mock_response = MagicMock() + mock_response.headers = {"content-length": str(len(fake_zip))} + mock_response.raise_for_status.return_value = None + mock_response.iter_content.return_value = [fake_zip] + mock_requests.get.return_value = mock_response + + location = "./test_search_export_no_extract" + try: + result = ws.search_export(query="*", format="coco", location=location, extract_zip=False) + + expected_zip = os.path.join(os.path.abspath(location), "roboflow.zip") + self.assertEqual(result, expected_zip) + self.assertTrue(os.path.exists(expected_zip)) + finally: + if os.path.exists(location): + shutil.rmtree(location) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_version.py b/tests/test_version.py index 031ee674..8cd5b69c 100644 --- a/tests/test_version.py +++ b/tests/test_version.py @@ -46,7 +46,7 @@ def test_download_raises_exception_on_api_failure(self): @responses.activate @patch.object(Version, "_Version__download_zip") - @patch.object(Version, "_Version__extract_zip") + @patch("roboflow.core.version.extract_zip") @patch.object(Version, "_Version__reformat_yaml") def test_download_returns_dataset(self, *_): responses.add(responses.GET, self.api_url, json={"export": {"link": None}})