Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -36,10 +36,11 @@ import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanExec
import org.apache.spark.sql.execution.vectorized.{OffHeapColumnVector, OnHeapColumnVector, WritableColumnVector}
import org.apache.spark.sql.internal.{SQLConf, StaticSQLConf}
import org.apache.spark.sql.types._
import org.apache.spark.sql.util.PartitionKeyedAccumulator
import org.apache.spark.sql.vectorized.{ColumnarBatch, ColumnVector}
import org.apache.spark.storage.StorageLevel
import org.apache.spark.util.{LongAccumulator, Utils}
import org.apache.spark.util.ArrayImplicits._
import org.apache.spark.util.Utils

/**
* The default implementation of CachedBatch.
Expand Down Expand Up @@ -261,9 +262,20 @@ case class CachedRDDBuilder(
@transient @volatile private var _cachedColumnBuffers: RDD[CachedBatch] = null
@transient @volatile private var _cachedColumnBuffersAreLoaded: Boolean = false

val sizeInBytesStats: LongAccumulator = cachedPlan.session.sparkContext.longAccumulator
val rowCountStats: LongAccumulator = cachedPlan.session.sparkContext.longAccumulator
private val materializedPartitions = cachedPlan.session.sparkContext.longAccumulator
// The cache's materialization bookkeeping: a partition-keyed accumulator storing
// (rowCount, sizeInBytes) per partition. AQE creates a separate cache scan stage per reference to
// the same cache and each submits its own build job, so the same partition can be computed by
// several concurrent jobs (and speculative tasks); Spark has no global cross-executor "compute
// this partition once" barrier (only a per-executor write lock). Keying by partition id
// (last-write-wins) means those duplicate completions cannot mark the cache loaded before every
// partition has been computed -- which otherwise let AQE read rowCount 0 on a non-empty cache and
// propagate an empty relation, silently dropping rows -- and also yields exact, de-duplicated row
// count / size.
private val partitionStats: PartitionKeyedAccumulator[(Long, Long)] = {
val acc = new PartitionKeyedAccumulator[(Long, Long)]
cachedPlan.session.sparkContext.register(acc)
acc
}

val cachedName = tableName.map(n => s"In-memory table $n")
.getOrElse(Utils.abbreviate(cachedPlan.toString, 1024))
Expand All @@ -284,6 +296,11 @@ case class CachedRDDBuilder(
if (_cachedColumnBuffers != null) {
_cachedColumnBuffers.unpersist(blocking)
_cachedColumnBuffers = null
// The buffers no longer back a live RDD. Reset the one-way "loaded" latch and the keyed
// bookkeeping so a rebuild on this builder does not inherit a stale "loaded" state or stale
// statistics. Safe to reset in place: every read of the accumulator is under this monitor.
_cachedColumnBuffersAreLoaded = false
partitionStats.reset()
}
}

Expand All @@ -296,16 +313,33 @@ case class CachedRDDBuilder(
// We must make sure the statistics of `sizeInBytes` and `rowCount` are accurate if
// `isCachedRDDLoaded` return true. Otherwise, AQE would do a wrong optimization,
// e.g., convert a non-empty plan to empty local relation if `rowCount` is 0.
// Because the statistics is based on accumulator, here we use an extra accumulator to
// track if all partitions are materialized.
val rddLoaded = _cachedColumnBuffers.partitions.length == materializedPartitions.value
// Count DISTINCT materialized partitions (the keyed accumulator's key set), so the cache is
// only reported loaded once every partition has been computed -- sound even if a partition is
// computed more than once by concurrent or speculative tasks.
val numMaterialized = partitionStats.accumulatedNumPartitions
val rddLoaded = _cachedColumnBuffers.partitions.length.toLong == numMaterialized
if (rddLoaded) {
_cachedColumnBuffersAreLoaded = rddLoaded
}
rddLoaded
}
}

// Reported row count / size for the cache's statistics: exact and de-duplicated, folded over the
// distinct materialized partitions. Synchronized so a fold never races a concurrent `clearCache`
// reset.
private[sql] def materializedRowCount: Long = synchronized {
partitionStats.foldValues(0L)((sum, v) => sum + v._1)
}

private[sql] def materializedSizeInBytes: Long = synchronized {
partitionStats.foldValues(0L)((sum, v) => sum + v._2)
}

// The id of the accumulator backing this cache's materialization bookkeeping. Exposed only so
// `CachedTableSuite`'s accumulator-cleanup test can verify it is cleared after uncache + GC.
private[sql] def materializationAccumulatorId: Long = partitionStats.id

private def buildBuffers(): RDD[CachedBatch] = {
val cb = try {
if (supportsColumnarInput) {
Expand All @@ -330,18 +364,29 @@ case class CachedRDDBuilder(
session.sharedState.cacheManager.recacheByPlan(session, logicalPlan)
throw e
}
// Records one successful partition materialization: this partition's (rows, bytes) keyed by its
// id. Bound to a local so the task closure below captures only the accumulator, not the
// enclosing CachedRDDBuilder (whose cachedPlan is not serializable).
val accumulator = partitionStats
val cached = cb.mapPartitionsInternal { it =>
TaskContext.get().addTaskCompletionListener[Unit] { context =>
val taskContext = TaskContext.get()
val partitionId = taskContext.partitionId()
// This task computes exactly one partition. Tally its totals so the completion listener
// records them once, keyed by partition id (covering empty-output partitions, which produce
// no batches).
var localRows = 0L
var localBytes = 0L
taskContext.addTaskCompletionListener[Unit] { context =>
if (!context.isFailed() && !context.isInterrupted()) {
materializedPartitions.add(1L)
accumulator.add((partitionId, (localRows, localBytes)))
}
}
new Iterator[CachedBatch] {
override def hasNext: Boolean = it.hasNext
override def next(): CachedBatch = {
val batch = it.next()
sizeInBytesStats.add(batch.sizeInBytes)
rowCountStats.add(batch.numRows)
localBytes += batch.sizeInBytes
localRows += batch.numRows
batch
}
}
Expand Down Expand Up @@ -460,8 +505,8 @@ case class InMemoryRelation(
statsOfPlanToCache
} else {
statsOfPlanToCache.copy(
sizeInBytes = cacheBuilder.sizeInBytesStats.value.longValue,
rowCount = Some(cacheBuilder.rowCountStats.value.longValue)
sizeInBytes = cacheBuilder.materializedSizeInBytes,
rowCount = Some(cacheBuilder.materializedRowCount)
)
}
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.util

import java.util.concurrent.ConcurrentHashMap

import org.apache.spark.util.AccumulatorV2

/**
* An `AccumulatorV2` that records one value of type `T` per partition, keyed by partition id with
* LAST-WRITE-WINS merge. When the same partition is recorded more than once -- e.g. duplicate
* cross-executor computes, or speculative tasks -- the later value replaces the earlier one rather
* than aggregating, so each partition contributes exactly once. The key set is the set of recorded
* partitions, and callers fold the values (see [[foldValues]]) to derive de-duplicated aggregates;
* a plain summing accumulator would instead over-count under duplicate computes.
*
* `add` is expected to be called once per task (e.g. from a task completion listener) with that
* partition's value, so a partition is recorded even when it produced nothing. Updates from
* failed/interrupted tasks are dropped by the accumulator framework (it is not
* `countFailedValues`), so only complete per-partition values are ever merged.
*
* Backed by a `ConcurrentHashMap`, whose per-entry atomicity is sufficient here: `add` and the
* `putAll` in `merge` are last-write-wins per key, and the reads (`value`, `accumulatedNumPartitions`,
* `foldValues`) only require thread-safety and eventual consistency -- they are weakly consistent
* during concurrent updates but exact once all updates have been merged. This avoids any explicit
* locking (and the nested-lock pattern a two-map `merge` would otherwise need).
*
* @tparam T the per-partition value type. Must be non-null (`ConcurrentHashMap` forbids nulls).
*/
class PartitionKeyedAccumulator[T] extends AccumulatorV2[(Int, T), java.util.Map[Int, T]] {

// partition id -> value.
private val byPartition = new ConcurrentHashMap[Int, T]()

override def isZero: Boolean = byPartition.isEmpty

override def copyAndReset(): PartitionKeyedAccumulator[T] = new PartitionKeyedAccumulator[T]

override def copy(): PartitionKeyedAccumulator[T] = {
val newAcc = new PartitionKeyedAccumulator[T]
newAcc.byPartition.putAll(byPartition)
newAcc
}

override def reset(): Unit = byPartition.clear()

override def add(v: (Int, T)): Unit = byPartition.put(v._1, v._2)

override def merge(other: AccumulatorV2[(Int, T), java.util.Map[Int, T]]): Unit = other match {
case o: PartitionKeyedAccumulator[T] =>
// Last-write-wins per partition id: a partition recorded by more than one task replaces
// rather than accumulates, keeping any caller-derived aggregate exact.
byPartition.putAll(o.byPartition)
case _ => throw new UnsupportedOperationException(
s"Cannot merge ${this.getClass.getName} with ${other.getClass.getName}")
}

// A read-only VIEW over the live map -- no copy. Only the accumulator framework calls `value`
// (event log / `toInfo` / `toString`); our own code reads via `accumulatedNumPartitions` /
// `foldValues`. The view is thread-safe (ConcurrentHashMap) and weakly consistent, which matches
// this accumulator's eventual-consistency contract.
override def value: java.util.Map[Int, T] = java.util.Collections.unmodifiableMap(byPartition)

/** Number of distinct partitions that have been recorded. */
def accumulatedNumPartitions: Long = byPartition.size().toLong

/** Folds the per-partition values (each partition counted once) into a single aggregate. */
def foldValues[A](zero: A)(op: (A, T) => A): A = {
var result = zero
val it = byPartition.values().iterator()
while (it.hasNext) result = op(result, it.next())
result
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -474,12 +474,12 @@ class CachedTableSuite extends SharedSparkSession
val toBeCleanedAccIds = new HashSet[Long]

val accId1 = spark.table("t1").queryExecution.withCachedData.collect {
case i: InMemoryRelation => i.cacheBuilder.sizeInBytesStats.id
case i: InMemoryRelation => i.cacheBuilder.materializationAccumulatorId
}.head
toBeCleanedAccIds += accId1

val accId2 = spark.table("t1").queryExecution.withCachedData.collect {
case i: InMemoryRelation => i.cacheBuilder.sizeInBytesStats.id
case i: InMemoryRelation => i.cacheBuilder.materializationAccumulatorId
}.head
toBeCleanedAccIds += accId2

Expand Down Expand Up @@ -509,6 +509,37 @@ class CachedTableSuite extends SharedSparkSession
}
}

