diff --git a/docs/implementing_data_loaders.md b/docs/implementing_data_loaders.md index 3699532..38f791e 100644 --- a/docs/implementing_data_loaders.md +++ b/docs/implementing_data_loaders.md @@ -305,71 +305,207 @@ def _get_table_metadata(self, table: pa.Table, duration: float, batch_count: int ## Testing -### Integration Test Structure +### Generalized Test Infrastructure -Create integration tests in `tests/integration/test_{system}_loader.py`: +The project uses a generalized test infrastructure that eliminates code duplication across loader tests. Instead of writing standalone tests for each loader, you inherit from shared base test classes. + +### Architecture + +``` +tests/integration/loaders/ +├── conftest.py # Base classes and fixtures +├── test_base_loader.py # 7 core tests (all loaders inherit) +├── test_base_streaming.py # 5 streaming tests (for loaders with reorg support) +└── backends/ + ├── test_postgresql.py # PostgreSQL-specific config + tests + ├── test_redis.py # Redis-specific config + tests + └── test_example.py # Your loader tests here +``` + +### Step 1: Create Configuration Fixture + +Add your loader's configuration fixture to `tests/conftest.py`: ```python -# tests/integration/test_example_loader.py +@pytest.fixture(scope='session') +def example_test_config(request): + """Example loader configuration from testcontainer or environment""" + # Use testcontainers for CI, or fall back to environment variables + if TESTCONTAINERS_AVAILABLE and USE_TESTCONTAINERS: + # Set up testcontainer (if applicable) + example_container = request.getfixturevalue('example_container') + return { + 'host': example_container.get_container_host_ip(), + 'port': example_container.get_exposed_port(5432), + 'database': 'test_db', + 'user': 'test_user', + 'password': 'test_pass', + } + else: + # Fall back to environment variables + return { + 'host': os.getenv('EXAMPLE_HOST', 'localhost'), + 'port': int(os.getenv('EXAMPLE_PORT', '5432')), + 'database': os.getenv('EXAMPLE_DB', 'test_db'), + 'user': os.getenv('EXAMPLE_USER', 'test_user'), + 'password': os.getenv('EXAMPLE_PASSWORD', 'test_pass'), + } +``` + +### Step 2: Create Test Configuration Class + +Create `tests/integration/loaders/backends/test_example.py`: +```python +""" +Example loader integration tests using generalized test infrastructure. +""" + +from typing import Any, Dict, List, Optional import pytest -import pyarrow as pa -from src.amp.loaders.base import LoadMode + from src.amp.loaders.implementations.example_loader import ExampleLoader +from tests.integration.loaders.conftest import LoaderTestConfig +from tests.integration.loaders.test_base_loader import BaseLoaderTests +from tests.integration.loaders.test_base_streaming import BaseStreamingTests + + +class ExampleTestConfig(LoaderTestConfig): + """Example-specific test configuration""" + + loader_class = ExampleLoader + config_fixture_name = 'example_test_config' + + # Declare loader capabilities + supports_overwrite = True + supports_streaming = True # Set to False if no streaming support + supports_multi_network = True # For blockchain loaders with reorg + supports_null_values = True + + def get_row_count(self, loader: ExampleLoader, table_name: str) -> int: + """Get row count from table""" + # Implement using your loader's API + return loader._connection.query(f"SELECT COUNT(*) FROM {table_name}")[0]['count'] + + def query_rows( + self, + loader: ExampleLoader, + table_name: str, + where: Optional[str] = None, + order_by: Optional[str] = None + ) -> List[Dict[str, Any]]: + """Query rows from table""" + query = f"SELECT * FROM {table_name}" + if where: + query += f" WHERE {where}" + if order_by: + query += f" ORDER BY {order_by}" + return loader._connection.query(query) + + def cleanup_table(self, loader: ExampleLoader, table_name: str) -> None: + """Drop table""" + loader._connection.execute(f"DROP TABLE IF EXISTS {table_name}") + + def get_column_names(self, loader: ExampleLoader, table_name: str) -> List[str]: + """Get column names from table""" + result = loader._connection.query( + f"SELECT column_name FROM information_schema.columns WHERE table_name = '{table_name}'" + ) + return [row['column_name'] for row in result] + + +# Core tests - ALL loaders must inherit these +class TestExampleCore(BaseLoaderTests): + """Inherits 7 core tests: connection, context manager, batching, modes, null handling, errors""" + config = ExampleTestConfig() -@pytest.fixture -def example_config(): - return { - 'host': 'localhost', - 'port': 5432, - 'database': 'test_db', - 'user': 'test_user', - 'password': 'test_pass' - } -@pytest.fixture -def test_data(): - return pa.Table.from_pydict({ - 'id': [1, 2, 3], - 'name': ['a', 'b', 'c'], - 'value': [1.0, 2.0, 3.0] - }) +# Streaming tests - Only for loaders with streaming/reorg support +class TestExampleStreaming(BaseStreamingTests): + """Inherits 5 streaming tests: metadata columns, reorg deletion, overlapping ranges, multi-network, microbatch dedup""" + config = ExampleTestConfig() + +# Loader-specific tests @pytest.mark.integration @pytest.mark.example -class TestExampleLoaderIntegration: - def test_connection(self, example_config): - loader = ExampleLoader(example_config) - - loader.connect() - assert loader.is_connected - - loader.disconnect() - assert not loader.is_connected - - def test_basic_loading(self, example_config, test_data): - loader = ExampleLoader(example_config) - +class TestExampleSpecific: + """Example-specific functionality tests""" + config = ExampleTestConfig() + + def test_custom_feature(self, loader, test_table_name, cleanup_tables): + """Test example-specific functionality""" + cleanup_tables.append(test_table_name) + with loader: - result = loader.load_table(test_data, 'test_table') - + # Test your loader's unique features + result = loader.some_custom_method(test_table_name) assert result.success - assert result.rows_loaded == 3 - assert result.metadata['operation'] == 'load_table' - assert result.metadata['batches_processed'] > 0 +``` + +### What You Get Automatically + +By inheriting from the base test classes, you automatically get: + +**From `BaseLoaderTests` (7 core tests):** +- `test_connection` - Connection establishment and disconnection +- `test_context_manager` - Context manager functionality +- `test_batch_loading` - Basic batch loading +- `test_append_mode` - Append mode operations +- `test_overwrite_mode` - Overwrite mode operations +- `test_null_handling` - Null value handling +- `test_error_handling` - Error scenarios + +**From `BaseStreamingTests` (5 streaming tests):** +- `test_streaming_metadata_columns` - Metadata column creation +- `test_reorg_deletion` - Blockchain reorganization handling +- `test_reorg_overlapping_ranges` - Overlapping range invalidation +- `test_reorg_multi_network` - Multi-network reorg isolation +- `test_microbatch_deduplication` - Microbatch duplicate detection + +### Required LoaderTestConfig Methods + +You must implement these four methods in your `LoaderTestConfig` subclass: + +```python +def get_row_count(self, loader, table_name: str) -> int: + """Return number of rows in table""" + +def query_rows(self, loader, table_name: str, where=None, order_by=None) -> List[Dict]: + """Query and return rows as list of dicts""" + +def cleanup_table(self, loader, table_name: str) -> None: + """Drop/delete the table""" + +def get_column_names(self, loader, table_name: str) -> List[str]: + """Return list of column names""" +``` + +### Capability Flags + +Set these flags in your `LoaderTestConfig` to control which tests run: + +```python +supports_overwrite = True # Can overwrite existing data +supports_streaming = True # Supports streaming with metadata +supports_multi_network = True # Supports multi-network isolation (blockchain loaders) +supports_null_values = True # Handles NULL values correctly ``` ### Running Tests ```bash -# Run all integration tests -make test-integration +# Run all tests for your loader +uv run pytest tests/integration/loaders/backends/test_example.py -v + +# Run only core tests +uv run pytest tests/integration/loaders/backends/test_example.py::TestExampleCore -v -# Run specific loader tests -make test-example +# Run only streaming tests +uv run pytest tests/integration/loaders/backends/test_example.py::TestExampleStreaming -v -# Run with environment variables -uv run --env-file .test.env pytest tests/integration/test_example_loader.py -v +# Run specific test +uv run pytest tests/integration/loaders/backends/test_example.py::TestExampleCore::test_connection -v ``` ## Best Practices @@ -645,5 +781,3 @@ class KeyValueLoader(DataLoader[KeyValueConfig]): 'database': self.config.database } ``` - -This documentation provides everything needed to implement new data loaders efficiently and consistently! \ No newline at end of file diff --git a/src/amp/admin/async_client.py b/src/amp/admin/async_client.py new file mode 100644 index 0000000..33ae241 --- /dev/null +++ b/src/amp/admin/async_client.py @@ -0,0 +1,168 @@ +"""Async HTTP client for Amp Admin API. + +This module provides the async AdminClient class for communicating +with the Amp Admin API over HTTP using asyncio and httpx. +""" + +import os +from typing import Optional + +import httpx + +from .errors import map_error_response + + +class AsyncAdminClient: + """Async HTTP client for Amp Admin API. + + Provides access to Admin API endpoints through sub-clients for + datasets, jobs, and schema operations using async/await. + + Args: + base_url: Base URL for Admin API (e.g., 'http://localhost:8080') + auth_token: Optional Bearer token for authentication (highest priority) + auth: If True, load auth token from ~/.amp/cache (shared with TS CLI) + + Authentication Priority (highest to lowest): + 1. Explicit auth_token parameter + 2. AMP_AUTH_TOKEN environment variable + 3. auth=True - reads from ~/.amp/cache/amp_cli_auth + + Example: + >>> # Use amp auth from file + >>> async with AsyncAdminClient('http://localhost:8080', auth=True) as client: + ... datasets = await client.datasets.list_all() + >>> + >>> # Use manual token + >>> async with AsyncAdminClient('http://localhost:8080', auth_token='your-token') as client: + ... job = await client.jobs.get(123) + """ + + def __init__(self, base_url: str, auth_token: Optional[str] = None, auth: bool = False): + """Initialize async Admin API client. + + Args: + base_url: Base URL for Admin API (e.g., 'http://localhost:8080') + auth_token: Optional Bearer token for authentication + auth: If True, load auth token from ~/.amp/cache + + Raises: + ValueError: If both auth=True and auth_token are provided + """ + if auth and auth_token: + raise ValueError('Cannot specify both auth=True and auth_token. Choose one authentication method.') + + self.base_url = base_url.rstrip('/') + + # Resolve auth token provider with priority: explicit param > env var > auth file + self._get_token = None + if auth_token: + # Priority 1: Explicit auth_token parameter (static token) + self._get_token = lambda: auth_token + elif os.getenv('AMP_AUTH_TOKEN'): + # Priority 2: AMP_AUTH_TOKEN environment variable (static token) + env_token = os.getenv('AMP_AUTH_TOKEN') + self._get_token = lambda: env_token + elif auth: + # Priority 3: Load from ~/.amp-cli-config/amp_cli_auth (auto-refreshing) + from amp.auth import AuthService + + auth_service = AuthService() + self._get_token = auth_service.get_token # Callable that auto-refreshes + + # Create async HTTP client (no auth header yet - will be added per-request) + self._http = httpx.AsyncClient( + base_url=self.base_url, + timeout=30.0, + follow_redirects=True, + ) + + async def _request( + self, method: str, path: str, json: Optional[dict] = None, params: Optional[dict] = None, **kwargs + ) -> httpx.Response: + """Make async HTTP request with error handling. + + Args: + method: HTTP method (GET, POST, DELETE, etc.) + path: API endpoint path (e.g., '/datasets') + json: Optional JSON request body + params: Optional query parameters + **kwargs: Additional arguments passed to httpx.request() + + Returns: + HTTP response object + + Raises: + AdminAPIError: If the API returns an error response + """ + # Add auth header dynamically (auto-refreshes if needed) + headers = kwargs.get('headers', {}) + if self._get_token: + headers['Authorization'] = f'Bearer {self._get_token()}' + kwargs['headers'] = headers + + response = await self._http.request(method, path, json=json, params=params, **kwargs) + + # Handle error responses + if response.status_code >= 400: + try: + error_data = response.json() + raise map_error_response(response.status_code, error_data) + except ValueError: + # Response is not JSON, fall back to generic HTTP error + response.raise_for_status() + + return response + + @property + def datasets(self): + """Access async datasets client. + + Returns: + AsyncDatasetsClient for dataset operations + """ + from .async_datasets import AsyncDatasetsClient + + return AsyncDatasetsClient(self) + + @property + def jobs(self): + """Access async jobs client. + + Returns: + AsyncJobsClient for job operations + """ + from .async_jobs import AsyncJobsClient + + return AsyncJobsClient(self) + + @property + def schema(self): + """Access async schema client. + + Returns: + AsyncSchemaClient for schema operations + """ + from .async_schema import AsyncSchemaClient + + return AsyncSchemaClient(self) + + async def close(self): + """Close the HTTP client and release resources. + + Example: + >>> client = AsyncAdminClient('http://localhost:8080') + >>> try: + ... datasets = await client.datasets.list_all() + ... finally: + ... await client.close() + """ + await self._http.aclose() + + async def __aenter__(self): + """Async context manager entry.""" + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb): + """Async context manager exit.""" + await self.close() diff --git a/src/amp/admin/async_datasets.py b/src/amp/admin/async_datasets.py new file mode 100644 index 0000000..f87668c --- /dev/null +++ b/src/amp/admin/async_datasets.py @@ -0,0 +1,244 @@ +"""Async datasets client for Admin API. + +This module provides the AsyncDatasetsClient class for managing datasets, +including registration, deployment, versioning, and manifest operations. +""" + +from typing import TYPE_CHECKING, Dict, Optional + +from amp.utils.manifest_inspector import describe_manifest + +from . import models + +if TYPE_CHECKING: + from .async_client import AsyncAdminClient + + +class AsyncDatasetsClient: + """Async client for dataset operations. + + Provides async methods for registering, deploying, listing, and managing datasets + through the Admin API. + + Args: + admin_client: Parent AsyncAdminClient instance + + Example: + >>> async with AsyncAdminClient('http://localhost:8080') as client: + ... datasets = await client.datasets.list_all() + """ + + def __init__(self, admin_client: 'AsyncAdminClient'): + """Initialize async datasets client. + + Args: + admin_client: Parent AsyncAdminClient instance + """ + self._admin = admin_client + + async def register(self, namespace: str, name: str, version: str, manifest: dict) -> None: + """Register a dataset manifest. + + Registers a new dataset configuration in the server's local registry. + The manifest defines tables, dependencies, and extraction logic. + + Args: + namespace: Dataset namespace (e.g., '_') + name: Dataset name + version: Semantic version (e.g., '1.0.0') or tag ('latest', 'dev') + manifest: Dataset manifest dict (kind='manifest') + + Raises: + InvalidManifestError: If manifest is invalid + DependencyValidationError: If dependencies are invalid + ManifestRegistrationError: If registration fails + + Example: + >>> async with AsyncAdminClient('http://localhost:8080') as client: + ... await client.datasets.register('_', 'my_dataset', '1.0.0', manifest) + """ + request_data = models.RegisterRequest(namespace=namespace, name=name, version=version, manifest=manifest) + + await self._admin._request('POST', '/datasets', json=request_data.model_dump(mode='json', exclude_none=True)) + + async def deploy( + self, + namespace: str, + name: str, + revision: str, + end_block: Optional[str] = None, + parallelism: Optional[int] = None, + ) -> models.DeployResponse: + """Deploy a dataset version. + + Triggers data extraction for the specified dataset version. + + Args: + namespace: Dataset namespace + name: Dataset name + revision: Version tag ('latest', 'dev', '1.0.0', etc.) + end_block: Optional end block ('latest', '-100', '1000000', or null) + parallelism: Optional number of parallel workers + + Returns: + DeployResponse with job_id + + Raises: + DatasetNotFoundError: If dataset/version not found + SchedulerError: If deployment fails + + Example: + >>> async with AsyncAdminClient('http://localhost:8080') as client: + ... response = await client.datasets.deploy('_', 'my_dataset', '1.0.0', parallelism=4) + ... print(f'Job ID: {response.job_id}') + """ + path = f'/datasets/{namespace}/{name}/versions/{revision}/deploy' + + # Build request body (POST requires JSON body, not query params) + body = {} + if end_block is not None: + body['end_block'] = end_block + if parallelism is not None: + body['parallelism'] = parallelism + + response = await self._admin._request('POST', path, json=body if body else {}) + return models.DeployResponse.model_validate(response.json()) + + async def list_all(self) -> models.DatasetsResponse: + """List all registered datasets. + + Returns all datasets across all namespaces with version information. + + Returns: + DatasetsResponse with list of datasets + + Raises: + ListAllDatasetsError: If listing fails + + Example: + >>> async with AsyncAdminClient('http://localhost:8080') as client: + ... datasets = await client.datasets.list_all() + ... for ds in datasets.datasets: + ... print(f'{ds.namespace}/{ds.name}: {ds.latest_version}') + """ + response = await self._admin._request('GET', '/datasets') + return models.DatasetsResponse.model_validate(response.json()) + + async def get_versions(self, namespace: str, name: str) -> models.VersionsResponse: + """List all versions of a dataset. + + Returns version information including semantic versions and special tags. + + Args: + namespace: Dataset namespace + name: Dataset name + + Returns: + VersionsResponse with version list + + Raises: + DatasetNotFoundError: If dataset not found + ListDatasetVersionsError: If listing fails + + Example: + >>> async with AsyncAdminClient('http://localhost:8080') as client: + ... versions = await client.datasets.get_versions('_', 'eth_firehose') + ... print(f'Latest: {versions.special_tags.latest}') + """ + path = f'/datasets/{namespace}/{name}/versions' + response = await self._admin._request('GET', path) + return models.VersionsResponse.model_validate(response.json()) + + async def get_version(self, namespace: str, name: str, revision: str) -> models.VersionInfo: + """Get detailed information about a specific dataset version. + + Args: + namespace: Dataset namespace + name: Dataset name + revision: Version tag or semantic version + + Returns: + VersionInfo with dataset details + + Raises: + DatasetNotFoundError: If dataset/version not found + GetDatasetVersionError: If retrieval fails + + Example: + >>> async with AsyncAdminClient('http://localhost:8080') as client: + ... info = await client.datasets.get_version('_', 'eth_firehose', '1.0.0') + ... print(f'Kind: {info.kind}') + """ + path = f'/datasets/{namespace}/{name}/versions/{revision}' + response = await self._admin._request('GET', path) + return models.VersionInfo.model_validate(response.json()) + + async def get_manifest(self, namespace: str, name: str, revision: str) -> dict: + """Get the manifest for a specific dataset version. + + Args: + namespace: Dataset namespace + name: Dataset name + revision: Version tag or semantic version + + Returns: + Manifest dict + + Raises: + DatasetNotFoundError: If dataset/version not found + GetManifestError: If retrieval fails + + Example: + >>> async with AsyncAdminClient('http://localhost:8080') as client: + ... manifest = await client.datasets.get_manifest('_', 'eth_firehose', '1.0.0') + ... print(manifest['kind']) + """ + path = f'/datasets/{namespace}/{name}/versions/{revision}/manifest' + response = await self._admin._request('GET', path) + return response.json() + + async def describe( + self, namespace: str, name: str, revision: str = 'latest' + ) -> Dict[str, list[Dict[str, str | bool]]]: + """Get a structured summary of tables and columns in a dataset. + + Returns a dictionary mapping table names to lists of column information, + making it easy to programmatically inspect the dataset schema. + + Args: + namespace: Dataset namespace + name: Dataset name + revision: Version tag (default: 'latest') + + Returns: + dict: Mapping of table names to column information. + + Example: + >>> async with AsyncAdminClient('http://localhost:8080') as client: + ... schema = await client.datasets.describe('_', 'eth_firehose', 'latest') + ... for table_name, columns in schema.items(): + ... print(f"Table: {table_name}") + """ + manifest = await self.get_manifest(namespace, name, revision) + return describe_manifest(manifest) + + async def delete(self, namespace: str, name: str) -> None: + """Delete all versions and metadata for a dataset. + + Removes all manifest links and version tags for the dataset. + Orphaned manifests (not referenced by other datasets) are also deleted. + + Args: + namespace: Dataset namespace + name: Dataset name + + Raises: + InvalidPathError: If namespace/name invalid + UnlinkDatasetManifestsError: If deletion fails + + Example: + >>> async with AsyncAdminClient('http://localhost:8080') as client: + ... await client.datasets.delete('_', 'my_old_dataset') + """ + path = f'/datasets/{namespace}/{name}' + await self._admin._request('DELETE', path) diff --git a/src/amp/admin/async_jobs.py b/src/amp/admin/async_jobs.py new file mode 100644 index 0000000..c811e09 --- /dev/null +++ b/src/amp/admin/async_jobs.py @@ -0,0 +1,187 @@ +"""Async jobs client for Admin API. + +This module provides the AsyncJobsClient class for monitoring and managing +extraction jobs using async/await. +""" + +from __future__ import annotations + +import asyncio +from typing import TYPE_CHECKING, Optional + +from . import models + +if TYPE_CHECKING: + from .async_client import AsyncAdminClient + + +class AsyncJobsClient: + """Async client for job operations. + + Provides async methods for monitoring, managing, and waiting for extraction jobs. + + Args: + admin_client: Parent AsyncAdminClient instance + + Example: + >>> async with AsyncAdminClient('http://localhost:8080') as client: + ... job = await client.jobs.get(123) + ... print(f'Status: {job.status}') + """ + + def __init__(self, admin_client: 'AsyncAdminClient'): + """Initialize async jobs client. + + Args: + admin_client: Parent AsyncAdminClient instance + """ + self._admin = admin_client + + async def get(self, job_id: int) -> models.JobInfo: + """Get job information by ID. + + Args: + job_id: Job ID to retrieve + + Returns: + JobInfo with job details + + Raises: + JobNotFoundError: If job not found + + Example: + >>> async with AsyncAdminClient('http://localhost:8080') as client: + ... job = await client.jobs.get(123) + ... print(f'Status: {job.status}') + """ + path = f'/jobs/{job_id}' + response = await self._admin._request('GET', path) + return models.JobInfo.model_validate(response.json()) + + async def list(self, limit: int = 50, last_job_id: Optional[int] = None) -> models.JobsResponse: + """List jobs with pagination. + + Args: + limit: Maximum number of jobs to return (default: 50, max: 1000) + last_job_id: Cursor from previous page's next_cursor field + + Returns: + JobsResponse with jobs and optional next_cursor + + Raises: + ListJobsError: If listing fails + + Example: + >>> async with AsyncAdminClient('http://localhost:8080') as client: + ... response = await client.jobs.list(limit=100) + ... for job in response.jobs: + ... print(f'{job.id}: {job.status}') + """ + params = {'limit': limit} + if last_job_id is not None: + params['last_job_id'] = last_job_id + + response = await self._admin._request('GET', '/jobs', params=params) + return models.JobsResponse.model_validate(response.json()) + + async def wait_for_completion( + self, job_id: int, poll_interval: int = 5, timeout: Optional[int] = None + ) -> models.JobInfo: + """Poll job until completion or timeout. + + Continuously polls the job status until it reaches a terminal state + (Completed, Failed, or Stopped). Uses asyncio.sleep for non-blocking waits. + + Args: + job_id: Job ID to monitor + poll_interval: Seconds between status checks (default: 5) + timeout: Optional timeout in seconds (default: None = infinite) + + Returns: + Final JobInfo when job completes + + Raises: + JobNotFoundError: If job not found + TimeoutError: If timeout is reached before completion + + Example: + >>> async with AsyncAdminClient('http://localhost:8080') as client: + ... deploy_resp = await client.datasets.deploy('_', 'my_dataset', '1.0.0') + ... final_job = await client.jobs.wait_for_completion(deploy_resp.job_id) + ... print(f'Final status: {final_job.status}') + """ + elapsed = 0.0 + terminal_states = {'Completed', 'Failed', 'Stopped'} + + while True: + job = await self.get(job_id) + + # Check if job reached terminal state + if job.status in terminal_states: + return job + + # Check timeout + if timeout is not None and elapsed >= timeout: + raise TimeoutError( + f'Job {job_id} did not complete within {timeout} seconds. Current status: {job.status}' + ) + + # Wait before next poll (non-blocking) + await asyncio.sleep(poll_interval) + elapsed += poll_interval + + async def stop(self, job_id: int) -> None: + """Stop a running job. + + Requests the job to stop gracefully. The job will transition through + StopRequested and Stopping states before reaching Stopped. + + Args: + job_id: Job ID to stop + + Raises: + JobNotFoundError: If job not found + JobStopError: If stop request fails + + Example: + >>> async with AsyncAdminClient('http://localhost:8080') as client: + ... await client.jobs.stop(123) + """ + path = f'/jobs/{job_id}/stop' + await self._admin._request('POST', path) + + async def delete(self, job_id: int) -> None: + """Delete a job in terminal state. + + Only jobs in terminal states (Completed, Failed, Stopped) can be deleted. + + Args: + job_id: Job ID to delete + + Raises: + JobNotFoundError: If job not found + JobDeleteError: If job is not in terminal state or deletion fails + + Example: + >>> async with AsyncAdminClient('http://localhost:8080') as client: + ... await client.jobs.delete(123) + """ + path = f'/jobs/{job_id}' + await self._admin._request('DELETE', path) + + async def delete_many(self, job_ids: list[int]) -> None: + """Delete multiple jobs in bulk. + + All specified jobs must be in terminal states. + + Args: + job_ids: List of job IDs to delete + + Raises: + JobsDeleteError: If any deletion fails + + Example: + >>> async with AsyncAdminClient('http://localhost:8080') as client: + ... await client.jobs.delete_many([123, 124, 125]) + """ + await self._admin._request('DELETE', '/jobs', json={'job_ids': job_ids}) diff --git a/src/amp/admin/async_schema.py b/src/amp/admin/async_schema.py new file mode 100644 index 0000000..74297eb --- /dev/null +++ b/src/amp/admin/async_schema.py @@ -0,0 +1,65 @@ +"""Async schema client for Admin API. + +This module provides the AsyncSchemaClient class for querying output schemas +of SQL queries without executing them using async/await. +""" + +from typing import TYPE_CHECKING + +from . import models + +if TYPE_CHECKING: + from .async_client import AsyncAdminClient + + +class AsyncSchemaClient: + """Async client for schema operations. + + Provides async methods for validating SQL queries and determining output schemas + using DataFusion's query planner. + + Args: + admin_client: Parent AsyncAdminClient instance + + Example: + >>> async with AsyncAdminClient('http://localhost:8080') as client: + ... schema = await client.schema.get_output_schema('SELECT * FROM eth.blocks', True) + """ + + def __init__(self, admin_client: 'AsyncAdminClient'): + """Initialize async schema client. + + Args: + admin_client: Parent AsyncAdminClient instance + """ + self._admin = admin_client + + async def get_output_schema(self, sql_query: str, is_sql_dataset: bool = True) -> models.OutputSchemaResponse: + """Get output schema for a SQL query. + + Validates the query and returns the Arrow schema that would be produced, + without actually executing the query. + + Args: + sql_query: SQL query to analyze + is_sql_dataset: Whether this is for a SQL dataset (default: True) + + Returns: + OutputSchemaResponse with Arrow schema + + Raises: + GetOutputSchemaError: If schema analysis fails + DependencyValidationError: If query references invalid dependencies + + Example: + >>> async with AsyncAdminClient('http://localhost:8080') as client: + ... schema_resp = await client.schema.get_output_schema( + ... 'SELECT block_num, hash FROM eth.blocks WHERE block_num > 1000000', + ... is_sql_dataset=True + ... ) + ... print(schema_resp.schema) + """ + request_data = models.OutputSchemaRequest(sql_query=sql_query, is_sql_dataset=is_sql_dataset) + + response = await self._admin._request('POST', '/schema', json=request_data.model_dump(mode='json')) + return models.OutputSchemaResponse.model_validate(response.json()) diff --git a/src/amp/async_client.py b/src/amp/async_client.py new file mode 100644 index 0000000..9b5ecc1 --- /dev/null +++ b/src/amp/async_client.py @@ -0,0 +1,718 @@ +"""Async Flight SQL client with data loading capabilities. + +This module provides the AsyncAmpClient class for async operations +with the Flight SQL server and Admin/Registry APIs. + +The async client is optimized for: +- Non-blocking HTTP API calls (Admin, Registry) +- Concurrent operations using asyncio +- Streaming data with async iteration + +Note: Flight SQL (gRPC) operations currently remain synchronous as PyArrow's +Flight client doesn't have native async support. For streaming operations, +consider using run_in_executor or the sync Client. +""" + +import asyncio +import logging +import os +from typing import AsyncIterator, Dict, Iterator, List, Optional, Union + +import pyarrow as pa +from google.protobuf.any_pb2 import Any +from pyarrow import flight +from pyarrow.flight import ClientMiddleware, ClientMiddlewareFactory + +from . import FlightSql_pb2 +from .config.connection_manager import ConnectionManager +from .config.label_manager import LabelManager +from .loaders.registry import create_loader, get_available_loaders +from .loaders.types import LabelJoinConfig, LoadConfig, LoadMode, LoadResult +from .streaming import ( + ReorgAwareStream, + ResumeWatermark, + StreamingResultIterator, +) + + +class AuthMiddleware(ClientMiddleware): + """Flight middleware to add Bearer token authentication header.""" + + def __init__(self, get_token): + """Initialize auth middleware. + + Args: + get_token: Callable that returns the current access token + """ + self.get_token = get_token + + def sending_headers(self): + """Add Authorization header to outgoing requests.""" + return {'authorization': f'Bearer {self.get_token()}'} + + +class AuthMiddlewareFactory(ClientMiddlewareFactory): + """Factory for creating auth middleware instances.""" + + def __init__(self, get_token): + """Initialize auth middleware factory. + + Args: + get_token: Callable that returns the current access token + """ + self.get_token = get_token + + def start_call(self, info): + """Create auth middleware for each call.""" + return AuthMiddleware(self.get_token) + + +class AsyncQueryBuilder: + """Async chainable query builder for data loading operations. + + Provides async versions of query operations. + """ + + def __init__(self, client: 'AsyncAmpClient', query: str): + self.client = client + self.query = query + self._result_cache = None + self._dependencies: Dict[str, str] = {} + self.logger = logging.getLogger(__name__) + + async def load( + self, + connection: str, + destination: str, + config: Dict[str, any] = None, + label_config: Optional[LabelJoinConfig] = None, + **kwargs, + ) -> Union[LoadResult, AsyncIterator[LoadResult]]: + """ + Async load query results to specified destination. + + Note: The actual data loading operations run synchronously in a thread + pool executor since PyArrow Flight doesn't support native async. + + Args: + connection: Named connection or connection name for auto-discovery + destination: Target destination (table name, key, path, etc.) + config: Inline configuration dict (alternative to connection) + label_config: Optional LabelJoinConfig for joining with label data + **kwargs: Additional loader-specific options + + Returns: + LoadResult or async iterator of LoadResults + """ + # Handle streaming mode + if kwargs.get('stream', False): + kwargs.pop('stream') + streaming_query = self._ensure_streaming_query(self.query) + return await self.client.query_and_load_streaming( + query=streaming_query, + destination=destination, + connection_name=connection, + config=config, + label_config=label_config, + **kwargs, + ) + + # Validate that parallel_config is only used with stream=True + if kwargs.get('parallel_config'): + raise ValueError('parallel_config requires stream=True') + + kwargs.setdefault('read_all', False) + + return await self.client.query_and_load( + query=self.query, + destination=destination, + connection_name=connection, + config=config, + label_config=label_config, + **kwargs, + ) + + def _ensure_streaming_query(self, query: str) -> str: + """Ensure query has SETTINGS stream = true""" + query = query.strip().rstrip(';') + if 'SETTINGS stream = true' not in query.upper(): + query += ' SETTINGS stream = true' + return query + + async def stream(self) -> AsyncIterator[pa.RecordBatch]: + """Stream query results as Arrow batches asynchronously.""" + self.logger.debug(f'Starting async stream for query: {self.query[:50]}...') + # Run synchronous Flight SQL operation in executor + loop = asyncio.get_event_loop() + batches = await loop.run_in_executor(None, lambda: list(self.client.get_sql_sync(self.query, read_all=False))) + for batch in batches: + yield batch + + async def to_arrow(self) -> pa.Table: + """Get query results as Arrow table asynchronously.""" + if self._result_cache is None: + self.logger.debug(f'Executing query for Arrow table: {self.query[:50]}...') + loop = asyncio.get_event_loop() + self._result_cache = await loop.run_in_executor( + None, lambda: self.client.get_sql_sync(self.query, read_all=True) + ) + return self._result_cache + + async def to_manifest(self, table_name: str, network: str = 'mainnet') -> dict: + """Generate a dataset manifest from this query asynchronously. + + Automatically fetches the Arrow schema using the Admin API /schema endpoint. + Requires the Client to be initialized with admin_url. + + Args: + table_name: Name for the table in the manifest + network: Network name (default: 'mainnet') + + Returns: + Complete manifest dict ready for registration + """ + # Get schema from Admin API + schema_response = await self.client.schema.get_output_schema(self.query, is_sql_dataset=True) + + # Build manifest structure + manifest = { + 'kind': 'manifest', + 'dependencies': self._dependencies, + 'tables': { + table_name: { + 'input': {'sql': self.query}, + 'schema': schema_response.schema_, + 'network': network, + } + }, + 'functions': {}, + } + return manifest + + def with_dependency(self, alias: str, reference: str) -> 'AsyncQueryBuilder': + """Add a dataset dependency for manifest generation.""" + self._dependencies[alias] = reference + return self + + def __repr__(self): + return f"AsyncQueryBuilder(query='{self.query[:50]}{'...' if len(self.query) > 50 else ''}')" + + +class AsyncAmpClient: + """Async Flight SQL client with data loading capabilities. + + Supports both query operations (via Flight SQL) and optional admin operations + (via async HTTP Admin API) and registry operations (via async Registry API). + + The Flight SQL operations are run in a thread pool executor since PyArrow's + Flight client doesn't have native async support. HTTP operations (Admin, + Registry) are fully async. + + Args: + url: Flight SQL URL (for backward compatibility, treated as query_url) + query_url: Query endpoint URL via Flight SQL (e.g., 'grpc://localhost:1602') + admin_url: Optional Admin API URL (e.g., 'http://localhost:8080') + registry_url: Optional Registry API URL (default: staging registry) + auth_token: Optional Bearer token for authentication (highest priority) + auth: If True, load auth token from ~/.amp/cache (shared with TS CLI) + + Authentication Priority (highest to lowest): + 1. Explicit auth_token parameter + 2. AMP_AUTH_TOKEN environment variable + 3. auth=True - reads from ~/.amp/cache/amp_cli_auth + + Example: + >>> # Query with async admin operations + >>> async with AsyncAmpClient( + ... query_url='grpc://localhost:1602', + ... admin_url='http://localhost:8080', + ... auth=True + ... ) as client: + ... datasets = await client.datasets.list_all() + ... table = await client.sql("SELECT * FROM eth.blocks LIMIT 10").to_arrow() + """ + + def __init__( + self, + url: Optional[str] = None, + query_url: Optional[str] = None, + admin_url: Optional[str] = None, + registry_url: str = 'https://api.registry.amp.staging.thegraph.com', + auth_token: Optional[str] = None, + auth: bool = False, + ): + # Backward compatibility: url parameter → query_url + if url and not query_url: + query_url = url + + # Resolve auth token provider with priority: explicit param > env var > auth file + get_token = None + if auth_token: + def get_token(): + return auth_token + elif os.getenv('AMP_AUTH_TOKEN'): + env_token = os.getenv('AMP_AUTH_TOKEN') + + def get_token(): + return env_token + elif auth: + from amp.auth import AuthService + + auth_service = AuthService() + get_token = auth_service.get_token + + # Initialize Flight SQL client + if query_url: + if get_token: + middleware = [AuthMiddlewareFactory(get_token)] + self.conn = flight.connect(query_url, middleware=middleware) + else: + self.conn = flight.connect(query_url) + else: + raise ValueError('Either url or query_url must be provided for Flight SQL connection') + + # Initialize managers + self.connection_manager = ConnectionManager() + self.label_manager = LabelManager() + self.logger = logging.getLogger(__name__) + + # Store URLs and auth params for lazy initialization of async clients + self._admin_url = admin_url + self._registry_url = registry_url + self._auth_token = auth_token + self._auth = auth + + # Lazy-initialized async clients + self._admin_client = None + self._registry_client = None + + def sql(self, query: str) -> AsyncQueryBuilder: + """ + Create an async chainable query builder. + + Args: + query: SQL query string + + Returns: + AsyncQueryBuilder instance for chaining operations + """ + return AsyncQueryBuilder(self, query) + + def configure_connection(self, name: str, loader: str, config: Dict[str, any]) -> None: + """Configure a named connection for reuse.""" + self.connection_manager.add_connection(name, loader, config) + + def configure_label(self, name: str, csv_path: str, binary_columns: Optional[List[str]] = None) -> None: + """Configure a label dataset from a CSV file for joining with streaming data.""" + self.label_manager.add_label(name, csv_path, binary_columns) + + def list_connections(self) -> Dict[str, str]: + """List all configured connections.""" + return self.connection_manager.list_connections() + + def get_available_loaders(self) -> List[str]: + """Get list of available data loaders.""" + return get_available_loaders() + + # Async Admin API access (optional, requires admin_url) + @property + def datasets(self): + """Access async datasets client for Admin API operations. + + Returns: + AsyncDatasetsClient for dataset registration, deployment, and management + + Raises: + ValueError: If admin_url was not provided during Client initialization + """ + if not self._admin_url: + raise ValueError( + 'Admin API not configured. Provide admin_url parameter to AsyncAmpClient() ' + 'to enable dataset management operations.' + ) + if not self._admin_client: + from amp.admin.async_client import AsyncAdminClient + + if self._auth: + self._admin_client = AsyncAdminClient(self._admin_url, auth=True) + elif self._auth_token or os.getenv('AMP_AUTH_TOKEN'): + token = self._auth_token or os.getenv('AMP_AUTH_TOKEN') + self._admin_client = AsyncAdminClient(self._admin_url, auth_token=token) + else: + self._admin_client = AsyncAdminClient(self._admin_url) + return self._admin_client.datasets + + @property + def jobs(self): + """Access async jobs client for Admin API operations. + + Returns: + AsyncJobsClient for job monitoring and management + + Raises: + ValueError: If admin_url was not provided during Client initialization + """ + if not self._admin_url: + raise ValueError( + 'Admin API not configured. Provide admin_url parameter to AsyncAmpClient() ' + 'to enable job monitoring operations.' + ) + if not self._admin_client: + from amp.admin.async_client import AsyncAdminClient + + if self._auth: + self._admin_client = AsyncAdminClient(self._admin_url, auth=True) + elif self._auth_token or os.getenv('AMP_AUTH_TOKEN'): + token = self._auth_token or os.getenv('AMP_AUTH_TOKEN') + self._admin_client = AsyncAdminClient(self._admin_url, auth_token=token) + else: + self._admin_client = AsyncAdminClient(self._admin_url) + return self._admin_client.jobs + + @property + def schema(self): + """Access async schema client for Admin API operations. + + Returns: + AsyncSchemaClient for SQL query schema analysis + + Raises: + ValueError: If admin_url was not provided during Client initialization + """ + if not self._admin_url: + raise ValueError( + 'Admin API not configured. Provide admin_url parameter to AsyncAmpClient() ' + 'to enable schema analysis operations.' + ) + if not self._admin_client: + from amp.admin.async_client import AsyncAdminClient + + if self._auth: + self._admin_client = AsyncAdminClient(self._admin_url, auth=True) + elif self._auth_token or os.getenv('AMP_AUTH_TOKEN'): + token = self._auth_token or os.getenv('AMP_AUTH_TOKEN') + self._admin_client = AsyncAdminClient(self._admin_url, auth_token=token) + else: + self._admin_client = AsyncAdminClient(self._admin_url) + return self._admin_client.schema + + @property + def registry(self): + """Access async registry client for Registry API operations. + + Returns: + AsyncRegistryClient for dataset discovery, search, and publishing + + Raises: + ValueError: If registry_url was not provided during Client initialization + """ + if not self._registry_url: + raise ValueError( + 'Registry API not configured. Provide registry_url parameter to AsyncAmpClient() ' + 'to enable dataset discovery and search operations.' + ) + if not self._registry_client: + from amp.registry.async_client import AsyncRegistryClient + + if self._auth: + self._registry_client = AsyncRegistryClient(self._registry_url, auth=True) + elif self._auth_token or os.getenv('AMP_AUTH_TOKEN'): + token = self._auth_token or os.getenv('AMP_AUTH_TOKEN') + self._registry_client = AsyncRegistryClient(self._registry_url, auth_token=token) + else: + self._registry_client = AsyncRegistryClient(self._registry_url) + return self._registry_client + + # Synchronous Flight SQL methods (run in executor for async context) + def get_sql_sync(self, query: str, read_all: bool = False): + """Execute SQL query and return Arrow data (synchronous). + + This is the underlying synchronous method used by async wrappers. + """ + command_query = FlightSql_pb2.CommandStatementQuery() + command_query.query = query + + any_command = Any() + any_command.Pack(command_query) + cmd = any_command.SerializeToString() + + flight_descriptor = flight.FlightDescriptor.for_command(cmd) + info = self.conn.get_flight_info(flight_descriptor) + reader = self.conn.do_get(info.endpoints[0].ticket) + + if read_all: + return reader.read_all() + else: + return self._batch_generator(reader) + + def _batch_generator(self, reader) -> Iterator[pa.RecordBatch]: + """Generate batches from Flight reader.""" + while True: + try: + chunk = reader.read_chunk() + yield chunk.data + except StopIteration: + break + + async def get_sql(self, query: str, read_all: bool = False): + """Execute SQL query asynchronously and return Arrow data. + + Runs the synchronous Flight SQL operation in a thread pool executor. + """ + loop = asyncio.get_event_loop() + if read_all: + return await loop.run_in_executor(None, lambda: self.get_sql_sync(query, read_all=True)) + else: + batches = await loop.run_in_executor(None, lambda: list(self.get_sql_sync(query, read_all=False))) + return batches + + async def query_and_load( + self, + query: str, + destination: str, + connection_name: str, + config: Optional[Dict[str, any]] = None, + label_config: Optional[LabelJoinConfig] = None, + **kwargs, + ) -> Union[LoadResult, AsyncIterator[LoadResult]]: + """Execute query and load results directly into target system asynchronously. + + Runs the data loading operation in a thread pool executor. + """ + loop = asyncio.get_event_loop() + + # Run the synchronous query_and_load in executor + def sync_load(): + # Get connection configuration and determine loader type + if connection_name: + try: + connection_info = self.connection_manager.get_connection_info(connection_name) + loader_config = connection_info['config'] + loader_type = connection_info['loader'] + except ValueError as e: + self.logger.error(f'Connection error: {e}') + raise + elif config: + loader_type = config.pop('loader_type', None) + if not loader_type: + raise ValueError("When using inline config, 'loader_type' must be specified") + loader_config = config + else: + raise ValueError('Either connection_name or config must be provided') + + # Extract load options + read_all = kwargs.pop('read_all', False) + load_config = LoadConfig( + batch_size=kwargs.pop('batch_size', 10000), + mode=LoadMode(kwargs.pop('mode', 'append')), + create_table=kwargs.pop('create_table', True), + schema_evolution=kwargs.pop('schema_evolution', False), + **{k: v for k, v in kwargs.items() if k in ['max_retries', 'retry_delay']}, + ) + + for key in ['max_retries', 'retry_delay']: + kwargs.pop(key, None) + + loader_specific_kwargs = kwargs + + if read_all: + table = self.get_sql_sync(query, read_all=True) + return self._load_table_sync( + table, + loader_type, + destination, + loader_config, + load_config, + label_config=label_config, + **loader_specific_kwargs, + ) + else: + batch_stream = self.get_sql_sync(query, read_all=False) + return list( + self._load_stream_sync( + batch_stream, + loader_type, + destination, + loader_config, + load_config, + label_config=label_config, + **loader_specific_kwargs, + ) + ) + + return await loop.run_in_executor(None, sync_load) + + def _load_table_sync( + self, + table: pa.Table, + loader: str, + table_name: str, + config: Dict[str, any], + load_config: LoadConfig, + **kwargs, + ) -> LoadResult: + """Load a complete Arrow Table synchronously.""" + try: + loader_instance = create_loader(loader, config, label_manager=self.label_manager) + + with loader_instance: + return loader_instance.load_table(table, table_name, **load_config.__dict__, **kwargs) + except Exception as e: + self.logger.error(f'Failed to load table: {e}') + return LoadResult( + rows_loaded=0, + duration=0.0, + ops_per_second=0.0, + table_name=table_name, + loader_type=loader, + success=False, + error=str(e), + ) + + def _load_stream_sync( + self, + batch_stream: Iterator[pa.RecordBatch], + loader: str, + table_name: str, + config: Dict[str, any], + load_config: LoadConfig, + **kwargs, + ) -> Iterator[LoadResult]: + """Load from a stream of batches synchronously.""" + try: + loader_instance = create_loader(loader, config, label_manager=self.label_manager) + + with loader_instance: + yield from loader_instance.load_stream(batch_stream, table_name, **load_config.__dict__, **kwargs) + except Exception as e: + self.logger.error(f'Failed to load stream: {e}') + yield LoadResult( + rows_loaded=0, + duration=0.0, + ops_per_second=0.0, + table_name=table_name, + loader_type=loader, + success=False, + error=str(e), + ) + + async def query_and_load_streaming( + self, + query: str, + destination: str, + connection_name: str, + config: Optional[Dict[str, any]] = None, + label_config: Optional[LabelJoinConfig] = None, + with_reorg_detection: bool = True, + resume_watermark: Optional[ResumeWatermark] = None, + **kwargs, + ) -> AsyncIterator[LoadResult]: + """Execute a streaming query and continuously load results asynchronously. + + Runs the streaming operation in a thread pool executor and yields results. + """ + loop = asyncio.get_event_loop() + + # Run streaming query synchronously and collect results + def sync_streaming(): + # Get connection configuration + if connection_name: + try: + connection_info = self.connection_manager.get_connection_info(connection_name) + loader_config = connection_info['config'] + loader_type = connection_info['loader'] + except ValueError as e: + self.logger.error(f'Connection error: {e}') + raise + elif config: + loader_type = config.pop('loader_type', None) + if not loader_type: + raise ValueError("When using inline config, 'loader_type' must be specified") + loader_config = config + else: + raise ValueError('Either connection_name or config must be provided') + + # Extract load config + load_config = LoadConfig( + batch_size=kwargs.pop('batch_size', 10000), + mode=LoadMode(kwargs.pop('mode', 'append')), + create_table=kwargs.pop('create_table', True), + schema_evolution=kwargs.pop('schema_evolution', False), + **{k: v for k, v in kwargs.items() if k in ['max_retries', 'retry_delay']}, + ) + + self.logger.info(f'Starting async streaming query to {loader_type}:{destination}') + + loader_instance = create_loader(loader_type, loader_config, label_manager=self.label_manager) + + results = [] + + try: + # Execute streaming query with Flight SQL + command_query = FlightSql_pb2.CommandStatementQuery() + command_query.query = query + + any_command = Any() + any_command.Pack(command_query) + cmd = any_command.SerializeToString() + + flight_descriptor = flight.FlightDescriptor.for_command(cmd) + info = self.conn.get_flight_info(flight_descriptor) + reader = self.conn.do_get(info.endpoints[0].ticket) + + stream_iterator = StreamingResultIterator(reader) + + if with_reorg_detection: + stream_iterator = ReorgAwareStream(stream_iterator) + + with loader_instance: + for result in loader_instance.load_stream_continuous( + stream_iterator, destination, connection_name=connection_name, **load_config.__dict__ + ): + results.append(result) + + except Exception as e: + self.logger.error(f'Streaming query failed: {e}') + results.append( + LoadResult( + rows_loaded=0, + duration=0.0, + ops_per_second=0.0, + table_name=destination, + loader_type=loader_type, + success=False, + error=str(e), + metadata={'streaming_error': True}, + ) + ) + + return results + + results = await loop.run_in_executor(None, sync_streaming) + for result in results: + yield result + + async def close(self): + """Close all connections and release resources.""" + # Close Flight SQL connection + if hasattr(self, 'conn') and self.conn: + try: + self.conn.close() + except Exception as e: + self.logger.warning(f'Error closing Flight connection: {e}') + + # Close async admin client if initialized + if self._admin_client: + await self._admin_client.close() + + # Close async registry client if initialized + if self._registry_client: + await self._registry_client.close() + + async def __aenter__(self): + """Async context manager entry.""" + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb): + """Async context manager exit.""" + await self.close() diff --git a/src/amp/loaders/base.py b/src/amp/loaders/base.py index cc8a9a9..967c1fd 100644 --- a/src/amp/loaders/base.py +++ b/src/amp/loaders/base.py @@ -302,8 +302,8 @@ def _try_load_batch(self, batch: pa.RecordBatch, table_name: str, **kwargs) -> L f'Please create any tables needed before running the loader. ' ) - # Handle overwrite mode - if mode == LoadMode.OVERWRITE and hasattr(self, '_clear_table'): + # Handle overwrite mode (only if not already cleared by load_table) + if mode == LoadMode.OVERWRITE and not kwargs.get('_already_cleared') and hasattr(self, '_clear_table'): self._clear_table(table_name) # Perform the actual load @@ -347,6 +347,14 @@ def load_table(self, table: pa.Table, table_name: str, **kwargs) -> LoadResult: start_time = time.time() batch_size = kwargs.get('batch_size', getattr(self, 'batch_size', 10000)) + # Handle overwrite mode ONCE before processing batches + mode = kwargs.get('mode', LoadMode.APPEND) + if mode == LoadMode.OVERWRITE and hasattr(self, '_clear_table'): + self._clear_table(table_name) + # Prevent subsequent batch loads from clearing again + kwargs = kwargs.copy() + kwargs['_already_cleared'] = True + rows_loaded = 0 batch_count = 0 errors = [] @@ -374,7 +382,7 @@ def load_table(self, table: pa.Table, table_name: str, **kwargs) -> LoadResult: loader_type=self.__class__.__name__.replace('Loader', '').lower(), success=len(errors) == 0, error='; '.join(errors[:3]) if errors else None, - metadata=self._get_table_metadata(table, duration, batch_count, **kwargs), + metadata=self._get_table_metadata(table, duration, batch_count, table_name=table_name, **kwargs), ) except Exception as e: diff --git a/src/amp/loaders/implementations/deltalake_loader.py b/src/amp/loaders/implementations/deltalake_loader.py index e609d09..925b800 100644 --- a/src/amp/loaders/implementations/deltalake_loader.py +++ b/src/amp/loaders/implementations/deltalake_loader.py @@ -127,6 +127,11 @@ def _detect_storage_backend(self) -> None: self.storage_backend = 'Unknown' self.logger.warning(f'Unknown storage backend: {parsed_path.scheme}') + @property + def partition_by(self) -> Optional[List[str]]: + """Convenient access to partition_by configuration""" + return self.config.partition_by + def _get_required_config_fields(self) -> list[str]: """Return required configuration fields""" return ['table_path'] @@ -462,10 +467,15 @@ def _get_loader_table_metadata( mode = kwargs.get('mode', LoadMode.APPEND) delta_mode = self._convert_load_mode(mode) + version = table_info.get('version', 0) + num_files = table_info.get('num_files', 0) + return { 'write_mode': delta_mode.value, - 'table_version': table_info.get('version', 0), - 'total_files': table_info.get('num_files', 0), + 'table_version': version, + 'delta_version': version, # Alias for compatibility + 'total_files': num_files, + 'files_added': num_files, # Alias for compatibility 'total_size_bytes': table_info.get('size_bytes', 0), 'partition_columns': self.config.partition_by or [], 'storage_backend': self.storage_backend, diff --git a/src/amp/loaders/implementations/iceberg_loader.py b/src/amp/loaders/implementations/iceberg_loader.py index a0e0b7b..d018833 100644 --- a/src/amp/loaders/implementations/iceberg_loader.py +++ b/src/amp/loaders/implementations/iceberg_loader.py @@ -1,6 +1,5 @@ # src/amp/loaders/implementations/iceberg_loader.py -import json from dataclasses import dataclass from typing import Any, Dict, List, Optional @@ -261,6 +260,7 @@ def _create_table_from_schema(self, schema: pa.Schema, table_name: str) -> None: fixed_schema = self._fix_schema_timestamps(schema) # Use create_table_if_not_exists for simpler logic + # PyIceberg's create_table_if_not_exists can handle partition_spec parameter if self.config.partition_spec: table = self._catalog.create_table_if_not_exists( identifier=table_identifier, schema=fixed_schema, partition_spec=self.config.partition_spec @@ -414,7 +414,22 @@ def _get_loader_table_metadata( self, table: pa.Table, duration: float, batch_count: int, **kwargs ) -> Dict[str, Any]: """Get Iceberg-specific metadata for table operation""" - return {'namespace': self.config.namespace} + metadata = {'namespace': self.config.namespace} + + # Try to get snapshot info from the last loaded table + table_name = kwargs.get('table_name') + if table_name: + try: + table_identifier = f'{self.config.namespace}.{table_name}' + if table_identifier in self._table_cache: + iceberg_table = self._table_cache[table_identifier] + current_snapshot = iceberg_table.current_snapshot() + if current_snapshot: + metadata['snapshot_id'] = current_snapshot.snapshot_id + except Exception as e: + self.logger.debug(f'Could not get snapshot info for metadata: {e}') + + return metadata def _table_exists(self, table_name: str) -> bool: """Check if a table exists""" @@ -524,6 +539,19 @@ def _handle_reorg(self, invalidation_ranges: List[BlockRange], table_name: str, return try: + # Collect all affected batch IDs from state store + all_affected_batch_ids = [] + for range_obj in invalidation_ranges: + # Get batch IDs that need to be deleted from state store + affected_batch_ids = self.state_store.invalidate_from_block( + connection_name, table_name, range_obj.network, range_obj.start + ) + all_affected_batch_ids.extend(affected_batch_ids) + + if not all_affected_batch_ids: + self.logger.info(f'No batches to delete for reorg in {table_name}') + return + # Load the Iceberg table table_identifier = f'{self.config.namespace}.{table_name}' try: @@ -532,58 +560,31 @@ def _handle_reorg(self, invalidation_ranges: List[BlockRange], table_name: str, self.logger.warning(f"Table '{table_identifier}' does not exist, skipping reorg handling") return - # Build delete predicate for all invalidation ranges - # For Iceberg, we'll use PyArrow expressions which get converted automatically - delete_conditions = [] - - for range_obj in invalidation_ranges: - network = range_obj.network - reorg_start = range_obj.start - - # Create condition for this network's reorg - # Delete all rows where the block range metadata for this network has end >= reorg_start - # This catches both overlapping ranges and forward ranges from the reorg point - - # Build expression to check _meta_block_ranges JSON array - # We need to parse the JSON and check if any range for this network - # has an end block >= reorg_start - delete_conditions.append( - f'_meta_block_ranges LIKE \'%"network":"{network}"%\' AND ' - f'EXISTS (SELECT 1 FROM JSON_ARRAY_ELEMENTS(_meta_block_ranges) AS range_elem ' - f"WHERE range_elem->>'network' = '{network}' AND " - f"(range_elem->>'end')::int >= {reorg_start})" - ) - - # Process reorg if we have deletion conditions - if delete_conditions: - self.logger.info( - f'Executing blockchain reorg deletion for {len(invalidation_ranges)} networks ' - f"in Iceberg table '{table_name}'" - ) + self.logger.info( + f'Executing blockchain reorg deletion for {len(invalidation_ranges)} networks ' + f"in Iceberg table '{table_name}' ({len(all_affected_batch_ids)} batch IDs)" + ) - # Since PyIceberg doesn't have a direct delete API yet, we'll use overwrite - # with filtered data as a workaround - # Future: Use SQL delete when available: - # combined_condition = ' OR '.join(f'({cond})' for cond in delete_conditions) - # delete_expr = f"DELETE FROM {table_identifier} WHERE {combined_condition}" - self._perform_reorg_deletion(iceberg_table, invalidation_ranges, table_name) + self._perform_reorg_deletion(iceberg_table, all_affected_batch_ids, table_name) except Exception as e: self.logger.error(f"Failed to handle blockchain reorg for table '{table_name}': {str(e)}") raise - def _perform_reorg_deletion( - self, iceberg_table: IcebergTable, invalidation_ranges: List[BlockRange], table_name: str - ) -> None: + def _perform_reorg_deletion(self, iceberg_table: IcebergTable, all_affected_batch_ids, table_name: str) -> None: """ - Perform the actual deletion for reorg handling using Iceberg's capabilities. + Perform the actual deletion for reorg handling using batch IDs. Since PyIceberg doesn't have a direct DELETE API yet, we'll use scan and overwrite to achieve the same effect while maintaining ACID guarantees. + + Args: + iceberg_table: The Iceberg table to delete from + all_affected_batch_ids: List of BatchIdentifier objects to delete + table_name: Table name for logging """ try: # First, scan the table to get current data - # We'll filter out the invalidated ranges during the scan scan = iceberg_table.scan() # Read all data into memory (for now - could be optimized with streaming) @@ -593,79 +594,35 @@ def _perform_reorg_deletion( self.logger.info(f"Table '{table_name}' is empty, nothing to delete for reorg") return - # Check if the table has the metadata column - if '_meta_block_ranges' not in arrow_table.schema.names: + # Check if the table has the batch ID column + if '_amp_batch_id' not in arrow_table.schema.names: self.logger.warning( - f"Table '{table_name}' doesn't have '_meta_block_ranges' column, skipping reorg handling" + f"Table '{table_name}' doesn't have '_amp_batch_id' column, skipping reorg handling" ) return - # Filter out invalidated rows - import pyarrow.compute as pc - - # Start with all rows marked as valid - keep_mask = pc.equal(pc.scalar(True), pc.scalar(True)) + # Build set of unique batch IDs to delete + unique_batch_ids = set(bid.unique_id for bid in all_affected_batch_ids) - for range_obj in invalidation_ranges: - network = range_obj.network - reorg_start = range_obj.start - - # For each row, check if it should be invalidated - # This is complex with JSON, so we'll parse and check each row - for i in range(arrow_table.num_rows): - meta_json = arrow_table['_meta_block_ranges'][i].as_py() - if meta_json: - try: - ranges_data = json.loads(meta_json) - # Check if any range for this network should be invalidated - for range_info in ranges_data: - if range_info['network'] == network and range_info['end'] >= reorg_start: - # Mark this row for deletion - keep_mask = pc.and_(keep_mask, pc.not_equal(pc.scalar(i), pc.scalar(i))) - break - except (json.JSONDecodeError, KeyError): - continue - - # Create a filtered table with only the rows we want to keep - # For a more efficient implementation, build a boolean array + # Filter out rows with matching batch IDs + # A row should be deleted if its _amp_batch_id contains any of the affected IDs + # (handles multi-network batches with "|"-separated IDs) keep_indices = [] deleted_count = 0 for i in range(arrow_table.num_rows): - should_delete = False - meta_json = arrow_table['_meta_block_ranges'][i].as_py() - - if meta_json: - try: - ranges_data = json.loads(meta_json) - - # Ensure ranges_data is a list - if not isinstance(ranges_data, list): - continue - - # Check each invalidation range - for range_obj in invalidation_ranges: - network = range_obj.network - reorg_start = range_obj.start - - # Check if any range for this network should be invalidated - for range_info in ranges_data: - if ( - isinstance(range_info, dict) - and range_info.get('network') == network - and range_info.get('end', 0) >= reorg_start - ): - should_delete = True - deleted_count += 1 - break - - if should_delete: - break - - except (json.JSONDecodeError, KeyError): - pass - - if not should_delete: + batch_id_value = arrow_table['_amp_batch_id'][i].as_py() + + if batch_id_value: + # Check if any affected batch ID appears in this row's batch ID + should_delete = any(bid in batch_id_value for bid in unique_batch_ids) + + if should_delete: + deleted_count += 1 + else: + keep_indices.append(i) + else: + # Keep rows without batch ID keep_indices.append(i) if deleted_count == 0: @@ -684,7 +641,8 @@ def _perform_reorg_deletion( iceberg_table.overwrite(filtered_table) self.logger.info( - f"Blockchain reorg deleted {deleted_count} rows from Iceberg table '{table_name}'. " + f"Blockchain reorg deleted {deleted_count} rows from Iceberg table '{table_name}' " + f'({len(all_affected_batch_ids)} batch IDs). ' f'New snapshot created with {filtered_table.num_rows} remaining rows.' ) diff --git a/src/amp/loaders/implementations/lmdb_loader.py b/src/amp/loaders/implementations/lmdb_loader.py index 8d4efbd..dba8bbb 100644 --- a/src/amp/loaders/implementations/lmdb_loader.py +++ b/src/amp/loaders/implementations/lmdb_loader.py @@ -218,18 +218,20 @@ def _clear_data(self, table_name: str) -> None: try: db = self._get_or_create_db(self.config.database_name) - # Clear all entries by iterating through and deleting - with self.env.begin(write=True, db=db) as txn: + # Collect all keys in a read transaction first + with self.env.begin(db=db) as txn: cursor = txn.cursor() - # Delete all key-value pairs - if cursor.first(): - while True: - if not cursor.delete(): - break - if not cursor.next(): - break + keys_to_delete = list(cursor.iternext(values=False)) - self.logger.info(f"Cleared all data for table '{table_name}'") + # Delete all keys in a write transaction + if keys_to_delete: + with self.env.begin(write=True, db=db) as txn: + for key in keys_to_delete: + txn.delete(key) + + self.logger.info(f"Cleared {len(keys_to_delete)} entries from LMDB for table '{table_name}'") + else: + self.logger.info(f"No data to clear for table '{table_name}'") except Exception as e: self.logger.error(f'Error in _clear_data: {e}') raise diff --git a/src/amp/loaders/implementations/postgresql_loader.py b/src/amp/loaders/implementations/postgresql_loader.py index 7bae9f1..2591b52 100644 --- a/src/amp/loaders/implementations/postgresql_loader.py +++ b/src/amp/loaders/implementations/postgresql_loader.py @@ -187,6 +187,10 @@ def load_batch_transactional( def _clear_table(self, table_name: str) -> None: """Clear table for overwrite mode""" + # Check if table exists first + if not self.table_exists(table_name): + return # Nothing to clear if table doesn't exist + conn = self.pool.getconn() try: with conn.cursor() as cur: diff --git a/src/amp/loaders/implementations/redis_loader.py b/src/amp/loaders/implementations/redis_loader.py index 5e2a421..9b3307f 100644 --- a/src/amp/loaders/implementations/redis_loader.py +++ b/src/amp/loaders/implementations/redis_loader.py @@ -555,7 +555,7 @@ def _clear_data(self, table_name: str) -> None: self.logger.info(f'Deleted {keys_deleted} existing keys') else: - # For collection-based structures, delete the main key + # For collection-based structures, use table_name:structure format collection_key = f'{table_name}:{self.data_structure.value}' if self.redis_client.exists(collection_key): self.redis_client.delete(collection_key) diff --git a/src/amp/registry/async_client.py b/src/amp/registry/async_client.py new file mode 100644 index 0000000..acb454f --- /dev/null +++ b/src/amp/registry/async_client.py @@ -0,0 +1,180 @@ +"""Async Registry API client.""" + +import logging +import os +from typing import Optional + +import httpx + +from . import errors + +logger = logging.getLogger(__name__) + + +class AsyncRegistryClient: + """Async client for interacting with the Amp Registry API. + + The Registry API provides dataset discovery, search, and publishing capabilities. + + Args: + base_url: Base URL for the Registry API (default: staging registry) + auth_token: Optional Bearer token for authenticated operations (highest priority) + auth: If True, load auth token from ~/.amp/cache (shared with TS CLI) + + Authentication Priority (highest to lowest): + 1. Explicit auth_token parameter + 2. AMP_AUTH_TOKEN environment variable + 3. auth=True - reads from ~/.amp/cache/amp_cli_auth + + Example: + >>> # Read-only operations (no auth required) + >>> async with AsyncRegistryClient() as client: + ... datasets = await client.datasets.search('ethereum') + >>> + >>> # Authenticated operations with explicit token + >>> async with AsyncRegistryClient(auth_token='your-token') as client: + ... await client.datasets.publish(...) + """ + + def __init__( + self, + base_url: str = 'https://api.registry.amp.staging.thegraph.com', + auth_token: Optional[str] = None, + auth: bool = False, + ): + """Initialize async Registry client. + + Args: + base_url: Base URL for the Registry API + auth_token: Optional Bearer token for authentication + auth: If True, load auth token from ~/.amp/cache + + Raises: + ValueError: If both auth=True and auth_token are provided + """ + if auth and auth_token: + raise ValueError('Cannot specify both auth=True and auth_token. Choose one authentication method.') + + self.base_url = base_url.rstrip('/') + + # Resolve auth token provider with priority: explicit param > env var > auth file + self._get_token = None + if auth_token: + # Priority 1: Explicit auth_token parameter (static token) + def get_token(): + return auth_token + + self._get_token = get_token + elif os.getenv('AMP_AUTH_TOKEN'): + # Priority 2: AMP_AUTH_TOKEN environment variable (static token) + env_token = os.getenv('AMP_AUTH_TOKEN') + + def get_token(): + return env_token + + self._get_token = get_token + elif auth: + # Priority 3: Load from ~/.amp/cache/amp_cli_auth (auto-refreshing) + from amp.auth import AuthService + + auth_service = AuthService() + self._get_token = auth_service.get_token # Callable that auto-refreshes + + # Create async HTTP client (no auth header yet - will be added per-request) + self._http = httpx.AsyncClient( + base_url=self.base_url, + headers={ + 'Content-Type': 'application/json', + 'Accept': 'application/json', + }, + timeout=30.0, + ) + + logger.info(f'Initialized async Registry client for {base_url}') + + @property + def datasets(self): + """Access the async datasets client. + + Returns: + AsyncRegistryDatasetsClient: Client for dataset operations + """ + from .async_datasets import AsyncRegistryDatasetsClient + + return AsyncRegistryDatasetsClient(self) + + async def _request( + self, + method: str, + path: str, + **kwargs, + ) -> httpx.Response: + """Make an async HTTP request to the Registry API. + + Args: + method: HTTP method (GET, POST, etc.) + path: API path (without base URL) + **kwargs: Additional arguments to pass to httpx + + Returns: + httpx.Response: HTTP response + + Raises: + RegistryError: If the request fails + """ + url = path if path.startswith('http') else f'{self.base_url}{path}' + + # Add auth header dynamically (auto-refreshes if needed) + headers = kwargs.get('headers', {}) + if self._get_token: + headers['Authorization'] = f'Bearer {self._get_token()}' + kwargs['headers'] = headers + + try: + response = await self._http.request(method, url, **kwargs) + + # Handle error responses + if response.status_code >= 400: + self._handle_error(response) + + return response + + except httpx.RequestError as e: + raise errors.RegistryError(f'Request failed: {e}') from e + + def _handle_error(self, response: httpx.Response) -> None: + """Handle error responses from the API. + + Args: + response: HTTP error response + + Raises: + RegistryError: Mapped exception for the error + """ + try: + error_data = response.json() + error_code = error_data.get('error_code', '') + error_message = error_data.get('error_message', response.text) + request_id = error_data.get('request_id', '') + + # Map to specific exception + raise errors.map_error(error_code, error_message, request_id) + + except (ValueError, KeyError): + # Couldn't parse error response, raise generic error + raise errors.RegistryError( + f'HTTP {response.status_code}: {response.text}', + error_code=str(response.status_code), + ) from None + + async def close(self): + """Close the HTTP client.""" + await self._http.aclose() + + async def __aenter__(self): + """Async context manager entry.""" + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb): + """Async context manager exit.""" + await self.close() diff --git a/src/amp/registry/async_datasets.py b/src/amp/registry/async_datasets.py new file mode 100644 index 0000000..012ae0b --- /dev/null +++ b/src/amp/registry/async_datasets.py @@ -0,0 +1,437 @@ +"""Async Registry datasets client.""" + +from __future__ import annotations + +import logging +from typing import TYPE_CHECKING, Any, Dict, Optional + +from amp.utils.manifest_inspector import describe_manifest + +from . import models + +if TYPE_CHECKING: + from .async_client import AsyncRegistryClient + +logger = logging.getLogger(__name__) + + +class AsyncRegistryDatasetsClient: + """Async client for dataset operations in the Registry API. + + Provides async methods for: + - Searching and discovering datasets + - Fetching dataset details and manifests + - Publishing datasets (requires authentication) + - Managing dataset visibility and versions + + Args: + registry_client: Parent AsyncRegistryClient instance + """ + + def __init__(self, registry_client: AsyncRegistryClient): + """Initialize async datasets client. + + Args: + registry_client: Parent AsyncRegistryClient instance + """ + self._registry = registry_client + + # Read Operations (Public - No Auth Required) + + async def list( + self, limit: int = 50, page: int = 1, sort_by: Optional[str] = None, direction: Optional[str] = None + ) -> models.DatasetListResponse: + """List all published datasets with pagination. + + Args: + limit: Maximum number of datasets to return (default: 50, max: 1000) + page: Page number (1-indexed, default: 1) + sort_by: Field to sort by (e.g., 'name', 'created_at', 'updated_at') + direction: Sort direction ('asc' or 'desc') + + Returns: + DatasetListResponse: Paginated list of datasets + + Example: + >>> async with AsyncRegistryClient() as client: + ... response = await client.datasets.list(limit=10, page=1) + ... print(f"Found {response.total_count} datasets") + """ + params: Dict[str, Any] = {'limit': limit, 'page': page} + if sort_by: + params['sort_by'] = sort_by + if direction: + params['direction'] = direction + + response = await self._registry._request('GET', '/api/v1/datasets', params=params) + return models.DatasetListResponse.model_validate(response.json()) + + async def search(self, query: str, limit: int = 50, page: int = 1) -> models.DatasetSearchResponse: + """Search datasets using full-text keyword search. + + Results are ranked by relevance score. + + Args: + query: Search query string + limit: Maximum number of results (default: 50, max: 1000) + page: Page number (1-indexed, default: 1) + + Returns: + DatasetSearchResponse: Search results with relevance scores + + Example: + >>> async with AsyncRegistryClient() as client: + ... results = await client.datasets.search('ethereum blocks') + ... for dataset in results.datasets: + ... print(f"[{dataset.score}] {dataset.namespace}/{dataset.name}") + """ + params = {'search': query, 'limit': limit, 'page': page} + response = await self._registry._request('GET', '/api/v1/datasets/search', params=params) + return models.DatasetSearchResponse.model_validate(response.json()) + + async def ai_search(self, query: str, limit: int = 50) -> list[models.DatasetWithScore]: + """Search datasets using AI-powered semantic search. + + Uses embeddings and natural language processing for better matching. + + Args: + query: Natural language search query + limit: Maximum number of results (default: 50) + + Returns: + list[DatasetWithScore]: List of datasets with relevance scores + + Example: + >>> async with AsyncRegistryClient() as client: + ... results = await client.datasets.ai_search('find NFT transfer data') + ... for dataset in results: + ... print(f"[{dataset.score}] {dataset.namespace}/{dataset.name}") + """ + params = {'search': query, 'limit': limit} + response = await self._registry._request('GET', '/api/v1/datasets/search/ai', params=params) + return [models.DatasetWithScore.model_validate(d) for d in response.json()] + + async def get(self, namespace: str, name: str) -> models.Dataset: + """Get detailed information about a specific dataset. + + Args: + namespace: Dataset namespace (e.g., 'edgeandnode', 'edgeandnode') + name: Dataset name (e.g., 'ethereum-mainnet') + + Returns: + Dataset: Complete dataset information + + Example: + >>> async with AsyncRegistryClient() as client: + ... dataset = await client.datasets.get('edgeandnode', 'ethereum-mainnet') + ... print(f"Latest version: {dataset.latest_version}") + """ + path = f'/api/v1/datasets/{namespace}/{name}' + response = await self._registry._request('GET', path) + return models.Dataset.model_validate(response.json()) + + async def list_versions(self, namespace: str, name: str) -> list[models.DatasetVersion]: + """List all versions of a dataset. + + Versions are returned sorted by latest first. + + Args: + namespace: Dataset namespace + name: Dataset name + + Returns: + list[DatasetVersion]: List of dataset versions + + Example: + >>> async with AsyncRegistryClient() as client: + ... versions = await client.datasets.list_versions('edgeandnode', 'ethereum-mainnet') + ... for version in versions: + ... print(f" - v{version.version} ({version.status})") + """ + path = f'/api/v1/datasets/{namespace}/{name}/versions' + response = await self._registry._request('GET', path) + return [models.DatasetVersion.model_validate(v) for v in response.json()] + + async def get_version(self, namespace: str, name: str, version: str) -> models.DatasetVersion: + """Get details of a specific dataset version. + + Args: + namespace: Dataset namespace + name: Dataset name + version: Version tag (e.g., '1.0.0', 'latest') + + Returns: + DatasetVersion: Version details + + Example: + >>> async with AsyncRegistryClient() as client: + ... version = await client.datasets.get_version('edgeandnode', 'ethereum-mainnet', 'latest') + ... print(f"Version: {version.version}") + """ + path = f'/api/v1/datasets/{namespace}/{name}/versions/{version}' + response = await self._registry._request('GET', path) + return models.DatasetVersion.model_validate(response.json()) + + async def get_manifest(self, namespace: str, name: str, version: str) -> dict: + """Fetch the manifest JSON for a specific dataset version. + + Manifests define the dataset structure, dependencies, and ETL logic. + + Args: + namespace: Dataset namespace + name: Dataset name + version: Version tag (e.g., '1.0.0', 'latest') + + Returns: + dict: Manifest JSON content + + Example: + >>> async with AsyncRegistryClient() as client: + ... manifest = await client.datasets.get_manifest('edgeandnode', 'ethereum-mainnet', 'latest') + ... print(f"Tables: {list(manifest.get('tables', {}).keys())}") + """ + path = f'/api/v1/datasets/{namespace}/{name}/versions/{version}/manifest' + response = await self._registry._request('GET', path) + return response.json() + + async def describe( + self, namespace: str, name: str, version: str = 'latest' + ) -> Dict[str, list[Dict[str, str | bool]]]: + """Get a structured summary of tables and columns in a dataset. + + Returns a dictionary mapping table names to lists of column information, + making it easy to programmatically inspect the dataset schema. + + Args: + namespace: Dataset namespace + name: Dataset name + version: Version tag (default: 'latest') + + Returns: + dict: Mapping of table names to column information. + + Example: + >>> async with AsyncRegistryClient() as client: + ... schema = await client.datasets.describe('edgeandnode', 'ethereum-mainnet', 'latest') + ... for table_name, columns in schema.items(): + ... print(f"Table: {table_name}") + """ + manifest = await self.get_manifest(namespace, name, version) + return describe_manifest(manifest) + + # Write Operations (Require Authentication) + + async def publish( + self, + namespace: str, + name: str, + version: str, + manifest: dict, + visibility: str = 'public', + description: Optional[str] = None, + tags: Optional[list[str]] = None, + chains: Optional[list[str]] = None, + sources: Optional[list[str]] = None, + ) -> models.Dataset: + """Publish a new dataset to the registry. + + Requires authentication (Bearer token). + + Args: + namespace: Dataset namespace (owner's username or org) + name: Dataset name + version: Initial version tag (e.g., '1.0.0') + manifest: Dataset manifest JSON + visibility: Dataset visibility ('public' or 'private', default: 'public') + description: Dataset description + tags: Optional list of tags/keywords + chains: Optional list of blockchain networks + sources: Optional list of data sources + + Returns: + Dataset: Created dataset + + Example: + >>> async with AsyncRegistryClient(auth_token='your-token') as client: + ... dataset = await client.datasets.publish( + ... namespace='myuser', + ... name='my_dataset', + ... version='1.0.0', + ... manifest=manifest, + ... description='My custom dataset' + ... ) + """ + body = { + 'name': name, + 'version': version, + 'manifest': manifest, + 'visibility': visibility, + } + if description: + body['description'] = description + if tags: + body['tags'] = tags + if chains: + body['chains'] = chains + if sources: + body['sources'] = sources + + response = await self._registry._request('POST', '/api/v1/owners/@me/datasets/publish', json=body) + return models.Dataset.model_validate(response.json()) + + async def publish_version( + self, + namespace: str, + name: str, + version: str, + manifest: dict, + description: Optional[str] = None, + ) -> models.DatasetVersion: + """Publish a new version of an existing dataset. + + Requires authentication and ownership of the dataset. + + Args: + namespace: Dataset namespace + name: Dataset name + version: New version tag (e.g., '1.1.0') + manifest: Dataset manifest JSON for this version + description: Optional version description + + Returns: + DatasetVersion: Created version + + Example: + >>> async with AsyncRegistryClient(auth_token='your-token') as client: + ... version = await client.datasets.publish_version( + ... namespace='myuser', + ... name='my_dataset', + ... version='1.1.0', + ... manifest=manifest + ... ) + """ + body = {'version': version, 'manifest': manifest} + if description: + body['description'] = description + + path = f'/api/v1/owners/@me/datasets/{namespace}/{name}/versions/publish' + response = await self._registry._request('POST', path, json=body) + return models.DatasetVersion.model_validate(response.json()) + + async def update( + self, + namespace: str, + name: str, + description: Optional[str] = None, + tags: Optional[list[str]] = None, + chains: Optional[list[str]] = None, + sources: Optional[list[str]] = None, + ) -> models.Dataset: + """Update dataset metadata. + + Requires authentication and ownership of the dataset. + + Args: + namespace: Dataset namespace + name: Dataset name + description: Updated description + tags: Updated tags + chains: Updated chains + sources: Updated sources + + Returns: + Dataset: Updated dataset + + Example: + >>> async with AsyncRegistryClient(auth_token='your-token') as client: + ... dataset = await client.datasets.update( + ... namespace='myuser', + ... name='my_dataset', + ... description='Updated description' + ... ) + """ + body = {} + if description is not None: + body['description'] = description + if tags is not None: + body['tags'] = tags + if chains is not None: + body['chains'] = chains + if sources is not None: + body['sources'] = sources + + path = f'/api/v1/owners/@me/datasets/{namespace}/{name}' + response = await self._registry._request('PUT', path, json=body) + return models.Dataset.model_validate(response.json()) + + async def update_visibility(self, namespace: str, name: str, visibility: str) -> models.Dataset: + """Update dataset visibility (public/private). + + Requires authentication and ownership of the dataset. + + Args: + namespace: Dataset namespace + name: Dataset name + visibility: New visibility ('public' or 'private') + + Returns: + Dataset: Updated dataset + + Example: + >>> async with AsyncRegistryClient(auth_token='your-token') as client: + ... dataset = await client.datasets.update_visibility('myuser', 'my_dataset', 'private') + """ + body = {'visibility': visibility} + path = f'/api/v1/owners/@me/datasets/{namespace}/{name}/visibility' + response = await self._registry._request('PATCH', path, json=body) + return models.Dataset.model_validate(response.json()) + + async def update_version_status( + self, namespace: str, name: str, version: str, status: str + ) -> models.DatasetVersion: + """Update the status of a dataset version. + + Requires authentication and ownership of the dataset. + + Args: + namespace: Dataset namespace + name: Dataset name + version: Version tag + status: New status ('draft', 'published', 'deprecated', or 'archived') + + Returns: + DatasetVersion: Updated version + + Example: + >>> async with AsyncRegistryClient(auth_token='your-token') as client: + ... version = await client.datasets.update_version_status( + ... 'myuser', 'my_dataset', '1.0.0', 'deprecated' + ... ) + """ + body = {'status': status} + path = f'/api/v1/owners/@me/datasets/{namespace}/{name}/versions/{version}' + response = await self._registry._request('PATCH', path, json=body) + return models.DatasetVersion.model_validate(response.json()) + + async def delete_version( + self, namespace: str, name: str, version: str + ) -> models.ArchiveDatasetVersionResponse: + """Delete (archive) a dataset version. + + Requires authentication and ownership of the dataset. + + Args: + namespace: Dataset namespace + name: Dataset name + version: Version tag to delete + + Returns: + ArchiveDatasetVersionResponse: Confirmation of deletion + + Example: + >>> async with AsyncRegistryClient(auth_token='your-token') as client: + ... response = await client.datasets.delete_version('myuser', 'my_dataset', '0.1.0') + """ + path = f'/api/v1/owners/@me/datasets/{namespace}/{name}/versions/{version}' + response = await self._registry._request('DELETE', path) + return models.ArchiveDatasetVersionResponse.model_validate(response.json()) diff --git a/tests/conftest.py b/tests/conftest.py index 2180725..3644ef4 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -203,6 +203,16 @@ def redis_test_config(request): return request.getfixturevalue('redis_config') +@pytest.fixture +def redis_streaming_config(redis_test_config): + """Redis config for streaming tests with blockchain data (uses tx_hash instead of id)""" + return { + **redis_test_config, + 'key_pattern': '{table}:{tx_hash}', # Use tx_hash from blockchain data + 'data_structure': 'hash', + } + + @pytest.fixture(scope='session') def delta_test_env(): """Create Delta Lake test environment for the session""" @@ -216,10 +226,12 @@ def delta_test_env(): @pytest.fixture -def delta_basic_config(delta_test_env): +def delta_basic_config(delta_test_env, request): """Basic Delta Lake configuration for testing""" + # Create unique table path for each test to avoid data accumulation + unique_suffix = f'{request.node.name}_{id(request)}' return { - 'table_path': str(Path(delta_test_env) / 'basic_table'), + 'table_path': str(Path(delta_test_env) / f'basic_table_{unique_suffix}'), 'partition_by': ['year', 'month'], 'optimize_after_write': True, 'vacuum_after_write': False, @@ -243,6 +255,20 @@ def delta_partitioned_config(delta_test_env): } +@pytest.fixture +def delta_streaming_config(delta_test_env): + """Delta Lake configuration for streaming tests (no partitioning)""" + return { + 'table_path': str(Path(delta_test_env) / 'streaming_table'), + 'partition_by': [], # No partitioning for streaming tests + 'optimize_after_write': True, + 'vacuum_after_write': False, + 'schema_evolution': True, + 'merge_schema': True, + 'storage_options': {}, + } + + @pytest.fixture def delta_temp_config(delta_test_env): """Temporary Delta Lake configuration with unique path""" @@ -287,6 +313,24 @@ def iceberg_basic_config(iceberg_test_env): } +@pytest.fixture +def iceberg_streaming_config(iceberg_test_env): + """Iceberg configuration for streaming tests (no partitioning)""" + return { + 'catalog_config': { + 'type': 'sql', + 'uri': f'sqlite:///{iceberg_test_env}/streaming_catalog.db', + 'warehouse': f'file://{iceberg_test_env}/streaming_warehouse', + }, + 'namespace': 'test_data', + 'create_namespace': True, + 'create_table': True, + 'schema_evolution': True, + 'batch_size': 10000, + 'partition_spec': [], # No partitioning for streaming tests + } + + @pytest.fixture def lmdb_test_env(): """Create LMDB test environment for the session""" diff --git a/tests/integration/loaders/backends/test_deltalake.py b/tests/integration/loaders/backends/test_deltalake.py new file mode 100644 index 0000000..aa0c661 --- /dev/null +++ b/tests/integration/loaders/backends/test_deltalake.py @@ -0,0 +1,303 @@ +""" +DeltaLake-specific loader integration tests. + +This module provides DeltaLake-specific test configuration and tests that +inherit from the generalized base test classes. +""" + +from pathlib import Path +from typing import Any, Dict, List, Optional + +import pytest + +try: + from deltalake import DeltaTable + + from src.amp.loaders.implementations.deltalake_loader import DeltaLakeLoader + from tests.integration.loaders.conftest import LoaderTestConfig + from tests.integration.loaders.test_base_loader import BaseLoaderTests + from tests.integration.loaders.test_base_streaming import BaseStreamingTests +except ImportError: + pytest.skip('amp modules not available', allow_module_level=True) + + +class DeltaLakeTestConfig(LoaderTestConfig): + """DeltaLake-specific test configuration""" + + loader_class = DeltaLakeLoader + + supports_overwrite = True + supports_streaming = True + supports_multi_network = True + supports_null_values = True + requires_existing_table = False # DeltaLake auto-creates tables + + def __init__(self, config_fixture_name='delta_basic_config'): + """ + Initialize DeltaLake test config. + + Args: + config_fixture_name: Name of the pytest fixture providing loader config + (default: 'delta_basic_config' for core tests) + """ + self.config_fixture_name = config_fixture_name + + def get_row_count(self, loader: DeltaLakeLoader, table_name: str) -> int: + """Get row count from DeltaLake table""" + # DeltaLake uses the table_path as the identifier + table_path = loader.config.table_path + dt = DeltaTable(table_path) + df = dt.to_pyarrow_table() + return len(df) + + def query_rows( + self, loader: DeltaLakeLoader, table_name: str, where: Optional[str] = None, order_by: Optional[str] = None + ) -> List[Dict[str, Any]]: + """Query rows from DeltaLake table""" + table_path = loader.config.table_path + dt = DeltaTable(table_path) + df = dt.to_pyarrow_table() + + # Convert to list of dicts (simple implementation, no filtering) + result = [] + for i in range(min(100, len(df))): + row = {col: df[col][i].as_py() for col in df.column_names} + result.append(row) + return result + + def cleanup_table(self, loader: DeltaLakeLoader, table_name: str) -> None: + """Delete DeltaLake table directory""" + import shutil + + table_path = loader.config.table_path + if Path(table_path).exists(): + shutil.rmtree(table_path, ignore_errors=True) + + def get_column_names(self, loader: DeltaLakeLoader, table_name: str) -> List[str]: + """Get column names from DeltaLake table""" + table_path = loader.config.table_path + dt = DeltaTable(table_path) + schema = dt.schema() + return [field.name for field in schema.fields] + + +@pytest.mark.delta_lake +class TestDeltaLakeCore(BaseLoaderTests): + """DeltaLake core loader tests (inherited from base)""" + + config = DeltaLakeTestConfig() + + +@pytest.mark.delta_lake +class TestDeltaLakeStreaming(BaseStreamingTests): + """DeltaLake streaming tests (inherited from base)""" + + config = DeltaLakeTestConfig('delta_streaming_config') # Use non-partitioned config + + +@pytest.mark.delta_lake +class TestDeltaLakeSpecific: + """DeltaLake-specific tests that cannot be generalized""" + + def test_partitioning(self, delta_partitioned_config, small_test_data): + """Test DeltaLake partitioning functionality""" + + loader = DeltaLakeLoader(delta_partitioned_config) + + # Verify partitioning is configured + assert loader.partition_by == ['year', 'month', 'day'] + + with loader: + result = loader.load_table(small_test_data, 'test_table') + assert result.success == True + assert result.rows_loaded == 5 + + # Verify partitions were created + dt = DeltaTable(loader.config.table_path) + files = dt.file_uris() + # Partitioned tables create subdirectories + assert len(files) > 0 + + def test_optimization_operations(self, delta_basic_config, comprehensive_test_data): + """Test DeltaLake OPTIMIZE operations""" + loader = DeltaLakeLoader(delta_basic_config) + + with loader: + # Load data + result = loader.load_table(comprehensive_test_data, 'test_table') + assert result.success == True + + # DeltaLake may auto-optimize if configured + dt = DeltaTable(loader.config.table_path) + version = dt.version() + assert version >= 0 + + # Verify table can be read after optimization + df = dt.to_pyarrow_table() + assert len(df) == 1000 + + def test_schema_evolution(self, delta_basic_config, small_test_data): + """Test DeltaLake schema evolution""" + import pyarrow as pa + + loader = DeltaLakeLoader(delta_basic_config) + + with loader: + # Load initial data + result = loader.load_table(small_test_data, 'test_table') + assert result.success == True + + # Add new column to schema + extended_data = { + **{col: small_test_data[col].to_pylist() for col in small_test_data.column_names}, + 'new_column': [100, 200, 300, 400, 500], + } + extended_table = pa.Table.from_pydict(extended_data) + + # Load with new schema (if schema evolution enabled) + from src.amp.loaders.base import LoadMode + + loader.load_table(extended_table, 'test_table', mode=LoadMode.APPEND) + + # Result depends on merge_schema configuration + dt = DeltaTable(loader.config.table_path) + schema = dt.schema() + # New column may or may not be present depending on config + assert len(schema.fields) >= len(small_test_data.schema) + + def test_table_history(self, delta_basic_config, small_test_data): + """Test DeltaLake table version history""" + from src.amp.loaders.base import LoadMode + + loader = DeltaLakeLoader(delta_basic_config) + + with loader: + # Load initial data + loader.load_table(small_test_data, 'test_table') + + # Append more data + loader.load_table(small_test_data, 'test_table', mode=LoadMode.APPEND) + + # Check version history + dt = DeltaTable(loader.config.table_path) + version = dt.version() + assert version >= 1 # At least 2 operations (create + append) + + # Verify history is accessible + history = dt.history() + assert len(history) >= 1 + + def test_metadata_completeness(self, delta_basic_config, comprehensive_test_data): + """Test DeltaLake metadata in load results""" + loader = DeltaLakeLoader(delta_basic_config) + + with loader: + result = loader.load_table(comprehensive_test_data, 'test_table') + + assert result.success == True + assert 'delta_version' in result.metadata + assert 'files_added' in result.metadata + assert result.metadata['delta_version'] >= 0 + + def test_query_operations(self, delta_basic_config, comprehensive_test_data): + """Test querying DeltaLake tables""" + loader = DeltaLakeLoader(delta_basic_config) + + with loader: + result = loader.load_table(comprehensive_test_data, 'test_table') + assert result.success == True + + # Query the table + dt = DeltaTable(loader.config.table_path) + df = dt.to_pyarrow_table() + + # Verify data integrity + assert len(df) == 1000 + assert 'id' in df.column_names + assert 'user_id' in df.column_names + + def test_file_size_calculation(self, delta_basic_config, comprehensive_test_data): + """Test file size calculation for DeltaLake tables""" + loader = DeltaLakeLoader(delta_basic_config) + + with loader: + result = loader.load_table(comprehensive_test_data, 'test_table') + assert result.success == True + + # Get table size + dt = DeltaTable(loader.config.table_path) + files = dt.file_uris() + assert len(files) > 0 + + # Calculate total size + total_size = 0 + for file_uri in files: + # Remove file:// prefix if present + file_path = file_uri.replace('file://', '') + if Path(file_path).exists(): + total_size += Path(file_path).stat().st_size + + assert total_size > 0 + + def test_concurrent_operations_safety(self, delta_basic_config, small_test_data): + """Test that DeltaLake handles concurrent operations safely""" + from concurrent.futures import ThreadPoolExecutor, as_completed + + from src.amp.loaders.base import LoadMode + + loader = DeltaLakeLoader(delta_basic_config) + + with loader: + # Load initial data + loader.load_table(small_test_data, 'test_table') + + # Try concurrent appends + def append_data(i): + return loader.load_table(small_test_data, 'test_table', mode=LoadMode.APPEND) + + with ThreadPoolExecutor(max_workers=3) as executor: + futures = [executor.submit(append_data, i) for i in range(3)] + results = [future.result() for future in as_completed(futures)] + + # All operations should succeed + assert all(r.success for r in results) + + # Verify final row count + dt = DeltaTable(loader.config.table_path) + df = dt.to_pyarrow_table() + assert len(df) == 20 # 5 initial + 3 * 5 appends + + +@pytest.mark.delta_lake +@pytest.mark.slow +class TestDeltaLakePerformance: + """DeltaLake performance tests""" + + def test_large_data_loading(self, delta_basic_config): + """Test loading large datasets to DeltaLake""" + import pyarrow as pa + + # Create large dataset + large_data = { + 'id': list(range(50000)), + 'value': [i * 0.123 for i in range(50000)], + 'category': [f'cat_{i % 100}' for i in range(50000)], + 'year': [2024 if i < 40000 else 2023 for i in range(50000)], + 'month': [(i // 100) % 12 + 1 for i in range(50000)], + 'day': [(i // 10) % 28 + 1 for i in range(50000)], + } + large_table = pa.Table.from_pydict(large_data) + + loader = DeltaLakeLoader(delta_basic_config) + + with loader: + result = loader.load_table(large_table, 'test_table') + + assert result.success == True + assert result.rows_loaded == 50000 + assert result.duration < 60 # Should complete within 60 seconds + + # Verify data integrity + dt = DeltaTable(loader.config.table_path) + df = dt.to_pyarrow_table() + assert len(df) == 50000 diff --git a/tests/integration/loaders/backends/test_iceberg.py b/tests/integration/loaders/backends/test_iceberg.py new file mode 100644 index 0000000..7fc856d --- /dev/null +++ b/tests/integration/loaders/backends/test_iceberg.py @@ -0,0 +1,319 @@ +""" +Iceberg-specific loader integration tests. + +This module provides Iceberg-specific test configuration and tests that +inherit from the generalized base test classes. +""" + +from typing import Any, Dict, List, Optional + +import pytest + +try: + from src.amp.loaders.implementations.iceberg_loader import IcebergLoader + from tests.integration.loaders.conftest import LoaderTestConfig + from tests.integration.loaders.test_base_loader import BaseLoaderTests + from tests.integration.loaders.test_base_streaming import BaseStreamingTests +except ImportError: + pytest.skip('amp modules not available', allow_module_level=True) + + +class IcebergTestConfig(LoaderTestConfig): + """Iceberg-specific test configuration""" + + loader_class = IcebergLoader + + supports_overwrite = True + supports_streaming = True + supports_multi_network = True + supports_null_values = True + + def __init__(self, config_fixture_name='iceberg_basic_config'): + """ + Initialize Iceberg test config. + + Args: + config_fixture_name: Name of the pytest fixture providing loader config + (default: 'iceberg_basic_config' for core tests) + """ + self.config_fixture_name = config_fixture_name + + def get_row_count(self, loader: IcebergLoader, table_name: str) -> int: + """Get row count from Iceberg table""" + catalog = loader._catalog + table = catalog.load_table((loader.config.namespace, table_name)) + df = table.scan().to_arrow() + return len(df) + + def query_rows( + self, loader: IcebergLoader, table_name: str, where: Optional[str] = None, order_by: Optional[str] = None + ) -> List[Dict[str, Any]]: + """Query rows from Iceberg table""" + catalog = loader._catalog + table = catalog.load_table((loader.config.namespace, table_name)) + df = table.scan().to_arrow() + + # Convert to list of dicts (simple implementation, no filtering) + result = [] + for i in range(min(100, len(df))): + row = {col: df[col][i].as_py() for col in df.column_names} + result.append(row) + return result + + def cleanup_table(self, loader: IcebergLoader, table_name: str) -> None: + """Drop Iceberg table""" + try: + catalog = loader._catalog + catalog.drop_table((loader.config.namespace, table_name)) + except Exception: + pass # Table may not exist + + def get_column_names(self, loader: IcebergLoader, table_name: str) -> List[str]: + """Get column names from Iceberg table""" + catalog = loader._catalog + table = catalog.load_table((loader.config.namespace, table_name)) + schema = table.schema() + return [field.name for field in schema.fields] + + +@pytest.mark.iceberg +class TestIcebergCore(BaseLoaderTests): + """Iceberg core loader tests (inherited from base)""" + + config = IcebergTestConfig() + + +@pytest.mark.iceberg +class TestIcebergStreaming(BaseStreamingTests): + """Iceberg streaming tests (inherited from base)""" + + config = IcebergTestConfig('iceberg_streaming_config') # Use non-partitioned config + + +@pytest.mark.iceberg +class TestIcebergSpecific: + """Iceberg-specific tests that cannot be generalized""" + + def test_catalog_initialization(self, iceberg_basic_config): + """Test Iceberg catalog initialization""" + loader = IcebergLoader(iceberg_basic_config) + + with loader: + assert loader._catalog is not None + assert loader.config.namespace is not None + + # Verify namespace exists or was created + namespaces = loader._catalog.list_namespaces() + assert any(ns == (loader.config.namespace,) for ns in namespaces) + + @pytest.mark.skip(reason='Partitioning with list format not yet fully implemented') + def test_partitioning(self, iceberg_basic_config, small_test_data): + """Test Iceberg partitioning (partition spec)""" + + # Create config with partitioning + config = {**iceberg_basic_config, 'partition_spec': [('year', 'identity'), ('month', 'identity')]} + loader = IcebergLoader(config) + + table_name = 'test_partitioned' + + with loader: + result = loader.load_table(small_test_data, table_name) + assert result.success == True + assert result.rows_loaded == 5 + + # Verify table was created with partition spec + catalog = loader._catalog + table = catalog.load_table((loader.config.namespace, table_name)) + spec = table.spec() + # Partition spec should have fields + assert len(spec.fields) > 0 + + # Cleanup + catalog.drop_table((loader.config.namespace, table_name)) + + def test_schema_evolution(self, iceberg_basic_config, small_test_data): + """Test Iceberg schema evolution""" + import pyarrow as pa + + from src.amp.loaders.base import LoadMode + + loader = IcebergLoader(iceberg_basic_config) + table_name = 'test_schema_evolution' + + with loader: + # Load initial data + result = loader.load_table(small_test_data, table_name) + assert result.success == True + + # Add new column + extended_data = { + **{col: small_test_data[col].to_pylist() for col in small_test_data.column_names}, + 'new_column': [100, 200, 300, 400, 500], + } + extended_table = pa.Table.from_pydict(extended_data) + + # Load with new schema + loader.load_table(extended_table, table_name, mode=LoadMode.APPEND) + + # Schema evolution depends on config + catalog = loader._catalog + table = catalog.load_table((loader.config.namespace, table_name)) + schema = table.schema() + # New column may or may not be present + assert len(schema.fields) >= len(small_test_data.schema) + + # Cleanup + catalog.drop_table((loader.config.namespace, table_name)) + + def test_timestamp_conversion(self, iceberg_basic_config): + """Test timestamp conversion for Iceberg""" + from datetime import datetime + + import pyarrow as pa + + # Create data with timestamps + data = { + 'id': [1, 2, 3], + 'timestamp': [datetime(2024, 1, 1), datetime(2024, 1, 2), datetime(2024, 1, 3)], + 'value': [100, 200, 300], + } + test_data = pa.Table.from_pydict(data) + + loader = IcebergLoader(iceberg_basic_config) + table_name = 'test_timestamps' + + with loader: + result = loader.load_table(test_data, table_name) + assert result.success == True + assert result.rows_loaded == 3 + + # Verify timestamps were converted correctly + catalog = loader._catalog + table = catalog.load_table((loader.config.namespace, table_name)) + df = table.scan().to_arrow() + + assert len(df) == 3 + assert 'timestamp' in df.column_names + + # Cleanup + catalog.drop_table((loader.config.namespace, table_name)) + + def test_multiple_tables(self, iceberg_basic_config, small_test_data): + """Test managing multiple tables with same loader""" + + loader = IcebergLoader(iceberg_basic_config) + + with loader: + # Create first table + result1 = loader.load_table(small_test_data, 'table1') + assert result1.success == True + + # Create second table + result2 = loader.load_table(small_test_data, 'table2') + assert result2.success == True + + # Verify both exist + catalog = loader._catalog + tables = catalog.list_tables(loader.config.namespace) + table_names = [t[1] for t in tables] + + assert 'table1' in table_names + assert 'table2' in table_names + + # Cleanup + catalog.drop_table((loader.config.namespace, 'table1')) + catalog.drop_table((loader.config.namespace, 'table2')) + + def test_upsert_operations(self, iceberg_basic_config): + """Test Iceberg upsert operations (merge on read)""" + import pyarrow as pa + + from src.amp.loaders.base import LoadMode + + loader = IcebergLoader(iceberg_basic_config) + table_name = 'test_upsert' + + # Initial data + data1 = {'id': [1, 2, 3], 'value': [100, 200, 300]} + table1 = pa.Table.from_pydict(data1) + + # Updated data (overlapping IDs) + data2 = {'id': [2, 3, 4], 'value': [250, 350, 400]} + table2 = pa.Table.from_pydict(data2) + + with loader: + # Load initial + loader.load_table(table1, table_name) + + # Upsert (if supported) or append + try: + # Try upsert mode if supported + loader.load_table(table2, table_name, mode=LoadMode.APPEND) + + # Verify row count + catalog = loader._catalog + table = catalog.load_table((loader.config.namespace, table_name)) + df = table.scan().to_arrow() + + # With append, we should have 6 rows (3 + 3) + # With upsert, we should have 4 rows (deduplicated) + assert len(df) >= 3 + + # Cleanup + catalog.drop_table((loader.config.namespace, table_name)) + except Exception: + # Upsert may not be supported + catalog = loader._catalog + catalog.drop_table((loader.config.namespace, table_name)) + + def test_metadata_completeness(self, iceberg_basic_config, comprehensive_test_data): + """Test Iceberg metadata in load results""" + loader = IcebergLoader(iceberg_basic_config) + table_name = 'test_metadata' + + with loader: + result = loader.load_table(comprehensive_test_data, table_name) + + assert result.success == True + assert 'snapshot_id' in result.metadata or 'files_written' in result.metadata + + # Cleanup + catalog = loader._catalog + catalog.drop_table((loader.config.namespace, table_name)) + + +@pytest.mark.iceberg +@pytest.mark.slow +class TestIcebergPerformance: + """Iceberg performance tests""" + + def test_large_data_loading(self, iceberg_basic_config): + """Test loading large datasets to Iceberg""" + import pyarrow as pa + + # Create large dataset + large_data = { + 'id': list(range(50000)), + 'value': [i * 0.123 for i in range(50000)], + 'category': [f'cat_{i % 100}' for i in range(50000)], + } + large_table = pa.Table.from_pydict(large_data) + + loader = IcebergLoader(iceberg_basic_config) + table_name = 'test_large' + + with loader: + result = loader.load_table(large_table, table_name) + + assert result.success == True + assert result.rows_loaded == 50000 + assert result.duration < 60 # Should complete within 60 seconds + + # Verify data integrity + catalog = loader._catalog + table = catalog.load_table((loader.config.namespace, table_name)) + df = table.scan().to_arrow() + assert len(df) == 50000 + + # Cleanup + catalog.drop_table((loader.config.namespace, table_name)) diff --git a/tests/integration/loaders/backends/test_lmdb.py b/tests/integration/loaders/backends/test_lmdb.py new file mode 100644 index 0000000..680ce6f --- /dev/null +++ b/tests/integration/loaders/backends/test_lmdb.py @@ -0,0 +1,360 @@ +""" +LMDB-specific loader integration tests. + +This module provides LMDB-specific test configuration and tests that +inherit from the generalized base test classes. +""" + +from pathlib import Path +from typing import Any, Dict, List, Optional + +import pytest + +try: + from src.amp.loaders.implementations.lmdb_loader import LMDBLoader + from tests.integration.loaders.conftest import LoaderTestConfig + from tests.integration.loaders.test_base_loader import BaseLoaderTests + from tests.integration.loaders.test_base_streaming import BaseStreamingTests +except ImportError: + pytest.skip('amp modules not available', allow_module_level=True) + + +class LMDBTestConfig(LoaderTestConfig): + """LMDB-specific test configuration""" + + loader_class = LMDBLoader + config_fixture_name = 'lmdb_config' # LMDB uses config fixture + + supports_overwrite = True + supports_streaming = True + supports_multi_network = True + supports_null_values = True + requires_existing_table = False # LMDB auto-creates databases + + def get_row_count(self, loader: LMDBLoader, table_name: str) -> int: + """Get row count from LMDB database""" + count = 0 + db = loader._get_or_create_db(getattr(loader.config, 'database_name', None)) + with loader.env.begin(db=db) as txn: + cursor = txn.cursor() + for _key, _value in cursor: + count += 1 + return count + + def query_rows( + self, loader: LMDBLoader, table_name: str, where: Optional[str] = None, order_by: Optional[str] = None + ) -> List[Dict[str, Any]]: + """Query rows from LMDB database""" + import json + + rows = [] + db = loader._get_or_create_db(getattr(loader.config, 'database_name', None)) + with loader.env.begin(db=db) as txn: + cursor = txn.cursor() + for i, (key, value) in enumerate(cursor): + if i >= 100: # Limit to 100 + break + try: + # Try to decode as JSON + row_data = json.loads(value.decode()) + row_data['_key'] = key.decode() + rows.append(row_data) + except Exception: + # Fallback to raw value + rows.append({'_key': key.decode(), '_value': value.decode()}) + return rows + + def cleanup_table(self, loader: LMDBLoader, table_name: str) -> None: + """Clear LMDB database""" + # LMDB doesn't have tables - clear the entire database + db = loader._get_or_create_db(getattr(loader.config, 'database_name', None)) + with loader.env.begin(db=db, write=True) as txn: + cursor = txn.cursor() + # Delete all keys + keys_to_delete = [key for key, _ in cursor] + for key in keys_to_delete: + txn.delete(key) + + def get_column_names(self, loader: LMDBLoader, table_name: str) -> List[str]: + """Get column names from LMDB database (from first record)""" + import pyarrow as pa + + db = loader._get_or_create_db(getattr(loader.config, 'database_name', None)) + with loader.env.begin(db=db) as txn: + cursor = txn.cursor() + for _key, value in cursor: + try: + # Deserialize Arrow IPC format (LMDB stores Arrow batches, not JSON) + reader = pa.ipc.open_stream(value) + batch = reader.read_next_batch() + return batch.schema.names # Returns all column names including metadata + except Exception: + return ['_value'] # Fallback + return [] + + +@pytest.fixture +def lmdb_test_env(): + """Create and cleanup temporary directory for LMDB databases""" + import shutil + import tempfile + + temp_dir = tempfile.mkdtemp(prefix='lmdb_test_') + yield temp_dir + + # Cleanup + shutil.rmtree(temp_dir, ignore_errors=True) + + +@pytest.fixture +def lmdb_config(lmdb_test_env): + """Create LMDB config from test env directory""" + return { + 'db_path': str(Path(lmdb_test_env) / 'test.lmdb'), + 'map_size': 100 * 1024**2, # 100MB + 'transaction_size': 1000, + 'create_if_missing': True, + } + + +@pytest.fixture +def lmdb_perf_config(lmdb_test_env): + """LMDB configuration for performance testing""" + return { + 'db_path': str(Path(lmdb_test_env) / 'perf.lmdb'), + 'map_size': 500 * 1024**2, # 500MB for performance tests + 'transaction_size': 5000, + 'create_if_missing': True, + } + + +@pytest.mark.lmdb +class TestLMDBCore(BaseLoaderTests): + """LMDB core loader tests (inherited from base)""" + + # Note: LMDB config needs special handling + config = LMDBTestConfig() + + +@pytest.mark.lmdb +class TestLMDBStreaming(BaseStreamingTests): + """LMDB streaming tests (inherited from base)""" + + config = LMDBTestConfig() + + +@pytest.mark.lmdb +class TestLMDBSpecific: + """LMDB-specific tests that cannot be generalized""" + + def test_key_column_strategy(self, lmdb_config, small_test_data): + """Test LMDB key column strategy""" + config = {**lmdb_config, 'key_column': 'id'} + loader = LMDBLoader(config) + + with loader: + result = loader.load_table(small_test_data, 'test_table') + assert result.success == True + assert result.rows_loaded == 5 + + # Verify keys are based on id column + db = loader._get_or_create_db(getattr(loader.config, 'database_name', None)) + with loader.env.begin(db=db) as txn: + cursor = txn.cursor() + keys = [key.decode() for key, _ in cursor] + + # Keys should be string representations of IDs + assert len(keys) == 5 + # Keys may be prefixed with table name + + def test_key_pattern_strategy(self, lmdb_config): + """Test custom key pattern generation""" + import pyarrow as pa + + data = {'tx_hash': ['0x100', '0x101', '0x102'], 'block': [100, 101, 102], 'value': [10.0, 11.0, 12.0]} + test_data = pa.Table.from_pydict(data) + + config = {**lmdb_config, 'key_column': 'tx_hash', 'key_pattern': 'tx:{key}'} + loader = LMDBLoader(config) + + with loader: + result = loader.load_table(test_data, 'test_table') + assert result.success == True + + # Verify custom key pattern + db = loader._get_or_create_db(getattr(loader.config, 'database_name', None)) + with loader.env.begin(db=db) as txn: + cursor = txn.cursor() + keys = [key.decode() for key, _ in cursor] + + # Keys should follow pattern + assert len(keys) == 3 + # At least one key should contain the pattern + assert any('tx:' in key or '0x' in key for key in keys) + + def test_composite_key_strategy(self, lmdb_config): + """Test composite key generation""" + import pyarrow as pa + + data = {'network': ['eth', 'eth', 'poly'], 'block': [100, 101, 100], 'tx_index': [0, 0, 0]} + test_data = pa.Table.from_pydict(data) + + config = {**lmdb_config, 'composite_key_columns': ['network', 'block', 'tx_index']} + loader = LMDBLoader(config) + + with loader: + result = loader.load_table(test_data, 'test_table') + assert result.success == True + assert result.rows_loaded == 3 + + # Verify composite keys + db = loader._get_or_create_db(getattr(loader.config, 'database_name', None)) + with loader.env.begin(db=db) as txn: + cursor = txn.cursor() + keys = [key.decode() for key, _ in cursor] + assert len(keys) == 3 + + def test_named_database(self, lmdb_config): + """Test LMDB named databases (sub-databases)""" + import pyarrow as pa + + data = {'id': [1, 2, 3], 'value': [100, 200, 300]} + test_data = pa.Table.from_pydict(data) + + config = {**lmdb_config, 'database_name': 'my_table'} + loader = LMDBLoader(config) + + with loader: + result = loader.load_table(test_data, 'test_table') + assert result.success == True + assert result.rows_loaded == 3 + + # Verify data in named database + db = loader._get_or_create_db(getattr(loader.config, 'database_name', None)) + with loader.env.begin(db=db) as txn: + cursor = txn.cursor() + count = sum(1 for _ in cursor) + assert count == 3 + + def test_transaction_batching(self, lmdb_test_env): + """Test transaction batching for large datasets""" + import pyarrow as pa + + # Create large dataset + large_data = {'id': list(range(5000)), 'value': [i * 10 for i in range(5000)]} + large_table = pa.Table.from_pydict(large_data) + + config = { + 'db_path': str(Path(lmdb_test_env) / 'batch_test.lmdb'), + 'map_size': 100 * 1024**2, + 'transaction_size': 500, # Batch every 500 rows + 'create_if_missing': True, + 'key_column': 'id', + } + loader = LMDBLoader(config) + + with loader: + result = loader.load_table(large_table, 'test_table') + assert result.success == True + assert result.rows_loaded == 5000 + + # Verify all data was loaded + db = loader._get_or_create_db(getattr(loader.config, 'database_name', None)) + with loader.env.begin(db=db) as txn: + cursor = txn.cursor() + count = sum(1 for _ in cursor) + assert count == 5000 + + def test_byte_key_handling(self, lmdb_config): + """Test handling of byte keys""" + import pyarrow as pa + + data = {'key': [b'key1', b'key2', b'key3'], 'value': [100, 200, 300]} + test_data = pa.Table.from_pydict(data) + + config = {**lmdb_config, 'key_column': 'key'} + loader = LMDBLoader(config) + + with loader: + result = loader.load_table(test_data, 'test_table') + assert result.success == True + assert result.rows_loaded == 3 + + def test_data_persistence(self, lmdb_config, small_test_data): + """Test that data persists after closing and reopening""" + import pyarrow as pa + + from src.amp.loaders.base import LoadMode + + # Load data + loader1 = LMDBLoader(lmdb_config) + with loader1: + result = loader1.load_table(small_test_data, 'test_table') + assert result.success == True + + # Close and reopen + loader2 = LMDBLoader(lmdb_config) + with loader2: + # Data should still be there + db = loader2._get_or_create_db(getattr(loader2.config, 'database_name', None)) + with loader2.env.begin(db=db) as txn: + cursor = txn.cursor() + count = sum(1 for _ in cursor) + assert count == 5 + + # Can append more (use different IDs to avoid key conflicts in key-value store) + additional_data = pa.Table.from_pydict( + { + 'id': [6, 7, 8, 9, 10], + 'name': ['f', 'g', 'h', 'i', 'j'], + 'value': [60.6, 70.7, 80.8, 90.9, 100.0], + 'year': [2024, 2024, 2024, 2024, 2024], + 'month': [1, 1, 1, 1, 1], + 'day': [6, 7, 8, 9, 10], + 'active': [False, True, False, True, False], + } + ) + result2 = loader2.load_table(additional_data, 'test_table', mode=LoadMode.APPEND) + assert result2.success == True + + # Now should have 10 + db = loader2._get_or_create_db(getattr(loader2.config, 'database_name', None)) + with loader2.env.begin(db=db) as txn: + cursor = txn.cursor() + count = sum(1 for _ in cursor) + assert count == 10 + + +@pytest.mark.lmdb +@pytest.mark.slow +class TestLMDBPerformance: + """LMDB performance tests""" + + def test_large_data_loading(self, lmdb_perf_config): + """Test loading large datasets to LMDB""" + import pyarrow as pa + + # Create large dataset + large_data = { + 'id': list(range(50000)), + 'value': [i * 0.123 for i in range(50000)], + 'category': [f'cat_{i % 100}' for i in range(50000)], + } + large_table = pa.Table.from_pydict(large_data) + + config = {**lmdb_perf_config, 'key_column': 'id'} + loader = LMDBLoader(config) + + with loader: + result = loader.load_table(large_table, 'test_table') + + assert result.success == True + assert result.rows_loaded == 50000 + assert result.duration < 60 # Should complete within 60 seconds + + # Verify data integrity + db = loader._get_or_create_db(getattr(loader.config, 'database_name', None)) + with loader.env.begin(db=db) as txn: + cursor = txn.cursor() + count = sum(1 for _ in cursor) + assert count == 50000 diff --git a/tests/integration/loaders/backends/test_postgresql.py b/tests/integration/loaders/backends/test_postgresql.py new file mode 100644 index 0000000..765223d --- /dev/null +++ b/tests/integration/loaders/backends/test_postgresql.py @@ -0,0 +1,354 @@ +""" +PostgreSQL-specific loader integration tests. + +This module provides PostgreSQL-specific test configuration and tests that +inherit from the generalized base test classes. +""" + +import time +from typing import Any, Dict, List, Optional + +import pytest + +try: + from src.amp.loaders.implementations.postgresql_loader import PostgreSQLLoader + from tests.integration.loaders.conftest import LoaderTestConfig + from tests.integration.loaders.test_base_loader import BaseLoaderTests + from tests.integration.loaders.test_base_streaming import BaseStreamingTests +except ImportError: + pytest.skip('amp modules not available', allow_module_level=True) + + +class PostgreSQLTestConfig(LoaderTestConfig): + """PostgreSQL-specific test configuration""" + + loader_class = PostgreSQLLoader + config_fixture_name = 'postgresql_test_config' + + supports_overwrite = True + supports_streaming = True + supports_multi_network = True + supports_null_values = True + + def get_row_count(self, loader: PostgreSQLLoader, table_name: str) -> int: + """Get row count from PostgreSQL table""" + conn = loader.pool.getconn() + try: + with conn.cursor() as cur: + cur.execute(f'SELECT COUNT(*) FROM {table_name}') + return cur.fetchone()[0] + finally: + loader.pool.putconn(conn) + + def query_rows( + self, loader: PostgreSQLLoader, table_name: str, where: Optional[str] = None, order_by: Optional[str] = None + ) -> List[Dict[str, Any]]: + """Query rows from PostgreSQL table""" + conn = loader.pool.getconn() + try: + with conn.cursor() as cur: + # Get column names first + cur.execute( + """ + SELECT column_name + FROM information_schema.columns + WHERE table_name = %s + ORDER BY ordinal_position + """, + (table_name,), + ) + columns = [row[0] for row in cur.fetchall()] + + # Build query + query = f'SELECT * FROM {table_name}' + if where: + query += f' WHERE {where}' + if order_by: + query += f' ORDER BY {order_by}' + + cur.execute(query) + rows = cur.fetchall() + + # Convert to list of dicts + return [dict(zip(columns, row, strict=False)) for row in rows] + finally: + loader.pool.putconn(conn) + + def cleanup_table(self, loader: PostgreSQLLoader, table_name: str) -> None: + """Drop PostgreSQL table""" + conn = loader.pool.getconn() + try: + with conn.cursor() as cur: + cur.execute(f'DROP TABLE IF EXISTS {table_name} CASCADE') + conn.commit() + finally: + loader.pool.putconn(conn) + + def get_column_names(self, loader: PostgreSQLLoader, table_name: str) -> List[str]: + """Get column names from PostgreSQL table""" + conn = loader.pool.getconn() + try: + with conn.cursor() as cur: + cur.execute( + """ + SELECT column_name + FROM information_schema.columns + WHERE table_name = %s + ORDER BY ordinal_position + """, + (table_name,), + ) + return [row[0] for row in cur.fetchall()] + finally: + loader.pool.putconn(conn) + + +@pytest.mark.postgresql +class TestPostgreSQLCore(BaseLoaderTests): + """PostgreSQL core loader tests (inherited from base)""" + + config = PostgreSQLTestConfig() + + +@pytest.mark.postgresql +class TestPostgreSQLStreaming(BaseStreamingTests): + """PostgreSQL streaming tests (inherited from base)""" + + config = PostgreSQLTestConfig() + + +@pytest.fixture +def cleanup_tables(postgresql_test_config): + """Cleanup test tables after tests""" + tables_to_clean = [] + + yield tables_to_clean + + # Cleanup + loader = PostgreSQLLoader(postgresql_test_config) + try: + loader.connect() + conn = loader.pool.getconn() + try: + with conn.cursor() as cur: + for table in tables_to_clean: + try: + cur.execute(f'DROP TABLE IF EXISTS {table} CASCADE') + conn.commit() + except Exception: + pass + finally: + loader.pool.putconn(conn) + loader.disconnect() + except Exception: + pass + + +@pytest.mark.postgresql +class TestPostgreSQLSpecific: + """PostgreSQL-specific tests that cannot be generalized""" + + def test_connection_pooling(self, postgresql_test_config, small_test_data, test_table_name, cleanup_tables): + """Test PostgreSQL connection pooling behavior""" + from src.amp.loaders.base import LoadMode + + cleanup_tables.append(test_table_name) + + loader = PostgreSQLLoader(postgresql_test_config) + + with loader: + # Perform multiple operations to test pool reuse + for i in range(5): + subset = small_test_data.slice(i, 1) + mode = LoadMode.OVERWRITE if i == 0 else LoadMode.APPEND + + result = loader.load_table(subset, test_table_name, mode=mode) + assert result.success == True + + # Verify pool is managing connections properly + # Note: _used is a dict in ThreadedConnectionPool, not an int + assert len(loader.pool._used) <= loader.pool.maxconn + + def test_binary_data_handling(self, postgresql_test_config, test_table_name, cleanup_tables): + """Test binary data handling with INSERT fallback""" + import pyarrow as pa + + cleanup_tables.append(test_table_name) + + # Create data with binary columns + data = {'id': [1, 2, 3], 'binary_data': [b'hello', b'world', b'test'], 'text_data': ['a', 'b', 'c']} + table = pa.Table.from_pydict(data) + + loader = PostgreSQLLoader(postgresql_test_config) + + with loader: + result = loader.load_table(table, test_table_name) + assert result.success == True + assert result.rows_loaded == 3 + + # Verify binary data was stored correctly + conn = loader.pool.getconn() + try: + with conn.cursor() as cur: + cur.execute(f'SELECT id, binary_data FROM {test_table_name} ORDER BY id') + rows = cur.fetchall() + assert rows[0][1].tobytes() == b'hello' + assert rows[1][1].tobytes() == b'world' + assert rows[2][1].tobytes() == b'test' + finally: + loader.pool.putconn(conn) + + def test_schema_retrieval(self, postgresql_test_config, small_test_data, test_table_name, cleanup_tables): + """Test schema retrieval functionality""" + cleanup_tables.append(test_table_name) + + loader = PostgreSQLLoader(postgresql_test_config) + + with loader: + # Create table + result = loader.load_table(small_test_data, test_table_name) + assert result.success == True + + # Get schema + schema = loader.get_table_schema(test_table_name) + assert schema is not None + + # Filter out metadata columns added by PostgreSQL loader + non_meta_fields = [ + field for field in schema if not (field.name.startswith('_meta_') or field.name.startswith('_amp_')) + ] + + assert len(non_meta_fields) == len(small_test_data.schema) + + # Verify column names match (excluding metadata columns) + original_names = set(small_test_data.schema.names) + retrieved_names = set(field.name for field in non_meta_fields) + assert original_names == retrieved_names + + def test_performance_metrics(self, postgresql_test_config, medium_test_table, test_table_name, cleanup_tables): + """Test performance metrics in results""" + cleanup_tables.append(test_table_name) + + loader = PostgreSQLLoader(postgresql_test_config) + + with loader: + start_time = time.time() + result = loader.load_table(medium_test_table, test_table_name) + end_time = time.time() + + assert result.success == True + assert result.duration > 0 + assert result.duration <= (end_time - start_time) + assert result.rows_loaded == 10000 + + # Check metadata contains performance info + assert 'table_size_bytes' in result.metadata + assert result.metadata['table_size_bytes'] > 0 + + def test_null_value_handling_detailed( + self, postgresql_test_config, null_test_data, test_table_name, cleanup_tables + ): + """Test comprehensive null value handling across all PostgreSQL data types""" + cleanup_tables.append(test_table_name) + + loader = PostgreSQLLoader(postgresql_test_config) + + with loader: + result = loader.load_table(null_test_data, test_table_name) + assert result.success == True + assert result.rows_loaded == 10 + + conn = loader.pool.getconn() + try: + with conn.cursor() as cur: + # Check text field nulls (rows 3, 6, 9 have index 2, 5, 8) + cur.execute(f'SELECT COUNT(*) FROM {test_table_name} WHERE text_field IS NULL') + text_nulls = cur.fetchone()[0] + assert text_nulls == 3 + + # Check int field nulls (rows 2, 5, 8 have index 1, 4, 7) + cur.execute(f'SELECT COUNT(*) FROM {test_table_name} WHERE int_field IS NULL') + int_nulls = cur.fetchone()[0] + assert int_nulls == 3 + + # Check float field nulls (rows 3, 6, 9 have index 2, 5, 8) + cur.execute(f'SELECT COUNT(*) FROM {test_table_name} WHERE float_field IS NULL') + float_nulls = cur.fetchone()[0] + assert float_nulls == 3 + + # Check bool field nulls (rows 3, 6, 9 have index 2, 5, 8) + cur.execute(f'SELECT COUNT(*) FROM {test_table_name} WHERE bool_field IS NULL') + bool_nulls = cur.fetchone()[0] + assert bool_nulls == 3 + + # Check timestamp field nulls + cur.execute(f'SELECT COUNT(*) FROM {test_table_name} WHERE timestamp_field IS NULL') + timestamp_nulls = cur.fetchone()[0] + assert timestamp_nulls == 4 + + # Check json field nulls + cur.execute(f'SELECT COUNT(*) FROM {test_table_name} WHERE json_field IS NULL') + json_nulls = cur.fetchone()[0] + assert json_nulls == 3 + + # Verify non-null values are intact + cur.execute(f'SELECT text_field FROM {test_table_name} WHERE id = 1') + text_val = cur.fetchone()[0] + assert text_val in ['a', '"a"'] # Handle potential CSV quoting + + cur.execute(f'SELECT int_field FROM {test_table_name} WHERE id = 1') + int_val = cur.fetchone()[0] + assert int_val == 1 + + cur.execute(f'SELECT float_field FROM {test_table_name} WHERE id = 1') + float_val = cur.fetchone()[0] + assert abs(float_val - 1.1) < 0.01 + + cur.execute(f'SELECT bool_field FROM {test_table_name} WHERE id = 1') + bool_val = cur.fetchone()[0] + assert bool_val == True + finally: + loader.pool.putconn(conn) + + +@pytest.mark.postgresql +@pytest.mark.slow +class TestPostgreSQLPerformance: + """PostgreSQL performance tests""" + + def test_large_data_loading(self, postgresql_test_config, test_table_name, cleanup_tables): + """Test loading large datasets""" + from datetime import datetime + + import pyarrow as pa + + cleanup_tables.append(test_table_name) + + # Create large dataset + large_data = { + 'id': list(range(50000)), + 'value': [i * 0.123 for i in range(50000)], + 'category': [f'category_{i % 100}' for i in range(50000)], + 'description': [f'This is a longer text description for row {i}' for i in range(50000)], + 'created_at': [datetime.now() for _ in range(50000)], + } + large_table = pa.Table.from_pydict(large_data) + + loader = PostgreSQLLoader(postgresql_test_config) + + with loader: + result = loader.load_table(large_table, test_table_name) + + assert result.success == True + assert result.rows_loaded == 50000 + assert result.duration < 60 # Should complete within 60 seconds + + # Verify data integrity + conn = loader.pool.getconn() + try: + with conn.cursor() as cur: + cur.execute(f'SELECT COUNT(*) FROM {test_table_name}') + count = cur.fetchone()[0] + assert count == 50000 + finally: + loader.pool.putconn(conn) diff --git a/tests/integration/loaders/backends/test_redis.py b/tests/integration/loaders/backends/test_redis.py new file mode 100644 index 0000000..f892034 --- /dev/null +++ b/tests/integration/loaders/backends/test_redis.py @@ -0,0 +1,389 @@ +""" +Redis-specific loader integration tests. + +This module provides Redis-specific test configuration and tests that +inherit from the generalized base test classes. +""" + +import json +from typing import Any, Dict, List, Optional + +import pytest + +try: + import redis + + from src.amp.loaders.implementations.redis_loader import RedisLoader + from tests.integration.loaders.conftest import LoaderTestConfig + from tests.integration.loaders.test_base_loader import BaseLoaderTests + from tests.integration.loaders.test_base_streaming import BaseStreamingTests +except ImportError: + pytest.skip('amp modules not available', allow_module_level=True) + + +class RedisTestConfig(LoaderTestConfig): + """Redis-specific test configuration""" + + loader_class = RedisLoader + + supports_overwrite = True + supports_streaming = True + supports_multi_network = True + supports_null_values = True + requires_existing_table = False # Redis auto-creates keys/structures + + def __init__(self, config_fixture_name='redis_test_config'): + """ + Initialize Redis test config. + + Args: + config_fixture_name: Name of the pytest fixture providing loader config + (default: 'redis_test_config' for core tests, + use 'redis_streaming_config' for streaming tests) + """ + self.config_fixture_name = config_fixture_name + + def get_row_count(self, loader: RedisLoader, table_name: str) -> int: + """Get row count from Redis based on data structure""" + data_structure = getattr(loader.config, 'data_structure', 'hash') + + if data_structure == 'hash': + # Count hash keys matching pattern + pattern = f'{table_name}:*' + count = 0 + for _ in loader.redis_client.scan_iter(match=pattern, count=1000): + count += 1 + return count + elif data_structure == 'string': + # Count string keys matching pattern + pattern = f'{table_name}:*' + count = 0 + for _ in loader.redis_client.scan_iter(match=pattern, count=1000): + count += 1 + return count + elif data_structure == 'stream': + # Get stream length + try: + return loader.redis_client.xlen(table_name) + except Exception: + return 0 + else: + # For other structures, scan for keys + pattern = f'{table_name}:*' + count = 0 + for _ in loader.redis_client.scan_iter(match=pattern, count=1000): + count += 1 + return count + + def query_rows( + self, loader: RedisLoader, table_name: str, where: Optional[str] = None, order_by: Optional[str] = None + ) -> List[Dict[str, Any]]: + """Query rows from Redis - limited functionality due to Redis architecture""" + # Redis doesn't support SQL-like queries, so this is a simplified implementation + data_structure = getattr(loader.config, 'data_structure', 'hash') + rows = [] + + if data_structure == 'hash': + pattern = f'{table_name}:*' + for key in loader.redis_client.scan_iter(match=pattern, count=100): + data = loader.redis_client.hgetall(key) + row = {k.decode(): v.decode() if isinstance(v, bytes) else v for k, v in data.items()} + rows.append(row) + elif data_structure == 'stream': + # Read from stream + messages = loader.redis_client.xrange(table_name, count=100) + for msg_id, data in messages: + row = {k.decode(): v.decode() if isinstance(v, bytes) else v for k, v in data.items()} + row['_stream_id'] = msg_id.decode() + rows.append(row) + + return rows[:100] # Limit to 100 rows + + def cleanup_table(self, loader: RedisLoader, table_name: str) -> None: + """Drop Redis keys for table""" + # Delete all keys matching the table pattern + patterns = [f'{table_name}:*', table_name] + + for pattern in patterns: + for key in loader.redis_client.scan_iter(match=pattern, count=1000): + loader.redis_client.delete(key) + + def get_column_names(self, loader: RedisLoader, table_name: str) -> List[str]: + """Get column names from Redis - return fields from first record""" + data_structure = getattr(loader.config, 'data_structure', 'hash') + + if data_structure == 'hash': + pattern = f'{table_name}:*' + for key in loader.redis_client.scan_iter(match=pattern, count=1): + data = loader.redis_client.hgetall(key) + return [k.decode() if isinstance(k, bytes) else k for k in data.keys()] + elif data_structure == 'stream': + messages = loader.redis_client.xrange(table_name, count=1) + if messages: + _, data = messages[0] + return [k.decode() if isinstance(k, bytes) else k for k in data.keys()] + + return [] + + +@pytest.fixture +def cleanup_redis(redis_test_config): + """Cleanup Redis data after tests""" + keys_to_clean = [] + patterns_to_clean = [] + + yield (keys_to_clean, patterns_to_clean) + + # Cleanup + try: + r = redis.Redis( + host=redis_test_config['host'], + port=redis_test_config['port'], + db=redis_test_config['db'], + password=redis_test_config['password'], + ) + + # Delete specific keys + for key in keys_to_clean: + r.delete(key) + + # Delete keys matching patterns + for pattern in patterns_to_clean: + for key in r.scan_iter(match=pattern, count=1000): + r.delete(key) + + r.close() + except Exception: + pass + + +@pytest.mark.redis +class TestRedisCore(BaseLoaderTests): + """Redis core loader tests (inherited from base)""" + + config = RedisTestConfig() + + +@pytest.mark.redis +class TestRedisStreaming(BaseStreamingTests): + """Redis streaming tests (inherited from base)""" + + config = RedisTestConfig('redis_streaming_config') + + +@pytest.mark.redis +class TestRedisSpecific: + """Redis-specific tests that cannot be generalized""" + + def test_hash_storage(self, redis_test_config, small_test_data, cleanup_redis): + """Test Redis hash data structure storage""" + + keys_to_clean, patterns_to_clean = cleanup_redis + table_name = 'test_hash' + patterns_to_clean.append(f'{table_name}:*') + + config = {**redis_test_config, 'data_structure': 'hash'} # Uses default key_pattern {table}:{id} + loader = RedisLoader(config) + + with loader: + result = loader.load_table(small_test_data, table_name) + assert result.success == True + assert result.rows_loaded == 5 + + # Verify data is stored as hashes + pattern = f'{table_name}:*' + keys = list(loader.redis_client.scan_iter(match=pattern, count=100)) + assert len(keys) == 5 + + # Verify hash structure + first_key = keys[0] + data = loader.redis_client.hgetall(first_key) + assert len(data) > 0 + + def test_string_storage(self, redis_test_config, small_test_data, cleanup_redis): + """Test Redis string data structure storage""" + keys_to_clean, patterns_to_clean = cleanup_redis + table_name = 'test_string' + patterns_to_clean.append(f'{table_name}:*') + + config = {**redis_test_config, 'data_structure': 'string'} # Uses default key_pattern {table}:{id} + loader = RedisLoader(config) + + with loader: + result = loader.load_table(small_test_data, table_name) + assert result.success == True + assert result.rows_loaded == 5 + + # Verify data is stored as strings + pattern = f'{table_name}:*' + keys = list(loader.redis_client.scan_iter(match=pattern, count=100)) + assert len(keys) == 5 + + # Verify string structure (should be JSON) + first_key = keys[0] + data = loader.redis_client.get(first_key) + assert data is not None + # Should be JSON-encoded + json.loads(data) + + def test_stream_storage(self, redis_test_config, small_test_data, cleanup_redis): + """Test Redis stream data structure storage""" + keys_to_clean, patterns_to_clean = cleanup_redis + table_name = 'test_stream' + stream_key = f'{table_name}:stream' # Stream keys use table_name:stream format + keys_to_clean.append(stream_key) + + config = {**redis_test_config, 'data_structure': 'stream'} + loader = RedisLoader(config) + + with loader: + result = loader.load_table(small_test_data, table_name) + assert result.success == True + assert result.rows_loaded == 5 + + # Verify data is in stream (Redis stream key format is table_name:stream) + stream_len = loader.redis_client.xlen(stream_key) + assert stream_len == 5 + + def test_set_storage(self, redis_test_config, small_test_data, cleanup_redis): + """Test Redis set data structure storage""" + keys_to_clean, patterns_to_clean = cleanup_redis + table_name = 'test_set' + patterns_to_clean.append(f'{table_name}:*') + + config = {**redis_test_config, 'data_structure': 'set'} # Uses default key_pattern {table}:{id} + loader = RedisLoader(config) + + with loader: + result = loader.load_table(small_test_data, table_name) + assert result.success == True + assert result.rows_loaded == 5 + + def test_ttl_functionality(self, redis_test_config, small_test_data, cleanup_redis): + """Test TTL (time-to-live) functionality""" + import time + + keys_to_clean, patterns_to_clean = cleanup_redis + table_name = 'test_ttl' + patterns_to_clean.append(f'{table_name}:*') + + config = {**redis_test_config, 'data_structure': 'hash', 'ttl': 2} # 2 second TTL, uses default key_pattern + loader = RedisLoader(config) + + with loader: + result = loader.load_table(small_test_data, table_name) + assert result.success == True + + # Verify data exists + pattern = f'{table_name}:*' + keys_before = list(loader.redis_client.scan_iter(match=pattern, count=100)) + assert len(keys_before) == 5 + + # Verify TTL is set + ttl = loader.redis_client.ttl(keys_before[0]) + assert ttl > 0 and ttl <= 2 + + # Wait for TTL to expire + time.sleep(3) + + # Verify data has expired + keys_after = list(loader.redis_client.scan_iter(match=pattern, count=100)) + assert len(keys_after) == 0 + + def test_key_pattern_generation(self, redis_test_config, cleanup_redis): + """Test custom key pattern generation""" + import pyarrow as pa + + keys_to_clean, patterns_to_clean = cleanup_redis + table_name = 'test_key_pattern' + patterns_to_clean.append(f'{table_name}:*') + + data = {'user_id': ['user1', 'user2', 'user3'], 'score': [100, 200, 300], 'level': [1, 2, 3]} + test_data = pa.Table.from_pydict(data) + + config = { + **redis_test_config, + 'data_structure': 'hash', + 'key_pattern': '{table}:user:{user_id}', # Custom key pattern using user_id field + } + loader = RedisLoader(config) + + with loader: + result = loader.load_table(test_data, table_name) + assert result.success == True + + # Verify custom key pattern + pattern = f'{table_name}:user:*' + keys = list(loader.redis_client.scan_iter(match=pattern, count=100)) + assert len(keys) == 3 + + # Verify key format + key_str = keys[0].decode() if isinstance(keys[0], bytes) else keys[0] + assert key_str.startswith(f'{table_name}:user:') + + def test_data_structure_comparison(self, redis_test_config, comprehensive_test_data, cleanup_redis): + """Test performance comparison between different data structures""" + import time + + keys_to_clean, patterns_to_clean = cleanup_redis + + structures = ['hash', 'string', 'stream'] + results = {} + + for structure in structures: + table_name = f'test_perf_{structure}' + patterns_to_clean.append(f'{table_name}:*') + keys_to_clean.append(table_name) + + config = {**redis_test_config, 'data_structure': structure} + # Hash and string structures use default key_pattern {table}:{id} + + loader = RedisLoader(config) + + with loader: + start_time = time.time() + result = loader.load_table(comprehensive_test_data, table_name) + duration = time.time() - start_time + + results[structure] = { + 'success': result.success, + 'duration': duration, + 'rows_loaded': result.rows_loaded, + } + + # Verify all structures work + for _structure, data in results.items(): + assert data['success'] == True + assert data['rows_loaded'] == 1000 + + +@pytest.mark.redis +@pytest.mark.slow +class TestRedisPerformance: + """Redis performance tests""" + + def test_large_data_loading(self, redis_test_config, cleanup_redis): + """Test loading large datasets to Redis""" + import pyarrow as pa + + keys_to_clean, patterns_to_clean = cleanup_redis + table_name = 'test_large' + patterns_to_clean.append(f'{table_name}:*') + + # Create large dataset + large_data = { + 'id': list(range(10000)), + 'value': [i * 0.123 for i in range(10000)], + 'category': [f'cat_{i % 100}' for i in range(10000)], + } + large_table = pa.Table.from_pydict(large_data) + + config = {**redis_test_config, 'data_structure': 'hash'} # Uses default key_pattern {table}:{id} + loader = RedisLoader(config) + + with loader: + result = loader.load_table(large_table, table_name) + + assert result.success == True + assert result.rows_loaded == 10000 + assert result.duration < 60 # Should complete within 60 seconds diff --git a/tests/integration/loaders/backends/test_snowflake.py b/tests/integration/loaders/backends/test_snowflake.py new file mode 100644 index 0000000..820e923 --- /dev/null +++ b/tests/integration/loaders/backends/test_snowflake.py @@ -0,0 +1,306 @@ +""" +Snowflake-specific loader integration tests. + +This module provides Snowflake-specific test configuration and tests that +inherit from the generalized base test classes. + +Note: Snowflake tests require valid Snowflake credentials and are typically +skipped in CI/CD. Run manually with: pytest -m snowflake +""" + +from typing import Any, Dict, List, Optional + +import pytest + +try: + from src.amp.loaders.implementations.snowflake_loader import SnowflakeLoader + from tests.integration.loaders.conftest import LoaderTestConfig + from tests.integration.loaders.test_base_loader import BaseLoaderTests + from tests.integration.loaders.test_base_streaming import BaseStreamingTests +except ImportError: + pytest.skip('amp modules not available', allow_module_level=True) + + +class SnowflakeTestConfig(LoaderTestConfig): + """Snowflake-specific test configuration""" + + loader_class = SnowflakeLoader + config_fixture_name = 'snowflake_config' + + supports_overwrite = False # Snowflake doesn't support OVERWRITE mode + supports_streaming = True + supports_multi_network = True + supports_null_values = True + requires_existing_table = False # Snowflake auto-creates tables + + def get_row_count(self, loader: SnowflakeLoader, table_name: str) -> int: + """Get row count from Snowflake table""" + with loader.connection.cursor() as cur: + cur.execute(f'SELECT COUNT(*) FROM {table_name}') + return cur.fetchone()[0] + + def query_rows( + self, loader: SnowflakeLoader, table_name: str, where: Optional[str] = None, order_by: Optional[str] = None + ) -> List[Dict[str, Any]]: + """Query rows from Snowflake table""" + query = f'SELECT * FROM {table_name}' + if where: + query += f' WHERE {where}' + if order_by: + query += f' ORDER BY {order_by}' + query += ' LIMIT 100' + + with loader.connection.cursor() as cur: + cur.execute(query) + columns = [col[0] for col in cur.description] + rows = cur.fetchall() + return [dict(zip(columns, row, strict=False)) for row in rows] + + def cleanup_table(self, loader: SnowflakeLoader, table_name: str) -> None: + """Drop Snowflake table""" + with loader.connection.cursor() as cur: + cur.execute(f'DROP TABLE IF EXISTS {table_name}') + + def get_column_names(self, loader: SnowflakeLoader, table_name: str) -> List[str]: + """Get column names from Snowflake table""" + with loader.connection.cursor() as cur: + cur.execute(f'SELECT * FROM {table_name} LIMIT 0') + return [col[0] for col in cur.description] + + +@pytest.mark.snowflake +class TestSnowflakeCore(BaseLoaderTests): + """Snowflake core loader tests (inherited from base)""" + + config = SnowflakeTestConfig() + + +@pytest.mark.snowflake +class TestSnowflakeStreaming(BaseStreamingTests): + """Snowflake streaming tests (inherited from base)""" + + config = SnowflakeTestConfig() + + +@pytest.fixture +def cleanup_tables(snowflake_config): + """Cleanup Snowflake tables after tests""" + tables_to_clean = [] + + yield tables_to_clean + + # Cleanup + if tables_to_clean: + try: + from snowflake.connector import connect + + conn = connect(**snowflake_config) + with conn.cursor() as cur: + for table_name in tables_to_clean: + try: + cur.execute(f'DROP TABLE IF EXISTS {table_name}') + except Exception: + pass + conn.close() + except Exception: + pass + + +@pytest.mark.snowflake +class TestSnowflakeSpecific: + """Snowflake-specific tests that cannot be generalized""" + + def test_stage_loading_method(self, snowflake_config, small_test_table, test_table_name, cleanup_tables): + """Test Snowflake stage-based loading (Snowflake-specific optimization)""" + + cleanup_tables.append(test_table_name) + + # Configure for stage loading + config = {**snowflake_config, 'loading_method': 'stage'} + loader = SnowflakeLoader(config) + + with loader: + result = loader.load_table(small_test_table, test_table_name) + assert result.success == True + assert result.rows_loaded == 100 + + # Verify data loaded + with loader.connection.cursor() as cur: + cur.execute(f'SELECT COUNT(*) FROM {test_table_name}') + count = cur.fetchone()[0] + assert count == 100 + + def test_insert_loading_method(self, snowflake_config, small_test_table, test_table_name, cleanup_tables): + """Test Snowflake INSERT-based loading""" + cleanup_tables.append(test_table_name) + + # Configure for INSERT loading + config = {**snowflake_config, 'loading_method': 'insert'} + loader = SnowflakeLoader(config) + + with loader: + result = loader.load_table(small_test_table, test_table_name) + assert result.success == True + assert result.rows_loaded == 100 + + def test_table_info_retrieval(self, snowflake_config, small_test_table, test_table_name, cleanup_tables): + """Test Snowflake table information retrieval""" + cleanup_tables.append(test_table_name) + + loader = SnowflakeLoader(snowflake_config) + + with loader: + # Load data first + loader.load_table(small_test_table, test_table_name) + + # Get table info + info = loader.get_table_info(test_table_name) + + assert info is not None + assert 'row_count' in info + assert 'bytes' in info + assert 'clustering_key' in info + assert info['row_count'] == 100 + + def test_concurrent_batch_loading(self, snowflake_config, medium_test_table, test_table_name, cleanup_tables): + """Test concurrent batch loading to Snowflake""" + from concurrent.futures import ThreadPoolExecutor, as_completed + + from src.amp.loaders.base import LoadMode + + cleanup_tables.append(test_table_name) + + loader = SnowflakeLoader(snowflake_config) + + with loader: + # Split table into batches + batches = medium_test_table.to_batches(max_chunksize=2000) + + # Load batches concurrently + def load_batch_with_mode(batch_tuple): + i, batch = batch_tuple + mode = LoadMode.OVERWRITE if i == 0 else LoadMode.APPEND + return loader.load_batch(batch, test_table_name, mode=mode) + + with ThreadPoolExecutor(max_workers=4) as executor: + futures = [executor.submit(load_batch_with_mode, (i, batch)) for i, batch in enumerate(batches)] + + results = [] + for future in as_completed(futures): + result = future.result() + results.append(result) + + # Verify all batches succeeded + assert all(r.success for r in results) + + # Verify total row count + with loader.connection.cursor() as cur: + cur.execute(f'SELECT COUNT(*) FROM {test_table_name}') + count = cur.fetchone()[0] + assert count == 10000 + + def test_schema_special_characters(self, snowflake_config, test_table_name, cleanup_tables): + """Test handling of special characters in column names""" + import pyarrow as pa + + cleanup_tables.append(test_table_name) + + # Create data with special characters in column names + data = { + 'id': [1, 2, 3], + 'user name': ['Alice', 'Bob', 'Charlie'], # Space in name + 'email@address': ['a@ex.com', 'b@ex.com', 'c@ex.com'], # @ symbol + 'data-value': [100, 200, 300], # Hyphen + } + test_data = pa.Table.from_pydict(data) + + loader = SnowflakeLoader(snowflake_config) + + with loader: + result = loader.load_table(test_data, test_table_name) + assert result.success == True + assert result.rows_loaded == 3 + + # Verify columns were properly escaped + with loader.connection.cursor() as cur: + cur.execute(f'SELECT * FROM {test_table_name} LIMIT 1') + columns = [col[0] for col in cur.description] + # Snowflake normalizes column names + assert len(columns) >= 4 + + def test_history_preservation_with_reorg(self, snowflake_config, test_table_name, cleanup_tables): + """Test Snowflake's history preservation during reorg (Time Travel feature)""" + import pyarrow as pa + + from src.amp.streaming.types import BatchMetadata, BlockRange, ResponseBatch + + cleanup_tables.append(test_table_name) + + config = {**snowflake_config, 'preserve_history': True} + loader = SnowflakeLoader(config) + + with loader: + # Load initial data + batch1 = pa.RecordBatch.from_pydict({'tx_hash': ['0x100', '0x101'], 'block_num': [100, 101]}) + + loader._create_table_from_schema(batch1.schema, test_table_name) + + response1 = ResponseBatch.data_batch( + data=batch1, + metadata=BatchMetadata(ranges=[BlockRange(network='ethereum', start=100, end=101, hash='0xaaa')]), + ) + + results = list(loader.load_stream_continuous(iter([response1]), test_table_name)) + assert len(results) == 1 + assert results[0].success + + # Verify initial count + with loader.connection.cursor() as cur: + cur.execute(f'SELECT COUNT(*) FROM {test_table_name}') + initial_count = cur.fetchone()[0] + assert initial_count == 2 + + # Perform reorg (with history preservation, data is soft-deleted) + reorg_response = ResponseBatch.reorg_batch( + invalidation_ranges=[BlockRange(network='ethereum', start=100, end=101)] + ) + reorg_results = list(loader.load_stream_continuous(iter([reorg_response]), test_table_name)) + assert len(reorg_results) == 1 + + # With history preservation, data may be marked deleted rather than physically removed + # Verify the reorg was processed (exact behavior depends on implementation) + cur.execute(f'SELECT COUNT(*) FROM {test_table_name}') + after_count = cur.fetchone()[0] + # Count should be 0 or data should have _deleted flag + assert after_count >= 0 # Flexible assertion for different implementations + + +@pytest.mark.snowflake +@pytest.mark.slow +class TestSnowflakePerformance: + """Snowflake performance tests""" + + def test_large_batch_loading(self, snowflake_config, performance_test_data, test_table_name, cleanup_tables): + """Test loading large batches to Snowflake""" + cleanup_tables.append(test_table_name) + + loader = SnowflakeLoader(snowflake_config) + + with loader: + result = loader.load_table(performance_test_data, test_table_name) + + assert result.success == True + assert result.rows_loaded == 50000 + assert result.duration < 120 # Should complete within 2 minutes + + # Verify data integrity + with loader.connection.cursor() as cur: + cur.execute(f'SELECT COUNT(*) FROM {test_table_name}') + count = cur.fetchone()[0] + assert count == 50000 + + +# Note: Snowpipe Streaming tests (TestSnowpipeStreamingIntegration) are kept +# in the original test file as they are highly Snowflake-specific and involve +# complex channel management that doesn't generalize well. See: +# tests/integration/test_snowflake_loader.py::TestSnowpipeStreamingIntegration diff --git a/tests/integration/loaders/conftest.py b/tests/integration/loaders/conftest.py new file mode 100644 index 0000000..86bf650 --- /dev/null +++ b/tests/integration/loaders/conftest.py @@ -0,0 +1,177 @@ +""" +Base test classes and configuration for loader integration tests. + +This module provides abstract base classes that enable test generalization +across different loader implementations while maintaining full test coverage. +""" + +from abc import ABC, abstractmethod +from datetime import datetime +from typing import Any, Dict, List, Optional, Type + +import pytest + +from src.amp.loaders.base import DataLoader + + +@pytest.fixture +def test_table_name(): + """Generate unique table name for each test""" + timestamp = datetime.now().strftime('%Y%m%d_%H%M%S_%f') + return f'test_table_{timestamp}' + + +class LoaderTestConfig(ABC): + """ + Configuration for a specific loader's tests. + + Each loader implementation must provide a concrete implementation + that defines loader-specific query methods and capabilities. + """ + + # Required: Specify the loader class being tested + loader_class: Type[DataLoader] = None + + # Required: Name of the pytest fixture that provides the loader config dict + config_fixture_name: str = None + + # Loader capability flags + supports_overwrite: bool = True + supports_streaming: bool = True + supports_multi_network: bool = True + supports_null_values: bool = True + requires_existing_table: bool = True # False for loaders that auto-create (DeltaLake, LMDB) + + @abstractmethod + def get_row_count(self, loader: DataLoader, table_name: str) -> int: + """ + Get the number of rows in a table. + + Args: + loader: The loader instance + table_name: Name of the table + + Returns: + Number of rows in the table + """ + raise NotImplementedError + + @abstractmethod + def query_rows( + self, loader: DataLoader, table_name: str, where: Optional[str] = None, order_by: Optional[str] = None + ) -> List[Dict[str, Any]]: + """ + Query rows from a table. + + Args: + loader: The loader instance + table_name: Name of the table + where: Optional WHERE clause (without 'WHERE' keyword) + order_by: Optional ORDER BY clause (without 'ORDER BY' keyword) + + Returns: + List of dictionaries representing rows + """ + raise NotImplementedError + + @abstractmethod + def cleanup_table(self, loader: DataLoader, table_name: str) -> None: + """ + Clean up/drop a table. + + Args: + loader: The loader instance + table_name: Name of the table to drop + """ + raise NotImplementedError + + @abstractmethod + def get_column_names(self, loader: DataLoader, table_name: str) -> List[str]: + """ + Get column names for a table. + + Args: + loader: The loader instance + table_name: Name of the table + + Returns: + List of column names + """ + raise NotImplementedError + + +class LoaderTestBase: + """ + Base class for all loader tests. + + Test classes should inherit from this and set the `config` class attribute + to a LoaderTestConfig instance. + + Example: + class TestPostgreSQLCore(BaseLoaderTests): + config = PostgreSQLTestConfig() + """ + + # Override this in subclasses + config: LoaderTestConfig = None + + @pytest.fixture + def loader(self, request): + """ + Create a loader instance from the config. + + This fixture dynamically retrieves the loader config from the + fixture name specified in LoaderTestConfig.config_fixture_name. + + For streaming tests, state management is enabled by default to support + reorg operations and batch tracking. + """ + if self.config is None: + raise ValueError('Test class must define a config attribute') + if self.config.loader_class is None: + raise ValueError('LoaderTestConfig must define loader_class') + if self.config.config_fixture_name is None: + raise ValueError('LoaderTestConfig must define config_fixture_name') + + # Get the loader config from the specified fixture + loader_config = request.getfixturevalue(self.config.config_fixture_name) + + # Enable state management for streaming tests (needed for reorg) + # Make a copy to avoid modifying the original fixture + if isinstance(loader_config, dict): + loader_config = loader_config.copy() + # Only enable state if not explicitly configured + if 'state' not in loader_config: + loader_config['state'] = {'enabled': True, 'storage': 'memory', 'store_batch_id': True} + + # Create and return the loader instance + return self.config.loader_class(loader_config) + + @pytest.fixture + def cleanup_tables(self, request): + """ + Cleanup fixture that drops tables after tests. + + Tests should append table names to this list, and they will be + cleaned up automatically after the test completes. + """ + tables_to_clean = [] + + yield tables_to_clean + + # Cleanup after test + if tables_to_clean and self.config: + loader_config = request.getfixturevalue(self.config.config_fixture_name) + loader = self.config.loader_class(loader_config) + try: + loader.connect() + for table_name in tables_to_clean: + try: + self.config.cleanup_table(loader, table_name) + except Exception: + # Ignore cleanup errors + pass + loader.disconnect() + except Exception: + # Ignore connection errors during cleanup + pass diff --git a/tests/integration/loaders/test_base_loader.py b/tests/integration/loaders/test_base_loader.py new file mode 100644 index 0000000..6384779 --- /dev/null +++ b/tests/integration/loaders/test_base_loader.py @@ -0,0 +1,159 @@ +""" +Generalized core loader tests that work across all loader implementations. + +These tests use the LoaderTestConfig abstraction to run identical test logic +across different storage backends (PostgreSQL, Redis, Snowflake, etc.). +""" + +import pytest + +from src.amp.loaders.base import LoadMode +from tests.integration.loaders.conftest import LoaderTestBase + + +class BaseLoaderTests(LoaderTestBase): + """ + Base test class with core loader functionality tests. + + All loaders should inherit from this class and provide a LoaderTestConfig. + + Example: + class TestPostgreSQLCore(BaseLoaderTests): + config = PostgreSQLTestConfig() + """ + + @pytest.mark.integration + def test_connection(self, loader): + """Test basic connection and disconnection to the storage backend""" + # Test connection + loader.connect() + assert loader._is_connected == True + + # Test disconnection + loader.disconnect() + assert loader._is_connected == False + + @pytest.mark.integration + def test_context_manager(self, loader, small_test_data, test_table_name, cleanup_tables): + """Test context manager functionality""" + cleanup_tables.append(test_table_name) + + with loader: + assert loader._is_connected == True + + result = loader.load_table(small_test_data, test_table_name) + assert result.success == True + + # Should be disconnected after context + assert loader._is_connected == False + + @pytest.mark.integration + def test_batch_loading(self, loader, medium_test_table, test_table_name, cleanup_tables): + """Test batch loading functionality with sequential batches""" + cleanup_tables.append(test_table_name) + + with loader: + # Test loading individual batches + batches = medium_test_table.to_batches(max_chunksize=250) + + for i, batch in enumerate(batches): + mode = LoadMode.OVERWRITE if i == 0 else LoadMode.APPEND + result = loader.load_batch(batch, test_table_name, mode=mode) + + assert result.success == True + assert result.rows_loaded == batch.num_rows + assert result.metadata['batch_size'] == batch.num_rows + + # Verify all data was loaded + total_rows = self.config.get_row_count(loader, test_table_name) + assert total_rows == 10000 + + @pytest.mark.integration + def test_append_mode(self, loader, small_test_data, test_table_name, cleanup_tables): + """Test append mode functionality""" + import pyarrow as pa + + cleanup_tables.append(test_table_name) + + with loader: + # Initial load + result = loader.load_table(small_test_data, test_table_name, mode=LoadMode.APPEND) + assert result.success == True + assert result.rows_loaded == 5 + + # Append additional data with DIFFERENT keys (6-10 instead of 1-5) + # This avoids duplicate key conflicts in key-value stores (LMDB, Redis) + additional_data = pa.Table.from_pydict( + { + 'id': [6, 7, 8, 9, 10], + 'name': ['f', 'g', 'h', 'i', 'j'], + 'value': [60.6, 70.7, 80.8, 90.9, 100.0], + 'year': [2024, 2024, 2024, 2024, 2024], + 'month': [1, 1, 1, 1, 1], + 'day': [6, 7, 8, 9, 10], + 'active': [False, True, False, True, False], + } + ) + result = loader.load_table(additional_data, test_table_name, mode=LoadMode.APPEND) + assert result.success == True + assert result.rows_loaded == 5 + + # Verify total rows + total_rows = self.config.get_row_count(loader, test_table_name) + assert total_rows == 10 # 5 + 5 + + @pytest.mark.integration + def test_overwrite_mode(self, loader, small_test_data, test_table_name, cleanup_tables): + """Test overwrite mode functionality""" + if not self.config.supports_overwrite: + pytest.skip('Loader does not support overwrite mode') + + cleanup_tables.append(test_table_name) + + with loader: + # Initial load + result = loader.load_table(small_test_data, test_table_name, mode=LoadMode.OVERWRITE) + assert result.success == True + assert result.rows_loaded == 5 + + # Overwrite with different data + new_data = small_test_data.slice(0, 3) # First 3 rows + result = loader.load_table(new_data, test_table_name, mode=LoadMode.OVERWRITE) + assert result.success == True + assert result.rows_loaded == 3 + + # Verify only new data remains + total_rows = self.config.get_row_count(loader, test_table_name) + assert total_rows == 3 + + @pytest.mark.integration + def test_null_handling(self, loader, null_test_data, test_table_name, cleanup_tables): + """Test null value handling across all data types""" + if not self.config.supports_null_values: + pytest.skip('Loader does not support null values') + + cleanup_tables.append(test_table_name) + + with loader: + result = loader.load_table(null_test_data, test_table_name) + assert result.success == True + assert result.rows_loaded == 10 + + # Verify data was loaded (basic sanity check) + # Specific null value verification is loader-specific and tested in backend tests + total_rows = self.config.get_row_count(loader, test_table_name) + assert total_rows == 10 + + @pytest.mark.integration + def test_error_handling(self, loader, small_test_data): + """Test error handling scenarios""" + if not self.config.requires_existing_table: + pytest.skip('Loader auto-creates tables, cannot test missing table error') + + with loader: + # Test loading to non-existent table without create_table + result = loader.load_table(small_test_data, 'non_existent_table', create_table=False) + + assert result.success == False + assert result.error is not None + assert result.rows_loaded == 0 diff --git a/tests/integration/loaders/test_base_streaming.py b/tests/integration/loaders/test_base_streaming.py new file mode 100644 index 0000000..5543620 --- /dev/null +++ b/tests/integration/loaders/test_base_streaming.py @@ -0,0 +1,356 @@ +""" +Generalized streaming and reorg tests that work across all loader implementations. + +These tests use the LoaderTestConfig abstraction to run identical streaming test logic +across different storage backends (PostgreSQL, Redis, Snowflake, etc.). +""" + +import pyarrow as pa +import pytest + +from tests.integration.loaders.conftest import LoaderTestBase + + +class BaseStreamingTests(LoaderTestBase): + """ + Base test class with streaming and reorg functionality tests. + + Loaders that support streaming should inherit from this class and provide a LoaderTestConfig. + + Example: + class TestPostgreSQLStreaming(BaseStreamingTests): + config = PostgreSQLTestConfig() + """ + + @pytest.mark.integration + def test_streaming_metadata_columns(self, loader, test_table_name, cleanup_tables): + """Test that streaming data creates tables with metadata columns""" + if not self.config.supports_streaming: + pytest.skip('Loader does not support streaming') + + cleanup_tables.append(test_table_name) + + from src.amp.streaming.types import BlockRange + + # Create test data with metadata + data = { + 'block_number': [100, 101, 102], + 'tx_hash': ['0xabc', '0xdef', '0x123'], # Use tx_hash for consistency with other streaming tests + 'value': [1.0, 2.0, 3.0], + } + batch = pa.RecordBatch.from_pydict(data) + + # Create metadata with block ranges + block_ranges = [BlockRange(network='ethereum', start=100, end=102)] + + with loader: + # Add metadata columns (simulating what load_stream_continuous does) + batch_with_metadata = loader._add_metadata_columns(batch, block_ranges) + + # Load the batch + result = loader.load_batch(batch_with_metadata, test_table_name, create_table=True) + assert result.success == True + assert result.rows_loaded == 3 + + # Verify metadata columns were created using backend-specific method + column_names = self.config.get_column_names(loader, test_table_name) + + # Should have original columns plus metadata columns + assert '_amp_batch_id' in column_names + + # Verify data was loaded + row_count = self.config.get_row_count(loader, test_table_name) + assert row_count == 3 + + @pytest.mark.integration + def test_reorg_deletion(self, loader, test_table_name, cleanup_tables): + """Test that _handle_reorg correctly deletes invalidated ranges""" + if not self.config.supports_streaming: + pytest.skip('Loader does not support streaming') + + cleanup_tables.append(test_table_name) + + from src.amp.streaming.types import BatchMetadata, BlockRange, ResponseBatch + + with loader: + # Create streaming batches with metadata + batch1 = pa.RecordBatch.from_pydict( + { + 'tx_hash': ['0x100', '0x101', '0x102'], + 'block_num': [100, 101, 102], + 'value': [10.0, 11.0, 12.0], + } + ) + batch2 = pa.RecordBatch.from_pydict( + {'tx_hash': ['0x200', '0x201'], 'block_num': [103, 104], 'value': [12.0, 33.0]} + ) + batch3 = pa.RecordBatch.from_pydict( + {'tx_hash': ['0x300', '0x301'], 'block_num': [105, 106], 'value': [7.0, 9.0]} + ) + batch4 = pa.RecordBatch.from_pydict( + {'tx_hash': ['0x400', '0x401'], 'block_num': [107, 108], 'value': [6.0, 73.0]} + ) + + # Create table from first batch schema + loader._create_table_from_schema(batch1.schema, test_table_name) + + # Create response batches with hashes + response1 = ResponseBatch.data_batch( + data=batch1, + metadata=BatchMetadata( + ranges=[BlockRange(network='ethereum', start=100, end=102, hash='0xaaa')], + ranges_complete=True, # Mark as complete so state tracking works + ), + ) + response2 = ResponseBatch.data_batch( + data=batch2, + metadata=BatchMetadata( + ranges=[BlockRange(network='ethereum', start=103, end=104, hash='0xbbb')], + ranges_complete=True, + ), + ) + response3 = ResponseBatch.data_batch( + data=batch3, + metadata=BatchMetadata( + ranges=[BlockRange(network='ethereum', start=105, end=106, hash='0xccc')], + ranges_complete=True, + ), + ) + response4 = ResponseBatch.data_batch( + data=batch4, + metadata=BatchMetadata( + ranges=[BlockRange(network='ethereum', start=107, end=108, hash='0xddd')], + ranges_complete=True, + ), + ) + + # Load via streaming API (with connection_name for state tracking) + stream = [response1, response2, response3, response4] + results = list(loader.load_stream_continuous(iter(stream), test_table_name, connection_name='test_conn')) + assert len(results) == 4 + assert all(r.success for r in results) + + # Verify initial data count + initial_count = self.config.get_row_count(loader, test_table_name) + assert initial_count == 9 # 3 + 2 + 2 + 2 + + # Test reorg deletion - invalidate blocks 104-108 on ethereum + reorg_response = ResponseBatch.reorg_batch( + invalidation_ranges=[BlockRange(network='ethereum', start=104, end=108)] + ) + reorg_results = list( + loader.load_stream_continuous(iter([reorg_response]), test_table_name, connection_name='test_conn') + ) + assert len(reorg_results) == 1 + assert reorg_results[0].success + + # Should delete batch2, batch3 and batch4 leaving only the 3 rows from batch1 + after_reorg_count = self.config.get_row_count(loader, test_table_name) + assert after_reorg_count == 3 + + @pytest.mark.integration + def test_reorg_overlapping_ranges(self, loader, test_table_name, cleanup_tables): + """Test reorg deletion with overlapping block ranges""" + if not self.config.supports_streaming: + pytest.skip('Loader does not support streaming') + + cleanup_tables.append(test_table_name) + + from src.amp.streaming.types import BatchMetadata, BlockRange, ResponseBatch + + with loader: + # Load data with overlapping ranges that should be invalidated + batch = pa.RecordBatch.from_pydict( + {'tx_hash': ['0x150', '0x175', '0x250'], 'block_num': [150, 175, 250], 'value': [15.0, 17.5, 25.0]} + ) + + # Create table from batch schema + loader._create_table_from_schema(batch.schema, test_table_name) + + response = ResponseBatch.data_batch( + data=batch, + metadata=BatchMetadata( + ranges=[BlockRange(network='ethereum', start=150, end=175, hash='0xaaa')], + ranges_complete=True, + ), + ) + + # Load via streaming API (with connection_name for state tracking) + results = list( + loader.load_stream_continuous(iter([response]), test_table_name, connection_name='test_conn') + ) + assert len(results) == 1 + assert results[0].success + + # Verify initial data + initial_count = self.config.get_row_count(loader, test_table_name) + assert initial_count == 3 + + # Test partial overlap invalidation (160-180) + # This should invalidate our range [150,175] because they overlap + reorg_response = ResponseBatch.reorg_batch( + invalidation_ranges=[BlockRange(network='ethereum', start=160, end=180)] + ) + reorg_results = list( + loader.load_stream_continuous(iter([reorg_response]), test_table_name, connection_name='test_conn') + ) + assert len(reorg_results) == 1 + assert reorg_results[0].success + + # All data should be deleted due to overlap + after_reorg_count = self.config.get_row_count(loader, test_table_name) + assert after_reorg_count == 0 + + @pytest.mark.integration + def test_reorg_multi_network(self, loader, test_table_name, cleanup_tables): + """Test that reorg only affects specified network""" + if not self.config.supports_streaming: + pytest.skip('Loader does not support streaming') + if not self.config.supports_multi_network: + pytest.skip('Loader does not support multi-network isolation') + + cleanup_tables.append(test_table_name) + + from src.amp.streaming.types import BatchMetadata, BlockRange, ResponseBatch + + with loader: + # Load data from multiple networks with same block ranges + batch_eth = pa.RecordBatch.from_pydict( + {'tx_hash': ['0x100_eth'], 'network_id': ['ethereum'], 'block_num': [100], 'value': [10.0]} + ) + batch_poly = pa.RecordBatch.from_pydict( + {'tx_hash': ['0x100_poly'], 'network_id': ['polygon'], 'block_num': [100], 'value': [10.0]} + ) + + # Create table from batch schema + loader._create_table_from_schema(batch_eth.schema, test_table_name) + + response_eth = ResponseBatch.data_batch( + data=batch_eth, + metadata=BatchMetadata( + ranges=[BlockRange(network='ethereum', start=100, end=100, hash='0xaaa')], + ranges_complete=True, + ), + ) + response_poly = ResponseBatch.data_batch( + data=batch_poly, + metadata=BatchMetadata( + ranges=[BlockRange(network='polygon', start=100, end=100, hash='0xbbb')], + ranges_complete=True, + ), + ) + + # Load both batches via streaming API (with connection_name for state tracking) + stream = [response_eth, response_poly] + results = list(loader.load_stream_continuous(iter(stream), test_table_name, connection_name='test_conn')) + assert len(results) == 2 + assert all(r.success for r in results) + + # Verify both networks' data exists + initial_count = self.config.get_row_count(loader, test_table_name) + assert initial_count == 2 + + # Invalidate only ethereum network + reorg_response = ResponseBatch.reorg_batch( + invalidation_ranges=[BlockRange(network='ethereum', start=100, end=100)] + ) + reorg_results = list( + loader.load_stream_continuous(iter([reorg_response]), test_table_name, connection_name='test_conn') + ) + assert len(reorg_results) == 1 + assert reorg_results[0].success + + # Should only delete ethereum data, polygon should remain + after_reorg_count = self.config.get_row_count(loader, test_table_name) + assert after_reorg_count == 1 + + @pytest.mark.integration + def test_microbatch_deduplication(self, loader, test_table_name, cleanup_tables): + """ + Test that multiple RecordBatches within the same microbatch are all loaded, + and deduplication only happens at microbatch boundaries when ranges_complete=True. + + This test verifies the fix for the critical bug where we were marking batches + as processed after every RecordBatch instead of waiting for ranges_complete=True. + """ + if not self.config.supports_streaming: + pytest.skip('Loader does not support streaming') + + cleanup_tables.append(test_table_name) + + from src.amp.streaming.types import BatchMetadata, BlockRange, ResponseBatch + + with loader: + # Create table first from the schema (include tx_hash for Redis key pattern compatibility) + batch1_data = pa.RecordBatch.from_pydict({'id': [1, 2], 'tx_hash': ['0x1', '0x2'], 'value': [100, 200]}) + loader._create_table_from_schema(batch1_data.schema, test_table_name) + + # Simulate a microbatch sent as 3 RecordBatches with the same BlockRange + # First RecordBatch of the microbatch (ranges_complete=False) + response1 = ResponseBatch.data_batch( + data=batch1_data, + metadata=BatchMetadata( + ranges=[BlockRange(network='ethereum', start=100, end=110, hash='0xabc123')], + ranges_complete=False, # Not the last batch in this microbatch + ), + ) + + # Second RecordBatch (same BlockRange, ranges_complete=False) + batch2_data = pa.RecordBatch.from_pydict({'id': [3, 4], 'tx_hash': ['0x3', '0x4'], 'value': [300, 400]}) + response2 = ResponseBatch.data_batch( + data=batch2_data, + metadata=BatchMetadata( + ranges=[BlockRange(network='ethereum', start=100, end=110, hash='0xabc123')], + ranges_complete=False, + ), + ) + + # Third RecordBatch (same BlockRange, ranges_complete=True) + batch3_data = pa.RecordBatch.from_pydict({'id': [5, 6], 'tx_hash': ['0x5', '0x6'], 'value': [500, 600]}) + response3 = ResponseBatch.data_batch( + data=batch3_data, + metadata=BatchMetadata( + ranges=[BlockRange(network='ethereum', start=100, end=110, hash='0xabc123')], + ranges_complete=True, # Last batch - safe to mark as processed + ), + ) + + # Process the microbatch stream + stream = [response1, response2, response3] + results = list( + loader.load_stream_continuous(iter(stream), test_table_name, connection_name='test_connection') + ) + + # CRITICAL: All 3 RecordBatches should be loaded successfully + assert len(results) == 3, 'All RecordBatches within microbatch should be processed' + assert all(r.success for r in results), 'All batches should succeed' + assert results[0].rows_loaded == 2, 'First batch should load 2 rows' + assert results[1].rows_loaded == 2, 'Second batch should load 2 rows (not skipped!)' + assert results[2].rows_loaded == 2, 'Third batch should load 2 rows (not skipped!)' + + # Verify all 6 rows in table + total_count = self.config.get_row_count(loader, test_table_name) + assert total_count == 6, 'All 6 rows from 3 RecordBatches should be in the table' + + # Test duplicate detection - send the same microbatch again + duplicate_batch = pa.RecordBatch.from_pydict({'id': [7, 8], 'value': [700, 800]}) + duplicate_response = ResponseBatch.data_batch( + data=duplicate_batch, + metadata=BatchMetadata( + ranges=[BlockRange(network='ethereum', start=100, end=110, hash='0xabc123')], + ranges_complete=True, + ), + ) + + duplicate_results = list( + loader.load_stream_continuous( + iter([duplicate_response]), test_table_name, connection_name='test_connection' + ) + ) + + # Duplicate should be skipped + assert duplicate_results[0].rows_loaded == 0, 'Duplicate microbatch should be skipped' + + # Verify count hasn't changed + final_count = self.config.get_row_count(loader, test_table_name) + assert final_count == 6, 'No additional rows should be added after duplicate' diff --git a/tests/integration/test_deltalake_loader.py b/tests/integration/test_deltalake_loader.py deleted file mode 100644 index c19c9a9..0000000 --- a/tests/integration/test_deltalake_loader.py +++ /dev/null @@ -1,870 +0,0 @@ -# tests/integration/test_deltalake_loader.py -""" -Integration tests for Delta Lake loader implementation. -These tests require actual Delta Lake functionality and local filesystem access. -""" - -import json -import shutil -import tempfile -from datetime import datetime, timedelta -from pathlib import Path - -import pyarrow as pa -import pytest - -from src.amp.loaders.base import LoadMode - -try: - from src.amp.loaders.implementations.deltalake_loader import DELTALAKE_AVAILABLE, DeltaLakeLoader - - # Skip all tests if deltalake is not available - if not DELTALAKE_AVAILABLE: - pytest.skip('Delta Lake not available', allow_module_level=True) - -except ImportError: - pytest.skip('amp modules not available', allow_module_level=True) - - -@pytest.fixture(scope='session') -def delta_test_env(): - """Setup Delta Lake test environment for the session""" - temp_dir = tempfile.mkdtemp(prefix='delta_test_') - yield temp_dir - # Cleanup - shutil.rmtree(temp_dir, ignore_errors=True) - - -@pytest.fixture -def delta_basic_config(delta_test_env): - """Get basic Delta Lake configuration""" - return { - 'table_path': str(Path(delta_test_env) / 'basic_table'), - 'partition_by': ['year', 'month'], - 'optimize_after_write': True, - 'vacuum_after_write': False, - 'schema_evolution': True, - 'merge_schema': True, - 'storage_options': {}, - } - - -@pytest.fixture -def delta_partitioned_config(delta_test_env): - """Get partitioned Delta Lake configuration""" - return { - 'table_path': str(Path(delta_test_env) / 'partitioned_table'), - 'partition_by': ['year', 'month', 'day'], - 'optimize_after_write': True, - 'vacuum_after_write': True, - 'schema_evolution': True, - 'merge_schema': True, - 'storage_options': {}, - } - - -@pytest.fixture -def comprehensive_test_data(): - """Create comprehensive test data for Delta Lake testing""" - base_date = datetime(2024, 1, 1) - - data = { - 'id': list(range(1000)), - 'user_id': [f'user_{i % 100}' for i in range(1000)], - 'transaction_amount': [round((i * 12.34) % 1000, 2) for i in range(1000)], - 'category': [['electronics', 'clothing', 'books', 'food', 'travel'][i % 5] for i in range(1000)], - 'timestamp': [(base_date + timedelta(days=i // 50, hours=i % 24)).isoformat() for i in range(1000)], - 'year': [2024 if i < 800 else 2023 for i in range(1000)], - 'month': [(i // 80) % 12 + 1 for i in range(1000)], - 'day': [(i // 30) % 28 + 1 for i in range(1000)], - 'is_weekend': [i % 7 in [0, 6] for i in range(1000)], - 'metadata': [ - json.dumps( - { - 'session_id': f'session_{i}', - 'device': ['mobile', 'desktop', 'tablet'][i % 3], - 'location': ['US', 'UK', 'DE', 'FR', 'JP'][i % 5], - } - ) - for i in range(1000) - ], - 'score': [i * 0.123 for i in range(1000)], - 'active': [i % 2 == 0 for i in range(1000)], - } - - return pa.Table.from_pydict(data) - - -@pytest.fixture -def small_test_data(): - """Create small test data for quick tests""" - data = { - 'id': [1, 2, 3, 4, 5], - 'name': ['a', 'b', 'c', 'd', 'e'], - 'value': [10.1, 20.2, 30.3, 40.4, 50.5], - 'year': [2024, 2024, 2024, 2024, 2024], - 'month': [1, 1, 1, 1, 1], - 'day': [1, 2, 3, 4, 5], - 'active': [True, False, True, False, True], - } - - return pa.Table.from_pydict(data) - - -@pytest.mark.integration -@pytest.mark.delta_lake -class TestDeltaLakeLoaderIntegration: - """Integration tests for Delta Lake loader""" - - def test_loader_initialization(self, delta_basic_config): - """Test loader initialization and connection""" - loader = DeltaLakeLoader(delta_basic_config) - - # Test configuration - assert loader.config.table_path == delta_basic_config['table_path'] - assert loader.config.partition_by == ['year', 'month'] - assert loader.config.optimize_after_write == True - assert loader.storage_backend == 'Local' - - # Test connection - loader.connect() - assert loader._is_connected == True - - # Test disconnection - loader.disconnect() - assert loader._is_connected == False - - def test_basic_table_operations(self, delta_basic_config, comprehensive_test_data): - """Test basic table creation and data loading""" - loader = DeltaLakeLoader(delta_basic_config) - - with loader: - # Test initial table creation - result = loader.load_table(comprehensive_test_data, 'test_transactions', mode=LoadMode.OVERWRITE) - - assert result.success == True - assert result.rows_loaded == 1000 - assert result.metadata['write_mode'] == 'overwrite' - assert result.metadata['storage_backend'] == 'Local' - assert result.metadata['partition_columns'] == ['year', 'month'] - - # Verify table exists - assert loader._table_exists == True - assert loader._delta_table is not None - - # Test table statistics - stats = loader.get_table_stats() - assert 'version' in stats - assert stats['storage_backend'] == 'Local' - assert stats['partition_columns'] == ['year', 'month'] - - def test_append_mode(self, delta_basic_config, comprehensive_test_data): - """Test append mode functionality""" - loader = DeltaLakeLoader(delta_basic_config) - - with loader: - # Initial load - result = loader.load_table(comprehensive_test_data, 'test_append', mode=LoadMode.OVERWRITE) - assert result.success == True - assert result.rows_loaded == 1000 - - # Append additional data - additional_data = comprehensive_test_data.slice(0, 100) # First 100 rows - result = loader.load_table(additional_data, 'test_append', mode=LoadMode.APPEND) - - assert result.success == True - assert result.rows_loaded == 100 - assert result.metadata['write_mode'] == 'append' - - # Verify total data - final_query = loader.query_table() - assert final_query.num_rows == 1100 # 1000 + 100 - - def test_batch_loading(self, delta_basic_config, comprehensive_test_data): - """Test batch loading functionality""" - loader = DeltaLakeLoader(delta_basic_config) - - with loader: - # Test loading individual batches - batches = comprehensive_test_data.to_batches(max_chunksize=200) - - for i, batch in enumerate(batches): - mode = LoadMode.OVERWRITE if i == 0 else LoadMode.APPEND - result = loader.load_batch(batch, 'test_batches', mode=mode) - - assert result.success == True - assert result.rows_loaded == batch.num_rows - assert result.metadata['operation'] == 'load_batch' - assert result.metadata['batch_size'] == batch.num_rows - - # Verify all data was loaded - final_query = loader.query_table() - assert final_query.num_rows == 1000 - - def test_partitioning(self, delta_partitioned_config, small_test_data): - """Test table partitioning functionality""" - loader = DeltaLakeLoader(delta_partitioned_config) - - with loader: - # Load partitioned data - result = loader.load_table(small_test_data, 'test_partitioned', mode=LoadMode.OVERWRITE) - - assert result.success == True - assert result.metadata['partition_columns'] == ['year', 'month', 'day'] - - # Verify partition structure exists - table_path = Path(delta_partitioned_config['table_path']) - assert table_path.exists() - - def test_schema_evolution(self, delta_basic_config, small_test_data): - """Test schema evolution functionality""" - loader = DeltaLakeLoader(delta_basic_config) - - with loader: - # Load initial data - result = loader.load_table(small_test_data, 'test_schema_evolution', mode=LoadMode.OVERWRITE) - - assert result.success == True - initial_schema = loader.get_table_schema() - initial_columns = set(initial_schema.names) - - # Create data with additional columns - extended_data_dict = small_test_data.to_pydict() - extended_data_dict['new_column'] = list(range(len(extended_data_dict['id']))) - extended_data_dict['another_field'] = ['test_value'] * len(extended_data_dict['id']) - extended_table = pa.Table.from_pydict(extended_data_dict) - - # Load extended data (should add new columns) - result = loader.load_table(extended_table, 'test_schema_evolution', mode=LoadMode.APPEND) - - assert result.success == True - - # Verify schema has evolved - evolved_schema = loader.get_table_schema() - evolved_columns = set(evolved_schema.names) - - assert 'new_column' in evolved_columns - assert 'another_field' in evolved_columns - assert evolved_columns.issuperset(initial_columns) - - def test_optimization_operations(self, delta_basic_config, comprehensive_test_data): - """Test table optimization operations""" - loader = DeltaLakeLoader(delta_basic_config) - - with loader: - # Load data multiple times to create multiple files - for i in range(3): - subset = comprehensive_test_data.slice(i * 300, 300) - mode = LoadMode.OVERWRITE if i == 0 else LoadMode.APPEND - - result = loader.load_table(subset, 'test_optimization', mode=mode) - assert result.success == True - - optimize_result = loader.optimize_table() - - assert optimize_result['success'] == True - assert 'duration_seconds' in optimize_result - assert 'metrics' in optimize_result - - # Verify data integrity after optimization - final_data = loader.query_table() - assert final_data.num_rows == 900 # 3 * 300 - - def test_query_operations(self, delta_basic_config, comprehensive_test_data): - """Test table querying operations""" - loader = DeltaLakeLoader(delta_basic_config) - - with loader: - # Load data - result = loader.load_table(comprehensive_test_data, 'test_query', mode=LoadMode.OVERWRITE) - assert result.success == True - - # Test basic query - query_result = loader.query_table() - assert query_result.num_rows == 1000 - - # Test column selection - query_result = loader.query_table(columns=['id', 'user_id', 'transaction_amount']) - assert query_result.num_rows == 1000 - assert query_result.column_names == ['id', 'user_id', 'transaction_amount'] - - # Test limit - query_result = loader.query_table(limit=50) - assert query_result.num_rows == 50 - - # Test combined options - query_result = loader.query_table(columns=['id', 'category'], limit=10) - assert query_result.num_rows == 10 - assert query_result.column_names == ['id', 'category'] - - def test_error_handling(self, delta_temp_config): - """Test error handling scenarios""" - loader = DeltaLakeLoader(delta_temp_config) - - with loader: - # Test loading invalid data (missing partition columns) - invalid_data = pa.table( - { - 'id': [1, 2, 3], - 'name': ['a', 'b', 'c'], - # Missing 'year' and 'month' partition columns - } - ) - - result = loader.load_table(invalid_data, 'test_errors', mode=LoadMode.OVERWRITE) - - # Should handle error gracefully - assert result.success == False - assert result.error is not None - assert result.rows_loaded == 0 - - def test_table_history(self, delta_basic_config, small_test_data): - """Test table history functionality""" - loader = DeltaLakeLoader(delta_basic_config) - - with loader: - # Create multiple versions - for i in range(3): - subset = small_test_data.slice(i, 1) - mode = LoadMode.OVERWRITE if i == 0 else LoadMode.APPEND - - result = loader.load_table(subset, 'test_history', mode=mode) - assert result.success == True - - # Get history - history = loader.get_table_history() - assert len(history) >= 3 - - # Verify history structure - for entry in history: - assert 'version' in entry - assert 'operation' in entry - assert 'timestamp' in entry - - def test_context_manager(self, delta_basic_config, small_test_data): - """Test context manager functionality""" - loader = DeltaLakeLoader(delta_basic_config) - - # Test context manager - with loader: - assert loader._is_connected == True - - result = loader.load_table(small_test_data, 'test_context', mode=LoadMode.OVERWRITE) - assert result.success == True - - # Should be disconnected after context - assert loader._is_connected == False - - def test_metadata_completeness(self, delta_basic_config, comprehensive_test_data): - """Test metadata completeness in results""" - loader = DeltaLakeLoader(delta_basic_config) - - with loader: - result = loader.load_table(comprehensive_test_data, 'test_metadata', mode=LoadMode.OVERWRITE) - - assert result.success == True - - # Check required metadata fields - metadata = result.metadata - required_fields = [ - 'write_mode', - 'storage_backend', - 'partition_columns', - 'throughput_rows_per_sec', - 'table_version', - ] - - for field in required_fields: - assert field in metadata, f'Missing metadata field: {field}' - - # Verify metadata values - assert metadata['write_mode'] == 'overwrite' - assert metadata['storage_backend'] == 'Local' - assert metadata['partition_columns'] == ['year', 'month'] - assert metadata['throughput_rows_per_sec'] > 0 - - def test_null_value_handling(self, delta_basic_config, null_test_data): - """Test comprehensive null value handling across all data types""" - loader = DeltaLakeLoader(delta_basic_config) - - with loader: - result = loader.load_table(null_test_data, 'test_nulls', mode=LoadMode.OVERWRITE) - assert result.success == True - assert result.rows_loaded == 10 - - query_result = loader.query_table() - assert query_result.num_rows == 10 - - df = query_result.to_pandas() - - text_nulls = df['text_field'].isna().sum() - assert text_nulls == 3 # Rows 3, 6, 9 should be NULL - - int_nulls = df['int_field'].isna().sum() - assert int_nulls == 3 # Rows 2, 5, 8 should be NULL - - float_nulls = df['float_field'].isna().sum() - assert float_nulls == 3 # Rows 3, 6, 9 should be NULL - - bool_nulls = df['bool_field'].isna().sum() - assert bool_nulls == 3 # Rows 3, 6, 9 should be NULL - - timestamp_nulls = df['timestamp_field'].isna().sum() - assert timestamp_nulls == 4 # Rows where i % 3 == 0 - - # Verify non-null values are intact - assert df.loc[df['id'] == 1, 'text_field'].iloc[0] == 'a' - assert df.loc[df['id'] == 1, 'int_field'].iloc[0] == 1 - assert abs(df.loc[df['id'] == 1, 'float_field'].iloc[0] - 1.1) < 0.01 - assert df.loc[df['id'] == 1, 'bool_field'].iloc[0] == True - - # Test schema evolution with null values - from datetime import datetime - - additional_data = pa.table( - { - 'id': [11, 12], - 'text_field': ['k', None], - 'int_field': [None, 12], - 'float_field': [11.1, None], - 'bool_field': [None, False], - 'timestamp_field': [datetime.now(), None], # At least one non-null to preserve type - 'json_field': [None, '{"test": "value"}'], - 'year': [2024, 2024], - 'month': [1, 1], - 'day': [11, 12], - 'new_nullable_field': [None, 'new_value'], # New field with nulls - } - ) - - result = loader.load_table(additional_data, 'test_nulls', mode=LoadMode.APPEND) - assert result.success == True - assert result.rows_loaded == 2 - - # Verify schema evolved and nulls handled in new column - final_query = loader.query_table() - assert final_query.num_rows == 12 - - final_df = final_query.to_pandas() - new_field_nulls = final_df['new_nullable_field'].isna().sum() - assert new_field_nulls == 11 # All original rows + 1 new null row - - def test_file_size_calculation_modern_api(self, delta_basic_config, comprehensive_test_data): - """Test file size calculation using modern get_add_file_sizes API""" - loader = DeltaLakeLoader(delta_basic_config) - - with loader: - result = loader.load_table(comprehensive_test_data, 'test_file_sizes', mode=LoadMode.OVERWRITE) - assert result.success == True - assert result.rows_loaded == 1000 - - table_info = loader._get_table_info() - - # Verify size calculation worked - assert 'size_bytes' in table_info - assert table_info['size_bytes'] > 0, 'File size should be greater than 0' - assert table_info['num_files'] > 0, 'Should have at least one file' - - # Verify metadata includes size information - assert 'total_size_bytes' in result.metadata - assert result.metadata['total_size_bytes'] > 0 - - -@pytest.mark.integration -@pytest.mark.delta_lake -@pytest.mark.slow -class TestDeltaLakeLoaderAdvanced: - """Advanced integration tests for Delta Lake loader""" - - def test_large_data_performance(self, delta_basic_config): - """Test performance with larger datasets""" - # Create larger test dataset - large_data = { - 'id': list(range(50000)), - 'value': [i * 0.123 for i in range(50000)], - 'category': [f'category_{i % 10}' for i in range(50000)], - 'year': [2024] * 50000, - 'month': [(i // 4000) % 12 + 1 for i in range(50000)], - 'timestamp': [datetime.now().isoformat() for _ in range(50000)], - } - - large_table = pa.Table.from_pydict(large_data) - - loader = DeltaLakeLoader(delta_basic_config) - - with loader: - # Load large dataset - result = loader.load_table(large_table, 'test_performance', mode=LoadMode.OVERWRITE) - - assert result.success == True - assert result.rows_loaded == 50000 - - # Verify performance metrics - assert result.metadata['throughput_rows_per_sec'] > 100 # Should be reasonably fast - assert result.duration < 120 # Should complete within reasonable time - - def test_concurrent_operations_safety(self, delta_basic_config, small_test_data): - """Test that operations are handled safely (basic concurrency test)""" - loader = DeltaLakeLoader(delta_basic_config) - - with loader: - # Load initial data - result = loader.load_table(small_test_data, 'test_concurrent', mode=LoadMode.OVERWRITE) - assert result.success == True - - # Perform multiple operations in sequence (simulating concurrent-like scenario) - operations = [] - - # Append operations - for i in range(3): - subset = small_test_data.slice(i, 1) - result = loader.load_table(subset, 'test_concurrent', mode=LoadMode.APPEND) - operations.append(result) - - # Verify all operations succeeded - for result in operations: - assert result.success == True - - # Verify final data integrity - final_data = loader.query_table() - assert final_data.num_rows == 8 # 5 + 3 * 1 - - def test_handle_reorg_no_table(self, delta_basic_config): - """Test reorg handling when table doesn't exist""" - from src.amp.streaming.types import BlockRange - - loader = DeltaLakeLoader(delta_basic_config) - - with loader: - # Call handle reorg on non-existent table - invalidation_ranges = [BlockRange(network='ethereum', start=100, end=200)] - - # Should not raise any errors - loader._handle_reorg(invalidation_ranges, 'test_reorg_empty', 'test_connection') - - def test_handle_reorg_no_metadata_column(self, delta_basic_config): - """Test reorg handling when table lacks metadata column""" - from src.amp.streaming.types import BlockRange - - loader = DeltaLakeLoader(delta_basic_config) - - with loader: - # Create table without metadata column - data = pa.table( - { - 'id': [1, 2, 3], - 'block_num': [100, 150, 200], - 'value': [10.0, 20.0, 30.0], - 'year': [2024, 2024, 2024], - 'month': [1, 1, 1], - } - ) - loader.load_table(data, 'test_reorg_no_meta', mode=LoadMode.OVERWRITE) - - # Call handle reorg - invalidation_ranges = [BlockRange(network='ethereum', start=150, end=250)] - - # Should log warning and not modify data - loader._handle_reorg(invalidation_ranges, 'test_reorg_no_meta', 'test_connection') - - # Verify data unchanged - remaining_data = loader.query_table() - assert remaining_data.num_rows == 3 - - def test_handle_reorg_single_network(self, delta_temp_config): - """Test reorg handling for single network data""" - from src.amp.streaming.types import BatchMetadata, BlockRange, ResponseBatch - - loader = DeltaLakeLoader(delta_temp_config) - - with loader: - # Create streaming batches with metadata - batch1 = pa.RecordBatch.from_pydict({'id': [1], 'block_num': [105], 'year': [2024], 'month': [1]}) - batch2 = pa.RecordBatch.from_pydict({'id': [2], 'block_num': [155], 'year': [2024], 'month': [1]}) - batch3 = pa.RecordBatch.from_pydict({'id': [3], 'block_num': [205], 'year': [2024], 'month': [1]}) - - # Create response batches with hashes - response1 = ResponseBatch.data_batch( - data=batch1, - metadata=BatchMetadata( - ranges=[BlockRange(network='ethereum', start=100, end=110, hash='0xabc')], - ranges_complete=True, # Mark as complete so it gets tracked in state store - ), - ) - response2 = ResponseBatch.data_batch( - data=batch2, - metadata=BatchMetadata( - ranges=[BlockRange(network='ethereum', start=150, end=160, hash='0xdef')], - ranges_complete=True, # Mark as complete so it gets tracked in state store - ), - ) - response3 = ResponseBatch.data_batch( - data=batch3, - metadata=BatchMetadata( - ranges=[BlockRange(network='ethereum', start=200, end=210, hash='0x123')], - ranges_complete=True, # Mark as complete so it gets tracked in state store - ), - ) - - # Load via streaming API - stream = [response1, response2, response3] - results = list(loader.load_stream_continuous(iter(stream), 'test_reorg_single')) - assert len(results) == 3 - assert all(r.success for r in results) - - # Verify all data exists - initial_data = loader.query_table() - assert initial_data.num_rows == 3 - - # Reorg from block 155 - should delete rows 2 and 3 - reorg_response = ResponseBatch.reorg_batch( - invalidation_ranges=[BlockRange(network='ethereum', start=155, end=300)] - ) - reorg_results = list(loader.load_stream_continuous(iter([reorg_response]), 'test_reorg_single')) - assert len(reorg_results) == 1 - assert reorg_results[0].success - assert reorg_results[0].is_reorg - - # Verify only first row remains - remaining_data = loader.query_table() - assert remaining_data.num_rows == 1 - assert remaining_data['id'][0].as_py() == 1 - - def test_handle_reorg_multi_network(self, delta_temp_config): - """Test reorg handling preserves data from unaffected networks""" - from src.amp.streaming.types import BatchMetadata, BlockRange, ResponseBatch - - loader = DeltaLakeLoader(delta_temp_config) - - with loader: - # Create streaming batches from multiple networks - batch1 = pa.RecordBatch.from_pydict({'id': [1], 'network': ['ethereum'], 'year': [2024], 'month': [1]}) - batch2 = pa.RecordBatch.from_pydict({'id': [2], 'network': ['polygon'], 'year': [2024], 'month': [1]}) - batch3 = pa.RecordBatch.from_pydict({'id': [3], 'network': ['ethereum'], 'year': [2024], 'month': [1]}) - batch4 = pa.RecordBatch.from_pydict({'id': [4], 'network': ['polygon'], 'year': [2024], 'month': [1]}) - - # Create response batches with network-specific ranges - response1 = ResponseBatch.data_batch( - data=batch1, - metadata=BatchMetadata( - ranges=[BlockRange(network='ethereum', start=100, end=110, hash='0xaaa')], - ranges_complete=True, # Mark as complete so it gets tracked in state store - ), - ) - response2 = ResponseBatch.data_batch( - data=batch2, - metadata=BatchMetadata( - ranges=[BlockRange(network='polygon', start=100, end=110, hash='0xbbb')], - ranges_complete=True, # Mark as complete so it gets tracked in state store - ), - ) - response3 = ResponseBatch.data_batch( - data=batch3, - metadata=BatchMetadata( - ranges=[BlockRange(network='ethereum', start=150, end=160, hash='0xccc')], - ranges_complete=True, # Mark as complete so it gets tracked in state store - ), - ) - response4 = ResponseBatch.data_batch( - data=batch4, - metadata=BatchMetadata( - ranges=[BlockRange(network='polygon', start=150, end=160, hash='0xddd')], - ranges_complete=True, # Mark as complete so it gets tracked in state store - ), - ) - - # Load via streaming API - stream = [response1, response2, response3, response4] - results = list(loader.load_stream_continuous(iter(stream), 'test_reorg_multi')) - assert len(results) == 4 - assert all(r.success for r in results) - - # Reorg only ethereum from block 150 - reorg_response = ResponseBatch.reorg_batch( - invalidation_ranges=[BlockRange(network='ethereum', start=150, end=200)] - ) - reorg_results = list(loader.load_stream_continuous(iter([reorg_response]), 'test_reorg_multi')) - assert len(reorg_results) == 1 - assert reorg_results[0].success - - # Verify ethereum row 3 deleted, but polygon rows preserved - remaining_data = loader.query_table() - assert remaining_data.num_rows == 3 - remaining_ids = sorted([id.as_py() for id in remaining_data['id']]) - assert remaining_ids == [1, 2, 4] # Row 3 deleted - - def test_handle_reorg_overlapping_ranges(self, delta_temp_config): - """Test reorg with overlapping block ranges""" - from src.amp.streaming.types import BatchMetadata, BlockRange, ResponseBatch - - loader = DeltaLakeLoader(delta_temp_config) - - with loader: - # Create streaming batches with different ranges - batch1 = pa.RecordBatch.from_pydict({'id': [1], 'year': [2024], 'month': [1]}) - batch2 = pa.RecordBatch.from_pydict({'id': [2], 'year': [2024], 'month': [1]}) - batch3 = pa.RecordBatch.from_pydict({'id': [3], 'year': [2024], 'month': [1]}) - - # Batch 1: 90-110 (ends before reorg start of 150) - # Batch 2: 140-160 (overlaps with reorg) - # Batch 3: 170-190 (after reorg, but should be deleted as 170 >= 150) - response1 = ResponseBatch.data_batch( - data=batch1, - metadata=BatchMetadata( - ranges=[BlockRange(network='ethereum', start=90, end=110, hash='0xaaa')], - ranges_complete=True, # Mark as complete so it gets tracked in state store - ), - ) - response2 = ResponseBatch.data_batch( - data=batch2, - metadata=BatchMetadata( - ranges=[BlockRange(network='ethereum', start=140, end=160, hash='0xbbb')], - ranges_complete=True, # Mark as complete so it gets tracked in state store - ), - ) - response3 = ResponseBatch.data_batch( - data=batch3, - metadata=BatchMetadata( - ranges=[BlockRange(network='ethereum', start=170, end=190, hash='0xccc')], - ranges_complete=True, # Mark as complete so it gets tracked in state store - ), - ) - - # Load via streaming API - stream = [response1, response2, response3] - results = list(loader.load_stream_continuous(iter(stream), 'test_reorg_overlap')) - assert len(results) == 3 - assert all(r.success for r in results) - - # Reorg from block 150 - should delete batches 2 and 3 - reorg_response = ResponseBatch.reorg_batch( - invalidation_ranges=[BlockRange(network='ethereum', start=150, end=200)] - ) - reorg_results = list(loader.load_stream_continuous(iter([reorg_response]), 'test_reorg_overlap')) - assert len(reorg_results) == 1 - assert reorg_results[0].success - - # Only first row should remain (ends at 110 < 150) - remaining_data = loader.query_table() - assert remaining_data.num_rows == 1 - assert remaining_data['id'][0].as_py() == 1 - - def test_handle_reorg_version_history(self, delta_temp_config): - """Test that reorg creates proper version history in Delta Lake""" - from src.amp.streaming.types import BatchMetadata, BlockRange, ResponseBatch - - loader = DeltaLakeLoader(delta_temp_config) - - with loader: - # Create streaming batches - batch1 = pa.RecordBatch.from_pydict({'id': [1], 'year': [2024], 'month': [1]}) - batch2 = pa.RecordBatch.from_pydict({'id': [2], 'year': [2024], 'month': [1]}) - batch3 = pa.RecordBatch.from_pydict({'id': [3], 'year': [2024], 'month': [1]}) - - response1 = ResponseBatch.data_batch( - data=batch1, - metadata=BatchMetadata( - ranges=[BlockRange(network='ethereum', start=0, end=10, hash='0xaaa')], - ranges_complete=True, # Mark as complete so it gets tracked in state store - ), - ) - response2 = ResponseBatch.data_batch( - data=batch2, - metadata=BatchMetadata( - ranges=[BlockRange(network='ethereum', start=50, end=60, hash='0xbbb')], - ranges_complete=True, # Mark as complete so it gets tracked in state store - ), - ) - response3 = ResponseBatch.data_batch( - data=batch3, - metadata=BatchMetadata( - ranges=[BlockRange(network='ethereum', start=100, end=110, hash='0xccc')], - ranges_complete=True, # Mark as complete so it gets tracked in state store - ), - ) - - # Load via streaming API - stream = [response1, response2, response3] - results = list(loader.load_stream_continuous(iter(stream), 'test_reorg_history')) - assert len(results) == 3 - - initial_version = loader._delta_table.version() - - # Perform reorg - reorg_response = ResponseBatch.reorg_batch( - invalidation_ranges=[BlockRange(network='ethereum', start=50, end=200)] - ) - reorg_results = list(loader.load_stream_continuous(iter([reorg_response]), 'test_reorg_history')) - assert len(reorg_results) == 1 - assert reorg_results[0].success - - # Check that version increased - final_version = loader._delta_table.version() - assert final_version > initial_version - - # Check history - history = loader.get_table_history(limit=5) - assert len(history) >= 2 - # Latest operation should be an overwrite (from reorg) - assert history[0]['operation'] == 'WRITE' - - def test_streaming_with_reorg(self, delta_temp_config): - """Test streaming data with reorg support""" - from src.amp.streaming.types import ( - BatchMetadata, - BlockRange, - ResponseBatch, - ) - - loader = DeltaLakeLoader(delta_temp_config) - - with loader: - # Create streaming data with metadata - data1 = pa.RecordBatch.from_pydict( - {'id': [1, 2], 'value': [100, 200], 'year': [2024, 2024], 'month': [1, 1]} - ) - - data2 = pa.RecordBatch.from_pydict( - {'id': [3, 4], 'value': [300, 400], 'year': [2024, 2024], 'month': [1, 1]} - ) - - # Create response batches using factory methods (with hashes for proper state management) - response1 = ResponseBatch.data_batch( - data=data1, - metadata=BatchMetadata( - ranges=[BlockRange(network='ethereum', start=100, end=110, hash='0xabc123')], - ranges_complete=True, # Mark as complete so it gets tracked in state store - ), - ) - - response2 = ResponseBatch.data_batch( - data=data2, - metadata=BatchMetadata( - ranges=[BlockRange(network='ethereum', start=150, end=160, hash='0xdef456')], - ranges_complete=True, # Mark as complete so it gets tracked in state store - ), - ) - - # Simulate reorg event using factory method - reorg_response = ResponseBatch.reorg_batch( - invalidation_ranges=[BlockRange(network='ethereum', start=150, end=200)] - ) - - # Process streaming data - stream = [response1, response2, reorg_response] - results = list(loader.load_stream_continuous(iter(stream), 'test_streaming_reorg')) - - # Verify results - assert len(results) == 3 - assert results[0].success - assert results[0].rows_loaded == 2 - assert results[1].success - assert results[1].rows_loaded == 2 - assert results[2].success - assert results[2].is_reorg - - # Verify reorg deleted the second batch - final_data = loader.query_table() - assert final_data.num_rows == 2 - remaining_ids = sorted([id.as_py() for id in final_data['id']]) - assert remaining_ids == [1, 2] # 3 and 4 deleted by reorg diff --git a/tests/integration/test_iceberg_loader.py b/tests/integration/test_iceberg_loader.py deleted file mode 100644 index cbbe4bf..0000000 --- a/tests/integration/test_iceberg_loader.py +++ /dev/null @@ -1,745 +0,0 @@ -# tests/integration/test_iceberg_loader.py -""" -Integration tests for Apache Iceberg loader implementation. -These tests require actual Iceberg functionality and catalog access. -""" - -import json -import tempfile -from datetime import datetime, timedelta - -import pyarrow as pa -import pytest - -from src.amp.loaders.base import LoadMode - -try: - from src.amp.loaders.implementations.iceberg_loader import ICEBERG_AVAILABLE, IcebergLoader - - # Skip all tests if iceberg is not available - if not ICEBERG_AVAILABLE: - pytest.skip('Apache Iceberg not available', allow_module_level=True) - -except ImportError: - pytest.skip('amp modules not available', allow_module_level=True) - - -@pytest.fixture(scope='session') -def iceberg_test_env(): - """Setup Iceberg test environment for the session""" - temp_dir = tempfile.mkdtemp(prefix='iceberg_test_') - yield temp_dir - # Note: cleanup is handled by temp directory auto-cleanup - - -@pytest.fixture -def iceberg_basic_config(iceberg_test_env): - """Get basic Iceberg configuration with local file catalog""" - return { - 'catalog_config': { - 'type': 'sql', - 'uri': f'sqlite:///{iceberg_test_env}/catalog.db', - 'warehouse': f'file://{iceberg_test_env}/warehouse', - }, - 'namespace': 'test_data', - 'create_namespace': True, - 'create_table': True, - 'schema_evolution': True, - 'batch_size': 1000, - } - - -@pytest.fixture -def iceberg_partitioned_config(iceberg_test_env): - """Get partitioned Iceberg configuration""" - # Note: partition_spec should be created with actual PartitionSpec when needed - # For now, return config without partitioning since we need schema first - return { - 'catalog_config': { - 'type': 'sql', - 'uri': f'sqlite:///{iceberg_test_env}/catalog.db', - 'warehouse': f'file://{iceberg_test_env}/warehouse', - }, - 'namespace': 'partitioned_data', - 'create_namespace': True, - 'create_table': True, - 'partition_spec': None, # Will be set in test if needed - 'schema_evolution': True, - 'batch_size': 500, - } - - -@pytest.fixture -def iceberg_temp_config(iceberg_test_env): - """Get temporary Iceberg configuration with unique namespace""" - timestamp = datetime.now().strftime('%Y%m%d_%H%M%S') - return { - 'catalog_config': { - 'type': 'sql', - 'uri': f'sqlite:///{iceberg_test_env}/catalog.db', - 'warehouse': f'file://{iceberg_test_env}/warehouse', - }, - 'namespace': f'temp_data_{timestamp}', - 'create_namespace': True, - 'create_table': True, - 'schema_evolution': True, - 'batch_size': 2000, - } - - -@pytest.fixture -def comprehensive_test_data(): - """Create comprehensive test data for Iceberg testing""" - base_date = datetime(2024, 1, 1) - - data = { - 'id': list(range(1000)), - 'user_id': [f'user_{i % 100}' for i in range(1000)], - 'transaction_amount': [round((i * 12.34) % 1000, 2) for i in range(1000)], - 'category': [['electronics', 'clothing', 'books', 'food', 'travel'][i % 5] for i in range(1000)], - 'timestamp': pa.array( - [(base_date + timedelta(days=i // 50, hours=i % 24)) for i in range(1000)], - type=pa.timestamp('ns', tz='UTC'), - ), - 'year': [2024 if i < 800 else 2023 for i in range(1000)], - 'month': [(i // 80) % 12 + 1 for i in range(1000)], - 'day': [(i // 30) % 28 + 1 for i in range(1000)], - 'is_weekend': [i % 7 in [0, 6] for i in range(1000)], - 'metadata': [ - json.dumps( - { - 'session_id': f'session_{i}', - 'device': ['mobile', 'desktop', 'tablet'][i % 3], - 'location': ['US', 'UK', 'DE', 'FR', 'JP'][i % 5], - } - ) - for i in range(1000) - ], - 'score': [i * 0.123 for i in range(1000)], - 'active': [i % 2 == 0 for i in range(1000)], - } - - return pa.Table.from_pydict(data) - - -@pytest.fixture -def small_test_data(): - """Create small test data for quick tests""" - data = { - 'id': [1, 2, 3, 4, 5], - 'name': ['a', 'b', 'c', 'd', 'e'], - 'value': [10.1, 20.2, 30.3, 40.4, 50.5], - 'timestamp': pa.array( - [ - datetime(2024, 1, 1, 10, 0, 0), - datetime(2024, 1, 1, 11, 0, 0), - datetime(2024, 1, 1, 12, 0, 0), - datetime(2024, 1, 1, 13, 0, 0), - datetime(2024, 1, 1, 14, 0, 0), - ], - type=pa.timestamp('ns', tz='UTC'), - ), - 'year': [2024, 2024, 2024, 2024, 2024], - 'month': [1, 1, 1, 1, 1], - 'day': [1, 2, 3, 4, 5], - 'active': [True, False, True, False, True], - } - - return pa.Table.from_pydict(data) - - -@pytest.mark.integration -@pytest.mark.iceberg -class TestIcebergLoaderIntegration: - """Integration tests for Iceberg loader""" - - def test_loader_initialization(self, iceberg_basic_config): - """Test loader initialization and connection""" - loader = IcebergLoader(iceberg_basic_config) - - assert loader.config.namespace == iceberg_basic_config['namespace'] - assert loader.config.create_namespace == True - assert loader.config.create_table == True - - loader.connect() - assert loader._is_connected == True - assert loader._catalog is not None - assert loader._namespace_exists == True - - loader.disconnect() - assert loader._is_connected == False - assert loader._catalog is None - - def test_basic_table_operations(self, iceberg_basic_config, comprehensive_test_data): - """Test basic table creation and data loading""" - loader = IcebergLoader(iceberg_basic_config) - - with loader: - result = loader.load_table(comprehensive_test_data, 'test_transactions', mode=LoadMode.OVERWRITE) - - assert result.success == True - assert result.rows_loaded == 1000 - assert result.loader_type == 'iceberg' - assert result.table_name == 'test_transactions' - assert 'operation' in result.metadata - assert 'rows_loaded' in result.metadata - assert result.metadata['namespace'] == iceberg_basic_config['namespace'] - - def test_append_mode(self, iceberg_basic_config, comprehensive_test_data): - """Test append mode functionality""" - loader = IcebergLoader(iceberg_basic_config) - - with loader: - result = loader.load_table(comprehensive_test_data, 'test_append', mode=LoadMode.OVERWRITE) - assert result.success == True - assert result.rows_loaded == 1000 - - additional_data = comprehensive_test_data.slice(0, 100) - result = loader.load_table(additional_data, 'test_append', mode=LoadMode.APPEND) - - assert result.success == True - assert result.rows_loaded == 100 - assert result.metadata['operation'] == 'load_table' - - def test_batch_loading(self, iceberg_basic_config, comprehensive_test_data): - """Test batch loading functionality""" - loader = IcebergLoader(iceberg_basic_config) - - with loader: - batches = comprehensive_test_data.to_batches(max_chunksize=200) - - for i, batch in enumerate(batches): - mode = LoadMode.OVERWRITE if i == 0 else LoadMode.APPEND - result = loader.load_batch(batch, 'test_batches', mode=mode) - - assert result.success == True - assert result.rows_loaded == batch.num_rows - assert result.metadata['operation'] == 'load_batch' - assert result.metadata['batch_size'] == batch.num_rows - assert result.metadata['schema_fields'] == len(batch.schema) - - def test_partitioning(self, iceberg_partitioned_config, small_test_data): - """Test table partitioning functionality""" - loader = IcebergLoader(iceberg_partitioned_config) - - with loader: - # Load partitioned data - result = loader.load_table(small_test_data, 'test_partitioned', mode=LoadMode.OVERWRITE) - - assert result.success == True - # Note: Partitioning requires creating PartitionSpec objects now - assert result.metadata['namespace'] == iceberg_partitioned_config['namespace'] - - def test_timestamp_conversion(self, iceberg_basic_config): - """Test timestamp precision conversion (ns -> us)""" - # Create data with nanosecond timestamps - timestamp_data = pa.table( - { - 'id': [1, 2, 3], - 'event_time': pa.array( - [ - datetime(2024, 1, 1, 10, 0, 0, 123456), # microsecond precision - datetime(2024, 1, 1, 11, 0, 0, 654321), - datetime(2024, 1, 1, 12, 0, 0, 987654), - ], - type=pa.timestamp('ns', tz='UTC'), - ), # nanosecond type - 'name': ['event1', 'event2', 'event3'], - } - ) - - loader = IcebergLoader(iceberg_basic_config) - - with loader: - # Just verify that we can load nanosecond timestamps successfully - result = loader.load_table(timestamp_data, 'test_timestamps', mode=LoadMode.OVERWRITE) - - assert result.success == True - assert result.rows_loaded == 3 - - # The conversion happens internally - we just care that it works - - def test_schema_evolution(self, iceberg_basic_config, small_test_data): - """Test schema evolution functionality""" - loader = IcebergLoader(iceberg_basic_config) - - with loader: - result = loader.load_table(small_test_data, 'test_schema_evolution', mode=LoadMode.OVERWRITE) - assert result.success == True - - extended_data_dict = small_test_data.to_pydict() - extended_data_dict['new_column'] = list(range(len(extended_data_dict['id']))) - extended_data_dict['another_field'] = ['test_value'] * len(extended_data_dict['id']) - extended_table = pa.Table.from_pydict(extended_data_dict) - - result = loader.load_table(extended_table, 'test_schema_evolution', mode=LoadMode.APPEND) - - # Schema evolution should work successfully - assert result.success == True - assert result.rows_loaded > 0 - - # Verify that the new columns were added to the schema - table_info = loader.get_table_info('test_schema_evolution') - assert table_info['exists'] == True - assert 'new_column' in table_info['columns'] - assert 'another_field' in table_info['columns'] - - def test_error_handling_invalid_catalog(self): - """Test error handling with invalid catalog configuration""" - invalid_config = { - 'catalog_config': {'type': 'invalid_catalog_type', 'uri': 'invalid://invalid'}, - 'namespace': 'test', - 'create_namespace': True, - 'create_table': True, - } - - loader = IcebergLoader(invalid_config) - - # Should raise an error on connect - with pytest.raises(ValueError): - loader.connect() - - def test_error_handling_invalid_namespace(self, iceberg_test_env): - """Test error handling when namespace creation fails""" - config = { - 'catalog_config': { - 'type': 'sql', - 'uri': f'sqlite:///{iceberg_test_env}/catalog.db', - 'warehouse': f'file://{iceberg_test_env}/warehouse', - }, - 'namespace': 'test_namespace', - 'create_namespace': False, # Don't create namespace - 'create_table': True, - } - - loader = IcebergLoader(config) - - # Should fail if namespace doesn't exist and create_namespace=False - from pyiceberg.exceptions import NoSuchNamespaceError - - with pytest.raises(NoSuchNamespaceError): - loader.connect() - - def test_load_mode_overwrite(self, iceberg_basic_config, small_test_data): - """Test overwrite mode functionality""" - loader = IcebergLoader(iceberg_basic_config) - - with loader: - # Initial load - result = loader.load_table(small_test_data, 'test_overwrite', mode=LoadMode.OVERWRITE) - assert result.success == True - assert result.rows_loaded == 5 - - # Create different data - different_data = pa.table( - { - 'id': [10, 20], - 'name': ['x', 'y'], - 'value': [100.0, 200.0], - 'timestamp': pa.array( - [datetime(2024, 2, 1, 10, 0, 0), datetime(2024, 2, 1, 11, 0, 0)], - type=pa.timestamp('ns', tz='UTC'), - ), - 'year': [2024, 2024], - 'month': [2, 2], - 'day': [1, 1], - 'active': [True, True], - } - ) - - # Overwrite with different data - result = loader.load_table(different_data, 'test_overwrite', mode=LoadMode.OVERWRITE) - assert result.success == True - assert result.rows_loaded == 2 - - def test_context_manager(self, iceberg_basic_config, small_test_data): - """Test context manager functionality""" - loader = IcebergLoader(iceberg_basic_config) - - # Test context manager auto-connect/disconnect - assert not loader._is_connected - - with loader: - assert loader._is_connected == True - - result = loader.load_table(small_test_data, 'test_context', mode=LoadMode.OVERWRITE) - assert result.success == True - - # Should be disconnected after context exit - assert loader._is_connected == False - - def test_load_result_metadata(self, iceberg_basic_config, comprehensive_test_data): - """Test that LoadResult contains proper metadata""" - loader = IcebergLoader(iceberg_basic_config) - - with loader: - result = loader.load_table(comprehensive_test_data, 'test_metadata', mode=LoadMode.OVERWRITE) - - assert result.success == True - assert result.loader_type == 'iceberg' - assert result.table_name == 'test_metadata' - assert result.rows_loaded == 1000 - assert result.duration > 0 - - # Check metadata content - metadata = result.metadata - assert 'operation' in metadata - assert 'rows_loaded' in metadata - assert 'columns' in metadata - assert 'namespace' in metadata - assert metadata['namespace'] == iceberg_basic_config['namespace'] - assert metadata['rows_loaded'] == 1000 - assert metadata['columns'] == len(comprehensive_test_data.schema) - - -@pytest.mark.integration -@pytest.mark.iceberg -@pytest.mark.slow -class TestIcebergLoaderAdvanced: - """Advanced integration tests for Iceberg loader""" - - def test_large_data_performance(self, iceberg_basic_config): - """Test performance with larger datasets""" - # Create larger test dataset - large_data = { - 'id': list(range(10000)), - 'value': [i * 0.123 for i in range(10000)], - 'category': [f'category_{i % 10}' for i in range(10000)], - 'year': [2024] * 10000, - 'month': [(i // 800) % 12 + 1 for i in range(10000)], - 'timestamp': pa.array( - [datetime(2024, 1, 1) + timedelta(seconds=i) for i in range(10000)], type=pa.timestamp('ns', tz='UTC') - ), - } - - large_table = pa.Table.from_pydict(large_data) - - loader = IcebergLoader(iceberg_basic_config) - - with loader: - # Load large dataset - result = loader.load_table(large_table, 'test_performance', mode=LoadMode.OVERWRITE) - - assert result.success == True - assert result.rows_loaded == 10000 - assert result.duration < 300 # Should complete within reasonable time - - def test_multiple_tables_same_loader(self, iceberg_basic_config, small_test_data): - """Test loading multiple tables with the same loader instance""" - loader = IcebergLoader(iceberg_basic_config) - - with loader: - table_names = ['table_1', 'table_2', 'table_3'] - - for table_name in table_names: - result = loader.load_table(small_test_data, table_name, mode=LoadMode.OVERWRITE) - assert result.success == True - assert result.table_name == table_name - assert result.rows_loaded == 5 - - def test_batch_streaming(self, iceberg_basic_config, comprehensive_test_data): - """Test streaming batch operations""" - loader = IcebergLoader(iceberg_basic_config) - - with loader: - # Convert to batch iterator - batches = comprehensive_test_data.to_batches(max_chunksize=100) - batch_list = list(batches) - - # Load using load_stream method from base class - results = list(loader.load_stream(iter(batch_list), 'test_streaming')) - - # Verify all batches were processed - total_rows = sum(r.rows_loaded for r in results if r.success) - assert total_rows == 1000 - - # Verify all operations succeeded - for result in results: - assert result.success == True - assert result.loader_type == 'iceberg' - - def test_upsert_operations(self, iceberg_basic_config): - """Test UPSERT/MERGE operations with automatic matching""" - # Use basic config - no special configuration needed for upsert - upsert_config = iceberg_basic_config.copy() - - # Initial data - initial_data = {'id': [1, 2, 3], 'name': ['Alice', 'Bob', 'Charlie'], 'value': [100, 200, 300]} - initial_table = pa.Table.from_pydict(initial_data) - - loader = IcebergLoader(upsert_config) - - with loader: - # Load initial data - result1 = loader.load_table(initial_table, 'test_upsert', mode=LoadMode.APPEND) - assert result1.success == True - assert result1.rows_loaded == 3 - - # Upsert data (update existing + insert new) - upsert_data = { - 'id': [2, 3, 4], # 2,3 exist (update), 4 is new (insert) - 'name': ['Bob_Updated', 'Charlie_Updated', 'David'], - 'value': [250, 350, 400], - } - upsert_table = pa.Table.from_pydict(upsert_data) - - result2 = loader.load_table(upsert_table, 'test_upsert', mode=LoadMode.UPSERT) - assert result2.success == True - assert result2.rows_loaded == 3 - - def test_upsert_simple(self, iceberg_basic_config): - """Test simple UPSERT operations with default behavior""" - - test_data = {'id': [1, 2, 3], 'name': ['Alice', 'Bob', 'Charlie'], 'value': [100, 200, 300]} - test_table = pa.Table.from_pydict(test_data) - - loader = IcebergLoader(iceberg_basic_config) - - with loader: - # Simple upsert with default settings - result = loader.load_table(test_table, 'test_simple_upsert', mode=LoadMode.UPSERT) - assert result.success == True - assert result.rows_loaded == 3 - - def test_upsert_fallback_to_append(self, iceberg_basic_config): - """Test that UPSERT falls back to APPEND when upsert fails""" - - test_data = {'id': [1, 2, 3], 'name': ['Alice', 'Bob', 'Charlie'], 'value': [100, 200, 300]} - test_table = pa.Table.from_pydict(test_data) - - loader = IcebergLoader(iceberg_basic_config) - - with loader: - # Even if upsert fails, should fallback gracefully - result = loader.load_table(test_table, 'test_upsert_fallback', mode=LoadMode.UPSERT) - assert result.success == True - assert result.rows_loaded == 3 - - def test_handle_reorg_empty_table(self, iceberg_basic_config): - """Test reorg handling on empty table""" - from src.amp.streaming.types import BlockRange - - loader = IcebergLoader(iceberg_basic_config) - - with loader: - # Create table with one row first - initial_data = pa.table( - {'id': [999], 'block_num': [999], '_meta_block_ranges': ['[{"network": "test", "start": 1, "end": 2}]']} - ) - loader.load_table(initial_data, 'test_reorg_empty', mode=LoadMode.OVERWRITE) - - # Now overwrite with empty data to simulate empty table - empty_data = pa.table({'id': [], 'block_num': [], '_meta_block_ranges': []}) - loader.load_table(empty_data, 'test_reorg_empty', mode=LoadMode.OVERWRITE) - - # Call handle reorg on empty table - invalidation_ranges = [BlockRange(network='ethereum', start=100, end=200)] - - # Should not raise any errors - loader._handle_reorg(invalidation_ranges, 'test_reorg_empty', 'test_connection') - - # Verify table still exists - table_info = loader.get_table_info('test_reorg_empty') - assert table_info['exists'] == True - - def test_handle_reorg_no_metadata_column(self, iceberg_basic_config): - """Test reorg handling when table lacks metadata column""" - from src.amp.streaming.types import BlockRange - - loader = IcebergLoader(iceberg_basic_config) - - with loader: - # Create table without metadata column - data = pa.table({'id': [1, 2, 3], 'block_num': [100, 150, 200], 'value': [10.0, 20.0, 30.0]}) - loader.load_table(data, 'test_reorg_no_meta', mode=LoadMode.OVERWRITE) - - # Call handle reorg - invalidation_ranges = [BlockRange(network='ethereum', start=150, end=250)] - - # Should log warning and not modify data - loader._handle_reorg(invalidation_ranges, 'test_reorg_no_meta', 'test_connection') - - # Verify data unchanged - table_info = loader.get_table_info('test_reorg_no_meta') - assert table_info['exists'] == True - - def test_handle_reorg_single_network(self, iceberg_basic_config): - """Test reorg handling for single network data""" - from src.amp.streaming.types import BlockRange - - loader = IcebergLoader(iceberg_basic_config) - - with loader: - # Create table with metadata - block_ranges = [ - [{'network': 'ethereum', 'start': 100, 'end': 110}], - [{'network': 'ethereum', 'start': 150, 'end': 160}], - [{'network': 'ethereum', 'start': 200, 'end': 210}], - ] - - data = pa.table( - { - 'id': [1, 2, 3], - 'block_num': [105, 155, 205], - '_meta_block_ranges': [json.dumps(ranges) for ranges in block_ranges], - } - ) - - # Load initial data - result = loader.load_table(data, 'test_reorg_single', mode=LoadMode.OVERWRITE) - assert result.success == True - assert result.rows_loaded == 3 - - # Reorg from block 155 - should delete rows 2 and 3 - invalidation_ranges = [BlockRange(network='ethereum', start=155, end=300)] - loader._handle_reorg(invalidation_ranges, 'test_reorg_single', 'test_connection') - - # Verify only first row remains - # Since we can't easily query Iceberg tables in tests, we'll verify through table info - table_info = loader.get_table_info('test_reorg_single') - assert table_info['exists'] == True - # The actual row count verification would require scanning the table - - def test_handle_reorg_multi_network(self, iceberg_basic_config): - """Test reorg handling preserves data from unaffected networks""" - from src.amp.streaming.types import BlockRange - - loader = IcebergLoader(iceberg_basic_config) - - with loader: - # Create table with data from multiple networks - block_ranges = [ - [{'network': 'ethereum', 'start': 100, 'end': 110}], - [{'network': 'polygon', 'start': 100, 'end': 110}], - [{'network': 'ethereum', 'start': 150, 'end': 160}], - [{'network': 'polygon', 'start': 150, 'end': 160}], - ] - - data = pa.table( - { - 'id': [1, 2, 3, 4], - 'network': ['ethereum', 'polygon', 'ethereum', 'polygon'], - '_meta_block_ranges': [json.dumps(r) for r in block_ranges], - } - ) - - # Load initial data - result = loader.load_table(data, 'test_reorg_multi', mode=LoadMode.OVERWRITE) - assert result.success == True - assert result.rows_loaded == 4 - - # Reorg only ethereum from block 150 - invalidation_ranges = [BlockRange(network='ethereum', start=150, end=200)] - loader._handle_reorg(invalidation_ranges, 'test_reorg_multi', 'test_connection') - - # Verify ethereum row 3 deleted, but polygon rows preserved - table_info = loader.get_table_info('test_reorg_multi') - assert table_info['exists'] == True - - def test_handle_reorg_overlapping_ranges(self, iceberg_basic_config): - """Test reorg with overlapping block ranges""" - from src.amp.streaming.types import BlockRange - - loader = IcebergLoader(iceberg_basic_config) - - with loader: - # Create data with overlapping ranges - block_ranges = [ - [{'network': 'ethereum', 'start': 90, 'end': 110}], # Overlaps with reorg - [{'network': 'ethereum', 'start': 140, 'end': 160}], # Overlaps with reorg - [{'network': 'ethereum', 'start': 170, 'end': 190}], # After reorg - ] - - data = pa.table({'id': [1, 2, 3], '_meta_block_ranges': [json.dumps(ranges) for ranges in block_ranges]}) - - # Load initial data - result = loader.load_table(data, 'test_reorg_overlap', mode=LoadMode.OVERWRITE) - assert result.success == True - assert result.rows_loaded == 3 - - # Reorg from block 150 - should delete rows where end >= 150 - invalidation_ranges = [BlockRange(network='ethereum', start=150, end=200)] - loader._handle_reorg(invalidation_ranges, 'test_reorg_overlap', 'test_connection') - - # Only first row should remain (ends at 110 < 150) - table_info = loader.get_table_info('test_reorg_overlap') - assert table_info['exists'] == True - - def test_handle_reorg_multiple_invalidations(self, iceberg_basic_config): - """Test handling multiple invalidation ranges""" - from src.amp.streaming.types import BlockRange - - loader = IcebergLoader(iceberg_basic_config) - - with loader: - # Create data from multiple networks - block_ranges = [ - [{'network': 'ethereum', 'start': 100, 'end': 110}], - [{'network': 'polygon', 'start': 200, 'end': 210}], - [{'network': 'arbitrum', 'start': 300, 'end': 310}], - [{'network': 'ethereum', 'start': 150, 'end': 160}], - [{'network': 'polygon', 'start': 250, 'end': 260}], - ] - - data = pa.table({'id': [1, 2, 3, 4, 5], '_meta_block_ranges': [json.dumps(r) for r in block_ranges]}) - - # Load initial data - result = loader.load_table(data, 'test_reorg_multiple', mode=LoadMode.OVERWRITE) - assert result.success == True - assert result.rows_loaded == 5 - - # Multiple reorgs - invalidation_ranges = [ - BlockRange(network='ethereum', start=150, end=200), # Affects row 4 - BlockRange(network='polygon', start=250, end=300), # Affects row 5 - ] - loader._handle_reorg(invalidation_ranges, 'test_reorg_multiple', 'test_connection') - - # Rows 1, 2, 3 should remain - table_info = loader.get_table_info('test_reorg_multiple') - assert table_info['exists'] == True - - def test_streaming_with_reorg(self, iceberg_basic_config): - """Test streaming data with reorg support""" - from src.amp.streaming.types import ( - BatchMetadata, - BlockRange, - ResponseBatch, - ) - - loader = IcebergLoader(iceberg_basic_config) - - with loader: - # Create streaming data with metadata - data1 = pa.RecordBatch.from_pydict({'id': [1, 2], 'value': [100, 200]}) - - data2 = pa.RecordBatch.from_pydict({'id': [3, 4], 'value': [300, 400]}) - - # Create response batches using factory methods (with hashes for proper state management) - response1 = ResponseBatch.data_batch( - data=data1, - metadata=BatchMetadata(ranges=[BlockRange(network='ethereum', start=100, end=110, hash='0xabc123')]), - ) - - response2 = ResponseBatch.data_batch( - data=data2, - metadata=BatchMetadata(ranges=[BlockRange(network='ethereum', start=150, end=160, hash='0xdef456')]), - ) - - # Simulate reorg event using factory method - reorg_response = ResponseBatch.reorg_batch( - invalidation_ranges=[BlockRange(network='ethereum', start=150, end=200)] - ) - - # Process streaming data - stream = [response1, response2, reorg_response] - results = list(loader.load_stream_continuous(iter(stream), 'test_streaming_reorg')) - - # Verify results - assert len(results) == 3 - assert results[0].success == True - assert results[0].rows_loaded == 2 - assert results[1].success == True - assert results[1].rows_loaded == 2 - assert results[2].success == True - assert results[2].is_reorg == True diff --git a/tests/integration/test_lmdb_loader.py b/tests/integration/test_lmdb_loader.py deleted file mode 100644 index 20e2f67..0000000 --- a/tests/integration/test_lmdb_loader.py +++ /dev/null @@ -1,653 +0,0 @@ -# tests/integration/test_lmdb_loader.py -""" -Integration tests for LMDB loader implementation. -These tests use a local LMDB instance. -""" - -import os -import shutil -import tempfile -import time -from datetime import datetime - -import pyarrow as pa -import pytest - -try: - from src.amp.loaders.base import LoadMode - from src.amp.loaders.implementations.lmdb_loader import LMDBLoader -except ImportError: - pytest.skip('LMD loader modules not available', allow_module_level=True) - - -@pytest.fixture -def lmdb_test_dir(): - """Create and cleanup temporary directory for LMDB databases""" - temp_dir = tempfile.mkdtemp(prefix='lmdb_test_') - yield temp_dir - - # Cleanup - shutil.rmtree(temp_dir, ignore_errors=True) - - -@pytest.fixture -def lmdb_config(lmdb_test_dir): - """LMDB configuration for testing""" - return { - 'db_path': os.path.join(lmdb_test_dir, 'test.lmdb'), - 'map_size': 100 * 1024**2, # 100MB for tests - 'transaction_size': 1000, - } - - -@pytest.fixture -def test_table_name(): - """Generate unique table name for each test""" - timestamp = datetime.now().strftime('%Y%m%d_%H%M%S_%f') - return f'test_table_{timestamp}' - - -@pytest.fixture -def sample_test_data(): - """Create sample test data""" - data = { - 'id': list(range(100)), - 'name': [f'name_{i}' for i in range(100)], - 'value': [i * 1.5 for i in range(100)], - 'active': [i % 2 == 0 for i in range(100)], - } - return pa.Table.from_pydict(data) - - -@pytest.fixture -def blockchain_test_data(): - """Create blockchain-like test data""" - data = { - 'block_number': list(range(1000, 2000)), - 'block_hash': [f'0x{i:064x}' for i in range(1000, 2000)], - 'timestamp': [1600000000 + i * 12 for i in range(1000)], - 'transaction_count': [i % 100 + 1 for i in range(1000)], - 'gas_used': [21000 + i * 1000 for i in range(1000)], - } - return pa.Table.from_pydict(data) - - -@pytest.mark.integration -@pytest.mark.lmdb -class TestLMDBLoaderIntegration: - """Test LMDB loader implementation""" - - def test_connection(self, lmdb_config): - """Test basic connection to LMDB""" - loader = LMDBLoader(lmdb_config) - - # Should not be connected initially - assert not loader.is_connected - - # Connect - loader.connect() - assert loader.is_connected - assert loader.env is not None - - # Check that database was created - assert os.path.exists(lmdb_config['db_path']) - - # Disconnect - loader.disconnect() - assert not loader.is_connected - - def test_load_table_basic(self, lmdb_config, sample_test_data, test_table_name): - """Test basic table loading""" - loader = LMDBLoader(lmdb_config) - loader.connect() - - # Load data - result = loader.load_table(sample_test_data, test_table_name) - - assert result.success - assert result.rows_loaded == 100 - assert result.table_name == test_table_name - assert result.loader_type == 'lmdb' - - # Verify data was stored - with loader.env.begin() as txn: - cursor = txn.cursor() - count = sum(1 for _ in cursor) - assert count == 100 - - loader.disconnect() - - def test_key_column_strategy(self, lmdb_config, sample_test_data, test_table_name): - """Test key generation using specific column""" - config = {**lmdb_config, 'key_column': 'id'} - loader = LMDBLoader(config) - loader.connect() - - result = loader.load_table(sample_test_data, test_table_name) - assert result.success - - # Verify keys were generated correctly - with loader.env.begin() as txn: - # Check specific keys - for i in range(10): - key = str(i).encode('utf-8') - value = txn.get(key) - assert value is not None - - loader.disconnect() - - def test_key_pattern_strategy(self, lmdb_config, blockchain_test_data, test_table_name): - """Test pattern-based key generation""" - config = {**lmdb_config, 'key_pattern': '{table}:block:{block_number}'} - loader = LMDBLoader(config) - loader.connect() - - result = loader.load_table(blockchain_test_data, test_table_name) - assert result.success - - # Verify pattern-based keys - with loader.env.begin() as txn: - key = f'{test_table_name}:block:1000'.encode('utf-8') - value = txn.get(key) - assert value is not None - - loader.disconnect() - - def test_composite_key_strategy(self, lmdb_config, sample_test_data, test_table_name): - """Test composite key generation""" - config = {**lmdb_config, 'composite_key_columns': ['name', 'id']} - loader = LMDBLoader(config) - loader.connect() - - result = loader.load_table(sample_test_data, test_table_name) - assert result.success - - # Verify composite keys - with loader.env.begin() as txn: - key = 'name_0:0'.encode('utf-8') - value = txn.get(key) - assert value is not None - - loader.disconnect() - - def test_named_database(self, lmdb_config, sample_test_data): - """Test using named databases""" - config = {**lmdb_config, 'database_name': 'blocks', 'key_column': 'id'} - loader = LMDBLoader(config) - loader.connect() - - # Load to named database - result = loader.load_table(sample_test_data, 'any_table') - assert result.success - - # Verify data is in named database - db = loader.env.open_db(b'blocks') - with loader.env.begin(db=db) as txn: - cursor = txn.cursor() - count = sum(1 for _ in cursor) - assert count == 100 - - loader.disconnect() - - def test_load_modes(self, lmdb_config, sample_test_data, test_table_name): - """Test different load modes""" - config = {**lmdb_config, 'key_column': 'id'} - loader = LMDBLoader(config) - loader.connect() - - # Initial load - result1 = loader.load_table(sample_test_data, test_table_name) - assert result1.success - assert result1.rows_loaded == 100 - - # Append mode (should fail for duplicate keys) - result2 = loader.load_table(sample_test_data, test_table_name, mode=LoadMode.APPEND) - assert not result2.success # Should fail due to duplicate keys - assert result2.rows_loaded == 0 # No new rows added - assert 'Key already exists' in str(result2.error) - - # Overwrite mode - result3 = loader.load_table(sample_test_data, test_table_name, mode=LoadMode.OVERWRITE) - assert result3.success - assert result3.rows_loaded == 100 - - loader.disconnect() - - def test_transaction_batching(self, lmdb_test_dir, blockchain_test_data, test_table_name): - """Test transaction batching performance""" - # Small transactions - config1 = { - 'db_path': os.path.join(lmdb_test_dir, 'test1.lmdb'), - 'map_size': 100 * 1024**2, - 'transaction_size': 100, - 'key_column': 'block_number', - } - loader1 = LMDBLoader(config1) - loader1.connect() - - start = time.time() - result1 = loader1.load_table(blockchain_test_data, test_table_name) - time1 = time.time() - start - - loader1.disconnect() - - # Large transactions - use different database - config2 = { - 'db_path': os.path.join(lmdb_test_dir, 'test2.lmdb'), - 'map_size': 100 * 1024**2, - 'transaction_size': 1000, - 'key_column': 'block_number', - } - loader2 = LMDBLoader(config2) - loader2.connect() - - start = time.time() - result2 = loader2.load_table(blockchain_test_data, test_table_name) - time2 = time.time() - start - - loader2.disconnect() - - # Both should succeed - assert result1.success - assert result2.success - - # Larger transactions should generally be faster - # (though this might not always be true in small test datasets) - print(f'Small txn time: {time1:.3f}s, Large txn time: {time2:.3f}s') - - def test_byte_key_handling(self, lmdb_config): - """Test handling of byte array keys""" - # Create data with byte keys - data = {'key': [b'key1', b'key2', b'key3'], 'value': [1, 2, 3]} - table = pa.Table.from_pydict(data) - - config = {**lmdb_config, 'key_column': 'key'} - loader = LMDBLoader(config) - loader.connect() - - result = loader.load_table(table, 'byte_test') - assert result.success - assert result.rows_loaded == 3 - - # Verify byte keys work correctly - with loader.env.begin() as txn: - assert txn.get(b'key1') is not None - assert txn.get(b'key2') is not None - assert txn.get(b'key3') is not None - - loader.disconnect() - - def test_large_batch_loading(self, lmdb_config): - """Test loading large batches""" - # Create larger dataset - size = 50000 - data = { - 'id': list(range(size)), - 'data': ['x' * 100 for _ in range(size)], # 100 chars per row - } - table = pa.Table.from_pydict(data) - - config = { - **lmdb_config, - 'map_size': 500 * 1024**2, # 500MB - 'transaction_size': 5000, - 'key_column': 'id', - } - loader = LMDBLoader(config) - loader.connect() - - start = time.time() - result = loader.load_table(table, 'large_test') - duration = time.time() - start - - assert result.success - assert result.rows_loaded == size - - # Check performance metrics - throughput = result.metadata['throughput_rows_per_sec'] - print(f'Loaded {size} rows in {duration:.2f}s ({throughput:.0f} rows/sec)') - assert throughput > 10000 # Should handle at least 10k rows/sec - - loader.disconnect() - - def test_error_handling(self, lmdb_config, sample_test_data): - """Test error handling""" - # Test with invalid key column - config = {**lmdb_config, 'key_column': 'nonexistent_column'} - loader = LMDBLoader(config) - loader.connect() - - result = loader.load_table(sample_test_data, 'error_test') - assert not result.success - assert 'not found in data' in result.error - - loader.disconnect() - - def test_data_persistence(self, lmdb_config, sample_test_data, test_table_name): - """Test that data persists across connections""" - config = {**lmdb_config, 'key_column': 'id'} - - # First connection - write data - loader1 = LMDBLoader(config) - loader1.connect() - result = loader1.load_table(sample_test_data, test_table_name) - assert result.success - loader1.disconnect() - - # Second connection - verify data - loader2 = LMDBLoader(config) - loader2.connect() - - with loader2.env.begin() as txn: - # Check a few keys - for i in range(10): - key = str(i).encode('utf-8') - value = txn.get(key) - assert value is not None - - # Deserialize and verify it's valid Arrow data - import pyarrow as pa - - reader = pa.ipc.open_stream(value) - batch = reader.read_next_batch() - assert batch.num_rows == 1 - - loader2.disconnect() - - def test_handle_reorg_empty_db(self, lmdb_config): - """Test reorg handling on empty database""" - from src.amp.streaming.types import BlockRange - - loader = LMDBLoader(lmdb_config) - loader.connect() - - # Call handle reorg on empty database - invalidation_ranges = [BlockRange(network='ethereum', start=100, end=200)] - - # Should not raise any errors - loader._handle_reorg(invalidation_ranges, 'test_reorg_empty', 'test_connection') - - loader.disconnect() - - def test_handle_reorg_no_metadata(self, lmdb_config): - """Test reorg handling when data lacks metadata column""" - from src.amp.streaming.types import BlockRange - - config = {**lmdb_config, 'key_column': 'id'} - loader = LMDBLoader(config) - loader.connect() - - # Create data without metadata column - data = pa.table({'id': [1, 2, 3], 'block_num': [100, 150, 200], 'value': [10.0, 20.0, 30.0]}) - loader.load_table(data, 'test_reorg_no_meta', mode=LoadMode.OVERWRITE) - - # Call handle reorg - invalidation_ranges = [BlockRange(network='ethereum', start=150, end=250)] - - # Should not delete any data (no metadata to check) - loader._handle_reorg(invalidation_ranges, 'test_reorg_no_meta', 'test_connection') - - # Verify data still exists - with loader.env.begin() as txn: - assert txn.get(b'1') is not None - assert txn.get(b'2') is not None - assert txn.get(b'3') is not None - - loader.disconnect() - - def test_handle_reorg_single_network(self, lmdb_config): - """Test reorg handling for single network data""" - from src.amp.streaming.types import BatchMetadata, BlockRange, ResponseBatch - - config = {**lmdb_config, 'key_column': 'id'} - loader = LMDBLoader(config) - loader.connect() - - # Create streaming batches with metadata - batch1 = pa.RecordBatch.from_pydict({'id': [1], 'block_num': [105]}) - batch2 = pa.RecordBatch.from_pydict({'id': [2], 'block_num': [155]}) - batch3 = pa.RecordBatch.from_pydict({'id': [3], 'block_num': [205]}) - - # Create response batches with hashes - response1 = ResponseBatch.data_batch( - data=batch1, - metadata=BatchMetadata( - ranges=[BlockRange(network='ethereum', start=100, end=110, hash='0xabc')], - ranges_complete=True, # Mark as complete so it gets tracked in state store - ), - ) - response2 = ResponseBatch.data_batch( - data=batch2, - metadata=BatchMetadata( - ranges=[BlockRange(network='ethereum', start=150, end=160, hash='0xdef')], - ranges_complete=True, # Mark as complete so it gets tracked in state store - ), - ) - response3 = ResponseBatch.data_batch( - data=batch3, - metadata=BatchMetadata( - ranges=[BlockRange(network='ethereum', start=200, end=210, hash='0x123')], - ranges_complete=True, # Mark as complete so it gets tracked in state store - ), - ) - - # Load via streaming API - stream = [response1, response2, response3] - results = list(loader.load_stream_continuous(iter(stream), 'test_reorg_single')) - assert len(results) == 3 - assert all(r.success for r in results) - - # Verify all data exists - with loader.env.begin() as txn: - assert txn.get(b'1') is not None - assert txn.get(b'2') is not None - assert txn.get(b'3') is not None - - # Reorg from block 155 - should delete rows 2 and 3 - reorg_response = ResponseBatch.reorg_batch( - invalidation_ranges=[BlockRange(network='ethereum', start=155, end=300)] - ) - reorg_results = list(loader.load_stream_continuous(iter([reorg_response]), 'test_reorg_single')) - assert len(reorg_results) == 1 - assert reorg_results[0].success - assert reorg_results[0].is_reorg - - # Verify only first row remains - with loader.env.begin() as txn: - assert txn.get(b'1') is not None - assert txn.get(b'2') is None # Deleted - assert txn.get(b'3') is None # Deleted - - loader.disconnect() - - def test_handle_reorg_multi_network(self, lmdb_config): - """Test reorg handling preserves data from unaffected networks""" - from src.amp.streaming.types import BatchMetadata, BlockRange, ResponseBatch - - config = {**lmdb_config, 'key_column': 'id'} - loader = LMDBLoader(config) - loader.connect() - - # Create streaming batches from multiple networks - batch1 = pa.RecordBatch.from_pydict({'id': [1], 'network': ['ethereum']}) - batch2 = pa.RecordBatch.from_pydict({'id': [2], 'network': ['polygon']}) - batch3 = pa.RecordBatch.from_pydict({'id': [3], 'network': ['ethereum']}) - batch4 = pa.RecordBatch.from_pydict({'id': [4], 'network': ['polygon']}) - - # Create response batches with network-specific ranges - response1 = ResponseBatch.data_batch( - data=batch1, - metadata=BatchMetadata( - ranges=[BlockRange(network='ethereum', start=100, end=110, hash='0xaaa')], - ranges_complete=True, # Mark as complete so it gets tracked in state store - ), - ) - response2 = ResponseBatch.data_batch( - data=batch2, - metadata=BatchMetadata( - ranges=[BlockRange(network='polygon', start=100, end=110, hash='0xbbb')], - ranges_complete=True, # Mark as complete so it gets tracked in state store - ), - ) - response3 = ResponseBatch.data_batch( - data=batch3, - metadata=BatchMetadata( - ranges=[BlockRange(network='ethereum', start=150, end=160, hash='0xccc')], - ranges_complete=True, # Mark as complete so it gets tracked in state store - ), - ) - response4 = ResponseBatch.data_batch( - data=batch4, - metadata=BatchMetadata( - ranges=[BlockRange(network='polygon', start=150, end=160, hash='0xddd')], - ranges_complete=True, # Mark as complete so it gets tracked in state store - ), - ) - - # Load via streaming API - stream = [response1, response2, response3, response4] - results = list(loader.load_stream_continuous(iter(stream), 'test_reorg_multi')) - assert len(results) == 4 - assert all(r.success for r in results) - - # Reorg only ethereum from block 150 - reorg_response = ResponseBatch.reorg_batch( - invalidation_ranges=[BlockRange(network='ethereum', start=150, end=200)] - ) - reorg_results = list(loader.load_stream_continuous(iter([reorg_response]), 'test_reorg_multi')) - assert len(reorg_results) == 1 - assert reorg_results[0].success - - # Verify ethereum row 3 deleted, but polygon rows preserved - with loader.env.begin() as txn: - assert txn.get(b'1') is not None # ethereum block 100 - assert txn.get(b'2') is not None # polygon block 100 - assert txn.get(b'3') is None # ethereum block 150 (deleted) - assert txn.get(b'4') is not None # polygon block 150 - - loader.disconnect() - - def test_handle_reorg_overlapping_ranges(self, lmdb_config): - """Test reorg with overlapping block ranges""" - from src.amp.streaming.types import BatchMetadata, BlockRange, ResponseBatch - - config = {**lmdb_config, 'key_column': 'id'} - loader = LMDBLoader(config) - loader.connect() - - # Create streaming batches with different ranges - batch1 = pa.RecordBatch.from_pydict({'id': [1]}) - batch2 = pa.RecordBatch.from_pydict({'id': [2]}) - batch3 = pa.RecordBatch.from_pydict({'id': [3]}) - - # Batch 1: 90-110 (ends before reorg start of 150) - # Batch 2: 140-160 (overlaps with reorg) - # Batch 3: 170-190 (after reorg, but should be deleted as 170 >= 150) - response1 = ResponseBatch.data_batch( - data=batch1, - metadata=BatchMetadata( - ranges=[BlockRange(network='ethereum', start=90, end=110, hash='0xaaa')], - ranges_complete=True, # Mark as complete so it gets tracked in state store - ), - ) - response2 = ResponseBatch.data_batch( - data=batch2, - metadata=BatchMetadata( - ranges=[BlockRange(network='ethereum', start=140, end=160, hash='0xbbb')], - ranges_complete=True, # Mark as complete so it gets tracked in state store - ), - ) - response3 = ResponseBatch.data_batch( - data=batch3, - metadata=BatchMetadata( - ranges=[BlockRange(network='ethereum', start=170, end=190, hash='0xccc')], - ranges_complete=True, # Mark as complete so it gets tracked in state store - ), - ) - - # Load via streaming API - stream = [response1, response2, response3] - results = list(loader.load_stream_continuous(iter(stream), 'test_reorg_overlap')) - assert len(results) == 3 - assert all(r.success for r in results) - - # Reorg from block 150 - should delete batches 2 and 3 - reorg_response = ResponseBatch.reorg_batch( - invalidation_ranges=[BlockRange(network='ethereum', start=150, end=200)] - ) - reorg_results = list(loader.load_stream_continuous(iter([reorg_response]), 'test_reorg_overlap')) - assert len(reorg_results) == 1 - assert reorg_results[0].success - - # Only first row should remain (ends at 110 < 150) - with loader.env.begin() as txn: - assert txn.get(b'1') is not None - assert txn.get(b'2') is None # Deleted (end=160 >= 150) - assert txn.get(b'3') is None # Deleted (end=190 >= 150) - - loader.disconnect() - - def test_streaming_with_reorg(self, lmdb_config): - """Test streaming data with reorg support""" - from src.amp.streaming.types import ( - BatchMetadata, - BlockRange, - ResponseBatch, - ) - - config = {**lmdb_config, 'key_column': 'id'} - loader = LMDBLoader(config) - loader.connect() - - # Create streaming data with metadata - data1 = pa.RecordBatch.from_pydict({'id': [1, 2], 'value': [100, 200]}) - - data2 = pa.RecordBatch.from_pydict({'id': [3, 4], 'value': [300, 400]}) - - # Create response batches using factory methods (with hashes for proper state management) - response1 = ResponseBatch.data_batch( - data=data1, - metadata=BatchMetadata( - ranges=[BlockRange(network='ethereum', start=100, end=110, hash='0xabc123')], - ranges_complete=True, # Mark as complete so it gets tracked in state store - ), - ) - - response2 = ResponseBatch.data_batch( - data=data2, - metadata=BatchMetadata( - ranges=[BlockRange(network='ethereum', start=150, end=160, hash='0xdef456')], - ranges_complete=True, # Mark as complete so it gets tracked in state store - ), - ) - - # Simulate reorg event using factory method - reorg_response = ResponseBatch.reorg_batch( - invalidation_ranges=[BlockRange(network='ethereum', start=150, end=200)] - ) - - # Process streaming data - stream = [response1, response2, reorg_response] - results = list(loader.load_stream_continuous(iter(stream), 'test_streaming_reorg')) - - # Verify results - assert len(results) == 3 - assert results[0].success - assert results[0].rows_loaded == 2 - assert results[1].success - assert results[1].rows_loaded == 2 - assert results[2].success - assert results[2].is_reorg - - # Verify reorg deleted the second batch - with loader.env.begin() as txn: - assert txn.get(b'1') is not None - assert txn.get(b'2') is not None - assert txn.get(b'3') is None # Deleted by reorg - assert txn.get(b'4') is None # Deleted by reorg - - loader.disconnect() - - -if __name__ == '__main__': - pytest.main([__file__, '-v']) diff --git a/tests/integration/test_postgresql_loader.py b/tests/integration/test_postgresql_loader.py deleted file mode 100644 index 649868c..0000000 --- a/tests/integration/test_postgresql_loader.py +++ /dev/null @@ -1,838 +0,0 @@ -# tests/integration/test_postgresql_loader.py -""" -Integration tests for PostgreSQL loader implementation. -These tests require a running PostgreSQL instance. -""" - -import time -from datetime import datetime - -import pyarrow as pa -import pytest - -try: - from src.amp.loaders.base import LoadMode - from src.amp.loaders.implementations.postgresql_loader import PostgreSQLLoader -except ImportError: - pytest.skip('amp modules not available', allow_module_level=True) - - -@pytest.fixture -def test_table_name(): - """Generate unique table name for each test""" - timestamp = datetime.now().strftime('%Y%m%d_%H%M%S_%f') - return f'test_table_{timestamp}' - - -@pytest.fixture -def postgresql_type_test_data(): - """Create test data specifically for PostgreSQL data type testing""" - data = { - 'id': list(range(1000)), - 'text_field': [f'text_{i}' for i in range(1000)], - 'float_field': [i * 1.23 for i in range(1000)], - 'bool_field': [i % 2 == 0 for i in range(1000)], - } - return pa.Table.from_pydict(data) - - -@pytest.fixture -def cleanup_tables(postgresql_test_config): - """Cleanup test tables after tests""" - tables_to_clean = [] - - yield tables_to_clean - - # Cleanup - loader = PostgreSQLLoader(postgresql_test_config) - try: - loader.connect() - conn = loader.pool.getconn() - try: - with conn.cursor() as cur: - for table in tables_to_clean: - try: - cur.execute(f'DROP TABLE IF EXISTS {table} CASCADE') - conn.commit() - except Exception: - pass - finally: - loader.pool.putconn(conn) - loader.disconnect() - except Exception: - pass - - -@pytest.mark.integration -@pytest.mark.postgresql -class TestPostgreSQLLoaderIntegration: - """Integration tests for PostgreSQL loader""" - - def test_loader_connection(self, postgresql_test_config): - """Test basic connection to PostgreSQL""" - loader = PostgreSQLLoader(postgresql_test_config) - - # Test connection - loader.connect() - assert loader._is_connected == True - assert loader.pool is not None - - # Test disconnection - loader.disconnect() - assert loader._is_connected == False - assert loader.pool is None - - def test_context_manager(self, postgresql_test_config, small_test_data, test_table_name, cleanup_tables): - """Test context manager functionality""" - cleanup_tables.append(test_table_name) - - loader = PostgreSQLLoader(postgresql_test_config) - - with loader: - assert loader._is_connected == True - - result = loader.load_table(small_test_data, test_table_name) - assert result.success == True - - # Should be disconnected after context - assert loader._is_connected == False - - def test_basic_table_operations(self, postgresql_test_config, small_test_data, test_table_name, cleanup_tables): - """Test basic table creation and data loading""" - cleanup_tables.append(test_table_name) - - loader = PostgreSQLLoader(postgresql_test_config) - - with loader: - # Test initial table creation - result = loader.load_table(small_test_data, test_table_name, create_table=True) - - assert result.success == True - assert result.rows_loaded == 5 - assert result.loader_type == 'postgresql' - assert result.table_name == test_table_name - assert 'columns' in result.metadata - assert result.metadata['columns'] == 7 - - def test_append_mode(self, postgresql_test_config, small_test_data, test_table_name, cleanup_tables): - """Test append mode functionality""" - cleanup_tables.append(test_table_name) - - loader = PostgreSQLLoader(postgresql_test_config) - - with loader: - # Initial load - result = loader.load_table(small_test_data, test_table_name, mode=LoadMode.APPEND) - assert result.success == True - assert result.rows_loaded == 5 - - # Append additional data - result = loader.load_table(small_test_data, test_table_name, mode=LoadMode.APPEND) - assert result.success == True - assert result.rows_loaded == 5 - - # Verify total rows - conn = loader.pool.getconn() - try: - with conn.cursor() as cur: - cur.execute(f'SELECT COUNT(*) FROM {test_table_name}') - count = cur.fetchone()[0] - assert count == 10 # 5 + 5 - finally: - loader.pool.putconn(conn) - - def test_overwrite_mode(self, postgresql_test_config, small_test_data, test_table_name, cleanup_tables): - """Test overwrite mode functionality""" - cleanup_tables.append(test_table_name) - - loader = PostgreSQLLoader(postgresql_test_config) - - with loader: - # Initial load - result = loader.load_table(small_test_data, test_table_name, mode=LoadMode.OVERWRITE) - assert result.success == True - assert result.rows_loaded == 5 - - # Overwrite with different data - new_data = small_test_data.slice(0, 3) # First 3 rows - result = loader.load_table(new_data, test_table_name, mode=LoadMode.OVERWRITE) - assert result.success == True - assert result.rows_loaded == 3 - - # Verify only new data remains - conn = loader.pool.getconn() - try: - with conn.cursor() as cur: - cur.execute(f'SELECT COUNT(*) FROM {test_table_name}') - count = cur.fetchone()[0] - assert count == 3 - finally: - loader.pool.putconn(conn) - - def test_batch_loading(self, postgresql_test_config, medium_test_table, test_table_name, cleanup_tables): - """Test batch loading functionality""" - cleanup_tables.append(test_table_name) - - loader = PostgreSQLLoader(postgresql_test_config) - - with loader: - # Test loading individual batches - batches = medium_test_table.to_batches(max_chunksize=250) - - for i, batch in enumerate(batches): - mode = LoadMode.OVERWRITE if i == 0 else LoadMode.APPEND - result = loader.load_batch(batch, test_table_name, mode=mode) - - assert result.success == True - assert result.rows_loaded == batch.num_rows - assert result.metadata['batch_size'] == batch.num_rows - - # Verify all data was loaded - conn = loader.pool.getconn() - try: - with conn.cursor() as cur: - cur.execute(f'SELECT COUNT(*) FROM {test_table_name}') - count = cur.fetchone()[0] - assert count == 10000 - finally: - loader.pool.putconn(conn) - - def test_data_types(self, postgresql_test_config, postgresql_type_test_data, test_table_name, cleanup_tables): - """Test various data types are handled correctly""" - cleanup_tables.append(test_table_name) - - loader = PostgreSQLLoader(postgresql_test_config) - - with loader: - result = loader.load_table(postgresql_type_test_data, test_table_name) - assert result.success == True - assert result.rows_loaded == 1000 - - # Verify data integrity - conn = loader.pool.getconn() - try: - with conn.cursor() as cur: - # Check various data types - cur.execute(f'SELECT id, text_field, float_field, bool_field FROM {test_table_name} WHERE id = 10') - row = cur.fetchone() - assert row[0] == 10 - assert row[1] in ['text_10', '"text_10"'] # Handle potential CSV quoting - assert abs(row[2] - 12.3) < 0.01 # 10 * 1.23 = 12.3 - assert row[3] == True - finally: - loader.pool.putconn(conn) - - def test_null_value_handling(self, postgresql_test_config, null_test_data, test_table_name, cleanup_tables): - """Test comprehensive null value handling across all data types""" - cleanup_tables.append(test_table_name) - - loader = PostgreSQLLoader(postgresql_test_config) - - with loader: - result = loader.load_table(null_test_data, test_table_name) - assert result.success == True - assert result.rows_loaded == 10 - - conn = loader.pool.getconn() - try: - with conn.cursor() as cur: - # Check text field nulls (rows 3, 6, 9 have index 2, 5, 8) - cur.execute(f'SELECT COUNT(*) FROM {test_table_name} WHERE text_field IS NULL') - text_nulls = cur.fetchone()[0] - assert text_nulls == 3 - - # Check int field nulls (rows 2, 5, 8 have index 1, 4, 7) - cur.execute(f'SELECT COUNT(*) FROM {test_table_name} WHERE int_field IS NULL') - int_nulls = cur.fetchone()[0] - assert int_nulls == 3 - - # Check float field nulls (rows 3, 6, 9 have index 2, 5, 8) - cur.execute(f'SELECT COUNT(*) FROM {test_table_name} WHERE float_field IS NULL') - float_nulls = cur.fetchone()[0] - assert float_nulls == 3 - - # Check bool field nulls (rows 3, 6, 9 have index 2, 5, 8) - cur.execute(f'SELECT COUNT(*) FROM {test_table_name} WHERE bool_field IS NULL') - bool_nulls = cur.fetchone()[0] - assert bool_nulls == 3 - - # Check timestamp field nulls - # (rows where i % 3 == 0, which are ids 3, 6, 9, plus id 1 due to zero indexing) - cur.execute(f'SELECT COUNT(*) FROM {test_table_name} WHERE timestamp_field IS NULL') - timestamp_nulls = cur.fetchone()[0] - assert timestamp_nulls == 4 - - # Check json field nulls (rows where i % 4 == 0, which are ids 4, 8 due to zero indexing pattern) - cur.execute(f'SELECT COUNT(*) FROM {test_table_name} WHERE json_field IS NULL') - json_nulls = cur.fetchone()[0] - assert json_nulls == 3 - - # Verify non-null values are intact - cur.execute(f'SELECT text_field FROM {test_table_name} WHERE id = 1') - text_val = cur.fetchone()[0] - assert text_val in ['a', '"a"'] # Handle potential CSV quoting - - cur.execute(f'SELECT int_field FROM {test_table_name} WHERE id = 1') - int_val = cur.fetchone()[0] - assert int_val == 1 - - cur.execute(f'SELECT float_field FROM {test_table_name} WHERE id = 1') - float_val = cur.fetchone()[0] - assert abs(float_val - 1.1) < 0.01 - - cur.execute(f'SELECT bool_field FROM {test_table_name} WHERE id = 1') - bool_val = cur.fetchone()[0] - assert bool_val == True - finally: - loader.pool.putconn(conn) - - def test_binary_data_handling(self, postgresql_test_config, test_table_name, cleanup_tables): - """Test binary data handling with INSERT fallback""" - cleanup_tables.append(test_table_name) - - # Create data with binary columns - data = {'id': [1, 2, 3], 'binary_data': [b'hello', b'world', b'test'], 'text_data': ['a', 'b', 'c']} - table = pa.Table.from_pydict(data) - - loader = PostgreSQLLoader(postgresql_test_config) - - with loader: - result = loader.load_table(table, test_table_name) - assert result.success == True - assert result.rows_loaded == 3 - - # Verify binary data was stored correctly - conn = loader.pool.getconn() - try: - with conn.cursor() as cur: - cur.execute(f'SELECT id, binary_data FROM {test_table_name} ORDER BY id') - rows = cur.fetchall() - assert rows[0][1].tobytes() == b'hello' - assert rows[1][1].tobytes() == b'world' - assert rows[2][1].tobytes() == b'test' - finally: - loader.pool.putconn(conn) - - def test_schema_retrieval(self, postgresql_test_config, small_test_data, test_table_name, cleanup_tables): - """Test schema retrieval functionality""" - cleanup_tables.append(test_table_name) - - loader = PostgreSQLLoader(postgresql_test_config) - - with loader: - # Create table - result = loader.load_table(small_test_data, test_table_name) - assert result.success == True - - # Get schema - schema = loader.get_table_schema(test_table_name) - assert schema is not None - - # Filter out metadata columns added by PostgreSQL loader - non_meta_fields = [ - field for field in schema if not (field.name.startswith('_meta_') or field.name.startswith('_amp_')) - ] - - assert len(non_meta_fields) == len(small_test_data.schema) - - # Verify column names match (excluding metadata columns) - original_names = set(small_test_data.schema.names) - retrieved_names = set(field.name for field in non_meta_fields) - assert original_names == retrieved_names - - def test_error_handling(self, postgresql_test_config, small_test_data): - """Test error handling scenarios""" - loader = PostgreSQLLoader(postgresql_test_config) - - with loader: - # Test loading to non-existent table without create_table - result = loader.load_table(small_test_data, 'non_existent_table', create_table=False) - - assert result.success == False - assert result.error is not None - assert result.rows_loaded == 0 - assert 'does not exist' in result.error - - def test_connection_pooling(self, postgresql_test_config, small_test_data, test_table_name, cleanup_tables): - """Test connection pooling behavior""" - cleanup_tables.append(test_table_name) - - loader = PostgreSQLLoader(postgresql_test_config) - - with loader: - # Perform multiple operations to test pool reuse - for i in range(5): - subset = small_test_data.slice(i, 1) - mode = LoadMode.OVERWRITE if i == 0 else LoadMode.APPEND - - result = loader.load_table(subset, test_table_name, mode=mode) - assert result.success == True - - # Verify pool is managing connections properly - # Note: _used is a dict in ThreadedConnectionPool, not an int - assert len(loader.pool._used) <= loader.pool.maxconn - - def test_performance_metrics(self, postgresql_test_config, medium_test_table, test_table_name, cleanup_tables): - """Test performance metrics in results""" - cleanup_tables.append(test_table_name) - - loader = PostgreSQLLoader(postgresql_test_config) - - with loader: - start_time = time.time() - result = loader.load_table(medium_test_table, test_table_name) - end_time = time.time() - - assert result.success == True - assert result.duration > 0 - assert result.duration <= (end_time - start_time) - assert result.rows_loaded == 10000 - - # Check metadata contains performance info - assert 'table_size_bytes' in result.metadata - assert result.metadata['table_size_bytes'] > 0 - - -@pytest.mark.integration -@pytest.mark.postgresql -@pytest.mark.slow -class TestPostgreSQLLoaderPerformance: - """Performance tests for PostgreSQL loader""" - - def test_large_data_loading(self, postgresql_test_config, test_table_name, cleanup_tables): - """Test loading large datasets""" - cleanup_tables.append(test_table_name) - - # Create large dataset - large_data = { - 'id': list(range(50000)), - 'value': [i * 0.123 for i in range(50000)], - 'category': [f'category_{i % 100}' for i in range(50000)], - 'description': [f'This is a longer text description for row {i}' for i in range(50000)], - 'created_at': [datetime.now() for _ in range(50000)], - } - large_table = pa.Table.from_pydict(large_data) - - loader = PostgreSQLLoader(postgresql_test_config) - - with loader: - result = loader.load_table(large_table, test_table_name) - - assert result.success == True - assert result.rows_loaded == 50000 - assert result.duration < 60 # Should complete within 60 seconds - - # Verify data integrity - conn = loader.pool.getconn() - try: - with conn.cursor() as cur: - cur.execute(f'SELECT COUNT(*) FROM {test_table_name}') - count = cur.fetchone()[0] - assert count == 50000 - finally: - loader.pool.putconn(conn) - - -@pytest.mark.integration -@pytest.mark.postgresql -class TestPostgreSQLLoaderStreaming: - """Integration tests for PostgreSQL loader streaming functionality""" - - def test_streaming_metadata_columns(self, postgresql_test_config, test_table_name, cleanup_tables): - """Test that streaming data creates tables with metadata columns""" - cleanup_tables.append(test_table_name) - - # Import streaming types - from src.amp.streaming.types import BlockRange - - # Create test data with metadata - data = { - 'block_number': [100, 101, 102], - 'transaction_hash': ['0xabc', '0xdef', '0x123'], - 'value': [1.0, 2.0, 3.0], - } - batch = pa.RecordBatch.from_pydict(data) - - # Create metadata with block ranges - block_ranges = [BlockRange(network='ethereum', start=100, end=102)] - - loader = PostgreSQLLoader(postgresql_test_config) - - with loader: - # Add metadata columns (simulating what load_stream_continuous does) - batch_with_metadata = loader._add_metadata_columns(batch, block_ranges) - - # Load the batch - result = loader.load_batch(batch_with_metadata, test_table_name, create_table=True) - assert result.success == True - assert result.rows_loaded == 3 - - # Verify metadata columns were created in the table - conn = loader.pool.getconn() - try: - with conn.cursor() as cur: - # Check table schema includes metadata columns - cur.execute( - """ - SELECT column_name, data_type - FROM information_schema.columns - WHERE table_name = %s - ORDER BY ordinal_position - """, - (test_table_name,), - ) - - columns = cur.fetchall() - column_names = [col[0] for col in columns] - - # Should have original columns plus metadata columns - assert '_amp_batch_id' in column_names - - # Verify metadata column types - column_types = {col[0]: col[1] for col in columns} - assert ( - 'text' in column_types['_amp_batch_id'].lower() - or 'varchar' in column_types['_amp_batch_id'].lower() - ) - - # Verify data was stored correctly - cur.execute(f'SELECT "_amp_batch_id" FROM {test_table_name} LIMIT 1') - meta_row = cur.fetchone() - - # _amp_batch_id contains a compact 16-char hex string (or multiple separated by |) - batch_id_str = meta_row[0] - assert batch_id_str is not None - assert isinstance(batch_id_str, str) - assert len(batch_id_str) >= 16 # At least one 16-char batch ID - - finally: - loader.pool.putconn(conn) - - def test_handle_reorg_deletion(self, postgresql_test_config, test_table_name, cleanup_tables): - """Test that _handle_reorg correctly deletes invalidated ranges""" - cleanup_tables.append(test_table_name) - - from src.amp.streaming.types import BatchMetadata, BlockRange, ResponseBatch - - loader = PostgreSQLLoader(postgresql_test_config) - - with loader: - # Create streaming batches with metadata - batch1 = pa.RecordBatch.from_pydict( - { - 'tx_hash': ['0x100', '0x101', '0x102'], - 'block_num': [100, 101, 102], - 'value': [10.0, 11.0, 12.0], - } - ) - batch2 = pa.RecordBatch.from_pydict( - {'tx_hash': ['0x200', '0x201'], 'block_num': [103, 104], 'value': [12.0, 33.0]} - ) - batch3 = pa.RecordBatch.from_pydict( - {'tx_hash': ['0x300', '0x301'], 'block_num': [105, 106], 'value': [7.0, 9.0]} - ) - batch4 = pa.RecordBatch.from_pydict( - {'tx_hash': ['0x400', '0x401'], 'block_num': [107, 108], 'value': [6.0, 73.0]} - ) - - # Create table from first batch schema - loader._create_table_from_schema(batch1.schema, test_table_name) - - # Create response batches with hashes - response1 = ResponseBatch.data_batch( - data=batch1, - metadata=BatchMetadata( - ranges=[BlockRange(network='ethereum', start=100, end=102, hash='0xaaa')], - ranges_complete=True, # Mark as complete so it gets tracked in state store - ), - ) - response2 = ResponseBatch.data_batch( - data=batch2, - metadata=BatchMetadata( - ranges=[BlockRange(network='ethereum', start=103, end=104, hash='0xbbb')], - ranges_complete=True, # Mark as complete so it gets tracked in state store - ), - ) - response3 = ResponseBatch.data_batch( - data=batch3, - metadata=BatchMetadata( - ranges=[BlockRange(network='ethereum', start=105, end=106, hash='0xccc')], - ranges_complete=True, # Mark as complete so it gets tracked in state store - ), - ) - response4 = ResponseBatch.data_batch( - data=batch4, - metadata=BatchMetadata( - ranges=[BlockRange(network='ethereum', start=107, end=108, hash='0xddd')], - ranges_complete=True, # Mark as complete so it gets tracked in state store - ), - ) - - # Load via streaming API - stream = [response1, response2, response3, response4] - results = list(loader.load_stream_continuous(iter(stream), test_table_name)) - assert len(results) == 4 - assert all(r.success for r in results) - - # Verify initial data count - conn = loader.pool.getconn() - try: - with conn.cursor() as cur: - cur.execute(f'SELECT COUNT(*) FROM {test_table_name}') - initial_count = cur.fetchone()[0] - assert initial_count == 9 # 3 + 2 + 2 + 2 - - # Test reorg deletion - invalidate blocks 104-108 on ethereum - reorg_response = ResponseBatch.reorg_batch( - invalidation_ranges=[BlockRange(network='ethereum', start=104, end=108)] - ) - reorg_results = list(loader.load_stream_continuous(iter([reorg_response]), test_table_name)) - assert len(reorg_results) == 1 - assert reorg_results[0].success - - # Should delete batch2, batch3 and batch4 leaving only the 3 rows from batch1 - cur.execute(f'SELECT COUNT(*) FROM {test_table_name}') - after_reorg_count = cur.fetchone()[0] - assert after_reorg_count == 3 - - finally: - loader.pool.putconn(conn) - - def test_reorg_with_overlapping_ranges(self, postgresql_test_config, test_table_name, cleanup_tables): - """Test reorg deletion with overlapping block ranges""" - cleanup_tables.append(test_table_name) - - from src.amp.streaming.types import BatchMetadata, BlockRange, ResponseBatch - - loader = PostgreSQLLoader(postgresql_test_config) - - with loader: - # Load data with overlapping ranges that should be invalidated - batch = pa.RecordBatch.from_pydict( - {'tx_hash': ['0x150', '0x175', '0x250'], 'block_num': [150, 175, 250], 'value': [15.0, 17.5, 25.0]} - ) - - # Create table from batch schema - loader._create_table_from_schema(batch.schema, test_table_name) - - response = ResponseBatch.data_batch( - data=batch, - metadata=BatchMetadata( - ranges=[BlockRange(network='ethereum', start=150, end=175, hash='0xaaa')], - ranges_complete=True, # Mark as complete so it gets tracked in state store - ), - ) - - # Load via streaming API - results = list(loader.load_stream_continuous(iter([response]), test_table_name)) - assert len(results) == 1 - assert results[0].success - - conn = loader.pool.getconn() - try: - with conn.cursor() as cur: - # Verify initial data - cur.execute(f'SELECT COUNT(*) FROM {test_table_name}') - assert cur.fetchone()[0] == 3 - - # Test partial overlap invalidation (160-180) - # This should invalidate our range [150,175] because they overlap - reorg_response = ResponseBatch.reorg_batch( - invalidation_ranges=[BlockRange(network='ethereum', start=160, end=180)] - ) - reorg_results = list(loader.load_stream_continuous(iter([reorg_response]), test_table_name)) - assert len(reorg_results) == 1 - assert reorg_results[0].success - - # All data should be deleted due to overlap - cur.execute(f'SELECT COUNT(*) FROM {test_table_name}') - assert cur.fetchone()[0] == 0 - - finally: - loader.pool.putconn(conn) - - def test_reorg_preserves_different_networks(self, postgresql_test_config, test_table_name, cleanup_tables): - """Test that reorg only affects specified network""" - cleanup_tables.append(test_table_name) - - from src.amp.streaming.types import BatchMetadata, BlockRange, ResponseBatch - - loader = PostgreSQLLoader(postgresql_test_config) - - with loader: - # Load data from multiple networks with same block ranges - batch_eth = pa.RecordBatch.from_pydict( - {'tx_hash': ['0x100_eth'], 'network_id': ['ethereum'], 'block_num': [100], 'value': [10.0]} - ) - batch_poly = pa.RecordBatch.from_pydict( - {'tx_hash': ['0x100_poly'], 'network_id': ['polygon'], 'block_num': [100], 'value': [10.0]} - ) - - # Create table from batch schema - loader._create_table_from_schema(batch_eth.schema, test_table_name) - - response_eth = ResponseBatch.data_batch( - data=batch_eth, - metadata=BatchMetadata( - ranges=[BlockRange(network='ethereum', start=100, end=100, hash='0xaaa')], - ranges_complete=True, # Mark as complete so it gets tracked in state store - ), - ) - response_poly = ResponseBatch.data_batch( - data=batch_poly, - metadata=BatchMetadata( - ranges=[BlockRange(network='polygon', start=100, end=100, hash='0xbbb')], - ranges_complete=True, # Mark as complete so it gets tracked in state store - ), - ) - - # Load both batches via streaming API - stream = [response_eth, response_poly] - results = list(loader.load_stream_continuous(iter(stream), test_table_name)) - assert len(results) == 2 - assert all(r.success for r in results) - - conn = loader.pool.getconn() - try: - with conn.cursor() as cur: - # Verify both networks' data exists - cur.execute(f'SELECT COUNT(*) FROM {test_table_name}') - assert cur.fetchone()[0] == 2 - - # Invalidate only ethereum network - reorg_response = ResponseBatch.reorg_batch( - invalidation_ranges=[BlockRange(network='ethereum', start=100, end=100)] - ) - reorg_results = list(loader.load_stream_continuous(iter([reorg_response]), test_table_name)) - assert len(reorg_results) == 1 - assert reorg_results[0].success - - # Should only delete ethereum data, polygon should remain - cur.execute(f'SELECT COUNT(*) FROM {test_table_name}') - assert cur.fetchone()[0] == 1 - - finally: - loader.pool.putconn(conn) - - def test_microbatch_deduplication(self, postgresql_test_config, test_table_name, cleanup_tables): - """ - Test that multiple RecordBatches within the same microbatch are all loaded, - and deduplication only happens at microbatch boundaries when ranges_complete=True. - - This test verifies the fix for the critical bug where we were marking batches - as processed after every RecordBatch instead of waiting for ranges_complete=True. - """ - from src.amp.streaming.types import BatchMetadata, BlockRange, ResponseBatch - - cleanup_tables.append(test_table_name) - - # Enable state management to test deduplication - config_with_state = { - **postgresql_test_config, - 'state': {'enabled': True, 'storage': 'memory', 'store_batch_id': True}, - } - loader = PostgreSQLLoader(config_with_state) - - with loader: - # Create table first from the schema - batch1_data = pa.RecordBatch.from_pydict({'id': [1, 2], 'value': [100, 200]}) - loader._create_table_from_schema(batch1_data.schema, test_table_name) - - # Simulate a microbatch sent as 3 RecordBatches with the same BlockRange - # This happens when the server sends large microbatches in smaller chunks - - # First RecordBatch of the microbatch (ranges_complete=False) - response1 = ResponseBatch.data_batch( - data=batch1_data, - metadata=BatchMetadata( - ranges=[BlockRange(network='ethereum', start=100, end=110, hash='0xabc123')], - ranges_complete=False, # Not the last batch in this microbatch - ), - ) - - # Second RecordBatch of the microbatch (ranges_complete=False) - batch2_data = pa.RecordBatch.from_pydict({'id': [3, 4], 'value': [300, 400]}) - response2 = ResponseBatch.data_batch( - data=batch2_data, - metadata=BatchMetadata( - ranges=[BlockRange(network='ethereum', start=100, end=110, hash='0xabc123')], # Same BlockRange! - ranges_complete=False, # Still not the last batch - ), - ) - - # Third RecordBatch of the microbatch (ranges_complete=True) - batch3_data = pa.RecordBatch.from_pydict({'id': [5, 6], 'value': [500, 600]}) - response3 = ResponseBatch.data_batch( - data=batch3_data, - metadata=BatchMetadata( - ranges=[BlockRange(network='ethereum', start=100, end=110, hash='0xabc123')], # Same BlockRange! - ranges_complete=True, # Last batch in this microbatch - safe to mark as processed - ), - ) - - # Process the microbatch stream - stream = [response1, response2, response3] - results = list( - loader.load_stream_continuous(iter(stream), test_table_name, connection_name='test_connection') - ) - - # CRITICAL: All 3 RecordBatches should be loaded successfully - # Before the fix, only the first batch would load (the other 2 would be skipped as "duplicates") - assert len(results) == 3, 'All RecordBatches within microbatch should be processed' - assert all(r.success for r in results), 'All batches should succeed' - assert results[0].rows_loaded == 2, 'First batch should load 2 rows' - assert results[1].rows_loaded == 2, 'Second batch should load 2 rows (not skipped!)' - assert results[2].rows_loaded == 2, 'Third batch should load 2 rows (not skipped!)' - - # Verify total rows in table (all batches loaded) - conn = loader.pool.getconn() - try: - with conn.cursor() as cur: - cur.execute(f'SELECT COUNT(*) FROM {test_table_name}') - total_count = cur.fetchone()[0] - assert total_count == 6, 'All 6 rows from 3 RecordBatches should be in the table' - - # Verify the actual IDs are present - cur.execute(f'SELECT id FROM {test_table_name} ORDER BY id') - all_ids = [row[0] for row in cur.fetchall()] - assert all_ids == [1, 2, 3, 4, 5, 6], 'All rows from all RecordBatches should be present' - - finally: - loader.pool.putconn(conn) - - # Now test that re-sending the complete microbatch is properly deduplicated - # This time, the first batch has ranges_complete=True (entire microbatch in one RecordBatch) - duplicate_batch = pa.RecordBatch.from_pydict({'id': [7, 8], 'value': [700, 800]}) - duplicate_response = ResponseBatch.data_batch( - data=duplicate_batch, - metadata=BatchMetadata( - ranges=[ - BlockRange(network='ethereum', start=100, end=110, hash='0xabc123') - ], # Same range as before! - ranges_complete=True, # Complete microbatch - ), - ) - - # Process duplicate microbatch - duplicate_results = list( - loader.load_stream_continuous( - iter([duplicate_response]), test_table_name, connection_name='test_connection' - ) - ) - - # The duplicate microbatch should be skipped (already processed) - assert len(duplicate_results) == 1 - assert duplicate_results[0].success is True - assert duplicate_results[0].rows_loaded == 0, 'Duplicate microbatch should be skipped' - assert duplicate_results[0].metadata.get('operation') == 'skip_duplicate', 'Should be marked as duplicate' - - # Verify row count unchanged (duplicate was skipped) - conn = loader.pool.getconn() - try: - with conn.cursor() as cur: - cur.execute(f'SELECT COUNT(*) FROM {test_table_name}') - final_count = cur.fetchone()[0] - assert final_count == 6, 'Row count should not increase after duplicate microbatch' - - finally: - loader.pool.putconn(conn) diff --git a/tests/integration/test_redis_loader.py b/tests/integration/test_redis_loader.py deleted file mode 100644 index ce23da7..0000000 --- a/tests/integration/test_redis_loader.py +++ /dev/null @@ -1,981 +0,0 @@ -# tests/integration/test_redis_loader.py -""" -Integration tests for Redis loader implementation. -These tests require a running Redis instance. -""" - -import json -import time -from datetime import datetime, timedelta - -import pyarrow as pa -import pytest - -try: - from src.amp.loaders.base import LoadMode - from src.amp.loaders.implementations.redis_loader import RedisLoader -except ImportError: - pytest.skip('amp modules not available', allow_module_level=True) - - -@pytest.fixture -def small_test_data(): - """Create small test data for quick tests""" - data = { - 'id': [1, 2, 3, 4, 5], - 'name': ['Alice', 'Bob', 'Charlie', 'David', 'Eve'], - 'score': [100, 200, 150, 300, 250], - 'active': [True, False, True, False, True], - 'created_at': [datetime.now() - timedelta(days=i) for i in range(5)], - } - return pa.Table.from_pydict(data) - - -@pytest.fixture -def comprehensive_test_data(): - """Create comprehensive test data with various data types""" - base_date = datetime(2024, 1, 1) - - data = { - 'id': list(range(1000)), - 'user_id': [f'user_{i % 100}' for i in range(1000)], - 'score': [i * 10 for i in range(1000)], - 'text_field': [f'text_{i}' for i in range(1000)], - 'float_field': [i * 0.123 for i in range(1000)], - 'bool_field': [i % 2 == 0 for i in range(1000)], - 'timestamp': [(base_date + timedelta(days=i // 10, hours=i % 24)).timestamp() for i in range(1000)], - 'binary_field': [f'binary_{i}'.encode() for i in range(1000)], - 'json_data': [json.dumps({'index': i, 'value': f'val_{i}'}) for i in range(1000)], - 'nullable_field': [i if i % 10 != 0 else None for i in range(1000)], - } - - return pa.Table.from_pydict(data) - - -@pytest.fixture -def cleanup_redis(redis_test_config): - """Cleanup Redis data after tests""" - keys_to_clean = [] - patterns_to_clean = [] - - yield (keys_to_clean, patterns_to_clean) - - # Cleanup - try: - import redis - - r = redis.Redis( - host=redis_test_config['host'], - port=redis_test_config['port'], - db=redis_test_config['db'], - password=redis_test_config['password'], - ) - - # Delete specific keys - for key in keys_to_clean: - r.delete(key) - - # Delete keys matching patterns - for pattern in patterns_to_clean: - for key in r.scan_iter(match=pattern, count=1000): - r.delete(key) - - r.close() - except Exception: - pass - - -@pytest.mark.integration -@pytest.mark.redis -class TestRedisLoaderIntegration: - """Integration tests for Redis loader""" - - def test_loader_connection(self, redis_test_config): - """Test basic connection to Redis""" - loader = RedisLoader(redis_test_config) - - # Test connection - loader.connect() - assert loader._is_connected == True - assert loader.redis_client is not None - assert loader.connection_pool is not None - - # Test health check - health = loader.health_check() - assert health['healthy'] == True - assert 'redis_version' in health - - # Test disconnection - loader.disconnect() - assert loader._is_connected == False - assert loader.redis_client is None - - def test_context_manager(self, redis_test_config, small_test_data, cleanup_redis): - """Test context manager functionality""" - keys_to_clean, patterns_to_clean = cleanup_redis - patterns_to_clean.append('test_context:*') - - loader = RedisLoader({**redis_test_config, 'data_structure': 'hash'}) - - with loader: - assert loader._is_connected == True - - result = loader.load_table(small_test_data, 'test_context') - assert result.success == True - - # Should be disconnected after context - assert loader._is_connected == False - - def test_hash_storage(self, redis_test_config, small_test_data, cleanup_redis): - """Test hash data structure storage""" - keys_to_clean, patterns_to_clean = cleanup_redis - patterns_to_clean.append('test_hash:*') - - config = {**redis_test_config, 'data_structure': 'hash', 'key_pattern': 'test_hash:{id}'} - loader = RedisLoader(config) - - with loader: - # Test loading - result = loader.load_table(small_test_data, 'test_hash') - - assert result.success == True - assert result.rows_loaded == 5 - assert result.metadata['data_structure'] == 'hash' - - # Verify data was stored - for i in range(5): - key = f'test_hash:{i + 1}' - assert loader.redis_client.exists(key) - - # Check specific fields - name = loader.redis_client.hget(key, 'name') - assert name.decode() == ['Alice', 'Bob', 'Charlie', 'David', 'Eve'][i] - - score = loader.redis_client.hget(key, 'score') - assert int(score.decode()) == [100, 200, 150, 300, 250][i] - - def test_string_storage(self, redis_test_config, small_test_data, cleanup_redis): - """Test string (JSON) data structure storage""" - keys_to_clean, patterns_to_clean = cleanup_redis - patterns_to_clean.append('test_string:*') - - config = {**redis_test_config, 'data_structure': 'string', 'key_pattern': 'test_string:{id}'} - loader = RedisLoader(config) - - with loader: - result = loader.load_table(small_test_data, 'test_string') - - assert result.success == True - assert result.rows_loaded == 5 - - # Verify data was stored as JSON - for i in range(5): - key = f'test_string:{i + 1}' - assert loader.redis_client.exists(key) - - # Parse JSON data - json_data = json.loads(loader.redis_client.get(key)) - assert json_data['name'] == ['Alice', 'Bob', 'Charlie', 'David', 'Eve'][i] - assert json_data['score'] == [100, 200, 150, 300, 250][i] - - def test_stream_storage(self, redis_test_config, small_test_data, cleanup_redis): - """Test stream data structure storage""" - keys_to_clean, patterns_to_clean = cleanup_redis - keys_to_clean.append('test_stream:stream') - - config = {**redis_test_config, 'data_structure': 'stream'} - loader = RedisLoader(config) - - with loader: - result = loader.load_table(small_test_data, 'test_stream') - - assert result.success == True - assert result.rows_loaded == 5 - - # Verify stream was created - stream_key = 'test_stream:stream' - assert loader.redis_client.exists(stream_key) - - # Check stream length - info = loader.redis_client.xinfo_stream(stream_key) - assert info['length'] == 5 - - def test_set_storage(self, redis_test_config, small_test_data, cleanup_redis): - """Test set data structure storage""" - keys_to_clean, patterns_to_clean = cleanup_redis - keys_to_clean.append('test_set:set') - - config = {**redis_test_config, 'data_structure': 'set', 'unique_field': 'name'} - loader = RedisLoader(config) - - with loader: - result = loader.load_table(small_test_data, 'test_set') - - assert result.success == True - assert result.rows_loaded == 5 - - # Verify set was created - set_key = 'test_set:set' - assert loader.redis_client.exists(set_key) - assert loader.redis_client.scard(set_key) == 5 - - # Check members - members = loader.redis_client.smembers(set_key) - names = {m.decode() for m in members} - assert names == {'Alice', 'Bob', 'Charlie', 'David', 'Eve'} - - def test_sorted_set_storage(self, redis_test_config, small_test_data, cleanup_redis): - """Test sorted set data structure storage""" - keys_to_clean, patterns_to_clean = cleanup_redis - keys_to_clean.append('test_zset:zset') - - config = {**redis_test_config, 'data_structure': 'sorted_set', 'score_field': 'score'} - loader = RedisLoader(config) - - with loader: - result = loader.load_table(small_test_data, 'test_zset') - - assert result.success == True - assert result.rows_loaded == 5 - - # Verify sorted set was created - zset_key = 'test_zset:zset' - assert loader.redis_client.exists(zset_key) - assert loader.redis_client.zcard(zset_key) == 5 - - # Check score ordering - members_with_scores = loader.redis_client.zrange(zset_key, 0, -1, withscores=True) - scores = [score for _, score in members_with_scores] - assert scores == [100.0, 150.0, 200.0, 250.0, 300.0] # Should be sorted - - def test_list_storage(self, redis_test_config, small_test_data, cleanup_redis): - """Test list data structure storage""" - keys_to_clean, patterns_to_clean = cleanup_redis - keys_to_clean.append('test_list:list') - - config = {**redis_test_config, 'data_structure': 'list'} - loader = RedisLoader(config) - - with loader: - result = loader.load_table(small_test_data, 'test_list') - - assert result.success == True - assert result.rows_loaded == 5 - - # Verify list was created - list_key = 'test_list:list' - assert loader.redis_client.exists(list_key) - assert loader.redis_client.llen(list_key) == 5 - - def test_append_mode(self, redis_test_config, small_test_data, cleanup_redis): - """Test append mode functionality""" - keys_to_clean, patterns_to_clean = cleanup_redis - patterns_to_clean.append('test_append:*') - - config = {**redis_test_config, 'data_structure': 'hash', 'key_pattern': 'test_append:{id}'} - loader = RedisLoader(config) - - with loader: - # Initial load - result = loader.load_table(small_test_data, 'test_append', mode=LoadMode.APPEND) - assert result.success == True - assert result.rows_loaded == 5 - - # Append more data (with different IDs) - # Convert to pydict, modify, and create new table - new_data_dict = small_test_data.to_pydict() - new_data_dict['id'] = [6, 7, 8, 9, 10] - new_table = pa.Table.from_pydict(new_data_dict) - - result = loader.load_table(new_table, 'test_append', mode=LoadMode.APPEND) - assert result.success == True - assert result.rows_loaded == 5 - - # Verify all keys exist - for i in range(1, 11): - key = f'test_append:{i}' - assert loader.redis_client.exists(key) - - def test_overwrite_mode(self, redis_test_config, small_test_data, cleanup_redis): - """Test overwrite mode functionality""" - keys_to_clean, patterns_to_clean = cleanup_redis - patterns_to_clean.append('test_overwrite:*') - - config = {**redis_test_config, 'data_structure': 'hash', 'key_pattern': 'test_overwrite:{id}'} - loader = RedisLoader(config) - - with loader: - # Initial load - result = loader.load_table(small_test_data, 'test_overwrite', mode=LoadMode.OVERWRITE) - assert result.success == True - assert result.rows_loaded == 5 - - # Overwrite with less data - new_data = small_test_data.slice(0, 3) - result = loader.load_table(new_data, 'test_overwrite', mode=LoadMode.OVERWRITE) - assert result.success == True - assert result.rows_loaded == 3 - - # Verify old keys were deleted - assert not loader.redis_client.exists('test_overwrite:4') - assert not loader.redis_client.exists('test_overwrite:5') - - def test_batch_loading(self, redis_test_config, comprehensive_test_data, cleanup_redis): - """Test batch loading functionality""" - keys_to_clean, patterns_to_clean = cleanup_redis - patterns_to_clean.append('test_batch:*') - - config = {**redis_test_config, 'data_structure': 'hash', 'key_pattern': 'test_batch:{id}', 'batch_size': 250} - loader = RedisLoader(config) - - with loader: - # Test loading individual batches - batches = comprehensive_test_data.to_batches(max_chunksize=250) - total_rows = 0 - - for i, batch in enumerate(batches): - mode = LoadMode.OVERWRITE if i == 0 else LoadMode.APPEND - result = loader.load_batch(batch, 'test_batch', mode=mode) - - assert result.success == True - assert result.rows_loaded == batch.num_rows - total_rows += batch.num_rows - - assert total_rows == 1000 - - def test_ttl_functionality(self, redis_test_config, small_test_data, cleanup_redis): - """Test TTL (time-to-live) functionality""" - keys_to_clean, patterns_to_clean = cleanup_redis - patterns_to_clean.append('test_ttl:*') - - config = { - **redis_test_config, - 'data_structure': 'hash', - 'key_pattern': 'test_ttl:{id}', - 'ttl': 2, # 2 seconds TTL - } - loader = RedisLoader(config) - - with loader: - result = loader.load_table(small_test_data, 'test_ttl') - assert result.success == True - - # Check keys exist - key = 'test_ttl:1' - assert loader.redis_client.exists(key) - - # Check TTL is set - ttl = loader.redis_client.ttl(key) - assert ttl > 0 and ttl <= 2 - - # Wait for expiration - time.sleep(3) - assert not loader.redis_client.exists(key) - - def test_null_value_handling(self, redis_test_config, null_test_data, cleanup_redis): - """Test comprehensive null value handling across all data types""" - keys_to_clean, patterns_to_clean = cleanup_redis - patterns_to_clean.append('test_nulls:*') - - config = {**redis_test_config, 'data_structure': 'hash', 'key_pattern': 'test_nulls:{id}'} - loader = RedisLoader(config) - - with loader: - result = loader.load_table(null_test_data, 'test_nulls') - assert result.success == True - assert result.rows_loaded == 10 - - # Verify null values were handled correctly (skipped in hash fields) - for i in range(1, 11): - key = f'test_nulls:{i}' - assert loader.redis_client.exists(key) - - # Check that null fields are not stored (Redis hash skips null values) - fields = loader.redis_client.hkeys(key) - field_names = {f.decode() if isinstance(f, bytes) else f for f in fields} - - # Based on our null test data pattern, verify expected fields - # text_field pattern: ['a', 'b', None, 'd', 'e', None, 'g', 'h', None, 'j'] - # So ids 3, 6, 9 should have None (indices 2, 5, 8) - if i in [3, 6, 9]: # text_field is None for these IDs - assert 'text_field' not in field_names - else: - assert 'text_field' in field_names, f'Expected text_field for id={i}, but got fields: {field_names}' - - # int_field pattern: [1, None, 3, 4, None, 6, 7, None, 9, 10] - # So ids 2, 5, 8 should have None (indices 1, 4, 7) - if i in [2, 5, 8]: # int_field is None for these IDs - assert 'int_field' not in field_names - else: - assert 'int_field' in field_names - - # float_field pattern: [1.1, 2.2, None, 4.4, 5.5, None, 7.7, 8.8, None, 10.0] - # So ids 3, 6, 9 should have None (indices 2, 5, 8) - if i in [3, 6, 9]: # float_field is None for these IDs - assert 'float_field' not in field_names - else: - assert 'float_field' in field_names - - # Verify non-null values are intact - if 'text_field' in field_names: - text_val = loader.redis_client.hget(key, 'text_field') - expected_chars = ['a', 'b', None, 'd', 'e', None, 'g', 'h', None, 'j'] - expected_char = expected_chars[i - 1] # Convert id to index - assert text_val.decode() == expected_char - - if 'int_field' in field_names: - int_val = loader.redis_client.hget(key, 'int_field') - expected_ints = [1, None, 3, 4, None, 6, 7, None, 9, 10] - expected_int = expected_ints[i - 1] # Convert id to index - assert int(int_val.decode()) == expected_int - - def test_null_value_handling_string_structure(self, redis_test_config, null_test_data, cleanup_redis): - """Test null value handling with string (JSON) data structure""" - keys_to_clean, patterns_to_clean = cleanup_redis - patterns_to_clean.append('test_json_nulls:*') - - config = {**redis_test_config, 'data_structure': 'string', 'key_pattern': 'test_json_nulls:{id}'} - loader = RedisLoader(config) - - with loader: - result = loader.load_table(null_test_data, 'test_json_nulls') - assert result.success == True - assert result.rows_loaded == 10 - - for i in range(1, 11): - key = f'test_json_nulls:{i}' - assert loader.redis_client.exists(key) - - json_str = loader.redis_client.get(key) - json_data = json.loads(json_str) - - if i in [3, 6, 9]: - assert 'text_field' not in json_data - else: - expected_chars = ['a', 'b', None, 'd', 'e', None, 'g', 'h', None, 'j'] - expected_char = expected_chars[i - 1] - assert json_data['text_field'] == expected_char - - if i in [2, 5, 8]: - assert 'int_field' not in json_data - else: - expected_ints = [1, None, 3, 4, None, 6, 7, None, 9, 10] - expected_int = expected_ints[i - 1] - assert json_data['int_field'] == expected_int - - def test_binary_data_handling(self, redis_test_config, cleanup_redis): - """Test binary data handling""" - keys_to_clean, patterns_to_clean = cleanup_redis - patterns_to_clean.append('test_binary:*') - - # Create data with binary columns - data = {'id': [1, 2, 3], 'binary_data': [b'hello', b'world', b'\x00\x01\x02\x03'], 'text_data': ['a', 'b', 'c']} - table = pa.Table.from_pydict(data) - - config = {**redis_test_config, 'data_structure': 'hash', 'key_pattern': 'test_binary:{id}'} - loader = RedisLoader(config) - - with loader: - result = loader.load_table(table, 'test_binary') - assert result.success == True - assert result.rows_loaded == 3 - - # Verify binary data was stored correctly - assert loader.redis_client.hget('test_binary:1', 'binary_data') == b'hello' - assert loader.redis_client.hget('test_binary:2', 'binary_data') == b'world' - assert loader.redis_client.hget('test_binary:3', 'binary_data') == b'\x00\x01\x02\x03' - - def test_comprehensive_stats(self, redis_test_config, small_test_data, cleanup_redis): - """Test comprehensive statistics functionality""" - keys_to_clean, patterns_to_clean = cleanup_redis - patterns_to_clean.append('test_stats:*') - - config = {**redis_test_config, 'data_structure': 'hash', 'key_pattern': 'test_stats:{id}'} - loader = RedisLoader(config) - - with loader: - result = loader.load_table(small_test_data, 'test_stats') - assert result.success == True - - # Get stats - stats = loader.get_comprehensive_stats('test_stats') - assert 'table_name' in stats - assert stats['data_structure'] == 'hash' - assert stats['key_count'] == 5 - assert 'estimated_memory_bytes' in stats - assert 'estimated_memory_mb' in stats - - def test_error_handling(self, redis_test_config, small_test_data): - """Test error handling scenarios""" - # Test with invalid configuration - invalid_config = {**redis_test_config, 'host': 'invalid-host-that-does-not-exist', 'socket_connect_timeout': 1} - loader = RedisLoader(invalid_config) - - import redis - - with pytest.raises(redis.exceptions.ConnectionError): - loader.connect() - - def test_key_pattern_generation(self, redis_test_config, cleanup_redis): - """Test various key pattern generations""" - keys_to_clean, patterns_to_clean = cleanup_redis - patterns_to_clean.append('complex:*') - - # Create data with multiple fields for complex key - data = {'user_id': ['u1', 'u2', 'u3'], 'session_id': ['s1', 's2', 's3'], 'timestamp': [100, 200, 300]} - table = pa.Table.from_pydict(data) - - config = { - **redis_test_config, - 'data_structure': 'hash', - 'key_pattern': 'complex:{user_id}:{session_id}:{timestamp}', - } - loader = RedisLoader(config) - - with loader: - result = loader.load_table(table, 'complex') - assert result.success == True - - # Verify complex keys were created - assert loader.redis_client.exists('complex:u1:s1:100') - assert loader.redis_client.exists('complex:u2:s2:200') - assert loader.redis_client.exists('complex:u3:s3:300') - - def test_performance_metrics(self, redis_test_config, comprehensive_test_data, cleanup_redis): - """Test performance metrics in results""" - keys_to_clean, patterns_to_clean = cleanup_redis - patterns_to_clean.append('test_perf:*') - - config = { - **redis_test_config, - 'data_structure': 'hash', - 'key_pattern': 'test_perf:{id}', - 'batch_size': 100, - 'pipeline_size': 500, - } - loader = RedisLoader(config) - - with loader: - start_time = time.time() - result = loader.load_table(comprehensive_test_data, 'test_perf') - end_time = time.time() - - assert result.success == True - assert result.duration > 0 - assert result.duration <= (end_time - start_time) - - # Check performance info - ops_per_second is now a direct attribute - assert hasattr(result, 'ops_per_second') - assert result.ops_per_second > 0 - assert 'batches_processed' in result.metadata - assert 'avg_batch_size' in result.metadata - - -@pytest.mark.integration -@pytest.mark.redis -@pytest.mark.slow -class TestRedisLoaderPerformance: - """Performance tests for Redis loader""" - - def test_large_data_loading(self, redis_test_config, cleanup_redis): - """Test loading large datasets""" - keys_to_clean, patterns_to_clean = cleanup_redis - patterns_to_clean.append('test_large:*') - - # Create large dataset - large_data = { - 'id': list(range(10000)), - 'value': [i * 0.123 for i in range(10000)], - 'category': [f'category_{i % 100}' for i in range(10000)], - 'description': [f'This is a longer text description for row {i}' for i in range(10000)], - 'active': [i % 2 == 0 for i in range(10000)], - } - large_table = pa.Table.from_pydict(large_data) - - config = { - **redis_test_config, - 'data_structure': 'hash', - 'key_pattern': 'test_large:{id}', - 'batch_size': 1000, - 'pipeline_size': 1000, - } - loader = RedisLoader(config) - - with loader: - result = loader.load_table(large_table, 'test_large') - - assert result.success == True - assert result.rows_loaded == 10000 - assert result.duration < 30 # Should complete within 30 seconds - - # Verify performance metrics - assert result.ops_per_second > 100 # Should handle >100 ops/sec - - def test_data_structure_performance_comparison(self, redis_test_config, cleanup_redis): - """Compare performance across different data structures""" - keys_to_clean, patterns_to_clean = cleanup_redis - - # Create test data - data = {'id': list(range(1000)), 'score': list(range(1000)), 'data': [f'value_{i}' for i in range(1000)]} - table = pa.Table.from_pydict(data) - - structures = ['hash', 'string', 'set', 'sorted_set', 'list'] - results = {} - - for structure in structures: - patterns_to_clean.append(f'perf_{structure}:*') - keys_to_clean.append(f'perf_{structure}:{structure}') - - config = { - **redis_test_config, - 'data_structure': structure, - 'key_pattern': f'perf_{structure}:{{id}}', - 'score_field': 'score' if structure == 'sorted_set' else None, - } - loader = RedisLoader(config) - - with loader: - result = loader.load_table(table, f'perf_{structure}') - results[structure] = result.ops_per_second - - # All structures should perform reasonably well - for structure, ops_per_sec in results.items(): - assert ops_per_sec > 50, f'{structure} performance too low: {ops_per_sec} ops/sec' - - -@pytest.mark.integration -@pytest.mark.redis -class TestRedisLoaderStreaming: - """Integration tests for Redis loader streaming functionality""" - - def test_streaming_metadata_columns(self, redis_test_config, cleanup_redis): - """Test that streaming data stores batch ID metadata""" - keys_to_clean, patterns_to_clean = cleanup_redis - table_name = 'streaming_test' - patterns_to_clean.append(f'{table_name}:*') - patterns_to_clean.append(f'block_index:{table_name}:*') - - # Import streaming types - from src.amp.streaming.types import BatchMetadata, BlockRange, ResponseBatch - - # Create test data with metadata - data = { - 'id': [1, 2, 3], # Required for Redis key generation - 'block_number': [100, 101, 102], - 'transaction_hash': ['0xabc', '0xdef', '0x123'], - 'value': [1.0, 2.0, 3.0], - } - batch = pa.RecordBatch.from_pydict(data) - - # Create metadata with block ranges - block_ranges = [BlockRange(network='ethereum', start=100, end=102, hash='0xabc')] - - config = {**redis_test_config, 'data_structure': 'hash'} - loader = RedisLoader(config) - - with loader: - # Load via streaming API - response = ResponseBatch.data_batch(data=batch, metadata=BatchMetadata(ranges=block_ranges)) - results = list(loader.load_stream_continuous(iter([response]), table_name)) - assert len(results) == 1 - assert results[0].success == True - assert results[0].rows_loaded == 3 - - # Verify data was stored - primary_keys = [f'{table_name}:1', f'{table_name}:2', f'{table_name}:3'] - for key in primary_keys: - assert loader.redis_client.exists(key) - # Check that batch_id metadata was stored - batch_id_field = loader.redis_client.hget(key, '_amp_batch_id') - assert batch_id_field is not None - batch_id_str = batch_id_field.decode('utf-8') - assert isinstance(batch_id_str, str) - assert len(batch_id_str) >= 16 # At least one 16-char batch ID - - def test_handle_reorg_deletion(self, redis_test_config, cleanup_redis): - """Test that _handle_reorg correctly deletes invalidated ranges""" - keys_to_clean, patterns_to_clean = cleanup_redis - table_name = 'reorg_test' - patterns_to_clean.append(f'{table_name}:*') - patterns_to_clean.append(f'block_index:{table_name}:*') - - from src.amp.streaming.types import BatchMetadata, BlockRange, ResponseBatch - - config = {**redis_test_config, 'data_structure': 'hash'} - loader = RedisLoader(config) - - with loader: - # Create streaming batches with metadata - batch1 = pa.RecordBatch.from_pydict( - { - 'id': [1, 2, 3], # Required for Redis key generation - 'tx_hash': ['0x100', '0x101', '0x102'], - 'block_num': [100, 101, 102], - 'value': [10.0, 11.0, 12.0], - } - ) - batch2 = pa.RecordBatch.from_pydict( - {'id': [4, 5], 'tx_hash': ['0x200', '0x201'], 'block_num': [103, 104], 'value': [13.0, 14.0]} - ) - batch3 = pa.RecordBatch.from_pydict( - {'id': [6, 7], 'tx_hash': ['0x300', '0x301'], 'block_num': [105, 106], 'value': [15.0, 16.0]} - ) - - # Create response batches with hashes - response1 = ResponseBatch.data_batch( - data=batch1, - metadata=BatchMetadata( - ranges=[BlockRange(network='ethereum', start=100, end=102, hash='0xaaa')], - ranges_complete=True, # Mark as complete so it gets tracked in state store - ), - ) - response2 = ResponseBatch.data_batch( - data=batch2, - metadata=BatchMetadata( - ranges=[BlockRange(network='ethereum', start=103, end=104, hash='0xbbb')], - ranges_complete=True, # Mark as complete so it gets tracked in state store - ), - ) - response3 = ResponseBatch.data_batch( - data=batch3, - metadata=BatchMetadata( - ranges=[BlockRange(network='ethereum', start=105, end=106, hash='0xccc')], - ranges_complete=True, # Mark as complete so it gets tracked in state store - ), - ) - - # Load via streaming API - stream = [response1, response2, response3] - results = list(loader.load_stream_continuous(iter(stream), table_name)) - assert len(results) == 3 - assert all(r.success for r in results) - - # Verify initial data - initial_keys = [] - pattern = f'{table_name}:*' - for key in loader.redis_client.scan_iter(match=pattern): - if not key.decode('utf-8').startswith('block_index'): - initial_keys.append(key) - assert len(initial_keys) == 7 # 3 + 2 + 2 - - # Test reorg deletion - invalidate blocks 104-108 on ethereum - reorg_response = ResponseBatch.reorg_batch( - invalidation_ranges=[BlockRange(network='ethereum', start=104, end=108)] - ) - reorg_results = list(loader.load_stream_continuous(iter([reorg_response]), table_name)) - assert len(reorg_results) == 1 - assert reorg_results[0].success - - # Should delete batch2 and batch3, leaving only batch1 (3 keys) - remaining_keys = [] - for key in loader.redis_client.scan_iter(match=pattern): - if not key.decode('utf-8').startswith('block_index'): - remaining_keys.append(key) - assert len(remaining_keys) == 3 - - def test_reorg_with_overlapping_ranges(self, redis_test_config, cleanup_redis): - """Test reorg deletion with overlapping block ranges""" - keys_to_clean, patterns_to_clean = cleanup_redis - table_name = 'overlap_test' - patterns_to_clean.append(f'{table_name}:*') - patterns_to_clean.append(f'block_index:{table_name}:*') - - from src.amp.streaming.types import BatchMetadata, BlockRange, ResponseBatch - - config = {**redis_test_config, 'data_structure': 'hash'} - loader = RedisLoader(config) - - with loader: - # Load data with overlapping ranges that should be invalidated - batch = pa.RecordBatch.from_pydict( - { - 'id': [1, 2, 3], - 'tx_hash': ['0x150', '0x175', '0x250'], - 'block_num': [150, 175, 250], - 'value': [15.0, 17.5, 25.0], - } - ) - - response = ResponseBatch.data_batch( - data=batch, - metadata=BatchMetadata( - ranges=[BlockRange(network='ethereum', start=150, end=175, hash='0xaaa')], - ranges_complete=True, # Mark as complete so it gets tracked in state store - ), - ) - - # Load via streaming API - results = list(loader.load_stream_continuous(iter([response]), table_name)) - assert len(results) == 1 - assert results[0].success - - # Verify initial data - pattern = f'{table_name}:*' - initial_keys = [] - for key in loader.redis_client.scan_iter(match=pattern): - if not key.decode('utf-8').startswith('block_index'): - initial_keys.append(key) - assert len(initial_keys) == 3 - - # Test partial overlap invalidation (160-180) - # This should invalidate our range [150,175] because they overlap - reorg_response = ResponseBatch.reorg_batch( - invalidation_ranges=[BlockRange(network='ethereum', start=160, end=180)] - ) - reorg_results = list(loader.load_stream_continuous(iter([reorg_response]), table_name)) - assert len(reorg_results) == 1 - assert reorg_results[0].success - - # All data should be deleted due to overlap - remaining_keys = [] - for key in loader.redis_client.scan_iter(match=pattern): - if not key.decode('utf-8').startswith('block_index'): - remaining_keys.append(key) - assert len(remaining_keys) == 0 - - def test_reorg_preserves_different_networks(self, redis_test_config, cleanup_redis): - """Test that reorg only affects specified network""" - keys_to_clean, patterns_to_clean = cleanup_redis - table_name = 'multinetwork_test' - patterns_to_clean.append(f'{table_name}:*') - patterns_to_clean.append(f'block_index:{table_name}:*') - - from src.amp.streaming.types import BatchMetadata, BlockRange, ResponseBatch - - config = {**redis_test_config, 'data_structure': 'hash'} - loader = RedisLoader(config) - - with loader: - # Load data from multiple networks with same block ranges - batch_eth = pa.RecordBatch.from_pydict( - { - 'id': [1], - 'tx_hash': ['0x100_eth'], - 'network_id': ['ethereum'], - 'block_num': [100], - 'value': [10.0], - } - ) - batch_poly = pa.RecordBatch.from_pydict( - { - 'id': [2], - 'tx_hash': ['0x100_poly'], - 'network_id': ['polygon'], - 'block_num': [100], - 'value': [10.0], - } - ) - - response_eth = ResponseBatch.data_batch( - data=batch_eth, - metadata=BatchMetadata( - ranges=[BlockRange(network='ethereum', start=100, end=100, hash='0xaaa')], - ranges_complete=True, # Mark as complete so it gets tracked in state store - ), - ) - response_poly = ResponseBatch.data_batch( - data=batch_poly, - metadata=BatchMetadata( - ranges=[BlockRange(network='polygon', start=100, end=100, hash='0xbbb')], - ranges_complete=True, # Mark as complete so it gets tracked in state store - ), - ) - - # Load both batches via streaming API - stream = [response_eth, response_poly] - results = list(loader.load_stream_continuous(iter(stream), table_name)) - assert len(results) == 2 - assert all(r.success for r in results) - - # Verify both networks' data exists - pattern = f'{table_name}:*' - initial_keys = [] - for key in loader.redis_client.scan_iter(match=pattern): - if not key.decode('utf-8').startswith('block_index'): - initial_keys.append(key) - assert len(initial_keys) == 2 - - # Invalidate only ethereum network - reorg_response = ResponseBatch.reorg_batch( - invalidation_ranges=[BlockRange(network='ethereum', start=100, end=100)] - ) - reorg_results = list(loader.load_stream_continuous(iter([reorg_response]), table_name)) - assert len(reorg_results) == 1 - assert reorg_results[0].success - - # Should only delete ethereum data, polygon should remain - remaining_keys = [] - for key in loader.redis_client.scan_iter(match=pattern): - if not key.decode('utf-8').startswith('block_index'): - remaining_keys.append(key) - assert len(remaining_keys) == 1 - - # Verify remaining data is from polygon (just check batch_id exists) - remaining_key = remaining_keys[0] - batch_id_field = loader.redis_client.hget(remaining_key, '_amp_batch_id') - assert batch_id_field is not None - # Batch ID is a compact string, not network-specific, so we just verify it exists - - def test_streaming_with_string_data_structure(self, redis_test_config, cleanup_redis): - """Test streaming support with string data structure""" - keys_to_clean, patterns_to_clean = cleanup_redis - table_name = 'string_streaming_test' - patterns_to_clean.append(f'{table_name}:*') - patterns_to_clean.append(f'block_index:{table_name}:*') - - from src.amp.streaming.types import BatchMetadata, BlockRange, ResponseBatch - - config = {**redis_test_config, 'data_structure': 'string'} - loader = RedisLoader(config) - - with loader: - # Create test data - data = { - 'id': [1, 2, 3], - 'transaction_hash': ['0xaaa', '0xbbb', '0xccc'], - 'value': [100.0, 200.0, 300.0], - } - batch = pa.RecordBatch.from_pydict(data) - block_ranges = [BlockRange(network='polygon', start=200, end=202, hash='0xabc')] - - # Load via streaming API - response = ResponseBatch.data_batch( - data=batch, - metadata=BatchMetadata( - ranges=block_ranges, - ranges_complete=True, # Mark as complete so it gets tracked in state store - ), - ) - results = list(loader.load_stream_continuous(iter([response]), table_name)) - assert len(results) == 1 - assert results[0].success == True - assert results[0].rows_loaded == 3 - - # Verify data was stored as JSON strings - for _i, id_val in enumerate([1, 2, 3]): - key = f'{table_name}:{id_val}' - assert loader.redis_client.exists(key) - - # Get and parse JSON data - json_data = loader.redis_client.get(key) - parsed_data = json.loads(json_data.decode('utf-8')) - assert '_amp_batch_id' in parsed_data - batch_id_str = parsed_data['_amp_batch_id'] - assert isinstance(batch_id_str, str) - assert len(batch_id_str) >= 16 # At least one 16-char batch ID - - # Verify reorg handling works with string data structure - reorg_response = ResponseBatch.reorg_batch( - invalidation_ranges=[BlockRange(network='polygon', start=201, end=205)] - ) - reorg_results = list(loader.load_stream_continuous(iter([reorg_response]), table_name)) - assert len(reorg_results) == 1 - assert reorg_results[0].success - - # All data should be deleted since ranges overlap - pattern = f'{table_name}:*' - remaining_keys = [] - for key in loader.redis_client.scan_iter(match=pattern): - if not key.decode('utf-8').startswith('block_index'): - remaining_keys.append(key) - assert len(remaining_keys) == 0 diff --git a/tests/integration/test_snowflake_loader.py b/tests/integration/test_snowflake_loader.py deleted file mode 100644 index 50f96c7..0000000 --- a/tests/integration/test_snowflake_loader.py +++ /dev/null @@ -1,1215 +0,0 @@ -# tests/integration/test_snowflake_loader.py -""" -Integration tests for Snowflake loader implementation. -These tests require a running Snowflake instance with proper credentials. - -NOTE: Snowflake integration tests are currently disabled because they require an active snowflake account -To re-enable these tests: -1. Set up a valid Snowflake account with billing enabled -2. Configure the following environment variables: - - SNOWFLAKE_ACCOUNT - - SNOWFLAKE_USER - - SNOWFLAKE_PASSWORD - - SNOWFLAKE_WAREHOUSE - - SNOWFLAKE_DATABASE - - SNOWFLAKE_SCHEMA (optional, defaults to PUBLIC) -3. Remove the skip decorator below -""" - -import time -from datetime import datetime - -import pyarrow as pa -import pytest - -try: - from src.amp.loaders.base import LoadMode - from src.amp.loaders.implementations.snowflake_loader import SnowflakeLoader - from src.amp.streaming.types import BatchMetadata, BlockRange, ResponseBatch -except ImportError: - pytest.skip('amp modules not available', allow_module_level=True) - - -def wait_for_snowpipe_data(loader, table_name, expected_count, max_wait=30, poll_interval=2): - """ - Wait for Snowpipe streaming data to become queryable. - - Snowpipe streaming has eventual consistency, so data may not be immediately - queryable after insertion. This helper polls until the expected row count is visible. - - Args: - loader: SnowflakeLoader instance with active connection - table_name: Name of the table to query - expected_count: Expected number of rows - max_wait: Maximum seconds to wait (default 30) - poll_interval: Seconds between poll attempts (default 2) - - Returns: - int: Actual row count found - - Raises: - AssertionError: If expected count not reached within max_wait seconds - """ - elapsed = 0 - while elapsed < max_wait: - loader.cursor.execute(f'SELECT COUNT(*) FROM {table_name}') - count = loader.cursor.fetchone()['COUNT(*)'] - if count == expected_count: - return count - time.sleep(poll_interval) - elapsed += poll_interval - - # Final check before giving up - loader.cursor.execute(f'SELECT COUNT(*) FROM {table_name}') - count = loader.cursor.fetchone()['COUNT(*)'] - assert count == expected_count, f'Expected {expected_count} rows after {max_wait}s, but found {count}' - return count - - -# Skip all Snowflake tests -pytestmark = pytest.mark.skip(reason='Requires active Snowflake account - see module docstring for details') - - -@pytest.fixture -def test_table_name(): - """Generate unique table name for each test""" - timestamp = datetime.now().strftime('%Y%m%d_%H%M%S_%f') - return f'test_table_{timestamp}' - - -@pytest.fixture -def cleanup_tables(snowflake_config): - """Cleanup test tables after tests""" - tables_to_clean = [] - - yield tables_to_clean - - loader = SnowflakeLoader(snowflake_config) - try: - loader.connect() - for table in tables_to_clean: - try: - loader.cursor.execute(f'DROP TABLE IF EXISTS {table}') - loader.connection.commit() - except Exception: - pass - except Exception: - pass - finally: - if loader._is_connected: - loader.disconnect() - - -@pytest.mark.integration -@pytest.mark.snowflake -class TestSnowflakeLoaderIntegration: - """Integration tests for Snowflake loader""" - - def test_loader_connection(self, snowflake_config): - """Test basic connection to Snowflake""" - loader = SnowflakeLoader(snowflake_config) - - loader.connect() - assert loader._is_connected is True - assert loader.connection is not None - assert loader.cursor is not None - - loader.disconnect() - assert loader._is_connected is False - assert loader.connection is None - assert loader.cursor is None - - def test_basic_table_loading_via_stage(self, snowflake_config, small_test_table, test_table_name, cleanup_tables): - """Test basic table loading using stage""" - cleanup_tables.append(test_table_name) - - config = {**snowflake_config, 'loading_method': 'stage'} - loader = SnowflakeLoader(config) - - with loader: - result = loader.load_table(small_test_table, test_table_name, create_table=True) - - assert result.success is True - assert result.rows_loaded == small_test_table.num_rows - assert result.table_name == test_table_name - assert result.loader_type == 'snowflake' - assert result.metadata['loading_method'] == 'stage' - - loader.cursor.execute(f'SELECT COUNT(*) FROM {test_table_name}') - count = loader.cursor.fetchone()['COUNT(*)'] - assert count == small_test_table.num_rows - - def test_basic_table_loading_via_insert(self, snowflake_config, small_test_table, test_table_name, cleanup_tables): - """Test basic table loading using INSERT (Note: currently defaults to stage for performance)""" - cleanup_tables.append(test_table_name) - - # Use insert loading (Note: implementation may default to stage for small tables) - config = {**snowflake_config, 'loading_method': 'insert'} - loader = SnowflakeLoader(config) - - with loader: - result = loader.load_table(small_test_table, test_table_name, create_table=True) - - assert result.success is True - assert result.rows_loaded == small_test_table.num_rows - # Note: Implementation uses stage by default for performance - assert result.metadata['loading_method'] in ['insert', 'stage'] - - loader.cursor.execute(f'SELECT COUNT(*) FROM {test_table_name}') - count = loader.cursor.fetchone()['COUNT(*)'] - assert count == small_test_table.num_rows - - def test_batch_loading(self, snowflake_config, medium_test_table, test_table_name, cleanup_tables): - """Test loading data in batches""" - cleanup_tables.append(test_table_name) - - loader = SnowflakeLoader(snowflake_config) - - with loader: - # Use smaller batch size to force multiple batches (medium_test_table has 10000 rows) - result = loader.load_table(medium_test_table, test_table_name, create_table=True, batch_size=5000) - - assert result.success is True - assert result.rows_loaded == medium_test_table.num_rows - # Implementation may optimize batching, so just check >= 1 - assert result.metadata.get('batches_processed', 1) >= 1 - - loader.cursor.execute(f'SELECT COUNT(*) FROM {test_table_name}') - count = loader.cursor.fetchone()['COUNT(*)'] - assert count == medium_test_table.num_rows - - def test_overwrite_mode(self, snowflake_config, small_test_table, test_table_name, cleanup_tables): - """Test OVERWRITE mode is not supported""" - cleanup_tables.append(test_table_name) - - loader = SnowflakeLoader(snowflake_config) - - with loader: - result1 = loader.load_table(small_test_table, test_table_name, create_table=True) - assert result1.success is True - - # OVERWRITE mode should fail with error message - result2 = loader.load_table(small_test_table, test_table_name, mode=LoadMode.OVERWRITE) - assert result2.success is False - assert 'Unsupported mode LoadMode.OVERWRITE' in result2.error - - def test_append_mode(self, snowflake_config, small_test_table, test_table_name, cleanup_tables): - """Test APPEND mode adds to existing data""" - cleanup_tables.append(test_table_name) - - loader = SnowflakeLoader(snowflake_config) - - with loader: - result1 = loader.load_table(small_test_table, test_table_name, create_table=True) - assert result1.success is True - - result2 = loader.load_table(small_test_table, test_table_name, mode=LoadMode.APPEND) - assert result2.success is True - - # Should have double the rows - loader.cursor.execute(f'SELECT COUNT(*) FROM {test_table_name}') - count = loader.cursor.fetchone()['COUNT(*)'] - assert count == small_test_table.num_rows * 2 - - def test_comprehensive_data_types(self, snowflake_config, comprehensive_test_data, test_table_name, cleanup_tables): - """Test various data types from comprehensive test data""" - cleanup_tables.append(test_table_name) - - loader = SnowflakeLoader(snowflake_config) - - with loader: - result = loader.load_table(comprehensive_test_data, test_table_name, create_table=True) - assert result.success is True - - loader.cursor.execute(f""" - SELECT - "id", - "user_id", - "transaction_amount", - "category", - "timestamp", - "is_weekend", - "score", - "active" - FROM {test_table_name} - WHERE "id" = 0 - """) - - row = loader.cursor.fetchone() - assert row['id'] == 0 - assert row['user_id'] == 'user_0' - assert abs(row['transaction_amount'] - 0.0) < 0.001 - assert row['category'] == 'electronics' - assert row['is_weekend'] is True # id=0 is weekend - assert abs(row['score'] - 0.0) < 0.001 - assert row['active'] is True - - def test_null_handling(self, snowflake_config, null_test_data, test_table_name, cleanup_tables): - """Test proper handling of NULL values""" - cleanup_tables.append(test_table_name) - - loader = SnowflakeLoader(snowflake_config) - - with loader: - result = loader.load_table(null_test_data, test_table_name, create_table=True) - assert result.success is True - - loader.cursor.execute(f""" - SELECT COUNT(*) as null_count - FROM {test_table_name} - WHERE "text_field" IS NULL - """) - - null_count = loader.cursor.fetchone()['NULL_COUNT'] - expected_nulls = sum(1 for val in null_test_data.column('text_field').to_pylist() if val is None) - assert null_count == expected_nulls - - loader.cursor.execute(f""" - SELECT - COUNT(CASE WHEN "int_field" IS NULL THEN 1 END) as int_nulls, - COUNT(CASE WHEN "float_field" IS NULL THEN 1 END) as float_nulls, - COUNT(CASE WHEN "bool_field" IS NULL THEN 1 END) as bool_nulls - FROM {test_table_name} - """) - - null_counts = loader.cursor.fetchone() - assert null_counts['INT_NULLS'] > 0 - assert null_counts['FLOAT_NULLS'] > 0 - assert null_counts['BOOL_NULLS'] > 0 - - def test_table_info(self, snowflake_config, small_test_table, test_table_name, cleanup_tables): - """Test getting table information""" - cleanup_tables.append(test_table_name) - - loader = SnowflakeLoader(snowflake_config) - - with loader: - result = loader.load_table(small_test_table, test_table_name, create_table=True) - assert result.success is True - - info = loader.get_table_info(test_table_name) - - assert info is not None - assert info['table_name'] == test_table_name.upper() - assert info['schema'] == snowflake_config.get('schema', 'PUBLIC') - # Table should have original columns + _amp_batch_id metadata column - assert len(info['columns']) == len(small_test_table.schema) + 1 - - # Verify _amp_batch_id column exists - batch_id_col = next((col for col in info['columns'] if col['name'].lower() == '_amp_batch_id'), None) - assert batch_id_col is not None, 'Expected _amp_batch_id metadata column' - - # In Snowflake, quoted column names are case-sensitive but INFORMATION_SCHEMA may return them differently - # Let's find the ID column by looking for either case variant - id_col = None - for col in info['columns']: - if col['name'].upper() == 'ID' or col['name'] == 'id': - id_col = col - break - - assert id_col is not None, f'Could not find ID column in {[col["name"] for col in info["columns"]]}' - assert 'INT' in id_col['type'] or 'NUMBER' in id_col['type'] - - @pytest.mark.slow - def test_performance_batch_loading(self, snowflake_config, performance_test_data, test_table_name, cleanup_tables): - """Test performance with larger dataset""" - cleanup_tables.append(test_table_name) - - config = {**snowflake_config, 'loading_method': 'stage'} - loader = SnowflakeLoader(config) - - with loader: - start_time = time.time() - - result = loader.load_table(performance_test_data, test_table_name, create_table=True) - - duration = time.time() - start_time - - assert result.success is True - assert result.rows_loaded == performance_test_data.num_rows - - rows_per_second = result.rows_loaded / duration - mb_per_second = (performance_test_data.nbytes / 1024 / 1024) / duration - - print('\nPerformance metrics:') - print(f' Total rows: {result.rows_loaded:,}') - print(f' Duration: {duration:.2f}s') - print(f' Throughput: {rows_per_second:,.0f} rows/sec') - print(f' Data rate: {mb_per_second:.2f} MB/sec') - print(f' Batches: {result.metadata.get("batches_processed", "N/A")}') - - def test_error_handling_invalid_table(self, snowflake_config, small_test_table): - """Test error handling for invalid table operations""" - loader = SnowflakeLoader(snowflake_config) - - with loader: - # Try to load without creating table - result = loader.load_table(small_test_table, 'non_existent_table_xyz', create_table=False) - - assert result.success is False - assert result.error is not None - - @pytest.mark.slow - def test_concurrent_batch_loading(self, snowflake_config, medium_test_table, test_table_name, cleanup_tables): - """Test loading multiple batches concurrently""" - cleanup_tables.append(test_table_name) - - loader = SnowflakeLoader(snowflake_config) - - with loader: - # Create table first - batch = medium_test_table.to_batches(max_chunksize=1)[0] - loader.load_batch(batch, test_table_name, create_table=True) - - # Load multiple batches - total_rows = 0 - for _i, batch in enumerate(medium_test_table.to_batches(max_chunksize=500)): - result = loader.load_batch(batch, test_table_name) - assert result.success is True - total_rows += result.rows_loaded - - assert total_rows == medium_test_table.num_rows - - # Verify all data loaded - loader.cursor.execute(f'SELECT COUNT(*) FROM {test_table_name}') - count = loader.cursor.fetchone()['COUNT(*)'] - assert count == medium_test_table.num_rows + 1 # +1 for initial batch - - # Removed test_stage_and_compression_options - compression parameter not supported in current config - - def test_schema_with_special_characters(self, snowflake_config, test_table_name, cleanup_tables): - """Test handling of column names with special characters""" - cleanup_tables.append(test_table_name) - - data = { - 'user-id': [1, 2, 3], - 'first name': ['Alice', 'Bob', 'Charlie'], - 'total$amount': [100.0, 200.0, 300.0], - '2024_data': ['a', 'b', 'c'], - } - special_table = pa.Table.from_pydict(data) - - loader = SnowflakeLoader(snowflake_config) - - with loader: - result = loader.load_table(special_table, test_table_name, create_table=True) - assert result.success is True - - # Verify row count - loader.cursor.execute(f'SELECT COUNT(*) FROM {test_table_name}') - count = loader.cursor.fetchone()['COUNT(*)'] - assert count == 3 - - # Verify we can query columns with special characters - # Note: Snowflake typically converts column names to uppercase and may need quoting - loader.cursor.execute(f""" - SELECT - "user-id", - "first name", - "total$amount", - "2024_data" - FROM {test_table_name} - WHERE "user-id" = 1 - """) - - row = loader.cursor.fetchone() - - assert row['user-id'] == 1 - assert row['first name'] == 'Alice' - assert abs(row['total$amount'] - 100.0) < 0.001 - assert row['2024_data'] == 'a' - - def test_handle_reorg_no_metadata_column(self, snowflake_config, test_table_name, cleanup_tables): - """Test reorg handling when table lacks metadata column""" - from src.amp.streaming.types import BlockRange - - cleanup_tables.append(test_table_name) - loader = SnowflakeLoader(snowflake_config) - - with loader: - # Create table without metadata column - data = pa.table({'id': [1, 2, 3], 'block_num': [100, 150, 200], 'value': [10.0, 20.0, 30.0]}) - loader.load_table(data, test_table_name, create_table=True) - - # Call handle reorg - invalidation_ranges = [BlockRange(network='ethereum', start=150, end=250)] - - # Should log warning and not modify data - loader._handle_reorg(invalidation_ranges, test_table_name, 'test_connection') - - # Verify data unchanged - loader.cursor.execute(f'SELECT COUNT(*) FROM {test_table_name}') - count = loader.cursor.fetchone()['COUNT(*)'] - assert count == 3 - - def test_handle_reorg_single_network(self, snowflake_config, test_table_name, cleanup_tables): - """Test reorg handling for single network data""" - - cleanup_tables.append(test_table_name) - loader = SnowflakeLoader(snowflake_config) - - with loader: - # Create batches with proper metadata - batch1 = pa.RecordBatch.from_pydict({'id': [1], 'block_num': [105]}) - batch2 = pa.RecordBatch.from_pydict({'id': [2], 'block_num': [155]}) - batch3 = pa.RecordBatch.from_pydict({'id': [3], 'block_num': [205]}) - - # Create streaming responses with block ranges - response1 = ResponseBatch.data_batch( - data=batch1, - metadata=BatchMetadata(ranges=[BlockRange(network='ethereum', start=100, end=110, hash='0xabc')]), - ) - response2 = ResponseBatch.data_batch( - data=batch2, - metadata=BatchMetadata(ranges=[BlockRange(network='ethereum', start=150, end=160, hash='0xdef')]), - ) - response3 = ResponseBatch.data_batch( - data=batch3, - metadata=BatchMetadata(ranges=[BlockRange(network='ethereum', start=200, end=210, hash='0xghi')]), - ) - - # Load data via streaming API - stream = [response1, response2, response3] - results = list(loader.load_stream_continuous(iter(stream), test_table_name)) - - # Verify all data loaded successfully - assert len(results) == 3 - assert all(r.success for r in results) - - # Verify all data exists - loader.cursor.execute(f'SELECT COUNT(*) FROM {test_table_name}') - count = loader.cursor.fetchone()['COUNT(*)'] - assert count == 3 - - # Trigger reorg from block 155 - should delete rows 2 and 3 - reorg_response = ResponseBatch.reorg_batch( - invalidation_ranges=[BlockRange(network='ethereum', start=155, end=300)] - ) - reorg_results = list(loader.load_stream_continuous(iter([reorg_response]), test_table_name)) - - # Verify reorg processed - assert len(reorg_results) == 1 - assert reorg_results[0].is_reorg - - # Verify only first row remains - loader.cursor.execute(f'SELECT COUNT(*) FROM {test_table_name}') - count = loader.cursor.fetchone()['COUNT(*)'] - assert count == 1 - - loader.cursor.execute(f'SELECT "id" FROM {test_table_name}') - remaining_id = loader.cursor.fetchone()['id'] - assert remaining_id == 1 - - def test_handle_reorg_multi_network(self, snowflake_config, test_table_name, cleanup_tables): - """Test reorg handling preserves data from unaffected networks""" - - cleanup_tables.append(test_table_name) - loader = SnowflakeLoader(snowflake_config) - - with loader: - # Create batches from multiple networks - batch1 = pa.RecordBatch.from_pydict({'id': [1], 'network': ['ethereum']}) - batch2 = pa.RecordBatch.from_pydict({'id': [2], 'network': ['polygon']}) - batch3 = pa.RecordBatch.from_pydict({'id': [3], 'network': ['ethereum']}) - batch4 = pa.RecordBatch.from_pydict({'id': [4], 'network': ['polygon']}) - - # Create streaming responses with block ranges - response1 = ResponseBatch.data_batch( - data=batch1, - metadata=BatchMetadata(ranges=[BlockRange(network='ethereum', start=100, end=110, hash='0xa')]), - ) - response2 = ResponseBatch.data_batch( - data=batch2, - metadata=BatchMetadata(ranges=[BlockRange(network='polygon', start=100, end=110, hash='0xb')]), - ) - response3 = ResponseBatch.data_batch( - data=batch3, - metadata=BatchMetadata(ranges=[BlockRange(network='ethereum', start=150, end=160, hash='0xc')]), - ) - response4 = ResponseBatch.data_batch( - data=batch4, - metadata=BatchMetadata(ranges=[BlockRange(network='polygon', start=150, end=160, hash='0xd')]), - ) - - # Load data via streaming API - stream = [response1, response2, response3, response4] - results = list(loader.load_stream_continuous(iter(stream), test_table_name)) - - # Verify all data loaded successfully - assert len(results) == 4 - assert all(r.success for r in results) - - # Trigger reorg for ethereum only from block 150 - reorg_response = ResponseBatch.reorg_batch( - invalidation_ranges=[BlockRange(network='ethereum', start=150, end=200)] - ) - reorg_results = list(loader.load_stream_continuous(iter([reorg_response]), test_table_name)) - - # Verify reorg processed - assert len(reorg_results) == 1 - assert reorg_results[0].is_reorg - - # Verify ethereum row 3 deleted, but polygon rows preserved - loader.cursor.execute(f'SELECT "id" FROM {test_table_name} ORDER BY "id"') - remaining_ids = [row['id'] for row in loader.cursor.fetchall()] - assert remaining_ids == [1, 2, 4] # Row 3 deleted - - def test_handle_reorg_overlapping_ranges(self, snowflake_config, test_table_name, cleanup_tables): - """Test reorg with overlapping block ranges""" - - cleanup_tables.append(test_table_name) - loader = SnowflakeLoader(snowflake_config) - - with loader: - # Create batches with overlapping ranges - batch1 = pa.RecordBatch.from_pydict({'id': [1]}) - batch2 = pa.RecordBatch.from_pydict({'id': [2]}) - batch3 = pa.RecordBatch.from_pydict({'id': [3]}) - - # Create streaming responses with block ranges - response1 = ResponseBatch.data_batch( - data=batch1, - metadata=BatchMetadata( - ranges=[BlockRange(network='ethereum', start=90, end=110, hash='0xa')] - ), # Before reorg - ) - response2 = ResponseBatch.data_batch( - data=batch2, - metadata=BatchMetadata( - ranges=[BlockRange(network='ethereum', start=140, end=160, hash='0xb')] - ), # Overlaps - ) - response3 = ResponseBatch.data_batch( - data=batch3, - metadata=BatchMetadata( - ranges=[BlockRange(network='ethereum', start=170, end=190, hash='0xc')] - ), # Overlaps - ) - - # Load data via streaming API - stream = [response1, response2, response3] - results = list(loader.load_stream_continuous(iter(stream), test_table_name)) - - # Verify all data loaded successfully - assert len(results) == 3 - assert all(r.success for r in results) - - # Trigger reorg from block 150 - should delete rows where end >= 150 - reorg_response = ResponseBatch.reorg_batch( - invalidation_ranges=[BlockRange(network='ethereum', start=150, end=200)] - ) - reorg_results = list(loader.load_stream_continuous(iter([reorg_response]), test_table_name)) - - # Verify reorg processed - assert len(reorg_results) == 1 - assert reorg_results[0].is_reorg - - # Only first row should remain (ends at 110 < 150) - loader.cursor.execute(f'SELECT COUNT(*) FROM {test_table_name}') - count = loader.cursor.fetchone()['COUNT(*)'] - assert count == 1 - - loader.cursor.execute(f'SELECT "id" FROM {test_table_name}') - remaining_id = loader.cursor.fetchone()['id'] - assert remaining_id == 1 - - def test_handle_reorg_with_history_preservation(self, snowflake_config, test_table_name, cleanup_tables): - """Test reorg history preservation mode - rows are updated instead of deleted""" - - cleanup_tables.append(test_table_name) - cleanup_tables.append(f'{test_table_name}_current') - cleanup_tables.append(f'{test_table_name}_history') - - # Enable history preservation - config_with_history = {**snowflake_config, 'preserve_reorg_history': True} - loader = SnowflakeLoader(config_with_history) - - with loader: - # Create batches with proper metadata - batch1 = pa.RecordBatch.from_pydict({'id': [1], 'block_num': [105]}) - batch2 = pa.RecordBatch.from_pydict({'id': [2], 'block_num': [155]}) - batch3 = pa.RecordBatch.from_pydict({'id': [3], 'block_num': [205]}) - - # Create streaming responses with block ranges - response1 = ResponseBatch.data_batch( - data=batch1, - metadata=BatchMetadata(ranges=[BlockRange(network='ethereum', start=100, end=110, hash='0xabc')]), - ) - response2 = ResponseBatch.data_batch( - data=batch2, - metadata=BatchMetadata(ranges=[BlockRange(network='ethereum', start=150, end=160, hash='0xdef')]), - ) - response3 = ResponseBatch.data_batch( - data=batch3, - metadata=BatchMetadata(ranges=[BlockRange(network='ethereum', start=200, end=210, hash='0xghi')]), - ) - - # Load data via streaming API - stream = [response1, response2, response3] - results = list(loader.load_stream_continuous(iter(stream), test_table_name)) - - # Verify all data loaded successfully - assert len(results) == 3 - assert all(r.success for r in results) - - # Verify temporal columns exist and are set correctly - loader.cursor.execute(f'SELECT COUNT(*) FROM {test_table_name} WHERE "_amp_is_current" = TRUE') - current_count = loader.cursor.fetchone()['COUNT(*)'] - assert current_count == 3 - - # Verify reorg columns exist - loader.cursor.execute(f'SELECT COUNT(*) FROM {test_table_name} WHERE "_amp_reorg_batch_id" IS NULL') - not_reorged_count = loader.cursor.fetchone()['COUNT(*)'] - assert not_reorged_count == 3 # All current rows should have NULL reorg_batch_id - - # Verify views exist - loader.cursor.execute(f'SELECT COUNT(*) FROM {test_table_name}_current') - view_count = loader.cursor.fetchone()['COUNT(*)'] - assert view_count == 3 - - # Trigger reorg from block 155 - should UPDATE rows 2 and 3, not delete them - reorg_response = ResponseBatch.reorg_batch( - invalidation_ranges=[BlockRange(network='ethereum', start=155, end=300)] - ) - reorg_results = list(loader.load_stream_continuous(iter([reorg_response]), test_table_name)) - - # Verify reorg processed - assert len(reorg_results) == 1 - assert reorg_results[0].is_reorg - - # Verify ALL 3 rows still exist in base table - loader.cursor.execute(f'SELECT COUNT(*) FROM {test_table_name}') - total_count = loader.cursor.fetchone()['COUNT(*)'] - assert total_count == 3 - - # Verify only first row is current - loader.cursor.execute(f'SELECT COUNT(*) FROM {test_table_name} WHERE "_amp_is_current" = TRUE') - current_count = loader.cursor.fetchone()['COUNT(*)'] - assert current_count == 1 - - # Verify _current view shows only active row - loader.cursor.execute(f'SELECT "id" FROM {test_table_name}_current') - current_ids = [row['id'] for row in loader.cursor.fetchall()] - assert current_ids == [1] - - # Verify _history view shows all rows - loader.cursor.execute(f'SELECT "id" FROM {test_table_name}_history ORDER BY "id"') - history_ids = [row['id'] for row in loader.cursor.fetchall()] - assert history_ids == [1, 2, 3] - - # Verify reorged rows have simplified reorg columns set correctly - loader.cursor.execute( - f'''SELECT "id", "_amp_is_current", "_amp_batch_id", "_amp_reorg_batch_id" - FROM {test_table_name} - WHERE "_amp_is_current" = FALSE - ORDER BY "id"''' - ) - reorged_rows = loader.cursor.fetchall() - assert len(reorged_rows) == 2 - assert reorged_rows[0]['id'] == 2 - assert reorged_rows[1]['id'] == 3 - # Verify reorg_batch_id is set (identifies which reorg event superseded these rows) - assert reorged_rows[0]['_amp_reorg_batch_id'] is not None - assert reorged_rows[1]['_amp_reorg_batch_id'] is not None - # Both rows superseded by same reorg event - assert reorged_rows[0]['_amp_reorg_batch_id'] == reorged_rows[1]['_amp_reorg_batch_id'] - - def test_parallel_streaming_with_stage(self, snowflake_config, test_table_name, cleanup_tables): - """Test parallel streaming using stage loading method""" - import threading - - cleanup_tables.append(test_table_name) - config = {**snowflake_config, 'loading_method': 'stage'} - loader = SnowflakeLoader(config) - - with loader: - # Create table first - initial_batch = pa.RecordBatch.from_pydict({'id': [1], 'partition': ['partition_0'], 'value': [100]}) - loader.load_batch(initial_batch, test_table_name, create_table=True) - - # Thread lock for serializing access to shared Snowflake connection - # (Snowflake connector is not thread-safe) - load_lock = threading.Lock() - - # Load multiple batches in parallel from different "streams" - def load_partition_data(partition_id: int, start_id: int): - """Simulate a stream partition loading data""" - for batch_num in range(3): - batch_start = start_id + (batch_num * 10) - batch = pa.RecordBatch.from_pydict( - { - 'id': list(range(batch_start, batch_start + 10)), - 'partition': [f'partition_{partition_id}'] * 10, - 'value': list(range(batch_start * 100, (batch_start + 10) * 100, 100)), - } - ) - # Use lock to ensure thread-safe access to shared connection - with load_lock: - result = loader.load_batch(batch, test_table_name, create_table=False) - assert result.success, f'Partition {partition_id} batch {batch_num} failed: {result.error}' - - # Launch 3 parallel "streams" (threads simulating parallel streaming) - threads = [] - for partition_id in range(3): - start_id = 100 + (partition_id * 100) - thread = threading.Thread(target=load_partition_data, args=(partition_id, start_id)) - threads.append(thread) - thread.start() - - # Wait for all streams to complete - for thread in threads: - thread.join() - - # Verify all data loaded correctly - loader.cursor.execute(f'SELECT COUNT(*) FROM {test_table_name}') - count = loader.cursor.fetchone()['COUNT(*)'] - # 1 initial + (3 partitions * 3 batches * 10 rows) = 91 rows - assert count == 91 - - # Verify each partition loaded correctly - for partition_id in range(3): - loader.cursor.execute( - f'SELECT COUNT(*) FROM {test_table_name} WHERE "partition" = \'partition_{partition_id}\'' - ) - partition_count = loader.cursor.fetchone()['COUNT(*)'] - # partition_0 has 31 rows (1 initial + 30 from thread), others have 30 - expected_count = 31 if partition_id == 0 else 30 - assert partition_count == expected_count - - def test_streaming_with_reorg(self, snowflake_config, test_table_name, cleanup_tables): - """Test streaming data with reorg support""" - from src.amp.streaming.types import BatchMetadata, BlockRange, ResponseBatch - - cleanup_tables.append(test_table_name) - loader = SnowflakeLoader(snowflake_config) - - with loader: - # Create streaming data with metadata - data1 = pa.RecordBatch.from_pydict({'id': [1, 2], 'value': [100, 200]}) - - data2 = pa.RecordBatch.from_pydict({'id': [3, 4], 'value': [300, 400]}) - - # Create response batches using factory methods (with hashes for proper state management) - response1 = ResponseBatch.data_batch( - data=data1, - metadata=BatchMetadata(ranges=[BlockRange(network='ethereum', start=100, end=110, hash='0xabc123')]), - ) - - response2 = ResponseBatch.data_batch( - data=data2, - metadata=BatchMetadata(ranges=[BlockRange(network='ethereum', start=150, end=160, hash='0xdef456')]), - ) - - # Simulate reorg event using factory method - reorg_response = ResponseBatch.reorg_batch( - invalidation_ranges=[BlockRange(network='ethereum', start=150, end=200)] - ) - - # Process streaming data - stream = [response1, response2, reorg_response] - results = list(loader.load_stream_continuous(iter(stream), test_table_name)) - - # Verify results - assert len(results) == 3 - assert results[0].success - assert results[0].rows_loaded == 2 - assert results[1].success - assert results[1].rows_loaded == 2 - assert results[2].success - assert results[2].is_reorg - - # Verify reorg deleted the second batch - loader.cursor.execute(f'SELECT "id" FROM {test_table_name} ORDER BY "id"') - remaining_ids = [row['id'] for row in loader.cursor.fetchall()] - assert remaining_ids == [1, 2] # 3 and 4 deleted by reorg - - -@pytest.fixture -def snowflake_streaming_config(): - """ - Snowflake Snowpipe Streaming configuration from environment. - - Requires: - - SNOWFLAKE_ACCOUNT: Account identifier - - SNOWFLAKE_USER: Username - - SNOWFLAKE_WAREHOUSE: Warehouse name - - SNOWFLAKE_DATABASE: Database name - - SNOWFLAKE_PRIVATE_KEY: Private key in PEM format (as string) - - SNOWFLAKE_SCHEMA: Schema name (optional, defaults to PUBLIC) - - SNOWFLAKE_ROLE: Role (optional) - """ - import os - - config = { - 'account': os.getenv('SNOWFLAKE_ACCOUNT', 'test_account'), - 'user': os.getenv('SNOWFLAKE_USER', 'test_user'), - 'warehouse': os.getenv('SNOWFLAKE_WAREHOUSE', 'test_warehouse'), - 'database': os.getenv('SNOWFLAKE_DATABASE', 'test_database'), - 'schema': os.getenv('SNOWFLAKE_SCHEMA', 'PUBLIC'), - 'loading_method': 'snowpipe_streaming', - 'streaming_channel_prefix': 'test_amp', - 'streaming_max_retries': 3, - 'streaming_buffer_flush_interval': 1, - } - - # Private key is required for Snowpipe Streaming - if os.getenv('SNOWFLAKE_PRIVATE_KEY'): - config['private_key'] = os.getenv('SNOWFLAKE_PRIVATE_KEY') - else: - pytest.skip('Snowpipe Streaming requires SNOWFLAKE_PRIVATE_KEY environment variable') - - if os.getenv('SNOWFLAKE_ROLE'): - config['role'] = os.getenv('SNOWFLAKE_ROLE') - - return config - - -@pytest.mark.integration -@pytest.mark.snowflake -class TestSnowpipeStreamingIntegration: - """Integration tests for Snowpipe Streaming functionality""" - - def test_streaming_connection(self, snowflake_streaming_config): - """Test connection with Snowpipe Streaming enabled""" - loader = SnowflakeLoader(snowflake_streaming_config) - - loader.connect() - assert loader._is_connected is True - assert loader.connection is not None - # Streaming channels dict is initialized empty (channels created on first load) - assert hasattr(loader, 'streaming_channels') - - loader.disconnect() - assert loader._is_connected is False - - def test_basic_streaming_batch_load( - self, snowflake_streaming_config, small_test_table, test_table_name, cleanup_tables - ): - """Test basic batch loading via Snowpipe Streaming""" - cleanup_tables.append(test_table_name) - loader = SnowflakeLoader(snowflake_streaming_config) - - with loader: - # Load first batch - batch = small_test_table.to_batches(max_chunksize=50)[0] - result = loader.load_batch(batch, test_table_name, create_table=True) - - assert result.success is True - assert result.rows_loaded == batch.num_rows - assert result.table_name == test_table_name - assert result.metadata['loading_method'] == 'snowpipe_streaming' - - # Wait for Snowpipe streaming data to become queryable (eventual consistency) - count = wait_for_snowpipe_data(loader, test_table_name, batch.num_rows) - assert count == batch.num_rows - - def test_streaming_multiple_batches( - self, snowflake_streaming_config, medium_test_table, test_table_name, cleanup_tables - ): - """Test loading multiple batches via Snowpipe Streaming""" - cleanup_tables.append(test_table_name) - loader = SnowflakeLoader(snowflake_streaming_config) - - with loader: - # Load multiple batches - total_rows = 0 - for i, batch in enumerate(medium_test_table.to_batches(max_chunksize=1000)): - result = loader.load_batch(batch, test_table_name, create_table=(i == 0)) - assert result.success is True - total_rows += result.rows_loaded - - assert total_rows == medium_test_table.num_rows - - # Wait for Snowpipe streaming data to become queryable (eventual consistency) - count = wait_for_snowpipe_data(loader, test_table_name, medium_test_table.num_rows) - assert count == medium_test_table.num_rows - - def test_streaming_channel_management( - self, snowflake_streaming_config, small_test_table, test_table_name, cleanup_tables - ): - """Test that channels are created and reused properly""" - cleanup_tables.append(test_table_name) - loader = SnowflakeLoader(snowflake_streaming_config) - - with loader: - # Load batches with same channel suffix - batch = small_test_table.to_batches(max_chunksize=50)[0] - - result1 = loader.load_batch(batch, test_table_name, create_table=True, channel_suffix='partition_0') - assert result1.success is True - - result2 = loader.load_batch(batch, test_table_name, channel_suffix='partition_0') - assert result2.success is True - - # Verify channel was reused (check loader's channel cache) - channel_key = f'{test_table_name}:test_amp_{test_table_name}_partition_0' - assert channel_key in loader.streaming_channels - - # Wait for Snowpipe streaming data to become queryable (eventual consistency) - count = wait_for_snowpipe_data(loader, test_table_name, batch.num_rows * 2) - assert count == batch.num_rows * 2 - - def test_streaming_multiple_partitions( - self, snowflake_streaming_config, small_test_table, test_table_name, cleanup_tables - ): - """Test parallel streaming with multiple partition channels""" - cleanup_tables.append(test_table_name) - loader = SnowflakeLoader(snowflake_streaming_config) - - with loader: - batch = small_test_table.to_batches(max_chunksize=30)[0] - - # Load to different partitions - result1 = loader.load_batch(batch, test_table_name, create_table=True, channel_suffix='partition_0') - result2 = loader.load_batch(batch, test_table_name, channel_suffix='partition_1') - result3 = loader.load_batch(batch, test_table_name, channel_suffix='partition_2') - - assert result1.success and result2.success and result3.success - - # Verify multiple channels created - assert len(loader.streaming_channels) == 3 - - # Wait for Snowpipe streaming data to become queryable (eventual consistency) - count = wait_for_snowpipe_data(loader, test_table_name, batch.num_rows * 3) - assert count == batch.num_rows * 3 - - def test_streaming_data_types( - self, snowflake_streaming_config, comprehensive_test_data, test_table_name, cleanup_tables - ): - """Test Snowpipe Streaming with various data types""" - cleanup_tables.append(test_table_name) - loader = SnowflakeLoader(snowflake_streaming_config) - - with loader: - result = loader.load_table(comprehensive_test_data, test_table_name, create_table=True) - assert result.success is True - - # Wait for Snowpipe streaming data to become queryable (eventual consistency) - count = wait_for_snowpipe_data(loader, test_table_name, comprehensive_test_data.num_rows) - assert count == comprehensive_test_data.num_rows - - # Verify specific row - loader.cursor.execute(f'SELECT * FROM {test_table_name} WHERE "id" = 0') - row = loader.cursor.fetchone() - assert row['id'] == 0 - - def test_streaming_null_handling(self, snowflake_streaming_config, null_test_data, test_table_name, cleanup_tables): - """Test Snowpipe Streaming with NULL values""" - cleanup_tables.append(test_table_name) - loader = SnowflakeLoader(snowflake_streaming_config) - - with loader: - result = loader.load_table(null_test_data, test_table_name, create_table=True) - assert result.success is True - - # Wait for Snowpipe streaming data to become queryable (eventual consistency) - wait_for_snowpipe_data(loader, test_table_name, null_test_data.num_rows) - - # Verify NULL handling - loader.cursor.execute(f'SELECT COUNT(*) FROM {test_table_name} WHERE "text_field" IS NULL') - null_count = loader.cursor.fetchone()['COUNT(*)'] - expected_nulls = sum(1 for val in null_test_data.column('text_field').to_pylist() if val is None) - assert null_count == expected_nulls - - def test_streaming_reorg_channel_closure(self, snowflake_streaming_config, test_table_name, cleanup_tables): - """Test that reorg properly closes streaming channels""" - import json - - from src.amp.streaming.types import BlockRange - - cleanup_tables.append(test_table_name) - loader = SnowflakeLoader(snowflake_streaming_config) - - with loader: - # Load initial data with multiple channels - batch = pa.RecordBatch.from_pydict( - { - 'id': [1, 2, 3], - 'value': [100, 200, 300], - '_meta_block_ranges': [json.dumps([{'network': 'ethereum', 'start': 100, 'end': 110}])] * 3, - } - ) - - loader.load_batch(batch, test_table_name, create_table=True, channel_suffix='partition_0') - loader.load_batch(batch, test_table_name, channel_suffix='partition_1') - - # Verify channels exist - assert len(loader.streaming_channels) == 2 - - # Wait for data to be queryable - time.sleep(5) - - # Trigger reorg - invalidation_ranges = [BlockRange(network='ethereum', start=100, end=200)] - loader._handle_reorg(invalidation_ranges, test_table_name, 'test_connection') - - # Verify channels were closed - assert len(loader.streaming_channels) == 0 - - # Verify data was deleted - loader.cursor.execute(f'SELECT COUNT(*) FROM {test_table_name}') - count = loader.cursor.fetchone()['COUNT(*)'] - assert count == 0 - - @pytest.mark.slow - def test_streaming_performance( - self, snowflake_streaming_config, performance_test_data, test_table_name, cleanup_tables - ): - """Test Snowpipe Streaming performance with larger dataset""" - cleanup_tables.append(test_table_name) - loader = SnowflakeLoader(snowflake_streaming_config) - - with loader: - start_time = time.time() - result = loader.load_table(performance_test_data, test_table_name, create_table=True) - duration = time.time() - start_time - - assert result.success is True - assert result.rows_loaded == performance_test_data.num_rows - - rows_per_second = result.rows_loaded / duration - - print('\nSnowpipe Streaming Performance:') - print(f' Total rows: {result.rows_loaded:,}') - print(f' Duration: {duration:.2f}s') - print(f' Throughput: {rows_per_second:,.0f} rows/sec') - print(f' Loading method: {result.metadata.get("loading_method")}') - - # Wait for Snowpipe streaming data to become queryable - # (eventual consistency, larger dataset may take longer) - count = wait_for_snowpipe_data(loader, test_table_name, performance_test_data.num_rows, max_wait=60) - assert count == performance_test_data.num_rows - - def test_streaming_error_handling(self, snowflake_streaming_config, test_table_name, cleanup_tables): - """Test error handling in Snowpipe Streaming""" - cleanup_tables.append(test_table_name) - loader = SnowflakeLoader(snowflake_streaming_config) - - with loader: - # Create table first - initial_data = pa.table({'id': [1, 2, 3], 'value': [100, 200, 300]}) - result = loader.load_table(initial_data, test_table_name, create_table=True) - assert result.success is True - - # Try to load data with extra column (Snowpipe streaming handles gracefully) - # Note: Snowpipe streaming accepts data with extra columns and silently ignores them - incompatible_data = pa.RecordBatch.from_pydict( - { - 'id': [4, 5], - 'different_column': ['a', 'b'], # Extra column not in table schema - } - ) - - result = loader.load_batch(incompatible_data, test_table_name) - # Snowpipe streaming handles this gracefully - it loads the matching columns - # and ignores columns that don't exist in the table - assert result.success is True - assert result.rows_loaded == 2 - - def test_microbatch_deduplication(self, snowflake_config, test_table_name, cleanup_tables): - """ - Test that multiple RecordBatches within the same microbatch are all loaded, - and deduplication only happens at microbatch boundaries when ranges_complete=True. - - This test verifies the fix for the critical bug where we were marking batches - as processed after every RecordBatch instead of waiting for ranges_complete=True. - """ - from src.amp.streaming.types import BatchMetadata, BlockRange, ResponseBatch - - cleanup_tables.append(test_table_name) - - # Enable state management to test deduplication - config_with_state = { - **snowflake_config, - 'state': {'enabled': True, 'storage': 'memory', 'store_batch_id': True}, - } - loader = SnowflakeLoader(config_with_state) - - with loader: - # Simulate a microbatch sent as 3 RecordBatches with the same BlockRange - # This happens when the server sends large microbatches in smaller chunks - - # First RecordBatch of the microbatch (ranges_complete=False) - batch1_data = pa.RecordBatch.from_pydict({'id': [1, 2], 'value': [100, 200]}) - response1 = ResponseBatch.data_batch( - data=batch1_data, - metadata=BatchMetadata( - ranges=[BlockRange(network='ethereum', start=100, end=110, hash='0xabc123')], - ranges_complete=False, # Not the last batch in this microbatch - ), - ) - - # Second RecordBatch of the microbatch (ranges_complete=False) - batch2_data = pa.RecordBatch.from_pydict({'id': [3, 4], 'value': [300, 400]}) - response2 = ResponseBatch.data_batch( - data=batch2_data, - metadata=BatchMetadata( - ranges=[BlockRange(network='ethereum', start=100, end=110, hash='0xabc123')], # Same BlockRange! - ranges_complete=False, # Still not the last batch - ), - ) - - # Third RecordBatch of the microbatch (ranges_complete=True) - batch3_data = pa.RecordBatch.from_pydict({'id': [5, 6], 'value': [500, 600]}) - response3 = ResponseBatch.data_batch( - data=batch3_data, - metadata=BatchMetadata( - ranges=[BlockRange(network='ethereum', start=100, end=110, hash='0xabc123')], # Same BlockRange! - ranges_complete=True, # Last batch in this microbatch - safe to mark as processed - ), - ) - - # Process the microbatch stream - stream = [response1, response2, response3] - results = list( - loader.load_stream_continuous(iter(stream), test_table_name, connection_name='test_connection') - ) - - # CRITICAL: All 3 RecordBatches should be loaded successfully - # Before the fix, only the first batch would load (the other 2 would be skipped as "duplicates") - assert len(results) == 3, 'All RecordBatches within microbatch should be processed' - assert all(r.success for r in results), 'All batches should succeed' - assert results[0].rows_loaded == 2, 'First batch should load 2 rows' - assert results[1].rows_loaded == 2, 'Second batch should load 2 rows (not skipped!)' - assert results[2].rows_loaded == 2, 'Third batch should load 2 rows (not skipped!)' - - # Verify total rows in table (all batches loaded) - loader.cursor.execute(f'SELECT COUNT(*) as count FROM {test_table_name}') - total_count = loader.cursor.fetchone()['COUNT'] - assert total_count == 6, 'All 6 rows from 3 RecordBatches should be in the table' - - # Verify the actual IDs are present - loader.cursor.execute(f'SELECT "id" FROM {test_table_name} ORDER BY "id"') - all_ids = [row['id'] for row in loader.cursor.fetchall()] - assert all_ids == [1, 2, 3, 4, 5, 6], 'All rows from all RecordBatches should be present' - - # Now test that re-sending the complete microbatch is properly deduplicated - # This time, the first batch has ranges_complete=True (entire microbatch in one RecordBatch) - duplicate_batch = pa.RecordBatch.from_pydict({'id': [7, 8], 'value': [700, 800]}) - duplicate_response = ResponseBatch.data_batch( - data=duplicate_batch, - metadata=BatchMetadata( - ranges=[ - BlockRange(network='ethereum', start=100, end=110, hash='0xabc123') - ], # Same range as before! - ranges_complete=True, # Complete microbatch - ), - ) - - # Process duplicate microbatch - duplicate_results = list( - loader.load_stream_continuous( - iter([duplicate_response]), test_table_name, connection_name='test_connection' - ) - ) - - # The duplicate microbatch should be skipped (already processed) - assert len(duplicate_results) == 1 - assert duplicate_results[0].success is True - assert duplicate_results[0].rows_loaded == 0, 'Duplicate microbatch should be skipped' - assert duplicate_results[0].metadata.get('operation') == 'skip_duplicate', 'Should be marked as duplicate' - - # Verify row count unchanged (duplicate was skipped) - loader.cursor.execute(f'SELECT COUNT(*) as count FROM {test_table_name}') - final_count = loader.cursor.fetchone()['COUNT'] - assert final_count == 6, 'Row count should not increase after duplicate microbatch' diff --git a/tests/unit/test_async_client.py b/tests/unit/test_async_client.py new file mode 100644 index 0000000..ef41044 --- /dev/null +++ b/tests/unit/test_async_client.py @@ -0,0 +1,328 @@ +""" +Unit tests for AsyncAmpClient and AsyncQueryBuilder API methods. + +These tests focus on the pure logic and data structures without requiring +actual Flight SQL connections or Admin API calls. +""" + +from unittest.mock import AsyncMock, Mock, patch + +import pytest + +from src.amp.async_client import AsyncAmpClient, AsyncQueryBuilder + + +@pytest.mark.unit +class TestAsyncQueryBuilder: + """Test AsyncQueryBuilder pure methods and logic""" + + def test_with_dependency_chaining(self): + """Test adding and chaining dependencies""" + qb = AsyncQueryBuilder(client=None, query='SELECT * FROM eth.blocks JOIN btc.blocks') + + result = qb.with_dependency('eth', '_/eth_firehose@0.0.0').with_dependency('btc', '_/btc_firehose@1.2.3') + + assert result is qb # Returns self for chaining + assert qb._dependencies == {'eth': '_/eth_firehose@0.0.0', 'btc': '_/btc_firehose@1.2.3'} + + def test_with_dependency_overwrites_existing_alias(self): + """Test that same alias overwrites previous dependency""" + qb = AsyncQueryBuilder(client=None, query='SELECT * FROM eth.blocks') + qb.with_dependency('eth', '_/eth_firehose@0.0.0') + qb.with_dependency('eth', '_/eth_firehose@1.0.0') + + assert qb._dependencies == {'eth': '_/eth_firehose@1.0.0'} + + def test_ensure_streaming_query_adds_settings(self): + """Test that streaming settings are added when not present""" + qb = AsyncQueryBuilder(client=None, query='SELECT * FROM eth.blocks') + + result = qb._ensure_streaming_query('SELECT * FROM eth.blocks') + assert result == 'SELECT * FROM eth.blocks SETTINGS stream = true' + + # Strips semicolons + result = qb._ensure_streaming_query('SELECT * FROM eth.blocks;') + assert result == 'SELECT * FROM eth.blocks SETTINGS stream = true' + + def test_ensure_streaming_query_preserves_existing_settings(self): + """Test that existing SETTINGS stream = true is preserved""" + qb = AsyncQueryBuilder(client=None, query='SELECT * FROM eth.blocks') + + # Should not duplicate when already present + result = qb._ensure_streaming_query('SELECT * FROM eth.blocks SETTINGS stream = true') + assert 'SETTINGS stream = true' in result + + def test_querybuilder_repr(self): + """Test AsyncQueryBuilder string representation""" + qb = AsyncQueryBuilder(client=None, query='SELECT * FROM eth.blocks') + repr_str = repr(qb) + + assert 'AsyncQueryBuilder' in repr_str + assert 'SELECT * FROM eth.blocks' in repr_str + + # Long queries are truncated + long_query = 'SELECT ' + ', '.join([f'col{i}' for i in range(100)]) + ' FROM eth.blocks' + qb_long = AsyncQueryBuilder(client=None, query=long_query) + assert '...' in repr(qb_long) + + def test_dependencies_initialized_empty(self): + """Test that dependencies and cache are initialized correctly""" + qb = AsyncQueryBuilder(client=None, query='SELECT * FROM eth.blocks') + + assert qb._dependencies == {} + assert qb._result_cache is None + + +@pytest.mark.unit +class TestAsyncClientInitialization: + """Test AsyncAmpClient initialization logic""" + + def test_client_requires_url_or_query_url(self): + """Test that AsyncAmpClient requires either url or query_url""" + with pytest.raises(ValueError, match='Either url or query_url must be provided'): + AsyncAmpClient() + + +@pytest.mark.unit +class TestAsyncClientAuthPriority: + """Test AsyncAmpClient authentication priority (explicit token > env var > auth file)""" + + @patch('amp.async_client.os.getenv') + @patch('amp.async_client.flight.connect') + def test_explicit_token_highest_priority(self, mock_connect, mock_getenv): + """Test that explicit auth_token parameter has highest priority""" + mock_getenv.return_value = 'env-var-token' + + AsyncAmpClient(query_url='grpc://localhost:1602', auth_token='explicit-token') + + # Verify that explicit token was used (not env var) + mock_connect.assert_called_once() + call_args = mock_connect.call_args + middleware = call_args[1].get('middleware', []) + assert len(middleware) == 1 + assert middleware[0].get_token() == 'explicit-token' + + @patch('amp.async_client.os.getenv') + @patch('amp.async_client.flight.connect') + def test_env_var_second_priority(self, mock_connect, mock_getenv): + """Test that AMP_AUTH_TOKEN env var has second priority""" + + # Return 'env-var-token' for AMP_AUTH_TOKEN, None for others + def getenv_side_effect(key, default=None): + if key == 'AMP_AUTH_TOKEN': + return 'env-var-token' + return default + + mock_getenv.side_effect = getenv_side_effect + + AsyncAmpClient(query_url='grpc://localhost:1602') + + # Verify env var was checked + calls = [str(call) for call in mock_getenv.call_args_list] + assert any('AMP_AUTH_TOKEN' in call for call in calls) + mock_connect.assert_called_once() + call_args = mock_connect.call_args + middleware = call_args[1].get('middleware', []) + assert len(middleware) == 1 + assert middleware[0].get_token() == 'env-var-token' + + @patch('amp.auth.AuthService') + @patch('amp.async_client.os.getenv') + @patch('amp.async_client.flight.connect') + def test_auth_file_lowest_priority(self, mock_connect, mock_getenv, mock_auth_service): + """Test that auth=True has lowest priority""" + + # Return None for all getenv calls + def getenv_side_effect(key, default=None): + return default + + mock_getenv.side_effect = getenv_side_effect + + mock_service_instance = Mock() + mock_service_instance.get_token.return_value = 'file-token' + mock_auth_service.return_value = mock_service_instance + + AsyncAmpClient(query_url='grpc://localhost:1602', auth=True) + + # Verify auth file was used + mock_auth_service.assert_called_once() + mock_connect.assert_called_once() + call_args = mock_connect.call_args + middleware = call_args[1].get('middleware', []) + assert len(middleware) == 1 + # The middleware should use the auth service's get_token method directly + assert middleware[0].get_token == mock_service_instance.get_token + + @patch('amp.async_client.os.getenv') + @patch('amp.async_client.flight.connect') + def test_no_auth_when_nothing_provided(self, mock_connect, mock_getenv): + """Test that no auth middleware is added when no auth is provided""" + + # Return None/default for all getenv calls + def getenv_side_effect(key, default=None): + return default + + mock_getenv.side_effect = getenv_side_effect + + AsyncAmpClient(query_url='grpc://localhost:1602') + + # Verify no middleware was added + mock_connect.assert_called_once() + call_args = mock_connect.call_args + middleware = call_args[1].get('middleware') + assert middleware is None or len(middleware) == 0 + + +@pytest.mark.unit +class TestAsyncClientSqlMethod: + """Test AsyncAmpClient.sql() method""" + + @patch('amp.async_client.os.getenv') + @patch('amp.async_client.flight.connect') + def test_sql_returns_async_query_builder(self, mock_connect, mock_getenv): + """Test that sql() returns an AsyncQueryBuilder instance""" + mock_getenv.return_value = None + + client = AsyncAmpClient(query_url='grpc://localhost:1602') + result = client.sql('SELECT * FROM eth.blocks') + + assert isinstance(result, AsyncQueryBuilder) + assert result.query == 'SELECT * FROM eth.blocks' + assert result.client is client + + +@pytest.mark.unit +class TestAsyncClientProperties: + """Test AsyncAmpClient properties for Admin and Registry access""" + + @patch('amp.async_client.os.getenv') + @patch('amp.async_client.flight.connect') + def test_datasets_raises_without_admin_url(self, mock_connect, mock_getenv): + """Test that datasets property raises when admin_url not provided""" + mock_getenv.return_value = None + + client = AsyncAmpClient(query_url='grpc://localhost:1602') + + with pytest.raises(ValueError, match='Admin API not configured'): + _ = client.datasets + + @patch('amp.async_client.os.getenv') + @patch('amp.async_client.flight.connect') + def test_jobs_raises_without_admin_url(self, mock_connect, mock_getenv): + """Test that jobs property raises when admin_url not provided""" + mock_getenv.return_value = None + + client = AsyncAmpClient(query_url='grpc://localhost:1602') + + with pytest.raises(ValueError, match='Admin API not configured'): + _ = client.jobs + + @patch('amp.async_client.os.getenv') + @patch('amp.async_client.flight.connect') + def test_schema_raises_without_admin_url(self, mock_connect, mock_getenv): + """Test that schema property raises when admin_url not provided""" + mock_getenv.return_value = None + + client = AsyncAmpClient(query_url='grpc://localhost:1602') + + with pytest.raises(ValueError, match='Admin API not configured'): + _ = client.schema + + @patch('amp.async_client.os.getenv') + @patch('amp.async_client.flight.connect') + def test_registry_raises_without_registry_url(self, mock_connect, mock_getenv): + """Test that registry property raises when registry_url not provided""" + mock_getenv.return_value = None + + client = AsyncAmpClient(query_url='grpc://localhost:1602', registry_url=None) + + with pytest.raises(ValueError, match='Registry API not configured'): + _ = client.registry + + +@pytest.mark.unit +class TestAsyncClientConfigurationMethods: + """Test AsyncAmpClient configuration methods""" + + @patch('amp.async_client.os.getenv') + @patch('amp.async_client.flight.connect') + def test_configure_connection(self, mock_connect, mock_getenv): + """Test that configure_connection stores connection config""" + mock_getenv.return_value = None + + client = AsyncAmpClient(query_url='grpc://localhost:1602') + client.configure_connection('test_conn', 'postgresql', {'host': 'localhost', 'database': 'test'}) + + # Verify connection was stored in manager + connections = client.list_connections() + assert 'test_conn' in connections + + @patch('amp.async_client.os.getenv') + @patch('amp.async_client.flight.connect') + def test_get_available_loaders(self, mock_connect, mock_getenv): + """Test that get_available_loaders returns list of loaders""" + mock_getenv.return_value = None + + client = AsyncAmpClient(query_url='grpc://localhost:1602') + loaders = client.get_available_loaders() + + assert isinstance(loaders, list) + # Should have at least postgresql and redis loaders + assert 'postgresql' in loaders or len(loaders) > 0 + + +@pytest.mark.unit +@pytest.mark.asyncio +class TestAsyncQueryBuilderLoad: + """Test AsyncQueryBuilder.load() method""" + + @patch('amp.async_client.os.getenv') + @patch('amp.async_client.flight.connect') + async def test_load_raises_for_parallel_config_without_stream(self, mock_connect, mock_getenv): + """Test that load() raises error when parallel_config used without stream=True""" + mock_getenv.return_value = None + + client = AsyncAmpClient(query_url='grpc://localhost:1602') + qb = client.sql('SELECT * FROM eth.blocks') + + with pytest.raises(ValueError, match='parallel_config requires stream=True'): + await qb.load( + connection='test_conn', + destination='test_table', + parallel_config={'partitions': 4}, + ) + + +@pytest.mark.unit +class TestAsyncClientContextManager: + """Test AsyncAmpClient context manager support""" + + @patch('amp.async_client.os.getenv') + @patch('amp.async_client.flight.connect') + @pytest.mark.asyncio + async def test_async_context_manager(self, mock_connect, mock_getenv): + """Test that AsyncAmpClient works as async context manager""" + mock_getenv.return_value = None + mock_conn = Mock() + mock_connect.return_value = mock_conn + + async with AsyncAmpClient(query_url='grpc://localhost:1602') as client: + assert client is not None + assert client.conn is mock_conn + + # Verify connection was closed on exit + mock_conn.close.assert_called_once() + + @patch('amp.async_client.os.getenv') + @patch('amp.async_client.flight.connect') + @pytest.mark.asyncio + async def test_close_method(self, mock_connect, mock_getenv): + """Test that close() properly closes all connections""" + mock_getenv.return_value = None + mock_conn = Mock() + mock_connect.return_value = mock_conn + + client = AsyncAmpClient(query_url='grpc://localhost:1602') + await client.close() + + mock_conn.close.assert_called_once()