From d303bb367d93eeda01c6d5876c38c6afd2be0363 Mon Sep 17 00:00:00 2001 From: Gera Shegalov Date: Sun, 7 Jun 2026 06:49:10 -0700 Subject: [PATCH] Add default common unshim packaging flow --- build/buildall | 241 +++++++- dist/README.md | 10 +- dist/build/package-parallel-worlds.py | 42 +- dist/keep-in-spark-shared.txt | 6 + dist/maven-antrun/build-parallel-worlds.xml | 9 +- dist/scripts/analyze-parallel-world-deps.py | 617 ++++++++++++++++++++ dist/scripts/binary-dedupe.sh | 311 ++++++++-- dist/scripts/build-unshim-parallel-world.py | 292 +++++++++ dist/unshimmed-common-from-single-shim.txt | 54 +- dist/unshimmed-from-each-spark3xx.txt | 2 + docs/dev/shimplify.md | 2 +- docs/dev/shims.md | 94 +++ 12 files changed, 1531 insertions(+), 149 deletions(-) create mode 100644 dist/keep-in-spark-shared.txt create mode 100644 dist/scripts/analyze-parallel-world-deps.py create mode 100644 dist/scripts/build-unshim-parallel-world.py diff --git a/build/buildall b/build/buildall index 6c977f3ba5c..33c8c445f5f 100755 --- a/build/buildall +++ b/build/buildall @@ -22,6 +22,25 @@ shopt -s extglob SKIP_CLEAN=1 BUILD_ALL_DEBUG=0 SCALA213=0 +UNSHIM_FAST=0 +UNSHIM_PARALLEL_WORLD_ONLY=0 +UNSHIM_REUSE_BUILT_JARS=0 +UNSHIM_ALLOWLIST_ONLY=0 + +SCRIPT_DIR=$(cd "$(dirname "${BASH_SOURCE[0]}")" >/dev/null 2>&1 && pwd) +SOURCE_DIR=$(cd "$SCRIPT_DIR/.." >/dev/null 2>&1 && pwd) + +function first_pom_value() { + local key="$1" + local pom="$2" + sed -n "0,/<$key>/{s|.*<$key>\([^<]*\).*|\1|p}" "$pom" | head -n 1 +} + +function last_pom_value() { + local key="$1" + local pom="$2" + sed -n "s|.*<$key>\([^<]*\).*|\1|p" "$pom" | tail -n 1 +} function join_by { local IFS="$1"; shift; echo "$*"; } @@ -56,6 +75,14 @@ function print_usage() { echo " repackage the dist module artifact using installed dependencies" echo " --scala213" echo " build 2.13 shims" + echo " --unshim-fast" + echo " skip Maven checks/docs, tests, build metadata, coverage, enforcer, and snapshot refresh for repeated unshim/dist iteration" + echo " --parallel-world-only, --unshim-parallel-world-only" + echo " build analyzer-only parallel-world output without the final Maven dist invocation" + echo " --unshim-reuse-built-jars" + echo " with --unshim-fast --parallel-world-only, skip shim Maven builds and reuse existing target jars" + echo " --unshim-allowlist-only" + echo " imply --unshim-fast --parallel-world-only --unshim-reuse-built-jars and require only unshim allowlist changes" } function bloopInstall() { @@ -148,6 +175,25 @@ case "$1" in SCALA213=1 ;; +--unshim-fast|--fast-unshim) + UNSHIM_FAST=1 + ;; + +--parallel-world-only|--unshim-parallel-world-only) + UNSHIM_PARALLEL_WORLD_ONLY=1 + ;; + +--unshim-reuse-built-jars) + UNSHIM_REUSE_BUILT_JARS=1 + ;; + +--unshim-allowlist-only) + UNSHIM_ALLOWLIST_ONLY=1 + UNSHIM_FAST=1 + UNSHIM_PARALLEL_WORLD_ONLY=1 + UNSHIM_REUSE_BUILT_JARS=1 + ;; + --rebuild-dist-only) SKIP_DIST_DEPS="1" MODULE="dist" @@ -174,14 +220,62 @@ if [[ "$DIST_PROFILE" == *Scala213 ]]; then SCALA213=1 fi +if [[ "$UNSHIM_PARALLEL_WORLD_ONLY" == "1" ]]; then + FINAL_OP="generate-resources" + MODULE="${MODULE:-dist}" +fi + MVN=${MVN:-"mvn"} # include options to mvn command export MVN="$MVN -Dmaven.wagon.http.retryHandler.count=3 ${MVN_OPT}" +if [[ "$UNSHIM_FAST" == "1" ]]; then + export MAVEN_REFRESH_OPT="--no-snapshot-updates" + export MVN_FAST_SKIP_OPTS="-Dmaven.test.skip=true -Drat.skip=true -Dmaven.scalastyle.skip=true -Dmaven.scaladoc.skip=true -Dmaven.javadoc.skip=true -Ddist.jar.compress=false -Djacoco.skip=true -Denforcer.skip=true -Drapids.build.info.skip=true -Dignore.shim.revisions.check=true" +else + export MAVEN_REFRESH_OPT="-U" + export MVN_FAST_SKIP_OPTS="" +fi +export UNSHIM_FAST +export UNSHIM_PARALLEL_WORLD_ONLY +export UNSHIM_ALLOWLIST_ONLY + +if [[ "$UNSHIM_REUSE_BUILT_JARS" == "1" && \ + ( "$UNSHIM_FAST" != "1" || "$UNSHIM_PARALLEL_WORLD_ONLY" != "1" ) ]]; then + echo >&2 "--unshim-reuse-built-jars requires --unshim-fast --parallel-world-only" + exit 1 +fi + +if [[ "$UNSHIM_ALLOWLIST_ONLY" == "1" ]] && \ + git -C "$SOURCE_DIR" rev-parse --is-inside-work-tree >/dev/null 2>&1; then + ALLOWLIST_ONLY_DIRTY=$( + { + git -C "$SOURCE_DIR" diff --name-only -- \ + . \ + ':(exclude)dist/unshimmed-common-from-single-shim.txt' \ + ':(exclude)dist/unshimmed-from-each-spark3xx.txt' \ + ':(exclude)dist/keep-in-spark-shared.txt' + git -C "$SOURCE_DIR" diff --cached --name-only -- \ + . \ + ':(exclude)dist/unshimmed-common-from-single-shim.txt' \ + ':(exclude)dist/unshimmed-from-each-spark3xx.txt' \ + ':(exclude)dist/keep-in-spark-shared.txt' + } | sort -u + ) + if [[ -n "$ALLOWLIST_ONLY_DIRTY" ]]; then + echo >&2 "--unshim-allowlist-only can only reuse jars when tracked changes are limited to dist/unshimmed*.txt or dist/keep-in-spark-shared.txt" + echo >&2 "$ALLOWLIST_ONLY_DIRTY" + exit 1 + fi +fi if [[ "$SCALA213" == "1" ]]; then POM_FILE="scala2.13/pom.xml" export MVN="$MVN -f scala2.13/" - $(dirname $0)/make-scala-version-build-files.sh 2.13 + if [[ "$UNSHIM_FAST" == "1" && -f "$POM_FILE" ]]; then + echo "Unshim fast: reusing existing Scala 2.13 POMs" + else + "$SCRIPT_DIR"/make-scala-version-build-files.sh 2.13 + fi else POM_FILE="pom.xml" fi @@ -216,7 +310,26 @@ case $DIST_PROFILE in esac echo "Spark versions involved: ${SPARK_SHIM_VERSIONS[@]} ..." -export MVN_BASE_DIR=$($MVN help:evaluate -Dexpression=project.basedir -q -DforceStdout) +if [[ "$UNSHIM_FAST" == "1" ]]; then + if [[ "$SCALA213" == "1" ]]; then + export MVN_BASE_DIR="$SOURCE_DIR/scala2.13" + else + export MVN_BASE_DIR="$SOURCE_DIR" + fi + export RAPIDS_PROJECT_VERSION=$(first_pom_value version "$POM_FILE") + export RAPIDS_SCALA_BINARY_VERSION=$(last_pom_value scala.binary.version "$POM_FILE") +else + export MVN_BASE_DIR=$($MVN help:evaluate -Dexpression=project.basedir -q -DforceStdout) +fi + +if [[ "$UNSHIM_PARALLEL_WORLD_ONLY" == "1" ]]; then + echo "Unshim parallel-world-only: preparing analyzer-only output and skipping JNI unpack, shimplify, and reduced POM generation" + MVN_FAST_SKIP_OPTS="$MVN_FAST_SKIP_OPTS -Drapids.jni.unpack.skip=true -Drapids.shimplify.skip=true -Drapids.parallel.world.skip.reduced.pom=true -Drapids.aggregator.downstream.refresh.skip=true" +elif [[ "$UNSHIM_FAST" == "1" && -d "$MVN_BASE_DIR/dist/target/jni-deps" ]]; then + echo "Unshim fast: reusing existing JNI deps from $MVN_BASE_DIR/dist/target/jni-deps" + MVN_FAST_SKIP_OPTS="$MVN_FAST_SKIP_OPTS -Drapids.jni.unpack.skip=true" +fi +export MVN_FAST_SKIP_OPTS if [[ "$GEN_BLOOP" == "true" ]]; then bloopInstall @@ -237,9 +350,45 @@ fi echo "Building a combined dist jar with Shims for ${SPARK_SHIM_VERSIONS[@]} ..." +function refresh_fast_aggregator_jar() { + [[ "$UNSHIM_FAST" == "1" ]] || return 0 + local BUILD_VER=$1 + local agg_dir="$MVN_BASE_DIR/aggregator/target/spark$BUILD_VER" + local agg_base="rapids-4-spark-aggregator_${RAPIDS_SCALA_BINARY_VERSION}-${RAPIDS_PROJECT_VERSION}" + local shaded_jar="$agg_dir/${agg_base}-shaded.jar" + local downstream_jar="$agg_dir/${agg_base}-spark$BUILD_VER.jar" + if [[ ! -f "$shaded_jar" ]]; then + echo >&2 "Expected shaded aggregator jar missing: $shaded_jar" + exit 255 + fi + if [[ -f "$downstream_jar" ]] && cmp -s "$shaded_jar" "$downstream_jar"; then + return 0 + fi + cp -p "$shaded_jar" "$downstream_jar" +} +export -f refresh_fast_aggregator_jar + +function verify_reusable_unshim_artifacts() { + local BUILD_VER=$1 + local classifier="spark$BUILD_VER" + local api_base="rapids-4-spark-sql-plugin-api_${RAPIDS_SCALA_BINARY_VERSION}-${RAPIDS_PROJECT_VERSION}" + local agg_base="rapids-4-spark-aggregator_${RAPIDS_SCALA_BINARY_VERSION}-${RAPIDS_PROJECT_VERSION}" + local api_jar="$MVN_BASE_DIR/sql-plugin-api/target/$classifier/${api_base}-$classifier.jar" + local agg_shaded_jar="$MVN_BASE_DIR/aggregator/target/$classifier/${agg_base}-shaded.jar" + local jar_path + for jar_path in "$api_jar" "$agg_shaded_jar"; do + if [[ ! -f "$jar_path" ]]; then + echo >&2 "Expected reusable unshim artifact missing: $jar_path" + echo >&2 "Re-run without --unshim-reuse-built-jars after source or dependency changes." + exit 255 + fi + done +} +export -f verify_reusable_unshim_artifacts + function build_single_shim() { [[ "$BUILD_ALL_DEBUG" == "1" ]] && set -x - BUILD_VER=$1 + local BUILD_VER=$1 mkdir -p "$MVN_BASE_DIR/target" if (( BUILD_PARALLEL == 1 || NUM_SHIMS == 1 )); then # Single-shim/serial build: stream Maven output live rather than to a log @@ -255,8 +404,8 @@ function build_single_shim() { LOG_FILE="$MVN_BASE_DIR/target/mvn-build-$BUILD_VER.log" fi - if [[ "$BUILD_VER" == "$BASE_VER" ]]; then - SKIP_CHECKS="false" + if [[ "$BUILD_VER" == "$BASE_VER" && \ + ( "$UNSHIM_FAST" != "1" || "$UNSHIM_PARALLEL_WORLD_ONLY" != "1" ) ]]; then # WORKAROUND: # maven build on L193 currently relies on aggregator dependency which # will removed by @@ -267,10 +416,20 @@ function build_single_shim() { # MVN_PHASE="install" else - SKIP_CHECKS="true" MVN_PHASE="package" fi + if [[ "$UNSHIM_FAST" == "1" || "$BUILD_VER" != "$BASE_VER" ]]; then + SKIP_CHECKS="true" + else + SKIP_CHECKS="false" + fi + + local BUILD_PROJECTS="tools" + if [[ "$UNSHIM_FAST" == "1" ]]; then + BUILD_PROJECTS="aggregator" + fi + echo "#### REDIRECTING mvn output to ${LOG_FILE:-stdout} ####" ( if [[ "$LOG_FILE" == "" ]]; then @@ -278,13 +437,15 @@ function build_single_shim() { else exec > "$LOG_FILE" 2>&1 || exit $? fi - $MVN -U "$MVN_PHASE" \ + $MVN $MAVEN_REFRESH_OPT "$MVN_PHASE" \ -DskipTests \ -Dbuildver="$BUILD_VER" \ -Drat.skip="$SKIP_CHECKS" \ - -Dmaven.scaladoc.skip \ + -Dmaven.scaladoc.skip=true \ + -Dmaven.javadoc.skip=true \ -Dmaven.scalastyle.skip="$SKIP_CHECKS" \ - -pl tools -am + $MVN_FAST_SKIP_OPTS \ + -pl "$BUILD_PROJECTS" -am ) || { # Only tail when output went to a real log file; for a live stream # (/dev/tty or existing stdout) the failure output is already on screen. @@ -294,6 +455,7 @@ function build_single_shim() { esac exit 255 } + refresh_fast_aggregator_jar "$BUILD_VER" } export -f build_single_shim @@ -310,25 +472,62 @@ export -f build_single_shim time ( # printf a single buildver array element per line if [[ "$SKIP_DIST_DEPS" != "1" ]]; then + if [[ "$UNSHIM_REUSE_BUILT_JARS" == "1" ]]; then + echo "Unshim fast: reusing existing per-shim jars and skipping Maven shim builds" + for bv in "${SPARK_SHIM_VERSIONS[@]}"; do + verify_reusable_unshim_artifacts "$bv" + refresh_fast_aggregator_jar "$bv" + done + else # Execute initialize to download a massive jar for spark-rapids-jni in a single thread to - # avoid repeating this work in parallel - # Initialize sql-plugin-api only to avoid dealing with missing submodule dependencies - # - $MVN initialize -pl sql-plugin-api -am + # avoid repeating this work in parallel. This is unnecessary in unshim-fast modes that skip + # JNI unpacking. + if [[ "$UNSHIM_FAST" == "1" && "$MVN_FAST_SKIP_OPTS" == *"-Drapids.jni.unpack.skip=true"* ]]; then + echo "Unshim fast: skipping serial Maven initialize preflight" + else + # Initialize sql-plugin-api only to avoid dealing with missing submodule dependencies. + $MVN initialize -pl sql-plugin-api -am + fi printf "%s\n" "${SPARK_SHIM_VERSIONS[@]}" | \ xargs -t -I% -P "$BUILD_PARALLEL" -n 1 \ bash -c 'build_single_shim "$@"' _ % + fi fi - # This used to resume from dist. However, without including aggregator in the build - # the build does not properly initialize spark.version property via buildver profiles - # in the root pom, and we get a missing spark330 dependency even for --profile=330,331 - # where the build does not require it. Moving it to aggregator resolves this issue with - # a negligible increase of the build time by ~2 seconds. + if [[ "$UNSHIM_FAST" == "1" && "$UNSHIM_REUSE_BUILT_JARS" != "1" ]]; then + for bv in "${SPARK_SHIM_VERSIONS[@]}"; do + refresh_fast_aggregator_jar "$bv" + done + fi + # Non-fast builds resume from aggregator so Maven initializes the buildver-derived + # spark.version.classifier before dist resolves its aggregator dependency. The unshim-fast + # dist path can skip that extra aggregator pass because the per-shim builds above already + # installed the base aggregator jar and refreshed all target aggregator jars. joinShimBuildFrom="aggregator" INCLUDED_BUILDVERS_OPT=-Dincluded_buildvers=$(join_by , "${SPARK_SHIM_VERSIONS[@]}") - echo "Resuming from $joinShimBuildFrom build only using $BASE_VER" - $MVN $FINAL_OP -rf $joinShimBuildFrom $MODULE_OPT $MVN_PROFILE_OPT $INCLUDED_BUILDVERS_OPT \ + if [[ "$UNSHIM_FAST" == "1" && "$MODULE" == "dist" ]]; then + if [[ "$UNSHIM_PARALLEL_WORLD_ONLY" == "1" ]]; then + echo "Unshim fast: assembling parallel-world directly without final Maven dist invocation" + python3 "$SOURCE_DIR/dist/scripts/build-unshim-parallel-world.py" \ + --mvn-base-dir "$MVN_BASE_DIR" \ + --source-dir "$SOURCE_DIR" \ + --project-version "$RAPIDS_PROJECT_VERSION" \ + --scala-binary-version "$RAPIDS_SCALA_BINARY_VERSION" \ + --buildvers "$(join_by , "${SPARK_SHIM_VERSIONS[@]}")" \ + --ignore-shim-revisions-check + exit 0 + else + echo "Resuming at dist only using $BASE_VER" + FINAL_RESUME_OPT="" + FINAL_MODULE_OPT="--projects dist" + fi + else + echo "Resuming from $joinShimBuildFrom build only using $BASE_VER" + FINAL_RESUME_OPT="-rf $joinShimBuildFrom" + FINAL_MODULE_OPT="$MODULE_OPT" + fi + $MVN $FINAL_OP $FINAL_RESUME_OPT $FINAL_MODULE_OPT $MVN_PROFILE_OPT $INCLUDED_BUILDVERS_OPT \ -Dbuildver="$BASE_VER" \ - -DskipTests -Dmaven.scaladoc.skip + -DskipTests -Dmaven.scaladoc.skip=true -Dmaven.javadoc.skip=true \ + $MVN_FAST_SKIP_OPTS ) diff --git a/dist/README.md b/dist/README.md index aa23b6a6332..840f9a52ee6 100644 --- a/dist/README.md +++ b/dist/README.md @@ -28,10 +28,8 @@ provider discovery mechanism [ParallelWorldClassloader](https://github.com/openjdk/jdk/blob/jdk8-b120/jaxws/src/share/jaxws_classes/com/sun/istack/internal/tools/ParallelWorldClassLoader.java)) for each version of Spark supported in the jar, i.e., spark330/, spark341/, etc. -If you have to change the contents of the uber jar the following files control what goes into the base jar as classes that are not shaded. +If you have to change the contents of the uber jar, the packaging defaults common classes to the base jar when binary dedupe proves they are bitwise-identical across shims. New common classes should normally remain unshimmed by default. The following files control explicit exceptions and non-class resources. -1. `unshimmed-common-from-single-shim.txt` - This has classes and files that should go into the base jar with their normal -package name (not shaded). This includes user visible classes (i.e., com/nvidia/spark/SQLPlugin), python files, -and other files that aren't version specific. Uses Spark 3.2.0 built jar for these base classes as explained above. -2. `unshimmed-from-each-spark3xx.txt` - This is applied to all the individual Spark specific version jars to pull -any files that need to go into the base of the jar and not into the Spark specific directory. +1. `keep-in-spark-shared.txt` - Patterns for bitwise-identical common `spark-shared` class files that must stay in `spark-shared` instead of being promoted to the base jar. This should stay small; add entries only for compatibility or packaging exceptions. +2. `unshimmed-common-from-single-shim.txt` - Files that must go into the base jar from one representative shim but are not selected by default class promotion, such as root `META-INF` resources and Python worker files. Avoid adding class files here unless they need special root-layout treatment outside bitwise-identical default promotion. +3. `unshimmed-from-each-spark3xx.txt` - This is applied to all the individual Spark specific version jars to pull any files that need to go into the base of the jar and not into the Spark specific directory. These are per-shim root artifacts rather than common `spark-shared` classes. diff --git a/dist/build/package-parallel-worlds.py b/dist/build/package-parallel-worlds.py index 4698c4a8ca0..e612b05b490 100644 --- a/dist/build/package-parallel-worlds.py +++ b/dist/build/package-parallel-worlds.py @@ -26,6 +26,30 @@ def shell_exec(shell_cmd): self.fail("failed to execute %s" % shell_cmd) +def has_fnmatch_magic(pattern): + return "*" in pattern or "?" in pattern or "[" in pattern + + +def select_matching_members(namelist, patterns): + if os.environ.get("UNSHIM_FAST") != "1": + matching_members = [] + for pat in patterns: + matching_members += fnmatch.filter(namelist, pat) + return matching_members + + names_by_entry = {} + for name in namelist: + names_by_entry.setdefault(name, []).append(name) + + matching_members = [] + for pat in patterns: + if has_fnmatch_magic(pat): + matching_members += fnmatch.filter(namelist, pat) + else: + matching_members += names_by_entry.get(pat, []) + return matching_members + + artifacts = attributes.get('artifact_csv').split(',') buildver_list = re.sub(r'\s+', '', project.getProperty('included_buildvers'), flags=re.UNICODE).split(',') @@ -40,6 +64,12 @@ def shell_exec(shell_cmd): art_url = project.getProperty('env.ART_URL') jenkins_settings = os.sep.join([source_basedir, 'jenkins', 'settings.xml']) repo_local = project.getProperty('maven.repo.local') +dist_dir = os.sep.join([source_basedir, 'dist']) +with open(os.sep.join([dist_dir, 'unshimmed-common-from-single-shim.txt']), 'r') as f: + from_single_shim = f.read().splitlines() +with open(os.sep.join([dist_dir, 'unshimmed-from-each-spark3xx.txt']), 'r') as f: + from_each = f.read().splitlines() +from_single_shim_or_each = from_single_shim + from_each for bv in buildver_list: classifier = 'spark' + bv @@ -73,11 +103,6 @@ def shell_exec(shell_cmd): mvn_cmd.append('='.join(['-Dmaven.repo.local', repo_local])) shell_exec(mvn_cmd) - dist_dir = os.sep.join([source_basedir, 'dist']) - with open(os.sep.join([dist_dir, 'unshimmed-common-from-single-shim.txt']), 'r') as f: - from_single_shim = f.read().splitlines() - with open(os.sep.join([dist_dir, 'unshimmed-from-each-spark3xx.txt']), 'r') as f: - from_each = f.read().splitlines() with zipfile.ZipFile(os.sep.join([deps_dir, art_jar]), 'r') as zip_handle: if project.getProperty('should.build.conventional.jar'): zip_handle.extractall(path=top_dist_jar_dir) @@ -88,9 +113,6 @@ def shell_exec(shell_cmd): zip_handle.extractall(path=top_dist_jar_dir) # TODO deprecate namelist = zip_handle.namelist() - matching_members = [] - glob_list = from_single_shim + from_each if bv == buildver_list[0] else from_each - for pat in glob_list: - new_matches = fnmatch.filter(namelist, pat) - matching_members += new_matches + glob_list = from_single_shim_or_each if bv == buildver_list[0] else from_each + matching_members = select_matching_members(namelist, glob_list) zip_handle.extractall(path=top_dist_jar_dir, members=matching_members) diff --git a/dist/keep-in-spark-shared.txt b/dist/keep-in-spark-shared.txt new file mode 100644 index 00000000000..5fc420febc9 --- /dev/null +++ b/dist/keep-in-spark-shared.txt @@ -0,0 +1,6 @@ +# Patterns for common spark-shared class files that must not be promoted to +# the root layout even when binary dedupe marks them bitwise-identical. +# +# Add entries only when a class is bitwise-identical but must remain loaded +# from spark-shared for compatibility or packaging reasons. New common classes +# should normally stay unshimmed by default. diff --git a/dist/maven-antrun/build-parallel-worlds.xml b/dist/maven-antrun/build-parallel-worlds.xml index afde7c2d755..f6ccf8cb0b9 100644 --- a/dist/maven-antrun/build-parallel-worlds.xml +++ b/dist/maven-antrun/build-parallel-worlds.xml @@ -123,6 +123,10 @@ failonerror="false"> + + @@ -132,13 +136,14 @@ - + - + Generating dependency-reduced-pom.xml <dependency> diff --git a/dist/scripts/analyze-parallel-world-deps.py b/dist/scripts/analyze-parallel-world-deps.py new file mode 100644 index 00000000000..ab2867db7c4 --- /dev/null +++ b/dist/scripts/analyze-parallel-world-deps.py @@ -0,0 +1,617 @@ +#!/usr/bin/env python3 + +# Copyright (c) 2026, NVIDIA CORPORATION. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Analyze dependencies between conventional, spark-shared, and shim classes. + +The dist jar contains classes in the conventional root layout, in spark-shared, +and in one or more Spark-version-specific directories. This script inspects the +class files and reports which root or spark-shared classes still have a static +dependency path to version-specific bytecode. +""" + +import argparse +import collections +import json +import os +import re +import struct +import sys +import zipfile + + +SHIM_DIR_RE = re.compile(r"^spark[0-9][0-9a-z]*$") +CLASSIFIER_PACKAGE_RE = re.compile(r"(^|\.)spark[0-9][0-9a-z]*($|\.)") +DESCRIPTOR_CLASS_RE = re.compile(r"L([^;<>\[\]\(\)]+);") + +DEFAULT_EXCLUDES = ( + "ai.rapids.cudf.", + "com.nvidia.shaded.", + "org.openucx.", +) + + +ClassInfo = collections.namedtuple("ClassInfo", ("name", "location", "entry", "deps")) + + +def _read_u1(data, offset): + return data[offset], offset + 1 + + +def _read_u2(data, offset): + return struct.unpack_from(">H", data, offset)[0], offset + 2 + + +def _read_u4(data, offset): + return struct.unpack_from(">I", data, offset)[0], offset + 4 + + +def _class_names_from_descriptor(value): + for match in DESCRIPTOR_CLASS_RE.finditer(value): + yield match.group(1) + + +def _normalize_internal_name(value): + if not value: + return [] + if value.startswith("["): + return list(_class_names_from_descriptor(value)) + if "/" in value and not value.startswith("("): + return [value] + return list(_class_names_from_descriptor(value)) + + +def parse_class_file(data): + magic, offset = _read_u4(data, 0) + if magic != 0xCAFEBABE: + raise ValueError("not a class file") + + # minor_version, major_version + _, offset = _read_u2(data, offset) + _, offset = _read_u2(data, offset) + + cp_count, offset = _read_u2(data, offset) + constant_pool = [None] * cp_count + class_name_indexes = [] + utf8_values = [] + + index = 1 + while index < cp_count: + tag, offset = _read_u1(data, offset) + if tag == 1: # CONSTANT_Utf8 + length, offset = _read_u2(data, offset) + raw = data[offset:offset + length] + offset += length + value = raw.decode("utf-8", errors="replace") + constant_pool[index] = value + utf8_values.append(value) + elif tag in (3, 4): # Integer, Float + offset += 4 + elif tag in (5, 6): # Long, Double + offset += 8 + index += 1 + elif tag == 7: # Class + name_index, offset = _read_u2(data, offset) + constant_pool[index] = name_index + class_name_indexes.append(name_index) + elif tag == 8: # String + offset += 2 + elif tag in (9, 10, 11, 12, 17, 18): # refs, NameAndType, Dynamic, InvokeDynamic + offset += 4 + elif tag == 15: # MethodHandle + offset += 3 + elif tag in (16, 19, 20): # MethodType, Module, Package + offset += 2 + else: + raise ValueError("unknown constant pool tag %s" % tag) + index += 1 + + # access_flags + _, offset = _read_u2(data, offset) + this_class_index, offset = _read_u2(data, offset) + this_name_index = constant_pool[this_class_index] + this_name = constant_pool[this_name_index] + + deps = set() + for name_index in class_name_indexes: + for dep in _normalize_internal_name(constant_pool[name_index]): + deps.add(dep.replace("/", ".")) + for value in utf8_values: + for dep in _class_names_from_descriptor(value): + deps.add(dep.replace("/", ".")) + + class_name = this_name.replace("/", ".") + deps.discard(class_name) + return class_name, deps + + +def location_from_entry(entry): + first = entry.split("/", 1)[0] + if first == "spark-shared": + return "spark-shared" + if SHIM_DIR_RE.match(first): + return first + return "root" + + +def is_classifier_class(class_name): + return bool(CLASSIFIER_PACKAGE_RE.search(class_name)) + + +def is_version_location(location): + return bool(SHIM_DIR_RE.match(location)) + + +def is_version_node(node): + class_name, location = node + return is_version_location(location) or is_classifier_class(class_name) + + +def iter_class_entries(path): + if zipfile.is_zipfile(path): + with zipfile.ZipFile(path) as zf: + for name in zf.namelist(): + if name.endswith(".class") and not name.endswith("/module-info.class"): + yield name, zf.read(name) + return + + for root, _, files in os.walk(path): + for file_name in files: + if not file_name.endswith(".class") or file_name == "module-info.class": + continue + full_path = os.path.join(root, file_name) + rel_path = os.path.relpath(full_path, path).replace(os.sep, "/") + with open(full_path, "rb") as fh: + yield rel_path, fh.read() + + +def should_exclude(class_name, prefixes): + return any(class_name.startswith(prefix) for prefix in prefixes) + + +def load_classes(path, exclude_prefixes): + classes = {} + name_locations = collections.defaultdict(set) + errors = [] + for entry, data in iter_class_entries(path): + try: + class_name, deps = parse_class_file(data) + except ValueError as exc: + errors.append("%s: %s" % (entry, exc)) + continue + if should_exclude(class_name, exclude_prefixes): + continue + location = location_from_entry(entry) + info = ClassInfo(class_name, location, entry, deps) + node = (class_name, location) + classes[node] = info + name_locations[class_name].add(location) + return classes, name_locations, errors + + +def resolve_dependency_targets(source_location, dep_name, name_locations): + locations = name_locations.get(dep_name) + if not locations: + return [] + + # Parent/root class loading wins in the current layout. Prefer a conventional + # class when one exists, then the source archive, then spark-shared, then the + # remaining version-specific locations. + ordered = [] + for preferred in ("root", source_location, "spark-shared"): + if preferred in locations and preferred not in ordered: + ordered.append(preferred) + ordered.extend(sorted(loc for loc in locations if loc not in ordered)) + return [(dep_name, loc) for loc in ordered] + + +def build_graph(classes, name_locations): + graph = {node: set() for node in classes} + for node, info in classes.items(): + for dep_name in info.deps: + for target in resolve_dependency_targets(info.location, dep_name, name_locations): + if target in classes: + graph[node].add(target) + return graph + + +def reverse_graph(graph): + rev = {node: set() for node in graph} + for source, targets in graph.items(): + for target in targets: + rev[target].add(source) + return rev + + +def reachable_to_version_specific(graph): + rev = reverse_graph(graph) + version_nodes = {node for node in graph if is_version_node(node)} + marked = set(version_nodes) + queue = collections.deque(version_nodes) + while queue: + node = queue.popleft() + for parent in rev[node]: + if parent not in marked: + marked.add(parent) + queue.append(parent) + return marked, version_nodes + + +def shortest_path_to_version(graph, start): + queue = collections.deque([(start, [start])]) + seen = {start} + while queue: + node, path = queue.popleft() + if node != start and is_version_node(node): + return path + for next_node in sorted(graph[node]): + if next_node not in seen: + seen.add(next_node) + queue.append((next_node, path + [next_node])) + return None + + +def tarjan_scc(graph): + sys.setrecursionlimit(max(sys.getrecursionlimit(), len(graph) * 2 + 1000)) + + index = 0 + stack = [] + on_stack = set() + indexes = {} + lowlinks = {} + components = [] + + def strongconnect(node): + nonlocal index + indexes[node] = index + lowlinks[node] = index + index += 1 + stack.append(node) + on_stack.add(node) + + for next_node in graph[node]: + if next_node not in indexes: + strongconnect(next_node) + lowlinks[node] = min(lowlinks[node], lowlinks[next_node]) + elif next_node in on_stack: + lowlinks[node] = min(lowlinks[node], indexes[next_node]) + + if lowlinks[node] == indexes[node]: + component = [] + while True: + item = stack.pop() + on_stack.remove(item) + component.append(item) + if item == node: + break + components.append(component) + + for node in graph: + if node not in indexes: + strongconnect(node) + return components + + +def dependency_first_component_order(graph, components): + comp_by_node = {} + for comp_id, component in enumerate(components): + for node in component: + comp_by_node[node] = comp_id + + # Source -> target means "source depends on target". Reverse component + # edges so Kahn's algorithm emits dependencies before their users. + prereq_edges = collections.defaultdict(set) + indegree = collections.Counter() + for source, targets in graph.items(): + source_comp = comp_by_node[source] + indegree.setdefault(source_comp, 0) + for target in targets: + target_comp = comp_by_node[target] + if source_comp == target_comp: + continue + if source_comp not in prereq_edges[target_comp]: + prereq_edges[target_comp].add(source_comp) + indegree[source_comp] += 1 + indegree.setdefault(target_comp, indegree[target_comp]) + + ready = collections.deque(sorted( + comp_id for comp_id in range(len(components)) if indegree[comp_id] == 0)) + ordered = [] + while ready: + comp_id = ready.popleft() + ordered.append(comp_id) + for dependent in sorted(prereq_edges[comp_id]): + indegree[dependent] -= 1 + if indegree[dependent] == 0: + ready.append(dependent) + return ordered + + +def format_node(node): + class_name, location = node + return "%s (%s)" % (class_name, location) + + +def print_path(path): + return " -> ".join(format_node(node) for node in path) + + +def json_node(node): + class_name, location = node + return { + "className": class_name, + "location": location, + } + + +def location_relative_entry(info): + parts = info.entry.split("/", 1) + if info.location == "root": + return info.entry + if len(parts) == 2: + return parts[1] + return info.entry + + +def direct_classifier_edges(graph): + edges = [] + for source, targets in graph.items(): + if is_classifier_class(source[0]): + continue + for target in targets: + if is_classifier_class(target[0]): + edges.append((source, target)) + return sorted(edges) + + +def version_blocker_counts(graph, version_nodes, root_or_shared): + """Count root/shared classes that can reach each version-specific node.""" + rev = reverse_graph(graph) + counts = [] + for version_node in sorted(version_nodes): + seen = {version_node} + queue = collections.deque([version_node]) + impacted = set() + while queue: + node = queue.popleft() + for parent in rev[node]: + if parent in seen: + continue + seen.add(parent) + queue.append(parent) + if parent in root_or_shared: + impacted.add(parent) + if impacted: + counts.append((len(impacted), version_node)) + return sorted(counts, key=lambda item: (-item[0], item[1])) + + +def nearest_version_target_counts(graph, blocked): + """Count terminal version nodes from each blocked node's shortest path.""" + rev = reverse_graph(graph) + distance = {} + queue = collections.deque() + for node in sorted(node for node in graph if is_version_node(node)): + distance[node] = 0 + queue.append(node) + + while queue: + node = queue.popleft() + for parent in sorted(rev[node]): + if parent in distance: + continue + distance[parent] = distance[node] + 1 + queue.append(parent) + + def rebuild_path(start): + path = [start] + node = start + while not is_version_node(node): + next_node = None + for candidate in sorted(graph[node]): + if distance.get(candidate) == distance[node] - 1: + next_node = candidate + break + if next_node is None: + return None + path.append(next_node) + node = next_node + return path + + counts = collections.Counter() + examples = {} + paths = [] + for node in blocked: + if node not in distance: + continue + path = rebuild_path(node) + if not path: + continue + paths.append((node, path)) + target = path[-1] + counts[target] += 1 + examples.setdefault(target, path) + ranked = sorted( + ((count, target, examples[target]) for target, count in counts.items()), + key=lambda item: (-item[0], item[1])) + return ranked, paths + + +def main(): + parser = argparse.ArgumentParser(description=__doc__) + parser.add_argument("path", help="dist/target/parallel-world directory or a dist jar") + parser.add_argument("--limit", type=int, default=20, + help="maximum number of examples to print per section") + parser.add_argument("--exclude-prefix", action="append", default=[], + help="class name prefix to exclude; may be passed more than once") + parser.add_argument("--show-safe", action="store_true", + help="print examples of spark-shared classes with no path to version-specific code") + parser.add_argument("--show-topo", action="store_true", + help="print root-safe spark-shared SCCs in dependency-first order") + parser.add_argument("--show-reachability", action="store_true", + help="print overlapping reachability counts for version-specific nodes") + parser.add_argument("--format", choices=("text", "json"), default="text", + help="output format") + parser.add_argument("--write-safe-paths", + help="write root-safe spark-shared class paths, one per line") + args = parser.parse_args() + + exclude_prefixes = tuple(DEFAULT_EXCLUDES) + tuple(args.exclude_prefix) + classes, name_locations, errors = load_classes(args.path, exclude_prefixes) + graph = build_graph(classes, name_locations) + contaminated, version_nodes = reachable_to_version_specific(graph) + components = tarjan_scc(graph) + component_order = dependency_first_component_order(graph, components) + + by_location = collections.Counter(info.location for info in classes.values()) + root_or_shared = { + node for node, info in classes.items() + if info.location in ("root", "spark-shared") and not is_classifier_class(info.name) + } + blocked = sorted(root_or_shared & contaminated) + safe_shared = sorted( + node for node in root_or_shared - contaminated + if classes[node].location == "spark-shared") + classifier_edges = direct_classifier_edges(graph) + version_components = [comp for comp in components if any(is_version_node(node) for node in comp)] + safe_sccs = [] + for comp_id in component_order: + component = components[comp_id] + safe_members = sorted(node for node in component if node in safe_shared) + if safe_members: + safe_sccs.append((comp_id, safe_members)) + version_blockers = ( + version_blocker_counts(graph, version_nodes, root_or_shared) + if args.show_reachability or args.format == "json" else []) + nearest_targets, blocked_paths = nearest_version_target_counts(graph, blocked) + safe_shared_paths = sorted(location_relative_entry(classes[node]) for node in safe_shared) + + if args.write_safe_paths: + with open(args.write_safe_paths, "w", encoding="utf-8") as out: + for path in safe_shared_paths: + out.write(path) + out.write("\n") + + if args.format == "json": + output = { + "path": args.path, + "classCount": len(classes), + "locationCounts": dict(sorted(by_location.items())), + "versionSpecificNodeCount": len(version_nodes), + "rootOrSharedBlockedCount": len(blocked), + "rootSafeSparkSharedCount": len(safe_shared), + "sccCount": len(components), + "versionSpecificSccCount": len(version_components), + "directClassifierDependencyCount": len(classifier_edges), + "rootSafeSparkSharedPaths": safe_shared_paths, + "directClassifierDependencyExamples": [ + { + "source": json_node(source), + "target": json_node(target), + } + for source, target in classifier_edges[:args.limit] + ], + "topVersionBlockersByReachability": [ + { + "blockedRootOrSharedCount": count, + "target": json_node(target), + } + for count, target in version_blockers[:args.limit] + ], + "nearestVersionTargetCounts": [ + { + "blockedShortestPathCount": count, + "target": json_node(target), + "examplePath": [json_node(node) for node in path], + } + for count, target, path in nearest_targets[:args.limit] + ], + "rootSafeSparkSharedSccCount": len(safe_sccs), + "rootSafeSparkSharedSccExamples": [ + { + "componentId": comp_id, + "classCount": len(members), + "classExamples": [json_node(node) for node in members[:args.limit]], + } + for comp_id, members in safe_sccs[:args.limit] + ], + "blockedExamples": [ + [json_node(node) for node in path] + for _, path in blocked_paths[:args.limit] + ], + } + json.dump(output, sys.stdout, indent=2, sort_keys=True) + print() + return + + print("Loaded %d classes from %s" % (len(classes), args.path)) + if errors: + print("Skipped %d malformed class files" % len(errors)) + print("Class locations:") + for location, count in sorted(by_location.items()): + print(" %s: %d" % (location, count)) + print("Version-specific/classifier nodes: %d" % len(version_nodes)) + print("Root or spark-shared nodes with a path to version-specific code: %d" % len(blocked)) + print("Root-safe spark-shared nodes: %d" % len(safe_shared)) + print("SCCs: %d total, %d containing version-specific code" % + (len(components), len(version_components))) + + print("\nDirect classifier-package dependencies: %d" % len(classifier_edges)) + for source, target in classifier_edges[:args.limit]: + print(" %s -> %s" % (format_node(source), format_node(target))) + if len(classifier_edges) > args.limit: + print(" ... %d more" % (len(classifier_edges) - args.limit)) + + if args.show_reachability: + print("\nTop version-specific blockers by upstream root/shared reachability:") + for count, target in version_blockers[:args.limit]: + print(" %d <- %s" % (count, format_node(target))) + if len(version_blockers) > args.limit: + print(" ... %d more" % (len(version_blockers) - args.limit)) + + print("\nNearest version targets from shortest blocked paths:") + for count, target, path in nearest_targets[:args.limit]: + print(" %d -> %s" % (count, format_node(target))) + print(" e.g. %s" % print_path(path)) + if len(nearest_targets) > args.limit: + print(" ... %d more" % (len(nearest_targets) - args.limit)) + + print("\nNearest paths from root/spark-shared code to version-specific code:") + for _, path in blocked_paths[:args.limit]: + print(" %s" % print_path(path)) + if len(blocked) > args.limit: + print(" ... %d more blocked classes" % (len(blocked) - args.limit)) + + if args.show_safe: + print("\nSpark-shared classes with no path to version-specific code:") + for node in safe_shared[:args.limit]: + print(" %s" % format_node(node)) + if len(safe_shared) > args.limit: + print(" ... %d more" % (len(safe_shared) - args.limit)) + + if args.show_topo: + print("\nRoot-safe spark-shared SCCs in dependency-first order:") + for printed, (comp_id, safe_members) in enumerate(safe_sccs): + print(" component %d, %d class(es)" % (comp_id, len(safe_members))) + for node in safe_members[:3]: + print(" %s" % format_node(node)) + if len(safe_members) > 3: + print(" ... %d more in component" % (len(safe_members) - 3)) + if printed + 1 >= args.limit: + break + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/dist/scripts/binary-dedupe.sh b/dist/scripts/binary-dedupe.sh index 2054e18ccf9..ea3ac931413 100755 --- a/dist/scripts/binary-dedupe.sh +++ b/dist/scripts/binary-dedupe.sh @@ -35,10 +35,34 @@ esac STEP=0 export SPARK_SHARED_TXT="$PWD/spark-shared.txt" +export SPARK_SHARED_CLASSES_TXT="$PWD/spark-shared-classes.txt" export SPARK_SHARED_COPY_LIST="$PWD/spark-shared-copy-list.txt" export DELETE_DUPLICATES_TXT="$PWD/delete-duplicates.txt" export SPARK_SHARED_DIR="$PWD/spark-shared" export UNSHIMMED_FROM_SPARK_SHARED_COPY_LIST="$PWD/unshimmed-from-spark-shared-copy-list.txt" +export ROOT_SAFE_SPARK_SHARED_TXT="$PWD/root-safe-spark-shared.txt" +export DEFAULT_UNSHIMMED_SPARK_SHARED_TXT="$PWD/default-unshimmed-spark-shared.txt" +export UNSHIMMED_NEED_SHARED_TXT="$PWD/unshimmed-need-shared.txt" +export UNSHIMMED_MISSING_SHARED_TXT="$PWD/unshimmed-missing-shared.txt" + +SPARK_SHIM_DIRS=() +if [[ "${UNSHIM_FAST:-0}" == "1" ]]; then + while IFS= read -r shim_dir; do + SPARK_SHIM_DIRS+=("$shim_dir") + done < <(find ./parallel-world -maxdepth 1 -mindepth 1 -type d -name 'spark[34]*' | sort) +fi + +DEDUPE_CACHE_DIR="${UNSHIM_DEDUPE_CACHE_DIR:-}" +DEDUPE_CACHE_SPARK_SHARED_TXT="" +DEDUPE_CACHE_SHA1_FILES_TXT="" +DEDUPE_CACHE_SHIM_SHA_PACKAGE_FILES_TXT="" +DEDUPE_CACHE_COUNT_SHIM_SHA_PACKAGE_FILES_TXT="" +if [[ -n "$DEDUPE_CACHE_DIR" ]]; then + DEDUPE_CACHE_SPARK_SHARED_TXT="$DEDUPE_CACHE_DIR/spark-shared.txt" + DEDUPE_CACHE_SHA1_FILES_TXT="$DEDUPE_CACHE_DIR/tmp-sha1-files.txt" + DEDUPE_CACHE_SHIM_SHA_PACKAGE_FILES_TXT="$DEDUPE_CACHE_DIR/tmp-shim-sha-package-files.txt" + DEDUPE_CACHE_COUNT_SHIM_SHA_PACKAGE_FILES_TXT="$DEDUPE_CACHE_DIR/tmp-count-shim-sha-package-files.txt" +fi # This script de-duplicates .class files at the binary level. # We could also diff classes using scalap / javap outputs. @@ -55,24 +79,54 @@ export UNSHIMMED_FROM_SPARK_SHARED_COPY_LIST="$PWD/unshimmed-from-spark-shared-c # - put the path starting with /sparkxyz back together for the final list echo "Retrieving class files hashing to a single value ..." - -echo "$((++STEP))/ SHA1 of all non-META files > tmp-sha1-files.txt" -find ./parallel-world/spark[34]* -name META-INF -prune -o -name webapps -prune -o \( -type f -print0 \) | \ - xargs --null $SHASUM > tmp-sha1-files.txt - -echo "$((++STEP))/ make shim column 1 > tmp-shim-sha-package-files.txt" -< tmp-sha1-files.txt awk -F/ '$1=$1' | \ - awk '{checksum=$1; shim=$4; $1=shim; $2=$3=""; $4=checksum; print $0}' | \ - tr -s ' ' > tmp-shim-sha-package-files.txt - -echo "$((++STEP))/ sort by path, sha1; output first from each group > tmp-count-shim-sha-package-files.txt" -sort -k3 -k2,2 -u tmp-shim-sha-package-files.txt | \ - uniq -f 2 -c > tmp-count-shim-sha-package-files.txt - -echo "$((++STEP))/ files with unique sha1 > $SPARK_SHARED_TXT" -grep '^\s\+1 .*' tmp-count-shim-sha-package-files.txt | \ - awk '{$1=""; $3=""; print $0 }' | \ - tr -s ' ' | sed 's/\ /\//g' > "$SPARK_SHARED_TXT" +CACHE_HIT=0 +if [[ -n "$DEDUPE_CACHE_SPARK_SHARED_TXT" && \ + -f "$DEDUPE_CACHE_SPARK_SHARED_TXT" && \ + -f "$DEDUPE_CACHE_SHA1_FILES_TXT" && \ + -f "$DEDUPE_CACHE_SHIM_SHA_PACKAGE_FILES_TXT" && \ + -f "$DEDUPE_CACHE_COUNT_SHIM_SHA_PACKAGE_FILES_TXT" ]]; then + echo "$((++STEP))/ reusing cached files with unique sha1 > $SPARK_SHARED_TXT" + cp "$DEDUPE_CACHE_SPARK_SHARED_TXT" "$SPARK_SHARED_TXT" + cp "$DEDUPE_CACHE_SHA1_FILES_TXT" tmp-sha1-files.txt + cp "$DEDUPE_CACHE_SHIM_SHA_PACKAGE_FILES_TXT" tmp-shim-sha-package-files.txt + cp "$DEDUPE_CACHE_COUNT_SHIM_SHA_PACKAGE_FILES_TXT" tmp-count-shim-sha-package-files.txt + CACHE_HIT=1 +# With one shim there is no cross-shim identity proof to perform; every +# non-META file is the sole representative for its path. +elif [[ "${UNSHIM_FAST:-0}" == "1" && "${#SPARK_SHIM_DIRS[@]}" == "1" ]]; then + echo "$((++STEP))/ single shim fast path; listing files > $SPARK_SHARED_TXT" + : > tmp-sha1-files.txt + : > tmp-shim-sha-package-files.txt + : > tmp-count-shim-sha-package-files.txt + find "${SPARK_SHIM_DIRS[0]}" -name META-INF -prune -o -name webapps -prune -o \( -type f -print \) | \ + sort | sed 's|^\./parallel-world||' > "$SPARK_SHARED_TXT" +else + echo "$((++STEP))/ SHA1 of all non-META files > tmp-sha1-files.txt" + find ./parallel-world/spark[34]* -name META-INF -prune -o -name webapps -prune -o \( -type f -print0 \) | \ + xargs --null $SHASUM > tmp-sha1-files.txt + + echo "$((++STEP))/ make shim column 1 > tmp-shim-sha-package-files.txt" + < tmp-sha1-files.txt awk -F/ '$1=$1' | \ + awk '{checksum=$1; shim=$4; $1=shim; $2=$3=""; $4=checksum; print $0}' | \ + tr -s ' ' > tmp-shim-sha-package-files.txt + + echo "$((++STEP))/ sort by path, sha1; output first from each group > tmp-count-shim-sha-package-files.txt" + sort -k3 -k2,2 -u tmp-shim-sha-package-files.txt | \ + uniq -f 2 -c > tmp-count-shim-sha-package-files.txt + + echo "$((++STEP))/ files with unique sha1 > $SPARK_SHARED_TXT" + grep '^\s\+1 .*' tmp-count-shim-sha-package-files.txt | \ + awk '{$1=""; $3=""; print $0 }' | \ + tr -s ' ' | sed 's/\ /\//g' > "$SPARK_SHARED_TXT" +fi + +if [[ "$CACHE_HIT" == "0" && -n "$DEDUPE_CACHE_SPARK_SHARED_TXT" ]]; then + mkdir -p "$DEDUPE_CACHE_DIR" + cp "$SPARK_SHARED_TXT" "$DEDUPE_CACHE_SPARK_SHARED_TXT" + cp tmp-sha1-files.txt "$DEDUPE_CACHE_SHA1_FILES_TXT" + cp tmp-shim-sha-package-files.txt "$DEDUPE_CACHE_SHIM_SHA_PACKAGE_FILES_TXT" + cp tmp-count-shim-sha-package-files.txt "$DEDUPE_CACHE_COUNT_SHIM_SHA_PACKAGE_FILES_TXT" +fi function retain_single_copy() { set -e @@ -100,9 +154,10 @@ function retain_single_copy() { done >> "$DELETE_DUPLICATES_TXT" || exit 255 } -function copy_unshimmed_from_spark_shared() { +function append_matching_spark_shared_patterns() { set -e - local unshimmed_patterns_txt="${UNSHIMMED_COMMON_FROM_SINGLE_SHIM_TXT:-}" + local unshimmed_patterns_txt="$1" + local output_txt="$2" [[ -n "$unshimmed_patterns_txt" ]] || return 0 [[ -f "$unshimmed_patterns_txt" ]] || { @@ -110,23 +165,102 @@ function copy_unshimmed_from_spark_shared() { exit 255 } - : > "$UNSHIMMED_FROM_SPARK_SHARED_COPY_LIST" - while read -r shared_path; do - local rel_path="${shared_path#./parallel-world/spark-shared/}" - local pattern - while read -r pattern; do - [[ -n "$pattern" ]] || continue - [[ "$pattern" =~ ^[[:space:]]*# ]] && continue - # shellcheck disable=SC2053 - if [[ "$rel_path" == $pattern ]]; then - echo "$rel_path" >> "$UNSHIMMED_FROM_SPARK_SHARED_COPY_LIST" - break - fi - done < "$unshimmed_patterns_txt" - done < <(find ./parallel-world/spark-shared -type f) + local shared_dir="./parallel-world/spark-shared" + local pattern + while IFS= read -r pattern; do + [[ -n "$pattern" ]] || continue + [[ "$pattern" =~ ^[[:space:]]*# ]] && continue + case "$pattern" in + *[\*\?\[]*) + find "$shared_dir" -type f -path "$shared_dir/$pattern" | + sed "s|^\./parallel-world/spark-shared/||" >> "$output_txt" + ;; + *) + if [[ -f "$shared_dir/$pattern" ]]; then + echo "$pattern" >> "$output_txt" + fi + ;; + esac + done < "$unshimmed_patterns_txt" +} + +function write_root_safe_spark_shared_classes() { + set -e + local analyzer_script="${UNSHIM_ANALYZER_SCRIPT:-}" + if [[ -z "$analyzer_script" && -n "${UNSHIMMED_COMMON_FROM_SINGLE_SHIM_TXT:-}" ]]; then + analyzer_script="$(dirname "$UNSHIMMED_COMMON_FROM_SINGLE_SHIM_TXT")/scripts/analyze-parallel-world-deps.py" + fi + [[ -n "$analyzer_script" && -f "$analyzer_script" ]] || { + echo >&2 "Cannot locate analyze-parallel-world-deps.py for default unshim analysis" + exit 255 + } + + echo "$((++STEP))/ analyzing spark-shared dependency paths > $ROOT_SAFE_SPARK_SHARED_TXT" + python3 "$analyzer_script" ./parallel-world \ + --write-safe-paths "$ROOT_SAFE_SPARK_SHARED_TXT" +} +function write_default_unshimmed_spark_shared_classes() { + set -e + echo "$((++STEP))/ selecting all bitwise-identical spark-shared classes > $DEFAULT_UNSHIMMED_SPARK_SHARED_TXT" + sed -E "s|^/spark[^/]*/||" "$SPARK_SHARED_TXT" | \ + grep '\.class$' | sort -u > "$DEFAULT_UNSHIMMED_SPARK_SHARED_TXT" +} + +function keep_in_spark_shared() { + set -e + local class_file="$1" + local keep_patterns_txt="${KEEP_IN_SPARK_SHARED_TXT:-}" + [[ -n "$keep_patterns_txt" ]] || return 1 + [[ -f "$keep_patterns_txt" ]] || { + echo >&2 "Keep-in-spark-shared list does not exist: $keep_patterns_txt" + exit 255 + } + + local pattern + while IFS= read -r pattern; do + [[ -n "$pattern" ]] || continue + [[ "$pattern" =~ ^[[:space:]]*# ]] && continue + # shellcheck disable=SC2053 + if [[ "$class_file" == $pattern ]]; then + return 0 + fi + done < "$keep_patterns_txt" + return 1 +} + +function filter_keep_in_spark_shared() { + set -e + local input_txt="$1" + local output_txt="$2" + local class_file + : > "$output_txt" + while IFS= read -r class_file; do + [[ -n "$class_file" ]] || continue + if keep_in_spark_shared "$class_file"; then + continue + fi + echo "$class_file" + done < "$input_txt" > "$output_txt.tmp" + mv "$output_txt.tmp" "$output_txt" +} + +function copy_unshimmed_from_spark_shared() { + set -e + local raw_copy_list="$UNSHIMMED_FROM_SPARK_SHARED_COPY_LIST.raw" + local sorted_copy_list="$UNSHIMMED_FROM_SPARK_SHARED_COPY_LIST.sorted" + + : > "$raw_copy_list" + write_root_safe_spark_shared_classes + write_default_unshimmed_spark_shared_classes + cat "$DEFAULT_UNSHIMMED_SPARK_SHARED_TXT" >> "$raw_copy_list" + append_matching_spark_shared_patterns \ + "${UNSHIMMED_COMMON_FROM_SINGLE_SHIM_TXT:-}" "$raw_copy_list" + + sort -u "$raw_copy_list" > "$sorted_copy_list" + filter_keep_in_spark_shared "$sorted_copy_list" "$UNSHIMMED_FROM_SPARK_SHARED_COPY_LIST" if [[ -s "$UNSHIMMED_FROM_SPARK_SHARED_COPY_LIST" ]]; then - echo "Promoting root-layout files from spark-shared via $unshimmed_patterns_txt" + echo "Promoting root-layout files from spark-shared by default" rsync --files-from="$UNSHIMMED_FROM_SPARK_SHARED_COPY_LIST" \ ./parallel-world/spark-shared ./parallel-world fi @@ -141,9 +275,23 @@ rm -rf "$SPARK_SHARED_DIR" mkdir -p "$SPARK_SHARED_DIR" echo "$((++STEP))/ retaining a single copy of spark-shared classes" -while read -r spark_common_class; do - retain_single_copy "$spark_common_class" -done < "$SPARK_SHARED_TXT" +awk -F/ " + NF >= 3 { + shim = \$2 + package_class = \$0 + sub(\"^/spark[34][^/]*/\", \"\", package_class) + print package_class >> (\"from-\" shim \"-to-spark-shared.txt\") + } +" "$SPARK_SHARED_TXT" +for pw in ./parallel-world/spark[34]* ; do + awk -v pw="$pw" " + { + package_class = \$0 + sub(\"^/spark[34][^/]*/\", \"\", package_class) + print pw \"/\" package_class + } + " "$SPARK_SHARED_TXT" +done >> "$DELETE_DUPLICATES_TXT" echo "$((++STEP))/ rsyncing common classes to $SPARK_SHARED_DIR" for copy_list in from-spark[34]*-to-spark-shared.txt; do @@ -157,7 +305,7 @@ done mv "$SPARK_SHARED_DIR" parallel-world/ -echo "$((++STEP))/ promoting allowlisted spark-shared files to root layout" +echo "$((++STEP))/ promoting default spark-shared files to root layout" copy_unshimmed_from_spark_shared # Verify that all class files in the conventional jar location are bitwise @@ -184,11 +332,16 @@ copy_unshimmed_from_spark_shared # Determine the list of unshimmed class files UNSHIMMED_LIST_TXT=unshimmed-result.txt -echo "$((++STEP))/ creating sorted list of unshimmed classes > $UNSHIMMED_LIST_TXT" -find ./parallel-world -name '*.class' -not -path './parallel-world/spark[34-]*' | \ +echo "$((++STEP))/ creating sorted list of root-layout unshimmed classes > $UNSHIMMED_LIST_TXT" +find ./parallel-world -name '*.class' \ + -not -path './parallel-world/spark[34-]*' \ + -not -path './parallel-world/spark-shared/*' | \ cut -d/ -f 3- | sort > "$UNSHIMMED_LIST_TXT" -function verify_same_sha_for_unshimmed() { +echo "$((++STEP))/ creating sorted list of spark-shared classes > $SPARK_SHARED_CLASSES_TXT" +sed -E "s|^/spark[^/]*/||" "$SPARK_SHARED_TXT" | sort -u > "$SPARK_SHARED_CLASSES_TXT" + +function unshimmed_class_needs_shared_identity() { set -e class_file="$1" @@ -196,7 +349,7 @@ function verify_same_sha_for_unshimmed() { # including the ones that are unshimmed. Instead of expensively recomputing # sha1 look up if there is an entry with the unshimmed class as a suffix - class_file_quoted=$(printf '%q' "$class_file") + class_file_quoted=$(printf "%q" "$class_file") # TODO currently RapidsShuffleManager is "removed" from /spark* by construction in # dist pom.xml via ant. We could delegate this logic to this script # and make both simmpler @@ -211,34 +364,72 @@ function verify_same_sha_for_unshimmed() { # the class provides concrete implementations for ALL getReader variants, # so the JVM resolves the correct one at runtime regardless of which # ShuffleManager version the class was compiled against. - if [[ ! "$class_file_quoted" =~ com/nvidia/spark/rapids/spark[34].*/.*ShuffleManager.class && \ - "$class_file_quoted" != "com/nvidia/spark/ParquetCachedBatchSerializer.class" && \ - ! "$class_file_quoted" =~ org/apache/spark/sql/rapids/ProxyRapidsShuffleInternalManagerBase ]]; then - if ! grep -q "/spark.\+/$class_file_quoted" "$SPARK_SHARED_TXT"; then - echo >&2 "$class_file is not bitwise-identical across shims" - exit 255 - fi + # GpuShuffleDependency has identical JVM bytecode and descriptors between + # Spark 3.5 and 4.1. Only ScalaSignature metadata differs after compiling + # the same source against different Spark dependency jars. WindowInPandasExecTypeShim + # has no methods in the class shell; its companion carries the behavior. + # CloseableColumnBatchIterator has identical descriptors and code; Scala 2.13 only + # renames generic Signature-attribute type variables across the Spark 3.5/4.1 compiles. + # GpuReadCSVFileFormat and GpuReadJsonFileFormat have identical descriptors and + # executable javap output; only ScalaSignature metadata differs across Spark deps. + # PythonMapInArrowExecShims and PythonArgumentUtils class shells have identical + # executable bytecode; only source-file metadata differs across shim source names. + # GpuUnionExecShim and RapidsErrorUtils class shells have identical executable + # bytecode; only ScalaSignature metadata differs. + # GpuStringTrim* differs after Spark 4.1 because String2TrimExpression adds + # collation/context-independent foldability methods. The case-class fields, + # product surface, and Spark 3.5-callable methods remain compatible; Spark 3.x + # does not invoke the added methods. + # GpuAtomicCreateTableAsSelectExec companion has identical executable bytecode; + # only line-number debug metadata differs across shim sources. + if [[ "$class_file_quoted" =~ com/nvidia/spark/rapids/spark[34].*/.*ShuffleManager.class || \ + "$class_file_quoted" == "com/nvidia/spark/ParquetCachedBatchSerializer.class" || \ + "$class_file_quoted" =~ org/apache/spark/sql/rapids/ProxyRapidsShuffleInternalManagerBase || \ + "$class_file_quoted" == "org/apache/spark/sql/rapids/GpuShuffleDependency.class" || \ + "$class_file_quoted" == "com/nvidia/spark/rapids/parquet/CloseableColumnBatchIterator.class" || \ + "$class_file_quoted" == "com/nvidia/spark/rapids/GpuReadCSVFileFormat.class" || \ + "$class_file_quoted" == "org/apache/spark/sql/catalyst/json/rapids/GpuReadJsonFileFormat.class" || \ + "$class_file_quoted" == "com/nvidia/spark/rapids/shims/PythonMapInArrowExecShims.class" || \ + "$class_file_quoted" == "org/apache/spark/sql/rapids/execution/python/shims/PythonArgumentUtils.class" || \ + "$class_file_quoted" == "com/nvidia/spark/rapids/shims/GpuUnionExecShim.class" || \ + "$class_file_quoted" == "org/apache/spark/sql/rapids/GpuStringTrim.class" || \ + "$class_file_quoted" == "org/apache/spark/sql/rapids/GpuStringTrimLeft.class" || \ + "$class_file_quoted" == "org/apache/spark/sql/rapids/GpuStringTrimRight.class" || \ + "$class_file" == "org/apache/spark/sql/execution/datasources/v2/rapids/GpuAtomicCreateTableAsSelectExec$.class" || \ + "$class_file_quoted" == "org/apache/spark/sql/rapids/shims/RapidsErrorUtils.class" || \ + "$class_file_quoted" == "org/apache/spark/sql/rapids/execution/python/shims/WindowInPandasExecTypeShim.class" ]]; then + return 1 fi + return 0 } -echo "$((++STEP))/ verifying unshimmed classes have unique sha1 across shims" +echo "$((++STEP))/ filtering unshimmed classes that require shared identity > $UNSHIMMED_NEED_SHARED_TXT" while read -r unshimmed_class; do - verify_same_sha_for_unshimmed "$unshimmed_class" -done < "$UNSHIMMED_LIST_TXT" + if unshimmed_class_needs_shared_identity "$unshimmed_class"; then + echo "$unshimmed_class" + fi +done < "$UNSHIMMED_LIST_TXT" | sort -u > "$UNSHIMMED_NEED_SHARED_TXT" + +echo "$((++STEP))/ verifying unshimmed classes have unique sha1 across shims" +comm -23 "$UNSHIMMED_NEED_SHARED_TXT" "$SPARK_SHARED_CLASSES_TXT" > "$UNSHIMMED_MISSING_SHARED_TXT" +if [[ -s "$UNSHIMMED_MISSING_SHARED_TXT" ]]; then + read -r missing_unshimmed_class < "$UNSHIMMED_MISSING_SHARED_TXT" + echo >&2 "$missing_unshimmed_class is not bitwise-identical across shims" + exit 255 +fi # Remove unshimmed classes from parallel worlds # TODO rework with low priority, only a few classes. echo "$((++STEP))/ removing duplicates of unshimmed classes" - -while read -r unshimmed_class; do +{ + sed "s|^|./parallel-world/spark-shared/|" "$UNSHIMMED_LIST_TXT" for pw in ./parallel-world/spark[34-]* ; do - unshimmed_path="$pw/$unshimmed_class" - [[ -f "$unshimmed_path" ]] && echo "$unshimmed_path" || true - done >> "$DELETE_DUPLICATES_TXT" -done < "$UNSHIMMED_LIST_TXT" + awk -v pw="$pw" "{ print pw \"/\" \$0 }" "$UNSHIMMED_LIST_TXT" + done +} >> "$DELETE_DUPLICATES_TXT" echo "$((++STEP))/ deleting all class files listed in $DELETE_DUPLICATES_TXT" -< "$DELETE_DUPLICATES_TXT" sort -u | xargs rm +< "$DELETE_DUPLICATES_TXT" sort -u | xargs rm -f end_time=$(date +%s) echo "binary-dedupe completed in $((end_time - start_time)) seconds" diff --git a/dist/scripts/build-unshim-parallel-world.py b/dist/scripts/build-unshim-parallel-world.py new file mode 100644 index 00000000000..9f41be63558 --- /dev/null +++ b/dist/scripts/build-unshim-parallel-world.py @@ -0,0 +1,292 @@ +#!/usr/bin/env python3 + +# Copyright (c) 2026, NVIDIA CORPORATION. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Build dist/target/parallel-world directly for repeated unshim analysis. + +This mirrors the analyzer-relevant part of dist/maven-antrun/build-parallel-worlds.xml +without starting a final Maven dist generate-resources invocation. It assumes buildall +has already built the per-shim sql-plugin-api and aggregator jars under target/sparkXYZ. +""" + +import argparse +import fnmatch +import hashlib +import os +from pathlib import Path +import shutil +import subprocess +import sys +import zipfile + + +ARTIFACTS = ("sql-plugin-api", "aggregator") + + +def read_patterns(path): + with path.open() as fh: + return [ + line.strip() + for line in fh + if line.strip() and not line.lstrip().startswith("#") + ] + + +def has_fnmatch_magic(pattern): + return any(ch in pattern for ch in "*?[") + + +def matching_members(namelist, patterns): + names_by_entry = {} + for name in namelist: + names_by_entry.setdefault(name, []).append(name) + + matches = [] + for pattern in patterns: + if has_fnmatch_magic(pattern): + matches.extend(fnmatch.filter(namelist, pattern)) + else: + matches.extend(names_by_entry.get(pattern, [])) + return matches + + +def safe_extract(zip_handle, destination, members=None): + destination = destination.resolve() + for member in members if members is not None else zip_handle.namelist(): + target = (destination / member).resolve() + if not str(target).startswith(str(destination) + os.sep): + raise RuntimeError("refusing to extract outside destination: %s" % member) + zip_handle.extract(member, destination) + + +def clean_output(target_dir): + for dirname in ("parallel-world", "deps", "extra-resources"): + path = target_dir / dirname + if path.exists(): + shutil.rmtree(path) + path.mkdir(parents=True, exist_ok=True) + for jar_path in target_dir.glob("*.jar"): + jar_path.unlink() + + +def artifact_jar(base_dir, artifact, scala_binary_version, project_version, buildver): + artifact_id = "rapids-4-spark-%s_%s" % (artifact, scala_binary_version) + classifier = "spark%s" % buildver + jar_name = "%s-%s-%s.jar" % (artifact_id, project_version, classifier) + jar_path = base_dir / artifact / "target" / classifier / jar_name + if not jar_path.is_file(): + raise FileNotFoundError( + "expected built %s jar missing: %s" % (artifact, jar_path)) + return jar_path + + +def jar_signature(jar_path): + stat = jar_path.stat() + return "\n".join(( + "path=%s" % jar_path, + "size=%s" % stat.st_size, + "mtime_ns=%s" % stat.st_mtime_ns, + "", + )) + + +def dedupe_cache_key(base_dir, scala_binary_version, project_version, buildvers): + parts = [] + for buildver in sorted(buildvers, reverse=True): + for artifact in ARTIFACTS: + jar_path = artifact_jar( + base_dir, artifact, scala_binary_version, project_version, buildver) + parts.extend(( + "buildver=%s" % buildver, + "artifact=%s" % artifact, + jar_signature(jar_path), + )) + return hashlib.sha1("\n".join(parts).encode("utf-8")).hexdigest() + + +def ensure_extracted_cache(jar_path, cache_dir): + contents_dir = cache_dir / "contents" + marker = cache_dir / ".source" + signature = jar_signature(jar_path) + + if marker.is_file() and marker.read_text() == signature: + return contents_dir + + if cache_dir.exists(): + shutil.rmtree(cache_dir) + contents_dir.mkdir(parents=True, exist_ok=True) + with zipfile.ZipFile(jar_path) as zip_handle: + safe_extract(zip_handle, contents_dir) + marker.write_text(signature) + return contents_dir + + +def link_or_copy(src, dst): + dst.parent.mkdir(parents=True, exist_ok=True) + if dst.exists() or dst.is_symlink(): + dst.unlink() + try: + os.link(src, dst) + except OSError: + shutil.copy2(src, dst) + + +def link_tree_contents(src_dir, dst_dir): + for root, _, files in os.walk(src_dir): + root_path = Path(root) + rel_root = root_path.relative_to(src_dir) + target_root = dst_dir / rel_root + target_root.mkdir(parents=True, exist_ok=True) + for name in files: + link_or_copy(root_path / name, target_root / name) + + +def link_members(contents_dir, destination, members): + for member in members: + if member.endswith("/"): + continue + src = contents_dir / member + if src.is_file(): + link_or_copy(src, destination / member) + + +def copy_and_extract_jars( + base_dir, + target_dir, + scala_binary_version, + project_version, + buildvers, + from_single_shim, + from_each): + parallel_world = target_dir / "parallel-world" + cache_root = target_dir / "unshim-parallel-world-cache" + sorted_buildvers = sorted(buildvers, reverse=True) + root_buildver = sorted_buildvers[0] + + for buildver in sorted_buildvers: + classifier = "spark%s" % buildver + for artifact in ARTIFACTS: + jar_path = artifact_jar( + base_dir, artifact, scala_binary_version, project_version, buildver) + contents_dir = ensure_extracted_cache( + jar_path, cache_root / classifier / artifact) + with zipfile.ZipFile(jar_path) as zip_handle: + namelist = zip_handle.namelist() + + link_tree_contents(contents_dir, parallel_world / classifier) + if buildver == root_buildver and artifact == "sql-plugin-api": + link_tree_contents(contents_dir, parallel_world) + + patterns = from_each + if buildver == root_buildver: + patterns = from_single_shim + from_each + members = matching_members(namelist, patterns) + link_members(contents_dir, parallel_world, members) + + +def run_checked(command, cwd, env=None): + subprocess.run(command, cwd=str(cwd), env=env, check=True) + + +def remove_allowlisted_from_spark_shared(parallel_world, from_single_shim): + shared_dir = parallel_world / "spark-shared" + if not shared_dir.is_dir(): + return + + for pattern in from_single_shim: + if has_fnmatch_magic(pattern): + for path in shared_dir.rglob("*"): + if path.is_file() and fnmatch.fnmatch(path.relative_to(shared_dir).as_posix(), pattern): + path.unlink() + else: + path = shared_dir / pattern + if path.is_file(): + path.unlink() + + +def main(): + parser = argparse.ArgumentParser(description=__doc__) + parser.add_argument("--mvn-base-dir", required=True, + help="Maven build root containing module target directories") + parser.add_argument("--source-dir", required=True, + help="Top-level spark-rapids source directory") + parser.add_argument("--project-version", required=True) + parser.add_argument("--scala-binary-version", required=True) + parser.add_argument("--buildvers", required=True, + help="Comma-separated Spark build versions, for example 350,411") + parser.add_argument("--ignore-shim-revisions-check", action="store_true", + help="Continue when per-shim build metadata revisions differ") + args = parser.parse_args() + + base_dir = Path(args.mvn_base_dir).resolve() + source_dir = Path(args.source_dir).resolve() + dist_dir = source_dir / "dist" + target_dir = base_dir / "dist" / "target" + parallel_world = target_dir / "parallel-world" + buildvers = [item.strip() for item in args.buildvers.split(",") if item.strip()] + + if len(buildvers) == 0: + raise RuntimeError("no build versions were supplied") + + from_single_shim = read_patterns(dist_dir / "unshimmed-common-from-single-shim.txt") + from_each = read_patterns(dist_dir / "unshimmed-from-each-spark3xx.txt") + + print("Direct unshim parallel-world assembly for Spark versions: %s" % + ", ".join(buildvers), + flush=True) + clean_output(target_dir) + copy_and_extract_jars( + base_dir, + target_dir, + args.scala_binary_version, + args.project_version, + buildvers, + from_single_shim, + from_each) + + revision_check = subprocess.run( + [str(dist_dir / "scripts" / "check-shims-revisions.sh"), ",".join(buildvers)], + cwd=str(target_dir), + check=False) + if revision_check.returncode != 0: + if args.ignore_shim_revisions_check: + print("Ignoring shim revision check failure for direct unshim parallel-world assembly", + flush=True) + else: + revision_check.check_returncode() + + dedupe_env = os.environ.copy() + dedupe_env["UNSHIM_FAST"] = "1" + dedupe_env["UNSHIM_DEDUPE_CACHE_DIR"] = str( + target_dir / "unshim-dedupe-cache" / dedupe_cache_key( + base_dir, + args.scala_binary_version, + args.project_version, + buildvers)) + dedupe_env["UNSHIMMED_COMMON_FROM_SINGLE_SHIM_TXT"] = str( + dist_dir / "unshimmed-common-from-single-shim.txt") + dedupe_env["KEEP_IN_SPARK_SHARED_TXT"] = str(dist_dir / "keep-in-spark-shared.txt") + dedupe_env["UNSHIM_ANALYZER_SCRIPT"] = str( + dist_dir / "scripts" / "analyze-parallel-world-deps.py") + run_checked([str(dist_dir / "scripts" / "binary-dedupe.sh")], + cwd=target_dir, + env=dedupe_env) + remove_allowlisted_from_spark_shared(parallel_world, from_single_shim) + + print("Direct unshim parallel-world output: %s" % parallel_world, flush=True) + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/dist/unshimmed-common-from-single-shim.txt b/dist/unshimmed-common-from-single-shim.txt index 5802807a250..a3dc3ed0214 100644 --- a/dist/unshimmed-common-from-single-shim.txt +++ b/dist/unshimmed-common-from-single-shim.txt @@ -1,53 +1,9 @@ +# Files that must be promoted to the root layout from one representative shim +# but are not selected by default class promotion. Common class files are +# unshimmed by default when binary-dedupe proves they are bitwise-identical +# across shims. META-INF/DEPENDENCIES META-INF/LICENSE META-INF/NOTICE -com/nvidia/spark/rapids/ExplainPlan.class -com/nvidia/spark/rapids/ExplainPlan$.class -com/nvidia/spark/rapids/ExplainPlanBase.class -com/nvidia/spark/rapids/Optimizer.class -com/nvidia/spark/rapids/optimizer/SQLOptimizerPlugin* -com/nvidia/spark/rapids/ShimLoaderTemp* -com/nvidia/spark/rapids/SparkShims* -com/nvidia/spark/rapids/fileio/iceberg/IcebergInputFile.class -com/nvidia/spark/rapids/fileio/iceberg/IcebergInputStream.class -com/nvidia/spark/rapids/fileio/iceberg/IcebergOutputFile.class -com/nvidia/spark/rapids/fileio/iceberg/IcebergOutputStream.class -com/nvidia/spark/rapids/iceberg/GpuInternalRow.class -com/nvidia/spark/rapids/iceberg/GpuInternalRowBase.class -com/nvidia/spark/rapids/iceberg/data/GpuDeleteFilter2.class -com/nvidia/spark/rapids/iceberg/package.class -com/nvidia/spark/rapids/iceberg/package$.class -com/nvidia/spark/rapids/iceberg/parquet/FileSchemaAccessors.class -com/nvidia/spark/rapids/iceberg/parquet/GpuIcebergParquetReader$.class -com/nvidia/spark/rapids/iceberg/parquet/SingleFile.class -com/nvidia/spark/rapids/iceberg/parquet/SingleFile$.class -com/nvidia/spark/rapids/iceberg/parquet/ThreadConf.class -com/nvidia/spark/rapids/iceberg/spark/GpuSparkReadOptions.class -com/nvidia/spark/rapids/iceberg/spark/GpuSparkReadOptions$.class -com/nvidia/spark/rapids/iceberg/spark/GpuSparkSQLProperties.class -com/nvidia/spark/rapids/iceberg/spark/GpuSparkSQLProperties$.class -com/nvidia/spark/rapids/iceberg/spark/GpuSparkUtil.class -com/nvidia/spark/rapids/iceberg/spark/GpuSparkUtil$.class -com/nvidia/spark/rapids/iceberg/spark/RapidsSparkCatalog.class -com/nvidia/spark/rapids/iceberg/spark/RapidsSparkSessionCatalog.class -com/nvidia/spark/rapids/iceberg/spark/source/RapidsSparkTable.class -org/apache/iceberg/aws/s3/IcebergS3InputFileAccess.class -org/apache/iceberg/data/GpuFileHelpers.class -org/apache/iceberg/io/GpuClusteredWriterBridge.class -org/apache/iceberg/io/GpuFanoutWriterBridge.class -org/apache/iceberg/io/GpuPositionDeleteFileWriter$.class -org/apache/iceberg/parquet/GpuParquetIOAccess.class -org/apache/iceberg/spark/GpuTypeToSparkType.class -org/apache/iceberg/spark/GpuTypeToSparkType$.class -org/apache/iceberg/spark/GpuSparkReadConf.class -org/apache/iceberg/spark/GpuSparkReadConfAccess.class -org/apache/iceberg/spark/package.class -org/apache/iceberg/spark/package$.class -org/apache/iceberg/spark/source/GpuBaseReader.class -org/apache/iceberg/spark/source/GpuSparkPlanningUtil.class -org/apache/iceberg/spark/source/GpuSparkScanAccess.class -org/apache/iceberg/spark/source/GpuSparkWriteAccess.class -org/apache/iceberg/spark/source/GpuStructInternalRow.class -org/apache/spark/sql/rapids/AdaptiveSparkPlanHelperShim* -org/apache/spark/sql/rapids/ExecutionPlanCaptureCallback* +rapids4spark-private-version-info.properties rapids/*.py diff --git a/dist/unshimmed-from-each-spark3xx.txt b/dist/unshimmed-from-each-spark3xx.txt index 918a572722b..1f96d9d0781 100644 --- a/dist/unshimmed-from-each-spark3xx.txt +++ b/dist/unshimmed-from-each-spark3xx.txt @@ -9,4 +9,6 @@ com/nvidia/spark/rapids/delta/DeltaProbe.class com/nvidia/spark/rapids/delta/DeltaProvider.class com/nvidia/spark/rapids/delta/DeltaProvider$.class com/nvidia/spark/rapids/PlanShims* +org/apache/spark/sql/rapids/GpuShuffleDependency.class +org/apache/spark/sql/rapids/execution/python/shims/WindowInPandasExecTypeShim.class spark-*-info.properties diff --git a/docs/dev/shimplify.md b/docs/dev/shimplify.md index 4fefd824c7c..dd1f83f871d 100644 --- a/docs/dev/shimplify.md +++ b/docs/dev/shimplify.md @@ -266,4 +266,4 @@ See [CPD user doc][7] for more details about the options you can pass inside `cp [4]: https://jsonlines.org/ [5]: https://spark.apache.org/versioning-policy.html [6]: https://plugins.jetbrains.com/plugin/16429-idea-resolve-symlinks -[7]: https://docs.pmd-code.org/latest/pmd_userdocs_cpd.html +[7]: https://pmd.github.io/pmd/pmd_userdocs_cpd.html diff --git a/docs/dev/shims.md b/docs/dev/shims.md index 38a368df73b..f68b5e61e81 100644 --- a/docs/dev/shims.md +++ b/docs/dev/shims.md @@ -22,6 +22,100 @@ class as a tight entry point for interacting with the host Spark runtime. In the following we provide recipes for typical scenarios addressed by the Shim layer. +## One-way Shim Module Boundary + +Shim source can be split between three layers when the implementation does not have to live +in the same module as the Spark-version-specific API reference. + +1. `sql-plugin-api` contains the narrow shared types that both sides can see. These types must + not depend on `sql-plugin` implementation classes. +2. `sql-plugin-shims` depends on `sql-plugin-api` and Spark. It may reference Spark classes whose + source or binary shape varies by build version, but it must not reference implementation types + such as `GpuOverrides`, `RapidsMeta`, `ExprRule`, `ExecRule`, or GPU meta classes. +3. `sql-plugin` depends on `sql-plugin-shims`. It turns API-level shim descriptors into concrete + plugin rules and owns the RAPIDS metadata factories. + +For replacement rules, use descriptor objects when the shim only needs to identify a Spark class +and provide stable rule metadata. For example, `ShimDataWritingCommandRule`, +`ShimRunnableCommandRule`, and `ShimExecRule` live in `sql-plugin-api`; versioned objects in +`sql-plugin-shims` instantiate those descriptors with Spark-specific class tags; `sql-plugin` +then calls the corresponding `GpuOverrides.*FromShim` method and supplies the actual `RapidsMeta` +factory. This keeps the call direction one-way: shared plugin code can consume shim descriptors, +while shim code cannot call back into shared plugin implementation. + +Classes whose `spark-rapids-shim-json-lines` entries cover all build versions can be unshimmed +into a common source root when there is no special-version sibling and the source is truly +compatible across the supported Spark APIs. When a file has Databricks-specific, Spark 4.1-specific, +or otherwise divergent siblings, keep the version-specific source and move only the API-safe part +behind the one-way boundary. + +## Reducing Parallel-World Classes + +The long-term goal is to maximize bytecode in the conventional jar layout and shrink the amount +of code that must be loaded through the parallel-world mechanism. A class can move from +`spark-shared` to the conventional layout only when it has no static dependency path to +Spark-version-specific bytecode. The dependency path matters transitively: a `spark-shared` class +that calls another `spark-shared` class that eventually calls a `sparkXYZ` class is not root-safe. + +`dist/unshimmed-common-from-single-shim.txt` names classes and resources that are allowed to be +stored in the conventional layout after the dist jar is assembled. During `binary-dedupe.sh`, files +from that allowlist may be promoted out of `spark-shared` into the root layout before the bitwise +identity check runs. This is important for profiles where the highest Spark build contributes only a +stub module, while a lower Spark build contributes the real implementation. For example, root-safe +Iceberg helpers can still be placed in the conventional layout even when the Spark 4.1 shim uses the +Iceberg stub. + +Use a small bootstrap allowlist for classes that are allowed to refer to packages generated with +`$_spark.version.classifier_`, such as `com.nvidia.spark.rapids.spark330.RapidsShuffleManager`. +Ordinary shared implementation classes should not have direct static dependencies on those +classifier packages. They should instead call through stable contracts in `sql-plugin-api` or +through descriptor objects in `sql-plugin-shims`. + +For an inventory of a released artifact, download the complete dist jar from Maven Central and run +the dependency analyzer directly against the jar: + +```bash +VERSION=26.04.2 +curl -fL -o /tmp/rapids-4-spark_2.12-${VERSION}-cuda12.jar \ + https://repo.maven.apache.org/maven2/com/nvidia/rapids-4-spark_2.12/${VERSION}/rapids-4-spark_2.12-${VERSION}-cuda12.jar + +python3 dist/scripts/analyze-parallel-world-deps.py \ + /tmp/rapids-4-spark_2.12-${VERSION}-cuda12.jar \ + --show-topo +``` + +Run the same command for the Scala 2.13 artifact when checking Spark 4.x coverage. Internal +snapshot artifacts can be analyzed the same way after downloading a timestamped dist jar from the +configured artifact repository; keep repository credentials in local Maven or environment +configuration rather than embedding them in scripts or docs. + +For local branch validation, build representative two-shim dist jars that span the widest +differences in each Scala line: + +```bash +./build/buildall --profile=350,411 --scala213 --module=dist +python3 dist/scripts/analyze-parallel-world-deps.py \ + scala2.13/dist/target/parallel-world \ + --show-topo + +./build/buildall --profile=330,358 --module=dist +python3 dist/scripts/analyze-parallel-world-deps.py \ + dist/target/parallel-world \ + --show-topo +``` + +The analyzer reports: + +1. direct classifier-package dependencies, which should remain limited to bootstrap/facade code; +2. root or `spark-shared` classes with transitive paths to version-specific classes; +3. root-safe `spark-shared` strongly connected components in dependency-first order. + +Use `--format=json` when comparing safe components across artifacts or build outputs. JSON output +keeps counts exact and bounds example sections with `--limit`. +Shortest paths explain why a class is blocked and usually identify the adapter boundary to cut. +Strongly connected components, not shortest paths, provide the migration ordering because classes in +the same component have to move or be refactored together. + ## Method signature discrepancies It's among the easiest issues to resolve. We define a method in SparkShims