From b9e5d4e2899080556de6b707921877384adafa0d Mon Sep 17 00:00:00 2001 From: Rahul Yadav Date: Sat, 7 Feb 2026 11:11:38 +0530 Subject: [PATCH 1/6] chore: handle unary gRPC call ordering in KeyAwareChannel --- .../cloud/spanner/spi/v1/KeyAwareChannel.java | 76 ++++++++++++++++++- 1 file changed, 75 insertions(+), 1 deletion(-) diff --git a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/spi/v1/KeyAwareChannel.java b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/spi/v1/KeyAwareChannel.java index 53790bf524..c085265f9f 100644 --- a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/spi/v1/KeyAwareChannel.java +++ b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/spi/v1/KeyAwareChannel.java @@ -245,6 +245,12 @@ static final class KeyAwareClientCall @Nullable private ChannelEndpoint selectedEndpoint; @Nullable private ByteString transactionIdToClear; private boolean allowDefaultAffinity; + private long pendingRequests; + private boolean pendingHalfClose; + @Nullable private Boolean pendingMessageCompression; + private boolean cancelled; + @Nullable private String cancelMessage; + @Nullable private Throwable cancelCause; KeyAwareClientCall( KeyAwareChannel parentChannel, @@ -268,11 +274,22 @@ protected ClientCall delegate() { public void start(Listener responseListener, Metadata headers) { this.responseListener = new KeyAwareClientCallListener<>(responseListener, this); this.headers = headers; + if (cancelled) { + this.responseListener.onClose( + io.grpc.Status.CANCELLED.withDescription(cancelMessage).withCause(cancelCause), + new Metadata()); + } } @Override @SuppressWarnings("unchecked") public void sendMessage(RequestT message) { + if (cancelled) { + return; + } + if (responseListener == null || headers == null) { + throw new IllegalStateException("start must be called before sendMessage"); + } ChannelEndpoint endpoint = null; ChannelFinder finder = null; @@ -326,8 +343,15 @@ public void sendMessage(RequestT message) { this.channelFinder = finder; delegate = endpoint.getChannel().newCall(methodDescriptor, callOptions); + if (pendingMessageCompression != null) { + delegate.setMessageCompression(pendingMessageCompression); + } delegate.start(responseListener, headers); + drainPendingRequests(); delegate.sendMessage(message); + if (pendingHalfClose) { + delegate.halfClose(); + } } @Override @@ -335,7 +359,7 @@ public void halfClose() { if (delegate != null) { delegate.halfClose(); } else { - throw new IllegalStateException("halfClose called before sendMessage"); + pendingHalfClose = true; } } @@ -346,6 +370,56 @@ public void cancel(@Nullable String message, @Nullable Throwable cause) { } else if (responseListener != null) { responseListener.onClose( io.grpc.Status.CANCELLED.withDescription(message).withCause(cause), new Metadata()); + cancelled = true; + cancelMessage = message; + cancelCause = cause; + } else { + cancelled = true; + cancelMessage = message; + cancelCause = cause; + } + } + + @Override + public void request(int numMessages) { + if (delegate != null) { + delegate.request(numMessages); + return; + } + if (numMessages <= 0) { + return; + } + long updated = pendingRequests + numMessages; + if (updated < 0L) { + updated = Long.MAX_VALUE; + } + pendingRequests = updated; + } + + @Override + public boolean isReady() { + if (delegate == null) { + return false; + } + return delegate.isReady(); + } + + @Override + public void setMessageCompression(boolean enabled) { + if (delegate != null) { + delegate.setMessageCompression(enabled); + } else { + pendingMessageCompression = enabled; + } + } + + private void drainPendingRequests() { + long requests = pendingRequests; + pendingRequests = 0L; + while (requests > 0) { + int batch = requests > Integer.MAX_VALUE ? Integer.MAX_VALUE : (int) requests; + delegate.request(batch); + requests -= batch; } } From a4fb7f89dfa2df62ecd00095574649ed7e79ef06 Mon Sep 17 00:00:00 2001 From: Rahul Yadav Date: Sat, 7 Feb 2026 12:49:37 +0530 Subject: [PATCH 2/6] incorporate suggestions --- .../cloud/spanner/spi/v1/KeyAwareChannel.java | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/spi/v1/KeyAwareChannel.java b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/spi/v1/KeyAwareChannel.java index c085265f9f..1ff1880e0b 100644 --- a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/spi/v1/KeyAwareChannel.java +++ b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/spi/v1/KeyAwareChannel.java @@ -367,16 +367,14 @@ public void halfClose() { public void cancel(@Nullable String message, @Nullable Throwable cause) { if (delegate != null) { delegate.cancel(message, cause); - } else if (responseListener != null) { - responseListener.onClose( - io.grpc.Status.CANCELLED.withDescription(message).withCause(cause), new Metadata()); - cancelled = true; - cancelMessage = message; - cancelCause = cause; } else { cancelled = true; cancelMessage = message; cancelCause = cause; + if (responseListener != null) { + responseListener.onClose( + io.grpc.Status.CANCELLED.withDescription(message).withCause(cause), new Metadata()); + } } } @@ -493,6 +491,9 @@ public void onMessage(ResponseT message) { transactionId = transactionIdFromMetadata(response); } else if (message instanceof ResultSet) { ResultSet response = (ResultSet) message; + if (response.hasCacheUpdate() && call.channelFinder != null) { + call.channelFinder.update(response.getCacheUpdate()); + } transactionId = transactionIdFromMetadata(response); } else if (message instanceof Transaction) { Transaction response = (Transaction) message; From 4bab7897ecfb99dc54cfd29b5e8607a3a445d0dc Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Knut=20Olav=20L=C3=B8ite?= Date: Sat, 7 Feb 2026 11:50:57 +0100 Subject: [PATCH 3/6] test: add tests for location API --- .../google/cloud/spanner/SpannerOptions.java | 19 +++ .../cloud/spanner/spi/v1/GapicSpannerRpc.java | 5 +- .../cloud/spanner/SpanFEBypassTest.java | 140 ++++++++++++++++++ .../connection/AbstractMockServerTest.java | 2 +- 4 files changed, 162 insertions(+), 4 deletions(-) create mode 100644 google-cloud-spanner/src/test/java/com/google/cloud/spanner/SpanFEBypassTest.java diff --git a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/SpannerOptions.java b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/SpannerOptions.java index 014bbe7c0e..2b8e1d28ce 100644 --- a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/SpannerOptions.java +++ b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/SpannerOptions.java @@ -18,6 +18,7 @@ import static com.google.api.gax.util.TimeConversionUtils.toJavaTimeDuration; import static com.google.api.gax.util.TimeConversionUtils.toThreetenDuration; +import static com.google.cloud.spanner.spi.v1.GapicSpannerRpc.EXPERIMENTAL_LOCATION_API_ENV_VAR; import com.google.api.core.ApiFunction; import com.google.api.core.BetaApi; @@ -257,6 +258,7 @@ public static GcpChannelPoolOptions createDefaultDynamicChannelPoolOptions() { private final OpenTelemetry openTelemetry; private final boolean enableApiTracing; private final boolean enableBuiltInMetrics; + private final boolean enableLocationApi; private final boolean enableExtendedTracing; private final boolean enableEndToEndTracing; private final String monitoringHost; @@ -926,6 +928,7 @@ protected SpannerOptions(Builder builder) { } else { enableBuiltInMetrics = builder.enableBuiltInMetrics; } + enableLocationApi = builder.enableLocationApi; enableEndToEndTracing = builder.enableEndToEndTracing; monitoringHost = builder.monitoringHost; defaultTransactionOptions = builder.defaultTransactionOptions; @@ -993,6 +996,10 @@ default boolean isEnableEndToEndTracing() { return false; } + default boolean isEnableLocationApi() { + return false; + } + @Deprecated @ObsoleteApi( "This will be removed in an upcoming version without a major version bump. You should use" @@ -1084,6 +1091,11 @@ public boolean isEnableEndToEndTracing() { return Boolean.parseBoolean(System.getenv(SPANNER_ENABLE_END_TO_END_TRACING)); } + @Override + public boolean isEnableLocationApi() { + return Boolean.parseBoolean(System.getenv(EXPERIMENTAL_LOCATION_API_ENV_VAR)); + } + @Override public String getMonitoringHost() { return System.getenv(SPANNER_MONITORING_HOST); @@ -1164,6 +1176,7 @@ public static class Builder private boolean enableExtendedTracing = SpannerOptions.environment.isEnableExtendedTracing(); private boolean enableEndToEndTracing = SpannerOptions.environment.isEnableEndToEndTracing(); private boolean enableBuiltInMetrics = SpannerOptions.environment.isEnableBuiltInMetrics(); + private boolean enableLocationApi = SpannerOptions.environment.isEnableLocationApi(); private String monitoringHost = SpannerOptions.environment.getMonitoringHost(); private SslContext mTLSContext = null; private String experimentalHost = null; @@ -1270,6 +1283,7 @@ protected Builder() { this.enableApiTracing = options.enableApiTracing; this.enableExtendedTracing = options.enableExtendedTracing; this.enableBuiltInMetrics = options.enableBuiltInMetrics; + this.enableLocationApi = options.enableLocationApi; this.enableEndToEndTracing = options.enableEndToEndTracing; this.monitoringHost = options.monitoringHost; this.defaultTransactionOptions = options.defaultTransactionOptions; @@ -2434,6 +2448,11 @@ public boolean isEnableBuiltInMetrics() { return enableBuiltInMetrics; } + @InternalApi + public boolean isEnableLocationApi() { + return enableLocationApi; + } + /** Returns the override metrics Host. */ String getMonitoringHost() { return monitoringHost; diff --git a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/spi/v1/GapicSpannerRpc.java b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/spi/v1/GapicSpannerRpc.java index 212e1ac538..0c912a9f98 100644 --- a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/spi/v1/GapicSpannerRpc.java +++ b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/spi/v1/GapicSpannerRpc.java @@ -226,7 +226,7 @@ public class GapicSpannerRpc implements SpannerRpc { private static final PathTemplate PROJECT_NAME_TEMPLATE = PathTemplate.create("projects/{project}"); - private static final String EXPERIMENTAL_LOCATION_API_ENV_VAR = + public static final String EXPERIMENTAL_LOCATION_API_ENV_VAR = "GOOGLE_SPANNER_EXPERIMENTAL_LOCATION_API"; private static final PathTemplate OPERATION_NAME_TEMPLATE = PathTemplate.create("{database=projects/*/instances/*/databases/*}/operations/{operation}"); @@ -399,8 +399,7 @@ public GapicSpannerRpc(final SpannerOptions options) { // If it is enabled in options uses the channel pool provided by the gRPC-GCP extension. maybeEnableGrpcGcpExtension(defaultChannelProviderBuilder, options); - boolean enableLocationApi = - Boolean.parseBoolean(System.getenv(EXPERIMENTAL_LOCATION_API_ENV_VAR)); + boolean enableLocationApi = options.isEnableLocationApi(); TransportChannelProvider baseChannelProvider = MoreObjects.firstNonNull( options.getChannelProvider(), defaultChannelProviderBuilder.build()); diff --git a/google-cloud-spanner/src/test/java/com/google/cloud/spanner/SpanFEBypassTest.java b/google-cloud-spanner/src/test/java/com/google/cloud/spanner/SpanFEBypassTest.java new file mode 100644 index 0000000000..17cfcdd0bc --- /dev/null +++ b/google-cloud-spanner/src/test/java/com/google/cloud/spanner/SpanFEBypassTest.java @@ -0,0 +1,140 @@ +/* + * Copyright 2026 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.cloud.spanner; + +import static org.junit.Assert.assertEquals; + +import com.google.cloud.NoCredentials; +import com.google.cloud.spanner.MockSpannerServiceImpl.StatementResult; +import com.google.cloud.spanner.connection.AbstractMockServerTest; +import com.google.cloud.spanner.connection.RandomResultSetGenerator; +import com.google.common.util.concurrent.Futures; +import com.google.common.util.concurrent.ListenableFuture; +import com.google.common.util.concurrent.ListeningExecutorService; +import com.google.common.util.concurrent.MoreExecutors; +import io.grpc.ManagedChannelBuilder; +import java.util.ArrayList; +import java.util.List; +import java.util.concurrent.Executors; +import java.util.concurrent.ThreadLocalRandom; +import org.junit.AfterClass; +import org.junit.BeforeClass; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +@RunWith(JUnit4.class) +public class SpanFEBypassTest extends AbstractMockServerTest { + private static final Statement SELECT_RANDOM_STATEMENT = Statement.of("select * from random"); + private static final int RANDOM_RESULT_ROW_COUNT = 20; + private static Spanner spanner; + private static DatabaseClient client; + + @BeforeClass + public static void enableLocationApiAndSetupClient() { + SpannerOptions.useEnvironment( + new SpannerOptions.SpannerEnvironment() { + @Override + public boolean isEnableLocationApi() { + return true; + } + }); + spanner = + SpannerOptions.newBuilder() + .setProjectId("my-project") + .setHost(String.format("http://localhost:%d", getPort())) + .setChannelConfigurator(ManagedChannelBuilder::usePlaintext) + .setCredentials(NoCredentials.getInstance()) + .build() + .getService(); + client = spanner.getDatabaseClient(DatabaseId.of("my-project", "my-instance", "my-database")); + + RandomResultSetGenerator generator = new RandomResultSetGenerator(RANDOM_RESULT_ROW_COUNT); + mockSpanner.putStatementResult( + StatementResult.query(SELECT_RANDOM_STATEMENT, generator.generate())); + } + + @AfterClass + public static void cleanup() { + SpannerOptions.useDefaultEnvironment(); + if (spanner != null) { + spanner.close(); + } + } + + @Test + public void testSingleQuery() { + int rowCount = 0; + try (ResultSet resultSet = client.singleUse().executeQuery(SELECT_RANDOM_STATEMENT)) { + while (resultSet.next()) { + rowCount++; + } + } + assertEquals(RANDOM_RESULT_ROW_COUNT, rowCount); + } + + @Test + public void testParallelQueries() throws Exception { + int numThreads = 10; + ListeningExecutorService executor = + MoreExecutors.listeningDecorator(Executors.newFixedThreadPool(numThreads)); + List> results = new ArrayList<>(); + for (int i = 0; i < numThreads; i++) { + results.add( + executor.submit( + () -> { + try (ResultSet resultSet = + client.singleUse().executeQuery(SELECT_RANDOM_STATEMENT)) { + while (resultSet.next()) { + // Randomly stop consuming results somewhere halfway the results (sometimes). + if (ThreadLocalRandom.current().nextInt(RANDOM_RESULT_ROW_COUNT * 2) == 5) { + break; + } + } + } + return null; + })); + } + executor.shutdown(); + Futures.allAsList(results).get(); + } + + @Test + public void testSingleReadWriteTransaction() { + client.readWriteTransaction().run(transaction -> transaction.executeUpdate(INSERT_STATEMENT)); + } + + @Test + public void testParallelReadWriteTransactions() throws Exception { + int numThreads = 10; + ListeningExecutorService executor = + MoreExecutors.listeningDecorator(Executors.newFixedThreadPool(numThreads)); + List> results = new ArrayList<>(); + for (int i = 0; i < numThreads; i++) { + results.add( + executor.submit( + () -> { + client + .readWriteTransaction() + .run(transaction -> transaction.executeUpdate(INSERT_STATEMENT)); + return null; + })); + } + executor.shutdown(); + Futures.allAsList(results).get(); + } +} diff --git a/google-cloud-spanner/src/test/java/com/google/cloud/spanner/connection/AbstractMockServerTest.java b/google-cloud-spanner/src/test/java/com/google/cloud/spanner/connection/AbstractMockServerTest.java index a78df0471e..6bbd6a4198 100644 --- a/google-cloud-spanner/src/test/java/com/google/cloud/spanner/connection/AbstractMockServerTest.java +++ b/google-cloud-spanner/src/test/java/com/google/cloud/spanner/connection/AbstractMockServerTest.java @@ -314,7 +314,7 @@ protected String getBaseUrl() { server.getPort()); } - protected int getPort() { + protected static int getPort() { return server.getPort(); } From d719633a0ebc83cc716d555e725ee9ac474f3dec Mon Sep 17 00:00:00 2001 From: Rahul Yadav Date: Sat, 7 Feb 2026 17:11:55 +0530 Subject: [PATCH 4/6] fix test --- .../java/com/google/cloud/spanner/spi/v1/GapicSpannerRpc.java | 4 +++- .../spanner/{SpanFEBypassTest.java => LocationAwareTest.java} | 2 +- 2 files changed, 4 insertions(+), 2 deletions(-) rename google-cloud-spanner/src/test/java/com/google/cloud/spanner/{SpanFEBypassTest.java => LocationAwareTest.java} (98%) diff --git a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/spi/v1/GapicSpannerRpc.java b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/spi/v1/GapicSpannerRpc.java index 0c912a9f98..27f262c825 100644 --- a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/spi/v1/GapicSpannerRpc.java +++ b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/spi/v1/GapicSpannerRpc.java @@ -399,7 +399,9 @@ public GapicSpannerRpc(final SpannerOptions options) { // If it is enabled in options uses the channel pool provided by the gRPC-GCP extension. maybeEnableGrpcGcpExtension(defaultChannelProviderBuilder, options); - boolean enableLocationApi = options.isEnableLocationApi(); + boolean enableLocationApi = + options.isEnableLocationApi() + || Boolean.parseBoolean(System.getenv(EXPERIMENTAL_LOCATION_API_ENV_VAR)); TransportChannelProvider baseChannelProvider = MoreObjects.firstNonNull( options.getChannelProvider(), defaultChannelProviderBuilder.build()); diff --git a/google-cloud-spanner/src/test/java/com/google/cloud/spanner/SpanFEBypassTest.java b/google-cloud-spanner/src/test/java/com/google/cloud/spanner/LocationAwareTest.java similarity index 98% rename from google-cloud-spanner/src/test/java/com/google/cloud/spanner/SpanFEBypassTest.java rename to google-cloud-spanner/src/test/java/com/google/cloud/spanner/LocationAwareTest.java index 17cfcdd0bc..24075d8c36 100644 --- a/google-cloud-spanner/src/test/java/com/google/cloud/spanner/SpanFEBypassTest.java +++ b/google-cloud-spanner/src/test/java/com/google/cloud/spanner/LocationAwareTest.java @@ -38,7 +38,7 @@ import org.junit.runners.JUnit4; @RunWith(JUnit4.class) -public class SpanFEBypassTest extends AbstractMockServerTest { +public class LocationAwareTest extends AbstractMockServerTest { private static final Statement SELECT_RANDOM_STATEMENT = Statement.of("select * from random"); private static final int RANDOM_RESULT_ROW_COUNT = 20; private static Spanner spanner; From 33175a592ed578d2c04993de7f1059dd73b9c769 Mon Sep 17 00:00:00 2001 From: Rahul Yadav Date: Mon, 9 Feb 2026 14:16:50 +0530 Subject: [PATCH 5/6] add more tests --- .../cloud/spanner/spi/v1/GapicSpannerRpc.java | 4 +- .../cloud/spanner/spi/v1/KeyAwareChannel.java | 29 +- .../cloud/spanner/LocationAwareTest.java | 130 ++++++ .../spanner/spi/v1/GapicSpannerRpcTest.java | 90 ++-- .../spanner/spi/v1/KeyAwareChannelTest.java | 405 ++++++++++++++++++ 5 files changed, 594 insertions(+), 64 deletions(-) create mode 100644 google-cloud-spanner/src/test/java/com/google/cloud/spanner/spi/v1/KeyAwareChannelTest.java diff --git a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/spi/v1/GapicSpannerRpc.java b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/spi/v1/GapicSpannerRpc.java index 27f262c825..0c912a9f98 100644 --- a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/spi/v1/GapicSpannerRpc.java +++ b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/spi/v1/GapicSpannerRpc.java @@ -399,9 +399,7 @@ public GapicSpannerRpc(final SpannerOptions options) { // If it is enabled in options uses the channel pool provided by the gRPC-GCP extension. maybeEnableGrpcGcpExtension(defaultChannelProviderBuilder, options); - boolean enableLocationApi = - options.isEnableLocationApi() - || Boolean.parseBoolean(System.getenv(EXPERIMENTAL_LOCATION_API_ENV_VAR)); + boolean enableLocationApi = options.isEnableLocationApi(); TransportChannelProvider baseChannelProvider = MoreObjects.firstNonNull( options.getChannelProvider(), defaultChannelProviderBuilder.build()); diff --git a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/spi/v1/KeyAwareChannel.java b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/spi/v1/KeyAwareChannel.java index 1ff1880e0b..4fcf88130d 100644 --- a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/spi/v1/KeyAwareChannel.java +++ b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/spi/v1/KeyAwareChannel.java @@ -248,9 +248,8 @@ static final class KeyAwareClientCall private long pendingRequests; private boolean pendingHalfClose; @Nullable private Boolean pendingMessageCompression; - private boolean cancelled; - @Nullable private String cancelMessage; - @Nullable private Throwable cancelCause; + @Nullable private io.grpc.Status cancelledStatus; + @Nullable private Metadata cancelledTrailers; KeyAwareClientCall( KeyAwareChannel parentChannel, @@ -274,17 +273,17 @@ protected ClientCall delegate() { public void start(Listener responseListener, Metadata headers) { this.responseListener = new KeyAwareClientCallListener<>(responseListener, this); this.headers = headers; - if (cancelled) { + if (this.cancelledStatus != null) { this.responseListener.onClose( - io.grpc.Status.CANCELLED.withDescription(cancelMessage).withCause(cancelCause), - new Metadata()); + this.cancelledStatus, + this.cancelledTrailers == null ? new Metadata() : this.cancelledTrailers); } } @Override @SuppressWarnings("unchecked") public void sendMessage(RequestT message) { - if (cancelled) { + if (this.cancelledStatus != null) { return; } if (responseListener == null || headers == null) { @@ -311,10 +310,7 @@ public void sendMessage(RequestT message) { String databaseId = parentChannel.extractDatabaseIdFromSession(reqBuilder.getSession()); if (databaseId != null && reqBuilder.hasMutationKey()) { finder = parentChannel.getOrCreateChannelFinder(databaseId); - ChannelEndpoint routed = finder.findServer(reqBuilder); - if (endpoint == null) { - endpoint = routed; - } + endpoint = finder.findServer(reqBuilder); } allowDefaultAffinity = true; message = (RequestT) reqBuilder.build(); @@ -345,6 +341,7 @@ public void sendMessage(RequestT message) { delegate = endpoint.getChannel().newCall(methodDescriptor, callOptions); if (pendingMessageCompression != null) { delegate.setMessageCompression(pendingMessageCompression); + pendingMessageCompression = null; } delegate.start(responseListener, headers); drainPendingRequests(); @@ -368,12 +365,12 @@ public void cancel(@Nullable String message, @Nullable Throwable cause) { if (delegate != null) { delegate.cancel(message, cause); } else { - cancelled = true; - cancelMessage = message; - cancelCause = cause; + cancelledStatus = io.grpc.Status.CANCELLED.withDescription(message).withCause(cause); + Metadata trailers = + cause == null ? new Metadata() : io.grpc.Status.trailersFromThrowable(cause); + cancelledTrailers = trailers == null ? new Metadata() : trailers; if (responseListener != null) { - responseListener.onClose( - io.grpc.Status.CANCELLED.withDescription(message).withCause(cause), new Metadata()); + responseListener.onClose(cancelledStatus, cancelledTrailers); } } } diff --git a/google-cloud-spanner/src/test/java/com/google/cloud/spanner/LocationAwareTest.java b/google-cloud-spanner/src/test/java/com/google/cloud/spanner/LocationAwareTest.java index 24075d8c36..aa038d512e 100644 --- a/google-cloud-spanner/src/test/java/com/google/cloud/spanner/LocationAwareTest.java +++ b/google-cloud-spanner/src/test/java/com/google/cloud/spanner/LocationAwareTest.java @@ -16,21 +16,39 @@ package com.google.cloud.spanner; +import static com.google.cloud.spanner.SpannerApiFutures.get; import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertThrows; +import static org.junit.Assert.assertTrue; +import com.google.api.core.ApiFuture; +import com.google.api.gax.rpc.ApiCallContext; import com.google.cloud.NoCredentials; +import com.google.cloud.spanner.AsyncResultSet.CallbackResponse; +import com.google.cloud.spanner.MockSpannerServiceImpl.SimulatedExecutionTime; import com.google.cloud.spanner.MockSpannerServiceImpl.StatementResult; +import com.google.cloud.spanner.SpannerOptions.CallContextConfigurator; import com.google.cloud.spanner.connection.AbstractMockServerTest; import com.google.cloud.spanner.connection.RandomResultSetGenerator; import com.google.common.util.concurrent.Futures; import com.google.common.util.concurrent.ListenableFuture; import com.google.common.util.concurrent.ListeningExecutorService; import com.google.common.util.concurrent.MoreExecutors; +import com.google.spanner.v1.ExecuteSqlRequest; +import com.google.spanner.v1.SpannerGrpc; +import io.grpc.Context; import io.grpc.ManagedChannelBuilder; +import io.grpc.MethodDescriptor; +import io.grpc.Status; +import java.time.Duration; import java.util.ArrayList; +import java.util.LinkedList; import java.util.List; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; import java.util.concurrent.ThreadLocalRandom; +import java.util.concurrent.TimeUnit; import org.junit.AfterClass; import org.junit.BeforeClass; import org.junit.Test; @@ -44,6 +62,10 @@ public class LocationAwareTest extends AbstractMockServerTest { private static Spanner spanner; private static DatabaseClient client; + private static final class TimeoutHolder { + private Duration timeout; + } + @BeforeClass public static void enableLocationApiAndSetupClient() { SpannerOptions.useEnvironment( @@ -137,4 +159,112 @@ public void testParallelReadWriteTransactions() throws Exception { executor.shutdown(); Futures.allAsList(results).get(); } + + @Test + public void testExecuteStreamingSqlCallContextTimeout_locationAware() { + final TimeoutHolder timeoutHolder = new TimeoutHolder(); + CallContextConfigurator configurator = + new CallContextConfigurator() { + @Override + public ApiCallContext configure( + ApiCallContext context, ReqT request, MethodDescriptor method) { + if (request instanceof ExecuteSqlRequest + && method.equals(SpannerGrpc.getExecuteStreamingSqlMethod())) { + return context.withTimeoutDuration(timeoutHolder.timeout); + } + return null; + } + }; + + mockSpanner.setExecuteStreamingSqlExecutionTime( + SimulatedExecutionTime.ofMinimumAndRandomTime(10, 0)); + Context context = + Context.current().withValue(SpannerOptions.CALL_CONTEXT_CONFIGURATOR_KEY, configurator); + try { + context.run( + () -> { + timeoutHolder.timeout = Duration.ofNanos(1L); + SpannerException e = + assertThrows( + SpannerException.class, + () -> { + try (ResultSet rs = + client.singleUse().executeQuery(SELECT_RANDOM_STATEMENT)) { + rs.next(); + } + }); + assertEquals(ErrorCode.DEADLINE_EXCEEDED, e.getErrorCode()); + + timeoutHolder.timeout = Duration.ofMinutes(1L); + try (ResultSet rs = client.singleUse().executeQuery(SELECT_RANDOM_STATEMENT)) { + assertTrue(rs.next()); + } + }); + } finally { + mockSpanner.removeAllExecutionTimes(); + } + } + + @Test + public void testExecuteStreamingSqlInvalidArgumentPropagates_locationAware() { + mockSpanner.setExecuteStreamingSqlExecutionTime( + SimulatedExecutionTime.ofException( + Status.INVALID_ARGUMENT.withDescription("invalid request").asRuntimeException())); + try { + SpannerException e = + assertThrows( + SpannerException.class, + () -> { + try (ResultSet rs = client.singleUse().executeQuery(SELECT_RANDOM_STATEMENT)) { + rs.next(); + } + }); + assertEquals(ErrorCode.INVALID_ARGUMENT, e.getErrorCode()); + } finally { + mockSpanner.removeAllExecutionTimes(); + } + } + + @Test + public void testExecuteQueryAsyncCancelReturnsCancelled_locationAware() throws Exception { + final List values = new LinkedList<>(); + final CountDownLatch receivedFirstRow = new CountDownLatch(1); + final CountDownLatch cancelled = new CountDownLatch(1); + final ApiFuture callbackResult; + + ExecutorService executor = Executors.newSingleThreadExecutor(); + try (AsyncResultSet rs = client.singleUse().executeQueryAsync(SELECT_RANDOM_STATEMENT)) { + callbackResult = + rs.setCallback( + executor, + resultSet -> { + try { + while (true) { + switch (resultSet.tryNext()) { + case DONE: + return CallbackResponse.DONE; + case NOT_READY: + return CallbackResponse.CONTINUE; + case OK: + values.add(1); + receivedFirstRow.countDown(); + cancelled.await(); + break; + } + } + } catch (Throwable t) { + return CallbackResponse.DONE; + } + }); + + assertTrue(receivedFirstRow.await(30L, TimeUnit.SECONDS)); + rs.cancel(); + cancelled.countDown(); + SpannerException e = assertThrows(SpannerException.class, () -> get(callbackResult)); + assertEquals(ErrorCode.CANCELLED, e.getErrorCode()); + assertEquals(1, values.size()); + } finally { + executor.shutdownNow(); + } + } } diff --git a/google-cloud-spanner/src/test/java/com/google/cloud/spanner/spi/v1/GapicSpannerRpcTest.java b/google-cloud-spanner/src/test/java/com/google/cloud/spanner/spi/v1/GapicSpannerRpcTest.java index 707c4468a4..09d0ea7d25 100644 --- a/google-cloud-spanner/src/test/java/com/google/cloud/spanner/spi/v1/GapicSpannerRpcTest.java +++ b/google-cloud-spanner/src/test/java/com/google/cloud/spanner/spi/v1/GapicSpannerRpcTest.java @@ -46,7 +46,6 @@ import com.google.cloud.spanner.DatabaseId; import com.google.cloud.spanner.Dialect; import com.google.cloud.spanner.ErrorCode; -import com.google.cloud.spanner.JavaVersionUtil; import com.google.cloud.spanner.MockSpannerServiceImpl; import com.google.cloud.spanner.MockSpannerServiceImpl.SimulatedExecutionTime; import com.google.cloud.spanner.MockSpannerServiceImpl.StatementResult; @@ -888,18 +887,7 @@ public void testCreateSession_whenMultiplexedSessionIsFalse_assertSessionProto() } @Test - public void testChannelEndpointCacheFactoryUsedWhenLocationApiEnabled() throws Exception { - assumeTrue(isJava8() && !isWindows()); - String envVar = "GOOGLE_SPANNER_EXPERIMENTAL_LOCATION_API"; - - Class classOfMap = System.getenv().getClass(); - java.lang.reflect.Field field = classOfMap.getDeclaredField("m"); - field.setAccessible(true); - @SuppressWarnings("unchecked") - Map writeableEnvironmentVariables = - (Map) field.get(System.getenv()); - String originalValue = writeableEnvironmentVariables.get(envVar); - + public void testChannelEndpointCacheFactoryUsedWhenLocationApiEnabled() { AtomicBoolean factoryCalled = new AtomicBoolean(false); ChannelEndpointCacheFactory factory = baseProvider -> { @@ -908,34 +896,25 @@ public void testChannelEndpointCacheFactoryUsedWhenLocationApiEnabled() throws E }; try { - writeableEnvironmentVariables.put(envVar, "true"); + SpannerOptions.useEnvironment( + new SpannerOptions.SpannerEnvironment() { + @Override + public boolean isEnableLocationApi() { + return true; + } + }); SpannerOptions options = createSpannerOptions().toBuilder().setChannelEndpointCacheFactory(factory).build(); GapicSpannerRpc rpc = new GapicSpannerRpc(options, true); rpc.shutdown(); assertTrue(factoryCalled.get()); } finally { - if (originalValue == null) { - writeableEnvironmentVariables.remove(envVar); - } else { - writeableEnvironmentVariables.put(envVar, originalValue); - } + SpannerOptions.useDefaultEnvironment(); } } @Test - public void testLocationApiDoesNotOverrideExplicitChannelProvider() throws Exception { - assumeTrue(isJava8() && !isWindows()); - String envVar = "GOOGLE_SPANNER_EXPERIMENTAL_LOCATION_API"; - - Class classOfMap = System.getenv().getClass(); - java.lang.reflect.Field field = classOfMap.getDeclaredField("m"); - field.setAccessible(true); - @SuppressWarnings("unchecked") - Map writeableEnvironmentVariables = - (Map) field.get(System.getenv()); - String originalValue = writeableEnvironmentVariables.get(envVar); - + public void testLocationApiDoesNotOverrideExplicitChannelProvider() { AtomicBoolean factoryCalled = new AtomicBoolean(false); ChannelEndpointCacheFactory factory = baseProvider -> { @@ -949,7 +928,13 @@ public void testLocationApiDoesNotOverrideExplicitChannelProvider() throws Excep address.getHostString(), server.getPort(), providerUsed); try { - writeableEnvironmentVariables.put(envVar, "true"); + SpannerOptions.useEnvironment( + new SpannerOptions.SpannerEnvironment() { + @Override + public boolean isEnableLocationApi() { + return true; + } + }); SpannerOptions options = createSpannerOptions().toBuilder() .setChannelProvider(channelProvider) @@ -960,11 +945,34 @@ public void testLocationApiDoesNotOverrideExplicitChannelProvider() throws Excep assertTrue(providerUsed.get()); assertFalse(factoryCalled.get()); } finally { - if (originalValue == null) { - writeableEnvironmentVariables.remove(envVar); - } else { - writeableEnvironmentVariables.put(envVar, originalValue); - } + SpannerOptions.useDefaultEnvironment(); + } + } + + @Test + public void testLocationApiDisabledInOptionsDoesNotCreateKeyAwareChannelProvider() { + AtomicBoolean factoryCalled = new AtomicBoolean(false); + ChannelEndpointCacheFactory factory = + baseProvider -> { + factoryCalled.set(true); + return new GrpcChannelEndpointCache(baseProvider); + }; + + try { + SpannerOptions.useEnvironment( + new SpannerOptions.SpannerEnvironment() { + @Override + public boolean isEnableLocationApi() { + return false; + } + }); + SpannerOptions options = + createSpannerOptions().toBuilder().setChannelEndpointCacheFactory(factory).build(); + GapicSpannerRpc rpc = new GapicSpannerRpc(options, true); + rpc.shutdown(); + assertFalse(factoryCalled.get()); + } finally { + SpannerOptions.useDefaultEnvironment(); } } @@ -1117,12 +1125,4 @@ private SpannerOptions createSpannerOptions() { .setCallCredentialsProvider(() -> MoreCallCredentials.from(VARIABLE_CREDENTIALS)) .build(); } - - private boolean isJava8() { - return JavaVersionUtil.getJavaMajorVersion() == 8; - } - - private boolean isWindows() { - return System.getProperty("os.name").toLowerCase().contains("windows"); - } } diff --git a/google-cloud-spanner/src/test/java/com/google/cloud/spanner/spi/v1/KeyAwareChannelTest.java b/google-cloud-spanner/src/test/java/com/google/cloud/spanner/spi/v1/KeyAwareChannelTest.java new file mode 100644 index 0000000000..2b54f71329 --- /dev/null +++ b/google-cloud-spanner/src/test/java/com/google/cloud/spanner/spi/v1/KeyAwareChannelTest.java @@ -0,0 +1,405 @@ +/* + * Copyright 2026 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.cloud.spanner.spi.v1; + +import static com.google.common.truth.Truth.assertThat; +import static org.junit.Assert.assertThrows; + +import com.google.api.gax.grpc.InstantiatingGrpcChannelProvider; +import com.google.protobuf.ByteString; +import com.google.protobuf.Empty; +import com.google.spanner.v1.BeginTransactionRequest; +import com.google.spanner.v1.CommitRequest; +import com.google.spanner.v1.CommitResponse; +import com.google.spanner.v1.ExecuteSqlRequest; +import com.google.spanner.v1.ResultSet; +import com.google.spanner.v1.RollbackRequest; +import com.google.spanner.v1.SpannerGrpc; +import com.google.spanner.v1.Transaction; +import io.grpc.CallOptions; +import io.grpc.ClientCall; +import io.grpc.ManagedChannel; +import io.grpc.Metadata; +import io.grpc.MethodDescriptor; +import io.grpc.Status; +import java.io.IOException; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.concurrent.TimeUnit; +import javax.annotation.Nullable; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +@RunWith(JUnit4.class) +public class KeyAwareChannelTest { + private static final String DEFAULT_ADDRESS = "default:1234"; + private static final String SESSION = + "projects/p/instances/i/databases/d/sessions/test-session-id"; + + @Test + public void cancelBeforeStartPreservesTrailersAndSkipsDelegateCreation() throws Exception { + TestHarness harness = createHarness(); + ClientCall call = + harness.channel.newCall(SpannerGrpc.getExecuteSqlMethod(), CallOptions.DEFAULT); + + Metadata causeTrailers = new Metadata(); + Metadata.Key key = Metadata.Key.of("debug", Metadata.ASCII_STRING_MARSHALLER); + causeTrailers.put(key, "timeout"); + RuntimeException cause = + Status.DEADLINE_EXCEEDED + .withDescription("server timeout") + .asRuntimeException(causeTrailers); + + call.cancel("cancelled by client", cause); + CapturingListener listener = new CapturingListener<>(); + call.start(listener, new Metadata()); + + assertThat(harness.defaultManagedChannel.callCount()).isEqualTo(0); + assertThat(listener.closeCount).isEqualTo(1); + assertThat(listener.closedStatus.getCode()).isEqualTo(Status.Code.CANCELLED); + assertThat(listener.closedStatus.getDescription()).isEqualTo("cancelled by client"); + assertThat(listener.closedTrailers.get(key)).isEqualTo("timeout"); + } + + @Test + public void cancelAfterStartBeforeSendSkipsDelegateCreation() throws Exception { + TestHarness harness = createHarness(); + ClientCall call = + harness.channel.newCall(SpannerGrpc.getExecuteSqlMethod(), CallOptions.DEFAULT); + + CapturingListener listener = new CapturingListener<>(); + call.start(listener, new Metadata()); + call.cancel("cancel", null); + call.sendMessage(ExecuteSqlRequest.newBuilder().setSession(SESSION).build()); + + assertThat(harness.defaultManagedChannel.callCount()).isEqualTo(0); + assertThat(listener.closeCount).isEqualTo(1); + assertThat(listener.closedStatus.getCode()).isEqualTo(Status.Code.CANCELLED); + } + + @Test + public void cancelAfterDelegateCreationDelegatesToUnderlyingCall() throws Exception { + TestHarness harness = createHarness(); + ClientCall call = + harness.channel.newCall(SpannerGrpc.getExecuteSqlMethod(), CallOptions.DEFAULT); + + CapturingListener listener = new CapturingListener<>(); + call.start(listener, new Metadata()); + call.sendMessage(ExecuteSqlRequest.newBuilder().setSession(SESSION).build()); + + @SuppressWarnings("unchecked") + RecordingClientCall delegate = + (RecordingClientCall) + harness.defaultManagedChannel.latestCall(); + + RuntimeException cause = new RuntimeException("boom"); + call.cancel("cancel now", cause); + + assertThat(delegate.cancelCalled).isTrue(); + assertThat(delegate.cancelMessage).isEqualTo("cancel now"); + assertThat(delegate.cancelCause).isSameInstanceAs(cause); + assertThat(listener.closeCount).isEqualTo(0); + } + + @Test + public void sendMessageBeforeStartThrows() throws Exception { + TestHarness harness = createHarness(); + ClientCall call = + harness.channel.newCall(SpannerGrpc.getExecuteSqlMethod(), CallOptions.DEFAULT); + + assertThrows( + IllegalStateException.class, + () -> call.sendMessage(ExecuteSqlRequest.newBuilder().setSession(SESSION).build())); + } + + @Test + public void deadlineExceededFromDelegateIsForwardedToListener() throws Exception { + TestHarness harness = createHarness(); + ClientCall call = + harness.channel.newCall(SpannerGrpc.getExecuteSqlMethod(), CallOptions.DEFAULT); + CapturingListener listener = new CapturingListener<>(); + + call.start(listener, new Metadata()); + call.sendMessage(ExecuteSqlRequest.newBuilder().setSession(SESSION).build()); + + @SuppressWarnings("unchecked") + RecordingClientCall delegate = + (RecordingClientCall) + harness.defaultManagedChannel.latestCall(); + + Metadata trailers = new Metadata(); + Metadata.Key key = Metadata.Key.of("timeout", Metadata.ASCII_STRING_MARSHALLER); + trailers.put(key, "true"); + Status status = Status.DEADLINE_EXCEEDED.withDescription("rpc timeout"); + delegate.emitOnClose(status, trailers); + + assertThat(listener.closeCount).isEqualTo(1); + assertThat(listener.closedStatus).isEqualTo(status); + assertThat(listener.closedTrailers.get(key)).isEqualTo("true"); + } + + @Test + public void timeoutOnCommitClearsTransactionAffinity() throws Exception { + TestHarness harness = createHarness(); + ByteString transactionId = ByteString.copyFromUtf8("tx-1"); + + ClientCall beginCall = + harness.channel.newCall(SpannerGrpc.getBeginTransactionMethod(), CallOptions.DEFAULT); + beginCall.start(new CapturingListener(), new Metadata()); + beginCall.sendMessage(BeginTransactionRequest.newBuilder().setSession(SESSION).build()); + + @SuppressWarnings("unchecked") + RecordingClientCall beginDelegate = + (RecordingClientCall) + harness.defaultManagedChannel.latestCall(); + beginDelegate.emitOnMessage(Transaction.newBuilder().setId(transactionId).build()); + beginDelegate.emitOnClose(Status.OK, new Metadata()); + + ClientCall commitCall = + harness.channel.newCall(SpannerGrpc.getCommitMethod(), CallOptions.DEFAULT); + commitCall.start(new CapturingListener(), new Metadata()); + commitCall.sendMessage( + CommitRequest.newBuilder().setSession(SESSION).setTransactionId(transactionId).build()); + + assertThat(harness.endpointCache.getCount(DEFAULT_ADDRESS)).isEqualTo(1); + + @SuppressWarnings("unchecked") + RecordingClientCall commitDelegate = + (RecordingClientCall) + harness.defaultManagedChannel.latestCall(); + commitDelegate.emitOnClose(Status.DEADLINE_EXCEEDED, new Metadata()); + + ClientCall rollbackCall = + harness.channel.newCall(SpannerGrpc.getRollbackMethod(), CallOptions.DEFAULT); + rollbackCall.start(new CapturingListener(), new Metadata()); + rollbackCall.sendMessage( + RollbackRequest.newBuilder().setSession(SESSION).setTransactionId(transactionId).build()); + + assertThat(harness.endpointCache.getCount(DEFAULT_ADDRESS)).isEqualTo(1); + } + + private static TestHarness createHarness() throws IOException { + FakeEndpointCache endpointCache = new FakeEndpointCache(DEFAULT_ADDRESS); + InstantiatingGrpcChannelProvider provider = + InstantiatingGrpcChannelProvider.newBuilder().setEndpoint("localhost:9999").build(); + KeyAwareChannel channel = KeyAwareChannel.create(provider, baseProvider -> endpointCache); + return new TestHarness(channel, endpointCache, endpointCache.defaultManagedChannel()); + } + + private static final class TestHarness { + private final KeyAwareChannel channel; + private final FakeEndpointCache endpointCache; + private final FakeManagedChannel defaultManagedChannel; + + private TestHarness( + KeyAwareChannel channel, + FakeEndpointCache endpointCache, + FakeManagedChannel defaultManagedChannel) { + this.channel = channel; + this.endpointCache = endpointCache; + this.defaultManagedChannel = defaultManagedChannel; + } + } + + private static final class CapturingListener extends ClientCall.Listener { + private int closeCount; + @Nullable private Status closedStatus; + @Nullable private Metadata closedTrailers; + + @Override + public void onClose(Status status, Metadata trailers) { + this.closeCount++; + this.closedStatus = status; + this.closedTrailers = trailers; + } + } + + private static final class FakeEndpointCache implements ChannelEndpointCache { + private final String defaultAddress; + private final FakeEndpoint defaultEndpoint; + private final Map endpoints = new HashMap<>(); + private final Map getCount = new HashMap<>(); + + private FakeEndpointCache(String defaultAddress) { + this.defaultAddress = defaultAddress; + this.defaultEndpoint = new FakeEndpoint(defaultAddress); + } + + @Override + public ChannelEndpoint defaultChannel() { + return defaultEndpoint; + } + + @Override + public ChannelEndpoint get(String address) { + getCount.put(address, getCount.getOrDefault(address, 0) + 1); + if (defaultAddress.equals(address)) { + return defaultEndpoint; + } + return endpoints.computeIfAbsent(address, FakeEndpoint::new); + } + + @Override + public void evict(String address) { + endpoints.remove(address); + } + + @Override + public void shutdown() { + defaultEndpoint.channel.shutdown(); + for (FakeEndpoint endpoint : endpoints.values()) { + endpoint.channel.shutdown(); + } + endpoints.clear(); + } + + int getCount(String address) { + return getCount.getOrDefault(address, 0); + } + + FakeManagedChannel defaultManagedChannel() { + return defaultEndpoint.channel; + } + } + + private static final class FakeEndpoint implements ChannelEndpoint { + private final String address; + private final FakeManagedChannel channel; + + private FakeEndpoint(String address) { + this.address = address; + this.channel = new FakeManagedChannel(address); + } + + @Override + public String getAddress() { + return address; + } + + @Override + public boolean isHealthy() { + return true; + } + + @Override + public ManagedChannel getChannel() { + return channel; + } + } + + private static final class FakeManagedChannel extends ManagedChannel { + private final String authority; + private final List> calls = new ArrayList<>(); + private boolean shutdown; + + private FakeManagedChannel(String authority) { + this.authority = authority; + } + + @Override + public ManagedChannel shutdown() { + shutdown = true; + return this; + } + + @Override + public ManagedChannel shutdownNow() { + shutdown = true; + return this; + } + + @Override + public boolean isShutdown() { + return shutdown; + } + + @Override + public boolean isTerminated() { + return shutdown; + } + + @Override + public boolean awaitTermination(long timeout, TimeUnit unit) { + return shutdown; + } + + @Override + public ClientCall newCall( + MethodDescriptor methodDescriptor, CallOptions callOptions) { + RecordingClientCall call = new RecordingClientCall<>(); + calls.add(call); + return call; + } + + @Override + public String authority() { + return authority; + } + + int callCount() { + return calls.size(); + } + + RecordingClientCall latestCall() { + return calls.get(calls.size() - 1); + } + } + + private static final class RecordingClientCall + extends ClientCall { + @Nullable private ClientCall.Listener listener; + private boolean cancelCalled; + @Nullable private String cancelMessage; + @Nullable private Throwable cancelCause; + + @Override + public void start(ClientCall.Listener responseListener, Metadata headers) { + this.listener = responseListener; + } + + @Override + public void request(int numMessages) {} + + @Override + public void cancel(@Nullable String message, @Nullable Throwable cause) { + this.cancelCalled = true; + this.cancelMessage = message; + this.cancelCause = cause; + } + + @Override + public void halfClose() {} + + @Override + public void sendMessage(RequestT message) {} + + void emitOnMessage(ResponseT response) { + if (listener != null) { + listener.onMessage(response); + } + } + + void emitOnClose(Status status, Metadata trailers) { + if (listener != null) { + listener.onClose(status, trailers); + } + } + } +} From 7e2d482f3f30ee14393253f3ffc02d820257787f Mon Sep 17 00:00:00 2001 From: Rahul Yadav Date: Mon, 9 Feb 2026 15:16:14 +0530 Subject: [PATCH 6/6] add test for executesSql update cache --- .../cloud/spanner/spi/v1/KeyAwareChannel.java | 259 +++++++++++------- .../spanner/spi/v1/KeyAwareChannelTest.java | 86 ++++++ 2 files changed, 244 insertions(+), 101 deletions(-) diff --git a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/spi/v1/KeyAwareChannel.java b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/spi/v1/KeyAwareChannel.java index 4fcf88130d..3014b9db46 100644 --- a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/spi/v1/KeyAwareChannel.java +++ b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/spi/v1/KeyAwareChannel.java @@ -250,6 +250,7 @@ static final class KeyAwareClientCall @Nullable private Boolean pendingMessageCompression; @Nullable private io.grpc.Status cancelledStatus; @Nullable private Metadata cancelledTrailers; + private final Object lock = new Object(); KeyAwareClientCall( KeyAwareChannel parentChannel, @@ -262,158 +263,214 @@ static final class KeyAwareClientCall @Override protected ClientCall delegate() { - if (delegate == null) { - throw new IllegalStateException( - "Delegate call not initialized before use. sendMessage was likely not called."); + synchronized (lock) { + if (delegate == null) { + throw new IllegalStateException( + "Delegate call not initialized before use. sendMessage was likely not called."); + } + return delegate; } - return delegate; } @Override public void start(Listener responseListener, Metadata headers) { - this.responseListener = new KeyAwareClientCallListener<>(responseListener, this); - this.headers = headers; - if (this.cancelledStatus != null) { - this.responseListener.onClose( - this.cancelledStatus, - this.cancelledTrailers == null ? new Metadata() : this.cancelledTrailers); + Listener listenerToClose = null; + io.grpc.Status statusToClose = null; + Metadata trailersToClose = null; + synchronized (lock) { + this.responseListener = new KeyAwareClientCallListener<>(responseListener, this); + this.headers = headers; + if (this.cancelledStatus != null) { + listenerToClose = this.responseListener; + statusToClose = this.cancelledStatus; + trailersToClose = + this.cancelledTrailers == null ? new Metadata() : this.cancelledTrailers; + } + } + if (listenerToClose != null) { + listenerToClose.onClose(statusToClose, trailersToClose); } } @Override @SuppressWarnings("unchecked") public void sendMessage(RequestT message) { - if (this.cancelledStatus != null) { - return; - } - if (responseListener == null || headers == null) { - throw new IllegalStateException("start must be called before sendMessage"); - } - ChannelEndpoint endpoint = null; - ChannelFinder finder = null; - - if (message instanceof ReadRequest) { - ReadRequest.Builder reqBuilder = ((ReadRequest) message).toBuilder(); - RoutingDecision routing = routeFromRequest(reqBuilder); - finder = routing.finder; - endpoint = routing.endpoint; - message = (RequestT) reqBuilder.build(); - } else if (message instanceof ExecuteSqlRequest) { - ExecuteSqlRequest.Builder reqBuilder = ((ExecuteSqlRequest) message).toBuilder(); - RoutingDecision routing = routeFromRequest(reqBuilder); - finder = routing.finder; - endpoint = routing.endpoint; - message = (RequestT) reqBuilder.build(); - } else if (message instanceof BeginTransactionRequest) { - BeginTransactionRequest.Builder reqBuilder = - ((BeginTransactionRequest) message).toBuilder(); - String databaseId = parentChannel.extractDatabaseIdFromSession(reqBuilder.getSession()); - if (databaseId != null && reqBuilder.hasMutationKey()) { - finder = parentChannel.getOrCreateChannelFinder(databaseId); - endpoint = finder.findServer(reqBuilder); + synchronized (lock) { + if (this.cancelledStatus != null) { + return; } - allowDefaultAffinity = true; - message = (RequestT) reqBuilder.build(); - } else if (message instanceof CommitRequest) { - CommitRequest request = (CommitRequest) message; - if (!request.getTransactionId().isEmpty()) { - endpoint = parentChannel.affinityEndpoint(request.getTransactionId()); - transactionIdToClear = request.getTransactionId(); + if (responseListener == null || headers == null) { + throw new IllegalStateException("start must be called before sendMessage"); } - } else if (message instanceof RollbackRequest) { - RollbackRequest request = (RollbackRequest) message; - if (!request.getTransactionId().isEmpty()) { - endpoint = parentChannel.affinityEndpoint(request.getTransactionId()); - transactionIdToClear = request.getTransactionId(); + ChannelEndpoint endpoint = null; + ChannelFinder finder = null; + + if (message instanceof ReadRequest) { + ReadRequest.Builder reqBuilder = ((ReadRequest) message).toBuilder(); + RoutingDecision routing = routeFromRequest(reqBuilder); + finder = routing.finder; + endpoint = routing.endpoint; + message = (RequestT) reqBuilder.build(); + } else if (message instanceof ExecuteSqlRequest) { + ExecuteSqlRequest.Builder reqBuilder = ((ExecuteSqlRequest) message).toBuilder(); + RoutingDecision routing = routeFromRequest(reqBuilder); + finder = routing.finder; + endpoint = routing.endpoint; + message = (RequestT) reqBuilder.build(); + } else if (message instanceof BeginTransactionRequest) { + BeginTransactionRequest.Builder reqBuilder = + ((BeginTransactionRequest) message).toBuilder(); + String databaseId = parentChannel.extractDatabaseIdFromSession(reqBuilder.getSession()); + if (databaseId != null && reqBuilder.hasMutationKey()) { + finder = parentChannel.getOrCreateChannelFinder(databaseId); + endpoint = finder.findServer(reqBuilder); + } + allowDefaultAffinity = true; + message = (RequestT) reqBuilder.build(); + } else if (message instanceof CommitRequest) { + CommitRequest request = (CommitRequest) message; + if (!request.getTransactionId().isEmpty()) { + endpoint = parentChannel.affinityEndpoint(request.getTransactionId()); + transactionIdToClear = request.getTransactionId(); + } + } else if (message instanceof RollbackRequest) { + RollbackRequest request = (RollbackRequest) message; + if (!request.getTransactionId().isEmpty()) { + endpoint = parentChannel.affinityEndpoint(request.getTransactionId()); + transactionIdToClear = request.getTransactionId(); + } + } else { + throw new IllegalStateException( + "Only read, query, begin transaction, commit, and rollback requests are supported for" + + " key-aware calls."); } - } else { - throw new IllegalStateException( - "Only read, query, begin transaction, commit, and rollback requests are supported for" - + " key-aware calls."); - } - if (endpoint == null) { - endpoint = parentChannel.endpointCache.defaultChannel(); - } - selectedEndpoint = endpoint; - this.channelFinder = finder; + if (endpoint == null) { + endpoint = parentChannel.endpointCache.defaultChannel(); + } + selectedEndpoint = endpoint; + this.channelFinder = finder; - delegate = endpoint.getChannel().newCall(methodDescriptor, callOptions); - if (pendingMessageCompression != null) { - delegate.setMessageCompression(pendingMessageCompression); - pendingMessageCompression = null; - } - delegate.start(responseListener, headers); - drainPendingRequests(); - delegate.sendMessage(message); - if (pendingHalfClose) { - delegate.halfClose(); + delegate = endpoint.getChannel().newCall(methodDescriptor, callOptions); + if (pendingMessageCompression != null) { + delegate.setMessageCompression(pendingMessageCompression); + pendingMessageCompression = null; + } + delegate.start(responseListener, headers); + drainPendingRequests(); + delegate.sendMessage(message); + if (pendingHalfClose) { + delegate.halfClose(); + } } } @Override public void halfClose() { - if (delegate != null) { - delegate.halfClose(); - } else { - pendingHalfClose = true; + ClientCall currentDelegate; + synchronized (lock) { + if (this.cancelledStatus != null) { + return; + } + if (delegate == null) { + pendingHalfClose = true; + return; + } + currentDelegate = delegate; } + currentDelegate.halfClose(); } @Override public void cancel(@Nullable String message, @Nullable Throwable cause) { - if (delegate != null) { - delegate.cancel(message, cause); - } else { - cancelledStatus = io.grpc.Status.CANCELLED.withDescription(message).withCause(cause); - Metadata trailers = - cause == null ? new Metadata() : io.grpc.Status.trailersFromThrowable(cause); - cancelledTrailers = trailers == null ? new Metadata() : trailers; - if (responseListener != null) { - responseListener.onClose(cancelledStatus, cancelledTrailers); + ClientCall currentDelegate; + Listener listenerToClose = null; + io.grpc.Status statusToClose = null; + Metadata trailersToClose = null; + synchronized (lock) { + currentDelegate = delegate; + if (currentDelegate == null) { + cancelledStatus = io.grpc.Status.CANCELLED.withDescription(message).withCause(cause); + Metadata trailers = + cause == null ? new Metadata() : io.grpc.Status.trailersFromThrowable(cause); + cancelledTrailers = trailers == null ? new Metadata() : trailers; + if (responseListener != null) { + listenerToClose = responseListener; + statusToClose = cancelledStatus; + trailersToClose = cancelledTrailers; + } } } + if (currentDelegate != null) { + currentDelegate.cancel(message, cause); + } else if (listenerToClose != null) { + listenerToClose.onClose(statusToClose, trailersToClose); + } } @Override public void request(int numMessages) { - if (delegate != null) { - delegate.request(numMessages); - return; - } - if (numMessages <= 0) { - return; - } - long updated = pendingRequests + numMessages; - if (updated < 0L) { - updated = Long.MAX_VALUE; + ClientCall currentDelegate; + synchronized (lock) { + if (cancelledStatus != null) { + return; + } + if (delegate != null) { + currentDelegate = delegate; + } else { + if (numMessages <= 0) { + return; + } + long updated = pendingRequests + numMessages; + if (updated < 0L) { + updated = Long.MAX_VALUE; + } + pendingRequests = updated; + return; + } } - pendingRequests = updated; + currentDelegate.request(numMessages); } @Override public boolean isReady() { - if (delegate == null) { + ClientCall currentDelegate; + synchronized (lock) { + currentDelegate = delegate; + } + if (currentDelegate == null) { return false; } - return delegate.isReady(); + return currentDelegate.isReady(); } @Override public void setMessageCompression(boolean enabled) { - if (delegate != null) { - delegate.setMessageCompression(enabled); - } else { - pendingMessageCompression = enabled; + ClientCall currentDelegate; + synchronized (lock) { + if (cancelledStatus != null) { + return; + } + if (delegate != null) { + currentDelegate = delegate; + } else { + pendingMessageCompression = enabled; + return; + } } + currentDelegate.setMessageCompression(enabled); } private void drainPendingRequests() { + ClientCall currentDelegate = delegate; + if (currentDelegate == null) { + return; + } long requests = pendingRequests; pendingRequests = 0L; while (requests > 0) { int batch = requests > Integer.MAX_VALUE ? Integer.MAX_VALUE : (int) requests; - delegate.request(batch); + currentDelegate.request(batch); requests -= batch; } } diff --git a/google-cloud-spanner/src/test/java/com/google/cloud/spanner/spi/v1/KeyAwareChannelTest.java b/google-cloud-spanner/src/test/java/com/google/cloud/spanner/spi/v1/KeyAwareChannelTest.java index 2b54f71329..21c27604c9 100644 --- a/google-cloud-spanner/src/test/java/com/google/cloud/spanner/spi/v1/KeyAwareChannelTest.java +++ b/google-cloud-spanner/src/test/java/com/google/cloud/spanner/spi/v1/KeyAwareChannelTest.java @@ -23,12 +23,17 @@ import com.google.protobuf.ByteString; import com.google.protobuf.Empty; import com.google.spanner.v1.BeginTransactionRequest; +import com.google.spanner.v1.CacheUpdate; import com.google.spanner.v1.CommitRequest; import com.google.spanner.v1.CommitResponse; import com.google.spanner.v1.ExecuteSqlRequest; +import com.google.spanner.v1.Group; +import com.google.spanner.v1.Range; import com.google.spanner.v1.ResultSet; import com.google.spanner.v1.RollbackRequest; +import com.google.spanner.v1.RoutingHint; import com.google.spanner.v1.SpannerGrpc; +import com.google.spanner.v1.Tablet; import com.google.spanner.v1.Transaction; import io.grpc.CallOptions; import io.grpc.ClientCall; @@ -195,6 +200,75 @@ public void timeoutOnCommitClearsTransactionAffinity() throws Exception { assertThat(harness.endpointCache.getCount(DEFAULT_ADDRESS)).isEqualTo(1); } + @Test + public void requestAfterCancelBeforeSendIsIgnored() throws Exception { + TestHarness harness = createHarness(); + ClientCall call = + harness.channel.newCall(SpannerGrpc.getExecuteSqlMethod(), CallOptions.DEFAULT); + + CapturingListener listener = new CapturingListener<>(); + call.start(listener, new Metadata()); + call.cancel("cancel", null); + call.request(10); + call.sendMessage(ExecuteSqlRequest.newBuilder().setSession(SESSION).build()); + + assertThat(harness.defaultManagedChannel.callCount()).isEqualTo(0); + assertThat(listener.closeCount).isEqualTo(1); + assertThat(listener.closedStatus.getCode()).isEqualTo(Status.Code.CANCELLED); + } + + @Test + public void resultSetCacheUpdateRoutesSubsequentRequest() throws Exception { + TestHarness harness = createHarness(); + ExecuteSqlRequest request = + ExecuteSqlRequest.newBuilder() + .setSession(SESSION) + .setRoutingHint(RoutingHint.newBuilder().setKey(bytes("a")).build()) + .build(); + + ClientCall firstCall = + harness.channel.newCall(SpannerGrpc.getExecuteSqlMethod(), CallOptions.DEFAULT); + firstCall.start(new CapturingListener(), new Metadata()); + firstCall.sendMessage(request); + + @SuppressWarnings("unchecked") + RecordingClientCall firstDelegate = + (RecordingClientCall) + harness.defaultManagedChannel.latestCall(); + + CacheUpdate cacheUpdate = + CacheUpdate.newBuilder() + .setDatabaseId(7L) + .addRange( + Range.newBuilder() + .setStartKey(bytes("a")) + .setLimitKey(bytes("z")) + .setGroupUid(9L) + .setSplitId(1L) + .setGeneration(bytes("1"))) + .addGroup( + Group.newBuilder() + .setGroupUid(9L) + .setGeneration(bytes("1")) + .addTablets( + Tablet.newBuilder() + .setTabletUid(3L) + .setServerAddress("routed:1234") + .setIncarnation(bytes("1")) + .setDistance(0))) + .build(); + + firstDelegate.emitOnMessage(ResultSet.newBuilder().setCacheUpdate(cacheUpdate).build()); + + ClientCall secondCall = + harness.channel.newCall(SpannerGrpc.getExecuteSqlMethod(), CallOptions.DEFAULT); + secondCall.start(new CapturingListener(), new Metadata()); + secondCall.sendMessage(request); + + assertThat(harness.endpointCache.callCountForAddress(DEFAULT_ADDRESS)).isEqualTo(1); + assertThat(harness.endpointCache.callCountForAddress("routed:1234")).isEqualTo(1); + } + private static TestHarness createHarness() throws IOException { FakeEndpointCache endpointCache = new FakeEndpointCache(DEFAULT_ADDRESS); InstantiatingGrpcChannelProvider provider = @@ -277,6 +351,14 @@ int getCount(String address) { FakeManagedChannel defaultManagedChannel() { return defaultEndpoint.channel; } + + int callCountForAddress(String address) { + if (defaultAddress.equals(address)) { + return defaultEndpoint.channel.callCount(); + } + FakeEndpoint endpoint = endpoints.get(address); + return endpoint == null ? 0 : endpoint.channel.callCount(); + } } private static final class FakeEndpoint implements ChannelEndpoint { @@ -402,4 +484,8 @@ void emitOnClose(Status status, Metadata trailers) { } } } + + private static ByteString bytes(String value) { + return ByteString.copyFromUtf8(value); + } }