From 0aaee472e98853f4d9d374604b31def1b7fa9e8b Mon Sep 17 00:00:00 2001 From: ziqi liu Date: Fri, 19 Jun 2026 19:12:07 +0000 Subject: [PATCH 1/4] [SPARK-57547][SQL] Track distinct materialized partitions to prevent cached relation row loss `CachedRDDBuilder` tracked materialization with a plain task-completion `LongAccumulator` and summed per-batch row-count and size stats. When a cold cache is first touched concurrently, the same partition can be computed more than once across executors (duplicate cross-executor computes). Those duplicates inflate the task-completion count and double-count the summed stats. The inflated count can make `isCachedRDDLoaded` report the relation as fully materialized before the distinct partitions have actually been recorded, and `computeStats` can then expose a `rowCount` that does not reflect the cached data. AQE's `PropagateEmptyRelation` can read that incorrect `rowCount` and collapse the cached source to an empty relation, silently dropping rows. This change adds `PartitionKeyedAccumulator`, a `ConcurrentHashMap`-backed accumulator keyed by partition id with last-write-wins merge semantics. The cached relation now: - counts the DISTINCT materialized partition ids (the accumulator key set) when deciding whether the cache is fully loaded, so duplicate computes cannot inflate the count; and - derives exact, de-duplicated row-count and size stats by folding the per-partition values, counting each partition once. The behavior is gated by a new internal conf `spark.sql.inMemoryColumnarStorage.distinctPartitionTracking` (default true); setting it to false restores the prior raw task-completion-count behavior. `clearCache` resets the bookkeeping so a rebuilt cache starts clean. ### Tests - `PartitionKeyedAccumulatorSuite` - accumulator semantics (last-write-wins add/merge, distinct key count, snapshot/reset). - `ConcurrentInMemoryRelationSuite` - local-cluster reproduction: rows are preserved under concurrent first-touch with the fix on; stats are exact under duplicate cross-executor computes; and a negative control showing the row loss with the fix disabled. - Extended `CachedTableSuite` (clearCache resets bookkeeping) and `InMemoryColumnarQuerySuite` (size/row-count read through the new accessors). Co-authored-by: Isaac --- .../apache/spark/sql/internal/SQLConf.scala | 15 + .../execution/columnar/InMemoryRelation.scala | 134 ++++++- .../sql/util/PartitionKeyedAccumulator.scala | 89 +++++ .../apache/spark/sql/CachedTableSuite.scala | 35 +- .../ConcurrentInMemoryRelationSuite.scala | 334 ++++++++++++++++++ .../columnar/InMemoryColumnarQuerySuite.scala | 4 +- .../util/PartitionKeyedAccumulatorSuite.scala | 107 ++++++ 7 files changed, 702 insertions(+), 16 deletions(-) create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/util/PartitionKeyedAccumulator.scala create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/ConcurrentInMemoryRelationSuite.scala create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/util/PartitionKeyedAccumulatorSuite.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 2813dad9400ff..04658ab3b0a61 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -861,6 +861,21 @@ object SQLConf { .booleanConf .createWithDefault(true) + val IN_MEMORY_DISTINCT_PARTITION_TRACKING = + buildConf("spark.sql.inMemoryColumnarStorage.distinctPartitionTracking") + .internal() + .doc("When true, a cached relation is considered fully materialized only once the set of " + + "DISTINCT materialized partitions covers every partition, rather than once a raw " + + "task-completion count reaches the partition count; clearCache also resets that " + + "bookkeeping. This prevents concurrent first-touch of a cold cache under AQE (where a " + + "cached source referenced several times has each reference launch its own build job) " + + "from letting duplicate partition completions mark the cache loaded with rowCount 0, " + + "which would make AQE wrongly propagate an empty relation and silently drop rows. When " + + "false, restores the prior raw task-completion-count behavior.") + .version("5.0.0") + .booleanConf + .createWithDefault(true) + val IN_MEMORY_TABLE_SCAN_STATISTICS_ENABLED = buildConf("spark.sql.inMemoryTableScanStatistics.enable") .internal() diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala index 9c012dbd58e12..98979db9c3ae0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.execution.columnar +import scala.util.{Left, Right} + import com.esotericsoftware.kryo.{DefaultSerializer, Kryo, Serializer => KryoSerializer} import com.esotericsoftware.kryo.io.{Input => KryoInput, Output => KryoOutput} @@ -36,6 +38,7 @@ import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanExec import org.apache.spark.sql.execution.vectorized.{OffHeapColumnVector, OnHeapColumnVector, WritableColumnVector} import org.apache.spark.sql.internal.{SQLConf, StaticSQLConf} import org.apache.spark.sql.types._ +import org.apache.spark.sql.util.PartitionKeyedAccumulator import org.apache.spark.sql.vectorized.{ColumnarBatch, ColumnVector} import org.apache.spark.storage.StorageLevel import org.apache.spark.util.{LongAccumulator, Utils} @@ -250,6 +253,17 @@ class DefaultCachedBatchSerializer extends SimpleMetricsCachedBatchSerializer { } } +// The legacy per-build accumulators, allocated together only on the off path (the `Left` of +// `statsAccumulators` below). `sizeInBytesStats` / `rowCountStats` are the per-batch stat sums +// (the reported row count / size on that path); `materializedPartitions` is the raw task-completion +// count driving the loaded check. All over-count under duplicate computes -- the pre-fix behavior. +// Top-level (not nested in CachedRDDBuilder) so the accumulators carry no outer reference when +// captured by the build task closure. +private case class LegacyAccumulators( + sizeInBytesStats: LongAccumulator, + rowCountStats: LongAccumulator, + materializedPartitions: LongAccumulator) + private[sql] case class CachedRDDBuilder( serializer: CachedBatchSerializer, @@ -261,9 +275,30 @@ case class CachedRDDBuilder( @transient @volatile private var _cachedColumnBuffers: RDD[CachedBatch] = null @transient @volatile private var _cachedColumnBuffersAreLoaded: Boolean = false - val sizeInBytesStats: LongAccumulator = cachedPlan.session.sparkContext.longAccumulator - val rowCountStats: LongAccumulator = cachedPlan.session.sparkContext.longAccumulator - private val materializedPartitions = cachedPlan.session.sparkContext.longAccumulator + // The cache's materialization bookkeeping, chosen ONCE here (at construction) from the + // IN_MEMORY_DISTINCT_PARTITION_TRACKING conf, so the scheme can never change mid-build. Exactly + // one branch is allocated: + // - Right (the fix, default): a partition-keyed accumulator storing (rowCount, sizeInBytes) per + // partition. AQE creates a separate cache scan stage per reference to the same cache and each + // submits its own build job, so the same partition can be computed by several concurrent jobs + // (and speculative tasks); Spark has no global cross-executor "compute this partition once" + // barrier (only a per-executor write lock). Keying by partition id (last-write-wins) means + // those duplicate completions cannot mark the cache loaded before every partition has been + // computed -- which otherwise let AQE read rowCount 0 on a non-empty cache and propagate an + // empty relation, silently dropping rows -- and also yields exact, de-duplicated row count / + // size. + // - Left (conf off): the raw per-batch accumulators (`LegacyAccumulators`) -- the buggy pre-fix + // behavior, kept only as a safety switch. + private val statsAccumulators + : Either[LegacyAccumulators, PartitionKeyedAccumulator[(Long, Long)]] = + if (cachedPlan.conf.getConf(SQLConf.IN_MEMORY_DISTINCT_PARTITION_TRACKING)) { + val acc = new PartitionKeyedAccumulator[(Long, Long)] + cachedPlan.session.sparkContext.register(acc) + Right(acc) + } else { + val sc = cachedPlan.session.sparkContext + Left(LegacyAccumulators(sc.longAccumulator, sc.longAccumulator, sc.longAccumulator)) + } val cachedName = tableName.map(n => s"In-memory table $n") .getOrElse(Utils.abbreviate(cachedPlan.toString, 1024)) @@ -284,6 +319,16 @@ case class CachedRDDBuilder( if (_cachedColumnBuffers != null) { _cachedColumnBuffers.unpersist(blocking) _cachedColumnBuffers = null + statsAccumulators match { + case Right(keyed) => + // The buffers no longer back a live RDD. Reset the one-way "loaded" latch and the keyed + // bookkeeping so a rebuild on this builder does not inherit a stale "loaded" state or + // stale statistics. Safe to reset in place: every read of `keyed` is under this monitor. + _cachedColumnBuffersAreLoaded = false + keyed.reset() + case Left(_) => + // Pre-fix behavior, kept only for the safety switch: do not reset the bookkeeping. + } } } @@ -296,9 +341,16 @@ case class CachedRDDBuilder( // We must make sure the statistics of `sizeInBytes` and `rowCount` are accurate if // `isCachedRDDLoaded` return true. Otherwise, AQE would do a wrong optimization, // e.g., convert a non-empty plan to empty local relation if `rowCount` is 0. - // Because the statistics is based on accumulator, here we use an extra accumulator to - // track if all partitions are materialized. - val rddLoaded = _cachedColumnBuffers.partitions.length == materializedPartitions.value + val numMaterialized = statsAccumulators match { + // Count DISTINCT materialized partitions (the keyed accumulator's key set), so the cache is + // only reported loaded once every partition has been computed -- sound even if a partition + // is computed more than once by concurrent or speculative tasks. + case Right(keyed) => keyed.accumulatedNumPartitions + // Buggy pre-fix behavior (safety switch only): a raw task-completion count, which duplicate + // computes can inflate past the partition count before every partition is materialized. + case Left(legacy) => legacy.materializedPartitions.value.longValue + } + val rddLoaded = _cachedColumnBuffers.partitions.length.toLong == numMaterialized if (rddLoaded) { _cachedColumnBuffersAreLoaded = rddLoaded } @@ -306,6 +358,32 @@ case class CachedRDDBuilder( } } + // Reported row count / size for the cache's statistics. On the default path these are exact and + // de-duplicated (folded over distinct materialized partitions); on the off path they fall back to + // the raw per-batch accumulators (which can over-count). Synchronized so a fold never races a + // concurrent `clearCache` reset. + private[sql] def materializedRowCount: Long = synchronized { + statsAccumulators match { + case Right(keyed) => keyed.foldValues(0L)((sum, v) => sum + v._1) + case Left(legacy) => legacy.rowCountStats.value.longValue + } + } + + private[sql] def materializedSizeInBytes: Long = synchronized { + statsAccumulators match { + case Right(keyed) => keyed.foldValues(0L)((sum, v) => sum + v._2) + case Left(legacy) => legacy.sizeInBytesStats.value.longValue + } + } + + // The id of the accumulator backing this cache's materialization bookkeeping (the keyed + // accumulator on the default path, a legacy stat accumulator otherwise). Exposed only so + // `CachedTableSuite`'s accumulator-cleanup test can verify it is cleared after uncache + GC. + private[sql] def materializationAccumulatorId: Long = statsAccumulators match { + case Left(legacy) => legacy.sizeInBytesStats.id + case Right(keyed) => keyed.id + } + private def buildBuffers(): RDD[CachedBatch] = { val cb = try { if (supportsColumnarInput) { @@ -330,18 +408,50 @@ case class CachedRDDBuilder( session.sharedState.cacheManager.recacheByPlan(session, logicalPlan) throw e } + // Records one successful partition materialization, using the scheme selected by the conf. + // Built once on the driver so the task closure below captures only the chosen accumulator. The + // default path records this partition's (rows, bytes) keyed by its id; the off path just bumps + // a raw completion count. + val recordMaterialized: (Int, Long, Long) => Unit = statsAccumulators match { + case Right(keyed) => + (partitionId: Int, rows: Long, bytes: Long) => + keyed.add((partitionId, (rows, bytes))) + case Left(legacy) => + val completions = legacy.materializedPartitions + (_: Int, _: Long, _: Long) => completions.add(1L) + } + // On the default path the keyed accumulator is authoritative for the cache stats, so the legacy + // per-batch accumulators are not fed; only the off path uses them. Extract them once on the + // driver (None on the default path) so the task closure captures just what it needs. + val legacyStatAccs: Option[(LongAccumulator, LongAccumulator)] = statsAccumulators match { + case Left(legacy) => Some((legacy.sizeInBytesStats, legacy.rowCountStats)) + case Right(_) => None + } val cached = cb.mapPartitionsInternal { it => - TaskContext.get().addTaskCompletionListener[Unit] { context => + val taskContext = TaskContext.get() + val partitionId = taskContext.partitionId() + // This task computes exactly one partition. On the default path, tally its totals so the + // completion listener records them once (covering empty-output partitions, which produce no + // batches); on the off path the raw per-batch accumulators below are fed directly. + var localRows = 0L + var localBytes = 0L + taskContext.addTaskCompletionListener[Unit] { context => if (!context.isFailed() && !context.isInterrupted()) { - materializedPartitions.add(1L) + recordMaterialized(partitionId, localRows, localBytes) } } new Iterator[CachedBatch] { override def hasNext: Boolean = it.hasNext override def next(): CachedBatch = { val batch = it.next() - sizeInBytesStats.add(batch.sizeInBytes) - rowCountStats.add(batch.numRows) + legacyStatAccs match { + case Some((sizeAcc, rowAcc)) => + sizeAcc.add(batch.sizeInBytes) + rowAcc.add(batch.numRows) + case None => + localBytes += batch.sizeInBytes + localRows += batch.numRows + } batch } } @@ -460,8 +570,8 @@ case class InMemoryRelation( statsOfPlanToCache } else { statsOfPlanToCache.copy( - sizeInBytes = cacheBuilder.sizeInBytesStats.value.longValue, - rowCount = Some(cacheBuilder.rowCountStats.value.longValue) + sizeInBytes = cacheBuilder.materializedSizeInBytes, + rowCount = Some(cacheBuilder.materializedRowCount) ) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/util/PartitionKeyedAccumulator.scala b/sql/core/src/main/scala/org/apache/spark/sql/util/PartitionKeyedAccumulator.scala new file mode 100644 index 0000000000000..4785010571a18 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/util/PartitionKeyedAccumulator.scala @@ -0,0 +1,89 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.util + +import java.util.concurrent.ConcurrentHashMap + +import org.apache.spark.util.AccumulatorV2 + +/** + * An [[AccumulatorV2]] that records one value of type `T` per partition, keyed by partition id with + * LAST-WRITE-WINS merge. When the same partition is recorded more than once -- e.g. duplicate + * cross-executor computes, or speculative tasks -- the later value replaces the earlier one rather + * than aggregating, so each partition contributes exactly once. The key set is the set of recorded + * partitions, and callers fold the values (see [[foldValues]]) to derive de-duplicated aggregates; + * a plain summing accumulator would instead over-count under duplicate computes. + * + * `add` is expected to be called once per task (e.g. from a task completion listener) with that + * partition's value, so a partition is recorded even when it produced nothing. Updates from + * failed/interrupted tasks are dropped by the accumulator framework (it is not + * `countFailedValues`), so only complete per-partition values are ever merged. + * + * Backed by a [[ConcurrentHashMap]], whose per-entry atomicity is sufficient here: `add` and the + * `putAll` in `merge` are last-write-wins per key, and the reads (`value`, `numPartitions`, + * `foldValues`) only require thread-safety and eventual consistency -- they are weakly consistent + * during concurrent updates but exact once all updates have been merged. This avoids any explicit + * locking (and the nested-lock pattern a two-map `merge` would otherwise need). + * + * @tparam T the per-partition value type. Must be non-null ([[ConcurrentHashMap]] forbids nulls). + */ +class PartitionKeyedAccumulator[T] extends AccumulatorV2[(Int, T), java.util.Map[Int, T]] { + + // partition id -> value. + private val byPartition = new ConcurrentHashMap[Int, T]() + + override def isZero: Boolean = byPartition.isEmpty + + override def copyAndReset(): PartitionKeyedAccumulator[T] = new PartitionKeyedAccumulator[T] + + override def copy(): PartitionKeyedAccumulator[T] = { + val newAcc = new PartitionKeyedAccumulator[T] + newAcc.byPartition.putAll(byPartition) + newAcc + } + + override def reset(): Unit = byPartition.clear() + + override def add(v: (Int, T)): Unit = byPartition.put(v._1, v._2) + + override def merge(other: AccumulatorV2[(Int, T), java.util.Map[Int, T]]): Unit = other match { + case o: PartitionKeyedAccumulator[T] => + // Last-write-wins per partition id: a partition recorded by more than one task replaces + // rather than accumulates, keeping any caller-derived aggregate exact. + byPartition.putAll(o.byPartition) + case _ => throw new UnsupportedOperationException( + s"Cannot merge ${this.getClass.getName} with ${other.getClass.getName}") + } + + // A read-only VIEW over the live map -- no copy. Only the accumulator framework calls `value` + // (event log / `toInfo` / `toString`); our own code reads via `accumulatedNumPartitions` / + // `foldValues`. The view is thread-safe (ConcurrentHashMap) and weakly consistent, which matches + // this accumulator's eventual-consistency contract. + override def value: java.util.Map[Int, T] = java.util.Collections.unmodifiableMap(byPartition) + + /** Number of distinct partitions that have been recorded. */ + def accumulatedNumPartitions: Long = byPartition.size().toLong + + /** Folds the per-partition values (each partition counted once) into a single aggregate. */ + def foldValues[A](zero: A)(op: (A, T) => A): A = { + var result = zero + val it = byPartition.values().iterator() + while (it.hasNext) result = op(result, it.next()) + result + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala index 106ee36594b38..085dbcd804665 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala @@ -474,12 +474,12 @@ class CachedTableSuite extends SharedSparkSession val toBeCleanedAccIds = new HashSet[Long] val accId1 = spark.table("t1").queryExecution.withCachedData.collect { - case i: InMemoryRelation => i.cacheBuilder.sizeInBytesStats.id + case i: InMemoryRelation => i.cacheBuilder.materializationAccumulatorId }.head toBeCleanedAccIds += accId1 val accId2 = spark.table("t1").queryExecution.withCachedData.collect { - case i: InMemoryRelation => i.cacheBuilder.sizeInBytesStats.id + case i: InMemoryRelation => i.cacheBuilder.materializationAccumulatorId }.head toBeCleanedAccIds += accId2 @@ -509,6 +509,37 @@ class CachedTableSuite extends SharedSparkSession } } + test("SPARK-57547: clearCache resets materialization bookkeeping") { + val df = spark.range(0, 100, 1, numPartitions = 4).filter($"id" >= 0) + df.cache() + try { + val cacheRelations = df.queryExecution.withCachedData.collect { + case i: InMemoryRelation => i + } + assert(cacheRelations.length == 1) + val builder = cacheRelations.head.cacheBuilder + // Force the cache build directly (a plain df action can be served from the query-result + // cache and skip the rebuild after clearCache). + builder.cachedColumnBuffers.count() + assert(builder.isCachedColumnBuffersLoaded) + assert(builder.materializedRowCount == 100L) + + builder.clearCache() + // The loaded latch and the materialization stats must not survive clearCache, otherwise a + // rebuilt cache would inherit a stale "loaded" state with stale/zero statistics. + assert(!builder.isCachedColumnBuffersLoaded) + assert(builder.materializedRowCount == 0L) + assert(builder.materializedSizeInBytes == 0L) + + // Rebuilding works and reports correct stats again. + builder.cachedColumnBuffers.count() + assert(builder.isCachedColumnBuffersLoaded) + assert(builder.materializedRowCount == 100L) + } finally { + df.unpersist(blocking = true) + } + } + test("SPARK-10327 Cache Table is not working while subquery has alias in its project list") { withTempView("abc") { sparkContext.parallelize((1, 1) :: (2, 2) :: Nil) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/ConcurrentInMemoryRelationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/ConcurrentInMemoryRelationSuite.scala new file mode 100644 index 0000000000000..70f4cea3e39aa --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/ConcurrentInMemoryRelationSuite.scala @@ -0,0 +1,334 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.columnar + +import java.io.File +import java.util.concurrent.CountDownLatch + +import scala.concurrent.duration._ + +import org.scalatest.concurrent.Eventually +import org.scalatest.time.{Millis, Seconds, Span} + +import org.apache.spark.{LocalSparkContext, SparkConf, SparkContext, SparkFunSuite} +import org.apache.spark.sql.{DataFrame, Dataset, SparkSession} +import org.apache.spark.sql.columnar.CachedBatch +import org.apache.spark.sql.functions.{lit, when} +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.util.{ThreadUtils, Utils} + +/** + * Regression test for SPARK-57547: concurrent first-touch of one cold table cache must not let + * duplicate partition computes silently drop rows. + * + * AQE creates a separate `TableCacheQueryStageExec` for every reference to the same cache (table + * cache stages are never reused), and each one submits its own build job over the shared cache RDD. + * A query that references a cached relation several times therefore first-touches the cold cache + * from several jobs at once. Spark has no global cross-executor "compute this partition once" + * barrier, so the same partition can be computed by multiple executors. If the cache decided it was + * "loaded" from a raw task-completion count (the legacy behavior), those duplicate completions + * could push the count to the partition count while a row-producing partition was still being + * computed, falsely marking the cache loaded with rowCount 0 -- which lets AQE propagate an empty + * relation and silently lose rows. + * + * The fix (gated by [[SQLConf.IN_MEMORY_DISTINCT_PARTITION_TRACKING]], default on) counts the + * DISTINCT set of materialized partitions instead, so duplicate computes can no longer mark the + * cache loaded early. These tests reproduce the data loss deterministically: a two-stage gate holds + * the row-producing partition while the empty-output partition's duplicate cross-executor + * completions accumulate. With the fix disabled the cache latches as loaded with rowCount 0 and a + * consumer observes an empty relation (lost rows); with the fix on the cache stays correctly + * not-loaded and the consumer observes every row. A multi-executor `local-cluster` session is + * required so the duplicate computes actually land on different executors. + */ +class ConcurrentInMemoryRelationSuite extends SparkFunSuite with LocalSparkContext with Eventually { + + private def cacheBuilderOf(ds: Dataset[_]): CachedRDDBuilder = { + val relations = ds.queryExecution.withCachedData.collect { case i: InMemoryRelation => i } + assert(relations.length == 1) + relations.head.cacheBuilder + } + + private def withSession(distinctTracking: Boolean, numExecutors: Int = 4)( + f: SparkSession => Unit): Unit = { + val conf = new SparkConf() + .setMaster(s"local-cluster[$numExecutors,1,1024]") + .setAppName("ConcurrentInMemoryRelationSuite") + .set(SQLConf.IN_MEMORY_DISTINCT_PARTITION_TRACKING.key, distinctTracking.toString) + sc = new SparkContext(conf) + try { + // Wait for all executors to register so tasks spread one-per-executor as the tests assume. + eventually(timeout(Span(60, Seconds)), interval(Span(200, Millis))) { + assert(sc.getExecutorIds().size == numExecutors) + } + f(SparkSession.builder().sparkContext(sc).getOrCreate()) + } finally { + resetSparkContext() + } + } + + /** + * Drives the actual SPARK-57547 data loss deterministically. + * + * Caches a skewed join with two shuffle partitions: every partition has non-empty INPUT (so + * neither is pruned as an empty task), but only the `skewKey` bucket produces OUTPUT rows -- so + * one partition is row-producing and the other produces zero rows. A two-stage gate blocks every + * partition's build inside `mapPartitions` until released. `numReferences` threads each submit + * their own build job over the shared cache RDD (exactly as per-reference + * `TableCacheQueryStageExec`s do); on `local-cluster[4,1,...]` (= numReferences x cachePartitions + * task slots, one task per executor) the empty-output partition is computed by both references on + * two distinct executors. + * + * Sequence: (1) the threads first-touch the cold cache, gating all `numReferences x + * cachePartitions` tasks; (2) release only the empty-output partition -- with the fix disabled + * its two cross-executor completions push the RAW materialized count to cachePartitions while the + * row-producing partition is still gated; (3) read `isCachedColumnBuffersLoaded`, which latches + * that poisoned state permanently (one-way latch), removing the timing race; (4) a consumer query + * (a GROUPED aggregate, so empty propagation can collapse it) observes the cache as materialized + * with rowCount 0 and AQE propagates an empty relation; (5) release the producing partition. + * Returns (rows the consumer observed, expected rows). + */ + private def runDataLossRepro(spark: SparkSession): (Long, Long) = { + import spark.implicits._ + val numKeys = 64 + val skewKey = 42 + val rowsPerKey = 500 + val numReferences = 2 + val cachePartitions = 2 // one row-producing, one empty-output (see shuffle.partitions below) + val expected = rowsPerKey.toLong * rowsPerKey.toLong // only the skewKey bucket joins + + // Exactly two shuffle partitions (one row-producing, one empty-output), no broadcast so the + // join shuffles, and build the cache with AQE off so the skewed producing partition is not + // rebalanced away (which would defuse the window). Consumers below run with AQE on so they go + // through TableCacheQueryStageExec + empty-relation propagation. + spark.conf.set("spark.sql.shuffle.partitions", "2") + spark.conf.set("spark.sql.autoBroadcastJoinThreshold", "-1") + spark.conf.set("spark.sql.adaptive.coalescePartitions.enabled", "false") + spark.conf.set("spark.sql.adaptive.skewJoin.enabled", "false") + spark.conf.set("spark.sql.adaptive.enabled", "false") + + val gateDir = Utils.createTempDir() + def file(name: String) = new File(gateDir, name) + val releaseEmpty = file("releaseEmpty").getAbsolutePath + val releaseProducing = file("releaseProducing").getAbsolutePath + val entryDir = gateDir.getAbsolutePath + + def side(matchSalt: Int, valueCol: String): DataFrame = + spark.range(0, numKeys.toLong * rowsPerKey).select( + ($"id" % numKeys).cast("int").as("k"), + when(($"id" % numKeys) === skewKey, lit(0)).otherwise(lit(matchSalt)).as("salt"), + $"id".as(valueCol)) + + val joined = side(1, "lv") + .join(side(2, "rv"), Seq("k", "salt")) + .select($"k", $"lv").as[(Int, Long)] + + // Two-stage gate: every partition signals it has entered (past the block-existence check) and + // waits for releaseEmpty; the row-producing partition (the only one with rows) waits longer. + val gated = joined.mapPartitions { iter => + val buffered = iter.buffered + val isProducing = buffered.hasNext + file(s"entered-${java.util.UUID.randomUUID()}").createNewFile() + def waitFor(path: String): Unit = { + val deadline = System.currentTimeMillis() + 60000 + while (!new File(path).exists() && System.currentTimeMillis() < deadline) Thread.sleep(50) + } + waitFor(releaseEmpty) + if (isProducing) waitFor(releaseProducing) + buffered + }.toDF("k", "lv") + + val cached = gated.cache() + try { + val builder = cacheBuilderOf(cached) + // Cache plan captured (static, 2 partitions); consumers from here on use AQE. + spark.conf.set("spark.sql.adaptive.enabled", "true") + // Every reference launches its own build job over the shared cache RDD (no dedup at this + // layer), so the empty partition is computed by every reference: numReferences x + // cachePartitions gated tasks. + val expectedEntries = numReferences * cachePartitions + val rdd = builder.cachedColumnBuffers + val submitted = new CountDownLatch(numReferences) + val pool = ThreadUtils.newDaemonFixedThreadPool(numReferences, "spark57547-dataloss") + try { + val firstTouch = (1 to numReferences).map { _ => + pool.submit(new java.util.concurrent.Callable[Unit] { + override def call(): Unit = { + val f = spark.sparkContext.submitJob( + rdd, + (_: Iterator[CachedBatch]) => (), + 0 until rdd.getNumPartitions, + (_: Int, _: Unit) => (), + ()) + submitted.countDown() + ThreadUtils.awaitResult(f, 120.seconds) + } + }) + } + assert(submitted.await(60, java.util.concurrent.TimeUnit.SECONDS)) + // Wait until every build task is parked at the gate (all have passed the block-existence + // check), so releasing the empty partition forces its cross-executor completions to run. + eventually(timeout(Span(60, Seconds)), interval(Span(100, Millis))) { + val entered = new File(entryDir).listFiles().count(_.getName.startsWith("entered-")) + assert(entered == expectedEntries, s"entered=$entered expected=$expectedEntries") + } + + // Stage 1: release ONLY the empty-output partition; the producing partition stays gated. + assert(new File(releaseEmpty).createNewFile()) + + // With the fix off, the empty partition's duplicate cross-executor completions push the raw + // materialized count to cachePartitions even though the producing partition has not run, so + // the cache latches as "loaded" with rowCount 0. We read it through the relation handle -- + // exactly what AQE's stats reads do in production -- and the one-way latch makes the poison + // permanent. That removes the timing race: the producing partition is still gated when the + // consumer runs below. With the fix on, distinct-partition accounting keeps the cache not + // loaded here, so this poll times out and we fall through to a normal (complete) build. + val poisoned = + try { + eventually(timeout(Span(30, Seconds)), interval(Span(100, Millis))) { + assert(builder.isCachedColumnBuffersLoaded) + } + true + } catch { + case _: org.scalatest.exceptions.TestFailedException => false + } + + // A GROUPED aggregate (not a global count): AQE empty-relation propagation collapses the + // whole result when the cache stage is (falsely) reported as a zero-row materialized stage; + // a global aggregate over empty would still emit one row and mask the loss. + val observed = if (poisoned) { + // The cache lied (loaded with rowCount 0 while the producing partition is still gated and + // unbuilt). The consumer plans against it and AQE propagates an empty relation, so the + // rows silently vanish. The producing partition stays gated, so this is deterministic. + val consumer = cached.groupBy("k").count() + val rows = consumer.collect() + assert(consumer.queryExecution.executedPlan.toString.contains("EmptyRelation"), + "expected AQE to propagate an empty relation from the poisoned cache stage") + assert(new File(releaseProducing).createNewFile()) // unblock the build for clean shutdown + rows.map(_.getLong(1)).sum + } else { + // Fix on: the cache is correctly not loaded, so let the producing partition finish and + // the consumer observes every row. + assert(new File(releaseProducing).createNewFile()) + cached.groupBy("k").count().collect().map(_.getLong(1)).sum + } + firstTouch.foreach(_.get(120, java.util.concurrent.TimeUnit.SECONDS)) + (observed, expected) + } finally { + pool.shutdown() + } + } finally { + cached.unpersist(blocking = true) + Utils.deleteRecursively(gateDir) + } + } + + /** + * Builds a cold cache whose partitions all carry rows and first-touches it concurrently from + * `numReferences` jobs with every partition gated, so each partition is computed once per + * reference on a distinct executor (`numReferences` duplicate cross-executor computes per + * partition). Returns (reported materialized row count, expected rows); with distinct-partition + * tracking on, the keyed accumulator de-duplicates the duplicate computes so the count is exact. + */ + private def runDuplicateComputeStats(spark: SparkSession): (Long, Long) = { + import spark.implicits._ + val numReferences = 2 + val cachePartitions = 2 + val numRows = 200L // split evenly across the partitions; every partition is non-empty + + val gateDir = Utils.createTempDir() + def file(name: String) = new File(gateDir, name) + val release = file("release").getAbsolutePath + val entryDir = gateDir.getAbsolutePath + + // Every partition has rows and blocks at the gate until released, so all references' build + // tasks are in flight (past the block-existence check) before any completes -- forcing the + // duplicate cross-executor computes that the per-batch accumulator would over-count. + val cached = spark.range(0, numRows, 1, cachePartitions).as[Long].mapPartitions { iter => + file(s"entered-${java.util.UUID.randomUUID()}").createNewFile() + val deadline = System.currentTimeMillis() + 60000 + while (!new File(release).exists() && System.currentTimeMillis() < deadline) Thread.sleep(50) + iter + }.cache() + try { + val builder = cacheBuilderOf(cached) + val rdd = builder.cachedColumnBuffers + val submitted = new CountDownLatch(numReferences) + val pool = ThreadUtils.newDaemonFixedThreadPool(numReferences, "spark57547-stats") + try { + val futures = (1 to numReferences).map { _ => + pool.submit(new java.util.concurrent.Callable[Unit] { + override def call(): Unit = { + val f = spark.sparkContext.submitJob( + rdd, + (_: Iterator[CachedBatch]) => (), + 0 until rdd.getNumPartitions, + (_: Int, _: Unit) => (), + ()) + submitted.countDown() + ThreadUtils.awaitResult(f, 120.seconds) + } + }) + } + assert(submitted.await(60, java.util.concurrent.TimeUnit.SECONDS)) + // Wait until every reference's task for every partition is parked at the gate, then release + // them so each partition is computed once per reference. + eventually(timeout(Span(60, Seconds)), interval(Span(100, Millis))) { + val entered = new File(entryDir).listFiles().count(_.getName.startsWith("entered-")) + assert(entered == numReferences * cachePartitions, s"entered=$entered") + } + assert(new File(release).createNewFile()) + futures.foreach(_.get(120, java.util.concurrent.TimeUnit.SECONDS)) + assert(builder.isCachedColumnBuffersLoaded) + (builder.materializedRowCount, numRows) + } finally { + pool.shutdown() + } + } finally { + cached.unpersist(blocking = true) + Utils.deleteRecursively(gateDir) + } + } + + test("SPARK-57547: concurrent first-touch of a cold cache does not lose rows") { + withSession(distinctTracking = true) { spark => + val (observed, expected) = runDataLossRepro(spark) + assert(observed == expected, s"consumer observed $observed rows, expected $expected") + } + } + + test("SPARK-57547: cache statistics are exact under duplicate cross-executor computes") { + // Every partition is computed by both references, so the partition-keyed accumulator sees a + // duplicate `add` per partition. Last-write-wins de-duplication keeps the reported row count + // exact -- a naive summing accumulator would over-count under these duplicate computes. + withSession(distinctTracking = true) { spark => + val (rowCount, expected) = runDuplicateComputeStats(spark) + assert(rowCount == expected, + s"partition-keyed accumulator should report exact row count $expected, got $rowCount") + } + } + + test("SPARK-57547: with the fix disabled, concurrent first-touch silently loses rows") { + // Demonstrates the actual data loss the fix prevents: a consumer that plans against the + // poisoned cache observes an empty relation instead of the real rows. + withSession(distinctTracking = false) { spark => + val (observed, expected) = runDataLossRepro(spark) + assert(observed == 0L, + s"expected the buggy path to lose all $expected rows but consumer observed $observed") + } + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala index 5cd62302861ae..57da12e87979a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala @@ -361,7 +361,7 @@ class InMemoryColumnarQuerySuite extends SharedSparkSession with AdaptiveSparkPl checkAnswer(cached, expectedAnswer) // Check that the right size was calculated. - assert(cached.cacheBuilder.sizeInBytesStats.value === expectedAnswer.length * INT.defaultSize) + assert(cached.cacheBuilder.materializedSizeInBytes === expectedAnswer.length * INT.defaultSize) } test("cached row count should be calculated") { @@ -375,7 +375,7 @@ class InMemoryColumnarQuerySuite extends SharedSparkSession with AdaptiveSparkPl checkAnswer(cached, expectedAnswer) // Check that the right row count was calculated. - assert(cached.cacheBuilder.rowCountStats.value === 6) + assert(cached.cacheBuilder.materializedRowCount === 6) } test("access primitive-type columns in CachedBatch without whole stage codegen") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/util/PartitionKeyedAccumulatorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/util/PartitionKeyedAccumulatorSuite.scala new file mode 100644 index 0000000000000..19e499942e310 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/util/PartitionKeyedAccumulatorSuite.scala @@ -0,0 +1,107 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.util + +import org.apache.spark.SparkFunSuite + +class PartitionKeyedAccumulatorSuite extends SparkFunSuite { + + // The cache use case records (rowCount, sizeInBytes) per partition. + private type Stats = (Long, Long) + + private def sumRows(acc: PartitionKeyedAccumulator[Stats]): Long = + acc.foldValues(0L)((sum, v) => sum + v._1) + + private def sumBytes(acc: PartitionKeyedAccumulator[Stats]): Long = + acc.foldValues(0L)((sum, v) => sum + v._2) + + test("isZero, add, value and accumulatedNumPartitions") { + val acc = new PartitionKeyedAccumulator[Stats] + assert(acc.isZero) + assert(acc.accumulatedNumPartitions == 0) + assert(acc.value.isEmpty) + + acc.add((0, (10L, 100L))) + assert(!acc.isZero) + assert(acc.accumulatedNumPartitions == 1) + assert(acc.value.get(0) == ((10L, 100L))) + + acc.add((1, (5L, 50L))) + assert(acc.accumulatedNumPartitions == 2) + assert(sumRows(acc) == 15L) + assert(sumBytes(acc) == 150L) + } + + test("add is last-write-wins for the same partition id") { + val acc = new PartitionKeyedAccumulator[Stats] + acc.add((0, (1L, 1L))) + acc.add((0, (2L, 2L))) // re-records partition 0 (e.g. a recompute) + assert(acc.accumulatedNumPartitions == 1) + assert(sumRows(acc) == 2L) // the later value wins, not 1 + 2 + assert(sumBytes(acc) == 2L) + } + + test("merge is last-write-wins per partition id (de-duplicates, does not sum)") { + // Two references compute the same partitions; partition 0 is computed by both. + val a = new PartitionKeyedAccumulator[Stats] + a.add((0, (10L, 100L))) + + val b = new PartitionKeyedAccumulator[Stats] + b.add((0, (10L, 100L))) // duplicate compute of partition 0 + b.add((1, (5L, 50L))) + + a.merge(b) + assert(a.accumulatedNumPartitions == 2) // partitions {0, 1}, not 3 + assert(sumRows(a) == 15L) // 10 (partition 0, counted once) + 5, NOT 25 + assert(sumBytes(a) == 150L) + } + + test("copy is an independent snapshot") { + val acc = new PartitionKeyedAccumulator[Stats] + acc.add((0, (10L, 100L))) + val snapshot = acc.copy() + acc.add((1, (5L, 50L))) // mutate the original after copying + + assert(snapshot.accumulatedNumPartitions == 1) + assert(sumRows(snapshot) == 10L) + assert(acc.accumulatedNumPartitions == 2) + assert(sumRows(acc) == 15L) + } + + test("reset and copyAndReset") { + val acc = new PartitionKeyedAccumulator[Stats] + acc.add((0, (10L, 100L))) + assert(!acc.isZero) + + assert(acc.copyAndReset().isZero) + assert(!acc.isZero) // copyAndReset does not mutate the source + + acc.reset() + assert(acc.isZero) + assert(acc.accumulatedNumPartitions == 0) + } + + test("works for an arbitrary value type") { + val acc = new PartitionKeyedAccumulator[String] + acc.add((0, "a")) + acc.add((1, "b")) + acc.add((0, "c")) // last-write-wins + assert(acc.accumulatedNumPartitions == 2) + assert(acc.foldValues("")((s, v) => s + v).length == 2) // "c" + "b" (each partition once) + } +} From 48ef1c49affb031fd3c1655c959a3f20ca1c78cb Mon Sep 17 00:00:00 2001 From: ziqi liu Date: Sat, 20 Jun 2026 01:55:29 +0000 Subject: [PATCH 2/4] [SPARK-57547][SQL][FOLLOWUP] Always track distinct partitions; drop the conf and legacy path Addresses review feedback on the original change: the legacy raw task-completion-count behavior is simply incorrect (it is what loses rows), and OSS has no backport-auditing or emergency-rollback process that would need a kill switch for it. So this removes the `spark.sql.inMemoryColumnarStorage.distinctPartitionTracking` conf and the legacy code path entirely, and tracks the distinct set of materialized partitions unconditionally. `InMemoryRelation`'s `Either[LegacyAccumulators, PartitionKeyedAccumulator]` (and the conf branch that chose between them) collapses to a single `PartitionKeyedAccumulator`; `clearCache`, `isCachedRDDLoaded`, the `materialized*` accessors, and `buildBuffers` lose their dual-path handling. The now-unused `LongAccumulator` and `scala.util.{Left, Right}` imports are dropped. Also fixes the Java unidoc `reference not found` errors by replacing the `[[AccumulatorV2]]` / `[[ConcurrentHashMap]]` scaladoc links in `PartitionKeyedAccumulator` (whose generated `{@link}`s javadoc could not resolve) with backtick code spans. ### Tests - `ConcurrentInMemoryRelationSuite`: removed the conf toggle and the disabled-mode negative-control test (the disabled path no longer exists); the deterministic `runDataLossRepro` stays and still detects a regression (its `poisoned` branch would make `observed != expected`). - `PartitionKeyedAccumulatorSuite`, `CachedTableSuite`, and `InMemoryColumnarQuerySuite` are unchanged and still pass. Co-authored-by: Isaac --- .../apache/spark/sql/internal/SQLConf.scala | 15 -- .../execution/columnar/InMemoryRelation.scala | 147 +++++------------- .../sql/util/PartitionKeyedAccumulator.scala | 6 +- .../ConcurrentInMemoryRelationSuite.scala | 68 ++++---- 4 files changed, 73 insertions(+), 163 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 04658ab3b0a61..2813dad9400ff 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -861,21 +861,6 @@ object SQLConf { .booleanConf .createWithDefault(true) - val IN_MEMORY_DISTINCT_PARTITION_TRACKING = - buildConf("spark.sql.inMemoryColumnarStorage.distinctPartitionTracking") - .internal() - .doc("When true, a cached relation is considered fully materialized only once the set of " + - "DISTINCT materialized partitions covers every partition, rather than once a raw " + - "task-completion count reaches the partition count; clearCache also resets that " + - "bookkeeping. This prevents concurrent first-touch of a cold cache under AQE (where a " + - "cached source referenced several times has each reference launch its own build job) " + - "from letting duplicate partition completions mark the cache loaded with rowCount 0, " + - "which would make AQE wrongly propagate an empty relation and silently drop rows. When " + - "false, restores the prior raw task-completion-count behavior.") - .version("5.0.0") - .booleanConf - .createWithDefault(true) - val IN_MEMORY_TABLE_SCAN_STATISTICS_ENABLED = buildConf("spark.sql.inMemoryTableScanStatistics.enable") .internal() diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala index 98979db9c3ae0..f79742907779f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala @@ -17,8 +17,6 @@ package org.apache.spark.sql.execution.columnar -import scala.util.{Left, Right} - import com.esotericsoftware.kryo.{DefaultSerializer, Kryo, Serializer => KryoSerializer} import com.esotericsoftware.kryo.io.{Input => KryoInput, Output => KryoOutput} @@ -41,8 +39,8 @@ import org.apache.spark.sql.types._ import org.apache.spark.sql.util.PartitionKeyedAccumulator import org.apache.spark.sql.vectorized.{ColumnarBatch, ColumnVector} import org.apache.spark.storage.StorageLevel -import org.apache.spark.util.{LongAccumulator, Utils} import org.apache.spark.util.ArrayImplicits._ +import org.apache.spark.util.Utils /** * The default implementation of CachedBatch. @@ -253,17 +251,6 @@ class DefaultCachedBatchSerializer extends SimpleMetricsCachedBatchSerializer { } } -// The legacy per-build accumulators, allocated together only on the off path (the `Left` of -// `statsAccumulators` below). `sizeInBytesStats` / `rowCountStats` are the per-batch stat sums -// (the reported row count / size on that path); `materializedPartitions` is the raw task-completion -// count driving the loaded check. All over-count under duplicate computes -- the pre-fix behavior. -// Top-level (not nested in CachedRDDBuilder) so the accumulators carry no outer reference when -// captured by the build task closure. -private case class LegacyAccumulators( - sizeInBytesStats: LongAccumulator, - rowCountStats: LongAccumulator, - materializedPartitions: LongAccumulator) - private[sql] case class CachedRDDBuilder( serializer: CachedBatchSerializer, @@ -275,30 +262,20 @@ case class CachedRDDBuilder( @transient @volatile private var _cachedColumnBuffers: RDD[CachedBatch] = null @transient @volatile private var _cachedColumnBuffersAreLoaded: Boolean = false - // The cache's materialization bookkeeping, chosen ONCE here (at construction) from the - // IN_MEMORY_DISTINCT_PARTITION_TRACKING conf, so the scheme can never change mid-build. Exactly - // one branch is allocated: - // - Right (the fix, default): a partition-keyed accumulator storing (rowCount, sizeInBytes) per - // partition. AQE creates a separate cache scan stage per reference to the same cache and each - // submits its own build job, so the same partition can be computed by several concurrent jobs - // (and speculative tasks); Spark has no global cross-executor "compute this partition once" - // barrier (only a per-executor write lock). Keying by partition id (last-write-wins) means - // those duplicate completions cannot mark the cache loaded before every partition has been - // computed -- which otherwise let AQE read rowCount 0 on a non-empty cache and propagate an - // empty relation, silently dropping rows -- and also yields exact, de-duplicated row count / - // size. - // - Left (conf off): the raw per-batch accumulators (`LegacyAccumulators`) -- the buggy pre-fix - // behavior, kept only as a safety switch. - private val statsAccumulators - : Either[LegacyAccumulators, PartitionKeyedAccumulator[(Long, Long)]] = - if (cachedPlan.conf.getConf(SQLConf.IN_MEMORY_DISTINCT_PARTITION_TRACKING)) { - val acc = new PartitionKeyedAccumulator[(Long, Long)] - cachedPlan.session.sparkContext.register(acc) - Right(acc) - } else { - val sc = cachedPlan.session.sparkContext - Left(LegacyAccumulators(sc.longAccumulator, sc.longAccumulator, sc.longAccumulator)) - } + // The cache's materialization bookkeeping: a partition-keyed accumulator storing + // (rowCount, sizeInBytes) per partition. AQE creates a separate cache scan stage per reference to + // the same cache and each submits its own build job, so the same partition can be computed by + // several concurrent jobs (and speculative tasks); Spark has no global cross-executor "compute + // this partition once" barrier (only a per-executor write lock). Keying by partition id + // (last-write-wins) means those duplicate completions cannot mark the cache loaded before every + // partition has been computed -- which otherwise let AQE read rowCount 0 on a non-empty cache and + // propagate an empty relation, silently dropping rows -- and also yields exact, de-duplicated row + // count / size. + private val partitionStats: PartitionKeyedAccumulator[(Long, Long)] = { + val acc = new PartitionKeyedAccumulator[(Long, Long)] + cachedPlan.session.sparkContext.register(acc) + acc + } val cachedName = tableName.map(n => s"In-memory table $n") .getOrElse(Utils.abbreviate(cachedPlan.toString, 1024)) @@ -319,16 +296,11 @@ case class CachedRDDBuilder( if (_cachedColumnBuffers != null) { _cachedColumnBuffers.unpersist(blocking) _cachedColumnBuffers = null - statsAccumulators match { - case Right(keyed) => - // The buffers no longer back a live RDD. Reset the one-way "loaded" latch and the keyed - // bookkeeping so a rebuild on this builder does not inherit a stale "loaded" state or - // stale statistics. Safe to reset in place: every read of `keyed` is under this monitor. - _cachedColumnBuffersAreLoaded = false - keyed.reset() - case Left(_) => - // Pre-fix behavior, kept only for the safety switch: do not reset the bookkeeping. - } + // The buffers no longer back a live RDD. Reset the one-way "loaded" latch and the keyed + // bookkeeping so a rebuild on this builder does not inherit a stale "loaded" state or stale + // statistics. Safe to reset in place: every read of the accumulator is under this monitor. + _cachedColumnBuffersAreLoaded = false + partitionStats.reset() } } @@ -341,15 +313,10 @@ case class CachedRDDBuilder( // We must make sure the statistics of `sizeInBytes` and `rowCount` are accurate if // `isCachedRDDLoaded` return true. Otherwise, AQE would do a wrong optimization, // e.g., convert a non-empty plan to empty local relation if `rowCount` is 0. - val numMaterialized = statsAccumulators match { - // Count DISTINCT materialized partitions (the keyed accumulator's key set), so the cache is - // only reported loaded once every partition has been computed -- sound even if a partition - // is computed more than once by concurrent or speculative tasks. - case Right(keyed) => keyed.accumulatedNumPartitions - // Buggy pre-fix behavior (safety switch only): a raw task-completion count, which duplicate - // computes can inflate past the partition count before every partition is materialized. - case Left(legacy) => legacy.materializedPartitions.value.longValue - } + // Count DISTINCT materialized partitions (the keyed accumulator's key set), so the cache is + // only reported loaded once every partition has been computed -- sound even if a partition is + // computed more than once by concurrent or speculative tasks. + val numMaterialized = partitionStats.accumulatedNumPartitions val rddLoaded = _cachedColumnBuffers.partitions.length.toLong == numMaterialized if (rddLoaded) { _cachedColumnBuffersAreLoaded = rddLoaded @@ -358,31 +325,20 @@ case class CachedRDDBuilder( } } - // Reported row count / size for the cache's statistics. On the default path these are exact and - // de-duplicated (folded over distinct materialized partitions); on the off path they fall back to - // the raw per-batch accumulators (which can over-count). Synchronized so a fold never races a - // concurrent `clearCache` reset. + // Reported row count / size for the cache's statistics: exact and de-duplicated, folded over the + // distinct materialized partitions. Synchronized so a fold never races a concurrent `clearCache` + // reset. private[sql] def materializedRowCount: Long = synchronized { - statsAccumulators match { - case Right(keyed) => keyed.foldValues(0L)((sum, v) => sum + v._1) - case Left(legacy) => legacy.rowCountStats.value.longValue - } + partitionStats.foldValues(0L)((sum, v) => sum + v._1) } private[sql] def materializedSizeInBytes: Long = synchronized { - statsAccumulators match { - case Right(keyed) => keyed.foldValues(0L)((sum, v) => sum + v._2) - case Left(legacy) => legacy.sizeInBytesStats.value.longValue - } + partitionStats.foldValues(0L)((sum, v) => sum + v._2) } - // The id of the accumulator backing this cache's materialization bookkeeping (the keyed - // accumulator on the default path, a legacy stat accumulator otherwise). Exposed only so + // The id of the accumulator backing this cache's materialization bookkeeping. Exposed only so // `CachedTableSuite`'s accumulator-cleanup test can verify it is cleared after uncache + GC. - private[sql] def materializationAccumulatorId: Long = statsAccumulators match { - case Left(legacy) => legacy.sizeInBytesStats.id - case Right(keyed) => keyed.id - } + private[sql] def materializationAccumulatorId: Long = partitionStats.id private def buildBuffers(): RDD[CachedBatch] = { val cb = try { @@ -408,50 +364,29 @@ case class CachedRDDBuilder( session.sharedState.cacheManager.recacheByPlan(session, logicalPlan) throw e } - // Records one successful partition materialization, using the scheme selected by the conf. - // Built once on the driver so the task closure below captures only the chosen accumulator. The - // default path records this partition's (rows, bytes) keyed by its id; the off path just bumps - // a raw completion count. - val recordMaterialized: (Int, Long, Long) => Unit = statsAccumulators match { - case Right(keyed) => - (partitionId: Int, rows: Long, bytes: Long) => - keyed.add((partitionId, (rows, bytes))) - case Left(legacy) => - val completions = legacy.materializedPartitions - (_: Int, _: Long, _: Long) => completions.add(1L) - } - // On the default path the keyed accumulator is authoritative for the cache stats, so the legacy - // per-batch accumulators are not fed; only the off path uses them. Extract them once on the - // driver (None on the default path) so the task closure captures just what it needs. - val legacyStatAccs: Option[(LongAccumulator, LongAccumulator)] = statsAccumulators match { - case Left(legacy) => Some((legacy.sizeInBytesStats, legacy.rowCountStats)) - case Right(_) => None - } + // Records one successful partition materialization: this partition's (rows, bytes) keyed by its + // id. Bound to a local so the task closure below captures only the accumulator, not the + // enclosing CachedRDDBuilder (whose cachedPlan is not serializable). + val accumulator = partitionStats val cached = cb.mapPartitionsInternal { it => val taskContext = TaskContext.get() val partitionId = taskContext.partitionId() - // This task computes exactly one partition. On the default path, tally its totals so the - // completion listener records them once (covering empty-output partitions, which produce no - // batches); on the off path the raw per-batch accumulators below are fed directly. + // This task computes exactly one partition. Tally its totals so the completion listener + // records them once, keyed by partition id (covering empty-output partitions, which produce + // no batches). var localRows = 0L var localBytes = 0L taskContext.addTaskCompletionListener[Unit] { context => if (!context.isFailed() && !context.isInterrupted()) { - recordMaterialized(partitionId, localRows, localBytes) + accumulator.add((partitionId, (localRows, localBytes))) } } new Iterator[CachedBatch] { override def hasNext: Boolean = it.hasNext override def next(): CachedBatch = { val batch = it.next() - legacyStatAccs match { - case Some((sizeAcc, rowAcc)) => - sizeAcc.add(batch.sizeInBytes) - rowAcc.add(batch.numRows) - case None => - localBytes += batch.sizeInBytes - localRows += batch.numRows - } + localBytes += batch.sizeInBytes + localRows += batch.numRows batch } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/util/PartitionKeyedAccumulator.scala b/sql/core/src/main/scala/org/apache/spark/sql/util/PartitionKeyedAccumulator.scala index 4785010571a18..30252a4a51bbb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/util/PartitionKeyedAccumulator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/util/PartitionKeyedAccumulator.scala @@ -22,7 +22,7 @@ import java.util.concurrent.ConcurrentHashMap import org.apache.spark.util.AccumulatorV2 /** - * An [[AccumulatorV2]] that records one value of type `T` per partition, keyed by partition id with + * An `AccumulatorV2` that records one value of type `T` per partition, keyed by partition id with * LAST-WRITE-WINS merge. When the same partition is recorded more than once -- e.g. duplicate * cross-executor computes, or speculative tasks -- the later value replaces the earlier one rather * than aggregating, so each partition contributes exactly once. The key set is the set of recorded @@ -34,13 +34,13 @@ import org.apache.spark.util.AccumulatorV2 * failed/interrupted tasks are dropped by the accumulator framework (it is not * `countFailedValues`), so only complete per-partition values are ever merged. * - * Backed by a [[ConcurrentHashMap]], whose per-entry atomicity is sufficient here: `add` and the + * Backed by a `ConcurrentHashMap`, whose per-entry atomicity is sufficient here: `add` and the * `putAll` in `merge` are last-write-wins per key, and the reads (`value`, `numPartitions`, * `foldValues`) only require thread-safety and eventual consistency -- they are weakly consistent * during concurrent updates but exact once all updates have been merged. This avoids any explicit * locking (and the nested-lock pattern a two-map `merge` would otherwise need). * - * @tparam T the per-partition value type. Must be non-null ([[ConcurrentHashMap]] forbids nulls). + * @tparam T the per-partition value type. Must be non-null (`ConcurrentHashMap` forbids nulls). */ class PartitionKeyedAccumulator[T] extends AccumulatorV2[(Int, T), java.util.Map[Int, T]] { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/ConcurrentInMemoryRelationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/ConcurrentInMemoryRelationSuite.scala index 70f4cea3e39aa..0f8faceaa2b3f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/ConcurrentInMemoryRelationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/ConcurrentInMemoryRelationSuite.scala @@ -29,7 +29,6 @@ import org.apache.spark.{LocalSparkContext, SparkConf, SparkContext, SparkFunSui import org.apache.spark.sql.{DataFrame, Dataset, SparkSession} import org.apache.spark.sql.columnar.CachedBatch import org.apache.spark.sql.functions.{lit, when} -import org.apache.spark.sql.internal.SQLConf import org.apache.spark.util.{ThreadUtils, Utils} /** @@ -46,14 +45,15 @@ import org.apache.spark.util.{ThreadUtils, Utils} * computed, falsely marking the cache loaded with rowCount 0 -- which lets AQE propagate an empty * relation and silently lose rows. * - * The fix (gated by [[SQLConf.IN_MEMORY_DISTINCT_PARTITION_TRACKING]], default on) counts the - * DISTINCT set of materialized partitions instead, so duplicate computes can no longer mark the - * cache loaded early. These tests reproduce the data loss deterministically: a two-stage gate holds - * the row-producing partition while the empty-output partition's duplicate cross-executor - * completions accumulate. With the fix disabled the cache latches as loaded with rowCount 0 and a - * consumer observes an empty relation (lost rows); with the fix on the cache stays correctly - * not-loaded and the consumer observes every row. A multi-executor `local-cluster` session is - * required so the duplicate computes actually land on different executors. + * The fix counts the DISTINCT set of materialized partitions instead, so duplicate computes can no + * longer mark the cache loaded early. These tests reproduce the race deterministically: a two-stage + * gate holds the row-producing partition while the empty-output partition's duplicate cross-executor + * completions accumulate. With distinct tracking the cache stays correctly not-loaded while a + * partition is still building, so the consumer observes every row; were the loaded check to fall + * back to a raw task-completion count it would latch the cache as loaded with rowCount 0 and let AQE + * propagate an empty relation, losing rows (which the repro detects as a row-count mismatch). A + * multi-executor `local-cluster` session is required so the duplicate computes land on different + * executors. */ class ConcurrentInMemoryRelationSuite extends SparkFunSuite with LocalSparkContext with Eventually { @@ -63,12 +63,11 @@ class ConcurrentInMemoryRelationSuite extends SparkFunSuite with LocalSparkConte relations.head.cacheBuilder } - private def withSession(distinctTracking: Boolean, numExecutors: Int = 4)( + private def withSession(numExecutors: Int = 4)( f: SparkSession => Unit): Unit = { val conf = new SparkConf() .setMaster(s"local-cluster[$numExecutors,1,1024]") .setAppName("ConcurrentInMemoryRelationSuite") - .set(SQLConf.IN_MEMORY_DISTINCT_PARTITION_TRACKING.key, distinctTracking.toString) sc = new SparkContext(conf) try { // Wait for all executors to register so tasks spread one-per-executor as the tests assume. @@ -94,13 +93,14 @@ class ConcurrentInMemoryRelationSuite extends SparkFunSuite with LocalSparkConte * two distinct executors. * * Sequence: (1) the threads first-touch the cold cache, gating all `numReferences x - * cachePartitions` tasks; (2) release only the empty-output partition -- with the fix disabled - * its two cross-executor completions push the RAW materialized count to cachePartitions while the - * row-producing partition is still gated; (3) read `isCachedColumnBuffersLoaded`, which latches - * that poisoned state permanently (one-way latch), removing the timing race; (4) a consumer query - * (a GROUPED aggregate, so empty propagation can collapse it) observes the cache as materialized - * with rowCount 0 and AQE propagates an empty relation; (5) release the producing partition. - * Returns (rows the consumer observed, expected rows). + * cachePartitions` tasks; (2) release only the empty-output partition, so its two cross-executor + * completions land while the row-producing partition is still gated; (3) poll + * `isCachedColumnBuffersLoaded` -- distinct-partition tracking keeps it false (a raw + * task-completion count would instead reach cachePartitions here and latch a poisoned "loaded" + * state with rowCount 0); (4) a consumer query (a GROUPED aggregate, where empty propagation could + * collapse the result) sees the cache not-loaded and plans against the real rows -- had it been + * poisoned, AQE would have propagated an empty relation and dropped rows; (5) release the + * producing partition. Returns (rows the consumer observed, expected rows), equal unless poisoned. */ private def runDataLossRepro(spark: SparkSession): (Long, Long) = { import spark.implicits._ @@ -190,13 +190,13 @@ class ConcurrentInMemoryRelationSuite extends SparkFunSuite with LocalSparkConte // Stage 1: release ONLY the empty-output partition; the producing partition stays gated. assert(new File(releaseEmpty).createNewFile()) - // With the fix off, the empty partition's duplicate cross-executor completions push the raw - // materialized count to cachePartitions even though the producing partition has not run, so - // the cache latches as "loaded" with rowCount 0. We read it through the relation handle -- - // exactly what AQE's stats reads do in production -- and the one-way latch makes the poison - // permanent. That removes the timing race: the producing partition is still gated when the - // consumer runs below. With the fix on, distinct-partition accounting keeps the cache not - // loaded here, so this poll times out and we fall through to a normal (complete) build. + // Were the loaded check to fall back to a raw task-completion count, the empty partition's + // duplicate cross-executor completions would push that count to cachePartitions even though + // the producing partition has not run, latching the cache as "loaded" with rowCount 0. We + // read it through the relation handle -- exactly what AQE's stats reads do in production -- + // and the one-way latch would make the poison permanent (the producing partition is still + // gated when the consumer runs below). With distinct-partition accounting the cache stays + // not loaded here, so this poll times out and we fall through to a normal (complete) build. val poisoned = try { eventually(timeout(Span(30, Seconds)), interval(Span(100, Millis))) { @@ -221,8 +221,8 @@ class ConcurrentInMemoryRelationSuite extends SparkFunSuite with LocalSparkConte assert(new File(releaseProducing).createNewFile()) // unblock the build for clean shutdown rows.map(_.getLong(1)).sum } else { - // Fix on: the cache is correctly not loaded, so let the producing partition finish and - // the consumer observes every row. + // The cache is correctly not loaded, so let the producing partition finish and the + // consumer observes every row. assert(new File(releaseProducing).createNewFile()) cached.groupBy("k").count().collect().map(_.getLong(1)).sum } @@ -305,7 +305,7 @@ class ConcurrentInMemoryRelationSuite extends SparkFunSuite with LocalSparkConte } test("SPARK-57547: concurrent first-touch of a cold cache does not lose rows") { - withSession(distinctTracking = true) { spark => + withSession() { spark => val (observed, expected) = runDataLossRepro(spark) assert(observed == expected, s"consumer observed $observed rows, expected $expected") } @@ -315,20 +315,10 @@ class ConcurrentInMemoryRelationSuite extends SparkFunSuite with LocalSparkConte // Every partition is computed by both references, so the partition-keyed accumulator sees a // duplicate `add` per partition. Last-write-wins de-duplication keeps the reported row count // exact -- a naive summing accumulator would over-count under these duplicate computes. - withSession(distinctTracking = true) { spark => + withSession() { spark => val (rowCount, expected) = runDuplicateComputeStats(spark) assert(rowCount == expected, s"partition-keyed accumulator should report exact row count $expected, got $rowCount") } } - - test("SPARK-57547: with the fix disabled, concurrent first-touch silently loses rows") { - // Demonstrates the actual data loss the fix prevents: a consumer that plans against the - // poisoned cache observes an empty relation instead of the real rows. - withSession(distinctTracking = false) { spark => - val (observed, expected) = runDataLossRepro(spark) - assert(observed == 0L, - s"expected the buggy path to lose all $expected rows but consumer observed $observed") - } - } } From 577e053ce14c7f9eb4f47d190876617a427527a3 Mon Sep 17 00:00:00 2001 From: ziqi liu Date: Sat, 20 Jun 2026 05:32:00 +0000 Subject: [PATCH 3/4] [SPARK-57547][SQL][FOLLOWUP] Wrap test scaladoc to satisfy scalastyle line length The prior follow-up's reworded scaladoc in ConcurrentInMemoryRelationSuite left four comment lines over 100 characters (the Scala linter checks test sources too, which a main-only scalastyle run missed). Re-wrap them; no behavior change. Co-authored-by: Isaac --- .../ConcurrentInMemoryRelationSuite.scala | 33 ++++++++++--------- 1 file changed, 17 insertions(+), 16 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/ConcurrentInMemoryRelationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/ConcurrentInMemoryRelationSuite.scala index 0f8faceaa2b3f..161f8ec647a22 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/ConcurrentInMemoryRelationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/ConcurrentInMemoryRelationSuite.scala @@ -45,15 +45,15 @@ import org.apache.spark.util.{ThreadUtils, Utils} * computed, falsely marking the cache loaded with rowCount 0 -- which lets AQE propagate an empty * relation and silently lose rows. * - * The fix counts the DISTINCT set of materialized partitions instead, so duplicate computes can no - * longer mark the cache loaded early. These tests reproduce the race deterministically: a two-stage - * gate holds the row-producing partition while the empty-output partition's duplicate cross-executor - * completions accumulate. With distinct tracking the cache stays correctly not-loaded while a - * partition is still building, so the consumer observes every row; were the loaded check to fall - * back to a raw task-completion count it would latch the cache as loaded with rowCount 0 and let AQE - * propagate an empty relation, losing rows (which the repro detects as a row-count mismatch). A - * multi-executor `local-cluster` session is required so the duplicate computes land on different - * executors. + * The fix counts the DISTINCT set of materialized partitions instead, so duplicate computes can + * no longer mark the cache loaded early. These tests reproduce the race deterministically: a + * two-stage gate holds the row-producing partition while the empty-output partition's duplicate + * cross-executor completions accumulate. With distinct tracking the cache stays correctly + * not-loaded while a partition is still building, so the consumer observes every row; were the + * loaded check to fall back to a raw task-completion count it would latch the cache as loaded + * with rowCount 0 and let AQE propagate an empty relation, losing rows (which the repro detects + * as a row-count mismatch). A multi-executor `local-cluster` session is required so the duplicate + * computes land on different executors. */ class ConcurrentInMemoryRelationSuite extends SparkFunSuite with LocalSparkContext with Eventually { @@ -93,14 +93,15 @@ class ConcurrentInMemoryRelationSuite extends SparkFunSuite with LocalSparkConte * two distinct executors. * * Sequence: (1) the threads first-touch the cold cache, gating all `numReferences x - * cachePartitions` tasks; (2) release only the empty-output partition, so its two cross-executor - * completions land while the row-producing partition is still gated; (3) poll + * cachePartitions` tasks; (2) release only the empty-output partition, so its two + * cross-executor completions land while the row-producing partition is still gated; (3) poll * `isCachedColumnBuffersLoaded` -- distinct-partition tracking keeps it false (a raw - * task-completion count would instead reach cachePartitions here and latch a poisoned "loaded" - * state with rowCount 0); (4) a consumer query (a GROUPED aggregate, where empty propagation could - * collapse the result) sees the cache not-loaded and plans against the real rows -- had it been - * poisoned, AQE would have propagated an empty relation and dropped rows; (5) release the - * producing partition. Returns (rows the consumer observed, expected rows), equal unless poisoned. + * task-completion count would instead reach cachePartitions here and latch a poisoned + * "loaded" state with rowCount 0); (4) a consumer query (a GROUPED aggregate, where empty + * propagation could collapse the result) sees the cache not-loaded and plans against the + * real rows -- had it been poisoned, AQE would have propagated an empty relation and dropped + * rows; (5) release the producing partition. Returns (rows the consumer observed, expected + * rows), equal unless poisoned. */ private def runDataLossRepro(spark: SparkSession): (Long, Long) = { import spark.implicits._ From e76e02ec14cbbfd2d7bcfa695a4f141607e022d3 Mon Sep 17 00:00:00 2001 From: Ziqi Liu Date: Sun, 21 Jun 2026 23:50:21 -0700 Subject: [PATCH 4/4] Update sql/core/src/main/scala/org/apache/spark/sql/util/PartitionKeyedAccumulator.scala Co-authored-by: Wenchen Fan --- .../org/apache/spark/sql/util/PartitionKeyedAccumulator.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/util/PartitionKeyedAccumulator.scala b/sql/core/src/main/scala/org/apache/spark/sql/util/PartitionKeyedAccumulator.scala index 30252a4a51bbb..2081762e39720 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/util/PartitionKeyedAccumulator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/util/PartitionKeyedAccumulator.scala @@ -35,7 +35,7 @@ import org.apache.spark.util.AccumulatorV2 * `countFailedValues`), so only complete per-partition values are ever merged. * * Backed by a `ConcurrentHashMap`, whose per-entry atomicity is sufficient here: `add` and the - * `putAll` in `merge` are last-write-wins per key, and the reads (`value`, `numPartitions`, + * `putAll` in `merge` are last-write-wins per key, and the reads (`value`, `accumulatedNumPartitions`, * `foldValues`) only require thread-safety and eventual consistency -- they are weakly consistent * during concurrent updates but exact once all updates have been merged. This avoids any explicit * locking (and the nested-lock pattern a two-map `merge` would otherwise need).