Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 16 additions & 1 deletion core/src/main/scala/org/apache/spark/SparkFiles.scala
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ package org.apache.spark

import java.io.File

import org.apache.spark.util.Utils

/**
* Resolves paths to files added through `SparkContext.addFile()`.
*/
Expand All @@ -31,7 +33,20 @@ object SparkFiles {
val jobArtifactUUID = JobArtifactSet
.getCurrentJobArtifactState.map(_.uuid).getOrElse("default")
val withUuid = if (jobArtifactUUID == "default") filename else s"$jobArtifactUUID/$filename"
new File(getRootDirectory(), withUuid).getAbsolutePath
val file = new File(getRootDirectory(), withUuid)
// In local mode, `SparkContext.addFile` places files directly under the root directory
// rather than under the job-specific artifact directory used by session isolation. Fall back
// to the root directory when the file is not found under the job-specific directory so that
// files added through `SparkContext.addFile` remain resolvable. This is scoped to local mode
// to preserve session isolation semantics on real executors. See SPARK-53478.
if (jobArtifactUUID != "default" && !file.exists() &&
Utils.isLocalMaster(SparkEnv.get.conf)) {
val fallbackFile = new File(getRootDirectory(), filename)
if (fallbackFile.exists()) {
return fallbackFile.getAbsolutePath
}
}
file.getAbsolutePath
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -328,6 +328,9 @@ private[spark] abstract class BasePythonRunner[IN, OUT](
envVars.put("PYTHON_UDF_BATCH_SIZE", batchSizeForPythonUDF.toString)

envVars.put("SPARK_JOB_ARTIFACT_UUID", jobArtifactUUID.getOrElse("default"))
// Lets the worker scope the `SparkFiles.get` local-mode fallback to local mode only,
// mirroring `org.apache.spark.SparkFiles.get`. See SPARK-53478.
envVars.put("SPARK_LOCAL_MODE", Utils.isLocalMaster(conf).toString)
envVars.put("SPARK_PYTHON_RUNTIME", "PYTHON_WORKER")

val (worker: PythonWorker, handle: Option[ProcessHandle]) = env.createPythonWorker(
Expand Down
17 changes: 17 additions & 0 deletions python/pyspark/core/files.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,23 @@ def get(cls, filename: str) -> str:
{'.../test.py'}
"""
path = os.path.join(SparkFiles.getRootDirectory(), filename)
# In local mode, `SparkContext.addFile` places files directly under the root directory
# rather than under the job-specific artifact directory used by session isolation. When a
# non-default job artifact UUID is active, the worker root directory is the job-specific
# directory, so fall back to its parent (the actual root directory) so that files added
# through `SparkContext.addFile` remain resolvable. This is scoped to local mode and a
# non-default UUID to preserve session isolation semantics on real executors. See
# SPARK-53478.
if (
cls._is_running_on_worker
and not os.path.exists(path)
and os.environ.get("SPARK_LOCAL_MODE", "false") == "true"
and os.environ.get("SPARK_JOB_ARTIFACT_UUID", "default") != "default"
):
parent_dir = os.path.dirname(SparkFiles.getRootDirectory())
parent_path = os.path.join(parent_dir, filename)
if os.path.exists(parent_path):
return os.path.abspath(parent_path)
return os.path.abspath(path)

@classmethod
Expand Down
23 changes: 23 additions & 0 deletions python/pyspark/sql/tests/test_artifact.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,29 @@ def test_add_file(self):
# file from different session.
self.check_add_file(self.spark.newSession())

def test_spark_files_get_with_sc_add_file(self):
# SPARK-53478: SparkFiles.get should resolve files added through
# SparkContext.addFile in local mode, even when a session-specific
# artifact directory is active.
from pyspark.core.files import SparkFiles

with tempfile.TemporaryDirectory(prefix="test_spark_files_get_with_sc_add_file") as d:
file_path = os.path.join(d, "my_sc_file.txt")
with open(file_path, "w") as f:
f.write("Hello from SparkContext.addFile")

self.spark.sparkContext.addFile(file_path)

@udf("string")
def func(x):
with open(SparkFiles.get("my_sc_file.txt"), "r") as my_file:
return my_file.read().strip()

session = self.spark.newSession()
session.range(1).select(
assert_true(func("id") == lit("Hello from SparkContext.addFile"))
).show()

def test_add_archive(self):
self.check_add_archive(self.spark)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -682,6 +682,25 @@ class ArtifactManagerSuite extends SharedSparkSession {
}
}

test("SPARK-53478: SparkFiles.get resolves files added via SparkContext.addFile " +
"in local mode") {
withTempDir { dir =>
val file = new File(dir, "test_file.txt")
Files.writeString(file.toPath, "Hello from SparkContext.addFile",
StandardCharsets.UTF_8)
spark.sparkContext.addFile(file.getAbsolutePath)

val s = spark
import s.implicits._
val result = Seq(1).toDF("value").map { _ =>
val path = org.apache.spark.SparkFiles.get("test_file.txt")
Files.readString(new File(path).toPath, StandardCharsets.UTF_8)
}.collect()

assert(result.head === "Hello from SparkContext.addFile")
}
}

private def getCodegenCount: Long = CodegenMetrics.METRIC_COMPILATION_TIME.getCount

private def runCodegenTest(msg: String)(addOneArtifact: => Unit): Unit = {
Expand Down