-
Notifications
You must be signed in to change notification settings - Fork 4.6k
Track classifier using TensorFlow #31682
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 2 commits
1756a7b
804ece6
28df534
80bac8f
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,5 @@ | ||
| <use name="PhysicsTools/TensorFlow"/> | ||
| <use name="RecoTracker/Record"/> | ||
| <export> | ||
| <lib name="1"/> | ||
| </export> |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,16 @@ | ||
| #ifndef TrackTfGraph_TfGraphDefWrapper_h | ||
| #define TrackTfGraph_TfGraphDefWrapper_h | ||
|
|
||
| #include "PhysicsTools/TensorFlow/interface/TensorFlow.h" | ||
|
|
||
| class TfGraphDefWrapper { | ||
| public: | ||
| TfGraphDefWrapper(tensorflow::GraphDef*); | ||
| tensorflow::GraphDef* GetGraphDef() const; | ||
|
|
||
| private: | ||
| tensorflow::GraphDef* graphDef_; | ||
| }; | ||
|
|
||
|
|
||
| #endif |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,11 @@ | ||
| <use name="FWCore/Framework"/> | ||
| <use name="FWCore/ParameterSet"/> | ||
| <use name="FWCore/PluginManager"/> | ||
| <use name="PhysicsTools/TensorFlow"/> | ||
| <use name="FWCore/Utilities"/> | ||
| <use name="RecoTracker/Record"/> | ||
| <use name="DataFormats/TrackTfGraph"/> | ||
| <flags EDM_PLUGIN="1"/> | ||
| <export> | ||
| <lib name="1"/> | ||
| </export> | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,66 @@ | ||
| // -*- C++ -*- | ||
| // | ||
| // Package: test/TFGraphDefProducer | ||
| // Class: TFGraphDefProducer | ||
| // | ||
| /**\class TFGraphDefProducer | ||
| Description: Produces TfGraphRecord into the event containing a tensorflow GraphDef object that can be used for running inference on a pretrained network | ||
| */ | ||
| // | ||
| // Original Author: Joona Havukainen | ||
| // Created: Fri, 24 Jul 2020 08:04:00 GMT | ||
| // | ||
| // | ||
|
|
||
| // system include files | ||
| #include <memory> | ||
|
|
||
| // user include files | ||
| #include "FWCore/Framework/interface/ModuleFactory.h" | ||
| #include "FWCore/Framework/interface/ESProducer.h" | ||
|
|
||
| #include "FWCore/Framework/interface/ESHandle.h" | ||
| #include "TrackingTools/Records/interface/TfGraphRecord.h" | ||
| #include "DataFormats/TrackTfGraph/interface/TfGraphDefWrapper.h" | ||
|
|
||
| // class declaration | ||
|
|
||
| class TfGraphDefProducer : public edm::ESProducer { | ||
| public: | ||
| TfGraphDefProducer(const edm::ParameterSet&); | ||
|
|
||
| using ReturnType = std::unique_ptr<TfGraphDefWrapper>; | ||
|
|
||
| ReturnType produce(const TfGraphRecord&); | ||
|
|
||
| static void fillDescriptions(edm::ConfigurationDescriptions& descriptions); | ||
|
|
||
| private: | ||
| TfGraphDefWrapper wrapper_; | ||
|
|
||
| // ----------member data --------------------------- | ||
| }; | ||
|
|
||
| TfGraphDefProducer::TfGraphDefProducer(const edm::ParameterSet& iConfig): | ||
| wrapper_(TfGraphDefWrapper(tensorflow::loadGraphDef(iConfig.getParameter<edm::FileInPath>("FileName").fullPath()))) | ||
| { | ||
| auto componentName = iConfig.getParameter<std::string>("ComponentName"); | ||
| setWhatProduced(this, componentName); | ||
| } | ||
|
|
||
| // ------------ method called to produce the data ------------ | ||
| std::unique_ptr<TfGraphDefWrapper> TfGraphDefProducer::produce(const TfGraphRecord& iRecord) { | ||
| return std::unique_ptr<TfGraphDefWrapper>(&wrapper_); | ||
| } | ||
|
|
||
| void TfGraphDefProducer::fillDescriptions(edm::ConfigurationDescriptions& descriptions) { | ||
| edm::ParameterSetDescription desc; | ||
| desc.add<std::string>("ComponentName", "tfGraphDef"); | ||
| desc.add<edm::FileInPath>("FileName", edm::FileInPath()); | ||
| descriptions.add("tfGraphDefProducer", desc); | ||
| } | ||
|
|
||
| //define this as a plug-in | ||
| #include "FWCore/PluginManager/interface/ModuleDef.h" | ||
| #include "FWCore/Framework/interface/MakerMacros.h" | ||
| DEFINE_FWK_EVENTSETUP_MODULE(TfGraphDefProducer); |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,4 @@ | ||
| #include "DataFormats/TrackTfGraph/interface/TfGraphDefWrapper.h" | ||
| #include "FWCore/Utilities/interface/typelookup.h" | ||
|
|
||
| TYPELOOKUP_DATA_REG(TfGraphDefWrapper); |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,7 @@ | ||
| #include "DataFormats/TrackTfGraph/interface/TfGraphDefWrapper.h" | ||
|
|
||
| TfGraphDefWrapper::TfGraphDefWrapper(tensorflow::GraphDef* graph) {graphDef_ = graph;} | ||
|
|
||
| tensorflow::GraphDef* TfGraphDefWrapper::GetGraphDef() const { | ||
| return graphDef_; | ||
| } |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,122 @@ | ||
| #include "RecoTracker/FinalTrackSelectors/interface/TrackMVAClassifier.h" | ||
|
|
||
| #include "FWCore/Framework/interface/EventSetup.h" | ||
| #include "FWCore/Framework/interface/ESHandle.h" | ||
|
|
||
| #include "DataFormats/TrackReco/interface/Track.h" | ||
| #include "DataFormats/VertexReco/interface/Vertex.h" | ||
|
|
||
| #include "getBestVertex.h" | ||
|
|
||
| #include "TrackingTools/Records/interface/TfGraphRecord.h" | ||
| #include "PhysicsTools/TensorFlow/interface/TensorFlow.h" | ||
| #include "DataFormats/TrackTfGraph/interface/TfGraphDefWrapper.h" | ||
|
|
||
| struct TfDnnCache { | ||
| TfDnnCache() : session_(nullptr) {} | ||
| tensorflow::Session* session_; | ||
| }; | ||
|
|
||
| namespace { | ||
| struct tfDnn { | ||
| tfDnn(const edm::ParameterSet& cfg): | ||
| tfDnnLabel_(cfg.getParameter<std::string>("tfDnnLabel")) | ||
|
|
||
| {} | ||
| TfDnnCache* cache_ = new TfDnnCache(); | ||
|
|
||
| ~tfDnn() { | ||
| if(cache_->session_) { | ||
| tensorflow::closeSession(cache_->session_); | ||
| delete cache_; | ||
| } | ||
| } | ||
|
|
||
| static const char *name() { return "TrackTfClassifier"; } | ||
|
|
||
| static void fillDescriptions(edm::ParameterSetDescription& desc) { | ||
| desc.add<std::string>("tfDnnLabel", "trackSelectionTf"); | ||
| } | ||
|
|
||
| void beginStream() {} | ||
|
|
||
| void initEvent(const edm::EventSetup& es) { | ||
| if (!cache_->session_) { | ||
| edm::ESHandle<TfGraphDefWrapper> tfDnnHandle; | ||
| es.get<TfGraphRecord>().get(tfDnnLabel_, tfDnnHandle); | ||
| tensorflow::GraphDef *graphDef_ = tfDnnHandle.product()->GetGraphDef(); | ||
| cache_->session_ = tensorflow::createSession(graphDef_, 1); //The integer controls how many threads are used for running inference | ||
| } | ||
| session_ = cache_->session_; | ||
| } | ||
|
|
||
| float operator()(reco::Track const & trk, | ||
| reco::BeamSpot const & beamSpot, | ||
| reco::VertexCollection const & vertices) const { | ||
|
|
||
| Point bestVertex = getBestVertex(trk, vertices); | ||
|
|
||
| tensorflow::Tensor input1(tensorflow::DT_FLOAT, {1, 29}); | ||
| tensorflow::Tensor input2(tensorflow::DT_FLOAT, {1, 1}); | ||
|
|
||
| input1.matrix<float>()(0, 0) = trk.pt(); | ||
| input1.matrix<float>()(0, 1) = trk.innerMomentum().x(); | ||
| input1.matrix<float>()(0, 2) = trk.innerMomentum().y(); | ||
| input1.matrix<float>()(0, 3) = trk.innerMomentum().z(); | ||
| input1.matrix<float>()(0, 4) = trk.innerMomentum().rho(); | ||
| input1.matrix<float>()(0, 5) = trk.outerMomentum().x(); | ||
| input1.matrix<float>()(0, 6) = trk.outerMomentum().y(); | ||
| input1.matrix<float>()(0, 7) = trk.outerMomentum().z(); | ||
| input1.matrix<float>()(0, 8) = trk.outerMomentum().rho(); | ||
| input1.matrix<float>()(0, 9) = trk.ptError(); | ||
| input1.matrix<float>()(0, 10) = trk.dxy(bestVertex); | ||
| input1.matrix<float>()(0, 11) = trk.dz(bestVertex); | ||
| input1.matrix<float>()(0, 12) = trk.dxy(beamSpot.position()); | ||
| input1.matrix<float>()(0, 13) = trk.dz(beamSpot.position()); | ||
| input1.matrix<float>()(0, 14) = trk.dxyError(); | ||
| input1.matrix<float>()(0, 15) = trk.dzError(); | ||
| input1.matrix<float>()(0, 16) = trk.normalizedChi2(); | ||
| input1.matrix<float>()(0, 17) = trk.eta(); | ||
| input1.matrix<float>()(0, 18) = trk.phi(); | ||
| input1.matrix<float>()(0, 19) = trk.etaError(); | ||
| input1.matrix<float>()(0, 20) = trk.phiError(); | ||
| input1.matrix<float>()(0, 21) = trk.hitPattern().numberOfValidPixelHits(); | ||
| input1.matrix<float>()(0, 22) = trk.hitPattern().numberOfValidStripHits(); | ||
| input1.matrix<float>()(0, 23) = trk.ndof(); | ||
| input1.matrix<float>()(0, 24) = trk.hitPattern().numberOfLostTrackerHits(reco::HitPattern::MISSING_INNER_HITS); | ||
| input1.matrix<float>()(0, 25) = trk.hitPattern().numberOfLostTrackerHits(reco::HitPattern::MISSING_OUTER_HITS); | ||
| input1.matrix<float>()(0, 26) = trk.hitPattern().trackerLayersTotallyOffOrBad(reco::HitPattern::MISSING_INNER_HITS); | ||
| input1.matrix<float>()(0, 27) = trk.hitPattern().trackerLayersTotallyOffOrBad(reco::HitPattern::MISSING_OUTER_HITS); | ||
| input1.matrix<float>()(0, 28) = trk.hitPattern().trackerLayersWithoutMeasurement(reco::HitPattern::TRACK_HITS); | ||
|
|
||
| //Original algo as its own input, it will enter the graph so that it gets one-hot encoded, as is the preferred | ||
| //format for categorical inputs, where the labels do not have any metric amongst them | ||
| input2.matrix<float>()(0, 0) = trk.originalAlgo(); | ||
|
|
||
| //The names for the input tensors get locked when freezing the trained tensorflow model. The NamedTensors must | ||
| //match those names | ||
| tensorflow::NamedTensorList inputs; | ||
| inputs.resize(2); | ||
| inputs[0] = tensorflow::NamedTensor("x", input1); | ||
| inputs[1] = tensorflow::NamedTensor("y", input2); | ||
| std::vector<tensorflow::Tensor> outputs; | ||
|
|
||
| //evaluate the input | ||
| tensorflow::run(session_, inputs, { "Identity" }, &outputs); | ||
|
|
||
| //scale output to be [-1, 1] due to convention | ||
| float output = 2.0*outputs[0].matrix<float>()(0, 0)-1.0; | ||
| return output; | ||
| } | ||
|
|
||
| std::string tfDnnLabel_; | ||
| tensorflow::Session *session_; | ||
| }; | ||
|
|
||
| using TrackTfClassifier = TrackMVAClassifier<tfDnn>; | ||
| } | ||
| #include "FWCore/PluginManager/interface/ModuleDef.h" | ||
| #include "FWCore/Framework/interface/MakerMacros.h" | ||
|
|
||
| DEFINE_FWK_MODULE(TrackTfClassifier); | ||
|
|
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,5 @@ | ||
| from DataFormats.TrackTfGraph.tfGraphDefProducer_cfi import tfGraphDefProducer as _tfGraphDefProducer | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This needs to be changed to reflect the new location of the TfGraphDefProducer. |
||
| trackSelectionTf = _tfGraphDefProducer.clone( | ||
| ComponentName = "trackSelectionTf", | ||
| FileName = "RecoTracker/FinalTrackSelectors/data/frozen_graph.pb" | ||
| ) | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,3 @@ | ||
| #include "PhysicsTools/TensorFlow/interface/TensorFlow.h" | ||
| #include "FWCore/Utilities/interface/typelookup.h" | ||
| TYPELOOKUP_DATA_REG(tensorflow::Session); | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This should indeed be unnecessary. |
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -4,6 +4,7 @@ | |
|
|
||
| #for dnn classifier | ||
| from Configuration.ProcessModifiers.trackdnn_cff import trackdnn | ||
| from dnnQualityCuts import qualityCutDictionary | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think we have a convention of using full import paths, i.e. from RecoTracker.IterativeTracking.dnnQualityCuts import qualityCutDictionaryThis comment applies everywhere where this is imported.
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. as it is, this import also won't work with python3. |
||
|
|
||
| ############################################### | ||
| # Low pT and detached tracks from pixel quadruplets | ||
|
|
@@ -212,11 +213,11 @@ | |
| qualityCuts = [-0.5,0.0,0.5] | ||
| ) | ||
|
|
||
| from RecoTracker.FinalTrackSelectors.TrackLwtnnClassifier_cfi import * | ||
| from RecoTracker.FinalTrackSelectors.trackSelectionLwtnn_cfi import * | ||
| trackdnn.toReplaceWith(detachedQuadStep, TrackLwtnnClassifier.clone( | ||
| src = 'detachedQuadStepTracks', | ||
| qualityCuts = [-0.6, 0.05, 0.7] | ||
| from RecoTracker.FinalTrackSelectors.TrackTfClassifier_cfi import * | ||
| from RecoTracker.FinalTrackSelectors.trackSelectionTf_cfi import * | ||
| trackdnn.toReplaceWith(detachedQuadStep, TrackTfClassifier.clone( | ||
| src = 'detachedQuadStepTracks', | ||
| qualityCuts = qualityCutDictionary['DetachedQuadStep'] | ||
| )) | ||
|
|
||
| highBetaStar_2018.toModify(detachedQuadStep,qualityCuts = [-0.7,0.0,0.5]) | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
hi @hajohajo - these dependencies on RecoTracker and PhysicsTools are not allowed in DataFormats.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think you should move TrackTfGraphDefProducer, TfGraphDefWrapper to RecoTracker/FinalTrackSelectors.