Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
143 changes: 143 additions & 0 deletions tests/aio/test_discovery_detect_local_dc.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,143 @@
# -*- coding: utf-8 -*-
import pytest
from unittest.mock import MagicMock, patch, AsyncMock
from ydb import driver, connection
from ydb.aio import pool, nearest_dc


class MockEndpointInfo:
def __init__(self, address, port, location):
self.address = address
self.port = port
self.endpoint = f"{address}:{port}"
self.location = location
self.ssl = False
self.node_id = 1

def endpoints_with_options(self):
yield (self.endpoint, connection.EndpointOptions(ssl_target_name_override=None, node_id=self.node_id))


class MockDiscoveryResult:
def __init__(self, self_location, endpoints):
self.self_location = self_location
self.endpoints = endpoints


@pytest.mark.asyncio
async def test_detect_local_dc_overrides_server_location():
"""Test that detected location overrides server's self_location for preferred endpoints."""
# Server reports dc1, but we detect dc2 as nearest
endpoints = [
MockEndpointInfo("dc1-host", 2135, "dc1"),
MockEndpointInfo("dc2-host", 2135, "dc2"),
]
mock_result = MockDiscoveryResult(self_location="dc1", endpoints=endpoints)

mock_resolver = MagicMock()
mock_resolver.resolve = AsyncMock(return_value=mock_result)

preferred = []

def mock_init(self, endpoint, driver_config, endpoint_options=None):
self.endpoint = endpoint
self.node_id = 1

with patch.object(nearest_dc, "detect_local_dc", AsyncMock(return_value="dc2")):
with patch("ydb.aio.connection.Connection.__init__", mock_init):
with patch("ydb.aio.connection.Connection.connection_ready", AsyncMock()):
with patch("ydb.aio.connection.Connection.close", AsyncMock()):
with patch("ydb.aio.connection.Connection.add_cleanup_callback", lambda *a: None):
config = driver.DriverConfig(
endpoint="grpc://test:2135", database="/local", detect_local_dc=True, use_all_nodes=False
)
discovery = pool.Discovery(
store=pool.ConnectionsCache(config.use_all_nodes), driver_config=config
)
discovery._resolver = mock_resolver

original_add = discovery._cache.add
discovery._cache.add = lambda conn, pref=False: (
preferred.append(conn.endpoint) if pref else None,
original_add(conn, pref),
)[1]

await discovery.execute_discovery()

assert any("dc2" in ep for ep in preferred), "dc2 should be preferred (detected)"
assert not any("dc1" in ep for ep in preferred), "dc1 should not be preferred"


@pytest.mark.asyncio
async def test_detect_local_dc_failure_fallback():
"""Test that detection failure falls back to server's self_location."""
endpoints = [
MockEndpointInfo("dc1-host", 2135, "dc1"),
MockEndpointInfo("dc2-host", 2135, "dc2"),
]
mock_result = MockDiscoveryResult(self_location="dc1", endpoints=endpoints)

mock_resolver = MagicMock()
mock_resolver.resolve = AsyncMock(return_value=mock_result)

preferred = []

def mock_init(self, endpoint, driver_config, endpoint_options=None):
self.endpoint = endpoint
self.node_id = 1

with patch.object(nearest_dc, "detect_local_dc", AsyncMock(return_value=None)):
with patch("ydb.aio.connection.Connection.__init__", mock_init):
with patch("ydb.aio.connection.Connection.connection_ready", AsyncMock()):
with patch("ydb.aio.connection.Connection.close", AsyncMock()):
with patch("ydb.aio.connection.Connection.add_cleanup_callback", lambda *a: None):
config = driver.DriverConfig(
endpoint="grpc://test:2135", database="/local", detect_local_dc=True, use_all_nodes=False
)
discovery = pool.Discovery(
store=pool.ConnectionsCache(config.use_all_nodes), driver_config=config
)
discovery._resolver = mock_resolver

original_add = discovery._cache.add
discovery._cache.add = lambda conn, pref=False: (
preferred.append(conn.endpoint) if pref else None,
original_add(conn, pref),
)[1]

await discovery.execute_discovery()

assert any("dc1" in ep for ep in preferred), "dc1 should be preferred (server fallback)"


@pytest.mark.asyncio
async def test_detect_local_dc_skipped_when_use_all_nodes_true():
"""Test that detect_local_dc is NOT called when use_all_nodes=True."""
endpoints = [
MockEndpointInfo("dc1-host", 2135, "dc1"),
MockEndpointInfo("dc2-host", 2135, "dc2"),
]
mock_result = MockDiscoveryResult(self_location="dc1", endpoints=endpoints)

mock_resolver = MagicMock()
mock_resolver.resolve = AsyncMock(return_value=mock_result)

def mock_init(self, endpoint, driver_config, endpoint_options=None):
self.endpoint = endpoint
self.node_id = 1

with patch.object(nearest_dc, "detect_local_dc", AsyncMock(return_value="dc2")) as detect_mock:
with patch("ydb.aio.connection.Connection.__init__", mock_init):
with patch("ydb.aio.connection.Connection.connection_ready", AsyncMock()):
with patch("ydb.aio.connection.Connection.close", AsyncMock()):
with patch("ydb.aio.connection.Connection.add_cleanup_callback", lambda *a: None):
config = driver.DriverConfig(
endpoint="grpc://test:2135", database="/local", detect_local_dc=True, use_all_nodes=True
)
discovery = pool.Discovery(
store=pool.ConnectionsCache(config.use_all_nodes), driver_config=config
)
discovery._resolver = mock_resolver
await discovery.execute_discovery()

