diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala index 9c012dbd58e12..f79742907779f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala @@ -36,10 +36,11 @@ import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanExec import org.apache.spark.sql.execution.vectorized.{OffHeapColumnVector, OnHeapColumnVector, WritableColumnVector} import org.apache.spark.sql.internal.{SQLConf, StaticSQLConf} import org.apache.spark.sql.types._ +import org.apache.spark.sql.util.PartitionKeyedAccumulator import org.apache.spark.sql.vectorized.{ColumnarBatch, ColumnVector} import org.apache.spark.storage.StorageLevel -import org.apache.spark.util.{LongAccumulator, Utils} import org.apache.spark.util.ArrayImplicits._ +import org.apache.spark.util.Utils /** * The default implementation of CachedBatch. @@ -261,9 +262,20 @@ case class CachedRDDBuilder( @transient @volatile private var _cachedColumnBuffers: RDD[CachedBatch] = null @transient @volatile private var _cachedColumnBuffersAreLoaded: Boolean = false - val sizeInBytesStats: LongAccumulator = cachedPlan.session.sparkContext.longAccumulator - val rowCountStats: LongAccumulator = cachedPlan.session.sparkContext.longAccumulator - private val materializedPartitions = cachedPlan.session.sparkContext.longAccumulator + // The cache's materialization bookkeeping: a partition-keyed accumulator storing + // (rowCount, sizeInBytes) per partition. AQE creates a separate cache scan stage per reference to + // the same cache and each submits its own build job, so the same partition can be computed by + // several concurrent jobs (and speculative tasks); Spark has no global cross-executor "compute + // this partition once" barrier (only a per-executor write lock). Keying by partition id + // (last-write-wins) means those duplicate completions cannot mark the cache loaded before every + // partition has been computed -- which otherwise let AQE read rowCount 0 on a non-empty cache and + // propagate an empty relation, silently dropping rows -- and also yields exact, de-duplicated row + // count / size. + private val partitionStats: PartitionKeyedAccumulator[(Long, Long)] = { + val acc = new PartitionKeyedAccumulator[(Long, Long)] + cachedPlan.session.sparkContext.register(acc) + acc + } val cachedName = tableName.map(n => s"In-memory table $n") .getOrElse(Utils.abbreviate(cachedPlan.toString, 1024)) @@ -284,6 +296,11 @@ case class CachedRDDBuilder( if (_cachedColumnBuffers != null) { _cachedColumnBuffers.unpersist(blocking) _cachedColumnBuffers = null + // The buffers no longer back a live RDD. Reset the one-way "loaded" latch and the keyed + // bookkeeping so a rebuild on this builder does not inherit a stale "loaded" state or stale + // statistics. Safe to reset in place: every read of the accumulator is under this monitor. + _cachedColumnBuffersAreLoaded = false + partitionStats.reset() } } @@ -296,9 +313,11 @@ case class CachedRDDBuilder( // We must make sure the statistics of `sizeInBytes` and `rowCount` are accurate if // `isCachedRDDLoaded` return true. Otherwise, AQE would do a wrong optimization, // e.g., convert a non-empty plan to empty local relation if `rowCount` is 0. - // Because the statistics is based on accumulator, here we use an extra accumulator to - // track if all partitions are materialized. - val rddLoaded = _cachedColumnBuffers.partitions.length == materializedPartitions.value + // Count DISTINCT materialized partitions (the keyed accumulator's key set), so the cache is + // only reported loaded once every partition has been computed -- sound even if a partition is + // computed more than once by concurrent or speculative tasks. + val numMaterialized = partitionStats.accumulatedNumPartitions + val rddLoaded = _cachedColumnBuffers.partitions.length.toLong == numMaterialized if (rddLoaded) { _cachedColumnBuffersAreLoaded = rddLoaded } @@ -306,6 +325,21 @@ case class CachedRDDBuilder( } } + // Reported row count / size for the cache's statistics: exact and de-duplicated, folded over the + // distinct materialized partitions. Synchronized so a fold never races a concurrent `clearCache` + // reset. + private[sql] def materializedRowCount: Long = synchronized { + partitionStats.foldValues(0L)((sum, v) => sum + v._1) + } + + private[sql] def materializedSizeInBytes: Long = synchronized { + partitionStats.foldValues(0L)((sum, v) => sum + v._2) + } + + // The id of the accumulator backing this cache's materialization bookkeeping. Exposed only so + // `CachedTableSuite`'s accumulator-cleanup test can verify it is cleared after uncache + GC. + private[sql] def materializationAccumulatorId: Long = partitionStats.id + private def buildBuffers(): RDD[CachedBatch] = { val cb = try { if (supportsColumnarInput) { @@ -330,18 +364,29 @@ case class CachedRDDBuilder( session.sharedState.cacheManager.recacheByPlan(session, logicalPlan) throw e } + // Records one successful partition materialization: this partition's (rows, bytes) keyed by its + // id. Bound to a local so the task closure below captures only the accumulator, not the + // enclosing CachedRDDBuilder (whose cachedPlan is not serializable). + val accumulator = partitionStats val cached = cb.mapPartitionsInternal { it => - TaskContext.get().addTaskCompletionListener[Unit] { context => + val taskContext = TaskContext.get() + val partitionId = taskContext.partitionId() + // This task computes exactly one partition. Tally its totals so the completion listener + // records them once, keyed by partition id (covering empty-output partitions, which produce + // no batches). + var localRows = 0L + var localBytes = 0L + taskContext.addTaskCompletionListener[Unit] { context => if (!context.isFailed() && !context.isInterrupted()) { - materializedPartitions.add(1L) + accumulator.add((partitionId, (localRows, localBytes))) } } new Iterator[CachedBatch] { override def hasNext: Boolean = it.hasNext override def next(): CachedBatch = { val batch = it.next() - sizeInBytesStats.add(batch.sizeInBytes) - rowCountStats.add(batch.numRows) + localBytes += batch.sizeInBytes + localRows += batch.numRows batch } } @@ -460,8 +505,8 @@ case class InMemoryRelation( statsOfPlanToCache } else { statsOfPlanToCache.copy( - sizeInBytes = cacheBuilder.sizeInBytesStats.value.longValue, - rowCount = Some(cacheBuilder.rowCountStats.value.longValue) + sizeInBytes = cacheBuilder.materializedSizeInBytes, + rowCount = Some(cacheBuilder.materializedRowCount) ) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/util/PartitionKeyedAccumulator.scala b/sql/core/src/main/scala/org/apache/spark/sql/util/PartitionKeyedAccumulator.scala new file mode 100644 index 0000000000000..bb8f04a8a5565 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/util/PartitionKeyedAccumulator.scala @@ -0,0 +1,90 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.util + +import java.util.concurrent.ConcurrentHashMap + +import org.apache.spark.util.AccumulatorV2 + +/** + * An `AccumulatorV2` that records one value of type `T` per partition, keyed by partition id with + * LAST-WRITE-WINS merge. When the same partition is recorded more than once -- e.g. duplicate + * cross-executor computes, or speculative tasks -- the later value replaces the earlier one rather + * than aggregating, so each partition contributes exactly once. The key set is the set of recorded + * partitions, and callers fold the values (see [[foldValues]]) to derive de-duplicated aggregates; + * a plain summing accumulator would instead over-count under duplicate computes. + * + * `add` is expected to be called once per task (e.g. from a task completion listener) with that + * partition's value, so a partition is recorded even when it produced nothing. Updates from + * failed/interrupted tasks are dropped by the accumulator framework (it is not + * `countFailedValues`), so only complete per-partition values are ever merged. + * + * Backed by a `ConcurrentHashMap`, whose per-entry atomicity is sufficient here: `add` and the + * `putAll` in `merge` are last-write-wins per key, and the reads (`value`, + * `accumulatedNumPartitions`, `foldValues`) only require thread-safety and eventual consistency + * -- they are weakly consistent during concurrent updates but exact once all updates have been + * merged. This avoids any explicit locking (and the nested-lock pattern a two-map `merge` would + * otherwise need). + * + * @tparam T the per-partition value type. Must be non-null (`ConcurrentHashMap` forbids nulls). + */ +class PartitionKeyedAccumulator[T] extends AccumulatorV2[(Int, T), java.util.Map[Int, T]] { + + // partition id -> value. + private val byPartition = new ConcurrentHashMap[Int, T]() + + override def isZero: Boolean = byPartition.isEmpty + + override def copyAndReset(): PartitionKeyedAccumulator[T] = new PartitionKeyedAccumulator[T] + + override def copy(): PartitionKeyedAccumulator[T] = { + val newAcc = new PartitionKeyedAccumulator[T] + newAcc.byPartition.putAll(byPartition) + newAcc + } + + override def reset(): Unit = byPartition.clear() + + override def add(v: (Int, T)): Unit = byPartition.put(v._1, v._2) + + override def merge(other: AccumulatorV2[(Int, T), java.util.Map[Int, T]]): Unit = other match { + case o: PartitionKeyedAccumulator[T] => + // Last-write-wins per partition id: a partition recorded by more than one task replaces + // rather than accumulates, keeping any caller-derived aggregate exact. + byPartition.putAll(o.byPartition) + case _ => throw new UnsupportedOperationException( + s"Cannot merge ${this.getClass.getName} with ${other.getClass.getName}") + } + + // A read-only VIEW over the live map -- no copy. Only the accumulator framework calls `value` + // (event log / `toInfo` / `toString`); our own code reads via `accumulatedNumPartitions` / + // `foldValues`. The view is thread-safe (ConcurrentHashMap) and weakly consistent, which matches + // this accumulator's eventual-consistency contract. + override def value: java.util.Map[Int, T] = java.util.Collections.unmodifiableMap(byPartition) + + /** Number of distinct partitions that have been recorded. */ + def accumulatedNumPartitions: Long = byPartition.size().toLong + + /** Folds the per-partition values (each partition counted once) into a single aggregate. */ + def foldValues[A](zero: A)(op: (A, T) => A): A = { + var result = zero + val it = byPartition.values().iterator() + while (it.hasNext) result = op(result, it.next()) + result + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala index 106ee36594b38..085dbcd804665 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala @@ -474,12 +474,12 @@ class CachedTableSuite extends SharedSparkSession val toBeCleanedAccIds = new HashSet[Long] val accId1 = spark.table("t1").queryExecution.withCachedData.collect { - case i: InMemoryRelation => i.cacheBuilder.sizeInBytesStats.id + case i: InMemoryRelation => i.cacheBuilder.materializationAccumulatorId }.head toBeCleanedAccIds += accId1 val accId2 = spark.table("t1").queryExecution.withCachedData.collect { - case i: InMemoryRelation => i.cacheBuilder.sizeInBytesStats.id + case i: InMemoryRelation => i.cacheBuilder.materializationAccumulatorId }.head toBeCleanedAccIds += accId2 @@ -509,6 +509,37 @@ class CachedTableSuite extends SharedSparkSession } } + test("SPARK-57547: clearCache resets materialization bookkeeping") { + val df = spark.range(0, 100, 1, numPartitions = 4).filter($"id" >= 0) + df.cache() + try { + val cacheRelations = df.queryExecution.withCachedData.collect { + case i: InMemoryRelation => i + } + assert(cacheRelations.length == 1) + val builder = cacheRelations.head.cacheBuilder + // Force the cache build directly (a plain df action can be served from the query-result + // cache and skip the rebuild after clearCache). + builder.cachedColumnBuffers.count() + assert(builder.isCachedColumnBuffersLoaded) + assert(builder.materializedRowCount == 100L) + + builder.clearCache() + // The loaded latch and the materialization stats must not survive clearCache, otherwise a + // rebuilt cache would inherit a stale "loaded" state with stale/zero statistics. + assert(!builder.isCachedColumnBuffersLoaded) + assert(builder.materializedRowCount == 0L) + assert(builder.materializedSizeInBytes == 0L) + + // Rebuilding works and reports correct stats again. + builder.cachedColumnBuffers.count() + assert(builder.isCachedColumnBuffersLoaded) + assert(builder.materializedRowCount == 100L) + } finally { + df.unpersist(blocking = true) + } + } + test("SPARK-10327 Cache Table is not working while subquery has alias in its project list") { withTempView("abc") { sparkContext.parallelize((1, 1) :: (2, 2) :: Nil) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/ConcurrentInMemoryRelationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/ConcurrentInMemoryRelationSuite.scala new file mode 100644 index 0000000000000..161f8ec647a22 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/ConcurrentInMemoryRelationSuite.scala @@ -0,0 +1,325 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.columnar + +import java.io.File +import java.util.concurrent.CountDownLatch + +import scala.concurrent.duration._ + +import org.scalatest.concurrent.Eventually +import org.scalatest.time.{Millis, Seconds, Span} + +import org.apache.spark.{LocalSparkContext, SparkConf, SparkContext, SparkFunSuite} +import org.apache.spark.sql.{DataFrame, Dataset, SparkSession} +import org.apache.spark.sql.columnar.CachedBatch +import org.apache.spark.sql.functions.{lit, when} +import org.apache.spark.util.{ThreadUtils, Utils} + +/** + * Regression test for SPARK-57547: concurrent first-touch of one cold table cache must not let + * duplicate partition computes silently drop rows. + * + * AQE creates a separate `TableCacheQueryStageExec` for every reference to the same cache (table + * cache stages are never reused), and each one submits its own build job over the shared cache RDD. + * A query that references a cached relation several times therefore first-touches the cold cache + * from several jobs at once. Spark has no global cross-executor "compute this partition once" + * barrier, so the same partition can be computed by multiple executors. If the cache decided it was + * "loaded" from a raw task-completion count (the legacy behavior), those duplicate completions + * could push the count to the partition count while a row-producing partition was still being + * computed, falsely marking the cache loaded with rowCount 0 -- which lets AQE propagate an empty + * relation and silently lose rows. + * + * The fix counts the DISTINCT set of materialized partitions instead, so duplicate computes can + * no longer mark the cache loaded early. These tests reproduce the race deterministically: a + * two-stage gate holds the row-producing partition while the empty-output partition's duplicate + * cross-executor completions accumulate. With distinct tracking the cache stays correctly + * not-loaded while a partition is still building, so the consumer observes every row; were the + * loaded check to fall back to a raw task-completion count it would latch the cache as loaded + * with rowCount 0 and let AQE propagate an empty relation, losing rows (which the repro detects + * as a row-count mismatch). A multi-executor `local-cluster` session is required so the duplicate + * computes land on different executors. + */ +class ConcurrentInMemoryRelationSuite extends SparkFunSuite with LocalSparkContext with Eventually { + + private def cacheBuilderOf(ds: Dataset[_]): CachedRDDBuilder = { + val relations = ds.queryExecution.withCachedData.collect { case i: InMemoryRelation => i } + assert(relations.length == 1) + relations.head.cacheBuilder + } + + private def withSession(numExecutors: Int = 4)( + f: SparkSession => Unit): Unit = { + val conf = new SparkConf() + .setMaster(s"local-cluster[$numExecutors,1,1024]") + .setAppName("ConcurrentInMemoryRelationSuite") + sc = new SparkContext(conf) + try { + // Wait for all executors to register so tasks spread one-per-executor as the tests assume. + eventually(timeout(Span(60, Seconds)), interval(Span(200, Millis))) { + assert(sc.getExecutorIds().size == numExecutors) + } + f(SparkSession.builder().sparkContext(sc).getOrCreate()) + } finally { + resetSparkContext() + } + } + + /** + * Drives the actual SPARK-57547 data loss deterministically. + * + * Caches a skewed join with two shuffle partitions: every partition has non-empty INPUT (so + * neither is pruned as an empty task), but only the `skewKey` bucket produces OUTPUT rows -- so + * one partition is row-producing and the other produces zero rows. A two-stage gate blocks every + * partition's build inside `mapPartitions` until released. `numReferences` threads each submit + * their own build job over the shared cache RDD (exactly as per-reference + * `TableCacheQueryStageExec`s do); on `local-cluster[4,1,...]` (= numReferences x cachePartitions + * task slots, one task per executor) the empty-output partition is computed by both references on + * two distinct executors. + * + * Sequence: (1) the threads first-touch the cold cache, gating all `numReferences x + * cachePartitions` tasks; (2) release only the empty-output partition, so its two + * cross-executor completions land while the row-producing partition is still gated; (3) poll + * `isCachedColumnBuffersLoaded` -- distinct-partition tracking keeps it false (a raw + * task-completion count would instead reach cachePartitions here and latch a poisoned + * "loaded" state with rowCount 0); (4) a consumer query (a GROUPED aggregate, where empty + * propagation could collapse the result) sees the cache not-loaded and plans against the + * real rows -- had it been poisoned, AQE would have propagated an empty relation and dropped + * rows; (5) release the producing partition. Returns (rows the consumer observed, expected + * rows), equal unless poisoned. + */ + private def runDataLossRepro(spark: SparkSession): (Long, Long) = { + import spark.implicits._ + val numKeys = 64 + val skewKey = 42 + val rowsPerKey = 500 + val numReferences = 2 + val cachePartitions = 2 // one row-producing, one empty-output (see shuffle.partitions below) + val expected = rowsPerKey.toLong * rowsPerKey.toLong // only the skewKey bucket joins + + // Exactly two shuffle partitions (one row-producing, one empty-output), no broadcast so the + // join shuffles, and build the cache with AQE off so the skewed producing partition is not + // rebalanced away (which would defuse the window). Consumers below run with AQE on so they go + // through TableCacheQueryStageExec + empty-relation propagation. + spark.conf.set("spark.sql.shuffle.partitions", "2") + spark.conf.set("spark.sql.autoBroadcastJoinThreshold", "-1") + spark.conf.set("spark.sql.adaptive.coalescePartitions.enabled", "false") + spark.conf.set("spark.sql.adaptive.skewJoin.enabled", "false") + spark.conf.set("spark.sql.adaptive.enabled", "false") + + val gateDir = Utils.createTempDir() + def file(name: String) = new File(gateDir, name) + val releaseEmpty = file("releaseEmpty").getAbsolutePath + val releaseProducing = file("releaseProducing").getAbsolutePath + val entryDir = gateDir.getAbsolutePath + + def side(matchSalt: Int, valueCol: String): DataFrame = + spark.range(0, numKeys.toLong * rowsPerKey).select( + ($"id" % numKeys).cast("int").as("k"), + when(($"id" % numKeys) === skewKey, lit(0)).otherwise(lit(matchSalt)).as("salt"), + $"id".as(valueCol)) + + val joined = side(1, "lv") + .join(side(2, "rv"), Seq("k", "salt")) + .select($"k", $"lv").as[(Int, Long)] + + // Two-stage gate: every partition signals it has entered (past the block-existence check) and + // waits for releaseEmpty; the row-producing partition (the only one with rows) waits longer. + val gated = joined.mapPartitions { iter => + val buffered = iter.buffered + val isProducing = buffered.hasNext + file(s"entered-${java.util.UUID.randomUUID()}").createNewFile() + def waitFor(path: String): Unit = { + val deadline = System.currentTimeMillis() + 60000 + while (!new File(path).exists() && System.currentTimeMillis() < deadline) Thread.sleep(50) + } + waitFor(releaseEmpty) + if (isProducing) waitFor(releaseProducing) + buffered + }.toDF("k", "lv") + + val cached = gated.cache() + try { + val builder = cacheBuilderOf(cached) + // Cache plan captured (static, 2 partitions); consumers from here on use AQE. + spark.conf.set("spark.sql.adaptive.enabled", "true") + // Every reference launches its own build job over the shared cache RDD (no dedup at this + // layer), so the empty partition is computed by every reference: numReferences x + // cachePartitions gated tasks. + val expectedEntries = numReferences * cachePartitions + val rdd = builder.cachedColumnBuffers + val submitted = new CountDownLatch(numReferences) + val pool = ThreadUtils.newDaemonFixedThreadPool(numReferences, "spark57547-dataloss") + try { + val firstTouch = (1 to numReferences).map { _ => + pool.submit(new java.util.concurrent.Callable[Unit] { + override def call(): Unit = { + val f = spark.sparkContext.submitJob( + rdd, + (_: Iterator[CachedBatch]) => (), + 0 until rdd.getNumPartitions, + (_: Int, _: Unit) => (), + ()) + submitted.countDown() + ThreadUtils.awaitResult(f, 120.seconds) + } + }) + } + assert(submitted.await(60, java.util.concurrent.TimeUnit.SECONDS)) + // Wait until every build task is parked at the gate (all have passed the block-existence + // check), so releasing the empty partition forces its cross-executor completions to run. + eventually(timeout(Span(60, Seconds)), interval(Span(100, Millis))) { + val entered = new File(entryDir).listFiles().count(_.getName.startsWith("entered-")) + assert(entered == expectedEntries, s"entered=$entered expected=$expectedEntries") + } + + // Stage 1: release ONLY the empty-output partition; the producing partition stays gated. + assert(new File(releaseEmpty).createNewFile()) + + // Were the loaded check to fall back to a raw task-completion count, the empty partition's + // duplicate cross-executor completions would push that count to cachePartitions even though + // the producing partition has not run, latching the cache as "loaded" with rowCount 0. We + // read it through the relation handle -- exactly what AQE's stats reads do in production -- + // and the one-way latch would make the poison permanent (the producing partition is still + // gated when the consumer runs below). With distinct-partition accounting the cache stays + // not loaded here, so this poll times out and we fall through to a normal (complete) build. + val poisoned = + try { + eventually(timeout(Span(30, Seconds)), interval(Span(100, Millis))) { + assert(builder.isCachedColumnBuffersLoaded) + } + true + } catch { + case _: org.scalatest.exceptions.TestFailedException => false + } + + // A GROUPED aggregate (not a global count): AQE empty-relation propagation collapses the + // whole result when the cache stage is (falsely) reported as a zero-row materialized stage; + // a global aggregate over empty would still emit one row and mask the loss. + val observed = if (poisoned) { + // The cache lied (loaded with rowCount 0 while the producing partition is still gated and + // unbuilt). The consumer plans against it and AQE propagates an empty relation, so the + // rows silently vanish. The producing partition stays gated, so this is deterministic. + val consumer = cached.groupBy("k").count() + val rows = consumer.collect() + assert(consumer.queryExecution.executedPlan.toString.contains("EmptyRelation"), + "expected AQE to propagate an empty relation from the poisoned cache stage") + assert(new File(releaseProducing).createNewFile()) // unblock the build for clean shutdown + rows.map(_.getLong(1)).sum + } else { + // The cache is correctly not loaded, so let the producing partition finish and the + // consumer observes every row. + assert(new File(releaseProducing).createNewFile()) + cached.groupBy("k").count().collect().map(_.getLong(1)).sum + } + firstTouch.foreach(_.get(120, java.util.concurrent.TimeUnit.SECONDS)) + (observed, expected) + } finally { + pool.shutdown() + } + } finally { + cached.unpersist(blocking = true) + Utils.deleteRecursively(gateDir) + } + } + + /** + * Builds a cold cache whose partitions all carry rows and first-touches it concurrently from + * `numReferences` jobs with every partition gated, so each partition is computed once per + * reference on a distinct executor (`numReferences` duplicate cross-executor computes per + * partition). Returns (reported materialized row count, expected rows); with distinct-partition + * tracking on, the keyed accumulator de-duplicates the duplicate computes so the count is exact. + */ + private def runDuplicateComputeStats(spark: SparkSession): (Long, Long) = { + import spark.implicits._ + val numReferences = 2 + val cachePartitions = 2 + val numRows = 200L // split evenly across the partitions; every partition is non-empty + + val gateDir = Utils.createTempDir() + def file(name: String) = new File(gateDir, name) + val release = file("release").getAbsolutePath + val entryDir = gateDir.getAbsolutePath + + // Every partition has rows and blocks at the gate until released, so all references' build + // tasks are in flight (past the block-existence check) before any completes -- forcing the + // duplicate cross-executor computes that the per-batch accumulator would over-count. + val cached = spark.range(0, numRows, 1, cachePartitions).as[Long].mapPartitions { iter => + file(s"entered-${java.util.UUID.randomUUID()}").createNewFile() + val deadline = System.currentTimeMillis() + 60000 + while (!new File(release).exists() && System.currentTimeMillis() < deadline) Thread.sleep(50) + iter + }.cache() + try { + val builder = cacheBuilderOf(cached) + val rdd = builder.cachedColumnBuffers + val submitted = new CountDownLatch(numReferences) + val pool = ThreadUtils.newDaemonFixedThreadPool(numReferences, "spark57547-stats") + try { + val futures = (1 to numReferences).map { _ => + pool.submit(new java.util.concurrent.Callable[Unit] { + override def call(): Unit = { + val f = spark.sparkContext.submitJob( + rdd, + (_: Iterator[CachedBatch]) => (), + 0 until rdd.getNumPartitions, + (_: Int, _: Unit) => (), + ()) + submitted.countDown() + ThreadUtils.awaitResult(f, 120.seconds) + } + }) + } + assert(submitted.await(60, java.util.concurrent.TimeUnit.SECONDS)) + // Wait until every reference's task for every partition is parked at the gate, then release + // them so each partition is computed once per reference. + eventually(timeout(Span(60, Seconds)), interval(Span(100, Millis))) { + val entered = new File(entryDir).listFiles().count(_.getName.startsWith("entered-")) + assert(entered == numReferences * cachePartitions, s"entered=$entered") + } + assert(new File(release).createNewFile()) + futures.foreach(_.get(120, java.util.concurrent.TimeUnit.SECONDS)) + assert(builder.isCachedColumnBuffersLoaded) + (builder.materializedRowCount, numRows) + } finally { + pool.shutdown() + } + } finally { + cached.unpersist(blocking = true) + Utils.deleteRecursively(gateDir) + } + } + + test("SPARK-57547: concurrent first-touch of a cold cache does not lose rows") { + withSession() { spark => + val (observed, expected) = runDataLossRepro(spark) + assert(observed == expected, s"consumer observed $observed rows, expected $expected") + } + } + + test("SPARK-57547: cache statistics are exact under duplicate cross-executor computes") { + // Every partition is computed by both references, so the partition-keyed accumulator sees a + // duplicate `add` per partition. Last-write-wins de-duplication keeps the reported row count + // exact -- a naive summing accumulator would over-count under these duplicate computes. + withSession() { spark => + val (rowCount, expected) = runDuplicateComputeStats(spark) + assert(rowCount == expected, + s"partition-keyed accumulator should report exact row count $expected, got $rowCount") + } + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala index 5cd62302861ae..57da12e87979a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala @@ -361,7 +361,7 @@ class InMemoryColumnarQuerySuite extends SharedSparkSession with AdaptiveSparkPl checkAnswer(cached, expectedAnswer) // Check that the right size was calculated. - assert(cached.cacheBuilder.sizeInBytesStats.value === expectedAnswer.length * INT.defaultSize) + assert(cached.cacheBuilder.materializedSizeInBytes === expectedAnswer.length * INT.defaultSize) } test("cached row count should be calculated") { @@ -375,7 +375,7 @@ class InMemoryColumnarQuerySuite extends SharedSparkSession with AdaptiveSparkPl checkAnswer(cached, expectedAnswer) // Check that the right row count was calculated. - assert(cached.cacheBuilder.rowCountStats.value === 6) + assert(cached.cacheBuilder.materializedRowCount === 6) } test("access primitive-type columns in CachedBatch without whole stage codegen") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/util/PartitionKeyedAccumulatorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/util/PartitionKeyedAccumulatorSuite.scala new file mode 100644 index 0000000000000..19e499942e310 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/util/PartitionKeyedAccumulatorSuite.scala @@ -0,0 +1,107 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.util + +import org.apache.spark.SparkFunSuite + +class PartitionKeyedAccumulatorSuite extends SparkFunSuite { + + // The cache use case records (rowCount, sizeInBytes) per partition. + private type Stats = (Long, Long) + + private def sumRows(acc: PartitionKeyedAccumulator[Stats]): Long = + acc.foldValues(0L)((sum, v) => sum + v._1) + + private def sumBytes(acc: PartitionKeyedAccumulator[Stats]): Long = + acc.foldValues(0L)((sum, v) => sum + v._2) + + test("isZero, add, value and accumulatedNumPartitions") { + val acc = new PartitionKeyedAccumulator[Stats] + assert(acc.isZero) + assert(acc.accumulatedNumPartitions == 0) + assert(acc.value.isEmpty) + + acc.add((0, (10L, 100L))) + assert(!acc.isZero) + assert(acc.accumulatedNumPartitions == 1) + assert(acc.value.get(0) == ((10L, 100L))) + + acc.add((1, (5L, 50L))) + assert(acc.accumulatedNumPartitions == 2) + assert(sumRows(acc) == 15L) + assert(sumBytes(acc) == 150L) + } + + test("add is last-write-wins for the same partition id") { + val acc = new PartitionKeyedAccumulator[Stats] + acc.add((0, (1L, 1L))) + acc.add((0, (2L, 2L))) // re-records partition 0 (e.g. a recompute) + assert(acc.accumulatedNumPartitions == 1) + assert(sumRows(acc) == 2L) // the later value wins, not 1 + 2 + assert(sumBytes(acc) == 2L) + } + + test("merge is last-write-wins per partition id (de-duplicates, does not sum)") { + // Two references compute the same partitions; partition 0 is computed by both. + val a = new PartitionKeyedAccumulator[Stats] + a.add((0, (10L, 100L))) + + val b = new PartitionKeyedAccumulator[Stats] + b.add((0, (10L, 100L))) // duplicate compute of partition 0 + b.add((1, (5L, 50L))) + + a.merge(b) + assert(a.accumulatedNumPartitions == 2) // partitions {0, 1}, not 3 + assert(sumRows(a) == 15L) // 10 (partition 0, counted once) + 5, NOT 25 + assert(sumBytes(a) == 150L) + } + + test("copy is an independent snapshot") { + val acc = new PartitionKeyedAccumulator[Stats] + acc.add((0, (10L, 100L))) + val snapshot = acc.copy() + acc.add((1, (5L, 50L))) // mutate the original after copying + + assert(snapshot.accumulatedNumPartitions == 1) + assert(sumRows(snapshot) == 10L) + assert(acc.accumulatedNumPartitions == 2) + assert(sumRows(acc) == 15L) + } + + test("reset and copyAndReset") { + val acc = new PartitionKeyedAccumulator[Stats] + acc.add((0, (10L, 100L))) + assert(!acc.isZero) + + assert(acc.copyAndReset().isZero) + assert(!acc.isZero) // copyAndReset does not mutate the source + + acc.reset() + assert(acc.isZero) + assert(acc.accumulatedNumPartitions == 0) + } + + test("works for an arbitrary value type") { + val acc = new PartitionKeyedAccumulator[String] + acc.add((0, "a")) + acc.add((1, "b")) + acc.add((0, "c")) // last-write-wins + assert(acc.accumulatedNumPartitions == 2) + assert(acc.foldValues("")((s, v) => s + v).length == 2) // "c" + "b" (each partition once) + } +}