Skip to content

Commit d08ae78

Browse files
authored
Resolve vector test failures after enabling AzureDB runs (#2757)
* Fixed test failures post enabling test in azure DB * Updated test * Fixed failure in testBulkCopyTableToTableJsonToVector()
1 parent 2bfd53e commit d08ae78

File tree

2 files changed

+91
-85
lines changed

2 files changed

+91
-85
lines changed

src/test/java/com/microsoft/sqlserver/jdbc/bulkCopy/BulkCopyISQLServerBulkRecordTest.java

Lines changed: 36 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -591,58 +591,54 @@ public void testBulkCopyWithCountVerification() throws SQLException {
591591
@AzureDB
592592
@Tag(Constants.vectorTest)
593593
public void testBulkCopyTableToTableJsonToVector() throws Exception {
594-
String srcTable = TestUtils.escapeSingleQuotes(AbstractSQLGenerator.escapeIdentifier("testSrcJsonTable"));
595-
String dstTable = TestUtils.escapeSingleQuotes(AbstractSQLGenerator.escapeIdentifier("testDstVectorTable"));
594+
String srcTable = TestUtils.escapeSingleQuotes(
595+
AbstractSQLGenerator.escapeIdentifier(RandomUtil.getIdentifier("testSrcJsonTable")));
596+
String dstTable = TestUtils.escapeSingleQuotes(
597+
AbstractSQLGenerator.escapeIdentifier(RandomUtil.getIdentifier("testDstVectorTable")));
596598
String vectorJson = "[1.0, 2.0, 3.0]";
597599
Object[] expectedVector = new Float[] { 1.0f, 2.0f, 3.0f };
598600

599-
// Create source table and insert JSON vector
600601
try (Connection conn = DriverManager.getConnection(connectionString);
601602
Statement stmt = conn.createStatement()) {
603+
604+
// Create source table and insert JSON vector
602605
stmt.executeUpdate("CREATE TABLE " + srcTable + " (vectorJsonCol JSON)");
603606
stmt.executeUpdate("INSERT INTO " + srcTable + " (vectorJsonCol) VALUES ('" + vectorJson + "')");
604-
}
605607

606-
// Create destination table with VECTOR column
607-
try (Connection conn = DriverManager.getConnection(connectionString);
608-
Statement stmt = conn.createStatement()) {
608+
// Create destination table with VECTOR column
609609
stmt.executeUpdate("CREATE TABLE " + dstTable + " (vectorCol VECTOR(3))");
610-
}
611610

612-
// Table-to-table bulk copy: read JSON, parse, and write as VECTOR
613-
try (Connection conn = DriverManager.getConnection(connectionString);
614-
Statement stmt = conn.createStatement();
615-
ResultSet rs = stmt.executeQuery("SELECT vectorJsonCol FROM " + srcTable);
616-
SQLServerBulkCopy bulkCopy = new SQLServerBulkCopy(conn)) {
617-
618-
bulkCopy.setDestinationTableName(dstTable);
619-
// For each row, parse JSON and bulk copy as VECTOR
620-
while (rs.next()) {
621-
String json = rs.getString(1);
622-
Object[] vector = parseJsonArrayToFloatArray(json);
623-
Vector vectorObj = new Vector(vector.length, VectorDimensionType.FLOAT32, vector);
624-
VectorBulkData vectorBulkData = new VectorBulkData(vectorObj, vector.length, VectorDimensionType.FLOAT32);
625-
bulkCopy.writeToServer(vectorBulkData);
626-
}
627-
}
611+
// Table-to-table bulk copy: read JSON, parse, and write as VECTOR
612+
try (ResultSet rs = stmt.executeQuery("SELECT vectorJsonCol FROM " + srcTable);
613+
SQLServerBulkCopy bulkCopy = new SQLServerBulkCopy(conn)) {
628614

629-
// Validate the data in the destination table
630-
try (Connection conn = DriverManager.getConnection(connectionString);
631-
Statement stmt = conn.createStatement();
632-
ResultSet rs = stmt.executeQuery("SELECT vectorCol FROM " + dstTable)) {
633-
assertTrue(rs.next(), "No data found in the destination table.");
634-
Vector resultVector = rs.getObject(1, Vector.class);
635-
assertNotNull(resultVector, "Retrieved vector is null.");
636-
assertEquals(3, resultVector.getDimensionCount(), "Dimension count mismatch.");
637-
assertEquals(VectorDimensionType.FLOAT32, resultVector.getVectorDimensionType(), "Vector dimension type mismatch.");
638-
assertArrayEquals(expectedVector, resultVector.getData(), "Vector data mismatch.");
639-
}
615+
bulkCopy.setDestinationTableName(dstTable);
616+
// For each row, parse JSON and bulk copy as VECTOR
617+
while (rs.next()) {
618+
String json = rs.getString(1);
619+
Object[] vector = parseJsonArrayToFloatArray(json);
620+
Vector vectorObj = new Vector(vector.length, VectorDimensionType.FLOAT32, vector);
621+
VectorBulkData vectorBulkData = new VectorBulkData(vectorObj, vector.length, VectorDimensionType.FLOAT32);
622+
bulkCopy.writeToServer(vectorBulkData);
623+
}
624+
}
640625

641-
// Cleanup
642-
try (Connection conn = DriverManager.getConnection(connectionString);
643-
Statement stmt = conn.createStatement()) {
644-
TestUtils.dropTableIfExists(srcTable, stmt);
645-
TestUtils.dropTableIfExists(dstTable, stmt);
626+
// Validate the data in the destination table
627+
try (ResultSet rs = stmt.executeQuery("SELECT vectorCol FROM " + dstTable)) {
628+
assertTrue(rs.next(), "No data found in the destination table.");
629+
Vector resultVector = rs.getObject(1, Vector.class);
630+
assertNotNull(resultVector, "Retrieved vector is null.");
631+
assertEquals(3, resultVector.getDimensionCount(), "Dimension count mismatch.");
632+
assertEquals(VectorDimensionType.FLOAT32, resultVector.getVectorDimensionType(), "Vector dimension type mismatch.");
633+
assertArrayEquals(expectedVector, resultVector.getData(), "Vector data mismatch.");
634+
}
635+
} finally {
636+
// Cleanup
637+
try (Connection conn = DriverManager.getConnection(connectionString);
638+
Statement stmt = conn.createStatement()) {
639+
TestUtils.dropTableIfExists(srcTable, stmt);
640+
TestUtils.dropTableIfExists(dstTable, stmt);
641+
}
646642
}
647643
}
648644

src/test/java/com/microsoft/sqlserver/jdbc/datatypes/VectorTest.java

Lines changed: 55 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
import java.sql.ResultSetMetaData;
2222
import java.sql.SQLException;
2323
import java.sql.Statement;
24+
import java.util.UUID;
2425

2526
import org.junit.jupiter.api.AfterAll;
2627
import org.junit.jupiter.api.BeforeAll;
@@ -61,6 +62,7 @@ public class VectorTest extends AbstractTest {
6162
private static final String TABLE_NAME = RandomUtil.getIdentifier("VECTOR_TVP_Test");
6263
private static final String TVP_NAME = RandomUtil.getIdentifier("VECTOR_TVP_Test_Type");
6364
private static final String TVP = RandomUtil.getIdentifier("VECTOR_TVP_UDF_Test_Type");
65+
private static final String uuid = UUID.randomUUID().toString().replaceAll("-", "");
6466

6567
@BeforeAll
6668
private static void setupTest() throws Exception {
@@ -1198,9 +1200,7 @@ public void testSelectIntoForVector() throws SQLException {
11981200
}
11991201

12001202
// Drop the destination table if it already exists
1201-
String dropTableSql = "IF OBJECT_ID('" + destinationTable + "', 'U') IS NOT NULL DROP TABLE "
1202-
+ destinationTable;
1203-
stmt.executeUpdate(dropTableSql);
1203+
TestUtils.dropTableIfExists(destinationTable, stmt);
12041204

12051205
// Perform the SELECT INTO operation
12061206
String selectIntoSql = "SELECT * INTO " + destinationTable + " FROM " + sourceTable;
@@ -1378,69 +1378,79 @@ public void testVectorNormalizeUdf() throws SQLException {
13781378
}
13791379
}
13801380

1381-
/**
1382-
* Test for vector normalization using a scalar-valued function.
1383-
* The function normalizes the input vector and returns the normalized vector.
1384-
*/
1385-
@Test
1386-
public void testVectorNormalizeScalarFunction() throws SQLException {
1387-
String vectorsTable = TestUtils
1388-
.escapeSingleQuotes(AbstractSQLGenerator.escapeIdentifier(RandomUtil.getIdentifier("Vectors")));
1389-
String udfName = TestUtils
1390-
.escapeSingleQuotes(AbstractSQLGenerator.escapeIdentifier(RandomUtil.getIdentifier("svf")));
1381+
private void setupSVF(String schemaName, String funcName, String tableName) throws SQLException {
1382+
String escapedSchema = AbstractSQLGenerator.escapeIdentifier(schemaName);
1383+
String escapedFunc = AbstractSQLGenerator.escapeIdentifier(funcName);
13911384

13921385
try (Statement stmt = connection.createStatement()) {
1393-
// Drop table and UDF if they already exist
1394-
TestUtils.dropTableIfExists(vectorsTable, stmt);
1395-
String dropUdfSQL = "IF OBJECT_ID('" + udfName + "', 'FN') IS NOT NULL DROP FUNCTION " + udfName;
1396-
stmt.execute(dropUdfSQL);
1386+
1387+
// Create schema if not exists
1388+
stmt.executeUpdate(
1389+
"IF NOT EXISTS (SELECT * FROM sys.schemas WHERE name = '" + schemaName + "') " +
1390+
"EXEC('CREATE SCHEMA " + escapedSchema + "')");
13971391

1398-
// Create the scalar-valued function
1399-
String createUdfSQL = "CREATE FUNCTION " + udfName + " (@p VECTOR(3)) " +
1392+
// Create scalar-valued function
1393+
String createSvfSQL = "CREATE FUNCTION " + escapedSchema + "." + escapedFunc + " (@p VECTOR(3)) " +
14001394
"RETURNS VECTOR(3) AS " +
14011395
"BEGIN " +
14021396
" DECLARE @v VECTOR(3); " +
14031397
" SET @v = vector_normalize(@p, 'norm2'); " +
14041398
" RETURN @v; " +
14051399
"END";
1406-
stmt.execute(createUdfSQL);
1400+
stmt.executeUpdate(createSvfSQL);
14071401

1408-
// Create the table
1409-
String createTableSQL = "CREATE TABLE " + vectorsTable + " (id INT PRIMARY KEY, data VECTOR(3))";
1402+
String createTableSQL = "CREATE TABLE " + tableName + " (id INT PRIMARY KEY, vec VECTOR(3))";
14101403
stmt.execute(createTableSQL);
14111404

1412-
// Insert sample data
1413-
String insertSQL = "INSERT INTO " + vectorsTable + " (id, data) VALUES (?, ?)";
1414-
try (PreparedStatement pstmt = connection.prepareStatement(insertSQL)) {
1415-
Object[] vectorData = new Float[] { 1.0f, 2.0f, 3.0f };
1416-
Vector vector = new Vector(3, VectorDimensionType.FLOAT32, vectorData);
1405+
}
1406+
}
14171407

1408+
/**
1409+
* Test for vector normalization using a scalar-valued function.
1410+
* The function normalizes the input vector and returns the normalized vector.
1411+
*/
1412+
@Test
1413+
public void testVectorIdentityScalarFunction() throws SQLException {
1414+
String schemaName = "testschemaVector" + uuid;
1415+
String funcName = "svf" + uuid;
1416+
String escapedSchema = AbstractSQLGenerator.escapeIdentifier(schemaName);
1417+
String escapedFunc = AbstractSQLGenerator.escapeIdentifier(funcName);
1418+
String tableName = escapedSchema + "." + "Vectors" + uuid;
1419+
1420+
try {
1421+
// Setup: create schema, function, and table
1422+
setupSVF(schemaName, funcName, tableName);
1423+
1424+
// Insert a vector row
1425+
try (PreparedStatement pstmt = connection.prepareStatement(
1426+
"INSERT INTO " + tableName + " (id, vec) VALUES (?, ?)")) {
1427+
Vector v = new Vector(3, VectorDimensionType.FLOAT32, new Float[] { 1.0f, 2.0f, 3.0f });
14181428
pstmt.setInt(1, 1);
1419-
pstmt.setObject(2, vector, microsoft.sql.Types.VECTOR);
1429+
pstmt.setObject(2, v, microsoft.sql.Types.VECTOR);
14201430
pstmt.executeUpdate();
14211431
}
14221432

1423-
// Test the scalar-valued function
1424-
String udfTestSQL = "DECLARE @v VECTOR(3) = (SELECT data FROM " + vectorsTable + " WHERE id = 1); " +
1425-
"SELECT " + udfName + "(@v) AS normalizedVector";
1426-
try (ResultSet rs = stmt.executeQuery(udfTestSQL)) {
1427-
assertTrue(rs.next(), "No result returned from scalar-valued function.");
1428-
Vector normalizedVector = rs.getObject("normalizedVector", Vector.class);
1429-
assertNotNull(normalizedVector, "Normalized vector is null.");
1430-
1431-
Object[] expectedNormalizedData = new Float[] { 0.2673f, 0.5345f, 0.8018f }; // Normalized values for [1, 2, 3]
1432-
Object[] actualNormalizedData = normalizedVector.getData();
1433+
// Call the scalar function and validate output
1434+
String svfTestSQL = "DECLARE @v VECTOR(3) = (SELECT vec FROM " + tableName + " WHERE id = 1); " +
1435+
"SELECT " + escapedSchema + "." + escapedFunc + "(@v) AS normalizedVector";
14331436

1434-
for (int i = 0; i < expectedNormalizedData.length; i++) {
1435-
assertEquals((float) expectedNormalizedData[i], (float) actualNormalizedData[i], 0.0001f,
1436-
"Normalized vector mismatch at index " + i);
1437-
}
1437+
try (Statement stmt = connection.createStatement();
1438+
ResultSet rs = stmt.executeQuery(svfTestSQL)) {
1439+
1440+
assertTrue(rs.next(), "No result from SVF.");
1441+
Vector normalizedVector = rs.getObject(1, Vector.class);
1442+
assertNotNull(normalizedVector, "Returned vector is null.");
1443+
1444+
Object[] expectedNormalizedData = new Float[] { 0.26726124f, 0.5345225f, 0.8017837f }; // Normalized values for [1, 2, 3]
1445+
assertArrayEquals(expectedNormalizedData, normalizedVector.getData(), "Vector roundtrip mismatch.");
14381446
}
1447+
14391448
} finally {
1440-
// Cleanup: Drop the UDF and table
1449+
// Cleanup: drop function, table, and schema
14411450
try (Statement stmt = connection.createStatement()) {
1442-
TestUtils.dropFunctionIfExists(udfName, stmt);
1443-
TestUtils.dropTableIfExists(vectorsTable, stmt);
1451+
TestUtils.dropFunctionWithSchemaIfExists(schemaName + "." + funcName, stmt);
1452+
TestUtils.dropTableWithSchemaIfExists(tableName, stmt);
1453+
TestUtils.dropSchemaIfExists(schemaName, stmt);
14441454
}
14451455
}
14461456
}

0 commit comments

Comments
 (0)