diff --git a/autotools-init.sh b/autotools-init.sh old mode 100644 new mode 100755 diff --git a/src/IntaRNA/HelixHandlerNoBulgeMax.h b/src/IntaRNA/HelixHandlerNoBulgeMax.h index 278f7a9b..755e6d2b 100644 --- a/src/IntaRNA/HelixHandlerNoBulgeMax.h +++ b/src/IntaRNA/HelixHandlerNoBulgeMax.h @@ -27,7 +27,7 @@ class HelixHandlerNoBulgeMax : public HelixHandler { //! container type for sparse helix information for a given left-most //! base pair (i1,i2) //! it holds both the energy (first) as well as the length of the helix - typedef boost::unordered_map< Interaction::BasePair, HelixData > HelixHash; + typedef boost::unordered_map< Interaction::BasePair, HelixData, Interaction::BasePair::Hash, Interaction::BasePair::Equal > HelixHash; protected: diff --git a/src/IntaRNA/Interaction.cpp b/src/IntaRNA/Interaction.cpp index 6a97b066..301dd3fb 100644 --- a/src/IntaRNA/Interaction.cpp +++ b/src/IntaRNA/Interaction.cpp @@ -27,6 +27,16 @@ operator<<(std::ostream& out, const Interaction::BasePair& bp) //////////////////////////////////////////////////////////////////////////// +std::ostream& +operator<<(std::ostream& out, const Interaction::Boundary& b) +{ + out <<"("< BasePair; + struct BasePair { + size_t first; //!< index in first sequence + size_t second; //!< index in second sequence + + /** + * Construction + * @param first index in first sequence + * @param second index in second sequence + */ + BasePair( const size_t first=RnaSequence::lastPos, const size_t second=RnaSequence::lastPos ) + : first(first), second(second) + {} + + //! hash value computation + struct Hash { + /** + * Hash value computation for an instance + * @param i instance to hash + * @return the respective hash value + */ + size_t operator()(const BasePair &i ) const + { + size_t key = 0; + boost::hash_combine(key, i.first); + boost::hash_combine(key, i.second); + return key; + } + }; + + //! equality check + struct Equal { + /** + * check equality of two instances + * @param lhs instance 1 + * @param rhs instance 2 + * @param true if all indices are equal; false otherwise + */ + bool operator()( const BasePair & lhs, const BasePair & rhs ) const + { + return lhs.first == rhs.first + && lhs.second == rhs.second ; + } + }; + + /** + * @param bp the basepair to compare to + * @return true if this basepair is considered smaller than bp + */ + const bool + operator < ( const BasePair &bp ) const { + if (first < bp.first) + return true; + else if (bp.first < first) + return false; + else if (second < bp.second) + return true; + return false; + } + + /** + * equality check + * @param bp the basepair to compare to + * @return true if this basepair equals bp + */ + const bool + operator == ( const BasePair &bp ) const { + return first == bp.first + && second == bp.second; + } + + /** + * inequality check + * @param bp the basepair to compare to + * @return false if this basepair equals bp + */ + const bool + operator != ( const BasePair &bp ) const { + return first != bp.first + || second != bp.second; + } + + }; //! type of a vector encoding base pair indices that are interacting typedef std::vector PairingVec; @@ -88,6 +169,7 @@ class Interaction { && lhs.j2 == rhs.j2 ; } }; + }; @@ -311,7 +393,15 @@ class Interaction { * @param bp the Interaction base pair object to add * @return the altered stream out */ - friend std::ostream& operator<<(std::ostream& out, const Interaction::BasePair& bp); + friend std::ostream& operator<<(std::ostream& out, const BasePair& bp); + + /** + * Prints the interaction boundary to stream + * @param out the ostream to write to + * @param b the Boundary object to add + * @return the altered stream out + */ + friend std::ostream& operator<<(std::ostream& out, const Boundary& b); /** * Prints the interacting base pairs to stream diff --git a/src/IntaRNA/Makefile.am b/src/IntaRNA/Makefile.am index c41bdf65..f7ff906b 100644 --- a/src/IntaRNA/Makefile.am +++ b/src/IntaRNA/Makefile.am @@ -58,6 +58,7 @@ libIntaRNA_a_HEADERS = \ PredictionTrackerSpotProb.h \ PredictionTrackerSpotProbAll.h \ PredictionTrackerProfileSpotProb.h \ + PredictionTrackerBasePairProb.h \ Predictor.h \ PredictorMfe.h \ PredictorMfeSeedOnly.h \ @@ -120,6 +121,7 @@ libIntaRNA_a_SOURCES = \ PredictionTrackerSpotProb.cpp \ PredictionTrackerSpotProbAll.cpp \ PredictionTrackerProfileSpotProb.cpp \ + PredictionTrackerBasePairProb.cpp \ PredictorMfe.cpp \ PredictorMfeSeedOnly.cpp \ PredictorMfe2d.cpp \ diff --git a/src/IntaRNA/NussinovHandler.cpp b/src/IntaRNA/NussinovHandler.cpp index 7ca7e4f7..37727add 100644 --- a/src/IntaRNA/NussinovHandler.cpp +++ b/src/IntaRNA/NussinovHandler.cpp @@ -24,7 +24,7 @@ NussinovHandler::getBasePairs( getBasePairs(k + 1, to - 1, traceback, pairs); } - pairs.push_back(std::pair(k, to)); + pairs.push_back(Interaction::BasePair(k, to)); } } diff --git a/src/IntaRNA/PredictionTracker.h b/src/IntaRNA/PredictionTracker.h index 2bab2328..a9a464d0 100644 --- a/src/IntaRNA/PredictionTracker.h +++ b/src/IntaRNA/PredictionTracker.h @@ -4,6 +4,7 @@ #include "IntaRNA/general.h" +#include "IntaRNA/SeedHandlerIdxOffset.h" namespace IntaRNA { @@ -11,6 +12,7 @@ namespace IntaRNA { * Generic interface to track prediction progress of Predictor instances. * */ +class PredictorMfeEns; class PredictionTracker { @@ -43,9 +45,16 @@ class PredictionTracker , const E_type energy ) = 0; -}; - + /** + * Updates the probability information. + * + * @param predictor the predictor providing the probability information + */ + virtual + void + updateZ( PredictorMfeEns *predictor, SeedHandler* seedHandler ); +}; /////////////////////////////////////////////////////////////////////////// @@ -63,6 +72,15 @@ PredictionTracker::~PredictionTracker() /////////////////////////////////////////////////////////////////////////// +inline +void +PredictionTracker::updateZ( PredictorMfeEns *predictor, SeedHandler* seedHandler ) +{ + // override in PredictionTrackers +} + +/////////////////////////////////////////////////////////////////////////// + } // namespace diff --git a/src/IntaRNA/PredictionTrackerBasePairProb.cpp b/src/IntaRNA/PredictionTrackerBasePairProb.cpp new file mode 100644 index 00000000..3caf2201 --- /dev/null +++ b/src/IntaRNA/PredictionTrackerBasePairProb.cpp @@ -0,0 +1,853 @@ + +#include "PredictionTrackerBasePairProb.h" + +extern "C" { + #include + #include +} + +#include +#include + +namespace IntaRNA { + +////////////////////////////////////////////////////////////////////// + +PredictionTrackerBasePairProb:: +PredictionTrackerBasePairProb( + const InteractionEnergy & energy + , const std::string & fileName + ) + : PredictionTracker() + , energy(energy) + , fileName(fileName) + , probabilityThreshold(0.0001) + , maxDotPlotSize(640) +{ +} + +////////////////////////////////////////////////////////////////////// + +PredictionTrackerBasePairProb:: +~PredictionTrackerBasePairProb() +{ +} + +////////////////////////////////////////////////////////////////////// + +void +PredictionTrackerBasePairProb:: +updateOptimumCalled( const size_t i1, const size_t j1 + , const size_t i2, const size_t j2 + , const E_type curE + ) +{ +} + +////////////////////////////////////////////////////////////////////// + +void +PredictionTrackerBasePairProb:: +updateZ( PredictorMfeEns *predictor, SeedHandler *seedHandler ) +{ + PredictorMfeEns2dSeedExtension* seedPredictor = dynamic_cast(predictor); + isSeedPredictor = (seedPredictor != nullptr); + + // sequence strings + const std::string & rna1 = energy.getAccessibility1().getSequence().asString(); + const std::string & reverseRna2 = energy.getAccessibility2().getAccessibilityOrigin().getSequence().asString(); + + size_t s1 = energy.size1(); + size_t s2 = energy.size2(); + size_t n1 = energy.getAccessibility1().getMaxLength(); + size_t n2 = energy.getAccessibility2().getMaxLength(); + + Z_type maxZ = 0.0; + Interaction::Boundary interactionBoundary; + + // initialize Z_partition + const PredictorMfeEns::Site2Z_hash & Z_partition = predictor->getZPartition(); + + // create index of left/right boundaries + for (auto z = Z_partition.begin(); z != Z_partition.end(); ++z) { + // identify best interaction boundary + Z_type Zstruct = z->second * energy.getBoltzmannWeight(energy.getE(z->first.i1, z->first.j1, z->first.i2, z->first.j2, E_type(0))); + if (Zstruct > maxZ) { + maxZ = Zstruct; + interactionBoundary = z->first; + } + + Interaction::BasePair iBP(z->first.i1, z->first.i2); + Interaction::BasePair jBP(z->first.j1, z->first.j2); + + // create left and jBP index + if (z->first.i1 != z->first.j1 && z->first.i2 != z->first.j2) { + // encode iBP/jBP boundary + // create left index + rightExt[iBP].insert(jBP); + // create right index + if (seedHandler != NULL) { + leftExt[jBP].insert(iBP); + } + } + + } // it (Z_partition) + + // Compute base-pair probabilities via combinations + if (!isSeedPredictor) { + computeBasePairProbsNoSeed(predictor); + } else { + computeBasePairProbs(seedPredictor, seedHandler); + } + + // build plist + struct vrna_elem_prob_s plist[structureProbs.size()+1]; + size_t i = 0; + const Z_type Zall = predictor->getZall(); + for (auto sp = structureProbs.begin(); sp != structureProbs.end(); ++sp) { + if ( (sp->second /Zall) > probabilityThreshold) { + plist[i].i = sp->first.first + 1; + plist[i].j = sp->first.second + 1; + plist[i].p = (sp->second /Zall); + plist[i].type = 0; // base-pair prob + i++; + } + } + plist[i].i = 0; // list end + + // create dot plot + char *name = strdup(fileName.c_str()); + std::string comment = + "Intermolecular base-pair probabilities generated by " + INTARNA_PACKAGE_STRING + " using Vienna RNA package " + VRNA_VERSION; + + generateDotPlotSvg(strdup(rna1.c_str()), strdup(reverseRna2.c_str()), name, plist, comment.c_str(), interactionBoundary, energy); + +} + +//////////////////////////////////////////////////////////////////////////// + +void +PredictionTrackerBasePairProb:: +computeBasePairProbsNoSeed( const PredictorMfeEns *predictor ) +{ + for (auto z = predictor->getZPartition().begin(); z != predictor->getZPartition().end(); ++z) { + assert( !Z_equal(z->second, 0) ); + + Z_type bpProb = z->second * energy.getBoltzmannWeight(energy.getE(z->first.i1, z->first.j1, z->first.i2, z->first.j2, E_type(0))); + + Interaction::BasePair iBP(z->first.i1, z->first.i2); + Interaction::BasePair jBP(z->first.j1, z->first.j2); + + // left end and single bp + updateProb(iBP, bpProb); + + if (iBP < jBP) { + // right end + updateProb(jBP, bpProb); + + // inner (rightExt > jBP) + for (auto right = rightExt[jBP].begin(); right != rightExt[jBP].end(); ++right) { + Z_type innerProb = z->second * getHybridZ(z->first.j1, right->first, z->first.j2, right->second, predictor) + * energy.getBoltzmannWeight(energy.getE(z->first.i1, right->first, z->first.i2, right->second, -energy.getE_init())); + updateProb(jBP, innerProb); + } + } + } +} + +//////////////////////////////////////////////////////////////////////////// + +void +PredictionTrackerBasePairProb:: +computeBasePairProbs( const PredictorMfeEns2dSeedExtension *predictor, const SeedHandler* seedHandler ) +{ + auto ZL_partition = predictor->getZLPartition(); + size_t i1, j1, i2, j2; + for (auto z = predictor->getZPartition().begin(); z != predictor->getZPartition().end(); ++z) { + i1 = z->first.i1; + j1 = z->first.j1; + i2 = z->first.i2; + j2 = z->first.j2; + + Z_type bpZ; + // i external left + bpZ = getHybridZ(i1, j1, i2, j2, predictor) + * energy.getBoltzmannWeight(energy.getE(i1, j1, i2, j2, E_type(0))); + updateProb(Interaction::BasePair(i1,i2), bpZ); + + // j external right (and not single bp) + if (i1!=j1) { + updateProb(Interaction::BasePair(j1,j2), bpZ); + } + + Z_type tempZ; + + // loop internal k (inbetween i and j) + for (size_t k1 = i1+1; k1 < j1; k1++) { + for (size_t k2 = i2+1; k2 < j2; k2++) { + + Interaction::BasePair kBP(k1, k2); + bpZ = 0; + tempZ = 0; + + // k internal + Z_type ZSleft = getHybridZ(i1, k1, i2, k2, predictor); + Z_type ZSright = getHybridZ(k1, j1, k2, j2, predictor); + + // ... ZS:ZS + tempZ = ZSleft * ZSright / energy.getBoltzmannWeight(energy.getE_init()); + bpZ += tempZ; + + // ... ZS:ZN + tempZ = ZSleft * getZRPartition(predictor, seedHandler, k1, j1, k2, j2); + bpZ += tempZ; + + // ... ZN:ZS + tempZ = getZPartitionValue(&ZL_partition, Interaction::Boundary(i1,k1,i2,k2), false) * ZSright / energy.getBoltzmannWeight(energy.getE_init()); + bpZ += tempZ; + + // seeds overlapping k + size_t si1 = RnaSequence::lastPos, si2 = RnaSequence::lastPos; + if (seedHandler->getConstraint().getBasePairs() > 2) { + while( seedHandler->updateToNextSeedWithK(si1, si2, k1, k2, false)) + { + size_t sj1 = si1+seedHandler->getSeedLength1(si1,si2)-1; + size_t sj2 = si2+seedHandler->getSeedLength2(si1,si2)-1; + // seed region mismatch + if ((sj1 == j1 && sj2 != j2) || (sj1 != j1 && sj2 == j2)) { + continue; + } + // check if still in region + if (si1 < i1 || si2 < i2 || sj1 > j1 || sj2 > j2) { + continue; + } + // ZN:seed:ZS + tempZ = getZPartitionValue(&ZL_partition, Interaction::Boundary(i1,si1,i2,si2), false) // contains E_init + * energy.getBoltzmannWeight(seedHandler->getSeedE(si1, si2)) + * getHybridZ(sj1, j1, sj2, j2, predictor) + / energy.getBoltzmannWeight(energy.getE_init()); + bpZ += tempZ; + + // ZS:seed:ZN + tempZ = getHybridZ(i1, si1, i2, si2, predictor) + * energy.getBoltzmannWeight(seedHandler->getSeedE(si1, si2)) + * getZRPartition(predictor, seedHandler, sj1, j1, sj2, j2); + bpZ += tempZ; + + // ZN:seed:ZP + tempZ = getZPartitionValue(&ZL_partition, Interaction::Boundary(i1,si1,i2,si2), false) // contains E_init + * energy.getBoltzmannWeight(seedHandler->getSeedE(si1, si2)) + * (getZHPartition(predictor, seedHandler, sj1, j1, sj2, j2) - getHybridZ(sj1, j1, sj2, j2, predictor)) + / energy.getBoltzmannWeight(energy.getE_init()); + bpZ += tempZ; + + // ZNL:seed':ZNR + size_t spi1 = RnaSequence::lastPos, spi2 = RnaSequence::lastPos; + while( seedHandler->updateToNextSeedWithK(spi1,spi2,k1,k2)) + { + if (seedHandler->areLoopOverlapping(spi1, spi2, si1, si2)) { + size_t spj1 = spi1+seedHandler->getSeedLength1(spi1,spi2)-1; + size_t spj2 = spi2+seedHandler->getSeedLength2(spi1,spi2)-1; + if (k1 < spj1 && k2 < spj2) { + continue; + } + // check if still in region + if (spi1 < i1 || spi2 < i2 ||spj1 > j1 || spj2 > j2) { + continue; + } + tempZ = getZPartitionValue(&ZL_partition, Interaction::Boundary(i1,spi1,i2,spi2), false) // contains E_init + * energy.getBoltzmannWeight(PredictorMfeEns2dSeedExtension::getNonOverlappingEnergy(spi1, spi2, si1, si2, energy, *seedHandler, false)) + * energy.getBoltzmannWeight(seedHandler->getSeedE(si1, si2)) + * getZRPartition(predictor, seedHandler, sj1, j1, sj2, j2); + bpZ += tempZ; + } + } + + } + } + + // add ED values, dangling ends, etc. + bpZ *= energy.getBoltzmannWeight(energy.getE(i1, j1, i2, j2, E_type(0))); + updateProb(kBP, bpZ); + + } // k2 + } // k1 + } // Z_partitions +} + +//////////////////////////////////////////////////////////////////////////// + +Z_type +PredictionTrackerBasePairProb:: +getZPartitionValue( const Site2Z_hash *Zpartition, const Interaction::Boundary & boundary, const bool addZInit ) +{ + auto keyEntry = Zpartition->find(boundary); + if ( Zpartition->find(boundary) == Zpartition->end() ) { + return 0; + } else { + if (addZInit) { + return keyEntry->second * energy.getBoltzmannWeight(energy.getE_init()); + } else { + return keyEntry->second; + } + } +} + +//////////////////////////////////////////////////////////////////////////// + +Z_type +PredictionTrackerBasePairProb:: +getZHPartition( const PredictorMfeEns2dSeedExtension *predictor, const SeedHandler* seedHandler + , const size_t i1, const size_t j1 + , const size_t i2, const size_t j2 ) +{ + // sanity check + if (!energy.areComplementary(i1,i2) || !energy.areComplementary(j1,j2)) { + return Z_type(0); + } + // single bp boundary with ZH == 1 + if (i1==j1 && i2==j2) { + return Z_type(energy.getBoltzmannWeight(energy.getE_init())); + } + Interaction::Boundary boundary(i1,j1,i2,j2); + + // check if ZH available from predictor + Z_type ZH = getZPartitionValue(&predictor->getZHPartition(), boundary, true); + if (Z_equal(ZH, 0.0)) { + // check if ZH available in missing ZH partitions + auto ZHentry = ZH_partition_missing.find(boundary); + if ( ZHentry != ZH_partition_missing.end() ) { + ZH = ZHentry->second; + } else { + ZH = fillHybridZ(i1, j1, i2, j2, seedHandler); + } + } + + return ZH; +} + +//////////////////////////////////////////////////////////////////////////// + +Z_type +PredictionTrackerBasePairProb:: +getZRPartition( const PredictorMfeEns2dSeedExtension *predictor, const SeedHandler* seedHandler + , const size_t i1, const size_t j1 + , const size_t i2, const size_t j2 ) +{ + // memoization + Interaction::Boundary boundary(i1,j1,i2,j2); + auto keyEntry = ZR_partition.find(boundary); + if ( ZR_partition.find(boundary) != ZR_partition.end() ) { + return ZR_partition[boundary]; + } + + // single bp boundary with ZNR == 1 + if (i1==j1 && i2==j2 && energy.areComplementary(i1,i2)) { + return Z_type(1); + } + + Z_type partZ = getZHPartition(predictor, seedHandler, i1, j1, i2, j2); + Z_type ZS = getHybridZ(boundary, predictor); + assert( ZS <= partZ ); + partZ -= ZS; + + // remove Einit + partZ /= energy.getBoltzmannWeight(energy.getE_init()); + + // iterate all seeds that overlap anchor seed sa on the right side + size_t si1 = RnaSequence::lastPos, si2 = RnaSequence::lastPos; + while( seedHandler->updateToNextSeedWithK(si1, si2, i1, i2, false)) + { + size_t sj1 = si1 + seedHandler->getSeedLength1(si1, si2) - 1; + size_t sj2 = si2 + seedHandler->getSeedLength2(si1, si2) - 1; + // check if still in region + if ( j1 < sj1 || j2 < sj2 ) continue; + E_type Eoverlap = seedHandler->getSeedE(si1, si2) + - PredictorMfeEns2dSeedExtension::getNonOverlappingEnergy(si1, si2, i1, i2, energy, *seedHandler, false); + Z_type corrZterm = energy.getBoltzmannWeight(Eoverlap) * getZRPartition(predictor, seedHandler, sj1, j1, sj2, j2); + assert(corrZterm <= partZ); + partZ -= corrZterm; + } + + // store ZR_partition + ZR_partition[boundary] = partZ; + + return partZ; +} + +//////////////////////////////////////////////////////////////////////////// + +Z_type +PredictionTrackerBasePairProb:: +getHybridZ( const size_t i1, const size_t j1 + , const size_t i2, const size_t j2 + , const PredictorMfeEns *predictor) +{ + Interaction::Boundary boundary(i1, j1, i2, j2); + return getHybridZ(boundary, predictor); +} + +//////////////////////////////////////////////////////////////////////////// + +Z_type +PredictionTrackerBasePairProb:: +getHybridZ( const Interaction::Boundary & boundary + , const PredictorMfeEns *predictor) +{ + // check in original data + if ( predictor->getZPartition().find(boundary) != predictor->getZPartition().end() ) { + return predictor->getZPartition().find(boundary)->second; + } else { + // fall back + return Z_type(0); + } +} + +//////////////////////////////////////////////////////////////////////////// + +Z_type +PredictionTrackerBasePairProb:: +getBasePairProb( const size_t i1, const size_t i2 + , const PredictorMfeEns *predictor) +{ + Interaction::BasePair bp(i1, i2); + if ( structureProbs.find(bp) == structureProbs.end() ) { + return Z_type(0); + } else { + return structureProbs[bp] / predictor->getZall(); + } +} + +//////////////////////////////////////////////////////////////////////////// + +Z_type +PredictionTrackerBasePairProb:: +fillHybridZ( const size_t ll1, const size_t si1, const size_t ll2, const size_t si2, const SeedHandler* seedHandler ) +{ +#if INTARNA_IN_DEBUG_MODE + // check indices + if (!energy.areComplementary(si1,si2) ) + throw std::runtime_error("PredictorMfeEns2dSeedExtension::fillHybridZ("+toString(si1)+","+toString(si2)+",..) are not complementary"); +#endif + + hybridZ.resize( si1-ll1+1, si2-ll2+1 ); + + // global vars to avoid reallocation + size_t i1,i2,k1,k2; + + // determine whether or not lonely base pairs are allowed or if we have to + // ensure a stacking to the right of the left boundary (i1,i2) + const size_t noLpShift = seedHandler->getConstraint().isLpAllowed() ? 0 : 1; + Z_type iStackZ = Z_type(1); + + Z_type overallZ = Z_type(0); + + // iterate over all window starts i1 (seq1) and i2 (seq2) + for (size_t l1=0; l1 < hybridZ.size1(); l1++) { + for (size_t l2=0; l2 < hybridZ.size2(); l2++) { + i1 = si1-l1; + i2 = si2-l2; + + // referencing cell access + Z_type & curZ = hybridZ(si1-i1,si2-i2); + + // init current cell (0 if not just right-most (j1,j2) base pair) + curZ = (i1==si1 && i2==si2) ? energy.getBoltzmannWeight(energy.getE_init()) : 0.0; + + // check if complementary (use global sequence indexing) + if( i1getConstraint().isLpAllowed()) { + // skip if no stacking possible + if (!energy.areComplementary(i1+noLpShift,i2+noLpShift)) { + continue; + } + // get stacking energy to avoid recomputation in recursion below + iStackZ = energy.getBoltzmannWeight(energy.getE_interLeft(i1,i1+noLpShift,i2,i2+noLpShift)); + // check just stacked + curZ += iStackZ + hybridZ(l1-noLpShift,l2-noLpShift); + } + + // check all combinations of decompositions into (i1,i2)..(k1,k2)-(j1,j2) + for (k1=i1+noLpShift; k1++ < si1; ) { + // ensure maximal loop length + if (k1-i1-noLpShift > energy.getMaxInternalLoopSize1()+1) break; + for (k2=i2+noLpShift; k2++ < si2; ) { + // ensure maximal loop length + if (k2-i2-noLpShift > energy.getMaxInternalLoopSize2()+1) break; + // check if (k1,k2) are valid left boundary + if ( ! Z_equal(hybridZ(si1-k1,si2-k2), 0.0) ) { + curZ += (iStackZ + * energy.getBoltzmannWeight(energy.getE_interLeft(i1+noLpShift,k1,i2+noLpShift,k2)) + * hybridZ(si1-k1,si2-k2)); + } + } // k2 + } // k1 + } // complementary + + // store partial Z + Interaction::Boundary key(i1,si1,i2,si2); + auto keyEntry = ZH_partition_missing.find(key); + if ( ZH_partition_missing.find(key) == ZH_partition_missing.end() ) { + ZH_partition_missing[key] = curZ; + } + // store Z of full region for return value + if (i1 == ll1 && i2 == ll2) { + overallZ = curZ; + } + + } // i2 + } // i1 + + return overallZ; +} + +//////////////////////////////////////////////////////////////////////////// + +bool +PredictionTrackerBasePairProb:: +generateDotPlot( const char *seq1, const char *seq2, const char *fileName + , const plist *pl, const char *comment + , const Interaction::Boundary interactionBoundary ) +{ + FILE *file; + file = fopen(fileName,"w"); + if (file == NULL) return false; /* failure */ + + size_t bbox[4]; + bbox[0] = 0; + bbox[1] = 0; + bbox[2] = 72 * (strlen(seq1) + 3); + bbox[3] = 72 * (strlen(seq2) + 3); + + size_t maxSize = std::max(bbox[2], bbox[3]); + float scale = 1; + if (maxSize > maxDotPlotSize) { + scale = maxDotPlotSize / (float)maxSize; + bbox[2] *= scale; + bbox[3] *= scale; + } + + fprintf(file, + "%%!PS-Adobe-3.0 EPSF-3.0\n" + "%%%%Creator: IntaRNA\n" + "%%%%Title: RNA Dot Plot\n" + "%%%%BoundingBox: %zu %zu %zu %zu\n" + "%%%%DocumentFonts: Helvetica\n" + "%%%%Pages: 1\n" + "%%%%EndComments\n", + bbox[0], bbox[1], bbox[2], bbox[3]); + + // scaling + fprintf(file, + "%%%%BeginProcSet: epsffit 1 0\n" + "gsave\n" + "%f 0 translate\n" + "%f %f scale\n" + "%%%%EndProcSet\n\n", + scale, scale, scale); + + // comment + + if (comment) { + fprintf(file, "%%%% %s\n", comment); + } + + fprintf(file, "/DPdict 100 dict def\n"); + fprintf(file, "DPdict begin\n"); + + // ps template + + fprintf(file, "%s", dotplotTemplate); + fprintf(file, "end\n"); + fprintf(file, "DPdict begin\n"); + + // sequences + + unsigned int i, length; + length = strlen(seq1); + fprintf(file, "/sequence1 { (\\\n"); + i = 0; + while (i < length) { + fprintf(file, "%.255s\\\n", seq1 + i); /* no lines longer than 255 */ + i += 255; + } + fprintf(file, ") } def\n"); + fprintf(file, "/len { sequence1 length } bind def\n\n"); + length = strlen(seq2); + fprintf(file, "/sequence2 { (\\\n"); + i = 0; + while (i < length) { + fprintf(file, "%.255s\\\n", seq2 + i); /* no lines longer than 255 */ + i += 255; + } + fprintf(file, ") } def\n"); + fprintf(file, "/len2 { sequence2 length } bind def\n\n"); + + fprintf(file, "72 72 translate\n" + "72 72 scale\n"); + + fprintf(file, "/Helvetica findfont 0.95 scalefont setfont\n\n"); + + // basepair data + + fprintf(file,"drawseq1\n"); + fprintf(file,"drawseq2\n"); + + fprintf(file,"%%data starts here\n"); + + fprintf(file,"%%start of base pair probability data\n"); + + fprintf(file, "/coor [\n"); + + for (const plist *pl1 = pl; pl1->i > 0; pl1++) { + if (pl1->type == 0) { + fprintf(file, "%1.9f %d %d boxgray\n", sqrt(pl1->p), pl1->i, pl1->j); + } + } + + fprintf(file, "] def\n"); + + fprintf(file, "0.25 0.25 0.25 setrgbcolor\n"); + + fprintf(file, "\n%%draw the grid\ndrawgrid\n\n"); + + // print frame + fprintf(file, + "0.03 setlinewidth\n\ + %1.1f %1.1f %zu %zu rectangle\n\ + 0 0 0 setrgbcolor\n\ + stroke\n", 0.5, 0.5, strlen(seq1), strlen(seq2)); + + // print best interaction outline + fprintf(file, + "0.03 setlinewidth\n\ + %1.1f %1.1f %zu %zu rectangle\n\ + 1 0 0 setrgbcolor\n\ + stroke\n", (float)interactionBoundary.i1 + 0.5, (float)interactionBoundary.i2 + 0.5, interactionBoundary.j1 - interactionBoundary.i1 + 1, interactionBoundary.j2 - interactionBoundary.i2 + 1); + + fprintf(file, "showpage\n" + "end\n" + "%%%%EOF\n"); + + fclose(file); + return true; /* success */ +} + +//////////////////////////////////////////////////////////////////////////// + +bool +PredictionTrackerBasePairProb:: +generateDotPlotSvg( const char *seq1, const char *seq2, const char *fileName + , const plist *pl, const char *comment + , const Interaction::Boundary interactionBoundary + , const InteractionEnergy & energy ) +{ + FILE *file; + file = fopen(fileName, "w"); + if (file == NULL) return false; /* failure */ + + // file information + fprintf(file, "\n\n"); + + const size_t unitSize = 1; + const size_t boxSize = 2 * unitSize; + const size_t maxWidth = (2.5+strlen(seq1)) * boxSize; + const size_t maxHeight = (2.5+strlen(seq2)) * boxSize; + + fprintf(file, + "\n", + maxWidth+boxSize, maxHeight+boxSize, maxDotPlotSize, maxDotPlotSize); + + // draw dots + fprintf(file, "\n\n\n"); + for (const plist *pl1 = pl; pl1->i > 0; pl1++) { + if (pl1->type == 0) { + std::ostringstream message; + message << "(" << (pl1->i) << "," << (strlen(seq2)-pl1->j+1) << ") = " << pl1->p; + fprintf(file, "%s", drawSvgSquare(pl1->i, strlen(seq2)-pl1->j+1, boxSize, pl1->p, "bp", message.str().c_str()).c_str()); + } + } + + fprintf(file, "\n\n\n"); + // unpaired probs seq1 + for (size_t i = 0; seq1[i] != '\0'; i++) { + float acc = 1-energy.getBoltzmannWeight(energy.getED1(i, i)); + if (acc > 0) { + std::ostringstream message; + message << "(" << (i+1) << ") = " << acc; + fprintf(file, "%s", drawSvgSquare(i+1, -0.5, boxSize, acc, "unpaired", message.str().c_str()).c_str()); + } + } + + // unpaired probs seq2 + for (size_t i = 0; seq2[i] != '\0'; i++) { + float acc = 1-energy.getBoltzmannWeight(energy.getED2(i, i)); + if (acc > 0) { + std::ostringstream message; + message << "(" << (strlen(seq2)-i) << ") = " << acc; + fprintf(file, "%s", drawSvgSquare(-0.5, strlen(seq2)-i, boxSize, acc, "unpaired", message.str().c_str()).c_str()); + } + } + + // draw grid + const float strokeWidth = 0.02 * boxSize; + fprintf(file, "\n\n\n"); + for (size_t i = 0; i <= strlen(seq1); i++) { + if (i % 5 == 0) { + fprintf(file, + "\n" + , size_t((i+2.5)*2*unitSize), 5*unitSize, size_t((i+2.5)*2*unitSize), maxHeight, strokeWidth); + } + if (i % 10 == 0) { + fprintf(file, + "\n" + , size_t((i+2.5)*2*unitSize), 5*unitSize, size_t((i+2.5)*2*unitSize), maxHeight, strokeWidth); + } + if (i % 50 == 0) { + fprintf(file, + "\n" + , size_t((i+2.5)*2*unitSize), 5*unitSize, size_t((i+2.5)*2*unitSize), maxHeight, strokeWidth); + } + } + for (size_t i = 0; i <= strlen(seq2); i++) { + if (i % 5 == 0) { + fprintf(file, + "\n" + , 5*unitSize, size_t((i+2.5)*2*unitSize), maxWidth, size_t((i+2.5)*2*unitSize), strokeWidth); + } + if (i % 10 == 0) { + fprintf(file, + "\n" + , 5*unitSize, size_t((i+2.5)*2*unitSize), maxWidth, size_t((i+2.5)*2*unitSize), strokeWidth); + } + if (i % 50 == 0) { + fprintf(file, + "\n" + , 5*unitSize, size_t((i+2.5)*2*unitSize), maxWidth, size_t((i+2.5)*2*unitSize), strokeWidth); + } + } + + // draw frames + fprintf(file, + "\n" + , 5*unitSize, 5*unitSize, maxWidth-5*unitSize, (maxHeight-5*unitSize), 2*strokeWidth); + fprintf(file, + "\n" + , 5*unitSize, boxSize, maxWidth-5*unitSize, boxSize, 2*strokeWidth); + fprintf(file, + "\n" + , boxSize, 5*unitSize, boxSize, maxHeight-5*unitSize, 2*strokeWidth); + + fprintf(file, "\n\n\n"); + fprintf(file, "\n", unitSize/4.0); + // draw sequence 1 + for (size_t i = 0; seq1[i] != '\0'; i++) { + fprintf(file, + "%c(%zu)\n" + , (i+3) * boxSize, boxSize, boxSize, seq1[i], i+1); + fprintf(file, + "%c(%zu)\n" + , (i+3) * boxSize, maxHeight + boxSize, boxSize, seq1[i], i+1); + } + + // draw sequence 2 + for (size_t i = 0; seq2[i] != '\0'; i++) { + fprintf(file, + "%c(%zu)\n" + , unitSize, size_t((i+3.5)*2*unitSize), boxSize, seq2[i], i+1); + fprintf(file, + "%c(%zu)\n" + , maxWidth + unitSize, size_t((i+3.5)*2*unitSize), boxSize, seq2[i], i+1); + } + fprintf(file, "\n"); + + // draw best interaction outline + fprintf(file, "\n\n\n"); + fprintf(file, + "\n" + , size_t((interactionBoundary.i1+2.5)*2*unitSize) + , size_t((strlen(seq2)-interactionBoundary.j2+1.5)*2*unitSize) + , (interactionBoundary.j1-interactionBoundary.i1+1) * boxSize + , (interactionBoundary.j2-interactionBoundary.i2+1) * boxSize + , 2*strokeWidth); + + fprintf(file, "\n\n\n"); + fprintf(file, "\n", boxSize, boxSize, boxSize, boxSize, boxSize, strokeWidth, strokeWidth, strokeWidth, 2*strokeWidth, 4*strokeWidth); + fprintf(file, "\n"); + + fclose(file); + return true; /* success */ +} + +//////////////////////////////////////////////////////////////////////////// + +std::string +PredictionTrackerBasePairProb:: +drawSvgSquare(const float x, const float y, const size_t size, const float probability, const char* className, const char* tooltip) +{ + std::ostringstream svg; + svg << ""; + svg << "" << tooltip << ""; + svg << "\n"; + return svg.str(); +} + +//////////////////////////////////////////////////////////////////////////// + +} // namespace diff --git a/src/IntaRNA/PredictionTrackerBasePairProb.h b/src/IntaRNA/PredictionTrackerBasePairProb.h new file mode 100644 index 00000000..943349f0 --- /dev/null +++ b/src/IntaRNA/PredictionTrackerBasePairProb.h @@ -0,0 +1,389 @@ + +#ifndef INTARNA_PREDICTIONTRACKERBASEPAIRPROB_H_ +#define INTARNA_PREDICTIONTRACKERBASEPAIRPROB_H_ + +#include "IntaRNA/PredictionTracker.h" +#include "IntaRNA/InteractionEnergy.h" +#include "IntaRNA/Interaction.h" +#include "IntaRNA/PredictorMfeEns.h" +#include "IntaRNA/PredictorMfeEns2dSeedExtension.h" + +#include + +#include + +namespace IntaRNA { + +/** + * Collects partition function parts Z(i,j), computes base-pair probabilities + * and prints them as a dot plot + * + * The information is written to stream on destruction. + * + * @author Frank Gelhausen + * + */ +class PredictionTrackerBasePairProb: public PredictionTracker +{ + +public: + + typedef std::unordered_map Site2Z_hash; + typedef std::unordered_map BasePair2Prob_hash; + typedef std::unordered_map, Interaction::BasePair::Hash, Interaction::BasePair::Equal> BasePairIndex; + typedef boost::numeric::ublas::matrix Z2dMatrix; + +public: + + /** + * Constructs a PredictionTracker that collects probability information + * for an interaction by computing the Boltzmann probabilities and + * generates a basepair-probabilities dotplot + * + * @param energy the energy function used for energy calculation + * @param fileName the name of the generated postscript file containing + * the dotplot + */ + PredictionTrackerBasePairProb( + const InteractionEnergy & energy + , const std::string & fileName + ); + + /** + * destruction: write the probabilities to stream. + */ + virtual ~PredictionTrackerBasePairProb(); + + + /** + * Updates the probability information for each Predictor.updateOptima() call. + * + * @param i1 the index of the first sequence interacting with i2 + * @param j1 the index of the first sequence interacting with j2 + * @param i2 the index of the second sequence interacting with i1 + * @param j2 the index of the second sequence interacting with j1 + * @param energy the overall energy of the interaction site + */ + virtual + void + updateOptimumCalled( const size_t i1, const size_t j1 + , const size_t i2, const size_t j2 + , const E_type energy + ); + + /** + * Updates the probability information. + * + * @param predictor the predictor providing the probability information + * @param seedHandler the seedHandler of the predictor (NULL if noseed predictior) + */ + virtual + void + updateZ( PredictorMfeEns *predictor, SeedHandler* seedHandler ) override; + + /** + * Access to the base pair probability + * of basepair (i1, i2) + * @param i1 index in first sequence + * @param i2 index in second sequence + * @param predictor the predictor providing the probability information + * + * @return the base pair probability of given basepair + */ + Z_type + getBasePairProb( const size_t i1, const size_t i2 + , const PredictorMfeEns *predictor); + +protected: + + //! energy handler used for predictions + const InteractionEnergy & energy; + + //! filename of the generated dotplot + const std::string fileName; + + //! threshold used to draw probabilities in dotplot + const Z_type probabilityThreshold; + + //! map storing structure probabilities + BasePair2Prob_hash structureProbs; + + //! left side index + BasePairIndex rightExt; + + //! right side index + BasePairIndex leftExt; + + //! flag for seed-based predictors + bool isSeedPredictor; + + //! maximum postscript width/height in ps units + const size_t maxDotPlotSize; + + //! partition function of all interaction hybrids that start on the right side of the seed including E_init + Z2dMatrix hybridZ; + + //! map storing the missing partitions of ZR for all considered interaction sites + Site2Z_hash ZR_partition; + + //! map storing the missing partitions of ZH for all considered interaction sites + Site2Z_hash ZH_partition_missing; + + /** + * Access to the given partition function covering + * the given boundary + * @param Site2Z_hash partition function hash + * @param Boundary boundary + * @param addZInit whether or not to add Z_init + * + * @return the partition function at given boundary + */ + Z_type + getZPartitionValue( const Site2Z_hash *Zpartition, const Interaction::Boundary & boundary, const bool addZInit ); + + /** + * Compute ZH partition function for given region + * @param predictor the predictor providing the probability information + * @param seedHandler the seedHandler of the predictor + * @param i1 region index + * @param j1 region index + * @param i2 region index + * @param j2 region index + * + * @return the ZH partition function at given region + */ + Z_type + getZHPartition( const PredictorMfeEns2dSeedExtension *predictor, const SeedHandler* seedHandler + , const size_t i1, const size_t j1 + , const size_t i2, const size_t j2 ); + + /** + * Compute ZR partition function for given region + * @note Does not contain E_init + * @param predictor the predictor providing the probability information + * @param seedHandler the seedHandler of the predictor + * @param i1 region index + * @param j1 region index + * @param i2 region index + * @param j2 region index + * @param si1 index of seed bordering the left side of ZR + * @param si2 index of seed bordering the left side of ZR + * + * @return the ZR partition function at given region + */ + Z_type + getZRPartition( const PredictorMfeEns2dSeedExtension *predictor, const SeedHandler* seedHandler + , const size_t i1, const size_t j1 + , const size_t i2, const size_t j2 ); + + /** + * Access to the current partition function covering + * the interaction at region (i1, j1, i2, j2). + * @param i1 region index + * @param j1 region index + * @param i2 region index + * @param j2 region index + * @param predictor the predictor providing the probability information + * + * @return the hybridization partition function at given region + */ + Z_type + getHybridZ( const size_t i1, const size_t j1 + , const size_t i2, const size_t j2 + , const PredictorMfeEns *predictor); + /** + * Access to the current partition function covering + * the interaction at region (i1, j1, i2, j2). + * @param boundary the region of interest + * @param predictor the predictor providing the probability information + * + * @return the hybridization partition function at given region + */ + Z_type + getHybridZ( const Interaction::Boundary & boundary + , const PredictorMfeEns *predictor); + + void + updateProb( const Interaction::BasePair & bp, const Z_type prob ) { + if (structureProbs.find(bp)==structureProbs.end()) { + structureProbs[bp] = prob; + } else { + structureProbs[bp] += prob; + } + } + + /** + * Generates a dotplot of the given base pair probabilities + * @param seq1 first RNA sequence + * @param seq2 second RNA sequence + * @param fileName name of the output file + * @param pl plist containing base pair probabilities + * @param comment comment to include in postscript + * @param interactionBoundary boundary of the predicted interaction + * + * @return false in case of failure + */ + bool + generateDotPlot( const char *seq1, const char *seq2, const char *fileName + , const plist *pl, const char *comment + , const Interaction::Boundary interactionBoundary ); + + /** + * Generates a dotplot of the given base pair probabilities + * @param seq1 first RNA sequence + * @param seq2 second RNA sequence + * @param fileName name of the output file + * @param pl plist containing base pair probabilities + * @param comment comment to include in postscript + * @param interactionBoundary boundary of the predicted interaction + * @param energy interaction energy + * + * @return false in case of failure + */ + bool + generateDotPlotSvg( const char *seq1, const char *seq2, const char *fileName + , const plist *pl, const char *comment + , const Interaction::Boundary interactionBoundary + , const InteractionEnergy & energy ); + + /** + * Draw an SVG square at given position with given size and opacity + * @param x horizontal position of square center + * @param y vertical position of square center + * @param size size of square + * @param probability bp-probability of square [0-1] + * @param tooltip tooltip shown on hover + * @param className class of svg element + * + * @return svg tag for square + */ + std::string + drawSvgSquare(const float x, const float y, const size_t size, const float probability, const char* className, const char* tooltip = ""); + + /** + * Compute basepair probabilities and store in structureProbs + * @param predictor the predictor providing the probability information + * @param seedHandler the seedHandler of the predictor + */ + void + computeBasePairProbs( const PredictorMfeEns2dSeedExtension *predictor, const SeedHandler* seedHandler ); + + /** + * Compute basepair probabilities for no-seed predictions and store in structureProbs + * @param predictor the predictor providing the probability information + */ + void + computeBasePairProbsNoSeed( const PredictorMfeEns *predictor ); + + /** + * Computes hybridZ + * + * Note: (i1,i2) have to be complementary (right-most base pair of seed) + * + * @param l1 start of the interaction within seq 1 + * @param si1 start of anchor seed in seq 1 + * @param l2 start of the interaction within seq 2 + * @param si2 start of anchor seed in seq 2 + * @param seedHandler the seedHandler of the predictor + * @return Z of region + */ + Z_type + fillHybridZ( const size_t l1, const size_t si1, const size_t l2, const size_t si2, const SeedHandler* seedHandler ); + + //! postscript template for dotplots + const char* const dotplotTemplate = + "/box { %%size x y box - draws box centered on x,y\n\ + 2 index 0.5 mul sub %% x -= 0.5\n\ + exch 2 index 0.5 mul sub exch %% y -= 0.5\n\ + 3 -1 roll dup rectfill\n\ + } bind def\n\ + /boxgray { %%size x y box - draws box centered on x,y\n\ + 0 0 1 5 index sub sethsbcolor %% grayscale\n\ + 1 index 0.5 sub %% x -= 0.5 s x y x'\n\ + 1 index 0.5 sub %% y -= 0.5 s x y x' y'\n\ + 5 2 roll %% x' y' s x y\n\ + pop pop pop\n\ + 1 dup rectfill\n\ + } bind def\n\ + \n\ + /drawseq1 { %% print sequence1\n\ + [ [0.7 -0.3 ]\n\ + [0.7 0.7 len2 add]\n\ + ] {\n\ + gsave\n\ + aload pop translate\n\ + 0 1 len 1 sub {\n\ + dup 0 moveto\n\ + sequence1 exch 1 getinterval\n\ + show\n\ + } for\n\ + grestore\n\ + } forall\n\ + } bind def\n\ + \n\ + /drawseq2 { %% print sequence2\n\ + [ [-0.3 len2 sub -0.4 -90]\n\ + [-0.3 len2 sub 0.7 len add -90]\n\ + ] {\n\ + gsave\n\ + aload pop rotate translate\n\ + 0 1 len2 1 sub {\n\ + dup 0 moveto\n\ + sequence2 exch 1 getinterval\n\ + show\n\ + } for\n\ + grestore\n\ + } forall\n\ + } bind def\n\ + \n\ + /rectangle {%% x y w h RT -\n\ + %% draw a rectangle size w h at x y\n\ + 4 -2 roll moveto %% lower left corner\n\ + dup 0 exch rlineto %% to upper left\n\ + exch 0 rlineto %% to upper right\n\ + neg 0 exch rlineto %% to lower right\n\ + closepath\n\ + } def\n\ + /drawgrid{\n\ + gsave\n\ + 0.5 dup translate\n\ + 1 %% len log 0.9 sub cvi 10 exch exp %% grid spacing\n\ + 0 exch len {\n\ + dup\n\ + dup cvi 10 mod 0 eq {\n\ + 0.01 setlinewidth\n\ + [1 0] 0 setdash\n\ + } {\n\ + 0.01 setlinewidth\n\ + [0.3 0.7] 0.15 setdash\n\ + } ifelse\n\ + 0 moveto\n\ + len2 lineto %% vertical\n\ + stroke\n\ + } for\n\ + \n\ + 1 %% len log 0.9 sub cvi 10 exch exp %% grid spacing\n\ + 0 exch len2 {\n\ + dup\n\ + dup cvi 10 mod 0 eq {\n\ + 0.01 setlinewidth\n\ + [1 0] 0 setdash\n\ + } {\n\ + 0.01 setlinewidth\n\ + [0.3 0.7] 0.15 setdash\n\ + } ifelse\n\ + len2 exch sub 0 exch moveto\n\ + len exch len2 exch sub lineto %% horizontal\n\ + stroke\n\ + } for\n\ + \n\ + grestore\n\ + } bind def\n"; + +}; + +////////////////////////////////////////////////////////////////////// + +} // namespace + +#endif /* INTARNA_PREDICTIONTRACKERBASEPAIRPROB_H_ */ diff --git a/src/IntaRNA/PredictionTrackerHub.h b/src/IntaRNA/PredictionTrackerHub.h index 2718f92b..fc572b9c 100644 --- a/src/IntaRNA/PredictionTrackerHub.h +++ b/src/IntaRNA/PredictionTrackerHub.h @@ -63,6 +63,15 @@ class PredictionTrackerHub: public PredictionTracker , const E_type energy ); + /** + * Updates the probability information. + * + * @param predictor the predictor providing the probability information + */ + virtual + void + updateZ( PredictorMfeEns *predictor, SeedHandler* seedHandler ) override; + /** * Adds a new PredictionTracker to the forwarding list. * @param tracker pointer to the tracker to forward to @@ -194,6 +203,19 @@ updateOptimumCalled( const size_t i1, const size_t j1 ///////////////////////////////////////////////////////////////////////// +inline +void +PredictionTrackerHub:: +updateZ( PredictorMfeEns *predictor, SeedHandler* seedHandler ) +{ + // forward to all in list + for (auto trackIt=trackList.begin(); trackIt!=trackList.end(); trackIt++) { + (*trackIt)->updateZ(predictor, seedHandler); + } +} + +///////////////////////////////////////////////////////////////////////// + inline void PredictionTrackerHub:: diff --git a/src/IntaRNA/PredictorMfe.h b/src/IntaRNA/PredictorMfe.h index bcd69e65..696712f1 100644 --- a/src/IntaRNA/PredictorMfe.h +++ b/src/IntaRNA/PredictorMfe.h @@ -96,7 +96,7 @@ class PredictorMfe : public Predictor { InteractionList mfeInteractions; //! hash to map index pairs to BestInteractionE entries - typedef boost::unordered_map< Interaction::BasePair, BestInteractionE > HashIdx2E; + typedef boost::unordered_map< Interaction::BasePair, BestInteractionE, Interaction::BasePair::Hash, Interaction::BasePair::Equal > HashIdx2E; //! if non-overlapping output is required, this data structure is filled //! to find non-overlapping interactions diff --git a/src/IntaRNA/PredictorMfeEns.cpp b/src/IntaRNA/PredictorMfeEns.cpp index 23a06f3d..ac6becbf 100644 --- a/src/IntaRNA/PredictorMfeEns.cpp +++ b/src/IntaRNA/PredictorMfeEns.cpp @@ -36,6 +36,25 @@ initZ() //////////////////////////////////////////////////////////////////////////// +const PredictorMfeEns::Site2Z_hash & +PredictorMfeEns:: +getZPartition() const { + return Z_partition; +} + +//////////////////////////////////////////////////////////////////////////// + +void +PredictorMfeEns:: +reportZ( SeedHandler* seedHandler ) +{ + if (predTracker != NULL) { + predTracker->updateZ(this, seedHandler); + } +} + +//////////////////////////////////////////////////////////////////////////// + void PredictorMfeEns:: updateZ( const size_t i1, const size_t j1 diff --git a/src/IntaRNA/PredictorMfeEns.h b/src/IntaRNA/PredictorMfeEns.h index 90e2d990..f0d3871b 100644 --- a/src/IntaRNA/PredictorMfeEns.h +++ b/src/IntaRNA/PredictorMfeEns.h @@ -20,6 +20,9 @@ namespace IntaRNA { */ class PredictorMfeEns : public PredictorMfe { +public: + + typedef std::unordered_map Site2Z_hash; public: @@ -38,16 +41,21 @@ class PredictorMfeEns : public PredictorMfe { virtual ~PredictorMfeEns(); -protected: + /** + * Access to Z_partition + * + * @return Z_partition + */ + const Site2Z_hash & + getZPartition() const; - //! data container to encode a site with respective partition function - struct ZPartition { - size_t i1; - size_t j1; - size_t i2; - size_t j2; - Z_type partZ; - }; + /** + * Report Z information to the prediction trackers + */ + void + reportZ( SeedHandler *seedHandler = NULL ); + +protected: //! access to the interaction energy handler of the super class using PredictorMfe::energy; @@ -59,7 +67,7 @@ class PredictorMfeEns : public PredictorMfe { using PredictorMfe::predTracker; //! map storing the partition of Zall for all considered interaction sites - std::unordered_map Z_partition; + Site2Z_hash Z_partition; /** diff --git a/src/IntaRNA/PredictorMfeEns2d.cpp b/src/IntaRNA/PredictorMfeEns2d.cpp index 707c509d..4a971f89 100644 --- a/src/IntaRNA/PredictorMfeEns2d.cpp +++ b/src/IntaRNA/PredictorMfeEns2d.cpp @@ -86,6 +86,8 @@ predict( const IndexRange & r1 // report mfe interaction reportOptima(); + + reportZ(); } //////////////////////////////////////////////////////////////////////////// diff --git a/src/IntaRNA/PredictorMfeEns2dHeuristic.cpp b/src/IntaRNA/PredictorMfeEns2dHeuristic.cpp index 00398d44..6db74533 100644 --- a/src/IntaRNA/PredictorMfeEns2dHeuristic.cpp +++ b/src/IntaRNA/PredictorMfeEns2dHeuristic.cpp @@ -71,6 +71,8 @@ predict( const IndexRange & r1 // trace back and output handler update reportOptima(); + reportZ(); + } diff --git a/src/IntaRNA/PredictorMfeEns2dHeuristicSeedExtension.cpp b/src/IntaRNA/PredictorMfeEns2dHeuristicSeedExtension.cpp index bc1fe839..b6dffe00 100644 --- a/src/IntaRNA/PredictorMfeEns2dHeuristicSeedExtension.cpp +++ b/src/IntaRNA/PredictorMfeEns2dHeuristicSeedExtension.cpp @@ -11,7 +11,7 @@ PredictorMfeEns2dHeuristicSeedExtension( , PredictionTracker * predTracker , SeedHandler * seedHandlerInstance ) : - PredictorMfeEns2dSeedExtension(energy,output,predTracker,seedHandlerInstance) + PredictorMfeEns2dSeedExtension(energy,output,predTracker,seedHandlerInstance,false) , E_right_opt(E_INF) , j1opt(0) , j2opt(0) diff --git a/src/IntaRNA/PredictorMfeEns2dSeedExtension.cpp b/src/IntaRNA/PredictorMfeEns2dSeedExtension.cpp index cc7aa459..d707a993 100644 --- a/src/IntaRNA/PredictorMfeEns2dSeedExtension.cpp +++ b/src/IntaRNA/PredictorMfeEns2dSeedExtension.cpp @@ -10,12 +10,14 @@ PredictorMfeEns2dSeedExtension( const InteractionEnergy & energy , OutputHandler & output , PredictionTracker * predTracker - , SeedHandler * seedHandlerInstance ) + , SeedHandler * seedHandlerInstance + , bool trackBasePairProbs ) : PredictorMfeEns(energy,output,predTracker) , seedHandler(seedHandlerInstance) , hybridZ_left( 0,0 ) , hybridZ_right( 0,0 ) + , trackBasePairProbs(trackBasePairProbs) { assert( seedHandler.getConstraint().getBasePairs() > 1 ); } @@ -127,24 +129,45 @@ predict( const IndexRange & r1, const IndexRange & r2 ) // report mfe interaction reportOptima(); + // report to predictionTracker + reportZ( &seedHandler ); + +} + +//////////////////////////////////////////////////////////////////////////// + +const PredictorMfeEns::Site2Z_hash & +PredictorMfeEns2dSeedExtension:: +getZLPartition() const { + return ZL_partition; +} + +//////////////////////////////////////////////////////////////////////////// + +const PredictorMfeEns::Site2Z_hash & +PredictorMfeEns2dSeedExtension:: +getZHPartition() const { + return ZH_partition; } ////////////////////////////////////////////////////////////////////////// E_type PredictorMfeEns2dSeedExtension:: -getNonOverlappingEnergy( const size_t si1, const size_t si2, const size_t si1p, const size_t si2p ) { +getNonOverlappingEnergy( const size_t si1, const size_t si2, const size_t si1p, const size_t si2p + , const InteractionEnergy & energy, const SeedHandler & seedHandler + , const bool sipIsSeed) { #if INTARNA_IN_DEBUG_MODE // check indices if( !seedHandler.isSeedBound(si1,si2) ) throw std::runtime_error("PredictorMfeEns2dSeedExtension::getNonOverlappingEnergy( si "+toString(si1)+","+toString(si2)+",..) is no seed bound"); - if( !seedHandler.isSeedBound(si1p,si2p) ) + if( sipIsSeed && !seedHandler.isSeedBound(si1p,si2p) ) throw std::runtime_error("PredictorMfeEns2dSeedExtension::getNonOverlappingEnergy( sip "+toString(si1p)+","+toString(si2p)+",..) is no seed bound"); if( si1 > si1p ) throw std::runtime_error("PredictorMfeEns2dSeedExtension::getNonOverlappingEnergy( si "+toString(si1)+","+toString(si2)+", sip "+toString(si1p)+","+toString(si2p)+",..) si1 > sj1 !"); // check if loop-overlapping (i.e. share at least one loop) - if( !seedHandler.areLoopOverlapping(si1,si2,si1p,si2p) ) { + if( sipIsSeed && !seedHandler.areLoopOverlapping(si1,si2,si1p,si2p) ) { throw std::runtime_error("PredictorMfeEns2dSeedExtension::getNonOverlappingEnergy( si "+toString(si1)+","+toString(si2)+", sip "+toString(si1p)+","+toString(si2p)+",..) are not loop overlapping"); } #endif @@ -183,14 +206,15 @@ void PredictorMfeEns2dSeedExtension:: fillHybridZ_left( const size_t si1, const size_t si2 ) { - // temporary access - const OutputConstraint & outConstraint = output.getOutputConstraint(); #if INTARNA_IN_DEBUG_MODE // check indices if (!energy.areComplementary(si1,si2) ) throw std::runtime_error("PredictorMfeEns2dSeedExtension::fillHybridZ_left("+toString(si1)+","+toString(si2)+",..) are not complementary"); #endif + // temporary access + const OutputConstraint & outConstraint = output.getOutputConstraint(); + // global vars to avoid reallocation size_t i1,i2,k1,k2; @@ -243,15 +267,42 @@ fillHybridZ_left( const size_t si1, const size_t si2 ) } // k1 // correction for left seeds - if ( i1destruction. * @param seedHandler the seed handler to be used + * @param trackBasePairProbs indicates the presence of a basePairProb tracker */ PredictorMfeEns2dSeedExtension( const InteractionEnergy & energy , OutputHandler & output , PredictionTracker * predTracker - , SeedHandler * seedHandler ); + , SeedHandler * seedHandler + , bool trackBasePairProbs ); /** @@ -69,9 +71,41 @@ class PredictorMfeEns2dSeedExtension: public PredictorMfeEns { predict( const IndexRange & r1 = IndexRange(0,RnaSequence::lastPos) , const IndexRange & r2 = IndexRange(0,RnaSequence::lastPos) ); + /** + * Access to ZL_partition + * + * @return ZL_partition + */ + const Site2Z_hash & + getZLPartition() const; -protected: + /** + * Access to ZH_partition + * + * @return ZH_partition + */ + const Site2Z_hash & + getZHPartition() const; + + /** + * Returns the hybridization energy of the non-overlapping part of seeds + * starting at si and sj + * + * @param si1 the index of seed1 in the first sequence + * @param si2 the index of seed1 in the second sequence + * @param sj1 the index of seed2 in the first sequence + * @param sj2 the index of seed2 in the second sequence + * @param energy the interaction energy handler + * @param seedHandler the seedHandler of the predictor + * @param sipIsSeed whether or not sj is a left seed boundary (for sanity checking) + */ + static + E_type + getNonOverlappingEnergy( const size_t si1, const size_t si2, const size_t sj1, const size_t sj2 + , const InteractionEnergy & energy, const SeedHandler & seedHandler + , const bool sjIsSeed = true ); +protected: //! access to the interaction energy handler of the super class using PredictorMfeEns::energy; @@ -79,15 +113,24 @@ class PredictorMfeEns2dSeedExtension: public PredictorMfeEns { //! access to the output handler of the super class using PredictorMfeEns::output; - //! partition function of all interaction hybrids that start on the left side of the seed including E_init - Z2dMatrix hybridZ_left; - //! the seed handler (with idx offset) SeedHandlerIdxOffset seedHandler; + //! partition function of all interaction hybrids that start on the left side of the seed including E_init + Z2dMatrix hybridZ_left; + //! partition function of all interaction hybrids that start on the right side of the seed excluding E_init Z2dMatrix hybridZ_right; + //! boolean indicating the presence of a basePairProb tracker + bool trackBasePairProbs; + + //! map storing the partition of ZL for all considered interaction sites + Site2Z_hash ZL_partition; + + //! map storing the partition of ZH for all considered interaction sites + Site2Z_hash ZH_partition; + protected: /** @@ -124,19 +167,6 @@ class PredictorMfeEns2dSeedExtension: public PredictorMfeEns { void traceBack( Interaction & interaction ); - /** - * Returns the hybridization energy of the non overlapping part of seeds - * starting at si and sj - * - * @param si1 the index of seed1 in the first sequence - * @param si2 the index of seed1 in the second sequence - * @param sj1 the index of seed2 in the first sequence - * @param sj2 the index of seed2 in the second sequence - */ - virtual - E_type - getNonOverlappingEnergy( const size_t si1, const size_t si2, const size_t sj1, const size_t sj2 ); - // debug function void printMatrix( const Z2dMatrix & matrix ); diff --git a/src/IntaRNA/SeedHandler.cpp b/src/IntaRNA/SeedHandler.cpp index ed1196eb..b0d532f1 100644 --- a/src/IntaRNA/SeedHandler.cpp +++ b/src/IntaRNA/SeedHandler.cpp @@ -27,6 +27,28 @@ isFeasibleSeedBasePair( const size_t i1, const size_t i2, const bool atEndOfSeed ////////////////////////////////////////////////////////////////////////// +bool +SeedHandler:: +isSeedBasePair( const size_t i1, const size_t i2 + , const size_t k1, const size_t k2, const bool includeBoundaries ) const +{ + if (!isSeedBound(i1, i2)) { + return false; + } + + // trace seed at (i1,i2) + Interaction interaction = Interaction(energy.getAccessibility1().getSequence(), energy.getAccessibility2().getAccessibilityOrigin().getSequence()); + traceBackSeed( interaction, i1, i2 ); + if (includeBoundaries) { + interaction.basePairs.push_back( energy.getBasePair(i1, i2) ); + interaction.basePairs.push_back( energy.getBasePair(i1+getSeedLength1(i1,i2)-1, i2+getSeedLength2(i1,i2)-1) ); + } + + return (std::find(interaction.basePairs.begin(), interaction.basePairs.end(), energy.getBasePair(k1, k2)) != interaction.basePairs.end()); +} + +////////////////////////////////////////////////////////////////////////// + bool SeedHandler:: updateToNextSeed( size_t & i1_out, size_t & i2_out @@ -75,6 +97,40 @@ updateToNextSeed( size_t & i1_out, size_t & i2_out ////////////////////////////////////////////////////////////////////////// +bool +SeedHandler:: +updateToNextSeedWithK( size_t & i1_out, size_t & i2_out + , const size_t k1, const size_t k2, const bool includeBoundaries + ) const +{ + // if no left-shift left + if ((i1_out == 0 && i2_out == 0) + || (!includeBoundaries && (k1 == 0 || k2 == 0))) + { + return false; + } + + const size_t seedBP = getConstraint().getBasePairs(); + const size_t maxDistI1 = (seedBP - 1) + getConstraint().getMaxUnpaired1(); + const size_t maxDistI2 = (seedBP - 1) + getConstraint().getMaxUnpaired2(); + + size_t min_i1 = k1 - std::min(k1, maxDistI1); + size_t min_i2 = k2 - std::min(k2, maxDistI2); + + // check all seeds within the range + while (updateToNextSeed(i1_out,i2_out,min_i1, (includeBoundaries?k1:k1-1), min_i2, (includeBoundaries?k2:k2-1))) { + // check if we found a seed including k in the range + if (isSeedBasePair(i1_out, i2_out, k1, k2, includeBoundaries)) { + return true; + } + } + + // no valid seed including k found + return false; +} + +////////////////////////////////////////////////////////////////////////// + void SeedHandler:: addSeeds( Interaction & i ) const diff --git a/src/IntaRNA/SeedHandler.h b/src/IntaRNA/SeedHandler.h index 7682cb45..e30af8c8 100644 --- a/src/IntaRNA/SeedHandler.h +++ b/src/IntaRNA/SeedHandler.h @@ -151,6 +151,22 @@ class SeedHandler , const size_t i2min = 0, const size_t i2max = RnaSequence::lastPos ) const; + /** + * updateToNextSeed for seeds including base pair k + * + * @param i1 seq1 seed index to be changed; set to > k1 to find first valid i1 + * @param i2 seq2 seed index to be changed; set to > k2 to find first valid i2 + * @param k1 first position within seq1 (inclusive) + * @param k2 last position within seq1 (inclusive) + * @param includeBoundaries whether boundaries count as seed base pair + * @return true if the input variables have been changed; false otherwise + */ + virtual + bool + updateToNextSeedWithK( size_t & i1, size_t & i2 + , const size_t k1, const size_t k2, const bool includeBoundaries = true + ) const; + /** * Adds all seeds to a given interaction that are completely covered by its * base pairs. @@ -194,6 +210,21 @@ class SeedHandler , const size_t i2 , const bool atEndOfSeed = false ) const; + /** + * Checks whether or not a given index pair is a valid seed base of a given seed + * + * @param i1 the left most interacting base of seq1 of a seed + * @param i2 the left most interacting base of seq2 of a seed + * @param k1 the interacting base of seq1 + * @param k2 the interacting base of seq2 + * @param includeBoundaries whether boundaries count as seed base pair + * @return true if (k1,k2) is a valid base pair of seed(i1,i2); false otherwise + */ + virtual + bool + isSeedBasePair( const size_t i1, const size_t i2 + , const size_t k1, const size_t k2, const bool includeBoundaries = true ) const; + protected: //! the used energy function diff --git a/src/IntaRNA/SeedHandlerExplicit.h b/src/IntaRNA/SeedHandlerExplicit.h index 9d2743e5..4d8445cd 100644 --- a/src/IntaRNA/SeedHandlerExplicit.h +++ b/src/IntaRNA/SeedHandlerExplicit.h @@ -206,7 +206,7 @@ class SeedHandlerExplicit : public SeedHandler protected: //! container to store - boost::unordered_map< Interaction::BasePair, SeedData > seedForLeftEnd; + boost::unordered_map< Interaction::BasePair, SeedData, Interaction::BasePair::Hash, Interaction::BasePair::Equal > seedForLeftEnd; }; diff --git a/src/IntaRNA/SeedHandlerNoBulge.cpp b/src/IntaRNA/SeedHandlerNoBulge.cpp index 832ce1d5..155f39e9 100644 --- a/src/IntaRNA/SeedHandlerNoBulge.cpp +++ b/src/IntaRNA/SeedHandlerNoBulge.cpp @@ -183,6 +183,48 @@ updateToNextSeed( size_t & i1_out, size_t & i2_out return false; } +////////////////////////////////////////////////////////////////////////// + +bool +SeedHandlerNoBulge:: +updateToNextSeedWithK( size_t & i1_out, size_t & i2_out + , const size_t k1, const size_t k2, const bool includeBoundaries + ) const +{ + // if no left-shift left + if ((i1_out == 0 && i2_out == 0) + || (!includeBoundaries && (k1 == 0 || k2 == 0))) + { + return false; + } + + if ((i1_out > k1 || i2_out > k2) // right of k + || (k1-i1_out)!=(k2-i2_out) ) // different distances + { + i1_out = k1; + i2_out = k2; + if( includeBoundaries && isSeedBound(i1_out,i2_out) ) { + return true; + } + } + + const size_t seedBP = getConstraint().getBasePairs(); + const size_t maxDistI = std::min( seedBP - 1, std::min(k1,k2)); + + // check all seeds within the range + while( i1_out > 0 && i2_out > 0 && (k1-i1_out) StackingEnergyList; //! container type for sparse seed information - typedef boost::unordered_map< Interaction::BasePair, E_type > SeedHash; + typedef boost::unordered_map< Interaction::BasePair, E_type, Interaction::BasePair::Hash, Interaction::BasePair::Equal > SeedHash; //! container to store seeds' hybridization energies SeedHash seedForLeftEnd; @@ -148,6 +148,38 @@ class SeedHandlerNoBulge : public SeedHandler , const size_t i2min = 0, const size_t i2max = RnaSequence::lastPos ) const; + /** + * updateToNextSeed for seeds including base pair k + * + * @param i1 seq1 seed index to be changed; set to > k1 to find first valid i1 + * @param i2 seq2 seed index to be changed; set to > k2 to find first valid i2 + * @param k1 first position within seq1 (inclusive) + * @param k2 last position within seq1 (inclusive) + * @param includeBoundaries whether boundaries count as seed base pair + * @return true if the input variables have been changed; false otherwise + */ + virtual + bool + updateToNextSeedWithK( size_t & i1, size_t & i2 + , const size_t k1, const size_t k2, const bool includeBoundaries = true + ) const; + + + /** + * Checks whether or not a given index pair is a valid seed base of a given seed + * + * @param i1 the left most interacting base of seq1 of a seed + * @param i2 the left most interacting base of seq2 of a seed + * @param k1 the interacting base of seq1 + * @param k2 the interacting base of seq2 + * @param includeBoundaries whether boundaries count as seed base pair + * @return true if (k1,k2) is a valid base pair of seed(i1,i2); false otherwise + */ + virtual + bool + isSeedBasePair( const size_t i1, const size_t i2 + , const size_t k1, const size_t k2, const bool includeBoundaries = true ) const; + protected: @@ -247,6 +279,30 @@ isSeedBound( const size_t i1, const size_t i2 ) const ////////////////////////////////////////////////////////////////////////// +inline +bool +SeedHandlerNoBulge:: +isSeedBasePair( const size_t i1, const size_t i2 + , const size_t k1, const size_t k2, const bool includeBoundaries ) const +{ + if (!isSeedBound(i1, i2)) { + return false; + } + + if (includeBoundaries) { + return k1 >= i1 && k2 >= i2 + && (k1 - i1) == (k2 - i2) + && (k1-i1) i1 && k2 > i2 + && (k1 - i1) == (k2 - i2) + && (k1-i1)::min()) +#define IntaRNA_precisionEpsilon std::pow(10, -10) #ifdef E_2_Ekcal #error E_2_Ekcal already defined diff --git a/src/bin/CommandLineParsing.cpp b/src/bin/CommandLineParsing.cpp index 697e13b5..1a512b85 100644 --- a/src/bin/CommandLineParsing.cpp +++ b/src/bin/CommandLineParsing.cpp @@ -62,6 +62,7 @@ extern "C" { #include "IntaRNA/PredictionTrackerSpotProb.h" #include "IntaRNA/PredictionTrackerSpotProbAll.h" #include "IntaRNA/PredictionTrackerProfileSpotProb.h" +#include "IntaRNA/PredictionTrackerBasePairProb.h" #include "IntaRNA/SeedHandlerMfe.h" #include "IntaRNA/SeedHandlerNoBulge.h" @@ -903,6 +904,7 @@ CommandLineParsing::CommandLineParsing( const Personality personality ) "\n 'tPu:' (target) unpaired probabilities values (RNAplfold format)." "\n 'pMinE:' (target+query) for each index pair the minimal energy of any interaction covering the pair (CSV format)" "\n 'spotProb:' (target+query) tracks for a given set of interaction spots their probability to be covered by an interaction. If no spots are provided, probabilities for all index combinations are computed. Spots are encoded by comma-separated 'idxT&idxQ' pairs (target-query). For each spot a probability is provided in concert with the probability that none of the spots (encoded by '0&0') is covered (CSV format). The spot encoding is followed colon-separated by the output stream/file name, eg. '--out=\"spotProb:3&76,59&2:STDERR\"'. NOTE: value has to be quoted due to '&' symbol!" + "\n 'basePairProb:' (target+query) tracks inter-molecular base pair probabilities and produces a dotplot in SVG format (requires model=P and mode=M)." "\nFor each, provide a file name or STDOUT/STDERR to write to the respective output stream." ).c_str()) ("outSep" @@ -1469,6 +1471,11 @@ parse(int argc, char** argv) throw error("--out argument shows multiple times '"+toString(c1->second)+"' as target file/stream."); } } + // check if base pair probability output possible + if (c1->first == OutPrefixCode::OP_basePairProb) { + if (model.val != 'P') { throw error("--out=basePairProb requires --model=P"); } + if (mode.val != 'M') { throw error("--out=basePairProb requires --mode=M"); } + } } } @@ -1542,7 +1549,7 @@ parse(int argc, char** argv) +") exceeds the maximally allowed number of helix base pairs ("+toString(helixMaxBP.val)+")"); } } - + // check for minimal sequence length for(size_t i=0; iaddPredictionTracker( + new PredictionTrackerBasePairProb( energy + , outPrefix2streamName.at(OutPrefixCode::OP_basePairProb) ) + ); + } + // check if any tracker registered if (predTracker->empty()) { // cleanup to avoid overhead @@ -2400,7 +2416,7 @@ getPredictor( const InteractionEnergy & energy, OutputHandler & output ) const case 'P' : { switch ( mode.val ) { case 'H' : return new PredictorMfeEns2dHeuristicSeedExtension( energy, output, predTracker, getSeedHandler( energy ) ); - case 'M' : return new PredictorMfeEns2dSeedExtension( energy, output, predTracker, getSeedHandler( energy ) ); + case 'M' : return new PredictorMfeEns2dSeedExtension( energy, output, predTracker, getSeedHandler( energy ), !outPrefix2streamName.at(OutPrefixCode::OP_basePairProb).empty() ); case 'S' : return new PredictorMfeEnsSeedOnly( energy, output, predTracker, getSeedHandler(energy) ); default : INTARNA_NOT_IMPLEMENTED("mode "+toString(mode.val)+" not available for model "+toString(model.val)); } @@ -2713,6 +2729,3 @@ getPersonality( int argc, char ** argv ) //////////////////////////////////////////////////////////////////////////// - - - diff --git a/src/bin/CommandLineParsing.h b/src/bin/CommandLineParsing.h index f505388d..a1ff2788 100644 --- a/src/bin/CommandLineParsing.h +++ b/src/bin/CommandLineParsing.h @@ -350,6 +350,7 @@ class CommandLineParsing { OP_tPu, OP_spotProb, OP_spotProbAll, + OP_basePairProb, OP_UNKNOWN }; @@ -377,6 +378,7 @@ class CommandLineParsing { if (prefLC == "qpu") { return OutPrefixCode::OP_qPu; } else if (prefLC == "tpu") { return OutPrefixCode::OP_tPu; } else if (prefLC == "spotprob") { return OutPrefixCode::OP_spotProb; } else + if (prefLC == "basepairprob") { return OutPrefixCode::OP_basePairProb; } else // not known return OutPrefixCode::OP_UNKNOWN; } diff --git a/tests/Makefile.am b/tests/Makefile.am index ff862502..63f667ad 100644 --- a/tests/Makefile.am +++ b/tests/Makefile.am @@ -18,7 +18,7 @@ SH_LOG_COMPILER = $(SHELL) dist_check_SCRIPTS = runIntaRNA.sh # the program to build -check_PROGRAMS = runApiTests +check_PROGRAMS = runApiTests # test sources runApiTests_SOURCES = \ @@ -42,10 +42,12 @@ runApiTests_SOURCES = \ Interaction_test.cpp \ InteractionEnergyBasePair_test.cpp \ InteractionRange_test.cpp \ + PredictionTrackerBasePairProb_test.cpp \ PredictionTrackerProfileMinE_test.cpp \ PredictionTrackerSpotProb_test.cpp \ PredictorMfe2dHelixBlockHeuristic_test.cpp \ PredictorMfe2dHelixBlockHeuristicSeed_test.cpp \ + PredictorMfeEns2d_test.cpp \ NussinovHandler_test.cpp \ RnaSequence_test.cpp \ OutputStreamHandlerSortedCsv_test.cpp \ @@ -65,6 +67,3 @@ LIBS= -L$(top_builddir)/src/IntaRNA -lIntaRNA \ runApiTests_CXXFLAGS = -I$(top_builddir)/src \ @AM_CXXFLAGS@ @CXXFLAGS@ \ -DELPP_NO_LOG_TO_FILE - - - \ No newline at end of file diff --git a/tests/PredictionTrackerBasePairProb_test.cpp b/tests/PredictionTrackerBasePairProb_test.cpp new file mode 100644 index 00000000..68d303dd --- /dev/null +++ b/tests/PredictionTrackerBasePairProb_test.cpp @@ -0,0 +1,340 @@ +#include "catch.hpp" + +#undef NDEBUG + +#include "IntaRNA/PredictionTrackerBasePairProb.h" +#include "IntaRNA/RnaSequence.h" +#include "IntaRNA/AccessibilityDisabled.h" +#include "IntaRNA/InteractionEnergyBasePair.h" +#include "IntaRNA/Interaction.h" +#include "IntaRNA/ReverseAccessibility.h" +#include "IntaRNA/PredictorMfeEns2d.h" +#include "IntaRNA/PredictorMfeEns2dSeedExtension.h" +#include "IntaRNA/OutputHandlerInteractionList.h" +#include "IntaRNA/SeedHandlerNoBulge.h" + +#include + +using namespace IntaRNA; + +//! Helper function to generate Boltzmann weights for different numbers of base pairs +inline std::vector getBPWeights(const InteractionEnergyBasePair & energy, const size_t maxBasepairs) +{ + std::vector bpWeights; + for (int i = 0; i <= maxBasepairs; i++) { + bpWeights.push_back(energy.getBoltzmannWeight(i * energy.getE_basePair())); + } + return bpWeights; +} + +TEST_CASE( "PredictionTrackerBasePairProb", "[PredictionTrackerBasePairProb]" ) { + + // setup easylogging++ stuff if not already done + #include "testEasyLoggingSetup.icc" + + SECTION("base pair probs - case 1 noseed") { + RnaSequence r1("r1", "GG"); + RnaSequence r2("r2", "CC"); + AccessibilityDisabled acc1(r1, 0, NULL); + AccessibilityDisabled acc2(r2, 0, NULL); + ReverseAccessibility racc(acc2); + InteractionEnergyBasePair energy(acc1, racc); + + OutputConstraint outC(1,OutputConstraint::OVERLAP_SEQ2,0,100); + OutputHandlerInteractionList out(outC, 1); + + IndexRange idx1(0,r1.lastPos); + IndexRange idx2(0,r2.lastPos); + + PredictionTrackerBasePairProb * tracker = new PredictionTrackerBasePairProb(energy, ""); + PredictorMfeEns2d predictor(energy, out, tracker); + + predictor.predict(idx1,idx2); + + std::vector wBP = getBPWeights(energy, 2); + + REQUIRE(Z_equal(tracker->getBasePairProb(0, 1, &predictor), wBP[1] / predictor.getZall())); + REQUIRE(Z_equal(tracker->getBasePairProb(1, 0, &predictor), wBP[1] / predictor.getZall())); + REQUIRE(Z_equal(tracker->getBasePairProb(0, 0, &predictor), (wBP[1] + wBP[2]) / predictor.getZall())); + REQUIRE(Z_equal(tracker->getBasePairProb(1, 1, &predictor), (wBP[1] + wBP[2]) / predictor.getZall())); + } + + SECTION("base pair probs - case 1 seed") { + RnaSequence r1("r1", "GG"); + RnaSequence r2("r2", "CC"); + AccessibilityDisabled acc1(r1, 0, NULL); + AccessibilityDisabled acc2(r2, 0, NULL); + ReverseAccessibility racc(acc2); + InteractionEnergyBasePair energy(acc1, racc); + + OutputConstraint outC(1,OutputConstraint::OVERLAP_SEQ2,0,100); + OutputHandlerInteractionList out(outC, 1); + + IndexRange idx1(0,r1.lastPos); + IndexRange idx2(0,r2.lastPos); + + PredictionTrackerBasePairProb * tracker = new PredictionTrackerBasePairProb(energy, ""); + SeedConstraint sC(2,0,0,0,0 + , AccessibilityDisabled::ED_UPPER_BOUND + , 0 + , IndexRangeList("") + , IndexRangeList("") + , "" + , false, false, false + ); + + PredictorMfeEns2dSeedExtension predictor(energy, out, tracker, new SeedHandlerNoBulge(energy, sC), true); + predictor.predict(idx1,idx2); + + REQUIRE(Z_equal(tracker->getBasePairProb(0, 1, &predictor), 0)); + REQUIRE(Z_equal(tracker->getBasePairProb(1, 0, &predictor), 0)); + REQUIRE(Z_equal(tracker->getBasePairProb(0, 0, &predictor), 1)); + REQUIRE(Z_equal(tracker->getBasePairProb(1, 1, &predictor), 1)); + } + + SECTION("base pair probs - case 2 noseed") { + RnaSequence r1("r1", "GGG"); + RnaSequence r2("r2", "CCC"); + AccessibilityDisabled acc1(r1, 0, NULL); + AccessibilityDisabled acc2(r2, 0, NULL); + ReverseAccessibility racc(acc2); + InteractionEnergyBasePair energy(acc1, racc); + + OutputConstraint outC(1,OutputConstraint::OVERLAP_SEQ2,0,100); + OutputHandlerInteractionList out(outC, 1); + + IndexRange idx1(0,r1.lastPos); + IndexRange idx2(0,r2.lastPos); + + PredictionTrackerBasePairProb * tracker = new PredictionTrackerBasePairProb(energy, ""); + PredictorMfeEns2d predictor(energy, out, tracker); + + predictor.predict(idx1,idx2); + + std::vector wBP = getBPWeights(energy, 3); + + REQUIRE(Z_equal(tracker->getBasePairProb(0, 1, &predictor), (wBP[1] + 2 * wBP[2]) / predictor.getZall())); + REQUIRE(Z_equal(tracker->getBasePairProb(1, 0, &predictor), (wBP[1] + 2 * wBP[2]) / predictor.getZall())); + REQUIRE(Z_equal(tracker->getBasePairProb(2, 1, &predictor), (wBP[1] + 2 * wBP[2]) / predictor.getZall())); + REQUIRE(Z_equal(tracker->getBasePairProb(1, 2, &predictor), (wBP[1] + 2 * wBP[2]) / predictor.getZall())); + + REQUIRE(Z_equal(tracker->getBasePairProb(0, 2, &predictor), wBP[1] / predictor.getZall())); + REQUIRE(Z_equal(tracker->getBasePairProb(2, 0, &predictor), wBP[1] / predictor.getZall())); + + REQUIRE(Z_equal(tracker->getBasePairProb(0, 0, &predictor), (wBP[1] + 4 * wBP[2] + wBP[3]) / predictor.getZall())); + REQUIRE(Z_equal(tracker->getBasePairProb(2, 2, &predictor), (wBP[1] + 4 * wBP[2] + wBP[3]) / predictor.getZall())); + REQUIRE(Z_equal(tracker->getBasePairProb(1, 1, &predictor), (wBP[1] + 2 * wBP[2] + wBP[3]) / predictor.getZall())); + } + + SECTION("base pair probs - case 2 seed") { + RnaSequence r1("r1", "GGG"); + RnaSequence r2("r2", "CCC"); + AccessibilityDisabled acc1(r1, 0, NULL); + AccessibilityDisabled acc2(r2, 0, NULL); + ReverseAccessibility racc(acc2); + InteractionEnergyBasePair energy(acc1, racc); + + OutputConstraint outC(1,OutputConstraint::OVERLAP_SEQ2,0,10000,0,0,0,1,1); + OutputHandlerInteractionList out(outC, 1); + + IndexRange idx1(0,r1.lastPos); + IndexRange idx2(0,r2.lastPos); + + PredictionTrackerBasePairProb * tracker = new PredictionTrackerBasePairProb(energy, ""); + SeedConstraint sC(3,0,0,0,0 + , AccessibilityDisabled::ED_UPPER_BOUND + , 0 + , IndexRangeList("") + , IndexRangeList("") + , "" + , false, false, false + ); + + PredictorMfeEns2dSeedExtension predictor(energy, out, tracker, new SeedHandlerNoBulge(energy, sC), true); + predictor.predict(idx1,idx2); + + REQUIRE(Z_equal(tracker->getBasePairProb(0, 1, &predictor), 0)); + REQUIRE(Z_equal(tracker->getBasePairProb(1, 0, &predictor), 0)); + REQUIRE(Z_equal(tracker->getBasePairProb(2, 1, &predictor), 0)); + REQUIRE(Z_equal(tracker->getBasePairProb(1, 2, &predictor), 0)); + + REQUIRE(Z_equal(tracker->getBasePairProb(0, 2, &predictor), 0)); + REQUIRE(Z_equal(tracker->getBasePairProb(2, 0, &predictor), 0)); + + REQUIRE(Z_equal(tracker->getBasePairProb(0, 0, &predictor), 1)); + REQUIRE(Z_equal(tracker->getBasePairProb(2, 2, &predictor), 1)); + REQUIRE(Z_equal(tracker->getBasePairProb(1, 1, &predictor), 1)); + } + + SECTION("base pair probs - case 3 noseed") { + RnaSequence r1("r1", "GGCGC"); + RnaSequence r2("r2", "GGCC"); + AccessibilityDisabled acc1(r1, 0, NULL); + AccessibilityDisabled acc2(r2, 0, NULL); + ReverseAccessibility racc(acc2); + InteractionEnergyBasePair energy(acc1, racc); + + OutputConstraint outC(1,OutputConstraint::OVERLAP_SEQ2,0,100); + OutputHandlerInteractionList out(outC, 1); + + IndexRange idx1(0,r1.lastPos); + IndexRange idx2(0,r2.lastPos); + + PredictionTrackerBasePairProb * tracker = new PredictionTrackerBasePairProb(energy, ""); + PredictorMfeEns2d predictor(energy, out, tracker); + + predictor.predict(idx1,idx2); + + std::vector wBP = getBPWeights(energy, 4); + + REQUIRE(Z_equal(tracker->getBasePairProb(1, 1, &predictor), (wBP[1] + 5 * wBP[2] + 5 * wBP[3] + wBP[4]) / predictor.getZall())); + } + + SECTION("base pair probs - case 4 seed") { + RnaSequence r1("r1", "GGGG"); + RnaSequence r2("r2", "CCCC"); + AccessibilityDisabled acc1(r1, 0, NULL); + AccessibilityDisabled acc2(r2, 0, NULL); + ReverseAccessibility racc(acc2); + InteractionEnergyBasePair energy(acc1, racc); + + OutputConstraint outC(1,OutputConstraint::OVERLAP_SEQ2,0,10000,0,0,0,1,1); + OutputHandlerInteractionList out(outC, 1); + + IndexRange idx1(0,r1.lastPos); + IndexRange idx2(0,r2.lastPos); + + PredictionTrackerBasePairProb * tracker = new PredictionTrackerBasePairProb(energy, ""); + SeedConstraint sC(3,0,0,0,0 + , AccessibilityDisabled::ED_UPPER_BOUND + , 0 + , IndexRangeList("") + , IndexRangeList("") + , "" + , false, false, false + ); + + + PredictorMfeEns2dSeedExtension predictor(energy, out, tracker, new SeedHandlerNoBulge(energy, sC), true); + predictor.predict(idx1,idx2); + + std::vector wBP = getBPWeights(energy, 4); + //LOG(DEBUG) <<" w 1 "<getBasePairProb(0, 0, &predictor), (wBP[3] + wBP[4]) / predictor.getZall())); + REQUIRE(Z_equal(tracker->getBasePairProb(1, 1, &predictor), (2 * wBP[3] + wBP[4]) / predictor.getZall())); + //LOG(DEBUG) <<"Z(1,1)"<getBasePairProb(2, 2, &predictor)*predictor.getZall() <<" Zall "<getBasePairProb(2, 2, &predictor), (2 * wBP[3] + wBP[4]) / predictor.getZall())); + REQUIRE(Z_equal(tracker->getBasePairProb(3, 3, &predictor), (wBP[3] + wBP[4]) / predictor.getZall())); + + REQUIRE(Z_equal(tracker->getBasePairProb(0, 1, &predictor), wBP[3] / predictor.getZall())); + REQUIRE(Z_equal(tracker->getBasePairProb(1, 2, &predictor), wBP[3] / predictor.getZall())); + REQUIRE(Z_equal(tracker->getBasePairProb(2, 3, &predictor), wBP[3] / predictor.getZall())); + + REQUIRE(Z_equal(tracker->getBasePairProb(3, 2, &predictor), wBP[3] / predictor.getZall())); + REQUIRE(Z_equal(tracker->getBasePairProb(2, 1, &predictor), wBP[3] / predictor.getZall())); + REQUIRE(Z_equal(tracker->getBasePairProb(1, 0, &predictor), wBP[3] / predictor.getZall())); + } + + SECTION("base pair probs - case 5 seed") { + RnaSequence r1("r1", "GGGG"); + RnaSequence r2("r2", "CCCC"); + AccessibilityDisabled acc1(r1, 0, NULL); + AccessibilityDisabled acc2(r2, 0, NULL); + ReverseAccessibility racc(acc2); + InteractionEnergyBasePair energy(acc1, racc); + + OutputConstraint outC(1,OutputConstraint::OVERLAP_SEQ2,0,10000,0,0,0,1,1); + OutputHandlerInteractionList out(outC, 1); + + IndexRange idx1(0,r1.lastPos); + IndexRange idx2(0,r2.lastPos); + + PredictionTrackerBasePairProb * tracker = new PredictionTrackerBasePairProb(energy, ""); + SeedConstraint sC(2,0,0,0,0 + , AccessibilityDisabled::ED_UPPER_BOUND + , 0 + , IndexRangeList("") + , IndexRangeList("") + , "" + , false, false, false + ); + + PredictorMfeEns2dSeedExtension predictor(energy, out, tracker, new SeedHandlerNoBulge(energy, sC), true); + predictor.predict(idx1,idx2); + + std::vector wBP = getBPWeights(energy, 4); + + REQUIRE(Z_equal(tracker->getBasePairProb(0, 0, &predictor), (1 * wBP[2] + 7 * wBP[3] + wBP[4]) / predictor.getZall())); + REQUIRE(Z_equal(tracker->getBasePairProb(1, 1, &predictor), (2 * wBP[2] + 5 * wBP[3] + wBP[4]) / predictor.getZall())); + REQUIRE(Z_equal(tracker->getBasePairProb(2, 2, &predictor), (2 * wBP[2] + 5 * wBP[3] + wBP[4]) / predictor.getZall())); + REQUIRE(Z_equal(tracker->getBasePairProb(3, 3, &predictor), (1 * wBP[2] + 7 * wBP[3] + wBP[4]) / predictor.getZall())); + } + + SECTION("base pair probs - case 6 seed") { + RnaSequence r1("r1", "GGGGGGG"); + RnaSequence r2("r2", "CCCCCCC"); + AccessibilityDisabled acc1(r1, 0, NULL); + AccessibilityDisabled acc2(r2, 0, NULL); + ReverseAccessibility racc(acc2); + InteractionEnergyBasePair energy(acc1, racc); + + OutputConstraint outC(1,OutputConstraint::OVERLAP_SEQ2,0,10000,0,0,0,1,1); + OutputHandlerInteractionList out(outC, 1); + + IndexRange idx1(0,r1.lastPos); + IndexRange idx2(0,r2.lastPos); + + PredictionTrackerBasePairProb * tracker = new PredictionTrackerBasePairProb(energy, ""); + SeedConstraint sC(5,0,0,0,0 + , AccessibilityDisabled::ED_UPPER_BOUND + , 0 + , IndexRangeList("") + , IndexRangeList("") + , "" + , false, false, false + ); + + PredictorMfeEns2dSeedExtension predictor(energy, out, tracker, new SeedHandlerNoBulge(energy, sC), true); + predictor.predict(idx1,idx2); + + std::vector wBP = getBPWeights(energy, 7); + LOG(DEBUG) <<" w 4 "<getBasePairProb(3, 3, &predictor), (3 * wBP[5] + 8 * wBP[6] + wBP[7]) / predictor.getZall())); + } + + SECTION("base pair probs - case 7 seed") { + RnaSequence r1("r1", "GGGGGG"); + RnaSequence r2("r2", "CCCCCC"); + AccessibilityDisabled acc1(r1, 0, NULL); + AccessibilityDisabled acc2(r2, 0, NULL); + ReverseAccessibility racc(acc2); + InteractionEnergyBasePair energy(acc1, racc); + + OutputConstraint outC(1,OutputConstraint::OVERLAP_SEQ2,0,10000,0,0,0,1,1); + OutputHandlerInteractionList out(outC, 1); + + IndexRange idx1(0,r1.lastPos); + IndexRange idx2(0,r2.lastPos); + + PredictionTrackerBasePairProb * tracker = new PredictionTrackerBasePairProb(energy, ""); + SeedConstraint sC(4,0,0,0,0 + , AccessibilityDisabled::ED_UPPER_BOUND + , 0 + , IndexRangeList("") + , IndexRangeList("") + , "" + , false, false, false + ); + + PredictorMfeEns2dSeedExtension predictor(energy, out, tracker, new SeedHandlerNoBulge(energy, sC), true); + predictor.predict(idx1,idx2); + + std::vector wBP = getBPWeights(energy, 6); + + REQUIRE(Z_equal(tracker->getBasePairProb(2, 2, &predictor), (3 * wBP[4] + 8 * wBP[5] + wBP[6]) / predictor.getZall())); + REQUIRE(Z_equal(tracker->getBasePairProb(3, 3, &predictor), (3 * wBP[4] + 8 * wBP[5] + wBP[6]) / predictor.getZall())); + } + +} diff --git a/tests/PredictorMfeEns2d_test.cpp b/tests/PredictorMfeEns2d_test.cpp new file mode 100644 index 00000000..b1b6a56f --- /dev/null +++ b/tests/PredictorMfeEns2d_test.cpp @@ -0,0 +1,148 @@ +#include "catch.hpp" + +#undef NDEBUG + +#include "IntaRNA/RnaSequence.h" +#include "IntaRNA/AccessibilityDisabled.h" +#include "IntaRNA/InteractionEnergyBasePair.h" +#include "IntaRNA/Interaction.h" +#include "IntaRNA/ReverseAccessibility.h" +#include "IntaRNA/PredictorMfeEns2d.h" +#include "IntaRNA/OutputHandlerInteractionList.h" + +using namespace IntaRNA; + +#include + +TEST_CASE( "PredictorMfeEns2d", "[PredictorMfeEns2d]" ) { + + // setup easylogging++ stuff if not already done + #include "testEasyLoggingSetup.icc" + + SECTION("Zall case 1: check value") { + RnaSequence r1("r1", "GG"); + RnaSequence r2("r2", "CC"); + AccessibilityDisabled acc1(r1, 0, NULL); + AccessibilityDisabled acc2(r2, 0, NULL); + ReverseAccessibility racc(acc2); + InteractionEnergyBasePair energy(acc1, racc); + + OutputConstraint outC(1,OutputConstraint::OVERLAP_SEQ2,0,100); + OutputHandlerInteractionList out(outC, 1); + + PredictorMfeEns2d predictor(energy, out, NULL); + + IndexRange idx1(0,r1.lastPos); + IndexRange idx2(0,r2.lastPos); + + predictor.predict(idx1,idx2); + + Z_type boltzmannSum = 4 * energy.getBoltzmannWeight(Ekcal_2_E(-1.0)) + + 1 * energy.getBoltzmannWeight(Ekcal_2_E(-2.0)); + + REQUIRE(Z_equal(predictor.getZall(), boltzmannSum)); + } + + SECTION("Zall case 2: check value") { + RnaSequence r1("r1", "GGG"); + RnaSequence r2("r2", "CCC"); + AccessibilityDisabled acc1(r1, 0, NULL); + AccessibilityDisabled acc2(r2, 0, NULL); + ReverseAccessibility racc(acc2); + InteractionEnergyBasePair energy(acc1, racc); + + OutputConstraint outC(1,OutputConstraint::OVERLAP_SEQ2,0,100); + OutputHandlerInteractionList out(outC, 1); + + PredictorMfeEns2d predictor(energy, out, NULL); + + IndexRange idx1(0,r1.lastPos); + IndexRange idx2(0,r2.lastPos); + + predictor.predict(idx1,idx2); + + Z_type boltzmannSum = 9 * energy.getBoltzmannWeight(Ekcal_2_E(-1.0)) + + 9 * energy.getBoltzmannWeight(Ekcal_2_E(-2.0)) + + 1 * energy.getBoltzmannWeight(Ekcal_2_E(-3.0)); + + REQUIRE(Z_equal(predictor.getZall(), boltzmannSum)); + } + + SECTION("Zall case 3: check value") { + RnaSequence r1("r1", "GGGG"); + RnaSequence r2("r2", "CCCC"); + AccessibilityDisabled acc1(r1, 0, NULL); + AccessibilityDisabled acc2(r2, 0, NULL); + ReverseAccessibility racc(acc2); + InteractionEnergyBasePair energy(acc1, racc); + + OutputConstraint outC(1,OutputConstraint::OVERLAP_SEQ2,0,100); + OutputHandlerInteractionList out(outC, 1); + + PredictorMfeEns2d predictor(energy, out, NULL); + + IndexRange idx1(0,r1.lastPos); + IndexRange idx2(0,r2.lastPos); + + predictor.predict(idx1,idx2); + + Z_type boltzmannSum = 16 * energy.getBoltzmannWeight(Ekcal_2_E(-1.0)) + + 36 * energy.getBoltzmannWeight(Ekcal_2_E(-2.0)) + + 16 * energy.getBoltzmannWeight(Ekcal_2_E(-3.0)) + + 1 * energy.getBoltzmannWeight(Ekcal_2_E(-4.0)); + + REQUIRE(Z_equal(predictor.getZall(), boltzmannSum)); + } + + SECTION("Zall case 4: check value") { + RnaSequence r1("r1", "GGC"); + RnaSequence r2("r2", "GCC"); + AccessibilityDisabled acc1(r1, 0, NULL); + AccessibilityDisabled acc2(r2, 0, NULL); + ReverseAccessibility racc(acc2); + InteractionEnergyBasePair energy(acc1, racc); + + OutputConstraint outC(1,OutputConstraint::OVERLAP_SEQ2,0,100); + OutputHandlerInteractionList out(outC, 1); + + PredictorMfeEns2d predictor(energy, out, NULL); + + IndexRange idx1(0,r1.lastPos); + IndexRange idx2(0,r2.lastPos); + + predictor.predict(idx1,idx2); + + Z_type boltzmannSum = 5 * energy.getBoltzmannWeight(Ekcal_2_E(-1.0)) + + 5 * energy.getBoltzmannWeight(Ekcal_2_E(-2.0)) + + 1 * energy.getBoltzmannWeight(Ekcal_2_E(-3.0)); + + REQUIRE(Z_equal(predictor.getZall(), boltzmannSum)); + } + + SECTION("Zall case 5: check value") { + RnaSequence r1("r1", "GGCGC"); + RnaSequence r2("r2", "GGCC"); + AccessibilityDisabled acc1(r1, 0, NULL); + AccessibilityDisabled acc2(r2, 0, NULL); + ReverseAccessibility racc(acc2); + InteractionEnergyBasePair energy(acc1, racc); + + OutputConstraint outC(1,OutputConstraint::OVERLAP_SEQ2,0,100); + OutputHandlerInteractionList out(outC, 1); + + PredictorMfeEns2d predictor(energy, out, NULL); + + IndexRange idx1(0,r1.lastPos); + IndexRange idx2(0,r2.lastPos); + + predictor.predict(idx1,idx2); + + Z_type boltzmannSum = 10 * energy.getBoltzmannWeight(Ekcal_2_E(-1.0)) + + 24 * energy.getBoltzmannWeight(Ekcal_2_E(-2.0)) + + 12 * energy.getBoltzmannWeight(Ekcal_2_E(-3.0)) + + 1 * energy.getBoltzmannWeight(Ekcal_2_E(-4.0)); + + REQUIRE(Z_equal(predictor.getZall(), boltzmannSum)); + } + +}