Skip to content

Commit d6c16ea

Browse files
committed
LLVMBuildUtils: Check size in string equal comparison
1 parent ad21ea9 commit d6c16ea

File tree

2 files changed

+75
-38
lines changed

2 files changed

+75
-38
lines changed

src/engine/internal/llvm/llvmbuildutils.cpp

Lines changed: 73 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -1357,8 +1357,7 @@ llvm::Value *LLVMBuildUtils::createStringComparison(LLVMRegister *arg1, LLVMRegi
13571357
// Explicitly cast to string
13581358
llvm::Value *string1 = castValue(arg1, Compiler::StaticType::String);
13591359
llvm::Value *string2 = castValue(arg2, Compiler::StaticType::String);
1360-
llvm::Value *cmp = m_builder.CreateCall(caseSensitive ? m_functions.resolve_string_compare_case_sensitive() : m_functions.resolve_string_compare_case_insensitive(), { string1, string2 });
1361-
llvm::Value *result = m_builder.CreateICmpEQ(cmp, m_builder.getInt32(0));
1360+
llvm::Value *result = createStringsEqualComparison(string1, string2, caseSensitive);
13621361
m_builder.CreateBr(nextBlock);
13631362

13641363
llvm::BasicBlock *compareBlockNext = m_builder.GetInsertBlock();
@@ -1846,13 +1845,13 @@ llvm::Value *LLVMBuildUtils::createStringAndStringComparison(LLVMRegister *arg1,
18461845
llvm::Value *value1 = castValue(arg1, Compiler::StaticType::String);
18471846
llvm::Value *value2 = castValue(arg2, Compiler::StaticType::String);
18481847

1848+
if (type == Comparison::EQ)
1849+
return createStringsEqualComparison(value1, value2, false);
1850+
18491851
llvm::Value *cmp = m_builder.CreateCall(m_functions.resolve_string_compare_case_insensitive(), { value1, value2 });
18501852
llvm::Value *zero = llvm::ConstantInt::get(m_builder.getInt32Ty(), 0, true);
18511853

18521854
switch (type) {
1853-
case Comparison::EQ:
1854-
return m_builder.CreateICmpEQ(cmp, zero);
1855-
18561855
case Comparison::GT:
18571856
return m_builder.CreateICmpSGT(cmp, zero);
18581857

@@ -1970,37 +1969,39 @@ llvm::Value *LLVMBuildUtils::createNumberAndStringComparison(LLVMRegister *arg1,
19701969
m_builder.SetInsertPoint(stringBlock);
19711970
llvm::Value *stringValue = addStringAlloca();
19721971
m_builder.CreateCall(m_functions.resolve_value_doubleToStringPtr(), { value1, stringValue });
1973-
llvm::Value *cmp = m_builder.CreateCall(m_functions.resolve_string_compare_case_insensitive(), { stringValue, value2 });
19741972

1975-
llvm::Value *zero = llvm::ConstantInt::get(m_builder.getInt32Ty(), 0, true);
19761973
llvm::Value *stringCmp;
19771974

1978-
switch (type) {
1979-
case Comparison::EQ:
1980-
stringCmp = m_builder.CreateICmpEQ(cmp, zero);
1981-
break;
1975+
if (type == Comparison::EQ)
1976+
stringCmp = createStringsEqualComparison(stringValue, value2, false);
1977+
else {
1978+
llvm::Value *cmp = m_builder.CreateCall(m_functions.resolve_string_compare_case_insensitive(), { stringValue, value2 });
1979+
llvm::Value *zero = llvm::ConstantInt::get(m_builder.getInt32Ty(), 0, true);
19821980

1983-
case Comparison::GT:
1984-
stringCmp = m_builder.CreateICmpSGT(cmp, zero);
1985-
break;
1981+
switch (type) {
1982+
case Comparison::GT:
1983+
stringCmp = m_builder.CreateICmpSGT(cmp, zero);
1984+
break;
19861985

1987-
case Comparison::LT:
1988-
stringCmp = m_builder.CreateICmpSLT(cmp, zero);
1989-
break;
1986+
case Comparison::LT:
1987+
stringCmp = m_builder.CreateICmpSLT(cmp, zero);
1988+
break;
19901989

1991-
default:
1992-
assert(false);
1993-
return nullptr;
1990+
default:
1991+
assert(false);
1992+
return nullptr;
1993+
}
19941994
}
19951995

1996+
llvm::BasicBlock *stringBlockNext = m_builder.GetInsertBlock();
19961997
m_builder.CreateBr(nextBlock);
19971998

19981999
// Merge the results
19992000
m_builder.SetInsertPoint(nextBlock);
20002001

20012002
llvm::PHINode *result = m_builder.CreatePHI(m_builder.getInt1Ty(), 2);
20022003
result->addIncoming(numberCmp, numberBlock);
2003-
result->addIncoming(stringCmp, stringBlock);
2004+
result->addIncoming(stringCmp, stringBlockNext);
20042005

20052006
return result;
20062007
}
@@ -2051,42 +2052,76 @@ llvm::Value *LLVMBuildUtils::createBoolAndStringComparison(LLVMRegister *arg1, L
20512052
// String comparison
20522053
m_builder.SetInsertPoint(stringBlock);
20532054
llvm::Value *stringValue = m_builder.CreateCall(m_functions.resolve_value_boolToStringPtr(), { value1 });
2054-
llvm::Value *cmp = m_builder.CreateCall(m_functions.resolve_string_compare_case_insensitive(), { stringValue, value2 });
2055-
// NOTE: Do not free the string!
20562055

2057-
llvm::Value *zero = llvm::ConstantInt::get(m_builder.getInt32Ty(), 0, true);
20582056
llvm::Value *stringCmp;
20592057

2060-
switch (type) {
2061-
case Comparison::EQ:
2062-
stringCmp = m_builder.CreateICmpEQ(cmp, zero);
2063-
break;
2058+
if (type == Comparison::EQ)
2059+
stringCmp = createStringsEqualComparison(stringValue, value2, false);
2060+
else {
2061+
llvm::Value *cmp = m_builder.CreateCall(m_functions.resolve_string_compare_case_insensitive(), { stringValue, value2 });
2062+
llvm::Value *zero = llvm::ConstantInt::get(m_builder.getInt32Ty(), 0, true);
20642063

2065-
case Comparison::GT:
2066-
stringCmp = m_builder.CreateICmpSGT(cmp, zero);
2067-
break;
2064+
switch (type) {
2065+
case Comparison::GT:
2066+
stringCmp = m_builder.CreateICmpSGT(cmp, zero);
2067+
break;
20682068

2069-
case Comparison::LT:
2070-
stringCmp = m_builder.CreateICmpSLT(cmp, zero);
2071-
break;
2069+
case Comparison::LT:
2070+
stringCmp = m_builder.CreateICmpSLT(cmp, zero);
2071+
break;
20722072

2073-
default:
2074-
assert(false);
2075-
return nullptr;
2073+
default:
2074+
assert(false);
2075+
return nullptr;
2076+
}
20762077
}
20772078

2079+
llvm::BasicBlock *stringBlockNext = m_builder.GetInsertBlock();
20782080
m_builder.CreateBr(nextBlock);
20792081

20802082
// Merge the results
20812083
m_builder.SetInsertPoint(nextBlock);
20822084

20832085
llvm::PHINode *result = m_builder.CreatePHI(m_builder.getInt1Ty(), 2);
20842086
result->addIncoming(numberCmp, numberBlock);
2085-
result->addIncoming(stringCmp, stringBlock);
2087+
result->addIncoming(stringCmp, stringBlockNext);
20862088

20872089
return result;
20882090
}
20892091

2092+
llvm::Value *LLVMBuildUtils::createStringsEqualComparison(llvm::Value *stringPtr1, llvm::Value *stringPtr2, bool caseSensitive)
2093+
{
2094+
llvm::Value *sizePtr1 = m_builder.CreateStructGEP(m_stringPtrType, stringPtr1, 1);
2095+
llvm::Value *size1 = m_builder.CreateLoad(m_builder.getInt64Ty(), sizePtr1);
2096+
2097+
llvm::Value *sizePtr2 = m_builder.CreateStructGEP(m_stringPtrType, stringPtr2, 1);
2098+
llvm::Value *size2 = m_builder.CreateLoad(m_builder.getInt64Ty(), sizePtr2);
2099+
2100+
llvm::Value *sameSize = m_builder.CreateICmpEQ(size1, size2);
2101+
2102+
llvm::BasicBlock *compareBlock = llvm::BasicBlock::Create(m_llvmCtx, "stringsEqual.compare", m_function);
2103+
llvm::BasicBlock *differentBlock = llvm::BasicBlock::Create(m_llvmCtx, "stringsEqual.different", m_function);
2104+
llvm::BasicBlock *nextBlock = llvm::BasicBlock::Create(m_llvmCtx, "stringsEqual.next", m_function);
2105+
m_builder.CreateCondBr(sameSize, compareBlock, differentBlock);
2106+
2107+
m_builder.SetInsertPoint(compareBlock);
2108+
llvm::FunctionCallee func = caseSensitive ? m_functions.resolve_string_compare_case_sensitive() : m_functions.resolve_string_compare_case_insensitive();
2109+
llvm::Value *cmp = m_builder.CreateCall(func, { stringPtr1, stringPtr2 });
2110+
llvm::Value *result = m_builder.CreateICmpEQ(cmp, m_builder.getInt32(0));
2111+
m_builder.CreateBr(nextBlock);
2112+
2113+
m_builder.SetInsertPoint(differentBlock);
2114+
m_builder.CreateBr(nextBlock);
2115+
2116+
m_builder.SetInsertPoint(nextBlock);
2117+
2118+
llvm::PHINode *phi = m_builder.CreatePHI(m_builder.getInt1Ty(), 2, "stringsEqual");
2119+
phi->addIncoming(result, compareBlock);
2120+
phi->addIncoming(m_builder.getInt1(false), differentBlock);
2121+
2122+
return phi;
2123+
}
2124+
20902125
llvm::Value *LLVMBuildUtils::getVariablePtr(llvm::Value *targetVariables, Variable *variable)
20912126
{
20922127
if (!m_target->isStage() && variable->target() == m_target) {

src/engine/internal/llvm/llvmbuildutils.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -148,6 +148,8 @@ class LLVMBuildUtils
148148
llvm::Value *createNumberAndStringComparison(LLVMRegister *arg1, LLVMRegister *arg2, Comparison type);
149149
llvm::Value *createBoolAndStringComparison(LLVMRegister *arg1, LLVMRegister *arg2, Comparison type);
150150

151+
llvm::Value *createStringsEqualComparison(llvm::Value *stringPtr1, llvm::Value *stringPtr2, bool caseSensitive);
152+
151153
llvm::Value *getVariablePtr(llvm::Value *targetVariables, Variable *variable);
152154
llvm::Value *getListPtr(llvm::Value *targetLists, List *list);
153155
llvm::Value *getListDataPtr(const LLVMListPtr &listPtr);

0 commit comments

Comments
 (0)