diff --git a/analysis_configurations/template_analysis/generate.py b/analysis_configurations/template_analysis/generate.py
index ba1cda60..62a72389 100644
--- a/analysis_configurations/template_analysis/generate.py
+++ b/analysis_configurations/template_analysis/generate.py
@@ -42,6 +42,7 @@ def run(args):
available_samples,
available_eras,
available_scopes,
+ DAG_dir=f"{args.output}/visualization",
)
# create a CodeGenerator object
generator = CodeGenerator(
diff --git a/analysis_configurations/template_analysis/generate_friends.py b/analysis_configurations/template_analysis/generate_friends.py
index 805e60a5..b5ec0200 100644
--- a/analysis_configurations/template_analysis/generate_friends.py
+++ b/analysis_configurations/template_analysis/generate_friends.py
@@ -58,6 +58,7 @@ def run(args):
available_eras,
available_scopes,
args.quantities_map,
+ DAG_dir=f"{args.output}/visualization",
)
# check if the config is of type FriendTreeConfiguration
if not isinstance(code_generation_config, FriendTreeConfiguration):
diff --git a/analysis_configurations/template_analysis/producers/event.py b/analysis_configurations/template_analysis/producers/event.py
index 770f5b8b..6566fc79 100644
--- a/analysis_configurations/template_analysis/producers/event.py
+++ b/analysis_configurations/template_analysis/producers/event.py
@@ -79,6 +79,7 @@
output=None,
scopes=["global"],
vec_configs=["met_filters"],
+ is_filter=True,
)
Lumi = Producer(
diff --git a/analysis_configurations/template_analysis/producers/scalefactors.py b/analysis_configurations/template_analysis/producers/scalefactors.py
index e6c55de0..8aa0a703 100644
--- a/analysis_configurations/template_analysis/producers/scalefactors.py
+++ b/analysis_configurations/template_analysis/producers/scalefactors.py
@@ -9,8 +9,8 @@
Muon_1_Reco_SF = Producer(
name="Muon_1_Reco_SF",
call="""physicsobject::muon::scalefactor::Reco(
- {df},
- correctionManager,
+ {df},
+ correctionManager,
{output},
{input},
"{muon_sf_file}",
@@ -24,8 +24,8 @@
Muon_1_ID_SF = Producer(
name="Muon_1_ID_SF",
call="""physicsobject::muon::scalefactor::Id(
- {df},
- correctionManager,
+ {df},
+ correctionManager,
{output},
{input},
"{muon_sf_file}",
@@ -39,8 +39,8 @@
Muon_1_Iso_SF = Producer(
name="Muon_1_Iso_SF",
call="""physicsobject::muon::scalefactor::Iso(
- {df},
- correctionManager,
+ {df},
+ correctionManager,
{output},
{input},
"{muon_sf_file}",
@@ -55,8 +55,8 @@
Muon_2_Reco_SF = Producer(
name="Muon_2_Reco_SF",
call="""physicsobject::muon::scalefactor::Reco(
- {df},
- correctionManager,
+ {df},
+ correctionManager,
{output},
{input},
"{muon_sf_file}",
@@ -70,8 +70,8 @@
Muon_2_ID_SF = Producer(
name="Muon_2_ID_SF",
call="""physicsobject::muon::scalefactor::Id(
- {df},
- correctionManager,
+ {df},
+ correctionManager,
{output},
{input},
"{muon_sf_file}",
@@ -85,8 +85,8 @@
Muon_2_Iso_SF = Producer(
name="Muon_2_Iso_SF",
call="""physicsobject::muon::scalefactor::Iso(
- {df},
- correctionManager,
+ {df},
+ correctionManager,
{output},
{input},
"{muon_sf_file}",
diff --git a/analysis_configurations/template_analysis/template_config.py b/analysis_configurations/template_analysis/template_config.py
index 21495a68..733dcfe9 100644
--- a/analysis_configurations/template_analysis/template_config.py
+++ b/analysis_configurations/template_analysis/template_config.py
@@ -14,6 +14,7 @@
from code_generation.modifiers import EraModifier
from code_generation.rules import RemoveProducer, AppendProducer
from code_generation.systematics import SystematicShift
+from code_generation.utility.generate_DAG import create_graph
def build_config(
@@ -24,6 +25,7 @@ def build_config(
available_sample_types: List[str],
available_eras: List[str],
available_scopes: List[str],
+ DAG_dir: str = "",
):
configuration = Configuration(
era,
@@ -260,4 +262,8 @@ def build_config(
configuration.optimize()
configuration.validate()
configuration.report()
- return configuration.expanded_configuration()
+ if DAG_dir:
+ nanoAOD_inputs = [n for n in dir(nanoAOD) if not n.startswith("__")]
+ create_graph(configuration, nanoAOD_inputs, DAG_dir, "CROWNelements")
+ configuration = configuration.expanded_configuration()
+ return configuration
diff --git a/analysis_configurations/template_analysis/template_friend_config.py b/analysis_configurations/template_analysis/template_friend_config.py
index 6de2c931..22fd732d 100644
--- a/analysis_configurations/template_analysis/template_friend_config.py
+++ b/analysis_configurations/template_analysis/template_friend_config.py
@@ -4,9 +4,11 @@
from .producers import pairselection as pairselection
from .producers import genparticles as genparticles
+from .quantities import nanoAOD as nanoAOD
from .quantities import output as q
from code_generation.friend_trees import FriendTreeConfiguration
from code_generation.rules import RemoveProducer
+from code_generation.utility.generate_DAG import create_graph
def build_config(
@@ -18,6 +20,7 @@ def build_config(
available_eras: List[str],
available_scopes: List[str],
quantities_map: Union[str, None] = None,
+ DAG_dir: str = "",
):
configuration = FriendTreeConfiguration(
era,
@@ -67,4 +70,8 @@ def build_config(
configuration.optimize()
configuration.validate()
configuration.report()
- return configuration.expanded_configuration()
+ if DAG_dir:
+ nanoAOD_inputs = [n for n in dir(nanoAOD) if not n.startswith("__")]
+ create_graph(configuration, nanoAOD_inputs, DAG_dir, "CROWNelements")
+ configuration = configuration.expanded_configuration()
+ return configuration
diff --git a/analysis_configurations/template_analysis/template_multifriend_config.py b/analysis_configurations/template_analysis/template_multifriend_config.py
index 23eb54e0..7329ae41 100644
--- a/analysis_configurations/template_analysis/template_multifriend_config.py
+++ b/analysis_configurations/template_analysis/template_multifriend_config.py
@@ -4,9 +4,11 @@
from .producers import pairselection as pairselection
from .producers import genparticles as genparticles
+from .quantities import nanoAOD as nanoAOD
from .quantities import output as q
from code_generation.friend_trees import FriendTreeConfiguration
from code_generation.rules import RemoveProducer
+from code_generation.utility.generate_DAG import create_graph
def build_config(
@@ -18,6 +20,7 @@ def build_config(
available_eras: List[str],
available_scopes: List[str],
quantities_map: Union[str, None] = None,
+ DAG_dir: str = "",
):
configuration = FriendTreeConfiguration(
era,
@@ -62,4 +65,8 @@ def build_config(
configuration.optimize()
configuration.validate()
configuration.report()
- return configuration.expanded_configuration()
+ if DAG_dir:
+ nanoAOD_inputs = [n for n in dir(nanoAOD) if not n.startswith("__")]
+ create_graph(configuration, nanoAOD_inputs, DAG_dir, "CROWNelements")
+ configuration = configuration.expanded_configuration()
+ return configuration
diff --git a/code_generation/analysis_template.cxx b/code_generation/analysis_template.cxx
index 990f2939..31b05dd6 100644
--- a/code_generation/analysis_template.cxx
+++ b/code_generation/analysis_template.cxx
@@ -156,73 +156,80 @@ int main(int argc, char *argv[]) {
const std::string analysis = {ANALYSISTAG};
const std::string config = {CONFIGTAG};
const std::string era = {ERATAG};
- const std::string sample = {SAMPLETAG};
+ const std::string sample_type = {SAMPLETAG};
const std::string commit_hash = {COMMITHASH};
bool setup_clean = {CROWN_IS_CLEAN};
const std::string analysis_commit_hash = {ANALYSIS_COMMITHASH};
bool analysis_setup_clean = {ANALYSIS_IS_CLEAN};
int scope_counter = 0;
for (auto const &output : output_quantities) {
- // output.first is the output file name
- // output.second is the list of quantities
+
TFile outputfile(output.first.c_str(), "UPDATE");
- TTree quantities_meta = TTree("quantities", "quantities");
- for (auto const &quantity : output.second) {
- quantities_meta.Branch(quantity.c_str(), &setup_clean);
- }
- quantities_meta.Write();
- TTree variations_meta = TTree("variations", "variations");
- for (auto const &variation : variations.at(output.first)) {
- variations_meta.Branch(variation.c_str(), &setup_clean);
- }
- variations_meta.Write();
- TTree conditions_meta = TTree("conditions", "conditions");
- conditions_meta.Branch(analysis.c_str(), &setup_clean);
- conditions_meta.Branch(config.c_str(), &setup_clean);
- conditions_meta.Branch(era.c_str(), &setup_clean);
- conditions_meta.Branch(sample.c_str(), &setup_clean);
- conditions_meta.Write();
- TTree commit_meta = TTree("commit", "commit");
- commit_meta.Branch(commit_hash.c_str(), &setup_clean);
- commit_meta.Fill();
- commit_meta.Write();
- TTree analysis_commit_meta =
- TTree("analysis_commit", "analysis_commit");
- analysis_commit_meta.Branch(analysis_commit_hash.c_str(),
- &analysis_setup_clean);
- analysis_commit_meta.Fill();
- analysis_commit_meta.Write();
+
+ // -----------------------------
+ // Unified metadata object
+ // -----------------------------
+ nlohmann::json j;
+
+ j["metadata"] = {
+ {"analysis", analysis},
+ {"config", config},
+ {"era", era},
+ {"sample_type", sample_type},
+ {"commit", commit_hash},
+ {"analysis_commit", analysis_commit_hash},
+ {"is_clean", setup_clean},
+ {"analysis_is_clean", analysis_setup_clean}
+ };
+
+ j["content"] = {
+ {"quantities", output.second},
+ {"variations", variations.at(output.first)}
+ };
+
+ TObjString json(j.dump().c_str());
+ outputfile.WriteObject(&json, "metadata");
+
+ // -----------------------------
+ // Cutflow histogram
+ // -----------------------------
if (nevents != 0) {
+
TH1D cutflow;
cutflow.SetName("cutflow");
cutflow.SetTitle("cutflow");
- // iterate through the cutflow vector and fill the histogram with
- // the .GetPass() values
- if (scope_counter >= cutReports.size()) {
+
+ if (cutReports.size() < scope_counter || cutReports.empty()) {
Logger::get("main")->critical(
- "Cutflow vector is too small, this should not happen");
+ "cutReports vector is too small, this should not happen");
return 1;
}
+
for (auto cut = cutReports[scope_counter].begin();
- cut != cutReports[scope_counter].end(); cut++) {
- cutflow.SetBinContent(
- std::distance(cutReports[scope_counter].begin(), cut) + 1,
- cut->GetPass());
- cutflow.GetXaxis()->SetBinLabel(
- std::distance(cutReports[scope_counter].begin(), cut) + 1,
- cut->GetName().c_str());
+ cut != cutReports[scope_counter].end(); cut++) {
+
+ int bin = std::distance(cutReports[scope_counter].begin(), cut) + 1;
+
+ cutflow.SetBinContent(bin, cut->GetPass());
+ cutflow.GetXaxis()->SetBinLabel(bin, cut->GetName().c_str());
}
- // store it in the output file
+
cutflow.Write();
}
- outputfile.Close();
- TFile *fout = TFile::Open(output.first.c_str(), "UPDATE");
+
+ // -----------------------------
+ // Additional shift maps
+ // -----------------------------
Logger::get("main")->info("Writing quantities map to {}", output.first);
- fout->WriteObject(&shift_quantities_map.at(output.first),
- "shift_quantities_map");
- fout->WriteObject(&quantities_shift_map.at(output.first),
- "quantities_shift_map");
- fout->Close();
+
+ outputfile.WriteObject(&shift_quantities_map.at(output.first),
+ "shift_quantities_map");
+
+ outputfile.WriteObject(&quantities_shift_map.at(output.first),
+ "quantities_shift_map");
+
+ outputfile.Close();
+
scope_counter++;
}
diff --git a/code_generation/analysis_template_friends.cxx b/code_generation/analysis_template_friends.cxx
index 1b1615d8..ace2d942 100644
--- a/code_generation/analysis_template_friends.cxx
+++ b/code_generation/analysis_template_friends.cxx
@@ -211,73 +211,80 @@ int main(int argc, char *argv[]) {
const std::string analysis = {ANALYSISTAG};
const std::string config = {CONFIGTAG};
const std::string era = {ERATAG};
- const std::string sample = {SAMPLETAG};
+ const std::string sample_type = {SAMPLETAG};
const std::string commit_hash = {COMMITHASH};
bool setup_clean = {CROWN_IS_CLEAN};
const std::string analysis_commit_hash = {ANALYSIS_COMMITHASH};
bool analysis_setup_clean = {ANALYSIS_IS_CLEAN};
int scope_counter = 0;
for (auto const &output : output_quantities) {
- // output.first is the output file name
- // output.second is the list of quantities
+
TFile outputfile(output.first.c_str(), "UPDATE");
- TTree quantities_meta = TTree("quantities", "quantities");
- for (auto const &quantity : output.second) {
- quantities_meta.Branch(quantity.c_str(), &setup_clean);
- }
- quantities_meta.Write();
- TTree variations_meta = TTree("variations", "variations");
- for (auto const &variation : variations.at(output.first)) {
- variations_meta.Branch(variation.c_str(), &setup_clean);
- }
- variations_meta.Write();
- TTree conditions_meta = TTree("conditions", "conditions");
- conditions_meta.Branch(analysis.c_str(), &setup_clean);
- conditions_meta.Branch(config.c_str(), &setup_clean);
- conditions_meta.Branch(era.c_str(), &setup_clean);
- conditions_meta.Branch(sample.c_str(), &setup_clean);
- conditions_meta.Write();
- TTree commit_meta = TTree("commit", "commit");
- commit_meta.Branch(commit_hash.c_str(), &setup_clean);
- commit_meta.Fill();
- commit_meta.Write();
- TTree analysis_commit_meta =
- TTree("analysis_commit", "analysis_commit");
- analysis_commit_meta.Branch(analysis_commit_hash.c_str(),
- &analysis_setup_clean);
- analysis_commit_meta.Fill();
- analysis_commit_meta.Write();
+
+ // -----------------------------
+ // Unified metadata container
+ // -----------------------------
+ nlohmann::json j;
+
+ j["metadata"] = {
+ {"analysis", analysis},
+ {"config", config},
+ {"era", era},
+ {"sample_type", sample_type},
+ {"commit", commit_hash},
+ {"analysis_commit", analysis_commit_hash},
+ {"is_clean", setup_clean},
+ {"analysis_is_clean", analysis_setup_clean}
+ };
+
+ j["content"] = {
+ {"quantities", output.second},
+ {"variations", variations.at(output.first)}
+ };
+
+ TObjString json(j.dump().c_str());
+ outputfile.WriteObject(&json, "metadata");
+
+ // -----------------------------
+ // Cutflow
+ // -----------------------------
if (nevents != 0) {
+
TH1D cutflow;
cutflow.SetName("cutflow");
cutflow.SetTitle("cutflow");
- // iterate through the cutflow vector and fill the histogram with
- // the .GetPass() values
+
if (cutReports.size() < scope_counter || cutReports.empty()) {
Logger::get("main")->critical(
"cutReports vector is too small, this should not happen");
return 1;
}
+
for (auto cut = cutReports[scope_counter].begin();
- cut != cutReports[scope_counter].end(); cut++) {
- cutflow.SetBinContent(
- std::distance(cutReports[scope_counter].begin(), cut) + 1,
- cut->GetPass());
- cutflow.GetXaxis()->SetBinLabel(
- std::distance(cutReports[scope_counter].begin(), cut) + 1,
- cut->GetName().c_str());
+ cut != cutReports[scope_counter].end(); cut++) {
+
+ int bin = std::distance(cutReports[scope_counter].begin(), cut) + 1;
+
+ cutflow.SetBinContent(bin, cut->GetPass());
+ cutflow.GetXaxis()->SetBinLabel(bin, cut->GetName().c_str());
}
- // store it in the output file
+
cutflow.Write();
}
- outputfile.Close();
- TFile *fout = TFile::Open(output.first.c_str(), "UPDATE");
+
+ // -----------------------------
+ // Shift maps
+ // -----------------------------
Logger::get("main")->info("Writing quantities map to {}", output.first);
- fout->WriteObject(&shift_quantities_map.at(output.first),
- "shift_quantities_map");
- fout->WriteObject(&quantities_shift_map.at(output.first),
- "quantities_shift_map");
- fout->Close();
+
+ outputfile.WriteObject(&shift_quantities_map.at(output.first),
+ "shift_quantities_map");
+
+ outputfile.WriteObject(&quantities_shift_map.at(output.first),
+ "quantities_shift_map");
+
+ outputfile.Close();
+
scope_counter++;
}
diff --git a/code_generation/code_generation.py b/code_generation/code_generation.py
index 6593831e..9fd9bfa8 100644
--- a/code_generation/code_generation.py
+++ b/code_generation/code_generation.py
@@ -390,12 +390,10 @@ def write_code(self, calls: str, includes: str, run_commands: str) -> None:
.replace(" // {RUN_COMMANDS}", run_commands)
.replace("// {MULTITHREADING}", threadcall)
.replace("// {DEBUGLEVEL}", self.set_debug_flag())
- .replace("{ERATAG}", '"Era={}"'.format(self.configuration.era))
- .replace(
- "{SAMPLETAG}", '"Samplegroup={}"'.format(self.configuration.sample)
- )
- .replace("{ANALYSISTAG}", '"Analysis={}"'.format(self.analysis_name))
- .replace("{CONFIGTAG}", '"Config={}"'.format(self.config_name))
+ .replace("{ERATAG}", '"{}"'.format(self.configuration.era))
+ .replace("{SAMPLETAG}", '"{}"'.format(self.configuration.sample))
+ .replace("{ANALYSISTAG}", '"{}"'.format(self.analysis_name))
+ .replace("{CONFIGTAG}", '"{}"'.format(self.config_name))
.replace("{OUTPUT_QUANTITIES}", self.set_output_quantities())
.replace("{SHIFT_QUANTITIES_MAP}", self.set_shift_quantities_map())
.replace("{QUANTITIES_SHIFT_MAP}", self.set_quantities_shift_map())
diff --git a/code_generation/friend_trees.py b/code_generation/friend_trees.py
index 84636249..69dcb9af 100644
--- a/code_generation/friend_trees.py
+++ b/code_generation/friend_trees.py
@@ -6,7 +6,7 @@
import os
from time import time
from code_generation.configuration import Configuration
-from typing import List, Union, Dict, Set
+from typing import List, Union, Dict, Set, Tuple, Any
from code_generation.exceptions import (
ConfigurationError,
@@ -141,21 +141,22 @@ def _determine_requested_shifts(self, shiftset: Set[str]) -> Dict[str, List[str]
def _readout_input_information(
self,
input_information_list: Union[List[str], List[Dict[str, List[str]]]],
+ metadata: Dict[str, Any] = {},
) -> Dict[str, Dict[str, List[str]]]:
def update_input_information(existing_data, new_data):
- if existing_data == {}:
- return new_data
- else:
- # otherwise we have to merge the contents, while not overwriting existing data
- for scope in new_data.keys():
- if scope not in existing_data.keys():
- existing_data[scope] = {}
- for shift in new_data[scope].keys():
- if shift not in existing_data[scope].keys():
- existing_data[scope][shift] = []
- for quantity in new_data[scope][shift]:
- if quantity not in existing_data[scope][shift]:
- existing_data[scope][shift].append(quantity)
+ # Merge contents, while not overwriting existing data
+ for scope in new_data.keys():
+ if scope not in existing_data.keys():
+ existing_data[scope] = {}
+ for shift in new_data[scope].keys():
+ if shift not in existing_data[scope].keys():
+ existing_data[scope][shift] = []
+ for quantity in new_data[scope][shift]:
+ if quantity not in existing_data[scope][shift]:
+ # Add origin config to quantity for naviagation with multifriends
+ existing_data[scope][shift].append(
+ (quantity, metadata["config"])
+ )
return existing_data
# first check if the input is a root file or a json file
@@ -164,13 +165,15 @@ def update_input_information(existing_data, new_data):
log.info(f"adding input information from {input_information}")
if isinstance(input_information, str):
if input_information.endswith(".root"):
- data = update_input_information(
- data, self._readout_input_root_file(input_information)
+ shift_map, metadata = self._readout_input_root_file(
+ input_information
)
+ data = update_input_information(data, shift_map)
elif input_information.endswith(".json"):
- data = update_input_information(
- data, self._readout_input_json_file(input_information)
+ shift_map, metadata = self._readout_input_json_file(
+ input_information
)
+ data = update_input_information(data, shift_map)
else:
error_message = f"\n Input information file {input_information} is not a json or root file \n"
error_message += (
@@ -183,7 +186,7 @@ def update_input_information(existing_data, new_data):
def _readout_input_root_file(
self, input_file: str
- ) -> Dict[str, Dict[str, List[str]]]:
+ ) -> Tuple[Dict[str, Dict[str, List[str]]], Dict[str, Any]]:
"""Read the shift_quantities_map from the input root file and return it as a dictionary
Args:
@@ -207,7 +210,7 @@ def _readout_input_root_file(
if not os.path.exists(lib_path):
log.error(f"Missing library: {lib_path}")
# Evaluate ROOT-specific return codes
- result = ROOT.gSystem.Load(lib_path)
+ result = ROOT.gSystem.Load(lib_path) # type: ignore
if result < 0:
err_type = (
"Version mismatch"
@@ -217,19 +220,19 @@ def _readout_input_root_file(
log.error(f"Load failed ({result}): {err_type} for {lib_path}")
f = ROOT.TFile.Open(input_file) # type: ignore
- name = "shift_quantities_map"
- m = f.Get(name)
+ m = f.Get("shift_quantities_map")
for shift, quantities in m:
data[str(shift)] = [str(quantity) for quantity in quantities]
+ metadata = json.loads(f.Get("metadata").GetString().Data())
f.Close()
log.debug(
f"Reading quantities information took {round(time() - start,2)} seconds"
)
- return {list(self.selected_scopes)[0]: data}
+ return {list(self.selected_scopes)[0]: data}, metadata["metadata"]
def _readout_input_json_file(
self, input_file: str
- ) -> Dict[str, Dict[str, List[str]]]:
+ ) -> Tuple[Dict[str, Dict[str, List[str]]], Dict[str, Any]]:
"""Read the shift_quantities_map from the input json file and return it as a dictionary
Args:
@@ -240,26 +243,33 @@ def _readout_input_json_file(
"""
with open(input_file) as f:
data = json.load(f)
+ quantity_data = data["quantities"]
+ metadata = data["metadata"]
# json file structure is: {era: {sampletype: {scope: {shift: [quantities]}}}
- if self.era not in data:
+ if self.era not in quantity_data or self.era != metadata["era"]:
errorstring = (
f"Era {self.era} not found in input information file {input_file}.\n"
)
- errorstring += f"Available eras are: {data.keys()}"
+ errorstring += f"Available eras are: {quantity_data.keys()}"
raise ConfigurationError(errorstring)
- if self.sample not in data[self.era].keys():
+ if (
+ self.sample not in quantity_data[self.era].keys()
+ or self.sample != metadata["sample_type"]
+ ):
errorstring = f"Sampletype {self.sample} not found in input information file {input_file}.\n"
- errorstring += f"Available sampletypes are: {data[self.era].keys()}"
+ errorstring += (
+ f"Available sampletypes are: {quantity_data[self.era].keys()}"
+ )
raise ConfigurationError(errorstring)
if not set(self.selected_scopes).issubset(
- set(data[self.era][self.sample].keys())
+ set(quantity_data[self.era][self.sample].keys())
):
errorstring = f"Scopes {self.selected_scopes} not found in input information file {input_file}.\n"
- errorstring += f"Available scopes are: {data[self.era][self.sample].keys()}"
+ errorstring += (
+ f"Available scopes are: {quantity_data[self.era][self.sample].keys()}"
+ )
raise ConfigurationError(errorstring)
- else:
- data = data[self.era][self.sample]
- return data
+ return quantity_data[self.era][self.sample], metadata
def optimize(self) -> None:
"""
@@ -375,7 +385,9 @@ def _validate_inputs(self) -> None:
[x.name for x in producer.get_outputs(scope)]
)
# get all available inputs
- for input_quantitiy in self.input_quantities_mapping[scope][""]:
+ for input_quantitiy, quantity_origin in self.input_quantities_mapping[
+ scope
+ ][""]:
available_inputs.add(input_quantitiy)
# now check if all inputs are available
missing_inputs = required_inputs - available_inputs
diff --git a/code_generation/helpers.py b/code_generation/helpers.py
index 75a5ce04..7e17987f 100644
--- a/code_generation/helpers.py
+++ b/code_generation/helpers.py
@@ -1,9 +1,10 @@
from __future__ import annotations # needed for type annotations in > python 3.7
+from typing import Any
# File with helper functions for the CROWN code generation
-def is_empty(value):
+def is_empty(value: Any) -> bool:
"""
Check if a value is empty.
@@ -13,12 +14,11 @@ def is_empty(value):
Returns:
bool: Whether the input value is considered 'empty'
"""
- # List of all values that should be considered empty despite not having a length.
empty_values = [None]
try:
length = len(value)
except TypeError:
length = -1
- bool_val = value in empty_values or length == 0
- return bool_val
+
+ return value in empty_values or length == 0
diff --git a/code_generation/producer.py b/code_generation/producer.py
index 69fa45f8..0332c576 100644
--- a/code_generation/producer.py
+++ b/code_generation/producer.py
@@ -28,6 +28,7 @@ def __init__(
input: Union[List[q.Quantity], Dict[str, List[q.Quantity]]],
output: Union[List[q.Quantity], None],
scopes: List[str],
+ is_filter: bool = False,
):
"""
A Producer is a class that holds all information about a producer. Input quantities are
@@ -38,6 +39,7 @@ def __init__(
input: A list of input quantities or a dict with scope specific input quantities
output: A list of output quantities
scopes: A list of scopes in which the producer is used
+ is_filter: True if the producer is a filter
"""
log.debug("Setting up a new producer {}".format(name))
@@ -57,6 +59,7 @@ def __init__(
self.call: str = call
self.output: Union[List[q.Quantity], None] = output
self.scopes = scopes
+ self.is_filter = is_filter
self.parameters: Dict[str, Set[str]] = self.extract_parameters()
# if input not given as dict and therfore not scope specific transform into dict with all scopes
if not isinstance(input, dict):
@@ -73,7 +76,10 @@ def __init__(
for output_quantity in self.output:
input_quantity.adopt(output_quantity, scope)
log.debug("-----------------------------------------")
- log.debug("| Producer: {}".format(self.name))
+ if self.is_filter:
+ log.debug("| Filter Producer: {}".format(self.name))
+ else:
+ log.debug("| Producer: {}".format(self.name))
log.debug("| Call: {}".format(self.call))
for scope in self.scopes:
if is_empty(self.input[scope]):
@@ -374,6 +380,7 @@ def __init__(
output: Union[List[q.Quantity], None],
scopes: List[str],
vec_configs: List[str],
+ is_filter: bool = False,
):
"""A Vector Producer is a Producer which can be configured to produce multiple calls and outputs at once, deprecated in favor of the ExtendedVectorProducer
@@ -384,9 +391,10 @@ def __init__(
output (Union[List[q.Quantity], None]): The outputs of the producer, either a list of Quantity objects, or None if the producer does not produce any output
scopes (List[str]): The scopes in which the producer is to be called
vec_configs (List[str]): A list of strings, which are the names of the parameters to be varied in the vectorized call
+ is_filter (bool, optional): Whether the vector producer is a filter. Defaults to False.
"""
self.name = name
- super().__init__(name, call, input, output, scopes)
+ super().__init__(name, call, input, output, scopes, is_filter)
self.vec_configs = vec_configs
def __str__(self) -> str:
@@ -463,6 +471,7 @@ def __init__(
output: str,
scope: Union[List[str], str],
vec_config: str,
+ is_filter: bool = False,
):
"""A ExtendedVectorProducer is a Producer which can be configured to produce multiple calls and outputs at once
@@ -473,6 +482,7 @@ def __init__(
output (Union[List[q.Quantity], None]): The outputs of the producer, either a list of Quantity objects, or None if the producer does not produce any output
scopes (List[str]): The scopes in which the producer is to be called
vec_configs (List[str]): The key of the vec config in the config dict
+ is_filter (bool, optional): Whether the vectorproducer is a filter. Defaults to False.
"""
# we create a Quantity Group, which is updated during the writecalls() step
self.outputname = output
@@ -482,7 +492,7 @@ def __init__(
quantity_group = q.QuantityGroup(name)
# set the vec config key of the quantity group
quantity_group.set_vec_config(vec_config)
- super().__init__(name, call, input, [quantity_group], scope)
+ super().__init__(name, call, input, [quantity_group], scope, is_filter)
if is_empty(self.output):
raise InvalidProducerConfigurationError(self.name)
# add the vec config to the parameters of the producer
@@ -566,7 +576,7 @@ def __init__(
either a list of Quantity objects, or a dict with the scope as key and a list of Quantity objects as value
scopes (List[str]): The scopes in which the filter is to be called
"""
- super().__init__(name, call, input, None, scopes)
+ super().__init__(name, call, input, None, scopes, True)
def __str__(self) -> str:
return "BaseFilter: {}".format(self.name)
diff --git a/code_generation/utility/CROWN_visualization.html b/code_generation/utility/CROWN_visualization.html
new file mode 100644
index 00000000..631d67f2
--- /dev/null
+++ b/code_generation/utility/CROWN_visualization.html
@@ -0,0 +1,995 @@
+
+
+
+
+ CROWN Graph Explorer
+
+
+
+
+
+
+
+
+
+
+
Controls
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+ Input Quantities
+
+
+
+
+
+
+
+ Output Quantities
+
+
+
+
+
+
+
+
+
+
Graph Legend
+
+
+
+ Reads from NanoAOD
+
+
+
+ Reads from Ntuple
+
+
+
+ Writes to file
+
+
+
+ Affected by Shift
+
+
+
+ Group
+
+
+
+ Vector Group
+
+
+
+ Filter
+
+
+
+ Unused Producer
+
+
+
+
+
+
+
+
+
+
Family Name
+
+
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/code_generation/utility/generate_DAG.py b/code_generation/utility/generate_DAG.py
new file mode 100644
index 00000000..830d9f9f
--- /dev/null
+++ b/code_generation/utility/generate_DAG.py
@@ -0,0 +1,878 @@
+import re
+import json
+import shutil
+import logging
+from pathlib import Path
+from collections import defaultdict
+from code_generation.quantity import QuantityGroup
+from code_generation.helpers import is_empty
+from typing import Never
+
+log = logging.getLogger(__name__)
+
+
+def log_and_fail(msg: str) -> Never:
+ log.exception(msg, stack_info=True)
+ raise RuntimeError(msg)
+
+
+def create_graph(configuration, NanoAOD_inputs, DAG_dir, json_name):
+ """Instantiate a GraphParser and execute DAG generation.
+
+ Args:
+ configuration (object): The configuration object containing scopes and producers.
+ NanoAOD_inputs (list): List of NanoAOD input quantities.
+ DAG_dir (str): Directory where the generated DAG JSON files will be saved.
+ json_name (str): The base name for the output JSON file.
+ """
+ for active_scope in configuration.scopes:
+ # Skip global scope as it is included in all other scopes
+ if active_scope == "global":
+ pass
+ else:
+ # Add Ntuple quantities for friends
+
+ if type(configuration).__name__ == "FriendTreeConfiguration":
+ external_inputs = {}
+ is_friend_config = True
+ elif type(configuration).__name__ == "Configuration":
+ external_inputs = {"NanoAOD": NanoAOD_inputs}
+ is_friend_config = False
+ else:
+ log_and_fail(
+ f"Unknown configuration type: {type(configuration).__name__}"
+ )
+
+ if hasattr(configuration, "input_quantities_mapping"):
+ if not is_friend_config:
+ log_and_fail(
+ "input_quantities_mapping is only supported for friend configurations."
+ )
+ external_inputs_tuples = configuration.input_quantities_mapping[
+ active_scope
+ ][""]
+ config_sources = defaultdict(list)
+ for quantity, config in external_inputs_tuples:
+ config_sources[config].append(quantity)
+ external_inputs.update(config_sources)
+ shifted_inputs = {
+ shift: [original for original, *_ in e_in]
+ for shift, e_in in configuration.input_quantities_mapping[
+ active_scope
+ ].items()
+ if shift != ""
+ }
+ else:
+ shifted_inputs = None
+ graph = GraphParser(
+ configuration,
+ external_inputs,
+ active_scope,
+ DAG_dir,
+ is_friend_config=is_friend_config,
+ shifted_inputs=shifted_inputs,
+ )
+ # Generate and save graph for each scope separately
+ graph.generate_graph()
+ graph.save_graph(json_name)
+
+
+class GraphParser:
+ """Parses configuration data to generate Directed Acyclic Graphs (DAGs).
+
+ Represents the dependencies and flow of producers, filters, and I/O.
+ """
+
+ def __init__(
+ self,
+ config,
+ external_inputs,
+ active_scope,
+ DAG_dir="",
+ is_friend_config=False,
+ shifted_inputs=None,
+ ):
+ """Initialize the GraphParser object.
+
+ Args:
+ config (object): The configuration object.
+ external_inputs (list): List of external inputs required by the graph.
+ active_scope (str): The specific scope being parsed in addition to the global scope.
+ DAG_dir (str, optional): The directory for saving DAG files. Defaults to "".
+ is_friend_config (bool, optional): Indicates if it's a friend configuration. Defaults to False.
+ """
+ self.config = config
+ self.active_scope = active_scope
+ if active_scope == "global":
+ self.scopes = ["global"]
+ else:
+ self.scopes = ["global", self.active_scope]
+ self.external_inputs = external_inputs
+ self.is_friend_config = is_friend_config
+ self.shifted_inputs = shifted_inputs
+ self.connections = defaultdict(lambda: defaultdict(list))
+ self.edges = []
+ self.inputs = defaultdict(lambda: defaultdict(list))
+ self.outputs = defaultdict(lambda: defaultdict(list))
+ self.direct_in_to_out = defaultdict(list)
+ self.vec_output_mappings = {}
+ self.node_register = defaultdict(
+ lambda: {
+ "name": None,
+ "parent": None,
+ "type": None,
+ "file_in": defaultdict(list),
+ "file_out": [],
+ "node_call": None,
+ "node_call_configs": {},
+ }
+ )
+ self.shift_registry = defaultdict()
+ self.DAG_dir = DAG_dir
+
+ def generate_graph(self):
+ """Execute the main orchestration to generate nodes and edges.
+
+ Iterates through all scopes of the configuration, builds the DAG components,
+ verifies output ambiguity, bundles edges, and generates formatting metadata
+ for visualization tools like Cytoscape.
+ """
+ # Determine producers from configuration for global and active scope
+ for scope in self.scopes:
+ self.parse_scope_producers(scope)
+
+ # Derive connections from inputs/outputs
+ self.assemble_connections()
+
+ # Identify nodes without a specific type and without any outgoing connections
+ # This excludes filters, groups, vector groups, and scopes.
+ # While they are only defined afterwards this also excludes proxy nodes and edges (branch, twig, leaf).
+ for node_id, node_data in self.node_register.items():
+ if (
+ not node_data.get("type")
+ and is_empty(node_data.get("file_out"))
+ and node_id not in self.connections
+ ):
+ node_data["type"] = "stump"
+
+ self.compile_shift_registry()
+
+ def parse_scope_producers(self, scope):
+ """Parse producers and outputs for a specific scope.
+
+ Args:
+ scope (str): The scope name to process.
+
+ Raises:
+ ValueError: If an output has multiple origins (ambiguous assignments).
+ """
+ if not is_empty(self.config.producers[scope]):
+ log.debug(f"For scope {scope}:")
+ # Add top level scope node
+ self.add_node(
+ id_name="scope",
+ name=f"{scope} scope",
+ scope=scope,
+ node_type="scope",
+ )
+
+ # Parse all producers in this scope
+ for p in self.config.producers[scope]:
+ self.parse_Producer_routing(
+ producer=p, parent="scope", scope=scope, align=" "
+ )
+
+ # Check outputs for ambiguous assignments (multiple origins for the same output)
+ all_keys = list(
+ set(list(self.outputs["global"].keys()) + list(self.outputs[scope].keys()))
+ )
+ for key in all_keys:
+ total_output_nodes = self.outputs["global"].get(key, []) + self.outputs[
+ scope
+ ].get(key, [])
+ if len(set(total_output_nodes)) != 1:
+ log_and_fail(f"Output {key} has multiple origins: {total_output_nodes}")
+
+ # Determine nodes writing to Ntuple by tracing configured scope outputs back to their producers
+ if hasattr(self.config, "outputs") and scope in self.config.outputs:
+ for output in self.config.outputs[scope]:
+ if isinstance(output, QuantityGroup):
+ if output.vec_config in self.config.config_parameters.get(
+ scope, {}
+ ):
+ vec_config = self.config.config_parameters[scope][
+ output.vec_config
+ ]
+ vec_output_name = self.vec_output_mappings.get(output.name)
+ if vec_output_name:
+ for o in vec_config:
+ req_out = o.get(vec_output_name)
+ if req_out:
+ self.set_is_out(scope, req_out)
+ else:
+ self.set_is_out(scope, output.name)
+
+ def parse_Producer_routing(self, producer, parent, scope, align=""):
+ """Route parsing logic based on the specific producer class type.
+
+ Args:
+ producer (object): The producer instance to be parsed.
+ parent (str): The ID of the parent node.
+ scope (str): The active scope.
+ align (str, optional): Spacing string used for debug print alignment. Defaults to "".
+
+ Raises:
+ NotImplementedError: If the producer's class is unknown.
+ """
+ class_name = producer.__class__.__name__
+
+ if class_name == "VectorProducer":
+ self.parse_VectorProducer(
+ producer=producer, parent=parent, scope=scope, align=align
+ )
+ elif class_name == "ExtendedVectorProducer":
+ self.parse_ExtendedVectorProducer(
+ producer=producer, parent=parent, scope=scope, align=align
+ )
+ elif class_name == "ProducerGroup":
+ self.parse_ProducerGroup(
+ producer=producer, parent=parent, scope=scope, align=align
+ )
+ elif class_name == "Filter":
+ self.parse_Filter(
+ producer=producer, parent=parent, scope=scope, align=align
+ )
+ elif class_name == "BaseFilter":
+ self.parse_BaseFilter(
+ producer=producer, parent=parent, scope=scope, align=align
+ )
+ elif class_name == "Producer":
+ self.parse_Producer(
+ producer=producer, parent=parent, scope=scope, align=align
+ )
+ else:
+ log_and_fail(f"Unknown Producer class {class_name}")
+
+ def parse_VectorProducer(self, producer, parent, scope, align=""):
+ """Map legacy vector configurations to graph nodes and inputs using regex.
+
+ Note:
+ Currently only supports 'event::filter::Flag'. This is a legacy function
+ and producers should be migrated to ExtendedVectorProducer.
+
+ Args:
+ producer (object): The legacy vector producer instance.
+ parent (str): The ID of the parent node.
+ scope (str): The active scope.
+ align (str, optional): Debug print alignment spacing. Defaults to "".
+
+ Raises:
+ NotImplementedError: If the producer call does not have legacy support.
+ """
+ log.debug(f"{align}Adding VectorProducer: {producer.name}")
+ log.warning(
+ f"!!! {producer.name} is a legacy producer and should be replaced with ExtendedVectorProducer !!!"
+ )
+
+ # Only accept whitelisted calls for legacy support
+ if producer.call.startswith("event::filter::Flag"):
+ self.parse_Flag_from_call(
+ producer=producer, parent=parent, scope=scope, align=align
+ )
+ else:
+ log_and_fail(f"The call {producer.call} does not have legacy support.")
+
+ def parse_Flag_from_call(self, producer, parent, scope, align=""):
+ """Extract vector configurations from a legacy producer's call string using regex.
+
+ Args:
+ producer (object): The legacy flag producer instance.
+ parent (str): The ID of the parent node.
+ scope (str): The active scope.
+ align (str, optional): Debug print alignment spacing. Defaults to "".
+
+ Raises:
+ ValueError: If input vector configs cannot be parsed or matched.
+ NotImplementedError: If list outputs are used (unsupported in legacy).
+ """
+ group_id = f"{producer.name}_v"
+ call = producer.call
+ # Pattern consists of: event::filter::Flag(df, "Flag_name", "Flag_name")
+ pattern = r'event::filter::Flag\({df}, "{(.*)}", "{(.*)}"\)'
+ match = re.search(pattern, call)
+
+ if match:
+ input_vec_config = match.group(1)
+ else:
+ log_and_fail(f"Input vector config could not be parsed from {call}")
+
+ if input_vec_config not in producer.vec_configs:
+ log_and_fail(
+ f"Input name from {call} not in producer vector configs {producer.vec_configs}"
+ )
+ # Determine the index of the input vector config from the call
+ vec_input_index = producer.vec_configs.index(input_vec_config)
+ call = producer.call
+ self.add_node(
+ id_name=group_id,
+ name=producer.name,
+ scope=scope,
+ parent=parent,
+ node_type="vector",
+ )
+ vec_configs = [
+ self.config.config_parameters[scope][c] for c in producer.vec_configs
+ ]
+ # Loop over all vector configurations
+ for i_c, c in enumerate(zip(*vec_configs)):
+ input_name = c[vec_input_index]
+ vector_id = f"{group_id}_{i_c}"
+ vector_name = f"{producer.name}_{i_c}"
+ vector_config_dict = {vc: cc for vc, cc in zip(producer.vec_configs, c)}
+ config_data = self.extract_configs(call, scope, vector_config_dict)
+ log.debug(f"{align} Adding Filter: {vector_name}")
+ self.add_node(
+ id_name=vector_id,
+ name=vector_name,
+ scope=scope,
+ parent=group_id,
+ is_filter=producer.is_filter,
+ node_call=call,
+ node_call_configs=config_data,
+ )
+ self.add_input(input_name, vector_id, scope)
+
+ # outputs are not supported for this legacy producer
+ if isinstance(producer.output, list):
+ log_and_fail("List outputs for legacy parsed calls are not supported.")
+
+ def parse_ExtendedVectorProducer(self, producer, parent, scope, align=""):
+ """Parse an ExtendedVectorProducer class instance into graph nodes, inputs, and outputs.
+
+ Args:
+ producer (object): The ExtendedVectorProducer instance.
+ parent (str): The ID of the parent node.
+ scope (str): The active scope.
+ align (str, optional): Debug print alignment spacing. Defaults to "".
+ """
+ log.debug(f"{align}Adding ExtendedVectorProducer: {producer.name}")
+ group_id = f"{producer.name}_v"
+ self.add_node(
+ id_name=group_id,
+ name=producer.name,
+ scope=scope,
+ parent=parent,
+ node_type="vector",
+ )
+
+ # ExtendedVectorProducer is better defined and doesn't require regex
+ call = producer.call
+ vec_config = self.config.config_parameters[scope][producer.vec_config]
+ # Loop over all vector configurations
+ for i_c, c in enumerate(vec_config):
+ vector_id = f"{group_id}_{i_c}"
+ vector_name = f"{producer.name}_{i_c}"
+ config_data = self.extract_configs(call, scope, c)
+ log.debug(f"{align} Adding Producer: {vector_name}")
+ self.add_node(
+ id_name=vector_id,
+ name=vector_name,
+ scope=scope,
+ parent=group_id,
+ is_filter=producer.is_filter,
+ node_call=call,
+ node_call_configs=config_data,
+ )
+
+ if isinstance(producer.input[scope], list):
+ for n in set(producer.input[scope]):
+ self.add_input(n.name, vector_id, scope)
+ log.debug(f"{align} Adding Input: {n.name}")
+
+ if isinstance(producer.output, list):
+ vec_output = c[producer.outputname]
+ self.vec_output_mappings[producer.output_group.name] = (
+ producer.outputname
+ )
+ self.add_output(vec_output, vector_id, scope)
+ log.debug(f"{align} Adding Output: {vec_output}")
+
+ def parse_ProducerGroup(self, producer, parent, scope, align=""):
+ """Parse a ProducerGroup class into graph nodes and recursively route its producers.
+
+ Args:
+ producer (object): The ProducerGroup instance.
+ parent (str): The ID of the parent node.
+ scope (str): The active scope.
+ align (str, optional): Debug print alignment spacing. Defaults to "".
+ """
+ log.debug(f"{align}Adding ProducerGroup: {producer.name}")
+
+ group_id = f"{producer.name}_g"
+ self.add_node(
+ id_name=group_id,
+ name=producer.name,
+ scope=scope,
+ parent=parent,
+ node_type="group",
+ )
+ # Add all producers in group recursively
+ for p in producer.producers[scope]:
+ self.parse_Producer_routing(
+ producer=p, parent=group_id, scope=scope, align=align + " "
+ )
+
+ def parse_Filter(self, producer, parent, scope, align=""):
+ """Parse a Filter class into graph nodes and recursively route its components.
+
+ Args:
+ producer (object): The Filter instance.
+ parent (str): The ID of the parent node.
+ scope (str): The active scope.
+ align (str, optional): Debug print alignment spacing. Defaults to "".
+ """
+ log.debug(f"{align}Adding Filter Group: {producer.name}")
+
+ group_id = f"{producer.name}_f"
+ self.add_node(
+ id_name=group_id,
+ name=producer.name,
+ scope=scope,
+ parent=parent,
+ node_type="group",
+ )
+
+ # Add all filters in group recursively
+ for p in producer.producers[scope]:
+ self.parse_Producer_routing(
+ producer=p, parent=group_id, scope=scope, align=align + " "
+ )
+
+ def parse_BaseFilter(self, producer, parent, scope, align=""):
+ """Parse a legacy BaseFilter class into graph nodes and inputs.
+
+ Args:
+ producer (object): The legacy BaseFilter instance.
+ parent (str): The ID of the parent node.
+ scope (str): The active scope.
+ align (str, optional): Debug print alignment spacing. Defaults to "".
+ """
+ log.debug(f"{align}Adding BaseFilter: {producer.name}")
+ log.warning(
+ f"!!! {producer.name} is a legacy producer and should be replaced with Filter !!!"
+ )
+ call = producer.call
+ config_data = self.extract_configs(call, scope)
+ self.add_node(
+ id_name=producer.name,
+ scope=scope,
+ parent=parent,
+ is_filter=producer.is_filter,
+ node_call=call,
+ node_call_configs=config_data,
+ )
+ # BaseFilter don't have outputs
+ for n in set(producer.input[scope]):
+ self.add_input(n.name, producer.name, scope)
+ log.debug(f"{align} Adding Input: {n.name}")
+
+ def parse_Producer(self, producer, parent, scope, align=""):
+ """Parse a generic Producer class into graph nodes, inputs, and outputs.
+
+ Args:
+ producer (object): The Producer instance.
+ parent (str): The ID of the parent node.
+ scope (str): The active scope.
+ align (str, optional): Debug print alignment spacing. Defaults to "".
+ """
+ log.debug(f"{align}Adding Producer: {producer.name}")
+ call = producer.call
+ config_data = self.extract_configs(call, scope)
+ self.add_node(
+ id_name=producer.name,
+ scope=scope,
+ parent=parent,
+ is_filter=producer.is_filter,
+ node_call=call,
+ node_call_configs=config_data,
+ )
+
+ if isinstance(producer.input[scope], list):
+ for n in set(producer.input[scope]):
+ self.add_input(n.name, producer.name, scope)
+ log.debug(f"{align} Adding Input: {n.name}")
+
+ if isinstance(producer.output, list):
+ for n in set(producer.output):
+ self.add_output(n.name, producer.name, scope)
+ log.debug(f"{align} Adding Output: {n.name}")
+
+ def set_is_out(self, scope, req_out):
+ """Designate a node as the source of an Ntuple output.
+
+ Args:
+ scope (str): The active scope.
+ req_out (str): The name of the requested output quantity.
+
+ Raises:
+ ValueError: If the number of producers is invalid, the source is missing,
+ or the output is not provided by NanoAOD/Ntuple.
+ """
+ producers = self.outputs[scope].get(req_out, []) + self.outputs["global"].get(
+ req_out, []
+ )
+ if len(producers) > 0:
+ if len(producers) != 1:
+ log_and_fail(f"Num producers for out {req_out}: {len(producers)}")
+ if self.node_register.get(producers[0]):
+ self.node_register[producers[0]]["file_out"].append(req_out)
+ else:
+ log_and_fail(
+ f"Source {producers[0]} is neither part of {scope} nor global scope."
+ )
+ else:
+ if req_out in set().union(*self.external_inputs.values()):
+ if self.is_friend_config:
+ keys = [
+ k
+ for k, values in self.external_inputs.items()
+ if req_out in values
+ ]
+ if len(keys) > 1:
+ log.warning(
+ f"Input {req_out} available from multiple sources {keys}. Picking the first one {keys[0]}."
+ )
+ quantity_source = keys[0]
+ log.debug(
+ f"Requested output quantity {req_out} provided by {quantity_source} Ntuple."
+ )
+ self.direct_in_to_out[quantity_source].append(req_out)
+ else:
+ log.debug(
+ f"Requested output quantity {req_out} provided by NanoAOD."
+ )
+ self.direct_in_to_out["NanoAOD"].append(req_out)
+ else:
+ log_and_fail(
+ f"Requested output quantity {req_out} not provided by NanoAOD/Ntuple."
+ )
+
+ def assemble_connections(self):
+ """Resolve mappings to construct the actual connecting edges in the DAG.
+
+ Determines the precise source of each input requirement, whether it
+ originates from another Producer, NanoAOD, or Ntuple.
+
+ Raises:
+ ValueError: If an input is entirely missing from both external sources and internal producers.
+ """
+ all_external_inputs = set().union(*self.external_inputs.values())
+ for scope in self.scopes:
+ # Assemble connections by iterating through all nodes
+ # {target_node: [required_inputs]} is matched to {input: [source_nodes]}
+ # Connection is of shape {source: {required_inputs: [targets]}}
+ for target_node, required_inputs in self.inputs[scope].items():
+ compose = defaultdict(list)
+
+ # Determine where each input comes from
+ for req_input in required_inputs:
+ if self.outputs["global"].get(req_input):
+ source = self.outputs["global"][req_input][0]
+ compose[source].append(req_input)
+ elif self.outputs[scope].get(req_input):
+ source = self.outputs[scope][req_input][0]
+ compose[source].append(req_input)
+ elif req_input in all_external_inputs:
+ # CROWN friend production may only read quantities from CROWN Ntuples
+ if self.is_friend_config:
+ keys = [
+ k
+ for k, values in self.external_inputs.items()
+ if req_input in values
+ ]
+ if len(keys) > 1:
+ log.warning(
+ f"Input {req_input} available from multiple sources {keys}. Picking the first one {keys[0]}."
+ )
+ quantity_source = keys[0]
+ self.node_register[target_node]["file_in"][
+ quantity_source
+ ].append(req_input)
+ else:
+ self.node_register[target_node]["file_in"][
+ "NanoAOD"
+ ].append(req_input)
+ else:
+ log_and_fail(
+ f"Input {req_input} is missing from NanoAOD/Ntuple and producers."
+ )
+
+ # Create connections grouped by source node
+ for source_node, input_names in compose.items():
+ for input_name in input_names:
+ log.debug(
+ f"Adding Connection {input_name} from {source_node} to {target_node}"
+ )
+ self.add_connection(
+ source=source_node, target=target_node, name=input_name
+ )
+
+ def compile_shift_registry(self):
+ # Aggregate all shifts
+ # Fails if configurations would be overwritten by another scope
+ aggregated_shifts = defaultdict(defaultdict)
+ for scope in self.scopes:
+ for shift, value in self.config.shifts[scope].items():
+ # Strip "__" of shift name
+ if not shift[0:2] == "__":
+ log_and_fail(f"Shift names must start with '__' -> Shift {shift}")
+ shift = shift[2:]
+ merger = aggregated_shifts[shift]
+ conflict = merger.keys() & value.keys() and merger != value
+ if conflict:
+ log_and_fail(f"Merge conflict on keys: {conflict}")
+ merger |= value
+
+ for shift, shift_configs in aggregated_shifts.items():
+ shift_heads = set()
+ for shift_cfg in shift_configs:
+ for node, node_data in self.node_register.items():
+ if shift_cfg in node_data["node_call_configs"]:
+ shift_heads.add(node)
+ if self.shifted_inputs:
+ shifted_inputs = self.shifted_inputs[shift]
+ else:
+ shifted_inputs = []
+ self.shift_registry[shift] = {
+ "heads": list(shift_heads),
+ "shift_configs": shift_configs,
+ "shifted_inputs": shifted_inputs,
+ }
+
+ def get_downstream(self, source_node, downstream=None):
+ # Get all downstream nodes
+ if downstream is None:
+ downstream = {"edges": set(), "nodes": set()}
+ for edge, nodes in self.connections[source_node].items():
+ downstream["edges"].add(f"edge_{edge}")
+ for node in nodes:
+ if node not in downstream["nodes"]:
+ downstream["nodes"].add(node)
+ ds_dat = self.get_downstream(node, downstream)
+ downstream["edges"].update(ds_dat["edges"])
+ downstream["nodes"].update(ds_dat["nodes"])
+
+ return downstream
+
+ def extract_configs(self, call, scope, vector_configs=None, ignore=None):
+ """Parse out explicit configuration string replacements from a call parameter.
+
+ Args:
+ call (str): The raw call string containing formatting templates.
+ scope (str): The active scope context.
+ vector_configs (dict, optional): Specific vector configurations to prefer. Defaults to None.
+ ignore (set/list, optional): Strings/keys to ignore when extracting. Defaults to None.
+
+ Returns:
+ dict: The mapped configurations dict resolving template markers to true values.
+
+ Raises:
+ ValueError: If an unknown configuration parameter is encountered.
+ """
+ # Ignore parameters that are not config parameters
+ if is_empty(ignore):
+ ignore = {
+ "df",
+ "output",
+ "input",
+ "output_vec",
+ "input_vec",
+ "vec_open",
+ "vec_close",
+ }
+ else:
+ ignore = set(ignore) # type: ignore
+
+ pattern = r"\{(\w+)\}"
+ matches = re.findall(pattern, call)
+
+ config_parameters = [m for m in matches if m not in ignore]
+
+ # Get config parameter values from vector or general configs
+ config_dict = {}
+ for c in config_parameters:
+ if vector_configs != None and not is_empty(vector_configs.get(c)):
+ config_dict[c] = vector_configs[c]
+ elif not is_empty(self.config.config_parameters[scope].get(c)):
+ config_dict[c] = self.config.config_parameters[scope][c]
+ else:
+ log_and_fail(f"Unknown config parameter {c}")
+ return config_dict
+
+ def add_output(self, output_name, output_node, scope):
+ """Register a node internally as the generator of a specific output.
+
+ Args:
+ output_name (str): The name of the output being produced.
+ output_node (str): The ID of the node producing it.
+ scope (str): The active scope context.
+
+ Raises:
+ ValueError: If the node is already registered for this specific output.
+ """
+ node_id = f"{scope}_{output_node}"
+ if node_id in self.outputs[scope][output_name]:
+ log_and_fail(f"Node {node_id} already exists in output nodes.")
+ self.outputs[scope][output_name].append(node_id)
+
+ def add_input(self, input_name, input_node, scope):
+ """Register a node internally as requiring a specific input dependency.
+
+ Args:
+ input_name (str): The name of the required input.
+ input_node (str): The ID of the node needing it.
+ scope (str): The active scope context.
+
+ Raises:
+ ValueError: If the input is already registered for this node.
+ """
+ scoped_input = f"{scope}_{input_node}"
+ if input_name in self.inputs[scope][scoped_input]:
+ log_and_fail(f"Input {input_name} already exists in inputs.")
+ self.inputs[scope][scoped_input].append(input_name)
+
+ def add_node(
+ self,
+ id_name,
+ name=None,
+ scope=None,
+ parent=None,
+ node_type=None,
+ is_filter=False,
+ family=None,
+ node_call=None,
+ node_call_configs=None,
+ ):
+ """Create a standardized node object and append it to the graph's internal list.
+
+ Also manages appending metadata to the node and edge family registers.
+
+ Args:
+ id_name (str): The base string ID of the node.
+ name (str, optional): Visual label for the node. Defaults to id_name.
+ scope (str, optional): The active scope context. Defaults to None.
+ parent (str, optional): The ID of the parent node. Defaults to None.
+ node_type (str, optional): Specific class or type of the node. Defaults to None.
+ is_filter (bool, optional): Indicates if the node acts as a filter. Defaults to False.
+ family (str, optional): Ties the node to a grouping/proxy family. Defaults to None.
+ node_call_data (dict, optional): Stores string call and extraction data. Defaults to None.
+
+ Raises:
+ ValueError: If conflicting args (node_type and is_filter) are supplied,
+ or if a full ID already exists within the standard family register.
+ """
+ if node_type and is_filter:
+ log_and_fail("Cannot specify both node_type and is_filter")
+ if is_filter:
+ node_type = "filter"
+ if not name:
+ name = id_name
+ if scope:
+ id_name = f"{scope}_{id_name}"
+ if parent:
+ if scope:
+ parent = f"{scope}_{parent}"
+ if self.node_register.get(id_name):
+ log_and_fail(f"Node {id_name} already exists in node register")
+ self.node_register[id_name]["name"] = name
+ if parent:
+ self.node_register[id_name]["parent"] = parent
+ if node_type:
+ self.node_register[id_name]["type"] = node_type
+ if node_call:
+ self.node_register[id_name]["node_call"] = node_call
+ self.node_register[id_name]["node_call_configs"] = node_call_configs
+
+ def add_connection(self, source, target, name):
+ """Log a relational connection locally between a source and a target.
+
+ Args:
+ source (str): ID of the source node providing data.
+ target (str): ID of the target node consuming data.
+ name (str): Connection quantity label.
+ """
+ self.connections[source][name].append(target)
+
+ def save_graph(self, name):
+ """Compile the DAG data structures, inject metadata, and export to a JSON file.
+
+ Also triggers an automatic update to the overarching DAG file tracker.
+
+ Args:
+ name (str): The base filename used to construct the final exported JSON path.
+ """
+ path = f"{name}_{self.config.era}_{self.config.sample}_{self.active_scope}.json"
+ if self.DAG_dir:
+ path = Path(self.DAG_dir) / path
+ Path(self.DAG_dir).mkdir(parents=True, exist_ok=True)
+
+ # Compile DAG data with metadata
+ full_data = {
+ "nodeRegister": self.node_register,
+ "edgeRegister": self.connections,
+ "directInputOutput": self.direct_in_to_out,
+ "shiftRegistry": self.shift_registry,
+ }
+
+ # Copy visualization file to build dir
+ script_current_dir = Path(__file__).parent.resolve()
+ source = script_current_dir / "CROWN_visualization.html"
+ target = Path(self.DAG_dir) / "index.html"
+ if not target.exists():
+ shutil.copy(source, target)
+
+ # Write DAG data to json
+ with open(path, "w") as f:
+ json.dump(full_data, f, indent=4)
+ log.info(
+ f"Generated DAG file for {self.config.era}/{self.config.sample}/{self.active_scope}: {path}"
+ )
+
+ # Update master DAG file list
+ self.update_DAG_file_list(
+ Path(self.DAG_dir) / "DAG_files.json",
+ self.config.era,
+ self.config.sample,
+ self.active_scope,
+ )
+
+ def update_DAG_file_list(self, config_path, new_era, new_sample, new_scope):
+ """Maintain and append to the master JSON manifest tracking generated DAG elements.
+
+ Prevents duplicates while verifying the registration of distinct eras, samples, and scopes.
+
+ Args:
+ config_path (str): Filepath pointing to the master 'DAG_files.json'.
+ new_era (str/int): The era identifier to check/add.
+ new_sample (str): The sample identifier to check/add.
+ new_scope (str): The scope string to check/add.
+ """
+ if Path(config_path).exists():
+ with open(config_path, "r") as f:
+ try:
+ data = json.load(f)
+ except json.JSONDecodeError:
+ data = {"era": [], "sample": [], "scope": []}
+ else:
+ data = {"era": [], "sample": [], "scope": []}
+
+ # Update master DAG file data
+ data["era"] = sorted(list(set(data["era"] + [str(new_era)])))
+ data["sample"] = sorted(list(set(data["sample"] + [str(new_sample)])))
+ data["scope"] = sorted(list(set(data["scope"] + [str(new_scope)])))
+
+ # Write updated master DAG file data back
+ with open(config_path, "w") as f:
+ json.dump(data, f, indent=4)
+
+ log.debug(f"Updated {config_path} with {new_era}/{new_sample}/{new_scope}")