Skip to content

Commit 6dce942

Browse files
branch-4.0: [Enhancement](ai) relax the matching restriction of some AI-Functions #58077 (#58113)
Cherry-picked from #58077 Co-authored-by: linrrarity <[email protected]>
1 parent b19ca07 commit 6dce942

File tree

2 files changed

+190
-2
lines changed

2 files changed

+190
-2
lines changed

be/src/vec/functions/ai/ai_functions.h

Lines changed: 36 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,9 @@
2020
#include <gen_cpp/FrontendService.h>
2121
#include <gen_cpp/PaloInternalService_types.h>
2222

23+
#include <algorithm>
24+
#include <cctype>
25+
#include <cstdlib>
2326
#include <memory>
2427
#include <string>
2528
#include <type_traits>
@@ -122,8 +125,14 @@ class AIFunction : public IFunction {
122125
}
123126
case PrimitiveType::TYPE_BOOLEAN: { // boolean for AI_FILTER
124127
#ifdef BE_TEST
125-
string_result = "0";
128+
const char* test_result = std::getenv("AI_TEST_RESULT");
129+
if (test_result != nullptr) {
130+
string_result = test_result;
131+
} else {
132+
string_result = "0";
133+
}
126134
#endif
135+
trim_string(string_result);
127136
if (string_result != "1" && string_result != "0") {
128137
return Status::RuntimeError("Failed to parse boolean value: " +
129138
string_result);
@@ -133,7 +142,22 @@ class AIFunction : public IFunction {
133142
break;
134143
}
135144
case PrimitiveType::TYPE_FLOAT: { // float for AI_SIMILARITY
136-
assert_cast<ColumnFloat32&>(*col_result).insert_value(std::stof(string_result));
145+
#ifdef BE_TEST
146+
const char* test_result = std::getenv("AI_TEST_RESULT");
147+
if (test_result != nullptr) {
148+
string_result = test_result;
149+
} else {
150+
string_result = "0.0";
151+
}
152+
#endif
153+
trim_string(string_result);
154+
try {
155+
float float_value = std::stof(string_result);
156+
assert_cast<ColumnFloat32&>(*col_result).insert_value(float_value);
157+
} catch (...) {
158+
return Status::RuntimeError("Failed to parse float value: " +
159+
string_result);
160+
}
137161
break;
138162
}
139163
default:
@@ -147,6 +171,16 @@ class AIFunction : public IFunction {
147171
}
148172

149173
private:
174+
// Trim whitespace and newlines from string
175+
static void trim_string(std::string& str) {
176+
str.erase(str.begin(), std::find_if(str.begin(), str.end(),
177+
[](unsigned char ch) { return !std::isspace(ch); }));
178+
str.erase(std::find_if(str.rbegin(), str.rend(),
179+
[](unsigned char ch) { return !std::isspace(ch); })
180+
.base(),
181+
str.end());
182+
}
183+
150184
// The ai resource must be literal
151185
Status _init_from_resource(FunctionContext* context, const Block& block,
152186
const ColumnNumbers& arguments, TAIResource& config,

be/test/ai/ai_function_test.cpp

Lines changed: 154 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -295,6 +295,79 @@ TEST(AIFunctionTest, AISimilarityTest) {
295295
ASSERT_EQ(prompt, "Text 1: I like this dish\nText 2: This dish is very good");
296296
}
297297

298+
TEST(AIFunctionTest, AISimilarityExecuteTest) {
299+
auto runtime_state = std::make_unique<MockRuntimeState>();
300+
auto ctx = FunctionContext::create_context(runtime_state.get(), {}, {});
301+
302+
std::vector<std::string> resources = {"mock_resource"};
303+
std::vector<std::string> text1 = {"I like this dish"};
304+
std::vector<std::string> text2 = {"This dish is very good"};
305+
auto col_resource = ColumnHelper::create_column<DataTypeString>(resources);
306+
auto col_text1 = ColumnHelper::create_column<DataTypeString>(text1);
307+
auto col_text2 = ColumnHelper::create_column<DataTypeString>(text2);
308+
309+
Block block;
310+
block.insert({std::move(col_resource), std::make_shared<DataTypeString>(), "resource"});
311+
block.insert({std::move(col_text1), std::make_shared<DataTypeString>(), "text1"});
312+
block.insert({std::move(col_text2), std::make_shared<DataTypeString>(), "text2"});
313+
block.insert({nullptr, std::make_shared<DataTypeFloat32>(), "result"});
314+
315+
ColumnNumbers arguments = {0, 1, 2};
316+
size_t result_idx = 3;
317+
318+
auto similarity_func = FunctionAISimilarity::create();
319+
Status exec_status =
320+
similarity_func->execute_impl(ctx.get(), block, arguments, result_idx, text1.size());
321+
322+
ASSERT_TRUE(exec_status.ok());
323+
}
324+
325+
TEST(AIFunctionTest, AISimilarityTrimWhitespace) {
326+
auto runtime_state = std::make_unique<MockRuntimeState>();
327+
auto ctx = FunctionContext::create_context(runtime_state.get(), {}, {});
328+
329+
std::vector<std::pair<std::string, float>> test_cases = {
330+
{"0.5", 0.5f}, {"1.0", 1.0f}, {"0.0", 0.0f}, {" 0.5", 0.5f},
331+
{"0.5 ", 0.5f}, {" 0.5 ", 0.5f}, {"\n0.8", 0.8f}, {"0.3\n", 0.3f},
332+
{"\n0.7\n", 0.7f}, {"\t0.2\t", 0.2f}, {" \n\t0.9 \n\t", 0.9f}, {" 0.1 ", 0.1f},
333+
{"\r\n0.6\r\n", 0.6f}};
334+
335+
for (const auto& test_case : test_cases) {
336+
setenv("AI_TEST_RESULT", test_case.first.c_str(), 1);
337+
338+
std::vector<std::string> resources = {"mock_resource"};
339+
std::vector<std::string> text1 = {"Test text 1"};
340+
std::vector<std::string> text2 = {"Test text 2"};
341+
auto col_resource = ColumnHelper::create_column<DataTypeString>(resources);
342+
auto col_text1 = ColumnHelper::create_column<DataTypeString>(text1);
343+
auto col_text2 = ColumnHelper::create_column<DataTypeString>(text2);
344+
345+
Block block;
346+
block.insert({std::move(col_resource), std::make_shared<DataTypeString>(), "resource"});
347+
block.insert({std::move(col_text1), std::make_shared<DataTypeString>(), "text1"});
348+
block.insert({std::move(col_text2), std::make_shared<DataTypeString>(), "text2"});
349+
block.insert({nullptr, std::make_shared<DataTypeFloat32>(), "result"});
350+
351+
ColumnNumbers arguments = {0, 1, 2};
352+
size_t result_idx = 3;
353+
354+
auto similarity_func = FunctionAISimilarity::create();
355+
Status exec_status = similarity_func->execute_impl(ctx.get(), block, arguments, result_idx,
356+
text1.size());
357+
358+
ASSERT_TRUE(exec_status.ok()) << "Failed for test case: '" << test_case.first << "'";
359+
360+
const auto& res_col =
361+
assert_cast<const ColumnFloat32&>(*block.get_by_position(result_idx).column);
362+
float val = res_col.get_data()[0];
363+
ASSERT_FLOAT_EQ(val, test_case.second)
364+
<< "Failed for test case: '" << test_case.first
365+
<< "', expected: " << test_case.second << ", got: " << val;
366+
}
367+
368+
unsetenv("AI_TEST_RESULT");
369+
}
370+
298371
TEST(AIFunctionTest, AIFilterTest) {
299372
FunctionAIFilter function;
300373

@@ -343,6 +416,87 @@ TEST(AIFunctionTest, AIFilterExecuteTest) {
343416
ASSERT_TRUE(val == 0);
344417
}
345418

419+
TEST(AIFunctionTest, AIFilterTrimWhitespace) {
420+
auto runtime_state = std::make_unique<MockRuntimeState>();
421+
auto ctx = FunctionContext::create_context(runtime_state.get(), {}, {});
422+
423+
std::vector<std::pair<std::string, UInt8>> test_cases = {
424+
{"0", 0}, {"1", 1}, {" 0", 0}, {"0 ", 0},
425+
{" 0 ", 0}, {"\n0", 0}, {"0\n", 0}, {"\n0\n", 0},
426+
{"\t1\t", 1}, {" \n\t1 \n\t", 1}, {" 1 ", 1}, {"\r\n0\r\n", 0}};
427+
428+
for (const auto& test_case : test_cases) {
429+
setenv("AI_TEST_RESULT", test_case.first.c_str(), 1);
430+
431+
std::vector<std::string> resources = {"mock_resource"};
432+
std::vector<std::string> texts = {"Test input"};
433+
auto col_resource = ColumnHelper::create_column<DataTypeString>(resources);
434+
auto col_text = ColumnHelper::create_column<DataTypeString>(texts);
435+
436+
Block block;
437+
block.insert({std::move(col_resource), std::make_shared<DataTypeString>(), "resource"});
438+
block.insert({std::move(col_text), std::make_shared<DataTypeString>(), "text"});
439+
block.insert({nullptr, std::make_shared<DataTypeBool>(), "result"});
440+
441+
ColumnNumbers arguments = {0, 1};
442+
size_t result_idx = 2;
443+
444+
auto filter_func = FunctionAIFilter::create();
445+
Status exec_status =
446+
filter_func->execute_impl(ctx.get(), block, arguments, result_idx, texts.size());
447+
448+
ASSERT_TRUE(exec_status.ok()) << "Failed for test case: '" << test_case.first << "'";
449+
450+
const auto& res_col =
451+
assert_cast<const ColumnUInt8&>(*block.get_by_position(result_idx).column);
452+
UInt8 val = res_col.get_data()[0];
453+
ASSERT_EQ(val, test_case.second)
454+
<< "Failed for test case: '" << test_case.first
455+
<< "', expected: " << (int)test_case.second << ", got: " << (int)val;
456+
}
457+
458+
unsetenv("AI_TEST_RESULT");
459+
}
460+
461+
TEST(AIFunctionTest, AIFilterInvalidValue) {
462+
auto runtime_state = std::make_unique<MockRuntimeState>();
463+
auto ctx = FunctionContext::create_context(runtime_state.get(), {}, {});
464+
465+
std::vector<std::string> invalid_cases = {
466+
"2", "maybe", "ok", "", " ", "01", "0.5", "sure", "truee", "falsee",
467+
"yess", "noo", "true", "false", "yes", "no", "TRUE", "FALSE", "YES", "NO"};
468+
469+
for (const auto& invalid_value : invalid_cases) {
470+
setenv("AI_TEST_RESULT", invalid_value.c_str(), 1);
471+
472+
std::vector<std::string> resources = {"mock_resource"};
473+
std::vector<std::string> texts = {"Test input"};
474+
auto col_resource = ColumnHelper::create_column<DataTypeString>(resources);
475+
auto col_text = ColumnHelper::create_column<DataTypeString>(texts);
476+
477+
Block block;
478+
block.insert({std::move(col_resource), std::make_shared<DataTypeString>(), "resource"});
479+
block.insert({std::move(col_text), std::make_shared<DataTypeString>(), "text"});
480+
block.insert({nullptr, std::make_shared<DataTypeBool>(), "result"});
481+
482+
ColumnNumbers arguments = {0, 1};
483+
size_t result_idx = 2;
484+
485+
auto filter_func = FunctionAIFilter::create();
486+
Status exec_status =
487+
filter_func->execute_impl(ctx.get(), block, arguments, result_idx, texts.size());
488+
489+
ASSERT_FALSE(exec_status.ok())
490+
<< "Should have failed for invalid value: '" << invalid_value << "'";
491+
ASSERT_TRUE(exec_status.to_string().find("Failed to parse boolean value") !=
492+
std::string::npos)
493+
<< "Error message should mention boolean parsing for value: '" << invalid_value
494+
<< "'";
495+
}
496+
497+
unsetenv("AI_TEST_RESULT");
498+
}
499+
346500
TEST(AIFunctionTest, ResourceNotFound) {
347501
auto runtime_state = std::make_unique<MockRuntimeState>();
348502
auto ctx = FunctionContext::create_context(runtime_state.get(), {}, {});

0 commit comments

Comments
 (0)