diff --git a/tests/aio/test_discovery_detect_local_dc.py b/tests/aio/test_discovery_detect_local_dc.py new file mode 100644 index 00000000..4f042396 --- /dev/null +++ b/tests/aio/test_discovery_detect_local_dc.py @@ -0,0 +1,143 @@ +# -*- coding: utf-8 -*- +import pytest +from unittest.mock import MagicMock, patch, AsyncMock +from ydb import driver, connection +from ydb.aio import pool, nearest_dc + + +class MockEndpointInfo: + def __init__(self, address, port, location): + self.address = address + self.port = port + self.endpoint = f"{address}:{port}" + self.location = location + self.ssl = False + self.node_id = 1 + + def endpoints_with_options(self): + yield (self.endpoint, connection.EndpointOptions(ssl_target_name_override=None, node_id=self.node_id)) + + +class MockDiscoveryResult: + def __init__(self, self_location, endpoints): + self.self_location = self_location + self.endpoints = endpoints + + +@pytest.mark.asyncio +async def test_detect_local_dc_overrides_server_location(): + """Test that detected location overrides server's self_location for preferred endpoints.""" + # Server reports dc1, but we detect dc2 as nearest + endpoints = [ + MockEndpointInfo("dc1-host", 2135, "dc1"), + MockEndpointInfo("dc2-host", 2135, "dc2"), + ] + mock_result = MockDiscoveryResult(self_location="dc1", endpoints=endpoints) + + mock_resolver = MagicMock() + mock_resolver.resolve = AsyncMock(return_value=mock_result) + + preferred = [] + + def mock_init(self, endpoint, driver_config, endpoint_options=None): + self.endpoint = endpoint + self.node_id = 1 + + with patch.object(nearest_dc, "detect_local_dc", AsyncMock(return_value="dc2")): + with patch("ydb.aio.connection.Connection.__init__", mock_init): + with patch("ydb.aio.connection.Connection.connection_ready", AsyncMock()): + with patch("ydb.aio.connection.Connection.close", AsyncMock()): + with patch("ydb.aio.connection.Connection.add_cleanup_callback", lambda *a: None): + config = driver.DriverConfig( + endpoint="grpc://test:2135", database="/local", detect_local_dc=True, use_all_nodes=False + ) + discovery = pool.Discovery( + store=pool.ConnectionsCache(config.use_all_nodes), driver_config=config + ) + discovery._resolver = mock_resolver + + original_add = discovery._cache.add + discovery._cache.add = lambda conn, pref=False: ( + preferred.append(conn.endpoint) if pref else None, + original_add(conn, pref), + )[1] + + await discovery.execute_discovery() + + assert any("dc2" in ep for ep in preferred), "dc2 should be preferred (detected)" + assert not any("dc1" in ep for ep in preferred), "dc1 should not be preferred" + + +@pytest.mark.asyncio +async def test_detect_local_dc_failure_fallback(): + """Test that detection failure falls back to server's self_location.""" + endpoints = [ + MockEndpointInfo("dc1-host", 2135, "dc1"), + MockEndpointInfo("dc2-host", 2135, "dc2"), + ] + mock_result = MockDiscoveryResult(self_location="dc1", endpoints=endpoints) + + mock_resolver = MagicMock() + mock_resolver.resolve = AsyncMock(return_value=mock_result) + + preferred = [] + + def mock_init(self, endpoint, driver_config, endpoint_options=None): + self.endpoint = endpoint + self.node_id = 1 + + with patch.object(nearest_dc, "detect_local_dc", AsyncMock(return_value=None)): + with patch("ydb.aio.connection.Connection.__init__", mock_init): + with patch("ydb.aio.connection.Connection.connection_ready", AsyncMock()): + with patch("ydb.aio.connection.Connection.close", AsyncMock()): + with patch("ydb.aio.connection.Connection.add_cleanup_callback", lambda *a: None): + config = driver.DriverConfig( + endpoint="grpc://test:2135", database="/local", detect_local_dc=True, use_all_nodes=False + ) + discovery = pool.Discovery( + store=pool.ConnectionsCache(config.use_all_nodes), driver_config=config + ) + discovery._resolver = mock_resolver + + original_add = discovery._cache.add + discovery._cache.add = lambda conn, pref=False: ( + preferred.append(conn.endpoint) if pref else None, + original_add(conn, pref), + )[1] + + await discovery.execute_discovery() + + assert any("dc1" in ep for ep in preferred), "dc1 should be preferred (server fallback)" + + +@pytest.mark.asyncio +async def test_detect_local_dc_skipped_when_use_all_nodes_true(): + """Test that detect_local_dc is NOT called when use_all_nodes=True.""" + endpoints = [ + MockEndpointInfo("dc1-host", 2135, "dc1"), + MockEndpointInfo("dc2-host", 2135, "dc2"), + ] + mock_result = MockDiscoveryResult(self_location="dc1", endpoints=endpoints) + + mock_resolver = MagicMock() + mock_resolver.resolve = AsyncMock(return_value=mock_result) + + def mock_init(self, endpoint, driver_config, endpoint_options=None): + self.endpoint = endpoint + self.node_id = 1 + + with patch.object(nearest_dc, "detect_local_dc", AsyncMock(return_value="dc2")) as detect_mock: + with patch("ydb.aio.connection.Connection.__init__", mock_init): + with patch("ydb.aio.connection.Connection.connection_ready", AsyncMock()): + with patch("ydb.aio.connection.Connection.close", AsyncMock()): + with patch("ydb.aio.connection.Connection.add_cleanup_callback", lambda *a: None): + config = driver.DriverConfig( + endpoint="grpc://test:2135", database="/local", detect_local_dc=True, use_all_nodes=True + ) + discovery = pool.Discovery( + store=pool.ConnectionsCache(config.use_all_nodes), driver_config=config + ) + discovery._resolver = mock_resolver + await discovery.execute_discovery() + + assert detect_mock.call_count == 0, "detect_local_dc should NOT be called when use_all_nodes=True" diff --git a/tests/aio/test_nearest_dc.py b/tests/aio/test_nearest_dc.py new file mode 100644 index 00000000..8ce6d06b --- /dev/null +++ b/tests/aio/test_nearest_dc.py @@ -0,0 +1,161 @@ +import asyncio +import pytest +from ydb.aio import nearest_dc + + +class MockEndpoint: + def __init__(self, address, port, location, ipv4_addrs=(), ipv6_addrs=()): + self.address = address + self.port = port + self.endpoint = f"{address}:{port}" + self.location = location + self.ipv4_addrs = ipv4_addrs + self.ipv6_addrs = ipv6_addrs + + +class MockWriter: + def __init__(self): + self.closed = False + + def close(self): + self.closed = True + + async def wait_closed(self): + await asyncio.sleep(0) + + +@pytest.mark.asyncio +async def test_check_fastest_endpoint_empty(): + assert await nearest_dc._check_fastest_endpoint([]) is None + + +@pytest.mark.asyncio +async def test_check_fastest_endpoint_all_fail(monkeypatch): + async def fake_open_connection(host, port): + raise OSError("connect failed") + + monkeypatch.setattr(nearest_dc.asyncio, "open_connection", fake_open_connection) + + endpoints = [ + MockEndpoint("a", 1, "dc1"), + MockEndpoint("b", 1, "dc2"), + ] + assert await nearest_dc._check_fastest_endpoint(endpoints, timeout=0.05) is None + + +@pytest.mark.asyncio +async def test_check_fastest_endpoint_fastest_wins(monkeypatch): + async def fake_open_connection(host, port): + if host == "slow": + await asyncio.sleep(0.05) + return None, MockWriter() + + monkeypatch.setattr(nearest_dc.asyncio, "open_connection", fake_open_connection) + + endpoints = [ + MockEndpoint("slow", 1, "dc_slow"), + MockEndpoint("fast", 1, "dc_fast"), + ] + winner = await nearest_dc._check_fastest_endpoint(endpoints, timeout=0.2) + assert winner is not None + assert winner.location == "dc_fast" + + +@pytest.mark.asyncio +async def test_check_fastest_endpoint_respects_main_timeout(monkeypatch): + async def fake_open_connection(host, port): + await asyncio.sleep(0.2) + return None, MockWriter() + + monkeypatch.setattr(nearest_dc.asyncio, "open_connection", fake_open_connection) + + endpoints = [ + MockEndpoint("hang1", 1, "dc1"), + MockEndpoint("hang2", 1, "dc2"), + ] + + winner = await nearest_dc._check_fastest_endpoint(endpoints, timeout=0.05) + + assert winner is None + + +@pytest.mark.asyncio +async def test_detect_local_dc_empty_endpoints(): + with pytest.raises(ValueError, match="Empty endpoints"): + await nearest_dc.detect_local_dc([]) + + +@pytest.mark.asyncio +async def test_detect_local_dc_single_location_returns_immediately(monkeypatch): + async def fail_if_called(*args, **kwargs): + raise AssertionError("open_connection should not be called for single location") + + monkeypatch.setattr(nearest_dc.asyncio, "open_connection", fail_if_called) + + endpoints = [ + MockEndpoint("h1", 1, "dc1"), + MockEndpoint("h2", 1, "dc1"), + ] + assert await nearest_dc.detect_local_dc(endpoints) == "dc1" + + +@pytest.mark.asyncio +async def test_detect_local_dc_returns_none_when_all_fail(monkeypatch): + async def fake_open_connection(host, port): + raise OSError("connect failed") + + monkeypatch.setattr(nearest_dc.asyncio, "open_connection", fake_open_connection) + + endpoints = [ + MockEndpoint("bad1", 9999, "dc1"), + MockEndpoint("bad2", 9999, "dc2"), + ] + assert await nearest_dc.detect_local_dc(endpoints, timeout=0.05) is None + + +@pytest.mark.asyncio +async def test_detect_local_dc_returns_location_of_fastest(monkeypatch): + async def fake_open_connection(host, port): + if host == "dc1_host": + await asyncio.sleep(0.05) + return None, MockWriter() + + monkeypatch.setattr(nearest_dc.asyncio, "open_connection", fake_open_connection) + + endpoints = [ + MockEndpoint("dc1_host", 1, "dc1"), + MockEndpoint("dc2_host", 1, "dc2"), + ] + assert await nearest_dc.detect_local_dc(endpoints, max_per_location=5, timeout=0.2) == "dc2" + + +@pytest.mark.asyncio +async def test_detect_local_dc_respects_max_per_location(monkeypatch): + calls = [] + + async def fake_open_connection(host, port): + calls.append((host, port)) + raise OSError("connect failed") + + monkeypatch.setattr(nearest_dc.asyncio, "open_connection", fake_open_connection) + + endpoints = [MockEndpoint(f"dc1_{i}", 1, "dc1") for i in range(5)] + [ + MockEndpoint(f"dc2_{i}", 1, "dc2") for i in range(5) + ] + await nearest_dc.detect_local_dc(endpoints, max_per_location=2, timeout=0.2) + + assert len(calls) == 4 + + +@pytest.mark.asyncio +async def test_detect_local_dc_validates_max_per_location(): + endpoints = [MockEndpoint("h1", 1, "dc1")] + with pytest.raises(ValueError, match="max_per_location must be >= 1"): + await nearest_dc.detect_local_dc(endpoints, max_per_location=0) + + +@pytest.mark.asyncio +async def test_detect_local_dc_validates_timeout(): + endpoints = [MockEndpoint("h1", 1, "dc1")] + with pytest.raises(ValueError, match="timeout must be > 0"): + await nearest_dc.detect_local_dc(endpoints, timeout=0) diff --git a/tests/test_discovery_detect_local_dc.py b/tests/test_discovery_detect_local_dc.py new file mode 100644 index 00000000..4123dd3f --- /dev/null +++ b/tests/test_discovery_detect_local_dc.py @@ -0,0 +1,120 @@ +# -*- coding: utf-8 -*- +from unittest.mock import Mock, MagicMock, patch +from ydb import driver, pool, nearest_dc, connection + + +class MockEndpointInfo: + def __init__(self, address, port, location): + self.address = address + self.port = port + self.endpoint = f"{address}:{port}" + self.location = location + self.ssl = False + self.node_id = 1 + + def endpoints_with_options(self): + yield (self.endpoint, connection.EndpointOptions(ssl_target_name_override=None, node_id=self.node_id)) + + +class MockDiscoveryResult: + def __init__(self, self_location, endpoints): + self.self_location = self_location + self.endpoints = endpoints + + +def test_detect_local_dc_overrides_server_location(): + """Test that detected location overrides server's self_location for preferred endpoints.""" + # Server reports dc1, but we detect dc2 as nearest + endpoints = [ + MockEndpointInfo("dc1-host", 2135, "dc1"), + MockEndpointInfo("dc2-host", 2135, "dc2"), + ] + mock_result = MockDiscoveryResult(self_location="dc1", endpoints=endpoints) + + mock_resolver = MagicMock() + mock_resolver.context_resolve.return_value.__enter__.return_value = mock_result + mock_resolver.context_resolve.return_value.__exit__.return_value = None + + preferred = [] + + with patch.object(nearest_dc, "detect_local_dc", Mock(return_value="dc2")): + with patch( + "ydb.connection.Connection.ready_factory", lambda *args, **kw: MagicMock(endpoint=args[0], node_id=1) + ): + config = driver.DriverConfig( + endpoint="grpc://test:2135", database="/local", detect_local_dc=True, use_all_nodes=False + ) + discovery = pool.Discovery(store=pool.ConnectionsCache(config.use_all_nodes), driver_config=config) + discovery._resolver = mock_resolver + + original_add = discovery._cache.add + discovery._cache.add = lambda conn, pref=False: ( + preferred.append(conn.endpoint) if pref else None, + original_add(conn, pref), + )[1] + + discovery.execute_discovery() + + assert any("dc2" in ep for ep in preferred), "dc2 should be preferred (detected)" + assert not any("dc1" in ep for ep in preferred), "dc1 should not be preferred" + + +def test_detect_local_dc_failure_fallback(): + """Test that detection failure falls back to server's self_location.""" + endpoints = [ + MockEndpointInfo("dc1-host", 2135, "dc1"), + MockEndpointInfo("dc2-host", 2135, "dc2"), + ] + mock_result = MockDiscoveryResult(self_location="dc1", endpoints=endpoints) + + mock_resolver = MagicMock() + mock_resolver.context_resolve.return_value.__enter__.return_value = mock_result + mock_resolver.context_resolve.return_value.__exit__.return_value = None + + preferred = [] + + with patch.object(nearest_dc, "detect_local_dc", Mock(return_value=None)): + with patch( + "ydb.connection.Connection.ready_factory", lambda *args, **kw: MagicMock(endpoint=args[0], node_id=1) + ): + config = driver.DriverConfig( + endpoint="grpc://test:2135", database="/local", detect_local_dc=True, use_all_nodes=False + ) + discovery = pool.Discovery(store=pool.ConnectionsCache(config.use_all_nodes), driver_config=config) + discovery._resolver = mock_resolver + + original_add = discovery._cache.add + discovery._cache.add = lambda conn, pref=False: ( + preferred.append(conn.endpoint) if pref else None, + original_add(conn, pref), + )[1] + + discovery.execute_discovery() + + assert any("dc1" in ep for ep in preferred), "dc1 should be preferred (server fallback)" + + +def test_detect_local_dc_skipped_when_use_all_nodes_true(): + """Test that detect_local_dc is NOT called when use_all_nodes=True.""" + endpoints = [ + MockEndpointInfo("dc1-host", 2135, "dc1"), + MockEndpointInfo("dc2-host", 2135, "dc2"), + ] + mock_result = MockDiscoveryResult(self_location="dc1", endpoints=endpoints) + + mock_resolver = MagicMock() + mock_resolver.context_resolve.return_value.__enter__.return_value = mock_result + mock_resolver.context_resolve.return_value.__exit__.return_value = None + + with patch.object(nearest_dc, "detect_local_dc", Mock(return_value="dc2")) as detect_mock: + with patch( + "ydb.connection.Connection.ready_factory", lambda *args, **kw: MagicMock(endpoint=args[0], node_id=1) + ): + config = driver.DriverConfig( + endpoint="grpc://test:2135", database="/local", detect_local_dc=True, use_all_nodes=True + ) + discovery = pool.Discovery(store=pool.ConnectionsCache(config.use_all_nodes), driver_config=config) + discovery._resolver = mock_resolver + discovery.execute_discovery() + + assert detect_mock.call_count == 0, "detect_local_dc should NOT be called when use_all_nodes=True" diff --git a/tests/test_nearest_dc.py b/tests/test_nearest_dc.py new file mode 100644 index 00000000..cc76e10d --- /dev/null +++ b/tests/test_nearest_dc.py @@ -0,0 +1,146 @@ +import time +import pytest +from ydb import nearest_dc + + +class MockEndpoint: + def __init__(self, address, port, location, ipv4_addrs=(), ipv6_addrs=()): + self.address = address + self.port = port + self.endpoint = f"{address}:{port}" + self.location = location + self.ipv4_addrs = ipv4_addrs + self.ipv6_addrs = ipv6_addrs + + +class DummySock: + def close(self): + pass + + +def test_check_fastest_endpoint_empty(): + assert nearest_dc._check_fastest_endpoint([]) is None + + +def test_check_fastest_endpoint_all_fail(monkeypatch): + def fake_create_connection(addr_port, timeout=None): + raise OSError("connect failed") + + monkeypatch.setattr(nearest_dc.socket, "create_connection", fake_create_connection) + + endpoints = [ + MockEndpoint("a", 1, "dc1"), + MockEndpoint("b", 1, "dc2"), + ] + assert nearest_dc._check_fastest_endpoint(endpoints, timeout=0.05) is None + + +def test_check_fastest_endpoint_fastest_wins(monkeypatch): + def fake_create_connection(addr_port, timeout=None): + host, _ = addr_port + if host == "slow": + time.sleep(0.05) + return DummySock() + + monkeypatch.setattr(nearest_dc.socket, "create_connection", fake_create_connection) + + endpoints = [ + MockEndpoint("slow", 1, "dc_slow"), + MockEndpoint("fast", 1, "dc_fast"), + ] + winner = nearest_dc._check_fastest_endpoint(endpoints, timeout=0.2) + assert winner is not None + assert winner.location == "dc_fast" + + +def test_check_fastest_endpoint_respects_main_timeout(monkeypatch): + def fake_create_connection(addr_port, timeout=None): + time.sleep(0.2) + return DummySock() + + monkeypatch.setattr(nearest_dc.socket, "create_connection", fake_create_connection) + + endpoints = [ + MockEndpoint("hang1", 1, "dc1"), + MockEndpoint("hang2", 1, "dc2"), + ] + + winner = nearest_dc._check_fastest_endpoint(endpoints, timeout=0.05) + + assert winner is None + + +def test_detect_local_dc_empty_endpoints(): + with pytest.raises(ValueError, match="Empty endpoints"): + nearest_dc.detect_local_dc([]) + + +def test_detect_local_dc_single_location_returns_immediately(monkeypatch): + def fail_if_called(*args, **kwargs): + raise AssertionError("create_connection should not be called for single location") + + monkeypatch.setattr(nearest_dc.socket, "create_connection", fail_if_called) + + endpoints = [ + MockEndpoint("h1", 1, "dc1"), + MockEndpoint("h2", 1, "dc1"), + ] + assert nearest_dc.detect_local_dc(endpoints) == "dc1" + + +def test_detect_local_dc_returns_none_when_all_fail(monkeypatch): + def fake_create_connection(addr_port, timeout=None): + raise OSError("connect failed") + + monkeypatch.setattr(nearest_dc.socket, "create_connection", fake_create_connection) + + endpoints = [ + MockEndpoint("bad1", 9999, "dc1"), + MockEndpoint("bad2", 9999, "dc2"), + ] + assert nearest_dc.detect_local_dc(endpoints, timeout=0.05) is None + + +def test_detect_local_dc_returns_location_of_fastest(monkeypatch): + def fake_create_connection(addr_port, timeout=None): + host, _ = addr_port + if host == "dc1_host": + time.sleep(0.05) + return DummySock() + + monkeypatch.setattr(nearest_dc.socket, "create_connection", fake_create_connection) + + endpoints = [ + MockEndpoint("dc1_host", 1, "dc1"), + MockEndpoint("dc2_host", 1, "dc2"), + ] + assert nearest_dc.detect_local_dc(endpoints, max_per_location=5, timeout=0.2) == "dc2" + + +def test_detect_local_dc_respects_max_per_location(monkeypatch): + calls = [] + + def fake_create_connection(addr_port, timeout=None): + calls.append(addr_port) + raise OSError("connect failed") + + monkeypatch.setattr(nearest_dc.socket, "create_connection", fake_create_connection) + + endpoints = [MockEndpoint(f"dc1_{i}", 1, "dc1") for i in range(5)] + [ + MockEndpoint(f"dc2_{i}", 1, "dc2") for i in range(5) + ] + nearest_dc.detect_local_dc(endpoints, max_per_location=2, timeout=0.2) + + assert len(calls) == 4 + + +def test_detect_local_dc_validates_max_per_location(): + endpoints = [MockEndpoint("h1", 1, "dc1")] + with pytest.raises(ValueError, match="max_per_location must be >= 1"): + nearest_dc.detect_local_dc(endpoints, max_per_location=0) + + +def test_detect_local_dc_validates_timeout(): + endpoints = [MockEndpoint("h1", 1, "dc1")] + with pytest.raises(ValueError, match="timeout must be > 0"): + nearest_dc.detect_local_dc(endpoints, timeout=0) diff --git a/ydb/aio/nearest_dc.py b/ydb/aio/nearest_dc.py new file mode 100644 index 00000000..c17e656d --- /dev/null +++ b/ydb/aio/nearest_dc.py @@ -0,0 +1,181 @@ +# -*- coding: utf-8 -*- +import asyncio +import logging +import random +import time +from typing import Dict, List, Optional + +from .. import resolver + + +logger = logging.getLogger(__name__) + + +async def _check_fastest_endpoint( + endpoints: List[resolver.EndpointInfo], timeout: float = 5.0 +) -> Optional[resolver.EndpointInfo]: + """ + Perform async TCP race: connect to all endpoints concurrently and return the fastest one. + + This function starts async TCP connections to all provided endpoints concurrently using + asyncio tasks and returns the first one that successfully connects. Other connection + attempts are cancelled once a winner is found. + + :param endpoints: List of resolver.EndpointInfo objects + :param timeout: Maximum time to wait for any connection (seconds) + :return: Fastest endpoint that connected successfully, or None if all failed or timeout + """ + if not endpoints: + return None + + deadline = time.monotonic() + timeout + + async def try_connect(endpoint): + remaining = deadline - time.monotonic() + if remaining <= 0: + return None + + if endpoint.ipv6_addrs: + target_host = endpoint.ipv6_addrs[0] + elif endpoint.ipv4_addrs: + target_host = endpoint.ipv4_addrs[0] + else: + target_host = endpoint.address + + try: + _, writer = await asyncio.wait_for( + asyncio.open_connection(target_host, endpoint.port), + timeout=remaining, + ) + writer.close() + await writer.wait_closed() + return endpoint + except (OSError, asyncio.TimeoutError): + return None + except Exception as e: + logger.debug("Unexpected error connecting to %s: %s", endpoint.endpoint, e) + return None + + tasks = [asyncio.create_task(try_connect(endpoint)) for endpoint in endpoints] + try: + for task in asyncio.as_completed(tasks, timeout=timeout): + endpoint = await task + if endpoint is not None: + return endpoint + return None + except asyncio.TimeoutError: + logger.debug("TCP race timeout after %.2fs, no endpoint connected in time", timeout) + return None + finally: + for t in tasks: + if not t.done(): + t.cancel() + + +def _split_endpoints_by_location( + endpoints: List[resolver.EndpointInfo], +) -> Dict[str, List[resolver.EndpointInfo]]: + """ + Group endpoints by their location. + + :param endpoints: List of resolver.EndpointInfo objects + :return: Dictionary mapping location -> list of resolver.EndpointInfo + """ + result: Dict[str, List[resolver.EndpointInfo]] = {} + for endpoint in endpoints: + location = endpoint.location + if location not in result: + result[location] = [] + result[location].append(endpoint) + return result + + +def _get_random_endpoints(endpoints: List[resolver.EndpointInfo], count: int) -> List[resolver.EndpointInfo]: + """ + Get random sample of endpoints. + + :param endpoints: List of resolver.EndpointInfo objects + :param count: Maximum number of endpoints to return + :return: Random sample of resolver.EndpointInfo + """ + if len(endpoints) <= count: + return endpoints + return random.sample(endpoints, count) + + +async def detect_local_dc( + endpoints: List[resolver.EndpointInfo], max_per_location: int = 3, timeout: float = 5.0 +) -> Optional[str]: + """ + Detect nearest datacenter by performing async TCP race between endpoints. + + This function groups endpoints by location, selects random samples from each location, + and performs parallel TCP connections to find the fastest one. The location of the + fastest endpoint is considered the nearest datacenter. + + Algorithm: + 1. Group endpoints by location + 2. If only one location exists, return it immediately + 3. Select up to max_per_location random endpoints from each location + 4. If too many endpoints, reduce to one per location and cap at limit + 5. Perform TCP race: connect to all selected endpoints simultaneously + 6. Return the location of the first endpoint that connects successfully + 7. If all connections fail, return None + + :param endpoints: List of resolver.EndpointInfo objects from discovery + :param max_per_location: Maximum number of endpoints to test per location (default: 3, must be >= 1) + :param timeout: TCP connection timeout in seconds (default: 5.0, must be > 0) + :return: Location string of the nearest datacenter, or None if detection failed + :raises ValueError: If endpoints list is empty, max_per_location < 1, or timeout <= 0 + """ + if not endpoints: + raise ValueError("Empty endpoints list for local DC detection") + if max_per_location < 1: + raise ValueError(f"max_per_location must be >= 1, got {max_per_location}") + if timeout <= 0: + raise ValueError(f"timeout must be > 0, got {timeout}") + + endpoints_by_location = _split_endpoints_by_location(endpoints) + + logger.debug( + "Detecting local DC from %d endpoints across %d locations", + len(endpoints), + len(endpoints_by_location), + ) + + if len(endpoints_by_location) == 1: + location = list(endpoints_by_location.keys())[0] + logger.debug("Only one location found: %s", location) + return location + + _MAX_CONCURRENT_TASKS = 30 + + endpoints_to_test = [] + for location, location_endpoints in endpoints_by_location.items(): + sample = _get_random_endpoints(location_endpoints, max_per_location) + endpoints_to_test.extend(sample) + logger.debug( + "Selected %d/%d endpoints from location '%s' for testing", + len(sample), + len(location_endpoints), + location, + ) + + if len(endpoints_to_test) > _MAX_CONCURRENT_TASKS: + endpoints_to_test = [random.choice(location_eps) for location_eps in endpoints_by_location.values()] + + if len(endpoints_to_test) > _MAX_CONCURRENT_TASKS: + endpoints_to_test = random.sample(endpoints_to_test, _MAX_CONCURRENT_TASKS) + + logger.debug("Capped endpoints to %d to limit concurrent tasks", len(endpoints_to_test)) + + fastest_endpoint = await _check_fastest_endpoint(endpoints_to_test, timeout=timeout) + + if fastest_endpoint is None: + logger.debug("Failed to detect local DC via TCP race: no endpoint connected in time") + return None + + detected_location = fastest_endpoint.location + logger.debug("Detected local DC: %s", detected_location) + + return detected_location diff --git a/ydb/aio/pool.py b/ydb/aio/pool.py index 7739035e..0d094431 100644 --- a/ydb/aio/pool.py +++ b/ydb/aio/pool.py @@ -10,7 +10,7 @@ from .connection import Connection, EndpointKey -from . import resolver +from . import nearest_dc, resolver if TYPE_CHECKING: from ydb.driver import DriverConfig @@ -145,6 +145,47 @@ async def execute_discovery(self) -> bool: if cached_endpoint.endpoint not in resolved_endpoints: self._cache.make_outdated(cached_endpoint) + local_dc = resolve_details.self_location + + # Detect local DC using TCP latency if enabled and preferred is meaningful + if self._driver_config.detect_local_dc and not self._driver_config.use_all_nodes: + # Use only endpoints that match the SSL requirements for detection + ssl_filtered_endpoints = [ + endpoint + for endpoint in resolve_details.endpoints + if (self._ssl_required and endpoint.ssl) or (not self._ssl_required and not endpoint.ssl) + ] + + if ssl_filtered_endpoints: + try: + detected_location = await nearest_dc.detect_local_dc( + ssl_filtered_endpoints, max_per_location=3, timeout=self._ready_timeout + ) + if detected_location: + local_dc = detected_location + self.logger.info( + "Detected local DC via TCP latency: %s (server reported: %s)", + local_dc, + resolve_details.self_location, + ) + else: + self.logger.warning( + "Failed to detect local DC via TCP latency, using server location: %s", + resolve_details.self_location, + ) + except Exception as e: + self.logger.warning( + "Failed to detect local DC via TCP latency, using server location: %s. Error: %s", + resolve_details.self_location, + e, + exc_info=True, + ) + else: + self.logger.warning( + "No SSL-compatible endpoints for local DC detection, using server location: %s", + resolve_details.self_location, + ) + for resolved_endpoint in resolve_details.endpoints: if self._ssl_required and not resolved_endpoint.ssl: continue @@ -152,7 +193,7 @@ async def execute_discovery(self) -> bool: if not self._ssl_required and resolved_endpoint.ssl: continue - preferred = resolve_details.self_location == resolved_endpoint.location + preferred = local_dc == resolved_endpoint.location for ( endpoint, diff --git a/ydb/driver.py b/ydb/driver.py index 72602a8c..fc781c25 100644 --- a/ydb/driver.py +++ b/ydb/driver.py @@ -105,6 +105,7 @@ class DriverConfig(object): "discovery_request_timeout", "compression", "disable_discovery", + "detect_local_dc", ) def __init__( @@ -130,6 +131,7 @@ def __init__( discovery_request_timeout: int = 10, compression: Optional[grpc.Compression] = None, disable_discovery: bool = False, + detect_local_dc: bool = False, ) -> None: """ A driver config to initialize a driver instance @@ -151,6 +153,9 @@ def __init__( :param grpc_lb_policy_name: A load balancing policy to be used for discovery channel construction. Default value is `round_round` :param discovery_request_timeout: A default timeout to complete the discovery. The default value is 10 seconds. :param disable_discovery: If True, endpoint discovery is disabled and only the start endpoint is used for all requests. + :param detect_local_dc: If True, detect nearest datacenter using TCP latency measurement instead of using\ + server-provided self_location. **Note**: This option only affects endpoint selection when use_all_nodes=False.\ + When use_all_nodes=True (default), all endpoints are used regardless of detected location. """ self.endpoint = endpoint @@ -179,6 +184,7 @@ def __init__( self.discovery_request_timeout = discovery_request_timeout self.compression = compression self.disable_discovery = disable_discovery + self.detect_local_dc = detect_local_dc def set_database(self, database: str) -> "DriverConfig": self.database = database diff --git a/ydb/nearest_dc.py b/ydb/nearest_dc.py new file mode 100644 index 00000000..6b82adeb --- /dev/null +++ b/ydb/nearest_dc.py @@ -0,0 +1,229 @@ +# -*- coding: utf-8 -*- +import atexit +import concurrent.futures +import socket +import sys +import threading +import logging +import random +import time +from typing import Dict, List, Optional + +from . import resolver + + +logger = logging.getLogger(__name__) + +# Module-level thread pool for TCP race (reused across discovery cycles) +_TCP_RACE_MAX_WORKERS = 30 +_TCP_RACE_EXECUTOR: Optional[concurrent.futures.ThreadPoolExecutor] = None +_EXECUTOR_LOCK = threading.Lock() +_ATEXIT_REGISTERED = False + + +def _get_executor() -> concurrent.futures.ThreadPoolExecutor: + """ + Lazily create and return the thread pool executor. + + The executor is created on first use to avoid import-time side effects. + The atexit hook is registered only when the executor is actually created. + """ + global _TCP_RACE_EXECUTOR, _ATEXIT_REGISTERED + + if _TCP_RACE_EXECUTOR is None: + with _EXECUTOR_LOCK: + if _TCP_RACE_EXECUTOR is None: + _TCP_RACE_EXECUTOR = concurrent.futures.ThreadPoolExecutor( + max_workers=_TCP_RACE_MAX_WORKERS, + thread_name_prefix="ydb-tcp-race", + ) + + if not _ATEXIT_REGISTERED: + atexit.register(_shutdown_executor) + _ATEXIT_REGISTERED = True + + return _TCP_RACE_EXECUTOR + + +def _shutdown_executor(): + """Shutdown the executor if it was created.""" + if _TCP_RACE_EXECUTOR is not None: + if sys.version_info >= (3, 9): + _TCP_RACE_EXECUTOR.shutdown(wait=False, cancel_futures=True) + else: + _TCP_RACE_EXECUTOR.shutdown(wait=False) + + +def _check_fastest_endpoint( + endpoints: List[resolver.EndpointInfo], timeout: float = 5.0 +) -> Optional[resolver.EndpointInfo]: + """ + Perform TCP race using a bounded thread pool and return the fastest endpoint. + + Uses a module-level ThreadPoolExecutor to avoid creating new threads on every + discovery cycle. Returns immediately when the first endpoint connects successfully. + + If there are more endpoints than the thread pool size, takes one random endpoint + per location to ensure fair representation of all locations in the race. If there + are still too many locations, randomly samples them to stay within the limit. + + :param endpoints: List of resolver.EndpointInfo objects + :param timeout: Maximum time to wait for any connection (seconds) + :return: Fastest endpoint that connected successfully, or None if all failed + """ + if not endpoints: + return None + + if len(endpoints) > _TCP_RACE_MAX_WORKERS: + endpoints_by_location = _split_endpoints_by_location(endpoints) + endpoints = [random.choice(location_eps) for location_eps in endpoints_by_location.values()] + + if len(endpoints) > _TCP_RACE_MAX_WORKERS: + endpoints = random.sample(endpoints, _TCP_RACE_MAX_WORKERS) + + stop_event = threading.Event() + winner_lock = threading.Lock() + deadline = time.monotonic() + timeout + + def try_connect(endpoint: resolver.EndpointInfo) -> Optional[resolver.EndpointInfo]: + """Try to connect to endpoint and return it if successful.""" + remaining = deadline - time.monotonic() + if remaining <= 0 or stop_event.is_set(): + return None + + if endpoint.ipv6_addrs: + target_host = endpoint.ipv6_addrs[0] + elif endpoint.ipv4_addrs: + target_host = endpoint.ipv4_addrs[0] + else: + target_host = endpoint.address + + try: + sock = socket.create_connection((target_host, endpoint.port), timeout=remaining) + try: + with winner_lock: + if stop_event.is_set(): + return None + stop_event.set() + return endpoint + finally: + sock.close() + except (OSError, socket.timeout): + # Ignore expected connection errors; endpoints that fail simply lose the TCP race. + return None + except Exception as e: + logger.debug("Unexpected error connecting to %s: %s", endpoint.endpoint, e) + return None + + executor = _get_executor() + futures: List[concurrent.futures.Future] = [executor.submit(try_connect, ep) for ep in endpoints] + + try: + for fut in concurrent.futures.as_completed(futures, timeout=timeout): + result = fut.result() + if result is not None: + return result + except concurrent.futures.TimeoutError: + # Overall timeout expired + pass + finally: + for f in futures: + f.cancel() + + return None + + +def _split_endpoints_by_location(endpoints: List[resolver.EndpointInfo]) -> Dict[str, List[resolver.EndpointInfo]]: + """ + Group endpoints by their location. + + :param endpoints: List of resolver.EndpointInfo objects + :return: Dictionary mapping location -> list of resolver.EndpointInfo + """ + result: Dict[str, List[resolver.EndpointInfo]] = {} + for endpoint in endpoints: + location = endpoint.location + if location not in result: + result[location] = [] + result[location].append(endpoint) + return result + + +def _get_random_endpoints(endpoints: List[resolver.EndpointInfo], count: int) -> List[resolver.EndpointInfo]: + """ + Get random sample of endpoints. + + :param endpoints: List of resolver.EndpointInfo objects + :param count: Maximum number of endpoints to return + :return: Random sample of resolver.EndpointInfo + """ + if len(endpoints) <= count: + return endpoints + return random.sample(endpoints, count) + + +def detect_local_dc( + endpoints: List[resolver.EndpointInfo], max_per_location: int = 3, timeout: float = 5.0 +) -> Optional[str]: + """ + Detect nearest datacenter by performing TCP race between endpoints. + + This function groups endpoints by location, selects random samples from each location, + and performs parallel TCP connections to find the fastest one. The location of the + fastest endpoint is considered the nearest datacenter. + + Algorithm: + 1. Group endpoints by location + 2. If only one location exists, return it immediately + 3. Select up to max_per_location random endpoints from each location + 4. Perform TCP race: connect to all selected endpoints simultaneously + 5. Return the location of the first endpoint that connects successfully + 6. If all connections fail, return None + + :param endpoints: List of resolver.EndpointInfo objects from discovery + :param max_per_location: Maximum number of endpoints to test per location (default: 3, must be >= 1) + :param timeout: TCP connection timeout in seconds (default: 5.0, must be > 0) + :return: Location string of the nearest datacenter, or None if detection failed + :raises ValueError: If endpoints list is empty, max_per_location < 1, or timeout <= 0 + """ + if not endpoints: + raise ValueError("Empty endpoints list for local DC detection") + if max_per_location < 1: + raise ValueError(f"max_per_location must be >= 1, got {max_per_location}") + if timeout <= 0: + raise ValueError(f"timeout must be > 0, got {timeout}") + + endpoints_by_location = _split_endpoints_by_location(endpoints) + + logger.debug( + "Detecting local DC from %d endpoints across %d locations", + len(endpoints), + len(endpoints_by_location), + ) + + if len(endpoints_by_location) == 1: + location = list(endpoints_by_location.keys())[0] + logger.debug("Only one location found: %s", location) + return location + + endpoints_to_test = [] + for location, location_endpoints in endpoints_by_location.items(): + sample = _get_random_endpoints(location_endpoints, max_per_location) + endpoints_to_test.extend(sample) + logger.debug( + "Selected %d/%d endpoints from location '%s' for testing", + len(sample), + len(location_endpoints), + location, + ) + + fastest_endpoint = _check_fastest_endpoint(endpoints_to_test, timeout=timeout) + + if fastest_endpoint is None: + logger.debug("Failed to detect local DC via TCP race: no endpoint connected in time") + return None + + detected_location = fastest_endpoint.location + logger.debug("Detected local DC: %s", detected_location) + + return detected_location diff --git a/ydb/pool.py b/ydb/pool.py index 1d1374e6..a45a931c 100644 --- a/ydb/pool.py +++ b/ydb/pool.py @@ -9,7 +9,7 @@ import random from typing import Any, Callable, ContextManager, List, Optional, Set, Tuple, TYPE_CHECKING -from . import connection as connection_impl, issues, resolver, _utilities, tracing +from . import connection as connection_impl, issues, nearest_dc, resolver, _utilities, tracing from abc import abstractmethod from .connection import Connection, EndpointKey @@ -232,6 +232,47 @@ def execute_discovery(self) -> bool: if cached_endpoint.endpoint not in resolved_endpoints: self._cache.make_outdated(cached_endpoint) + local_dc = resolve_details.self_location + + # Detect local DC using TCP latency if enabled and preferred is meaningful + if self._driver_config.detect_local_dc and not self._driver_config.use_all_nodes: + # Use only endpoints that match the SSL requirements for detection + ssl_filtered_endpoints = [ + endpoint + for endpoint in resolve_details.endpoints + if (self._ssl_required and endpoint.ssl) or (not self._ssl_required and not endpoint.ssl) + ] + + if ssl_filtered_endpoints: + try: + detected_location = nearest_dc.detect_local_dc( + ssl_filtered_endpoints, max_per_location=3, timeout=self._ready_timeout + ) + if detected_location: + local_dc = detected_location + self.logger.info( + "Detected local DC via TCP latency: %s (server reported: %s)", + local_dc, + resolve_details.self_location, + ) + else: + self.logger.warning( + "Failed to detect local DC via TCP latency, using server location: %s", + resolve_details.self_location, + ) + except Exception as e: + self.logger.warning( + "Failed to detect local DC via TCP latency, using server location: %s. Error: %s", + resolve_details.self_location, + e, + exc_info=True, + ) + else: + self.logger.warning( + "No SSL-compatible endpoints for local DC detection, using server location: %s", + resolve_details.self_location, + ) + for resolved_endpoint in resolve_details.endpoints: if self._ssl_required and not resolved_endpoint.ssl: continue @@ -239,7 +280,7 @@ def execute_discovery(self) -> bool: if not self._ssl_required and resolved_endpoint.ssl: continue - preferred = resolve_details.self_location == resolved_endpoint.location + preferred = local_dc == resolved_endpoint.location for ( endpoint,