assert detect_mock.call_count == 0, "detect_local_dc should NOT be called when use_all_nodes=True"
161 changes: 161 additions & 0 deletions tests/aio/test_nearest_dc.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,161 @@
import asyncio
import pytest
from ydb.aio import nearest_dc


class MockEndpoint:
def __init__(self, address, port, location, ipv4_addrs=(), ipv6_addrs=()):
self.address = address
self.port = port
self.endpoint = f"{address}:{port}"
self.location = location
self.ipv4_addrs = ipv4_addrs
self.ipv6_addrs = ipv6_addrs


class MockWriter:
def __init__(self):
self.closed = False

def close(self):
self.closed = True

async def wait_closed(self):
await asyncio.sleep(0)


@pytest.mark.asyncio
async def test_check_fastest_endpoint_empty():
assert await nearest_dc._check_fastest_endpoint([]) is None


@pytest.mark.asyncio
async def test_check_fastest_endpoint_all_fail(monkeypatch):
async def fake_open_connection(host, port):
raise OSError("connect failed")

monkeypatch.setattr(nearest_dc.asyncio, "open_connection", fake_open_connection)

endpoints = [
MockEndpoint("a", 1, "dc1"),
MockEndpoint("b", 1, "dc2"),
]
assert await nearest_dc._check_fastest_endpoint(endpoints, timeout=0.05) is None


@pytest.mark.asyncio
async def test_check_fastest_endpoint_fastest_wins(monkeypatch):
async def fake_open_connection(host, port):
if host == "slow":
await asyncio.sleep(0.05)
return None, MockWriter()

monkeypatch.setattr(nearest_dc.asyncio, "open_connection", fake_open_connection)

endpoints = [
MockEndpoint("slow", 1, "dc_slow"),
MockEndpoint("fast", 1, "dc_fast"),
]
winner = await nearest_dc._check_fastest_endpoint(endpoints, timeout=0.2)
assert winner is not None
assert winner.location == "dc_fast"


@pytest.mark.asyncio
async def test_check_fastest_endpoint_respects_main_timeout(monkeypatch):
async def fake_open_connection(host, port):
await asyncio.sleep(0.2)
return None, MockWriter()

monkeypatch.setattr(nearest_dc.asyncio, "open_connection", fake_open_connection)

endpoints = [
MockEndpoint("hang1", 1, "dc1"),
MockEndpoint("hang2", 1, "dc2"),
]

winner = await nearest_dc._check_fastest_endpoint(endpoints, timeout=0.05)

assert winner is None


@pytest.mark.asyncio
async def test_detect_local_dc_empty_endpoints():
with pytest.raises(ValueError, match="Empty endpoints"):
await nearest_dc.detect_local_dc([])


@pytest.mark.asyncio
async def test_detect_local_dc_single_location_returns_immediately(monkeypatch):
async def fail_if_called(*args, **kwargs):
raise AssertionError("open_connection should not be called for single location")

monkeypatch.setattr(nearest_dc.asyncio, "open_connection", fail_if_called)

endpoints = [
MockEndpoint("h1", 1, "dc1"),
MockEndpoint("h2", 1, "dc1"),
]
assert await nearest_dc.detect_local_dc(endpoints) == "dc1"


@pytest.mark.asyncio
async def test_detect_local_dc_returns_none_when_all_fail(monkeypatch):
async def fake_open_connection(host, port):
raise OSError("connect failed")

monkeypatch.setattr(nearest_dc.asyncio, "open_connection", fake_open_connection)

endpoints = [
MockEndpoint("bad1", 9999, "dc1"),
MockEndpoint("bad2", 9999, "dc2"),
]
assert await nearest_dc.detect_local_dc(endpoints, timeout=0.05) is None


@pytest.mark.asyncio
async def test_detect_local_dc_returns_location_of_fastest(monkeypatch):
async def fake_open_connection(host, port):
if host == "dc1_host":
await asyncio.sleep(0.05)
return None, MockWriter()

monkeypatch.setattr(nearest_dc.asyncio, "open_connection", fake_open_connection)

endpoints = [
MockEndpoint("dc1_host", 1, "dc1"),
MockEndpoint("dc2_host", 1, "dc2"),
]
assert await nearest_dc.detect_local_dc(endpoints, max_per_location=5, timeout=0.2) == "dc2"


@pytest.mark.asyncio
async def test_detect_local_dc_respects_max_per_location(monkeypatch):
calls = []

async def fake_open_connection(host, port):
calls.append((host, port))
raise OSError("connect failed")

monkeypatch.setattr(nearest_dc.asyncio, "open_connection", fake_open_connection)

endpoints = [MockEndpoint(f"dc1_{i}", 1, "dc1") for i in range(5)] + [
MockEndpoint(f"dc2_{i}", 1, "dc2") for i in range(5)
]
await nearest_dc.detect_local_dc(endpoints, max_per_location=2, timeout=0.2)

assert len(calls) == 4


@pytest.mark.asyncio
async def test_detect_local_dc_validates_max_per_location():
endpoints = [MockEndpoint("h1", 1, "dc1")]
with pytest.raises(ValueError, match="max_per_location must be >= 1"):
await nearest_dc.detect_local_dc(endpoints, max_per_location=0)


@pytest.mark.asyncio
async def test_detect_local_dc_validates_timeout():
endpoints = [MockEndpoint("h1", 1, "dc1")]
with pytest.raises(ValueError, match="timeout must be > 0"):
await nearest_dc.detect_local_dc(endpoints, timeout=0)
Loading
Loading