test("SPARK-57547: clearCache resets materialization bookkeeping") {
val df = spark.range(0, 100, 1, numPartitions = 4).filter($"id" >= 0)
df.cache()
try {
val cacheRelations = df.queryExecution.withCachedData.collect {
case i: InMemoryRelation => i
}
assert(cacheRelations.length == 1)
val builder = cacheRelations.head.cacheBuilder
// Force the cache build directly (a plain df action can be served from the query-result
// cache and skip the rebuild after clearCache).
builder.cachedColumnBuffers.count()
assert(builder.isCachedColumnBuffersLoaded)
assert(builder.materializedRowCount == 100L)

builder.clearCache()
// The loaded latch and the materialization stats must not survive clearCache, otherwise a
// rebuilt cache would inherit a stale "loaded" state with stale/zero statistics.
assert(!builder.isCachedColumnBuffersLoaded)
assert(builder.materializedRowCount == 0L)
assert(builder.materializedSizeInBytes == 0L)

// Rebuilding works and reports correct stats again.
builder.cachedColumnBuffers.count()
assert(builder.isCachedColumnBuffersLoaded)
assert(builder.materializedRowCount == 100L)
} finally {
df.unpersist(blocking = true)
}
}

test("SPARK-10327 Cache Table is not working while subquery has alias in its project list") {
withTempView("abc") {
sparkContext.parallelize((1, 1) :: (2, 2) :: Nil)
Expand Down
Loading