|
21 | 21 | import java.sql.ResultSetMetaData; |
22 | 22 | import java.sql.SQLException; |
23 | 23 | import java.sql.Statement; |
| 24 | +import java.util.UUID; |
24 | 25 |
|
25 | 26 | import org.junit.jupiter.api.AfterAll; |
26 | 27 | import org.junit.jupiter.api.BeforeAll; |
@@ -61,6 +62,7 @@ public class VectorTest extends AbstractTest { |
61 | 62 | private static final String TABLE_NAME = RandomUtil.getIdentifier("VECTOR_TVP_Test"); |
62 | 63 | private static final String TVP_NAME = RandomUtil.getIdentifier("VECTOR_TVP_Test_Type"); |
63 | 64 | private static final String TVP = RandomUtil.getIdentifier("VECTOR_TVP_UDF_Test_Type"); |
| 65 | + private static final String uuid = UUID.randomUUID().toString().replaceAll("-", ""); |
64 | 66 |
|
65 | 67 | @BeforeAll |
66 | 68 | private static void setupTest() throws Exception { |
@@ -1198,9 +1200,7 @@ public void testSelectIntoForVector() throws SQLException { |
1198 | 1200 | } |
1199 | 1201 |
|
1200 | 1202 | // Drop the destination table if it already exists |
1201 | | - String dropTableSql = "IF OBJECT_ID('" + destinationTable + "', 'U') IS NOT NULL DROP TABLE " |
1202 | | - + destinationTable; |
1203 | | - stmt.executeUpdate(dropTableSql); |
| 1203 | + TestUtils.dropTableIfExists(destinationTable, stmt); |
1204 | 1204 |
|
1205 | 1205 | // Perform the SELECT INTO operation |
1206 | 1206 | String selectIntoSql = "SELECT * INTO " + destinationTable + " FROM " + sourceTable; |
@@ -1378,69 +1378,79 @@ public void testVectorNormalizeUdf() throws SQLException { |
1378 | 1378 | } |
1379 | 1379 | } |
1380 | 1380 |
|
1381 | | - /** |
1382 | | - * Test for vector normalization using a scalar-valued function. |
1383 | | - * The function normalizes the input vector and returns the normalized vector. |
1384 | | - */ |
1385 | | - @Test |
1386 | | - public void testVectorNormalizeScalarFunction() throws SQLException { |
1387 | | - String vectorsTable = TestUtils |
1388 | | - .escapeSingleQuotes(AbstractSQLGenerator.escapeIdentifier(RandomUtil.getIdentifier("Vectors"))); |
1389 | | - String udfName = TestUtils |
1390 | | - .escapeSingleQuotes(AbstractSQLGenerator.escapeIdentifier(RandomUtil.getIdentifier("svf"))); |
| 1381 | + private void setupSVF(String schemaName, String funcName, String tableName) throws SQLException { |
| 1382 | + String escapedSchema = AbstractSQLGenerator.escapeIdentifier(schemaName); |
| 1383 | + String escapedFunc = AbstractSQLGenerator.escapeIdentifier(funcName); |
1391 | 1384 |
|
1392 | 1385 | try (Statement stmt = connection.createStatement()) { |
1393 | | - // Drop table and UDF if they already exist |
1394 | | - TestUtils.dropTableIfExists(vectorsTable, stmt); |
1395 | | - String dropUdfSQL = "IF OBJECT_ID('" + udfName + "', 'FN') IS NOT NULL DROP FUNCTION " + udfName; |
1396 | | - stmt.execute(dropUdfSQL); |
| 1386 | + |
| 1387 | + // Create schema if not exists |
| 1388 | + stmt.executeUpdate( |
| 1389 | + "IF NOT EXISTS (SELECT * FROM sys.schemas WHERE name = '" + schemaName + "') " + |
| 1390 | + "EXEC('CREATE SCHEMA " + escapedSchema + "')"); |
1397 | 1391 |
|
1398 | | - // Create the scalar-valued function |
1399 | | - String createUdfSQL = "CREATE FUNCTION " + udfName + " (@p VECTOR(3)) " + |
| 1392 | + // Create scalar-valued function |
| 1393 | + String createSvfSQL = "CREATE FUNCTION " + escapedSchema + "." + escapedFunc + " (@p VECTOR(3)) " + |
1400 | 1394 | "RETURNS VECTOR(3) AS " + |
1401 | 1395 | "BEGIN " + |
1402 | 1396 | " DECLARE @v VECTOR(3); " + |
1403 | 1397 | " SET @v = vector_normalize(@p, 'norm2'); " + |
1404 | 1398 | " RETURN @v; " + |
1405 | 1399 | "END"; |
1406 | | - stmt.execute(createUdfSQL); |
| 1400 | + stmt.executeUpdate(createSvfSQL); |
1407 | 1401 |
|
1408 | | - // Create the table |
1409 | | - String createTableSQL = "CREATE TABLE " + vectorsTable + " (id INT PRIMARY KEY, data VECTOR(3))"; |
| 1402 | + String createTableSQL = "CREATE TABLE " + tableName + " (id INT PRIMARY KEY, vec VECTOR(3))"; |
1410 | 1403 | stmt.execute(createTableSQL); |
1411 | 1404 |
|
1412 | | - // Insert sample data |
1413 | | - String insertSQL = "INSERT INTO " + vectorsTable + " (id, data) VALUES (?, ?)"; |
1414 | | - try (PreparedStatement pstmt = connection.prepareStatement(insertSQL)) { |
1415 | | - Object[] vectorData = new Float[] { 1.0f, 2.0f, 3.0f }; |
1416 | | - Vector vector = new Vector(3, VectorDimensionType.FLOAT32, vectorData); |
| 1405 | + } |
| 1406 | + } |
1417 | 1407 |
|
| 1408 | + /** |
| 1409 | + * Test for vector normalization using a scalar-valued function. |
| 1410 | + * The function normalizes the input vector and returns the normalized vector. |
| 1411 | + */ |
| 1412 | + @Test |
| 1413 | + public void testVectorIdentityScalarFunction() throws SQLException { |
| 1414 | + String schemaName = "testschemaVector" + uuid; |
| 1415 | + String funcName = "svf" + uuid; |
| 1416 | + String escapedSchema = AbstractSQLGenerator.escapeIdentifier(schemaName); |
| 1417 | + String escapedFunc = AbstractSQLGenerator.escapeIdentifier(funcName); |
| 1418 | + String tableName = escapedSchema + "." + "Vectors" + uuid; |
| 1419 | + |
| 1420 | + try { |
| 1421 | + // Setup: create schema, function, and table |
| 1422 | + setupSVF(schemaName, funcName, tableName); |
| 1423 | + |
| 1424 | + // Insert a vector row |
| 1425 | + try (PreparedStatement pstmt = connection.prepareStatement( |
| 1426 | + "INSERT INTO " + tableName + " (id, vec) VALUES (?, ?)")) { |
| 1427 | + Vector v = new Vector(3, VectorDimensionType.FLOAT32, new Float[] { 1.0f, 2.0f, 3.0f }); |
1418 | 1428 | pstmt.setInt(1, 1); |
1419 | | - pstmt.setObject(2, vector, microsoft.sql.Types.VECTOR); |
| 1429 | + pstmt.setObject(2, v, microsoft.sql.Types.VECTOR); |
1420 | 1430 | pstmt.executeUpdate(); |
1421 | 1431 | } |
1422 | 1432 |
|
1423 | | - // Test the scalar-valued function |
1424 | | - String udfTestSQL = "DECLARE @v VECTOR(3) = (SELECT data FROM " + vectorsTable + " WHERE id = 1); " + |
1425 | | - "SELECT " + udfName + "(@v) AS normalizedVector"; |
1426 | | - try (ResultSet rs = stmt.executeQuery(udfTestSQL)) { |
1427 | | - assertTrue(rs.next(), "No result returned from scalar-valued function."); |
1428 | | - Vector normalizedVector = rs.getObject("normalizedVector", Vector.class); |
1429 | | - assertNotNull(normalizedVector, "Normalized vector is null."); |
1430 | | - |
1431 | | - Object[] expectedNormalizedData = new Float[] { 0.2673f, 0.5345f, 0.8018f }; // Normalized values for [1, 2, 3] |
1432 | | - Object[] actualNormalizedData = normalizedVector.getData(); |
| 1433 | + // Call the scalar function and validate output |
| 1434 | + String svfTestSQL = "DECLARE @v VECTOR(3) = (SELECT vec FROM " + tableName + " WHERE id = 1); " + |
| 1435 | + "SELECT " + escapedSchema + "." + escapedFunc + "(@v) AS normalizedVector"; |
1433 | 1436 |
|
1434 | | - for (int i = 0; i < expectedNormalizedData.length; i++) { |
1435 | | - assertEquals((float) expectedNormalizedData[i], (float) actualNormalizedData[i], 0.0001f, |
1436 | | - "Normalized vector mismatch at index " + i); |
1437 | | - } |
| 1437 | + try (Statement stmt = connection.createStatement(); |
| 1438 | + ResultSet rs = stmt.executeQuery(svfTestSQL)) { |
| 1439 | + |
| 1440 | + assertTrue(rs.next(), "No result from SVF."); |
| 1441 | + Vector normalizedVector = rs.getObject(1, Vector.class); |
| 1442 | + assertNotNull(normalizedVector, "Returned vector is null."); |
| 1443 | + |
| 1444 | + Object[] expectedNormalizedData = new Float[] { 0.26726124f, 0.5345225f, 0.8017837f }; // Normalized values for [1, 2, 3] |
| 1445 | + assertArrayEquals(expectedNormalizedData, normalizedVector.getData(), "Vector roundtrip mismatch."); |
1438 | 1446 | } |
| 1447 | + |
1439 | 1448 | } finally { |
1440 | | - // Cleanup: Drop the UDF and table |
| 1449 | + // Cleanup: drop function, table, and schema |
1441 | 1450 | try (Statement stmt = connection.createStatement()) { |
1442 | | - TestUtils.dropFunctionIfExists(udfName, stmt); |
1443 | | - TestUtils.dropTableIfExists(vectorsTable, stmt); |
| 1451 | + TestUtils.dropFunctionWithSchemaIfExists(schemaName + "." + funcName, stmt); |
| 1452 | + TestUtils.dropTableWithSchemaIfExists(tableName, stmt); |
| 1453 | + TestUtils.dropSchemaIfExists(schemaName, stmt); |
1444 | 1454 | } |
1445 | 1455 | } |
1446 | 1456 | } |
|
0 commit comments