diff --git a/core/src/main/java/com/datastax/dse/driver/internal/core/graph/BytecodeGraphStatement.java b/core/src/main/java/com/datastax/dse/driver/internal/core/graph/BytecodeGraphStatement.java index b6fe05a987c..e8e4554e81f 100644 --- a/core/src/main/java/com/datastax/dse/driver/internal/core/graph/BytecodeGraphStatement.java +++ b/core/src/main/java/com/datastax/dse/driver/internal/core/graph/BytecodeGraphStatement.java @@ -19,9 +19,11 @@ import com.datastax.dse.driver.api.core.graph.FluentGraphStatement; import com.datastax.oss.driver.api.core.ConsistencyLevel; +import com.datastax.oss.driver.api.core.RequestRoutingType; import com.datastax.oss.driver.api.core.config.DriverExecutionProfile; import com.datastax.oss.driver.api.core.cql.Statement; import com.datastax.oss.driver.api.core.metadata.Node; +import edu.umd.cs.findbugs.annotations.NonNull; import java.nio.ByteBuffer; import java.time.Duration; import java.util.Collections; @@ -127,4 +129,10 @@ protected BytecodeGraphStatement newInstance( readConsistencyLevel, writeConsistencyLevel); } + + @NonNull + @Override + public RequestRoutingType getRequestRoutingType() { + return RequestRoutingType.REGULAR; + } } diff --git a/core/src/main/java/com/datastax/dse/driver/internal/core/graph/DefaultBatchGraphStatement.java b/core/src/main/java/com/datastax/dse/driver/internal/core/graph/DefaultBatchGraphStatement.java index e16287c415d..632d45e61d6 100644 --- a/core/src/main/java/com/datastax/dse/driver/internal/core/graph/DefaultBatchGraphStatement.java +++ b/core/src/main/java/com/datastax/dse/driver/internal/core/graph/DefaultBatchGraphStatement.java @@ -19,6 +19,7 @@ import com.datastax.dse.driver.api.core.graph.BatchGraphStatement; import com.datastax.oss.driver.api.core.ConsistencyLevel; +import com.datastax.oss.driver.api.core.RequestRoutingType; import com.datastax.oss.driver.api.core.config.DriverExecutionProfile; import com.datastax.oss.driver.api.core.metadata.Node; import com.datastax.oss.driver.shaded.guava.common.collect.ImmutableList; @@ -151,4 +152,10 @@ protected BatchGraphStatement newInstance( public Iterator iterator() { return this.traversals.iterator(); } + + @NonNull + @Override + public RequestRoutingType getRequestRoutingType() { + return RequestRoutingType.REGULAR; + } } diff --git a/core/src/main/java/com/datastax/dse/driver/internal/core/graph/DefaultFluentGraphStatement.java b/core/src/main/java/com/datastax/dse/driver/internal/core/graph/DefaultFluentGraphStatement.java index 0f6f1faabbf..44fa9e41853 100644 --- a/core/src/main/java/com/datastax/dse/driver/internal/core/graph/DefaultFluentGraphStatement.java +++ b/core/src/main/java/com/datastax/dse/driver/internal/core/graph/DefaultFluentGraphStatement.java @@ -19,6 +19,7 @@ import com.datastax.dse.driver.api.core.graph.FluentGraphStatement; import com.datastax.oss.driver.api.core.ConsistencyLevel; +import com.datastax.oss.driver.api.core.RequestRoutingType; import com.datastax.oss.driver.api.core.config.DriverExecutionProfile; import com.datastax.oss.driver.api.core.metadata.Node; import edu.umd.cs.findbugs.annotations.NonNull; @@ -103,4 +104,10 @@ protected FluentGraphStatement newInstance( public GraphTraversal getTraversal() { return traversal; } + + @NonNull + @Override + public RequestRoutingType getRequestRoutingType() { + return RequestRoutingType.REGULAR; + } } diff --git a/core/src/main/java/com/datastax/dse/driver/internal/core/graph/DefaultScriptGraphStatement.java b/core/src/main/java/com/datastax/dse/driver/internal/core/graph/DefaultScriptGraphStatement.java index 71f79134237..587e1221b41 100644 --- a/core/src/main/java/com/datastax/dse/driver/internal/core/graph/DefaultScriptGraphStatement.java +++ b/core/src/main/java/com/datastax/dse/driver/internal/core/graph/DefaultScriptGraphStatement.java @@ -19,6 +19,7 @@ import com.datastax.dse.driver.api.core.graph.ScriptGraphStatement; import com.datastax.oss.driver.api.core.ConsistencyLevel; +import com.datastax.oss.driver.api.core.RequestRoutingType; import com.datastax.oss.driver.api.core.config.DriverExecutionProfile; import com.datastax.oss.driver.api.core.metadata.Node; import com.datastax.oss.protocol.internal.util.collection.NullAllowingImmutableMap; @@ -204,4 +205,10 @@ protected ScriptGraphStatement newInstance( public String toString() { return String.format("ScriptGraphStatement['%s', params: %s]", this.script, this.queryParams); } + + @NonNull + @Override + public RequestRoutingType getRequestRoutingType() { + return RequestRoutingType.REGULAR; + } } diff --git a/core/src/main/java/com/datastax/oss/driver/api/core/RequestRoutingType.java b/core/src/main/java/com/datastax/oss/driver/api/core/RequestRoutingType.java new file mode 100644 index 00000000000..43bffe99589 --- /dev/null +++ b/core/src/main/java/com/datastax/oss/driver/api/core/RequestRoutingType.java @@ -0,0 +1,9 @@ +package com.datastax.oss.driver.api.core; + +/** The type of routing for a given request. */ +public enum RequestRoutingType { + /** A regular (non-LWT) request. */ + REGULAR, + /** A lightweight transaction (LWT) request. */ + LWT +} diff --git a/core/src/main/java/com/datastax/oss/driver/api/core/config/DefaultDriverOption.java b/core/src/main/java/com/datastax/oss/driver/api/core/config/DefaultDriverOption.java index e651b1d999e..9e0119903df 100644 --- a/core/src/main/java/com/datastax/oss/driver/api/core/config/DefaultDriverOption.java +++ b/core/src/main/java/com/datastax/oss/driver/api/core/config/DefaultDriverOption.java @@ -718,7 +718,7 @@ public enum DefaultDriverOption implements DriverOption { /** * CQL 4.x has a known issue where prepared statement invalidation may be bypassed on the client - * side. Reference: https://github.com/scylladb/scylladb/issues/20860 + * side. Reference: link * *

When this occurs, the client's metadata can become outdated, leading to various * deserialization errors. @@ -1063,7 +1063,17 @@ public enum DefaultDriverOption implements DriverOption { *

Value type: {@link java.util.List List}<{@link String}> */ LOAD_BALANCING_DC_FAILOVER_PREFERRED_REMOTE_DCS( - "advanced.load-balancing-policy.dc-failover.preferred-remote-dcs"); + "advanced.load-balancing-policy.dc-failover.preferred-remote-dcs"), + + /** + * The default routing method to use for LWT (Lightweight Transaction) requests. REGULAR uses the + * standard load balancing algorithm with slow replica avoidance and shuffling. + * PRESERVE_REPLICA_ORDER maintains the replica order from the partitioner. + * + *

Value-type: string + */ + LOAD_BALANCING_DEFAULT_LWT_REQUEST_ROUTING_METHOD( + "advanced.load-balancing-policy.default-lwt-request-routing-method"); private final String path; diff --git a/core/src/main/java/com/datastax/oss/driver/api/core/config/OptionsMap.java b/core/src/main/java/com/datastax/oss/driver/api/core/config/OptionsMap.java index ed95389f57b..28559ea8556 100644 --- a/core/src/main/java/com/datastax/oss/driver/api/core/config/OptionsMap.java +++ b/core/src/main/java/com/datastax/oss/driver/api/core/config/OptionsMap.java @@ -393,6 +393,9 @@ protected static void fillWithDriverDefaults(OptionsMap map) { map.put(TypedDriverOption.METRICS_GENERATE_AGGREGABLE_HISTOGRAMS, true); map.put( TypedDriverOption.LOAD_BALANCING_DC_FAILOVER_PREFERRED_REMOTE_DCS, ImmutableList.of("")); + map.put( + TypedDriverOption.LOAD_BALANCING_DEFAULT_LWT_REQUEST_ROUTING_METHOD, + "PRESERVE_REPLICA_ORDER"); } @Immutable diff --git a/core/src/main/java/com/datastax/oss/driver/api/core/config/TypedDriverOption.java b/core/src/main/java/com/datastax/oss/driver/api/core/config/TypedDriverOption.java index 1fa752783d8..818468ee9d5 100644 --- a/core/src/main/java/com/datastax/oss/driver/api/core/config/TypedDriverOption.java +++ b/core/src/main/java/com/datastax/oss/driver/api/core/config/TypedDriverOption.java @@ -933,6 +933,12 @@ public String toString() { DefaultDriverOption.LOAD_BALANCING_DC_FAILOVER_PREFERRED_REMOTE_DCS, GenericType.listOf(String.class)); + /** The request routing method to use in the request routing load balancing policy. */ + public static final TypedDriverOption LOAD_BALANCING_DEFAULT_LWT_REQUEST_ROUTING_METHOD = + new TypedDriverOption<>( + DefaultDriverOption.LOAD_BALANCING_DEFAULT_LWT_REQUEST_ROUTING_METHOD, + GenericType.STRING); + private static Iterable> introspectBuiltInValues() { try { ImmutableList.Builder> result = ImmutableList.builder(); diff --git a/core/src/main/java/com/datastax/oss/driver/api/core/cql/BatchStatement.java b/core/src/main/java/com/datastax/oss/driver/api/core/cql/BatchStatement.java index e831ed62369..63afd227425 100644 --- a/core/src/main/java/com/datastax/oss/driver/api/core/cql/BatchStatement.java +++ b/core/src/main/java/com/datastax/oss/driver/api/core/cql/BatchStatement.java @@ -280,13 +280,4 @@ default int computeSizeInBytes(@NonNull DriverContext context) { return size; } - - /** - * Overrides LWT state to a specific value. If unset or set to {@code null} the {@link - * Statement#isLWT()} method will infer result from the statments in the batch. - * - * @param newIsLWT new Boolean to set - * @return new BatchStatement with updated isLWT field. - */ - BatchStatement setIsLWT(Boolean newIsLWT); } diff --git a/core/src/main/java/com/datastax/oss/driver/api/core/cql/BatchStatementBuilder.java b/core/src/main/java/com/datastax/oss/driver/api/core/cql/BatchStatementBuilder.java index 26e0aef8ca1..8e34c916ea1 100644 --- a/core/src/main/java/com/datastax/oss/driver/api/core/cql/BatchStatementBuilder.java +++ b/core/src/main/java/com/datastax/oss/driver/api/core/cql/BatchStatementBuilder.java @@ -39,7 +39,6 @@ public class BatchStatementBuilder extends StatementBuilder> statementsBuilder; private int statementsCount; - @Nullable private Boolean isLWT = null; public BatchStatementBuilder(@NonNull BatchType batchType) { this.batchType = batchType; @@ -76,19 +75,6 @@ public BatchStatementBuilder setKeyspace(@NonNull String keyspaceName) { return setKeyspace(CqlIdentifier.fromCql(keyspaceName)); } - /** - * Forces driver to see this batch as LWT or non-LWT. Note that if never explicitly set or set to - * {@code null}, the resulting {@code DefaultBatchStatement} will decide its LWT state based on - * contained statements. - * - * @return this builder; never {@code null}. - */ - @NonNull - public BatchStatementBuilder setIsLWT(Boolean newIsLWT) { - this.isLWT = newIsLWT; - return this; - } - /** * Adds a new statement to the batch. * @@ -172,7 +158,7 @@ public BatchStatement build() { timeout, node, nowInSeconds, - isLWT); + requestRoutingType); } public int getStatementsCount() { diff --git a/core/src/main/java/com/datastax/oss/driver/api/core/cql/BoundStatementBuilder.java b/core/src/main/java/com/datastax/oss/driver/api/core/cql/BoundStatementBuilder.java index 7e8f8723e1b..58a3a2319a2 100644 --- a/core/src/main/java/com/datastax/oss/driver/api/core/cql/BoundStatementBuilder.java +++ b/core/src/main/java/com/datastax/oss/driver/api/core/cql/BoundStatementBuilder.java @@ -20,6 +20,7 @@ import com.datastax.oss.driver.api.core.ConsistencyLevel; import com.datastax.oss.driver.api.core.CqlIdentifier; import com.datastax.oss.driver.api.core.ProtocolVersion; +import com.datastax.oss.driver.api.core.RequestRoutingType; import com.datastax.oss.driver.api.core.config.DriverExecutionProfile; import com.datastax.oss.driver.api.core.metadata.token.Token; import com.datastax.oss.driver.api.core.type.DataType; @@ -67,7 +68,8 @@ public BoundStatementBuilder( @Nullable ConsistencyLevel serialConsistencyLevel, @Nullable Duration timeout, @NonNull CodecRegistry codecRegistry, - @NonNull ProtocolVersion protocolVersion) { + @NonNull ProtocolVersion protocolVersion, + @Nullable RequestRoutingType requestRoutingType) { this.preparedStatement = preparedStatement; this.variableDefinitions = variableDefinitions; this.values = values; @@ -89,6 +91,7 @@ public BoundStatementBuilder( this.timeout = timeout; this.codecRegistry = codecRegistry; this.protocolVersion = protocolVersion; + this.requestRoutingType = requestRoutingType; } public BoundStatementBuilder(@NonNull BoundStatement template) { @@ -204,6 +207,7 @@ public BoundStatement build() { codecRegistry, protocolVersion, node, - nowInSeconds); + nowInSeconds, + requestRoutingType); } } diff --git a/core/src/main/java/com/datastax/oss/driver/api/core/cql/PreparedStatement.java b/core/src/main/java/com/datastax/oss/driver/api/core/cql/PreparedStatement.java index 982db8b3b41..7ad77463aed 100644 --- a/core/src/main/java/com/datastax/oss/driver/api/core/cql/PreparedStatement.java +++ b/core/src/main/java/com/datastax/oss/driver/api/core/cql/PreparedStatement.java @@ -25,6 +25,7 @@ import com.datastax.oss.driver.api.core.CqlSession; import com.datastax.oss.driver.api.core.DefaultProtocolVersion; +import com.datastax.oss.driver.api.core.RequestRoutingType; import com.datastax.oss.driver.api.core.metadata.token.Partitioner; import edu.umd.cs.findbugs.annotations.NonNull; import edu.umd.cs.findbugs.annotations.Nullable; @@ -133,6 +134,10 @@ public interface PreparedStatement { */ boolean isLWT(); + /** Returns the request routing type for this prepared statement. */ + @Nullable + RequestRoutingType getRequestRoutingType(); + /** * Updates {@link #getResultMetadataId()} and {@link #getResultSetDefinitions()} atomically. * diff --git a/core/src/main/java/com/datastax/oss/driver/api/core/cql/SimpleStatement.java b/core/src/main/java/com/datastax/oss/driver/api/core/cql/SimpleStatement.java index ef04cd14a5b..20f17fa716e 100644 --- a/core/src/main/java/com/datastax/oss/driver/api/core/cql/SimpleStatement.java +++ b/core/src/main/java/com/datastax/oss/driver/api/core/cql/SimpleStatement.java @@ -20,6 +20,7 @@ import com.datastax.oss.driver.api.core.CqlIdentifier; import com.datastax.oss.driver.api.core.CqlSession; import com.datastax.oss.driver.api.core.DefaultProtocolVersion; +import com.datastax.oss.driver.api.core.RequestRoutingType; import com.datastax.oss.driver.api.core.context.DriverContext; import com.datastax.oss.driver.api.core.session.Request; import com.datastax.oss.driver.internal.core.cql.DefaultSimpleStatement; @@ -84,7 +85,8 @@ static SimpleStatement newInstance(@NonNull String cqlQuery) { null, null, null, - Statement.NO_NOW_IN_SECONDS); + Statement.NO_NOW_IN_SECONDS, + RequestRoutingType.REGULAR); } /** @@ -118,7 +120,8 @@ static SimpleStatement newInstance( null, null, null, - Statement.NO_NOW_IN_SECONDS); + Statement.NO_NOW_IN_SECONDS, + RequestRoutingType.REGULAR); } /** @@ -149,7 +152,8 @@ static SimpleStatement newInstance( null, null, null, - Statement.NO_NOW_IN_SECONDS); + Statement.NO_NOW_IN_SECONDS, + RequestRoutingType.REGULAR); } /** diff --git a/core/src/main/java/com/datastax/oss/driver/api/core/cql/SimpleStatementBuilder.java b/core/src/main/java/com/datastax/oss/driver/api/core/cql/SimpleStatementBuilder.java index 1ac910ff6a7..38deffe404c 100644 --- a/core/src/main/java/com/datastax/oss/driver/api/core/cql/SimpleStatementBuilder.java +++ b/core/src/main/java/com/datastax/oss/driver/api/core/cql/SimpleStatementBuilder.java @@ -185,6 +185,7 @@ public SimpleStatement build() { serialConsistencyLevel, timeout, node, - nowInSeconds); + nowInSeconds, + requestRoutingType); } } diff --git a/core/src/main/java/com/datastax/oss/driver/api/core/cql/Statement.java b/core/src/main/java/com/datastax/oss/driver/api/core/cql/Statement.java index 464a0a92a53..e88831e7925 100644 --- a/core/src/main/java/com/datastax/oss/driver/api/core/cql/Statement.java +++ b/core/src/main/java/com/datastax/oss/driver/api/core/cql/Statement.java @@ -27,6 +27,7 @@ import com.datastax.oss.driver.api.core.CqlIdentifier; import com.datastax.oss.driver.api.core.CqlSession; import com.datastax.oss.driver.api.core.NoNodeAvailableException; +import com.datastax.oss.driver.api.core.RequestRoutingType; import com.datastax.oss.driver.api.core.config.DefaultDriverOption; import com.datastax.oss.driver.api.core.config.DriverExecutionProfile; import com.datastax.oss.driver.api.core.context.DriverContext; @@ -525,6 +526,20 @@ default SelfT setNowInSeconds(int nowInSeconds) { return (SelfT) this; } + /** + * Sets the request routing type to use when applying the request (for testing purposes). + * + *

This method's default implementation returns the statement unchanged. The only reason it + * exists is to preserve binary compatibility. Internally, the driver overrides it to record the + * new value. + */ + @NonNull + @CheckReturnValue + @SuppressWarnings("unchecked") + default SelfT setRequestRoutingType(@Nullable RequestRoutingType requestRoutingType) { + return (SelfT) this; + } + /** * Informs if this is a prepared LWT query. * @@ -540,7 +555,9 @@ default SelfT setNowInSeconds(int nowInSeconds) { * * @see Docs about LWT */ - boolean isLWT(); + default boolean isLWT() { + return getRequestRoutingType() == RequestRoutingType.LWT; // treating null as non-LWT + } /** * Calculates the approximate size in bytes that the statement will have when encoded. diff --git a/core/src/main/java/com/datastax/oss/driver/api/core/cql/StatementBuilder.java b/core/src/main/java/com/datastax/oss/driver/api/core/cql/StatementBuilder.java index 531070b854c..9894dd9c813 100644 --- a/core/src/main/java/com/datastax/oss/driver/api/core/cql/StatementBuilder.java +++ b/core/src/main/java/com/datastax/oss/driver/api/core/cql/StatementBuilder.java @@ -19,6 +19,7 @@ import com.datastax.oss.driver.api.core.ConsistencyLevel; import com.datastax.oss.driver.api.core.CqlIdentifier; +import com.datastax.oss.driver.api.core.RequestRoutingType; import com.datastax.oss.driver.api.core.config.DriverExecutionProfile; import com.datastax.oss.driver.api.core.metadata.Node; import com.datastax.oss.driver.api.core.metadata.token.Token; @@ -61,6 +62,7 @@ public abstract class StatementBuilder< @Nullable protected Duration timeout; @Nullable protected Node node; protected int nowInSeconds = Statement.NO_NOW_IN_SECONDS; + @Nullable protected RequestRoutingType requestRoutingType; protected StatementBuilder() { // nothing to do @@ -87,6 +89,7 @@ protected StatementBuilder(StatementT template) { this.timeout = template.getTimeout(); this.node = template.getNode(); this.nowInSeconds = template.getNowInSeconds(); + this.requestRoutingType = template.getRequestRoutingType(); } /** @see Statement#setExecutionProfileName(String) */ @@ -282,6 +285,12 @@ public SelfT setNowInSeconds(int nowInSeconds) { return self; } + @NonNull + public SelfT setRequestRoutingType(@Nullable RequestRoutingType requestRoutingType) { + this.requestRoutingType = requestRoutingType; + return self; + } + @NonNull protected Map buildCustomPayload() { return (customPayloadBuilder == null) diff --git a/core/src/main/java/com/datastax/oss/driver/api/core/session/Request.java b/core/src/main/java/com/datastax/oss/driver/api/core/session/Request.java index 92c25e146c7..c3035f2bf12 100644 --- a/core/src/main/java/com/datastax/oss/driver/api/core/session/Request.java +++ b/core/src/main/java/com/datastax/oss/driver/api/core/session/Request.java @@ -25,6 +25,7 @@ import com.datastax.oss.driver.api.core.CqlIdentifier; import com.datastax.oss.driver.api.core.DefaultProtocolVersion; +import com.datastax.oss.driver.api.core.RequestRoutingType; import com.datastax.oss.driver.api.core.config.DefaultDriverOption; import com.datastax.oss.driver.api.core.config.DriverConfig; import com.datastax.oss.driver.api.core.config.DriverExecutionProfile; @@ -101,7 +102,7 @@ public interface Request { * The table to use for tablet-aware routing. Infers the table from available ColumnDefinitions or * {@code null} if it is not possible. * - * @return + * @return The table to use for tablet-aware routing, or {@code null} if not set. */ @Nullable default CqlIdentifier getRoutingTable() { @@ -199,4 +200,16 @@ default Partitioner getPartitioner() { /** @return The node configured on this statement, or null if none is configured. */ @Nullable Node getNode(); + + /** + * Returns the routing type for this request. + * + *

The value represents how the request is handled on the server side (for example, regular vs + * lightweight transaction). Load balancing policies use this signal to shape the execution plan + * (eligible coordinators and ordering). + * + * @return The routing type configured on this request + */ + @Nullable + RequestRoutingType getRequestRoutingType(); } diff --git a/core/src/main/java/com/datastax/oss/driver/internal/core/cql/Conversions.java b/core/src/main/java/com/datastax/oss/driver/internal/core/cql/Conversions.java index 88f35eb75a0..0a864293b0d 100644 --- a/core/src/main/java/com/datastax/oss/driver/internal/core/cql/Conversions.java +++ b/core/src/main/java/com/datastax/oss/driver/internal/core/cql/Conversions.java @@ -27,6 +27,7 @@ import com.datastax.oss.driver.api.core.CqlIdentifier; import com.datastax.oss.driver.api.core.CqlSession; import com.datastax.oss.driver.api.core.ProtocolVersion; +import com.datastax.oss.driver.api.core.RequestRoutingType; import com.datastax.oss.driver.api.core.config.DefaultDriverOption; import com.datastax.oss.driver.api.core.config.DriverConfig; import com.datastax.oss.driver.api.core.config.DriverExecutionProfile; @@ -427,7 +428,9 @@ public static DefaultPreparedStatement toPreparedStatement( request.areBoundStatementsTracing(), context.getCodecRegistry(), context.getProtocolVersion(), - lwtInfo != null && lwtInfo.isLwt(response.variablesMetadata.flags)); + lwtInfo != null && lwtInfo.isLwt(response.variablesMetadata.flags) + ? RequestRoutingType.LWT + : RequestRoutingType.REGULAR); } public static ColumnDefinitions toColumnDefinitions( diff --git a/core/src/main/java/com/datastax/oss/driver/internal/core/cql/CqlRequestHandler.java b/core/src/main/java/com/datastax/oss/driver/internal/core/cql/CqlRequestHandler.java index 80eece271a8..4008dd528f0 100644 --- a/core/src/main/java/com/datastax/oss/driver/internal/core/cql/CqlRequestHandler.java +++ b/core/src/main/java/com/datastax/oss/driver/internal/core/cql/CqlRequestHandler.java @@ -97,11 +97,9 @@ import java.util.List; import java.util.Map; import java.util.Queue; -import java.util.Set; import java.util.concurrent.CancellationException; import java.util.concurrent.CompletableFuture; import java.util.concurrent.CompletionStage; -import java.util.concurrent.ConcurrentLinkedQueue; import java.util.concurrent.CopyOnWriteArrayList; import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicInteger; @@ -208,14 +206,6 @@ public void onThrottleReady(boolean wasDelayed) { Queue queryPlan; if (this.initialStatement.getNode() != null) { queryPlan = new SimpleQueryPlan(this.initialStatement.getNode()); - } else if (this.initialStatement.isLWT()) { - queryPlan = - getReplicas( - session.getKeyspace().orElse(null), - this.initialStatement, - context - .getLoadBalancingPolicyWrapper() - .newQueryPlan(initialStatement, executionProfile.getName(), session)); } else { queryPlan = context @@ -226,26 +216,6 @@ public void onThrottleReady(boolean wasDelayed) { sendRequest(initialStatement, null, queryPlan, 0, 0, true); } - private Queue getReplicas( - CqlIdentifier loggedKeyspace, Statement statement, Queue fallback) { - Token routingToken = getRoutingToken(statement); - CqlIdentifier keyspace = statement.getKeyspace(); - if (keyspace == null) { - keyspace = statement.getRoutingKeyspace(); - if (keyspace == null) { - keyspace = loggedKeyspace; - } - } - - TokenMap tokenMap = context.getMetadataManager().getMetadata().getTokenMap().orElse(null); - if (routingToken == null || keyspace == null || tokenMap == null) { - return fallback; - } - - Set replicas = tokenMap.getReplicas(keyspace, routingToken); - return new ConcurrentLinkedQueue<>(replicas); - } - public CompletionStage handle() { return result; } diff --git a/core/src/main/java/com/datastax/oss/driver/internal/core/cql/DefaultBatchStatement.java b/core/src/main/java/com/datastax/oss/driver/internal/core/cql/DefaultBatchStatement.java index c8cb5b7a084..cde8d91e4c9 100644 --- a/core/src/main/java/com/datastax/oss/driver/internal/core/cql/DefaultBatchStatement.java +++ b/core/src/main/java/com/datastax/oss/driver/internal/core/cql/DefaultBatchStatement.java @@ -25,6 +25,7 @@ import com.datastax.oss.driver.api.core.ConsistencyLevel; import com.datastax.oss.driver.api.core.CqlIdentifier; +import com.datastax.oss.driver.api.core.RequestRoutingType; import com.datastax.oss.driver.api.core.config.DriverExecutionProfile; import com.datastax.oss.driver.api.core.cql.BatchStatement; import com.datastax.oss.driver.api.core.cql.BatchType; @@ -42,6 +43,7 @@ import java.util.Iterator; import java.util.List; import java.util.Map; +import java.util.Objects; import net.jcip.annotations.Immutable; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -69,7 +71,8 @@ public class DefaultBatchStatement implements BatchStatement { private final Duration timeout; private final Node node; private final int nowInSeconds; - private final Boolean isLWT; + @Nullable private final RequestRoutingType requestRoutingType; + private RequestRoutingType cachedStatementsRequestRoutingType; public DefaultBatchStatement( BatchType batchType, @@ -91,7 +94,7 @@ public DefaultBatchStatement( Duration timeout, Node node, int nowInSeconds, - Boolean isLWT) { + @Nullable RequestRoutingType requestRoutingType) { for (BatchableStatement statement : statements) { if (statement != null && (statement.getConsistencyLevel() != null @@ -123,7 +126,7 @@ public DefaultBatchStatement( this.timeout = timeout; this.node = node; this.nowInSeconds = nowInSeconds; - this.isLWT = isLWT; + this.requestRoutingType = requestRoutingType; } @NonNull @@ -155,7 +158,7 @@ public BatchStatement setBatchType(@NonNull BatchType newBatchType) { timeout, node, nowInSeconds, - isLWT); + requestRoutingType); } @NonNull @@ -181,7 +184,7 @@ public BatchStatement setKeyspace(@Nullable CqlIdentifier newKeyspace) { timeout, node, nowInSeconds, - isLWT); + requestRoutingType); } @NonNull @@ -211,7 +214,7 @@ public BatchStatement add(@NonNull BatchableStatement statement) { timeout, node, nowInSeconds, - isLWT); + requestRoutingType); } } @@ -245,7 +248,7 @@ public BatchStatement addAll(@NonNull Iterable> timeout, node, nowInSeconds, - isLWT); + requestRoutingType); } } @@ -277,7 +280,7 @@ public BatchStatement clear() { timeout, node, nowInSeconds, - isLWT); + requestRoutingType); } @NonNull @@ -314,7 +317,7 @@ public BatchStatement setPagingState(ByteBuffer newPagingState) { timeout, node, nowInSeconds, - isLWT); + requestRoutingType); } @Override @@ -345,7 +348,7 @@ public BatchStatement setPageSize(int newPageSize) { timeout, node, nowInSeconds, - isLWT); + requestRoutingType); } @Nullable @@ -377,7 +380,7 @@ public BatchStatement setConsistencyLevel(@Nullable ConsistencyLevel newConsiste timeout, node, nowInSeconds, - isLWT); + requestRoutingType); } @Nullable @@ -410,7 +413,7 @@ public BatchStatement setSerialConsistencyLevel( timeout, node, nowInSeconds, - isLWT); + requestRoutingType); } @Override @@ -441,7 +444,7 @@ public BatchStatement setExecutionProfileName(@Nullable String newConfigProfileN timeout, node, nowInSeconds, - isLWT); + requestRoutingType); } @Override @@ -472,7 +475,7 @@ public DefaultBatchStatement setExecutionProfile(@Nullable DriverExecutionProfil timeout, node, nowInSeconds, - isLWT); + requestRoutingType); } @Override @@ -538,7 +541,7 @@ public BatchStatement setRoutingKeyspace(CqlIdentifier newRoutingKeyspace) { timeout, node, nowInSeconds, - isLWT); + requestRoutingType); } @NonNull @@ -564,7 +567,7 @@ public BatchStatement setNode(@Nullable Node newNode) { timeout, newNode, nowInSeconds, - isLWT); + requestRoutingType); } @Nullable @@ -611,7 +614,7 @@ public BatchStatement setRoutingKey(ByteBuffer newRoutingKey) { timeout, node, nowInSeconds, - isLWT); + requestRoutingType); } @Override @@ -652,7 +655,7 @@ public BatchStatement setRoutingToken(Token newRoutingToken) { timeout, node, nowInSeconds, - isLWT); + requestRoutingType); } @NonNull @@ -684,7 +687,7 @@ public DefaultBatchStatement setCustomPayload(@NonNull Map n timeout, node, nowInSeconds, - isLWT); + requestRoutingType); } @Override @@ -721,7 +724,7 @@ public DefaultBatchStatement setIdempotent(Boolean newIdempotence) { timeout, node, nowInSeconds, - isLWT); + requestRoutingType); } @Override @@ -752,7 +755,7 @@ public BatchStatement setTracing(boolean newTracing) { timeout, node, nowInSeconds, - isLWT); + requestRoutingType); } @Override @@ -783,7 +786,7 @@ public BatchStatement setQueryTimestamp(long newTimestamp) { timeout, node, nowInSeconds, - isLWT); + requestRoutingType); } @NonNull @@ -809,7 +812,7 @@ public BatchStatement setTimeout(@Nullable Duration newTimeout) { newTimeout, node, nowInSeconds, - isLWT); + requestRoutingType); } @Override @@ -840,12 +843,36 @@ public BatchStatement setNowInSeconds(int newNowInSeconds) { timeout, node, newNowInSeconds, - isLWT); + requestRoutingType); + } + + /** + * Returns the request routing type for this batch statement based on {@link + * DefaultBatchStatement#isLWT()} implementation while maintaining non-null contract. + * + * @return the request routing type, never null + */ + @Nullable + @Override + public RequestRoutingType getRequestRoutingType() { + if (Objects.nonNull(requestRoutingType)) { + return requestRoutingType; + } else if (Objects.isNull( + cachedStatementsRequestRoutingType)) { // Immutability of the statement list and statements + // allows us to cache the result + cachedStatementsRequestRoutingType = + statements.stream() + .map(Statement::getRequestRoutingType) + .filter((rt) -> Objects.nonNull(rt) && rt == RequestRoutingType.LWT) + .findFirst() + .orElse(RequestRoutingType.REGULAR); + } + return cachedStatementsRequestRoutingType; } @NonNull @Override - public BatchStatement setIsLWT(Boolean newIsLWT) { + public BatchStatement setRequestRoutingType(RequestRoutingType requestRoutingType) { return new DefaultBatchStatement( batchType, statements, @@ -866,12 +893,6 @@ public BatchStatement setIsLWT(Boolean newIsLWT) { timeout, node, nowInSeconds, - newIsLWT); - } - - @Override - public boolean isLWT() { - if (isLWT != null) return isLWT; - return statements.stream().anyMatch(Statement::isLWT); + requestRoutingType); } } diff --git a/core/src/main/java/com/datastax/oss/driver/internal/core/cql/DefaultBoundStatement.java b/core/src/main/java/com/datastax/oss/driver/internal/core/cql/DefaultBoundStatement.java index 05673692ce9..2c3ad902f39 100644 --- a/core/src/main/java/com/datastax/oss/driver/internal/core/cql/DefaultBoundStatement.java +++ b/core/src/main/java/com/datastax/oss/driver/internal/core/cql/DefaultBoundStatement.java @@ -26,6 +26,7 @@ import com.datastax.oss.driver.api.core.ConsistencyLevel; import com.datastax.oss.driver.api.core.CqlIdentifier; import com.datastax.oss.driver.api.core.ProtocolVersion; +import com.datastax.oss.driver.api.core.RequestRoutingType; import com.datastax.oss.driver.api.core.config.DriverExecutionProfile; import com.datastax.oss.driver.api.core.cql.BoundStatement; import com.datastax.oss.driver.api.core.cql.ColumnDefinitions; @@ -43,6 +44,7 @@ import java.util.Arrays; import java.util.List; import java.util.Map; +import java.util.Objects; import net.jcip.annotations.Immutable; @Immutable @@ -69,6 +71,7 @@ public class DefaultBoundStatement implements BoundStatement { private final ProtocolVersion protocolVersion; private final Node node; private final int nowInSeconds; + @Nullable private final RequestRoutingType requestRoutingType; public DefaultBoundStatement( PreparedStatement preparedStatement, @@ -91,7 +94,8 @@ public DefaultBoundStatement( CodecRegistry codecRegistry, ProtocolVersion protocolVersion, Node node, - int nowInSeconds) { + int nowInSeconds, + @Nullable RequestRoutingType requestRoutingType) { this.preparedStatement = preparedStatement; this.variableDefinitions = variableDefinitions; this.values = values; @@ -113,6 +117,7 @@ public DefaultBoundStatement( this.protocolVersion = protocolVersion; this.node = node; this.nowInSeconds = nowInSeconds; + this.requestRoutingType = requestRoutingType; } @Override @@ -207,7 +212,8 @@ public BoundStatement setBytesUnsafe(int i, ByteBuffer v) { codecRegistry, protocolVersion, node, - nowInSeconds); + nowInSeconds, + requestRoutingType); } @NonNull @@ -251,7 +257,8 @@ public BoundStatement setExecutionProfileName(@Nullable String newConfigProfileN codecRegistry, protocolVersion, node, - nowInSeconds); + nowInSeconds, + requestRoutingType); } @Override @@ -283,7 +290,8 @@ public BoundStatement setExecutionProfile(@Nullable DriverExecutionProfile newPr codecRegistry, protocolVersion, node, - nowInSeconds); + nowInSeconds, + requestRoutingType); } @Override @@ -333,7 +341,8 @@ public BoundStatement setRoutingKeyspace(@Nullable CqlIdentifier newRoutingKeysp codecRegistry, protocolVersion, node, - nowInSeconds); + nowInSeconds, + requestRoutingType); } @NonNull @@ -360,7 +369,8 @@ public BoundStatement setNode(@Nullable Node newNode) { codecRegistry, protocolVersion, newNode, - nowInSeconds); + nowInSeconds, + requestRoutingType); } @Nullable @@ -420,7 +430,8 @@ public BoundStatement setRoutingKey(@Nullable ByteBuffer newRoutingKey) { codecRegistry, protocolVersion, node, - nowInSeconds); + nowInSeconds, + requestRoutingType); } @Override @@ -452,7 +463,8 @@ public BoundStatement setRoutingToken(@Nullable Token newRoutingToken) { codecRegistry, protocolVersion, node, - nowInSeconds); + nowInSeconds, + requestRoutingType); } @NonNull @@ -485,7 +497,8 @@ public BoundStatement setCustomPayload(@NonNull Map newCusto codecRegistry, protocolVersion, node, - nowInSeconds); + nowInSeconds, + requestRoutingType); } @Override @@ -517,7 +530,8 @@ public BoundStatement setIdempotent(@Nullable Boolean newIdempotence) { codecRegistry, protocolVersion, node, - nowInSeconds); + nowInSeconds, + requestRoutingType); } @Override @@ -549,7 +563,8 @@ public BoundStatement setTracing(boolean newTracing) { codecRegistry, protocolVersion, node, - nowInSeconds); + nowInSeconds, + requestRoutingType); } @Override @@ -581,7 +596,8 @@ public BoundStatement setQueryTimestamp(long newTimestamp) { codecRegistry, protocolVersion, node, - nowInSeconds); + nowInSeconds, + requestRoutingType); } @Nullable @@ -614,7 +630,8 @@ public BoundStatement setTimeout(@Nullable Duration newTimeout) { codecRegistry, protocolVersion, node, - nowInSeconds); + nowInSeconds, + requestRoutingType); } @Override @@ -646,7 +663,8 @@ public BoundStatement setPagingState(@Nullable ByteBuffer newPagingState) { codecRegistry, protocolVersion, node, - nowInSeconds); + nowInSeconds, + requestRoutingType); } @Override @@ -678,7 +696,8 @@ public BoundStatement setPageSize(int newPageSize) { codecRegistry, protocolVersion, node, - nowInSeconds); + nowInSeconds, + requestRoutingType); } @Nullable @@ -711,7 +730,8 @@ public BoundStatement setConsistencyLevel(@Nullable ConsistencyLevel newConsiste codecRegistry, protocolVersion, node, - nowInSeconds); + nowInSeconds, + requestRoutingType); } @Nullable @@ -745,7 +765,8 @@ public BoundStatement setSerialConsistencyLevel( codecRegistry, protocolVersion, node, - nowInSeconds); + nowInSeconds, + requestRoutingType); } @Override @@ -777,7 +798,44 @@ public BoundStatement setNowInSeconds(int newNowInSeconds) { codecRegistry, protocolVersion, node, - newNowInSeconds); + newNowInSeconds, + requestRoutingType); + } + + @Nullable + @Override + public RequestRoutingType getRequestRoutingType() { + return Objects.nonNull(requestRoutingType) + ? requestRoutingType + : preparedStatement.getRequestRoutingType(); + } + + @NonNull + @Override + public BoundStatement setRequestRoutingType(@Nullable RequestRoutingType requestRoutingType) { + return new DefaultBoundStatement( + preparedStatement, + variableDefinitions, + values, + executionProfileName, + executionProfile, + routingKeyspace, + routingKey, + routingToken, + customPayload, + idempotent, + tracing, + timestamp, + pagingState, + pageSize, + consistencyLevel, + serialConsistencyLevel, + timeout, + codecRegistry, + protocolVersion, + node, + nowInSeconds, + requestRoutingType); } @Override diff --git a/core/src/main/java/com/datastax/oss/driver/internal/core/cql/DefaultPrepareRequest.java b/core/src/main/java/com/datastax/oss/driver/internal/core/cql/DefaultPrepareRequest.java index 7f87dbe5b51..019b56dbb1f 100644 --- a/core/src/main/java/com/datastax/oss/driver/internal/core/cql/DefaultPrepareRequest.java +++ b/core/src/main/java/com/datastax/oss/driver/internal/core/cql/DefaultPrepareRequest.java @@ -20,6 +20,7 @@ import com.datastax.oss.driver.api.core.ConsistencyLevel; import com.datastax.oss.driver.api.core.CqlIdentifier; import com.datastax.oss.driver.api.core.CqlSession; +import com.datastax.oss.driver.api.core.RequestRoutingType; import com.datastax.oss.driver.api.core.config.DriverExecutionProfile; import com.datastax.oss.driver.api.core.cql.PrepareRequest; import com.datastax.oss.driver.api.core.cql.SimpleStatement; @@ -197,6 +198,12 @@ public Node getNode() { return null; } + @NonNull + @Override + public RequestRoutingType getRequestRoutingType() { + return RequestRoutingType.REGULAR; + } + @Override public boolean areBoundStatementsTracing() { return statement.isTracing(); diff --git a/core/src/main/java/com/datastax/oss/driver/internal/core/cql/DefaultPreparedStatement.java b/core/src/main/java/com/datastax/oss/driver/internal/core/cql/DefaultPreparedStatement.java index dace3647645..754a89ac228 100644 --- a/core/src/main/java/com/datastax/oss/driver/internal/core/cql/DefaultPreparedStatement.java +++ b/core/src/main/java/com/datastax/oss/driver/internal/core/cql/DefaultPreparedStatement.java @@ -27,6 +27,7 @@ import com.datastax.oss.driver.api.core.ConsistencyLevel; import com.datastax.oss.driver.api.core.CqlIdentifier; import com.datastax.oss.driver.api.core.ProtocolVersion; +import com.datastax.oss.driver.api.core.RequestRoutingType; import com.datastax.oss.driver.api.core.config.DefaultDriverOption; import com.datastax.oss.driver.api.core.config.DriverExecutionProfile; import com.datastax.oss.driver.api.core.cql.BoundStatement; @@ -47,6 +48,7 @@ import com.datastax.oss.driver.internal.core.session.RepreparePayload; import com.datastax.oss.driver.shaded.guava.common.base.Splitter; import edu.umd.cs.findbugs.annotations.NonNull; +import edu.umd.cs.findbugs.annotations.Nullable; import java.nio.ByteBuffer; import java.time.Duration; import java.util.List; @@ -82,7 +84,7 @@ public class DefaultPreparedStatement implements PreparedStatement { private final ConsistencyLevel serialConsistencyLevelForBoundStatements; private final Duration timeoutForBoundStatements; private final Partitioner partitioner; - private final boolean isLWT; + @Nullable private final RequestRoutingType requestRoutingType; private volatile boolean skipMetadata; public DefaultPreparedStatement( @@ -110,7 +112,7 @@ public DefaultPreparedStatement( boolean areBoundStatementsTracing, CodecRegistry codecRegistry, ProtocolVersion protocolVersion, - boolean isLWT) { + @Nullable RequestRoutingType requestRoutingType) { this.id = id; this.partitionKeyIndices = partitionKeyIndices; // It's important that we keep a reference to this object, so that it only gets evicted from @@ -136,7 +138,7 @@ public DefaultPreparedStatement( this.codecRegistry = codecRegistry; this.protocolVersion = protocolVersion; - this.isLWT = isLWT; + this.requestRoutingType = requestRoutingType; this.skipMetadata = resolveSkipMetadata( query, resultMetadataId, resultSetDefinitions, this.executionProfileForBoundStatements); @@ -188,7 +190,13 @@ public ColumnDefinitions getResultSetDefinitions() { @Override public boolean isLWT() { - return isLWT; + return requestRoutingType == RequestRoutingType.LWT; + } + + @Nullable + @Override + public RequestRoutingType getRequestRoutingType() { + return requestRoutingType; } @Override @@ -229,7 +237,8 @@ public BoundStatement bind(@NonNull Object... values) { codecRegistry, protocolVersion, null, - Statement.NO_NOW_IN_SECONDS); + Statement.NO_NOW_IN_SECONDS, + requestRoutingType); } @NonNull @@ -255,7 +264,8 @@ public BoundStatementBuilder boundStatementBuilder(@NonNull Object... values) { serialConsistencyLevelForBoundStatements, timeoutForBoundStatements, codecRegistry, - protocolVersion); + protocolVersion, + requestRoutingType); } public RepreparePayload getRepreparePayload() { @@ -263,8 +273,8 @@ public RepreparePayload getRepreparePayload() { } private static class ResultMetadata { - private ByteBuffer resultMetadataId; - private ColumnDefinitions resultSetDefinitions; + private final ByteBuffer resultMetadataId; + private final ColumnDefinitions resultSetDefinitions; private ResultMetadata(ByteBuffer resultMetadataId, ColumnDefinitions resultSetDefinitions) { this.resultMetadataId = resultMetadataId; diff --git a/core/src/main/java/com/datastax/oss/driver/internal/core/cql/DefaultSimpleStatement.java b/core/src/main/java/com/datastax/oss/driver/internal/core/cql/DefaultSimpleStatement.java index 0af32b988fe..0268689d86f 100644 --- a/core/src/main/java/com/datastax/oss/driver/internal/core/cql/DefaultSimpleStatement.java +++ b/core/src/main/java/com/datastax/oss/driver/internal/core/cql/DefaultSimpleStatement.java @@ -25,6 +25,7 @@ import com.datastax.oss.driver.api.core.ConsistencyLevel; import com.datastax.oss.driver.api.core.CqlIdentifier; +import com.datastax.oss.driver.api.core.RequestRoutingType; import com.datastax.oss.driver.api.core.config.DriverExecutionProfile; import com.datastax.oss.driver.api.core.cql.SimpleStatement; import com.datastax.oss.driver.api.core.metadata.Node; @@ -64,6 +65,7 @@ public class DefaultSimpleStatement implements SimpleStatement { private final Duration timeout; private final Node node; private final int nowInSeconds; + @Nullable private final RequestRoutingType requestRoutingType; /** @see SimpleStatement#builder(String) */ public DefaultSimpleStatement( @@ -86,7 +88,8 @@ public DefaultSimpleStatement( ConsistencyLevel serialConsistencyLevel, Duration timeout, Node node, - int nowInSeconds) { + int nowInSeconds, + @Nullable RequestRoutingType requestRoutingType) { if (!positionalValues.isEmpty() && !namedValues.isEmpty()) { throw new IllegalArgumentException("Can't have both positional and named values"); } @@ -110,6 +113,7 @@ public DefaultSimpleStatement( this.timeout = timeout; this.node = node; this.nowInSeconds = nowInSeconds; + this.requestRoutingType = requestRoutingType; } @NonNull @@ -141,7 +145,8 @@ public SimpleStatement setQuery(@NonNull String newQuery) { serialConsistencyLevel, timeout, node, - nowInSeconds); + nowInSeconds, + requestRoutingType); } @NonNull @@ -173,7 +178,8 @@ public SimpleStatement setPositionalValues(@NonNull List newPositionalVa serialConsistencyLevel, timeout, node, - nowInSeconds); + nowInSeconds, + requestRoutingType); } @NonNull @@ -205,7 +211,8 @@ public SimpleStatement setNamedValuesWithIds(@NonNull Map serialConsistencyLevel, timeout, node, - nowInSeconds); + nowInSeconds, + requestRoutingType); } @Nullable @@ -237,7 +244,8 @@ public SimpleStatement setExecutionProfileName(@Nullable String newConfigProfile serialConsistencyLevel, timeout, node, - nowInSeconds); + nowInSeconds, + requestRoutingType); } @Nullable @@ -269,7 +277,8 @@ public SimpleStatement setExecutionProfile(@Nullable DriverExecutionProfile newP serialConsistencyLevel, timeout, node, - nowInSeconds); + nowInSeconds, + requestRoutingType); } @Nullable @@ -301,7 +310,8 @@ public SimpleStatement setKeyspace(@Nullable CqlIdentifier newKeyspace) { serialConsistencyLevel, timeout, node, - nowInSeconds); + nowInSeconds, + requestRoutingType); } @Nullable @@ -333,7 +343,8 @@ public SimpleStatement setRoutingKeyspace(@Nullable CqlIdentifier newRoutingKeys serialConsistencyLevel, timeout, node, - nowInSeconds); + nowInSeconds, + requestRoutingType); } @NonNull @@ -359,7 +370,8 @@ public SimpleStatement setNode(@Nullable Node newNode) { serialConsistencyLevel, timeout, newNode, - nowInSeconds); + nowInSeconds, + requestRoutingType); } @Nullable @@ -397,7 +409,8 @@ public SimpleStatement setRoutingKey(@Nullable ByteBuffer newRoutingKey) { serialConsistencyLevel, timeout, node, - nowInSeconds); + nowInSeconds, + requestRoutingType); } @Nullable @@ -429,7 +442,8 @@ public SimpleStatement setRoutingToken(@Nullable Token newRoutingToken) { serialConsistencyLevel, timeout, node, - nowInSeconds); + nowInSeconds, + requestRoutingType); } @NonNull @@ -461,7 +475,8 @@ public SimpleStatement setCustomPayload(@NonNull Map newCust serialConsistencyLevel, timeout, node, - nowInSeconds); + nowInSeconds, + requestRoutingType); } @Nullable @@ -493,7 +508,8 @@ public SimpleStatement setIdempotent(@Nullable Boolean newIdempotence) { serialConsistencyLevel, timeout, node, - nowInSeconds); + nowInSeconds, + requestRoutingType); } @Override @@ -524,7 +540,8 @@ public SimpleStatement setTracing(boolean newTracing) { serialConsistencyLevel, timeout, node, - nowInSeconds); + nowInSeconds, + requestRoutingType); } @Override @@ -555,7 +572,8 @@ public SimpleStatement setQueryTimestamp(long newTimestamp) { serialConsistencyLevel, timeout, node, - nowInSeconds); + nowInSeconds, + requestRoutingType); } @Nullable @@ -587,7 +605,8 @@ public SimpleStatement setTimeout(@Nullable Duration newTimeout) { serialConsistencyLevel, newTimeout, node, - nowInSeconds); + nowInSeconds, + requestRoutingType); } @Nullable @@ -619,7 +638,8 @@ public SimpleStatement setPagingState(@Nullable ByteBuffer newPagingState) { serialConsistencyLevel, timeout, node, - nowInSeconds); + nowInSeconds, + requestRoutingType); } @Override @@ -650,7 +670,8 @@ public SimpleStatement setPageSize(int newPageSize) { serialConsistencyLevel, timeout, node, - nowInSeconds); + nowInSeconds, + requestRoutingType); } @Nullable @@ -682,7 +703,8 @@ public SimpleStatement setConsistencyLevel(@Nullable ConsistencyLevel newConsist serialConsistencyLevel, timeout, node, - nowInSeconds); + nowInSeconds, + requestRoutingType); } @Nullable @@ -715,7 +737,8 @@ public SimpleStatement setSerialConsistencyLevel( newSerialConsistencyLevel, timeout, node, - nowInSeconds); + nowInSeconds, + requestRoutingType); } @Override @@ -746,12 +769,46 @@ public SimpleStatement setNowInSeconds(int newNowInSeconds) { serialConsistencyLevel, timeout, node, - newNowInSeconds); + newNowInSeconds, + requestRoutingType); + } + + @Nullable + @Override + public RequestRoutingType getRequestRoutingType() { + return requestRoutingType; + } + + @NonNull + @Override + public SimpleStatement setRequestRoutingType(@Nullable RequestRoutingType requestRoutingType) { + return new DefaultSimpleStatement( + query, + positionalValues, + namedValues, + executionProfileName, + executionProfile, + keyspace, + routingKeyspace, + routingKey, + routingToken, + customPayload, + idempotent, + tracing, + timestamp, + pagingState, + pageSize, + consistencyLevel, + serialConsistencyLevel, + timeout, + node, + nowInSeconds, + requestRoutingType); } @Override public boolean isLWT() { - return false; + return requestRoutingType == RequestRoutingType.LWT; } public static Map wrapKeys(Map namedValues) { diff --git a/core/src/main/java/com/datastax/oss/driver/internal/core/loadbalancing/DefaultLoadBalancingPolicy.java b/core/src/main/java/com/datastax/oss/driver/internal/core/loadbalancing/DefaultLoadBalancingPolicy.java index 0b78141227a..f798ff033c2 100644 --- a/core/src/main/java/com/datastax/oss/driver/internal/core/loadbalancing/DefaultLoadBalancingPolicy.java +++ b/core/src/main/java/com/datastax/oss/driver/internal/core/loadbalancing/DefaultLoadBalancingPolicy.java @@ -20,6 +20,7 @@ import static java.util.concurrent.TimeUnit.MILLISECONDS; import static java.util.concurrent.TimeUnit.MINUTES; +import com.datastax.oss.driver.api.core.RequestRoutingType; import com.datastax.oss.driver.api.core.config.DefaultDriverOption; import com.datastax.oss.driver.api.core.config.DriverExecutionProfile; import com.datastax.oss.driver.api.core.context.DriverContext; @@ -50,6 +51,7 @@ import java.util.concurrent.ThreadLocalRandom; import java.util.concurrent.atomic.AtomicLongArray; import net.jcip.annotations.ThreadSafe; +import org.apache.commons.lang3.tuple.Pair; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -68,7 +70,7 @@ * } * * - * See {@code reference.conf} (in the manual or core driver JAR) for more details. + *

See {@code reference.conf} (in the manual or core driver JAR) for more details. * *

Local datacenter: This implementation requires a local datacenter to be defined, * otherwise it will throw an {@link IllegalStateException}. A local datacenter can be supplied @@ -95,6 +97,11 @@ @ThreadSafe public class DefaultLoadBalancingPolicy extends BasicLoadBalancingPolicy implements RequestTracker { + public enum RequestRoutingMethod { + REGULAR, + PRESERVE_REPLICA_ORDER + } + private static final Logger LOG = LoggerFactory.getLogger(DefaultLoadBalancingPolicy.class); private static final long NEWLY_UP_INTERVAL_NANOS = MINUTES.toNanos(1); @@ -104,14 +111,31 @@ public class DefaultLoadBalancingPolicy extends BasicLoadBalancingPolicy impleme protected final ConcurrentMap responseTimes; protected final Map upTimes = new ConcurrentHashMap<>(); private final boolean avoidSlowReplicas; + private final RequestRoutingMethod lwtRequestRoutingMethod; public DefaultLoadBalancingPolicy(@NonNull DriverContext context, @NonNull String profileName) { super(context, profileName); this.avoidSlowReplicas = profile.getBoolean(DefaultDriverOption.LOAD_BALANCING_POLICY_SLOW_AVOIDANCE, true); + this.lwtRequestRoutingMethod = getDefaultLWTRequestRoutingMethod(); this.responseTimes = new MapMaker().weakKeys().makeMap(); } + @NonNull + private RequestRoutingMethod getDefaultLWTRequestRoutingMethod() { + String methodString = + profile.getString(DefaultDriverOption.LOAD_BALANCING_DEFAULT_LWT_REQUEST_ROUTING_METHOD); + try { + return RequestRoutingMethod.valueOf(methodString.toUpperCase()); + } catch (IllegalArgumentException e) { + LOG.warn( + "[{}] Unknown request routing method '{}', defaulting to PRESERVE_REPLICA_ORDER", + logPrefix, + methodString); + return RequestRoutingMethod.PRESERVE_REPLICA_ORDER; + } + } + @NonNull @Override public Optional getRequestTracker() { @@ -128,116 +152,66 @@ protected Optional discoverLocalDc(@NonNull Map nodes) { return new MandatoryLocalDcHelper(context, profile, logPrefix).discoverLocalDc(nodes); } + @NonNull + public RequestRoutingMethod getDefaultLWTRequestRoutingMethod(@Nullable Request request) { + if (request == null) { + return RequestRoutingMethod.REGULAR; + } + if (request.getRequestRoutingType() == RequestRoutingType.LWT) { + return lwtRequestRoutingMethod; + } else { + return RequestRoutingMethod.REGULAR; + } + } + @NonNull @Override public Queue newQueryPlan(@Nullable Request request, @Nullable Session session) { - if (!avoidSlowReplicas) { - return super.newQueryPlan(request, session); + switch (getDefaultLWTRequestRoutingMethod(request)) { + case PRESERVE_REPLICA_ORDER: + return newQueryPlanPreserveReplicas(request, session); + case REGULAR: + default: + return newQueryPlanRegular(request, session); } + } - // Take a snapshot since the set is concurrent: - Object[] currentNodes = getLiveNodes().dc(getLocalDatacenter()).toArray(); + /** + * Builds a query plan that preserves the replica order as returned by the load balancing + * strategy, while pushing non-local replicas after local ones. + */ + @NonNull + public Queue newQueryPlanPreserveReplicas( + @Nullable Request request, @Nullable Session session) { + List replicas = getReplicas(request, session); + String localDc = getLocalDatacenter(); + if (localDc == null || replicas.isEmpty()) { + return new SimpleQueryPlan(replicas.toArray()); + } + + return new SimpleQueryPlan(moveNonLocalReplicasToTheEnd(replicas, localDc)); + } - List allReplicas = getReplicas(request, session); + /** + * Builds a query plan that prioritizes local replicas, shuffles them for balance, and then + * round-robins the remaining local nodes. + */ + @NonNull + public Queue newQueryPlanRegular(@Nullable Request request, @Nullable Session session) { + List replicas = getReplicas(request, session); + Object[] currentNodes = getLiveNodes().dc(getLocalDatacenter()).toArray(); int replicaCount = 0; // in currentNodes - int localRackReplicaCount = 0; // in currentNodes - String localRack = getLocalRack(); - - if (!allReplicas.isEmpty()) { - - // Move replicas to the beginning of the plan - // Replicas in local rack should precede other replicas - for (int i = 0; i < currentNodes.length; i++) { - Node node = (Node) currentNodes[i]; - if (allReplicas.contains(node)) { - if (Objects.equals(node.getRack(), localRack) - && Objects.equals(node.getDatacenter(), getLocalDatacenter())) { - ArrayUtils.bubbleUp(currentNodes, i, localRackReplicaCount); - localRackReplicaCount++; - } else { - ArrayUtils.bubbleUp(currentNodes, i, replicaCount); - } - replicaCount++; - } - } + if (!replicas.isEmpty()) { + Pair counts = moveReplicasToFront(currentNodes, replicas); + replicaCount = counts.getLeft(); + + int localRackReplicaCount = counts.getRight(); // in currentNodes if (replicaCount > 1) { - if (localRack != null && localRackReplicaCount > 0) { - // Shuffle only replicas that are in the local rack - shuffleHead(currentNodes, localRackReplicaCount); - // Shuffles only replicas that are not in local rack - shuffleInRange(currentNodes, localRackReplicaCount, replicaCount - 1); - } else { - shuffleHead(currentNodes, replicaCount); - } + shuffleLocalRackReplicasAndReplicas(currentNodes, replicaCount, localRackReplicaCount); - if (replicaCount > 2) { - - assert session != null; - - // Test replicas health - Node newestUpReplica = null; - BitSet unhealthyReplicas = null; // bit mask storing indices of unhealthy replicas - long mostRecentUpTimeNanos = -1; - long now = nanoTime(); - for (int i = 0; i < replicaCount; i++) { - Node node = (Node) currentNodes[i]; - assert node != null; - Long upTimeNanos = upTimes.get(node); - if (upTimeNanos != null - && now - upTimeNanos - NEWLY_UP_INTERVAL_NANOS < 0 - && upTimeNanos - mostRecentUpTimeNanos > 0) { - newestUpReplica = node; - mostRecentUpTimeNanos = upTimeNanos; - } - if (newestUpReplica == null && isUnhealthy(node, session, now)) { - if (unhealthyReplicas == null) { - unhealthyReplicas = new BitSet(replicaCount); - } - unhealthyReplicas.set(i); - } - } - - // When: - // - there isn't any newly UP replica and - // - there is one or more unhealthy replicas and - // - there is a majority of healthy replicas - int unhealthyReplicasCount = - unhealthyReplicas == null ? 0 : unhealthyReplicas.cardinality(); - if (newestUpReplica == null - && unhealthyReplicasCount > 0 - && unhealthyReplicasCount < (replicaCount / 2.0)) { - - // Reorder the unhealthy replicas to the back of the list - // Start from the back of the replicas, then move backwards; - // stop once all unhealthy replicas are moved to the back. - int counter = 0; - for (int i = replicaCount - 1; i >= 0 && counter < unhealthyReplicasCount; i--) { - if (unhealthyReplicas.get(i)) { - ArrayUtils.bubbleDown(currentNodes, i, replicaCount - 1 - counter); - counter++; - } - } - } - - // When: - // - there is a newly UP replica and - // - the replica in first or second position is the most recent replica marked as UP and - // - dice roll 1d4 != 1 - else if ((newestUpReplica == currentNodes[0] || newestUpReplica == currentNodes[1]) - && diceRoll1d4() != 1) { - - // Send it to the back of the replicas - ArrayUtils.bubbleDown( - currentNodes, newestUpReplica == currentNodes[0] ? 0 : 1, replicaCount - 1); - } - - // Reorder the first two replicas in the shuffled list based on the number of - // in-flight requests - if (getInFlight((Node) currentNodes[0], session) - > getInFlight((Node) currentNodes[1], session)) { - ArrayUtils.swap(currentNodes, 0, 1); - } + if (replicaCount > 2 && avoidSlowReplicas) { + avoidSlowReplicas(Objects.requireNonNull(session), currentNodes, replicaCount); } } } @@ -255,6 +229,123 @@ > getInFlight((Node) currentNodes[1], session)) { return maybeAddDcFailover(request, plan); } + /** + * Returns a replica array with local-datacenter replicas first and remote replicas preserved at + * the end. + */ + private static Object[] moveNonLocalReplicasToTheEnd(List replicas, String localDc) { + Object[] orderedReplicas = new Object[replicas.size()]; + int index = 0; + for (Node replica : replicas) { + if (Objects.equals(replica.getDatacenter(), localDc)) { + orderedReplicas[index++] = replica; + } + } + for (Node replica : replicas) { + if (!Objects.equals(replica.getDatacenter(), localDc)) { + orderedReplicas[index++] = replica; + } + } + return orderedReplicas; + } + + private Pair moveReplicasToFront( + Object[] currentNodes, List allReplicas) { + int replicaCount = 0, localRackReplicaCount = 0; + for (int i = 0; i < currentNodes.length; i++) { + Node node = (Node) currentNodes[i]; + if (allReplicas.contains(node)) { + if (Objects.equals(node.getRack(), getLocalRack()) + && Objects.equals(node.getDatacenter(), getLocalDatacenter())) { + ArrayUtils.bubbleUp(currentNodes, i, localRackReplicaCount); + localRackReplicaCount++; + } else { + ArrayUtils.bubbleUp(currentNodes, i, replicaCount); + } + replicaCount++; + } + } + return Pair.of(replicaCount, localRackReplicaCount); + } + + private void shuffleLocalRackReplicasAndReplicas( + Object[] currentNodes, int replicaCount, int localRackReplicaCount) { + if (getLocalRack() != null && localRackReplicaCount > 0) { + // Shuffle only replicas that are in the local rack + shuffleHead(currentNodes, localRackReplicaCount); + // Shuffles only replicas that are not in local rack + shuffleInRange(currentNodes, localRackReplicaCount, replicaCount - 1); + } else { + shuffleHead(currentNodes, replicaCount); + } + } + + private void avoidSlowReplicas( + @NonNull Session session, Object[] currentNodes, int replicaCount) { + // Test replicas health + Node newestUpReplica = null; + BitSet unhealthyReplicas = null; // bit mask storing indices of unhealthy replicas + long mostRecentUpTimeNanos = -1; + long now = nanoTime(); + for (int i = 0; i < replicaCount; i++) { + Node node = (Node) currentNodes[i]; + assert node != null; + Long upTimeNanos = upTimes.get(node); + if (upTimeNanos != null + && now - upTimeNanos - NEWLY_UP_INTERVAL_NANOS < 0 + && upTimeNanos - mostRecentUpTimeNanos > 0) { + newestUpReplica = node; + mostRecentUpTimeNanos = upTimeNanos; + } + if (newestUpReplica == null && isUnhealthy(node, session, now)) { + if (unhealthyReplicas == null) { + unhealthyReplicas = new BitSet(replicaCount); + } + unhealthyReplicas.set(i); + } + } + + // When: + // - there isn't any newly UP replica and + // - there is one or more unhealthy replicas and + // - there is a majority of healthy replicas + int unhealthyReplicasCount = unhealthyReplicas == null ? 0 : unhealthyReplicas.cardinality(); + if (newestUpReplica == null + && unhealthyReplicasCount > 0 + && unhealthyReplicasCount < (replicaCount / 2.0)) { + + // Reorder the unhealthy replicas to the back of the list + // Start from the back of the replicas, then move backwards; + // stop once all unhealthy replicas are moved to the back. + int counter = 0; + for (int i = replicaCount - 1; i >= 0 && counter < unhealthyReplicasCount; i--) { + if (unhealthyReplicas.get(i)) { + ArrayUtils.bubbleDown(currentNodes, i, replicaCount - 1 - counter); + counter++; + } + } + } + + // When: + // - there is a newly UP replica and + // - the replica in first or second position is the most recent replica marked as UP and + // - dice roll 1d4 != 1 + else if ((newestUpReplica == currentNodes[0] || newestUpReplica == currentNodes[1]) + && diceRoll1d4() != 1) { + + // Send it to the back of the replicas + ArrayUtils.bubbleDown( + currentNodes, newestUpReplica == currentNodes[0] ? 0 : 1, replicaCount - 1); + } + + // Reorder the first two replicas in the shuffled list based on the number of + // in-flight requests + if (getInFlight((Node) currentNodes[0], session) + > getInFlight((Node) currentNodes[1], session)) { + ArrayUtils.swap(currentNodes, 0, 1); + } + } + @Override public void onNodeSuccess( @NonNull Request request, @@ -325,8 +416,7 @@ protected class NodeResponseRateSample { @VisibleForTesting protected final OptionalLong newest; private NodeResponseRateSample() { - long now = nanoTime(); - this.oldest = now; + this.oldest = nanoTime(); this.newest = OptionalLong.empty(); } diff --git a/core/src/main/resources/reference.conf b/core/src/main/resources/reference.conf index 161cd4bc91a..40d56d67341 100644 --- a/core/src/main/resources/reference.conf +++ b/core/src/main/resources/reference.conf @@ -651,6 +651,14 @@ datastax-java-driver { # Overridable in a profile: no preferred-remote-dcs = [""] } + # The method to use when routing requests. + # Options are: + # - "REGULAR": default behavior of the load balancing policy includes avoiding slow replicas and shuffling nodes + # - "PRESERVE_REPLICA_ORDER": tries to preserve the order of replicas as returned by the partitioner when building the query plan. When dc is provided, move replicas from non-local dc to the back of query plan, but ignores local rack. + # Required: no + # Modifiable at runtime: no + # Overridable in a profile: yes + default-lwt-request-routing-method = "PRESERVE_REPLICA_ORDER" } # Whether to schedule reconnection attempts if all contact points are unreachable on the first diff --git a/core/src/test/java/com/datastax/oss/driver/api/core/cql/StatementBuilderTest.java b/core/src/test/java/com/datastax/oss/driver/api/core/cql/StatementBuilderTest.java index 9904b1e27d7..a10208645fd 100644 --- a/core/src/test/java/com/datastax/oss/driver/api/core/cql/StatementBuilderTest.java +++ b/core/src/test/java/com/datastax/oss/driver/api/core/cql/StatementBuilderTest.java @@ -22,6 +22,7 @@ import static org.mockito.Mockito.when; import com.datastax.oss.driver.shaded.guava.common.base.Charsets; +import edu.umd.cs.findbugs.annotations.NonNull; import java.nio.ByteBuffer; import org.junit.Test; @@ -38,6 +39,7 @@ public MockSimpleStatementBuilder(SimpleStatement template) { super(template); } + @NonNull @Override public SimpleStatement build() { diff --git a/core/src/test/java/com/datastax/oss/driver/api/core/cql/StatementProfileTest.java b/core/src/test/java/com/datastax/oss/driver/api/core/cql/StatementProfileTest.java index af2dccd0432..d59d3a460b9 100644 --- a/core/src/test/java/com/datastax/oss/driver/api/core/cql/StatementProfileTest.java +++ b/core/src/test/java/com/datastax/oss/driver/api/core/cql/StatementProfileTest.java @@ -22,6 +22,7 @@ import static org.mockito.Mockito.when; import com.datastax.oss.driver.TestDataProviders; +import com.datastax.oss.driver.api.core.RequestRoutingType; import com.datastax.oss.driver.api.core.config.DriverExecutionProfile; import com.datastax.oss.driver.internal.core.cql.DefaultBoundStatement; import com.tngtech.java.junit.dataprovider.DataProvider; @@ -191,6 +192,7 @@ private static BoundStatement newBoundStatement() { null, null, null, - Statement.NO_NOW_IN_SECONDS); + Statement.NO_NOW_IN_SECONDS, + RequestRoutingType.REGULAR); } } diff --git a/core/src/test/java/com/datastax/oss/driver/internal/core/cql/DefaultBatchStatementTest.java b/core/src/test/java/com/datastax/oss/driver/internal/core/cql/DefaultBatchStatementTest.java index 2377968b4fc..3f38ddaf3cb 100644 --- a/core/src/test/java/com/datastax/oss/driver/internal/core/cql/DefaultBatchStatementTest.java +++ b/core/src/test/java/com/datastax/oss/driver/internal/core/cql/DefaultBatchStatementTest.java @@ -26,6 +26,7 @@ import ch.qos.logback.classic.Level; import ch.qos.logback.classic.spi.ILoggingEvent; import com.datastax.oss.driver.api.core.DefaultConsistencyLevel; +import com.datastax.oss.driver.api.core.RequestRoutingType; import com.datastax.oss.driver.api.core.cql.BatchStatement; import com.datastax.oss.driver.api.core.cql.BatchStatementBuilder; import com.datastax.oss.driver.api.core.cql.BatchType; @@ -109,43 +110,50 @@ public void should_infer_lwt_status() { SimpleStatement.builder("SELECT * FROM some_table WHERE a = ?").build(); BoundStatement lwtBoundStatement = mock(DefaultBoundStatement.class); when(lwtBoundStatement.isLWT()).thenReturn(true); + when(lwtBoundStatement.getRequestRoutingType()).thenReturn(RequestRoutingType.LWT); // Without LWT statements added BatchStatementBuilder batchStatementBuilder = new BatchStatementBuilder(BatchType.UNLOGGED); batchStatementBuilder.addStatement(simpleStatement); - assertThat(batchStatementBuilder.build().isLWT()).isFalse(); + BatchStatement batch = batchStatementBuilder.build(); + assertThat(batch.isLWT()).isFalse(); + assertThat(batch.getRequestRoutingType()).isEqualByComparingTo(RequestRoutingType.REGULAR); // Check if implicitly set to true after adding LWT bound statement batchStatementBuilder.addStatement(lwtBoundStatement); assertThat(batchStatementBuilder.build().isLWT()).isTrue(); // Check if explicit set overrides implicit resolution - batchStatementBuilder.setIsLWT(false); - assertThat(batchStatementBuilder.build().isLWT()).isFalse(); + batchStatementBuilder.setRequestRoutingType(RequestRoutingType.REGULAR); + batch = batchStatementBuilder.build(); + assertThat(batch.isLWT()).isFalse(); + assertThat(batch.getRequestRoutingType()).isEqualByComparingTo(RequestRoutingType.REGULAR); batchStatementBuilder = new BatchStatementBuilder(BatchType.UNLOGGED); batchStatementBuilder.addStatement(simpleStatement); - batchStatementBuilder.setIsLWT(true); - assertThat(batchStatementBuilder.build().isLWT()).isTrue(); + batchStatementBuilder.setRequestRoutingType(RequestRoutingType.LWT); + batch = batchStatementBuilder.build(); + assertThat(batch.isLWT()).isTrue(); + assertThat(batch.getRequestRoutingType()).isEqualByComparingTo(RequestRoutingType.LWT); // Check if explicit set remains after clear assertThat(batchStatementBuilder.build().clear().isLWT()).isTrue(); // Similar checks without using builder - BatchStatement batch = BatchStatement.newInstance(BatchType.UNLOGGED); + batch = BatchStatement.newInstance(BatchType.UNLOGGED); assertThat(batch.isLWT()).isFalse(); batch = batch.add(simpleStatement); assertThat(batch.isLWT()).isFalse(); batch = batch.add(lwtBoundStatement); assertThat(batch.isLWT()).isTrue(); - batch = batch.setIsLWT(false); + batch = batch.setRequestRoutingType(RequestRoutingType.REGULAR); assertThat(batch.isLWT()).isFalse(); batch = batch.add(lwtBoundStatement); assertThat(batch.isLWT()).isFalse(); - batch = batch.setIsLWT(true); + batch = batch.setRequestRoutingType(RequestRoutingType.LWT); assertThat(batch.isLWT()).isTrue(); batch = batch.clear(); assertThat(batch.isLWT()).isTrue(); - batch = batch.setIsLWT(null); + batch = batch.setRequestRoutingType(null); assertThat(batch.isLWT()).isFalse(); assertThat(BatchStatement.newInstance(BatchType.UNLOGGED).isLWT()).isFalse(); @@ -155,4 +163,169 @@ public void should_infer_lwt_status() { assertThat(BatchStatement.newInstance(BatchType.LOGGED, lwtBoundStatement).isLWT()).isTrue(); assertThat(BatchStatement.newInstance(BatchType.COUNTER, lwtBoundStatement).isLWT()).isTrue(); } + + @Test + public void should_handle_null_routing_type_in_empty_batch() { + // Empty batch should return REGULAR (not null) and isLWT should be false + BatchStatement batch = BatchStatement.newInstance(BatchType.UNLOGGED); + assertThat(batch.getRequestRoutingType()).isNotNull(); + assertThat(batch.getRequestRoutingType()).isEqualByComparingTo(RequestRoutingType.REGULAR); + assertThat(batch.isLWT()).isFalse(); + + // Same for other batch types + batch = BatchStatement.newInstance(BatchType.LOGGED); + assertThat(batch.getRequestRoutingType()).isNotNull(); + assertThat(batch.getRequestRoutingType()).isEqualByComparingTo(RequestRoutingType.REGULAR); + assertThat(batch.isLWT()).isFalse(); + + batch = BatchStatement.newInstance(BatchType.COUNTER); + assertThat(batch.getRequestRoutingType()).isNotNull(); + assertThat(batch.getRequestRoutingType()).isEqualByComparingTo(RequestRoutingType.REGULAR); + assertThat(batch.isLWT()).isFalse(); + } + + @Test + public void should_handle_statements_with_null_routing_types() { + // Create statements that return null routing type + BoundStatement nullRoutingStatement1 = mock(DefaultBoundStatement.class); + when(nullRoutingStatement1.isLWT()).thenReturn(false); + when(nullRoutingStatement1.getRequestRoutingType()).thenReturn(null); + + BoundStatement nullRoutingStatement2 = mock(DefaultBoundStatement.class); + when(nullRoutingStatement2.isLWT()).thenReturn(false); + when(nullRoutingStatement2.getRequestRoutingType()).thenReturn(null); + + // Batch with only null routing type statements should return REGULAR + BatchStatement batch = BatchStatement.newInstance(BatchType.UNLOGGED); + batch = batch.add(nullRoutingStatement1); + batch = batch.add(nullRoutingStatement2); + + assertThat(batch.getRequestRoutingType()).isNotNull(); + assertThat(batch.getRequestRoutingType()).isEqualByComparingTo(RequestRoutingType.REGULAR); + assertThat(batch.isLWT()).isFalse(); + } + + @Test + public void should_handle_mixed_null_and_non_null_routing_types() { + // Create statements with different routing types + BoundStatement nullRoutingStatement = mock(DefaultBoundStatement.class); + when(nullRoutingStatement.isLWT()).thenReturn(false); + when(nullRoutingStatement.getRequestRoutingType()).thenReturn(null); + + BoundStatement regularStatement = mock(DefaultBoundStatement.class); + when(regularStatement.isLWT()).thenReturn(false); + when(regularStatement.getRequestRoutingType()).thenReturn(RequestRoutingType.REGULAR); + + BoundStatement lwtStatement = mock(DefaultBoundStatement.class); + when(lwtStatement.isLWT()).thenReturn(true); + when(lwtStatement.getRequestRoutingType()).thenReturn(RequestRoutingType.LWT); + + // Test 1: null + regular -> REGULAR + BatchStatement batch = BatchStatement.newInstance(BatchType.UNLOGGED); + batch = batch.add(nullRoutingStatement); + batch = batch.add(regularStatement); + assertThat(batch.getRequestRoutingType()).isEqualByComparingTo(RequestRoutingType.REGULAR); + assertThat(batch.isLWT()).isFalse(); + + // Test 2: null + LWT -> LWT (LWT should be detected) + batch = BatchStatement.newInstance(BatchType.UNLOGGED); + batch = batch.add(nullRoutingStatement); + batch = batch.add(lwtStatement); + assertThat(batch.getRequestRoutingType()).isEqualByComparingTo(RequestRoutingType.LWT); + assertThat(batch.isLWT()).isTrue(); + + // Test 3: regular + null + LWT -> LWT (LWT should be detected regardless of order) + batch = BatchStatement.newInstance(BatchType.UNLOGGED); + batch = batch.add(regularStatement); + batch = batch.add(nullRoutingStatement); + batch = batch.add(lwtStatement); + assertThat(batch.getRequestRoutingType()).isEqualByComparingTo(RequestRoutingType.LWT); + assertThat(batch.isLWT()).isTrue(); + + // Test 4: LWT + null + regular -> LWT (order shouldn't matter) + batch = BatchStatement.newInstance(BatchType.UNLOGGED); + batch = batch.add(lwtStatement); + batch = batch.add(nullRoutingStatement); + batch = batch.add(regularStatement); + assertThat(batch.getRequestRoutingType()).isEqualByComparingTo(RequestRoutingType.LWT); + assertThat(batch.isLWT()).isTrue(); + } + + @Test + public void should_handle_explicit_null_routing_type_override() { + BoundStatement lwtStatement = mock(DefaultBoundStatement.class); + when(lwtStatement.isLWT()).thenReturn(true); + when(lwtStatement.getRequestRoutingType()).thenReturn(RequestRoutingType.LWT); + + BoundStatement regularStatement = mock(DefaultBoundStatement.class); + when(regularStatement.isLWT()).thenReturn(false); + when(regularStatement.getRequestRoutingType()).thenReturn(RequestRoutingType.REGULAR); + + // Test 1: Batch with LWT statement, then set routing type to null + // Should fall back to inference and detect LWT + BatchStatement batch = BatchStatement.newInstance(BatchType.UNLOGGED); + batch = batch.add(lwtStatement); + assertThat(batch.getRequestRoutingType()).isEqualByComparingTo(RequestRoutingType.LWT); + assertThat(batch.isLWT()).isTrue(); + + batch = batch.setRequestRoutingType(null); + assertThat(batch.getRequestRoutingType()).isEqualByComparingTo(RequestRoutingType.LWT); + assertThat(batch.isLWT()).isTrue(); + + // Test 2: Batch with regular statement, set routing type to null + // Should infer REGULAR + batch = BatchStatement.newInstance(BatchType.UNLOGGED); + batch = batch.add(regularStatement); + batch = batch.setRequestRoutingType(null); + assertThat(batch.getRequestRoutingType()).isEqualByComparingTo(RequestRoutingType.REGULAR); + assertThat(batch.isLWT()).isFalse(); + + // Test 3: Empty batch with explicit null routing type + // Should return REGULAR + batch = BatchStatement.newInstance(BatchType.UNLOGGED); + batch = batch.setRequestRoutingType(null); + assertThat(batch.getRequestRoutingType()).isEqualByComparingTo(RequestRoutingType.REGULAR); + assertThat(batch.isLWT()).isFalse(); + } + + @Test + public void should_return_non_null_routing_type_consistently() { + // Verify that getRequestRoutingType never returns null + SimpleStatement simpleStatement = + SimpleStatement.builder("SELECT * FROM some_table WHERE a = ?").build(); + + BoundStatement lwtStatement = mock(DefaultBoundStatement.class); + when(lwtStatement.isLWT()).thenReturn(true); + when(lwtStatement.getRequestRoutingType()).thenReturn(RequestRoutingType.LWT); + + BoundStatement nullRoutingStatement = mock(DefaultBoundStatement.class); + when(nullRoutingStatement.isLWT()).thenReturn(false); + when(nullRoutingStatement.getRequestRoutingType()).thenReturn(null); + + // Test various batch configurations + BatchStatement batch1 = BatchStatement.newInstance(BatchType.UNLOGGED); + assertThat(batch1.getRequestRoutingType()).isNotNull(); + + BatchStatement batch2 = batch1.add(simpleStatement); + assertThat(batch2.getRequestRoutingType()).isNotNull(); + + BatchStatement batch3 = batch2.add(lwtStatement); + assertThat(batch3.getRequestRoutingType()).isNotNull(); + + BatchStatement batch4 = batch3.setRequestRoutingType(null); + assertThat(batch4.getRequestRoutingType()).isNotNull(); + + BatchStatement batch5 = + BatchStatement.newInstance(BatchType.UNLOGGED).add(nullRoutingStatement); + assertThat(batch5.getRequestRoutingType()).isNotNull(); + assertThat(batch5.getRequestRoutingType()).isEqualByComparingTo(RequestRoutingType.REGULAR); + + BatchStatement batch6 = batch5.setRequestRoutingType(RequestRoutingType.LWT); + assertThat(batch6.getRequestRoutingType()).isNotNull(); + assertThat(batch6.getRequestRoutingType()).isEqualByComparingTo(RequestRoutingType.LWT); + + BatchStatement batch7 = batch6.setRequestRoutingType(null); + assertThat(batch7.getRequestRoutingType()).isNotNull(); + assertThat(batch7.getRequestRoutingType()).isEqualByComparingTo(RequestRoutingType.REGULAR); + } } diff --git a/core/src/test/java/com/datastax/oss/driver/internal/core/cql/StatementSizeTest.java b/core/src/test/java/com/datastax/oss/driver/internal/core/cql/StatementSizeTest.java index dc3ab0702f7..1291e0f8a49 100644 --- a/core/src/test/java/com/datastax/oss/driver/internal/core/cql/StatementSizeTest.java +++ b/core/src/test/java/com/datastax/oss/driver/internal/core/cql/StatementSizeTest.java @@ -21,6 +21,7 @@ import static org.mockito.Mockito.when; import com.datastax.oss.driver.api.core.DefaultProtocolVersion; +import com.datastax.oss.driver.api.core.RequestRoutingType; import com.datastax.oss.driver.api.core.config.DriverConfig; import com.datastax.oss.driver.api.core.config.DriverExecutionProfile; import com.datastax.oss.driver.api.core.cql.BatchStatement; @@ -287,6 +288,7 @@ private BoundStatement newBoundStatement( CodecRegistry.DEFAULT, DefaultProtocolVersion.V5, null, - Statement.NO_NOW_IN_SECONDS); + Statement.NO_NOW_IN_SECONDS, + RequestRoutingType.REGULAR); } } diff --git a/core/src/test/java/com/datastax/oss/driver/internal/core/loadbalancing/BasicLoadBalancingPolicyQueryPlanTest.java b/core/src/test/java/com/datastax/oss/driver/internal/core/loadbalancing/BasicLoadBalancingPolicyQueryPlanTest.java index c2e89cdf07c..428bb5db4f6 100644 --- a/core/src/test/java/com/datastax/oss/driver/internal/core/loadbalancing/BasicLoadBalancingPolicyQueryPlanTest.java +++ b/core/src/test/java/com/datastax/oss/driver/internal/core/loadbalancing/BasicLoadBalancingPolicyQueryPlanTest.java @@ -38,6 +38,7 @@ import static org.mockito.Mockito.when; import com.datastax.oss.driver.api.core.CqlIdentifier; +import com.datastax.oss.driver.api.core.RequestRoutingType; import com.datastax.oss.driver.api.core.config.DriverExecutionProfile; import com.datastax.oss.driver.api.core.metadata.Metadata; import com.datastax.oss.driver.api.core.metadata.TokenMap; @@ -80,6 +81,7 @@ public void setup() { when(metadataManager.getContactPoints()).thenReturn(ImmutableSet.of(node1)); when(metadataManager.getMetadata()).thenReturn(metadata); when(metadata.getTokenMap()).thenAnswer(invocation -> Optional.of(this.tokenMap)); + when(request.getRequestRoutingType()).thenReturn(RequestRoutingType.REGULAR); policy = createAndInitPolicy(); } diff --git a/core/src/test/java/com/datastax/oss/driver/internal/core/loadbalancing/DcInferringLoadBalancingPolicyInitTest.java b/core/src/test/java/com/datastax/oss/driver/internal/core/loadbalancing/DcInferringLoadBalancingPolicyInitTest.java index 20de3afe9c3..8440fa3bd6b 100644 --- a/core/src/test/java/com/datastax/oss/driver/internal/core/loadbalancing/DcInferringLoadBalancingPolicyInitTest.java +++ b/core/src/test/java/com/datastax/oss/driver/internal/core/loadbalancing/DcInferringLoadBalancingPolicyInitTest.java @@ -36,7 +36,10 @@ import edu.umd.cs.findbugs.annotations.NonNull; import java.util.UUID; import org.junit.Test; +import org.junit.runner.RunWith; +import org.mockito.junit.MockitoJUnitRunner; +@RunWith(MockitoJUnitRunner.Silent.class) public class DcInferringLoadBalancingPolicyInitTest extends LoadBalancingPolicyTestBase { @Test diff --git a/core/src/test/java/com/datastax/oss/driver/internal/core/loadbalancing/DefaultLoadBalancingPolicyConfigTest.java b/core/src/test/java/com/datastax/oss/driver/internal/core/loadbalancing/DefaultLoadBalancingPolicyConfigTest.java new file mode 100644 index 00000000000..768722e0e86 --- /dev/null +++ b/core/src/test/java/com/datastax/oss/driver/internal/core/loadbalancing/DefaultLoadBalancingPolicyConfigTest.java @@ -0,0 +1,114 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * Copyright (C) 2020 ScyllaDB + * + * Modified by ScyllaDB + */ +package com.datastax.oss.driver.internal.core.loadbalancing; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.mockito.Mockito.atLeast; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +import com.datastax.oss.driver.api.core.config.DefaultDriverOption; +import com.datastax.oss.driver.api.core.config.DriverExecutionProfile; +import com.datastax.oss.driver.shaded.guava.common.collect.ImmutableSet; +import com.tngtech.java.junit.dataprovider.DataProvider; +import com.tngtech.java.junit.dataprovider.DataProviderRunner; +import com.tngtech.java.junit.dataprovider.UseDataProvider; +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.mockito.MockitoAnnotations; + +@RunWith(DataProviderRunner.class) +public class DefaultLoadBalancingPolicyConfigTest extends LoadBalancingPolicyTestBase { + + @Before + @Override + public void setup() { + MockitoAnnotations.initMocks(this); + super.setup(); + } + + @Test + @DataProvider(value = {"REGULAR", "regular", "PRESERVE_REPLICA_ORDER", "Preserve_Replica_Order"}) + public void should_accept_valid_routing_methods(String routingMethod) { + when(metadataManager.getContactPoints()).thenReturn(ImmutableSet.of(node1)); + + when(defaultProfile.getString( + DefaultDriverOption.LOAD_BALANCING_DEFAULT_LWT_REQUEST_ROUTING_METHOD)) + .thenReturn(routingMethod); + DefaultLoadBalancingPolicy policy = + new DefaultLoadBalancingPolicy(context, DriverExecutionProfile.DEFAULT_NAME); + assertThat(policy).isNotNull(); + } + + @Test + @DataProvider( + value = {"INVALID_METHOD", "", "@#$%^&*()", " REGULAR "}, + trimValues = false) + public void should_default_to_preserve_replica_order_for_invalid_routing_methods( + String invalidValue) { + when(metadataManager.getContactPoints()).thenReturn(ImmutableSet.of(node1)); + + when(defaultProfile.getString( + DefaultDriverOption.LOAD_BALANCING_DEFAULT_LWT_REQUEST_ROUTING_METHOD)) + .thenReturn(invalidValue); + DefaultLoadBalancingPolicy policy = + new DefaultLoadBalancingPolicy(context, DriverExecutionProfile.DEFAULT_NAME); + + assertThat(policy).isNotNull(); + + verify(appender).doAppend(loggingEventCaptor.capture()); + assertThat(loggingEventCaptor.getValue().getFormattedMessage()) + .contains("Unknown request routing method") + .contains("defaulting to PRESERVE_REPLICA_ORDER"); + } + + @Test + @UseDataProvider("configurationCombinations") + public void should_accept_configuration_combinations( + String routingMethod, boolean slowAvoidance) { + when(metadataManager.getContactPoints()).thenReturn(ImmutableSet.of(node1)); + + when(defaultProfile.getString( + DefaultDriverOption.LOAD_BALANCING_DEFAULT_LWT_REQUEST_ROUTING_METHOD)) + .thenReturn(routingMethod); + when(defaultProfile.getBoolean(DefaultDriverOption.LOAD_BALANCING_POLICY_SLOW_AVOIDANCE, true)) + .thenReturn(slowAvoidance); + + DefaultLoadBalancingPolicy policy = + new DefaultLoadBalancingPolicy(context, DriverExecutionProfile.DEFAULT_NAME); + assertThat(policy).isNotNull(); + + verify(defaultProfile, atLeast(1)) + .getBoolean(DefaultDriverOption.LOAD_BALANCING_POLICY_SLOW_AVOIDANCE, true); + } + + @DataProvider + public static Object[][] configurationCombinations() { + return new Object[][] { + {"PRESERVE_REPLICA_ORDER", false}, + {"REGULAR", true} + }; + } +} diff --git a/core/src/test/java/com/datastax/oss/driver/internal/core/loadbalancing/DefaultLoadBalancingPolicyInitTest.java b/core/src/test/java/com/datastax/oss/driver/internal/core/loadbalancing/DefaultLoadBalancingPolicyInitTest.java index 77887d627f9..53d9633a23d 100644 --- a/core/src/test/java/com/datastax/oss/driver/internal/core/loadbalancing/DefaultLoadBalancingPolicyInitTest.java +++ b/core/src/test/java/com/datastax/oss/driver/internal/core/loadbalancing/DefaultLoadBalancingPolicyInitTest.java @@ -36,7 +36,10 @@ import edu.umd.cs.findbugs.annotations.NonNull; import java.util.UUID; import org.junit.Test; +import org.junit.runner.RunWith; +import org.mockito.junit.MockitoJUnitRunner; +@RunWith(MockitoJUnitRunner.Silent.class) public class DefaultLoadBalancingPolicyInitTest extends LoadBalancingPolicyTestBase { @Test diff --git a/core/src/test/java/com/datastax/oss/driver/internal/core/loadbalancing/DefaultLoadBalancingPolicyLwtRoutingTest.java b/core/src/test/java/com/datastax/oss/driver/internal/core/loadbalancing/DefaultLoadBalancingPolicyLwtRoutingTest.java new file mode 100644 index 00000000000..1e16aafa5f2 --- /dev/null +++ b/core/src/test/java/com/datastax/oss/driver/internal/core/loadbalancing/DefaultLoadBalancingPolicyLwtRoutingTest.java @@ -0,0 +1,253 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * Copyright (C) 2020 ScyllaDB + * + * Modified by ScyllaDB + */ +package com.datastax.oss.driver.internal.core.loadbalancing; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.mockito.BDDMockito.given; +import static org.mockito.Mockito.when; + +import com.datastax.oss.driver.api.core.CqlIdentifier; +import com.datastax.oss.driver.api.core.RequestRoutingType; +import com.datastax.oss.driver.api.core.config.DefaultDriverOption; +import com.datastax.oss.driver.api.core.config.DriverExecutionProfile; +import com.datastax.oss.driver.api.core.metadata.Metadata; +import com.datastax.oss.driver.api.core.metadata.Node; +import com.datastax.oss.driver.api.core.metadata.TokenMap; +import com.datastax.oss.driver.api.core.metadata.token.Token; +import com.datastax.oss.driver.api.core.session.Request; +import com.datastax.oss.driver.internal.core.session.DefaultSession; +import com.datastax.oss.driver.shaded.guava.common.collect.ImmutableList; +import com.datastax.oss.driver.shaded.guava.common.collect.ImmutableMap; +import com.datastax.oss.driver.shaded.guava.common.collect.ImmutableSet; +import com.datastax.oss.protocol.internal.util.Bytes; +import java.nio.ByteBuffer; +import java.util.Optional; +import java.util.Queue; +import java.util.UUID; +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.mockito.Mock; +import org.mockito.junit.MockitoJUnitRunner; + +@RunWith(MockitoJUnitRunner.Silent.class) +public class DefaultLoadBalancingPolicyLwtRoutingTest extends LoadBalancingPolicyTestBase { + + private static final CqlIdentifier KEYSPACE = CqlIdentifier.fromInternal("ks"); + private static final ByteBuffer ROUTING_KEY = Bytes.fromHexString("0xdeadbeef"); + + @Mock protected Request request; + @Mock protected DefaultSession session; + @Mock protected Metadata metadata; + @Mock protected TokenMap tokenMap; + @Mock protected Token routingToken; + + private DefaultLoadBalancingPolicy policy; + + @Before + @Override + public void setup() { + super.setup(); + when(metadataManager.getContactPoints()).thenReturn(ImmutableSet.of(node1)); + when(metadataManager.getMetadata()).thenReturn(metadata); + when(metadata.getTokenMap()).thenAnswer(invocation -> Optional.of(this.tokenMap)); + + // Set up nodes with proper DCs + when(node1.getDatacenter()).thenReturn("dc1"); + when(node2.getDatacenter()).thenReturn("dc1"); + when(node3.getDatacenter()).thenReturn("dc1"); + when(node4.getDatacenter()).thenReturn("dc2"); + when(node5.getDatacenter()).thenReturn("dc2"); + + // Configure for PRESERVE_REPLICA_ORDER routing for LWT + when(defaultProfile.getString( + DefaultDriverOption.LOAD_BALANCING_DEFAULT_LWT_REQUEST_ROUTING_METHOD)) + .thenReturn("PRESERVE_REPLICA_ORDER"); + + policy = new DefaultLoadBalancingPolicy(context, DriverExecutionProfile.DEFAULT_NAME); + policy.init( + ImmutableMap.of( + UUID.randomUUID(), node1, + UUID.randomUUID(), node2, + UUID.randomUUID(), node3, + UUID.randomUUID(), node4, + UUID.randomUUID(), node5), + distanceReporter); + } + + @Test + public void should_preserve_replica_order_with_empty_replicas() { + // Given + given(request.getRoutingKeyspace()).willReturn(KEYSPACE); + given(request.getRoutingKey()).willReturn(ROUTING_KEY); + given(request.getRequestRoutingType()).willReturn(RequestRoutingType.LWT); + given(tokenMap.getReplicasList(KEYSPACE, null, ROUTING_KEY)).willReturn(ImmutableList.of()); + + // When + Queue plan = policy.newQueryPlan(request, session); + + // Then + assertThat(plan).isEmpty(); + } + + @Test + public void should_preserve_replica_order_with_single_local_replica() { + // Given + given(request.getRoutingKeyspace()).willReturn(KEYSPACE); + given(request.getRoutingKey()).willReturn(ROUTING_KEY); + given(request.getRequestRoutingType()).willReturn(RequestRoutingType.LWT); + given(tokenMap.getReplicasList(KEYSPACE, null, ROUTING_KEY)) + .willReturn(ImmutableList.of(node2)); + + // When + Queue plan = policy.newQueryPlan(request, session); + + // Then + assertThat(plan).containsExactly(node2); + } + + @Test + public void should_preserve_replica_order_with_multiple_local_replicas() { + // Given + given(request.getRoutingKeyspace()).willReturn(KEYSPACE); + given(request.getRoutingKey()).willReturn(ROUTING_KEY); + given(request.getRequestRoutingType()).willReturn(RequestRoutingType.LWT); + given(tokenMap.getReplicasList(KEYSPACE, null, ROUTING_KEY)) + .willReturn(ImmutableList.of(node3, node1, node2)); + + // When + Queue plan = policy.newQueryPlan(request, session); + + // Then - order preserved exactly as returned from token map + assertThat(plan).containsExactly(node3, node1, node2); + } + + @Test + public void should_push_remote_replicas_to_end() { + // Given + given(request.getRoutingKeyspace()).willReturn(KEYSPACE); + given(request.getRoutingKey()).willReturn(ROUTING_KEY); + given(request.getRequestRoutingType()).willReturn(RequestRoutingType.LWT); + // Token map returns replicas in mixed order: remote, local, remote, local + given(tokenMap.getReplicasList(KEYSPACE, null, ROUTING_KEY)) + .willReturn(ImmutableList.of(node4, node1, node5, node2)); + + // When + Queue plan = policy.newQueryPlan(request, session); + + // Then - local replicas first (preserving their order), remote replicas last (preserving their + // order) + assertThat(plan).containsExactly(node1, node2, node4, node5); + } + + @Test + public void should_preserve_replica_order_with_all_remote_replicas() { + // Given + given(request.getRoutingKeyspace()).willReturn(KEYSPACE); + given(request.getRoutingKey()).willReturn(ROUTING_KEY); + given(request.getRequestRoutingType()).willReturn(RequestRoutingType.LWT); + given(tokenMap.getReplicasList(KEYSPACE, null, ROUTING_KEY)) + .willReturn(ImmutableList.of(node5, node4)); + + // When + Queue plan = policy.newQueryPlan(request, session); + + // Then - all remote replicas, order preserved + assertThat(plan).containsExactly(node5, node4); + } + + @Test + public void should_handle_null_local_datacenter() { + // Given + when(defaultProfile.isDefined(DefaultDriverOption.LOAD_BALANCING_LOCAL_DATACENTER)) + .thenReturn(false); + + given(request.getRoutingKeyspace()).willReturn(KEYSPACE); + given(request.getRoutingKey()).willReturn(ROUTING_KEY); + given(request.getRequestRoutingType()).willReturn(RequestRoutingType.LWT); + given(tokenMap.getReplicasList(KEYSPACE, null, ROUTING_KEY)) + .willReturn(ImmutableList.of(node1, node2)); + + // When - calling with request that might not have local DC set + // The method should handle null localDc gracefully and just return replicas as-is + Queue plan = policy.newQueryPlanPreserveReplicas(request, session); + + // Then - returns all replicas in order when localDc is not defined + assertThat(plan).containsExactly(node1, node2); + } + + @Test + public void should_preserve_order_when_no_routing_key() { + // Given + given(request.getRoutingKeyspace()).willReturn(null); + given(request.getRoutingKey()).willReturn(null); + given(request.getRequestRoutingType()).willReturn(RequestRoutingType.REGULAR); + + // When + Queue plan = policy.newQueryPlan(request, session); + + // Then - with no routing key, no replicas identified, falls back to empty or default behavior + // This tests the edge case where getReplicas returns empty list + assertThat(plan).isNotNull(); + } + + @Test + public void should_dispatch_to_preserve_replicas_when_lwt_and_config_set() { + // Given + given(request.getRoutingKeyspace()).willReturn(KEYSPACE); + given(request.getRoutingKey()).willReturn(ROUTING_KEY); + given(request.getRequestRoutingType()).willReturn(RequestRoutingType.LWT); + given(tokenMap.getReplicasList(KEYSPACE, null, ROUTING_KEY)) + .willReturn(ImmutableList.of(node1, node2)); + + // When + Queue plan = policy.newQueryPlan(request, session); + + // Then - verify it used preserve replica order (no shuffling) + // Call multiple times to ensure order is always preserved (not shuffled) + Queue plan2 = policy.newQueryPlan(request, session); + Queue plan3 = policy.newQueryPlan(request, session); + + assertThat(plan).containsExactly(node1, node2); + assertThat(plan2).containsExactly(node1, node2); + assertThat(plan3).containsExactly(node1, node2); + } + + @Test + public void should_not_add_non_replicas_in_preserve_mode() { + // Given + given(request.getRoutingKeyspace()).willReturn(KEYSPACE); + given(request.getRoutingKey()).willReturn(ROUTING_KEY); + given(request.getRequestRoutingType()).willReturn(RequestRoutingType.LWT); + // Only node1 is a replica + given(tokenMap.getReplicasList(KEYSPACE, null, ROUTING_KEY)) + .willReturn(ImmutableList.of(node1)); + + // When + Queue plan = policy.newQueryPlan(request, session); + + // Then - only the replica is in the plan, other live nodes are NOT added + assertThat(plan).containsExactly(node1); + } +} diff --git a/core/src/test/java/com/datastax/oss/driver/internal/core/loadbalancing/DefaultLoadBalancingPolicyQueryPlanTest.java b/core/src/test/java/com/datastax/oss/driver/internal/core/loadbalancing/DefaultLoadBalancingPolicyQueryPlanTest.java index f016323c16b..f9445b84d76 100644 --- a/core/src/test/java/com/datastax/oss/driver/internal/core/loadbalancing/DefaultLoadBalancingPolicyQueryPlanTest.java +++ b/core/src/test/java/com/datastax/oss/driver/internal/core/loadbalancing/DefaultLoadBalancingPolicyQueryPlanTest.java @@ -33,6 +33,7 @@ import static org.mockito.Mockito.spy; import static org.mockito.Mockito.times; +import com.datastax.oss.driver.api.core.RequestRoutingType; import com.datastax.oss.driver.api.core.metadata.Node; import com.datastax.oss.driver.internal.core.pool.ChannelPool; import com.datastax.oss.driver.shaded.guava.common.collect.ImmutableList; @@ -44,8 +45,11 @@ import java.util.concurrent.atomic.AtomicLongArray; import org.junit.Before; import org.junit.Test; +import org.junit.runner.RunWith; import org.mockito.Mock; +import org.mockito.junit.MockitoJUnitRunner; +@RunWith(MockitoJUnitRunner.Silent.class) public class DefaultLoadBalancingPolicyQueryPlanTest extends BasicLoadBalancingPolicyQueryPlanTest { private static final long T0 = Long.MIN_VALUE; @@ -387,6 +391,73 @@ public void should_prefer_local_rack_replica_with_less_inflight_requests() { assertThat(plan2).containsExactly(node5, node3, node1, node4, node2); } + @Test + public void should_ignore_local_rack_prioritization_for_lwt_requests() { + // Given - LWT request with local rack configured + given(request.getRoutingKeyspace()).willReturn(KEYSPACE); + given(request.getRoutingKey()).willReturn(ROUTING_KEY); + given(request.getRequestRoutingType()).willReturn(RequestRoutingType.LWT); + given(tokenMap.getReplicasList(KEYSPACE, null, ROUTING_KEY)) + .willReturn(ImmutableList.of(node1, node3, node5)); + + String localRack = "rack1"; + given(dsePolicy.getLocalRack()).willReturn(localRack); + // Only node1 is in the local rack + given(node1.getRack()).willReturn(localRack); + given(node3.getRack()).willReturn("rack2"); + given(node5.getRack()).willReturn("rack3"); + + given(pool1.getInFlight()).willReturn(0); + given(pool3.getInFlight()).willReturn(0); + given(pool5.getInFlight()).willReturn(0); + + // When + Queue plan1 = dsePolicy.newQueryPlan(request, session); + Queue plan2 = dsePolicy.newQueryPlan(request, session); + + // Then - for LWT requests (RequestRoutingType.LWT) the policy should ignore local rack + // prioritization and preserve the replica order returned by the token map. + // The shuffle methods are still invoked for the non-replica range, so only the non-replica + // nodes (node2 and node4) are permuted between successive plans. + then(dsePolicy).should(times(2)).shuffleHead(any(), anyInt()); + then(dsePolicy).should(times(2)).shuffleInRange(any(), anyInt(), anyInt()); + assertThat(plan1).containsExactly(node1, node3, node5, node2, node4); + assertThat(plan2).containsExactly(node1, node3, node5, node4, node2); + } + + @Test + public void should_respect_local_rack_prioritization_for_regular_requests() { + // Given - REGULAR request (not LWT) with local rack configured + given(request.getRoutingKeyspace()).willReturn(KEYSPACE); + given(request.getRoutingKey()).willReturn(ROUTING_KEY); + given(request.getRequestRoutingType()) + .willReturn(com.datastax.oss.driver.api.core.RequestRoutingType.REGULAR); + given(tokenMap.getReplicasList(KEYSPACE, null, ROUTING_KEY)) + .willReturn(ImmutableList.of(node1, node3, node5)); + + String localRack = "rack1"; + given(dsePolicy.getLocalRack()).willReturn(localRack); + // node1 is in the local rack + given(node1.getRack()).willReturn(localRack); + given(node3.getRack()).willReturn("rack2"); + given(node5.getRack()).willReturn("rack3"); + + given(pool1.getInFlight()).willReturn(0); + given(pool3.getInFlight()).willReturn(0); + given(pool5.getInFlight()).willReturn(0); + + // When + Queue plan1 = dsePolicy.newQueryPlan(request, session); + Queue plan2 = dsePolicy.newQueryPlan(request, session); + + // Then - local rack replica prioritized and shuffled separately from others + // Verify that local rack replicas and non-local-rack replicas are shuffled separately + then(dsePolicy).should(times(2)).shuffleHead(any(), anyInt()); + then(dsePolicy).should(times(2)).shuffleInRange(any(), anyInt(), anyInt()); + assertThat(plan1).containsExactly(node1, node3, node5, node2, node4); + assertThat(plan2).containsExactly(node1, node3, node5, node4, node2); + } + @Override protected DefaultLoadBalancingPolicy createAndInitPolicy() { DefaultLoadBalancingPolicy policy = diff --git a/core/src/test/java/com/datastax/oss/driver/internal/core/loadbalancing/DefaultLoadBalancingPolicyRequestRoutingTest.java b/core/src/test/java/com/datastax/oss/driver/internal/core/loadbalancing/DefaultLoadBalancingPolicyRequestRoutingTest.java new file mode 100644 index 00000000000..9aef1825329 --- /dev/null +++ b/core/src/test/java/com/datastax/oss/driver/internal/core/loadbalancing/DefaultLoadBalancingPolicyRequestRoutingTest.java @@ -0,0 +1,257 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * Copyright (C) 2020 ScyllaDB + * + * Modified by ScyllaDB + */ +package com.datastax.oss.driver.internal.core.loadbalancing; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.mockito.BDDMockito.given; +import static org.mockito.Mockito.when; + +import com.datastax.oss.driver.api.core.CqlIdentifier; +import com.datastax.oss.driver.api.core.RequestRoutingType; +import com.datastax.oss.driver.api.core.config.DefaultDriverOption; +import com.datastax.oss.driver.api.core.config.DriverExecutionProfile; +import com.datastax.oss.driver.api.core.metadata.Metadata; +import com.datastax.oss.driver.api.core.metadata.Node; +import com.datastax.oss.driver.api.core.metadata.TokenMap; +import com.datastax.oss.driver.api.core.metadata.token.Token; +import com.datastax.oss.driver.api.core.session.Request; +import com.datastax.oss.driver.internal.core.loadbalancing.DefaultLoadBalancingPolicy.RequestRoutingMethod; +import com.datastax.oss.driver.internal.core.session.DefaultSession; +import com.datastax.oss.driver.shaded.guava.common.collect.ImmutableList; +import com.datastax.oss.driver.shaded.guava.common.collect.ImmutableMap; +import com.datastax.oss.driver.shaded.guava.common.collect.ImmutableSet; +import com.datastax.oss.protocol.internal.util.Bytes; +import java.nio.ByteBuffer; +import java.util.Optional; +import java.util.Queue; +import java.util.UUID; +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.mockito.Mock; +import org.mockito.junit.MockitoJUnitRunner; + +@RunWith(MockitoJUnitRunner.Silent.class) +public class DefaultLoadBalancingPolicyRequestRoutingTest extends LoadBalancingPolicyTestBase { + + private static final CqlIdentifier KEYSPACE = CqlIdentifier.fromInternal("ks"); + private static final ByteBuffer ROUTING_KEY = Bytes.fromHexString("0xdeadbeef"); + + @Mock protected Request request; + @Mock protected DefaultSession session; + @Mock protected Metadata metadata; + @Mock protected TokenMap tokenMap; + @Mock protected Token routingToken; + + private DefaultLoadBalancingPolicy policy; + + @Before + @Override + public void setup() { + super.setup(); + when(metadataManager.getContactPoints()).thenReturn(ImmutableSet.of(node1)); + when(metadataManager.getMetadata()).thenReturn(metadata); + when(metadata.getTokenMap()).thenAnswer(invocation -> Optional.of(this.tokenMap)); + + when(node1.getDatacenter()).thenReturn("dc1"); + when(node2.getDatacenter()).thenReturn("dc1"); + when(node3.getDatacenter()).thenReturn("dc1"); + } + + private void initPolicy(String routingMethod) { + when(defaultProfile.getString( + DefaultDriverOption.LOAD_BALANCING_DEFAULT_LWT_REQUEST_ROUTING_METHOD)) + .thenReturn(routingMethod); + + policy = new DefaultLoadBalancingPolicy(context, DriverExecutionProfile.DEFAULT_NAME); + policy.init( + ImmutableMap.of( + UUID.randomUUID(), node1, + UUID.randomUUID(), node2, + UUID.randomUUID(), node3), + distanceReporter); + } + + @Test + public void should_return_regular_when_request_is_null() { + // Given + initPolicy("REGULAR"); + + // When + RequestRoutingMethod method = policy.getDefaultLWTRequestRoutingMethod(null); + + // Then + assertThat(method).isEqualTo(RequestRoutingMethod.REGULAR); + } + + @Test + public void should_return_regular_when_routing_type_is_regular() { + // Given + initPolicy("PRESERVE_REPLICA_ORDER"); + given(request.getRequestRoutingType()).willReturn(RequestRoutingType.REGULAR); + + // When + RequestRoutingMethod method = policy.getDefaultLWTRequestRoutingMethod(request); + + // Then + assertThat(method).isEqualTo(RequestRoutingMethod.REGULAR); + } + + @Test + public void should_return_regular_for_lwt_when_config_is_regular() { + // Given + initPolicy("REGULAR"); + given(request.getRequestRoutingType()).willReturn(RequestRoutingType.LWT); + + // When + RequestRoutingMethod method = policy.getDefaultLWTRequestRoutingMethod(request); + + // Then + assertThat(method).isEqualTo(RequestRoutingMethod.REGULAR); + } + + @Test + public void should_return_preserve_replica_order_for_lwt_when_config_is_preserve() { + // Given + initPolicy("PRESERVE_REPLICA_ORDER"); + given(request.getRequestRoutingType()).willReturn(RequestRoutingType.LWT); + + // When + RequestRoutingMethod method = policy.getDefaultLWTRequestRoutingMethod(request); + + // Then + assertThat(method).isEqualTo(RequestRoutingMethod.PRESERVE_REPLICA_ORDER); + } + + @Test + public void should_dispatch_to_regular_query_plan_when_request_is_regular() { + // Given + initPolicy("PRESERVE_REPLICA_ORDER"); + given(request.getRoutingKeyspace()).willReturn(KEYSPACE); + given(request.getRoutingKey()).willReturn(ROUTING_KEY); + given(request.getRequestRoutingType()).willReturn(RequestRoutingType.REGULAR); + given(tokenMap.getReplicasList(KEYSPACE, null, ROUTING_KEY)) + .willReturn(ImmutableList.of(node1, node2)); + + // When + Queue plan1 = policy.newQueryPlan(request, session); + Queue plan2 = policy.newQueryPlan(request, session); + + // Then - regular routing shuffles replicas (node1, node2), and also adds the local + // non-replica node (node3); order may vary between plans but the same three nodes + // must be present in each plan + assertThat(plan1).containsExactlyInAnyOrder(node1, node2, node3); + assertThat(plan2).containsExactlyInAnyOrder(node1, node2, node3); + } + + @Test + public void should_dispatch_to_preserve_query_plan_when_lwt_and_config_preserve() { + // Given + initPolicy("PRESERVE_REPLICA_ORDER"); + given(request.getRoutingKeyspace()).willReturn(KEYSPACE); + given(request.getRoutingKey()).willReturn(ROUTING_KEY); + given(request.getRequestRoutingType()).willReturn(RequestRoutingType.LWT); + given(tokenMap.getReplicasList(KEYSPACE, null, ROUTING_KEY)) + .willReturn(ImmutableList.of(node2, node1)); + + // When + Queue plan1 = policy.newQueryPlan(request, session); + Queue plan2 = policy.newQueryPlan(request, session); + Queue plan3 = policy.newQueryPlan(request, session); + + // Then - preserve routing maintains exact order + assertThat(plan1).containsExactly(node2, node1); + assertThat(plan2).containsExactly(node2, node1); + assertThat(plan3).containsExactly(node2, node1); + } + + @Test + public void should_dispatch_to_regular_query_plan_when_lwt_but_config_regular() { + // Given + initPolicy("REGULAR"); + given(request.getRoutingKeyspace()).willReturn(KEYSPACE); + given(request.getRoutingKey()).willReturn(ROUTING_KEY); + given(request.getRequestRoutingType()).willReturn(RequestRoutingType.LWT); + given(tokenMap.getReplicasList(KEYSPACE, null, ROUTING_KEY)) + .willReturn(ImmutableList.of(node1, node2)); + + // When + Queue plan = policy.newQueryPlan(request, session); + + // Then - uses regular routing which may shuffle and add non-replicas + assertThat(plan).containsExactlyInAnyOrder(node1, node2, node3); + } + + @Test + public void should_handle_null_request_in_new_query_plan() { + // Given + initPolicy("PRESERVE_REPLICA_ORDER"); + + // When + Queue plan = policy.newQueryPlan(null, session); + + // Then - null request should use regular routing + assertThat(plan).isNotNull(); + assertThat(plan).containsExactlyInAnyOrder(node1, node2, node3); + } + + @Test + public void should_use_regular_routing_for_unknown_routing_type() { + // Given + initPolicy("PRESERVE_REPLICA_ORDER"); + given(request.getRoutingKeyspace()).willReturn(KEYSPACE); + given(request.getRoutingKey()).willReturn(ROUTING_KEY); + // Use REGULAR as a stand-in for any "unknown" type - the switch has a default case + given(request.getRequestRoutingType()).willReturn(RequestRoutingType.REGULAR); + given(tokenMap.getReplicasList(KEYSPACE, null, ROUTING_KEY)) + .willReturn(ImmutableList.of(node1)); + + // When + RequestRoutingMethod method = policy.getDefaultLWTRequestRoutingMethod(request); + + // Then - defaults to REGULAR for any unrecognized type + assertThat(method).isEqualTo(RequestRoutingMethod.REGULAR); + } + + @Test + public void should_consistently_route_same_request_type() { + // Given + initPolicy("PRESERVE_REPLICA_ORDER"); + given(request.getRoutingKeyspace()).willReturn(KEYSPACE); + given(request.getRoutingKey()).willReturn(ROUTING_KEY); + given(request.getRequestRoutingType()).willReturn(RequestRoutingType.LWT); + given(tokenMap.getReplicasList(KEYSPACE, null, ROUTING_KEY)) + .willReturn(ImmutableList.of(node1, node2, node3)); + + // When - call multiple times + RequestRoutingMethod method1 = policy.getDefaultLWTRequestRoutingMethod(request); + RequestRoutingMethod method2 = policy.getDefaultLWTRequestRoutingMethod(request); + RequestRoutingMethod method3 = policy.getDefaultLWTRequestRoutingMethod(request); + + // Then - should always return the same method + assertThat(method1).isEqualTo(RequestRoutingMethod.PRESERVE_REPLICA_ORDER); + assertThat(method2).isEqualTo(RequestRoutingMethod.PRESERVE_REPLICA_ORDER); + assertThat(method3).isEqualTo(RequestRoutingMethod.PRESERVE_REPLICA_ORDER); + } +} diff --git a/core/src/test/java/com/datastax/oss/driver/internal/core/loadbalancing/DefaultLoadBalancingPolicyRequestTrackerTest.java b/core/src/test/java/com/datastax/oss/driver/internal/core/loadbalancing/DefaultLoadBalancingPolicyRequestTrackerTest.java index 757af43ef67..aa890778804 100644 --- a/core/src/test/java/com/datastax/oss/driver/internal/core/loadbalancing/DefaultLoadBalancingPolicyRequestTrackerTest.java +++ b/core/src/test/java/com/datastax/oss/driver/internal/core/loadbalancing/DefaultLoadBalancingPolicyRequestTrackerTest.java @@ -28,8 +28,11 @@ import java.util.UUID; import org.junit.Before; import org.junit.Test; +import org.junit.runner.RunWith; import org.mockito.Mock; +import org.mockito.junit.MockitoJUnitRunner; +@RunWith(MockitoJUnitRunner.Silent.class) public class DefaultLoadBalancingPolicyRequestTrackerTest extends LoadBalancingPolicyTestBase { @Mock Request request; diff --git a/core/src/test/java/com/datastax/oss/driver/internal/core/loadbalancing/LoadBalancingPolicyTestBase.java b/core/src/test/java/com/datastax/oss/driver/internal/core/loadbalancing/LoadBalancingPolicyTestBase.java index c9149efa69f..b301433ed64 100644 --- a/core/src/test/java/com/datastax/oss/driver/internal/core/loadbalancing/LoadBalancingPolicyTestBase.java +++ b/core/src/test/java/com/datastax/oss/driver/internal/core/loadbalancing/LoadBalancingPolicyTestBase.java @@ -35,14 +35,11 @@ import com.datastax.oss.driver.shaded.guava.common.collect.ImmutableList; import org.junit.After; import org.junit.Before; -import org.junit.runner.RunWith; import org.mockito.ArgumentCaptor; import org.mockito.Captor; import org.mockito.Mock; -import org.mockito.junit.MockitoJUnitRunner; import org.slf4j.LoggerFactory; -@RunWith(MockitoJUnitRunner.class) public abstract class LoadBalancingPolicyTestBase { @Mock protected DefaultNode node1; @@ -81,6 +78,9 @@ public void setup() { DefaultDriverOption.LOAD_BALANCING_DC_FAILOVER_ALLOW_FOR_LOCAL_CONSISTENCY_LEVELS)) .thenReturn(false); when(defaultProfile.getString(DefaultDriverOption.REQUEST_CONSISTENCY)).thenReturn("ONE"); + when(defaultProfile.getString( + DefaultDriverOption.LOAD_BALANCING_DEFAULT_LWT_REQUEST_ROUTING_METHOD)) + .thenReturn("REGULAR"); when(context.getMetadataManager()).thenReturn(metadataManager); diff --git a/core/src/test/java/com/datastax/oss/driver/internal/core/tracker/RequestLogFormatterTest.java b/core/src/test/java/com/datastax/oss/driver/internal/core/tracker/RequestLogFormatterTest.java index c482afe7a47..e175acc267b 100644 --- a/core/src/test/java/com/datastax/oss/driver/internal/core/tracker/RequestLogFormatterTest.java +++ b/core/src/test/java/com/datastax/oss/driver/internal/core/tracker/RequestLogFormatterTest.java @@ -28,6 +28,7 @@ import com.datastax.oss.driver.api.core.DefaultProtocolVersion; import com.datastax.oss.driver.api.core.ProtocolVersion; +import com.datastax.oss.driver.api.core.RequestRoutingType; import com.datastax.oss.driver.api.core.context.DriverContext; import com.datastax.oss.driver.api.core.cql.BatchStatement; import com.datastax.oss.driver.api.core.cql.BoundStatement; @@ -294,6 +295,6 @@ private PreparedStatement mockPreparedStatement(String query, Map sessionRule = SessionRule.builder(ccmRule).build(); + private final SessionRule SESSION_RULE = SessionRule.builder(CCM_RULE).build(); - @Rule public TestRule chain = RuleChain.outerRule(ccmRule).around(sessionRule); + @Rule public TestRule chain = RuleChain.outerRule(CCM_RULE).around(SESSION_RULE); @Rule public TestName name = new TestName(); @@ -89,11 +91,11 @@ public void createTable() { SchemaChangeSynchronizer.withLock( () -> { for (String schemaStatement : schemaStatements) { - sessionRule + SESSION_RULE .session() .execute( SimpleStatement.newInstance(schemaStatement) - .setExecutionProfile(sessionRule.slowProfile())); + .setExecutionProfile(SESSION_RULE.slowProfile())); } }); } @@ -103,7 +105,7 @@ public void should_issue_log_warn_if_batched_statement_have_consistency_level_se SimpleStatement simpleStatement = SimpleStatement.builder("INSERT INTO test (k0, k1, v) values ('123123', ?, ?)").build(); - try (CqlSession session = SessionUtils.newSession(ccmRule, sessionRule.keyspace())) { + try (CqlSession session = SessionUtils.newSession(CCM_RULE, SESSION_RULE.keyspace())) { PreparedStatement prep = session.prepare(simpleStatement); BatchStatementBuilder batch = BatchStatement.builder(DefaultBatchType.UNLOGGED); batch.addStatement(prep.bind(1, 2).setConsistencyLevel(ConsistencyLevel.QUORUM)); @@ -139,7 +141,7 @@ public void should_execute_batch_of_simple_statements_with_variables() { } BatchStatement batchStatement = builder.build(); - sessionRule.session().execute(batchStatement); + SESSION_RULE.session().execute(batchStatement); verifyBatchInsert(); } @@ -154,14 +156,14 @@ public void should_execute_batch_of_bound_statements_with_variables() { String.format( "INSERT INTO test (k0, k1, v) values ('%s', ? , ?)", name.getMethodName())) .build(); - PreparedStatement preparedStatement = sessionRule.session().prepare(insert); + PreparedStatement preparedStatement = SESSION_RULE.session().prepare(insert); for (int i = 0; i < batchCount; i++) { builder.addStatement(preparedStatement.bind(i, i + 1)); } BatchStatement batchStatement = builder.build(); - sessionRule.session().execute(batchStatement); + SESSION_RULE.session().execute(batchStatement); verifyBatchInsert(); } @@ -178,14 +180,14 @@ public void should_execute_batch_of_bound_statements_with_unset_values() { String.format( "INSERT INTO test (k0, k1, v) values ('%s', ? , ?)", name.getMethodName())) .build(); - PreparedStatement preparedStatement = sessionRule.session().prepare(insert); + PreparedStatement preparedStatement = SESSION_RULE.session().prepare(insert); for (int i = 0; i < batchCount; i++) { builder.addStatement(preparedStatement.bind(i, i + 1)); } BatchStatement batchStatement = builder.build(); - sessionRule.session().execute(batchStatement); + SESSION_RULE.session().execute(batchStatement); verifyBatchInsert(); @@ -196,17 +198,17 @@ public void should_execute_batch_of_bound_statements_with_unset_values() { if (i % 20 == 0) { boundStatement = boundStatement.unset(1); } - builder.addStatement(boundStatement); + builder2.addStatement(boundStatement); } - sessionRule.session().execute(builder2.build()); + SESSION_RULE.session().execute(builder2.build()); Statement select = SimpleStatement.builder("SELECT * from test where k0 = ?") .addPositionalValue(name.getMethodName()) .build(); - ResultSet result = sessionRule.session().execute(select); + ResultSet result = SESSION_RULE.session().execute(select); List rows = result.all(); assertThat(rows).hasSize(100); @@ -230,7 +232,7 @@ public void should_execute_batch_of_bound_statements_with_named_variables() { // variable values. BatchStatementBuilder builder = BatchStatement.builder(DefaultBatchType.UNLOGGED); PreparedStatement preparedStatement = - sessionRule.session().prepare("INSERT INTO test (k0, k1, v) values (:k0, :k1, :v)"); + SESSION_RULE.session().prepare("INSERT INTO test (k0, k1, v) values (:k0, :k1, :v)"); for (int i = 0; i < batchCount; i++) { builder.addStatement( @@ -243,7 +245,7 @@ public void should_execute_batch_of_bound_statements_with_named_variables() { } BatchStatement batchStatement = builder.build(); - sessionRule.session().execute(batchStatement); + SESSION_RULE.session().execute(batchStatement); verifyBatchInsert(); } @@ -257,7 +259,7 @@ public void should_execute_batch_of_bound_and_simple_statements_with_variables() String.format( "INSERT INTO test (k0, k1, v) values ('%s', ? , ?)", name.getMethodName())) .build(); - PreparedStatement preparedStatement = sessionRule.session().prepare(insert); + PreparedStatement preparedStatement = SESSION_RULE.session().prepare(insert); for (int i = 0; i < batchCount; i++) { if (i % 2 == 1) { @@ -274,7 +276,7 @@ public void should_execute_batch_of_bound_and_simple_statements_with_variables() } BatchStatement batchStatement = builder.build(); - sessionRule.session().execute(batchStatement); + SESSION_RULE.session().execute(batchStatement); verifyBatchInsert(); } @@ -284,25 +286,53 @@ public void should_execute_cas_batch() { // Build a batch with CAS operations on the same partition. BatchStatementBuilder builder = BatchStatement.builder(DefaultBatchType.UNLOGGED); SimpleStatement insert = - SimpleStatement.builder( - String.format( - "INSERT INTO test (k0, k1, v) values ('%s', ? , ?) IF NOT EXISTS", - name.getMethodName())) + SimpleStatement.builder("INSERT INTO test (k0, k1, v) values (?, ?, ?) IF NOT EXISTS") .build(); - PreparedStatement preparedStatement = sessionRule.session().prepare(insert); + PreparedStatement preparedStatement = SESSION_RULE.session().prepare(insert); for (int i = 0; i < batchCount; i++) { - builder.addStatement(preparedStatement.bind(i, i + 1)); + builder.addStatement(preparedStatement.bind(name.getMethodName(), i, i + 1)); + } + + // Ensure LWT routing has a concrete routing key to compute replicas. + BoundStatement routingKeyStmt = preparedStatement.bind(name.getMethodName(), 0, 1); + builder.setRoutingKey(routingKeyStmt.getRoutingKey()); + builder.setSerialConsistencyLevel(ConsistencyLevel.SERIAL); + // Enforce LWT routing only for Cassandra where prepare metadata lacks LWT flags. + if (CcmBridge.isDistributionOf(BackendType.CASSANDRA)) { + builder.setRequestRoutingType(RequestRoutingType.LWT); } BatchStatement batchStatement = builder.build(); - ResultSet result = sessionRule.session().execute(batchStatement); + // Validate serial consistency and LWT routing on the batch itself. + assertThat(batchStatement.getSerialConsistencyLevel()).isEqualTo(ConsistencyLevel.SERIAL); + assertThat(batchStatement.isLWT()).isEqualTo(true); + assertThat(batchStatement.getRoutingKey()).isNotNull(); + + ResultSet result = SESSION_RULE.session().execute(batchStatement); + // Validate that executed request preserved serial consistency level. + assertThat(result.getExecutionInfo().getRequest()).isInstanceOf(Statement.class); + assertThat(((Statement) result.getExecutionInfo().getRequest()).getSerialConsistencyLevel()) + .isEqualTo(ConsistencyLevel.SERIAL); assertThat(result.wasApplied()).isTrue(); verifyBatchInsert(); - // re execute same batch and ensure wasn't applied. - result = sessionRule.session().execute(batchStatement); + // Rebuild an equivalent batch and ensure it isn't applied. + BatchStatementBuilder rerunBuilder = BatchStatement.builder(DefaultBatchType.UNLOGGED); + rerunBuilder.setSerialConsistencyLevel(ConsistencyLevel.SERIAL); + for (int i = 0; i < batchCount; i++) { + rerunBuilder.addStatement(preparedStatement.bind(name.getMethodName(), i, i + 1)); + } + // Use the same routing key to target the same partition for LWT. + rerunBuilder.setRoutingKey(routingKeyStmt.getRoutingKey()); + // Enforce LWT routing only for Cassandra where prepare metadata lacks LWT flags. + if (CcmBridge.isDistributionOf(BackendType.CASSANDRA)) { + rerunBuilder.setRequestRoutingType(RequestRoutingType.LWT); + } + BatchStatement rerunBatch = rerunBuilder.build(); + assertThat(rerunBatch.isLWT()).isEqualTo(true); + result = SESSION_RULE.session().execute(rerunBatch); assertThat(result.wasApplied()).isFalse(); } @@ -322,11 +352,11 @@ public void should_execute_counter_batch() { } BatchStatement batchStatement = builder.build(); - sessionRule.session().execute(batchStatement); + SESSION_RULE.session().execute(batchStatement); for (int i = 1; i <= 3; i++) { ResultSet result = - sessionRule + SESSION_RULE .session() .execute( String.format( @@ -356,7 +386,7 @@ public void should_fail_logged_batch_with_counter_increment() { } BatchStatement batchStatement = builder.build(); - sessionRule.session().execute(batchStatement); + SESSION_RULE.session().execute(batchStatement); } @Test(expected = InvalidQueryException.class) @@ -383,7 +413,7 @@ public void should_fail_counter_batch_with_non_counter_increment() { builder.addStatement(simpleInsert); BatchStatement batchStatement = builder.build(); - sessionRule.session().execute(batchStatement); + SESSION_RULE.session().execute(batchStatement); } @Test @@ -394,13 +424,13 @@ public void should_not_allow_unset_value_when_protocol_less_than_v4() { SessionUtils.configLoaderBuilder() .withString(DefaultDriverOption.PROTOCOL_VERSION, "V3") .build(); - try (CqlSession v3Session = SessionUtils.newSession(ccmRule, loader)) { + try (CqlSession v3Session = SessionUtils.newSession(CCM_RULE, loader)) { // Intentionally use fully qualified table here to avoid warnings as these are not supported // by v3 protocol version, see JAVA-3068 PreparedStatement prepared = v3Session.prepare( String.format( - "INSERT INTO %s.test (k0, k1, v) values (?, ?, ?)", sessionRule.keyspace())); + "INSERT INTO %s.test (k0, k1, v) values (?, ?, ?)", SESSION_RULE.keyspace())); BatchStatementBuilder builder = BatchStatement.builder(DefaultBatchType.LOGGED); builder.addStatements( @@ -427,7 +457,7 @@ private void verifyBatchInsert() { .addPositionalValue(name.getMethodName()) .build(); - ResultSet result = sessionRule.session().execute(select); + ResultSet result = SESSION_RULE.session().execute(select); List rows = result.all(); assertThat(rows).hasSize(100); diff --git a/integration-tests/src/test/java/com/datastax/oss/driver/core/loadbalancing/LWTLoadBalancingIT.java b/integration-tests/src/test/java/com/datastax/oss/driver/core/loadbalancing/LWTLoadBalancingIT.java index 9e2d034a19f..0586b3236bb 100644 --- a/integration-tests/src/test/java/com/datastax/oss/driver/core/loadbalancing/LWTLoadBalancingIT.java +++ b/integration-tests/src/test/java/com/datastax/oss/driver/core/loadbalancing/LWTLoadBalancingIT.java @@ -27,6 +27,7 @@ import com.datastax.oss.driver.api.core.CqlIdentifier; import com.datastax.oss.driver.api.core.CqlSession; import com.datastax.oss.driver.api.core.ProtocolVersion; +import com.datastax.oss.driver.api.core.RequestRoutingType; import com.datastax.oss.driver.api.core.cql.BatchStatement; import com.datastax.oss.driver.api.core.cql.BatchType; import com.datastax.oss.driver.api.core.cql.PreparedStatement; @@ -43,6 +44,7 @@ import com.datastax.oss.driver.api.testinfra.session.SessionUtils; import java.nio.ByteBuffer; import java.util.HashSet; +import java.util.List; import java.util.Set; import org.junit.BeforeClass; import org.junit.ClassRule; @@ -76,24 +78,26 @@ public static void setup() { } @Test - public void should_use_only_one_node_when_lwt_detected() { + public void should_use_replicas_when_lwt_detected() { assumeTrue( CcmBridge.isDistributionOf(BackendType.SCYLLA)); // Functionality only available in Scylla CqlSession session = SESSION_RULE.session(); int pk = 1234; ByteBuffer routingKey = TypeCodecs.INT.encodePrimitive(pk, ProtocolVersion.DEFAULT); TokenMap tokenMap = SESSION_RULE.session().getMetadata().getTokenMap().get(); - Node owner = - tokenMap.getReplicasList(session.getKeyspace().get(), routingKey).iterator().next(); + List replicas = tokenMap.getReplicasList(session.getKeyspace().get(), routingKey); PreparedStatement statement = SESSION_RULE .session() .prepare("INSERT INTO foo (pk, ck, v) VALUES (?, ?, ?) IF NOT EXISTS"); assertThat(statement.isLWT()).isTrue(); - for (int i = 0; i < 30; i++) { + Set coordinators = new HashSet<>(); + for (int i = 0; i < 100; i++) { ResultSet result = session.execute(statement.bind(pk, i, 123)); - assertThat(result.getExecutionInfo().getCoordinator()).isEqualTo(owner); + coordinators.add(result.getExecutionInfo().getCoordinator()); } + assertThat(coordinators).isSubsetOf(replicas); + assertThat(coordinators.size()).isGreaterThan(0).isLessThanOrEqualTo(replicas.size()); } @Test @@ -116,22 +120,22 @@ public void should_not_use_only_one_node_when_non_lwt() { } @Test - public void should_use_only_one_node_when_lwt_batch_detected() { + public void should_use_replicas_when_lwt_batch_detected() { assumeTrue( CcmBridge.isDistributionOf(BackendType.SCYLLA)); // Functionality only available in Scylla CqlSession session = SESSION_RULE.session(); int pk = 1234; ByteBuffer routingKey = TypeCodecs.INT.encodePrimitive(pk, ProtocolVersion.DEFAULT); TokenMap tokenMap = SESSION_RULE.session().getMetadata().getTokenMap().get(); - Node owner = - tokenMap.getReplicasList(session.getKeyspace().get(), routingKey).iterator().next(); + List replicas = tokenMap.getReplicasList(session.getKeyspace().get(), routingKey); PreparedStatement statement = SESSION_RULE .session() .prepare("INSERT INTO foo (pk, ck, v) VALUES (?, ?, ?) IF NOT EXISTS"); assertThat(statement.isLWT()).isTrue(); - for (int i = 0; i < 30; i++) { + Set coordinatorsLwt = new HashSet<>(); + for (int i = 0; i < 100; i++) { BatchStatement batch = BatchStatement.newInstance(BatchType.UNLOGGED); SimpleStatement simpleStatement = SimpleStatement.newInstance( @@ -142,8 +146,10 @@ public void should_use_only_one_node_when_lwt_batch_detected() { batch = batch.add(statement.bind(pk, i, 123)); assertThat(batch.isLWT()).isTrue(); ResultSet result = session.execute(batch); - assertThat(result.getExecutionInfo().getCoordinator()).isEqualTo(owner); + coordinatorsLwt.add(result.getExecutionInfo().getCoordinator()); } + assertThat(coordinatorsLwt).isSubsetOf(replicas); + assertThat(coordinatorsLwt.size()).isGreaterThan(0).isLessThanOrEqualTo(replicas.size()); // Check if multiple coordinators are used when forcibly set to non-LWT Set coordinators = new HashSet<>(); @@ -156,7 +162,7 @@ public void should_use_only_one_node_when_lwt_batch_detected() { assertThat(simpleStatement.isLWT()).isFalse(); batch = batch.add(simpleStatement); batch = batch.add(statement.bind(pk, i, 123)); - batch = batch.setIsLWT(false); + batch = batch.setRequestRoutingType(RequestRoutingType.REGULAR); assertThat(batch.isLWT()).isFalse(); ResultSet result = session.execute(batch); coordinators.add(result.getExecutionInfo().getCoordinator()); diff --git a/integration-tests/src/test/java/com/datastax/oss/driver/core/loadbalancing/LWTLoadBalancingMultiDcIT.java b/integration-tests/src/test/java/com/datastax/oss/driver/core/loadbalancing/LWTLoadBalancingMultiDcIT.java new file mode 100644 index 00000000000..011c1f3ea0a --- /dev/null +++ b/integration-tests/src/test/java/com/datastax/oss/driver/core/loadbalancing/LWTLoadBalancingMultiDcIT.java @@ -0,0 +1,210 @@ +/* + * Copyright DataStax, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * Copyright (C) 2026 ScyllaDB + * + * Modified by ScyllaDB + */ +package com.datastax.oss.driver.core.loadbalancing; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.awaitility.Awaitility.await; + +import com.datastax.oss.driver.api.core.CqlIdentifier; +import com.datastax.oss.driver.api.core.CqlSession; +import com.datastax.oss.driver.api.core.ProtocolVersion; +import com.datastax.oss.driver.api.core.RequestRoutingType; +import com.datastax.oss.driver.api.core.Version; +import com.datastax.oss.driver.api.core.config.DefaultDriverOption; +import com.datastax.oss.driver.api.core.cql.BatchStatement; +import com.datastax.oss.driver.api.core.cql.BatchStatementBuilder; +import com.datastax.oss.driver.api.core.cql.BatchType; +import com.datastax.oss.driver.api.core.cql.PreparedStatement; +import com.datastax.oss.driver.api.core.cql.ResultSet; +import com.datastax.oss.driver.api.core.metadata.Node; +import com.datastax.oss.driver.api.core.metadata.TokenMap; +import com.datastax.oss.driver.api.core.type.codec.TypeCodecs; +import com.datastax.oss.driver.api.testinfra.ccm.CcmBridge; +import com.datastax.oss.driver.api.testinfra.ccm.CustomCcmRule; +import com.datastax.oss.driver.api.testinfra.requirement.BackendType; +import com.datastax.oss.driver.api.testinfra.session.SessionRule; +import com.datastax.oss.driver.api.testinfra.session.SessionUtils; +import java.nio.ByteBuffer; +import java.time.Duration; +import java.util.HashSet; +import java.util.Set; +import java.util.concurrent.TimeUnit; +import org.junit.BeforeClass; +import org.junit.ClassRule; +import org.junit.Test; +import org.junit.rules.RuleChain; +import org.junit.rules.TestRule; + +public class LWTLoadBalancingMultiDcIT { + private static final String LOCAL_DC = "dc1"; + private static final String KEYSPACE = "test"; + + private static final CustomCcmRule CCM_RULE = + CustomCcmRule.builder().withNodes(2, 1).build(); // 2 nodes in DC1, 1 node in DC2 + + private static final SessionRule SESSION_RULE = + SessionRule.builder(CCM_RULE) + .withKeyspace(false) + .withConfigLoader( + SessionUtils.configLoaderBuilder() + .withString(DefaultDriverOption.LOAD_BALANCING_LOCAL_DATACENTER, LOCAL_DC) + .withDuration(DefaultDriverOption.REQUEST_TIMEOUT, Duration.ofSeconds(30)) + .build()) + .build(); + + @ClassRule + public static final TestRule CHAIN = RuleChain.outerRule(CCM_RULE).around(SESSION_RULE); + + public static final int FIRST_TEST_PARTITION_KEY = 4242; + public static final int SECOND_TEST_PARTITION_KEY = 4343; + public static final int NUM_TEST_ITERATIONS = 30; + + @BeforeClass + public static void setup() { + CqlSession session = SESSION_RULE.session(); + + // Create multi-DC keyspace and table similarly to DefaultLoadBalancingPolicyIT. + if (CcmBridge.isDistributionOf(BackendType.SCYLLA) + && ((CcmBridge.SCYLLA_ENTERPRISE + && CcmBridge.getDistributionVersion().compareTo(Version.parse("2023.1.0")) >= 0) + || (!CcmBridge.SCYLLA_ENTERPRISE + && CcmBridge.getDistributionVersion().compareTo(Version.parse("6.1.0")) >= 0))) { + session.execute( + "CREATE KEYSPACE test " + + "WITH replication = {'class': 'NetworkTopologyStrategy', 'dc1': 2, 'dc2': 1} " + + "AND tablets = { 'enabled': false }"); + } else { + session.execute( + "CREATE KEYSPACE test " + + "WITH replication = {'class': 'NetworkTopologyStrategy', 'dc1': 2, 'dc2': 1}"); + } + + session.execute("CREATE TABLE test.foo (pk int, ck int, v int, PRIMARY KEY (pk, ck))"); + + // Wait for schema readiness + await() + .pollInterval(200, TimeUnit.MILLISECONDS) + .atMost(60, TimeUnit.SECONDS) + .untilAsserted( + () -> { + assertThat(session.getMetadata().getKeyspace(KEYSPACE)).isPresent(); + TokenMap tm = session.getMetadata().getTokenMap().get(); + ByteBuffer routingKey = + TypeCodecs.INT.encodePrimitive(FIRST_TEST_PARTITION_KEY, ProtocolVersion.DEFAULT); + Set replicas = + new HashSet<>(tm.getReplicasList(CqlIdentifier.fromCql(KEYSPACE), routingKey)); + assertThat(replicas).hasSize(3); // RF 2 in dc1, 1 in dc2 + assertThat(replicas.stream().filter(n -> LOCAL_DC.equals(n.getDatacenter()))) + .hasSizeGreaterThanOrEqualTo(1); + }); + } + + @Test + public void should_route_lwt_to_local_dc_replicas() { + int pk = FIRST_TEST_PARTITION_KEY; + CqlIdentifier keyspace = CqlIdentifier.fromCql(KEYSPACE); + ByteBuffer routingKey = TypeCodecs.INT.encodePrimitive(pk, ProtocolVersion.DEFAULT); + + TokenMap tokenMap = SESSION_RULE.session().getMetadata().getTokenMap().get(); + Set localReplicas = new HashSet<>(); + Set allReplicas = new HashSet<>(tokenMap.getReplicasList(keyspace, routingKey)); + for (Node replica : allReplicas) { + if (LOCAL_DC.equals(replica.getDatacenter())) { + localReplicas.add(replica); + } + } + assertThat(localReplicas).isNotEmpty(); + + PreparedStatement lwt = + SESSION_RULE + .session() + .prepare("INSERT INTO test.foo (pk, ck, v) VALUES (?, ?, ?) IF NOT EXISTS"); + // Cassandra does not expose LWT flag via prepare metadata; driver may not detect LWT. + if (!CcmBridge.isDistributionOf(BackendType.CASSANDRA)) { + assertThat(lwt.isLWT()).isTrue(); + } + + Set coordinators = new HashSet<>(); + Set coordinatorDcs = new HashSet<>(); + for (int i = 0; i < NUM_TEST_ITERATIONS; i++) { + ResultSet result = SESSION_RULE.session().execute(lwt.bind(pk, i, 7)); + Node coord = result.getExecutionInfo().getCoordinator(); + coordinators.add(coord); + coordinatorDcs.add(coord.getDatacenter()); + } + + assertThat(coordinators).isSubsetOf(allReplicas); + assertThat(coordinators).isSubsetOf(localReplicas); + assertThat(coordinatorDcs).containsOnly(LOCAL_DC); + } + + @Test + public void should_route_lwt_batch_to_local_dc_replicas() { + int pk = SECOND_TEST_PARTITION_KEY; + CqlIdentifier keyspace = CqlIdentifier.fromCql(KEYSPACE); + ByteBuffer routingKey = TypeCodecs.INT.encodePrimitive(pk, ProtocolVersion.DEFAULT); + + TokenMap tokenMap = SESSION_RULE.session().getMetadata().getTokenMap().get(); + Set localReplicas = new HashSet<>(); + Set allReplicas = new HashSet<>(tokenMap.getReplicasList(keyspace, routingKey)); + for (Node replica : allReplicas) { + if (LOCAL_DC.equals(replica.getDatacenter())) { + localReplicas.add(replica); + } + } + assertThat(localReplicas).isNotEmpty(); + + PreparedStatement lwt = + SESSION_RULE + .session() + .prepare("INSERT INTO test.foo (pk, ck, v) VALUES (?, ?, ?) IF NOT EXISTS"); + PreparedStatement nonLwtPrepared = + SESSION_RULE.session().prepare("INSERT INTO test.foo (pk, ck, v) VALUES (?, ?, ?)"); + + // Run a bunch of times to exercise load balancing. + Set coordinators = new HashSet<>(); + Set coordinatorDcs = new HashSet<>(); + for (int i = 0; i < NUM_TEST_ITERATIONS; i++) { + BatchStatementBuilder builder = + BatchStatement.builder(BatchType.UNLOGGED) + .setRoutingKeyspace(keyspace) + .setRoutingKey(routingKey) + .addStatement(nonLwtPrepared.bind(pk, 0, 101)) + .addStatement(lwt.bind(pk, i, 202)); + // Ensure LWT routing type on Cassandra where detection may be absent + if (CcmBridge.isDistributionOf(BackendType.CASSANDRA)) { + builder = builder.setRequestRoutingType(RequestRoutingType.LWT); + } + BatchStatement batch = builder.build(); + assertThat(batch.isLWT()).isTrue(); + + ResultSet result = SESSION_RULE.session().execute(batch); + Node coord = result.getExecutionInfo().getCoordinator(); + coordinators.add(coord); + coordinatorDcs.add(coord.getDatacenter()); + } + + assertThat(coordinators).isSubsetOf(allReplicas); + assertThat(coordinators).isSubsetOf(localReplicas); + assertThat(coordinatorDcs).containsOnly(LOCAL_DC); + } +} diff --git a/integration-tests/src/test/java/com/datastax/oss/driver/core/metadata/NodeStateIT.java b/integration-tests/src/test/java/com/datastax/oss/driver/core/metadata/NodeStateIT.java index e468e0a10d7..dc7590da2ec 100644 --- a/integration-tests/src/test/java/com/datastax/oss/driver/core/metadata/NodeStateIT.java +++ b/integration-tests/src/test/java/com/datastax/oss/driver/core/metadata/NodeStateIT.java @@ -57,6 +57,7 @@ import com.datastax.oss.simulacron.server.BoundNode; import com.datastax.oss.simulacron.server.RejectScope; import edu.umd.cs.findbugs.annotations.NonNull; +import edu.umd.cs.findbugs.annotations.Nullable; import java.io.IOException; import java.net.InetSocketAddress; import java.net.ServerSocket; @@ -703,7 +704,7 @@ public void stopIgnoring(Node node) { @NonNull @Override - public Queue newQueryPlan(@NonNull Request request, @NonNull Session session) { + public Queue newQueryPlan(@Nullable Request request, @Nullable Session session) { Object[] snapshot = liveNodes.toArray(); Queue queryPlan = new ConcurrentLinkedQueue<>(); int start = offset.getAndIncrement(); // Note: offset overflow won't be an issue in tests diff --git a/integration-tests/src/test/java/com/datastax/oss/driver/example/guava/internal/KeyRequest.java b/integration-tests/src/test/java/com/datastax/oss/driver/example/guava/internal/KeyRequest.java index ef582cce1b9..3c15ed4db52 100644 --- a/integration-tests/src/test/java/com/datastax/oss/driver/example/guava/internal/KeyRequest.java +++ b/integration-tests/src/test/java/com/datastax/oss/driver/example/guava/internal/KeyRequest.java @@ -18,6 +18,7 @@ package com.datastax.oss.driver.example.guava.internal; import com.datastax.oss.driver.api.core.CqlIdentifier; +import com.datastax.oss.driver.api.core.RequestRoutingType; import com.datastax.oss.driver.api.core.config.DriverExecutionProfile; import com.datastax.oss.driver.api.core.metadata.Node; import com.datastax.oss.driver.api.core.metadata.token.Token; @@ -94,4 +95,10 @@ public Duration getTimeout() { public Node getNode() { return null; } + + @NonNull + @Override + public RequestRoutingType getRequestRoutingType() { + return RequestRoutingType.REGULAR; + } }