-
Notifications
You must be signed in to change notification settings - Fork 4.6k
Track classifier using TensorFlow #31682
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
1756a7b
804ece6
28df534
80bac8f
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
| @@ -0,0 +1,15 @@ | ||||||
| #ifndef TrackTfGraph_TfGraphDefWrapper_h | ||||||
| #define TrackTfGraph_TfGraphDefWrapper_h | ||||||
|
|
||||||
| #include "PhysicsTools/TensorFlow/interface/TensorFlow.h" | ||||||
|
|
||||||
| class TfGraphDefWrapper { | ||||||
| public: | ||||||
| TfGraphDefWrapper(tensorflow::GraphDef*); | ||||||
| tensorflow::GraphDef* GetGraphDef() const; | ||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This must return
Suggested change
to ensure the graph is used in const-thread safe manner. |
||||||
|
|
||||||
| private: | ||||||
| tensorflow::GraphDef* graphDef_; | ||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. How about
Suggested change
? |
||||||
| }; | ||||||
|
|
||||||
| #endif | ||||||
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
| @@ -0,0 +1,66 @@ | ||||||
| // -*- C++ -*- | ||||||
| // | ||||||
| // Package: test/TFGraphDefProducer | ||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||||||
| // Class: TFGraphDefProducer | ||||||
| // | ||||||
| /**\class TFGraphDefProducer | ||||||
| Description: Produces TfGraphRecord into the event containing a tensorflow GraphDef object that can be used for running inference on a pretrained network | ||||||
| */ | ||||||
| // | ||||||
| // Original Author: Joona Havukainen | ||||||
| // Created: Fri, 24 Jul 2020 08:04:00 GMT | ||||||
| // | ||||||
| // | ||||||
|
|
||||||
| // system include files | ||||||
| #include <memory> | ||||||
|
|
||||||
| // user include files | ||||||
| #include "FWCore/Framework/interface/ModuleFactory.h" | ||||||
| #include "FWCore/Framework/interface/ESProducer.h" | ||||||
|
|
||||||
| #include "FWCore/Framework/interface/ESHandle.h" | ||||||
| #include "TrackingTools/Records/interface/TfGraphRecord.h" | ||||||
| #include "RecoTracker/FinalTrackSelectors/interface/TfGraphDefWrapper.h" | ||||||
|
|
||||||
| // class declaration | ||||||
|
|
||||||
| class TfGraphDefProducer : public edm::ESProducer { | ||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Just to note that if this class was made an ESSource, an
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Given that the ESProducer itself is generic (even if the Session is made part of the ESProduct) I can easily imagine other users in the future, in which case it would be better to stay using |
||||||
| public: | ||||||
| TfGraphDefProducer(const edm::ParameterSet&); | ||||||
|
|
||||||
| using ReturnType = std::unique_ptr<TfGraphDefWrapper>; | ||||||
|
|
||||||
| ReturnType produce(const TfGraphRecord&); | ||||||
|
|
||||||
| static void fillDescriptions(edm::ConfigurationDescriptions& descriptions); | ||||||
|
|
||||||
| private: | ||||||
| TfGraphDefWrapper wrapper_; | ||||||
|
|
||||||
| // ----------member data --------------------------- | ||||||
| }; | ||||||
|
|
||||||
| TfGraphDefProducer::TfGraphDefProducer(const edm::ParameterSet& iConfig) | ||||||
| : wrapper_( | ||||||
| TfGraphDefWrapper(tensorflow::loadGraphDef(iConfig.getParameter<edm::FileInPath>("FileName").fullPath()))) { | ||||||
| auto componentName = iConfig.getParameter<std::string>("ComponentName"); | ||||||
| setWhatProduced(this, componentName); | ||||||
| } | ||||||
|
|
||||||
| // ------------ method called to produce the data ------------ | ||||||
| std::unique_ptr<TfGraphDefWrapper> TfGraphDefProducer::produce(const TfGraphRecord& iRecord) { | ||||||
| return std::unique_ptr<TfGraphDefWrapper>(&wrapper_); | ||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This leads to double-delete (once by the destructor of I would actually suggest
Suggested change
This way the graph is loaded only if some other module consumes it instead of at the construction time. |
||||||
| } | ||||||
|
|
||||||
| void TfGraphDefProducer::fillDescriptions(edm::ConfigurationDescriptions& descriptions) { | ||||||
| edm::ParameterSetDescription desc; | ||||||
| desc.add<std::string>("ComponentName", "tfGraphDef"); | ||||||
| desc.add<edm::FileInPath>("FileName", edm::FileInPath()); | ||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Leaving out default would make it clear already in the python that a necessary parameter has not been set. |
||||||
| descriptions.add("tfGraphDefProducer", desc); | ||||||
| } | ||||||
|
|
||||||
| //define this as a plug-in | ||||||
| #include "FWCore/PluginManager/interface/ModuleDef.h" | ||||||
| #include "FWCore/Framework/interface/MakerMacros.h" | ||||||
| DEFINE_FWK_EVENTSETUP_MODULE(TfGraphDefProducer); | ||||||
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
| @@ -0,0 +1,123 @@ | ||||||
| #include "RecoTracker/FinalTrackSelectors/interface/TrackMVAClassifier.h" | ||||||
|
|
||||||
| #include "FWCore/Framework/interface/EventSetup.h" | ||||||
| #include "FWCore/Framework/interface/ESHandle.h" | ||||||
|
|
||||||
| #include "DataFormats/TrackReco/interface/Track.h" | ||||||
| #include "DataFormats/VertexReco/interface/Vertex.h" | ||||||
|
|
||||||
| #include "getBestVertex.h" | ||||||
|
|
||||||
| #include "TrackingTools/Records/interface/TfGraphRecord.h" | ||||||
| #include "PhysicsTools/TensorFlow/interface/TensorFlow.h" | ||||||
| #include "RecoTracker/FinalTrackSelectors/interface/TfGraphDefWrapper.h" | ||||||
|
|
||||||
| struct TfDnnCache { | ||||||
| TfDnnCache() : session_(nullptr) {} | ||||||
| tensorflow::Session* session_; | ||||||
| }; | ||||||
|
|
||||||
| namespace { | ||||||
| struct tfDnn { | ||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. coding rule 2.6 [1]:
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. In addition, why struct and not class? Seems like this is a nontrivial object with state, initialization etc. |
||||||
| tfDnn(const edm::ParameterSet& cfg) | ||||||
| : tfDnnLabel_(cfg.getParameter<std::string>("tfDnnLabel")) | ||||||
|
|
||||||
| {} | ||||||
| TfDnnCache* cache_ = new TfDnnCache(); | ||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Would
Suggested change
work? |
||||||
|
|
||||||
| ~tfDnn() { | ||||||
| if (cache_->session_) { | ||||||
| tensorflow::closeSession(cache_->session_); | ||||||
| delete cache_; | ||||||
| } | ||||||
| } | ||||||
|
|
||||||
| static const char* name() { return "TrackTfClassifier"; } | ||||||
|
|
||||||
| static void fillDescriptions(edm::ParameterSetDescription& desc) { | ||||||
| desc.add<std::string>("tfDnnLabel", "trackSelectionTf"); | ||||||
| } | ||||||
|
|
||||||
| void beginStream() {} | ||||||
|
|
||||||
| void initEvent(const edm::EventSetup& es) { | ||||||
| if (!cache_->session_) { | ||||||
| edm::ESHandle<TfGraphDefWrapper> tfDnnHandle; | ||||||
| es.get<TfGraphRecord>().get(tfDnnLabel_, tfDnnHandle); | ||||||
| tensorflow::GraphDef* graphDef_ = tfDnnHandle.product()->GetGraphDef(); | ||||||
| cache_->session_ = tensorflow::createSession( | ||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ah, ok, this creates a new session per stream.
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ah, yes, I could have pointed you to that, sorry. I thought you were asking why it was implemented that way.
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. we need
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I thought Session is not thread safe, and therefore may not be shared across stream instances.
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. all other TF sessions that we have are in streams; although I'm still not sure if this was done deliberately due to the session being not MT safe. |
||||||
| graphDef_, 1); //The integer controls how many threads are used for running inference | ||||||
| } | ||||||
| session_ = cache_->session_; | ||||||
| } | ||||||
|
|
||||||
| float operator()(reco::Track const& trk, | ||||||
| reco::BeamSpot const& beamSpot, | ||||||
| reco::VertexCollection const& vertices) const { | ||||||
| Point bestVertex = getBestVertex(trk, vertices); | ||||||
|
|
||||||
| tensorflow::Tensor input1(tensorflow::DT_FLOAT, {1, 29}); | ||||||
| tensorflow::Tensor input2(tensorflow::DT_FLOAT, {1, 1}); | ||||||
|
|
||||||
| input1.matrix<float>()(0, 0) = trk.pt(); | ||||||
| input1.matrix<float>()(0, 1) = trk.innerMomentum().x(); | ||||||
| input1.matrix<float>()(0, 2) = trk.innerMomentum().y(); | ||||||
| input1.matrix<float>()(0, 3) = trk.innerMomentum().z(); | ||||||
| input1.matrix<float>()(0, 4) = trk.innerMomentum().rho(); | ||||||
| input1.matrix<float>()(0, 5) = trk.outerMomentum().x(); | ||||||
| input1.matrix<float>()(0, 6) = trk.outerMomentum().y(); | ||||||
| input1.matrix<float>()(0, 7) = trk.outerMomentum().z(); | ||||||
| input1.matrix<float>()(0, 8) = trk.outerMomentum().rho(); | ||||||
| input1.matrix<float>()(0, 9) = trk.ptError(); | ||||||
| input1.matrix<float>()(0, 10) = trk.dxy(bestVertex); | ||||||
| input1.matrix<float>()(0, 11) = trk.dz(bestVertex); | ||||||
| input1.matrix<float>()(0, 12) = trk.dxy(beamSpot.position()); | ||||||
| input1.matrix<float>()(0, 13) = trk.dz(beamSpot.position()); | ||||||
| input1.matrix<float>()(0, 14) = trk.dxyError(); | ||||||
| input1.matrix<float>()(0, 15) = trk.dzError(); | ||||||
| input1.matrix<float>()(0, 16) = trk.normalizedChi2(); | ||||||
| input1.matrix<float>()(0, 17) = trk.eta(); | ||||||
| input1.matrix<float>()(0, 18) = trk.phi(); | ||||||
| input1.matrix<float>()(0, 19) = trk.etaError(); | ||||||
| input1.matrix<float>()(0, 20) = trk.phiError(); | ||||||
| input1.matrix<float>()(0, 21) = trk.hitPattern().numberOfValidPixelHits(); | ||||||
| input1.matrix<float>()(0, 22) = trk.hitPattern().numberOfValidStripHits(); | ||||||
| input1.matrix<float>()(0, 23) = trk.ndof(); | ||||||
| input1.matrix<float>()(0, 24) = trk.hitPattern().numberOfLostTrackerHits(reco::HitPattern::MISSING_INNER_HITS); | ||||||
| input1.matrix<float>()(0, 25) = trk.hitPattern().numberOfLostTrackerHits(reco::HitPattern::MISSING_OUTER_HITS); | ||||||
| input1.matrix<float>()(0, 26) = | ||||||
| trk.hitPattern().trackerLayersTotallyOffOrBad(reco::HitPattern::MISSING_INNER_HITS); | ||||||
| input1.matrix<float>()(0, 27) = | ||||||
| trk.hitPattern().trackerLayersTotallyOffOrBad(reco::HitPattern::MISSING_OUTER_HITS); | ||||||
| input1.matrix<float>()(0, 28) = trk.hitPattern().trackerLayersWithoutMeasurement(reco::HitPattern::TRACK_HITS); | ||||||
|
|
||||||
| //Original algo as its own input, it will enter the graph so that it gets one-hot encoded, as is the preferred | ||||||
| //format for categorical inputs, where the labels do not have any metric amongst them | ||||||
| input2.matrix<float>()(0, 0) = trk.originalAlgo(); | ||||||
|
|
||||||
| //The names for the input tensors get locked when freezing the trained tensorflow model. The NamedTensors must | ||||||
| //match those names | ||||||
| tensorflow::NamedTensorList inputs; | ||||||
| inputs.resize(2); | ||||||
| inputs[0] = tensorflow::NamedTensor("x", input1); | ||||||
| inputs[1] = tensorflow::NamedTensor("y", input2); | ||||||
| std::vector<tensorflow::Tensor> outputs; | ||||||
|
|
||||||
| //evaluate the input | ||||||
| tensorflow::run(session_, inputs, {"Identity"}, &outputs); | ||||||
|
|
||||||
| //scale output to be [-1, 1] due to convention | ||||||
| float output = 2.0 * outputs[0].matrix<float>()(0, 0) - 1.0; | ||||||
| return output; | ||||||
| } | ||||||
|
|
||||||
| std::string tfDnnLabel_; | ||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. can be const, I think? |
||||||
| tensorflow::Session* session_; | ||||||
| }; | ||||||
|
|
||||||
| using TrackTfClassifier = TrackMVAClassifier<tfDnn>; | ||||||
| } // namespace | ||||||
| #include "FWCore/PluginManager/interface/ModuleDef.h" | ||||||
| #include "FWCore/Framework/interface/MakerMacros.h" | ||||||
|
|
||||||
| DEFINE_FWK_MODULE(TrackTfClassifier); | ||||||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,5 @@ | ||
| from DataFormats.TrackTfGraph.tfGraphDefProducer_cfi import tfGraphDefProducer as _tfGraphDefProducer | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This needs to be changed to reflect the new location of the TfGraphDefProducer. |
||
| trackSelectionTf = _tfGraphDefProducer.clone( | ||
| ComponentName = "trackSelectionTf", | ||
| FileName = "RecoTracker/FinalTrackSelectors/data/QCDFlatPU_QCDHighPt_ZEE_DisplacedSUSY_2020.pb" | ||
| ) | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,3 @@ | ||
| #include "PhysicsTools/TensorFlow/interface/TensorFlow.h" | ||
| #include "FWCore/Utilities/interface/typelookup.h" | ||
| TYPELOOKUP_DATA_REG(tensorflow::Session); | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This should indeed be unnecessary. |
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,4 @@ | ||
| #include "RecoTracker/FinalTrackSelectors/interface/TfGraphDefWrapper.h" | ||
| #include "FWCore/Utilities/interface/typelookup.h" | ||
|
|
||
| TYPELOOKUP_DATA_REG(TfGraphDefWrapper); |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,5 @@ | ||
| #include "RecoTracker/FinalTrackSelectors/interface/TfGraphDefWrapper.h" | ||
|
|
||
| TfGraphDefWrapper::TfGraphDefWrapper(tensorflow::GraphDef* graph) { graphDef_ = graph; } | ||
|
|
||
| tensorflow::GraphDef* TfGraphDefWrapper::GetGraphDef() const { return graphDef_; } |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -4,6 +4,7 @@ | |
|
|
||
| #for dnn classifier | ||
| from Configuration.ProcessModifiers.trackdnn_cff import trackdnn | ||
| from dnnQualityCuts import qualityCutDictionary | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think we have a convention of using full import paths, i.e. from RecoTracker.IterativeTracking.dnnQualityCuts import qualityCutDictionaryThis comment applies everywhere where this is imported.
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. as it is, this import also won't work with python3. |
||
|
|
||
| ############################################### | ||
| # Low pT and detached tracks from pixel quadruplets | ||
|
|
@@ -212,11 +213,11 @@ | |
| qualityCuts = [-0.5,0.0,0.5] | ||
| ) | ||
|
|
||
| from RecoTracker.FinalTrackSelectors.TrackLwtnnClassifier_cfi import * | ||
| from RecoTracker.FinalTrackSelectors.trackSelectionLwtnn_cfi import * | ||
| trackdnn.toReplaceWith(detachedQuadStep, TrackLwtnnClassifier.clone( | ||
| src = 'detachedQuadStepTracks', | ||
| qualityCuts = [-0.6, 0.05, 0.7] | ||
| from RecoTracker.FinalTrackSelectors.TrackTfClassifier_cfi import * | ||
| from RecoTracker.FinalTrackSelectors.trackSelectionTf_cfi import * | ||
| trackdnn.toReplaceWith(detachedQuadStep, TrackTfClassifier.clone( | ||
| src = 'detachedQuadStepTracks', | ||
| qualityCuts = qualityCutDictionary['DetachedQuadStep'] | ||
| )) | ||
|
|
||
| highBetaStar_2018.toModify(detachedQuadStep,qualityCuts = [-0.7,0.0,0.5]) | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
the header guard should be updated to match the path:
RecoTracker_FinalTrackSelectors_TfGraphDefWrapper_h