Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 35 additions & 0 deletions skills/docs/dev/TESTING.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
# Testing

The commands assume you are in the `skills/` directory.

## Setup

Set up a local dev environment:

```bash
python -m venv .venv
source .venv/bin/activate
pip install -e ".[dev]"
```

## Fast Tests

Run the fast tests:

```bash
pytest -m "not slow"
```

These are generally lightweight skill validation tests, such as verifying skill frontmatter.

## Integration Tests

Run the integration tests:

```bash
pytest -m slow -s
```

These tests deterministically fill in the template project from `skills/udf-gen-test/templates/` with fixture implementations, then actually compile and run Spark tests and benchmark scripts locally.

Thus they require JDK, Maven and Maven repository access, a GPU environment, and (for `cuda` tests) CMake and CUDA toolkit.
38 changes: 38 additions & 0 deletions skills/pyproject.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0

[build-system]
requires = ["setuptools>=61.0"]
build-backend = "setuptools.build_meta"

[project]
name = "aether-agent"
version = "0.1.0"
description = "Convert Spark UDFs into GPU implementations"
authors = [
{name = "Rishi Chandra", email = "rishic@nvidia.com"}
]
readme = "README.md"
requires-python = ">=3.10"
classifiers = [
"Programming Language :: Python :: 3",
"Operating System :: OS Independent",
]

[project.optional-dependencies]
dev = [
"pytest==8.4.1",
"PyYAML==6.0.3",
"isort==6.0.1",
"black==25.1.0",
"ruff==0.12.8",
]

[tool.setuptools]
packages = []

[tool.pyright]
typeCheckingMode = "standard"

[tool.pytest.ini_options]
markers = ["slow: integration tests"]
6 changes: 6 additions & 0 deletions skills/tests/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0

"""
Tests for the skills templates.
"""
224 changes: 224 additions & 0 deletions skills/tests/fixtures.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,224 @@
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0

"""
Source fixtures for the JVM template integration tests.
"""

from pathlib import Path


def _read_resource(name: str) -> str:
return (Path(__file__).parent / "resources" / name).read_text(encoding="utf-8")


# ---------------------------------------------------------------------------
# UDF source code
# ---------------------------------------------------------------------------

CPU_UDF_NAME = "IntegerMultiplyBy2UDF"
RAPIDS_UDF_NAME = "IntegerMultiplyBy2RapidsUDF"
NATIVE_UDF_NAME = "IntegerMultiplyBy2NativeRapidsUDF"

SCALA_UDF_SOURCE = _read_resource(f"{CPU_UDF_NAME}.scala")
JAVA_UDF_SOURCE = _read_resource(f"{CPU_UDF_NAME}.java")
SCALA_RAPIDS_UDF_SOURCE = _read_resource(f"{RAPIDS_UDF_NAME}.scala")
JAVA_RAPIDS_UDF_SOURCE = _read_resource(f"{RAPIDS_UDF_NAME}.java")
NATIVE_RAPIDS_UDF_SOURCE = _read_resource(f"{NATIVE_UDF_NAME}.java")
SQL_SOURCE = _read_resource("integer_multiply_by_2.sql")
JNI_SOURCE = _read_resource("IntegerMultiplyBy2Jni.cpp")
CUDA_SOURCE = _read_resource("integer_multiply_by_2.cu")
HEADER_SOURCE = _read_resource("integer_multiply_by_2.hpp")

# ---------------------------------------------------------------------------
# Unit test methods
# ---------------------------------------------------------------------------

CREATE_TEST_DATA = """\
def createTestData(spark: SparkSession): DataFrame = {
val schema = StructType(Seq(
StructField("id", IntegerType, nullable = false),
StructField("value", IntegerType, nullable = true)
))
val testData = Seq(
Row(1, 123),
Row(2, 0),
Row(3, -5),
Row(4, null)
)
spark.createDataFrame(spark.sparkContext.parallelize(testData), schema)
}"""

EXECUTE_UDF = """\
def executeUDF(spark: SparkSession, udfName: String, testDF: DataFrame): DataFrame = {
testDF.createOrReplaceTempView("test_table")
spark.sql(s"SELECT *, $udfName(value) AS result FROM test_table")
}"""

ASSERT_UDF_RESULTS = """\
def assertUDFResults(resultDF: DataFrame, testDF: DataFrame): Unit = {
val results = resultDF.collect().sortBy(_.getAs[Int]("id"))
assert(results(0).getAs[Int]("result") === 246)
assert(results(1).getAs[Int]("result") === 0)
assert(results(2).getAs[Int]("result") === -10)
assert(results(3).isNullAt(results(3).fieldIndex("result")))
}"""

_SCALA_REGISTER_CALL = "spark.udf.register({name}, new com.udf.{cls}())"
_JAVA_REGISTER_CALL = "spark.udf.register({name}, new com.udf.{cls}(), org.apache.spark.sql.types.IntegerType)"


_REGISTER_METHOD = """\
def {method}(spark: SparkSession, udfName: String): Unit = {{
{register_call}
}}"""


SCALA_REGISTER_UDF = _REGISTER_METHOD.format(
method="registerUDF",
register_call=_SCALA_REGISTER_CALL.format(name="udfName", cls=CPU_UDF_NAME),
)
JAVA_REGISTER_UDF = _REGISTER_METHOD.format(
method="registerUDF",
register_call=_JAVA_REGISTER_CALL.format(name="udfName", cls=CPU_UDF_NAME),
)

SCALA_REGISTER_RAPIDS_UDF = _REGISTER_METHOD.format(
method="registerRapidsUDF",
register_call=_SCALA_REGISTER_CALL.format(name="udfName", cls=RAPIDS_UDF_NAME),
)
JAVA_REGISTER_RAPIDS_UDF = _REGISTER_METHOD.format(
method="registerRapidsUDF",
register_call=_JAVA_REGISTER_CALL.format(name="udfName", cls=RAPIDS_UDF_NAME),
)
NATIVE_REGISTER_RAPIDS_UDF = _REGISTER_METHOD.format(
method="registerRapidsUDF",
register_call=_JAVA_REGISTER_CALL.format(name="udfName", cls=NATIVE_UDF_NAME),
)

# ---------------------------------------------------------------------------
# BenchUtils methods
# ---------------------------------------------------------------------------

BENCH_GENERATE = """\
def generateSyntheticData(
spark: SparkSession,
numRows: Long,
numPartitions: Int
): DataFrame = {
val baseDF = spark.range(0, numRows, 1, numPartitions)
baseDF.select(
col("id"),
(rand() * 1000).cast(IntegerType).alias("value")
)
}"""


_BENCH_EXECUTE_METHOD = """\
def {method}(spark: SparkSession, df: DataFrame): DataFrame = {{
df.createOrReplaceTempView("bench_table")
{register}
spark.sql("SELECT *, udf(value) AS result FROM bench_table")
}}"""


BENCH_EXECUTE_SCALA_CPU = _BENCH_EXECUTE_METHOD.format(
method="executeCpu",
register=_SCALA_REGISTER_CALL.format(name='"udf"', cls=CPU_UDF_NAME),
)
BENCH_EXECUTE_JAVA_CPU = _BENCH_EXECUTE_METHOD.format(
method="executeCpu",
register=_JAVA_REGISTER_CALL.format(name='"udf"', cls=CPU_UDF_NAME),
)
BENCH_EXECUTE_SCALA_CUDF = _BENCH_EXECUTE_METHOD.format(
method="executeGpu",
register=_SCALA_REGISTER_CALL.format(name='"udf"', cls=RAPIDS_UDF_NAME),
)
BENCH_EXECUTE_JAVA_CUDF = _BENCH_EXECUTE_METHOD.format(
method="executeGpu",
register=_JAVA_REGISTER_CALL.format(name='"udf"', cls=RAPIDS_UDF_NAME),
)
BENCH_EXECUTE_CUDA = _BENCH_EXECUTE_METHOD.format(
method="executeGpu",
register=_JAVA_REGISTER_CALL.format(name='"udf"', cls=NATIVE_UDF_NAME),
)

BENCH_EXECUTE_SQL = """\
def executeGpu(spark: SparkSession, df: DataFrame): DataFrame = {
df.createOrReplaceTempView("bench_table")
val sqlSource = scala.io.Source.fromFile("src/main/resources/integer_multiply_by_2.sql")
val sqlContent = try sqlSource.mkString finally sqlSource.close()
val benchSql = sqlContent.replace("test_table", "bench_table")
spark.sql(benchSql)
}"""

# ---------------------------------------------------------------------------
# MicroBenchRunner methods
# ---------------------------------------------------------------------------

MICRO_PREPARE_CPU = """\
def prepareCpuData(
hostColumns: Array[HostColumnVector],
numRows: Int
): Array[AnyRef] = {
val values = Array.tabulate(numRows) { i =>
if (hostColumns(1).isNull(i)) null
else Int.box(hostColumns(1).getInt(i))
}
Array[AnyRef](values)
}"""


_MICRO_EXECUTE_CPU_METHOD = """\
def executeCpu(data: Array[AnyRef], numRows: Int): Unit = {{
val values = data(0).asInstanceOf[Array[Integer]]
val udf = new com.udf.{cls}()
var i = 0
while (i < numRows) {{
udf.{invoke}(values(i))
i += 1
}}
}}"""

MICRO_EXECUTE_SCALA_CPU = _MICRO_EXECUTE_CPU_METHOD.format(
cls=CPU_UDF_NAME,
invoke="apply",
)
MICRO_EXECUTE_JAVA_CPU = _MICRO_EXECUTE_CPU_METHOD.format(
cls=CPU_UDF_NAME,
invoke="call",
)


_MICRO_EXECUTE_GPU_METHOD = """\
def executeGpu(table: Table, numRows: Int): Unit = {{
val udf = new com.udf.{cls}()
withResource(udf.evaluateColumnar(numRows, table.getColumn(1))) {{ _ => }}
}}"""

MICRO_EXECUTE_CUDF = _MICRO_EXECUTE_GPU_METHOD.format(cls=RAPIDS_UDF_NAME)
MICRO_EXECUTE_CUDA = _MICRO_EXECUTE_GPU_METHOD.format(cls=NATIVE_UDF_NAME)

# ---------------------------------------------------------------------------
# Native source paths
# ---------------------------------------------------------------------------

CMAKE_SOURCE_FILES = """\
set(SOURCE_FILES
"src/IntegerMultiplyBy2Jni.cpp"
"src/integer_multiply_by_2.cu"
)
"""

NATIVE_PLACEHOLDER_FILES = (
"src/main/java/com/udf/PlaceholderUDFNameNativeRapidsUDF.java",
"native/src/main/cpp/src/PlaceholderUDFNameJni.cpp",
"native/src/main/cpp/src/placeholder_udf_name.cu",
"native/src/main/cpp/src/placeholder_udf_name.hpp",
)

NATIVE_SOURCE_FILES = {
"native/src/main/cpp/src/IntegerMultiplyBy2Jni.cpp": JNI_SOURCE,
"native/src/main/cpp/src/integer_multiply_by_2.cu": CUDA_SOURCE,
"native/src/main/cpp/src/integer_multiply_by_2.hpp": HEADER_SOURCE,
}
63 changes: 63 additions & 0 deletions skills/tests/resources/IntegerMultiplyBy2Jni.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
/*
* SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: Apache-2.0
*/

#include "integer_multiply_by_2.hpp"

#include <cudf/column/column.hpp>
#include <cudf/column/column_view.hpp>
#include <cudf/types.hpp>

#include <jni.h>

#include <memory>
#include <string>

namespace {

constexpr char const* RUNTIME_ERROR_CLASS = "java/lang/RuntimeException";
constexpr char const* ILLEGAL_ARG_CLASS = "java/lang/IllegalArgumentException";

void throw_java_exception(JNIEnv* env, char const* class_name, char const* message)
{
jclass ex_class = env->FindClass(class_name);
if (ex_class != nullptr) {
env->ThrowNew(ex_class, message);
}
}

} // namespace

extern "C" {

JNIEXPORT jlong JNICALL
Java_com_udf_IntegerMultiplyBy2NativeRapidsUDF_integerMultiplyBy2(JNIEnv* env,
jclass,
jlong input_view)
{
try {
auto input = reinterpret_cast<cudf::column_view const*>(input_view);
if (input == nullptr) {
throw_java_exception(env, ILLEGAL_ARG_CLASS, "input column view is null");
return 0;
}
if (input->type().id() != cudf::type_id::INT32) {
throw_java_exception(env, ILLEGAL_ARG_CLASS, "input must be INT32");
return 0;
}

std::unique_ptr<cudf::column> result = integer_multiply_by_2(*input);
return reinterpret_cast<jlong>(result.release());
} catch (std::bad_alloc const& e) {
auto message = std::string("Unable to allocate native memory: ") + e.what();
throw_java_exception(env, RUNTIME_ERROR_CLASS, message.c_str());
} catch (std::invalid_argument const& e) {
throw_java_exception(env, ILLEGAL_ARG_CLASS, e.what());
} catch (std::exception const& e) {
throw_java_exception(env, RUNTIME_ERROR_CLASS, e.what());
}
return 0;
}

}
27 changes: 27 additions & 0 deletions skills/tests/resources/IntegerMultiplyBy2NativeRapidsUDF.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
/*
* SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: Apache-2.0
*/

package com.udf;

import ai.rapids.cudf.ColumnVector;
import com.nvidia.spark.RapidsUDF;
import org.apache.hadoop.hive.ql.exec.UDF;
import org.apache.spark.sql.api.java.UDF1;

public class IntegerMultiplyBy2NativeRapidsUDF extends UDF
implements UDF1<Integer, Integer>, RapidsUDF {
@Override
public Integer call(Integer value) {
return value == null ? null : value * 2;
}

@Override
public ColumnVector evaluateColumnar(int numRows, ColumnVector... args) {
NativeUDFLoader.ensureLoaded();
return new ColumnVector(integerMultiplyBy2(args[0].getNativeView()));
}

private static native long integerMultiplyBy2(long inputView);
}
Loading
Loading