diff --git a/Configuration/EventContent/python/EventContent_cff.py b/Configuration/EventContent/python/EventContent_cff.py index de8f59d46602a..0ab2435492bcb 100644 --- a/Configuration/EventContent/python/EventContent_cff.py +++ b/Configuration/EventContent/python/EventContent_cff.py @@ -715,7 +715,7 @@ def SwapKeepAndDrop(l): 'keep *_hltGeneralTracks_*_*', 'keep *_hltInitialStepTrackSelectionHighPurity_*_*', 'keep *_hltHighPtTripletStepTrackSelectionHighPurity_*_*', - 'keep *_hltInitialStepTracksT5TCLST_*_*', + 'keep *_hltInitialStepTracksT4T5TCLST_*_*', 'keep *_hltOfflinePrimaryVertices_*_*', ]) diff --git a/HLTrigger/Configuration/python/HLT_75e33/modules/hltGeneralTracks_cfi.py b/HLTrigger/Configuration/python/HLT_75e33/modules/hltGeneralTracks_cfi.py index 712618b2246ff..ffc4e4da842ff 100644 --- a/HLTrigger/Configuration/python/HLT_75e33/modules/hltGeneralTracks_cfi.py +++ b/HLTrigger/Configuration/python/HLT_75e33/modules/hltGeneralTracks_cfi.py @@ -45,10 +45,10 @@ from Configuration.ProcessModifiers.ngtScouting_cff import ngtScouting from ..modules.hltPhase2PixelTracks_cfi import * _hltGeneralTracksNGTScoutingLST = hltGeneralTracks.clone( - TrackProducers = ["hltPhase2PixelTracks", "hltInitialStepTracksT5TCLST"], + TrackProducers = ["hltPhase2PixelTracks", "hltInitialStepTracksT4T5TCLST"], hasSelector = [0,0], indivShareFrac = [0.1,0.1], - selectedTrackQuals = ["hltPhase2PixelTracks", "hltInitialStepTracksT5TCLST"], + selectedTrackQuals = ["hltPhase2PixelTracks", "hltInitialStepTracksT4T5TCLST"], setsToMerge = {0: dict(pQual=True, tLists=[0,1])} ) diff --git a/HLTrigger/Configuration/python/HLT_75e33/modules/hltInitialStepTracksT5TCLST_cfi.py b/HLTrigger/Configuration/python/HLT_75e33/modules/hltInitialStepTracksT4T5TCLST_cfi.py similarity index 54% rename from HLTrigger/Configuration/python/HLT_75e33/modules/hltInitialStepTracksT5TCLST_cfi.py rename to HLTrigger/Configuration/python/HLT_75e33/modules/hltInitialStepTracksT4T5TCLST_cfi.py index a7c0cc2b302d1..fd19ea3e884e8 100644 --- a/HLTrigger/Configuration/python/HLT_75e33/modules/hltInitialStepTracksT5TCLST_cfi.py +++ b/HLTrigger/Configuration/python/HLT_75e33/modules/hltInitialStepTracksT4T5TCLST_cfi.py @@ -1,4 +1,4 @@ import FWCore.ParameterSet.Config as cms from ..modules.hltInitialStepTracks_cfi import hltInitialStepTracks as _hltInitialStepTracks -hltInitialStepTracksT5TCLST = _hltInitialStepTracks.clone( src = "hltInitialStepTrackCandidates:t5TCsLST" ) +hltInitialStepTracksT4T5TCLST = _hltInitialStepTracks.clone( src = "hltInitialStepTrackCandidates:t4t5TCsLST" ) diff --git a/HLTrigger/Configuration/python/HLT_75e33/sequences/HLTInitialStepSequence_cfi.py b/HLTrigger/Configuration/python/HLT_75e33/sequences/HLTInitialStepSequence_cfi.py index 83e1be0bb3315..49d9fdeeb90b5 100644 --- a/HLTrigger/Configuration/python/HLT_75e33/sequences/HLTInitialStepSequence_cfi.py +++ b/HLTrigger/Configuration/python/HLT_75e33/sequences/HLTInitialStepSequence_cfi.py @@ -56,7 +56,7 @@ (singleIterPatatrack & trackingLST & seedingLST).toReplaceWith(HLTInitialStepSequence, _HLTInitialStepSequenceSingleIterPatatrackLSTSeeding) -from ..modules.hltInitialStepTracksT5TCLST_cfi import * +from ..modules.hltInitialStepTracksT4T5TCLST_cfi import * _HLTInitialStepSequenceNGTScouting = cms.Sequence( hltInitialStepSeeds +hltInitialStepSeedTracksLST @@ -64,7 +64,7 @@ +hltInputLST +hltLST +hltInitialStepTrackCandidates - +hltInitialStepTracksT5TCLST + +hltInitialStepTracksT4T5TCLST ) from Configuration.ProcessModifiers.ngtScouting_cff import ngtScouting diff --git a/RecoTracker/FinalTrackSelectors/python/earlyGeneralTracks_cfi.py b/RecoTracker/FinalTrackSelectors/python/earlyGeneralTracks_cfi.py index 80e7d9299d50a..c7ea1ccd1ee0f 100644 --- a/RecoTracker/FinalTrackSelectors/python/earlyGeneralTracks_cfi.py +++ b/RecoTracker/FinalTrackSelectors/python/earlyGeneralTracks_cfi.py @@ -136,11 +136,11 @@ def _extend_pixelLess(x): from Configuration.ProcessModifiers.trackingLST_cff import trackingLST (trackingPhase2PU140 & trackingLST).toModify(earlyGeneralTracks, - TrackProducers = ['highPtTripletStepLSTpTracks', 'highPtTripletStepLSTT5Tracks'], + TrackProducers = ['highPtTripletStepLSTpTracks', 'highPtTripletStepLSTT4T5Tracks'], hasSelector = [1,0], indivShareFrac = [0.1,0.1], selectedTrackQuals = ['highPtTripletStepSelector:highPtTripletStep', - 'highPtTripletStepSelectorLSTT5:highPtTripletStepLSTT5' + 'highPtTripletStepSelectorLSTT4T5:highPtTripletStepLSTT4T5' ], setsToMerge = {0: dict(tLists = [0,1])} ) diff --git a/RecoTracker/IterativeTracking/python/HighPtTripletStep_cff.py b/RecoTracker/IterativeTracking/python/HighPtTripletStep_cff.py index 807f1dd390bea..3fc8609645d8e 100644 --- a/RecoTracker/IterativeTracking/python/HighPtTripletStep_cff.py +++ b/RecoTracker/IterativeTracking/python/HighPtTripletStep_cff.py @@ -287,16 +287,16 @@ highPtTripletStepLSTpTracks = highPtTripletStepTracks.clone( src = 'highPtTripletStepTrackCandidates:pTCsLST' ) -highPtTripletStepLSTT5Tracks = highPtTripletStepTracks.clone( - src = 'highPtTripletStepTrackCandidates:t5TCsLST' +highPtTripletStepLSTT4T5Tracks = highPtTripletStepTracks.clone( + src = 'highPtTripletStepTrackCandidates:t4t5TCsLST' ) _highPtTripletStepTracks_LST = RecoTracker.FinalTrackSelectors.trackListMerger_cfi.trackListMerger.clone( TrackProducers = ['highPtTripletStepLSTpTracks', - 'highPtTripletStepLSTT5Tracks'], + 'highPtTripletStepLSTT4T5Tracks'], hasSelector = [1,0], indivShareFrac = [0.1,0.1], selectedTrackQuals = ['highPtTripletStepSelector:highPtTripletStep', - 'highPtTripletStepSelectorLSTT5:highPtTripletStepLSTT5'], + 'highPtTripletStepSelectorLSTT4T5:highPtTripletStepLSTT4T5'], copyExtras = True, copyMVA = False, setsToMerge = [cms.PSet( tLists=cms.vint32(0,1), pQual=cms.bool(True) )] @@ -382,21 +382,21 @@ (trackingPhase2PU140 & trackingLST).toModify(highPtTripletStepSelector, src = 'highPtTripletStepLSTpTracks') # Passthrough selector to satisfy the TrackListMerger requirement for selector values -highPtTripletStepSelectorLSTT5 = RecoTracker.FinalTrackSelectors.multiTrackSelector_cfi.multiTrackSelector.clone( - src = 'highPtTripletStepLSTT5Tracks', +highPtTripletStepSelectorLSTT4T5 = RecoTracker.FinalTrackSelectors.multiTrackSelector_cfi.multiTrackSelector.clone( + src = 'highPtTripletStepLSTT4T5Tracks', trackSelectors = [ RecoTracker.FinalTrackSelectors.multiTrackSelector_cfi.looseMTS.clone( - name = 'highPtTripletStepLSTT5Loose', + name = 'highPtTripletStepLSTT4T5Loose', minHitsToBypassChecks = 0 ), #end of pset RecoTracker.FinalTrackSelectors.multiTrackSelector_cfi.tightMTS.clone( - name = 'highPtTripletStepLSTT5Tight', - preFilterName = 'highPtTripletStepLSTT5Loose', + name = 'highPtTripletStepLSTT4T5Tight', + preFilterName = 'highPtTripletStepLSTT4T5Loose', minHitsToBypassChecks = 0 ), RecoTracker.FinalTrackSelectors.multiTrackSelector_cfi.highpurityMTS.clone( - name = 'highPtTripletStepLSTT5', - preFilterName = 'highPtTripletStepLSTT5Tight', + name = 'highPtTripletStepLSTT4T5', + preFilterName = 'highPtTripletStepLSTT4T5Tight', minHitsToBypassChecks = 0 ), ] #end of vpset @@ -430,7 +430,7 @@ from RecoTracker.LST.lstProducerTask_cff import * _HighPtTripletStepTask_LST.add(siPhase2RecHits, lstInitialStepSeedTracks, lstHighPtTripletStepSeedTracks, lstInputProducer, - lstProducerTask, highPtTripletStepLSTpTracks, highPtTripletStepLSTT5Tracks, highPtTripletStepSelectorLSTT5) + lstProducerTask, highPtTripletStepLSTpTracks, highPtTripletStepLSTT4T5Tracks, highPtTripletStepSelectorLSTT4T5) (trackingPhase2PU140 & trackingLST).toReplaceWith(HighPtTripletStepTask, _HighPtTripletStepTask_LST) from HeterogeneousCore.AlpakaCore.functions import makeSerialClone @@ -445,23 +445,23 @@ ) highPtTripletStepLSTpTracksSerialSync = highPtTripletStepLSTpTracks.clone( src = 'highPtTripletStepTrackCandidatesSerialSync:pTCsLST') -highPtTripletStepLSTT5TracksSerialSync = highPtTripletStepLSTT5Tracks.clone( - src = 'highPtTripletStepTrackCandidatesSerialSync:t5TCsLST') +highPtTripletStepLSTT4T5TracksSerialSync = highPtTripletStepLSTT4T5Tracks.clone( + src = 'highPtTripletStepTrackCandidatesSerialSync:t4t5TCsLST') highPtTripletStepSelectorSerialSync = highPtTripletStepSelector.clone() (trackingPhase2PU140 & trackingLST).toModify(highPtTripletStepSelectorSerialSync, src = "highPtTripletStepLSTpTracksSerialSync" ) -highPtTripletStepSelectorLSTT5SerialSync = highPtTripletStepSelectorLSTT5.clone(src = "highPtTripletStepLSTT5TracksSerialSync") +highPtTripletStepSelectorLSTT4T5SerialSync = highPtTripletStepSelectorLSTT4T5.clone(src = "highPtTripletStepLSTT4T5TracksSerialSync") highPtTripletStepTracksSerialSync = highPtTripletStepTracks.clone() (trackingPhase2PU140 & trackingLST).toModify(highPtTripletStepTracksSerialSync, TrackProducers = ['highPtTripletStepLSTpTracksSerialSync', - 'highPtTripletStepLSTT5TracksSerialSync'], + 'highPtTripletStepLSTT4T5TracksSerialSync'], selectedTrackQuals = ['highPtTripletStepSelectorSerialSync:highPtTripletStep', - 'highPtTripletStepSelectorLSTT5SerialSync:highPtTripletStepLSTT5'], + 'highPtTripletStepSelectorLSTT4T5SerialSync:highPtTripletStepLSTT4T5'], ) _HighPtTripletStepTask_LSTSerialSync = HighPtTripletStepTask.copy() _HighPtTripletStepTask_LSTSerialSync.add(siPhase2RecHits, lstInitialStepSeedTracks, lstHighPtTripletStepSeedTracks, lstInputProducerSerialSync, lstProducerSerialSync, highPtTripletStepTrackCandidatesSerialSync, - highPtTripletStepLSTpTracksSerialSync, highPtTripletStepLSTT5TracksSerialSync, - highPtTripletStepSelectorSerialSync, highPtTripletStepSelectorLSTT5SerialSync, + highPtTripletStepLSTpTracksSerialSync, highPtTripletStepLSTT4T5TracksSerialSync, + highPtTripletStepSelectorSerialSync, highPtTripletStepSelectorLSTT4T5SerialSync, highPtTripletStepTracksSerialSync ) HighPtTripletStepTaskSerialSync = cms.Task() diff --git a/RecoTracker/LST/plugins/LSTOutputConverter.cc b/RecoTracker/LST/plugins/LSTOutputConverter.cc index db9303e43f7cd..200db3cc5c9e6 100644 --- a/RecoTracker/LST/plugins/LSTOutputConverter.cc +++ b/RecoTracker/LST/plugins/LSTOutputConverter.cc @@ -44,18 +44,19 @@ class LSTOutputConverter : public edm::stream::EDProducer<> { const edm::ESGetToken propagatorAlongToken_; const edm::ESGetToken propagatorOppositeToken_; const edm::ESGetToken tGeomToken_; + const edm::ESGetToken tTopoToken_; std::unique_ptr seedCreator_; const edm::EDPutTokenT trajectorySeedPutToken_; const edm::EDPutTokenT trajectorySeedpLSPutToken_; const edm::EDPutTokenT trackCandidatePutToken_; const edm::EDPutTokenT trackCandidatepTCPutToken_; - const edm::EDPutTokenT trackCandidateT5TCPutToken_; + const edm::EDPutTokenT trackCandidateT4T5TCPutToken_; const edm::EDPutTokenT trackCandidateNopLSTCPutToken_; const edm::EDPutTokenT trackCandidatepTTCPutToken_; const edm::EDPutTokenT trackCandidatepLSTCPutToken_; const edm::EDPutTokenT> seedStopInfoPutToken_; const edm::EDPutTokenT> pTCsSeedStopInfoPutToken_; - const edm::EDPutTokenT> t5TCsSeedStopInfoPutToken_; + const edm::EDPutTokenT> t4t5TCsSeedStopInfoPutToken_; const edm::EDPutTokenT> pTTCsSeedStopInfoPutToken_; }; @@ -69,6 +70,7 @@ LSTOutputConverter::LSTOutputConverter(edm::ParameterSet const& iConfig) propagatorAlongToken_{esConsumes(iConfig.getParameter("propagatorAlong"))}, propagatorOppositeToken_{esConsumes(iConfig.getParameter("propagatorOpposite"))}, tGeomToken_(esConsumes()), + tTopoToken_(esConsumes()), seedCreator_(SeedCreatorFactory::get()->create("SeedFromConsecutiveHitsCreator", iConfig.getParameter("SeedCreatorPSet"), consumesCollector())), @@ -82,13 +84,13 @@ LSTOutputConverter::LSTOutputConverter(edm::ParameterSet const& iConfig) trajectorySeedpLSPutToken_(produces("pLSTSsLST")), trackCandidatePutToken_(produces("")), trackCandidatepTCPutToken_(produces("pTCsLST")), - trackCandidateT5TCPutToken_(produces("t5TCsLST")), + trackCandidateT4T5TCPutToken_(produces("t4t5TCsLST")), trackCandidateNopLSTCPutToken_(produces("nopLSTCsLST")), trackCandidatepTTCPutToken_(produces("pTTCsLST")), trackCandidatepLSTCPutToken_(produces("pLSTCsLST")), seedStopInfoPutToken_(produces("")), pTCsSeedStopInfoPutToken_(produces("pTCsLST")), - t5TCsSeedStopInfoPutToken_(produces("t5TCsLST")), + t4t5TCsSeedStopInfoPutToken_(produces("t4t5TCsLST")), pTTCsSeedStopInfoPutToken_(produces("pTTCsLST")) {} void LSTOutputConverter::fillDescriptions(edm::ConfigurationDescriptions& descriptions) { @@ -126,6 +128,7 @@ void LSTOutputConverter::produce(edm::Event& iEvent, const edm::EventSetup& iSet auto const& propAlo = iSetup.getData(propagatorAlongToken_); auto const& propOppo = iSetup.getData(propagatorOppositeToken_); auto const& tracker = iSetup.getData(tGeomToken_); + const TrackerTopology& tTopo = iSetup.getData(tTopoToken_); auto lstOutput_view = lstOutput.const_view(); unsigned int nTrackCandidates = lstOutput_view.nTrackCandidates(); @@ -135,10 +138,10 @@ void LSTOutputConverter::produce(edm::Event& iEvent, const edm::EventSetup& iSet TrajectorySeedCollection outputTS, outputpLSTS; outputTS.reserve(nTrackCandidates); outputpLSTS.reserve(nTrackCandidates); - TrackCandidateCollection outputTC, outputpTC, outputT5TC, outputNopLSTC, outputpTTC, outputpLSTC; + TrackCandidateCollection outputTC, outputpTC, outputT4T5TC, outputNopLSTC, outputpTTC, outputpLSTC; outputTC.reserve(nTrackCandidates); outputpTC.reserve(nTrackCandidates); - outputT5TC.reserve(nTrackCandidates); + outputT4T5TC.reserve(nTrackCandidates); outputNopLSTC.reserve(nTrackCandidates); outputpTTC.reserve(nTrackCandidates); outputpLSTC.reserve(nTrackCandidates); @@ -151,19 +154,19 @@ void LSTOutputConverter::produce(edm::Event& iEvent, const edm::EventSetup& iSet LogDebug("LSTOutputConverter") << " cand " << i << " " << iType << " " << lstOutput_view.pixelSeedIndex()[i]; TrajectorySeed seed; edm::RefToBase seedRef; - if (iType != lst::LSTObjType::T5) { + if (iType != lst::LSTObjType::T5 && iType != lst::LSTObjType::T4) { seed = pixelSeeds[lstOutput_view.pixelSeedIndex()[i]]; seedRef = {pixelSeedsRBP, lstOutput_view.pixelSeedIndex()[i]}; } edm::OwnVector recHits; - if (iType != lst::LSTObjType::T5) { + if (iType != lst::LSTObjType::T5 && iType != lst::LSTObjType::T4) { for (auto const& hit : seed.recHits()) recHits.push_back(hit.clone()); } // pixel-seeded TCs from LST always have 4 pixel hits - unsigned int const nPixelHits = iType == lst::LSTObjType::T5 ? 0 : 4; + unsigned int const nPixelHits = (iType == lst::LSTObjType::T5 || iType == lst::LSTObjType::T4) ? 0 : 4; unsigned int nHits = 0; switch (iType) { case lst::LSTObjType::T5: @@ -178,6 +181,9 @@ void LSTOutputConverter::produce(edm::Event& iEvent, const edm::EventSetup& iSet case lst::LSTObjType::pLS: nHits = lst::Params_pLS::kHits; break; + case lst::LSTObjType::T4: + nHits = lst::Params_T4::kHits; + break; } for (unsigned int j = nPixelHits; j < nHits; j++) recHits.push_back(OTHits[lstOutput_view.hitIndices()[i][j]]->clone()); @@ -206,17 +212,28 @@ void LSTOutputConverter::produce(edm::Event& iEvent, const edm::EventSetup& iSet if (iType != lst::LSTObjType::pLS) { // Construct a full-length TrajectorySeed always for T5s, // only when required by a flag for other pT objects. - if (includeNonpLSTSs_ || iType == lst::LSTObjType::T5) { + if (includeNonpLSTSs_ || (iType == lst::LSTObjType::T5 || iType == lst::LSTObjType::T4)) { using Hit = SeedingHitSet::ConstRecHitPointer; std::vector hitsForSeed; hitsForSeed.reserve(nHits); int n = 0; + unsigned int firstLayer; for (auto const& hit : recHits) { if (iType == lst::LSTObjType::T5) { auto hType = tracker.getDetectorType(hit.geographicalId()); if (hType != TrackerGeometry::ModuleType::Ph2PSP && n < 2) continue; // the first two should be P } + if (iType == lst::LSTObjType::T4) { + unsigned int hitLayer = tTopo.layer(hit.geographicalId()); + auto hType = tracker.getDetectorType(hit.geographicalId()); + if (n == 0) + firstLayer = hitLayer; + else { + if (hType == TrackerGeometry::ModuleType::Ph2PSS && hitLayer == firstLayer) + continue; + } + } hitsForSeed.emplace_back(dynamic_cast(&hit)); n++; } @@ -226,10 +243,10 @@ void LSTOutputConverter::produce(edm::Event& iEvent, const edm::EventSetup& iSet if (seeds.empty()) { edm::LogInfo("LSTOutputConverter") << "failed to convert a LST object to a seed" << i << " " << iType << " " << lstOutput_view.pixelSeedIndex()[i]; - if (iType == lst::LSTObjType::T5) + if (iType == lst::LSTObjType::T5 || iType == lst::LSTObjType::T4) continue; } - if (iType == lst::LSTObjType::T5) { + if (iType == lst::LSTObjType::T5 || iType == lst::LSTObjType::T4) { seed = seeds[0]; seedRef = edm::RefToBase(edm::Ref(outputTSRP, outputTS.size())); } @@ -257,13 +274,13 @@ void LSTOutputConverter::produce(edm::Event& iEvent, const edm::EventSetup& iSet PTrajectoryStateOnDet st = trajectoryStateTransform::persistentState(tsosPair.first, recHits[0].det()->geographicalId().rawId()); - if (!includeT5s_ && iType == lst::LSTObjType::T5) + if (!includeT5s_ && (iType == lst::LSTObjType::T5 || iType == lst::LSTObjType::T4)) continue; auto tc = TrackCandidate(recHits, seed, st, seedRef); outputTC.emplace_back(tc); - if (iType == lst::LSTObjType::T5) { - outputT5TC.emplace_back(tc); + if (iType == lst::LSTObjType::T5 || iType == lst::LSTObjType::T4) { + outputT4T5TC.emplace_back(tc); outputNopLSTC.emplace_back(tc); } else { outputpTC.emplace_back(tc); @@ -284,14 +301,14 @@ void LSTOutputConverter::produce(edm::Event& iEvent, const edm::EventSetup& iSet } LogDebug("LSTOutputConverter") << "done with conversion: Track candidate output size = " << outputpTC.size() - << " (p* objects) + " << outputT5TC.size() << " (T5 objects)"; + << " (p* objects) + " << outputT4T5TC.size() << " (T5 objects)"; //dummy (for now) stop infos: one per used kind of candidates std::vector seedStopInfo(pixelSeeds.size()); iEvent.emplace(seedStopInfoPutToken_, std::move(seedStopInfo)); std::vector pTCsSeedStopInfo(pixelSeeds.size()); iEvent.emplace(pTCsSeedStopInfoPutToken_, std::move(pTCsSeedStopInfo)); - std::vector t5TCsSeedStopInfo(outputTS.size()); - iEvent.emplace(t5TCsSeedStopInfoPutToken_, std::move(t5TCsSeedStopInfo)); + std::vector t4t5TCsSeedStopInfo(outputTS.size()); + iEvent.emplace(t4t5TCsSeedStopInfoPutToken_, std::move(t4t5TCsSeedStopInfo)); std::vector pTTCsSeedStopInfo(pixelSeeds.size()); iEvent.emplace(pTTCsSeedStopInfoPutToken_, std::move(pTTCsSeedStopInfo)); @@ -299,7 +316,7 @@ void LSTOutputConverter::produce(edm::Event& iEvent, const edm::EventSetup& iSet iEvent.emplace(trajectorySeedpLSPutToken_, std::move(outputpLSTS)); iEvent.emplace(trackCandidatePutToken_, std::move(outputTC)); iEvent.emplace(trackCandidatepTCPutToken_, std::move(outputpTC)); - iEvent.emplace(trackCandidateT5TCPutToken_, std::move(outputT5TC)); + iEvent.emplace(trackCandidateT4T5TCPutToken_, std::move(outputT4T5TC)); iEvent.emplace(trackCandidateNopLSTCPutToken_, std::move(outputNopLSTC)); iEvent.emplace(trackCandidatepTTCPutToken_, std::move(outputpTTC)); iEvent.emplace(trackCandidatepLSTCPutToken_, std::move(outputpLSTC)); diff --git a/RecoTracker/LSTCore/interface/Common.h b/RecoTracker/LSTCore/interface/Common.h index b383f6aea0e2b..0fa84d28aee54 100644 --- a/RecoTracker/LSTCore/interface/Common.h +++ b/RecoTracker/LSTCore/interface/Common.h @@ -18,7 +18,7 @@ namespace lst { enum PixelType : int8_t { kInvalid = -1, kHighPt = 0, kLowPtPosCurv = 1, kLowPtNegCurv = 2 }; // Named types for LST objects - enum LSTObjType : int8_t { T5 = 4, pT3 = 5, pT5 = 7, pLS = 8 }; + enum LSTObjType : int8_t { T5 = 4, pT3 = 5, pT5 = 7, pLS = 8, T4 = 9 }; constexpr unsigned int max_blocks = 80; constexpr unsigned int max_connected_modules = 40; @@ -73,6 +73,12 @@ namespace lst { using ArrayU16xLayers = edm::StdArray; using ArrayUxHits = edm::StdArray; }; + struct Params_T4 { + static constexpr int kLayers = 4, kHits = 8; + using ArrayU8xLayers = edm::StdArray; + using ArrayU16xLayers = edm::StdArray; + using ArrayUxHits = edm::StdArray; + }; struct Params_T5 { static constexpr int kLayers = 5, kHits = 10; static constexpr int kEmbed = 6; diff --git a/RecoTracker/LSTCore/interface/ObjectRangesSoA.h b/RecoTracker/LSTCore/interface/ObjectRangesSoA.h index ccab6b23909f6..d09effe64d47b 100644 --- a/RecoTracker/LSTCore/interface/ObjectRangesSoA.h +++ b/RecoTracker/LSTCore/interface/ObjectRangesSoA.h @@ -13,19 +13,25 @@ namespace lst { SOA_COLUMN(ArrayIx2, segmentRanges), SOA_COLUMN(ArrayIx2, tripletRanges), SOA_COLUMN(ArrayIx2, quintupletRanges), + SOA_COLUMN(ArrayIx2, quadrupletRanges), SOA_COLUMN(int, miniDoubletModuleIndices), SOA_COLUMN(int, miniDoubletModuleOccupancy), SOA_COLUMN(int, segmentModuleIndices), SOA_COLUMN(int, segmentModuleOccupancy), SOA_COLUMN(int, tripletModuleIndices), SOA_COLUMN(int, tripletModuleOccupancy), + SOA_COLUMN(int, quadrupletModuleIndices), + SOA_COLUMN(int, quadrupletModuleOccupancy), SOA_COLUMN(int, quintupletModuleIndices), SOA_COLUMN(int, quintupletModuleOccupancy), + SOA_COLUMN(uint16_t, indicesOfEligibleT4Modules), SOA_COLUMN(uint16_t, indicesOfEligibleT5Modules), SOA_SCALAR(unsigned int, nTotalMDs), SOA_SCALAR(unsigned int, nTotalSegs), SOA_SCALAR(unsigned int, nTotalTrips), + SOA_SCALAR(unsigned int, nTotalQuads), SOA_SCALAR(unsigned int, nTotalQuints), + SOA_SCALAR(uint16_t, nEligibleT4Modules), SOA_SCALAR(uint16_t, nEligibleT5Modules)) using ObjectRangesSoA = ObjectRangesSoALayout<>; diff --git a/RecoTracker/LSTCore/interface/QuadrupletsHostCollection.h b/RecoTracker/LSTCore/interface/QuadrupletsHostCollection.h new file mode 100644 index 0000000000000..128ed65e3ece3 --- /dev/null +++ b/RecoTracker/LSTCore/interface/QuadrupletsHostCollection.h @@ -0,0 +1,10 @@ +#ifndef RecoTracker_LSTCore_interface_QuadrupletsHostCollection_h +#define RecoTracker_LSTCore_interface_QuadrupletsHostCollection_h + +#include "RecoTracker/LSTCore/interface/QuadrupletsSoA.h" +#include "DataFormats/Portable/interface/PortableHostCollection.h" + +namespace lst { + using QuadrupletsHostCollection = PortableHostMultiCollection; +} // namespace lst +#endif diff --git a/RecoTracker/LSTCore/interface/QuadrupletsSoA.h b/RecoTracker/LSTCore/interface/QuadrupletsSoA.h new file mode 100644 index 0000000000000..0fa57a35e7f40 --- /dev/null +++ b/RecoTracker/LSTCore/interface/QuadrupletsSoA.h @@ -0,0 +1,52 @@ +#ifndef RecoTracker_LSTCore_interface_QuadrupletsSoA_h +#define RecoTracker_LSTCore_interface_QuadrupletsSoA_h + +#include +#include "DataFormats/Common/interface/StdArray.h" +#include "DataFormats/SoATemplate/interface/SoALayout.h" + +#include "RecoTracker/LSTCore/interface/Common.h" + +namespace lst { + GENERATE_SOA_LAYOUT(QuadrupletsSoALayout, + SOA_COLUMN(ArrayUx2, + preAllocatedTripletIndices), // pre-allocated the theoretical max triplet indices + SOA_COLUMN(ArrayUx2, tripletIndices), // inner and outer triplet indices + SOA_COLUMN(Params_T4::ArrayU16xLayers, lowerModuleIndices), // lower module index in each layer + SOA_COLUMN(Params_T4::ArrayU8xLayers, logicalLayers), // layer ID + SOA_COLUMN(Params_T4::ArrayUxHits, hitIndices), // hit indices + SOA_COLUMN(FPX, innerRadius), // inner triplet circle radius + SOA_COLUMN(FPX, outerRadius), // outer triplet radius + SOA_COLUMN(FPX, pt), + SOA_COLUMN(FPX, eta), + SOA_COLUMN(FPX, phi), + SOA_COLUMN(FPX, score_rphisum), // r-phi based score + SOA_COLUMN(char, isDup), // duplicate flag + SOA_COLUMN(bool, partOfTC), + SOA_COLUMN(float, regressionRadius), + SOA_COLUMN(float, nonAnchorRegressionRadius), + SOA_COLUMN(float, regressionCenterX), + SOA_COLUMN(float, regressionCenterY), + SOA_COLUMN(float, rzChiSquared), // r-z only chi2 + SOA_COLUMN(float, chiSquared), + SOA_COLUMN(float, nonAnchorChiSquared), + SOA_COLUMN(float, promptScore), + SOA_COLUMN(float, displacedScore), + SOA_COLUMN(float, fakeScore), + SOA_COLUMN(int, layer), + SOA_COLUMN(float, dBeta)); + + using QuadrupletsSoA = QuadrupletsSoALayout<>; + using Quadruplets = QuadrupletsSoA::View; + using QuadrupletsConst = QuadrupletsSoA::ConstView; + + GENERATE_SOA_LAYOUT(QuadrupletsOccupancySoALayout, + SOA_COLUMN(unsigned int, nQuadruplets), + SOA_COLUMN(unsigned int, totOccupancyQuadruplets)); + + using QuadrupletsOccupancySoA = QuadrupletsOccupancySoALayout<>; + using QuadrupletsOccupancy = QuadrupletsOccupancySoA::View; + using QuadrupletsOccupancyConst = QuadrupletsOccupancySoA::ConstView; + +} // namespace lst +#endif diff --git a/RecoTracker/LSTCore/interface/TrackCandidatesSoA.h b/RecoTracker/LSTCore/interface/TrackCandidatesSoA.h index 86256e2b8be0d..b6182f167f50b 100644 --- a/RecoTracker/LSTCore/interface/TrackCandidatesSoA.h +++ b/RecoTracker/LSTCore/interface/TrackCandidatesSoA.h @@ -26,7 +26,8 @@ namespace lst { SOA_SCALAR(unsigned int, nTrackCandidatespT3), SOA_SCALAR(unsigned int, nTrackCandidatespT5), SOA_SCALAR(unsigned int, nTrackCandidatespLS), - SOA_SCALAR(unsigned int, nTrackCandidatesT5)) + SOA_SCALAR(unsigned int, nTrackCandidatesT5), + SOA_SCALAR(unsigned int, nTrackCandidatesT4)) using TrackCandidatesBaseSoA = TrackCandidatesBaseSoALayout<>; using TrackCandidatesExtendedSoA = TrackCandidatesExtendedSoALayout<>; diff --git a/RecoTracker/LSTCore/interface/TripletsSoA.h b/RecoTracker/LSTCore/interface/TripletsSoA.h index ca5cc21738c5c..b67bf4ca5fc52 100644 --- a/RecoTracker/LSTCore/interface/TripletsSoA.h +++ b/RecoTracker/LSTCore/interface/TripletsSoA.h @@ -23,6 +23,8 @@ namespace lst { SOA_COLUMN(float, promptScore), // DNN confidence score for real (prompt) t3 SOA_COLUMN(float, displacedScore), // DNN confidence score for real (displaced) t3 SOA_COLUMN(unsigned int, connectedMax), // number of outer-triplets that pass the MD-equality cut + SOA_COLUMN(unsigned int, connectedLSMax), // n of outer-triplets that pass the LS-equality cut + SOA_COLUMN(short, charge), #ifdef CUT_VALUE_DEBUG SOA_COLUMN(float, betaInCut), #endif diff --git a/RecoTracker/LSTCore/interface/alpaka/Common.h b/RecoTracker/LSTCore/interface/alpaka/Common.h index a69528059177d..ddef1abc9c7c5 100644 --- a/RecoTracker/LSTCore/interface/alpaka/Common.h +++ b/RecoTracker/LSTCore/interface/alpaka/Common.h @@ -118,6 +118,25 @@ namespace ALPAKA_ACCELERATOR_NAMESPACE::lst { }; } // namespace pt3dnn + namespace t4dnn { + HOST_DEVICE_CONSTANT float kZ_max = 267.2349854f; + HOST_DEVICE_CONSTANT float kR_max = 110.1099396f; + HOST_DEVICE_CONSTANT float kEta_norm = 2.5f; + constexpr unsigned int kEtaBins = 25; + + HOST_DEVICE_CONSTANT float kWp_displaced[kPtBins][kEtaBins] = { + {0.6532, 0.2885, 0.3381, 0.3925, 0.3886, 0.3998, 0.5003, 0.4532, 0.3624, 0.5571, 0.4461, 0.3688, 0.4487, + 0.4183, 0.42073, 0.4718, 0.4004, 0.3037, 0.3010, 0.2001, 0.2483, 0.2288, 0.0990, 0.0992, 0.0847}, + {0.0245, 0.0330, 0.1931, 0.0502, 0.0179, 0.8189, 0.8216, 0.5082, 0.3526, 0.2734, 0.4204, 0.0582, 0.0184, + 0.1018, 0.0899, 0.2338, 0.2594, 0.2093, 0.1854, 0.1399, 0.2743, 0.6624, 0.7046, 0.0640, 0.2394}}; + + HOST_DEVICE_CONSTANT float kWp_fake[kPtBins][kEtaBins] = { + {0.1999, 0.4224, 0.2946, 0.3265, 0.3264, 0.2734, 0.2478, 0.1879, 0.1520, 0.2460, 0.2781, 0.3844, 0.2920, + 0.3993, 0.1187, 0.0933, 0.1248, 0.1158, 0.1441, 0.0827, 0.0738, 0.0402, 0.0314, 0.0208, 0.0131}, + {0.9115, 0.9605, 0.6660, 0.7374, 0.9263, 0.1698, 0.1485, 0.3590, 0.5302, 0.6662, 0.1273, 0.5445, 0.5916, + 0.5985, 0.7687, 0.1317, 0.2187, 0.1160, 0.4810, 0.1532, 0.3180, 0.0155, 0.0111, 0.1336, 0.1455}}; + } // namespace t4dnn + } // namespace dnn } // namespace ALPAKA_ACCELERATOR_NAMESPACE::lst diff --git a/RecoTracker/LSTCore/interface/alpaka/QuadrupletsDeviceCollection.h b/RecoTracker/LSTCore/interface/alpaka/QuadrupletsDeviceCollection.h new file mode 100644 index 0000000000000..f24e3a9856e46 --- /dev/null +++ b/RecoTracker/LSTCore/interface/alpaka/QuadrupletsDeviceCollection.h @@ -0,0 +1,12 @@ +#ifndef RecoTracker_LSTCore_interface_QuadrupletsDeviceCollection_h +#define RecoTracker_LSTCore_interface_QuadrupletsDeviceCollection_h + +#include "DataFormats/Portable/interface/alpaka/PortableCollection.h" + +#include "RecoTracker/LSTCore/interface/alpaka/Common.h" +#include "RecoTracker/LSTCore/interface/QuadrupletsSoA.h" + +namespace ALPAKA_ACCELERATOR_NAMESPACE::lst { + using QuadrupletsDeviceCollection = PortableCollection2; +} // namespace ALPAKA_ACCELERATOR_NAMESPACE::lst +#endif diff --git a/RecoTracker/LSTCore/src/alpaka/Kernels.h b/RecoTracker/LSTCore/src/alpaka/Kernels.h index 3eafeb6025de6..49c700f48ecc6 100644 --- a/RecoTracker/LSTCore/src/alpaka/Kernels.h +++ b/RecoTracker/LSTCore/src/alpaka/Kernels.h @@ -14,6 +14,7 @@ #include "RecoTracker/LSTCore/interface/QuintupletsSoA.h" #include "RecoTracker/LSTCore/interface/SegmentsSoA.h" #include "RecoTracker/LSTCore/interface/TripletsSoA.h" +#include "RecoTracker/LSTCore/interface/QuadrupletsSoA.h" namespace ALPAKA_ACCELERATOR_NAMESPACE::lst { ALPAKA_FN_ACC ALPAKA_FN_INLINE void rmQuintupletFromMemory(Quintuplets quintuplets, @@ -38,6 +39,12 @@ namespace ALPAKA_ACCELERATOR_NAMESPACE::lst { pixelSegments.isDup()[pixelSegmentArrayIndex] |= 1 + secondpass; } + ALPAKA_FN_ACC ALPAKA_FN_INLINE void rmQuadrupletFromMemory(Quadruplets quadruplets, + unsigned int quadrupletIndex, + bool secondpass = false) { + quadruplets.isDup()[quadrupletIndex] |= 1 + secondpass; + }; + ALPAKA_FN_ACC ALPAKA_FN_INLINE int checkHitsT5(unsigned int ix, unsigned int jx, QuintupletsConst quintuplets) { unsigned int hits1[Params_T5::kHits]; unsigned int hits2[Params_T5::kHits]; @@ -142,6 +149,31 @@ namespace ALPAKA_ACCELERATOR_NAMESPACE::lst { matched[1] = nMatched; } + ALPAKA_FN_ACC ALPAKA_FN_INLINE int checkHitsT4(unsigned int ix, unsigned int jx, QuadrupletsConst quadruplets) { + unsigned int hits1[Params_T4::kHits]; + unsigned int hits2[Params_T4::kHits]; + + for (int i = 0; i < Params_T4::kHits; i++) { + hits1[i] = quadruplets.hitIndices()[ix][i]; + hits2[i] = quadruplets.hitIndices()[jx][i]; + } + + int nMatched = 0; + for (int i = 0; i < Params_T4::kHits; i++) { + bool matched = false; + for (int j = 0; j < Params_T4::kHits; j++) { + if (hits1[i] == hits2[j]) { + matched = true; + break; + } + } + if (matched) { + nMatched++; + } + } + return nMatched; + }; + struct RemoveDupQuintupletsAfterBuild { ALPAKA_FN_ACC void operator()(Acc3D const& acc, ModulesConst modules, @@ -274,6 +306,124 @@ namespace ALPAKA_ACCELERATOR_NAMESPACE::lst { } }; + struct RemoveDupQuadrupletsAfterBuild { + ALPAKA_FN_ACC void operator()(Acc3D const& acc, + ModulesConst modules, + Quadruplets quadruplets, + QuadrupletsOccupancyConst quadrupletsOccupancy, + ObjectRangesConst ranges) const { + for (auto lowmod : cms::alpakatools::uniform_elements_z(acc, modules.nLowerModules())) { + unsigned int nQuadruplets_lowmod = quadrupletsOccupancy.nQuadruplets()[lowmod]; + int quadrupletModuleIndices_lowmod = ranges.quadrupletModuleIndices()[lowmod]; + + for (unsigned int ix1 : cms::alpakatools::uniform_elements_y(acc, nQuadruplets_lowmod)) { + unsigned int ix = quadrupletModuleIndices_lowmod + ix1; + const float eta1 = __H2F(quadruplets.eta()[ix]); + const float phi1 = __H2F(quadruplets.phi()[ix]); + const float score1 = quadruplets.displacedScore()[ix] - quadruplets.fakeScore()[ix]; + + for (unsigned int jx1 : cms::alpakatools::uniform_elements_x(acc, ix1 + 1, nQuadruplets_lowmod)) { + unsigned int jx = quadrupletModuleIndices_lowmod + jx1; + + const float eta2 = __H2F(quadruplets.eta()[jx]); + const float phi2 = __H2F(quadruplets.phi()[jx]); + float dEta = alpaka::math::abs(acc, eta1 - eta2); + float dPhi = cms::alpakatools::deltaPhi(acc, phi1, phi2); + + if (dEta > 0.1f) + continue; + + if (alpaka::math::abs(acc, dPhi) > 0.1f) + continue; + + const float score2 = quadruplets.displacedScore()[jx] - quadruplets.fakeScore()[jx]; + + int nMatched = checkHitsT4(ix, jx, quadruplets); + const int minNHitsForDup_T4 = 5; + if (nMatched >= minNHitsForDup_T4) { + if (score1 >= score2) { + rmQuadrupletFromMemory(quadruplets, jx); + } else { + rmQuadrupletFromMemory(quadruplets, ix); + } + } + } + } + } + } + }; + + struct RemoveDupQuadrupletsBeforeTC { + ALPAKA_FN_ACC void operator()(Acc2D const& acc, + Quadruplets quadruplets, + QuadrupletsOccupancyConst quadrupletsOccupancy, + ObjectRangesConst ranges) const { + for (unsigned int lowmodIdx1 : cms::alpakatools::uniform_elements_y(acc, ranges.nEligibleT4Modules())) { + uint16_t lowmod1 = ranges.indicesOfEligibleT4Modules()[lowmodIdx1]; + unsigned int nQuadruplets_lowmod1 = quadrupletsOccupancy.nQuadruplets()[lowmod1]; + if (nQuadruplets_lowmod1 == 0) + continue; + + unsigned int quadrupletModuleIndices_lowmod1 = ranges.quadrupletModuleIndices()[lowmod1]; + + for (unsigned int lowmodIdx2 : + cms::alpakatools::uniform_elements_x(acc, lowmodIdx1, ranges.nEligibleT4Modules())) { + uint16_t lowmod2 = ranges.indicesOfEligibleT4Modules()[lowmodIdx2]; + unsigned int nQuadruplets_lowmod2 = quadrupletsOccupancy.nQuadruplets()[lowmod2]; + if (nQuadruplets_lowmod2 == 0) + continue; + + unsigned int quadrupletModuleIndices_lowmod2 = ranges.quadrupletModuleIndices()[lowmod2]; + + for (unsigned int ix1 = 0; ix1 < nQuadruplets_lowmod1; ix1 += 1) { + unsigned int ix = quadrupletModuleIndices_lowmod1 + ix1; + if ((quadruplets.isDup()[ix] & 1)) + continue; + + const float eta1 = __H2F(quadruplets.eta()[ix]); + const float phi1 = __H2F(quadruplets.phi()[ix]); + const float score1 = quadruplets.displacedScore()[ix] - quadruplets.fakeScore()[ix]; + + for (unsigned int jx1 = 0; jx1 < nQuadruplets_lowmod2; jx1++) { + unsigned int jx = quadrupletModuleIndices_lowmod2 + jx1; + if (ix == jx) + continue; + + if ((quadruplets.isDup()[jx] & 1)) + continue; + + const float eta2 = __H2F(quadruplets.eta()[jx]); + const float phi2 = __H2F(quadruplets.phi()[jx]); + float dEta = alpaka::math::abs(acc, eta1 - eta2); + float dPhi = cms::alpakatools::deltaPhi(acc, phi1, phi2); + + if (dEta > 0.1f) + continue; + + if (alpaka::math::abs(acc, dPhi) > 0.1f) + continue; + + const float score2 = quadruplets.displacedScore()[jx] - quadruplets.fakeScore()[jx]; + + float dR2 = dEta * dEta + dPhi * dPhi; + int nMatched = checkHitsT4(ix, jx, quadruplets); + const int minNHitsForDup_T4 = 4; + if (dR2 < 0.001f || nMatched >= minNHitsForDup_T4) { + if (score1 > score2) { + rmQuadrupletFromMemory(quadruplets, jx, true); + } else if (score1 < score2) { + rmQuadrupletFromMemory(quadruplets, ix, true); + } else { + rmQuadrupletFromMemory(quadruplets, (ix < jx ? ix : jx), true); + } + } + } + } + } + } + } + }; + struct RemoveDupPixelTripletsFromMap { ALPAKA_FN_ACC void operator()(Acc2D const& acc, PixelTriplets pixelTriplets) const { for (unsigned int ix : cms::alpakatools::uniform_elements_y(acc, pixelTriplets.nPixelTriplets())) { diff --git a/RecoTracker/LSTCore/src/alpaka/LST.cc b/RecoTracker/LSTCore/src/alpaka/LST.cc index 501946aaf861e..055309f7e4f7b 100644 --- a/RecoTracker/LSTCore/src/alpaka/LST.cc +++ b/RecoTracker/LSTCore/src/alpaka/LST.cc @@ -105,6 +105,23 @@ void LST::run(Queue& queue, printf("# of Pixel T3s produced: %d\n", event.getNumberOfPixelTriplets()); } + event.createQuadruplets(); + if (verbose) { + alpaka::wait(queue); // event calls are asynchronous: wait before printing + printf("# of Quadruplets produced: %d\n", event.getNumberOfQuadruplets()); + printf("# of Quadruplets produced layer 1-2-3-4: %d\n", event.getNumberOfQuadrupletsByLayerBarrel(0)); + printf("# of Quadruplets produced layer 2: %d\n", event.getNumberOfQuadrupletsByLayerBarrel(1)); + printf("# of Quadruplets produced layer 3: %d\n", event.getNumberOfQuadrupletsByLayerBarrel(2)); + printf("# of Quadruplets produced layer 4: %d\n", event.getNumberOfQuadrupletsByLayerBarrel(3)); + printf("# of Quadruplets produced layer 5: %d\n", event.getNumberOfQuadrupletsByLayerBarrel(4)); + printf("# of Quadruplets produced layer 6: %d\n", event.getNumberOfQuadrupletsByLayerBarrel(5)); + printf("# of Quadruplets produced endcap layer 1: %d\n", event.getNumberOfQuadrupletsByLayerEndcap(0)); + printf("# of Quadruplets produced endcap layer 2: %d\n", event.getNumberOfQuadrupletsByLayerEndcap(1)); + printf("# of Quadruplets produced endcap layer 3: %d\n", event.getNumberOfQuadrupletsByLayerEndcap(2)); + printf("# of Quadruplets produced endcap layer 4: %d\n", event.getNumberOfQuadrupletsByLayerEndcap(3)); + printf("# of Quadruplets produced endcap layer 5: %d\n", event.getNumberOfQuadrupletsByLayerEndcap(4)); + } + event.createTrackCandidates(no_pls_dupclean, tc_pls_triplets); if (verbose) { alpaka::wait(queue); // event calls are asynchronous: wait before printing @@ -114,6 +131,7 @@ void LST::run(Queue& queue, printf(" # of pT3 TrackCandidates produced: %d\n", event.getNumberOfPT3TrackCandidates()); printf(" # of pLS TrackCandidates produced: %d\n", event.getNumberOfPLSTrackCandidates()); printf(" # of T5 TrackCandidates produced: %d\n", event.getNumberOfT5TrackCandidates()); + printf(" # of T4 TrackCandidates produced: %d\n", event.getNumberOfT4TrackCandidates()); } trackCandidatesBaseDC_ = event.releaseTrackCandidatesBaseDeviceCollection(); diff --git a/RecoTracker/LSTCore/src/alpaka/LSTEvent.dev.cc b/RecoTracker/LSTCore/src/alpaka/LSTEvent.dev.cc index 38ff64568a502..a683bdbd03a28 100644 --- a/RecoTracker/LSTCore/src/alpaka/LSTEvent.dev.cc +++ b/RecoTracker/LSTCore/src/alpaka/LSTEvent.dev.cc @@ -13,6 +13,7 @@ #include "Segment.h" #include "TrackCandidate.h" #include "Triplet.h" +#include "Quadruplet.h" using Device = ALPAKA_ACCELERATOR_NAMESPACE::Device; using Queue = ALPAKA_ACCELERATOR_NAMESPACE::Queue; @@ -30,11 +31,13 @@ void LSTEvent::initSync() { n_segments_by_layer_barrel_[i] = 0; n_triplets_by_layer_barrel_[i] = 0; n_quintuplets_by_layer_barrel_[i] = 0; + n_quadruplets_by_layer_barrel_[i] = 0; if (i < 5) { n_minidoublets_by_layer_endcap_[i] = 0; n_segments_by_layer_endcap_[i] = 0; n_triplets_by_layer_endcap_[i] = 0; n_quintuplets_by_layer_endcap_[i] = 0; + n_quadruplets_by_layer_endcap_[i] = 0; } } } @@ -47,11 +50,13 @@ void LSTEvent::resetEventSync() { n_segments_by_layer_barrel_[i] = 0; n_triplets_by_layer_barrel_[i] = 0; n_quintuplets_by_layer_barrel_[i] = 0; + n_quadruplets_by_layer_barrel_[i] = 0; if (i < 5) { n_minidoublets_by_layer_endcap_[i] = 0; n_segments_by_layer_endcap_[i] = 0; n_triplets_by_layer_endcap_[i] = 0; n_quintuplets_by_layer_endcap_[i] = 0; + n_quadruplets_by_layer_endcap_[i] = 0; } } lstInputDC_ = nullptr; @@ -66,6 +71,7 @@ void LSTEvent::resetEventSync() { trackCandidatesExtendedDC_.reset(); pixelTripletsDC_.reset(); pixelQuintupletsDC_.reset(); + quadrupletsDC_.reset(); lstInputHC_.reset(); hitsHC_.reset(); @@ -80,6 +86,7 @@ void LSTEvent::resetEventSync() { trackCandidatesBaseHC_.reset(); trackCandidatesExtendedHC_.reset(); modulesHC_.reset(); + quadrupletsHC_.reset(); } void LSTEvent::addInputToEvent(LSTInputDeviceCollection const* lstInputDC) { @@ -388,6 +395,9 @@ void LSTEvent::createTriplets() { alpaka::memset(queue_, partOfPT3_view, 0u); auto connectedMax_view = cms::alpakatools::make_device_view(queue_, triplets.connectedMax()); alpaka::memset(queue_, connectedMax_view, 0u); + auto connectedLSMax_view = + cms::alpakatools::make_device_view(queue_, triplets.connectedLSMax(), triplets.metadata().size()); + alpaka::memset(queue_, connectedLSMax_view, 0u); } uint16_t nonZeroModules = 0; @@ -448,15 +458,17 @@ void LSTEvent::createTriplets() { auto const addTripletRangesToEventExplicit_workDiv = cms::alpakatools::make_workdiv(1, 1024); - alpaka::exec(queue_, - addTripletRangesToEventExplicit_workDiv, - AddTripletRangesToEventExplicit{}, - modules_.const_view(), - tripletsDC_->const_view(), - rangesDC_->view()); + if (nonZeroModules != 0) { + alpaka::exec(queue_, + addTripletRangesToEventExplicit_workDiv, + AddTripletRangesToEventExplicit{}, + modules_.const_view(), + tripletsDC_->const_view(), + rangesDC_->view()); - if (addObjects_) { - addTripletsToEventExplicit(); + if (addObjects_) { + addTripletsToEventExplicit(); + } } } @@ -538,6 +550,55 @@ void LSTEvent::createTrackCandidates(bool no_pls_dupclean, bool tc_pls_triplets) trackCandidatesExtendedDC_->view(), rangesDC_->const_view()); + auto nEligibleModulesT4_buf_h = cms::alpakatools::make_host_buffer(queue_); + auto nEligibleModulesT4_buf_d = cms::alpakatools::make_device_view(queue_, rangesOccupancy.nEligibleT4Modules()); + alpaka::memcpy(queue_, nEligibleModulesT4_buf_h, nEligibleModulesT4_buf_d); + alpaka::wait(queue_); // wait to get the value before using + auto const nEligibleModulesT4 = *nEligibleModulesT4_buf_h.data(); + + auto const removeDupQuadrupletsBeforeTC_workDiv = cms::alpakatools::make_workdiv( + {std::max(nEligibleModulesT4 / threadsPerBlockY, 1), std::max(nEligibleModulesT4 / threadsPerBlockX, 1)}, + {16, 32}); + + alpaka::exec(queue_, + removeDupQuadrupletsBeforeTC_workDiv, + RemoveDupQuadrupletsBeforeTC{}, + quadrupletsDC_->view(), + quadrupletsDC_->view(), + rangesDC_->const_view()); + + auto const crossCleanT4_workDiv = cms::alpakatools::make_workdiv( + {(nLowerModules_ / threadsPerBlock) + 1, 1, max_blocks}, {threadsPerBlock, 1, threadsPerBlock}); + + alpaka::exec(queue_, + crossCleanT4_workDiv, + CrossCleanT4{}, + modules_.const_view(), + quadrupletsDC_->view(), + quadrupletsDC_->const_view(), + pixelQuintupletsDC_->const_view(), + pixelTripletsDC_->const_view(), + quintupletsDC_->const_view(), + trackCandidatesBaseDC_->view(), + trackCandidatesExtendedDC_->view(), + miniDoubletsDC_->view(), + segmentsDC_->view(), + tripletsDC_->view(), + rangesDC_->const_view()); + + auto const addT4asTrackCandidate_workDiv = cms::alpakatools::make_workdiv({8, 10}, {8, 128}); + + alpaka::exec(queue_, + addT4asTrackCandidate_workDiv, + AddT4asTrackCandidate{}, + nLowerModules_, + quadrupletsDC_->view(), + quadrupletsDC_->const_view(), + tripletsDC_->const_view(), + trackCandidatesBaseDC_->view(), + trackCandidatesExtendedDC_->view(), + rangesDC_->const_view()); + if (!no_pls_dupclean) { auto const checkHitspLS_workDiv = cms::alpakatools::make_workdiv({max_blocks * 4, max_blocks / 4}, {16, 16}); @@ -567,7 +628,8 @@ void LSTEvent::createTrackCandidates(bool no_pls_dupclean, bool tc_pls_triplets) pixelSegmentsDC_->view(), miniDoubletsDC_->const_view(), lstInputDC_->const_view(), - quintupletsDC_->const_view()); + quintupletsDC_->const_view(), + quadrupletsDC_->const_view()); auto const addpLSasTrackCandidate_workDiv = cms::alpakatools::make_workdiv(max_blocks, 384); @@ -587,6 +649,7 @@ void LSTEvent::createTrackCandidates(bool no_pls_dupclean, bool tc_pls_triplets) auto nTrackCanpT3Host_buf = cms::alpakatools::make_host_buffer(queue_); auto nTrackCanpLSHost_buf = cms::alpakatools::make_host_buffer(queue_); auto nTrackCanT5Host_buf = cms::alpakatools::make_host_buffer(queue_); + auto nTrackCanT4Host_buf = cms::alpakatools::make_host_buffer(queue_); alpaka::memcpy(queue_, nTrackCanpT5Host_buf, cms::alpakatools::make_device_view(queue_, (*trackCandidatesExtendedDC_)->nTrackCandidatespT5())); @@ -599,14 +662,18 @@ void LSTEvent::createTrackCandidates(bool no_pls_dupclean, bool tc_pls_triplets) alpaka::memcpy(queue_, nTrackCanT5Host_buf, cms::alpakatools::make_device_view(queue_, (*trackCandidatesExtendedDC_)->nTrackCandidatesT5())); + alpaka::memcpy(queue_, + nTrackCanT4Host_buf, + cms::alpakatools::make_device_view(queue_, (*trackCandidatesExtendedDC_)->nTrackCandidatesT4())); alpaka::wait(queue_); // wait to get the values before using them auto nTrackCandidatespT5 = *nTrackCanpT5Host_buf.data(); auto nTrackCandidatespT3 = *nTrackCanpT3Host_buf.data(); auto nTrackCandidatespLS = *nTrackCanpLSHost_buf.data(); auto nTrackCandidatesT5 = *nTrackCanT5Host_buf.data(); + auto nTrackCandidatesT4 = *nTrackCanT4Host_buf.data(); if ((nTrackCandidatespT5 + nTrackCandidatespT3 + nTrackCandidatespLS == n_max_pixel_track_candidates) || - (nTrackCandidatesT5 == n_max_nonpixel_track_candidates)) { + (nTrackCandidatesT5 + nTrackCandidatesT4 == n_max_nonpixel_track_candidates)) { lstWarning( "\ ****************************************************************************************************\n\ @@ -998,6 +1065,99 @@ void LSTEvent::createPixelQuintuplets() { #endif } +void LSTEvent::createQuadruplets() { + auto const countLSConn_workDiv = cms::alpakatools::make_workdiv({nLowerModules_, 1, 1}, {1, 8, 32}); + + alpaka::exec(queue_, + countLSConn_workDiv, + CountTripletLSConnections{}, + modules_.const_view(), + miniDoubletsDC_->const_view(), + segmentsDC_->const_view(), + tripletsDC_->view(), + tripletsDC_->const_view(), + rangesDC_->const_view(), + ptCut_); + + auto const createEligibleModulesListForQuadruplets_workDiv = cms::alpakatools::make_workdiv(1, 1024); + + alpaka::exec(queue_, + createEligibleModulesListForQuadruplets_workDiv, + CreateEligibleModulesListForQuadruplets{}, + modules_.const_view(), + tripletsDC_->const_view(), + rangesDC_->view(), + tripletsDC_->view()); + + auto nEligibleT4Modules_buf = cms::alpakatools::make_host_buffer(queue_); + auto nTotalQuadruplets_buf = cms::alpakatools::make_host_buffer(queue_); + auto rangesOccupancy = rangesDC_->view(); + auto nEligibleT4Modules_view_d = cms::alpakatools::make_device_view(queue_, rangesOccupancy.nEligibleT4Modules()); + auto nTotalQuadruplets_view_d = cms::alpakatools::make_device_view(queue_, rangesOccupancy.nTotalQuads()); + alpaka::memcpy(queue_, nEligibleT4Modules_buf, nEligibleT4Modules_view_d); + alpaka::memcpy(queue_, nTotalQuadruplets_buf, nTotalQuadruplets_view_d); + alpaka::wait(queue_); // wait for the values before using them + + auto nEligibleT4Modules = *nEligibleT4Modules_buf.data(); + auto nTotalQuadruplets = *nTotalQuadruplets_buf.data(); + + if (!quadrupletsDC_) { + std::array const quadruplets_sizes{{static_cast(nTotalQuadruplets), static_cast(nLowerModules_)}}; + quadrupletsDC_.emplace(quadruplets_sizes, queue_); + auto quadrupletsOccupancy = quadrupletsDC_->view(); + auto nQuadruplets_view = cms::alpakatools::make_device_view( + queue_, quadrupletsOccupancy.nQuadruplets(), quadrupletsOccupancy.metadata().size()); + alpaka::memset(queue_, nQuadruplets_view, 0u); + auto totOccupancyQuadruplets_view = cms::alpakatools::make_device_view( + queue_, quadrupletsOccupancy.totOccupancyQuadruplets(), quadrupletsOccupancy.metadata().size()); + alpaka::memset(queue_, totOccupancyQuadruplets_view, 0u); + auto quadruplets = quadrupletsDC_->view(); + auto isDup_view = cms::alpakatools::make_device_view(queue_, quadruplets.isDup(), quadruplets.metadata().size()); + alpaka::memset(queue_, isDup_view, 0u); + } + + auto const createQuadruplets_workDiv = + cms::alpakatools::make_workdiv({std::max((int)nEligibleT4Modules, 1), 1, 1}, {1, 8, 32}); + + alpaka::exec(queue_, + createQuadruplets_workDiv, + CreateQuadruplets{}, + modules_.const_view(), + miniDoubletsDC_->const_view(), + segmentsDC_->const_view(), + tripletsDC_->view(), + tripletsDC_->const_view(), + quadrupletsDC_->view(), + quadrupletsDC_->view(), + rangesDC_->const_view(), + nEligibleT4Modules, + ptCut_); + + auto const removeDupQuadrupletsAfterBuild_workDiv = + cms::alpakatools::make_workdiv({max_blocks, 1, 1}, {1, 16, 16}); + + alpaka::exec(queue_, + removeDupQuadrupletsAfterBuild_workDiv, + RemoveDupQuadrupletsAfterBuild{}, + modules_.const_view(), + quadrupletsDC_->view(), + quadrupletsDC_->const_view(), + rangesDC_->const_view()); + + auto const addQuadrupletRangesToEventExplicit_workDiv = cms::alpakatools::make_workdiv(1, 1024); + + alpaka::exec(queue_, + addQuadrupletRangesToEventExplicit_workDiv, + AddQuadrupletRangesToEventExplicit{}, + modules_.const_view(), + quadrupletsDC_->const_view(), + rangesDC_->view()); + + if (addObjects_) { + addQuadrupletsToEventExplicit(); + } +} + void LSTEvent::addMiniDoubletsToEventExplicit() { auto nMDsCPU_buf = cms::alpakatools::make_host_buffer(queue_, nLowerModules_); auto mdsOccupancy = miniDoubletsDC_->const_view(); @@ -1151,6 +1311,43 @@ void LSTEvent::addTripletsToEventExplicit() { } } +void LSTEvent::addQuadrupletsToEventExplicit() { + auto quadrupletsOccupancy = quadrupletsDC_->const_view(); + auto nQuadruplets_view = + cms::alpakatools::make_device_view(queue_, quadrupletsOccupancy.nQuadruplets(), nLowerModules_); + auto nQuadrupletsCPU_buf = cms::alpakatools::make_host_buffer(queue_, nLowerModules_); + alpaka::memcpy(queue_, nQuadrupletsCPU_buf, nQuadruplets_view); + + auto modules = modules_.const_view(); + + // FIXME: replace by ES host data + auto module_subdets_buf = cms::alpakatools::make_host_buffer(queue_, nLowerModules_); + auto module_subdets_view = + cms::alpakatools::make_device_view(queue_, modules.subdets(), nLowerModules_); // only lower modules + alpaka::memcpy(queue_, module_subdets_buf, module_subdets_view, nLowerModules_); + + auto module_layers_buf = cms::alpakatools::make_host_buffer(queue_, nLowerModules_); + auto module_layers_view = + cms::alpakatools::make_device_view(queue_, modules.layers(), nLowerModules_); // only lower modules + alpaka::memcpy(queue_, module_layers_buf, module_layers_view, nLowerModules_); + + alpaka::wait(queue_); // wait for inputs before using them + + auto const* nQuadrupletsCPU = nQuadrupletsCPU_buf.data(); + auto const* module_subdets = module_subdets_buf.data(); + auto const* module_layers = module_layers_buf.data(); + + for (uint16_t i = 0; i < nLowerModules_; i++) { + if (nQuadrupletsCPU[i] != 0) { + if (module_subdets[i] == Barrel) { + n_quadruplets_by_layer_barrel_[module_layers[i] - 1] += nQuadrupletsCPU[i]; + } else { + n_quadruplets_by_layer_endcap_[module_layers[i] - 1] += nQuadrupletsCPU[i]; + } + } + } +} + unsigned int LSTEvent::getNumberOfMiniDoublets() { unsigned int miniDoublets = 0; for (auto& it : n_minidoublets_by_layer_barrel_) { @@ -1299,6 +1496,7 @@ int LSTEvent::getNumberOfPLSTrackCandidates() { int LSTEvent::getNumberOfPixelTrackCandidates() { auto nTrackCandidates_buf_h = cms::alpakatools::make_host_buffer(queue_); auto nTrackCandidatesT5_buf_h = cms::alpakatools::make_host_buffer(queue_); + auto nTrackCandidatesT4_buf_h = cms::alpakatools::make_host_buffer(queue_); alpaka::memcpy(queue_, nTrackCandidates_buf_h, @@ -1306,9 +1504,12 @@ int LSTEvent::getNumberOfPixelTrackCandidates() { alpaka::memcpy(queue_, nTrackCandidatesT5_buf_h, cms::alpakatools::make_device_view(queue_, (*trackCandidatesExtendedDC_)->nTrackCandidatesT5())); + alpaka::memcpy(queue_, + nTrackCandidatesT4_buf_h, + cms::alpakatools::make_device_view(queue_, (*trackCandidatesExtendedDC_)->nTrackCandidatesT4())); alpaka::wait(queue_); - return (*nTrackCandidates_buf_h.data()) - (*nTrackCandidatesT5_buf_h.data()); + return (*nTrackCandidates_buf_h.data()) - (*nTrackCandidatesT5_buf_h.data()) - (*nTrackCandidatesT4_buf_h.data()); } int LSTEvent::getNumberOfT5TrackCandidates() { @@ -1322,6 +1523,37 @@ int LSTEvent::getNumberOfT5TrackCandidates() { return *nTrackCandidatesT5_buf_h.data(); } +int LSTEvent::getNumberOfT4TrackCandidates() { + auto nTrackCandidatesT4_buf_h = cms::alpakatools::make_host_buffer(queue_); + + alpaka::memcpy(queue_, + nTrackCandidatesT4_buf_h, + cms::alpakatools::make_device_view(queue_, (*trackCandidatesExtendedDC_)->nTrackCandidatesT4())); + alpaka::wait(queue_); + + return *nTrackCandidatesT4_buf_h.data(); +} + +unsigned int LSTEvent::getNumberOfQuadruplets() { + unsigned int quadruplets = 0; + for (auto& it : n_quadruplets_by_layer_barrel_) { + quadruplets += it; + } + for (auto& it : n_quadruplets_by_layer_endcap_) { + quadruplets += it; + } + + return quadruplets; +} + +unsigned int LSTEvent::getNumberOfQuadrupletsByLayerBarrel(unsigned int layer) { + return n_quadruplets_by_layer_barrel_[layer]; +} + +unsigned int LSTEvent::getNumberOfQuadrupletsByLayerEndcap(unsigned int layer) { + return n_quadruplets_by_layer_endcap_[layer]; +} + template typename TSoA::ConstView LSTEvent::getInput(bool sync) { if constexpr (std::is_same_v) { @@ -1448,6 +1680,25 @@ typename TSoA::ConstView LSTEvent::getTriplets(bool sync) { template TripletsConst LSTEvent::getTriplets(bool); template TripletsOccupancyConst LSTEvent::getTriplets(bool); +template +typename TSoA::ConstView LSTEvent::getQuadruplets(bool sync) { + if constexpr (std::is_same_v) { + return quadrupletsDC_->const_view(); + } else { + if (!quadrupletsHC_) { + quadrupletsHC_.emplace( + cms::alpakatools::CopyToHost>::copyAsync( + queue_, *quadrupletsDC_)); + + if (sync) + alpaka::wait(queue_); // host consumers expect filled data + } + } + return quadrupletsHC_->const_view(); +} +template QuadrupletsConst LSTEvent::getQuadruplets(bool); +template QuadrupletsOccupancyConst LSTEvent::getQuadruplets(bool); + template typename TSoA::ConstView LSTEvent::getQuintuplets(bool sync) { if constexpr (std::is_same_v) { diff --git a/RecoTracker/LSTCore/src/alpaka/LSTEvent.h b/RecoTracker/LSTCore/src/alpaka/LSTEvent.h index 377e0195d48f2..558c074e05101 100644 --- a/RecoTracker/LSTCore/src/alpaka/LSTEvent.h +++ b/RecoTracker/LSTCore/src/alpaka/LSTEvent.h @@ -10,6 +10,7 @@ #include "RecoTracker/LSTCore/interface/PixelQuintupletsHostCollection.h" #include "RecoTracker/LSTCore/interface/PixelTripletsHostCollection.h" #include "RecoTracker/LSTCore/interface/QuintupletsHostCollection.h" +#include "RecoTracker/LSTCore/interface/QuadrupletsHostCollection.h" #include "RecoTracker/LSTCore/interface/SegmentsHostCollection.h" #include "RecoTracker/LSTCore/interface/PixelSegmentsHostCollection.h" #include "RecoTracker/LSTCore/interface/TrackCandidatesHostCollection.h" @@ -24,6 +25,7 @@ #include "RecoTracker/LSTCore/interface/alpaka/PixelQuintupletsDeviceCollection.h" #include "RecoTracker/LSTCore/interface/alpaka/PixelTripletsDeviceCollection.h" #include "RecoTracker/LSTCore/interface/alpaka/QuintupletsDeviceCollection.h" +#include "RecoTracker/LSTCore/interface/alpaka/QuadrupletsDeviceCollection.h" #include "RecoTracker/LSTCore/interface/alpaka/SegmentsDeviceCollection.h" #include "RecoTracker/LSTCore/interface/alpaka/PixelSegmentsDeviceCollection.h" #include "RecoTracker/LSTCore/interface/alpaka/TrackCandidatesDeviceCollection.h" @@ -50,6 +52,8 @@ namespace ALPAKA_ACCELERATOR_NAMESPACE::lst { std::array n_triplets_by_layer_endcap_{}; std::array n_quintuplets_by_layer_barrel_{}; std::array n_quintuplets_by_layer_endcap_{}; + std::array n_quadruplets_by_layer_barrel_{}; + std::array n_quadruplets_by_layer_endcap_{}; unsigned int nTotalSegments_; unsigned int pixelSize_; uint16_t pixelModuleIndex_; @@ -63,6 +67,7 @@ namespace ALPAKA_ACCELERATOR_NAMESPACE::lst { std::optional pixelSegmentsDC_; std::optional tripletsDC_; std::optional quintupletsDC_; + std::optional quadrupletsDC_; std::optional trackCandidatesBaseDC_; std::optional trackCandidatesExtendedDC_; std::optional pixelTripletsDC_; @@ -82,6 +87,7 @@ namespace ALPAKA_ACCELERATOR_NAMESPACE::lst { std::optional quintupletsHC_; std::optional pixelTripletsHC_; std::optional pixelQuintupletsHC_; + std::optional quadrupletsHC_; const uint16_t nModules_; const uint16_t nLowerModules_; @@ -130,6 +136,7 @@ namespace ALPAKA_ACCELERATOR_NAMESPACE::lst { void createQuintuplets(); void pixelLineSegmentCleaning(bool no_pls_dupclean); void createPixelQuintuplets(); + void createQuadruplets(); // functions that map the objects to the appropriate modules void addMiniDoubletsToEventExplicit(); @@ -137,6 +144,7 @@ namespace ALPAKA_ACCELERATOR_NAMESPACE::lst { void addQuintupletsToEventExplicit(); void addTripletsToEventExplicit(); void resetObjectsInModule(); + void addQuadrupletsToEventExplicit(); unsigned int getNumberOfMiniDoublets(); unsigned int getNumberOfMiniDoubletsByLayerBarrel(unsigned int layer); @@ -163,6 +171,11 @@ namespace ALPAKA_ACCELERATOR_NAMESPACE::lst { int getNumberOfPLSTrackCandidates(); int getNumberOfPixelTrackCandidates(); int getNumberOfT5TrackCandidates(); + int getNumberOfT4TrackCandidates(); + + unsigned int getNumberOfQuadruplets(); + unsigned int getNumberOfQuadrupletsByLayerBarrel(unsigned int layer); + unsigned int getNumberOfQuadrupletsByLayerEndcap(unsigned int layer); // sync adds alpaka::wait at the end of filling a buffer during lazy fill // (has no effect on repeated calls) @@ -181,6 +194,8 @@ namespace ALPAKA_ACCELERATOR_NAMESPACE::lst { template typename TSoA::ConstView getTriplets(bool sync = true); template + typename TSoA::ConstView getQuadruplets(bool sync = true); + template typename TSoA::ConstView getQuintuplets(bool sync = true); template PixelTripletsConst getPixelTriplets(bool sync = true); diff --git a/RecoTracker/LSTCore/src/alpaka/NeuralNetwork.h b/RecoTracker/LSTCore/src/alpaka/NeuralNetwork.h index 8494b34c09c3d..9e1ff995188d2 100644 --- a/RecoTracker/LSTCore/src/alpaka/NeuralNetwork.h +++ b/RecoTracker/LSTCore/src/alpaka/NeuralNetwork.h @@ -12,6 +12,7 @@ #include "pT3NeuralNetworkWeights.h" #include "T5EmbedNetworkWeights.h" #include "pLSEmbedNetworkWeights.h" +#include "T4NeuralNetworkWeights.h" namespace ALPAKA_ACCELERATOR_NAMESPACE::lst { @@ -432,6 +433,131 @@ namespace ALPAKA_ACCELERATOR_NAMESPACE::lst { } // namespace plsembdnn + namespace t4dnn { + template + ALPAKA_FN_ACC ALPAKA_FN_INLINE bool runInference(TAcc const& acc, + MiniDoubletsConst mds, + ModulesConst modules, + const unsigned int mdIndex1, + const unsigned int mdIndex2, + const unsigned int mdIndex3, + const unsigned int mdIndex4, + uint16_t lowerModuleIndex1, + uint16_t lowerModuleIndex2, + uint16_t lowerModuleIndex3, + uint16_t lowerModuleIndex4, + const float innerRadius, + const float outerRadius, + float& promptScore, + float& displacedScore, + float& fakeScore, + const float regressionRadius, + const float nonAnchorRegressionRadius, + float fakeScore1, + float promptScore1, + float displacedScore1, + float fakeScore2, + float promptScore2, + float displacedScore2) { + // Constants + constexpr unsigned int kinputFeatures = 30; + constexpr unsigned int khiddenFeatures = 32; + constexpr unsigned int koutputFeatures = 3; + + float eta1 = alpaka::math::abs(acc, mds.anchorEta()[mdIndex1]); // inner T3 anchor hit 1 eta (t3_0_eta) + float eta2 = alpaka::math::abs(acc, mds.anchorEta()[mdIndex2]); // inner T3 anchor hit 2 eta (t3_2_eta) + float eta3 = alpaka::math::abs(acc, mds.anchorEta()[mdIndex3]); // inner T3 anchor hit 3 eta (t3_4_eta) + float eta4 = alpaka::math::abs(acc, mds.anchorEta()[mdIndex4]); // outer T3 anchor hit 4 eta (t3_2_eta) + + float phi1 = mds.anchorPhi()[mdIndex1]; // inner T3 anchor hit 1 phi + float phi2 = mds.anchorPhi()[mdIndex2]; // inner T3 anchor hit 2 phi + float phi3 = mds.anchorPhi()[mdIndex3]; // inner T3 anchor hit 3 phi + float phi4 = mds.anchorPhi()[mdIndex4]; // outer T3 anchor hit 4 phi + + float z1 = alpaka::math::abs(acc, mds.anchorZ()[mdIndex1]); // inner T3 anchor hit 1 z (t3_0_z) + float z2 = alpaka::math::abs(acc, mds.anchorZ()[mdIndex2]); // inner T3 anchor hit 2 z (t3_2_z) + float z3 = alpaka::math::abs(acc, mds.anchorZ()[mdIndex3]); // inner T3 anchor hit 3 z (t3_4_z) + float z4 = alpaka::math::abs(acc, mds.anchorZ()[mdIndex4]); // outer T3 anchor hit 4 z (t3_2_z) + + float r1 = mds.anchorRt()[mdIndex1]; // inner T3 anchor hit 1 r (t3_0_r) + float r2 = mds.anchorRt()[mdIndex2]; // inner T3 anchor hit 2 r (t3_2_r) + float r3 = mds.anchorRt()[mdIndex3]; // inner T3 anchor hit 3 r (t3_4_r) + float r4 = mds.anchorRt()[mdIndex4]; // outer T3 anchor hit 4 r (t3_2_r) + + // Build the input feature vector using pairwise differences after the first hit + float x[kinputFeatures] = { + eta1 / dnn::t4dnn::kEta_norm, // inner T3: First hit eta normalized + alpaka::math::abs(acc, phi1) / dnn::kPhi_norm, // inner T3: First hit phi normalized + z1 / dnn::t4dnn::kZ_max, // inner T3: First hit z normalized + r1 / dnn::t4dnn::kR_max, // inner T3: First hit r normalized + + eta2 - eta1, // inner T3: Difference in eta between hit 2 and 1 + cms::alpakatools::deltaPhi(acc, phi2, phi1) / + dnn::kPhi_norm, // inner T3: Difference in phi between hit 2 and 1 + (z2 - z1) / dnn::t4dnn::kZ_max, // inner T3: Difference in z between hit 2 and 1 normalized + (r2 - r1) / dnn::t4dnn::kR_max, // inner T3: Difference in r between hit 2 and 1 normalized + + eta3 - eta2, // inner T3: Difference in eta between hit 3 and 2 + cms::alpakatools::deltaPhi(acc, phi3, phi2) / + dnn::kPhi_norm, // inner T3: Difference in phi between hit 3 and 2 + (z3 - z2) / dnn::t4dnn::kZ_max, // inner T3: Difference in z between hit 3 and 2 normalized + (r3 - r2) / dnn::t4dnn::kR_max, // inner T3: Difference in r between hit 3 and 2 normalized + + eta4 - eta3, // outer T3: Difference in eta between hit 4 and 3 + cms::alpakatools::deltaPhi(acc, phi4, phi3) / + dnn::kPhi_norm, // inner T3: Difference in phi between hit 4 and 3 + (z4 - z3) / dnn::t4dnn::kZ_max, // outer T3: Difference in z between hit 4 and 3 normalized + (r4 - r3) / dnn::t4dnn::kR_max, // outer T3: Difference in r between hit 4 and 3 normalized + + 1.0f / innerRadius, // T4 inner radius (t4_innerRadius) + 1.0f / outerRadius, // T4 outer radius (t4_outerRadius) + innerRadius / outerRadius, + 1.0f / regressionRadius, + 1.0f / nonAnchorRegressionRadius, + fakeScore1, + promptScore1, + displacedScore1, + fakeScore2, + promptScore2, + displacedScore2, + (fakeScore2 - fakeScore1), + (promptScore2 - promptScore1), + (displacedScore2 - displacedScore1), + }; + + float x_1[khiddenFeatures]; // Layer 1 output + float x_2[khiddenFeatures]; // Layer 2 output + float x_3[koutputFeatures]; // Layer 3 output + + // Layer 1: Linear + Relu + linear_layer(x, x_1, dnn::t4dnn::wgtT_layer1, dnn::t4dnn::bias_layer1); + relu_activation(x_1); + + // Layer 2: Linear + Relu + linear_layer(x_1, x_2, dnn::t4dnn::wgtT_layer2, dnn::t4dnn::bias_layer2); + relu_activation(x_2); + + // Layer 3: Linear + Softmax + linear_layer( + x_2, x_3, dnn::t4dnn::wgtT_output_layer, dnn::t4dnn::bias_output_layer); + softmax_activation(acc, x_3); + + // Get the bin index based on abs(eta) of first hit and t4_pt + float t4_pt = (innerRadius + outerRadius) * lst::k2Rinv1GeVf; //t4 pt is average + + uint8_t pt_index = (t4_pt > 5.f); + uint8_t bin_index = (eta1 > 2.5f) ? (dnn::t4dnn::kEtaBins - 1) : static_cast(eta1 / 0.1f); + + promptScore = x_3[1]; + displacedScore = x_3[2]; + fakeScore = x_3[0]; + + return (x_3[2] > dnn::t4dnn::kWp_displaced[pt_index][bin_index]) && + (x_3[0] < dnn::t4dnn::kWp_fake[pt_index][bin_index]); + } + + } //namespace t4dnn + } // namespace ALPAKA_ACCELERATOR_NAMESPACE::lst #endif diff --git a/RecoTracker/LSTCore/src/alpaka/Quadruplet.h b/RecoTracker/LSTCore/src/alpaka/Quadruplet.h new file mode 100644 index 0000000000000..b3299a26e62ce --- /dev/null +++ b/RecoTracker/LSTCore/src/alpaka/Quadruplet.h @@ -0,0 +1,995 @@ +#ifndef RecoTracker_LSTCore_src_alpaka_Quadruplet_h +#define RecoTracker_LSTCore_src_alpaka_Quadruplet_h + +#include "HeterogeneousCore/AlpakaInterface/interface/workdivision.h" +#include "FWCore/Utilities/interface/isFinite.h" +#include "FWCore/Utilities/interface/CMSUnrollLoop.h" + +#include "RecoTracker/LSTCore/interface/ObjectRangesSoA.h" +#include "RecoTracker/LSTCore/interface/MiniDoubletsSoA.h" +#include "RecoTracker/LSTCore/interface/SegmentsSoA.h" +#include "RecoTracker/LSTCore/interface/TripletsSoA.h" +#include "RecoTracker/LSTCore/interface/QuadrupletsSoA.h" +#include "RecoTracker/LSTCore/interface/alpaka/Common.h" +#include "RecoTracker/LSTCore/interface/ModulesSoA.h" +#include "RecoTracker/LSTCore/interface/EndcapGeometry.h" +#include "RecoTracker/LSTCore/interface/ObjectRangesSoA.h" +#include "RecoTracker/LSTCore/interface/Circle.h" + +#include "Quintuplet.h" + +#include "NeuralNetwork.h" + +namespace ALPAKA_ACCELERATOR_NAMESPACE::lst { + ALPAKA_FN_ACC ALPAKA_FN_INLINE void addQuadrupletToMemory(TripletsConst triplets, + Quadruplets quadruplets, + unsigned int innerTripletIndex, + unsigned int outerTripletIndex, + uint16_t lowerModule1, + uint16_t lowerModule2, + uint16_t lowerModule3, + uint16_t lowerModule4, + float innerRadius, + float outerRadius, + float pt, + float eta, + float phi, + float scores, + uint8_t layer, + unsigned int quadrupletIndex, + float rzChiSquared, + float dBeta, + float promptScore, + float displacedScore, + float fakeScore, + float regressionCenterX, + float regressionCenterY, + float regressionRadius, + float nonAnchorRegressionRadius) { + quadruplets.tripletIndices()[quadrupletIndex][0] = innerTripletIndex; + quadruplets.tripletIndices()[quadrupletIndex][1] = outerTripletIndex; + + quadruplets.lowerModuleIndices()[quadrupletIndex][0] = lowerModule1; + quadruplets.lowerModuleIndices()[quadrupletIndex][1] = lowerModule2; + quadruplets.lowerModuleIndices()[quadrupletIndex][2] = lowerModule3; + quadruplets.lowerModuleIndices()[quadrupletIndex][3] = lowerModule4; + quadruplets.innerRadius()[quadrupletIndex] = __F2H(innerRadius); + quadruplets.outerRadius()[quadrupletIndex] = __F2H(outerRadius); + quadruplets.pt()[quadrupletIndex] = __F2H(pt); + quadruplets.eta()[quadrupletIndex] = __F2H(eta); + quadruplets.phi()[quadrupletIndex] = __F2H(phi); + quadruplets.score_rphisum()[quadrupletIndex] = __F2H(scores); + quadruplets.layer()[quadrupletIndex] = layer; + quadruplets.isDup()[quadrupletIndex] = 0; + quadruplets.logicalLayers()[quadrupletIndex][0] = triplets.logicalLayers()[innerTripletIndex][0]; + quadruplets.logicalLayers()[quadrupletIndex][1] = triplets.logicalLayers()[innerTripletIndex][1]; + quadruplets.logicalLayers()[quadrupletIndex][2] = triplets.logicalLayers()[innerTripletIndex][2]; + quadruplets.logicalLayers()[quadrupletIndex][3] = triplets.logicalLayers()[outerTripletIndex][2]; + + quadruplets.hitIndices()[quadrupletIndex][0] = triplets.hitIndices()[innerTripletIndex][0]; + quadruplets.hitIndices()[quadrupletIndex][1] = triplets.hitIndices()[innerTripletIndex][1]; + quadruplets.hitIndices()[quadrupletIndex][2] = triplets.hitIndices()[innerTripletIndex][2]; + quadruplets.hitIndices()[quadrupletIndex][3] = triplets.hitIndices()[innerTripletIndex][3]; + quadruplets.hitIndices()[quadrupletIndex][4] = triplets.hitIndices()[innerTripletIndex][4]; + quadruplets.hitIndices()[quadrupletIndex][5] = triplets.hitIndices()[innerTripletIndex][5]; + quadruplets.hitIndices()[quadrupletIndex][6] = triplets.hitIndices()[outerTripletIndex][4]; + quadruplets.hitIndices()[quadrupletIndex][7] = triplets.hitIndices()[outerTripletIndex][5]; + + quadruplets.rzChiSquared()[quadrupletIndex] = rzChiSquared; + quadruplets.dBeta()[quadrupletIndex] = dBeta; + quadruplets.promptScore()[quadrupletIndex] = promptScore; + quadruplets.displacedScore()[quadrupletIndex] = displacedScore; + quadruplets.fakeScore()[quadrupletIndex] = fakeScore; + + quadruplets.regressionRadius()[quadrupletIndex] = regressionRadius; + quadruplets.nonAnchorRegressionRadius()[quadrupletIndex] = nonAnchorRegressionRadius; + quadruplets.regressionCenterX()[quadrupletIndex] = regressionCenterX; + quadruplets.regressionCenterY()[quadrupletIndex] = regressionCenterY; + }; + + template + ALPAKA_FN_ACC ALPAKA_FN_INLINE bool passT4RZConstraint(TAcc const& acc, + ModulesConst modules, + MiniDoubletsConst mds, + unsigned int firstMDIndex, + unsigned int secondMDIndex, + unsigned int thirdMDIndex, + unsigned int fourthMDIndex, + uint16_t lowerModuleIndex1, + uint16_t lowerModuleIndex2, + uint16_t lowerModuleIndex3, + uint16_t lowerModuleIndex4, + float& rzChiSquared, + float inner_pt, + float innerRadius, + float g, + float f, + short charge) { + //all the values are stored in the unit of cm, in the calculation below we need to be cautious if we want to use the meter unit + //get r and z + const float rt1 = mds.anchorRt()[firstMDIndex] / 100; + const float rt2 = mds.anchorRt()[secondMDIndex] / 100; + const float rt3 = mds.anchorRt()[thirdMDIndex] / 100; + const float rt4 = mds.anchorRt()[fourthMDIndex] / 100; + + const float z1 = mds.anchorZ()[firstMDIndex] / 100; + const float z2 = mds.anchorZ()[secondMDIndex] / 100; + const float z3 = mds.anchorZ()[thirdMDIndex] / 100; + const float z4 = mds.anchorZ()[fourthMDIndex] / 100; + + // Using lst_layer numbering convention defined in ModuleMethods.h + const short layer2 = modules.lstLayers()[lowerModuleIndex2]; + const short layer3 = modules.lstLayers()[lowerModuleIndex3]; + const short layer4 = modules.lstLayers()[lowerModuleIndex4]; + + // Get the module type of each MD: 0 is ps, 1 is 2s + const bool moduleType1 = modules.moduleType()[lowerModuleIndex1]; + const bool moduleType2 = modules.moduleType()[lowerModuleIndex2]; + const bool moduleType3 = modules.moduleType()[lowerModuleIndex3]; + const bool moduleType4 = modules.moduleType()[lowerModuleIndex4]; + + // Get the x,y position of each MD + const float x1 = mds.anchorX()[firstMDIndex] / 100; + const float x2 = mds.anchorX()[secondMDIndex] / 100; + const float x3 = mds.anchorX()[thirdMDIndex] / 100; + const float x4 = mds.anchorX()[fourthMDIndex] / 100; + const float y1 = mds.anchorY()[firstMDIndex] / 100; + const float y2 = mds.anchorY()[secondMDIndex] / 100; + const float y3 = mds.anchorY()[thirdMDIndex] / 100; + const float y4 = mds.anchorY()[fourthMDIndex] / 100; + + float residual = 0; + float error2 = 0; + // (g,f) is the center of the circle fitted by the innermost 3 points on x,y coordinates + float x_center = g / 100, y_center = f / 100; + float x_init = x3; + float y_init = y3; + float z_init = z3; + float rt_init = rt3; //use the third MD as initial point + + if (moduleType3 == 1) // if MD3 is in 2s layer, use MD2 as initial point + { + x_init = x2; + y_init = y2; + z_init = z2; + rt_init = rt2; + } + + float pt = inner_pt, px = pt * charge * (y_init - y_center) / innerRadius * 100, + py = -pt * charge * (x_init - x_center) / innerRadius * 100; + + // But if the initial T4 curve goes across quarters(i.e. cross axis to separate the quarters), need special redeclaration of px,py signs on these to avoid errors + if (moduleType3 == 0) { // 0 is ps + if (x4 < x3 && x3 < x2) + px = -alpaka::math::abs(acc, px); + else if (x4 > x3 && x3 > x2) + px = alpaka::math::abs(acc, px); + if (y4 < y3 && y3 < y2) + py = -alpaka::math::abs(acc, py); + else if (y4 > y3 && y3 > y2) + py = alpaka::math::abs(acc, py); + } else if (moduleType3 == 1) // 1 is 2s + { + if (x3 < x2 && x2 < x1) + px = -alpaka::math::abs(acc, px); + else if (x3 > x2 && x2 > x1) + px = alpaka::math::abs(acc, px); + if (y3 < y2 && y2 < y1) + py = -alpaka::math::abs(acc, py); + else if (y3 > y2 && y2 > y1) + py = alpaka::math::abs(acc, py); + } + + //to get pz, we use pt/pz=ds/dz, ds is the arclength between MD1 and MD3. + float AO = alpaka::math::sqrt(acc, (x1 - x_center) * (x1 - x_center) + (y1 - y_center) * (y1 - y_center)); + float BO = + alpaka::math::sqrt(acc, (x_init - x_center) * (x_init - x_center) + (y_init - y_center) * (y_init - y_center)); + float AB2 = (x1 - x_init) * (x1 - x_init) + (y1 - y_init) * (y1 - y_init); + float dPhi = alpaka::math::acos(acc, (AO * AO + BO * BO - AB2) / (2 * AO * BO)); + float ds = innerRadius / 100 * dPhi; + + float pz = (z_init - z1) / ds * pt; + float p = alpaka::math::sqrt(acc, px * px + py * py + pz * pz); + + float a = -2.f * k2Rinv1GeVf * 100 * charge; + + float zsi, rtsi; + short layeri; + bool moduleTypei; + rzChiSquared = 0; + float zs[] = {z2, z3, z4}, rts[] = {rt2, rt3, rt4}; + short layers[] = {layer2, layer3, layer4}; + bool moduleTypes[] = {moduleType2, moduleType3, moduleType4}; + for (size_t i = 2; i < 5; i++) { + size_t j = i - 2; + zsi = zs[j]; + rtsi = rts[j]; + layeri = layers[j]; + moduleTypei = moduleTypes[j]; + + if (moduleType3 == 0) { //0: ps + if (i == 3) + continue; + } else { + if (i == 2) + continue; + } + // calculation is copied from PixelTriplet.cc computePT3RZChiSquared + float diffr = 0, diffz = 0; + + float rou = a / p; + // for endcap + float s = (zsi - z_init) * p / pz; + float x = x_init + px / a * alpaka::math::sin(acc, rou * s) - py / a * (1 - alpaka::math::cos(acc, rou * s)); + float y = y_init + py / a * alpaka::math::sin(acc, rou * s) + px / a * (1 - alpaka::math::cos(acc, rou * s)); + diffr = (rtsi - alpaka::math::sqrt(acc, x * x + y * y)) * 100; + + // for barrel + if (layeri <= 6) { + float paraA = + rt_init * rt_init + 2 * (px * px + py * py) / (a * a) + 2 * (y_init * px - x_init * py) / a - rtsi * rtsi; + float paraB = 2 * (x_init * px + y_init * py) / a; + float paraC = 2 * (y_init * px - x_init * py) / a + 2 * (px * px + py * py) / (a * a); + float A = paraB * paraB + paraC * paraC; + float B = 2 * paraA * paraB; + float C = paraA * paraA - paraC * paraC; + float sol1 = (-B + alpaka::math::sqrt(acc, B * B - 4 * A * C)) / (2 * A); + float sol2 = (-B - alpaka::math::sqrt(acc, B * B - 4 * A * C)) / (2 * A); + float solz1 = alpaka::math::asin(acc, sol1) / rou * pz / p + z_init; + float solz2 = alpaka::math::asin(acc, sol2) / rou * pz / p + z_init; + float diffz1 = (solz1 - zsi) * 100; + float diffz2 = (solz2 - zsi) * 100; + diffz = edm::isNotFinite(diffz1) ? diffz2 + : edm::isNotFinite(diffz2) + ? diffz1 + : ((alpaka::math::abs(acc, diffz1) < alpaka::math::abs(acc, diffz2)) ? diffz1 : diffz2); + } + residual = (layeri > 6) ? diffr : diffz; + + // error + error2 = moduleTypei == 0 ? kPixelPSZpitch * kPixelPSZpitch : kStrip2SZpitch * kStrip2SZpitch; + + //check the tilted module, side: PosZ, NegZ, Center(for not tilted) + float drdz; + short side, subdets; + if (i == 2) { + drdz = alpaka::math::abs(acc, modules.drdzs()[lowerModuleIndex2]); + side = modules.sides()[lowerModuleIndex2]; + subdets = modules.subdets()[lowerModuleIndex2]; + } + if (i == 3) { + drdz = alpaka::math::abs(acc, modules.drdzs()[lowerModuleIndex3]); + side = modules.sides()[lowerModuleIndex3]; + subdets = modules.subdets()[lowerModuleIndex3]; + } + const bool isEndcapOrCenter = (subdets == lst::Endcap) or (side == lst::Center); + if (i == 2 || i == 3) { + residual = (layeri <= 6 && ((side == Center) or (drdz < 1))) ? diffz : diffr; + float projection_missing2 = 1.f; + if (drdz < 1) + projection_missing2 = isEndcapOrCenter ? 1.f : 1 / (1 + drdz * drdz); // cos(atan(drdz)), if dr/dz<1 + if (drdz > 1) + projection_missing2 = + isEndcapOrCenter ? 1.f : (drdz * drdz) / (1 + drdz * drdz); //sin(atan(drdz)), if dr/dz>1 + error2 = error2 * projection_missing2; + } + rzChiSquared += 12 * (residual * residual) / error2; + } + // for set rzchi2 cut + // if the 4 points are linear, helix calculation gives nan + if (inner_pt > 100 || edm::isNotFinite(rzChiSquared)) { + float slope; + const bool isPSPS2S = moduleType1 == 0 and moduleType2 == 0 and moduleType3 == 1; + slope = isPSPS2S ? (z2 - z1) / (rt2 - rt1) : (z3 - z1) / (rt3 - rt1); + float residual4_linear = (layer4 <= 6) ? ((z4 - z1) - slope * (rt4 - rt1)) : ((rt4 - rt1) - (z4 - z1) / slope); + + // creating a chi squared type quantity + // 0-> PS, 1->2S + residual4_linear = (moduleType4 == 0) ? residual4_linear / kPixelPSZpitch : residual4_linear / kStrip2SZpitch; + residual4_linear = residual4_linear * 100; + + rzChiSquared = 12 * (residual4_linear * residual4_linear); + return rzChiSquared < 5.839f; + } + float eta1 = alpaka::math::abs(acc, mds.anchorEta()[firstMDIndex]); + uint8_t bin_index = (eta1 > 2.5f) ? (25 - 1) : static_cast(eta1 / 0.1f); + float chi2_cuts[] = {31.5082, 24.5654, 28.9223, 35.5906, 32.0746, 22.6416, 39.1476, 41.0791, 30.2745, + 40.2882, 31.2135, 17.8911, 9.0297, 7.6862, 2.7591, 5.0587, 6.4014, 3.7348, + 4.4768, 5.3087, 15.4535, 14.1107, 23.2778, 18.3643, 26.3276}; + return rzChiSquared < chi2_cuts[bin_index]; + }; + + template + ALPAKA_FN_ACC ALPAKA_FN_INLINE bool runQuadrupletDefaultAlgo(TAcc const& acc, + ModulesConst modules, + MiniDoubletsConst mds, + SegmentsConst segments, + TripletsConst triplets, + uint16_t lowerModuleIndex1, + uint16_t lowerModuleIndex2, + uint16_t lowerModuleIndex3, + uint16_t lowerModuleIndex4, + unsigned int innerTripletIndex, + unsigned int outerTripletIndex, + float& regressionCenterX, + float& regressionCenterY, + float& regressionRadius, + float& nonAnchorRegressionRadius, + float& chiSquared, + const float ptCut, + float& rzChiSquared, + float& nonAnchorChiSquared, + float& dBeta, + float& promptScore, + float& displacedScore, + float& fakeScore) { + unsigned int firstSegmentIndex = triplets.segmentIndices()[innerTripletIndex][0]; + unsigned int secondSegmentIndex = triplets.segmentIndices()[innerTripletIndex][1]; + unsigned int thirdSegmentIndex = triplets.segmentIndices()[outerTripletIndex][1]; + + // require both T3s to have the same charge + const short innerT3charge = triplets.charge()[innerTripletIndex]; + const short outerT3charge = triplets.charge()[outerTripletIndex]; + if (innerT3charge != outerT3charge) + return false; + + unsigned int firstMDIndex = segments.mdIndices()[firstSegmentIndex][0]; + unsigned int secondMDIndex = segments.mdIndices()[secondSegmentIndex][0]; + unsigned int thirdMDIndex = segments.mdIndices()[secondSegmentIndex][1]; + unsigned int fourthMDIndex = segments.mdIndices()[thirdSegmentIndex][1]; + + float x1 = mds.anchorX()[firstMDIndex]; + float x2 = mds.anchorX()[secondMDIndex]; + float x3 = mds.anchorX()[thirdMDIndex]; + float x4 = mds.anchorX()[fourthMDIndex]; + + float y1 = mds.anchorY()[firstMDIndex]; + float y2 = mds.anchorY()[secondMDIndex]; + float y3 = mds.anchorY()[thirdMDIndex]; + float y4 = mds.anchorY()[fourthMDIndex]; + + float inner_circleCenterX = triplets.centerX()[innerTripletIndex]; + float inner_circleCenterY = triplets.centerY()[innerTripletIndex]; + float innerRadius = triplets.radius()[innerTripletIndex]; + float outerRadius = triplets.radius()[outerTripletIndex]; + float inner_pt = 2 * k2Rinv1GeVf * innerRadius; + float pt = (innerRadius + outerRadius) * k2Rinv1GeVf; + + // 4 categories for sigmas + float sigmas2[4], delta1[4], delta2[4], slopes[4]; + bool isFlat[4]; + + float xVec[] = {x1, x2, x3, x4}; + float yVec[] = {y1, y2, y3, y4}; + + const uint16_t lowerModuleIndices[] = {lowerModuleIndex1, lowerModuleIndex2, lowerModuleIndex3, lowerModuleIndex4}; + + computeSigmasForRegression(acc, modules, lowerModuleIndices, delta1, delta2, slopes, isFlat, Params_T4::kLayers); + regressionRadius = computeRadiusUsingRegression(acc, + Params_T4::kLayers, + xVec, + yVec, + delta1, + delta2, + slopes, + isFlat, + regressionCenterX, + regressionCenterY, + sigmas2, + chiSquared); + + //compute the other chisquared + //non anchor is always shifted for tilted and endcap! + float nonAnchorSigmas2[4], nonAnchorDelta1[Params_T4::kLayers], nonAnchorDelta2[Params_T4::kLayers], + nonAnchorSlopes[Params_T4::kLayers]; + float nonAnchorxs[] = {mds.outerX()[firstMDIndex], + mds.outerX()[secondMDIndex], + mds.outerX()[thirdMDIndex], + mds.outerX()[fourthMDIndex]}; + float nonAnchorys[] = {mds.outerY()[firstMDIndex], + mds.outerY()[secondMDIndex], + mds.outerY()[thirdMDIndex], + mds.outerY()[fourthMDIndex]}; + + bool nonAnchorisFlat[4]; + float nonAnchorRegressionG, nonAnchorRegressionF; + + computeSigmasForRegression(acc, + modules, + lowerModuleIndices, + nonAnchorDelta1, + nonAnchorDelta2, + nonAnchorSlopes, + nonAnchorisFlat, + Params_T4::kLayers, + false); + + nonAnchorRegressionRadius = computeRadiusUsingRegression(acc, + Params_T4::kLayers, + nonAnchorxs, + nonAnchorys, + nonAnchorDelta1, + nonAnchorDelta2, + nonAnchorSlopes, + nonAnchorisFlat, + nonAnchorRegressionG, + nonAnchorRegressionF, + nonAnchorSigmas2, + nonAnchorChiSquared); + + bool inference = lst::t4dnn::runInference(acc, + mds, + modules, + firstMDIndex, + secondMDIndex, + thirdMDIndex, + fourthMDIndex, + lowerModuleIndex1, + lowerModuleIndex2, + lowerModuleIndex3, + lowerModuleIndex4, + innerRadius, + outerRadius, + promptScore, + displacedScore, + fakeScore, + regressionRadius, + nonAnchorRegressionRadius, + triplets.fakeScore()[innerTripletIndex], + triplets.promptScore()[innerTripletIndex], + triplets.displacedScore()[innerTripletIndex], + triplets.fakeScore()[outerTripletIndex], + triplets.promptScore()[outerTripletIndex], + triplets.displacedScore()[outerTripletIndex]); + + if (!inference) { + return false; + } + // only run dBeta selector for low/high pT to avoid removing displaced efficiency + if (pt > 10 || pt < 1) { + if (not runQuintupletdBetaAlgoSelector(acc, + modules, + mds, + segments, + lowerModuleIndex1, + lowerModuleIndex2, + lowerModuleIndex3, + lowerModuleIndex4, + firstSegmentIndex, + thirdSegmentIndex, + firstMDIndex, + secondMDIndex, + thirdMDIndex, + fourthMDIndex, + dBeta, + ptCut)) + return false; + } + + if (not passT4RZConstraint(acc, + modules, + mds, + firstMDIndex, + secondMDIndex, + thirdMDIndex, + fourthMDIndex, + lowerModuleIndex1, + lowerModuleIndex2, + lowerModuleIndex3, + lowerModuleIndex4, + rzChiSquared, + inner_pt, + innerRadius, + inner_circleCenterX, + inner_circleCenterY, + innerT3charge)) + return false; + + float dxy = abs(std::hypot(regressionCenterX, regressionCenterY) - regressionRadius); + float eta_layer3; + const int layer1 = modules.layers()[lowerModuleIndex1]; + if (layer1 == 3) { + eta_layer3 = alpaka::math::abs(acc, mds.anchorEta()[firstMDIndex]); + } else if (layer1 == 2) { + eta_layer3 = alpaka::math::abs(acc, mds.anchorEta()[secondMDIndex]); + } else { + eta_layer3 = alpaka::math::abs(acc, mds.anchorEta()[thirdMDIndex]); + } + if (dxy < 0.05f && eta_layer3 < 0.5f) + return false; + else if (dxy < 0.01f && eta_layer3 < 1.5f) + return false; + + nonAnchorChiSquared = computeChiSquared(acc, + Params_T4::kLayers, + nonAnchorxs, + nonAnchorys, + nonAnchorDelta1, + nonAnchorDelta2, + nonAnchorSlopes, + isFlat, + regressionCenterX, + regressionCenterY, + regressionRadius); + + return true; + }; + + struct CreateQuadruplets { + ALPAKA_FN_ACC void operator()(Acc3D const& acc, + ModulesConst modules, + MiniDoubletsConst mds, + SegmentsConst segments, + Triplets triplets, + TripletsOccupancyConst tripletsOccupancy, + Quadruplets quadruplets, + QuadrupletsOccupancy quadrupletsOccupancy, + ObjectRangesConst ranges, + uint16_t nEligibleT4Modules, + const float ptCut) const { + ALPAKA_ASSERT_ACC((alpaka::getWorkDiv(acc)[1] == 1) && + (alpaka::getWorkDiv(acc)[2] == 1)); + + unsigned int& matchCount = alpaka::declareSharedVar(acc); + + const auto threadIdx = alpaka::getIdx(acc); + const auto blockDim = alpaka::getWorkDiv(acc); + + const int threadIdX = threadIdx.x(); + const int threadIdY = threadIdx.y(); + const int blockSizeX = blockDim.x(); + const int blockSizeY = blockDim.y(); + const int blockSize = blockSizeX * blockSizeY; + const int flatThreadIdxXY = threadIdY * blockSizeX + threadIdX; + const int flatThreadExtent = blockSize; // total threads per block + + const auto& mdIndices = segments.mdIndices(); + const auto& segIdx = triplets.segmentIndices(); + const auto& lmIdx = triplets.lowerModuleIndices(); + const auto& tripIdx = ranges.tripletModuleIndices(); + + for (int iter : cms::alpakatools::uniform_groups_z(acc, nEligibleT4Modules)) { + const uint16_t lowerModule1 = ranges.indicesOfEligibleT4Modules()[iter]; + + if (cms::alpakatools::once_per_block(acc)) { + matchCount = 0; + } + + short layer2_adjustment, md_adjustment; + int layer = modules.layers()[lowerModule1]; + if (layer == 1) { + if (modules.subdets()[lowerModule1] != Endcap) + continue; + layer2_adjustment = 1; + md_adjustment = 1; + } // get upper segment to be in third layer + else if (layer == 2) { + if (modules.subdets()[lowerModule1] != Endcap) + continue; + layer2_adjustment = 1; + md_adjustment = 0; + } // get lower segment to be in third layer + else { + layer2_adjustment = 0; + md_adjustment = 0; + } + const unsigned int nInnerTriplets = tripletsOccupancy.nTriplets()[lowerModule1]; + + alpaka::syncBlockThreads(acc); + + // Step 1: Make inner and outer triplet pairs + for (unsigned int innerTripletArrayIndex : cms::alpakatools::uniform_elements_y(acc, nInnerTriplets)) { + const unsigned int innerTripletIndex = tripIdx[lowerModule1] + innerTripletArrayIndex; + if (triplets.partOfPT5()[innerTripletIndex]) + continue; //don't create T4s for T3s accounted in pT5s + if (triplets.partOfT5()[innerTripletIndex]) + continue; //don't create T4s for T3s accounted in T5s + if (triplets.partOfPT3()[innerTripletIndex]) + continue; //don't create T4s for T3s accounted in pT3s + const uint16_t lowerModule2 = lmIdx[innerTripletIndex][1]; + const unsigned int nOuterTriplets = tripletsOccupancy.nTriplets()[lowerModule2]; + for (unsigned int outerTripletArrayIndex : cms::alpakatools::uniform_elements_x(acc, nOuterTriplets)) { + unsigned int outerTripletIndex = tripIdx[lowerModule2] + outerTripletArrayIndex; + if (triplets.partOfPT5()[outerTripletIndex]) + continue; //don't create T4s for T3s accounted in pT5s + if (triplets.partOfT5()[outerTripletIndex]) + continue; //don't create T4s for T3s accounted in T5s + if (triplets.partOfPT3()[outerTripletIndex]) + continue; //don't create T4s for T3s accounted in pT3s + + const unsigned int innerT3LS2Index = segIdx[innerTripletIndex][1]; + const unsigned int outerT3LS1Index = segIdx[outerTripletIndex][0]; + + //check if the 2 T3s have a common LS + if (innerT3LS2Index != outerT3LS1Index) + continue; + + // If densely connected, do not attempt parallel processing to avoid truncation + if (nInnerTriplets >= kNTripletThreshold || nOuterTriplets >= kNTripletThreshold) { + const uint16_t lowerModule3 = lmIdx[outerTripletIndex][1]; + const uint16_t lowerModule4 = lmIdx[outerTripletIndex][2]; + + float innerRadius = triplets.radius()[innerTripletIndex]; + float outerRadius = triplets.radius()[outerTripletIndex]; + float rzChiSquared, dBeta, nonAnchorChiSquared, regressionCenterX, regressionCenterY, regressionRadius, + nonAnchorRegressionRadius, chiSquared, promptScore, displacedScore, fakeScore; + + float pt = (innerRadius + outerRadius) * k2Rinv1GeVf; + + bool success = runQuadrupletDefaultAlgo(acc, + modules, + mds, + segments, + triplets, + lowerModule1, + lowerModule2, + lowerModule3, + lowerModule4, + innerTripletIndex, + outerTripletIndex, + regressionCenterX, + regressionCenterY, + regressionRadius, + nonAnchorRegressionRadius, + chiSquared, + ptCut, + rzChiSquared, + nonAnchorChiSquared, + dBeta, + promptScore, + displacedScore, + fakeScore); + if (success) { + int totOccupancyQuadruplets = + alpaka::atomicAdd(acc, + &quadrupletsOccupancy.totOccupancyQuadruplets()[lowerModule1], + 1u, + alpaka::hierarchy::Threads{}); + if (totOccupancyQuadruplets >= ranges.quadrupletModuleOccupancy()[lowerModule1]) { +#ifdef WARNINGS + printf("Quadruplet excess alert! Module index = %d, Occupancy = %d\n", + lowerModule1, + totOccupancyQuadruplets); +#endif + } else { + int quadrupletModuleIndex = alpaka::atomicAdd( + acc, &quadrupletsOccupancy.nQuadruplets()[lowerModule1], 1u, alpaka::hierarchy::Threads{}); + if (ranges.quadrupletModuleIndices()[lowerModule1] == -1) { +#ifdef WARNINGS + printf("Quadruplets : no memory for module at module index = %d\n", lowerModule1); +#endif + } else { + unsigned int quadrupletIndex = + ranges.quadrupletModuleIndices()[lowerModule1] + quadrupletModuleIndex; + const unsigned int layer3MDIndex = + mdIndices[segIdx[innerTripletIndex][md_adjustment]][layer2_adjustment]; + float phi = mds.anchorPhi()[layer3MDIndex]; + float eta = mds.anchorEta()[layer3MDIndex]; + + float scores = chiSquared + nonAnchorChiSquared; + addQuadrupletToMemory(triplets, + quadruplets, + innerTripletIndex, + outerTripletIndex, + lowerModule1, + lowerModule2, + lowerModule3, + lowerModule4, + innerRadius, + outerRadius, + pt, + eta, + phi, + scores, + layer, + quadrupletIndex, + rzChiSquared, + dBeta, + promptScore, + displacedScore, + fakeScore, + regressionCenterX, + regressionCenterY, + regressionRadius, + nonAnchorRegressionRadius); + } + } + } + continue; + } + + int mIdx = alpaka::atomicAdd(acc, &matchCount, 1u, alpaka::hierarchy::Threads{}); + const unsigned int quadrupletIndex = ranges.quadrupletModuleIndices()[lowerModule1] + mIdx; + +#ifdef WARNINGS + const unsigned int rightBound = + static_cast(ranges.quadrupletModuleIndices()[lowerModule1 + 1]); + if (quadrupletIndex >= rightBound) { + printf( + "Quadruplet module occupancy alert! module quadruplet starting index = %d, Pair quadruplet index = " + "%d, next module quadruplet starting index = %d\n", + ranges.quadrupletModuleIndices()[lowerModule1], + mIdx, + ranges.quadrupletModuleIndices()[lowerModule1 + 1]); + } +#endif + + quadruplets.preAllocatedTripletIndices()[quadrupletIndex][0] = innerTripletIndex; + quadruplets.preAllocatedTripletIndices()[quadrupletIndex][1] = outerTripletIndex; + } + } + + alpaka::syncBlockThreads(acc); + if (matchCount == 0) { + continue; + } + + // Step 2: Parallel processing of triplet pairs + for (unsigned int i = flatThreadIdxXY; i < matchCount; i += flatThreadExtent) { + const unsigned int quadrupletIndex = ranges.quadrupletModuleIndices()[lowerModule1] + i; + const int innerTripletIndex = quadruplets.preAllocatedTripletIndices()[quadrupletIndex][0]; + const int outerTripletIndex = quadruplets.preAllocatedTripletIndices()[quadrupletIndex][1]; + + const uint16_t lowerModule2 = lmIdx[innerTripletIndex][1]; + const uint16_t lowerModule3 = lmIdx[outerTripletIndex][1]; + const uint16_t lowerModule4 = lmIdx[outerTripletIndex][2]; + + float innerRadius = triplets.radius()[innerTripletIndex]; + float outerRadius = triplets.radius()[outerTripletIndex]; + float rzChiSquared, dBeta, nonAnchorChiSquared, regressionCenterX, regressionCenterY, regressionRadius, + nonAnchorRegressionRadius, chiSquared, promptScore, displacedScore, fakeScore; + + float pt = (innerRadius + outerRadius) * k2Rinv1GeVf; + + bool success = runQuadrupletDefaultAlgo(acc, + modules, + mds, + segments, + triplets, + lowerModule1, + lowerModule2, + lowerModule3, + lowerModule4, + innerTripletIndex, + outerTripletIndex, + regressionCenterX, + regressionCenterY, + regressionRadius, + nonAnchorRegressionRadius, + chiSquared, + ptCut, + rzChiSquared, + nonAnchorChiSquared, + dBeta, + promptScore, + displacedScore, + fakeScore); + if (success) { + int totOccupancyQuadruplets = alpaka::atomicAdd( + acc, &quadrupletsOccupancy.totOccupancyQuadruplets()[lowerModule1], 1u, alpaka::hierarchy::Threads{}); + if (totOccupancyQuadruplets >= ranges.quadrupletModuleOccupancy()[lowerModule1]) { +#ifdef WARNINGS + printf("Quadruplet excess alert! Module index = %d, Occupancy = %d\n", + lowerModule1, + totOccupancyQuadruplets); +#endif + } else { + int quadrupletModuleIndex = alpaka::atomicAdd( + acc, &quadrupletsOccupancy.nQuadruplets()[lowerModule1], 1u, alpaka::hierarchy::Threads{}); + if (ranges.quadrupletModuleIndices()[lowerModule1] == -1) { +#ifdef WARNINGS + printf("Quadruplets : no memory for module at module index = %d\n", lowerModule1); +#endif + } else { + const unsigned int quadrupletIndex = + ranges.quadrupletModuleIndices()[lowerModule1] + quadrupletModuleIndex; + const unsigned int layer3MDIndex = + mdIndices[segIdx[innerTripletIndex][md_adjustment]][layer2_adjustment]; + float phi = mds.anchorPhi()[layer3MDIndex]; + float eta = mds.anchorEta()[layer3MDIndex]; + + float scores = chiSquared + nonAnchorChiSquared; + addQuadrupletToMemory(triplets, + quadruplets, + innerTripletIndex, + outerTripletIndex, + lowerModule1, + lowerModule2, + lowerModule3, + lowerModule4, + innerRadius, + outerRadius, + pt, + eta, + phi, + scores, + layer, + quadrupletIndex, + rzChiSquared, + dBeta, + promptScore, + displacedScore, + fakeScore, + regressionCenterX, + regressionCenterY, + regressionRadius, + nonAnchorRegressionRadius); + } + } + } + } + } + } + }; + + ALPAKA_FN_ACC ALPAKA_FN_INLINE bool isValidQuadRegion(ModulesConst modules, uint16_t lowerModule) { + const short layer = modules.layers()[lowerModule]; + const short subdet = modules.subdets()[lowerModule]; + // Quadruplets starting outside these regions are not built. + return (subdet == Barrel && layer > 2) || (subdet == Endcap); + } + + struct CountTripletLSConnections { + ALPAKA_FN_ACC void operator()(Acc3D const& acc, + ModulesConst modules, + MiniDoubletsConst mds, + SegmentsConst segments, + Triplets triplets, + TripletsOccupancyConst tripletsOcc, + ObjectRangesConst ranges, + const float ptCut) const { + // The atomicAdd below with hierarchy::Threads{} requires one block in x, y dimensions. + ALPAKA_ASSERT_ACC((alpaka::getWorkDiv(acc)[1] == 1) && + (alpaka::getWorkDiv(acc)[2] == 1)); + const auto& mdIndices = segments.mdIndices(); + const auto& segIdx = triplets.segmentIndices(); + const auto& lmIdx = triplets.lowerModuleIndices(); + const auto& tripIdx = ranges.tripletModuleIndices(); + + for (uint16_t lowerModule1 : cms::alpakatools::uniform_groups_z(acc, modules.nLowerModules())) { + if (!isValidQuadRegion(modules, lowerModule1)) + continue; + + const unsigned int nInnerTriplets = tripletsOcc.nTriplets()[lowerModule1]; + if (nInnerTriplets == 0) + continue; + + for (unsigned int innerTripletArrayIndex : cms::alpakatools::uniform_elements_y(acc, nInnerTriplets)) { + const unsigned int innerTripletIndex = tripIdx[lowerModule1] + innerTripletArrayIndex; + + const uint16_t lowerModule2 = lmIdx[innerTripletIndex][1]; + const unsigned int nOuterTriplets = tripletsOcc.nTriplets()[lowerModule2]; + if (nOuterTriplets == 0) + continue; + + const unsigned int secondSegIdx = segIdx[innerTripletIndex][1]; + const unsigned int secondMDInner = mdIndices[secondSegIdx][0]; + const unsigned int secondMDOuter = mdIndices[secondSegIdx][1]; + + for (unsigned int outerTripletArrayIndex : cms::alpakatools::uniform_elements_x(acc, nOuterTriplets)) { + const unsigned int outerTripletIndex = tripIdx[lowerModule2] + outerTripletArrayIndex; + const unsigned int thirdSegIdx = segIdx[outerTripletIndex][0]; + const unsigned int thirdMDInner = mdIndices[thirdSegIdx][0]; + const unsigned int thirdMDOuter = mdIndices[thirdSegIdx][1]; + + if ((secondMDInner == thirdMDInner) && (secondMDOuter == thirdMDOuter)) { + // Will only perform runQuadrupletDefaultAlgorithm() checks if densely connected + if (nInnerTriplets < kNTripletThreshold && nOuterTriplets < kNTripletThreshold) { + alpaka::atomicAdd(acc, &triplets.connectedLSMax()[innerTripletIndex], 1u, alpaka::hierarchy::Threads{}); + } else { + const uint16_t lowerModule3 = lmIdx[outerTripletIndex][1]; + const uint16_t lowerModule4 = lmIdx[outerTripletIndex][2]; + + float rzChiSquared, dBeta, nonAnchorChiSquared, regressionCenterX, regressionCenterY, regressionRadius, + nonAnchorRegressionRadius, chiSquared, promptScore, displacedScore, fakeScore; + + const bool ok = runQuadrupletDefaultAlgo(acc, + modules, + mds, + segments, + triplets, + lowerModule1, + lowerModule2, + lowerModule3, + lowerModule4, + innerTripletIndex, + outerTripletIndex, + regressionCenterX, + regressionCenterY, + regressionRadius, + nonAnchorRegressionRadius, + chiSquared, + ptCut, + rzChiSquared, + nonAnchorChiSquared, + dBeta, + promptScore, + displacedScore, + fakeScore); + if (ok) { + alpaka::atomicAdd( + acc, &triplets.connectedLSMax()[innerTripletIndex], 1u, alpaka::hierarchy::Threads{}); + } + } + } + } + } + } + } + }; + + struct CreateEligibleModulesListForQuadruplets { + ALPAKA_FN_ACC void operator()(Acc1D const& acc, + ModulesConst modules, + TripletsOccupancyConst tripletsOcc, + ObjectRanges ranges, + Triplets triplets) const { + // Single-block kernel + ALPAKA_ASSERT_ACC((alpaka::getWorkDiv(acc)[0] == 1)); + + int& nEligibleT4Modulesx = alpaka::declareSharedVar(acc); + int& nTotalQuadrupletsx = alpaka::declareSharedVar(acc); + if (cms::alpakatools::once_per_block(acc)) { + nTotalQuadrupletsx = 0; + nEligibleT4Modulesx = 0; + } + alpaka::syncBlockThreads(acc); + + for (uint16_t lowerModule : cms::alpakatools::uniform_elements(acc, modules.nLowerModules())) { + if (!isValidQuadRegion(modules, lowerModule)) + continue; + + unsigned int nInnerTriplets = tripletsOcc.nTriplets()[lowerModule]; + if (nInnerTriplets == 0) + continue; + + // Sum the real connectivity for triplets in this module + int dynamic_count = 0; + const unsigned int firstTripletIdx = ranges.tripletModuleIndices()[lowerModule]; + for (unsigned int t = 0; t < nInnerTriplets; ++t) { + unsigned int tripletIndex = firstTripletIdx + t; + dynamic_count += triplets.connectedLSMax()[tripletIndex]; + } + + if (dynamic_count == 0) + continue; + + int nEligibleT4Modules = alpaka::atomicAdd(acc, &nEligibleT4Modulesx, 1, alpaka::hierarchy::Threads{}); + int nTotQ = alpaka::atomicAdd(acc, &nTotalQuadrupletsx, dynamic_count, alpaka::hierarchy::Threads{}); + + ranges.quadrupletModuleIndices()[lowerModule] = nTotQ; + ranges.indicesOfEligibleT4Modules()[nEligibleT4Modules] = lowerModule; + ranges.quadrupletModuleOccupancy()[lowerModule] = dynamic_count; + } + + // Wait for all threads to finish before reporting final values + alpaka::syncBlockThreads(acc); + if (cms::alpakatools::once_per_block(acc)) { + ranges.nEligibleT4Modules() = static_cast(nEligibleT4Modulesx); + ranges.nTotalQuads() = static_cast(nTotalQuadrupletsx); + } + } + }; + + struct AddQuadrupletRangesToEventExplicit { + ALPAKA_FN_ACC void operator()(Acc1D const& acc, + ModulesConst modules, + QuadrupletsOccupancyConst quadrupletsOccupancy, + ObjectRanges ranges) const { + // implementation is 1D with a single block + ALPAKA_ASSERT_ACC((alpaka::getWorkDiv(acc)[0] == 1)); + + for (uint16_t i : cms::alpakatools::uniform_elements(acc, modules.nLowerModules())) { + if (quadrupletsOccupancy.nQuadruplets()[i] == 0 or ranges.quadrupletModuleIndices()[i] == -1) { + ranges.quadrupletRanges()[i][0] = -1; + ranges.quadrupletRanges()[i][1] = -1; + } else { + ranges.quadrupletRanges()[i][0] = ranges.quadrupletModuleIndices()[i]; + ranges.quadrupletRanges()[i][1] = + ranges.quadrupletModuleIndices()[i] + quadrupletsOccupancy.nQuadruplets()[i] - 1; + } + } + } + }; +} // namespace ALPAKA_ACCELERATOR_NAMESPACE::lst +#endif diff --git a/RecoTracker/LSTCore/src/alpaka/T4NeuralNetworkWeights.h b/RecoTracker/LSTCore/src/alpaka/T4NeuralNetworkWeights.h new file mode 100644 index 0000000000000..2ba26ff3ee088 --- /dev/null +++ b/RecoTracker/LSTCore/src/alpaka/T4NeuralNetworkWeights.h @@ -0,0 +1,296 @@ +#ifndef RecoTracker_LSTCore_src_alpaka_T4NeuralNetworkWeights_h +#define RecoTracker_LSTCore_src_alpaka_T4NeuralNetworkWeights_h + +#include + +#include "FWCore/Utilities/interface/HostDeviceConstant.h" + +namespace ALPAKA_ACCELERATOR_NAMESPACE::lst::dnn::t4dnn { + HOST_DEVICE_CONSTANT float bias_layer1[32] = { + 0.3130181f, -0.3157252f, -0.0845900f, -0.2268437f, 0.1305549f, 0.3839142f, 0.3933745f, -0.6758229f, + -0.4188058f, -0.2523611f, 1.4036129f, 0.8239079f, 0.1575654f, 0.2041763f, 0.8787493f, 0.2706699f, + -0.1112185f, 0.8988609f, 0.9274163f, -0.1023219f, 0.2916122f, -0.2606929f, 0.3098971f, -0.0602703f, + -0.6031470f, -0.0826582f, 0.3605700f, 0.4836628f, -0.3951748f, 0.0171050f, 0.5156327f, 0.0655813f}; + + HOST_DEVICE_CONSTANT float wgtT_layer1[30][32] = { + {-0.3409404f, -0.2000102f, -0.0890483f, -0.6186467f, -1.7570605f, 0.7890699f, -0.8753229f, 0.5488843f, + -0.5376814f, -0.2228569f, -0.3573552f, 2.1554949f, 0.2248887f, -0.6073594f, 0.2075009f, -0.1408760f, + -0.7051892f, -0.0664303f, 0.2747473f, 0.1450685f, -2.2709231f, -0.4088669f, 0.5452566f, 0.3086576f, + -0.1213564f, -0.9737034f, 0.2679004f, -0.1193755f, 0.9693206f, -0.7785844f, 0.4612639f, 0.7628022f}, + {0.0309823f, -0.5052885f, 0.0509167f, 0.1379152f, 0.2392345f, 0.0935747f, 0.1001187f, 0.2049023f, + -0.1292204f, -0.1732666f, 0.0397899f, -0.1621571f, -0.5934961f, 0.1731108f, -0.0290751f, 0.1774067f, + 0.0519257f, 0.0035256f, -0.0188142f, -0.0639332f, -0.2388456f, 0.0038807f, 0.0088869f, -0.1521942f, + 0.0138392f, 0.0613728f, -0.0693704f, 0.0792511f, -0.0990796f, -0.4666552f, 0.0071133f, -0.1573301f}, + {0.0850391f, 0.1398510f, 0.7312427f, 1.6511540f, -0.1157119f, 1.8312333f, 0.4276405f, 0.1428780f, + -0.0762665f, -0.1204791f, 0.5231065f, -1.9012266f, -0.2759933f, 2.1613867f, -0.3789352f, -0.7384162f, + 0.9980950f, -1.3757244f, -0.3345136f, -0.1468498f, -0.8294499f, 0.0777341f, 1.4606416f, 0.5440134f, + 2.1587088f, 0.0192866f, -4.7703443f, -0.3778406f, 2.0303624f, 0.4716987f, -0.5599666f, 1.0512540f}, + {0.1225046f, 0.1352115f, -0.6140373f, -0.4254693f, 0.2660460f, -0.7739570f, 1.0955174f, 1.2222520f, + -0.1673346f, -0.0774871f, 0.4850340f, -0.2480809f, 2.9553666f, -0.6026165f, -0.5579752f, -0.3300077f, + 0.0972013f, 1.9986047f, 0.9151404f, 0.0942718f, -0.0422994f, -0.0338943f, -0.3063800f, -0.6567743f, + 2.2631214f, 0.6305654f, 0.2954915f, 0.8076106f, 0.7743834f, 0.7023419f, -1.6891261f, -0.6115909f}, + {2.2928166f, 8.8309431f, -1.5345458f, -4.7059007f, 1.2810588f, -0.4951739f, 4.4845004f, 7.7040706f, + -1.8243797f, 0.0956211f, -0.3696172f, 0.6678835f, 4.7292924f, -9.7971382f, 2.0854435f, -3.7175915f, + 3.1399553f, -0.6623941f, 0.1497208f, -0.0905410f, -3.9599996f, 3.3040602f, 2.2535894f, -8.0852966f, + -16.3342419f, 14.6141891f, 2.6647313f, -10.8139477f, 0.5194562f, -2.2337329f, 1.4479593f, 10.4768066f}, + {6.3267903f, 8.4862614f, 18.2483406f, -2.5856876f, 2.9871955f, 6.0332689f, -11.9952469f, -1.8413588f, + 1.3217239f, -0.1347535f, 0.8526750f, -4.6671519f, 1.7776268f, 0.2549058f, -14.5247345f, 3.1438231f, + -9.3088989f, 1.1253707f, 0.0789171f, 0.1527994f, 1.7165481f, -0.9345309f, -5.6528535f, -9.5160980f, + -1.0419953f, -1.0004028f, 1.2163434f, 6.5496387f, 0.0804405f, -6.1611404f, -9.2283039f, -5.5179152f}, + {-2.3399000f, 4.7614522f, -0.7949132f, -1.8299819f, -3.5392070f, -0.0867105f, -0.6724365f, 0.4369151f, + 5.0728159f, -0.2512915f, 0.8541925f, -0.0066838f, -3.1859064f, -1.5541859f, 0.0789470f, -1.3237801f, + 2.6402714f, -0.0662765f, -0.2674660f, 0.0608886f, 1.1315312f, -0.1070065f, 1.8663841f, -0.4901078f, + -1.0676215f, 1.3194273f, -1.2205451f, -1.0945961f, 0.2475978f, 0.9682984f, 0.2073012f, 1.7876559f}, + {-0.4367641f, 0.4319819f, 1.1654582f, -3.1896923f, -1.2711153f, -0.6044747f, -3.3841281f, 0.1955531f, + 4.4088063f, -0.1708138f, -0.3717940f, 0.1636910f, -2.8220074f, -0.1908980f, 1.0748017f, -0.3119625f, + -1.5694339f, -0.1778283f, 0.9781732f, -0.0754240f, 4.2073040f, 2.0899282f, 1.2668608f, 0.0779037f, + -1.9252455f, 2.5578146f, -0.9014751f, 0.0522716f, 0.4070499f, -0.2140355f, 0.7342044f, 0.8045168f}, + {-6.6071205f, 0.6651391f, -5.2370253f, 5.0692110f, 0.7510651f, 2.4857941f, 11.5648117f, -1.7613629f, + 1.9677536f, 0.1799172f, -0.3524120f, -1.9071139f, -7.4744873f, 2.6429889f, -0.1444485f, 2.3542099f, + 1.9711516f, 0.3886011f, 0.0591138f, 0.0699189f, -3.0183461f, 5.6741471f, -1.1488796f, -5.7687993f, + -7.4685879f, -9.6448822f, -3.4691532f, 5.0486813f, -12.6432810f, 8.1007729f, -1.7438285f, -17.6806660f}, + {2.1428952f, -6.8331566f, 13.9079485f, 0.1021993f, -3.7234893f, 4.1696281f, -2.8855472f, 2.6500645f, + 0.2383825f, 0.0089392f, -1.2698770f, -1.4507635f, 1.9835703f, 1.9673402f, 0.0971966f, -3.0400901f, + -1.6690960f, -1.8351660f, -0.5974421f, -0.1041481f, -4.3467126f, -0.8935851f, -4.2216101f, -3.5241582f, + -1.7024223f, 2.4351561f, -3.5994439f, 8.1877632f, 1.3503532f, -8.8505297f, -4.3802824f, -1.2580007f}, + {0.5120507f, -0.8167784f, 0.0356091f, 0.9550222f, -0.0856578f, -1.6678712f, -1.3589050f, -0.4634860f, + 1.7959082f, -0.0221804f, 1.6858692f, -0.2350332f, 1.5413535f, -0.8954431f, -1.9626563f, -0.8659950f, + 1.5534185f, -0.2663006f, -0.0141180f, -0.0641252f, -2.7467253f, 1.3316264f, 0.1042683f, 1.3546844f, + 1.3957758f, -1.1120845f, -0.4499652f, -1.0622435f, 1.5979129f, -2.7754719f, -1.8175740f, -0.2714084f}, + {2.8423781f, -0.5802551f, 2.4059951f, -0.1214163f, 0.6598997f, -2.7015202f, -0.6345362f, -0.1668582f, + 0.9119438f, -0.1413521f, -0.1243188f, -1.5296214f, -2.2767708f, -0.4354365f, -0.2501381f, -4.5476084f, + -2.0407526f, -0.1834106f, 0.4992486f, -0.0249097f, 4.2349467f, -2.2088041f, 0.8245230f, 3.3258171f, + -0.8291156f, 1.8809289f, 2.0291409f, -0.1829567f, 0.3753935f, -2.1255214f, 0.3044383f, -0.8117877f}, + {1.0955867f, -10.5214415f, -2.5690074f, 5.7836456f, 8.4926434f, -0.9013922f, 3.8609581f, -18.2928333f, + -0.1085265f, -0.1957067f, 0.4130777f, 8.3614740f, -0.1366454f, 15.5778351f, 4.3706384f, -6.1245522f, + -0.2726971f, -2.2554579f, 0.1885025f, 0.1015946f, 3.1017957f, -9.0225515f, -2.1712937f, 3.6819403f, + 0.3107212f, 5.4358387f, 0.1556205f, -5.7449660f, -8.0870762f, -15.4704018f, 2.6353786f, 3.1047711f}, + {7.6964021f, 5.5716128f, 10.1820107f, 0.3581720f, -3.3344195f, 0.0203794f, 4.8660736f, 0.9843333f, + 1.5092058f, -0.1621139f, 1.0461260f, -3.6782911f, -3.2391973f, -1.3214555f, 2.3153496f, -0.1955887f, + -0.4074941f, 1.3566486f, 0.4631876f, -0.1068150f, 0.6354957f, 0.4232795f, 1.1560757f, -1.6600587f, + 5.4032493f, -0.7661732f, 0.9706129f, -1.9504737f, -2.6826806f, -1.7078609f, 0.2543025f, -0.4549442f}, + {-0.9363269f, 3.2511759f, 1.4988805f, 0.3896330f, 3.5761805f, -1.8898439f, -0.5672647f, 0.0965158f, + 2.3343837f, -0.2087380f, -1.6115571f, -2.6963377f, -2.8213720f, -1.1722008f, -0.7269015f, -3.2212329f, + -2.5389972f, 1.8190597f, -0.3862028f, 0.0262889f, -0.8232661f, -0.2829636f, 0.9204201f, -0.5837291f, + 1.1138207f, 1.4175742f, 0.6769784f, -1.8336256f, -0.9406338f, 1.4764718f, -0.4086397f, 1.3364717f}, + {-0.2418238f, 0.2686039f, 0.5018921f, 3.8336344f, 0.0712455f, -2.2001183f, -0.9898972f, 0.2089690f, + 3.5906739f, -0.1943938f, -0.6164807f, -2.6908536f, 0.5925170f, -0.3563481f, -1.3926654f, -1.4510415f, + 0.0248772f, 0.8849464f, 0.3642189f, -0.1435763f, 1.5389724f, 0.9731716f, 0.7003614f, 1.1896435f, + 0.0179679f, -0.7365713f, 0.8497150f, 0.2604503f, -2.4675038f, 1.0341363f, 1.5702873f, 0.1958498f}, + {-2.8209245f, -8.6002207f, -7.2007174f, -4.5361319f, 1.8299714f, -10.6522865f, -11.7428637f, -5.2347007f, + 7.4452977f, -0.1641102f, 0.6734878f, 6.6603560f, 0.1524790f, 10.6615028f, -19.3698826f, 2.1673820f, + 11.9764709f, 1.0121659f, 4.0597305f, -0.1707846f, -2.1111398f, -6.5540004f, -18.7466564f, -9.4047441f, + 3.0535262f, 2.7488508f, 9.8994598f, 1.6856064f, 19.2076931f, 6.8278632f, -21.9050026f, 9.4619055f}, + {-9.7272358f, 11.2765751f, -7.5082231f, -1.7225909f, 13.1407022f, -14.8349190f, -3.2504196f, 15.8990402f, + 10.4930105f, 0.0052292f, -1.3149272f, 4.6401610f, -0.3964378f, -4.2996163f, 0.8589647f, 3.0086250f, + 16.7304554f, -1.8806626f, -2.5309651f, -0.1554628f, -1.7752417f, 6.6541591f, -16.9228725f, -4.0883851f, + 2.5308323f, 1.2640512f, 6.7352295f, 4.6456785f, 8.4472675f, 6.6682787f, -22.7476864f, 5.0754099f}, + {-0.0464466f, 0.0371061f, -0.3150824f, 0.2278159f, 0.0781510f, 0.3682392f, -0.0076914f, -0.0199720f, + -0.0213357f, -0.0197757f, -3.6537273f, -1.3694360f, -0.1145320f, -0.1972410f, -1.1493046f, 0.5570444f, + 0.0053487f, -2.8934319f, -3.5391710f, -0.0754734f, -1.4390492f, 0.5983394f, -1.1443400f, -0.0306496f, + -0.0482763f, -0.3964492f, 0.0067964f, -0.3276957f, 0.0020439f, -0.3366909f, -0.0156225f, 0.0541534f}, + {-19.9178314f, -10.1069622f, 4.6057873f, -1.6936491f, -1.3457170f, 4.9226985f, 9.7256756f, 3.1693432f, + -14.5028820f, 0.0931624f, 0.0866399f, 0.5221885f, -5.0346432f, -9.0003071f, 6.8700786f, -0.3239930f, + 10.3604975f, 0.2665467f, -0.8745342f, 0.1671114f, 11.8470201f, 7.2739577f, 0.3534940f, 9.4823771f, + 7.6572828f, -0.6058345f, 4.8220043f, -4.4774542f, 6.0150776f, 13.1617508f, 8.9563951f, -5.3839235f}, + {-15.3784885f, -4.7021117f, 1.3418291f, -5.5577135f, -3.5915122f, 2.1950095f, 9.0849819f, 4.1605716f, + -9.9853363f, -0.2489175f, 0.0786141f, 1.3932520f, -13.3884411f, -7.0676465f, 8.2610102f, -0.6629214f, + 9.6442671f, -1.4963982f, 0.0448971f, 0.0271809f, 11.0636673f, 10.8506622f, 1.5703944f, 11.9227686f, + 8.8615837f, 2.9824717f, 6.5997915f, -5.1319938f, 6.7163310f, 8.8659868f, 8.8796968f, -0.0585466f}, + {-0.5536407f, -0.0183597f, -0.8522137f, -1.6674993f, -0.5186954f, 0.2934249f, -0.3867851f, -0.2717040f, + 0.3141194f, 0.0803582f, 0.9972823f, 1.0203497f, 0.1344238f, 0.2180226f, 0.3881106f, -0.5604802f, + 0.8635465f, 0.7390655f, 0.9010500f, -0.1396771f, -0.5388748f, -0.0631624f, 0.1186292f, -0.5192882f, + 0.4407078f, -1.0921305f, 0.4530205f, -0.2301908f, -0.0831401f, -1.6932087f, 0.6916737f, -7.0294051f}, + {-0.7002707f, -1.2612772f, 0.1443634f, 0.1889855f, -1.5781668f, 0.5102594f, 1.0737917f, -0.3611829f, + -9.7292194f, -0.2542418f, 0.7840535f, -0.1868368f, -3.0662625f, 0.1156164f, 1.8433700f, 0.5661708f, + -0.7161480f, 0.5610115f, 1.1667061f, 0.0282633f, 2.0658562f, -0.1299639f, 0.8088669f, -0.3167280f, + 0.0436082f, 0.2469898f, -0.0169766f, 1.4717587f, -0.6065165f, 0.3588113f, 2.0036082f, 1.5167968f}, + {0.8769321f, 0.4987610f, -0.1006753f, 0.1093301f, 0.6219906f, -0.0845869f, 0.5858616f, -0.3685570f, + 0.1055476f, 0.0447120f, 0.9492686f, 0.6054065f, 0.3753049f, -0.0302966f, 0.1571343f, 0.5852838f, + -0.0052162f, 0.6741483f, 1.0925988f, -0.0104674f, -0.2353034f, -0.2112046f, -0.1704750f, 0.2319756f, + -0.6973480f, 0.5177563f, 0.0008280f, 0.0842060f, -0.4818317f, 0.2600521f, -0.8436811f, -0.3244939f}, + {-0.2498078f, -0.0665436f, -0.6455372f, -11.8636007f, 0.5546330f, -0.1828474f, 0.4531056f, -0.5045753f, + 0.0021860f, -0.1727674f, 0.8539888f, 0.9355938f, 0.1768771f, 0.1097697f, 0.3439056f, -2.7710619f, + 0.5562497f, 0.7414805f, 0.9723032f, -0.0587594f, -0.2656697f, -0.2262570f, 0.2070577f, -0.4649415f, + 0.6072245f, 0.4071639f, 0.7271490f, -0.3399665f, -0.2358655f, -0.6572027f, 0.3878818f, -2.0380676f}, + {-1.0040656f, -15.5001106f, 0.1743044f, -0.1830439f, -12.8375502f, 0.6479514f, 0.4189325f, 0.0413221f, + -2.3929427f, 0.0917285f, 0.8877478f, -0.4985425f, -3.9854083f, 0.2641849f, 1.1492108f, 1.1804928f, + -0.4219547f, 0.7097337f, 0.7229153f, -0.1290352f, 1.0844772f, -0.2112084f, 0.2665555f, -0.0124178f, + -0.7096525f, -0.0406028f, 0.1377101f, 0.9613231f, -0.5315894f, 0.3310910f, 1.4178389f, 1.1687748f}, + {0.6737692f, 0.5887758f, -0.0501311f, 1.0043997f, 0.9804797f, 0.1708015f, 0.1819106f, -0.5253279f, + -0.4184220f, -0.0125217f, 1.0117618f, 0.4730924f, 0.6627691f, -0.1479898f, 0.2580604f, 0.7293127f, + 0.0848437f, 0.7156146f, 0.7816427f, 0.0669967f, -0.1618401f, -0.0909416f, 0.1740791f, -0.0961089f, + -0.9237953f, 0.1705291f, 0.2970416f, 0.2243806f, -0.2949813f, 0.1089575f, -0.3233873f, -0.5519235f}, + {0.8008638f, 0.1581205f, -0.0660521f, -1.8700145f, 1.8415585f, -0.1517836f, 0.9849059f, -0.3305838f, + -0.1100404f, 0.1668913f, 0.0475755f, -0.7035441f, 0.0815494f, -0.4517657f, 0.3797469f, -0.7237183f, + -0.4326949f, -0.0045456f, -0.1520977f, 0.0641505f, 0.3261537f, 0.3451711f, 0.3400122f, 0.1905922f, + 0.5979490f, 1.8726524f, -0.0212991f, 0.5759162f, -0.3411353f, 0.6154488f, -0.6591337f, 1.1044852f}, + {-0.5229920f, -1.9910779f, 0.0082411f, -1.0964926f, -2.8197410f, 0.0379619f, -0.6362457f, 0.2922821f, + 2.1298475f, 0.0429705f, 0.0860498f, 0.0998842f, -1.5189127f, 0.8646886f, -0.3320501f, 0.8533753f, + 0.2106902f, -0.1353741f, 0.1139200f, 0.1223277f, -1.3724730f, -0.4976718f, -0.8456128f, 0.6275283f, + -0.7746818f, -0.4165801f, -0.3285437f, -1.1218213f, 0.2541133f, -0.0699065f, -0.3644446f, -0.1836386f}, + {-0.3121711f, 0.6553325f, 0.0358024f, 2.9133885f, 0.0601561f, 0.0650956f, -0.3388072f, -0.1669552f, + -1.0281963f, 0.0421454f, -0.0252209f, 0.6697660f, 0.3152214f, -0.1007413f, 0.0397672f, 0.0762647f, + 0.6438169f, -0.1315191f, -0.0042900f, -0.0298286f, 0.4026648f, 0.1243395f, 0.5714200f, -0.5934660f, + 0.1301748f, -0.8935670f, 0.3112782f, 0.2476320f, 0.3565531f, -0.2412163f, 0.8716605f, -0.7849768f}, + }; + + HOST_DEVICE_CONSTANT float bias_layer2[32] = { + -0.2339502f, -2.1117039f, 1.3029057f, 0.1388091f, -0.1596625f, -0.0429177f, -1.2913821f, -3.2888331f, + -0.1078175f, 0.3783462f, -0.1941358f, -0.4196923f, 1.5371246f, 1.0102557f, 0.0155053f, -0.1713930f, + -0.8003848f, 1.2788725f, 1.2547292f, 0.9737166f, 0.0159556f, -0.9538616f, -0.6874874f, 1.4604672f, + -0.0163722f, -0.6740248f, 0.4310561f, 0.4786355f, 0.4338695f, -0.2144726f, -0.3379008f, 0.9070537f}; + + HOST_DEVICE_CONSTANT float wgtT_layer2[32][32] = { + {-0.1525858f, -1.7981244f, 0.8225231f, 0.9017935f, -0.1756990f, -0.1731562f, 0.3533860f, -0.1859016f, + 0.0066088f, -1.9248251f, -0.1184793f, 0.0248543f, 0.5696007f, -0.0907341f, 1.2301605f, -0.2350527f, + -1.4149488f, -0.1846859f, -1.7834854f, 1.0467646f, -0.1320335f, -0.5828664f, -0.9305819f, -1.3805257f, + -0.0421591f, -0.9069743f, -1.1278086f, 0.2304732f, -0.5726073f, 0.0060693f, -0.6733936f, 2.4653385f}, + {-0.1177371f, -13.0121050f, -1.2429829f, -2.7850580f, -0.0618916f, -0.1455433f, -2.9844699f, 2.1670566f, + 0.0382921f, 1.7592702f, 0.0663455f, -0.6991704f, 1.1692400f, -1.3941003f, -0.7780491f, -0.1350629f, + -3.8783550f, -0.2548187f, 0.3231797f, 1.3278724f, -0.3614419f, -0.6638193f, -0.3801469f, 0.3153987f, + 2.9468744f, 0.4463870f, -2.2508523f, 0.2242093f, -0.7827371f, 0.0342467f, -0.1136825f, -1.4055092f}, + {-0.0814950f, 1.1034510f, -2.7856872f, 1.1012450f, -0.0762363f, 0.0192755f, 0.0942746f, -0.8603304f, + 0.0463609f, -0.4289516f, -0.2340941f, 0.7183347f, -0.3241202f, -3.3051322f, 0.6164730f, -0.1720765f, + -2.9967003f, -0.2995018f, 1.4385067f, -3.2133255f, -0.0849749f, 1.7341144f, 0.7924995f, 2.0554399f, + 1.0967957f, 0.3311919f, -0.9581537f, -0.7702340f, 1.0316148f, 0.0471946f, 0.9495452f, 2.0788605f}, + {0.0446539f, -0.7262716f, 0.2242534f, -2.2987046f, -0.1550755f, -0.2146342f, 0.2053458f, 0.1453160f, + -0.0573685f, -0.6836352f, -0.1972867f, 0.3058005f, -0.2511542f, 1.0547395f, 0.2851569f, -0.1420605f, + 0.1955147f, -0.8554440f, -0.7066809f, -0.3464343f, -0.0067930f, -0.3269323f, -0.2229819f, -0.3971729f, + -0.6963940f, -0.2993692f, -0.1417452f, -0.7594388f, 0.5256453f, -0.1079851f, 0.1336530f, -1.2257522f}, + {-0.0354096f, -10.4679108f, 2.8168788f, 0.2949276f, -0.1779284f, -0.1868454f, -1.4323943f, 0.8472028f, + -0.0747709f, 0.2078577f, -0.1229924f, 0.0567793f, 2.0551534f, -0.4557192f, -0.7714190f, -0.0745471f, + -0.1583210f, 1.3936666f, -1.1807573f, -0.6911215f, -0.1539799f, -0.6865327f, 0.5996337f, -0.7800899f, + 0.1926797f, 0.4295036f, -1.3045609f, 1.3919017f, -0.6824273f, -0.0342412f, -0.1761061f, -0.6753551f}, + {-0.1453138f, -2.1702523f, -0.8408509f, 0.0124695f, 0.0548023f, -0.0599126f, -1.3220210f, -0.9635140f, + 0.0304381f, 0.8915846f, -0.0114084f, -0.1811013f, -1.0392225f, 0.9355077f, 0.5600502f, -0.1725564f, + 1.3016142f, 0.0637798f, 0.2127237f, -0.4201719f, -0.1813546f, -2.8520944f, -1.0162476f, 0.4036767f, + -1.3526999f, -0.3187918f, 0.2210670f, -0.0195748f, 0.5739521f, -0.2086164f, -4.5898190f, 0.8253796f}, + {-0.1084008f, -1.1329324f, -1.6204234f, 1.1476817f, -0.0628205f, -0.1122542f, 0.5731813f, 0.2743464f, + -0.0484099f, -0.8538088f, 0.0105360f, -0.7381154f, -1.1602422f, -1.2242011f, 0.2023281f, -0.2181710f, + -1.7745955f, -0.0181406f, 1.6978747f, -1.8673037f, -0.1872354f, 1.3988467f, -1.3061260f, 0.6495667f, + -0.6886965f, -0.6353581f, 0.2563086f, 0.4972568f, 1.2073671f, -0.0324760f, 0.6208668f, 2.0307891f}, + {-0.0193408f, -0.4770618f, 1.1954961f, -4.2316422f, -0.0323112f, -0.1459102f, -0.8704525f, 0.0916627f, + 0.1220654f, 1.4996083f, 0.1548122f, -1.9588864f, -1.5469869f, -1.3433179f, -1.0718721f, -0.0825612f, + -0.4096117f, 0.9981126f, -1.7012634f, -1.8265936f, -0.0371830f, 2.3563027f, -1.3538713f, 0.6455814f, + 1.7223636f, 0.7526782f, -2.3576136f, 1.1849345f, 2.1408458f, -0.1714138f, 0.4818093f, -0.3588967f}, + {0.0729335f, -6.1714144f, 2.1946981f, -4.1299558f, 0.0044464f, -0.1933377f, 0.6864235f, -0.3803750f, + -0.2148182f, 0.9621077f, -0.1807104f, 2.4574375f, 1.2062972f, 2.5480094f, 0.6602007f, -0.0865193f, + -1.3168519f, 1.7552648f, 0.5018759f, 1.1154609f, -0.2282289f, 0.7631954f, -1.0011274f, -1.3729336f, + 0.9056255f, -0.0414166f, -1.8079906f, 1.0183337f, 0.3374647f, 0.0572285f, -0.4237293f, 1.7220364f}, + {-0.0589474f, -0.0295253f, 0.1225147f, 0.1446855f, 0.1071389f, 0.1601978f, -0.0401577f, -0.0364199f, + 0.0953632f, 0.0868486f, 0.1608090f, 0.0642403f, 0.1249414f, 0.1899325f, 0.0255690f, -0.0968316f, + -0.1216723f, -0.1698821f, 0.0820711f, 0.1747911f, 0.0620590f, -0.1446941f, -0.1555044f, 0.0741209f, + -0.0763885f, -0.1246467f, 0.1337765f, -0.0873028f, 0.0942246f, 0.0860358f, 0.1234084f, 0.1226101f}, + {0.0643494f, -0.8147358f, -1.5970768f, -1.5264196f, -0.1014192f, 0.0894014f, -0.2399453f, 1.0807495f, + -0.1235767f, -1.2951756f, -0.0810054f, -1.1764668f, 0.4282590f, 0.0908309f, 0.6702700f, 0.0942179f, + -0.8752475f, 1.0613892f, 2.6491807f, 0.4649454f, 0.0426983f, -0.8645003f, -1.2832506f, 1.0818568f, + -0.6891628f, -2.6222782f, -0.0669045f, -5.2834606f, -3.7087319f, -0.1093205f, -3.0082226f, 0.0202389f}, + {-0.2165046f, -2.3887262f, 1.3350971f, 3.2154691f, -0.0899850f, -0.1132472f, 3.4892216f, 0.6763555f, + -0.0366738f, -0.3074943f, -0.0737181f, -0.2638237f, -2.1615725f, 0.9533494f, 1.0567867f, -0.0743023f, + -1.5021936f, -2.1020463f, -0.4901834f, 0.3259082f, -0.2369750f, -0.1386719f, -0.1885716f, -0.6453994f, + -0.2157237f, -0.3096344f, -0.7145861f, -0.0144529f, -1.7015969f, 0.0518598f, -0.8833122f, 0.1368720f}, + {-0.1693773f, -1.7163130f, 1.4214562f, -5.3610578f, -0.1453443f, 0.0007269f, 2.3939090f, -0.4252889f, + -0.2080639f, 0.4997855f, -0.1604623f, -0.9568118f, 0.1174134f, 1.5827537f, 0.0471937f, 0.0410202f, + -5.9910173f, -1.7195174f, 2.6566005f, -0.1300166f, -0.3056662f, -0.4734422f, 0.5415567f, -0.6322125f, + -0.6389906f, 0.8110722f, -1.7760222f, -0.1890267f, 0.4568902f, 0.0178664f, 0.9815944f, -0.6494362f}, + {-0.0803987f, -0.4049147f, -1.1458402f, -3.8709407f, -0.0748584f, 0.0285123f, 1.0476711f, -0.4686022f, + -0.1482179f, -1.3519528f, -0.1977116f, 1.0795574f, -2.0080688f, 0.8830637f, -1.8861086f, -0.1462930f, + -0.4670950f, -2.0300276f, -1.3247769f, 1.1512129f, -0.1377436f, 1.7818434f, 0.5111244f, -1.0817790f, + -1.9341105f, -0.2747863f, 1.7866142f, -0.6981304f, 0.2916693f, -0.0235312f, -0.8175534f, -1.6998042f}, + {-0.1061751f, 0.6215375f, -0.4626823f, 0.7672512f, 0.0020817f, -0.1591142f, 1.0898968f, -0.9204068f, + 0.0574663f, 0.1774926f, 0.0484907f, -0.3295842f, 0.4489160f, -1.1343844f, 1.5402520f, -0.1346410f, + -1.5354218f, -0.7182790f, 0.2196787f, -4.9884086f, -0.1196452f, -0.7342232f, -0.2625498f, -0.1683128f, + -3.4147658f, 0.4656263f, -0.1907654f, -1.5676821f, -0.1993192f, 0.0392413f, -0.4939966f, 0.2587339f}, + {-0.1575988f, 0.3401670f, 0.3991243f, 0.7242632f, -0.1744899f, -0.1183499f, 0.0967574f, -0.3592824f, + -0.2704069f, -0.2581256f, -0.2485954f, 0.3446464f, -0.7076147f, -0.6296598f, 1.0094991f, -0.1585015f, + -0.0655680f, -0.1737440f, -0.6190755f, 0.2832435f, -0.0455194f, -1.0389528f, 1.4021788f, 0.1367040f, + -0.2918817f, -0.8456540f, -0.1551147f, -0.8092477f, 0.2053470f, 0.0517656f, 0.3124267f, -0.0847588f}, + {0.0037535f, 1.7123971f, -0.0702230f, -7.8142509f, -0.1969137f, -0.0857682f, -0.2811019f, -0.4737110f, + -0.1132331f, -0.1521158f, -0.0600281f, 0.0706206f, 1.0167074f, 0.2372540f, 0.5142798f, -0.1154348f, + -0.4991046f, 1.8288782f, -0.2266811f, 0.5554674f, -0.2068119f, -1.1852149f, 1.3449707f, -2.0556967f, + -0.9963455f, 0.2519412f, 0.1331067f, -0.4544231f, 1.0645961f, -0.1106808f, 0.9460998f, 0.4178981f}, + {-0.2119178f, 1.9799466f, -0.2797689f, -1.0932020f, 0.0028856f, -0.1588331f, -0.3786546f, 1.1156173f, + -0.0215458f, -5.3119068f, -0.0948427f, -1.5506617f, -1.4219491f, 2.4189386f, 1.4762870f, -0.0996478f, + 1.9492981f, 0.9989491f, 0.9490891f, -0.7439177f, -0.0704944f, -0.7028235f, 0.2047088f, -2.2349808f, + 0.5578019f, -2.0313745f, -0.0651128f, -0.3024914f, -2.0939598f, -0.0180862f, -1.0688573f, -1.6645123f}, + {-0.0167054f, -0.9788288f, 0.5361816f, 1.6640991f, -0.0242798f, 0.0521166f, -0.4116896f, -0.8344618f, + -0.0963774f, 3.8331950f, -0.0267836f, 2.8407450f, 1.5915482f, -1.3799198f, -2.6205218f, -0.0732574f, + 0.7984170f, -2.0659196f, -3.5870969f, 2.8777542f, -0.1461587f, 1.1945763f, 1.1590073f, 1.2939118f, + 0.4921311f, 4.3183928f, 0.2045673f, 4.7427683f, 3.8618271f, -0.0308256f, 4.8768749f, 1.0944498f}, + {-0.1358386f, -0.1472694f, -0.1158773f, -0.0939915f, 0.1222765f, 0.1192099f, 0.0518727f, 0.0441144f, + 0.0335920f, 0.1631196f, -0.0097409f, 0.0913221f, -0.1351816f, -0.0724972f, -0.0489905f, 0.0919923f, + -0.0411735f, 0.0600587f, 0.0222677f, 0.0181216f, -0.0119533f, 0.0149932f, 0.1105281f, 0.0120449f, + -0.0900230f, 0.1096532f, -0.0958123f, 0.0478526f, -0.0528525f, -0.0530951f, 0.1649548f, 0.0884538f}, + {-0.1365108f, -1.4470286f, -3.3190532f, -0.6094688f, -0.2091497f, -0.0999000f, -0.9306201f, -1.0360260f, + -0.1827736f, 0.0683407f, -0.0502374f, 0.3790842f, -3.5734098f, -0.9349656f, 0.3886600f, 0.1461777f, + -2.0582819f, 1.1910707f, -1.9501384f, -2.9547875f, -0.1470737f, -1.4672521f, -0.8007376f, -0.9336768f, + 1.3155514f, 0.4972472f, -5.6431427f, -3.5151341f, -5.8025484f, -0.0358306f, 0.4548545f, 0.2571939f}, + {-0.1756670f, -1.3679352f, -0.3739035f, -0.4339395f, -0.2128811f, -0.1038225f, 0.5042853f, 0.4281598f, + -0.2762718f, -0.3050812f, -0.0561761f, -0.3663020f, -0.7264599f, -0.7135370f, -1.4355675f, -0.0855041f, + -1.6003639f, -3.3269587f, 0.3331296f, -0.2510884f, -0.0435672f, 2.3782334f, -0.8790841f, -0.4602313f, + 0.3282675f, 1.0674137f, -0.4261901f, 0.6184355f, -0.6284660f, -0.2009129f, -1.3984579f, -0.1182039f}, + {-0.0179437f, -2.4176567f, -0.0757853f, -0.9053250f, -0.1604971f, 0.0653039f, -0.0456533f, -1.2324991f, + -0.2042288f, -1.4000354f, -0.1496307f, 0.7025797f, -0.8148692f, -2.2639663f, -0.1080219f, -0.1692714f, + -0.0350256f, 1.1112232f, -1.2173100f, -3.3865623f, -0.0014034f, 1.0519972f, -0.4149089f, -0.9822370f, + 1.0764426f, 1.0734540f, 0.2395243f, -0.9695317f, -0.2875118f, 0.1405493f, -0.8919160f, 0.4231086f}, + {-0.2191032f, 1.8018681f, -1.7559167f, 0.0348163f, -0.0597553f, 0.0096346f, 0.4048410f, -2.5880859f, + -0.2669998f, 0.5746318f, -0.1141291f, 0.3722134f, -0.6411135f, -0.8711343f, 0.9618454f, -0.1413043f, + 0.5999789f, 0.4171342f, -0.0654649f, -1.1597379f, -0.0378705f, -1.1590790f, 1.3731012f, -0.1211245f, + -0.4004700f, -0.8745431f, 0.3397753f, -0.4758925f, 0.4651093f, -0.1950274f, 0.7756749f, 0.2193565f}, + {-0.0431680f, 1.4789274f, -1.2237777f, 1.1632382f, -0.0625869f, -0.0474214f, -1.3440026f, 1.4450136f, + -0.1436337f, 0.7880324f, -0.2373608f, 1.4573339f, 0.7586362f, -0.9148275f, -0.2211355f, 0.0550283f, + -6.0262470f, 0.3978752f, -0.1995126f, -0.0479593f, -0.1120174f, -1.2093679f, 3.4314570f, 0.0222109f, + -1.1163449f, 0.2131999f, 2.2462623f, -0.4972607f, -0.5182921f, -0.0701132f, 0.5019436f, -0.1937658f}, + {-0.1667943f, 0.0196575f, -0.9141287f, -0.2902696f, 0.0318007f, -0.0335459f, 0.5858797f, -0.0900941f, + -0.0614002f, 1.2998632f, -0.0909332f, -0.5147573f, -0.0871529f, 1.2078508f, -0.2408318f, 0.0943059f, + 2.0550728f, -2.3531728f, -0.4569368f, -0.0462965f, -0.1727646f, 0.8314115f, -0.0838026f, 1.1487722f, + -2.3224678f, 0.6001235f, 0.4483957f, -2.0149994f, -0.5540643f, -0.0993498f, -0.3938250f, 1.0691496f}, + {-0.2128484f, -0.9067336f, -0.8067648f, -0.5638098f, 0.0466387f, -0.1250697f, -1.0633522f, 2.0702426f, + -0.2612922f, 0.4281299f, 0.0145171f, 0.8089548f, 1.0637047f, -1.6296221f, -0.0324031f, -0.1453255f, + 0.5244270f, 0.4705407f, 0.3187122f, 0.8282533f, -0.0451930f, 0.6502138f, -0.3958839f, 0.8502606f, + 0.9307000f, 0.3706145f, -1.7802641f, 0.1567971f, 0.0740849f, -0.0895442f, 0.3950359f, -1.7540554f}, + {0.0023149f, 0.7263200f, 1.8512523f, -0.4771179f, -0.1163291f, 0.0224466f, 2.1579833f, -0.2968757f, + -0.0573871f, 0.9053006f, -0.0305074f, -0.9798708f, -0.0732942f, 1.8749733f, 0.1304994f, -0.1875225f, + 1.1224291f, 0.5052640f, -0.4813080f, -1.1107147f, -0.1909173f, -2.3484604f, 0.1204147f, -0.4592973f, + -0.2538125f, 0.4104489f, -0.7043998f, 1.0401071f, 0.0527195f, -0.1755383f, -0.0225556f, -1.7959408f}, + {-0.1738733f, -2.5466421f, 0.4712865f, -4.1575985f, -0.0673208f, -0.0929229f, 1.4804220f, -2.0162396f, + -0.0699170f, -1.2635130f, 0.0033261f, -0.2207318f, 1.6315613f, -0.2808344f, 0.6500907f, 0.0873509f, + 1.2061590f, 0.9874881f, 0.0503163f, 1.1066431f, -0.2690387f, 3.0120494f, 1.0901014f, -0.8387108f, + -0.1055128f, 0.7355948f, -1.6113559f, 0.1626149f, 0.6186574f, -0.1442034f, -1.3751197f, -0.7357581f}, + {0.0801182f, -1.7048576f, -0.6614535f, 0.2159117f, -0.1067486f, -0.0390127f, -1.7236508f, -4.8257847f, + -0.1392784f, 2.6549494f, -0.0076066f, -0.2928389f, -1.4872373f, -2.2248919f, 0.3740856f, 0.0089244f, + -2.1601932f, 0.2364899f, 2.2799277f, 0.6705108f, 0.0404501f, 0.8240705f, 0.9378061f, 1.7379807f, + 0.6561645f, -2.2824814f, 0.4317684f, 0.9481612f, 3.3949718f, -0.1386193f, -0.5392355f, -0.2192377f}, + {-0.0682519f, 2.1510222f, -1.3331705f, 1.7252599f, -0.0390061f, -0.0311894f, -1.4449217f, 0.7528419f, + 0.0482161f, -0.9996649f, -0.1083589f, 0.7481404f, -1.0902423f, 0.3155636f, -0.7373950f, -0.0225023f, + -0.5480027f, -1.7043972f, 0.1796622f, -1.9089470f, -0.0284076f, 0.1727537f, 0.9399895f, 1.2661113f, + -0.1629798f, 0.9542015f, -0.3868639f, -0.2013974f, 0.7035374f, -0.1902321f, 0.0618807f, 1.6585518f}, + {-0.0382449f, -0.3556600f, 0.7932450f, -1.7579209f, 0.0754038f, -0.1452644f, -0.0262194f, 0.0631723f, + -0.1482574f, 0.4631275f, 0.0041328f, -0.5673072f, 0.6279140f, -0.4053148f, 0.7965385f, 0.0043115f, + -0.7390732f, -0.1081802f, -0.3900261f, 0.2581029f, -0.0874927f, 0.5915189f, -1.1614733f, 0.6526067f, + -0.0630925f, -1.2862672f, 0.0975272f, -1.6696635f, -2.1918375f, 0.0246310f, -0.5078331f, -1.6059371f}, + }; + + HOST_DEVICE_CONSTANT float bias_output_layer[3] = {-0.0299599f, -0.5291600f, 0.4870134f}; + + HOST_DEVICE_CONSTANT float wgtT_output_layer[32][3] = { + {0.0675000f, 0.0580845f, 0.0252083f}, {-0.5480256f, 0.5014070f, -0.4848022f}, + {0.0558893f, -0.9268196f, 0.6123658f}, {-0.1926070f, 0.4826855f, -0.7119671f}, + {-0.0407576f, -0.0624052f, 0.0473613f}, {0.0114524f, -0.0661213f, 0.0861401f}, + {0.4889011f, 0.3984242f, -0.1460386f}, {-0.0307660f, -1.1088817f, 0.5681090f}, + {-0.0103742f, -0.0451352f, 0.0913075f}, {0.2240870f, -0.4127513f, -0.3115116f}, + {-0.1053132f, 0.0329629f, -0.0964765f}, {0.1293648f, -0.0118805f, -0.5233608f}, + {0.0734460f, 0.5619589f, -0.4186259f}, {-0.3173129f, 0.1465155f, -0.1484945f}, + {-0.5634108f, 0.2698199f, 0.1681544f}, {0.1714749f, -0.1649845f, 0.1014268f}, + {0.1057630f, 0.9072341f, -1.1890781f}, {-0.3175716f, -0.2992002f, 0.3401313f}, + {-0.4994496f, -0.1189708f, 0.3650176f}, {0.4023024f, -0.9219202f, 0.0693439f}, + {-0.1709604f, -0.0994071f, 0.0222464f}, {0.3324146f, 0.0158491f, -0.4939574f}, + {0.0952293f, -0.5191534f, 0.3818873f}, {0.0176812f, 0.2723607f, -0.3078566f}, + {0.1919187f, -0.8505318f, 0.1964855f}, {0.0822281f, 0.0565761f, -0.5816049f}, + {-0.0743144f, 0.0852944f, -0.7256451f}, {0.1760607f, 0.0578475f, -0.8243266f}, + {0.1617602f, 0.1823115f, -0.4042889f}, {-0.0384557f, -0.0115344f, 0.0508929f}, + {0.5513267f, -0.0695007f, -1.2113558f}, {0.0910288f, -0.4101452f, 0.3242988f}, + }; +} //namespace ALPAKA_ACCELERATOR_NAMESPACE::lst::dnn::t4dnn +#endif diff --git a/RecoTracker/LSTCore/src/alpaka/TrackCandidate.h b/RecoTracker/LSTCore/src/alpaka/TrackCandidate.h index a74e767b2807b..8cfe00506588b 100644 --- a/RecoTracker/LSTCore/src/alpaka/TrackCandidate.h +++ b/RecoTracker/LSTCore/src/alpaka/TrackCandidate.h @@ -17,6 +17,7 @@ #include "RecoTracker/LSTCore/interface/SegmentsSoA.h" #include "RecoTracker/LSTCore/interface/TrackCandidatesSoA.h" #include "RecoTracker/LSTCore/interface/TripletsSoA.h" +#include "RecoTracker/LSTCore/interface/QuadrupletsSoA.h" #include "NeuralNetwork.h" @@ -77,6 +78,40 @@ namespace ALPAKA_ACCELERATOR_NAMESPACE::lst { candsExtended.radius()[trackCandidateIndex] = __F2H(radius); } + ALPAKA_FN_ACC ALPAKA_FN_INLINE void addT4TrackCandidateToMemory(TrackCandidatesBase& candsBase, + TrackCandidatesExtended& candsExtended, + LSTObjType trackCandidateType, + unsigned int innerTrackletIndex, + unsigned int outerTrackletIndex, + uint8_t* logicalLayerIndices, + uint16_t* lowerModuleIndices, + unsigned int* hitIndices, + int pixelSeedIndex, + float centerX, + float centerY, + float radius, + unsigned int trackCandidateIndex, + unsigned int directObjectIndex) { + candsBase.trackCandidateType()[trackCandidateIndex] = trackCandidateType; + candsExtended.directObjectIndices()[trackCandidateIndex] = directObjectIndex; + candsBase.pixelSeedIndex()[trackCandidateIndex] = pixelSeedIndex; + + candsExtended.objectIndices()[trackCandidateIndex][0] = innerTrackletIndex; + candsExtended.objectIndices()[trackCandidateIndex][1] = outerTrackletIndex; + + //send the starting pointer to the logicalLayer and hitIndices + for (int i = 0; i < Params_T4::kLayers; i++) { + candsExtended.logicalLayers()[trackCandidateIndex][i] = logicalLayerIndices[i]; + candsExtended.lowerModuleIndices()[trackCandidateIndex][i] = lowerModuleIndices[i]; + } + for (int i = 0; i < 2 * Params_T4::kLayers; i++) { + candsBase.hitIndices()[trackCandidateIndex][i] = hitIndices[i]; + } + candsExtended.centerX()[trackCandidateIndex] = __F2H(centerX); + candsExtended.centerY()[trackCandidateIndex] = __F2H(centerY); + candsExtended.radius()[trackCandidateIndex] = __F2H(radius); + } + ALPAKA_FN_ACC ALPAKA_FN_INLINE int checkPixelHits( unsigned int ix, unsigned int jx, MiniDoubletsConst mds, SegmentsConst segments, HitsBaseConst hitsBase) { int phits1[Params_pLS::kHits]; @@ -231,7 +266,8 @@ namespace ALPAKA_ACCELERATOR_NAMESPACE::lst { PixelSegments pixelSegments, MiniDoubletsConst mds, HitsBaseConst hitsBase, - QuintupletsConst quintuplets) const { + QuintupletsConst quintuplets, + QuadrupletsConst quadruplets) const { int pixelModuleIndex = modules.nLowerModules(); unsigned int nPixels = segmentsOccupancy.nSegments()[pixelModuleIndex]; for (unsigned int pixelArrayIndex : cms::alpakatools::uniform_elements_y(acc, nPixels)) { @@ -314,6 +350,123 @@ namespace ALPAKA_ACCELERATOR_NAMESPACE::lst { } }; + struct CrossCleanT4 { + ALPAKA_FN_ACC void operator()(Acc3D const& acc, + ModulesConst modules, + Quadruplets quadruplets, + QuadrupletsOccupancyConst quadrupletsOccupancy, + PixelQuintupletsConst pixelQuintuplets, + PixelTripletsConst pixelTriplets, + QuintupletsConst quintuplets, + TrackCandidatesBase candsBase, + TrackCandidatesExtended candsExtended, + MiniDoubletsConst mds, + SegmentsConst segments, + TripletsConst triplets, + ObjectRangesConst ranges) const { + for (int lowmod : cms::alpakatools::uniform_elements_z(acc, modules.nLowerModules())) { + if (ranges.quadrupletModuleIndices()[lowmod] == -1) + continue; + + unsigned int nQuads = quadrupletsOccupancy.nQuadruplets()[lowmod]; + for (unsigned int iOff : cms::alpakatools::uniform_elements_y(acc, nQuads)) { + unsigned int iT4 = ranges.quadrupletModuleIndices()[lowmod] + iOff; + + // skip already-dup + if (quadruplets.isDup()[iT4]) + continue; + + // Cross cleaning step + float eta1 = __H2F(quadruplets.eta()[iT4]); + float phi1 = __H2F(quadruplets.phi()[iT4]); + + unsigned int nTrackCandidates = candsBase.nTrackCandidates(); + for (unsigned int trackCandidateIndex : cms::alpakatools::uniform_elements_x(acc, nTrackCandidates)) { + short type = candsBase.trackCandidateType()[trackCandidateIndex]; + unsigned int outerTrackletIdx = candsExtended.objectIndices()[trackCandidateIndex][1]; + if (type == LSTObjType::T5) { + unsigned int quintupletIndex = outerTrackletIdx; // T5 index + uint16_t t5_lowerModIdx1 = quintuplets.lowerModuleIndices()[quintupletIndex][0]; + short layer2_adjustment = 1; + short layer3_adjustment; + int layer = modules.layers()[t5_lowerModIdx1]; + if (layer == 1) { + layer3_adjustment = 1; + } else { + layer3_adjustment = 0; + } + int innerTripletIndex = quintuplets.tripletIndices()[quintupletIndex][0]; + float phi2 = + mds.anchorPhi()[segments.mdIndices()[triplets.segmentIndices()[innerTripletIndex][layer3_adjustment]] + [layer2_adjustment]]; + float eta2 = + mds.anchorEta()[segments.mdIndices()[triplets.segmentIndices()[innerTripletIndex][layer3_adjustment]] + [layer2_adjustment]]; + float dEta = alpaka::math::abs(acc, eta1 - eta2); + float dPhi = cms::alpakatools::deltaPhi(acc, phi1, phi2); + + float dR2 = dEta * dEta + dPhi * dPhi; + if (dR2 < 1e-3f) { + quadruplets.isDup()[iT4] = true; + } + } + if (type == LSTObjType::pT3) { + int pT3Index = outerTrackletIdx; + uint16_t pT3_lowerModIdx1 = pixelTriplets.lowerModuleIndices()[pT3Index][0]; + short layer2_adjustment = 1; + short layer3_adjustment; + int layer = modules.layers()[pT3_lowerModIdx1]; + if (layer == 1) { + layer3_adjustment = 1; + } else { + layer3_adjustment = 0; + } + int innerTripletIndex = pixelTriplets.tripletIndices()[pT3Index]; + float phi2 = + mds.anchorPhi()[segments.mdIndices()[triplets.segmentIndices()[innerTripletIndex][layer3_adjustment]] + [layer2_adjustment]]; + float eta2 = + mds.anchorEta()[segments.mdIndices()[triplets.segmentIndices()[innerTripletIndex][layer3_adjustment]] + [layer2_adjustment]]; + float dEta = alpaka::math::abs(acc, eta1 - eta2); + float dPhi = cms::alpakatools::deltaPhi(acc, phi1, phi2); + + float dR2 = dEta * dEta + dPhi * dPhi; + if (dR2 < 1e-3f) + quadruplets.isDup()[iT4] = true; + } + if (type == LSTObjType::pT5) { + unsigned int quintupletIndex = outerTrackletIdx; + uint16_t t5_lowerModIdx1 = quintuplets.lowerModuleIndices()[quintupletIndex][0]; + short layer2_adjustment = 1; + short layer3_adjustment; + int layer = modules.layers()[t5_lowerModIdx1]; + if (layer == 1) { + layer3_adjustment = 1; + } else { + layer3_adjustment = 0; + } + int innerTripletIndex = quintuplets.tripletIndices()[quintupletIndex][0]; + float phi2 = + mds.anchorPhi()[segments.mdIndices()[triplets.segmentIndices()[innerTripletIndex][layer3_adjustment]] + [layer2_adjustment]]; + float eta2 = + mds.anchorEta()[segments.mdIndices()[triplets.segmentIndices()[innerTripletIndex][layer3_adjustment]] + [layer2_adjustment]]; + float dEta = alpaka::math::abs(acc, eta1 - eta2); + float dPhi = cms::alpakatools::deltaPhi(acc, phi1, phi2); + + float dR2 = dEta * dEta + dPhi * dPhi; + if (dR2 < 1e-3f) { + quadruplets.isDup()[iT4] = true; + } + } + } + } + } + } + }; + struct AddpT3asTrackCandidates { ALPAKA_FN_ACC void operator()(Acc1D const& acc, uint16_t nLowerModules, @@ -434,8 +587,8 @@ namespace ALPAKA_ACCELERATOR_NAMESPACE::lst { unsigned int trackCandidateIdx = alpaka::atomicAdd(acc, &candsBase.nTrackCandidates(), 1u, alpaka::hierarchy::Threads{}); - if (trackCandidateIdx - candsExtended.nTrackCandidatesT5() >= - n_max_pixel_track_candidates) // T5 TCs have already been added + if (trackCandidateIdx - candsExtended.nTrackCandidatesT5() - candsExtended.nTrackCandidatesT4() >= + n_max_pixel_track_candidates) // T5, T4 TCs have already been added { #ifdef WARNINGS printf("Track Candidate excess alert! Type = pLS"); @@ -507,6 +660,60 @@ namespace ALPAKA_ACCELERATOR_NAMESPACE::lst { } } }; + + struct AddT4asTrackCandidate { + ALPAKA_FN_ACC void operator()(Acc2D const& acc, + uint16_t nLowerModules, + Quadruplets quadruplets, + QuadrupletsOccupancyConst quadrupletsOccupancy, + TripletsConst triplets, + TrackCandidatesBase candsBase, + TrackCandidatesExtended candsExtended, + ObjectRangesConst ranges) const { + for (int idx : cms::alpakatools::uniform_elements_y(acc, nLowerModules)) { + if (ranges.quadrupletModuleIndices()[idx] == -1) + continue; + + unsigned int nQuads = quadrupletsOccupancy.nQuadruplets()[idx]; + for (unsigned int jdx : cms::alpakatools::uniform_elements_x(acc, nQuads)) { + unsigned int quadrupletIndex = ranges.quadrupletModuleIndices()[idx] + jdx; + + if (quadruplets.isDup()[quadrupletIndex]) + continue; + + unsigned int trackCandidateIdx = + alpaka::atomicAdd(acc, &candsBase.nTrackCandidates(), 1u, alpaka::hierarchy::Threads{}); + if (trackCandidateIdx - candsExtended.nTrackCandidatespT5() - candsExtended.nTrackCandidatespT3() - + candsExtended.nTrackCandidatesT5() >= + n_max_nonpixel_track_candidates) // pT5, pT3, T5 TCs have been added, but not pLS TCs + { +#ifdef WARNINGS + printf("Track Candidate excess alert! Type = T4"); +#endif + alpaka::atomicSub(acc, &candsBase.nTrackCandidates(), 1u, alpaka::hierarchy::Threads{}); + break; + } else { + alpaka::atomicAdd(acc, &candsExtended.nTrackCandidatesT4(), 1u, alpaka::hierarchy::Threads{}); + addT4TrackCandidateToMemory(candsBase, + candsExtended, + LSTObjType::T4, + quadrupletIndex, + quadrupletIndex, + quadruplets.logicalLayers()[quadrupletIndex].data(), + quadruplets.lowerModuleIndices()[quadrupletIndex].data(), + quadruplets.hitIndices()[quadrupletIndex].data(), + -1 /*no pixel seed index for T4s*/, + quadruplets.regressionCenterX()[quadrupletIndex], + quadruplets.regressionCenterY()[quadrupletIndex], + quadruplets.regressionRadius()[quadrupletIndex], + trackCandidateIdx, + quadrupletIndex); + quadruplets.partOfTC()[quadrupletIndex] = true; + } + } + } + } + }; } // namespace ALPAKA_ACCELERATOR_NAMESPACE::lst ASSERT_DEVICE_MATCHES_HOST_COLLECTION(lst::TrackCandidatesBaseDeviceCollection, lst::TrackCandidatesBaseHostCollection); diff --git a/RecoTracker/LSTCore/src/alpaka/Triplet.h b/RecoTracker/LSTCore/src/alpaka/Triplet.h index 51e641d07a83c..36341e7969008 100644 --- a/RecoTracker/LSTCore/src/alpaka/Triplet.h +++ b/RecoTracker/LSTCore/src/alpaka/Triplet.h @@ -29,7 +29,8 @@ namespace ALPAKA_ACCELERATOR_NAMESPACE::lst { float circleCenterX, float circleCenterY, unsigned int tripletIndex, - float (&t3Scores)[dnn::t3dnn::kOutputFeatures]) { + float (&t3Scores)[dnn::t3dnn::kOutputFeatures], + short charge) { triplets.segmentIndices()[tripletIndex][0] = innerSegmentIndex; triplets.segmentIndices()[tripletIndex][1] = outerSegmentIndex; triplets.lowerModuleIndices()[tripletIndex][0] = innerInnerLowerModuleIndex; @@ -57,6 +58,8 @@ namespace ALPAKA_ACCELERATOR_NAMESPACE::lst { triplets.hitIndices()[tripletIndex][3] = mds.outerHitIndices()[secondMDIndex]; triplets.hitIndices()[tripletIndex][4] = mds.anchorHitIndices()[thirdMDIndex]; triplets.hitIndices()[tripletIndex][5] = mds.outerHitIndices()[thirdMDIndex]; + + triplets.charge()[tripletIndex] = charge; #ifdef CUT_VALUE_DEBUG triplets.betaInCut()[tripletIndex] = betaInCut; #endif @@ -78,7 +81,8 @@ namespace ALPAKA_ACCELERATOR_NAMESPACE::lst { unsigned int thirdMDIndex, float circleRadius, float circleCenterX, - float circleCenterY) { + float circleCenterY, + short& charge) { // Using lst_layer numbering convention defined in ModuleMethods.h const short layer1 = modules.lstLayers()[innerInnerLowerModuleIndex]; const short layer2 = modules.lstLayers()[middleLowerModuleIndex]; @@ -97,6 +101,18 @@ namespace ALPAKA_ACCELERATOR_NAMESPACE::lst { //use linear approximation for regions 9 and 20-24 because it works better (see https://github.com/SegmentLinking/cmssw/pull/92) float residual = alpaka::math::abs(acc, z2 - ((z3 - z1) / (r3 - r1) * (r2 - r1) + z1)); + //get the x,y position of each MD + const float x1 = mds.anchorX()[firstMDIndex] / 100; + const float x2 = mds.anchorX()[secondMDIndex] / 100; + const float x3 = mds.anchorX()[thirdMDIndex] / 100; + + const float y1 = mds.anchorY()[firstMDIndex] / 100; + const float y2 = mds.anchorY()[secondMDIndex] / 100; + const float y3 = mds.anchorY()[thirdMDIndex] / 100; + + float cross = (x2 - x1) * (y3 - y1) - (y2 - y1) * (x3 - x1); + charge = -1 * ((int)copysignf(1.0f, cross)); + //region definitions: https://github.com/user-attachments/assets/2b3c1425-66eb-4524-83de-deb6f3b31f71 if (layer1 == 1 && layer2 == 7) { return residual < 0.01f; // Region 9 @@ -121,15 +137,6 @@ namespace ALPAKA_ACCELERATOR_NAMESPACE::lst { //get the type of module: 0 is ps, 1 is 2s const bool moduleType3 = modules.moduleType()[outerOuterLowerModuleIndex]; - //get the x,y position of each MD - const float x1 = mds.anchorX()[firstMDIndex] / 100; - const float x2 = mds.anchorX()[secondMDIndex] / 100; - const float x3 = mds.anchorX()[thirdMDIndex] / 100; - - const float y1 = mds.anchorY()[firstMDIndex] / 100; - const float y2 = mds.anchorY()[secondMDIndex] / 100; - const float y3 = mds.anchorY()[thirdMDIndex] / 100; - //set initial and target points float x_init = x2; float y_init = y2; @@ -165,9 +172,6 @@ namespace ALPAKA_ACCELERATOR_NAMESPACE::lst { float y_center = circleCenterY / 100; float pt = 2 * k2Rinv1GeVf * circleRadius; //k2Rinv1GeVf is already in cm^(-1) - float cross = (x2 - x1) * (y3 - y1) - (y2 - y1) * (x3 - x1); - short charge = -1 * ((int)copysignf(1.0f, cross)); - //get the px and py at the initial point float px = 2 * charge * k2Rinv1GeVf * (y_init - y_center) * 100; float py = -2 * charge * k2Rinv1GeVf * (x_init - x_center) * 100; @@ -400,7 +404,8 @@ namespace ALPAKA_ACCELERATOR_NAMESPACE::lst { float& circleCenterX, float& circleCenterY, const float ptCut, - float (&t3Scores)[dnn::t3dnn::kOutputFeatures]) { + float (&t3Scores)[dnn::t3dnn::kOutputFeatures], + short& charge) { const unsigned int firstMDIndex = segments.mdIndices()[innerSegmentIndex][0]; const unsigned int secondMDIndex = segments.mdIndices()[outerSegmentIndex][0]; const unsigned int thirdMDIndex = segments.mdIndices()[outerSegmentIndex][1]; @@ -426,7 +431,8 @@ namespace ALPAKA_ACCELERATOR_NAMESPACE::lst { thirdMDIndex, circleRadius, circleCenterX, - circleCenterY)) + circleCenterY, + charge)) return false; const float rt_InLo = mds.anchorRt()[firstMDIndex]; @@ -574,6 +580,7 @@ namespace ALPAKA_ACCELERATOR_NAMESPACE::lst { uint16_t outerOuterLowerModuleIndex = segments.outerLowerModuleIndices()[outerSegmentIndex]; float betaIn, betaInCut, circleRadius, circleCenterX, circleCenterY; + short charge; float t3Scores[dnn::t3dnn::kOutputFeatures] = {0.f}; @@ -592,8 +599,8 @@ namespace ALPAKA_ACCELERATOR_NAMESPACE::lst { circleCenterX, circleCenterY, ptCut, - t3Scores); - + t3Scores, + charge); if (success) { unsigned int totOccupancyTriplets = alpaka::atomicAdd(acc, @@ -627,7 +634,8 @@ namespace ALPAKA_ACCELERATOR_NAMESPACE::lst { circleCenterX, circleCenterY, tripletIndex, - t3Scores); + t3Scores, + charge); } } } diff --git a/RecoTracker/LSTCore/standalone/analysis/DNN/train_T4_DNN.ipynb b/RecoTracker/LSTCore/standalone/analysis/DNN/train_T4_DNN.ipynb new file mode 100644 index 0000000000000..0d367c485db56 --- /dev/null +++ b/RecoTracker/LSTCore/standalone/analysis/DNN/train_T4_DNN.ipynb @@ -0,0 +1,1448 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "" + ] + }, + "execution_count": 1, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# set seed for reproducibility\n", + "import torch\n", + "torch.manual_seed(42)" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "import uproot\n", + "import numpy as np\n", + "\n", + "def load_root_file(file_path, branches=None, print_branches=False):\n", + " all_branches = {}\n", + " with uproot.open(file_path) as file:\n", + " tree = file[\"tree\"]\n", + " # Load all ROOT branches into array if not specified\n", + " if branches is None:\n", + " branches = tree.keys()\n", + " # Option to print the branch names\n", + " if print_branches:\n", + " print(\"Branches:\", tree.keys())\n", + " # Each branch is added to the dictionary\n", + " for branch in branches:\n", + " try:\n", + " all_branches[branch] = (tree[branch].array(library=\"np\"))\n", + " except uproot.KeyInFileError as e:\n", + " print(f\"KeyInFileError: {e}\")\n", + " # Number of events in file\n", + " all_branches['event'] = tree.num_entries\n", + " return all_branches\n", + "\n", + "def load_root_files(file_path1, file_path2, branches=None, print_branches=False):\n", + " all_branches = {}\n", + " def load_file(file_path, all_branches):\n", + " with uproot.open(file_path) as file:\n", + " tree = file[\"tree\"]\n", + " # Load all ROOT branches into array if not specified\n", + " if branches is None:\n", + " file_branches = tree.keys()\n", + " else:\n", + " file_branches = branches\n", + " # Option to print the branch names\n", + " if print_branches:\n", + " print(f\"Branches in {file_path}:\", tree.keys())\n", + " # Each branch is added to the dictionary\n", + " for branch in file_branches:\n", + " try:\n", + " if branch in all_branches:\n", + " all_branches[branch] = np.concatenate(\n", + " (all_branches[branch], tree[branch].array(library=\"np\"))\n", + " )\n", + " else:\n", + " all_branches[branch] = tree[branch].array(library=\"np\")\n", + " except uproot.KeyInFileError as e:\n", + " print(f\"KeyInFileError in {file_path}: {e}\")\n", + " # Number of events in file\n", + " all_branches['event'] = all_branches.get('event', 0) + tree.num_entries\n", + " load_file(file_path1, all_branches)\n", + " load_file(file_path2, all_branches)\n", + "\n", + " return all_branches\n", + "\n", + "branches_list = [\n", + " 't4_innerRadius',\n", + " 't4_outerRadius',\n", + " 't4_pt',\n", + " 't4_eta',\n", + " 't4_phi',\n", + " 't4_isFake',\n", + " 't4_t3_idx0',\n", + " 't4_t3_idx1',\n", + " 't4_pMatched',\n", + " 't4_sim_vxy',\n", + " 't4_sim_vz',\n", + " 't4_t3_fakeScore1',\n", + " 't4_t3_promptScore1',\n", + " 't4_t3_displacedScore1',\n", + " 't4_t3_fakeScore2',\n", + " 't4_t3_promptScore2',\n", + " 't4_t3_displacedScore2',\n", + " 't4_regressionRadius',\n", + " 't4_nonAnchorRegressionRadius'\n", + "]\n", + "\n", + "# Hit-dependent branches\n", + "suffixes = ['r', 'z', 'eta', 'phi', 'layer']\n", + "branches_list += [f't4_t3_{i}_{suffix}' for i in [0, 2, 4] for suffix in suffixes]\n", + "\n", + "PU_file_path = \"noCuts_Current_150925_500ev.root\"\n", + "cube_file_path = \"noCuts_cube50_cpu_debugfull.root\"\n", + "branches = load_root_files(PU_file_path, cube_file_path, branches_list)" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Z max: 267.2349853515625, R max: 110.10993957519531, Eta max: 2.5\n" + ] + } + ], + "source": [ + "z_max = np.max([np.max(event) for event in branches[f't4_t3_4_z'] if event.size>0])\n", + "r_max = np.max([np.max(event) for event in branches[f't4_t3_4_r'] if event.size>0])\n", + "eta_max = 2.5\n", + "phi_max = np.pi\n", + "\n", + "print(f'Z max: {z_max}, R max: {r_max}, Eta max: {eta_max}')\n", + "\n", + "def delta_phi(phi1, phi2):\n", + " delta = phi1 - phi2\n", + " # Adjust delta to be within the range [-pi, pi]\n", + " if delta > np.pi:\n", + " delta -= 2 * np.pi\n", + " elif delta < -np.pi:\n", + " delta += 2 * np.pi\n", + " return delta" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [], + "source": [ + "features_list = []\n", + "eta_list = [] # Used for DNN cut values\n", + "\n", + "for event in range(branches['event']):\n", + " # Determine the number of elements in this event\n", + " num_elements = len(branches['t4_t3_idx0'][event])\n", + "\n", + " for i in range(num_elements):\n", + " features_iter = []\n", + " eta_iter = []\n", + " \n", + " idx0 = branches['t4_t3_idx0'][event][i]\n", + " idx1 = branches['t4_t3_idx1'][event][i]\n", + "\n", + " eta1 = np.abs(branches['t4_t3_0_eta'][event][idx0])\n", + " eta2 = np.abs(branches['t4_t3_2_eta'][event][idx0])\n", + " eta3 = np.abs(branches['t4_t3_4_eta'][event][idx0])\n", + " eta4 = np.abs(branches['t4_t3_4_eta'][event][idx1])\n", + "\n", + " phi1 = (branches['t4_t3_0_phi'][event][idx0])\n", + " phi2 = (branches['t4_t3_2_phi'][event][idx0])\n", + " phi3 = (branches['t4_t3_4_phi'][event][idx0])\n", + " phi4 = (branches['t4_t3_4_phi'][event][idx1])\n", + "\n", + " z1 = np.abs(branches['t4_t3_0_z'][event][idx0])\n", + " z2 = np.abs(branches['t4_t3_2_z'][event][idx0])\n", + " z3 = np.abs(branches['t4_t3_4_z'][event][idx0])\n", + " z4 = np.abs(branches['t4_t3_4_z'][event][idx1])\n", + "\n", + " r1 = branches['t4_t3_0_r'][event][idx0]\n", + " r2 = branches['t4_t3_2_r'][event][idx0]\n", + " r3 = branches['t4_t3_4_r'][event][idx0]\n", + " r4 = branches['t4_t3_4_r'][event][idx1]\n", + "\n", + " innerRad = branches['t4_innerRadius'][event][i]\n", + " outerRad = branches['t4_outerRadius'][event][i]\n", + "\n", + " regRad = branches['t4_regressionRadius'][event][i]\n", + " nonAnchorRegRad = branches['t4_nonAnchorRegressionRadius'][event][i]\n", + "\n", + " f1 = branches['t4_t3_fakeScore1'][event][i]\n", + " f2 = branches['t4_t3_fakeScore2'][event][i]\n", + " p1 = branches['t4_t3_promptScore1'][event][i]\n", + " p2 = branches['t4_t3_promptScore2'][event][i]\n", + " d1 = branches['t4_t3_displacedScore1'][event][i]\n", + " d2 = branches['t4_t3_displacedScore2'][event][i]\n", + "\n", + "\n", + " # Construct the input feature vector using pairwise differences\n", + " features_iter = [\n", + " eta1 / eta_max, # First hit eta, normalized\n", + " np.abs(phi1) / phi_max, # First hit phi, normalized\n", + " z1 / z_max, # First hit z, normalized\n", + " r1 / r_max, # First hit r, normalized\n", + "\n", + " eta2 - eta1, # Difference in eta between hit 2 and 1\n", + " delta_phi(phi2, phi1) / phi_max, # Difference in phi between hit 2 and 1\n", + " (z2 - z1) / z_max, # Difference in z between hit 2 and 1, normalized\n", + " (r2 - r1) / r_max, # Difference in r between hit 2 and 1, normalized\n", + "\n", + " eta3 - eta2, # Difference in eta between hit 3 and 2\n", + " delta_phi(phi3, phi2) / phi_max, # Difference in phi between hit 3 and 2\n", + " (z3 - z2) / z_max, # Difference in z between hit 3 and 2, normalized\n", + " (r3 - r2) / r_max, # Difference in r between hit 3 and 2, normalized\n", + "\n", + " eta4 - eta3, # Difference in eta between hit 4 and 3\n", + " delta_phi(phi4, phi3) / phi_max, # Difference in phi between hit 4 and 3\n", + " (z4 - z3) / z_max, # Difference in z between hit 4 and 3, normalized\n", + " (r4 - r3) / r_max, # Difference in r between hit 4 and 3, normalized\n", + "\n", + " 1.0/innerRad,\n", + " 1.0/outerRad,\n", + " innerRad/outerRad,\n", + " 1.0/regRad,\n", + " 1.0/nonAnchorRegRad,\n", + "\n", + " f1,\n", + " p1,\n", + " d1,\n", + "\n", + " f2,\n", + " p2,\n", + " d2,\n", + "\n", + " (f2 - f1),\n", + " (p2 - p1),\n", + " (d2 - d1),\n", + " ]\n", + "\n", + " # Use the abs eta value of first hit to select cut thresholds\n", + " eta_iter.extend([np.abs(branches['t4_t3_0_eta'][event][idx0])])\n", + " \n", + " # Append the feature vector to the list\n", + " features_list.append(features_iter)\n", + " eta_list.append(eta_iter)\n", + "\n", + "# Convert the list of features to a NumPy array\n", + "features = np.array(features_list).T\n", + "eta_list = np.array(eta_list).T" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [], + "source": [ + "import torch\n", + "\n", + "# Stack features along a new axis to form a single array suitable for NN input\n", + "input_features_numpy = np.stack(features, axis=-1)\n", + "\n", + "# Identify rows with NaN or Inf values\n", + "mask = ~np.isnan(input_features_numpy) & ~np.isinf(input_features_numpy)\n", + "\n", + "# Apply mask across all columns: retain a row only if all its entries are neither NaN nor Inf\n", + "filtered_input_features_numpy = input_features_numpy[np.all(mask, axis=1)]\n", + "t4_isFake_filtered = (np.concatenate(branches['t4_pMatched']) <= 0.75)[np.all(mask, axis=1)]\n", + "t4_sim_vxy_filtered = np.concatenate(branches['t4_sim_vxy'])[np.all(mask, axis=1)]\n", + "\n", + "# Convert to PyTorch tensor when ready to use with NN\n", + "input_features_tensor = torch.tensor(filtered_input_features_numpy, dtype=torch.float32)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Using device: cuda\n", + "Total samples: 1946967\n", + "Fake count: 1932985\n", + "Real count: 13982\n", + "Prompt count: 2190\n", + "Displaced count: 11792\n", + "Initial dataset size: 1946967\n", + "Class distribution before downsampling - Fake: 1932985.0, Prompt: 2190.0, Displaced: 11792.0\n", + "Class distribution after downsampling - Fake: 966492.0, Prompt: 2190.0, Displaced: 11792.0\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch [1/300], Train Loss: 0.8262, Test Loss: 0.7307\n", + "Epoch [2/300], Train Loss: 0.6984, Test Loss: 0.6723\n", + "Epoch [3/300], Train Loss: 0.6558, Test Loss: 0.6354\n", + "Epoch [4/300], Train Loss: 0.6422, Test Loss: 0.6504\n", + "Epoch [5/300], Train Loss: 0.6285, Test Loss: 0.6006\n", + "Epoch [6/300], Train Loss: 0.5979, Test Loss: 0.5772\n", + "Epoch [7/300], Train Loss: 0.5928, Test Loss: 0.5758\n", + "Epoch [8/300], Train Loss: 0.5960, Test Loss: 0.5796\n", + "Epoch [9/300], Train Loss: 0.5725, Test Loss: 0.5631\n", + "Epoch [10/300], Train Loss: 0.5847, Test Loss: 0.5568\n", + "Epoch [11/300], Train Loss: 0.5568, Test Loss: 0.5725\n", + "Epoch [12/300], Train Loss: 0.5551, Test Loss: 0.5485\n", + "Epoch [13/300], Train Loss: 0.5623, Test Loss: 0.5463\n", + "Epoch [14/300], Train Loss: 0.5651, Test Loss: 0.5555\n", + "Epoch [15/300], Train Loss: 0.5495, Test Loss: 0.5448\n", + "Epoch [16/300], Train Loss: 0.5534, Test Loss: 0.5455\n", + "Epoch [17/300], Train Loss: 0.5334, Test Loss: 0.6103\n", + "Epoch [18/300], Train Loss: 0.5272, Test Loss: 0.5251\n", + "Epoch [19/300], Train Loss: 0.5258, Test Loss: 0.5222\n", + "Epoch [20/300], Train Loss: 0.5201, Test Loss: 0.5431\n", + "Epoch [21/300], Train Loss: 0.5144, Test Loss: 0.5056\n", + "Epoch [22/300], Train Loss: 0.5118, Test Loss: 0.5412\n", + "Epoch [23/300], Train Loss: 0.4975, Test Loss: 0.5118\n", + "Epoch [24/300], Train Loss: 0.4988, Test Loss: 0.5144\n", + "Epoch [25/300], Train Loss: 0.4954, Test Loss: 0.5034\n", + "Epoch [26/300], Train Loss: 0.4959, Test Loss: 0.5003\n", + "Epoch [27/300], Train Loss: 0.4833, Test Loss: 0.4957\n", + "Epoch [28/300], Train Loss: 0.4849, Test Loss: 0.5282\n", + "Epoch [29/300], Train Loss: 0.4788, Test Loss: 0.4945\n", + "Epoch [30/300], Train Loss: 0.4824, Test Loss: 0.4891\n", + "Epoch [31/300], Train Loss: 0.4792, Test Loss: 0.4802\n", + "Epoch [32/300], Train Loss: 0.4737, Test Loss: 0.5338\n", + "Epoch [33/300], Train Loss: 0.4752, Test Loss: 0.5139\n", + "Epoch [34/300], Train Loss: 0.4672, Test Loss: 0.4943\n", + "Epoch [35/300], Train Loss: 0.4710, Test Loss: 0.5034\n", + "Epoch [36/300], Train Loss: 0.4575, Test Loss: 0.5106\n", + "Epoch [37/300], Train Loss: 0.4573, Test Loss: 0.4946\n", + "Epoch [38/300], Train Loss: 0.4609, Test Loss: 0.4791\n", + "Epoch [39/300], Train Loss: 0.4584, Test Loss: 0.4649\n", + "Epoch [40/300], Train Loss: 0.4688, Test Loss: 0.4792\n", + "Epoch [41/300], Train Loss: 0.4613, Test Loss: 0.4610\n", + "Epoch [42/300], Train Loss: 0.4510, Test Loss: 0.4804\n", + "Epoch [43/300], Train Loss: 0.4483, Test Loss: 0.4874\n", + "Epoch [44/300], Train Loss: 0.4445, Test Loss: 0.4716\n", + "Epoch [45/300], Train Loss: 0.4424, Test Loss: 0.4908\n", + "Epoch [46/300], Train Loss: 0.4435, Test Loss: 0.4672\n", + "Epoch [47/300], Train Loss: 0.4456, Test Loss: 0.4673\n", + "Epoch [48/300], Train Loss: 0.4424, Test Loss: 0.4532\n", + "Epoch [49/300], Train Loss: 0.4396, Test Loss: 0.4657\n", + "Epoch [50/300], Train Loss: 0.4362, Test Loss: 0.5339\n", + "Epoch [51/300], Train Loss: 0.4441, Test Loss: 0.4684\n", + "Epoch [52/300], Train Loss: 0.4574, Test Loss: 0.4573\n", + "Epoch [53/300], Train Loss: 0.4444, Test Loss: 0.4703\n", + "Epoch [54/300], Train Loss: 0.4352, Test Loss: 0.4398\n", + "Epoch [55/300], Train Loss: 0.4338, Test Loss: 0.4560\n", + "Epoch [56/300], Train Loss: 0.4468, Test Loss: 0.4672\n", + "Epoch [57/300], Train Loss: 0.4252, Test Loss: 0.4407\n", + "Epoch [58/300], Train Loss: 0.4334, Test Loss: 0.4492\n", + "Epoch [59/300], Train Loss: 0.4298, Test Loss: 0.4438\n", + "Epoch [60/300], Train Loss: 0.4231, Test Loss: 0.4564\n", + "Epoch [61/300], Train Loss: 0.4229, Test Loss: 0.4524\n", + "Epoch [62/300], Train Loss: 0.4185, Test Loss: 0.4634\n", + "Epoch [63/300], Train Loss: 0.4166, Test Loss: 0.4557\n", + "Epoch [64/300], Train Loss: 0.4214, Test Loss: 0.4489\n", + "Epoch [65/300], Train Loss: 0.4159, Test Loss: 0.4627\n", + "Epoch [66/300], Train Loss: 0.4113, Test Loss: 0.4788\n", + "Epoch [67/300], Train Loss: 0.4130, Test Loss: 0.4512\n", + "Epoch [68/300], Train Loss: 0.4114, Test Loss: 0.4509\n", + "Epoch [69/300], Train Loss: 0.4108, Test Loss: 0.4716\n", + "Epoch [70/300], Train Loss: 0.4154, Test Loss: 0.4326\n", + "Epoch [71/300], Train Loss: 0.4147, Test Loss: 0.4302\n", + "Epoch [72/300], Train Loss: 0.4247, Test Loss: 0.5130\n", + "Epoch [73/300], Train Loss: 0.4078, Test Loss: 0.4326\n", + "Epoch [74/300], Train Loss: 0.4171, Test Loss: 0.4450\n", + "Epoch [75/300], Train Loss: 0.4153, Test Loss: 0.4334\n", + "Epoch [76/300], Train Loss: 0.4270, Test Loss: 0.4437\n", + "Epoch [77/300], Train Loss: 0.4012, Test Loss: 0.4235\n", + "Epoch [78/300], Train Loss: 0.4010, Test Loss: 0.4442\n", + "Epoch [79/300], Train Loss: 0.4078, Test Loss: 0.4476\n", + "Epoch [80/300], Train Loss: 0.4240, Test Loss: 0.4310\n", + "Epoch [81/300], Train Loss: 0.4213, Test Loss: 0.4574\n", + "Epoch [82/300], Train Loss: 0.4130, Test Loss: 0.4812\n", + "Epoch [83/300], Train Loss: 0.3974, Test Loss: 0.4349\n", + "Epoch [84/300], Train Loss: 0.3932, Test Loss: 0.4436\n", + "Epoch [85/300], Train Loss: 0.3967, Test Loss: 0.4406\n", + "Epoch [86/300], Train Loss: 0.3921, Test Loss: 0.4468\n", + "Epoch [87/300], Train Loss: 0.3930, Test Loss: 0.4496\n", + "Epoch [88/300], Train Loss: 0.3922, Test Loss: 0.4462\n", + "Epoch [89/300], Train Loss: 0.4134, Test Loss: 0.4289\n", + "Epoch [90/300], Train Loss: 0.4080, Test Loss: 0.4400\n", + "Epoch [91/300], Train Loss: 0.4070, Test Loss: 0.4224\n", + "Epoch [92/300], Train Loss: 0.3892, Test Loss: 0.4594\n", + "Epoch [93/300], Train Loss: 0.4063, Test Loss: 0.4472\n", + "Epoch [94/300], Train Loss: 0.3997, Test Loss: 0.4308\n", + "Epoch [95/300], Train Loss: 0.4149, Test Loss: 0.5144\n", + "Epoch [96/300], Train Loss: 0.4140, Test Loss: 0.4127\n", + "Epoch [97/300], Train Loss: 0.3880, Test Loss: 0.4359\n", + "Epoch [98/300], Train Loss: 0.4036, Test Loss: 0.4353\n", + "Epoch [99/300], Train Loss: 0.4043, Test Loss: 0.4192\n", + "Epoch [100/300], Train Loss: 0.4127, Test Loss: 0.4683\n", + "Epoch [101/300], Train Loss: 0.4072, Test Loss: 0.4290\n", + "Epoch [102/300], Train Loss: 0.3958, Test Loss: 0.4598\n", + "Epoch [103/300], Train Loss: 0.4020, Test Loss: 0.4302\n", + "Epoch [104/300], Train Loss: 0.4015, Test Loss: 0.4803\n", + "Epoch [105/300], Train Loss: 0.3887, Test Loss: 0.4258\n", + "Epoch [106/300], Train Loss: 0.3788, Test Loss: 0.4279\n", + "Epoch [107/300], Train Loss: 0.3901, Test Loss: 0.4457\n", + "Epoch [108/300], Train Loss: 0.3904, Test Loss: 0.4271\n", + "Epoch [109/300], Train Loss: 0.3809, Test Loss: 0.4260\n", + "Epoch [110/300], Train Loss: 0.3849, Test Loss: 0.4127\n", + "Epoch [111/300], Train Loss: 0.3864, Test Loss: 0.4762\n", + "Epoch [112/300], Train Loss: 0.3832, Test Loss: 0.4473\n", + "Epoch [113/300], Train Loss: 0.3796, Test Loss: 0.4374\n", + "Epoch [114/300], Train Loss: 0.3783, Test Loss: 0.4334\n", + "Epoch [115/300], Train Loss: 0.3831, Test Loss: 0.4108\n", + "Epoch [116/300], Train Loss: 0.3758, Test Loss: 0.4429\n", + "Epoch [117/300], Train Loss: 0.3785, Test Loss: 0.4300\n", + "Epoch [118/300], Train Loss: 0.3739, Test Loss: 0.4449\n", + "Epoch [119/300], Train Loss: 0.3740, Test Loss: 0.4305\n", + "Epoch [120/300], Train Loss: 0.3743, Test Loss: 0.4437\n", + "Epoch [121/300], Train Loss: 0.3741, Test Loss: 0.4362\n", + "Epoch [122/300], Train Loss: 0.3708, Test Loss: 0.4114\n", + "Epoch [123/300], Train Loss: 0.3679, Test Loss: 0.4234\n", + "Epoch [124/300], Train Loss: 0.3661, Test Loss: 0.5044\n", + "Epoch [125/300], Train Loss: 0.3712, Test Loss: 0.4327\n", + "Epoch [126/300], Train Loss: 0.3703, Test Loss: 0.4273\n", + "Epoch [127/300], Train Loss: 0.3901, Test Loss: 0.4462\n", + "Epoch [128/300], Train Loss: 0.3758, Test Loss: 0.4272\n", + "Epoch [129/300], Train Loss: 0.3759, Test Loss: 0.4172\n", + "Epoch [130/300], Train Loss: 0.3739, Test Loss: 0.4229\n", + "Epoch [131/300], Train Loss: 0.3668, Test Loss: 0.4256\n", + "Epoch [132/300], Train Loss: 0.3663, Test Loss: 0.4326\n", + "Epoch [133/300], Train Loss: 0.3726, Test Loss: 0.4198\n", + "Epoch [134/300], Train Loss: 0.3748, Test Loss: 0.4201\n", + "Epoch [135/300], Train Loss: 0.3669, Test Loss: 0.4145\n", + "Epoch [136/300], Train Loss: 0.3636, Test Loss: 0.4297\n", + "Epoch [137/300], Train Loss: 0.3592, Test Loss: 0.4185\n", + "Epoch [138/300], Train Loss: 0.3656, Test Loss: 0.4440\n", + "Epoch [139/300], Train Loss: 0.3650, Test Loss: 0.4338\n", + "Epoch [140/300], Train Loss: 0.3679, Test Loss: 0.4335\n", + "Epoch [141/300], Train Loss: 0.3597, Test Loss: 0.4269\n", + "Epoch [142/300], Train Loss: 0.3618, Test Loss: 0.4381\n", + "Epoch [143/300], Train Loss: 0.3720, Test Loss: 0.4318\n", + "Epoch [144/300], Train Loss: 0.3701, Test Loss: 0.4254\n", + "Epoch [145/300], Train Loss: 0.3643, Test Loss: 0.4900\n", + "Epoch [146/300], Train Loss: 0.3736, Test Loss: 0.4397\n", + "Epoch [147/300], Train Loss: 0.3625, Test Loss: 0.4248\n", + "Epoch [148/300], Train Loss: 0.3519, Test Loss: 0.4480\n", + "Epoch [149/300], Train Loss: 0.3556, Test Loss: 0.4299\n", + "Epoch [150/300], Train Loss: 0.3556, Test Loss: 0.4094\n", + "Epoch [151/300], Train Loss: 0.3525, Test Loss: 0.4243\n", + "Epoch [152/300], Train Loss: 0.3526, Test Loss: 0.4035\n", + "Epoch [153/300], Train Loss: 0.3504, Test Loss: 0.4089\n", + "Epoch [154/300], Train Loss: 0.3539, Test Loss: 0.4252\n", + "Epoch [155/300], Train Loss: 0.3576, Test Loss: 0.4084\n", + "Epoch [156/300], Train Loss: 0.3520, Test Loss: 0.4198\n", + "Epoch [157/300], Train Loss: 0.3495, Test Loss: 0.4285\n", + "Epoch [158/300], Train Loss: 0.3499, Test Loss: 0.4082\n", + "Epoch [159/300], Train Loss: 0.3475, Test Loss: 0.4138\n", + "Epoch [160/300], Train Loss: 0.3461, Test Loss: 0.4153\n", + "Epoch [161/300], Train Loss: 0.3513, Test Loss: 0.4413\n", + "Epoch [162/300], Train Loss: 0.3505, Test Loss: 0.4231\n", + "Epoch [163/300], Train Loss: 0.3615, Test Loss: 0.4049\n", + "Epoch [164/300], Train Loss: 0.3439, Test Loss: 0.4157\n", + "Epoch [165/300], Train Loss: 0.3502, Test Loss: 0.4159\n", + "Epoch [166/300], Train Loss: 0.3454, Test Loss: 0.4066\n", + "Epoch [167/300], Train Loss: 0.3453, Test Loss: 0.4293\n", + "Epoch [168/300], Train Loss: 0.3484, Test Loss: 0.4205\n", + "Epoch [169/300], Train Loss: 0.3466, Test Loss: 0.4128\n", + "Epoch [170/300], Train Loss: 0.3473, Test Loss: 0.4063\n", + "Epoch [171/300], Train Loss: 0.3396, Test Loss: 0.4006\n", + "Epoch [172/300], Train Loss: 0.3476, Test Loss: 0.4058\n", + "Epoch [173/300], Train Loss: 0.3449, Test Loss: 0.4076\n", + "Epoch [174/300], Train Loss: 0.3450, Test Loss: 0.4103\n", + "Epoch [175/300], Train Loss: 0.3403, Test Loss: 0.4116\n", + "Epoch [176/300], Train Loss: 0.3432, Test Loss: 0.4331\n", + "Epoch [177/300], Train Loss: 0.3403, Test Loss: 0.4116\n", + "Epoch [178/300], Train Loss: 0.3400, Test Loss: 0.4104\n", + "Epoch [179/300], Train Loss: 0.3396, Test Loss: 0.4070\n", + "Epoch [180/300], Train Loss: 0.3407, Test Loss: 0.4206\n", + "Epoch [181/300], Train Loss: 0.3399, Test Loss: 0.4223\n", + "Epoch [182/300], Train Loss: 0.3404, Test Loss: 0.4279\n", + "Epoch [183/300], Train Loss: 0.3392, Test Loss: 0.4018\n", + "Epoch [184/300], Train Loss: 0.3459, Test Loss: 0.4183\n", + "Epoch [185/300], Train Loss: 0.3385, Test Loss: 0.4166\n", + "Epoch [186/300], Train Loss: 0.3394, Test Loss: 0.4159\n", + "Epoch [187/300], Train Loss: 0.3373, Test Loss: 0.4247\n", + "Epoch [188/300], Train Loss: 0.3413, Test Loss: 0.4185\n", + "Epoch [189/300], Train Loss: 0.3387, Test Loss: 0.4141\n", + "Epoch [190/300], Train Loss: 0.3391, Test Loss: 0.4188\n", + "Epoch [191/300], Train Loss: 0.3350, Test Loss: 0.4356\n", + "Epoch [192/300], Train Loss: 0.3351, Test Loss: 0.4201\n", + "Epoch [193/300], Train Loss: 0.3375, Test Loss: 0.4211\n", + "Epoch [194/300], Train Loss: 0.3372, Test Loss: 0.4053\n", + "Epoch [195/300], Train Loss: 0.3351, Test Loss: 0.4102\n", + "Epoch [196/300], Train Loss: 0.3372, Test Loss: 0.4055\n", + "Epoch [197/300], Train Loss: 0.3351, Test Loss: 0.4109\n", + "Epoch [198/300], Train Loss: 0.3354, Test Loss: 0.4085\n", + "Epoch [199/300], Train Loss: 0.3372, Test Loss: 0.4105\n", + "Epoch [200/300], Train Loss: 0.3377, Test Loss: 0.4089\n", + "Epoch [201/300], Train Loss: 0.3372, Test Loss: 0.4181\n", + "Epoch [202/300], Train Loss: 0.3372, Test Loss: 0.4441\n", + "Epoch [203/300], Train Loss: 0.3332, Test Loss: 0.4079\n", + "Epoch [204/300], Train Loss: 0.3342, Test Loss: 0.4259\n", + "Epoch [205/300], Train Loss: 0.3423, Test Loss: 0.4104\n", + "Epoch [206/300], Train Loss: 0.3342, Test Loss: 0.4055\n", + "Epoch [207/300], Train Loss: 0.3396, Test Loss: 0.4037\n", + "Epoch [208/300], Train Loss: 0.3294, Test Loss: 0.4051\n", + "Epoch [209/300], Train Loss: 0.3330, Test Loss: 0.4080\n", + "Epoch [210/300], Train Loss: 0.3363, Test Loss: 0.4226\n", + "Epoch [211/300], Train Loss: 0.3302, Test Loss: 0.4197\n", + "Epoch [212/300], Train Loss: 0.3328, Test Loss: 0.4133\n", + "Epoch [213/300], Train Loss: 0.3328, Test Loss: 0.4245\n", + "Epoch [214/300], Train Loss: 0.3282, Test Loss: 0.4323\n", + "Epoch [215/300], Train Loss: 0.3336, Test Loss: 0.4349\n", + "Epoch [216/300], Train Loss: 0.3318, Test Loss: 0.4043\n", + "Epoch [217/300], Train Loss: 0.3299, Test Loss: 0.4074\n", + "Epoch [218/300], Train Loss: 0.3316, Test Loss: 0.4121\n", + "Epoch [219/300], Train Loss: 0.3327, Test Loss: 0.4349\n", + "Epoch [220/300], Train Loss: 0.3291, Test Loss: 0.4175\n", + "Epoch [221/300], Train Loss: 0.3296, Test Loss: 0.4422\n", + "Epoch [222/300], Train Loss: 0.3321, Test Loss: 0.4050\n", + "Epoch [223/300], Train Loss: 0.3332, Test Loss: 0.4203\n", + "Epoch [224/300], Train Loss: 0.3320, Test Loss: 0.4201\n", + "Epoch [225/300], Train Loss: 0.3307, Test Loss: 0.4229\n", + "Epoch [226/300], Train Loss: 0.3342, Test Loss: 0.4133\n", + "Epoch [227/300], Train Loss: 0.3264, Test Loss: 0.4063\n", + "Epoch [228/300], Train Loss: 0.3394, Test Loss: 0.4537\n", + "Epoch [229/300], Train Loss: 0.3335, Test Loss: 0.4053\n", + "Epoch [230/300], Train Loss: 0.3278, Test Loss: 0.4050\n", + "Epoch [231/300], Train Loss: 0.3310, Test Loss: 0.4099\n", + "Epoch [232/300], Train Loss: 0.3313, Test Loss: 0.4339\n", + "Epoch [233/300], Train Loss: 0.3270, Test Loss: 0.4010\n", + "Epoch [234/300], Train Loss: 0.3248, Test Loss: 0.4435\n", + "Epoch [235/300], Train Loss: 0.3309, Test Loss: 0.4129\n", + "Epoch [236/300], Train Loss: 0.3304, Test Loss: 0.4043\n", + "Epoch [237/300], Train Loss: 0.3278, Test Loss: 0.4281\n", + "Epoch [238/300], Train Loss: 0.3266, Test Loss: 0.4177\n", + "Epoch [239/300], Train Loss: 0.3267, Test Loss: 0.4099\n", + "Epoch [240/300], Train Loss: 0.3273, Test Loss: 0.4351\n", + "Epoch [241/300], Train Loss: 0.3263, Test Loss: 0.4138\n", + "Epoch [242/300], Train Loss: 0.3266, Test Loss: 0.4077\n", + "Epoch [243/300], Train Loss: 0.3242, Test Loss: 0.4124\n", + "Epoch [244/300], Train Loss: 0.3222, Test Loss: 0.4295\n", + "Epoch [245/300], Train Loss: 0.3231, Test Loss: 0.4194\n", + "Epoch [246/300], Train Loss: 0.3260, Test Loss: 0.4018\n", + "Epoch [247/300], Train Loss: 0.3247, Test Loss: 0.4260\n", + "Epoch [248/300], Train Loss: 0.3244, Test Loss: 0.4250\n", + "Epoch [249/300], Train Loss: 0.3222, Test Loss: 0.4036\n", + "Epoch [250/300], Train Loss: 0.3214, Test Loss: 0.4194\n", + "Epoch [251/300], Train Loss: 0.3288, Test Loss: 0.4402\n", + "Epoch [252/300], Train Loss: 0.3243, Test Loss: 0.4092\n", + "Epoch [253/300], Train Loss: 0.3195, Test Loss: 0.4114\n", + "Epoch [254/300], Train Loss: 0.3278, Test Loss: 0.4083\n", + "Epoch [255/300], Train Loss: 0.3194, Test Loss: 0.4039\n", + "Epoch [256/300], Train Loss: 0.3190, Test Loss: 0.4172\n", + "Epoch [257/300], Train Loss: 0.3209, Test Loss: 0.4165\n", + "Epoch [258/300], Train Loss: 0.3220, Test Loss: 0.4042\n", + "Epoch [259/300], Train Loss: 0.3259, Test Loss: 0.4143\n", + "Epoch [260/300], Train Loss: 0.3220, Test Loss: 0.4093\n", + "Epoch [261/300], Train Loss: 0.3196, Test Loss: 0.4037\n", + "Epoch [262/300], Train Loss: 0.3234, Test Loss: 0.4045\n", + "Epoch [263/300], Train Loss: 0.3253, Test Loss: 0.4052\n", + "Epoch [264/300], Train Loss: 0.3191, Test Loss: 0.4276\n", + "Epoch [265/300], Train Loss: 0.3237, Test Loss: 0.4100\n", + "Epoch [266/300], Train Loss: 0.3177, Test Loss: 0.4121\n", + "Epoch [267/300], Train Loss: 0.3267, Test Loss: 0.4139\n", + "Epoch [268/300], Train Loss: 0.3169, Test Loss: 0.4129\n", + "Epoch [269/300], Train Loss: 0.3194, Test Loss: 0.3981\n", + "Epoch [270/300], Train Loss: 0.3198, Test Loss: 0.4082\n", + "Epoch [271/300], Train Loss: 0.3218, Test Loss: 0.4116\n", + "Epoch [272/300], Train Loss: 0.3179, Test Loss: 0.4261\n", + "Epoch [273/300], Train Loss: 0.3218, Test Loss: 0.4040\n", + "Epoch [274/300], Train Loss: 0.3193, Test Loss: 0.4003\n", + "Epoch [275/300], Train Loss: 0.3216, Test Loss: 0.4119\n", + "Epoch [276/300], Train Loss: 0.3163, Test Loss: 0.4251\n", + "Epoch [277/300], Train Loss: 0.3235, Test Loss: 0.4140\n", + "Epoch [278/300], Train Loss: 0.3131, Test Loss: 0.4133\n", + "Epoch [279/300], Train Loss: 0.3159, Test Loss: 0.4344\n", + "Epoch [280/300], Train Loss: 0.3205, Test Loss: 0.4107\n", + "Epoch [281/300], Train Loss: 0.3182, Test Loss: 0.4140\n", + "Epoch [282/300], Train Loss: 0.3175, Test Loss: 0.4070\n", + "Epoch [283/300], Train Loss: 0.3165, Test Loss: 0.4210\n", + "Epoch [284/300], Train Loss: 0.3131, Test Loss: 0.4466\n", + "Epoch [285/300], Train Loss: 0.3196, Test Loss: 0.4135\n", + "Epoch [286/300], Train Loss: 0.3156, Test Loss: 0.4003\n", + "Epoch [287/300], Train Loss: 0.3134, Test Loss: 0.4294\n", + "Epoch [288/300], Train Loss: 0.3205, Test Loss: 0.4209\n", + "Epoch [289/300], Train Loss: 0.3153, Test Loss: 0.4058\n", + "Epoch [290/300], Train Loss: 0.3174, Test Loss: 0.4105\n", + "Epoch [291/300], Train Loss: 0.3142, Test Loss: 0.4304\n", + "Epoch [292/300], Train Loss: 0.3162, Test Loss: 0.4192\n", + "Epoch [293/300], Train Loss: 0.3154, Test Loss: 0.4096\n", + "Epoch [294/300], Train Loss: 0.3162, Test Loss: 0.4072\n", + "Epoch [295/300], Train Loss: 0.3135, Test Loss: 0.4309\n", + "Epoch [296/300], Train Loss: 0.3187, Test Loss: 0.4078\n", + "Epoch [297/300], Train Loss: 0.3118, Test Loss: 0.4294\n", + "Epoch [298/300], Train Loss: 0.3129, Test Loss: 0.4383\n", + "Epoch [299/300], Train Loss: 0.3192, Test Loss: 0.4008\n", + "Epoch [300/300], Train Loss: 0.3121, Test Loss: 0.4246\n" + ] + } + ], + "source": [ + "from torch import nn\n", + "from torch.optim import Adam\n", + "from torch.utils.data import DataLoader, TensorDataset, random_split\n", + "\n", + "# Set device\n", + "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", + "print(f\"Using device: {device}\")\n", + "\n", + "# Create labels tensor\n", + "def create_multiclass_labels(t4_isFake, t4_sim_vxy, displacement_threshold=0.1):\n", + " num_samples = len(t4_isFake)\n", + " labels = torch.zeros((num_samples, 3))\n", + " \n", + " # Fake tracks (class 0)\n", + " fake_mask = t4_isFake\n", + " labels[fake_mask, 0] = 1\n", + " \n", + " # Real tracks\n", + " real_mask = ~fake_mask \n", + " \n", + " # Split real tracks into prompt (class 1) and displaced (class 2)\n", + " prompt_mask = (t4_sim_vxy <= displacement_threshold) & real_mask\n", + " displaced_mask = (t4_sim_vxy > displacement_threshold) & real_mask\n", + " \n", + " labels[prompt_mask, 1] = 1\n", + " labels[displaced_mask, 2] = 1\n", + "\n", + " print(f\"Total samples: {num_samples}\")\n", + " print(f\"Fake count: {fake_mask.sum().item()}\")\n", + " print(f\"Real count: {real_mask.sum().item()}\")\n", + " print(f\"Prompt count: {prompt_mask.sum().item()}\")\n", + " print(f\"Displaced count: {displaced_mask.sum().item()}\")\n", + " \n", + " return labels\n", + "\n", + "labels_tensor = create_multiclass_labels(\n", + " t4_isFake_filtered,\n", + " t4_sim_vxy_filtered\n", + ")\n", + "\n", + "# Neural network for multi-class classification\n", + "class MultiClassNeuralNetwork(nn.Module):\n", + " def __init__(self):\n", + " super(MultiClassNeuralNetwork, self).__init__()\n", + " self.layer1 = nn.Linear(input_features_numpy.shape[1], 32)\n", + " self.layer2 = nn.Linear(32, 32)\n", + " self.output_layer = nn.Linear(32, 3)\n", + " \n", + " def forward(self, x):\n", + " x = self.layer1(x)\n", + " x = nn.ReLU()(x)\n", + " x = self.layer2(x)\n", + " x = nn.ReLU()(x)\n", + " x = self.output_layer(x)\n", + " return nn.functional.softmax(x, dim=1)\n", + "\n", + "# Weighted loss function for multi-class\n", + "class WeightedCrossEntropyLoss(nn.Module):\n", + " def __init__(self):\n", + " super(WeightedCrossEntropyLoss, self).__init__()\n", + " \n", + " def forward(self, outputs, targets, weights):\n", + " eps = 1e-7\n", + " log_probs = torch.log(outputs + eps)\n", + " losses = -weights * torch.sum(targets * log_probs, dim=1)\n", + " return losses.mean()\n", + "\n", + "\n", + "# Calculate class weights (each sample gets a weight to equalize class contributions)\n", + "def calculate_class_weights(labels):\n", + " class_counts = torch.sum(labels, dim=0)\n", + " total_samples = len(labels)\n", + " class_weights = total_samples / (3 * class_counts) # Normalize across 3 classes\n", + " \n", + " sample_weights = torch.zeros(len(labels))\n", + " for i in range(3):\n", + " sample_weights[labels[:, i] == 1] = class_weights[i]\n", + " \n", + " return sample_weights\n", + "\n", + "# Print initial dataset size\n", + "print(f\"Initial dataset size: {len(labels_tensor)}\")\n", + "\n", + "# Remove rows with NaN and update weights accordingly\n", + "nan_mask = torch.isnan(input_features_tensor).any(dim=1) | torch.isnan(labels_tensor).any(dim=1)\n", + "filtered_inputs = input_features_tensor[~nan_mask]\n", + "filtered_labels = labels_tensor[~nan_mask]\n", + "\n", + "# Count samples in each class before downsampling\n", + "class_counts_before = torch.sum(filtered_labels, dim=0)\n", + "print(f\"Class distribution before downsampling - Fake: {class_counts_before[0]}, Prompt: {class_counts_before[1]}, Displaced: {class_counts_before[2]}\")\n", + "\n", + "# Option to downsample each class (binary-class)\n", + "downsample_classes = True # Set to False to disable downsampling\n", + "if downsample_classes:\n", + " # Define downsampling ratios for each class:\n", + " # For example, downsample fakes (class 0) to 50% and keep prompt (class 1) and displaced (class 2) at 100%\n", + " downsample_ratios = {0: 0.5, 1: 1.0, 2: 1.0}\n", + " indices_list = []\n", + " for cls in range(3):\n", + " # Find indices for the current class\n", + " cls_mask = (filtered_labels[:, cls] == 1)\n", + " cls_indices = torch.nonzero(cls_mask).squeeze()\n", + " ratio = downsample_ratios.get(cls, 1.0)\n", + " num_cls = cls_indices.numel()\n", + " num_to_sample = int(num_cls * ratio)\n", + " # Ensure at least one sample is kept if available\n", + " if num_to_sample < 1 and num_cls > 0:\n", + " num_to_sample = 1\n", + " # Shuffle and select the desired number of samples\n", + " cls_indices_shuffled = cls_indices[torch.randperm(num_cls)]\n", + " sampled_cls_indices = cls_indices_shuffled[:num_to_sample]\n", + " indices_list.append(sampled_cls_indices)\n", + " \n", + " # Combine the indices from all classes\n", + " selected_indices = torch.cat(indices_list)\n", + " filtered_inputs = filtered_inputs[selected_indices]\n", + " filtered_labels = filtered_labels[selected_indices]\n", + "\n", + "# Print class distribution after downsampling\n", + "class_counts_after = torch.sum(filtered_labels, dim=0)\n", + "print(f\"Class distribution after downsampling - Fake: {class_counts_after[0]}, Prompt: {class_counts_after[1]}, Displaced: {class_counts_after[2]}\")\n", + "\n", + "# Recalculate sample weights after downsampling (equal weighting per class based on new counts)\n", + "sample_weights = calculate_class_weights(filtered_labels)\n", + "filtered_weights = sample_weights \n", + "\n", + "# Create dataset with weights\n", + "dataset = TensorDataset(filtered_inputs, filtered_labels, filtered_weights)\n", + "\n", + "# Split into train and test sets\n", + "train_size = int(0.8 * len(dataset))\n", + "test_size = len(dataset) - train_size\n", + "train_dataset, test_dataset = random_split(dataset, [train_size, test_size])\n", + "\n", + "# Create data loaders\n", + "train_loader = DataLoader(train_dataset, batch_size=1024, shuffle=True, num_workers=10, pin_memory=True)\n", + "test_loader = DataLoader(test_dataset, batch_size=1024, shuffle=False, num_workers=10, pin_memory=True)\n", + "\n", + "# Initialize model and optimizer\n", + "model = MultiClassNeuralNetwork().to(device)\n", + "loss_function = WeightedCrossEntropyLoss()\n", + "optimizer = Adam(model.parameters(), lr=0.0025)\n", + "\n", + "def evaluate_loss(loader):\n", + " model.eval()\n", + " total_loss = 0\n", + " num_batches = 0\n", + " with torch.no_grad():\n", + " for inputs, targets, weights in loader:\n", + " inputs, targets, weights = inputs.to(device), targets.to(device), weights.to(device)\n", + " outputs = model(inputs)\n", + " loss = loss_function(outputs, targets, weights)\n", + " total_loss += loss.item()\n", + " num_batches += 1\n", + " return total_loss / num_batches\n", + "\n", + "# Training loop\n", + "num_epochs = 300\n", + "train_loss_log = []\n", + "test_loss_log = []\n", + "\n", + "for epoch in range(num_epochs):\n", + " model.train()\n", + " epoch_loss = 0\n", + " num_batches = 0\n", + "\n", + " for inputs, targets, weights in train_loader:\n", + " inputs, targets, weights = inputs.to(device), targets.to(device), weights.to(device)\n", + " \n", + " # Forward pass\n", + " outputs = model(inputs)\n", + " loss = loss_function(outputs, targets, weights)\n", + " epoch_loss += loss.item()\n", + " num_batches += 1\n", + "\n", + " # Backward and optimize\n", + " optimizer.zero_grad()\n", + " loss.backward()\n", + " optimizer.step()\n", + "\n", + " # Calculate average losses\n", + " train_loss = epoch_loss / num_batches\n", + " test_loss = evaluate_loss(test_loader)\n", + " \n", + " train_loss_log.append(train_loss)\n", + " test_loss_log.append(test_loss)\n", + " \n", + " print(f'Epoch [{epoch+1}/{num_epochs}], Train Loss: {train_loss:.4f}, Test Loss: {test_loss:.4f}')" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Baseline accuracy: 0.8717\n", + "\n", + "Feature importances:\n", + "Feature 28 importance: 0.0235\n", + "Feature 27 importance: 0.0229\n", + "Feature 0 importance: 0.0202\n", + "Feature 14 importance: 0.0136\n", + "Feature 18 importance: 0.0093\n", + "Feature 6 importance: 0.0086\n", + "Feature 16 importance: 0.0086\n", + "Feature 23 importance: 0.0068\n", + "Feature 24 importance: 0.0062\n", + "Feature 13 importance: 0.0047\n", + "Feature 17 importance: 0.0040\n", + "Feature 11 importance: 0.0034\n", + "Feature 22 importance: 0.0031\n", + "Feature 5 importance: 0.0023\n", + "Feature 25 importance: 0.0023\n", + "Feature 15 importance: 0.0011\n", + "Feature 10 importance: 0.0010\n", + "Feature 20 importance: 0.0008\n", + "Feature 1 importance: 0.0003\n", + "Feature 19 importance: 0.0001\n", + "Feature 8 importance: 0.0000\n", + "Feature 21 importance: -0.0001\n", + "Feature 2 importance: -0.0005\n", + "Feature 9 importance: -0.0006\n", + "Feature 3 importance: -0.0009\n", + "Feature 7 importance: -0.0026\n", + "Feature 29 importance: -0.0026\n", + "Feature 26 importance: -0.0030\n", + "Feature 12 importance: -0.0035\n", + "Feature 4 importance: -0.0044\n" + ] + } + ], + "source": [ + "import torch\n", + "import numpy as np\n", + "from sklearn.metrics import accuracy_score\n", + "\n", + "# Convert tensors to numpy for simplicity if you want to manipulate them outside of PyTorch\n", + "input_features_np = input_features_tensor.numpy()\n", + "labels_np = torch.argmax(labels_tensor, dim=1).numpy() # Convert one-hot to class indices\n", + "\n", + "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", + "\n", + "def model_accuracy(features, labels, model):\n", + " \"\"\"\n", + " Compute accuracy for a multi-class classification model\n", + " that outputs probabilities of size [batch_size, num_classes].\n", + " \"\"\"\n", + " model.eval() # Set the model to evaluation mode\n", + " \n", + " # Move the features and labels to the correct device\n", + " inputs = features.to(device)\n", + " labels = labels.to(device)\n", + " \n", + " with torch.no_grad():\n", + " outputs = model(inputs) # shape: [batch_size, num_classes]\n", + " # For multi-class, the predicted class is argmax of the probabilities\n", + " predicted = torch.argmax(outputs, dim=1)\n", + " # Convert one-hot encoded labels to class indices if needed\n", + " if len(labels.shape) > 1:\n", + " labels = torch.argmax(labels, dim=1)\n", + " # Compute mean accuracy\n", + " accuracy = (predicted == labels).float().mean().item()\n", + " \n", + " return accuracy\n", + "\n", + "# Compute baseline accuracy\n", + "baseline_accuracy = model_accuracy(input_features_tensor, labels_tensor, model)\n", + "print(f\"Baseline accuracy: {baseline_accuracy:.4f}\")\n", + "\n", + "# Initialize array to store feature importances\n", + "feature_importances = np.zeros(input_features_tensor.shape[1])\n", + "\n", + "# Iterate over each feature for permutation importance\n", + "for i in range(input_features_tensor.shape[1]):\n", + " # Create a copy of the original features\n", + " permuted_features = input_features_tensor.clone()\n", + " \n", + " # Permute feature i across all examples\n", + " # We do this by shuffling the rows for that specific column\n", + " permuted_features[:, i] = permuted_features[torch.randperm(permuted_features.size(0)), i]\n", + " \n", + " # Compute accuracy after permutation\n", + " permuted_accuracy = model_accuracy(permuted_features, labels_tensor, model)\n", + " \n", + " # The drop in accuracy is used as a measure of feature importance\n", + " feature_importances[i] = baseline_accuracy - permuted_accuracy\n", + "\n", + "# Sort features by descending importance\n", + "important_features_indices = np.argsort(feature_importances)[::-1]\n", + "important_features_scores = np.sort(feature_importances)[::-1]\n", + "\n", + "# Print out results\n", + "print(\"\\nFeature importances:\")\n", + "for idx, score in zip(important_features_indices, important_features_scores):\n", + " print(f\"Feature {idx} importance: {score:.4f}\")" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "ALPAKA_STATIC_ACC_MEM_GLOBAL const float bias_layer1[32] = {\n", + "0.3130181f, -0.3157252f, -0.0845900f, -0.2268437f, 0.1305549f, 0.3839142f, 0.3933745f, -0.6758229f, -0.4188058f, -0.2523611f, 1.4036129f, 0.8239079f, 0.1575654f, 0.2041763f, 0.8787493f, 0.2706699f, -0.1112185f, 0.8988609f, 0.9274163f, -0.1023219f, 0.2916122f, -0.2606929f, 0.3098971f, -0.0602703f, -0.6031470f, -0.0826582f, 0.3605700f, 0.4836628f, -0.3951748f, 0.0171050f, 0.5156327f, 0.0655813f };\n", + "\n", + "ALPAKA_STATIC_ACC_MEM_GLOBAL const float wgtT_layer1[30][32] = {\n", + "{ -0.3409404f, -0.2000102f, -0.0890483f, -0.6186467f, -1.7570605f, 0.7890699f, -0.8753229f, 0.5488843f, -0.5376814f, -0.2228569f, -0.3573552f, 2.1554949f, 0.2248887f, -0.6073594f, 0.2075009f, -0.1408760f, -0.7051892f, -0.0664303f, 0.2747473f, 0.1450685f, -2.2709231f, -0.4088669f, 0.5452566f, 0.3086576f, -0.1213564f, -0.9737034f, 0.2679004f, -0.1193755f, 0.9693206f, -0.7785844f, 0.4612639f, 0.7628022f },\n", + "{ 0.0309823f, -0.5052885f, 0.0509167f, 0.1379152f, 0.2392345f, 0.0935747f, 0.1001187f, 0.2049023f, -0.1292204f, -0.1732666f, 0.0397899f, -0.1621571f, -0.5934961f, 0.1731108f, -0.0290751f, 0.1774067f, 0.0519257f, 0.0035256f, -0.0188142f, -0.0639332f, -0.2388456f, 0.0038807f, 0.0088869f, -0.1521942f, 0.0138392f, 0.0613728f, -0.0693704f, 0.0792511f, -0.0990796f, -0.4666552f, 0.0071133f, -0.1573301f },\n", + "{ 0.0850391f, 0.1398510f, 0.7312427f, 1.6511540f, -0.1157119f, 1.8312333f, 0.4276405f, 0.1428780f, -0.0762665f, -0.1204791f, 0.5231065f, -1.9012266f, -0.2759933f, 2.1613867f, -0.3789352f, -0.7384162f, 0.9980950f, -1.3757244f, -0.3345136f, -0.1468498f, -0.8294499f, 0.0777341f, 1.4606416f, 0.5440134f, 2.1587088f, 0.0192866f, -4.7703443f, -0.3778406f, 2.0303624f, 0.4716987f, -0.5599666f, 1.0512540f },\n", + "{ 0.1225046f, 0.1352115f, -0.6140373f, -0.4254693f, 0.2660460f, -0.7739570f, 1.0955174f, 1.2222520f, -0.1673346f, -0.0774871f, 0.4850340f, -0.2480809f, 2.9553666f, -0.6026165f, -0.5579752f, -0.3300077f, 0.0972013f, 1.9986047f, 0.9151404f, 0.0942718f, -0.0422994f, -0.0338943f, -0.3063800f, -0.6567743f, 2.2631214f, 0.6305654f, 0.2954915f, 0.8076106f, 0.7743834f, 0.7023419f, -1.6891261f, -0.6115909f },\n", + "{ 2.2928166f, 8.8309431f, -1.5345458f, -4.7059007f, 1.2810588f, -0.4951739f, 4.4845004f, 7.7040706f, -1.8243797f, 0.0956211f, -0.3696172f, 0.6678835f, 4.7292924f, -9.7971382f, 2.0854435f, -3.7175915f, 3.1399553f, -0.6623941f, 0.1497208f, -0.0905410f, -3.9599996f, 3.3040602f, 2.2535894f, -8.0852966f, -16.3342419f, 14.6141891f, 2.6647313f, -10.8139477f, 0.5194562f, -2.2337329f, 1.4479593f, 10.4768066f },\n", + "{ 6.3267903f, 8.4862614f, 18.2483406f, -2.5856876f, 2.9871955f, 6.0332689f, -11.9952469f, -1.8413588f, 1.3217239f, -0.1347535f, 0.8526750f, -4.6671519f, 1.7776268f, 0.2549058f, -14.5247345f, 3.1438231f, -9.3088989f, 1.1253707f, 0.0789171f, 0.1527994f, 1.7165481f, -0.9345309f, -5.6528535f, -9.5160980f, -1.0419953f, -1.0004028f, 1.2163434f, 6.5496387f, 0.0804405f, -6.1611404f, -9.2283039f, -5.5179152f },\n", + "{ -2.3399000f, 4.7614522f, -0.7949132f, -1.8299819f, -3.5392070f, -0.0867105f, -0.6724365f, 0.4369151f, 5.0728159f, -0.2512915f, 0.8541925f, -0.0066838f, -3.1859064f, -1.5541859f, 0.0789470f, -1.3237801f, 2.6402714f, -0.0662765f, -0.2674660f, 0.0608886f, 1.1315312f, -0.1070065f, 1.8663841f, -0.4901078f, -1.0676215f, 1.3194273f, -1.2205451f, -1.0945961f, 0.2475978f, 0.9682984f, 0.2073012f, 1.7876559f },\n", + "{ -0.4367641f, 0.4319819f, 1.1654582f, -3.1896923f, -1.2711153f, -0.6044747f, -3.3841281f, 0.1955531f, 4.4088063f, -0.1708138f, -0.3717940f, 0.1636910f, -2.8220074f, -0.1908980f, 1.0748017f, -0.3119625f, -1.5694339f, -0.1778283f, 0.9781732f, -0.0754240f, 4.2073040f, 2.0899282f, 1.2668608f, 0.0779037f, -1.9252455f, 2.5578146f, -0.9014751f, 0.0522716f, 0.4070499f, -0.2140355f, 0.7342044f, 0.8045168f },\n", + "{ -6.6071205f, 0.6651391f, -5.2370253f, 5.0692110f, 0.7510651f, 2.4857941f, 11.5648117f, -1.7613629f, 1.9677536f, 0.1799172f, -0.3524120f, -1.9071139f, -7.4744873f, 2.6429889f, -0.1444485f, 2.3542099f, 1.9711516f, 0.3886011f, 0.0591138f, 0.0699189f, -3.0183461f, 5.6741471f, -1.1488796f, -5.7687993f, -7.4685879f, -9.6448822f, -3.4691532f, 5.0486813f, -12.6432810f, 8.1007729f, -1.7438285f, -17.6806660f },\n", + "{ 2.1428952f, -6.8331566f, 13.9079485f, 0.1021993f, -3.7234893f, 4.1696281f, -2.8855472f, 2.6500645f, 0.2383825f, 0.0089392f, -1.2698770f, -1.4507635f, 1.9835703f, 1.9673402f, 0.0971966f, -3.0400901f, -1.6690960f, -1.8351660f, -0.5974421f, -0.1041481f, -4.3467126f, -0.8935851f, -4.2216101f, -3.5241582f, -1.7024223f, 2.4351561f, -3.5994439f, 8.1877632f, 1.3503532f, -8.8505297f, -4.3802824f, -1.2580007f },\n", + "{ 0.5120507f, -0.8167784f, 0.0356091f, 0.9550222f, -0.0856578f, -1.6678712f, -1.3589050f, -0.4634860f, 1.7959082f, -0.0221804f, 1.6858692f, -0.2350332f, 1.5413535f, -0.8954431f, -1.9626563f, -0.8659950f, 1.5534185f, -0.2663006f, -0.0141180f, -0.0641252f, -2.7467253f, 1.3316264f, 0.1042683f, 1.3546844f, 1.3957758f, -1.1120845f, -0.4499652f, -1.0622435f, 1.5979129f, -2.7754719f, -1.8175740f, -0.2714084f },\n", + "{ 2.8423781f, -0.5802551f, 2.4059951f, -0.1214163f, 0.6598997f, -2.7015202f, -0.6345362f, -0.1668582f, 0.9119438f, -0.1413521f, -0.1243188f, -1.5296214f, -2.2767708f, -0.4354365f, -0.2501381f, -4.5476084f, -2.0407526f, -0.1834106f, 0.4992486f, -0.0249097f, 4.2349467f, -2.2088041f, 0.8245230f, 3.3258171f, -0.8291156f, 1.8809289f, 2.0291409f, -0.1829567f, 0.3753935f, -2.1255214f, 0.3044383f, -0.8117877f },\n", + "{ 1.0955867f, -10.5214415f, -2.5690074f, 5.7836456f, 8.4926434f, -0.9013922f, 3.8609581f, -18.2928333f, -0.1085265f, -0.1957067f, 0.4130777f, 8.3614740f, -0.1366454f, 15.5778351f, 4.3706384f, -6.1245522f, -0.2726971f, -2.2554579f, 0.1885025f, 0.1015946f, 3.1017957f, -9.0225515f, -2.1712937f, 3.6819403f, 0.3107212f, 5.4358387f, 0.1556205f, -5.7449660f, -8.0870762f, -15.4704018f, 2.6353786f, 3.1047711f },\n", + "{ 7.6964021f, 5.5716128f, 10.1820107f, 0.3581720f, -3.3344195f, 0.0203794f, 4.8660736f, 0.9843333f, 1.5092058f, -0.1621139f, 1.0461260f, -3.6782911f, -3.2391973f, -1.3214555f, 2.3153496f, -0.1955887f, -0.4074941f, 1.3566486f, 0.4631876f, -0.1068150f, 0.6354957f, 0.4232795f, 1.1560757f, -1.6600587f, 5.4032493f, -0.7661732f, 0.9706129f, -1.9504737f, -2.6826806f, -1.7078609f, 0.2543025f, -0.4549442f },\n", + "{ -0.9363269f, 3.2511759f, 1.4988805f, 0.3896330f, 3.5761805f, -1.8898439f, -0.5672647f, 0.0965158f, 2.3343837f, -0.2087380f, -1.6115571f, -2.6963377f, -2.8213720f, -1.1722008f, -0.7269015f, -3.2212329f, -2.5389972f, 1.8190597f, -0.3862028f, 0.0262889f, -0.8232661f, -0.2829636f, 0.9204201f, -0.5837291f, 1.1138207f, 1.4175742f, 0.6769784f, -1.8336256f, -0.9406338f, 1.4764718f, -0.4086397f, 1.3364717f },\n", + "{ -0.2418238f, 0.2686039f, 0.5018921f, 3.8336344f, 0.0712455f, -2.2001183f, -0.9898972f, 0.2089690f, 3.5906739f, -0.1943938f, -0.6164807f, -2.6908536f, 0.5925170f, -0.3563481f, -1.3926654f, -1.4510415f, 0.0248772f, 0.8849464f, 0.3642189f, -0.1435763f, 1.5389724f, 0.9731716f, 0.7003614f, 1.1896435f, 0.0179679f, -0.7365713f, 0.8497150f, 0.2604503f, -2.4675038f, 1.0341363f, 1.5702873f, 0.1958498f },\n", + "{ -2.8209245f, -8.6002207f, -7.2007174f, -4.5361319f, 1.8299714f, -10.6522865f, -11.7428637f, -5.2347007f, 7.4452977f, -0.1641102f, 0.6734878f, 6.6603560f, 0.1524790f, 10.6615028f, -19.3698826f, 2.1673820f, 11.9764709f, 1.0121659f, 4.0597305f, -0.1707846f, -2.1111398f, -6.5540004f, -18.7466564f, -9.4047441f, 3.0535262f, 2.7488508f, 9.8994598f, 1.6856064f, 19.2076931f, 6.8278632f, -21.9050026f, 9.4619055f },\n", + "{ -9.7272358f, 11.2765751f, -7.5082231f, -1.7225909f, 13.1407022f, -14.8349190f, -3.2504196f, 15.8990402f, 10.4930105f, 0.0052292f, -1.3149272f, 4.6401610f, -0.3964378f, -4.2996163f, 0.8589647f, 3.0086250f, 16.7304554f, -1.8806626f, -2.5309651f, -0.1554628f, -1.7752417f, 6.6541591f, -16.9228725f, -4.0883851f, 2.5308323f, 1.2640512f, 6.7352295f, 4.6456785f, 8.4472675f, 6.6682787f, -22.7476864f, 5.0754099f },\n", + "{ -0.0464466f, 0.0371061f, -0.3150824f, 0.2278159f, 0.0781510f, 0.3682392f, -0.0076914f, -0.0199720f, -0.0213357f, -0.0197757f, -3.6537273f, -1.3694360f, -0.1145320f, -0.1972410f, -1.1493046f, 0.5570444f, 0.0053487f, -2.8934319f, -3.5391710f, -0.0754734f, -1.4390492f, 0.5983394f, -1.1443400f, -0.0306496f, -0.0482763f, -0.3964492f, 0.0067964f, -0.3276957f, 0.0020439f, -0.3366909f, -0.0156225f, 0.0541534f },\n", + "{ -19.9178314f, -10.1069622f, 4.6057873f, -1.6936491f, -1.3457170f, 4.9226985f, 9.7256756f, 3.1693432f, -14.5028820f, 0.0931624f, 0.0866399f, 0.5221885f, -5.0346432f, -9.0003071f, 6.8700786f, -0.3239930f, 10.3604975f, 0.2665467f, -0.8745342f, 0.1671114f, 11.8470201f, 7.2739577f, 0.3534940f, 9.4823771f, 7.6572828f, -0.6058345f, 4.8220043f, -4.4774542f, 6.0150776f, 13.1617508f, 8.9563951f, -5.3839235f },\n", + "{ -15.3784885f, -4.7021117f, 1.3418291f, -5.5577135f, -3.5915122f, 2.1950095f, 9.0849819f, 4.1605716f, -9.9853363f, -0.2489175f, 0.0786141f, 1.3932520f, -13.3884411f, -7.0676465f, 8.2610102f, -0.6629214f, 9.6442671f, -1.4963982f, 0.0448971f, 0.0271809f, 11.0636673f, 10.8506622f, 1.5703944f, 11.9227686f, 8.8615837f, 2.9824717f, 6.5997915f, -5.1319938f, 6.7163310f, 8.8659868f, 8.8796968f, -0.0585466f },\n", + "{ -0.5536407f, -0.0183597f, -0.8522137f, -1.6674993f, -0.5186954f, 0.2934249f, -0.3867851f, -0.2717040f, 0.3141194f, 0.0803582f, 0.9972823f, 1.0203497f, 0.1344238f, 0.2180226f, 0.3881106f, -0.5604802f, 0.8635465f, 0.7390655f, 0.9010500f, -0.1396771f, -0.5388748f, -0.0631624f, 0.1186292f, -0.5192882f, 0.4407078f, -1.0921305f, 0.4530205f, -0.2301908f, -0.0831401f, -1.6932087f, 0.6916737f, -7.0294051f },\n", + "{ -0.7002707f, -1.2612772f, 0.1443634f, 0.1889855f, -1.5781668f, 0.5102594f, 1.0737917f, -0.3611829f, -9.7292194f, -0.2542418f, 0.7840535f, -0.1868368f, -3.0662625f, 0.1156164f, 1.8433700f, 0.5661708f, -0.7161480f, 0.5610115f, 1.1667061f, 0.0282633f, 2.0658562f, -0.1299639f, 0.8088669f, -0.3167280f, 0.0436082f, 0.2469898f, -0.0169766f, 1.4717587f, -0.6065165f, 0.3588113f, 2.0036082f, 1.5167968f },\n", + "{ 0.8769321f, 0.4987610f, -0.1006753f, 0.1093301f, 0.6219906f, -0.0845869f, 0.5858616f, -0.3685570f, 0.1055476f, 0.0447120f, 0.9492686f, 0.6054065f, 0.3753049f, -0.0302966f, 0.1571343f, 0.5852838f, -0.0052162f, 0.6741483f, 1.0925988f, -0.0104674f, -0.2353034f, -0.2112046f, -0.1704750f, 0.2319756f, -0.6973480f, 0.5177563f, 0.0008280f, 0.0842060f, -0.4818317f, 0.2600521f, -0.8436811f, -0.3244939f },\n", + "{ -0.2498078f, -0.0665436f, -0.6455372f, -11.8636007f, 0.5546330f, -0.1828474f, 0.4531056f, -0.5045753f, 0.0021860f, -0.1727674f, 0.8539888f, 0.9355938f, 0.1768771f, 0.1097697f, 0.3439056f, -2.7710619f, 0.5562497f, 0.7414805f, 0.9723032f, -0.0587594f, -0.2656697f, -0.2262570f, 0.2070577f, -0.4649415f, 0.6072245f, 0.4071639f, 0.7271490f, -0.3399665f, -0.2358655f, -0.6572027f, 0.3878818f, -2.0380676f },\n", + "{ -1.0040656f, -15.5001106f, 0.1743044f, -0.1830439f, -12.8375502f, 0.6479514f, 0.4189325f, 0.0413221f, -2.3929427f, 0.0917285f, 0.8877478f, -0.4985425f, -3.9854083f, 0.2641849f, 1.1492108f, 1.1804928f, -0.4219547f, 0.7097337f, 0.7229153f, -0.1290352f, 1.0844772f, -0.2112084f, 0.2665555f, -0.0124178f, -0.7096525f, -0.0406028f, 0.1377101f, 0.9613231f, -0.5315894f, 0.3310910f, 1.4178389f, 1.1687748f },\n", + "{ 0.6737692f, 0.5887758f, -0.0501311f, 1.0043997f, 0.9804797f, 0.1708015f, 0.1819106f, -0.5253279f, -0.4184220f, -0.0125217f, 1.0117618f, 0.4730924f, 0.6627691f, -0.1479898f, 0.2580604f, 0.7293127f, 0.0848437f, 0.7156146f, 0.7816427f, 0.0669967f, -0.1618401f, -0.0909416f, 0.1740791f, -0.0961089f, -0.9237953f, 0.1705291f, 0.2970416f, 0.2243806f, -0.2949813f, 0.1089575f, -0.3233873f, -0.5519235f },\n", + "{ 0.8008638f, 0.1581205f, -0.0660521f, -1.8700145f, 1.8415585f, -0.1517836f, 0.9849059f, -0.3305838f, -0.1100404f, 0.1668913f, 0.0475755f, -0.7035441f, 0.0815494f, -0.4517657f, 0.3797469f, -0.7237183f, -0.4326949f, -0.0045456f, -0.1520977f, 0.0641505f, 0.3261537f, 0.3451711f, 0.3400122f, 0.1905922f, 0.5979490f, 1.8726524f, -0.0212991f, 0.5759162f, -0.3411353f, 0.6154488f, -0.6591337f, 1.1044852f },\n", + "{ -0.5229920f, -1.9910779f, 0.0082411f, -1.0964926f, -2.8197410f, 0.0379619f, -0.6362457f, 0.2922821f, 2.1298475f, 0.0429705f, 0.0860498f, 0.0998842f, -1.5189127f, 0.8646886f, -0.3320501f, 0.8533753f, 0.2106902f, -0.1353741f, 0.1139200f, 0.1223277f, -1.3724730f, -0.4976718f, -0.8456128f, 0.6275283f, -0.7746818f, -0.4165801f, -0.3285437f, -1.1218213f, 0.2541133f, -0.0699065f, -0.3644446f, -0.1836386f },\n", + "{ -0.3121711f, 0.6553325f, 0.0358024f, 2.9133885f, 0.0601561f, 0.0650956f, -0.3388072f, -0.1669552f, -1.0281963f, 0.0421454f, -0.0252209f, 0.6697660f, 0.3152214f, -0.1007413f, 0.0397672f, 0.0762647f, 0.6438169f, -0.1315191f, -0.0042900f, -0.0298286f, 0.4026648f, 0.1243395f, 0.5714200f, -0.5934660f, 0.1301748f, -0.8935670f, 0.3112782f, 0.2476320f, 0.3565531f, -0.2412163f, 0.8716605f, -0.7849768f },\n", + "};\n", + "\n", + "ALPAKA_STATIC_ACC_MEM_GLOBAL const float bias_layer2[32] = {\n", + "-0.2339502f, -2.1117039f, 1.3029057f, 0.1388091f, -0.1596625f, -0.0429177f, -1.2913821f, -3.2888331f, -0.1078175f, 0.3783462f, -0.1941358f, -0.4196923f, 1.5371246f, 1.0102557f, 0.0155053f, -0.1713930f, -0.8003848f, 1.2788725f, 1.2547292f, 0.9737166f, 0.0159556f, -0.9538616f, -0.6874874f, 1.4604672f, -0.0163722f, -0.6740248f, 0.4310561f, 0.4786355f, 0.4338695f, -0.2144726f, -0.3379008f, 0.9070537f };\n", + "\n", + "ALPAKA_STATIC_ACC_MEM_GLOBAL const float wgtT_layer2[32][32] = {\n", + "{ -0.1525858f, -1.7981244f, 0.8225231f, 0.9017935f, -0.1756990f, -0.1731562f, 0.3533860f, -0.1859016f, 0.0066088f, -1.9248251f, -0.1184793f, 0.0248543f, 0.5696007f, -0.0907341f, 1.2301605f, -0.2350527f, -1.4149488f, -0.1846859f, -1.7834854f, 1.0467646f, -0.1320335f, -0.5828664f, -0.9305819f, -1.3805257f, -0.0421591f, -0.9069743f, -1.1278086f, 0.2304732f, -0.5726073f, 0.0060693f, -0.6733936f, 2.4653385f },\n", + "{ -0.1177371f, -13.0121050f, -1.2429829f, -2.7850580f, -0.0618916f, -0.1455433f, -2.9844699f, 2.1670566f, 0.0382921f, 1.7592702f, 0.0663455f, -0.6991704f, 1.1692400f, -1.3941003f, -0.7780491f, -0.1350629f, -3.8783550f, -0.2548187f, 0.3231797f, 1.3278724f, -0.3614419f, -0.6638193f, -0.3801469f, 0.3153987f, 2.9468744f, 0.4463870f, -2.2508523f, 0.2242093f, -0.7827371f, 0.0342467f, -0.1136825f, -1.4055092f },\n", + "{ -0.0814950f, 1.1034510f, -2.7856872f, 1.1012450f, -0.0762363f, 0.0192755f, 0.0942746f, -0.8603304f, 0.0463609f, -0.4289516f, -0.2340941f, 0.7183347f, -0.3241202f, -3.3051322f, 0.6164730f, -0.1720765f, -2.9967003f, -0.2995018f, 1.4385067f, -3.2133255f, -0.0849749f, 1.7341144f, 0.7924995f, 2.0554399f, 1.0967957f, 0.3311919f, -0.9581537f, -0.7702340f, 1.0316148f, 0.0471946f, 0.9495452f, 2.0788605f },\n", + "{ 0.0446539f, -0.7262716f, 0.2242534f, -2.2987046f, -0.1550755f, -0.2146342f, 0.2053458f, 0.1453160f, -0.0573685f, -0.6836352f, -0.1972867f, 0.3058005f, -0.2511542f, 1.0547395f, 0.2851569f, -0.1420605f, 0.1955147f, -0.8554440f, -0.7066809f, -0.3464343f, -0.0067930f, -0.3269323f, -0.2229819f, -0.3971729f, -0.6963940f, -0.2993692f, -0.1417452f, -0.7594388f, 0.5256453f, -0.1079851f, 0.1336530f, -1.2257522f },\n", + "{ -0.0354096f, -10.4679108f, 2.8168788f, 0.2949276f, -0.1779284f, -0.1868454f, -1.4323943f, 0.8472028f, -0.0747709f, 0.2078577f, -0.1229924f, 0.0567793f, 2.0551534f, -0.4557192f, -0.7714190f, -0.0745471f, -0.1583210f, 1.3936666f, -1.1807573f, -0.6911215f, -0.1539799f, -0.6865327f, 0.5996337f, -0.7800899f, 0.1926797f, 0.4295036f, -1.3045609f, 1.3919017f, -0.6824273f, -0.0342412f, -0.1761061f, -0.6753551f },\n", + "{ -0.1453138f, -2.1702523f, -0.8408509f, 0.0124695f, 0.0548023f, -0.0599126f, -1.3220210f, -0.9635140f, 0.0304381f, 0.8915846f, -0.0114084f, -0.1811013f, -1.0392225f, 0.9355077f, 0.5600502f, -0.1725564f, 1.3016142f, 0.0637798f, 0.2127237f, -0.4201719f, -0.1813546f, -2.8520944f, -1.0162476f, 0.4036767f, -1.3526999f, -0.3187918f, 0.2210670f, -0.0195748f, 0.5739521f, -0.2086164f, -4.5898190f, 0.8253796f },\n", + "{ -0.1084008f, -1.1329324f, -1.6204234f, 1.1476817f, -0.0628205f, -0.1122542f, 0.5731813f, 0.2743464f, -0.0484099f, -0.8538088f, 0.0105360f, -0.7381154f, -1.1602422f, -1.2242011f, 0.2023281f, -0.2181710f, -1.7745955f, -0.0181406f, 1.6978747f, -1.8673037f, -0.1872354f, 1.3988467f, -1.3061260f, 0.6495667f, -0.6886965f, -0.6353581f, 0.2563086f, 0.4972568f, 1.2073671f, -0.0324760f, 0.6208668f, 2.0307891f },\n", + "{ -0.0193408f, -0.4770618f, 1.1954961f, -4.2316422f, -0.0323112f, -0.1459102f, -0.8704525f, 0.0916627f, 0.1220654f, 1.4996083f, 0.1548122f, -1.9588864f, -1.5469869f, -1.3433179f, -1.0718721f, -0.0825612f, -0.4096117f, 0.9981126f, -1.7012634f, -1.8265936f, -0.0371830f, 2.3563027f, -1.3538713f, 0.6455814f, 1.7223636f, 0.7526782f, -2.3576136f, 1.1849345f, 2.1408458f, -0.1714138f, 0.4818093f, -0.3588967f },\n", + "{ 0.0729335f, -6.1714144f, 2.1946981f, -4.1299558f, 0.0044464f, -0.1933377f, 0.6864235f, -0.3803750f, -0.2148182f, 0.9621077f, -0.1807104f, 2.4574375f, 1.2062972f, 2.5480094f, 0.6602007f, -0.0865193f, -1.3168519f, 1.7552648f, 0.5018759f, 1.1154609f, -0.2282289f, 0.7631954f, -1.0011274f, -1.3729336f, 0.9056255f, -0.0414166f, -1.8079906f, 1.0183337f, 0.3374647f, 0.0572285f, -0.4237293f, 1.7220364f },\n", + "{ -0.0589474f, -0.0295253f, 0.1225147f, 0.1446855f, 0.1071389f, 0.1601978f, -0.0401577f, -0.0364199f, 0.0953632f, 0.0868486f, 0.1608090f, 0.0642403f, 0.1249414f, 0.1899325f, 0.0255690f, -0.0968316f, -0.1216723f, -0.1698821f, 0.0820711f, 0.1747911f, 0.0620590f, -0.1446941f, -0.1555044f, 0.0741209f, -0.0763885f, -0.1246467f, 0.1337765f, -0.0873028f, 0.0942246f, 0.0860358f, 0.1234084f, 0.1226101f },\n", + "{ 0.0643494f, -0.8147358f, -1.5970768f, -1.5264196f, -0.1014192f, 0.0894014f, -0.2399453f, 1.0807495f, -0.1235767f, -1.2951756f, -0.0810054f, -1.1764668f, 0.4282590f, 0.0908309f, 0.6702700f, 0.0942179f, -0.8752475f, 1.0613892f, 2.6491807f, 0.4649454f, 0.0426983f, -0.8645003f, -1.2832506f, 1.0818568f, -0.6891628f, -2.6222782f, -0.0669045f, -5.2834606f, -3.7087319f, -0.1093205f, -3.0082226f, 0.0202389f },\n", + "{ -0.2165046f, -2.3887262f, 1.3350971f, 3.2154691f, -0.0899850f, -0.1132472f, 3.4892216f, 0.6763555f, -0.0366738f, -0.3074943f, -0.0737181f, -0.2638237f, -2.1615725f, 0.9533494f, 1.0567867f, -0.0743023f, -1.5021936f, -2.1020463f, -0.4901834f, 0.3259082f, -0.2369750f, -0.1386719f, -0.1885716f, -0.6453994f, -0.2157237f, -0.3096344f, -0.7145861f, -0.0144529f, -1.7015969f, 0.0518598f, -0.8833122f, 0.1368720f },\n", + "{ -0.1693773f, -1.7163130f, 1.4214562f, -5.3610578f, -0.1453443f, 0.0007269f, 2.3939090f, -0.4252889f, -0.2080639f, 0.4997855f, -0.1604623f, -0.9568118f, 0.1174134f, 1.5827537f, 0.0471937f, 0.0410202f, -5.9910173f, -1.7195174f, 2.6566005f, -0.1300166f, -0.3056662f, -0.4734422f, 0.5415567f, -0.6322125f, -0.6389906f, 0.8110722f, -1.7760222f, -0.1890267f, 0.4568902f, 0.0178664f, 0.9815944f, -0.6494362f },\n", + "{ -0.0803987f, -0.4049147f, -1.1458402f, -3.8709407f, -0.0748584f, 0.0285123f, 1.0476711f, -0.4686022f, -0.1482179f, -1.3519528f, -0.1977116f, 1.0795574f, -2.0080688f, 0.8830637f, -1.8861086f, -0.1462930f, -0.4670950f, -2.0300276f, -1.3247769f, 1.1512129f, -0.1377436f, 1.7818434f, 0.5111244f, -1.0817790f, -1.9341105f, -0.2747863f, 1.7866142f, -0.6981304f, 0.2916693f, -0.0235312f, -0.8175534f, -1.6998042f },\n", + "{ -0.1061751f, 0.6215375f, -0.4626823f, 0.7672512f, 0.0020817f, -0.1591142f, 1.0898968f, -0.9204068f, 0.0574663f, 0.1774926f, 0.0484907f, -0.3295842f, 0.4489160f, -1.1343844f, 1.5402520f, -0.1346410f, -1.5354218f, -0.7182790f, 0.2196787f, -4.9884086f, -0.1196452f, -0.7342232f, -0.2625498f, -0.1683128f, -3.4147658f, 0.4656263f, -0.1907654f, -1.5676821f, -0.1993192f, 0.0392413f, -0.4939966f, 0.2587339f },\n", + "{ -0.1575988f, 0.3401670f, 0.3991243f, 0.7242632f, -0.1744899f, -0.1183499f, 0.0967574f, -0.3592824f, -0.2704069f, -0.2581256f, -0.2485954f, 0.3446464f, -0.7076147f, -0.6296598f, 1.0094991f, -0.1585015f, -0.0655680f, -0.1737440f, -0.6190755f, 0.2832435f, -0.0455194f, -1.0389528f, 1.4021788f, 0.1367040f, -0.2918817f, -0.8456540f, -0.1551147f, -0.8092477f, 0.2053470f, 0.0517656f, 0.3124267f, -0.0847588f },\n", + "{ 0.0037535f, 1.7123971f, -0.0702230f, -7.8142509f, -0.1969137f, -0.0857682f, -0.2811019f, -0.4737110f, -0.1132331f, -0.1521158f, -0.0600281f, 0.0706206f, 1.0167074f, 0.2372540f, 0.5142798f, -0.1154348f, -0.4991046f, 1.8288782f, -0.2266811f, 0.5554674f, -0.2068119f, -1.1852149f, 1.3449707f, -2.0556967f, -0.9963455f, 0.2519412f, 0.1331067f, -0.4544231f, 1.0645961f, -0.1106808f, 0.9460998f, 0.4178981f },\n", + "{ -0.2119178f, 1.9799466f, -0.2797689f, -1.0932020f, 0.0028856f, -0.1588331f, -0.3786546f, 1.1156173f, -0.0215458f, -5.3119068f, -0.0948427f, -1.5506617f, -1.4219491f, 2.4189386f, 1.4762870f, -0.0996478f, 1.9492981f, 0.9989491f, 0.9490891f, -0.7439177f, -0.0704944f, -0.7028235f, 0.2047088f, -2.2349808f, 0.5578019f, -2.0313745f, -0.0651128f, -0.3024914f, -2.0939598f, -0.0180862f, -1.0688573f, -1.6645123f },\n", + "{ -0.0167054f, -0.9788288f, 0.5361816f, 1.6640991f, -0.0242798f, 0.0521166f, -0.4116896f, -0.8344618f, -0.0963774f, 3.8331950f, -0.0267836f, 2.8407450f, 1.5915482f, -1.3799198f, -2.6205218f, -0.0732574f, 0.7984170f, -2.0659196f, -3.5870969f, 2.8777542f, -0.1461587f, 1.1945763f, 1.1590073f, 1.2939118f, 0.4921311f, 4.3183928f, 0.2045673f, 4.7427683f, 3.8618271f, -0.0308256f, 4.8768749f, 1.0944498f },\n", + "{ -0.1358386f, -0.1472694f, -0.1158773f, -0.0939915f, 0.1222765f, 0.1192099f, 0.0518727f, 0.0441144f, 0.0335920f, 0.1631196f, -0.0097409f, 0.0913221f, -0.1351816f, -0.0724972f, -0.0489905f, 0.0919923f, -0.0411735f, 0.0600587f, 0.0222677f, 0.0181216f, -0.0119533f, 0.0149932f, 0.1105281f, 0.0120449f, -0.0900230f, 0.1096532f, -0.0958123f, 0.0478526f, -0.0528525f, -0.0530951f, 0.1649548f, 0.0884538f },\n", + "{ -0.1365108f, -1.4470286f, -3.3190532f, -0.6094688f, -0.2091497f, -0.0999000f, -0.9306201f, -1.0360260f, -0.1827736f, 0.0683407f, -0.0502374f, 0.3790842f, -3.5734098f, -0.9349656f, 0.3886600f, 0.1461777f, -2.0582819f, 1.1910707f, -1.9501384f, -2.9547875f, -0.1470737f, -1.4672521f, -0.8007376f, -0.9336768f, 1.3155514f, 0.4972472f, -5.6431427f, -3.5151341f, -5.8025484f, -0.0358306f, 0.4548545f, 0.2571939f },\n", + "{ -0.1756670f, -1.3679352f, -0.3739035f, -0.4339395f, -0.2128811f, -0.1038225f, 0.5042853f, 0.4281598f, -0.2762718f, -0.3050812f, -0.0561761f, -0.3663020f, -0.7264599f, -0.7135370f, -1.4355675f, -0.0855041f, -1.6003639f, -3.3269587f, 0.3331296f, -0.2510884f, -0.0435672f, 2.3782334f, -0.8790841f, -0.4602313f, 0.3282675f, 1.0674137f, -0.4261901f, 0.6184355f, -0.6284660f, -0.2009129f, -1.3984579f, -0.1182039f },\n", + "{ -0.0179437f, -2.4176567f, -0.0757853f, -0.9053250f, -0.1604971f, 0.0653039f, -0.0456533f, -1.2324991f, -0.2042288f, -1.4000354f, -0.1496307f, 0.7025797f, -0.8148692f, -2.2639663f, -0.1080219f, -0.1692714f, -0.0350256f, 1.1112232f, -1.2173100f, -3.3865623f, -0.0014034f, 1.0519972f, -0.4149089f, -0.9822370f, 1.0764426f, 1.0734540f, 0.2395243f, -0.9695317f, -0.2875118f, 0.1405493f, -0.8919160f, 0.4231086f },\n", + "{ -0.2191032f, 1.8018681f, -1.7559167f, 0.0348163f, -0.0597553f, 0.0096346f, 0.4048410f, -2.5880859f, -0.2669998f, 0.5746318f, -0.1141291f, 0.3722134f, -0.6411135f, -0.8711343f, 0.9618454f, -0.1413043f, 0.5999789f, 0.4171342f, -0.0654649f, -1.1597379f, -0.0378705f, -1.1590790f, 1.3731012f, -0.1211245f, -0.4004700f, -0.8745431f, 0.3397753f, -0.4758925f, 0.4651093f, -0.1950274f, 0.7756749f, 0.2193565f },\n", + "{ -0.0431680f, 1.4789274f, -1.2237777f, 1.1632382f, -0.0625869f, -0.0474214f, -1.3440026f, 1.4450136f, -0.1436337f, 0.7880324f, -0.2373608f, 1.4573339f, 0.7586362f, -0.9148275f, -0.2211355f, 0.0550283f, -6.0262470f, 0.3978752f, -0.1995126f, -0.0479593f, -0.1120174f, -1.2093679f, 3.4314570f, 0.0222109f, -1.1163449f, 0.2131999f, 2.2462623f, -0.4972607f, -0.5182921f, -0.0701132f, 0.5019436f, -0.1937658f },\n", + "{ -0.1667943f, 0.0196575f, -0.9141287f, -0.2902696f, 0.0318007f, -0.0335459f, 0.5858797f, -0.0900941f, -0.0614002f, 1.2998632f, -0.0909332f, -0.5147573f, -0.0871529f, 1.2078508f, -0.2408318f, 0.0943059f, 2.0550728f, -2.3531728f, -0.4569368f, -0.0462965f, -0.1727646f, 0.8314115f, -0.0838026f, 1.1487722f, -2.3224678f, 0.6001235f, 0.4483957f, -2.0149994f, -0.5540643f, -0.0993498f, -0.3938250f, 1.0691496f },\n", + "{ -0.2128484f, -0.9067336f, -0.8067648f, -0.5638098f, 0.0466387f, -0.1250697f, -1.0633522f, 2.0702426f, -0.2612922f, 0.4281299f, 0.0145171f, 0.8089548f, 1.0637047f, -1.6296221f, -0.0324031f, -0.1453255f, 0.5244270f, 0.4705407f, 0.3187122f, 0.8282533f, -0.0451930f, 0.6502138f, -0.3958839f, 0.8502606f, 0.9307000f, 0.3706145f, -1.7802641f, 0.1567971f, 0.0740849f, -0.0895442f, 0.3950359f, -1.7540554f },\n", + "{ 0.0023149f, 0.7263200f, 1.8512523f, -0.4771179f, -0.1163291f, 0.0224466f, 2.1579833f, -0.2968757f, -0.0573871f, 0.9053006f, -0.0305074f, -0.9798708f, -0.0732942f, 1.8749733f, 0.1304994f, -0.1875225f, 1.1224291f, 0.5052640f, -0.4813080f, -1.1107147f, -0.1909173f, -2.3484604f, 0.1204147f, -0.4592973f, -0.2538125f, 0.4104489f, -0.7043998f, 1.0401071f, 0.0527195f, -0.1755383f, -0.0225556f, -1.7959408f },\n", + "{ -0.1738733f, -2.5466421f, 0.4712865f, -4.1575985f, -0.0673208f, -0.0929229f, 1.4804220f, -2.0162396f, -0.0699170f, -1.2635130f, 0.0033261f, -0.2207318f, 1.6315613f, -0.2808344f, 0.6500907f, 0.0873509f, 1.2061590f, 0.9874881f, 0.0503163f, 1.1066431f, -0.2690387f, 3.0120494f, 1.0901014f, -0.8387108f, -0.1055128f, 0.7355948f, -1.6113559f, 0.1626149f, 0.6186574f, -0.1442034f, -1.3751197f, -0.7357581f },\n", + "{ 0.0801182f, -1.7048576f, -0.6614535f, 0.2159117f, -0.1067486f, -0.0390127f, -1.7236508f, -4.8257847f, -0.1392784f, 2.6549494f, -0.0076066f, -0.2928389f, -1.4872373f, -2.2248919f, 0.3740856f, 0.0089244f, -2.1601932f, 0.2364899f, 2.2799277f, 0.6705108f, 0.0404501f, 0.8240705f, 0.9378061f, 1.7379807f, 0.6561645f, -2.2824814f, 0.4317684f, 0.9481612f, 3.3949718f, -0.1386193f, -0.5392355f, -0.2192377f },\n", + "{ -0.0682519f, 2.1510222f, -1.3331705f, 1.7252599f, -0.0390061f, -0.0311894f, -1.4449217f, 0.7528419f, 0.0482161f, -0.9996649f, -0.1083589f, 0.7481404f, -1.0902423f, 0.3155636f, -0.7373950f, -0.0225023f, -0.5480027f, -1.7043972f, 0.1796622f, -1.9089470f, -0.0284076f, 0.1727537f, 0.9399895f, 1.2661113f, -0.1629798f, 0.9542015f, -0.3868639f, -0.2013974f, 0.7035374f, -0.1902321f, 0.0618807f, 1.6585518f },\n", + "{ -0.0382449f, -0.3556600f, 0.7932450f, -1.7579209f, 0.0754038f, -0.1452644f, -0.0262194f, 0.0631723f, -0.1482574f, 0.4631275f, 0.0041328f, -0.5673072f, 0.6279140f, -0.4053148f, 0.7965385f, 0.0043115f, -0.7390732f, -0.1081802f, -0.3900261f, 0.2581029f, -0.0874927f, 0.5915189f, -1.1614733f, 0.6526067f, -0.0630925f, -1.2862672f, 0.0975272f, -1.6696635f, -2.1918375f, 0.0246310f, -0.5078331f, -1.6059371f },\n", + "};\n", + "\n", + "ALPAKA_STATIC_ACC_MEM_GLOBAL const float bias_output_layer[3] = {\n", + "-0.0299599f, -0.5291600f, 0.4870134f };\n", + "\n", + "ALPAKA_STATIC_ACC_MEM_GLOBAL const float wgtT_output_layer[32][3] = {\n", + "{ 0.0675000f, 0.0580845f, 0.0252083f },\n", + "{ -0.5480256f, 0.5014070f, -0.4848022f },\n", + "{ 0.0558893f, -0.9268196f, 0.6123658f },\n", + "{ -0.1926070f, 0.4826855f, -0.7119671f },\n", + "{ -0.0407576f, -0.0624052f, 0.0473613f },\n", + "{ 0.0114524f, -0.0661213f, 0.0861401f },\n", + "{ 0.4889011f, 0.3984242f, -0.1460386f },\n", + "{ -0.0307660f, -1.1088817f, 0.5681090f },\n", + "{ -0.0103742f, -0.0451352f, 0.0913075f },\n", + "{ 0.2240870f, -0.4127513f, -0.3115116f },\n", + "{ -0.1053132f, 0.0329629f, -0.0964765f },\n", + "{ 0.1293648f, -0.0118805f, -0.5233608f },\n", + "{ 0.0734460f, 0.5619589f, -0.4186259f },\n", + "{ -0.3173129f, 0.1465155f, -0.1484945f },\n", + "{ -0.5634108f, 0.2698199f, 0.1681544f },\n", + "{ 0.1714749f, -0.1649845f, 0.1014268f },\n", + "{ 0.1057630f, 0.9072341f, -1.1890781f },\n", + "{ -0.3175716f, -0.2992002f, 0.3401313f },\n", + "{ -0.4994496f, -0.1189708f, 0.3650176f },\n", + "{ 0.4023024f, -0.9219202f, 0.0693439f },\n", + "{ -0.1709604f, -0.0994071f, 0.0222464f },\n", + "{ 0.3324146f, 0.0158491f, -0.4939574f },\n", + "{ 0.0952293f, -0.5191534f, 0.3818873f },\n", + "{ 0.0176812f, 0.2723607f, -0.3078566f },\n", + "{ 0.1919187f, -0.8505318f, 0.1964855f },\n", + "{ 0.0822281f, 0.0565761f, -0.5816049f },\n", + "{ -0.0743144f, 0.0852944f, -0.7256451f },\n", + "{ 0.1760607f, 0.0578475f, -0.8243266f },\n", + "{ 0.1617602f, 0.1823115f, -0.4042889f },\n", + "{ -0.0384557f, -0.0115344f, 0.0508929f },\n", + "{ 0.5513267f, -0.0695007f, -1.2113558f },\n", + "{ 0.0910288f, -0.4101452f, 0.3242988f },\n", + "};\n", + "\n" + ] + } + ], + "source": [ + "def print_formatted_weights_biases(weights, biases, layer_name):\n", + " # Print biases\n", + " print(f\"ALPAKA_STATIC_ACC_MEM_GLOBAL const float bias_{layer_name}[{len(biases)}] = {{\")\n", + " print(\", \".join(f\"{b:.7f}f\" for b in biases) + \" };\")\n", + " print()\n", + "\n", + " # Print weights\n", + " print(f\"ALPAKA_STATIC_ACC_MEM_GLOBAL const float wgtT_{layer_name}[{len(weights[0])}][{len(weights)}] = {{\")\n", + " for row in weights.T:\n", + " formatted_row = \", \".join(f\"{w:.7f}f\" for w in row)\n", + " print(f\"{{ {formatted_row} }},\")\n", + " print(\"};\")\n", + " print()\n", + "\n", + "def print_model_weights_biases(model):\n", + " # Make sure the model is in evaluation mode\n", + " model.eval()\n", + "\n", + " # Iterate through all named modules in the model\n", + " for name, module in model.named_modules():\n", + " # Check if the module is a linear layer\n", + " if isinstance(module, nn.Linear):\n", + " # Get weights and biases\n", + " weights = module.weight.data.cpu().numpy()\n", + " biases = module.bias.data.cpu().numpy()\n", + "\n", + " # Print formatted weights and biases\n", + " print_formatted_weights_biases(weights, biases, name.replace('.', '_'))\n", + "\n", + "print_model_weights_biases(model)\n" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": {}, + "outputs": [], + "source": [ + "# Ensure input_features_tensor is moved to the appropriate device\n", + "input_features_tensor = input_features_tensor.to(device)\n", + "filtered_inputs = input_features_tensor[~nan_mask]\n", + "filtered_labels = labels_tensor[~nan_mask]\n", + "\n", + "# Make predictions\n", + "with torch.no_grad():\n", + " model.eval()\n", + " outputs = model(input_features_tensor)\n", + " predictions = outputs.squeeze().cpu().numpy()\n", + "\n", + "full_tracks = (np.concatenate(branches['t4_pMatched']) > 0.95)\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Eta bin 0.00-0.10: 343773 fakes, 783 true Displaced\n", + "Eta bin 0.10-0.20: 277198 fakes, 685 true Displaced\n", + "Eta bin 0.20-0.30: 236575 fakes, 679 true Displaced\n", + "Eta bin 0.30-0.40: 243786 fakes, 803 true Displaced\n", + "Eta bin 0.40-0.50: 236255 fakes, 682 true Displaced\n", + "Eta bin 0.50-0.60: 215018 fakes, 856 true Displaced\n", + "Eta bin 0.60-0.70: 157631 fakes, 989 true Displaced\n", + "Eta bin 0.70-0.80: 117039 fakes, 746 true Displaced\n", + "Eta bin 0.80-0.90: 103566 fakes, 893 true Displaced\n", + "Eta bin 0.90-1.00: 63672 fakes, 710 true Displaced\n", + "Eta bin 1.00-1.10: 46689 fakes, 803 true Displaced\n", + "Eta bin 1.10-1.20: 54520 fakes, 765 true Displaced\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Eta bin 1.20-1.30: 69878 fakes, 988 true Displaced\n", + "Eta bin 1.30-1.40: 28034 fakes, 447 true Displaced\n", + "Eta bin 1.40-1.50: 10484 fakes, 458 true Displaced\n", + "Eta bin 1.50-1.60: 16192 fakes, 1550 true Displaced\n", + "Eta bin 1.60-1.70: 26335 fakes, 1499 true Displaced\n", + "Eta bin 1.70-1.80: 29781 fakes, 1367 true Displaced\n", + "Eta bin 1.80-1.90: 45547 fakes, 1894 true Displaced\n", + "Eta bin 1.90-2.00: 41836 fakes, 1748 true Displaced\n", + "Eta bin 2.00-2.10: 11233 fakes, 1706 true Displaced\n", + "Eta bin 2.10-2.20: 8634 fakes, 1151 true Displaced\n", + "Eta bin 2.20-2.30: 4555 fakes, 1104 true Displaced\n", + "Eta bin 2.30-2.40: 7739 fakes, 1940 true Displaced\n", + "Eta bin 2.40-2.50: 1872 fakes, 1428 true Displaced\n", + "
\n", + "\n", + "Displaced tracks, pt: 0.0 to 5.0 GeV\n", + "Number of true displaced tracks: 26674\n", + "Number of fake tracks in pt bin: 2397842\n", + "\n", + "65% Retention Cut Values: {0.6532, 0.6635, 0.7246, 0.7622, 0.6931, 0.7138, 0.7964, 0.7897, 0.8166, 0.8314, 0.7027, 0.6347, 0.6674, 0.7002, 0.7721, 0.8819, 0.8457, 0.8013, 0.7852, 0.7547, 0.7885, 0.8362, 0.8480, 0.6366, 0.3407} Mean: 0.7376000285148621\n", + "65% Cut Fake Rejections: {96.8, 97.5, 97.9, 98.2, 97.6, 97.3, 98.0, 97.2, 97.5, 97.2, 91.8, 89.7, 88.0, 86.0, 88.2, 84.2, 89.7, 89.2, 89.3, 89.4, 82.8, 82.6, 85.9, 81.8, 76.0} Mean: 90.8%\n", + "\n", + "70% Retention Cut Values: {0.5721, 0.5600, 0.6553, 0.7036, 0.6332, 0.6271, 0.7428, 0.7396, 0.7384, 0.7892, 0.6409, 0.5853, 0.6323, 0.6704, 0.7149, 0.8548, 0.8051, 0.7519, 0.7197, 0.6995, 0.7382, 0.7676, 0.8132, 0.4517, 0.2545} Mean: 0.6744999885559082\n", + "70% Cut Fake Rejections: {95.9, 96.6, 97.2, 97.6, 97.1, 96.4, 97.4, 96.5, 96.3, 96.4, 89.7, 87.9, 86.7, 84.5, 85.9, 81.9, 87.5, 87.2, 86.8, 87.3, 80.1, 77.6, 83.7, 74.8, 66.8} Mean: 88.6%\n", + "\n", + "75% Retention Cut Values: {0.4679, 0.4943, 0.5824, 0.6214, 0.5755, 0.5700, 0.6640, 0.6667, 0.6209, 0.7128, 0.5568, 0.5166, 0.5729, 0.6201, 0.6468, 0.8087, 0.7613, 0.6858, 0.6370, 0.6208, 0.6688, 0.6535, 0.7243, 0.2648, 0.1871} Mean: 0.5960000157356262\n", + "75% Cut Fake Rejections: {94.7, 95.9, 96.5, 96.8, 96.5, 95.7, 96.6, 95.6, 94.5, 95.0, 86.6, 84.9, 84.3, 81.9, 83.2, 78.0, 85.2, 84.6, 83.5, 84.4, 77.0, 71.8, 78.3, 60.6, 55.9} Mean: 85.5%\n", + "\n", + "80% Retention Cut Values: {0.3643, 0.3997, 0.4992, 0.5466, 0.4813, 0.5027, 0.5943, 0.5423, 0.5045, 0.6272, 0.5120, 0.4585, 0.5139, 0.5219, 0.5931, 0.7501, 0.7040, 0.5831, 0.5597, 0.4881, 0.6129, 0.5526, 0.3974, 0.1495, 0.1389} Mean: 0.5038999915122986\n", + "80% Cut Fake Rejections: {93.0, 94.6, 95.7, 96.1, 95.4, 94.8, 95.7, 93.7, 92.3, 92.9, 84.8, 82.1, 81.7, 77.0, 81.1, 73.8, 82.3, 80.9, 80.6, 79.4, 74.6, 67.6, 65.3, 42.6, 46.4} Mean: 81.8%\n", + "\n", + "85% Retention Cut Values: {0.2693, 0.2885, 0.3381, 0.3925, 0.3886, 0.3998, 0.5003, 0.4532, 0.3624, 0.5571, 0.4461, 0.3688, 0.4487, 0.4183, 0.4207, 0.6097, 0.5510, 0.4526, 0.4826, 0.3136, 0.4704, 0.4141, 0.2126, 0.1203, 0.1067} Mean: 0.391400009393692\n", + "85% Cut Fake Rejections: {90.9, 92.7, 93.6, 94.2, 94.3, 93.2, 94.4, 92.1, 88.4, 91.0, 81.7, 77.3, 78.6, 71.6, 73.2, 66.4, 75.2, 75.8, 77.5, 71.1, 68.6, 60.6, 48.5, 35.7, 35.8} Mean: 76.9%\n", + "\n", + "90% Retention Cut Values: {0.1981, 0.2071, 0.1651, 0.2351, 0.2320, 0.2366, 0.3681, 0.3035, 0.2574, 0.4382, 0.3126, 0.3057, 0.3526, 0.2610, 0.3053, 0.4718, 0.4004, 0.3037, 0.3010, 0.2001, 0.2483, 0.2288, 0.0990, 0.0992, 0.0847} Mean: 0.26460000872612\n", + "90% Cut Fake Rejections: {88.3, 90.4, 89.0, 90.9, 91.0, 89.4, 92.0, 88.2, 83.8, 87.0, 73.9, 73.1, 73.8, 61.3, 66.5, 59.0, 67.7, 67.9, 68.6, 63.1, 55.6, 47.6, 29.6, 30.0, 30.0} Mean: 70.3%\n", + "\n", + "95% Retention Cut Values: {0.0852, 0.0952, 0.0733, 0.1015, 0.1282, 0.1008, 0.2063, 0.1310, 0.1406, 0.2143, 0.1995, 0.1805, 0.2321, 0.1726, 0.1861, 0.3016, 0.2233, 0.1839, 0.1699, 0.1414, 0.1340, 0.1005, 0.0799, 0.0775, 0.0640} Mean: 0.14890000224113464\n", + "95% Cut Fake Rejections: {78.3, 83.4, 80.8, 83.3, 86.0, 80.4, 86.7, 77.7, 74.8, 75.3, 64.9, 61.7, 65.6, 53.5, 56.4, 48.3, 55.7, 58.6, 58.5, 56.8, 42.5, 29.4, 24.1, 24.7, 22.8} Mean: 61.2%\n", + "
\n", + "\n", + "Fake tracks, pt: 0.0 to 5.0 GeV\n", + "Number of true fake tracks: 26674\n", + "Number of fake tracks in pt bin: 2397842\n", + "\n", + "65% Retention Cut Values: {0.1999, 0.1969, 0.1692, 0.1742, 0.1824, 0.1582, 0.1384, 0.1035, 0.0906, 0.1073, 0.1916, 0.2504, 0.1690, 0.2609, 0.0723, 0.0417, 0.0597, 0.0442, 0.0552, 0.0297, 0.0270, 0.0183, 0.0123, 0.0100, 0.0057} Mean: 0.11069999635219574\n", + "65% Cut Fake Rejections: {97.2, 97.8, 98.1, 97.8, 97.5, 97.5, 97.5, 97.8, 97.8, 97.2, 89.9, 87.9, 90.0, 85.7, 88.9, 88.6, 90.1, 91.9, 91.3, 94.6, 87.0, 88.4, 89.3, 85.0, 78.9} Mean: 92.1%\n", + "\n", + "70% Retention Cut Values: {0.2629, 0.2361, 0.2003, 0.2037, 0.2165, 0.1897, 0.1541, 0.1188, 0.1068, 0.1368, 0.2128, 0.2902, 0.1947, 0.3002, 0.0869, 0.0497, 0.0713, 0.0502, 0.0650, 0.0370, 0.0340, 0.0211, 0.0137, 0.0117, 0.0063} Mean: 0.13079999387264252\n", + "70% Cut Fake Rejections: {96.1, 97.2, 97.6, 97.3, 96.8, 96.8, 97.1, 97.3, 97.3, 96.2, 88.6, 85.4, 88.5, 83.8, 86.9, 86.4, 88.6, 90.9, 90.2, 93.4, 84.2, 87.0, 87.6, 82.8, 77.6} Mean: 90.9%\n", + "\n", + "75% Retention Cut Values: {0.3379, 0.3084, 0.2357, 0.2501, 0.2501, 0.2346, 0.1864, 0.1535, 0.1249, 0.1868, 0.2382, 0.3318, 0.2401, 0.3343, 0.1031, 0.0625, 0.0818, 0.0615, 0.0831, 0.0467, 0.0445, 0.0243, 0.0153, 0.0140, 0.0074} Mean: 0.1582999974489212\n", + "75% Cut Fake Rejections: {94.6, 96.0, 97.0, 96.4, 96.2, 95.9, 96.2, 96.3, 96.7, 94.8, 87.0, 83.0, 85.9, 81.9, 85.0, 83.9, 87.3, 89.4, 88.3, 92.0, 80.8, 85.2, 86.6, 80.2, 74.5} Mean: 89.2%\n", + "\n", + "80% Retention Cut Values: {0.4175, 0.4224, 0.2946, 0.3265, 0.3264, 0.2734, 0.2478, 0.1879, 0.1520, 0.2460, 0.2781, 0.3844, 0.2920, 0.3993, 0.1187, 0.0765, 0.0982, 0.0805, 0.1061, 0.0644, 0.0557, 0.0288, 0.0187, 0.0168, 0.0097} Mean: 0.19689999520778656\n", + "80% Cut Fake Rejections: {92.9, 94.0, 95.9, 95.0, 94.6, 95.0, 94.6, 95.3, 95.8, 92.6, 84.4, 80.0, 83.1, 78.4, 83.6, 81.5, 85.6, 86.9, 86.2, 89.7, 78.0, 83.2, 83.7, 78.1, 70.2} Mean: 87.1%\n", + "\n", + "85% Retention Cut Values: {0.5190, 0.5398, 0.3760, 0.4063, 0.3922, 0.3490, 0.3289, 0.2385, 0.2072, 0.3168, 0.3340, 0.4398, 0.3609, 0.5060, 0.1707, 0.0933, 0.1248, 0.1158, 0.1441, 0.0827, 0.0738, 0.0402, 0.0314, 0.0208, 0.0131} Mean: 0.24899999797344208\n", + "85% Cut Fake Rejections: {90.3, 91.5, 94.4, 93.4, 93.1, 93.2, 92.3, 93.7, 94.0, 89.9, 81.0, 76.5, 79.4, 72.9, 79.3, 78.8, 83.2, 83.6, 83.1, 87.7, 74.8, 78.8, 75.1, 74.8, 65.5} Mean: 84.0%\n", + "\n", + "90% Retention Cut Values: {0.6180, 0.6521, 0.6022, 0.5880, 0.5662, 0.4464, 0.4244, 0.3351, 0.3012, 0.4045, 0.4663, 0.5329, 0.4525, 0.6504, 0.2746, 0.1339, 0.1790, 0.1685, 0.2066, 0.1302, 0.0960, 0.0623, 0.0498, 0.0278, 0.0167} Mean: 0.3353999853134155\n", + "90% Cut Fake Rejections: {87.2, 88.4, 89.5, 89.1, 88.4, 90.7, 89.5, 90.6, 90.6, 86.3, 73.2, 70.5, 74.3, 64.5, 72.9, 73.7, 79.0, 79.6, 79.4, 83.4, 71.5, 73.1, 67.9, 70.0, 61.3} Mean: 79.4%\n", + "\n", + "95% Retention Cut Values: {0.7176, 0.7772, 0.8458, 0.8151, 0.7574, 0.7156, 0.6743, 0.4925, 0.4833, 0.5268, 0.6188, 0.6430, 0.5647, 0.8129, 0.4544, 0.2189, 0.2672, 0.3052, 0.3373, 0.2165, 0.1635, 0.0924, 0.0915, 0.0545, 0.0241} Mean: 0.4668000042438507\n", + "95% Cut Fake Rejections: {82.9, 83.2, 78.2, 79.1, 80.4, 81.1, 79.5, 84.4, 82.9, 80.6, 63.3, 62.8, 68.0, 52.1, 63.4, 65.2, 73.2, 71.3, 72.2, 77.2, 63.5, 66.3, 56.5, 59.0, 56.9} Mean: 71.3%\n", + "Eta bin 0.00-0.10: 58324 fakes, 32 true Displaced\n", + "Eta bin 0.10-0.20: 42710 fakes, 8 true Displaced\n", + "Eta bin 0.20-0.30: 37318 fakes, 13 true Displaced\n", + "Eta bin 0.30-0.40: 35874 fakes, 13 true Displaced\n", + "Eta bin 0.40-0.50: 35293 fakes, 18 true Displaced\n", + "Eta bin 0.50-0.60: 14925 fakes, 7 true Displaced\n", + "Eta bin 0.60-0.70: 13246 fakes, 10 true Displaced\n", + "Eta bin 0.70-0.80: 8969 fakes, 7 true Displaced\n", + "Eta bin 0.80-0.90: 7843 fakes, 9 true Displaced\n", + "Eta bin 0.90-1.00: 4815 fakes, 2 true Displaced\n", + "Eta bin 1.00-1.10: 4865 fakes, 10 true Displaced\n", + "Eta bin 1.10-1.20: 6657 fakes, 6 true Displaced\n", + "Eta bin 1.20-1.30: 9686 fakes, 24 true Displaced\n", + "Eta bin 1.30-1.40: 2782 fakes, 10 true Displaced\n", + "Eta bin 1.40-1.50: 1961 fakes, 11 true Displaced\n", + "Eta bin 1.50-1.60: 1527 fakes, 29 true Displaced\n", + "Eta bin 1.60-1.70: 2750 fakes, 63 true Displaced\n", + "Eta bin 1.70-1.80: 3531 fakes, 37 true Displaced\n", + "Eta bin 1.80-1.90: 4829 fakes, 50 true Displaced\n", + "Eta bin 1.90-2.00: 4246 fakes, 16 true Displaced\n", + "Eta bin 2.00-2.10: 1789 fakes, 16 true Displaced\n", + "Eta bin 2.10-2.20: 1601 fakes, 10 true Displaced\n", + "Eta bin 2.20-2.30: 785 fakes, 3 true Displaced\n", + "Eta bin 2.30-2.40: 1073 fakes, 16 true Displaced\n", + "Eta bin 2.40-2.50: 242 fakes, 15 true Displaced\n", + "
\n", + "\n", + "Displaced tracks, pt: 5.0 to inf GeV\n", + "Number of true displaced tracks: 435\n", + "Number of fake tracks in pt bin: 307641\n", + "\n", + "65% Retention Cut Values: {0.2265, 0.0395, 0.4312, 0.4840, 0.4537, 0.9577, 0.9294, 0.9280, 0.4712, 0.4557, 0.4738, 0.0794, 0.1110, 0.5503, 0.1385, 0.5199, 0.4993, 0.6410, 0.4735, 0.3342, 0.3112, 0.8436, 0.7717, 0.4069, 0.4031} Mean: 0.477400004863739\n", + "65% Cut Fake Rejections: {96.8, 86.8, 98.8, 98.8, 98.5, 100.0, 99.8, 99.9, 98.2, 97.7, 97.8, 72.3, 75.0, 93.5, 68.4, 86.5, 88.1, 95.1, 87.3, 83.7, 80.4, 99.4, 97.5, 81.3, 85.1} Mean: 90.7%\n", + "\n", + "70% Retention Cut Values: {0.1905, 0.0395, 0.4040, 0.3383, 0.3143, 0.9561, 0.9278, 0.9043, 0.4546, 0.4192, 0.4578, 0.0747, 0.0802, 0.5133, 0.1262, 0.5143, 0.4778, 0.6301, 0.3928, 0.3137, 0.3038, 0.8183, 0.7583, 0.3755, 0.3992} Mean: 0.4474000036716461\n", + "70% Cut Fake Rejections: {96.2, 86.8, 98.7, 98.5, 97.9, 100.0, 99.8, 99.9, 98.1, 97.4, 97.6, 70.9, 69.9, 93.2, 67.0, 86.2, 87.2, 94.4, 84.8, 82.8, 79.8, 98.6, 97.5, 78.9, 85.1} Mean: 89.9%\n", + "\n", + "75% Retention Cut Values: {0.1580, 0.0394, 0.3631, 0.1789, 0.2648, 0.9539, 0.9259, 0.8688, 0.4380, 0.3828, 0.4338, 0.0699, 0.0682, 0.4578, 0.1201, 0.3865, 0.4560, 0.5532, 0.3760, 0.2905, 0.2956, 0.8145, 0.7449, 0.3459, 0.3417} Mean: 0.413100004196167\n", + "75% Cut Fake Rejections: {95.6, 86.8, 98.5, 97.2, 97.6, 100.0, 99.8, 99.8, 97.8, 97.1, 97.2, 69.6, 67.8, 90.5, 66.0, 72.4, 86.5, 92.9, 84.2, 81.5, 79.2, 98.4, 97.2, 77.5, 79.8} Mean: 88.4%\n", + "\n", + "80% Retention Cut Values: {0.0391, 0.0394, 0.3378, 0.1402, 0.1203, 0.9516, 0.9220, 0.8333, 0.4039, 0.3463, 0.4204, 0.0652, 0.0429, 0.3625, 0.1140, 0.3114, 0.3984, 0.4187, 0.3275, 0.2243, 0.2907, 0.7854, 0.7315, 0.3334, 0.2778} Mean: 0.3695000112056732\n", + "80% Cut Fake Rejections: {84.2, 86.8, 98.4, 96.4, 95.2, 100.0, 99.8, 99.7, 97.5, 96.4, 97.1, 68.5, 60.6, 88.0, 65.2, 67.6, 83.7, 82.0, 82.1, 77.3, 78.3, 98.1, 96.9, 77.4, 72.7} Mean: 86.0%\n", + "\n", + "85% Retention Cut Values: {0.0246, 0.0394, 0.2889, 0.0983, 0.0281, 0.9173, 0.9160, 0.7343, 0.3697, 0.3098, 0.4204, 0.0617, 0.0345, 0.2174, 0.1020, 0.2642, 0.3695, 0.3122, 0.2431, 0.1738, 0.2899, 0.7246, 0.7180, 0.1358, 0.2513} Mean: 0.32179999351501465\n", + "85% Cut Fake Rejections: {78.3, 86.8, 98.1, 94.8, 82.5, 99.9, 99.8, 99.3, 97.0, 96.2, 97.1, 67.3, 57.5, 79.4, 63.3, 64.5, 82.5, 75.2, 78.4, 73.7, 78.3, 97.3, 96.8, 57.7, 69.8} Mean: 82.9%\n", + "\n", + "90% Retention Cut Values: {0.0245, 0.0330, 0.1931, 0.0502, 0.0179, 0.8189, 0.8216, 0.5082, 0.3526, 0.2734, 0.4204, 0.0582, 0.0184, 0.1018, 0.0899, 0.2338, 0.2594, 0.2093, 0.1854, 0.1399, 0.2743, 0.6624, 0.7046, 0.0640, 0.2394} Mean: 0.2702000141143799\n", + "90% Cut Fake Rejections: {78.3, 84.8, 97.2, 89.5, 75.9, 99.8, 99.6, 97.9, 96.7, 95.6, 97.1, 66.2, 48.2, 66.8, 61.2, 61.7, 75.9, 66.4, 74.9, 69.5, 76.0, 95.7, 96.7, 43.2, 68.6} Mean: 79.3%\n", + "\n", + "95% Retention Cut Values: {0.0129, 0.0255, 0.1427, 0.0241, 0.0030, 0.7204, 0.4184, 0.2822, 0.3526, 0.2369, 0.4204, 0.0546, 0.0125, 0.0888, 0.0824, 0.1519, 0.2437, 0.0519, 0.0186, 0.1089, 0.2072, 0.5954, 0.6912, 0.0577, 0.2156} Mean: 0.20880000293254852\n", + "95% Cut Fake Rejections: {67.7, 81.4, 96.2, 81.4, 43.9, 99.6, 97.3, 95.9, 96.7, 94.7, 97.1, 65.3, 42.5, 64.8, 60.4, 52.7, 74.9, 47.2, 43.3, 66.3, 71.0, 94.2, 96.6, 40.9, 65.3} Mean: 73.5%\n", + "
\n", + "\n", + "Fake tracks, pt: 5.0 to inf GeV\n", + "Number of true fake tracks: 435\n", + "Number of fake tracks in pt bin: 307641\n", + "\n", + "65% Retention Cut Values: {0.6532, 0.9603, 0.0759, 0.3199, 0.4190, 0.0409, 0.0602, 0.0498, 0.1935, 0.4991, 0.0806, 0.4732, 0.2191, 0.4040, 0.5088, 0.0703, 0.0402, 0.0639, 0.1041, 0.0875, 0.0805, 0.0054, 0.0079, 0.0199, 0.0103} Mean: 0.21789999306201935\n", + "65% Cut Fake Rejections: {93.7, 65.2, 99.9, 98.8, 97.7, 100.0, 99.8, 99.9, 99.2, 95.4, 99.5, 91.5, 95.0, 84.9, 72.8, 91.2, 96.3, 92.9, 94.8, 92.7, 94.2, 100.0, 99.6, 97.6, 88.4} Mean: 93.6%\n", + "\n", + "70% Retention Cut Values: {0.7254, 0.9603, 0.2475, 0.4581, 0.6331, 0.0423, 0.0618, 0.0564, 0.2225, 0.5325, 0.0835, 0.4905, 0.2978, 0.4040, 0.5088, 0.0727, 0.0461, 0.0679, 0.1579, 0.0921, 0.0845, 0.0064, 0.0085, 0.0225, 0.0123} Mean: 0.251800000667572\n", + "70% Cut Fake Rejections: {89.8, 65.2, 99.5, 97.6, 93.5, 99.9, 99.8, 99.9, 99.0, 95.0, 99.5, 90.9, 92.8, 84.9, 72.8, 91.2, 95.6, 92.6, 92.7, 92.5, 94.0, 100.0, 99.6, 97.2, 88.4} Mean: 93.0%\n", + "\n", + "75% Retention Cut Values: {0.7619, 0.9603, 0.5050, 0.6356, 0.6610, 0.0442, 0.0629, 0.0665, 0.2516, 0.5659, 0.0837, 0.5078, 0.3080, 0.4040, 0.5780, 0.0727, 0.0750, 0.0730, 0.3107, 0.1022, 0.0906, 0.0072, 0.0092, 0.0348, 0.0130} Mean: 0.2874000072479248\n", + "75% Cut Fake Rejections: {88.0, 65.2, 97.5, 94.7, 91.9, 99.9, 99.8, 99.9, 98.7, 94.3, 99.5, 90.5, 92.2, 84.9, 67.9, 91.2, 94.4, 92.1, 83.5, 92.0, 93.7, 100.0, 99.5, 95.7, 88.4} Mean: 91.8%\n", + "\n", + "80% Retention Cut Values: {0.7911, 0.9604, 0.5849, 0.6872, 0.7378, 0.0462, 0.0638, 0.0765, 0.3630, 0.5993, 0.0892, 0.5251, 0.3419, 0.4363, 0.6472, 0.0885, 0.0940, 0.0798, 0.4018, 0.1204, 0.1090, 0.0087, 0.0098, 0.0659, 0.0147} Mean: 0.31769999861717224\n", + "80% Cut Fake Rejections: {86.5, 65.2, 96.4, 91.6, 88.0, 99.9, 99.8, 99.8, 97.2, 93.2, 99.5, 89.9, 91.0, 83.9, 58.3, 89.6, 93.2, 91.7, 80.4, 91.0, 92.8, 99.8, 99.4, 90.2, 88.4} Mean: 90.3%\n", + "\n", + "85% Retention Cut Values: {0.7957, 0.9605, 0.6452, 0.7256, 0.8383, 0.0780, 0.0646, 0.1521, 0.4745, 0.6328, 0.1012, 0.5348, 0.4404, 0.5091, 0.7080, 0.1168, 0.1020, 0.0817, 0.4583, 0.1210, 0.1090, 0.0113, 0.0105, 0.0698, 0.0197} Mean: 0.35040000081062317\n", + "85% Cut Fake Rejections: {86.2, 65.2, 94.6, 90.2, 81.1, 99.8, 99.8, 99.6, 94.6, 91.8, 99.4, 89.2, 87.3, 81.9, 52.5, 87.2, 92.9, 91.5, 77.5, 91.0, 92.8, 99.7, 99.4, 89.5, 87.6} Mean: 88.9%\n", + "\n", + "90% Retention Cut Values: {0.9115, 0.9605, 0.6660, 0.7374, 0.9263, 0.1698, 0.1485, 0.3590, 0.5302, 0.6662, 0.1273, 0.5445, 0.5916, 0.5985, 0.7687, 0.1317, 0.2187, 0.1160, 0.4810, 0.1532, 0.3180, 0.0155, 0.0111, 0.1336, 0.1455} Mean: 0.4171999990940094\n", + "90% Cut Fake Rejections: {74.6, 65.2, 93.5, 89.7, 70.8, 99.6, 99.5, 96.9, 93.3, 90.4, 99.0, 88.7, 79.9, 76.9, 48.0, 85.7, 88.0, 89.4, 76.5, 89.6, 73.1, 98.9, 99.4, 77.6, 57.4} Mean: 84.1%\n", + "\n", + "95% Retention Cut Values: {0.9423, 0.9605, 0.7309, 0.8130, 0.9891, 0.2616, 0.5234, 0.5660, 0.5302, 0.6996, 0.2025, 0.5542, 0.6456, 0.7462, 0.7962, 0.1357, 0.2877, 0.1279, 0.5260, 0.3493, 0.6150, 0.0258, 0.0118, 0.1965, 0.2335} Mean: 0.49880000948905945\n", + "95% Cut Fake Rejections: {67.2, 65.2, 90.4, 86.0, 38.0, 99.0, 93.9, 93.5, 93.3, 88.6, 97.6, 88.2, 74.3, 64.7, 46.6, 85.1, 84.4, 89.0, 74.3, 77.8, 47.1, 97.7, 99.4, 68.2, 41.7} Mean: 78.0%\n" + ] + } + ], + "source": [ + "import numpy as np\n", + "from matplotlib import pyplot as plt\n", + "from matplotlib.colors import LogNorm\n", + "import torch\n", + "\n", + "# Ensure input_features_tensor is on the right device\n", + "input_features_tensor = input_features_tensor.to(device)\n", + "\n", + "t4_pt = np.concatenate(branches['t4_pt'])\n", + "\n", + "# Get model predictions\n", + "with torch.no_grad():\n", + " model.eval()\n", + " outputs = model(input_features_tensor)\n", + " predictions = outputs.cpu().numpy() # Shape will be [n_samples, 3]\n", + "\n", + "\n", + "def plot_for_pt_bin(pt_min, pt_max, percentiles, eta_bin_edges, t4_pt, predictions, t4_sim_vxy, eta_list):\n", + " \"\"\"\n", + " Calculate and plot cut values for specified percentiles in a given pt bin, separately for prompt and displaced tracks\n", + " \"\"\"\n", + " # Filter data based on pt bin\n", + " pt_mask = (t4_pt > pt_min) & (t4_pt <= pt_max)\n", + " \n", + " # Get absolute eta values for all tracks in pt bin\n", + " abs_eta = np.abs(eta_list[0][pt_mask])\n", + " \n", + " # Get predictions for all tracks in pt bin\n", + " pred_filtered = predictions[pt_mask]\n", + " \n", + " # Get track types using pMatched and t4_sim_vxy\n", + " matched = (np.concatenate(branches['t4_pMatched']) > 0.95)[pt_mask]\n", + " fake_tracks = (np.concatenate(branches['t4_pMatched']) <= 0.75)[pt_mask]\n", + " true_displaced = (t4_sim_vxy[pt_mask] > 0.1) & matched\n", + " \n", + " # Separate plots for prompt and displaced tracks\n", + " for track_type, true_mask, pred_idx, title_suffix in [\n", + " (\"Displaced\", true_displaced, 2, \"Displaced Real Tracks\"),\n", + " (\"Fake\", true_displaced, 0, \"Displaced Real Tracks\")\n", + " ]:\n", + " # Dictionaries to store values\n", + " cut_values = {p: [] for p in percentiles}\n", + " fake_rejections = {p: [] for p in percentiles}\n", + " \n", + " # Get probabilities for this class\n", + " probs = pred_filtered[:, pred_idx]\n", + " \n", + " # Create two side-by-side plots\n", + " fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(20, 8))\n", + " \n", + " # Plot probability distribution (only for true tracks of this type)\n", + " h = ax1.hist2d(abs_eta[true_mask], \n", + " probs[true_mask], \n", + " bins=[eta_bin_edges, 50], \n", + " norm=LogNorm())\n", + " plt.colorbar(h[3], ax=ax1, label='Counts')\n", + " \n", + " # For each eta bin\n", + " bin_centers = []\n", + " for i in range(len(eta_bin_edges) - 1):\n", + " eta_min, eta_max = eta_bin_edges[i], eta_bin_edges[i+1]\n", + " bin_center = (eta_min + eta_max) / 2\n", + " bin_centers.append(bin_center)\n", + " \n", + " # Get tracks in this eta bin\n", + " eta_mask = (abs_eta >= eta_min) & (abs_eta < eta_max)\n", + " \n", + " # True tracks of this type in this bin\n", + " true_type_mask = eta_mask & true_mask\n", + " # Fake tracks in this bin\n", + " fake_mask = eta_mask & fake_tracks\n", + " \n", + " if track_type == \"Displaced\":\n", + " print(f\"Eta bin {eta_min:.2f}-{eta_max:.2f}: {np.sum(fake_mask)} fakes, {np.sum(true_type_mask)} true {track_type}\")\n", + " \n", + " if np.sum(true_type_mask) > 0: # If we have true tracks in this bin\n", + " for percentile in percentiles:\n", + " # Calculate cut value to keep desired percentage of true tracks\n", + " if track_type == \"Fake\":\n", + " cut_value = np.percentile(probs[true_type_mask], percentile)\n", + " else:\n", + " cut_value = np.percentile(probs[true_type_mask], 100 - percentile)\n", + " cut_values[percentile].append(cut_value)\n", + " \n", + " # Calculate fake rejection for this cut\n", + " if np.sum(fake_mask) > 0:\n", + " if track_type == \"Fake\":\n", + " fake_rej = 100 * np.mean(probs[fake_mask] > cut_value)\n", + " else:\n", + " fake_rej = 100 * np.mean(probs[fake_mask] < cut_value)\n", + " fake_rejections[percentile].append(fake_rej)\n", + " else:\n", + " fake_rejections[percentile].append(np.nan)\n", + " else:\n", + " for percentile in percentiles:\n", + " cut_values[percentile].append(np.nan)\n", + " fake_rejections[percentile].append(np.nan)\n", + " \n", + " # Plot cut values and fake rejections\n", + " colors = plt.cm.rainbow(np.linspace(0, 1, len(percentiles)))\n", + " bin_centers = np.array(bin_centers)\n", + " \n", + " for (percentile, color) in zip(percentiles, colors):\n", + " values = np.array(cut_values[percentile])\n", + " mask = ~np.isnan(values)\n", + " if np.any(mask):\n", + " # Plot cut values\n", + " ax1.plot(bin_centers[mask], values[mask], '-', color=color, marker='o',\n", + " label=f'{percentile}% Retention Cut')\n", + " # Plot fake rejections\n", + " rej_values = np.array(fake_rejections[percentile])\n", + " ax2.plot(bin_centers[mask], rej_values[mask], '-', color=color, marker='o',\n", + " label=f'{percentile}% Cut')\n", + " \n", + " # Set plot labels and titles\n", + " ax1.set_xlabel(\"Absolute Eta\")\n", + " ax1.set_ylabel(f\"DNN {track_type} Probability\")\n", + " ax1.set_title(f\"DNN Score vs Eta ({title_suffix})\\npt: {pt_min:.1f} to {pt_max:.1f} GeV\")\n", + " ax1.legend()\n", + " ax1.grid(True, alpha=0.3)\n", + " \n", + " ax2.set_xlabel(\"Absolute Eta\")\n", + " ax2.set_ylabel(\"Fake Rejection (%)\")\n", + " ax2.set_title(f\"Fake Rejection vs Eta\\npt: {pt_min:.1f} to {pt_max:.1f} GeV\")\n", + " ax2.legend()\n", + " ax2.grid(True, alpha=0.3)\n", + " ax2.set_ylim(0, 100)\n", + " \n", + " plt.tight_layout()\n", + " plt.show()\n", + " \n", + " # Print statistics\n", + " print(f\"\\n{track_type} tracks, pt: {pt_min:.1f} to {pt_max:.1f} GeV\")\n", + " print(f\"Number of true {track_type.lower()} tracks: {np.sum(true_mask)}\")\n", + " print(f\"Number of fake tracks in pt bin: {np.sum(fake_tracks)}\")\n", + " \n", + " for percentile in percentiles:\n", + " print(f\"\\n{percentile}% Retention Cut Values:\",\n", + " '{' + ', '.join(f\"{x:.4f}\" if not np.isnan(x) else 'nan' for x in cut_values[percentile]) + '}',\n", + " f\"Mean: {np.round(np.nanmean(cut_values[percentile]), 4)}\")\n", + " print(f\"{percentile}% Cut Fake Rejections:\",\n", + " '{' + ', '.join(f\"{x:.1f}\" if not np.isnan(x) else 'nan' for x in fake_rejections[percentile]) + '}',\n", + " f\"Mean: {np.round(np.nanmean(fake_rejections[percentile]), 1)}%\")\n", + "\n", + "def analyze_pt_bins(pt_bins, percentiles, eta_bin_edges, t4_pt, predictions, t4_sim_vxy, eta_list):\n", + " \"\"\"\n", + " Analyze and plot for multiple pt bins and percentiles\n", + " \"\"\"\n", + " for i in range(len(pt_bins) - 1):\n", + " plot_for_pt_bin(pt_bins[i], pt_bins[i + 1], percentiles, eta_bin_edges,\n", + " t4_pt, predictions, t4_sim_vxy, eta_list)\n", + "\n", + "# Run the analysis with same parameters as before\n", + "percentiles = [65, 70, 75, 80, 85, 90, 95]\n", + "pt_bins = [0, 5, np.inf]\n", + "eta_bin_edges = np.arange(0, 2.6, 0.1)\n", + "\n", + "analyze_pt_bins(\n", + " pt_bins=pt_bins,\n", + " percentiles=percentiles,\n", + " eta_bin_edges=eta_bin_edges,\n", + " t4_pt=t4_pt,\n", + " predictions=predictions,\n", + " t4_sim_vxy=np.concatenate(branches['t4_sim_vxy']),\n", + " eta_list=eta_list\n", + ")" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "pytorch_env", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.12.7" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/RecoTracker/LSTCore/standalone/bin/lst.cc b/RecoTracker/LSTCore/standalone/bin/lst.cc index 205690d918856..66cc11552af6e 100644 --- a/RecoTracker/LSTCore/standalone/bin/lst.cc +++ b/RecoTracker/LSTCore/standalone/bin/lst.cc @@ -77,6 +77,7 @@ int main(int argc, char **argv) { "pls", "Write pLS branches in output ntuple.")("pt3", "Write pT3 branches in output ntuple.")( "pt5", "Write pT5 branches in output ntuple.")("occ", "Write occupancy branches in output ntuple.")( "t5dnn", "Write T5 DNN branches in output ntuple.")("t3dnn", "Write T3 DNN branches in output ntuple.")( + "t4", "Write T4 branches in output ntuple.")("t4dnn", "Write T4 DNN branches in output ntuple.")( "allobj", "Write all object branches in output ntuple.")( "J,jet", "Accounts for specific jet branches in input root file for testing")( "sim", "Write extra sim branches in output ntuple"); @@ -289,6 +290,10 @@ int main(int argc, char **argv) { // --pt5 ana.pt5_branches = result["pt5"].as() || result["allobj"].as(); + //_______________________________________________________________________________ + // --t4 + ana.t4_branches = result["t4"].as() || result["allobj"].as(); + //_______________________________________________________________________________ // --occ ana.occ_branches = result["occ"].as() || result["allobj"].as(); @@ -301,6 +306,10 @@ int main(int argc, char **argv) { // --t3dnn ana.t3dnn_branches = result["t3dnn"].as() || result["allobj"].as(); + //_______________________________________________________________________________ + // --t4dnn + ana.t4dnn_branches = result["t4dnn"].as() || result["allobj"].as(); + //_______________________________________________________________________________ // --jet ana.jet_branches = result["jet"].as() || result["allobj"].as(); @@ -455,6 +464,7 @@ void run_lst() { float timing_T3; float timing_T5; float timing_pLS; + float timing_T4; float timing_pT5; float timing_pT3; float timing_TC; @@ -479,7 +489,7 @@ void run_lst() { timing_T5 = runQuintuplet(events.at(omp_get_thread_num())); timing_pLS = runPixelLineSegment(events.at(omp_get_thread_num()), ana.no_pls_dupclean); - + timing_T4 = runQuadruplet(events.at(omp_get_thread_num())); timing_pT5 = runPixelQuintuplet(events.at(omp_get_thread_num())); timing_pT3 = runpT3(events.at(omp_get_thread_num())); timing_TC = runTrackCandidate(events.at(omp_get_thread_num()), ana.no_pls_dupclean, ana.tc_pls_triplets); @@ -525,6 +535,7 @@ void run_lst() { timing_T3, timing_T5, timing_pLS, + timing_T4, timing_pT5, timing_pT3, timing_TC, diff --git a/RecoTracker/LSTCore/standalone/code/core/AccessHelper.cc b/RecoTracker/LSTCore/standalone/code/core/AccessHelper.cc index 78a1d6673319e..aac6ef9fc6b55 100644 --- a/RecoTracker/LSTCore/standalone/code/core/AccessHelper.cc +++ b/RecoTracker/LSTCore/standalone/code/core/AccessHelper.cc @@ -147,6 +147,76 @@ std::tuple, std::vector> getHitIdxsAndHi return convertHitsToHitIdxsAndHitTypes(event, getHitsFromT3(event, T3)); } +// ============== +// ----* T4 *---- +// ============== + +//____________________________________________________________________________________________ +std::vector getT3sFromT4(LSTEvent* event, unsigned int t4) { + auto const quadruplets = event->getQuadruplets(); + unsigned int t3_1 = quadruplets.tripletIndices()[t4][0]; + unsigned int t3_2 = quadruplets.tripletIndices()[t4][1]; + return {t3_1, t3_2}; +} + +//____________________________________________________________________________________________ +std::vector getLSsFromT4(LSTEvent* event, unsigned int T4) { + std::vector T3s = getT3sFromT4(event, T4); + std::vector LSs_0 = getLSsFromT3(event, T3s[0]); + std::vector LSs_1 = getLSsFromT3(event, T3s[1]); + return {LSs_0[0], LSs_0[1], LSs_1[1]}; +} + +//____________________________________________________________________________________________ +std::vector getMDsFromT4(LSTEvent* event, unsigned int T4) { + std::vector LSs = getLSsFromT4(event, T4); + std::vector MDs_0 = getMDsFromLS(event, LSs[0]); + std::vector MDs_1 = getMDsFromLS(event, LSs[1]); + std::vector MDs_2 = getMDsFromLS(event, LSs[2]); + return {MDs_0[0], MDs_0[1], MDs_2[0], MDs_2[1]}; +} + +//____________________________________________________________________________________________ +std::vector getHitsFromT4(LSTEvent* event, unsigned int T4) { + std::vector MDs = getMDsFromT4(event, T4); + std::vector hits_0 = getHitsFromMD(event, MDs[0]); + std::vector hits_1 = getHitsFromMD(event, MDs[1]); + std::vector hits_2 = getHitsFromMD(event, MDs[2]); + std::vector hits_3 = getHitsFromMD(event, MDs[3]); + return {hits_0[0], hits_0[1], hits_1[0], hits_1[1], hits_2[0], hits_2[1], hits_3[0], hits_3[1]}; +} + +//____________________________________________________________________________________________ +std::vector getHitIdxsFromT4(LSTEvent* event, unsigned int T4) { + auto hitsBase = event->getInput(); + std::vector hits = getHitsFromT4(event, T4); + std::vector hitidxs; + for (auto& hit : hits) + hitidxs.push_back(hitsBase.idxs()[hit]); + return hitidxs; +} +//____________________________________________________________________________________________ +std::vector getModuleIdxsFromT4(LSTEvent* event, unsigned int T4) { + std::vector hits = getHitsFromT4(event, T4); + std::vector module_idxs; + auto hitsEvt = event->getHits(); + for (auto& hitIdx : hits) { + module_idxs.push_back(hitsEvt.moduleIndices()[hitIdx]); + } + return module_idxs; +} +//____________________________________________________________________________________________ +std::vector getHitTypesFromT4(LSTEvent* event, unsigned int T4) { + return {4, 4, 4, 4, 4, 4, 4, 4}; + ; +} + +//____________________________________________________________________________________________ +std::tuple, std::vector> getHitIdxsAndHitTypesFromT4(LSTEvent* event, + unsigned T4) { + return convertHitsToHitIdxsAndHitTypes(event, getHitsFromT4(event, T4)); +} + // ============== // ----* T5 *---- // ============== @@ -441,6 +511,9 @@ std::vector getLSsFromTC(LSTEvent* event, unsigned int iTC) { case lst::LSTObjType::pLS: return std::vector(); break; + case lst::LSTObjType::T4: + return getLSsFromT4(event, objidx); + break; } } @@ -465,5 +538,8 @@ std::tuple, std::vector> getHitIdxsAndHi case lst::LSTObjType::pLS: return getHitIdxsAndHitTypesFrompLS(event, objidx); break; + case lst::LSTObjType::T4: + return getHitIdxsAndHitTypesFromT4(event, objidx); + break; } } diff --git a/RecoTracker/LSTCore/standalone/code/core/AccessHelper.h b/RecoTracker/LSTCore/standalone/code/core/AccessHelper.h index 5790f3131fc3d..12b70ffb30b56 100644 --- a/RecoTracker/LSTCore/standalone/code/core/AccessHelper.h +++ b/RecoTracker/LSTCore/standalone/code/core/AccessHelper.h @@ -38,6 +38,17 @@ std::vector getModuleIdxsFromT3(LSTEvent* event, unsigned int T3); std::tuple, std::vector> getHitIdxsAndHitTypesFromT3(LSTEvent* event, unsigned T3); +// ----* T4 *---- +std::vector getT3sFromT4(LSTEvent* event, unsigned int T4); +std::vector getLSsFromT4(LSTEvent* event, unsigned int T4); +std::vector getMDsFromT4(LSTEvent* event, unsigned int T4); +std::vector getHitsFromT4(LSTEvent* event, unsigned int T4); +std::vector getHitIdxsFromT4(LSTEvent* event, unsigned int T4); +std::vector getHitTypesFromT4(LSTEvent* event, unsigned int T4); +std::vector getModuleIdxsFromT4(LSTEvent* event, unsigned int T4); +std::tuple, std::vector> getHitIdxsAndHitTypesFromT4(LSTEvent* event, + unsigned T4); + // ----* T5 *---- std::vector getT3sFromT5(LSTEvent* event, unsigned int T5); std::vector getLSsFromT5(LSTEvent* event, unsigned int T5); diff --git a/RecoTracker/LSTCore/standalone/code/core/AnalysisConfig.h b/RecoTracker/LSTCore/standalone/code/core/AnalysisConfig.h index 230135e68f433..d663568b9c423 100644 --- a/RecoTracker/LSTCore/standalone/code/core/AnalysisConfig.h +++ b/RecoTracker/LSTCore/standalone/code/core/AnalysisConfig.h @@ -153,6 +153,9 @@ class AnalysisConfig { // Boolean to enable pT5 branches bool pt5_branches; + // Boolean to enable T4 branches + bool t4_branches; + // Boolean to enable occupancy branches bool occ_branches; @@ -162,6 +165,9 @@ class AnalysisConfig { // Boolean to enable T5 DNN branches bool t5dnn_branches; + // Boolean to enable T4 DNN branches + bool t4dnn_branches; + // Boolean to enable jet branches bool jet_branches; diff --git a/RecoTracker/LSTCore/standalone/code/core/trkCore.cc b/RecoTracker/LSTCore/standalone/code/core/trkCore.cc index 633b223ce425d..9cb02d1fa96c6 100644 --- a/RecoTracker/LSTCore/standalone/code/core/trkCore.cc +++ b/RecoTracker/LSTCore/standalone/code/core/trkCore.cc @@ -169,6 +169,52 @@ float runpT3(LSTEvent* event) { return pt3_elapsed; } +//___________________________________________________________________________________________________________________________________________________________________________________________ +float runQuadruplet(LSTEvent* event) { + TStopwatch my_timer; + if (ana.verbose >= 2) + std::cout << "Reco Quadruplet start" << std::endl; + my_timer.Start(); + event->createQuadruplets(); + event->wait(); // device side event calls are asynchronous: wait to measure time or print + float t4_elapsed = my_timer.RealTime(); + if (ana.verbose >= 2) + std::cout << "Reco Quadruplet processing time: " << t4_elapsed << " secs" << std::endl; + + if (ana.verbose >= 2) + std::cout << "# of Quadruplets produced: " << event->getNumberOfQuadruplets() << std::endl; + if (ana.verbose >= 2) + std::cout << "# of Quadruplets produced layer 1-2-3-4: " << event->getNumberOfQuadrupletsByLayerBarrel(0) + << std::endl; + if (ana.verbose >= 2) + std::cout << "# of Quadruplets produced layer 2: " << event->getNumberOfQuadrupletsByLayerBarrel(1) << std::endl; + if (ana.verbose >= 2) + std::cout << "# of Quadruplets produced layer 3: " << event->getNumberOfQuadrupletsByLayerBarrel(2) << std::endl; + if (ana.verbose >= 2) + std::cout << "# of Quadruplets produced layer 4: " << event->getNumberOfQuadrupletsByLayerBarrel(3) << std::endl; + if (ana.verbose >= 2) + std::cout << "# of Quadruplets produced layer 5: " << event->getNumberOfQuadrupletsByLayerBarrel(4) << std::endl; + if (ana.verbose >= 2) + std::cout << "# of Quadruplets produced layer 6: " << event->getNumberOfQuadrupletsByLayerBarrel(5) << std::endl; + if (ana.verbose >= 2) + std::cout << "# of Quadruplets produced endcap layer 1: " << event->getNumberOfQuadrupletsByLayerEndcap(0) + << std::endl; + if (ana.verbose >= 2) + std::cout << "# of Quadruplets produced endcap layer 2: " << event->getNumberOfQuadrupletsByLayerEndcap(1) + << std::endl; + if (ana.verbose >= 2) + std::cout << "# of Quadruplets produced endcap layer 3: " << event->getNumberOfQuadrupletsByLayerEndcap(2) + << std::endl; + if (ana.verbose >= 2) + std::cout << "# of Quadruplets produced endcap layer 4: " << event->getNumberOfQuadrupletsByLayerEndcap(3) + << std::endl; + if (ana.verbose >= 2) + std::cout << "# of Quadruplets produced endcap layer 5: " << event->getNumberOfQuadrupletsByLayerEndcap(4) + << std::endl; + + return t4_elapsed; +} + //___________________________________________________________________________________________________________________________________________________________________________________________ float runQuintuplet(LSTEvent* event) { TStopwatch my_timer; @@ -272,6 +318,8 @@ float runTrackCandidate(LSTEvent* event, bool no_pls_dupclean, bool tc_pls_tripl std::cout << " # of pLS TrackCandidates produced: " << event->getNumberOfPLSTrackCandidates() << std::endl; if (ana.verbose >= 2) std::cout << "# of T5 TrackCandidates produced: " << event->getNumberOfT5TrackCandidates() << std::endl; + if (ana.verbose >= 2) + std::cout << "# of T4 TrackCandidates produced: " << event->getNumberOfT4TrackCandidates() << std::endl; return tc_elapsed; } @@ -743,6 +791,7 @@ void printTimingInformation(std::vector>& timing_information, std::cout << " " << std::setw(6) << "T3"; std::cout << " " << std::setw(6) << "T5"; std::cout << " " << std::setw(6) << "pLS"; + std::cout << " " << std::setw(6) << "T4"; std::cout << " " << std::setw(6) << "pT5"; std::cout << " " << std::setw(6) << "pT3"; std::cout << " " << std::setw(6) << "TC"; @@ -757,81 +806,38 @@ void printTimingInformation(std::vector>& timing_information, auto timing = timing_information[ievt]; float timing_total = 0.f; float timing_total_short = 0.f; - timing_total += timing[0] * 1000; // Hits - timing_total += timing[1] * 1000; // MD - timing_total += timing[2] * 1000; // LS - timing_total += timing[3] * 1000; // T3 - timing_total += timing[4] * 1000; // T5 - timing_total += timing[5] * 1000; // pLS - timing_total += timing[6] * 1000; // pT5 - timing_total += timing[7] * 1000; // pT3 - timing_total += timing[8] * 1000; // TC - timing_total_short += timing[1] * 1000; // MD - timing_total_short += timing[2] * 1000; // LS - timing_total_short += timing[3] * 1000; // T3 - timing_total_short += timing[4] * 1000; // T5 - timing_total_short += timing[6] * 1000; // pT5 - timing_total_short += timing[7] * 1000; // pT3 - timing_total_short += timing[8] * 1000; // TC - timing_total_short += timing[9] * 1000; // Reset + timing_total += timing[0] * 1000; // Hits + for (size_t iobj = 1; iobj <= 9; ++iobj) { // MD-TC + timing_total += timing[iobj] * 1000; + if (iobj != 5) + timing_total_short += timing[iobj] * 1000; // exclude pLS + } + timing_total_short += timing[10] * 1000; // Reset std::cout << std::setw(6) << ievt; - std::cout << " " << std::setw(6) << timing[0] * 1000; // Hits - std::cout << " " << std::setw(6) << timing[1] * 1000; // MD - std::cout << " " << std::setw(6) << timing[2] * 1000; // LS - std::cout << " " << std::setw(6) << timing[3] * 1000; // T3 - std::cout << " " << std::setw(6) << timing[4] * 1000; // T5 - std::cout << " " << std::setw(6) << timing[5] * 1000; // pLS - std::cout << " " << std::setw(6) << timing[6] * 1000; // pT5 - std::cout << " " << std::setw(6) << timing[7] * 1000; // pT3 - std::cout << " " << std::setw(6) << timing[8] * 1000; // TC - std::cout << " " << std::setw(6) << timing[9] * 1000; // Reset + for (auto objtime : timing) { + std::cout << " " << std::setw(6) << objtime * 1000; // Print Hits-Reset + } std::cout << " " << std::setw(7) << timing_total; // Total time std::cout << " " << std::setw(7) << timing_total_short; // Total time std::cout << std::endl; - timing_sum_information[0] += timing[0] * 1000; // Hits - timing_sum_information[1] += timing[1] * 1000; // MD - timing_sum_information[2] += timing[2] * 1000; // LS - timing_sum_information[3] += timing[3] * 1000; // T3 - timing_sum_information[4] += timing[4] * 1000; // T5 - timing_sum_information[5] += timing[5] * 1000; // pLS - timing_sum_information[6] += timing[6] * 1000; // pT5 - timing_sum_information[7] += timing[7] * 1000; // pT3 - timing_sum_information[8] += timing[8] * 1000; // TC - timing_sum_information[9] += timing[9] * 1000; // Reset + for (size_t iobj = 0; iobj <= 10; ++iobj) { // Hits-Reset + timing_sum_information[iobj] += timing[iobj] * 1000; + } timing_shortlist.push_back(timing_total_short); // short total timing_list.push_back(timing_total); // short total } - timing_sum_information[0] /= timing_information.size(); // Hits - timing_sum_information[1] /= timing_information.size(); // MD - timing_sum_information[2] /= timing_information.size(); // LS - timing_sum_information[3] /= timing_information.size(); // T3 - timing_sum_information[4] /= timing_information.size(); // T5 - timing_sum_information[5] /= timing_information.size(); // pLS - timing_sum_information[6] /= timing_information.size(); // pT5 - timing_sum_information[7] /= timing_information.size(); // pT3 - timing_sum_information[8] /= timing_information.size(); // TC - timing_sum_information[9] /= timing_information.size(); // Reset + for (size_t iobj = 0; iobj <= 10; iobj++) { // Hits-Reset + timing_sum_information[iobj] /= timing_information.size(); + } float timing_total_avg = 0.0; - timing_total_avg += timing_sum_information[0]; // Hits - timing_total_avg += timing_sum_information[1]; // MD - timing_total_avg += timing_sum_information[2]; // LS - timing_total_avg += timing_sum_information[3]; // T3 - timing_total_avg += timing_sum_information[4]; // T5 - timing_total_avg += timing_sum_information[5]; // pLS - timing_total_avg += timing_sum_information[6]; // pT5 - timing_total_avg += timing_sum_information[7]; // pT3 - timing_total_avg += timing_sum_information[8]; // TC - timing_total_avg += timing_sum_information[9]; // Reset float timing_totalshort_avg = 0.0; - timing_totalshort_avg += timing_sum_information[1]; // MD - timing_totalshort_avg += timing_sum_information[2]; // LS - timing_totalshort_avg += timing_sum_information[3]; // T3 - timing_totalshort_avg += timing_sum_information[4]; // T5 - timing_totalshort_avg += timing_sum_information[6]; // pT5 - timing_totalshort_avg += timing_sum_information[7]; // pT3 - timing_totalshort_avg += timing_sum_information[8]; // TC - timing_totalshort_avg += timing_sum_information[9]; // Reset + timing_total_avg += timing_sum_information[0]; // Hits + for (size_t iobj = 1; iobj <= 10; iobj++) { // MD-Reset + timing_total_avg += timing_sum_information[iobj]; + if (iobj != 5) + timing_totalshort_avg += timing_sum_information[iobj]; // exclude pLS + } float standardDeviation = 0.0; for (auto shorttime : timing_shortlist) { @@ -847,6 +853,7 @@ void printTimingInformation(std::vector>& timing_information, std::cout << " " << std::setw(6) << "T3"; std::cout << " " << std::setw(6) << "T5"; std::cout << " " << std::setw(6) << "pLS"; + std::cout << " " << std::setw(6) << "T4"; std::cout << " " << std::setw(6) << "pT5"; std::cout << " " << std::setw(6) << "pT3"; std::cout << " " << std::setw(6) << "TC"; @@ -855,18 +862,11 @@ void printTimingInformation(std::vector>& timing_information, std::cout << " " << std::setw(7) << "Total(short)"; std::cout << std::endl; std::cout << std::setw(6) << "avg"; - std::cout << " " << std::setw(6) << timing_sum_information[0]; // Hits - std::cout << " " << std::setw(6) << timing_sum_information[1]; // MD - std::cout << " " << std::setw(6) << timing_sum_information[2]; // LS - std::cout << " " << std::setw(6) << timing_sum_information[3]; // T3 - std::cout << " " << std::setw(6) << timing_sum_information[4]; // T5 - std::cout << " " << std::setw(6) << timing_sum_information[5]; // pLS - std::cout << " " << std::setw(6) << timing_sum_information[6]; // pT5 - std::cout << " " << std::setw(6) << timing_sum_information[7]; // pT3 - std::cout << " " << std::setw(6) << timing_sum_information[8]; // TC - std::cout << " " << std::setw(6) << timing_sum_information[9]; // Reset - std::cout << " " << std::setw(7) << timing_total_avg; // Average total time - std::cout << " " << std::setw(7) << timing_totalshort_avg; // Average total time + for (auto objsum : timing_sum_information) { + std::cout << " " << std::setw(6) << objsum; // Print Hits-Reset + } + std::cout << " " << std::setw(7) << timing_total_avg; // Average total time + std::cout << " " << std::setw(7) << timing_totalshort_avg; // Average total time std::cout << "+/- " << std::setw(4) << stdDev; std::cout << " " << std::setw(7) << fullavg; // Average full time std::cout << " " << ana.compilation_target; diff --git a/RecoTracker/LSTCore/standalone/code/core/trkCore.h b/RecoTracker/LSTCore/standalone/code/core/trkCore.h index e1d572bfcbb49..cd8cb4f425c7a 100644 --- a/RecoTracker/LSTCore/standalone/code/core/trkCore.h +++ b/RecoTracker/LSTCore/standalone/code/core/trkCore.h @@ -23,9 +23,9 @@ float runMiniDoublet(LSTEvent* event, int evt); float runSegment(LSTEvent* event); float runT4(LSTEvent* event); float runT4x(LSTEvent* event); -float runpT4(LSTEvent* event); float runT3(LSTEvent* event); float runTrackCandidate(LSTEvent* event, bool no_pls_dupclean, bool tc_pls_triplets); +float runQuadruplet(LSTEvent* event); float runQuintuplet(LSTEvent* event); float runPixelQuintuplet(LSTEvent* event); float runPixelLineSegment(LSTEvent* event, bool no_pls_dupclean); diff --git a/RecoTracker/LSTCore/standalone/code/core/write_lst_ntuple.cc b/RecoTracker/LSTCore/standalone/code/core/write_lst_ntuple.cc index 2e8a137a8c8d6..2e8e428a65147 100644 --- a/RecoTracker/LSTCore/standalone/code/core/write_lst_ntuple.cc +++ b/RecoTracker/LSTCore/standalone/code/core/write_lst_ntuple.cc @@ -19,6 +19,8 @@ void createOutputBranches() { createLineSegmentBranches(); if (ana.t3_branches) createTripletBranches(); + if (ana.t4_branches) + createQuadrupletBranches(); if (ana.t5_branches) createQuintupletBranches(); if (ana.pls_branches) @@ -36,6 +38,8 @@ void createOutputBranches() { createT5DNNBranches(); if (ana.t3dnn_branches) createT3DNNBranches(); + if (ana.t4dnn_branches) + createT4DNNBranches(); } //________________________________________________________________________________________________________________________________ @@ -51,6 +55,8 @@ void fillOutputBranches(LSTEvent* event) { setT3DNNBranches(event, matchfrac); if (ana.t5dnn_branches) setT5DNNBranches(event); + if (ana.t4dnn_branches) + setT4DNNBranches(event); auto const md_idx_map = (ana.md_branches ? setMiniDoubletBranches(event, n_accepted_simtrk, matchfrac) : std::map()); @@ -58,6 +64,8 @@ void fillOutputBranches(LSTEvent* event) { : std::map()); auto const t3_idx_map = (ana.t3_branches ? setTripletBranches(event, n_accepted_simtrk, matchfrac, ls_idx_map) : std::map()); + auto const t4_idx_map = (ana.t4_branches ? setQuadrupletBranches(event, n_accepted_simtrk, matchfrac, t3_idx_map) + : std::map()); auto const t5_idx_map = (ana.t5_branches ? setQuintupletBranches(event, n_accepted_simtrk, matchfrac, t3_idx_map) : std::map()); auto const pls_idx_map = (ana.pls_branches ? setPixelLineSegmentBranches(event, n_accepted_simtrk, matchfrac) @@ -69,7 +77,8 @@ void fillOutputBranches(LSTEvent* event) { (ana.pt5_branches ? setPixelQuintupletBranches(event, n_accepted_simtrk, matchfrac, pls_idx_map, t5_idx_map) : std::map()); - setTrackCandidateBranches(event, n_accepted_simtrk, t5_idx_map, pls_idx_map, pt3_idx_map, pt5_idx_map, matchfrac); + setTrackCandidateBranches( + event, n_accepted_simtrk, t5_idx_map, pls_idx_map, pt3_idx_map, pt5_idx_map, t4_idx_map, matchfrac); // Now actually fill the ttree ana.tx->fill(); @@ -145,6 +154,42 @@ void createT3DNNBranches() { ana.tx->createBranch>>("t3_matched_simIdx"); } +//________________________________________________________________________________________________________________________________ +void createT4DNNBranches() { + // Common branches + ana.tx->createBranch>("t4_t3_idx0"); + ana.tx->createBranch>("t4_t3_idx1"); + ana.tx->createBranch>("t4_tc_idx"); + ana.tx->createBranch>("t4_partOfTC"); + ana.tx->createBranch>("t4_t3_pt"); + ana.tx->createBranch>("t4_t3_eta"); + ana.tx->createBranch>("t4_t3_phi"); + ana.tx->createBranch>("t4_t3_fakeScore1"); + ana.tx->createBranch>("t4_t3_promptScore1"); + ana.tx->createBranch>("t4_t3_displacedScore1"); + ana.tx->createBranch>("t4_t3_fakeScore2"); + ana.tx->createBranch>("t4_t3_promptScore2"); + ana.tx->createBranch>("t4_t3_displacedScore2"); + ana.tx->createBranch>("t4_regressionRadius"); + ana.tx->createBranch>("t4_nonAnchorRegressionRadius"); + + // Hit-specific branches + std::vector hitIndices = {"0", "1", "2", "3", "4", "5"}; + std::vector hitProperties = { + "r", "x", "y", "z", "eta", "phi", "detId", "layer", "moduleType", "moduleIdx"}; + + for (const auto& idx : hitIndices) { + for (const auto& prop : hitProperties) { + std::string branchName = "t4_t3_" + idx + "_" + prop; + if (prop == "detId" || prop == "layer" || prop == "moduleType" || prop == "moduleIdx") { + ana.tx->createBranch>(branchName); + } else { + ana.tx->createBranch>(branchName); + } + } + } +} + //________________________________________________________________________________________________________________________________ void createJetBranches() { ana.tx->createBranch>("sim_deltaEta"); @@ -251,6 +296,12 @@ void createSimTrackContainerBranches() { // list of match fraction for each match (> 0%) to pt5_* container ana.tx->createBranch>>("sim_pt5IdxAllFrac"); } + if (ana.t4_branches) { + // list of idx to matches (> 0%) to t4_* container + ana.tx->createBranch>>("sim_t4IdxAll"); + // list of match fraction for each match (> 0%) to t4_* container + ana.tx->createBranch>>("sim_t4IdxAllFrac"); + } } //________________________________________________________________________________________________________________________________ @@ -262,7 +313,7 @@ void createTrackCandidateBranches() { ana.tx->createBranch>("tc_pt"); // pt ana.tx->createBranch>("tc_eta"); // eta ana.tx->createBranch>("tc_phi"); // phi - ana.tx->createBranch>("tc_type"); // type = 7 (pT5), 5 (pT3), 4 (T5), 8 (pLS) + ana.tx->createBranch>("tc_type"); // type = 7 (pT5), 5 (pT3), 4 (T5), 8 (pLS), 9 (T4) ana.tx->createBranch>("tc_isFake"); // 1 if tc is fake 0 other if not ana.tx->createBranch>("tc_isDuplicate"); // 1 if tc is duplicate 0 other if not ana.tx->createBranch>("tc_simIdx"); // idx of best matched (highest nhit and > 75%) simulated track @@ -282,6 +333,9 @@ void createTrackCandidateBranches() { if (ana.pls_branches) ana.tx->createBranch>( "tc_plsIdx"); // index to the pls_* if it is the said type, if not set to -999 + if (ana.t4_branches) + ana.tx->createBranch>( + "tc_t4Idx"); // index to the t4_* if it is the said type, if not set to -999 } //________________________________________________________________________________________________________________________________ @@ -374,6 +428,40 @@ void createTripletBranches() { ana.tx->createBranch>>("t3_simIdxAllFrac"); } +//________________________________________________________________________________________________________________________________ +void createQuadrupletBranches() { + // Quadruplets (i.e. Four mini-doublets, a.k.a. T4) + // + // The container will hold per entry a quadruplet built by LST in the event. + // + ana.tx->createBranch>("sim_T4_matched"); + ana.tx->createBranch>("t4_isFake"); + ana.tx->createBranch>("t4_isDuplicate"); + ana.tx->createBranch>("t4_moduleType_binary"); + ana.tx->createBranch>("t4_layer_binary"); + ana.tx->createBranch>("t4_innerRadius"); + ana.tx->createBranch>("t4_outerRadius"); + ana.tx->createBranch>("t4_pt"); + ana.tx->createBranch>("t4_eta"); + ana.tx->createBranch>("t4_phi"); + ana.tx->createBranch>("t4_isDup"); + ana.tx->createBranch>("t4_rzChiSquared"); + ana.tx->createBranch>("t4_pMatched"); + ana.tx->createBranch>("t4_sim_vxy"); + ana.tx->createBranch>("t4_sim_vz"); + ana.tx->createBranch>>("t4_matched_simIdx"); + ana.tx->createBranch>("t4_score_rphisum"); + ana.tx->createBranch>("t4_promptScore"); + ana.tx->createBranch>("t4_displacedScore"); + ana.tx->createBranch>("t4_fakeScore"); + + ana.tx->createBranch>("t4_simIdx"); // idx of best matched (highest nhit and > 75%) simulated track + // list of idx of all matched (> 0%) simulated track + ana.tx->createBranch>>("t4_simIdxAll"); + // list of idx of all matched (> 0%) simulated track + ana.tx->createBranch>>("t4_simIdxAllFrac"); +} + //________________________________________________________________________________________________________________________________ void createQuintupletBranches() { // Quintuplets (i.e. Five mini-doublets, a.k.a. T5) @@ -517,6 +605,7 @@ void createOccupancyBranches() { ana.tx->createBranch>("t3_occupancies"); ana.tx->createBranch("tc_occupancies"); ana.tx->createBranch>("t5_occupancies"); + ana.tx->createBranch>("t4_occupancies"); ana.tx->createBranch("pT3_occupancies"); ana.tx->createBranch("pT5_occupancies"); } @@ -1253,6 +1342,173 @@ std::map setTripletBranches(LSTEvent* event, return t3_idx_map; } +//________________________________________________________________________________________________________________________________ +std::map setQuadrupletBranches(LSTEvent* event, + unsigned int n_accepted_simtrk, + float matchfrac, + std::map const& t3_idx_map) { + //-------------------------------------------- + // + // + // Quadruplet + // + // + //-------------------------------------------- + + auto const& trk_sim_pt = trk.getVF("sim_pt"); + auto const& trk_sim_parentVtxIdx = trk.getVI("sim_parentVtxIdx"); + auto const& trk_simvtx_x = trk.getVF("simvtx_x"); + auto const& trk_simvtx_y = trk.getVF("simvtx_y"); + auto const& trk_simvtx_z = trk.getVF("simvtx_z"); + auto const& trk_simhit_simTrkIdx = trk.getVI("simhit_simTrkIdx"); + auto const& trk_ph2_simHitIdx = trk.getVVI("ph2_simHitIdx"); + auto const& trk_pix_simHitIdx = trk.getVVI("pix_simHitIdx"); + + auto const& hitsBase = event->getInput(); + auto const& ranges = event->getRanges(); + auto const& modules = event->getModules(); + auto const& quadruplets = event->getQuadruplets(); + auto const& quadrupletOccupancies = event->getQuadruplets(); + + int n_total_simtrk = trk_sim_pt.size(); + std::vector sim_t4_matched(n_accepted_simtrk); + std::vector> sim_t4IdxAll(n_total_simtrk); + std::vector> sim_t4IdxAllFrac(n_total_simtrk); + std::vector> t4_simIdxAll; + std::vector> t4_simIdxAllFrac; + // Then obtain the lower module index + unsigned int t4_idx = 0; // global t4 index that will be used to keep track of t4 being outputted to the ntuple + // map to keep track of (GPU t4Idx) -> (t4_idx in ntuple output) + std::map t4_idx_map; + + for (unsigned int idx = 0; idx < modules.nLowerModules(); ++idx) { + unsigned int nmods = modules.nLowerModules(); + for (unsigned int iT4 = 0; iT4 < quadrupletOccupancies.nQuadruplets()[idx]; iT4++) { + unsigned int t4Idx = ranges.quadrupletModuleIndices()[idx] + iT4; + t4_idx_map[t4Idx] = t4_idx; + std::vector hit_idx, hit_type; + std::tie(hit_idx, hit_type) = getHitIdxsAndHitTypesFromT4(event, t4Idx); + std::vector simidx; + std::vector simidxfrac; + float percent_matched; + std::tie(simidx, simidxfrac) = matchedSimTrkIdxsAndFracs(hit_idx, + hit_type, + trk_simhit_simTrkIdx, + trk_ph2_simHitIdx, + trk_pix_simHitIdx, + false, + matchfrac, + &percent_matched); + std::vector t3Idxs = getT3sFromT4(event, t4Idx); + + float pt = __H2F(quadruplets.pt()[t4Idx]); + float eta = __H2F(quadruplets.eta()[t4Idx]); + float phi = __H2F(quadruplets.phi()[t4Idx]); + ana.tx->pushbackToBranch("t4_pt", pt); + ana.tx->pushbackToBranch("t4_eta", eta); + ana.tx->pushbackToBranch("t4_phi", phi); + ana.tx->pushbackToBranch("t4_innerRadius", __H2F(quadruplets.innerRadius()[t4Idx])); + ana.tx->pushbackToBranch("t4_outerRadius", __H2F(quadruplets.outerRadius()[t4Idx])); + ana.tx->pushbackToBranch("t4_pMatched", percent_matched); + ana.tx->pushbackToBranch("t4_score_rphisum", __H2F(quadruplets.score_rphisum()[t4Idx])); + ana.tx->pushbackToBranch("t4_rzChiSquared", quadruplets.rzChiSquared()[t4Idx]); + ana.tx->pushbackToBranch("t4_promptScore", quadruplets.promptScore()[t4Idx]); + ana.tx->pushbackToBranch("t4_displacedScore", quadruplets.displacedScore()[t4Idx]); + ana.tx->pushbackToBranch("t4_fakeScore", quadruplets.fakeScore()[t4Idx]); + + int layer_binary = 0; + int moduleType_binary = 0; + std::vector layers; + std::vector module_idx = getModuleIdxsFromT4(event, t4Idx); + + for (size_t i = 0; i < module_idx.size(); i += 2) { + layer_binary |= (1 << (modules.layers()[module_idx[i]] + 6 * (modules.subdets()[module_idx[i]] == 4))); + moduleType_binary |= (modules.moduleType()[module_idx[i]] << i); + layers.push_back(modules.layers()[module_idx[i]] + 6 * (modules.subdets()[module_idx[i]] == 4) + + 5 * (modules.subdets()[module_idx[i]] == 4 && modules.moduleType()[module_idx[i]] == 1)); + } + ana.tx->pushbackToBranch("t4_layer_binary", layer_binary); + ana.tx->pushbackToBranch("t4_moduleType_binary", moduleType_binary); + + bool isfake = true; + for (size_t isim = 0; isim < simidx.size(); ++isim) { + if (simidxfrac[isim] > matchfrac) { + isfake = false; + break; + } + } + ana.tx->pushbackToBranch("t4_isFake", isfake); + t4_simIdxAll.push_back(simidx); + t4_simIdxAllFrac.push_back(simidxfrac); + for (size_t is = 0; is < simidx.size(); ++is) { + int sim_idx = simidx.at(is); + if (sim_idx < n_accepted_simtrk) { + sim_t4_matched.at(sim_idx) += 1; + } + float sim_idx_frac = simidxfrac.at(is); + if (sim_idx < n_total_simtrk) { + sim_t4IdxAll.at(sim_idx).push_back(t4_idx); + sim_t4IdxAllFrac.at(sim_idx).push_back(sim_idx_frac); + } + } + int t4_simIdx = -999; + float t4_simIdxBestFrac = 0; + for (size_t isim = 0; isim < simidx.size(); ++isim) { + int thisidx = simidx[isim]; + float thisfrac = simidxfrac[isim]; + if (thisfrac > t4_simIdxBestFrac and thisfrac > matchfrac) { + t4_simIdxBestFrac = thisfrac; + t4_simIdx = thisidx; + } + } + ana.tx->pushbackToBranch("t4_simIdx", t4_simIdx); + // count global + t4_idx++; + + // Avoid fakes when calculating the vertex distance, set default to 0.0. + if (simidx.size() == 0) { + ana.tx->pushbackToBranch("t4_sim_vxy", 0.0); + ana.tx->pushbackToBranch("t4_sim_vz", 0.0); + } else { + int vtxidx = trk_sim_parentVtxIdx[simidx[0]]; + float vtx_x = trk_simvtx_x[vtxidx]; + float vtx_y = trk_simvtx_y[vtxidx]; + float vtx_z = trk_simvtx_z[vtxidx]; + + ana.tx->pushbackToBranch("t4_sim_vxy", sqrt(vtx_x * vtx_x + vtx_y * vtx_y)); + ana.tx->pushbackToBranch("t4_sim_vz", vtx_z); + } + } + } + ana.tx->setBranch>>("t4_simIdxAll", t4_simIdxAll); + ana.tx->setBranch>>("t4_simIdxAllFrac", t4_simIdxAllFrac); + std::vector> sim_t4IdxAll_to_write; + std::vector> sim_t4IdxAllFrac_to_write; + std::copy(sim_t4IdxAll.begin(), sim_t4IdxAll.begin() + n_accepted_simtrk, std::back_inserter(sim_t4IdxAll_to_write)); + std::copy(sim_t4IdxAllFrac.begin(), + sim_t4IdxAllFrac.begin() + n_accepted_simtrk, + std::back_inserter(sim_t4IdxAllFrac_to_write)); + ana.tx->setBranch>>("sim_t4IdxAll", sim_t4IdxAll_to_write); + ana.tx->setBranch>>("sim_t4IdxAllFrac", sim_t4IdxAllFrac_to_write); + + std::vector t4_isDuplicate(t4_simIdxAll.size()); + for (unsigned int i = 0; i < t4_simIdxAll.size(); i++) { + bool isDuplicate = false; + for (unsigned int isim = 0; isim < t4_simIdxAll[i].size(); isim++) { + int simidx = t4_simIdxAll[i][isim]; + if (simidx < n_accepted_simtrk) { + if (sim_t4_matched[simidx] > 1) { + isDuplicate = true; + } + } + } + t4_isDuplicate[i] = isDuplicate; + } + ana.tx->setBranch>("t4_isDuplicate", t4_isDuplicate); + + return t4_idx_map; +} + //________________________________________________________________________________________________________________________________ std::map setQuintupletBranches(LSTEvent* event, unsigned int n_accepted_simtrk, @@ -1893,6 +2149,7 @@ void setTrackCandidateBranches(LSTEvent* event, std::map pls_idx_map, std::map pt3_idx_map, std::map pt5_idx_map, + std::map t4_idx_map, float matchfrac) { //-------------------------------------------- // @@ -1964,6 +2221,8 @@ void setTrackCandidateBranches(LSTEvent* event, ana.tx->pushbackToBranch("tc_t5Idx", -999); if (ana.pls_branches) ana.tx->pushbackToBranch("tc_plsIdx", -999); + if (ana.t4_branches) + ana.tx->pushbackToBranch("tc_t4Idx", -999); } else if (type == LSTObjType::pT3) { if (ana.pt5_branches) ana.tx->pushbackToBranch("tc_pt5Idx", -999); @@ -1975,6 +2234,8 @@ void setTrackCandidateBranches(LSTEvent* event, ana.tx->pushbackToBranch("tc_t5Idx", -999); if (ana.pls_branches) ana.tx->pushbackToBranch("tc_plsIdx", -999); + if (ana.t4_branches) + ana.tx->pushbackToBranch("tc_t4Idx", -999); } else if (type == LSTObjType::T5) { if (ana.pt5_branches) ana.tx->pushbackToBranch("tc_pt5Idx", -999); @@ -1985,6 +2246,8 @@ void setTrackCandidateBranches(LSTEvent* event, "tc_t5Idx", (ana.t5_branches ? t5_idx_map[trackCandidatesExtended.directObjectIndices()[tc_idx]] : -999)); if (ana.pls_branches) ana.tx->pushbackToBranch("tc_plsIdx", -999); + if (ana.t4_branches) + ana.tx->pushbackToBranch("tc_t4Idx", -999); } else if (type == LSTObjType::pLS) { if (ana.pt5_branches) ana.tx->pushbackToBranch("tc_pt5Idx", -999); @@ -1998,6 +2261,20 @@ void setTrackCandidateBranches(LSTEvent* event, (ana.pls_branches ? pls_idx_map[ranges.segmentModuleIndices()[modules.nLowerModules()] + trackCandidatesExtended.directObjectIndices()[tc_idx]] : -999)); + if (ana.t4_branches) + ana.tx->pushbackToBranch("tc_t4Idx", -999); + } else if (type == LSTObjType::T4) { + if (ana.pt5_branches) + ana.tx->pushbackToBranch("tc_pt5Idx", -999); + if (ana.pt3_branches) + ana.tx->pushbackToBranch("tc_pt3Idx", -999); + if (ana.t5_branches) + ana.tx->pushbackToBranch("tc_t5Idx", -999); + if (ana.pls_branches) + ana.tx->pushbackToBranch("tc_plsIdx", -999); + if (ana.t4_branches) + ana.tx->pushbackToBranch( + "tc_t4Idx", (ana.t4_branches ? t4_idx_map[trackCandidatesExtended.directObjectIndices()[tc_idx]] : -999)); } ana.tx->pushbackToBranch("tc_isFake", isFake); @@ -2110,6 +2387,7 @@ void setOccupancyBranches(LSTEvent* event) { auto segments = event->getSegments(); auto triplets = event->getTriplets(); auto quintuplets = event->getQuintuplets(); + auto quadruplets = event->getQuadruplets(); auto pixelQuintuplets = event->getPixelQuintuplets(); auto pixelTriplets = event->getPixelTriplets(); auto trackCandidatesBase = event->getTrackCandidatesBase(); @@ -2127,6 +2405,7 @@ void setOccupancyBranches(LSTEvent* event) { std::vector segmentOccupancy; std::vector mdOccupancy; std::vector quintupletOccupancy; + std::vector quadrupletOccupancy; for (unsigned int lowerIdx = 0; lowerIdx <= modules.nLowerModules(); lowerIdx++) { //layer = 0, subdet = 0 => pixel module @@ -2144,6 +2423,7 @@ void setOccupancyBranches(LSTEvent* event) { if (lowerIdx < modules.nLowerModules()) { quintupletOccupancy.push_back(quintuplets.totOccupancyQuintuplets()[lowerIdx]); + quadrupletOccupancy.push_back(quadruplets.totOccupancyQuadruplets()[lowerIdx]); tripletOccupancy.push_back(triplets.totOccupancyTriplets()[lowerIdx]); } } @@ -2162,6 +2442,7 @@ void setOccupancyBranches(LSTEvent* event) { ana.tx->setBranch("tc_occupancies", trackCandidatesBase.nTrackCandidates()); ana.tx->setBranch("pT3_occupancies", pixelTriplets.totOccupancyPixelTriplets()); ana.tx->setBranch>("t5_occupancies", quintupletOccupancy); + ana.tx->setBranch>("t4_occupancies", quadrupletOccupancy); ana.tx->setBranch("pT5_occupancies", pixelQuintuplets.totOccupancyPixelQuintuplets()); } @@ -2261,6 +2542,63 @@ void fillT5DNNBranches(LSTEvent* event, unsigned int iT3) { ana.tx->pushbackToBranch("t5_t3_phi", hitObjects[0].phi()); } +//________________________________________________________________________________________________________________________________ +void fillT4DNNBranches(LSTEvent* event, unsigned int iT3) { + auto hitsBase = event->getInput(); + auto hitsExtended = event->getHits(); + auto modules = event->getModules(); + + std::vector hitIdx = getHitsFromT3(event, iT3); + std::vector hitObjects(hitIdx.size()); + + auto const& trk_ph2_subdet = trk.getVUS("ph2_subdet"); + auto const& trk_ph2_layer = trk.getVUS("ph2_layer"); + auto const& trk_ph2_detId = trk.getVU("ph2_detId"); + + for (int i = 0; i < hitIdx.size(); ++i) { + unsigned int hit = hitIdx[i]; + float x = hitsBase.xs()[hit]; + float y = hitsBase.ys()[hit]; + float z = hitsBase.zs()[hit]; + hitObjects[i] = lst_math::Hit(x, y, z); + + std::string idx = std::to_string(i); + ana.tx->pushbackToBranch("t4_t3_" + idx + "_r", sqrt(x * x + y * y)); + ana.tx->pushbackToBranch("t4_t3_" + idx + "_x", x); + ana.tx->pushbackToBranch("t4_t3_" + idx + "_y", y); + ana.tx->pushbackToBranch("t4_t3_" + idx + "_z", z); + ana.tx->pushbackToBranch("t4_t3_" + idx + "_eta", hitObjects[i].eta()); + ana.tx->pushbackToBranch("t4_t3_" + idx + "_phi", hitObjects[i].phi()); + + int subdet = trk_ph2_subdet[hitsBase.idxs()[hit]]; + int is_endcap = subdet == 4; + int layer = trk_ph2_layer[hitsBase.idxs()[hit]] + 6 * is_endcap; + int detId = trk_ph2_detId[hitsBase.idxs()[hit]]; + unsigned int module = hitsExtended.moduleIndices()[hit]; + + ana.tx->pushbackToBranch("t4_t3_" + idx + "_detId", detId); + ana.tx->pushbackToBranch("t4_t3_" + idx + "_layer", layer); + ana.tx->pushbackToBranch("t4_t3_" + idx + "_moduleType", modules.moduleType()[module]); + ana.tx->pushbackToBranch("t4_t3_" + idx + "_moduleIdx", module); + } + + float radius; + auto const& devHost = cms::alpakatools::host(); + std::tie(radius, std::ignore, std::ignore) = computeRadiusFromThreeAnchorHits(devHost, + hitObjects[0].x(), + hitObjects[0].y(), + hitObjects[1].x(), + hitObjects[1].y(), + hitObjects[2].x(), + hitObjects[2].y()); + + ana.tx->pushbackToBranch("t4_t3_pt", k2Rinv1GeVf * 2 * radius); + + // Angles + ana.tx->pushbackToBranch("t4_t3_eta", hitObjects[2].eta()); + ana.tx->pushbackToBranch("t4_t3_phi", hitObjects[0].phi()); +} + //________________________________________________________________________________________________________________________________ void setT3DNNBranches(LSTEvent* event, float matchfrac) { auto const& trk_sim_parentVtxIdx = trk.getVI("sim_parentVtxIdx"); @@ -2400,6 +2738,70 @@ void setT5DNNBranches(LSTEvent* event) { } } +//________________________________________________________________________________________________________________________________ +void setT4DNNBranches(LSTEvent* event) { + auto tripletsOcc = event->getTriplets(); + auto tripletsSoA = event->getTriplets(); + auto modules = event->getModules(); + auto ranges = event->getRanges(); + auto const quadrupletsOcc = event->getQuadruplets(); + auto const quadruplets = event->getQuadruplets(); + auto trackCandidatesBase = event->getTrackCandidatesBase(); + auto trackCandidatesExtended = event->getTrackCandidatesExtended(); + + std::unordered_set allT3s; + std::unordered_map t3_index_map; + + for (unsigned int idx = 0; idx < modules.nLowerModules(); ++idx) { + for (unsigned int jdx = 0; jdx < tripletsOcc.nTriplets()[idx]; ++jdx) { + unsigned int t3Idx = ranges.tripletModuleIndices()[idx] + jdx; + if (allT3s.insert(t3Idx).second) { + t3_index_map[t3Idx] = allT3s.size() - 1; + fillT4DNNBranches(event, t3Idx); + } + } + } + + std::unordered_map t4_tc_index_map; + std::unordered_set t4s_used_in_tc; + + for (unsigned int idx = 0; idx < trackCandidatesBase.nTrackCandidates(); idx++) { + if (trackCandidatesBase.trackCandidateType()[idx] == LSTObjType::T4) { + unsigned int objIdx = trackCandidatesExtended.directObjectIndices()[idx]; + t4s_used_in_tc.insert(objIdx); + t4_tc_index_map[objIdx] = idx; + } + } + + for (unsigned int idx = 0; idx < modules.nLowerModules(); ++idx) { + for (unsigned int jdx = 0; jdx < quadrupletsOcc.nQuadruplets()[idx]; ++jdx) { + unsigned int t4Idx = ranges.quadrupletModuleIndices()[idx] + jdx; + std::vector t3sIdx = getT3sFromT4(event, t4Idx); + + ana.tx->pushbackToBranch("t4_t3_idx0", t3_index_map[t3sIdx[0]]); + ana.tx->pushbackToBranch("t4_t3_idx1", t3_index_map[t3sIdx[1]]); + + ana.tx->pushbackToBranch("t4_t3_fakeScore1", tripletsSoA.fakeScore()[t3sIdx[0]]); + ana.tx->pushbackToBranch("t4_t3_promptScore1", tripletsSoA.promptScore()[t3sIdx[0]]); + ana.tx->pushbackToBranch("t4_t3_displacedScore1", tripletsSoA.displacedScore()[t3sIdx[0]]); + ana.tx->pushbackToBranch("t4_t3_fakeScore2", tripletsSoA.fakeScore()[t3sIdx[1]]); + ana.tx->pushbackToBranch("t4_t3_promptScore2", tripletsSoA.promptScore()[t3sIdx[1]]); + ana.tx->pushbackToBranch("t4_t3_displacedScore2", tripletsSoA.displacedScore()[t3sIdx[1]]); + + ana.tx->pushbackToBranch("t4_regressionRadius", quadruplets.regressionRadius()[t4Idx]); + ana.tx->pushbackToBranch("t4_nonAnchorRegressionRadius", quadruplets.nonAnchorRegressionRadius()[t4Idx]); + + if (t4s_used_in_tc.find(t4Idx) != t4s_used_in_tc.end()) { + ana.tx->pushbackToBranch("t4_partOfTC", 1); + ana.tx->pushbackToBranch("t4_tc_idx", t4_tc_index_map[t4Idx]); + } else { + ana.tx->pushbackToBranch("t4_partOfTC", 0); + ana.tx->pushbackToBranch("t4_tc_idx", -999); + } + } + } +} + //________________________________________________________________________________________________________________________________ std::tuple> parseTrackCandidate( LSTEvent* event, @@ -2428,6 +2830,9 @@ std::tuple> parseTrackCandidate( case LSTObjType::T5: std::tie(pt, eta, phi, hit_idx, hit_type) = parseT5(event, idx, trk_ph2_x, trk_ph2_y, trk_ph2_z); break; + case LSTObjType::T4: + std::tie(pt, eta, phi, hit_idx, hit_type) = parseT4(event, idx, trk_ph2_x, trk_ph2_y, trk_ph2_z); + break; case LSTObjType::pLS: std::tie(pt, eta, phi, hit_idx, hit_type) = parsepLS(event, idx); break; @@ -2469,6 +2874,9 @@ std::tuple, std::vector> case LSTObjType::T5: std::tie(pt, eta, phi, hit_idx, hit_type) = parseT5(event, idx, trk_ph2_x, trk_ph2_y, trk_ph2_z); break; + case LSTObjType::T4: + std::tie(pt, eta, phi, hit_idx, hit_type) = parseT4(event, idx, trk_ph2_x, trk_ph2_y, trk_ph2_z); + break; case LSTObjType::pLS: std::tie(pt, eta, phi, hit_idx, hit_type) = parsepLS(event, idx); break; @@ -2669,6 +3077,44 @@ std::tuple, std::vector, std::vector> parseT4( + LSTEvent* event, + unsigned int idx, + std::vector const& trk_ph2_x, + std::vector const& trk_ph2_y, + std::vector const& trk_ph2_z) { + auto const trackCandidatesExtended = event->getTrackCandidatesExtended(); + auto const quadruplets = event->getQuadruplets(); + unsigned int t4 = trackCandidatesExtended.directObjectIndices()[idx]; + std::vector hits = getHitsFromT4(event, t4); + + // + // pictorial representation of a T4 + // + // inner tracker outer tracker + // ------------- -------------------------- + // 01 23 45 67 (anchor hit of a minidoublet is always the first of the pair) + // (none) oo -- oo -- oo -- oo T4 + unsigned int Hit_0 = hits[0]; + unsigned int Hit_2 = hits[2]; + unsigned int Hit_6 = hits[6]; + + // T4 radius is average of the inner and outer radius + const float pt = (quadruplets.innerRadius()[t4] + quadruplets.outerRadius()[t4]) * k2Rinv1GeVf; + + // T4 eta and phi are computed using outer and innermost hits + lst_math::Hit hitA(trk_ph2_x[Hit_0], trk_ph2_y[Hit_0], trk_ph2_z[Hit_0]); + lst_math::Hit hitB(trk_ph2_x[Hit_6], trk_ph2_y[Hit_6], trk_ph2_z[Hit_6]); + const float phi = hitA.phi(); + const float eta = hitB.eta(); + + std::vector hit_idx = getHitIdxsFromT4(event, t4); + std::vector hit_type = getHitTypesFromT4(event, t4); + + return {pt, eta, phi, hit_idx, hit_type}; +} + //________________________________________________________________________________________________________________________________ std::tuple, std::vector> parsepLS(LSTEvent* event, unsigned int idx) { diff --git a/RecoTracker/LSTCore/standalone/code/core/write_lst_ntuple.h b/RecoTracker/LSTCore/standalone/code/core/write_lst_ntuple.h index fa4a0194f884a..cb47f97ba7764 100644 --- a/RecoTracker/LSTCore/standalone/code/core/write_lst_ntuple.h +++ b/RecoTracker/LSTCore/standalone/code/core/write_lst_ntuple.h @@ -18,11 +18,13 @@ void createOutputBranches(); void createJetBranches(); void createT5DNNBranches(); void createT3DNNBranches(); +void createT4DNNBranches(); void createSimTrackContainerBranches(); void createTrackCandidateBranches(); void createMiniDoubletBranches(); void createLineSegmentBranches(); void createTripletBranches(); +void createQuadrupletBranches(); void createQuintupletBranches(); void createPixelLineSegmentBranches(); void createPixelTripletBranches(); @@ -38,6 +40,7 @@ void setTrackCandidateBranches(LSTEvent* event, std::map pls_idx_map, std::map pt3_idx_map, std::map pt5_idx_map, + std::map t4_idx_map, float matchfrac); std::map setMiniDoubletBranches(LSTEvent* event, unsigned int n_accepted_simtrk, @@ -50,6 +53,10 @@ std::map setTripletBranches(LSTEvent* event, unsigned int n_accepted_simtrk, float matchfrac, std::map const& ls_idx_map); +std::map setQuadrupletBranches(LSTEvent* event, + unsigned int n_accepted_simtrk, + float matchfrac, + std::map const& t3_idx_map); std::map setQuintupletBranches(LSTEvent* event, unsigned int n_accepted_simtrk, float matchfrac, @@ -70,8 +77,10 @@ std::map setPixelQuintupletBranches(LSTEvent* event, void fillT5DNNBranches(LSTEvent* event, unsigned int T3); void fillT3DNNBranches(LSTEvent* event, unsigned int iT3); +void fillT4DNNBranches(LSTEvent* event, unsigned int T4); void setT5DNNBranches(LSTEvent* event); void setT3DNNBranches(LSTEvent* event, float matchfrac = 0.75); +void setT4DNNBranches(LSTEvent* event); std::tuple> parseTrackCandidate( LSTEvent* event, @@ -103,6 +112,12 @@ std::tuple, std::vector const& trk_ph2_x, std::vector const& trk_ph2_y, std::vector const& trk_ph2_z); +std::tuple, std::vector> parseT4( + LSTEvent* event, + unsigned int, + std::vector const& trk_ph2_x, + std::vector const& trk_ph2_y, + std::vector const& trk_ph2_z); std::tuple, std::vector> parsepLS(LSTEvent* event, unsigned int); @@ -112,12 +127,9 @@ void printHitMultiplicities(LSTEvent* event); // Print objects (GPU) void printAllObjects(LSTEvent* event); -void printpT4s(LSTEvent* event); void printMDs(LSTEvent* event); void printLSs(LSTEvent* event); void printpLSs(LSTEvent* event); void printT3s(LSTEvent* event); -void printT4s(LSTEvent* event); -void printTCs(LSTEvent* event); #endif diff --git a/RecoTracker/LSTCore/standalone/efficiency/bin/lst_timing b/RecoTracker/LSTCore/standalone/efficiency/bin/lst_timing index b3f9d1d8966e8..86d68acfa75a1 100755 --- a/RecoTracker/LSTCore/standalone/efficiency/bin/lst_timing +++ b/RecoTracker/LSTCore/standalone/efficiency/bin/lst_timing @@ -149,5 +149,5 @@ echo "Total Timing Summary" grep -h "Time for map " timing_temp.txt | cut -d " " -f 6 | awk '{ SUM += $1} END { print "Average time for map loading =",SUM/5,"ms" }' # 5 is the number of stream values run grep -h "Time for input " timing_temp.txt | cut -d " " -f 6 | awk '{ SUM += $1} END { print "Average time for input loading =",SUM/5,"ms" }' # 5 is the number of stream values run grep -h "Time for event " timing_temp.txt | cut -d " " -f 6 | awk '{ SUM += $1} END { print "Average time for lst::Event creation =",SUM/21,"ms"}' # 5 is the number of total streams run (1+2+4+6+8) -echo " Evt Hits MD LS T3 T5 pLS pT5 pT3 TC Reset Event Short Rate" +echo " Evt Hits MD LS T3 T5 pLS T4 pT5 pT3 TC Reset Event Short Rate" grep -hr "avg " timing_temp.txt # space is needed to not get certain bad lines diff --git a/RecoTracker/LSTCore/standalone/efficiency/python/lst_plot_performance.py b/RecoTracker/LSTCore/standalone/efficiency/python/lst_plot_performance.py index 1b1bc3f81276f..8faeb1d8dd978 100755 --- a/RecoTracker/LSTCore/standalone/efficiency/python/lst_plot_performance.py +++ b/RecoTracker/LSTCore/standalone/efficiency/python/lst_plot_performance.py @@ -10,7 +10,7 @@ sel_choices = ["base", "loweta", "xtr", "vtr", "none"] metric_choices = ["eff", "fakerate", "duplrate", "fakeorduplrate"] variable_choices = ["pt", "ptmtv", "ptlow", "eta", "phi", "dxy", "dz", "vxy", "deltaEta", "deltaPhi", "deltaR", "jet_eta", "jet_phi", "jet_pt"] -objecttype_choices = ["TC", "pT5", "T5", "pT3", "pLS", "pT5_lower", "pT3_lower", "T5_lower", "pLS_lower"] +objecttype_choices = ["TC", "pT5", "T5", "pT3", "pLS", "T4", "pT5_lower", "pT3_lower", "T5_lower", "pLS_lower"] #lowerObjectType = ["pT5_lower", "pT3_lower", "T5_lower"] r.gROOT.SetBatch(True) @@ -125,7 +125,7 @@ def plot(args): numer = [] numer.append(params["input_file"].Get(params["numer"]).Clone()) - breakdown_hist_types = ["pT5", "pT3", "T5", "pLS"] + breakdown_hist_types = ["pT5", "pT3", "T5", "pLS", "T4"] print("breakdown = ", params["breakdown"]) if params["breakdown"]: for breakdown_hist_type in breakdown_hist_types: @@ -143,7 +143,7 @@ def plot(args): if params["breakdown"]: - params["legend_labels"] = ["TC" ,"pT5" ,"pT3" ,"T5" ,"pLS"] + params["legend_labels"] = ["TC" ,"pT5" ,"pT3" ,"T5" ,"pLS", "T4"] else: params["legend_labels"] = [args.objecttype] @@ -382,7 +382,7 @@ def parse_plot_name(output_name): elif "pT4_" in output_name: rtnstr.append("Quadruplet w/ Pixel LS") elif "T4_" in output_name: - rtnstr.append("Quadruplet w/o gap") + rtnstr.append("Quadruplet") elif "T4x_" in output_name: rtnstr.append("Quadruplet w/ gap") elif "pT3_" in output_name: @@ -559,8 +559,8 @@ def draw_plot(effs, nums, dens, params): effs[0].SetTitle(parse_plot_name(output_name)) # Draw the efficiency graphs - colors = [1, 2, 3, 4, 6] - markerstyles = [20, 26, 28, 24, 27] + colors = [1, 2, 3, 4, 6, 7] + markerstyles = [20, 26, 28, 24, 27, 25] markersize = 1.2 linewidth = 2 for i, eff in enumerate(effs): @@ -708,6 +708,7 @@ def plot_standard_performance_plots(args): "pT3": [False], "T5": [False], "pLS": [False], + "T4": [False], "pT5_lower":[False], "pT3_lower":[False], "T5_lower":[False], @@ -719,6 +720,7 @@ def plot_standard_performance_plots(args): "pT3": [False], "T5": [False], "pLS": [False], + "T4": [False], "pT5_lower":[False], "pT3_lower":[False], "T5_lower":[False], @@ -730,6 +732,7 @@ def plot_standard_performance_plots(args): "pT3": [False], "T5": [False], "pLS": [False], + "T4": [False], "pT5_lower":[False], "pT3_lower":[False], "T5_lower":[False], @@ -741,6 +744,7 @@ def plot_standard_performance_plots(args): "pT3": [False], "T5": [False], "pLS": [False], + "T4": [False], "pT5_lower":[False], "pT3_lower":[False], "T5_lower":[False], @@ -842,4 +846,3 @@ def plot_standard_performance_plots(args): if __name__ == "__main__": main() - diff --git a/RecoTracker/LSTCore/standalone/efficiency/src/performance.cc b/RecoTracker/LSTCore/standalone/efficiency/src/performance.cc index c3585bae77223..56bac3b6de395 100644 --- a/RecoTracker/LSTCore/standalone/efficiency/src/performance.cc +++ b/RecoTracker/LSTCore/standalone/efficiency/src/performance.cc @@ -1,6 +1,6 @@ #include "performance.h" -enum { pT5 = 7, pT3 = 5, T5 = 4, pLS = 8 }; +enum { pT5 = 7, pT3 = 5, T5 = 4, pLS = 8, T4 = 9 }; //__________________________________________________________________________________________________________________________________________________________________________ int main(int argc, char** argv) { @@ -96,6 +96,18 @@ int main(int argc, char** argv) { return lstEff_sim_tcIdx.at(isim) >= 0 ? lstEff_tc_type.at(lstEff_sim_tcIdx.at(isim)) == pLS : false; }, /* sel */ sels[isel])); + list_effSetDef.push_back(SimTrackSetDefinition( + /* name */ + TString("T4_") + selnames[isel], + /* pdgid */ pdgid, + /* q */ charge, + /* pass */ + [&](unsigned int isim) { + auto& lstEff_sim_tcIdx = lstEff.getVI("sim_tcIdx"); + auto& lstEff_tc_type = lstEff.getVI("tc_type"); + return lstEff_sim_tcIdx.at(isim) >= 0 ? lstEff_tc_type.at(lstEff_sim_tcIdx.at(isim)) == T4 : false; + }, + /* sel */ sels[isel])); if (ana.do_lower_level) { //lower objects - name will have pT5_lower_, T5_lower_, pT3_lower_ @@ -222,6 +234,17 @@ int main(int argc, char** argv) { /* eta */ [&]() { return lstEff.getVF("tc_eta"); }, /* phi */ [&]() { return lstEff.getVF("tc_phi"); }, /* type */ [&]() { return lstEff.getVI("tc_type"); })); + list_FRSetDef.push_back( + RecoTrackSetDefinition(/* name */ + "T4", + /* pass */ + [&](unsigned int itc) { return lstEff.getVI("tc_isFake").at(itc) > 0; }, + /* sel */ + [&](unsigned int itc) { return lstEff.getVI("tc_type").at(itc) == T4; }, + /* pt */ [&]() { return lstEff.getVF("tc_pt"); }, + /* eta */ [&]() { return lstEff.getVF("tc_eta"); }, + /* phi */ [&]() { return lstEff.getVF("tc_phi"); }, + /* type */ [&]() { return lstEff.getVI("tc_type"); })); if (ana.do_lower_level) { list_FRSetDef.push_back(RecoTrackSetDefinition( @@ -320,6 +343,17 @@ int main(int argc, char** argv) { /* eta */ [&]() { return lstEff.getVF("tc_eta"); }, /* phi */ [&]() { return lstEff.getVF("tc_phi"); }, /* type */ [&]() { return lstEff.getVI("tc_type"); })); + list_DRSetDef.push_back( + RecoTrackSetDefinition(/* name */ + "T4", + /* pass */ + [&](unsigned int itc) { return lstEff.getVI("tc_isDuplicate").at(itc) > 0; }, + /* sel */ + [&](unsigned int itc) { return lstEff.getVI("tc_type").at(itc) == T4; }, + /* pt */ [&]() { return lstEff.getVF("tc_pt"); }, + /* eta */ [&]() { return lstEff.getVF("tc_eta"); }, + /* phi */ [&]() { return lstEff.getVF("tc_phi"); }, + /* type */ [&]() { return lstEff.getVI("tc_type"); })); if (ana.do_lower_level) { list_DRSetDef.push_back(RecoTrackSetDefinition( @@ -432,6 +466,20 @@ int main(int argc, char** argv) { /* eta */ [&]() { return lstEff.getVF("tc_eta"); }, /* phi */ [&]() { return lstEff.getVF("tc_phi"); }, /* type */ [&]() { return lstEff.getVI("tc_type"); })); + list_FDRSetDef.push_back( + RecoTrackSetDefinition(/* name */ + "T4", + /* pass */ + [&](unsigned int itc) { + return (lstEff.getVI("tc_isDuplicate").at(itc) > 0) or + (lstEff.getVI("tc_isFake").at(itc) > 0); + }, + /* sel */ + [&](unsigned int itc) { return lstEff.getVI("tc_type").at(itc) == T4; }, + /* pt */ [&]() { return lstEff.getVF("tc_pt"); }, + /* eta */ [&]() { return lstEff.getVF("tc_eta"); }, + /* phi */ [&]() { return lstEff.getVF("tc_phi"); }, + /* type */ [&]() { return lstEff.getVI("tc_type"); })); if (ana.do_lower_level) { list_FDRSetDef.push_back(RecoTrackSetDefinition( diff --git a/Validation/RecoTrack/python/HLTmultiTrackValidator_cff.py b/Validation/RecoTrack/python/HLTmultiTrackValidator_cff.py index e2418e360d245..5c246ba910a9e 100644 --- a/Validation/RecoTrack/python/HLTmultiTrackValidator_cff.py +++ b/Validation/RecoTrack/python/HLTmultiTrackValidator_cff.py @@ -71,5 +71,5 @@ def _modifyForNGTScouting(trackvalidator): (ngtScouting & ~trackingLST).toModify(hltTrackValidator, _modifyForNGTScouting) def _modifyForNGTScoutingLST(trackvalidator): - trackvalidator.label = ["hltGeneralTracks", "hltPhase2PixelTracks", "hltInitialStepTracksT5TCLST", "hltPixelLessTracks", "hltWithPixelTracks"] + trackvalidator.label = ["hltGeneralTracks", "hltPhase2PixelTracks", "hltInitialStepTracksT4T5TCLST", "hltPixelLessTracks", "hltWithPixelTracks"] (ngtScouting & trackingLST).toModify(hltTrackValidator, _modifyForNGTScoutingLST)