@@ -89,6 +89,13 @@ void LeafSubpatterns::buildAllPatterns(const MoleculesHost& queriesHost) {
8989 }
9090 }
9191
92+ if (patternsHost.numMolecules () > 0 ) {
93+ for (size_t i = 0 ; i < patternsHost.numMolecules (); ++i) {
94+ const int atoms = patternsHost.batchAtomStarts [i + 1 ] - patternsHost.batchAtomStarts [i];
95+ maxPatternAtoms_ = std::max (maxPatternAtoms_, atoms);
96+ }
97+ }
98+
9299 // Second pass: build precomputed BatchedPatternEntry structures
93100 perQueryPatterns.resize (numQueries);
94101 perQueryMaxDepth.resize (numQueries, 0 );
@@ -146,7 +153,7 @@ void LeafSubpatterns::buildAllPatterns(const MoleculesHost& queriesHost) {
146153 }
147154
148155 for (int d = 0 ; d <= queryMaxDepth; ++d) {
149- const auto & srcEntries = perQueryPatterns[queryIdx][d];
156+ const auto & srcEntries = perQueryPatterns[queryIdx][d];
150157 auto & destEntries = allQueriesPatternsAtDepth[d];
151158 destEntries.insert (destEntries.end (), srcEntries.begin (), srcEntries.end ());
152159 }
@@ -190,6 +197,11 @@ void RecursivePatternPreprocessor::preprocessMiniBatch(
190197
191198 scratch.setStream (stream);
192199
200+ const auto baseProps = getTemplateConfigProperties (templateConfig);
201+ const int paintQueryAtoms = std::max (baseProps.maxQueryAtoms , leafSubpatterns_.maxPatternAtoms ());
202+ const int paintTargetAtoms = std::max (baseProps.maxTargetAtoms , paintQueryAtoms);
203+ const auto paintConfig = selectTemplateConfig (paintTargetAtoms, paintQueryAtoms, baseProps.maxBondsPerAtom );
204+
193205 constexpr int gsiBuffersPerBlock = 2 ;
194206
195207 const int maxPaintPairsPerSubBatch = std::max (miniBatchSize, 1024 );
@@ -257,7 +269,7 @@ void RecursivePatternPreprocessor::preprocessMiniBatch(
257269 }
258270 isFirstLabelKernel = false ;
259271
260- launchLabelMatrixPaintKernel (templateConfig ,
272+ launchLabelMatrixPaintKernel (paintConfig ,
261273 targetsDevice.view <MoleculeType::Target>(),
262274 leafSubpatterns_.view (),
263275 scratch.patternEntries .data (),
@@ -273,7 +285,7 @@ void RecursivePatternPreprocessor::preprocessMiniBatch(
273285 zeroBuffers,
274286 stream);
275287
276- launchSubstructPaintKernel (templateConfig ,
288+ launchSubstructPaintKernel (paintConfig ,
277289 algorithm,
278290 targetsDevice.view <MoleculeType::Target>(),
279291 leafSubpatterns_.view (),
@@ -376,6 +388,11 @@ void preprocessRecursiveSmarts(SubstructTemplateConfig templateConfig,
376388 const int lastTargetInMiniBatch = (miniBatchPairOffset + miniBatchSize - 1 ) / numQueries;
377389 const int numTargetsInMiniBatch = lastTargetInMiniBatch - firstTargetInMiniBatch + 1 ;
378390
391+ const auto baseProps = getTemplateConfigProperties (templateConfig);
392+ const int paintQueryAtoms = std::max (baseProps.maxQueryAtoms , leafSubpatterns.maxPatternAtoms ());
393+ const int paintTargetAtoms = std::max (baseProps.maxTargetAtoms , paintQueryAtoms);
394+ const auto paintConfig = selectTemplateConfig (paintTargetAtoms, paintQueryAtoms, baseProps.maxBondsPerAtom );
395+
379396 constexpr int gsiBuffersPerBlock = 2 ;
380397
381398 const int maxPaintPairsPerSubBatch = std::max (miniBatchSize, 1024 );
@@ -449,7 +466,7 @@ void preprocessRecursiveSmarts(SubstructTemplateConfig templateConfig,
449466 }
450467 isFirstLabelKernel = false ;
451468
452- launchLabelMatrixPaintKernel (templateConfig ,
469+ launchLabelMatrixPaintKernel (paintConfig ,
453470 targetsDevice.view <MoleculeType::Target>(),
454471 leafSubpatterns.view (),
455472 scratch.patternEntries .data (),
@@ -465,7 +482,7 @@ void preprocessRecursiveSmarts(SubstructTemplateConfig templateConfig,
465482 zeroBuffers,
466483 stream);
467484
468- launchSubstructPaintKernel (templateConfig ,
485+ launchSubstructPaintKernel (paintConfig ,
469486 algorithm,
470487 targetsDevice.view <MoleculeType::Target>(),
471488 leafSubpatterns.view (),
0 commit comments