Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 46 additions & 0 deletions roboflow/adapters/rfapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,52 @@ def get_version_export(
return payload


def start_search_export(
api_key: str,
workspace_url: str,
query: str,
format: str,
dataset: Optional[str] = None,
annotation_group: Optional[str] = None,
name: Optional[str] = None,
) -> str:
"""Start a search export job.

Returns the export_id string used to poll for completion.

Raises RoboflowError on non-202 responses.
"""
url = f"{API_URL}/{workspace_url}/search/export?api_key={api_key}"
body: Dict[str, str] = {"query": query, "format": format}
if dataset is not None:
body["dataset"] = dataset
if annotation_group is not None:
body["annotationGroup"] = annotation_group
if name is not None:
body["name"] = name

response = requests.post(url, json=body)
if response.status_code != 202:
raise RoboflowError(response.text)

payload = response.json()
return payload["link"]


def get_search_export(api_key: str, workspace_url: str, export_id: str) -> dict:
"""Poll the status of a search export job.

Returns dict with ``ready`` (bool) and ``link`` (str, present when ready).

Raises RoboflowError on non-200 responses.
"""
url = f"{API_URL}/{workspace_url}/search/export/{export_id}?api_key={api_key}"
response = requests.get(url)
if response.status_code != 200:
raise RoboflowError(response.text)
return response.json()


def upload_image(
api_key,
project_url,
Expand Down
31 changes: 3 additions & 28 deletions roboflow/core/version.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
import os
import sys
import time
import zipfile
from typing import TYPE_CHECKING, Optional, Union

import requests
Expand All @@ -32,7 +31,7 @@
from roboflow.models.object_detection import ObjectDetectionModel
from roboflow.models.semantic_segmentation import SemanticSegmentationModel
from roboflow.util.annotations import amend_data_yaml
from roboflow.util.general import write_line
from roboflow.util.general import extract_zip, write_line
from roboflow.util.model_processor import process
from roboflow.util.versions import get_model_format, get_wrong_dependencies_versions, normalize_yolo_model_type

Expand Down Expand Up @@ -239,7 +238,7 @@ def download(self, model_format=None, location=None, overwrite: bool = False):
link = export_info["export"]["link"]

self.__download_zip(link, location, model_format)
self.__extract_zip(location, model_format)
extract_zip(location, desc=f"Extracting Dataset Version Zip to {location} in {model_format}:")
self.__reformat_yaml(location, model_format) # TODO: is roboflow-python a place to be munging yaml files?

return Dataset(self.name, self.version, model_format, os.path.abspath(location))
Expand Down Expand Up @@ -577,30 +576,6 @@ def bar_progress(current, total, width=80):
sys.stdout.write("\n")
sys.stdout.flush()

def __extract_zip(self, location, format):
"""
Extracts the contents of a downloaded ZIP file and then deletes the zipped file.

Args:
location (str): filepath of the data directory that contains the ZIP file
format (str): the format identifier string

Raises:
RuntimeError: If there is an error unzipping the file
""" # noqa: E501 // docs
desc = None if TQDM_DISABLE else f"Extracting Dataset Version Zip to {location} in {format}:"
with zipfile.ZipFile(location + "/roboflow.zip", "r") as zip_ref:
for member in tqdm(
zip_ref.infolist(),
desc=desc,
):
try:
zip_ref.extract(member, location)
except zipfile.error:
raise RuntimeError("Error unzipping download")

os.remove(location + "/roboflow.zip")

def __get_download_location(self):
"""
Get the local path to save a downloaded dataset to
Expand Down Expand Up @@ -707,4 +682,4 @@ def __str__(self):


def unwrap_version_id(version_id: str) -> str:
return version_id if "/" not in str(version_id) else version_id.split("/")[-1]
return version_id if "/" not in str(version_id) else version_id.rsplit("/", maxsplit=1)[-1]
106 changes: 106 additions & 0 deletions roboflow/core/workspace.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,17 +5,21 @@
import json
import os
import sys
import time
from typing import Any, Dict, List, Optional

import requests
from PIL import Image
from requests.exceptions import HTTPError
from tqdm import tqdm

from roboflow.adapters import rfapi
from roboflow.adapters.rfapi import AnnotationSaveError, ImageUploadError, RoboflowError
from roboflow.config import API_URL, APP_URL, CLIP_FEATURIZE_URL, DEMO_KEYS
from roboflow.core.project import Project
from roboflow.util import folderparser
from roboflow.util.active_learning_utils import check_box_size, clip_encode, count_comparisons
from roboflow.util.general import extract_zip as _extract_zip
from roboflow.util.image_utils import load_labelmap
from roboflow.util.model_processor import process
from roboflow.util.two_stage_utils import ocr_infer
Expand Down Expand Up @@ -662,6 +666,108 @@ def _upload_zip(
except Exception as e:
print(f"An error occured when uploading the model: {e}")

def search_export(
self,
query: str,
format: str = "coco",
location: Optional[str] = None,
dataset: Optional[str] = None,
annotation_group: Optional[str] = None,
name: Optional[str] = None,
extract_zip: bool = True,
) -> str:
"""Export search results as a downloaded dataset.

Args:
query: Search query string (e.g. ``"tag:annotate"`` or ``"*"``).
format: Annotation format for the export (default ``"coco"``).
location: Local directory to save the exported dataset.
Defaults to ``./search-export-{format}``.
dataset: Limit export to a specific dataset (project) slug.
annotation_group: Limit export to a specific annotation group.
name: Optional name for the export.
extract_zip: If True (default), extract the zip and remove it.
If False, keep the zip file as-is.

Returns:
Absolute path to the extracted directory or the zip file.

Raises:
ValueError: If both *dataset* and *annotation_group* are provided.
RoboflowError: On API errors or export timeout.
"""
if dataset is not None and annotation_group is not None:
raise ValueError("dataset and annotation_group are mutually exclusive; provide only one")

if location is None:
location = f"./search-export-{format}"
location = os.path.abspath(location)

# 1. Start the export
export_id = rfapi.start_search_export(
api_key=self.__api_key,
workspace_url=self.url,
query=query,
format=format,
dataset=dataset,
annotation_group=annotation_group,
name=name,
)
print(f"Export started (id={export_id}). Polling for completion...")

# 2. Poll until ready
timeout = 600
poll_interval = 5
elapsed = 0
while elapsed < timeout:
status = rfapi.get_search_export(
api_key=self.__api_key,
workspace_url=self.url,
export_id=export_id,
)
if status.get("ready"):
break
time.sleep(poll_interval)
elapsed += poll_interval
else:
raise RoboflowError(f"Search export timed out after {timeout}s")

download_url = status["link"]

# 3. Download zip
if not os.path.exists(location):
os.makedirs(location)

zip_path = os.path.join(location, "roboflow.zip")
response = requests.get(download_url, stream=True)
try:
response.raise_for_status()
except HTTPError as e:
raise RoboflowError(f"Failed to download search export: {e}")

total_length = response.headers.get("content-length")
try:
total_kib = int(total_length) // 1024 + 1 if total_length is not None else None
except (TypeError, ValueError):
total_kib = None
with open(zip_path, "wb") as f:
for chunk in tqdm(
response.iter_content(chunk_size=1024),
desc=f"Downloading search export to {location}",
total=total_kib,
):
if chunk:
f.write(chunk)
f.flush()

if extract_zip:
_extract_zip(location, desc=f"Extracting search export to {location}")
print(f"Search export extracted to {location}")
return location
else:
print(f"Search export saved to {zip_path}")
return zip_path

def __str__(self):
projects = self.projects()
json_value = {"name": self.name, "url": self.url, "projects": projects}
Expand Down
29 changes: 29 additions & 0 deletions roboflow/roboflowpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,6 +202,21 @@ def infer(args):
print(group)


def search_export(args):
rf = roboflow.Roboflow()
workspace = rf.workspace(args.workspace)
result = workspace.search_export(
query=args.query,
format=args.format,
location=args.location,
dataset=args.dataset,
annotation_group=args.annotation_group,
name=args.name,
extract_zip=not args.no_extract,
)
print(result)


def _argparser():
parser = argparse.ArgumentParser(description="Welcome to the roboflow CLI: computer vision at your fingertips 🪄")
subparsers = parser.add_subparsers(title="subcommands")
Expand All @@ -218,6 +233,7 @@ def _argparser():
_add_run_video_inference_api_parser(subparsers)
deployment.add_deployment_parser(subparsers)
_add_whoami_parser(subparsers)
_add_search_export_parser(subparsers)

parser.add_argument("-v", "--version", help="show version info", action="store_true")
parser.set_defaults(func=show_version)
Expand Down Expand Up @@ -594,6 +610,19 @@ def _add_get_workspace_project_version_parser(subparsers):
workspace_project_version_parser.set_defaults(func=get_workspace_project_version)


def _add_search_export_parser(subparsers):
p = subparsers.add_parser("search-export", help="Export search results as a dataset")
p.add_argument("query", help="Search query (e.g. 'tag:annotate' or '*')")
p.add_argument("-f", dest="format", default="coco", help="Annotation format (default: coco)")
p.add_argument("-w", dest="workspace", help="Workspace url or id (uses default workspace if not specified)")
p.add_argument("-l", dest="location", help="Local directory to save the export")
p.add_argument("-d", dest="dataset", help="Limit export to a specific dataset (project slug)")
p.add_argument("-g", dest="annotation_group", help="Limit export to a specific annotation group")
p.add_argument("-n", dest="name", help="Optional name for the export")
p.add_argument("--no-extract", dest="no_extract", action="store_true", help="Skip extraction, keep the zip file")
p.set_defaults(func=search_export)


def _add_login_parser(subparsers):
login_parser = subparsers.add_parser("login", help="Log in to Roboflow")
login_parser.add_argument(
Expand Down
25 changes: 25 additions & 0 deletions roboflow/util/general.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,13 @@
import os
import sys
import time
import zipfile
from random import random

from tqdm import tqdm

from roboflow.config import TQDM_DISABLE


def write_line(line):
sys.stdout.write("\r" + line)
Expand Down Expand Up @@ -40,3 +46,22 @@ def __call__(self, func, *args, **kwargs):
self.retries += 1
else:
raise


def extract_zip(location: str, desc: str = "Extracting"):
"""Extract ``roboflow.zip`` inside *location* and remove the archive.

Args:
location: Directory containing ``roboflow.zip``.
desc: Description shown in the tqdm progress bar.
"""
zip_path = os.path.join(location, "roboflow.zip")
tqdm_desc = None if TQDM_DISABLE else desc
with zipfile.ZipFile(zip_path, "r") as zip_ref:
for member in tqdm(zip_ref.infolist(), desc=tqdm_desc):
try:
zip_ref.extract(member, location)
except zipfile.error:
raise RuntimeError("Error unzipping download")

os.remove(zip_path)
Loading