Skip to content

Commit bbccd30

Browse files
scal444evasnow1992
authored andcommitted
Fix shared memory overflow due to config setting error (#98)
1 parent 8ce9529 commit bbccd30

File tree

3 files changed

+127
-5
lines changed

3 files changed

+127
-5
lines changed

src/substruct/recursive_preprocessor.cu

Lines changed: 22 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,13 @@ void LeafSubpatterns::buildAllPatterns(const MoleculesHost& queriesHost) {
8989
}
9090
}
9191

92+
if (patternsHost.numMolecules() > 0) {
93+
for (size_t i = 0; i < patternsHost.numMolecules(); ++i) {
94+
const int atoms = patternsHost.batchAtomStarts[i + 1] - patternsHost.batchAtomStarts[i];
95+
maxPatternAtoms_ = std::max(maxPatternAtoms_, atoms);
96+
}
97+
}
98+
9299
// Second pass: build precomputed BatchedPatternEntry structures
93100
perQueryPatterns.resize(numQueries);
94101
perQueryMaxDepth.resize(numQueries, 0);
@@ -146,7 +153,7 @@ void LeafSubpatterns::buildAllPatterns(const MoleculesHost& queriesHost) {
146153
}
147154

148155
for (int d = 0; d <= queryMaxDepth; ++d) {
149-
const auto& srcEntries = perQueryPatterns[queryIdx][d];
156+
const auto& srcEntries = perQueryPatterns[queryIdx][d];
150157
auto& destEntries = allQueriesPatternsAtDepth[d];
151158
destEntries.insert(destEntries.end(), srcEntries.begin(), srcEntries.end());
152159
}
@@ -190,6 +197,11 @@ void RecursivePatternPreprocessor::preprocessMiniBatch(
190197

191198
scratch.setStream(stream);
192199

200+
const auto baseProps = getTemplateConfigProperties(templateConfig);
201+
const int paintQueryAtoms = std::max(baseProps.maxQueryAtoms, leafSubpatterns_.maxPatternAtoms());
202+
const int paintTargetAtoms = std::max(baseProps.maxTargetAtoms, paintQueryAtoms);
203+
const auto paintConfig = selectTemplateConfig(paintTargetAtoms, paintQueryAtoms, baseProps.maxBondsPerAtom);
204+
193205
constexpr int gsiBuffersPerBlock = 2;
194206

195207
const int maxPaintPairsPerSubBatch = std::max(miniBatchSize, 1024);
@@ -257,7 +269,7 @@ void RecursivePatternPreprocessor::preprocessMiniBatch(
257269
}
258270
isFirstLabelKernel = false;
259271

260-
launchLabelMatrixPaintKernel(templateConfig,
272+
launchLabelMatrixPaintKernel(paintConfig,
261273
targetsDevice.view<MoleculeType::Target>(),
262274
leafSubpatterns_.view(),
263275
scratch.patternEntries.data(),
@@ -273,7 +285,7 @@ void RecursivePatternPreprocessor::preprocessMiniBatch(
273285
zeroBuffers,
274286
stream);
275287

276-
launchSubstructPaintKernel(templateConfig,
288+
launchSubstructPaintKernel(paintConfig,
277289
algorithm,
278290
targetsDevice.view<MoleculeType::Target>(),
279291
leafSubpatterns_.view(),
@@ -376,6 +388,11 @@ void preprocessRecursiveSmarts(SubstructTemplateConfig templateConfig,
376388
const int lastTargetInMiniBatch = (miniBatchPairOffset + miniBatchSize - 1) / numQueries;
377389
const int numTargetsInMiniBatch = lastTargetInMiniBatch - firstTargetInMiniBatch + 1;
378390

391+
const auto baseProps = getTemplateConfigProperties(templateConfig);
392+
const int paintQueryAtoms = std::max(baseProps.maxQueryAtoms, leafSubpatterns.maxPatternAtoms());
393+
const int paintTargetAtoms = std::max(baseProps.maxTargetAtoms, paintQueryAtoms);
394+
const auto paintConfig = selectTemplateConfig(paintTargetAtoms, paintQueryAtoms, baseProps.maxBondsPerAtom);
395+
379396
constexpr int gsiBuffersPerBlock = 2;
380397

381398
const int maxPaintPairsPerSubBatch = std::max(miniBatchSize, 1024);
@@ -449,7 +466,7 @@ void preprocessRecursiveSmarts(SubstructTemplateConfig templateConfig,
449466
}
450467
isFirstLabelKernel = false;
451468

452-
launchLabelMatrixPaintKernel(templateConfig,
469+
launchLabelMatrixPaintKernel(paintConfig,
453470
targetsDevice.view<MoleculeType::Target>(),
454471
leafSubpatterns.view(),
455472
scratch.patternEntries.data(),
@@ -465,7 +482,7 @@ void preprocessRecursiveSmarts(SubstructTemplateConfig templateConfig,
465482
zeroBuffers,
466483
stream);
467484

468-
launchSubstructPaintKernel(templateConfig,
485+
launchSubstructPaintKernel(paintConfig,
469486
algorithm,
470487
targetsDevice.view<MoleculeType::Target>(),
471488
leafSubpatterns.view(),

src/substruct/recursive_preprocessor.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,8 @@ struct LeafSubpatterns {
8787
/// Max recursion depth across all queries
8888
int allQueriesMaxDepth = 0;
8989

90+
int maxPatternAtoms_ = 0;
91+
9092
LeafSubpatterns() = default;
9193

9294
/**
@@ -130,6 +132,11 @@ struct LeafSubpatterns {
130132
*/
131133
[[nodiscard]] size_t size() const { return patternIndexMap.size(); }
132134

135+
/**
136+
* @brief Max atom count across all leaf subpatterns.
137+
*/
138+
[[nodiscard]] int maxPatternAtoms() const { return maxPatternAtoms_; }
139+
133140
/**
134141
* @brief Get view for kernel access.
135142
*/

tests/test_recursive_preprocessor.cu

Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -177,3 +177,101 @@ TEST(RecursivePreprocessorTest, PaintsBitsForSimpleRecursivePattern) {
177177
EXPECT_FALSE(hasRecursiveBit(1, 0));
178178
EXPECT_FALSE(hasRecursiveBit(1, 1));
179179
}
180+
181+
/**
182+
* @brief Leaf subpattern with more atoms than the caller's MaxQueryAtoms
183+
* template tier should not overflow the shared memory label matrix.
184+
*/
185+
TEST(RecursivePreprocessorTest, LeafPatternLargerThanConfigMaxQueryAtoms) {
186+
ScopedStream stream;
187+
188+
auto target = makeMolFromSmiles("CCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCC");
189+
auto query = makeMolFromSmarts("[$(*~C~C~C~C~C~C~C~C~C~C~C~C~C~C~C~C~C)]");
190+
191+
ASSERT_NE(target, nullptr);
192+
ASSERT_NE(query, nullptr);
193+
194+
std::vector<const RDKit::ROMol*> targets = {target.get()};
195+
std::vector<const RDKit::ROMol*> queries = {query.get()};
196+
std::vector<int> emptySortOrder;
197+
198+
MoleculesHost targetsHost;
199+
nvMolKit::buildTargetBatchParallelInto(targetsHost, 1, targets, emptySortOrder);
200+
MoleculesHost queriesHost = nvMolKit::buildQueryBatchParallel(queries, emptySortOrder, 1);
201+
202+
const int maxTargetAtoms = maxAtomsPerTarget(targetsHost);
203+
ASSERT_GE(maxTargetAtoms, 32);
204+
205+
MoleculesDevice targetsDevice(stream.stream());
206+
targetsDevice.copyFromHost(targetsHost);
207+
208+
RecursivePatternPreprocessor preprocessor;
209+
preprocessor.buildPatterns(queriesHost);
210+
preprocessor.syncToDevice(stream.stream());
211+
212+
const LeafSubpatterns& leafSubpatterns = preprocessor.leafSubpatterns();
213+
ASSERT_FALSE(leafSubpatterns.empty());
214+
ASSERT_GT(leafSubpatterns.maxPatternAtoms(), 16);
215+
216+
const int numTargets = 1;
217+
const int numQueries = 1;
218+
const int miniBatchSize = numTargets * numQueries;
219+
220+
AsyncDeviceVector<int> pairMatchStartsDev(static_cast<size_t>(miniBatchSize + 1), stream.stream());
221+
pairMatchStartsDev.zero();
222+
MiniBatchResultsDevice miniBatchResults(stream.stream());
223+
miniBatchResults.allocateMiniBatch(miniBatchSize, pairMatchStartsDev.data(), 0, numQueries, maxTargetAtoms, 2);
224+
const std::vector<int> atomCounts = queryAtomCounts(queriesHost);
225+
miniBatchResults.setQueryAtomCounts(atomCounts.data(), atomCounts.size());
226+
miniBatchResults.zeroRecursiveBits();
227+
228+
RecursiveScratchBuffers scratch(stream.stream());
229+
scratch.allocateBuffers(256);
230+
231+
std::array<std::vector<BatchedPatternEntry>, kMaxSmartsNestingDepth + 1> patternsAtDepth;
232+
for (auto& vec : patternsAtDepth) {
233+
vec.clear();
234+
}
235+
236+
const int queryMaxDepth = leafSubpatterns.perQueryMaxDepth.empty() ? 0 : leafSubpatterns.perQueryMaxDepth[0];
237+
for (int depth = 0; depth <= queryMaxDepth; ++depth) {
238+
const auto& src = leafSubpatterns.perQueryPatterns[0][depth];
239+
patternsAtDepth[depth].insert(patternsAtDepth[depth].end(), src.begin(), src.end());
240+
}
241+
242+
preprocessor.preprocessMiniBatch(SubstructTemplateConfig::Config_T32_Q16_B4,
243+
targetsDevice,
244+
miniBatchResults,
245+
numQueries,
246+
0,
247+
miniBatchSize,
248+
SubstructAlgorithm::GSI,
249+
stream.stream(),
250+
scratch,
251+
patternsAtDepth,
252+
queryMaxDepth,
253+
0,
254+
numTargets,
255+
nullptr,
256+
0);
257+
258+
cudaCheckError(cudaStreamSynchronize(stream.stream()));
259+
cudaCheckError(cudaGetLastError());
260+
261+
std::vector<uint32_t> hostBits(static_cast<size_t>(miniBatchSize) * maxTargetAtoms);
262+
cudaCheckError(cudaMemcpyAsync(hostBits.data(),
263+
miniBatchResults.recursiveMatchBits(),
264+
hostBits.size() * sizeof(uint32_t),
265+
cudaMemcpyDeviceToHost,
266+
stream.stream()));
267+
cudaCheckError(cudaStreamSynchronize(stream.stream()));
268+
269+
bool anyBitSet = false;
270+
for (size_t i = 0; i < hostBits.size(); ++i) {
271+
if (hostBits[i] != 0) {
272+
anyBitSet = true;
273+
break;
274+
}
275+
}
276+
EXPECT_TRUE(anyBitSet);
277+
}

0 commit comments

Comments
 (0)