From 139dde2e07dcde3d9b79ba1158d621b186b48fab Mon Sep 17 00:00:00 2001 From: Idris Yandarov <32651311+r142f@users.noreply.github.com> Date: Mon, 26 Jan 2026 06:36:36 +0300 Subject: [PATCH 01/12] feat: add nearest DC detection with TCP race --- tests/aio/test_nearest_dc.py | 145 ++++++++++++++++++++++++++++++ tests/test_nearest_dc.py | 132 ++++++++++++++++++++++++++++ ydb/aio/nearest_dc.py | 164 ++++++++++++++++++++++++++++++++++ ydb/aio/pool.py | 26 +++++- ydb/driver.py | 4 + ydb/nearest_dc.py | 166 +++++++++++++++++++++++++++++++++++ ydb/pool.py | 26 +++++- 7 files changed, 659 insertions(+), 4 deletions(-) create mode 100644 tests/aio/test_nearest_dc.py create mode 100644 tests/test_nearest_dc.py create mode 100644 ydb/aio/nearest_dc.py create mode 100644 ydb/nearest_dc.py diff --git a/tests/aio/test_nearest_dc.py b/tests/aio/test_nearest_dc.py new file mode 100644 index 00000000..bd252cff --- /dev/null +++ b/tests/aio/test_nearest_dc.py @@ -0,0 +1,145 @@ +import asyncio +import pytest +from ydb.aio import nearest_dc + + +class MockEndpoint: + def __init__(self, address, port, location): + self.address = address + self.port = port + self.endpoint = f"{address}:{port}" + self.location = location + + +class MockWriter: + def __init__(self): + self.closed = False + + def close(self): + self.closed = True + + async def wait_closed(self): + await asyncio.sleep(0) + + +@pytest.mark.asyncio +async def test_check_fastest_endpoint_empty(): + assert await nearest_dc._check_fastest_endpoint([]) is None + + +@pytest.mark.asyncio +async def test_check_fastest_endpoint_all_fail(monkeypatch): + async def fake_open_connection(host, port): + raise OSError("connect failed") + + monkeypatch.setattr(nearest_dc.asyncio, "open_connection", fake_open_connection) + + endpoints = [ + MockEndpoint("a", 1, "dc1"), + MockEndpoint("b", 1, "dc2"), + ] + assert await nearest_dc._check_fastest_endpoint(endpoints, timeout=0.05) is None + + +@pytest.mark.asyncio +async def test_check_fastest_endpoint_fastest_wins(monkeypatch): + async def fake_open_connection(host, port): + if host == "slow": + await asyncio.sleep(0.05) + return None, MockWriter() + + monkeypatch.setattr(nearest_dc.asyncio, "open_connection", fake_open_connection) + + endpoints = [ + MockEndpoint("slow", 1, "dc_slow"), + MockEndpoint("fast", 1, "dc_fast"), + ] + winner = await nearest_dc._check_fastest_endpoint(endpoints, timeout=0.2) + assert winner is not None + assert winner.location == "dc_fast" + + +@pytest.mark.asyncio +async def test_check_fastest_endpoint_respects_main_timeout(monkeypatch): + async def fake_open_connection(host, port): + await asyncio.sleep(0.2) + return None, MockWriter() + + monkeypatch.setattr(nearest_dc.asyncio, "open_connection", fake_open_connection) + + endpoints = [ + MockEndpoint("hang1", 1, "dc1"), + MockEndpoint("hang2", 1, "dc2"), + ] + + winner = await nearest_dc._check_fastest_endpoint(endpoints, timeout=0.05) + + assert winner is None + + +@pytest.mark.asyncio +async def test_detect_local_dc_empty_endpoints(): + with pytest.raises(ValueError, match="Empty endpoints"): + await nearest_dc.detect_local_dc([]) + + +@pytest.mark.asyncio +async def test_detect_local_dc_single_location_returns_immediately(monkeypatch): + async def fail_if_called(*args, **kwargs): + raise AssertionError("open_connection should not be called for single location") + + monkeypatch.setattr(nearest_dc.asyncio, "open_connection", fail_if_called) + + endpoints = [ + MockEndpoint("h1", 1, "dc1"), + MockEndpoint("h2", 1, "dc1"), + ] + assert await nearest_dc.detect_local_dc(endpoints) == "dc1" + + +@pytest.mark.asyncio +async def test_detect_local_dc_fallback_to_first_location_when_all_fail(monkeypatch): + async def fake_open_connection(host, port): + raise OSError("connect failed") + + monkeypatch.setattr(nearest_dc.asyncio, "open_connection", fake_open_connection) + + endpoints = [ + MockEndpoint("bad1", 9999, "dc1"), + MockEndpoint("bad2", 9999, "dc2"), + ] + assert await nearest_dc.detect_local_dc(endpoints, timeout=0.05) == "dc1" + + +@pytest.mark.asyncio +async def test_detect_local_dc_returns_location_of_fastest(monkeypatch): + async def fake_open_connection(host, port): + if host == "dc1_host": + await asyncio.sleep(0.05) + return None, MockWriter() + + monkeypatch.setattr(nearest_dc.asyncio, "open_connection", fake_open_connection) + + endpoints = [ + MockEndpoint("dc1_host", 1, "dc1"), + MockEndpoint("dc2_host", 1, "dc2"), + ] + assert await nearest_dc.detect_local_dc(endpoints, max_per_location=5, timeout=0.2) == "dc2" + + +@pytest.mark.asyncio +async def test_detect_local_dc_respects_max_per_location(monkeypatch): + calls = [] + + async def fake_open_connection(host, port): + calls.append((host, port)) + raise OSError("connect failed") + + monkeypatch.setattr(nearest_dc.asyncio, "open_connection", fake_open_connection) + + endpoints = [MockEndpoint(f"dc1_{i}", 1, "dc1") for i in range(5)] + [ + MockEndpoint(f"dc2_{i}", 1, "dc2") for i in range(5) + ] + await nearest_dc.detect_local_dc(endpoints, max_per_location=2, timeout=0.2) + + assert len(calls) == 4 diff --git a/tests/test_nearest_dc.py b/tests/test_nearest_dc.py new file mode 100644 index 00000000..97c53d68 --- /dev/null +++ b/tests/test_nearest_dc.py @@ -0,0 +1,132 @@ +import time +import pytest +from ydb import nearest_dc + + +class MockEndpoint: + def __init__(self, address, port, location): + self.address = address + self.port = port + self.endpoint = f"{address}:{port}" + self.location = location + + +class DummySock: + def close(self): + pass + + +def test_check_fastest_endpoint_empty(): + assert nearest_dc._check_fastest_endpoint([]) is None + + +def test_check_fastest_endpoint_all_fail(monkeypatch): + def fake_create_connection(addr_port, timeout=None): + raise OSError("connect failed") + + monkeypatch.setattr(nearest_dc.socket, "create_connection", fake_create_connection) + + endpoints = [ + MockEndpoint("a", 1, "dc1"), + MockEndpoint("b", 1, "dc2"), + ] + assert nearest_dc._check_fastest_endpoint(endpoints, timeout=0.05) is None + + +def test_check_fastest_endpoint_fastest_wins(monkeypatch): + def fake_create_connection(addr_port, timeout=None): + host, _ = addr_port + if host == "slow": + time.sleep(0.05) + return DummySock() + + monkeypatch.setattr(nearest_dc.socket, "create_connection", fake_create_connection) + + endpoints = [ + MockEndpoint("slow", 1, "dc_slow"), + MockEndpoint("fast", 1, "dc_fast"), + ] + winner = nearest_dc._check_fastest_endpoint(endpoints, timeout=0.2) + assert winner is not None + assert winner.location == "dc_fast" + + +def test_check_fastest_endpoint_respects_main_timeout(monkeypatch): + def fake_create_connection(addr_port, timeout=None): + time.sleep(0.2) + return DummySock() + + monkeypatch.setattr(nearest_dc.socket, "create_connection", fake_create_connection) + + endpoints = [ + MockEndpoint("hang1", 1, "dc1"), + MockEndpoint("hang2", 1, "dc2"), + ] + + winner = nearest_dc._check_fastest_endpoint(endpoints, timeout=0.05) + + assert winner is None + + +def test_detect_local_dc_empty_endpoints(): + with pytest.raises(ValueError, match="Empty endpoints"): + nearest_dc.detect_local_dc([]) + + +def test_detect_local_dc_single_location_returns_immediately(monkeypatch): + def fail_if_called(*args, **kwargs): + raise AssertionError("create_connection should not be called for single location") + + monkeypatch.setattr(nearest_dc.socket, "create_connection", fail_if_called) + + endpoints = [ + MockEndpoint("h1", 1, "dc1"), + MockEndpoint("h2", 1, "dc1"), + ] + assert nearest_dc.detect_local_dc(endpoints) == "dc1" + + +def test_detect_local_dc_fallback_to_first_location_when_all_fail(monkeypatch): + def fake_create_connection(addr_port, timeout=None): + raise OSError("connect failed") + + monkeypatch.setattr(nearest_dc.socket, "create_connection", fake_create_connection) + + endpoints = [ + MockEndpoint("bad1", 9999, "dc1"), + MockEndpoint("bad2", 9999, "dc2"), + ] + assert nearest_dc.detect_local_dc(endpoints, timeout=0.05) == "dc1" + + +def test_detect_local_dc_returns_location_of_fastest(monkeypatch): + def fake_create_connection(addr_port, timeout=None): + host, _ = addr_port + if host == "dc1_host": + time.sleep(0.05) + return DummySock() + + monkeypatch.setattr(nearest_dc.socket, "create_connection", fake_create_connection) + + endpoints = [ + MockEndpoint("dc1_host", 1, "dc1"), + MockEndpoint("dc2_host", 1, "dc2"), + ] + assert nearest_dc.detect_local_dc(endpoints, max_per_location=5, timeout=0.2) == "dc2" + + +def test_detect_local_dc_respects_max_per_location(monkeypatch): + calls = [] + + def fake_create_connection(addr_port, timeout=None): + calls.append(addr_port) + raise OSError("connect failed") + + monkeypatch.setattr(nearest_dc.socket, "create_connection", fake_create_connection) + + endpoints = [MockEndpoint(f"dc1_{i}", 1, "dc1") for i in range(5)] + [ + MockEndpoint(f"dc2_{i}", 1, "dc2") for i in range(5) + ] + nearest_dc.detect_local_dc(endpoints, max_per_location=2, timeout=0.2) + + assert len(calls) == 4 diff --git a/ydb/aio/nearest_dc.py b/ydb/aio/nearest_dc.py new file mode 100644 index 00000000..3d2cd001 --- /dev/null +++ b/ydb/aio/nearest_dc.py @@ -0,0 +1,164 @@ +# -*- coding: utf-8 -*- +import asyncio +import logging +import random +import time +from typing import List, Dict, Optional + +from .. import resolver + + +logger = logging.getLogger(__name__) + + +async def _check_fastest_endpoint( + endpoints: List[resolver.EndpointInfo], timeout: float = 5.0 +) -> Optional[resolver.EndpointInfo]: + """ + Perform async TCP race: connect to all endpoints concurrently and return the fastest one. + + This function starts async TCP connections to all provided endpoints concurrently using + asyncio tasks and returns the first one that successfully connects. Other connection + attempts are cancelled once a winner is found. + + :param endpoints: List of resolver.EndpointInfo objects + :param timeout: Maximum time to wait for any connection (seconds) + :return: Fastest endpoint that connected successfully, or None if all failed or timeout + """ + if not endpoints: + return None + + deadline = time.monotonic() + timeout + + async def try_connect(endpoint): + remaining = deadline - time.monotonic() + if remaining <= 0: + return None + + try: + _, writer = await asyncio.wait_for( + asyncio.open_connection(endpoint.address, endpoint.port), + timeout=remaining, + ) + writer.close() + await writer.wait_closed() + return endpoint + except (OSError, asyncio.TimeoutError) as e: + logger.debug("Failed to connect to %s: %s", endpoint.endpoint, e) + return None + + tasks = [asyncio.create_task(try_connect(endpoint)) for endpoint in endpoints] + try: + for task in asyncio.as_completed(tasks, timeout=timeout): + endpoint = await task + if endpoint is not None: + return endpoint + return None + except asyncio.TimeoutError: + logger.debug("TCP race timeout after %.2fs, no endpoint connected in time", timeout) + return None + finally: + for t in tasks: + if not t.done(): + t.cancel() + await asyncio.gather(*tasks, return_exceptions=True) + + +def _split_endpoints_by_location( + endpoints: List[resolver.EndpointInfo], +) -> Dict[str, List[resolver.EndpointInfo]]: + """ + Group endpoints by their location. + + :param endpoints: List of resolver.EndpointInfo objects + :return: Dictionary mapping location -> list of resolver.EndpointInfo + """ + result = {} + for endpoint in endpoints: + location = endpoint.location + if location not in result: + result[location] = [] + result[location].append(endpoint) + return result + + +def _get_random_endpoints(endpoints: List[resolver.EndpointInfo], count: int) -> List[resolver.EndpointInfo]: + """ + Get random sample of endpoints. + + :param endpoints: List of resolver.EndpointInfo objects + :param count: Maximum number of endpoints to return + :return: Random sample of resolver.EndpointInfo + """ + if len(endpoints) <= count: + return endpoints + + endpoints_copy = list(endpoints) + random.shuffle(endpoints_copy) + return endpoints_copy[:count] + + +async def detect_local_dc( + endpoints: List[resolver.EndpointInfo], max_per_location: int = 3, timeout: float = 5.0 +) -> str: + """ + Detect nearest datacenter by performing async TCP race between endpoints. + + This function groups endpoints by location, selects random samples from each location, + and performs parallel TCP connections to find the fastest one. The location of the + fastest endpoint is considered the nearest datacenter. + + Algorithm: + 1. Group endpoints by location + 2. If only one location exists, return it immediately + 3. Select up to max_per_location random endpoints from each location + 4. Perform TCP race: connect to all selected endpoints simultaneously + 5. Return the location of the first endpoint that connects successfully + + :param endpoints: List of resolver.EndpointInfo objects from discovery + :param max_per_location: Maximum number of endpoints to test per location (default: 3) + :param timeout: TCP connection timeout in seconds (default: 5.0) + :return: Location string of the nearest datacenter + :raises ValueError: If endpoints list is empty or detection fails + """ + if not endpoints: + raise ValueError("Empty endpoints list for local DC detection") + + endpoints_by_location = _split_endpoints_by_location(endpoints) + + logger.debug( + "Detecting local DC from %d endpoints across %d locations", + len(endpoints), + len(endpoints_by_location), + ) + + if len(endpoints_by_location) == 1: + location = list(endpoints_by_location.keys())[0] + logger.info("Only one location found: %s", location) + return location + + endpoints_to_test = [] + for location, location_endpoints in endpoints_by_location.items(): + sample = _get_random_endpoints(location_endpoints, max_per_location) + endpoints_to_test.extend(sample) + logger.debug( + "Selected %d/%d endpoints from location '%s' for testing", + len(sample), + len(location_endpoints), + location, + ) + + fastest_endpoint = await _check_fastest_endpoint(endpoints_to_test, timeout=timeout) + + if fastest_endpoint is None: + fallback_location = endpoints[0].location + logger.warning( + "Failed to detect local DC via TCP race, falling back to first endpoint location: %s", + fallback_location, + ) + return fallback_location + + detected_location = fastest_endpoint.location + logger.info("Detected local DC: %s", detected_location) + + return detected_location diff --git a/ydb/aio/pool.py b/ydb/aio/pool.py index 7739035e..ecfb23e3 100644 --- a/ydb/aio/pool.py +++ b/ydb/aio/pool.py @@ -10,7 +10,7 @@ from .connection import Connection, EndpointKey -from . import resolver +from . import nearest_dc, resolver if TYPE_CHECKING: from ydb.driver import DriverConfig @@ -145,6 +145,28 @@ async def execute_discovery(self) -> bool: if cached_endpoint.endpoint not in resolved_endpoints: self._cache.make_outdated(cached_endpoint) + local_dc = resolve_details.self_location + + # Detect local DC using TCP latency if enabled + if self._driver_config.detect_local_dc: + try: + detected_location = await nearest_dc.detect_local_dc( + resolve_details.endpoints, max_per_location=3, timeout=self._ready_timeout + ) + if detected_location: + local_dc = detected_location + self.logger.info( + "Detected local DC via TCP latency: %s (server reported: %s)", + local_dc, + resolve_details.self_location, + ) + except Exception as e: + self.logger.warning( + "Failed to detect local DC via TCP latency, using server location: %s. Error: %s", + resolve_details.self_location, + e, + ) + for resolved_endpoint in resolve_details.endpoints: if self._ssl_required and not resolved_endpoint.ssl: continue @@ -152,7 +174,7 @@ async def execute_discovery(self) -> bool: if not self._ssl_required and resolved_endpoint.ssl: continue - preferred = resolve_details.self_location == resolved_endpoint.location + preferred = local_dc == resolved_endpoint.location for ( endpoint, diff --git a/ydb/driver.py b/ydb/driver.py index 72602a8c..983d1b1b 100644 --- a/ydb/driver.py +++ b/ydb/driver.py @@ -105,6 +105,7 @@ class DriverConfig(object): "discovery_request_timeout", "compression", "disable_discovery", + "detect_local_dc", ) def __init__( @@ -130,6 +131,7 @@ def __init__( discovery_request_timeout: int = 10, compression: Optional[grpc.Compression] = None, disable_discovery: bool = False, + detect_local_dc=False, ) -> None: """ A driver config to initialize a driver instance @@ -151,6 +153,7 @@ def __init__( :param grpc_lb_policy_name: A load balancing policy to be used for discovery channel construction. Default value is `round_round` :param discovery_request_timeout: A default timeout to complete the discovery. The default value is 10 seconds. :param disable_discovery: If True, endpoint discovery is disabled and only the start endpoint is used for all requests. + :param detect_local_dc: If True, detect nearest datacenter using TCP latency measurement instead of using server-provided self_location. """ self.endpoint = endpoint @@ -179,6 +182,7 @@ def __init__( self.discovery_request_timeout = discovery_request_timeout self.compression = compression self.disable_discovery = disable_discovery + self.detect_local_dc = detect_local_dc def set_database(self, database: str) -> "DriverConfig": self.database = database diff --git a/ydb/nearest_dc.py b/ydb/nearest_dc.py new file mode 100644 index 00000000..f2840683 --- /dev/null +++ b/ydb/nearest_dc.py @@ -0,0 +1,166 @@ +# -*- coding: utf-8 -*- +import socket +import threading +import logging +import random +import time +from typing import List, Dict, Optional + +from . import resolver + + +logger = logging.getLogger(__name__) + + +def _check_fastest_endpoint( + endpoints: List[resolver.EndpointInfo], timeout: float = 5.0 +) -> Optional[resolver.EndpointInfo]: + """ + Perform TCP race: connect to all endpoints simultaneously and return the fastest one. + + This function starts TCP connections to all provided endpoints in parallel + and returns the first one that successfully connects. Other connection attempts + will continue until their socket timeout expires (they cannot be interrupted). + + :param endpoints: List of resolver.EndpointInfo objects + :param timeout: Maximum time to wait for any connection (seconds) + :return: Fastest endpoint that connected successfully, or None if all failed + """ + if not endpoints: + return None + + result = {"endpoint": None, "lock": threading.Lock()} + stop_event = threading.Event() + deadline = time.monotonic() + timeout + + def try_connect(endpoint: resolver.EndpointInfo): + """Try to connect to endpoint and report if successful.""" + remaining = deadline - time.monotonic() + if remaining <= 0 or stop_event.is_set(): + return + + try: + sock = socket.create_connection((endpoint.address, endpoint.port), timeout=remaining) + + try: + with result["lock"]: + if result["endpoint"] is None: + result["endpoint"] = endpoint + stop_event.set() + logger.debug("TCP race winner: %s (location: %s)", endpoint.endpoint, endpoint.location) + finally: + sock.close() + + except Exception as e: + logger.warning("Unexpected error connecting to %s: %s", endpoint.endpoint, e) + + threads: List[threading.Thread] = [] + for ep in endpoints: + thread = threading.Thread(target=try_connect, args=(ep,), daemon=True) + thread.start() + threads.append(thread) + + for thread in threads: + remaining = deadline - time.monotonic() + if remaining <= 0 or stop_event.is_set(): + break + + thread.join(timeout=remaining) + + return result["endpoint"] + + +def _split_endpoints_by_location(endpoints: List[resolver.EndpointInfo]) -> Dict[str, List[resolver.EndpointInfo]]: + """ + Group endpoints by their location. + + :param endpoints: List of resolver.EndpointInfo objects + :return: Dictionary mapping location -> list of resolver.EndpointInfo + """ + result = {} + for endpoint in endpoints: + location = endpoint.location + if location not in result: + result[location] = [] + result[location].append(endpoint) + return result + + +def _get_random_endpoints(endpoints: List[resolver.EndpointInfo], count: int) -> List[resolver.EndpointInfo]: + """ + Get random sample of endpoints. + + :param endpoints: List of resolver.EndpointInfo objects + :param count: Maximum number of endpoints to return + :return: Random sample of resolver.EndpointInfo + """ + if len(endpoints) <= count: + return endpoints + + endpoints_copy = list(endpoints) + random.shuffle(endpoints_copy) + return endpoints_copy[:count] + + +def detect_local_dc(endpoints: List[resolver.EndpointInfo], max_per_location: int = 3, timeout: float = 5.0) -> str: + """ + Detect nearest datacenter by performing TCP race between endpoints. + + This function groups endpoints by location, selects random samples from each location, + and performs parallel TCP connections to find the fastest one. The location of the + fastest endpoint is considered the nearest datacenter. + + Algorithm: + 1. Group endpoints by location + 2. If only one location exists, return it immediately + 3. Select up to max_per_location random endpoints from each location + 4. Perform TCP race: connect to all selected endpoints simultaneously + 5. Return the location of the first endpoint that connects successfully + + :param endpoints: List of resolver.EndpointInfo objects from discovery + :param max_per_location: Maximum number of endpoints to test per location (default: 3) + :param timeout: TCP connection timeout in seconds (default: 5.0) + :return: Location string of the nearest datacenter + :raises ValueError: If endpoints list is empty or detection fails + """ + if not endpoints: + raise ValueError("Empty endpoints list for local DC detection") + + endpoints_by_location = _split_endpoints_by_location(endpoints) + + logger.debug( + "Detecting local DC from %d endpoints across %d locations", + len(endpoints), + len(endpoints_by_location), + ) + + if len(endpoints_by_location) == 1: + location = list(endpoints_by_location.keys())[0] + logger.info("Only one location found: %s", location) + return location + + endpoints_to_test = [] + for location, location_endpoints in endpoints_by_location.items(): + sample = _get_random_endpoints(location_endpoints, max_per_location) + endpoints_to_test.extend(sample) + logger.debug( + "Selected %d/%d endpoints from location '%s' for testing", + len(sample), + len(location_endpoints), + location, + ) + + fastest_endpoint = _check_fastest_endpoint(endpoints_to_test, timeout=timeout) + + if fastest_endpoint is None: + fallback_location = endpoints[0].location + logger.warning( + "Failed to detect local DC via TCP race, falling back to first endpoint location: %s", + fallback_location, + ) + return fallback_location + + detected_location = fastest_endpoint.location + logger.info("Detected local DC: %s", detected_location) + + return detected_location diff --git a/ydb/pool.py b/ydb/pool.py index 1d1374e6..1c011f22 100644 --- a/ydb/pool.py +++ b/ydb/pool.py @@ -9,7 +9,7 @@ import random from typing import Any, Callable, ContextManager, List, Optional, Set, Tuple, TYPE_CHECKING -from . import connection as connection_impl, issues, resolver, _utilities, tracing +from . import connection as connection_impl, issues, nearest_dc, resolver, _utilities, tracing from abc import abstractmethod from .connection import Connection, EndpointKey @@ -232,6 +232,28 @@ def execute_discovery(self) -> bool: if cached_endpoint.endpoint not in resolved_endpoints: self._cache.make_outdated(cached_endpoint) + local_dc = resolve_details.self_location + + # Detect local DC using TCP latency if enabled + if self._driver_config.detect_local_dc: + try: + detected_location = nearest_dc.detect_local_dc( + resolve_details.endpoints, max_per_location=3, timeout=self._ready_timeout + ) + if detected_location: + local_dc = detected_location + self.logger.info( + "Detected local DC via TCP latency: %s (server reported: %s)", + local_dc, + resolve_details.self_location, + ) + except Exception as e: + self.logger.warning( + "Failed to detect local DC via TCP latency, using server location: %s. Error: %s", + resolve_details.self_location, + e, + ) + for resolved_endpoint in resolve_details.endpoints: if self._ssl_required and not resolved_endpoint.ssl: continue @@ -239,7 +261,7 @@ def execute_discovery(self) -> bool: if not self._ssl_required and resolved_endpoint.ssl: continue - preferred = resolve_details.self_location == resolved_endpoint.location + preferred = local_dc == resolved_endpoint.location for ( endpoint, From 22257fee10dd9456153600be81f9ecdd80cb397b Mon Sep 17 00:00:00 2001 From: Idris Yandarov <32651311+r142f@users.noreply.github.com> Date: Thu, 5 Feb 2026 07:04:33 +0300 Subject: [PATCH 02/12] fix: fix linting issues --- ydb/aio/nearest_dc.py | 4 ++-- ydb/nearest_dc.py | 6 +++--- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/ydb/aio/nearest_dc.py b/ydb/aio/nearest_dc.py index 3d2cd001..e819718c 100644 --- a/ydb/aio/nearest_dc.py +++ b/ydb/aio/nearest_dc.py @@ -3,7 +3,7 @@ import logging import random import time -from typing import List, Dict, Optional +from typing import Dict, List, Optional from .. import resolver @@ -73,7 +73,7 @@ def _split_endpoints_by_location( :param endpoints: List of resolver.EndpointInfo objects :return: Dictionary mapping location -> list of resolver.EndpointInfo """ - result = {} + result: Dict[str, List[resolver.EndpointInfo]] = {} for endpoint in endpoints: location = endpoint.location if location not in result: diff --git a/ydb/nearest_dc.py b/ydb/nearest_dc.py index f2840683..12b8f555 100644 --- a/ydb/nearest_dc.py +++ b/ydb/nearest_dc.py @@ -4,7 +4,7 @@ import logging import random import time -from typing import List, Dict, Optional +from typing import Any, Dict, List, Optional from . import resolver @@ -29,7 +29,7 @@ def _check_fastest_endpoint( if not endpoints: return None - result = {"endpoint": None, "lock": threading.Lock()} + result: Dict[str, Any] = {"endpoint": None, "lock": threading.Lock()} stop_event = threading.Event() deadline = time.monotonic() + timeout @@ -77,7 +77,7 @@ def _split_endpoints_by_location(endpoints: List[resolver.EndpointInfo]) -> Dict :param endpoints: List of resolver.EndpointInfo objects :return: Dictionary mapping location -> list of resolver.EndpointInfo """ - result = {} + result: Dict[str, List[resolver.EndpointInfo]] = {} for endpoint in endpoints: location = endpoint.location if location not in result: From 5674d2474be994b42ab5ba15aefef81239e887db Mon Sep 17 00:00:00 2001 From: Idris Yandarov <32651311+r142f@users.noreply.github.com> Date: Sat, 7 Feb 2026 12:06:26 +0300 Subject: [PATCH 03/12] fix: fixing flaws --- tests/aio/test_nearest_dc.py | 4 ++-- tests/test_nearest_dc.py | 4 ++-- ydb/aio/nearest_dc.py | 21 ++++++++++----------- ydb/aio/pool.py | 5 +++++ ydb/driver.py | 2 +- ydb/nearest_dc.py | 28 ++++++++++++---------------- ydb/pool.py | 5 +++++ 7 files changed, 37 insertions(+), 32 deletions(-) diff --git a/tests/aio/test_nearest_dc.py b/tests/aio/test_nearest_dc.py index bd252cff..be9b1f08 100644 --- a/tests/aio/test_nearest_dc.py +++ b/tests/aio/test_nearest_dc.py @@ -98,7 +98,7 @@ async def fail_if_called(*args, **kwargs): @pytest.mark.asyncio -async def test_detect_local_dc_fallback_to_first_location_when_all_fail(monkeypatch): +async def test_detect_local_dc_returns_none_when_all_fail(monkeypatch): async def fake_open_connection(host, port): raise OSError("connect failed") @@ -108,7 +108,7 @@ async def fake_open_connection(host, port): MockEndpoint("bad1", 9999, "dc1"), MockEndpoint("bad2", 9999, "dc2"), ] - assert await nearest_dc.detect_local_dc(endpoints, timeout=0.05) == "dc1" + assert await nearest_dc.detect_local_dc(endpoints, timeout=0.05) is None @pytest.mark.asyncio diff --git a/tests/test_nearest_dc.py b/tests/test_nearest_dc.py index 97c53d68..c626fb7f 100644 --- a/tests/test_nearest_dc.py +++ b/tests/test_nearest_dc.py @@ -86,7 +86,7 @@ def fail_if_called(*args, **kwargs): assert nearest_dc.detect_local_dc(endpoints) == "dc1" -def test_detect_local_dc_fallback_to_first_location_when_all_fail(monkeypatch): +def test_detect_local_dc_returns_none_when_all_fail(monkeypatch): def fake_create_connection(addr_port, timeout=None): raise OSError("connect failed") @@ -96,7 +96,7 @@ def fake_create_connection(addr_port, timeout=None): MockEndpoint("bad1", 9999, "dc1"), MockEndpoint("bad2", 9999, "dc2"), ] - assert nearest_dc.detect_local_dc(endpoints, timeout=0.05) == "dc1" + assert nearest_dc.detect_local_dc(endpoints, timeout=0.05) is None def test_detect_local_dc_returns_location_of_fastest(monkeypatch): diff --git a/ydb/aio/nearest_dc.py b/ydb/aio/nearest_dc.py index e819718c..ae50d52d 100644 --- a/ydb/aio/nearest_dc.py +++ b/ydb/aio/nearest_dc.py @@ -43,8 +43,10 @@ async def try_connect(endpoint): writer.close() await writer.wait_closed() return endpoint - except (OSError, asyncio.TimeoutError) as e: - logger.debug("Failed to connect to %s: %s", endpoint.endpoint, e) + except (OSError, asyncio.TimeoutError): + return None + except Exception as e: + logger.debug("Unexpected error connecting to %s: %s", endpoint.endpoint, e) return None tasks = [asyncio.create_task(try_connect(endpoint)) for endpoint in endpoints] @@ -100,7 +102,7 @@ def _get_random_endpoints(endpoints: List[resolver.EndpointInfo], count: int) -> async def detect_local_dc( endpoints: List[resolver.EndpointInfo], max_per_location: int = 3, timeout: float = 5.0 -) -> str: +) -> Optional[str]: """ Detect nearest datacenter by performing async TCP race between endpoints. @@ -114,12 +116,13 @@ async def detect_local_dc( 3. Select up to max_per_location random endpoints from each location 4. Perform TCP race: connect to all selected endpoints simultaneously 5. Return the location of the first endpoint that connects successfully + 6. If all connections fail, return None :param endpoints: List of resolver.EndpointInfo objects from discovery :param max_per_location: Maximum number of endpoints to test per location (default: 3) :param timeout: TCP connection timeout in seconds (default: 5.0) - :return: Location string of the nearest datacenter - :raises ValueError: If endpoints list is empty or detection fails + :return: Location string of the nearest datacenter, or None if detection failed + :raises ValueError: If endpoints list is empty """ if not endpoints: raise ValueError("Empty endpoints list for local DC detection") @@ -151,12 +154,8 @@ async def detect_local_dc( fastest_endpoint = await _check_fastest_endpoint(endpoints_to_test, timeout=timeout) if fastest_endpoint is None: - fallback_location = endpoints[0].location - logger.warning( - "Failed to detect local DC via TCP race, falling back to first endpoint location: %s", - fallback_location, - ) - return fallback_location + logger.warning("Failed to detect local DC via TCP race: no endpoint connected in time") + return None detected_location = fastest_endpoint.location logger.info("Detected local DC: %s", detected_location) diff --git a/ydb/aio/pool.py b/ydb/aio/pool.py index ecfb23e3..f4a6ac23 100644 --- a/ydb/aio/pool.py +++ b/ydb/aio/pool.py @@ -160,6 +160,11 @@ async def execute_discovery(self) -> bool: local_dc, resolve_details.self_location, ) + else: + self.logger.warning( + "Failed to detect local DC via TCP latency, using server location: %s", + resolve_details.self_location, + ) except Exception as e: self.logger.warning( "Failed to detect local DC via TCP latency, using server location: %s. Error: %s", diff --git a/ydb/driver.py b/ydb/driver.py index 983d1b1b..d60c282c 100644 --- a/ydb/driver.py +++ b/ydb/driver.py @@ -131,7 +131,7 @@ def __init__( discovery_request_timeout: int = 10, compression: Optional[grpc.Compression] = None, disable_discovery: bool = False, - detect_local_dc=False, + detect_local_dc: bool = False, ) -> None: """ A driver config to initialize a driver instance diff --git a/ydb/nearest_dc.py b/ydb/nearest_dc.py index 12b8f555..a107e1ad 100644 --- a/ydb/nearest_dc.py +++ b/ydb/nearest_dc.py @@ -51,8 +51,10 @@ def try_connect(endpoint: resolver.EndpointInfo): finally: sock.close() + except (OSError, socket.timeout): + pass except Exception as e: - logger.warning("Unexpected error connecting to %s: %s", endpoint.endpoint, e) + logger.debug("Unexpected error connecting to %s: %s", endpoint.endpoint, e) threads: List[threading.Thread] = [] for ep in endpoints: @@ -60,12 +62,7 @@ def try_connect(endpoint: resolver.EndpointInfo): thread.start() threads.append(thread) - for thread in threads: - remaining = deadline - time.monotonic() - if remaining <= 0 or stop_event.is_set(): - break - - thread.join(timeout=remaining) + stop_event.wait(timeout=max(0.0, deadline - time.monotonic())) return result["endpoint"] @@ -102,7 +99,9 @@ def _get_random_endpoints(endpoints: List[resolver.EndpointInfo], count: int) -> return endpoints_copy[:count] -def detect_local_dc(endpoints: List[resolver.EndpointInfo], max_per_location: int = 3, timeout: float = 5.0) -> str: +def detect_local_dc( + endpoints: List[resolver.EndpointInfo], max_per_location: int = 3, timeout: float = 5.0 +) -> Optional[str]: """ Detect nearest datacenter by performing TCP race between endpoints. @@ -116,12 +115,13 @@ def detect_local_dc(endpoints: List[resolver.EndpointInfo], max_per_location: in 3. Select up to max_per_location random endpoints from each location 4. Perform TCP race: connect to all selected endpoints simultaneously 5. Return the location of the first endpoint that connects successfully + 6. If all connections fail, return None :param endpoints: List of resolver.EndpointInfo objects from discovery :param max_per_location: Maximum number of endpoints to test per location (default: 3) :param timeout: TCP connection timeout in seconds (default: 5.0) - :return: Location string of the nearest datacenter - :raises ValueError: If endpoints list is empty or detection fails + :return: Location string of the nearest datacenter, or None if detection failed + :raises ValueError: If endpoints list is empty """ if not endpoints: raise ValueError("Empty endpoints list for local DC detection") @@ -153,12 +153,8 @@ def detect_local_dc(endpoints: List[resolver.EndpointInfo], max_per_location: in fastest_endpoint = _check_fastest_endpoint(endpoints_to_test, timeout=timeout) if fastest_endpoint is None: - fallback_location = endpoints[0].location - logger.warning( - "Failed to detect local DC via TCP race, falling back to first endpoint location: %s", - fallback_location, - ) - return fallback_location + logger.warning("Failed to detect local DC via TCP race: no endpoint connected in time") + return None detected_location = fastest_endpoint.location logger.info("Detected local DC: %s", detected_location) diff --git a/ydb/pool.py b/ydb/pool.py index 1c011f22..70599931 100644 --- a/ydb/pool.py +++ b/ydb/pool.py @@ -247,6 +247,11 @@ def execute_discovery(self) -> bool: local_dc, resolve_details.self_location, ) + else: + self.logger.warning( + "Failed to detect local DC via TCP latency, using server location: %s", + resolve_details.self_location, + ) except Exception as e: self.logger.warning( "Failed to detect local DC via TCP latency, using server location: %s. Error: %s", From 9b911fd96446ed924c7af29c1fa39c8f2b255a64 Mon Sep 17 00:00:00 2001 From: Idris Yandarov <32651311+r142f@users.noreply.github.com> Date: Tue, 10 Feb 2026 05:25:14 +0300 Subject: [PATCH 04/12] fix: fixing flaws --- ydb/aio/nearest_dc.py | 6 +++--- ydb/driver.py | 4 +++- ydb/nearest_dc.py | 7 ++++--- 3 files changed, 10 insertions(+), 7 deletions(-) diff --git a/ydb/aio/nearest_dc.py b/ydb/aio/nearest_dc.py index ae50d52d..e9699f64 100644 --- a/ydb/aio/nearest_dc.py +++ b/ydb/aio/nearest_dc.py @@ -137,7 +137,7 @@ async def detect_local_dc( if len(endpoints_by_location) == 1: location = list(endpoints_by_location.keys())[0] - logger.info("Only one location found: %s", location) + logger.debug("Only one location found: %s", location) return location endpoints_to_test = [] @@ -154,10 +154,10 @@ async def detect_local_dc( fastest_endpoint = await _check_fastest_endpoint(endpoints_to_test, timeout=timeout) if fastest_endpoint is None: - logger.warning("Failed to detect local DC via TCP race: no endpoint connected in time") + logger.debug("Failed to detect local DC via TCP race: no endpoint connected in time") return None detected_location = fastest_endpoint.location - logger.info("Detected local DC: %s", detected_location) + logger.debug("Detected local DC: %s", detected_location) return detected_location diff --git a/ydb/driver.py b/ydb/driver.py index d60c282c..fc781c25 100644 --- a/ydb/driver.py +++ b/ydb/driver.py @@ -153,7 +153,9 @@ def __init__( :param grpc_lb_policy_name: A load balancing policy to be used for discovery channel construction. Default value is `round_round` :param discovery_request_timeout: A default timeout to complete the discovery. The default value is 10 seconds. :param disable_discovery: If True, endpoint discovery is disabled and only the start endpoint is used for all requests. - :param detect_local_dc: If True, detect nearest datacenter using TCP latency measurement instead of using server-provided self_location. + :param detect_local_dc: If True, detect nearest datacenter using TCP latency measurement instead of using\ + server-provided self_location. **Note**: This option only affects endpoint selection when use_all_nodes=False.\ + When use_all_nodes=True (default), all endpoints are used regardless of detected location. """ self.endpoint = endpoint diff --git a/ydb/nearest_dc.py b/ydb/nearest_dc.py index a107e1ad..5e0f4ec5 100644 --- a/ydb/nearest_dc.py +++ b/ydb/nearest_dc.py @@ -52,6 +52,7 @@ def try_connect(endpoint: resolver.EndpointInfo): sock.close() except (OSError, socket.timeout): + # Ignore expected connection errors; endpoints that fail simply lose the TCP race. pass except Exception as e: logger.debug("Unexpected error connecting to %s: %s", endpoint.endpoint, e) @@ -136,7 +137,7 @@ def detect_local_dc( if len(endpoints_by_location) == 1: location = list(endpoints_by_location.keys())[0] - logger.info("Only one location found: %s", location) + logger.debug("Only one location found: %s", location) return location endpoints_to_test = [] @@ -153,10 +154,10 @@ def detect_local_dc( fastest_endpoint = _check_fastest_endpoint(endpoints_to_test, timeout=timeout) if fastest_endpoint is None: - logger.warning("Failed to detect local DC via TCP race: no endpoint connected in time") + logger.debug("Failed to detect local DC via TCP race: no endpoint connected in time") return None detected_location = fastest_endpoint.location - logger.info("Detected local DC: %s", detected_location) + logger.debug("Detected local DC: %s", detected_location) return detected_location From 4e1375faa8f63a66b2032af4ce47d47851dfa576 Mon Sep 17 00:00:00 2001 From: Idris Yandarov <32651311+r142f@users.noreply.github.com> Date: Wed, 11 Feb 2026 06:39:15 +0300 Subject: [PATCH 05/12] fix: fixing flaws --- ydb/aio/nearest_dc.py | 5 +-- ydb/aio/pool.py | 1 + ydb/nearest_dc.py | 71 ++++++++++++++++++++++++++----------------- ydb/pool.py | 1 + 4 files changed, 46 insertions(+), 32 deletions(-) diff --git a/ydb/aio/nearest_dc.py b/ydb/aio/nearest_dc.py index e9699f64..27a33b33 100644 --- a/ydb/aio/nearest_dc.py +++ b/ydb/aio/nearest_dc.py @@ -94,10 +94,7 @@ def _get_random_endpoints(endpoints: List[resolver.EndpointInfo], count: int) -> """ if len(endpoints) <= count: return endpoints - - endpoints_copy = list(endpoints) - random.shuffle(endpoints_copy) - return endpoints_copy[:count] + return random.sample(endpoints, count) async def detect_local_dc( diff --git a/ydb/aio/pool.py b/ydb/aio/pool.py index f4a6ac23..711bf2c4 100644 --- a/ydb/aio/pool.py +++ b/ydb/aio/pool.py @@ -170,6 +170,7 @@ async def execute_discovery(self) -> bool: "Failed to detect local DC via TCP latency, using server location: %s. Error: %s", resolve_details.self_location, e, + exc_info=True, ) for resolved_endpoint in resolve_details.endpoints: diff --git a/ydb/nearest_dc.py b/ydb/nearest_dc.py index 5e0f4ec5..d4f7f7d7 100644 --- a/ydb/nearest_dc.py +++ b/ydb/nearest_dc.py @@ -1,26 +1,37 @@ # -*- coding: utf-8 -*- +import atexit +import concurrent.futures import socket import threading import logging import random import time -from typing import Any, Dict, List, Optional +from typing import Dict, List, Optional from . import resolver logger = logging.getLogger(__name__) +# Module-level thread pool for TCP race (reused across discovery cycles) +_TCP_RACE_MAX_WORKERS = 15 +_TCP_RACE_EXECUTOR = concurrent.futures.ThreadPoolExecutor( + max_workers=_TCP_RACE_MAX_WORKERS, + thread_name_prefix="ydb-tcp-race", +) + +# Ensure executor is shut down on process exit +atexit.register(lambda: _TCP_RACE_EXECUTOR.shutdown(wait=False, cancel_futures=True)) + def _check_fastest_endpoint( endpoints: List[resolver.EndpointInfo], timeout: float = 5.0 ) -> Optional[resolver.EndpointInfo]: """ - Perform TCP race: connect to all endpoints simultaneously and return the fastest one. + Perform TCP race using a bounded thread pool and return the fastest endpoint. - This function starts TCP connections to all provided endpoints in parallel - and returns the first one that successfully connects. Other connection attempts - will continue until their socket timeout expires (they cannot be interrupted). + Uses a module-level ThreadPoolExecutor to avoid creating new threads on every + discovery cycle. Returns immediately when the first endpoint connects successfully. :param endpoints: List of resolver.EndpointInfo objects :param timeout: Maximum time to wait for any connection (seconds) @@ -29,43 +40,50 @@ def _check_fastest_endpoint( if not endpoints: return None - result: Dict[str, Any] = {"endpoint": None, "lock": threading.Lock()} + endpoints = _get_random_endpoints(endpoints, _TCP_RACE_MAX_WORKERS) + stop_event = threading.Event() + winner_lock = threading.Lock() deadline = time.monotonic() + timeout - def try_connect(endpoint: resolver.EndpointInfo): - """Try to connect to endpoint and report if successful.""" + def try_connect(endpoint: resolver.EndpointInfo) -> Optional[resolver.EndpointInfo]: + """Try to connect to endpoint and return it if successful.""" remaining = deadline - time.monotonic() if remaining <= 0 or stop_event.is_set(): - return + return None try: sock = socket.create_connection((endpoint.address, endpoint.port), timeout=remaining) - try: - with result["lock"]: - if result["endpoint"] is None: - result["endpoint"] = endpoint - stop_event.set() - logger.debug("TCP race winner: %s (location: %s)", endpoint.endpoint, endpoint.location) + with winner_lock: + if stop_event.is_set(): + return None + stop_event.set() + return endpoint finally: sock.close() - except (OSError, socket.timeout): # Ignore expected connection errors; endpoints that fail simply lose the TCP race. - pass + return None except Exception as e: logger.debug("Unexpected error connecting to %s: %s", endpoint.endpoint, e) + return None - threads: List[threading.Thread] = [] - for ep in endpoints: - thread = threading.Thread(target=try_connect, args=(ep,), daemon=True) - thread.start() - threads.append(thread) + futures: List[concurrent.futures.Future] = [_TCP_RACE_EXECUTOR.submit(try_connect, ep) for ep in endpoints] - stop_event.wait(timeout=max(0.0, deadline - time.monotonic())) + try: + for fut in concurrent.futures.as_completed(futures, timeout=timeout): + result = fut.result() + if result is not None: + return result + except concurrent.futures.TimeoutError: + # Overall timeout expired + pass + finally: + for f in futures: + f.cancel() - return result["endpoint"] + return None def _split_endpoints_by_location(endpoints: List[resolver.EndpointInfo]) -> Dict[str, List[resolver.EndpointInfo]]: @@ -94,10 +112,7 @@ def _get_random_endpoints(endpoints: List[resolver.EndpointInfo], count: int) -> """ if len(endpoints) <= count: return endpoints - - endpoints_copy = list(endpoints) - random.shuffle(endpoints_copy) - return endpoints_copy[:count] + return random.sample(endpoints, count) def detect_local_dc( diff --git a/ydb/pool.py b/ydb/pool.py index 70599931..a0c9e673 100644 --- a/ydb/pool.py +++ b/ydb/pool.py @@ -257,6 +257,7 @@ def execute_discovery(self) -> bool: "Failed to detect local DC via TCP latency, using server location: %s. Error: %s", resolve_details.self_location, e, + exc_info=True, ) for resolved_endpoint in resolve_details.endpoints: From b2394062831171b7825d47495ba86ea886b9ef61 Mon Sep 17 00:00:00 2001 From: Idris Yandarov <32651311+r142f@users.noreply.github.com> Date: Wed, 11 Feb 2026 06:55:39 +0300 Subject: [PATCH 06/12] fix: add tests --- tests/aio/test_discovery_detect_local_dc.py | 106 ++++++++++++++++++++ tests/test_discovery_detect_local_dc.py | 90 +++++++++++++++++ 2 files changed, 196 insertions(+) create mode 100644 tests/aio/test_discovery_detect_local_dc.py create mode 100644 tests/test_discovery_detect_local_dc.py diff --git a/tests/aio/test_discovery_detect_local_dc.py b/tests/aio/test_discovery_detect_local_dc.py new file mode 100644 index 00000000..1e7faf5a --- /dev/null +++ b/tests/aio/test_discovery_detect_local_dc.py @@ -0,0 +1,106 @@ +# -*- coding: utf-8 -*- +import pytest +from unittest.mock import MagicMock, patch, AsyncMock +from ydb import driver, connection +from ydb.aio import pool, nearest_dc + + +class MockEndpointInfo: + def __init__(self, address, port, location): + self.address = address + self.port = port + self.endpoint = f"{address}:{port}" + self.location = location + self.ssl = False + self.node_id = 1 + + def endpoints_with_options(self): + yield (self.endpoint, connection.EndpointOptions(ssl_target_name_override=None, node_id=self.node_id)) + + +class MockDiscoveryResult: + def __init__(self, self_location, endpoints): + self.self_location = self_location + self.endpoints = endpoints + + +@pytest.mark.asyncio +async def test_detect_local_dc_overrides_server_location(): + """Test that detected location overrides server's self_location for preferred endpoints.""" + # Server reports dc1, but we detect dc2 as nearest + endpoints = [ + MockEndpointInfo("dc1-host", 2135, "dc1"), + MockEndpointInfo("dc2-host", 2135, "dc2"), + ] + mock_result = MockDiscoveryResult(self_location="dc1", endpoints=endpoints) + + mock_resolver = MagicMock() + mock_resolver.resolve = AsyncMock(return_value=mock_result) + + preferred = [] + + def mock_init(self, endpoint, driver_config, endpoint_options=None): + self.endpoint = endpoint + self.node_id = 1 + + with patch.object(nearest_dc, "detect_local_dc", AsyncMock(return_value="dc2")): + with patch("ydb.aio.connection.Connection.__init__", mock_init): + with patch("ydb.aio.connection.Connection.connection_ready", AsyncMock()): + with patch("ydb.aio.connection.Connection.close", AsyncMock()): + with patch("ydb.aio.connection.Connection.add_cleanup_callback", lambda *a: None): + config = driver.DriverConfig( + endpoint="grpc://test:2135", database="/local", detect_local_dc=True + ) + discovery = pool.Discovery(store=pool.ConnectionsCache(), driver_config=config) + discovery._resolver = mock_resolver + + original_add = discovery._cache.add + discovery._cache.add = lambda conn, pref=False: ( + preferred.append(conn.endpoint) if pref else None, + original_add(conn, pref), + )[1] + + await discovery.execute_discovery() + + assert any("dc2" in ep for ep in preferred), "dc2 should be preferred (detected)" + assert not any("dc1" in ep for ep in preferred), "dc1 should not be preferred" + + +@pytest.mark.asyncio +async def test_detect_local_dc_failure_fallback(): + """Test that detection failure falls back to server's self_location.""" + endpoints = [ + MockEndpointInfo("dc1-host", 2135, "dc1"), + MockEndpointInfo("dc2-host", 2135, "dc2"), + ] + mock_result = MockDiscoveryResult(self_location="dc1", endpoints=endpoints) + + mock_resolver = MagicMock() + mock_resolver.resolve = AsyncMock(return_value=mock_result) + + preferred = [] + + def mock_init(self, endpoint, driver_config, endpoint_options=None): + self.endpoint = endpoint + self.node_id = 1 + + with patch.object(nearest_dc, "detect_local_dc", AsyncMock(return_value=None)): + with patch("ydb.aio.connection.Connection.__init__", mock_init): + with patch("ydb.aio.connection.Connection.connection_ready", AsyncMock()): + with patch("ydb.aio.connection.Connection.close", AsyncMock()): + with patch("ydb.aio.connection.Connection.add_cleanup_callback", lambda *a: None): + config = driver.DriverConfig( + endpoint="grpc://test:2135", database="/local", detect_local_dc=True + ) + discovery = pool.Discovery(store=pool.ConnectionsCache(), driver_config=config) + discovery._resolver = mock_resolver + + original_add = discovery._cache.add + discovery._cache.add = lambda conn, pref=False: ( + preferred.append(conn.endpoint) if pref else None, + original_add(conn, pref), + )[1] + + await discovery.execute_discovery() + + assert any("dc1" in ep for ep in preferred), "dc1 should be preferred (server fallback)" diff --git a/tests/test_discovery_detect_local_dc.py b/tests/test_discovery_detect_local_dc.py new file mode 100644 index 00000000..291db3ac --- /dev/null +++ b/tests/test_discovery_detect_local_dc.py @@ -0,0 +1,90 @@ +# -*- coding: utf-8 -*- +from unittest.mock import Mock, MagicMock, patch +from ydb import driver, pool, nearest_dc, connection + + +class MockEndpointInfo: + def __init__(self, address, port, location): + self.address = address + self.port = port + self.endpoint = f"{address}:{port}" + self.location = location + self.ssl = False + self.node_id = 1 + + def endpoints_with_options(self): + yield (self.endpoint, connection.EndpointOptions(ssl_target_name_override=None, node_id=self.node_id)) + + +class MockDiscoveryResult: + def __init__(self, self_location, endpoints): + self.self_location = self_location + self.endpoints = endpoints + + +def test_detect_local_dc_overrides_server_location(): + """Test that detected location overrides server's self_location for preferred endpoints.""" + # Server reports dc1, but we detect dc2 as nearest + endpoints = [ + MockEndpointInfo("dc1-host", 2135, "dc1"), + MockEndpointInfo("dc2-host", 2135, "dc2"), + ] + mock_result = MockDiscoveryResult(self_location="dc1", endpoints=endpoints) + + mock_resolver = MagicMock() + mock_resolver.context_resolve.return_value.__enter__.return_value = mock_result + mock_resolver.context_resolve.return_value.__exit__.return_value = None + + preferred = [] + + with patch.object(nearest_dc, "detect_local_dc", Mock(return_value="dc2")): + with patch( + "ydb.connection.Connection.ready_factory", lambda *args, **kw: MagicMock(endpoint=args[0], node_id=1) + ): + config = driver.DriverConfig(endpoint="grpc://test:2135", database="/local", detect_local_dc=True) + discovery = pool.Discovery(store=pool.ConnectionsCache(), driver_config=config) + discovery._resolver = mock_resolver + + original_add = discovery._cache.add + discovery._cache.add = lambda conn, pref=False: ( + preferred.append(conn.endpoint) if pref else None, + original_add(conn, pref), + )[1] + + discovery.execute_discovery() + + assert any("dc2" in ep for ep in preferred), "dc2 should be preferred (detected)" + assert not any("dc1" in ep for ep in preferred), "dc1 should not be preferred" + + +def test_detect_local_dc_failure_fallback(): + """Test that detection failure falls back to server's self_location.""" + endpoints = [ + MockEndpointInfo("dc1-host", 2135, "dc1"), + MockEndpointInfo("dc2-host", 2135, "dc2"), + ] + mock_result = MockDiscoveryResult(self_location="dc1", endpoints=endpoints) + + mock_resolver = MagicMock() + mock_resolver.context_resolve.return_value.__enter__.return_value = mock_result + mock_resolver.context_resolve.return_value.__exit__.return_value = None + + preferred = [] + + with patch.object(nearest_dc, "detect_local_dc", Mock(return_value=None)): + with patch( + "ydb.connection.Connection.ready_factory", lambda *args, **kw: MagicMock(endpoint=args[0], node_id=1) + ): + config = driver.DriverConfig(endpoint="grpc://test:2135", database="/local", detect_local_dc=True) + discovery = pool.Discovery(store=pool.ConnectionsCache(), driver_config=config) + discovery._resolver = mock_resolver + + original_add = discovery._cache.add + discovery._cache.add = lambda conn, pref=False: ( + preferred.append(conn.endpoint) if pref else None, + original_add(conn, pref), + )[1] + + discovery.execute_discovery() + + assert any("dc1" in ep for ep in preferred), "dc1 should be preferred (server fallback)" From 92553117f0ca81cf7c7af4b8500835d15c252203 Mon Sep 17 00:00:00 2001 From: Idris Yandarov <32651311+r142f@users.noreply.github.com> Date: Wed, 11 Feb 2026 12:05:41 +0300 Subject: [PATCH 07/12] fix: fixing flaws --- tests/aio/test_discovery_detect_local_dc.py | 12 ++++++++---- tests/test_discovery_detect_local_dc.py | 12 ++++++++---- ydb/nearest_dc.py | 21 +++++++++++++++++---- 3 files changed, 33 insertions(+), 12 deletions(-) diff --git a/tests/aio/test_discovery_detect_local_dc.py b/tests/aio/test_discovery_detect_local_dc.py index 1e7faf5a..93e8ddb1 100644 --- a/tests/aio/test_discovery_detect_local_dc.py +++ b/tests/aio/test_discovery_detect_local_dc.py @@ -49,9 +49,11 @@ def mock_init(self, endpoint, driver_config, endpoint_options=None): with patch("ydb.aio.connection.Connection.close", AsyncMock()): with patch("ydb.aio.connection.Connection.add_cleanup_callback", lambda *a: None): config = driver.DriverConfig( - endpoint="grpc://test:2135", database="/local", detect_local_dc=True + endpoint="grpc://test:2135", database="/local", detect_local_dc=True, use_all_nodes=False + ) + discovery = pool.Discovery( + store=pool.ConnectionsCache(config.use_all_nodes), driver_config=config ) - discovery = pool.Discovery(store=pool.ConnectionsCache(), driver_config=config) discovery._resolver = mock_resolver original_add = discovery._cache.add @@ -90,9 +92,11 @@ def mock_init(self, endpoint, driver_config, endpoint_options=None): with patch("ydb.aio.connection.Connection.close", AsyncMock()): with patch("ydb.aio.connection.Connection.add_cleanup_callback", lambda *a: None): config = driver.DriverConfig( - endpoint="grpc://test:2135", database="/local", detect_local_dc=True + endpoint="grpc://test:2135", database="/local", detect_local_dc=True, use_all_nodes=False + ) + discovery = pool.Discovery( + store=pool.ConnectionsCache(config.use_all_nodes), driver_config=config ) - discovery = pool.Discovery(store=pool.ConnectionsCache(), driver_config=config) discovery._resolver = mock_resolver original_add = discovery._cache.add diff --git a/tests/test_discovery_detect_local_dc.py b/tests/test_discovery_detect_local_dc.py index 291db3ac..9be8efa1 100644 --- a/tests/test_discovery_detect_local_dc.py +++ b/tests/test_discovery_detect_local_dc.py @@ -41,8 +41,10 @@ def test_detect_local_dc_overrides_server_location(): with patch( "ydb.connection.Connection.ready_factory", lambda *args, **kw: MagicMock(endpoint=args[0], node_id=1) ): - config = driver.DriverConfig(endpoint="grpc://test:2135", database="/local", detect_local_dc=True) - discovery = pool.Discovery(store=pool.ConnectionsCache(), driver_config=config) + config = driver.DriverConfig( + endpoint="grpc://test:2135", database="/local", detect_local_dc=True, use_all_nodes=False + ) + discovery = pool.Discovery(store=pool.ConnectionsCache(config.use_all_nodes), driver_config=config) discovery._resolver = mock_resolver original_add = discovery._cache.add @@ -75,8 +77,10 @@ def test_detect_local_dc_failure_fallback(): with patch( "ydb.connection.Connection.ready_factory", lambda *args, **kw: MagicMock(endpoint=args[0], node_id=1) ): - config = driver.DriverConfig(endpoint="grpc://test:2135", database="/local", detect_local_dc=True) - discovery = pool.Discovery(store=pool.ConnectionsCache(), driver_config=config) + config = driver.DriverConfig( + endpoint="grpc://test:2135", database="/local", detect_local_dc=True, use_all_nodes=False + ) + discovery = pool.Discovery(store=pool.ConnectionsCache(config.use_all_nodes), driver_config=config) discovery._resolver = mock_resolver original_add = discovery._cache.add diff --git a/ydb/nearest_dc.py b/ydb/nearest_dc.py index d4f7f7d7..611a15dd 100644 --- a/ydb/nearest_dc.py +++ b/ydb/nearest_dc.py @@ -2,6 +2,7 @@ import atexit import concurrent.futures import socket +import sys import threading import logging import random @@ -14,14 +15,21 @@ logger = logging.getLogger(__name__) # Module-level thread pool for TCP race (reused across discovery cycles) -_TCP_RACE_MAX_WORKERS = 15 +_TCP_RACE_MAX_WORKERS = 30 _TCP_RACE_EXECUTOR = concurrent.futures.ThreadPoolExecutor( max_workers=_TCP_RACE_MAX_WORKERS, thread_name_prefix="ydb-tcp-race", ) -# Ensure executor is shut down on process exit -atexit.register(lambda: _TCP_RACE_EXECUTOR.shutdown(wait=False, cancel_futures=True)) + +def _shutdown_executor(): + if sys.version_info >= (3, 9): + _TCP_RACE_EXECUTOR.shutdown(wait=False, cancel_futures=True) + else: + _TCP_RACE_EXECUTOR.shutdown(wait=False) + + +atexit.register(_shutdown_executor) def _check_fastest_endpoint( @@ -33,6 +41,9 @@ def _check_fastest_endpoint( Uses a module-level ThreadPoolExecutor to avoid creating new threads on every discovery cycle. Returns immediately when the first endpoint connects successfully. + If there are more endpoints than the thread pool size, takes one random endpoint + per location to ensure fair representation of all locations in the race. + :param endpoints: List of resolver.EndpointInfo objects :param timeout: Maximum time to wait for any connection (seconds) :return: Fastest endpoint that connected successfully, or None if all failed @@ -40,7 +51,9 @@ def _check_fastest_endpoint( if not endpoints: return None - endpoints = _get_random_endpoints(endpoints, _TCP_RACE_MAX_WORKERS) + if len(endpoints) > _TCP_RACE_MAX_WORKERS: + endpoints_by_location = _split_endpoints_by_location(endpoints) + endpoints = [random.choice(location_eps) for location_eps in endpoints_by_location.values()] stop_event = threading.Event() winner_lock = threading.Lock() From 9dec804ac85d4ddb9c6b66611ad54629490e74d9 Mon Sep 17 00:00:00 2001 From: Idris Yandarov <32651311+r142f@users.noreply.github.com> Date: Wed, 11 Feb 2026 12:39:55 +0300 Subject: [PATCH 08/12] fix: fixing flaws --- tests/aio/test_nearest_dc.py | 14 ++++++++++++++ tests/test_nearest_dc.py | 12 ++++++++++++ ydb/aio/nearest_dc.py | 27 +++++++++++++++++++++------ ydb/nearest_dc.py | 16 ++++++++++++---- 4 files changed, 59 insertions(+), 10 deletions(-) diff --git a/tests/aio/test_nearest_dc.py b/tests/aio/test_nearest_dc.py index be9b1f08..4bf0b458 100644 --- a/tests/aio/test_nearest_dc.py +++ b/tests/aio/test_nearest_dc.py @@ -143,3 +143,17 @@ async def fake_open_connection(host, port): await nearest_dc.detect_local_dc(endpoints, max_per_location=2, timeout=0.2) assert len(calls) == 4 + + +@pytest.mark.asyncio +async def test_detect_local_dc_validates_max_per_location(): + endpoints = [MockEndpoint("h1", 1, "dc1")] + with pytest.raises(ValueError, match="max_per_location must be >= 1"): + await nearest_dc.detect_local_dc(endpoints, max_per_location=0) + + +@pytest.mark.asyncio +async def test_detect_local_dc_validates_timeout(): + endpoints = [MockEndpoint("h1", 1, "dc1")] + with pytest.raises(ValueError, match="timeout must be > 0"): + await nearest_dc.detect_local_dc(endpoints, timeout=0) diff --git a/tests/test_nearest_dc.py b/tests/test_nearest_dc.py index c626fb7f..6f7f0c43 100644 --- a/tests/test_nearest_dc.py +++ b/tests/test_nearest_dc.py @@ -130,3 +130,15 @@ def fake_create_connection(addr_port, timeout=None): nearest_dc.detect_local_dc(endpoints, max_per_location=2, timeout=0.2) assert len(calls) == 4 + + +def test_detect_local_dc_validates_max_per_location(): + endpoints = [MockEndpoint("h1", 1, "dc1")] + with pytest.raises(ValueError, match="max_per_location must be >= 1"): + nearest_dc.detect_local_dc(endpoints, max_per_location=0) + + +def test_detect_local_dc_validates_timeout(): + endpoints = [MockEndpoint("h1", 1, "dc1")] + with pytest.raises(ValueError, match="timeout must be > 0"): + nearest_dc.detect_local_dc(endpoints, timeout=0) diff --git a/ydb/aio/nearest_dc.py b/ydb/aio/nearest_dc.py index 27a33b33..0802bc5c 100644 --- a/ydb/aio/nearest_dc.py +++ b/ydb/aio/nearest_dc.py @@ -111,18 +111,23 @@ async def detect_local_dc( 1. Group endpoints by location 2. If only one location exists, return it immediately 3. Select up to max_per_location random endpoints from each location - 4. Perform TCP race: connect to all selected endpoints simultaneously - 5. Return the location of the first endpoint that connects successfully - 6. If all connections fail, return None + 4. If too many endpoints, reduce to one per location and cap at limit + 5. Perform TCP race: connect to all selected endpoints simultaneously + 6. Return the location of the first endpoint that connects successfully + 7. If all connections fail, return None :param endpoints: List of resolver.EndpointInfo objects from discovery - :param max_per_location: Maximum number of endpoints to test per location (default: 3) - :param timeout: TCP connection timeout in seconds (default: 5.0) + :param max_per_location: Maximum number of endpoints to test per location (default: 3, must be >= 1) + :param timeout: TCP connection timeout in seconds (default: 5.0, must be > 0) :return: Location string of the nearest datacenter, or None if detection failed - :raises ValueError: If endpoints list is empty + :raises ValueError: If endpoints list is empty, max_per_location < 1, or timeout <= 0 """ if not endpoints: raise ValueError("Empty endpoints list for local DC detection") + if max_per_location < 1: + raise ValueError(f"max_per_location must be >= 1, got {max_per_location}") + if timeout <= 0: + raise ValueError(f"timeout must be > 0, got {timeout}") endpoints_by_location = _split_endpoints_by_location(endpoints) @@ -137,6 +142,8 @@ async def detect_local_dc( logger.debug("Only one location found: %s", location) return location + _MAX_CONCURRENT_TASKS = 99 + endpoints_to_test = [] for location, location_endpoints in endpoints_by_location.items(): sample = _get_random_endpoints(location_endpoints, max_per_location) @@ -148,6 +155,14 @@ async def detect_local_dc( location, ) + if len(endpoints_to_test) > _MAX_CONCURRENT_TASKS: + endpoints_to_test = [random.choice(location_eps) for location_eps in endpoints_by_location.values()] + + if len(endpoints_to_test) > _MAX_CONCURRENT_TASKS: + endpoints_to_test = random.sample(endpoints_to_test, _MAX_CONCURRENT_TASKS) + + logger.debug("Capped endpoints to %d to limit concurrent tasks", len(endpoints_to_test)) + fastest_endpoint = await _check_fastest_endpoint(endpoints_to_test, timeout=timeout) if fastest_endpoint is None: diff --git a/ydb/nearest_dc.py b/ydb/nearest_dc.py index 611a15dd..306149d4 100644 --- a/ydb/nearest_dc.py +++ b/ydb/nearest_dc.py @@ -42,7 +42,8 @@ def _check_fastest_endpoint( discovery cycle. Returns immediately when the first endpoint connects successfully. If there are more endpoints than the thread pool size, takes one random endpoint - per location to ensure fair representation of all locations in the race. + per location to ensure fair representation of all locations in the race. If there + are still too many locations, randomly samples them to stay within the limit. :param endpoints: List of resolver.EndpointInfo objects :param timeout: Maximum time to wait for any connection (seconds) @@ -55,6 +56,9 @@ def _check_fastest_endpoint( endpoints_by_location = _split_endpoints_by_location(endpoints) endpoints = [random.choice(location_eps) for location_eps in endpoints_by_location.values()] + if len(endpoints) > _TCP_RACE_MAX_WORKERS: + endpoints = random.sample(endpoints, _TCP_RACE_MAX_WORKERS) + stop_event = threading.Event() winner_lock = threading.Lock() deadline = time.monotonic() + timeout @@ -147,13 +151,17 @@ def detect_local_dc( 6. If all connections fail, return None :param endpoints: List of resolver.EndpointInfo objects from discovery - :param max_per_location: Maximum number of endpoints to test per location (default: 3) - :param timeout: TCP connection timeout in seconds (default: 5.0) + :param max_per_location: Maximum number of endpoints to test per location (default: 3, must be >= 1) + :param timeout: TCP connection timeout in seconds (default: 5.0, must be > 0) :return: Location string of the nearest datacenter, or None if detection failed - :raises ValueError: If endpoints list is empty + :raises ValueError: If endpoints list is empty, max_per_location < 1, or timeout <= 0 """ if not endpoints: raise ValueError("Empty endpoints list for local DC detection") + if max_per_location < 1: + raise ValueError(f"max_per_location must be >= 1, got {max_per_location}") + if timeout <= 0: + raise ValueError(f"timeout must be > 0, got {timeout}") endpoints_by_location = _split_endpoints_by_location(endpoints) From 15fb6cf0645d5cc6f48bae7b8cd635e025098424 Mon Sep 17 00:00:00 2001 From: Idris Yandarov <32651311+r142f@users.noreply.github.com> Date: Wed, 11 Feb 2026 13:16:47 +0300 Subject: [PATCH 09/12] fix: fixing flaws --- ydb/nearest_dc.py | 45 ++++++++++++++++++++++++++++++++++----------- 1 file changed, 34 insertions(+), 11 deletions(-) diff --git a/ydb/nearest_dc.py b/ydb/nearest_dc.py index 306149d4..ce3cea24 100644 --- a/ydb/nearest_dc.py +++ b/ydb/nearest_dc.py @@ -16,20 +16,42 @@ # Module-level thread pool for TCP race (reused across discovery cycles) _TCP_RACE_MAX_WORKERS = 30 -_TCP_RACE_EXECUTOR = concurrent.futures.ThreadPoolExecutor( - max_workers=_TCP_RACE_MAX_WORKERS, - thread_name_prefix="ydb-tcp-race", -) +_TCP_RACE_EXECUTOR: Optional[concurrent.futures.ThreadPoolExecutor] = None +_EXECUTOR_LOCK = threading.Lock() +_ATEXIT_REGISTERED = False -def _shutdown_executor(): - if sys.version_info >= (3, 9): - _TCP_RACE_EXECUTOR.shutdown(wait=False, cancel_futures=True) - else: - _TCP_RACE_EXECUTOR.shutdown(wait=False) +def _get_executor() -> concurrent.futures.ThreadPoolExecutor: + """ + Lazily create and return the thread pool executor. + + The executor is created on first use to avoid import-time side effects. + The atexit hook is registered only when the executor is actually created. + """ + global _TCP_RACE_EXECUTOR, _ATEXIT_REGISTERED + if _TCP_RACE_EXECUTOR is None: + with _EXECUTOR_LOCK: + if _TCP_RACE_EXECUTOR is None: + _TCP_RACE_EXECUTOR = concurrent.futures.ThreadPoolExecutor( + max_workers=_TCP_RACE_MAX_WORKERS, + thread_name_prefix="ydb-tcp-race", + ) -atexit.register(_shutdown_executor) + if not _ATEXIT_REGISTERED: + atexit.register(_shutdown_executor) + _ATEXIT_REGISTERED = True + + return _TCP_RACE_EXECUTOR + + +def _shutdown_executor(): + """Shutdown the executor if it was created.""" + if _TCP_RACE_EXECUTOR is not None: + if sys.version_info >= (3, 9): + _TCP_RACE_EXECUTOR.shutdown(wait=False, cancel_futures=True) + else: + _TCP_RACE_EXECUTOR.shutdown(wait=False) def _check_fastest_endpoint( @@ -86,7 +108,8 @@ def try_connect(endpoint: resolver.EndpointInfo) -> Optional[resolver.EndpointIn logger.debug("Unexpected error connecting to %s: %s", endpoint.endpoint, e) return None - futures: List[concurrent.futures.Future] = [_TCP_RACE_EXECUTOR.submit(try_connect, ep) for ep in endpoints] + executor = _get_executor() + futures: List[concurrent.futures.Future] = [executor.submit(try_connect, ep) for ep in endpoints] try: for fut in concurrent.futures.as_completed(futures, timeout=timeout): From f3a4cb61a067bfd7b43145549cd461c981d1c78e Mon Sep 17 00:00:00 2001 From: Idris Yandarov <32651311+r142f@users.noreply.github.com> Date: Wed, 11 Feb 2026 14:12:19 +0300 Subject: [PATCH 10/12] fix: fixing flaws --- ydb/aio/nearest_dc.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/ydb/aio/nearest_dc.py b/ydb/aio/nearest_dc.py index 0802bc5c..a5c3cda6 100644 --- a/ydb/aio/nearest_dc.py +++ b/ydb/aio/nearest_dc.py @@ -63,7 +63,6 @@ async def try_connect(endpoint): for t in tasks: if not t.done(): t.cancel() - await asyncio.gather(*tasks, return_exceptions=True) def _split_endpoints_by_location( @@ -142,7 +141,7 @@ async def detect_local_dc( logger.debug("Only one location found: %s", location) return location - _MAX_CONCURRENT_TASKS = 99 + _MAX_CONCURRENT_TASKS = 30 endpoints_to_test = [] for location, location_endpoints in endpoints_by_location.items(): From 35e059cc8fd23cf4783666ed1ced3008d4564180 Mon Sep 17 00:00:00 2001 From: Idris Yandarov <32651311+r142f@users.noreply.github.com> Date: Wed, 11 Feb 2026 15:17:28 +0300 Subject: [PATCH 11/12] fix: fixing flaws --- tests/aio/test_nearest_dc.py | 4 +++- tests/test_nearest_dc.py | 4 +++- ydb/aio/nearest_dc.py | 9 ++++++++- ydb/nearest_dc.py | 9 ++++++++- 4 files changed, 22 insertions(+), 4 deletions(-) diff --git a/tests/aio/test_nearest_dc.py b/tests/aio/test_nearest_dc.py index 4bf0b458..8ce6d06b 100644 --- a/tests/aio/test_nearest_dc.py +++ b/tests/aio/test_nearest_dc.py @@ -4,11 +4,13 @@ class MockEndpoint: - def __init__(self, address, port, location): + def __init__(self, address, port, location, ipv4_addrs=(), ipv6_addrs=()): self.address = address self.port = port self.endpoint = f"{address}:{port}" self.location = location + self.ipv4_addrs = ipv4_addrs + self.ipv6_addrs = ipv6_addrs class MockWriter: diff --git a/tests/test_nearest_dc.py b/tests/test_nearest_dc.py index 6f7f0c43..cc76e10d 100644 --- a/tests/test_nearest_dc.py +++ b/tests/test_nearest_dc.py @@ -4,11 +4,13 @@ class MockEndpoint: - def __init__(self, address, port, location): + def __init__(self, address, port, location, ipv4_addrs=(), ipv6_addrs=()): self.address = address self.port = port self.endpoint = f"{address}:{port}" self.location = location + self.ipv4_addrs = ipv4_addrs + self.ipv6_addrs = ipv6_addrs class DummySock: diff --git a/ydb/aio/nearest_dc.py b/ydb/aio/nearest_dc.py index a5c3cda6..c17e656d 100644 --- a/ydb/aio/nearest_dc.py +++ b/ydb/aio/nearest_dc.py @@ -35,9 +35,16 @@ async def try_connect(endpoint): if remaining <= 0: return None + if endpoint.ipv6_addrs: + target_host = endpoint.ipv6_addrs[0] + elif endpoint.ipv4_addrs: + target_host = endpoint.ipv4_addrs[0] + else: + target_host = endpoint.address + try: _, writer = await asyncio.wait_for( - asyncio.open_connection(endpoint.address, endpoint.port), + asyncio.open_connection(target_host, endpoint.port), timeout=remaining, ) writer.close() diff --git a/ydb/nearest_dc.py b/ydb/nearest_dc.py index ce3cea24..6b82adeb 100644 --- a/ydb/nearest_dc.py +++ b/ydb/nearest_dc.py @@ -91,8 +91,15 @@ def try_connect(endpoint: resolver.EndpointInfo) -> Optional[resolver.EndpointIn if remaining <= 0 or stop_event.is_set(): return None + if endpoint.ipv6_addrs: + target_host = endpoint.ipv6_addrs[0] + elif endpoint.ipv4_addrs: + target_host = endpoint.ipv4_addrs[0] + else: + target_host = endpoint.address + try: - sock = socket.create_connection((endpoint.address, endpoint.port), timeout=remaining) + sock = socket.create_connection((target_host, endpoint.port), timeout=remaining) try: with winner_lock: if stop_event.is_set(): From 75c84c61c1c56ef3ec6105403d1e73c10bfbd85f Mon Sep 17 00:00:00 2001 From: Idris Yandarov <32651311+r142f@users.noreply.github.com> Date: Wed, 11 Feb 2026 17:17:37 +0300 Subject: [PATCH 12/12] fix: fixing flaws --- tests/aio/test_discovery_detect_local_dc.py | 33 ++++++++++++++ tests/test_discovery_detect_local_dc.py | 26 +++++++++++ ydb/aio/pool.py | 49 +++++++++++++-------- ydb/pool.py | 49 +++++++++++++-------- 4 files changed, 121 insertions(+), 36 deletions(-) diff --git a/tests/aio/test_discovery_detect_local_dc.py b/tests/aio/test_discovery_detect_local_dc.py index 93e8ddb1..4f042396 100644 --- a/tests/aio/test_discovery_detect_local_dc.py +++ b/tests/aio/test_discovery_detect_local_dc.py @@ -108,3 +108,36 @@ def mock_init(self, endpoint, driver_config, endpoint_options=None): await discovery.execute_discovery() assert any("dc1" in ep for ep in preferred), "dc1 should be preferred (server fallback)" + + +@pytest.mark.asyncio +async def test_detect_local_dc_skipped_when_use_all_nodes_true(): + """Test that detect_local_dc is NOT called when use_all_nodes=True.""" + endpoints = [ + MockEndpointInfo("dc1-host", 2135, "dc1"), + MockEndpointInfo("dc2-host", 2135, "dc2"), + ] + mock_result = MockDiscoveryResult(self_location="dc1", endpoints=endpoints) + + mock_resolver = MagicMock() + mock_resolver.resolve = AsyncMock(return_value=mock_result) + + def mock_init(self, endpoint, driver_config, endpoint_options=None): + self.endpoint = endpoint + self.node_id = 1 + + with patch.object(nearest_dc, "detect_local_dc", AsyncMock(return_value="dc2")) as detect_mock: + with patch("ydb.aio.connection.Connection.__init__", mock_init): + with patch("ydb.aio.connection.Connection.connection_ready", AsyncMock()): + with patch("ydb.aio.connection.Connection.close", AsyncMock()): + with patch("ydb.aio.connection.Connection.add_cleanup_callback", lambda *a: None): + config = driver.DriverConfig( + endpoint="grpc://test:2135", database="/local", detect_local_dc=True, use_all_nodes=True + ) + discovery = pool.Discovery( + store=pool.ConnectionsCache(config.use_all_nodes), driver_config=config + ) + discovery._resolver = mock_resolver + await discovery.execute_discovery() + + assert detect_mock.call_count == 0, "detect_local_dc should NOT be called when use_all_nodes=True" diff --git a/tests/test_discovery_detect_local_dc.py b/tests/test_discovery_detect_local_dc.py index 9be8efa1..4123dd3f 100644 --- a/tests/test_discovery_detect_local_dc.py +++ b/tests/test_discovery_detect_local_dc.py @@ -92,3 +92,29 @@ def test_detect_local_dc_failure_fallback(): discovery.execute_discovery() assert any("dc1" in ep for ep in preferred), "dc1 should be preferred (server fallback)" + + +def test_detect_local_dc_skipped_when_use_all_nodes_true(): + """Test that detect_local_dc is NOT called when use_all_nodes=True.""" + endpoints = [ + MockEndpointInfo("dc1-host", 2135, "dc1"), + MockEndpointInfo("dc2-host", 2135, "dc2"), + ] + mock_result = MockDiscoveryResult(self_location="dc1", endpoints=endpoints) + + mock_resolver = MagicMock() + mock_resolver.context_resolve.return_value.__enter__.return_value = mock_result + mock_resolver.context_resolve.return_value.__exit__.return_value = None + + with patch.object(nearest_dc, "detect_local_dc", Mock(return_value="dc2")) as detect_mock: + with patch( + "ydb.connection.Connection.ready_factory", lambda *args, **kw: MagicMock(endpoint=args[0], node_id=1) + ): + config = driver.DriverConfig( + endpoint="grpc://test:2135", database="/local", detect_local_dc=True, use_all_nodes=True + ) + discovery = pool.Discovery(store=pool.ConnectionsCache(config.use_all_nodes), driver_config=config) + discovery._resolver = mock_resolver + discovery.execute_discovery() + + assert detect_mock.call_count == 0, "detect_local_dc should NOT be called when use_all_nodes=True" diff --git a/ydb/aio/pool.py b/ydb/aio/pool.py index 711bf2c4..0d094431 100644 --- a/ydb/aio/pool.py +++ b/ydb/aio/pool.py @@ -147,30 +147,43 @@ async def execute_discovery(self) -> bool: local_dc = resolve_details.self_location - # Detect local DC using TCP latency if enabled - if self._driver_config.detect_local_dc: - try: - detected_location = await nearest_dc.detect_local_dc( - resolve_details.endpoints, max_per_location=3, timeout=self._ready_timeout - ) - if detected_location: - local_dc = detected_location - self.logger.info( - "Detected local DC via TCP latency: %s (server reported: %s)", - local_dc, - resolve_details.self_location, + # Detect local DC using TCP latency if enabled and preferred is meaningful + if self._driver_config.detect_local_dc and not self._driver_config.use_all_nodes: + # Use only endpoints that match the SSL requirements for detection + ssl_filtered_endpoints = [ + endpoint + for endpoint in resolve_details.endpoints + if (self._ssl_required and endpoint.ssl) or (not self._ssl_required and not endpoint.ssl) + ] + + if ssl_filtered_endpoints: + try: + detected_location = await nearest_dc.detect_local_dc( + ssl_filtered_endpoints, max_per_location=3, timeout=self._ready_timeout ) - else: + if detected_location: + local_dc = detected_location + self.logger.info( + "Detected local DC via TCP latency: %s (server reported: %s)", + local_dc, + resolve_details.self_location, + ) + else: + self.logger.warning( + "Failed to detect local DC via TCP latency, using server location: %s", + resolve_details.self_location, + ) + except Exception as e: self.logger.warning( - "Failed to detect local DC via TCP latency, using server location: %s", + "Failed to detect local DC via TCP latency, using server location: %s. Error: %s", resolve_details.self_location, + e, + exc_info=True, ) - except Exception as e: + else: self.logger.warning( - "Failed to detect local DC via TCP latency, using server location: %s. Error: %s", + "No SSL-compatible endpoints for local DC detection, using server location: %s", resolve_details.self_location, - e, - exc_info=True, ) for resolved_endpoint in resolve_details.endpoints: diff --git a/ydb/pool.py b/ydb/pool.py index a0c9e673..a45a931c 100644 --- a/ydb/pool.py +++ b/ydb/pool.py @@ -234,30 +234,43 @@ def execute_discovery(self) -> bool: local_dc = resolve_details.self_location - # Detect local DC using TCP latency if enabled - if self._driver_config.detect_local_dc: - try: - detected_location = nearest_dc.detect_local_dc( - resolve_details.endpoints, max_per_location=3, timeout=self._ready_timeout - ) - if detected_location: - local_dc = detected_location - self.logger.info( - "Detected local DC via TCP latency: %s (server reported: %s)", - local_dc, - resolve_details.self_location, + # Detect local DC using TCP latency if enabled and preferred is meaningful + if self._driver_config.detect_local_dc and not self._driver_config.use_all_nodes: + # Use only endpoints that match the SSL requirements for detection + ssl_filtered_endpoints = [ + endpoint + for endpoint in resolve_details.endpoints + if (self._ssl_required and endpoint.ssl) or (not self._ssl_required and not endpoint.ssl) + ] + + if ssl_filtered_endpoints: + try: + detected_location = nearest_dc.detect_local_dc( + ssl_filtered_endpoints, max_per_location=3, timeout=self._ready_timeout ) - else: + if detected_location: + local_dc = detected_location + self.logger.info( + "Detected local DC via TCP latency: %s (server reported: %s)", + local_dc, + resolve_details.self_location, + ) + else: + self.logger.warning( + "Failed to detect local DC via TCP latency, using server location: %s", + resolve_details.self_location, + ) + except Exception as e: self.logger.warning( - "Failed to detect local DC via TCP latency, using server location: %s", + "Failed to detect local DC via TCP latency, using server location: %s. Error: %s", resolve_details.self_location, + e, + exc_info=True, ) - except Exception as e: + else: self.logger.warning( - "Failed to detect local DC via TCP latency, using server location: %s. Error: %s", + "No SSL-compatible endpoints for local DC detection, using server location: %s", resolve_details.self_location, - e, - exc_info=True, ) for resolved_endpoint in resolve_details.endpoints: