Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -85,20 +85,19 @@ object MicroBenchRunner {
def executeCpu(data: Array[AnyRef], numRows: Int): Unit = ???

/**
* TODO: Execute the GPU UDF via evaluateColumnar.
* TODO: Execute the GPU UDF via evaluateColumnar and close its result.
*
* Example:
* {{{
* val udf = new com.udf.PlaceholderRapidsUDFName()
* udf.evaluateColumnar(numRows,
* table.getColumn(0), table.getColumn(1))
* withResource(udf.evaluateColumnar(numRows,
* table.getColumn(0), table.getColumn(1))) { _ => }
* }}}
*
* @param table the dataset loaded on GPU
* @param numRows number of rows in the dataset
* @return result ColumnVector (NOTE: caller must close)
*/
def executeGpu(table: Table, numRows: Int): ColumnVector = ???
def executeGpu(table: Table, numRows: Int): Unit = ???

def main(args: Array[String]): Unit = {
val parsed = parseArgs(args)
Expand Down Expand Up @@ -165,7 +164,7 @@ object MicroBenchRunner {
if (runGpu) {
try {
val times = runBenchmark(warmup, measured, profile = profile) {
withResource(executeGpu(table, numRows)) { _ => }
executeGpu(table, numRows)
}
val medianMs = times(times.length / 2) / 1e6
val minMs = times(0) / 1e6
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -112,31 +112,26 @@ object SparkBenchRunner {
val resultDir = new File(path).getParentFile
if (resultDir != null) resultDir.mkdirs()

try {
import java.util.{LinkedHashMap => JLinkedHashMap, Arrays => JArrays}
val report = new JLinkedHashMap[String, AnyRef]()
report.put("mode", mode)
report.put("data_path", dataPath)
report.put("status", status)
report.put("e2e_runtime", java.lang.Double.valueOf(elapsed))
report.put("cli_args", JArrays.asList(cliArgs: _*))
errorMessage.foreach { msg =>
val error = new JLinkedHashMap[String, String]()
error.put("error_message", msg)
errorLogFile.foreach(f => error.put("error_log_file", f))
report.put("error", error)
}

val mapper = new ObjectMapper()
mapper.enable(SerializationFeature.INDENT_OUTPUT)
val printer = new DefaultPrettyPrinter()
printer.indentArraysWith(DefaultIndenter.SYSTEM_LINEFEED_INSTANCE)
mapper.writer(printer).writeValue(new File(path), report)
System.err.println(s"Report written to: $path")
} catch {
case e: Exception =>
System.err.println(s"Failed to write report: ${e.getMessage}")
import java.util.{LinkedHashMap => JLinkedHashMap, Arrays => JArrays}
val report = new JLinkedHashMap[String, AnyRef]()
report.put("mode", mode)
report.put("data_path", dataPath)
report.put("status", status)
report.put("e2e_runtime", java.lang.Double.valueOf(elapsed))
report.put("cli_args", JArrays.asList(cliArgs: _*))
errorMessage.foreach { msg =>
val error = new JLinkedHashMap[String, String]()
error.put("error_message", msg)
errorLogFile.foreach(f => error.put("error_log_file", f))
report.put("error", error)
}

val mapper = new ObjectMapper()
mapper.enable(SerializationFeature.INDENT_OUTPUT)
val printer = new DefaultPrettyPrinter()
printer.indentArraysWith(DefaultIndenter.SYSTEM_LINEFEED_INSTANCE)
mapper.writer(printer).writeValue(new File(path), report)
System.err.println(s"Report written to: $path")
}

/** Write an exception to an error log file. */
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,17 +32,18 @@ class CudfComparisonTest extends AnyFunSuite with BeforeAndAfterAll {
def registerRapidsUDF(spark: SparkSession, udfName: String): Unit = ???

test("UDF vs RapidsUDF") {
val testDF = UnitTest.createTestData(spark).repartition(1)
// Repartition down to 2 tasks to ensure we exercise multi-row columns.
val testDF = UnitTest.createTestData(spark).repartition(2)

// Run CPU UDF
UnitTest.registerUDF(spark, "placeholder_udf_name")
val cpuResultDF = UnitTest.executeUDF(spark, "placeholder_udf_name", testDF)
UnitTest.verifyUDFResults(cpuResultDF, testDF)
UnitTest.assertUDFResults(cpuResultDF, testDF)

// Run RapidsUDF
registerRapidsUDF(spark, "placeholder_rapids_udf_name")
val gpuResultDF = UnitTest.executeUDF(spark, "placeholder_rapids_udf_name", testDF)
UnitTest.verifyUDFResults(gpuResultDF, testDF)
UnitTest.assertUDFResults(gpuResultDF, testDF)

// Compare
TestUtils.assertDataFrameEquals(actual = gpuResultDF, expected = cpuResultDF)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,19 +30,20 @@ class SqlComparisonTest extends AnyFunSuite with BeforeAndAfterAll {
}

test("UDF vs SQL expression") {
val testDF = UnitTest.createTestData(spark).repartition(1)
// Repartition down to 2 tasks to ensure we exercise multi-row columns.
val testDF = UnitTest.createTestData(spark).repartition(2)

// Run CPU UDF
UnitTest.registerUDF(spark, "placeholder_udf_name")
val udfResultDF = UnitTest.executeUDF(spark, "placeholder_udf_name", testDF)
UnitTest.verifyUDFResults(udfResultDF, testDF)
UnitTest.assertUDFResults(udfResultDF, testDF)

// Read and execute SQL expression
testDF.createOrReplaceTempView("test_table")
val sqlSource = scala.io.Source.fromFile("src/main/resources/placeholder_udf_name.sql")
val sqlContent = try sqlSource.mkString finally sqlSource.close()
val sqlResultDF = spark.sql(sqlContent)
UnitTest.verifyUDFResults(sqlResultDF, testDF)
UnitTest.assertUDFResults(sqlResultDF, testDF)

// Compare results
TestUtils.assertDataFrameEquals(actual = sqlResultDF, expected = udfResultDF)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@ import org.scalatest.BeforeAndAfterAll

object UnitTest extends Assertions {
/**
* TODO: Create a test DataFrame with diverse test cases including edge cases.
* TODO: Create a test DataFrame with diverse test cases including edge cases
* (at least 10+ cases).
*
* Example:
* {{{
Expand All @@ -25,6 +26,7 @@ object UnitTest extends Assertions {
* Row(1, 800),
* Row(2, 550),
* Row(3, null)
* // ...
Comment thread
rishic3 marked this conversation as resolved.
Outdated
* )
* spark.createDataFrame(spark.sparkContext.parallelize(testData), schema)
* }}}
Expand Down Expand Up @@ -55,7 +57,7 @@ object UnitTest extends Assertions {
def executeUDF(spark: SparkSession, udfName: String, testDF: DataFrame): DataFrame = ???

/**
* TODO: Verify UDF results using assert statements.
* TODO: Assert the UDF results match expectations.
*
* Example:
* {{{
Expand All @@ -65,7 +67,7 @@ object UnitTest extends Assertions {
* assert(results(2).getAs[String]("risk_level") === "UNKNOWN")
* }}}
*/
def verifyUDFResults(resultDF: DataFrame, testDF: DataFrame): Unit = ???
def assertUDFResults(resultDF: DataFrame, testDF: DataFrame): Unit = ???
}

class UnitTest extends AnyFunSuite with BeforeAndAfterAll {
Expand All @@ -89,11 +91,12 @@ class UnitTest extends AnyFunSuite with BeforeAndAfterAll {
}

test("UDF produces correct results") {
val testDF = UnitTest.createTestData(spark).repartition(1)
// Repartition down to 2 tasks to ensure we exercise multi-row columns.
val testDF = UnitTest.createTestData(spark).repartition(2)

UnitTest.registerUDF(spark, "placeholder_udf_name")
val resultDF = UnitTest.executeUDF(spark, "placeholder_udf_name", testDF)

UnitTest.verifyUDFResults(resultDF, testDF)
UnitTest.assertUDFResults(resultDF, testDF)
}
}
2 changes: 1 addition & 1 deletion skills/udf-judge-conversion/SKILL.md
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ Check that:
- Assertions verify schema, row count, deterministic ordering, output values, null propagation, and exception/default behavior where applicable.
- The test exercises visible CPU UDF branches. Coverage reports should support this when available.
- Assertions reflect the CPU UDF's actual behavior and do not merely assert weak properties such as non-null output.
- Extra unit tests outside the shared `verifyUDFResults` path are mirrored in the comparison test and run against both CPU and GPU/SQL paths.
- Extra unit tests outside the shared `assertUDFResults` path are mirrored in the comparison test and run against both CPU and GPU/SQL paths.

## Comparison Test Checks

Expand Down
Loading