diff --git a/core/src/main/java/org/opensearch/sql/planner/Planner.java b/core/src/main/java/org/opensearch/sql/planner/Planner.java index 4625d72d3fc..8a015bc072b 100644 --- a/core/src/main/java/org/opensearch/sql/planner/Planner.java +++ b/core/src/main/java/org/opensearch/sql/planner/Planner.java @@ -14,6 +14,7 @@ import org.opensearch.sql.planner.optimizer.LogicalPlanOptimizer; import org.opensearch.sql.planner.physical.PhysicalPlan; import org.opensearch.sql.storage.Table; +import org.opensearch.sql.storage.read.TableScanBuilder; /** Planner that plans and chooses the optimal physical plan. */ @RequiredArgsConstructor @@ -34,7 +35,35 @@ public PhysicalPlan plan(LogicalPlan plan) { if (table == null) { return plan.accept(new DefaultImplementor<>(), null); } - return table.implement(table.optimize(optimize(plan))); + LogicalPlan optimized = table.optimize(optimize(plan)); + // Give scan builders a chance to reject shapes that push-down alone cannot express safely + // (e.g. operators that land above the scan but outside its push-down contract). + validateScanBuilders(optimized); + return table.implement(optimized); + } + + /** + * Walk the optimized plan and invoke {@link TableScanBuilder#validatePlan(LogicalPlan)} on every + * scan builder, passing the fully optimized root so scan builders can inspect their ancestors. + */ + private void validateScanBuilders(LogicalPlan optimized) { + optimized.accept( + new LogicalPlanNodeVisitor() { + @Override + public Void visitNode(LogicalPlan node, Object context) { + for (LogicalPlan child : node.getChild()) { + child.accept(this, context); + } + return null; + } + + @Override + public Void visitTableScanBuilder(TableScanBuilder node, Object context) { + node.validatePlan(optimized); + return null; + } + }, + null); } private Table findTable(LogicalPlan plan) { diff --git a/core/src/main/java/org/opensearch/sql/storage/read/TableScanBuilder.java b/core/src/main/java/org/opensearch/sql/storage/read/TableScanBuilder.java index b2da0b67a4b..3d2fb2872e5 100644 --- a/core/src/main/java/org/opensearch/sql/storage/read/TableScanBuilder.java +++ b/core/src/main/java/org/opensearch/sql/storage/read/TableScanBuilder.java @@ -119,6 +119,19 @@ public boolean pushDownPageSize(LogicalPaginate paginate) { return false; } + /** + * Post-optimization validation hook. Called once by the planner after all push-down rules have + * run, with the fully optimized plan root. Subclasses may inspect the ancestors of this scan + * builder to reject planner shapes that push-down alone cannot express safely (for example, + * operators that land above the scan but outside its push-down contract and would be executed + * after the scan has already returned a bounded result set). Default is no-op. + * + * @param root the fully optimized logical plan containing this scan builder + */ + public void validatePlan(LogicalPlan root) { + // no-op by default + } + @Override public R accept(LogicalPlanNodeVisitor visitor, C context) { return visitor.visitTableScanBuilder(this, context); diff --git a/docs/user/dql/vector-search.rst b/docs/user/dql/vector-search.rst new file mode 100644 index 00000000000..8b0237a6ef0 --- /dev/null +++ b/docs/user/dql/vector-search.rst @@ -0,0 +1,331 @@ + +============================== +Vector Search [Experimental] +============================== + +.. rubric:: Table of contents + +.. contents:: + :local: + :depth: 2 + +Introduction +============ + +``vectorSearch()`` is an experimental feature. Syntax, options, and +pushdown behavior may change in future releases based on feedback. + +The ``vectorSearch()`` table function runs a k-NN query against a ``knn_vector`` +field and exposes the matching documents as a relation in the ``FROM`` clause. +It relies on the OpenSearch `k-NN plugin +`_. The target index must +map the vector field as ``knn_vector`` and the index must be created with +``index.knn: true``. + +The SQL layer translates ``vectorSearch()`` into an OpenSearch search +request whose body is native k-NN query DSL; the query vector is parsed +into a numeric array before that DSL is emitted. + +Relevance is expressed through the OpenSearch ``_score`` metadata field, and +results are returned ordered by ``_score DESC`` by default. + +vectorSearch +============ + +Description +----------- + +``vectorSearch(table='', field='', vector='', option='')`` + +All four arguments are required and must be passed by name as string +literals. Positional arguments, or a mix of positional and named +arguments, are not supported. For example, the following is invalid:: + + FROM vectorSearch('my_vectors', field='embedding', + vector='[0.1,0.2]', option='k=5') AS v + +A table alias is required. Projected fields are referenced through the +alias (``v._id``, ``v._score``, ``v.category``). + +If the ``opensearch-knn`` plugin is not installed on the target cluster, +query execution fails with a ``vectorSearch() requires the k-NN plugin`` +error. ``_explain`` continues to work without the plugin. + +Arguments +--------- + +- ``table``: single concrete index or alias to search. Wildcards + (``*``), comma-separated multi-index targets, ``_all``, ``.``, and + ``..`` are not supported. The target index must have + ``index.knn: true`` and map the target field as ``knn_vector``. A + normal alias name is accepted. If the alias resolves to multiple + backing indices, the SQL layer does not prevalidate that every + backing index has a compatible ``knn_vector`` mapping, dimension, or + engine; OpenSearch execution remains the source of truth for those + checks. +- ``field``: name of the ``knn_vector`` field. +- ``vector``: query vector as a JSON-style array of numbers, passed as a + string (for example, ``'[0.1, 0.2, 0.3]'``). Components must be + comma-separated finite numbers. Semicolon, colon, and pipe separators + are not supported, and empty components (for example, ``'[1.0,,2.0]'`` + or ``'[1.0,]'``) return an error. The vector dimension must match the + ``knn_vector`` mapping on the target index. +- ``option``: comma-separated ``key=value`` pairs. Exactly one of ``k``, + ``max_distance``, or ``min_score`` is required. ``filter_type`` is + optional. + +Supported option keys +--------------------- + +Option keys are lower-case and case-sensitive. ``K=5`` or +``Filter_Type=post`` returns an "Unknown option key" error. + +- ``k``: top-k mode. Integer between 1 and 10000. The query returns up to + ``k`` nearest neighbors. +- ``max_distance``: radial mode. Non-negative number. Matches documents + within the given distance of the query vector. ``LIMIT`` is required and + caps the returned rows. +- ``min_score``: radial mode. Non-negative number. Matches documents with + score at or above the given threshold. ``LIMIT`` is required and caps + the returned rows. +- ``filter_type``: ``post`` or ``efficient``. Controls how a ``WHERE`` + clause is applied. See `Filtering`_. + +``k``, ``max_distance``, and ``min_score`` are mutually exclusive; specify +exactly one. + +Native k-NN tuning options (for example, ``method_parameters.ef_search``, +``method_parameters.nprobes``, ``rescore.oversample_factor``) are not +supported through ``vectorSearch()`` and return an "Unknown option +key" error. + +Syntax +------ + +:: + + SELECT + FROM vectorSearch( + table='', + field='', + vector='', + option='' + ) AS + [WHERE ] + [ORDER BY ._score DESC] + [LIMIT ] + +Example 1: Top-k +---------------- + +Return the five nearest neighbors of a query vector:: + + POST /_plugins/_sql + { + "query" : """ + SELECT v._id, v._score + FROM vectorSearch( + table='my_vectors', + field='embedding', + vector='[0.1, 0.2, 0.3]', + option='k=5' + ) AS v + """ + } + +In top-k mode, the request size defaults to ``k``; adding ``LIMIT n`` further +reduces the row count, but ``n`` must not exceed ``k``. + +Example 2: Radial search (``max_distance``) +------------------------------------------- + +Return up to the specified ``LIMIT`` documents within a maximum distance +of the query vector. ``LIMIT`` is required for radial searches; without +it the result set would be unbounded:: + + POST /_plugins/_sql + { + "query" : """ + SELECT v._id, v._score + FROM vectorSearch( + table='my_vectors', + field='embedding', + vector='[0.1, 0.2, 0.3]', + option='max_distance=0.5' + ) AS v + LIMIT 100 + """ + } + +Example 3: Radial search (``min_score``) +---------------------------------------- + +Return up to the specified ``LIMIT`` documents whose score is at or +above the given threshold. ``LIMIT`` is required for radial searches; +without it the result set would be unbounded:: + + POST /_plugins/_sql + { + "query" : """ + SELECT v._id, v._score + FROM vectorSearch( + table='my_vectors', + field='embedding', + vector='[0.1, 0.2, 0.3]', + option='min_score=0.8' + ) AS v + LIMIT 100 + """ + } + +Filtering +========= + +A ``WHERE`` clause on non-vector fields of the ``vectorSearch()`` alias is +pushed down to OpenSearch when it can be translated to an OpenSearch filter. +Two placement strategies are available via the ``filter_type`` option: + +- ``efficient`` (default): the ``WHERE`` predicate is embedded directly + inside the k-NN query (``knn.filter``), enabling native efficient + k-NN filtering during vector search. Efficient filtering depends on + native k-NN engine and method support; if the target index does not + support ``knn.filter`` for the configured engine and method, set + ``filter_type=post``. See the `k-NN filtering guide + `_ + for engine and method requirements. +- ``post``: the k-NN query is placed in a scoring (``bool.must``) + context and the ``WHERE`` predicate is placed as a non-scoring + ``bool.filter`` outside the k-NN clause. This is Boolean filter + placement, not the REST ``post_filter`` parameter, and may return + fewer than ``k`` rows when the filter is selective. + +Full-text predicates (``match``, ``match_phrase``, ``multi_match``, and +the rest of the full-text family) under a ``WHERE`` clause are used as +filters, not as hybrid keyword-vector score fusion. Their placement +follows ``filter_type``: the default (``efficient``) embeds supported +full-text predicates under ``knn.filter``, while ``post`` places them +in ``bool.filter`` outside the k-NN clause. In both cases they restrict +which candidates are retained but their text relevance score does not +combine with the vector ``_score``. ``vectorSearch()`` is not a hybrid +vector + text relevance scorer. + +Behavior depends on whether ``filter_type`` is specified: + +- **Omitted (default, ``efficient``)**: the ``WHERE`` predicate is + embedded under ``knn.filter`` so the k-NN engine applies native + efficient filtering during vector search. A query with no ``WHERE`` + clause is valid. ``efficient`` supports simple native filters: + ``term``, ``range``, ``wildcard``, ``exists``, full-text family + (``match``, ``match_phrase``, ``match_phrase_prefix``, + ``match_bool_prefix``, ``multi_match``, ``query_string``, + ``simple_query_string``), and boolean combinations of those filters. + Predicates that compile to script queries (arithmetic, function calls + on indexed fields, ``CASE``, date math), nested predicates, and other + query shapes are not supported under ``knn.filter`` and return an + error. Set ``filter_type=post`` to apply such predicates after the + k-NN search. If the predicate cannot be translated to an OpenSearch + filter query at all (a distinct translation failure from the + unsupported-shape cases above), the default path falls back to + evaluating the ``WHERE`` clause in memory after the k-NN results are + returned. +- **Explicit ``efficient``**: same contract as the default. Specifying + it is useful when a query should be explicit about the placement + strategy and should fail if the predicate cannot be safely embedded + under ``knn.filter``. +- **Explicit ``post``**: a ``WHERE`` clause is required and must be + translatable to an OpenSearch filter query. Predicates that translate + to native OpenSearch queries are pushed down as a ``bool.filter`` + alongside the k-NN query. Predicates that do not have a native + equivalent (for example, arithmetic or function calls on indexed + fields) are pushed down as an OpenSearch script query and evaluated + server-side. If predicate translation itself fails, the query returns + an error; there is no silent in-memory fallback under explicit + ``post``. Use ``filter_type=post`` when the predicate shape is not + supported by efficient filtering. + +Example 4: Default efficient filtering (no ``filter_type``) +----------------------------------------------------------- + +:: + + POST /_plugins/_sql + { + "query" : """ + SELECT v._id, v._score, v.category + FROM vectorSearch( + table='my_vectors', + field='embedding', + vector='[0.1, 0.2, 0.3]', + option='k=10' + ) AS v + WHERE v.category = 'books' + """ + } + +The predicate is embedded under ``knn.filter`` so the k-NN engine +applies native efficient filtering during vector search. + +Example 5: Post-filtering for predicates not supported by efficient mode +------------------------------------------------------------------------ + +Use ``filter_type=post`` for predicates that do not fit the ``efficient`` +allow-list, such as arithmetic or function calls on indexed fields:: + + POST /_plugins/_sql + { + "query" : """ + SELECT v._id, v._score, v.category + FROM vectorSearch( + table='my_vectors', + field='embedding', + vector='[0.1, 0.2, 0.3]', + option='k=10,filter_type=post' + ) AS v + WHERE v.price * 1.1 < 100 + """ + } + +Scoring, sorting, and limits +============================ + +- ``vectorSearch()`` exposes the OpenSearch ``_score`` metadata field on the + alias. For an alias ``v``, select it as ``v._score``. +- ``_score`` can be selected and referenced in ``ORDER BY``, but it cannot + appear in ``WHERE``. Use ``option='min_score=...'`` for score-threshold + vector search. +- Results are returned in ``_score DESC`` order by default. The only + supported ``ORDER BY`` expression is ``._score DESC`` (for + example, ``v._score DESC``). +- In top-k mode (``k=N``), ``LIMIT n`` is optional; when present, ``n`` must + be ``≤ k``. +- In radial mode (``max_distance`` or ``min_score``), ``LIMIT`` is required. +- ``OFFSET`` is not supported on ``vectorSearch()``. Use ``LIMIT`` only. + +Limitations +=========== + +The following are not supported on ``vectorSearch()``: + +- ``GROUP BY`` and aggregations directly over a ``vectorSearch()`` + relation are not supported and return an error. +- Operators wrapped around a ``vectorSearch()`` subquery are rejected + when they would run after ``vectorSearch()`` has already produced a + finite result set, because they can silently yield zero, skipped, or + incorrectly ordered rows. Specifically, an outer ``WHERE``, + ``ORDER BY``, ``OFFSET`` (non-zero), ``GROUP BY``, aggregation, or + ``DISTINCT`` applied to a ``vectorSearch()`` subquery returns an + error. Place ``WHERE`` predicates inside the subquery, directly on + the ``vectorSearch()`` alias, so that they participate in ``WHERE`` + pushdown. A plain outer ``LIMIT`` (without ``OFFSET``) wrapping a + ``vectorSearch()`` subquery is allowed and caps the returned rows. +- ``JOIN`` between a ``vectorSearch()`` relation and another relation is + not supported. +- ``UNION`` / ``INTERSECT`` / ``EXCEPT`` combining a ``vectorSearch()`` + relation with another relation is not supported. +- Multiple ``vectorSearch()`` calls in the same query are not supported. +- The query vector must be supplied as a literal. Parameterized vectors + (for example, values bound from another column) are not supported. +- Indexes that define a user field named ``_score`` cannot be queried + with ``vectorSearch()`` because ``_score`` is reserved for the + synthetic vector score exposed on the alias. Rename the field or query + the index with a plain ``SELECT``. diff --git a/docs/user/index.rst b/docs/user/index.rst index bb4b6399198..32ce39ed93d 100644 --- a/docs/user/index.rst +++ b/docs/user/index.rst @@ -43,6 +43,8 @@ OpenSearch SQL enables you to extract insights out of OpenSearch using the famil - `Window Functions `_ + - `Vector Search `_ + * **Beyond SQL** - `PartiQL (JSON) Support `_ diff --git a/integ-test/src/test/java/org/opensearch/sql/sql/ExistsPushdownIT.java b/integ-test/src/test/java/org/opensearch/sql/sql/ExistsPushdownIT.java new file mode 100644 index 00000000000..08ceb8c35f9 --- /dev/null +++ b/integ-test/src/test/java/org/opensearch/sql/sql/ExistsPushdownIT.java @@ -0,0 +1,83 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.sql; + +import java.io.IOException; +import java.util.regex.Matcher; +import java.util.regex.Pattern; +import org.junit.Test; +import org.opensearch.sql.legacy.SQLIntegTestCase; +import org.opensearch.sql.legacy.TestsConstants; + +/** + * Explain-plan integration tests asserting that {@code IS NOT NULL} / {@code IS NULL} predicates + * push down as native OpenSearch {@code exists} DSL rather than as serialized script queries. + * + *

Before this change both predicates serialized through the compounded script engine, producing + * a {@code "script"} clause in the pushdown DSL. After this change the v2 filter builder emits + * {@code {"exists": {"field": ...}}} directly for {@code IS NOT NULL}, and a {@code bool} query + * with a single {@code must_not[exists]} child for {@code IS NULL}. This matches what downstream + * tooling, serverless / AOSS, and the Calcite path already produce. + */ +public class ExistsPushdownIT extends SQLIntegTestCase { + + // Anchored on the surrounding `sourceBuilder=...`, `pitId=` tokens in OpenSearchRequest's + // toString() output. Test-only coupling: if that request-string format changes (token renamed, + // pitId removed), this helper breaks even when the DSL shape is still correct. Update the regex + // anchors if that happens. + private static final Pattern SOURCE_BUILDER_JSON = + Pattern.compile("sourceBuilder=(\\{.*?\\}), pitId=", Pattern.DOTALL); + + /** Extracts and unescapes the sourceBuilder JSON embedded in the explain request string. */ + private static String extractSourceBuilderJson(String explain) { + Matcher m = SOURCE_BUILDER_JSON.matcher(explain); + assertTrue("Explain should contain sourceBuilder JSON:\n" + explain, m.find()); + return m.group(1).replace("\\\"", "\""); + } + + @Override + protected void init() throws Exception { + loadIndex(Index.ACCOUNT); + } + + private static final String TEST_INDEX = TestsConstants.TEST_INDEX_ACCOUNT; + + @Test + public void testIsNotNullPushesDownAsExistsQuery() throws IOException { + String explain = + explainQuery("SELECT age FROM " + TEST_INDEX + " WHERE age IS NOT NULL LIMIT 1"); + String sourceBuilder = extractSourceBuilderJson(explain); + + assertTrue( + "IS NOT NULL should push down as native exists DSL:\n" + sourceBuilder, + sourceBuilder.contains("\"exists\"")); + assertTrue( + "IS NOT NULL exists DSL should target the 'age' field:\n" + sourceBuilder, + sourceBuilder.contains("\"field\":\"age\"")); + assertFalse( + "IS NOT NULL should not fall through to a script query:\n" + sourceBuilder, + sourceBuilder.contains("\"script\"")); + } + + @Test + public void testIsNullPushesDownAsMustNotExistsQuery() throws IOException { + String explain = explainQuery("SELECT age FROM " + TEST_INDEX + " WHERE age IS NULL LIMIT 1"); + String sourceBuilder = extractSourceBuilderJson(explain); + + assertTrue( + "IS NULL should push down as bool/must_not[exists] DSL:\n" + sourceBuilder, + sourceBuilder.contains("\"must_not\"")); + assertTrue( + "IS NULL should wrap a native exists clause:\n" + sourceBuilder, + sourceBuilder.contains("\"exists\"")); + assertTrue( + "IS NULL exists DSL should target the 'age' field:\n" + sourceBuilder, + sourceBuilder.contains("\"field\":\"age\"")); + assertFalse( + "IS NULL should not fall through to a script query:\n" + sourceBuilder, + sourceBuilder.contains("\"script\"")); + } +} diff --git a/integ-test/src/test/java/org/opensearch/sql/sql/VectorSearchExecutionIT.java b/integ-test/src/test/java/org/opensearch/sql/sql/VectorSearchExecutionIT.java new file mode 100644 index 00000000000..36e78567d54 --- /dev/null +++ b/integ-test/src/test/java/org/opensearch/sql/sql/VectorSearchExecutionIT.java @@ -0,0 +1,227 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.sql; + +import static org.opensearch.sql.util.TestUtils.createIndexByRestClient; +import static org.opensearch.sql.util.TestUtils.isIndexExist; +import static org.opensearch.sql.util.TestUtils.performRequest; + +import java.io.IOException; +import java.util.Arrays; +import java.util.HashSet; +import java.util.Set; +import org.json.JSONArray; +import org.json.JSONObject; +import org.junit.Assume; +import org.junit.Test; +import org.opensearch.client.Request; +import org.opensearch.client.Response; +import org.opensearch.sql.legacy.SQLIntegTestCase; + +/** + * Happy-path execution tests for the vectorSearch() SQL table function. These tests run an actual + * k-NN query against a small in-memory knn_vector index and assert that results come back ordered + * by score and respect any WHERE filters. + * + *

The k-NN plugin is not provisioned by the default integ-test cluster — each test calls {@link + * Assume#assumeTrue} on {@link #isKnnPluginInstalled()} so the class is silently skipped when k-NN + * is absent. Run locally against a cluster that has opensearch-knn installed. Provisioning k-NN in + * CI is a separate follow-up. + */ +public class VectorSearchExecutionIT extends SQLIntegTestCase { + + private static final String TEST_INDEX = "vector_exec_test"; + + // 6 docs in 2D — two clusters so filter/radial tests have distinguishable results. + // Cluster A near [1, 1]: docs 1-3 (state=TX, ages 25/30/40). + // Cluster B near [9, 9]: docs 4-6 (state=CA, ages 28/35/45). + // Pin Lucene HNSW + L2 so efficient filtering is deterministic (k-NN supports efficient + // filtering only on lucene+hnsw and faiss+hnsw/ivf) and the L2 → 1/(1+d) scoring used by the + // radial min_score test is well-defined. + private static final String MAPPING = + "{" + + " \"settings\": {\"index\": {\"knn\": true}}," + + " \"mappings\": {" + + " \"properties\": {" + + " \"embedding\": {" + + " \"type\": \"knn_vector\"," + + " \"dimension\": 2," + + " \"method\": {" + + " \"name\": \"hnsw\"," + + " \"engine\": \"lucene\"," + + " \"space_type\": \"l2\"" + + " }" + + " }," + + " \"state\": {\"type\": \"keyword\"}," + + " \"age\": {\"type\": \"integer\"}" + + " }" + + " }" + + "}"; + + private static final String BULK_BODY = + "{\"index\":{\"_id\":\"1\"}}\n" + + "{\"embedding\":[1.0,1.0],\"state\":\"TX\",\"age\":25}\n" + + "{\"index\":{\"_id\":\"2\"}}\n" + + "{\"embedding\":[1.1,0.9],\"state\":\"TX\",\"age\":30}\n" + + "{\"index\":{\"_id\":\"3\"}}\n" + + "{\"embedding\":[0.9,1.2],\"state\":\"TX\",\"age\":40}\n" + + "{\"index\":{\"_id\":\"4\"}}\n" + + "{\"embedding\":[9.0,9.0],\"state\":\"CA\",\"age\":28}\n" + + "{\"index\":{\"_id\":\"5\"}}\n" + + "{\"embedding\":[9.1,8.8],\"state\":\"CA\",\"age\":35}\n" + + "{\"index\":{\"_id\":\"6\"}}\n" + + "{\"embedding\":[8.7,9.3],\"state\":\"CA\",\"age\":45}\n"; + + @Override + protected void init() throws Exception { + Assume.assumeTrue("k-NN plugin not installed on test cluster", isKnnPluginInstalled()); + if (!isIndexExist(client(), TEST_INDEX)) { + createIndexByRestClient(client(), TEST_INDEX, MAPPING); + Request bulk = new Request("POST", "/" + TEST_INDEX + "/_bulk?refresh=true"); + bulk.setJsonEntity(BULK_BODY); + performRequest(client(), bulk); + } + } + + private static boolean isKnnPluginInstalled() { + try { + Response response = client().performRequest(new Request("GET", "/_cat/plugins?h=component")); + String body = new String(response.getEntity().getContent().readAllBytes()); + return body.contains("opensearch-knn"); + } catch (IOException e) { + return false; + } + } + + // ── Top-k happy path ──────────────────────────────────────────────── + + @Test + public void testTopKReturnsNearestSortedByScore() throws IOException { + JSONObject result = + executeJdbcRequest( + "SELECT v._id, v._score " + + "FROM vectorSearch(table='" + + TEST_INDEX + + "', field='embedding', " + + "vector='[1.0, 1.0]', option='k=3') AS v " + + "LIMIT 3"); + + // All 3 returned docs should be from cluster A (ids 1-3), ordered by score desc. + JSONArray rows = result.getJSONArray("datarows"); + assertEquals("Expected 3 rows:\n" + result, 3, rows.length()); + for (int i = 0; i < rows.length(); i++) { + String id = rows.getJSONArray(i).getString(0); + assertTrue( + "Row " + i + " id=" + id + " should be from cluster A (1,2,3):\n" + result, + id.equals("1") || id.equals("2") || id.equals("3")); + } + // Scores must be non-increasing. + double prev = Double.POSITIVE_INFINITY; + for (int i = 0; i < rows.length(); i++) { + double score = rows.getJSONArray(i).getDouble(1); + assertTrue( + "Scores must be sorted desc, got " + score + " after " + prev + ":\n" + result, + score <= prev); + prev = score; + } + } + + // ── POST filter happy path ────────────────────────────────────────── + + @Test + public void testPostFilterReturnsOnlyMatchingDocs() throws IOException { + // Query from cluster B with WHERE state='TX' forces POST filtering to surface TX docs + // (cluster A) even though the vector is closer to cluster B. k=10 covers all 6 docs so + // post-filtering to state='TX' deterministically yields exactly {1,2,3}. filter_type=post + // is specified explicitly because the default placement is EFFICIENT — this test + // guarantees POST continues to work when the user opts into it. + JSONObject result = + executeJdbcRequest( + "SELECT v._id, v._score " + + "FROM vectorSearch(table='" + + TEST_INDEX + + "', field='embedding', " + + "vector='[9.0, 9.0]', option='k=10,filter_type=post') AS v " + + "WHERE v.state = 'TX' " + + "LIMIT 10"); + + assertRowIdsEqual(result, "1", "2", "3"); + } + + // ── EFFICIENT filter happy path ───────────────────────────────────── + + @Test + public void testEfficientFilterReturnsOnlyMatchingDocs() throws IOException { + // Query vector sits on cluster A (TX) but WHERE state='CA' forces EFFICIENT filtering to + // navigate HNSW toward CA docs. With k=3, a POST-filter implementation would return 0 rows + // (the 3 nearest candidates are all TX, which get filtered out); an efficient-filter + // implementation returns exactly the 3 CA docs {4,5,6}. This asymmetry makes the test + // discriminate between the two filter modes. + JSONObject result = + executeJdbcRequest( + "SELECT v._id, v._score " + + "FROM vectorSearch(table='" + + TEST_INDEX + + "', field='embedding', " + + "vector='[1.0, 1.0]', option='k=3,filter_type=efficient') AS v " + + "WHERE v.state = 'CA' " + + "LIMIT 3"); + + assertRowIdsEqual(result, "4", "5", "6"); + } + + // ── Radial happy paths ────────────────────────────────────────────── + + @Test + public void testRadialMaxDistanceReturnsOnlyNearDocs() throws IOException { + // max_distance=1.0 (L2) centered on [1,1] includes all 3 cluster A docs (max L2 ≈ 0.22) + // and excludes cluster B which is ~11 units away. + JSONObject result = + executeJdbcRequest( + "SELECT v._id " + + "FROM vectorSearch(table='" + + TEST_INDEX + + "', field='embedding', " + + "vector='[1.0, 1.0]', option='max_distance=1.0') AS v " + + "LIMIT 10"); + + assertRowIdsEqual(result, "1", "2", "3"); + } + + @Test + public void testRadialMinScoreReturnsOnlyHighScoreDocs() throws IOException { + // For L2 space, OpenSearch score = 1/(1+distance). Centered on [1,1], cluster A docs + // score ~0.82-1.0 and cluster B scores ~0.08. min_score=0.5 yields exactly {1,2,3}. + JSONObject result = + executeJdbcRequest( + "SELECT v._id, v._score " + + "FROM vectorSearch(table='" + + TEST_INDEX + + "', field='embedding', " + + "vector='[1.0, 1.0]', option='min_score=0.5') AS v " + + "LIMIT 10"); + + JSONArray rows = result.getJSONArray("datarows"); + for (int i = 0; i < rows.length(); i++) { + double score = rows.getJSONArray(i).getDouble(1); + assertTrue("Row " + i + " score=" + score + " should be >= 0.5:\n" + result, score >= 0.5); + } + assertRowIdsEqual(result, "1", "2", "3"); + } + + /** Asserts the result's datarows column 0 contains exactly the given ids (as a set). */ + private static void assertRowIdsEqual(JSONObject result, String... expectedIds) { + JSONArray rows = result.getJSONArray("datarows"); + assertEquals( + "Expected " + expectedIds.length + " rows:\n" + result, expectedIds.length, rows.length()); + Set expected = new HashSet<>(Arrays.asList(expectedIds)); + Set actual = new HashSet<>(); + for (int i = 0; i < rows.length(); i++) { + actual.add(rows.getJSONArray(i).getString(0)); + } + assertEquals("Row id set mismatch:\n" + result, expected, actual); + } +} diff --git a/integ-test/src/test/java/org/opensearch/sql/sql/VectorSearchExplainIT.java b/integ-test/src/test/java/org/opensearch/sql/sql/VectorSearchExplainIT.java new file mode 100644 index 00000000000..8719189b13a --- /dev/null +++ b/integ-test/src/test/java/org/opensearch/sql/sql/VectorSearchExplainIT.java @@ -0,0 +1,559 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.sql; + +import java.io.IOException; +import java.nio.charset.StandardCharsets; +import java.util.ArrayList; +import java.util.Base64; +import java.util.List; +import java.util.regex.Matcher; +import java.util.regex.Pattern; +import org.junit.Test; +import org.opensearch.sql.legacy.SQLIntegTestCase; +import org.opensearch.sql.legacy.TestsConstants; + +/** + * Explain-plan integration tests for vectorSearch SQL table function. These tests verify DSL + * push-down shape via _explain. They do NOT require the k-NN plugin since _explain only parses and + * plans the query without executing it against a knn index. + */ +public class VectorSearchExplainIT extends SQLIntegTestCase { + + // Matches WrapperQueryBuilder's base64 payload in explain JSON. The explain output escapes + // quotes as \", so the regex tolerates both \" and " forms around the query key/value. + private static final Pattern WRAPPER_PAYLOAD = + Pattern.compile("\\\\?\"query\\\\?\":\\\\?\"([A-Za-z0-9+/=]+)\\\\?\""); + // Anchored on the surrounding `sourceBuilder=...`, `pitId=` tokens in OpenSearchRequest's + // toString() output. Test-only coupling: if that request-string format changes (token renamed, + // pitId removed), this helper breaks even when the DSL shape is still correct. Update the regex + // anchors if that happens. + private static final Pattern SOURCE_BUILDER_JSON = + Pattern.compile("sourceBuilder=(\\{.*?\\}), pitId=", Pattern.DOTALL); + + /** Decodes every base64-encoded wrapper payload in the explain output into its knn JSON. */ + private static List decodeWrapperKnnJsons(String explain) { + List payloads = new ArrayList<>(); + Matcher m = WRAPPER_PAYLOAD.matcher(explain); + while (m.find()) { + payloads.add(new String(Base64.getDecoder().decode(m.group(1)), StandardCharsets.UTF_8)); + } + return payloads; + } + + /** Returns the single wrapper knn JSON, asserting exactly one is present. */ + private static String decodeSoleKnnJson(String explain) { + List payloads = decodeWrapperKnnJsons(explain); + assertEquals( + "Expected exactly one wrapper query payload in explain:\n" + explain, 1, payloads.size()); + return payloads.get(0); + } + + /** Extracts and unescapes the sourceBuilder JSON embedded in the explain request string. */ + private static String extractSourceBuilderJson(String explain) { + Matcher m = SOURCE_BUILDER_JSON.matcher(explain); + assertTrue("Explain should contain sourceBuilder JSON:\n" + explain, m.find()); + return m.group(1).replace("\\\"", "\""); + } + + @Override + protected void init() throws Exception { + // _explain needs the index to exist for field resolution. + loadIndex(Index.ACCOUNT); + } + + private static final String TEST_INDEX = TestsConstants.TEST_INDEX_ACCOUNT; + + // ── Top-k / radial DSL shape ───────────────────────────────────────── + + @Test + public void testExplainTopKProducesKnnQuery() throws IOException { + String explain = + explainQuery( + "SELECT v._id, v._score " + + "FROM vectorSearch(table='" + + TEST_INDEX + + "', field='embedding', " + + "vector='[1.0, 2.0, 3.0]', option='k=5') AS v " + + "LIMIT 5"); + + assertTrue( + "Explain should contain track_scores:\n" + explain, explain.contains("track_scores")); + + // Top-k without WHERE should have the knn at the root, not wrapped in an outer bool. + String sourceBuilderJson = extractSourceBuilderJson(explain); + assertFalse( + "Top-k without WHERE should not wrap knn in an outer bool:\n" + sourceBuilderJson, + sourceBuilderJson.contains("\"bool\"")); + + String knnJson = decodeSoleKnnJson(explain); + assertTrue("knn JSON should contain knn key:\n" + knnJson, knnJson.contains("\"knn\"")); + assertTrue( + "knn JSON should target the embedding field:\n" + knnJson, + knnJson.contains("\"embedding\"")); + assertTrue( + "knn JSON should contain the vector values:\n" + knnJson, + knnJson.contains("[1.0,2.0,3.0]")); + assertTrue("knn JSON should contain k=5:\n" + knnJson, knnJson.contains("\"k\":5")); + assertFalse( + "Top-k without WHERE should not embed a filter:\n" + knnJson, knnJson.contains("filter")); + } + + @Test + public void testExplainRadialMaxDistanceProducesKnnQuery() throws IOException { + String explain = + explainQuery( + "SELECT v._id, v._score " + + "FROM vectorSearch(table='" + + TEST_INDEX + + "', field='embedding', " + + "vector='[1.0, 2.0]', option='max_distance=10.5') AS v " + + "LIMIT 100"); + + // Radial without WHERE should have the knn at the root, not wrapped in an outer bool. + String sourceBuilderJson = extractSourceBuilderJson(explain); + assertFalse( + "Radial without WHERE should not wrap knn in an outer bool:\n" + sourceBuilderJson, + sourceBuilderJson.contains("\"bool\"")); + + String knnJson = decodeSoleKnnJson(explain); + assertTrue("knn JSON should contain knn key:\n" + knnJson, knnJson.contains("\"knn\"")); + assertTrue( + "knn JSON should target the embedding field:\n" + knnJson, + knnJson.contains("\"embedding\"")); + assertTrue( + "knn JSON should contain the vector values:\n" + knnJson, knnJson.contains("[1.0,2.0]")); + assertTrue( + "knn JSON should contain max_distance=10.5:\n" + knnJson, + knnJson.contains("\"max_distance\":10.5")); + assertFalse( + "Radial without WHERE should not embed a filter:\n" + knnJson, knnJson.contains("filter")); + } + + @Test + public void testExplainRadialMinScoreProducesKnnQuery() throws IOException { + String explain = + explainQuery( + "SELECT v._id, v._score " + + "FROM vectorSearch(table='" + + TEST_INDEX + + "', field='embedding', " + + "vector='[1.0, 2.0]', option='min_score=0.8') AS v " + + "LIMIT 100"); + + // Radial without WHERE should have the knn at the root, not wrapped in an outer bool. + String sourceBuilderJson = extractSourceBuilderJson(explain); + assertFalse( + "Radial without WHERE should not wrap knn in an outer bool:\n" + sourceBuilderJson, + sourceBuilderJson.contains("\"bool\"")); + + String knnJson = decodeSoleKnnJson(explain); + assertTrue("knn JSON should contain knn key:\n" + knnJson, knnJson.contains("\"knn\"")); + assertTrue( + "knn JSON should target the embedding field:\n" + knnJson, + knnJson.contains("\"embedding\"")); + assertTrue( + "knn JSON should contain the vector values:\n" + knnJson, knnJson.contains("[1.0,2.0]")); + assertTrue( + "knn JSON should contain min_score=0.8:\n" + knnJson, + knnJson.contains("\"min_score\":0.8")); + assertFalse( + "Radial without WHERE should not embed a filter:\n" + knnJson, knnJson.contains("filter")); + } + + // ── Default (EFFICIENT) pre-filter DSL shape ──────────────────────── + + @Test + public void testExplainDefaultFilterProducesKnnWithFilter() throws IOException { + String explain = + explainQuery( + "SELECT v._id, v._score " + + "FROM vectorSearch(table='" + + TEST_INDEX + + "', field='embedding', " + + "vector='[1.0, 2.0, 3.0]', option='k=10') AS v " + + "WHERE v.state = 'TX' " + + "LIMIT 10"); + + // Default (EFFICIENT) shape: WHERE is embedded inside knn.filter, the knn JSON is base64- + // encoded inside a WrapperQueryBuilder, and there is no outer bool/must wrapping. + String sourceBuilderJson = extractSourceBuilderJson(explain); + assertFalse( + "Default EFFICIENT mode should not produce bool query:\n" + sourceBuilderJson, + sourceBuilderJson.contains("\"bool\"")); + assertFalse( + "Default EFFICIENT mode should not contain must clause:\n" + sourceBuilderJson, + sourceBuilderJson.contains("\"must\"")); + + String knnJson = decodeSoleKnnJson(explain); + assertTrue("knn JSON should contain knn key:\n" + knnJson, knnJson.contains("\"knn\"")); + assertTrue( + "knn JSON should target the embedding field:\n" + knnJson, + knnJson.contains("\"embedding\"")); + assertTrue("knn JSON should contain k=10:\n" + knnJson, knnJson.contains("\"k\":10")); + assertTrue( + "Default EFFICIENT mode must embed filter inside knn:\n" + knnJson, + knnJson.contains("filter")); + assertTrue( + "Default EFFICIENT mode must embed the WHERE predicate inside knn:\n" + knnJson, + knnJson.contains("state")); + } + + @Test + public void testExplainDefaultCompoundPredicateProducesKnnWithFilter() throws IOException { + String explain = + explainQuery( + "SELECT v._id, v._score " + + "FROM vectorSearch(table='" + + TEST_INDEX + + "', field='embedding', " + + "vector='[1.0, 2.0, 3.0]', option='k=10') AS v " + + "WHERE v.state = 'TX' AND v.age > 30 " + + "LIMIT 10"); + + // Compound default-mode WHERE must also route through knn.filter: no outer bool/must, and + // both predicate fields embedded inside the knn payload. + String sourceBuilderJson = extractSourceBuilderJson(explain); + assertFalse( + "Default EFFICIENT mode should not produce bool query:\n" + sourceBuilderJson, + sourceBuilderJson.contains("\"bool\"")); + assertFalse( + "Default EFFICIENT mode should not contain must clause:\n" + sourceBuilderJson, + sourceBuilderJson.contains("\"must\"")); + + String knnJson = decodeSoleKnnJson(explain); + assertTrue("knn JSON should contain knn key:\n" + knnJson, knnJson.contains("\"knn\"")); + assertTrue( + "knn JSON should target the embedding field:\n" + knnJson, + knnJson.contains("\"embedding\"")); + assertTrue("knn JSON should contain k=10:\n" + knnJson, knnJson.contains("\"k\":10")); + assertTrue( + "Compound default EFFICIENT must embed filter inside knn:\n" + knnJson, + knnJson.contains("filter")); + assertTrue( + "Compound default EFFICIENT must embed the state predicate inside knn:\n" + knnJson, + knnJson.contains("state")); + assertTrue( + "Compound default EFFICIENT must embed the age predicate inside knn:\n" + knnJson, + knnJson.contains("age")); + } + + @Test + public void testExplainDefaultRadialWithWhereProducesKnnWithFilter() throws IOException { + String explain = + explainQuery( + "SELECT v._id, v._score " + + "FROM vectorSearch(table='" + + TEST_INDEX + + "', field='embedding', " + + "vector='[1.0, 2.0]', option='max_distance=10.5') AS v " + + "WHERE v.state = 'TX' " + + "LIMIT 100"); + + // Radial + default WHERE must also use the EFFICIENT shape: no outer bool/must, radial + // parameters preserved inside the knn payload, and the WHERE predicate embedded alongside. + String sourceBuilderJson = extractSourceBuilderJson(explain); + assertFalse( + "Default EFFICIENT mode should not produce bool query:\n" + sourceBuilderJson, + sourceBuilderJson.contains("\"bool\"")); + assertFalse( + "Default EFFICIENT mode should not contain must clause:\n" + sourceBuilderJson, + sourceBuilderJson.contains("\"must\"")); + + String knnJson = decodeSoleKnnJson(explain); + assertTrue("knn JSON should contain knn key:\n" + knnJson, knnJson.contains("\"knn\"")); + assertTrue( + "knn JSON should target the embedding field:\n" + knnJson, + knnJson.contains("\"embedding\"")); + assertTrue( + "knn JSON should contain max_distance=10.5:\n" + knnJson, + knnJson.contains("\"max_distance\":10.5")); + assertTrue( + "Radial default EFFICIENT must embed filter inside knn:\n" + knnJson, + knnJson.contains("filter")); + assertTrue( + "Radial default EFFICIENT must embed the WHERE predicate inside knn:\n" + knnJson, + knnJson.contains("state")); + } + + // ── Sort + LIMIT explain ───────────────────────────────────────────── + + @Test + public void testOrderByScoreDescExplainSucceeds() throws IOException { + String explain = + explainQuery( + "SELECT v._id, v._score " + + "FROM vectorSearch(table='" + + TEST_INDEX + + "', field='embedding', " + + "vector='[1.0, 2.0]', option='k=5') AS v " + + "ORDER BY v._score DESC " + + "LIMIT 5"); + + assertTrue( + "Explain should succeed with ORDER BY _score DESC:\n" + explain, + explain.contains("wrapper")); + } + + @Test + public void testExplainLimitWithinKSucceeds() throws IOException { + String explain = + explainQuery( + "SELECT v._id, v._score " + + "FROM vectorSearch(table='" + + TEST_INDEX + + "', field='embedding', " + + "vector='[1.0, 2.0]', option='k=10') AS v " + + "LIMIT 5"); + + assertTrue("Explain should succeed with LIMIT <= k:\n" + explain, explain.contains("wrapper")); + } + + // ── filter_type explain ───────────────────────────────────────────── + + @Test + public void testExplainFilterTypePostProducesBoolQuery() throws IOException { + String explain = + explainQuery( + "SELECT v._id, v._score " + + "FROM vectorSearch(table='" + + TEST_INDEX + + "', field='embedding', " + + "vector='[1.0, 2.0, 3.0]', option='k=10,filter_type=post') AS v " + + "WHERE v.state = 'TX' " + + "LIMIT 10"); + + // Explicit filter_type=post must produce the same bool.must=[knn]/bool.filter=[term] shape as + // the default, and the WHERE predicate must NOT leak into the knn payload (that would be + // efficient mode). This is the key false-positive guard: substring-only checks would pass for + // efficient mode too. + String sourceBuilderJson = extractSourceBuilderJson(explain); + assertTrue( + "Explain should contain bool query:\n" + sourceBuilderJson, + sourceBuilderJson.contains("\"bool\"")); + assertTrue( + "Explain should contain must:\n" + sourceBuilderJson, + sourceBuilderJson.contains("\"must\"")); + assertTrue( + "Explain should contain filter:\n" + sourceBuilderJson, + sourceBuilderJson.contains("\"filter\"")); + assertTrue( + "Explain should contain the outer state predicate:\n" + sourceBuilderJson, + sourceBuilderJson.contains("\"state.keyword\"")); + + String knnJson = decodeSoleKnnJson(explain); + assertTrue("knn JSON should contain knn key:\n" + knnJson, knnJson.contains("\"knn\"")); + assertTrue( + "knn JSON should target the embedding field:\n" + knnJson, + knnJson.contains("\"embedding\"")); + assertFalse( + "filter_type=post must not embed the WHERE predicate inside knn:\n" + knnJson, + knnJson.contains("state")); + assertFalse( + "filter_type=post must not embed a filter inside knn:\n" + knnJson, + knnJson.contains("filter")); + } + + @Test + public void testExplainFilterTypeEfficientProducesKnnWithFilter() throws IOException { + String explain = + explainQuery( + "SELECT v._id, v._score " + + "FROM vectorSearch(table='" + + TEST_INDEX + + "', field='embedding', " + + "vector='[1.0, 2.0]', option='k=5,filter_type=efficient') AS v " + + "WHERE v.state = 'TX' " + + "LIMIT 5"); + + // Efficient mode: knn rebuilt with filter inside, wrapped in WrapperQueryBuilder. + // The knn JSON (including the embedded filter) is base64-encoded inside the wrapper, + // so we verify structure by: (1) no bool/must in plaintext (that would be post-filter shape), + // (2) decode the base64 payload to confirm the filter and predicate field are embedded inside + // the knn query. + String sourceBuilderJson = extractSourceBuilderJson(explain); + assertFalse( + "Efficient mode should not produce bool query (that is post-filter shape):\n" + + sourceBuilderJson, + sourceBuilderJson.contains("\"bool\"")); + assertFalse( + "Efficient mode should not contain must clause:\n" + sourceBuilderJson, + sourceBuilderJson.contains("\"must\"")); + + String knnJson = decodeSoleKnnJson(explain); + assertTrue( + "Efficient mode knn JSON should contain filter:\n" + knnJson, knnJson.contains("filter")); + assertTrue( + "Efficient mode knn JSON should contain the WHERE predicate field:\n" + knnJson, + knnJson.contains("state")); + } + + @Test + public void testEfficientFilterWithOrderByScoreDescSucceeds() throws IOException { + String explain = + explainQuery( + "SELECT v._id, v._score " + + "FROM vectorSearch(table='" + + TEST_INDEX + + "', field='embedding', " + + "vector='[1.0, 2.0]', option='k=5,filter_type=efficient') AS v " + + "WHERE v.state = 'TX' " + + "ORDER BY v._score DESC " + + "LIMIT 5"); + + // Same efficient-mode shape guarantee as testExplainFilterTypeEfficientProducesKnnWithFilter, + // with an added ORDER BY _score DESC: no outer bool/must, and the WHERE predicate must be + // embedded inside the knn payload (efficient filtering, not post-filter). + String sourceBuilderJson = extractSourceBuilderJson(explain); + assertFalse( + "Efficient mode should not produce bool query (that is post-filter shape):\n" + + sourceBuilderJson, + sourceBuilderJson.contains("\"bool\"")); + assertFalse( + "Efficient mode should not contain must clause:\n" + sourceBuilderJson, + sourceBuilderJson.contains("\"must\"")); + + String knnJson = decodeSoleKnnJson(explain); + assertTrue( + "Efficient mode knn JSON should contain filter:\n" + knnJson, knnJson.contains("filter")); + assertTrue( + "Efficient mode knn JSON should contain the WHERE predicate field:\n" + knnJson, + knnJson.contains("state")); + } + + // ── BETWEEN / NOT IN pushdown regression guards ───────────────────── + // These tests lock in the DSL shape currently produced for BETWEEN and NOT IN predicates + // when pushed down through vectorSearch(). They exist to catch silent regressions where a + // change in the v2 FilterQueryBuilder pipeline would fall back to a serialized script query + // instead of the native range/bool shape the cluster can index-accelerate. + + @Test + public void testBetweenPushesAsRange() throws IOException { + // Pin filter_type=post to keep the regression guard aimed at the post-filter serialization + // path: these assertions lock in the outer bool/must/filter shape that only appears when + // WHERE is applied alongside knn rather than embedded under knn.filter. + String explain = + explainQuery( + "SELECT v._id, v._score " + + "FROM vectorSearch(table='" + + TEST_INDEX + + "', field='embedding', " + + "vector='[1.0, 2.0, 3.0]', option='k=10,filter_type=post') AS v " + + "WHERE v.balance BETWEEN 50 AND 200 " + + "LIMIT 10"); + + // BETWEEN is desugared by the analyzer into AND(>=, <=), which FilterQueryBuilder renders as + // two range clauses combined under a bool. The goal here is regression lock-in: ensure the + // pushed filter is native range DSL, not a serialized script query. + String sourceBuilderJson = extractSourceBuilderJson(explain); + assertTrue( + "Explain should contain bool query:\n" + sourceBuilderJson, + sourceBuilderJson.contains("\"bool\"")); + assertTrue( + "Explain should contain must clause (knn in scoring context):\n" + sourceBuilderJson, + sourceBuilderJson.contains("\"must\"")); + assertTrue( + "Explain should contain filter clause (WHERE in non-scoring context):\n" + + sourceBuilderJson, + sourceBuilderJson.contains("\"filter\"")); + assertTrue( + "BETWEEN should push as native range DSL:\n" + sourceBuilderJson, + sourceBuilderJson.contains("\"range\"")); + assertTrue( + "Range should target balance field:\n" + sourceBuilderJson, + sourceBuilderJson.contains("\"balance\"")); + // RangeQueryBuilder serializes inclusive bounds as from/to + include_lower/include_upper. Lock + // both the lower bound (50) and upper bound (200) are present in the pushed DSL. + assertTrue( + "Range should contain lower bound 50:\n" + sourceBuilderJson, + sourceBuilderJson.contains("\"from\" : 50") || sourceBuilderJson.contains("\"from\":50")); + assertTrue( + "Range should contain upper bound 200:\n" + sourceBuilderJson, + sourceBuilderJson.contains("\"to\" : 200") || sourceBuilderJson.contains("\"to\":200")); + // Script-query fallback sentinel: the CompoundedScriptEngine lang marker must NOT appear when + // BETWEEN is pushed down natively. + assertFalse( + "BETWEEN must not fall back to a serialized script query:\n" + sourceBuilderJson, + sourceBuilderJson.contains("\"script\"")); + + // POST-filter mode (default): the WHERE predicate must live OUTSIDE the knn payload. + String knnJson = decodeSoleKnnJson(explain); + assertTrue("knn JSON should contain knn key:\n" + knnJson, knnJson.contains("\"knn\"")); + assertFalse( + "Post-filter mode must not embed the balance predicate inside knn:\n" + knnJson, + knnJson.contains("balance")); + assertFalse( + "Post-filter mode must not embed a range inside knn:\n" + knnJson, + knnJson.contains("range")); + } + + @Test + public void testNotInPushesAsMustNotTerms() throws IOException { + // Pin filter_type=post to keep the regression guard aimed at the post-filter serialization + // path: these assertions lock in the outer bool/must/filter shape that only appears when + // WHERE is applied alongside knn rather than embedded under knn.filter. + String explain = + explainQuery( + "SELECT v._id, v._score " + + "FROM vectorSearch(table='" + + TEST_INDEX + + "', field='embedding', " + + "vector='[1.0, 2.0, 3.0]', option='k=10,filter_type=post') AS v " + + "WHERE v.gender NOT IN ('M', 'F') " + + "LIMIT 10"); + + // v2 analyzer desugars `x NOT IN (a, b)` into `NOT(x = a OR x = b)`. FilterQueryBuilder maps + // NOT to bool.must_not and OR to bool.should, so the pushed DSL is must_not[should[term,term]] + // rather than a single terms clause. The shape we're locking in is: native bool with must_not + // on the keyword subfield, *not* a serialized script query. + String sourceBuilderJson = extractSourceBuilderJson(explain); + assertTrue( + "Explain should contain bool query:\n" + sourceBuilderJson, + sourceBuilderJson.contains("\"bool\"")); + assertTrue( + "Explain should contain must clause (knn in scoring context):\n" + sourceBuilderJson, + sourceBuilderJson.contains("\"must\"")); + assertTrue( + "Explain should contain filter clause (WHERE in non-scoring context):\n" + + sourceBuilderJson, + sourceBuilderJson.contains("\"filter\"")); + assertTrue( + "NOT IN should push as bool.must_not:\n" + sourceBuilderJson, + sourceBuilderJson.contains("\"must_not\"")); + // OR-of-equals desugaring means the two literals land in a bool.should of term clauses. + assertTrue( + "NOT IN should contain should clause for OR-of-equals desugaring:\n" + sourceBuilderJson, + sourceBuilderJson.contains("\"should\"")); + assertTrue( + "NOT IN should produce term clauses for each literal:\n" + sourceBuilderJson, + sourceBuilderJson.contains("\"term\"")); + // Terms target the keyword subfield of gender (text field with .keyword multi-field). + assertTrue( + "NOT IN term clauses should target gender.keyword:\n" + sourceBuilderJson, + sourceBuilderJson.contains("\"gender.keyword\"")); + // Both literals must be present in the pushed DSL. + assertTrue( + "NOT IN should contain the 'M' literal:\n" + sourceBuilderJson, + sourceBuilderJson.contains("\"M\"")); + assertTrue( + "NOT IN should contain the 'F' literal:\n" + sourceBuilderJson, + sourceBuilderJson.contains("\"F\"")); + // Script-query fallback sentinel: native pushdown must not degrade to a serialized script. + assertFalse( + "NOT IN must not fall back to a serialized script query:\n" + sourceBuilderJson, + sourceBuilderJson.contains("\"script\"")); + + // POST-filter mode (default): the WHERE predicate must live OUTSIDE the knn payload. + String knnJson = decodeSoleKnnJson(explain); + assertTrue("knn JSON should contain knn key:\n" + knnJson, knnJson.contains("\"knn\"")); + assertFalse( + "Post-filter mode must not embed the gender predicate inside knn:\n" + knnJson, + knnJson.contains("gender")); + assertFalse( + "Post-filter mode must not embed must_not inside knn:\n" + knnJson, + knnJson.contains("must_not")); + } +} diff --git a/integ-test/src/test/java/org/opensearch/sql/sql/VectorSearchIT.java b/integ-test/src/test/java/org/opensearch/sql/sql/VectorSearchIT.java new file mode 100644 index 00000000000..c10b3a219f6 --- /dev/null +++ b/integ-test/src/test/java/org/opensearch/sql/sql/VectorSearchIT.java @@ -0,0 +1,755 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.sql; + +import static org.hamcrest.Matchers.containsString; + +import java.io.IOException; +import org.junit.Test; +import org.opensearch.client.Request; +import org.opensearch.client.ResponseException; +import org.opensearch.sql.legacy.SQLIntegTestCase; +import org.opensearch.sql.legacy.TestsConstants; + +/** + * Integration tests for vectorSearch SQL table function — validation and error paths. These tests + * verify that invalid inputs are rejected with clear error messages. Explain-plan DSL shape tests + * live in {@link VectorSearchExplainIT}. + */ +public class VectorSearchIT extends SQLIntegTestCase { + + @Override + protected void init() throws Exception { + loadIndex(Index.ACCOUNT); + } + + private static final String TEST_INDEX = TestsConstants.TEST_INDEX_ACCOUNT; + + // ── Validation error paths ──────────────────────────────────────────── + + @Test + public void testMutualExclusivityRejectsKAndMaxDistance() throws IOException { + ResponseException ex = + expectThrows( + ResponseException.class, + () -> + executeQuery( + "SELECT v._id FROM vectorSearch(table='t', field='f', " + + "vector='[1.0]', option='k=5,max_distance=10') AS v")); + + assertThat(ex.getMessage(), containsString("Only one of")); + } + + @Test + public void testMutualExclusivityRejectsKAndMinScore() throws IOException { + ResponseException ex = + expectThrows( + ResponseException.class, + () -> + executeQuery( + "SELECT v._id FROM vectorSearch(table='t', field='f', " + + "vector='[1.0]', option='k=5,min_score=0.5') AS v")); + + assertThat(ex.getMessage(), containsString("Only one of")); + } + + @Test + public void testKTooLargeRejects() throws IOException { + ResponseException ex = + expectThrows( + ResponseException.class, + () -> + executeQuery( + "SELECT v._id FROM vectorSearch(table='t', field='f', " + + "vector='[1.0]', option='k=10001') AS v")); + + assertThat(ex.getMessage(), containsString("k must be between 1 and 10000")); + } + + @Test + public void testKZeroRejects() throws IOException { + ResponseException ex = + expectThrows( + ResponseException.class, + () -> + executeQuery( + "SELECT v._id FROM vectorSearch(table='t', field='f', " + + "vector='[1.0]', option='k=0') AS v")); + + assertThat(ex.getMessage(), containsString("k must be between 1 and 10000")); + } + + @Test + public void testUnknownOptionKeyRejects() throws IOException { + ResponseException ex = + expectThrows( + ResponseException.class, + () -> + executeQuery( + "SELECT v._id FROM vectorSearch(table='t', field='f', " + + "vector='[1.0]', option='k=5,method.ef_search=100') AS v")); + + assertThat(ex.getMessage(), containsString("Unknown option key")); + } + + @Test + public void testEmptyVectorRejects() throws IOException { + ResponseException ex = + expectThrows( + ResponseException.class, + () -> + executeQuery( + "SELECT v._id FROM vectorSearch(table='t', field='f', " + + "vector='[]', option='k=5') AS v")); + + assertThat(ex.getMessage(), containsString("must not be empty")); + } + + @Test + public void testInvalidFieldNameRejects() throws IOException { + ResponseException ex = + expectThrows( + ResponseException.class, + () -> + executeQuery( + "SELECT v._id FROM vectorSearch(table='t', " + + "field='field\\\"injection', vector='[1.0]', option='k=5') AS v")); + + assertThat(ex.getMessage(), containsString("Invalid field name")); + } + + @Test + public void testMissingRequiredOptionRejects() throws IOException { + ResponseException ex = + expectThrows( + ResponseException.class, + () -> + executeQuery( + "SELECT v._id FROM vectorSearch(table='t', field='f', " + + "vector='[1.0]', option='') AS v")); + + assertThat(ex.getMessage(), containsString("Missing required option")); + } + + @Test + public void testRadialWithoutLimitRejects() throws IOException { + ResponseException ex = + expectThrows( + ResponseException.class, + () -> + executeQuery( + "SELECT v._id FROM vectorSearch(table='" + + TEST_INDEX + + "', field='embedding', " + + "vector='[1.0, 2.0]', option='max_distance=10.5') AS v")); + + assertThat(ex.getMessage(), containsString("LIMIT is required for radial vector search")); + } + + // ── Sort restriction validation ───────────────────────────────────────── + + @Test + public void testOrderByNonScoreFieldRejects() throws IOException { + ResponseException ex = + expectThrows( + ResponseException.class, + () -> + executeQuery( + "SELECT v._id FROM vectorSearch(table='" + + TEST_INDEX + + "', field='embedding', " + + "vector='[1.0, 2.0]', option='k=5') AS v " + + "ORDER BY v.firstname ASC " + + "LIMIT 5")); + + assertThat(ex.getMessage(), containsString("unsupported sort expression")); + } + + @Test + public void testOrderByScoreAscRejects() throws IOException { + ResponseException ex = + expectThrows( + ResponseException.class, + () -> + executeQuery( + "SELECT v._id FROM vectorSearch(table='" + + TEST_INDEX + + "', field='embedding', " + + "vector='[1.0, 2.0]', option='k=5') AS v " + + "ORDER BY v._score ASC " + + "LIMIT 5")); + + assertThat(ex.getMessage(), containsString("_score ASC is not supported")); + } + + // ── filter_type validation ──────────────────────────────────────────── + + @Test + public void testFilterTypeEfficientWithoutWhereRejects() throws IOException { + ResponseException ex = + expectThrows( + ResponseException.class, + () -> + executeQuery( + "SELECT v._id FROM vectorSearch(table='" + + TEST_INDEX + + "', field='embedding', " + + "vector='[1.0, 2.0]', option='k=5,filter_type=efficient') AS v " + + "LIMIT 5")); + + assertThat(ex.getMessage(), containsString("filter_type requires a pushdownable WHERE clause")); + } + + @Test + public void testFilterTypePostWithoutWhereRejects() throws IOException { + ResponseException ex = + expectThrows( + ResponseException.class, + () -> + executeQuery( + "SELECT v._id FROM vectorSearch(table='" + + TEST_INDEX + + "', field='embedding', " + + "vector='[1.0, 2.0]', option='k=5,filter_type=post') AS v " + + "LIMIT 5")); + + assertThat(ex.getMessage(), containsString("filter_type requires a pushdownable WHERE clause")); + } + + @Test + public void testInvalidFilterTypeRejects() throws IOException { + ResponseException ex = + expectThrows( + ResponseException.class, + () -> + executeQuery( + "SELECT v._id FROM vectorSearch(table='t', field='f', " + + "vector='[1.0]', option='k=5,filter_type=bogus') AS v")); + + assertThat(ex.getMessage(), containsString("filter_type must be one of")); + } + + @Test + public void testGroupByRejects() throws IOException { + ResponseException ex = + expectThrows( + ResponseException.class, + () -> + executeQuery( + "SELECT v.gender, COUNT(*) FROM vectorSearch(table='" + + TEST_INDEX + + "', field='f', vector='[1.0]', option='k=5') AS v GROUP BY v.gender")); + + assertThat( + ex.getMessage(), + containsString("Aggregations are not supported on vectorSearch() relations")); + } + + @Test + public void testBareAggregateRejects() throws IOException { + ResponseException ex = + expectThrows( + ResponseException.class, + () -> + executeQuery( + "SELECT COUNT(*) FROM vectorSearch(table='" + + TEST_INDEX + + "', field='f', vector='[1.0]', option='k=5') AS v")); + + assertThat( + ex.getMessage(), + containsString("Aggregations are not supported on vectorSearch() relations")); + } + + // ── OFFSET / WHERE _score / filter_type=efficient script rejection ─── + + @Test + public void testOffsetRejected() throws IOException { + ResponseException ex = + expectThrows( + ResponseException.class, + () -> + executeQuery( + "SELECT v._id FROM vectorSearch(table='" + + TEST_INDEX + + "', field='embedding', " + + "vector='[1.0, 2.0]', option='k=5') AS v " + + "LIMIT 5 OFFSET 2")); + + assertThat(ex.getMessage(), containsString("OFFSET is not supported on vectorSearch()")); + assertThat(ex.getMessage(), containsString("LIMIT only")); + } + + @Test + public void testScoreInWhereRejected() throws IOException { + ResponseException ex = + expectThrows( + ResponseException.class, + () -> + executeQuery( + "SELECT v._id FROM vectorSearch(table='" + + TEST_INDEX + + "', field='embedding', " + + "vector='[1.0, 2.0]', option='k=5') AS v " + + "WHERE v._score > 0.5 " + + "LIMIT 5")); + + assertThat(ex.getMessage(), containsString("WHERE on _score is not supported")); + assertThat(ex.getMessage(), containsString("min_score")); + } + + @Test + public void testOrderByScoreDescLimitOffsetRejected() throws IOException { + // The natural user shape pairs sort with pagination: ORDER BY _score DESC LIMIT N OFFSET M. + // The planner's pushDownSort() path can collapse the sort+limit into a top-k size, so OFFSET + // must still be rejected by pushDownLimit when the combined form is used. Without this guard + // the parent builder would push `from: ` and silently shift the top-k window. + ResponseException ex = + expectThrows( + ResponseException.class, + () -> + executeQuery( + "SELECT v._id FROM vectorSearch(table='" + + TEST_INDEX + + "', field='embedding', " + + "vector='[1.0, 2.0]', option='k=5') AS v " + + "ORDER BY v._score DESC " + + "LIMIT 5 OFFSET 2")); + + assertThat(ex.getMessage(), containsString("OFFSET is not supported on vectorSearch()")); + } + + @Test + public void testEfficientModeRejectsScriptPredicate() throws IOException { + // WHERE age + 1 > 30 compiles to a ScriptQueryBuilder under the hood because the outer > + // is applied to an arithmetic expression, not a direct field reference. Efficient mode + // cannot embed script queries under knn.filter, so this must be rejected up front with a + // clear remediation hint instead of a cluster-side failure. + ResponseException ex = + expectThrows( + ResponseException.class, + () -> + executeQuery( + "SELECT v._id FROM vectorSearch(table='" + + TEST_INDEX + + "', field='embedding', " + + "vector='[1.0, 2.0]', option='k=5,filter_type=efficient') AS v " + + "WHERE v.age + 1 > 30 " + + "LIMIT 5")); + + assertThat( + ex.getMessage(), containsString("vectorSearch WHERE pre-filtering does not support")); + assertThat(ex.getMessage(), containsString("script queries")); + } + + // ── k-NN plugin capability check ────────────────────────────────────── + // The default integ-test cluster does not have the k-NN plugin installed. Execution-path + // queries against vectorSearch() should therefore fail with the clear "k-NN plugin missing" + // error from KnnPluginCapability, while _explain continues to work because the capability + // probe is deferred to scan open() and does not run during analysis/planning. + + @Test + public void testExecutionWithoutKnnPluginReturnsCapabilityError() throws IOException { + ResponseException ex = + expectThrows( + ResponseException.class, + () -> + executeQuery( + "SELECT v._id FROM vectorSearch(table='" + + TEST_INDEX + + "', field='embedding', vector='[1.0, 2.0]', option='k=5') AS v " + + "LIMIT 5")); + + // Lock in the full user-facing sentence, not just loose substrings. The exact wording is + // part of the contract and regressions should fail loudly rather than keep passing on a + // subtly reworded message. + assertThat( + ex.getMessage(), + containsString( + "vectorSearch() requires the k-NN plugin, which is not installed on this cluster.")); + } + + @Test + public void testExplainWithoutKnnPluginStillWorks() throws IOException { + // _explain only parses and plans the query. It must NOT require the k-NN plugin — the + // capability probe is intentionally deferred to scan open() so pluginless clusters can + // still inspect query plans. If this test starts failing with "k-NN plugin not installed", + // the probe has leaked back into an analysis-time path. + String explain = + explainQuery( + "SELECT v._id FROM vectorSearch(table='" + + TEST_INDEX + + "', field='embedding', vector='[1.0, 2.0]', option='k=5') AS v " + + "LIMIT 5"); + + // Assert the scan-operator name, not just "wrapper": the name confirms the plan reached + // the vectorSearch scan builder rather than some other scan shape. + assertThat(explain, containsString("VectorSearchIndexScan")); + assertThat(explain, containsString("wrapper")); + } + + // ── Argument shape validation ───────────────────────────────────────── + + @Test + public void testInvalidTableNameRejected() throws IOException { + // A slash is outside the SAFE_FIELD_NAME regex and is not a valid OpenSearch index character, + // so it should be rejected at the SQL layer before any cluster call is attempted. + ResponseException ex = + expectThrows( + ResponseException.class, + () -> + executeQuery( + "SELECT v._id FROM vectorSearch(table='idx/evil', field='f', " + + "vector='[1.0]', option='k=5') AS v")); + + assertThat(ex.getMessage(), containsString("Invalid table name")); + } + + @Test + public void testWildcardTableRejectedWithDedicatedMessage() throws IOException { + // Wildcards in a table name fan out to multiple indices, which vectorSearch() does not + // support (top-k semantics, dimension checks, and embedded filter JSON are not defined + // across heterogeneous shards). Surface a dedicated user-facing error instead of the + // generic "must contain only alphanumeric..." fallback. + ResponseException ex = + expectThrows( + ResponseException.class, + () -> + executeQuery( + "SELECT v._id FROM vectorSearch(table='sql_vector_*', field='f', " + + "vector='[1.0]', option='k=5') AS v")); + + assertThat(ex.getMessage(), containsString("Invalid table name")); + assertThat(ex.getMessage(), containsString("wildcards")); + assertThat(ex.getMessage(), containsString("single concrete index")); + } + + @Test + public void testMultiTargetTableRejectedWithDedicatedMessage() throws IOException { + ResponseException ex = + expectThrows( + ResponseException.class, + () -> + executeQuery( + "SELECT v._id FROM vectorSearch(table='idx_a,idx_b', field='f', " + + "vector='[1.0]', option='k=5') AS v")); + + assertThat(ex.getMessage(), containsString("Invalid table name")); + assertThat(ex.getMessage(), containsString("multi-target")); + } + + @Test + public void testDuplicateNamedArgRejected() throws IOException { + // Previously this crashed the server with 500 ArrayIndexOutOfBoundsException. Must now + // surface as a clean 400 with a user-facing message. + ResponseException ex = + expectThrows( + ResponseException.class, + () -> + executeQuery( + "SELECT v._id FROM vectorSearch(table='a', table='b', " + + "vector='[1.0]', option='k=5') AS v")); + + assertThat(ex.getMessage(), containsString("Duplicate argument name")); + } + + @Test + public void testUnknownNamedArgRejected() throws IOException { + // A grammar-legal but unknown name must surface as a clean 400 from the resolver. + ResponseException ex = + expectThrows( + ResponseException.class, + () -> + executeQuery( + "SELECT v._id FROM vectorSearch(bogus='idx', field='f', " + + "vector='[1.0]', option='k=5') AS v")); + + assertThat(ex.getMessage(), containsString("Unknown argument name")); + } + + @Test + public void testPositionalArgRejected() throws IOException { + // The real shape a user would hit: `vectorSearch('idx', field=..., vector=..., option=...)`. + // The V2 grammar now accepts this form so the AstBuilder can surface a clean + // SemanticCheckException instead of letting the request fall back to the legacy SQL engine, + // which previously returned 200 with zero rows. + ResponseException ex = + expectThrows( + ResponseException.class, + () -> + executeQuery( + "SELECT v._id FROM vectorSearch('idx', field='embedding', " + + "vector='[1.0, 1.0]', option='k=3') AS v LIMIT 3")); + + assertThat(ex.getMessage(), containsString("requires named arguments")); + } + + @Test + public void testCaseInsensitiveDuplicateArgRejected() throws IOException { + // Argument names are normalized to lower-case, so `table` and `TABLE` must be treated as the + // same key and rejected as a duplicate rather than silently keeping one of the two values. + ResponseException ex = + expectThrows( + ResponseException.class, + () -> + executeQuery( + "SELECT v._id FROM vectorSearch(table='a', TABLE='b', " + + "vector='[1.0]', option='k=5') AS v")); + + assertThat(ex.getMessage(), containsString("Duplicate argument name")); + } + + @Test + public void testTableNameAllRejected() throws IOException { + // `_all` would fan out to every index. The preview contract is a single concrete index or + // alias, so it must be rejected explicitly rather than allowed to route broadly. + ResponseException ex = + expectThrows( + ResponseException.class, + () -> + executeQuery( + "SELECT v._id FROM vectorSearch(table='_all', field='f', " + + "vector='[1.0]', option='k=5') AS v")); + + assertThat(ex.getMessage(), containsString("Invalid table name")); + } + + @Test + public void testTableNameSingleDotRejected() throws IOException { + ResponseException ex = + expectThrows( + ResponseException.class, + () -> + executeQuery( + "SELECT v._id FROM vectorSearch(table='.', field='f', " + + "vector='[1.0]', option='k=5') AS v")); + + assertThat(ex.getMessage(), containsString("Invalid table name")); + } + + @Test + public void testTableNameDoubleDotRejected() throws IOException { + ResponseException ex = + expectThrows( + ResponseException.class, + () -> + executeQuery( + "SELECT v._id FROM vectorSearch(table='..', field='f', " + + "vector='[1.0]', option='k=5') AS v")); + + assertThat(ex.getMessage(), containsString("Invalid table name")); + } + + @Test + public void testMissingRequiredArgRejected() throws IOException { + // Omitting a required named argument (here: `field`) must produce a clean 400 rather than a + // NullPointerException or a legacy-engine fallback. + ResponseException ex = + expectThrows( + ResponseException.class, + () -> + executeQuery( + "SELECT v._id FROM vectorSearch(table='a', " + + "vector='[1.0]', option='k=5') AS v")); + + assertThat(ex.getMessage(), containsString("requires 4 arguments")); + } + + /** + * Users running FROM vectorSearch(...) without an AS alias previously received an opaque parser + * error from the legacy SQL engine fallback. The clearer SemanticCheckException from the v2 + * engine must surface to the user instead. + */ + @Test + public void testVectorSearchRequiresAlias() throws IOException { + ResponseException ex = + expectThrows( + ResponseException.class, + () -> + executeQuery( + "SELECT * FROM vectorSearch(" + + "table='t', field='f', vector='[1.0]', option='k=5') " + + "LIMIT 3")); + + String body = ex.getMessage(); + assertThat(body, containsString("requires a table alias")); + assertThat(body, containsString("vectorSearch")); + } + + // Synthetic column collision (metadata vs. user field). + // vectorSearch() exposes synthetic v._id and v._score columns. A user mapping property of the + // same name would collide on the response tuple key. OpenSearch blocks _id at mapping time; + // _score is not blocked, so VectorSearchIndex rejects it at scan-build time. + + @Test + public void testUserMappingWithIdFieldIsRejectedByOpenSearch() throws IOException { + // Locks in OpenSearch's rejection of a user property named _id: without it, v._id could + // collide with a user field at response time. The exact error message belongs to OpenSearch. + String indexName = "vs_collision_id"; + deleteIndexIfExists(indexName); + + Request createIndex = new Request("PUT", "/" + indexName); + createIndex.setJsonEntity("{\"mappings\":{\"properties\":{\"_id\":{\"type\":\"keyword\"}}}}"); + + expectThrows(ResponseException.class, () -> client().performRequest(createIndex)); + } + + @Test + public void testVectorSearchAgainstIndexWithScoreFieldRejects() throws IOException { + // _explain exercises planning (where the guard runs) without needing the k-NN plugin. + String indexName = "vs_collision_score"; + deleteIndexIfExists(indexName); + + Request createIndex = new Request("PUT", "/" + indexName); + createIndex.setJsonEntity("{\"mappings\":{\"properties\":{\"_score\":{\"type\":\"float\"}}}}"); + client().performRequest(createIndex); + + ResponseException ex = + expectThrows( + ResponseException.class, + () -> + explainQuery( + "SELECT v._score FROM vectorSearch(table='" + + indexName + + "', field='embedding', vector='[1.0, 2.0]', option='k=5') AS v " + + "LIMIT 5")); + + assertEquals(400, ex.getResponse().getStatusLine().getStatusCode()); + assertThat(ex.getMessage(), containsString("_score")); + assertThat(ex.getMessage(), containsString("collides")); + } + + @Test + public void testSemicolonSeparatorInVectorRejected() throws IOException { + ResponseException ex = + expectThrows( + ResponseException.class, + () -> + executeQuery( + "SELECT v._id FROM vectorSearch(table='t', field='f', " + + "vector='[1.0;2.0]', option='k=5') AS v")); + + assertThat(ex.getMessage(), containsString("vector=")); + assertThat(ex.getMessage(), containsString("comma-separated")); + } + + @Test + public void testNegativeMinScoreRejected() throws IOException { + ResponseException ex = + expectThrows( + ResponseException.class, + () -> + executeQuery( + "SELECT v._id FROM vectorSearch(table='t', field='f', " + + "vector='[1.0]', option='min_score=-0.5') AS v")); + + assertThat(ex.getMessage(), containsString("min_score")); + assertThat(ex.getMessage(), containsString("non-negative")); + } + + @Test + public void testNegativeMaxDistanceRejected() throws IOException { + ResponseException ex = + expectThrows( + ResponseException.class, + () -> + executeQuery( + "SELECT v._id FROM vectorSearch(table='t', field='f', " + + "vector='[1.0]', option='max_distance=-1.0') AS v")); + + assertThat(ex.getMessage(), containsString("max_distance")); + assertThat(ex.getMessage(), containsString("non-negative")); + } + + @Test + public void testTrailingCommaInVectorRejected() throws IOException { + ResponseException ex = + expectThrows( + ResponseException.class, + () -> + executeQuery( + "SELECT v._id FROM vectorSearch(table='t', field='f', " + + "vector='[1.0,2.0,]', option='k=5') AS v")); + + assertThat(ex.getMessage(), containsString("Invalid vector component")); + assertThat(ex.getMessage(), containsString("trailing or consecutive commas")); + } + + // ── Alias with multiple backing indices ─────────────────────────────── + // vectorSearch() accepts an alias as `table=`. When the alias points at multiple backing + // indices, planning must accept the alias string instead of treating it as a wildcard or + // multi-target. Execution correctness over compatible knn_vector mappings is a separate + // concern covered by k-NN-enabled tests/follow-up; these tests lock in planning acceptance + // only, via _explain on the default no-kNN cluster. + + @Test + public void testExplainOverAliasWithMultipleBackingIndices() throws IOException { + // Create two indices with identical keyword mappings (no knn_vector, since the plugin is + // not installed) and a shared alias. We only assert the planner accepts the alias; whether + // k-NN accepts the alias at execution is a separate concern tested on a k-NN-enabled + // cluster. + // Randomized names so a stale alias/index left by an aborted prior run of this class does + // not shadow a fresh setup, which is a concrete risk on local reruns. + String suffix = java.util.UUID.randomUUID().toString().replace("-", "").substring(0, 8); + String idx1 = "vector_alias_backing_1_" + suffix; + String idx2 = "vector_alias_backing_2_" + suffix; + String alias = "vector_alias_combined_" + suffix; + try { + createSimpleIndex(idx1); + createSimpleIndex(idx2); + addToAlias(idx1, alias); + addToAlias(idx2, alias); + + String explain = + explainQuery( + "SELECT v._id FROM vectorSearch(table='" + + alias + + "', field='embedding', vector='[1.0, 2.0]', option='k=5') AS v"); + + assertThat(explain, containsString("VectorSearchIndexScan")); + assertThat(explain, containsString(alias)); + } finally { + // Deleting the backing indices removes the alias automatically, but delete the alias + // first for robustness against partial setup failures. + deleteAliasIfExists(alias); + deleteIndexIfExists(idx1); + deleteIndexIfExists(idx2); + } + } + + private void createSimpleIndex(String indexName) throws IOException { + Request create = new Request("PUT", "/" + indexName); + create.setJsonEntity("{\"mappings\":{\"properties\":{\"state\":{\"type\":\"keyword\"}}}}"); + client().performRequest(create); + } + + private void addToAlias(String indexName, String aliasName) throws IOException { + Request req = new Request("POST", "/_aliases"); + req.setJsonEntity( + "{\"actions\":[{\"add\":{\"index\":\"" + + indexName + + "\",\"alias\":\"" + + aliasName + + "\"}}]}"); + client().performRequest(req); + } + + private void deleteIndexIfExists(String indexName) { + try { + client().performRequest(new Request("DELETE", "/" + indexName)); + } catch (IOException ignored) { + // Index does not exist, which is fine. + } + } + + private void deleteAliasIfExists(String aliasName) { + try { + client().performRequest(new Request("DELETE", "/_all/_alias/" + aliasName)); + } catch (IOException ignored) { + // Alias does not exist, which is fine. + } + } +} diff --git a/integ-test/src/test/java/org/opensearch/sql/sql/VectorSearchSubqueryIT.java b/integ-test/src/test/java/org/opensearch/sql/sql/VectorSearchSubqueryIT.java new file mode 100644 index 00000000000..04346f87a76 --- /dev/null +++ b/integ-test/src/test/java/org/opensearch/sql/sql/VectorSearchSubqueryIT.java @@ -0,0 +1,306 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.sql; + +import static org.hamcrest.Matchers.containsString; + +import java.io.IOException; +import org.junit.Test; +import org.opensearch.client.ResponseException; +import org.opensearch.sql.legacy.SQLIntegTestCase; +import org.opensearch.sql.legacy.TestsConstants; + +/** + * Integration tests for vectorSearch() used inside subqueries. Locks in the rejection of outer + * WHERE on a vectorSearch() subquery, which would otherwise silently yield zero rows because the + * outer predicate is applied only after the k-NN search has already selected top-k documents by + * vector distance. + * + *

Uses _explain-only plus error-path queries, so the k-NN plugin is not required — the planner + * validation fires during planning, before any k-NN execution. + */ +public class VectorSearchSubqueryIT extends SQLIntegTestCase { + + @Override + protected void init() throws Exception { + loadIndex(Index.ACCOUNT); + } + + private static final String TEST_INDEX = TestsConstants.TEST_INDEX_ACCOUNT; + + @Test + public void testOuterWhereOnSubqueryRejected() throws IOException { + // Without the guard the outer predicate is dropped from the pushed DSL and applied only in + // memory after k-NN returned top-k, which can yield silent zero rows. + ResponseException ex = + expectThrows( + ResponseException.class, + () -> + executeQuery( + "SELECT * FROM (SELECT v.firstname, v.state " + + "FROM vectorSearch(table='" + + TEST_INDEX + + "', field='embedding', vector='[1.0, 2.0]', option='k=5') AS v) t " + + "WHERE t.state = 'TX'")); + + assertThat( + ex.getMessage(), + containsString("Outer WHERE on a vectorSearch() subquery is not supported")); + assertThat(ex.getMessage(), containsString("silently yield zero rows")); + } + + @Test + public void testOuterWhereOnSubqueryRejectedWithLimit() throws IOException { + // Same shape with an outer LIMIT — exercises a second planner path (LogicalLimit above + // LogicalFilter above LogicalProject above scan builder). + ResponseException ex = + expectThrows( + ResponseException.class, + () -> + executeQuery( + "SELECT * FROM (SELECT v.firstname, v.state " + + "FROM vectorSearch(table='" + + TEST_INDEX + + "', field='embedding', vector='[1.0, 2.0]', option='k=5') AS v) t " + + "WHERE t.state = 'TX' " + + "LIMIT 3")); + + assertThat( + ex.getMessage(), + containsString("Outer WHERE on a vectorSearch() subquery is not supported")); + } + + @Test + public void testOuterWhereOnSubqueryRejectedExplain() throws IOException { + // The guard must fire during planning, before any k-NN execution — so _explain must also + // return the validation error rather than a silently dropped predicate in the DSL. + ResponseException ex = + expectThrows( + ResponseException.class, + () -> + explainQuery( + "SELECT * FROM (SELECT v.firstname, v.state " + + "FROM vectorSearch(table='" + + TEST_INDEX + + "', field='embedding', vector='[1.0, 2.0]', option='k=5') AS v) t " + + "WHERE t.state = 'TX'")); + + assertThat( + ex.getMessage(), + containsString("Outer WHERE on a vectorSearch() subquery is not supported")); + } + + @Test + public void testOuterWhereWithInnerWhereStillRejected() throws IOException { + // Outer WHERE must be rejected even when the subquery already has its own inner WHERE. + // The shape reaches the planner as Filter(outer) -> Project -> Filter(inner) -> Scan, and + // the outer predicate is still separated from the k-NN search by the subquery project + // boundary. Without preserving the project marker across the inner filter, the walker + // would miss this shape and the outer predicate would silently produce zero rows. + ResponseException ex = + expectThrows( + ResponseException.class, + () -> + executeQuery( + "SELECT * FROM (SELECT v.firstname, v.state, v.age " + + "FROM vectorSearch(table='" + + TEST_INDEX + + "', field='embedding', vector='[1.0, 2.0]', option='k=5') AS v " + + "WHERE v.age > 10) t " + + "WHERE t.state = 'TX'")); + + assertThat( + ex.getMessage(), + containsString("Outer WHERE on a vectorSearch() subquery is not supported")); + } + + @Test + public void testInnerWhereStillWorks() throws IOException { + // Positive control: WHERE directly on vectorSearch() inside the subquery must still plan + // successfully — the rejection is scoped to OUTER filters that cannot reach the push-down + // contract. We use _explain because the default integ-test cluster has no k-NN plugin. + String explain = + explainQuery( + "SELECT * FROM (SELECT v.firstname, v.state " + + "FROM vectorSearch(table='" + + TEST_INDEX + + "', field='embedding', vector='[1.0, 2.0]', option='k=5') AS v " + + "WHERE v.state = 'TX') t"); + + assertThat(explain, containsString("wrapper")); + // Inner WHERE should push down, so the state predicate appears in the DSL. + assertThat(explain, containsString("state")); + } + + @Test + public void testInnerWhereWithOuterProjectStillWorks() throws IOException { + // Another positive control: the outer layer can still project and limit columns from the + // subquery without the guard firing — only outer WHERE is rejected. + String explain = + explainQuery( + "SELECT t.firstname FROM (SELECT v.firstname, v.state " + + "FROM vectorSearch(table='" + + TEST_INDEX + + "', field='embedding', vector='[1.0, 2.0]', option='k=5') AS v " + + "WHERE v.state = 'TX') t " + + "LIMIT 3"); + + assertThat(explain, containsString("wrapper")); + } + + @Test + public void testSubqueryNoWhereStillWorks() throws IOException { + // Baseline: a subquery with no WHERE anywhere must not be rejected — the guard fires only + // when an outer LogicalFilter sits above a subquery project boundary. + String explain = + explainQuery( + "SELECT * FROM (SELECT v.firstname " + + "FROM vectorSearch(table='" + + TEST_INDEX + + "', field='embedding', vector='[1.0, 2.0]', option='k=5') AS v) t " + + "LIMIT 3"); + + assertThat(explain, containsString("wrapper")); + } + + @Test + public void testInnerOrderByScoreDescInSubqueryAllowed() throws IOException { + // Positive control: inner ORDER BY _score DESC on the vectorSearch() relation inside the + // subquery is the only supported sort, and must continue to plan successfully even when + // wrapped in an outer SELECT. Proves the walker does not over-reject sort shapes that are + // below the subquery Project rather than above it. + String explain = + explainQuery( + "SELECT * FROM (SELECT v.firstname, v._score " + + "FROM vectorSearch(table='" + + TEST_INDEX + + "', field='embedding', vector='[1.0, 2.0]', option='k=5') AS v " + + "ORDER BY v._score DESC) t " + + "LIMIT 3"); + + assertThat(explain, containsString("wrapper")); + } + + @Test + public void testOuterOrderByOnSubqueryRejected() throws IOException { + // Outer ORDER BY over a vectorSearch() subquery would run on a truncated top-k slice rather + // than the full relation, silently reordering only the already-ANN-selected rows. + ResponseException ex = + expectThrows( + ResponseException.class, + () -> + explainQuery( + "SELECT * FROM (SELECT v.firstname, v.state " + + "FROM vectorSearch(table='" + + TEST_INDEX + + "', field='embedding', vector='[1.0, 2.0]', option='k=5') AS v) t " + + "ORDER BY t.state")); + + assertThat( + ex.getMessage(), + containsString("Outer ORDER BY on a vectorSearch() subquery is not supported")); + } + + @Test + public void testOuterOffsetOnSubqueryRejected() throws IOException { + // Outer OFFSET silently drops top-k rows by vector distance. The inner query already caps at + // k and any outer OFFSET shifts that window in an opaque way, so reject it. + ResponseException ex = + expectThrows( + ResponseException.class, + () -> + explainQuery( + "SELECT * FROM (SELECT v.firstname " + + "FROM vectorSearch(table='" + + TEST_INDEX + + "', field='embedding', vector='[1.0, 2.0]', option='k=5') AS v) t " + + "LIMIT 3 OFFSET 2")); + + assertThat( + ex.getMessage(), + containsString("Outer OFFSET on a vectorSearch() subquery is not supported")); + } + + @Test + public void testOuterLimitWithoutOffsetOnSubqueryAllowed() throws IOException { + // Positive control: outer LIMIT without OFFSET just caps the row count and must plan without + // error. Locks in the offset==0 boundary of the OFFSET rejection. + String explain = + explainQuery( + "SELECT * FROM (SELECT v.firstname " + + "FROM vectorSearch(table='" + + TEST_INDEX + + "', field='embedding', vector='[1.0, 2.0]', option='k=5') AS v) t " + + "LIMIT 3"); + + assertThat(explain, containsString("wrapper")); + } + + @Test + public void testOuterAggregationOnSubqueryRejected() throws IOException { + // Outer aggregation (here COUNT(*)) over a vectorSearch() subquery would run on the + // truncated top-k slice, producing a count that silently depends on k rather than the full + // population. vectorSearch() does not support aggregations, so reject the outer-subquery + // variant with the same subquery-boundary walker that catches outer WHERE. + ResponseException ex = + expectThrows( + ResponseException.class, + () -> + explainQuery( + "SELECT COUNT(*) FROM (SELECT v.firstname " + + "FROM vectorSearch(table='" + + TEST_INDEX + + "', field='embedding', vector='[1.0, 2.0]', option='k=5') AS v) t")); + + assertThat( + ex.getMessage(), + containsString( + "Outer GROUP BY / aggregation / DISTINCT on a vectorSearch() subquery is not" + + " supported")); + } + + @Test + public void testOuterGroupByOnSubqueryRejected() throws IOException { + // GROUP BY rewrites to LogicalAggregation and is caught by the same subquery-boundary walker. + ResponseException ex = + expectThrows( + ResponseException.class, + () -> + explainQuery( + "SELECT t.state, COUNT(*) FROM (SELECT v.firstname, v.state " + + "FROM vectorSearch(table='" + + TEST_INDEX + + "', field='embedding', vector='[1.0, 2.0]', option='k=5') AS v) t " + + "GROUP BY t.state")); + + assertThat( + ex.getMessage(), + containsString( + "Outer GROUP BY / aggregation / DISTINCT on a vectorSearch() subquery is not" + + " supported")); + } + + @Test + public void testOuterDistinctOnSubqueryRejected() throws IOException { + // SELECT DISTINCT rewrites to a LogicalAggregation with empty aggregator list and the select + // items as the group-by list. The subquery-boundary walker must catch this shape too. + ResponseException ex = + expectThrows( + ResponseException.class, + () -> + explainQuery( + "SELECT DISTINCT t.state FROM (SELECT v.firstname, v.state " + + "FROM vectorSearch(table='" + + TEST_INDEX + + "', field='embedding', vector='[1.0, 2.0]', option='k=5') AS v) t")); + + assertThat( + ex.getMessage(), + containsString( + "Outer GROUP BY / aggregation / DISTINCT on a vectorSearch() subquery is not" + + " supported")); + } +} diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/data/type/OpenSearchDataType.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/data/type/OpenSearchDataType.java index 837a2a062ef..79d49a143de 100644 --- a/opensearch/src/main/java/org/opensearch/sql/opensearch/data/type/OpenSearchDataType.java +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/data/type/OpenSearchDataType.java @@ -43,7 +43,8 @@ public enum MappingType { ScaledFloat("scaled_float", ExprCoreType.DOUBLE), Double("double", ExprCoreType.DOUBLE), Boolean("boolean", ExprCoreType.BOOLEAN), - Alias("alias", ExprCoreType.UNKNOWN); + Alias("alias", ExprCoreType.UNKNOWN), + KnnVector("knn_vector", ExprCoreType.ARRAY); // TODO: ranges, geo shape, point, shape private final String name; diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/FilterType.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/FilterType.java new file mode 100644 index 00000000000..cc42bb35f58 --- /dev/null +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/FilterType.java @@ -0,0 +1,43 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.opensearch.storage; + +import java.util.Arrays; +import java.util.Set; +import java.util.stream.Collectors; +import org.opensearch.sql.exception.ExpressionEvaluationException; + +/** Filter placement strategy for vectorSearch() WHERE clauses. */ +public enum FilterType { + /** WHERE placed in bool.filter outside the knn clause (post-filtering). */ + POST("post"), + + /** WHERE placed inside knn.filter for efficient pre-filtering. */ + EFFICIENT("efficient"); + + private final String value; + + FilterType(String value) { + this.value = value; + } + + public String getValue() { + return value; + } + + private static final Set VALID_VALUES = + Arrays.stream(values()).map(FilterType::getValue).collect(Collectors.toSet()); + + public static FilterType fromString(String str) { + for (FilterType ft : values()) { + if (ft.value.equals(str)) { + return ft; + } + } + throw new ExpressionEvaluationException( + String.format("filter_type must be one of %s, got '%s'", VALID_VALUES, str)); + } +} diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/OpenSearchStorageEngine.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/OpenSearchStorageEngine.java index ce6740cd784..1b7de315fb6 100644 --- a/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/OpenSearchStorageEngine.java +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/OpenSearchStorageEngine.java @@ -7,10 +7,13 @@ import static org.opensearch.sql.utils.SystemIndexUtils.isSystemIndex; +import java.util.Collection; +import java.util.List; import lombok.Getter; import lombok.RequiredArgsConstructor; import org.opensearch.sql.DataSourceSchemaName; import org.opensearch.sql.common.setting.Settings; +import org.opensearch.sql.expression.function.FunctionResolver; import org.opensearch.sql.opensearch.client.OpenSearchClient; import org.opensearch.sql.opensearch.storage.system.OpenSearchSystemIndex; import org.opensearch.sql.storage.StorageEngine; @@ -25,6 +28,11 @@ public class OpenSearchStorageEngine implements StorageEngine { @Getter private final Settings settings; + @Override + public Collection getFunctions() { + return List.of(new VectorSearchTableFunctionResolver(client, settings)); + } + @Override public Table getTable(DataSourceSchemaName dataSourceSchemaName, String name) { if (isSystemIndex(name)) { diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/VectorSearchIndex.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/VectorSearchIndex.java new file mode 100644 index 00000000000..06727a5462b --- /dev/null +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/VectorSearchIndex.java @@ -0,0 +1,191 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.opensearch.storage; + +import java.util.Map; +import java.util.function.Function; +import org.opensearch.common.unit.TimeValue; +import org.opensearch.index.query.QueryBuilder; +import org.opensearch.index.query.WrapperQueryBuilder; +import org.opensearch.sql.common.setting.Settings; +import org.opensearch.sql.opensearch.client.OpenSearchClient; +import org.opensearch.sql.opensearch.request.OpenSearchRequestBuilder; +import org.opensearch.sql.opensearch.storage.capability.KnnPluginCapability; +import org.opensearch.sql.opensearch.storage.scan.OpenSearchIndexScan; +import org.opensearch.sql.opensearch.storage.scan.VectorSearchIndexScan; +import org.opensearch.sql.opensearch.storage.scan.VectorSearchIndexScanBuilder; +import org.opensearch.sql.opensearch.storage.scan.VectorSearchQueryBuilder; +import org.opensearch.sql.storage.read.TableScanBuilder; + +/** + * Vector-search-aware OpenSearch index. Seeds the scan with a knn query and enables score tracking. + */ +public class VectorSearchIndex extends OpenSearchIndex { + + private final String field; + private final float[] vector; + private final Map options; + private final FilterType filterType; // null means default (EFFICIENT) + // Nullable for back-compat with existing tests and the non-vector-search constructor. When + // present, the scan defers a lazy k-NN plugin probe to open() so execution fails fast with a + // clear SQL error if the plugin is missing. + private final KnnPluginCapability knnCapability; + + public VectorSearchIndex( + OpenSearchClient client, + Settings settings, + String indexName, + String field, + float[] vector, + Map options, + FilterType filterType, + KnnPluginCapability knnCapability) { + super(client, settings, indexName); + this.field = field; + this.vector = vector; + this.options = options; + this.filterType = filterType; + this.knnCapability = knnCapability; + } + + public VectorSearchIndex( + OpenSearchClient client, + Settings settings, + String indexName, + String field, + float[] vector, + Map options, + FilterType filterType) { + this(client, settings, indexName, field, vector, options, filterType, null); + } + + /** + * Default constructor — preserves existing call sites; uses no explicit filter type, so the scan + * falls back to the default placement ({@link FilterType#EFFICIENT}). + */ + public VectorSearchIndex( + OpenSearchClient client, + Settings settings, + String indexName, + String field, + float[] vector, + Map options) { + this(client, settings, indexName, field, vector, options, null, null); + } + + @Override + public TableScanBuilder createScanBuilder() { + // _score is not blocked at mapping time, so a user field named _score would collide with the + // synthetic v._score column on the response tuple and fail with an opaque duplicate-key error. + // Reject here so the user sees a clear SQL error (and _explain surfaces the problem without a + // k-NN request). + if (getFieldTypes().containsKey(METADATA_FIELD_SCORE)) { + throw new IllegalArgumentException( + String.format( + "Index '%s' defines a user field named '_score' that collides with the synthetic" + + " _score column exposed by vectorSearch(). Rename the field or query the index" + + " without vectorSearch().", + getIndexName())); + } + final TimeValue cursorKeepAlive = + getSettings().getSettingValue(Settings.Key.SQL_CURSOR_KEEP_ALIVE); + var requestBuilder = createRequestBuilder(); + + // Callback for efficient filtering: serialize WHERE QueryBuilder to JSON, + // rebuild knn query with filter embedded. JSON handling stays in this class. + Function rebuildWithFilter = + whereQuery -> new WrapperQueryBuilder(buildKnnQueryJson(whereQuery.toString())); + + boolean filterTypeExplicit = filterType != null; + FilterType effectiveFilterType = filterType != null ? filterType : FilterType.EFFICIENT; + + var queryBuilder = + new VectorSearchQueryBuilder( + requestBuilder, + buildKnnQuery(), + options, + effectiveFilterType, + filterTypeExplicit, + rebuildWithFilter); + requestBuilder.pushDownTrackedScore(true); + + // Default size policy: LIMIT pushdown will further reduce if present. + if (options.containsKey("k")) { + // Top-k mode: default size to k so queries without LIMIT return k results. + requestBuilder.pushDownLimitToRequestTotal(Integer.parseInt(options.get("k")), 0); + } else { + // Radial mode (max_distance/min_score): cap at maxResultWindow. + // Without an explicit cap, radial queries could return unbounded results. + requestBuilder.pushDownLimitToRequestTotal(getMaxResultWindow(), 0); + } + + Function createScanOperator = + rb -> { + var request = + rb.build(getIndexName(), cursorKeepAlive, getClient(), getFieldTypes().isEmpty()); + if (knnCapability != null) { + return new VectorSearchIndexScan( + getClient(), rb.getMaxResponseSize(), request, knnCapability); + } + return new OpenSearchIndexScan(getClient(), rb.getMaxResponseSize(), request); + }; + return new VectorSearchIndexScanBuilder(queryBuilder, createScanOperator); + } + + private QueryBuilder buildKnnQuery() { + return new WrapperQueryBuilder(buildKnnQueryJson()); + } + + // Package-private for testing + String buildKnnQueryJson() { + return buildKnnQueryJson(null); + } + + /** + * Builds knn query JSON, optionally embedding a filter clause for efficient filtering. + * + * @param filterJson serialized filter JSON to embed in knn.field.filter, or null for no filter + */ + String buildKnnQueryJson(String filterJson) { + StringBuilder vectorJson = new StringBuilder("["); + for (int i = 0; i < vector.length; i++) { + if (i > 0) vectorJson.append(","); + vectorJson.append(vector[i]); + } + vectorJson.append("]"); + + StringBuilder optionsJson = new StringBuilder(); + for (Map.Entry entry : options.entrySet()) { + optionsJson.append(","); + String value = entry.getValue(); + // All P0 option values are canonicalized to numeric strings by validateOptions(). + // The quoted fallback is retained for forward compatibility with future non-numeric options. + if (isNumeric(value)) { + optionsJson.append(String.format("\"%s\":%s", entry.getKey(), value)); + } else { + optionsJson.append(String.format("\"%s\":\"%s\"", entry.getKey(), value)); + } + } + + String filterClause = ""; + if (filterJson != null) { + filterClause = String.format(",\"filter\":%s", filterJson); + } + + return String.format( + "{\"knn\":{\"%s\":{\"vector\":%s%s%s}}}", + field, vectorJson.toString(), optionsJson.toString(), filterClause); + } + + private static boolean isNumeric(String str) { + try { + Double.parseDouble(str); + return true; + } catch (NumberFormatException e) { + return false; + } + } +} diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/VectorSearchTableFunctionImplementation.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/VectorSearchTableFunctionImplementation.java new file mode 100644 index 00000000000..c1b5354f4b1 --- /dev/null +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/VectorSearchTableFunctionImplementation.java @@ -0,0 +1,370 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.opensearch.storage; + +import static org.opensearch.sql.opensearch.storage.VectorSearchTableFunctionResolver.FIELD; +import static org.opensearch.sql.opensearch.storage.VectorSearchTableFunctionResolver.OPTION; +import static org.opensearch.sql.opensearch.storage.VectorSearchTableFunctionResolver.TABLE; +import static org.opensearch.sql.opensearch.storage.VectorSearchTableFunctionResolver.VECTOR; + +import java.util.HashSet; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Map; +import java.util.regex.Pattern; +import java.util.stream.Collectors; +import org.opensearch.sql.common.setting.Settings; +import org.opensearch.sql.data.model.ExprValue; +import org.opensearch.sql.data.type.ExprCoreType; +import org.opensearch.sql.data.type.ExprType; +import org.opensearch.sql.exception.ExpressionEvaluationException; +import org.opensearch.sql.expression.Expression; +import org.opensearch.sql.expression.FunctionExpression; +import org.opensearch.sql.expression.NamedArgumentExpression; +import org.opensearch.sql.expression.env.Environment; +import org.opensearch.sql.expression.function.FunctionName; +import org.opensearch.sql.expression.function.TableFunctionImplementation; +import org.opensearch.sql.opensearch.client.OpenSearchClient; +import org.opensearch.sql.opensearch.storage.capability.KnnPluginCapability; +import org.opensearch.sql.storage.Table; + +public class VectorSearchTableFunctionImplementation extends FunctionExpression + implements TableFunctionImplementation { + + /** + * P0 allowed option keys. Rejects unknown/future keys to prevent unvalidated DSL injection. A + * {@link List} (rather than a {@link Set}) so the unknown-key error message renders the supported + * keys in a stable, user-friendly order. + */ + static final List ALLOWED_OPTION_KEYS = + List.of("k", "max_distance", "min_score", "filter_type"); + + /** + * Field names must be safe for JSON interpolation: alphanumeric, dots (nested), underscores, + * hyphens. Rejects characters that could corrupt the WrapperQueryBuilder JSON. The same regex is + * reused for table names so user-supplied identifiers cannot break out of the JSON context. + */ + private static final Pattern SAFE_FIELD_NAME = Pattern.compile("^[a-zA-Z0-9._\\-]+$"); + + private final FunctionName functionName; + private final List arguments; + private final OpenSearchClient client; + private final Settings settings; + private final KnnPluginCapability knnCapability; + + public VectorSearchTableFunctionImplementation( + FunctionName functionName, + List arguments, + OpenSearchClient client, + Settings settings, + KnnPluginCapability knnCapability) { + super(functionName, arguments); + this.functionName = functionName; + this.arguments = arguments; + this.client = client; + this.settings = settings; + this.knnCapability = knnCapability; + } + + @Override + public ExprValue valueOf(Environment valueEnv) { + throw new UnsupportedOperationException( + String.format("vectorSearch function [%s] is only supported in FROM clause", functionName)); + } + + @Override + public ExprType type() { + return ExprCoreType.STRUCT; + } + + @Override + public String toString() { + List args = + arguments.stream() + .map( + arg -> { + if (arg instanceof NamedArgumentExpression) { + NamedArgumentExpression named = (NamedArgumentExpression) arg; + return String.format("%s=%s", named.getArgName(), named.getValue().toString()); + } + return arg.toString(); + }) + .collect(Collectors.toList()); + return String.format("%s(%s)", functionName, String.join(", ", args)); + } + + @Override + public Table applyArguments() { + // Local validation runs first so that malformed queries return stable SQL validation errors + // regardless of cluster state. The k-NN plugin presence is checked later, lazily at scan + // open() time, so analysis-time paths (_explain, local validation) stay functional on + // clusters without k-NN. + validateNamedArgs(); + String tableName = getArgumentValue(TABLE); + validateTableName(tableName); + String fieldName = getArgumentValue(FIELD); + validateFieldName(fieldName); + String vectorLiteral = getArgumentValue(VECTOR); + String optionStr = getArgumentValue(OPTION); + + float[] vector = parseVector(vectorLiteral); + Map options = parseOptions(optionStr); + validateOptions(options); + + // Strip filter_type — it's a SQL-layer directive, not a knn parameter + FilterType filterType = null; + if (options.containsKey("filter_type")) { + filterType = FilterType.fromString(options.remove("filter_type")); + } + + return new VectorSearchIndex( + client, settings, tableName, fieldName, vector, options, filterType, knnCapability); + } + + private float[] parseVector(String vectorLiteral) { + String cleaned = vectorLiteral.replaceAll("[\\[\\]]", "").trim(); + if (cleaned.isEmpty()) { + throw new ExpressionEvaluationException("Vector literal must not be empty"); + } + // Reject common non-comma separators before Float.parseFloat fails with a generic + // "Invalid vector component" that doesn't hint the user at the separator. + if (cleaned.indexOf(';') >= 0 || cleaned.indexOf(':') >= 0 || cleaned.indexOf('|') >= 0) { + throw new ExpressionEvaluationException( + String.format( + "Invalid vector literal '%s': vector= requires comma-separated components," + + " e.g., vector='[1.0,2.0,3.0]'", + vectorLiteral)); + } + // Preserve trailing empties (split(",", -1)) so malformed literals like "[1.0,]" or + // "[1.0,,2.0]" surface an explicit error instead of silently shrinking the vector. + String[] parts = cleaned.split(",", -1); + float[] vector = new float[parts.length]; + for (int i = 0; i < parts.length; i++) { + String component = parts[i].trim(); + if (component.isEmpty()) { + throw new ExpressionEvaluationException( + String.format( + "Invalid vector component at position %d: must be a number (check for" + + " trailing or consecutive commas in '%s')", + i, vectorLiteral)); + } + try { + vector[i] = Float.parseFloat(component); + } catch (NumberFormatException e) { + throw new ExpressionEvaluationException( + String.format("Invalid vector component '%s': must be a number", component)); + } + if (!Float.isFinite(vector[i])) { + throw new ExpressionEvaluationException( + String.format("Invalid vector component '%s': must be a finite number", component)); + } + } + return vector; + } + + static Map parseOptions(String optionStr) { + Map options = new LinkedHashMap<>(); + // A wholly empty option string is handled downstream with a clearer "missing required option" + // message than a generic malformed-segment error. + if (optionStr.trim().isEmpty()) { + return options; + } + // split(",", -1) preserves trailing empties so malformed inputs like "k=5," or "k=5,,k2=v" + // surface an explicit error instead of being silently dropped. + String[] pairs = optionStr.split(",", -1); + for (String pair : pairs) { + String trimmed = pair.trim(); + if (trimmed.isEmpty()) { + throw new ExpressionEvaluationException( + "Malformed option segment '': expected key=value (check for trailing or" + + " consecutive commas)"); + } + String[] kv = trimmed.split("=", 2); + if (kv.length != 2 || kv[0].trim().isEmpty() || kv[1].trim().isEmpty()) { + throw new ExpressionEvaluationException( + String.format("Malformed option segment '%s': expected key=value", trimmed)); + } + String key = kv[0].trim(); + if (options.containsKey(key)) { + throw new ExpressionEvaluationException(String.format("Duplicate option key '%s'", key)); + } + options.put(key, kv[1].trim()); + } + return options; + } + + /** + * Reject non-named arguments, null arg names, and duplicate named arguments early. Runs before + * any list-index-based lookup so a malformed argument list can never cause an AIOOBE downstream. + */ + private void validateNamedArgs() { + HashSet seen = new HashSet<>(); + for (Expression arg : arguments) { + if (!(arg instanceof NamedArgumentExpression)) { + throw new ExpressionEvaluationException( + "vectorSearch() requires named arguments (e.g., table='index'), " + + "but received: " + + arg.getClass().getSimpleName()); + } + String name = ((NamedArgumentExpression) arg).getArgName(); + if (name == null || name.isEmpty()) { + throw new ExpressionEvaluationException( + "vectorSearch() requires named arguments (e.g., table='index'), " + + "but received an argument with no name"); + } + if (!seen.add(name.toLowerCase(java.util.Locale.ROOT))) { + throw new ExpressionEvaluationException( + "Duplicate argument name '" + + name + + "' in vectorSearch(); each named argument may appear at most once"); + } + } + } + + /** + * Reject table names with characters that could corrupt the WrapperQueryBuilder JSON or escape + * the target index name. Allows alphanumeric, dots, underscores, and hyphens (the characters + * OpenSearch index names already permit). Explicitly rejects wildcards ('*') and multi-target + * patterns (comma-separated) with a dedicated message, because vectorSearch() targets a single + * concrete index or alias and fan-out patterns would otherwise fall through to the generic regex + * message. Also rejects the `_all` routing target and the pathologic `.` / `..` names because + * those either fan out to every index or are not valid concrete index names. Other native-invalid + * names (leading dot, leading hyphen, bare underscore, uppercase, and so on) are intentionally + * passed through for the OpenSearch client to reject with its own error message. + */ + private void validateTableName(String tableName) { + // Dedicated error for fan-out patterns ('*' and ',') before the generic regex; see Javadoc + // for why vectorSearch() targets a single index. + if (tableName.indexOf('*') >= 0 || tableName.indexOf(',') >= 0) { + throw new ExpressionEvaluationException( + String.format( + "Invalid table name '%s': vectorSearch() requires a single concrete index or alias;" + + " wildcards ('*') and multi-target patterns (comma-separated) are not" + + " supported", + tableName)); + } + if (!SAFE_FIELD_NAME.matcher(tableName).matches()) { + throw new ExpressionEvaluationException( + String.format( + "Invalid table name '%s': must contain only alphanumeric characters," + + " dots, underscores, or hyphens", + tableName)); + } + String lower = tableName.toLowerCase(java.util.Locale.ROOT); + if (lower.equals("_all") || tableName.equals(".") || tableName.equals("..")) { + throw new ExpressionEvaluationException( + String.format( + "Invalid table name '%s': vectorSearch() requires a single concrete index or alias;" + + " '_all', '.', and '..' are not supported", + tableName)); + } + } + + /** + * Reject field names with characters that could corrupt the WrapperQueryBuilder JSON. Allows + * alphanumeric, dots (nested fields), underscores, and hyphens. + */ + private void validateFieldName(String fieldName) { + if (!SAFE_FIELD_NAME.matcher(fieldName).matches()) { + throw new ExpressionEvaluationException( + String.format( + "Invalid field name '%s': must contain only alphanumeric characters," + + " dots, underscores, or hyphens", + fieldName)); + } + } + + /** + * Validates and canonicalizes option values. All P0 option values must be numeric. Parsing them + * here prevents non-numeric strings from reaching the raw JSON construction in buildKnnQuery(). + */ + private void validateOptions(Map options) { + // Reject unknown option keys — only P0 keys are allowed + for (String key : options.keySet()) { + if (!ALLOWED_OPTION_KEYS.contains(key)) { + throw new ExpressionEvaluationException( + String.format("Unknown option key '%s'. Supported keys: %s", key, ALLOWED_OPTION_KEYS)); + } + } + if (options.containsKey("filter_type")) { + // Validate early — fromString throws if invalid + FilterType.fromString(options.get("filter_type")); + } + boolean hasK = options.containsKey("k"); + boolean hasMaxDistance = options.containsKey("max_distance"); + boolean hasMinScore = options.containsKey("min_score"); + if (!hasK && !hasMaxDistance && !hasMinScore) { + throw new ExpressionEvaluationException( + "Missing required option: one of k, max_distance, or min_score"); + } + // Mutual exclusivity: exactly one search mode allowed + int modeCount = (hasK ? 1 : 0) + (hasMaxDistance ? 1 : 0) + (hasMinScore ? 1 : 0); + if (modeCount > 1) { + throw new ExpressionEvaluationException( + "Only one of k, max_distance, or min_score may be specified"); + } + // Parse and canonicalize numeric values — closes JSON injection via option values + if (hasK) { + int k = parseIntOption(options, "k"); + if (k < 1 || k > 10000) { + throw new ExpressionEvaluationException( + String.format("k must be between 1 and 10000, got %d", k)); + } + } + if (hasMaxDistance) { + double maxDistance = parseDoubleOption(options, "max_distance"); + if (maxDistance < 0) { + throw new ExpressionEvaluationException( + String.format( + "max_distance must be non-negative, got %s", options.get("max_distance"))); + } + } + if (hasMinScore) { + double minScore = parseDoubleOption(options, "min_score"); + if (minScore < 0) { + throw new ExpressionEvaluationException( + String.format("min_score must be non-negative, got %s", options.get("min_score"))); + } + } + } + + private int parseIntOption(Map options, String key) { + try { + int value = Integer.parseInt(options.get(key)); + options.put(key, Integer.toString(value)); + return value; + } catch (NumberFormatException e) { + throw new ExpressionEvaluationException( + String.format("Option '%s' must be an integer, got '%s'", key, options.get(key))); + } + } + + private double parseDoubleOption(Map options, String key) { + try { + double value = Double.parseDouble(options.get(key)); + if (!Double.isFinite(value)) { + throw new ExpressionEvaluationException( + String.format("Option '%s' must be a finite number, got '%s'", key, options.get(key))); + } + options.put(key, Double.toString(value)); + return value; + } catch (NumberFormatException e) { + throw new ExpressionEvaluationException( + String.format("Option '%s' must be a number, got '%s'", key, options.get(key))); + } + } + + private String getArgumentValue(String name) { + return arguments.stream() + .filter(arg -> ((NamedArgumentExpression) arg).getArgName().equalsIgnoreCase(name)) + .map(arg -> ((NamedArgumentExpression) arg).getValue().valueOf().stringValue()) + .findFirst() + .orElseThrow( + () -> + new ExpressionEvaluationException( + String.format("Missing required argument: %s", name))); + } +} diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/VectorSearchTableFunctionResolver.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/VectorSearchTableFunctionResolver.java new file mode 100644 index 00000000000..8db1f270afd --- /dev/null +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/VectorSearchTableFunctionResolver.java @@ -0,0 +1,109 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.opensearch.storage; + +import static org.opensearch.sql.data.type.ExprCoreType.STRING; + +import java.util.HashSet; +import java.util.List; +import org.apache.commons.lang3.tuple.Pair; +import org.opensearch.sql.common.setting.Settings; +import org.opensearch.sql.exception.ExpressionEvaluationException; +import org.opensearch.sql.expression.Expression; +import org.opensearch.sql.expression.NamedArgumentExpression; +import org.opensearch.sql.expression.function.FunctionBuilder; +import org.opensearch.sql.expression.function.FunctionName; +import org.opensearch.sql.expression.function.FunctionResolver; +import org.opensearch.sql.expression.function.FunctionSignature; +import org.opensearch.sql.opensearch.client.OpenSearchClient; +import org.opensearch.sql.opensearch.storage.capability.KnnPluginCapability; + +public class VectorSearchTableFunctionResolver implements FunctionResolver { + + public static final String VECTOR_SEARCH = "vectorsearch"; + public static final String TABLE = "table"; + public static final String FIELD = "field"; + public static final String VECTOR = "vector"; + public static final String OPTION = "option"; + public static final List ARGUMENT_NAMES = List.of(TABLE, FIELD, VECTOR, OPTION); + + private final OpenSearchClient client; + private final Settings settings; + private final KnnPluginCapability knnCapability; + + public VectorSearchTableFunctionResolver(OpenSearchClient client, Settings settings) { + this(client, settings, new KnnPluginCapability(client)); + } + + VectorSearchTableFunctionResolver( + OpenSearchClient client, Settings settings, KnnPluginCapability knnCapability) { + this.client = client; + this.settings = settings; + this.knnCapability = knnCapability; + } + + @Override + public Pair resolve(FunctionSignature unresolvedSignature) { + FunctionName functionName = FunctionName.of(VECTOR_SEARCH); + FunctionSignature functionSignature = + new FunctionSignature(functionName, List.of(STRING, STRING, STRING, STRING)); + FunctionBuilder functionBuilder = + (functionProperties, arguments) -> { + validateArguments(arguments); + return new VectorSearchTableFunctionImplementation( + functionName, arguments, client, settings, knnCapability); + }; + return Pair.of(functionSignature, functionBuilder); + } + + @Override + public FunctionName getFunctionName() { + return FunctionName.of(VECTOR_SEARCH); + } + + private void validateArguments(List arguments) { + if (arguments.size() != ARGUMENT_NAMES.size()) { + throw new ExpressionEvaluationException( + String.format( + "vectorSearch requires %d arguments (%s), got %d", + ARGUMENT_NAMES.size(), String.join(", ", ARGUMENT_NAMES), arguments.size())); + } + // Shape check at the resolver so positional or unknown-named args produce a clean 400 before + // planning proceeds. The Implementation layer repeats the non-named and duplicate-name checks + // as defense-in-depth; the unknown-name allowlist is enforced only here because the + // Implementation looks up values by known keys and does not need to re-validate the allowlist. + HashSet seen = new HashSet<>(); + for (Expression arg : arguments) { + if (!(arg instanceof NamedArgumentExpression)) { + throw new ExpressionEvaluationException( + "vectorSearch() requires named arguments (e.g., table='index'), " + + "but received: " + + arg.getClass().getSimpleName()); + } + String name = ((NamedArgumentExpression) arg).getArgName(); + if (name == null || name.isEmpty()) { + throw new ExpressionEvaluationException( + "vectorSearch() requires named arguments (e.g., table='index'), " + + "but received an argument with no name"); + } + String lower = name.toLowerCase(java.util.Locale.ROOT); + if (!ARGUMENT_NAMES.contains(lower)) { + throw new ExpressionEvaluationException( + String.format( + "Unknown argument name '%s' in vectorSearch(); allowed names are %s", + name, ARGUMENT_NAMES)); + } + if (!seen.add(lower)) { + throw new ExpressionEvaluationException( + "Duplicate argument name '" + + name + + "' in vectorSearch(); each named argument may appear at most once"); + } + } + // At this point `seen` holds exactly ARGUMENT_NAMES.size() entries (no duplicates, no unknowns, + // and arity matches), so every required name is present. No separate missing-name check needed. + } +} diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/capability/KnnPluginCapability.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/capability/KnnPluginCapability.java new file mode 100644 index 00000000000..9ba59915e1d --- /dev/null +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/capability/KnnPluginCapability.java @@ -0,0 +1,92 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.opensearch.storage.capability; + +import java.util.Objects; +import java.util.Optional; +import java.util.concurrent.atomic.AtomicReference; +import org.opensearch.action.admin.cluster.node.info.NodesInfoRequest; +import org.opensearch.action.admin.cluster.node.info.NodesInfoResponse; +import org.opensearch.action.admin.cluster.node.info.PluginsAndModules; +import org.opensearch.plugins.PluginInfo; +import org.opensearch.sql.exception.ExpressionEvaluationException; +import org.opensearch.sql.opensearch.client.OpenSearchClient; +import org.opensearch.transport.client.node.NodeClient; + +/** + * Probes the cluster's Nodes Info API once and caches whether the k-NN plugin is installed, so + * vectorSearch() fails fast with a clear error when the plugin is absent instead of surfacing a + * native OpenSearch error deep in execution. + * + *

The probe requires a {@link NodeClient}. In REST-client mode (standalone SQL service) the node + * client is absent and the check is skipped — execution-time errors remain the signal there. + * + *

The check runs lazily at scan open() — i.e. only when a vectorSearch() query is actually + * executed — so analysis-time paths like _explain and local argument validation keep working on + * clusters without k-NN. + */ +public class KnnPluginCapability { + + /** + * Canonical k-NN plugin class. Using the class name (not artifact name) so the check is stable + * across packaging variants. + */ + private static final String KNN_PLUGIN_CLASSNAME = "org.opensearch.knn.plugin.KNNPlugin"; + + private final OpenSearchClient client; + private final AtomicReference cached = new AtomicReference<>(); + + public KnnPluginCapability(OpenSearchClient client) { + this.client = client; + } + + /** + * Throws {@link ExpressionEvaluationException} with a user-facing message if the k-NN plugin is + * not installed on any node in the cluster. The result is cached after the first successful + * probe; probe failures are not cached so the next call retries. + */ + public void requireInstalled() { + Boolean hit = cached.get(); + if (hit == null) { + Optional probed = probe(); + if (probed.isEmpty()) { + // Probe unavailable (REST-client mode, no NodeClient). Don't block — execution-time + // errors will surface if k-NN is genuinely missing. + return; + } + hit = probed.get(); + cached.set(hit); + } + if (!hit) { + throw new ExpressionEvaluationException( + "vectorSearch() requires the k-NN plugin, which is not installed on this cluster." + + " Install opensearch-knn or use a cluster that has it."); + } + } + + private Optional probe() { + Optional maybeNode = client.getNodeClient(); + if (maybeNode.isEmpty()) { + return Optional.empty(); + } + NodeClient node = maybeNode.get(); + try { + NodesInfoRequest request = new NodesInfoRequest().clear().addMetric("plugins"); + NodesInfoResponse response = node.admin().cluster().nodesInfo(request).actionGet(); + boolean installed = + response.getNodes().stream() + .map(info -> info.getInfo(PluginsAndModules.class)) + .filter(Objects::nonNull) + .flatMap(p -> p.getPluginInfos().stream()) + .map(PluginInfo::getClassname) + .anyMatch(KNN_PLUGIN_CLASSNAME::equals); + return Optional.of(installed); + } catch (Exception e) { + // Probe failed (IO error, timeout). Don't cache — let the next call retry. + return Optional.empty(); + } + } +} diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/scan/OpenSearchIndexScanBuilder.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/scan/OpenSearchIndexScanBuilder.java index 70e6f0f2157..af9d46cd745 100644 --- a/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/scan/OpenSearchIndexScanBuilder.java +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/scan/OpenSearchIndexScanBuilder.java @@ -45,8 +45,8 @@ public OpenSearchIndexScanBuilder( this.scanFactory = scanFactory; } - /** Constructor used for unit tests. */ - protected OpenSearchIndexScanBuilder( + /** Constructor that accepts a custom PushDownQueryBuilder delegate. */ + public OpenSearchIndexScanBuilder( PushDownQueryBuilder translator, Function scanFactory) { this.delegate = translator; diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/scan/VectorSearchIndexScan.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/scan/VectorSearchIndexScan.java new file mode 100644 index 00000000000..86d1934f132 --- /dev/null +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/scan/VectorSearchIndexScan.java @@ -0,0 +1,37 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.opensearch.storage.scan; + +import org.opensearch.sql.opensearch.client.OpenSearchClient; +import org.opensearch.sql.opensearch.request.OpenSearchRequest; +import org.opensearch.sql.opensearch.storage.capability.KnnPluginCapability; + +/** + * OpenSearch scan for vector-search relations. Delegates everything to {@link OpenSearchIndexScan} + * except for {@link #open()}, where it first verifies the k-NN plugin is installed so we fail fast + * with a clear SQL error before the native request would fail deep in execution. The check is + * deferred to open() (not applyArguments() or the scan builder) so that analysis-time paths like + * _explain continue to work on clusters without k-NN. + */ +public class VectorSearchIndexScan extends OpenSearchIndexScan { + + private final KnnPluginCapability knnCapability; + + public VectorSearchIndexScan( + OpenSearchClient client, + int maxResponseSize, + OpenSearchRequest request, + KnnPluginCapability knnCapability) { + super(client, maxResponseSize, request); + this.knnCapability = knnCapability; + } + + @Override + public void open() { + knnCapability.requireInstalled(); + super.open(); + } +} diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/scan/VectorSearchIndexScanBuilder.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/scan/VectorSearchIndexScanBuilder.java new file mode 100644 index 00000000000..a898ac41299 --- /dev/null +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/scan/VectorSearchIndexScanBuilder.java @@ -0,0 +1,160 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.opensearch.storage.scan; + +import java.util.function.Function; +import org.opensearch.sql.exception.ExpressionEvaluationException; +import org.opensearch.sql.opensearch.request.OpenSearchRequestBuilder; +import org.opensearch.sql.planner.logical.LogicalAggregation; +import org.opensearch.sql.planner.logical.LogicalFilter; +import org.opensearch.sql.planner.logical.LogicalLimit; +import org.opensearch.sql.planner.logical.LogicalPlan; +import org.opensearch.sql.planner.logical.LogicalProject; +import org.opensearch.sql.planner.logical.LogicalSort; + +/** + * Scan builder for vector search relations. + * + *

Rejects planner shapes that the SQL surface cannot express safely: + * + *

    + *
  • Aggregations — native OpenSearch k-NN supports aggregations alongside similarity + * search, but the SQL layer does not plumb them through, so we fail fast rather than return + * silently unaggregated results. + *
  • Outer operators over a vectorSearch() subquery — when vectorSearch() is wrapped in a + * subquery (e.g. {@code SELECT * FROM (SELECT v.id FROM vectorSearch(...) AS v) t WHERE + * t.price < 150}), outer WHERE / ORDER BY / OFFSET / GROUP BY / aggregation / DISTINCT do not + * participate in the vectorSearch pushdown contract (the inner {@link LogicalProject} sits + * between the outer operator and this scan builder, so those nodes never match the + * direct-adjacency push-down patterns). They would then be applied in memory after + * top-k results have been selected by vector distance, which can silently yield zero rows or + * mis-ordered results. We detect these shapes in {@link #validatePlan(LogicalPlan)} and + * reject with a clear error. + *
+ */ +public class VectorSearchIndexScanBuilder extends OpenSearchIndexScanBuilder { + + public VectorSearchIndexScanBuilder( + PushDownQueryBuilder translator, + Function scanFactory) { + super(translator, scanFactory); + } + + @Override + public boolean pushDownAggregation(LogicalAggregation aggregation) { + throw new ExpressionEvaluationException( + "Aggregations are not supported on vectorSearch() relations."); + } + + /** + * Walk the fully-optimized plan and reject outer-operator-over-subquery shapes. We look for an + * outer {@link LogicalFilter}, {@link LogicalSort}, {@link LogicalLimit} with non-zero offset, or + * {@link LogicalAggregation} whose descendant chain reaches this scan builder through one or more + * {@link LogicalProject} nodes (the subquery-boundary marker). An operator directly above this + * scan builder is fine — those go through the push-down contract in the query builder. + */ + @Override + public void validatePlan(LogicalPlan root) { + checkForOuterOperator(root, null, false); + } + + /** + * Recursive walker that tracks the outermost "risky" operator seen on the current walk path and + * whether a {@link LogicalProject} has been crossed since then: + * + *
    + *
  • {@code outerOp} — name of the outermost filter/sort/offset/aggregation ancestor, or + * {@code null} if none. Projects only matter below such an operator — without one, a + * project is just the outer SELECT and should not trigger rejection. + *
  • {@code sawProjectSinceOuter} — true iff a {@link LogicalProject} has been seen between + * the outermost risky ancestor and the current position. Once separation by a Project has + * been established, it is permanent — a lower {@link LogicalFilter} below the Project does + * not undo the outer boundary. + *
+ * + *

This matters for shapes like {@code Filter(outer) -> Project(subquery) -> Filter(inner) -> + * Scan}, where the outer predicate is still blocked from reaching the push-down contract by the + * subquery Project regardless of the inner filter. Resetting on the inner filter would make the + * walker miss this shape. + */ + private void checkForOuterOperator( + LogicalPlan node, String outerOp, boolean sawProjectSinceOuter) { + if (node == this) { + if (outerOp != null && sawProjectSinceOuter) { + throw new ExpressionEvaluationException(rejectionMessage(outerOp)); + } + return; + } + String nextOuterOp = outerOp; + boolean nextSawProject = sawProjectSinceOuter; + if (outerOp == null) { + String operator = classifyOuterOperator(node); + if (operator != null) { + nextOuterOp = operator; + } + } else if (node instanceof LogicalProject) { + nextSawProject = true; + } + for (LogicalPlan child : node.getChild()) { + checkForOuterOperator(child, nextOuterOp, nextSawProject); + } + } + + /** + * Returns a user-facing label for operators that cannot safely sit above a vectorSearch() + * subquery, or {@code null} for operators that are fine (Project, scan, etc.). {@link + * LogicalLimit} with {@code offset == 0} is safe — plain LIMIT wrapping a subquery just caps the + * row count. Non-zero OFFSET skips top-k rows by distance and is rejected. + */ + private static String classifyOuterOperator(LogicalPlan node) { + if (node instanceof LogicalFilter) { + return "WHERE"; + } + if (node instanceof LogicalSort) { + return "ORDER BY"; + } + if (node instanceof LogicalAggregation) { + return "GROUP BY / aggregation / DISTINCT"; + } + if (node instanceof LogicalLimit) { + Integer offset = ((LogicalLimit) node).getOffset(); + if (offset != null && offset != 0) { + return "OFFSET"; + } + } + return null; + } + + // Operator-specific messages: the generic "move it inside the subquery" advice is only right + // for WHERE and for ORDER BY _score DESC. OFFSET, aggregation, GROUP BY, and DISTINCT are + // themselves unsupported on vectorSearch() directly, so the message must not claim a workaround + // that would only trip the user on a second validation error. + private static String rejectionMessage(String outerOp) { + switch (outerOp) { + case "WHERE": + return "Outer WHERE on a vectorSearch() subquery is not supported: the predicate does not" + + " participate in the vectorSearch pushdown contract and would be applied only" + + " after top-k results have been selected by vector distance, which can silently" + + " yield zero rows. Move the WHERE into the same SELECT block as vectorSearch() so" + + " it participates in the vectorSearch WHERE pushdown contract."; + case "ORDER BY": + return "Outer ORDER BY on a vectorSearch() subquery is not supported: sorting does not" + + " participate in the vectorSearch pushdown contract and would be applied only" + + " after top-k results have been selected by vector distance, which can yield" + + " mis-ordered results. Use ORDER BY ._score DESC in the same SELECT block" + + " as vectorSearch(), or omit ORDER BY."; + case "OFFSET": + return "Outer OFFSET on a vectorSearch() subquery is not supported. OFFSET is not" + + " supported on vectorSearch(); use LIMIT only."; + case "GROUP BY / aggregation / DISTINCT": + return "Outer GROUP BY / aggregation / DISTINCT on a vectorSearch() subquery is not" + + " supported. Aggregations and DISTINCT are not supported on vectorSearch()" + + " relations."; + default: + return "Outer " + outerOp + " on a vectorSearch() subquery is not supported."; + } + } +} diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/scan/VectorSearchQueryBuilder.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/scan/VectorSearchQueryBuilder.java new file mode 100644 index 00000000000..33714a793ab --- /dev/null +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/scan/VectorSearchQueryBuilder.java @@ -0,0 +1,285 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.opensearch.storage.scan; + +import java.util.Map; +import java.util.Set; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.function.Function; +import org.apache.commons.lang3.tuple.Pair; +import org.opensearch.index.query.BoolQueryBuilder; +import org.opensearch.index.query.ConstantScoreQueryBuilder; +import org.opensearch.index.query.ExistsQueryBuilder; +import org.opensearch.index.query.MatchBoolPrefixQueryBuilder; +import org.opensearch.index.query.MatchPhrasePrefixQueryBuilder; +import org.opensearch.index.query.MatchPhraseQueryBuilder; +import org.opensearch.index.query.MatchQueryBuilder; +import org.opensearch.index.query.MultiMatchQueryBuilder; +import org.opensearch.index.query.NestedQueryBuilder; +import org.opensearch.index.query.QueryBuilder; +import org.opensearch.index.query.QueryBuilders; +import org.opensearch.index.query.QueryStringQueryBuilder; +import org.opensearch.index.query.RangeQueryBuilder; +import org.opensearch.index.query.ScriptQueryBuilder; +import org.opensearch.index.query.SimpleQueryStringBuilder; +import org.opensearch.index.query.TermQueryBuilder; +import org.opensearch.index.query.WildcardQueryBuilder; +import org.opensearch.sql.ast.tree.Sort; +import org.opensearch.sql.ast.tree.Sort.SortOption; +import org.opensearch.sql.exception.ExpressionEvaluationException; +import org.opensearch.sql.expression.Expression; +import org.opensearch.sql.expression.ExpressionNodeVisitor; +import org.opensearch.sql.expression.ReferenceExpression; +import org.opensearch.sql.opensearch.request.OpenSearchRequestBuilder; +import org.opensearch.sql.opensearch.storage.FilterType; +import org.opensearch.sql.opensearch.storage.script.filter.FilterQueryBuilder; +import org.opensearch.sql.opensearch.storage.script.filter.FilterQueryBuilder.ScriptQueryUnSupportedException; +import org.opensearch.sql.opensearch.storage.serde.DefaultExpressionSerializer; +import org.opensearch.sql.planner.logical.LogicalFilter; +import org.opensearch.sql.planner.logical.LogicalLimit; +import org.opensearch.sql.planner.logical.LogicalSort; + +/** + * Query builder for vector search. The knn relevance score is preserved regardless of placement + * strategy — in {@code EFFICIENT} mode the knn query carries its own scores, and in {@code POST} + * mode the knn query sits in a scoring ({@code must}) context while the WHERE clause is applied as + * a non-scoring ({@code filter}) clause. + * + *

Supports two filter placement strategies via {@link FilterType}: + * + *

    + *
  • {@code EFFICIENT} — WHERE inside {@code knn.filter} for pre-filtering during ANN search + * (default). + *
  • {@code POST} — WHERE in {@code bool.filter} outside knn (post-filtering fallback, used when + * the WHERE shape is not compatible with pre-filtering). + *
+ */ +public class VectorSearchQueryBuilder extends OpenSearchIndexScanQueryBuilder { + + private final QueryBuilder knnQuery; + private final Map options; + private final FilterType filterType; + private final boolean filterTypeExplicit; + private final Function rebuildKnnWithFilter; + private boolean filterPushed = false; + private boolean limitPushed = false; + + /** Full constructor with filter type support. */ + public VectorSearchQueryBuilder( + OpenSearchRequestBuilder requestBuilder, + QueryBuilder knnQuery, + Map options, + FilterType filterType, + boolean filterTypeExplicit, + Function rebuildKnnWithFilter) { + super(requestBuilder); + requestBuilder.getSourceBuilder().query(knnQuery); + this.knnQuery = knnQuery; + this.options = options; + this.filterType = filterType != null ? filterType : FilterType.EFFICIENT; + this.filterTypeExplicit = filterTypeExplicit; + if (this.filterType == FilterType.EFFICIENT && rebuildKnnWithFilter == null) { + throw new IllegalArgumentException( + "EFFICIENT filter mode requires a non-null rebuildKnnWithFilter callback"); + } + this.rebuildKnnWithFilter = rebuildKnnWithFilter; + } + + /** + * Test-only constructor — pins {@link FilterType#POST} so callers that do not wire a {@code + * rebuildKnnWithFilter} callback (unit tests) can still exercise the push-down contract. + * Production callers always go through the full constructor, which defaults to {@link + * FilterType#EFFICIENT}. + */ + public VectorSearchQueryBuilder( + OpenSearchRequestBuilder requestBuilder, QueryBuilder knnQuery, Map options) { + this(requestBuilder, knnQuery, options, FilterType.POST, false, null); + } + + @Override + public boolean pushDownFilter(LogicalFilter filter) { + FilterQueryBuilder queryBuilder = new FilterQueryBuilder(new DefaultExpressionSerializer()); + Expression queryCondition = filter.getCondition(); + + // _score is synthetic, not a stored field; a range query on it silently returns 0 rows. + // Users who want a score floor should use option='min_score=...'. + if (containsScoreReference(queryCondition)) { + throw new ExpressionEvaluationException( + "WHERE on _score is not supported on vectorSearch()." + + " Use option='min_score=...' for score-floor filtering."); + } + + QueryBuilder whereQuery; + try { + whereQuery = queryBuilder.build(queryCondition); + } catch (ScriptQueryUnSupportedException e) { + if (filterTypeExplicit) { + throw new ExpressionEvaluationException( + "filter_type only works when the WHERE clause can be translated to an" + + " OpenSearch filter. Rewrite the WHERE clause or omit filter_type."); + } + // Default mode: fall back to in-memory filtering (matches base class behavior) + return false; + } + filterPushed = true; + + if (filterType == FilterType.EFFICIENT) { + // Fail closed: knn.filter on AOSS rejects script queries and nested predicates expand the + // preview contract. Allow-list validator beats a blacklist walker. + validateEfficientFilterSafe(whereQuery); + QueryBuilder rebuiltKnn = rebuildKnnWithFilter.apply(whereQuery); + requestBuilder.getSourceBuilder().query(rebuiltKnn); + } else { + // POST mode: knn in must (scores), WHERE in filter (no scoring impact) + BoolQueryBuilder combined = QueryBuilders.boolQuery().must(knnQuery).filter(whereQuery); + requestBuilder.getSourceBuilder().query(combined); + } + return true; + } + + @Override + public boolean pushDownLimit(LogicalLimit limit) { + // OFFSET would shift the search window and silently drop top results; reject with a clear + // error rather than have the parent path push `from: ` into the request. + if (limit.getOffset() != null && limit.getOffset() != 0) { + throw new ExpressionEvaluationException( + "OFFSET is not supported on vectorSearch(). Remove OFFSET and use LIMIT only."); + } + validateLimitWithinK(limit.getLimit()); + limitPushed = true; + return super.pushDownLimit(limit); + } + + @Override + public boolean pushDownSort(LogicalSort sort) { + // Vector search returns results sorted by _score DESC by default. + // Only _score DESC is meaningful; reject all other sort expressions. + for (Pair sortItem : sort.getSortList()) { + Expression expr = sortItem.getRight(); + if (!(expr instanceof ReferenceExpression) + || !"_score".equals(((ReferenceExpression) expr).getAttr())) { + throw new ExpressionEvaluationException( + String.format( + "vectorSearch only supports ORDER BY _score DESC; " + + "unsupported sort expression: %s", + expr)); + } + if (sortItem.getLeft().getSortOrder() != Sort.SortOrder.DESC) { + throw new ExpressionEvaluationException( + "vectorSearch only supports ORDER BY _score DESC; _score ASC is not supported"); + } + } + // _score DESC is knn's natural order, so the sort itself is not pushed. Preserve the + // parent's sort.getCount() → limit contract; SQL sends 0, PPL may combine sort+limit. + if (sort.getCount() != 0) { + validateLimitWithinK(sort.getCount()); + limitPushed = true; + requestBuilder.pushDownLimit(sort.getCount(), 0); + } + return true; + } + + /** Validates that the requested limit does not exceed k in top-k mode. */ + private void validateLimitWithinK(int limit) { + if (options.containsKey("k")) { + int k = Integer.parseInt(options.get("k")); + if (limit > k) { + throw new ExpressionEvaluationException( + String.format("LIMIT %d exceeds k=%d in top-k vector search", limit, k)); + } + } + } + + // True if any ReferenceExpression in the tree names _score (case-insensitive, so quoted/ + // backticked variants cannot bypass the guard). + private static boolean containsScoreReference(Expression expr) { + AtomicBoolean found = new AtomicBoolean(false); + expr.accept( + new ExpressionNodeVisitor() { + @Override + public Void visitReference(ReferenceExpression node, Void context) { + if (node.getAttr() != null && "_score".equalsIgnoreCase(node.getAttr())) { + found.set(true); + } + return null; + } + }, + null); + return found.get(); + } + + // Allow-list of leaf query types FilterQueryBuilder emits today. Any new wrapper or container + // appearing here must fail closed rather than silently embed under knn.filter. + private static final Set> SAFE_EFFICIENT_FILTER_LEAVES = + Set.of( + TermQueryBuilder.class, + RangeQueryBuilder.class, + WildcardQueryBuilder.class, + MatchQueryBuilder.class, + MatchPhraseQueryBuilder.class, + MatchPhrasePrefixQueryBuilder.class, + MultiMatchQueryBuilder.class, + QueryStringQueryBuilder.class, + SimpleQueryStringBuilder.class, + MatchBoolPrefixQueryBuilder.class, + ExistsQueryBuilder.class); + + // Package-private for direct branch coverage in unit tests. Fail-closed: recurse known + // containers, reject ScriptQueryBuilder/NestedQueryBuilder with targeted messages, allow + // listed leaves, reject everything else as unsupported shape. + static void validateEfficientFilterSafe(QueryBuilder qb) { + if (qb == null) { + return; + } + if (qb instanceof ScriptQueryBuilder) { + throw new ExpressionEvaluationException( + "vectorSearch WHERE pre-filtering does not support predicates that compile to" + + " script queries (arithmetic, function calls, CASE, date math). Rewrite the" + + " WHERE clause to use term/range/bool predicates, or set filter_type=post to" + + " apply the predicate after the k-NN search."); + } + if (qb instanceof BoolQueryBuilder) { + BoolQueryBuilder bool = (BoolQueryBuilder) qb; + bool.must().forEach(VectorSearchQueryBuilder::validateEfficientFilterSafe); + bool.filter().forEach(VectorSearchQueryBuilder::validateEfficientFilterSafe); + bool.should().forEach(VectorSearchQueryBuilder::validateEfficientFilterSafe); + bool.mustNot().forEach(VectorSearchQueryBuilder::validateEfficientFilterSafe); + return; + } + if (qb instanceof ConstantScoreQueryBuilder) { + validateEfficientFilterSafe(((ConstantScoreQueryBuilder) qb).innerQuery()); + return; + } + if (qb instanceof NestedQueryBuilder) { + throw new ExpressionEvaluationException( + "vectorSearch WHERE pre-filtering does not support nested predicates in this" + + " preview. Rewrite the WHERE clause using non-nested fields, or set" + + " filter_type=post to apply the predicate after the k-NN search."); + } + if (SAFE_EFFICIENT_FILTER_LEAVES.contains(qb.getClass())) { + return; + } + throw new ExpressionEvaluationException( + "vectorSearch WHERE pre-filtering encountered an unsupported filter query shape: " + + qb.getClass().getSimpleName() + + ". Rewrite the WHERE clause using simple term/range/bool predicates, or set" + + " filter_type=post to apply the predicate after the k-NN search."); + } + + @Override + public OpenSearchRequestBuilder build() { + if (filterTypeExplicit && !filterPushed) { + throw new ExpressionEvaluationException("filter_type requires a pushdownable WHERE clause"); + } + boolean isRadial = !options.containsKey("k"); + if (isRadial && !limitPushed) { + throw new ExpressionEvaluationException( + "LIMIT is required for radial vector search (max_distance or min_score)." + + " Without LIMIT, the result set size is unbounded."); + } + return super.build(); + } +} diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/script/filter/FilterQueryBuilder.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/script/filter/FilterQueryBuilder.java index 6ca25b7e9b7..2ff0dfa4a50 100644 --- a/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/script/filter/FilterQueryBuilder.java +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/script/filter/FilterQueryBuilder.java @@ -29,6 +29,7 @@ import org.opensearch.sql.expression.function.FunctionName; import org.opensearch.sql.opensearch.storage.script.CompoundedScriptEngine.ScriptEngineType; import org.opensearch.sql.opensearch.storage.script.core.ExpressionScript; +import org.opensearch.sql.opensearch.storage.script.filter.lucene.ExistsQuery; import org.opensearch.sql.opensearch.storage.script.filter.lucene.LikeQuery; import org.opensearch.sql.opensearch.storage.script.filter.lucene.LuceneQuery; import org.opensearch.sql.opensearch.storage.script.filter.lucene.NestedQuery; @@ -86,6 +87,8 @@ public ScriptQueryUnSupportedException(String message) { .put(BuiltinFunctionName.WILDCARD_QUERY.getName(), new WildcardQuery()) .put(BuiltinFunctionName.WILDCARDQUERY.getName(), new WildcardQuery()) .put(BuiltinFunctionName.NESTED.getName(), new NestedQuery()) + .put(BuiltinFunctionName.IS_NULL.getName(), new ExistsQuery(true /* negated */)) + .put(BuiltinFunctionName.IS_NOT_NULL.getName(), new ExistsQuery(false)) .build(); /** diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/script/filter/lucene/ExistsQuery.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/script/filter/lucene/ExistsQuery.java new file mode 100644 index 00000000000..5822f2f416a --- /dev/null +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/script/filter/lucene/ExistsQuery.java @@ -0,0 +1,69 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.opensearch.storage.script.filter.lucene; + +import static org.opensearch.sql.analysis.NestedAnalyzer.isNestedFunction; + +import lombok.RequiredArgsConstructor; +import org.opensearch.index.query.QueryBuilder; +import org.opensearch.index.query.QueryBuilders; +import org.opensearch.sql.expression.FunctionExpression; +import org.opensearch.sql.expression.ReferenceExpression; + +/** + * Lucene query that builds a native {@code exists} DSL fragment for {@code IS NULL} / {@code IS NOT + * NULL} predicates. + * + *

This replaces the previous behavior of serializing these unary predicates as compounded script + * queries. The native {@code exists} query is cheaper, AOSS / serverless compatible, and the + * expected DSL shape downstream consumers look for. + * + *

Unlike most {@link LuceneQuery} subclasses this predicate family is unary (a single reference + * argument) rather than the standard {ref, literal} pair, so this class overrides both {@link + * #canSupport(FunctionExpression)} and {@link #build(FunctionExpression)}. + * + *

Nested-field predicates are intentionally NOT supported here: OpenSearch DSL does not handle + * {@code IS_NULL} / {@code IS_NOT_NULL} on nested fields correctly (see the equivalent guard in + * {@code PredicateAnalyzer} for the Calcite path). When the reference is a nested function, {@link + * #canSupport} returns {@code false} and {@link + * org.opensearch.sql.opensearch.storage.script.filter.FilterQueryBuilder} falls back to the script + * query path, preserving correctness. + */ +@RequiredArgsConstructor +public class ExistsQuery extends LuceneQuery { + + /** When true, the predicate is {@code IS NULL} and the exists query is wrapped in must_not. */ + private final boolean negated; + + @Override + public boolean canSupport(FunctionExpression func) { + return func.getArguments().size() == 1 + && func.getArguments().get(0) instanceof ReferenceExpression + && !isNestedFunction(func.getArguments().get(0)); + } + + /** + * Unary IS NULL / IS NOT NULL has no {@code arg[1]}, so we must never route through {@link + * org.opensearch.sql.opensearch.storage.script.filter.lucene.NestedQuery#buildNested} — that path + * reads {@code func.getArguments().get(1)} and would throw. Returning {@code false} here forces + * {@code FilterQueryBuilder} to fall back to the script-query path for nested-field predicates. + */ + @Override + public boolean isNestedPredicate(FunctionExpression func) { + return false; + } + + @Override + public QueryBuilder build(FunctionExpression func) { + ReferenceExpression ref = (ReferenceExpression) func.getArguments().get(0); + String fieldName = ref.getRawPath(); + QueryBuilder existsQuery = QueryBuilders.existsQuery(fieldName); + if (negated) { + return QueryBuilders.boolQuery().mustNot(existsQuery); + } + return existsQuery; + } +} diff --git a/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/OpenSearchStorageEngineTest.java b/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/OpenSearchStorageEngineTest.java index 38f2ae495e0..fa04395e065 100644 --- a/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/OpenSearchStorageEngineTest.java +++ b/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/OpenSearchStorageEngineTest.java @@ -11,12 +11,14 @@ import static org.opensearch.sql.analysis.DataSourceSchemaIdentifierNameResolver.DEFAULT_DATASOURCE_NAME; import static org.opensearch.sql.utils.SystemIndexUtils.TABLE_INFO; +import java.util.Collection; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.extension.ExtendWith; import org.mockito.Mock; import org.mockito.junit.jupiter.MockitoExtension; import org.opensearch.sql.DataSourceSchemaName; import org.opensearch.sql.common.setting.Settings; +import org.opensearch.sql.expression.function.FunctionResolver; import org.opensearch.sql.opensearch.client.OpenSearchClient; import org.opensearch.sql.opensearch.storage.system.OpenSearchSystemIndex; import org.opensearch.sql.storage.Table; @@ -36,6 +38,15 @@ public void getTable() { assertAll(() -> assertNotNull(table), () -> assertTrue(table instanceof OpenSearchIndex)); } + @Test + public void getFunctionsReturnsVectorSearchResolver() { + OpenSearchStorageEngine engine = new OpenSearchStorageEngine(client, settings); + Collection functions = engine.getFunctions(); + assertTrue( + functions.stream().anyMatch(f -> f instanceof VectorSearchTableFunctionResolver), + "getFunctions() should contain a VectorSearchTableFunctionResolver"); + } + @Test public void getSystemTable() { OpenSearchStorageEngine engine = new OpenSearchStorageEngine(client, settings); diff --git a/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/VectorSearchIndexTest.java b/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/VectorSearchIndexTest.java new file mode 100644 index 00000000000..6a9a76a48f0 --- /dev/null +++ b/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/VectorSearchIndexTest.java @@ -0,0 +1,266 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.opensearch.storage; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.mockito.Mockito.lenient; +import static org.mockito.Mockito.when; + +import com.google.common.collect.ImmutableMap; +import java.util.LinkedHashMap; +import java.util.Map; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; +import org.mockito.Mock; +import org.mockito.junit.jupiter.MockitoExtension; +import org.opensearch.common.unit.TimeValue; +import org.opensearch.sql.common.setting.Settings; +import org.opensearch.sql.opensearch.client.OpenSearchClient; +import org.opensearch.sql.opensearch.data.type.OpenSearchDataType; +import org.opensearch.sql.opensearch.data.type.OpenSearchDataType.MappingType; +import org.opensearch.sql.opensearch.mapping.IndexMapping; + +@ExtendWith(MockitoExtension.class) +class VectorSearchIndexTest { + + @Mock private OpenSearchClient client; + + @Mock private Settings settings; + + @Mock private IndexMapping indexMapping; + + @Test + void buildKnnQueryJsonTopK() { + VectorSearchIndex index = + new VectorSearchIndex( + client, + settings, + "test-index", + "embedding", + new float[] {1.0f, 2.0f, 3.0f}, + Map.of("k", "5")); + + String json = index.buildKnnQueryJson(); + assertEquals("{\"knn\":{\"embedding\":{\"vector\":[1.0,2.0,3.0],\"k\":5}}}", json); + } + + @Test + void buildKnnQueryJsonRadialMaxDistance() { + VectorSearchIndex index = + new VectorSearchIndex( + client, + settings, + "test-index", + "embedding", + new float[] {1.0f, 2.0f}, + Map.of("max_distance", "10.5")); + + String json = index.buildKnnQueryJson(); + assertEquals("{\"knn\":{\"embedding\":{\"vector\":[1.0,2.0],\"max_distance\":10.5}}}", json); + } + + @Test + void buildKnnQueryJsonRadialMinScore() { + VectorSearchIndex index = + new VectorSearchIndex( + client, + settings, + "test-index", + "embedding", + new float[] {0.5f}, + Map.of("min_score", "0.8")); + + String json = index.buildKnnQueryJson(); + assertEquals("{\"knn\":{\"embedding\":{\"vector\":[0.5],\"min_score\":0.8}}}", json); + } + + @Test + void buildKnnQueryJsonNestedFieldName() { + VectorSearchIndex index = + new VectorSearchIndex( + client, + settings, + "test-index", + "doc.embedding", + new float[] {1.0f, 2.0f}, + Map.of("k", "10")); + + String json = index.buildKnnQueryJson(); + assertTrue(json.contains("\"doc.embedding\""), "Should contain nested field name with dot"); + } + + @Test + void buildKnnQueryJsonMultiElementVector() { + VectorSearchIndex index = + new VectorSearchIndex( + client, + settings, + "test-index", + "embedding", + new float[] {1.0f, -2.5f, 0.0f, 3.14f, 100.0f}, + Map.of("k", "3")); + + String json = index.buildKnnQueryJson(); + assertTrue( + json.contains("[1.0,-2.5,0.0,3.14,100.0]"), + "Should contain all vector components with correct comma separation"); + } + + @Test + void buildKnnQueryJsonSingleElementVector() { + VectorSearchIndex index = + new VectorSearchIndex( + client, settings, "test-index", "embedding", new float[] {42.0f}, Map.of("k", "1")); + + String json = index.buildKnnQueryJson(); + assertTrue(json.contains("[42.0]"), "Should contain single-element vector"); + } + + @Test + void buildKnnQueryJsonNumericOptionRenderedUnquoted() { + LinkedHashMap options = new LinkedHashMap<>(); + options.put("k", "5"); + + VectorSearchIndex index = + new VectorSearchIndex( + client, settings, "test-index", "embedding", new float[] {1.0f}, options); + + String json = index.buildKnnQueryJson(); + assertTrue(json.contains("\"k\":5"), "Numeric option should be unquoted"); + } + + @Test + void buildKnnQueryJsonNonNumericOptionRenderedQuoted() { + LinkedHashMap options = new LinkedHashMap<>(); + options.put("k", "5"); + options.put("method", "hnsw"); + + VectorSearchIndex index = + new VectorSearchIndex( + client, settings, "test-index", "embedding", new float[] {1.0f}, options); + + String json = index.buildKnnQueryJson(); + assertTrue(json.contains("\"method\":\"hnsw\""), "Non-numeric option should be quoted"); + assertTrue(json.contains("\"k\":5"), "Numeric option should be unquoted"); + } + + @Test + void buildKnnQueryJsonWithFilterEmbeds() { + VectorSearchIndex index = + new VectorSearchIndex( + client, + settings, + "test-index", + "embedding", + new float[] {1.0f, 2.0f}, + Map.of("k", "5"), + FilterType.EFFICIENT); + + String filterJson = "{\"term\":{\"city\":{\"value\":\"Miami\"}}}"; + String json = index.buildKnnQueryJson(filterJson); + + assertTrue(json.contains("\"filter\""), "Should contain filter field"); + assertTrue(json.contains("\"term\""), "Should contain the filter content"); + assertTrue(json.contains("\"k\":5"), "Should still contain k"); + assertTrue(json.contains("\"vector\":[1.0,2.0]"), "Should contain vector"); + } + + @Test + void buildKnnQueryJsonWithFilterRadial() { + VectorSearchIndex index = + new VectorSearchIndex( + client, + settings, + "test-index", + "embedding", + new float[] {1.0f}, + Map.of("max_distance", "10.5"), + FilterType.EFFICIENT); + + String filterJson = "{\"range\":{\"rating\":{\"gte\":4.0}}}"; + String json = index.buildKnnQueryJson(filterJson); + + assertTrue(json.contains("\"max_distance\":10.5"), "Should contain max_distance"); + assertTrue(json.contains("\"filter\""), "Should contain filter"); + } + + @Test + void buildKnnQueryJsonNullFilterProducesBaseJson() { + VectorSearchIndex index = + new VectorSearchIndex( + client, + settings, + "test-index", + "embedding", + new float[] {1.0f}, + Map.of("k", "5"), + null); + + String json = index.buildKnnQueryJson(null); + String baseJson = index.buildKnnQueryJson(); + + assertEquals(baseJson, json, "null filter should produce same JSON as no-arg version"); + assertFalse(json.contains("\"filter\""), "Should not contain filter field"); + } + + @Test + void buildKnnQueryJsonExcludesFilterType() { + LinkedHashMap options = new LinkedHashMap<>(); + options.put("k", "5"); + + VectorSearchIndex index = + new VectorSearchIndex( + client, + settings, + "test-index", + "embedding", + new float[] {1.0f}, + options, + FilterType.EFFICIENT); + + String json = index.buildKnnQueryJson(); + assertFalse(json.contains("filter_type"), "filter_type should not appear in knn JSON"); + assertTrue(json.contains("\"k\":5"), "k should still be present"); + } + + @Test + void isInstanceOfOpenSearchIndex() { + VectorSearchIndex index = + new VectorSearchIndex( + client, settings, "test-index", "embedding", new float[] {1.0f}, Map.of("k", "5")); + assertTrue(index instanceof OpenSearchIndex); + } + + @Test + void createScanBuilderRejectsIndexWithScoreField() { + // A mapping that declares a user field named _score cannot coexist with the synthetic + // v._score column exposed by vectorSearch(); the guard in createScanBuilder should reject + // it with a clear, user-facing error. + lenient() + .when(settings.getSettingValue(Settings.Key.SQL_CURSOR_KEEP_ALIVE)) + .thenReturn(TimeValue.timeValueMinutes(1)); + when(indexMapping.getFieldMappings()) + .thenReturn(Map.of("_score", OpenSearchDataType.of(MappingType.Float))); + when(client.getIndexMappings("test-index")) + .thenReturn(ImmutableMap.of("test-index", indexMapping)); + + VectorSearchIndex index = + new VectorSearchIndex( + client, settings, "test-index", "embedding", new float[] {1.0f}, Map.of("k", "5")); + + IllegalArgumentException ex = + assertThrows(IllegalArgumentException.class, index::createScanBuilder); + assertTrue( + ex.getMessage().contains("_score"), + "error message should mention the colliding _score field"); + assertTrue( + ex.getMessage().contains("collides"), + "error message should describe the collision, got: " + ex.getMessage()); + } +} diff --git a/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/VectorSearchTableFunctionImplementationTest.java b/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/VectorSearchTableFunctionImplementationTest.java new file mode 100644 index 00000000000..7bd64838876 --- /dev/null +++ b/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/VectorSearchTableFunctionImplementationTest.java @@ -0,0 +1,778 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.opensearch.storage; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; + +import java.util.List; +import java.util.Map; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; +import org.mockito.Mock; +import org.mockito.junit.jupiter.MockitoExtension; +import org.opensearch.sql.common.setting.Settings; +import org.opensearch.sql.data.type.ExprCoreType; +import org.opensearch.sql.exception.ExpressionEvaluationException; +import org.opensearch.sql.expression.DSL; +import org.opensearch.sql.expression.Expression; +import org.opensearch.sql.expression.function.FunctionName; +import org.opensearch.sql.opensearch.client.OpenSearchClient; +import org.opensearch.sql.opensearch.storage.capability.KnnPluginCapability; +import org.opensearch.sql.storage.Table; + +@ExtendWith(MockitoExtension.class) +class VectorSearchTableFunctionImplementationTest { + + @Mock private OpenSearchClient client; + + @Mock private Settings settings; + + // No-op capability — tests in this class don't exercise the k-NN plugin probe. + // Dedicated tests for the probe live in KnnPluginCapabilityTest. + private final KnnPluginCapability knnCapability = + org.mockito.Mockito.mock(KnnPluginCapability.class); + + @Test + void testValueOfThrows() { + VectorSearchTableFunctionImplementation impl = createImpl(); + UnsupportedOperationException ex = + assertThrows(UnsupportedOperationException.class, () -> impl.valueOf()); + assertTrue(ex.getMessage().contains("only supported in FROM clause")); + } + + @Test + void testType() { + VectorSearchTableFunctionImplementation impl = createImpl(); + assertEquals(ExprCoreType.STRUCT, impl.type()); + } + + @Test + void testToString() { + VectorSearchTableFunctionImplementation impl = createImpl(); + String str = impl.toString(); + assertTrue(str.contains("vectorsearch")); + assertTrue(str.contains("table=")); + assertTrue(str.contains("my-index")); + } + + @Test + void testApplyArguments() { + VectorSearchTableFunctionImplementation impl = createImpl(); + Table table = impl.applyArguments(); + assertTrue(table instanceof VectorSearchIndex); + } + + @Test + void testApplyArgumentsDoesNotProbeKnnCapability() { + // Contract: applyArguments() runs during analysis (including _explain) and must NOT invoke + // the k-NN plugin probe. The probe is deferred to scan open() so pluginless clusters can + // still explain and validate vectorSearch() queries locally. + KnnPluginCapability observingCapability = org.mockito.Mockito.mock(KnnPluginCapability.class); + FunctionName functionName = FunctionName.of("vectorsearch"); + List args = + List.of( + DSL.namedArgument("table", DSL.literal("my-index")), + DSL.namedArgument("field", DSL.literal("embedding")), + DSL.namedArgument("vector", DSL.literal("[1.0, 2.0]")), + DSL.namedArgument("option", DSL.literal("k=5"))); + VectorSearchTableFunctionImplementation impl = + new VectorSearchTableFunctionImplementation( + functionName, args, client, settings, observingCapability); + impl.applyArguments(); + org.mockito.Mockito.verify(observingCapability, org.mockito.Mockito.never()).requireInstalled(); + } + + @Test + void testApplyArgumentsWithBracketedVector() { + VectorSearchTableFunctionImplementation impl = + createImplWithArgs("my-index", "embedding", "[1.0, 2.0, 3.0]", "k=5"); + Table table = impl.applyArguments(); + assertTrue(table instanceof VectorSearchIndex); + } + + @Test + void testApplyArgumentsWithUnbracketedVector() { + VectorSearchTableFunctionImplementation impl = + createImplWithArgs("my-index", "embedding", "1.0, 2.0, 3.0", "k=5"); + Table table = impl.applyArguments(); + assertTrue(table instanceof VectorSearchIndex); + } + + @Test + void testUnknownOptionKeyThrows() { + VectorSearchTableFunctionImplementation impl = + createImplWithArgs("my-index", "embedding", "[1.0, 2.0]", "k=10,method.ef_search=100"); + ExpressionEvaluationException ex = + assertThrows(ExpressionEvaluationException.class, () -> impl.applyArguments()); + assertTrue(ex.getMessage().contains("Unknown option key")); + assertTrue(ex.getMessage().contains("method.ef_search")); + } + + @Test + void testApplyArgumentsWithMaxDistance() { + VectorSearchTableFunctionImplementation impl = + createImplWithArgs("my-index", "embedding", "[1.0, 2.0]", "max_distance=10.0"); + Table table = impl.applyArguments(); + assertTrue(table instanceof VectorSearchIndex); + } + + @Test + void testApplyArgumentsWithMinScore() { + VectorSearchTableFunctionImplementation impl = + createImplWithArgs("my-index", "embedding", "[1.0, 2.0]", "min_score=0.5"); + Table table = impl.applyArguments(); + assertTrue(table instanceof VectorSearchIndex); + } + + @Test + void testUnknownOptionKeyOnlyThrows() { + VectorSearchTableFunctionImplementation impl = + createImplWithArgs("my-index", "embedding", "[1.0, 2.0]", "not_a_key=100"); + ExpressionEvaluationException ex = + assertThrows(ExpressionEvaluationException.class, () -> impl.applyArguments()); + assertTrue(ex.getMessage().contains("Unknown option key")); + } + + @Test + void testParseOptionsMultiple() { + Map opts = + VectorSearchTableFunctionImplementation.parseOptions("k=5,max_distance=10.0"); + assertEquals("5", opts.get("k")); + assertEquals("10.0", opts.get("max_distance")); + } + + @Test + void testMalformedOptionSegmentThrows() { + ExpressionEvaluationException ex = + assertThrows( + ExpressionEvaluationException.class, + () -> VectorSearchTableFunctionImplementation.parseOptions("k=5,badoption")); + assertTrue(ex.getMessage().contains("Malformed option segment")); + } + + @Test + void testDuplicateOptionKeyThrows() { + ExpressionEvaluationException ex = + assertThrows( + ExpressionEvaluationException.class, + () -> VectorSearchTableFunctionImplementation.parseOptions("k=5,k=10")); + assertTrue(ex.getMessage().contains("Duplicate option key")); + } + + @Test + void testNoRequiredOptionThrows() { + VectorSearchTableFunctionImplementation impl = + createImplWithArgs("my-index", "embedding", "[1.0, 2.0]", ""); + ExpressionEvaluationException ex = + assertThrows(ExpressionEvaluationException.class, () -> impl.applyArguments()); + assertTrue(ex.getMessage().contains("Missing required option")); + } + + @Test + void testEmptyVectorThrows() { + VectorSearchTableFunctionImplementation impl = + createImplWithArgs("my-index", "embedding", "[]", "k=5"); + ExpressionEvaluationException ex = + assertThrows(ExpressionEvaluationException.class, () -> impl.applyArguments()); + assertTrue(ex.getMessage().contains("must not be empty")); + } + + @Test + void testMalformedVectorComponentThrows() { + VectorSearchTableFunctionImplementation impl = + createImplWithArgs("my-index", "embedding", "[1.0, abc, 3.0]", "k=5"); + ExpressionEvaluationException ex = + assertThrows(ExpressionEvaluationException.class, () -> impl.applyArguments()); + assertTrue(ex.getMessage().contains("Invalid vector component")); + } + + @Test + void testNonFiniteVectorComponentThrows() { + VectorSearchTableFunctionImplementation impl = + createImplWithArgs("my-index", "embedding", "[1.0, Infinity, 3.0]", "k=5"); + ExpressionEvaluationException ex = + assertThrows(ExpressionEvaluationException.class, () -> impl.applyArguments()); + assertTrue(ex.getMessage().contains("must be a finite number")); + } + + @Test + void testMissingArgumentThrows() { + FunctionName functionName = FunctionName.of("vectorsearch"); + List args = + List.of( + DSL.namedArgument("table", DSL.literal("my-index")), + DSL.namedArgument("field", DSL.literal("embedding")), + DSL.namedArgument("vector", DSL.literal("[1.0, 2.0]"))); + VectorSearchTableFunctionImplementation impl = + new VectorSearchTableFunctionImplementation( + functionName, args, client, settings, knnCapability); + ExpressionEvaluationException ex = + assertThrows(ExpressionEvaluationException.class, () -> impl.applyArguments()); + assertEquals("Missing required argument: option", ex.getMessage()); + } + + @Test + void testInvalidFieldNameThrows() { + VectorSearchTableFunctionImplementation impl = + createImplWithArgs("my-index", "field\"injection", "[1.0, 2.0]", "k=5"); + ExpressionEvaluationException ex = + assertThrows(ExpressionEvaluationException.class, () -> impl.applyArguments()); + assertTrue(ex.getMessage().contains("Invalid field name")); + } + + @Test + void testNestedFieldNameAllowed() { + VectorSearchTableFunctionImplementation impl = + createImplWithArgs("my-index", "doc.embedding", "[1.0, 2.0]", "k=5"); + Table table = impl.applyArguments(); + assertTrue(table instanceof VectorSearchIndex); + } + + @Test + void testNonNumericKThrows() { + VectorSearchTableFunctionImplementation impl = + createImplWithArgs("my-index", "embedding", "[1.0, 2.0]", "k=abc"); + ExpressionEvaluationException ex = + assertThrows(ExpressionEvaluationException.class, () -> impl.applyArguments()); + assertTrue(ex.getMessage().contains("must be an integer")); + } + + @Test + void testNonNumericMaxDistanceThrows() { + VectorSearchTableFunctionImplementation impl = + createImplWithArgs("my-index", "embedding", "[1.0, 2.0]", "max_distance=notanumber"); + ExpressionEvaluationException ex = + assertThrows(ExpressionEvaluationException.class, () -> impl.applyArguments()); + assertTrue(ex.getMessage().contains("must be a number")); + } + + @Test + void testInfiniteMinScoreThrows() { + VectorSearchTableFunctionImplementation impl = + createImplWithArgs("my-index", "embedding", "[1.0, 2.0]", "min_score=Infinity"); + ExpressionEvaluationException ex = + assertThrows(ExpressionEvaluationException.class, () -> impl.applyArguments()); + assertTrue(ex.getMessage().contains("must be a finite number")); + } + + @Test + void testMutualExclusivityKAndMaxDistanceThrows() { + VectorSearchTableFunctionImplementation impl = + createImplWithArgs("my-index", "embedding", "[1.0, 2.0]", "k=5,max_distance=10.0"); + ExpressionEvaluationException ex = + assertThrows(ExpressionEvaluationException.class, () -> impl.applyArguments()); + assertTrue(ex.getMessage().contains("Only one of")); + } + + @Test + void testMutualExclusivityKAndMinScoreThrows() { + VectorSearchTableFunctionImplementation impl = + createImplWithArgs("my-index", "embedding", "[1.0, 2.0]", "k=5,min_score=0.5"); + ExpressionEvaluationException ex = + assertThrows(ExpressionEvaluationException.class, () -> impl.applyArguments()); + assertTrue(ex.getMessage().contains("Only one of")); + } + + @Test + void testMutualExclusivityAllThreeThrows() { + VectorSearchTableFunctionImplementation impl = + createImplWithArgs( + "my-index", "embedding", "[1.0, 2.0]", "k=5,max_distance=10.0,min_score=0.5"); + ExpressionEvaluationException ex = + assertThrows(ExpressionEvaluationException.class, () -> impl.applyArguments()); + assertTrue(ex.getMessage().contains("Only one of")); + } + + @Test + void testKTooSmallThrows() { + VectorSearchTableFunctionImplementation impl = + createImplWithArgs("my-index", "embedding", "[1.0, 2.0]", "k=0"); + ExpressionEvaluationException ex = + assertThrows(ExpressionEvaluationException.class, () -> impl.applyArguments()); + assertTrue(ex.getMessage().contains("k must be between 1 and 10000")); + } + + @Test + void testKTooLargeThrows() { + VectorSearchTableFunctionImplementation impl = + createImplWithArgs("my-index", "embedding", "[1.0, 2.0]", "k=10001"); + ExpressionEvaluationException ex = + assertThrows(ExpressionEvaluationException.class, () -> impl.applyArguments()); + assertTrue(ex.getMessage().contains("k must be between 1 and 10000")); + } + + @Test + void testKBoundaryValuesAllowed() { + // k=1 should work + VectorSearchTableFunctionImplementation impl1 = + createImplWithArgs("my-index", "embedding", "[1.0, 2.0]", "k=1"); + assertTrue(impl1.applyArguments() instanceof VectorSearchIndex); + + // k=10000 should work + VectorSearchTableFunctionImplementation impl2 = + createImplWithArgs("my-index", "embedding", "[1.0, 2.0]", "k=10000"); + assertTrue(impl2.applyArguments() instanceof VectorSearchIndex); + } + + @Test + void testNonNamedArgThrows() { + FunctionName functionName = FunctionName.of("vectorsearch"); + List args = List.of(DSL.literal("my-index")); + VectorSearchTableFunctionImplementation impl = + new VectorSearchTableFunctionImplementation( + functionName, args, client, settings, knnCapability); + ExpressionEvaluationException ex = + assertThrows(ExpressionEvaluationException.class, () -> impl.applyArguments()); + assertTrue(ex.getMessage().contains("requires named arguments")); + } + + @Test + void testNullArgNameThrows() { + FunctionName functionName = FunctionName.of("vectorsearch"); + List args = + List.of( + DSL.namedArgument(null, DSL.literal("my-index")), + DSL.namedArgument("field", DSL.literal("embedding")), + DSL.namedArgument("vector", DSL.literal("[1.0, 2.0]")), + DSL.namedArgument("option", DSL.literal("k=5"))); + VectorSearchTableFunctionImplementation impl = + new VectorSearchTableFunctionImplementation( + functionName, args, client, settings, knnCapability); + ExpressionEvaluationException ex = + assertThrows(ExpressionEvaluationException.class, () -> impl.applyArguments()); + assertTrue(ex.getMessage().contains("requires named arguments")); + } + + @Test + void testNaNVectorComponentThrows() { + VectorSearchTableFunctionImplementation impl = + createImplWithArgs("my-index", "embedding", "[1.0, NaN, 3.0]", "k=5"); + ExpressionEvaluationException ex = + assertThrows(ExpressionEvaluationException.class, () -> impl.applyArguments()); + assertTrue(ex.getMessage().contains("must be a finite number")); + } + + @Test + void testEmptyOptionKeyThrows() { + ExpressionEvaluationException ex = + assertThrows( + ExpressionEvaluationException.class, + () -> VectorSearchTableFunctionImplementation.parseOptions("=value")); + assertTrue(ex.getMessage().contains("Malformed option segment")); + } + + @Test + void testEmptyOptionValueThrows() { + ExpressionEvaluationException ex = + assertThrows( + ExpressionEvaluationException.class, + () -> VectorSearchTableFunctionImplementation.parseOptions("key=")); + assertTrue(ex.getMessage().contains("Malformed option segment")); + } + + @Test + void testNegativeKThrows() { + VectorSearchTableFunctionImplementation impl = + createImplWithArgs("my-index", "embedding", "[1.0, 2.0]", "k=-1"); + ExpressionEvaluationException ex = + assertThrows(ExpressionEvaluationException.class, () -> impl.applyArguments()); + assertTrue(ex.getMessage().contains("k must be between 1 and 10000")); + } + + @Test + void testNaNMaxDistanceThrows() { + VectorSearchTableFunctionImplementation impl = + createImplWithArgs("my-index", "embedding", "[1.0, 2.0]", "max_distance=NaN"); + ExpressionEvaluationException ex = + assertThrows(ExpressionEvaluationException.class, () -> impl.applyArguments()); + assertTrue(ex.getMessage().contains("must be a finite number")); + } + + @Test + void testNaNMinScoreThrows() { + VectorSearchTableFunctionImplementation impl = + createImplWithArgs("my-index", "embedding", "[1.0, 2.0]", "min_score=NaN"); + ExpressionEvaluationException ex = + assertThrows(ExpressionEvaluationException.class, () -> impl.applyArguments()); + assertTrue(ex.getMessage().contains("must be a finite number")); + } + + @Test + void testCaseInsensitiveArgLookup() { + FunctionName functionName = FunctionName.of("vectorsearch"); + List args = + List.of( + DSL.namedArgument("TABLE", DSL.literal("my-index")), + DSL.namedArgument("FIELD", DSL.literal("embedding")), + DSL.namedArgument("VECTOR", DSL.literal("[1.0, 2.0]")), + DSL.namedArgument("OPTION", DSL.literal("k=5"))); + VectorSearchTableFunctionImplementation impl = + new VectorSearchTableFunctionImplementation( + functionName, args, client, settings, knnCapability); + Table table = impl.applyArguments(); + assertTrue(table instanceof VectorSearchIndex); + } + + @Test + void testInvalidFilterTypeRejects() { + FunctionName functionName = FunctionName.of("vectorsearch"); + List args = + List.of( + DSL.namedArgument("table", DSL.literal("my-index")), + DSL.namedArgument("field", DSL.literal("embedding")), + DSL.namedArgument("vector", DSL.literal("[1.0, 2.0]")), + DSL.namedArgument("option", DSL.literal("k=5,filter_type=invalid"))); + VectorSearchTableFunctionImplementation impl = + new VectorSearchTableFunctionImplementation( + functionName, args, client, settings, knnCapability); + ExpressionEvaluationException ex = + assertThrows(ExpressionEvaluationException.class, impl::applyArguments); + assertTrue(ex.getMessage().contains("filter_type must be one of")); + } + + @Test + void testFilterTypePostAccepted() { + VectorSearchTableFunctionImplementation impl = + createImplWithArgs("my-index", "embedding", "[1.0, 2.0]", "k=5,filter_type=post"); + Table table = impl.applyArguments(); + assertTrue(table instanceof VectorSearchIndex); + } + + @Test + void testFilterTypeEfficientAccepted() { + VectorSearchTableFunctionImplementation impl = + createImplWithArgs("my-index", "embedding", "[1.0, 2.0]", "k=5,filter_type=efficient"); + Table table = impl.applyArguments(); + assertTrue(table instanceof VectorSearchIndex); + } + + @Test + void testParseOptionsPreservesFilterTypeValue() { + Map options = + VectorSearchTableFunctionImplementation.parseOptions("k=5,filter_type=post"); + assertEquals("post", options.get("filter_type")); + } + + @Test + void applyArguments_rejectsInvalidTableName() { + VectorSearchTableFunctionImplementation impl = + createImplWithArgs("idx\"; DROP", "embedding", "[1.0, 2.0]", "k=5"); + ExpressionEvaluationException ex = + assertThrows(ExpressionEvaluationException.class, () -> impl.applyArguments()); + assertTrue(ex.getMessage().contains("Invalid table name")); + assertTrue( + ex.getMessage() + .contains("must contain only alphanumeric characters, dots, underscores, or hyphens")); + } + + @Test + void applyArguments_rejectsAllRoutingTarget() { + VectorSearchTableFunctionImplementation impl = + createImplWithArgs("_all", "embedding", "[1.0, 2.0]", "k=5"); + ExpressionEvaluationException ex = + assertThrows(ExpressionEvaluationException.class, () -> impl.applyArguments()); + assertTrue(ex.getMessage().contains("Invalid table name")); + assertTrue(ex.getMessage().contains("_all")); + } + + @Test + void applyArguments_rejectsSingleDotTable() { + VectorSearchTableFunctionImplementation impl = + createImplWithArgs(".", "embedding", "[1.0, 2.0]", "k=5"); + ExpressionEvaluationException ex = + assertThrows(ExpressionEvaluationException.class, () -> impl.applyArguments()); + assertTrue(ex.getMessage().contains("Invalid table name")); + } + + @Test + void applyArguments_rejectsDoubleDotTable() { + VectorSearchTableFunctionImplementation impl = + createImplWithArgs("..", "embedding", "[1.0, 2.0]", "k=5"); + ExpressionEvaluationException ex = + assertThrows(ExpressionEvaluationException.class, () -> impl.applyArguments()); + assertTrue(ex.getMessage().contains("Invalid table name")); + } + + @Test + void applyArguments_rejectsWildcardTableWithDedicatedMessage() { + VectorSearchTableFunctionImplementation impl = + createImplWithArgs("sql_vector_*", "embedding", "[1.0, 2.0]", "k=5"); + ExpressionEvaluationException ex = + assertThrows(ExpressionEvaluationException.class, () -> impl.applyArguments()); + assertTrue(ex.getMessage().contains("Invalid table name")); + assertTrue(ex.getMessage().contains("wildcards ('*')")); + assertTrue(ex.getMessage().contains("single concrete index")); + } + + @Test + void applyArguments_rejectsBareStarTableWithDedicatedMessage() { + VectorSearchTableFunctionImplementation impl = + createImplWithArgs("*", "embedding", "[1.0, 2.0]", "k=5"); + ExpressionEvaluationException ex = + assertThrows(ExpressionEvaluationException.class, () -> impl.applyArguments()); + assertTrue(ex.getMessage().contains("wildcards ('*')")); + } + + @Test + void applyArguments_rejectsMultiTargetTableWithDedicatedMessage() { + VectorSearchTableFunctionImplementation impl = + createImplWithArgs("idx_a,idx_b", "embedding", "[1.0, 2.0]", "k=5"); + ExpressionEvaluationException ex = + assertThrows(ExpressionEvaluationException.class, () -> impl.applyArguments()); + assertTrue(ex.getMessage().contains("Invalid table name")); + assertTrue(ex.getMessage().contains("multi-target")); + assertTrue(ex.getMessage().contains("single concrete index")); + } + + @Test + void applyArguments_rejectsMidNameStarTable() { + VectorSearchTableFunctionImplementation impl = + createImplWithArgs("foo*bar", "embedding", "[1.0, 2.0]", "k=5"); + ExpressionEvaluationException ex = + assertThrows(ExpressionEvaluationException.class, () -> impl.applyArguments()); + assertTrue(ex.getMessage().contains("wildcards ('*')")); + } + + @Test + void validateNamedArgs_rejectsDuplicateNames() { + // Two occurrences of "table" reach the Implementation layer directly (bypassing the resolver). + FunctionName functionName = FunctionName.of("vectorsearch"); + List args = + List.of( + DSL.namedArgument("table", DSL.literal("a")), + DSL.namedArgument("table", DSL.literal("b")), + DSL.namedArgument("vector", DSL.literal("[1.0]")), + DSL.namedArgument("option", DSL.literal("k=5"))); + VectorSearchTableFunctionImplementation impl = + new VectorSearchTableFunctionImplementation( + functionName, args, client, settings, knnCapability); + ExpressionEvaluationException ex = + assertThrows(ExpressionEvaluationException.class, () -> impl.applyArguments()); + assertTrue(ex.getMessage().contains("Duplicate argument name")); + assertTrue(ex.getMessage().contains("table")); + } + + // ── Option parsing: empty value, whitespace, unknown keys ──────────── + + @Test + void parseOptions_rejectsEmptyValue() { + ExpressionEvaluationException ex = + assertThrows( + ExpressionEvaluationException.class, + () -> VectorSearchTableFunctionImplementation.parseOptions("k=")); + assertTrue(ex.getMessage().contains("Malformed option segment")); + } + + @Test + void parseOptions_rejectsEmptyValueInMidSegment() { + ExpressionEvaluationException ex = + assertThrows( + ExpressionEvaluationException.class, + () -> VectorSearchTableFunctionImplementation.parseOptions("k=,filter_type=post")); + assertTrue(ex.getMessage().contains("Malformed option segment")); + } + + @Test + void parseOptions_trimsWhitespaceAroundKeyAndValue() { + Map options = + VectorSearchTableFunctionImplementation.parseOptions(" k = 5 , filter_type = post "); + assertEquals("5", options.get("k")); + assertEquals("post", options.get("filter_type")); + } + + @Test + void applyArguments_rejectsUnknownOptionKey() { + VectorSearchTableFunctionImplementation impl = + createImplWithArgs( + "my-index", "embedding", "[1.0, 2.0]", "k=5,method_parameters.ef_search=100"); + ExpressionEvaluationException ex = + assertThrows(ExpressionEvaluationException.class, () -> impl.applyArguments()); + assertTrue(ex.getMessage().contains("Unknown option key")); + assertTrue(ex.getMessage().contains("method_parameters.ef_search")); + } + + // ── Vector parsing: non-comma separator ───────────────────────────── + + @Test + void applyArguments_rejectsSemicolonSeparatorInVector() { + VectorSearchTableFunctionImplementation impl = + createImplWithArgs("my-index", "embedding", "[1.0;2.0]", "k=5"); + ExpressionEvaluationException ex = + assertThrows(ExpressionEvaluationException.class, () -> impl.applyArguments()); + assertTrue(ex.getMessage().contains("vector=")); + assertTrue(ex.getMessage().contains("comma-separated")); + } + + @Test + void applyArguments_rejectsColonSeparatorInVector() { + VectorSearchTableFunctionImplementation impl = + createImplWithArgs("my-index", "embedding", "[1.0:2.0]", "k=5"); + ExpressionEvaluationException ex = + assertThrows(ExpressionEvaluationException.class, () -> impl.applyArguments()); + assertTrue(ex.getMessage().contains("vector=")); + } + + @Test + void applyArguments_rejectsPipeSeparatorInVector() { + VectorSearchTableFunctionImplementation impl = + createImplWithArgs("my-index", "embedding", "[1.0|2.0]", "k=5"); + ExpressionEvaluationException ex = + assertThrows(ExpressionEvaluationException.class, () -> impl.applyArguments()); + assertTrue(ex.getMessage().contains("vector=")); + } + + // ── Option bounds: negative k, min_score, max_distance ────────────── + + @Test + void applyArguments_negativeKMessageCitesRange() { + VectorSearchTableFunctionImplementation impl = + createImplWithArgs("my-index", "embedding", "[1.0, 2.0]", "k=-3"); + ExpressionEvaluationException ex = + assertThrows(ExpressionEvaluationException.class, () -> impl.applyArguments()); + assertTrue(ex.getMessage().contains("1")); + assertTrue(ex.getMessage().contains("10000")); + } + + @Test + void applyArguments_rejectsNegativeMinScore() { + VectorSearchTableFunctionImplementation impl = + createImplWithArgs("my-index", "embedding", "[1.0, 2.0]", "min_score=-0.5"); + ExpressionEvaluationException ex = + assertThrows(ExpressionEvaluationException.class, () -> impl.applyArguments()); + assertTrue(ex.getMessage().contains("min_score")); + assertTrue(ex.getMessage().contains("non-negative")); + } + + @Test + void applyArguments_rejectsNegativeMaxDistance() { + VectorSearchTableFunctionImplementation impl = + createImplWithArgs("my-index", "embedding", "[1.0, 2.0]", "max_distance=-1.0"); + ExpressionEvaluationException ex = + assertThrows(ExpressionEvaluationException.class, () -> impl.applyArguments()); + assertTrue(ex.getMessage().contains("max_distance")); + assertTrue(ex.getMessage().contains("non-negative")); + } + + @Test + void applyArguments_acceptsZeroMinScore() { + VectorSearchTableFunctionImplementation impl = + createImplWithArgs("my-index", "embedding", "[1.0, 2.0]", "min_score=0"); + Table table = impl.applyArguments(); + assertTrue(table instanceof VectorSearchIndex); + } + + @Test + void applyArguments_acceptsZeroMaxDistance() { + VectorSearchTableFunctionImplementation impl = + createImplWithArgs("my-index", "embedding", "[1.0, 2.0]", "max_distance=0"); + Table table = impl.applyArguments(); + assertTrue(table instanceof VectorSearchIndex); + } + + // ── Vector parsing: trailing / empty components (PR #5381 review) ───── + + @Test + void applyArguments_rejectsTrailingCommaInVector() { + VectorSearchTableFunctionImplementation impl = + createImplWithArgs("my-index", "embedding", "[1.0,2.0,]", "k=5"); + ExpressionEvaluationException ex = + assertThrows(ExpressionEvaluationException.class, () -> impl.applyArguments()); + assertTrue(ex.getMessage().contains("Invalid vector component")); + assertTrue(ex.getMessage().contains("trailing or consecutive commas")); + } + + @Test + void applyArguments_rejectsConsecutiveCommasInVector() { + VectorSearchTableFunctionImplementation impl = + createImplWithArgs("my-index", "embedding", "[1.0,,2.0]", "k=5"); + ExpressionEvaluationException ex = + assertThrows(ExpressionEvaluationException.class, () -> impl.applyArguments()); + assertTrue(ex.getMessage().contains("Invalid vector component")); + assertTrue(ex.getMessage().contains("trailing or consecutive commas")); + } + + @Test + void applyArguments_rejectsLeadingCommaInVector() { + VectorSearchTableFunctionImplementation impl = + createImplWithArgs("my-index", "embedding", "[,1.0,2.0]", "k=5"); + ExpressionEvaluationException ex = + assertThrows(ExpressionEvaluationException.class, () -> impl.applyArguments()); + assertTrue(ex.getMessage().contains("Invalid vector component")); + } + + // ── Option parsing: empty segments (PR #5381 review) ───────────────── + + @Test + void parseOptions_rejectsTrailingEmptySegment() { + ExpressionEvaluationException ex = + assertThrows( + ExpressionEvaluationException.class, + () -> VectorSearchTableFunctionImplementation.parseOptions("k=5,")); + assertTrue(ex.getMessage().contains("Malformed option segment")); + assertTrue(ex.getMessage().contains("trailing or consecutive commas")); + } + + @Test + void parseOptions_rejectsLeadingEmptySegment() { + ExpressionEvaluationException ex = + assertThrows( + ExpressionEvaluationException.class, + () -> VectorSearchTableFunctionImplementation.parseOptions(",k=5")); + assertTrue(ex.getMessage().contains("Malformed option segment")); + } + + @Test + void parseOptions_rejectsConsecutiveCommas() { + ExpressionEvaluationException ex = + assertThrows( + ExpressionEvaluationException.class, + () -> VectorSearchTableFunctionImplementation.parseOptions("k=5,,filter_type=post")); + assertTrue(ex.getMessage().contains("Malformed option segment")); + } + + // ── Unknown-key error lists supported keys in stable order (PR #5381 review) ── + + @Test + void applyArguments_unknownOptionKeyErrorListsSupportedKeysInStableOrder() { + VectorSearchTableFunctionImplementation impl = + createImplWithArgs("my-index", "embedding", "[1.0, 2.0]", "k=5,bogus=1"); + ExpressionEvaluationException ex = + assertThrows(ExpressionEvaluationException.class, () -> impl.applyArguments()); + // Match the rendered list literal (e.g. "[k, max_distance, min_score, filter_type]") rather + // than searching for the substring "k", which would match the first "k" in "Unknown option + // key" and reduce the assertion to a tautology. + assertTrue( + ex.getMessage().contains("[k, max_distance, min_score, filter_type]"), + "expected stable key order in error; got: " + ex.getMessage()); + } + + @Test + void parseOptions_emptyStringReturnsEmptyMap() { + // The wholly empty option string is explicitly allowed through parseOptions so it flows to + // the "Missing required option" gate in validateOptions. Pins that contract. + Map opts = VectorSearchTableFunctionImplementation.parseOptions(""); + assertTrue(opts.isEmpty()); + } + + private VectorSearchTableFunctionImplementation createImpl() { + return createImplWithArgs("my-index", "embedding", "[1.0, 2.0, 3.0]", "k=5"); + } + + private VectorSearchTableFunctionImplementation createImplWithArgs( + String table, String field, String vector, String option) { + FunctionName functionName = FunctionName.of("vectorsearch"); + List args = + List.of( + DSL.namedArgument("table", DSL.literal(table)), + DSL.namedArgument("field", DSL.literal(field)), + DSL.namedArgument("vector", DSL.literal(vector)), + DSL.namedArgument("option", DSL.literal(option))); + return new VectorSearchTableFunctionImplementation( + functionName, args, client, settings, knnCapability); + } +} diff --git a/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/VectorSearchTableFunctionResolverTest.java b/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/VectorSearchTableFunctionResolverTest.java new file mode 100644 index 00000000000..c6fece7bf32 --- /dev/null +++ b/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/VectorSearchTableFunctionResolverTest.java @@ -0,0 +1,208 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.opensearch.storage; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.opensearch.sql.data.type.ExprCoreType.STRING; + +import java.util.List; +import java.util.stream.Collectors; +import org.apache.commons.lang3.tuple.Pair; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; +import org.mockito.Mock; +import org.mockito.junit.jupiter.MockitoExtension; +import org.opensearch.sql.common.setting.Settings; +import org.opensearch.sql.exception.ExpressionEvaluationException; +import org.opensearch.sql.expression.DSL; +import org.opensearch.sql.expression.Expression; +import org.opensearch.sql.expression.function.FunctionBuilder; +import org.opensearch.sql.expression.function.FunctionName; +import org.opensearch.sql.expression.function.FunctionProperties; +import org.opensearch.sql.expression.function.FunctionSignature; +import org.opensearch.sql.expression.function.TableFunctionImplementation; +import org.opensearch.sql.opensearch.client.OpenSearchClient; + +@ExtendWith(MockitoExtension.class) +class VectorSearchTableFunctionResolverTest { + + @Mock private OpenSearchClient client; + + @Mock private Settings settings; + + @Mock private FunctionProperties functionProperties; + + @Test + void testResolve() { + VectorSearchTableFunctionResolver resolver = + new VectorSearchTableFunctionResolver(client, settings); + FunctionName functionName = FunctionName.of("vectorsearch"); + List expressions = + List.of( + DSL.namedArgument("table", DSL.literal("my-index")), + DSL.namedArgument("field", DSL.literal("embedding")), + DSL.namedArgument("vector", DSL.literal("[1.0, 2.0, 3.0]")), + DSL.namedArgument("option", DSL.literal("k=5"))); + FunctionSignature functionSignature = + new FunctionSignature( + functionName, expressions.stream().map(Expression::type).collect(Collectors.toList())); + + Pair resolution = resolver.resolve(functionSignature); + + assertEquals(functionName, resolution.getKey().getFunctionName()); + assertEquals(functionName, resolver.getFunctionName()); + assertEquals(List.of(STRING, STRING, STRING, STRING), resolution.getKey().getParamTypeList()); + + TableFunctionImplementation impl = + (TableFunctionImplementation) resolution.getValue().apply(functionProperties, expressions); + assertTrue(impl instanceof VectorSearchTableFunctionImplementation); + } + + @Test + void testWrongArgumentCount() { + VectorSearchTableFunctionResolver resolver = + new VectorSearchTableFunctionResolver(client, settings); + FunctionName functionName = FunctionName.of("vectorsearch"); + List expressions = + List.of( + DSL.namedArgument("table", DSL.literal("my-index")), + DSL.namedArgument("field", DSL.literal("embedding"))); + FunctionSignature functionSignature = + new FunctionSignature( + functionName, expressions.stream().map(Expression::type).collect(Collectors.toList())); + + Pair resolution = resolver.resolve(functionSignature); + FunctionBuilder builder = resolution.getValue(); + + ExpressionEvaluationException ex = + assertThrows( + ExpressionEvaluationException.class, + () -> builder.apply(functionProperties, expressions)); + assertTrue(ex.getMessage().contains("requires 4 arguments")); + } + + @Test + void testTooManyArguments() { + VectorSearchTableFunctionResolver resolver = + new VectorSearchTableFunctionResolver(client, settings); + FunctionName functionName = FunctionName.of("vectorsearch"); + List expressions = + List.of( + DSL.namedArgument("table", DSL.literal("my-index")), + DSL.namedArgument("field", DSL.literal("embedding")), + DSL.namedArgument("vector", DSL.literal("[1.0]")), + DSL.namedArgument("option", DSL.literal("k=5")), + DSL.namedArgument("extra", DSL.literal("unexpected"))); + FunctionSignature functionSignature = + new FunctionSignature( + functionName, expressions.stream().map(Expression::type).collect(Collectors.toList())); + + Pair resolution = resolver.resolve(functionSignature); + FunctionBuilder builder = resolution.getValue(); + + ExpressionEvaluationException ex = + assertThrows( + ExpressionEvaluationException.class, + () -> builder.apply(functionProperties, expressions)); + assertTrue(ex.getMessage().contains("requires 4 arguments")); + } + + @Test + void testZeroArguments() { + VectorSearchTableFunctionResolver resolver = + new VectorSearchTableFunctionResolver(client, settings); + FunctionName functionName = FunctionName.of("vectorsearch"); + List expressions = List.of(); + FunctionSignature functionSignature = + new FunctionSignature( + functionName, expressions.stream().map(Expression::type).collect(Collectors.toList())); + + Pair resolution = resolver.resolve(functionSignature); + FunctionBuilder builder = resolution.getValue(); + + ExpressionEvaluationException ex = + assertThrows( + ExpressionEvaluationException.class, + () -> builder.apply(functionProperties, expressions)); + assertTrue(ex.getMessage().contains("requires 4 arguments")); + } + + @Test + void resolve_rejectsPositionalArgument() { + VectorSearchTableFunctionResolver resolver = + new VectorSearchTableFunctionResolver(client, settings); + FunctionName functionName = FunctionName.of("vectorsearch"); + // One positional literal mixed with three named arguments. Arity passes, but the resolver + // must reject this before planning so the SQL layer returns a clean 400 rather than a 200 + // with zero rows. + List expressions = + List.of( + DSL.literal("my-index"), + DSL.namedArgument("field", DSL.literal("embedding")), + DSL.namedArgument("vector", DSL.literal("[1.0, 2.0]")), + DSL.namedArgument("option", DSL.literal("k=5"))); + FunctionSignature functionSignature = + new FunctionSignature( + functionName, expressions.stream().map(Expression::type).collect(Collectors.toList())); + FunctionBuilder builder = resolver.resolve(functionSignature).getValue(); + + ExpressionEvaluationException ex = + assertThrows( + ExpressionEvaluationException.class, + () -> builder.apply(functionProperties, expressions)); + assertTrue(ex.getMessage().contains("requires named arguments")); + } + + @Test + void resolve_rejectsDuplicateNamedArgument() { + VectorSearchTableFunctionResolver resolver = + new VectorSearchTableFunctionResolver(client, settings); + FunctionName functionName = FunctionName.of("vectorsearch"); + List expressions = + List.of( + DSL.namedArgument("table", DSL.literal("a")), + DSL.namedArgument("table", DSL.literal("b")), + DSL.namedArgument("vector", DSL.literal("[1.0]")), + DSL.namedArgument("option", DSL.literal("k=5"))); + FunctionSignature functionSignature = + new FunctionSignature( + functionName, expressions.stream().map(Expression::type).collect(Collectors.toList())); + FunctionBuilder builder = resolver.resolve(functionSignature).getValue(); + + ExpressionEvaluationException ex = + assertThrows( + ExpressionEvaluationException.class, + () -> builder.apply(functionProperties, expressions)); + assertTrue(ex.getMessage().contains("Duplicate argument name")); + assertTrue(ex.getMessage().contains("table")); + } + + @Test + void resolve_rejectsUnknownArgumentName() { + VectorSearchTableFunctionResolver resolver = + new VectorSearchTableFunctionResolver(client, settings); + FunctionName functionName = FunctionName.of("vectorsearch"); + List expressions = + List.of( + DSL.namedArgument("table", DSL.literal("my-index")), + DSL.namedArgument("field", DSL.literal("embedding")), + DSL.namedArgument("vector", DSL.literal("[1.0, 2.0]")), + DSL.namedArgument("bogus", DSL.literal("k=5"))); + FunctionSignature functionSignature = + new FunctionSignature( + functionName, expressions.stream().map(Expression::type).collect(Collectors.toList())); + FunctionBuilder builder = resolver.resolve(functionSignature).getValue(); + + ExpressionEvaluationException ex = + assertThrows( + ExpressionEvaluationException.class, + () -> builder.apply(functionProperties, expressions)); + assertTrue(ex.getMessage().contains("Unknown argument name")); + assertTrue(ex.getMessage().contains("bogus")); + } +} diff --git a/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/capability/KnnPluginCapabilityTest.java b/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/capability/KnnPluginCapabilityTest.java new file mode 100644 index 00000000000..147a5a093ce --- /dev/null +++ b/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/capability/KnnPluginCapabilityTest.java @@ -0,0 +1,129 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.opensearch.storage.capability; + +import static org.junit.jupiter.api.Assertions.assertDoesNotThrow; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +import java.util.List; +import java.util.Optional; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; +import org.mockito.Mock; +import org.mockito.junit.jupiter.MockitoExtension; +import org.opensearch.action.admin.cluster.node.info.NodeInfo; +import org.opensearch.action.admin.cluster.node.info.NodesInfoRequest; +import org.opensearch.action.admin.cluster.node.info.NodesInfoResponse; +import org.opensearch.action.admin.cluster.node.info.PluginsAndModules; +import org.opensearch.common.action.ActionFuture; +import org.opensearch.plugins.PluginInfo; +import org.opensearch.sql.exception.ExpressionEvaluationException; +import org.opensearch.sql.opensearch.client.OpenSearchClient; +import org.opensearch.transport.client.AdminClient; +import org.opensearch.transport.client.ClusterAdminClient; +import org.opensearch.transport.client.node.NodeClient; + +@ExtendWith(MockitoExtension.class) +class KnnPluginCapabilityTest { + + @Mock private OpenSearchClient client; + @Mock private NodeClient nodeClient; + @Mock private AdminClient adminClient; + @Mock private ClusterAdminClient clusterAdminClient; + @Mock private ActionFuture nodesInfoFuture; + + @Test + void skipsWhenNodeClientAbsent() { + when(client.getNodeClient()).thenReturn(Optional.empty()); + KnnPluginCapability capability = new KnnPluginCapability(client); + // No exception — REST-client mode cannot probe; execution-time errors remain the signal. + assertDoesNotThrow(capability::requireInstalled); + } + + @Test + void passesWhenKnnPluginInstalled() { + stubNodesInfo(pluginInfo("org.opensearch.knn.plugin.KNNPlugin")); + KnnPluginCapability capability = new KnnPluginCapability(client); + assertDoesNotThrow(capability::requireInstalled); + } + + @Test + void throwsWhenKnnPluginAbsent() { + stubNodesInfo(pluginInfo("org.opensearch.security.OpenSearchSecurityPlugin")); + KnnPluginCapability capability = new KnnPluginCapability(client); + ExpressionEvaluationException ex = + assertThrows(ExpressionEvaluationException.class, capability::requireInstalled); + assertTrue( + ex.getMessage().contains("k-NN plugin"), + "Expected k-NN plugin message, got: " + ex.getMessage()); + assertTrue( + ex.getMessage().contains("not installed"), + "Expected 'not installed' phrasing, got: " + ex.getMessage()); + } + + @Test + void cachesSuccessfulProbeResult() { + stubNodesInfo(pluginInfo("org.opensearch.knn.plugin.KNNPlugin")); + KnnPluginCapability capability = new KnnPluginCapability(client); + capability.requireInstalled(); + capability.requireInstalled(); + capability.requireInstalled(); + // Probe fires once regardless of how many times requireInstalled() is called. + verify(clusterAdminClient, times(1)).nodesInfo(any(NodesInfoRequest.class)); + } + + @Test + void cachesNegativeProbeResult() { + stubNodesInfo(pluginInfo("org.opensearch.security.OpenSearchSecurityPlugin")); + KnnPluginCapability capability = new KnnPluginCapability(client); + assertThrows(ExpressionEvaluationException.class, capability::requireInstalled); + assertThrows(ExpressionEvaluationException.class, capability::requireInstalled); + verify(clusterAdminClient, times(1)).nodesInfo(any(NodesInfoRequest.class)); + } + + @Test + void doesNotCacheOnProbeFailure() { + when(client.getNodeClient()).thenReturn(Optional.of(nodeClient)); + when(nodeClient.admin()).thenReturn(adminClient); + when(adminClient.cluster()).thenReturn(clusterAdminClient); + when(clusterAdminClient.nodesInfo(any(NodesInfoRequest.class))).thenReturn(nodesInfoFuture); + when(nodesInfoFuture.actionGet()).thenThrow(new RuntimeException("transport error")); + + KnnPluginCapability capability = new KnnPluginCapability(client); + assertDoesNotThrow(capability::requireInstalled); // probe failed — treat as unknown + assertDoesNotThrow(capability::requireInstalled); + // Probe retries on each call after a failure — failures are not cached. + verify(clusterAdminClient, times(2)).nodesInfo(any(NodesInfoRequest.class)); + } + + private void stubNodesInfo(PluginInfo... plugins) { + when(client.getNodeClient()).thenReturn(Optional.of(nodeClient)); + when(nodeClient.admin()).thenReturn(adminClient); + when(adminClient.cluster()).thenReturn(clusterAdminClient); + when(clusterAdminClient.nodesInfo(any(NodesInfoRequest.class))).thenReturn(nodesInfoFuture); + + NodeInfo nodeInfo = mock(NodeInfo.class); + PluginsAndModules pam = mock(PluginsAndModules.class); + when(nodeInfo.getInfo(PluginsAndModules.class)).thenReturn(pam); + when(pam.getPluginInfos()).thenReturn(List.of(plugins)); + + NodesInfoResponse response = mock(NodesInfoResponse.class); + when(response.getNodes()).thenReturn(List.of(nodeInfo)); + when(nodesInfoFuture.actionGet()).thenReturn(response); + } + + private PluginInfo pluginInfo(String classname) { + PluginInfo pluginInfo = mock(PluginInfo.class); + when(pluginInfo.getClassname()).thenReturn(classname); + return pluginInfo; + } +} diff --git a/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/scan/VectorSearchIndexScanBuilderTest.java b/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/scan/VectorSearchIndexScanBuilderTest.java new file mode 100644 index 00000000000..ce2f2efb824 --- /dev/null +++ b/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/scan/VectorSearchIndexScanBuilderTest.java @@ -0,0 +1,234 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.opensearch.storage.scan; + +import static org.junit.jupiter.api.Assertions.assertDoesNotThrow; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.mockito.Mockito.mock; + +import com.google.common.collect.ImmutableList; +import java.util.Collections; +import org.junit.jupiter.api.Test; +import org.opensearch.index.query.WrapperQueryBuilder; +import org.opensearch.sql.ast.tree.Sort; +import org.opensearch.sql.common.setting.Settings; +import org.opensearch.sql.data.type.ExprCoreType; +import org.opensearch.sql.exception.ExpressionEvaluationException; +import org.opensearch.sql.expression.DSL; +import org.opensearch.sql.expression.NamedExpression; +import org.opensearch.sql.opensearch.data.value.OpenSearchExprValueFactory; +import org.opensearch.sql.opensearch.request.OpenSearchRequestBuilder; +import org.opensearch.sql.planner.logical.LogicalAggregation; +import org.opensearch.sql.planner.logical.LogicalFilter; +import org.opensearch.sql.planner.logical.LogicalLimit; +import org.opensearch.sql.planner.logical.LogicalPlan; +import org.opensearch.sql.planner.logical.LogicalProject; +import org.opensearch.sql.planner.logical.LogicalSort; +import org.opensearch.sql.planner.logical.LogicalValues; + +class VectorSearchIndexScanBuilderTest { + + private VectorSearchIndexScanBuilder newScanBuilder() { + var requestBuilder = + new OpenSearchRequestBuilder( + mock(OpenSearchExprValueFactory.class), 10000, mock(Settings.class)); + var queryBuilder = + new VectorSearchQueryBuilder( + requestBuilder, new WrapperQueryBuilder("{\"knn\":{}}"), java.util.Map.of("k", "5")); + return new VectorSearchIndexScanBuilder(queryBuilder, rb -> mock(OpenSearchIndexScan.class)); + } + + private static LogicalProject project(LogicalPlan input) { + NamedExpression field = DSL.named("id", DSL.ref("id", ExprCoreType.STRING)); + return new LogicalProject(input, ImmutableList.of(field), ImmutableList.of()); + } + + private static LogicalFilter filter(LogicalPlan input) { + return new LogicalFilter( + input, DSL.less(DSL.ref("price", ExprCoreType.INTEGER), DSL.literal(150))); + } + + private static LogicalSort sort(LogicalPlan input) { + return new LogicalSort( + input, + ImmutableList.of( + org.apache.commons.lang3.tuple.Pair.of( + Sort.SortOption.DEFAULT_DESC, DSL.ref("price", ExprCoreType.INTEGER)))); + } + + private static LogicalLimit limit(LogicalPlan input, int offset) { + return new LogicalLimit(input, 10, offset); + } + + private static LogicalAggregation aggregation(LogicalPlan input) { + return new LogicalAggregation(input, Collections.emptyList(), Collections.emptyList(), false); + } + + @Test + void pushDownAggregationIsRejected() { + var scanBuilder = newScanBuilder(); + + var agg = + new LogicalAggregation( + new LogicalValues(Collections.emptyList()), + Collections.emptyList(), + Collections.emptyList(), + false); + + ExpressionEvaluationException ex = + assertThrows( + ExpressionEvaluationException.class, () -> scanBuilder.pushDownAggregation(agg)); + assertTrue( + ex.getMessage().contains("Aggregations are not supported"), + "Error should state aggregations are not supported; actual: " + ex.getMessage()); + assertTrue( + ex.getMessage().contains("vectorSearch"), + "Error should mention vectorSearch; actual: " + ex.getMessage()); + } + + @Test + void validatePlanRejectsOuterFilterOverSubqueryProject() { + // Models: SELECT * FROM (SELECT v.id FROM vs(...) AS v) t WHERE t.price < 150 + // Shape after optimizer: Project(outer) → Filter → Project(inner) → scanBuilder + var scanBuilder = newScanBuilder(); + LogicalPlan root = project(filter(project(scanBuilder))); + + ExpressionEvaluationException ex = + assertThrows(ExpressionEvaluationException.class, () -> scanBuilder.validatePlan(root)); + assertTrue( + ex.getMessage().contains("Outer WHERE on a vectorSearch() subquery"), + "Error should mention outer WHERE on subquery; actual: " + ex.getMessage()); + assertTrue( + ex.getMessage().contains("silently yield zero rows"), + "Error should explain silent zero rows; actual: " + ex.getMessage()); + } + + @Test + void validatePlanRejectsDoubleWrappedOuterFilter() { + // Models nested subqueries: + // SELECT * FROM (SELECT * FROM (SELECT v.id FROM vs(...) AS v) t1) t2 WHERE t2.price < 150 + var scanBuilder = newScanBuilder(); + LogicalPlan root = filter(project(project(scanBuilder))); + + assertThrows(ExpressionEvaluationException.class, () -> scanBuilder.validatePlan(root)); + } + + @Test + void validatePlanAllowsFilterDirectlyAboveScanBuilder() { + // Models: SELECT v.id FROM vs(...) AS v WHERE v.gender='M' + // Here the filter would normally be pushed down and removed, but if it were kept (e.g. a + // non-pushdownable predicate), validatePlan must not reject it — it is already at the + // vectorSearch level, not an outer filter. + var scanBuilder = newScanBuilder(); + LogicalPlan root = project(filter(scanBuilder)); + + assertDoesNotThrow(() -> scanBuilder.validatePlan(root)); + } + + @Test + void validatePlanAllowsInnerFilterWrappedInOuterProject() { + // Models: SELECT * FROM (SELECT v.id FROM vs(...) AS v WHERE v.gender='M') t + // After pushdown the inner filter may remain when non-pushdownable; importantly, there is no + // outer filter — only outer projects wrapping an inner filter directly on scanBuilder. + var scanBuilder = newScanBuilder(); + LogicalPlan root = project(project(filter(scanBuilder))); + + assertDoesNotThrow(() -> scanBuilder.validatePlan(root)); + } + + @Test + void validatePlanRejectsFilterProjectFilterShape() { + // Models: SELECT * FROM (SELECT v.id FROM vs(...) AS v WHERE v.gender='M') t + // WHERE t.price < 150 + // Shape: Filter(outer) → Project(subquery) → Filter(inner) → scanBuilder + // The outer filter is still separated from the scan by the subquery Project; the inner + // filter sitting between the Project and the scan does not erase that boundary. Without + // preserving the project marker across the inner filter, the walker would miss this shape. + var scanBuilder = newScanBuilder(); + LogicalPlan root = filter(project(filter(scanBuilder))); + + ExpressionEvaluationException ex = + assertThrows(ExpressionEvaluationException.class, () -> scanBuilder.validatePlan(root)); + assertTrue( + ex.getMessage().contains("Outer WHERE on a vectorSearch() subquery"), + "Error should mention outer WHERE on subquery; actual: " + ex.getMessage()); + } + + @Test + void validatePlanAllowsNoFilterAtAll() { + // Baseline: no WHERE anywhere. SELECT * FROM (SELECT v.id FROM vs(...) AS v) t + var scanBuilder = newScanBuilder(); + LogicalPlan root = project(project(scanBuilder)); + + assertDoesNotThrow(() -> scanBuilder.validatePlan(root)); + } + + @Test + void validatePlanAllowsBareScanBuilder() { + // Defensive: a plan that is just the scan builder itself. + var scanBuilder = newScanBuilder(); + + assertDoesNotThrow(() -> scanBuilder.validatePlan(scanBuilder)); + } + + @Test + void validatePlanRejectsOuterSortOverSubqueryProject() { + // Models: SELECT * FROM (SELECT v.id FROM vs(...) AS v) t ORDER BY t.price + // Shape: Sort(outer) → Project(subquery) → scanBuilder + // Outer ORDER BY would be applied only after top-k ANN results, producing an order the user + // did not ask for (vector distance ordering leaks through when rows are fewer than expected). + var scanBuilder = newScanBuilder(); + LogicalPlan root = sort(project(scanBuilder)); + + ExpressionEvaluationException ex = + assertThrows(ExpressionEvaluationException.class, () -> scanBuilder.validatePlan(root)); + assertTrue( + ex.getMessage().contains("Outer ORDER BY on a vectorSearch() subquery"), + "Error should mention outer ORDER BY on subquery; actual: " + ex.getMessage()); + } + + @Test + void validatePlanRejectsOuterOffsetOverSubqueryProject() { + // Models: SELECT * FROM (SELECT v.id FROM vs(...) AS v) t LIMIT 10 OFFSET 5 + // Outer OFFSET silently skips the top-N nearest rows chosen by ANN, so the remaining rows + // would be a truncated tail of the k-NN result set rather than the user's intended window. + var scanBuilder = newScanBuilder(); + LogicalPlan root = limit(project(scanBuilder), 5); + + ExpressionEvaluationException ex = + assertThrows(ExpressionEvaluationException.class, () -> scanBuilder.validatePlan(root)); + assertTrue( + ex.getMessage().contains("Outer OFFSET on a vectorSearch() subquery"), + "Error should mention outer OFFSET on subquery; actual: " + ex.getMessage()); + } + + @Test + void validatePlanAllowsOuterLimitWithoutOffsetOverSubquery() { + // Outer LIMIT with offset=0 just caps row count and is safe over a subquery — reject only + // non-zero OFFSET. Locks in the offset==0 boundary of the guard. + var scanBuilder = newScanBuilder(); + LogicalPlan root = limit(project(scanBuilder), 0); + + assertDoesNotThrow(() -> scanBuilder.validatePlan(root)); + } + + @Test + void validatePlanRejectsOuterAggregationOverSubqueryProject() { + // Models: SELECT COUNT(*) FROM (SELECT v.id FROM vs(...) AS v) t + // (Or outer GROUP BY / DISTINCT, both of which rewrite to LogicalAggregation.) The outer + // aggregation would run on a truncated top-k slice rather than a meaningful population, + // masking the fact that aggregations are not supported on vectorSearch() in this preview. + var scanBuilder = newScanBuilder(); + LogicalPlan root = aggregation(project(scanBuilder)); + + ExpressionEvaluationException ex = + assertThrows(ExpressionEvaluationException.class, () -> scanBuilder.validatePlan(root)); + assertTrue( + ex.getMessage().contains("Outer GROUP BY / aggregation / DISTINCT on a vectorSearch()"), + "Error should mention outer aggregation on subquery; actual: " + ex.getMessage()); + } +} diff --git a/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/scan/VectorSearchIndexScanTest.java b/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/scan/VectorSearchIndexScanTest.java new file mode 100644 index 00000000000..3fa2adec88a --- /dev/null +++ b/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/scan/VectorSearchIndexScanTest.java @@ -0,0 +1,39 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.opensearch.storage.scan; + +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.mockito.Mockito.doThrow; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.never; +import static org.mockito.Mockito.verify; + +import org.junit.jupiter.api.Test; +import org.opensearch.sql.exception.ExpressionEvaluationException; +import org.opensearch.sql.opensearch.client.OpenSearchClient; +import org.opensearch.sql.opensearch.request.OpenSearchRequest; +import org.opensearch.sql.opensearch.storage.capability.KnnPluginCapability; + +class VectorSearchIndexScanTest { + + @Test + void openProbesKnnPluginBeforeFetch() { + OpenSearchClient client = mock(OpenSearchClient.class); + OpenSearchRequest request = mock(OpenSearchRequest.class); + KnnPluginCapability capability = mock(KnnPluginCapability.class); + doThrow(new ExpressionEvaluationException("k-NN plugin missing")) + .when(capability) + .requireInstalled(); + + VectorSearchIndexScan scan = new VectorSearchIndexScan(client, 10, request, capability); + ExpressionEvaluationException ex = + assertThrows(ExpressionEvaluationException.class, scan::open); + assertTrue(ex.getMessage().contains("k-NN plugin")); + // Capability threw, so the underlying client must not have been touched for this scan. + verify(client, never()).search(request); + } +} diff --git a/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/scan/VectorSearchQueryBuilderTest.java b/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/scan/VectorSearchQueryBuilderTest.java new file mode 100644 index 00000000000..b02d680af15 --- /dev/null +++ b/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/scan/VectorSearchQueryBuilderTest.java @@ -0,0 +1,857 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.opensearch.storage.scan; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.mockito.Mockito.mock; +import static org.opensearch.sql.data.type.ExprCoreType.STRING; + +import java.util.Collections; +import java.util.List; +import java.util.Map; +import java.util.function.Function; +import org.apache.lucene.search.join.ScoreMode; +import org.junit.jupiter.api.Test; +import org.opensearch.index.query.BoolQueryBuilder; +import org.opensearch.index.query.QueryBuilder; +import org.opensearch.index.query.QueryBuilders; +import org.opensearch.index.query.WrapperQueryBuilder; +import org.opensearch.sql.common.setting.Settings; +import org.opensearch.sql.data.type.ExprCoreType; +import org.opensearch.sql.exception.ExpressionEvaluationException; +import org.opensearch.sql.expression.DSL; +import org.opensearch.sql.expression.ReferenceExpression; +import org.opensearch.sql.opensearch.data.value.OpenSearchExprValueFactory; +import org.opensearch.sql.opensearch.request.OpenSearchRequestBuilder; +import org.opensearch.sql.opensearch.storage.FilterType; +import org.opensearch.sql.planner.logical.LogicalFilter; +import org.opensearch.sql.planner.logical.LogicalLimit; +import org.opensearch.sql.planner.logical.LogicalValues; + +class VectorSearchQueryBuilderTest { + + @Test + void knnQuerySetAsScoringQuery() { + var requestBuilder = createRequestBuilder(); + var knnQuery = new WrapperQueryBuilder("{\"knn\":{}}"); + + new VectorSearchQueryBuilder(requestBuilder, knnQuery, Map.of("k", "5")); + + QueryBuilder query = requestBuilder.getSourceBuilder().query(); + assertTrue( + query instanceof WrapperQueryBuilder, + "knn query should be set directly as top-level query (scoring context)"); + } + + @Test + void pushDownFilterKeepsKnnInScoringContext() { + var requestBuilder = createRequestBuilder(); + var knnQuery = new WrapperQueryBuilder("{\"knn\":{}}"); + var builder = new VectorSearchQueryBuilder(requestBuilder, knnQuery, Map.of("k", "5")); + + // Simulate WHERE name = 'John' + var condition = DSL.equal(new ReferenceExpression("name", STRING), DSL.literal("John")); + var dummyChild = new LogicalValues(Collections.emptyList()); + var filter = new LogicalFilter(dummyChild, condition); + + boolean pushed = builder.pushDownFilter(filter); + + assertTrue(pushed, "pushDownFilter should succeed"); + QueryBuilder resultQuery = requestBuilder.getSourceBuilder().query(); + assertTrue(resultQuery instanceof BoolQueryBuilder, "Result should be a BoolQuery"); + BoolQueryBuilder boolQuery = (BoolQueryBuilder) resultQuery; + assertEquals(1, boolQuery.must().size(), "knn query should be in must (scoring context)"); + assertEquals(1, boolQuery.filter().size(), "WHERE predicate should be in filter (non-scoring)"); + assertTrue( + boolQuery.must().get(0) instanceof WrapperQueryBuilder, + "must clause should contain the original knn WrapperQueryBuilder"); + } + + @Test + void pushDownLimitWithinKSucceeds() { + var requestBuilder = createRequestBuilder(); + var knnQuery = new WrapperQueryBuilder("{\"knn\":{}}"); + var builder = new VectorSearchQueryBuilder(requestBuilder, knnQuery, Map.of("k", "5")); + + var dummyChild = new LogicalValues(Collections.emptyList()); + var limit = new LogicalLimit(dummyChild, 3, 0); + + boolean pushed = builder.pushDownLimit(limit); + assertTrue(pushed, "LIMIT within k should succeed"); + } + + @Test + void pushDownLimitExceedingKThrows() { + var requestBuilder = createRequestBuilder(); + var knnQuery = new WrapperQueryBuilder("{\"knn\":{}}"); + var builder = new VectorSearchQueryBuilder(requestBuilder, knnQuery, Map.of("k", "5")); + + var dummyChild = new LogicalValues(Collections.emptyList()); + var limit = new LogicalLimit(dummyChild, 10, 0); + + ExpressionEvaluationException ex = + assertThrows(ExpressionEvaluationException.class, () -> builder.pushDownLimit(limit)); + assertTrue(ex.getMessage().contains("LIMIT 10 exceeds k=5")); + } + + @Test + void pushDownLimitEqualToKSucceeds() { + var requestBuilder = createRequestBuilder(); + var knnQuery = new WrapperQueryBuilder("{\"knn\":{}}"); + var builder = new VectorSearchQueryBuilder(requestBuilder, knnQuery, Map.of("k", "5")); + + var dummyChild = new LogicalValues(Collections.emptyList()); + var limit = new LogicalLimit(dummyChild, 5, 0); + + boolean pushed = builder.pushDownLimit(limit); + assertTrue(pushed, "LIMIT equal to k should succeed"); + } + + @Test + void pushDownLimitRadialModeNoRestriction() { + var requestBuilder = createRequestBuilder(); + var knnQuery = new WrapperQueryBuilder("{\"knn\":{}}"); + var builder = + new VectorSearchQueryBuilder(requestBuilder, knnQuery, Map.of("max_distance", "10.0")); + + var dummyChild = new LogicalValues(Collections.emptyList()); + var limit = new LogicalLimit(dummyChild, 100, 0); + + boolean pushed = builder.pushDownLimit(limit); + assertTrue(pushed, "Radial mode should not restrict LIMIT"); + } + + @Test + void pushDownLimitMinScoreModeNoRestriction() { + var requestBuilder = createRequestBuilder(); + var knnQuery = new WrapperQueryBuilder("{\"knn\":{}}"); + var builder = + new VectorSearchQueryBuilder(requestBuilder, knnQuery, Map.of("min_score", "0.5")); + + var dummyChild = new LogicalValues(Collections.emptyList()); + var limit = new LogicalLimit(dummyChild, 100, 0); + + boolean pushed = builder.pushDownLimit(limit); + assertTrue(pushed, "min_score mode should not restrict LIMIT"); + } + + @Test + void pushDownSortScoreDescAccepted() { + var requestBuilder = createRequestBuilder(); + var knnQuery = new WrapperQueryBuilder("{\"knn\":{}}"); + var builder = new VectorSearchQueryBuilder(requestBuilder, knnQuery, Map.of("k", "5")); + + var dummyChild = new LogicalValues(Collections.emptyList()); + var sort = + new org.opensearch.sql.planner.logical.LogicalSort( + dummyChild, + List.of( + org.apache.commons.lang3.tuple.ImmutablePair.of( + org.opensearch.sql.ast.tree.Sort.SortOption.DEFAULT_DESC, + new ReferenceExpression("_score", ExprCoreType.FLOAT)))); + + boolean pushed = builder.pushDownSort(sort); + assertTrue(pushed, "ORDER BY _score DESC should be accepted"); + } + + @Test + void pushDownSortPreservesSortCountAsLimit() { + var requestBuilder = createRequestBuilder(); + var knnQuery = new WrapperQueryBuilder("{\"knn\":{}}"); + var builder = new VectorSearchQueryBuilder(requestBuilder, knnQuery, Map.of("k", "10")); + + var dummyChild = new LogicalValues(Collections.emptyList()); + // LogicalSort with count=7 simulates a sort+limit combined node (PPL path) + var sort = + new org.opensearch.sql.planner.logical.LogicalSort( + dummyChild, + 7, + List.of( + org.apache.commons.lang3.tuple.ImmutablePair.of( + org.opensearch.sql.ast.tree.Sort.SortOption.DEFAULT_DESC, + new ReferenceExpression("_score", ExprCoreType.FLOAT)))); + + boolean pushed = builder.pushDownSort(sort); + assertTrue(pushed, "ORDER BY _score DESC with count should be accepted"); + assertEquals( + 7, + requestBuilder.getMaxResponseSize(), + "sort.getCount() should be pushed down as request size"); + } + + @Test + void pushDownSortCountExceedingKRejects() { + var requestBuilder = createRequestBuilder(); + var knnQuery = new WrapperQueryBuilder("{\"knn\":{}}"); + var builder = new VectorSearchQueryBuilder(requestBuilder, knnQuery, Map.of("k", "5")); + + var dummyChild = new LogicalValues(Collections.emptyList()); + // LogicalSort with count=10 exceeds k=5 — should be rejected + var sort = + new org.opensearch.sql.planner.logical.LogicalSort( + dummyChild, + 10, + List.of( + org.apache.commons.lang3.tuple.ImmutablePair.of( + org.opensearch.sql.ast.tree.Sort.SortOption.DEFAULT_DESC, + new ReferenceExpression("_score", ExprCoreType.FLOAT)))); + + ExpressionEvaluationException ex = + assertThrows(ExpressionEvaluationException.class, () -> builder.pushDownSort(sort)); + assertTrue(ex.getMessage().contains("LIMIT 10 exceeds k=5")); + } + + @Test + void pushDownSortNonScoreFieldRejected() { + var requestBuilder = createRequestBuilder(); + var knnQuery = new WrapperQueryBuilder("{\"knn\":{}}"); + var builder = new VectorSearchQueryBuilder(requestBuilder, knnQuery, Map.of("k", "5")); + + var dummyChild = new LogicalValues(Collections.emptyList()); + var sort = + new org.opensearch.sql.planner.logical.LogicalSort( + dummyChild, + List.of( + org.apache.commons.lang3.tuple.ImmutablePair.of( + org.opensearch.sql.ast.tree.Sort.SortOption.DEFAULT_ASC, + new ReferenceExpression("name", STRING)))); + + ExpressionEvaluationException ex = + assertThrows(ExpressionEvaluationException.class, () -> builder.pushDownSort(sort)); + assertTrue(ex.getMessage().contains("unsupported sort expression")); + } + + @Test + void pushDownSortMultipleExpressionsRejectsNonScore() { + var requestBuilder = createRequestBuilder(); + var knnQuery = new WrapperQueryBuilder("{\"knn\":{}}"); + var builder = new VectorSearchQueryBuilder(requestBuilder, knnQuery, Map.of("k", "5")); + + var dummyChild = new LogicalValues(Collections.emptyList()); + var sort = + new org.opensearch.sql.planner.logical.LogicalSort( + dummyChild, + List.of( + org.apache.commons.lang3.tuple.ImmutablePair.of( + org.opensearch.sql.ast.tree.Sort.SortOption.DEFAULT_DESC, + new ReferenceExpression("_score", ExprCoreType.FLOAT)), + org.apache.commons.lang3.tuple.ImmutablePair.of( + org.opensearch.sql.ast.tree.Sort.SortOption.DEFAULT_ASC, + new ReferenceExpression("name", STRING)))); + + ExpressionEvaluationException ex = + assertThrows(ExpressionEvaluationException.class, () -> builder.pushDownSort(sort)); + assertTrue(ex.getMessage().contains("unsupported sort expression")); + } + + @Test + void pushDownSortScoreAscRejected() { + var requestBuilder = createRequestBuilder(); + var knnQuery = new WrapperQueryBuilder("{\"knn\":{}}"); + var builder = new VectorSearchQueryBuilder(requestBuilder, knnQuery, Map.of("k", "5")); + + var dummyChild = new LogicalValues(Collections.emptyList()); + var sort = + new org.opensearch.sql.planner.logical.LogicalSort( + dummyChild, + List.of( + org.apache.commons.lang3.tuple.ImmutablePair.of( + org.opensearch.sql.ast.tree.Sort.SortOption.DEFAULT_ASC, + new ReferenceExpression("_score", ExprCoreType.FLOAT)))); + + ExpressionEvaluationException ex = + assertThrows(ExpressionEvaluationException.class, () -> builder.pushDownSort(sort)); + assertTrue(ex.getMessage().contains("_score ASC is not supported")); + } + + @Test + void pushDownFilterCompoundPredicateSurvives() { + var requestBuilder = createRequestBuilder(); + var knnQuery = new WrapperQueryBuilder("{\"knn\":{}}"); + var builder = new VectorSearchQueryBuilder(requestBuilder, knnQuery, Map.of("k", "5")); + + // Simulate WHERE name = 'John' AND age > 30 + var condition = + DSL.and( + DSL.equal(new ReferenceExpression("name", STRING), DSL.literal("John")), + DSL.greater(new ReferenceExpression("age", ExprCoreType.INTEGER), DSL.literal(30))); + var dummyChild = new LogicalValues(Collections.emptyList()); + var filter = new LogicalFilter(dummyChild, condition); + + boolean pushed = builder.pushDownFilter(filter); + + assertTrue(pushed, "pushDownFilter with compound predicate should succeed"); + QueryBuilder resultQuery = requestBuilder.getSourceBuilder().query(); + assertTrue(resultQuery instanceof BoolQueryBuilder, "Result should be a BoolQuery"); + BoolQueryBuilder boolQuery = (BoolQueryBuilder) resultQuery; + assertEquals(1, boolQuery.must().size(), "knn query should be in must (scoring context)"); + assertEquals(1, boolQuery.filter().size(), "compound WHERE should be in filter (non-scoring)"); + } + + @Test + void pushDownFilterEfficientPlacesInsideKnn() { + var requestBuilder = createRequestBuilder(); + var knnQuery = new WrapperQueryBuilder("{\"knn\":{}}"); + // Callback simulates VectorSearchIndex rebuilding knn with filter + Function rebuildWithFilter = + whereQuery -> new WrapperQueryBuilder("{\"knn\":{\"filter\":\"embedded\"}}"); + var builder = + new VectorSearchQueryBuilder( + requestBuilder, + knnQuery, + Map.of("k", "5"), + FilterType.EFFICIENT, + true, + rebuildWithFilter); + + var condition = DSL.equal(new ReferenceExpression("city", STRING), DSL.literal("Miami")); + var dummyChild = new LogicalValues(Collections.emptyList()); + var filter = new LogicalFilter(dummyChild, condition); + + boolean pushed = builder.pushDownFilter(filter); + + assertTrue(pushed, "pushDownFilter should succeed"); + QueryBuilder resultQuery = requestBuilder.getSourceBuilder().query(); + assertTrue( + resultQuery instanceof WrapperQueryBuilder, + "Efficient filter should produce a WrapperQueryBuilder (rebuilt knn), not BoolQuery"); + } + + @Test + void pushDownFilterExplicitPostProducesBool() { + var requestBuilder = createRequestBuilder(); + var knnQuery = new WrapperQueryBuilder("{\"knn\":{}}"); + var builder = + new VectorSearchQueryBuilder( + requestBuilder, knnQuery, Map.of("k", "5"), FilterType.POST, true, null); + + var condition = DSL.equal(new ReferenceExpression("name", STRING), DSL.literal("John")); + var dummyChild = new LogicalValues(Collections.emptyList()); + var filter = new LogicalFilter(dummyChild, condition); + + boolean pushed = builder.pushDownFilter(filter); + + assertTrue(pushed); + QueryBuilder resultQuery = requestBuilder.getSourceBuilder().query(); + assertTrue(resultQuery instanceof BoolQueryBuilder); + BoolQueryBuilder boolQuery = (BoolQueryBuilder) resultQuery; + assertEquals(1, boolQuery.must().size()); + assertEquals(1, boolQuery.filter().size()); + } + + // ── Constructor validation ────────────────────────────────────────── + + @Test + void constructorRejectsEfficientModeWithNullCallback() { + var requestBuilder = createRequestBuilder(); + var knnQuery = new WrapperQueryBuilder("{\"knn\":{}}"); + + assertThrows( + IllegalArgumentException.class, + () -> + new VectorSearchQueryBuilder( + requestBuilder, knnQuery, Map.of("k", "5"), FilterType.EFFICIENT, true, null)); + } + + // ── Build-time validation ──────────────────────────────────────────── + + @Test + void buildRejectsExplicitFilterTypePostWithoutWhere() { + var requestBuilder = createRequestBuilder(); + var knnQuery = new WrapperQueryBuilder("{\"knn\":{}}"); + var builder = + new VectorSearchQueryBuilder( + requestBuilder, knnQuery, Map.of("k", "5"), FilterType.POST, true, null); + + ExpressionEvaluationException ex = + assertThrows(ExpressionEvaluationException.class, builder::build); + assertTrue(ex.getMessage().contains("filter_type requires a pushdownable WHERE clause")); + } + + @Test + void buildRejectsExplicitFilterTypeEfficientWithoutWhere() { + var requestBuilder = createRequestBuilder(); + var knnQuery = new WrapperQueryBuilder("{\"knn\":{}}"); + Function rebuildWithFilter = + whereQuery -> new WrapperQueryBuilder("{\"knn\":{\"filter\":\"embedded\"}}"); + var builder = + new VectorSearchQueryBuilder( + requestBuilder, + knnQuery, + Map.of("k", "5"), + FilterType.EFFICIENT, + true, + rebuildWithFilter); + + ExpressionEvaluationException ex = + assertThrows(ExpressionEvaluationException.class, builder::build); + assertTrue(ex.getMessage().contains("filter_type requires a pushdownable WHERE clause")); + } + + @Test + void buildSucceedsWithNoFilterTypeAndNoWhere() { + var requestBuilder = createRequestBuilder(); + var knnQuery = new WrapperQueryBuilder("{\"knn\":{}}"); + var builder = new VectorSearchQueryBuilder(requestBuilder, knnQuery, Map.of("k", "5")); + + OpenSearchRequestBuilder result = builder.build(); + assertNotNull(result); + } + + @Test + void buildSucceedsWithFilterTypeAndPushedWhere() { + var requestBuilder = createRequestBuilder(); + var knnQuery = new WrapperQueryBuilder("{\"knn\":{}}"); + var builder = + new VectorSearchQueryBuilder( + requestBuilder, knnQuery, Map.of("k", "5"), FilterType.POST, true, null); + + var condition = DSL.equal(new ReferenceExpression("name", STRING), DSL.literal("John")); + var dummyChild = new LogicalValues(Collections.emptyList()); + builder.pushDownFilter(new LogicalFilter(dummyChild, condition)); + + OpenSearchRequestBuilder result = builder.build(); + assertNotNull(result); + } + + // ── Radial without LIMIT rejection ───────────────────────────────── + + @Test + void buildRejectsRadialMaxDistanceWithoutLimit() { + var requestBuilder = createRequestBuilder(); + var knnQuery = new WrapperQueryBuilder("{\"knn\":{}}"); + var builder = + new VectorSearchQueryBuilder(requestBuilder, knnQuery, Map.of("max_distance", "10.0")); + + ExpressionEvaluationException ex = + assertThrows(ExpressionEvaluationException.class, builder::build); + assertTrue(ex.getMessage().contains("LIMIT is required for radial vector search")); + } + + @Test + void buildRejectsRadialMinScoreWithoutLimit() { + var requestBuilder = createRequestBuilder(); + var knnQuery = new WrapperQueryBuilder("{\"knn\":{}}"); + var builder = + new VectorSearchQueryBuilder(requestBuilder, knnQuery, Map.of("min_score", "0.5")); + + ExpressionEvaluationException ex = + assertThrows(ExpressionEvaluationException.class, builder::build); + assertTrue(ex.getMessage().contains("LIMIT is required for radial vector search")); + } + + @Test + void buildSucceedsRadialWithLimit() { + var requestBuilder = createRequestBuilder(); + var knnQuery = new WrapperQueryBuilder("{\"knn\":{}}"); + var builder = + new VectorSearchQueryBuilder(requestBuilder, knnQuery, Map.of("max_distance", "10.0")); + + var dummyChild = new LogicalValues(Collections.emptyList()); + builder.pushDownLimit(new LogicalLimit(dummyChild, 50, 0)); + + OpenSearchRequestBuilder result = builder.build(); + assertNotNull(result); + } + + @Test + void buildSucceedsTopKWithoutLimit() { + var requestBuilder = createRequestBuilder(); + var knnQuery = new WrapperQueryBuilder("{\"knn\":{}}"); + var builder = new VectorSearchQueryBuilder(requestBuilder, knnQuery, Map.of("k", "5")); + + OpenSearchRequestBuilder result = builder.build(); + assertNotNull(result); + } + + // ── Regression: LIMIT and sort invariants under efficient mode ────── + + @Test + void pushDownLimitExceedingKThrowsUnderEfficientMode() { + var requestBuilder = createRequestBuilder(); + var knnQuery = new WrapperQueryBuilder("{\"knn\":{}}"); + Function rebuildWithFilter = + whereQuery -> new WrapperQueryBuilder("{\"knn\":{}}"); + var builder = + new VectorSearchQueryBuilder( + requestBuilder, + knnQuery, + Map.of("k", "5"), + FilterType.EFFICIENT, + true, + rebuildWithFilter); + + var dummyChild = new LogicalValues(Collections.emptyList()); + var limit = new LogicalLimit(dummyChild, 10, 0); + + ExpressionEvaluationException ex = + assertThrows(ExpressionEvaluationException.class, () -> builder.pushDownLimit(limit)); + assertTrue(ex.getMessage().contains("LIMIT 10 exceeds k=5")); + } + + @Test + void pushDownSortScoreDescAcceptedUnderEfficientMode() { + var requestBuilder = createRequestBuilder(); + var knnQuery = new WrapperQueryBuilder("{\"knn\":{}}"); + Function rebuildWithFilter = + whereQuery -> new WrapperQueryBuilder("{\"knn\":{}}"); + var builder = + new VectorSearchQueryBuilder( + requestBuilder, + knnQuery, + Map.of("k", "5"), + FilterType.EFFICIENT, + true, + rebuildWithFilter); + + var dummyChild = new LogicalValues(Collections.emptyList()); + var sort = + new org.opensearch.sql.planner.logical.LogicalSort( + dummyChild, + List.of( + org.apache.commons.lang3.tuple.ImmutablePair.of( + org.opensearch.sql.ast.tree.Sort.SortOption.DEFAULT_DESC, + new ReferenceExpression("_score", ExprCoreType.FLOAT)))); + + boolean pushed = builder.pushDownSort(sort); + assertTrue(pushed, "ORDER BY _score DESC should be accepted under efficient mode"); + } + + @Test + void pushDownSortNonScoreRejectedUnderEfficientMode() { + var requestBuilder = createRequestBuilder(); + var knnQuery = new WrapperQueryBuilder("{\"knn\":{}}"); + Function rebuildWithFilter = + whereQuery -> new WrapperQueryBuilder("{\"knn\":{}}"); + var builder = + new VectorSearchQueryBuilder( + requestBuilder, + knnQuery, + Map.of("k", "5"), + FilterType.EFFICIENT, + true, + rebuildWithFilter); + + var dummyChild = new LogicalValues(Collections.emptyList()); + var sort = + new org.opensearch.sql.planner.logical.LogicalSort( + dummyChild, + List.of( + org.apache.commons.lang3.tuple.ImmutablePair.of( + org.opensearch.sql.ast.tree.Sort.SortOption.DEFAULT_ASC, + new ReferenceExpression("name", STRING)))); + + ExpressionEvaluationException ex = + assertThrows(ExpressionEvaluationException.class, () -> builder.pushDownSort(sort)); + assertTrue(ex.getMessage().contains("unsupported sort expression")); + } + + // ── Non-pushdownable filter handling ────────────────────────────────── + + @Test + void pushDownFilterNonPushdownableWithExplicitFilterTypeThrows() { + var requestBuilder = createRequestBuilder(); + var knnQuery = new WrapperQueryBuilder("{\"knn\":{}}"); + var builder = + new VectorSearchQueryBuilder( + requestBuilder, knnQuery, Map.of("k", "5"), FilterType.POST, true, null); + + // STRUCT = STRUCT triggers ScriptQueryUnSupportedException in FilterQueryBuilder + var condition = + DSL.equal( + new ReferenceExpression("nested_field", ExprCoreType.STRUCT), + new ReferenceExpression("other_field", ExprCoreType.STRUCT)); + var dummyChild = new LogicalValues(Collections.emptyList()); + var filter = new LogicalFilter(dummyChild, condition); + + ExpressionEvaluationException ex = + assertThrows(ExpressionEvaluationException.class, () -> builder.pushDownFilter(filter)); + assertTrue( + ex.getMessage().contains("filter_type only works when the WHERE clause can be translated")); + assertTrue(ex.getMessage().contains("Rewrite the WHERE clause or omit filter_type")); + } + + @Test + void pushDownFilterNonPushdownableWithoutExplicitFilterTypeFallsBack() { + var requestBuilder = createRequestBuilder(); + var knnQuery = new WrapperQueryBuilder("{\"knn\":{}}"); + var builder = new VectorSearchQueryBuilder(requestBuilder, knnQuery, Map.of("k", "5")); + + // STRUCT = STRUCT triggers ScriptQueryUnSupportedException in FilterQueryBuilder + var condition = + DSL.equal( + new ReferenceExpression("nested_field", ExprCoreType.STRUCT), + new ReferenceExpression("other_field", ExprCoreType.STRUCT)); + var dummyChild = new LogicalValues(Collections.emptyList()); + var filter = new LogicalFilter(dummyChild, condition); + + boolean pushed = builder.pushDownFilter(filter); + assertFalse(pushed, "Non-pushdownable filter should return false for in-memory fallback"); + } + + // ── OFFSET rejection ──────────────────────────────────────────────── + + @Test + void pushDownLimit_rejectsNonZeroOffset() { + var requestBuilder = createRequestBuilder(); + var knnQuery = new WrapperQueryBuilder("{\"knn\":{}}"); + var builder = new VectorSearchQueryBuilder(requestBuilder, knnQuery, Map.of("k", "5")); + + var dummyChild = new LogicalValues(Collections.emptyList()); + // LIMIT 3 OFFSET 2: the planner passes both through LogicalLimit + var limit = new LogicalLimit(dummyChild, 3, 2); + + ExpressionEvaluationException ex = + assertThrows(ExpressionEvaluationException.class, () -> builder.pushDownLimit(limit)); + assertTrue( + ex.getMessage().contains("OFFSET is not supported on vectorSearch()"), + "Expected OFFSET rejection message, got: " + ex.getMessage()); + assertTrue( + ex.getMessage().contains("LIMIT only"), + "Expected remediation guidance in message, got: " + ex.getMessage()); + } + + @Test + void pushDownLimit_acceptsZeroOffset() { + var requestBuilder = createRequestBuilder(); + var knnQuery = new WrapperQueryBuilder("{\"knn\":{}}"); + var builder = new VectorSearchQueryBuilder(requestBuilder, knnQuery, Map.of("k", "5")); + + var dummyChild = new LogicalValues(Collections.emptyList()); + var limit = new LogicalLimit(dummyChild, 3, 0); + + // Zero offset is the normal case; must continue to succeed. + assertTrue(builder.pushDownLimit(limit)); + } + + // ── WHERE on _score rejection ──────────────────────────────────────── + + @Test + void pushDownFilter_rejectsScoreReferenceInWhere() { + var requestBuilder = createRequestBuilder(); + var knnQuery = new WrapperQueryBuilder("{\"knn\":{}}"); + var builder = new VectorSearchQueryBuilder(requestBuilder, knnQuery, Map.of("k", "5")); + + // WHERE _score > 0.5 (note: _score is a synthetic column, not a stored field) + var condition = + DSL.greater(new ReferenceExpression("_score", ExprCoreType.FLOAT), DSL.literal(0.5)); + var dummyChild = new LogicalValues(Collections.emptyList()); + var filter = new LogicalFilter(dummyChild, condition); + + ExpressionEvaluationException ex = + assertThrows(ExpressionEvaluationException.class, () -> builder.pushDownFilter(filter)); + assertTrue( + ex.getMessage().contains("WHERE on _score is not supported"), + "Expected _score rejection message, got: " + ex.getMessage()); + assertTrue( + ex.getMessage().contains("min_score"), + "Expected remediation guidance pointing at option='min_score=...', got: " + + ex.getMessage()); + } + + @Test + void pushDownFilter_rejectsScoreReferenceInsideCompound() { + var requestBuilder = createRequestBuilder(); + var knnQuery = new WrapperQueryBuilder("{\"knn\":{}}"); + var builder = new VectorSearchQueryBuilder(requestBuilder, knnQuery, Map.of("k", "5")); + + // WHERE state = 'TX' AND _score > 0.5: rejection must walk compound predicates + var condition = + DSL.and( + DSL.equal(new ReferenceExpression("state", STRING), DSL.literal("TX")), + DSL.greater(new ReferenceExpression("_score", ExprCoreType.FLOAT), DSL.literal(0.5))); + var dummyChild = new LogicalValues(Collections.emptyList()); + var filter = new LogicalFilter(dummyChild, condition); + + ExpressionEvaluationException ex = + assertThrows(ExpressionEvaluationException.class, () -> builder.pushDownFilter(filter)); + assertTrue( + ex.getMessage().contains("WHERE on _score is not supported"), + "Expected _score rejection message, got: " + ex.getMessage()); + } + + @Test + void pushDownFilter_rejectsUppercaseScoreReference() { + var requestBuilder = createRequestBuilder(); + var knnQuery = new WrapperQueryBuilder("{\"knn\":{}}"); + var builder = new VectorSearchQueryBuilder(requestBuilder, knnQuery, Map.of("k", "5")); + + // WHERE _SCORE > 0.5 must be rejected the same way as _score; the check is case-insensitive + // so variants that preserve original casing cannot bypass the guard. + var condition = + DSL.greater(new ReferenceExpression("_SCORE", ExprCoreType.FLOAT), DSL.literal(0.5)); + var dummyChild = new LogicalValues(Collections.emptyList()); + var filter = new LogicalFilter(dummyChild, condition); + + ExpressionEvaluationException ex = + assertThrows(ExpressionEvaluationException.class, () -> builder.pushDownFilter(filter)); + assertTrue( + ex.getMessage().contains("WHERE on _score is not supported"), + "Expected _score rejection message, got: " + ex.getMessage()); + } + + // ── filter_type=efficient rejects script subtrees ─────────────────── + + @Test + void pushDownFilter_efficient_rejectsScriptSubtree() { + var requestBuilder = createRequestBuilder(); + var knnQuery = new WrapperQueryBuilder("{\"knn\":{}}"); + Function rebuildWithFilter = + whereQuery -> new WrapperQueryBuilder("{\"knn\":{\"filter\":\"embedded\"}}"); + var builder = + new VectorSearchQueryBuilder( + requestBuilder, + knnQuery, + Map.of("k", "5"), + FilterType.EFFICIENT, + true, + rebuildWithFilter); + + // price + 1 > 100 lowers to a ScriptQueryBuilder; embedding it under knn.filter would + // trigger the AOSS rejection this PR guards against. + var condition = + DSL.greater( + DSL.add(new ReferenceExpression("price", ExprCoreType.INTEGER), DSL.literal(1)), + DSL.literal(100)); + var dummyChild = new LogicalValues(Collections.emptyList()); + var filter = new LogicalFilter(dummyChild, condition); + + ExpressionEvaluationException ex = + assertThrows(ExpressionEvaluationException.class, () -> builder.pushDownFilter(filter)); + assertTrue( + ex.getMessage().contains("vectorSearch WHERE pre-filtering does not support"), + "Expected script rejection message, got: " + ex.getMessage()); + assertTrue( + ex.getMessage().contains("script queries"), + "Expected script queries guidance in message, got: " + ex.getMessage()); + assertTrue( + ex.getMessage().contains("filter_type=post"), + "Expected filter_type=post fallback guidance, got: " + ex.getMessage()); + } + + @Test + void pushDownFilter_post_allowsScriptSubtree() { + // POST puts WHERE in an outer bool.filter, not under knn.filter, so scripts are fine. + var requestBuilder = createRequestBuilder(); + var knnQuery = new WrapperQueryBuilder("{\"knn\":{}}"); + var builder = new VectorSearchQueryBuilder(requestBuilder, knnQuery, Map.of("k", "5")); + + var condition = + DSL.greater( + DSL.add(new ReferenceExpression("price", ExprCoreType.INTEGER), DSL.literal(1)), + DSL.literal(100)); + var dummyChild = new LogicalValues(Collections.emptyList()); + var filter = new LogicalFilter(dummyChild, condition); + + assertTrue(builder.pushDownFilter(filter), "POST mode must still accept script predicates"); + } + + @Test + void buildSucceedsRadialWithSortEmbeddedLimit() { + var requestBuilder = createRequestBuilder(); + var knnQuery = new WrapperQueryBuilder("{\"knn\":{}}"); + var builder = + new VectorSearchQueryBuilder(requestBuilder, knnQuery, Map.of("max_distance", "10.0")); + + var dummyChild = new LogicalValues(Collections.emptyList()); + // LogicalSort with count=50 simulates PPL sort-with-limit path + var sort = + new org.opensearch.sql.planner.logical.LogicalSort( + dummyChild, + 50, + List.of( + org.apache.commons.lang3.tuple.ImmutablePair.of( + org.opensearch.sql.ast.tree.Sort.SortOption.DEFAULT_DESC, + new ReferenceExpression("_score", ExprCoreType.FLOAT)))); + + builder.pushDownSort(sort); + + // build() should not reject — limitPushed must be true via pushDownSort's count path + OpenSearchRequestBuilder result = builder.build(); + assertNotNull(result); + } + + // ── filter_type=efficient allow-list validator ────────────────────── + + @Test + void validateEfficientFilterSafe_rejectsNestedQuery() { + // FilterQueryBuilder emits NestedQueryBuilder for SQL nested(field, pred); nested vector + // semantics are outside the P0 preview so rejection must be targeted, not generic. + QueryBuilder nested = + QueryBuilders.nestedQuery( + "parent", QueryBuilders.termQuery("parent.f", "v"), ScoreMode.None); + + ExpressionEvaluationException ex = + assertThrows( + ExpressionEvaluationException.class, + () -> VectorSearchQueryBuilder.validateEfficientFilterSafe(nested)); + assertTrue( + ex.getMessage().contains("vectorSearch WHERE pre-filtering does not support nested"), + "Expected targeted nested rejection, got: " + ex.getMessage()); + } + + @Test + void validateEfficientFilterSafe_rejectsNestedBuriedInBool() { + // AND-ing nested() with a term must still be caught; otherwise the guard is trivially bypassed. + QueryBuilder tree = + QueryBuilders.boolQuery() + .filter(QueryBuilders.termQuery("state", "CA")) + .filter( + QueryBuilders.nestedQuery( + "parent", QueryBuilders.termQuery("parent.f", "v"), ScoreMode.None)); + + ExpressionEvaluationException ex = + assertThrows( + ExpressionEvaluationException.class, + () -> VectorSearchQueryBuilder.validateEfficientFilterSafe(tree)); + assertTrue(ex.getMessage().contains("nested predicates")); + } + + @Test + void validateEfficientFilterSafe_acceptsBoolOfSafeLeaves() { + QueryBuilder tree = + QueryBuilders.boolQuery() + .filter(QueryBuilders.termQuery("category", "shoes")) + .filter(QueryBuilders.rangeQuery("price").gte(80).lte(150)); + + VectorSearchQueryBuilder.validateEfficientFilterSafe(tree); + } + + @Test + void validateEfficientFilterSafe_acceptsExistsLeaf() { + // IS NOT NULL lowers to ExistsQueryBuilder; locks in allow-list coverage for that path. + QueryBuilder exists = QueryBuilders.existsQuery("brand"); + + VectorSearchQueryBuilder.validateEfficientFilterSafe(exists); + } + + @Test + void validateEfficientFilterSafe_rejectsUnknownWrapper() { + // Unknown shapes must fail closed so future FilterQueryBuilder additions cannot silently + // re-introduce the AOSS-rejection bug class this PR is guarding against. + QueryBuilder unknown = new WrapperQueryBuilder("{\"term\":{\"f\":\"v\"}}"); + + ExpressionEvaluationException ex = + assertThrows( + ExpressionEvaluationException.class, + () -> VectorSearchQueryBuilder.validateEfficientFilterSafe(unknown)); + assertTrue( + ex.getMessage().contains("unsupported filter query shape"), + "Expected unknown-shape rejection, got: " + ex.getMessage()); + assertTrue( + ex.getMessage().contains("WrapperQueryBuilder"), + "Expected class name in message, got: " + ex.getMessage()); + } + + private OpenSearchRequestBuilder createRequestBuilder() { + return new OpenSearchRequestBuilder( + mock(OpenSearchExprValueFactory.class), 10000, mock(Settings.class)); + } +} diff --git a/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/script/filter/FilterQueryBuilderTest.java b/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/script/filter/FilterQueryBuilderTest.java index 310bb5e73c5..e930056474a 100644 --- a/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/script/filter/FilterQueryBuilderTest.java +++ b/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/script/filter/FilterQueryBuilderTest.java @@ -174,20 +174,80 @@ void should_build_wildcard_query_for_like_expression() { } @Test - void should_build_script_query_for_unsupported_lucene_query() { + void should_build_exists_query_for_is_not_null() { + assertJsonEquals( + "{\n" + + " \"exists\" : {\n" + + " \"field\" : \"age\",\n" + + " \"boost\" : 1.0\n" + + " }\n" + + "}", + buildQuery(DSL.isnotnull(ref("age", INTEGER)))); + } + + @Test + void should_build_must_not_exists_query_for_is_null() { + assertJsonEquals( + "{\n" + + " \"bool\" : {\n" + + " \"must_not\" : [\n" + + " {\n" + + " \"exists\" : {\n" + + " \"field\" : \"age\",\n" + + " \"boost\" : 1.0\n" + + " }\n" + + " }\n" + + " ],\n" + + " \"adjust_pure_negative\" : true,\n" + + " \"boost\" : 1.0\n" + + " }\n" + + "}", + buildQuery(DSL.is_null(ref("age", INTEGER)))); + } + + @Test + void should_fallback_to_script_for_nested_is_not_null() { + // Nested IS_NOT_NULL must NOT route through NestedQuery.buildNested(): that path reads + // arg[1] and unary IS_NOT_NULL only has arg[0]. ExistsQuery.isNestedPredicate() returns + // false precisely to force the script fallback here. mockToStringSerializer(); assertJsonEquals( "{\n" + " \"script\" : {\n" + " \"script\" : {\n" - + " \"source\" : \"{\\\"langType\\\":\\\"v2\\\",\\\"script\\\":\\\"is not" - + " null(age)\\\"}\",\n" + + " \"source\" :" + + " \"{\\\"langType\\\":\\\"v2\\\",\\\"script\\\":\\\"is" + + " not null(FunctionExpression(functionName=nested, arguments=[message.info," + + " message]))\\\"}\",\n" + " \"lang\" : \"opensearch_compounded_script\"\n" + " },\n" + " \"boost\" : 1.0\n" + " }\n" + "}", - buildQuery(DSL.isnotnull(ref("age", INTEGER)))); + buildQuery( + DSL.isnotnull( + DSL.nested(DSL.ref("message.info", STRING), DSL.ref("message", STRING))))); + } + + @Test + void should_fallback_to_script_for_nested_is_null() { + // Symmetric to the IS_NOT_NULL case: must not crash with an arg[1] lookup via NestedQuery. + mockToStringSerializer(); + assertJsonEquals( + "{\n" + + " \"script\" : {\n" + + " \"script\" : {\n" + + " \"source\" :" + + " \"{\\\"langType\\\":\\\"v2\\\",\\\"script\\\":\\\"is" + + " null(FunctionExpression(functionName=nested, arguments=[message.info," + + " message]))\\\"}\",\n" + + " \"lang\" : \"opensearch_compounded_script\"\n" + + " },\n" + + " \"boost\" : 1.0\n" + + " }\n" + + "}", + buildQuery( + DSL.is_null(DSL.nested(DSL.ref("message.info", STRING), DSL.ref("message", STRING))))); } @Test diff --git a/sql/src/main/antlr/OpenSearchSQLParser.g4 b/sql/src/main/antlr/OpenSearchSQLParser.g4 index 5f7361160b3..6b34507eacc 100644 --- a/sql/src/main/antlr/OpenSearchSQLParser.g4 +++ b/sql/src/main/antlr/OpenSearchSQLParser.g4 @@ -109,8 +109,18 @@ fromClause ; relation - : tableName (AS? alias)? # tableAsRelation - | LR_BRACKET subquery = querySpecification RR_BRACKET AS? alias # subqueryAsRelation + : tableName (AS? alias)? # tableAsRelation + | LR_BRACKET subquery = querySpecification RR_BRACKET AS? alias # subqueryAsRelation + | qualifiedName LR_BRACKET tableFunctionArgs RR_BRACKET (AS? alias)? # tableFunctionRelation + ; + +tableFunctionArgs + : tableFunctionArg (COMMA tableFunctionArg)* + ; + +tableFunctionArg + : ident EQUAL_SYMBOL functionArg + | functionArg ; whereClause diff --git a/sql/src/main/java/org/opensearch/sql/sql/parser/AstBuilder.java b/sql/src/main/java/org/opensearch/sql/sql/parser/AstBuilder.java index bdbc360713c..5250ab7fb0f 100644 --- a/sql/src/main/java/org/opensearch/sql/sql/parser/AstBuilder.java +++ b/sql/src/main/java/org/opensearch/sql/sql/parser/AstBuilder.java @@ -13,6 +13,7 @@ import static org.opensearch.sql.sql.antlr.parser.OpenSearchSQLParser.SelectElementContext; import static org.opensearch.sql.sql.antlr.parser.OpenSearchSQLParser.SubqueryAsRelationContext; import static org.opensearch.sql.sql.antlr.parser.OpenSearchSQLParser.TableAsRelationContext; +import static org.opensearch.sql.sql.antlr.parser.OpenSearchSQLParser.TableFunctionRelationContext; import static org.opensearch.sql.sql.antlr.parser.OpenSearchSQLParser.WhereClauseContext; import static org.opensearch.sql.sql.parser.ParserUtils.getTextInQuery; import static org.opensearch.sql.utils.SystemIndexUtils.TABLE_INFO; @@ -20,12 +21,14 @@ import com.google.common.collect.ImmutableList; import java.util.Collections; +import java.util.Locale; import java.util.Optional; import lombok.RequiredArgsConstructor; import org.antlr.v4.runtime.tree.ParseTree; import org.opensearch.sql.ast.expression.Alias; import org.opensearch.sql.ast.expression.AllFields; import org.opensearch.sql.ast.expression.Function; +import org.opensearch.sql.ast.expression.UnresolvedArgument; import org.opensearch.sql.ast.expression.UnresolvedExpression; import org.opensearch.sql.ast.tree.DescribeRelation; import org.opensearch.sql.ast.tree.Filter; @@ -34,10 +37,12 @@ import org.opensearch.sql.ast.tree.Relation; import org.opensearch.sql.ast.tree.RelationSubquery; import org.opensearch.sql.ast.tree.SubqueryAlias; +import org.opensearch.sql.ast.tree.TableFunction; import org.opensearch.sql.ast.tree.UnresolvedPlan; import org.opensearch.sql.ast.tree.Values; import org.opensearch.sql.common.antlr.SyntaxCheckException; import org.opensearch.sql.common.utils.StringUtils; +import org.opensearch.sql.exception.SemanticCheckException; import org.opensearch.sql.expression.function.BuiltinFunctionName; import org.opensearch.sql.sql.antlr.parser.OpenSearchSQLParser; import org.opensearch.sql.sql.antlr.parser.OpenSearchSQLParser.QuerySpecificationContext; @@ -189,6 +194,57 @@ public UnresolvedPlan visitSubqueryAsRelation(SubqueryAsRelationContext ctx) { return new RelationSubquery(visit(ctx.subquery), subqueryAlias); } + @Override + public UnresolvedPlan visitTableFunctionRelation(TableFunctionRelationContext ctx) { + // The grammar accepts both `ident = value` and bare `value` forms for each table function + // argument so that the real positional shape (e.g. `vectorSearch('idx', field='f', ...)`) + // reaches this V2 builder instead of failing to parse and silently falling back to the + // legacy SQL engine. Reject the positional shape here with a SemanticCheckException so the + // user receives a clean 400 rather than an opaque legacy parser error. + ctx.tableFunctionArgs() + .tableFunctionArg() + .forEach( + arg -> { + if (arg.ident() == null) { + String functionName = ctx.qualifiedName().getText(); + throw new SemanticCheckException( + String.format( + Locale.ROOT, + "Table function '%s' requires named arguments (e.g. name='value')," + + " but received a positional argument: %s", + functionName, + arg.functionArg().getText())); + } + }); + ImmutableList.Builder args = ImmutableList.builder(); + ctx.tableFunctionArgs() + .tableFunctionArg() + .forEach( + arg -> { + String argName = + StringUtils.unquoteIdentifier(arg.ident().getText()).toLowerCase(Locale.ROOT); + UnresolvedExpression argValue = visitAstExpression(arg.functionArg()); + args.add(new UnresolvedArgument(argName, argValue)); + }); + TableFunction tableFunction = + new TableFunction(visitAstExpression(ctx.qualifiedName()), args.build()); + if (ctx.alias() == null) { + String functionName = ctx.qualifiedName().getText(); + // Use SemanticCheckException (not SyntaxCheckException) so the request does not fall back + // to the legacy SQL engine, whose opaque parser error would mask this message. + throw new SemanticCheckException( + String.format( + Locale.ROOT, + "Table function '%s' requires a table alias." + + " Add an alias after the closing parenthesis, for example:" + + " FROM %s(...) AS v", + functionName, + functionName)); + } + String alias = StringUtils.unquoteIdentifier(ctx.alias().getText()); + return new SubqueryAlias(alias, tableFunction); + } + @Override public UnresolvedPlan visitWhereClause(WhereClauseContext ctx) { return new Filter(visitAstExpression(ctx.expression())); diff --git a/sql/src/test/java/org/opensearch/sql/sql/parser/AstBuilderTest.java b/sql/src/test/java/org/opensearch/sql/sql/parser/AstBuilderTest.java index 1ecaa181e6f..695cf85b144 100644 --- a/sql/src/test/java/org/opensearch/sql/sql/parser/AstBuilderTest.java +++ b/sql/src/test/java/org/opensearch/sql/sql/parser/AstBuilderTest.java @@ -6,6 +6,8 @@ package org.opensearch.sql.sql.parser; import static java.util.Collections.emptyList; +import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.Matchers.containsString; import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertThrows; import static org.opensearch.sql.ast.dsl.AstDSL.agg; @@ -40,7 +42,11 @@ import org.opensearch.sql.ast.expression.DataType; import org.opensearch.sql.ast.expression.Literal; import org.opensearch.sql.ast.expression.NestedAllTupleFields; +import org.opensearch.sql.ast.expression.UnresolvedArgument; +import org.opensearch.sql.ast.tree.SubqueryAlias; +import org.opensearch.sql.ast.tree.TableFunction; import org.opensearch.sql.common.antlr.SyntaxCheckException; +import org.opensearch.sql.exception.SemanticCheckException; class AstBuilderTest extends AstBuilderTestBase { @@ -131,6 +137,142 @@ public void can_build_from_index_with_alias_quoted() { buildAST("SELECT `t`.name FROM test `t` WHERE `t`.age = 30")); } + @Test + public void can_build_from_table_function() { + assertEquals( + project( + new SubqueryAlias( + "v", + new TableFunction( + qualifiedName("vectorSearch"), + ImmutableList.of( + new UnresolvedArgument("table", stringLiteral("products")), + new UnresolvedArgument("field", stringLiteral("embedding")), + new UnresolvedArgument("vector", stringLiteral("[0.1,0.2]")), + new UnresolvedArgument("option", stringLiteral("k=10"))))), + AllFields.of()), + buildAST( + "SELECT * FROM vectorSearch(" + + "table='products', field='embedding', " + + "vector='[0.1,0.2]', option='k=10') AS v")); + } + + @Test + public void can_build_from_table_function_with_where_order_limit() { + assertEquals( + project( + limit( + sort( + filter( + new SubqueryAlias( + "s", + new TableFunction( + qualifiedName("vectorSearch"), + ImmutableList.of( + new UnresolvedArgument("table", stringLiteral("products")), + new UnresolvedArgument("field", stringLiteral("embedding")), + new UnresolvedArgument("vector", stringLiteral("[0.1,0.2]")), + new UnresolvedArgument("option", stringLiteral("k=10"))))), + function("=", qualifiedName("s", "category"), stringLiteral("shoes"))), + field(qualifiedName("s", "_score"), argument("asc", booleanLiteral(false)))), + 5, + 0), + alias("s.title", qualifiedName("s", "title")), + alias("s._score", qualifiedName("s", "_score"))), + buildAST( + "SELECT s.title, s._score FROM vectorSearch(" + + "table='products', field='embedding', " + + "vector='[0.1,0.2]', option='k=10') AS s " + + "WHERE s.category = 'shoes' " + + "ORDER BY s._score DESC " + + "LIMIT 5")); + } + + @Test + public void table_function_args_are_resolved_by_name_not_position() { + assertEquals( + project( + new SubqueryAlias( + "v", + new TableFunction( + qualifiedName("vectorSearch"), + ImmutableList.of( + new UnresolvedArgument("option", stringLiteral("k=10")), + new UnresolvedArgument("field", stringLiteral("embedding")), + new UnresolvedArgument("table", stringLiteral("products")), + new UnresolvedArgument("vector", stringLiteral("[0.1,0.2]"))))), + AllFields.of()), + buildAST( + "SELECT * FROM vectorSearch(" + + "option='k=10', field='embedding', " + + "table='products', vector='[0.1,0.2]') AS v")); + } + + @Test + public void table_function_arg_names_are_canonicalized() { + assertEquals( + project( + new SubqueryAlias( + "v", + new TableFunction( + qualifiedName("vectorSearch"), + ImmutableList.of( + new UnresolvedArgument("table", stringLiteral("products")), + new UnresolvedArgument("field", stringLiteral("embedding")), + new UnresolvedArgument("vector", stringLiteral("[0.1,0.2]")), + new UnresolvedArgument("option", stringLiteral("k=10"))))), + AllFields.of()), + buildAST( + "SELECT * FROM vectorSearch(" + + "TABLE='products', FIELD='embedding', " + + "VECTOR='[0.1,0.2]', OPTION='k=10') AS v")); + } + + @Test + public void table_function_allows_alias_without_as_keyword() { + assertEquals( + project( + new SubqueryAlias( + "v", + new TableFunction( + qualifiedName("vectorSearch"), + ImmutableList.of( + new UnresolvedArgument("table", stringLiteral("products")), + new UnresolvedArgument("vector", stringLiteral("[0.1]"))))), + AllFields.of()), + buildAST("SELECT * FROM vectorSearch(table='products', vector='[0.1]') v")); + } + + @Test + public void table_function_relation_requires_alias() { + SemanticCheckException ex = + assertThrows( + SemanticCheckException.class, + () -> + buildAST( + "SELECT * FROM vectorSearch(" + + "table='products', field='embedding', " + + "vector='[0.1,0.2]', option='k=10')")); + assertThat(ex.getMessage(), containsString("requires a table alias")); + assertThat(ex.getMessage(), containsString("vectorSearch")); + } + + @Test + public void table_function_relation_rejects_positional_argument() { + // Grammar accepts both `ident=value` and bare `value` for each table function argument so + // the real positional shape reaches the V2 AstBuilder. The AstBuilder must reject it with a + // SemanticCheckException rather than let the request fall back to the legacy engine. + SemanticCheckException ex = + assertThrows( + SemanticCheckException.class, + () -> + buildAST( + "SELECT * FROM vectorSearch('products', field='embedding', " + + "vector='[0.1,0.2]', option='k=10') AS v")); + org.junit.jupiter.api.Assertions.assertTrue( + ex.getMessage().contains("requires named arguments")); + } + @Test public void can_build_where_clause() { assertEquals(